배경
PJRT는 ML 생태계에 추가할 통일된 기기 API입니다. 장기 비전은 다음과 같습니다.
- 프레임워크 (JAX, TF 등)는 프레임워크에 불투명한 기기별 구현이 있는 PJRT를 호출합니다.
- 각 기기는 PJRT API 구현에 중점을 두며 프레임워크에 대해 불투명할 수 있습니다.
PJRT는 C API와 C++ API를 모두 제공합니다. 어느 레이어에서든 연결해도 됩니다. C++ API는 클래스를 사용하여 일부 개념을 추상화하지만 XLA 데이터 유형과도 더 밀접한 관계가 있습니다. 이 페이지에서는 C++ API에 중점을 둡니다.
PJRT 구성요소
PjRtClient
자세한 내용은 pjrt_client.h > PjRtClient
를 참고하세요.
클라이언트는 기기와 프레임워크 간의 모든 통신을 관리하고 통신에 사용되는 모든 상태를 캡슐화합니다. PJRT 플러그인과 상호작용하는 API의 일반 집합이 있으며, 특정 플러그인의 기기와 메모리 공간을 소유합니다.
PjRtDevice
전체 참조는 pjrt_client.h > PjRtDevice
및 pjrt_device_description.h
에서 확인
기기 클래스는 단일 기기를 설명하는 데 사용됩니다. 기기에는 기기의 종류 (GPU/CPU/xPU를 식별하는 고유 해시)를 식별하는 데 도움이 되는 기기 설명과 로컬 및 전 세계 기기 그리드 내 위치가 있습니다.
기기는 연결된 메모리 공간과 소유한 클라이언트도 알고 있습니다.
기기는 이와 연결된 실제 데이터의 버퍼를 반드시 알 필요는 없지만 연결된 메모리 공간을 통해 이를 파악할 수 있습니다.
PjRtMemorySpace
자세한 내용은 pjrt_client.h > PjRtMemorySpace
를 참고하세요.
메모리 공간은 메모리 위치를 설명하는 데 사용할 수 있습니다. 이러한 위치는 고정 해제할 수 있으며 어디에나 있을 수 있지만 기기에서 액세스할 수 있습니다. 또는 고정할 수 있으며 특정 기기에 있어야 합니다.
메모리 공간은 연결된 데이터 버퍼와 메모리 공간이 연결된 기기 (복수)는 물론 그 일부인 클라이언트를 알고 있습니다.
PjRtBuffer
자세한 내용은 pjrt_client.h > PjRtBuffer
를 참고하세요.
버퍼는 MLIR 요소 속성 또는 독점 텐서 형식과 같이 플러그인 내에서 쉽게 작업할 수 있는 형식으로 기기의 데이터를 보관합니다.
프레임워크는 클론 (또는 빌림)해야 하는 모듈의 입력 인수의 경우 xla::Literal
형식으로 데이터를 기기에 전송하려고 시도할 수 있습니다. 버퍼가 더 이상 필요하지 않으면 프레임워크에서 Delete
메서드를 호출하여 정리합니다.
버퍼는 자신이 속한 메모리 공간을 알고 있으며, 어떤 기기가 메모리 공간에 액세스할 수 있는지 추론할 수 있지만, 버퍼가 반드시 기기를 알 필요는 없습니다.
프레임워크와 통신하기 위해 버퍼는 xla::Literal
유형으로 변환하는 방법을 알고 있습니다.
// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}
// Buffer to Literal
xla::PjRtFuture<> ToLiteral(xla::MutableLiteralBase* literal) override {...}
버퍼를 만드는 API에는 호스트 버퍼의 리터럴 데이터를 공유하거나 복사하거나 변형할 수 있는지 지정하는 데 도움이 되는 버퍼 시맨틱이 있습니다.
마지막으로, 버퍼가 프레임워크 레이어 x = jit(foo)(10)
의 변수에 할당된 경우 버퍼가 실행 범위보다 오래 지속되어야 할 수 있습니다. 이 경우 버퍼는 버퍼에 보관된 데이터에 대한 임시 소유 포인터와 기본 데이터를 해석하기 위한 메타데이터 (dtype / dim 크기)를 제공하는 외부 참조를 빌드할 수 있습니다.
PjRtCompiler
자세한 내용은 pjrt_compiler.h > PjRtCompiler
를 참고하세요.
PjRtCompiler
클래스는 XLA 백엔드에 유용한 구현 세부정보를 제공하지만 플러그인이 구현할 필요는 없습니다. 이론적으로 PjRtCompiler
또는 PjRtClient::Compile
메서드의 책임은 입력 모듈을 사용하여 PjRtLoadedExecutable
를 반환하는 것입니다.
PjRtExecutable / PjRtLoadedExecutable
pjrt_executable.h > PjRtExecutable
및 pjrt_client.h > PjRtLoadedExecutable
에서 전체 참조를 확인하세요.
PjRtExecutable
는 컴파일된 아티팩트와 실행 옵션을 가져와서 직렬화/역직렬화하여 실행 파일을 필요에 따라 저장하고 로드하는 방법을 알고 있습니다.
PjRtLoadedExecutable
는 입력 인수를 실행할 준비가 된 메모리 내 컴파일된 실행 파일이며 PjRtExecutable
의 서브클래스입니다.
실행 파일은 클라이언트의 Execute
메서드 중 하나를 통해 인터페이스됩니다.
// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}
// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
PjRtDevice* device, ...) {...}
// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
PjRtDevice* device, ...) {...}
Execute
를 호출하기 전에 프레임워크는 실행 중인 클라이언트가 소유하지만 프레임워크가 참조하도록 반환된 PjRtBuffers
에 필요한 모든 데이터를 전송합니다. 그런 다음 이러한 버퍼가 Execute
메서드에 인수로 제공됩니다.
PJRT 개념
PjRtFutures 및 비동기 계산
플러그인의 일부가 비동기식으로 구현되는 경우 future를 올바르게 구현해야 합니다.
다음 프로그램을 고려해 보세요.
@jax.jit
def foo(x): return x + 1
x = foo(1)
# [...] other logic not using `x`
print(x + 1)
비동기 플러그인은 계산 x
를 큐에 추가하고 아직 읽을 준비가 되지 않은 버퍼를 즉시 반환할 수 있지만 실행 시 이를 채웁니다. 실행은 x
후에도 다른 PJRT 기기에서 실행을 포함하여 x
가 필요하지 않은 필요한 계산을 계속 큐에 추가할 수 있습니다. x
의 값이 필요하면 버퍼가 GetReadyFuture
에서 반환된 future를 통해 준비되었음을 선언할 때까지 실행이 차단됩니다.
Futures는 기기 및 버퍼를 비롯하여 객체를 사용할 수 있는 시기를 확인하는 데 유용할 수 있습니다.
고급 개념
기본 API 구현을 넘어 확장하면 플러그인에서 사용할 수 있는 JAX의 기능이 확장됩니다. 이러한 기능은 모두 선택사항입니다. 일반적인 JIT 및 실행 워크플로는 이러한 기능 없이도 작동하지만 프로덕션 품질 파이프라인의 경우 PJRT API에서 지원하는 이러한 기능의 지원 정도를 고려해야 합니다.
- 메모리 공간
- 맞춤 레이아웃
- send/recv와 같은 통신 작업
- 호스트 오프로드
- 샤딩
일반적인 PJRT 프레임워크-기기 통신
로그 예시
다음은 PJRT 플러그인을 로드하고 y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)
를 실행하기 위해 호출된 메서드의 로그입니다. 이 경우 StableHLO 참조 PJRT 플러그인과 상호작용하는 JAX를 로깅합니다.
로그 예시
////////////////////////////////// // Load the plugin ////////////////////////////////// I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400) I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8) I device.cc:104] StablehloReferenceDevice(0x23bac4e0) I device.cc:123] client(0x23bac4e0) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I client_cpp_pjrt.cc:86] devices(0x23bac400) I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400) I device.cc:168] description(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:86] Attributes(0x23bac4f8) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400) I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I device.cc:148] memory_spaces(0x23bac4e0) Creating PJRT Client from client I client_cpp_pjrt.cc:108] platform_version(0x23bac400) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I device.cc:57] id(0x23bac4f8) I device.cc:70] device_kind(0x23bac4f8) I device.cc:70] device_kind(0x23bac4f8) I device.cc:80] ToString(0x23bac4f8) I device.cc:80] ToString(0x23bac4f8) I device.cc:75] DebugString(0x23bac4f8) I device.cc:75] DebugString(0x23bac4f8) I device.cc:61] process_index(0x23bac4f8) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I device.cc:153] default_memory_space(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) ////////////////////////////////// // RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)` ////////////////////////////////// I executable.cc:309] num_partitions(0x240bab70) I executable.cc:305] num_replicas(0x240bab70) I executable.cc:309] num_partitions(0x240bab70) I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400) I buffer.cc:285] CreateMlirBufferFromLiteral I buffer.cc:98] CreateFromLiteral I buffer.cc:99] CreateFromLiteral: s32[] 2 I buffer.cc:64] MlirPjrtBuffer(0x240bb050) I buffer.cc:102] CreateFromLiteral -> 0x240bb050 I buffer.cc:158] device(0x240bb050) I buffer.cc:154] memory_space(0x240bb050) I buffer.cc:154] memory_space(0x240bb050) I executable.cc:328] GetHloModules(0x240bab70) I executable.cc:240] Execute(0x240bab70) I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x240bb050) I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32> I executable.cc:205] EvalModule: module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) { // ... return %3 : tensor<i32> } } I executable.cc:206] Inputs: [dense<2> : tensor<i32>] I executable.cc:213] Results: [dense<2> : tensor<i32>] I device.cc:153] default_memory_space(0x23bac4e0) I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x22cea630) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630 ////////////////////////////////// // RUN: `print(y)` ////////////////////////////////// I buffer.cc:263] GetReadyFuture(0x22cea630) I buffer.cc:264] GetReadyFuture(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references. I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:168] AcquireExternalReference(0x22cea630) I buffer.cc:73] MlirClonedExternalReference(0x2404d560) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32> I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x240bb050) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050 I buffer.cc:168] AcquireExternalReference(0x22cea630) I buffer.cc:73] MlirClonedExternalReference(0x240b6010) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32> I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x23b2db60) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60 I buffer.cc:263] GetReadyFuture(0x22cea630) I buffer.cc:264] GetReadyFuture(0x22cea630)