ภาพรวมของ PJRT C++ Device API

ฉากหลัง

PJRT คือ Device API ที่สอดคล้องกันซึ่งเราต้องการเพิ่มลงในระบบนิเวศ ML วิสัยทัศน์ระยะยาว คือ

  1. เฟรมเวิร์ก (JAX, TF ฯลฯ) จะเรียกใช้ PJRT ซึ่งมีการติดตั้งใช้งานเฉพาะอุปกรณ์ ที่เฟรมเวิร์กมองไม่เห็น
  2. อุปกรณ์แต่ละเครื่องมุ่งเน้นที่การใช้ API ของ PJRT และอาจไม่สามารถมองเห็นเฟรมเวิร์กได้

PJRT มีทั้ง C API และ C++ API การเสียบปลั๊กที่เลเยอร์ใดก็ได้ไม่เป็นไร C++ API ใช้คลาสเพื่อดึงแนวคิดบางอย่างออกมา แต่ก็มีความเชื่อมโยงที่แน่นแฟ้นกับ ประเภทข้อมูล XLA ด้วย หน้านี้จะเน้นที่ C++ API

คอมโพเนนต์ PJRT

คอมโพเนนต์ PJRT

PjRtClient

ดูข้อมูลอ้างอิงทั้งหมดได้ที่ pjrt_client.h > PjRtClient

ไคลเอ็นต์จัดการการสื่อสารทั้งหมดระหว่างอุปกรณ์กับเฟรมเวิร์ก และ แคปซูลสถานะทั้งหมดที่ใช้ในการสื่อสาร โดยมีชุด API ทั่วไป สำหรับการโต้ตอบกับปลั๊กอิน PJRT และเป็นเจ้าของอุปกรณ์และพื้นที่หน่วยความจำ สำหรับปลั๊กอินที่กำหนด

PjRtDevice

ข้อมูลอ้างอิงทั้งหมดที่ pjrt_client.h > PjRtDevice และ pjrt_device_description.h

คลาสอุปกรณ์ใช้เพื่ออธิบายอุปกรณ์เครื่องเดียว อุปกรณ์มีคำอธิบายอุปกรณ์เพื่อช่วยระบุประเภทของอุปกรณ์ (แฮชที่ไม่ซ้ำกันเพื่อระบุ GPU/CPU/xPU) และตำแหน่งภายในตารางกริดของอุปกรณ์ทั้งในเครื่องและทั่วโลก

นอกจากนี้ อุปกรณ์ยังทราบพื้นที่เก็บข้อมูลที่เชื่อมโยงและไคลเอ็นต์ที่เป็นเจ้าของด้วย

อุปกรณ์ไม่จำเป็นต้องรู้บัฟเฟอร์ของข้อมูลจริงที่เชื่อมโยงกับอุปกรณ์ แต่สามารถหาได้โดยดูผ่านพื้นที่หน่วยความจำที่เชื่อมโยง

PjRtMemorySpace

ดูข้อมูลอ้างอิงทั้งหมดได้ที่ pjrt_client.h > PjRtMemorySpace

คุณใช้พื้นที่ความทรงจำเพื่ออธิบายสถานที่ของความทรงจำได้ โดยคุณจะ เลิกปักหมุดและย้ายไปไว้ที่ใดก็ได้ แต่ต้องเข้าถึงได้จากอุปกรณ์ หรือจะ ปักหมุดและต้องอยู่ในอุปกรณ์ที่เฉพาะเจาะจงก็ได้

พื้นที่หน่วยความจำจะทราบบัฟเฟอร์ข้อมูลที่เชื่อมโยงและอุปกรณ์ (พหูพจน์) ที่เชื่อมโยงกับพื้นที่หน่วยความจำ รวมถึงไคลเอ็นต์ที่เป็นส่วนหนึ่งของพื้นที่หน่วยความจำ

PjRtBuffer

ดูข้อมูลอ้างอิงทั้งหมดได้ที่ pjrt_client.h > PjRtBuffer

บัฟเฟอร์จะเก็บข้อมูลในอุปกรณ์ในรูปแบบที่ทำงานได้ง่าย ภายในปลั๊กอิน เช่น แอตทริบิวต์ขององค์ประกอบ MLIR หรือรูปแบบเทนเซอร์ที่เป็นกรรมสิทธิ์ เฟรมเวิร์กอาจพยายามส่งข้อมูลไปยังอุปกรณ์ในรูปแบบของ xla::Literal กล่าวคือ สำหรับอาร์กิวเมนต์อินพุตไปยังโมดูล ซึ่งต้องโคลน (หรือยืม) ไปยัง หน่วยความจำของอุปกรณ์ เมื่อไม่จำเป็นต้องใช้บัฟเฟอร์อีกต่อไป เฟรมเวิร์กจะเรียกใช้Deleteเมธอด เพื่อล้างข้อมูล

