Bu dokümanda, PJRT ile entegrasyon ve PJRT entegrasyonunun JAX ile nasıl test edileceğine dair önerilere odaklanılmaktadır.
PJRT ile entegrasyon
1. adım: PJRT C API arayüzünü uygulayın
A seçeneği: PJRT C API'yi doğrudan uygulayabilirsiniz.
B seçeneği: xla deposundaki C++ koduna göre derleme yapabiliyorsanız (çatallama veya bazel aracılığıyla) PJRT C++ API'sini de uygulayabilir ve C→C++ sarmalayıcısını kullanabilirsiniz:
- Temel PJRT istemcisini (ve ilgili PJRT sınıflarını) devralan bir C++ PJRT istemcisi uygulayın. C++ PJRT istemcisine örnek olarak pjrt_stream_executor_client.h ve tfrt_cpu_pjrt_client.h verilebilir.
- C++ PJRT istemcisinin parçası olmayan birkaç C API yöntemini uygulayın:
- PJRT_Client_Create. Aşağıda, örnek bir sözde kod verilmiştir (
GetPluginPjRtClient
'ün yukarıda uygulanan bir C++ PJRT istemcisi döndürdüğü varsayılmıştır):
- PJRT_Client_Create. Aşağıda, örnek bir sözde kod verilmiştir (
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
args->client = pjrt::CreateWrapperClient(std::move(client));
return nullptr;
}
} // namespace my_plugin
PJRT_Client_Create'in çerçeveden iletilen seçenekleri alabileceğini unutmayın. GPU istemcisinin bu özelliği nasıl kullandığına dair bir örnek burada verilmiştir.
- [İsteğe bağlı] PJRT_TopologyDescription_Create.
- [İsteğe bağlı] PJRT_Plugin_Initialize. Bu, diğer işlevler çağrılmadan önce çerçeve tarafından çağrılacak tek seferlik bir eklenti kurulumudur.
- [İsteğe bağlı] PJRT_Plugin_Attributes.
Sarmalayıcı ile kalan C API'lerini uygulamanız gerekmez.
2. Adım: GetPjRtApi'yi uygulayın
PJRT C API uygulamalarında işlev işaretçileri içeren bir PJRT_Api*
döndüren bir GetPjRtApi
yöntemi uygulamanız gerekir. Aşağıda, sarmalayıcı üzerinden uygulandığı varsayılan bir örnek verilmiştir (pjrt_c_api_cpu.cc'ye benzer):
const PJRT_Api* GetPjrtApi() {
static const PJRT_Api pjrt_api =
pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
return &pjrt_api;
}
3. Adım: C API uygulamalarını test edin
Temel PJRT C API davranışları için küçük bir test grubu çalıştırmak üzere RegisterPjRtCApiTestFactory'yi çağırabilirsiniz.
JAX'ten PJRT eklentisi kullanma
1. adım: JAX'i ayarlayın
JAX gecelik sürümünü kullanabilirsiniz
pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>
veya JAX'i kaynaktan derleyin.
Şu anda jaxlib sürümünü PJRT C API sürümüyle eşleştirmeniz gerekiyor. Genellikle, eklentinizi derlediğiniz TF taahhütüyle aynı güne ait bir jaxlib gecelik sürümü kullanmak yeterlidir. Örneğin:
pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
Ayrıca, tam olarak derlediğiniz XLA taahhütünde kaynaktan bir jaxlib oluşturabilirsiniz (talimatlar).
ABI uyumluluğunu yakında desteklemeye başlayacağız.
2. Adım: jax_plugins ad alanını kullanın veya entry_point değerini ayarlayın
Eklentinizin JAX tarafından keşfedilmesi için iki seçenek vardır.
- Ad alanı paketlerini kullanma (ref).
jax_plugins
ad alanı paketi altında küresel olarak benzersiz bir modül tanımlayın (yani birjax_plugins
dizini oluşturup modülünüzü bunun altında tanımlayın). Aşağıda bir dizin yapısı örneği verilmiştir:
jax_plugins/
my_plugin/
__init__.py
my_plugin.so
- Paket meta verilerini kullanma (ref). pyproject.toml veya setup.py aracılığıyla paket oluşturuyorsanız
jax_plugins
grubuna, tam modül adınızı gösteren bir giriş noktası ekleyerek eklenti modül adınızın reklamını yapın. pyproject.toml veya setup.py üzerinden bir örnek aşağıda verilmiştir:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'
# use setup.py
entry_points={
"jax_plugins": [
"my_plugin = my_plugin",
],
}
openxla-pjrt-plugin'in 2. seçenek kullanılarak nasıl uygulandığına dair örnekler: https://github.com/openxla/openxla-pjrt-plugin/pull/119, https://github.com/openxla/openxla-pjrt-plugin/pull/120
3. adım: initialize() yöntemini uygulayın
Eklentiyi kaydetmek için Python modülünüzde bir initialize() yöntemi uygulamanız gerekir. Örneğin:
import os
import jax._src.xla_bridge as xb
def initialize():
path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)
xla_bridge.register_plugin
'nin nasıl kullanılacağı hakkında daha fazla bilgi için lütfen burayı inceleyin. Bu yöntem şu anda gizlidir. Gelecekte herkese açık bir API yayınlanacaktır.
Eklentinin kaydedilip kaydedilmediğini doğrulamak ve yüklenememesi durumunda hata oluşturmak için aşağıdaki satırı çalıştırabilirsiniz.
jax.config.update("jax_platforms", "my_plugin")
JAX'in birden fazla arka ucu/eklentisi olabilir. Eklentinizin varsayılan arka uç olarak kullanılmasını sağlamak için birkaç seçenek vardır:
- 1. seçenek: Programın başında
jax.config.update("jax_platforms", "my_plugin")
komutunu çalıştırın. - 2. seçenek: ENV
JAX_PLATFORMS=my_plugin
olarak ayarlayın. - 3. Seçenek: xb.register_plugin işlevini çağırırken yeterince yüksek bir öncelik ayarlayın (varsayılan değer 400'dür ve mevcut diğer arka uçlardan daha yüksektir). En yüksek öncelikli arka uç yalnızca
JAX_PLATFORMS=''
olduğunda kullanılacaktır.JAX_PLATFORMS
politikasının varsayılan değeri''
'tır ancak bazen bu değerin üzerine yazılır.
JAX ile test etme
Deneyebileceğiniz bazı temel test durumları:
# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2
# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0
# pmap
arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))
# single device: [0]
# 4 devices: [6 7 8 9]
(Jax birim testlerini eklentinize karşı çalıştırma talimatlarını yakında ekleyeceğiz.)
PJRT eklentileri hakkında daha fazla örnek için PJRT Örnekleri başlıklı makaleyi inceleyin.