StableHLO は、機械学習(ML)モデルの高レベル オペレーション(HLO)用のオペレーション セットです。StableHLO は、異なる ML フレームワークと ML コンパイラ間のポータビリティ レイヤとして機能します。StableHLO プログラムを生成する ML フレームワークは、StableHLO プログラムを使用する ML コンパイラと互換性があります。
Google の目標は、さまざまな ML フレームワーク(TensorFlow、JAX、PyTorch など)と ML コンパイラ(XLA、IREE など)間の相互運用性を高めることで、ML 開発を簡素化し、加速させることです。このドキュメントでは、その目的のために StableHLO プログラミング言語の仕様を説明します。
この仕様には、3 つの主要なセクションがあります。まず、プログラムのセクションでは、StableHLO 関数で構成される StableHLO プログラムの構造について説明します。StableHLO 関数自体は StableHLO オペレーションで構成されます。この構造内の [Ops] セクションには、個々のオペレーションのセマンティクスを指定します。[実行] セクションには、プログラム内で一緒に実行されるこれらのオペレーションのすべてのセマンティクスが示されています。最後に、表記セクションでは、仕様全体で使用される表記について説明します。
StableHLO の以前のリリースの仕様を表示するには、目的のタグ付きリリースでリポジトリを開きます。たとえば、StableHLO v0.19.0 仕様。StableHLO のマイナー バージョンの増加ごとに発生した変更を確認するには、VhloDialect.td のバージョンログをご覧ください。
プログラム
Program ::= {Func}
StableHLO プログラムは、任意の数の StableHLO 関数で構成されます。以下は、3 つの入力(%image、%weights、%bias)と 1 つの出力を持つ関数 @main を含むプログラムの例です。関数本体には 6 つのオペレーションがあります。
func.func @main(
  %image: tensor<28x28xf32>,
  %weights: tensor<784x10xf32>,
  %bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
  %0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
  %1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
  %2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  %3 = "stablehlo.constant"() {value = dense<0.0> : tensor<1x10xf32>} : () -> tensor<1x10xf32>
  %4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  "func.return"(%4): (tensor<1x10xf32>) -> ()
}
関数
Func        ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs  ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput   ::= ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput  ::= ValueType
FuncBody    ::= {Op}
StableHLO 関数(名前付き関数とも呼ばれます)には、識別子、入出力、本文があります。今後、HLO との互換性を高めるために、関数に追加のメタデータを導入する予定です(#425、#626、#740、#744)。
識別子
FuncId  ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
          | '%' letter {letter | digit}
letter  ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit   ::= '0' | ... | '9'
StableHLO 識別子は多くのプログラミング言語の識別子に似ていますが、2 つの特殊性があります。1)すべての識別子には、さまざまな種類の識別子を区別するシグルが付いています。2)値識別子は完全に数値にすることができ、StableHLO プログラムの生成を簡素化できます。
型
Type         ::= ValueType | NonValueType
ValueType    ::= TensorType | QuantizedTensorType | TokenType | TupleType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType
StableHLO 型は、StableHLO 値を表す値の型(ファーストクラス型とも呼ばれます)と、他のプログラム要素を記述する値以外の型に分類されます。StableHLO 型は多くのプログラミング言語の型に似ていますが、主な特徴は StableHLO のドメイン固有の性質であり、異常な結果をもたらします(例: スカラー型は値型ではありません)。
TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit} | '?'
テンソル型はテンソル、つまり多次元配列を表します。シェイプと要素型があります。シェイプは、0~R-1 の番号が付けられた対応するディメンション(軸とも呼ばれます)の昇順で、正または不明のディメンション サイズを表します。ディメンション R の数はランクと呼ばれます。たとえば、tensor<2x3xf32> は形状が 2x3 で要素型が f32 のテンソル型です。0 番目のディメンションと 1 番目のディメンションの 2 つのディメンション(つまり 2 つの軸)があり、サイズは 2 と 3 です。ランクは 2 です。
シェイプは、部分的または完全に未知(動的)にできます。たとえば、tensor<?x2xf64> は部分的に未知、tensor<?x?xf64> は完全に未知です。動的ディメンションのサイズは ? を使用して表されます。シェイプのランク付けを解除することはできません。
今後、テンソル型をディメンション サイズと要素型以外にも拡張し、レイアウト(#629)やスパース性(#1078)などを含める予定です。
QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
                  QuantizationStorageType
                  ['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
                  ':' QuantizationExpressedType
                  [':' QuantizationDimension]
                  ',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerLiteral
QuantizationStorageMax ::= IntegerLiteral
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerLiteral
QuantizationParameters ::= QuantizationParameter
                         | '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale [':' QuantizationZeroPoint]
QuantizationScale ::= FloatLiteral
QuantizationZeroPoint ::= IntegerLiteral
| 名前 | タイプ | 制約 | 
|---|---|---|
| storage_type | 整数型 | (C1 ~ C3)、(C8) | 
| storage_min | 整数定数 | (C1)、(C3)、(C7) | 
| storage_max | 整数定数 | (C2)、(C3)、(C7) | 
| expressed_type | floating-point-type | (C4) | 
| quantization_dimension | 省略可能な整数定数 | (C10-C12) | 
| scales | 浮動小数点定数の可変数 | (C4 ~ C6)、(C9)、(C10)、(C13) | 
| zero_points | 可変長整数定数 | (C7 ~ C9) | 
量子化要素の型は、表現型の浮動小数点値に対応する、storage_min~storage_max(両端を含む)の範囲のストレージ型の整数値を表します。指定された整数値 i について、対応する浮動小数点値 f は f = (i - zero_point) * scale として計算できます。ここで、scale と zero_point は量子化パラメータと呼ばれます。storage_min と storage_max は文法では省略可能ですが、デフォルト値はそれぞれ min_value(storage_type) と max_value(storage_type) です。量子化要素タイプには次の制約があります。
- (C1)type(storage_min) = storage_type。
- (C2)type(storage_max) = storage_type。
- (C3)min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type)。
- (C4)type(scales...) = expressed_type。
- (C5)0 < scales。
- (C6)is_finite(scales...)。
- (C7)storage_min <= zero_points <= storage_max。
- (C8)type(zero_points...) = storage_type。
- (C9)size(scales) = size(zero_points)。
- (C10)is_empty(quantization_dimension)の場合はsize(scales) = 1です。
- (C11)0 <= quantization_dimension。
現在、QuantizationScale は浮動小数点定数ですが、乗数とシフトで表される整数ベースのスケールに強い関心があります。近い将来、この問題を調査する予定です(#1404)。
QuantizationZeroPoint のセマンティクス(型、値、量子化されたテンソル型にゼロポイントが 1 つだけか複数ある可能性があるかなど)については、現在も議論が続いています。このディスカッションの結果に基づいて、ゼロポイントに関する仕様は今後変更される可能性があります(#1405)。
現在進行中のもう一つの議論では、QuantizationStorageMin と QuantizationStorageMax のセマンティクスについて、これらの値と量子化テンソルの値に制約を適用すべきかどうかを判断しています(#1406)。
最後に、未知のディメンション サイズの表現(#1407)と同様に、未知のスケールとゼロポイントの表現を検討する予定です。
量子化テンソル型は、量子化された要素を持つテンソルを表します。これらのテンソルは、通常の要素型ではなく、量子化された要素型を持つ点を除き、通常のテンソルとまったく同じです。
量子化されたテンソルでは、量子化はテンソル単位で行うことができます。つまり、テンソル全体に 1 つの scale と zero_point を設定します。または、軸単位で行うこともできます。つまり、特定のディメンション quantization_dimension のスライスごとに 1 組の scales と zero_points を設定します。より正式には、軸ごとの量子化を使用するテンソル t には、quantization_dimension の dim(t, quantization_dimension) スライス(t[:, ..., 0, ..., :], t[:, ..., 1, ..., :] など)があります。i 番目のスライスのすべての要素は、量子化パラメータとして scales[i] と zero_points[i] を使用します。量子化テンソル型には次の制約があります。
- テンソルごとの量子化の場合:
- 追加の制約はありません。
 
- 軸ごとの量子化の場合:
- (C12)quantization_dimension < rank(self)。
- (C13)dim(self, quantization_dimension) = size(scales)。
 
- (C12)
TokenType ::= 'token'
トークン型は、トークン(一部のオペレーションによって生成および消費される不透明な値)を表します。トークンは、実行セクションで説明されているように、オペレーションに実行順序を適用するために使用されます。
TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]
タプル型は、タプル(異種リスト)を表します。タプルは、HLO との互換性を確保するためにのみ存在するレガシー機能です。HLO では、タプルを使用して可変長の入力と出力を表します。StableHLO では、可変長の入力と出力がネイティブにサポートされています。StableHLO でタプルを使用するのは、HLO ABI を包括的に表す場合のみです。たとえば、T、tuple<T>、tuple<tuple<T>> は特定の実装によって大きく異なる場合があります。今後、HLO ABI を変更し、StableHLO からタプル型を削除する予定です(#598)。
TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si2' | 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui2' | 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f4E2M1FN' | 'f6E2M3FN' | 'f6E3M2FN' | 'f8E3M4' | 'f8E4M3'
            | 'f8E4M3FN' | 'f8E4M3FNUZ' | 'f8E4M3B11FNUZ' | 'f8E5M2'
            | 'f8E5M2FNUZ' | 'f8E8M0FNU' | 'bf16' | 'f16' | 'f32' | 'f64'
TensorFloat32 ::= 'tf32'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'
要素型は、テンソル型の要素を表します。多くのプログラミング言語とは異なり、これらの型は StableHLO ではファーストクラスではありません。つまり、StableHLO プログラムでは、これらの型の値を直接表現できません(その結果、T 型のスカラー値を tensor<T> 型の 0 次元テンソル値で表現するのが一般的です)。
- ブール値型は、ブール値 trueとfalseを表します。
- 整数型は、符号付き(si)または符号なし(ui)のいずれかであり、サポートされているビット幅(2、4、8、16、32、64)のいずれかです。符号付きsiN型は、-2^(N-1)~2^(N-1)-1の整数値を表し、符号なしuiN型は、0~2^N-1の整数値を表します。
- 浮動小数点型は次のいずれかになります。
- f8E3M4、- f8E4M3、- f8E5M2: IEEE-754 規則に従う 8 ビットの浮動小数点数。
- f8E4M3FN型と- f8E5M2型は、ディープラーニング用の FP8 形式で説明されている FP8 形式の- E4M3エンコードと- E5M2エンコードに対応しています。
- f8E4M3FNUZ型と- f8E5M2FNUZ型は、ディープ ニューラル ネットワーク用の 8 ビット数値形式で説明されている FP8 形式の- E4M3エンコードと- E5M2エンコードに対応しています。
- ハイブリッド 8 ビット浮動小数点数(HFP8)トレーニングとディープ ニューラル ネットワークの推論で説明されている FP8 形式の E4M3エンコードに対応するf8E4M3B11FNUZタイプ。
- BFloat16: Cloud TPU で高パフォーマンスを発揮させる秘訣で説明されている bfloat16形式に対応するbf16型。
- f16、- f32、- f64型は、IEEE 754 標準で説明されている- binary16(「半精度」)、- binary32(「単精度」)、- binary64(「倍精度」)の各形式に対応しています。
- tf32型は TensorFloat32 形式に対応しており、StableHLO では限定的にサポートされています。
- OCP マイクロスケーリング形式の仕様で説明されている f4E2M1FN、f6E2M3FN、f6E3M2FN、f8E8M0FNUMX(マイクロスケーリング)タイプ。
 
- 複素型は、同じ要素型の実部と虚部を持つ複素値を表します。サポートされている複合型は complex<f32>(どちらもf32型)とcomplex<f64>(どちらもf64型)です。
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]
関数型は、名前付き関数と匿名関数の両方を表します。入力型(-> の左側の型のリスト)と出力の型(-> の右側の型のリスト)があります。多くのプログラミング言語では関数型がファースト クラスですが、StableHLO ではそうではありません。
StringType ::= 'string'
String 型はバイトのシーケンスを表します。多くのプログラミング言語とは異なり、StableHLO では文字列型はファーストクラスではなく、プログラム要素の静的メタデータを指定する場合にのみ使用されます。
運用
StableHLO オペレーション(オペレーションとも呼ばれる)は、ML モデルのクローズド セットの大まかなオペレーションを表します。前述のように、StableHLO 構文は MLIR に大きく影響を受けています。これは必ずしも最も人間工学的な代替手段ではありませんが、ML フレームワークと ML コンパイラ間の相互運用性を高めるという StableHLO の目標には最適です。
Op            ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName        ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic    ::= 'abs' | 'add' | ...
StableHLO オペレーション(オペレーションとも呼ばれる)には、名前、入出力、署名があります。この名前は、stablehlo. 接頭辞と、サポートされているいずれかのオペレーションを一意に識別するニーモニックで構成されます。サポートされているすべてのオペレーションの一覧については、以下をご覧ください。
OpInputs        ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues   ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue    ::= ValueId
OpInputFuncs    ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs    ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs       ::= [OpOutput {',' OpOutput} '=']
OpOutput        ::= ValueId
オペレーションは入力を使用し、出力を生成します。入力は、入力値(実行時に計算される)、入力関数(StableHLO では関数がファーストクラス値ではないため、静的に提供される)、入力属性(これも静的に提供される)に分類されます。op によって消費、生成される入力と出力の種類は、そのニモニックによって異なります。たとえば、add オペレーションは 2 つの入力値を使用し、1 つの出力値を生成します。これに対し、select_and_scatter op は 3 つの入力値、2 つの入力関数、3 つの入力属性を使用します。
OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused      ::= '^' digit {digit}
              | '^' letter {letter | digit}
入力関数(別名: 匿名関数)は、名前付き関数と非常によく似ていますが、1)識別子がない(「匿名」という名前の由来)ことと、2)出力型を宣言しない(出力型は関数内の return オペレーションから推論される)点が異なります。
入力関数の構文には、MLIR との互換性を確保するために現在使用されていない部分(上記の Unused 生産を参照)が含まれています。MLIR には、ジャンプ演算子で接続された複数の演算子の「ブロック」を持つことができる、より一般的な「リージョン」のコンセプトがあります。これらのブロックには、Unused 本番環境に対応する ID が付いているため、ブロック同士を区別できます。StableHLO にはジャンプ演算がないため、MLIR 構文の対応する部分は使用されません(ただし、まだ存在します)。
OpInputAttr      ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName  ::= letter {letter | digit}
OpInputAttrValue ::= Constant
入力属性には、名前と値があります。値は、サポートされている定数の一つです。プログラム要素の静的メタデータを指定する主な方法です。たとえば concatenate 演算では、属性 dimension を使用して、入力値を連結するディメンションを指定します。同様に、slice 演算は start_indices や limit_indices などの複数の属性を使用して、入力値をスライスするために使用される境界を指定します。
現時点では、実際の StableHLO プログラムに、このドキュメントで説明されていない属性が含まれていることがあります。今後、これらの属性を StableHLO オペセットに吸収するか、StableHLO プログラムに出現しないようにする予定です。それまでは以下の属性のリストをご覧ください。
- layout(#629)。
- mhlo.frontend_attributes(#628)。
- mhlo.sharding(#619)。
- output_operand_aliases(#740)。
- 位置情報メタデータ(#594)。
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'
オペレーション シグネチャは、すべての入力値の型(-> の左側の型のリスト)とすべての出力値の型(-> の右側の型のリスト)で構成されます。厳密に言えば、入力型は冗長であり、出力型もほとんどの場合冗長です(ほとんどの StableHLO オペレーションでは、出力型は入力から推測できるため)。ただし、op シグネチャは、MLIR との互換性を確保するために、意図的に StableHLO 構文の一部になっています。
以下に、メモニクスが select_and_scatter の演算子の例を示します。3 つの入力値(%operand、%source、%init_value)、2 つの入力関数、3 つの入力属性(window_dimensions、window_strides、padding)を使用します。op のシグネチャには入力値の型のみが含まれます(インラインで提供される入力関数と属性の型は含まれません)。
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
    %0 = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GE>
    } : (tensor<i32>, tensor<i32>) -> tensor<i1>
    "stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
    "stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
  window_dimensions = dense<[3, 1]> : tensor<2xi64>,
  window_strides = dense<[2, 1]> : tensor<2xi64>,
  padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>
定数
Constant ::= BooleanConstant
           | IntegerConstant
           | FloatConstant
           | ComplexConstant
           | TensorConstant
           | QuantizedTensorConstant
           | StringConstant
           | EnumConstant
StableHLO 定数には、StableHLO 値を表すリテラルと型があります。一般に、型は定数の構文の一部ですが、明確な場合を除きます(たとえば、ブール値定数は明確に型 i1 ですが、整数定数には複数の型が考えられます)。
BooleanConstant ::= BooleanLiteral
BooleanLiteral  ::= 'true' | 'false'
ブール定数は、ブール値 true と false を表します。ブール値定数は i1 型です。
IntegerConstant   ::= IntegerLiteral ':' IntegerType
IntegerLiteral    ::= ['-' | '+'] DecimalDigits
                    | ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits     ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit      ::= '0' | ... | '9'
hexadecimalDigit  ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'
整数定数は、10 進数または 16 進数表記の文字列を使用して整数値を表します。バイナリやオクタルなどの他の基数はサポートされていません。整数定数には次の制約があります。
- (C1)is_wellformed(integer_literal, integer_type)。
FloatConstant  ::= FloatLiteral ':' FloatType
FloatLiteral   ::= SignPart IntegerPart FractionalPart ScientificPart
                 | '0x' [HexadecimalDigits]
SignPart       ::= ['-' | '+']
IntegerPart    ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]
浮動小数点定数は、10 進数または科学的記数法を使用する文字列で浮動小数点値を表します。また、16 進数表記を使用して、対応する型の浮動小数点形式で基盤となるビットを直接指定することもできます。浮動小数点定数には次の制約があります。
- (C1)16 進数以外の表記が使用されている場合、is_wellformed(float_literal, float_type)。
- (C2)16 進数表記を使用している場合、size(hexadecimal_digits) = num_bits(float_type) / 4。
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral  ::= '(' RealPart ',' ImaginaryPart ')'
RealPart        ::= FloatLiteral
ImaginaryPart   ::= FloatLiteral
複素定数は、実部(先頭)と虚部(末尾)のリストを使用して複素値を表します。たとえば、(1.0, 0.0) : complex<f32> は 1.0 + 0.0i を表し、(0.0, 1.0) : complex<f32> は 0.0 + 1.0i を表します。これらのパーツをメモリに保存する順序は実装で定義します。複素定数には次の制約があります。
- (C1)is_wellformed(real_part, complex_element_type(complex_type))。
- (C2)is_wellformed(imaginary_part, complex_element_type(complex_type))。
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral  ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral   ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements  ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral
テンソル定数は、NumPy の表記法で指定されるネストされたリストを使用してテンソル値を表します。たとえば、dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> は、インデックスから要素への次のマッピングを持つテンソル値を表します。{0, 0} => 1、{0, 1} => 2、{0, 2} => 3、{1, 0} => 4、{1, 1} => 5、{1, 2} => 6。これらの要素がメモリに格納される順序は実装によって定義されます。テンソル定数には次の制約があります。
- (C1)has_syntax(tensor_literal, element_type(tensor_type))。ここで、- has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type)。
- has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type)。
 
- (C2)has_shape(tensor_literal, shape(tensor_type))。次に例を示します。- has_shape(element_literal: Syntax, []) = true。
- has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:])。
- それ以外の場合は false。
 
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral  ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
量子化テンソル定数は、テンソル定数と同じ表記を使用して量子化テンソル値を表します。要素はストレージ タイプの定数として指定されます。量子化されたテンソル定数には次の制約があります。
- (C1)has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type))。
- (C2)has_shape(quantized_tensor_literal, shape(quantized_tensor_type))。
StringConstant  ::= StringLiteral
StringLiteral   ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence  ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))
文字列リテラルは、ASCII 文字とエスケープ シーケンスを使用して指定されたバイトで構成されます。エンコードに依存しないため、これらのバイトの解釈は実装で定義されます。文字列リテラルの型は string です。
運用
abs
セマンティクス
operand テンソルに対して要素ごとの絶対値演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- 符号付き整数の場合: 整数の剰余。
- 浮動小数点数: IEEE-754 の abs。
- 複素数の場合: 複素数モジュラス。
- 量子化された型の場合: dequantize_op_quantize(abs, operand, type(result))。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | 符号付き整数型、浮動小数点型、複合型のテンソル、またはテンソルごとの量子化テンソル | (C1 ~ C2) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | 符号付き整数型または浮動小数点型のテンソル、またはテンソルごとの量子化テンソル | (C1-C2) | 
制約
- (C1)shape(result) = shape(operand)。
- (C2)baseline_element_type(result)は次のように定義されます。- is_complex(operand)の場合、- complex_element_type(element_type(operand))。
- そうでない場合は baseline_element_type(operand)。
 
例
// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand) : (tensor<3xi32>) -> tensor<3xi32>
// %result: [2, 0, 2]
追加
セマンティクス
2 つのテンソル lhs と rhs の要素ごとの加算を行い、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- ブール値の場合: 論理 OR。
- 整数の場合: 整数の加算。
- 浮動小数点数の場合: IEEE-754 の addition。
- 複素数の場合: 複素数加算。
- 量子化された型の場合: dequantize_op_quantize(add, lhs, rhs, type(result))。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | lhs | テンソルまたは量子化テンソル | (C1~C6) | 
| (I2) | rhs | テンソルまたは量子化テンソル | (C1 ~ C5)、(C7) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | テンソルまたは量子化テンソル | (C1~C7) | 
制約
- オペレーションで量子化されていないテンソルを使用する場合:
- (C1)type(lhs) = type(rhs) = type(result)。
 
- (C1)
- オペレーションで量子化テンソルを使用する場合:
- (C2)is_quantized(lhs) and is_quantized(rhs) and is_quantized(result)。
- (C3)storage_type(lhs) = storage_type(rhs) = storage_type(result)。
- (C4)expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)。
- (C5)(is_per_axis_quantized(lhs) or is_per_axis_quantized(rhs)) = is_per_axis_quantized(result)。
- (C6)is_per_axis_quantized(lhs)の場合はquantization_dimension(lhs) = quantization_dimension(result)です。
- (C7)is_per_axis_quantized(rhs)の場合はquantization_dimension(rhs) = quantization_dimension(result)です。
 
- (C2)
例
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[6, 8], [10, 12]]
after_all
セマンティクス
inputs を生成するオペレーションが、result に依存するオペレーションの前に実行されるようにします。このオペレーションを実行しても何も行われません。このオペレーションは、result から inputs へのデータ依存関係を確立するためにのみ存在します。
入力
| ラベル | 名前 | タイプ | 
|---|---|---|
| (I1) | inputs | 可変長の tokenの数 | 
出力
| 名前 | タイプ | 
|---|---|
| result | token | 
例
// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token
all_gather
セマンティクス
StableHLO プロセス グリッドの各プロセス グループ内で、各プロセスの operands テンソルの値を all_gather_dim に沿って連結し、results テンソルを生成します。
このオペレーションでは、StableHLO プロセス グリッドが、次のように定義された process_groups に分割されます。
- channel_id <= 0 and use_global_device_ids = falseの場合、- cross_replica(replica_groups)
- channel_id > 0 and use_global_device_ids = falseの場合、- cross_replica_and_partition(replica_groups)
- channel_id > 0 and use_global_device_ids = trueの場合、- flattened_ids(replica_groups)
その後、各 process_group 内で次の操作を行います。
- process_groupのすべての- receiverに対して- operands...@receiver = [operand@sender for sender in process_group]。
- process_groupのすべての- processに対して- results...@process = concatenate(operands...@process, all_gather_dim)。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operands | テンソルの可変数またはテンソルごとの量子化テンソル | (C1)、(C6) | 
| (I2) | all_gather_dim | si64型の定数 | (C1)、(C6) | 
| (I3) | replica_groups | si64型の 2 次元のテンソル定数 | (C2 ~ C4) | 
| (I4) | channel_id | si64型の定数 | (C5) | 
| (I5) | use_global_device_ids | i1型の定数 | (C5) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| results | テンソルの可変数またはテンソルごとの量子化テンソル | (C6) | 
制約
- (C1)0 <= all_gather_dim < rank(operands...)。
- (C2)is_unique(replica_groups)。
- (C3)size(replica_groups)は次のように定義されます。- cross_replicaを使用する場合は- num_replicas。
- cross_replica_and_partitionを使用する場合は- num_replicas。
- flattened_idsを使用する場合は- num_processes。
 
- (C4)0 <= replica_groups < size(replica_groups)。
- (C5)use_global_device_ids = trueの場合、channel_id > 0。
- (C6)type(results...) = type(operands...)(ただし、次の場合は除く)- dim(results..., all_gather_dim) = dim(operands..., all_gather_dim) * dim(process_groups, 1)。
 