บัฟเฟอร์จะทราบพื้นที่หน่วยความจำที่เป็นส่วนหนึ่งของบัฟเฟอร์ และสามารถระบุได้ว่า อุปกรณ์ใดเข้าถึงบัฟเฟอร์ได้ แต่บัฟเฟอร์ไม่จำเป็นต้องรู้จักอุปกรณ์ของตนเอง

สำหรับการสื่อสารกับเฟรมเวิร์ก บัฟเฟอร์จะทราบวิธีแปลงเป็นและจากประเภท xla::Literal ดังนี้

// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}

// Buffer to Literal
xla::Future<> ToLiteral(xla::MutableLiteralBase* literal) override {...}

API สำหรับการสร้างบัฟเฟอร์มีความหมายของบัฟเฟอร์ ซึ่งช่วยกำหนดว่าสามารถแชร์ คัดลอก หรือเปลี่ยนแปลงข้อมูลตามตัวอักษรจากบัฟเฟอร์โฮสต์ได้หรือไม่

สุดท้ายนี้ บัฟเฟอร์อาจต้องมีอายุการใช้งานนานกว่าขอบเขตการดำเนินการ หากมีการ กำหนดให้กับตัวแปรในเลเยอร์เฟรมเวิร์ก x = jit(foo)(10) ในกรณีเหล่านี้ บัฟเฟอร์จะช่วยสร้างการอ้างอิงภายนอกซึ่งให้พอยน์เตอร์ที่เป็นเจ้าของชั่วคราว ไปยังข้อมูลที่บัฟเฟอร์เก็บไว้ พร้อมกับข้อมูลเมตา (dtype / ขนาดมิติ) สำหรับการตีความข้อมูลพื้นฐาน

PjRtCompiler

ดูข้อมูลอ้างอิงทั้งหมดได้ที่ pjrt_compiler.h > PjRtCompiler

คลาส PjRtCompiler มีรายละเอียดการใช้งานที่เป็นประโยชน์สำหรับแบ็กเอนด์ XLA แต่ไม่จำเป็นสำหรับปลั๊กอินในการใช้งาน ในทางทฤษฎีแล้ว ความรับผิดชอบของ PjRtCompiler หรือเมธอด PjRtClient::Compile คือการรับโมดูลอินพุตและส่งคืน PjRtLoadedExecutable

PjRtExecutable / PjRtLoadedExecutable

ดูข้อมูลอ้างอิงทั้งหมดได้ที่ pjrt_executable.h > PjRtExecutable และ pjrt_client.h > PjRtLoadedExecutable

PjRtExecutableรู้วิธีใช้ชิ้นงานที่คอมไพล์แล้วและตัวเลือกการดำเนินการ และทำการซีเรียลไลซ์/ดีซีเรียลไลซ์เพื่อให้จัดเก็บและโหลดไฟล์ที่เรียกใช้งานได้ตาม ต้องการ

PjRtLoadedExecutable คือไฟล์ปฏิบัติการที่คอมไพล์แล้วในหน่วยความจำซึ่งพร้อม รับอาร์กิวเมนต์อินพุตเพื่อดำเนินการ โดยเป็นคลาสย่อยของ PjRtExecutable

โดยจะมีการเชื่อมต่อกับไฟล์ที่เรียกใช้งานได้ผ่านหนึ่งในเมธอด Execute ของไคลเอ็นต์

// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}

// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
              PjRtDevice* device, ...) {...}

// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
                PjRtDevice* device, ...) {...}

ก่อนที่จะเรียกใช้ Execute เฟรมเวิร์กจะโอนข้อมูลที่จำเป็นทั้งหมดไปยัง PjRtBuffers ที่ไคลเอ็นต์ที่เรียกใช้เป็นเจ้าของ แต่จะส่งคืนให้เฟรมเวิร์ก อ้างอิง จากนั้นจะมีการระบุบัฟเฟอร์เหล่านี้เป็นอาร์กิวเมนต์ให้กับเมธอด Execute

แนวคิด PJRT

การคำนวณแบบ Futures และ Async

