Обзор API устройств PJRT C++

Фон

PJRT — это унифицированный API устройств, который мы хотим добавить в экосистему машинного обучения. Долгосрочная цель заключается в следующем:

  1. Фреймворки (JAX, TF и ​​т. д.) будут вызывать PJRT, который имеет реализации, специфичные для устройств, которые непрозрачны для фреймворков;
  2. Каждое устройство ориентировано на реализацию API PJRT и может быть непрозрачным для фреймворков.

PJRT предлагает API как для C, так и для C++. Подключение к любому из этих уровней допустимо. API C++ использует классы для абстрагирования некоторых концепций, но при этом имеет более тесную связь с типами данных XLA. Эта страница посвящена API C++.

Компоненты PJRT

Компоненты PJRT

PjRtClient

Полная ссылка: pjrt_client.h > PjRtClient .

Клиенты управляют всем взаимодействием между устройством и фреймворком и инкапсулируют все состояния, используемые в этом взаимодействии. Они располагают универсальным набором API для взаимодействия с плагином PJRT и владеют устройствами и пространствами памяти для данного плагина.

PjRtDevice

Полные ссылки на pjrt_client.h > PjRtDevice и pjrt_device_description.h

Класс устройства используется для описания отдельного устройства. Устройство имеет описание, помогающее определить его тип (уникальный хеш для идентификации графического процессора/процессора/xPU) и местоположение в сети устройств как локально, так и глобально.

Устройства также знают связанные с ними области памяти и клиента, которому они принадлежат.

Устройство не обязательно знает буферы реальных данных, связанных с ним, но оно может выяснить это, просматривая связанные с ним области памяти.

PjRtMemorySpace

Полная ссылка: pjrt_client.h > PjRtMemorySpace .

Пространства памяти можно использовать для описания местоположения памяти. Они могут быть либо незакреплёнными и свободно располагаться где угодно, но при этом быть доступными с устройства, либо закреплёнными и должны располагаться на определённом устройстве.

Пространства памяти знают связанные с ними буферы данных и устройства (во множественном числе), с которыми связано пространство памяти, а также клиента, частью которого оно является.

PjRtBuffer

Полная ссылка на pjrt_client.h > PjRtBuffer .

Буфер хранит данные на устройстве в формате, удобном для работы внутри плагина, например, в формате элементов MLIR attr или в собственном тензорном формате. Фреймворк может попытаться отправить данные на устройство в виде xla::Literal , то есть в качестве входного аргумента для модуля, который должен быть клонирован (или заимствован) в память устройства. Когда буфер больше не нужен, фреймворк вызывает метод Delete для очистки.

Буфер знает пространство памяти, частью которого он является, и транзитивно может выяснить, какие устройства могут получить к нему доступ, но буферы не обязательно знают свои устройства.

Для взаимодействия с фреймворками буферы умеют преобразовываться в тип xla::Literal и обратно:

// Literal to Buffer
absl::StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(...) {...}

// Buffer to Literal
xla::Future<> ToLiteral(xla::MutableLiteralBase* literal) override {...}

API для создания буфера имеют семантику буфера , которая помогает определить, можно ли совместно использовать, копировать или изменять литеральные данные из буфера хоста.

Наконец, буферу может потребоваться существовать дольше, чем область его выполнения, если он назначен переменной в слое фреймворка x = jit(foo)(10) ; в этих случаях буферы позволяют создавать внешние ссылки, которые предоставляют временно принадлежащий указатель на данные, хранящиеся в буфере, вместе с метаданными (размеры dtype и dim) для интерпретации базовых данных.

PjRtCompiler

Полная ссылка на pjrt_compiler.h > PjRtCompiler .

Класс PjRtCompiler предоставляет полезную информацию о реализации бэкендов XLA, но не является обязательным для реализации плагином. Теоретически, задача PjRtCompiler или метода PjRtClient::Compile — принять входной модуль и вернуть PjRtLoadedExecutable .

PjRtExecutable / PjRtLoadedExecutable

Полная ссылка на pjrt_executable.h > PjRtExecutable и pjrt_client.h > PjRtLoadedExecutable .

PjRtExecutable знает, как брать скомпилированный артефакт и параметры выполнения, а также сериализовать/десериализовать их, чтобы исполняемый файл можно было сохранить и загрузить по мере необходимости.

PjRtLoadedExecutable — это скомпилированный в памяти исполняемый файл, готовый к выполнению входных аргументов. Он является подклассом PjRtExecutable .

Взаимодействие с исполняемыми файлами осуществляется через один из методов Execute клиента:

// Execute on addressable devices
absl::StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles, ...) {...}

// Execute assigned replica/partition on the specified device
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecuteSharded(absl::Span<PjRtBuffer* const> argument_handles,
              PjRtDevice* device, ...) {...}

// Execute on specified device, single replica / partition
absl::StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
ExecutePortable(absl::Span<PjRtBuffer* const> argument_handles,
                PjRtDevice* device, ...) {...}

Перед вызовом Execute фреймворк передаст все необходимые данные в PjRtBuffers принадлежащие исполняющему клиенту, но возвращаемые фреймворку для использования. Эти буферы затем передаются в качестве аргументов методу Execute .

Концепции PJRT

Фьючерсы и асинхронные вычисления

Если какая-либо часть плагина реализована асинхронно, она должна правильно реализовывать будущие события.

Рассмотрим следующую программу:

