StableHLO 中的動態性

動態思想的目前狀態在 Dynamism RFC 中更正式地說明,此頁面將提供 RFC 的概要總覽,並討論與動態程式互動的重要 API 和工具。

Dynamism 用語與支援總覽

首先,我們將說明本文件中會出現的幾個詞彙,以及 StableHLO 支援這些詞彙的簡要介紹:

動態維度

動態維度是指尺寸大小不明的任何維度。在 StableHLO 中,我們會使用 ? 表示動態維度,也就是 tensor<16x?xf32>

有限動態

受限的動態維度是指值具有已知上限的動態維度。這通常可用於在執行期間填充張量。在 StableHLO 中,我們使用 #stablehlo.bounds 做為張量編碼,表示受限的動態性,也就是等級 2 張量,其中一個動態維度的上限為 16,另一個沒有上限的維度可表示為 tensor<?x?xf32, #stablehlo.bounds<16, ?>>

StableHLO 能夠代表受限的變數,但 TensorFlow 衍生的架構有限,且某些 PyTorch/XLA 支援有限。

無限動態

顧名思義,無限動態指的是沒有已知大小上限的動態維度。這種命名型態在 StableHLO 中非常常見,而 JAX、PyTorch/XLA 和 TF 支援通常用於匯出具有動態批量或序列長度的模型。

在 StableHLO 中,我們只會採用這種型態的邊界編碼,即 tensor<?x?xf32>

形狀多態性

形狀多態性是從 JAX 繼承的術語

形塑多態性的兩個重要意涵如下:

  1. 程式中的所有變數都會追溯回其輸入引數。
  2. 所有動態功能都只與張量的形狀相關,也就是不依賴資料。

運用這兩項規則後,一旦知道程式的靜態形狀,我們就能採用動態程式,並將其完整修正為靜態程式編譯 (請參閱「修正動態程式的編譯器票證」一節)。

一般來說,形狀多態性會使用無限的動態性,如果已知的引數形狀可導致完全靜態的程式,就不需要猜測如何繫結值。

資料相關動態

資料相關動態性是指與張量內資料相關的動態維度大小。標準範例是 nonzeros 函式,可傳回張量值中所有 0 元素的索引。必須評估資料才能知道形狀,但通常可以使用受限的變數進行編譯,並將額外記憶體用於潛在的輸出張量大小。

許多資料相關的動態運算可使用受限動態功能進行建模,其中會指定張量大小的上限,而硬體通常會透過張量填充來實作這項功能。目前,PyTorch/XLA 和 TensorFlow 支援部分資料相關動態功能,但 JAX 目前不會追蹤導致資料相關動態功能的運算。

匯出含有動態維度的節目

請參閱 StableHLO 教學課程,以瞭解如何匯出具有動態批量或序列長度的程式:

用來改良動態程式的編譯器憑證

移除動態傳遞管道

有幾個實用的傳遞可用於精修形狀,而且這些傳遞都方便地整合在傳遞管道 createStablehloRemoveDynamismPipeline 中:

void createStablehloRemoveDynamismPipeline(OpPassManager &pm,
                                           TypeRange refinedTypes);

個別票證可用於精進動態功能

個別來說,以下是形狀精修時常用的處理程序:

如需最新資訊和範例,請參閱連結的說明文件。

範例:動態化有什麼用處?我該如何使用?

動態性有許多用途,我們在此主要著重於形狀多態性的常見用途,也就是建立靈活的匯出模型表示法,通常用於表示動態批次大小或序列長度。

靜態 add_one 模型

我們將使用以下簡單的 add_one 模型來說明這項功能:

def add_one(x):
  return x + 1

使用 tensor<4xf32> 進行追蹤時,我們會取得下列 StableHLO 程式:

// File: add_one.mlir
func.func @add_one(%arg0: tensor<4xf32>) -> tensor<4xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<4xf32>
  %0 = stablehlo.add %arg0, %cst : tensor<4xf32>
  return %0 : tensor<4xf32>
}

這個模型「只」適用於具有 tensor<4xf32> 形狀的輸入引數。如果我們變更了批次大小或序列長度,就必須重新追蹤原始程式碼,並重新降低至 StableHLO,但我們無法保證能繼續存取原始程式碼!

動態 add_one 模型

這時形狀多態性動態處理機制就派上用場。相反地,JAX 和 PyTorch/XLA 可以使用動態有效的 IR 傳送 add_one 模型,並廣播常數以符合動態輸入形狀,如下所示:

// File: add_one_dynamic.mlir
func.func public @main(%arg0: tensor<?xf32>) -> tensor<?xf32> {
  %cst = stablehlo.constant dense<1.0> : tensor<f32>
  %0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor<?xf32>) -> tensor<i32>
  %1 = stablehlo.reshape %0 : (tensor<i32>) -> tensor<1xi32>
  %2 = stablehlo.dynamic_broadcast_in_dim %cst, %1, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
  %3 = stablehlo.add %arg0, %2 : tensor<?xf32>
  return %3 : tensor<?xf32>
}

這個模型表示法更具彈性,可讓您延後指定批次大小或序列長度等值。這個模型可部署在支援動態形狀的平台 (例如 AI Edge),也可以使用本文件中提及的動態處理程序進行精進。

精進動態模型

舉例來說,下列傳遞順序可充分改善這個程式:

stablehlo-opt add_one_dynamic.mlir \
  --stablehlo-refine-arguments='types=tensor<16xf32>' \
  --stablehlo-refine-shapes \
  --stablehlo-canonicalize-dynamism

逐步轉換程式如下:

// After stablehlo-refine-arguments: Inputs updated, shapes not propagated
func.func public @main(%arg0: tensor<16xf32>) -> tensor<?xf32> {
  %c = stablehlo.constant dense<16> : tensor<1xi64>
  %0 = stablehlo.custom_call @stablehlo.shape_refinement_operand_wrapper(%arg0, %c) {indices_of_shape_operands = dense<1> : tensor<1xi64>} : (tensor<16xf32>, tensor<1xi64>) -> tensor<?xf32>
  ...
  %3 = stablehlo.dynamic_broadcast_in_dim %cst, %2, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
  %4 = stablehlo.add %0, %3 : tensor<?xf32>
  return %4 : tensor<?xf32>
}

// After stablehlo-refine-shapes: Shapes propagated, dynamic ops still exist
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
  %c = stablehlo.constant dense<16> : tensor<1xi32>
  %0 = stablehlo.dynamic_broadcast_in_dim %cst, %c, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<16xf32>
  %1 = stablehlo.add %arg0, %0 : tensor<16xf32>
  return %1 : tensor<16xf32>
}

// After stablehlo-canonicalize-dynamism: Dynamic ops replaced with static ops
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
  %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<16xf32>
  %1 = stablehlo.add %arg0, %0 : tensor<16xf32>
  return %1 : tensor<16xf32>
}

// (Bonus) Use ` --stablehlo-aggressive-simplification` pass to canonicalize the
// constant broadcast, leaving us with the original static program in this case.
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<16xf32>
  %0 = stablehlo.add %arg0, %cst : tensor<16xf32>
  return %0 : tensor<16xf32>
}