Omówienie interfejsu PJRT C++ Device API

Tło

PJRT to jednolity interfejs Device API, który chcemy dodać do ekosystemu ML. Nasze długoterminowe cele to:

  1. Frameworki (JAX, TF itp.) będą wywoływać PJRT, który ma implementacje zależne od urządzenia, które są nieprzezroczyste dla frameworków;
  2. Każde urządzenie skupia się na implementacji interfejsów API PJRT i może być nieprzejrzyste dla frameworków.

PJRT udostępnia interfejsy API C i C++. Podłączenie się na dowolnym poziomie jest OK. Interfejs API w C++ używa klas do abstrakcji niektórych pojęć, ale ma też silniejsze powiązania z typami danych XLA. Ta strona dotyczy interfejsu C++ API.

Komponenty PJRT

Komponenty PJRT

PjRtClient

Pełne informacje znajdziesz tutaj: pjrt_client.h > PjRtClient.

Klienty zarządzają całą komunikacją między urządzeniem a ramami, a także otaczają stan używany w komunikacji. Ma on ogólny zestaw interfejsów API do interakcji z wtyczką PJRT oraz kontroluje urządzenia i miejsce na dane dla danej wtyczki.

PjRtDevice

Pełne informacje w dokumentach pjrt_client.h > PjRtDevicepjrt_device_description.h

Klasa urządzenia służy do opisania pojedynczego urządzenia. Urządzenie ma opis, który pomaga zidentyfikować jego rodzaj (unikalny skrót do identyfikacji GPU/CPU/xPU) oraz lokalizację w siatce urządzeń zarówno lokalnie, jak i globalnie.

Urządzenia znają też powiązane z nimi przestrzenie pamięci i klienta, które je obsługują.

Urządzenie nie musi znać buforów rzeczywistych danych powiązanych z nim, ale może je znaleźć, przeglądając powiązane obszary pamięci.

PjRtMemorySpace

Pełne informacje znajdziesz tutaj: pjrt_client.h > PjRtMemorySpace.

Pamięć może być opisywana za pomocą przestrzeni pamięci. Możesz je odpiąć, aby były dostępne na dowolnym urządzeniu, lub przypiąć, aby były dostępne tylko na określonym urządzeniu.

Pamięci mają informacje o powiązanych z nimi buforach danych oraz urządzeniach (w liczbie mnogiej), z którymi są powiązane, a także o kliencie, którego są częścią.

PjRtBuffer

Pełne informacje znajdziesz tutaj: pjrt_client.h > PjRtBuffer.

Bufor przechowuje dane na urządzeniu w takim formacie, z którym łatwo będzie pracować w pluginie, np. w formacie atrybutów elementów MLIR lub zastrzeżonym formacie tensora. Framework może próbować wysłać dane do urządzenia w postaci xla::Literal, czyli argumentu wejściowego modułu, który musi zostać sklonowany (lub wypożyczony) do pamięci urządzenia. Gdy bufor nie jest już potrzebny, framework wywołuje metodę Delete, aby go wyczyścić.

Bufor zna przestrzeń pamięci, której jest częścią, i może pośrednio określić, które urządzenia mogą uzyskać do niej dostęp, ale niekoniecznie zna te urządzenia.

W przypadku komunikacji z ramami bufory wiedzą, jak konwertować dane do typu xla::Literal i z niego:

// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}

// Buffer to Literal
xla::PjRtFuture<> ToLiteral(xla::MutableLiteralBase* literal) override {...}

Interfejsy API służące do tworzenia bufora mają semantykę bufora, która określa, czy można udostępniać, kopiować lub zmieniać dane dosłowne z bufora hosta.

Na koniec należy pamiętać, że bufor może być potrzebny dłużej niż w ramach jego wykonania, jeśli jest przypisany do zmiennej w warstwie ramowej x = jit(foo)(10). W takich przypadkach bufory umożliwiają tworzenie odwołań zewnętrznych, które zapewniają tymczasowo posiadany wskaźnik do danych przechowywanych przez bufor wraz z metadanymi (typ danych / rozmiary wymiarów) do interpretacji danych źródłowych.

PjRtCompiler

Pełne informacje znajdziesz tutaj: pjrt_compiler.h > PjRtCompiler.

Klasa PjRtCompiler zawiera przydatne informacje o implementacji XLA na zapleczu, ale nie jest niezbędna do implementacji wtyczki. Teoretycznie PjRtCompiler, czyli metoda PjRtClient::Compile, ma za zadanie przetwarzanie modułu wejściowego i zwracanie wartości PjRtLoadedExecutable.

PjRtExecutable / PjRtLoadedExecutable

