ข้อมูลเบื้องต้น
PJRT คือ Uniform Device API ที่เราต้องการเพิ่มลงในระบบนิเวศ ML วิสัยทัศน์ระยะยาวคือ
- เฟรมเวิร์ก (JAX, TF ฯลฯ) จะเรียกใช้ PJRT ซึ่งมีการใช้งานเฉพาะอุปกรณ์ที่เฟรมเวิร์กมองไม่เห็น
- อุปกรณ์แต่ละเครื่องมุ่งเน้นที่การใช้ PJRT API และอาจไม่รองรับเฟรมเวิร์ก
PJRT มีทั้ง C API และ C++ API การเชื่อมต่อที่เลเยอร์ใดเลเยอร์หนึ่งนั้นใช้ได้ C++ API ใช้คลาสเพื่อแยกแนวคิดบางอย่างออก แต่ก็มีความสัมพันธ์กับประเภทข้อมูล XLA มากกว่า หน้านี้จะเน้นที่ C++ API
คอมโพเนนต์ PJRT
PjRtClient
ดูข้อมูลอ้างอิงฉบับเต็มได้ที่ pjrt_client.h > PjRtClient
ไคลเอ็นต์จะจัดการการสื่อสารทั้งหมดระหว่างอุปกรณ์กับเฟรมเวิร์ก และรวมสถานะทั้งหมดที่ใช้ในการสื่อสาร โดยจะมีชุด API ทั่วไปสําหรับการโต้ตอบกับปลั๊กอิน PJRT และเป็นเจ้าของอุปกรณ์และพื้นที่หน่วยความจําสําหรับปลั๊กอินหนึ่งๆ
PjRtDevice
ดูข้อมูลอ้างอิงฉบับเต็มได้ที่ pjrt_client.h > PjRtDevice
และ pjrt_device_description.h
คลาสอุปกรณ์ใช้อธิบายอุปกรณ์เครื่องเดียว อุปกรณ์มีคำอธิบายอุปกรณ์เพื่อช่วยระบุประเภท (แฮชที่ไม่ซ้ำกันเพื่อระบุ GPU/CPU/xPU) และตำแหน่งภายในตารางกริดของอุปกรณ์ทั้งในระดับท้องถิ่นและทั่วโลก
อุปกรณ์ยังทราบพื้นที่หน่วยความจำที่เชื่อมโยงและลูกค้าที่เป็นเจ้าของพื้นที่หน่วยความจำนั้นด้วย
อุปกรณ์ไม่จําเป็นต้องทราบบัฟเฟอร์ของข้อมูลจริงที่เชื่อมโยงกับอุปกรณ์ แต่สามารถหาข้อมูลดังกล่าวได้โดยดูจากพื้นที่หน่วยความจําที่เกี่ยวข้อง
PjRtMemorySpace
ดูข้อมูลอ้างอิงฉบับเต็มได้ที่ pjrt_client.h > PjRtMemorySpace
พื้นที่หน่วยความจำสามารถใช้เพื่ออธิบายตำแหน่งของหน่วยความจำได้ โดยคุณสามารถเลิกปักหมุดและวางไว้ที่ใดก็ได้ แต่ต้องเข้าถึงได้จากอุปกรณ์ หรือจะปักหมุดไว้และวางไว้ในอุปกรณ์ที่ต้องการก็ได้
พื้นที่หน่วยความจำจะทราบบัฟเฟอร์ข้อมูลที่เชื่อมโยงอยู่ รวมถึงอุปกรณ์ (หลายเครื่อง) ที่เชื่อมโยงกับพื้นที่หน่วยความจำนั้น และไคลเอ็นต์ที่เป็นส่วนหนึ่งของพื้นที่หน่วยความจำ
PjRtBuffer
ดูข้อมูลฉบับเต็มได้ที่ pjrt_client.h > PjRtBuffer
บัฟเฟอร์จะเก็บข้อมูลในอุปกรณ์ในรูปแบบที่ใช้งานได้ง่ายภายในปลั๊กอิน เช่น แอตทริบิวต์องค์ประกอบ MLIR หรือรูปแบบ Tensor ที่เป็นกรรมสิทธิ์
เฟรมเวิร์กอาจพยายามส่งข้อมูลไปยังอุปกรณ์ในรูปแบบ xla::Literal
นั่นคือสำหรับอาร์กิวเมนต์อินพุตไปยังโมดูล ซึ่งต้องโคลน (หรือยืม) ไปยังหน่วยความจำของอุปกรณ์ เมื่อไม่จำเป็นต้องใช้บัฟเฟอร์อีกต่อไป เฟรมเวิร์กจะเรียกใช้เมธอด Delete
เพื่อล้างข้อมูล
บัฟเฟอร์จะรู้พื้นที่หน่วยความจำที่ตนเองเป็นส่วนประกอบ และโดยจะคิดหาว่าอุปกรณ์ใดเข้าถึงได้ แต่บัฟเฟอร์ไม่อาจรู้อุปกรณ์ของตนเสมอไป
สําหรับการสื่อสารกับเฟรมเวิร์ก บัฟเฟอร์จะรู้วิธีแปลงจากและกลับเป็นประเภท xla::Literal
ต่อไปนี้
// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}
// Buffer to Literal
xla::PjRtFuture<> ToLiteral(xla::MutableLiteralBase* literal) override {...}
API สําหรับสร้างบัฟเฟอร์จะมีความหมายของบัฟเฟอร์ ซึ่งช่วยระบุว่าข้อมูลตามตัวอักษรจากบัฟเฟอร์ของโฮสต์สามารถแชร์หรือคัดลอก หรือเปลี่ยนแปลงหรือไม่
สุดท้าย บัฟเฟอร์อาจต้องมีอายุการใช้งานนานกว่าขอบเขตของการดำเนินการ หากมีการกําหนดค่าให้กับตัวแปรในเลเยอร์เฟรมเวิร์ก x = jit(foo)(10)
ในกรณีเหล่านี้ บัฟเฟอร์จะอนุญาตให้สร้างการอ้างอิงภายนอกซึ่งให้พอยน์เตอร์ที่เป็นเจ้าของชั่วคราวไปยังข้อมูลที่บัฟเฟอร์เก็บไว้ พร้อมกับข้อมูลเมตา (dtype / ขนาดมิติข้อมูล) สำหรับการตีความข้อมูลพื้นฐาน
PjRtCompiler
ดูข้อมูลอ้างอิงฉบับเต็มได้ที่ pjrt_compiler.h > PjRtCompiler
คลาส PjRtCompiler
ให้รายละเอียดการใช้งานที่เป็นประโยชน์สําหรับแบ็กเอนด์ XLA แต่ไม่จำเป็นต้องใช้กับปลั๊กอิน ในทางทฤษฎีแล้ว ความรับผิดชอบของ PjRtCompiler
หรือเมธอด PjRtClient::Compile
คือการนำโมดูลอินพุตมาแสดงผลเป็น PjRtLoadedExecutable
PjRtExecutable / PjRtLoadedExecutable
ดูข้อมูลฉบับเต็มได้ที่ pjrt_executable.h > PjRtExecutable
และ pjrt_client.h > PjRtLoadedExecutable
PjRtExecutable
จะทราบวิธีใช้อาร์ติแฟกต์ที่คอมไพล์แล้วและตัวเลือกการดําเนินการ และจัดรูปแบบ/แยกรูปแบบอาร์ติแฟกต์ดังกล่าวเพื่อให้จัดเก็บและโหลดไฟล์ปฏิบัติการได้ตามต้องการ
PjRtLoadedExecutable
คือไฟล์ปฏิบัติการที่คอมไพล์ไว้ในหน่วยความจำซึ่งพร้อมที่จะดำเนินการกับอาร์กิวเมนต์อินพุต และเป็นคลาสย่อยของ PjRtExecutable
ไฟล์ปฏิบัติการจะติดต่อผ่านเมธอด Execute
ของไคลเอ็นต์วิธีใดวิธีหนึ่งต่อไปนี้
// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}
// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
PjRtDevice* device, ...) {...}
// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
PjRtDevice* device, ...) {...}
ก่อนที่จะเรียกใช้ Execute
เฟรมเวิร์กจะโอนข้อมูลที่จำเป็นทั้งหมดไปยัง PjRtBuffers
ที่เป็นของไคลเอ็นต์ที่กำลังดำเนินการ แต่ส่งคืนเพื่อให้เฟรมเวิร์กที่ใช้อ้างอิง จากนั้นระบบจะส่งบัฟเฟอร์เหล่านี้เป็นอาร์กิวเมนต์ไปยังเมธอด Execute
แนวคิด PJRT
PjRtFutures และการประมวลผลแบบไม่พร้อมกัน
หากมีการใช้ส่วนใดส่วนหนึ่งของปลั๊กอินแบบไม่พร้อมกัน ต้องใช้ฟิวเจอร์อย่างถูกต้อง
ลองพิจารณาโปรแกรมต่อไปนี้
@jax.jit
def foo(x): return x + 1
x = foo(1)
# [...] other logic not using `x`
print(x + 1)
ปลั๊กอินอะซิงโครนัสจะสามารถจัดคิวการคำนวณ x
ได้ และแสดงผลบัฟเฟอร์ที่ยังไม่พร้อมอ่านทันที แต่คำสั่งจะแสดงขึ้นมา การดำเนินการสามารถจัดคิวการคํานวณที่จําเป็นหลังจาก x
ต่อไปได้ ซึ่งไม่จําเป็นต้องใช้ x
รวมถึงการดำเนินการในอุปกรณ์ PJRT เครื่องอื่นๆ เมื่อต้องการค่าของ x
ระบบจะบล็อกการดำเนินการจนกว่าบัฟเฟอร์จะประกาศว่าพร้อมแล้วผ่านค่าที่ GetReadyFuture
แสดงผลในอนาคต
ฟิวเจอร์สมีประโยชน์ในการระบุว่าออบเจ็กต์จะพร้อมใช้งานเมื่อใด ซึ่งรวมถึงอุปกรณ์และบัฟเฟอร์
แนวคิดขั้นสูง
การใช้ API พื้นฐานเพิ่มเติมจะขยายฟีเจอร์ของ JAX ที่ปลั๊กอินใช้ได้ ฟีเจอร์เหล่านี้เป็นฟีเจอร์ที่ต้องเลือกใช้ทั้งหมด ในแง่ที่ว่า JIT ทั่วไปและเวิร์กโฟลว์การดำเนินการจะทำงานได้โดยไม่ต้องใช้ฟีเจอร์เหล่านี้ แต่สำหรับไปป์ไลน์คุณภาพระดับเวอร์ชันที่ใช้งานจริง คุณควรพิจารณาระดับการรองรับฟีเจอร์ต่อไปนี้ที่ PJRT API รองรับ
- พื้นที่หน่วยความจำ
- การจัดวางที่กำหนดเอง
- การดำเนินการสื่อสาร เช่น ส่ง/รับ
- การลดภาระงานของโฮสต์
- การจัดสรร
การสื่อสารระหว่างเฟรมเวิร์ก PJRT กับอุปกรณ์ทั่วไป
บันทึกตัวอย่าง
ต่อไปนี้เป็นบันทึกของเมธอดที่เรียกให้โหลดปลั๊กอิน PJRT และเรียกใช้ y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)
ในกรณีนี้ เราบันทึก JAX ที่โต้ตอบกับปลั๊กอิน PJRT อ้างอิง StableHLO
บันทึกตัวอย่าง
////////////////////////////////// // Load the plugin ////////////////////////////////// I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400) I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8) I device.cc:104] StablehloReferenceDevice(0x23bac4e0) I device.cc:123] client(0x23bac4e0) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I client_cpp_pjrt.cc:86] devices(0x23bac400) I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400) I device.cc:168] description(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:86] Attributes(0x23bac4f8) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400) I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I device.cc:148] memory_spaces(0x23bac4e0) Creating PJRT Client from client I client_cpp_pjrt.cc:108] platform_version(0x23bac400) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I device.cc:57] id(0x23bac4f8) I device.cc:70] device_kind(0x23bac4f8) I device.cc:70] device_kind(0x23bac4f8) I device.cc:80] ToString(0x23bac4f8) I device.cc:80] ToString(0x23bac4f8) I device.cc:75] DebugString(0x23bac4f8) I device.cc:75] DebugString(0x23bac4f8) I device.cc:61] process_index(0x23bac4f8) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I device.cc:153] default_memory_space(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) ////////////////////////////////// // RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)` ////////////////////////////////// I executable.cc:309] num_partitions(0x240bab70) I executable.cc:305] num_replicas(0x240bab70) I executable.cc:309] num_partitions(0x240bab70) I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400) I buffer.cc:285] CreateMlirBufferFromLiteral I buffer.cc:98] CreateFromLiteral I buffer.cc:99] CreateFromLiteral: s32[] 2 I buffer.cc:64] MlirPjrtBuffer(0x240bb050) I buffer.cc:102] CreateFromLiteral -> 0x240bb050 I buffer.cc:158] device(0x240bb050) I buffer.cc:154] memory_space(0x240bb050) I buffer.cc:154] memory_space(0x240bb050) I executable.cc:328] GetHloModules(0x240bab70) I executable.cc:240] Execute(0x240bab70) I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x240bb050) I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32> I executable.cc:205] EvalModule: module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) { // ... return %3 : tensor<i32> } } I executable.cc:206] Inputs: [dense<2> : tensor<i32>] I executable.cc:213] Results: [dense<2> : tensor<i32>] I device.cc:153] default_memory_space(0x23bac4e0) I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x22cea630) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630 ////////////////////////////////// // RUN: `print(y)` ////////////////////////////////// I buffer.cc:263] GetReadyFuture(0x22cea630) I buffer.cc:264] GetReadyFuture(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references. I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:168] AcquireExternalReference(0x22cea630) I buffer.cc:73] MlirClonedExternalReference(0x2404d560) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32> I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x240bb050) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050 I buffer.cc:168] AcquireExternalReference(0x22cea630) I buffer.cc:73] MlirClonedExternalReference(0x240b6010) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32> I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x23b2db60) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60 I buffer.cc:263] GetReadyFuture(0x22cea630) I buffer.cc:264] GetReadyFuture(0x22cea630)