Введение
XLA — это компилятор для линейной алгебры, зависящий от оборудования и фреймворка, предлагающий лучшую в своем классе производительность. JAX, TF, Pytorch и другие используют XLA, преобразуя пользовательский ввод в набор операций StableHLO («операция высокого уровня»: набор из ~100 статически сформированных инструкций, таких как сложение, вычитание, matmul и т. д.), из которого XLA создает оптимизированный код для различных бэкэндов:
Во время выполнения фреймворки вызывают API среды выполнения PJRT , который позволяет фреймворкам выполнять операцию «заполнять указанные буферы, используя заданную программу StableHLO на определенном устройстве».
XLA: конвейер графического процессора
XLA:GPU использует комбинацию «собственных» (PTX, через LLVM) излучателей и излучателей TritonIR для генерации высокопроизводительных ядер графического процессора (синий цвет обозначает компоненты 3P):
Пример выполнения: JAX
Чтобы проиллюстрировать конвейер, начнем с работающего примера в JAX, который вычисляет matmul в сочетании с умножением на константу и отрицанием:
def f(a, b):
return -((a @ b) * 0.125)
Мы можем проверить HLO, сгенерированный функцией:
M = 1024
K = 512
N = 2048
key = jax.random.PRNGKey(1701)
a = jax.random.randint(key, (M, K), dtype=jax.numpy.int8, minval=0, maxval=255)
b = jax.random.normal(key, (K, N), dtype=jax.dtypes.bfloat16)
print(jax.xla_computation(f)(a, b).as_hlo_text())
который генерирует:
HloModule xla_computation_f, entry_computation_layout={(s8[1024,512]{1,0}, bf16[512,2048]{1,0})->(bf16[1024,2048]{1,0})}
ENTRY main.10 {
Arg_0.1 = s8[1024,512]{1,0} parameter(0)
convert.5 = bf16[1024,512]{1,0} convert(Arg_0.1)
Arg_1.2 = bf16[512,2048]{1,0} parameter(1)
dot.6 = bf16[1024,2048]{1,0} dot(convert.5, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
constant.3 = bf16[] constant(0.125)
broadcast.4 = bf16[1024,2048]{1,0} broadcast(constant.3), dimensions={}
multiply.7 = bf16[1024,2048]{1,0} multiply(dot.6, broadcast.4)
ROOT negate.8 = bf16[1024,2048]{1,0} negate(multiply.7)
}
Мы также можем визуализировать входное вычисление HLO, используя jax.xla_computation(f)(a, b).as_hlo_dot_graph()
:
Оптимизации HLO: ключевые компоненты
В HLO происходит ряд важных оптимизационных проходов, когда происходит перезапись HLO->HLO.
Разделитель SPMD
Разделитель XLA SPMD, как описано в GSPMD: General and Scalable Parallelization for MLComputation Graphs , потребляет HLO с аннотациями шардинга (созданными, например, jax.pjit
) и создает шардированный HLO, который затем может работать на нескольких хостах и устройствах. Помимо разделения, SPMD пытается оптимизировать HLO для оптимального графика выполнения, перекрывая вычисления и связь между узлами.
Пример
Рассмотрим пример простой программы JAX, распределенной по двум устройствам:
# Defines a mesh with two axes called ‘x’ and ‘y’,
# sharded across two devices: first and second CPU.
with jax.sharding.Mesh(
[['cpu:0', 'cpu:1']], ('x', 'y')):
@pjit
def f(a, b):
out = -((a @ b) * 0.125)
# Shard output matrix access across ‘x’
# and ‘y’ respectively. Generates ‘Sharding’
# custom call.
out = with_sharding_constraint(
out, jax.lax.PartitionSpec('x', 'y'))
return out
# Random inputs to call our function.
a = jax.random.randint(key, (1024, 512), jnp.int8)
b = jax.random.normal(key, (512, 2048), jnp.float32)
print(f.lower(a, b).compiler_ir())
Визуализируя это, аннотации шардинга представлены в виде пользовательских вызовов:
Чтобы проверить, как разделитель SPMD расширяет пользовательский вызов, мы можем посмотреть на HLO после оптимизаций:
print(f.lower(np.ones((8, 8)).compile().as_text())
Что генерирует HLO с коллективом:
Задание по макету
HLO разделяет логическую форму и физическую компоновку (как тензоры располагаются в памяти). Например, матрица f32[32, 64]
может быть представлена либо в строковом, либо в столбцовом порядке, представляемом как {1,0}
или {0,1}
соответственно. В общем случае компоновка представляется как часть формы, показывая перестановку по числу измерений, указывающих физическую компоновку в памяти.
Для каждой операции, присутствующей в HLO, проход назначения макета выбирает оптимальный макет (например, NHWC для свертки по Амперу). Например, операция matmul int8xint8->int32
предпочитает макет {0,1}
для RHS вычисления. Аналогично, «транспонирование», вставленное пользователем, игнорируется и кодируется как изменение макета.
Затем макеты распространяются по графу, а конфликты между макетами или в конечных точках графа материализуются как операции copy
, которые выполняют физическую транспозицию. Например, начиная с графа
Выполнив задание макета, мы видим следующие макеты и вставленную операцию copy
:
Слияние
Fusion — это единственная важнейшая оптимизация XLA, которая группирует несколько операций (например, сложение в возведение в степень в matmul) в одно ядро. Поскольку многие рабочие нагрузки GPU, как правило, привязаны к памяти, Fusion значительно ускоряет выполнение, избегая записи промежуточных тензоров в HBM и последующего их обратного чтения, и вместо этого передает их либо в регистры, либо в общую память.
Объединенные инструкции HLO объединяются в единое вычисление слияния, которое устанавливает следующие инварианты:
В HBM не реализовано промежуточное хранилище внутри слияния (всё должно передаваться либо через регистры, либо через общую память).
Слияние всегда компилируется только в одно ядро графического процессора.
Оптимизации HLO на примере выполнения
Мы можем проверить HLO после оптимизации с помощью jax.jit(f).lower(a, b).compile().as_text()
и убедиться, что было сгенерировано одно слияние:
HloModule jit_f, is_scheduled=true, entry_computation_layout={(s8[3,2]{1,0}, bf16[2,3]{1,0})->bf16[3,3]{1,0} }, allow_spmd_sharding_propagation_to_output={true}
%triton_gemm_dot.6_computation (parameter_0: s8[3,2], parameter_1: bf16[2,3]) -> bf16[3,3] {
%parameter_0 = s8[3,2]{1,0} parameter(0)
%convert.0 = bf16[3,2]{1,0} convert(s8[3,2]{1,0} %parameter_0)
%parameter_1 = bf16[2,3]{1,0} parameter(1)
%dot.0 = bf16[3,3]{1,0} dot(bf16[3,2]{1,0} %convert.0, bf16[2,3]{1,0} %parameter_1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
%convert.1 = f32[3,3]{1,0} convert(bf16[3,3]{1,0} %dot.0)
%constant_0 = bf16[] constant(0.125)
%broadcast.0 = bf16[3,3]{1,0} broadcast(bf16[] %constant_0), dimensions={}
%convert.2 = f32[3,3]{1,0} convert(bf16[3,3]{1,0} %broadcast.0)
%multiply.0 = f32[3,3]{1,0} multiply(f32[3,3]{1,0} %convert.1, f32[3,3]{1,0} %convert.2)
%negate.0 = f32[3,3]{1,0} negate(f32[3,3]{1,0} %multiply.0)
ROOT %convert.6 = bf16[3,3]{1,0} convert(f32[3,3]{1,0} %negate.0)
}
ENTRY %main.9 (Arg_0.1: s8[3,2], Arg_1.2: bf16[2,3]) -> bf16[3,3] {
%Arg_1.2 = bf16[2,3]{1,0} parameter(1), sharding={replicated}
%Arg_0.1 = s8[3,2]{1,0} parameter(0), sharding={replicated}
ROOT %triton_gemm_dot.6 = bf16[3,3]{1,0} fusion(s8[3,2]{1,0} %Arg_0.1, bf16[2,3]{1,0} %Arg_1.2), kind=kCustom, calls=%triton_gemm_dot.6_computation, backend_config={"kind":"__triton_gemm","triton_gemm_config":{"block_m":"64","block_n":"64","block_k":"64","split_k":"1","num_stages":"2","num_warps":"4"} }
}
Обратите внимание, что fusion backend_config
сообщает нам, что Triton будет использоваться в качестве стратегии генерации кода, и указывает выбранную разбивку на блоки.
Мы также можем визуализировать полученный модуль:
Назначение буфера и планирование
Проход назначения буфера учитывает информацию о форме и нацелен на создание оптимального распределения буфера для программы, минимизируя объем потребляемой промежуточной памяти. В отличие от TF или немедленного (некомпилированного) выполнения PyTorch, где распределитель памяти не знает граф заранее, планировщик XLA может «заглянуть в будущее» и создать оптимальный график вычислений.
Компилятор Backend: выбор кода и библиотеки
Для каждой инструкции HLO в вычислениях XLA выбирает, запускать ли ее с использованием библиотеки, связанной со средой выполнения, или кодировать ее в PTX.
Выбор библиотеки
Для многих распространенных операций XLA:GPU использует библиотеки быстрой производительности от NVIDIA, такие как cuBLAS, cuDNN и NCCL. Библиотеки имеют преимущество проверенной быстрой производительности, но часто исключают сложные возможности слияния.
Прямая генерация кода
Бэкэнд XLA:GPU напрямую генерирует высокопроизводительный LLVM IR для ряда операций (сокращение, транспонирование и т. д.).
Генерация кода Triton
Для более сложных слияний, включающих умножение матриц или softmax, XLA:GPU использует Triton в качестве слоя генерации кода. Слияния HLO преобразуются в TritonIR (диалект MLIR, который служит входом для Triton), выбирают параметры тайлинга и вызывают Triton для генерации PTX:
Мы наблюдали, как полученный код очень хорошо работает на Ampere, с производительностью, близкой к линии крыши, при правильно настроенных размерах плитки.
Время выполнения
XLA Runtime преобразует полученную последовательность вызовов ядра CUDA и вызовов библиотеки в RuntimeIR (диалект MLIR в XLA), на котором выполняется извлечение графа CUDA. Граф CUDA все еще находится в стадии разработки, в настоящее время поддерживаются только некоторые узлы. После извлечения границ графа CUDA RuntimeIR компилируется через LLVM в исполняемый файл CPU, который затем может быть сохранен или передан для компиляции Ahead-Of-Time.