インタープリタの設計

データモデル

StableHLO プログラムはテンソルを使った計算 (n 次元配列)。現在のモデルでは、 クラス TensorTensor オブジェクトの基盤となるストレージ クラス。 detail::Buffer は、テンソルの mlir::ShapedType と テンソルの変更可能な blob を表す mlir::HeapAsmResourceBlob オブジェクト 連続したバイト配列として 主要な順序 detail::Buffer オブジェクトは、メモリ管理を簡素化するために参照カウントされます。

テンソルの個々の要素は、Element クラスを使用して表現されます。 APIntAPFloat、または ストレージの場合は pair<APFloat,APFloat>。最後の 1 つは要素を保存するために使用します 使用できます。

Tensor には、個々の要素を操作する次の API があります。

  • Element Tensor::get(llvm::ArrayRef<int64_t> index): 多次元インデックス index の個々のテンソル要素を Element とします。 渡されます。
  • void Tensor::set(llvm::ArrayRef<int64_t> index, Element element);: Element オブジェクトの element を多次元のテンソルに更新する インデックス index

通訳の仕組み

インタープリタへの入力関数は、

SmallVector<Tensor> eval(func::FuncOp func, ArrayRef<Tensor> args);

これにより、次の処理が行われます。

  1. func の SSA 引数とそれに関連するランタイム Tensor を追跡します シンボル テーブル マップ M を使用して、args で提供されます。
  2. func 内の各演算について、SSACFG の順序で次を実行します。 <ph type="x-smartling-placeholder">
      </ph>
    • op で eval を呼び出す。演算の各 SSA オペランドについて、 ランタイム値を M から eval 呼び出しの引数として渡します。
    • 演算の SSA 結果と M での評価値を追跡します。

(2)で説明した演算レベルの eval は、 意味を持たせることができます。stablehlo::AddOp の例を次に示します。 この例では、lhs テンソルと rhs テンソルの個々の要素はペアワイズです。 Element オブジェクトとして抽出され、それらが追加されます。加算の結果、 Element オブジェクト。最終的な result テンソルに格納されます。

Tensor eval(AddOp op, const Tensor &lhs, const Tensor &rhs) {
  Tensor result(op.getType());

  for (auto it = result.index_begin(); it != result.index_end(); ++it)
    result.set(*it, lhs.get(*it) + rhs.get(*it));

  return result;
}

全体として、インタープリタの設計は、 eval 関数の実装は、個々のオペレーションに StableHLO のリファレンス実装として機能します。たとえば、 eval をテンプレート関数として定義し、要素型でパラメータ化する さまざまな要素タイプがどのように処理されるかを Element::operator+ など。eval の実装を簡素化します。

コンスタント フォールディングにインタープリタを使用する

インタープリタ メカニズムを使用して、定数オペランドを持つ演算を折りたたみ可能 使用できます。次のコード スニペットは、実装の概念を示しています。 浮動小数点型のオペランドで stablehlo::AddOp を折りたたむ場合:

OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
  auto attrs = adaptor.getOperands();
  DenseElementsAttr lhsData = dyn_cast<DenseElementsAttr>(attrs[0]);
  DenseElementsAttr rhsData = dyn_cast<DenseElementsAttr>(attrs[1]);
  if (!lhsData || !rhsData) return {};

  auto lhs = Tensor(lhsData);
  auto rhs = Tensor(rhsData);
  auto result = eval(*this, lhs, rhs);

  SmallVector<APFloat> values;
  for (auto i = 0; i < result.getNumElements(); ++i) {
    Element element = result.get(i);
    values.push_back(cast<FloatAttr>(element.getValue()).getValue());
  }

  return DenseElementsAttr::get(result.getType(), values);
}

現時点では、 StableHLO のフォルダは実装する予定がないため、折りたたまれます。 ただし将来的には、 MHLO での折りたたみについて説明します。この時点で、コード スニペットのエルゴノミクスを改善します。 (たとえば、定数オペランドを Tensor オブジェクトを作成し、Tensor の結果を OpFoldResult に展開します)。

StableHLO インタープリタのテスト

インタープリタは、入力として(A)StableHLO プログラム、(B)データ値を受け取り、 プログラムにフィードされ、出力データ値を生成します。この出力データの値が、 ユーザーから提供された期待データ値と照らし合わせます。データ値(B)は、 stablehlo.constant オペレーションを使用してプログラム内にハードコードされています。「 インタープリタは入力プログラムを評価します。テスト対象のオペレーションの出力 チェック(例: check.expect_eqcheck.expect_almost_eq)を介してチェックされるため、 表示されます。check.expect_eqcheck.expect_eq_const はビット単位のチェック サポートされている型、check.expect_almost_eq、および check.expect_almost_eq_const は、許容範囲内にあるほぼ等価かどうかを確認します。 に記載されているように、浮動小数点数型と複合型に対するテスト ガイドライン(G6)に記載されています。

// CHECK-LABEL: Evaluated results of function: add_op_test_ui4
func.func @add_op_test_ui4() {
  %0 = stablehlo.constant dense<[0, 2]> : tensor<2xui4>
  %1 = stablehlo.constant dense<[15, 3]> : tensor<2xui4>
  %2 = stablehlo.add %0, %1 : tensor<2xui4>
  check.expect_eq_const %2, [15, 5] : tensor<2xui4>
  func.return
}

