StableHLO でのダイナミズム

動的処理の現在の状態は、Dynamism RFC でより正式に説明されています。このページでは、RFC の概要と、動的プログラムを操作するための重要な API とツールについて説明します。

Dynamism の用語とサポートの概要

まず、このドキュメントで使用するいくつかの用語と、StableHLO でのサポートについて簡単に説明します。

動的な次元

動的ディメンションとは、ディメンションのサイズが不明なディメンションのことです。StableHLO では、動的ディメンションを ?tensor<16x?xf32> など)で表します。

制限付きダイナミズム

制限付きダイナミズムとは、値の上限が既知のダイナミック ディメンションのことです。これは通常、実行時にテンソルのパディングに役立ちます。StableHLO では、制限付きの動的性質をテンソル エンコードとして #stablehlo.bounds を使用して表します。つまり、1 つの動的ディメンションが 16 に制限され、もう 1 つの動的ディメンションが制限なしのランク 2 テンソルは tensor<?x?xf32, #stablehlo.bounds<16, ?>> として表すことができます。

StableHLO は制限付きの動的性を表すことができますが、フレームワークのサポートは限定的です。TensorFlow で始まり、PyTorch/XLA で一部サポートされています。

無限のダイナミズム

無制限のダイナミズムは、名前が示すように、サイズの上限が不明な動的ディメンションを指します。このタイプのダイナミズムは、JAX、PyTorch/XLA、TF をサポートする StableHLO で非常に一般的であり、動的バッチサイズまたはシーケンス長のモデルのエクスポートによく使用されます。

StableHLO では、この形式の動的性(tensor<?x?xf32>)の境界エンコードを単に省略します。

シェイプのポリモーフィズム

シェイプ ポリモーフィズムは、JAX から継承した用語です。

シェイプ ポリモーフィズムには、次の 2 つの重要な影響があります。

  1. プログラムの動的な部分はすべて、入力引数に起因します。
  2. すべてのダイナミズムはテンソルの形状のみに関連し、データに依存しません。

これらの 2 つのルールにより、プログラムの静的シェイプが判明すると、動的プログラムを取得して、コンパイルの静的プログラムに完全に精製できます(「動的プログラムの精製のためのコンパイラ パス」をご覧ください)。

通常、形状ポリモーフィズムでは無制限の動的性が使用されます。既知の引数の形状から完全に静的なプログラムを作成できる場合は、値をバウンディングする方法について推測する必要はありません。

データ依存のダイナミズム

データ依存のダイナミズムとは、テンソル内のデータに関連する動的ディメンション サイズを指します。標準的な例は、テンソル値で 0 であるすべての要素のインデックスを返す nonzeros 関数です。データを評価しないと形状はわかりませんが、多くの場合、制限付きの動的性を使用してコンパイルでき、潜在的な出力テンソルのサイズに追加のメモリを使用できます。

データ依存の動的オペレーションの多くは、制限付きの動的性を使用してモデル化できます。この場合、テンソルサイズの上限が指定され、ハードウェアは通常、テンソル パディングを使用してこれを実装します。現在、PyTorch/XLA と TensorFlow ではデータ依存のダイナミズムがサポートされていますが、JAX は現在、データ依存のダイナミズムにつながるオペレーションをトレースしていません。

動的ディメンションを含むプログラムをエクスポートする

動的なバッチサイズまたはシーケンス長を使用してプログラムをエクスポートする方法については、StableHLO のチュートリアルをご覧ください。

動的プログラムを改良するためのコンパイラ パス

ダイナミズム パス パイプラインを削除

シェイプの調整に役立つパスがいくつかあります。これらはすべて、パス パイプライン createStablehloRemoveDynamismPipeline にバンドルされています。

void createStablehloRemoveDynamismPipeline(OpPassManager &pm,
                                           TypeRange refinedTypes);

ダイナミズムを調整するための個別のパス

個別のパスは、シェイプの改良に役立つ傾向があります。

最新情報と例については、リンク先のドキュメントをご覧ください。

例: ダイナミズムの利点と使い方

