Dinamisisme di StableHLO

Status dinamisme saat ini dijelaskan secara lebih formal dalam RFC Dinamisme, halaman ini akan memberikan ringkasan tingkat tinggi RFC dan membahas API serta alat penting untuk berinteraksi dengan program dinamis.

Terminologi Dinamisme & Ringkasan Dukungan

Pertama, untuk membahas beberapa istilah yang akan muncul dalam dokumen ini, serta pengantar singkat tentang dukungan StableHLO:

Dimensi dinamis

Dimensi dinamis mengacu pada dimensi apa pun yang ukuran dimensinya tidak diketahui. Di StableHLO, kita merepresentasikan dimensi dinamis menggunakan ?, yaitu tensor<16x?xf32>.

Dinamisme terikat

Dinamisme terbatas mengacu pada dimensi dinamis yang nilainya memiliki batas atas yang diketahui. Umumnya, hal ini berguna untuk menambahkan padding pada tensor selama eksekusi. Di StableHLO, kita merepresentasikan dinamisme terbatas menggunakan #stablehlo.bounds sebagai encoding tensor, yaitu tensor berperingkat 2 dengan satu dimensi dinamis yang dibatasi pada 16 dan dimensi lainnya tanpa batas dapat direpresentasikan sebagai tensor<?x?xf32, #stablehlo.bounds<16, ?>>.

StableHLO dapat merepresentasikan dinamisme terbatas, tetapi dukungan framework-nya terbatas, berasal dari TensorFlow, dan dengan beberapa dukungan di PyTorch/XLA.

Dinamisme tanpa batas

Dinamisme tanpa batas, seperti namanya, mengacu pada dimensi dinamis tanpa batas ukuran yang diketahui. Jenis dinamisme ini sangat umum di StableHLO, dengan dukungan JAX, PyTorch/XLA, dan TF, sering digunakan untuk mengekspor model dengan ukuran batch atau panjang urutan dinamis.

Di StableHLO, kita cukup menghilangkan encoding batas untuk bentuk dinamisme ini, yaitu tensor<?x?xf32>.

Polimorfisme bentuk

Polimorfisme bentuk adalah istilah yang kami warisi dari JAX.

Ada dua implikasi utama untuk membentuk polimorfisme:

  1. Semua dinamisme dalam program dapat dilacak kembali ke argumen inputnya.
  2. Semua dinamisme hanya berkaitan dengan bentuk tensor, yaitu tidak bergantung pada data.

Dengan kedua aturan ini, setelah bentuk statis suatu program diketahui, kita dapat mengambil program dinamis dan sepenuhnya menyempurnakannya menjadi program statis untuk kompilasi (lihat "Compiler passes for refining dynamic programs").

Umumnya, polimorfisme bentuk menggunakan dinamisme tanpa batas. Jika bentuk argumen yang diketahui dapat menghasilkan program yang sepenuhnya statis, tidak perlu menebak cara membatasi nilai.

Dinamisme yang bergantung pada data

Dinamisme yang bergantung pada data mengacu pada ukuran dimensi dinamis yang berkaitan dengan data di dalam tensor. Contoh kanonis adalah fungsi nonzeros yang menampilkan indeks semua elemen yang 0 dalam nilai tensor. Bentuknya tidak dapat diketahui tanpa mengevaluasi data, tetapi sering kali dapat dikompilasi menggunakan dinamisme terbatas, dengan membelanjakan memori tambahan untuk ukuran tensor output potensial.

Banyak operasi dinamis yang bergantung pada data dapat dimodelkan menggunakan dinamisme terbatas, dengan batas atas ukuran tensor ditentukan, dan hardware umumnya akan menerapkan hal ini melalui padding tensor. Saat ini ada beberapa dukungan untuk dinamisme yang bergantung pada data di PyTorch/XLA dan TensorFlow, tetapi JAX saat ini tidak melacak operasi yang menyebabkan dinamisme yang bergantung pada data.

Mengekspor program dengan dimensi dinamis

Lihat tutorial StableHLO kami untuk mengetahui informasi tentang cara mengekspor program dengan ukuran batch atau panjang urutan dinamis:

Compiler meneruskan untuk menyempurnakan program dinamis

Menghapus pipeline kartu dinamisme

Ada beberapa kartu yang berguna untuk menyempurnakan bentuk, dan semuanya dikelompokkan dalam pipeline kartu createStablehloRemoveDynamismPipeline:

void createStablehloRemoveDynamismPipeline(OpPassManager &pm,
                                           TypeRange refinedTypes);

Setiap kartu untuk menyempurnakan dinamisme

Secara terpisah, lintasan yang cenderung berguna untuk penyempurnaan bentuk adalah:

