ที่มา
PJRT คือ Device API แบบเดียวกันที่เราต้องการเพิ่มลงในระบบนิเวศ ML วิสัยทัศน์ระยะยาวคือ (1) เฟรมเวิร์ก (JAX, TF เป็นต้น) จะเรียก PJRT ซึ่งจะมีการใช้งานเฉพาะอุปกรณ์ที่ไม่ชัดเจนเฟรมเวิร์ก (2) อุปกรณ์แต่ละเครื่องมุ่งเน้นที่การใช้ PJRT API และอาจไม่ชัดเจนว่าเฟรมเวิร์กมีอะไรบ้าง
เอกสารนี้มุ่งเน้นที่คำแนะนำเกี่ยวกับวิธีผสานรวมกับ PJRT และวิธีทดสอบการผสานรวม PJRT กับ JAX
วิธีผสานรวมกับ PJRT
ขั้นตอนที่ 1: ใช้อินเทอร์เฟซ PJRT C API
ตัวเลือก ก: คุณใช้ PJRT C API ได้โดยตรง
ตัวเลือก B: หากคุณสามารถสร้างโดยใช้โค้ด C++ ใน Xla Repo (ผ่าน Forking หรือ Bazel) ได้ คุณจะสามารถใช้งาน PJRT C++ API และใช้ C→C++ Wrapper
- ใช้ไคลเอ็นต์ PJRT ของ C++ ที่รับค่ามาจากไคลเอ็นต์ PJRT พื้นฐาน (และคลาส PJRT ที่เกี่ยวข้อง) ตัวอย่างไคลเอ็นต์ C++ PJRT มีดังนี้ pjrt_stream_executor_client.h, tfrt_cpu_pjrt_client.h
- ใช้เมธอด API แบบ C บางรายการที่ไม่ได้เป็นส่วนหนึ่งของไคลเอ็นต์ C++ PJRT ดังนี้
- PJRT_Client_Create ด้านล่างนี้เป็นตัวอย่างโค้ดเทียม (สมมติว่า
GetPluginPjRtClient
แสดงผลไคลเอ็นต์ C++ PJRT ที่ใช้ด้านบน)
- PJRT_Client_Create ด้านล่างนี้เป็นตัวอย่างโค้ดเทียม (สมมติว่า
#include "third_party/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
namespace my_plugin {
PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
std::unique_ptr<xla::PjRtClient> client = GetPluginPjRtClient();
args->client = pjrt::CreateWrapperClient(std::move(client));
return nullptr;
}
} // namespace my_plugin
โปรดทราบว่า PJRT_Client_Create จะใช้ตัวเลือกที่ส่งจากเฟรมเวิร์กได้ ดูตัวอย่างวิธีที่ไคลเอ็นต์ GPU ใช้ฟีเจอร์นี้ได้ที่นี่
- [ไม่บังคับ] PJRT_TopologyDescription_Create
- [ไม่บังคับ] PJRT_Plugin_Initialize นี่คือการตั้งค่าปลั๊กอินแบบครั้งเดียว ซึ่งเฟรมเวิร์กจะเรียกใช้ก่อนที่จะเรียกใช้ฟังก์ชันอื่นๆ
- [ไม่บังคับ] PJRT_Plugin_Attributes
เมื่อใช้ wrapper คุณไม่จําเป็นต้องใช้ C API ที่เหลืออยู่
ขั้นตอนที่ 2: ใช้งาน GetPjRtApi
คุณต้องใช้เมธอด GetPjRtApi
ซึ่งแสดงผล PJRT_Api*
ที่มีตัวชี้ฟังก์ชันไปยังการใช้งาน PJRT C API ด้านล่างนี้เป็นตัวอย่างที่สมมติว่ามีการใช้งานผ่าน Wrapper (คล้ายกับ pjrt_c_api_cpu.cc)
const PJRT_Api* GetPjrtApi() {
static const PJRT_Api pjrt_api =
pjrt::CreatePjrtApi(my_plugin::PJRT_Client_Create);
return &pjrt_api;
}
ขั้นตอนที่ 3: ทดสอบการใช้งาน API ของ C
คุณเรียกใช้ RegisterPjRtCApiTestFactory เพื่อเรียกใช้การทดสอบชุดเล็กๆ สำหรับลักษณะการทำงานพื้นฐานของ PJRT C API ได้
วิธีใช้ปลั๊กอิน PJRT จาก JAX
ขั้นตอนที่ 1: ตั้งค่า JAX
คุณสามารถใช้ JAX ต่อคืน
pip install --pre -U jaxlib -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
pip install git+<a href="https://github.com/google/jax">https://github.com/google/jax</a>
หรือ สร้าง JAX จากซอร์ส
สำหรับตอนนี้ คุณจำเป็นต้องจับคู่เวอร์ชัน jaxlib กับเวอร์ชัน PJRT C API โดยทั่วไปแล้วการใช้เวอร์ชัน jaxlib แบบ Nightly จากวันเดียวกับ TFคอมมิต ที่คุณกำลังสร้างปลั๊กอินด้วยก็เพียงพอแล้ว เช่น
pip install --pre -U jaxlib==0.4.2.dev20230103 -f <a href="https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html">https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html</a>
นอกจากนี้ คุณยังสร้าง jaxlib จากแหล่งที่มาตาม XLA ที่คอมมิตแบบที่คุณกำลังสร้างได้ด้วย (instructions)
เราจะเริ่มรองรับความเข้ากันได้กับ ABI เร็วๆ นี้
ขั้นตอนที่ 2: ใช้เนมสเปซ jax_plugins หรือตั้งค่า entry_point
มี 2 ตัวเลือกในการให้ JAX ค้นพบปลั๊กอินของคุณ
- การใช้แพ็กเกจเนมสเปซ (ref) กำหนดโมดูลที่ไม่ซ้ำกันทั่วโลกภายใต้แพ็กเกจเนมสเปซ
jax_plugins
(เช่น เพียงสร้างไดเรกทอรีjax_plugins
และกำหนดโมดูลของคุณด้านล่าง) ตัวอย่างโครงสร้างไดเรกทอรีมีดังนี้
jax_plugins/
my_plugin/
__init__.py
my_plugin.so
- การใช้ข้อมูลเมตาของแพ็กเกจ (ref) หากสร้างแพ็กเกจผ่าน pyproject.toml หรือ Setup.py ให้โฆษณาชื่อโมดูลปลั๊กอินของคุณโดยการใส่จุดเข้าถึงภายใต้กลุ่ม
jax_plugins
ซึ่งชี้ไปยังชื่อโมดูลแบบเต็ม ต่อไปนี้เป็นตัวอย่างผ่าน pyproject.toml หรือ Setup.py:
# use pyproject.toml
[project.entry-points.'jax_plugins']
my_plugin = 'my_plugin'
# use setup.py
entry_points={
"jax_plugins": [
"my_plugin = my_plugin",
],
}
ตัวอย่างการติดตั้งใช้งาน openxla-pjrt-plugin โดยใช้ตัวเลือกที่ 2: https://github.com/openxla/openxla-pjrt-plugin/pull/119, https://github.com/openxla/openxla-pjrt-plugin/pull/120
ขั้นตอนที่ 3: ใช้เมธอด initialize()
คุณต้องใช้เมธอด initialize() ในโมดูล Python เพื่อลงทะเบียนปลั๊กอิน เช่น
import os
import jax._src.xla_bridge as xb
def initialize():
path = os.path.join(os.path.dirname(__file__), 'my_plugin.so')
xb.register_plugin('my_plugin', priority=500, library_path=path, options=None)
โปรดไปที่นี่เกี่ยวกับวิธีใช้ xla_bridge.register_plugin
ปัจจุบันวิธีนี้เป็นวิธีส่วนตัว API สาธารณะจะเปิดตัวในอนาคต
คุณสามารถเรียกใช้บรรทัดด้านล่างเพื่อยืนยันว่ามีการลงทะเบียนปลั๊กอินแล้ว และแสดงข้อผิดพลาดหากโหลดไม่ได้
jax.config.update("jax_platforms", "my_plugin")
JAX อาจมีแบ็กเอนด์/ปลั๊กอินหลายรายการ คุณมีตัวเลือก 2-3 อย่างที่จะช่วยให้มั่นใจว่าระบบใช้ปลั๊กอินของคุณเป็นแบ็กเอนด์เริ่มต้นดังนี้
- ตัวเลือกที่ 1: เรียกใช้
jax.config.update("jax_platforms", "my_plugin")
ในช่วงเริ่มต้นของโปรแกรม - ตัวเลือกที่ 2: ตั้งค่า ENV
JAX_PLATFORMS=my_plugin
- ตัวเลือกที่ 3: กำหนดลำดับความสำคัญสูงพอเมื่อเรียกใช้ xb.register_plugin (ค่าเริ่มต้นคือ 400 ซึ่งสูงกว่าแบ็กเอนด์อื่นๆ ที่มีอยู่) โปรดทราบว่าระบบจะใช้แบ็กเอนด์ที่มีลำดับความสำคัญสูงสุดเมื่อ
JAX_PLATFORMS=''
เท่านั้น ค่าเริ่มต้นของJAX_PLATFORMS
คือ''
แต่บางครั้งอาจมีการเขียนทับ
วิธีทดสอบด้วย JAX
ตัวอย่างกรอบการทดสอบพื้นฐานที่ควรลองใช้มีดังนี้
# JAX 1+1
print(jax.numpy.add(1, 1))
# => 2
# jit
print(jax.jit(lambda x: x * 2)(1.))
# => 2.0
# pmap
arr = jax.numpy.arange(jax.device_count()) print(jax.pmap(lambda x: x +
jax.lax.psum(x, 'i'), axis_name='i')(arr))
# single device: [0]
# 4 devices: [6 7 8 9]
(เราจะเพิ่มวิธีการเรียกใช้การทดสอบหน่วย jax กับปลั๊กอินของคุณเร็วๆ นี้)
ตัวอย่าง: ปลั๊กอิน JAX CUDA
- การติดตั้งใช้งาน PJRT C API ผ่าน Wrapper (pjrt_c_api_gpu.h)
- ตั้งค่าจุดแรกเข้าสำหรับแพ็กเกจ (setup.py)
- ใช้เมธอด initialize() (__init__.py)
- สามารถทดสอบด้วยการทดสอบ jax สำหรับ CUDA ```