XLA: Эмиттеры графического процессора

Существует три способа создания кода для HLO в XLA:GPU.

изображение

  1. Замена HLO на пользовательские вызовы внешних библиотек, например NVidia cuBLAS , cuDNN .
  2. Тайлинг HLO на уровне блоков и последующее использование OpenAI Triton .
  3. Использование эмиттеров XLA для постепенного понижения HLO до LLVM IR.

Этот документ посвящен эмиттерам XLA:GPU.

Генератор кода на основе героев

В XLA:GPU имеется 7 типов эмиттеров. Каждый тип эмиттера соответствует «герою» слияния, то есть наиболее важной операции в объединенных вычислениях, которая формирует генерацию кода для всего слияния.

изображение

Например, эмиттер транспонирования будет выбран, если в слиянии есть HloTransposeInstruction , который требует использования общей памяти для улучшения шаблонов чтения и записи памяти. Эмиттер сокращения генерирует сокращения, используя перемешивание и общую память. Эмиттер контура является эмиттером по умолчанию. Если у фьюжн нет героя, для которого у нас есть специальный эмиттер, то будет использоваться петлевой эмиттер.

Общий обзор

Код состоит из следующих больших строительных блоков:

  • Разделитель вычислений — разделение вычислений HLO Fusion на функции.
  • Эмиттеры — преобразование секционированного слияния HLO в MLIR (диалекты xla_gpu , tensor , arith , math , scf )
  • Конвейер компиляции — оптимизирует и понижает IR до LLVM.

изображение

Разделение

См. Compution_partitioner.h .

Неэлементные инструкции HLO не всегда могут выполняться вместе. Рассмотрим следующий график HLO:

     param
       |
      log
      |  \
      |  transpose
      |  /
      add

Если мы выпустим это в одной функции, доступ к log будет осуществляться по двум разным индексам для каждого элемента add . Старые эмиттеры решают эту проблему, генерируя log дважды. Для данного конкретного графа это не проблема, но при наличии нескольких разбиений размер кода растет в геометрической прогрессии.

Здесь мы решаем эту проблему, разбивая граф на части, которые можно безопасно создавать как одну функцию. Критерии:

  • Инструкции, которые имеют только одного пользователя, можно безопасно отправлять вместе со своим пользователем.
  • Инструкции, которые имеют несколько пользователей, можно безопасно создавать вместе со своими пользователями, если все пользователи обращаются к ним через одни и те же индексы.

В приведенном выше примере add и tranpose обращаются к разным индексам log , поэтому выдавать его вместе с ними небезопасно.

Таким образом, граф разделен на три функции (каждая из которых содержит только одну инструкцию).

То же самое применимо к следующему примеру с slice и pad add .

изображение

Элементарная эмиссия

См. elemental_hlo_to_mlir.h .

Элементарная эмиссия создает циклы и математические/арифметические операции для HloInstructions . По большей части это просто, но здесь происходит несколько интересных вещей.

Индексирующие преобразования

Некоторые инструкции ( transpose , broadcast , reshape , slice , reverse и некоторые другие) представляют собой чисто преобразования индексов: чтобы создать элемент результата, нам нужно создать какой-то другой элемент входных данных. Для этого мы можем повторно использовать indexing_anaанализ XLA, в котором есть функции для создания сопоставления вывода и ввода для инструкции.

Например, для transpose из [20,40] в [40,20] будет создана следующая карта индексации (одно аффинное выражение для каждого входного измерения; d0 и d1 — выходные измерения):

  (d0, d1) -> d1
  (d0, d1) -> d0

Таким образом, для этих инструкций чистого преобразования индекса мы можем просто получить карту, применить ее к выходным индексам и создать входные данные по результирующему индексу.

Аналогичным образом, операция pad использует карты индексации и ограничения для большей части реализации. pad также является индексирующим преобразованием с некоторыми дополнительными проверками, чтобы увидеть, возвращаем ли мы элемент ввода или значение заполнения.

Кортежи

Мы не поддерживаем внутренние tuple . Мы также не поддерживаем вывод вложенных кортежей. Все графики XLA, использующие эти функции, можно преобразовать в графики, которые этого не делают.

Собирать

Мы поддерживаем только канонические сборки, созданные с помощью gather_simplifier .

Функции подграфа

Для подграфа вычисления с параметрами %p0 до %p_n и корней подграфа с рангом r и типами элементов ( e0 до e_m ) мы используем следующую сигнатуру функции MLIR:

(%p0: tensor<...>, %p1: tensor<...>, ..., %pn: tensor<...>,
 %i0: index, %i1: index, ..., %i_r-1: index) -> (e0, ..., e_m)

То есть у нас есть один входной тензор для каждого вычислительного параметра, один входной индекс для каждого измерения выходного сигнала и один результат для каждого выходного сигнала.

Чтобы выпустить функцию, мы просто используем элементный эмиттер, указанный выше, и рекурсивно выдаем его операнды, пока не достигнем края подграфа. Затем мы: выдаем tensor.extract для параметров или выдаем func.call для других подграфов.

Функция входа

Каждый тип эмиттера отличается тем, как он генерирует функцию входа, то есть функцию для героя. Функция входа отличается от функций, описанных выше, поскольку она не имеет индексов в качестве входных данных (только идентификаторы потока и блока) и фактически должна куда-то записывать выходные данные. Для эмиттера цикла это довольно просто, но эмиттеры транспонирования и редукции имеют нетривиальную логику записи.

Подпись входного вычисления:

(%p0: tensor<...>, ..., %pn: tensor<...>,
 %r0: tensor<...>, ..., %rn: tensor<...>) -> (tensor<...>, ..., tensor<...>)

Здесь, как и раньше, %pn — это параметры вычислений, а %rn — это результаты вычислений. Вычисление входа принимает результаты в виде тензоров, tensor.insert обновляет их, а затем возвращает их. Никакое другое использование выходных тензоров не допускается.

Конвейер компиляции

Петлевой излучатель

См. Loop_mlir.h .

Давайте изучим наиболее важные этапы конвейера компиляции MLIR с использованием HLO для функции GELU.

изображение

Это вычисление HLO имеет только поэлементные операции, константы и широковещательные сообщения. Он будет излучаться с помощью петлевого излучателя.

Преобразование MLIR

После преобразования в MLIR мы получаем xla_gpu.loop , который зависит от %thread_id_x и %block_id_x и определяет цикл, который линейно обходит все элементы вывода, чтобы гарантировать объединенную запись.

На каждой итерации этого цикла мы вызываем

   %pure_call = xla_gpu.pure_call @gelu(%input, %dim0, %dim1, %dim2)
      : (tensor<6x512x4096xbf16>, index, index, index) -> bf16

для вычисления элементов корневой операции. Обратите внимание, что у нас есть только одна описанная функция для @gelu , поскольку средство разделения не обнаружило тензор, имеющий 2 или более различных шаблонов доступа.

#map = #xla_gpu.indexing_map<"(th_x, bl_x)[vector_index] -> ("
 "bl_x floordiv 4096, (bl_x floordiv 8) mod 512, (bl_x mod 8) * 512 + th_x * 4 + vector_index),"
 "domain: th_x in [0, 127], bl_x in [0, 24575], vector_index in [0, 3]">

func.func @main(%input: tensor<6x512x4096xbf16> , %output: tensor<6x512x4096xbf16>)
   -> tensor<6x512x4096xbf16> {
 %thread_id_x = gpu.thread_id  x {xla.range = [0 : index, 127 : index]}
 %block_id_x = gpu.block_id  x {xla.range = [0 : index, 24575 : index]}

 %xla_loop = xla_gpu.loop (%thread_id_x, %block_id_x)[%vector_index] -> (%dim0, %dim1, %dim2)
     in #map iter_args(%iter = %output) -> (tensor<6x512x4096xbf16>) {
   %pure_call = xla_gpu.pure_call @gelu(%input, %dim0, %dim1, %dim2)
      : (tensor<6x512x4096xbf16>, index, index, index) -> bf16
   %inserted = tensor.insert %pure_call into %iter[%dim0, %dim1, %dim2] : tensor<6x512x4096xbf16>
   xla_gpu.yield %inserted : tensor<6x512x4096xbf16>
 }
 return %xla_loop : tensor<6x512x4096xbf16>
}