Pełne informacje znajdziesz w pjrt_executable.h > PjRtExecutablepjrt_client.h > PjRtLoadedExecutable.

PjRtExecutable wie, jak pobrać skompilowany artefakt i opcje wykonania oraz zakodować je lub zdekodować, aby można było przechowywać i wczytywać plik wykonywalny w razie potrzeby.

PjRtLoadedExecutable to plik wykonywalny skompilowany w pamięci, który jest gotowy do wykonania z danymi wejściowymi. Jest to podklasa klasy PjRtExecutable.

Pliki wykonywalne są łączone z jedną z metod Execute klienta:

// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}

// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
              PjRtDevice* device, ...) {...}

// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
                PjRtDevice* device, ...) {...}

Przed wywołaniem Execute framework przekaże wszystkie wymagane dane do PjRtBuffers należącego do klienta wykonującego kod, ale zwróci je do frameworku do celów informacyjnych. Te bufory są następnie przekazywane jako argumenty metody Execute.

PJRT Concepts

PjRtFutures i obliczenia asynchroniczne

Jeśli jakakolwiek część wtyczki jest implementowana asynchronicznie, musi odpowiednio implementować przyszłość.

Rozważmy ten program:

@jax.jit
def foo(x): return x + 1

x = foo(1)
# [...] other logic not using `x`
print(x + 1)

Wtyczka asynchroniczna może ustawić obliczenia x w kole i natychmiast zwrócić bufor, który nie jest jeszcze gotowy do odczytu, ale zostanie wypełniony podczas wykonywania. Po zakończeniu x można kontynuować kolejkowanie niezbędnych obliczeń, które nie wymagają x, w tym na innych urządzeniach PJRT. Gdy zajdzie potrzeba wartości x, wykonanie zostanie zablokowane, dopóki bufor nie zadeklaruje, że jest gotowy, co nastąpi za pomocą przyszłej wartości zwracanej przez GetReadyFuture.

Futures mogą być przydatne do określania, kiedy obiekt stanie się dostępny, w tym urządzenia i bufory.

Zaawansowane koncepcje

Rozszerzenie implementacji interfejsów podstawowych API zwiększy liczbę funkcji JAX, z których może korzystać wtyczka. Wszystkie te funkcje są opcjonalne w tym sensie, że typowy przepływ pracy z wykorzystaniem kompilacji Just-In-Time i wykonania będzie działał bez nich, ale w przypadku łańcucha przetwarzania o jakości produkcyjnej należy wziąć pod uwagę stopień obsługi tych funkcji obsługiwanych przez interfejsy PJRT:

  • Pamięć na urządzeniach
  • Układy niestandardowe
  • operacje związane z komunikacją, np. wysyłanie/odbieranie.
  • Przenoszenie obciążenia na hosta
  • Podział na fragmenty

Typowa komunikacja między PJRT a urządzeniem

Przykładowy dziennik

Poniżej znajduje się dziennik metod wywoływanych w celu załadowania wtyczki PJRT i wykonania y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1). W tym przypadku JAX rejestruje interakcje z wtyczką PJRT dla StableHLO.

Przykładowy dziennik

//////////////////////////////////
// Load the plugin
//////////////////////////////////

I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400)
I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8)
I device.cc:104] StablehloReferenceDevice(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I client_cpp_pjrt.cc:86] devices(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I device.cc:168] description(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:86] Attributes(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:148] memory_spaces(0x23bac4e0)
Creating PJRT Client from client
I client_cpp_pjrt.cc:108] platform_version(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:57] id(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:153] default_memory_space(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)

//////////////////////////////////
// RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)`
//////////////////////////////////

I executable.cc:309] num_partitions(0x240bab70)
I executable.cc:305] num_replicas(0x240bab70)
I executable.cc:309] num_partitions(0x240bab70)
I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400)
I buffer.cc:285] CreateMlirBufferFromLiteral
I buffer.cc:98] CreateFromLiteral
I buffer.cc:99] CreateFromLiteral: s32[] 2
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:102] CreateFromLiteral -> 0x240bb050
I buffer.cc:158] device(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I executable.cc:328] GetHloModules(0x240bab70)
I executable.cc:240] Execute(0x240bab70)
I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x240bb050)
I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32>
I executable.cc:205] EvalModule:
module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    // ...
    return %3 : tensor<i32>
  }
}
I executable.cc:206] Inputs: [dense<2> : tensor<i32>]
I executable.cc:213] Results: [dense<2> : tensor<i32>]
I device.cc:153] default_memory_space(0x23bac4e0)
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x22cea630)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630

//////////////////////////////////
// RUN: `print(y)`
//////////////////////////////////

I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references.
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x2404d560)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x240b6010)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x23b2db60)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60
I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)