通訳の設計

データモデル

StableHLO プログラムはテンソル(n 次元配列)に対する計算であり、現在のモデルでは Tensor クラスを使用して実装されています。Tensor オブジェクトの基盤となるストレージ クラス detail::Buffer には、テンソルの mlir::ShapedType と、テンソルデータの可変 blob を表す mlir::HeapAsmResourceBlob オブジェクトが、大からマイナーの順に連続したバイト配列として格納されます。detail::Buffer オブジェクトは、メモリ管理を簡素化するために参照カウントされます。

テンソルの個々の要素は Element クラスを使用して表現されます。このクラスは、APIntAPFloatpair<APFloat,APFloat> のいずれかをストレージに保持する識別共用体を使用します。後者は複雑な型を持つ要素を格納するために使用されます。

Tensor には、個々の要素を操作する次の API があります。

  • Element Tensor::get(llvm::ArrayRef<int64_t> index): 多次元インデックス index にある個々のテンソル要素を Element オブジェクトとして抽出します。
  • void Tensor::set(llvm::ArrayRef<int64_t> index, Element element);: Element オブジェクト element を多次元インデックス index のテンソルに更新します。

インタープリタの仕組み

インタープリタへの entry 関数は、

SmallVector<Tensor> eval(func::FuncOp func, ArrayRef<Tensor> args);

これは次の処理を行います。

  1. シンボル テーブル マップ M を使用して、func の SSA 引数と、args で提供される関連するランタイム Tensor 値を追跡します。
  2. func 内の各演算について、SSACFG の順序で次のように設定します。
    • op に対して eval を呼び出します。op の SSA オペランドごとに、そのランタイム値を M から抽出し、eval 呼び出しの引数として提供されます。
    • 演算の SSA 結果と評価された値を M で追跡します。

(2)に記載されている演算レベルの eval は、演算の実行セマンティクスを実装します。以下に、stablehlo::AddOp の例を示します。この例では、lhs テンソルと rhs テンソルの個々の要素が Element オブジェクトとしてペア単位で抽出され、その後追加されます。加算の結果である Element オブジェクトは最終的な result テンソルに格納されます。

Tensor eval(AddOp op, const Tensor &lhs, const Tensor &rhs) {
  Tensor result(op.getType());

  for (auto it = result.index_begin(); it != result.index_end(); ++it)
    result.set(*it, lhs.get(*it) + rhs.get(*it));

  return result;
}

全体として、インタープリタの設計は、StableHLO のリファレンス実装として機能することを意図しているため、個々の演算の eval 関数の実装を読みやすくするために最適化されています。たとえば、eval をテンプレート関数として定義し、要素タイプでパラメータ化する代わりに、さまざまな要素タイプの処理に関する詳細を Element::operator+ などでカプセル化して、eval の実装を簡素化しています。

定数フォールディングにインタープリタを使用する

インタープリタのメカニズムを使用して、定数オペランド値を持つ演算を折りたたみます。次のコード スニペットは、浮動小数点型オペランドで stablehlo::AddOp を折りたたむための実装のアイデアを示しています。

OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
  auto attrs = adaptor.getOperands();
  DenseElementsAttr lhsData = attrs[0].dyn_cast<DenseElementsAttr>();
  DenseElementsAttr rhsData = attrs[1].dyn_cast<DenseElementsAttr>();
  if (!lhsData || !rhsData) return {};

  auto lhs = Tensor(lhsData);
  auto rhs = Tensor(rhsData);
  auto result = eval(*this, lhs, rhs);

  SmallVector<APFloat> values;
  for (auto i = 0; i < result.getNumElements(); ++i) {
    Element element = result.get(i);
    values.push_back(element.getValue().cast<FloatAttr>().getValue());
  }

  return DenseElementsAttr::get(result.getType(), values);
}

現時点では、StableHLO 用のフォルダを実装する予定がないため、インタープリタを定数折りたたみに統合するための積極的な取り組みは行っていません。ただし将来的には、MHLO での定数折りたたみにインタープリタを活用することを計画しており、その時点で上記のコード スニペットのエルゴノミクスを改善します(たとえば、定数オペランドを Tensor オブジェクトにパックし、Tensor の結果を OpFoldResult に展開するヘルパー関数を用意できます)。

StableHLO インタープリタのテスト

インタープリタは、(A)StableHLO プログラムと(B)プログラムにフィードされるデータ値を入力として受け取り、出力データ値を生成します。出力データ値は、ユーザーが指定した想定データ値と照合されます。データ値(B)は、stablehlo.constant 演算を使用してプログラム自体にハードコードされます。インタープリタは入力プログラムを評価します。テスト対象の演算の出力は、以下に示すようにチェック(check.expect_eqcheck.expect_almost_eq など)によってチェックされます。check.expect_eqcheck.expect_eq_const はサポートされているすべての型についてビット単位等価かどうかを確認します。check.expect_almost_eqcheck.expect_almost_eq_const は、浮動小数点型と複雑な型について、テスト ガイドライン(G6)で説明されている許容範囲内にほぼ等価かどうかを確認します。

// CHECK-LABEL: Evaluated results of function: add_op_test_ui4
func.func @add_op_test_ui4() {
  %0 = stablehlo.constant dense<[0, 2]> : tensor<2xui4>
  %1 = stablehlo.constant dense<[15, 3]> : tensor<2xui4>
  %2 = stablehlo.add %0, %1 : tensor<2xui4>
  check.expect_eq_const %2, [15, 5] : tensor<2xui4>
  func.return
}

