Arka plan
PJRT, ML ekosistemine eklemek istediğimiz tek tip Device API'dir. Uzun vadeli vizyon şu şekildedir: (1) Çerçeveler (JAX, TF vb.) çerçeveler için opak olan cihaza özel uygulamaları olan PJRT'yi çağırır; (2) her cihaz, PJRT API'lerini uygulamaya odaklanır ve çerçeveler için opak olabilir.
Bu belgede, PJRT ile entegrasyon ve JAX ile PJRT entegrasyonunun nasıl test edileceği ile ilgili öneriler ele alınmaktadır.
PJRT entegrasyonu
1. Adım: PJRT C API arayüzünü uygulayın
Seçenek A: PJRT C API'yi doğrudan uygulayabilirsiniz.
B seçeneği: xla deposunda C++ koduna göre derleme yapabiliyorsanız (çatal veya bazel aracılığıyla) PJRT C++ API'sini uygulayıp C→C++ sarmalayıcıyı da kullanabilirsiniz:
- Temel PJRT istemcisinden (ve ilgili PJRT sınıflarından) devralan bir C++ PJRT istemcisini uygulayın. C++ PJRT istemcisi için bazı örnekler şunlardır: pjrt_stream_executor_client.h, tfrt_cpu_pjrt_client.h.
- C++ PJRT istemcisinin parçası olmayan birkaç C API yöntemini uygulayın:
- PJRT_Client_Create Aşağıda bazı örnek kod verilmiştir (
GetPluginPjRtClient
hizmetinin yukarıda uygulanan bir C++ PJRT istemcisi döndürdüğü varsayılır):
- PJRT_Client_Create Aşağıda bazı örnek kod verilmiştir (
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
args->client = pjrt::CreateWrapperClient(std::move(client));
return nullptr;
}
} // namespace my_plugin
PJRT_Client_Create öğesinin, çerçeveden iletilen seçenekleri alabileceğini unutmayın. Bir GPU istemcisinin bu özelliği nasıl kullandığına ilişkin bir örneği burada bulabilirsiniz.
- [İsteğe bağlı] PJRT_TopologyDescription_Create.
- [İsteğe Bağlı] PJRT_Plugin_Initialize. Bu, diğer işlevler çağrılmadan önce çerçeve tarafından çağrılacak tek seferlik bir eklenti kurulumudur.
- [İsteğe Bağlı] PJRT_Plugin_Attributes.
Sarmalayıcı ile kalan C API'lerini uygulamanız gerekmez.
2. Adım: GetPjRtApi'yi uygulayın
PJRT C API uygulamalarına yönelik işlev işaretçileri içeren bir PJRT_Api*
döndüren bir GetPjRtApi
yöntemi uygulamanız gerekiyor. Aşağıda, sarmalayıcı üzerinden uygulandığı varsayılan bir örnek verilmiştir (pjrt_c_api_cpu.cc'ye benzer şekilde):
const PJRT_Api* GetPjrtApi() {
static const PJRT_Api pjrt_api =
pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
return &pjrt_api;
}
3. Adım: C API uygulamalarını test etme
Temel PJRT C API davranışlarıyla ilgili küçük bir test grubu çalıştırmak için RegisterPjRtCApiTestFactory'yi çağırabilirsiniz.
JAX'tan PJRT eklentisi nasıl kullanılır?
1. Adım: JAX'i kurun
JAX'ı gece kullanabilirsiniz
pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>
veya kaynaktan JAX oluşturun.
Şimdilik jaxlib sürümünü PJRT C API sürümüyle eşleştirmeniz gerekiyor. Genellikle, eklentinizi oluştururken kullandığınız TF kaydı ile aynı gün içindeki jaxlib gece sürümünü kullanmanız yeterlidir, ör.
pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
Kaynaktan tam olarak kullandığınız XLA kaydında da bir jaxlib oluşturabilirsiniz (instructions).
ABI uyumluluğunu yakında desteklemeye başlayacağız.
2. Adım: jax_plugins ad alanını kullanın veya giriş_noktasını ayarlayın
Eklentinizin JAX tarafından keşfedilmesi için iki seçenek vardır.
- Ad alanı paketlerini kullanma (ref).
jax_plugins
ad alanı paketi altında genel olarak benzersiz bir modül tanımlayın (yani birjax_plugins
dizini oluşturup modülünüzü onun altında tanımlayın). Aşağıda örnek bir dizin yapısı verilmiştir:
jax_plugins/
my_plugin/
__init__.py
my_plugin.so
- Paket meta verilerini kullanma (ref). pyproject.toml veya setup.py aracılığıyla bir paket oluşturuyorsanız
jax_plugins
grubu altında tam modül adınıza işaret eden bir giriş noktası ekleyerek eklenti modülünüzün adını tanıtın. pyproject.toml veya setup.py üzerinden alınan bir örneği burada görebilirsiniz:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'
# use setup.py
entry_points={
"jax_plugins": [
"my_plugin = my_plugin",
],
}
2. Seçenek kullanılarak openxla-pjrt-plugin'in nasıl uygulandığına dair örnekler: https://github.com/openxla/openxla-pjrt-plugin/pull/119, https://github.com/openxla/openxla-pjrt-plugin/pull/120
3. Adım: initialize() yöntemini uygulayın
Eklentiyi kaydetmek için python modülünüzde bir initialize() yöntemi uygulamanız gerekir. Örneğin:
import os
import jax._src.xla_bridge as xb
def initialize():
path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)
xla_bridge.register_plugin
uygulamasını nasıl kullanacağınızı öğrenmek için lütfen buraya göz atın. Bu yöntem şu anda gizli bir yöntemdir. Gelecekte herkese açık bir API yayınlanacaktır.
Eklentinin kayıtlı olduğunu doğrulamak ve yüklenememesi durumunda hata bildirmek için aşağıdaki satırı çalıştırabilirsiniz.
jax.config.update("jax_platforms", "my_plugin")
JAX'ın birden fazla arka ucu/eklentisi olabilir. Eklentinizin varsayılan arka uç olarak kullanıldığından emin olmak için birkaç seçenek vardır:
- 1. seçenek:
jax.config.update("jax_platforms", "my_plugin")
uygulamasını programın başında çalıştırma. - 2. seçenek: ENV'yi
JAX_PLATFORMS=my_plugin
ayarlayın. - 3. Seçenek: xb.register_plugin'i çağırırken yeterince yüksek bir öncelik belirleyin (varsayılan değer, diğer mevcut arka uçlardan daha yüksek olan 400'dür). En yüksek önceliğe sahip arka ucun yalnızca
JAX_PLATFORMS=''
durumunda kullanılacağını unutmayın.JAX_PLATFORMS
öğesinin varsayılan değeri''
olsa da bazen bu değerin üzerine yazılabilir.
JAX ile nasıl test yapılır?
Deneyebileceğiniz bazı temel test durumları:
# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2
# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0
# pmap
arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))
# single device: [0]
# 4 devices: [6 7 8 9]
(Yakında eklentinize karşı jax birimi testlerini çalıştırmaya ilişkin talimatlar ekleyeceğiz!)
Örnek: JAX CUDA eklentisi
- Sarmalayıcı (pjrt_c_api_gpu.h) üzerinden PJRT C API uygulaması.
- Paketin giriş noktasını ayarlayın (setup.py).
- Bir initialize() yöntemini uygulayın (__init__.py).
- CUDA için tüm jax testleriyle test edilebilir. ```