func.func private @gelu(%arg0: tensor<6x512x4096xbf16>, %i: index, %j: index, %k: index) -> bf16 {
  %cst = arith.constant 5.000000e-01 : bf16
  %cst_0 = arith.constant 1.000000e+00 : bf16
  %cst_1 = arith.constant 7.968750e-01 : bf16
  %cst_2 = arith.constant 4.467770e-02 : bf16
  %extracted = tensor.extract %arg0[%i, %j, %k] : tensor<6x512x4096xbf16>
  %0 = arith.mulf %extracted, %extracted : bf16
  %1 = arith.mulf %0, %extracted : bf16
  %2 = arith.mulf %1, %cst_2 : bf16
  %3 = arith.addf %extracted, %2 : bf16
  %4 = arith.mulf %3, %cst_1 : bf16
  %5 = math.tanh %4 : bf16
  %6 = arith.addf %5, %cst_0 : bf16
  %7 = arith.mulf %6, %cst : bf16
  %8 = arith.mulf %extracted, %7 : bf16
  return %8 : bf16
}

Инлайнер

После встраивания @gelu мы получаем одну функцию @main . Может случиться так, что одна и та же функция вызывается дважды и более. В этом случае мы не встраиваем. Более подробную информацию о правилах встраивания можно найти в xla_gpu_dialect.cc .

func.func @main(%arg0: tensor<6x512x4096xbf16>, %arg1: tensor<6x512x4096xbf16>) -> tensor<6x512x4096xbf16> {
 ...
  %thread_id_x = gpu.thread_id  x {xla.range = [0 : index, 127 : index]}
  %block_id_x = gpu.block_id  x {xla.range = [0 : index, 24575 : index]}

  %xla_loop = xla_gpu.loop (%thread_id_x, %block_id_x)[%vector_index] -> (%dim0, %dim1, %dim2)
      in #map iter_args(%iter = %output) -> (tensor<6x512x4096xbf16>) {
    %extracted = tensor.extract %input[%dim0, %dim1, %dim2] : tensor<6x512x4096xbf16>
    %0 = arith.mulf %extracted, %extracted : bf16
    %1 = arith.mulf %0, %extracted : bf16
    %2 = arith.mulf %1, %cst : bf16
    %3 = arith.addf %extracted, %2 : bf16
    %4 = arith.mulf %3, %cst_0 : bf16
    %5 = math.tanh %4 : bf16
    %6 = arith.addf %5, %cst_1 : bf16
    %7 = arith.mulf %6, %cst_2 : bf16
    %8 = arith.mulf %extracted, %7 : bf16
    %inserted = tensor.insert %8 into %iter[%dim0, %dim1, %dim2] : tensor<6x512x4096xbf16>
    xla_gpu.yield %inserted : tensor<6x512x4096xbf16>
  }
  return %xla_loop : tensor<6x512x4096xbf16>
}

Преобразование xla_gpu в scf

См. нижний_xla_gpu_to_scf.cc .

xla_gpu.loop представляет собой гнездо циклов с проверкой границ внутри. Если переменные индукции цикла выходят за пределы области отображения индексации, то эта итерация пропускается. Это означает, что цикл преобразуется в одну или несколько вложенных операций scf.for с scf.if внутри.

%xla_loop = scf.for %vector_index = %c0 to %c4 step %c1 iter_args(%iter = %output) -> (tensor<6x512x4096xbf16>) {
   %2 = arith.cmpi sge, %thread_id_x, %c0 : index
   %3 = arith.cmpi sle, %thread_id_x, %c127 : index
   %4 = arith.andi %2, %3 : i1
   %5 = arith.cmpi sge, %block_id_x, %c0 : index
   %6 = arith.cmpi sle, %block_id_x, %c24575 : index
   %7 = arith.andi %5, %6 : i1
   %inbounds = arith.andi %4, %7 : i1
   %9 = scf.if %inbounds -> (tensor<6x512x4096xbf16>) {
     %dim0 = xla_gpu.apply_indexing #map(%thread_id_x,  %block_id_x)[%vector_index]
     %dim1 = xla_gpu.apply_indexing #map1(%thread_id_x, %block_id_x)[%vector_index]
     %dim2 = xla_gpu.apply_indexing #map2(%thread_id_x, %block_id_x)[%vector_index]
     %extracted = tensor.extract %input[%dim0, %dim1, %dim2] : tensor<6x512x4096xbf16>
     // ... more arithmetic operations
     %29 = arith.mulf %extracted, %28 : bf16
     %inserted = tensor.insert %29 into %iter[%dim0, %dim1, %dim2] : tensor<6x512x4096xbf16>
     scf.yield %inserted : tensor<6x512x4096xbf16>
   } else {
     scf.yield %iter : tensor<6x512x4096xbf16>
   }
   scf.yield %9 : tensor<6x512x4096xbf16>
 }

