StableHLO は、ML モデルの高レベル オペレーション(HLO)のオペレーション セットです。StableHLO は、さまざまな ML フレームワークと ML コンパイラ間の移植性レイヤとして機能します。StableHLO プログラムを生成する ML フレームワークは、StableHLO プログラムを使用する ML コンパイラと互換性があります。
Google の目標は、さまざまな ML フレームワーク(TensorFlow、JAX、PyTorch など)と ML コンパイラ(XLA、IREE など)間の相互運用性を高めることで、ML 開発を簡素化し、高速化することです。このため、このドキュメントでは StableHLO プログラミング言語の仕様について説明します。
この仕様には、主に 3 つのセクションがあります。まず、プログラム セクションでは、StableHLO 関数で構成される StableHLO プログラムの構造について説明します。StableHLO 関数自体は StableHLO オペレーションで構成されます。この構造内の Ops セクションでは、個々のオペレーションのセマンティクスを指定します。実行セクションでは、プログラム内で同時に実行されるこれらのオペレーションのセマンティクスについて説明します。最後に、表記セクションでは、仕様全体で使用される表記について説明します。
StableHLO の以前のリリースの仕様を表示するには、目的の タグ付きリリースでリポジトリを開きます。たとえば、StableHLO v0.19.0 仕様などです。StableHLO の各マイナー バージョン アップで発生した変更を確認するには、VhloDialect.td のバージョンログを参照してください。
プログラム
Program ::= {Func}
StableHLO プログラムは、任意の数の StableHLO 関数で構成されます。次の例は、3 つの入力(%image、%weights、%bias)と 1 つの出力を持つ関数 @main を含むプログラムの例です。関数の本体には 6 つのオペレーションがあります。
func.func @main(
%image: tensor<28x28xf32>,
%weights: tensor<784x10xf32>,
%bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
%0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
%1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
%2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
%3 = "stablehlo.constant"() {value = dense<0.0> : tensor<1x10xf32>} : () -> tensor<1x10xf32>
%4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
"func.return"(%4): (tensor<1x10xf32>) -> ()
}
関数
Func ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput ::= ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput ::= ValueType
FuncBody ::= {Op}
StableHLO 関数(名前付き関数とも呼ばれます)には、識別子、入力/出力、本文があります。今後、関数にメタデータを追加して、HLO との互換性を高める予定です(#425、#626、#740、#744)。
識別子
FuncId ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
| '%' letter {letter | digit}
letter ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit ::= '0' | ... | '9'
StableHLO 識別子は、多くのプログラミング言語の識別子と似ていますが、次の 2 つの特殊性があります。1)すべての識別子には、さまざまな種類の識別子を区別するシジルがあります。2)値識別子は完全に数値にすることができ、StableHLO プログラムの生成を簡素化できます。
型
Type ::= ValueType | NonValueType
ValueType ::= TensorType | QuantizedTensorType | TokenType | TupleType | BufferType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType
StableHLO 型は、StableHLO 値を表す値型(ファーストクラス型とも呼ばれます)と、他のプログラム要素を記述する非値型に分類されます。StableHLO の型は多くのプログラミング言語の型と似ていますが、StableHLO のドメイン固有の性質により、いくつかの特殊な結果が生じます(スカラー型は値型ではありません)。
TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit} | '?'
テンソル型は、テンソル(多次元配列)を表します。これらにはシェイプと要素の型があります。シェイプは、対応するディメンション(軸とも呼ばれます)の昇順で、0 から R-1 までの番号が付けられた非負または不明なディメンション サイズを表します。ディメンションの数 R は、ランクと呼ばれます。たとえば、tensor<2x3xf32> は、形状 2x3 と要素型 f32 を持つテンソル型です。サイズが 2 と 3 の 2 つのディメンション(つまり、2 つの軸)- 0 番目のディメンションと 1 番目のディメンションがあります。ランクは 2 です。
シェイプは部分的に不明(動的)な場合と完全に不明な場合があります。たとえば、tensor<?x2xf64> は部分的に不明で、tensor<?x?xf64> は完全に不明です。動的ディメンションのサイズは ? を使用して表されます。シェイプのランク付けを解除することはできません。
将来的には、テンソル型を次元サイズや要素型を超えて拡張し、レイアウト(#629)やスパース性(#1078)などを含めることを検討しています。
QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
QuantizationStorageType
['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
':' QuantizationExpressedType
[':' QuantizationDimension]
',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerLiteral
QuantizationStorageMax ::= IntegerLiteral
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerLiteral
QuantizationParameters ::= QuantizationParameter
| '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale [':' QuantizationZeroPoint]
QuantizationScale ::= FloatLiteral
QuantizationZeroPoint ::= IntegerLiteral
| 名前 | タイプ | 制約 |
|---|---|---|
storage_type |
整数型 | (C1 ~ C3)、(C8) |
storage_min |
整数定数 | (C1)、(C3)、(C7) |
storage_max |
整数定数 | (C2)、(C3)、(C7) |
expressed_type |
浮動小数点型 | (C4) |
quantization_dimension |
オプションの整数定数 | (C10-C12) |
scales |
可変個数の浮動小数点定数 | (C4-C6)、(C9)、(C10)、(C13) |
zero_points |
可変個数の整数定数 | (C7-C9) |
量子化された要素タイプは、ストレージ タイプの整数値を storage_min から storage_max(両端を含む)の範囲で表します。これは、表現されたタイプの浮動小数点値に対応します。指定された整数値 i に対して、対応する浮動小数点値 f は f = (i - zero_point) * scale として計算できます。ここで、scale と zero_point は量子化パラメータと呼ばれます。storage_min と storage_max は文法では省略可能ですが、デフォルト値はそれぞれ min_value(storage_type) と max_value(storage_type) です。量子化された要素の型には次の制約があります。
- (C1)
type(storage_min) = storage_type。 - (C2)
type(storage_max) = storage_type。 - (C3)
min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type)。 - (C4)
type(scales...) = expressed_type。 - (C5)
0 < scales。 - (C6)
is_finite(scales...)。 - (C7)
storage_min <= zero_points <= storage_max。 - (C8)
type(zero_points...) = storage_type。 - (C9)
size(scales) = size(zero_points)。 - (C10)
is_empty(quantization_dimension)の場合はsize(scales) = 1です。 - (C11)
0 <= quantization_dimension。
現時点では、QuantizationScale は浮動小数点定数ですが、乗数とシフトで表される整数ベースのスケールに対する関心が高まっています。この件については、近いうちに検討する予定です(#1404)。
QuantizationZeroPoint のセマンティクス(型、値、量子化されたテンソル型にゼロポイントを 1 つだけ含めることができるか、複数含めることができるかなど)については、現在も議論が続いています。この議論の結果に基づいて、ゼロポイントに関する仕様は今後変更される可能性があります(#1405)。
もう 1 つの継続中の議論は、QuantizationStorageMin と QuantizationStorageMax のセマンティクスに関するものです。これらの値と量子化されたテンソルの値に制約を課すべきかどうかを判断します(#1406)。
最後に、不明なディメンション サイズの表現を検討しているのと同様に(#1407)、不明なスケールとゼロポイントの表現を検討する予定です。
量子化テンソル型は、量子化された要素を持つテンソルを表します。これらのテンソルは、要素の型が通常の要素の型ではなく量子化された要素の型である点を除き、通常のテンソルとまったく同じです。
量子化テンソルでは、量子化はテンソル全体に対して 1 つの scale と zero_point を持つ「テンソル単位」にすることも、特定のディメンション quantization_dimension のスライスごとに 1 つのペアを持つ複数の scales と zero_points を持つ「軸単位」にすることもできます。より正式には、軸ごとの量子化を含むテンソル t では、quantization_dimension の dim(t, quantization_dimension) スライス(t[:, ..., 0, ..., :], t[:, ..., 1, ..., :] など)があります。i 番目のスライス内のすべての要素は、量子化パラメータとして scales[i] と zero_points[i] を使用します。量子化テンソル型には次の制約があります。
- テンソルごとの量子化の場合:
- 追加の制約はありません。
- 軸ごとの量子化の場合:
- (C12)
quantization_dimension < rank(self)。 - (C13)
dim(self, quantization_dimension) = size(scales)。
- (C12)
TokenType ::= 'token'
トークン型は、トークン(一部のオペレーションで生成および使用される不透明な値)を表します。トークンは、実行セクションで説明されているように、オペレーションに実行順序を適用するために使用されます。
TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]
バッファ型はバッファを表します。たとえば、XLA では、バッファは一貫したストレージを持つ多次元配列です。テンソル型と同様に、バッファ型にはシェイプと要素型があります。シェイプは、対応するディメンション(軸とも呼ばれます)の昇順で、0 から R-1 までの番号が付けられた非負または不明なディメンション サイズを表します。次元の数 R はランクと呼ばれます。たとえば、memref<2x3xf32> は、形状 2x3 と要素型 f32 を持つバッファ型です。サイズが 2 と 3 の 2 つのディメンション(つまり、2 つの軸)である 0 番目のディメンションと 1 番目のディメンションがあります。ランクは 2 です。
バッファは custom_call から CreateBuffer または Pin を使用して割り当て、custom_call から Unpin を使用して割り当て解除できます。custom_call オペレーションのみがバッファ内のコンテンツを読み書きできます。詳しくは、custom_call をご覧ください。
タプル型は、タプル(異種リスト)を表します。タプルは、HLO との互換性のためだけに存在するレガシー機能です。HLO では、タプルは可変長引数の入力と出力を表すために使用されます。StableHLO では、可変長入力と出力がネイティブでサポートされています。StableHLO でタプルが使用されるのは、HLO ABI を包括的に表現する場合のみです。たとえば、T、tuple<T>、tuple<tuple<T>> は、特定の実装によって大きく異なる場合があります。今後、HLO ABI を変更する予定です。これにより、StableHLO からタプル型を削除できる可能性があります(#598)。
TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si2' | 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui2' | 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f4E2M1FN' | 'f6E2M3FN' | 'f6E3M2FN' | 'f8E3M4' | 'f8E4M3'
| 'f8E4M3FN' | 'f8E4M3FNUZ' | 'f8E4M3B11FNUZ' | 'f8E5M2'
| 'f8E5M2FNUZ' | 'f8E8M0FNU' | 'bf16' | 'f16' | 'f32' | 'f64'
TensorFloat32 ::= 'tf32'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'
要素の型は、テンソル型の要素を表します。多くのプログラミング言語とは異なり、これらの型は StableHLO ではファーストクラスではありません。つまり、StableHLO プログラムではこれらの型の値を直接表現できません(そのため、T 型のスカラー値を tensor<T> 型の 0 次元テンソル値で表現するのが一般的です)。
- ブール型は、ブール値
trueとfalseを表します。 - 整数型は、符号付き(
si)または符号なし(ui)のいずれかで、サポートされているビット幅(2、4、8、16、32、64)のいずれかになります。符号付きsiN型は-2^(N-1)から2^(N-1)-1までの整数値を表し、符号なしuiN型は0から2^N-1までの整数値を表します。 - 浮動小数点型は、次のいずれかになります。
f8E3M4、f8E4M3、f8E5M2は、IEEE-754 規則に従った 8 ビットの浮動小数点数です。- ディープ ラーニング用の FP8 形式で説明されている FP8 形式の
E4M3エンコードとE5M2エンコードにそれぞれ対応するf8E4M3FN型とf8E5M2型。 - ディープ ニューラル ネットワーク用の 8 ビット数値形式で説明されている FP8 形式の
E4M3エンコードとE5M2エンコードに対応するf8E4M3FNUZ型とf8E5M2FNUZ型。 - ディープ ニューラル ネットワークのハイブリッド 8 ビット浮動小数点(HFP8)トレーニングと推論で説明されている FP8 形式の
E4M3エンコードに対応するf8E4M3B11FNUZタイプ。 - BFloat16: Cloud TPU で高パフォーマンスを発揮させる秘訣で説明されている
bfloat16形式に対応するbf16型。 f16、f32、f64の各型は、それぞれ IEEE 754 標準で説明されているbinary16(「半精度」)、binary32(「単精度」)、binary64(「倍精度」)の各形式に対応しています。tf32型は TensorFloat32 形式に対応しており、StableHLO でのサポートは限定的です。- OCP マイクロ スケーリング形式の仕様で説明されている
f4E2M1FN、f6E2M3FN、f6E3M2FN、f8E8M0FNUMX(マイクロ スケーリング)タイプ。
- 複素型は、同じ要素型の実部と虚部を持つ複素数値を表します。サポートされている複合型は、
complex<f32>(両方の部分がf32型)とcomplex<f64>(両方の部分がf64型)です。
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]
関数型は、名前付き関数と匿名関数の両方を表します。これらには、入力型(-> の左側の型のリスト)と出力型(-> の右側の型のリスト)があります。多くのプログラミング言語では、関数型はファーストクラスですが、StableHLO ではそうではありません。
StringType ::= 'string'
String 型はバイトのシーケンスを表します。多くのプログラミング言語とは異なり、StableHLO では文字列型はファーストクラスではなく、プログラム要素の静的メタデータを指定するためにのみ使用されます。
運用
StableHLO オペレーション(ops とも呼ばれます)は、ML モデルの高レベル オペレーションのクローズド セットを表します。前述のように、StableHLO の構文は MLIR に大きく影響を受けています。これは必ずしも最も人間工学的な代替手段ではありませんが、ML フレームワークと ML コンパイラ間の相互運用性を高めるという StableHLO の目標に最も適していると言えます。
Op ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic ::= 'abs' | 'add' | ...
StableHLO オペレーション(ops とも呼ばれます)には、名前、入力/出力、シグネチャがあります。名前は、stablehlo. プレフィックスと、サポートされているオペレーションの 1 つを一意に識別するニーモニックで構成されます。サポートされているすべてのオペレーションの包括的なリストについては、以下をご覧ください。
OpInputs ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue ::= ValueId
OpInputFuncs ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs ::= [OpOutput {',' OpOutput} '=']
OpOutput ::= ValueId
オペレーションは入力を消費し、出力を生成します。入力は、入力値(実行中に計算)、入力関数(StableHLO では関数はファーストクラスの値ではないため静的に提供)、入力属性(これも静的に提供)に分類されます。オペレーションによって使用および生成される入力と出力の種類は、そのニーモニックによって異なります。たとえば、add オペレーションは 2 つの入力値を消費し、1 つの出力値を生成します。これに対し、select_and_scatter オペレーションは 3 つの入力値、2 つの入力関数、3 つの入力属性を使用します。
OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused ::= '^' digit {digit}
| '^' letter {letter | digit}
入力関数(匿名関数とも呼ばれます)は、名前付き関数とよく似ていますが、次の点が異なります。1) 識別子がない(そのため「匿名」という名前が付けられています)。2) 出力型を宣言しない(出力型は関数内の return 演算から推論されます)。
入力関数の構文には、MLIR との互換性のために現在未使用の部分が含まれています(上記の Unused 生成を参照)。MLIR には、ジャンプ オペレーションを介して接続された複数のオペレーション「ブロック」を持つことができる「リージョン」というより一般的な概念があります。これらのブロックには Unused 本番環境に対応する ID があり、互いに区別できます。StableHLO にはジャンプ オペレーションがないため、MLIR 構文の対応する部分は使用されません(ただし、まだ存在します)。
OpInputAttr ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName ::= letter {letter | digit}
OpInputAttrValue ::= Constant
入力属性には、名前と、サポートされている定数のいずれかの値があります。これらは、プログラム要素の静的メタデータを指定する主な方法です。たとえば、concatenate オペレーションは、属性 dimension を使用して、入力値を連結するディメンションを指定します。同様に、slice op は start_indices や limit_indices などの複数の属性を使用して、入力値をスライスするために使用される境界を指定します。
現在、実際の StableHLO プログラムには、このドキュメントで説明されていない属性が含まれていることがあります。今後、これらの属性を StableHLO opset に吸収するか、StableHLO プログラムに表示されないようにする予定です。それまでの間、これらの属性のリストを以下に示します。
layout(#629)。mhlo.frontend_attributes(#628)。mhlo.sharding(#619)。output_operand_aliases(#740)。- 位置情報メタデータ(#594)。
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'
Op シグネチャは、すべての入力値の型(-> の左側の型のリスト)とすべての出力値の型(-> の右側の型のリスト)で構成されます。厳密に言うと、入力型は冗長であり、出力型もほとんどの場合冗長です(ほとんどの StableHLO オペレーションでは、出力型は入力から推論できるため)。ただし、MLIR との互換性を確保するため、op シグネチャは StableHLO 構文の一部として意図的に残されています。
以下に、ニーモニックが select_and_scatter の op の例を示します。3 つの入力値(%operand、%source、%init_value)、2 つの入力関数、3 つの入力属性(window_dimensions、window_strides、padding)を使用します。op のシグネチャには入力値の型のみが含まれていることに注意してください(入力関数と属性の型はインラインで指定されます)。
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i32>, tensor<i32>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
window_dimensions = dense<[3, 1]> : tensor<2xi64>,
window_strides = dense<[2, 1]> : tensor<2xi64>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>
定数
Constant ::= BooleanConstant
| IntegerConstant
| FloatConstant
| ComplexConstant
| TensorConstant
| QuantizedTensorConstant
| StringConstant
| EnumConstant
StableHLO 定数には、StableHLO 値を表すリテラルと型があります。通常、型は定数構文の一部です。ただし、型が一意に決まる場合(ブール定数の型は i1 であり、整数定数の型は複数考えられる場合など)は除きます。
BooleanConstant ::= BooleanLiteral
BooleanLiteral ::= 'true' | 'false'
ブール定数は、ブール値 true と false を表します。ブール定数の型は i1 です。
IntegerConstant ::= IntegerLiteral ':' IntegerType
IntegerLiteral ::= ['-' | '+'] DecimalDigits
| ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit ::= '0' | ... | '9'
hexadecimalDigit ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'
整数定数は、10 進数または 16 進数表記を使用する文字列を介して整数値を表します。他の基数(2 進数や 8 進数など)はサポートされていません。整数定数には次の制約があります。
- (C1)
is_wellformed(integer_literal, integer_type)。
FloatConstant ::= FloatLiteral ':' FloatType
FloatLiteral ::= SignPart IntegerPart FractionalPart ScientificPart
| '0x' [HexadecimalDigits]
SignPart ::= ['-' | '+']
IntegerPart ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]
浮動小数点定数は、10 進数または科学的記数法を使用する文字列を介して浮動小数点値を表します。また、16 進表記を使用して、対応する型の浮動小数点形式の基になるビットを直接指定することもできます。浮動小数点定数には次の制約があります。
- (C1)16 進数以外の表記が使用されている場合、
is_wellformed(float_literal, float_type)。 - (C2)16 進表記を使用する場合は、
size(hexadecimal_digits) = num_bits(float_type) / 4。
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral ::= '(' RealPart ',' ImaginaryPart ')'
RealPart ::= FloatLiteral
ImaginaryPart ::= FloatLiteral
複素定数は、実数部(最初)と虚数部(2 番目)のリストを使用して複素値を表します。たとえば、(1.0, 0.0) : complex<f32> は 1.0 + 0.0i を表し、(0.0, 1.0) : complex<f32> は 0.0 + 1.0i を表します。これらの部分がメモリに保存される順序は、実装定義です。複素定数には次の制約があります。
- (C1)
is_wellformed(real_part, complex_element_type(complex_type))。 - (C2)
is_wellformed(imaginary_part, complex_element_type(complex_type))。
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral
Tensor 定数は、NumPy 表記で指定されたネストされたリストを使用してテンソル値を表します。たとえば、dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> は、インデックスから要素への次のマッピングを持つテンソル値を表します。{0, 0} => 1、{0, 1} => 2、{0, 2} => 3、{1, 0} => 4、{1, 1} => 5、{1, 2} => 6。これらの要素がメモリに保存される順序は、実装定義です。テンソル定数には次の制約があります。
- (C1)
has_syntax(tensor_literal, element_type(tensor_type))。ここで、has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type)。has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type)。
- (C2)
has_shape(tensor_literal, shape(tensor_type))。ここで、has_shape(element_literal: Syntax, []) = true。has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:])。- それ以外の場合は
false。
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
量子化テンソル定数は、テンソル定数と同じ表記を使用して量子化テンソル値を表します。要素は、ストレージ タイプの定数として指定されます。量子化されたテンソル定数には次の制約があります。
- (C1)
has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type))。 - (C2)
has_shape(quantized_tensor_literal, shape(quantized_tensor_type))。
StringConstant ::= StringLiteral
StringLiteral ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))
文字列リテラルは、ASCII 文字とエスケープ シーケンスを使用して指定されたバイトで構成されます。エンコードに依存しないため、これらのバイトの解釈は実装定義です。文字列リテラルの型は string です。
運用
abs
セマンティクス
operand テンソルに対して要素ごとの絶対値演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- 符号付き整数の場合: 整数剰余。
- 浮動小数点の場合: IEEE-754 の
abs。 - 複素数の場合: 複素数の絶対値。
- 量子化された型の場合:
dequantize_op_quantize(abs, operand, type(result))。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
符号付き整数型、浮動小数点型、複素数型のテンソル、またはテンソルごとの量子化テンソル | (C1-C2) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
符号付き整数型または浮動小数点型のテンソル、またはテンソルごとの量子化テンソル | (C1-C2) |
制約
- (C1)
shape(result) = shape(operand)。 - (C2)
baseline_element_type(result)は次のように定義されます。is_complex(operand)の場合、complex_element_type(element_type(operand))。- それ以外の場合は
baseline_element_type(operand)。
例
// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand)< : (t>ens>or3xi32<) - t>ensor3xi32
// %result: [2, 0, 2]
追加
セマンティクス
2 つのテンソル lhs と rhs の要素ごとの加算を行い、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- ブール値の場合: 論理 OR。
- 整数の場合: 整数の加算。
- 浮動小数点の場合: IEEE-754 の
addition。 - 複素数の場合: 複素数の加算。
- 量子化された型の場合:
dequantize_op_quantize(add, lhs, rhs, type(result))。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | lhs |
テンソルまたは量子化テンソル | (C1-C6) |
| (I2) | rhs |
テンソルまたは量子化テンソル | (C1 ~ C5)、(C7) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
テンソルまたは量子化テンソル | (C1 ~ C7) |
制約
- オペレーションが量子化されていないテンソルを使用する場合:
- (C1)
type(lhs) = type(rhs) = type(result)。
- (C1)
- オペレーションで量子化テンソルを使用する場合:
- (C2)
is_quantized(lhs) and is_quantized(rhs) and is_quantized(result)。 - (C3)
storage_type(lhs) = storage_type(rhs) = storage_type(result)。 - (C4)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)。 - (C5)
(is_per_axis_quantized(lhs) or is_per_axis_quantized(rhs)) = is_per_axis_quantized(result)。 - (C6)
is_per_axis_quantized(lhs)の場合はquantization_dimension(lhs) = quantization_dimension(result)。 - (C7)
is_per_axis_quantized(rhs)の場合はquantization_dimension(rhs) = quantization_dimension(result)です。
- (C2)
例
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[6, 8], [10, 12]]
after_all
セマンティクス
inputs を生成するオペレーションが、result に依存するオペレーションの前に実行されるようにします。このオペレーションの実行は何も行いません。result から inputs へのデータ依存関係を確立するためにのみ存在します。
入力
| ラベル | 名前 | タイプ |
|---|---|---|
| (I1) | inputs |
可変個数の token |
出力
| 名前 | タイプ |
|---|---|
result |
token |
例
// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehl>o.token) - !stablehlo.token
all_gather
セマンティクス
StableHLO プロセス グリッドの各プロセス グループ内で、各プロセスの operands テンソルの値を all_gather_dim に沿って連結し、results テンソルを生成します。
このオペレーションは、StableHLO プロセス グリッドを process_groups に分割します。これは次のように定義されます。
channel_id <= 0 and use_global_device_ids = falseの場合、cross_replica(replica_groups)。channel_id > 0 and use_global_device_ids = falseの場合、cross_replica_and_partition(replica_groups)。channel_id > 0 and use_global_device_ids = trueの場合、flattened_ids(replica_groups)。
その後、各 process_group 内で次の操作を行います。
process_group内のすべてのreceiverのoperands...@receiver = [operand@sender for sender in process_group]。process_group内のすべてのprocessのresults...@process = concatenate(operands...@process, all_gather_dim)。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operands |
可変数のテンソルまたはテンソルごとの量子化テンソル | (C1)、(C6) |
| (I2) | all_gather_dim |
si64 型の定数 |
(C1)、(C6) |
| (I3) | replica_groups |
型 si64 の 2 次元テンソル定数 |
(C2-C4) |
| (I4) | channel_id |
si64 型の定数 |
(C5) |
| (I5) | use_global_device_ids |
i1 型の定数 |
(C5) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
results |
可変数のテンソルまたはテンソルごとの量子化テンソル | (C6) |
制約
- (C1)
0 <= all_gather_dim < rank(operands...)。 - (C2)
is_unique(replica_groups)。 - (C3)
size(replica_groups)は次のように定義されます。cross_replicaを使用する場合はnum_replicas。cross_replica_and_partitionを使用する場合はnum_replicas。flattened_idsを使用する場合はnum_processes。
- (C4)
0 <= replica_groups < size(replica_groups)。 - (C5)
use_global_device_ids = trueの場合はchannel_id > 0です。 - (C6)
type(results...) = type(operands...)ただし、dim(results..., all_gather_dim) = dim(operands..., all_gather_dim) * dim(process_groups, 1)。
例
// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [[1, 2], [3, 4]]
// %operand0@(1, 0): [[5, 6], [7, 8]]
// %operand1@(0, 0): [[11, 12], [13, 14]]
// %operand1@(1, 0): [[15, 16], [17, 18]]
%result:2 = "stablehlo.all_gather"(%operand0, %operand1) {
all_gather_dim = 1 : i64,
replica_grou<ps = den>se[[0, 1]<] : ten>sor1x2xi64,
// channel_id = 0
channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 0
// use_global_device_ids = false
}< : (ten>sor2x2xi<64, ten>sor>2x2xi64)< - (ten>sor2x4xi<64, ten>sor2x4xi64)
// %result0@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result0@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result1@(0, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
// %result1@(1, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
all_reduce
セマンティクス
StableHLO プロセス グリッドの各プロセス グループ内で、各プロセスの operands テンソルの値に削減関数 computation を適用し、results テンソルを生成します。
このオペレーションは、StableHLO プロセス グリッドを process_groups に分割します。これは次のように定義されます。
channel_id <= 0 and use_global_device_ids = falseの場合、cross_replica(replica_groups)。channel_id > 0 and use_global_device_ids = falseの場合、cross_replica_and_partition(replica_groups)。channel_id > 0 and use_global_device_ids = trueの場合、flattened_ids(replica_groups)。
その後、各 process_group 内で次の操作を行います。
results...@process[result_index] = exec(schedule)(二分木schedule)ここで、exec(node)=computation(exec(node.left), exec(node.right))。exec(leaf)=leaf.value。
scheduleは、インオーダー トラバーサルがto_destination_type(operands...@process_group...[result_index], type(func_inputs(computation)[0]))である実装定義のバイナリツリーです。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operands |
可変数のテンソルまたはテンソルごとの量子化テンソル | (C5)、(C6) |
| (I2) | replica_groups |
si64 型の 1 次元テンソル定数の可変長引数 |
(C1 ~ C3) |
| (I3) | channel_id |
si64 型の定数 |
(C4) |
| (I4) | use_global_device_ids |
i1 型の定数 |
(C4) |
| (I5) | computation |
関数 | (C5) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
results |
可変数のテンソルまたはテンソルごとの量子化テンソル | (C6 ~ C7) |
制約
- (C1)
is_unique(replica_groups)。 - (C2)
size(replica_groups)は次のように定義されます。cross_replicaを使用する場合はnum_replicas。cross_replica_and_partitionを使用する場合はnum_replicas。flattened_idsを使用する場合はnum_processes。
- (C3)
0 <= replica_groups < size(replica_groups)。 - (C4)
use_global_device_ids = trueの場合はchannel_id > 0。 - (C5)
computationは(tensor<E>, tensor<E>) -> (tensor<E>)型です(is_promotable(element_type(operand), E))。 - (C6)
shape(results...) = shape(operands...)。 - (C7)
element_type(results...) = E。
例
// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [1, 2, 3, 4]
// %operand0@(1, 0): [5, 6, 7, 8]
// %operand1@(0, 0): [9, 10, 11, 12]
// %operand1@(1, 0): [13, 14, 15, 16]
%result:2 = "stablehlo.all_reduce"(%operand0, %operand0) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
%0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
"stable<hlo>.re>turn"(%0) : (tensori64) - ()<
}) {
>replica_g<roups => dense[[0, 1]] : tensor1x2xi64,
// channel_id = 0
channel_hand<le = #stablehlo.chan>nel_handlehandle = 0, type = 0
// use_global_<devic>e_ids = <false>
} >: (tenso<r4xi6>4, tenso<r4xi6>4) - (tensor4xi64, tensor4xi64)
// %result0@(0, 0): [6, 8, 10, 12]
// %result0@(1, 0): [6, 8, 10, 12]
// %result1@(0, 0): [22, 24, 26, 28]
// %result1@(1, 0): [22, 24, 26, 28]
all_to_all
セマンティクス
StableHLO プロセス グリッドの各プロセス グループ内で、operands テンソルの値を split_dimension に沿って分割し、分割された部分をプロセス間で分散させ、分散された部分を concat_dimension に沿って連結して results テンソルを生成します。このオペレーションは、StableHLO プロセス グリッドを process_groups に分割します。これは次のように定義されます。
channel_id <= 0の場合、cross_replica(replica_groups)。channel_id > 0の場合、cross_partition(replica_groups)。
その後、各 process_group 内で次の操作を行います。
process_group内のすべてのsenderのsplit_parts...@sender = split(operands...@sender, split_count, split_dimension)。receiver_index = process_group.index(receiver)の場合:scattered_parts...@receiver = [split_parts...@sender[receiver_index] for sender in process_group]results...@process = concatenate(scattered_parts...@process, concat_dimension)。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operands |
可変数のテンソルまたはテンソルごとの量子化テンソル | (C1 ~ C3)、(C9) |
| (I2) | split_dimension |
si64 型の定数 |
(C1)、(C2)、(C9) |
| (I3) | concat_dimension |
si64 型の定数 |
(C3)、(C9) |
| (I4) | split_count |
si64 型の定数 |
(C2)、(C4)、(C8)、(C9) |
| (I5) | replica_groups |
型 si64 の 2 次元テンソル定数 |
(C5-C8) |
| (I6) | channel_id |
si64 型の定数 |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
results |
可変数のテンソルまたはテンソルごとの量子化テンソル | (C9) |
制約
- (C1)
0 <= split_dimension < rank(operands...)。 - (C2)
dim(operands..., split_dimension) % split_count = 0。 - (C3)
0 <= concat_dimension < rank(operands...)。 - (C4)
0 < split_count。 - (C5)
is_unique(replica_groups)。 - (C6)
size(replica_groups)は次のように定義されます。cross_replicaを使用する場合はnum_replicas。cross_partitionを使用する場合はnum_partitions。
- (C7)
0 <= replica_groups < size(replica_groups)。 - (C8)
dim(replica_groups, 1) = split_count。 - (C9)
type(results...) = type(operands...)。ただし、split_dimension != concat_dimensionの場合は除く。dim(results..., split_dimension) = dim(operands..., split_dimension) / split_count。dim(results..., concat_dimension) = dim(operands..., concat_dimension) * split_count。
例
// num_replicas: 2
// num_partitions: 1
// %operand1@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand1@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
// %operand2@(0, 0): [[17, 18, 19, 20],
// [21, 22, 23, 24]]
// %operand2@(1, 0): [[25, 26, 27, 28],
// [29, 30, 31, 32]]
%result:2 = "stablehlo.all_to_all"(%operand1, %operand2) {
split_dimension = 1 : i64,
concat_dimension = 0 : i64,
split_count = 2 : i64,
replica_grou<ps = den>se[[0, 1]<] : ten>sor1x2xi64
// channel_id = 0
}< : (ten>sor2x4xi<64, ten>sor>2x4xi64)< - (ten>sor4x2xi<64, ten>sor4x2xi64)
// %result#0@(0, 0): [[1, 2], [5, 6], [9, 10], [13, 14]]
// %result#0@(1, 0): [[3, 4], [7, 8], [11, 12], [15, 16]]
// %result#1@(0, 0): [[17, 18], [21, 22], [25, 26], [29, 30]]
// %result#1@(1, 0): [[19, 20], [23, 24], [27, 28], [31, 32]]
と
セマンティクス
2 つのテンソル lhs と rhs の要素ごとの AND を実行し、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- ブール値の場合: 論理 AND。
- 整数の場合: ビット単位の AND。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | lhs |
ブール値または整数型のテンソル | (C1) |
| (I2) | rhs |
ブール値または整数型のテンソル | (C1) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
ブール値または整数型のテンソル | (C1) |
制約
- (C1)
type(lhs) = type(rhs) = type(result)。
例
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[1, 2], [3, 0]]
atan2
セマンティクス
lhs テンソルと rhs テンソルに対して要素ごとの atan2 演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- 浮動小数点の場合: IEEE-754 の
atan2。 - 複素数の場合は、複素数の atan2。
- 量子化された型の場合:
dequantize_op_quantize(atan2, lhs, rhs, type(result))。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | lhs |
浮動小数点型または複素数型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
| (I2) | rhs |
浮動小数点型または複素数型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
浮動小数点型または複素数型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)。
例
// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs)< : (t>ensor3xf<64, t>ens>or3xf64<) - t>ensor3xf64
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]
batch_norm_grad
セマンティクス
grad_output からバックプロパゲーションして batch_norm_training の複数の入力の勾配を計算し、grad_operand、grad_scale、grad_offset テンソルを生成します。より正式には、このオペレーションは、次のように Python 構文を使用して既存の StableHLO オペレーションへの分解として表現できます。
def compute_sum(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
return sum
def compute_mean(operand, feature_index):
sum = compute_sum(operand, feature_index)
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
# Broadcast inputs to type(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance`
# Intermediate values will be useful for computing gradients
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
# Use the implementation from batchnorm_expander.cc in XLA
# Temporary variables have exactly the same names as in the C++ code
elements_per_feature = broadcast_in_dim(
constant(divide(size(operand), dim(operand, feature_index)),
element_type(grad_output)),
[], type(operand))
i1 = multiply(grad_output, elements_per_feature)
i2 = broadcast_in_dim(
compute_sum(grad_output, feature_index), [feature_index], type(operand))
i3 = broadcast_in_dim(
compute_sum(multiply(grad_output, centered_operand), feature_index),
[feature_index], type(operand))
i4 = multiply(i3, centered_operand)
i5 = divide(i4, add(variance_bcast, epsilon_bcast))
i6 = subtract(subtract(i1, i2), i5)
grad_operand =
multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
grad_scale =
compute_sum(multiply(grad_output, normalized_operand), feature_index)
grad_offset = compute_sum(grad_output, feature_index)
return grad_operand, grad_scale, grad_offset
量子化された型の場合、dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean,
variance, grad_output: batch_norm_grad(operand, scale, mean, variance,
grad_output, epsilon, feature_index), operand, scale, mean, variance,
grad_output, type(grad_operand), type(grad_scale), type(feature_index)) を実行します。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C1-C3)、(C5) |
| (I2) | scale |
浮動小数点型またはテンソルごとの量子化型の 1 次元テンソル | (C2)、(C4)、(C5) |
| (I3) | mean |
浮動小数点型またはテンソルごとの量子化型の 1 次元テンソル | (C2)、(C4) |
| (I4) | variance |
浮動小数点型またはテンソルごとの量子化型の 1 次元テンソル | (C2)、(C4) |
| (I5) | grad_output |
浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C2)、(C3) |
| (I6) | epsilon |
f32 型の定数 |
|
| (I7) | feature_index |
si64 型の定数 |
(C1)、(C5) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
grad_operand |
浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C2)、(C3) |
grad_scale |
浮動小数点型またはテンソルごとの量子化型の 1 次元テンソル | (C2)、(C4) |
grad_offset |
浮動小数点型またはテンソルごとの量子化型の 1 次元テンソル | (C2)、(C4) |
制約
- (C1)
0 <= feature_index < rank(operand)。 - (C2)
operand、scale、mean、variance、grad_output、grad_operand、grad_scale、grad_offsetは同じbaseline_element_typeを持ちます。 - (C3)
operand、grad_output、grad_operandの形状が同じである。 - (C4)
scale、mean、variance、grad_scale、grad_offsetの形状は同じです。 - (C5)
size(scale) = dim(operand, feature_index)。
例
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
// [[0.1, 0.1], [0.1, 0.1]],
// [[0.1, 0.1], [0.1, 0.1]]
// ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
}< : (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ensor2xf<64, t>ensor2xf64,
< tenso>r2x>2x2xf64)< - (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ensor2xf64)
// %grad_operand: [
// [[0.0, 0.0], [0.0, 0.0]],
// [[0.0, 0.0], [0.0, 0.0]]
// ]
// %grad_scale: [0.0, 0.0]
// %grad_offset: [0.4, 0.4]
batch_norm_inference
セマンティクス
feature_index ディメンションを除くすべてのディメンションで operand テンソルを正規化し、result テンソルを生成します。より正式には、このオペレーションは、次のように Python 構文を使用して既存の StableHLO オペレーションへの分解として表現できます。
def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
# Broadcast inputs to shape(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance` instead of
# computing them like `batch_norm_training` does.
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
return add(multiply(scale_bcast, normalized_operand), offset_bcast)
量子化された型の場合、dequantize_op_quantize(lambda operand, scale, offset, mean, variance:
batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index), operand, scale, offset, mean, variance, type(result)) を実行します。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C1 ~ C7) |
| (I2) | scale |
浮動小数点型またはテンソルごとの量子化型の 1 次元テンソル | (C2)、(C3) |
| (I3) | offset |
浮動小数点型またはテンソルごとの量子化型の 1 次元テンソル | (C2)、(C4) |
| (I4) | mean |
浮動小数点型またはテンソルごとの量子化型の 1 次元テンソル | (C5) |
| (I5) | variance |
浮動小数点型またはテンソルごとの量子化型の 1 次元テンソル | (C2)、(C6) |
| (I6) | epsilon |
f32 型の定数 |
|
| (I7) | feature_index |
si64 型の定数 |
(C1)、(C3 ~ C6) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C2)、(C7) |
制約
- (C1)
0 <= feature_index < rank(operand)。 - (C2)
operand、scale、offset、mean、variance、resultは同じbaseline_element_typeを持ちます。 - (C3)
size(scale) = dim(operand, feature_index)。 - (C4)
size(offset) = dim(operand, feature_index)。 - (C5)
size(mean) = dim(operand, feature_index)。 - (C6)
size(variance) = dim(operand, feature_index)。 - (C7)
baseline_type(operand) = baseline_type(result)。
例
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
}< : (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ensor2xf<64, t>ensor2xf<64, t>ens>or2xf64<) - tenso>r2x2x2xf64
// %result: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
batch_norm_training
セマンティクス
feature_index ディメンションを除くすべてのディメンションの平均と分散を計算し、operand テンソルを正規化して output、batch_mean、batch_var テンソルを生成します。より正式には、このオペレーションは、次のように Python 構文を使用して既存の StableHLO オペレーションへの分解として表現できます。
def compute_mean(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def compute_variance(operand, feature_index):
mean = compute_mean(operand, feature_index)
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
centered_operand = subtract(operand, mean_bcast)
return compute_mean(mul(centered_operand, centered_operand), feature_index)
def batch_norm_training(operand, scale, offset, epsilon, feature_index):
mean = compute_mean(operand, feature_index)
variance = compute_variance(operand, feature_index)
return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index),
mean, variance
量子化された型の場合、dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset:
batch_norm_training(operand, scale, offset, epsilon, feature_index), operand,
scale, offset, type(output), type(batch_mean), type(batch_var)) を実行します。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C1) |
| (I2) | scale |
浮動小数点またはテンソルごとの量子化の 1 次元テンソル | (C2)、(C3) |
| (I3) | offset |
浮動小数点またはテンソルごとの量子化の 1 次元テンソル | (C2)、(C4) |
| (I4) | epsilon |
f32 型の定数 |
(C1)、(C3 ~ C6) |
| (I5) | feature_index |
si64 型の定数 |
(C1)、(C3 ~ C6) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
output |
浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C7) |
batch_mean |
浮動小数点またはテンソルごとの量子化の 1 次元テンソル | (C2)、(C5) |
batch_var |
浮動小数点またはテンソルごとの量子化の 1 次元テンソル | (C2)、(C6) |
制約
- (C1)
0 <= feature_index < rank(operand)。 - (C2)
operand、scale、offset、batch_mean、batch_var、outputは同じbaseline_element_typeを持ちます。 - (C3)
size(scale) = dim(operand, feature_index)。 - (C4)
size(offset) = dim(operand, feature_index)。 - (C5)
size(batch_mean) = dim(operand, feature_index)。 - (C6)
size(batch_var) = dim(operand, feature_index)。 - (C7)
baseline_type(output) = baseline_type(operand)。
例
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
}< : (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ens>or2xf64) -
< (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ensor2xf64)
// %output: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]
bitcast_convert
セマンティクス
operand テンソルに対してビットキャスト オペレーションを実行し、operand テンソル全体のビットが result テンソルの型を使用して再解釈される result テンソルを生成します。
より正式には、E = element_type(operand)、E' = element_type(result)、R = rank(operand) がある場合:
num_bits(E') < num_bits(E)の場合は、bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1])。num_bits(E') > num_bits(E)の場合は、bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :])。num_bits(E') = num_bits(E)の場合は、bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1])。
bits は、指定された値のメモリ内表現を返します。テンソルの正確な表現は実装定義であり、要素型の正確な表現も実装定義であるため、その動作は実装定義です。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
テンソルまたは量子化テンソル | (C1-C2) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
テンソルまたは量子化テンソル | (C1-C2) |
制約
- (C1)
E = is_quantized(operand) ? storage_type(operand) : element_type(operand)、E' = is_quantized(result) ? storage_type(result) : element_type(result)、R = rank(operand)がある場合:num_bits(E') = num_bits(E)の場合は、shape(result) = shape(operand)。num_bits(E') < num_bits(E)の場合:rank(result) = R + 1。dim(result, i) = dim(operand, i)(すべての0 <= i < R)。dim(result, R) * num_bits(E') = num_bits(E)。num_bits(E') > num_bits(E)の場合:rank(result) = R - 1。dim(result, i) = dim(operand, i)(すべての0 <= i < R)。dim(operand, R - 1) * num_bits(E) = num_bits(E')。
- (C2)
is_complex(operand) or is_complex(result)の場合はis_complex(operand) and is_complex(result)。
例
// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand)< : >(te>nsorf64<) - t>ensor4xf16
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation
broadcast_in_dim
セマンティクス
operand テンソルのデータを複製して、入力テンソルのディメンションとランクを拡張し、result テンソルを生成します。より正式には、result[result_index] = operand[operand_index] です。ここで、axes(operand) のすべての d について、次のようになります。
dim(operand, d) = 1の場合、operand_index[d] = 0。- それ以外の場合は
operand_index[d] = result_index[broadcast_dimensions[d]]。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
テンソルまたは量子化テンソル | (C1-C2)、(C5-C6) |
| (I2) | broadcast_dimensions |
si64 型の 1 次元テンソル定数 |
(C2-C6) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
テンソルまたは量子化テンソル | (C1)、(C3)、(C5-C6) |
制約
- (C1)
element_type(result)は次のように計算されます。!is_per_axis_quantized(operand)の場合、element_type(operand)。element_type(operand)。ただし、quantization_dimension(operand)、scales(operand)、zero_points(operand)はそれぞれquantization_dimension(result)、scales(result)、zero_points(result)と異なる場合があります。
- (C2)
size(broadcast_dimensions) = rank(operand)。 - (C3)
0 <= broadcast_dimensions < rank(result)。 - (C4)
is_unique(broadcast_dimensions)。 - (C5)
axes(operand)内のすべてのdについて:dim(operand, d) = 1またはdim(operand, d) = dim(result, broadcast_dimensions[d])。
- (C6)
is_per_axis_quantized(result)の場合:quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)]。dim(operand, quantization_dimension(operand)) = 1の場合はscales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result)))。
例
// %operand: [
// [1, 2, 3]
// ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
broadcast_dimensio<ns = arra>yi64: 2, 1
}< : (ten>sor>1x3xi32<) - tenso>r2x3x2xi32
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
ケース
セマンティクス
index の値に応じて、branches から 1 つの関数を実行した結果を出力します。より正式には、result = selected_branch() です。ここで、
0 <= index < size(branches)の場合、selected_branch = branches[index]。- それ以外の場合は
selected_branch = branches[-1]。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | index |
型 si32 の 0 次元テンソル |
|
| (I2) | branches |
可変長引数の関数 | (C1 ~ C4) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
results |
可変数のテンソル、量子化テンソル、トークン | (C4) |
制約
- (C1)
0 < size(branches)。 - (C2)
input_types(branches...) = []。 - (C3)
same(output_types(branches...))。 - (C4)
type(results...) = output_types(branches[0])。
例
// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
"stablehlo.return"(%result_branch0, %resul<t_bra>nch0) : <(tens>or2>xi64, tensor2xi64) - ()
}, {
"stablehlo.return"(%result_branc<h1, %>result_b<ranch>1) >: (tensor2xi64, <ten>sor>2xi64) -< ()
}>) : (ten<sori3>2) - (tensor2xi64, tensor2xi64)
// %result0: [1, 1]
// %result1: [1, 1]
cbrt
セマンティクス
operand テンソルに対して要素ごとの立方根演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- 浮動小数点の場合: IEEE-754 の
rootn(x, 3)。 - 複素数の場合: 複素数の立方根。
- 量子化された型の場合:
dequantize_op_quantize(cbrt, operand, type(result))
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
浮動小数点型または複素数型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
浮動小数点型または複素数型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(operand) = baseline_type(result)。
例
// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand)< : (t>ens>or4xf64<) - t>ensor4xf64
// %result: [0.0, 1.0, 2.0, 3.0]
ceil
セマンティクス
operand テンソルの要素ごとの天井関数を実行し、result テンソルを生成します。IEEE-754 仕様の roundToIntegralTowardPositive オペレーションを実装します。量子化された型の場合、dequantize_op_quantize(ceil, operand, type(result)) を実行します。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C1) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(operand) = baseline_type(result)。
例
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand)< : (t>ens>or5xf32<) - t>ensor5xf32
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]
cholesky
セマンティクス
行列のバッチのコレスキー分解を計算します。
より正式には、index_space(result) のすべての i について、result[i0, ..., iR-3, :, :] は a[i0, ..., iR-3, :, :] のコレスキー分解であり、下三角行列(lower が true の場合)または上三角行列(lower が false の場合)のいずれかの形式です。反対側の三角形(それぞれ厳密な上三角または厳密な下三角)の出力値は実装定義です。
入力行列がエルミート正定値行列ではない i が存在する場合、動作は未定義です。
量子化された型の場合、dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result)) を実行します。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | a |
浮動小数点型または複素数型のテンソル、またはテンソルごとの量子化テンソル | (C1 ~ C3) |
| (I2) | lower |
i1 型の定数 |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
浮動小数点型または複素数型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(a) = baseline_type(result)。 - (C2)
2 <= rank(a)。 - (C3)
dim(a, -2) = dim(a, -1)。
例
// %a: [
// [1.0, 2.0, 3.0],
// [2.0, 20.0, 26.0],
// [3.0, 26.0, 70.0]
// ]
%result = "stablehlo.cholesky"(%a) {
lower = true
}< : (ten>sor>3x3xf32<) - ten>sor3x3xf64
// %result: [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
クランプ
セマンティクス
operand テンソルのすべての要素を最小値と最大値の間にクランプし、result テンソルを生成します。より正式には、result[result_index] =
minimum(maximum(operand[result_index], min_element), max_element)(min_element = rank(min) = 0 ? min[] : min[result_index]、max_element = rank(max) = 0 ? max[] : max[result_index])。量子化された型の場合、dequantize_op_quantize(clamp, min, operand, max, type(result)) を実行します。
複素数に順序を課すには驚くべきセマンティクスが伴うため、今後、このオペレーションの複素数のサポートを削除する予定です(#560)。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | min |
テンソルまたはテンソルごとの量子化テンソル | (C1)、(C3) |
| (I2) | operand |
テンソルまたはテンソルごとの量子化テンソル | (C1 ~ C4) |
| (I3) | max |
テンソルまたはテンソルごとの量子化テンソル | (C2)、(C3) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
テンソルまたはテンソルごとの量子化テンソル | (C4) |
制約
- (C1)
rank(min) = 0 or shape(min) = shape(operand)。 - (C2)
rank(max) = 0 or shape(max) = shape(operand)。 - (C3)
baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max)。 - (C4)
baseline_type(operand) = baseline_type(result)。
例
// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max)< : (t>ensor3xi<32, t>ensor3xi<32, t>ens>or3xi32<) - t>ensor3xi32
// %result: [5, 13, 20]
collective_broadcast
セマンティクス
StableHLO プロセス グリッドの各プロセス グループ内で、ソース プロセスからターゲット プロセスに operand テンソルの値を送信し、result テンソルを生成します。
このオペレーションは、StableHLO プロセス グリッドを process_groups に分割します。これは次のように定義されます。
channel_id <= 0の場合、cross_replica(replica_groups)。channel_id > 0の場合、cross_partition(replica_groups)。
その後、result@process は次のように計算されます。
- プロセスが
process_groups[i]にあるようなiが存在する場合はoperand@process_groups[i, 0]。 - それ以外の場合は
broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
テンソルまたはテンソルごとの量子化テンソル | (C3) |
| (I2) | replica_groups |
si64 型の 1 次元テンソル定数の可変長引数 |
(C1)、(C2) |
| (I3) | channel_id |
si64 型の定数 |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
テンソルまたはテンソルごとの量子化テンソル | (C3) |
制約
- (C1)
is_unique(replica_groups)。 - (C2)
0 <= replica_groups < N。ここで、Nは次のように定義されます。cross_replicaを使用する場合はnum_replicas。cross_partitionを使用する場合はnum_partitions。
- (C3)
type(result) = type(operand)。
例
// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
replica_grou<ps = den>se[[2, 1]<] : ten>sor1x2xi64,
channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 0
} : (ten>sor>1x2xi64<) - ten>sor1x2xi64
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]
collective_permute
セマンティクス
StableHLO プロセス グリッド内の各プロセス グループ内で、ソース プロセスからターゲット プロセスに operand テンソルの値を送信し、result テンソルを生成します。
このオペレーションは、StableHLO プロセス グリッドを process_groups に分割します。これは次のように定義されます。
channel_id <= 0の場合、cross_replica(source_target_pairs)。channel_id > 0の場合、cross_partition(source_target_pairs)。
その後、result@process は次のように計算されます。
process_groups[i, 1] = processとなるiが存在する場合はoperand@process_groups[i, 0]。- それ以外の場合は
broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
テンソルまたはテンソルごとの量子化テンソル | (C5) |
| (I2) | source_target_pairs |
型 si64 の 2 次元テンソル定数 |
(C1 ~ C4) |
| (I3) | channel_id |
si64 型の定数 |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
テンソルまたはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
dim(source_target_pairs, 1) = 2。 - (C2)
is_unique(source_target_pairs[:, 0])。 - (C3)
is_unique(source_target_pairs[:, 1])。 - (C4)
0 <= source_target_pairs < N。ここで、Nは次のように定義されます。cross_replicaを使用する場合はnum_replicas。cross_partitionを使用する場合はnum_partitions。
- (C5)
type(result) = type(operand)。
例
// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
source_target_pai<rs = dense[[0, 1>], [1, 2]<] : ten>sor2x2xi64,
channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 0
}< : (ten>sor>2x2xi64<) - ten>sor2x2xi64
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]
比較
セマンティクス
comparison_direction と compare_type に従って lhs テンソルと rhs テンソルの要素ごとの比較を行い、result テンソルを生成します。
comparison_direction と compare_type の値は次の意味を持ちます。
ブール値と整数型の要素の場合:
EQ:lhs = rhsNE:lhs != rhsGE:lhs >= rhsGT:lhs > rhsLE:lhs <= rhsLT:lhs < rhs
compare_type = FLOAT を含む浮動小数点要素型の場合、op は次の IEEE-754 オペレーションを実装します。
EQ:compareQuietEqualNE:compareQuietNotEqualGE:compareQuietGreaterEqualGT:compareQuietGreaterLE:compareQuietLessEqualLT:compareQuietLess
compare_type = TOTALORDER を含む浮動小数点要素型の場合、op は IEEE-754 の totalOrder オペレーションと compareQuietEqual オペレーションの組み合わせを使用します。
複雑な要素タイプの場合、指定された comparison_direction と compare_type を使用して (real, imag) ペアの辞書式比較が行われます。複素数に順序を課すには、驚くべきセマンティクスが伴います。そのため、今後、comparison_direction が GE、GT、LE、または LT の場合、複素数のサポートを削除する予定です(#560)。
量子化された型の場合、dequantize_compare(lhs, rhs,
comparison_direction) を実行します。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | lhs |
テンソルまたはテンソルごとの量子化テンソル | (C1 ~ C3) |
| (I2) | rhs |
テンソルまたはテンソルごとの量子化テンソル | (C1-C2) |
| (I3) | comparison_direction |
EQ、NE、GE、GT、LE、LT の列挙型 |
|
| (I4) | compare_type |
FLOAT、TOTALORDER、SIGNED、UNSIGNED の列挙型 |
(C3) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
ブール型のテンソル | (C2) |
制約
- (C1)
baseline_element_type(lhs) = baseline_element_type(rhs)。 - (C2)
shape(lhs) = shape(rhs) = shape(result)。 - (C3)
compare_typeは次のように定義されます。is_signed_integer(element_type(lhs))の場合、SIGNED。is_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs))の場合、UNSIGNED。is_float(element_type(lhs))の場合はFLOATまたはTOTALORDER。is_complex(element_type(lhs))の場合、FLOAT。
例
// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
comparison_direction = <#stablehlocomparison_di>rection LT,
compare_type = <#stablehlocomparison_>type FLOAT
}< : (t>ensor2xf<32, t>ens>or2xf32<) - >tensor2xi1
// %result: [true, false]
複雑
セマンティクス
実数値と虚数値のペア lhs と rhs から複素数値への要素ごとの変換を行い、result テンソルを生成します。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | lhs |
f32 型または f64 型のテンソル |
(C1 ~ C3) |
| (I2) | rhs |
f32 型または f64 型のテンソル |
(C1) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
複合型のテンソル | (C2)、(C3) |
制約
- (C1)
type(lhs) = type(rhs)。 - (C2)
shape(result) = shape(lhs)。 - (C3)
element_type(result)はcomplex<E>型で、E = element_type(lhs)です。
例
// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs)< : (t>ensor2xf<64, t>ens>or2xf64<) - tenso<r2x>>complexf64
// %result: [(1.0, 2.0), (3.0, 4.0)]
複合
セマンティクス
他の StableHLO オペレーションで構成されたオペレーションをカプセル化し、inputs と composite_attributes を受け取って results を生成します。op のセマンティクスは decomposition 属性によって実装されます。composite オペレーションは、プログラムのセマンティクスを変更することなく、その分解に置き換えることができます。分解をインライン化しても同じ op セマンティクスが得られない場合は、custom_call を使用することをおすすめします。
version フィールド(デフォルトは 0)は、コンポジットのセマンティクスが変更されたタイミングを示すために使用されます。
入力
| ラベル | 名前 | タイプ |
|---|---|---|
| (I1) | inputs |
可変個数の値 |
| (I2) | name |
string 型の定数 |
| (I3) | composite_attributes |
属性ディクショナリ |
| (I4) | decomposition |
string 型の定数 |
| (I5) | version |
si32 型の定数 |
出力
| 名前 | タイプ |
|---|---|
results |
可変個数の値 |
制約
- (C1)
is_namespaced_op_name(name) - (C2)
is_defined_in_parent_scope(decomposition) - (C3)
types(inputs...) == input_types(decomposition) - (C4)
types(results...) == output_types(decomposition)
例
%results = "stablehlo.composite"(%input0, %input1) {
name = "my_namespace.my_op",
composite_attributes = {
my_attribute = "my_value"
},
decomposition = @my_op,
< ve>rsion = <1 :> i3>2
} : (<ten>sorf32, tensorf32) - tensorf32
concatenate
セマンティクス
指定された引数と同じ順序で dimension ディメンションに沿って inputs を連結し、result テンソルを生成します。より正式には、result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1]。ここで、
id = d0 + ... + dk-1 + kd。dはdimensionと等しく、d0、... はinputsのd番目のディメンション サイズです。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | inputs |
可変数のテンソルまたはテンソルごとの量子化テンソル | (C1-C6) |
| (I2) | dimension |
si64 型の定数 |
(C2)、(C4)、(C6) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
テンソルまたはテンソルごとの量子化テンソル | (C5-C6) |
制約
- (C1)
same(element_type(inputs...))。 - (C2)
same(shape(inputs...))(dim(inputs..., dimension)以外)。 - (C3)
0 < size(inputs)。 - (C4)
0 <= dimension < rank(inputs[0])。 - (C5)
element_type(result) = element_type(inputs[0])。 - (C6)
shape(result) = shape(inputs[0])ただし、次の場合は除く。dim(result, dimension) = dim(inputs[0], dimension) + ...。
例
// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
dimension = 0 : i64
}< : (ten>sor3x2xi<64, ten>sor>1x2xi64<) - ten>sor4x2xi64
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]
定数
セマンティクス
定数 value から output テンソルを生成します。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | value |
定数 | (C1) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
output |
テンソルまたは量子化テンソル | (C1) |
制約
- (C1)
type(value) = type(output)。
例
%output = "stablehlo.constant"() {
val<ue = dense[[0.0, 1.0], [>2.0, 3.0]<] : ten>sor2x2xf3>2
} : (<) - ten>sor2x2xf32
// %output: [[0.0, 1.0], [2.0, 3.0]]
コンバージョン
セマンティクス
operand テンソルで要素ごとに要素型を変換し、result テンソルを生成します。
boolean-to-any-supported-type 変換の場合、値 false は 0 に変換され、値 true は 1 に変換されます。any-supported-type-to-boolean 変換の場合、ゼロ値は false に変換され、ゼロ以外の値は true に変換されます。複合型の処理方法については、以下をご覧ください。
整数から整数、整数から浮動小数点数、浮動小数点数から浮動小数点数の変換では、ソース値を宛先型で正確に表すことができる場合、結果値はその正確な表現になります。それ以外の場合の動作は未定です(#180)。
floating-point-to-integerへの変換では、小数部分が切り捨てられます。切り捨てられた値を宛先型で表現できない場合、動作は未定です(#180)。
複素数から複素数への変換では、実部と虚部の変換に 浮動小数点から浮動小数点への変換と同じ動作が適用されます。
complex-to-any-other-type 変換と any-other-type-to-complex 変換では、それぞれ、変換元の虚数部は無視されるか、変換先の虚数部がゼロに設定されます。実部の変換は、浮動小数点変換に従います。
原則として、このオペレーションは逆量子化(量子化テンソルから通常のテンソルへの変換)、量子化(通常のテンソルから量子化テンソルへの変換)、再量子化(量子化テンソル間の変換)を表すことができますが、現時点では、最初のユースケースには uniform_dequantize、2 番目と 3 番目のユースケースには uniform_quantize という専用のオペレーションがあります。今後、この 2 つのオペレーションは convert に統合される可能性があります(#1576)。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
テンソル | (C1) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
テンソル | (C1) |
制約
- (C1)
shape(operand) = shape(result)。
例
// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand)< : (t>ens>or3xi64<) - tenso<r3x>>complexf64
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]
畳み込み
セマンティクス
lhs のウィンドウと rhs のスライス間のドット積を計算し、result を生成します。次の図は、具体的な例を使用して、result の要素が lhs と rhs からどのように計算されるかを示しています。
より正式には、lhs のウィンドウを表現できるように、lhs の観点から入力の再フレーミングを検討します。
lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension))。lhs_window_strides = lhs_shape(1, window_strides, 1)lhs_padding = lhs_shape([0, 0], padding, [0, 0])lhs_base_dilations = lhs_shape(1, lhs_dilation, 1)lhs_window_dilations = lhs_shape(1, rhs_dilation, 1)。
このリフレーミングでは、次のヘルパー関数を使用します。
lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension])。result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension])。j[d] = i[permutation[d]]の場合:permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1]
feature_group_count = 1 かつ batch_group_count = 1 の場合、index_space(dim(result, output_spatial_dimensions...)) のすべての output_spatial_index について、result[result_shape(:, output_spatial_index, :)] = dot_product(ここで、
padding_value = constant(0, element_type(lhs))。padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1)lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strideslhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations)。reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true])。 この機能は使用されていないようなので、今後削除する予定です(#1181)。dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension])。
feature_group_count > 1 の場合:
lhses = split(lhs, feature_group_count, input_feature_dimension)。rhses = split(rhs, feature_group_count, kernel_output_feature_dimension)results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...)result = concatenate(results, output_feature_dimension)。
batch_group_count > 1 の場合:
lhses = split(lhs, batch_group_count, input_batch_dimension)。rhses = split(rhs, batch_group_count, kernel_output_feature_dimension)results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...)。result = concatenate(results, output_feature_dimension)。
量子化された型の場合、dequantize_op_quantize(
lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs,
type(result)) を実行します。
ハイブリッド量子化型の場合、hybrid_dequantize_then_op(
lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs) を実行します。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | lhs |
テンソルまたはテンソルごとの量子化テンソル | (C1)、(C10-C11)、(C14)、(C25)、(C27-C28)、(C31-C32)、(C34) |
| (I2) | rhs |
テンソルまたは量子化テンソル | (C1)、(C14-C16)、(C25)、(C27-C29)、(C31-C34) |
| (I3) | window_strides |
si64 型の 1 次元テンソル定数 |
(C2-C3)、(C25) |
| (I4) | padding |
型 si64 の 2 次元テンソル定数 |
(C4)、(C25) |
| (I5) | lhs_dilation |
si64 型の 1 次元テンソル定数 |
(C5-C6)、(C25) |
| (I6) | rhs_dilation |
si64 型の 1 次元テンソル定数 |
(C7-C8)、(C25) |
| (I7) | window_reversal |
i1 型の 1 次元テンソル定数 |
(C9) |
| (I8) | input_batch_dimension |
si64 型の定数 |
(C10)、(C13)、(C25) |
| (I9) | input_feature_dimension |
si64 型の定数 |
(C11)、(C13-C14) |
| (I10) | input_spatial_dimensions |
si64 型の 1 次元テンソル定数 |
(C12)、(C13)、(C25) |
| (I11) | kernel_input_feature_dimension |
si64 型の定数 |
(C14)、(C18) |
| (I12) | kernel_output_feature_dimension |
si64 型の定数 |
(C15-C16)、(C18)、(C25)、(C29) |
| (I13) | kernel_spatial_dimensions |
si64 型の 1 次元テンソル定数 |
(C17-C18)、(C25) |
| (I14) | output_batch_dimension |
si64 型の定数 |
(C20)、(C25) |
| (I15) | output_feature_dimension |
si64 型の定数 |
(C20)、(C25)、(C30) |
| (I16) | output_spatial_dimensions |
si64 型の 1 次元テンソル定数 |
(C19-C20)、(C25) |
| (I17) | feature_group_count |
si64 型の定数 |
(C11)、(C14)、(C16)、(C21)、(C23) |
| (I18) | batch_group_count |
si64 型の定数 |
(C10)、(C15)、(C22)、(C23)、(C25) |
| (I19) | precision_config |
DEFAULT、HIGH、HIGHEST の列挙型の可変長引数 |
(C24) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
テンソルまたは量子化テンソル | (C25-C28)、(C30)、(C32-34) |
制約
- (C1)
N = rank(lhs) = rank(rhs)。 - (C2)
size(window_strides) = N - 2。 - (C3)
0 < window_strides。 - (C4)
shape(padding) = [N - 2, 2]。 - (C5)
size(lhs_dilation) = N - 2。 - (C6)
0 < lhs_dilation。 - (C7)
size(rhs_dilation) = N - 2。 - (C8)
0 < rhs_dilation。 - (C9)
size(window_reversal) = N - 2。 - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0。 - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0。 - (C12)
size(input_spatial_dimensions) = N - 2。 - (C13)
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]が与えられた場合:is_unique(input_dimensions)。0 <= input_dimensions < N。
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count。 - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0。 - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0。 - (C17)
size(kernel_spatial_dimensions) = N - 2。 - (C18)
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]が与えられた場合:is_unique(kernel_dimensions)。0 <= kernel_dimensions < N。
- (C19)
size(output_spatial_dimensions) = N - 2。 - (C20)
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]が与えられた場合:is_unique(output_dimensions)。0 <= output_dimensions < N。
- (C21)
0 < feature_group_count。 - (C22)
0 < batch_group_count。 - (C23)
feature_group_count = 1 or batch_group_count = 1。 - (C24)
size(precision_config) = 2。 - (C25)
dim(result, result_dim)は次のように定義されます。result_dim = output_batch_dimensionの場合、dim(lhs, input_batch_dimension) / batch_group_count。result_dim = output_feature_dimensionの場合、dim(rhs, kernel_output_feature_dimension)。- それ以外の場合は
num_windows。ここで output_spatial_dimensions[spatial_dim] = result_dim。lhs_dim = input_spatial_dimensions[spatial_dim]rhs_dim = kernel_spatial_dimensions[spatial_dim]dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1。
- (C26)
rank(result) = N。 - オペレーションが量子化されていないテンソルを使用する場合:
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result)。
- (C27)
- オペレーションで量子化テンソルを使用する場合:
- (C28)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)。 - (C29)
is_per_axis_quantized(rhs)の場合はquantization_dimension(rhs) = kernel_output_feature_dimension。 - (C30)
is_per_axis_quantized(result)の場合はquantization_dimension(result) = output_feature_dimension。 is_quantized(lhs)の場合:- (C31)
storage_type(lhs) = storage_type(rhs)。 - (C32)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)。 - (C33)
is_per_tensor_quantized(rhs)の場合はis_per_tensor_quantized(result)。 !is_quantized(lhs)の場合:- (C34)
element_type(lhs) = expressed_type(rhs) = element_type(result)。
- (C28)
例
// %lhs: [[
// [
// [1], [2], [5], [6]
// ],
// [
// [3], [4], [7], [8]
// ],
// [
// [10], [11], [14], [15]
// ],
// [
// [12], [13], [16], [17]
// ]
// ]]
//
// %rhs: [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
window_strid<es = arra>yi64: 4, 4,
paddi<n>g = dense<0 : ten>sor2x2xi64,
lhs_dilati<on = arra>yi64: 2, 2,
rhs_dilati<on = arra>yi64: 1, 1,
window_revers<al = arrayi1: fa>lse, false,
// In the StableHLO dialect, dimension numbers are encoded vi<a:
// `[input >dim<ensions]x[kernel >di>mensions]-[output dimensions]`.
// "b" is batch dimension, "f" is feature dimension,
// "i" is input feature dimension, "o" is output feature dimension,
// "0/1/etc" a<re spatial dimensions.
d>imension_num>bers = #stablehlo.conv[b, 0, 1, f]x[0, 1, i, o]-[b, 0, 1, f],
batch_group_count = 1 : i64,
fea<ture_group_count >= 1 : i64,
< precision_config> = [#stablehl<oprecision >DEFAULT,< #stablehlo>pre>cision <DEFAULT]
} >: (tensor1x4x4x1xi64, tensor3x3x1x1xi64) - tensor1x2x2x1xi64
// %result: [[
// [[10], [26]],
// [[46], [62]]
// ]]
コサイン
セマンティクス
operand テンソルに対して要素ごとのコサイン演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- 浮動小数点の場合: IEEE-754 の
cos。 - 複素数の場合は、複素余弦。
- 量子化された型の場合:
dequantize_op_quantize(cosine, operand, type(result))。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
浮動小数点型または複素数型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
浮動小数点型または複素数型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(operand) = baseline_type(result)。
例
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.cosine"(%operand)< : (ten>sor>2x2xf32<) - ten>sor2x2xf32
// %result: [[1.0, 0.0], [-1.0, 0.0]]
count_leading_zeros
セマンティクス
operand テンソルの先頭のゼロビット数を要素ごとにカウントし、result テンソルを生成します。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
整数型のテンソル | (C1) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
整数型のテンソル | (C1) |
制約
- (C1)
type(operand) = type(result)。
例
// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand)< : (ten>sor>2x2xi64<) - ten>sor2x2xi64
// %result: [[64, 63], [56, 0]]
custom_call
セマンティクス
inputs と called_computations を受け取り、results を生成する実装定義のオペレーション call_target_name をカプセル化します。has_side_effect、backend_config、api_version を使用して、実装定義の追加メタデータを提供できます。
現時点では、このオペレーションには、XLA コンパイラの対応するオペレーションの有機的な進化を反映した、かなり整理されていないメタデータのコレクションが含まれています。今後、このメタデータを統合する予定です(#741)。
入力
| ラベル | 名前 | タイプ |
|---|---|---|
| (I1) | inputs |
可変個数の値 |
| (I2) | call_target_name |
string 型の定数 |
| (I3) | has_side_effect |
i1 型の定数 |
| (I4) | backend_config |
string 型の定数または属性ディクショナリ |
| (I5) | api_version |
si32 型の定数 |
| (I6) | called_computations |
string 型の可変長定数 |
| (I7) | output_operand_aliases |
出力とオペランドでエイリアシング部分を指定する |
出力
| 名前 | タイプ |
|---|---|
results |
可変個数の値 |
(XLA GPU サポート)特別な custom_call ターゲット
buffer 型に関連する 3 つの特別な call_target_name があります。CreateBuffer は初期化されていない buffer を作成し、Pin は初期化された buffer を作成し、Unpin は buffer を割り当て解除して buffer のコンテンツを返します。
%uninitialized_buffer = "stablehlo.custom_call"() {
call_target_name = "CreateBuffer",
api_version> = 4 : <i32,
>} : () - memref4xf64
%initialized_buffer = "stablehlo.custom_call"(%init_value) {
call_target_name = "Pin&quo<t;,
> ap>i_versi<on = >4 : i32,
} : (tensor4xf64) - memref4xf64
%dealloc_buffer = "stablehlo.custom_call"(%initialized_buffer) {
call_target_na<me = >&qu>ot;Unpi<n&quo>t;,
api_version = 4 : i32,
} : (memref4xf64) - tensor4xf64
エイリアス
一部の custom_call オペレーションでは、出力の一部とオペランドの一部が同じメモリを共有する必要があります。これは output_operand_aliases で表すことができます。エイリアス ペア表現は、出力部分を表す出力タプル インデックスのリストと、オペランド部分を表すオペランド インデックスとオペランド タプル インデックスのリストで構成されます。対応する型が tuple 型でない場合、出力またはオペランドのタプル インデックスのリストは空になります。また、任意のネストされたタプル型に対して任意の長さになります。これは、XLA エイリアス表現に似ています。
エイリアス ペアの出力部分と入力部分は同じ型でなければなりません。CreateBuffer、Pin、Unpin への呼び出しではない custom_call オペレーションの場合、buffer オペランドはエイリアスのペアに 1 つだけ指定でき、buffer 出力はエイリアスのペアに 1 つ指定する必要があります。
例
%results = "stablehlo.custom_call"(%input0) {
call_target_name = "foo",
has_side_effect = false,
backend_config = {bar = 42 : i32},
api_version = 4 : i32,
called_computations <= [>@fo>o]
} : <(te>nsorf64) - tensorf64
%updated_buffer = "stablehlo.custom_call"(%buffer) {
call_target_name = "Update",
api_version = 4 : i32,
output_operand_aliases< = [
#stablehlo.output_operand_aliasoutput_tuple_indices = [],
operand_ind>ex = 0,
< oper>and>_tuple_<indic>es = []]
} : (memref4xf64) - memref4xf64
divide
セマンティクス
被除数 lhs テンソルと除数 rhs テンソルの要素ごとの除算を行い、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- 整数の場合: 整数除算。代数商が生成され、小数部は破棄されます。
- 浮動小数点の場合: IEEE-754 の
division。 - 複素数の場合: 複素数の除算。
- 量子化された型の場合:
dequantize_op_quantize(divide, lhs, rhs, type(result))。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | lhs |
整数型、浮動小数点型、複合型、またはテンソルごとの量子化テンソルのテンソル | (C1) |
| (I2) | rhs |
整数型、浮動小数点型、複合型、またはテンソルごとの量子化テンソルのテンソル | (C1) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
整数型、浮動小数点型、複素数型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)。
例
// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs)< : (t>ensor4xf<32, t>ens>or4xf32<) - t>ensor4xf32
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]
dot_general
セマンティクス
lhs のスライスと rhs のスライスのドット積を計算し、result テンソルを生成します。
より正式には、result[result_index] = dot_product。ここで、
lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions]。rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions]。result_batching_index + result_lhs_index + result_rhs_index = result_indexここで、size(result_batching_index) = size(lhs_batching_dimensions)、size(result_lhs_index) = size(lhs_result_dimensions)、size(result_rhs_index) = size(rhs_result_dimensions)。transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions)。transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :])reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions))transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions)transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :])reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions))。dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y))。
量子化された型の場合、dequantize_op_quantize(
lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
rhs_batching_dimensions, lhs_contracting_dimensions,
rhs_contracting_dimensions, precision_config), lhs, rhs, type(result)) を実行します。
ハイブリッド量子化型の場合、hybrid_dequantize_then_op(
lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
rhs_batching_dimensions, lhs_contracting_dimensions,
rhs_contracting_dimensions, precision_config), lhs, rhs) を実行します。
precision_config は、アクセラレータ バックエンドでの計算の速度と精度のトレードオフを制御します。次のいずれかになります(現時点では、これらの列挙型の値のセマンティクスは十分に指定されていませんが、#755 で対応する予定です)。
DEFAULT: 最速の計算ですが、元の数値に対する近似の精度は最も低くなります。HIGH: 計算は遅くなりますが、元の数値により近い近似値が得られます。HIGHEST: 最も遅い計算ですが、元の数値に最も正確な近似値が得られます。
DotAlgorithm は、ドット演算の実装に使用されるアルゴリズムの主なプロパティを定義します。また、精度も定義します。アルゴリズム属性フィールドが設定されている場合、precision_config は DEFAULT にする必要があります。デフォルト パラメータは実装で定義されるため、DotAlgorithms にはデフォルト値がありません。そのため、すべてのドット アルゴリズム フィールドを None に設定して、空のドット アルゴリズムを指定できます。この場合、代わりに precision_config 値が使用されます。
DotAlgorithm フィールドには次のものがあります。
lhs_precision_typeとrhs_precision_type: オペレーションの LHS と RHS が丸められる精度。精度型は、入力と出力のストレージ型とは独立しています。accumulation_type: 累積に使用される精度。lhs_component_count、rhs_component_count、num_primitive_operationsは、LHS や RHS を複数のコンポーネントに分解し、それらの値に対して複数の「プリミティブ」ドット演算を行うアルゴリズムを実行する場合に適用されます。通常、これはより高い精度をエミュレートするために行われます(例: Leveraging the bfloat16 Artificial Intelligence Datatype For Higher-Precision Computations: bf16_6x tf32_3x など)。分解のないアルゴリズムの場合、これらの値は1に設定する必要があります。allow_imprecise_accumulation: 一部のステップ(CUBLASLT_MATMUL_DESC_FAST_ACCUMなど)で低精度の累積が許可されるかどうかを指定します。
DotAlgorithm 属性の例:
// Inputs are casted to tf32, and then accumulated in f32:
{lhs_precision_type = tf32,
rhs_precision_type = tf32,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = false}
// bf16_6x: each input is decomposed to 3 bf16 components, then 6 dot operations are done on those components, and the result is accumulated in f32.
{lhs_precision_type = bf16,
rhs_precision_type = bf16,
accumulation_type = f32,
lhs_component_count = 3,
rhs_component_count = 3,
num_primitive_operations = 6,
allow_imprecise_accumulation = false}
// Inputs are (casted to) f8e5m2, and we accumulate in f32, but for some steps we may accumulate in lower precision.
{lhs_precision_type = f8e5m2,
rhs_precision_type = f8e5m2,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = true}
どの組み合わせがサポートされるかは、実装によって異なります。一般に、StableHLO のコンシューマーによって、各アルゴリズムが各アクセラレータ タイプでサポートされるとは限りません。指定されたアルゴリズムがサポートされていない場合は、代替アルゴリズムにフォールバックするのではなく、エラーを発生させる必要があります。StableHLO 検証では、ベスト エフォートで検証が行われ、どのハードウェアでもサポートされていないことがわかっているアルゴリズムが防止されます。
サポートされているアルゴリズムの値については、xla_data.proto > Algorithm をご覧ください。チケット #2483 には、バックエンドでサポートされているアルゴリズムに関する一元化されたドキュメントを作成する計画が記載されています。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | lhs |
テンソルまたはテンソルごとの量子化テンソル | (C5-C6)、(C9-C10)、(C12-C14)、(C17-C18)、(C20) |
| (I2) | rhs |
テンソルまたは量子化テンソル | (C7-C10)、(C12-C20) |
| (I3) | lhs_batching_dimensions |
si64 型の 1 次元テンソル定数 |
(C1)、(C3)、(C5)、(C9)、(C12) |
| (I4) | rhs_batching_dimensions |
si64 型の 1 次元テンソル定数 |
(C1)、(C4)、(C7)、(C9) |
| (I5) | lhs_contracting_dimensions |
si64 型の 1 次元テンソル定数 |
(C2)、(C3)、(C6)、(C10) |
| (I6) | rhs_contracting_dimensions |
si64 型の 1 次元テンソル定数 |
(C2)、(C4)、(C8)、(C10)、(C16) |
| (I7) | precision_config |
DEFAULT、HIGH、HIGHEST の列挙型の可変長引数 |
(C11)、(C21) |
| (I8) | lhs_precision_type |
FloatType または TensorFloat32 | (C21) |
| (I9) | rhs_precision_type |
FloatType または TensorFloat32 | (C21) |
| (I10) | accumulation_type |
FloatType または TensorFloat32 | (C21) |
| (I11) | lhs_component_count |
si32 型の定数 |
(C21)、(C22) |
| (I12) | rhs_component_count |
si32 型の定数 |
(C21)、(C23) |
| (I13) | num_primitive_operations |
si32 型の定数 |
(C21)、(C24) |
| (I14) | allow_imprecise_accumulation |
bool 型の定数 |
(C21) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
テンソルまたは量子化テンソル | (C12)、(C14)、(C18-C20) |
制約
- (C1)
size(lhs_batching_dimensions) = size(rhs_batching_dimensions)。 - (C2)
size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions)。 - (C3)
is_unique(lhs_batching_dimensions + lhs_contracting_dimensions)。 - (C4)
is_unique(rhs_batching_dimensions + rhs_contracting_dimensions)。 - (C5)
0 <= lhs_batching_dimensions < rank(lhs)。 - (C6)
0 <= lhs_contracting_dimensions < rank(lhs)。 - (C7)
0 <= rhs_batching_dimensions < rank(rhs)。 - (C8)
0 <= rhs_contracting_dimensions < rank(rhs)。 - (C9)
dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)。 - (C10)
dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...)。 - (C11)
size(precision_config) = 2。 - (C12)
shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions)。 - オペレーションが量子化されていないテンソルを使用する場合:
- (C13)
element_type(lhs) = element_type(rhs)。
- (C13)
- オペレーションで量子化テンソルを使用する場合:
- (C14)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)。 - (C15)
zero_points(rhs) = 0。 - (C16)
is_per_axis_quantized(rhs)の場合、quantization_dimension(rhs)はrhs_contracting_dimensionsに含まれません。 is_quantized(lhs)の場合:- (C17)
storage_type(lhs) = storage_type(rhs)。 - (C18)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)。 - (C19)
is_per_tensor_quantized(rhs)の場合、is_per_tensor_quantized(result)。 !is_quantized(lhs)の場合:- (C20)
element_type(lhs) = expressed_type(rhs) = element_type(result)。
- (C14)
!is_empty_algorithm(lhs_precision_type, rhs_precision_type, accumulation_type, lhs_component_count, rhs_component_count, num_primitive_operations allow_imprecise_accumulation)の場合:- (C21)
precision_config... = DEFAULT。 - (C22)
0 < lhs_component_count。 - (C23)
0 < rhs_component_count。 - (C24)
0 < num_primitive_operations。
- (C21)
例
// %lhs: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
// %rhs: [
// [[1, 0],
// [0, 1]],
// [[1, 0],
// [0, 1]]
// ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
dot_dimension_numbers = #sta<blehlo.dot
lhs_batching_dimensions = [0],
rhs_batching_dimensions = [0],
lhs_contracting_dimensions = [2],
rhs_contracting_dimension>s = [1]
,
precision_config = [<#stablehloprecisi>on DEFAULT, <#stablehloprecisi>on DEFAULT],
algorithm = #stablehlo.dot<_algorithm
lhs_precision_type = tf32,
rhs_precision_type = tf32,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation >= false
}< : (tenso>r2x2x2xi<64, tenso>r2x>2x2xi64<) - tenso>r2x2x2xi64
// %result: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
dynamic_broadcast_in_dim
セマンティクス
このオペレーションは broadcast_in_dim オペレーションと機能的に同じですが、結果の形状は output_dimensions を介して動的に指定されます。
このオペレーションでは、ディメンションの展開動作に関する静的な知識を表すオプションの属性 known_expanding_dimensions、known_nonexpanding_dimensions も受け入れます。指定しない場合、すべてのディメンションが拡張可能と見なされます。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
テンソルまたは量子化テンソル | (C1-C2)、(C5-C6)、(C9) |
| (I2) | output_dimensions |
整数型の 1 次元テンソル | (C7) |
| (I3) | broadcast_dimensions |
整数型の 1 次元定数テンソル | (C2-C6) |
| (I4) | known_expanding_dimensions |
整数型の 1 次元定数テンソル | (C8-C9) |
| (I5) | known_nonexpanding_dimensions |
整数型の 1 次元定数テンソル | (C8-C9) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
テンソルまたは量子化テンソル | (C1)、(C3)、(C5 ~ C7) |
制約
- (C1)
element_type(result)は次のように計算されます。!is_per_axis_quantized(operand)の場合、element_type(operand)。element_type(operand)。ただし、quantization_dimension(operand)、scales(operand)、zero_points(operand)はそれぞれquantization_dimension(result)、scales(result)、zero_points(result)と異なる場合があります。
- (C2)
size(broadcast_dimensions) = rank(operand)。 - (C3)
0 <= broadcast_dimensions < rank(result)。 - (C4)
is_unique(broadcast_dimensions)。 - (C5)
axes(operand)内のすべてのdについて:dim(operand, d) = 1またはdim(operand, d) = dim(result, broadcast_dimensions[d])。
- (C6)
is_per_axis_quantized(result)の場合:quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)]。dim(operand, quantization_dimension(operand)) = 1の場合はscales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result)))。
- (C7)
size(output_dimensions) = rank(result)。 - (C8)
is_unique(known_expanding_dimensions + known_nonexpanding_dimensions)。 - (C9)
0 <= known_expanding_dimensions < rank(operand)。 - (C10)
0 <= known_nonexpanding_dimensions < rank(operand)。
例
// %operand: [
// [1, 2, 3]
// ]
%operand = stablehlo.constant dense<[[1, 2, 3]]> : tensor<1x3xi64>
%output_dimensions = stablehlo.constant dense<[2, 3, 2]> : tensor<3xi64>
%result = "stablehlo.dynamic_broadcast_in_dim"(%operand, %output_dimensions) {
broadcast_dimensio<ns = arra>yi64: 2, 1,
known_expanding_dimensio<ns = a>rrayi64: 0,
known_nonexpanding_dimensio<ns = a>rrayi64: 1
}< : (ten>sor1x3xi<64, t>ens>or3xi64<) - tenso>r2x3x2xi64
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
dynamic_conv
セマンティクス
このオペレーションは、畳み込みオペレーションと機能的に同じですが、パディングは padding を介して動的に指定されます。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | lhs |
テンソルまたはテンソルごとの量子化テンソル | (C1)、(C10-C11)、(C14)、(C25)、(C26-C27)、(C30-C31)、(C33) |
| (I2) | rhs |
テンソルまたは量子化テンソル | (C1)、(C14 ~ C16)、(C26 ~ C28)、(C30 ~ C33) |
| (I3) | padding |
整数型の 2 次元テンソル | (C4) |
| (I4) | window_strides |
si64 型の 1 次元テンソル定数 |
(C2-C3) |
| (I5) | lhs_dilation |
si64 型の 1 次元テンソル定数 |
(C5-C6) |
| (I6) | rhs_dilation |
si64 型の 1 次元テンソル定数 |
(C7-C8) |
| (I7) | window_reversal |
i1 型の 1 次元テンソル定数 |
(C9) |
| (I8) | input_batch_dimension |
si64 型の定数 |
(C10)、(C13) |
| (I9) | input_feature_dimension |
si64 型の定数 |
(C11)、(C13-C14) |
| (I10) | input_spatial_dimensions |
si64 型の 1 次元テンソル定数 |
(C12)、(C13) |
| (I11) | kernel_input_feature_dimension |
si64 型の定数 |
(C14)、(C18) |
| (I12) | kernel_output_feature_dimension |
si64 型の定数 |
(C15-C16)、(C18)、(C28) |
| (I13) | kernel_spatial_dimensions |
si64 型の 1 次元テンソル定数 |
(C17-C18) |
| (I14) | output_batch_dimension |
si64 型の定数 |
(C20) |
| (I15) | output_feature_dimension |
si64 型の定数 |
(C20)、(C29) |
| (I16) | output_spatial_dimensions |
si64 型の 1 次元テンソル定数 |
(C19 ~ C20) |
| (I17) | feature_group_count |
si64 型の定数 |
(C11)、(C14)、(C16)、(C21)、(C23) |
| (I18) | batch_group_count |
si64 型の定数 |
(C10)、(C15)、(C22)、(C23) |
| (I19) | precision_config |
DEFAULT、HIGH、HIGHEST の列挙型の可変長引数 |
(C24) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
テンソルまたは量子化テンソル | (C25-C27)、(C29)、(C31-C33) |
制約
- (C1)
N = rank(lhs) = rank(rhs)。 - (C2)
size(window_strides) = N - 2。 - (C3)
0 < window_strides。 - (C4)
shape(padding) = [N - 2, 2]。 - (C5)
size(lhs_dilation) = N - 2。 - (C6)
0 < lhs_dilation。 - (C7)
size(rhs_dilation) = N - 2。 - (C8)
0 < rhs_dilation。 - (C9)
size(window_reversal) = N - 2。 - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0。 - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0。 - (C12)
size(input_spatial_dimensions) = N - 2。 - (C13)
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]が与えられた場合:is_unique(input_dimensions)。0 <= input_dimensions < N。
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count。 - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0。 - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0。 - (C17)
size(kernel_spatial_dimensions) = N - 2。 - (C18)
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]が与えられた場合:is_unique(kernel_dimensions)。0 <= kernel_dimensions < N。
- (C19)
size(output_spatial_dimensions) = N - 2。 - (C20)
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]が与えられた場合:is_unique(output_dimensions)。0 <= output_dimensions < N。
- (C21)
0 < feature_group_count。 - (C22)
0 < batch_group_count。 - (C23)
feature_group_count = 1 or batch_group_count = 1。 - (C24)
size(precision_config) = 2。 - (C25)
dim(result, result_dim)は次のように定義されます。result_dim = output_batch_dimensionの場合、dim(lhs, input_batch_dimension) / batch_group_count。result_dim = output_feature_dimensionの場合、dim(rhs, kernel_output_feature_dimension)。- それ以外の場合は
num_windows。ここで output_spatial_dimensions[spatial_dim] = result_dim。lhs_dim = input_spatial_dimensions[spatial_dim]rhs_dim = kernel_spatial_dimensions[spatial_dim]dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1。
- (C26)
rank(result) = N。 - オペレーションが量子化されていないテンソルを使用する場合:
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result)。
- (C27)
- オペレーションで量子化テンソルを使用する場合:
- (C28)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)。 - (C29)
is_per_axis_quantized(rhs)の場合はquantization_dimension(rhs) = kernel_output_feature_dimension。 - (C30)
is_per_axis_quantized(result)の場合はquantization_dimension(result) = output_feature_dimension。 is_quantized(lhs)の場合:- (C31)
storage_type(lhs) = storage_type(rhs)。 - (C32)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)。 - (C33)
is_per_tensor_quantized(rhs)の場合はis_per_tensor_quantized(result)。 !is_quantized(lhs)の場合:- (C34)
element_type(lhs) = expressed_type(rhs) = element_type(result)。
- (C28)
例
// %lhs: [[
// [[1], [2], [5], [6]],
// [[3], [4], [7], [8]],
// [[10], [11], [14], [15]],
// [[12], [13], [16], [17]]
// ]]
//
// %rhs: [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
// %padding: [[1, 1],
// [1, 1]]
%result = "stablehlo.dynamic_conv"(%lhs, %rhs, %padding) {
window_strid<es = arra>yi64: 4, 4,
lhs_dilati<on = arra>yi64: 2, 2,
rhs_dilati<on = arra>yi64: 1, 1,
window_revers<al = arrayi1: fa>lse, false,
dimension_numbers = #stab<lehlo.convraw
input_batch_dimension = 0,
input_feature_dimension = 3,
input_spatial_dimensions = [0, 1],
kernel_input_feature_dimension = 2,
kernel_output_feature_dimension = 3,
kernel_spatial_dimensions = [0, 1],
output_batch_dimension = 0,
output_feature_dimension = 3,
output_spatial_dimensions => [1, 2]
,
feature_group_count = 1 : i64,
batch_group_count = 1 : i64,
precision_config = [<#stablehloprecisi>on DEFAULT, <#stablehloprecisi>on DEFAULT]
}< : (tensor1>x4x4x1xi<64, tensor3>x3x1x1xi<64, ten>sor>2x2xi64<) - tensor1>x2x2x1xi64
// %result: [[
// [[1], [5]],
// [[10], [14]]
// ]]
dynamic_gather
セマンティクス
このオペレーションは、gather op と機能的に同じで、slice_sizes が値として動的に指定されます。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
テンソルまたはテンソルごとの量子化テンソル | (C1)、(C7)、(C10-C12)、(C14) |
| (I2) | start_indices |
整数型のテンソル | (C2)、(C3)、(C13) |
| (I3) | slice_sizes |
整数型の 1 次元テンソル | (C8)、(C11-C13) |
| (I4) | offset_dims |
si64 型の 1 次元テンソル定数 |
(C1)、(C4-C5)、(C13) |
| (I5) | collapsed_slice_dims |
si64 型の 1 次元テンソル定数 |
(C1)、(C6-C8)、(C13) |
| (I6) | start_index_map |
si64 型の 1 次元テンソル定数 |
(C3)、(C9)、(C10) |
| (I7) | index_vector_dim |
si64 型の定数 |
(C2)、(C3)、(C13) |
| (I8) | indices_are_sorted |
i1 型の定数 |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
テンソルまたはテンソルごとの量子化テンソル | (C5)、(C13-C14) |
制約
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims)。 - (C2)
0 <= index_vector_dim <= rank(start_indices)。 - (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1。 - (C4)
is_unique(offset_dims) and is_sorted(offset_dims)。 - (C5)
0 <= offset_dims < rank(result)。 - (C6)
is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims)。 - (C7)
0 <= collapsed_slice_dims < rank(operand)。 - (C8)
slice_sizes[collapsed_slice_dims...] <= 1。 - (C9)
is_unique(start_index_map)。 - (C10)
0 <= start_index_map < rank(operand)。 - (C11)
size(slice_sizes) = rank(operand)。 - (C12)
0 <= slice_sizes <= shape(operand)。 - (C13)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)。ここで、batch_dim_sizes = shape(start_indices)。ただし、index_vector_dimに対応するstart_indicesのディメンション サイズは含まれません。offset_dim_sizes = shape(slice_sizes)。ただし、collapsed_slice_dimsに対応するslice_sizesのディメンション サイズは含まれません。combineは、batch_dim_sizesをbatch_dimsに対応する軸に配置し、offset_dim_sizesをoffset_dimsに対応する軸に配置します。
- (C14)
element_type(operand) = element_type(result)。
例
// %operand: [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ]
// %start_indices: [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 2]]
// ]
// %slize_sizes: [1, 2, 2]
%result = "stablehlo.dynamic_gather"(%operand, %start_indices, %slize_sizes) {
dimension_numbers = #stable<hlo.gather
offset_dims = [2, 3],
collapsed_slice_dims = [0],
start_index_map = [1, 0],
index_vect>or_dim = 2,
indices_are_sorted = false
}< : (tenso>r3x4x2xi<64, tenso>r2x3x2xi<64, t>ens>or3xi64<) - tensor2>x3x2x2xi64
// %result: [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[9, 10], [11, 12]],
// [[11, 12], [13, 14]],
// [[17, 18], [19, 20]]
// ]
// ]
dynamic_iota
セマンティクス
このオペレーションは iota オペレーションと機能的に同じですが、結果の形状は output_shape を介して動的に指定されます。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | output_shape |
整数型の 1 次元テンソル | (C1)、(C2) |
| (I2) | iota_dimension |
si64 |
(C1) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
整数型、浮動小数点型、複素数型のテンソル、またはテンソルごとの量子化テンソル | (C2) |
制約
- (C1)
0 <= iota_dimension < size(output_shape)。 - (C2)
rank(result) = size(output_shape)。
例
%output_shape = stablehlo.constant dense<[4, 5]> : tensor<2xi64>
%result = "stablehlo.dynamic_iota"(%output_shape) {
iota_dimension = 0 : i64
}< : (t>ens>or2xi64<) - ten>sor4x5xi64
// %result: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
dynamic_pad
セマンティクス
このオペレーションは pad オペレーションと機能的に同じですが、edge_padding_low、edge_padding_high、interior_padding が値として動的に指定されます。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
テンソルまたはテンソルごとの量子化テンソル | (C1)、(C2)、(C4) |
| (I2) | padding_value |
0 次元テンソルまたはテンソルごとの量子化テンソル | (C1) |
| (I3) | edge_padding_low |
整数型の 1 次元テンソル | (C1)、(C4) |
| (I4) | edge_padding_high |
整数型の 1 次元テンソル | (C1)、(C4) |
| (I5) | interior_padding |
整数型の 1 次元テンソル | (C2-C4) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
テンソルまたはテンソルごとの量子化テンソル | (C3-C6) |
制約
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result)。 - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)。 - (C3)
0 <= interior_padding。 - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high。
例
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
// %edge_padding_low: [0, 1]
// %edge_padding_high: [2, 1]
// %interior_padding: [1, 2]
%result = "stablehlo.dynamic_pad"(%operand, %padding_value,
%edge_padding_low, %edge_padding_high, %interior_padding
)< : (ten>sor2x3xi<64,> tensori<64, t>ensor2xi<64, t>ensor2xi<64, t>ens>or2xi64<) - ten>sor5x9xi64
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
dynamic_reshape
セマンティクス
このオペレーションは reshape オペレーションと機能的に同じですが、結果の形状は output_shape を介して動的に指定されます。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
テンソルまたは量子化テンソル | (C1 ~ C3) |
| (I2) | output_shape |
整数型の 1 次元テンソル | (C4) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
テンソルまたは量子化テンソル | (C1 ~ C4) |
制約
- (C1)
element_type(result)は次のように計算されます。!is_per_axis_quantized(operand)の場合、element_type(operand)。element_type(operand)。ただし、quantization_dimension(operand)とquantization_dimension(result)は異なる場合があります。
- (C2)
size(operand) = size(result)。 - (C3)
is_per_axis_quantized(operand)の場合:reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)。dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)。
- (C4)
size(output_shape) = rank(result)。
例
// %operand: [[1, 2, 3], [4, 5, 6]]
// %output_shape: [3, 2]
%result = "stablehlo.dynamic_reshape"(%operand, %output_shape)< : (ten>sor2x3xi<64, t>ens>or2xi64<) - ten>sor3x2xi64
// %result: [[1, 2], [3, 4], [5, 6]]
dynamic_slice
セマンティクス
動的に計算された開始インデックスを使用して operand からスライスを抽出し、result テンソルを生成します。start_indices には、調整の対象となる各ディメンションのスライスの開始インデックスが含まれ、slice_sizes には、各ディメンションのスライスのサイズが含まれます。より正式には、result[result_index] = operand[operand_index] です。ここで:
adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes)。operand_index = adjusted_start_indices + result_index。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
テンソルまたはテンソルごとの量子化テンソル | (C1)、(C2)、(C4) |
| (I2) | start_indices |
整数型の 0 次元テンソルの可変長 | (C2)、(C3) |
| (I3) | slice_sizes |
si64 型の 1 次元テンソル定数 |
(C2)、(C4)、(C5) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
テンソルまたはテンソルごとの量子化テンソル | (C1)、(C5) |
制約
- (C1)
element_type(operand) = element_type(result)。 - (C2)
size(start_indices) = size(slice_sizes) = rank(operand)。 - (C3)
same(type(start_indices...))。 - (C4)
0 <= slice_sizes <= shape(operand)。 - (C5)
shape(result) = slice_sizes。
例
// %operand: [
// [0, 0, 1, 1],
// [0, 0, 1, 1],
// [0, 0, 0, 0],
// [0, 0, 0, 0]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
slice_siz<es = arra>yi64: 2, 2
}< : (ten>sor4x4xi<32,> tensori<64,> te>nsori64<) - ten>sor2x2xi32
// %result: [
// [1, 1],
// [1, 1]
// ]
dynamic_update_slice
セマンティクス
start_indices で始まるスライスが update の値で更新される点を除き、operand テンソルと同じ result テンソルを生成します。より正式には、result[result_index] は次のように定義されます。
update[update_index]:0 <= update_index < shape(update)の場合。ここで、adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update))。update_index = result_index - adjusted_start_indices。
- それ以外の場合は
operand[result_index]。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
テンソルまたはテンソルごとの量子化テンソル | (C1-C4)、(C6) |
| (I2) | update |
テンソルまたはテンソルごとの量子化テンソル | (C2)、(C3)、(C6) |
| (I3) | start_indices |
整数型の 0 次元テンソルの可変長 | (C4)、(C5) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
テンソルまたはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
type(operand) = type(result)。 - (C2)
element_type(update) = element_type(operand)。 - (C3)
rank(update) = rank(operand)。 - (C4)
size(start_indices) = rank(operand)。 - (C5)
same(type(start_indices...))。 - (C6)
0 <= shape(update) <= shape(operand)。
例
// %operand: [
// [1, 1, 0, 0],
// [1, 1, 0, 0],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
// %update: [
// [1, 1],
// [1, 1]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
< : (ten>sor4x4xi<32, ten>sor2x2xi<32,> tensori<64,> te>nsori64<) - ten>sor4x4xi32
// %result: [
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
指数
セマンティクス
operand テンソルに対して要素ごとの指数演算を行い、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- 浮動小数点の場合: IEEE-754 の
exp。 - 複素数の場合: 複素指数。
- 量子化された型の場合:
dequantize_op_quantize(exponential, operand, type(result))。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
浮動小数点型または複素数型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
浮動小数点型または複素数型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(operand) = baseline_type(result)。
例
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand)< : (ten>sor>2x2xf64<) - ten>sor2x2xf64
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]
exponential_minus_one
セマンティクス
operand テンソルに対して要素ごとの指数関数マイナス 1 演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- 浮動小数点の場合: IEEE-754 の
expm1。 - 複素数の場合: 複素指数関数から 1 を引いた値。
- 量子化された型の場合:
dequantize_op_quantize(exponential_minus_one, operand, type(result))。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
浮動小数点型または複素数型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
浮動小数点型または複素数型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(operand) = baseline_type(result)。
例
// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand)< : (t>ens>or2xf64<) - t>ensor2xf64
// %result: [0.0, 1.71828187]
fft
セマンティクス
実数と複素数の入力/出力に対して、順フーリエ変換と逆フーリエ変換を行います。
fft_type は、次のいずれかです。
FFT: 複素数から複素数への FFT を転送します。IFFT: 逆複素数から複素数への FFT。RFFT: 実数から複素数への FFT を転送します。IRFFT: 実数から複素数への逆 FFT(つまり、複素数を取得して実数を返します)。
より正式には、複素型の 1 次元テンソルを入力として受け取り、同じ型の 1 次元テンソルを出力として生成し、離散フーリエ変換を計算する関数 fft が与えられた場合:
fft_type = FFT の場合、result は L = size(fft_length) の一連の L 計算の最終結果として定義されます。たとえば、L = 3 の場合は次のようになります。
result1[i0, ..., :] = fft(operand[i0, ..., :])。result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])。
さらに、同じ型シグネチャを持ち、fft の逆関数を計算する関数 ifft があるとします。
fft_type = IFFT の場合、result は fft_type = FFT の計算の逆として定義されます。たとえば、L = 3 の場合は次のようになります。
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])。result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])result[i0, ..., :] = ifft(result2[i0, ..., :])。
さらに、浮動小数点型の 1 次元テンソルを受け取り、同じ浮動小数点セマンティクスの複素数型の 1 次元テンソルを生成し、次のように動作する関数 rfft があるとします。
rfft(real_operand) = truncated_resultここでcomplex_operand... = (real_operand..., 0.0)。complex_result = fft(complex_operand)truncated_result = complex_result[:(rank(complex_result) / 2 + 1)]。
(実数オペランドに対して離散フーリエ変換が計算される場合、結果の最初の N/2 + 1 要素は結果の残りの部分を一意に定義するため、冗長な要素の計算を避けるために rfft の結果は切り捨てられます)。
fft_type = RFFT の場合、result は L = size(fft_length) の一連の L 計算の最終結果として定義されます。たとえば、L = 3 の場合は次のようになります。
result1[i0, ..., :] = rfft(operand[i0, ..., :])。result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])。
最後に、同じ型シグネチャを持ち、rfft の逆関数を計算する関数 irfft を指定します。
fft_type = IRFFT の場合、result は fft_type = RFFT の計算の逆として定義されます。たとえば、L = 3 の場合は次のようになります。
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])。result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])result[i0, ..., :] = irfft(result2[i0, ..., :])。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
浮動小数点型または複合型のテンソル | (C1)、(C2)、(C4)、(C5) |
| (I2) | fft_type |
FFT、IFFT、RFFT、IRFFT の列挙型 |
(C2)、(C5) |
| (I3) | fft_length |
si64 型の 1 次元テンソル定数 |
(C1)、(C3)、(C4) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
浮動小数点型または複合型のテンソル | (C2)、(C4)、(C5) |
制約
- (C1)
size(fft_length) <= rank(operand)。 - (C2)
operand要素型とresult要素型の関係は次のとおりです。fft_type = FFT、element_type(operand)、element_type(result)の複合型が同じ場合。fft_type = IFFT、element_type(operand)、element_type(result)の複合型が同じ場合。fft_type = RFFT、element_type(operand)が浮動小数点型で、element_type(result)が同じ浮動小数点セマンティクスの複合型の場合。fft_type = IRFFT、element_type(operand)が複合型で、element_type(result)が同じ浮動小数点セマンティクスの浮動小数点型の場合。
- (C3)
1 <= size(fft_length) <= 3。 - (C4)
operandとresultの中に浮動小数点型のテンソルrealがある場合、shape(real)[-size(fft_length):] = fft_length。 - (C5)
shape(result) = shape(operand)(ただし、次のものを除く)。fft_type = RFFTの場合は、dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1。fft_type = IRFFTの場合は、dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1。
例
// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
fft_type = <#stablehloff>t_type FFT,
fft_leng<th = a>rrayi64: 4
}< : (tenso<r4x>>com>plexf32<) - tenso<r4x>>complexf32
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]
床
セマンティクス
operand テンソルの要素ごとの床関数を実行し、result テンソルを生成します。IEEE-754 仕様の roundToIntegralTowardNegative オペレーションを実装します。量子化された型の場合、dequantize_op_quantize(floor, operand, type(result)) を実行します。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C1) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(operand) = baseline_type(result)。
例
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand)< : (t>ens>or5xf32<) - t>ensor5xf32
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]
gather
セマンティクス
start_indices で指定されたオフセットから operand テンソルからスライスを収集し、result テンソルを生成します。
次の図は、具体的な例を使用して、result の要素が operand の要素にどのようにマッピングされるかを示しています。この図では、いくつかの result インデックスの例を取り上げ、それらがどの operand インデックスに対応するかを詳しく説明しています。
より正式には、result[result_index] = operand[operand_index] です。ここで:
batch_dims = [d for d in axes(result) and d not in offset_dims]。batch_index = result_index[batch_dims...]。start_indexは次のように定義されます。start_indices[bi0, ..., :, ..., biN](biはbatch_indexの個々の要素)で、index_vector_dim<rank(start_indices)の場合、:はインデックスindex_vector_dimに挿入されます。- それ以外の場合は
[start_indices[batch_index]]。
axes(operand)のd_operandの場合:d_operand = start_index_map[d_start]の場合、full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand])。- それ以外の場合は
full_start_index[d_operand] = 0。
axes(operand)のd_operandの場合:d_operand = operand_batching_dims[i_batching]とd_start = start_indices_batching_dims[i_batching]の場合、full_batching_index[d_operand] = batch_index[d_start - (d_start < index_vector_dim ? 0 : 1)]。- それ以外の場合は
full_batching_index[d_operand] = 0。
offset_index = result_index[offset_dims...]。full_offset_index = [oi0, ..., 0, ..., oiN]。ここで、oiはoffset_indexの個々の要素であり、0はcollapsed_slice_dimsとoperand_batching_dimsのインデックスに挿入されます。operand_index = full_start_index + full_batching_index + full_offset_index。
indices_are_sorted が true の場合、実装は start_indices が start_index_map に関してソートされていると想定できます。それ以外の場合、動作は未定義です。より正式には、indices(result)、full_start_index(i1) <= full_start_index(i2) のすべての i1 < i2 について、
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
テンソルまたはテンソルごとの量子化テンソル | (C1)、(C8)、(C11)、(C17)、(C19 ~ C21)、(C23) |
| (I2) | start_indices |
整数型のテンソル | (C2-C3)、(C14)、(C17)、(C22) |
| (I3) | offset_dims |
si64 型の 1 次元テンソル定数 |
(C1)、(C4-C5)、(C22) |
| (I4) | collapsed_slice_dims |
si64 型の 1 次元テンソル定数 |
(C1)、(C6 ~ C9)、(C22) |
| (I5) | operand_batching_dims |
si64 型の 1 次元テンソル定数 |
(C1)、(C6)、(C10-C12)、(C16-C18)、(C22) |
| (I6) | start_indices_batching_dims |
si64 型の 1 次元テンソル定数 |
(C13-C17) |
| (I7) | start_index_map |
si64 型の 1 次元テンソル定数 |
(C3)、(C18-C19) |
| (I8) | index_vector_dim |
si64 型の定数 |
(C2-C3)、(C15)、(C22) |
| (I9) | slice_sizes |
si64 型の 1 次元テンソル定数 |
(C9)、(C12)、(C20-C22) |
| (I10) | indices_are_sorted |
i1 型の定数 |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
テンソルまたはテンソルごとの量子化テンソル | (C5)、(C22 ~ C23) |
制約
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims) + size(operand_batching_dims)。 - (C2)
0 <= index_vector_dim <= rank(start_indices)。 - (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1。 - (C4)
is_unique(offset_dims) and is_sorted(offset_dims)。 - (C5)
0 <= offset_dims < rank(result)。 - (C6)
is_unique(concatenate(collapsed_slice_dims, operand_batching_dims)) - (C7)
is_sorted(collapsed_slice_dims)。 - (C8)
0 <= collapsed_slice_dims < rank(operand)。 - (C9)
slice_sizes[collapsed_slice_dims...] <= 1。 - (C10)
is_sorted(operand_batching_dims)。 - (C11)
0 <= operand_batching_dims < rank(operand)。 - (C12)
slice_sizes[operand_batching_dims...] <= 1。 - (C13)
is_unique(start_indices_batching_dims)。 - (C14)
0 <= start_indices_batching_dims < rank(start_indices)。 - (C15)
index_vector_dim not in start_indices_batching_dims。 - (C16)
size(operand_batching_dims) == size(start_indices_batching_dims)。 - (C17)
dim(operand, operand_batching_dims...) = dim(start_indices, start_indices_batching_dims...)。 - (C18)
is_unique(concatenate(start_index_map, operand_batching_dims))。 - (C19)
0 <= start_index_map < rank(operand)。 - (C20)
size(slice_sizes) = rank(operand)。 - (C21)
0 <= slice_sizes <= shape(operand)。 - (C22)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)ここで:batch_dim_sizes = shape(start_indices)。ただし、index_vector_dimに対応するstart_indicesのディメンション サイズは含まれません。offset_dim_sizes = slice_sizes。ただし、collapsed_slice_dimsとoperand_batching_dimsに対応するslice_sizesのディメンション サイズは含まれません。combineは、batch_dim_sizesをbatch_dimsに対応する軸に配置し、offset_dim_sizesをoffset_dimsに対応する軸に配置します。
- (C23)
element_type(operand) = element_type(result)。
例
// %operand: [
// [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ],
// [
// [[25, 26], [27, 28], [29, 30], [31, 32]],
// [[33, 34], [35, 36], [37, 38], [39, 40]],
// [[41, 42], [43, 44], [45, 46], [47, 48]]
// ]
// ]
// %start_indices: [
// [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 9]]
// ],
// [
// [[0, 0], [2, 1], [2, 2]],
// [[1, 2], [0, 1], [1, 0]]
// ]
// ]
%result = "stablehlo.gather"(%operand, %start_indices) {
dimension_numbers = #stable<hlo.gather
offset_dims = [3, 4],
collapsed_slice_dims = [1],
operand_batching_dims = [0],
start_indices_batching_dims = [1],
start_index_map = [2, 1],
index_vect>or_dim = 3,
slice_siz<es = arrayi64: >1, 1, 2, 2,
indices_are_sorted = false
}< : (tensor2>x3x4x2xi<32, tensor2>x2x>3x2xi64<) - tensor2x2>x3x2x2xi32
// %result: [
// [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[33, 34], [35, 36]],
// [[35, 36], [37, 38]],
// [[41, 42], [43, 44]]
// ]
// ],
// [
// [
// [[1, 2], [3, 4]],
// [[13, 14], [15, 16]],
// [[21, 22], [23, 24]]
// ],
// [
// [[43, 44], [45, 46]],
// [[33, 34], [35, 36]],
// [[27, 28], [29, 30]]
// ]
// ]
// ]
get_dimension_size
セマンティクス
指定された operand の dimension のサイズを生成します。より正式には、result = dim(operand, dimension) です。セマンティクスは、型の形状コンポーネントのみに関係します。要素の型は任意です。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
テンソルまたは量子化テンソル | (C1) |
| (I2) | dimension |
si64 型の定数 |
(C1) |
出力
| 名前 | タイプ |
|---|---|
result |
型 si32 の 0 次元テンソル |
制約
- (C1)
0 <= dimension < rank(operand)。
例
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
dimension = 1 : i64
}< : (ten>sor>2x3xi64<) -> tensori32
// %result: 3
get_tuple_element
セマンティクス
operand タプルの index 位置にある要素を抽出し、result を生成します。より正式には、result = operand[index] です。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
tuple | (C1)、(C2) |
| (I2) | index |
si32 型の定数 |
(C1)、(C2) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
任意の値 | (C2) |
制約
- (C1)
0 <= index < size(operand)。 - (C2)
type(result) = tuple_element_types(operand)[index]。
例
// %operand: ([1.0, 2.0], (3))
%result = "stablehlo.get_tuple_element"(<%operand) {index >= 0 : i32<} : (t<uplet>ensor2x<f64, t<upl>>>ete>nsori64<) - t>ensor2xf64
// %result: [1.0, 2.0]
~なら
セマンティクス
pred の値に応じて、true_branch または false_branch のいずれか 1 つの関数を実行した結果を出力します。より正式には、result =
pred ? true_branch() : false_branch() です。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | pred |
型 i1 の 0 次元テンソル |
|
| (I2) | true_branch |
関数 | (C1 ~ C3) |
| (I3) | false_branch |
関数 | (C1)、(C2) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
results |
可変数のテンソル、量子化テンソル、トークン | (C3) |
制約
- (C1)
input_types(true_branch) = input_types(false_branch) = []。 - (C2)
output_types(true_branch) = output_types(false_branch)。 - (C3)
type(results...) = output_types(true_branch)。
例
// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
"stablehlo.return"(%result_tr<ue_>bra>nch) : (tensori32) - ()
}, {
"stablehlo.return"(%<res>ult>_false_branch) :< (>ten>sori32)< - >()
}) : (tensori1) - tensori32
// %result: 10
imag
セマンティクス
operand から要素ごとに虚数部を抽出し、result テンソルを生成します。より正式には、各要素 x について、imag(x) = is_complex(x) ? imaginary_part(x) :
constant(0, element_type(result)) となります。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
浮動小数点型または複合型のテンソル | (C1)、(C2) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
浮動小数点型のテンソル | (C1)、(C2) |
制約
- (C1)
shape(result) = shape(operand)。 - (C2)
element_type(result)は次のように定義されます。is_complex(operand)の場合、complex_element_type(element_type(operand))。- それ以外の場合は
element_type(operand)。
例
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand)< : (tenso<r2x>>com>plexf32<) - t>ensor2xf32
// %result: [2.0, 4.0]
infeed
セマンティクス
インフィードからデータを読み取り、results を生成します。
infeed_config のセマンティクスは実装で定義されます。
results は、最初にペイロード値、最後にトークンで構成されます。今後、ペイロードとトークンを 2 つの別々の出力に分割して、明確さを向上させる予定です(#670)。
入力
| ラベル | 名前 | タイプ |
|---|---|---|
| (I1) | token |
token |
| (I2) | infeed_config |
string 型の定数 |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
results |
可変数のテンソル、量子化テンソル、トークン | (C1 ~ C3) |
制約
- (C1)
0 < size(results)。 - (C2)
is_empty(result[:-1])またはis_tensor(type(results[:-1]))。 - (C3)
is_token(type(results[-1]))。
例
// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : >(!stable<hlo.tok>en) - (tensor2x2xi64, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
infeed_config> = "<;">
} : (!stablehlo.token) - (tensor2x2xi64, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]
iota
セマンティクス
iota_dimension ディメンションに沿って 0 から始まる昇順の値で output テンソルを埋めます。より正式には、
output[output_index] = constant(is_quantized(output) ?
quantize(output_index[iota_dimension], element_type(output)) :
output_index[iota_dimension], element_type(output))。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | iota_dimension |
si64 |
(C1) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
output |
整数型、浮動小数点型、複素数型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
0 <= iota_dimension < rank(output)。
例
%output = "stablehlo.iota"() {
iota_dimension = 0 : i6>4
} : (<) - ten>sor4x5xi32
// %output: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
%output = "stablehlo.iota"() {
iota_dimensio>n = 1 :< i64
} >: () - tensor4x5xi32
// %output: [
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4]
// ]
is_finite
セマンティクス
x の値が有限(+Inf、-Inf、NaN のいずれでもない)かどうかを要素ごとにチェックし、y テンソルを生成します。IEEE-754 仕様の isFinite オペレーションを実装します。量子化された型の場合、結果は常に true になります。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | x |
浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C1) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
y |
ブール型のテンソル | (C1) |
制約
- (C1)
shape(x) = shape(y)。
例
// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x)< : (tens>or7xf64<) - >tensor7xi1
// %y: [false, false, false, true, true, true, true]
log
セマンティクス
operand テンソルに対して要素ごとの対数演算を行い、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- 浮動小数点の場合: IEEE-754 の
log。 - 複素数の場合: 複素対数。
- 量子化された型の場合:
dequantize_op_quantize(log, operand, type(result))。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
浮動小数点型または複素数型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
浮動小数点型または複素数型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(operand) = baseline_type(result)。
例
// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand)< : (ten>sor>2x2xf64<) - ten>sor2x2xf64
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]
log_plus_one
セマンティクス
operand テンソルに対して要素ごとの対数プラス 1 演算を行い、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- 浮動小数点の場合: IEEE-754 の
logp1。 - 複素数:
complex(log(hypot(real(x) + 1, imag(x))), atan2(imag(x), real(x) + 1)) - 量子化された型の場合:
dequantize_op_quantize(log_plus_one, operand, type(result))。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
浮動小数点型または複素数型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
浮動小数点型または複素数型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(operand) = baseline_type(result)。
例
// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand)< : (t>ens>or5xf64<) - t>ensor5xf64
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]
ロジスティック
セマンティクス
operand テンソルに対して要素ごとのロジスティック演算を行い、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- 浮動小数点の場合: IEEE-754 の
division(1, addition(1, exp(-x)))。 - 複素数の場合: 複素ロジスティック。
- 量子化された型の場合:
dequantize_op_quantize(logistic, operand, type(result))。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
浮動小数点型または複素数型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
浮動小数点型または複素数型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(operand) = baseline_type(result)。
例
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand)< : (ten>sor>2x2xf64<) - ten>sor2x2xf64
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]
地図
セマンティクス
マップ関数 computation を dimensions に沿って inputs に適用し、result テンソルを生成します。
より正式には、result[result_index] = computation(inputs...[result_index])。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | inputs |
可変数のテンソルまたはテンソルごとの量子化テンソル | (C1 ~ C4) |
| (I2) | dimensions |
si64 型の 1 次元テンソル定数 |
(C3) |
| (I3) | computation |
関数 | (C4) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
テンソルまたはテンソルごとの量子化テンソル | (C1)、(C4) |
制約
- (C1)
shape(inputs...) = shape(result)。 - (C2)
0 < size(inputs) = N。 - (C3)
dimensions = range(rank(inputs[0]))。 - (C4)
computationの型は(tensor<E0>, ..., tensor<EN-1>) -> tensor<E'>で、Ei = element_type(inputs[i])とE' = element_type(result)です。
例
// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
%0 = stablehlo.multiply %arg0, %arg<1 :> tensori64
stablehlo.return %<0 :> tensori64
}) {
dimensio<ns = arra>yi64: 0, 1
}< : (ten>sor2x2xi<64, ten>sor>2x2xi64<) - ten>sor2x2xi64
// %result: [[0, 5], [12, 21]]
最大
セマンティクス
テンソル lhs と rhs に対して要素ごとの最大値演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- ブール値の場合: 論理 OR。
- 整数の場合: 整数の最大値。
- 浮動小数点の場合: IEEE-754 の
maximum。 - 複素数の場合:
(real, imaginary)ペアの辞書式順序の最大値。複素数に順序を課すには驚くべきセマンティクスが伴うため、今後、このオペレーションの複素数のサポートを削除する予定です(#560)。 - 量子化された型の場合:
dequantize_op_quantize(maximum, lhs, rhs, type(result))。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | lhs |
テンソルまたはテンソルごとの量子化テンソル | (C1) |
| (I2) | rhs |
テンソルまたはテンソルごとの量子化テンソル | (C1) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
テンソルまたはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)。
例
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[5, 6], [7, 8]]
最小
セマンティクス
テンソル lhs と rhs に対して要素ごとの最小値演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- ブール値の場合: 論理 AND。
- 整数の場合: 最小整数。
- 浮動小数点の場合: IEEE-754 の
minimum。 - 複素数の場合:
(real, imaginary)ペアの辞書式順序の最小値。複素数に順序を課すには驚くべきセマンティクスが伴うため、今後、このオペレーションの複素数のサポートを削除する予定です(#560)。 - 量子化された型の場合:
dequantize_op_quantize(minimum, lhs, rhs, type(result))。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | lhs |
テンソルまたはテンソルごとの量子化テンソル | (C1) |
| (I2) | rhs |
テンソルまたはテンソルごとの量子化テンソル | (C1) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
テンソルまたはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)。
例
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[1, 2], [3, 4]]
乗算
セマンティクス
2 つのテンソル lhs と rhs の要素ごとの積を計算し、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- ブール値の場合: 論理 AND。
- 整数の場合: 整数の乗算。
- 浮動小数点の場合: IEEE-754 の
multiplication。 - 複素数の場合: 複素数の乗算。
- 量子化された型の場合:
dequantize_op_quantize(multiply, lhs, rhs, type(result))。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | lhs |
テンソルまたはテンソルごとの量子化テンソル | (C1) |
| (I2) | rhs |
テンソルまたはテンソルごとの量子化テンソル | (C1) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
テンソルまたはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(operand) = baseline_type(result)。
例
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[5, 12], [21, 32]]
negate
セマンティクス
operand テンソルの要素ごとの否定を実行し、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- 符号付き整数の場合: 整数の否定。
- 符号なし整数の場合: 符号付き整数へのビットキャスト、整数の否定、符号なし整数へのビットキャスト。
- 浮動小数点の場合: IEEE-754 の
negate。 - 複素数の場合: 複素数の否定。
- 量子化された型の場合:
dequantize_op_quantize(negate, operand, type(result))。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
整数型、浮動小数点型、複素数型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
整数型、浮動小数点型、複素数型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(operand) = baseline_type(result)。
例
// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand)< : (t>ens>or2xi32<) - t>ensor2xi32
// %result: [0, 2]
// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"<(%operand<) :>> (t>ensor1x<complexf3<2) >>- tensor1xcomplexf32
// %result: [-2.5, -0.0]
いない
セマンティクス
テンソル operand の要素ごとの NOT を実行し、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- ブール値の場合: 論理 NOT。
- 整数の場合: ビット演算 NOT。
引数
| 名前 | タイプ | 制約 |
|---|---|---|
operand |
ブール値または整数型のテンソル | (C1) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
ブール値または整数型のテンソル | (C1) |
制約
- (C1)
type(operand) = type(result)。
例
// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand)< : (ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[-2, -3], [-4, -5]]
// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"<(%op>era>nd) : (<tens>or2xi1) - tensor2xi1
// %result: [false, true]
optimization_barrier
セマンティクス
operand を生成するオペレーションが result に依存するオペレーションの前に実行されるようにし、コンパイラ変換によってオペレーションがバリアを越えて移動されるのを防ぎます。それ以外の場合、オペレーションは ID(result = operand)です。
引数
| 名前 | タイプ | 制約 |
|---|---|---|
operand |
可変数のテンソル、テンソルごとの量子化テンソルまたはトークン | (C1) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
可変数のテンソル、テンソルごとの量子化テンソルまたはトークン | (C1) |
制約
- (C1)
type(operand...) = type(result...)。
例
// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1)< : >(tensorf<32,> te>nsorf32)< - >(tensorf<32,> tensorf32)
// %result0: 0.0
// %result1: 1.0
または
セマンティクス
2 つのテンソル lhs と rhs の要素ごとの OR を実行し、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- ブール値の場合: 論理 OR。
- 整数の場合: ビット単位の OR。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | lhs |
整数型またはブール型のテンソル | (C1) |
| (I2) | rhs |
整数型またはブール型のテンソル | (C1) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
整数型またはブール型のテンソル | (C1) |
制約
- (C1)
type(lhs) = type(rhs) = type(result)。
例
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[5, 6], [7, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%<lhs, %>rhs) : (<tensor>2x2>xi1, te<nsor2x>2xi1) - tensor2x2xi1
// %result: [[false, true], [true, true]]
アウトフィード
セマンティクス
inputs をアウトフィードに書き込み、result トークンを生成します。
outfeed_config のセマンティクスは実装で定義されます。
入力
| ラベル | 名前 | タイプ |
|---|---|---|
| (I1) | inputs |
可変数のテンソルまたは量子化テンソル |
| (I2) | token |
token |
| (I3) | outfeed_config |
string 型の定数 |
出力
| 名前 | タイプ |
|---|---|
result |
token |
例
%result = "stablehlo.outfeed"(%input0, %token) {
outfeed_config = &quo<t;"<>/span>
} : (tensor2x2x2xi64,> !stablehlo.token) - !stablehlo.token
pad
セマンティクス
指定された padding_value を使用して、テンソルの周囲とテンソルの要素間のパディングで operand を拡張します。
edge_padding_low と edge_padding_high は、各ディメンションのローエンド(インデックス 0 の隣)とハイエンド(最も高いインデックスの隣)に追加されるパディングの量をそれぞれ指定します。パディングの量は負の値にできます。負のパディングの絶対値は、指定されたディメンションから削除する要素の数を示します。
interior_padding は、各ディメンションの 2 つの要素間に追加されるパディングの量を指定します。負の値は指定できません。内側のパディングはエッジ パディングの前に発生するため、負のエッジ パディングは内側のパディングが適用されたオペランドから要素を削除します。
より正式には、result[result_index] は次のように定義されます。
result_index = edge_padding_low + operand_index * (interior_padding + 1)の場合、operand[operand_index]。- それ以外の場合は
padding_value。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
テンソルまたはテンソルごとの量子化テンソル | (C1)、(C2)、(C4) |
| (I2) | padding_value |
0 次元テンソルまたはテンソルごとの量子化テンソル | (C1) |
| (I3) | edge_padding_low |
si64 型の 1 次元テンソル定数 |
(C1)、(C4) |
| (I4) | edge_padding_high |
si64 型の 1 次元テンソル定数 |
(C1)、(C4) |
| (I5) | interior_padding |
si64 型の 1 次元テンソル定数 |
(C2-C4) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
テンソルまたはテンソルごとの量子化テンソル | (C3-C6) |
制約
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result)。 - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)。 - (C3)
0 <= interior_padding。 - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high。
例
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
edge_padding_l<ow = arra>yi64: 0, 1,
edge_padding_hi<gh = arra>yi64: 2, 1,
interior_paddi<ng = arra>yi64: 1, 2
}< : (ten>sor2x3xi<32,> te>nsori32<) - ten>sor5x9xi32
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
partition_id
セマンティクス
現在のプロセスの partition_id を生成します。
出力
| 名前 | タイプ |
|---|---|
result |
型 ui32 の 0 次元テンソル |
例
%result = "stablehlo.partition_id">;() : (<) - >tensorui32
popcnt
セマンティクス
operand テンソルに設定されているビット数を要素ごとにカウントし、result テンソルを生成します。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
整数型のテンソル | (C1) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
整数型のテンソル | (C1) |
制約
- (C1)
type(operand) = type(result)。
例
// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand)< : (t>ens>or4xi64<) - t>ensor4xi64
// %result: [0, 1, 1, 7]
電力
セマンティクス
lhs テンソルと rhs テンソルの要素ごとのべき乗を計算し、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- 整数の場合: 整数のべき乗。
- 浮動小数点の場合: IEEE-754 の
pow。 - 複素数の場合: 複素数のべき乗。
- 量子化された型の場合:
dequantize_op_quantize(power, lhs, rhs, type(result))。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | lhs |
整数型、浮動小数点型、複素数型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
| (I2) | rhs |
整数型、浮動小数点型、複素数型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
整数型、浮動小数点型、複素数型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(operand) = baseline_type(result)。
例
// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs)< : (t>ensor6xf<64, t>ens>or6xf64<) - t>ensor6xf64
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]
real
セマンティクス
operand から要素ごとに実部を抽出し、result テンソルを生成します。より正式には、各要素 x について、real(x) = is_complex(x) ? real_part(x) : x となります。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
浮動小数点型または複合型のテンソル | (C1)、(C2) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
浮動小数点型のテンソル | (C1)、(C2) |
制約
- (C1)
shape(result) = shape(operand)。 - (C2)
element_type(result)は次のように定義されます。is_complex(operand)の場合、complex_element_type(element_type(operand))。- それ以外の場合は
element_type(operand)。
例
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand)< : (tenso<r2x>>com>plexf32<) - t>ensor2xf32
// %result: [1.0, 3.0]
受信
セマンティクス
channel_id を含むチャネルからデータを受信し、results を生成します。
is_host_transfer が true の場合、オペレーションはホストからデータを転送します。それ以外の場合は、source_target_pairs の値に基づいて別のデバイスからデータを転送します。このフラグは channel_type で提供される情報を複製するため、将来的にはいずれか 1 つのみを残す予定です(#666)。is_host_transfer = false で、source_target_pairs が None または空の場合、動作は未定義とみなされます。
results は、最初にペイロード値、最後にトークンで構成されます。今後、ペイロードとトークンを 2 つの別々の出力に分割して、明確さを向上させる予定です(#670)。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | token |
token |
|
| (I2) | channel_id |
si64 型の定数 |
|
| (I3) | channel_type |
DEVICE_TO_DEVICE と DEVICE_TO_HOST の列挙型 |
(C5) |
| (I4) | is_host_transfer |
i1 型の定数 |
(C5-C6) |
| (I5) | source_target_pairs |
型 si64 の 2 次元テンソル定数 |
(C1-C4)、(C6) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
results |
可変数のテンソル、量子化テンソル、トークン | (C2-C4) |
制約
- (C1)
dim(source_target_pairs, 1) = 2。 - (C2)
is_unique(source_target_pairs[:, 0])。 - (C3)
is_unique(source_target_pairs[:, 1])。 - (C4)
0 <= source_target_pairs < N。ここで、Nは次のように定義されます。cross_replicaを使用する場合はnum_replicas。cross_partitionを使用する場合はnum_partitions。
- (C5)
channel_typeは次のように定義されます。is_host_transfer = trueの場合、DEVICE_TO_HOST- それ以外の場合は
DEVICE_TO_DEVICE。
例
%results0, %results1 = "stablehlo.recv"(%token) {
channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 1,
is_host_transfer = false,
source_target_pai<rs = dense[[0, 1>], [1, 2]<] : ten>sor2x2xi64
} : (!stablehl>o.token)< - (ten>sor2x2xi64, !stablehlo.token)
reduce
セマンティクス
リダクション関数 body を dimensions に沿って inputs と init_values に適用し、results テンソルを生成します。
削減の順序は実装定義です。つまり、body と init_values は、すべての実装のすべての入力に対して同じ結果が生成されることを保証するために、モノイドを形成する必要があります。ただし、この条件は多くの一般的な削減には当てはまりません。たとえば、body の浮動小数点加算と init_values のゼロは、浮動小数点加算が結合的ではないため、実際にはモノイドを形成しません。
より正式には、results...[j0, ..., jR-1] = reduce(input_slices_converted) です。ここで:
input_slices = inputs...[j0, ..., :, ..., jR-1]: ここで、:はdimensionsに挿入されます。input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...)。init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...)。reduce(input_slices_converted) = exec(schedule)(二分木schedule)ここで、exec(node) = body(exec(node.left), exec(node.right))。exec(leaf) = leaf.value。
scheduleは実装定義の完全なバイナリツリーであり、そのインオーダー トラバーサルは次のようになります。indexの辞書順の昇順でindex_space(input_slices_converted)のすべてのindexのinput_slices_converted...[index]値。- 実装定義の位置に実装定義の量の
init_values_convertedが散在しています。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | inputs |
可変数のテンソルまたはテンソルごとの量子化テンソル | (C1 ~ C4)、(C6)、(C7) |
| (I2) | init_values |
0 次元テンソルまたはテンソルごとの量子化テンソルの可変個数 | (C2)、(C3) |
| (I3) | dimensions |
si64 型の 1 次元テンソル定数 |
(C4)、(C5)、(C7) |
| (I4) | body |
関数 | (C6) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
results |
可変数のテンソルまたはテンソルごとの量子化テンソル | (C3)、(C7)、(C8) |
制約
- (C1)
same(shape(inputs...))。 - (C2)
element_type(inputs...) = element_type(init_values...)。 - (C3)
0 < size(inputs) = size(init_values) = size(results) = N。 - (C4)
0 <= dimensions < rank(inputs[0])。 - (C5)
is_unique(dimensions)。 - (C6)
bodyの型は(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)で、is_promotable(element_type(inputs[i]), Ei)です。 - (C7)
shape(results...) = shape(inputs...)。ただし、dimensionsに対応するinputs...のディメンション サイズは含まれません。 - (C8)
[0,N)内のすべてのiについてelement_type(results[i]) = Ei。
例
// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
%0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
"stable<hlo>.re>turn"(%0) : (tensori64) <- ()
}>) {
dimens<ions = >arrayi64<: 1>
} >: (tens<or1x6>xi64, tensori64) - tensor1xi64
// %result = [15]
reduce_precision
セマンティクス
operand の要素ごとの変換を、exponent_bits と mantissa_bits を使用する別の浮動小数点型に対して実行し、元の浮動小数点型に戻して output テンソルを生成します。
より正式には:
- 元の値の仮数部が更新され、
roundToIntegralTiesToEvenセマンティクスを使用して、元の値がmantissa_bitsで表現可能な最も近い値に丸められます。 - 次に、
mantissa_bitsが元の値の仮数部のビット数より小さい場合、仮数部のビットはmantissa_bitsに切り捨てられます。 - 次に、中間結果の指数ビットが
exponent_bitsで指定された範囲に収まらない場合、中間結果は元の符号を使用して無限大にオーバーフローするか、元の符号を使用してゼロにアンダーフローします。 - 量子化された型の場合、
dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result))を実行します。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C1) |
| (I2) | exponent_bits |
si32 型の定数 |
(C2) |
| (I3) | mantissa_bits |
si32 型の定数 |
(C3) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
output |
浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(operand) = baseline_type(output)。 - (C2)
1 <= exponent_bits。 - (C3)
0 <= mantissa_bits。
例
// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
exponent_bits = 5 : i32,
mantissa_bits = 10 : i32
}< : (t>ens>or6xf64<) - t>ensor6xf64
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]
reduce_scatter
セマンティクス
StableHLO プロセス グリッドの各プロセス グループ内で、各プロセスの operand テンソルの値に対して computations を使用して削減を実行し、削減結果を scatter_dimension に沿って分割して、分割された部分をプロセス間で分散して result を生成します。
このオペレーションは、StableHLO プロセス グリッドを process_groups に分割します。これは次のように定義されます。
channel_id <= 0 and use_global_device_ids = falseの場合、cross_replica(replica_groups)。channel_id > 0 and use_global_device_ids = falseの場合、cross_replica_and_partition(replica_groups)。channel_id > 0 and use_global_device_ids = trueの場合、flattened_ids(replica_groups)。
その後、各 process_group 内で次の操作を行います。
reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation)。parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension)。process_group内のすべてのsenderについてresult@receiver = parts@sender[receiver_index]。ここで、receiver_index = process_group.index(receiver)。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
テンソルまたはテンソルごとの量子化テンソル | (C1)、(C2)、(C7)、(C8) |
| (I2) | scatter_dimension |
si64 型の定数 |
(C1)、(C2)、(C8) |
| (I3) | replica_groups |
型 si64 の 2 次元テンソル定数 |
(C3 ~ C5) |
| (I4) | channel_id |
si64 型の定数 |
(C6) |
| (I5) | use_global_device_ids |
i1 型の定数 |
(C6) |
| (I6) | computation |
関数 | (C7) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
テンソルまたはテンソルごとの量子化テンソル | (C8-C9) |
制約
- (C1)
dim(operand, scatter_dimension) % dim(process_groups, 1) = 0。 - (C2)
0 <= scatter_dimension < rank(operand)。 - (C3)
is_unique(replica_groups)。 - (C4)
size(replica_groups)は次のように定義されます。cross_replicaを使用する場合はnum_replicas。cross_replica_and_partitionを使用する場合はnum_replicas。flattened_idsを使用する場合はnum_processes。
- (C5)
0 <= replica_groups < size(replica_groups)。 - (C6)
use_global_device_ids = trueの場合はchannel_id > 0。 - (C7)
computationは(tensor<E>, tensor<E>) -> (tensor<E>)型で、is_promotable(element_type(operand), E)です。 - (C8)
shape(result) = shape(operand)(ただし、dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1)。
- (C9)
element_type(result) = E。
例
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
%0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
"stable<hlo>.re>turn"(%0) : (tensori64) - ()
}) {
scatter_dimension = 1 :< i64,
>replica_g<roups => dense[[0, 1]] : tensor1x2xi64,
channel_hand<le = #stablehlo.chan>nel_handleha<ndle = >0, >type = <0
} : (>tensor2x4xi64) - tensor2x2xi64
//
// %result@(0, 0): [[10, 12],
// [18, 20]]
// %result@(1, 0): [[14, 16],
// [22, 24]]
reduce_window
セマンティクス
削減関数 body を inputs と init_values のウィンドウに適用し、results を生成します。
次の図は、具体的な例を使用して、results... の要素が inputs... からどのように計算されるかを示しています。
より正式には、results...[result_index] = reduce(windows, init_values, axes(inputs...), body)(reduce を参照)。ここで、
padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1)。window_start = result_index * window_strideswindow_end = window_start + (window_dimensions - 1) * window_dilations + 1windows = slice(padded_inputs..., window_start, window_end, window_dilations)。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | inputs |
可変数のテンソルまたはテンソルごとの量子化テンソル | (C1-C4)、(C6)、(C8)、(C10)、(C12)、(C13)、(C15) |
| (I2) | init_values |
0 次元テンソルまたはテンソルごとの量子化テンソルの可変個数 | (C1)、(C13) |
| (I3) | window_dimensions |
si64 型の 1 次元テンソル定数 |
(C4)、(C5)、(C15) |
| (I4) | window_strides |
si64 型の 1 次元テンソル定数 |
(C6)、(C7)、(C15) |
| (I5) | base_dilations |
si64 型の 1 次元テンソル定数 |
(C8)、(C9)、(C15) |
| (I6) | window_dilations |
si64 型の 1 次元テンソル定数 |
(C10)、(C11)、(C15) |
| (I7) | padding |
型 si64 の 2 次元テンソル定数 |
(C12)、(C15) |
| (I8) | body |
関数 | (C13) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
results |
可変数のテンソルまたはテンソルごとの量子化テンソル | (C1)、(C14 ~ C16) |
制約
- (C1)
0 < size(inputs) = size(init_values) = size(results) = N。 - (C2)
same(shape(inputs...))。 - (C3)
element_type(inputs...) = element_type(init_values...)。 - (C4)
size(window_dimensions) = rank(inputs[0])。 - (C5)
0 < window_dimensions。 - (C6)
size(window_strides) = rank(inputs[0])。 - (C7)
0 < window_strides。 - (C8)
size(base_dilations) = rank(inputs[0])。 - (C9)
0 < base_dilations。 - (C10)
size(window_dilations) = rank(inputs[0])。 - (C11)
0 < window_dilations。 - (C12)
shape(padding) = [rank(inputs[0]), 2]。 - (C13)
bodyの型は(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)で、is_promotable(element_type(inputs[i]), Ei)。 - (C14)
same(shape(results...))。 - (C15)
shape(results[0]) = num_windowsここで、dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1。padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1]dilated_window_shape = (window_dimensions - 1) * window_dilations + 1is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shapenum_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1。
- (C16)
[0,N)内のすべてのiのelement_type(results[i]) = Ei。
例
// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
%0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
"stable<hlo>.re>turn"(%0) : (tensori64) - ()
})< {
wind>ow_dimensions = arrayi64: <2, 1,
w>indow_strides = arrayi64: <4, 1,
b>ase_dilations = arrayi64: 2,< 1,
win>dow_dilations = arr<ayi64: 3, 1,
p>adding = <dense[[>2, 1], [0, 0<]] : te>nsor2x2x<i64>
} >: (tens<or3x2xi>64, tensori64) - tensor2x2xi64
// %result = [[0, 0], [3, 4]]
余り
セマンティクス
被除数 lhs テンソルと除数 rhs テンソルの要素ごとの剰余を計算し、result テンソルを生成します。
より正式には、結果の符号は被除数から取得され、結果の絶対値は常に除数の絶対値よりも小さくなります。剰余は lhs - d * rhs として計算されます。ここで、d は次のように求められます。
- 整数の場合:
stablehlo.divide(lhs, rhs)。 - 浮動小数点の場合: 丸め属性
roundTowardZeroを持つ IEEE-754 のdivision(lhs, rhs)。 - 複素数: 未定(#997)。
- 量子化された型の場合:
dequantize_op_quantize(remainder, lhs, rhs, type(result))。
浮動小数点要素型の場合、このオペレーションは IEEE-754 仕様の remainder オペレーションとは対照的です。このオペレーションでは、d は lhs/rhs の正確な値に最も近い整数値で、偶数に丸められます。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | lhs |
整数型、浮動小数点型、複合型、またはテンソルごとの量子化テンソルのテンソル | (C1) |
| (I2) | rhs |
整数型、浮動小数点型、複合型、またはテンソルごとの量子化テンソルのテンソル | (C1) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
整数型、浮動小数点型、複合型、またはテンソルごとの量子化テンソルのテンソル | (C1) |
制約
- (C1)
baseline_type(operand) = baseline_type(result)。
例
// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs)< : (t>ensor4xi<64, t>ens>or4xi64<) - t>ensor4xi64
// %result: [2, -2, 2, -2]
replica_id
セマンティクス
現在のプロセスの replica_id を生成します。
出力
| 名前 | タイプ |
|---|---|
result |
型 ui32 の 0 次元テンソル |
例
%result = "stablehlo.replica_id">;() : (<) - >tensorui32
reshape
セマンティクス
operand テンソルを result テンソルに再形成します。概念的には、正規表現は同じまま、形状のみを変更する(tensor<2x3xf32> から tensor<3x2xf32> または tensor<6xf32> など)ことになります。
より正式には、result[result_index] = operand[operand_index] です。ここで、result_index と operand_index は、index_space(result) と index_space(operand) の辞書式順序で同じ位置にあります。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
テンソルまたは量子化テンソル | (C1 ~ C3) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
テンソルまたは量子化テンソル | (C1 ~ C3) |
制約
- (C1)
element_type(result)は次のように計算されます。!is_per_axis_quantized(operand)の場合、element_type(operand)。element_type(operand)。ただし、quantization_dimension(operand)とquantization_dimension(result)は異なる場合があります。
- (C2)
size(operand) = size(result)。 - (C3)
is_per_axis_quantized(operand)の場合:reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)。dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)。
例
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.reshape"(%operand)< : (ten>sor>2x3xi32<) - ten>sor3x2xi32
// %result: [[1, 2], [3, 4], [5, 6]]
reverse
セマンティクス
指定された dimensions に沿って operand 内の要素の順序を逆にして、result テンソルを生成します。より正式には、result[result_index] = operand[operand_index] です。ここで:
dimensionsのdの場合、operand_index[d] = dim(result, d) - result_index[d] - 1。- それ以外の場合は
operand_index[d] = result_index[d]。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
テンソルまたはテンソルごとの量子化テンソル | (C1)、(C3) |
| (I2) | dimensions |
si64 型の 1 次元テンソル定数 |
(C2)、(C3) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
テンソルまたはテンソルごとの量子化テンソル | (C1)、(C3) |
制約
- (C1)
type(operand) = type(result)。 - (C2)
is_unique(dimensions)。 - (C3)
0 <= dimensions < rank(result)。
例
// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
dimensio<ns = a>rrayi64: 1
}< : (ten>sor>3x2xi32<) - ten>sor3x2xi32
// %result: [[2, 1], [4, 3], [6, 5]]
rng
セマンティクス
rng_distribution アルゴリズムを使用して乱数を生成し、指定された形状 shape の result テンソルを生成します。
rng_distribution = UNIFORM の場合、区間 [a, b) の一様分布に従って乱数が生成されます。a >= b の場合、動作は未定義です。
rng_distribution = NORMAL の場合、平均 = a、標準偏差 = b の正規分布に従って乱数が生成されます。b < 0 の場合、動作は未定義です。
乱数の生成方法は実装で定義されます。たとえば、決定論的である場合とそうでない場合があり、隠れ状態を使用する場合とそうでない場合があります。
多くの関係者との会話の中で、この op は事実上非推奨になっていることが明らかになったため、今後削除することを検討しています(#597)。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | a |
整数、ブール値、または浮動小数点型の 0 次元テンソル | (C1)、(C2) |
| (I2) | b |
整数、ブール値、または浮動小数点型の 0 次元テンソル | (C1)、(C2) |
| (I3) | shape |
si64 型の 1 次元テンソル定数 |
(C3) |
| (I4) | rng_distribution |
UNIFORM と NORMAL の列挙型 |
(C2) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
整数、ブール値、浮動小数点型のテンソル | (C1 ~ C3) |
制約
- (C1)
element_type(a) = element_type(b) = element_type(result)。 - (C2)
rng_distribution = NORMALの場合はis_float(a)です。 - (C3)
shape(result) = shape。
例
// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
rng_distribution = <#stablehlorng_distributi>on UNIFORM
}< : >(tensori<32,> tensori<32, t>ens>or2xi64<) - ten>sor3x3xi32
// %result: [
// [1, 0, 1],
// [1, 1, 1],
// [0, 0, 0]
// ]
rng_bit_generator
セマンティクス
初期状態 initial_state が与えられた場合、疑似乱数生成関数アルゴリズム rng_algorithm を使用して、一様ランダム ビットで埋められた output と更新された出力状態 output_state を返します。出力は initial_state の決定論的関数であることが保証されますが、実装間で決定論的であることは保証されません。
rng_algorithm は、次のいずれかです。
DEFAULT: 実装定義のアルゴリズム。THREE_FRY: Threefry アルゴリズムの実装定義バリアント。*PHILOX: Philox アルゴリズムの実装定義バリアント。*
* 参照: Salmon et al. SC 2011. 並列乱数: 1、2、3 のように簡単です。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | rng_algorithm |
DEFAULT、THREE_FRY、PHILOX の列挙型 |
(C2) |
| (I2) | initial_state |
ui64 型の 1 次元テンソル |
(C1)、(C2) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
output_state |
ui64 型の 1 次元テンソル |
(C1) |
output |
整数型または浮動小数点型のテンソル |
制約
- (C1)
type(initial_state) = type(output_state)。 - (C2)
size(initial_state)は次のように定義されます。rng_algorithm = DEFAULTの場合は実装定義。rng_algorithm = THREE_FRYの場合、2。rng_algorithm = PHILOXの場合は2または3。
例
// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
rng_algorithm = <#stablehlorng_algorithm> THREE_FRY
}< : (te>nso>r2xui64)< - (te>nsor2xui<64, tens>or2x2xui64)
// %output_state: [1, 6]
// %output: [
// [9236835810183407956, 16087790271692313299],
// [18212823393184779219, 2658481902456610144]
// ]
round_nearest_afz
セマンティクス
operand テンソルに対して、要素ごとに最も近い整数への丸めを行い、ゼロから離れるようにタイブレークを行い、result テンソルを生成します。IEEE-754 仕様の roundToIntegralTiesToAway 演算を実装します。量子化された型の場合、dequantize_op_quantize(round_nearest_afz, operand, type(result)) を実行します。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C1) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(operand) = baseline_type(result)。
例
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand)< : (t>ens>or5xf64<) - t>ensor5xf64
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]
round_nearest_even
セマンティクス
operand テンソルに対して、最も近い整数への要素ごとの丸めを行い、偶数への丸めを行い、result テンソルを生成します。IEEE-754 仕様の roundToIntegralTiesToEven オペレーションを実装します。量子化された型の場合、dequantize_op_quantize(round_nearest_even, operand, type(result)) を実行します。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C1) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(operand) = baseline_type(result)。
例
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand)< : (t>ens>or5xf64<) - t>ensor5xf64
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]
rsqrt
セマンティクス
operand テンソルに対して要素ごとの逆平方根演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- 浮動小数点の場合: IEEE-754 の
rSqrt。 - 複素数の場合: 複素数の逆数の平方根。
- 量子化された型の場合:
dequantize_op_quantize(rsqrt, operand, type(result))。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
浮動小数点型または複素数型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
浮動小数点型または複素数型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(operand) = baseline_type(result)。
例
// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand)< : (ten>sor>2x2xf32<) - ten>sor2x2xf32
// %result: [[1.0, 0.5], [0.33333343, 0.2]]
散布図
セマンティクス
scatter_indices で指定された複数のスライスが update_computation を使用して updates の値で更新される点を除き、inputs テンソルと等しい results テンソルを生成します。
次の図は、具体的な例を使用して、updates... の要素が results... の要素にどのようにマッピングされるかを示しています。この図では、いくつかの updates... インデックスの例を取り上げ、それらがどの results... インデックスに対応するかを詳しく説明しています。
より正式には、index_space(updates[0]) のすべての update_index について、次のようになります。
update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims]。update_scatter_index = update_index[update_scatter_dims...]。start_indexは次のように定義されます。scatter_indices[si0, ..., :, ..., siN](siはupdate_scatter_indexの個々の要素)で、index_vector_dim<rank(scatter_indices)の場合、:がindex_vector_dimインデックスに挿入されます。- それ以外の場合は
[scatter_indices[update_scatter_index]]。
axes(inputs[0])のd_inputの場合:d_input = scatter_dims_to_operand_dims[d_start]の場合、full_start_index[d_input] = start_index[d_start]。- それ以外の場合は
full_start_index[d_input] = 0。
axes(inputs[0])のd_inputの場合:d_input = input_batching_dims[i_batching]とd_start = scatter_indices_batching_dims[i_batching]の場合、full_batching_index[d_input] = update_scatter_index[d_start - (d_start < index_vector_dim ? 0 : 1)]。- それ以外の場合は
full_batching_index[d_input] = 0。
update_window_index = update_index[update_window_dims...]。full_window_index = [wi0, ..., 0, ..., wiN]。ここで、wiはupdate_window_indexの個々の要素であり、0はinserted_window_dimsとinput_batching_dimsのインデックスに挿入されます。result_index = full_start_index + full_batching_index + full_window_index。
この場合、results = exec(schedule, inputs) は次のようになります。
scheduleは、index_space(updates[0])の実装定義の順列です。exec([update_index, ...], results) = exec([...], updated_results)ここで:result_indexがshape(results...)の範囲内にある場合updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )updated_values = update_computation(results...[result_index], updates_converted)updated_resultsは、results...[result_index]がupdated_values...に設定されたresultsのコピーです。- それ以外の場合
updated_results = results。
exec([], results) = results。
indices_are_sorted が true の場合、実装は scatter_indices が scatter_dims_to_operand_dims に関してソートされていると想定できます。それ以外の場合の動作は未定義です。より正式には、indices(result) のすべての i1 < i2 について、full_start_index(i1) <= full_start_index(i2) です。
unique_indices が true の場合、実装では、分散されるすべての result_index インデックスが一意であると想定できます。unique_indices が true で、分散先のインデックスが一意でない場合、動作は未定義です。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | inputs |
可変数のテンソルまたはテンソルごとの量子化テンソル | (C1)、(C2)、(C4-C6)、(C11)、(C13)、(C18)、(C21)、(C23-C24) |
| (I2) | scatter_indices |
整数型のテンソル | (C4)、(C15)、(C19)、(C22) |
| (I3) | updates |
可変数のテンソルまたはテンソルごとの量子化テンソル | (C3-C6)、(C8) |
| (I4) | update_window_dims |
si64 型の 1 次元テンソル定数 |
(C2)、(C4)、(C7-C8) |
| (I5) | inserted_window_dims |
si64 型の 1 次元テンソル定数 |
(C2)、(C4)、(C9-C11) |
| (I6) | input_batching_dims |
si64 型の 1 次元テンソル定数 |
(C2)、(C4)、(C9)、(C12-13)、(C17-18)、(C20) |
| (I7) | scatter_indices_batching_dims |
si64 型の 1 次元テンソル定数 |
(C14-C18) |
| (I8) | scatter_dims_to_operand_dims |
si64 型の 1 次元テンソル定数 |
(C19-C21) |
| (I9) | index_vector_dim |
si64 型の定数 |
(C4)、(C16)、(C19)、(C22) |
| (I10) | indices_are_sorted |
i1 型の定数 |
|
| (I11) | unique_indices |
i1 型の定数 |
|
| (I12) | update_computation |
関数 | (C23) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
results |
可変数のテンソルまたはテンソルごとの量子化テンソル | (C24-C25) |
制約
- (C1)
same(shape(inputs...))。 - (C2)
rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims) + size(input_batching_dims)。 - (C3)
same(shape(updates...))。 - (C4)
shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes)(ここで、update_scatter_dim_sizes = shape(scatter_indices)。ただし、index_vector_dimに対応するscatter_indicesのディメンション サイズは含まれません。update_window_dim_sizes <= shape(inputs[0])。ただし、inserted_window_dimsとinput_batching_dimsに対応するinputs[0]のディメンション サイズは含まれません。combineは、update_scatter_dim_sizesをupdate_scatter_dimsに対応する軸に配置し、update_window_dim_sizesをupdate_window_dimsに対応する軸に配置します。
- (C5)
0 < size(inputs) = size(updates) = N。 - (C6)
element_type(updates...) = element_type(inputs...)。 - (C7)
is_unique(update_window_dims) and is_sorted(update_window_dims)。 - (C8)
0 <= update_window_dims < rank(updates[0])。 - (C9)
is_unique(concatenate(inserted_window_dims, input_batching_dims)) - (C10)
is_sorted(inserted_window_dims)。 - (C11)
0 <= inserted_window_dims < rank(inputs[0])。 - (C12)
is_sorted(input_batching_dims)。 - (C13)
0 <= input_batching_dims < rank(inputs[0]))。 - (C14)
is_unique(scatter_indices_batching_dims)。 - (C15)
0 <= scatter_indices_batching_dims < rank(scatter_indices)。 - (C16)
index_vector_dim not in scatter_indices_batching_dims。 - (C17)
size(input_batching_dims) == size(scatter_indices_batching_dims)。 - (C18)
dim(inputs[0], input_batching_dims...) = dim(scatter_indices, scatter_indices_batching_dims...)。 - (C19)
size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1。 - (C20)
is_unique(concatenate(scatter_dims_to_operand_dims, input_batching_dims))。 - (C21)
0 <= scatter_dims_to_operand_dims < rank(inputs[0])。 - (C22)
0 <= index_vector_dim <= rank(scatter_indices)。 - (C23)
update_computationの型は(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)で、is_promotable(element_type(inputs[i]), Ei)。 - (C24)
shape(inputs...) = shape(results...)。 - (C25)
[0,N)内のすべてのiに対してelement_type(results[i]) = Ei。
例
// %input: [
// [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ],
// [
// [[25, 26], [27, 28], [29, 30], [31, 32]],
// [[33, 34], [35, 36], [37, 38], [39, 40]],
// [[41, 42], [43, 44], [45, 46], [47, 48]]
// ]
// ]
// %scatter_indices: [
// [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 9]]
// ],
// [
// [[0, 0], [2, 1], [2, 2]],
// [[1, 2], [0, 1], [1, 0]]
// ]
// ]
// %update: [
// [
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]],
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]]
// ],
// [
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]],
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]]
// ]
// ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
%0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
"stable<hlo>.re>turn"(%0) : (tensori64) - ()
}) {
scatter_dimensio<n_numbers = #stablehlo.scatter
update_window_dims = [3, 4],
inserted_window_dims = [1],
input_batching_dims = [0],
scatter_indices_batching_dims = [1],
scatter_dims_to_operand_dims = [2>, 1],
index_vector_dim = 3,
indices_are_sorted = false,
uniq<ue_indices >= false
<} : (tensor>2x3x4x2x<i64, tensor2x>2x3>x2xi64,< tensor2x2x>3x2x2xi64) - tensor2x3x4x2xi64
// %result: [
// [
// [[3, 4], [6, 7], [6, 7], [7, 8]],
// [[9, 10],[11, 12], [15, 16], [17, 18]],
// [[17, 18], [19, 20], [22, 23], [24, 25]]
// ],
// [
// [[25, 26], [28, 29], [30, 31], [31, 32]],
// [[35, 36], [38, 39], [38, 39], [39, 40]],
// [[41, 42], [44, 45], [46, 47], [47, 48]]
// ]
// ]
選択
セマンティクス
pred の対応する要素の値に基づいて、各要素が on_true テンソルまたは on_false テンソルから選択される result テンソルを生成します。より正式には、result[result_index] = pred_element ? on_true[result_index] :
on_false[result_index](pred_element = rank(pred) = 0 ? pred[] :
pred[result_index])です。量子化された型の場合、dequantize_select_quantize(pred, on_true, on_false, type(result)) を実行します。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | pred |
i1 型のテンソル |
(C1) |
| (I2) | on_true |
テンソルまたはテンソルごとの量子化テンソル | (C1-C2) |
| (I3) | on_false |
テンソルまたはテンソルごとの量子化テンソル | (C2) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
テンソルまたはテンソルごとの量子化テンソル | (C2) |
制約
- (C1)
rank(pred) = 0 or shape(pred) = shape(on_true)。 - (C2)
baseline_type(on_true) = baseline_type(on_false) = baseline_type(result)。
例
// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false)< : (te>nsor2x2x<i1, ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[5, 2], [3, 8]]
select_and_scatter
セマンティクス
select を使用して input テンソルの reduce_window の結果に基づいて scatter を使用して source テンソルの値を分散し、result テンソルを生成します。
次の図は、具体的な例を使用して、result の要素が operand と source からどのように計算されるかを示しています。
より正式には、次のようになります。
次の入力を使用した
selected_values = reduce_window_without_init(...):inputs = [operand].window_dimensions、window_strides、paddingはそのまま使用されます。base_dilations = windows_dilations = 1。bodyは次のように定義されます。
def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>: return select(arg0, arg1) ? arg0 : arg1;ここで、
E = element_type(operand)とreduce_window_without_initはreduce_windowとまったく同じように機能しますが、基盤となるreduceのschedule(reduce を参照)には初期値が含まれていません。対応するウィンドウに値がない場合の動作は、現在未指定です(#731)。result[result_index] = reduce([source_values], [init_value], [0], scatter)ここで:source_values = [source[source_index] for source_index in source_indices]。selected_values[source_index]にoperand_indexのoperand要素がある場合はselected_index(source_index) = operand_index。source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index]。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
テンソルまたはテンソルごとの量子化テンソル | (C1-C4)、(C6)、(C8-C11) |
| (I2) | source |
テンソルまたはテンソルごとの量子化テンソル | (C1)、(C2) |
| (I3) | init_value |
0 次元テンソルまたはテンソルごとの量子化テンソル | (C3) |
| (I4) | window_dimensions |
si64 型の 1 次元テンソル定数 |
(C2)、(C4)、(C5) |
| (I5) | window_strides |
si64 型の 1 次元テンソル定数 |
(C2)、(C6)、(C7) |
| (I6) | padding |
型 si64 の 2 次元テンソル定数 |
(C2)、(C8) |
| (I7) | select |
関数 | (C9) |
| (I8) | scatter |
関数 | (C10) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
テンソルまたはテンソルごとの量子化テンソル | (C11-C12) |
制約
- (C1)
element_type(operand) = element_type(source)。 - (C2)
shape(source) = num_windowsここで:padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1]。is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shapenum_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1。
- (C3)
element_type(init_value) = element_type(operand)。 - (C4)
size(window_dimensions) = rank(operand)。 - (C5)
0 < window_dimensions。 - (C6)
size(window_strides) = rank(operand)。 - (C7)
0 < window_strides。 - (C8)
shape(padding) = [rank(operand), 2]。 - (C9)
selectは(tensor<E>, tensor<E>) -> tensor<i1>型で、E = element_type(operand)です。 - (C10)
scatterの型は(tensor<E>, tensor<E>) -> tensor<E>で、is_promotable(element_type(operand), E)。 - (C11)
shape(operand) = shape(result)。 - (C12)
element_type(result) = E。
例
// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_di<rection = #stablehlocom>parison_directio<n G>E
} <: (>ten>sori64,< t>ensori64) - tensori1
"stable<hl>o.r>eturn"(%0) : (tensori1) <- (>)
}, {
^bb0(%<arg>0: tensori64, %arg1: tensori64):
%0 = "sta<ble>hlo.add&<quo>t;(>%arg0, <%ar>g1) : (tensori64, tensori64) - tensor<i64>
> "stablehlo.return"(%0) :< (tensori>64) - ()
}) {
window_dim<ensions => arrayi64: 3, 1,
<window_strides => arrayi64<: 2, 1,>
padding =< dense[>[0, 1], <[0, 0]]> : tenso<r2x>2xi>64
} : <(tensor>4x2xi64, tensor2x2xi64, tensori64) - tensor4x2xi64
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]
送信
セマンティクス
チャンネル channel_id に inputs を送信します。入力は、source_target_pairs で指定された順序で他のデバイスに送信されます。このオペレーションは result トークンを生成します。
is_host_transfer が true の場合、オペレーションはデータをホストに転送します。それ以外の場合は、source_target_pairs の値に基づいて別のデバイスにデータを転送します。このフラグは channel_type で提供される情報を複製するため、将来的にはいずれか 1 つのみを残す予定です(#666)。is_host_transfer = false で、source_target_pairs が None または空の場合、動作は未定義とみなされます。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | inputs |
可変数のテンソルまたは量子化テンソル | |
| (I2) | token |
token |
|
| (I3) | channel_id |
si64 型の定数 |
|
| (I4) | channel_type |
DEVICE_TO_DEVICE と DEVICE_TO_HOST の列挙型 |
(C5) |
| (I5) | is_host_transfer |
i1 型の定数 |
(C5-C6) |
| (I6) | source_target_pairs |
型 si64 の 2 次元テンソル定数 |
(C1-C4)、(C6) |
出力
| 名前 | タイプ |
|---|---|
result |
token |
制約
- (C1)
dim(source_target_pairs, 1) = 2。 - (C2)
is_unique(source_target_pairs[:, 0])。 - (C3)
is_unique(source_target_pairs[:, 1])。 - (C4)
0 <= source_target_pairs < N。ここで、Nは次のように定義されます。cross_replicaを使用する場合はnum_replicas。cross_partitionを使用する場合はnum_partitions。
- (C5)
channel_typeは次のように定義されます。is_host_transfer = trueの場合、DEVICE_TO_HOST- それ以外の場合は
DEVICE_TO_DEVICE。
例
%result = "stablehlo.send"(%operand, %token) {
channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 1,
is_host_transfer = false,
source_target_pai<rs = dense[[0, 1>], [1, 2]<] : ten>sor2x2xi64
}< : (ten>sor2x2xi64, !stablehl>o.token) - !stablehlo.token
shift_left
セマンティクス
lhs テンソルに対して rhs ビット数だけ要素ごとの左シフト演算を行い、result テンソルを生成します。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | lhs |
整数型のテンソル | (C1) |
| (I2) | rhs |
整数型のテンソル | (C1) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
整数型のテンソル | (C1) |
制約
- (C1)
type(lhs) = type(rhs) = type(result)。
例
// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs<): (t>ensor3xi<64, t>ens>or3xi64<) - t>ensor3xi64
// %result: [-2, 0, 8]
shift_right_arithmetic
セマンティクス
lhs テンソルに対して rhs ビット数の要素ごとの算術右シフト演算を行い、result テンソルを生成します。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | lhs |
整数型のテンソル | (C1) |
| (I2) | rhs |
整数型のテンソル | (C1) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
整数型のテンソル | (C1) |
制約
- (C1)
type(lhs) = type(rhs) = type(result)。
例
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs<): (t>ensor3xi<64, t>ens>or3xi64<) - t>ensor3xi64
// %result: [-1, 0, 1]
shift_right_logical
セマンティクス
lhs テンソルに対して rhs ビット数の要素ごとの論理右シフト演算を行い、result テンソルを生成します。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | lhs |
整数型のテンソル | (C1) |
| (I2) | rhs |
整数型のテンソル | (C1) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
整数型のテンソル | (C1) |
制約
- (C1)
type(lhs) = type(rhs) = type(result)。
例
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs<): (t>ensor3xi<64, t>ens>or3xi64<) - t>ensor3xi64
// %result: [9223372036854775807, 0, 1]
署名
セマンティクス
operand の要素ごとの符号を返し、result テンソルを生成します。より正式には、各要素 x のセマンティクスは、次のように Python 構文を使用して表現できます。
def sign(x):
if is_integer(x):
if compare(x, 0, LT, SIGNED): return -1
if compare(x, 0, EQ, SIGNED): return 0
return 1
elif is_float(x):
if is_nan(x): return NaN
if compare(x, -0.0, EQ, FLOAT): return -0.0
if compare(x, +0.0, EQ, FLOAT): return +0.0
if compare(x, 0.0, LT, FLOAT): return -1.0
return 1.0
elif is_complex(x):
if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
return divide(x, convert(abs(x), type(x)))
量子化された型の場合、dequantize_op_quantize(sign, operand, type(result)) を実行します。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
符号付き整数型、浮動小数点型、複素数型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
符号付き整数型、浮動小数点型、複素数型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(operand) = baseline_type(result)。
例
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand)< : (t>ens>or5xf64<) - t>ensor5xf64
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
サイン
セマンティクス
operand テンソルに対して要素ごとの正弦演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- 浮動小数点の場合: IEEE-754 の
sin。 - 複素数の場合は、複素数の正弦。
- 量子化された型の場合:
dequantize_op_quantize(sine, operand, type(result))。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
浮動小数点型または複素数型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
浮動小数点型または複素数型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(operand) = baseline_type(result)。
例
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.sine"(%operand)< : (ten>sor>2x2xf32<) - ten>sor2x2xf32
// %result: [[0.0, 1.0], [0.0, -1.0]]
slice
セマンティクス
静的に計算された開始インデックスを使用して operand からスライスを抽出し、result テンソルを生成します。start_indices には各ディメンションのスライスの開始インデックスが含まれ、limit_indices には各ディメンションのスライスの終了インデックス(排他的)が含まれ、strides には各ディメンションのストライドが含まれます。
より正式には、result[result_index] = operand[operand_index](operand_index = start_indices + result_index * strides)。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
テンソルまたはテンソルごとの量子化テンソル | (C1-C3)、(C5) |
| (I2) | start_indices |
si64 型の 1 次元テンソル定数 |
(C2)、(C3)、(C5) |
| (I3) | limit_indices |
si64 型の 1 次元テンソル定数 |
(C2)、(C3)、(C5) |
| (I4) | strides |
si64 型の 1 次元テンソル定数 |
(C2)、(C4) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
テンソルまたはテンソルごとの量子化テンソル | (C1)、(C5) |
制約
- (C1)
element_type(operand) = element_type(result)。 - (C2)
size(start_indices) = size(limit_indices) = size(strides) = rank(operand)。 - (C3)
0 <= start_indices <= limit_indices <= shape(operand)。 - (C4)
0 < strides。 - (C5)
shape(result) = ceil((limit_indices - start_indices) / strides)。
例
// %operand: [
// [0, 0, 0, 0],
// [0, 0, 1, 1],
// [0, 0, 1, 1]
// ]
%result = "stablehlo.slice"(%operand) {
start_indic<es = arra>yi64: 1, 2,
limit_indic<es = arra>yi64: 3, 4,
strid<es = arra>yi64: 1, 1
}< : (ten>sor>3x4xi64<) - ten>sor2x2xi64
// % result: [
// [1, 1],
// [1, 1]
// ]
並べ替え
セマンティクス
comparator に従って、inputs の 1 次元スライスを dimension ディメンションに沿って一緒に並べ替え、results を生成します。
他のオペレーションの同様の入力とは異なり、dimension では負の値が許可されます。セマンティクスは次のとおりです。今後、一貫性の理由から、この動作は禁止される可能性があります(#1377)。
is_stable が true の場合、並べ替えは安定しています。つまり、コンパレータによって等しいと見なされる要素の相対順序が保持されます。入力が 1 つの場合、2 つの要素 e1 と e2 は、comparator(e1, e2) = comparator(e2, e1) = false の場合にのみコンパレータによって等しいと見なされます。複数の入力に一般化する方法については、以下の形式化をご覧ください。
より正式には、index_space(results[0]) のすべての result_index について、次のようになります。
adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension。result_slice = [ri0, ..., :, ..., riR-1](riNはresult_indexの個々の要素、:はadjusted_dimensionに挿入されます)。inputs_together = (inputs[0]..., ..., inputs[N-1]...)。results_together[result_slice] = sort(inputs_together[result_slice], comparator_together)。- ここで、
sortは 1 次元スライスを昇順で並べ替えます。comparator_togetherは、左側の引数が右側の 2 番目の引数より小さい場合にtrueを返します。 def comparator_together(lhs_together, rhs_together): args = [] for (lhs_el, rhs_el) in zip(lhs_together, rhs_together): args.append(lhs_el) args.append(rhs_el) return comparator(*args)(results[0]..., ..., results[N-1]...) = results_together。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | inputs |
可変数のテンソルまたはテンソルごとの量子化テンソル | (C1 ~ C5) |
| (I2) | dimension |
si64 型の定数 |
(C4) |
| (I3) | is_stable |
i1 型の定数 |
|
| (I4) | comparator |
関数 | (C5) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
results |
可変数のテンソルまたはテンソルごとの量子化テンソル | (C2)、(C3) |
制約
- (C1)
0 < size(inputs)。 - (C2)
type(inputs...) = type(results...)。 - (C3)
same(shape(inputs...) + shape(results...))。 - (C4)
-R <= dimension < R(R = rank(inputs[0]))。 - (C5)
comparatorの型は(tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1>で、Ei = element_type(inputs[i])です。
例
// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64, %ar<g2:> tensori64, %ar<g3:> tensori64):
%predicate = "stablehlo.compare"(%arg0, %arg1) {
comparison_di<rection = #stablehlocom>parison_directio<n G>T
} <: (>ten>sori64,< t>ensori64) - tensori1
"stablehlo.retu<rn>&qu>ot;(%predicate) : (tensori1) - ()
}) {
dimension = 0 : i64,
< is_st>able = t<rue
} :> (t>ensor2x3<xi64, t>ensor2x3<xi64) -> (tensor2x3xi64, tensor2x3xi64)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]
sqrt
セマンティクス
operand テンソルに対して要素ごとの平方根演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- 浮動小数点の場合: IEEE-754 の
squareRoot。 - 複素数の場合: 複素数の平方根。
- 量子化された型の場合:
dequantize_op_quantize(sqrt, operand, type(result))。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
浮動小数点型または複素数型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
浮動小数点型または複素数型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(operand) = baseline_type(result)。
例
// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand)< : (ten>sor>2x2xf32<) - ten>sor2x2xf32
// %result: [[0.0, 1.0], [2.0, 3.0]]
subtract
セマンティクス
2 つのテンソル lhs と rhs の要素ごとの減算を行い、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- 整数の場合: 整数の減算。
- 浮動小数点の場合: IEEE-754 の
subtraction。 - 複素数の場合: 複素数の減算。
- 量子化された型の場合:
dequantize_op_quantize(subtract, lhs, rhs, type(result))。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | lhs |
整数型、浮動小数点型、複素数型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
| (I2) | rhs |
整数型、浮動小数点型、複素数型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
整数型、浮動小数点型、複素数型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)。
例
// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs)< : (ten>sor2x2xf<32, ten>sor>2x2xf32)< - (ten>sor2x2xf32)
// %result: [[1, 2], [3, 4]]
tan
セマンティクス
operand テンソルに対して要素ごとの正接演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- 浮動小数点の場合: IEEE-754 の
tan。 - 複素数の場合は、複素数のタンジェント。
- 量子化された型の場合:
dequantize_op_quantize(tan, operand, type(result))。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
浮動小数点型または複素数型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
浮動小数点型または複素数型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(operand) = baseline_type(result)。
例
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.tan"(%operand)< : (ten>sor>2x2xf64<) - ten>sor2x2xf64
// %result: [
// [0.0, 1.63312e+16],
// [0.0, 5.44375e+15]
// ]
tanh
セマンティクス
operand テンソルに対して要素ごとの双曲線正接演算を行い、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- 浮動小数点の場合: IEEE-754 の
tanh。 - 複素数の場合: 複素数の双曲線正接。
- 量子化された型の場合:
dequantize_op_quantize(tanh, operand, type(result))。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
浮動小数点型または複素数型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
浮動小数点型または複素数型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(operand) = baseline_type(result)。
例
// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand)< : (t>ens>or3xf32<) - t>ensor3xf32
// %result: [-0.76159416, 0.0, 0.76159416]
transpose
セマンティクス
permutation を使用して operand テンソルのディメンションを置換し、result テンソルを生成します。より正式には、result[result_index] = operand[operand_index](result_index[d] = operand_index[permutation[d]])。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
テンソルまたは量子化テンソル | (C1 ~ C4) |
| (I2) | permutation |
si64 型の 1 次元テンソル定数 |
(C2-C4) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
テンソルまたは量子化テンソル | (C1)、(C3 ~ C4) |
制約
- (C1)
element_type(result)は次のように計算されます。!is_per_axis_quantized(operand)の場合、element_type(operand)。element_type(operand)。ただし、quantization_dimension(operand)とquantization_dimension(result)は異なる場合があります。
- (C2)
permutationはrange(rank(operand))の順列です。 - (C3)
shape(result) = dim(operand, permutation...)。 - (C4)
is_per_axis_quantized(result)の場合はquantization_dimension(operand) = permutation(quantization_dimension(result))。
例
// %operand: [
// [[1,2], [3,4], [5,6]],
// [[7,8], [9,10], [11,12]]
// ]
%result = "stablehlo.transpose"(%operand) {
permutati<on = arrayi6>4: 2, 1, 0
}< : (tenso>r2x>3x2xi32<) - tenso>r2x3x2xi32
// %result: [
// [[1,7], [3,9], [5,11]],
// [[2,8], [4,10], [6,12]]
// ]
triangular_solve
セマンティクス
下三角または上三角の係数行列を持つ連立一次方程式のバッチを解きます。
より正式には、a と b が与えられたとき、left_side が true の場合は result[i0, ..., iR-3, :, :] が op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :] の解となり、left_side が false の場合は result[i0, ..., iR-3, :, :] が op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :] の解となります。ここで、変数 x を解きます。op(a) は transpose_a によって決定されます。transpose_a は次のいずれかになります。x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :]
NO_TRANSPOSE:aをそのまま使用してオペレーションを実行します。TRANSPOSE:aの転置に対して演算を実行します。ADJOINT:aの共役転置に対して演算を実行します。
lower が true の場合は a の下三角、それ以外の場合は a の上三角からのみ入力データを読み取ります。出力データは同じ三角形で返されます。もう一方の三角形の値は実装定義です。
unit_diagonal が true の場合、実装は a の対角要素が 1 に等しいと想定できます。それ以外の場合、動作は未定義です。
量子化された型の場合、dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower,
unit_diagonal, transpose_a), a, b, type(result)) を実行します。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | a |
浮動小数点型または複素数型のテンソル、またはテンソルごとの量子化テンソル | (C1 ~ C3) |
| (I2) | b |
浮動小数点型または複素数型のテンソル、またはテンソルごとの量子化テンソル | (C1 ~ C4) |
| (I3) | left_side |
i1 型の定数 |
(C3) |
| (I4) | lower |
i1 型の定数 |
|
| (I5) | unit_diagonal |
i1 型の定数 |
|
| (I6) | transpose_a |
NO_TRANSPOSE、TRANSPOSE、ADJOINT の列挙型 |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
浮動小数点型または複素数型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_element_type(a) = baseline_element_type(b)。 - (C2)
2 <= rank(a) = rank(b) = R。 - (C3)
shape(a)とshape(b)の関係は次のように定義されます。shape(a)[:-3] = shape(b)[:-3]。dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1)。
- (C4)
baseline_type(b) = baseline_type(result)。
例
// %a = [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
// %b = [
// [2.0, 0.0, 0.0],
// [4.0, 8.0, 0.0],
// [6.0, 10.0, 12.0]
// ]
%result = "stablehlo.triangular_solve"(%a, %b) {
left_side = true,
lower = true,
unit_diagonal = false,
transpose_a = <#stablehlotranspose NO>_TRANSPOSE
}< : (ten>sor3x3xf<32, ten>sor>3x3xf32<) - ten>sor3x3xf32
// %result: [
// [2.0, 0.0, 0.0],
// [0.0, 2.0, 0.0],
// [0.0, 0.0, 2.0]
// ]
tuple
セマンティクス
値 val から result タプルを生成します。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | val |
可変個数の値 | (C1) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
tuple | (C1) |
制約
- (C1)
resultはtuple<E0, ..., EN-1>型で、Ei = type(val[i])です。
例
// %val0: memref[1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1)< : (m>emref2x<f32, t<upl>>ete>nsori3<2) - t<uplem>emref2x<f32, t<upl>>>etensori32
// %result: (memref[1.0, 2.0], (3))
uniform_dequantize
セマンティクス
operand 型で定義された量子化パラメータに従って、量子化テンソル operand の要素ごとの変換を浮動小数点テンソル result に行います。
より正式には、result = dequantize(operand)。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
量子化テンソル | (C1)、(C2) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
浮動小数点型のテンソル | (C1)、(C2) |
制約
- (C1)
shape(operand) = shape(result)。 - (C2)
element_type(result) = expressed_type(operand)。
例
// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand)< : (tensor2x!qua<nt.uniformi8:f32:0, {0.1:-3>>0,0>.5:-20}<) - t>ensor2xf32
// %result: [4.0, 15.0]
uniform_quantize
セマンティクス
result 型で定義された量子化パラメータに従って、浮動小数点テンソルまたは量子化テンソル operand の要素ごとの変換を量子化テンソル result に行います。
より正式には、
is_float(operand)の場合:result = quantize(operand, type(result))。
is_quantized(operand)の場合:float_result = dequantize(operand)。result = quantize(float_result, type(result))。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
浮動小数点型または量子化型のテンソル | (C1)、(C2) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
量子化テンソル | (C1)、(C2) |
制約
- (C1)
shape(operand) = shape(result)。 - (C2)
expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand)。
例
// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand)< : (t>ens>or2xf32<) - tensor2x!qua<nt.uniformi8:f32:0, {0.1:-3>>0,0.5:-20}
// %result: [10, 10]
// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"<(%operand) : (te<nsor2x!quant.uniformi8:f32:>>0, >{0.1:-3<0,0.5:-20}) - te<nsor2x!quant.uniformi8:f32:>>0, {0.1:-20,0.2:-30}
// %result: [20, 45]
しばらく
セマンティクス
cond 関数が true を出力している間、body 関数を 0 回以上実行した結果を出力します。より正式には、セマンティクスは Python 構文を使用して次のように表現できます。
internal_state = operand
while cond(*internal_state):
internal_state = body(*internal_state)
results = internal_state
無限ループの動作は未定です(#383)。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | operand |
可変個数の値 | (C1 ~ C3) |
| (I2) | cond |
関数 | (C1) |
| (I3) | body |
関数 | (C2) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
results |
可変個数の値 | (C3) |
制約
- (C1)
condの型は(T0, ..., TN-1) -> tensor<i1>で、Ti = type(operand[i])です。 - (C2)
bodyの型は(T0, ..., TN-1) -> (T0, ..., TN-1)で、Ti = type(operand[i])です。 - (C3)
type(results...) = type(operand...)。
例
// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
%cond = "stablehlo.compare"(%arg0, %ten) {
comparison_di<rection = #stablehlocom>parison_directio<n L>T
} <: (>ten>sori64,< t>ensori64) - tensori1
stablehlo.r<et>urn %cond : tensori1
}, {
< ^>bb0(%arg0: tens<ori>64, %arg1: tensori64):
%new_sum = stablehlo.add <%ar>g1, %one : tensori64
%new_i = stablehlo.add <%ar>g0, %one : tensori64
stablehlo.return %new_<i, >%new_sum< : >tensori64, te<nso>ri64
}) <: (>ten>sori64, <ten>sori64) <- (>tensori64, tensori64)
// %results0: 10
// %results1: 10
xor
セマンティクス
2 つのテンソル lhs と rhs の要素ごとの XOR を実行し、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- ブール値の場合: 論理 XOR。
- 整数の場合: ビット単位の XOR。
入力
| ラベル | 名前 | タイプ | 制約 |
|---|---|---|---|
| (I1) | lhs |
ブール値または整数型のテンソル | (C1) |
| (I2) | rhs |
ブール値または整数型のテンソル | (C1) |
出力
| 名前 | タイプ | 制約 |
|---|---|---|
result |
ブール値または整数型のテンソル | (C1) |
制約
- (C1)
type(lhs) = type(rhs) = type(result)。
例
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[4, 4], [4, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%<lhs, %>rhs) : (<tensor>2x2>xi1, te<nsor2x>2xi1) - tensor2x2xi1
// %result: [[false, true], [true, false]]
言語の相互運用
現在、StableHLO プログラムには、StableHLO で定義されていないオペレーションが含まれていることがあります。
モジュール、関数、呼び出し、戻り値
StableHLO は、ModuleOp、FuncOp、CallOp、ReturnOp にアップストリーム MLIR オペレーションを使用します。これは、既存の MLIR メカニズムとの相互運用性を高めるために行われました。多くの有用なパスは FuncOp と ModuleOp を対象として記述されており、多くのコンパイル パイプラインはこれらのオペレーションが存在することを想定しています。これらのオペレーションには完全な互換性保証が適用されます。これらのオペレーションが互換性のない方法で変更された場合(削除など)、互換性を維持するために StableHLO の同等のオペレーションが追加されます。
CHLO
CHLO opset には、StableHLO に分解される高レベルのオペレーションが含まれています。現在、CHLO の互換性は保証されていません。互換性を保証するには、シリアル化の前に chlo-legalize-to-stablehlo パスを使用する必要があります。
シェイプ オペレーション
コミュニティでは、コア MLIR 言語の特定のオペレーションを動的 StableHLO プログラムで使用して、形状計算を行うのが一般的なユースケースです。通常、これには shape_of や num_elements などの shape 言語オペレーション、dim や from_elements などの tensor 言語オペレーション、組み込みの index 型が含まれます。
Dynamism RFC > O2 では、これらは対象外とされていますが、相互運用性のために index 型のサポートが一部含まれています。これらのオペレーションまたは型には互換性の保証はありません。shape-legalize-to-stablehlo パスを使用して、これらのオペレーションを完全にサポートされている StableHLO オペレーションに変換できます。
非推奨のオペレーション
MHLO から継承された StableHLO オペレーションがいくつかあり、これらは非推奨となり、StableHLO から削除される予定です。これらの削除の詳細については、StableHLO v1.0 クリーンアップ #2283 をご覧ください。これらの非推奨のトラッカーの問題は #2340 です。
これらのオペレーションは、いくつかのカテゴリに分類されます。
- StableHLO オペレーションの「Not in HLO」カテゴリ - これらは当初 StableHLO opset の一部でしたが、後で適合しないと判断されました:
broadcast、create_token、cross-replica-sum、dot、einsum、torch_index_select、unary_einsum(#3)。 - 未使用のオペレーション - これらのオペレーションは、以前は有用だった可能性がありますが、オペレーションの開発が不十分であったか、これらのオペレーションを使用するパイプラインがリファクタリングされて、これらのオペレーションが不要になった可能性があります。これには、
map、tuple(#598)、get_tuple_element、rng、complexの比較 #560、畳み込みwindow_reversal(#1181)が含まれます。
これらのオペレーションの一部は、既存のオペレーション(broadcast、create_token、cross-replica-sum、dot、unary_einsum)を使用して表現できるため、簡単に削除できます。既存の互換性ウィンドウ(6 か月)が経過すると、削除されます。その他は、削除に向けて検討中です(einsum、get_tuple_element、map、rng、torch_index_select、tuple、complex の比較、window_reversal)。コミュニティからのフィードバックを待って、これらのオペレーションは削除されるか、完全なサポート付きで仕様に追加されます。これらのオペレーションの将来が明らかになるまでは、互換性が保証されるのは 6 か月間のみです。
実行
順次実行
StableHLO プログラムは、main 関数に入力値を指定して出力値を計算することで実行されます。関数の出力値は、対応する return オペレーションをルートとするオペレーションのグラフを実行することで計算されます。
実行順序は、データフローと一致する限り(つまり、オペレーションが使用前に実行される場合)、実装定義です。StableHLO では、すべての副作用のあるオペレーションが 1 つのトークンを消費し、1 つのトークンを生成します(複数のトークンは after_all を介して 1 つのトークンに多重化できます)。そのため、副作用の実行順序もデータフローと一致します。たとえば、次のプログラムでは、%0 → %1 → %2 → return と %1 → %0 → %2 → return の 2 つの実行順序が考えられます。
func.func @main() -> tensor<f64> {
%0 = stablehlo.constant dense<1.0> : tensor<f64>
%1 = stablehlo.constant dense<2.0> : tensor<f64>
%2 = stablehlo.add %0, %1 : tensor<f64>
return %2 : tensor<f64>
}
より正式には、StableHLO プロセスは、1)StableHLO プログラム、2)オペレーション ステータス(まだ実行されていない、すでに実行されている)、3)プロセスが処理している中間値の組み合わせです。プロセスは main 関数への入力値で始まり、オペレーションのグラフを介してオペレーションのステータスと中間値を更新し、出力値で終了します。さらなる形式化は未定です(#484)。
並列実行
StableHLO プログラムは並列実行でき、num_replicas × num_partitions の 2D プロセス グリッドに編成されます。num_replicas と num_partitions の両方の型は ui32 です。
StableHLO プロセス グリッドでは、num_replicas * num_partitions 個の StableHLO プロセスが同時に実行されています。各プロセスには一意の process_id = (replica_id, partition_id) があります。replica_ids = range(num_replicas) の replica_id と partition_ids = range(num_partitions) の partition_id はどちらも ui32 型です。
プロセス グリッドのサイズはすべてのプログラムで静的にわかっています(将来的には、StableHLO プログラムの明示的な一部にする予定です #650)。プロセス グリッド内の位置はすべてのプロセスで静的にわかっています。各プロセスは、replica_id オペレーションと partition_id オペレーションを介して、プロセス グリッド内の位置にアクセスできます。
プロセス グリッド内では、プログラムはすべて同じ(「単一プログラム、複数データ」スタイル)にすることも、すべて異なる(「複数プログラム、複数データ」スタイル)にすることも、その中間にすることもできます。今後、GSPMD(#619)など、並列 StableHLO プログラムを定義する他のイディオムのサポートを導入する予定です。
プロセス グリッド内では、プロセスはほとんど互いに独立しています。プロセスには個別のオペレーション ステータス、個別の入力/中間/出力値があり、ほとんどのオペレーションはプロセス間で個別に実行されます。ただし、後述する少数の集合オペレーションは例外です。
ほとんどのオペレーションの実行では同じプロセスからの値のみが使用されるため、通常はこれらの値を名前で参照しても曖昧さはありません。ただし、集合演算のセマンティクスを記述する場合は、これでは不十分です。そのため、特定のプロセス内の値 name を参照する表記 name@process_id が使用されます。(その観点からすると、修飾されていない name は name@(replica_id(), partition_id()) の省略形と見なすことができます)。
プロセス間の実行順序は、以下で説明するポイントツーポイント通信と集団演算によって導入される同期を除き、実装定義です。
ポイントツーポイント通信
StableHLO プロセスは、StableHLO チャネルを介して相互に通信できます。チャネルは、si64 型の正の ID で表されます。さまざまなオペレーションを通じて、チャネルに値を送信したり、チャネルから値を受信したりできます。
これらのチャンネル ID の取得元、プロセス プログラムがそれらを認識する方法、それらによって導入される同期の種類など、さらなる形式化は未定です(#484)。
ストリーミング通信
すべての StableHLO プロセスは、次の 2 つのストリーミング インターフェースにアクセスできます。
- 読み取り可能なインフィード。
- 書き込み可能なアウトフィード。
プロセス間の通信に使用されるチャネルとは異なり、インフィードとアウトフィードのもう一方の端は実装定義です。
ストリーミング通信が実行順序にどのように影響するか、どのような同期が導入されるかなどのさらなる形式化は、未定です(#484)。
Collective ops
StableHLO には、all_gather、all_reduce、all_to_all、collective_broadcast、collective_permute、reduce_scatter の 6 つの集合演算があります。これらのオペレーションはすべて、StableHLO プロセス グリッド内のプロセスを StableHLO プロセス グループに分割し、各プロセス グループ内で他のプロセス グループとは独立して結合計算を実行します。
各プロセス グループ内で、集合演算は同期バリアを導入する可能性があります。この同期がいつ発生するのか、プロセスがこのバリアにどのように到達するのか、到達しない場合はどうなるのかなど、さらなる形式化は未定です(#484)。
プロセス グループにパーティション間の通信が含まれている場合(パーティション ID が異なるプロセスがプロセス グループに含まれている場合)、集合演算の実行にはチャネルが必要であり、集合演算は si64 型の正の channel_id を提供する必要があります。クロスレプリカ通信にチャネルは必要ありません。
集合演算によって実行される計算は個々の演算に固有のものであり、上記の個々の演算のセクションで説明されています。ただし、プロセス グリッドをプロセス グループに分割する戦略は、これらのオペレーション間で共有され、このセクションで説明します。より正式には、StableHLO は次の 4 つの戦略をサポートしています。
cross_replica
プロセス グループ内では、クロスレプリカ通信のみが行われます。この戦略では、replica_groups(レプリカ ID のリストのリスト)を取得し、replica_groups と partition_ids のデカルト積を計算します。replica_groups には一意の要素があり、すべての replica_ids をカバーする必要があります。より正式には、Python 構文を使用して次のように記述します。
def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
for partition_id in partition_ids:
process_group = []
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
たとえば、replica_groups = [[0, 1], [2, 3]] と num_partitions = 2 の場合、cross_replica は [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]] を生成します。
cross_partition
各プロセス グループ内では、パーティション間の通信のみが行われます。この戦略は、パーティション ID のリストのリストである partition_groups を受け取り、partition_groups と replica_ids のデカルト積を計算します。partition_groups には一意の要素があり、すべての partition_ids をカバーする必要があります。より正式には、Python 構文を使用して次のように記述します。
def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
for partition_group in partition_groups:
for replica_id in replica_ids:
process_group = []
for partition_id in partition_group:
process_group.append((replica_id, partition_id))
yield process_group
たとえば、partition_groups = [[0, 1]] と num_replicas = 4 の場合、cross_partition は [[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]] を生成します。
cross_replica_and_partition
クロスレプリカ通信とクロスパーティション通信の両方が、各プロセス グループ内で発生する可能性があります。この戦略は、レプリカ ID のリストのリストである replica_groups を受け取り、各 replica_group の直積を partition_ids で計算します。replica_groups には一意の要素があり、すべての replica_ids をカバーする必要があります。より正式には、Python 構文を使用して次のように記述します。
def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
process_group = []
for partition_id in partition_ids:
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
たとえば、replica_groups = [[0, 1], [2, 3]] と num_partitions = 2 の場合、cross_replica_and_partition は [[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]] を生成します。
flattened_ids
この戦略は、flattened_id_groups(replica_id * num_partitions + partition_id の形式の「フラット化」されたプロセス ID のリストのリスト)を受け取り、プロセス ID に変換します。flattened_id_groups には一意の要素があり、すべての process_ids をカバーする必要があります。より正式には、Python 構文を使用して次のように記述します。
def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
for flattened_id_group in flattened_id_groups:
process_group = []
for flattened_id in flattened_id_group:
replica_id = flattened_id // num_partitions
partition_id = flattened_id % num_partitions
process_group.append((replica_id, partition_id))
yield process_group
たとえば、flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]]、num_replicas = 4、num_partitions = 2 の場合、flattened_ids は [[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]] を生成します。
精度
現時点では、StableHLO は数値の精度を保証していませんが、今後変更される可能性があります(#1156)。
量子化されたオペレーションの実行セマンティクス
量子化された StableHLO オペレーションの解釈は、ハードウェアの要件と機能によって異なる場合があります。たとえば、一部のハードウェアでは、「逆量子化、浮動小数点演算の実行、量子化」という戦略を使用して量子化されたオペレーションを解釈する場合があります。整数演算で計算全体を実行する場合があります。したがって、量子化された StableHLO オペレーションの解釈は、特定の実装によってのみ決定されます。ハイブリッド量子化(#1575)の解釈は、仕様(1792 経由)で規定されているセマンティクスに基づく必要があります。
エラー
StableHLO プログラムは、個々のオペレーションに対する広範な制約セットによって検証されます。これにより、実行前に多くのクラスのエラーが排除されます。ただし、整数オーバーフローや範囲外アクセスなどにより、エラー条件が発生する可能性はあります。明示的に呼び出されない限り、これらのエラーはすべて実装定義の動作になりますが、これは将来変更される可能性があります(#1157)。
浮動小数点例外
このルールの例外として、StableHLO プログラムの浮動小数点例外には明確に定義された動作があります。IEEE-754 標準で定義された例外(無効なオペレーション、ゼロ除算、オーバーフロー、アンダーフロー、不正確な例外)が発生するオペレーションは、デフォルトの結果(標準で定義)を生成し、対応するステータス フラグを発生させることなく実行を継続します。これは、標準の raiseNoFlag 例外処理に似ています。非標準の演算(複雑な算術演算や特定の超越関数など)の例外は、実装定義です。
形状の不一致
StableHLO は動的形状のテンソルをサポートしています。ただし、形状は実行時に一致している必要があります。一致していない場合、動作は未定義になります。StableHLO は、実行時にテンソルが特定の形状を持つことをアサートできる op を明示的に提供していません。正しいコードを生成する責任はプロデューサーにあります。
具体的な例として、次のプログラムは有効です。ただし、実行時には %arg0 と %arg1 の正確な形状が同じである必要があります。そうでない場合、プログラムの動作は未定義になります。
func.func @foo(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<?xi32>
return %0 : tensor<?xi32>
}
Notation
構文を記述するために、このドキュメントでは、EBNF 構文の修正された ISO フレーバー(ISO/IEC 14977:1996、Wikipedia)を使用します。2 つの変更があります。1)ルールは = ではなく ::= を使用して定義されます。
2)連結は , ではなく並置で表現されます。
セマンティクス(「型」、「定数」、「演算」セクション内)を記述するために、Python 構文をベースにした数式を使用しています。この数式は、以下で説明するように、配列演算を簡潔に表現するためのサポートが追加されています。これは小さなコード スニペットには適していますが、大きなコード スニペットが必要になるまれなケースでは、常に明示的に導入されるバニラ Python 構文を使用します。
数式
dot_general 仕様の例に基づいて、数式の仕組みを見てみましょう。このオペレーションの制約の 1 つは dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...) です。
この数式で使用される名前は、1)グローバル関数(dim など)、2)対応するプログラム要素のメンバー定義(dot_general の「入力」セクションで定義された lhs、lhs_batching_dimensions、rhs、rhs_batching_dimensions の入力など)の 2 つのソースから取得されます。
前述のとおり、この数式の構文は Python ベースですが、簡潔さを重視した拡張機能がいくつかあります。この数式を理解するために、標準の Python 構文に変換してみましょう。
A)これらの式では、等価性を表すために = を使用しています。Python 構文を取得する最初のステップは、次のように = を == に置き換えることです。dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...)。
B)また、これらの式は、スカラー式をテンソル式に変換する省略記号(...)をサポートしています。簡単に言うと、f(xs...) は「テンソル xs の各スカラー x について、スカラー f(x) を計算し、これらのスカラー結果をすべてテンソル結果としてまとめて返す」という意味です。標準の Python 構文では、この数式は [dim(lhs, dim1) for dim1 in lhs_batching_dimensions] ==
[dim(rhs, dim2) for dim2 in rhs_batching_dimensions] になります。
省略記号を使用すると、個々のスカラー レベルで作業する必要がなくなることがよくあります。ただし、複雑なケースでは、gather 仕様の start_indices[bi0, ..., :, ..., biN] 式のように、下位レベルの準形式構文が使用されることがあります。簡潔にするため、このような構文をプレーンな Python に変換するための正確な形式主義は提供しません。ケースバイケースで直感的に理解できることを期待しています。特定の数式がわかりにくい場合は、改善を試みますのでお知らせください。
また、数式では、テンソル、テンソルのリスト(可変個数のテンソルから生じる可能性があるものなど)など、あらゆる種類のリストを拡張するために省略記号が使用されています。これも、正確な形式主義が提供されていない(リストは StableHLO 型システムの一部ですらない)領域であり、直感的な理解に依存しています。
C)最後に、暗黙的なブロードキャストという注目すべき表記法を紹介します。StableHLO opset は暗黙的なブロードキャストをサポートしていませんが、簡潔にするために、数式はサポートしています。簡単に言うと、テンソルが想定されるコンテキストでスカラーが使用されると、スカラーは想定される形状にブロードキャストされます。
dot_general の例を続けると、別の制約は 0 <= lhs_batching_dimensions < rank(lhs) になります。dot_general 仕様で定義されているように、lhs_batching_dimensions はテンソルですが、0 と rank(lhs) はスカラーです。暗黙的なブロードキャストを適用すると、式は [0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)] になります。
特定の dot_general オペレーションに適用すると、この式はブール値のテンソルに評価されます。式が制約として使用される場合、式が true または true 要素のみを含むテンソルと評価されると、制約が保持されます。
名前
数式では、レキシカル スコープに 1) グローバル関数、2) メンバー定義が含まれます。
3)ローカル定義。グローバル関数のリストを以下に示します。要素定義のリストは、表記が適用されるプログラム要素によって異なります。
- オペレーションの場合、メンバー定義には「入力」セクションと「出力」セクションで導入された名前が含まれます。
- それ以外の場合、メンバー定義には、対応する EBNF 非終端記号にちなんで名付けられたプログラム要素の構造部分が含まれます。ほとんどの場合、これらの構造部分の名前は、非終端記号の名前をスネークケースに変換することで取得されます(例:
IntegerLiteral=>integer_literal)。ただし、場合によっては、その過程で名前が省略されることがあります(例:QuantizationStorageType=>storage_type)。その場合は、オペレーション仕様の「入力」 / 「出力」セクションと同様に、名前が明示的に導入されます。 - また、メンバー定義には、対応するプログラム要素を参照する
selfが常に含まれます。
値
数式は、評価時に次の種類の値を使用します。1)Value(実際の値。例: dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>。常に型がわかっています)、2)Placeholder(将来の値。例: lhs、rhs、result。実際の値はまだわからず、型のみがわかっています)、3)Type(「型」セクションで定義されている型)、4)Function(「関数」セクションで定義されているグローバル関数)。
コンテキストによっては、名前が異なる値を参照している場合があります。具体的には、ops の「セマンティクス」セクション(他のプログラム要素の同等のセクション)でランタイム ロジックが定義されているため、すべての入力は Value として使用できます。一方、演算(および同等のもの)の「制約」セクションでは、「コンパイル時」ロジック(通常は実行前に実行されるもの)が定義されるため、Value として使用できるのは定数入力のみで、その他の入力は Placeholder としてのみ使用できます。
| 名前 | 「セマンティクス」 | 「制約」 |
|---|---|---|
| グローバル機能 | Function |
Function |
| 定数入力 | Value |
Value |
| 非定数入力 | Value |
Placeholder |
| 出力 | Value |
Placeholder |
| ローカル定義 | 定義によって異なる | 定義によって異なる |
transpose オペレーションの例を考えてみましょう。
%result = "stablehlo.transpose"(%operand) {
permutati<on = dens>e[2, 1, 0<] : t>ensor3xi64
}< : (tenso>r2x>3x2xi32<) - tenso>r2x3x2xi32
このオペレーションでは、permutation は定数であるため、セマンティクスと制約の両方で Value として使用できます。一方、operand と result は、セマンティクスでは Value として使用できますが、制約では Placeholder としてのみ使用できます。
関数
型の構築
型を構築するために使用できる関数はありません。代わりに、通常はより簡潔な型構文を直接使用します。例: function_type(
[tensor_type([], E), tensor_type([], E)], [tensor_type([], E)]) ではなく (tensor<E>, tensor<E>) -> (tensor<E>)。
型に対する関数
element_typeはテンソル型と量子化テンソル型で定義され、それぞれ対応するTensorTypeまたはQuantizedTensorTypeのTensorElementTypeまたはQuantizedTensorElementType部分を返します。
def element_type(x: Value | Placeholder | Type):
if type(x) == TensorType:
return tensor_element_type(x)
if type(x) == QuantizedTensorType:
return quantized_tensor_element_type(x)
if type(x) is not Type:
return element_type(type(x))
is_per_axis_quantized(x: Value | Placeholder | Type) -> Valueはis_quantized(x) and quantization_dimension(x) is not Noneのショートカットです。is_per_tensor_quantized(x: Value | Placeholder | Type) -> Valueはis_quantized(x) and quantization_dimension(x) is Noneのショートカットです。is_promotable(x: Type, y: Type) -> boolは、型xを型yに昇格できるかどうかを確認します。xとyがQuantizedTensorElementTypeの場合、プロモーションはstorage_typeにのみ適用されます。この特定のバージョンのプロモーションは、現在、削減計算のコンテキストで使用されています(詳細については、RFC を参照してください)。
def is_promotable(x: Type, y: Type) -> Value:
is_same_type = (is_bool(x) and is_bool(y)) or
(is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
(is_complex(x) and is_complex(y)) or
(is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))
if is_same_type == False:
return False
if is_integer(x) or is_float(x):
return bitwidth(x) <= bitwidth(y)
if is_complex(x):
return bitwidth(element_type(x)) <= bitwidth(element_type(y))
if is_quantized(x):
return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))
return false
is_quantized(x: Value | Placeholder | Type) -> Valueはis_quantized_tensor_element_type(x)のショートカットです。is_type_name(x: Value | Placeholder | Type) -> Value。すべてのタイプで使用できます。たとえば、xがFloatTypeの場合、is_float(x)はtrueを返します。xが値またはプレースホルダの場合、この関数はis_type_name(type(x))のショートカットです。max_value(x: Type) -> ValueはTensorElementTypeの最大値を返します。xがTensorElementTypeでない場合は、Noneを返します。min_value(x: Type) -> Valueは、TensorElementTypeの最小値を返します。xがTensorElementTypeでない場合は、Noneを返します。member_name(x: Value | Placeholder | Type) -> Any。すべての型のすべてのメンバー定義member_nameで使用できます。たとえば、tensor_element_type(x)は対応するTensorTypeのTensorElementType部分を返します。xが値またはプレースホルダの場合、この関数はmember_name(type(x))のショートカットです。xが適切なメンバーを持つ型、またはそのような型の値やプレースホルダでない場合は、Noneを返します。is_empty_algorithm(*args: Type)は、すべてのドット アルゴリズム フィールドがNoneに設定されているかどうかを確認します。ドット アルゴリズムには実装で定義されたデフォルトの動作があるため、デフォルト値を指定すると正しくありません。
値の構築
operation_name(*xs: Value | Type) -> Value。すべてのオペレーションで使用できます。たとえば、add(lhs, rhs)は 2 つのテンソル値lhsとrhsを受け取り、これらの入力でaddオペレーションを評価した結果を返します。broadcast_in_dimなどの一部のオペレーションでは、出力の型は「耐荷重」です。つまり、オペレーションの評価に必要です。この場合、関数はこれらの型を引数として受け取ります。
値の関数
Python のすべての演算子と関数を使用できます。たとえば、Python の サブスクリプション表記と スライス表記の両方を使用して、テンソル、量子化テンソル、タプルにインデックスを付けることができます。
to_destination_type(x: Value, destination_type: Type) -> Valueはテンソルで定義され、type(x)とdestination_typeに基づいてxの変換値を次のように返します。
def to_destination_type(x: Value, destination_type: Type) -> Value:
if type(x) == destination_type:
return x
if is_quantized(destination_type):
if is_quantized(type(x)):
return quantize(x, destination_type)
assert is_float(type(x))
return quantize(x, destination_type)
if is_quantized(type(x)):
assert destination_type = expressed_type(type(x))
return dequantize(type(x))
return convert(x, destination_type)
convert、uniform_quantize、uniform_dequantize オペレーションの統合に関する初期の議論があります(#1576)。マージ後は上記の関数は不要になり、代わりに convert のオペレーション名を使用できます。
is_nan(x: Value) -> Valueはテンソルで定義され、xのすべての要素がNaNの場合はtrueを返し、それ以外の場合はfalseを返します。xがテンソルでない場合は、Noneを返します。is_sorted(x: Value) -> Valueはテンソルで定義され、xの要素がインデックスの昇順の辞書式順序で昇順に並べ替えられている場合はtrueを返し、それ以外の場合はfalseを返します。xがテンソルでない場合は、Noneを返します。is_unique(x: Value) -> Valueはテンソルで定義され、xに重複する要素がない場合はtrueを返し、それ以外の場合はfalseを返します。xがテンソルでない場合は、Noneを返します。member_name(x: Value) -> Anyは、すべての値のすべてのメンバー定義member_nameに対して定義されます。たとえば、real_part(x)は対応するComplexConstantのRealPart部分を返します。xが適切なメンバーを持つ値でない場合、Noneを返します。same(x: Value) -> Valueはテンソルで定義され、xの要素がすべて等しい場合はtrueを返し、それ以外の場合はfalseを返します。テンソルに要素がない場合、「すべてが互いに等しい」と見なされ、関数はtrueを返します。xがテンソルでない場合は、Noneを返します。split(x: Value, num_results: Value, axis: Value) -> Valueはテンソルで定義され、軸axisに沿ってxのnum_resultsスライスを返します。xがテンソルまたはdim(x, axis) % num_results != 0でない場合は、Noneを返します。is_defined_in_parent_scope(x: Value) -> Valueは文字列で定義され、xが関連する op の親関数と同じスコープで定義された関数の名前である場合はtrueを返します。is_namespaced_op_name(x: Value) -> Valueは文字列で定義され、xが有効な op 名(つまり、次の正規表現[a-zA-Z][a-zA-Z0-9_]*([.][a-zA-Z0-9_$]+)+を尊重している)の場合にtrueを返します。
形状の計算
axes(x: Value | Placeholder | Type) -> Valueはrange(rank(x))のショートカットです。dim(x: Value | Placeholder | Type, axis: Value) -> Valueはshape(x)[axis]のショートカットです。dims(x: Value | Placeholder | Type, axes: List) -> Listはlist(map(lambda axis: dim(x, axis), axes))のショートカットです。index_space(x: Value | Placeholder | Type) -> Valueはテンソルで定義され、対応するTensorTypeのsize(x)インデックスを辞書順の昇順([0, ..., 0]、[0, ..., 1]、...、shape(x) - 1)で返します。xがテンソル型、量子化テンソル型、またはこれらの型のいずれかの値またはプレースホルダでない場合、Noneを返します。rank(x: Value | Placeholder | Type) -> Valueはsize(shape(x))のショートカットです。shape(x: Value | Placeholder | Type) -> Valueは、member_nameを介して「型に対する関数」セクションで定義されています。size(x: Value | Placeholder | Type) -> Valueはreduce(lambda x, y: x * y, shape(x))のショートカットです。
量子化の計算
def baseline_element_type(x: Value | Placeholder | Type) -> Typeはelement_type(baseline_type(x))のショートカットです。baseline_typeは、テンソル型と量子化テンソル型で定義され、それらを「ベースライン」、つまり同じ形状で要素型の量子化パラメータがデフォルト値にリセットされた型に変換します。これは、テンソルと量子化テンソルの両方の型を均一に比較するための便利なトリックとして使用されます。これは、かなり頻繁に必要になります。量子化された型の場合、これにより、量子化パラメータを無視して型を比較できます。つまり、shape、storage_type、expressed_type、storage_min、storage_max、quantization_dimension(軸ごとの量子化された型の場合)はすべて一致する必要がありますが、scalesとzero pointsは異なる場合があります。
def baseline_type(x: Value | Placeholder | Type) -> Type:
if type(x) == TensorType:
return x
if type(x) == QuantizedTensorType:
element_type = quantized_tensor_element_type(x)
baseline_element_type = QuantizedTensorElementType(
storage_type = storage_type(element_type),
storage_min = storage_min(element_type),
storage_max = storage_max(element_type),
expressed_type = expressed_type(element_type),
quantization_dimension = quantization_dimension(element_type),
scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
return QuantizedTensorType(shape(x), baseline_element_type)
if type(x) is not Type:
return baseline_element_type(type(x))
dequantizeは量子化テンソル型で定義され、それらを浮動小数点テンソル型に変換します。これは、量子化された要素の型に関連付けられたゼロ点とスケールを使用して、ストレージ型の整数値を表す量子化された要素を、表現型の対応する浮動小数点値に変換することで行われます。
def compute_zero_points(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
zero_points[i] = zero_points(quantized_type)[i[d]]
return zero_points
def compute_scales(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
type(result_type))
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
scales[i] = scales(quantized_type)[i[d]]
return scales
def dequantize(x: Value) -> Value:
assert is_quantized(x)
x_storage = bitcast_convert(x, storage_type(x))
x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
x_expressed_sub = convert(x_storage_sub, expressed_type(x))
return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
quantizeは浮動小数点テンソル型で定義され、それらを量子化テンソル型に変換します。これは、量子化された要素タイプに関連付けられたゼロ点とスケールを使用して、表現された型の浮動小数点値をストレージ型の対応する整数値に変換することで行われます。
def quantize(x: Value, result_type: Type) -> Value:
assert is_float(x) and is_quantized(result_type)
zero_points = compute_zero_points(result_type, TensorType(shape(x), storage_type(result_type)))
converted_zero_points = convert(zero_points, expressed_type(result_type))
converted_min = convert(storage_min(result_type), expressed_type(result_type))
converted_max = convert(storage_max(result_type), expressed_type(result_type))
x_scaled = x / compute_scales(result_type, type(x))
x_scaled_add_zp = x_scaled + converted_zero_points
x_clamped = clamp(converted_min, x_scaled_add_zp, converted_max)
x_rounded = round_nearest_even(x_clamped)
return convert(x_rounded, result_type)
dequantize_op_quantizeは、量子化されたテンソルに対する要素ごとの計算を指定するために使用されます。逆量子化(量子化された要素を表現型に変換)してから演算を実行し、量子化(結果をストレージ型に戻す)します。現時点では、この関数はテンソルごとの量子化でのみ機能します。軸ごとの量子化は開発中です(#1574)。
def dequantize_op_quantize(op, *inputs_and_output_type):
inputs = inputs_and_output_type[:-1]
output_type = inputs_and_output_type[-1]
float_inputs = map(dequantize, inputs)
float_result = op(*float_inputs)
return quantize(float_result, output_type)
def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
inputs = inputs_and_output_type[:-3]
float_inputs = map(dequantize, inputs)
float_results = op(*float_inputs)
return map(quantize, float_results, inputs_and_output_type[-3:])
def dequantize_compare(lhs, rhs, comparison_direction):
float_lhs = dequantize(lhs)
float_rhs = dequantize(rhs)
return compare(float_lhs, float_rhs, comparison_direction, FLOAT)
def dequantize_select_quantize(pred, on_true, on_false, output_type):
float_on_true = dequantize(on_true)
float_on_false = dequantize(on_false)
float_result = select(pred, float_on_true, float_on_false)
return quantize(float_result, output_type)
hybrid_dequantize_then_opは、lhs を浮動小数点型で、rhs を量子化型で受け入れるハイブリッド オペレーションの重みのみの量子化を指定するために使用されます。量子化された入力を表現型に逆量子化し、浮動小数点演算を実行します。浮動小数点 lhs テンソルの要素型と量子化された rhs テンソルの表現型は同じである必要があります。
def hybrid_dequantize_then_op(op, lhs, rhs):
assert(is_float(lhs) and is_quantized(rhs) and element_type(lhs) == expressed_type(rhs))
return op(lhs, dequantize(rhs))
グリッド コンピューティング
cross_partition(replica_groups: Value) -> Value。上記の「cross_replica」セクションをご覧ください。cross_replica(replica_groups: Value) -> Value。上記の「cross_replica」セクションをご覧ください。cross_replica_and_partition(replica_groups: Value) -> Value。上記の「cross_replica_and_partition」セクションをご覧ください。flattened_ids(replica_groups: Value) -> Value。上記の「flattened_ids」セクションをご覧ください。
ダイナミズム
StableHLO 値は、tensor<?xi64> などの動的ディメンション サイズを持つことができます。ただし、StableHLO 値は動的な次元数(ランクなしの動的性、例: tensor<*xi64>)を持つことはできません。オペランドと結果は、サイズに制約がある場合でも、動的な次元サイズを使用できます。制約は、可能であれば静的に検証されます。それ以外の場合は、実行時に延期され、不一致があると未定義の動作が発生します。例については以下をご覧ください。
単項要素ごとのオペレーションのシェイプの不一致
次の簡単なプログラムを考えてみましょう。
func.func @foo(%arg0: tensor<?xf64>) {
%0 = stablehlo.abs %arg0 : (tensor<?xf64>) -> tensor<2xf64>
return
}
このようなプログラムは一般的ではありません。結果の形状はわかっているが、入力の形状はわからないという状況は一般的ではないからです。ただし、これは有効な StableHLO プログラムです。オペランドの正確な形状が不明なため、このプログラムで abs オペレーションを静的に検証することはできません。ただし、形状は確実に互換性があり、これは静的にチェックできます。? は実行時に 2 になる可能性があり、問題は発生しません。ただし、? が他の整数になる可能性もあります。その場合の動作は未定義です。
結果でディメンション サイズが動的である場合、未定義の動作は発生しません。実際には「想定される」サイズがないため、不一致は発生しません。
バイナリ要素ごとのオペレーションのシェイプの不一致
次の簡単なプログラムを考えてみましょう。
func.func @foo(%arg0: tensor<?xf64>, %arg1: tensor<?xf64>) {
%0 = stablehlo.add %arg0, %arg0 : (tensor<?xf64>, tensor<?xf64>) -> tensor<?xf64>
return
}
バイナリ要素ごとの演算の場合、入力と結果の形状は実行時に一致している必要があります。コンパイル時には、静的ディメンションは等しくなければなりませんが、それ以外の場合は互換性があるだけで構いません。入力に動的なディメンションがある場合、実行時に未定義の動作が発生する可能性があります。これは、動的なサイズが他のオペランド(静的または動的)の対応するサイズと一致しない可能性があるためです。すべての入力が静的である場合、結果が動的かどうかは関係ありません。静的に既知のディメンションは静的にチェックされ、動的ディメンションには制約が課されません。
出力シェイプをオペランドとして受け取るオペレーションのシェイプの不一致
次の簡単なプログラムを考えてみましょう。
func.func @foo(%arg0: tensor<2xi32>) {
%0 = stablehlo.dynamic_iota %arg0, dim = 0 : (tensor<2xi32>) -> tensor<3x4xi64>
return
}
実行時のシェイプ オペランドの値は、結果のシェイプと一致する必要があります。一致しない場合、動作は未定義になります。つまり、実行時に %arg0 の値は dense<[3, 4]> : tensor<2xi32> になります。形状オペランドが定数の場合、これは静的に検証できます。結果の形状が完全に動的である場合、不一致は発生しません。