例
// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [[1, 2], [3, 4]]
// %operand0@(1, 0): [[5, 6], [7, 8]]
// %operand1@(0, 0): [[11, 12], [13, 14]]
// %operand1@(1, 0): [[15, 16], [17, 18]]
%result:2 = "stablehlo.all_gather"(%operand0, %operand1) {
  all_gather_dim = 1 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  // channel_id = 0
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
  // use_global_device_ids = false
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> (tensor<2x4xi64>, tensor<2x4xi64>)
// %result0@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result0@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result1@(0, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
// %result1@(1, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
all_reduce
セマンティクス
StableHLO プロセス グリッドの各プロセス グループ内で、各プロセスの operands テンソルの値にリダクション関数 computation を適用し、results テンソルを生成します。
このオペレーションは、StableHLO プロセス グリッドを次のように定義された process_groups に分割します。
- channel_id <= 0 and use_global_device_ids = falseの場合、- cross_replica(replica_groups)
- channel_id > 0 and use_global_device_ids = falseの場合、- cross_replica_and_partition(replica_groups)
- channel_id > 0 and use_global_device_ids = trueの場合、- flattened_ids(replica_groups)
その後、各 process_group 内で次の操作を行います。
- results...@process[result_index] = exec(schedule)は任意のバイナリ ツリーです。- scheduleは次のとおりです。- exec(node)=- computation(exec(node.left), exec(node.right))。
- exec(leaf)=- leaf.value。
 
- scheduleは、順序付き走査が- to_destination_type(operands...@process_group...[result_index], type(func_inputs(computation)[0]))である実装定義のバイナリツリーです。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operands | テンソルの可変数またはテンソルごとの量子化テンソル | (C5)、(C6) | 
| (I2) | replica_groups | si64型の 1 次元テンソル定数の可変数 | (C1 ~ C3) | 
| (I3) | channel_id | si64型の定数 | (C4) | 
| (I4) | use_global_device_ids | i1型の定数 | (C4) | 
| (I5) | computation | 関数 | (C5) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| results | テンソルの可変数またはテンソルごとの量子化テンソル | (C6~C7) | 
制約
- (C1)is_unique(replica_groups)。
- (C2)size(replica_groups)は次のように定義されます。- cross_replicaを使用する場合は- num_replicas。
- cross_replica_and_partitionを使用する場合は- num_replicas。
- flattened_idsを使用する場合は- num_processes。
 
- (C3)0 <= replica_groups < size(replica_groups)。
- (C4)use_global_device_ids = trueの場合、channel_id > 0。
- (C5)computationの型は(tensor<E>, tensor<E>) -> (tensor<E>)で、is_promotable(element_type(operand), E)です。
- (C6)shape(results...) = shape(operands...)。
- (C7)element_type(results...) = E。
例
// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [1, 2, 3, 4]
// %operand0@(1, 0): [5, 6, 7, 8]
// %operand1@(0, 0): [9, 10, 11, 12]
// %operand1@(1, 0): [13, 14, 15, 16]
%result:2 = "stablehlo.all_reduce"(%operand0, %operand0) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  // channel_id = 0
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
  // use_global_device_ids = false
} : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>, tensor<4xi64>)
// %result0@(0, 0): [6, 8, 10, 12]
// %result0@(1, 0): [6, 8, 10, 12]
// %result1@(0, 0): [22, 24, 26, 28]
// %result1@(1, 0): [22, 24, 26, 28]
all_to_all
セマンティクス
StableHLO プロセス グリッド内の各プロセス グループ内で、split_dimension に沿って operands テンソルの値を分割し、分割された部分をプロセス間で分散し、分散された部分を concat_dimension に沿って連結して results テンソルを生成します。このオペレーションでは、StableHLO プロセス グリッドが、次のように定義された process_groups に分割されます。
- channel_id <= 0の場合、- cross_replica(replica_groups)。
- channel_id > 0の場合は- cross_partition(replica_groups)。
その後、各 process_group 内で次の操作を行います。
- process_group内のすべての- senderに対して- split_parts...@sender = split(operands...@sender, split_count, split_dimension)。
- scattered_parts...@receiver = [split_parts...@sender[receiver_index] for sender in process_group]は- receiver_index = process_group.index(receiver)です。
- results...@process = concatenate(scattered_parts...@process, concat_dimension)。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operands | テンソルの可変数またはテンソルごとの量子化テンソル | (C1 ~ C3)、(C9) | 
| (I2) | split_dimension | si64型の定数 | (C1)、(C2)、(C9) | 
| (I3) | concat_dimension | si64型の定数 | (C3)、(C9) | 
| (I4) | split_count | si64型の定数 | (C2)、(C4)、(C8)、(C9) | 
| (I5) | replica_groups | si64型の 2 次元テンソル定数 | (C5~C8) | 
| (I6) | channel_id | si64型の定数 | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| results | テンソルの可変数またはテンソルごとの量子化テンソル | (C9) | 
制約
- (C1)0 <= split_dimension < rank(operands...)。
- (C2)dim(operands..., split_dimension) % split_count = 0。
- (C3)0 <= concat_dimension < rank(operands...)。
- (C4)0 < split_count。
- (C5)is_unique(replica_groups)。
- (C6)size(replica_groups)は次のように定義されます。- cross_replicaを使用する場合は- num_replicas。
- cross_partitionを使用する場合は- num_partitions。
 
- (C7)0 <= replica_groups < size(replica_groups)。
- (C8)dim(replica_groups, 1) = split_count。
- (C9)type(results...) = type(operands...)(ただし、split_dimension != concat_dimensionの場合を除く):- dim(results..., split_dimension) = dim(operands..., split_dimension) / split_count。
- dim(results..., concat_dimension) = dim(operands..., concat_dimension) * split_count。
 
例
// num_replicas: 2
// num_partitions: 1
// %operand1@(0, 0): [[1, 2, 3, 4],
//                    [5, 6, 7, 8]]
// %operand1@(1, 0): [[9, 10, 11, 12],
//                    [13, 14, 15, 16]]
// %operand2@(0, 0): [[17, 18, 19, 20],
//                    [21, 22, 23, 24]]
// %operand2@(1, 0): [[25, 26, 27, 28],
//                    [29, 30, 31, 32]]
%result:2 = "stablehlo.all_to_all"(%operand1, %operand2) {
  split_dimension = 1 : i64,
  concat_dimension = 0 : i64,
  split_count = 2 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
  // channel_id = 0
} : (tensor<2x4xi64>, tensor<2x4xi64>) -> (tensor<4x2xi64>, tensor<4x2xi64>)
// %result#0@(0, 0): [[1, 2], [5, 6], [9, 10], [13, 14]]
// %result#0@(1, 0): [[3, 4], [7, 8], [11, 12], [15, 16]]
// %result#1@(0, 0): [[17, 18], [21, 22], [25, 26], [29, 30]]
// %result#1@(1, 0): [[19, 20], [23, 24], [27, 28], [31, 32]]
と
セマンティクス
2 つのテンソル lhs と rhs の要素ごとの AND を実行し、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- ブール値の場合: 論理 AND。
- 整数の場合: ビット演算 AND。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | lhs | ブール型または整数型のテンソル | (C1) | 
| (I2) | rhs | ブール値型または整数型のテンソル | (C1) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | ブール型または整数型のテンソル | (C1) | 
制約
- (C1)type(lhs) = type(rhs) = type(result)。
例
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 0]]
atan2
セマンティクス
lhs テンソルと rhs テンソルの要素ごとの atan2 演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- 浮動小数点数の場合: IEEE-754 の atan2。
- 複素数の場合: 複素 atan2。
- 量子化された型の場合: dequantize_op_quantize(atan2, lhs, rhs, type(result))。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | lhs | 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) | 
| (I2) | rhs | 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) | 
制約
- (C1)baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)。
例
// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs) : (tensor<3xf64>, tensor<3xf64>) -> tensor<3xf64>
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]
batch_norm_grad
セマンティクス
grad_output からバックプロパゲートされる batch_norm_training の複数の入力の勾配を計算し、grad_operand、grad_scale、grad_offset テンソルを生成します。より正式には、このオペレーションは、次のように Python 構文を使用して既存の StableHLO オペレーションへの分解として表現できます。
def compute_sum(operand, feature_index):
  (sum,) = reduce(
      inputs=[operand],
      init_values=[constant(0, element_type(operand))],
      dimensions=[i for i in range(rank(operand)) if i != feature_index],
      body=lambda x, y: add(x, y))
  return sum
def compute_mean(operand, feature_index):
  sum = compute_sum(operand, feature_index)
  divisor = constant(size(operand) / dim(operand, feature_index),
                     element_type(operand))
  divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
  return divide(sum, divisor_bcast)
def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
  # Broadcast inputs to type(operand)
  scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
  epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
                                   type(operand))
  # Perform normalization using the provided `mean` and `variance`
  # Intermediate values will be useful for computing gradients
  centered_operand = subtract(operand, mean_bcast)
  stddev = sqrt(add(variance_bcast, epsilon_bcast))
  normalized_operand = divide(centered_operand, stddev)
  # Use the implementation from batchnorm_expander.cc in XLA
  # Temporary variables have exactly the same names as in the C++ code
  elements_per_feature = broadcast_in_dim(
      constant(divide(size(operand), dim(operand, feature_index)),
               element_type(grad_output)),
      [], type(operand))
  i1 = multiply(grad_output, elements_per_feature)
  i2 = broadcast_in_dim(
      compute_sum(grad_output, feature_index), [feature_index], type(operand))
  i3 = broadcast_in_dim(
      compute_sum(multiply(grad_output, centered_operand), feature_index),
      [feature_index], type(operand))
  i4 = multiply(i3, centered_operand)
  i5 = divide(i4, add(variance_bcast, epsilon_bcast))
  i6 = subtract(subtract(i1, i2), i5)
  grad_operand =
      multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
  grad_scale =
      compute_sum(multiply(grad_output, normalized_operand), feature_index)
  grad_offset = compute_sum(grad_output, feature_index)
  return grad_operand, grad_scale, grad_offset
量子化された型の場合は、dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean,
variance, grad_output: batch_norm_grad(operand, scale, mean, variance,
grad_output, epsilon, feature_index), operand, scale, mean, variance,
grad_output, type(grad_operand), type(grad_scale), type(feature_index)) を実行します。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | 浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C1~C3)、(C5) | 
| (I2) | scale | 浮動小数点型またはテンソルごとの量子化型の 1 次元テンソル | (C2)、(C4)、(C5) | 
| (I3) | mean | 浮動小数点型またはテンソルごとの量子化型の 1 次元テンソル | (C2)、(C4) | 
| (I4) | variance | 浮動小数点型またはテンソルごとの量子化型の 1 次元テンソル | (C2)、(C4) | 
| (I5) | grad_output | 浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C2)、(C3) | 
| (I6) | epsilon | f32型の定数 | |
| (I7) | feature_index | si64型の定数 | (C1)、(C5) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| grad_operand | 浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C2)、(C3) | 
| grad_scale | 浮動小数点数またはテンソルごとの量子化型の 1 次元テンソル | (C2)、(C4) | 
| grad_offset | 浮動小数点型またはテンソルごとの量子化型の 1 次元テンソル | (C2)、(C4) | 
制約
- (C1)0 <= feature_index < rank(operand)。
- (C2)operand、scale、mean、variance、grad_output、grad_operand、grad_scale、grad_offsetのbaseline_element_typeが同じである。
- (C3)operand、grad_output、grad_operandは同じ形状です。
- (C4)scale、mean、variance、grad_scale、grad_offsetの形状が同じである。
- (C5)size(scale) = dim(operand, feature_index)。
例
// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
//                [[0.1, 0.1], [0.1, 0.1]],
//                [[0.1, 0.1], [0.1, 0.1]]
//               ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>,
     tensor<2x2x2xf64>) -> (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %grad_operand: [
//                 [[0.0, 0.0], [0.0, 0.0]],
//                 [[0.0, 0.0], [0.0, 0.0]]
//                ]
// %grad_scale:  [0.0, 0.0]
// %grad_offset: [0.4, 0.4]
batch_norm_inference
セマンティクス
feature_index 次元を除くすべての次元で operand テンソルを正規化し、result テンソルを生成します。より正式には、このオペレーションは、Python 構文を使用して既存の StableHLO オペレーションへの分解として次のように表現できます。
def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
  # Broadcast inputs to shape(operand)
  scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
  offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
  epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
                                   type(operand))
  # Perform normalization using the provided `mean` and `variance` instead of
  # computing them like `batch_norm_training` does.
  centered_operand = subtract(operand, mean_bcast)
  stddev = sqrt(add(variance_bcast, epsilon_bcast))
  normalized_operand = divide(centered_operand, stddev)
  return add(multiply(scale_bcast, normalized_operand), offset_bcast)
量子化された型の場合は、dequantize_op_quantize(lambda operand, scale, offset, mean, variance:
batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index), operand, scale, offset, mean, variance, type(result)) を実行します。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | 浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C1~C7) | 
| (I2) | scale | 浮動小数点数またはテンソルごとの量子化型の 1 次元テンソル | (C2)、(C3) | 
| (I3) | offset | 浮動小数点型またはテンソルごとの量子化型の 1 次元テンソル | (C2)、(C4) | 
| (I4) | mean | 浮動小数点数またはテンソルごとの量子化型の 1 次元テンソル | (C5) | 
| (I5) | variance | 浮動小数点数またはテンソルごとの量子化型の 1 次元テンソル | (C2)、(C6) | 
| (I6) | epsilon | f32型の定数 | |
| (I7) | feature_index | si64型の定数 | (C1)、(C3 ~ C6) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | 浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C2)、(C7) | 
制約
- (C1)0 <= feature_index < rank(operand)。
- (C2)operand、scale、offset、mean、variance、resultには同じbaseline_element_typeがあります。
- (C3)size(scale) = dim(operand, feature_index)。
- (C4)size(offset) = dim(operand, feature_index)。
- (C5)size(mean) = dim(operand, feature_index)。
- (C6)size(variance) = dim(operand, feature_index)。
- (C7)baseline_type(operand) = baseline_type(result)。
例
// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>) -> tensor<2x2x2xf64>
// %result: [
//           [[0.0, 0.0], [2.0, 2.0]],
//           [[2.0, 2.0], [0.0, 0.0]]
//          ]
batch_norm_training
セマンティクス
feature_index ディメンション以外のすべてのディメンションで平均と分散を計算し、operand テンソルを正規化して output、batch_mean、batch_var テンソルを生成します。より正式には、このオペレーションは、Python 構文を使用して既存の StableHLO オペレーションへの分解として次のように表現できます。
def compute_mean(operand, feature_index):
  (sum,) = reduce(
      inputs=[operand],
      init_values=[constant(0, element_type(operand))],
      dimensions=[i for i in range(rank(operand)) if i != feature_index],
      body=lambda x, y: add(x, y))
  divisor = constant(size(operand) / dim(operand, feature_index),
                     element_type(operand))
  divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
  return divide(sum, divisor_bcast)
def compute_variance(operand, feature_index):
  mean = compute_mean(operand, feature_index)
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  centered_operand = subtract(operand, mean_bcast)
  return compute_mean(mul(centered_operand, centered_operand), feature_index)
def batch_norm_training(operand, scale, offset, epsilon, feature_index):
  mean = compute_mean(operand, feature_index)
  variance = compute_variance(operand, feature_index)
  return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
                              feature_index),
         mean, variance
量子化された型の場合は、dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset:
batch_norm_training(operand, scale, offset, epsilon, feature_index), operand,
scale, offset, type(output), type(batch_mean), type(batch_var)) を実行します。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | 浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C1) | 
| (I2) | scale | 浮動小数点またはテンソルごとの量子化された 1 次元テンソル | (C2)、(C3) | 
| (I3) | offset | 浮動小数点またはテンソルごとの量子化された 1 次元テンソル | (C2)、(C4) | 
| (I4) | epsilon | f32型の定数 | (C1)、(C3 ~ C6) | 
| (I5) | feature_index | si64型の定数 | (C1)、(C3 ~ C6) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| output | 浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C7) | 
| batch_mean | 浮動小数点数またはテンソルごとに量子化された 1 次元テンソル | (C2)、(C5) | 
| batch_var | 浮動小数点数またはテンソルごとに量子化された 1 次元テンソル | (C2)、(C6) | 
制約
- (C1)0 <= feature_index < rank(operand)。
- (C2)operand、scale、offset、batch_mean、batch_var、outputは同じbaseline_element_typeを持ちます。
- (C3)size(scale) = dim(operand, feature_index)。
- (C4)size(offset) = dim(operand, feature_index)。
- (C5)size(batch_mean) = dim(operand, feature_index)。
- (C6)size(batch_var) = dim(operand, feature_index)。
- (C7)baseline_type(output) = baseline_type(operand)。
例
// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) ->
    (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %output: [
//           [[0.0, 0.0], [2.0, 2.0]],
//           [[2.0, 2.0], [0.0, 0.0]]
//          ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]
bitcast_convert
セマンティクス
operand テンソルにビットキャスト演算を実行し、result テンソルを生成します。ここで、operand テンソル全体のビットは result テンソルの型を使用して再解釈されます。
より正式には、E = element_type(operand)、E' = element_type(result)、R = rank(operand) を指定すると、次のようになります。
- num_bits(E') < num_bits(E)の場合は- bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1])。
- num_bits(E') > num_bits(E)の場合、- bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :])。
- num_bits(E') = num_bits(E)の場合は- bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1])。
bits は、指定された値のメモリ内表現を返します。テンソルの正確な表現は実装定義であり、要素型の正確な表現も実装定義であるため、その動作は実装定義です。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | テンソルまたは量子化テンソル | (C1 ~ C2) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | テンソルまたは量子化テンソル | (C1 ~ C2) | 
制約
- (C1)E = is_quantized(operand) ? storage_type(operand) : element_type(operand)、E' = is_quantized(result) ? storage_type(result) : element_type(result)、R = rank(operand)が指定されている場合:- num_bits(E') = num_bits(E)の場合は- shape(result) = shape(operand)。
- num_bits(E') < num_bits(E)の場合:
- rank(result) = R + 1。
- dim(result, i) = dim(operand, i)はすべての- 0 <= i < Rに適用されます。
- dim(result, R) * num_bits(E') = num_bits(E)。
- num_bits(E') > num_bits(E)の場合:
- rank(result) = R - 1。
- dim(result, i) = dim(operand, i)はすべての- 0 <= i < Rに適用されます。
- dim(operand, R - 1) * num_bits(E) = num_bits(E')。
 
- (C2)is_complex(operand) or is_complex(result)の場合はis_complex(operand) and is_complex(result)です。
例
// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand) : (tensor<f64>) -> tensor<4xf16>
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation
broadcast_in_dim
セマンティクス
operand テンソルにデータを複製することで入力テンソルの次元やランクを展開し、result テンソルを生成します。より正式には、result[result_index] = operand[operand_index] で、axes(operand) 内のすべての d について次のようにします。
- dim(operand, d) = 1の場合、- operand_index[d] = 0。
- それ以外の場合は operand_index[d] = result_index[broadcast_dimensions[d]]。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | テンソルまたは量子化テンソル | (C1-C2)、(C5-C6) | 
| (I2) | broadcast_dimensions | si64型の 1 次元テンソル定数 | (C2~C6) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | テンソルまたは量子化テンソル | (C1)、(C3)、(C5~C6) | 
制約
- (C1)element_type(result)は次のように定義されます。- element_type(operand)(- !is_per_axis_quantized(operand)の場合)。
- element_type(operand)(ただし、- quantization_dimension(operand)、- scales(operand)、- zero_points(operand)は、それぞれ- quantization_dimension(result)、- scales(result)、- zero_points(result)と異なる場合があります)。
 
- (C2)size(broadcast_dimensions) = rank(operand)。
- (C3)0 <= broadcast_dimensions < rank(result)。
- (C4)is_unique(broadcast_dimensions)。
- (C5)axes(operand)内のすべてのdについて:- dim(operand, d) = 1または
- dim(operand, d) = dim(result, broadcast_dimensions[d])。
 
- (C6)is_per_axis_quantized(result)の場合:- quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)]。
- dim(operand, quantization_dimension(operand)) = 1の場合は- scales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result)))です。
 
例
// %operand: [
//            [1, 2, 3]
//           ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
  broadcast_dimensions = array<i64: 2, 1>
} : (tensor<1x3xi32>) -> tensor<2x3x2xi32>
// %result: [
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ],
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ]
//          ]
ケース
セマンティクス
index の値に応じて、branches から関数を 1 つだけ実行することで、出力を生成します。より正式には、result = selected_branch() のようにします。
- 0 <= index < size(branches)の場合、- selected_branch = branches[index]。
- それ以外の場合は selected_branch = branches[-1]。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | index | si32型の 0 次元テンソル | |
| (I2) | branches | 関数の可変数 | (C1~C4) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| results | テンソル、量子化テンソル、またはトークンの可変数 | (C4) | 
制約
- (C1)0 < size(branches)。
- (C2)input_types(branches...) = []。
- (C3)same(output_types(branches...))。
- (C4)type(results...) = output_types(branches[0])。
例
// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
  "stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> ()
}, {
  "stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> ()
}) : (tensor<i32>) -> (tensor<2xi64>, tensor<2xi64>)
// %result0: [1, 1]
// %result1: [1, 1]
CBRT
セマンティクス
operand テンソルに対して要素単位の立方根演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の操作を行います。
- 浮動小数点数の場合: IEEE-754 の rootn(x, 3)。
- 複素数の場合: 複素立方根。
- 量子化型の場合: dequantize_op_quantize(cbrt, operand, type(result))
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) | 
制約
- (C1)baseline_type(operand) = baseline_type(result)。
例
// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand) : (tensor<4xf64>) -> tensor<4xf64>
// %result: [0.0, 1.0, 2.0, 3.0]
ceil
セマンティクス
operand テンソルの要素単位の ceil を実行し、result テンソルを生成します。IEEE-754 仕様の roundToIntegralTowardPositive オペレーションを実装します。量子化された型の場合は、dequantize_op_quantize(ceil, operand, type(result)) を実行します。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | 浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C1) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | 浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C1) | 
制約
- (C1)baseline_type(operand) = baseline_type(result)。
例
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]
cholesky
セマンティクス
行列のバッチの Cholesky 分解を計算します。
より正式には、index_space(result) のすべての i について、result[i0, ..., iR-3, :, :] は a[i0, ..., iR-3, :, :] の Cholesky 分解であり、下三角(lower が true の場合)または上三角(lower が false の場合)行列のいずれかになります。反対の三角形(厳密な上三角形または厳密な下三角形)の出力値は、実装で定義されます。
入力行列がエルミート正定値行列ではない i が存在する場合、動作は未定義です。
量子化された型の場合は、dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result)) を実行します。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | a | 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1~C3) | 
| (I2) | lower | 型 i1の 0 次元テンソル定数 | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) | 
制約
- (C1)baseline_type(a) = baseline_type(result)。
- (C2)2 <= rank(a)。
- (C3)dim(a, -2) = dim(a, -1)。
例
// %a: [
//      [1.0, 2.0, 3.0],
//      [2.0, 20.0, 26.0],
//      [3.0, 26.0, 70.0]
//     ]
%result = "stablehlo.cholesky"(%a) {
  lower = true
} : (tensor<3x3xf32>) -> tensor<3x3xf64>
// %result: [
//           [1.0, 0.0, 0.0],
//           [2.0, 4.0, 0.0],
//           [3.0, 5.0, 6.0]
//          ]
クランプ
セマンティクス
operand テンソルのすべての要素を最小値と最大値の間にクランプし、result テンソルを生成します。より正式には、result[result_index] =
minimum(maximum(operand[result_index], min_element), max_element)(min_element = rank(min) = 0 ? min[] : min[result_index]、max_element = rank(max) = 0 ? max[] : max[result_index])です。量子化された型の場合は、dequantize_op_quantize(clamp, min, operand, max, type(result)) を実行します。
複素数の順序付けには意外なセマンティクスが伴うため、将来的には、この演算での複素数のサポートは終了する予定です(#560)。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | min | テンソルまたはテンソルごとの量子化テンソル | (C1)、(C3) | 
| (I2) | operand | テンソルまたはテンソルごとの量子化テンソル | (C1~C4) | 
| (I3) | max | テンソルまたはテンソルごとの量子化テンソル | (C2)、(C3) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | テンソルまたはテンソルごとの量子化テンソル | (C4) | 
制約
- (C1)rank(min) = 0 or shape(min) = shape(operand)。
- (C2)rank(max) = 0 or shape(max) = shape(operand)。
- (C3)baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max)。
- (C4)baseline_type(operand) = baseline_type(result)。
例
// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max) : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// %result: [5, 13, 20]
collective_broadcast
セマンティクス
StableHLO プロセス グリッドの各プロセス グループ内で、operand テンソルの値をソースプロセスからターゲット プロセスに送信し、result テンソルを生成します。
このオペレーションでは、StableHLO プロセス グリッドが、次のように定義された process_groups に分割されます。
- channel_id <= 0の場合は- cross_replica(replica_groups)。
- channel_id > 0の場合、- cross_partition(replica_groups)。
その後、result@process は次のようになります。
- operand@process_groups[i, 0]: プロセスが- process_groups[i]にあるような- iが存在する場合。
- それ以外の場合は broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | テンソルまたはテンソルごとの量子化テンソル | (C3) | 
| (I2) | replica_groups | si64型の 1 次元テンソル定数の可変数 | (C1)、(C2) | 
| (I3) | channel_id | si64型の定数 | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | テンソルまたはテンソルごとの量子化テンソル | (C3) | 
制約
- (C1)is_unique(replica_groups)。
- (C2)0 <= replica_groups < N。ここで、Nは次のように定義されます。- cross_replicaを使用する場合は- num_replicas。
- cross_partitionを使用する場合は- num_partitions。
 
- (C3)type(result) = type(operand)。
例
// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
  replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]
collective_permute
セマンティクス
StableHLO プロセス グリッド内の各プロセス グループ内で、operand テンソルの値をソースプロセスからターゲット プロセスに送信し、result テンソルを生成します。
このオペレーションは、StableHLO プロセス グリッドを次のように定義された process_groups に分割します。
- channel_id <= 0の場合、- cross_replica(source_target_pairs)。
- channel_id > 0の場合は- cross_partition(source_target_pairs)。
その後、result@process は次のようになります。
- operand@process_groups[i, 0]。- process_groups[i, 1] = processのような- iが存在する場合。
- それ以外の場合は broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | テンソルまたはテンソルごとの量子化テンソル | (C5) | 
| (I2) | source_target_pairs | si64型の 2 次元のテンソル定数 | (C1~C4) | 
| (I3) | channel_id | si64型の定数 | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | テンソルまたはテンソルごとの量子化テンソル | (C1) | 
制約
- (C1)dim(source_target_pairs, 1) = 2。
- (C2)is_unique(source_target_pairs[:, 0])。
- (C3)is_unique(source_target_pairs[:, 1])。
- (C4)0 <= source_target_pairs < N。ここで、Nは次のように定義されます。- cross_replicaを使用する場合は- num_replicas。
- cross_partitionを使用する場合は- num_partitions。
 
- (C5)type(result) = type(operand)。
例
// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
  source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]
比較
セマンティクス
comparison_direction と compare_type に従って lhs テンソルと rhs テンソルの要素ごとの比較を行い、result テンソルを生成します。
comparison_direction と compare_type の値のセマンティクスは次のとおりです。
ブール値と整数の要素型の場合:
- EQ:- lhs = rhs
- NE:- lhs != rhs
- GE:- lhs >= rhs
- GT:- lhs > rhs
- LE:- lhs <= rhs
- LT:- lhs < rhs
compare_type = FLOAT の浮動小数点要素型の場合、op は次の IEEE-754 演算を実装します。
- EQ:- compareQuietEqual
- NE:- compareQuietNotEqual
- GE:- compareQuietGreaterEqual
- GT:- compareQuietGreater
- LE:- compareQuietLessEqual
- LT:- compareQuietLess
compare_type = TOTALORDER を含む浮動小数点要素型の場合、op は IEEE-754 の totalOrder オペレーションと compareQuietEqual オペレーションの組み合わせを使用します。
複雑な要素型の場合、指定された comparison_direction と compare_type を使用して、(real, imag) ペアの辞書順による比較が行われます。複素数に順序付けを適用すると、予期しないセマンティクスが発生するため、comparison_direction が GE、GT、LE、または LT の場合、複素数のサポートを削除する予定です(#560)。
量子化された型の場合、dequantize_compare(lhs, rhs,
comparison_direction) を実行します。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | lhs | テンソルまたはテンソルごとの量子化テンソル | (C1 ~ C3) | 
| (I2) | rhs | テンソルまたはテンソルごとの量子化テンソル | (C1-C2) | 
| (I3) | comparison_direction | EQ、NE、GE、GT、LE、LTの列挙型 | |
| (I4) | compare_type | FLOAT、TOTALORDER、SIGNED、UNSIGNEDの列挙型 | (C3) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | ブール型のテンソル | (C2) | 
制約
- (C1)baseline_element_type(lhs) = baseline_element_type(rhs)。
- (C2)shape(lhs) = shape(rhs) = shape(result)。
- (C3)compare_typeは次のように定義されます。- is_signed_integer(element_type(lhs))の場合は- SIGNED。
- is_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs))の場合は- UNSIGNED。
- is_float(element_type(lhs))の場合は- FLOATまたは- TOTALORDER。
- is_complex(element_type(lhs))の場合は- FLOAT。
 
