翻譯模式

資料模型

StableHLO 程式是對張量 (n 維陣列) 進行的計算。在目前模型中,這些張量會使用類別 Tensor 來實作。Tensor 物件的基礎儲存空間級別 detail::Buffer 會儲存張量的 mlir::ShapedType 以及 mlir::HeapAsmResourceBlob 物件,該物件代表張量資料的可變動 blob,並以從最大到最小的順序排列為連續位元組陣列。detail::Buffer 物件會參照數量,以簡化記憶體管理。

張量的個別元素會使用 Element 類別來表示,該類別會使用擁有 APIntAPFloatpair<APFloat,APFloat> 之一的獨立聯集做為儲存空間。最後一種用於儲存複雜類型的元素。

Tensor 具有下列 API,可以與其個別元素互動:

  • Element Tensor::get(llvm::ArrayRef<int64_t> index):將多維度索引 index 中的個別張量元素擷取為 Element 物件。
  • void Tensor::set(llvm::ArrayRef<int64_t> index, Element element);:將 Element 物件 element 更新為多維度索引 index 中的張量。

翻譯模式的運作方式

解譯器的進入函式如下:

SmallVector<Tensor> eval(func::FuncOp func, ArrayRef<Tensor> args);

這會執行以下動作:

  1. 使用符號表對應 M.,追蹤 func 的 SSA 引數及其相關聯的執行階段 Tensor 值 (位於 args 提供)。
  2. 針對 func 內的每個運算,以 SSACFG 順序進行:
    • 在運算上叫用 eval。針對運算的每個 SSA 運算元,請從 M 擷取其執行階段值,以做為引數提供給 eval 叫用。
    • 追蹤該運算的 SSA 結果,並以 M 為單位的評估值。

(2) 中提及的運算層級 eval 負責實作運算的執行語意。以下為 stablehlo::AddOp 的範例。在範例中,lhsrhs 張量的個別元素會配對擷取為 Element 物件,然後新增這些元素。新增的結果 (Element 物件) 會儲存在最後的 result 張量中。

Tensor eval(AddOp op, const Tensor &lhs, const Tensor &rhs) {
  Tensor result(op.getType());

  for (auto it = result.index_begin(); it != result.index_end(); ++it)
    result.set(*it, lhs.get(*it) + rhs.get(*it));

  return result;
}

整體而言,解譯器的設計適合用於個別作業的 eval 函式實作易讀性,因為其設計旨在做為 StableHLO 的參考實作。例如,我們不會將 eval 定義為範本函式並將其參數化為元素類型,而是封裝在 Element::operator+ 中處理不同元素類型的方式等詳細資料,從而簡化 eval 的實作。

使用解譯器處理固定折疊

我們可以使用解譯器機制,以常數運算元值折疊運算。以下程式碼片段說明使用浮點型運算元折疊 stablehlo::AddOp 實作:

OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
  auto attrs = adaptor.getOperands();
  DenseElementsAttr lhsData = attrs[0].dyn_cast<DenseElementsAttr>();
  DenseElementsAttr rhsData = attrs[1].dyn_cast<DenseElementsAttr>();
  if (!lhsData || !rhsData) return {};

  auto lhs = Tensor(lhsData);
  auto rhs = Tensor(rhsData);
  auto result = eval(*this, lhs, rhs);

  SmallVector<APFloat> values;
  for (auto i = 0; i < result.getNumElements(); ++i) {
    Element element = result.get(i);
    values.push_back(element.getValue().cast<FloatAttr>().getValue());
  }

  return DenseElementsAttr::get(result.getType(), values);
}

我們目前並未積極將解譯器整合至持續折疊中,因為我們並沒有打算為 StableHLO 實作資料夾。但我們未來計劃利用解譯器處理 MHLO 中持續折疊的情形,屆時將會改善上述程式碼片段的人體工學 (例如,我們有可將常數運算封裝至 Tensor 物件並將 Tensor 結果封裝為 OpFoldResult 的輔助函式)。

測試 StableHLO 解譯器

解譯器接受輸入 (A) StableHLO 程式,以及 (B) 資料值要提供給程式,然後產生輸出資料值,這個值會與使用者提供的預期資料值進行比對。資料值 (B) 在程式本身中使用 stablehlo.constant 運算,以硬式編碼的方式寫入。解譯器會評估輸入程式。受測試的運算輸出會透過檢查 (例如 check.expect_eqcheck.expect_almost_eq) 檢查,如下所示。check.expect_eqcheck.expect_eq_const 會檢查任何支援的類型是否有位元等性,check.expect_almost_eqcheck.expect_almost_eq_const 會檢查容許條件內是否接近相等性 (如測試指南 (G6) 中所述),針對浮點和複雜類型進行檢查。

// CHECK-LABEL: Evaluated results of function: add_op_test_ui4
func.func @add_op_test_ui4() {
  %0 = stablehlo.constant dense<[0, 2]> : tensor<2xui4>
  %1 = stablehlo.constant dense<[15, 3]> : tensor<2xui4>
  %2 = stablehlo.add %0, %1 : tensor<2xui4>
  check.expect_eq_const %2, [15, 5] : tensor<2xui4>
  func.return
}

