ภาพรวมของ PJRT C++ Device API

ข้อมูลเบื้องต้น

PJRT เป็น Device API มาตรฐานที่เราต้องการเพิ่มลงในระบบนิเวศ ML วิสัยทัศน์ระยะยาวคือ

  1. เฟรมเวิร์ก (JAX, TF ฯลฯ) จะเรียกใช้ PJRT ซึ่งมีการใช้งานเฉพาะอุปกรณ์ที่เฟรมเวิร์กมองไม่เห็น
  2. อุปกรณ์แต่ละเครื่องมุ่งเน้นที่การใช้ PJRT API และอาจไม่รองรับเฟรมเวิร์ก

PJRT มีทั้ง C API และ C++ API การเชื่อมต่อที่เลเยอร์ใดเลเยอร์หนึ่งนั้นใช้ได้ C++ API ใช้คลาสเพื่อแยกแนวคิดบางอย่างออก แต่ก็มีความสัมพันธ์กับประเภทข้อมูล XLA มากกว่า หน้านี้จะเน้นที่ C++ API

คอมโพเนนต์ PJRT

คอมโพเนนต์ PJRT

PjRtClient

ดูข้อมูลอ้างอิงฉบับเต็มได้ที่ pjrt_client.h > PjRtClient

ไคลเอ็นต์จะจัดการการสื่อสารทั้งหมดระหว่างอุปกรณ์กับเฟรมเวิร์ก และรวมสถานะทั้งหมดที่ใช้ในการสื่อสาร โดยจะมีชุด API ทั่วไปสําหรับการโต้ตอบกับปลั๊กอิน PJRT และเป็นเจ้าของอุปกรณ์และพื้นที่หน่วยความจําสําหรับปลั๊กอินหนึ่งๆ

PjRtDevice

ดูข้อมูลอ้างอิงฉบับเต็มได้ที่ pjrt_client.h > PjRtDevice และ pjrt_device_description.h

คลาสอุปกรณ์ใช้อธิบายอุปกรณ์เครื่องเดียว อุปกรณ์มีคำอธิบายอุปกรณ์เพื่อช่วยระบุประเภท (แฮชที่ไม่ซ้ำกันเพื่อระบุ GPU/CPU/xPU) และตำแหน่งภายในตารางกริดของอุปกรณ์ทั้งในระดับท้องถิ่นและทั่วโลก

อุปกรณ์ยังทราบพื้นที่หน่วยความจำที่เชื่อมโยงและลูกค้าที่เป็นเจ้าของพื้นที่หน่วยความจำนั้นด้วย

อุปกรณ์ไม่จําเป็นต้องทราบบัฟเฟอร์ของข้อมูลจริงที่เชื่อมโยงกับอุปกรณ์ แต่สามารถหาข้อมูลดังกล่าวได้โดยดูจากพื้นที่หน่วยความจําที่เกี่ยวข้อง

PjRtMemorySpace

ดูข้อมูลอ้างอิงฉบับเต็มได้ที่ pjrt_client.h > PjRtMemorySpace

พื้นที่หน่วยความจําสามารถใช้เพื่ออธิบายตําแหน่งของหน่วยความจํา โดยคุณสามารถเลิกปักหมุดและวางไว้ที่ใดก็ได้ แต่ต้องเข้าถึงได้จากอุปกรณ์ หรือจะปักหมุดไว้และวางไว้ในอุปกรณ์ที่ต้องการก็ได้

พื้นที่หน่วยความจำจะทราบบัฟเฟอร์ข้อมูลที่เชื่อมโยงอยู่ รวมถึงอุปกรณ์ (หลายเครื่อง) ที่เชื่อมโยงกับพื้นที่หน่วยความจำนั้น และไคลเอ็นต์ที่เป็นส่วนหนึ่งของพื้นที่หน่วยความจำ

PjRtBuffer

ดูข้อมูลอ้างอิงฉบับเต็มได้ที่ pjrt_client.h > PjRtBuffer

บัฟเฟอร์จะเก็บข้อมูลในอุปกรณ์ในรูปแบบที่ใช้งานได้ง่ายภายในปลั๊กอิน เช่น แอตทริบิวต์องค์ประกอบ MLIR หรือรูปแบบ Tensor ที่เป็นกรรมสิทธิ์ เฟรมเวิร์กอาจพยายามส่งข้อมูลไปยังอุปกรณ์ในรูปแบบ xla::Literal นั่นคือสำหรับอาร์กิวเมนต์อินพุตไปยังโมดูล ซึ่งต้องโคลน (หรือยืม) ไปยังหน่วยความจำของอุปกรณ์ เมื่อไม่จำเป็นต้องใช้บัฟเฟอร์อีกต่อไป เฟรมเวิร์กจะเรียกใช้เมธอด Delete เพื่อล้างข้อมูล