テスト ユーティリティ stablehlo-translate --interpretコード) は、プログラムを解析し、関数に含まれる各関数を解釈します。 関数を構成する演算です。専用のテストスイートがあり StableHLO オペレーションごとに、さまざまなランタイム動作を実行する複数のテスト。 テストはこちらで確認できます。

テストに関するガイドライン

(G1)すべての演算について、サポートされているすべての型をテストする必要がありますか?

次のルールを組み合わせて使用することで、決定できます。

  1. op の実装中、対応する eval にコードが存在する場合 関数で特定の型を処理する場合、テストが不可欠であり、 その種類をカバーしますたとえば、add 演算には、排他的なコードがあります。 整数、ブール値、浮動小数点数、複合型を扱います。 タイプのカテゴリごとに 1 つのテストが必要です。

  2. ある一連の型が対応する eval 関数で均一に処理されている場合、 これらすべての型に対して 1 回のテストで十分です。たとえば add 演算に対する整数型のすべてのバリアント(si4u4si8u8) など)は、llvm::APInt API を使用して同様に処理されるため、スキップできます。 テストを追加し、テストを実行する代わりに 代表的なテストです。担当者を選定する際のあいまいさを避けるため、Google は 次のガイドラインを使用する必要があります。

    • 均一に処理されるすべての型が同じプリミティブ型である場合 (すべてが整数型、浮動小数点型、複合型である場合)は、 最も大きいビット幅を選択してください。
    • すべての型が均一に処理され、プリミティブ型が混在している場合、 次のプリミティブ型を持つものを選択します(降順) 設定: 整数、浮動小数点数、ブール値、複雑。

(G2)ある事象をカバーするために必要なテストの数をどのように決定すればよいか どうなるでしょうか?

目的は、演算のインタープリタのロジックを包括的にカバーすることです。 (すなわち、実装の特殊なケース)を、最小限のテストで実施できます。 テストの回数を最小限に抑えることは、保守性のために重要です。テスト回数が少ないほど 簡単にレビューでき 確実に 包括的に説明しますそのため、よりシンプルな構成のほとんどは テストは 1 つだけになりますなんらかの理由で カバレッジが実用的でない場合は、90% 以上で停止しても問題ありません。この決定は pull リクエストの審査時にケースバイケースで行われます。

(G3)インタープリタ インフラストラクチャ用のテストを追加することはできますか?

インタープリタのインフラストラクチャはほぼシンプルで、 信頼基盤を築く必要があります。重要な点は、さまざまな型がどのように詰め込まれるかです。 基になるインタープリタストレージから展開されます(G1)で説明したとおり、 処理の異なるオペレーションのみを テストすることになりますあり 異なるコードに対応するパッキング/アンパックコードが 整数型/浮動小数点型のバリアントであるため、 説明します。すべてを網羅するには、constant のようなオペレーションを選択します。 すべての StableHLO 要素タイプをサポートし、網羅的なテストを作成します。

(G4)op の実装が他の op に依存している場合は、 どうすればよいでしょうか?

いいえ。たとえば、batch_norm_grad の実装は以下に基づいて行うことができます。 dividesubtractmultiply など。後者のテストは避けるべきです 運用上のオーバーヘッドを削減できます。

(G5)実装定義 / 未定義の状態を再現するためのテストを作成すべきか 行動とはどのようなものか?

実装で定義されている、または適用できる 未定義の動作に対処できます。実装定義の動作のテスト インタープリタのローカル動作を示します。 一般化されます。未定義の動作を実行するテストは、次の要素には寄与しません。 関数の動作を理解します

(G6)浮動小数点型のテストを作成する際に、 期待される結果はチェックで指定する必要がありますか?

基本演算(加算、減算、乗算、除算、 IEEE 仕様に沿った実装では、 四捨五入された結果が、数学的に正確な結果の 0.5 ULP 以内とは言え、私たちは これらのオペレーションから期待される結果は、 最大 1 つの ULP ですただし、これは超越関数では機能しない場合があります。 精度保証の対象となる(sinecosine など) 実装で定義されている(根拠)。

現在の実装では、「画一的な」アプローチが許容値は 0.0001 です。 次の例は、上記の許容差の実例を示しています。

func.func @check_tolerance() {
  %0 = stablehlo.constant dense<0.2> : tensor<f32>

  // The following check succeeds as %0 is almost equal to the provided
  // constant modulo the tolerance, mentioned above.
  check.expect_almost_eq_const %0, dense<0.19999> : tensor<f32>

  // The following check fails as %0 is not bitwise equal to the provided
  // constant.
  check.expect_eq_const %0, dense<0.19999> : tensor<f32>

  func.return
}

これは、StableHLO オペレーションの数値精度テストの最初のステップにすぎません。 現時点では、これは StableHLO 仕様で指定されていない領域であり、 現在調査中です。#1156 StableHLO を実際に使用した経験と、 できます。作業が進行したら、インフラストラクチャを更新します。 必要があります。

(G7)テストのコーディング スタイルについて何か問題はありますか?

  1. デフォルトの名前ではなく、実際の入力/出力名を使用してください。 SSA 値(%0、%1 など)にマッピング
  2. テストで プリティ プリント形式が使用されていることを確認します(存在する場合)。

(G8)すでに示した例を仕様に含めた方がよいですか? ○(テストの完全性のため)