PJRT C++ Device API 개요

배경

PJRT는 ML 생태계에 추가할 통일된 기기 API입니다. 장기 비전은 다음과 같습니다.

  1. 프레임워크 (JAX, TF 등)는 프레임워크에 불투명한 기기별 구현이 있는 PJRT를 호출합니다.
  2. 각 기기는 PJRT API 구현에 중점을 두며 프레임워크에 대해 불투명할 수 있습니다.

PJRT는 C API와 C++ API를 모두 제공합니다. 어느 레이어에 연결해도 괜찮습니다. C++ API는 클래스를 사용하여 일부 개념을 추상화하지만 XLA 데이터 유형과 더 밀접한 관련이 있습니다. 이 페이지에서는 C++ API에 중점을 둡니다.

PJRT 구성요소

PJRT 구성요소

PjRtClient

자세한 내용은 pjrt_client.h > PjRtClient를 참고하세요.

클라이언트는 기기와 프레임워크 간의 모든 통신을 관리하고 통신에 사용되는 모든 상태를 캡슐화합니다. PJRT 플러그인과 상호작용하는 API의 일반 집합이 있으며, 특정 플러그인의 기기와 메모리 공간을 소유합니다.

PjRtDevice

전체 참조는 pjrt_client.h > PjRtDevicepjrt_device_description.h에서 확인

기기 클래스는 단일 기기를 설명하는 데 사용됩니다. 기기에는 기기의 종류 (GPU/CPU/xPU를 식별하는 고유 해시)를 식별하는 데 도움이 되는 기기 설명과 로컬 및 전 세계 기기 그리드 내 위치가 있습니다.

기기는 연결된 메모리 공간과 소유한 클라이언트도 알고 있습니다.

기기는 이와 연결된 실제 데이터의 버퍼를 반드시 알 필요는 없지만 연결된 메모리 공간을 통해 이를 파악할 수 있습니다.

PjRtMemorySpace

자세한 내용은 pjrt_client.h > PjRtMemorySpace를 참고하세요.

메모리 공간을 사용하여 메모리의 위치를 설명할 수 있습니다. 이러한 항목은 고정 해제할 수 있고 어디에서나 자유롭게 사용할 수 있지만 기기에서 액세스할 수 있습니다. 또는 항목을 고정하여 특정 기기에 유지해야 할 수도 있습니다.

메모리 공간은 연결된 데이터 버퍼와 메모리 공간이 연결된 기기 (복수)는 물론 그 일부인 클라이언트를 알고 있습니다.

PjRtBuffer

자세한 내용은 pjrt_client.h > PjRtBuffer를 참고하세요.

버퍼는 MLIR 요소 속성 또는 독점 텐서 형식과 같이 플러그인 내에서 쉽게 작업할 수 있는 형식으로 기기의 데이터를 보관합니다. 프레임워크는 클론 (또는 빌림)해야 하는 모듈의 입력 인수의 경우 xla::Literal 형식으로 데이터를 기기에 전송하려고 시도할 수 있습니다. 버퍼가 더 이상 필요하지 않으면 프레임워크에서 Delete 메서드를 호출하여 정리합니다.

버퍼는 자신이 속한 메모리 공간을 알고 있으며, 어떤 기기가 메모리 공간에 액세스할 수 있는지 추론할 수 있지만, 버퍼가 반드시 기기를 알 필요는 없습니다.

프레임워크와 통신하기 위해 버퍼는 xla::Literal 유형으로 변환하거나 이러한 유형에서 변환하는 방법을 알고 있습니다.

// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}

// Buffer to Literal
xla::PjRtFuture<> ToLiteral(xla::MutableLiteralBase* literal) override {...}

버퍼를 만드는 API에는 호스트 버퍼의 리터럴 데이터를 공유하거나 복사하거나 변형할 수 있는지 지정하는 데 도움이 되는 버퍼 시맨틱이 있습니다.

마지막으로, 버퍼가 프레임워크 레이어 x = jit(foo)(10)의 변수에 할당된 경우 버퍼가 실행 범위보다 오래 지속되어야 할 수 있습니다. 이 경우 버퍼는 버퍼에 보관된 데이터에 대한 임시 소유 포인터와 기본 데이터를 해석하기 위한 메타데이터 (dtype / dim 크기)를 제공하는 외부 참조를 빌드할 수 있습니다.

PjRtCompiler

자세한 내용은 pjrt_compiler.h > PjRtCompiler를 참고하세요.

PjRtCompiler 클래스는 XLA 백엔드에 유용한 구현 세부정보를 제공하지만 플러그인이 구현할 필요는 없습니다. 이론적으로 PjRtCompiler 또는 PjRtClient::Compile 메서드의 책임은 입력 모듈을 사용하여 PjRtLoadedExecutable를 반환하는 것입니다.