Сгладить тензоры

См. Flatten_tensors.cc .

Тензоры Nd проецируются на 1D. Это упростит векторизацию и переход к LLVM, поскольку каждый доступ к тензору теперь соответствует тому, как данные выравниваются в памяти.

#map = #xla_gpu.indexing_map<"(th_x, bl_x, vector_index) -> (th_x * 4 + bl_x * 512 + vector_index),"
 "domain: th_x in [0, 127], bl_x in [0, 24575], vector_index in [0, 3]">

func.func @main(%input: tensor<12582912xbf16>, %output: tensor<12582912xbf16>) -> tensor<12582912xbf16> {
 %xla_loop = scf.for %vector_index = %c0 to %c4 step %c1 iter_args(%iter = %output) -> (tensor<12582912xbf16>) {
   %dim = xla_gpu.apply_indexing #map(%thread_id_x, %block_id_x, %vector_index)
   %extracted = tensor.extract %input[%dim] : tensor<12582912xbf16>
   %2 = arith.mulf %extracted, %extracted : bf16
   %3 = arith.mulf %2, %extracted : bf16
   %4 = arith.mulf %3, %cst_2 : bf16
   %5 = arith.addf %extracted, %4 : bf16
   %6 = arith.mulf %5, %cst_1 : bf16
   %7 = math.tanh %6 : bf16
   %8 = arith.addf %7, %cst_0 : bf16
   %9 = arith.mulf %8, %cst : bf16
   %10 = arith.mulf %extracted, %9 : bf16
   %inserted = tensor.insert %10 into %iter[%dim] : tensor<12582912xbf16>
   scf.yield %inserted : tensor<12582912xbf16>
 }
 return %xla_loop : tensor<12582912xbf16>
}

Векторизация

См. Vectorize_loads_stores.cc .

Этот проход анализирует индексы в операциях tensor.extract и tensor.insert , и если они создаются с помощью xla_gpu.apply_indexing , который обращается к элементам, смежным с %vector_index , и доступ выровнен, то tensor.extract преобразуется в vector.transfer_read и вытащили из петли.

В этом конкретном случае существует карта индексации (th_x, bl_x, vector_index) -> (th_x * 4 + bl_x * 512 + vector_index) используемая для вычисления элементов для извлечения и вставки в цикл scf.for от 0 до 4. Поэтому как tensor.extract , так и tensor.insert могут быть векторизованы.

func.func @main(%input: tensor<12582912xbf16>, %output: tensor<12582912xbf16>) -> tensor<12582912xbf16> {
 %vector_0 = arith.constant dense<0.000000e+00> : vector<4xbf16>
 %0 = xla_gpu.apply_indexing #map(%thread_id_x, %block_id_x, %c0)
 %2 = vector.transfer_read %input[%0], %cst {in_bounds = [true]} : tensor<12582912xbf16>, vector<4xbf16>
 %xla_loop:2 = scf.for %vector_index = %c0 to %c4 step %c1
     iter_args(%iter = %output, %iter_vector = %vector_0) -> (tensor<12582912xbf16>, vector<4xbf16>) {
   %5 = vector.extract %2[%vector_index] : bf16 from vector<4xbf16>
   %6 = arith.mulf %5, %5 : bf16
   %7 = arith.mulf %6, %5 : bf16
   %8 = arith.mulf %7, %cst_4 : bf16
   %9 = arith.addf %5, %8 : bf16
   %10 = arith.mulf %9, %cst_3 : bf16
   %11 = math.tanh %10 : bf16
   %12 = arith.addf %11, %cst_2 : bf16
   %13 = arith.mulf %12, %cst_1 : bf16
   %14 = arith.mulf %5, %13 : bf16
   %15 = vector.insert %14, %iter_vector [%vector_index] : bf16 into vector<4xbf16>
   scf.yield %iter, %15 : tensor<12582912xbf16>, vector<4xbf16>
 }
 %4 = vector.transfer_write %xla_loop#1, %output[%0] {in_bounds = [true]}
     : vector<4xbf16>, tensor<12582912xbf16>
 return %4 : tensor<12582912xbf16>
}

Разворачивание цикла

См. оптимизированный_loops.cc .

Развертывание цикла находит циклы scf.for , которые можно развернуть. В этом случае цикл по элементам вектора исчезает.

