StableHLO 中的動態性

如要瞭解動態性的現況,請參閱動態性 RFC。本頁面將概要說明 RFC,並討論與動態程式互動的重要 API 和工具。

動態性術語和支援總覽

首先,我們將介紹這份文件中的幾個詞彙,並簡要說明 StableHLO 對這些詞彙的支援情形:

動態維度

動態尺寸是指尺寸大小不明的任何尺寸。 在 StableHLO 中,我們使用 ? 代表動態維度,也就是 tensor<16x?xf32>

有限動態

有界動態是指值有已知上限的動態維度。一般來說,這有助於在執行期間填補張量。 在 StableHLO 中,我們使用 #stablehlo.bounds 做為張量編碼來表示有界動態,也就是說,一個動態維度以 16 為界,另一個沒有界限的等級 2 張量可以表示為 tensor<?x?xf32, #stablehlo.bounds<16, ?>>

StableHLO 可以表示有界動態,但架構支援有限,源自 TensorFlow,且在 PyTorch/XLA 中提供部分支援。

不受限的動態性

如名稱所示,無界動態是指大小沒有已知界限的動態維度。這類動態性在 StableHLO 中非常常見,且支援 JAX、PyTorch/XLA 和 TF,通常用於匯出具有動態批次大小或序列長度的模型。

在 StableHLO 中,我們只會省略這種動態形式的界線編碼,也就是 tensor<?x?xf32>

形狀多態

形狀多態性是我們從 JAX 沿用的術語

形狀多型有兩項重要影響:

  1. 程式中的所有動態性都可追溯至輸入引數。
  2. 所有動態性都只與張量形狀有關,也就是說,與資料無關。

有了這兩項規則,一旦程式的靜態形狀已知,我們就能取得動態程式,並將其完全精簡為編譯用的靜態程式 (請參閱「Compiler passes for refining dynamic programs」)。

一般來說,形狀多型會使用無界動態性,如果已知引數形狀可產生完全靜態的程式,就不需要猜測如何繫結值。

取決於資料的動態性

資料相關動態是指與張量內資料相關的動態維度大小。標準範例是 nonzeros 函式,該函式會傳回張量值中所有 0 元素的索引。評估資料後才能得知形狀,但通常可使用有界動態性編譯,在潛在輸出張量大小上耗用額外記憶體。

許多與資料相關的動態作業都可以使用有界動態機制來模擬,也就是指定張量大小上限,而硬體通常會透過張量填補來實作這項機制。目前 PyTorch/XLA 和 TensorFlow 支援部分資料相關動態功能,但 JAX 目前不會追蹤導致資料相關動態功能的作業。

匯出含有動態維度的節目

如要瞭解如何匯出具有動態批次大小或序列長度的程式,請參閱 StableHLO 教學課程:

編譯器會傳遞內容,以精簡動態程式

移除動態範圍傳遞管道

有幾個實用的通道可供調整形狀,方便的是,這些通道都已整合到通道管道 createStablehloRemoveDynamismPipeline 中:

void createStablehloRemoveDynamismPipeline(OpPassManager &pm,
                                           TypeRange refinedTypes);

個別通道,用於調整動態

就個別而言,有助於形狀細修的通道包括:

如需最新資訊和範例,請參閱連結的說明文件。

範例:動態功能有何用處?如何使用?

動態性有許多用途,這裡主要著重於形狀多型性的常見用途,也就是建立彈性匯出模型表示法,通常用於表示動態批次大小或序列長度。

靜態 add_one 模型

我們會使用下列簡單的 add_one 模型來示範:

def add_one(x):
  return x + 1

使用 tensor<4xf32> 追蹤時,我們會取得下列 StableHLO 程式:

// File: add_one.mlir
func.func @add_one(%arg0: tensor<4xf32>) -> tensor<4xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<4xf32>
  %0 = stablehlo.add %arg0, %cst : tensor<4xf32>
  return %0 : tensor<4xf32>
}

這個模型適用於具有 tensor<4xf32> 形狀的輸入引數。如果我們變更批次大小或序列長度,就必須重新追蹤原始碼並重新降級至 StableHLO,而且我們不保證仍可存取原始碼!

動態 add_one 模型

這時形狀多型動態功能就能派上用場。反之,JAX 和 PyTorch/XLA 可以發出 add_one 模型,其中包含動態有效的 IR,可廣播常數來配合動態輸入形狀,如下所示:

// File: add_one_dynamic.mlir
func.func public @main(%arg0: tensor<?xf32>) -> tensor<?xf32> {
  %cst = stablehlo.constant dense<1.0> : tensor<f32>
  %0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor<?xf32>) -> tensor<i32>
  %1 = stablehlo.reshape %0 : (tensor<i32>) -> tensor<1xi32>
  %2 = stablehlo.dynamic_broadcast_in_dim %cst, %1, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
  %3 = stablehlo.add %arg0, %2 : tensor<?xf32>
  return %3 : tensor<?xf32>
}

這個模型表示法彈性更高,可延後指定批次大小或序列長度等值。這個模型可部署在支援動態形狀的平台 (例如 AI Edge),也可以使用本說明文件提及的動態傳遞進行調整。

微調動態模型

舉例來說,下列傳遞順序可完全精簡這個程式:

stablehlo-opt add_one_dynamic.mlir \
  --stablehlo-refine-arguments='types=tensor<16xf32>' \
  --stablehlo-refine-shapes \
  --stablehlo-canonicalize-dynamism

程式會逐步轉換,如下所示:

// After stablehlo-refine-arguments: Inputs updated, shapes not propagated
func.func public @main(%arg0: tensor<16xf32>) -> tensor<?xf32> {
  %c = stablehlo.constant dense<16> : tensor<1xi64>
  %0 = stablehlo.custom_call @stablehlo.shape_refinement_operand_wrapper(%arg0, %c) {indices_of_shape_operands = dense<1> : tensor<1xi64>} : (tensor<16xf32>, tensor<1xi64>) -> tensor<?xf32>
  ...
  %3 = stablehlo.dynamic_broadcast_in_dim %cst, %2, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
  %4 = stablehlo.add %0, %3 : tensor<?xf32>
  return %4 : tensor<?xf32>
}

// After stablehlo-refine-shapes: Shapes propagated, dynamic ops still exist
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
  %c = stablehlo.constant dense<16> : tensor<1xi32>
  %0 = stablehlo.dynamic_broadcast_in_dim %cst, %c, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<16xf32>
  %1 = stablehlo.add %arg0, %0 : tensor<16xf32>
  return %1 : tensor<16xf32>
}

// After stablehlo-canonicalize-dynamism: Dynamic ops replaced with static ops
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
  %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<16xf32>
  %1 = stablehlo.add %arg0, %0 : tensor<16xf32>
  return %1 : tensor<16xf32>
}

// (Bonus) Use ` --stablehlo-aggressive-simplification` pass to canonicalize the
// constant broadcast, leaving us with the original static program in this case.
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<16xf32>
  %0 = stablehlo.add %arg0, %cst : tensor<16xf32>
  return %0 : tensor<16xf32>
}