Ringkasan Arsitektur XLA:GPU

Pengantar

XLA adalah compiler khusus domain hardware dan framework untuk aljabar linier, yang menawarkan performa terbaik di kelasnya. JAX, TF, Pytorch, dan lainnya menggunakan XLA dengan mengonversi input pengguna ke StableHLO (“operasi tingkat tinggi”: serangkaian ~100 instruksi berbentuk statis seperti penambahan, pengurangan, matmul, dll.), yang kemudian XLA menghasilkan kode yang dioptimalkan untuk berbagai backend:

Selama eksekusi, framework memanggil API runtime PJRT, yang memungkinkan framework melakukan operasi “mengisi buffer yang ditentukan menggunakan program StableHLO tertentu di perangkat tertentu”.

Pipeline XLA:GPU

XLA:GPU menggunakan kombinasi pemancar “asli” (PTX, melalui LLVM) dan pemancar TritonIR untuk menghasilkan kernel GPU berperforma tinggi (warna biru menunjukkan komponen 3P):

Contoh Menjalankan: JAX

Untuk menggambarkan pipeline, mari kita mulai dengan contoh yang berjalan di JAX, yang menghitung matmul yang dikombinasikan dengan perkalian dengan konstanta dan negasi:

def f(a, b):
    return -((a @ b) * 0.125)

Kita dapat memeriksa HLO yang dihasilkan oleh fungsi:

M = 1024
K = 512
N = 2048
key = jax.random.PRNGKey(1701)
a = jax.random.randint(key, (M, K), dtype=jax.numpy.int8, minval=0, maxval=255)
b = jax.random.normal(key, (K, N), dtype=jax.dtypes.bfloat16)

print(jax.xla_computation(f)(a, b).as_hlo_text())

yang menghasilkan:

HloModule xla_computation_f, entry_computation_layout={(s8[1024,512]{1,0}, bf16[512,2048]{1,0})->(bf16[1024,2048]{1,0})}