測試公用程式 stablehlo-translate --interpret (程式碼) 會負責剖析程式,解譯每個函式,包括構成函式的作業。我們有專屬的測試套件,其中包含多項測試,可針對每個 StableHLO Op 執行各種執行階段行為。您可以在這裡找到測試 (例如 interpret_*.mlir)。

檢測指引

(G1) 是否需要針對每次運算測試所有支援的類型?

我們可以搭配下列規則決定:

  1. 在實作運算時,如果對應的 eval 函式中有程式碼可處理特定類型的,那麼務必讓測試涵蓋該類型。舉例來說,add 運算可以使用專屬程式碼來處理整數、布林值、浮點和複雜類型,因此每個類型類別都需要測試一個測試。

  2. 如果在對應的 eval 函式中統一處理一組類型,則只要對所有類型執行單一測試即可。舉例來說,在 add 運算中,所有整數類型的變化版本 (si4u4si8u8 等) 都會以類似 llvm::APInt API 的方式處理,因此我們可以略過為每個變化版本新增測試,而改為新增一個代表性測試。為避免選擇代表性明確,我們建議使用下列規範:

    • 如果所有類型都統一處理、具有相同的原始類型 (例如所有類型都是整數、浮點或複雜類型),請選擇位元寬度上限的類型。
    • 如果所有類型都統一處理、混用原始類型,請選擇採用以下原始類型,並按照偏好順序遞減選擇:整數、浮點、布林值、複雜值。

(G2) 如何決定涵蓋運算行為所需的測試數量?

我們的目標是透過最少次數的測試,全面涵蓋運算的解譯器邏輯 (即所有邊角案例)。盡可能減少測試次數對於可維護性至關重要。減少的測試越少,檢查結果就越容易,也更容易確認整個測試涵蓋了該運算。因此,我們預期大多數較簡單的作業最終只會進行一項測試。如果某些充分原因的全方位涵蓋範圍不切實際,則只需停止在 >= 90% 即可。在提取要求審查期間,會視個案情況裁定。

(G3) 如何新增解譯器基礎架構的測試?

解譯器基礎架構大部分都很直接,可新增至我們的信任基礎。唯一的複雜部分是各種類型從基礎解譯器儲存空間封裝及解壓縮的方式。如 (G1) 所述,我們只會測試那些處理不同運算類型的運算。由於封裝/解壓縮程式碼可能會對應至整數/浮點類型的不同變化版本,因此可能無法在測試期間完整涵蓋。為確保完整涵蓋率,我們可以選擇 constant 等運算,以便支援所有 StableHLO 元素類型,並編寫詳盡的測試。

(G4) 如果運算的實作依附於其他作業,是否應該為後者編寫測試?

不可以。舉例來說,batch_norm_grad 的實作可根據 dividesubtractmultiply 等。測試前者時,我們應避免測試後者。

(G5) 我們是否應該撰寫測試,以執行實作定義 / 未定義的行為?

我們不應編寫測試實作已定義或未定義作業的測試。執行實作定義的行為會示範轉譯器的本機行為,而不應一般化這些行為。執行未定義行為的測試無法參與運算的瞭解。

(G6) 在編寫浮點類型的測試時,需要在檢查中指定預期結果的精確度嗎?

針對基本運算 (加減、乘法、除法和平方),按照 IEEE 規格實作應在數學實際結果的 0.5 ULP 內提供四捨五入的結果。也就是說,我們可以看到這些作業的預期結果會相差最多 1 ULP。然而,這可能不適用於已定義精確度保證的傳輸函式 (sinecosine等) (原因)。

目前實作使用的「one-size-fits-all」容許值是 0.0001。以下範例示範上述容許條件的實際運作情形。

func.func @check_tolerance() {
  %0 = stablehlo.constant dense<0.2> : tensor<f32>

  // The following check succeeds as %0 is almost equal to the provided
  // constant modulo the tolerance, mentioned above.
  check.expect_almost_eq_const %0, dense<0.19999> : tensor<f32>

  // The following check fails as %0 is not bitwise equal to the provided
  // constant.
  check.expect_eq_const %0, dense<0.19999> : tensor<f32>

  func.return
}

這只是測試 StableHLO 運算數值準確率的第一步。目前,在 StableHLO 規格的不足區域,我們目前正在根據我們在實務上使用 StableHLO 的經驗及相關人員的意見回饋,找出 #1156 的問題。隨著這項作業的進行,我們就會據此更新基礎架構。

(G7) 測試的程式設計風格是否有任何問題?

  1. 請務必使用輸入/輸出的實際名稱,而非預設為 SSA 值 (例如 %0、%1 等)。
  2. 確認測試採用經過整理的格式 (如果有的話)。

(G8) 我們是否應在規格中加入已提供的範例?是 (為完整測試)。