StableHLO は、機械学習(ML)モデルの高レベル オペレーション(HLO)用のオペレーション セットです。StableHLO は、異なる ML フレームワークと ML コンパイラ間のポータビリティ レイヤとして機能します。StableHLO プログラムを生成する ML フレームワークは、StableHLO プログラムを使用する ML コンパイラと互換性があります。
Google の目標は、さまざまな ML フレームワーク(TensorFlow、JAX、PyTorch など)と ML コンパイラ(XLA、IREE など)間の相互運用性を高めることで、ML 開発を簡素化し、加速させることです。このドキュメントでは、その目的のために StableHLO プログラミング言語の仕様を説明します。
この仕様には、3 つの主要なセクションがあります。まず、プログラムのセクションでは、StableHLO 関数で構成される StableHLO プログラムの構造について説明します。StableHLO 関数自体は StableHLO オペレーションで構成されます。この構造内の [Ops] セクションには、個々のオペレーションのセマンティクスを指定します。[実行] セクションには、プログラム内で一緒に実行されるこれらのオペレーションのすべてのセマンティクスが示されています。最後に、表記セクションでは、仕様全体で使用される表記について説明します。
StableHLO の以前のリリースの仕様を表示するには、目的のタグ付きリリースでリポジトリを開きます。たとえば、StableHLO v0.19.0 仕様。StableHLO のマイナー バージョンの増加ごとに発生した変更を確認するには、VhloDialect.td のバージョンログをご覧ください。
プログラム
Program ::= {Func}
StableHLO プログラムは、任意の数の StableHLO 関数で構成されます。以下は、3 つの入力(%image
、%weights
、%bias
)と 1 つの出力を持つ関数 @main
を含むプログラムの例です。関数本体には 6 つのオペレーションがあります。
func.func @main(
%image: tensor<28x28xf32>,
%weights: tensor<784x10xf32>,
%bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
%0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
%1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
%2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
%3 = "stablehlo.constant"() {value = dense<0.0> : tensor<1x10xf32>} : () -> tensor<1x10xf32>
%4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
"func.return"(%4): (tensor<1x10xf32>) -> ()
}
関数
Func ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput ::= ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput ::= ValueType
FuncBody ::= {Op}
StableHLO 関数(名前付き関数とも呼ばれます)には、識別子、入出力、本文があります。今後、HLO との互換性を高めるために、関数に追加のメタデータを導入する予定です(#425、#626、#740、#744)。
識別子
FuncId ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
| '%' letter {letter | digit}
letter ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit ::= '0' | ... | '9'
StableHLO 識別子は多くのプログラミング言語の識別子に似ていますが、2 つの特殊性があります。1)すべての識別子には、さまざまな種類の識別子を区別するシグルが付いています。2)値識別子は完全に数値にすることができ、StableHLO プログラムの生成を簡素化できます。
型
Type ::= ValueType | NonValueType
ValueType ::= TensorType | QuantizedTensorType | TokenType | TupleType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType
StableHLO 型は、StableHLO 値を表す値の型(ファーストクラス型とも呼ばれます)と、他のプログラム要素を記述する値以外の型に分類されます。StableHLO 型は多くのプログラミング言語の型に似ていますが、主な特徴は StableHLO のドメイン固有の性質であり、異常な結果をもたらします(例: スカラー型は値型ではありません)。
TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit} | '?'
テンソル型はテンソル、つまり多次元配列を表します。シェイプと要素型があります。シェイプは、0
~R-1
の番号が付けられた対応するディメンション(軸とも呼ばれます)の昇順で、正または不明のディメンション サイズを表します。ディメンション R
の数はランクと呼ばれます。たとえば、tensor<2x3xf32>
は形状が 2x3
で要素型が f32
のテンソル型です。0 番目のディメンションと 1 番目のディメンションの 2 つのディメンション(つまり 2 つの軸)があり、サイズは 2 と 3 です。ランクは 2 です。
シェイプは、部分的または完全に未知(動的)にできます。たとえば、tensor<?x2xf64>
は部分的に未知、tensor<?x?xf64>
は完全に未知です。動的ディメンションのサイズは ?
を使用して表されます。シェイプのランク付けを解除することはできません。
今後、テンソル型をディメンション サイズと要素型以外にも拡張し、レイアウト(#629)やスパース性(#1078)などを含める予定です。
QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
QuantizationStorageType
['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
':' QuantizationExpressedType
[':' QuantizationDimension]
',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerLiteral
QuantizationStorageMax ::= IntegerLiteral
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerLiteral
QuantizationParameters ::= QuantizationParameter
| '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale [':' QuantizationZeroPoint]
QuantizationScale ::= FloatLiteral
QuantizationZeroPoint ::= IntegerLiteral
名前 | タイプ | 制約 |
---|---|---|
storage_type |
整数型 | (C1 ~ C3)、(C8) |
storage_min |
整数定数 | (C1)、(C3)、(C7) |
storage_max |
整数定数 | (C2)、(C3)、(C7) |
expressed_type |
floating-point-type | (C4) |
quantization_dimension |
省略可能な整数定数 | (C10-C12) |
scales |
浮動小数点定数の可変数 | (C4 ~ C6)、(C9)、(C10)、(C13) |
zero_points |
可変長整数定数 | (C7 ~ C9) |
量子化要素の型は、表現型の浮動小数点値に対応する、storage_min
~storage_max
(両端を含む)の範囲のストレージ型の整数値を表します。指定された整数値 i
について、対応する浮動小数点値 f
は f = (i - zero_point) * scale
として計算できます。ここで、scale
と zero_point
は量子化パラメータと呼ばれます。storage_min
と storage_max
は文法では省略可能ですが、デフォルト値はそれぞれ min_value(storage_type)
と max_value(storage_type)
です。量子化要素タイプには次の制約があります。
- (C1)
type(storage_min) = storage_type
。 - (C2)
type(storage_max) = storage_type
。 - (C3)
min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type)
。 - (C4)
type(scales...) = expressed_type
。 - (C5)
0 < scales
。 - (C6)
is_finite(scales...)
。 - (C7)
storage_min <= zero_points <= storage_max
。 - (C8)
type(zero_points...) = storage_type
。 - (C9)
size(scales) = size(zero_points)
。 - (C10)
is_empty(quantization_dimension)
の場合はsize(scales) = 1
です。 - (C11)
0 <= quantization_dimension
。
現在、QuantizationScale
は浮動小数点定数ですが、乗数とシフトで表される整数ベースのスケールに強い関心があります。近い将来、この問題を調査する予定です(#1404)。
QuantizationZeroPoint
のセマンティクス(型、値、量子化されたテンソル型にゼロポイントが 1 つだけか複数ある可能性があるかなど)については、現在も議論が続いています。このディスカッションの結果に基づいて、ゼロポイントに関する仕様は今後変更される可能性があります(#1405)。
現在進行中のもう一つの議論では、QuantizationStorageMin
と QuantizationStorageMax
のセマンティクスについて、これらの値と量子化テンソルの値に制約を適用すべきかどうかを判断しています(#1406)。
最後に、未知のディメンション サイズの表現(#1407)と同様に、未知のスケールとゼロポイントの表現を検討する予定です。
量子化テンソル型は、量子化された要素を持つテンソルを表します。これらのテンソルは、通常の要素型ではなく、量子化された要素型を持つ点を除き、通常のテンソルとまったく同じです。
量子化されたテンソルでは、量子化はテンソル単位で行うことができます。つまり、テンソル全体に 1 つの scale
と zero_point
を設定します。または、軸単位で行うこともできます。つまり、特定のディメンション quantization_dimension
のスライスごとに 1 組の scales
と zero_points
を設定します。より正式には、軸ごとの量子化を使用するテンソル t
には、quantization_dimension
の dim(t, quantization_dimension)
スライス(t[:, ..., 0, ..., :], t[:, ..., 1, ..., :]
など)があります。i
番目のスライスのすべての要素は、量子化パラメータとして scales[i]
と zero_points[i]
を使用します。量子化テンソル型には次の制約があります。
- テンソルごとの量子化の場合:
- 追加の制約はありません。
- 軸ごとの量子化の場合:
- (C12)
quantization_dimension < rank(self)
。 - (C13)
dim(self, quantization_dimension) = size(scales)
。
- (C12)
TokenType ::= 'token'
トークン型は、トークン(一部のオペレーションによって生成および消費される不透明な値)を表します。トークンは、実行セクションで説明されているように、オペレーションに実行順序を適用するために使用されます。
TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]
タプル型は、タプル(異種リスト)を表します。タプルは、HLO との互換性を確保するためにのみ存在するレガシー機能です。HLO では、タプルを使用して可変長の入力と出力を表します。StableHLO では、可変長の入力と出力がネイティブにサポートされています。StableHLO でタプルを使用するのは、HLO ABI を包括的に表す場合のみです。たとえば、T
、tuple<T>
、tuple<tuple<T>>
は特定の実装によって大きく異なる場合があります。今後、HLO ABI を変更し、StableHLO からタプル型を削除する予定です(#598)。
TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si2' | 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui2' | 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f4E2M1FN' | 'f6E2M3FN' | 'f6E3M2FN' | 'f8E3M4' | 'f8E4M3'
| 'f8E4M3FN' | 'f8E4M3FNUZ' | 'f8E4M3B11FNUZ' | 'f8E5M2'
| 'f8E5M2FNUZ' | 'f8E8M0FNU' | 'bf16' | 'f16' | 'f32' | 'f64'
TensorFloat32 ::= 'tf32'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'
要素型は、テンソル型の要素を表します。多くのプログラミング言語とは異なり、これらの型は StableHLO ではファーストクラスではありません。つまり、StableHLO プログラムでは、これらの型の値を直接表現できません(その結果、T
型のスカラー値を tensor<T>
型の 0 次元テンソル値で表現するのが一般的です)。
- ブール値型は、ブール値
true
とfalse
を表します。 - 整数型は、符号付き(
si
)または符号なし(ui
)のいずれかであり、サポートされているビット幅(2
、4
、8
、16
、32
、64
)のいずれかです。符号付きsiN
型は、-2^(N-1)
~2^(N-1)-1
の整数値を表し、符号なしuiN
型は、0
~2^N-1
の整数値を表します。 - 浮動小数点型は次のいずれかになります。
f8E3M4
、f8E4M3
、f8E5M2
: IEEE-754 規則に従う 8 ビットの浮動小数点数。f8E4M3FN
型とf8E5M2
型は、ディープラーニング用の FP8 形式で説明されている FP8 形式のE4M3
エンコードとE5M2
エンコードに対応しています。f8E4M3FNUZ
型とf8E5M2FNUZ
型は、ディープ ニューラル ネットワーク用の 8 ビット数値形式で説明されている FP8 形式のE4M3
エンコードとE5M2
エンコードに対応しています。- ハイブリッド 8 ビット浮動小数点数(HFP8)トレーニングとディープ ニューラル ネットワークの推論で説明されている FP8 形式の
E4M3
エンコードに対応するf8E4M3B11FNUZ
タイプ。 - BFloat16: Cloud TPU で高パフォーマンスを発揮させる秘訣で説明されている
bfloat16
形式に対応するbf16
型。 f16
、f32
、f64
型は、IEEE 754 標準で説明されているbinary16
(「半精度」)、binary32
(「単精度」)、binary64
(「倍精度」)の各形式に対応しています。tf32
型は TensorFloat32 形式に対応しており、StableHLO では限定的にサポートされています。- OCP マイクロスケーリング形式の仕様で説明されている
f4E2M1FN
、f6E2M3FN
、f6E3M2FN
、f8E8M0FNU
MX(マイクロスケーリング)タイプ。
- 複素型は、同じ要素型の実部と虚部を持つ複素値を表します。サポートされている複合型は
complex<f32>
(どちらもf32
型)とcomplex<f64>
(どちらもf64
型)です。
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]
関数型は、名前付き関数と匿名関数の両方を表します。入力型(->
の左側の型のリスト)と出力の型(->
の右側の型のリスト)があります。多くのプログラミング言語では関数型がファースト クラスですが、StableHLO ではそうではありません。
StringType ::= 'string'
String 型はバイトのシーケンスを表します。多くのプログラミング言語とは異なり、StableHLO では文字列型はファーストクラスではなく、プログラム要素の静的メタデータを指定する場合にのみ使用されます。
運用
StableHLO オペレーション(オペレーションとも呼ばれる)は、ML モデルのクローズド セットの大まかなオペレーションを表します。前述のように、StableHLO 構文は MLIR に大きく影響を受けています。これは必ずしも最も人間工学的な代替手段ではありませんが、ML フレームワークと ML コンパイラ間の相互運用性を高めるという StableHLO の目標には最適です。
Op ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic ::= 'abs' | 'add' | ...
StableHLO オペレーション(オペレーションとも呼ばれる)には、名前、入出力、署名があります。この名前は、stablehlo.
接頭辞と、サポートされているいずれかのオペレーションを一意に識別するニーモニックで構成されます。サポートされているすべてのオペレーションの一覧については、以下をご覧ください。
OpInputs ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue ::= ValueId
OpInputFuncs ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs ::= [OpOutput {',' OpOutput} '=']
OpOutput ::= ValueId
オペレーションは入力を使用し、出力を生成します。入力は、入力値(実行時に計算される)、入力関数(StableHLO では関数がファーストクラス値ではないため、静的に提供される)、入力属性(これも静的に提供される)に分類されます。op によって消費、生成される入力と出力の種類は、そのニモニックによって異なります。たとえば、add
オペレーションは 2 つの入力値を使用し、1 つの出力値を生成します。これに対し、select_and_scatter
op は 3 つの入力値、2 つの入力関数、3 つの入力属性を使用します。
OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused ::= '^' digit {digit}
| '^' letter {letter | digit}
入力関数(別名: 匿名関数)は、名前付き関数と非常によく似ていますが、1)識別子がない(「匿名」という名前の由来)ことと、2)出力型を宣言しない(出力型は関数内の return
オペレーションから推論される)点が異なります。
入力関数の構文には、MLIR との互換性を確保するために現在使用されていない部分(上記の Unused
生産を参照)が含まれています。MLIR には、ジャンプ演算子で接続された複数の演算子の「ブロック」を持つことができる、より一般的な「リージョン」のコンセプトがあります。これらのブロックには、Unused
本番環境に対応する ID が付いているため、ブロック同士を区別できます。StableHLO にはジャンプ演算がないため、MLIR 構文の対応する部分は使用されません(ただし、まだ存在します)。
OpInputAttr ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName ::= letter {letter | digit}
OpInputAttrValue ::= Constant
入力属性には、名前と値があります。値は、サポートされている定数の一つです。プログラム要素の静的メタデータを指定する主な方法です。たとえば concatenate
演算では、属性 dimension
を使用して、入力値を連結するディメンションを指定します。同様に、slice
演算は start_indices
や limit_indices
などの複数の属性を使用して、入力値をスライスするために使用される境界を指定します。
現時点では、実際の StableHLO プログラムに、このドキュメントで説明されていない属性が含まれていることがあります。今後、これらの属性を StableHLO オペセットに吸収するか、StableHLO プログラムに出現しないようにする予定です。それまでは以下の属性のリストをご覧ください。
layout
(#629)。mhlo.frontend_attributes
(#628)。mhlo.sharding
(#619)。output_operand_aliases
(#740)。- 位置情報メタデータ(#594)。
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'
オペレーション シグネチャは、すべての入力値の型(->
の左側の型のリスト)とすべての出力値の型(->
の右側の型のリスト)で構成されます。厳密に言えば、入力型は冗長であり、出力型もほとんどの場合冗長です(ほとんどの StableHLO オペレーションでは、出力型は入力から推測できるため)。ただし、op シグネチャは、MLIR との互換性を確保するために、意図的に StableHLO 構文の一部になっています。
以下に、メモニクスが select_and_scatter
の演算子の例を示します。3 つの入力値(%operand
、%source
、%init_value
)、2 つの入力関数、3 つの入力属性(window_dimensions
、window_strides
、padding
)を使用します。op のシグネチャには入力値の型のみが含まれます(インラインで提供される入力関数と属性の型は含まれません)。
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i32>, tensor<i32>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
window_dimensions = dense<[3, 1]> : tensor<2xi64>,
window_strides = dense<[2, 1]> : tensor<2xi64>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>
定数
Constant ::= BooleanConstant
| IntegerConstant
| FloatConstant
| ComplexConstant
| TensorConstant
| QuantizedTensorConstant
| StringConstant
| EnumConstant
StableHLO 定数には、StableHLO 値を表すリテラルと型があります。一般に、型は定数の構文の一部ですが、明確な場合を除きます(たとえば、ブール値定数は明確に型 i1
ですが、整数定数には複数の型が考えられます)。
BooleanConstant ::= BooleanLiteral
BooleanLiteral ::= 'true' | 'false'
ブール定数は、ブール値 true
と false
を表します。ブール値定数は i1
型です。
IntegerConstant ::= IntegerLiteral ':' IntegerType
IntegerLiteral ::= ['-' | '+'] DecimalDigits
| ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit ::= '0' | ... | '9'
hexadecimalDigit ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'
整数定数は、10 進数または 16 進数表記の文字列を使用して整数値を表します。バイナリやオクタルなどの他の基数はサポートされていません。整数定数には次の制約があります。
- (C1)
is_wellformed(integer_literal, integer_type)
。
FloatConstant ::= FloatLiteral ':' FloatType
FloatLiteral ::= SignPart IntegerPart FractionalPart ScientificPart
| '0x' [HexadecimalDigits]
SignPart ::= ['-' | '+']
IntegerPart ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]
浮動小数点定数は、10 進数または科学的記数法を使用する文字列で浮動小数点値を表します。また、16 進数表記を使用して、対応する型の浮動小数点形式で基盤となるビットを直接指定することもできます。浮動小数点定数には次の制約があります。
- (C1)16 進数以外の表記が使用されている場合、
is_wellformed(float_literal, float_type)
。 - (C2)16 進数表記を使用している場合、
size(hexadecimal_digits) = num_bits(float_type) / 4
。
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral ::= '(' RealPart ',' ImaginaryPart ')'
RealPart ::= FloatLiteral
ImaginaryPart ::= FloatLiteral
複素定数は、実部(先頭)と虚部(末尾)のリストを使用して複素値を表します。たとえば、(1.0, 0.0) : complex<f32>
は 1.0 + 0.0i
を表し、(0.0, 1.0) : complex<f32>
は 0.0 + 1.0i
を表します。これらのパーツをメモリに保存する順序は実装で定義します。複素定数には次の制約があります。
- (C1)
is_wellformed(real_part, complex_element_type(complex_type))
。 - (C2)
is_wellformed(imaginary_part, complex_element_type(complex_type))
。
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral
テンソル定数は、NumPy の表記法で指定されるネストされたリストを使用してテンソル値を表します。たとえば、dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
は、インデックスから要素への次のマッピングを持つテンソル値を表します。{0, 0} => 1
、{0, 1} => 2
、{0, 2} => 3
、{1, 0} => 4
、{1, 1} => 5
、{1, 2} => 6
。これらの要素がメモリに格納される順序は実装によって定義されます。テンソル定数には次の制約があります。
- (C1)
has_syntax(tensor_literal, element_type(tensor_type))
。ここで、has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type)
。has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type)
。
- (C2)
has_shape(tensor_literal, shape(tensor_type))
。次に例を示します。has_shape(element_literal: Syntax, []) = true
。has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:])
。- それ以外の場合は
false
。
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
量子化テンソル定数は、テンソル定数と同じ表記を使用して量子化テンソル値を表します。要素はストレージ タイプの定数として指定されます。量子化されたテンソル定数には次の制約があります。
- (C1)
has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type))
。 - (C2)
has_shape(quantized_tensor_literal, shape(quantized_tensor_type))
。
StringConstant ::= StringLiteral
StringLiteral ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))
文字列リテラルは、ASCII 文字とエスケープ シーケンスを使用して指定されたバイトで構成されます。エンコードに依存しないため、これらのバイトの解釈は実装で定義されます。文字列リテラルの型は string
です。
運用
abs
セマンティクス
operand
テンソルに対して要素ごとの絶対値演算を実行し、result
テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- 符号付き整数の場合: 整数の剰余。
- 浮動小数点数: IEEE-754 の
abs
。 - 複素数の場合: 複素数モジュラス。
- 量子化された型の場合:
dequantize_op_quantize(abs, operand, type(result))
。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
符号付き整数型、浮動小数点型、複合型のテンソル、またはテンソルごとの量子化テンソル | (C1 ~ C2) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
符号付き整数型または浮動小数点型のテンソル、またはテンソルごとの量子化テンソル | (C1-C2) |
制約
- (C1)
shape(result) = shape(operand)
。 - (C2)
baseline_element_type(result)
は次のように定義されます。is_complex(operand)
の場合、complex_element_type(element_type(operand))
。- そうでない場合は
baseline_element_type(operand)
。
例
// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand) : (tensor<3xi32>) -> tensor<3xi32>
// %result: [2, 0, 2]
追加
セマンティクス
2 つのテンソル lhs
と rhs
の要素ごとの加算を行い、result
テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- ブール値の場合: 論理 OR。
- 整数の場合: 整数の加算。
- 浮動小数点数の場合: IEEE-754 の
addition
。 - 複素数の場合: 複素数加算。
- 量子化された型の場合:
dequantize_op_quantize(add, lhs, rhs, type(result))
。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | lhs |
テンソルまたは量子化テンソル | (C1~C6) |
(I2) | rhs |
テンソルまたは量子化テンソル | (C1 ~ C5)、(C7) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
テンソルまたは量子化テンソル | (C1~C7) |
制約
- オペレーションで量子化されていないテンソルを使用する場合:
- (C1)
type(lhs) = type(rhs) = type(result)
。
- (C1)
- オペレーションで量子化テンソルを使用する場合:
- (C2)
is_quantized(lhs) and is_quantized(rhs) and is_quantized(result)
。 - (C3)
storage_type(lhs) = storage_type(rhs) = storage_type(result)
。 - (C4)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
。 - (C5)
(is_per_axis_quantized(lhs) or is_per_axis_quantized(rhs)) = is_per_axis_quantized(result)
。 - (C6)
is_per_axis_quantized(lhs)
の場合はquantization_dimension(lhs) = quantization_dimension(result)
です。 - (C7)
is_per_axis_quantized(rhs)
の場合はquantization_dimension(rhs) = quantization_dimension(result)
です。
- (C2)
例
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[6, 8], [10, 12]]
after_all
セマンティクス
inputs
を生成するオペレーションが、result
に依存するオペレーションの前に実行されるようにします。このオペレーションを実行しても何も行われません。このオペレーションは、result
から inputs
へのデータ依存関係を確立するためにのみ存在します。
入力
ラベル | 名前 | タイプ |
---|---|---|
(I1) | inputs |
可変長の token の数 |
出力
名前 | タイプ |
---|---|
result |
token |
例
// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token
all_gather
セマンティクス
StableHLO プロセス グリッドの各プロセス グループ内で、各プロセスの operands
テンソルの値を all_gather_dim
に沿って連結し、results
テンソルを生成します。
このオペレーションでは、StableHLO プロセス グリッドが、次のように定義された process_groups
に分割されます。
channel_id <= 0 and use_global_device_ids = false
の場合、cross_replica(replica_groups)
channel_id > 0 and use_global_device_ids = false
の場合、cross_replica_and_partition(replica_groups)
channel_id > 0 and use_global_device_ids = true
の場合、flattened_ids(replica_groups)
その後、各 process_group
内で次の操作を行います。
process_group
のすべてのreceiver
に対してoperands...@receiver = [operand@sender for sender in process_group]
。process_group
のすべてのprocess
に対してresults...@process = concatenate(operands...@process, all_gather_dim)
。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operands |
テンソルの可変数またはテンソルごとの量子化テンソル | (C1)、(C6) |
(I2) | all_gather_dim |
si64 型の定数 |
(C1)、(C6) |
(I3) | replica_groups |
si64 型の 2 次元のテンソル定数 |
(C2 ~ C4) |
(I4) | channel_id |
si64 型の定数 |
(C5) |
(I5) | use_global_device_ids |
i1 型の定数 |
(C5) |
出力
名前 | タイプ | 制約 |
---|---|---|
results |
テンソルの可変数またはテンソルごとの量子化テンソル | (C6) |
制約
- (C1)
0 <= all_gather_dim < rank(operands...)
。 - (C2)
is_unique(replica_groups)
。 - (C3)
size(replica_groups)
は次のように定義されます。cross_replica
を使用する場合はnum_replicas
。cross_replica_and_partition
を使用する場合はnum_replicas
。flattened_ids
を使用する場合はnum_processes
。
- (C4)
0 <= replica_groups < size(replica_groups)
。 - (C5)
use_global_device_ids = true
の場合、channel_id > 0
。 - (C6)
type(results...) = type(operands...)
(ただし、次の場合は除く)dim(results..., all_gather_dim) = dim(operands..., all_gather_dim) * dim(process_groups, 1)
。
例
// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [[1, 2], [3, 4]]
// %operand0@(1, 0): [[5, 6], [7, 8]]
// %operand1@(0, 0): [[11, 12], [13, 14]]
// %operand1@(1, 0): [[15, 16], [17, 18]]
%result:2 = "stablehlo.all_gather"(%operand0, %operand1) {
all_gather_dim = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
// channel_id = 0
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
// use_global_device_ids = false
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> (tensor<2x4xi64>, tensor<2x4xi64>)
// %result0@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result0@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result1@(0, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
// %result1@(1, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
all_reduce
セマンティクス
StableHLO プロセス グリッドの各プロセス グループ内で、各プロセスの operands
テンソルの値にリダクション関数 computation
を適用し、results
テンソルを生成します。
このオペレーションは、StableHLO プロセス グリッドを次のように定義された process_groups
に分割します。
channel_id <= 0 and use_global_device_ids = false
の場合、cross_replica(replica_groups)
channel_id > 0 and use_global_device_ids = false
の場合、cross_replica_and_partition(replica_groups)
channel_id > 0 and use_global_device_ids = true
の場合、flattened_ids(replica_groups)
その後、各 process_group
内で次の操作を行います。
results...@process[result_index] = exec(schedule)
は任意のバイナリ ツリーです。schedule
は次のとおりです。exec(node)
=computation(exec(node.left), exec(node.right))
。exec(leaf)
=leaf.value
。
schedule
は、順序付き走査がto_destination_type(operands...@process_group...[result_index], type(func_inputs(computation)[0]))
である実装定義のバイナリツリーです。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operands |
テンソルの可変数またはテンソルごとの量子化テンソル | (C5)、(C6) |
(I2) | replica_groups |
si64 型の 1 次元テンソル定数の可変数 |
(C1 ~ C3) |
(I3) | channel_id |
si64 型の定数 |
(C4) |
(I4) | use_global_device_ids |
i1 型の定数 |
(C4) |
(I5) | computation |
関数 | (C5) |
出力
名前 | タイプ | 制約 |
---|---|---|
results |
テンソルの可変数またはテンソルごとの量子化テンソル | (C6~C7) |
制約
- (C1)
is_unique(replica_groups)
。 - (C2)
size(replica_groups)
は次のように定義されます。cross_replica
を使用する場合はnum_replicas
。cross_replica_and_partition
を使用する場合はnum_replicas
。flattened_ids
を使用する場合はnum_processes
。
- (C3)
0 <= replica_groups < size(replica_groups)
。 - (C4)
use_global_device_ids = true
の場合、channel_id > 0
。 - (C5)
computation
の型は(tensor<E>, tensor<E>) -> (tensor<E>)
で、is_promotable(element_type(operand), E)
です。 - (C6)
shape(results...) = shape(operands...)
。 - (C7)
element_type(results...) = E
。
例
// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [1, 2, 3, 4]
// %operand0@(1, 0): [5, 6, 7, 8]
// %operand1@(0, 0): [9, 10, 11, 12]
// %operand1@(1, 0): [13, 14, 15, 16]
%result:2 = "stablehlo.all_reduce"(%operand0, %operand0) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
// channel_id = 0
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
// use_global_device_ids = false
} : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>, tensor<4xi64>)
// %result0@(0, 0): [6, 8, 10, 12]
// %result0@(1, 0): [6, 8, 10, 12]
// %result1@(0, 0): [22, 24, 26, 28]
// %result1@(1, 0): [22, 24, 26, 28]
all_to_all
セマンティクス
StableHLO プロセス グリッド内の各プロセス グループ内で、split_dimension
に沿って operands
テンソルの値を分割し、分割された部分をプロセス間で分散し、分散された部分を concat_dimension
に沿って連結して results
テンソルを生成します。このオペレーションでは、StableHLO プロセス グリッドが、次のように定義された process_groups
に分割されます。
channel_id <= 0
の場合、cross_replica(replica_groups)
。channel_id > 0
の場合はcross_partition(replica_groups)
。
その後、各 process_group
内で次の操作を行います。
process_group
内のすべてのsender
に対してsplit_parts...@sender = split(operands...@sender, split_count, split_dimension)
。scattered_parts...@receiver = [split_parts...@sender[receiver_index] for sender in process_group]
はreceiver_index = process_group.index(receiver)
です。results...@process = concatenate(scattered_parts...@process, concat_dimension)
。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operands |
テンソルの可変数またはテンソルごとの量子化テンソル | (C1 ~ C3)、(C9) |
(I2) | split_dimension |
si64 型の定数 |
(C1)、(C2)、(C9) |
(I3) | concat_dimension |
si64 型の定数 |
(C3)、(C9) |
(I4) | split_count |
si64 型の定数 |
(C2)、(C4)、(C8)、(C9) |
(I5) | replica_groups |
si64 型の 2 次元テンソル定数 |
(C5~C8) |
(I6) | channel_id |
si64 型の定数 |
出力
名前 | タイプ | 制約 |
---|---|---|
results |
テンソルの可変数またはテンソルごとの量子化テンソル | (C9) |
制約
- (C1)
0 <= split_dimension < rank(operands...)
。 - (C2)
dim(operands..., split_dimension) % split_count = 0
。 - (C3)
0 <= concat_dimension < rank(operands...)
。 - (C4)
0 < split_count
。 - (C5)
is_unique(replica_groups)
。 - (C6)
size(replica_groups)
は次のように定義されます。cross_replica
を使用する場合はnum_replicas
。cross_partition
を使用する場合はnum_partitions
。
- (C7)
0 <= replica_groups < size(replica_groups)
。 - (C8)
dim(replica_groups, 1) = split_count
。 - (C9)
type(results...) = type(operands...)
(ただし、split_dimension != concat_dimension
の場合を除く):dim(results..., split_dimension) = dim(operands..., split_dimension) / split_count
。dim(results..., concat_dimension) = dim(operands..., concat_dimension) * split_count
。
例
// num_replicas: 2
// num_partitions: 1
// %operand1@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand1@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
// %operand2@(0, 0): [[17, 18, 19, 20],
// [21, 22, 23, 24]]
// %operand2@(1, 0): [[25, 26, 27, 28],
// [29, 30, 31, 32]]
%result:2 = "stablehlo.all_to_all"(%operand1, %operand2) {
split_dimension = 1 : i64,
concat_dimension = 0 : i64,
split_count = 2 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
// channel_id = 0
} : (tensor<2x4xi64>, tensor<2x4xi64>) -> (tensor<4x2xi64>, tensor<4x2xi64>)
// %result#0@(0, 0): [[1, 2], [5, 6], [9, 10], [13, 14]]
// %result#0@(1, 0): [[3, 4], [7, 8], [11, 12], [15, 16]]
// %result#1@(0, 0): [[17, 18], [21, 22], [25, 26], [29, 30]]
// %result#1@(1, 0): [[19, 20], [23, 24], [27, 28], [31, 32]]
と
セマンティクス
2 つのテンソル lhs
と rhs
の要素ごとの AND を実行し、result
テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- ブール値の場合: 論理 AND。
- 整数の場合: ビット演算 AND。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | lhs |
ブール型または整数型のテンソル | (C1) |
(I2) | rhs |
ブール値型または整数型のテンソル | (C1) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
ブール型または整数型のテンソル | (C1) |
制約
- (C1)
type(lhs) = type(rhs) = type(result)
。
例
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 0]]
atan2
セマンティクス
lhs
テンソルと rhs
テンソルの要素ごとの atan2 演算を実行し、result
テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- 浮動小数点数の場合: IEEE-754 の
atan2
。 - 複素数の場合: 複素 atan2。
- 量子化された型の場合:
dequantize_op_quantize(atan2, lhs, rhs, type(result))
。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | lhs |
浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
(I2) | rhs |
浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
。
例
// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs) : (tensor<3xf64>, tensor<3xf64>) -> tensor<3xf64>
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]
batch_norm_grad
セマンティクス
grad_output
からバックプロパゲートされる batch_norm_training
の複数の入力の勾配を計算し、grad_operand
、grad_scale
、grad_offset
テンソルを生成します。より正式には、このオペレーションは、次のように Python 構文を使用して既存の StableHLO オペレーションへの分解として表現できます。
def compute_sum(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
return sum
def compute_mean(operand, feature_index):
sum = compute_sum(operand, feature_index)
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
# Broadcast inputs to type(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance`
# Intermediate values will be useful for computing gradients
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
# Use the implementation from batchnorm_expander.cc in XLA
# Temporary variables have exactly the same names as in the C++ code
elements_per_feature = broadcast_in_dim(
constant(divide(size(operand), dim(operand, feature_index)),
element_type(grad_output)),
[], type(operand))
i1 = multiply(grad_output, elements_per_feature)
i2 = broadcast_in_dim(
compute_sum(grad_output, feature_index), [feature_index], type(operand))
i3 = broadcast_in_dim(
compute_sum(multiply(grad_output, centered_operand), feature_index),
[feature_index], type(operand))
i4 = multiply(i3, centered_operand)
i5 = divide(i4, add(variance_bcast, epsilon_bcast))
i6 = subtract(subtract(i1, i2), i5)
grad_operand =
multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
grad_scale =
compute_sum(multiply(grad_output, normalized_operand), feature_index)
grad_offset = compute_sum(grad_output, feature_index)
return grad_operand, grad_scale, grad_offset
量子化された型の場合は、dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean,
variance, grad_output: batch_norm_grad(operand, scale, mean, variance,
grad_output, epsilon, feature_index), operand, scale, mean, variance,
grad_output, type(grad_operand), type(grad_scale), type(feature_index))
を実行します。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C1~C3)、(C5) |
(I2) | scale |
浮動小数点型またはテンソルごとの量子化型の 1 次元テンソル | (C2)、(C4)、(C5) |
(I3) | mean |
浮動小数点型またはテンソルごとの量子化型の 1 次元テンソル | (C2)、(C4) |
(I4) | variance |
浮動小数点型またはテンソルごとの量子化型の 1 次元テンソル | (C2)、(C4) |
(I5) | grad_output |
浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C2)、(C3) |
(I6) | epsilon |
f32 型の定数 |
|
(I7) | feature_index |
si64 型の定数 |
(C1)、(C5) |
出力
名前 | タイプ | 制約 |
---|---|---|
grad_operand |
浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C2)、(C3) |
grad_scale |
浮動小数点数またはテンソルごとの量子化型の 1 次元テンソル | (C2)、(C4) |
grad_offset |
浮動小数点型またはテンソルごとの量子化型の 1 次元テンソル | (C2)、(C4) |
制約
- (C1)
0 <= feature_index < rank(operand)
。 - (C2)
operand
、scale
、mean
、variance
、grad_output
、grad_operand
、grad_scale
、grad_offset
のbaseline_element_type
が同じである。 - (C3)
operand
、grad_output
、grad_operand
は同じ形状です。 - (C4)
scale
、mean
、variance
、grad_scale
、grad_offset
の形状が同じである。 - (C5)
size(scale) = dim(operand, feature_index)
。
例
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
// [[0.1, 0.1], [0.1, 0.1]],
// [[0.1, 0.1], [0.1, 0.1]]
// ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>,
tensor<2x2x2xf64>) -> (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %grad_operand: [
// [[0.0, 0.0], [0.0, 0.0]],
// [[0.0, 0.0], [0.0, 0.0]]
// ]
// %grad_scale: [0.0, 0.0]
// %grad_offset: [0.4, 0.4]
batch_norm_inference
セマンティクス
feature_index
次元を除くすべての次元で operand
テンソルを正規化し、result
テンソルを生成します。より正式には、このオペレーションは、Python 構文を使用して既存の StableHLO オペレーションへの分解として次のように表現できます。
def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
# Broadcast inputs to shape(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance` instead of
# computing them like `batch_norm_training` does.
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
return add(multiply(scale_bcast, normalized_operand), offset_bcast)
量子化された型の場合は、dequantize_op_quantize(lambda operand, scale, offset, mean, variance:
batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index), operand, scale, offset, mean, variance, type(result))
を実行します。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C1~C7) |
(I2) | scale |
浮動小数点数またはテンソルごとの量子化型の 1 次元テンソル | (C2)、(C3) |
(I3) | offset |
浮動小数点型またはテンソルごとの量子化型の 1 次元テンソル | (C2)、(C4) |
(I4) | mean |
浮動小数点数またはテンソルごとの量子化型の 1 次元テンソル | (C5) |
(I5) | variance |
浮動小数点数またはテンソルごとの量子化型の 1 次元テンソル | (C2)、(C6) |
(I6) | epsilon |
f32 型の定数 |
|
(I7) | feature_index |
si64 型の定数 |
(C1)、(C3 ~ C6) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C2)、(C7) |
制約
- (C1)
0 <= feature_index < rank(operand)
。 - (C2)
operand
、scale
、offset
、mean
、variance
、result
には同じbaseline_element_type
があります。 - (C3)
size(scale) = dim(operand, feature_index)
。 - (C4)
size(offset) = dim(operand, feature_index)
。 - (C5)
size(mean) = dim(operand, feature_index)
。 - (C6)
size(variance) = dim(operand, feature_index)
。 - (C7)
baseline_type(operand) = baseline_type(result)
。
例
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>) -> tensor<2x2x2xf64>
// %result: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
batch_norm_training
セマンティクス
feature_index
ディメンション以外のすべてのディメンションで平均と分散を計算し、operand
テンソルを正規化して output
、batch_mean
、batch_var
テンソルを生成します。より正式には、このオペレーションは、Python 構文を使用して既存の StableHLO オペレーションへの分解として次のように表現できます。
def compute_mean(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def compute_variance(operand, feature_index):
mean = compute_mean(operand, feature_index)
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
centered_operand = subtract(operand, mean_bcast)
return compute_mean(mul(centered_operand, centered_operand), feature_index)
def batch_norm_training(operand, scale, offset, epsilon, feature_index):
mean = compute_mean(operand, feature_index)
variance = compute_variance(operand, feature_index)
return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index),
mean, variance
量子化された型の場合は、dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset:
batch_norm_training(operand, scale, offset, epsilon, feature_index), operand,
scale, offset, type(output), type(batch_mean), type(batch_var))
を実行します。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C1) |
(I2) | scale |
浮動小数点またはテンソルごとの量子化された 1 次元テンソル | (C2)、(C3) |
(I3) | offset |
浮動小数点またはテンソルごとの量子化された 1 次元テンソル | (C2)、(C4) |
(I4) | epsilon |
f32 型の定数 |
(C1)、(C3 ~ C6) |
(I5) | feature_index |
si64 型の定数 |
(C1)、(C3 ~ C6) |
出力
名前 | タイプ | 制約 |
---|---|---|
output |
浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C7) |
batch_mean |
浮動小数点数またはテンソルごとに量子化された 1 次元テンソル | (C2)、(C5) |
batch_var |
浮動小数点数またはテンソルごとに量子化された 1 次元テンソル | (C2)、(C6) |
制約
- (C1)
0 <= feature_index < rank(operand)
。 - (C2)
operand
、scale
、offset
、batch_mean
、batch_var
、output
は同じbaseline_element_type
を持ちます。 - (C3)
size(scale) = dim(operand, feature_index)
。 - (C4)
size(offset) = dim(operand, feature_index)
。 - (C5)
size(batch_mean) = dim(operand, feature_index)
。 - (C6)
size(batch_var) = dim(operand, feature_index)
。 - (C7)
baseline_type(output) = baseline_type(operand)
。
例
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) ->
(tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %output: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]
bitcast_convert
セマンティクス
operand
テンソルにビットキャスト演算を実行し、result
テンソルを生成します。ここで、operand
テンソル全体のビットは result
テンソルの型を使用して再解釈されます。
より正式には、E = element_type(operand)
、E' = element_type(result)
、R = rank(operand)
を指定すると、次のようになります。
num_bits(E') < num_bits(E)
の場合はbits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1])
。num_bits(E') > num_bits(E)
の場合、bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :])
。num_bits(E') = num_bits(E)
の場合はbits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1])
。
bits
は、指定された値のメモリ内表現を返します。テンソルの正確な表現は実装定義であり、要素型の正確な表現も実装定義であるため、その動作は実装定義です。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
テンソルまたは量子化テンソル | (C1 ~ C2) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
テンソルまたは量子化テンソル | (C1 ~ C2) |
制約
- (C1)
E = is_quantized(operand) ? storage_type(operand) : element_type(operand)
、E' = is_quantized(result) ? storage_type(result) : element_type(result)
、R = rank(operand)
が指定されている場合:num_bits(E') = num_bits(E)
の場合はshape(result) = shape(operand)
。num_bits(E') < num_bits(E)
の場合:rank(result) = R + 1
。dim(result, i) = dim(operand, i)
はすべての0 <= i < R
に適用されます。dim(result, R) * num_bits(E') = num_bits(E)
。num_bits(E') > num_bits(E)
の場合:rank(result) = R - 1
。dim(result, i) = dim(operand, i)
はすべての0 <= i < R
に適用されます。dim(operand, R - 1) * num_bits(E) = num_bits(E')
。
- (C2)
is_complex(operand) or is_complex(result)
の場合はis_complex(operand) and is_complex(result)
です。
例
// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand) : (tensor<f64>) -> tensor<4xf16>
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation
broadcast_in_dim
セマンティクス
operand
テンソルにデータを複製することで入力テンソルの次元やランクを展開し、result
テンソルを生成します。より正式には、result[result_index] = operand[operand_index]
で、axes(operand)
内のすべての d
について次のようにします。
dim(operand, d) = 1
の場合、operand_index[d] = 0
。- それ以外の場合は
operand_index[d] = result_index[broadcast_dimensions[d]]
。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
テンソルまたは量子化テンソル | (C1-C2)、(C5-C6) |
(I2) | broadcast_dimensions |
si64 型の 1 次元テンソル定数 |
(C2~C6) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
テンソルまたは量子化テンソル | (C1)、(C3)、(C5~C6) |
制約
- (C1)
element_type(result)
は次のように定義されます。element_type(operand)
(!is_per_axis_quantized(operand)
の場合)。element_type(operand)
(ただし、quantization_dimension(operand)
、scales(operand)
、zero_points(operand)
は、それぞれquantization_dimension(result)
、scales(result)
、zero_points(result)
と異なる場合があります)。
- (C2)
size(broadcast_dimensions) = rank(operand)
。 - (C3)
0 <= broadcast_dimensions < rank(result)
。 - (C4)
is_unique(broadcast_dimensions)
。 - (C5)
axes(operand)
内のすべてのd
について:dim(operand, d) = 1
またはdim(operand, d) = dim(result, broadcast_dimensions[d])
。
- (C6)
is_per_axis_quantized(result)
の場合:quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)]
。dim(operand, quantization_dimension(operand)) = 1
の場合はscales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result)))
です。
例
// %operand: [
// [1, 2, 3]
// ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
broadcast_dimensions = array<i64: 2, 1>
} : (tensor<1x3xi32>) -> tensor<2x3x2xi32>
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
ケース
セマンティクス
index
の値に応じて、branches
から関数を 1 つだけ実行することで、出力を生成します。より正式には、result = selected_branch()
のようにします。
0 <= index < size(branches)
の場合、selected_branch = branches[index]
。- それ以外の場合は
selected_branch = branches[-1]
。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | index |
si32 型の 0 次元テンソル |
|
(I2) | branches |
関数の可変数 | (C1~C4) |
出力
名前 | タイプ | 制約 |
---|---|---|
results |
テンソル、量子化テンソル、またはトークンの可変数 | (C4) |
制約
- (C1)
0 < size(branches)
。 - (C2)
input_types(branches...) = []
。 - (C3)
same(output_types(branches...))
。 - (C4)
type(results...) = output_types(branches[0])
。
例
// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
"stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> ()
}, {
"stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> ()
}) : (tensor<i32>) -> (tensor<2xi64>, tensor<2xi64>)
// %result0: [1, 1]
// %result1: [1, 1]
CBRT
セマンティクス
operand
テンソルに対して要素単位の立方根演算を実行し、result
テンソルを生成します。要素のタイプに応じて、次の操作を行います。
- 浮動小数点数の場合: IEEE-754 の
rootn(x, 3)
。 - 複素数の場合: 複素立方根。
- 量子化型の場合:
dequantize_op_quantize(cbrt, operand, type(result))
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(operand) = baseline_type(result)
。
例
// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand) : (tensor<4xf64>) -> tensor<4xf64>
// %result: [0.0, 1.0, 2.0, 3.0]
ceil
セマンティクス
operand
テンソルの要素単位の ceil を実行し、result
テンソルを生成します。IEEE-754 仕様の roundToIntegralTowardPositive
オペレーションを実装します。量子化された型の場合は、dequantize_op_quantize(ceil, operand, type(result))
を実行します。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C1) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(operand) = baseline_type(result)
。
例
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]
cholesky
セマンティクス
行列のバッチの Cholesky 分解を計算します。
より正式には、index_space(result)
のすべての i
について、result[i0, ..., iR-3, :, :]
は a[i0, ..., iR-3, :, :]
の Cholesky 分解であり、下三角(lower
が true
の場合)または上三角(lower
が false
の場合)行列のいずれかになります。反対の三角形(厳密な上三角形または厳密な下三角形)の出力値は、実装で定義されます。
入力行列がエルミート正定値行列ではない i
が存在する場合、動作は未定義です。
量子化された型の場合は、dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result))
を実行します。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | a |
浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1~C3) |
(I2) | lower |
型 i1 の 0 次元テンソル定数 |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(a) = baseline_type(result)
。 - (C2)
2 <= rank(a)
。 - (C3)
dim(a, -2) = dim(a, -1)
。
例
// %a: [
// [1.0, 2.0, 3.0],
// [2.0, 20.0, 26.0],
// [3.0, 26.0, 70.0]
// ]
%result = "stablehlo.cholesky"(%a) {
lower = true
} : (tensor<3x3xf32>) -> tensor<3x3xf64>
// %result: [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
クランプ
セマンティクス
operand
テンソルのすべての要素を最小値と最大値の間にクランプし、result
テンソルを生成します。より正式には、result[result_index] =
minimum(maximum(operand[result_index], min_element), max_element)
(min_element = rank(min) = 0 ? min[] : min[result_index]
、max_element = rank(max) = 0 ? max[] : max[result_index]
)です。量子化された型の場合は、dequantize_op_quantize(clamp, min, operand, max, type(result))
を実行します。
複素数の順序付けには意外なセマンティクスが伴うため、将来的には、この演算での複素数のサポートは終了する予定です(#560)。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | min |
テンソルまたはテンソルごとの量子化テンソル | (C1)、(C3) |
(I2) | operand |
テンソルまたはテンソルごとの量子化テンソル | (C1~C4) |
(I3) | max |
テンソルまたはテンソルごとの量子化テンソル | (C2)、(C3) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
テンソルまたはテンソルごとの量子化テンソル | (C4) |
制約
- (C1)
rank(min) = 0 or shape(min) = shape(operand)
。 - (C2)
rank(max) = 0 or shape(max) = shape(operand)
。 - (C3)
baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max)
。 - (C4)
baseline_type(operand) = baseline_type(result)
。
例
// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max) : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// %result: [5, 13, 20]
collective_broadcast
セマンティクス
StableHLO プロセス グリッドの各プロセス グループ内で、operand
テンソルの値をソースプロセスからターゲット プロセスに送信し、result
テンソルを生成します。
このオペレーションでは、StableHLO プロセス グリッドが、次のように定義された process_groups
に分割されます。
channel_id <= 0
の場合はcross_replica(replica_groups)
。channel_id > 0
の場合、cross_partition(replica_groups)
。
その後、result@process
は次のようになります。
operand@process_groups[i, 0]
: プロセスがprocess_groups[i]
にあるようなi
が存在する場合。- それ以外の場合は
broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))
。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
テンソルまたはテンソルごとの量子化テンソル | (C3) |
(I2) | replica_groups |
si64 型の 1 次元テンソル定数の可変数 |
(C1)、(C2) |
(I3) | channel_id |
si64 型の定数 |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
テンソルまたはテンソルごとの量子化テンソル | (C3) |
制約
- (C1)
is_unique(replica_groups)
。 - (C2)
0 <= replica_groups < N
。ここで、N
は次のように定義されます。cross_replica
を使用する場合はnum_replicas
。cross_partition
を使用する場合はnum_partitions
。
- (C3)
type(result) = type(operand)
。
例
// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]
collective_permute
セマンティクス
StableHLO プロセス グリッド内の各プロセス グループ内で、operand
テンソルの値をソースプロセスからターゲット プロセスに送信し、result
テンソルを生成します。
このオペレーションは、StableHLO プロセス グリッドを次のように定義された process_groups
に分割します。
channel_id <= 0
の場合、cross_replica(source_target_pairs)
。channel_id > 0
の場合はcross_partition(source_target_pairs)
。
その後、result@process
は次のようになります。
operand@process_groups[i, 0]
。process_groups[i, 1] = process
のようなi
が存在する場合。- それ以外の場合は
broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))
。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
テンソルまたはテンソルごとの量子化テンソル | (C5) |
(I2) | source_target_pairs |
si64 型の 2 次元のテンソル定数 |
(C1~C4) |
(I3) | channel_id |
si64 型の定数 |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
テンソルまたはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
dim(source_target_pairs, 1) = 2
。 - (C2)
is_unique(source_target_pairs[:, 0])
。 - (C3)
is_unique(source_target_pairs[:, 1])
。 - (C4)
0 <= source_target_pairs < N
。ここで、N
は次のように定義されます。cross_replica
を使用する場合はnum_replicas
。cross_partition
を使用する場合はnum_partitions
。
- (C5)
type(result) = type(operand)
。
例
// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]
比較
セマンティクス
comparison_direction
と compare_type
に従って lhs
テンソルと rhs
テンソルの要素ごとの比較を行い、result
テンソルを生成します。
comparison_direction
と compare_type
の値のセマンティクスは次のとおりです。
ブール値と整数の要素型の場合:
EQ
:lhs = rhs
NE
:lhs != rhs
GE
:lhs >= rhs
GT
:lhs > rhs
LE
:lhs <= rhs
LT
:lhs < rhs
compare_type = FLOAT
の浮動小数点要素型の場合、op は次の IEEE-754 演算を実装します。
EQ
:compareQuietEqual
NE
:compareQuietNotEqual
GE
:compareQuietGreaterEqual
GT
:compareQuietGreater
LE
:compareQuietLessEqual
LT
:compareQuietLess
compare_type = TOTALORDER
を含む浮動小数点要素型の場合、op は IEEE-754 の totalOrder
オペレーションと compareQuietEqual
オペレーションの組み合わせを使用します。
複雑な要素型の場合、指定された comparison_direction
と compare_type
を使用して、(real, imag)
ペアの辞書順による比較が行われます。複素数に順序付けを適用すると、予期しないセマンティクスが発生するため、comparison_direction
が GE
、GT
、LE
、または LT
の場合、複素数のサポートを削除する予定です(#560)。
量子化された型の場合、dequantize_compare(lhs, rhs,
comparison_direction)
を実行します。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | lhs |
テンソルまたはテンソルごとの量子化テンソル | (C1 ~ C3) |
(I2) | rhs |
テンソルまたはテンソルごとの量子化テンソル | (C1-C2) |
(I3) | comparison_direction |
EQ 、NE 、GE 、GT 、LE 、LT の列挙型 |
|
(I4) | compare_type |
FLOAT 、TOTALORDER 、SIGNED 、UNSIGNED の列挙型 |
(C3) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
ブール型のテンソル | (C2) |
制約
- (C1)
baseline_element_type(lhs) = baseline_element_type(rhs)
。 - (C2)
shape(lhs) = shape(rhs) = shape(result)
。 - (C3)
compare_type
は次のように定義されます。is_signed_integer(element_type(lhs))
の場合はSIGNED
。is_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs))
の場合はUNSIGNED
。is_float(element_type(lhs))
の場合はFLOAT
またはTOTALORDER
。is_complex(element_type(lhs))
の場合はFLOAT
。
例
// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
comparison_direction = #stablehlo<comparison_direction LT>,
compare_type = #stablehlo<comparison_type FLOAT>
} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
// %result: [true, false]
複雑
セマンティクス
実数値と虚数値のペア(lhs
と rhs
)から複合値に要素単位で変換し、result
テンソルを生成します。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | lhs |
f32 型または f64 型のテンソル |
(C1~C3) |
(I2) | rhs |
f32 型または f64 型のテンソル |
(C1) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
複合型テンソル | (C2)、(C3) |
制約
- (C1)
type(lhs) = type(rhs)
。 - (C2)
shape(result) = shape(lhs)
。 - (C3)
element_type(result)
の型はcomplex<E>
で、E = element_type(lhs)
です。
例
// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs) : (tensor<2xf64>, tensor<2xf64>) -> tensor<2xcomplex<f64>>
// %result: [(1.0, 2.0), (3.0, 4.0)]
複合
セマンティクス
他の StableHLO オペレーションで構成された(コンポーズされた)オペレーションをカプセル化し、inputs
と composite_attributes
を受け取って results
を生成します。op のセマンティクスは decomposition
属性によって実装されます。composite
オペレーションは、プログラムのセマンティクスを変更せずに分解に置き換えることができます。分解をインライン化しても同じオペレーションのセマンティクスが得られない場合は、custom_call
を使用することをおすすめします。
version
フィールド(デフォルトは 0
)は、コンポジットのセマンティクスが変更されたときに使用されます。
入力
ラベル | 名前 | タイプ |
---|---|---|
(I1) | inputs |
値の可変数 |
(I2) | name |
string 型の定数 |
(I3) | composite_attributes |
属性辞書 |
(I4) | decomposition |
string 型の定数 |
(I5) | version |
si32 型の定数 |
出力
名前 | タイプ |
---|---|
results |
値の可変数 |
制約
- (C1)
is_namespaced_op_name(name)
- (C2)
is_defined_in_parent_scope(decomposition)
- (C3)
types(inputs...) == input_types(decomposition)
- (C4)
types(results...) == output_types(decomposition)
例
%results = "stablehlo.composite"(%input0, %input1) {
name = "my_namespace.my_op",
composite_attributes = {
my_attribute = "my_value"
},
decomposition = @my_op,
version = 1 : i32
} : (tensor<f32>, tensor<f32>) -> tensor<f32>
concatenate
セマンティクス
指定された引数と同じ順序で dimension
次元に沿って inputs
を連結し、result
テンソルを生成します。より正式には、result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1]
です。ここで、
id = d0 + ... + dk-1 + kd
。d
はdimension
に等しく、d0
はinputs
のd
番目のディメンション サイズです。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | inputs |
テンソルの可変数またはテンソルごとの量子化テンソル | (C1~C6) |
(I2) | dimension |
si64 型の定数 |
(C2)、(C4)、(C6) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
テンソルまたはテンソルごとの量子化テンソル | (C5-C6) |
制約
- (C1)
same(element_type(inputs...))
。 - (C2)
dim(inputs..., dimension)
を除くsame(shape(inputs...))
。 - (C3)
0 < size(inputs)
。 - (C4)
0 <= dimension < rank(inputs[0])
。 - (C5)
element_type(result) = element_type(inputs[0])
。 - (C6)
shape(result) = shape(inputs[0])
(ただし、次の場合は除く)。dim(result, dimension) = dim(inputs[0], dimension) + ...
。
例
// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
dimension = 0 : i64
} : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64>
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]
定数
セマンティクス
定数 value
から output
テンソルを生成します。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | value |
定数 | (C1) |
出力
名前 | タイプ | 制約 |
---|---|---|
output |
テンソルまたは量子化テンソル | (C1) |
制約
- (C1)
type(value) = type(output)
。
例
%output = "stablehlo.constant"() {
value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
} : () -> tensor<2x2xf32>
// %output: [[0.0, 1.0], [2.0, 3.0]]
コンバージョン
セマンティクス
operand
テンソルで要素型を要素ごとに変換し、result
テンソルを生成します。
boolean-to-any-supported-typeでは、値 false
はゼロに変換され、値 true
は 1 に変換されます。any-supported-type-to-booleanへの変換の場合、ゼロ値は false
に変換され、ゼロ以外の値は true
に変換されます。複雑な型での動作については以下をご覧ください。
integer-to-integer、integer-to-floating-point、floating-point-to-floating-point を含む変換では、ソース値が宛先の型で正確に表現できる場合、結果値はそのまま表現されます。それ以外の場合、動作は未定です(#180)。
浮動小数点数から整数への変換では、小数部分は切り捨てられます。切り捨てられた値を宛先型で表現できない場合、動作は未定です(#180)。
複素数から複素数への変換は、実部と虚部の変換における浮動小数点数から浮動小数点数への変換と同じ動作になります。
複素数から他の型への変換と他の型から複素数への変換では、それぞれソースの虚数値が無視されるか、宛先の虚数値がゼロになります。実部の変換は浮動小数点変換に従います。
原則として、このオペレーションはデクォンタイズ(量子化テンソルから通常のテンソルへの変換)、量子化(通常のテンソルから量子化テンソルへの変換)、再量子化(量子化テンソル間の変換)を表現できますが、現時点では専用のオペレーションがあります。最初のユースケースの場合は uniform_dequantize
、2 つ目のユースケースと 3 つ目のユースケースの場合は uniform_quantize
です。今後、これらの 2 つのオペレーションは convert
に統合される可能性があります(#1576)。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
テンソル | (C1) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
テンソル | (C1) |
制約
- (C1)
shape(operand) = shape(result)
。
例
// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand) : (tensor<3xi64>) -> tensor<3xcomplex<f64>>
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]
畳み込み
セマンティクス
lhs
のウィンドウと rhs
のスライス間のドット積を計算し、result
を生成します。次の図は、具体的な例を使用して、result
の要素が lhs
と rhs
からどのように計算されるかを示しています。
より正式には、lhs
のウィンドウを表現できるようにするために、次のように lhs
の観点から入力をフレーミングすることを検討します。
lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension))
。lhs_window_strides = lhs_shape(1, window_strides, 1)
lhs_padding = lhs_shape([0, 0], padding, [0, 0])
lhs_base_dilations = lhs_shape(1, lhs_dilation, 1)
lhs_window_dilations = lhs_shape(1, rhs_dilation, 1)
。
このフレーム変更では、次のヘルパー関数を使用します。
lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension])
。result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension])
。permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1]
はj[d] = i[permutation[d]]
です。
feature_group_count = 1
かつ batch_group_count = 1
の場合、index_space(dim(result, output_spatial_dimensions...))
内のすべての output_spatial_index
について、result[result_shape(:, output_spatial_index, :)] = dot_product
です。ここで、
padding_value = constant(0, element_type(lhs))
。padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1)
lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides
lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations)
。reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true])
。この機能は使用されていないと思われるため、今後削除する予定です(#1181)。dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension])
。
feature_group_count > 1
の場合:
lhses = split(lhs, feature_group_count, input_feature_dimension)
。rhses = split(rhs, feature_group_count, kernel_output_feature_dimension)
results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...)
result = concatenate(results, output_feature_dimension)
。
batch_group_count > 1
の場合:
lhses = split(lhs, batch_group_count, input_batch_dimension)
。rhses = split(rhs, batch_group_count, kernel_output_feature_dimension)
results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...)
。result = concatenate(results, output_feature_dimension)
。
量子化された型の場合は、dequantize_op_quantize(
lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs,
type(result))
を実行します。
ハイブリッド量子化タイプの場合は、hybrid_dequantize_then_op(
lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs)
を実行します。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | lhs |
テンソルまたはテンソルごとの量子化テンソル | (C1)、(C10~C11)、(C14)、(C25)、(C27~C28)、(C31~C32)、(C34) |
(I2) | rhs |
テンソルまたは量子化テンソル | (C1)、(C14 ~ C16)、(C25)、(C27 ~ C29)、(C31 ~ C34) |
(I3) | window_strides |
si64 型の 1 次元テンソル定数 |
(C2~C3)、(C25) |
(I4) | padding |
si64 型の 2 次元テンソル定数 |
(C4)、(C25) |
(I5) | lhs_dilation |
si64 型の 1 次元テンソル定数 |
(C5 ~ C6)、(C25) |
(I6) | rhs_dilation |
si64 型の 1 次元のテンソル定数 |
(C7 ~ C8)、(C25) |
(I7) | window_reversal |
i1 型の 1 次元テンソル定数 |
(C9) |
(I8) | input_batch_dimension |
si64 型の定数 |
(C10)、(C13)、(C25) |
(I9) | input_feature_dimension |
si64 型の定数 |
(C11)、(C13~C14) |
(I10) | input_spatial_dimensions |
si64 型の 1 次元のテンソル定数 |
(C12)、(C13)、(C25) |
(I11) | kernel_input_feature_dimension |
si64 型の定数 |
(C14)、(C18) |
(I12) | kernel_output_feature_dimension |
si64 型の定数 |
(C15 ~ C16)、(C18)、(C25)、(C29) |
(I13) | kernel_spatial_dimensions |
si64 型の 1 次元テンソル定数 |
(C17~C18)、(C25) |
(I14) | output_batch_dimension |
si64 型の定数 |
(C20)、(C25) |
(I15) | output_feature_dimension |
si64 型の定数 |
(C20)、(C25)、(C30) |
(I16) | output_spatial_dimensions |
si64 型の 1 次元テンソル定数 |
(C19 ~ C20)、(C25) |
(I17) | feature_group_count |
si64 型の定数 |
(C11)、(C14)、(C16)、(C21)、(C23) |
(I18) | batch_group_count |
si64 型の定数 |
(C10)、(C15)、(C22)、(C23)、(C25) |
(I19) | precision_config |
DEFAULT 、HIGH 、HIGHEST の可変列挙型の数 |
(C24) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
テンソルまたは量子化テンソル | (C25 ~ C28)、(C30)、(C32 ~ 34) |
制約
- (C1)
N = rank(lhs) = rank(rhs)
。 - (C2)
size(window_strides) = N - 2
。 - (C3)
0 < window_strides
。 - (C4)
shape(padding) = [N - 2, 2]
。 - (C5)
size(lhs_dilation) = N - 2
。 - (C6)
0 < lhs_dilation
。 - (C7)
size(rhs_dilation) = N - 2
。 - (C8)
0 < rhs_dilation
。 - (C9)
size(window_reversal) = N - 2
。 - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0
。 - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0
。 - (C12)
size(input_spatial_dimensions) = N - 2
。 - (C13)
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]
が与えられている場合:is_unique(input_dimensions)
。0 <= input_dimensions < N
。
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count
。 - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0
。 - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0
。 - (C17)
size(kernel_spatial_dimensions) = N - 2
。 - (C18)
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]
が与えられている場合:is_unique(kernel_dimensions)
。0 <= kernel_dimensions < N
。
- (C19)
size(output_spatial_dimensions) = N - 2
。 - (C20)
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]
が与えられている場合:is_unique(output_dimensions)
。0 <= output_dimensions < N
。
- (C21)
0 < feature_group_count
。 - (C22)
0 < batch_group_count
。 - (C23)
feature_group_count = 1 or batch_group_count = 1
。 - (C24)
size(precision_config) = 2
。 - (C25)
dim(result, result_dim)
は次のように定義されます。result_dim = output_batch_dimension
の場合、dim(lhs, input_batch_dimension) / batch_group_count
。result_dim = output_feature_dimension
の場合、dim(rhs, kernel_output_feature_dimension)
。- それ以外の場合は
num_windows
。ここで: output_spatial_dimensions[spatial_dim] = result_dim
。lhs_dim = input_spatial_dimensions[spatial_dim]
rhs_dim = kernel_spatial_dimensions[spatial_dim]
dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1
。
- (C26)
rank(result) = N
。 - オペレーションで量子化されていないテンソルを使用する場合:
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result)
。
- (C27)
- オペレーションで量子化されたテンソルを使用する場合:
- (C28)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
。 - (C29)
is_per_axis_quantized(rhs)
の場合、quantization_dimension(rhs) = kernel_output_feature_dimension
。 - (C30)
is_per_axis_quantized(result)
の場合、quantization_dimension(result) = output_feature_dimension
。 is_quantized(lhs)
の場合:- (C31)
storage_type(lhs) = storage_type(rhs)
。 - (C32)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
。 - (C33)
is_per_tensor_quantized(rhs)
の場合はis_per_tensor_quantized(result)
です。 !is_quantized(lhs)
の場合:- (C34)
element_type(lhs) = expressed_type(rhs) = element_type(result)
。
- (C28)
例
// %lhs: [[
// [
// [1], [2], [5], [6]
// ],
// [
// [3], [4], [7], [8]
// ],
// [
// [10], [11], [14], [15]
// ],
// [
// [12], [13], [16], [17]
// ]
// ]]
//
// %rhs: [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
window_strides = array<i64: 4, 4>,
padding = dense<0> : tensor<2x2xi64>,
lhs_dilation = array<i64: 2, 2>,
rhs_dilation = array<i64: 1, 1>,
window_reversal = array<i1: false, false>,
// In the StableHLO dialect, dimension numbers are encoded via:
// `[<input dimensions>]x[<kernel dimensions>]->[output dimensions]`.
// "b" is batch dimension, "f" is feature dimension,
// "i" is input feature dimension, "o" is output feature dimension,
// "0/1/etc" are spatial dimensions.
dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
batch_group_count = 1 : i64,
feature_group_count = 1 : i64,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
// [[10], [26]],
// [[46], [62]]
// ]]
コサイン
セマンティクス
operand
テンソルに対して要素単位のコサイン演算を実行し、result
テンソルを生成します。要素のタイプに応じて、次の操作を行います。
- 浮動小数点数の場合: IEEE-754 の
cos
。 - 複素数の場合: 複素余弦。
- 量子化型の場合:
dequantize_op_quantize(cosine, operand, type(result))
。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(operand) = baseline_type(result)
。
例
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.cosine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.0], [-1.0, 0.0]]
count_leading_zeros
セマンティクス
operand
テンソルの先頭のゼロビット数を要素ごとにカウントし、result
テンソルを生成します。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
整数型のテンソル | (C1) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
整数型のテンソル | (C1) |
制約
- (C1)
type(operand) = type(result)
。
例
// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand) : (tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[64, 63], [56, 0]]
custom_call
セマンティクス
inputs
と called_computations
を受け取って results
を生成する、実装定義のオペレーション call_target_name
をカプセル化します。has_side_effect
、backend_config
、api_version
を使用して、実装定義のメタデータを追加できます。
現時点では、このオペレーションにはかなり整理されていないメタデータのコレクションが含まれています。これは、XLA コンパイラでの対応するオペレーションの有機的な進化を反映しています。今後、このメタデータを統合する予定です(#741)。
入力
ラベル | 名前 | タイプ |
---|---|---|
(I1) | inputs |
値の可変数 |
(I2) | call_target_name |
string 型の定数 |
(I3) | has_side_effect |
i1 型の定数 |
(I4) | backend_config |
string 型の定数または属性辞書 |
(I5) | api_version |
si32 型の定数 |
(I6) | called_computations |
string 型の可変長定数 |
出力
名前 | タイプ |
---|---|
results |
可変長の値の数 |
例
%results = "stablehlo.custom_call"(%input0) {
call_target_name = "foo",
has_side_effect = false,
backend_config = {bar = 42 : i32},
api_version = 4 : i32,
called_computations = [@foo]
} : (tensor<f64>) -> tensor<f64>
割る
セマンティクス
除数 lhs
テンソルと除算子 rhs
テンソルの要素ごとの除算を行い、result
テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- 整数の場合: 小数部を破棄して代数的な商を生成する整数除算。
- 浮動小数点数の場合: IEEE-754 の
division
。 - 複素数の場合: 複素除算。
- 量子化された型の場合:
dequantize_op_quantize(divide, lhs, rhs, type(result))
。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | lhs |
整数型、浮動小数点型、複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
(I2) | rhs |
整数型、浮動小数点型、複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
整数型、浮動小数点型、複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
。
例
// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]
dot_general
セマンティクス
lhs
のスライスと rhs
のスライスの間のドット積を計算し、result
テンソルを生成します。
正式には result[result_index] = dot_product
です。ここで、
lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions]
。rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions]
。result_batching_index + result_lhs_index + result_rhs_index = result_index
(size(result_batching_index) = size(lhs_batching_dimensions)
、size(result_lhs_index) = size(lhs_result_dimensions)
、size(result_rhs_index) = size(rhs_result_dimensions)
)。transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions)
。transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :])
reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions))
transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions)
transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :])
reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions))
。dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y))
。
量子化された型の場合は、dequantize_op_quantize(
lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
rhs_batching_dimensions, lhs_contracting_dimensions,
rhs_contracting_dimensions, precision_config), lhs, rhs, type(result))
を実行します。
ハイブリッド量子化タイプの場合は、hybrid_dequantize_then_op(
lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
rhs_batching_dimensions, lhs_contracting_dimensions,
rhs_contracting_dimensions, precision_config), lhs, rhs)
を実行します。
precision_config
は、アクセラレータ バックエンドでの計算の速度と精度のトレードオフを制御します。これは次のいずれかです(現時点では、これらの列挙型値のセマンティクスは指定されていません。#755 で対応する予定です)。
DEFAULT
: 計算速度は最も速いが、元の数値に対する近似値の精度は最も低い。HIGH
: 計算は遅くなりますが、元の数値に近い値をより正確に求めることができます。HIGHEST
: 計算に最も時間がかかりますが、元の数値に最も近い近似値が得られます。
DotAlgorithm
は、ドット演算の実装に使用されるアルゴリズムの主なプロパティを定義します。これにより、精度も定義されます。アルゴリズム属性フィールドが設定されている場合、precision_config
は DEFAULT
にする必要があります。デフォルト パラメータは実装で定義されるため、DotAlgorithms
にデフォルト値はありません。そのため、すべてのドット アルゴリズム フィールドを None
に設定して、空のドット アルゴリズムを指定し、代わりに precision_config
値を使用できます。
DotAlgorithm
フィールドには次のフィールドが含まれます。
lhs_precision_type
とrhs_precision_type
: オペレーションの LHS と RHS が丸められる精度。精度タイプは、入力と出力のストレージ タイプとは独立しています。accumulation_type
累積に使用される精度。lhs_component_count
、rhs_component_count
、num_primitive_operations
は、LHS や RHS を複数のコンポーネントに分解し、それらの値に対して複数の「プリミティブ」ドット演算を実行するアルゴリズムを実行する場合に適用されます。通常は、より高い精度をエミュレートします(例: 高精度の計算に bfloat16 AI データ型を活用: btf12_6x3x3)。分解のないアルゴリズムの場合、これらの値は1
に設定する必要があります。allow_imprecise_accumulation
: 一部のステップで低精度の累積が許可されるかどうかを指定します(CUBLASLT_MATMUL_DESC_FAST_ACCUM
など)。
DotAlgorithm
属性の例:
// Inputs are casted to tf32, and then accumulated in f32:
{lhs_precision_type = tf32,
rhs_precision_type = tf32,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = false}
// bf16_6x: each input is decomposed to 3 bf16 components, then 6 dot operations are done on those components, and the result is accumulated in f32.
{lhs_precision_type = bf16,
rhs_precision_type = bf16,
accumulation_type = f32,
lhs_component_count = 3,
rhs_component_count = 3,
num_primitive_operations = 6,
allow_imprecise_accumulation = false}
// Inputs are (casted to) f8e5m2, and we accumulate in f32, but for some steps we may accumulate in lower precision.
{lhs_precision_type = f8e5m2,
rhs_precision_type = f8e5m2,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = true}
どの組み合わせをサポートするかは、実装が決定します。一般に、各アルゴリズムが StableHLO のコンシューマによって各アクセラレータ タイプでサポートされているとは限りません。特定のアルゴリズムがサポートされていない場合は、代替のアルゴリズムにフォールバックするのではなく、エラーが発生します。StableHLO 検証ではベスト エフォート検証が提供され、どのハードウェアでもサポートされていないアルゴリズムが使用されるのを防ぎます。
サポートされているアルゴリズムの値については、xla_data.proto > Algorithm
をご覧ください。チケット #2483 には、バックエンドによってサポートされているアルゴリズムに関する一元化されたドキュメントを作成する計画が記載されています。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | lhs |
テンソルまたはテンソルごとの量子化テンソル | (C5 ~ C6)、(C9 ~ C10)、(C12 ~ C14)、(C17 ~ C18)、(C20) |
(I2) | rhs |
テンソルまたは量子化テンソル | (C7~C10)、(C12~C20) |
(I3) | lhs_batching_dimensions |
si64 型の 1 次元テンソル定数 |
(C1)、(C3)、(C5)、(C9)、(C12) |
(I4) | rhs_batching_dimensions |
si64 型の 1 次元テンソル定数 |
(C1)、(C4)、(C7)、(C9) |
(I5) | lhs_contracting_dimensions |
si64 型の 1 次元テンソル定数 |
(C2)、(C3)、(C6)、(C10) |
(I6) | rhs_contracting_dimensions |
si64 型の 1 次元のテンソル定数 |
(C2)、(C4)、(C8)、(C10)、(C16) |
(I7) | precision_config |
DEFAULT 、HIGH 、HIGHEST の可変長の数値型 |
(C11)、(C21) |
(I8) | lhs_precision_type |
FloatType または TensorFloat32 | (C21) |
(I9) | rhs_precision_type |
FloatType または TensorFloat32 | (C21) |
(I10) | accumulation_type |
FloatType または TensorFloat32 | (C21) |
(I11) | lhs_component_count |
si32 型の定数 |
(C21)、(C22) |
(I12) | rhs_component_count |
si32 型の定数 |
(C21)、(C23) |
(I13) | num_primitive_operations |
si32 型の定数 |
(C21)、(C24) |
(I14) | allow_imprecise_accumulation |
bool 型の定数 |
(C21) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
テンソルまたは量子化テンソル | (C12)、(C14)、(C18 ~ C20) |
制約
- (C1)
size(lhs_batching_dimensions) = size(rhs_batching_dimensions)
。 - (C2)
size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions)
。 - (C3)
is_unique(lhs_batching_dimensions + lhs_contracting_dimensions)
。 - (C4)
is_unique(rhs_batching_dimensions + rhs_contracting_dimensions)
。 - (C5)
0 <= lhs_batching_dimensions < rank(lhs)
。 - (C6)
0 <= lhs_contracting_dimensions < rank(lhs)
。 - (C7)
0 <= rhs_batching_dimensions < rank(rhs)
。 - (C8)
0 <= rhs_contracting_dimensions < rank(rhs)
。 - (C9)
dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)
。 - (C10)
dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...)
。 - (C11)
size(precision_config) = 2
。 - (C12)
shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions)
。 - オペレーションで量子化されていないテンソルを使用する場合:
- (C13)
element_type(lhs) = element_type(rhs)
。
- (C13)
- オペレーションで量子化されたテンソルを使用する場合:
- (C14)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
。 - (C15)
zero_points(rhs) = 0
。 - (C16)
is_per_axis_quantized(rhs)
の場合、quantization_dimension(rhs)
がrhs_contracting_dimensions
にない。 is_quantized(lhs)
の場合:- (C17)
storage_type(lhs) = storage_type(rhs)
。 - (C18)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
。 - (C19)
is_per_tensor_quantized(rhs)
の場合、is_per_tensor_quantized(result)
。 !is_quantized(lhs)
の場合:- (C20)
element_type(lhs) = expressed_type(rhs) = element_type(result)
。
- (C14)
!is_empty_algorithm(lhs_precision_type, rhs_precision_type, accumulation_type, lhs_component_count, rhs_component_count, num_primitive_operations allow_imprecise_accumulation)
の場合:- (C21)
precision_config... = DEFAULT
。 - (C22)
0 < lhs_component_count
。 - (C23)
0 < rhs_component_count
。 - (C24)
0 < num_primitive_operations
。
- (C21)
例
// %lhs: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
// %rhs: [
// [[1, 0],
// [0, 1]],
// [[1, 0],
// [0, 1]]
// ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
dot_dimension_numbers = #stablehlo.dot<
lhs_batching_dimensions = [0],
rhs_batching_dimensions = [0],
lhs_contracting_dimensions = [2],
rhs_contracting_dimensions = [1]
>,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>],
algorithm = #stablehlo.dot_algorithm<
lhs_precision_type = tf32,
rhs_precision_type = tf32,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = false
>
} : (tensor<2x2x2xi64>, tensor<2x2x2xi64>) -> tensor<2x2x2xi64>
// %result: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
dynamic_broadcast_in_dim
セマンティクス
このオペレーションは、broadcast_in_dim オペレーションと機能的には同じですが、結果の形状は output_dimensions
を介して動的に指定されます。
このオペレーションでは、オプションの属性 known_expanding_dimensions
、known_nonexpanding_dimensions
も受け入れ、ディメンションの展開動作に関する静的な知識を表現します。指定しない場合は、すべてのディメンションが拡張可能であると見なされます。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
テンソルまたは量子化テンソル | (C1-C2)、(C5-C6)、(C9) |
(I2) | output_dimensions |
整数型の 1 次元テンソル | (C7) |
(I3) | broadcast_dimensions |
整数型の 1 次元定数テンソル | (C2~C6) |
(I4) | known_expanding_dimensions |
整数型の 1 次元定数テンソル | (C8~C9) |
(I5) | known_nonexpanding_dimensions |
整数型の 1 次元定数テンソル | (C8~C9) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
テンソルまたは量子化テンソル | (C1)、(C3)、(C5~C7) |
制約
- (C1)
element_type(result)
は次のように定義されます。element_type(operand)
(!is_per_axis_quantized(operand)
の場合)。element_type(operand)
(ただし、quantization_dimension(operand)
、scales(operand)
、zero_points(operand)
は、それぞれquantization_dimension(result)
、scales(result)
、zero_points(result)
と異なる場合があります)。
- (C2)
size(broadcast_dimensions) = rank(operand)
。 - (C3)
0 <= broadcast_dimensions < rank(result)
。 - (C4)
is_unique(broadcast_dimensions)
。 - (C5)
axes(operand)
内のすべてのd
について:dim(operand, d) = 1
またはdim(operand, d) = dim(result, broadcast_dimensions[d])
。
- (C6)
is_per_axis_quantized(result)
の場合:quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)]
。dim(operand, quantization_dimension(operand)) = 1
の場合はscales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result)))
です。
- (C7)
size(output_dimensions) = rank(result)
。 - (C8)
is_unique(known_expanding_dimensions + known_nonexpanding_dimensions)
。 - (C9)
0 <= known_expanding_dimensions < rank(operand)
。 - (C10)
0 <= known_nonexpanding_dimensions < rank(operand)
。
例
// %operand: [
// [1, 2, 3]
// ]
%operand = stablehlo.constant dense<[[1, 2, 3]]> : tensor<1x3xi64>
%output_dimensions = stablehlo.constant dense<[2, 3, 2]> : tensor<3xi64>
%result = "stablehlo.dynamic_broadcast_in_dim"(%operand, %output_dimensions) {
broadcast_dimensions = array<i64: 2, 1>,
known_expanding_dimensions = array<i64: 0>,
known_nonexpanding_dimensions = array<i64: 1>
} : (tensor<1x3xi64>, tensor<3xi64>) -> tensor<2x3x2xi64>
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
dynamic_conv
セマンティクス
このオペレーションは、畳み込みオペレーションと機能的に同じですが、パディングは padding
を介して動的に指定されます。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | lhs |
テンソルまたはテンソルごとの量子化テンソル | (C1)、(C10 ~ C11)、(C14)(C25)、(C26 ~ C27)、(C30 ~ C31)、(C33) |
(I2) | rhs |
テンソルまたは量子化テンソル | (C1)、(C14 ~ C16)、(C26 ~ C28)、(C30 ~ C33) |
(I3) | padding |
整数型の 2 次元テンソル | (C4) |
(I4) | window_strides |
si64 型の 1 次元テンソル定数 |
(C2 ~ C3) |
(I5) | lhs_dilation |
si64 型の 1 次元のテンソル定数 |
(C5-C6) |
(I6) | rhs_dilation |
si64 型の 1 次元テンソル定数 |
(C7 ~ C8) |
(I7) | window_reversal |
i1 型の 1 次元テンソル定数 |
(C9) |
(I8) | input_batch_dimension |
si64 型の定数 |
(C10)、(C13) |
(I9) | input_feature_dimension |
si64 型の定数 |
(C11)、(C13~C14) |
(I10) | input_spatial_dimensions |
si64 型の 1 次元テンソル定数 |
(C12)、(C13) |
(I11) | kernel_input_feature_dimension |
si64 型の定数 |
(C14)、(C18) |
(I12) | kernel_output_feature_dimension |
si64 型の定数 |
(C15~C16)、(C18)、(C28) |
(I13) | kernel_spatial_dimensions |
si64 型の 1 次元テンソル定数 |
(C17~C18) |
(I14) | output_batch_dimension |
si64 型の定数 |
(C20) |
(I15) | output_feature_dimension |
si64 型の定数 |
(C20)、(C29) |
(I16) | output_spatial_dimensions |
si64 型の 1 次元のテンソル定数 |
(C19 ~ C20) |
(I17) | feature_group_count |
si64 型の定数 |
(C11)、(C14)、(C16)、(C21)、(C23) |
(I18) | batch_group_count |
si64 型の定数 |
(C10)、(C15)、(C22)、(C23) |
(I19) | precision_config |
DEFAULT 、HIGH 、HIGHEST の可変列挙型の数 |
(C24) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
テンソルまたは量子化テンソル | (C25~C27)、(C29)、(C31~C33) |
制約
- (C1)
N = rank(lhs) = rank(rhs)
。 - (C2)
size(window_strides) = N - 2
。 - (C3)
0 < window_strides
。 - (C4)
shape(padding) = [N - 2, 2]
。 - (C5)
size(lhs_dilation) = N - 2
。 - (C6)
0 < lhs_dilation
。 - (C7)
size(rhs_dilation) = N - 2
。 - (C8)
0 < rhs_dilation
。 - (C9)
size(window_reversal) = N - 2
。 - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0
。 - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0
。 - (C12)
size(input_spatial_dimensions) = N - 2
。 - (C13)
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]
が与えられている場合:is_unique(input_dimensions)
。0 <= input_dimensions < N
。
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count
。 - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0
。 - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0
。 - (C17)
size(kernel_spatial_dimensions) = N - 2
。 - (C18)
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]
が与えられている場合:is_unique(kernel_dimensions)
。0 <= kernel_dimensions < N
。
- (C19)
size(output_spatial_dimensions) = N - 2
。 - (C20)
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]
が与えられている場合:is_unique(output_dimensions)
。0 <= output_dimensions < N
。
- (C21)
0 < feature_group_count
。 - (C22)
0 < batch_group_count
。 - (C23)
feature_group_count = 1 or batch_group_count = 1
。 - (C24)
size(precision_config) = 2
。 - (C25)
dim(result, result_dim)
は次のように定義されます。result_dim = output_batch_dimension
の場合、dim(lhs, input_batch_dimension) / batch_group_count
。result_dim = output_feature_dimension
の場合、dim(rhs, kernel_output_feature_dimension)
。- それ以外の場合は
num_windows
。ここで: output_spatial_dimensions[spatial_dim] = result_dim
。lhs_dim = input_spatial_dimensions[spatial_dim]
rhs_dim = kernel_spatial_dimensions[spatial_dim]
dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1
。
- (C26)
rank(result) = N
。 - オペレーションで量子化されていないテンソルを使用する場合:
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result)
。
- (C27)
- オペレーションで量子化されたテンソルを使用する場合:
- (C28)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
。 - (C29)
is_per_axis_quantized(rhs)
の場合、quantization_dimension(rhs) = kernel_output_feature_dimension
。 - (C30)
is_per_axis_quantized(result)
の場合、quantization_dimension(result) = output_feature_dimension
。 is_quantized(lhs)
の場合:- (C31)
storage_type(lhs) = storage_type(rhs)
。 - (C32)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
。 - (C33)
is_per_tensor_quantized(rhs)
の場合はis_per_tensor_quantized(result)
です。 !is_quantized(lhs)
の場合:- (C34)
element_type(lhs) = expressed_type(rhs) = element_type(result)
。
- (C28)
例
// %lhs: [[
// [[1], [2], [5], [6]],
// [[3], [4], [7], [8]],
// [[10], [11], [14], [15]],
// [[12], [13], [16], [17]]
// ]]
//
// %rhs: [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
// %padding: [[1, 1],
// [1, 1]]
%result = "stablehlo.dynamic_conv"(%lhs, %rhs, %padding) {
window_strides = array<i64: 4, 4>,
lhs_dilation = array<i64: 2, 2>,
rhs_dilation = array<i64: 1, 1>,
window_reversal = array<i1: false, false>,
dimension_numbers = #stablehlo.conv<raw
input_batch_dimension = 0,
input_feature_dimension = 3,
input_spatial_dimensions = [0, 1],
kernel_input_feature_dimension = 2,
kernel_output_feature_dimension = 3,
kernel_spatial_dimensions = [0, 1],
output_batch_dimension = 0,
output_feature_dimension = 3,
output_spatial_dimensions = [1, 2]
>,
feature_group_count = 1 : i64,
batch_group_count = 1 : i64,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>, tensor<2x2xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
// [[1], [5]],
// [[10], [14]]
// ]]
dynamic_gather
セマンティクス
このオペレーションは、slice_sizes
が値として動的に指定される gather オペレーションと機能的に同じです。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
テンソルまたはテンソルごとの量子化テンソル | (C1)、(C7)、(C10~C12)、(C14) |
(I2) | start_indices |
整数型のテンソル | (C2)、(C3)、(C13) |
(I3) | slice_sizes |
整数型の 1 次元テンソル | (C8)、(C11~C13) |
(I4) | offset_dims |
si64 型の 1 次元のテンソル定数 |
(C1)、(C4 ~ C5)、(C13) |
(I5) | collapsed_slice_dims |
si64 型の 1 次元テンソル定数 |
(C1)、(C6~C8)、(C13) |
(I6) | start_index_map |
si64 型の 1 次元のテンソル定数 |
(C3)、(C9)、(C10) |
(I7) | index_vector_dim |
si64 型の定数 |
(C2)、(C3)、(C13) |
(I8) | indices_are_sorted |
i1 型の定数 |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
テンソルまたはテンソルごとの量子化テンソル | (C5)、(C13 ~ C14) |
制約
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims)
。 - (C2)
0 <= index_vector_dim <= rank(start_indices)
。 - (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1
。 - (C4)
is_unique(offset_dims) and is_sorted(offset_dims)
。 - (C5)
0 <= offset_dims < rank(result)
。 - (C6)
is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims)
。 - (C7)
0 <= collapsed_slice_dims < rank(operand)
。 - (C8)
slice_sizes[collapsed_slice_dims...] <= 1
。 - (C9)
is_unique(start_index_map)
。 - (C10)
0 <= start_index_map < rank(operand)
。 - (C11)
size(slice_sizes) = rank(operand)
。 - (C12)
0 <= slice_sizes <= shape(operand)
。 - (C13)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)
ここで:batch_dim_sizes = shape(start_indices)
とほぼ同じですが、index_vector_dim
に対応するstart_indices
のディメンション サイズは含まれません。offset_dim_sizes = shape(slice_sizes)
。ただし、collapsed_slice_dims
に対応するslice_sizes
のディメンション サイズは含まれません。combine
は、batch_dims
に対応する軸にbatch_dim_sizes
を配置し、offset_dims
に対応する軸にoffset_dim_sizes
を配置します。
- (C14)
element_type(operand) = element_type(result)
。
例
// %operand: [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ]
// %start_indices: [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 2]]
// ]
// %slize_sizes: [1, 2, 2]
%result = "stablehlo.dynamic_gather"(%operand, %start_indices, %slize_sizes) {
dimension_numbers = #stablehlo.gather<
offset_dims = [2, 3],
collapsed_slice_dims = [0],
start_index_map = [1, 0],
index_vector_dim = 2>,
indices_are_sorted = false
} : (tensor<3x4x2xi64>, tensor<2x3x2xi64>, tensor<3xi64>) -> tensor<2x3x2x2xi64>
// %result: [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[9, 10], [11, 12]],
// [[11, 12], [13, 14]],
// [[17, 18], [19, 20]]
// ]
// ]
dynamic_iota
セマンティクス
この演算は iota op と機能的に同じですが、結果の形状は output_shape
によって動的に指定されます。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | output_shape |
整数型の 1 次元テンソル | (C1)、(C2) |
(I2) | iota_dimension |
si64 |
(C1) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
整数型、浮動小数点型、複合型のテンソル、またはテンソルごとの量子化テンソル | (C2) |
制約
- (C1)
0 <= iota_dimension < size(output_shape)
。 - (C2)
rank(result) = size(output_shape)
。
例
%output_shape = stablehlo.constant dense<[4, 5]> : tensor<2xi64>
%result = "stablehlo.dynamic_iota"(%output_shape) {
iota_dimension = 0 : i64
} : (tensor<2xi64>) -> tensor<4x5xi64>
// %result: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
dynamic_pad
セマンティクス
このオペレーションは、pad オペレーションと機能的には同じですが、edge_padding_low
、edge_padding_high
、interior_padding
が値として動的に指定されます。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
テンソルまたはテンソルごとの量子化テンソル | (C1)、(C2)、(C4) |
(I2) | padding_value |
0 次元テンソルまたはテンソルごとの量子化テンソル | (C1) |
(I3) | edge_padding_low |
整数型の 1 次元テンソル | (C1)、(C4) |
(I4) | edge_padding_high |
整数型の 1 次元テンソル | (C1)、(C4) |
(I5) | interior_padding |
整数型の 1 次元テンソル | (C2-C4) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
テンソルまたはテンソルごとの量子化テンソル | (C3-C6) |
制約
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result)
。 - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)
。 - (C3)
0 <= interior_padding
。 - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high
。
例
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
// %edge_padding_low: [0, 1]
// %edge_padding_high: [2, 1]
// %interior_padding: [1, 2]
%result = "stablehlo.dynamic_pad"(%operand, %padding_value,
%edge_padding_low, %edge_padding_high, %interior_padding
) : (tensor<2x3xi64>, tensor<i64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<5x9xi64>
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
dynamic_reshape
セマンティクス
このオペレーションは、reshape オペレーションと機能的には同じですが、結果の形状は output_shape
を介して動的に指定されます。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
テンソルまたは量子化テンソル | (C1~C3) |
(I2) | output_shape |
整数型の 1 次元テンソル | (C4) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
テンソルまたは量子化テンソル | (C1~C4) |
制約
- (C1)
element_type(result)
は次のように定義されます。element_type(operand)
(!is_per_axis_quantized(operand)
の場合)。element_type(operand)
。ただし、quantization_dimension(operand)
とquantization_dimension(result)
が異なる場合があります。
- (C2)
size(operand) = size(result)
。 - (C3)
is_per_axis_quantized(operand)
の場合:reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
。dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
。
- (C4)
size(output_shape) = rank(result)
。
例
// %operand: [[1, 2, 3], [4, 5, 6]]
// %output_shape: [3, 2]
%result = "stablehlo.dynamic_reshape"(%operand, %output_shape) : (tensor<2x3xi64>, tensor<2xi64>) -> tensor<3x2xi64>
// %result: [[1, 2], [3, 4], [5, 6]]
dynamic_slice
セマンティクス
動的に計算された開始インデックスを使用して operand
からスライスを取り出し、result
テンソルを生成します。start_indices
には、調整の対象となる各ディメンションのスライスの開始インデックスが含まれ、slice_sizes
には各ディメンションのスライスのサイズが含まれます。より正式には、result[result_index] = operand[operand_index]
です。ここで:
adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes)
。operand_index = adjusted_start_indices + result_index
。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
テンソルまたはテンソルごとの量子化テンソル | (C1)、(C2)、(C4) |
(I2) | start_indices |
整数型の 0 次元テンソルの可変長の数 | (C2)、(C3) |
(I3) | slice_sizes |
si64 型の 1 次元テンソル定数 |
(C2)、(C4)、(C5) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
テンソルまたはテンソルごとの量子化テンソル | (C1)、(C5) |
制約
- (C1)
element_type(operand) = element_type(result)
。 - (C2)
size(start_indices) = size(slice_sizes) = rank(operand)
。 - (C3)
same(type(start_indices...))
。 - (C4)
0 <= slice_sizes <= shape(operand)
。 - (C5)
shape(result) = slice_sizes
。
例
// %operand: [
// [0, 0, 1, 1],
// [0, 0, 1, 1],
// [0, 0, 0, 0],
// [0, 0, 0, 0]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
slice_sizes = array<i64: 2, 2>
} : (tensor<4x4xi32>, tensor<i64>, tensor<i64>) -> tensor<2x2xi32>
// %result: [
// [1, 1],
// [1, 1]
// ]
dynamic_update_slice
セマンティクス
start_indices
で始まるスライスが update
の値で更新される点を除き、operand
テンソルと同じ result
テンソルを生成します。より正式には、result[result_index]
は次のように定義されます。
update[update_index]
:0 <= update_index < shape(update)
の場合。条件は次のとおりです。adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update))
。update_index = result_index - adjusted_start_indices
。
- それ以外の場合は
operand[result_index]
。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
テンソルまたはテンソルごとの量子化テンソル | (C1 ~ C4)、(C6) |
(I2) | update |
テンソルまたはテンソルごとの量子化テンソル | (C2)、(C3)、(C6) |
(I3) | start_indices |
整数型の 0 次元テンソルの可変長の数 | (C4)、(C5) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
テンソルまたはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
type(operand) = type(result)
。 - (C2)
element_type(update) = element_type(operand)
。 - (C3)
rank(update) = rank(operand)
。 - (C4)
size(start_indices) = rank(operand)
。 - (C5)
same(type(start_indices...))
。 - (C6)
0 <= shape(update) <= shape(operand)
。
例
// %operand: [
// [1, 1, 0, 0],
// [1, 1, 0, 0],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
// %update: [
// [1, 1],
// [1, 1]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
: (tensor<4x4xi32>, tensor<2x2xi32>, tensor<i64>, tensor<i64>) -> tensor<4x4xi32>
// %result: [
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
指数
セマンティクス
operand
テンソルの要素ごとの指数演算を実行し、result
テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- 浮動小数点数の場合: IEEE-754 の
exp
。 - 複素数の場合: 複素指数関数。
- 量子化された型の場合:
dequantize_op_quantize(exponential, operand, type(result))
。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(operand) = baseline_type(result)
。
例
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]
exponential_minus_one
セマンティクス
operand
テンソルに対して要素ごとの指数マイナス 1 演算を実行し、result
テンソルを生成します。要素のタイプに応じて、次の操作を行います。
- 浮動小数点数: IEEE-754 の
expm1
。 - 複素数の場合: 複素指数 - 1。
- 量子化された型の場合:
dequantize_op_quantize(exponential_minus_one, operand, type(result))
。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(operand) = baseline_type(result)
。
例
// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand) : (tensor<2xf64>) -> tensor<2xf64>
// %result: [0.0, 1.71828187]
fft
セマンティクス
実数と複素数の入力 / 出力の正規化と逆正規化を行います。
fft_type
は、次のいずれかです。
FFT
: 複雑な FFT を転送します。IFFT
: 複素数から複素数への逆 FFT。RFFT
: 実数から複素数への FFT を前方処理します。IRFFT
: 実数から複素数への逆 FFT(複素数を受け取り、実数を返します)。
より形式的には、複雑な型の 1 次元テンソルを入力として受け取る関数 fft
は、出力と同じ型の 1 次元テンソルを生成し、離散フーリエ変換を計算します。
fft_type = FFT
の場合、result
は、L = size(fft_length)
の連続した L 計算の最終結果として定義されます。たとえば、L = 3
の場合は次のようになります。
result1[i0, ..., :] = fft(operand[i0, ..., :])
。result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])
。
さらに、同じ型シグネチャを持ち、fft
の逆数を計算する関数 ifft
があるとします。
fft_type = IFFT
の場合、result
は fft_type = FFT
の計算の逆数として定義されます。たとえば、L = 3
の場合は次のようになります。
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])
。result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
result[i0, ..., :] = ifft(result2[i0, ..., :])
。
さらに、浮動小数点型の 1 次元テンソルを取る関数 rfft
は、同じ浮動小数点セマンティクスの複合型の 1 次元テンソルを生成し、次のように動作します。
rfft(real_operand) = truncated_result
complex_operand... = (real_operand..., 0.0)
。complex_result = fft(complex_operand)
truncated_result = complex_result[:(rank(complex_result) / 2 + 1)]
。
(実数オペランドに対して離散フーリエ変換が計算される場合、結果の最初の N/2 + 1
要素が結果の残りの部分を明確に定義するため、冗長な要素の計算を避けるために rfft
の結果は切り捨てられます)。
fft_type = RFFT
の場合、result
は、L = size(fft_length)
の連続した L 計算の最終結果として定義されます。たとえば、L = 3
の場合は次のようになります。
result1[i0, ..., :] = rfft(operand[i0, ..., :])
。result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])
。
最後に、同じ型のシグネチャを持つ関数 irfft
が与えられ、rfft
の逆数を計算します。
fft_type = IRFFT
の場合、result
は fft_type = RFFT
の計算の逆数として定義されます。たとえば、L = 3
の場合は次のようになります。
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])
。result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
result[i0, ..., :] = irfft(result2[i0, ..., :])
。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
浮動小数点型または複合型のテンソル | (C1)、(C2)、(C4)、(C5) |
(I2) | fft_type |
FFT 、IFFT 、RFFT 、IRFFT の列挙型 |
(C2)、(C5) |
(I3) | fft_length |
si64 型の 1 次元のテンソル定数 |
(C1)、(C3)、(C4) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
浮動小数点数または複素数のテンソル | (C2)、(C4)、(C5) |
制約
- (C1)
size(fft_length) <= rank(operand)
。 - (C2)
operand
要素とresult
要素型の関係はさまざまです。fft_type = FFT
、element_type(operand)
、element_type(result)
が同じ複合型の場合。fft_type = IFFT
、element_type(operand)
、element_type(result)
が同じ複合型の場合。fft_type = RFFT
の場合、element_type(operand)
は浮動小数点型で、element_type(result)
は同じ浮動小数点セマンティクスの複合型です。fft_type = IRFFT
の場合、element_type(operand)
は複合型であり、element_type(result)
は同じ浮動小数点セマンティクスの浮動小数点型です。
- (C3)
1 <= size(fft_length) <= 3
。 - (C4)
operand
とresult
の間に浮動小数点型のテンサーreal
がある場合、shape(real)[-size(fft_length):] = fft_length
。 - (C5)次の点を除き
shape(result) = shape(operand)
。fft_type = RFFT
の場合はdim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1
。fft_type = IRFFT
の場合はdim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1
。
例
// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
fft_type = #stablehlo<fft_type FFT>,
fft_length = array<i64: 4>
} : (tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>>
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]
floor
セマンティクス
operand
テンソルの要素ごとの床演算を行い、result
テンソルを生成します。IEEE-754 仕様の roundToIntegralTowardNegative
オペレーションを実装します。量子化型の場合は、dequantize_op_quantize(floor, operand, type(result))
を実行します。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C1) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(operand) = baseline_type(result)
。
例
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]
集める
セマンティクス
start_indices
で指定されたオフセットから operand
テンソルのスライスを収集し、result
テンソルを生成します。
次の図は、result
の要素が operand
の要素にどのようにマッピングされるかを、具体的な例を使用して示しています。この図では、いくつかの result
インデックスの例を選択し、それらが対応する operand
インデックスについて詳しく説明しています。
より正式には、result[result_index] = operand[operand_index]
です。ここで:
batch_dims = [d for d in axes(result) and d not in offset_dims]
。batch_index = result_index[batch_dims...]
。start_index
は次のように定義されます。start_indices[bi0, ..., :, ..., biN]
。ここで、bi
はbatch_index
内の個々の要素で、:
はindex_vector_dim
インデックスに挿入されます(index_vector_dim
<rank(start_indices)
の場合)。- それ以外の場合は
[start_indices[batch_index]]
。
axes(operand)
のd_operand
は、d_operand = start_index_map[d_start]
の場合、full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand])
- そうでない場合は
full_start_index[d_operand] = 0
。
axes(operand)
のd_operand
は、full_batching_index[d_operand] = batch_index[d_start - (d_start < index_vector_dim ? 0 : 1)]
d_operand = operand_batching_dims[i_batching]
とd_start = start_indices_batching_dims[i_batching]
の場合。- それ以外の場合は
full_batching_index[d_operand] = 0
。
offset_index = result_index[offset_dims...]
。full_offset_index = [oi0, ..., 0, ..., oiN]
。ここで、oi
はoffset_index
の個々の要素であり、0
はcollapsed_slice_dims
とoperand_batching_dims
のインデックスに挿入されます。operand_index = full_start_index + full_batching_index + full_offset_index
。
indices_are_sorted
が true
の場合、実装では start_indices
が start_index_map
に関して並べ替えられていると想定できます。それ以外の場合、動作は未定義です。より正式には、indices(result)
のすべての i1 < i2
について、full_start_index(i1) <= full_start_index(i2)
です。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
テンソルまたはテンソルごとの量子化テンソル | (C1)、(C8)、(C11)、(C17)、(C19 ~ C21)、(C23) |
(I2) | start_indices |
整数型のテンソル | (C2~C3)、(C14)、(C17)、(C22) |
(I3) | offset_dims |
si64 型の 1 次元テンソル定数 |
(C1)、(C4 ~ C5)、(C22) |
(I4) | collapsed_slice_dims |
si64 型の 1 次元のテンソル定数 |
(C1)、(C6 ~ C9)、(C22) |
(I5) | operand_batching_dims |
si64 型の 1 次元テンソル定数 |
(C1)、(C6)、(C10 ~ C12)、(C16 ~ C18)、(C22) |
(I6) | start_indices_batching_dims |
si64 型の 1 次元テンソル定数 |
(C13 ~ C17) |
(I7) | start_index_map |
si64 型の 1 次元テンソル定数 |
(C3)、(C18~C19) |
(I8) | index_vector_dim |
si64 型の定数 |
(C2 ~ C3)、(C15)、(C22) |
(I9) | slice_sizes |
si64 型の 1 次元のテンソル定数 |
(C9)、(C12)、(C20~C22) |
(I10) | indices_are_sorted |
i1 型の定数 |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
テンソルまたはテンソルごとの量子化テンソル | (C5)、(C22 ~ C23) |
制約
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims) + size(operand_batching_dims)
。 - (C2)
0 <= index_vector_dim <= rank(start_indices)
。 - (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1
。 - (C4)
is_unique(offset_dims) and is_sorted(offset_dims)
。 - (C5)
0 <= offset_dims < rank(result)
。 - (C6)
is_unique(concatenate(collapsed_slice_dims, operand_batching_dims))
- (C7)
is_sorted(collapsed_slice_dims)
。 - (C8)
0 <= collapsed_slice_dims < rank(operand)
。 - (C9)
slice_sizes[collapsed_slice_dims...] <= 1
。 - (C10)
is_sorted(operand_batching_dims)
。 - (C11)
0 <= operand_batching_dims < rank(operand)
。 - (C12)
slice_sizes[operand_batching_dims...] <= 1
。 - (C13)
is_unique(start_indices_batching_dims)
。 - (C14)
0 <= start_indices_batching_dims < rank(start_indices)
。 - (C15)
index_vector_dim not in start_indices_batching_dims
。 - (C16)
size(operand_batching_dims) == size(start_indices_batching_dims)
。 - (C17)
dim(operand, operand_batching_dims...) = dim(start_indices, start_indices_batching_dims...)
。 - (C18)
is_unique(concatenate(start_index_map, operand_batching_dims))
。 - (C19)
0 <= start_index_map < rank(operand)
。 - (C20)
size(slice_sizes) = rank(operand)
。 - (C21)
0 <= slice_sizes <= shape(operand)
。 - (C22)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)
。ここで:batch_dim_sizes = shape(start_indices)
とほぼ同じですが、index_vector_dim
に対応するstart_indices
のディメンション サイズは含まれません。offset_dim_sizes = slice_sizes
。ただし、collapsed_slice_dims
とoperand_batching_dims
に対応するslice_sizes
のディメンション サイズは含まれません。combine
は、batch_dims
に対応する軸にbatch_dim_sizes
を配置し、offset_dims
に対応する軸にoffset_dim_sizes
を配置します。
- (C23)
element_type(operand) = element_type(result)
。
例
// %operand: [
// [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ],
// [
// [[25, 26], [27, 28], [29, 30], [31, 32]],
// [[33, 34], [35, 36], [37, 38], [39, 40]],
// [[41, 42], [43, 44], [45, 46], [47, 48]]
// ]
// ]
// %start_indices: [
// [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 9]]
// ],
// [
// [[0, 0], [2, 1], [2, 2]],
// [[1, 2], [0, 1], [1, 0]]
// ]
// ]
%result = "stablehlo.gather"(%operand, %start_indices) {
dimension_numbers = #stablehlo.gather<
offset_dims = [3, 4],
collapsed_slice_dims = [1],
operand_batching_dims = [0],
start_indices_batching_dims = [1],
start_index_map = [2, 1],
index_vector_dim = 3>,
slice_sizes = array<i64: 1, 1, 2, 2>,
indices_are_sorted = false
} : (tensor<2x3x4x2xi32>, tensor<2x2x3x2xi64>) -> tensor<2x2x3x2x2xi32>
// %result: [
// [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[33, 34], [35, 36]],
// [[35, 36], [37, 38]],
// [[41, 42], [43, 44]]
// ]
// ],
// [
// [
// [[1, 2], [3, 4]],
// [[13, 14], [15, 16]],
// [[21, 22], [23, 24]]
// ],
// [
// [[43, 44], [45, 46]],
// [[33, 34], [35, 36]],
// [[27, 28], [29, 30]]
// ]
// ]
// ]
get_dimension_size
セマンティクス
operand
の指定された dimension
のサイズを生成します。正式には result = dim(operand, dimension)
です。セマンティクスは、型のシェイプ コンポーネントにのみ関係します。要素の型は任意です。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
テンソルまたは量子化テンソル | (C1) |
(I2) | dimension |
si64 型の定数 |
(C1) |
出力
名前 | タイプ |
---|---|
result |
si32 型の 0 次元テンソル |
制約
- (C1)
0 <= dimension < rank(operand)
。
例
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
dimension = 1 : i64
} : (tensor<2x3xi64>) -> tensor<i32>
// %result: 3
get_tuple_element
セマンティクス
operand
タプルの index
位置にある要素を抽出して result
を生成します。正式には result = operand[index]
です。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
tuple | (C1)、(C2) |
(I2) | index |
si32 型の定数 |
(C1)、(C2) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
サポートされている任意のタイプ | (C2) |
制約
- (C1)
0 <= index < size(operand)
。 - (C2)
type(result) = tuple_element_types(operand)[index]
。
例
// %operand: ([1.0, 2.0], (3))
index = 0 : i32
} : (tuple<tensor<2xf32>, tuple<tensor<i32>>>) -> tensor<2xf32>
// %result: [1.0, 2.0]
~なら
セマンティクス
pred
の値に応じて、true_branch
または false_branch
から関数を 1 つだけ実行することで、出力を生成します。(より正式には、result =
pred ? true_branch() : false_branch()
)。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | pred |
i1 型の 0 次元テンソル |
|
(I2) | true_branch |
関数 | (C1~C3) |
(I3) | false_branch |
関数 | (C1)、(C2) |
出力
名前 | タイプ | 制約 |
---|---|---|
results |
テンソル、量子化テンソル、トークンの可変数 | (C3) |
制約
- (C1)
input_types(true_branch) = input_types(false_branch) = []
。 - (C2)
output_types(true_branch) = output_types(false_branch)
。 - (C3)
type(results...) = output_types(true_branch)
。
例
// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
"stablehlo.return"(%result_true_branch) : (tensor<i32>) -> ()
}, {
"stablehlo.return"(%result_false_branch) : (tensor<i32>) -> ()
}) : (tensor<i1>) -> tensor<i32>
// %result: 10
imag
セマンティクス
operand
から虚部を要素ごとに抽出し、result
テンソルを生成します。より正式には、各要素 x
について imag(x) = is_complex(x) ? imaginary_part(x) :
constant(0, element_type(result))
です。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
浮動小数点数または複素数のテンソル | (C1)、(C2) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
浮動小数点型のテンソル | (C1)、(C2) |
制約
- (C1)
shape(result) = shape(operand)
。 - (C2)
element_type(result)
は次のように定義されます。is_complex(operand)
の場合、complex_element_type(element_type(operand))
。- それ以外の場合は
element_type(operand)
。
例
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [2.0, 4.0]
インフィード
セマンティクス
インフィードからデータを読み取り、results
を生成します。
infeed_config
のセマンティクスは実装で定義されます。
results
は、先頭にペイロード値、最後にトークンで構成されます。今後、明確性を高めるために、ペイロードとトークンを 2 つの個別の出力に分割する予定です(#670)。
入力
ラベル | 名前 | タイプ |
---|---|---|
(I1) | token |
token |
(I2) | infeed_config |
string 型の定数 |
出力
名前 | タイプ | 制約 |
---|---|---|
results |
テンソル、量子化テンソル、またはトークンの可変数 | (C1~C3) |
制約
- (C1)
0 < size(results)
。 - (C2)
is_empty(result[:-1])
またはis_tensor(type(results[:-1]))
。 - (C3)
is_token(type(results[-1]))
。
例
// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]
iota
セマンティクス
output
テンソルを、iota_dimension
ディメンションに沿って 0 から順に増加する値で埋めます。より正式には、
output[output_index] = constant(is_quantized(output) ?
quantize(output_index[iota_dimension], element_type(output)) :
output_index[iota_dimension], element_type(output))
。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | iota_dimension |
si64 |
(C1) |
出力
名前 | タイプ | 制約 |
---|---|---|
output |
整数型、浮動小数点型、複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
0 <= iota_dimension < rank(output)
。
例
%output = "stablehlo.iota"() {
iota_dimension = 0 : i64
} : () -> tensor<4x5xi32>
// %output: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
%output = "stablehlo.iota"() {
iota_dimension = 1 : i64
} : () -> tensor<4x5xi32>
// %output: [
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4]
// ]
is_finite
セマンティクス
x
の値が有限(+Inf、-Inf、NaN のいずれでもない)かどうかを要素ごとにチェックし、y
テンソルを生成します。IEEE-754 仕様の isFinite
オペレーションを実装します。量子化された型の場合、結果は常に true
です。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | x |
浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C1) |
出力
名前 | タイプ | 制約 |
---|---|---|
y |
ブール型のテンソル | (C1) |
制約
- (C1)
shape(x) = shape(y)
。
例
// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x) : (tensor<7xf64) -> tensor<7xi1>
// %y: [false, false, false, true, true, true, true]
log
セマンティクス
operand
テンソルの要素ごとのログ演算を実行し、result
テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- 浮動小数点数の場合: IEEE-754 の
log
。 - 複素数の場合: 複素対数。
- 量子化型の場合:
dequantize_op_quantize(log, operand, type(result))
。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(operand) = baseline_type(result)
。
例
// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]
log_plus_one
セマンティクス
operand
テンソルに対して要素単位の対数に 1 演算を加え、result
テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- 浮動小数点数の場合: IEEE-754 の
logp1
。 - 複素数の場合: 複素対数に 1 を加算します。
- 量子化された型の場合:
dequantize_op_quantize(log_plus_one, operand, type(result))
。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(operand) = baseline_type(result)
。
例
// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]
ロジスティクス
セマンティクス
operand
テンソルの要素ごとのロジスティック演算を実行し、result
テンソルを生成します。要素のタイプに応じて、次の操作を行います。
- 浮動小数点数の場合: IEEE-754 の
division(1, addition(1, exp(-x)))
。 - 複素数の場合: 複素ロジスティック。
- 量子化された型の場合:
dequantize_op_quantize(logistic, operand, type(result))
。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(operand) = baseline_type(result)
。
例
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]
地図
セマンティクス
dimensions
に沿ってマッピング関数 computation
を inputs
に適用し、result
テンソルを生成します。
正式には result[result_index] = computation(inputs...[result_index])
です。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | inputs |
テンソルの可変数またはテンソルごとの量子化テンソル | (C1 ~ C4) |
(I2) | dimensions |
si64 型の 1 次元テンソル定数 |
(C3) |
(I3) | computation |
関数 | (C4) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
テンソルまたはテンソルごとの量子化テンソル | (C1)、(C4) |
制約
- (C1)
shape(inputs...) = shape(result)
。 - (C2)
0 < size(inputs) = N
。 - (C3)
dimensions = range(rank(inputs[0]))
。 - (C4)
computation
の型は(tensor<E0>, ..., tensor<EN-1>) -> tensor<E'>
で、Ei = element_type(inputs[i])
とE' = element_type(result)
です。
例
// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = stablehlo.multiply %arg0, %arg1 : tensor<i64>
stablehlo.return %0 : tensor<i64>
}) {
dimensions = array<i64: 0, 1>
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[0, 5], [12, 21]]
最大
セマンティクス
テンソル lhs
と rhs
に対して要素ごとの最大演算を実行し、result
テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- ブール値の場合: 論理 OR。
- 整数の場合: 整数の最大値。
- 浮動小数点数の場合: IEEE-754 の
maximum
。 - 複素数の場合:
(real, imaginary)
ペアの辞書順の最大値。複素数の順序付けには意外なセマンティクスが伴うため、将来的には、この演算での複素数のサポートは終了する予定です(#560)。 - 量子化された型の場合:
dequantize_op_quantize(maximum, lhs, rhs, type(result))
。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | lhs |
テンソルまたはテンソルごとの量子化テンソル | (C1) |
(I2) | rhs |
テンソルまたはテンソルごとの量子化テンソル | (C1) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
テンソルまたはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
。
例
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 8]]
最小
セマンティクス
テンソル lhs
と rhs
に対して要素ごとの最小演算を実行し、result
テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- ブール値の場合: 論理 AND。
- 整数の場合: 整数の最小値。
- 浮動小数点数の場合: IEEE-754 の
minimum
。 - 複素数の場合:
(real, imaginary)
ペアの辞書順最小値。複素数に順序付けを適用すると、意図しないセマンティクスが発生するため、今後、このオペレーションでの複素数のサポートを削除する予定です(#560)。 - 量子化型の場合:
dequantize_op_quantize(minimum, lhs, rhs, type(result))
。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | lhs |
テンソルまたはテンソルごとの量子化テンソル | (C1) |
(I2) | rhs |
テンソルまたはテンソルごとの量子化テンソル | (C1) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
テンソルまたはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
。
例
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 4]]
乗算
セマンティクス
2 つのテンソル lhs
と rhs
の要素ごとの積を実行し、result
テンソルを生成します。要素のタイプに応じて、次の操作を行います。
- ブール値の場合: 論理 AND。
- 整数の場合: 整数の乗算。
- 浮動小数点数の場合: IEEE-754 の
multiplication
。 - 複素数の場合: 複素数乗算。
- 量子化された型の場合:
dequantize_op_quantize(multiply, lhs, rhs, type(result))
。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | lhs |
テンソルまたはテンソルごとの量子化テンソル | (C1) |
(I2) | rhs |
テンソルまたはテンソルごとの量子化テンソル | (C1) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
テンソルまたはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(operand) = baseline_type(result)
。
例
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 12], [21, 32]]
negate
セマンティクス
operand
テンソルの要素ごとの否定を行い、result
テンソルを生成します。要素のタイプに応じて、次の操作を行います。
- 符号付き整数の場合: 整数の否定。
- 符号なし整数の場合: 符号付き整数へのビットキャスト、整数の否定、符号なし整数へのビットキャスト。
- 浮動小数点数: IEEE-754 の
negate
。 - 複素数の場合: 複素数を否定する。
- 量子化された型の場合:
dequantize_op_quantize(negate, operand, type(result))
。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
整数型、浮動小数点型、複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
整数型、浮動小数点型、複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(operand) = baseline_type(result)
。
例
// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand) : (tensor<2xi32>) -> tensor<2xi32>
// %result: [0, 2]
// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"(%operand) : (tensor<1xcomplex<f32>>) -> tensor<1xcomplex<f32>>
// %result: [-2.5, -0.0]
いない
セマンティクス
テンソル operand
の要素ごとの NOT を実行し、result
テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- ブール値の場合: 論理 NOT。
- 整数の場合: ビット演算 NOT。
引数
名前 | タイプ | 制約 |
---|---|---|
operand |
ブール型または整数型のテンソル | (C1) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
ブール型または整数型のテンソル | (C1) |
制約
- (C1)
type(operand) = type(result)
。
例
// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand) : (tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[-2, -3], [-4, -5]]
// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"(%operand) : (tensor<2xi1>) -> tensor<2xi1>
// %result: [false, true]
optimization_barrier
セマンティクス
operand
を生成するオペレーションが result
に依存するオペレーションの前に実行され、コンパイラ変換によってオペレーションがバリアを超えて移動されないようにします。それ以外の場合、オペレーションは ID(result = operand
など)です。
引数
名前 | タイプ | 制約 |
---|---|---|
operand |
テンソルの可変数、テンソルごとの量子化テンソルまたはトークン | (C1) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
テンソルの可変数、テンソルごとの量子化テンソルまたはトークン | (C1) |
制約
- (C1)
type(operand...) = type(result...)
。
例
// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
// %result0: 0.0
// %result1: 1.0
または
セマンティクス
2 つのテンソル lhs
と rhs
の要素ごとの OR を実行し、result
テンソルを生成します。要素のタイプに応じて、次の操作を行います。
- ブール値の場合: 論理 OR。
- 整数の場合: ビット演算 OR。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | lhs |
整数型またはブール型のテンソル | (C1) |
(I2) | rhs |
整数型またはブール型のテンソル | (C1) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
整数型またはブール型のテンソル | (C1) |
制約
- (C1)
type(lhs) = type(rhs) = type(result)
。
例
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, true]]
アウトフィード
セマンティクス
inputs
をアウトフィードに書き込み、result
トークンを生成します。
outfeed_config
のセマンティクスは実装で定義されます。
入力
ラベル | 名前 | タイプ |
---|---|---|
(I1) | inputs |
テンソルまたは量子化テンソルの可変数 |
(I2) | token |
token |
(I3) | outfeed_config |
string 型の定数 |
出力
名前 | タイプ |
---|---|
result |
token |
例
%result = "stablehlo.outfeed"(%input0, %token) {
outfeed_config = ""
} : (tensor<2x2x2xi64>, !stablehlo.token) -> !stablehlo.token
パッド
セマンティクス
テンソルの周囲と、指定された padding_value
を使用してテンソルの要素の間にパディングすることで、operand
を拡張します。
edge_padding_low
と edge_padding_high
は、各ディメンションの下端(インデックス 0 の隣)と上端(最大インデックスの隣)に追加されるパディングの量をそれぞれ指定します。パディングの量は負の値にすることができます。負のパディングの絶対値は、指定されたディメンションから削除する要素の数を示します。
interior_padding
は、各ディメンション内の任意の 2 つの要素間に追加されるパディングの量を指定します。この値は負の値にすることはできません。内部パディングはエッジ パディングの前に行われるため、負のエッジ パディングを行うと、内部パディングされたオペランドから要素が削除されます。
より正式には、result[result_index]
は次のように定義されます。
result_index = edge_padding_low + operand_index * (interior_padding + 1)
の場合、operand[operand_index]
。- それ以外の場合は
padding_value
。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
テンソルまたはテンソルごとの量子化テンソル | (C1)、(C2)、(C4) |
(I2) | padding_value |
0 次元テンソルまたはテンソルごとの量子化テンソル | (C1) |
(I3) | edge_padding_low |
si64 型の 1 次元テンソル定数 |
(C1)、(C4) |
(I4) | edge_padding_high |
si64 型の 1 次元のテンソル定数 |
(C1)、(C4) |
(I5) | interior_padding |
si64 型の 1 次元テンソル定数 |
(C2 ~ C4) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
テンソルまたはテンソルごとの量子化テンソル | (C3-C6) |
制約
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result)
。 - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)
。 - (C3)
0 <= interior_padding
。 - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high
。
例
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
edge_padding_low = array<i64: 0, 1>,
edge_padding_high = array<i64: 2, 1>,
interior_padding = array<i64: 1, 2>
} : (tensor<2x3xi32>, tensor<i32>) -> tensor<5x9xi32>
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
partition_id
セマンティクス
現在のプロセスの partition_id
を生成します。
出力
名前 | タイプ |
---|---|
result |
ui32 型の 0 次元テンソル |
例
%result = "stablehlo.partition_id"() : () -> tensor<ui32>
popcnt
セマンティクス
operand
テンソルに設定されているビット数を要素ごとにカウントし、result
テンソルを生成します。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
整数型のテンソル | (C1) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
整数型のテンソル | (C1) |
制約
- (C1)
type(operand) = type(result)
。
例
// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand) : (tensor<4xi64>) -> tensor<4xi64>
// %result: [0, 1, 1, 7]
電力
セマンティクス
lhs
テンソルと rhs
テンソルの要素ごとの指数演算を行い、result
テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- 整数の場合: 整数のべき乗。
- 浮動小数点数の場合: IEEE-754 の
pow
。 - 複素数の場合: 複素指数関数。
- 量子化された型の場合:
dequantize_op_quantize(power, lhs, rhs, type(result))
。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | lhs |
整数型、浮動小数点型、複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
(I2) | rhs |
整数、浮動小数点、複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
整数型、浮動小数点型、複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(operand) = baseline_type(result)
。
例
// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs) : (tensor<6xf64>, tensor<6xf64>) -> tensor<6xf64>
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]
real
セマンティクス
operand
から要素ごとに実部を抽出し、result
テンソルを生成します。より正式には、各要素 x
について real(x) = is_complex(x) ? real_part(x) : x
です。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
浮動小数点数または複素数のテンソル | (C1)、(C2) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
浮動小数点型のテンソル | (C1)、(C2) |
制約
- (C1)
shape(result) = shape(operand)
。 - (C2)
element_type(result)
は次のように定義されます。is_complex(operand)
の場合、complex_element_type(element_type(operand))
。- それ以外の場合は
element_type(operand)
。
例
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [1.0, 3.0]
受信
セマンティクス
channel_id
を使用してチャネルからデータを受信し、results
を生成します。
is_host_transfer
が true
の場合、オペレーションはホストからデータを転送します。そうでない場合は、別のデバイスからデータを転送します。具体的な意味は実装によって異なります。このフラグは channel_type
で提供される情報と重複するため、今後はどちらか一方のみを保持する予定です(#666)。
results
は、先頭にペイロード値、最後にトークンで構成されます。今後、明確性を高めるために、ペイロードとトークンを 2 つの個別の出力に分割する予定です(#670)。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | token |
token |
(C4) |
(I2) | channel_id |
si64 型の定数 |
|
(I3) | channel_type |
DEVICE_TO_DEVICE と HOST_TO_DEVICE の列挙型 |
(C1) |
(I4) | is_host_transfer |
i1 型の定数 |
(C1) |
出力
名前 | タイプ | 制約 |
---|---|---|
results |
テンソル、量子化テンソル、トークンの可変数 | (C2 ~ C4) |
制約
- (C1)
channel_type
は次のように定義されます。is_host_transfer = true
の場合、HOST_TO_DEVICE
- それ以外の場合は
DEVICE_TO_DEVICE
。
- (C2)
0 < size(results)
。 - (C3)
is_empty(result[:-1])
またはis_tensor(type(results[:-1]))
。 - (C4)
is_token(type(results[-1]))
。
例
%results0, %results1 = "stablehlo.recv"(%token) {
channel_handle = #stablehlo.channel_handle<handle = 1, type = 3>,
is_host_transfer = true
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
reduce
セマンティクス
dimensions
に沿ってリダクション関数 body
を inputs
と init_values
に適用し、results
テンソルを生成します。
減算の順序は実装で定義されます。つまり、すべての実装ですべての入力に対して同じ結果が得られるように、body
と init_values
はモノイドを形成する必要があります。ただし、この条件は一般的な多くの減少では成立しません。たとえば、body
の浮動小数点加算と init_values
のゼロは、浮動小数点加算が結合的ではないため、実際にはモノイドを形成しません。
より正式には、results...[j0, ..., jR-1] = reduce(input_slices_converted)
です。ここで:
input_slices = inputs...[j0, ..., :, ..., jR-1]
。:
はdimensions
に挿入されます。input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...)
。init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...)
。reduce(input_slices_converted) = exec(schedule)
は任意のバイナリ ツリーです。schedule
は次のとおりです。exec(node) = body(exec(node.left), exec(node.right))
。exec(leaf) = leaf.value
。
schedule
は実装定義の完全二分木で、順序付き走査は次のとおりです。input_slices_converted...[index]
値。index_space(input_slices_converted)
内のすべてのindex
は、index
の昇順の辞書順で指定します。- 実装で定義された位置に、実装で定義された量の
init_values_converted
が点在します。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | inputs |
テンソルの可変数またはテンソルごとの量子化テンソル | (C1~C4)、(C6)、(C7) |
(I2) | init_values |
0 次元テンソルまたはテンソルごとの量子化テンソルの可変数 | (C2)、(C3) |
(I3) | dimensions |
si64 型の 1 次元テンソル定数 |
(C4)、(C5)、(C7) |
(I4) | body |
関数 | (C6) |
出力
名前 | タイプ | 制約 |
---|---|---|
results |
テンソルの可変数またはテンソルごとの量子化テンソル | (C3)、(C7)、(C8) |
制約
- (C1)
same(shape(inputs...))
。 - (C2)
element_type(inputs...) = element_type(init_values...)
。 - (C3)
0 < size(inputs) = size(init_values) = size(results) = N
。 - (C4)
0 <= dimensions < rank(inputs[0])
。 - (C5)
is_unique(dimensions)
。 - (C6)
body
の型は(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,
tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
で、is_promotable(element_type(inputs[i]), Ei)
です。 - (C7)
shape(results...) = shape(inputs...)
。ただし、dimensions
に対応するinputs...
のディメンション サイズは含まれません。 - (C8)
[0,N)
のすべてのi
に対してelement_type(results[i]) = Ei
。
例
// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
dimensions = array<i64: 1>
} : (tensor<1x6xi64>, tensor<i64>) -> tensor<1xi64>
// %result = [15]
reduce_precision
セマンティクス
operand
を exponent_bits
と mantissa_bits
を使用する別の浮動小数点型に要素単位で変換し、元の浮動小数点型に戻して output
テンソルを生成します。
より正式には、次のように記述します。
- 元の値のマンシンダビットが更新され、
roundToIntegralTiesToEven
セマンティクスを使用して、元の値をmantissa_bits
で表せる最も近い値に丸められます。 - 次に、
mantissa_bits
が元の値のマンシッサ ビット数より小さい場合は、マンシッサ ビットがmantissa_bits
に切り捨てられます。 - 次に、中間結果の指数ビットが
exponent_bits
で指定された範囲に収まらない場合、中間結果は元の符号を使用して無限大までオーバーフローするか、元の符号を使用してゼロにアンダーフローします。 - 量子化された型の場合は、
dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result))
を実行します。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C1) |
(I2) | exponent_bits |
si32 型の定数 |
(C2) |
(I3) | mantissa_bits |
si32 型の定数 |
(C3) |
出力
名前 | タイプ | 制約 |
---|---|---|
output |
浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(operand) = baseline_type(output)
。 - (C2)
1 <= exponent_bits
。 - (C3)
0 <= mantissa_bits
。
例
// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
exponent_bits = 5 : i32,
mantissa_bits = 10 : i32
} : (tensor<6xf64>) -> tensor<6xf64>
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]
reduce_scatter
セマンティクス
StableHLO プロセス グリッド内の各プロセス グループ内で、各プロセスの operand
テンソルの値に対して computations
を使用して減算を行い、scatter_dimension
に沿って減算結果を分割し、分割された部分をプロセス間で分散して result
を生成します。
このオペレーションでは、StableHLO プロセス グリッドが、次のように定義された process_groups
に分割されます。
channel_id <= 0 and use_global_device_ids = false
の場合、cross_replica(replica_groups)
channel_id > 0 and use_global_device_ids = false
の場合、cross_replica_and_partition(replica_groups)
channel_id > 0 and use_global_device_ids = true
の場合、flattened_ids(replica_groups)
その後、各 process_group
内で次の操作を行います。
reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation)
。parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension)
。process_group
内のすべてのsender
に対してresult@receiver = parts@sender[receiver_index]
(receiver_index = process_group.index(receiver)
)。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
テンソルまたはテンソルごとの量子化テンソル | (C1)、(C2)、(C7)、(C8) |
(I2) | scatter_dimension |
si64 型の定数 |
(C1)、(C2)、(C8) |
(I3) | replica_groups |
si64 型の 2 次元テンソル定数 |
(C3~C5) |
(I4) | channel_id |
si64 型の定数 |
(C6) |
(I5) | use_global_device_ids |
i1 型の定数 |
(C6) |
(I6) | computation |
関数 | (C7) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
テンソルまたはテンソルごとの量子化テンソル | (C8~C9) |
制約
- (C1)
dim(operand, scatter_dimension) % dim(process_groups, 1) = 0
。 - (C2)
0 <= scatter_dimension < rank(operand)
。 - (C3)
is_unique(replica_groups)
。 - (C4)
size(replica_groups)
は次のように定義されます。cross_replica
を使用する場合はnum_replicas
。cross_replica_and_partition
を使用する場合はnum_replicas
。flattened_ids
を使用する場合はnum_processes
。
- (C5)
0 <= replica_groups < size(replica_groups)
。 - (C6)
use_global_device_ids = true
の場合はchannel_id > 0
です。 - (C7)
computation
の型は(tensor<E>, tensor<E>) -> (tensor<E>)
です。is_promotable(element_type(operand), E)
- (C8)
shape(result) = shape(operand)
(ただし、次を除く)dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1)
。
- (C9)
element_type(result) = E
。
例
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
scatter_dimension = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x4xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[10, 12],
// [18, 20]]
// %result@(1, 0): [[14, 16],
// [22, 24]]
reduce_window
セマンティクス
inputs
と init_values
のウィンドウに削減関数 body
を適用し、results
を生成します。
次の図は、具体的な例を使用して results...
の要素が inputs...
から計算される方法を示しています。
より正式には、results...[result_index] = reduce(windows, init_values, axes(inputs...), body)
(reduce を参照)で、次のように定義されます。
padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1)
。window_start = result_index * window_strides
window_end = window_start + (window_dimensions - 1) * window_dilations + 1
windows = slice(padded_inputs..., window_start, window_end, window_dilations)
。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | inputs |
テンソルの可変数またはテンソルごとの量子化テンソル | (C1~C4)、(C6)、(C8)、(C10)、(C12)、(C13)、(C15) |
(I2) | init_values |
0 次元テンソルまたはテンソルごとの量子化テンソルの可変数 | (C1)、(C13) |
(I3) | window_dimensions |
si64 型の 1 次元テンソル定数 |
(C4)、(C5)、(C15) |
(I4) | window_strides |
si64 型の 1 次元テンソル定数 |
(C6)、(C7)、(C15) |
(I5) | base_dilations |
si64 型の 1 次元のテンソル定数 |
(C8)、(C9)、(C15) |
(I6) | window_dilations |
si64 型の 1 次元テンソル定数 |
(C10)、(C11)、(C15) |
(I7) | padding |
si64 型の 2 次元テンソル定数 |
(C12)、(C15) |
(I8) | body |
関数 | (C13) |
出力
名前 | タイプ | 制約 |
---|---|---|
results |
テンソルの可変数またはテンソルごとの量子化テンソル | (C1)、(C14~C16) |
制約
- (C1)
0 < size(inputs) = size(init_values) = size(results) = N
。 - (C2)
same(shape(inputs...))
。 - (C3)
element_type(inputs...) = element_type(init_values...)
。 - (C4)
size(window_dimensions) = rank(inputs[0])
。 - (C5)
0 < window_dimensions
。 - (C6)
size(window_strides) = rank(inputs[0])
。 - (C7)
0 < window_strides
。 - (C8)
size(base_dilations) = rank(inputs[0])
。 - (C9)
0 < base_dilations
。 - (C10)
size(window_dilations) = rank(inputs[0])
。 - (C11)
0 < window_dilations
。 - (C12)
shape(padding) = [rank(inputs[0]), 2]
。 - (C13)
body
の型は(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,
tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
で、is_promotable(element_type(inputs[i]), Ei)
です。 - (C14)
same(shape(results...))
。 - (C15)
shape(results[0]) = num_windows
ここで:dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1
。padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1]
dilated_window_shape = (window_dimensions - 1) * window_dilations + 1
is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape
num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1
。
- (C16)
[0,N)
内のすべてのi
に対してelement_type(results[i]) = Ei
。
例
// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
window_dimensions = array<i64: 2, 1>,
window_strides = array<i64: 4, 1>,
base_dilations = array<i64: 2, 1>,
window_dilations = array<i64: 3, 1>,
padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<3x2xi64>, tensor<i64>) -> tensor<2x2xi64>
// %result = [[0, 0], [3, 4]]
残り
セマンティクス
被除数 lhs
テンソルと除数 rhs
テンソルの要素ごとの余りを実行し、result
テンソルを生成します。
より正式には、結果の符号は被除数から取得され、結果の絶対値は常に除数の絶対値よりも小さくなります。残りは lhs - d * rhs
として計算されます。ここで、d
は次のように定義されます。
- 整数の場合:
stablehlo.divide(lhs, rhs)
。 - 浮動小数点数の場合: 丸め属性
roundTowardZero
を持つ IEEE-754 のdivision(lhs, rhs)
。 - 複素数の場合: 未定(#997)。
- 量子化された型の場合:
dequantize_op_quantize(remainder, lhs, rhs, type(result))
。
浮動小数点要素型の場合、このオペレーションは IEEE-754 仕様の remainder
オペレーションとは対照的です。d
は、lhs/rhs
の正確な値に最も近い整数値で、偶数に等しい値です。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | lhs |
整数型、浮動小数点型、複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
(I2) | rhs |
整数型、浮動小数点型、複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
整数型、浮動小数点型、複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(operand) = baseline_type(result)
。
例
// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64>
// %result: [2, -2, 2, -2]
replica_id
セマンティクス
現在のプロセスの replica_id
を生成します。
出力
名前 | タイプ |
---|---|
result |
ui32 型の 0 次元テンソル |
例
%result = "stablehlo.replica_id"() : () -> tensor<ui32>
reshape
セマンティクス
operand
テンソルを result
テンソルに再フォーマットします。概念的には、同じ正規表現を維持しながら、形状を変更する可能性(tensor<2x3xf32>
から tensor<3x2xf32>
または tensor<6xf32>
への変更など)があります。
より正式には、result[result_index] = operand[operand_index]
で、result_index
と operand_index
は index_space(result)
と index_space(operand)
の辞書順で同じ位置にあります。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
テンソルまたは量子化テンソル | (C1~C3) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
テンソルまたは量子化テンソル | (C1~C3) |
制約
- (C1)
element_type(result)
は次のように定義されます。element_type(operand)
(!is_per_axis_quantized(operand)
の場合)。element_type(operand)
。ただし、quantization_dimension(operand)
とquantization_dimension(result)
が異なる場合があります。
- (C2)
size(operand) = size(result)
。 - (C3)
is_per_axis_quantized(operand)
の場合:reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
。dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
。
例
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.reshape"(%operand) : (tensor<2x3xi32>) -> tensor<3x2xi32>
// %result: [[1, 2], [3, 4], [5, 6]]
reverse
セマンティクス
指定された dimensions
に沿って operand
の要素の順序を逆にして、result
テンソルを生成します。より正式には、result[result_index] = operand[operand_index]
で次のようにします。
dimensions
のd
の場合、operand_index[d] = dim(result, d) - result_index[d] - 1
- そうでない場合は
operand_index[d] = result_index[d]
。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
テンソルまたはテンソルごとの量子化テンソル | (C1)、(C3) |
(I2) | dimensions |
si64 型の 1 次元テンソル定数 |
(C2)、(C3) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
テンソルまたはテンソルごとの量子化テンソル | (C1)、(C3) |
制約
- (C1)
type(operand) = type(result)
。 - (C2)
is_unique(dimensions)
。 - (C3)
0 <= dimensions < rank(result)
。
例
// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
dimensions = array<i64: 1>
} : (tensor<3x2xi32>) -> tensor<3x2xi32>
// %result: [[2, 1], [4, 3], [6, 5]]
rng
セマンティクス
rng_distribution
アルゴリズムを使用して乱数を生成し、指定された形状 shape
の result
テンソルを生成します。
rng_distribution = UNIFORM
の場合、乱数は [a, b)
の区間で均一分布に従って生成されます。a >= b
の場合、動作は未定義です。
rng_distribution = NORMAL
の場合、乱数は平均 = a
、標準偏差 = b
の正規分布に従って生成されます。b < 0
の場合、動作は未定義です。
乱数の生成方法は実装で定義されます。たとえば、確定的である場合もあれば、そうでない場合もあります。また、非表示の状態を使用する場合もあれば、使用しないこともあります。
多くの関係者と話し合った結果、このオペレーションは事実上非推奨であることが判明したため、今後は削除を検討する予定です(#597)。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | a |
整数型、ブール型、浮動小数点型の 0 次元テンソル | (C1)、(C2) |
(I2) | b |
整数型、ブール型、浮動小数点型の 0 次元テンソル | (C1)、(C2) |
(I3) | shape |
si64 型の 1 次元テンソル定数 |
(C3) |
(I4) | rng_distribution |
UNIFORM と NORMAL の列挙型 |
(C2) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
整数型、ブール型、浮動小数点型のテンソル | (C1~C3) |
制約
- (C1)
element_type(a) = element_type(b) = element_type(result)
。 - (C2)
rng_distribution = NORMAL
の場合はis_float(a)
です。 - (C3)
shape(result) = shape
。
例
// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
rng_distribution = #stablehlo<rng_distribution UNIFORM>
} : (tensor<i32>, tensor<i32>, tensor<2xi64>) -> tensor<3x3xi32>
// %result: [
// [1, 0, 1],
// [1, 1, 1],
// [0, 0, 0]
// ]
rng_bit_generator
セマンティクス
初期状態 initial_state
に基づいて、疑似乱数生成アルゴリズム rng_algorithm
を使用して、均一なランダムビットで満たされた output
と更新された出力状態 output_state
を返します。出力は initial_state
の確定的関数であることが保証されますが、実装間で確定的であるとは限りません。
rng_algorithm
は、次のいずれかです。
DEFAULT
: 実装定義のアルゴリズム。THREE_FRY
: Threefry アルゴリズムの実装定義バリアント。*PHILOX
: Philox アルゴリズムの実装定義バリアント。*
* 参照: Salmon et al. SC 2011. 並列乱数: 1、2、3 で簡単に生成。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | rng_algorithm |
DEFAULT 、THREE_FRY 、PHILOX の列挙型 |
(C2) |
(I2) | initial_state |
ui64 型の 1 次元テンソル |
(C1)、(C2) |
出力
名前 | タイプ | 制約 |
---|---|---|
output_state |
ui64 型の 1 次元テンソル |
(C1) |
output |
整数型または浮動小数点型のテンソル |
制約
- (C1)
type(initial_state) = type(output_state)
。 - (C2)
size(initial_state)
は次のように定義されます。rng_algorithm = DEFAULT
の場合は実装定義。rng_algorithm = THREE_FRY
の場合、2
。rng_algorithm = PHILOX
の場合は2
または3
。
例
// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
rng_algorithm = #stablehlo<rng_algorithm THREE_FRY>
} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x2xui64>)
// %output_state: [1, 6]
// %output: [
// [9236835810183407956, 16087790271692313299],
// [18212823393184779219, 2658481902456610144]
// ]
round_nearest_afz
セマンティクス
operand
テンソルに対して要素ごとの四捨五入を行い、同値の場合はゼロから離れた値を返します。これにより、result
テンソルが生成されます。IEEE-754 仕様の roundToIntegralTiesToAway
オペレーションを実装します。量子化された型の場合は、dequantize_op_quantize(round_nearest_afz, operand, type(result))
を実行します。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C1) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(operand) = baseline_type(result)
。
例
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]
round_nearest_even
セマンティクス
テンソルで operand
テンソルで、最も近い整数への要素単位の丸めを実行し、偶数の整数への結合を解除して、result
テンソルを生成します。IEEE-754 仕様の roundToIntegralTiesToEven
オペレーションを実装します。量子化された型の場合は、dequantize_op_quantize(round_nearest_even, operand, type(result))
を実行します。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C1) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(operand) = baseline_type(result)
。
例
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]
rsqrt
セマンティクス
operand
テンソルに対して要素単位の逆平方根演算を実行し、result
テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- 浮動小数点数の場合: IEEE-754 の
rSqrt
。 - 複素数の場合: 複素数の逆平方根。
- 量子化された型の場合:
dequantize_op_quantize(rsqrt, operand, type(result))
。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(operand) = baseline_type(result)
。
例
// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.5], [0.33333343, 0.2]]
散布
セマンティクス
inputs
テンソルと同じ results
テンソルを生成します。ただし、scatter_indices
で指定された複数のスライスが、update_computation
を使用して値 updates
で更新されます。
次の図は、updates...
の要素が results...
の要素にどのようにマッピングされるかを、具体的な例を使用して示しています。この図では、いくつかの updates...
インデックスの例を選択して、それらが対応する results...
インデックスの詳細について説明しています。
より正式に、index_space(updates[0])
内のすべての update_index
について、次のように定義します。
update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims]
。update_scatter_index = update_index[update_scatter_dims...]
。start_index
は次のように定義されます。scatter_indices[si0, ..., :, ..., siN]
。ここで、si
はupdate_scatter_index
の個々の要素であり、index_vector_dim
<rank(scatter_indices)
の場合、:
はindex_vector_dim
インデックスに挿入されます。- そうでない場合は
[scatter_indices[update_scatter_index]]
。
axes(inputs[0])
のd_input
は、d_input = scatter_dims_to_operand_dims[d_start]
の場合、full_start_index[d_input] = start_index[d_start]
。- それ以外の場合は
full_start_index[d_input] = 0
。
axes(inputs[0])
のd_input
は、d_input = input_batching_dims[i_batching]
かつd_start = scatter_indices_batching_dims[i_batching]
の場合はfull_batching_index[d_input] = update_scatter_index[d_start - (d_start < index_vector_dim ? 0 : 1)]
。- それ以外の場合は
full_batching_index[d_input] = 0
。
update_window_index = update_index[update_window_dims...]
。full_window_index = [wi0, ..., 0, ..., wiN]
。ここで、wi
はupdate_window_index
の個々の要素で、0
はinserted_window_dims
とinput_batching_dims
のインデックスに挿入されます。result_index = full_start_index + full_batching_index + full_window_index
。
results = exec(schedule, inputs)
。ここで、
schedule
は、index_space(updates[0])
の実装定義された並べ替えです。exec([update_index, ...], results) = exec([...], updated_results)
ここで:result_index
がshape(results...)
の範囲内にある場合updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )
updated_values = update_computation(results...[result_index], updates_converted)
updated_results
は、results...[result_index]
がupdated_values...
に設定されたresults
のコピーです。- それ以外の場合
updated_results = results
。
exec([], results) = results
。
indices_are_sorted
が true
の場合、実装では scatter_indices
が scatter_dims_to_operand_dims
に関して並べ替えられていると想定できます。それ以外の場合、動作は未定義です。より正式には、indices(result)
のすべての i1 < i2
について、full_start_index(i1)
<= full_start_index(i2)
です。
unique_indices
が true
の場合、実装は分散されているすべての result_index
インデックスが一意であると想定できます。unique_indices
が true
で、分散先のインデックスが一意でない場合、動作は未定義です。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | inputs |
テンソルの可変数またはテンソルごとの量子化テンソル | (C1)、(C2)、(C4 ~ C6)、(C11)、(C13)、(C18)、(C21)、(C23 ~ C24) |
(I2) | scatter_indices |
整数型のテンソル | (C4)、(C15)、(C19)、(C22) |
(I3) | updates |
テンソルの可変数またはテンソルごとの量子化テンソル | (C3-C6)、(C8) |
(I4) | update_window_dims |
si64 型の 1 次元のテンソル定数 |
(C2)、(C4)、(C7~C8) |
(I5) | inserted_window_dims |
si64 型の 1 次元テンソル定数 |
(C2)、(C4)、(C9 ~ C11) |
(I6) | input_batching_dims |
si64 型の 1 次元テンソル定数 |
(C2)、(C4)、(C9)、(C12-13)、(C17-18)、(C20) |
(I7) | scatter_indices_batching_dims |
si64 型の 1 次元テンソル定数 |
(C14 ~ C18) |
(I8) | scatter_dims_to_operand_dims |
si64 型の 1 次元テンソル定数 |
(C19 ~ C21) |
(I9) | index_vector_dim |
si64 型の定数 |
(C4)、(C16)、(C19)、(C22) |
(I10) | indices_are_sorted |
i1 型の定数 |
|
(I11) | unique_indices |
i1 型の定数 |
|
(I12) | update_computation |
関数 | (C23) |
出力
名前 | タイプ | 制約 |
---|---|---|
results |
テンソルの可変数またはテンソルごとの量子化テンソル | (C24 ~ C25) |
制約
- (C1)
same(shape(inputs...))
。 - (C2)`rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims)
- size(input_batching_dims)`.
- (C3)
same(shape(updates...))
。 - (C4)
shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes)
ここで:update_scatter_dim_sizes = shape(scatter_indices)
。ただし、index_vector_dim
に対応するscatter_indices
のディメンション サイズは含まれません。update_window_dim_sizes <= shape(inputs[0])
。ただし、inserted_window_dims
とinput_batching_dims
に対応するinputs[0]
のディメンション サイズは含まれません。combine
は、update_scatter_dim_sizes
をupdate_scatter_dims
に対応する軸に配置し、update_window_dim_sizes
をupdate_window_dims
に対応する軸に配置します。
- (C5)
0 < size(inputs) = size(updates) = N
。 - (C6)
element_type(updates...) = element_type(inputs...)
。 - (C7)
is_unique(update_window_dims) and is_sorted(update_window_dims)
。 - (C8)
0 <= update_window_dims < rank(updates[0])
。 - (C9)
is_unique(concatenate(inserted_window_dims, input_batching_dims))
- (C10)
is_sorted(inserted_window_dims)
。 - (C11)
0 <= inserted_window_dims < rank(inputs[0])
。 - (C12)
is_sorted(input_batching_dims)
。 - (C13)
0 <= input_batching_dims < rank(inputs[0]))
。 - (C14)
is_unique(scatter_indices_batching_dims)
。 - (C15)
0 <= scatter_indices_batching_dims < rank(scatter_indices)
。 - (C16)
index_vector_dim not in scatter_indices_batching_dims
。 - (C17)
size(input_batching_dims) == size(scatter_indices_batching_dims)
。 - (C18)
dim(inputs[0], input_batching_dims...) = dim(scatter_indices, scatter_indices_batching_dims...)
。 - (C19)
size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1
。 - (C20)
is_unique(concatenate(scatter_dims_to_operand_dims, input_batching_dims))
。 - (C21)
0 <= scatter_dims_to_operand_dims < rank(inputs[0])
。 - (C22)
0 <= index_vector_dim <= rank(scatter_indices)
。 - (C23)
update_computation
の型は(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
です。ここで、is_promotable(element_type(inputs[i]), Ei)
。 - (C24)
shape(inputs...) = shape(results...)
。 - (C25)
[0,N)
内のすべてのi
に対してelement_type(results[i]) = Ei
。
例
// %input: [
// [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ],
// [
// [[25, 26], [27, 28], [29, 30], [31, 32]],
// [[33, 34], [35, 36], [37, 38], [39, 40]],
// [[41, 42], [43, 44], [45, 46], [47, 48]]
// ]
// ]
// %scatter_indices: [
// [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 9]]
// ],
// [
// [[0, 0], [2, 1], [2, 2]],
// [[1, 2], [0, 1], [1, 0]]
// ]
// ]
// %update: [
// [
// [[1, 1], [1, 1], [1, 1]],
// [[1, 1], [1, 1], [1, 1]]
// ],
// [
// [[1, 1], [1, 1], [1, 1]],
// [[1, 1], [1, 1], [1, 1]]
// ]
// ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [3, 4],
inserted_window_dims = [1],
input_batching_dims = [0],
scatter_indices_batching_dims = [1],
scatter_dims_to_operand_dims = [2, 1],
index_vector_dim = 3>,
indices_are_sorted = false,
unique_indices = false
} : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>, tensor<2x2x3x2x2xi64>) -> tensor<2x3x4x2xi64>
// %result: [
// [
// [[3, 4], [6, 7], [6, 7], [7, 8]],
// [[9, 10],[11, 12], [15, 16], [17, 18]],
// [[17, 18], [19, 20], [22, 23], [24, 25]]
// ],
// [
// [[25, 26], [28, 29], [30, 31], [31, 32]],
// [[35, 36], [38, 39], [38, 39], [39, 40]],
// [[41, 42], [44, 45], [46, 47], [47, 48]]
// ]
// ]
選択
セマンティクス
result
テンソルを生成します。各要素は、pred
の対応する要素の値に基づいて on_true
テンソルまたは on_false
テンソルから選択されます。正式には result[result_index] = pred_element ? on_true[result_index] :
on_false[result_index]
です。ここで pred_element = rank(pred) = 0 ? pred[] :
pred[result_index]
は量子化された型の場合は、dequantize_select_quantize(pred, on_true, on_false, type(result))
を実行します。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | pred |
i1 型のテンソル |
(C1) |
(I2) | on_true |
テンソルまたはテンソルごとの量子化テンソル | (C1-C2) |
(I3) | on_false |
テンソルまたはテンソルごとの量子化テンソル | (C2) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
テンソルまたはテンソルごとの量子化テンソル | (C2) |
制約
- (C1)
rank(pred) = 0 or shape(pred) = shape(on_true)
。 - (C2)
baseline_type(on_true) = baseline_type(on_false) = baseline_type(result)
。
例
// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false) : (tensor<2x2xi1>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 2], [3, 8]]
select_and_scatter
セマンティクス
select
を使用した input
テンソルの reduce_window
の結果に基づいて、scatter
を使用して source
テンソルから値を散布し、result
テンソルを生成します。
次の図は、具体的な例を使用して、result
の要素が operand
と source
からどのように計算されるかを示しています。
よりフォーマルな表現:
selected_values = reduce_window_without_init(...)
は、次の入力に置き換えます。inputs = [operand].
window_dimensions
、window_strides
、padding
はそのまま使用されます。base_dilations = windows_dilations = 1
。body
は次のように定義されます。
def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>: return select(arg0, arg1) ? arg0 : arg1;
ここで、
E = element_type(operand)
とreduce_window_without_init
はreduce_window
とまったく同じように機能しますが、基になるreduce
のschedule
(削減を参照)に init 値が含まれないことが異なります。対応するウィンドウに値がない場合の動作は現在未定義です(#731)。result[result_index] = reduce([source_values], [init_value], [0], scatter)
ここで:source_values = [source[source_index] for source_index in source_indices]
。selected_index(source_index) = operand_index
:selected_values[source_index]
にoperand_index
のoperand
要素がある場合。source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index]
。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
テンソルまたはテンソルごとの量子化テンソル | (C1~C4)、(C6)、(C8~C11) |
(I2) | source |
テンソルまたはテンソルごとの量子化テンソル | (C1)、(C2) |
(I3) | init_value |
0 次元テンソルまたはテンソルごとの量子化テンソル | (C3) |
(I4) | window_dimensions |
si64 型の 1 次元のテンソル定数 |
(C2)、(C4)、(C5) |
(I5) | window_strides |
si64 型の 1 次元のテンソル定数 |
(C2)、(C6)、(C7) |
(I6) | padding |
si64 型の 2 次元のテンソル定数 |
(C2)、(C8) |
(I7) | select |
関数 | (C9) |
(I8) | scatter |
関数 | (C10) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
テンソルまたはテンソルごとの量子化テンソル | (C11-C12) |
制約
- (C1)
element_type(operand) = element_type(source)
。 - (C2)
shape(source) = num_windows
。ここで:padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1]
。is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape
num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1
。
- (C3)
element_type(init_value) = element_type(operand)
。 - (C4)
size(window_dimensions) = rank(operand)
。 - (C5)
0 < window_dimensions
。 - (C6)
size(window_strides) = rank(operand)
。 - (C7)
0 < window_strides
。 - (C8)
shape(padding) = [rank(operand), 2]
。 - (C9)
select
の型は(tensor<E>, tensor<E>) -> tensor<i1>
(E = element_type(operand)
)です。 - (C10)
scatter
の型は(tensor<E>, tensor<E>) -> tensor<E>
です。is_promotable(element_type(operand), E)
- (C11)
shape(operand) = shape(result)
。 - (C12)
element_type(result) = E
。
例
// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
window_dimensions = array<i64: 3, 1>,
window_strides = array<i64: 2, 1>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi64>, tensor<2x2xi64>, tensor<i64>) -> tensor<4x2xi64>
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]
送信
セマンティクス
inputs
をチャンネル channel_id
に送信し、result
トークンを生成します。
is_host_transfer
が true
の場合、オペレーションはホストにデータを転送します。そうでない場合は、別のデバイスにデータを転送します。つまり、実装は定義されます。このフラグは channel_type
で提供される情報と重複するため、今後はどちらか一方のみを保持する予定です(#666)。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | inputs |
テンソルまたは量子化テンソルの可変数 | |
(I2) | token |
token |
|
(I3) | channel_id |
si64 型の定数 |
|
(I4) | channel_type |
DEVICE_TO_DEVICE と DEVICE_TO_HOST の列挙型 |
(C1) |
(I5) | is_host_transfer |
i1 型の定数 |
(C1) |
出力
名前 | タイプ |
---|---|
result |
token |
制約
- (C1)
channel_type
は次のように定義されます。is_host_transfer = true
の場合、DEVICE_TO_HOST
- それ以外の場合は
DEVICE_TO_DEVICE
。
例
%result = "stablehlo.send"(%operand, %token) {
channel_handle = #stablehlo.channel_handle<handle = 1, type = 2>,
is_host_transfer = true
} : (tensor<2x2xi64>, !stablehlo.token) -> !stablehlo.token
shift_left
セマンティクス
lhs
テンソルを rhs
ビット分要素ごとに左シフトし、result
テンソルを生成します。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | lhs |
整数型のテンソル | (C1) |
(I2) | rhs |
整数型のテンソル | (C1) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
整数型のテンソル | (C1) |
制約
- (C1)
type(lhs) = type(rhs) = type(result)
。
例
// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-2, 0, 8]
shift_right_arithmetic
セマンティクス
lhs
テンソルに対して rhs
ビット数だけ要素単位の右シフト演算を実行し、result
テンソルを生成します。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | lhs |
整数型のテンソル | (C1) |
(I2) | rhs |
整数型のテンソル | (C1) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
整数型のテンソル | (C1) |
制約
- (C1)
type(lhs) = type(rhs) = type(result)
。
例
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-1, 0, 1]
shift_right_logical
セマンティクス
lhs
テンソルを rhs
ビット数だけ要素ごとに論理右シフトし、result
テンソルを生成します。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | lhs |
整数型のテンソル | (C1) |
(I2) | rhs |
整数型のテンソル | (C1) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
整数型のテンソル | (C1) |
制約
- (C1)
type(lhs) = type(rhs) = type(result)
。
例
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [9223372036854775807, 0, 1]
署名
セマンティクス
要素ごとの operand
の符号を返して、result
テンソルを生成します。より正式には、各要素 x
のセマンティクスは、次のように Python 構文を使用して表現できます。
def sign(x):
if is_integer(x):
if compare(x, 0, LT, SIGNED): return -1
if compare(x, 0, EQ, SIGNED): return 0
return 1
elif is_float(x):
if is_nan(x): return NaN
if compare(x, -0.0, EQ, FLOAT): return -0.0
if compare(x, +0.0, EQ, FLOAT): return +0.0
if compare(x, 0.0, LT, FLOAT): return -1.0
return 1.0
elif is_complex(x):
if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
return divide(x, convert(abs(x), type(x)))
量子化型の場合は、dequantize_op_quantize(sign, operand, type(result))
を実行します。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
符号付き整数、浮動小数点数、複素型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
符号付き整数型、浮動小数点型、複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(operand) = baseline_type(result)
。
例
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
サイン
セマンティクス
operand
テンソルで要素ごとの正弦演算を実行し、result
テンソルを生成します。要素のタイプに応じて、次の操作を行います。
- 浮動小数点数の場合: IEEE-754 の
sin
。 - 複素数の場合: 複素正弦。
- 量子化された型の場合:
dequantize_op_quantize(sine, operand, type(result))
。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(operand) = baseline_type(result)
。
例
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.sine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [0.0, -1.0]]
slice
セマンティクス
静的に計算された開始インデックスを使用して operand
からスライスを取り出し、result
テンソルを生成します。start_indices
には各ディメンションのスライスの開始インデックスが、limit_indices
には各ディメンションのスライスの終了インデックス(これを含まない)が含まれ、strides
には各ディメンションのストライドが含まれます。
より正式には、result[result_index] = operand[operand_index]
で operand_index = start_indices + result_index * strides
。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
テンソルまたはテンソルごとの量子化テンソル | (C1 ~ C3)、(C5) |
(I2) | start_indices |
si64 型の 1 次元テンソル定数 |
(C2)、(C3)、(C5) |
(I3) | limit_indices |
si64 型の 1 次元テンソル定数 |
(C2)、(C3)、(C5) |
(I4) | strides |
si64 型の 1 次元テンソル定数 |
(C2)、(C4) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
テンソルまたはテンソルごとの量子化テンソル | (C1)、(C5) |
制約
- (C1)
element_type(operand) = element_type(result)
。 - (C2)
size(start_indices) = size(limit_indices) = size(strides) = rank(operand)
。 - (C3)
0 <= start_indices <= limit_indices <= shape(operand)
。 - (C4)
0 < strides
。 - (C5)
shape(result) = ceil((limit_indices - start_indices) / strides)
。
例
// %operand: [
// [0, 0, 0, 0],
// [0, 0, 1, 1],
// [0, 0, 1, 1]
// ]
%result = "stablehlo.slice"(%operand) {
start_indices = array<i64: 1, 2>,
limit_indices = array<i64: 3, 4>,
strides = array<i64: 1, 1>
} : (tensor<3x4xi64>) -> tensor<2x2xi64>
// % result: [
// [1, 1],
// [1, 1]
// ]
並べ替え
セマンティクス
comparator
に従って、ディメンション dimension
に沿って inputs
の 1 次元スライスを並べ替え、results
を生成します。
他のオペレーションの同様の入力とは異なり、dimension
では負の値を使用できます。セマンティクスは次のとおりです。今後、一貫性のため、この操作が禁止される可能性があります(#1377)。
is_stable
が true の場合、並べ替えは安定します。つまり、比較演算子によって等しいとみなされる要素の相対順序は保持されます。入力が 1 つの場合、2 つの要素 e1
と e2
は、comparator(e1, e2) = comparator(e2, e1) = false
の場合にのみ、比較演算子によって等しいと見なされます。これが複数の入力に一般化される仕組みについては、以下の形式化をご覧ください。
より正式に、index_space(results[0])
内のすべての result_index
について、次のように定義します。
adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension
。result_slice = [ri0, ..., :, ..., riR-1]
。ここで、riN
はresult_index
の個々の要素であり、:
はadjusted_dimension
に挿入されます。inputs_together = (inputs[0]..., ..., inputs[N-1]...)
。results_together[result_slice] = sort(inputs_together[result_slice], comparator_together)
。- ここで、
sort
は 1 次元スライスを降順以外の順序で並べ替えます。これは、左側の引数が右側の 2 番目の引数より小さい場合にcomparator_together
がtrue
を返すことを想定しています。 def comparator_together(lhs_together, rhs_together): args = [] for (lhs_el, rhs_el) in zip(lhs_together, rhs_together): args.append(lhs_el) args.append(rhs_el) return comparator(*args)
(results[0]..., ..., results[N-1]...) = results_together
。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | inputs |
テンソルの可変数またはテンソルごとの量子化テンソル | (C1 ~ C5) |
(I2) | dimension |
si64 型の定数 |
(C4) |
(I3) | is_stable |
i1 型の定数 |
|
(I4) | comparator |
関数 | (C5) |
出力
名前 | タイプ | 制約 |
---|---|---|
results |
テンソルの可変数またはテンソルごとの量子化テンソル | (C2)、(C3) |
制約
- (C1)
0 < size(inputs)
。 - (C2)
type(inputs...) = type(results...)
。 - (C3)
same(shape(inputs...) + shape(results...))
。 - (C4)
-R <= dimension < R
(R = rank(inputs[0])
)。 - (C5)
comparator
の型は(tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1>
です。ここで、Ei = element_type(inputs[i])
。
例
// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>, %arg3: tensor<i64>):
%predicate = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GT>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"stablehlo.return"(%predicate) : (tensor<i1>) -> ()
}) {
dimension = 0 : i64,
is_stable = true
} : (tensor<2x3xi64>, tensor<2x3xi64>) -> (tensor<2x3xi64>, tensor<2x3xi64>)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]
sqrt
セマンティクス
operand
テンソルで要素ごとの平方根演算を実行し、result
テンソルを生成します。要素のタイプに応じて、次の操作を行います。
- 浮動小数点数の場合: IEEE-754 の
squareRoot
。 - 複素数の場合: 複素平方根。
- 量子化された型の場合:
dequantize_op_quantize(sqrt, operand, type(result))
。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(operand) = baseline_type(result)
。
例
// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [2.0, 3.0]]
subtract
セマンティクス
2 つのテンソル lhs
と rhs
の要素ごとの減算を行い、result
テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- 整数の場合: 整数の減算。
- 浮動小数点数: IEEE-754 の
subtraction
。 - 複素数の場合: 複素減算。
- 量子化された型の場合:
dequantize_op_quantize(subtract, lhs, rhs, type(result))
。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | lhs |
整数型、浮動小数点型、複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
(I2) | rhs |
整数、浮動小数点、複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
整数型、浮動小数点型、複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
。
例
// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
// %result: [[1, 2], [3, 4]]
tan
セマンティクス
operand
テンソルに対して要素ごとのタンジェント演算を実行し、result
テンソルを生成します。要素のタイプに応じて、次の操作を行います。
- 浮動小数点数の場合: IEEE-754 の
tan
。 - 複素数の場合: 複素タンジェント。
- 量子化型の場合:
dequantize_op_quantize(tan, operand, type(result))
。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(operand) = baseline_type(result)
。
例
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.tan"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [
// [0.0, 1.63312e+16],
// [0.0, 5.44375e+15]
// ]
tanh
セマンティクス
operand
テンソルに対して要素単位の双曲線正接演算を実行し、result
テンソルを生成します。要素のタイプに応じて、次の操作を行います。
- 浮動小数点数の場合: IEEE-754 の
tanh
。 - 複素数の場合: 複素双曲線正接。
- 量子化された型の場合:
dequantize_op_quantize(tanh, operand, type(result))
。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_type(operand) = baseline_type(result)
。
例
// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand) : (tensor<3xf32>) -> tensor<3xf32>
// %result: [-0.76159416, 0.0, 0.76159416]
行 / 列の入れ替え
セマンティクス
permutation
を使用して operand
テンソルの次元を並べ替え、result
テンソルを生成します。より正式には、result[result_index] = operand[operand_index]
で、result_index[d] = operand_index[permutation[d]]
です。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
テンソルまたは量子化テンソル | (C1~C4) |
(I2) | permutation |
si64 型の 1 次元テンソル定数 |
(C2 ~ C4) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
テンソルまたは量子化テンソル | (C1)、(C3 ~ C4) |
制約
- (C1)
element_type(result)
は次のように定義されます。element_type(operand)
(!is_per_axis_quantized(operand)
の場合)。element_type(operand)
。ただし、quantization_dimension(operand)
とquantization_dimension(result)
が異なる場合があります。
- (C2)
permutation
はrange(rank(operand))
の並べ替えです。 - (C3)
shape(result) = dim(operand, permutation...)
。 - (C4)
is_per_axis_quantized(result)
の場合はquantization_dimension(operand) = permutation(quantization_dimension(result))
です。
例
// %operand: [
// [[1,2], [3,4], [5,6]],
// [[7,8], [9,10], [11,12]]
// ]
%result = "stablehlo.transpose"(%operand) {
permutation = array<i64: 2, 1, 0>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
// %result: [
// [[1,7], [3,9], [5,11]],
// [[2,8], [4,10], [6,12]]
// ]
triangular_solve
セマンティクス
下三角または上三角係数行列を持つ連立一次方程式のバッチを解きます。
より正式には、a
と b
が指定されている場合、result[i0, ..., iR-3, :, :]
は left_side
が true
の場合の op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :]
の解、または left_side
が false
の場合の x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :]
です。変数 x
は、op(a)
が transpose_a
によって決定されます。transpose_a
は次のいずれかです。
NO_TRANSPOSE
:a
をそのまま使用してオペレーションを実行します。TRANSPOSE
:a
の転置でオペレーションを実行します。ADJOINT
:a
の共役転置に対して演算を実行します。
入力データは、lower
が true
の場合、a
の下三角からのみ読み取られます。それ以外の場合は、a
の上三角から読み取られます。出力データは同じ三角形に返されます。もう 1 つの三角形の値は実装で定義されます。
unit_diagonal
が true の場合、実装は a
の対角要素が 1 に等しいと想定できます。それ以外の場合は、動作は未定義です。
量子化された型の場合は、dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower,
unit_diagonal, transpose_a), a, b, type(result))
を実行します。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | a |
浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1~C3) |
(I2) | b |
浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1~C4) |
(I3) | left_side |
i1 型の定数 |
(C3) |
(I4) | lower |
i1 型の定数 |
|
(I5) | unit_diagonal |
i1 型の定数 |
|
(I6) | transpose_a |
NO_TRANSPOSE 、TRANSPOSE 、ADJOINT の列挙型 |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) |
制約
- (C1)
baseline_element_type(a) = baseline_element_type(b)
。 - (C2)
2 <= rank(a) = rank(b) = R
。 - (C3)
shape(a)
とshape(b)
の関係は次のように定義されます。shape(a)[:-3] = shape(b)[:-3]
。dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1)
。
- (C4)
baseline_type(b) = baseline_type(result)
。
例
// %a = [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
// %b = [
// [2.0, 0.0, 0.0],
// [4.0, 8.0, 0.0],
// [6.0, 10.0, 12.0]
// ]
%result = "stablehlo.triangular_solve"(%a, %b) {
left_side = true,
lower = true,
unit_diagonal = false,
transpose_a = #stablehlo<transpose NO_TRANSPOSE>
} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// %result: [
// [2.0, 0.0, 0.0],
// [0.0, 2.0, 0.0],
// [0.0, 0.0, 2.0]
// ]
tuple
セマンティクス
値 val
から result
タプルを生成します。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | val |
可変長の値の数 | (C1) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
tuple | (C1) |
制約
- (C1)
result
の型はtuple<E0, ..., EN-1>
で、Ei = type(val[i])
です。
例
// %val0: [1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1) : (tensor<2xf32>, tuple<tensor<i32>>) -> tuple<tensor<2xf32>, tuple<tensor<i32>>>
// %result: ([1.0, 2.0], (3))
uniform_dequantize
セマンティクス
operand
型で定義された量子化パラメータに従って、量子化テンソル operand
を浮動小数点テンソル result
に要素ごとに変換します。
正式には result = dequantize(operand)
です。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
量子化テンソル | (C1)、(C2) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
浮動小数点型のテンソル | (C1)、(C2) |
制約
- (C1)
shape(operand) = shape(result)
。 - (C2)
element_type(result) = expressed_type(operand)
。
例
// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
// %result: [4.0, 15.0]
uniform_quantize
セマンティクス
result
型で定義された量子化パラメータに従って、浮動小数点テンソルまたは量子化テンソル operand
を量子化テンソル result
に要素単位で変換します。
よりフォーマルな表現では
is_float(operand)
の場合:result = quantize(operand, type(result))
。
is_quantized(operand)
の場合:float_result = dequantize(operand)
。result = quantize(float_result, type(result))
。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
浮動小数点型または量子化型のテンソル | (C1)、(C2) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
量子化テンソル | (C1)、(C2) |
制約
- (C1)
shape(operand) = shape(result)
。 - (C2)
expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand)
。
例
// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
// %result: [10, 10]
// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-20,0.2:-30}>>
// %result: [20, 45]
しばらく
セマンティクス
cond
関数が true
を出力している間に、body
関数を 0 回以上実行すると、出力が生成されます。より正式には、セマンティクスは Python 構文を使用して次のように表現できます。
internal_state = operand
while cond(*internal_state):
internal_state = body(*internal_state)
results = internal_state
無限ループの動作は未定です(#383)。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | operand |
テンソル、量子化テンソル、トークンの可変数 | (C1~C3) |
(I2) | cond |
関数 | (C1) |
(I3) | body |
関数 | (C2) |
出力
名前 | タイプ | 制約 |
---|---|---|
results |
テンソル、量子化テンソル、トークンの可変数 | (C3) |
制約
- (C1)
cond
の型は(T0, ..., TN-1) -> tensor<i1>
です。ここで、Ti = type(operand[i])
。 - (C2)
body
の型は(T0, ..., TN-1) -> (T0, ..., TN-1)
です。ここで、Ti = type(operand[i])
。 - (C3)
type(results...) = type(operand...)
。
例
// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%cond = "stablehlo.compare"(%arg0, %ten) {
comparison_direction = #stablehlo<comparison_direction LT>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
stablehlo.return %cond : tensor<i1>
}, {
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%new_sum = stablehlo.add %arg1, %one : tensor<i64>
%new_i = stablehlo.add %arg0, %one : tensor<i64>
stablehlo.return %new_i, %new_sum : tensor<i64>, tensor<i64>
}) : (tensor<i64>, tensor<i64>) -> (tensor<i64>, tensor<i64>)
// %results0: 10
// %results1: 10
xor
セマンティクス
2 つのテンソル lhs
と rhs
の要素ごとの XOR を実行し、result
テンソルを生成します。要素のタイプに応じて、次の操作を行います。
- ブール値の場合: 論理 XOR。
- 整数の場合: ビット単位の XOR。
入力
ラベル | 名前 | タイプ | 制約 |
---|---|---|---|
(I1) | lhs |
ブール型または整数型のテンソル | (C1) |
(I2) | rhs |
ブール値型または整数型のテンソル | (C1) |
出力
名前 | タイプ | 制約 |
---|---|---|
result |
ブール型または整数型のテンソル | (C1) |
制約
- (C1)
type(lhs) = type(rhs) = type(result)
。
例
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, false]]
方言の相互運用
現時点では、StableHLO プログラムには、StableHLO で定義されていないオペレーションが含まれている場合があります。
モジュール、関数、呼び出し、リターン
StableHLO は、ModuleOp、FuncOp、CallOp、ReturnOp にアップストリーム MLIR オペレーションを使用します。これは、既存の MLIR 機構との相互運用性を高めるために行われました。これは、FuncOp と ModuleOp をターゲットとする多くの有用なパスが記述されており、多くのコンパイル パイプラインは、これらの op が存在することを想定しているためです。これらの op には、完全な互換性の保証が適用されます。これらのオペレーションが互換性のない方法で変更された場合(削除など)、互換性を維持するために StableHLO と同等のオペレーションが追加されます。
CHLO
CHLO オペセットには、StableHLO に分解される上位レベルのオペレーションが含まれています。現在のところ、CHLO の互換性は保証されていません。互換性を保証するには、シリアル化の前に chlo-legalize-to-stablehlo パスを使用する必要があります。
シェイプ オペレーション
動的な StableHLO プログラムでコア MLIR 言語の特定のオペレーションを使用して、シェイプ計算を実行することは、コミュニティで一般的なユースケースです。一般的には、shape_of
や num_elements
などの shape
言語演算、dim
や from_elements
などの tensor
言語演算、組み込みの index
型演算があります。
Dynamism RFC > O2 では、これらはサポート範囲外とされています。ただし、相互運用性のために index
型の一部がサポートされています。これらの演算や型には互換性は保証されません。shape-legalize-to-stablehlo パスを使用して、これらのオペレーションを完全にサポートされている StableHLO オペレーションに変換できます。
非推奨のオペレーション
MHLO から継承された StableHLO オペレーションがいくつかあります。これらは非推奨で、StableHLO から移行される予定です。これらの削除の詳細については、StableHLO v1.0 のクリーンアップ #2283 をご覧ください。これらのサポート終了に関するトラッカーの問題は #2340 です。
これらのオペレーションは、次のカテゴリに分類されます。
- StableHLO オペレーションの「HLO に含まれない」カテゴリ - 最初は StableHLO オペセットの一部でしたが、後で適切ではないと判断されました。
broadcast
、create_token
、cross-replica-sum
、dot
、einsum
、torch_index_select
、unary_einsum
(#3)。 - 使用されていないオペレーション - これらのオペレーションは、ある時点では有用だったかもしれませんが、オペレーションが十分に開発されていないか、これらのオペレーションを使用するパイプラインがリファクタリングされ、オペレーションが不要になっています。これには、
map
、tuple
(#598)、get_tuple_element
、rng
、complex
の比較 #560、畳み込みwindow_reversal
(#1181)が含まれます。
これらのオペレーションの一部は、既存のオペレーション(broadcast
、create_token
、cross-replica-sum
、dot
、unary_einsum
)を使用して表現できるため、簡単に削除できます。これらのオペレーションは、既存の互換性期間(6 か月)が経過すると削除されます。その他のオペレーションは、削除の検討中です(einsum
、get_tuple_element
、map
、rng
torch_index_select
、tuple
、complex
比較、window_reversal
)。コミュニティからのフィードバックに基づいて、これらのオペレーションは削除されるか、完全なサポートで仕様に追加されます。これらのオペレーションの将来が判明するまでは、互換性が保証されるのは 6 か月間のみです。
実行
順次実行
StableHLO プログラムを実行するには、main
関数に入力値を指定し、出力値を計算します。関数の出力値は、対応する return
オペレーションにルートを持つオペレーションのグラフを実行することで計算されます。
実行順序は、データフローに沿っている限り(つまり、オペレーションが使用前に実行される場合)、実装で定義されます。StableHLO では、副作用のある演算はすべて 1 つのトークンを消費し、1 つのトークンを生成します(複数のトークンは after_all
を介して 1 つのトークンに多重化できます)。そのため、副作用の実行順序もデータフローに沿っています。たとえば、次のプログラムでは、%0
→ %1
→ %2
→ return
と %1
→ %0
→ %2
→ return
の 2 つの実行順序が可能です。
func.func @main() -> tensor<f64> {
%0 = stablehlo.constant dense<1.0> : tensor<f64>
%1 = stablehlo.constant dense<2.0> : tensor<f64>
%2 = stablehlo.add %0, %1 : tensor<f64>
return %2 : tensor<f64>
}
より正式には、StableHLO プロセスは、1)StableHLO プログラム、2)オペレーションのステータス(まだ実行されていない、すでに実行されている)、3)プロセスが処理中の中間値を組み合わせたものです。このプロセスは、main
関数への入力値から始まり、オペレーションのステータスと中間値を更新するオペレーションのグラフを経て、出力値で終了します。形式化は未定です(#484)。
並列実行
StableHLO プログラムは並行して実行でき、どちらも ui32
型の num_partitions
で num_replicas
の 2D プロセス グリッドに編成できます。
StableHLO プロセス グリッドでは、StableHLO プロセスの num_replicas * num_partitions
が同時に実行されています。各プロセスには一意の process_id = (replica_id, partition_id)
があります。replica_ids = range(num_replicas)
の replica_id
と partition_ids = range(num_partitions)
の partition_id
はどちらも ui32
型です。
プロセスグリッドのサイズはプログラムごとに静的に把握され(将来的には StableHLO プログラム #650 で明示的な一部にする予定です)、プロセスグリッド内の位置はプロセスごとに静的に把握されます。各プロセスは、replica_id
オペレーションと partition_id
オペレーションを介して、プロセス グリッド内の位置にアクセスできます。
プロセス グリッド内で、プログラムはすべて同じ(「単一プログラム、複数のデータ」スタイル)にすることも、すべて異なる(「複数プログラム、複数のデータ」スタイル)にすることも、その中間にも設定できます。今後、GSPMD(#619)など、並列 StableHLO プログラムを定義する他のイディオムのサポートを導入する予定です。
プロセス グリッド内のプロセスは、ほとんどが互いに独立しています。オペレーションのステータス、入力値、中間値、出力値はそれぞれ別々で、ほとんどのオペレーションはプロセス間で別々に実行されます(後述する少数の集約オペレーションを除く)。
ほとんどのオペレーションの実行では同じプロセスの値のみが使用されるため、通常、これらの値を名前で参照してもあいまいさはありません。ただし、集約オペレーションのセマンティクスを記述する場合、これは不十分です。そのため、特定のプロセス内の値 name
を参照する記号 name@process_id
が使用されます。(この観点から、修飾されていない name
は name@(replica_id(), partition_id())
の省略形と見なすことができます)。
プロセス間の実行順序は、以下で説明するように、ポイントツーポイント通信と集約オペレーションによって導入される同期を除き、実装で定義されます。
ポイントツーポイント通信
StableHLO プロセスは、StableHLO チャネルを介して相互に通信できます。チャネルは、si64
型の正の ID で表されます。さまざまなオペレーションを使用して、値をチャネルに送信したり、チャネルから受信したりできます。
これらのチャンネル ID の取得元、プロセス プログラムが認識する方法、プロセスによって行われる同期の種類など、さらなる形式化は未定(#484)です。
ストリーミング通信
すべての StableHLO プロセスは、次の 2 つのストリーミング インターフェースにアクセスできます。
- 読み取り可能なInfeed。
- 書き込み可能なアウトフィードの。
プロセス間の通信に使用され、両端にプロセスがあるチャネルとは異なり、インフィードとアウトフィードでは、もう一方の端は実装で定義されます。
ストリーミング通信が実行順序に与える影響や、それによって行われる同期の種類など、さらなる形式化は未定です(#484)。
集団オペレーション
StableHLO には、all_gather
、all_reduce
、all_to_all
、collective_broadcast
、collective_permute
、reduce_scatter
の 6 つのグループ演算があります。これらのオペレーションはすべて、StableHLO プロセス グリッド内のプロセスを StableHLO プロセス グループに分割し、他のプロセス グループとは独立して、各プロセス グループ内で共同計算を実行します。
各プロセス グループ内で、集約オペレーションによって同期バリアが発生する可能性があります。この同期が正確にいつ行われるか、プロセスがこの障壁に到達する正確な方法、到達しなかった場合にどうなるかなど、さらなる形式化は未定です(#484)。
プロセス グループにパーティション間の通信が含まれる場合(パーティション ID が異なるプロセスがプロセス グループ内にある場合)、集約オペレーションの実行にはチャネルが必要であり、集約オペレーションは si64
型の正の channel_id
を提供する必要があります。レプリカ間の通信にはチャネルは必要ありません。
集約オペレーションによって実行される計算は個々のオペレーションに固有であり、上記の個々のオペレーションのセクションで説明しています。ただし、プロセス グリッドをプロセス グループに分割する戦略はこれらのオペレーション間で共有され、このセクションで説明します。より正式には、StableHLO は次の 4 つの戦略をサポートしています。
cross_replica
各プロセス グループ内では、レプリカ間の通信のみが発生します。この戦略では、replica_groups
(レプリカ ID のリストのリストのリスト)を受け取り、replica_groups
と partition_ids
の直積を計算します。replica_groups
には一意の要素があり、すべての replica_ids
をカバーする必要があります。より正式には、Python 構文を使用します。
def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
for partition_id in partition_ids:
process_group = []
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
たとえば、replica_groups = [[0, 1], [2, 3]]
と num_partitions = 2
の場合、cross_replica
は [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]]
を生成します。
cross_partition
各プロセス グループ内で行われる通信は、パーティション間の通信のみです。この戦略は、partition_groups
(パーティション ID のリストのリスト)を受け取り、replica_ids
による partition_groups
のデカルト積を計算します。partition_groups
には一意の要素があり、すべての partition_ids
をカバーする必要があります。より正式には、Python 構文を使用して次のようにします。
def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
for partition_group in partition_groups:
for replica_id in replica_ids:
process_group = []
for partition_id in partition_group:
process_group.append((replica_id, partition_id))
yield process_group
たとえば、partition_groups = [[0, 1]]
と num_replicas = 4
の場合、cross_partition
は [[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]]
を生成します。
cross_replica_and_partition
各プロセス グループ内で、レプリカ間通信とパーティション間通信の両方が発生する可能性があります。この戦略は、レプリカ ID のリストのリストである replica_groups
を受け取り、partition_ids
によって各 replica_group
のデカルト積を計算します。replica_groups
は一意の要素を持ち、すべての replica_ids
をカバーする必要があります。より正式には、Python 構文を使用して次のようにします。
def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
process_group = []
for partition_id in partition_ids:
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
たとえば、replica_groups = [[0, 1], [2, 3]]
と num_partitions = 2
の場合、cross_replica_and_partition
は [[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]]
を生成します。
flattened_ids
この戦略では、flattened_id_groups
(replica_id * num_partitions + partition_id
形式の「フラット化」されたプロセス ID のリスト)を受け取り、プロセス ID に変換します。flattened_id_groups
は一意の要素を持ち、すべての process_ids
をカバーする必要があります。より正式には、Python 構文を使用して次のようにします。
def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
for flattened_id_group in flattened_id_groups:
process_group = []
for flattened_id in flattened_id_group:
replica_id = flattened_id // num_partitions
partition_id = flattened_id % num_partitions
process_group.append((replica_id, partition_id))
yield process_group
たとえば、flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]]
、num_replicas = 4
、num_partitions = 2
の場合、flattened_ids
は [[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]]
を生成します。
精度
現時点では、StableHLO は数値の精度を保証していませんが、今後変更される可能性があります(#1156)。
量子化オペレーションの実行セマンティクス
量子化された StableHLO 演算の解釈は、ハードウェアの要件と機能によって異なる場合があります。たとえば、一部のハードウェアでは、「逆量子化し、浮動小数点演算を実行して、最終的に量子化」戦略を使用して量子化演算を解釈します。他の関数では、整数演算で計算全体を実行することもあります。したがって、量子化された StableHLO オペレーションの解釈は、特定の実装によってのみ決まります。ハイブリッド量子化(#1575)の解釈は、仕様(1792 を介して)で規定されているそのセマンティクスに基づく必要があります。
エラー
StableHLO プログラムは、個々のオペレーションに対する広範な制約セットによって検証され、実行前に多くのエラークラスが除外されます。ただし、整数オーバーフローや境界外アクセスなどのエラー条件は引き続き発生する可能性があります。明示的に示されていない限り、これらのエラーはすべて実装定義の動作になりますが、今後変更される可能性があります(#1157)。
浮動小数点例外
このルールの例外として、StableHLO プログラムの浮動小数点例外には明確に定義された動作があります。IEEE-754 標準で定義されている例外(無効なオペレーション、ゼロ除算、オーバーフロー、アンダーフロー、不正確な例外)が発生するオペレーションは、(標準で定義されているように)デフォルトの結果を生成し、対応するステータス フラグを発生させずに実行を続行します。これは、標準の raiseNoFlag
例外処理に似ています。標準以外の演算(複素演算や特定の超越関数など)の例外は、実装で定義されます。
形状の不一致
StableHLO は、動的に形状を変更できるテンソルをサポートしています。ただし、実行時にシェイプが一致する必要があります。一致しない場合、動作は未定義になります。StableHLO では、実行時にテンソルが特定の形状であることをアサートできる op が明示的に提供されていません。正しいコードを生成するのは、プロデューサーの責任です。
具体的な例として、以下のプログラムは有効です。ただし、実行時には %arg0
と %arg1
の正確な形状が同じである必要があります。一致していない場合、プログラムの動作は未定義になります。
func.func @foo(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<?xi32>
return %0 : tensor<?xi32>
}
Notation
このドキュメントでは、構文の記述に、EBNF 構文の ISO フレーバー(ISO/IEC 14977:1996、Wikipedia)を変更して使用しています。変更点は 2 つあります。1)ルールは =
ではなく ::=
を使用して定義されます。
2)連結は ,
ではなく並置を使用して表されます。
セマンティクスの記述(「型」、「定数」、「オペレーション」セクション内)では、Python 構文に基づく式を使用しています。この式は、以下で説明するように、配列オペレーションを簡潔に表現できるように拡張されています。これは小さなコード スニペットには適していますが、大きなコード スニペットが必要なまれなケースでは、常に明示的に導入される標準の Python 構文を使用します。
数式
dot_general
仕様の例に基づいて、数式の仕組みを見てみましょう。このオペレーションの制約の 1 つは次のとおりです。dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)
。
この式で使用される名前は、2 つのソースから取得されます。1 つはグローバル関数(dim
など)、もう 1 つは対応するプログラム要素のメンバー定義(dot_general
の [入力] セクションで定義された lhs
、lhs_batching_dimensions
、rhs
、rhs_batching_dimensions
の入力)です。
前述のとおり、この式の構文は Python ベースで、簡潔さを重視した拡張機能がいくつかあります。式を理解するために 通常の Python 構文に変換しましょう
A)これらの数式では、=
を使用して等号を表しています。Python 構文を取得するための最初のステップは、=
を ==
に置き換えることです。dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...)
のようにします。
B)また、これらの式では、スカラー式をテンソル式に変換する楕円記号(...
)もサポートされています。簡単に言うと、f(xs...)
は「テンソル xs
内のスカラー x
ごとにスカラー f(x)
を計算し、これらのスカラー結果をすべてテンソル結果として返す」ことを意味します。通常の Python 構文の場合、数式の例は [dim(lhs, dim1) for dim1 in lhs_batching_dimensions] ==
[dim(rhs, dim2) for dim2 in rhs_batching_dimensions]
になります。
楕円記法を使用すると、個々のスカラーレベルで作業する必要がなくなります。ただし、場合によっては、gather
仕様の start_indices[bi0, ..., :, ..., biN]
数式のように、下位レベルの準非形式構文が使用されることがあります。簡潔にするため、このような構文を標準の Python に変換するための正確な形式主義は提供していません。ケースごとに直感的に理解できるようにすることを目的としています。特定の式がわかりにくい場合はお知らせください。改善を検討いたします。
また、式では楕円を使用して、テンソル、テンソルのリスト(可変長のテンソルから発生する可能性があるものなど)など、あらゆる種類のリストを展開します。これは、正確な形式主義を提供しない(リストは StableHLO 型システムの一部ではないなど)もう 1 つの領域であり、直感的な理解に依存しています。
C)最後に、暗黙的なブロードキャストという注目すべき記法について説明します。StableHLO opset は暗黙的なブロードキャストをサポートしていませんが、簡潔さを提供する目的で数式はサポートしています。簡単に言うと、テンソルが予想されるコンテキストでスカラーを使用すると、スカラーは期待される形状にブロードキャストされます。
dot_general
の例を続けると、別の制約 0 <= lhs_batching_dimensions < rank(lhs)
があります。dot_general
仕様で定義されているように、lhs_batching_dimensions
はテンソルですが、0
と rank(lhs)
はどちらもスカラーです。暗黙的なブロードキャストを適用すると、式は [0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)]
になります。
特定の dot_general
演算に適用されると、この数式はブール値のテンソルに評価されます。数式を制約として使用する場合、数式が true
と評価された場合、または true
要素のみを含むテンソルと評価された場合に制約が適用されます。
名前
数式では、1)グローバル関数、2)メンバー定義が、レキシカル スコープに含まれます。
3)ローカル定義グローバル関数のリストを以下に示します。要素定義のリストは、表記が適用されるプログラム要素によって異なります。
- オペレーションの場合、メンバー定義には「入力」セクションと「出力」セクションで説明した名前が含まれます。
- それ以外の場合、メンバー定義には、対応する EBNF 非終端記号にちなんで命名された、プログラム要素の構造部分が含まれます。ほとんどの場合、これらの構造部分の名前は、非終端記号の名前をスネークケースに変換することで取得されます(例:
IntegerLiteral
=>integer_literal
)。ただし、このプロセスで名前が省略されることもあります(例:QuantizationStorageType
=>storage_type
)。その場合は、オペレーション仕様の [入力] / [出力] セクションと同様に、名前が明示的に導入されます。 - また、メンバー定義には常に
self
が含まれ、対応するプログラム要素を参照します。
値
数式の評価では、次の種類の値が使用されます。1)Value
(実際の値、例: dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
。型は常に判明しています)、2)Placeholder
(将来の値、例: lhs
、rhs
、result
。実際の値はまだ判明しておらず、型のみが判明しています)、3)Type
([型] セクションで定義されている型)、4)Function
([関数] セクションで定義されているグローバル関数)。
コンテキストによっては、名前が異なる値を参照している場合があります。具体的には、オペレーションの「セマンティクス」セクション(および他のプログラム要素の同等のセクション)でランタイム ロジックが定義されているため、すべての入力を Value
として使用できます。これに対して、op(および同等のもの)の「制約」セクションでは「コンパイル時」ロジック、つまり通常はランタイム前に実行されるロジックを定義しています。そのため、定数入力のみが Value
として使用でき、その他の入力は Placeholder
としてのみ利用できます。
名前 | 「Semantics」 | [制約] で |
---|---|---|
グローバル機能 | Function |
Function |
定数入力 | Value |
Value |
定数以外の入力 | Value |
Placeholder |
出力 | Value |
Placeholder |
ローカル定義 | 定義によって異なる | 定義によって異なる |
transpose
オペレーションの例を見てみましょう。
%result = "stablehlo.transpose"(%operand) {
permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
このオペレーションでは、permutation
は定数であるため、セマンティクスと制約の両方で Value
として使用できます。一方、operand
と result
は、セマンティクスでは Value
として使用できますが、制約では Placeholder
としてのみ使用できます。
関数
型の構築
型の作成に使用できる関数はありません。通常は、より簡潔なタイプ構文を直接使用します。たとえば、function_type(
[tensor_type([], E), tensor_type([], E)], [tensor_type([], E)])
ではなく (tensor<E>, tensor<E>) -> (tensor<E>)
です。
型の関数
element_type
はテンソル型と量子化テンソル型で定義され、それぞれ対応するTensorType
またはQuantizedTensorType
のTensorElementType
またはQuantizedTensorElementType
部分を返します。
def element_type(x: Value | Placeholder | Type):
if type(x) == TensorType:
return tensor_element_type(x)
if type(x) == QuantizedTensorType:
return quantized_tensor_element_type(x)
if type(x) is not Type:
return element_type(type(x))
is_per_axis_quantized(x: Value | Placeholder | Type) -> Value
はis_quantized(x) and quantization_dimension(x) is not None
のショートカットです。is_per_tensor_quantized(x: Value | Placeholder | Type) -> Value
はis_quantized(x) and quantization_dimension(x) is None
のショートカットです。is_promotable(x: Type, y: Type) -> bool
は、x
型をy
型に昇格できるかどうかを確認します。x
とy
がQuantizedTensorElementType
の場合、プロモーションはstorage_type
にのみ適用されます。この特定のバージョンのプロモーションは、現在、削減計算のコンテキストで使用されています(詳細については、RFC をご覧ください)。
def is_promotable(x: Type, y: Type) -> Value:
is_same_type = (is_bool(x) and is_bool(y)) or
(is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
(is_complex(x) and is_complex(y)) or
(is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))
if is_same_type == False:
return False
if is_integer(x) or is_float(x):
return bitwidth(x) <= bitwidth(y)
if is_complex(x):
return bitwidth(element_type(x)) <= bitwidth(element_type(y))
if is_quantized(x):
return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))
return false
is_quantized(x: Value | Placeholder | Type) -> Value
はis_quantized_tensor_element_type(x)
のショートカットです。is_type_name(x: Value | Placeholder | Type) -> Value
。すべてのタイプで使用できます。たとえば、x
がFloatType
の場合、is_float(x)
はtrue
を返します。x
が値またはプレースホルダの場合、この関数はis_type_name(type(x))
のショートカットになります。max_value(x: Type) -> Value
はTensorElementType
の最大値を返します。x
がTensorElementType
でない場合、None
を返します。min_value(x: Type) -> Value
は、TensorElementType
の可能な最小値を返します。x
がTensorElementType
でない場合、None
を返します。member_name(x: Value | Placeholder | Type) -> Any
。すべてのタイプのすべてのメンバー定義member_name
で使用できます。たとえば、tensor_element_type(x)
は対応するTensorType
のTensorElementType
部分を返します。x
が値またはプレースホルダの場合、この関数はmember_name(type(x))
のショートカットです。x
が適切なメンバーを持つ型、またはそのような型の値またはプレースホルダでない場合、None
を返します。is_empty_algorithm(*args: Type)
は、すべてのドット アルゴリズム フィールドがNone
に設定されているかどうかを確認します。これは、ドット アルゴリズムの実装でデフォルトの動作が定義されているため、デフォルト値を指定すると正しくありません。
値の構成
operation_name(*xs: Value | Type) -> Value
。すべてのオペレーションで使用できます。たとえば、add(lhs, rhs)
は 2 つのテンソル値lhs
とrhs
を受け取り、これらの入力でadd
演算を評価した結果を返します。broadcast_in_dim
などの一部のオペレーションでは、出力のタイプが「負荷を負う」タイプです。つまり、オペレーションの評価に必要です。この場合、関数はこれらの型を引数として受け取ります。
値の関数
Python のすべての演算子と関数を使用できます。たとえば、Python のサブスクリプションとスライスの両方の表記を使用して、テンソル、量子化テンソル、タプルにインデックスを付けることができます。
to_destination_type(x: Value, destination_type: Type) -> Value
はテンソルで定義され、次のようにtype(x)
とdestination_type
に基づいてx
の変換値を返します。
def to_destination_type(x: Value, destination_type: Type) -> Value:
if type(x) == destination_type:
return x
if is_quantized(destination_type):
if is_quantized(type(x)):
return quantize(x, destination_type)
assert is_float(type(x))
return quantize(x, destination_type)
if is_quantized(type(x)):
assert destination_type = expressed_type(type(x))
return dequantize(type(x))
return convert(x, destination_type)
convert
、uniform_quantize
、uniform_dequantize
オペレーションのマージについては、早期の議論があります(#1576)。
マージ後、上記の関数は不要になり、代わりに convert
のオペレーション名を使用できます。
is_nan(x: Value) -> Value
はテンソルで定義され、x
のすべての要素がNaN
の場合はtrue
を返します。それ以外の場合はfalse
を返します。x
がテンソルでない場合、None
を返します。is_sorted(x: Value) -> Value
はテンソルで定義され、x
の要素がインデックスの辞書順の昇順で並べ替えられている場合はtrue
を返し、そうでない場合はfalse
を返します。x
がテンソルでない場合は、None
を返します。is_unique(x: Value) -> Value
はテンソルで定義され、x
に重複する要素がない場合true
を返します。それ以外の場合はfalse
を返します。x
がテンソルでない場合は、None
を返します。member_name(x: Value) -> Any
は、すべての値のすべてのメンバー定義member_name
に対して定義されます。たとえば、real_part(x)
は対応するComplexConstant
のRealPart
部分を返します。x
が適切なメンバーを含む値でない場合、None
を返します。same(x: Value) -> Value
はテンソルに対して定義され、x
の要素がすべて同じであればtrue
を返し、それ以外の場合はfalse
を返します。テンソルに要素がない場合、それは「すべて等しい」と見なされます。つまり、関数はtrue
を返します。x
がテンソルでない場合は、None
を返します。split(x: Value, num_results: Value, axis: Value) -> Value
はテンソルで定義され、軸axis
に沿ってx
のnum_results
スライスを返します。x
がテンソルまたはdim(x, axis) % num_results != 0
でない場合は、None
を返します。is_defined_in_parent_scope(x: Value) -> Value
は文字列で定義され、x
が関連するオペレーションの親関数と同じスコープで定義された関数の名前である場合、true
を返します。is_namespaced_op_name(x: Value) -> Value
は文字列で定義され、x
が有効なオペレーション名([a-zA-Z][a-zA-Z0-9_]*([.][a-zA-Z0-9_$]+)+
という正規表現に従う)である場合はtrue
を返します。
シェイプの計算
axes(x: Value | Placeholder | Type) -> Value
はrange(rank(x))
のショートカットです。dim(x: Value | Placeholder | Type, axis: Value) -> Value
はshape(x)[axis]
のショートカットです。dims(x: Value | Placeholder | Type, axes: List) -> List
はlist(map(lambda axis: dim(x, axis), axes))
のショートカットです。index_space(x: Value | Placeholder | Type) -> Value
はテンソルで定義され、対応するTensorType
のsize(x)
インデックスを辞書順で昇順([0, ..., 0]
、[0, ..., 1]
、...、shape(x) - 1
)で返します。x
がテンソル型、量子化テンソル型、値、またはこれらの型のプレースホルダでない場合、None
を返します。rank(x: Value | Placeholder | Type) -> Value
はsize(shape(x))
のショートカットです。shape(x: Value | Placeholder | Type) -> Value
は、「型の関数」セクションでmember_name
を介して定義されます。size(x: Value | Placeholder | Type) -> Value
はreduce(lambda x, y: x * y, shape(x))
のショートカットです。
量子化計算
def baseline_element_type(x: Value | Placeholder | Type) -> Type
はelement_type(baseline_type(x))
のショートカットです。baseline_type
はテンソル型と量子化テンソル型で定義され、それらを「ベースライン」に変換します。ベースラインとは、同じ形状で、要素型の量子化パラメータがデフォルト値にリセットされた型です。これは、テンソルと量子化テンソルの両方のタイプを均一に比較するための便利なトリックとして使用されます。これは非常に頻繁に必要になります。量子化タイプの場合、量子化パラメータを無視してタイプを比較できます。つまり、shape
、storage_type
、expressed_type
、storage_min
、storage_max
、quantization_dimension
(軸ごとの量子化タイプの場合)はすべて一致する必要がありますが、scales
とzero points
は異なる場合があります。
def baseline_type(x: Value | Placeholder | Type) -> Type:
if type(x) == TensorType:
return x
if type(x) == QuantizedTensorType:
element_type = quantized_tensor_element_type(x)
baseline_element_type = QuantizedTensorElementType(
storage_type = storage_type(element_type),
storage_min = storage_min(element_type),
storage_max = storage_max(element_type),
expressed_type = expressed_type(element_type),
quantization_dimension = quantization_dimension(element_type),
scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
return QuantizedTensorType(shape(x), baseline_element_type)
if type(x) is not Type:
return baseline_element_type(type(x))
dequantize
は量子化テンソル型に対して定義され、浮動小数点テンソル型に変換します。これは、量子化要素型に関連付けられたゼロ点とスケールを使用して、ストレージ型の整数値を表す量子化要素を、表現型の対応する浮動小数点値に変換することで行われます。
def compute_zero_points(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
zero_points[i] = zero_points(quantized_type)[i[d]]
return zero_points
def compute_scales(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
type(result_type))
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
scales[i] = scales(quantized_type)[i[d]]
return scales
def dequantize(x: Value) -> Value:
assert is_quantized(x)
x_storage = bitcast_convert(x, storage_type(x))
x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
x_expressed_sub = convert(x_storage_sub, expressed_type(x))
return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
quantize
は浮動小数点テンソル型で定義され、量子化されたテンソル型に変換します。これは、量子化された要素型に関連付けられたゼロ点とスケールを使用して、表現型の浮動小数点値をストレージ型の対応する整数値に変換することで行われます。
def quantize(x: Value, result_type: Type) -> Value:
assert is_float(x) and is_quantized(result_type)
zero_points = compute_zero_points(result_type, TensorType(shape(x), storage_type(result_type)))
converted_zero_points = convert(zero_points, expressed_type(result_type))
converted_min = convert(storage_min(result_type), expressed_type(result_type))
converted_max = convert(storage_max(result_type), expressed_type(result_type))
x_scaled = x / compute_scales(result_type, type(x))
x_scaled_add_zp = x_scaled + converted_zero_points
x_clamped = clamp(converted_min, x_scaled_add_zp, converted_max)
x_rounded = round_nearest_even(x_clamped)
return convert(x_rounded, result_type)
dequantize_op_quantize
は、量子化テンソルの要素単位の計算を指定するために使用されます。デクォンタイズ(量子化された要素を表現型に変換)してから演算を実行し、量子化(結果をストレージ タイプに戻す)します。現時点では、この関数はテンソルごとの量子化でのみ機能します。軸ごとの量子化は現在開発中です(#1574)。
def dequantize_op_quantize(op, *inputs_and_output_type):
inputs = inputs_and_output_type[:-1]
output_type = inputs_and_output_type[-1]
float_inputs = map(dequantize, inputs)
float_result = op(*float_inputs)
return quantize(float_result, output_type)
def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
inputs = inputs_and_output_type[:-3]
float_inputs = map(dequantize, inputs)
float_results = op(*float_inputs)
return map(quantize, float_results, inputs_and_output_type[-3:])
def dequantize_compare(lhs, rhs, comparison_direction):
float_lhs = dequantize(lhs)
float_rhs = dequantize(rhs)
return compare(float_lhs, float_rhs, comparison_direction, FLOAT)
def dequantize_select_quantize(pred, on_true, on_false, output_type):
float_on_true = dequantize(on_true)
float_on_false = dequantize(on_false)
float_result = select(pred, float_on_true, float_on_false)
return quantize(float_result, output_type)
hybrid_dequantize_then_op
は、浮動小数点型の lhs と量子化された型の rhs を受け入れるハイブリッド演算の重みのみの量子化を指定するために使用されます。量子化された入力を表現型にデクォンタイズし、浮動小数点数で計算を行います。浮動小数点の左辺テンソルの要素型と、量子化された右辺テンソルの表現型は同じである必要があります。
def hybrid_dequantize_then_op(op, lhs, rhs):
assert(is_float(lhs) and is_quantized(rhs) and element_type(lhs) == expressed_type(rhs))
return op(lhs, dequantize(rhs))
グリッド計算
cross_partition(replica_groups: Value) -> Value
。上記の「cross_replica」セクションをご覧ください。cross_replica(replica_groups: Value) -> Value
。上記の「cross_replica」セクションをご覧ください。cross_replica_and_partition(replica_groups: Value) -> Value
。上記の「cross_replica_and_partition」セクションをご覧ください。flattened_ids(replica_groups: Value) -> Value
。上記の「flattened_ids」セクションをご覧ください。
ダイナミズム
StableHLO 値には、動的なディメンション サイズ(tensor<?xi64>
など)を指定できます。ただし、StableHLO 値には動的ディメンション数(ランクなしの動的性、tensor<*xi64>
など)を指定できません。オペランドと結果では、サイズに制約がある場合でも、動的ディメンション サイズを使用できます。制約は可能であれば静的に検証されます。検証できない場合は実行時に延期され、不一致により未定義の動作が発生します。例については以下をご覧ください。
単項要素ごとのオペレーションのシェイプの不一致
次のサンプル プログラムを考えてみましょう。
func.func @foo(%arg0: tensor<?xf64>) {
%0 = stablehlo.abs %arg0 : (tensor<?xf64>) -> tensor<2xf64>
return
}
このようなプログラムは珍しいものです。結果の形状はわかっても、入力の形状がわからないことは一般的ではありません。ただし、これは有効な StableHLO プログラムです。このプログラムの abs
オペレーションを静的に検証することはできません。オペランドの正確な形状が不明であるためです。ただし、シェイプは確かに互換性があり、これは静的に確認できます。?
が実行時に 2
になっても問題はありません。ただし、?
が他の整数になる可能性もあります。その場合、動作は未定義になります。
結果でディメンション サイズが動的である場合、未定義の動作は発生しません。実際、サイズは「想定」されていないため、不一致は発生しません。
バイナリ要素ごとのオペレーションのシェイプの不一致
次のおもちゃプログラムについて考えてみましょう。
func.func @foo(%arg0: tensor<?xf64>, %arg1: tensor<?xf64>) {
%0 = stablehlo.add %arg0, %arg0 : (tensor<?xf64>, tensor<?xf64>) -> tensor<?xf64>
return
}
バイナリ要素ごとのオペレーションの場合、入力と結果の形状は実行時に一致している必要があります。コンパイル時に静的ディメンションは同じである必要があります。それ以外の場合は、互換性がある必要があります。入力でいずれかのディメンションが動的である場合、動的サイズが他のオペランドの対応するサイズ(静的または動的)と一致しないため、ランタイムで未定義の動作が発生する可能性があります。すべての入力が静的である場合、結果が動的かどうかは重要ではありません。静的に既知のディメンションは静的にチェックされ、動的ディメンションには制約が適用されません。
出力シェイプをオペランドとして取るオペレーションのシェイプの不一致
次のサンプル プログラムを考えてみましょう。
func.func @foo(%arg0: tensor<2xi32>) {
%0 = stablehlo.dynamic_iota %arg0, dim = 0 : (tensor<2xi32>) -> tensor<3x4xi64>
return
}
実行時にシェイプ オペランドの値は、結果のシェイプと一致する必要があります。一致しない場合、動作は未定義になります。つまり、実行時に %arg0
の値は dense<[3, 4]> : tensor<2xi32>
にする必要があります。シェイプ オペランドが定数の場合は、静的に検証できます。結果の形状が完全に動的であれば、不一致が生じることはありません。