@jax.jit
def foo(x): return x + 1

x = foo(1)
# [...] other logic not using `x`
print(x + 1)

Асинхронный плагин сможет поставить в очередь вычисление x и немедленно вернуть буфер, который ещё не готов к чтению, но в процессе выполнения он будет заполнен. Выполнение может продолжиться для постановки в очередь необходимых вычислений после x , не требующих x , включая выполнение на других устройствах PJRT. Как только значение x станет необходимым, выполнение будет заблокировано до тех пор, пока буфер не объявит себя готовым через future-объект, возвращаемый GetReadyFuture .

Фьючерсы могут быть полезны для определения того, когда объект станет доступен, включая устройства и буферы.

Расширенные концепции

Выход за рамки реализации базовых API расширит возможности JAX, которые может использовать плагин. Все эти функции являются опциональными, то есть в типичном рабочем процессе JIT и выполнения они будут работать и без них. Однако для конвейера высокого качества, вероятно, следует продумать степень поддержки любой из этих функций, поддерживаемых API PJRT:

  • Пространства памяти
  • Индивидуальные макеты
  • Коммуникационные операции, такие как отправка/прием
  • Разгрузка хоста
  • Шардинг

Типичная связь между фреймворком PJRT и устройством

Пример журнала

Ниже представлен журнал методов, вызываемых для загрузки плагина PJRT и выполнения y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1) . В данном случае мы регистрируем взаимодействие JAX с плагином StableHLO Reference PJRT.

Пример журнала

//////////////////////////////////
// Load the plugin
//////////////////////////////////

I client_cpp_pjrt.cc:55] StablehloReferencePjrtClient(0x23bac400)
I device.cc:53] StablehloReferenceDeviceDescription(0x23bac4f8)
I device.cc:104] StablehloReferenceDevice(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:143] AttachDefaultMemorySpace(0x23bac4e0)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I client_cpp_pjrt.cc:86] devices(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I device.cc:168] description(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:86] Attributes(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I client_cpp_pjrt.cc:81] addressable_device_count(0x23bac400)
I client_cpp_pjrt.cc:95] memory_spaces(0x23bac400)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:148] memory_spaces(0x23bac4e0)
Creating PJRT Client from client
I client_cpp_pjrt.cc:108] platform_version(0x23bac400)
I client_cpp_pjrt.cc:67] platform_name(0x23bac400)
I device.cc:57] id(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:70] device_kind(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:80] ToString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:75] DebugString(0x23bac4f8)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:128] IsAddressable(0x23bac4e0)
I device.cc:168] description(0x23bac4e0)
I device.cc:61] process_index(0x23bac4f8)
I device.cc:123] client(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I device.cc:153] default_memory_space(0x23bac4e0)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)

//////////////////////////////////
// RUN: `y = jax.jit(lambda x: jnp.power(x, jnp.int32(2)))(1)`
//////////////////////////////////

I executable.cc:309] num_partitions(0x240bab70)
I executable.cc:305] num_replicas(0x240bab70)
I executable.cc:309] num_partitions(0x240bab70)
I client_cpp_pjrt.cc:233] BufferFromHostBuffer(0x23bac400)
I buffer.cc:285] CreateMlirBufferFromLiteral
I buffer.cc:98] CreateFromLiteral
I buffer.cc:99] CreateFromLiteral: s32[] 2
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:102] CreateFromLiteral -> 0x240bb050
I buffer.cc:158] device(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I buffer.cc:154] memory_space(0x240bb050)
I executable.cc:328] GetHloModules(0x240bab70)
I executable.cc:240] Execute(0x240bab70)
I executable.cc:197] ExecuteWithReferenceInterpreter(0x240bab70)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x240bb050)
I buffer.cc:311] GetAttributeFromBuffer(0x240bb050) -> dense<2> : tensor<i32>
I executable.cc:205] EvalModule:
module @jit attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    // ...
    return %3 : tensor<i32>
  }
}
I executable.cc:206] Inputs: [dense<2> : tensor<i32>]
I executable.cc:213] Results: [dense<2> : tensor<i32>]
I device.cc:153] default_memory_space(0x23bac4e0)
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x22cea630)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x22cea630

//////////////////////////////////
// RUN: `print(y)`
//////////////////////////////////

I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:158] device(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I buffer.cc:154] memory_space(0x22cea630)
I client_cpp_pjrt.cc:71] process_index(0x23bac400)
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:269] IsOnCpu(0x22cea630) # Returns true, allows external references.
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:129] on_device_shape(0x22cea630)
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x2404d560)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x240bb050)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x240bb050
I buffer.cc:168] AcquireExternalReference(0x22cea630)
I buffer.cc:73] MlirClonedExternalReference(0x240b6010)
I buffer.cc:303] GetAttributeFromBuffer
I buffer.cc:229] IsDeleted(0x22cea630)
I buffer.cc:311] GetAttributeFromBuffer(0x22cea630) -> dense<2> : tensor<i32>
I buffer.cc:291] CreateMlirBufferFromAttribute
I buffer.cc:116] CreateFromAttribute
I buffer.cc:64] MlirPjrtBuffer(0x23b2db60)
I buffer.cc:122] CreateFromAttribute(dense<2> : tensor<i32>) -> 0x23b2db60
I buffer.cc:263] GetReadyFuture(0x22cea630)
I buffer.cc:264] GetReadyFuture(0x22cea630)