Dinamisisme di StableHLO

Status dinamisme saat ini secara lebih formal dieja dalam RFC Dinamis. Halaman ini akan memberikan ringkasan tingkat tinggi tentang RFC serta membahas API dan alat penting untuk berinteraksi dengan program dinamis.

Ringkasan Dukungan & Terminologi Dinamisme

Pertama, untuk membahas beberapa istilah yang akan muncul dalam dokumen ini, serta pengantar singkat tentang dukungan mereka di StableHLO:

Dimensi dinamis

Dimensi dinamis mengacu pada dimensi yang ukuran dimensinya tidak diketahui. Di StableHLO, kita merepresentasikan dimensi dinamis menggunakan ?, yaitu tensor<16x?xf32>.

Dinamika terbatas

Dinamisme terbatas mengacu pada dimensi dinamis yang nilainya memiliki batas atas yang diketahui. Umumnya ini berguna untuk padding pada tensor selama eksekusi. Di StableHLO, kita merepresentasikan dinamisme terbatas menggunakan #stablehlo.bounds sebagai enkode tensor, yaitu tensor peringkat 2 dengan satu dimensi dinamis yang dibatasi pada 16 dan yang lainnya tanpa batas dapat direpresentasikan sebagai tensor<?x?xf32, #stablehlo.bounds<16, ?>>.

StableHLO dapat merepresentasikan dinamisme terbatas, tetapi ada dukungan framework yang terbatas, yang berasal dari TensorFlow, dan dengan beberapa dukungan di PyTorch/XLA.

Dinamika tanpa batas

Dinamika tanpa batas seperti namanya mengacu pada dimensi dinamis tanpa batas ukuran yang diketahui. Jenis dinamisme ini sangat umum di StableHLO, dengan dukungan JAX, PyTorch/XLA, dan TF, yang sering digunakan untuk mengekspor model dengan ukuran batch atau panjang urutan dinamis.

Di StableHLO, kita cukup menghapus encoding batas untuk bentuk dinamisme ini, yaitu tensor<?x?xf32>.

Polimorfisme bentuk

Polimorfisme bentuk adalah istilah yang kita warisi dari JAX.

Ada dua implikasi utama untuk membentuk polimorfisme:

  1. Semua dinamisme dalam program ini dilacak kembali ke argumen inputnya.
  2. Semua dinamisme hanya berkaitan dengan bentuk tensor, yaitu tidak bergantung pada data.

Dengan dua aturan ini, setelah bentuk statis program diketahui, kita dapat mengambil program dinamis dan sepenuhnya meningkatkannya menjadi program statis untuk kompilasi (lihat "Proses compiler untuk meningkatkan program dinamis").

Umumnya, polimorfisme bentuk menggunakan dinamisme tanpa batas, jika bentuk argumen yang diketahui dapat mengarah ke program yang sepenuhnya statis, tidak perlu menebak cara membatasi nilai.

Dinamika yang bergantung pada data

Dinamika yang bergantung pada data mengacu pada ukuran dimensi dinamis yang berkaitan dengan data di dalam tensor. Contoh kanonis adalah fungsi nonzeros yang menampilkan indeks semua elemen yang merupakan 0 dalam nilai tensor. Bentuk ini tidak dapat diketahui tanpa mengevaluasi data, tetapi sering kali dapat dikompilasi menggunakan dinamisme terbatas, yang menghabiskan memori tambahan pada ukuran tensor output yang potensial.

Banyak operasi dinamis yang bergantung pada data dapat dimodelkan menggunakan dinamisme terbatas, dengan batas atas pada ukuran tensor ditentukan, dan hardware umumnya akan menerapkan hal ini melalui padding tensor. Saat ini, ada beberapa dukungan untuk dinamisme yang bergantung pada data di PyTorch/XLA dan TensorFlow, tetapi JAX saat ini tidak melacak operasi yang menyebabkan dinamisme yang bergantung pada data.

Mengekspor program dengan dimensi dinamis

Lihat tutorial StableHLO kami untuk mengetahui informasi tentang cara mengekspor program dengan ukuran batch dinamis atau panjang urutan:

Lulusan compiler untuk meningkatkan kualitas program dinamis

Menghapus pipeline kartu dinamisme

Ada beberapa kartu yang berguna untuk meningkatkan kualitas bentuk, dan semuanya dipaketkan dalam pipeline kartu createStablehloRemoveDynamismPipeline:

void createStablehloRemoveDynamismPipeline(OpPassManager &pm,
                                           TypeRange refinedTypes);

Tiket individual untuk meningkatkan dinamisme

Secara terpisah, kartu yang cenderung berguna untuk penyempurnaan bentuk adalah:

Lihat dokumentasi tertaut untuk mengetahui informasi dan contoh terbaru.

Contoh: Apa manfaat dinamisme, dan bagaimana cara menggunakannya?

Dinamisme memiliki banyak kegunaan. Di sini, kita akan berfokus pada kasus penggunaan umum untuk Polimorfisme Bentuk - membuat representasi model yang diekspor secara fleksibel, umumnya digunakan untuk merepresentasikan ukuran batch dinamis atau panjang urutan.

Model add_one statis

Kita akan menggunakan model add_one sederhana berikut untuk menunjukkan hal ini:

def add_one(x):
  return x + 1

Saat dilacak menggunakan tensor<4xf32>, kita akan mendapatkan program StableHLO berikut:

// File: add_one.mlir
func.func @add_one(%arg0: tensor<4xf32>) -> tensor<4xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<4xf32>
  %0 = stablehlo.add %arg0, %cst : tensor<4xf32>
  return %0 : tensor<4xf32>
}

Model ini hanya akan berfungsi untuk argumen input yang memiliki bentuk tensor<4xf32>. Jika pernah mengubah ukuran batch atau panjang urutan, kita harus melacak ulang kode sumber dan menurunkan kembali ke StableHLO, dan tidak ada jaminan bahwa kita masih memiliki akses ke kode sumber.

Model add_one dinamis

Di sinilah dinamisme polimorfik bentuk berperan. Sebagai gantinya, JAX dan PyTorch/XLA dapat memunculkan model add_one dengan IR yang valid secara dinamis yang akan menyiarkan konstanta agar cocok dengan bentuk input dinamis sebagai berikut:

// File: add_one_dynamic.mlir
func.func public @main(%arg0: tensor<?xf32>) -> tensor<?xf32> {
  %cst = stablehlo.constant dense<1.0> : tensor<f32>
  %0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor<?xf32>) -> tensor<i32>
  %1 = stablehlo.reshape %0 : (tensor<i32>) -> tensor<1xi32>
  %2 = stablehlo.dynamic_broadcast_in_dim %cst, %1, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
  %3 = stablehlo.add %arg0, %2 : tensor<?xf32>
  return %3 : tensor<?xf32>
}

Representasi model ini jauh lebih fleksibel, dan memungkinkan spesifikasi nilai yang ditangguhkan seperti ukuran batch atau panjang urutan. Model ini dapat di-deploy di platform dengan dukungan bentuk dinamis (seperti AI Edge), atau dapat ditingkatkan menggunakan kartu dinamisme yang disebutkan dalam dokumentasi ini.

Meningkatkan kualitas model dinamis

Misalnya, pengurutan pass berikut dapat sepenuhnya menyempurnakan program ini:

stablehlo-opt add_one_dynamic.mlir \
  --stablehlo-refine-arguments='types=tensor<16xf32>' \
  --stablehlo-refine-shapes \
  --stablehlo-canonicalize-dynamism

Secara bertahap, berikut cara program diubah:

// After stablehlo-refine-arguments: Inputs updated, shapes not propagated
func.func public @main(%arg0: tensor<16xf32>) -> tensor<?xf32> {
  %c = stablehlo.constant dense<16> : tensor<1xi64>
  %0 = stablehlo.custom_call @stablehlo.shape_refinement_operand_wrapper(%arg0, %c) {indices_of_shape_operands = dense<1> : tensor<1xi64>} : (tensor<16xf32>, tensor<1xi64>) -> tensor<?xf32>
  ...
  %3 = stablehlo.dynamic_broadcast_in_dim %cst, %2, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
  %4 = stablehlo.add %0, %3 : tensor<?xf32>
  return %4 : tensor<?xf32>
}

// After stablehlo-refine-shapes: Shapes propagated, dynamic ops still exist
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
  %c = stablehlo.constant dense<16> : tensor<1xi32>
  %0 = stablehlo.dynamic_broadcast_in_dim %cst, %c, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<16xf32>
  %1 = stablehlo.add %arg0, %0 : tensor<16xf32>
  return %1 : tensor<16xf32>
}

// After stablehlo-canonicalize-dynamism: Dynamic ops replaced with static ops
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
  %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<16xf32>
  %1 = stablehlo.add %arg0, %0 : tensor<16xf32>
  return %1 : tensor<16xf32>
}

// (Bonus) Use ` --stablehlo-aggressive-simplification` pass to canonicalize the
// constant broadcast, leaving us with the original static program in this case.
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<16xf32>
  %0 = stablehlo.add %arg0, %cst : tensor<16xf32>
  return %0 : tensor<16xf32>
}