Обзор API устройств PJRT C++

Фон

PJRT — это унифицированный API-интерфейс устройства, который мы хотим добавить в экосистему машинного обучения. Долгосрочное видение заключается в следующем:

  1. Фреймворки (JAX, TF и ​​т. д.) будут вызывать PJRT, который имеет реализации для конкретных устройств, непрозрачные для фреймворков;
  2. Каждое устройство ориентировано на реализацию API-интерфейсов PJRT и может быть непрозрачным для платформ.

PJRT предлагает API C и C++ API. Подключение на любом уровне допустимо, API C++ использует классы для абстрагирования некоторых концепций, но также имеет более сильную связь с типами данных XLA. На этой странице основное внимание уделяется API C++.

Компоненты ПЖРТ

Компоненты ПЖРТ

ПджРТКлиент

Полная ссылка на pjrt_client.h > PjRtClient .

Клиенты управляют всей связью между устройством и платформой и инкапсулируют все состояние, используемое в связи. У них есть общий набор API для взаимодействия с плагином PJRT, и они владеют устройствами и пространствами памяти для данного плагина.

PjRtDevice

Полные ссылки в pjrt_client.h > PjRtDevice и pjrt_device_description.h

Класс устройства используется для описания одного устройства. Устройство имеет описание устройства, помогающее определить его тип (уникальный хэш для идентификации графического процессора/ЦП/xPU) и местоположение в сетке устройств как локально, так и глобально.

Устройства также знают связанные с ними области памяти и клиента, которым они принадлежат.

Устройство не обязательно знает буферы фактических данных, связанных с ним, но оно может выяснить это, просматривая связанные с ним области памяти.

PjRtMemoryПространство

Полная ссылка на pjrt_client.h > PjRtMemorySpace .

Пространства памяти можно использовать для описания местоположения памяти. Они могут быть либо откреплены и могут свободно размещаться где угодно, но быть доступны с устройства, либо они могут быть закреплены и должны находиться на конкретном устройстве.

Пространства памяти знают связанные с ними буферы данных и устройства (множественное число), с которыми связано пространство памяти, а также клиент, частью которого оно является.

PjRTBuffer

Полная ссылка на pjrt_client.h > PjRtBuffer .

Буфер хранит данные на устройстве в каком-либо формате, с которым будет легко работать внутри плагина, например, attr элементов MLIR или собственный тензорный формат. Платформа может попытаться отправить данные устройству в форме xla::Literal , т. е. в качестве входного аргумента модуля, который необходимо клонировать (или заимствовать) в память устройства. Как только буфер больше не нужен, платформа вызывает метод Delete для очистки.

Буфер знает пространство памяти, частью которого он является, и может транзитивно определить, какие устройства могут получить к нему доступ, но буферы не обязательно знают свои устройства.

Для взаимодействия с платформами буферы умеют конвертировать в тип xla::Literal и обратно:

// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}

// Buffer to Literal
xla::PjRtFuture<> ToLiteral(xla::MutableLiteralBase* literal) override {...}

API-интерфейсы для создания буфера имеют семантику буфера , которая помогает определять, можно ли использовать литеральные данные из буфера хоста совместно, копировать или изменять.

Наконец, буфер может длиться дольше, чем область его выполнения, если он назначен переменной на уровне платформы x = jit(foo)(10) , в этих случаях буферы позволяют создавать внешние ссылки, которые предоставляют временно принадлежащий указатель. к данным, хранящимся в буфере, вместе с метаданными (размеры dtype/dim) для интерпретации базовых данных.

PjRTCompiler

Полная ссылка на pjrt_compiler.h > PjRtCompiler .

Класс PjRtCompiler предоставляет полезные сведения о реализации для серверных частей XLA, но не является обязательным для реализации плагина. Теоретически ответственность PjRtCompiler или метода PjRtClient::Compile заключается в том, чтобы взять входной модуль и вернуть PjRtLoadedExecutable .

PjRtExecutable / PjRtLoadedExecutable

Полная ссылка на pjrt_executable.h > PjRtExecutable и pjrt_client.h > PjRtLoadedExecutable .

