如要瞭解動態性的現況,請參閱動態性 RFC。本頁面將概要說明 RFC,並討論與動態程式互動的重要 API 和工具。
動態性術語和支援總覽
首先,我們將介紹這份文件中的幾個詞彙,並簡要說明 StableHLO 對這些詞彙的支援情形:
動態維度
動態尺寸是指尺寸大小不明的任何尺寸。
在 StableHLO 中,我們使用 ? 代表動態維度,也就是 tensor<16x?xf32>。
有限動態
有界動態是指值有已知上限的動態維度。一般來說,這有助於在執行期間填補張量。
在 StableHLO 中,我們使用 #stablehlo.bounds 做為張量編碼來表示有界動態,也就是說,一個動態維度以 16 為界,另一個沒有界限的等級 2 張量可以表示為 tensor<?x?xf32, #stablehlo.bounds<16, ?>>。
StableHLO 可以表示有界動態,但架構支援有限,源自 TensorFlow,且在 PyTorch/XLA 中提供部分支援。
不受限的動態性
如名稱所示,無界動態是指大小沒有已知界限的動態維度。這類動態性在 StableHLO 中非常常見,且支援 JAX、PyTorch/XLA 和 TF,通常用於匯出具有動態批次大小或序列長度的模型。
在 StableHLO 中,我們只會省略這種動態形式的界線編碼,也就是 tensor<?x?xf32>。
形狀多態
形狀多態性是我們從 JAX 沿用的術語。
形狀多型有兩項重要影響:
- 程式中的所有動態性都可追溯至輸入引數。
- 所有動態性都只與張量形狀有關,也就是說,與資料無關。
有了這兩項規則,一旦程式的靜態形狀已知,我們就能取得動態程式,並將其完全精簡為編譯用的靜態程式 (請參閱「Compiler passes for refining dynamic programs」)。
一般來說,形狀多型會使用無界動態性,如果已知引數形狀可產生完全靜態的程式,就不需要猜測如何繫結值。
取決於資料的動態性
資料相關動態是指與張量內資料相關的動態維度大小。標準範例是 nonzeros 函式,該函式會傳回張量值中所有 0 元素的索引。評估資料後才能得知形狀,但通常可使用有界動態性編譯,在潛在輸出張量大小上耗用額外記憶體。
許多與資料相關的動態作業都可以使用有界動態機制來模擬,也就是指定張量大小上限,而硬體通常會透過張量填補來實作這項機制。目前 PyTorch/XLA 和 TensorFlow 支援部分資料相關動態功能,但 JAX 目前不會追蹤導致資料相關動態功能的作業。
匯出含有動態維度的節目
如要瞭解如何匯出具有動態批次大小或序列長度的程式,請參閱 StableHLO 教學課程:
編譯器會傳遞內容,以精簡動態程式
移除動態範圍傳遞管道
有幾個實用的通道可供調整形狀,方便的是,這些通道都已整合到通道管道 createStablehloRemoveDynamismPipeline 中:
void createStablehloRemoveDynamismPipeline(OpPassManager &pm,
TypeRange refinedTypes);
個別通道,用於調整動態
就個別而言,有助於形狀細修的通道包括:
stablehlo-refine-arguments,以具體張量型別取代輸入引數。stablehlo-refine-shapes,將新的輸入引數形狀資訊傳播至整個程式。stablehlo-canonicalize-dynamism,將動態作業替換為靜態變體。stablehlo-check-shape-assertions,檢查及移除形狀斷言自訂呼叫。
如需最新資訊和範例,請參閱連結的說明文件。
範例:動態功能有何用處?如何使用?
動態性有許多用途,這裡主要著重於形狀多型性的常見用途,也就是建立彈性匯出模型表示法,通常用於表示動態批次大小或序列長度。
靜態 add_one 模型
我們會使用下列簡單的 add_one 模型來示範:
def add_one(x):
return x + 1
使用 tensor<4xf32> 追蹤時,我們會取得下列 StableHLO 程式:
// File: add_one.mlir
func.func @add_one(%arg0: tensor<4xf32>) -> tensor<4xf32> {
%cst = stablehlo.constant dense<1.000000e+00> : tensor<4xf32>
%0 = stablehlo.add %arg0, %cst : tensor<4xf32>
return %0 : tensor<4xf32>
}
這個模型僅適用於具有 tensor<4xf32> 形狀的輸入引數。如果我們變更批次大小或序列長度,就必須重新追蹤原始碼並重新降級至 StableHLO,而且我們不保證仍可存取原始碼!
動態 add_one 模型
這時形狀多型動態功能就能派上用場。反之,JAX 和 PyTorch/XLA 可以發出 add_one 模型,其中包含動態有效的 IR,可廣播常數來配合動態輸入形狀,如下所示:
// File: add_one_dynamic.mlir
func.func public @main(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%cst = stablehlo.constant dense<1.0> : tensor<f32>
%0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor<?xf32>) -> tensor<i32>
%1 = stablehlo.reshape %0 : (tensor<i32>) -> tensor<1xi32>
%2 = stablehlo.dynamic_broadcast_in_dim %cst, %1, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
%3 = stablehlo.add %arg0, %2 : tensor<?xf32>
return %3 : tensor<?xf32>
}
這個模型表示法彈性更高,可延後指定批次大小或序列長度等值。這個模型可部署在支援動態形狀的平台 (例如 AI Edge),也可以使用本說明文件提及的動態傳遞進行調整。
微調動態模型
舉例來說,下列傳遞順序可完全精簡這個程式:
stablehlo-opt add_one_dynamic.mlir \
--stablehlo-refine-arguments='types=tensor<16xf32>' \
--stablehlo-refine-shapes \
--stablehlo-canonicalize-dynamism
程式會逐步轉換,如下所示:
// After stablehlo-refine-arguments: Inputs updated, shapes not propagated
func.func public @main(%arg0: tensor<16xf32>) -> tensor<?xf32> {
%c = stablehlo.constant dense<16> : tensor<1xi64>
%0 = stablehlo.custom_call @stablehlo.shape_refinement_operand_wrapper(%arg0, %c) {indices_of_shape_operands = dense<1> : tensor<1xi64>} : (tensor<16xf32>, tensor<1xi64>) -> tensor<?xf32>
...
%3 = stablehlo.dynamic_broadcast_in_dim %cst, %2, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
%4 = stablehlo.add %0, %3 : tensor<?xf32>
return %4 : tensor<?xf32>
}
// After stablehlo-refine-shapes: Shapes propagated, dynamic ops still exist
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
%cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%c = stablehlo.constant dense<16> : tensor<1xi32>
%0 = stablehlo.dynamic_broadcast_in_dim %cst, %c, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<16xf32>
%1 = stablehlo.add %arg0, %0 : tensor<16xf32>
return %1 : tensor<16xf32>
}
// After stablehlo-canonicalize-dynamism: Dynamic ops replaced with static ops
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
%cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<16xf32>
%1 = stablehlo.add %arg0, %0 : tensor<16xf32>
return %1 : tensor<16xf32>
}
// (Bonus) Use ` --stablehlo-aggressive-simplification` pass to canonicalize the
// constant broadcast, leaving us with the original static program in this case.
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
%cst = stablehlo.constant dense<1.000000e+00> : tensor<16xf32>
%0 = stablehlo.add %arg0, %cst : tensor<16xf32>
return %0 : tensor<16xf32>
}