การผสานรวมปลั๊กอิน PJRT

เอกสารนี้จะเน้นที่คําแนะนําเกี่ยวกับวิธีผสานรวมกับ PJRT และวิธีทดสอบการผสานรวม PJRT กับ JAX

วิธีผสานรวมกับ PJRT

ขั้นตอนที่ 1: ใช้อินเทอร์เฟซ PJRT C API

ตัวเลือก ก: คุณใช้ PJRT C API ได้โดยตรง

ตัวเลือก ข: หากสามารถบิลด์กับโค้ด C++ ใน repo xla (ผ่านการแยกหรือ Bazel) คุณจะใช้ PJRT C++ API และใช้ Wrapper C→C++ ได้ด้วย โดยทำดังนี้

  1. ใช้ไคลเอ็นต์ PJRT ของ C++ ที่รับค่ามาจากไคลเอ็นต์ PJRT พื้นฐาน (และคลาส PJRT ที่เกี่ยวข้อง) ตัวอย่างไคลเอ็นต์ PJRT ของ C++ ได้แก่ pjrt_stream_executor_client.h, tfrt_cpu_pjrt_client.h
  2. ใช้เมธอด C API 2-3 รายการที่ไม่ได้เป็นส่วนหนึ่งของไคลเอ็นต์ PJRT ของ C++ ดังนี้
    • PJRT_Client_Create ด้านล่างนี้คือตัวอย่างโค้ดจำลอง (สมมติว่า GetPluginPjRtClient แสดงผลไคลเอ็นต์ PJRT ของ C++ ที่ติดตั้งใช้งานด้านบน)
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"

namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
  std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
  args->client = pjrt::CreateWrapperClient(std::move(client));
  return nullptr;
}
}  // namespace my_plugin

หมายเหตุ PJRT_Client_Create สามารถใช้ตัวเลือกที่ส่งมาจากเฟรมเวิร์ก ที่นี่เป็นตัวอย่างวิธีที่ไคลเอ็นต์ GPU ใช้ฟีเจอร์นี้

  • [ไม่บังคับ] PJRT_TopologyDescription_Create
  • [ไม่บังคับ] PJRT_Plugin_Initialize นี่เป็นการตั้งค่าปลั๊กอินแบบครั้งเดียว ซึ่งเฟรมเวิร์กจะเรียกใช้ก่อนเรียกใช้ฟังก์ชันอื่นๆ
  • [ไม่บังคับ] PJRT_Plugin_Attributes

เมื่อใช้ Wrapper คุณไม่จำเป็นต้องใช้ C API ที่เหลือ

ขั้นตอนที่ 2: ติดตั้งใช้งาน GetPjRtApi

คุณต้องติดตั้งใช้งานเมธอด GetPjRtApi ซึ่งแสดงผล PJRT_Api* ที่มีตัวชี้ฟังก์ชันไปยังการติดตั้งใช้งาน PJRT C API ด้านล่างนี้คือตัวอย่างที่สมมติว่าใช้งานผ่าน Wrapper (คล้ายกับ pjrt_c_api_cpu.cc)

const PJRT_Api* GetPjrtApi() {
  static const PJRT_Api pjrt_api =
      pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
  return &pjrt_api;
}

ขั้นตอนที่ 3: ทดสอบการติดตั้งใช้งาน C API

คุณสามารถเรียก RegisterPjRtCApiTestFactory เพื่อเรียกใช้ชุดการทดสอบขนาดเล็กสําหรับลักษณะการทํางานพื้นฐานของ PJRT C API

วิธีใช้ปลั๊กอิน PJRT จาก JAX

ขั้นตอนที่ 1: ตั้งค่า JAX

คุณจะใช้ JAX เวอร์ชันรายวันก็ได้

pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>

หรือสร้าง JAX จากซอร์ส

ในระหว่างนี้ คุณต้องจับคู่เวอร์ชัน jaxlib กับเวอร์ชัน PJRT C API โดยทั่วไปแล้ว การใช้ jaxlib เวอร์ชันรายวันจากวันเดียวกับการคอมมิต TF ที่คุณใช้สร้างปลั๊กอินก็เพียงพอแล้ว เช่น

pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>