例
// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
  comparison_direction = #stablehlo<comparison_direction LT>,
  compare_type = #stablehlo<comparison_type FLOAT>
} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
// %result: [true, false]
複雑
セマンティクス
実数値と虚数値のペア(lhs と rhs)から複合値に要素単位で変換し、result テンソルを生成します。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | lhs | f32型またはf64型のテンソル | (C1~C3) | 
| (I2) | rhs | f32型またはf64型のテンソル | (C1) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | 複合型テンソル | (C2)、(C3) | 
制約
- (C1)type(lhs) = type(rhs)。
- (C2)shape(result) = shape(lhs)。
- (C3)element_type(result)の型はcomplex<E>で、E = element_type(lhs)です。
例
// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs) : (tensor<2xf64>, tensor<2xf64>) -> tensor<2xcomplex<f64>>
// %result: [(1.0, 2.0), (3.0, 4.0)]
複合
セマンティクス
他の StableHLO オペレーションで構成された(コンポーズされた)オペレーションをカプセル化し、inputs と composite_attributes を受け取って results を生成します。op のセマンティクスは decomposition 属性によって実装されます。composite オペレーションは、プログラムのセマンティクスを変更せずに分解に置き換えることができます。分解をインライン化しても同じオペレーションのセマンティクスが得られない場合は、custom_call を使用することをおすすめします。
version フィールド(デフォルトは 0)は、コンポジットのセマンティクスが変更されたときに使用されます。
入力
| ラベル | 名前 | タイプ | 
|---|---|---|
| (I1) | inputs | 値の可変数 | 
| (I2) | name | string型の定数 | 
| (I3) | composite_attributes | 属性辞書 | 
| (I4) | decomposition | string型の定数 | 
| (I5) | version | si32型の定数 | 
出力
| 名前 | タイプ | 
|---|---|
| results | 値の可変数 | 
制約
- (C1)is_namespaced_op_name(name)
- (C2)is_defined_in_parent_scope(decomposition)
- (C3)types(inputs...) == input_types(decomposition)
- (C4)types(results...) == output_types(decomposition)
例
%results = "stablehlo.composite"(%input0, %input1) {
  name = "my_namespace.my_op",
  composite_attributes = {
    my_attribute = "my_value"
  },
  decomposition = @my_op,
  version = 1 : i32
} : (tensor<f32>, tensor<f32>) -> tensor<f32>
concatenate
セマンティクス
指定された引数と同じ順序で dimension 次元に沿って inputs を連結し、result テンソルを生成します。より正式には、result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1] です。ここで、
- id = d0 + ... + dk-1 + kd。
- dは- dimensionに等しく、- d0は- inputsの- d番目のディメンション サイズです。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | inputs | テンソルの可変数またはテンソルごとの量子化テンソル | (C1~C6) | 
| (I2) | dimension | si64型の定数 | (C2)、(C4)、(C6) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | テンソルまたはテンソルごとの量子化テンソル | (C5-C6) | 
制約
- (C1)same(element_type(inputs...))。
- (C2)dim(inputs..., dimension)を除くsame(shape(inputs...))。
- (C3)0 < size(inputs)。
- (C4)0 <= dimension < rank(inputs[0])。
- (C5)element_type(result) = element_type(inputs[0])。
- (C6)shape(result) = shape(inputs[0])(ただし、次の場合は除く)。- dim(result, dimension) = dim(inputs[0], dimension) + ...。
 
例
// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
  dimension = 0 : i64
} : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64>
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]
定数
セマンティクス
定数 value から output テンソルを生成します。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | value | 定数 | (C1) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| output | テンソルまたは量子化テンソル | (C1) | 
制約
- (C1)type(value) = type(output)。
例
%output = "stablehlo.constant"() {
  value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
} : () -> tensor<2x2xf32>
// %output: [[0.0, 1.0], [2.0, 3.0]]
コンバージョン
セマンティクス
operand テンソルで要素型を要素ごとに変換し、result テンソルを生成します。
boolean-to-any-supported-typeでは、値 false はゼロに変換され、値 true は 1 に変換されます。any-supported-type-to-booleanへの変換の場合、ゼロ値は false に変換され、ゼロ以外の値は true に変換されます。複雑な型での動作については以下をご覧ください。
integer-to-integer、integer-to-floating-point、floating-point-to-floating-point を含む変換では、ソース値が宛先の型で正確に表現できる場合、結果値はそのまま表現されます。それ以外の場合、動作は未定です(#180)。
浮動小数点数から整数への変換では、小数部分は切り捨てられます。切り捨てられた値を宛先型で表現できない場合、動作は未定です(#180)。
複素数から複素数への変換は、実部と虚部の変換における浮動小数点数から浮動小数点数への変換と同じ動作になります。
複素数から他の型への変換と他の型から複素数への変換では、それぞれソースの虚数値が無視されるか、宛先の虚数値がゼロになります。実部の変換は浮動小数点変換に従います。
原則として、このオペレーションはデクォンタイズ(量子化テンソルから通常のテンソルへの変換)、量子化(通常のテンソルから量子化テンソルへの変換)、再量子化(量子化テンソル間の変換)を表現できますが、現時点では専用のオペレーションがあります。最初のユースケースの場合は uniform_dequantize、2 つ目のユースケースと 3 つ目のユースケースの場合は uniform_quantize です。今後、これらの 2 つのオペレーションは convert に統合される可能性があります(#1576)。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | テンソル | (C1) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | テンソル | (C1) | 
制約
- (C1)shape(operand) = shape(result)。
例
// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand) : (tensor<3xi64>) -> tensor<3xcomplex<f64>>
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]
畳み込み
セマンティクス
lhs のウィンドウと rhs のスライス間のドット積を計算し、result を生成します。次の図は、具体的な例を使用して、result の要素が lhs と rhs からどのように計算されるかを示しています。
より正式には、lhs のウィンドウを表現できるようにするために、次のように lhs の観点から入力をフレーミングすることを検討します。
- lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension))。
- lhs_window_strides = lhs_shape(1, window_strides, 1)
- lhs_padding = lhs_shape([0, 0], padding, [0, 0])
- lhs_base_dilations = lhs_shape(1, lhs_dilation, 1)
- lhs_window_dilations = lhs_shape(1, rhs_dilation, 1)。
このフレーム変更では、次のヘルパー関数を使用します。
- lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension])。
- result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension])。
- permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1]は- j[d] = i[permutation[d]]です。
feature_group_count = 1 かつ batch_group_count = 1 の場合、index_space(dim(result, output_spatial_dimensions...)) 内のすべての output_spatial_index について、result[result_shape(:, output_spatial_index, :)] = dot_product です。ここで、
- padding_value = constant(0, element_type(lhs))。
- padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1)
- lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides
- lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations)。
- reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true])。この機能は使用されていないと思われるため、今後削除する予定です(#1181)。
- dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension])。
feature_group_count > 1 の場合:
- lhses = split(lhs, feature_group_count, input_feature_dimension)。
- rhses = split(rhs, feature_group_count, kernel_output_feature_dimension)
- results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...)
- result = concatenate(results, output_feature_dimension)。
batch_group_count > 1 の場合:
- lhses = split(lhs, batch_group_count, input_batch_dimension)。
- rhses = split(rhs, batch_group_count, kernel_output_feature_dimension)
- results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...)。
- result = concatenate(results, output_feature_dimension)。
量子化された型の場合は、dequantize_op_quantize(
    lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
        lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
        input_feature_dimension, input_spatial_dimensions,
        kernel_input_feature_dimension, kernel_output_feature_dimension,
        kernel_spatial_dimensions, output_batch_dimension,
        output_feature_dimension, output_spatial_dimensions,
        feature_group_count, batch_group_count, precision_config), lhs, rhs,
        type(result)) を実行します。
ハイブリッド量子化タイプの場合は、hybrid_dequantize_then_op(
    lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
        lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
        input_feature_dimension, input_spatial_dimensions,
        kernel_input_feature_dimension, kernel_output_feature_dimension,
        kernel_spatial_dimensions, output_batch_dimension,
        output_feature_dimension, output_spatial_dimensions,
        feature_group_count, batch_group_count, precision_config), lhs, rhs) を実行します。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | lhs | テンソルまたはテンソルごとの量子化テンソル | (C1)、(C10~C11)、(C14)、(C25)、(C27~C28)、(C31~C32)、(C34) | 
| (I2) | rhs | テンソルまたは量子化テンソル | (C1)、(C14 ~ C16)、(C25)、(C27 ~ C29)、(C31 ~ C34) | 
| (I3) | window_strides | si64型の 1 次元テンソル定数 | (C2~C3)、(C25) | 
| (I4) | padding | si64型の 2 次元テンソル定数 | (C4)、(C25) | 
| (I5) | lhs_dilation | si64型の 1 次元テンソル定数 | (C5 ~ C6)、(C25) | 
| (I6) | rhs_dilation | si64型の 1 次元のテンソル定数 | (C7 ~ C8)、(C25) | 
| (I7) | window_reversal | i1型の 1 次元テンソル定数 | (C9) | 
| (I8) | input_batch_dimension | si64型の定数 | (C10)、(C13)、(C25) | 
| (I9) | input_feature_dimension | si64型の定数 | (C11)、(C13~C14) | 
| (I10) | input_spatial_dimensions | si64型の 1 次元のテンソル定数 | (C12)、(C13)、(C25) | 
| (I11) | kernel_input_feature_dimension | si64型の定数 | (C14)、(C18) | 
| (I12) | kernel_output_feature_dimension | si64型の定数 | (C15 ~ C16)、(C18)、(C25)、(C29) | 
| (I13) | kernel_spatial_dimensions | si64型の 1 次元テンソル定数 | (C17~C18)、(C25) | 
| (I14) | output_batch_dimension | si64型の定数 | (C20)、(C25) | 
| (I15) | output_feature_dimension | si64型の定数 | (C20)、(C25)、(C30) | 
| (I16) | output_spatial_dimensions | si64型の 1 次元テンソル定数 | (C19 ~ C20)、(C25) | 
| (I17) | feature_group_count | si64型の定数 | (C11)、(C14)、(C16)、(C21)、(C23) | 
| (I18) | batch_group_count | si64型の定数 | (C10)、(C15)、(C22)、(C23)、(C25) | 
| (I19) | precision_config | DEFAULT、HIGH、HIGHESTの可変列挙型の数 | (C24) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | テンソルまたは量子化テンソル | (C25 ~ C28)、(C30)、(C32 ~ 34) | 
制約
- (C1)N = rank(lhs) = rank(rhs)。
- (C2)size(window_strides) = N - 2。
- (C3)0 < window_strides。
- (C4)shape(padding) = [N - 2, 2]。
- (C5)size(lhs_dilation) = N - 2。
- (C6)0 < lhs_dilation。
- (C7)size(rhs_dilation) = N - 2。
- (C8)0 < rhs_dilation。
- (C9)size(window_reversal) = N - 2。
- (C10)dim(lhs, input_batch_dimension) % batch_group_count = 0。
- (C11)dim(lhs, input_feature_dimension) % feature_group_count = 0。
- (C12)size(input_spatial_dimensions) = N - 2。
- (C13)input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]が与えられている場合:- is_unique(input_dimensions)。
- 0 <= input_dimensions < N。
 
- (C14)dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count。
- (C15)dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0。
- (C16)dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0。
- (C17)size(kernel_spatial_dimensions) = N - 2。
- (C18)kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]が与えられている場合:- is_unique(kernel_dimensions)。
- 0 <= kernel_dimensions < N。
 
- (C19)size(output_spatial_dimensions) = N - 2。
- (C20)output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]が与えられている場合:- is_unique(output_dimensions)。
- 0 <= output_dimensions < N。
 
- (C21)0 < feature_group_count。
- (C22)0 < batch_group_count。
- (C23)feature_group_count = 1 or batch_group_count = 1。
- (C24)size(precision_config) = 2。
- (C25)dim(result, result_dim)は次のように定義されます。- result_dim = output_batch_dimensionの場合、- dim(lhs, input_batch_dimension) / batch_group_count。
- result_dim = output_feature_dimensionの場合、- dim(rhs, kernel_output_feature_dimension)。
- それ以外の場合は num_windows。ここで:
- output_spatial_dimensions[spatial_dim] = result_dim。
- lhs_dim = input_spatial_dimensions[spatial_dim]
- rhs_dim = kernel_spatial_dimensions[spatial_dim]
- dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
- padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
- dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
- is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
- num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1。
 
- (C26)rank(result) = N。
- オペレーションで量子化されていないテンソルを使用する場合:
- (C27)element_type(lhs) = element_type(rhs) = element_type(result)。
 
- (C27)
- オペレーションで量子化されたテンソルを使用する場合:
- (C28)is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)。
- (C29)is_per_axis_quantized(rhs)の場合、quantization_dimension(rhs) = kernel_output_feature_dimension。
- (C30)is_per_axis_quantized(result)の場合、quantization_dimension(result) = output_feature_dimension。
- is_quantized(lhs)の場合:
- (C31)storage_type(lhs) = storage_type(rhs)。
- (C32)expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)。
- (C33)is_per_tensor_quantized(rhs)の場合はis_per_tensor_quantized(result)です。
- !is_quantized(lhs)の場合:
- (C34)element_type(lhs) = expressed_type(rhs) = element_type(result)。
 
- (C28)
例
// %lhs: [[
//        [
//          [1], [2], [5], [6]
//        ],
//        [
//          [3], [4], [7], [8]
//        ],
//        [
//          [10], [11], [14], [15]
//        ],
//        [
//          [12], [13], [16], [17]
//        ]
//      ]]
//
// %rhs: [
//        [[[1]], [[1]], [[1]]],
//        [[[1]], [[1]], [[1]]],
//        [[[1]], [[1]], [[1]]]
//       ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
  window_strides = array<i64: 4, 4>,
  padding = dense<0> : tensor<2x2xi64>,
  lhs_dilation = array<i64: 2, 2>,
  rhs_dilation = array<i64: 1, 1>,
  window_reversal = array<i1: false, false>,
  // In the StableHLO dialect, dimension numbers are encoded via:
  // `[<input dimensions>]x[<kernel dimensions>]->[output dimensions]`.
  // "b" is batch dimension, "f" is feature dimension,
  // "i" is input feature dimension, "o" is output feature dimension,
  // "0/1/etc" are spatial dimensions.
  dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
  batch_group_count = 1 : i64,
  feature_group_count = 1 : i64,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
//            [[10], [26]],
//            [[46], [62]]
//          ]]
コサイン
セマンティクス
operand テンソルに対して要素単位のコサイン演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の操作を行います。
- 浮動小数点数の場合: IEEE-754 の cos。
- 複素数の場合: 複素余弦。
- 量子化型の場合: dequantize_op_quantize(cosine, operand, type(result))。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) | 
制約
- (C1)baseline_type(operand) = baseline_type(result)。
例
// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.cosine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.0], [-1.0, 0.0]]
count_leading_zeros
セマンティクス
operand テンソルの先頭のゼロビット数を要素ごとにカウントし、result テンソルを生成します。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | 整数型のテンソル | (C1) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | 整数型のテンソル | (C1) | 
制約
- (C1)type(operand) = type(result)。
例
// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand) : (tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[64, 63], [56, 0]]
custom_call
セマンティクス
inputs と called_computations を受け取って results を生成する、実装定義のオペレーション call_target_name をカプセル化します。has_side_effect、backend_config、api_version を使用して、実装定義のメタデータを追加できます。
現時点では、このオペレーションにはかなり整理されていないメタデータのコレクションが含まれています。これは、XLA コンパイラでの対応するオペレーションの有機的な進化を反映しています。今後、このメタデータを統合する予定です(#741)。
入力
| ラベル | 名前 | タイプ | 
|---|---|---|
| (I1) | inputs | 値の可変数 | 
| (I2) | call_target_name | string型の定数 | 
| (I3) | has_side_effect | i1型の定数 | 
| (I4) | backend_config | string型の定数または属性辞書 | 
| (I5) | api_version | si32型の定数 | 
| (I6) | called_computations | string型の可変長定数 | 
出力
| 名前 | タイプ | 
|---|---|
| results | 可変長の値の数 | 
例
%results = "stablehlo.custom_call"(%input0) {
  call_target_name = "foo",
  has_side_effect = false,
  backend_config = {bar = 42 : i32},
  api_version = 4 : i32,
  called_computations = [@foo]
} : (tensor<f64>) -> tensor<f64>
割る
セマンティクス
除数 lhs テンソルと除算子 rhs テンソルの要素ごとの除算を行い、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- 整数の場合: 小数部を破棄して代数的な商を生成する整数除算。
- 浮動小数点数の場合: IEEE-754 の division。
- 複素数の場合: 複素除算。
- 量子化された型の場合:
- dequantize_op_quantize(divide, lhs, rhs, type(result))。
 
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | lhs | 整数型、浮動小数点型、複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) | 
| (I2) | rhs | 整数型、浮動小数点型、複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | 整数型、浮動小数点型、複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) | 
制約
- (C1)baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)。
例
// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]
dot_general
セマンティクス
lhs のスライスと rhs のスライスの間のドット積を計算し、result テンソルを生成します。
正式には result[result_index] = dot_product です。ここで、
- lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions]。
- rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions]。
- result_batching_index + result_lhs_index + result_rhs_index = result_index(- size(result_batching_index) = size(lhs_batching_dimensions)、- size(result_lhs_index) = size(lhs_result_dimensions)、- size(result_rhs_index) = size(rhs_result_dimensions))。
- transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions)。
- transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :])
- reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions))
- transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions)
- transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :])
- reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions))。
- dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y))。
量子化された型の場合は、dequantize_op_quantize(
    lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
        rhs_batching_dimensions, lhs_contracting_dimensions,
        rhs_contracting_dimensions, precision_config), lhs, rhs, type(result)) を実行します。
ハイブリッド量子化タイプの場合は、hybrid_dequantize_then_op(
    lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
        rhs_batching_dimensions, lhs_contracting_dimensions,
        rhs_contracting_dimensions, precision_config), lhs, rhs) を実行します。
precision_config は、アクセラレータ バックエンドでの計算の速度と精度のトレードオフを制御します。これは次のいずれかです(現時点では、これらの列挙型値のセマンティクスは指定されていません。#755 で対応する予定です)。
- DEFAULT: 計算速度は最も速いが、元の数値に対する近似値の精度は最も低い。
- HIGH: 計算は遅くなりますが、元の数値に近い値をより正確に求めることができます。
- HIGHEST: 計算に最も時間がかかりますが、元の数値に最も近い近似値が得られます。
DotAlgorithm は、ドット演算の実装に使用されるアルゴリズムの主なプロパティを定義します。これにより、精度も定義されます。アルゴリズム属性フィールドが設定されている場合、precision_config は DEFAULT にする必要があります。デフォルト パラメータは実装で定義されるため、DotAlgorithms にデフォルト値はありません。そのため、すべてのドット アルゴリズム フィールドを None に設定して、空のドット アルゴリズムを指定し、代わりに precision_config 値を使用できます。
DotAlgorithm フィールドには次のフィールドが含まれます。
- lhs_precision_typeと- rhs_precision_type: オペレーションの LHS と RHS が丸められる精度。精度タイプは、入力と出力のストレージ タイプとは独立しています。
- accumulation_type累積に使用される精度。
- lhs_component_count、- rhs_component_count、- num_primitive_operationsは、LHS や RHS を複数のコンポーネントに分解し、それらの値に対して複数の「プリミティブ」ドット演算を実行するアルゴリズムを実行する場合に適用されます。通常は、より高い精度をエミュレートします(例: 高精度の計算に bfloat16 AI データ型を活用: btf12_6x3x3)。分解のないアルゴリズムの場合、これらの値は- 1に設定する必要があります。
- allow_imprecise_accumulation: 一部のステップで低精度の累積が許可されるかどうかを指定します(- CUBLASLT_MATMUL_DESC_FAST_ACCUMなど)。
DotAlgorithm 属性の例:
// Inputs are casted to tf32, and then accumulated in f32:
{lhs_precision_type = tf32,
 rhs_precision_type = tf32,
 accumulation_type = f32,
 lhs_component_count = 1,
 rhs_component_count = 1,
 num_primitive_operations = 1,
 allow_imprecise_accumulation = false}
// bf16_6x: each input is decomposed to 3 bf16 components, then 6 dot operations are done on those components, and the result is accumulated in f32.
{lhs_precision_type = bf16,
 rhs_precision_type = bf16,
 accumulation_type = f32,
 lhs_component_count = 3,
 rhs_component_count = 3,
 num_primitive_operations = 6,
 allow_imprecise_accumulation = false}
// Inputs are (casted to) f8e5m2, and we accumulate in f32, but for some steps we may accumulate in lower precision.
{lhs_precision_type = f8e5m2,
 rhs_precision_type = f8e5m2,
 accumulation_type = f32,
 lhs_component_count = 1,
 rhs_component_count = 1,
 num_primitive_operations = 1,
 allow_imprecise_accumulation = true}
どの組み合わせをサポートするかは、実装が決定します。一般に、各アルゴリズムが StableHLO のコンシューマによって各アクセラレータ タイプでサポートされているとは限りません。特定のアルゴリズムがサポートされていない場合は、代替のアルゴリズムにフォールバックするのではなく、エラーが発生します。StableHLO 検証ではベスト エフォート検証が提供され、どのハードウェアでもサポートされていないアルゴリズムが使用されるのを防ぎます。
サポートされているアルゴリズムの値については、xla_data.proto > Algorithm をご覧ください。チケット #2483 には、バックエンドによってサポートされているアルゴリズムに関する一元化されたドキュメントを作成する計画が記載されています。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | lhs | テンソルまたはテンソルごとの量子化テンソル | (C5 ~ C6)、(C9 ~ C10)、(C12 ~ C14)、(C17 ~ C18)、(C20) | 
| (I2) | rhs | テンソルまたは量子化テンソル | (C7~C10)、(C12~C20) | 
| (I3) | lhs_batching_dimensions | si64型の 1 次元テンソル定数 | (C1)、(C3)、(C5)、(C9)、(C12) | 
| (I4) | rhs_batching_dimensions | si64型の 1 次元テンソル定数 | (C1)、(C4)、(C7)、(C9) | 
| (I5) | lhs_contracting_dimensions | si64型の 1 次元テンソル定数 | (C2)、(C3)、(C6)、(C10) | 
| (I6) | rhs_contracting_dimensions | si64型の 1 次元のテンソル定数 | (C2)、(C4)、(C8)、(C10)、(C16) | 
| (I7) | precision_config | DEFAULT、HIGH、HIGHESTの可変長の数値型 | (C11)、(C21) | 
| (I8) | lhs_precision_type | FloatType または TensorFloat32 | (C21) | 
| (I9) | rhs_precision_type | FloatType または TensorFloat32 | (C21) | 
| (I10) | accumulation_type | FloatType または TensorFloat32 | (C21) | 
| (I11) | lhs_component_count | si32型の定数 | (C21)、(C22) | 
| (I12) | rhs_component_count | si32型の定数 | (C21)、(C23) | 
| (I13) | num_primitive_operations | si32型の定数 | (C21)、(C24) | 
| (I14) | allow_imprecise_accumulation | bool型の定数 | (C21) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | テンソルまたは量子化テンソル | (C12)、(C14)、(C18 ~ C20) | 
制約
- (C1)size(lhs_batching_dimensions) = size(rhs_batching_dimensions)。
- (C2)size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions)。
- (C3)is_unique(lhs_batching_dimensions + lhs_contracting_dimensions)。
- (C4)is_unique(rhs_batching_dimensions + rhs_contracting_dimensions)。
- (C5)0 <= lhs_batching_dimensions < rank(lhs)。
- (C6)0 <= lhs_contracting_dimensions < rank(lhs)。
- (C7)0 <= rhs_batching_dimensions < rank(rhs)。
- (C8)0 <= rhs_contracting_dimensions < rank(rhs)。
- (C9)dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)。
- (C10)dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...)。
- (C11)size(precision_config) = 2。
- (C12)shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions)。
- オペレーションで量子化されていないテンソルを使用する場合:
- (C13)element_type(lhs) = element_type(rhs)。
 
- (C13)
- オペレーションで量子化されたテンソルを使用する場合:
- (C14)is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)。
- (C15)zero_points(rhs) = 0。
- (C16)is_per_axis_quantized(rhs)の場合、quantization_dimension(rhs)がrhs_contracting_dimensionsにない。
- is_quantized(lhs)の場合:
- (C17)storage_type(lhs) = storage_type(rhs)。
- (C18)expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)。
- (C19)is_per_tensor_quantized(rhs)の場合、is_per_tensor_quantized(result)。
- !is_quantized(lhs)の場合:
- (C20)element_type(lhs) = expressed_type(rhs) = element_type(result)。
 
- (C14)
- !is_empty_algorithm(lhs_precision_type, rhs_precision_type, accumulation_type, lhs_component_count, rhs_component_count, num_primitive_operations allow_imprecise_accumulation)の場合:- (C21)precision_config... = DEFAULT。
- (C22)0 < lhs_component_count。
- (C23)0 < rhs_component_count。
- (C24)0 < num_primitive_operations。
 
