PJRT eklentisi entegrasyonu

Bu dokümanda, PJRT ile entegrasyon ve PJRT entegrasyonunun JAX ile nasıl test edileceğine dair önerilere odaklanılmaktadır.

PJRT ile entegrasyon

1. adım: PJRT C API arayüzünü uygulayın

A seçeneği: PJRT C API'yi doğrudan uygulayabilirsiniz.

B seçeneği: xla deposundaki C++ koduna göre derleme yapabiliyorsanız (çatallama veya bazel aracılığıyla) PJRT C++ API'sini de uygulayabilir ve C→C++ sarmalayıcısını kullanabilirsiniz:

  1. Temel PJRT istemcisini (ve ilgili PJRT sınıflarını) devralan bir C++ PJRT istemcisi uygulayın. C++ PJRT istemcisine örnek olarak pjrt_stream_executor_client.h ve tfrt_cpu_pjrt_client.h verilebilir.
  2. C++ PJRT istemcisinin parçası olmayan birkaç C API yöntemini uygulayın:
    • PJRT_Client_Create. Aşağıda, örnek bir sözde kod verilmiştir (GetPluginPjRtClient'ün yukarıda uygulanan bir C++ PJRT istemcisi döndürdüğü varsayılmıştır):
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"

namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
  std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
  args->client = pjrt::CreateWrapperClient(std::move(client));
  return nullptr;
}
}  // namespace my_plugin

PJRT_Client_Create'in çerçeveden iletilen seçenekleri alabileceğini unutmayın. GPU istemcisinin bu özelliği nasıl kullandığına dair bir örnek burada verilmiştir.

Sarmalayıcı ile kalan C API'lerini uygulamanız gerekmez.

2. Adım: GetPjRtApi'yi uygulayın

PJRT C API uygulamalarında işlev işaretçileri içeren bir PJRT_Api* döndüren bir GetPjRtApi yöntemi uygulamanız gerekir. Aşağıda, sarmalayıcı üzerinden uygulandığı varsayılan bir örnek verilmiştir (pjrt_c_api_cpu.cc'ye benzer):

const PJRT_Api* GetPjrtApi() {
  static const PJRT_Api pjrt_api =
      pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
  return &pjrt_api;
}

3. Adım: C API uygulamalarını test edin

Temel PJRT C API davranışları için küçük bir test grubu çalıştırmak üzere RegisterPjRtCApiTestFactory'yi çağırabilirsiniz.

JAX'ten PJRT eklentisi kullanma

1. adım: JAX'i ayarlayın

JAX gecelik sürümünü kullanabilirsiniz

pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>

veya JAX'i kaynaktan derleyin.

Şu anda jaxlib sürümünü PJRT C API sürümüyle eşleştirmeniz gerekiyor. Genellikle, eklentinizi derlediğiniz TF taahhütüyle aynı güne ait bir jaxlib gecelik sürümü kullanmak yeterlidir. Örneğin:

pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

Ayrıca, tam olarak derlediğiniz XLA taahhütünde kaynaktan bir jaxlib oluşturabilirsiniz (talimatlar).

ABI uyumluluğunu yakında desteklemeye başlayacağız.

2. Adım: jax_plugins ad alanını kullanın veya entry_point değerini ayarlayın

Eklentinizin JAX tarafından keşfedilmesi için iki seçenek vardır.

  1. Ad alanı paketlerini kullanma (ref). jax_plugins ad alanı paketi altında küresel olarak benzersiz bir modül tanımlayın (yani bir jax_plugins dizini oluşturup modülünüzü bunun altında tanımlayın). Aşağıda bir dizin yapısı örneği verilmiştir:
jax_plugins/
  my_plugin/
    __init__.py
    my_plugin.so
  1. Paket meta verilerini kullanma (ref). pyproject.toml veya setup.py aracılığıyla paket oluşturuyorsanız jax_plugins grubuna, tam modül adınızı gösteren bir giriş noktası ekleyerek eklenti modül adınızın reklamını yapın. pyproject.toml veya setup.py üzerinden bir örnek aşağıda verilmiştir:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'

# use setup.py
entry_points={
  "jax_plugins": [
    "my_plugin = my_plugin",
  ],
}

openxla-pjrt-plugin'in 2. seçenek kullanılarak nasıl uygulandığına dair örnekler: https://github.com/openxla/openxla-pjrt-plugin/pull/119, https://github.com/openxla/openxla-pjrt-plugin/pull/120

3. adım: initialize() yöntemini uygulayın

Eklentiyi kaydetmek için Python modülünüzde bir initialize() yöntemi uygulamanız gerekir. Örneğin:

import os
import jax._src.xla_bridge as xb

def initialize():
  path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
  xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)

xla_bridge.register_plugin'nin nasıl kullanılacağı hakkında daha fazla bilgi için lütfen burayı inceleyin. Bu yöntem şu anda gizlidir. Gelecekte herkese açık bir API yayınlanacaktır.

Eklentinin kaydedilip kaydedilmediğini doğrulamak ve yüklenememesi durumunda hata oluşturmak için aşağıdaki satırı çalıştırabilirsiniz.

jax.config.update("jax_platforms", "my_plugin")

JAX'in birden fazla arka ucu/eklentisi olabilir. Eklentinizin varsayılan arka uç olarak kullanılmasını sağlamak için birkaç seçenek vardır:

  • 1. seçenek: Programın başında jax.config.update("jax_platforms", "my_plugin") komutunu çalıştırın.
  • 2. seçenek: ENV JAX_PLATFORMS=my_plugin olarak ayarlayın.
  • 3. Seçenek: xb.register_plugin işlevini çağırırken yeterince yüksek bir öncelik ayarlayın (varsayılan değer 400'dür ve mevcut diğer arka uçlardan daha yüksektir). En yüksek öncelikli arka uç yalnızca JAX_PLATFORMS='' olduğunda kullanılacaktır. JAX_PLATFORMS politikasının varsayılan değeri '''tır ancak bazen bu değerin üzerine yazılır.

JAX ile test etme

Deneyebileceğiniz bazı temel test durumları:

# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2

# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0

# pmap

arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))

# single device: [0]

# 4 devices: [6 7 8 9]

(Jax birim testlerini eklentinize karşı çalıştırma talimatlarını yakında ekleyeceğiz.)

PJRT eklentileri hakkında daha fazla örnek için PJRT Örnekleri başlıklı makaleyi inceleyin.