ダイナミズムには多くの用途がありますが、ここでは主にシェイプ ポリモーフィズムの一般的なユースケース(柔軟なエクスポートされたモデル表現の作成)に焦点を当てます。これは通常、動的バッチサイズまたはシーケンス長の表現に使用されます。

静的 add_one モデル

これを説明するために、次のシンプルな add_one モデルを使用します。

def add_one(x):
  return x + 1

tensor<4xf32> を使用してトレースすると、次の StableHLO プログラムが生成されます。

// File: add_one.mlir
func.func @add_one(%arg0: tensor<4xf32>) -> tensor<4xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<4xf32>
  %0 = stablehlo.add %arg0, %cst : tensor<4xf32>
  return %0 : tensor<4xf32>
}

このモデルは、tensor<4xf32> シェイプの入力引数でのみ機能します。バッチサイズやシーケンス長を変更した場合は、ソースコードを再トレースして StableHLO に再変換する必要があります。また、ソースコードにアクセスできる保証もありません。

動的 add_one モデル

ここで、形状多型の動的性が役に立ちます。代わりに、JAX と PyTorch/XLA は、動的に有効な IR を使用して add_one モデルを出力できます。このモデルは、次のように動的な入力シェイプに合わせて定数をブロードキャストします。

// File: add_one_dynamic.mlir
func.func public @main(%arg0: tensor<?xf32>) -> tensor<?xf32> {
  %cst = stablehlo.constant dense<1.0> : tensor<f32>
  %0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor<?xf32>) -> tensor<i32>
  %1 = stablehlo.reshape %0 : (tensor<i32>) -> tensor<1xi32>
  %2 = stablehlo.dynamic_broadcast_in_dim %cst, %1, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
  %3 = stablehlo.add %arg0, %2 : tensor<?xf32>
  return %3 : tensor<?xf32>
}

このモデル表現ははるかに柔軟で、バッチサイズやシーケンス長などの値を後で指定できます。このモデルは、動的シェイプをサポートするプラットフォーム(AI Edge など)にデプロイできます。また、このドキュメントで説明するダイナミズム パスを使用して、モデルを改良することもできます。

動的モデルの改良

たとえば、次のパス順序でこのプログラムを完全に洗練できます。

stablehlo-opt add_one_dynamic.mlir \
  --stablehlo-refine-arguments='types=tensor<16xf32>' \
  --stablehlo-refine-shapes \
  --stablehlo-canonicalize-dynamism

プログラムは、以下のように段階的に変換されます。

// After stablehlo-refine-arguments: Inputs updated, shapes not propagated
func.func public @main(%arg0: tensor<16xf32>) -> tensor<?xf32> {
  %c = stablehlo.constant dense<16> : tensor<1xi64>
  %0 = stablehlo.custom_call @stablehlo.shape_refinement_operand_wrapper(%arg0, %c) {indices_of_shape_operands = dense<1> : tensor<1xi64>} : (tensor<16xf32>, tensor<1xi64>) -> tensor<?xf32>
  ...
  %3 = stablehlo.dynamic_broadcast_in_dim %cst, %2, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
  %4 = stablehlo.add %0, %3 : tensor<?xf32>
  return %4 : tensor<?xf32>
}

// After stablehlo-refine-shapes: Shapes propagated, dynamic ops still exist
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
  %c = stablehlo.constant dense<16> : tensor<1xi32>
  %0 = stablehlo.dynamic_broadcast_in_dim %cst, %c, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<16xf32>
  %1 = stablehlo.add %arg0, %0 : tensor<16xf32>
  return %1 : tensor<16xf32>
}

// After stablehlo-canonicalize-dynamism: Dynamic ops replaced with static ops
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
  %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<16xf32>
  %1 = stablehlo.add %arg0, %0 : tensor<16xf32>
  return %1 : tensor<16xf32>
}

// (Bonus) Use ` --stablehlo-aggressive-simplification` pass to canonicalize the
// constant broadcast, leaving us with the original static program in this case.
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<16xf32>
  %0 = stablehlo.add %arg0, %cst : tensor<16xf32>
  return %0 : tensor<16xf32>
}