- (C21)
例
// %lhs: [
//        [[1, 2],
//         [3, 4]],
//        [[5, 6],
//         [7, 8]]
//       ]
// %rhs: [
//        [[1, 0],
//         [0, 1]],
//        [[1, 0],
//         [0, 1]]
//       ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
  dot_dimension_numbers = #stablehlo.dot<
    lhs_batching_dimensions = [0],
    rhs_batching_dimensions = [0],
    lhs_contracting_dimensions = [2],
    rhs_contracting_dimensions = [1]
  >,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>],
  algorithm = #stablehlo.dot_algorithm<
    lhs_precision_type = tf32,
    rhs_precision_type = tf32,
    accumulation_type = f32,
    lhs_component_count = 1,
    rhs_component_count = 1,
    num_primitive_operations = 1,
    allow_imprecise_accumulation = false
  >
} : (tensor<2x2x2xi64>, tensor<2x2x2xi64>) -> tensor<2x2x2xi64>
// %result: [
//           [[1, 2],
//            [3, 4]],
//           [[5, 6],
//            [7, 8]]
//          ]
dynamic_broadcast_in_dim
セマンティクス
このオペレーションは、broadcast_in_dim オペレーションと機能的には同じですが、結果の形状は output_dimensions を介して動的に指定されます。
このオペレーションでは、オプションの属性 known_expanding_dimensions、known_nonexpanding_dimensions も受け入れ、ディメンションの展開動作に関する静的な知識を表現します。指定しない場合は、すべてのディメンションが拡張可能であると見なされます。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | テンソルまたは量子化テンソル | (C1-C2)、(C5-C6)、(C9) | 
| (I2) | output_dimensions | 整数型の 1 次元テンソル | (C7) | 
| (I3) | broadcast_dimensions | 整数型の 1 次元定数テンソル | (C2~C6) | 
| (I4) | known_expanding_dimensions | 整数型の 1 次元定数テンソル | (C8~C9) | 
| (I5) | known_nonexpanding_dimensions | 整数型の 1 次元定数テンソル | (C8~C9) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | テンソルまたは量子化テンソル | (C1)、(C3)、(C5~C7) | 
制約
- (C1)element_type(result)は次のように定義されます。- element_type(operand)(- !is_per_axis_quantized(operand)の場合)。
- element_type(operand)(ただし、- quantization_dimension(operand)、- scales(operand)、- zero_points(operand)は、それぞれ- quantization_dimension(result)、- scales(result)、- zero_points(result)と異なる場合があります)。
 
- (C2)size(broadcast_dimensions) = rank(operand)。
- (C3)0 <= broadcast_dimensions < rank(result)。
- (C4)is_unique(broadcast_dimensions)。
- (C5)axes(operand)内のすべてのdについて:- dim(operand, d) = 1または
- dim(operand, d) = dim(result, broadcast_dimensions[d])。
 
- (C6)is_per_axis_quantized(result)の場合:- quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)]。
- dim(operand, quantization_dimension(operand)) = 1の場合は- scales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result)))です。
 
- (C7)size(output_dimensions) = rank(result)。
- (C8)is_unique(known_expanding_dimensions + known_nonexpanding_dimensions)。
- (C9)0 <= known_expanding_dimensions < rank(operand)。
- (C10)0 <= known_nonexpanding_dimensions < rank(operand)。
例
// %operand: [
//            [1, 2, 3]
//           ]
%operand = stablehlo.constant dense<[[1, 2, 3]]> : tensor<1x3xi64>
%output_dimensions = stablehlo.constant dense<[2, 3, 2]> : tensor<3xi64>
%result = "stablehlo.dynamic_broadcast_in_dim"(%operand, %output_dimensions) {
  broadcast_dimensions = array<i64: 2, 1>,
  known_expanding_dimensions = array<i64: 0>,
  known_nonexpanding_dimensions = array<i64: 1>
} : (tensor<1x3xi64>, tensor<3xi64>) -> tensor<2x3x2xi64>
// %result: [
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ],
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ]
//          ]
dynamic_conv
セマンティクス
このオペレーションは、畳み込みオペレーションと機能的に同じですが、パディングは padding を介して動的に指定されます。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | lhs | テンソルまたはテンソルごとの量子化テンソル | (C1)、(C10 ~ C11)、(C14)(C25)、(C26 ~ C27)、(C30 ~ C31)、(C33) | 
| (I2) | rhs | テンソルまたは量子化テンソル | (C1)、(C14 ~ C16)、(C26 ~ C28)、(C30 ~ C33) | 
| (I3) | padding | 整数型の 2 次元テンソル | (C4) | 
| (I4) | window_strides | si64型の 1 次元テンソル定数 | (C2 ~ C3) | 
| (I5) | lhs_dilation | si64型の 1 次元のテンソル定数 | (C5-C6) | 
| (I6) | rhs_dilation | si64型の 1 次元テンソル定数 | (C7 ~ C8) | 
| (I7) | window_reversal | i1型の 1 次元テンソル定数 | (C9) | 
| (I8) | input_batch_dimension | si64型の定数 | (C10)、(C13) | 
| (I9) | input_feature_dimension | si64型の定数 | (C11)、(C13~C14) | 
| (I10) | input_spatial_dimensions | si64型の 1 次元テンソル定数 | (C12)、(C13) | 
| (I11) | kernel_input_feature_dimension | si64型の定数 | (C14)、(C18) | 
| (I12) | kernel_output_feature_dimension | si64型の定数 | (C15~C16)、(C18)、(C28) | 
| (I13) | kernel_spatial_dimensions | si64型の 1 次元テンソル定数 | (C17~C18) | 
| (I14) | output_batch_dimension | si64型の定数 | (C20) | 
| (I15) | output_feature_dimension | si64型の定数 | (C20)、(C29) | 
| (I16) | output_spatial_dimensions | si64型の 1 次元のテンソル定数 | (C19 ~ C20) | 
| (I17) | feature_group_count | si64型の定数 | (C11)、(C14)、(C16)、(C21)、(C23) | 
| (I18) | batch_group_count | si64型の定数 | (C10)、(C15)、(C22)、(C23) | 
| (I19) | precision_config | DEFAULT、HIGH、HIGHESTの可変列挙型の数 | (C24) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | テンソルまたは量子化テンソル | (C25~C27)、(C29)、(C31~C33) | 
制約
- (C1)N = rank(lhs) = rank(rhs)。
- (C2)size(window_strides) = N - 2。
- (C3)0 < window_strides。
- (C4)shape(padding) = [N - 2, 2]。
- (C5)size(lhs_dilation) = N - 2。
- (C6)0 < lhs_dilation。
- (C7)size(rhs_dilation) = N - 2。
- (C8)0 < rhs_dilation。
- (C9)size(window_reversal) = N - 2。
- (C10)dim(lhs, input_batch_dimension) % batch_group_count = 0。
- (C11)dim(lhs, input_feature_dimension) % feature_group_count = 0。
- (C12)size(input_spatial_dimensions) = N - 2。
- (C13)input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]が与えられている場合:- is_unique(input_dimensions)。
- 0 <= input_dimensions < N。
 
- (C14)dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count。
- (C15)dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0。
- (C16)dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0。
- (C17)size(kernel_spatial_dimensions) = N - 2。
- (C18)kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]が与えられている場合:- is_unique(kernel_dimensions)。
- 0 <= kernel_dimensions < N。
 
- (C19)size(output_spatial_dimensions) = N - 2。
- (C20)output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]が与えられている場合:- is_unique(output_dimensions)。
- 0 <= output_dimensions < N。
 
- (C21)0 < feature_group_count。
- (C22)0 < batch_group_count。
- (C23)feature_group_count = 1 or batch_group_count = 1。
- (C24)size(precision_config) = 2。
- (C25)dim(result, result_dim)は次のように定義されます。- result_dim = output_batch_dimensionの場合、- dim(lhs, input_batch_dimension) / batch_group_count。
- result_dim = output_feature_dimensionの場合、- dim(rhs, kernel_output_feature_dimension)。
- それ以外の場合は num_windows。ここで:
- output_spatial_dimensions[spatial_dim] = result_dim。
- lhs_dim = input_spatial_dimensions[spatial_dim]
- rhs_dim = kernel_spatial_dimensions[spatial_dim]
- dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
- padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
- dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
- is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
- num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1。
 
- (C26)rank(result) = N。
- オペレーションで量子化されていないテンソルを使用する場合:
- (C27)element_type(lhs) = element_type(rhs) = element_type(result)。
 
- (C27)
- オペレーションで量子化されたテンソルを使用する場合:
- (C28)is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)。
- (C29)is_per_axis_quantized(rhs)の場合、quantization_dimension(rhs) = kernel_output_feature_dimension。
- (C30)is_per_axis_quantized(result)の場合、quantization_dimension(result) = output_feature_dimension。
- is_quantized(lhs)の場合:
- (C31)storage_type(lhs) = storage_type(rhs)。
- (C32)expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)。
- (C33)is_per_tensor_quantized(rhs)の場合はis_per_tensor_quantized(result)です。
- !is_quantized(lhs)の場合:
- (C34)element_type(lhs) = expressed_type(rhs) = element_type(result)。
 
- (C28)
例
// %lhs: [[
//        [[1], [2], [5], [6]],
//        [[3], [4], [7], [8]],
//        [[10], [11], [14], [15]],
//        [[12], [13], [16], [17]]
//      ]]
//
// %rhs: [
//         [[[1]], [[1]], [[1]]],
//         [[[1]], [[1]], [[1]]],
//         [[[1]], [[1]], [[1]]]
//        ]
// %padding: [[1, 1],
//            [1, 1]]
%result = "stablehlo.dynamic_conv"(%lhs, %rhs, %padding) {
  window_strides = array<i64: 4, 4>,
  lhs_dilation = array<i64: 2, 2>,
  rhs_dilation = array<i64: 1, 1>,
  window_reversal = array<i1: false, false>,
  dimension_numbers = #stablehlo.conv<raw
    input_batch_dimension = 0,
    input_feature_dimension = 3,
    input_spatial_dimensions = [0, 1],
    kernel_input_feature_dimension = 2,
    kernel_output_feature_dimension = 3,
    kernel_spatial_dimensions = [0, 1],
    output_batch_dimension = 0,
    output_feature_dimension = 3,
    output_spatial_dimensions = [1, 2]
  >,
  feature_group_count = 1 : i64,
  batch_group_count = 1 : i64,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>, tensor<2x2xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
//            [[1], [5]],
//            [[10], [14]]
//          ]]
dynamic_gather
セマンティクス
このオペレーションは、slice_sizes が値として動的に指定される gather オペレーションと機能的に同じです。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | テンソルまたはテンソルごとの量子化テンソル | (C1)、(C7)、(C10~C12)、(C14) | 
| (I2) | start_indices | 整数型のテンソル | (C2)、(C3)、(C13) | 
| (I3) | slice_sizes | 整数型の 1 次元テンソル | (C8)、(C11~C13) | 
| (I4) | offset_dims | si64型の 1 次元のテンソル定数 | (C1)、(C4 ~ C5)、(C13) | 
| (I5) | collapsed_slice_dims | si64型の 1 次元テンソル定数 | (C1)、(C6~C8)、(C13) | 
| (I6) | start_index_map | si64型の 1 次元のテンソル定数 | (C3)、(C9)、(C10) | 
| (I7) | index_vector_dim | si64型の定数 | (C2)、(C3)、(C13) | 
| (I8) | indices_are_sorted | i1型の定数 | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | テンソルまたはテンソルごとの量子化テンソル | (C5)、(C13 ~ C14) | 
制約
- (C1)rank(operand) = size(offset_dims) + size(collapsed_slice_dims)。
- (C2)0 <= index_vector_dim <= rank(start_indices)。
- (C3)size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1。
- (C4)is_unique(offset_dims) and is_sorted(offset_dims)。
- (C5)0 <= offset_dims < rank(result)。
- (C6)is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims)。
- (C7)0 <= collapsed_slice_dims < rank(operand)。
- (C8)slice_sizes[collapsed_slice_dims...] <= 1。
- (C9)is_unique(start_index_map)。
- (C10)0 <= start_index_map < rank(operand)。
- (C11)size(slice_sizes) = rank(operand)。
- (C12)0 <= slice_sizes <= shape(operand)。
- (C13)shape(result) = combine(batch_dim_sizes, offset_dim_sizes)ここで:- batch_dim_sizes = shape(start_indices)とほぼ同じですが、- index_vector_dimに対応する- start_indicesのディメンション サイズは含まれません。
- offset_dim_sizes = shape(slice_sizes)。ただし、- collapsed_slice_dimsに対応する- slice_sizesのディメンション サイズは含まれません。
- combineは、- batch_dimsに対応する軸に- batch_dim_sizesを配置し、- offset_dimsに対応する軸に- offset_dim_sizesを配置します。
 
- (C14)element_type(operand) = element_type(result)。
例
// %operand: [
//            [[1, 2], [3, 4], [5, 6], [7, 8]],
//            [[9, 10],[11, 12], [13, 14], [15, 16]],
//            [[17, 18], [19, 20], [21, 22], [23, 24]]
//           ]
// %start_indices: [
//                  [[0, 0], [1, 0], [2, 1]],
//                  [[0, 1], [1, 1], [0, 2]]
//                 ]
// %slize_sizes: [1, 2, 2]
%result = "stablehlo.dynamic_gather"(%operand, %start_indices, %slize_sizes) {
  dimension_numbers = #stablehlo.gather<
    offset_dims = [2, 3],
    collapsed_slice_dims = [0],
    start_index_map = [1, 0],
    index_vector_dim = 2>,
  indices_are_sorted = false
} : (tensor<3x4x2xi64>, tensor<2x3x2xi64>, tensor<3xi64>) -> tensor<2x3x2x2xi64>
// %result: [
//            [
//              [[1, 2], [3, 4]],
//              [[3, 4], [5, 6]],
//              [[13, 14], [15, 16]]
//            ],
//            [
//              [[9, 10], [11, 12]],
//              [[11, 12], [13, 14]],
//              [[17, 18], [19, 20]]
//            ]
//          ]
dynamic_iota
セマンティクス
この演算は iota op と機能的に同じですが、結果の形状は output_shape によって動的に指定されます。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | output_shape | 整数型の 1 次元テンソル | (C1)、(C2) | 
| (I2) | iota_dimension | si64 | (C1) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | 整数型、浮動小数点型、複合型のテンソル、またはテンソルごとの量子化テンソル | (C2) | 
制約
- (C1)0 <= iota_dimension < size(output_shape)。
- (C2)rank(result) = size(output_shape)。
例
%output_shape = stablehlo.constant dense<[4, 5]> : tensor<2xi64>
%result = "stablehlo.dynamic_iota"(%output_shape) {
  iota_dimension = 0 : i64
} : (tensor<2xi64>) -> tensor<4x5xi64>
// %result: [
//           [0, 0, 0, 0, 0],
//           [1, 1, 1, 1, 1],
//           [2, 2, 2, 2, 2],
//           [3, 3, 3, 3, 3]
//          ]
dynamic_pad
セマンティクス
このオペレーションは、pad オペレーションと機能的には同じですが、edge_padding_low、edge_padding_high、interior_padding が値として動的に指定されます。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | テンソルまたはテンソルごとの量子化テンソル | (C1)、(C2)、(C4) | 
| (I2) | padding_value | 0 次元テンソルまたはテンソルごとの量子化テンソル | (C1) | 
| (I3) | edge_padding_low | 整数型の 1 次元テンソル | (C1)、(C4) | 
| (I4) | edge_padding_high | 整数型の 1 次元テンソル | (C1)、(C4) | 
| (I5) | interior_padding | 整数型の 1 次元テンソル | (C2-C4) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | テンソルまたはテンソルごとの量子化テンソル | (C3-C6) | 
制約
- (C1)element_type(operand) = element_type(padding_value) = element_type(result)。
- (C2)size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)。
- (C3)0 <= interior_padding。
- (C4)shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high。
例
// %operand: [
//            [1, 2, 3],
//            [4, 5, 6]
//           ]
// %padding_value: 0
// %edge_padding_low: [0, 1]
// %edge_padding_high: [2, 1]
// %interior_padding: [1, 2]
%result = "stablehlo.dynamic_pad"(%operand, %padding_value,
  %edge_padding_low, %edge_padding_high, %interior_padding
) : (tensor<2x3xi64>, tensor<i64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<5x9xi64>
// %result: [
//           [0, 1, 0, 0, 2, 0, 0, 3, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 4, 0, 0, 5, 0, 0, 6, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0]
//          ]
dynamic_reshape
セマンティクス
このオペレーションは、reshape オペレーションと機能的には同じですが、結果の形状は output_shape を介して動的に指定されます。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | テンソルまたは量子化テンソル | (C1~C3) | 
| (I2) | output_shape | 整数型の 1 次元テンソル | (C4) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | テンソルまたは量子化テンソル | (C1~C4) | 
制約
- (C1)element_type(result)は次のように定義されます。- element_type(operand)(- !is_per_axis_quantized(operand)の場合)。
- element_type(operand)。ただし、- quantization_dimension(operand)と- quantization_dimension(result)が異なる場合があります。
 
- (C2)size(operand) = size(result)。
- (C3)is_per_axis_quantized(operand)の場合:- reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)。
- dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
- reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)。
 
- (C4)size(output_shape) = rank(result)。
例
// %operand: [[1, 2, 3], [4, 5, 6]]
// %output_shape: [3, 2]
%result = "stablehlo.dynamic_reshape"(%operand, %output_shape) : (tensor<2x3xi64>, tensor<2xi64>) -> tensor<3x2xi64>
// %result: [[1, 2], [3, 4], [5, 6]]
dynamic_slice
セマンティクス
動的に計算された開始インデックスを使用して operand からスライスを取り出し、result テンソルを生成します。start_indices には、調整の対象となる各ディメンションのスライスの開始インデックスが含まれ、slice_sizes には各ディメンションのスライスのサイズが含まれます。より正式には、result[result_index] = operand[operand_index] です。ここで:
- adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes)。
- operand_index = adjusted_start_indices + result_index。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | テンソルまたはテンソルごとの量子化テンソル | (C1)、(C2)、(C4) | 
| (I2) | start_indices | 整数型の 0 次元テンソルの可変長の数 | (C2)、(C3) | 
| (I3) | slice_sizes | si64型の 1 次元テンソル定数 | (C2)、(C4)、(C5) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | テンソルまたはテンソルごとの量子化テンソル | (C1)、(C5) | 
制約
- (C1)element_type(operand) = element_type(result)。
- (C2)size(start_indices) = size(slice_sizes) = rank(operand)。
- (C3)same(type(start_indices...))。
- (C4)0 <= slice_sizes <= shape(operand)。
- (C5)shape(result) = slice_sizes。
例
// %operand: [
//            [0, 0, 1, 1],
//            [0, 0, 1, 1],
//            [0, 0, 0, 0],
//            [0, 0, 0, 0]
//           ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
  slice_sizes = array<i64: 2, 2>
} : (tensor<4x4xi32>, tensor<i64>, tensor<i64>) -> tensor<2x2xi32>
// %result: [
//           [1, 1],
//           [1, 1]
//          ]
dynamic_update_slice
セマンティクス
start_indices で始まるスライスが update の値で更新される点を除き、operand テンソルと同じ result テンソルを生成します。より正式には、result[result_index] は次のように定義されます。
- update[update_index]:- 0 <= update_index < shape(update)の場合。条件は次のとおりです。- adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update))。
- update_index = result_index - adjusted_start_indices。
 
- それ以外の場合は operand[result_index]。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | テンソルまたはテンソルごとの量子化テンソル | (C1 ~ C4)、(C6) | 
| (I2) | update | テンソルまたはテンソルごとの量子化テンソル | (C2)、(C3)、(C6) | 
| (I3) | start_indices | 整数型の 0 次元テンソルの可変長の数 | (C4)、(C5) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | テンソルまたはテンソルごとの量子化テンソル | (C1) | 
制約
- (C1)type(operand) = type(result)。
- (C2)element_type(update) = element_type(operand)。
- (C3)rank(update) = rank(operand)。
- (C4)size(start_indices) = rank(operand)。
- (C5)same(type(start_indices...))。
- (C6)0 <= shape(update) <= shape(operand)。
例
// %operand: [
//            [1, 1, 0, 0],
//            [1, 1, 0, 0],
//            [1, 1, 1, 1],
//            [1, 1, 1, 1]
//           ]
// %update: [
//           [1, 1],
//           [1, 1]
//          ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
  : (tensor<4x4xi32>, tensor<2x2xi32>, tensor<i64>, tensor<i64>) -> tensor<4x4xi32>
// %result: [
//           [1, 1, 1, 1],
//           [1, 1, 1, 1],
//           [1, 1, 1, 1],
//           [1, 1, 1, 1]
//          ]
指数
セマンティクス
operand テンソルの要素ごとの指数演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- 浮動小数点数の場合: IEEE-754 の exp。
- 複素数の場合: 複素指数関数。
- 量子化された型の場合: dequantize_op_quantize(exponential, operand, type(result))。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) | 
制約
- (C1)baseline_type(operand) = baseline_type(result)。
例
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]
exponential_minus_one
セマンティクス
operand テンソルに対して要素ごとの指数マイナス 1 演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の操作を行います。
- 浮動小数点数: IEEE-754 の expm1。
- 複素数の場合: 複素指数 - 1。
- 量子化された型の場合: dequantize_op_quantize(exponential_minus_one, operand, type(result))。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) | 
制約
- (C1)baseline_type(operand) = baseline_type(result)。
例
// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand) : (tensor<2xf64>) -> tensor<2xf64>
// %result: [0.0, 1.71828187]
fft
セマンティクス
実数と複素数の入力 / 出力の正規化と逆正規化を行います。
fft_type は、次のいずれかです。
- FFT: 複雑な FFT を転送します。
- IFFT: 複素数から複素数への逆 FFT。
- RFFT: 実数から複素数への FFT を前方処理します。
- IRFFT: 実数から複素数への逆 FFT(複素数を受け取り、実数を返します)。
より形式的には、複雑な型の 1 次元テンソルを入力として受け取る関数 fft は、出力と同じ型の 1 次元テンソルを生成し、離散フーリエ変換を計算します。
fft_type = FFT の場合、result は、L = size(fft_length) の連続した L 計算の最終結果として定義されます。たとえば、L = 3 の場合は次のようになります。
- result1[i0, ..., :] = fft(operand[i0, ..., :])。
- result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
- result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])。
さらに、同じ型シグネチャを持ち、fft の逆数を計算する関数 ifft があるとします。
fft_type = IFFT の場合、result は fft_type = FFT の計算の逆数として定義されます。たとえば、L = 3 の場合は次のようになります。
- result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])。
- result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
- result[i0, ..., :] = ifft(result2[i0, ..., :])。
さらに、浮動小数点型の 1 次元テンソルを取る関数 rfft は、同じ浮動小数点セマンティクスの複合型の 1 次元テンソルを生成し、次のように動作します。
- rfft(real_operand) = truncated_result
- complex_operand... = (real_operand..., 0.0)。
- complex_result = fft(complex_operand)
- truncated_result = complex_result[:(rank(complex_result) / 2 + 1)]。
(実数オペランドに対して離散フーリエ変換が計算される場合、結果の最初の N/2 + 1 要素が結果の残りの部分を明確に定義するため、冗長な要素の計算を避けるために rfft の結果は切り捨てられます)。
fft_type = RFFT の場合、result は、L = size(fft_length) の連続した L 計算の最終結果として定義されます。たとえば、L = 3 の場合は次のようになります。
- result1[i0, ..., :] = rfft(operand[i0, ..., :])。
- result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
- result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])。
最後に、同じ型のシグネチャを持つ関数 irfft が与えられ、rfft の逆数を計算します。
fft_type = IRFFT の場合、result は fft_type = RFFT の計算の逆数として定義されます。たとえば、L = 3 の場合は次のようになります。
- result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])。
- result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
- result[i0, ..., :] = irfft(result2[i0, ..., :])。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | 浮動小数点型または複合型のテンソル | (C1)、(C2)、(C4)、(C5) | 
| (I2) | fft_type | FFT、IFFT、RFFT、IRFFTの列挙型 | (C2)、(C5) | 
| (I3) | fft_length | si64型の 1 次元のテンソル定数 | (C1)、(C3)、(C4) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | 浮動小数点数または複素数のテンソル | (C2)、(C4)、(C5) | 
制約
- (C1)size(fft_length) <= rank(operand)。
- (C2)operand要素とresult要素型の関係はさまざまです。- fft_type = FFT、- element_type(operand)、- element_type(result)が同じ複合型の場合。
- fft_type = IFFT、- element_type(operand)、- element_type(result)が同じ複合型の場合。
- fft_type = RFFTの場合、- element_type(operand)は浮動小数点型で、- element_type(result)は同じ浮動小数点セマンティクスの複合型です。
- fft_type = IRFFTの場合、- element_type(operand)は複合型であり、- element_type(result)は同じ浮動小数点セマンティクスの浮動小数点型です。
 
- (C3)1 <= size(fft_length) <= 3。
- (C4)operandとresultの間に浮動小数点型のテンサーrealがある場合、shape(real)[-size(fft_length):] = fft_length。
- (C5)次の点を除き shape(result) = shape(operand)。- fft_type = RFFTの場合は- dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1。
- fft_type = IRFFTの場合は- dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1。
 
例
// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
  fft_type = #stablehlo<fft_type FFT>,
  fft_length = array<i64: 4>
} : (tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>>
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]
floor
セマンティクス
operand テンソルの要素ごとの床演算を行い、result テンソルを生成します。IEEE-754 仕様の roundToIntegralTowardNegative オペレーションを実装します。量子化型の場合は、dequantize_op_quantize(floor, operand, type(result)) を実行します。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | 浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C1) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | 浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C1) | 
制約
- (C1)baseline_type(operand) = baseline_type(result)。
例
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]
集める
セマンティクス
start_indices で指定されたオフセットから operand テンソルのスライスを収集し、result テンソルを生成します。
次の図は、result の要素が operand の要素にどのようにマッピングされるかを、具体的な例を使用して示しています。この図では、いくつかの result インデックスの例を選択し、それらが対応する operand インデックスについて詳しく説明しています。
より正式には、result[result_index] = operand[operand_index] です。ここで:
- batch_dims = [d for d in axes(result) and d not in offset_dims]。
- batch_index = result_index[batch_dims...]。
- start_indexは次のように定義されます。- start_indices[bi0, ..., :, ..., biN]。ここで、- biは- batch_index内の個々の要素で、- :は- index_vector_dimインデックスに挿入されます(- index_vector_dim<- rank(start_indices)の場合)。
- それ以外の場合は [start_indices[batch_index]]。
 
- axes(operand)の- d_operandは、- d_operand = start_index_map[d_start]の場合、- full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand])
- そうでない場合は full_start_index[d_operand] = 0。
 
- axes(operand)の- d_operandは、- full_batching_index[d_operand] = batch_index[d_start - (d_start < index_vector_dim ? 0 : 1)]- d_operand = operand_batching_dims[i_batching]と- d_start = start_indices_batching_dims[i_batching]の場合。
- それ以外の場合は full_batching_index[d_operand] = 0。
 