นอกจากนี้ คุณยังสร้าง jaxlib จากซอร์สโค้ดในคอมมิต XLA ที่คุณกำลังสร้างได้ด้วย (วิธีการ)

เราจะเริ่มรองรับความเข้ากันได้ของ ABI ในเร็วๆ นี้

ขั้นตอนที่ 2: ใช้เนมสเปซ jax_plugins หรือตั้งค่า entry_point

JAX จะค้นพบปลั๊กอินของคุณได้ 2 วิธีดังนี้

  1. การใช้แพ็กเกจเนมสเปซ (ref) กําหนดโมดูลที่ไม่ซ้ำกันทั่วโลกภายในแพ็กเกจเนมสเปซ jax_plugins (กล่าวคือ สร้างไดเรกทอรี jax_plugins และกำหนดโมดูลด้านล่าง) ตัวอย่างโครงสร้างไดเรกทอรีมีดังนี้
jax_plugins/
  my_plugin/
    __init__.py
    my_plugin.so
  1. การใช้ข้อมูลเมตาของแพ็กเกจ (ref) หากสร้างแพ็กเกจผ่าน pyproject.toml หรือ setup.py ให้แสดงชื่อโมดูลของปลั๊กอินโดยใส่จุดแรกเข้าในกลุ่ม jax_plugins ซึ่งชี้ไปยังชื่อโมดูลแบบเต็ม ตัวอย่างผ่าน pyproject.toml หรือ setup.py มีดังนี้
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'

# use setup.py
entry_points={
  "jax_plugins": [
    "my_plugin = my_plugin",
  ],
}

ตัวอย่างการใช้งาน openxla-pjrt-plugin โดยใช้ตัวเลือกที่ 2 มีดังนี้ https://github.com/openxla/openxla-pjrt-plugin/pull/119, https://github.com/openxla/openxla-pjrt-plugin/pull/120

ขั้นตอนที่ 3: ใช้เมธอด initialize()

คุณต้องติดตั้งใช้งานเมธอด initialize() ในโมดูล Python เพื่อลงทะเบียนปลั๊กอิน เช่น

import os
import jax._src.xla_bridge as xb

def initialize():
  path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
  xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)

โปรดดูวิธีใช้ xla_bridge.register_plugin ที่นี่ ซึ่งปัจจุบันเป็นเมธอดส่วนตัว เราจะเปิดตัว API สาธารณะในอนาคต

คุณสามารถเรียกใช้บรรทัดด้านล่างเพื่อยืนยันว่ามีการลงทะเบียนปลั๊กอินและแสดงข้อผิดพลาดหากโหลดไม่ได้

jax.config.update("jax_platforms", "my_plugin")

JAX อาจมีแบ็กเอนด์/ปลั๊กอินหลายรายการ มีตัวเลือก 2-3 รายการเพื่อให้แน่ใจว่าระบบจะใช้ปลั๊กอินของคุณเป็นแบ็กเอนด์เริ่มต้น

  • ตัวเลือกที่ 1: เรียกใช้ jax.config.update("jax_platforms", "my_plugin") ในช่วงต้นของโปรแกรม
  • ตัวเลือกที่ 2: set ENV JAX_PLATFORMS=my_plugin
  • ตัวเลือกที่ 3: ตั้งค่าลําดับความสําคัญให้สูงพอเมื่อเรียกใช้ xb.register_plugin (ค่าเริ่มต้นคือ 400 ซึ่งสูงกว่าแบ็กเอนด์ที่มีอยู่อื่นๆ) โปรดทราบว่าระบบจะใช้แบ็กเอนด์ที่มีลำดับความสำคัญสูงสุดเฉพาะเมื่อ JAX_PLATFORMS='' ค่าเริ่มต้นของ JAX_PLATFORMS คือ '' แต่บางครั้งระบบจะเขียนทับค่านี้

วิธีทดสอบด้วย JAX

กรอบการทดสอบพื้นฐานที่ควรลองมีดังนี้

# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2

# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0

# pmap

arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))

# single device: [0]

# 4 devices: [6 7 8 9]

(เราจะเพิ่มวิธีการเรียกใช้การทดสอบหน่วย jax กับปลั๊กอินของคุณในเร็วๆ นี้)

ดูตัวอย่างเพิ่มเติมของปลั๊กอิน PJRT ได้ที่ตัวอย่าง PJRT