func.func @main(%input: tensor<12582912xbf16>, %arg1: tensor<12582912xbf16>) -> tensor<12582912xbf16> {

  %cst_0 = arith.constant dense<0.000000e+00> : vector<4xbf16>
  %dim = xla_gpu.apply_indexing #map(%thread_id_x, %block_id_x, %c0)
  %2 = vector.transfer_read %input[%dim], %cst {in_bounds = [true]} : tensor<12582912xbf16>, vector<4xbf16>
  %3 = vector.extract %2[%c0] : bf16 from vector<4xbf16>
  ...
  %13 = vector.insert %12, %cst_0 [%c0] : bf16 into vector<4xbf16>
  %14 = vector.extract %2[%c1] : bf16 from vector<4xbf16>
  ...
  %24 = vector.insert %23, %13 [%c1] : bf16 into vector<4xbf16>
  %25 = vector.extract %2[%c2] : bf16 from vector<4xbf16>
  ...
  %35 = vector.insert %34, %24 [%c2] : bf16 into vector<4xbf16>
  %36 = vector.extract %2[%c3] : bf16 from vector<4xbf16>
  ...
  %46 = vector.insert %45, %35 [%c3] : bf16 into vector<4xbf16>
  %47 = vector.transfer_write %46, %arg1[%dim] {in_bounds = [true]} : vector<4xbf16>, tensor<12582912xbf16>
  return %47 : tensor<12582912xbf16>
}

Преобразование в LLVM

В основном мы используем стандартные понижения LLVM, но есть и несколько специальных приемов. Мы не можем использовать понижения memref для тензоров, поскольку мы не буферизуем IR и наш ABI несовместим с memref ABI. Вместо этого у нас есть кастомный понижение напрямую из тензоров в LLVM .

  • Понижение тензоров осуществляется в файле low_tensors.cc . tensor.extract понижается до llvm.load , tensor.insert до llvm.store очевидным образом.
  • propagate_slice_indices и merge_pointers_to_same_slice вместе реализуют детали назначения буфера и ABI XLA: если два тензора совместно используют один и тот же срез буфера, они передаются только один раз. Эти проходы дедуплицируют аргументы функции.
llvm.func @__nv_tanhf(f32) -> f32
llvm.func @main(%arg0: !llvm.ptr, %arg1: !llvm.ptr) {
  %11 = nvvm.read.ptx.sreg.tid.x : i32
  %12 = nvvm.read.ptx.sreg.ctaid.x : i32
  %13 = llvm.mul %11, %1 : i32
  %14 = llvm.mul %12, %0 : i32
  %15 = llvm.add %13, %14 : i32
  %16 = llvm.getelementptr inbounds %arg0[%15] : (!llvm.ptr, i32) -> !llvm.ptr, bf16
  %17 = llvm.load %16 invariant : !llvm.ptr -> vector<4xbf16>
  %18 = llvm.extractelement %17[%2 : i32] : vector<4xbf16>
  %19 = llvm.fmul %18, %18  : bf16
  %20 = llvm.fmul %19, %18  : bf16
  %21 = llvm.fmul %20, %4  : bf16
  %22 = llvm.fadd %18, %21  : bf16
  %23 = llvm.fmul %22, %5  : bf16
  %24 = llvm.fpext %23 : bf16 to f32
  %25 = llvm.call @__nv_tanhf(%24) : (f32) -> f32
  %26 = llvm.fptrunc %25 : f32 to bf16
  %27 = llvm.fadd %26, %6  : bf16
  %28 = llvm.fmul %27, %7  : bf16
  %29 = llvm.fmul %18, %28  : bf16
  %30 = llvm.insertelement %29, %8[%2 : i32] : vector<4xbf16>
  ...
}

Транспонировать эмиттер

Давайте рассмотрим немного более сложный пример.

изображение

Эмиттер транспонирования отличается от эмиттера цикла только тем, как генерируется функция входа.