- offset_index = result_index[offset_dims...]。
- full_offset_index = [oi0, ..., 0, ..., oiN]。ここで、- oiは- offset_indexの個々の要素であり、- 0は- collapsed_slice_dimsと- operand_batching_dimsのインデックスに挿入されます。
- operand_index = full_start_index + full_batching_index + full_offset_index。
indices_are_sorted が true の場合、実装では start_indices が start_index_map に関して並べ替えられていると想定できます。それ以外の場合、動作は未定義です。より正式には、indices(result) のすべての i1 < i2 について、full_start_index(i1) <= full_start_index(i2) です。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | テンソルまたはテンソルごとの量子化テンソル | (C1)、(C8)、(C11)、(C17)、(C19 ~ C21)、(C23) | 
| (I2) | start_indices | 整数型のテンソル | (C2~C3)、(C14)、(C17)、(C22) | 
| (I3) | offset_dims | si64型の 1 次元テンソル定数 | (C1)、(C4 ~ C5)、(C22) | 
| (I4) | collapsed_slice_dims | si64型の 1 次元のテンソル定数 | (C1)、(C6 ~ C9)、(C22) | 
| (I5) | operand_batching_dims | si64型の 1 次元テンソル定数 | (C1)、(C6)、(C10 ~ C12)、(C16 ~ C18)、(C22) | 
| (I6) | start_indices_batching_dims | si64型の 1 次元テンソル定数 | (C13 ~ C17) | 
| (I7) | start_index_map | si64型の 1 次元テンソル定数 | (C3)、(C18~C19) | 
| (I8) | index_vector_dim | si64型の定数 | (C2 ~ C3)、(C15)、(C22) | 
| (I9) | slice_sizes | si64型の 1 次元のテンソル定数 | (C9)、(C12)、(C20~C22) | 
| (I10) | indices_are_sorted | i1型の定数 | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | テンソルまたはテンソルごとの量子化テンソル | (C5)、(C22 ~ C23) | 
制約
- (C1)rank(operand) = size(offset_dims) + size(collapsed_slice_dims) + size(operand_batching_dims)。
- (C2)0 <= index_vector_dim <= rank(start_indices)。
- (C3)size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1。
- (C4)is_unique(offset_dims) and is_sorted(offset_dims)。
- (C5)0 <= offset_dims < rank(result)。
- (C6)is_unique(concatenate(collapsed_slice_dims, operand_batching_dims))
- (C7)is_sorted(collapsed_slice_dims)。
- (C8)0 <= collapsed_slice_dims < rank(operand)。
- (C9)slice_sizes[collapsed_slice_dims...] <= 1。
- (C10)is_sorted(operand_batching_dims)。
- (C11)0 <= operand_batching_dims < rank(operand)。
- (C12)slice_sizes[operand_batching_dims...] <= 1。
- (C13)is_unique(start_indices_batching_dims)。
- (C14)0 <= start_indices_batching_dims < rank(start_indices)。
- (C15)index_vector_dim not in start_indices_batching_dims。
- (C16)size(operand_batching_dims) == size(start_indices_batching_dims)。
- (C17)dim(operand, operand_batching_dims...) = dim(start_indices, start_indices_batching_dims...)。
- (C18)is_unique(concatenate(start_index_map, operand_batching_dims))。
- (C19)0 <= start_index_map < rank(operand)。
- (C20)size(slice_sizes) = rank(operand)。
- (C21)0 <= slice_sizes <= shape(operand)。
- (C22)shape(result) = combine(batch_dim_sizes, offset_dim_sizes)。ここで:- batch_dim_sizes = shape(start_indices)とほぼ同じですが、- index_vector_dimに対応する- start_indicesのディメンション サイズは含まれません。
- offset_dim_sizes = slice_sizes。ただし、- collapsed_slice_dimsと- operand_batching_dimsに対応する- slice_sizesのディメンション サイズは含まれません。
- combineは、- batch_dimsに対応する軸に- batch_dim_sizesを配置し、- offset_dimsに対応する軸に- offset_dim_sizesを配置します。
 
- (C23)element_type(operand) = element_type(result)。
例
// %operand: [
//            [
//             [[1, 2], [3, 4], [5, 6], [7, 8]],
//             [[9, 10],[11, 12], [13, 14], [15, 16]],
//             [[17, 18], [19, 20], [21, 22], [23, 24]]
//            ],
//            [
//             [[25, 26], [27, 28], [29, 30], [31, 32]],
//             [[33, 34], [35, 36], [37, 38], [39, 40]],
//             [[41, 42], [43, 44], [45, 46], [47, 48]]
//            ]
//           ]
// %start_indices: [
//                  [
//                   [[0, 0], [1, 0], [2, 1]],
//                   [[0, 1], [1, 1], [0, 9]]
//                  ],
//                  [
//                   [[0, 0], [2, 1], [2, 2]],
//                   [[1, 2], [0, 1], [1, 0]]
//                  ]
//                 ]
%result = "stablehlo.gather"(%operand, %start_indices) {
  dimension_numbers = #stablehlo.gather<
    offset_dims = [3, 4],
    collapsed_slice_dims = [1],
    operand_batching_dims = [0],
    start_indices_batching_dims = [1],
    start_index_map = [2, 1],
    index_vector_dim = 3>,
  slice_sizes = array<i64: 1, 1, 2, 2>,
  indices_are_sorted = false
} : (tensor<2x3x4x2xi32>, tensor<2x2x3x2xi64>) -> tensor<2x2x3x2x2xi32>
// %result: [
//           [
//            [
//             [[1, 2], [3, 4]],
//             [[3, 4], [5, 6]],
//             [[13, 14], [15, 16]]
//            ],
//            [
//             [[33, 34], [35, 36]],
//             [[35, 36], [37, 38]],
//             [[41, 42], [43, 44]]
//            ]
//           ],
//           [
//            [
//             [[1, 2], [3, 4]],
//             [[13, 14], [15, 16]],
//             [[21, 22], [23, 24]]
//            ],
//            [
//             [[43, 44], [45, 46]],
//             [[33, 34], [35, 36]],
//             [[27, 28], [29, 30]]
//            ]
//           ]
//          ]
get_dimension_size
セマンティクス
operand の指定された dimension のサイズを生成します。正式には result = dim(operand, dimension) です。セマンティクスは、型のシェイプ コンポーネントにのみ関係します。要素の型は任意です。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | テンソルまたは量子化テンソル | (C1) | 
| (I2) | dimension | si64型の定数 | (C1) | 
出力
| 名前 | タイプ | 
|---|---|
| result | si32型の 0 次元テンソル | 
制約
- (C1)0 <= dimension < rank(operand)。
例
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
  dimension = 1 : i64
} : (tensor<2x3xi64>) -> tensor<i32>
// %result: 3
get_tuple_element
セマンティクス
operand タプルの index 位置にある要素を抽出して result を生成します。正式には result = operand[index] です。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | tuple | (C1)、(C2) | 
| (I2) | index | si32型の定数 | (C1)、(C2) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | サポートされている任意のタイプ | (C2) | 
制約
- (C1)0 <= index < size(operand)。
- (C2)type(result) = tuple_element_types(operand)[index]。
例
// %operand: ([1.0, 2.0], (3))
  index = 0 : i32
} : (tuple<tensor<2xf32>, tuple<tensor<i32>>>) -> tensor<2xf32>
// %result: [1.0, 2.0]
~なら
セマンティクス
pred の値に応じて、true_branch または false_branch から関数を 1 つだけ実行することで、出力を生成します。(より正式には、result =
pred ? true_branch() : false_branch())。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | pred | i1型の 0 次元テンソル | |
| (I2) | true_branch | 関数 | (C1~C3) | 
| (I3) | false_branch | 関数 | (C1)、(C2) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| results | テンソル、量子化テンソル、トークンの可変数 | (C3) | 
制約
- (C1)input_types(true_branch) = input_types(false_branch) = []。
- (C2)output_types(true_branch) = output_types(false_branch)。
- (C3)type(results...) = output_types(true_branch)。
例
// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
  "stablehlo.return"(%result_true_branch) : (tensor<i32>) -> ()
}, {
  "stablehlo.return"(%result_false_branch) : (tensor<i32>) -> ()
}) : (tensor<i1>) -> tensor<i32>
// %result: 10
imag
セマンティクス
operand から虚部を要素ごとに抽出し、result テンソルを生成します。より正式には、各要素 x について imag(x) = is_complex(x) ? imaginary_part(x) :
constant(0, element_type(result)) です。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | 浮動小数点数または複素数のテンソル | (C1)、(C2) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | 浮動小数点型のテンソル | (C1)、(C2) | 
制約
- (C1)shape(result) = shape(operand)。
- (C2)element_type(result)は次のように定義されます。- is_complex(operand)の場合、- complex_element_type(element_type(operand))。
- それ以外の場合は element_type(operand)。
 
例
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [2.0, 4.0]
インフィード
セマンティクス
インフィードからデータを読み取り、results を生成します。
infeed_config のセマンティクスは実装で定義されます。
results は、先頭にペイロード値、最後にトークンで構成されます。今後、明確性を高めるために、ペイロードとトークンを 2 つの個別の出力に分割する予定です(#670)。
入力
| ラベル | 名前 | タイプ | 
|---|---|---|
| (I1) | token | token | 
| (I2) | infeed_config | string型の定数 | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| results | テンソル、量子化テンソル、またはトークンの可変数 | (C1~C3) | 
制約
- (C1)0 < size(results)。
- (C2)is_empty(result[:-1])またはis_tensor(type(results[:-1]))。
- (C3)is_token(type(results[-1]))。
例
// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
  infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
  infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]
iota
セマンティクス
output テンソルを、iota_dimension ディメンションに沿って 0 から順に増加する値で埋めます。より正式には、
output[output_index] = constant(is_quantized(output) ?
quantize(output_index[iota_dimension], element_type(output)) :
output_index[iota_dimension], element_type(output))。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | iota_dimension | si64 | (C1) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| output | 整数型、浮動小数点型、複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) | 
制約
- (C1)0 <= iota_dimension < rank(output)。
例
%output = "stablehlo.iota"() {
  iota_dimension = 0 : i64
} : () -> tensor<4x5xi32>
// %output: [
//           [0, 0, 0, 0, 0],
//           [1, 1, 1, 1, 1],
//           [2, 2, 2, 2, 2],
//           [3, 3, 3, 3, 3]
//          ]
%output = "stablehlo.iota"() {
  iota_dimension = 1 : i64
} : () -> tensor<4x5xi32>
// %output: [
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4]
//          ]
is_finite
セマンティクス
x の値が有限(+Inf、-Inf、NaN のいずれでもない)かどうかを要素ごとにチェックし、y テンソルを生成します。IEEE-754 仕様の isFinite オペレーションを実装します。量子化された型の場合、結果は常に true です。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | x | 浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C1) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| y | ブール型のテンソル | (C1) | 
制約
- (C1)shape(x) = shape(y)。
例
// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x) : (tensor<7xf64) -> tensor<7xi1>
// %y: [false, false, false, true, true, true, true]
log
セマンティクス
operand テンソルの要素ごとのログ演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- 浮動小数点数の場合: IEEE-754 の log。
- 複素数の場合: 複素対数。
- 量子化型の場合: dequantize_op_quantize(log, operand, type(result))。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) | 
制約
- (C1)baseline_type(operand) = baseline_type(result)。
例
// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]
log_plus_one
セマンティクス
operand テンソルに対して要素単位の対数に 1 演算を加え、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- 浮動小数点数の場合: IEEE-754 の logp1。
- 複素数の場合: 複素対数に 1 を加算します。
- 量子化された型の場合: dequantize_op_quantize(log_plus_one, operand, type(result))。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) | 
制約
- (C1)baseline_type(operand) = baseline_type(result)。
例
// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]
ロジスティクス
セマンティクス
operand テンソルの要素ごとのロジスティック演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の操作を行います。
- 浮動小数点数の場合: IEEE-754 の division(1, addition(1, exp(-x)))。
- 複素数の場合: 複素ロジスティック。
- 量子化された型の場合: dequantize_op_quantize(logistic, operand, type(result))。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) | 
制約
- (C1)baseline_type(operand) = baseline_type(result)。
例
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]
地図
セマンティクス
dimensions に沿ってマッピング関数 computation を inputs に適用し、result テンソルを生成します。
正式には result[result_index] = computation(inputs...[result_index]) です。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | inputs | テンソルの可変数またはテンソルごとの量子化テンソル | (C1 ~ C4) | 
| (I2) | dimensions | si64型の 1 次元テンソル定数 | (C3) | 
| (I3) | computation | 関数 | (C4) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | テンソルまたはテンソルごとの量子化テンソル | (C1)、(C4) | 
制約
- (C1)shape(inputs...) = shape(result)。
- (C2)0 < size(inputs) = N。
- (C3)dimensions = range(rank(inputs[0]))。
- (C4)computationの型は(tensor<E0>, ..., tensor<EN-1>) -> tensor<E'>で、Ei = element_type(inputs[i])とE' = element_type(result)です。
例
// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = stablehlo.multiply %arg0, %arg1 : tensor<i64>
    stablehlo.return %0 : tensor<i64>
}) {
  dimensions = array<i64: 0, 1>
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[0, 5], [12, 21]]
最大
セマンティクス
テンソル lhs と rhs に対して要素ごとの最大演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- ブール値の場合: 論理 OR。
- 整数の場合: 整数の最大値。
- 浮動小数点数の場合: IEEE-754 の maximum。
- 複素数の場合: (real, imaginary)ペアの辞書順の最大値。複素数の順序付けには意外なセマンティクスが伴うため、将来的には、この演算での複素数のサポートは終了する予定です(#560)。
- 量子化された型の場合:
- dequantize_op_quantize(maximum, lhs, rhs, type(result))。
 
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | lhs | テンソルまたはテンソルごとの量子化テンソル | (C1) | 
| (I2) | rhs | テンソルまたはテンソルごとの量子化テンソル | (C1) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | テンソルまたはテンソルごとの量子化テンソル | (C1) | 
制約
- (C1)baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)。
例
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 8]]
最小
セマンティクス
テンソル lhs と rhs に対して要素ごとの最小演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- ブール値の場合: 論理 AND。
- 整数の場合: 整数の最小値。
- 浮動小数点数の場合: IEEE-754 の minimum。
- 複素数の場合: (real, imaginary)ペアの辞書順最小値。複素数に順序付けを適用すると、意図しないセマンティクスが発生するため、今後、このオペレーションでの複素数のサポートを削除する予定です(#560)。
- 量子化型の場合:
- dequantize_op_quantize(minimum, lhs, rhs, type(result))。
 
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | lhs | テンソルまたはテンソルごとの量子化テンソル | (C1) | 
| (I2) | rhs | テンソルまたはテンソルごとの量子化テンソル | (C1) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | テンソルまたはテンソルごとの量子化テンソル | (C1) | 
制約
- (C1)baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)。
例
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 4]]
乗算
セマンティクス
2 つのテンソル lhs と rhs の要素ごとの積を実行し、result テンソルを生成します。要素のタイプに応じて、次の操作を行います。
- ブール値の場合: 論理 AND。
- 整数の場合: 整数の乗算。
- 浮動小数点数の場合: IEEE-754 の multiplication。
- 複素数の場合: 複素数乗算。
- 量子化された型の場合:
- dequantize_op_quantize(multiply, lhs, rhs, type(result))。
 
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | lhs | テンソルまたはテンソルごとの量子化テンソル | (C1) | 
| (I2) | rhs | テンソルまたはテンソルごとの量子化テンソル | (C1) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | テンソルまたはテンソルごとの量子化テンソル | (C1) | 
制約
- (C1)baseline_type(operand) = baseline_type(result)。
例
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 12], [21, 32]]
negate
セマンティクス
operand テンソルの要素ごとの否定を行い、result テンソルを生成します。要素のタイプに応じて、次の操作を行います。
- 符号付き整数の場合: 整数の否定。
- 符号なし整数の場合: 符号付き整数へのビットキャスト、整数の否定、符号なし整数へのビットキャスト。
- 浮動小数点数: IEEE-754 の negate。
- 複素数の場合: 複素数を否定する。
- 量子化された型の場合: dequantize_op_quantize(negate, operand, type(result))。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | 整数型、浮動小数点型、複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | 整数型、浮動小数点型、複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) | 
制約
- (C1)baseline_type(operand) = baseline_type(result)。
例
// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand) : (tensor<2xi32>) -> tensor<2xi32>
// %result: [0, 2]
// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"(%operand) : (tensor<1xcomplex<f32>>) -> tensor<1xcomplex<f32>>
// %result: [-2.5, -0.0]
いない
セマンティクス
テンソル operand の要素ごとの NOT を実行し、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- ブール値の場合: 論理 NOT。
- 整数の場合: ビット演算 NOT。
引数
| 名前 | タイプ | 制約 | 
|---|---|---|
| operand | ブール型または整数型のテンソル | (C1) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | ブール型または整数型のテンソル | (C1) | 
制約
- (C1)type(operand) = type(result)。
例
// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand) : (tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[-2, -3], [-4, -5]]
// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"(%operand) : (tensor<2xi1>) -> tensor<2xi1>
// %result: [false, true]
optimization_barrier
セマンティクス
operand を生成するオペレーションが result に依存するオペレーションの前に実行され、コンパイラ変換によってオペレーションがバリアを超えて移動されないようにします。それ以外の場合、オペレーションは ID(result = operand など)です。
引数
| 名前 | タイプ | 制約 | 
|---|---|---|
| operand | テンソルの可変数、テンソルごとの量子化テンソルまたはトークン | (C1) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | テンソルの可変数、テンソルごとの量子化テンソルまたはトークン | (C1) | 
制約
- (C1)type(operand...) = type(result...)。
例
// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
// %result0: 0.0
// %result1: 1.0
または
セマンティクス
2 つのテンソル lhs と rhs の要素ごとの OR を実行し、result テンソルを生成します。要素のタイプに応じて、次の操作を行います。
- ブール値の場合: 論理 OR。
- 整数の場合: ビット演算 OR。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | lhs | 整数型またはブール型のテンソル | (C1) | 
| (I2) | rhs | 整数型またはブール型のテンソル | (C1) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | 整数型またはブール型のテンソル | (C1) | 
制約
- (C1)type(lhs) = type(rhs) = type(result)。
例
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, true]]
アウトフィード
セマンティクス
inputs をアウトフィードに書き込み、result トークンを生成します。
outfeed_config のセマンティクスは実装で定義されます。
入力
| ラベル | 名前 | タイプ | 
|---|---|---|
| (I1) | inputs | テンソルまたは量子化テンソルの可変数 | 
| (I2) | token | token | 
| (I3) | outfeed_config | string型の定数 | 
出力
| 名前 | タイプ | 
|---|---|
| result | token | 
例
%result = "stablehlo.outfeed"(%input0, %token) {
  outfeed_config = ""
} : (tensor<2x2x2xi64>, !stablehlo.token) -> !stablehlo.token
パッド
セマンティクス
テンソルの周囲と、指定された padding_value を使用してテンソルの要素の間にパディングすることで、operand を拡張します。
edge_padding_low と edge_padding_high は、各ディメンションの下端(インデックス 0 の隣)と上端(最大インデックスの隣)に追加されるパディングの量をそれぞれ指定します。パディングの量は負の値にすることができます。負のパディングの絶対値は、指定されたディメンションから削除する要素の数を示します。
interior_padding は、各ディメンション内の任意の 2 つの要素間に追加されるパディングの量を指定します。この値は負の値にすることはできません。内部パディングはエッジ パディングの前に行われるため、負のエッジ パディングを行うと、内部パディングされたオペランドから要素が削除されます。
より正式には、result[result_index] は次のように定義されます。
- result_index = edge_padding_low + operand_index * (interior_padding + 1)の場合、- operand[operand_index]。
- それ以外の場合は padding_value。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | テンソルまたはテンソルごとの量子化テンソル | (C1)、(C2)、(C4) | 
| (I2) | padding_value | 0 次元テンソルまたはテンソルごとの量子化テンソル | (C1) | 
| (I3) | edge_padding_low | si64型の 1 次元テンソル定数 | (C1)、(C4) | 
| (I4) | edge_padding_high | si64型の 1 次元のテンソル定数 | (C1)、(C4) | 
| (I5) | interior_padding | si64型の 1 次元テンソル定数 | (C2 ~ C4) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | テンソルまたはテンソルごとの量子化テンソル | (C3-C6) | 
制約
- (C1)element_type(operand) = element_type(padding_value) = element_type(result)。
- (C2)size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)。
- (C3)0 <= interior_padding。
- (C4)shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high。
例
// %operand: [
//            [1, 2, 3],
//            [4, 5, 6]
//           ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
  edge_padding_low = array<i64: 0, 1>,
  edge_padding_high = array<i64: 2, 1>,
  interior_padding = array<i64: 1, 2>
} : (tensor<2x3xi32>, tensor<i32>) -> tensor<5x9xi32>
// %result: [
//           [0, 1, 0, 0, 2, 0, 0, 3, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 4, 0, 0, 5, 0, 0, 6, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0]
//          ]
partition_id
セマンティクス
現在のプロセスの partition_id を生成します。
出力
| 名前 | タイプ | 
|---|---|
| result | ui32型の 0 次元テンソル | 
例
%result = "stablehlo.partition_id"() : () -> tensor<ui32>
popcnt
セマンティクス
operand テンソルに設定されているビット数を要素ごとにカウントし、result テンソルを生成します。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | 整数型のテンソル | (C1) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | 整数型のテンソル | (C1) | 
制約
- (C1)type(operand) = type(result)。
例
// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand) : (tensor<4xi64>) -> tensor<4xi64>
// %result: [0, 1, 1, 7]
電力
セマンティクス
lhs テンソルと rhs テンソルの要素ごとの指数演算を行い、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- 整数の場合: 整数のべき乗。
- 浮動小数点数の場合: IEEE-754 の pow。
- 複素数の場合: 複素指数関数。
- 量子化された型の場合: dequantize_op_quantize(power, lhs, rhs, type(result))。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | lhs | 整数型、浮動小数点型、複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) | 
| (I2) | rhs | 整数、浮動小数点、複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | 整数型、浮動小数点型、複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) | 
制約
- (C1)baseline_type(operand) = baseline_type(result)。
例
// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs) : (tensor<6xf64>, tensor<6xf64>) -> tensor<6xf64>
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]
real
セマンティクス
operand から要素ごとに実部を抽出し、result テンソルを生成します。より正式には、各要素 x について real(x) = is_complex(x) ? real_part(x) : x です。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | 浮動小数点数または複素数のテンソル | (C1)、(C2) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | 浮動小数点型のテンソル | (C1)、(C2) | 
制約
- (C1)shape(result) = shape(operand)。
- (C2)element_type(result)は次のように定義されます。- is_complex(operand)の場合、- complex_element_type(element_type(operand))。
- それ以外の場合は element_type(operand)。
 
例
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [1.0, 3.0]
受信
セマンティクス
channel_id を使用してチャネルからデータを受信し、results を生成します。
is_host_transfer が true の場合、オペレーションはホストからデータを転送します。そうでない場合は、別のデバイスからデータを転送します。具体的な意味は実装によって異なります。このフラグは channel_type で提供される情報と重複するため、今後はどちらか一方のみを保持する予定です(#666)。
results は、先頭にペイロード値、最後にトークンで構成されます。今後、明確性を高めるために、ペイロードとトークンを 2 つの個別の出力に分割する予定です(#670)。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | token | token | (C4) | 
| (I2) | channel_id | si64型の定数 | |
| (I3) | channel_type | DEVICE_TO_DEVICEとHOST_TO_DEVICEの列挙型 | (C1) | 
| (I4) | is_host_transfer | i1型の定数 | (C1) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| results | テンソル、量子化テンソル、トークンの可変数 | (C2 ~ C4) | 
制約
- (C1)channel_typeは次のように定義されます。- is_host_transfer = trueの場合、- HOST_TO_DEVICE
- それ以外の場合は DEVICE_TO_DEVICE。
 
- (C2)0 < size(results)。
- (C3)is_empty(result[:-1])またはis_tensor(type(results[:-1]))。
- (C4)is_token(type(results[-1]))。
例
%results0, %results1 = "stablehlo.recv"(%token) {
  channel_handle = #stablehlo.channel_handle<handle = 1, type = 3>,
  is_host_transfer = true
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
reduce
セマンティクス
dimensions に沿ってリダクション関数 body を inputs と init_values に適用し、results テンソルを生成します。
減算の順序は実装で定義されます。つまり、すべての実装ですべての入力に対して同じ結果が得られるように、body と init_values はモノイドを形成する必要があります。ただし、この条件は一般的な多くの減少では成立しません。たとえば、body の浮動小数点加算と init_values のゼロは、浮動小数点加算が結合的ではないため、実際にはモノイドを形成しません。
より正式には、results...[j0, ..., jR-1] = reduce(input_slices_converted) です。ここで:
- input_slices = inputs...[j0, ..., :, ..., jR-1]。- :は- dimensionsに挿入されます。
- input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...)。
- init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...)。
- reduce(input_slices_converted) = exec(schedule)は任意のバイナリ ツリーです。- scheduleは次のとおりです。- exec(node) = body(exec(node.left), exec(node.right))。
- exec(leaf) = leaf.value。
 
- scheduleは実装定義の完全二分木で、順序付き走査は次のとおりです。- input_slices_converted...[index]値。- index_space(input_slices_converted)内のすべての- indexは、- indexの昇順の辞書順で指定します。
- 実装で定義された位置に、実装で定義された量の init_values_convertedが点在します。
 
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | inputs | テンソルの可変数またはテンソルごとの量子化テンソル | (C1~C4)、(C6)、(C7) | 
| (I2) | init_values | 0 次元テンソルまたはテンソルごとの量子化テンソルの可変数 | (C2)、(C3) | 
| (I3) | dimensions | si64型の 1 次元テンソル定数 | (C4)、(C5)、(C7) | 
| (I4) | body | 関数 | (C6) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| results | テンソルの可変数またはテンソルごとの量子化テンソル | (C3)、(C7)、(C8) | 
制約
- (C1)same(shape(inputs...))。
- (C2)element_type(inputs...) = element_type(init_values...)。
- (C3)0 < size(inputs) = size(init_values) = size(results) = N。
- (C4)0 <= dimensions < rank(inputs[0])。
- (C5)is_unique(dimensions)。
- (C6)bodyの型は(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)で、is_promotable(element_type(inputs[i]), Ei)です。
- (C7)shape(results...) = shape(inputs...)。ただし、dimensionsに対応するinputs...のディメンション サイズは含まれません。
- (C8)[0,N)のすべてのiに対してelement_type(results[i]) = Ei。
例
// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  dimensions = array<i64: 1>
} : (tensor<1x6xi64>, tensor<i64>) -> tensor<1xi64>
// %result = [15]
reduce_precision
セマンティクス
operand を exponent_bits と mantissa_bits を使用する別の浮動小数点型に要素単位で変換し、元の浮動小数点型に戻して output テンソルを生成します。
より正式には、次のように記述します。
- 元の値のマンシンダビットが更新され、roundToIntegralTiesToEvenセマンティクスを使用して、元の値をmantissa_bitsで表せる最も近い値に丸められます。
- 次に、mantissa_bitsが元の値のマンシッサ ビット数より小さい場合は、マンシッサ ビットがmantissa_bitsに切り捨てられます。
- 次に、中間結果の指数ビットが exponent_bitsで指定された範囲に収まらない場合、中間結果は元の符号を使用して無限大までオーバーフローするか、元の符号を使用してゼロにアンダーフローします。
- 量子化された型の場合は、dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result))を実行します。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | 浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C1) | 
| (I2) | exponent_bits | si32型の定数 | (C2) | 
| (I3) | mantissa_bits | si32型の定数 | (C3) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| output | 浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C1) | 
制約
- (C1)baseline_type(operand) = baseline_type(output)。
- (C2)1 <= exponent_bits。
- (C3)0 <= mantissa_bits。
例
// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
  exponent_bits = 5 : i32,
  mantissa_bits = 10 : i32
} : (tensor<6xf64>) -> tensor<6xf64>
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]
reduce_scatter
セマンティクス
StableHLO プロセス グリッド内の各プロセス グループ内で、各プロセスの operand テンソルの値に対して computations を使用して減算を行い、scatter_dimension に沿って減算結果を分割し、分割された部分をプロセス間で分散して result を生成します。
このオペレーションでは、StableHLO プロセス グリッドが、次のように定義された process_groups に分割されます。
- channel_id <= 0 and use_global_device_ids = falseの場合、- cross_replica(replica_groups)
- channel_id > 0 and use_global_device_ids = falseの場合、- cross_replica_and_partition(replica_groups)
- channel_id > 0 and use_global_device_ids = trueの場合、- flattened_ids(replica_groups)
その後、各 process_group 内で次の操作を行います。
- reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation)。
- parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension)。
- process_group内のすべての- senderに対して- result@receiver = parts@sender[receiver_index](- receiver_index = process_group.index(receiver))。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | テンソルまたはテンソルごとの量子化テンソル | (C1)、(C2)、(C7)、(C8) | 
| (I2) | scatter_dimension | si64型の定数 | (C1)、(C2)、(C8) | 
| (I3) | replica_groups | si64型の 2 次元テンソル定数 | (C3~C5) | 
| (I4) | channel_id | si64型の定数 | (C6) | 
| (I5) | use_global_device_ids | i1型の定数 | (C6) | 
| (I6) | computation | 関数 | (C7) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | テンソルまたはテンソルごとの量子化テンソル | (C8~C9) | 
制約
- (C1)dim(operand, scatter_dimension) % dim(process_groups, 1) = 0。
- (C2)0 <= scatter_dimension < rank(operand)。
- (C3)is_unique(replica_groups)。
- (C4)size(replica_groups)は次のように定義されます。- cross_replicaを使用する場合は- num_replicas。
- cross_replica_and_partitionを使用する場合は- num_replicas。
- flattened_idsを使用する場合は- num_processes。
 