PjRtExecutable знает, как взять скомпилированный артефакт и параметры выполнения и сериализовать/десериализовать их, чтобы исполняемый файл можно было сохранить и загрузить по мере необходимости.

PjRtLoadedExecutable — это исполняемый файл, скомпилированный в памяти, который готов к выполнению входных аргументов. Это подкласс PjRtExecutable .

Исполняемые файлы взаимодействуют через один из клиентских методов Execute :

// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}

// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
              PjRtDevice* device, ...) {...}

// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
                PjRtDevice* device, ...) {...}

Перед вызовом Execute платформа передаст все необходимые данные в PjRtBuffers принадлежащие исполняющему клиенту, но возвращаемые платформе для ссылки. Эти буферы затем предоставляются в качестве аргументов методу Execute .

Концепции PJRT

PjRtFutures и асинхронные вычисления

Если какая-либо часть плагина реализована асинхронно, она должна правильно реализовывать фьючерсы.

Рассмотрим следующую программу:

@jax.jit
def foo(x): return x + 1

x = foo(1)
# [...] other logic not using `x`
print(x + 1)

Плагин async сможет поставить в очередь вычисление x и немедленно вернуть буфер, который еще не готов к чтению, но выполнение заполнит его. Выполнение может продолжать ставить в очередь необходимые вычисления после x , которые не требуют x , включая выполнение на других устройствах PJRT. Как только потребуется значение x , выполнение будет заблокировано до тех пор, пока буфер не объявит себя готовым через будущее, возвращаемое GetReadyFuture .

Фьючерсы могут быть полезны для определения того, когда объект станет доступным, включая устройства и буферы.

Расширенные концепции

Выход за рамки реализации базовых API расширит возможности JAX, которые могут использоваться плагином. Все эти функции являются дополнительными в том смысле, что в типичном рабочем процессе JIT и выполнения они будут работать без них, но для конвейера производственного качества, вероятно, следует подумать о степени поддержки любой из этих функций, поддерживаемых API-интерфейсами PJRT:

  • Пространства памяти
  • Пользовательские макеты
  • Операции связи, такие как отправка/получение
  • Разгрузка хоста
  • Шардинг

Типичная связь между платформой PJRT и устройством

Пример журнала

Ниже приведен журнал методов, вызываемых для загрузки плагина PJRT и выполнения y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1) . В этом случае мы регистрируем взаимодействие JAX с плагином StableHLO Reference PJRT.

Пример журнала

//////////////////////////////////
// Load the plugin
//////////////////////////////////

I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400)
I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8)
I device.cc:104] StablehloReferenceDevice(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I client_cpp_pjrt.cc:86] devices(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I device.cc:168] description(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:86] Attributes(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:148] memory_spaces(0x23bac4e0)
Creating PJRT Client from client
I client_cpp_pjrt.cc:108] platform_version(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:57] id(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:153] default_memory_space(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)

//////////////////////////////////
// RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)`
//////////////////////////////////

I executable.cc:309] num_partitions(0x240bab70)
I executable.cc:305] num_replicas(0x240bab70)
I executable.cc:309] num_partitions(0x240bab70)
I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400)
I buffer.cc:285] CreateMlirBufferFromLiteral
I buffer.cc:98] CreateFromLiteral
I buffer.cc:99] CreateFromLiteral: s32[] 2
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:102] CreateFromLiteral -> 0x240bb050
I buffer.cc:158] device(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I executable.cc:328] GetHloModules(0x240bab70)
I executable.cc:240] Execute(0x240bab70)
I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x240bb050)
I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32>
I executable.cc:205] EvalModule:
module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    // ...
    return %3 : tensor<i32>
  }
}
I executable.cc:206] Inputs: [dense<2> : tensor<i32>]
I executable.cc:213] Results: [dense<2> : tensor<i32>]
I device.cc:153] default_memory_space(0x23bac4e0)
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x22cea630)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630

//////////////////////////////////
// RUN: `print(y)`
//////////////////////////////////

I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references.
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x2404d560)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x240b6010)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x23b2db60)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60
I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)