배경
PJRT는 ML 생태계에 추가하고자 하는 균일한 기기 API입니다. 장기 비전은 다음과 같습니다.
- 프레임워크 (JAX, TF 등)는 프레임워크에 불투명한 기기별 구현이 있는 PJRT를 호출합니다.
- 각 기기는 PJRT API 구현에 중점을 두며 프레임워크에 불투명할 수 있습니다.
PJRT는 C API와 C++ API를 모두 제공합니다. 어느 레이어에 연결해도 괜찮습니다. C++ API는 클래스를 사용하여 일부 개념을 추상화하지만 XLA 데이터 유형과도 더 강력한 관계를 갖습니다. 이 페이지에서는 C++ API를 중점적으로 다룹니다.
PJRT 구성요소
PjRtClient
pjrt_client.h > PjRtClient에서 전체 참조를 확인하세요.
클라이언트는 기기와 프레임워크 간의 모든 통신을 관리하고 통신에 사용되는 모든 상태를 캡슐화합니다. PJRT 플러그인과 상호작용하기 위한 일반 API 세트가 있으며 특정 플러그인의 기기와 메모리 공간을 소유합니다.
PjRtDevice
pjrt_client.h > PjRtDevice 및 pjrt_device_description.h의 전체 참조
기기 클래스는 단일 기기를 설명하는 데 사용됩니다. 기기에는 기기의 종류 (GPU/CPU/xPU를 식별하는 고유 해시)와 기기 그리드 내 위치를 로컬 및 전역으로 식별하는 데 도움이 되는 기기 설명이 있습니다.
기기는 연결된 메모리 공간과 소유한 클라이언트도 알고 있습니다.
기기는 연결된 실제 데이터의 버퍼를 알 필요는 없지만 연결된 메모리 공간을 살펴봄으로써 이를 파악할 수 있습니다.
PjRtMemorySpace
pjrt_client.h > PjRtMemorySpace에서 전체 참조를 확인하세요.
메모리 공간은 메모리의 위치를 설명하는 데 사용할 수 있습니다. 이러한 앱은 고정 해제되어 기기에서 액세스할 수 있는 한 어디에나 있을 수 있으며, 고정되어 특정 기기에 있어야 할 수도 있습니다.
메모리 공간은 연결된 데이터 버퍼와 메모리 공간이 연결된 기기 (복수) 및 메모리 공간이 속한 클라이언트를 알고 있습니다.
PjRtBuffer
pjrt_client.h > PjRtBuffer에서 전체 참조를 확인하세요.
버퍼는 플러그인 내에서 쉽게 사용할 수 있는 형식(예: MLIR 요소 속성 또는 독점 텐서 형식)으로 기기에 데이터를 보유합니다.
프레임워크는 xla::Literal 형식으로 기기에 데이터를 전송하려고 할 수 있습니다. 즉, 모듈의 입력 인수에 대해 클론 (또는 차용)되어야 기기의 메모리에 저장됩니다. 버퍼가 더 이상 필요하지 않으면 프레임워크에서 Delete 메서드를 호출하여 정리합니다.
버퍼는 자신이 속한 메모리 공간을 알고 있으며, 이를 통해 버퍼에 액세스할 수 있는 기기를 추이적으로 파악할 수 있지만 버퍼가 기기를 반드시 아는 것은 아닙니다.
프레임워크와 통신하기 위해 버퍼는 xla::Literal 유형으로 변환하는 방법을 알고 있습니다.
// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}
// Buffer to Literal
xla::Future<> ToLiteral(xla::MutableLiteralBase* literal) override {...}
버퍼 생성 API에는 호스트 버퍼의 리터럴 데이터를 공유하거나 복사하거나 변경할 수 있는지 여부를 지정하는 데 도움이 되는 버퍼 시맨틱이 있습니다.
마지막으로 프레임워크 레이어 x = jit(foo)(10)의 변수에 할당된 경우 버퍼가 실행 범위보다 오래 지속되어야 할 수 있습니다. 이러한 경우 버퍼를 사용하면 버퍼가 보유한 데이터에 대한 임시 소유 포인터와 기본 데이터를 해석하기 위한 메타데이터 (dtype / dim 크기)를 제공하는 외부 참조를 빌드할 수 있습니다.
PjRtCompiler
pjrt_compiler.h > PjRtCompiler에서 전체 참조를 확인하세요.
PjRtCompiler 클래스는 XLA 백엔드에 유용한 구현 세부정보를 제공하지만 플러그인이 구현할 필요는 없습니다. 이론적으로 PjRtCompiler 또는 PjRtClient::Compile 메서드의 책임은 입력 모듈을 가져와 PjRtLoadedExecutable를 반환하는 것입니다.
PjRtExecutable / PjRtLoadedExecutable
pjrt_executable.h > PjRtExecutable 및 pjrt_client.h > PjRtLoadedExecutable에서 전체 참조를 확인하세요.
PjRtExecutable는 컴파일된 아티팩트와 실행 옵션을 가져와 실행 파일을 저장하고 필요에 따라 로드할 수 있도록 직렬화/역직렬화하는 방법을 알고 있습니다.
PjRtLoadedExecutable는 실행할 입력 인수를 사용할 준비가 된 메모리 내 컴파일된 실행 파일이며 PjRtExecutable의 서브클래스입니다.
실행 파일은 클라이언트의 Execute 메서드 중 하나를 통해 인터페이스됩니다.
// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}
// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
PjRtDevice* device, ...) {...}
// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
PjRtDevice* device, ...) {...}
프레임워크는 Execute를 호출하기 전에 실행 클라이언트가 소유하지만 프레임워크가 참조할 수 있도록 반환된 PjRtBuffers에 필요한 모든 데이터를 전송합니다. 이러한 버퍼는 Execute 메서드에 인수로 제공됩니다.
PJRT 개념
Future 및 비동기 계산
플러그인의 일부가 비동기식으로 구현된 경우 퓨처를 적절하게 구현해야 합니다.
다음 프로그램을 고려해 보세요.
@jax.jit
def foo(x): return x + 1
x = foo(1)
# [...] other logic not using `x`
print(x + 1)
비동기 플러그인은 계산 x를 큐에 추가하고 아직 읽을 준비가 되지 않은 버퍼를 즉시 반환할 수 있지만 실행 시 버퍼가 채워집니다. 실행은 x 이후에 다른 PJRT 기기에서의 실행을 비롯하여 x가 필요하지 않은 필요한 계산을 계속 대기열에 추가할 수 있습니다. x 값이 필요하면 버퍼가 GetReadyFuture에서 반환된 future를 통해 준비되었다고 선언할 때까지 실행이 차단됩니다.
퓨처는 기기 및 버퍼를 비롯한 객체를 사용할 수 있는 시점을 확인하는 데 유용합니다.
고급 개념
기본 API 구현을 넘어 확장하면 플러그인에서 사용할 수 있는 JAX 기능이 확장됩니다. 이러한 기능은 모두 선택사항입니다. 일반적인 JIT 및 실행 워크플로는 이러한 기능 없이도 작동하지만 프로덕션 품질 파이프라인의 경우 PJRT API에서 지원하는 이러한 기능의 지원 정도를 고려해야 합니다.
- 메모리 공간
- 맞춤 레이아웃
- 전송/수신과 같은 통신 작업
- 호스트 오프로드
- 샤딩
일반적인 PJRT 프레임워크-기기 통신
로그 예
다음은 PJRT 플러그인을 로드하고 y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)를 실행하기 위해 호출된 메서드의 로그입니다. 이 경우 StableHLO 참조 PJRT 플러그인과 상호작용하는 JAX를 로깅합니다.
로그 예시
////////////////////////////////// // Load the plugin ////////////////////////////////// I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400) I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8) I device.cc:104] StablehloReferenceDevice(0x23bac4e0) I device.cc:123] client(0x23bac4e0) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I client_cpp_pjrt.cc:86] devices(0x23bac400) I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400) I device.cc:168] description(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:86] Attributes(0x23bac4f8) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400) I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I device.cc:148] memory_spaces(0x23bac4e0) Creating PJRT Client from client I client_cpp_pjrt.cc:108] platform_version(0x23bac400) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I device.cc:57] id(0x23bac4f8) I device.cc:70] device_kind(0x23bac4f8) I device.cc:70] device_kind(0x23bac4f8) I device.cc:80] ToString(0x23bac4f8) I device.cc:80] ToString(0x23bac4f8) I device.cc:75] DebugString(0x23bac4f8) I device.cc:75] DebugString(0x23bac4f8) I device.cc:61] process_index(0x23bac4f8) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I device.cc:153] default_memory_space(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) ////////////////////////////////// // RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)` ////////////////////////////////// I executable.cc:309] num_partitions(0x240bab70) I executable.cc:305] num_replicas(0x240bab70) I executable.cc:309] num_partitions(0x240bab70) I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400) I buffer.cc:285] CreateMlirBufferFromLiteral I buffer.cc:98] CreateFromLiteral I buffer.cc:99] CreateFromLiteral: s32[] 2 I buffer.cc:64] MlirPjrtBuffer(0x240bb050) I buffer.cc:102] CreateFromLiteral -> 0x240bb050 I buffer.cc:158] device(0x240bb050) I buffer.cc:154] memory_space(0x240bb050) I buffer.cc:154] memory_space(0x240bb050) I executable.cc:328] GetHloModules(0x240bab70) I executable.cc:240] Execute(0x240bab70) I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x240bb050) I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32> I executable.cc:205] EvalModule: module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) { // ... return %3 : tensor<i32> } } I executable.cc:206] Inputs: [dense<2> : tensor<i32>] I executable.cc:213] Results: [dense<2> : tensor<i32>] I device.cc:153] default_memory_space(0x23bac4e0) I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x22cea630) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630 ////////////////////////////////// // RUN: `print(y)` ////////////////////////////////// I buffer.cc:263] GetReadyFuture(0x22cea630) I buffer.cc:264] GetReadyFuture(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references. I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:168] AcquireExternalReference(0x22cea630) I buffer.cc:73] MlirClonedExternalReference(0x2404d560) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32> I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x240bb050) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050 I buffer.cc:168] AcquireExternalReference(0x22cea630) I buffer.cc:73] MlirClonedExternalReference(0x240b6010) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32> I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x23b2db60) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60 I buffer.cc:263] GetReadyFuture(0x22cea630) I buffer.cc:264] GetReadyFuture(0x22cea630)