บัฟเฟอร์จะทราบพื้นที่หน่วยความจำที่ตนเป็นส่วนหนึ่งของ และสามารถระบุอุปกรณ์ที่เข้าถึงบัฟเฟอร์ได้ แต่บัฟเฟอร์ไม่จำเป็นต้องทราบอุปกรณ์ของตน

สําหรับการสื่อสารกับเฟรมเวิร์ก บัฟเฟอร์จะรู้วิธีแปลงจากและกลับเป็นประเภท xla::Literal ต่อไปนี้

// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}

// Buffer to Literal
xla::PjRtFuture<> ToLiteral(xla::MutableLiteralBase* literal) override {...}

API สำหรับการสร้างบัฟเฟอร์มีความหมายของบัฟเฟอร์ซึ่งช่วยกำหนดว่าจะแชร์หรือคัดลอกหรือเปลี่ยนรูปแบบข้อมูลตัวอักษรจากบัฟเฟอร์ของโฮสต์ได้หรือไม่

สุดท้าย บัฟเฟอร์อาจต้องมีอายุการใช้งานนานกว่าขอบเขตของการดำเนินการ หากมีการกําหนดค่าให้กับตัวแปรในเลเยอร์เฟรมเวิร์ก x = jit(foo)(10) ในกรณีเหล่านี้ บัฟเฟอร์จะอนุญาตให้สร้างการอ้างอิงภายนอกซึ่งให้พอยน์เตอร์ที่เป็นเจ้าของชั่วคราวไปยังข้อมูลที่บัฟเฟอร์เก็บไว้ พร้อมกับข้อมูลเมตา (dtype / ขนาดมิติข้อมูล) สำหรับการตีความข้อมูลพื้นฐาน

PjRtCompiler

ดูข้อมูลอ้างอิงฉบับเต็มได้ที่ pjrt_compiler.h > PjRtCompiler

คลาส PjRtCompiler ให้รายละเอียดการใช้งานที่เป็นประโยชน์สําหรับแบ็กเอนด์ XLA แต่ไม่จำเป็นต้องใช้กับปลั๊กอิน ในทางทฤษฎีแล้ว ความรับผิดชอบของ PjRtCompiler หรือเมธอด PjRtClient::Compile คือการนำโมดูลอินพุตมาแสดงผลเป็น PjRtLoadedExecutable

PjRtExecutable / PjRtLoadedExecutable

ดูข้อมูลอ้างอิงฉบับเต็มได้ที่ pjrt_executable.h > PjRtExecutable และ pjrt_client.h > PjRtLoadedExecutable

PjRtExecutable จะทราบวิธีใช้อาร์ติแฟกต์ที่คอมไพล์แล้วและตัวเลือกการดําเนินการ และจัดรูปแบบ/แยกรูปแบบอาร์ติแฟกต์ดังกล่าวเพื่อให้จัดเก็บและโหลดไฟล์ปฏิบัติการได้ตามต้องการ

PjRtLoadedExecutable คือไฟล์ปฏิบัติการที่คอมไพล์ไว้ในหน่วยความจำซึ่งพร้อมที่จะดำเนินการกับอาร์กิวเมนต์อินพุต และเป็นคลาสย่อยของ PjRtExecutable

ไฟล์ปฏิบัติการจะติดต่อผ่านเมธอด Execute ของไคลเอ็นต์วิธีใดวิธีหนึ่งต่อไปนี้

// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}

// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
              PjRtDevice* device, ...) {...}

// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
                PjRtDevice* device, ...) {...}

ก่อนเรียก Execute เฟรมเวิร์กจะโอนข้อมูลที่จำเป็นทั้งหมดไปยัง PjRtBuffers ที่เป็นของไคลเอ็นต์ที่ดำเนินการ แต่ส่งคืนเพื่อให้เฟรมเวิร์กใช้อ้างอิง จากนั้นระบบจะส่งบัฟเฟอร์เหล่านี้เป็นอาร์กิวเมนต์ไปยังเมธอด Execute

แนวคิด PJRT

PjRtFutures และการประมวลผลแบบไม่พร้อมกัน

