Omówienie interfejsu PJRT C++ Device API

Tło

PJRT to jednolity interfejs API urządzenia, który chcemy dodać do ekosystemu ML. Długoterminowa wizja jest taka:

  1. Platformy (JAX, TF itp.) będą wywoływać PJRT, która ma implementacje specyficzne dla urządzenia, niewidoczne dla platform;
  2. Każde urządzenie koncentruje się na implementacji interfejsów PJRT API i może być nieprzejrzyste dla platform.

PJRT udostępnia interfejsy API w językach C i C++. Możesz podłączyć się na dowolnym poziomie. Interfejs C++ API używa klas do abstrakcyjnego przedstawiania niektórych koncepcji, ale jest też silniej powiązany z typami danych XLA. Ta strona dotyczy interfejsu C++ API.

Komponenty PJRT

Komponenty PJRT

PjRtClient

Pełne informacje znajdziesz na stronie pjrt_client.h > PjRtClient.

Klienci zarządzają całą komunikacją między urządzeniem a platformą i obejmują wszystkie stany używane w komunikacji. Mają ogólny zestaw interfejsów API do interakcji z wtyczką PJRT oraz są właścicielami urządzeń i przestrzeni pamięci dla danej wtyczki.

PjRtDevice

Pełne informacje: pjrt_client.h > PjRtDevicepjrt_device_description.h

Klasa urządzenia służy do opisywania pojedynczego urządzenia. Urządzenie ma opis, który pomaga określić jego rodzaj (unikalny hash identyfikujący GPU/CPU/xPU) i lokalizację w siatce urządzeń zarówno lokalnie, jak i globalnie.

Urządzenia znają też powiązane z nimi przestrzenie pamięci i klienta, do którego należą.

Urządzenie nie musi znać buforów rzeczywistych danych z nim powiązanych, ale może to ustalić, przeglądając powiązane z nim obszary pamięci.

PjRtMemorySpace

Pełne informacje znajdziesz na stronie pjrt_client.h > PjRtMemorySpace.

Przestrzenie pamięci mogą służyć do opisywania lokalizacji pamięci. Mogą być odpięte i znajdować się w dowolnym miejscu, ale muszą być dostępne z urządzenia. Mogą też być przypięte i muszą znajdować się na konkretnym urządzeniu.

Przestrzenie pamięci znają powiązane z nimi bufory danych i urządzenia, z którymi są powiązane, a także klienta, którego są częścią.

PjRtBuffer

Pełne informacje znajdziesz na stronie pjrt_client.h > PjRtBuffer.

Bufor przechowuje dane na urządzeniu w formacie, który ułatwia pracę w ramach wtyczki, np. atrybut elementów MLIR lub zastrzeżony format tensora. Framework może próbować wysłać dane do urządzenia w formie xla::Literal, czyli jako argument wejściowy modułu, który musi zostać sklonowany (lub wypożyczony) do pamięci urządzenia. Gdy bufor nie jest już potrzebny, framework wywołuje metodę Delete, aby go wyczyścić.

Bufor wie, do jakiej przestrzeni pamięci należy, i może pośrednio określić, które urządzenia mają do niego dostęp, ale niekoniecznie zna swoje urządzenia.

W przypadku komunikacji z frameworkami bufory wiedzą, jak konwertować dane na typ xla::Literal i z niego:

// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}

// Buffer to Literal
xla::Future<> ToLiteral(xla::MutableLiteralBase* literal) override {...}

Interfejsy API do tworzenia bufora mają semantykę bufora, która określa, czy dane z bufora hosta mogą być udostępniane, kopiowane czy modyfikowane.

Bufor może też działać dłużej niż zakres jego wykonania, jeśli jest przypisany do zmiennej w warstwie struktury x = jit(foo)(10). W takich przypadkach bufor umożliwia tworzenie odwołań zewnętrznych, które zapewniają tymczasowo posiadany wskaźnik do danych przechowywanych w buforze wraz z metadanymi (typ danych lub rozmiary wymiarów) do interpretowania danych bazowych.

PjRtCompiler

Pełne informacje znajdziesz na stronie pjrt_compiler.h > PjRtCompiler.

Klasa PjRtCompiler zawiera przydatne szczegóły implementacji backendów XLA, ale nie jest niezbędna do implementacji wtyczki. Teoretycznie zadaniem PjRtCompiler lub metody PjRtClient::Compile jest przyjęcie modułu wejściowego i zwrócenie PjRtLoadedExecutable.

PjRtExecutable / PjRtLoadedExecutable

Pełne informacje znajdziesz na stronach pjrt_executable.h > PjRtExecutablepjrt_client.h > PjRtLoadedExecutable.

