データモデル
StableHLO プログラムはテンソルを使った計算
(n 次元配列)。現在のモデルでは、
クラス Tensor
。Tensor
オブジェクトの基盤となるストレージ クラス。
detail::Buffer
は、テンソルの mlir::ShapedType
と
テンソルの変更可能な blob を表す mlir::HeapAsmResourceBlob
オブジェクト
連続したバイト配列として
主要な順序
detail::Buffer
オブジェクトは、メモリ管理を簡素化するために参照カウントされます。
テンソルの個々の要素は、Element
クラスを使用して表現されます。
APInt
、APFloat
、または
ストレージの場合は pair<APFloat,APFloat>
。最後の 1 つは要素を保存するために使用します
使用できます。
Tensor
には、個々の要素を操作する次の API があります。
Element Tensor::get(llvm::ArrayRef<int64_t> index)
: 多次元インデックスindex
の個々のテンソル要素をElement
とします。 渡されます。void Tensor::set(llvm::ArrayRef<int64_t> index, Element element);
:Element
オブジェクトのelement
を多次元のテンソルに更新する インデックスindex
。
通訳の仕組み
インタープリタへの入力関数は、
SmallVector<Tensor> eval(func::FuncOp func, ArrayRef<Tensor> args);
これにより、次の処理が行われます。
func
の SSA 引数とそれに関連するランタイムTensor
を追跡します シンボル テーブル マップ M を使用して、args
で提供されます。func
内の各演算について、SSACFG の順序で次を実行します。 <ph type="x-smartling-placeholder">- </ph>
- op で
eval
を呼び出す。演算の各 SSA オペランドについて、 ランタイム値を M からeval
呼び出しの引数として渡します。 - 演算の SSA 結果と M での評価値を追跡します。
- op で
(2)で説明した演算レベルの eval
は、
意味を持たせることができます。stablehlo::AddOp
の例を次に示します。
この例では、lhs
テンソルと rhs
テンソルの個々の要素はペアワイズです。
Element
オブジェクトとして抽出され、それらが追加されます。加算の結果、
Element
オブジェクト。最終的な result
テンソルに格納されます。
Tensor eval(AddOp op, const Tensor &lhs, const Tensor &rhs) {
Tensor result(op.getType());
for (auto it = result.index_begin(); it != result.index_end(); ++it)
result.set(*it, lhs.get(*it) + rhs.get(*it));
return result;
}
全体として、インタープリタの設計は、
eval
関数の実装は、個々のオペレーションに
StableHLO のリファレンス実装として機能します。たとえば、
eval
をテンプレート関数として定義し、要素型でパラメータ化する
さまざまな要素タイプがどのように処理されるかを
Element::operator+
など。eval
の実装を簡素化します。
コンスタント フォールディングにインタープリタを使用する
インタープリタ メカニズムを使用して、定数オペランドを持つ演算を折りたたみ可能
使用できます。次のコード スニペットは、実装の概念を示しています。
浮動小数点型のオペランドで stablehlo::AddOp
を折りたたむ場合:
OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
auto attrs = adaptor.getOperands();
DenseElementsAttr lhsData = dyn_cast<DenseElementsAttr>(attrs[0]);
DenseElementsAttr rhsData = dyn_cast<DenseElementsAttr>(attrs[1]);
if (!lhsData || !rhsData) return {};
auto lhs = Tensor(lhsData);
auto rhs = Tensor(rhsData);
auto result = eval(*this, lhs, rhs);
SmallVector<APFloat> values;
for (auto i = 0; i < result.getNumElements(); ++i) {
Element element = result.get(i);
values.push_back(cast<FloatAttr>(element.getValue()).getValue());
}
return DenseElementsAttr::get(result.getType(), values);
}
現時点では、
StableHLO のフォルダは実装する予定がないため、折りたたまれます。
ただし将来的には、
MHLO での折りたたみについて説明します。この時点で、コード スニペットのエルゴノミクスを改善します。
(たとえば、定数オペランドを
Tensor
オブジェクトを作成し、Tensor
の結果を OpFoldResult
に展開します)。
StableHLO インタープリタのテスト
インタープリタは、入力として(A)StableHLO プログラム、(B)データ値を受け取り、
プログラムにフィードされ、出力データ値を生成します。この出力データの値が、
ユーザーから提供された期待データ値と照らし合わせます。データ値(B)は、
stablehlo.constant
オペレーションを使用してプログラム内にハードコードされています。「
インタープリタは入力プログラムを評価します。テスト対象のオペレーションの出力
チェック(例: check.expect_eq
、check.expect_almost_eq
)を介してチェックされるため、
表示されます。check.expect_eq
と check.expect_eq_const
はビット単位のチェック
サポートされている型、check.expect_almost_eq
、および
check.expect_almost_eq_const
は、許容範囲内にあるほぼ等価かどうかを確認します。
に記載されているように、浮動小数点数型と複合型に対するテスト ガイドライン(G6)に記載されています。
// CHECK-LABEL: Evaluated results of function: add_op_test_ui4
func.func @add_op_test_ui4() {
%0 = stablehlo.constant dense<[0, 2]> : tensor<2xui4>
%1 = stablehlo.constant dense<[15, 3]> : tensor<2xui4>
%2 = stablehlo.add %0, %1 : tensor<2xui4>
check.expect_eq_const %2, [15, 5] : tensor<2xui4>
func.return
}
テスト ユーティリティ stablehlo-translate --interpret
(コード)
は、プログラムを解析し、関数に含まれる各関数を解釈します。
関数を構成する演算です。専用のテストスイートがあり
StableHLO オペレーションごとに、さまざまなランタイム動作を実行する複数のテスト。
テストはこちらで確認できます。
テストに関するガイドライン
(G1)すべての演算について、サポートされているすべての型をテストする必要がありますか?
次のルールを組み合わせて使用することで、決定できます。
op の実装中、対応する
eval
にコードが存在する場合 関数で特定の型を処理する場合、テストが不可欠であり、 その種類をカバーしますたとえば、add
演算には、排他的なコードがあります。 整数、ブール値、浮動小数点数、複合型を扱います。 タイプのカテゴリごとに 1 つのテストが必要です。ある一連の型が対応する
eval
関数で均一に処理されている場合、 これらすべての型に対して 1 回のテストで十分です。たとえばadd
演算に対する整数型のすべてのバリアント(si4
、u4
、si8
、u8
) など)は、llvm::APInt
API を使用して同様に処理されるため、スキップできます。 テストを追加し、テストを実行する代わりに 代表的なテストです。担当者を選定する際のあいまいさを避けるため、Google は 次のガイドラインを使用する必要があります。- 均一に処理されるすべての型が同じプリミティブ型である場合 (すべてが整数型、浮動小数点型、複合型である場合)は、 最も大きいビット幅を選択してください。
- すべての型が均一に処理され、プリミティブ型が混在している場合、 次のプリミティブ型を持つものを選択します(降順) 設定: 整数、浮動小数点数、ブール値、複雑。
(G2)ある事象をカバーするために必要なテストの数をどのように決定すればよいか どうなるでしょうか?
目的は、演算のインタープリタのロジックを包括的にカバーすることです。 (すなわち、実装の特殊なケース)を、最小限のテストで実施できます。 テストの回数を最小限に抑えることは、保守性のために重要です。テスト回数が少ないほど 簡単にレビューでき 確実に 包括的に説明しますそのため、よりシンプルな構成のほとんどは テストは 1 つだけになりますなんらかの理由で カバレッジが実用的でない場合は、90% 以上で停止しても問題ありません。この決定は pull リクエストの審査時にケースバイケースで行われます。
(G3)インタープリタ インフラストラクチャ用のテストを追加することはできますか?
インタープリタのインフラストラクチャはほぼシンプルで、
信頼基盤を築く必要があります。重要な点は、さまざまな型がどのように詰め込まれるかです。
基になるインタープリタストレージから展開されます(G1)で説明したとおり、
処理の異なるオペレーションのみを
テストすることになりますあり
異なるコードに対応するパッキング/アンパックコードが
整数型/浮動小数点型のバリアントであるため、
説明します。すべてを網羅するには、constant
のようなオペレーションを選択します。
すべての StableHLO 要素タイプをサポートし、網羅的なテストを作成します。
(G4)op の実装が他の op に依存している場合は、 どうすればよいでしょうか?
いいえ。たとえば、batch_norm_grad
の実装は以下に基づいて行うことができます。
divide
、subtract
、multiply
など。後者のテストは避けるべきです
運用上のオーバーヘッドを削減できます。
(G5)実装定義 / 未定義の状態を再現するためのテストを作成すべきか 行動とはどのようなものか?
実装で定義されている、または適用できる 未定義の動作に対処できます。実装定義の動作のテスト インタープリタのローカル動作を示します。 一般化されます。未定義の動作を実行するテストは、次の要素には寄与しません。 関数の動作を理解します
(G6)浮動小数点型のテストを作成する際に、 期待される結果はチェックで指定する必要がありますか?
基本演算(加算、減算、乗算、除算、
IEEE 仕様に沿った実装では、
四捨五入された結果が、数学的に正確な結果の 0.5 ULP 以内とは言え、私たちは
これらのオペレーションから期待される結果は、
最大 1 つの ULP ですただし、これは超越関数では機能しない場合があります。
精度保証の対象となる(sine
、cosine
など)
実装で定義されている(根拠)。
現在の実装では、「画一的な」アプローチが許容値は 0.0001 です。 次の例は、上記の許容差の実例を示しています。
func.func @check_tolerance() {
%0 = stablehlo.constant dense<0.2> : tensor<f32>
// The following check succeeds as %0 is almost equal to the provided
// constant modulo the tolerance, mentioned above.
check.expect_almost_eq_const %0, dense<0.19999> : tensor<f32>
// The following check fails as %0 is not bitwise equal to the provided
// constant.
check.expect_eq_const %0, dense<0.19999> : tensor<f32>
func.return
}
これは、StableHLO オペレーションの数値精度テストの最初のステップにすぎません。 現時点では、これは StableHLO 仕様で指定されていない領域であり、 現在調査中です。#1156 StableHLO を実際に使用した経験と、 できます。作業が進行したら、インフラストラクチャを更新します。 必要があります。
(G7)テストのコーディング スタイルについて何か問題はありますか?
- デフォルトの名前ではなく、実際の入力/出力名を使用してください。 SSA 値(%0、%1 など)にマッピング
- テストで プリティ プリント形式が使用されていることを確認します(存在する場合)。
(G8)すでに示した例を仕様に含めた方がよいですか? ○(テストの完全性のため)