型別推斷

StableHLO 最初是從 MHLO 方言啟動,並沿用類型推論的 MHLO 實作。status.md 會追蹤實作進度。

下列建議指南旨在確保為 StableHLO 運算實作高品質驗證器和形狀函式。

建議採行的做法

這些提案適用於重新審視現有實作項目,以及實現新的作業,直到全面涵蓋。

(P1) 使用 StableHLO 規格做為可靠資料來源

「規格」spec是 StableHLO 運算所有驗證器和形狀函式的可靠資料來源。每個運算的現有驗證器和形狀函式都必須重新審視,確保與規格保持一致。請注意,規格文件不斷演進。如果無法取得運算規格,請改用 XLA 實作做為可靠資料來源,包括 xla/service/shape_inference.ccxla/service/hlo_verifier.cc。XLA 實作範圍並不涵蓋不受限的合成,因此,在不界定擴散性 RFC 可供使用之前,我們都會套用常識。

(P2) 充分運用 ODS

ODS 檔案 (例如 StablehloOps.td) 定義了每個運算元/屬性/結果的運算與類型運算,並將進行驗證。因此,針對 ODS 保證的屬性,驗證器或形狀函式中不需要驗證碼。如果與 ODS 重複,請移除重複的驗證碼,因為系統永遠不會觸發這些驗證碼。

我們需要針對 ODS 的限制新增測試嗎?請參閱「制定測試規範」。

(P3) 將驗證碼保存在驗證器和形狀函式中

兩者:

  • 驗證器:由 Op::verify() 實作,且
  • 形狀函式:由 InferTypeOpInterface 等項目實作,例如 Op::inferReturnTypes()Op::inferReturnTypeComponents

可能有用來檢查運算元/屬性/結果的驗證碼初始分割可能如下所示:讓驗證器檢查運算元/屬性,然後讓形狀函式只計算推測結果類型,並檢查與實際結果類型的相容性。不過,實際上這個分割會產生幾個問題:

  • 無需先呼叫驗證器,即可由自動產生的 build() 函式呼叫形狀函式。因此,相關的輸入內容也必須在形狀函式中驗證
  • 重複的程式碼:舉例來說,在驗證器中,我們會對運算元進行某些處理,然後驗證一些中繼結果。接著,做為形狀函式,這些中間結果有助於推斷最終結果。這類中繼結果必須計算兩次。
  • 維護負擔:控制機制的驗證包含兩種不同方法。

解決方法如下:

  1. 針對沒有區域的大多數作業 (例如 PadOp):將所有驗證碼放入形狀函式,並完全捨棄驗證器。

  2. 針對具有區域的運算 (例如 ReduceOp/IfOp;如要查看完整清單,請參閱這裡):自動產生的建構工具不會將區域做為參數使用,因此如果這些建構工具會使用類型推論,系統就會使用空白區域呼叫形狀函式 (請參閱這個範例)。

    1. 如果類型推論不需要區域 (例如 ReduceOp),請將區域相關驗證邏輯置於驗證器中,而非形狀函式。如果程式碼不可行,請複製部分程式碼。

    2. 如果需要類型推論 (IfOp/CaseOp/MapOp) 需要區域,則形狀函式也必須明確驗證區域並非明確空白,即使 ODS 已保證其在運算定義中仍存在。

(P4) 制定測試規範

是否需要新增/維護 ODS 涵蓋的驗證測試?

我們不會這麼做。測試應著重於驗證器和形狀函式,而對 ODS 所做的變更則需要重新審視這個運算。

但請小心缺少的部分:舉例來說,如果運算包含特徵 SameOperandsAndResultShape,而該特徵只會檢查形狀,而不包含元素類型,則運算元/結果的元素類型的驗證程序仍需經過測試。

我們會在哪裡放置驗證器和類型推論的測試?

ops_stablehlo.mlir 包含作業的陽性情況,以及 (至少) 每個驗證錯誤都有 1 項陰性測試。此外,還能檢查推測的傳回類型與實際結果類型是否「相容」 (不同於!)。

infer_stablehlo.mlir 使用 hlo_test_infer.get_return_type_components"(%x):... 行驗證運算的形狀函式是否存在,並檢查推測的類型是否完全符合預期。一般而言,每次運算只能有一個陽性測試。

進行方式

實作或重新造訪運算的驗證器和/或形狀函式時:

  1. 請將所有正案例和負值案例放在 ops_stablehlo.mlir 中。

  2. infer_stablehlo.mlir 中新增一項陽性測試以測試介面。

  3. (選用) 如果運算作業相當複雜,且可能包含大量測試,請考慮在相同資料夾中新增一個名為 verify_<op_name>.mlirverify_<your_topic>.mlir 的測試檔案。