- (C5)0 <= replica_groups < size(replica_groups)。
- (C6)use_global_device_ids = trueの場合はchannel_id > 0です。
- (C7)computationの型は(tensor<E>, tensor<E>) -> (tensor<E>)です。is_promotable(element_type(operand), E)
- (C8)shape(result) = shape(operand)(ただし、次を除く)- dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1)。
 
- (C9)element_type(result) = E。
例
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
//                   [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
//                   [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
  %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
  "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  scatter_dimension = 1 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x4xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[10, 12],
//                  [18, 20]]
// %result@(1, 0): [[14, 16],
//                  [22, 24]]
reduce_window
セマンティクス
inputs と init_values のウィンドウに削減関数 body を適用し、results を生成します。
次の図は、具体的な例を使用して results... の要素が inputs... から計算される方法を示しています。
より正式には、results...[result_index] = reduce(windows, init_values, axes(inputs...), body)(reduce を参照)で、次のように定義されます。
- padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1)。
- window_start = result_index * window_strides
- window_end = window_start + (window_dimensions - 1) * window_dilations + 1
- windows = slice(padded_inputs..., window_start, window_end, window_dilations)。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | inputs | テンソルの可変数またはテンソルごとの量子化テンソル | (C1~C4)、(C6)、(C8)、(C10)、(C12)、(C13)、(C15) | 
| (I2) | init_values | 0 次元テンソルまたはテンソルごとの量子化テンソルの可変数 | (C1)、(C13) | 
| (I3) | window_dimensions | si64型の 1 次元テンソル定数 | (C4)、(C5)、(C15) | 
| (I4) | window_strides | si64型の 1 次元テンソル定数 | (C6)、(C7)、(C15) | 
| (I5) | base_dilations | si64型の 1 次元のテンソル定数 | (C8)、(C9)、(C15) | 
| (I6) | window_dilations | si64型の 1 次元テンソル定数 | (C10)、(C11)、(C15) | 
| (I7) | padding | si64型の 2 次元テンソル定数 | (C12)、(C15) | 
| (I8) | body | 関数 | (C13) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| results | テンソルの可変数またはテンソルごとの量子化テンソル | (C1)、(C14~C16) | 
制約
- (C1)0 < size(inputs) = size(init_values) = size(results) = N。
- (C2)same(shape(inputs...))。
- (C3)element_type(inputs...) = element_type(init_values...)。
- (C4)size(window_dimensions) = rank(inputs[0])。
- (C5)0 < window_dimensions。
- (C6)size(window_strides) = rank(inputs[0])。
- (C7)0 < window_strides。
- (C8)size(base_dilations) = rank(inputs[0])。
- (C9)0 < base_dilations。
- (C10)size(window_dilations) = rank(inputs[0])。
- (C11)0 < window_dilations。
- (C12)shape(padding) = [rank(inputs[0]), 2]。
- (C13)bodyの型は(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)で、is_promotable(element_type(inputs[i]), Ei)です。
- (C14)same(shape(results...))。
- (C15)shape(results[0]) = num_windowsここで:- dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1。
- padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1]
- dilated_window_shape = (window_dimensions - 1) * window_dilations + 1
- is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape
- num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1。
 
- (C16)[0,N)内のすべてのiに対してelement_type(results[i]) = Ei。
例
// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  window_dimensions = array<i64: 2, 1>,
  window_strides = array<i64: 4, 1>,
  base_dilations = array<i64: 2, 1>,
  window_dilations = array<i64: 3, 1>,
  padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<3x2xi64>, tensor<i64>) -> tensor<2x2xi64>
// %result = [[0, 0], [3, 4]]
残り
セマンティクス
被除数 lhs テンソルと除数 rhs テンソルの要素ごとの余りを実行し、result テンソルを生成します。
より正式には、結果の符号は被除数から取得され、結果の絶対値は常に除数の絶対値よりも小さくなります。残りは lhs - d * rhs として計算されます。ここで、d は次のように定義されます。
- 整数の場合: stablehlo.divide(lhs, rhs)。
- 浮動小数点数の場合: 丸め属性 roundTowardZeroを持つ IEEE-754 のdivision(lhs, rhs)。
- 複素数の場合: 未定(#997)。
- 量子化された型の場合:
- dequantize_op_quantize(remainder, lhs, rhs, type(result))。
 
浮動小数点要素型の場合、このオペレーションは IEEE-754 仕様の remainder オペレーションとは対照的です。d は、lhs/rhs の正確な値に最も近い整数値で、偶数に等しい値です。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | lhs | 整数型、浮動小数点型、複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) | 
| (I2) | rhs | 整数型、浮動小数点型、複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | 整数型、浮動小数点型、複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) | 
制約
- (C1)baseline_type(operand) = baseline_type(result)。
例
// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64>
// %result: [2, -2, 2, -2]
replica_id
セマンティクス
現在のプロセスの replica_id を生成します。
出力
| 名前 | タイプ | 
|---|---|
| result | ui32型の 0 次元テンソル | 
例
%result = "stablehlo.replica_id"() : () -> tensor<ui32>
reshape
セマンティクス
operand テンソルを result テンソルに再フォーマットします。概念的には、同じ正規表現を維持しながら、形状を変更する可能性(tensor<2x3xf32> から tensor<3x2xf32> または tensor<6xf32> への変更など)があります。
より正式には、result[result_index] = operand[operand_index] で、result_index と operand_index は index_space(result) と index_space(operand) の辞書順で同じ位置にあります。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | テンソルまたは量子化テンソル | (C1~C3) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | テンソルまたは量子化テンソル | (C1~C3) | 
制約
- (C1)element_type(result)は次のように定義されます。- element_type(operand)(- !is_per_axis_quantized(operand)の場合)。
- element_type(operand)。ただし、- quantization_dimension(operand)と- quantization_dimension(result)が異なる場合があります。
 
- (C2)size(operand) = size(result)。
- (C3)is_per_axis_quantized(operand)の場合:- reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)。
- dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
- reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)。
 
例
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.reshape"(%operand) : (tensor<2x3xi32>) -> tensor<3x2xi32>
// %result: [[1, 2], [3, 4], [5, 6]]
reverse
セマンティクス
指定された dimensions に沿って operand の要素の順序を逆にして、result テンソルを生成します。より正式には、result[result_index] = operand[operand_index] で次のようにします。
- dimensionsの- dの場合、- operand_index[d] = dim(result, d) - result_index[d] - 1
- そうでない場合は operand_index[d] = result_index[d]。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | テンソルまたはテンソルごとの量子化テンソル | (C1)、(C3) | 
| (I2) | dimensions | si64型の 1 次元テンソル定数 | (C2)、(C3) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | テンソルまたはテンソルごとの量子化テンソル | (C1)、(C3) | 
制約
- (C1)type(operand) = type(result)。
- (C2)is_unique(dimensions)。
- (C3)0 <= dimensions < rank(result)。
例
// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
  dimensions = array<i64: 1>
} : (tensor<3x2xi32>) -> tensor<3x2xi32>
// %result: [[2, 1], [4, 3], [6, 5]]
rng
セマンティクス
rng_distribution アルゴリズムを使用して乱数を生成し、指定された形状 shape の result テンソルを生成します。
rng_distribution = UNIFORM の場合、乱数は [a, b) の区間で均一分布に従って生成されます。a >= b の場合、動作は未定義です。
rng_distribution = NORMAL の場合、乱数は平均 = a、標準偏差 = b の正規分布に従って生成されます。b < 0 の場合、動作は未定義です。
乱数の生成方法は実装で定義されます。たとえば、確定的である場合もあれば、そうでない場合もあります。また、非表示の状態を使用する場合もあれば、使用しないこともあります。
多くの関係者と話し合った結果、このオペレーションは事実上非推奨であることが判明したため、今後は削除を検討する予定です(#597)。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | a | 整数型、ブール型、浮動小数点型の 0 次元テンソル | (C1)、(C2) | 
| (I2) | b | 整数型、ブール型、浮動小数点型の 0 次元テンソル | (C1)、(C2) | 
| (I3) | shape | si64型の 1 次元テンソル定数 | (C3) | 
| (I4) | rng_distribution | UNIFORMとNORMALの列挙型 | (C2) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | 整数型、ブール型、浮動小数点型のテンソル | (C1~C3) | 
制約
- (C1)element_type(a) = element_type(b) = element_type(result)。
- (C2)rng_distribution = NORMALの場合はis_float(a)です。
- (C3)shape(result) = shape。
例
// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
  rng_distribution = #stablehlo<rng_distribution UNIFORM>
} : (tensor<i32>, tensor<i32>, tensor<2xi64>) -> tensor<3x3xi32>
// %result: [
//           [1, 0, 1],
//           [1, 1, 1],
//           [0, 0, 0]
//          ]
rng_bit_generator
セマンティクス
初期状態 initial_state に基づいて、疑似乱数生成アルゴリズム rng_algorithm を使用して、均一なランダムビットで満たされた output と更新された出力状態 output_state を返します。出力は initial_state の確定的関数であることが保証されますが、実装間で確定的であるとは限りません。
rng_algorithm は、次のいずれかです。
- DEFAULT: 実装定義のアルゴリズム。
- THREE_FRY: Threefry アルゴリズムの実装定義バリアント。*
- PHILOX: Philox アルゴリズムの実装定義バリアント。*
* 参照: Salmon et al. SC 2011. 並列乱数: 1、2、3 で簡単に生成。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | rng_algorithm | DEFAULT、THREE_FRY、PHILOXの列挙型 | (C2) | 
| (I2) | initial_state | ui64型の 1 次元テンソル | (C1)、(C2) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| output_state | ui64型の 1 次元テンソル | (C1) | 
| output | 整数型または浮動小数点型のテンソル | 
制約
- (C1)type(initial_state) = type(output_state)。
- (C2)size(initial_state)は次のように定義されます。- rng_algorithm = DEFAULTの場合は実装定義。
- rng_algorithm = THREE_FRYの場合、- 2。
- rng_algorithm = PHILOXの場合は- 2または- 3。
 
例
// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
  rng_algorithm = #stablehlo<rng_algorithm THREE_FRY>
} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x2xui64>)
// %output_state: [1, 6]
// %output: [
//           [9236835810183407956, 16087790271692313299],
//           [18212823393184779219, 2658481902456610144]
//          ]
round_nearest_afz
セマンティクス
operand テンソルに対して要素ごとの四捨五入を行い、同値の場合はゼロから離れた値を返します。これにより、result テンソルが生成されます。IEEE-754 仕様の roundToIntegralTiesToAway オペレーションを実装します。量子化された型の場合は、dequantize_op_quantize(round_nearest_afz, operand, type(result)) を実行します。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | 浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C1) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | 浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C1) | 
制約
- (C1)baseline_type(operand) = baseline_type(result)。
例
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]
round_nearest_even
セマンティクス
テンソルで operand テンソルで、最も近い整数への要素単位の丸めを実行し、偶数の整数への結合を解除して、result テンソルを生成します。IEEE-754 仕様の roundToIntegralTiesToEven オペレーションを実装します。量子化された型の場合は、dequantize_op_quantize(round_nearest_even, operand, type(result)) を実行します。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | 浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C1) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | 浮動小数点型のテンソルまたはテンソルごとの量子化テンソル | (C1) | 
制約
- (C1)baseline_type(operand) = baseline_type(result)。
例
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]
rsqrt
セマンティクス
operand テンソルに対して要素単位の逆平方根演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- 浮動小数点数の場合: IEEE-754 の rSqrt。
- 複素数の場合: 複素数の逆平方根。
- 量子化された型の場合: dequantize_op_quantize(rsqrt, operand, type(result))。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) | 
制約
- (C1)baseline_type(operand) = baseline_type(result)。
例
// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.5], [0.33333343, 0.2]]
散布
セマンティクス
inputs テンソルと同じ results テンソルを生成します。ただし、scatter_indices で指定された複数のスライスが、update_computation を使用して値 updates で更新されます。
次の図は、updates... の要素が results... の要素にどのようにマッピングされるかを、具体的な例を使用して示しています。この図では、いくつかの updates... インデックスの例を選択して、それらが対応する results... インデックスの詳細について説明しています。
より正式に、index_space(updates[0]) 内のすべての update_index について、次のように定義します。
- update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims]。
- update_scatter_index = update_index[update_scatter_dims...]。
- start_indexは次のように定義されます。- scatter_indices[si0, ..., :, ..., siN]。ここで、- siは- update_scatter_indexの個々の要素であり、- index_vector_dim<- rank(scatter_indices)の場合、- :は- index_vector_dimインデックスに挿入されます。
- そうでない場合は [scatter_indices[update_scatter_index]]。
 
- axes(inputs[0])の- d_inputは、- d_input = scatter_dims_to_operand_dims[d_start]の場合、- full_start_index[d_input] = start_index[d_start]。
- それ以外の場合は full_start_index[d_input] = 0。
 
- axes(inputs[0])の- d_inputは、- d_input = input_batching_dims[i_batching]かつ- d_start = scatter_indices_batching_dims[i_batching]の場合は- full_batching_index[d_input] = update_scatter_index[d_start - (d_start < index_vector_dim ? 0 : 1)]。
- それ以外の場合は full_batching_index[d_input] = 0。
 
- update_window_index = update_index[update_window_dims...]。
- full_window_index = [wi0, ..., 0, ..., wiN]。ここで、- wiは- update_window_indexの個々の要素で、- 0は- inserted_window_dimsと- input_batching_dimsのインデックスに挿入されます。
- result_index = full_start_index + full_batching_index + full_window_index。
results = exec(schedule, inputs)。ここで、
- scheduleは、- index_space(updates[0])の実装定義された並べ替えです。
- exec([update_index, ...], results) = exec([...], updated_results)ここで:- result_indexが- shape(results...)の範囲内にある場合
- updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )
- updated_values = update_computation(results...[result_index], updates_converted)
- updated_resultsは、- results...[result_index]が- updated_values...に設定された- resultsのコピーです。
- それ以外の場合
- updated_results = results。
 
- exec([], results) = results。
indices_are_sorted が true の場合、実装では scatter_indices が scatter_dims_to_operand_dims に関して並べ替えられていると想定できます。それ以外の場合、動作は未定義です。より正式には、indices(result) のすべての i1 < i2 について、full_start_index(i1) <= full_start_index(i2) です。
unique_indices が true の場合、実装は分散されているすべての result_index インデックスが一意であると想定できます。unique_indices が true で、分散先のインデックスが一意でない場合、動作は未定義です。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | inputs | テンソルの可変数またはテンソルごとの量子化テンソル | (C1)、(C2)、(C4 ~ C6)、(C11)、(C13)、(C18)、(C21)、(C23 ~ C24) | 
| (I2) | scatter_indices | 整数型のテンソル | (C4)、(C15)、(C19)、(C22) | 
| (I3) | updates | テンソルの可変数またはテンソルごとの量子化テンソル | (C3-C6)、(C8) | 
| (I4) | update_window_dims | si64型の 1 次元のテンソル定数 | (C2)、(C4)、(C7~C8) | 
| (I5) | inserted_window_dims | si64型の 1 次元テンソル定数 | (C2)、(C4)、(C9 ~ C11) | 
| (I6) | input_batching_dims | si64型の 1 次元テンソル定数 | (C2)、(C4)、(C9)、(C12-13)、(C17-18)、(C20) | 
| (I7) | scatter_indices_batching_dims | si64型の 1 次元テンソル定数 | (C14 ~ C18) | 
| (I8) | scatter_dims_to_operand_dims | si64型の 1 次元テンソル定数 | (C19 ~ C21) | 
| (I9) | index_vector_dim | si64型の定数 | (C4)、(C16)、(C19)、(C22) | 
| (I10) | indices_are_sorted | i1型の定数 | |
| (I11) | unique_indices | i1型の定数 | |
| (I12) | update_computation | 関数 | (C23) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| results | テンソルの可変数またはテンソルごとの量子化テンソル | (C24 ~ C25) | 
制約
- (C1)same(shape(inputs...))。
- (C2)`rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims)
- size(input_batching_dims)`.
 
- (C3)same(shape(updates...))。
- (C4)shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes)ここで:- update_scatter_dim_sizes = shape(scatter_indices)。ただし、- index_vector_dimに対応する- scatter_indicesのディメンション サイズは含まれません。
- update_window_dim_sizes <= shape(inputs[0])。ただし、- inserted_window_dimsと- input_batching_dimsに対応する- inputs[0]のディメンション サイズは含まれません。
- combineは、- update_scatter_dim_sizesを- update_scatter_dimsに対応する軸に配置し、- update_window_dim_sizesを- update_window_dimsに対応する軸に配置します。
 
- (C5)0 < size(inputs) = size(updates) = N。
- (C6)element_type(updates...) = element_type(inputs...)。
- (C7)is_unique(update_window_dims) and is_sorted(update_window_dims)。
- (C8)0 <= update_window_dims < rank(updates[0])。
- (C9)is_unique(concatenate(inserted_window_dims, input_batching_dims))
- (C10)is_sorted(inserted_window_dims)。
- (C11)0 <= inserted_window_dims < rank(inputs[0])。
- (C12)is_sorted(input_batching_dims)。
- (C13)0 <= input_batching_dims < rank(inputs[0]))。
- (C14)is_unique(scatter_indices_batching_dims)。
- (C15)0 <= scatter_indices_batching_dims < rank(scatter_indices)。
- (C16)index_vector_dim not in scatter_indices_batching_dims。
- (C17)size(input_batching_dims) == size(scatter_indices_batching_dims)。
- (C18)dim(inputs[0], input_batching_dims...) = dim(scatter_indices, scatter_indices_batching_dims...)。
- (C19)size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1。
- (C20)is_unique(concatenate(scatter_dims_to_operand_dims, input_batching_dims))。
- (C21)0 <= scatter_dims_to_operand_dims < rank(inputs[0])。
- (C22)0 <= index_vector_dim <= rank(scatter_indices)。
- (C23)update_computationの型は(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)です。ここで、is_promotable(element_type(inputs[i]), Ei)。
- (C24)shape(inputs...) = shape(results...)。
- (C25)[0,N)内のすべてのiに対してelement_type(results[i]) = Ei。
例
// %input: [
//          [
//           [[1, 2], [3, 4], [5, 6], [7, 8]],
//           [[9, 10],[11, 12], [13, 14], [15, 16]],
//           [[17, 18], [19, 20], [21, 22], [23, 24]]
//          ],
//          [
//           [[25, 26], [27, 28], [29, 30], [31, 32]],
//           [[33, 34], [35, 36], [37, 38], [39, 40]],
//           [[41, 42], [43, 44], [45, 46], [47, 48]]
//          ]
//         ]
// %scatter_indices: [
//                    [
//                     [[0, 0], [1, 0], [2, 1]],
//                     [[0, 1], [1, 1], [0, 9]]
//                    ],
//                    [
//                     [[0, 0], [2, 1], [2, 2]],
//                     [[1, 2], [0, 1], [1, 0]]
//                    ]
//                   ]
// %update: [
//           [
//            [[1, 1], [1, 1], [1, 1]],
//            [[1, 1], [1, 1], [1, 1]]
//           ],
//           [
//            [[1, 1], [1, 1], [1, 1]],
//            [[1, 1], [1, 1], [1, 1]]
//           ]
//          ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  scatter_dimension_numbers = #stablehlo.scatter<
    update_window_dims = [3, 4],
    inserted_window_dims = [1],
    input_batching_dims = [0],
    scatter_indices_batching_dims = [1],
    scatter_dims_to_operand_dims = [2, 1],
    index_vector_dim = 3>,
  indices_are_sorted = false,
  unique_indices = false
} : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>, tensor<2x2x3x2x2xi64>) -> tensor<2x3x4x2xi64>
// %result: [
//           [
//            [[3, 4], [6, 7], [6, 7], [7, 8]],
//            [[9, 10],[11, 12], [15, 16], [17, 18]],
//            [[17, 18], [19, 20], [22, 23], [24, 25]]
//           ],
//           [
//            [[25, 26], [28, 29], [30, 31], [31, 32]],
//            [[35, 36], [38, 39], [38, 39], [39, 40]],
//            [[41, 42], [44, 45], [46, 47], [47, 48]]
//           ]
//          ]
選択
セマンティクス
result テンソルを生成します。各要素は、pred の対応する要素の値に基づいて on_true テンソルまたは on_false テンソルから選択されます。正式には result[result_index] = pred_element ? on_true[result_index] :
on_false[result_index] です。ここで pred_element = rank(pred) = 0 ? pred[] :
pred[result_index] は量子化された型の場合は、dequantize_select_quantize(pred, on_true, on_false, type(result)) を実行します。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | pred | i1型のテンソル | (C1) | 
| (I2) | on_true | テンソルまたはテンソルごとの量子化テンソル | (C1-C2) | 
| (I3) | on_false | テンソルまたはテンソルごとの量子化テンソル | (C2) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | テンソルまたはテンソルごとの量子化テンソル | (C2) | 
制約
- (C1)rank(pred) = 0 or shape(pred) = shape(on_true)。
- (C2)baseline_type(on_true) = baseline_type(on_false) = baseline_type(result)。
例
// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false) : (tensor<2x2xi1>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 2], [3, 8]]
select_and_scatter
セマンティクス
select を使用した input テンソルの reduce_window の結果に基づいて、scatter を使用して source テンソルから値を散布し、result テンソルを生成します。
次の図は、具体的な例を使用して、result の要素が operand と source からどのように計算されるかを示しています。
よりフォーマルな表現:
- selected_values = reduce_window_without_init(...)は、次の入力に置き換えます。- inputs = [operand].
- window_dimensions、- window_strides、- paddingはそのまま使用されます。
- base_dilations = windows_dilations = 1。
- bodyは次のように定義されます。
 - def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>: return select(arg0, arg1) ? arg0 : arg1;- ここで、 - E = element_type(operand)と- reduce_window_without_initは- reduce_windowとまったく同じように機能しますが、基になる- reduceの- schedule(削減を参照)に init 値が含まれないことが異なります。対応するウィンドウに値がない場合の動作は現在未定義です(#731)。
- result[result_index] = reduce([source_values], [init_value], [0], scatter)ここで:- source_values = [source[source_index] for source_index in source_indices]。
- selected_index(source_index) = operand_index:- selected_values[source_index]に- operand_indexの- operand要素がある場合。
- source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index]。
 
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | テンソルまたはテンソルごとの量子化テンソル | (C1~C4)、(C6)、(C8~C11) | 
| (I2) | source | テンソルまたはテンソルごとの量子化テンソル | (C1)、(C2) | 
| (I3) | init_value | 0 次元テンソルまたはテンソルごとの量子化テンソル | (C3) | 
| (I4) | window_dimensions | si64型の 1 次元のテンソル定数 | (C2)、(C4)、(C5) | 
| (I5) | window_strides | si64型の 1 次元のテンソル定数 | (C2)、(C6)、(C7) | 
| (I6) | padding | si64型の 2 次元のテンソル定数 | (C2)、(C8) | 
| (I7) | select | 関数 | (C9) | 
| (I8) | scatter | 関数 | (C10) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | テンソルまたはテンソルごとの量子化テンソル | (C11-C12) | 
制約
- (C1)element_type(operand) = element_type(source)。
- (C2)shape(source) = num_windows。ここで:- padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1]。
- is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape
- num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1。
 