หากมีการใช้ส่วนใดส่วนหนึ่งของปลั๊กอินแบบไม่พร้อมกัน ปลั๊กอินต้องใช้ฟิวเจอร์อย่างถูกต้อง

ลองพิจารณาโปรแกรมต่อไปนี้

@jax.jit
def foo(x): return x + 1

x = foo(1)
# [...] other logic not using `x`
print(x + 1)

ปลั๊กอินแบบไม่พร้อมกันจะสามารถจัดคิวการคำนวณ x และส่งคืนบัฟเฟอร์ที่ยังไม่พร้อมอ่านได้ทันที แต่การดำเนินการจะเติมข้อมูลในบัฟเฟอร์ การดำเนินการจะยังคงจัดคิวการคำนวณที่จำเป็นหลังจาก x ซึ่งไม่จำเป็นต้องใช้ x รวมถึงการดำเนินการในอุปกรณ์ PJRT อื่นๆ เมื่อต้องการค่าของ x การดำเนินการจะถูกบล็อกจนกว่าบัฟเฟอร์จะประกาศว่าพร้อมผ่านฟิวเจอร์ที่ส่งคืนโดย GetReadyFuture

Future มีประโยชน์ในการพิจารณาว่าออบเจ็กต์จะพร้อมใช้งานเมื่อใด ซึ่งรวมถึงอุปกรณ์และบัฟเฟอร์

แนวคิดขั้นสูง

การขยายขอบเขตไปไกลกว่าการใช้ API พื้นฐานจะช่วยขยายฟีเจอร์ของ JAX ที่ปลั๊กอินใช้ได้ ฟีเจอร์เหล่านี้เป็นฟีเจอร์ที่ต้องเลือกใช้ทั้งหมดในแง่ที่ว่าเวิร์กโฟลว์ JIT และการดำเนินการทั่วไปจะทำงานได้โดยไม่มีฟีเจอร์เหล่านี้ แต่สำหรับไปป์ไลน์คุณภาพระดับการใช้งานจริง คุณควรพิจารณาถึงระดับการรองรับฟีเจอร์ใดๆ ต่อไปนี้ที่ API ของ PJRT รองรับ

  • พื้นที่ความทรงจำ
  • การจัดวางที่กำหนดเอง
  • การดำเนินการสื่อสาร เช่น ส่ง/รับ
  • การออฟโหลดโฮสต์
  • การแบ่งส่วน

การสื่อสารระหว่างเฟรมเวิร์ก PJRT กับอุปกรณ์โดยทั่วไป

ตัวอย่างบันทึก

ต่อไปนี้คือบันทึกของเมธอดที่เรียกใช้เพื่อโหลดปลั๊กอิน PJRT และ เรียกใช้ y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1) ในกรณีนี้ เราจะบันทึกการโต้ตอบของ JAX กับปลั๊กอิน PJRT อ้างอิงของ StableHLO

ตัวอย่างบันทึก

//////////////////////////////////
// Load the plugin
//////////////////////////////////

I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400)
I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8)
I device.cc:104] StablehloReferenceDevice(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I client_cpp_pjrt.cc:86] devices(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I device.cc:168] description(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:86] Attributes(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:148] memory_spaces(0x23bac4e0)
Creating PJRT Client from client
I client_cpp_pjrt.cc:108] platform_version(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:57] id(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:153] default_memory_space(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)

//////////////////////////////////
// RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)`
//////////////////////////////////

I executable.cc:309] num_partitions(0x240bab70)
I executable.cc:305] num_replicas(0x240bab70)
I executable.cc:309] num_partitions(0x240bab70)
I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400)
I buffer.cc:285] CreateMlirBufferFromLiteral
I buffer.cc:98] CreateFromLiteral
I buffer.cc:99] CreateFromLiteral: s32[] 2
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:102] CreateFromLiteral -> 0x240bb050
I buffer.cc:158] device(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I executable.cc:328] GetHloModules(0x240bab70)
I executable.cc:240] Execute(0x240bab70)
I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x240bb050)
I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32>
I executable.cc:205] EvalModule:
module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    // ...
    return %3 : tensor<i32>
  }
}
I executable.cc:206] Inputs: [dense<2> : tensor<i32>]
I executable.cc:213] Results: [dense<2> : tensor<i32>]
I device.cc:153] default_memory_space(0x23bac4e0)
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x22cea630)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630

//////////////////////////////////
// RUN: `print(y)`
//////////////////////////////////

I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references.
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x2404d560)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x240b6010)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x23b2db60)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60
I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)