テスト ユーティリティ stablehlo-translate --interpretコード)はプログラムの解析を行い、関数を構成する演算を含む各関数を解釈します。StableHLO オペレーションごとに、さまざまなランタイム動作を実行する複数のテストで構成される専用のテストスイートがあります。テストはこちらで確認できます(例: interpret_*.mlir)。

テストに関するガイドライン

(G1)すべての演算でサポートされている型をすべてテストする必要はありますか?

次のルールを組み合わせて判断できます。

  1. 演算の実装時に、対応する eval 関数に特定の型を処理するコードが存在する場合は、その型をカバーするテストが不可欠です。たとえば、add 演算には整数型、ブール値、浮動小数点型、複合型を処理する排他的なコードがあるため、型のカテゴリごとにテストが 1 つ必要です。

  2. 型のセットが対応する eval 関数で均一に処理されている場合は、すべての型を 1 回テストするだけで十分です。たとえば、add 演算の場合、整数型のすべてのバリアント(si4u4si8u8 など)が llvm::APInt API を使用して同様に処理されます。そのため、各バリアントのテストの追加はスキップし、代わりに 1 つの代表的なテストを追加できます。代表者の選択であいまいさを避けるため、次のガイドラインに従ってください。

    • すべての型が一様に処理され、同じプリミティブ型(すべて整数型、浮動小数点型、複合型など)を持つ場合、ビット幅が最大であるものを選択します。
    • すべての型(一律に処理され、プリミティブ型が混在している場合)は、整数、浮動小数点、ブール値、複素という優先順にプリミティブ型を持つものを選択します。

(G2)演算の動作をカバーするために必要なテストの数はどのように決めるのですか?

目標は、最小限のテストで op のインタープリタのロジック(つまり、実装のすべてのコーナーケース)を包括的にカバーすることです。保守性のためには、テストの数を最小限に抑えることが重要です。テストの数が多ければ多いほど、それらをレビューして、演算を包括的にカバーしていることを確認するのも容易になります。そのため、より単純な演算のほとんどは 1 回のテストで終わることが想定されます。なんらかの妥当な理由で包括的なカバレッジが現実的でない場合は、90% 以上で停止しても問題ありません。これは、pull リクエストの審査時にケースバイケースで判断されます。

(G3)インタープリタ インフラストラクチャのテストを追加するにはどうすればよいですか?

インタープリタのインフラストラクチャはほとんどシンプルであり、Google の信頼ベースに追加できます。重要な点は、さまざまな型を基盤となるインタープリタ ストレージにパッキングし、そこからアンパックする方法のみです。(G1)で説明したように、異なる処理方法の op のみをテストします。そのため、整数型/浮動小数点型のさまざまなバリアントに対応するパッキング/アンパッキング コードが、テスト中に完全にカバーされない可能性があります。完全にカバーするには、すべての StableHLO 要素タイプをサポートする constant のようなオペレーションを選択し、網羅的なテストを作成します。

(G4)演算の実装が他の演算に依存する場合、後者のテストを作成するべきですか?

いいえ。たとえば、batch_norm_grad の実装は dividesubtractmultiply などに基づいて行うことができます。後者のオペレーションをテストする際は、前者をテストすることは避ける必要があります。

(G5)実装で定義されている動作や未定義の動作を実行するためのテストを作成すればよいですか?

演算の実装で定義されている動作または未定義の動作を実行するテストは作成しないでください。実装定義の動作を実行するテストは、一般化すべきではないインタープリタのローカル動作を示しています。未定義の動作を実行するテストは、演算の動作の理解には寄与しません。

(G6)浮動小数点型のテストを作成するにあたって、チェックで期待される結果の精度はどの程度にする必要がありますか?

基本演算(加算、減算、乗算、除算、2 乗)の場合、IEEE 仕様に従った実装では、数学的に正確な結果の 0.5 ULP 以内の丸め結果が得られることが期待されます。とは言え、これらのオペレーションで期待される結果の差異は最大で 1 ULP 程度であると想定できます。ただし、精度の保証が実装によって定義される超越的な関数(sinecosine など)では機能しない場合があります(根拠)。

現在の実装では、1 つの公差値 0.0001 が使用されています。次の例は、上記の tolerance の動作を示しています。

func.func @check_tolerance() {
  %0 = stablehlo.constant dense<0.2> : tensor<f32>

  // The following check succeeds as %0 is almost equal to the provided
  // constant modulo the tolerance, mentioned above.
  check.expect_almost_eq_const %0, dense<0.19999> : tensor<f32>

  // The following check fails as %0 is not bitwise equal to the provided
  // constant.
  check.expect_eq_const %0, dense<0.19999> : tensor<f32>

  func.return
}

これは、StableHLO 演算の数値精度をテストする最初の手順にすぎません。現時点では、StableHLO 仕様においてこの規定は明確ではありません。StableHLO を実際に使用した際の経験と関係者からのフィードバックに基づいて、#1156 の明確化に向けて取り組みが進行中です。この取り組みが進行すると、それに応じてインフラストラクチャが更新されます。

(G7)テストのコーディング スタイルについて何か記載はありますか?

  1. 入出力には、デフォルトの SSA 値(%0、%1 など)ではなく、実際の名前を使用してください。
  2. テストで プリティ プリント形式が存在する場合は、使用していることを確認します。

(G8)仕様ですでに提供されている例を含める必要がありますか? ○(テストの完全性のために必要)。