- (C3)element_type(init_value) = element_type(operand)。
- (C4)size(window_dimensions) = rank(operand)。
- (C5)0 < window_dimensions。
- (C6)size(window_strides) = rank(operand)。
- (C7)0 < window_strides。
- (C8)shape(padding) = [rank(operand), 2]。
- (C9)selectの型は(tensor<E>, tensor<E>) -> tensor<i1>(E = element_type(operand))です。
- (C10)scatterの型は(tensor<E>, tensor<E>) -> tensor<E>です。is_promotable(element_type(operand), E)
- (C11)shape(operand) = shape(result)。
- (C12)element_type(result) = E。
例
// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GE>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    "stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  window_dimensions = array<i64: 3, 1>,
  window_strides = array<i64: 2, 1>,
  padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi64>, tensor<2x2xi64>, tensor<i64>) -> tensor<4x2xi64>
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]
送信
セマンティクス
inputs をチャンネル channel_id に送信し、result トークンを生成します。
is_host_transfer が true の場合、オペレーションはホストにデータを転送します。そうでない場合は、別のデバイスにデータを転送します。つまり、実装は定義されます。このフラグは channel_type で提供される情報と重複するため、今後はどちらか一方のみを保持する予定です(#666)。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | inputs | テンソルまたは量子化テンソルの可変数 | |
| (I2) | token | token | |
| (I3) | channel_id | si64型の定数 | |
| (I4) | channel_type | DEVICE_TO_DEVICEとDEVICE_TO_HOSTの列挙型 | (C1) | 
| (I5) | is_host_transfer | i1型の定数 | (C1) | 
出力
| 名前 | タイプ | 
|---|---|
| result | token | 
制約
- (C1)channel_typeは次のように定義されます。- is_host_transfer = trueの場合、- DEVICE_TO_HOST
- それ以外の場合は DEVICE_TO_DEVICE。
 
例
%result = "stablehlo.send"(%operand, %token) {
  channel_handle = #stablehlo.channel_handle<handle = 1, type = 2>,
  is_host_transfer = true
} : (tensor<2x2xi64>, !stablehlo.token) -> !stablehlo.token
shift_left
セマンティクス
lhs テンソルを rhs ビット分要素ごとに左シフトし、result テンソルを生成します。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | lhs | 整数型のテンソル | (C1) | 
| (I2) | rhs | 整数型のテンソル | (C1) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | 整数型のテンソル | (C1) | 
制約
- (C1)type(lhs) = type(rhs) = type(result)。
例
// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-2, 0, 8]
shift_right_arithmetic
セマンティクス
lhs テンソルに対して rhs ビット数だけ要素単位の右シフト演算を実行し、result テンソルを生成します。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | lhs | 整数型のテンソル | (C1) | 
| (I2) | rhs | 整数型のテンソル | (C1) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | 整数型のテンソル | (C1) | 
制約
- (C1)type(lhs) = type(rhs) = type(result)。
例
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-1, 0, 1]
shift_right_logical
セマンティクス
lhs テンソルを rhs ビット数だけ要素ごとに論理右シフトし、result テンソルを生成します。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | lhs | 整数型のテンソル | (C1) | 
| (I2) | rhs | 整数型のテンソル | (C1) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | 整数型のテンソル | (C1) | 
制約
- (C1)type(lhs) = type(rhs) = type(result)。
例
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [9223372036854775807, 0, 1]
署名
セマンティクス
要素ごとの operand の符号を返して、result テンソルを生成します。より正式には、各要素 x のセマンティクスは、次のように Python 構文を使用して表現できます。
def sign(x):
  if is_integer(x):
    if compare(x, 0, LT, SIGNED): return -1
    if compare(x, 0, EQ, SIGNED): return 0
    return 1
  elif is_float(x):
    if is_nan(x): return NaN
    if compare(x, -0.0, EQ, FLOAT): return -0.0
    if compare(x, +0.0, EQ, FLOAT): return +0.0
    if compare(x, 0.0, LT, FLOAT): return -1.0
    return 1.0
  elif is_complex(x):
    if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
    if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
    return divide(x, convert(abs(x), type(x)))
量子化型の場合は、dequantize_op_quantize(sign, operand, type(result)) を実行します。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | 符号付き整数、浮動小数点数、複素型のテンソル、またはテンソルごとの量子化テンソル | (C1) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | 符号付き整数型、浮動小数点型、複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) | 
制約
- (C1)baseline_type(operand) = baseline_type(result)。
例
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
サイン
セマンティクス
operand テンソルで要素ごとの正弦演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の操作を行います。
- 浮動小数点数の場合: IEEE-754 の sin。
- 複素数の場合: 複素正弦。
- 量子化された型の場合: dequantize_op_quantize(sine, operand, type(result))。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) | 
制約
- (C1)baseline_type(operand) = baseline_type(result)。
例
// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.sine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [0.0, -1.0]]
slice
セマンティクス
静的に計算された開始インデックスを使用して operand からスライスを取り出し、result テンソルを生成します。start_indices には各ディメンションのスライスの開始インデックスが、limit_indices には各ディメンションのスライスの終了インデックス(これを含まない)が含まれ、strides には各ディメンションのストライドが含まれます。
より正式には、result[result_index] = operand[operand_index] で operand_index = start_indices + result_index * strides。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | テンソルまたはテンソルごとの量子化テンソル | (C1 ~ C3)、(C5) | 
| (I2) | start_indices | si64型の 1 次元テンソル定数 | (C2)、(C3)、(C5) | 
| (I3) | limit_indices | si64型の 1 次元テンソル定数 | (C2)、(C3)、(C5) | 
| (I4) | strides | si64型の 1 次元テンソル定数 | (C2)、(C4) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | テンソルまたはテンソルごとの量子化テンソル | (C1)、(C5) | 
制約
- (C1)element_type(operand) = element_type(result)。
- (C2)size(start_indices) = size(limit_indices) = size(strides) = rank(operand)。
- (C3)0 <= start_indices <= limit_indices <= shape(operand)。
- (C4)0 < strides。
- (C5)shape(result) = ceil((limit_indices - start_indices) / strides)。
例
// %operand: [
//            [0, 0, 0, 0],
//            [0, 0, 1, 1],
//            [0, 0, 1, 1]
//           ]
%result = "stablehlo.slice"(%operand) {
  start_indices = array<i64: 1, 2>,
  limit_indices = array<i64: 3, 4>,
  strides = array<i64: 1, 1>
} : (tensor<3x4xi64>) -> tensor<2x2xi64>
// % result: [
//            [1, 1],
//            [1, 1]
//           ]
並べ替え
セマンティクス
comparator に従って、ディメンション dimension に沿って inputs の 1 次元スライスを並べ替え、results を生成します。
他のオペレーションの同様の入力とは異なり、dimension では負の値を使用できます。セマンティクスは次のとおりです。今後、一貫性のため、この操作が禁止される可能性があります(#1377)。
is_stable が true の場合、並べ替えは安定します。つまり、比較演算子によって等しいとみなされる要素の相対順序は保持されます。入力が 1 つの場合、2 つの要素 e1 と e2 は、comparator(e1, e2) = comparator(e2, e1) = false の場合にのみ、比較演算子によって等しいと見なされます。これが複数の入力に一般化される仕組みについては、以下の形式化をご覧ください。
より正式に、index_space(results[0]) 内のすべての result_index について、次のように定義します。
- adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension。
- result_slice = [ri0, ..., :, ..., riR-1]。ここで、- riNは- result_indexの個々の要素であり、- :は- adjusted_dimensionに挿入されます。
- inputs_together = (inputs[0]..., ..., inputs[N-1]...)。
- results_together[result_slice] = sort(inputs_together[result_slice], comparator_together)。
- ここで、sortは 1 次元スライスを降順以外の順序で並べ替えます。これは、左側の引数が右側の 2 番目の引数より小さい場合にcomparator_togetherがtrueを返すことを想定しています。
- def comparator_together(lhs_together, rhs_together): args = [] for (lhs_el, rhs_el) in zip(lhs_together, rhs_together): args.append(lhs_el) args.append(rhs_el) return comparator(*args)
- (results[0]..., ..., results[N-1]...) = results_together。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | inputs | テンソルの可変数またはテンソルごとの量子化テンソル | (C1 ~ C5) | 
| (I2) | dimension | si64型の定数 | (C4) | 
| (I3) | is_stable | i1型の定数 | |
| (I4) | comparator | 関数 | (C5) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| results | テンソルの可変数またはテンソルごとの量子化テンソル | (C2)、(C3) | 
制約
- (C1)0 < size(inputs)。
- (C2)type(inputs...) = type(results...)。
- (C3)same(shape(inputs...) + shape(results...))。
- (C4)-R <= dimension < R(R = rank(inputs[0]))。
- (C5)comparatorの型は(tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1>です。ここで、Ei = element_type(inputs[i])。
例
// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>, %arg3: tensor<i64>):
    %predicate = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GT>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    "stablehlo.return"(%predicate) : (tensor<i1>) -> ()
}) {
  dimension = 0 : i64,
  is_stable = true
} : (tensor<2x3xi64>, tensor<2x3xi64>) -> (tensor<2x3xi64>, tensor<2x3xi64>)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]
sqrt
セマンティクス
operand テンソルで要素ごとの平方根演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の操作を行います。
- 浮動小数点数の場合: IEEE-754 の squareRoot。
- 複素数の場合: 複素平方根。
- 量子化された型の場合: dequantize_op_quantize(sqrt, operand, type(result))。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) | 
制約
- (C1)baseline_type(operand) = baseline_type(result)。
例
// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [2.0, 3.0]]
subtract
セマンティクス
2 つのテンソル lhs と rhs の要素ごとの減算を行い、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。
- 整数の場合: 整数の減算。
- 浮動小数点数: IEEE-754 の subtraction。
- 複素数の場合: 複素減算。
- 量子化された型の場合:
- dequantize_op_quantize(subtract, lhs, rhs, type(result))。
 
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | lhs | 整数型、浮動小数点型、複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) | 
| (I2) | rhs | 整数、浮動小数点、複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | 整数型、浮動小数点型、複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) | 
制約
- (C1)baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)。
例
// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
// %result: [[1, 2], [3, 4]]
tan
セマンティクス
operand テンソルに対して要素ごとのタンジェント演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の操作を行います。
- 浮動小数点数の場合: IEEE-754 の tan。
- 複素数の場合: 複素タンジェント。
- 量子化型の場合: dequantize_op_quantize(tan, operand, type(result))。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) | 
制約
- (C1)baseline_type(operand) = baseline_type(result)。
例
// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.tan"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [
//           [0.0, 1.63312e+16],
//           [0.0, 5.44375e+15]
//          ]
tanh
セマンティクス
operand テンソルに対して要素単位の双曲線正接演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の操作を行います。
- 浮動小数点数の場合: IEEE-754 の tanh。
- 複素数の場合: 複素双曲線正接。
- 量子化された型の場合:
- dequantize_op_quantize(tanh, operand, type(result))。
 
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) | 
制約
- (C1)baseline_type(operand) = baseline_type(result)。
例
// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand) : (tensor<3xf32>) -> tensor<3xf32>
// %result: [-0.76159416, 0.0, 0.76159416]
行 / 列の入れ替え
セマンティクス
permutation を使用して operand テンソルの次元を並べ替え、result テンソルを生成します。より正式には、result[result_index] = operand[operand_index] で、result_index[d] = operand_index[permutation[d]] です。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | テンソルまたは量子化テンソル | (C1~C4) | 
| (I2) | permutation | si64型の 1 次元テンソル定数 | (C2 ~ C4) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | テンソルまたは量子化テンソル | (C1)、(C3 ~ C4) | 
制約
- (C1)element_type(result)は次のように定義されます。- element_type(operand)(- !is_per_axis_quantized(operand)の場合)。
- element_type(operand)。ただし、- quantization_dimension(operand)と- quantization_dimension(result)が異なる場合があります。
 
- (C2)permutationはrange(rank(operand))の並べ替えです。
- (C3)shape(result) = dim(operand, permutation...)。
- (C4)is_per_axis_quantized(result)の場合はquantization_dimension(operand) = permutation(quantization_dimension(result))です。
例
// %operand: [
//            [[1,2], [3,4], [5,6]],
//            [[7,8], [9,10], [11,12]]
//           ]
%result = "stablehlo.transpose"(%operand) {
  permutation = array<i64: 2, 1, 0>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
// %result: [
//           [[1,7], [3,9], [5,11]],
//           [[2,8], [4,10], [6,12]]
//          ]
triangular_solve
セマンティクス
下三角または上三角係数行列を持つ連立一次方程式のバッチを解きます。
より正式には、a と b が指定されている場合、result[i0, ..., iR-3, :, :] は left_side が true の場合の op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :] の解、または left_side が false の場合の x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :] です。変数 x は、op(a) が transpose_a によって決定されます。transpose_a は次のいずれかです。
- NO_TRANSPOSE:- aをそのまま使用してオペレーションを実行します。
- TRANSPOSE:- aの転置でオペレーションを実行します。
- ADJOINT:- aの共役転置に対して演算を実行します。
入力データは、lower が true の場合、a の下三角からのみ読み取られます。それ以外の場合は、a の上三角から読み取られます。出力データは同じ三角形に返されます。もう 1 つの三角形の値は実装で定義されます。
unit_diagonal が true の場合、実装は a の対角要素が 1 に等しいと想定できます。それ以外の場合は、動作は未定義です。
量子化された型の場合は、dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower,
unit_diagonal, transpose_a), a, b, type(result)) を実行します。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | a | 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1~C3) | 
| (I2) | b | 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1~C4) | 
| (I3) | left_side | i1型の定数 | (C3) | 
| (I4) | lower | i1型の定数 | |
| (I5) | unit_diagonal | i1型の定数 | |
| (I6) | transpose_a | NO_TRANSPOSE、TRANSPOSE、ADJOINTの列挙型 | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル | (C1) | 
制約
- (C1)baseline_element_type(a) = baseline_element_type(b)。
- (C2)2 <= rank(a) = rank(b) = R。
- (C3)shape(a)とshape(b)の関係は次のように定義されます。- shape(a)[:-3] = shape(b)[:-3]。
- dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1)。
 
- (C4)baseline_type(b) = baseline_type(result)。
例
// %a = [
//       [1.0, 0.0, 0.0],
//       [2.0, 4.0, 0.0],
//       [3.0, 5.0, 6.0]
//      ]
// %b = [
//       [2.0, 0.0, 0.0],
//       [4.0, 8.0, 0.0],
//       [6.0, 10.0, 12.0]
//      ]
%result = "stablehlo.triangular_solve"(%a, %b) {
  left_side = true,
  lower = true,
  unit_diagonal = false,
  transpose_a = #stablehlo<transpose NO_TRANSPOSE>
} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// %result: [
//           [2.0, 0.0, 0.0],
//           [0.0, 2.0, 0.0],
//           [0.0, 0.0, 2.0]
//          ]
tuple
セマンティクス
値 val から result タプルを生成します。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | val | 可変長の値の数 | (C1) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | tuple | (C1) | 
制約
- (C1)resultの型はtuple<E0, ..., EN-1>で、Ei = type(val[i])です。
例
// %val0: [1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1) : (tensor<2xf32>, tuple<tensor<i32>>) -> tuple<tensor<2xf32>, tuple<tensor<i32>>>
// %result: ([1.0, 2.0], (3))
uniform_dequantize
セマンティクス
operand 型で定義された量子化パラメータに従って、量子化テンソル operand を浮動小数点テンソル result に要素ごとに変換します。
正式には result = dequantize(operand) です。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | 量子化テンソル | (C1)、(C2) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | 浮動小数点型のテンソル | (C1)、(C2) | 
制約
- (C1)shape(operand) = shape(result)。
- (C2)element_type(result) = expressed_type(operand)。
例
// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
// %result: [4.0, 15.0]
uniform_quantize
セマンティクス
result 型で定義された量子化パラメータに従って、浮動小数点テンソルまたは量子化テンソル operand を量子化テンソル result に要素単位で変換します。
よりフォーマルな表現では
- is_float(operand)の場合:- result = quantize(operand, type(result))。
 
- is_quantized(operand)の場合:- float_result = dequantize(operand)。
- result = quantize(float_result, type(result))。
 
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | 浮動小数点型または量子化型のテンソル | (C1)、(C2) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | 量子化テンソル | (C1)、(C2) | 
制約
- (C1)shape(operand) = shape(result)。
- (C2)expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand)。
例
// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
// %result: [10, 10]
// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-20,0.2:-30}>>
// %result: [20, 45]
しばらく
セマンティクス
cond 関数が true を出力している間に、body 関数を 0 回以上実行すると、出力が生成されます。より正式には、セマンティクスは Python 構文を使用して次のように表現できます。
internal_state = operand
while cond(*internal_state):
  internal_state = body(*internal_state)