หากมีการใช้ส่วนใดส่วนหนึ่งของปลั๊กอินแบบไม่พร้อมกัน ต้องใช้ฟิวเจอร์อย่างถูกต้อง

ลองพิจารณาโปรแกรมต่อไปนี้

@jax.jit
def foo(x): return x + 1

x = foo(1)
# [...] other logic not using `x`
print(x + 1)

ปลั๊กอินแบบแอสซิงค์จะจัดคิวการคํานวณ x และแสดงผลบัฟเฟอร์ทันที ซึ่งยังไม่พร้อมให้อ่าน แต่การดําเนินการจะสร้างบัฟเฟอร์ขึ้นมา การดำเนินการสามารถจัดคิวการคํานวณที่จําเป็นหลังจาก x ต่อไปได้ ซึ่งไม่จําเป็นต้องใช้ x รวมถึงการดำเนินการในอุปกรณ์ PJRT เครื่องอื่นๆ เมื่อต้องการค่าของ x ระบบจะบล็อกการดำเนินการจนกว่าบัฟเฟอร์จะประกาศว่าพร้อมแล้วผ่านค่าที่ GetReadyFuture แสดงผลในอนาคต

ฟิวเจอร์สมีประโยชน์ในการระบุว่าออบเจ็กต์จะพร้อมใช้งานเมื่อใด ซึ่งรวมถึงอุปกรณ์และบัฟเฟอร์

แนวคิดขั้นสูง

การใช้ API พื้นฐานเพิ่มเติมจะขยายฟีเจอร์ของ JAX ที่ปลั๊กอินใช้ได้ ฟีเจอร์เหล่านี้เป็นฟีเจอร์ที่ต้องเลือกใช้ทั้งหมด ในแง่ที่ว่า JIT ทั่วไปและเวิร์กโฟลว์การดำเนินการจะทำงานได้โดยไม่ต้องใช้ฟีเจอร์เหล่านี้ แต่สำหรับไปป์ไลน์คุณภาพระดับเวอร์ชันที่ใช้งานจริง คุณควรพิจารณาระดับการรองรับฟีเจอร์ต่อไปนี้ที่ PJRT API รองรับ

  • พื้นที่หน่วยความจำ
  • การจัดวางที่กำหนดเอง
  • การดำเนินการสื่อสาร เช่น ส่ง/รับ
  • การย้ายข้อมูลไปยังโฮสต์
  • การจัดสรร

การสื่อสารระหว่างเฟรมเวิร์ก PJRT กับอุปกรณ์ทั่วไป

ตัวอย่างบันทึก

ต่อไปนี้เป็นบันทึกของเมธอดที่เรียกให้โหลดปลั๊กอิน PJRT และดำเนินการ y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1) ในกรณีนี้ เราบันทึก JAX ที่โต้ตอบกับปลั๊กอิน PJRT อ้างอิง StableHLO

บันทึกตัวอย่าง

//////////////////////////////////
// Load the plugin
//////////////////////////////////

I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400)
I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8)
I device.cc:104] StablehloReferenceDevice(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I client_cpp_pjrt.cc:86] devices(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I device.cc:168] description(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:86] Attributes(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:148] memory_spaces(0x23bac4e0)
Creating PJRT Client from client
I client_cpp_pjrt.cc:108] platform_version(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:57] id(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:153] default_memory_space(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)

//////////////////////////////////
// RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)`
//////////////////////////////////

I executable.cc:309] num_partitions(0x240bab70)
I executable.cc:305] num_replicas(0x240bab70)
I executable.cc:309] num_partitions(0x240bab70)
I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400)
I buffer.cc:285] CreateMlirBufferFromLiteral
I buffer.cc:98] CreateFromLiteral
I buffer.cc:99] CreateFromLiteral: s32[] 2
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:102] CreateFromLiteral -> 0x240bb050
I buffer.cc:158] device(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I executable.cc:328] GetHloModules(0x240bab70)
I executable.cc:240] Execute(0x240bab70)
I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x240bb050)
I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32>
I executable.cc:205] EvalModule:
module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    // ...
    return %3 : tensor<i32>
  }
}
I executable.cc:206] Inputs: [dense<2> : tensor<i32>]
I executable.cc:213] Results: [dense<2> : tensor<i32>]
I device.cc:153] default_memory_space(0x23bac4e0)
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x22cea630)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630

//////////////////////////////////
// RUN: `print(y)`
//////////////////////////////////

I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references.
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x2404d560)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x240b6010)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x23b2db60)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60
I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)