Фон
PJRT — это унифицированный API устройств, который мы хотим добавить в экосистему машинного обучения. Долгосрочная цель заключается в следующем:
- Фреймворки (JAX, TF и т. д.) будут вызывать PJRT, который имеет реализации, специфичные для устройств, которые непрозрачны для фреймворков;
- Каждое устройство ориентировано на реализацию API PJRT и может быть непрозрачным для фреймворков.
PJRT предлагает API как для C, так и для C++. Подключение к любому из этих уровней допустимо. API C++ использует классы для абстрагирования некоторых концепций, но при этом имеет более тесную связь с типами данных XLA. Эта страница посвящена API C++.
Компоненты PJRT
PjRtClient
Полная ссылка: pjrt_client.h > PjRtClient .
Клиенты управляют всем взаимодействием между устройством и фреймворком и инкапсулируют все состояния, используемые в этом взаимодействии. Они располагают универсальным набором API для взаимодействия с плагином PJRT и владеют устройствами и пространствами памяти для данного плагина.
PjRtDevice
Полные ссылки на pjrt_client.h > PjRtDevice и pjrt_device_description.h
Класс устройства используется для описания отдельного устройства. Устройство имеет описание, помогающее определить его тип (уникальный хеш для идентификации графического процессора/процессора/xPU) и местоположение в сети устройств как локально, так и глобально.
Устройства также знают связанные с ними области памяти и клиента, которому они принадлежат.
Устройство не обязательно знает буферы реальных данных, связанных с ним, но оно может выяснить это, просматривая связанные с ним области памяти.
PjRtMemorySpace
Полная ссылка: pjrt_client.h > PjRtMemorySpace .
Пространства памяти можно использовать для описания местоположения памяти. Они могут быть либо незакреплёнными и свободно располагаться где угодно, но при этом быть доступными с устройства, либо закреплёнными и должны располагаться на определённом устройстве.
Пространства памяти знают связанные с ними буферы данных и устройства (во множественном числе), с которыми связано пространство памяти, а также клиента, частью которого оно является.
PjRtBuffer
Полная ссылка на pjrt_client.h > PjRtBuffer .
Буфер хранит данные на устройстве в формате, удобном для работы внутри плагина, например, в формате элементов MLIR attr или в собственном тензорном формате. Фреймворк может попытаться отправить данные на устройство в виде xla::Literal , то есть в качестве входного аргумента для модуля, который должен быть клонирован (или заимствован) в память устройства. Когда буфер больше не нужен, фреймворк вызывает метод Delete для очистки.
Буфер знает пространство памяти, частью которого он является, и транзитивно может выяснить, какие устройства могут получить к нему доступ, но буферы не обязательно знают свои устройства.
Для взаимодействия с фреймворками буферы умеют преобразовываться в тип xla::Literal и обратно:
// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}
// Buffer to Literal
xla::Future<> ToLiteral(xla::MutableLiteralBase* literal) override {...}
API для создания буфера имеют семантику буфера , которая помогает определить, можно ли совместно использовать, копировать или изменять литеральные данные из буфера хоста.
Наконец, буферу может потребоваться существовать дольше, чем область его выполнения, если он назначен переменной в слое фреймворка x = jit(foo)(10) ; в этих случаях буферы позволяют создавать внешние ссылки, которые предоставляют временно принадлежащий указатель на данные, хранящиеся в буфере, вместе с метаданными (размеры dtype и dim) для интерпретации базовых данных.
PjRtCompiler
Полная ссылка на pjrt_compiler.h > PjRtCompiler .
Класс PjRtCompiler предоставляет полезную информацию о реализации бэкендов XLA, но не является обязательным для реализации плагином. Теоретически, задача PjRtCompiler или метода PjRtClient::Compile — принять входной модуль и вернуть PjRtLoadedExecutable .
PjRtExecutable / PjRtLoadedExecutable
Полная ссылка на pjrt_executable.h > PjRtExecutable и pjrt_client.h > PjRtLoadedExecutable .
PjRtExecutable знает, как брать скомпилированный артефакт и параметры выполнения, а также сериализовать/десериализовать их, чтобы исполняемый файл можно было сохранить и загрузить по мере необходимости.
PjRtLoadedExecutable — это скомпилированный в памяти исполняемый файл, готовый к выполнению входных аргументов. Он является подклассом PjRtExecutable .
Взаимодействие с исполняемыми файлами осуществляется через один из методов Execute клиента:
// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}
// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
PjRtDevice* device, ...) {...}
// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
PjRtDevice* device, ...) {...}
Перед вызовом Execute фреймворк передаст все необходимые данные в PjRtBuffers принадлежащие исполняющему клиенту, но возвращаемые фреймворку для использования. Эти буферы затем передаются в качестве аргументов методу Execute .
Концепции PJRT
Фьючерсы и асинхронные вычисления
Если какая-либо часть плагина реализована асинхронно, она должна правильно реализовывать будущие события.
Рассмотрим следующую программу:
@jax.jit
def foo(x): return x + 1
x = foo(1)
# [...] other logic not using `x`
print(x + 1)
Асинхронный плагин сможет поставить в очередь вычисление x и немедленно вернуть буфер, который ещё не готов к чтению, но в процессе выполнения он будет заполнен. Выполнение может продолжиться для постановки в очередь необходимых вычислений после x , не требующих x , включая выполнение на других устройствах PJRT. Как только значение x станет необходимым, выполнение будет заблокировано до тех пор, пока буфер не объявит себя готовым через future-объект, возвращаемый GetReadyFuture .
Фьючерсы могут быть полезны для определения того, когда объект станет доступен, включая устройства и буферы.
Расширенные концепции
Выход за рамки реализации базовых API расширит возможности JAX, которые может использовать плагин. Все эти функции являются опциональными, то есть в типичном рабочем процессе JIT и выполнения они будут работать и без них. Однако для конвейера высокого качества, вероятно, следует продумать степень поддержки любой из этих функций, поддерживаемых API PJRT:
- Пространства памяти
- Индивидуальные макеты
- Коммуникационные операции, такие как отправка/прием
- Разгрузка хоста
- Шардинг
Типичная связь между фреймворком PJRT и устройством
Пример журнала
Ниже представлен журнал методов, вызываемых для загрузки плагина PJRT и выполнения y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1) . В данном случае мы регистрируем взаимодействие JAX с плагином StableHLO Reference PJRT.
Пример журнала
//////////////////////////////////
// Load the plugin
//////////////////////////////////
I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400)
I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8)
I device.cc:104] StablehloReferenceDevice(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I client_cpp_pjrt.cc:86] devices(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I device.cc:168] description(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:86] Attributes(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:148] memory_spaces(0x23bac4e0)
Creating PJRT Client from client
I client_cpp_pjrt.cc:108] platform_version(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:57] id(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:153] default_memory_space(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
//////////////////////////////////
// RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)`
//////////////////////////////////
I executable.cc:309] num_partitions(0x240bab70)
I executable.cc:305] num_replicas(0x240bab70)
I executable.cc:309] num_partitions(0x240bab70)
I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400)
I buffer.cc:285] CreateMlirBufferFromLiteral
I buffer.cc:98] CreateFromLiteral
I buffer.cc:99] CreateFromLiteral: s32[] 2
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:102] CreateFromLiteral -> 0x240bb050
I buffer.cc:158] device(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I executable.cc:328] GetHloModules(0x240bab70)
I executable.cc:240] Execute(0x240bab70)
I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x240bb050)
I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32>
I executable.cc:205] EvalModule:
module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
// ...
return %3 : tensor<i32>
}
}
I executable.cc:206] Inputs: [dense<2> : tensor<i32>]
I executable.cc:213] Results: [dense<2> : tensor<i32>]
I device.cc:153] default_memory_space(0x23bac4e0)
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x22cea630)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630
//////////////////////////////////
// RUN: `print(y)`
//////////////////////////////////
I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references.
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x2404d560)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x240b6010)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x23b2db60)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60
I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)