results = internal_state
無限ループの動作は未定です(#383)。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | operand | テンソル、量子化テンソル、トークンの可変数 | (C1~C3) | 
| (I2) | cond | 関数 | (C1) | 
| (I3) | body | 関数 | (C2) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| results | テンソル、量子化テンソル、トークンの可変数 | (C3) | 
制約
- (C1)condの型は(T0, ..., TN-1) -> tensor<i1>です。ここで、Ti = type(operand[i])。
- (C2)bodyの型は(T0, ..., TN-1) -> (T0, ..., TN-1)です。ここで、Ti = type(operand[i])。
- (C3)type(results...) = type(operand...)。
例
// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %cond = "stablehlo.compare"(%arg0, %ten) {
      comparison_direction = #stablehlo<comparison_direction LT>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    stablehlo.return %cond : tensor<i1>
  }, {
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %new_sum = stablehlo.add %arg1, %one : tensor<i64>
    %new_i = stablehlo.add %arg0, %one : tensor<i64>
    stablehlo.return %new_i, %new_sum : tensor<i64>, tensor<i64>
}) : (tensor<i64>, tensor<i64>) -> (tensor<i64>, tensor<i64>)
// %results0: 10
// %results1: 10
xor
セマンティクス
2 つのテンソル lhs と rhs の要素ごとの XOR を実行し、result テンソルを生成します。要素のタイプに応じて、次の操作を行います。
- ブール値の場合: 論理 XOR。
- 整数の場合: ビット単位の XOR。
入力
| ラベル | 名前 | タイプ | 制約 | 
|---|---|---|---|
| (I1) | lhs | ブール型または整数型のテンソル | (C1) | 
| (I2) | rhs | ブール値型または整数型のテンソル | (C1) | 
出力
| 名前 | タイプ | 制約 | 
|---|---|---|
| result | ブール型または整数型のテンソル | (C1) | 
制約
- (C1)type(lhs) = type(rhs) = type(result)。
例
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, false]]
方言の相互運用
現時点では、StableHLO プログラムには、StableHLO で定義されていないオペレーションが含まれている場合があります。
モジュール、関数、呼び出し、リターン
StableHLO は、ModuleOp、FuncOp、CallOp、ReturnOp にアップストリーム MLIR オペレーションを使用します。これは、既存の MLIR 機構との相互運用性を高めるために行われました。これは、FuncOp と ModuleOp をターゲットとする多くの有用なパスが記述されており、多くのコンパイル パイプラインは、これらの op が存在することを想定しているためです。これらの op には、完全な互換性の保証が適用されます。これらのオペレーションが互換性のない方法で変更された場合(削除など)、互換性を維持するために StableHLO と同等のオペレーションが追加されます。
CHLO
CHLO オペセットには、StableHLO に分解される上位レベルのオペレーションが含まれています。現在のところ、CHLO の互換性は保証されていません。互換性を保証するには、シリアル化の前に chlo-legalize-to-stablehlo パスを使用する必要があります。
シェイプ オペレーション
動的な StableHLO プログラムでコア MLIR 言語の特定のオペレーションを使用して、シェイプ計算を実行することは、コミュニティで一般的なユースケースです。一般的には、shape_of や num_elements などの shape 言語演算、dim や from_elements などの tensor 言語演算、組み込みの index 型演算があります。
Dynamism RFC > O2 では、これらはサポート範囲外とされています。ただし、相互運用性のために index 型の一部がサポートされています。これらの演算や型には互換性は保証されません。shape-legalize-to-stablehlo パスを使用して、これらのオペレーションを完全にサポートされている StableHLO オペレーションに変換できます。
非推奨のオペレーション
MHLO から継承された StableHLO オペレーションがいくつかあります。これらは非推奨で、StableHLO から移行される予定です。これらの削除の詳細については、StableHLO v1.0 のクリーンアップ #2283 をご覧ください。これらのサポート終了に関するトラッカーの問題は #2340 です。
これらのオペレーションは、次のカテゴリに分類されます。
- StableHLO オペレーションの「HLO に含まれない」カテゴリ - 最初は StableHLO オペセットの一部でしたが、後で適切ではないと判断されました。broadcast、create_token、cross-replica-sum、dot、einsum、torch_index_select、unary_einsum(#3)。
- 使用されていないオペレーション - これらのオペレーションは、ある時点では有用だったかもしれませんが、オペレーションが十分に開発されていないか、これらのオペレーションを使用するパイプラインがリファクタリングされ、オペレーションが不要になっています。これには、map、tuple(#598)、get_tuple_element、rng、complexの比較 #560、畳み込みwindow_reversal(#1181)が含まれます。
これらのオペレーションの一部は、既存のオペレーション(broadcast、create_token、cross-replica-sum、dot、unary_einsum)を使用して表現できるため、簡単に削除できます。これらのオペレーションは、既存の互換性期間(6 か月)が経過すると削除されます。その他のオペレーションは、削除の検討中です(einsum、get_tuple_element、map、rng torch_index_select、tuple、complex 比較、window_reversal)。コミュニティからのフィードバックに基づいて、これらのオペレーションは削除されるか、完全なサポートで仕様に追加されます。これらのオペレーションの将来が判明するまでは、互換性が保証されるのは 6 か月間のみです。
実行
順次実行
StableHLO プログラムを実行するには、main 関数に入力値を指定し、出力値を計算します。関数の出力値は、対応する return オペレーションにルートを持つオペレーションのグラフを実行することで計算されます。
実行順序は、データフローに沿っている限り(つまり、オペレーションが使用前に実行される場合)、実装で定義されます。StableHLO では、副作用のある演算はすべて 1 つのトークンを消費し、1 つのトークンを生成します(複数のトークンは after_all を介して 1 つのトークンに多重化できます)。そのため、副作用の実行順序もデータフローに沿っています。たとえば、次のプログラムでは、%0 → %1 → %2 → return と %1 → %0 → %2 → return の 2 つの実行順序が可能です。
func.func @main() -> tensor<f64> {
  %0 = stablehlo.constant dense<1.0> : tensor<f64>
  %1 = stablehlo.constant dense<2.0> : tensor<f64>
  %2 = stablehlo.add %0, %1 : tensor<f64>
  return %2 : tensor<f64>
}
より正式には、StableHLO プロセスは、1)StableHLO プログラム、2)オペレーションのステータス(まだ実行されていない、すでに実行されている)、3)プロセスが処理中の中間値を組み合わせたものです。このプロセスは、main 関数への入力値から始まり、オペレーションのステータスと中間値を更新するオペレーションのグラフを経て、出力値で終了します。形式化は未定です(#484)。
並列実行
StableHLO プログラムは並行して実行でき、どちらも ui32 型の num_partitions で num_replicas の 2D プロセス グリッドに編成できます。
StableHLO プロセス グリッドでは、StableHLO プロセスの num_replicas * num_partitions が同時に実行されています。各プロセスには一意の process_id = (replica_id, partition_id) があります。replica_ids = range(num_replicas) の replica_id と partition_ids = range(num_partitions) の partition_id はどちらも ui32 型です。
プロセスグリッドのサイズはプログラムごとに静的に把握され(将来的には StableHLO プログラム #650 で明示的な一部にする予定です)、プロセスグリッド内の位置はプロセスごとに静的に把握されます。各プロセスは、replica_id オペレーションと partition_id オペレーションを介して、プロセス グリッド内の位置にアクセスできます。
プロセス グリッド内で、プログラムはすべて同じ(「単一プログラム、複数のデータ」スタイル)にすることも、すべて異なる(「複数プログラム、複数のデータ」スタイル)にすることも、その中間にも設定できます。今後、GSPMD(#619)など、並列 StableHLO プログラムを定義する他のイディオムのサポートを導入する予定です。
プロセス グリッド内のプロセスは、ほとんどが互いに独立しています。オペレーションのステータス、入力値、中間値、出力値はそれぞれ別々で、ほとんどのオペレーションはプロセス間で別々に実行されます(後述する少数の集約オペレーションを除く)。
ほとんどのオペレーションの実行では同じプロセスの値のみが使用されるため、通常、これらの値を名前で参照してもあいまいさはありません。ただし、集約オペレーションのセマンティクスを記述する場合、これは不十分です。そのため、特定のプロセス内の値 name を参照する記号 name@process_id が使用されます。(この観点から、修飾されていない name は name@(replica_id(), partition_id()) の省略形と見なすことができます)。
プロセス間の実行順序は、以下で説明するように、ポイントツーポイント通信と集約オペレーションによって導入される同期を除き、実装で定義されます。
ポイントツーポイント通信
StableHLO プロセスは、StableHLO チャネルを介して相互に通信できます。チャネルは、si64 型の正の ID で表されます。さまざまなオペレーションを使用して、値をチャネルに送信したり、チャネルから受信したりできます。
これらのチャンネル ID の取得元、プロセス プログラムが認識する方法、プロセスによって行われる同期の種類など、さらなる形式化は未定(#484)です。
ストリーミング通信
すべての StableHLO プロセスは、次の 2 つのストリーミング インターフェースにアクセスできます。
- 読み取り可能なInfeed。
- 書き込み可能なアウトフィードの。
プロセス間の通信に使用され、両端にプロセスがあるチャネルとは異なり、インフィードとアウトフィードでは、もう一方の端は実装で定義されます。
ストリーミング通信が実行順序に与える影響や、それによって行われる同期の種類など、さらなる形式化は未定です(#484)。
集団オペレーション
StableHLO には、all_gather、all_reduce、all_to_all、collective_broadcast、collective_permute、reduce_scatter の 6 つのグループ演算があります。これらのオペレーションはすべて、StableHLO プロセス グリッド内のプロセスを StableHLO プロセス グループに分割し、他のプロセス グループとは独立して、各プロセス グループ内で共同計算を実行します。
各プロセス グループ内で、集約オペレーションによって同期バリアが発生する可能性があります。この同期が正確にいつ行われるか、プロセスがこの障壁に到達する正確な方法、到達しなかった場合にどうなるかなど、さらなる形式化は未定です(#484)。
プロセス グループにパーティション間の通信が含まれる場合(パーティション ID が異なるプロセスがプロセス グループ内にある場合)、集約オペレーションの実行にはチャネルが必要であり、集約オペレーションは si64 型の正の channel_id を提供する必要があります。レプリカ間の通信にはチャネルは必要ありません。
集約オペレーションによって実行される計算は個々のオペレーションに固有であり、上記の個々のオペレーションのセクションで説明しています。ただし、プロセス グリッドをプロセス グループに分割する戦略はこれらのオペレーション間で共有され、このセクションで説明します。より正式には、StableHLO は次の 4 つの戦略をサポートしています。
cross_replica
各プロセス グループ内では、レプリカ間の通信のみが発生します。この戦略では、replica_groups(レプリカ ID のリストのリストのリスト)を受け取り、replica_groups と partition_ids の直積を計算します。replica_groups には一意の要素があり、すべての replica_ids をカバーする必要があります。より正式には、Python 構文を使用します。
def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
  for replica_group in replica_groups:
    for partition_id in partition_ids:
      process_group = []
      for replica_id in replica_group:
        process_group.append((replica_id, partition_id))
      yield process_group
たとえば、replica_groups = [[0, 1], [2, 3]] と num_partitions = 2 の場合、cross_replica は [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]] を生成します。
cross_partition
各プロセス グループ内で行われる通信は、パーティション間の通信のみです。この戦略は、partition_groups(パーティション ID のリストのリスト)を受け取り、replica_ids による partition_groups のデカルト積を計算します。partition_groups には一意の要素があり、すべての partition_ids をカバーする必要があります。より正式には、Python 構文を使用して次のようにします。
def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
  for partition_group in partition_groups:
    for replica_id in replica_ids:
      process_group = []
      for partition_id in partition_group:
        process_group.append((replica_id, partition_id))
      yield process_group
たとえば、partition_groups = [[0, 1]] と num_replicas = 4 の場合、cross_partition は [[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]] を生成します。
cross_replica_and_partition
各プロセス グループ内で、レプリカ間通信とパーティション間通信の両方が発生する可能性があります。この戦略は、レプリカ ID のリストのリストである replica_groups を受け取り、partition_ids によって各 replica_group のデカルト積を計算します。replica_groups は一意の要素を持ち、すべての replica_ids をカバーする必要があります。より正式には、Python 構文を使用して次のようにします。
def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
  for replica_group in replica_groups:
    process_group = []
    for partition_id in partition_ids:
      for replica_id in replica_group:
        process_group.append((replica_id, partition_id))
    yield process_group
たとえば、replica_groups = [[0, 1], [2, 3]] と num_partitions = 2 の場合、cross_replica_and_partition は [[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]] を生成します。
flattened_ids
この戦略では、flattened_id_groups(replica_id * num_partitions + partition_id 形式の「フラット化」されたプロセス ID のリスト)を受け取り、プロセス ID に変換します。flattened_id_groups は一意の要素を持ち、すべての process_ids をカバーする必要があります。より正式には、Python 構文を使用して次のようにします。
def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
  for flattened_id_group in flattened_id_groups:
    process_group = []
    for flattened_id in flattened_id_group:
      replica_id = flattened_id // num_partitions
      partition_id = flattened_id % num_partitions
      process_group.append((replica_id, partition_id))
    yield process_group
たとえば、flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]]、num_replicas = 4、num_partitions = 2 の場合、flattened_ids は [[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]] を生成します。
精度
現時点では、StableHLO は数値の精度を保証していませんが、今後変更される可能性があります(#1156)。
量子化オペレーションの実行セマンティクス
量子化された StableHLO 演算の解釈は、ハードウェアの要件と機能によって異なる場合があります。たとえば、一部のハードウェアでは、「逆量子化し、浮動小数点演算を実行して、最終的に量子化」戦略を使用して量子化演算を解釈します。他の関数では、整数演算で計算全体を実行することもあります。したがって、量子化された StableHLO オペレーションの解釈は、特定の実装によってのみ決まります。ハイブリッド量子化(#1575)の解釈は、仕様(1792 を介して)で規定されているそのセマンティクスに基づく必要があります。
エラー
StableHLO プログラムは、個々のオペレーションに対する広範な制約セットによって検証され、実行前に多くのエラークラスが除外されます。ただし、整数オーバーフローや境界外アクセスなどのエラー条件は引き続き発生する可能性があります。明示的に示されていない限り、これらのエラーはすべて実装定義の動作になりますが、今後変更される可能性があります(#1157)。
浮動小数点例外
このルールの例外として、StableHLO プログラムの浮動小数点例外には明確に定義された動作があります。IEEE-754 標準で定義されている例外(無効なオペレーション、ゼロ除算、オーバーフロー、アンダーフロー、不正確な例外)が発生するオペレーションは、(標準で定義されているように)デフォルトの結果を生成し、対応するステータス フラグを発生させずに実行を続行します。これは、標準の raiseNoFlag 例外処理に似ています。標準以外の演算(複素演算や特定の超越関数など)の例外は、実装で定義されます。
形状の不一致
StableHLO は、動的に形状を変更できるテンソルをサポートしています。ただし、実行時にシェイプが一致する必要があります。一致しない場合、動作は未定義になります。StableHLO では、実行時にテンソルが特定の形状であることをアサートできる op が明示的に提供されていません。正しいコードを生成するのは、プロデューサーの責任です。
具体的な例として、以下のプログラムは有効です。ただし、実行時には %arg0 と %arg1 の正確な形状が同じである必要があります。一致していない場合、プログラムの動作は未定義になります。
func.func @foo(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
    %0 = stablehlo.add %arg0, %arg1 : tensor<?xi32>
    return %0 : tensor<?xi32>
}
Notation
このドキュメントでは、構文の記述に、EBNF 構文の ISO フレーバー(ISO/IEC 14977:1996、Wikipedia)を変更して使用しています。変更点は 2 つあります。1)ルールは = ではなく ::= を使用して定義されます。
2)連結は , ではなく並置を使用して表されます。
セマンティクスの記述(「型」、「定数」、「オペレーション」セクション内)では、Python 構文に基づく式を使用しています。この式は、以下で説明するように、配列オペレーションを簡潔に表現できるように拡張されています。これは小さなコード スニペットには適していますが、大きなコード スニペットが必要なまれなケースでは、常に明示的に導入される標準の Python 構文を使用します。
数式
dot_general 仕様の例に基づいて、数式の仕組みを見てみましょう。このオペレーションの制約の 1 つは次のとおりです。dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)。
この式で使用される名前は、2 つのソースから取得されます。1 つはグローバル関数(dim など)、もう 1 つは対応するプログラム要素のメンバー定義(dot_general の [入力] セクションで定義された lhs、lhs_batching_dimensions、rhs、rhs_batching_dimensions の入力)です。
前述のとおり、この式の構文は Python ベースで、簡潔さを重視した拡張機能がいくつかあります。式を理解するために 通常の Python 構文に変換しましょう
A)これらの数式では、= を使用して等号を表しています。Python 構文を取得するための最初のステップは、= を == に置き換えることです。dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...) のようにします。
B)また、これらの式では、スカラー式をテンソル式に変換する楕円記号(...)もサポートされています。簡単に言うと、f(xs...) は「テンソル xs 内のスカラー x ごとにスカラー f(x) を計算し、これらのスカラー結果をすべてテンソル結果として返す」ことを意味します。通常の Python 構文の場合、数式の例は [dim(lhs, dim1) for dim1 in lhs_batching_dimensions] ==
[dim(rhs, dim2) for dim2 in rhs_batching_dimensions] になります。
楕円記法を使用すると、個々のスカラーレベルで作業する必要がなくなります。ただし、場合によっては、gather 仕様の start_indices[bi0, ..., :, ..., biN] 数式のように、下位レベルの準非形式構文が使用されることがあります。簡潔にするため、このような構文を標準の Python に変換するための正確な形式主義は提供していません。ケースごとに直感的に理解できるようにすることを目的としています。特定の式がわかりにくい場合はお知らせください。改善を検討いたします。
また、式では楕円を使用して、テンソル、テンソルのリスト(可変長のテンソルから発生する可能性があるものなど)など、あらゆる種類のリストを展開します。これは、正確な形式主義を提供しない(リストは StableHLO 型システムの一部ではないなど)もう 1 つの領域であり、直感的な理解に依存しています。
C)最後に、暗黙的なブロードキャストという注目すべき記法について説明します。StableHLO opset は暗黙的なブロードキャストをサポートしていませんが、簡潔さを提供する目的で数式はサポートしています。簡単に言うと、テンソルが予想されるコンテキストでスカラーを使用すると、スカラーは期待される形状にブロードキャストされます。
dot_general の例を続けると、別の制約 0 <= lhs_batching_dimensions < rank(lhs) があります。dot_general 仕様で定義されているように、lhs_batching_dimensions はテンソルですが、0 と rank(lhs) はどちらもスカラーです。暗黙的なブロードキャストを適用すると、式は [0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)] になります。
特定の dot_general 演算に適用されると、この数式はブール値のテンソルに評価されます。数式を制約として使用する場合、数式が true と評価された場合、または true 要素のみを含むテンソルと評価された場合に制約が適用されます。
名前
数式では、1)グローバル関数、2)メンバー定義が、レキシカル スコープに含まれます。
3)ローカル定義グローバル関数のリストを以下に示します。要素定義のリストは、表記が適用されるプログラム要素によって異なります。
- オペレーションの場合、メンバー定義には「入力」セクションと「出力」セクションで説明した名前が含まれます。
- それ以外の場合、メンバー定義には、対応する EBNF 非終端記号にちなんで命名された、プログラム要素の構造部分が含まれます。ほとんどの場合、これらの構造部分の名前は、非終端記号の名前をスネークケースに変換することで取得されます(例: IntegerLiteral=>integer_literal)。ただし、このプロセスで名前が省略されることもあります(例:QuantizationStorageType=>storage_type)。その場合は、オペレーション仕様の [入力] / [出力] セクションと同様に、名前が明示的に導入されます。
- また、メンバー定義には常に selfが含まれ、対応するプログラム要素を参照します。
値
数式の評価では、次の種類の値が使用されます。1)Value(実際の値、例: dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>。型は常に判明しています)、2)Placeholder(将来の値、例: lhs、rhs、result。実際の値はまだ判明しておらず、型のみが判明しています)、3)Type([型] セクションで定義されている型)、4)Function([関数] セクションで定義されているグローバル関数)。
コンテキストによっては、名前が異なる値を参照している場合があります。具体的には、オペレーションの「セマンティクス」セクション(および他のプログラム要素の同等のセクション)でランタイム ロジックが定義されているため、すべての入力を Value として使用できます。これに対して、op(および同等のもの)の「制約」セクションでは「コンパイル時」ロジック、つまり通常はランタイム前に実行されるロジックを定義しています。そのため、定数入力のみが Value として使用でき、その他の入力は Placeholder としてのみ利用できます。
| 名前 | 「Semantics」 | [制約] で | 
|---|---|---|
| グローバル機能 | Function | Function | 
| 定数入力 | Value | Value | 
| 定数以外の入力 | Value | Placeholder | 
| 出力 | Value | Placeholder | 
| ローカル定義 | 定義によって異なる | 定義によって異なる | 
transpose オペレーションの例を見てみましょう。
%result = "stablehlo.transpose"(%operand) {
  permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
このオペレーションでは、permutation は定数であるため、セマンティクスと制約の両方で Value として使用できます。一方、operand と result は、セマンティクスでは Value として使用できますが、制約では Placeholder としてのみ使用できます。
関数
型の構築
型の作成に使用できる関数はありません。通常は、より簡潔なタイプ構文を直接使用します。たとえば、function_type(
[tensor_type([], E), tensor_type([], E)], [tensor_type([], E)]) ではなく (tensor<E>, tensor<E>) -> (tensor<E>) です。
型の関数
- element_typeはテンソル型と量子化テンソル型で定義され、それぞれ対応する- TensorTypeまたは- QuantizedTensorTypeの- TensorElementTypeまたは- QuantizedTensorElementType部分を返します。
def element_type(x: Value | Placeholder | Type):
 if type(x) == TensorType:
    return tensor_element_type(x)
  if type(x) == QuantizedTensorType:
    return quantized_tensor_element_type(x)
  if type(x) is not Type:
    return element_type(type(x))
- is_per_axis_quantized(x: Value | Placeholder | Type) -> Valueは- is_quantized(x) and quantization_dimension(x) is not Noneのショートカットです。
- is_per_tensor_quantized(x: Value | Placeholder | Type) -> Valueは- is_quantized(x) and quantization_dimension(x) is Noneのショートカットです。
- is_promotable(x: Type, y: Type) -> boolは、- x型を- y型に昇格できるかどうかを確認します。- xと- yが- QuantizedTensorElementTypeの場合、プロモーションは- storage_typeにのみ適用されます。この特定のバージョンのプロモーションは、現在、削減計算のコンテキストで使用されています(詳細については、RFC をご覧ください)。
def is_promotable(x: Type, y: Type) -> Value:
  is_same_type = (is_bool(x) and is_bool(y)) or
    (is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
    (is_complex(x) and is_complex(y)) or
    (is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))
  if is_same_type == False:
    return False
  if is_integer(x) or is_float(x):
    return bitwidth(x) <= bitwidth(y)
  if is_complex(x):
    return bitwidth(element_type(x)) <= bitwidth(element_type(y))
  if is_quantized(x):
    return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))
  return false
- is_quantized(x: Value | Placeholder | Type) -> Valueは- is_quantized_tensor_element_type(x)のショートカットです。
- is_type_name(x: Value | Placeholder | Type) -> Value。すべてのタイプで使用できます。たとえば、- xが- FloatTypeの場合、- is_float(x)は- trueを返します。- xが値またはプレースホルダの場合、この関数は- is_type_name(type(x))のショートカットになります。
- max_value(x: Type) -> Valueは- TensorElementTypeの最大値を返します。- xが- TensorElementTypeでない場合、- Noneを返します。
- min_value(x: Type) -> Valueは、- TensorElementTypeの可能な最小値を返します。- xが- TensorElementTypeでない場合、- Noneを返します。
- member_name(x: Value | Placeholder | Type) -> Any。すべてのタイプのすべてのメンバー定義- member_nameで使用できます。たとえば、- tensor_element_type(x)は対応する- TensorTypeの- TensorElementType部分を返します。- xが値またはプレースホルダの場合、この関数は- member_name(type(x))のショートカットです。- xが適切なメンバーを持つ型、またはそのような型の値またはプレースホルダでない場合、- Noneを返します。
- is_empty_algorithm(*args: Type)は、すべてのドット アルゴリズム フィールドが- Noneに設定されているかどうかを確認します。これは、ドット アルゴリズムの実装でデフォルトの動作が定義されているため、デフォルト値を指定すると正しくありません。
値の構成
- operation_name(*xs: Value | Type) -> Value。すべてのオペレーションで使用できます。たとえば、- add(lhs, rhs)は 2 つのテンソル値- lhsと- rhsを受け取り、これらの入力で- add演算を評価した結果を返します。- broadcast_in_dimなどの一部のオペレーションでは、出力のタイプが「負荷を負う」タイプです。つまり、オペレーションの評価に必要です。この場合、関数はこれらの型を引数として受け取ります。
値の関数
- Python のすべての演算子と関数を使用できます。たとえば、Python のサブスクリプションとスライスの両方の表記を使用して、テンソル、量子化テンソル、タプルにインデックスを付けることができます。 
- to_destination_type(x: Value, destination_type: Type) -> Valueはテンソルで定義され、次のように- type(x)と- destination_typeに基づいて- xの変換値を返します。
def to_destination_type(x: Value, destination_type: Type) -> Value:
  if type(x) == destination_type:
    return x
  if is_quantized(destination_type):
    if is_quantized(type(x)):
      return quantize(x, destination_type)
    assert is_float(type(x))
    return quantize(x, destination_type)
  if is_quantized(type(x)):
    assert destination_type = expressed_type(type(x))
    return dequantize(type(x))
  return convert(x, destination_type)
convert、uniform_quantize、uniform_dequantize オペレーションのマージについては、早期の議論があります(#1576)。
マージ後、上記の関数は不要になり、代わりに convert のオペレーション名を使用できます。
- is_nan(x: Value) -> Valueはテンソルで定義され、- xのすべての要素が- NaNの場合は- trueを返します。それ以外の場合は- falseを返します。- xがテンソルでない場合、- Noneを返します。
- is_sorted(x: Value) -> Valueはテンソルで定義され、- xの要素がインデックスの辞書順の昇順で並べ替えられている場合は- trueを返し、そうでない場合は- falseを返します。- xがテンソルでない場合は、- Noneを返します。
- is_unique(x: Value) -> Valueはテンソルで定義され、- xに重複する要素がない場合- trueを返します。それ以外の場合は- falseを返します。- xがテンソルでない場合は、- Noneを返します。
- member_name(x: Value) -> Anyは、すべての値のすべてのメンバー定義- member_nameに対して定義されます。たとえば、- real_part(x)は対応する- ComplexConstantの- RealPart部分を返します。- xが適切なメンバーを含む値でない場合、- Noneを返します。
- same(x: Value) -> Valueはテンソルに対して定義され、- xの要素がすべて同じであれば- trueを返し、それ以外の場合は- falseを返します。テンソルに要素がない場合、それは「すべて等しい」と見なされます。つまり、関数は- trueを返します。- xがテンソルでない場合は、- Noneを返します。
- split(x: Value, num_results: Value, axis: Value) -> Valueはテンソルで定義され、軸- axisに沿って- xの- num_resultsスライスを返します。- xがテンソルまたは- dim(x, axis) % num_results != 0でない場合は、- Noneを返します。
- is_defined_in_parent_scope(x: Value) -> Valueは文字列で定義され、- xが関連するオペレーションの親関数と同じスコープで定義された関数の名前である場合、- trueを返します。
- is_namespaced_op_name(x: Value) -> Valueは文字列で定義され、- xが有効なオペレーション名(- [a-zA-Z][a-zA-Z0-9_]*([.][a-zA-Z0-9_$]+)+という正規表現に従う)である場合は- trueを返します。
シェイプの計算
- axes(x: Value | Placeholder | Type) -> Valueは- range(rank(x))のショートカットです。
- dim(x: Value | Placeholder | Type, axis: Value) -> Valueは- shape(x)[axis]のショートカットです。
- dims(x: Value | Placeholder | Type, axes: List) -> Listは- list(map(lambda axis: dim(x, axis), axes))のショートカットです。
- index_space(x: Value | Placeholder | Type) -> Valueはテンソルで定義され、対応する- TensorTypeの- size(x)インデックスを辞書順で昇順(- [0, ..., 0]、- [0, ..., 1]、...、- shape(x) - 1)で返します。- xがテンソル型、量子化テンソル型、値、またはこれらの型のプレースホルダでない場合、- Noneを返します。
- rank(x: Value | Placeholder | Type) -> Valueは- size(shape(x))のショートカットです。
- shape(x: Value | Placeholder | Type) -> Valueは、「型の関数」セクションで- member_nameを介して定義されます。
- size(x: Value | Placeholder | Type) -> Valueは- reduce(lambda x, y: x * y, shape(x))のショートカットです。
量子化計算
- def baseline_element_type(x: Value | Placeholder | Type) -> Typeは- element_type(baseline_type(x))のショートカットです。
- baseline_typeはテンソル型と量子化テンソル型で定義され、それらを「ベースライン」に変換します。ベースラインとは、同じ形状で、要素型の量子化パラメータがデフォルト値にリセットされた型です。これは、テンソルと量子化テンソルの両方のタイプを均一に比較するための便利なトリックとして使用されます。これは非常に頻繁に必要になります。量子化タイプの場合、量子化パラメータを無視してタイプを比較できます。つまり、- shape、- storage_type、- expressed_type、- storage_min、- storage_max、- quantization_dimension(軸ごとの量子化タイプの場合)はすべて一致する必要がありますが、- scalesと- zero pointsは異なる場合があります。
def baseline_type(x: Value | Placeholder | Type) -> Type:
  if type(x) == TensorType:
    return x
  if type(x) == QuantizedTensorType:
    element_type = quantized_tensor_element_type(x)
    baseline_element_type = QuantizedTensorElementType(
      storage_type = storage_type(element_type),
      storage_min = storage_min(element_type),
      storage_max = storage_max(element_type),
      expressed_type = expressed_type(element_type),
      quantization_dimension = quantization_dimension(element_type),
      scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
      zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
    return QuantizedTensorType(shape(x), baseline_element_type)
  if type(x) is not Type:
    return baseline_element_type(type(x))
- dequantizeは量子化テンソル型に対して定義され、浮動小数点テンソル型に変換します。これは、量子化要素型に関連付けられたゼロ点とスケールを使用して、ストレージ型の整数値を表す量子化要素を、表現型の対応する浮動小数点値に変換することで行われます。
def compute_zero_points(quantized_type, result_type):
  if is_per_tensor_quantized(quantized_type):
    return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
  if is_per_axis_quantized(quantized_type):
    for i in index_space(result_type):
      d = quantization_dimension(quantized_type)
      zero_points[i] = zero_points(quantized_type)[i[d]]
    return zero_points
def compute_scales(quantized_type, result_type):
  if is_per_tensor_quantized(quantized_type):
    return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
            type(result_type))
  if is_per_axis_quantized(quantized_type):
    for i in index_space(result_type):
      d = quantization_dimension(quantized_type)
      scales[i] = scales(quantized_type)[i[d]]
    return scales
def dequantize(x: Value) -> Value:
  assert is_quantized(x)
  x_storage = bitcast_convert(x, storage_type(x))
  x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
  x_expressed_sub = convert(x_storage_sub, expressed_type(x))
  return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
- quantizeは浮動小数点テンソル型で定義され、量子化されたテンソル型に変換します。これは、量子化された要素型に関連付けられたゼロ点とスケールを使用して、表現型の浮動小数点値をストレージ型の対応する整数値に変換することで行われます。
def quantize(x: Value, result_type: Type) -> Value:
  assert is_float(x) and is_quantized(result_type)
  zero_points = compute_zero_points(result_type, TensorType(shape(x), storage_type(result_type)))
  converted_zero_points = convert(zero_points, expressed_type(result_type))
  converted_min = convert(storage_min(result_type), expressed_type(result_type))
  converted_max = convert(storage_max(result_type), expressed_type(result_type))
  x_scaled = x / compute_scales(result_type, type(x))
  x_scaled_add_zp = x_scaled + converted_zero_points
  x_clamped = clamp(converted_min, x_scaled_add_zp, converted_max)
  x_rounded = round_nearest_even(x_clamped)
  return convert(x_rounded, result_type)
- dequantize_op_quantizeは、量子化テンソルの要素単位の計算を指定するために使用されます。デクォンタイズ(量子化された要素を表現型に変換)してから演算を実行し、量子化(結果をストレージ タイプに戻す)します。現時点では、この関数はテンソルごとの量子化でのみ機能します。軸ごとの量子化は現在開発中です(#1574)。
def dequantize_op_quantize(op, *inputs_and_output_type):
  inputs = inputs_and_output_type[:-1]
  output_type = inputs_and_output_type[-1]
  float_inputs = map(dequantize, inputs)
  float_result = op(*float_inputs)
  return quantize(float_result, output_type)
def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
  inputs = inputs_and_output_type[:-3]
  float_inputs = map(dequantize, inputs)
  float_results = op(*float_inputs)
  return map(quantize, float_results, inputs_and_output_type[-3:])
def dequantize_compare(lhs, rhs, comparison_direction):
  float_lhs = dequantize(lhs)
  float_rhs = dequantize(rhs)
  return compare(float_lhs, float_rhs, comparison_direction, FLOAT)
def dequantize_select_quantize(pred, on_true, on_false, output_type):
  float_on_true = dequantize(on_true)
  float_on_false = dequantize(on_false)
  float_result = select(pred, float_on_true, float_on_false)
  return quantize(float_result, output_type)
- hybrid_dequantize_then_opは、浮動小数点型の lhs と量子化された型の rhs を受け入れるハイブリッド演算の重みのみの量子化を指定するために使用されます。量子化された入力を表現型にデクォンタイズし、浮動小数点数で計算を行います。浮動小数点の左辺テンソルの要素型と、量子化された右辺テンソルの表現型は同じである必要があります。
def hybrid_dequantize_then_op(op, lhs, rhs):
  assert(is_float(lhs) and is_quantized(rhs) and element_type(lhs) == expressed_type(rhs))
  return op(lhs, dequantize(rhs))
グリッド計算
- cross_partition(replica_groups: Value) -> Value。上記の「cross_replica」セクションをご覧ください。
- cross_replica(replica_groups: Value) -> Value。上記の「cross_replica」セクションをご覧ください。
- cross_replica_and_partition(replica_groups: Value) -> Value。上記の「cross_replica_and_partition」セクションをご覧ください。
- flattened_ids(replica_groups: Value) -> Value。上記の「flattened_ids」セクションをご覧ください。
ダイナミズム
StableHLO 値には、動的なディメンション サイズ(tensor<?xi64> など)を指定できます。ただし、StableHLO 値には動的ディメンション数(ランクなしの動的性、tensor<*xi64> など)を指定できません。オペランドと結果では、サイズに制約がある場合でも、動的ディメンション サイズを使用できます。制約は可能であれば静的に検証されます。検証できない場合は実行時に延期され、不一致により未定義の動作が発生します。例については以下をご覧ください。
単項要素ごとのオペレーションのシェイプの不一致
次のサンプル プログラムを考えてみましょう。
func.func @foo(%arg0: tensor<?xf64>) {
  %0 = stablehlo.abs %arg0 : (tensor<?xf64>) -> tensor<2xf64>
  return
}
このようなプログラムは珍しいものです。結果の形状はわかっても、入力の形状がわからないことは一般的ではありません。ただし、これは有効な StableHLO プログラムです。このプログラムの abs オペレーションを静的に検証することはできません。オペランドの正確な形状が不明であるためです。ただし、シェイプは確かに互換性があり、これは静的に確認できます。? が実行時に 2 になっても問題はありません。ただし、? が他の整数になる可能性もあります。その場合、動作は未定義になります。
結果でディメンション サイズが動的である場合、未定義の動作は発生しません。実際、サイズは「想定」されていないため、不一致は発生しません。
バイナリ要素ごとのオペレーションのシェイプの不一致
次のおもちゃプログラムについて考えてみましょう。
func.func @foo(%arg0: tensor<?xf64>, %arg1: tensor<?xf64>) {
  %0 = stablehlo.add %arg0, %arg0 : (tensor<?xf64>, tensor<?xf64>) -> tensor<?xf64>
  return
}
バイナリ要素ごとのオペレーションの場合、入力と結果の形状は実行時に一致している必要があります。コンパイル時に静的ディメンションは同じである必要があります。それ以外の場合は、互換性がある必要があります。入力でいずれかのディメンションが動的である場合、動的サイズが他のオペランドの対応するサイズ(静的または動的)と一致しないため、ランタイムで未定義の動作が発生する可能性があります。すべての入力が静的である場合、結果が動的かどうかは重要ではありません。静的に既知のディメンションは静的にチェックされ、動的ディメンションには制約が適用されません。
出力シェイプをオペランドとして取るオペレーションのシェイプの不一致
次のサンプル プログラムを考えてみましょう。
func.func @foo(%arg0: tensor<2xi32>) {
  %0 = stablehlo.dynamic_iota %arg0, dim = 0 : (tensor<2xi32>) -> tensor<3x4xi64>
  return
}
実行時にシェイプ オペランドの値は、結果のシェイプと一致する必要があります。一致しない場合、動作は未定義になります。つまり、実行時に %arg0 の値は dense<[3, 4]> : tensor<2xi32> にする必要があります。シェイプ オペランドが定数の場合は、静的に検証できます。結果の形状が完全に動的であれば、不一致が生じることはありません。