Tło
PJRT to jednolity interfejs Device API, który chcemy dodać do ekosystemu ML. Nasze długoterminowe cele to:
- Frameworki (JAX, TF itp.) będą wywoływać PJRT, który ma implementacje zależne od urządzenia, które są nieprzezroczyste dla frameworków;
- Każde urządzenie skupia się na implementacji interfejsów API PJRT i może być nieprzejrzyste dla frameworków.
PJRT udostępnia interfejsy API C i C++. Podłączenie na dowolną warstwę jest normalne. Interfejs API C++ używa klas do wyodrębnienia niektórych koncepcji, ale ma też lepsze powiązania z typami danych XLA. Ta strona dotyczy interfejsu C++ API.
Komponenty PJRT
PjRtClient
Pełne informacje znajdziesz w artykule pjrt_client.h > PjRtClient
.
Klienty zarządzają całą komunikacją między urządzeniem a platformą i uwzględniają wszystkie stany używane w komunikacji. Ma on ogólny zestaw interfejsów API do interakcji z wtyczką PJRT oraz kontroluje urządzenia i miejsce na dane dla danej wtyczki.
PjRtDevice
Pełne informacje w dokumentach pjrt_client.h > PjRtDevice
i pjrt_device_description.h
Klasa urządzenia służy do opisania pojedynczego urządzenia. Urządzenie ma opis, który pomaga zidentyfikować jego rodzaj (unikalny skrót do identyfikacji GPU/CPU/xPU) oraz lokalizację w siatce urządzeń zarówno lokalnie, jak i globalnie.
Urządzenia znają też powiązane z nimi miejsca w pamięci oraz klienta, do którego należą.
Urządzenie nie musi znać buforów rzeczywistych danych powiązanych z nim, ale może je znaleźć, przeglądając powiązane obszary pamięci.
PjRtMemorySpace
Pełne informacje znajdziesz tutaj: pjrt_client.h > PjRtMemorySpace
.
Pamięć może być opisywana za pomocą przestrzeni pamięci. Możesz je odpiąć, aby były dostępne na dowolnym urządzeniu, lub przypiąć, aby były dostępne tylko na określonym urządzeniu.
Pamięci mają informacje o powiązanych z nimi buforach danych oraz urządzeniach (w liczbie mnogiej), z którymi są powiązane, a także o kliencie, którego są częścią.
PjRtBuffer
Pełne informacje znajdziesz na pjrt_client.h > PjRtBuffer
.
Bufor przechowuje dane na urządzeniu w takim formacie, z którym łatwo będzie pracować w pluginie, np. w formacie atrybutów elementów MLIR lub zastrzeżonym formacie tensora.
Framework może próbować wysłać dane do urządzenia w postaci xla::Literal
, czyli argumentu wejściowego modułu, który musi zostać sklonowany (lub wypożyczony) do pamięci urządzenia. Gdy bufor nie jest już potrzebny, framework wywołuje metodę Delete
, aby go wyczyścić.
Bufor wie, ile miejsca w pamięci zajmuje, i przechodnie może określić, które urządzenia mają do niego dostęp, ale bufory nie zawsze wiedzą o ich urządzeniach.
W przypadku komunikacji z ramami bufory wiedzą, jak konwertować dane do typu xla::Literal
i z niego:
// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}
// Buffer to Literal
xla::PjRtFuture<> ToLiteral(xla::MutableLiteralBase* literal) override {...}
Interfejsy API służące do tworzenia bufora mają semantykę bufora, która określa, czy można udostępniać, kopiować lub zmieniać dane dosłowne z bufora hosta.
Na koniec należy pamiętać, że bufor może być potrzebny dłużej niż w ramach jego wykonania, jeśli jest przypisany do zmiennej w warstwie ramowej x = jit(foo)(10)
. W takich przypadkach bufory umożliwiają tworzenie odwołań zewnętrznych, które zapewniają tymczasowo posiadany wskaźnik do danych przechowywanych przez bufor wraz z metadanymi (typ danych / rozmiary wymiarów) do interpretacji danych źródłowych.
PjRtCompiler
Pełne informacje znajdziesz na pjrt_compiler.h > PjRtCompiler
.
Klasa PjRtCompiler
zawiera przydatne informacje o implementacji XLA na zapleczu, ale nie jest niezbędna do implementacji wtyczki. W teorii obowiązkiem metody PjRtCompiler
, czyli metody PjRtClient::Compile
, jest pobranie modułu wejściowego i zwrócenie PjRtLoadedExecutable
.
PjRtExecutable / PjRtLoadedExecutable
Pełne informacje znajdziesz w pjrt_executable.h > PjRtExecutable
i pjrt_client.h > PjRtLoadedExecutable
.
PjRtExecutable
wie, jak pobrać skompilowany artefakt i opcje wykonania oraz zserializować/zserializować je, aby móc przechowywać i ładować plik wykonywalny zgodnie z potrzebami.
PjRtLoadedExecutable
to plik wykonywalny skompilowany w pamięci, który jest gotowy do wykonania z danymi wejściowymi. Jest to podklasa klasy PjRtExecutable
.
Pliki wykonywalne są łączone z jedną z metod Execute
klienta:
// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}
// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
PjRtDevice* device, ...) {...}
// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
PjRtDevice* device, ...) {...}
Przed wywołaniem Execute
framework przekaże wszystkie wymagane dane do PjRtBuffers
należącego do klienta wykonującego kod, ale zwróci je do frameworku do celów informacyjnych. Bufory te są następnie udostępniane jako argumenty metody Execute
.
Pojęcia związane z PJRT
Przyszłe transakcje terminowe na PjRt i obliczenia asynchroniczne
Jeśli jakakolwiek część wtyczki jest implementowana asynchronicznie, musi odpowiednio implementować przyszłość.
Rozważ następujący program:
@jax.jit
def foo(x): return x + 1
x = foo(1)
# [...] other logic not using `x`
print(x + 1)
Wtyczka asynchroniczna może ustawić obliczenia x
w kole i natychmiast zwrócić bufor, który nie jest jeszcze gotowy do odczytu, ale zostanie wypełniony podczas wykonywania. Po zakończeniu x
można kontynuować kolejkowanie niezbędnych obliczeń, które nie wymagają x
, w tym na innych urządzeniach PJRT. Gdy zajdzie potrzeba wartości x
, wykonanie zostanie zablokowane, dopóki bufor nie zadeklaruje, że jest gotowy, co nastąpi za pomocą przyszłej wartości zwracanej przez GetReadyFuture
.
Przyszłość pomaga określić, kiedy obiekt staje się dostępny, w tym urządzenia i bufory.
Zaawansowane koncepcje
Rozszerzenie implementacji interfejsów podstawowych API zwiększy liczbę funkcji JAX, z których może korzystać wtyczka. Wszystkie te funkcje są opcjonalne w tym sensie, że typowy przepływ pracy z wykorzystaniem kompilacji Just-In-Time i wykonania będzie działał bez nich, ale w przypadku łańcucha przetwarzania o jakości produkcyjnej należy wziąć pod uwagę stopień obsługi tych funkcji obsługiwanych przez interfejsy PJRT:
- Miejsca w pamięci
- Układy niestandardowe
- Operacje komunikacyjne, np. wysyłanie/odbieranie
- Przenoszenie obciążenia na hosta
- Podział na fragmenty
Typowa komunikacja między PJRT a urządzeniem
Przykładowy zapis
Poniżej znajduje się dziennik metod wywoływanych w celu załadowania wtyczki PJRT i wykonania y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)
. W tym przypadku rejestrujemy JAX podczas interakcji z wtyczką StableHLO Reference PJRT.
Przykładowy dziennik
////////////////////////////////// // Load the plugin ////////////////////////////////// I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400) I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8) I device.cc:104] StablehloReferenceDevice(0x23bac4e0) I device.cc:123] client(0x23bac4e0) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I client_cpp_pjrt.cc:86] devices(0x23bac400) I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400) I device.cc:168] description(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:86] Attributes(0x23bac4f8) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400) I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I device.cc:148] memory_spaces(0x23bac4e0) Creating PJRT Client from client I client_cpp_pjrt.cc:108] platform_version(0x23bac400) I client_cpp_pjrt.cc:67] platform_name(0x23bac400) I device.cc:57] id(0x23bac4f8) I device.cc:70] device_kind(0x23bac4f8) I device.cc:70] device_kind(0x23bac4f8) I device.cc:80] ToString(0x23bac4f8) I device.cc:80] ToString(0x23bac4f8) I device.cc:75] DebugString(0x23bac4f8) I device.cc:75] DebugString(0x23bac4f8) I device.cc:61] process_index(0x23bac4f8) I device.cc:128] IsAddressable(0x23bac4e0) I device.cc:168] description(0x23bac4e0) I device.cc:61] process_index(0x23bac4f8) I device.cc:123] client(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I device.cc:153] default_memory_space(0x23bac4e0) I client_cpp_pjrt.cc:71] process_index(0x23bac400) ////////////////////////////////// // RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)` ////////////////////////////////// I executable.cc:309] num_partitions(0x240bab70) I executable.cc:305] num_replicas(0x240bab70) I executable.cc:309] num_partitions(0x240bab70) I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400) I buffer.cc:285] CreateMlirBufferFromLiteral I buffer.cc:98] CreateFromLiteral I buffer.cc:99] CreateFromLiteral: s32[] 2 I buffer.cc:64] MlirPjrtBuffer(0x240bb050) I buffer.cc:102] CreateFromLiteral -> 0x240bb050 I buffer.cc:158] device(0x240bb050) I buffer.cc:154] memory_space(0x240bb050) I buffer.cc:154] memory_space(0x240bb050) I executable.cc:328] GetHloModules(0x240bab70) I executable.cc:240] Execute(0x240bab70) I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x240bb050) I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32> I executable.cc:205] EvalModule: module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) { // ... return %3 : tensor<i32> } } I executable.cc:206] Inputs: [dense<2> : tensor<i32>] I executable.cc:213] Results: [dense<2> : tensor<i32>] I device.cc:153] default_memory_space(0x23bac4e0) I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x22cea630) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630 ////////////////////////////////// // RUN: `print(y)` ////////////////////////////////// I buffer.cc:263] GetReadyFuture(0x22cea630) I buffer.cc:264] GetReadyFuture(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:158] device(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I buffer.cc:154] memory_space(0x22cea630) I client_cpp_pjrt.cc:71] process_index(0x23bac400) I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references. I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:129] on_device_shape(0x22cea630) I buffer.cc:168] AcquireExternalReference(0x22cea630) I buffer.cc:73] MlirClonedExternalReference(0x2404d560) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32> I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x240bb050) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050 I buffer.cc:168] AcquireExternalReference(0x22cea630) I buffer.cc:73] MlirClonedExternalReference(0x240b6010) I buffer.cc:303] GetAttributeFromBuffer I buffer.cc:229] IsDeleted(0x22cea630) I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32> I buffer.cc:291] CreateMlirBufferFromAttribute I buffer.cc:116] CreateFromAttribute I buffer.cc:64] MlirPjrtBuffer(0x23b2db60) I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60 I buffer.cc:263] GetReadyFuture(0x22cea630) I buffer.cc:264] GetReadyFuture(0x22cea630)