Lihat dokumentasi terkait untuk mengetahui informasi dan contoh terbaru.

Contoh: Apa manfaat dinamis, dan bagaimana cara menggunakannya?

Dinamisme memiliki banyak kegunaan. Di sini, kita akan berfokus pada kasus penggunaan umum untuk Polimorfisme Bentuk - membuat representasi model yang diekspor secara fleksibel, yang umumnya digunakan untuk merepresentasikan ukuran batch dinamis atau panjang urutan.

Model add_one statis

Kita akan menggunakan model add_one sederhana berikut untuk mendemonstrasikannya:

def add_one(x):
  return x + 1

Saat dilacak menggunakan tensor<4xf32>, kita akan mendapatkan program StableHLO berikut:

// File: add_one.mlir
func.func @add_one(%arg0: tensor<4xf32>) -> tensor<4xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<4xf32>
  %0 = stablehlo.add %arg0, %cst : tensor<4xf32>
  return %0 : tensor<4xf32>
}

Model ini hanya akan berfungsi untuk argumen input yang memiliki bentuk tensor<4xf32>. Jika kita pernah mengubah ukuran batch atau panjang urutan, kita perlu menelusuri ulang kode sumber dan menurunkan ulang ke StableHLO, dan tidak ada jaminan bahwa kita masih memiliki akses ke kode sumber.

Model add_one dinamis

Di sinilah dinamisme polimorfik bentuk berperan. Sebagai gantinya, JAX dan PyTorch/XLA dapat memancarkan model add_one dengan IR yang valid secara dinamis yang akan menyiarkan konstanta agar sesuai dengan bentuk input dinamis sebagai berikut:

// File: add_one_dynamic.mlir
func.func public @main(%arg0: tensor<?xf32>) -> tensor<?xf32> {
  %cst = stablehlo.constant dense<1.0> : tensor<f32>
  %0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor<?xf32>) -> tensor<i32>
  %1 = stablehlo.reshape %0 : (tensor<i32>) -> tensor<1xi32>
  %2 = stablehlo.dynamic_broadcast_in_dim %cst, %1, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
  %3 = stablehlo.add %arg0, %2 : tensor<?xf32>
  return %3 : tensor<?xf32>
}

Representasi model ini jauh lebih fleksibel, dan memungkinkan spesifikasi nilai yang ditangguhkan seperti ukuran batch atau panjang urutan. Model ini dapat di-deploy di platform dengan dukungan bentuk dinamis (seperti AI Edge), atau dapat disempurnakan menggunakan lintasan dinamisme yang disebutkan dalam dokumentasi ini.

Memperbaiki model dinamis

Misalnya, pengurutan kartu berikut dapat sepenuhnya menyempurnakan program ini:

stablehlo-opt add_one_dynamic.mlir \
  --stablehlo-refine-arguments='types=tensor<16xf32>' \
  --stablehlo-refine-shapes \
  --stablehlo-canonicalize-dynamism

Secara bertahap, program ini akan diubah sebagai berikut:

// After stablehlo-refine-arguments: Inputs updated, shapes not propagated
func.func public @main(%arg0: tensor<16xf32>) -> tensor<?xf32> {
  %c = stablehlo.constant dense<16> : tensor<1xi64>
  %0 = stablehlo.custom_call @stablehlo.shape_refinement_operand_wrapper(%arg0, %c) {indices_of_shape_operands = dense<1> : tensor<1xi64>} : (tensor<16xf32>, tensor<1xi64>) -> tensor<?xf32>
  ...
  %3 = stablehlo.dynamic_broadcast_in_dim %cst, %2, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
  %4 = stablehlo.add %0, %3 : tensor<?xf32>
  return %4 : tensor<?xf32>
}

// After stablehlo-refine-shapes: Shapes propagated, dynamic ops still exist
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
  %c = stablehlo.constant dense<16> : tensor<1xi32>
  %0 = stablehlo.dynamic_broadcast_in_dim %cst, %c, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<16xf32>
  %1 = stablehlo.add %arg0, %0 : tensor<16xf32>
  return %1 : tensor<16xf32>
}

// After stablehlo-canonicalize-dynamism: Dynamic ops replaced with static ops
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
  %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<16xf32>
  %1 = stablehlo.add %arg0, %0 : tensor<16xf32>
  return %1 : tensor<16xf32>
}

// (Bonus) Use ` --stablehlo-aggressive-simplification` pass to canonicalize the
// constant broadcast, leaving us with the original static program in this case.
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<16xf32>
  %0 = stablehlo.add %arg0, %cst : tensor<16xf32>
  return %0 : tensor<16xf32>
}