PjRtExecutable wie, jak pobrać skompilowany artefakt i opcje wykonania oraz serializować i deserializować je, aby plik wykonywalny można było w razie potrzeby przechowywać i wczytywać;

PjRtLoadedExecutable to skompilowany w pamięci plik wykonywalny, który jest gotowy do wykonania z argumentami wejściowymi. Jest to podklasa klasy PjRtExecutable.

Z plikami wykonywalnymi można się komunikować za pomocą jednej z metod Execute klienta:

// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}

// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
              PjRtDevice* device, ...) {...}

// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
                PjRtDevice* device, ...) {...}

Przed wywołaniem funkcji Execute platforma przekaże wszystkie wymagane dane do usługi PjRtBuffers należącej do klienta wykonującego działanie, ale zwróci je do platformy w celu odwołania się do nich. Te bufory są następnie przekazywane jako argumenty do metody Execute.

Pojęcia związane z PJRT

Obiekty Future i obliczenia asynchroniczne

Jeśli jakakolwiek część wtyczki jest zaimplementowana asynchronicznie, musi prawidłowo implementować obiekty Future.

Rozważmy ten program:

@jax.jit
def foo(x): return x + 1

x = foo(1)
# [...] other logic not using `x`
print(x + 1)

Wtyczka asynchroniczna może umieścić obliczenia w kolejce x i natychmiast zwrócić bufor, który nie jest jeszcze gotowy do odczytu, ale wykonanie go wypełni. Po wykonaniu operacji x można nadal umieszczać w kolejce niezbędne obliczenia, które nie wymagają operacji x, w tym obliczenia na innych urządzeniach PJRT. Gdy wartość x jest potrzebna, wykonywanie zostanie zablokowane, dopóki bufor nie zadeklaruje gotowości za pomocą obiektu Future zwróconego przez GetReadyFuture.

Obiekty Future mogą być przydatne do określania, kiedy obiekt staje się dostępny, w tym urządzenia i bufory.

Zaawansowane koncepcje

Wykraczanie poza implementację podstawowych interfejsów API rozszerzy funkcje JAX, które mogą być używane przez wtyczkę. Są to funkcje opcjonalne, ponieważ typowy przepływ pracy JIT i wykonania będzie działać bez nich, ale w przypadku potoku o jakości produkcyjnej warto zastanowić się nad stopniem obsługi dowolnej z tych funkcji obsługiwanych przez interfejsy API PJRT:

  • Przestrzenie pamięci
  • Układy niestandardowe
  • Operacje komunikacyjne, takie jak wysyłanie i odbieranie
  • Przenoszenie zadań na hosta
  • Fragmentacja

Typowa komunikacja między urządzeniem a platformą PJRT

Przykładowy dziennik

Poniżej znajduje się dziennik metod wywoływanych w celu wczytania wtyczki PJRT i wykonania funkcji y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1). W tym przypadku rejestrujemy interakcję JAX z wtyczką PJRT StableHLO Reference.

Przykładowy log

//////////////////////////////////
// Load the plugin
//////////////////////////////////

I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400)
I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8)
I device.cc:104] StablehloReferenceDevice(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I client_cpp_pjrt.cc:86] devices(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I device.cc:168] description(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:86] Attributes(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:148] memory_spaces(0x23bac4e0)
Creating PJRT Client from client
I client_cpp_pjrt.cc:108] platform_version(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:57] id(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:153] default_memory_space(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)

//////////////////////////////////
// RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)`
//////////////////////////////////

I executable.cc:309] num_partitions(0x240bab70)
I executable.cc:305] num_replicas(0x240bab70)
I executable.cc:309] num_partitions(0x240bab70)
I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400)
I buffer.cc:285] CreateMlirBufferFromLiteral
I buffer.cc:98] CreateFromLiteral
I buffer.cc:99] CreateFromLiteral: s32[] 2
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:102] CreateFromLiteral -> 0x240bb050
I buffer.cc:158] device(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I executable.cc:328] GetHloModules(0x240bab70)
I executable.cc:240] Execute(0x240bab70)
I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x240bb050)
I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32>
I executable.cc:205] EvalModule:
module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    // ...
    return %3 : tensor<i32>
  }
}
I executable.cc:206] Inputs: [dense<2> : tensor<i32>]
I executable.cc:213] Results: [dense<2> : tensor<i32>]
I device.cc:153] default_memory_space(0x23bac4e0)
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x22cea630)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630

//////////////////////////////////
// RUN: `print(y)`
//////////////////////////////////

I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references.
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x2404d560)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x240b6010)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x23b2db60)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60
I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)