func.func @transpose(%arg0: tensor<20x160x170xf32>, %arg1: tensor<170x160x20xf32>) -> tensor<170x160x20xf32> {
  %thread_id_x = gpu.thread_id  x {xla.range = [0 : index, 127 : index]}
  %block_id_x = gpu.block_id  x {xla.range = [0 : index, 959 : index]}

  %shmem = xla_gpu.allocate_shared : tensor<32x1x33xf32>
  %xla_loop = xla_gpu.loop (%thread_id_x, %block_id_x)[%i, %j]
      -> (%input_dim0, %input_dim1, %input_dim2, %shmem_dim0, %shmem_dim1, %shmem_dim2)
      in #map iter_args(%iter = %shmem) -> (tensor<32x1x33xf32>) {
    %extracted = tensor.extract %arg0[%input_dim0, %input_dim1, %input_dim2] : tensor<20x160x170xf32>
    %0 = math.exp %extracted : f32
    %inserted = tensor.insert %0 into %iter[%shmem_dim0, %shmem_dim1, %shmem_dim2] : tensor<32x1x33xf32>
    xla_gpu.yield %inserted : tensor<32x1x33xf32>
  }

  %synced_tensor = xla_gpu.sync_threads %xla_loop : tensor<32x1x33xf32>

  %xla_loop_0 = xla_gpu.loop (%thread_id_x %block_id_x)[%i, %j] -> (%dim0, %dim1, %dim2)
      in #map1 iter_args(%iter = %arg1) -> (tensor<170x160x20xf32>) {
    // indexing computations
    %extracted = tensor.extract %synced_tensor[%0, %c0, %1] : tensor<32x1x33xf32>
    %2 = math.absf %extracted : f32
    %inserted = tensor.insert %2 into %iter[%3, %4, %1] : tensor<170x160x20xf32>
    xla_gpu.yield %inserted : tensor<170x160x20xf32>
  }
  return %xla_loop_0 : tensor<170x160x20xf32>
}

В этом случае мы генерируем две операции xla_gpu.loop . Первый выполняет объединенное чтение с ввода и записывает результат в общую память.

Тензор общей памяти создается с помощью xla_gpu.allocate_shared op.

После синхронизации потоков с помощью xla_gpu.sync_threads второй xla_gpu.loop считывает элементы из тензора общей памяти и выполняет объединенную запись на выходные данные.

Репродуктор

Чтобы увидеть IR после каждого прохода конвейера компиляции, можно запустить run_hlo_module с флагом --v=5 .

run_hlo_module --platform=CUDA --xla_disable_all_hlo_passes --reference_platform="" --v=5 /tmp/gelu.hlo

где /tmp/gelu.hlo содержит

HloModule m:

gelu {
  %param = bf16[6,512,4096] parameter(0)
  %constant_0 = bf16[] constant(0.5)
  %bcast_0 = bf16[6,512,4096] broadcast(bf16[] %constant_0), dimensions={}
  %constant_1 = bf16[] constant(1)
  %bcast_1 = bf16[6,512,4096] broadcast(bf16[] %constant_1), dimensions={}
  %constant_2 = bf16[] constant(0.79785)
  %bcast_2 = bf16[6,512,4096] broadcast(bf16[] %constant_2), dimensions={}
  %constant_3 = bf16[] constant(0.044708)
  %bcast_3 = bf16[6,512,4096] broadcast(bf16[] %constant_3), dimensions={}
  %square = bf16[6,512,4096] multiply(bf16[6,512,4096] %param, bf16[6,512,4096] %param)
  %cube = bf16[6,512,4096] multiply(bf16[6,512,4096] %square, bf16[6,512,4096] %param)
  %multiply_3 = bf16[6,512,4096] multiply(bf16[6,512,4096] %cube, bf16[6,512,4096] %bcast_3)
  %add_1 = bf16[6,512,4096] add(bf16[6,512,4096] %param, bf16[6,512,4096] %multiply_3)
  %multiply_2 = bf16[6,512,4096] multiply(bf16[6,512,4096] %add_1, bf16[6,512,4096] %bcast_2)
  %tanh_0 = bf16[6,512,4096] tanh(bf16[6,512,4096] %multiply_2)
  %add_0 = bf16[6,512,4096] add(bf16[6,512,4096] %tanh_0, bf16[6,512,4096] %bcast_1)
  %multiply_1 = bf16[6,512,4096] multiply(bf16[6,512,4096] %add_0, bf16[6,512,4096] %bcast_0)
  ROOT %multiply_0 = bf16[6,512,4096] multiply(bf16[6,512,4096] %param, bf16[6,512,4096] %multiply_1)
}

ENTRY main {
  %param = bf16[6,512,4096] parameter(0)
  ROOT fusion = bf16[6,512,4096] fusion(%param), kind=kLoop, calls=gelu
}