PjRtExecutable / PjRtLoadedExecutable

pjrt_executable.h > PjRtExecutablepjrt_client.h > PjRtLoadedExecutable에서 전체 참조를 확인하세요.

PjRtExecutable는 컴파일된 아티팩트와 실행 옵션을 가져와서 직렬화/역직렬화하여 실행 파일을 필요에 따라 저장하고 로드하는 방법을 알고 있습니다.

PjRtLoadedExecutable는 입력 인수를 실행할 준비가 된 메모리 내 컴파일된 실행 파일이며 PjRtExecutable의 서브클래스입니다.

실행 파일은 클라이언트의 Execute 메서드 중 하나를 통해 인터페이스됩니다.

// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}

// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
              PjRtDevice* device, ...) {...}

// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
                PjRtDevice* device, ...) {...}

Execute를 호출하기 전에 프레임워크는 실행 중인 클라이언트가 소유하지만 프레임워크가 참조하도록 반환된 PjRtBuffers에 필요한 모든 데이터를 전송합니다. 그런 다음 이러한 버퍼는 Execute 메서드에 인수로 제공됩니다.

PJRT 개념

PjRtFutures 및 비동기 계산

플러그인의 일부가 비동기식으로 구현되는 경우 future를 올바르게 구현해야 합니다.

다음 프로그램을 살펴보겠습니다.

@jax.jit
def foo(x): return x + 1

x = foo(1)
# [...] other logic not using `x`
print(x + 1)

비동기 플러그인은 계산 x를 큐에 추가하고 아직 읽을 준비가 되지 않은 버퍼를 즉시 반환할 수 있지만 실행 시 이를 채웁니다. 실행은 x 후에도 다른 PJRT 기기에서 실행을 포함하여 x가 필요하지 않은 필요한 계산을 계속 큐에 추가할 수 있습니다. x의 값이 필요하면 버퍼가 GetReadyFuture에서 반환된 future를 통해 준비되었음을 선언할 때까지 실행이 차단됩니다.

Futures는 기기 및 버퍼를 비롯하여 객체를 사용할 수 있는 시기를 확인하는 데 유용할 수 있습니다.

고급 개념

기본 API 구현 이상으로 확장하면 플러그인에서 사용할 수 있는 JAX의 기능이 확장됩니다. 이러한 기능은 모두 선택사항입니다. 일반적인 JIT 및 실행 워크플로는 이러한 기능 없이도 작동하지만 프로덕션 품질 파이프라인의 경우 PJRT API에서 지원하는 이러한 기능의 지원 정도를 고려해야 합니다.

  • 메모리 공간
  • 맞춤 레이아웃
  • 전송/수신과 같은 커뮤니케이션 작업
  • 호스트 오프로드
  • 샤딩

일반적인 PJRT 프레임워크-기기 통신

로그 예

다음은 PJRT 플러그인을 로드하고 y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)를 실행하기 위해 호출된 메서드의 로그입니다. 이 경우 StableHLO 참조 PJRT 플러그인과 상호작용하는 JAX를 로깅합니다.

로그 예시

//////////////////////////////////
// Load the plugin
//////////////////////////////////

I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400)
I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8)
I device.cc:104] StablehloReferenceDevice(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I client_cpp_pjrt.cc:86] devices(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I device.cc:168] description(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:86] Attributes(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:148] memory_spaces(0x23bac4e0)
Creating PJRT Client from client
I client_cpp_pjrt.cc:108] platform_version(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:57] id(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:153] default_memory_space(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)

//////////////////////////////////
// RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)`
//////////////////////////////////

I executable.cc:309] num_partitions(0x240bab70)
I executable.cc:305] num_replicas(0x240bab70)
I executable.cc:309] num_partitions(0x240bab70)
I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400)
I buffer.cc:285] CreateMlirBufferFromLiteral
I buffer.cc:98] CreateFromLiteral
I buffer.cc:99] CreateFromLiteral: s32[] 2
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:102] CreateFromLiteral -> 0x240bb050
I buffer.cc:158] device(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I executable.cc:328] GetHloModules(0x240bab70)
I executable.cc:240] Execute(0x240bab70)
I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x240bb050)
I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32>
I executable.cc:205] EvalModule:
module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    // ...
    return %3 : tensor<i32>
  }
}
I executable.cc:206] Inputs: [dense<2> : tensor<i32>]
I executable.cc:213] Results: [dense<2> : tensor<i32>]
I device.cc:153] default_memory_space(0x23bac4e0)
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x22cea630)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630

//////////////////////////////////
// RUN: `print(y)`
//////////////////////////////////

I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references.
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x2404d560)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x240b6010)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x23b2db60)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60
I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)