ENTRY main.10 {
  Arg_0.1 = s8[1024,512]{1,0} parameter(0)
  convert.5 = bf16[1024,512]{1,0} convert(Arg_0.1)
  Arg_1.2 = bf16[512,2048]{1,0} parameter(1)
  dot.6 = bf16[1024,2048]{1,0} dot(convert.5, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
  constant.3 = bf16[] constant(0.125)
  broadcast.4 = bf16[1024,2048]{1,0} broadcast(constant.3), dimensions={}
  multiply.7 = bf16[1024,2048]{1,0} multiply(dot.6, broadcast.4)
  ROOT negate.8 = bf16[1024,2048]{1,0} negate(multiply.7)
}

Kita juga dapat memvisualisasikan komputasi HLO input, menggunakan jax.xla_computation(f)(a, b).as_hlo_dot_graph():

Pengoptimalan di HLO: Komponen Utama

Sejumlah pengoptimalan penting terjadi di HLO, sebagai penulisan ulang HLO->HLO.

Partisi SPMD

Partisi SPMD XLA, seperti yang dijelaskan dalam GSPMD: General and Scalable Parallelization for MLComputation Graphs, menggunakan HLO dengan anotasi sharding (dihasilkan misalnya oleh jax.pjit), dan menghasilkan HLO yang di-shard yang kemudian dapat berjalan di sejumlah host dan perangkat. Selain partisi, SPMD mencoba mengoptimalkan HLO untuk jadwal eksekusi yang optimal, komputasi dan komunikasi yang tumpang-tindih antar-node.

Contoh

Pertimbangkan untuk memulai dari program JAX sederhana yang di-shard di dua perangkat:

# Defines a mesh with two axes called ‘x’ and ‘y’,
# sharded across two devices: first and second CPU.
with jax.sharding.Mesh(
      [['cpu:0', 'cpu:1']], ('x', 'y')):

    @pjit
    def f(a, b):
        out = -((a @ b) * 0.125)
        # Shard output matrix access across ‘x’
        # and ‘y’ respectively. Generates ‘Sharding’
        # custom call.
        out = with_sharding_constraint(
          out, jax.lax.PartitionSpec('x', 'y'))
        return out

# Random inputs to call our function.
a = jax.random.randint(key, (1024, 512), jnp.int8)
b = jax.random.normal(key, (512, 2048), jnp.float32)

print(f.lower(a, b).compiler_ir())

Secara visual, anotasi sharding ditampilkan sebagai panggilan kustom:

Untuk memeriksa cara pemartisi SPMD memperluas panggilan kustom, kita dapat melihat HLO setelah pengoptimalan:

print(f.lower(np.ones((8, 8)).compile().as_text())

Yang menghasilkan HLO dengan kolektif:

Penetapan Tata Letak

HLO memisahkan bentuk logis dan tata letak fisik (cara tensor ditata di memori). Misalnya, matriks f32[32, 64] dapat direpresentasikan dalam urutan baris-utama atau kolom-utama, yang masing-masing direpresentasikan sebagai {1,0} atau {0,1}. Secara umum, tata letak direpresentasikan sebagai bagian dari bentuk, yang menunjukkan permutasi atas jumlah dimensi yang menunjukkan tata letak fisik dalam memori.

Untuk setiap operasi yang ada di HLO, langkah Layout Assignment memilih tata letak optimal (misalnya NHWC untuk konvolusi di Ampere). Misalnya, operasi int8xint8->int32 matmul lebih memilih tata letak {0,1} untuk sisi kanan penghitungan. Demikian pula, “transposisi” yang dimasukkan oleh pengguna diabaikan, dan dikodekan sebagai perubahan tata letak.

Tata letak kemudian disebarkan melalui grafik, dan konflik antara tata letak atau di endpoint grafik diwujudkan sebagai operasi copy, yang melakukan transposisi fisik. Misalnya, mulai dari grafik

Dengan menjalankan penetapan tata letak, kita akan melihat tata letak dan operasi copy berikut yang dimasukkan:

Fusion

Fusi adalah pengoptimalan terpenting XLA, yang mengelompokkan beberapa operasi (misalnya, penambahan ke eksponensiasi ke matmul) ke dalam satu kernel. Karena banyak workload GPU cenderung terikat memori, fusi secara dramatis mempercepat eksekusi dengan menghindari penulisan tensor perantara ke HBM dan kemudian membacanya kembali, dan sebagai gantinya, meneruskannya di register atau memori bersama.

Petunjuk HLO gabungan diblokir bersama dalam satu komputasi gabungan, yang menetapkan invarian berikut:

  • Tidak ada penyimpanan perantara di dalam fusi yang diwujudkan dalam HBM (semuanya harus diteruskan melalui register atau memori bersama).

  • Fusi selalu dikompilasi ke tepat satu kernel GPU

Pengoptimalan HLO pada Contoh yang Sedang Berjalan

Kita dapat memeriksa HLO pasca-pengoptimalan menggunakan jax.jit(f).lower(a, b).compile().as_text(), dan memverifikasi bahwa satu fusi telah dibuat:

HloModule jit_f, is_scheduled=true, entry_computation_layout={(s8[3,2]{1,0}, bf16[2,3]{1,0})->bf16[3,3]{1,0} }, allow_spmd_sharding_propagation_to_output={true}

%triton_gemm_dot.6_computation (parameter_0: s8[3,2], parameter_1: bf16[2,3]) -> bf16[3,3] {
  %parameter_0 = s8[3,2]{1,0} parameter(0)
  %convert.0 = bf16[3,2]{1,0} convert(s8[3,2]{1,0} %parameter_0)
  %parameter_1 = bf16[2,3]{1,0} parameter(1)
  %dot.0 = bf16[3,3]{1,0} dot(bf16[3,2]{1,0} %convert.0, bf16[2,3]{1,0} %parameter_1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
  %convert.1 = f32[3,3]{1,0} convert(bf16[3,3]{1,0} %dot.0)
  %constant_0 = bf16[] constant(0.125)
  %broadcast.0 = bf16[3,3]{1,0} broadcast(bf16[] %constant_0), dimensions={}
  %convert.2 = f32[3,3]{1,0} convert(bf16[3,3]{1,0} %broadcast.0)
  %multiply.0 = f32[3,3]{1,0} multiply(f32[3,3]{1,0} %convert.1, f32[3,3]{1,0} %convert.2)
  %negate.0 = f32[3,3]{1,0} negate(f32[3,3]{1,0} %multiply.0)
  ROOT %convert.6 = bf16[3,3]{1,0} convert(f32[3,3]{1,0} %negate.0)
}

ENTRY %main.9 (Arg_0.1: s8[3,2], Arg_1.2: bf16[2,3]) -> bf16[3,3] {
  %Arg_1.2 = bf16[2,3]{1,0} parameter(1), sharding={replicated}
  %Arg_0.1 = s8[3,2]{1,0} parameter(0), sharding={replicated}
  ROOT %triton_gemm_dot.6 = bf16[3,3]{1,0} fusion(s8[3,2]{1,0} %Arg_0.1, bf16[2,3]{1,0} %Arg_1.2), kind=kCustom, calls=%triton_gemm_dot.6_computation, backend_config={"kind":"__triton_gemm","triton_gemm_config":{"block_m":"64","block_n":"64","block_k":"64","split_k":"1","num_stages":"2","num_warps":"4"} }
}

Perhatikan bahwa penggabungan backend_config memberi tahu kita bahwa Triton akan digunakan sebagai strategi pembuatan kode, dan menentukan pengelompokan yang dipilih.

Kita juga dapat memvisualisasikan modul yang dihasilkan:

Penugasan dan Penjadwalan Buffer

Penerusan penetapan buffer mempertimbangkan informasi bentuk, dan bertujuan untuk menghasilkan alokasi buffer yang optimal untuk program, sehingga meminimalkan jumlah memori perantara yang digunakan. Tidak seperti eksekusi mode langsung (non-dikompilasi) TF atau PyTorch, di mana pengalokasi memori tidak mengetahui grafiknya terlebih dahulu, penjadwal XLA dapat “melihat ke masa depan” dan menghasilkan jadwal komputasi yang optimal.

Backend Compiler: Pemilihan Codegen dan Library

Untuk setiap instruksi HLO dalam komputasi, XLA memilih apakah akan menjalankannya menggunakan library yang ditautkan ke runtime, atau membuat kode PTX.

Pilihan Perpustakaan

Untuk banyak operasi umum, XLA:GPU menggunakan library berperforma tinggi dari NVIDIA, seperti cuBLAS, cuDNN, dan NCCL. Library ini memiliki keunggulan performa cepat yang terverifikasi, tetapi sering kali menghalangi peluang fusi yang kompleks.

Pembuatan kode langsung

Backend XLA:GPU menghasilkan LLVM IR berperforma tinggi secara langsung untuk sejumlah operasi (pengurangan, transposisi, dll.).

Pembuatan kode Triton

Untuk fusi yang lebih canggih yang mencakup perkalian matriks atau softmax, XLA:GPU menggunakan Triton sebagai lapisan pembuatan kode. Fusi HLO dikonversi menjadi TritonIR (dialek MLIR yang berfungsi sebagai input ke Triton), memilih parameter pengelompokan dan memanggil Triton untuk pembuatan PTX:

Kami telah mengamati bahwa kode yang dihasilkan berperforma sangat baik di Ampere, dengan performa mendekati batas atas dengan ukuran petak yang disetel dengan benar.

Runtime

XLA Runtime mengonversi urutan panggilan kernel CUDA dan pemanggilan library yang dihasilkan menjadi RuntimeIR (dialek MLIR di XLA), yang kemudian diekstrak grafiknya di CUDA. Grafik CUDA masih dalam proses, hanya beberapa node yang didukung saat ini. Setelah batas grafik CUDA diekstrak, RuntimeIR dikompilasi melalui LLVM ke CPU yang dapat dieksekusi, yang kemudian dapat disimpan atau ditransfer untuk kompilasi Ahead-Of-Time.