StableHLO の仕様

StableHLO は、機械学習(ML)モデルの高レベル オペレーション(HLO)用のオペレーション セットです。StableHLO は、異なる ML フレームワークと ML コンパイラ間のポータビリティ レイヤとして機能します。StableHLO プログラムを生成する ML フレームワークは、StableHLO プログラムを使用する ML コンパイラと互換性があります。

Google の目標は、さまざまな ML フレームワーク(TensorFlow、JAX、PyTorch など)と ML コンパイラ(XLA、IREE など)間の相互運用性を高めることで、ML 開発を簡素化し、加速させることです。このドキュメントでは、その目的のために StableHLO プログラミング言語の仕様を説明します。

この仕様には、3 つの主要なセクションがあります。まず、プログラムのセクションでは、StableHLO 関数で構成される StableHLO プログラムの構造について説明します。StableHLO 関数自体は StableHLO オペレーションで構成されます。この構造内の [Ops] セクションには、個々のオペレーションのセマンティクスを指定します。[実行] セクションには、プログラム内で一緒に実行されるこれらのオペレーションのすべてのセマンティクスが示されています。最後に、表記セクションでは、仕様全体で使用される表記について説明します。

StableHLO の以前のリリースの仕様を表示するには、目的のタグ付きリリースでリポジトリを開きます。たとえば、StableHLO v0.19.0 仕様。StableHLO のマイナー バージョンの増加ごとに発生した変更を確認するには、VhloDialect.td のバージョンログをご覧ください。

プログラム

Program ::= {Func}

StableHLO プログラムは、任意の数の StableHLO 関数で構成されます。以下は、3 つの入力(%image%weights%bias)と 1 つの出力を持つ関数 @main を含むプログラムの例です。関数本体には 6 つのオペレーションがあります。

func.func @main(
  %image: tensor<28x28xf32>,
  %weights: tensor<784x10xf32>,
  %bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
  %0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
  %1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
  %2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  %3 = "stablehlo.constant"() {value = dense<0.0> : tensor<1x10xf32>} : () -> tensor<1x10xf32>
  %4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  "func.return"(%4): (tensor<1x10xf32>) -> ()
}

関数

Func        ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs  ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput   ::= ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput  ::= ValueType
FuncBody    ::= {Op}

StableHLO 関数(名前付き関数とも呼ばれます)には、識別子、入出力、本文があります。今後、HLO との互換性を高めるために、関数に追加のメタデータを導入する予定です(#425#626#740#744)。

識別子

FuncId  ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
          | '%' letter {letter | digit}
letter  ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit   ::= '0' | ... | '9'

StableHLO 識別子は多くのプログラミング言語の識別子に似ていますが、2 つの特殊性があります。1)すべての識別子には、さまざまな種類の識別子を区別するシグルが付いています。2)値識別子は完全に数値にすることができ、StableHLO プログラムの生成を簡素化できます。

Type         ::= ValueType | NonValueType
ValueType    ::= TensorType | QuantizedTensorType | TokenType | TupleType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType

StableHLO 型は、StableHLO 値を表す値の型(ファーストクラス型とも呼ばれます)と、他のプログラム要素を記述する値以外の型に分類されます。StableHLO 型は多くのプログラミング言語の型に似ていますが、主な特徴は StableHLO のドメイン固有の性質であり、異常な結果をもたらします(例: スカラー型は値型ではありません)。

TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit} | '?'

テンソル型はテンソル、つまり多次元配列を表します。シェイプ要素型があります。シェイプは、0R-1 の番号が付けられた対応するディメンションとも呼ばれます)の昇順で、正または不明のディメンション サイズを表します。ディメンション R の数はランクと呼ばれます。たとえば、tensor<2x3xf32> は形状が 2x3 で要素型が f32 のテンソル型です。0 番目のディメンションと 1 番目のディメンションの 2 つのディメンション(つまり 2 つの軸)があり、サイズは 2 と 3 です。ランクは 2 です。

シェイプは、部分的または完全に未知(動的)にできます。たとえば、tensor<?x2xf64> は部分的に未知、tensor<?x?xf64> は完全に未知です。動的ディメンションのサイズは ? を使用して表されます。シェイプのランク付けを解除することはできません。

今後、テンソル型をディメンション サイズと要素型以外にも拡張し、レイアウト(#629)やスパース性(#1078)などを含める予定です。

QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
                  QuantizationStorageType
                  ['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
                  ':' QuantizationExpressedType
                  [':' QuantizationDimension]
                  ',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerLiteral
QuantizationStorageMax ::= IntegerLiteral
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerLiteral
QuantizationParameters ::= QuantizationParameter
                         | '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale [':' QuantizationZeroPoint]
QuantizationScale ::= FloatLiteral
QuantizationZeroPoint ::= IntegerLiteral
名前 タイプ 制約
storage_type 整数型 (C1 ~ C3)、(C8)
storage_min 整数定数 (C1)、(C3)、(C7)
storage_max 整数定数 (C2)、(C3)、(C7)
expressed_type floating-point-type (C4)
quantization_dimension 省略可能な整数定数 (C10-C12)
scales 浮動小数点定数の可変数 (C4 ~ C6)、(C9)、(C10)、(C13)
zero_points 可変長整数定数 (C7 ~ C9)

量子化要素の型は、表現型の浮動小数点値に対応する、storage_minstorage_max(両端を含む)の範囲のストレージ型の整数値を表します。指定された整数値 i について、対応する浮動小数点値 ff = (i - zero_point) * scale として計算できます。ここで、scalezero_point は量子化パラメータと呼ばれます。storage_minstorage_max は文法では省略可能ですが、デフォルト値はそれぞれ min_value(storage_type)max_value(storage_type) です。量子化要素タイプには次の制約があります。

  • (C1)type(storage_min) = storage_type
  • (C2)type(storage_max) = storage_type
  • (C3)min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type)
  • (C4)type(scales...) = expressed_type
  • (C5)0 < scales
  • (C6)is_finite(scales...)
  • (C7)storage_min <= zero_points <= storage_max
  • (C8)type(zero_points...) = storage_type
  • (C9)size(scales) = size(zero_points)
  • (C10)is_empty(quantization_dimension) の場合は size(scales) = 1 です。
  • (C11)0 <= quantization_dimension

現在、QuantizationScale は浮動小数点定数ですが、乗数とシフトで表される整数ベースのスケールに強い関心があります。近い将来、この問題を調査する予定です(#1404)。

QuantizationZeroPoint のセマンティクス(型、値、量子化されたテンソル型にゼロポイントが 1 つだけか複数ある可能性があるかなど)については、現在も議論が続いています。このディスカッションの結果に基づいて、ゼロポイントに関する仕様は今後変更される可能性があります(#1405)。

現在進行中のもう一つの議論では、QuantizationStorageMinQuantizationStorageMax のセマンティクスについて、これらの値と量子化テンソルの値に制約を適用すべきかどうかを判断しています(#1406)。

最後に、未知のディメンション サイズの表現(#1407)と同様に、未知のスケールとゼロポイントの表現を検討する予定です。

量子化テンソル型は、量子化された要素を持つテンソルを表します。これらのテンソルは、通常の要素型ではなく、量子化された要素型を持つ点を除き、通常のテンソルとまったく同じです。

量子化されたテンソルでは、量子化はテンソル単位で行うことができます。つまり、テンソル全体に 1 つの scalezero_point を設定します。または、軸単位で行うこともできます。つまり、特定のディメンション quantization_dimension のスライスごとに 1 組の scaleszero_points を設定します。より正式には、軸ごとの量子化を使用するテンソル t には、quantization_dimensiondim(t, quantization_dimension) スライス(t[:, ..., 0, ..., :], t[:, ..., 1, ..., :] など)があります。i 番目のスライスのすべての要素は、量子化パラメータとして scales[i]zero_points[i] を使用します。量子化テンソル型には次の制約があります。

  • テンソルごとの量子化の場合:
    • 追加の制約はありません。
  • 軸ごとの量子化の場合:
    • (C12)quantization_dimension < rank(self)
    • (C13)dim(self, quantization_dimension) = size(scales)
TokenType ::= 'token'

トークン型は、トークン(一部のオペレーションによって生成および消費される不透明な値)を表します。トークンは、実行セクションで説明されているように、オペレーションに実行順序を適用するために使用されます。

TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]

タプル型は、タプル(異種リスト)を表します。タプルは、HLO との互換性を確保するためにのみ存在するレガシー機能です。HLO では、タプルを使用して可変長の入力と出力を表します。StableHLO では、可変長の入力と出力がネイティブにサポートされています。StableHLO でタプルを使用するのは、HLO ABI を包括的に表す場合のみです。たとえば、Ttuple<T>tuple<tuple<T>> は特定の実装によって大きく異なる場合があります。今後、HLO ABI を変更し、StableHLO からタプル型を削除する予定です(#598)。

TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si2' | 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui2' | 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f4E2M1FN' | 'f6E2M3FN' | 'f6E3M2FN' | 'f8E3M4' | 'f8E4M3'
            | 'f8E4M3FN' | 'f8E4M3FNUZ' | 'f8E4M3B11FNUZ' | 'f8E5M2'
            | 'f8E5M2FNUZ' | 'f8E8M0FNU' | 'bf16' | 'f16' | 'f32' | 'f64'
TensorFloat32 ::= 'tf32'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'

要素型は、テンソル型の要素を表します。多くのプログラミング言語とは異なり、これらの型は StableHLO ではファーストクラスではありません。つまり、StableHLO プログラムでは、これらの型の値を直接表現できません(その結果、T 型のスカラー値を tensor<T> 型の 0 次元テンソル値で表現するのが一般的です)。

FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]

関数型は、名前付き関数と匿名関数の両方を表します。入力型(-> の左側の型のリスト)と出力の型(-> の右側の型のリスト)があります。多くのプログラミング言語では関数型がファースト クラスですが、StableHLO ではそうではありません。

StringType ::= 'string'

String 型はバイトのシーケンスを表します。多くのプログラミング言語とは異なり、StableHLO では文字列型はファーストクラスではなく、プログラム要素の静的メタデータを指定する場合にのみ使用されます。

運用

StableHLO オペレーション(オペレーションとも呼ばれる)は、ML モデルのクローズド セットの大まかなオペレーションを表します。前述のように、StableHLO 構文は MLIR に大きく影響を受けています。これは必ずしも最も人間工学的な代替手段ではありませんが、ML フレームワークと ML コンパイラ間の相互運用性を高めるという StableHLO の目標には最適です。

Op            ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName        ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic    ::= 'abs' | 'add' | ...

StableHLO オペレーション(オペレーションとも呼ばれる)には、名前、入出力、署名があります。この名前は、stablehlo. 接頭辞と、サポートされているいずれかのオペレーションを一意に識別するニーモニックで構成されます。サポートされているすべてのオペレーションの一覧については、以下をご覧ください。

OpInputs        ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues   ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue    ::= ValueId
OpInputFuncs    ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs    ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs       ::= [OpOutput {',' OpOutput} '=']
OpOutput        ::= ValueId

オペレーションは入力を使用し、出力を生成します。入力は、入力値(実行時に計算される)、入力関数(StableHLO では関数がファーストクラス値ではないため、静的に提供される)、入力属性(これも静的に提供される)に分類されます。op によって消費、生成される入力と出力の種類は、そのニモニックによって異なります。たとえば、add オペレーションは 2 つの入力値を使用し、1 つの出力値を生成します。これに対し、select_and_scatter op は 3 つの入力値、2 つの入力関数、3 つの入力属性を使用します。

OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused      ::= '^' digit {digit}
              | '^' letter {letter | digit}

入力関数(別名: 匿名関数)は、名前付き関数と非常によく似ていますが、1)識別子がない(「匿名」という名前の由来)ことと、2)出力型を宣言しない(出力型は関数内の return オペレーションから推論される)点が異なります。

入力関数の構文には、MLIR との互換性を確保するために現在使用されていない部分(上記の Unused 生産を参照)が含まれています。MLIR には、ジャンプ演算子で接続された複数の演算子の「ブロック」を持つことができる、より一般的な「リージョン」のコンセプトがあります。これらのブロックには、Unused 本番環境に対応する ID が付いているため、ブロック同士を区別できます。StableHLO にはジャンプ演算がないため、MLIR 構文の対応する部分は使用されません(ただし、まだ存在します)。

OpInputAttr      ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName  ::= letter {letter | digit}
OpInputAttrValue ::= Constant

入力属性には、名前と値があります。値は、サポートされている定数の一つです。プログラム要素の静的メタデータを指定する主な方法です。たとえば concatenate 演算では、属性 dimension を使用して、入力値を連結するディメンションを指定します。同様に、slice 演算は start_indiceslimit_indices などの複数の属性を使用して、入力値をスライスするために使用される境界を指定します。

現時点では、実際の StableHLO プログラムに、このドキュメントで説明されていない属性が含まれていることがあります。今後、これらの属性を StableHLO オペセットに吸収するか、StableHLO プログラムに出現しないようにする予定です。それまでは以下の属性のリストをご覧ください。

  • layout#629)。
  • mhlo.frontend_attributes#628)。
  • mhlo.sharding#619)。
  • output_operand_aliases#740)。
  • 位置情報メタデータ(#594)。
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'

オペレーション シグネチャは、すべての入力値の型(-> の左側の型のリスト)とすべての出力値の型(-> の右側の型のリスト)で構成されます。厳密に言えば、入力型は冗長であり、出力型もほとんどの場合冗長です(ほとんどの StableHLO オペレーションでは、出力型は入力から推測できるため)。ただし、op シグネチャは、MLIR との互換性を確保するために、意図的に StableHLO 構文の一部になっています。

以下に、メモニクスが select_and_scatter の演算子の例を示します。3 つの入力値(%operand%source%init_value)、2 つの入力関数、3 つの入力属性(window_dimensionswindow_stridespadding)を使用します。op のシグネチャには入力値の型のみが含まれます(インラインで提供される入力関数と属性の型は含まれません)。

%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
    %0 = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GE>
    } : (tensor<i32>, tensor<i32>) -> tensor<i1>
    "stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
    "stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
  window_dimensions = dense<[3, 1]> : tensor<2xi64>,
  window_strides = dense<[2, 1]> : tensor<2xi64>,
  padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>

定数

Constant ::= BooleanConstant
           | IntegerConstant
           | FloatConstant
           | ComplexConstant
           | TensorConstant
           | QuantizedTensorConstant
           | StringConstant
           | EnumConstant

StableHLO 定数には、StableHLO 値を表すリテラルと型があります。一般に、型は定数の構文の一部ですが、明確な場合を除きます(たとえば、ブール値定数は明確に型 i1 ですが、整数定数には複数の型が考えられます)。

BooleanConstant ::= BooleanLiteral
BooleanLiteral  ::= 'true' | 'false'

ブール定数は、ブール値 truefalse を表します。ブール値定数は i1 型です。

IntegerConstant   ::= IntegerLiteral ':' IntegerType
IntegerLiteral    ::= ['-' | '+'] DecimalDigits
                    | ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits     ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit      ::= '0' | ... | '9'
hexadecimalDigit  ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'

整数定数は、10 進数または 16 進数表記の文字列を使用して整数値を表します。バイナリやオクタルなどの他の基数はサポートされていません。整数定数には次の制約があります。

  • (C1)is_wellformed(integer_literal, integer_type)
FloatConstant  ::= FloatLiteral ':' FloatType
FloatLiteral   ::= SignPart IntegerPart FractionalPart ScientificPart
                 | '0x' [HexadecimalDigits]
SignPart       ::= ['-' | '+']
IntegerPart    ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]

浮動小数点定数は、10 進数または科学的記数法を使用する文字列で浮動小数点値を表します。また、16 進数表記を使用して、対応する型の浮動小数点形式で基盤となるビットを直接指定することもできます。浮動小数点定数には次の制約があります。

  • (C1)16 進数以外の表記が使用されている場合、is_wellformed(float_literal, float_type)
  • (C2)16 進数表記を使用している場合、size(hexadecimal_digits) = num_bits(float_type) / 4
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral  ::= '(' RealPart ',' ImaginaryPart ')'
RealPart        ::= FloatLiteral
ImaginaryPart   ::= FloatLiteral

複素定数は、実部(先頭)と虚部(末尾)のリストを使用して複素値を表します。たとえば、(1.0, 0.0) : complex<f32>1.0 + 0.0i を表し、(0.0, 1.0) : complex<f32>0.0 + 1.0i を表します。これらのパーツをメモリに保存する順序は実装で定義します。複素定数には次の制約があります。

  • (C1)is_wellformed(real_part, complex_element_type(complex_type))
  • (C2)is_wellformed(imaginary_part, complex_element_type(complex_type))
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral  ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral   ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements  ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral

テンソル定数は、NumPy の表記法で指定されるネストされたリストを使用してテンソル値を表します。たとえば、dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> は、インデックスから要素への次のマッピングを持つテンソル値を表します。{0, 0} => 1{0, 1} => 2{0, 2} => 3{1, 0} => 4{1, 1} => 5{1, 2} => 6。これらの要素がメモリに格納される順序は実装によって定義されます。テンソル定数には次の制約があります。

  • (C1)has_syntax(tensor_literal, element_type(tensor_type))。ここで、
    • has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type)
    • has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type)
  • (C2)has_shape(tensor_literal, shape(tensor_type))。次に例を示します。
    • has_shape(element_literal: Syntax, []) = true
    • has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:])
    • それ以外の場合は false
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral  ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'

量子化テンソル定数は、テンソル定数と同じ表記を使用して量子化テンソル値を表します。要素はストレージ タイプの定数として指定されます。量子化されたテンソル定数には次の制約があります。

  • (C1)has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type))
  • (C2)has_shape(quantized_tensor_literal, shape(quantized_tensor_type))
StringConstant  ::= StringLiteral
StringLiteral   ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence  ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))

文字列リテラルは、ASCII 文字とエスケープ シーケンスを使用して指定されたバイトで構成されます。エンコードに依存しないため、これらのバイトの解釈は実装で定義されます。文字列リテラルの型は string です。

運用

abs

セマンティクス

operand テンソルに対して要素ごとの絶対値演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。

  • 符号付き整数の場合: 整数の剰余。
  • 浮動小数点数: IEEE-754 の abs
  • 複素数の場合: 複素数モジュラス。
  • 量子化された型の場合: dequantize_op_quantize(abs, operand, type(result))

入力

ラベル 名前 タイプ 制約
(I1) operand 符号付き整数型、浮動小数点型、複合型のテンソル、またはテンソルごとの量子化テンソル (C1 ~ C2)

出力

名前 タイプ 制約
result 符号付き整数型または浮動小数点型のテンソル、またはテンソルごとの量子化テンソル (C1-C2)

制約

  • (C1)shape(result) = shape(operand)
  • (C2)baseline_element_type(result) は次のように定義されます。
    • is_complex(operand) の場合、complex_element_type(element_type(operand))
    • そうでない場合は baseline_element_type(operand)

// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand) : (tensor<3xi32>) -> tensor<3xi32>
// %result: [2, 0, 2]

その他の例

追加

セマンティクス

2 つのテンソル lhsrhs の要素ごとの加算を行い、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。

  • ブール値の場合: 論理 OR。
  • 整数の場合: 整数の加算。
  • 浮動小数点数の場合: IEEE-754 の addition
  • 複素数の場合: 複素数加算。
  • 量子化された型の場合: dequantize_op_quantize(add, lhs, rhs, type(result))

入力

ラベル 名前 タイプ 制約
(I1) lhs テンソルまたは量子化テンソル (C1~C6)
(I2) rhs テンソルまたは量子化テンソル (C1 ~ C5)、(C7)

出力

名前 タイプ 制約
result テンソルまたは量子化テンソル (C1~C7)

制約

  • オペレーションで量子化されていないテンソルを使用する場合:
    • (C1)type(lhs) = type(rhs) = type(result)
  • オペレーションで量子化テンソルを使用する場合:
    • (C2)is_quantized(lhs) and is_quantized(rhs) and is_quantized(result)
    • (C3)storage_type(lhs) = storage_type(rhs) = storage_type(result)
    • (C4)expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
    • (C5)(is_per_axis_quantized(lhs) or is_per_axis_quantized(rhs)) = is_per_axis_quantized(result)
    • (C6)is_per_axis_quantized(lhs) の場合は quantization_dimension(lhs) = quantization_dimension(result) です。
    • (C7)is_per_axis_quantized(rhs) の場合は quantization_dimension(rhs) = quantization_dimension(result) です。

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[6, 8], [10, 12]]

その他の例

after_all

セマンティクス

inputs を生成するオペレーションが、result に依存するオペレーションの前に実行されるようにします。このオペレーションを実行しても何も行われません。このオペレーションは、result から inputs へのデータ依存関係を確立するためにのみ存在します。

入力

ラベル 名前 タイプ
(I1) inputs 可変長の token の数

出力

名前 タイプ
result token

// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token

その他の例

all_gather

セマンティクス

StableHLO プロセス グリッドの各プロセス グループ内で、各プロセスの operands テンソルの値を all_gather_dim に沿って連結し、results テンソルを生成します。

このオペレーションでは、StableHLO プロセス グリッドが、次のように定義された process_groups に分割されます。

  • channel_id <= 0 and use_global_device_ids = false の場合、cross_replica(replica_groups)
  • channel_id > 0 and use_global_device_ids = false の場合、cross_replica_and_partition(replica_groups)
  • channel_id > 0 and use_global_device_ids = true の場合、flattened_ids(replica_groups)

その後、各 process_group 内で次の操作を行います。

  • process_group のすべての receiver に対して operands...@receiver = [operand@sender for sender in process_group]
  • process_group のすべての process に対して results...@process = concatenate(operands...@process, all_gather_dim)

入力

ラベル 名前 タイプ 制約
(I1) operands テンソルの可変数またはテンソルごとの量子化テンソル (C1)、(C6)
(I2) all_gather_dim si64 型の定数 (C1)、(C6)
(I3) replica_groups si64 型の 2 次元のテンソル定数 (C2 ~ C4)
(I4) channel_id si64 型の定数 (C5)
(I5) use_global_device_ids i1 型の定数 (C5)

出力

名前 タイプ 制約
results テンソルの可変数またはテンソルごとの量子化テンソル (C6)

制約

  • (C1)0 <= all_gather_dim < rank(operands...)
  • (C2)is_unique(replica_groups)
  • (C3)size(replica_groups) は次のように定義されます。
    • cross_replica を使用する場合は num_replicas
    • cross_replica_and_partition を使用する場合は num_replicas
    • flattened_ids を使用する場合は num_processes
  • (C4)0 <= replica_groups < size(replica_groups)
  • (C5)use_global_device_ids = true の場合、channel_id > 0
  • (C6)type(results...) = type(operands...)(ただし、次の場合は除く)
    • dim(results..., all_gather_dim) = dim(operands..., all_gather_dim) * dim(process_groups, 1)

// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [[1, 2], [3, 4]]
// %operand0@(1, 0): [[5, 6], [7, 8]]
// %operand1@(0, 0): [[11, 12], [13, 14]]
// %operand1@(1, 0): [[15, 16], [17, 18]]
%result:2 = "stablehlo.all_gather"(%operand0, %operand1) {
  all_gather_dim = 1 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  // channel_id = 0
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
  // use_global_device_ids = false
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> (tensor<2x4xi64>, tensor<2x4xi64>)
// %result0@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result0@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result1@(0, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
// %result1@(1, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]

その他の例

all_reduce

セマンティクス

StableHLO プロセス グリッドの各プロセス グループ内で、各プロセスの operands テンソルの値にリダクション関数 computation を適用し、results テンソルを生成します。

このオペレーションは、StableHLO プロセス グリッドを次のように定義された process_groups に分割します。

  • channel_id <= 0 and use_global_device_ids = false の場合、cross_replica(replica_groups)
  • channel_id > 0 and use_global_device_ids = false の場合、cross_replica_and_partition(replica_groups)
  • channel_id > 0 and use_global_device_ids = true の場合、flattened_ids(replica_groups)

その後、各 process_group 内で次の操作を行います。

  • results...@process[result_index] = exec(schedule) は任意のバイナリ ツリーです。schedule は次のとおりです。
    • exec(node) = computation(exec(node.left), exec(node.right))
    • exec(leaf) = leaf.value
  • schedule は、順序付き走査が to_destination_type(operands...@process_group...[result_index], type(func_inputs(computation)[0])) である実装定義のバイナリツリーです。

入力

ラベル 名前 タイプ 制約
(I1) operands テンソルの可変数またはテンソルごとの量子化テンソル (C5)、(C6)
(I2) replica_groups si64 型の 1 次元テンソル定数の可変数 (C1 ~ C3)
(I3) channel_id si64 型の定数 (C4)
(I4) use_global_device_ids i1 型の定数 (C4)
(I5) computation 関数 (C5)

出力

名前 タイプ 制約
results テンソルの可変数またはテンソルごとの量子化テンソル (C6~C7)

制約

  • (C1)is_unique(replica_groups)
  • (C2)size(replica_groups) は次のように定義されます。
    • cross_replica を使用する場合は num_replicas
    • cross_replica_and_partition を使用する場合は num_replicas
    • flattened_ids を使用する場合は num_processes
  • (C3)0 <= replica_groups < size(replica_groups)
  • (C4)use_global_device_ids = true の場合、channel_id > 0
  • (C5)computation の型は (tensor<E>, tensor<E>) -> (tensor<E>) で、is_promotable(element_type(operand), E) です。
  • (C6)shape(results...) = shape(operands...)
  • (C7)element_type(results...) = E

// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [1, 2, 3, 4]
// %operand0@(1, 0): [5, 6, 7, 8]
// %operand1@(0, 0): [9, 10, 11, 12]
// %operand1@(1, 0): [13, 14, 15, 16]
%result:2 = "stablehlo.all_reduce"(%operand0, %operand0) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  // channel_id = 0
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
  // use_global_device_ids = false
} : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>, tensor<4xi64>)
// %result0@(0, 0): [6, 8, 10, 12]
// %result0@(1, 0): [6, 8, 10, 12]
// %result1@(0, 0): [22, 24, 26, 28]
// %result1@(1, 0): [22, 24, 26, 28]

その他の例

all_to_all

セマンティクス

all_to_all

StableHLO プロセス グリッド内の各プロセス グループ内で、split_dimension に沿って operands テンソルの値を分割し、分割された部分をプロセス間で分散し、分散された部分を concat_dimension に沿って連結して results テンソルを生成します。このオペレーションでは、StableHLO プロセス グリッドが、次のように定義された process_groups に分割されます。

  • channel_id <= 0 の場合、cross_replica(replica_groups)
  • channel_id > 0 の場合は cross_partition(replica_groups)

その後、各 process_group 内で次の操作を行います。

  • process_group 内のすべての sender に対して split_parts...@sender = split(operands...@sender, split_count, split_dimension)
  • scattered_parts...@receiver = [split_parts...@sender[receiver_index] for sender in process_group]receiver_index = process_group.index(receiver) です。
  • results...@process = concatenate(scattered_parts...@process, concat_dimension)

入力

ラベル 名前 タイプ 制約
(I1) operands テンソルの可変数またはテンソルごとの量子化テンソル (C1 ~ C3)、(C9)
(I2) split_dimension si64 型の定数 (C1)、(C2)、(C9)
(I3) concat_dimension si64 型の定数 (C3)、(C9)
(I4) split_count si64 型の定数 (C2)、(C4)、(C8)、(C9)
(I5) replica_groups si64 型の 2 次元テンソル定数 (C5~C8)
(I6) channel_id si64 型の定数

出力

名前 タイプ 制約
results テンソルの可変数またはテンソルごとの量子化テンソル (C9)

制約

  • (C1)0 <= split_dimension < rank(operands...)
  • (C2)dim(operands..., split_dimension) % split_count = 0
  • (C3)0 <= concat_dimension < rank(operands...)
  • (C4)0 < split_count
  • (C5)is_unique(replica_groups)
  • (C6)size(replica_groups) は次のように定義されます。
    • cross_replica を使用する場合は num_replicas
    • cross_partition を使用する場合は num_partitions
  • (C7)0 <= replica_groups < size(replica_groups)
  • (C8)dim(replica_groups, 1) = split_count
  • (C9)type(results...) = type(operands...)(ただし、split_dimension != concat_dimension の場合を除く):
    • dim(results..., split_dimension) = dim(operands..., split_dimension) / split_count
    • dim(results..., concat_dimension) = dim(operands..., concat_dimension) * split_count

// num_replicas: 2
// num_partitions: 1
// %operand1@(0, 0): [[1, 2, 3, 4],
//                    [5, 6, 7, 8]]
// %operand1@(1, 0): [[9, 10, 11, 12],
//                    [13, 14, 15, 16]]
// %operand2@(0, 0): [[17, 18, 19, 20],
//                    [21, 22, 23, 24]]
// %operand2@(1, 0): [[25, 26, 27, 28],
//                    [29, 30, 31, 32]]
%result:2 = "stablehlo.all_to_all"(%operand1, %operand2) {
  split_dimension = 1 : i64,
  concat_dimension = 0 : i64,
  split_count = 2 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
  // channel_id = 0
} : (tensor<2x4xi64>, tensor<2x4xi64>) -> (tensor<4x2xi64>, tensor<4x2xi64>)
// %result#0@(0, 0): [[1, 2], [5, 6], [9, 10], [13, 14]]
// %result#0@(1, 0): [[3, 4], [7, 8], [11, 12], [15, 16]]
// %result#1@(0, 0): [[17, 18], [21, 22], [25, 26], [29, 30]]
// %result#1@(1, 0): [[19, 20], [23, 24], [27, 28], [31, 32]]

その他の例

セマンティクス

2 つのテンソル lhsrhs の要素ごとの AND を実行し、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。

  • ブール値の場合: 論理 AND。
  • 整数の場合: ビット演算 AND。

入力

ラベル 名前 タイプ 制約
(I1) lhs ブール型または整数型のテンソル (C1)
(I2) rhs ブール値型または整数型のテンソル (C1)

出力

名前 タイプ 制約
result ブール型または整数型のテンソル (C1)

制約

  • (C1)type(lhs) = type(rhs) = type(result)

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 0]]

その他の例

atan2

セマンティクス

lhs テンソルと rhs テンソルの要素ごとの atan2 演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。

  • 浮動小数点数の場合: IEEE-754 の atan2
  • 複素数の場合: 複素 atan2。
  • 量子化された型の場合: dequantize_op_quantize(atan2, lhs, rhs, type(result))

入力

ラベル 名前 タイプ 制約
(I1) lhs 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル (C1)
(I2) rhs 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)

// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs) : (tensor<3xf64>, tensor<3xf64>) -> tensor<3xf64>
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]

その他の例

batch_norm_grad

セマンティクス

grad_output からバックプロパゲートされる batch_norm_training の複数の入力の勾配を計算し、grad_operandgrad_scalegrad_offset テンソルを生成します。より正式には、このオペレーションは、次のように Python 構文を使用して既存の StableHLO オペレーションへの分解として表現できます。

def compute_sum(operand, feature_index):
  (sum,) = reduce(
      inputs=[operand],
      init_values=[constant(0, element_type(operand))],
      dimensions=[i for i in range(rank(operand)) if i != feature_index],
      body=lambda x, y: add(x, y))
  return sum

def compute_mean(operand, feature_index):
  sum = compute_sum(operand, feature_index)
  divisor = constant(size(operand) / dim(operand, feature_index),
                     element_type(operand))
  divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
  return divide(sum, divisor_bcast)

def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
  # Broadcast inputs to type(operand)
  scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
  epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
                                   type(operand))

  # Perform normalization using the provided `mean` and `variance`
  # Intermediate values will be useful for computing gradients
  centered_operand = subtract(operand, mean_bcast)
  stddev = sqrt(add(variance_bcast, epsilon_bcast))
  normalized_operand = divide(centered_operand, stddev)

  # Use the implementation from batchnorm_expander.cc in XLA
  # Temporary variables have exactly the same names as in the C++ code
  elements_per_feature = broadcast_in_dim(
      constant(divide(size(operand), dim(operand, feature_index)),
               element_type(grad_output)),
      [], type(operand))
  i1 = multiply(grad_output, elements_per_feature)
  i2 = broadcast_in_dim(
      compute_sum(grad_output, feature_index), [feature_index], type(operand))
  i3 = broadcast_in_dim(
      compute_sum(multiply(grad_output, centered_operand), feature_index),
      [feature_index], type(operand))
  i4 = multiply(i3, centered_operand)
  i5 = divide(i4, add(variance_bcast, epsilon_bcast))
  i6 = subtract(subtract(i1, i2), i5)

  grad_operand =
      multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
  grad_scale =
      compute_sum(multiply(grad_output, normalized_operand), feature_index)
  grad_offset = compute_sum(grad_output, feature_index)

  return grad_operand, grad_scale, grad_offset

量子化された型の場合は、dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean, variance, grad_output: batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index), operand, scale, mean, variance, grad_output, type(grad_operand), type(grad_scale), type(feature_index)) を実行します。

入力

ラベル 名前 タイプ 制約
(I1) operand 浮動小数点型のテンソルまたはテンソルごとの量子化テンソル (C1~C3)、(C5)
(I2) scale 浮動小数点型またはテンソルごとの量子化型の 1 次元テンソル (C2)、(C4)、(C5)
(I3) mean 浮動小数点型またはテンソルごとの量子化型の 1 次元テンソル (C2)、(C4)
(I4) variance 浮動小数点型またはテンソルごとの量子化型の 1 次元テンソル (C2)、(C4)
(I5) grad_output 浮動小数点型のテンソルまたはテンソルごとの量子化テンソル (C2)、(C3)
(I6) epsilon f32 型の定数
(I7) feature_index si64 型の定数 (C1)、(C5)

出力

名前 タイプ 制約
grad_operand 浮動小数点型のテンソルまたはテンソルごとの量子化テンソル (C2)、(C3)
grad_scale 浮動小数点数またはテンソルごとの量子化型の 1 次元テンソル (C2)、(C4)
grad_offset 浮動小数点型またはテンソルごとの量子化型の 1 次元テンソル (C2)、(C4)

制約

  • (C1)0 <= feature_index < rank(operand)
  • (C2)operandscalemeanvariancegrad_outputgrad_operandgrad_scalegrad_offsetbaseline_element_type が同じである。
  • (C3)operandgrad_outputgrad_operand は同じ形状です。
  • (C4)scalemeanvariancegrad_scalegrad_offset の形状が同じである。
  • (C5)size(scale) = dim(operand, feature_index)

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
//                [[0.1, 0.1], [0.1, 0.1]],
//                [[0.1, 0.1], [0.1, 0.1]]
//               ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>,
     tensor<2x2x2xf64>) -> (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %grad_operand: [
//                 [[0.0, 0.0], [0.0, 0.0]],
//                 [[0.0, 0.0], [0.0, 0.0]]
//                ]
// %grad_scale:  [0.0, 0.0]
// %grad_offset: [0.4, 0.4]

batch_norm_inference

セマンティクス

feature_index 次元を除くすべての次元で operand テンソルを正規化し、result テンソルを生成します。より正式には、このオペレーションは、Python 構文を使用して既存の StableHLO オペレーションへの分解として次のように表現できます。

def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
  # Broadcast inputs to shape(operand)
  scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
  offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
  epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
                                   type(operand))

  # Perform normalization using the provided `mean` and `variance` instead of
  # computing them like `batch_norm_training` does.
  centered_operand = subtract(operand, mean_bcast)
  stddev = sqrt(add(variance_bcast, epsilon_bcast))
  normalized_operand = divide(centered_operand, stddev)
  return add(multiply(scale_bcast, normalized_operand), offset_bcast)

量子化された型の場合は、dequantize_op_quantize(lambda operand, scale, offset, mean, variance: batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index), operand, scale, offset, mean, variance, type(result)) を実行します。

入力

ラベル 名前 タイプ 制約
(I1) operand 浮動小数点型のテンソルまたはテンソルごとの量子化テンソル (C1~C7)
(I2) scale 浮動小数点数またはテンソルごとの量子化型の 1 次元テンソル (C2)、(C3)
(I3) offset 浮動小数点型またはテンソルごとの量子化型の 1 次元テンソル (C2)、(C4)
(I4) mean 浮動小数点数またはテンソルごとの量子化型の 1 次元テンソル (C5)
(I5) variance 浮動小数点数またはテンソルごとの量子化型の 1 次元テンソル (C2)、(C6)
(I6) epsilon f32 型の定数
(I7) feature_index si64 型の定数 (C1)、(C3 ~ C6)

出力

名前 タイプ 制約
result 浮動小数点型のテンソルまたはテンソルごとの量子化テンソル (C2)、(C7)

制約

  • (C1)0 <= feature_index < rank(operand)
  • (C2)operandscaleoffsetmeanvarianceresult には同じ baseline_element_type があります。
  • (C3)size(scale) = dim(operand, feature_index)
  • (C4)size(offset) = dim(operand, feature_index)
  • (C5)size(mean) = dim(operand, feature_index)
  • (C6)size(variance) = dim(operand, feature_index)
  • (C7)baseline_type(operand) = baseline_type(result)

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>) -> tensor<2x2x2xf64>
// %result: [
//           [[0.0, 0.0], [2.0, 2.0]],
//           [[2.0, 2.0], [0.0, 0.0]]
//          ]

batch_norm_training

セマンティクス

feature_index ディメンション以外のすべてのディメンションで平均と分散を計算し、operand テンソルを正規化して outputbatch_meanbatch_var テンソルを生成します。より正式には、このオペレーションは、Python 構文を使用して既存の StableHLO オペレーションへの分解として次のように表現できます。

def compute_mean(operand, feature_index):
  (sum,) = reduce(
      inputs=[operand],
      init_values=[constant(0, element_type(operand))],
      dimensions=[i for i in range(rank(operand)) if i != feature_index],
      body=lambda x, y: add(x, y))
  divisor = constant(size(operand) / dim(operand, feature_index),
                     element_type(operand))
  divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
  return divide(sum, divisor_bcast)

def compute_variance(operand, feature_index):
  mean = compute_mean(operand, feature_index)
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  centered_operand = subtract(operand, mean_bcast)
  return compute_mean(mul(centered_operand, centered_operand), feature_index)

def batch_norm_training(operand, scale, offset, epsilon, feature_index):
  mean = compute_mean(operand, feature_index)
  variance = compute_variance(operand, feature_index)
  return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
                              feature_index),
         mean, variance

量子化された型の場合は、dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset: batch_norm_training(operand, scale, offset, epsilon, feature_index), operand, scale, offset, type(output), type(batch_mean), type(batch_var)) を実行します。

入力

ラベル 名前 タイプ 制約
(I1) operand 浮動小数点型のテンソルまたはテンソルごとの量子化テンソル (C1)
(I2) scale 浮動小数点またはテンソルごとの量子化された 1 次元テンソル (C2)、(C3)
(I3) offset 浮動小数点またはテンソルごとの量子化された 1 次元テンソル (C2)、(C4)
(I4) epsilon f32 型の定数 (C1)、(C3 ~ C6)
(I5) feature_index si64 型の定数 (C1)、(C3 ~ C6)

出力

名前 タイプ 制約
output 浮動小数点型のテンソルまたはテンソルごとの量子化テンソル (C7)
batch_mean 浮動小数点数またはテンソルごとに量子化された 1 次元テンソル (C2)、(C5)
batch_var 浮動小数点数またはテンソルごとに量子化された 1 次元テンソル (C2)、(C6)

制約

  • (C1)0 <= feature_index < rank(operand)
  • (C2)operandscaleoffsetbatch_meanbatch_varoutput は同じ baseline_element_type を持ちます。
  • (C3)size(scale) = dim(operand, feature_index)
  • (C4)size(offset) = dim(operand, feature_index)
  • (C5)size(batch_mean) = dim(operand, feature_index)
  • (C6)size(batch_var) = dim(operand, feature_index)
  • (C7)baseline_type(output) = baseline_type(operand)

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) ->
    (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %output: [
//           [[0.0, 0.0], [2.0, 2.0]],
//           [[2.0, 2.0], [0.0, 0.0]]
//          ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]

bitcast_convert

セマンティクス

operand テンソルにビットキャスト演算を実行し、result テンソルを生成します。ここで、operand テンソル全体のビットは result テンソルの型を使用して再解釈されます。

より正式には、E = element_type(operand)E' = element_type(result)R = rank(operand) を指定すると、次のようになります。

  • num_bits(E') < num_bits(E) の場合は bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1])
  • num_bits(E') > num_bits(E) の場合、bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :])
  • num_bits(E') = num_bits(E) の場合は bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1])

bits は、指定された値のメモリ内表現を返します。テンソルの正確な表現は実装定義であり、要素型の正確な表現も実装定義であるため、その動作は実装定義です。

入力

ラベル 名前 タイプ 制約
(I1) operand テンソルまたは量子化テンソル (C1 ~ C2)

出力

名前 タイプ 制約
result テンソルまたは量子化テンソル (C1 ~ C2)

制約

  • (C1)E = is_quantized(operand) ? storage_type(operand) : element_type(operand)E' = is_quantized(result) ? storage_type(result) : element_type(result)R = rank(operand) が指定されている場合:
    • num_bits(E') = num_bits(E) の場合は shape(result) = shape(operand)
    • num_bits(E') < num_bits(E) の場合:
    • rank(result) = R + 1
    • dim(result, i) = dim(operand, i) はすべての 0 <= i < R に適用されます。
    • dim(result, R) * num_bits(E') = num_bits(E)
    • num_bits(E') > num_bits(E) の場合:
    • rank(result) = R - 1
    • dim(result, i) = dim(operand, i) はすべての 0 <= i < R に適用されます。
    • dim(operand, R - 1) * num_bits(E) = num_bits(E')
  • (C2)is_complex(operand) or is_complex(result) の場合は is_complex(operand) and is_complex(result) です。

// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand) : (tensor<f64>) -> tensor<4xf16>
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation

その他の例

broadcast_in_dim

セマンティクス

operand テンソルにデータを複製することで入力テンソルの次元やランクを展開し、result テンソルを生成します。より正式には、result[result_index] = operand[operand_index] で、axes(operand) 内のすべての d について次のようにします。

  • dim(operand, d) = 1 の場合、operand_index[d] = 0
  • それ以外の場合は operand_index[d] = result_index[broadcast_dimensions[d]]

入力

ラベル 名前 タイプ 制約
(I1) operand テンソルまたは量子化テンソル (C1-C2)、(C5-C6)
(I2) broadcast_dimensions si64 型の 1 次元テンソル定数 (C2~C6)

出力

名前 タイプ 制約
result テンソルまたは量子化テンソル (C1)、(C3)、(C5~C6)

制約

  • (C1)element_type(result) は次のように定義されます。
    • element_type(operand)!is_per_axis_quantized(operand) の場合)。
    • element_type(operand)(ただし、quantization_dimension(operand)scales(operand)zero_points(operand) は、それぞれ quantization_dimension(result)scales(result)zero_points(result) と異なる場合があります)。
  • (C2)size(broadcast_dimensions) = rank(operand)
  • (C3)0 <= broadcast_dimensions < rank(result)
  • (C4)is_unique(broadcast_dimensions)
  • (C5)axes(operand) 内のすべての d について:
    • dim(operand, d) = 1 または
    • dim(operand, d) = dim(result, broadcast_dimensions[d])
  • (C6)is_per_axis_quantized(result) の場合:
    • quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)]
    • dim(operand, quantization_dimension(operand)) = 1 の場合は scales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result))) です。

// %operand: [
//            [1, 2, 3]
//           ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
  broadcast_dimensions = array<i64: 2, 1>
} : (tensor<1x3xi32>) -> tensor<2x3x2xi32>
// %result: [
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ],
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ]
//          ]

その他の例

ケース

セマンティクス

index の値に応じて、branches から関数を 1 つだけ実行することで、出力を生成します。より正式には、result = selected_branch() のようにします。

  • 0 <= index < size(branches) の場合、selected_branch = branches[index]
  • それ以外の場合は selected_branch = branches[-1]

入力

ラベル 名前 タイプ 制約
(I1) index si32 型の 0 次元テンソル
(I2) branches 関数の可変数 (C1~C4)

出力

名前 タイプ 制約
results テンソル、量子化テンソル、またはトークンの可変数 (C4)

制約

  • (C1)0 < size(branches)
  • (C2)input_types(branches...) = []
  • (C3)same(output_types(branches...))
  • (C4)type(results...) = output_types(branches[0])

// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
  "stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> ()
}, {
  "stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> ()
}) : (tensor<i32>) -> (tensor<2xi64>, tensor<2xi64>)
// %result0: [1, 1]
// %result1: [1, 1]

その他の例

CBRT

セマンティクス

operand テンソルに対して要素単位の立方根演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の操作を行います。

  • 浮動小数点数の場合: IEEE-754 の rootn(x, 3)
  • 複素数の場合: 複素立方根。
  • 量子化型の場合: dequantize_op_quantize(cbrt, operand, type(result))

入力

ラベル 名前 タイプ 制約
(I1) operand 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(operand) = baseline_type(result)

// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand) : (tensor<4xf64>) -> tensor<4xf64>
// %result: [0.0, 1.0, 2.0, 3.0]

その他の例

ceil

セマンティクス

operand テンソルの要素単位の ceil を実行し、result テンソルを生成します。IEEE-754 仕様の roundToIntegralTowardPositive オペレーションを実装します。量子化された型の場合は、dequantize_op_quantize(ceil, operand, type(result)) を実行します。

入力

ラベル 名前 タイプ 制約
(I1) operand 浮動小数点型のテンソルまたはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result 浮動小数点型のテンソルまたはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(operand) = baseline_type(result)

// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]

その他の例

cholesky

セマンティクス

行列のバッチの Cholesky 分解を計算します。

より正式には、index_space(result) のすべての i について、result[i0, ..., iR-3, :, :]a[i0, ..., iR-3, :, :] の Cholesky 分解であり、下三角(lowertrue の場合)または上三角(lowerfalse の場合)行列のいずれかになります。反対の三角形(厳密な上三角形または厳密な下三角形)の出力値は、実装で定義されます。

入力行列がエルミート正定値行列ではない i が存在する場合、動作は未定義です。

量子化された型の場合は、dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result)) を実行します。

入力

ラベル 名前 タイプ 制約
(I1) a 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル (C1~C3)
(I2) lower i1 の 0 次元テンソル定数

出力

名前 タイプ 制約
result 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(a) = baseline_type(result)
  • (C2)2 <= rank(a)
  • (C3)dim(a, -2) = dim(a, -1)

// %a: [
//      [1.0, 2.0, 3.0],
//      [2.0, 20.0, 26.0],
//      [3.0, 26.0, 70.0]
//     ]
%result = "stablehlo.cholesky"(%a) {
  lower = true
} : (tensor<3x3xf32>) -> tensor<3x3xf64>
// %result: [
//           [1.0, 0.0, 0.0],
//           [2.0, 4.0, 0.0],
//           [3.0, 5.0, 6.0]
//          ]

クランプ

セマンティクス

operand テンソルのすべての要素を最小値と最大値の間にクランプし、result テンソルを生成します。より正式には、result[result_index] = minimum(maximum(operand[result_index], min_element), max_element)min_element = rank(min) = 0 ? min[] : min[result_index]max_element = rank(max) = 0 ? max[] : max[result_index])です。量子化された型の場合は、dequantize_op_quantize(clamp, min, operand, max, type(result)) を実行します。

複素数の順序付けには意外なセマンティクスが伴うため、将来的には、この演算での複素数のサポートは終了する予定です(#560)。

入力

ラベル 名前 タイプ 制約
(I1) min テンソルまたはテンソルごとの量子化テンソル (C1)、(C3)
(I2) operand テンソルまたはテンソルごとの量子化テンソル (C1~C4)
(I3) max テンソルまたはテンソルごとの量子化テンソル (C2)、(C3)

出力

名前 タイプ 制約
result テンソルまたはテンソルごとの量子化テンソル (C4)

制約

  • (C1)rank(min) = 0 or shape(min) = shape(operand)
  • (C2)rank(max) = 0 or shape(max) = shape(operand)
  • (C3)baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max)
  • (C4)baseline_type(operand) = baseline_type(result)

// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max) : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// %result: [5, 13, 20]

その他の例

collective_broadcast

セマンティクス

StableHLO プロセス グリッドの各プロセス グループ内で、operand テンソルの値をソースプロセスからターゲット プロセスに送信し、result テンソルを生成します。

このオペレーションでは、StableHLO プロセス グリッドが、次のように定義された process_groups に分割されます。

  • channel_id <= 0 の場合は cross_replica(replica_groups)
  • channel_id > 0 の場合、cross_partition(replica_groups)

その後、result@process は次のようになります。

  • operand@process_groups[i, 0]: プロセスが process_groups[i] にあるような i が存在する場合。
  • それ以外の場合は broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))

入力

ラベル 名前 タイプ 制約
(I1) operand テンソルまたはテンソルごとの量子化テンソル (C3)
(I2) replica_groups si64 型の 1 次元テンソル定数の可変数 (C1)、(C2)
(I3) channel_id si64 型の定数

出力

名前 タイプ 制約
result テンソルまたはテンソルごとの量子化テンソル (C3)

制約

  • (C1)is_unique(replica_groups)
  • (C2)0 <= replica_groups < N。ここで、N は次のように定義されます。
    • cross_replica を使用する場合は num_replicas
    • cross_partition を使用する場合は num_partitions
  • (C3)type(result) = type(operand)

// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
  replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]

collective_permute

セマンティクス

StableHLO プロセス グリッド内の各プロセス グループ内で、operand テンソルの値をソースプロセスからターゲット プロセスに送信し、result テンソルを生成します。

このオペレーションは、StableHLO プロセス グリッドを次のように定義された process_groups に分割します。

  • channel_id <= 0 の場合、cross_replica(source_target_pairs)
  • channel_id > 0 の場合は cross_partition(source_target_pairs)

その後、result@process は次のようになります。

  • operand@process_groups[i, 0]process_groups[i, 1] = process のような i が存在する場合。
  • それ以外の場合は broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))

入力

ラベル 名前 タイプ 制約
(I1) operand テンソルまたはテンソルごとの量子化テンソル (C5)
(I2) source_target_pairs si64 型の 2 次元のテンソル定数 (C1~C4)
(I3) channel_id si64 型の定数

出力

名前 タイプ 制約
result テンソルまたはテンソルごとの量子化テンソル (C1)

制約

  • (C1)dim(source_target_pairs, 1) = 2
  • (C2)is_unique(source_target_pairs[:, 0])
  • (C3)is_unique(source_target_pairs[:, 1])
  • (C4)0 <= source_target_pairs < N。ここで、N は次のように定義されます。
    • cross_replica を使用する場合は num_replicas
    • cross_partition を使用する場合は num_partitions
  • (C5)type(result) = type(operand)

// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
  source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]

その他の例

比較

セマンティクス

comparison_directioncompare_type に従って lhs テンソルと rhs テンソルの要素ごとの比較を行い、result テンソルを生成します。

comparison_directioncompare_type の値のセマンティクスは次のとおりです。

ブール値と整数の要素型の場合:

  • EQ: lhs = rhs
  • NE: lhs != rhs
  • GE: lhs >= rhs
  • GT: lhs > rhs
  • LE: lhs <= rhs
  • LT: lhs < rhs

compare_type = FLOAT の浮動小数点要素型の場合、op は次の IEEE-754 演算を実装します。

  • EQ: compareQuietEqual
  • NE: compareQuietNotEqual
  • GE: compareQuietGreaterEqual
  • GT: compareQuietGreater
  • LE: compareQuietLessEqual
  • LT: compareQuietLess

compare_type = TOTALORDER を含む浮動小数点要素型の場合、op は IEEE-754 の totalOrder オペレーションと compareQuietEqual オペレーションの組み合わせを使用します。

複雑な要素型の場合、指定された comparison_directioncompare_type を使用して、(real, imag) ペアの辞書順による比較が行われます。複素数に順序付けを適用すると、予期しないセマンティクスが発生するため、comparison_directionGEGTLE、または LT の場合、複素数のサポートを削除する予定です(#560)。

量子化された型の場合、dequantize_compare(lhs, rhs, comparison_direction) を実行します。

入力

ラベル 名前 タイプ 制約
(I1) lhs テンソルまたはテンソルごとの量子化テンソル (C1 ~ C3)
(I2) rhs テンソルまたはテンソルごとの量子化テンソル (C1-C2)
(I3) comparison_direction EQNEGEGTLELT の列挙型
(I4) compare_type FLOATTOTALORDERSIGNEDUNSIGNED の列挙型 (C3)

出力

名前 タイプ 制約
result ブール型のテンソル (C2)

制約

  • (C1)baseline_element_type(lhs) = baseline_element_type(rhs)
  • (C2)shape(lhs) = shape(rhs) = shape(result)
  • (C3)compare_type は次のように定義されます。
    • is_signed_integer(element_type(lhs)) の場合は SIGNED
    • is_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs)) の場合は UNSIGNED
    • is_float(element_type(lhs)) の場合は FLOAT または TOTALORDER
    • is_complex(element_type(lhs)) の場合は FLOAT

// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
  comparison_direction = #stablehlo<comparison_direction LT>,
  compare_type = #stablehlo<comparison_type FLOAT>
} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
// %result: [true, false]

その他の例

複雑

セマンティクス

実数値と虚数値のペア(lhsrhs)から複合値に要素単位で変換し、result テンソルを生成します。

入力

ラベル 名前 タイプ 制約
(I1) lhs f32 型または f64 型のテンソル (C1~C3)
(I2) rhs f32 型または f64 型のテンソル (C1)

出力

名前 タイプ 制約
result 複合型テンソル (C2)、(C3)

制約

  • (C1)type(lhs) = type(rhs)
  • (C2)shape(result) = shape(lhs)
  • (C3)element_type(result) の型は complex<E> で、E = element_type(lhs) です。

// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs) : (tensor<2xf64>, tensor<2xf64>) -> tensor<2xcomplex<f64>>
// %result: [(1.0, 2.0), (3.0, 4.0)]

その他の例

複合

セマンティクス

他の StableHLO オペレーションで構成された(コンポーズされた)オペレーションをカプセル化し、inputscomposite_attributes を受け取って results を生成します。op のセマンティクスは decomposition 属性によって実装されます。composite オペレーションは、プログラムのセマンティクスを変更せずに分解に置き換えることができます。分解をインライン化しても同じオペレーションのセマンティクスが得られない場合は、custom_call を使用することをおすすめします。

version フィールド(デフォルトは 0)は、コンポジットのセマンティクスが変更されたときに使用されます。

入力

ラベル 名前 タイプ
(I1) inputs 値の可変数
(I2) name string 型の定数
(I3) composite_attributes 属性辞書
(I4) decomposition string 型の定数
(I5) version si32 型の定数

出力

名前 タイプ
results 値の可変数

制約

  • (C1)is_namespaced_op_name(name)
  • (C2)is_defined_in_parent_scope(decomposition)
  • (C3)types(inputs...) == input_types(decomposition)
  • (C4)types(results...) == output_types(decomposition)

%results = "stablehlo.composite"(%input0, %input1) {
  name = "my_namespace.my_op",
  composite_attributes = {
    my_attribute = "my_value"
  },
  decomposition = @my_op,
  version = 1 : i32
} : (tensor<f32>, tensor<f32>) -> tensor<f32>

その他の例

concatenate

セマンティクス

指定された引数と同じ順序で dimension 次元に沿って inputs を連結し、result テンソルを生成します。より正式には、result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1] です。ここで、

  1. id = d0 + ... + dk-1 + kd
  2. ddimension に等しく、d0inputsd 番目のディメンション サイズです。

入力

ラベル 名前 タイプ 制約
(I1) inputs テンソルの可変数またはテンソルごとの量子化テンソル (C1~C6)
(I2) dimension si64 型の定数 (C2)、(C4)、(C6)

出力

名前 タイプ 制約
result テンソルまたはテンソルごとの量子化テンソル (C5-C6)

制約

  • (C1)same(element_type(inputs...))
  • (C2)dim(inputs..., dimension) を除く same(shape(inputs...))
  • (C3)0 < size(inputs)
  • (C4)0 <= dimension < rank(inputs[0])
  • (C5)element_type(result) = element_type(inputs[0])
  • (C6)shape(result) = shape(inputs[0])(ただし、次の場合は除く)。
    • dim(result, dimension) = dim(inputs[0], dimension) + ...

// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
  dimension = 0 : i64
} : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64>
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]

その他の例

定数

セマンティクス

定数 value から output テンソルを生成します。

入力

ラベル 名前 タイプ 制約
(I1) value 定数 (C1)

出力

名前 タイプ 制約
output テンソルまたは量子化テンソル (C1)

制約

  • (C1)type(value) = type(output)

%output = "stablehlo.constant"() {
  value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
} : () -> tensor<2x2xf32>
// %output: [[0.0, 1.0], [2.0, 3.0]]

その他の例

コンバージョン

セマンティクス

operand テンソルで要素型を要素ごとに変換し、result テンソルを生成します。

boolean-to-any-supported-typeでは、値 false はゼロに変換され、値 true は 1 に変換されます。any-supported-type-to-booleanへの変換の場合、ゼロ値は false に変換され、ゼロ以外の値は true に変換されます。複雑な型での動作については以下をご覧ください。

integer-to-integerinteger-to-floating-pointfloating-point-to-floating-point を含む変換では、ソース値が宛先の型で正確に表現できる場合、結果値はそのまま表現されます。それ以外の場合、動作は未定です(#180)。

浮動小数点数から整数への変換では、小数部分は切り捨てられます。切り捨てられた値を宛先型で表現できない場合、動作は未定です(#180)。

複素数から複素数への変換は、実部と虚部の変換における浮動小数点数から浮動小数点数への変換と同じ動作になります。

複素数から他の型への変換他の型から複素数への変換では、それぞれソースの虚数値が無視されるか、宛先の虚数値がゼロになります。実部の変換は浮動小数点変換に従います。

原則として、このオペレーションはデクォンタイズ(量子化テンソルから通常のテンソルへの変換)、量子化(通常のテンソルから量子化テンソルへの変換)、再量子化(量子化テンソル間の変換)を表現できますが、現時点では専用のオペレーションがあります。最初のユースケースの場合は uniform_dequantize、2 つ目のユースケースと 3 つ目のユースケースの場合は uniform_quantize です。今後、これらの 2 つのオペレーションは convert に統合される可能性があります(#1576)。

入力

ラベル 名前 タイプ 制約
(I1) operand テンソル (C1)

出力

名前 タイプ 制約
result テンソル (C1)

制約

  • (C1)shape(operand) = shape(result)

// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand) : (tensor<3xi64>) -> tensor<3xcomplex<f64>>
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]

その他の例

畳み込み

セマンティクス

lhs のウィンドウと rhs のスライス間のドット積を計算し、result を生成します。次の図は、具体的な例を使用して、result の要素が lhsrhs からどのように計算されるかを示しています。

畳み込み

より正式には、lhs のウィンドウを表現できるようにするために、次のように lhs の観点から入力をフレーミングすることを検討します。

  • lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension))
  • lhs_window_strides = lhs_shape(1, window_strides, 1)
  • lhs_padding = lhs_shape([0, 0], padding, [0, 0])
  • lhs_base_dilations = lhs_shape(1, lhs_dilation, 1)
  • lhs_window_dilations = lhs_shape(1, rhs_dilation, 1)

このフレーム変更では、次のヘルパー関数を使用します。

  • lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension])
  • result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension])
  • permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1]j[d] = i[permutation[d]] です。

feature_group_count = 1 かつ batch_group_count = 1 の場合、index_space(dim(result, output_spatial_dimensions...)) 内のすべての output_spatial_index について、result[result_shape(:, output_spatial_index, :)] = dot_product です。ここで、

  • padding_value = constant(0, element_type(lhs))
  • padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1)
  • lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides
  • lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations)
  • reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true])。この機能は使用されていないと思われるため、今後削除する予定です(#1181)。
  • dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension])

feature_group_count > 1 の場合:

  • lhses = split(lhs, feature_group_count, input_feature_dimension)
  • rhses = split(rhs, feature_group_count, kernel_output_feature_dimension)
  • results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...)
  • result = concatenate(results, output_feature_dimension)

batch_group_count > 1 の場合:

  • lhses = split(lhs, batch_group_count, input_batch_dimension)
  • rhses = split(rhs, batch_group_count, kernel_output_feature_dimension)
  • results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...)
  • result = concatenate(results, output_feature_dimension)

量子化された型の場合は、dequantize_op_quantize( lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension, input_feature_dimension, input_spatial_dimensions, kernel_input_feature_dimension, kernel_output_feature_dimension, kernel_spatial_dimensions, output_batch_dimension, output_feature_dimension, output_spatial_dimensions, feature_group_count, batch_group_count, precision_config), lhs, rhs, type(result)) を実行します。

ハイブリッド量子化タイプの場合は、hybrid_dequantize_then_op( lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension, input_feature_dimension, input_spatial_dimensions, kernel_input_feature_dimension, kernel_output_feature_dimension, kernel_spatial_dimensions, output_batch_dimension, output_feature_dimension, output_spatial_dimensions, feature_group_count, batch_group_count, precision_config), lhs, rhs) を実行します。

入力

ラベル 名前 タイプ 制約
(I1) lhs テンソルまたはテンソルごとの量子化テンソル (C1)、(C10~C11)、(C14)、(C25)、(C27~C28)、(C31~C32)、(C34)
(I2) rhs テンソルまたは量子化テンソル (C1)、(C14 ~ C16)、(C25)、(C27 ~ C29)、(C31 ~ C34)
(I3) window_strides si64 型の 1 次元テンソル定数 (C2~C3)、(C25)
(I4) padding si64 型の 2 次元テンソル定数 (C4)、(C25)
(I5) lhs_dilation si64 型の 1 次元テンソル定数 (C5 ~ C6)、(C25)
(I6) rhs_dilation si64 型の 1 次元のテンソル定数 (C7 ~ C8)、(C25)
(I7) window_reversal i1 型の 1 次元テンソル定数 (C9)
(I8) input_batch_dimension si64 型の定数 (C10)、(C13)、(C25)
(I9) input_feature_dimension si64 型の定数 (C11)、(C13~C14)
(I10) input_spatial_dimensions si64 型の 1 次元のテンソル定数 (C12)、(C13)、(C25)
(I11) kernel_input_feature_dimension si64 型の定数 (C14)、(C18)
(I12) kernel_output_feature_dimension si64 型の定数 (C15 ~ C16)、(C18)、(C25)、(C29)
(I13) kernel_spatial_dimensions si64 型の 1 次元テンソル定数 (C17~C18)、(C25)
(I14) output_batch_dimension si64 型の定数 (C20)、(C25)
(I15) output_feature_dimension si64 型の定数 (C20)、(C25)、(C30)
(I16) output_spatial_dimensions si64 型の 1 次元テンソル定数 (C19 ~ C20)、(C25)
(I17) feature_group_count si64 型の定数 (C11)、(C14)、(C16)、(C21)、(C23)
(I18) batch_group_count si64 型の定数 (C10)、(C15)、(C22)、(C23)、(C25)
(I19) precision_config DEFAULTHIGHHIGHEST の可変列挙型の数 (C24)

出力

名前 タイプ 制約
result テンソルまたは量子化テンソル (C25 ~ C28)、(C30)、(C32 ~ 34)

制約

  • (C1)N = rank(lhs) = rank(rhs)
  • (C2)size(window_strides) = N - 2
  • (C3)0 < window_strides
  • (C4)shape(padding) = [N - 2, 2]
  • (C5)size(lhs_dilation) = N - 2
  • (C6)0 < lhs_dilation
  • (C7)size(rhs_dilation) = N - 2
  • (C8)0 < rhs_dilation
  • (C9)size(window_reversal) = N - 2
  • (C10)dim(lhs, input_batch_dimension) % batch_group_count = 0
  • (C11)dim(lhs, input_feature_dimension) % feature_group_count = 0
  • (C12)size(input_spatial_dimensions) = N - 2
  • (C13)input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension] が与えられている場合:
    • is_unique(input_dimensions)
    • 0 <= input_dimensions < N
  • (C14)dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count
  • (C15)dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0
  • (C16)dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0
  • (C17)size(kernel_spatial_dimensions) = N - 2
  • (C18)kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension] が与えられている場合:
    • is_unique(kernel_dimensions)
    • 0 <= kernel_dimensions < N
  • (C19)size(output_spatial_dimensions) = N - 2
  • (C20)output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension] が与えられている場合:
    • is_unique(output_dimensions)
    • 0 <= output_dimensions < N
  • (C21)0 < feature_group_count
  • (C22)0 < batch_group_count
  • (C23)feature_group_count = 1 or batch_group_count = 1
  • (C24)size(precision_config) = 2
  • (C25)dim(result, result_dim) は次のように定義されます。
    • result_dim = output_batch_dimension の場合、dim(lhs, input_batch_dimension) / batch_group_count
    • result_dim = output_feature_dimension の場合、dim(rhs, kernel_output_feature_dimension)
    • それ以外の場合は num_windows。ここで:
    • output_spatial_dimensions[spatial_dim] = result_dim
    • lhs_dim = input_spatial_dimensions[spatial_dim]
    • rhs_dim = kernel_spatial_dimensions[spatial_dim]
    • dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
    • padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
    • dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
    • is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
    • num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1
  • (C26)rank(result) = N
  • オペレーションで量子化されていないテンソルを使用する場合:
    • (C27)element_type(lhs) = element_type(rhs) = element_type(result)
  • オペレーションで量子化されたテンソルを使用する場合:
    • (C28)is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
    • (C29)is_per_axis_quantized(rhs) の場合、quantization_dimension(rhs) = kernel_output_feature_dimension
    • (C30)is_per_axis_quantized(result) の場合、quantization_dimension(result) = output_feature_dimension
    • is_quantized(lhs) の場合:
    • (C31)storage_type(lhs) = storage_type(rhs)
    • (C32)expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
    • (C33)is_per_tensor_quantized(rhs) の場合は is_per_tensor_quantized(result) です。
    • !is_quantized(lhs) の場合:
    • (C34)element_type(lhs) = expressed_type(rhs) = element_type(result)

// %lhs: [[
//        [
//          [1], [2], [5], [6]
//        ],
//        [
//          [3], [4], [7], [8]
//        ],
//        [
//          [10], [11], [14], [15]
//        ],
//        [
//          [12], [13], [16], [17]
//        ]
//      ]]
//
// %rhs: [
//        [[[1]], [[1]], [[1]]],
//        [[[1]], [[1]], [[1]]],
//        [[[1]], [[1]], [[1]]]
//       ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
  window_strides = array<i64: 4, 4>,
  padding = dense<0> : tensor<2x2xi64>,
  lhs_dilation = array<i64: 2, 2>,
  rhs_dilation = array<i64: 1, 1>,
  window_reversal = array<i1: false, false>,
  // In the StableHLO dialect, dimension numbers are encoded via:
  // `[<input dimensions>]x[<kernel dimensions>]->[output dimensions]`.
  // "b" is batch dimension, "f" is feature dimension,
  // "i" is input feature dimension, "o" is output feature dimension,
  // "0/1/etc" are spatial dimensions.
  dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
  batch_group_count = 1 : i64,
  feature_group_count = 1 : i64,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
//            [[10], [26]],
//            [[46], [62]]
//          ]]

その他の例

コサイン

セマンティクス

operand テンソルに対して要素単位のコサイン演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の操作を行います。

  • 浮動小数点数の場合: IEEE-754 の cos
  • 複素数の場合: 複素余弦。
  • 量子化型の場合: dequantize_op_quantize(cosine, operand, type(result))

入力

ラベル 名前 タイプ 制約
(I1) operand 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(operand) = baseline_type(result)

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.cosine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.0], [-1.0, 0.0]]

その他の例

count_leading_zeros

セマンティクス

operand テンソルの先頭のゼロビット数を要素ごとにカウントし、result テンソルを生成します。

入力

ラベル 名前 タイプ 制約
(I1) operand 整数型のテンソル (C1)

出力

名前 タイプ 制約
result 整数型のテンソル (C1)

制約

  • (C1)type(operand) = type(result)

// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand) : (tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[64, 63], [56, 0]]

その他の例

custom_call

セマンティクス

inputscalled_computations を受け取って results を生成する、実装定義のオペレーション call_target_name をカプセル化します。has_side_effectbackend_configapi_version を使用して、実装定義のメタデータを追加できます。

現時点では、このオペレーションにはかなり整理されていないメタデータのコレクションが含まれています。これは、XLA コンパイラでの対応するオペレーションの有機的な進化を反映しています。今後、このメタデータを統合する予定です(#741)。

入力

ラベル 名前 タイプ
(I1) inputs 値の可変数
(I2) call_target_name string 型の定数
(I3) has_side_effect i1 型の定数
(I4) backend_config string 型の定数または属性辞書
(I5) api_version si32 型の定数
(I6) called_computations string 型の可変長定数

出力

名前 タイプ
results 可変長の値の数

%results = "stablehlo.custom_call"(%input0) {
  call_target_name = "foo",
  has_side_effect = false,
  backend_config = {bar = 42 : i32},
  api_version = 4 : i32,
  called_computations = [@foo]
} : (tensor<f64>) -> tensor<f64>

割る

セマンティクス

除数 lhs テンソルと除算子 rhs テンソルの要素ごとの除算を行い、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。

  • 整数の場合: 小数部を破棄して代数的な商を生成する整数除算。
  • 浮動小数点数の場合: IEEE-754 の division
  • 複素数の場合: 複素除算。
  • 量子化された型の場合:
    • dequantize_op_quantize(divide, lhs, rhs, type(result))

入力

ラベル 名前 タイプ 制約
(I1) lhs 整数型、浮動小数点型、複合型のテンソル、またはテンソルごとの量子化テンソル (C1)
(I2) rhs 整数型、浮動小数点型、複合型のテンソル、またはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result 整数型、浮動小数点型、複合型のテンソル、またはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)

// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]

その他の例

dot_general

セマンティクス

lhs のスライスと rhs のスライスの間のドット積を計算し、result テンソルを生成します。

正式には result[result_index] = dot_product です。ここで、

  • lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions]
  • rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions]
  • result_batching_index + result_lhs_index + result_rhs_index = result_indexsize(result_batching_index) = size(lhs_batching_dimensions)size(result_lhs_index) = size(lhs_result_dimensions)size(result_rhs_index) = size(rhs_result_dimensions))。
  • transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions)
  • transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :])
  • reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions))
  • transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions)
  • transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :])
  • reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions))
  • dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y))

量子化された型の場合は、dequantize_op_quantize( lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions, rhs_batching_dimensions, lhs_contracting_dimensions, rhs_contracting_dimensions, precision_config), lhs, rhs, type(result)) を実行します。

ハイブリッド量子化タイプの場合は、hybrid_dequantize_then_op( lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions, rhs_batching_dimensions, lhs_contracting_dimensions, rhs_contracting_dimensions, precision_config), lhs, rhs) を実行します。

precision_config は、アクセラレータ バックエンドでの計算の速度と精度のトレードオフを制御します。これは次のいずれかです(現時点では、これらの列挙型値のセマンティクスは指定されていません。#755 で対応する予定です)。

  • DEFAULT: 計算速度は最も速いが、元の数値に対する近似値の精度は最も低い。
  • HIGH: 計算は遅くなりますが、元の数値に近い値をより正確に求めることができます。
  • HIGHEST: 計算に最も時間がかかりますが、元の数値に最も近い近似値が得られます。

DotAlgorithm は、ドット演算の実装に使用されるアルゴリズムの主なプロパティを定義します。これにより、精度も定義されます。アルゴリズム属性フィールドが設定されている場合、precision_configDEFAULT にする必要があります。デフォルト パラメータは実装で定義されるため、DotAlgorithms にデフォルト値はありません。そのため、すべてのドット アルゴリズム フィールドを None に設定して、空のドット アルゴリズムを指定し、代わりに precision_config 値を使用できます。

DotAlgorithm フィールドには次のフィールドが含まれます。

  • lhs_precision_typerhs_precision_type: オペレーションの LHS と RHS が丸められる精度。精度タイプは、入力と出力のストレージ タイプとは独立しています。
  • accumulation_type 累積に使用される精度。
  • lhs_component_countrhs_component_countnum_primitive_operations は、LHS や RHS を複数のコンポーネントに分解し、それらの値に対して複数の「プリミティブ」ドット演算を実行するアルゴリズムを実行する場合に適用されます。通常は、より高い精度をエミュレートします(例: 高精度の計算に bfloat16 AI データ型を活用: btf12_6x3x3)。分解のないアルゴリズムの場合、これらの値は 1 に設定する必要があります。
  • allow_imprecise_accumulation: 一部のステップで低精度の累積が許可されるかどうかを指定します(CUBLASLT_MATMUL_DESC_FAST_ACCUM など)。

DotAlgorithm 属性の例:

// Inputs are casted to tf32, and then accumulated in f32:
{lhs_precision_type = tf32,
 rhs_precision_type = tf32,
 accumulation_type = f32,
 lhs_component_count = 1,
 rhs_component_count = 1,
 num_primitive_operations = 1,
 allow_imprecise_accumulation = false}


// bf16_6x: each input is decomposed to 3 bf16 components, then 6 dot operations are done on those components, and the result is accumulated in f32.
{lhs_precision_type = bf16,
 rhs_precision_type = bf16,
 accumulation_type = f32,
 lhs_component_count = 3,
 rhs_component_count = 3,
 num_primitive_operations = 6,
 allow_imprecise_accumulation = false}


// Inputs are (casted to) f8e5m2, and we accumulate in f32, but for some steps we may accumulate in lower precision.
{lhs_precision_type = f8e5m2,
 rhs_precision_type = f8e5m2,
 accumulation_type = f32,
 lhs_component_count = 1,
 rhs_component_count = 1,
 num_primitive_operations = 1,
 allow_imprecise_accumulation = true}

どの組み合わせをサポートするかは、実装が決定します。一般に、各アルゴリズムが StableHLO のコンシューマによって各アクセラレータ タイプでサポートされているとは限りません。特定のアルゴリズムがサポートされていない場合は、代替のアルゴリズムにフォールバックするのではなく、エラーが発生します。StableHLO 検証ではベスト エフォート検証が提供され、どのハードウェアでもサポートされていないアルゴリズムが使用されるのを防ぎます。

サポートされているアルゴリズムの値については、xla_data.proto > Algorithm をご覧ください。チケット #2483 には、バックエンドによってサポートされているアルゴリズムに関する一元化されたドキュメントを作成する計画が記載されています。

入力

ラベル 名前 タイプ 制約
(I1) lhs テンソルまたはテンソルごとの量子化テンソル (C5 ~ C6)、(C9 ~ C10)、(C12 ~ C14)、(C17 ~ C18)、(C20)
(I2) rhs テンソルまたは量子化テンソル (C7~C10)、(C12~C20)
(I3) lhs_batching_dimensions si64 型の 1 次元テンソル定数 (C1)、(C3)、(C5)、(C9)、(C12)
(I4) rhs_batching_dimensions si64 型の 1 次元テンソル定数 (C1)、(C4)、(C7)、(C9)
(I5) lhs_contracting_dimensions si64 型の 1 次元テンソル定数 (C2)、(C3)、(C6)、(C10)
(I6) rhs_contracting_dimensions si64 型の 1 次元のテンソル定数 (C2)、(C4)、(C8)、(C10)、(C16)
(I7) precision_config DEFAULTHIGHHIGHEST の可変長の数値型 (C11)、(C21)
(I8) lhs_precision_type FloatType または TensorFloat32 (C21)
(I9) rhs_precision_type FloatType または TensorFloat32 (C21)
(I10) accumulation_type FloatType または TensorFloat32 (C21)
(I11) lhs_component_count si32 型の定数 (C21)、(C22)
(I12) rhs_component_count si32 型の定数 (C21)、(C23)
(I13) num_primitive_operations si32 型の定数 (C21)、(C24)
(I14) allow_imprecise_accumulation bool 型の定数 (C21)

出力

名前 タイプ 制約
result テンソルまたは量子化テンソル (C12)、(C14)、(C18 ~ C20)

制約

  • (C1)size(lhs_batching_dimensions) = size(rhs_batching_dimensions)
  • (C2)size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions)
  • (C3)is_unique(lhs_batching_dimensions + lhs_contracting_dimensions)
  • (C4)is_unique(rhs_batching_dimensions + rhs_contracting_dimensions)
  • (C5)0 <= lhs_batching_dimensions < rank(lhs)
  • (C6)0 <= lhs_contracting_dimensions < rank(lhs)
  • (C7)0 <= rhs_batching_dimensions < rank(rhs)
  • (C8)0 <= rhs_contracting_dimensions < rank(rhs)
  • (C9)dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)
  • (C10)dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...)
  • (C11)size(precision_config) = 2
  • (C12)shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions)
  • オペレーションで量子化されていないテンソルを使用する場合:
    • (C13)element_type(lhs) = element_type(rhs)
  • オペレーションで量子化されたテンソルを使用する場合:
    • (C14)is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
    • (C15)zero_points(rhs) = 0
    • (C16)is_per_axis_quantized(rhs) の場合、quantization_dimension(rhs)rhs_contracting_dimensions にない。
    • is_quantized(lhs) の場合:
    • (C17)storage_type(lhs) = storage_type(rhs)
    • (C18)expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
    • (C19)is_per_tensor_quantized(rhs) の場合、is_per_tensor_quantized(result)
    • !is_quantized(lhs) の場合:
    • (C20)element_type(lhs) = expressed_type(rhs) = element_type(result)
  • !is_empty_algorithm(lhs_precision_type, rhs_precision_type, accumulation_type, lhs_component_count, rhs_component_count, num_primitive_operations allow_imprecise_accumulation) の場合:
    • (C21)precision_config... = DEFAULT
    • (C22)0 < lhs_component_count
    • (C23)0 < rhs_component_count
    • (C24)0 < num_primitive_operations

// %lhs: [
//        [[1, 2],
//         [3, 4]],
//        [[5, 6],
//         [7, 8]]
//       ]
// %rhs: [
//        [[1, 0],
//         [0, 1]],
//        [[1, 0],
//         [0, 1]]
//       ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
  dot_dimension_numbers = #stablehlo.dot<
    lhs_batching_dimensions = [0],
    rhs_batching_dimensions = [0],
    lhs_contracting_dimensions = [2],
    rhs_contracting_dimensions = [1]
  >,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>],
  algorithm = #stablehlo.dot_algorithm<
    lhs_precision_type = tf32,
    rhs_precision_type = tf32,
    accumulation_type = f32,
    lhs_component_count = 1,
    rhs_component_count = 1,
    num_primitive_operations = 1,
    allow_imprecise_accumulation = false
  >
} : (tensor<2x2x2xi64>, tensor<2x2x2xi64>) -> tensor<2x2x2xi64>
// %result: [
//           [[1, 2],
//            [3, 4]],
//           [[5, 6],
//            [7, 8]]
//          ]

その他の例

dynamic_broadcast_in_dim

セマンティクス

このオペレーションは、broadcast_in_dim オペレーションと機能的には同じですが、結果の形状は output_dimensions を介して動的に指定されます。

このオペレーションでは、オプションの属性 known_expanding_dimensionsknown_nonexpanding_dimensions も受け入れ、ディメンションの展開動作に関する静的な知識を表現します。指定しない場合は、すべてのディメンションが拡張可能であると見なされます。

入力

ラベル 名前 タイプ 制約
(I1) operand テンソルまたは量子化テンソル (C1-C2)、(C5-C6)、(C9)
(I2) output_dimensions 整数型の 1 次元テンソル (C7)
(I3) broadcast_dimensions 整数型の 1 次元定数テンソル (C2~C6)
(I4) known_expanding_dimensions 整数型の 1 次元定数テンソル (C8~C9)
(I5) known_nonexpanding_dimensions 整数型の 1 次元定数テンソル (C8~C9)

出力

名前 タイプ 制約
result テンソルまたは量子化テンソル (C1)、(C3)、(C5~C7)

制約

  • (C1)element_type(result) は次のように定義されます。
    • element_type(operand)!is_per_axis_quantized(operand) の場合)。
    • element_type(operand)(ただし、quantization_dimension(operand)scales(operand)zero_points(operand) は、それぞれ quantization_dimension(result)scales(result)zero_points(result) と異なる場合があります)。
  • (C2)size(broadcast_dimensions) = rank(operand)
  • (C3)0 <= broadcast_dimensions < rank(result)
  • (C4)is_unique(broadcast_dimensions)
  • (C5)axes(operand) 内のすべての d について:
    • dim(operand, d) = 1 または
    • dim(operand, d) = dim(result, broadcast_dimensions[d])
  • (C6)is_per_axis_quantized(result) の場合:
    • quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)]
    • dim(operand, quantization_dimension(operand)) = 1 の場合は scales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result))) です。
  • (C7)size(output_dimensions) = rank(result)
  • (C8)is_unique(known_expanding_dimensions + known_nonexpanding_dimensions)
  • (C9)0 <= known_expanding_dimensions < rank(operand)
  • (C10)0 <= known_nonexpanding_dimensions < rank(operand)

// %operand: [
//            [1, 2, 3]
//           ]
%operand = stablehlo.constant dense<[[1, 2, 3]]> : tensor<1x3xi64>
%output_dimensions = stablehlo.constant dense<[2, 3, 2]> : tensor<3xi64>
%result = "stablehlo.dynamic_broadcast_in_dim"(%operand, %output_dimensions) {
  broadcast_dimensions = array<i64: 2, 1>,
  known_expanding_dimensions = array<i64: 0>,
  known_nonexpanding_dimensions = array<i64: 1>
} : (tensor<1x3xi64>, tensor<3xi64>) -> tensor<2x3x2xi64>
// %result: [
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ],
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ]
//          ]

その他の例

dynamic_conv

セマンティクス

このオペレーションは、畳み込みオペレーションと機能的に同じですが、パディングは padding を介して動的に指定されます。

入力

ラベル 名前 タイプ 制約
(I1) lhs テンソルまたはテンソルごとの量子化テンソル (C1)、(C10 ~ C11)、(C14)(C25)、(C26 ~ C27)、(C30 ~ C31)、(C33)
(I2) rhs テンソルまたは量子化テンソル (C1)、(C14 ~ C16)、(C26 ~ C28)、(C30 ~ C33)
(I3) padding 整数型の 2 次元テンソル (C4)
(I4) window_strides si64 型の 1 次元テンソル定数 (C2 ~ C3)
(I5) lhs_dilation si64 型の 1 次元のテンソル定数 (C5-C6)
(I6) rhs_dilation si64 型の 1 次元テンソル定数 (C7 ~ C8)
(I7) window_reversal i1 型の 1 次元テンソル定数 (C9)
(I8) input_batch_dimension si64 型の定数 (C10)、(C13)
(I9) input_feature_dimension si64 型の定数 (C11)、(C13~C14)
(I10) input_spatial_dimensions si64 型の 1 次元テンソル定数 (C12)、(C13)
(I11) kernel_input_feature_dimension si64 型の定数 (C14)、(C18)
(I12) kernel_output_feature_dimension si64 型の定数 (C15~C16)、(C18)、(C28)
(I13) kernel_spatial_dimensions si64 型の 1 次元テンソル定数 (C17~C18)
(I14) output_batch_dimension si64 型の定数 (C20)
(I15) output_feature_dimension si64 型の定数 (C20)、(C29)
(I16) output_spatial_dimensions si64 型の 1 次元のテンソル定数 (C19 ~ C20)
(I17) feature_group_count si64 型の定数 (C11)、(C14)、(C16)、(C21)、(C23)
(I18) batch_group_count si64 型の定数 (C10)、(C15)、(C22)、(C23)
(I19) precision_config DEFAULTHIGHHIGHEST の可変列挙型の数 (C24)

出力

名前 タイプ 制約
result テンソルまたは量子化テンソル (C25~C27)、(C29)、(C31~C33)

制約

  • (C1)N = rank(lhs) = rank(rhs)
  • (C2)size(window_strides) = N - 2
  • (C3)0 < window_strides
  • (C4)shape(padding) = [N - 2, 2]
  • (C5)size(lhs_dilation) = N - 2
  • (C6)0 < lhs_dilation
  • (C7)size(rhs_dilation) = N - 2
  • (C8)0 < rhs_dilation
  • (C9)size(window_reversal) = N - 2
  • (C10)dim(lhs, input_batch_dimension) % batch_group_count = 0
  • (C11)dim(lhs, input_feature_dimension) % feature_group_count = 0
  • (C12)size(input_spatial_dimensions) = N - 2
  • (C13)input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension] が与えられている場合:
    • is_unique(input_dimensions)
    • 0 <= input_dimensions < N
  • (C14)dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count
  • (C15)dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0
  • (C16)dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0
  • (C17)size(kernel_spatial_dimensions) = N - 2
  • (C18)kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension] が与えられている場合:
    • is_unique(kernel_dimensions)
    • 0 <= kernel_dimensions < N
  • (C19)size(output_spatial_dimensions) = N - 2
  • (C20)output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension] が与えられている場合:
    • is_unique(output_dimensions)
    • 0 <= output_dimensions < N
  • (C21)0 < feature_group_count
  • (C22)0 < batch_group_count
  • (C23)feature_group_count = 1 or batch_group_count = 1
  • (C24)size(precision_config) = 2
  • (C25)dim(result, result_dim) は次のように定義されます。
    • result_dim = output_batch_dimension の場合、dim(lhs, input_batch_dimension) / batch_group_count
    • result_dim = output_feature_dimension の場合、dim(rhs, kernel_output_feature_dimension)
    • それ以外の場合は num_windows。ここで:
    • output_spatial_dimensions[spatial_dim] = result_dim
    • lhs_dim = input_spatial_dimensions[spatial_dim]
    • rhs_dim = kernel_spatial_dimensions[spatial_dim]
    • dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
    • padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
    • dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
    • is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
    • num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1
  • (C26)rank(result) = N
  • オペレーションで量子化されていないテンソルを使用する場合:
    • (C27)element_type(lhs) = element_type(rhs) = element_type(result)
  • オペレーションで量子化されたテンソルを使用する場合:
    • (C28)is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
    • (C29)is_per_axis_quantized(rhs) の場合、quantization_dimension(rhs) = kernel_output_feature_dimension
    • (C30)is_per_axis_quantized(result) の場合、quantization_dimension(result) = output_feature_dimension
    • is_quantized(lhs) の場合:
    • (C31)storage_type(lhs) = storage_type(rhs)
    • (C32)expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
    • (C33)is_per_tensor_quantized(rhs) の場合は is_per_tensor_quantized(result) です。
    • !is_quantized(lhs) の場合:
    • (C34)element_type(lhs) = expressed_type(rhs) = element_type(result)

// %lhs: [[
//        [[1], [2], [5], [6]],
//        [[3], [4], [7], [8]],
//        [[10], [11], [14], [15]],
//        [[12], [13], [16], [17]]
//      ]]
//
// %rhs: [
//         [[[1]], [[1]], [[1]]],
//         [[[1]], [[1]], [[1]]],
//         [[[1]], [[1]], [[1]]]
//        ]
// %padding: [[1, 1],
//            [1, 1]]
%result = "stablehlo.dynamic_conv"(%lhs, %rhs, %padding) {
  window_strides = array<i64: 4, 4>,
  lhs_dilation = array<i64: 2, 2>,
  rhs_dilation = array<i64: 1, 1>,
  window_reversal = array<i1: false, false>,
  dimension_numbers = #stablehlo.conv<raw
    input_batch_dimension = 0,
    input_feature_dimension = 3,
    input_spatial_dimensions = [0, 1],
    kernel_input_feature_dimension = 2,
    kernel_output_feature_dimension = 3,
    kernel_spatial_dimensions = [0, 1],
    output_batch_dimension = 0,
    output_feature_dimension = 3,
    output_spatial_dimensions = [1, 2]
  >,
  feature_group_count = 1 : i64,
  batch_group_count = 1 : i64,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>, tensor<2x2xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
//            [[1], [5]],
//            [[10], [14]]
//          ]]

その他の例

dynamic_gather

セマンティクス

このオペレーションは、slice_sizes が値として動的に指定される gather オペレーションと機能的に同じです。

入力

ラベル 名前 タイプ 制約
(I1) operand テンソルまたはテンソルごとの量子化テンソル (C1)、(C7)、(C10~C12)、(C14)
(I2) start_indices 整数型のテンソル (C2)、(C3)、(C13)
(I3) slice_sizes 整数型の 1 次元テンソル (C8)、(C11~C13)
(I4) offset_dims si64 型の 1 次元のテンソル定数 (C1)、(C4 ~ C5)、(C13)
(I5) collapsed_slice_dims si64 型の 1 次元テンソル定数 (C1)、(C6~C8)、(C13)
(I6) start_index_map si64 型の 1 次元のテンソル定数 (C3)、(C9)、(C10)
(I7) index_vector_dim si64 型の定数 (C2)、(C3)、(C13)
(I8) indices_are_sorted i1 型の定数

出力

名前 タイプ 制約
result テンソルまたはテンソルごとの量子化テンソル (C5)、(C13 ~ C14)

制約

  • (C1)rank(operand) = size(offset_dims) + size(collapsed_slice_dims)
  • (C2)0 <= index_vector_dim <= rank(start_indices)
  • (C3)size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1
  • (C4)is_unique(offset_dims) and is_sorted(offset_dims)
  • (C5)0 <= offset_dims < rank(result)
  • (C6)is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims)
  • (C7)0 <= collapsed_slice_dims < rank(operand)
  • (C8)slice_sizes[collapsed_slice_dims...] <= 1
  • (C9)is_unique(start_index_map)
  • (C10)0 <= start_index_map < rank(operand)
  • (C11)size(slice_sizes) = rank(operand)
  • (C12)0 <= slice_sizes <= shape(operand)
  • (C13)shape(result) = combine(batch_dim_sizes, offset_dim_sizes) ここで:
    • batch_dim_sizes = shape(start_indices) とほぼ同じですが、index_vector_dim に対応する start_indices のディメンション サイズは含まれません。
    • offset_dim_sizes = shape(slice_sizes)。ただし、collapsed_slice_dims に対応する slice_sizes のディメンション サイズは含まれません。
    • combine は、batch_dims に対応する軸に batch_dim_sizes を配置し、offset_dims に対応する軸に offset_dim_sizes を配置します。
  • (C14)element_type(operand) = element_type(result)

// %operand: [
//            [[1, 2], [3, 4], [5, 6], [7, 8]],
//            [[9, 10],[11, 12], [13, 14], [15, 16]],
//            [[17, 18], [19, 20], [21, 22], [23, 24]]
//           ]
// %start_indices: [
//                  [[0, 0], [1, 0], [2, 1]],
//                  [[0, 1], [1, 1], [0, 2]]
//                 ]
// %slize_sizes: [1, 2, 2]
%result = "stablehlo.dynamic_gather"(%operand, %start_indices, %slize_sizes) {
  dimension_numbers = #stablehlo.gather<
    offset_dims = [2, 3],
    collapsed_slice_dims = [0],
    start_index_map = [1, 0],
    index_vector_dim = 2>,
  indices_are_sorted = false
} : (tensor<3x4x2xi64>, tensor<2x3x2xi64>, tensor<3xi64>) -> tensor<2x3x2x2xi64>
// %result: [
//            [
//              [[1, 2], [3, 4]],
//              [[3, 4], [5, 6]],
//              [[13, 14], [15, 16]]
//            ],
//            [
//              [[9, 10], [11, 12]],
//              [[11, 12], [13, 14]],
//              [[17, 18], [19, 20]]
//            ]
//          ]

その他の例

dynamic_iota

セマンティクス

この演算は iota op と機能的に同じですが、結果の形状は output_shape によって動的に指定されます。

入力

ラベル 名前 タイプ 制約
(I1) output_shape 整数型の 1 次元テンソル (C1)、(C2)
(I2) iota_dimension si64 (C1)

出力

名前 タイプ 制約
result 整数型、浮動小数点型、複合型のテンソル、またはテンソルごとの量子化テンソル (C2)

制約

  • (C1)0 <= iota_dimension < size(output_shape)
  • (C2)rank(result) = size(output_shape)

%output_shape = stablehlo.constant dense<[4, 5]> : tensor<2xi64>
%result = "stablehlo.dynamic_iota"(%output_shape) {
  iota_dimension = 0 : i64
} : (tensor<2xi64>) -> tensor<4x5xi64>
// %result: [
//           [0, 0, 0, 0, 0],
//           [1, 1, 1, 1, 1],
//           [2, 2, 2, 2, 2],
//           [3, 3, 3, 3, 3]
//          ]

その他の例

dynamic_pad

セマンティクス

このオペレーションは、pad オペレーションと機能的には同じですが、edge_padding_lowedge_padding_highinterior_padding が値として動的に指定されます。

入力

ラベル 名前 タイプ 制約
(I1) operand テンソルまたはテンソルごとの量子化テンソル (C1)、(C2)、(C4)
(I2) padding_value 0 次元テンソルまたはテンソルごとの量子化テンソル (C1)
(I3) edge_padding_low 整数型の 1 次元テンソル (C1)、(C4)
(I4) edge_padding_high 整数型の 1 次元テンソル (C1)、(C4)
(I5) interior_padding 整数型の 1 次元テンソル (C2-C4)

出力

名前 タイプ 制約
result テンソルまたはテンソルごとの量子化テンソル (C3-C6)

制約

  • (C1)element_type(operand) = element_type(padding_value) = element_type(result)
  • (C2)size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)
  • (C3)0 <= interior_padding
  • (C4)shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high

// %operand: [
//            [1, 2, 3],
//            [4, 5, 6]
//           ]
// %padding_value: 0
// %edge_padding_low: [0, 1]
// %edge_padding_high: [2, 1]
// %interior_padding: [1, 2]
%result = "stablehlo.dynamic_pad"(%operand, %padding_value,
  %edge_padding_low, %edge_padding_high, %interior_padding
) : (tensor<2x3xi64>, tensor<i64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<5x9xi64>
// %result: [
//           [0, 1, 0, 0, 2, 0, 0, 3, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 4, 0, 0, 5, 0, 0, 6, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0]
//          ]

その他の例

dynamic_reshape

セマンティクス

このオペレーションは、reshape オペレーションと機能的には同じですが、結果の形状は output_shape を介して動的に指定されます。

入力

ラベル 名前 タイプ 制約
(I1) operand テンソルまたは量子化テンソル (C1~C3)
(I2) output_shape 整数型の 1 次元テンソル (C4)

出力

名前 タイプ 制約
result テンソルまたは量子化テンソル (C1~C4)

制約

  • (C1)element_type(result) は次のように定義されます。
    • element_type(operand)!is_per_axis_quantized(operand) の場合)。
    • element_type(operand)。ただし、quantization_dimension(operand)quantization_dimension(result) が異なる場合があります。
  • (C2)size(operand) = size(result)
  • (C3)is_per_axis_quantized(operand) の場合:
    • reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
    • dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
    • reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
  • (C4)size(output_shape) = rank(result)

// %operand: [[1, 2, 3], [4, 5, 6]]
// %output_shape: [3, 2]
%result = "stablehlo.dynamic_reshape"(%operand, %output_shape) : (tensor<2x3xi64>, tensor<2xi64>) -> tensor<3x2xi64>
// %result: [[1, 2], [3, 4], [5, 6]]

その他の例

dynamic_slice

セマンティクス

動的に計算された開始インデックスを使用して operand からスライスを取り出し、result テンソルを生成します。start_indices には、調整の対象となる各ディメンションのスライスの開始インデックスが含まれ、slice_sizes には各ディメンションのスライスのサイズが含まれます。より正式には、result[result_index] = operand[operand_index] です。ここで:

  • adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes)
  • operand_index = adjusted_start_indices + result_index

入力

ラベル 名前 タイプ 制約
(I1) operand テンソルまたはテンソルごとの量子化テンソル (C1)、(C2)、(C4)
(I2) start_indices 整数型の 0 次元テンソルの可変長の数 (C2)、(C3)
(I3) slice_sizes si64 型の 1 次元テンソル定数 (C2)、(C4)、(C5)

出力

名前 タイプ 制約
result テンソルまたはテンソルごとの量子化テンソル (C1)、(C5)

制約

  • (C1)element_type(operand) = element_type(result)
  • (C2)size(start_indices) = size(slice_sizes) = rank(operand)
  • (C3)same(type(start_indices...))
  • (C4)0 <= slice_sizes <= shape(operand)
  • (C5)shape(result) = slice_sizes

// %operand: [
//            [0, 0, 1, 1],
//            [0, 0, 1, 1],
//            [0, 0, 0, 0],
//            [0, 0, 0, 0]
//           ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
  slice_sizes = array<i64: 2, 2>
} : (tensor<4x4xi32>, tensor<i64>, tensor<i64>) -> tensor<2x2xi32>
// %result: [
//           [1, 1],
//           [1, 1]
//          ]

その他の例

dynamic_update_slice

セマンティクス

start_indices で始まるスライスが update の値で更新される点を除き、operand テンソルと同じ result テンソルを生成します。より正式には、result[result_index] は次のように定義されます。

  • update[update_index]: 0 <= update_index < shape(update) の場合。条件は次のとおりです。
    • adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update))
    • update_index = result_index - adjusted_start_indices
  • それ以外の場合は operand[result_index]

入力

ラベル 名前 タイプ 制約
(I1) operand テンソルまたはテンソルごとの量子化テンソル (C1 ~ C4)、(C6)
(I2) update テンソルまたはテンソルごとの量子化テンソル (C2)、(C3)、(C6)
(I3) start_indices 整数型の 0 次元テンソルの可変長の数 (C4)、(C5)

出力

名前 タイプ 制約
result テンソルまたはテンソルごとの量子化テンソル (C1)

制約

  • (C1)type(operand) = type(result)
  • (C2)element_type(update) = element_type(operand)
  • (C3)rank(update) = rank(operand)
  • (C4)size(start_indices) = rank(operand)
  • (C5)same(type(start_indices...))
  • (C6)0 <= shape(update) <= shape(operand)

// %operand: [
//            [1, 1, 0, 0],
//            [1, 1, 0, 0],
//            [1, 1, 1, 1],
//            [1, 1, 1, 1]
//           ]
// %update: [
//           [1, 1],
//           [1, 1]
//          ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
  : (tensor<4x4xi32>, tensor<2x2xi32>, tensor<i64>, tensor<i64>) -> tensor<4x4xi32>
// %result: [
//           [1, 1, 1, 1],
//           [1, 1, 1, 1],
//           [1, 1, 1, 1],
//           [1, 1, 1, 1]
//          ]

その他の例

指数

セマンティクス

operand テンソルの要素ごとの指数演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。

  • 浮動小数点数の場合: IEEE-754 の exp
  • 複素数の場合: 複素指数関数。
  • 量子化された型の場合: dequantize_op_quantize(exponential, operand, type(result))

入力

ラベル 名前 タイプ 制約
(I1) operand 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(operand) = baseline_type(result)

// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]

その他の例

exponential_minus_one

セマンティクス

operand テンソルに対して要素ごとの指数マイナス 1 演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の操作を行います。

  • 浮動小数点数: IEEE-754 の expm1
  • 複素数の場合: 複素指数 - 1。
  • 量子化された型の場合: dequantize_op_quantize(exponential_minus_one, operand, type(result))

入力

ラベル 名前 タイプ 制約
(I1) operand 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(operand) = baseline_type(result)

// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand) : (tensor<2xf64>) -> tensor<2xf64>
// %result: [0.0, 1.71828187]

その他の例

fft

セマンティクス

実数と複素数の入力 / 出力の正規化と逆正規化を行います。

fft_type は、次のいずれかです。

  • FFT: 複雑な FFT を転送します。
  • IFFT: 複素数から複素数への逆 FFT。
  • RFFT: 実数から複素数への FFT を前方処理します。
  • IRFFT: 実数から複素数への逆 FFT(複素数を受け取り、実数を返します)。

より形式的には、複雑な型の 1 次元テンソルを入力として受け取る関数 fft は、出力と同じ型の 1 次元テンソルを生成し、離散フーリエ変換を計算します。

fft_type = FFT の場合、result は、L = size(fft_length) の連続した L 計算の最終結果として定義されます。たとえば、L = 3 の場合は次のようになります。

  • result1[i0, ..., :] = fft(operand[i0, ..., :])
  • result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
  • result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])

さらに、同じ型シグネチャを持ち、fft の逆数を計算する関数 ifft があるとします。

fft_type = IFFT の場合、resultfft_type = FFT の計算の逆数として定義されます。たとえば、L = 3 の場合は次のようになります。

  • result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])
  • result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
  • result[i0, ..., :] = ifft(result2[i0, ..., :])

さらに、浮動小数点型の 1 次元テンソルを取る関数 rfft は、同じ浮動小数点セマンティクスの複合型の 1 次元テンソルを生成し、次のように動作します。

  • rfft(real_operand) = truncated_result
  • complex_operand... = (real_operand..., 0.0)
  • complex_result = fft(complex_operand)
  • truncated_result = complex_result[:(rank(complex_result) / 2 + 1)]

(実数オペランドに対して離散フーリエ変換が計算される場合、結果の最初の N/2 + 1 要素が結果の残りの部分を明確に定義するため、冗長な要素の計算を避けるために rfft の結果は切り捨てられます)。

fft_type = RFFT の場合、result は、L = size(fft_length) の連続した L 計算の最終結果として定義されます。たとえば、L = 3 の場合は次のようになります。

  • result1[i0, ..., :] = rfft(operand[i0, ..., :])
  • result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
  • result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])

最後に、同じ型のシグネチャを持つ関数 irfft が与えられ、rfft の逆数を計算します。

fft_type = IRFFT の場合、resultfft_type = RFFT の計算の逆数として定義されます。たとえば、L = 3 の場合は次のようになります。

  • result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])
  • result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
  • result[i0, ..., :] = irfft(result2[i0, ..., :])

入力

ラベル 名前 タイプ 制約
(I1) operand 浮動小数点型または複合型のテンソル (C1)、(C2)、(C4)、(C5)
(I2) fft_type FFTIFFTRFFTIRFFT の列挙型 (C2)、(C5)
(I3) fft_length si64 型の 1 次元のテンソル定数 (C1)、(C3)、(C4)

出力

名前 タイプ 制約
result 浮動小数点数または複素数のテンソル (C2)、(C4)、(C5)

制約

  • (C1)size(fft_length) <= rank(operand)
  • (C2)operand 要素と result 要素型の関係はさまざまです。
    • fft_type = FFTelement_type(operand)element_type(result) が同じ複合型の場合。
    • fft_type = IFFTelement_type(operand)element_type(result) が同じ複合型の場合。
    • fft_type = RFFT の場合、element_type(operand) は浮動小数点型で、element_type(result) は同じ浮動小数点セマンティクスの複合型です。
    • fft_type = IRFFT の場合、element_type(operand) は複合型であり、element_type(result) は同じ浮動小数点セマンティクスの浮動小数点型です。
  • (C3)1 <= size(fft_length) <= 3
  • (C4)operandresult の間に浮動小数点型のテンサー real がある場合、shape(real)[-size(fft_length):] = fft_length
  • (C5)次の点を除き shape(result) = shape(operand)
    • fft_type = RFFT の場合は dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1
    • fft_type = IRFFT の場合は dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1

// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
  fft_type = #stablehlo<fft_type FFT>,
  fft_length = array<i64: 4>
} : (tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>>
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]

floor

セマンティクス

operand テンソルの要素ごとの床演算を行い、result テンソルを生成します。IEEE-754 仕様の roundToIntegralTowardNegative オペレーションを実装します。量子化型の場合は、dequantize_op_quantize(floor, operand, type(result)) を実行します。

入力

ラベル 名前 タイプ 制約
(I1) operand 浮動小数点型のテンソルまたはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result 浮動小数点型のテンソルまたはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(operand) = baseline_type(result)

// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]

その他の例

集める

セマンティクス

start_indices で指定されたオフセットから operand テンソルのスライスを収集し、result テンソルを生成します。

次の図は、result の要素が operand の要素にどのようにマッピングされるかを、具体的な例を使用して示しています。この図では、いくつかの result インデックスの例を選択し、それらが対応する operand インデックスについて詳しく説明しています。

収集

より正式には、result[result_index] = operand[operand_index] です。ここで:

  • batch_dims = [d for d in axes(result) and d not in offset_dims]
  • batch_index = result_index[batch_dims...]
  • start_index は次のように定義されます。
    • start_indices[bi0, ..., :, ..., biN]。ここで、bibatch_index 内の個々の要素で、:index_vector_dim インデックスに挿入されます(index_vector_dim < rank(start_indices) の場合)。
    • それ以外の場合は [start_indices[batch_index]]
  • axes(operand)d_operand は、
    • d_operand = start_index_map[d_start] の場合、full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand])
    • そうでない場合は full_start_index[d_operand] = 0
  • axes(operand)d_operand は、
    • full_batching_index[d_operand] = batch_index[d_start - (d_start < index_vector_dim ? 0 : 1)] d_operand = operand_batching_dims[i_batching]d_start = start_indices_batching_dims[i_batching] の場合。
    • それ以外の場合は full_batching_index[d_operand] = 0
  • offset_index = result_index[offset_dims...]
  • full_offset_index = [oi0, ..., 0, ..., oiN]。ここで、oioffset_index の個々の要素であり、0collapsed_slice_dimsoperand_batching_dims のインデックスに挿入されます。
  • operand_index = full_start_index + full_batching_index + full_offset_index

indices_are_sortedtrue の場合、実装では start_indicesstart_index_map に関して並べ替えられていると想定できます。それ以外の場合、動作は未定義です。より正式には、indices(result) のすべての i1 < i2 について、full_start_index(i1) <= full_start_index(i2) です。

入力

ラベル 名前 タイプ 制約
(I1) operand テンソルまたはテンソルごとの量子化テンソル (C1)、(C8)、(C11)、(C17)、(C19 ~ C21)、(C23)
(I2) start_indices 整数型のテンソル (C2~C3)、(C14)、(C17)、(C22)
(I3) offset_dims si64 型の 1 次元テンソル定数 (C1)、(C4 ~ C5)、(C22)
(I4) collapsed_slice_dims si64 型の 1 次元のテンソル定数 (C1)、(C6 ~ C9)、(C22)
(I5) operand_batching_dims si64 型の 1 次元テンソル定数 (C1)、(C6)、(C10 ~ C12)、(C16 ~ C18)、(C22)
(I6) start_indices_batching_dims si64 型の 1 次元テンソル定数 (C13 ~ C17)
(I7) start_index_map si64 型の 1 次元テンソル定数 (C3)、(C18~C19)
(I8) index_vector_dim si64 型の定数 (C2 ~ C3)、(C15)、(C22)
(I9) slice_sizes si64 型の 1 次元のテンソル定数 (C9)、(C12)、(C20~C22)
(I10) indices_are_sorted i1 型の定数

出力

名前 タイプ 制約
result テンソルまたはテンソルごとの量子化テンソル (C5)、(C22 ~ C23)

制約

  • (C1)rank(operand) = size(offset_dims) + size(collapsed_slice_dims) + size(operand_batching_dims)
  • (C2)0 <= index_vector_dim <= rank(start_indices)
  • (C3)size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1
  • (C4)is_unique(offset_dims) and is_sorted(offset_dims)
  • (C5)0 <= offset_dims < rank(result)
  • (C6)is_unique(concatenate(collapsed_slice_dims, operand_batching_dims))
  • (C7)is_sorted(collapsed_slice_dims)
  • (C8)0 <= collapsed_slice_dims < rank(operand)
  • (C9)slice_sizes[collapsed_slice_dims...] <= 1
  • (C10)is_sorted(operand_batching_dims)
  • (C11)0 <= operand_batching_dims < rank(operand)
  • (C12)slice_sizes[operand_batching_dims...] <= 1
  • (C13)is_unique(start_indices_batching_dims)
  • (C14)0 <= start_indices_batching_dims < rank(start_indices)
  • (C15)index_vector_dim not in start_indices_batching_dims
  • (C16)size(operand_batching_dims) == size(start_indices_batching_dims)
  • (C17)dim(operand, operand_batching_dims...) = dim(start_indices, start_indices_batching_dims...)
  • (C18)is_unique(concatenate(start_index_map, operand_batching_dims))
  • (C19)0 <= start_index_map < rank(operand)
  • (C20)size(slice_sizes) = rank(operand)
  • (C21)0 <= slice_sizes <= shape(operand)
  • (C22)shape(result) = combine(batch_dim_sizes, offset_dim_sizes)。ここで:
    • batch_dim_sizes = shape(start_indices) とほぼ同じですが、index_vector_dim に対応する start_indices のディメンション サイズは含まれません。
    • offset_dim_sizes = slice_sizes。ただし、collapsed_slice_dimsoperand_batching_dims に対応する slice_sizes のディメンション サイズは含まれません。
    • combine は、batch_dims に対応する軸に batch_dim_sizes を配置し、offset_dims に対応する軸に offset_dim_sizes を配置します。
  • (C23)element_type(operand) = element_type(result)

// %operand: [
//            [
//             [[1, 2], [3, 4], [5, 6], [7, 8]],
//             [[9, 10],[11, 12], [13, 14], [15, 16]],
//             [[17, 18], [19, 20], [21, 22], [23, 24]]
//            ],
//            [
//             [[25, 26], [27, 28], [29, 30], [31, 32]],
//             [[33, 34], [35, 36], [37, 38], [39, 40]],
//             [[41, 42], [43, 44], [45, 46], [47, 48]]
//            ]
//           ]
// %start_indices: [
//                  [
//                   [[0, 0], [1, 0], [2, 1]],
//                   [[0, 1], [1, 1], [0, 9]]
//                  ],
//                  [
//                   [[0, 0], [2, 1], [2, 2]],
//                   [[1, 2], [0, 1], [1, 0]]
//                  ]
//                 ]
%result = "stablehlo.gather"(%operand, %start_indices) {
  dimension_numbers = #stablehlo.gather<
    offset_dims = [3, 4],
    collapsed_slice_dims = [1],
    operand_batching_dims = [0],
    start_indices_batching_dims = [1],
    start_index_map = [2, 1],
    index_vector_dim = 3>,
  slice_sizes = array<i64: 1, 1, 2, 2>,
  indices_are_sorted = false
} : (tensor<2x3x4x2xi32>, tensor<2x2x3x2xi64>) -> tensor<2x2x3x2x2xi32>
// %result: [
//           [
//            [
//             [[1, 2], [3, 4]],
//             [[3, 4], [5, 6]],
//             [[13, 14], [15, 16]]
//            ],
//            [
//             [[33, 34], [35, 36]],
//             [[35, 36], [37, 38]],
//             [[41, 42], [43, 44]]
//            ]
//           ],
//           [
//            [
//             [[1, 2], [3, 4]],
//             [[13, 14], [15, 16]],
//             [[21, 22], [23, 24]]
//            ],
//            [
//             [[43, 44], [45, 46]],
//             [[33, 34], [35, 36]],
//             [[27, 28], [29, 30]]
//            ]
//           ]
//          ]

その他の例

get_dimension_size

セマンティクス

operand の指定された dimension のサイズを生成します。正式には result = dim(operand, dimension) です。セマンティクスは、型のシェイプ コンポーネントにのみ関係します。要素の型は任意です。

入力

ラベル 名前 タイプ 制約
(I1) operand テンソルまたは量子化テンソル (C1)
(I2) dimension si64 型の定数 (C1)

出力

名前 タイプ
result si32 型の 0 次元テンソル

制約

  • (C1)0 <= dimension < rank(operand)

// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
  dimension = 1 : i64
} : (tensor<2x3xi64>) -> tensor<i32>
// %result: 3

その他の例

get_tuple_element

セマンティクス

operand タプルの index 位置にある要素を抽出して result を生成します。正式には result = operand[index] です。

入力

ラベル 名前 タイプ 制約
(I1) operand tuple (C1)、(C2)
(I2) index si32 型の定数 (C1)、(C2)

出力

名前 タイプ 制約
result サポートされている任意のタイプ (C2)

制約

  • (C1)0 <= index < size(operand)
  • (C2)type(result) = tuple_element_types(operand)[index]

// %operand: ([1.0, 2.0], (3))
  index = 0 : i32
} : (tuple<tensor<2xf32>, tuple<tensor<i32>>>) -> tensor<2xf32>
// %result: [1.0, 2.0]

その他の例

~なら

セマンティクス

pred の値に応じて、true_branch または false_branch から関数を 1 つだけ実行することで、出力を生成します。(より正式には、result = pred ? true_branch() : false_branch())。

入力

ラベル 名前 タイプ 制約
(I1) pred i1 型の 0 次元テンソル
(I2) true_branch 関数 (C1~C3)
(I3) false_branch 関数 (C1)、(C2)

出力

名前 タイプ 制約
results テンソル、量子化テンソル、トークンの可変数 (C3)

制約

  • (C1)input_types(true_branch) = input_types(false_branch) = []
  • (C2)output_types(true_branch) = output_types(false_branch)
  • (C3)type(results...) = output_types(true_branch)

// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
  "stablehlo.return"(%result_true_branch) : (tensor<i32>) -> ()
}, {
  "stablehlo.return"(%result_false_branch) : (tensor<i32>) -> ()
}) : (tensor<i1>) -> tensor<i32>
// %result: 10

その他の例

imag

セマンティクス

operand から虚部を要素ごとに抽出し、result テンソルを生成します。より正式には、各要素 x について imag(x) = is_complex(x) ? imaginary_part(x) : constant(0, element_type(result)) です。

入力

ラベル 名前 タイプ 制約
(I1) operand 浮動小数点数または複素数のテンソル (C1)、(C2)

出力

名前 タイプ 制約
result 浮動小数点型のテンソル (C1)、(C2)

制約

  • (C1)shape(result) = shape(operand)
  • (C2)element_type(result) は次のように定義されます。
    • is_complex(operand) の場合、complex_element_type(element_type(operand))
    • それ以外の場合は element_type(operand)

// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [2.0, 4.0]

その他の例

インフィード

セマンティクス

インフィードからデータを読み取り、results を生成します。

infeed_config のセマンティクスは実装で定義されます。

results は、先頭にペイロード値、最後にトークンで構成されます。今後、明確性を高めるために、ペイロードとトークンを 2 つの個別の出力に分割する予定です(#670)。

入力

ラベル 名前 タイプ
(I1) token token
(I2) infeed_config string 型の定数

出力

名前 タイプ 制約
results テンソル、量子化テンソル、またはトークンの可変数 (C1~C3)

制約

  • (C1)0 < size(results)
  • (C2)is_empty(result[:-1]) または is_tensor(type(results[:-1]))
  • (C3)is_token(type(results[-1]))

// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
  infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
  infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]

その他の例

iota

セマンティクス

output テンソルを、iota_dimension ディメンションに沿って 0 から順に増加する値で埋めます。より正式には、

output[output_index] = constant(is_quantized(output) ? quantize(output_index[iota_dimension], element_type(output)) : output_index[iota_dimension], element_type(output))

入力

ラベル 名前 タイプ 制約
(I1) iota_dimension si64 (C1)

出力

名前 タイプ 制約
output 整数型、浮動小数点型、複合型のテンソル、またはテンソルごとの量子化テンソル (C1)

制約

  • (C1)0 <= iota_dimension < rank(output)

%output = "stablehlo.iota"() {
  iota_dimension = 0 : i64
} : () -> tensor<4x5xi32>
// %output: [
//           [0, 0, 0, 0, 0],
//           [1, 1, 1, 1, 1],
//           [2, 2, 2, 2, 2],
//           [3, 3, 3, 3, 3]
//          ]

%output = "stablehlo.iota"() {
  iota_dimension = 1 : i64
} : () -> tensor<4x5xi32>
// %output: [
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4]
//          ]

その他の例

is_finite

セマンティクス

x の値が有限(+Inf、-Inf、NaN のいずれでもない)かどうかを要素ごとにチェックし、y テンソルを生成します。IEEE-754 仕様の isFinite オペレーションを実装します。量子化された型の場合、結果は常に true です。

入力

ラベル 名前 タイプ 制約
(I1) x 浮動小数点型のテンソルまたはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
y ブール型のテンソル (C1)

制約

  • (C1)shape(x) = shape(y)

// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x) : (tensor<7xf64) -> tensor<7xi1>
// %y: [false, false, false, true, true, true, true]

その他の例

log

セマンティクス

operand テンソルの要素ごとのログ演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。

  • 浮動小数点数の場合: IEEE-754 の log
  • 複素数の場合: 複素対数。
  • 量子化型の場合: dequantize_op_quantize(log, operand, type(result))

入力

ラベル 名前 タイプ 制約
(I1) operand 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(operand) = baseline_type(result)

// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]

その他の例

log_plus_one

セマンティクス

operand テンソルに対して要素単位の対数に 1 演算を加え、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。

  • 浮動小数点数の場合: IEEE-754 の logp1
  • 複素数の場合: 複素対数に 1 を加算します。
  • 量子化された型の場合: dequantize_op_quantize(log_plus_one, operand, type(result))

入力

ラベル 名前 タイプ 制約
(I1) operand 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(operand) = baseline_type(result)

// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]

その他の例

ロジスティクス

セマンティクス

operand テンソルの要素ごとのロジスティック演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の操作を行います。

  • 浮動小数点数の場合: IEEE-754 の division(1, addition(1, exp(-x)))
  • 複素数の場合: 複素ロジスティック。
  • 量子化された型の場合: dequantize_op_quantize(logistic, operand, type(result))

入力

ラベル 名前 タイプ 制約
(I1) operand 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(operand) = baseline_type(result)

// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]

その他の例

地図

セマンティクス

dimensions に沿ってマッピング関数 computationinputs に適用し、result テンソルを生成します。

正式には result[result_index] = computation(inputs...[result_index]) です。

入力

ラベル 名前 タイプ 制約
(I1) inputs テンソルの可変数またはテンソルごとの量子化テンソル (C1 ~ C4)
(I2) dimensions si64 型の 1 次元テンソル定数 (C3)
(I3) computation 関数 (C4)

出力

名前 タイプ 制約
result テンソルまたはテンソルごとの量子化テンソル (C1)、(C4)

制約

  • (C1)shape(inputs...) = shape(result)
  • (C2)0 < size(inputs) = N
  • (C3)dimensions = range(rank(inputs[0]))
  • (C4)computation の型は (tensor<E0>, ..., tensor<EN-1>) -> tensor<E'> で、Ei = element_type(inputs[i])E' = element_type(result) です。

// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = stablehlo.multiply %arg0, %arg1 : tensor<i64>
    stablehlo.return %0 : tensor<i64>
}) {
  dimensions = array<i64: 0, 1>
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[0, 5], [12, 21]]

その他の例

最大

セマンティクス

テンソル lhsrhs に対して要素ごとの最大演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。

  • ブール値の場合: 論理 OR。
  • 整数の場合: 整数の最大値。
  • 浮動小数点数の場合: IEEE-754 の maximum
  • 複素数の場合: (real, imaginary) ペアの辞書順の最大値。複素数の順序付けには意外なセマンティクスが伴うため、将来的には、この演算での複素数のサポートは終了する予定です(#560)。
  • 量子化された型の場合:
    • dequantize_op_quantize(maximum, lhs, rhs, type(result))

入力

ラベル 名前 タイプ 制約
(I1) lhs テンソルまたはテンソルごとの量子化テンソル (C1)
(I2) rhs テンソルまたはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result テンソルまたはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)

// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 8]]

その他の例

最小

セマンティクス

テンソル lhsrhs に対して要素ごとの最小演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。

  • ブール値の場合: 論理 AND。
  • 整数の場合: 整数の最小値。
  • 浮動小数点数の場合: IEEE-754 の minimum
  • 複素数の場合: (real, imaginary) ペアの辞書順最小値。複素数に順序付けを適用すると、意図しないセマンティクスが発生するため、今後、このオペレーションでの複素数のサポートを削除する予定です(#560)。
  • 量子化型の場合:
    • dequantize_op_quantize(minimum, lhs, rhs, type(result))

入力

ラベル 名前 タイプ 制約
(I1) lhs テンソルまたはテンソルごとの量子化テンソル (C1)
(I2) rhs テンソルまたはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result テンソルまたはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)

// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 4]]

その他の例

乗算

セマンティクス

2 つのテンソル lhsrhs の要素ごとの積を実行し、result テンソルを生成します。要素のタイプに応じて、次の操作を行います。

  • ブール値の場合: 論理 AND。
  • 整数の場合: 整数の乗算。
  • 浮動小数点数の場合: IEEE-754 の multiplication
  • 複素数の場合: 複素数乗算。
  • 量子化された型の場合:
    • dequantize_op_quantize(multiply, lhs, rhs, type(result))

入力

ラベル 名前 タイプ 制約
(I1) lhs テンソルまたはテンソルごとの量子化テンソル (C1)
(I2) rhs テンソルまたはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result テンソルまたはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(operand) = baseline_type(result)

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 12], [21, 32]]

その他の例

negate

セマンティクス

operand テンソルの要素ごとの否定を行い、result テンソルを生成します。要素のタイプに応じて、次の操作を行います。

  • 符号付き整数の場合: 整数の否定。
  • 符号なし整数の場合: 符号付き整数へのビットキャスト、整数の否定、符号なし整数へのビットキャスト。
  • 浮動小数点数: IEEE-754 の negate
  • 複素数の場合: 複素数を否定する。
  • 量子化された型の場合: dequantize_op_quantize(negate, operand, type(result))

入力

ラベル 名前 タイプ 制約
(I1) operand 整数型、浮動小数点型、複合型のテンソル、またはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result 整数型、浮動小数点型、複合型のテンソル、またはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(operand) = baseline_type(result)

// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand) : (tensor<2xi32>) -> tensor<2xi32>
// %result: [0, 2]

// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"(%operand) : (tensor<1xcomplex<f32>>) -> tensor<1xcomplex<f32>>
// %result: [-2.5, -0.0]

その他の例

いない

セマンティクス

テンソル operand の要素ごとの NOT を実行し、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。

  • ブール値の場合: 論理 NOT。
  • 整数の場合: ビット演算 NOT。

引数

名前 タイプ 制約
operand ブール型または整数型のテンソル (C1)

出力

名前 タイプ 制約
result ブール型または整数型のテンソル (C1)

制約

  • (C1)type(operand) = type(result)

// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand) : (tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[-2, -3], [-4, -5]]

// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"(%operand) : (tensor<2xi1>) -> tensor<2xi1>
// %result: [false, true]

その他の例

optimization_barrier

セマンティクス

operand を生成するオペレーションが result に依存するオペレーションの前に実行され、コンパイラ変換によってオペレーションがバリアを超えて移動されないようにします。それ以外の場合、オペレーションは ID(result = operand など)です。

引数

名前 タイプ 制約
operand テンソルの可変数、テンソルごとの量子化テンソルまたはトークン (C1)

出力

名前 タイプ 制約
result テンソルの可変数、テンソルごとの量子化テンソルまたはトークン (C1)

制約

  • (C1)type(operand...) = type(result...)

// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
// %result0: 0.0
// %result1: 1.0

その他の例

または

セマンティクス

2 つのテンソル lhsrhs の要素ごとの OR を実行し、result テンソルを生成します。要素のタイプに応じて、次の操作を行います。

  • ブール値の場合: 論理 OR。
  • 整数の場合: ビット演算 OR。

入力

ラベル 名前 タイプ 制約
(I1) lhs 整数型またはブール型のテンソル (C1)
(I2) rhs 整数型またはブール型のテンソル (C1)

出力

名前 タイプ 制約
result 整数型またはブール型のテンソル (C1)

制約

  • (C1)type(lhs) = type(rhs) = type(result)

// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 12]]

// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, true]]

その他の例

アウトフィード

セマンティクス

inputs をアウトフィードに書き込み、result トークンを生成します。

outfeed_config のセマンティクスは実装で定義されます。

入力

ラベル 名前 タイプ
(I1) inputs テンソルまたは量子化テンソルの可変数
(I2) token token
(I3) outfeed_config string 型の定数

出力

名前 タイプ
result token

%result = "stablehlo.outfeed"(%input0, %token) {
  outfeed_config = ""
} : (tensor<2x2x2xi64>, !stablehlo.token) -> !stablehlo.token

その他の例

パッド

セマンティクス

テンソルの周囲と、指定された padding_value を使用してテンソルの要素の間にパディングすることで、operand を拡張します。

edge_padding_lowedge_padding_high は、各ディメンションの下端(インデックス 0 の隣)と上端(最大インデックスの隣)に追加されるパディングの量をそれぞれ指定します。パディングの量は負の値にすることができます。負のパディングの絶対値は、指定されたディメンションから削除する要素の数を示します。

interior_padding は、各ディメンション内の任意の 2 つの要素間に追加されるパディングの量を指定します。この値は負の値にすることはできません。内部パディングはエッジ パディングの前に行われるため、負のエッジ パディングを行うと、内部パディングされたオペランドから要素が削除されます。

より正式には、result[result_index] は次のように定義されます。

  • result_index = edge_padding_low + operand_index * (interior_padding + 1) の場合、operand[operand_index]
  • それ以外の場合は padding_value

入力

ラベル 名前 タイプ 制約
(I1) operand テンソルまたはテンソルごとの量子化テンソル (C1)、(C2)、(C4)
(I2) padding_value 0 次元テンソルまたはテンソルごとの量子化テンソル (C1)
(I3) edge_padding_low si64 型の 1 次元テンソル定数 (C1)、(C4)
(I4) edge_padding_high si64 型の 1 次元のテンソル定数 (C1)、(C4)
(I5) interior_padding si64 型の 1 次元テンソル定数 (C2 ~ C4)

出力

名前 タイプ 制約
result テンソルまたはテンソルごとの量子化テンソル (C3-C6)

制約

  • (C1)element_type(operand) = element_type(padding_value) = element_type(result)
  • (C2)size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)
  • (C3)0 <= interior_padding
  • (C4)shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high

// %operand: [
//            [1, 2, 3],
//            [4, 5, 6]
//           ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
  edge_padding_low = array<i64: 0, 1>,
  edge_padding_high = array<i64: 2, 1>,
  interior_padding = array<i64: 1, 2>
} : (tensor<2x3xi32>, tensor<i32>) -> tensor<5x9xi32>
// %result: [
//           [0, 1, 0, 0, 2, 0, 0, 3, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 4, 0, 0, 5, 0, 0, 6, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0]
//          ]

その他の例

partition_id

セマンティクス

現在のプロセスの partition_id を生成します。

出力

名前 タイプ
result ui32 型の 0 次元テンソル

%result = "stablehlo.partition_id"() : () -> tensor<ui32>

その他の例

popcnt

セマンティクス

operand テンソルに設定されているビット数を要素ごとにカウントし、result テンソルを生成します。

入力

ラベル 名前 タイプ 制約
(I1) operand 整数型のテンソル (C1)

出力

名前 タイプ 制約
result 整数型のテンソル (C1)

制約

  • (C1)type(operand) = type(result)

// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand) : (tensor<4xi64>) -> tensor<4xi64>
// %result: [0, 1, 1, 7]

その他の例

電力

セマンティクス

lhs テンソルと rhs テンソルの要素ごとの指数演算を行い、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。

  • 整数の場合: 整数のべき乗。
  • 浮動小数点数の場合: IEEE-754 の pow
  • 複素数の場合: 複素指数関数。
  • 量子化された型の場合: dequantize_op_quantize(power, lhs, rhs, type(result))

入力

ラベル 名前 タイプ 制約
(I1) lhs 整数型、浮動小数点型、複合型のテンソル、またはテンソルごとの量子化テンソル (C1)
(I2) rhs 整数、浮動小数点、複合型のテンソル、またはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result 整数型、浮動小数点型、複合型のテンソル、またはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(operand) = baseline_type(result)

// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs) : (tensor<6xf64>, tensor<6xf64>) -> tensor<6xf64>
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]

その他の例

real

セマンティクス

operand から要素ごとに実部を抽出し、result テンソルを生成します。より正式には、各要素 x について real(x) = is_complex(x) ? real_part(x) : x です。

入力

ラベル 名前 タイプ 制約
(I1) operand 浮動小数点数または複素数のテンソル (C1)、(C2)

出力

名前 タイプ 制約
result 浮動小数点型のテンソル (C1)、(C2)

制約

  • (C1)shape(result) = shape(operand)
  • (C2)element_type(result) は次のように定義されます。
    • is_complex(operand) の場合、complex_element_type(element_type(operand))
    • それ以外の場合は element_type(operand)

// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [1.0, 3.0]

その他の例

受信

セマンティクス

channel_id を使用してチャネルからデータを受信し、results を生成します。

is_host_transfertrue の場合、オペレーションはホストからデータを転送します。そうでない場合は、別のデバイスからデータを転送します。具体的な意味は実装によって異なります。このフラグは channel_type で提供される情報と重複するため、今後はどちらか一方のみを保持する予定です(#666)。

results は、先頭にペイロード値、最後にトークンで構成されます。今後、明確性を高めるために、ペイロードとトークンを 2 つの個別の出力に分割する予定です(#670)。

入力

ラベル 名前 タイプ 制約
(I1) token token (C4)
(I2) channel_id si64 型の定数
(I3) channel_type DEVICE_TO_DEVICEHOST_TO_DEVICE の列挙型 (C1)
(I4) is_host_transfer i1 型の定数 (C1)

出力

名前 タイプ 制約
results テンソル、量子化テンソル、トークンの可変数 (C2 ~ C4)

制約

  • (C1)channel_type は次のように定義されます。
    • is_host_transfer = true の場合、HOST_TO_DEVICE
    • それ以外の場合は DEVICE_TO_DEVICE
  • (C2)0 < size(results)
  • (C3)is_empty(result[:-1]) または is_tensor(type(results[:-1]))
  • (C4)is_token(type(results[-1]))

%results0, %results1 = "stablehlo.recv"(%token) {
  channel_handle = #stablehlo.channel_handle<handle = 1, type = 3>,
  is_host_transfer = true
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)

その他の例

reduce

セマンティクス

dimensions に沿ってリダクション関数 bodyinputsinit_values に適用し、results テンソルを生成します。

減算の順序は実装で定義されます。つまり、すべての実装ですべての入力に対して同じ結果が得られるように、bodyinit_values はモノイドを形成する必要があります。ただし、この条件は一般的な多くの減少では成立しません。たとえば、body の浮動小数点加算と init_values のゼロは、浮動小数点加算が結合的ではないため、実際にはモノイドを形成しません。

より正式には、results...[j0, ..., jR-1] = reduce(input_slices_converted) です。ここで:

  • input_slices = inputs...[j0, ..., :, ..., jR-1]:dimensions に挿入されます。
  • input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...)
  • init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...)
  • reduce(input_slices_converted) = exec(schedule) は任意のバイナリ ツリーです。schedule は次のとおりです。
    • exec(node) = body(exec(node.left), exec(node.right))
    • exec(leaf) = leaf.value
  • schedule は実装定義の完全二分木で、順序付き走査は次のとおりです。
    • input_slices_converted...[index] 値。index_space(input_slices_converted) 内のすべての index は、index の昇順の辞書順で指定します。
    • 実装で定義された位置に、実装で定義された量の init_values_converted が点在します。

入力

ラベル 名前 タイプ 制約
(I1) inputs テンソルの可変数またはテンソルごとの量子化テンソル (C1~C4)、(C6)、(C7)
(I2) init_values 0 次元テンソルまたはテンソルごとの量子化テンソルの可変数 (C2)、(C3)
(I3) dimensions si64 型の 1 次元テンソル定数 (C4)、(C5)、(C7)
(I4) body 関数 (C6)

出力

名前 タイプ 制約
results テンソルの可変数またはテンソルごとの量子化テンソル (C3)、(C7)、(C8)

制約

  • (C1)same(shape(inputs...))
  • (C2)element_type(inputs...) = element_type(init_values...)
  • (C3)0 < size(inputs) = size(init_values) = size(results) = N
  • (C4)0 <= dimensions < rank(inputs[0])
  • (C5)is_unique(dimensions)
  • (C6)body の型は (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>) で、is_promotable(element_type(inputs[i]), Ei) です。
  • (C7)shape(results...) = shape(inputs...)。ただし、dimensions に対応する inputs... のディメンション サイズは含まれません。
  • (C8)[0,N) のすべての i に対して element_type(results[i]) = Ei

// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  dimensions = array<i64: 1>
} : (tensor<1x6xi64>, tensor<i64>) -> tensor<1xi64>
// %result = [15]

その他の例

reduce_precision

セマンティクス

operandexponent_bitsmantissa_bits を使用する別の浮動小数点型に要素単位で変換し、元の浮動小数点型に戻して output テンソルを生成します。

より正式には、次のように記述します。

  • 元の値のマンシンダビットが更新され、roundToIntegralTiesToEven セマンティクスを使用して、元の値を mantissa_bits で表せる最も近い値に丸められます。
  • 次に、mantissa_bits が元の値のマンシッサ ビット数より小さい場合は、マンシッサ ビットが mantissa_bits に切り捨てられます。
  • 次に、中間結果の指数ビットが exponent_bits で指定された範囲に収まらない場合、中間結果は元の符号を使用して無限大までオーバーフローするか、元の符号を使用してゼロにアンダーフローします。
  • 量子化された型の場合は、dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result)) を実行します。

入力

ラベル 名前 タイプ 制約
(I1) operand 浮動小数点型のテンソルまたはテンソルごとの量子化テンソル (C1)
(I2) exponent_bits si32 型の定数 (C2)
(I3) mantissa_bits si32 型の定数 (C3)

出力

名前 タイプ 制約
output 浮動小数点型のテンソルまたはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(operand) = baseline_type(output)
  • (C2)1 <= exponent_bits
  • (C3)0 <= mantissa_bits

// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
  exponent_bits = 5 : i32,
  mantissa_bits = 10 : i32
} : (tensor<6xf64>) -> tensor<6xf64>
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]

その他の例

reduce_scatter

セマンティクス

reduce_scatter

StableHLO プロセス グリッド内の各プロセス グループ内で、各プロセスの operand テンソルの値に対して computations を使用して減算を行い、scatter_dimension に沿って減算結果を分割し、分割された部分をプロセス間で分散して result を生成します。

このオペレーションでは、StableHLO プロセス グリッドが、次のように定義された process_groups に分割されます。

  • channel_id <= 0 and use_global_device_ids = false の場合、cross_replica(replica_groups)
  • channel_id > 0 and use_global_device_ids = false の場合、cross_replica_and_partition(replica_groups)
  • channel_id > 0 and use_global_device_ids = true の場合、flattened_ids(replica_groups)

その後、各 process_group 内で次の操作を行います。

  • reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation)
  • parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension)
  • process_group 内のすべての sender に対して result@receiver = parts@sender[receiver_index]receiver_index = process_group.index(receiver))。

入力

ラベル 名前 タイプ 制約
(I1) operand テンソルまたはテンソルごとの量子化テンソル (C1)、(C2)、(C7)、(C8)
(I2) scatter_dimension si64 型の定数 (C1)、(C2)、(C8)
(I3) replica_groups si64 型の 2 次元テンソル定数 (C3~C5)
(I4) channel_id si64 型の定数 (C6)
(I5) use_global_device_ids i1 型の定数 (C6)
(I6) computation 関数 (C7)

出力

名前 タイプ 制約
result テンソルまたはテンソルごとの量子化テンソル (C8~C9)

制約

  • (C1)dim(operand, scatter_dimension) % dim(process_groups, 1) = 0
  • (C2)0 <= scatter_dimension < rank(operand)
  • (C3)is_unique(replica_groups)
  • (C4)size(replica_groups) は次のように定義されます。
    • cross_replica を使用する場合は num_replicas
    • cross_replica_and_partition を使用する場合は num_replicas
    • flattened_ids を使用する場合は num_processes
  • (C5)0 <= replica_groups < size(replica_groups)
  • (C6)use_global_device_ids = true の場合は channel_id > 0 です。
  • (C7)computation の型は (tensor<E>, tensor<E>) -> (tensor<E>) です。is_promotable(element_type(operand), E)
  • (C8)shape(result) = shape(operand)(ただし、次を除く)
    • dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1)
  • (C9)element_type(result) = E

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
//                   [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
//                   [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
  %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
  "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  scatter_dimension = 1 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x4xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[10, 12],
//                  [18, 20]]
// %result@(1, 0): [[14, 16],
//                  [22, 24]]

その他の例

reduce_window

セマンティクス

inputsinit_values のウィンドウに削減関数 body を適用し、results を生成します。

次の図は、具体的な例を使用して results... の要素が inputs... から計算される方法を示しています。

reduce_window

より正式には、results...[result_index] = reduce(windows, init_values, axes(inputs...), body)reduce を参照)で、次のように定義されます。

  • padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1)
  • window_start = result_index * window_strides
  • window_end = window_start + (window_dimensions - 1) * window_dilations + 1
  • windows = slice(padded_inputs..., window_start, window_end, window_dilations)

入力

ラベル 名前 タイプ 制約
(I1) inputs テンソルの可変数またはテンソルごとの量子化テンソル (C1~C4)、(C6)、(C8)、(C10)、(C12)、(C13)、(C15)
(I2) init_values 0 次元テンソルまたはテンソルごとの量子化テンソルの可変数 (C1)、(C13)
(I3) window_dimensions si64 型の 1 次元テンソル定数 (C4)、(C5)、(C15)
(I4) window_strides si64 型の 1 次元テンソル定数 (C6)、(C7)、(C15)
(I5) base_dilations si64 型の 1 次元のテンソル定数 (C8)、(C9)、(C15)
(I6) window_dilations si64 型の 1 次元テンソル定数 (C10)、(C11)、(C15)
(I7) padding si64 型の 2 次元テンソル定数 (C12)、(C15)
(I8) body 関数 (C13)

出力

名前 タイプ 制約
results テンソルの可変数またはテンソルごとの量子化テンソル (C1)、(C14~C16)

制約

  • (C1)0 < size(inputs) = size(init_values) = size(results) = N
  • (C2)same(shape(inputs...))
  • (C3)element_type(inputs...) = element_type(init_values...)
  • (C4)size(window_dimensions) = rank(inputs[0])
  • (C5)0 < window_dimensions
  • (C6)size(window_strides) = rank(inputs[0])
  • (C7)0 < window_strides
  • (C8)size(base_dilations) = rank(inputs[0])
  • (C9)0 < base_dilations
  • (C10)size(window_dilations) = rank(inputs[0])
  • (C11)0 < window_dilations
  • (C12)shape(padding) = [rank(inputs[0]), 2]
  • (C13)body の型は (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>) で、is_promotable(element_type(inputs[i]), Ei) です。
  • (C14)same(shape(results...))
  • (C15)shape(results[0]) = num_windows ここで:
    • dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1
    • padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1]
    • dilated_window_shape = (window_dimensions - 1) * window_dilations + 1
    • is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape
    • num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1
  • (C16)[0,N) 内のすべての i に対して element_type(results[i]) = Ei

// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  window_dimensions = array<i64: 2, 1>,
  window_strides = array<i64: 4, 1>,
  base_dilations = array<i64: 2, 1>,
  window_dilations = array<i64: 3, 1>,
  padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<3x2xi64>, tensor<i64>) -> tensor<2x2xi64>
// %result = [[0, 0], [3, 4]]

その他の例

残り

セマンティクス

被除数 lhs テンソルと除数 rhs テンソルの要素ごとの余りを実行し、result テンソルを生成します。

より正式には、結果の符号は被除数から取得され、結果の絶対値は常に除数の絶対値よりも小さくなります。残りは lhs - d * rhs として計算されます。ここで、d は次のように定義されます。

  • 整数の場合: stablehlo.divide(lhs, rhs)
  • 浮動小数点数の場合: 丸め属性 roundTowardZero を持つ IEEE-754 の division(lhs, rhs)
  • 複素数の場合: 未定(#997)。
  • 量子化された型の場合:
    • dequantize_op_quantize(remainder, lhs, rhs, type(result))

浮動小数点要素型の場合、このオペレーションは IEEE-754 仕様の remainder オペレーションとは対照的です。d は、lhs/rhs の正確な値に最も近い整数値で、偶数に等しい値です。

入力

ラベル 名前 タイプ 制約
(I1) lhs 整数型、浮動小数点型、複合型のテンソル、またはテンソルごとの量子化テンソル (C1)
(I2) rhs 整数型、浮動小数点型、複合型のテンソル、またはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result 整数型、浮動小数点型、複合型のテンソル、またはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(operand) = baseline_type(result)

// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64>
// %result: [2, -2, 2, -2]

その他の例

replica_id

セマンティクス

現在のプロセスの replica_id を生成します。

出力

名前 タイプ
result ui32 型の 0 次元テンソル

%result = "stablehlo.replica_id"() : () -> tensor<ui32>

その他の例

reshape

セマンティクス

operand テンソルを result テンソルに再フォーマットします。概念的には、同じ正規表現を維持しながら、形状を変更する可能性(tensor<2x3xf32> から tensor<3x2xf32> または tensor<6xf32> への変更など)があります。

より正式には、result[result_index] = operand[operand_index] で、result_indexoperand_indexindex_space(result)index_space(operand) の辞書順で同じ位置にあります。

入力

ラベル 名前 タイプ 制約
(I1) operand テンソルまたは量子化テンソル (C1~C3)

出力

名前 タイプ 制約
result テンソルまたは量子化テンソル (C1~C3)

制約

  • (C1)element_type(result) は次のように定義されます。
    • element_type(operand)!is_per_axis_quantized(operand) の場合)。
    • element_type(operand)。ただし、quantization_dimension(operand)quantization_dimension(result) が異なる場合があります。
  • (C2)size(operand) = size(result)
  • (C3)is_per_axis_quantized(operand) の場合:
    • reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
    • dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
    • reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)

// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.reshape"(%operand) : (tensor<2x3xi32>) -> tensor<3x2xi32>
// %result: [[1, 2], [3, 4], [5, 6]]

その他の例

reverse

セマンティクス

指定された dimensions に沿って operand の要素の順序を逆にして、result テンソルを生成します。より正式には、result[result_index] = operand[operand_index] で次のようにします。

  • dimensionsd の場合、operand_index[d] = dim(result, d) - result_index[d] - 1
  • そうでない場合は operand_index[d] = result_index[d]

入力

ラベル 名前 タイプ 制約
(I1) operand テンソルまたはテンソルごとの量子化テンソル (C1)、(C3)
(I2) dimensions si64 型の 1 次元テンソル定数 (C2)、(C3)

出力

名前 タイプ 制約
result テンソルまたはテンソルごとの量子化テンソル (C1)、(C3)

制約

  • (C1)type(operand) = type(result)
  • (C2)is_unique(dimensions)
  • (C3)0 <= dimensions < rank(result)

// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
  dimensions = array<i64: 1>
} : (tensor<3x2xi32>) -> tensor<3x2xi32>
// %result: [[2, 1], [4, 3], [6, 5]]

その他の例

rng

セマンティクス

rng_distribution アルゴリズムを使用して乱数を生成し、指定された形状 shaperesult テンソルを生成します。

rng_distribution = UNIFORM の場合、乱数は [a, b) の区間で均一分布に従って生成されます。a >= b の場合、動作は未定義です。

rng_distribution = NORMAL の場合、乱数は平均 = a、標準偏差 = b の正規分布に従って生成されます。b < 0 の場合、動作は未定義です。

乱数の生成方法は実装で定義されます。たとえば、確定的である場合もあれば、そうでない場合もあります。また、非表示の状態を使用する場合もあれば、使用しないこともあります。

多くの関係者と話し合った結果、このオペレーションは事実上非推奨であることが判明したため、今後は削除を検討する予定です(#597)。

入力

ラベル 名前 タイプ 制約
(I1) a 整数型、ブール型、浮動小数点型の 0 次元テンソル (C1)、(C2)
(I2) b 整数型、ブール型、浮動小数点型の 0 次元テンソル (C1)、(C2)
(I3) shape si64 型の 1 次元テンソル定数 (C3)
(I4) rng_distribution UNIFORMNORMAL の列挙型 (C2)

出力

名前 タイプ 制約
result 整数型、ブール型、浮動小数点型のテンソル (C1~C3)

制約

  • (C1)element_type(a) = element_type(b) = element_type(result)
  • (C2)rng_distribution = NORMAL の場合は is_float(a) です。
  • (C3)shape(result) = shape

// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
  rng_distribution = #stablehlo<rng_distribution UNIFORM>
} : (tensor<i32>, tensor<i32>, tensor<2xi64>) -> tensor<3x3xi32>
// %result: [
//           [1, 0, 1],
//           [1, 1, 1],
//           [0, 0, 0]
//          ]

rng_bit_generator

セマンティクス

初期状態 initial_state に基づいて、疑似乱数生成アルゴリズム rng_algorithm を使用して、均一なランダムビットで満たされた output と更新された出力状態 output_state を返します。出力は initial_state の確定的関数であることが保証されますが、実装間で確定的であるとは限りません。

rng_algorithm は、次のいずれかです。

  • DEFAULT: 実装定義のアルゴリズム。
  • THREE_FRY: Threefry アルゴリズムの実装定義バリアント。*
  • PHILOX: Philox アルゴリズムの実装定義バリアント。*

* 参照: Salmon et al. SC 2011. 並列乱数: 1、2、3 で簡単に生成。

入力

ラベル 名前 タイプ 制約
(I1) rng_algorithm DEFAULTTHREE_FRYPHILOX の列挙型 (C2)
(I2) initial_state ui64 型の 1 次元テンソル (C1)、(C2)

出力

名前 タイプ 制約
output_state ui64 型の 1 次元テンソル (C1)
output 整数型または浮動小数点型のテンソル

制約

  • (C1)type(initial_state) = type(output_state)
  • (C2)size(initial_state) は次のように定義されます。
    • rng_algorithm = DEFAULT の場合は実装定義。
    • rng_algorithm = THREE_FRY の場合、2
    • rng_algorithm = PHILOX の場合は 2 または 3

// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
  rng_algorithm = #stablehlo<rng_algorithm THREE_FRY>
} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x2xui64>)
// %output_state: [1, 6]
// %output: [
//           [9236835810183407956, 16087790271692313299],
//           [18212823393184779219, 2658481902456610144]
//          ]

round_nearest_afz

セマンティクス

operand テンソルに対して要素ごとの四捨五入を行い、同値の場合はゼロから離れた値を返します。これにより、result テンソルが生成されます。IEEE-754 仕様の roundToIntegralTiesToAway オペレーションを実装します。量子化された型の場合は、dequantize_op_quantize(round_nearest_afz, operand, type(result)) を実行します。

入力

ラベル 名前 タイプ 制約
(I1) operand 浮動小数点型のテンソルまたはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result 浮動小数点型のテンソルまたはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(operand) = baseline_type(result)

// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]

その他の例

round_nearest_even

セマンティクス

テンソルで operand テンソルで、最も近い整数への要素単位の丸めを実行し、偶数の整数への結合を解除して、result テンソルを生成します。IEEE-754 仕様の roundToIntegralTiesToEven オペレーションを実装します。量子化された型の場合は、dequantize_op_quantize(round_nearest_even, operand, type(result)) を実行します。

入力

ラベル 名前 タイプ 制約
(I1) operand 浮動小数点型のテンソルまたはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result 浮動小数点型のテンソルまたはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(operand) = baseline_type(result)

// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]

その他の例

rsqrt

セマンティクス

operand テンソルに対して要素単位の逆平方根演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。

  • 浮動小数点数の場合: IEEE-754 の rSqrt
  • 複素数の場合: 複素数の逆平方根。
  • 量子化された型の場合: dequantize_op_quantize(rsqrt, operand, type(result))

入力

ラベル 名前 タイプ 制約
(I1) operand 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(operand) = baseline_type(result)

// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.5], [0.33333343, 0.2]]

その他の例

散布

セマンティクス

inputs テンソルと同じ results テンソルを生成します。ただし、scatter_indices で指定された複数のスライスが、update_computation を使用して値 updates で更新されます。

次の図は、updates... の要素が results... の要素にどのようにマッピングされるかを、具体的な例を使用して示しています。この図では、いくつかの updates... インデックスの例を選択して、それらが対応する results... インデックスの詳細について説明しています。

散布

より正式に、index_space(updates[0]) 内のすべての update_index について、次のように定義します。

  • update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims]
  • update_scatter_index = update_index[update_scatter_dims...]
  • start_index は次のように定義されます。
    • scatter_indices[si0, ..., :, ..., siN]。ここで、siupdate_scatter_index の個々の要素であり、index_vector_dim < rank(scatter_indices) の場合、:index_vector_dim インデックスに挿入されます。
    • そうでない場合は [scatter_indices[update_scatter_index]]
  • axes(inputs[0])d_input は、
    • d_input = scatter_dims_to_operand_dims[d_start] の場合、full_start_index[d_input] = start_index[d_start]
    • それ以外の場合は full_start_index[d_input] = 0
  • axes(inputs[0])d_input は、
    • d_input = input_batching_dims[i_batching] かつ d_start = scatter_indices_batching_dims[i_batching] の場合は full_batching_index[d_input] = update_scatter_index[d_start - (d_start < index_vector_dim ? 0 : 1)]
    • それ以外の場合は full_batching_index[d_input] = 0
  • update_window_index = update_index[update_window_dims...]
  • full_window_index = [wi0, ..., 0, ..., wiN]。ここで、wiupdate_window_index の個々の要素で、0inserted_window_dimsinput_batching_dims のインデックスに挿入されます。
  • result_index = full_start_index + full_batching_index + full_window_index

results = exec(schedule, inputs)。ここで、

  • schedule は、index_space(updates[0]) の実装定義された並べ替えです。
  • exec([update_index, ...], results) = exec([...], updated_results) ここで:
    • result_indexshape(results...) の範囲内にある場合
    • updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )
    • updated_values = update_computation(results...[result_index], updates_converted)
    • updated_results は、results...[result_index]updated_values... に設定された results のコピーです。
    • それ以外の場合
    • updated_results = results
  • exec([], results) = results

indices_are_sortedtrue の場合、実装では scatter_indicesscatter_dims_to_operand_dims に関して並べ替えられていると想定できます。それ以外の場合、動作は未定義です。より正式には、indices(result) のすべての i1 < i2 について、full_start_index(i1) <= full_start_index(i2) です。

unique_indicestrue の場合、実装は分散されているすべての result_index インデックスが一意であると想定できます。unique_indicestrue で、分散先のインデックスが一意でない場合、動作は未定義です。

入力

ラベル 名前 タイプ 制約
(I1) inputs テンソルの可変数またはテンソルごとの量子化テンソル (C1)、(C2)、(C4 ~ C6)、(C11)、(C13)、(C18)、(C21)、(C23 ~ C24)
(I2) scatter_indices 整数型のテンソル (C4)、(C15)、(C19)、(C22)
(I3) updates テンソルの可変数またはテンソルごとの量子化テンソル (C3-C6)、(C8)
(I4) update_window_dims si64 型の 1 次元のテンソル定数 (C2)、(C4)、(C7~C8)
(I5) inserted_window_dims si64 型の 1 次元テンソル定数 (C2)、(C4)、(C9 ~ C11)
(I6) input_batching_dims si64 型の 1 次元テンソル定数 (C2)、(C4)、(C9)、(C12-13)、(C17-18)、(C20)
(I7) scatter_indices_batching_dims si64 型の 1 次元テンソル定数 (C14 ~ C18)
(I8) scatter_dims_to_operand_dims si64 型の 1 次元テンソル定数 (C19 ~ C21)
(I9) index_vector_dim si64 型の定数 (C4)、(C16)、(C19)、(C22)
(I10) indices_are_sorted i1 型の定数
(I11) unique_indices i1 型の定数
(I12) update_computation 関数 (C23)

出力

名前 タイプ 制約
results テンソルの可変数またはテンソルごとの量子化テンソル (C24 ~ C25)

制約

  • (C1)same(shape(inputs...))
  • (C2)`rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims)
    • size(input_batching_dims)`.
  • (C3)same(shape(updates...))
  • (C4)shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes) ここで:
    • update_scatter_dim_sizes = shape(scatter_indices)。ただし、index_vector_dim に対応する scatter_indices のディメンション サイズは含まれません。
    • update_window_dim_sizes <= shape(inputs[0])。ただし、inserted_window_dimsinput_batching_dims に対応する inputs[0] のディメンション サイズは含まれません。
    • combine は、update_scatter_dim_sizesupdate_scatter_dims に対応する軸に配置し、update_window_dim_sizesupdate_window_dims に対応する軸に配置します。
  • (C5)0 < size(inputs) = size(updates) = N
  • (C6)element_type(updates...) = element_type(inputs...)
  • (C7)is_unique(update_window_dims) and is_sorted(update_window_dims)
  • (C8)0 <= update_window_dims < rank(updates[0])
  • (C9)is_unique(concatenate(inserted_window_dims, input_batching_dims))
  • (C10)is_sorted(inserted_window_dims)
  • (C11)0 <= inserted_window_dims < rank(inputs[0])
  • (C12)is_sorted(input_batching_dims)
  • (C13)0 <= input_batching_dims < rank(inputs[0]))
  • (C14)is_unique(scatter_indices_batching_dims)
  • (C15)0 <= scatter_indices_batching_dims < rank(scatter_indices)
  • (C16)index_vector_dim not in scatter_indices_batching_dims
  • (C17)size(input_batching_dims) == size(scatter_indices_batching_dims)
  • (C18)dim(inputs[0], input_batching_dims...) = dim(scatter_indices, scatter_indices_batching_dims...)
  • (C19)size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1
  • (C20)is_unique(concatenate(scatter_dims_to_operand_dims, input_batching_dims))
  • (C21)0 <= scatter_dims_to_operand_dims < rank(inputs[0])
  • (C22)0 <= index_vector_dim <= rank(scatter_indices)
  • (C23)update_computation の型は (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>) です。ここで、is_promotable(element_type(inputs[i]), Ei)
  • (C24)shape(inputs...) = shape(results...)
  • (C25)[0,N) 内のすべての i に対して element_type(results[i]) = Ei

// %input: [
//          [
//           [[1, 2], [3, 4], [5, 6], [7, 8]],
//           [[9, 10],[11, 12], [13, 14], [15, 16]],
//           [[17, 18], [19, 20], [21, 22], [23, 24]]
//          ],
//          [
//           [[25, 26], [27, 28], [29, 30], [31, 32]],
//           [[33, 34], [35, 36], [37, 38], [39, 40]],
//           [[41, 42], [43, 44], [45, 46], [47, 48]]
//          ]
//         ]
// %scatter_indices: [
//                    [
//                     [[0, 0], [1, 0], [2, 1]],
//                     [[0, 1], [1, 1], [0, 9]]
//                    ],
//                    [
//                     [[0, 0], [2, 1], [2, 2]],
//                     [[1, 2], [0, 1], [1, 0]]
//                    ]
//                   ]
// %update: [
//           [
//            [[1, 1], [1, 1], [1, 1]],
//            [[1, 1], [1, 1], [1, 1]]
//           ],
//           [
//            [[1, 1], [1, 1], [1, 1]],
//            [[1, 1], [1, 1], [1, 1]]
//           ]
//          ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  scatter_dimension_numbers = #stablehlo.scatter<
    update_window_dims = [3, 4],
    inserted_window_dims = [1],
    input_batching_dims = [0],
    scatter_indices_batching_dims = [1],
    scatter_dims_to_operand_dims = [2, 1],
    index_vector_dim = 3>,
  indices_are_sorted = false,
  unique_indices = false
} : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>, tensor<2x2x3x2x2xi64>) -> tensor<2x3x4x2xi64>
// %result: [
//           [
//            [[3, 4], [6, 7], [6, 7], [7, 8]],
//            [[9, 10],[11, 12], [15, 16], [17, 18]],
//            [[17, 18], [19, 20], [22, 23], [24, 25]]
//           ],
//           [
//            [[25, 26], [28, 29], [30, 31], [31, 32]],
//            [[35, 36], [38, 39], [38, 39], [39, 40]],
//            [[41, 42], [44, 45], [46, 47], [47, 48]]
//           ]
//          ]

その他の例

選択

セマンティクス

result テンソルを生成します。各要素は、pred の対応する要素の値に基づいて on_true テンソルまたは on_false テンソルから選択されます。正式には result[result_index] = pred_element ? on_true[result_index] : on_false[result_index] です。ここで pred_element = rank(pred) = 0 ? pred[] : pred[result_index] は量子化された型の場合は、dequantize_select_quantize(pred, on_true, on_false, type(result)) を実行します。

入力

ラベル 名前 タイプ 制約
(I1) pred i1 型のテンソル (C1)
(I2) on_true テンソルまたはテンソルごとの量子化テンソル (C1-C2)
(I3) on_false テンソルまたはテンソルごとの量子化テンソル (C2)

出力

名前 タイプ 制約
result テンソルまたはテンソルごとの量子化テンソル (C2)

制約

  • (C1)rank(pred) = 0 or shape(pred) = shape(on_true)
  • (C2)baseline_type(on_true) = baseline_type(on_false) = baseline_type(result)

// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false) : (tensor<2x2xi1>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 2], [3, 8]]

その他の例

select_and_scatter

セマンティクス

select を使用した input テンソルの reduce_window の結果に基づいて、scatter を使用して source テンソルから値を散布し、result テンソルを生成します。

次の図は、具体的な例を使用して、result の要素が operandsource からどのように計算されるかを示しています。

select_and_scatter

よりフォーマルな表現:

  • selected_values = reduce_window_without_init(...) は、次の入力に置き換えます。

    • inputs = [operand].
    • window_dimensionswindow_stridespadding はそのまま使用されます。
    • base_dilations = windows_dilations = 1
    • body は次のように定義されます。
    def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>:
      return select(arg0, arg1) ? arg0 : arg1;
    

    ここで、E = element_type(operand)reduce_window_without_initreduce_window とまったく同じように機能しますが、基になる reduceschedule削減を参照)に init 値が含まれないことが異なります。対応するウィンドウに値がない場合の動作は現在未定義です(#731)。

  • result[result_index] = reduce([source_values], [init_value], [0], scatter) ここで:

    • source_values = [source[source_index] for source_index in source_indices]
    • selected_index(source_index) = operand_index: selected_values[source_index]operand_indexoperand 要素がある場合。
    • source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index]

入力

ラベル 名前 タイプ 制約
(I1) operand テンソルまたはテンソルごとの量子化テンソル (C1~C4)、(C6)、(C8~C11)
(I2) source テンソルまたはテンソルごとの量子化テンソル (C1)、(C2)
(I3) init_value 0 次元テンソルまたはテンソルごとの量子化テンソル (C3)
(I4) window_dimensions si64 型の 1 次元のテンソル定数 (C2)、(C4)、(C5)
(I5) window_strides si64 型の 1 次元のテンソル定数 (C2)、(C6)、(C7)
(I6) padding si64 型の 2 次元のテンソル定数 (C2)、(C8)
(I7) select 関数 (C9)
(I8) scatter 関数 (C10)

出力

名前 タイプ 制約
result テンソルまたはテンソルごとの量子化テンソル (C11-C12)

制約

  • (C1)element_type(operand) = element_type(source)
  • (C2)shape(source) = num_windows。ここで:
    • padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1]
    • is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape
    • num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1
  • (C3)element_type(init_value) = element_type(operand)
  • (C4)size(window_dimensions) = rank(operand)
  • (C5)0 < window_dimensions
  • (C6)size(window_strides) = rank(operand)
  • (C7)0 < window_strides
  • (C8)shape(padding) = [rank(operand), 2]
  • (C9)select の型は (tensor<E>, tensor<E>) -> tensor<i1>E = element_type(operand))です。
  • (C10)scatter の型は (tensor<E>, tensor<E>) -> tensor<E> です。is_promotable(element_type(operand), E)
  • (C11)shape(operand) = shape(result)
  • (C12)element_type(result) = E

// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GE>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    "stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  window_dimensions = array<i64: 3, 1>,
  window_strides = array<i64: 2, 1>,
  padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi64>, tensor<2x2xi64>, tensor<i64>) -> tensor<4x2xi64>
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]

その他の例

送信

セマンティクス

inputs をチャンネル channel_id に送信し、result トークンを生成します。

is_host_transfertrue の場合、オペレーションはホストにデータを転送します。そうでない場合は、別のデバイスにデータを転送します。つまり、実装は定義されます。このフラグは channel_type で提供される情報と重複するため、今後はどちらか一方のみを保持する予定です(#666)。

入力

ラベル 名前 タイプ 制約
(I1) inputs テンソルまたは量子化テンソルの可変数
(I2) token token
(I3) channel_id si64 型の定数
(I4) channel_type DEVICE_TO_DEVICEDEVICE_TO_HOST の列挙型 (C1)
(I5) is_host_transfer i1 型の定数 (C1)

出力

名前 タイプ
result token

制約

  • (C1)channel_type は次のように定義されます。
    • is_host_transfer = true の場合、DEVICE_TO_HOST
    • それ以外の場合は DEVICE_TO_DEVICE

%result = "stablehlo.send"(%operand, %token) {
  channel_handle = #stablehlo.channel_handle<handle = 1, type = 2>,
  is_host_transfer = true
} : (tensor<2x2xi64>, !stablehlo.token) -> !stablehlo.token

その他の例

shift_left

セマンティクス

lhs テンソルを rhs ビット分要素ごとに左シフトし、result テンソルを生成します。

入力

ラベル 名前 タイプ 制約
(I1) lhs 整数型のテンソル (C1)
(I2) rhs 整数型のテンソル (C1)

出力

名前 タイプ 制約
result 整数型のテンソル (C1)

制約

  • (C1)type(lhs) = type(rhs) = type(result)

// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-2, 0, 8]

その他の例

shift_right_arithmetic

セマンティクス

lhs テンソルに対して rhs ビット数だけ要素単位の右シフト演算を実行し、result テンソルを生成します。

入力

ラベル 名前 タイプ 制約
(I1) lhs 整数型のテンソル (C1)
(I2) rhs 整数型のテンソル (C1)

出力

名前 タイプ 制約
result 整数型のテンソル (C1)

制約

  • (C1)type(lhs) = type(rhs) = type(result)

// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-1, 0, 1]

その他の例

shift_right_logical

セマンティクス

lhs テンソルを rhs ビット数だけ要素ごとに論理右シフトし、result テンソルを生成します。

入力

ラベル 名前 タイプ 制約
(I1) lhs 整数型のテンソル (C1)
(I2) rhs 整数型のテンソル (C1)

出力

名前 タイプ 制約
result 整数型のテンソル (C1)

制約

  • (C1)type(lhs) = type(rhs) = type(result)

// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [9223372036854775807, 0, 1]

その他の例

署名

セマンティクス

要素ごとの operand の符号を返して、result テンソルを生成します。より正式には、各要素 x のセマンティクスは、次のように Python 構文を使用して表現できます。

def sign(x):
  if is_integer(x):
    if compare(x, 0, LT, SIGNED): return -1
    if compare(x, 0, EQ, SIGNED): return 0
    return 1
  elif is_float(x):
    if is_nan(x): return NaN
    if compare(x, -0.0, EQ, FLOAT): return -0.0
    if compare(x, +0.0, EQ, FLOAT): return +0.0
    if compare(x, 0.0, LT, FLOAT): return -1.0
    return 1.0
  elif is_complex(x):
    if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
    if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
    return divide(x, convert(abs(x), type(x)))

量子化型の場合は、dequantize_op_quantize(sign, operand, type(result)) を実行します。

入力

ラベル 名前 タイプ 制約
(I1) operand 符号付き整数、浮動小数点数、複素型のテンソル、またはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result 符号付き整数型、浮動小数点型、複合型のテンソル、またはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(operand) = baseline_type(result)

// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]

その他の例

サイン

セマンティクス

operand テンソルで要素ごとの正弦演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の操作を行います。

  • 浮動小数点数の場合: IEEE-754 の sin
  • 複素数の場合: 複素正弦。
  • 量子化された型の場合: dequantize_op_quantize(sine, operand, type(result))

入力

ラベル 名前 タイプ 制約
(I1) operand 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(operand) = baseline_type(result)

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.sine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [0.0, -1.0]]

その他の例

slice

セマンティクス

静的に計算された開始インデックスを使用して operand からスライスを取り出し、result テンソルを生成します。start_indices には各ディメンションのスライスの開始インデックスが、limit_indices には各ディメンションのスライスの終了インデックス(これを含まない)が含まれ、strides には各ディメンションのストライドが含まれます。

より正式には、result[result_index] = operand[operand_index]operand_index = start_indices + result_index * strides

入力

ラベル 名前 タイプ 制約
(I1) operand テンソルまたはテンソルごとの量子化テンソル (C1 ~ C3)、(C5)
(I2) start_indices si64 型の 1 次元テンソル定数 (C2)、(C3)、(C5)
(I3) limit_indices si64 型の 1 次元テンソル定数 (C2)、(C3)、(C5)
(I4) strides si64 型の 1 次元テンソル定数 (C2)、(C4)

出力

名前 タイプ 制約
result テンソルまたはテンソルごとの量子化テンソル (C1)、(C5)

制約

  • (C1)element_type(operand) = element_type(result)
  • (C2)size(start_indices) = size(limit_indices) = size(strides) = rank(operand)
  • (C3)0 <= start_indices <= limit_indices <= shape(operand)
  • (C4)0 < strides
  • (C5)shape(result) = ceil((limit_indices - start_indices) / strides)

// %operand: [
//            [0, 0, 0, 0],
//            [0, 0, 1, 1],
//            [0, 0, 1, 1]
//           ]
%result = "stablehlo.slice"(%operand) {
  start_indices = array<i64: 1, 2>,
  limit_indices = array<i64: 3, 4>,
  strides = array<i64: 1, 1>
} : (tensor<3x4xi64>) -> tensor<2x2xi64>
// % result: [
//            [1, 1],
//            [1, 1]
//           ]

その他の例

並べ替え

セマンティクス

comparator に従って、ディメンション dimension に沿って inputs の 1 次元スライスを並べ替え、results を生成します。

他のオペレーションの同様の入力とは異なり、dimension では負の値を使用できます。セマンティクスは次のとおりです。今後、一貫性のため、この操作が禁止される可能性があります(#1377)。

is_stable が true の場合、並べ替えは安定します。つまり、比較演算子によって等しいとみなされる要素の相対順序は保持されます。入力が 1 つの場合、2 つの要素 e1e2 は、comparator(e1, e2) = comparator(e2, e1) = false の場合にのみ、比較演算子によって等しいと見なされます。これが複数の入力に一般化される仕組みについては、以下の形式化をご覧ください。

より正式に、index_space(results[0]) 内のすべての result_index について、次のように定義します。

  • adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension
  • result_slice = [ri0, ..., :, ..., riR-1]。ここで、riNresult_index の個々の要素であり、:adjusted_dimension に挿入されます。
  • inputs_together = (inputs[0]..., ..., inputs[N-1]...)
  • results_together[result_slice] = sort(inputs_together[result_slice], comparator_together)
  • ここで、sort は 1 次元スライスを降順以外の順序で並べ替えます。これは、左側の引数が右側の 2 番目の引数より小さい場合に comparator_togethertrue を返すことを想定しています。
  • def comparator_together(lhs_together, rhs_together):
      args = []
      for (lhs_el, rhs_el) in zip(lhs_together, rhs_together):
        args.append(lhs_el)
        args.append(rhs_el)
      return comparator(*args)
    
  • (results[0]..., ..., results[N-1]...) = results_together

入力

ラベル 名前 タイプ 制約
(I1) inputs テンソルの可変数またはテンソルごとの量子化テンソル (C1 ~ C5)
(I2) dimension si64 型の定数 (C4)
(I3) is_stable i1 型の定数
(I4) comparator 関数 (C5)

出力

名前 タイプ 制約
results テンソルの可変数またはテンソルごとの量子化テンソル (C2)、(C3)

制約

  • (C1)0 < size(inputs)
  • (C2)type(inputs...) = type(results...)
  • (C3)same(shape(inputs...) + shape(results...))
  • (C4)-R <= dimension < RR = rank(inputs[0]))。
  • (C5)comparator の型は (tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1> です。ここで、Ei = element_type(inputs[i])

// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>, %arg3: tensor<i64>):
    %predicate = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GT>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    "stablehlo.return"(%predicate) : (tensor<i1>) -> ()
}) {
  dimension = 0 : i64,
  is_stable = true
} : (tensor<2x3xi64>, tensor<2x3xi64>) -> (tensor<2x3xi64>, tensor<2x3xi64>)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]

その他の例

sqrt

セマンティクス

operand テンソルで要素ごとの平方根演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の操作を行います。

  • 浮動小数点数の場合: IEEE-754 の squareRoot
  • 複素数の場合: 複素平方根。
  • 量子化された型の場合: dequantize_op_quantize(sqrt, operand, type(result))

入力

ラベル 名前 タイプ 制約
(I1) operand 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(operand) = baseline_type(result)

// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [2.0, 3.0]]

その他の例

subtract

セマンティクス

2 つのテンソル lhsrhs の要素ごとの減算を行い、result テンソルを生成します。要素のタイプに応じて、次の処理を行います。

  • 整数の場合: 整数の減算。
  • 浮動小数点数: IEEE-754 の subtraction
  • 複素数の場合: 複素減算。
  • 量子化された型の場合:
    • dequantize_op_quantize(subtract, lhs, rhs, type(result))

入力

ラベル 名前 タイプ 制約
(I1) lhs 整数型、浮動小数点型、複合型のテンソル、またはテンソルごとの量子化テンソル (C1)
(I2) rhs 整数、浮動小数点、複合型のテンソル、またはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result 整数型、浮動小数点型、複合型のテンソル、またはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)

// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
// %result: [[1, 2], [3, 4]]

その他の例

tan

セマンティクス

operand テンソルに対して要素ごとのタンジェント演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の操作を行います。

  • 浮動小数点数の場合: IEEE-754 の tan
  • 複素数の場合: 複素タンジェント。
  • 量子化型の場合: dequantize_op_quantize(tan, operand, type(result))

入力

ラベル 名前 タイプ 制約
(I1) operand 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(operand) = baseline_type(result)

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.tan"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [
//           [0.0, 1.63312e+16],
//           [0.0, 5.44375e+15]
//          ]

その他の例

tanh

セマンティクス

operand テンソルに対して要素単位の双曲線正接演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の操作を行います。

  • 浮動小数点数の場合: IEEE-754 の tanh
  • 複素数の場合: 複素双曲線正接。
  • 量子化された型の場合:
    • dequantize_op_quantize(tanh, operand, type(result))

入力

ラベル 名前 タイプ 制約
(I1) operand 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(operand) = baseline_type(result)

// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand) : (tensor<3xf32>) -> tensor<3xf32>
// %result: [-0.76159416, 0.0, 0.76159416]

その他の例

行 / 列の入れ替え

セマンティクス

permutation を使用して operand テンソルの次元を並べ替え、result テンソルを生成します。より正式には、result[result_index] = operand[operand_index] で、result_index[d] = operand_index[permutation[d]] です。

入力

ラベル 名前 タイプ 制約
(I1) operand テンソルまたは量子化テンソル (C1~C4)
(I2) permutation si64 型の 1 次元テンソル定数 (C2 ~ C4)

出力

名前 タイプ 制約
result テンソルまたは量子化テンソル (C1)、(C3 ~ C4)

制約

  • (C1)element_type(result) は次のように定義されます。
    • element_type(operand)!is_per_axis_quantized(operand) の場合)。
    • element_type(operand)。ただし、quantization_dimension(operand)quantization_dimension(result) が異なる場合があります。
  • (C2)permutationrange(rank(operand)) の並べ替えです。
  • (C3)shape(result) = dim(operand, permutation...)
  • (C4)is_per_axis_quantized(result) の場合は quantization_dimension(operand) = permutation(quantization_dimension(result)) です。

// %operand: [
//            [[1,2], [3,4], [5,6]],
//            [[7,8], [9,10], [11,12]]
//           ]
%result = "stablehlo.transpose"(%operand) {
  permutation = array<i64: 2, 1, 0>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
// %result: [
//           [[1,7], [3,9], [5,11]],
//           [[2,8], [4,10], [6,12]]
//          ]

その他の例

triangular_solve

セマンティクス

下三角または上三角係数行列を持つ連立一次方程式のバッチを解きます。

より正式には、ab が指定されている場合、result[i0, ..., iR-3, :, :]left_sidetrue の場合の op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :] の解、または left_sidefalse の場合の x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :] です。変数 x は、op(a)transpose_a によって決定されます。transpose_a は次のいずれかです。

  • NO_TRANSPOSE: a をそのまま使用してオペレーションを実行します。
  • TRANSPOSE: a の転置でオペレーションを実行します。
  • ADJOINT: a の共役転置に対して演算を実行します。

入力データは、lowertrue の場合、a の下三角からのみ読み取られます。それ以外の場合は、a の上三角から読み取られます。出力データは同じ三角形に返されます。もう 1 つの三角形の値は実装で定義されます。

unit_diagonal が true の場合、実装は a の対角要素が 1 に等しいと想定できます。それ以外の場合は、動作は未定義です。

量子化された型の場合は、dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower, unit_diagonal, transpose_a), a, b, type(result)) を実行します。

入力

ラベル 名前 タイプ 制約
(I1) a 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル (C1~C3)
(I2) b 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル (C1~C4)
(I3) left_side i1 型の定数 (C3)
(I4) lower i1 型の定数
(I5) unit_diagonal i1 型の定数
(I6) transpose_a NO_TRANSPOSETRANSPOSEADJOINT の列挙型

出力

名前 タイプ 制約
result 浮動小数点型または複合型のテンソル、またはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_element_type(a) = baseline_element_type(b)
  • (C2)2 <= rank(a) = rank(b) = R
  • (C3)shape(a)shape(b) の関係は次のように定義されます。
    • shape(a)[:-3] = shape(b)[:-3]
    • dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1)
  • (C4)baseline_type(b) = baseline_type(result)

// %a = [
//       [1.0, 0.0, 0.0],
//       [2.0, 4.0, 0.0],
//       [3.0, 5.0, 6.0]
//      ]
// %b = [
//       [2.0, 0.0, 0.0],
//       [4.0, 8.0, 0.0],
//       [6.0, 10.0, 12.0]
//      ]
%result = "stablehlo.triangular_solve"(%a, %b) {
  left_side = true,
  lower = true,
  unit_diagonal = false,
  transpose_a = #stablehlo<transpose NO_TRANSPOSE>
} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// %result: [
//           [2.0, 0.0, 0.0],
//           [0.0, 2.0, 0.0],
//           [0.0, 0.0, 2.0]
//          ]

tuple

セマンティクス

val から result タプルを生成します。

入力

ラベル 名前 タイプ 制約
(I1) val 可変長の値の数 (C1)

出力

名前 タイプ 制約
result tuple (C1)

制約

  • (C1)result の型は tuple<E0, ..., EN-1> で、Ei = type(val[i]) です。

// %val0: [1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1) : (tensor<2xf32>, tuple<tensor<i32>>) -> tuple<tensor<2xf32>, tuple<tensor<i32>>>
// %result: ([1.0, 2.0], (3))

その他の例

uniform_dequantize

セマンティクス

operand 型で定義された量子化パラメータに従って、量子化テンソル operand を浮動小数点テンソル result に要素ごとに変換します。

正式には result = dequantize(operand) です。

入力

ラベル 名前 タイプ 制約
(I1) operand 量子化テンソル (C1)、(C2)

出力

名前 タイプ 制約
result 浮動小数点型のテンソル (C1)、(C2)

制約

  • (C1)shape(operand) = shape(result)
  • (C2)element_type(result) = expressed_type(operand)

// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
// %result: [4.0, 15.0]

uniform_quantize

セマンティクス

result 型で定義された量子化パラメータに従って、浮動小数点テンソルまたは量子化テンソル operand を量子化テンソル result に要素単位で変換します。

よりフォーマルな表現では

  • is_float(operand) の場合:
    • result = quantize(operand, type(result))
  • is_quantized(operand) の場合:
    • float_result = dequantize(operand)
    • result = quantize(float_result, type(result))

入力

ラベル 名前 タイプ 制約
(I1) operand 浮動小数点型または量子化型のテンソル (C1)、(C2)

出力

名前 タイプ 制約
result 量子化テンソル (C1)、(C2)

制約

  • (C1)shape(operand) = shape(result)
  • (C2)expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand)

// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
// %result: [10, 10]

// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-20,0.2:-30}>>
// %result: [20, 45]

しばらく

セマンティクス

cond 関数が true を出力している間に、body 関数を 0 回以上実行すると、出力が生成されます。より正式には、セマンティクスは Python 構文を使用して次のように表現できます。

internal_state = operand
while cond(*internal_state):
  internal_state = body(*internal_state)
results = internal_state

無限ループの動作は未定です(#383)。

入力

ラベル 名前 タイプ 制約
(I1) operand テンソル、量子化テンソル、トークンの可変数 (C1~C3)
(I2) cond 関数 (C1)
(I3) body 関数 (C2)

出力

名前 タイプ 制約
results テンソル、量子化テンソル、トークンの可変数 (C3)

制約

  • (C1)cond の型は (T0, ..., TN-1) -> tensor<i1> です。ここで、Ti = type(operand[i])
  • (C2)body の型は (T0, ..., TN-1) -> (T0, ..., TN-1) です。ここで、Ti = type(operand[i])
  • (C3)type(results...) = type(operand...)

// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %cond = "stablehlo.compare"(%arg0, %ten) {
      comparison_direction = #stablehlo<comparison_direction LT>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    stablehlo.return %cond : tensor<i1>
  }, {
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %new_sum = stablehlo.add %arg1, %one : tensor<i64>
    %new_i = stablehlo.add %arg0, %one : tensor<i64>
    stablehlo.return %new_i, %new_sum : tensor<i64>, tensor<i64>
}) : (tensor<i64>, tensor<i64>) -> (tensor<i64>, tensor<i64>)
// %results0: 10
// %results1: 10

その他の例

xor

セマンティクス

2 つのテンソル lhsrhs の要素ごとの XOR を実行し、result テンソルを生成します。要素のタイプに応じて、次の操作を行います。

  • ブール値の場合: 論理 XOR。
  • 整数の場合: ビット単位の XOR。

入力

ラベル 名前 タイプ 制約
(I1) lhs ブール型または整数型のテンソル (C1)
(I2) rhs ブール値型または整数型のテンソル (C1)

出力

名前 タイプ 制約
result ブール型または整数型のテンソル (C1)

制約

  • (C1)type(lhs) = type(rhs) = type(result)

// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]

// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, false]]

その他の例

方言の相互運用

現時点では、StableHLO プログラムには、StableHLO で定義されていないオペレーションが含まれている場合があります。

モジュール、関数、呼び出し、リターン

StableHLO は、ModuleOp、FuncOp、CallOp、ReturnOp にアップストリーム MLIR オペレーションを使用します。これは、既存の MLIR 機構との相互運用性を高めるために行われました。これは、FuncOp と ModuleOp をターゲットとする多くの有用なパスが記述されており、多くのコンパイル パイプラインは、これらの op が存在することを想定しているためです。これらの op には、完全な互換性の保証が適用されます。これらのオペレーションが互換性のない方法で変更された場合(削除など)、互換性を維持するために StableHLO と同等のオペレーションが追加されます。

CHLO

CHLO オペセットには、StableHLO に分解される上位レベルのオペレーションが含まれています。現在のところ、CHLO の互換性は保証されていません。互換性を保証するには、シリアル化の前に chlo-legalize-to-stablehlo パスを使用する必要があります。

シェイプ オペレーション

動的な StableHLO プログラムでコア MLIR 言語の特定のオペレーションを使用して、シェイプ計算を実行することは、コミュニティで一般的なユースケースです。一般的には、shape_ofnum_elements などの shape 言語演算、dimfrom_elements などの tensor 言語演算、組み込みの index 型演算があります。

Dynamism RFC > O2 では、これらはサポート範囲外とされています。ただし、相互運用性のために index 型の一部がサポートされています。これらの演算や型には互換性は保証されません。shape-legalize-to-stablehlo パスを使用して、これらのオペレーションを完全にサポートされている StableHLO オペレーションに変換できます。

非推奨のオペレーション

MHLO から継承された StableHLO オペレーションがいくつかあります。これらは非推奨で、StableHLO から移行される予定です。これらの削除の詳細については、StableHLO v1.0 のクリーンアップ #2283 をご覧ください。これらのサポート終了に関するトラッカーの問題は #2340 です。

これらのオペレーションは、次のカテゴリに分類されます。

  • StableHLO オペレーションの「HLO に含まれない」カテゴリ - 最初は StableHLO オペセットの一部でしたが、後で適切ではないと判断されました。broadcastcreate_tokencross-replica-sumdoteinsumtorch_index_selectunary_einsum#3)。
  • 使用されていないオペレーション - これらのオペレーションは、ある時点では有用だったかもしれませんが、オペレーションが十分に開発されていないか、これらのオペレーションを使用するパイプラインがリファクタリングされ、オペレーションが不要になっています。これには、maptuple#598)、get_tuple_elementrngcomplex の比較 #560、畳み込み window_reversal#1181)が含まれます。

これらのオペレーションの一部は、既存のオペレーション(broadcastcreate_tokencross-replica-sumdotunary_einsum)を使用して表現できるため、簡単に削除できます。これらのオペレーションは、既存の互換性期間(6 か月)が経過すると削除されます。その他のオペレーションは、削除の検討中です(einsumget_tuple_elementmaprng torch_index_selecttuplecomplex 比較、window_reversal)。コミュニティからのフィードバックに基づいて、これらのオペレーションは削除されるか、完全なサポートで仕様に追加されます。これらのオペレーションの将来が判明するまでは、互換性が保証されるのは 6 か月間のみです。

実行

順次実行

StableHLO プログラムを実行するには、main 関数に入力値を指定し、出力値を計算します。関数の出力値は、対応する return オペレーションにルートを持つオペレーションのグラフを実行することで計算されます。

実行順序は、データフローに沿っている限り(つまり、オペレーションが使用前に実行される場合)、実装で定義されます。StableHLO では、副作用のある演算はすべて 1 つのトークンを消費し、1 つのトークンを生成します(複数のトークンは after_all を介して 1 つのトークンに多重化できます)。そのため、副作用の実行順序もデータフローに沿っています。たとえば、次のプログラムでは、%0%1%2return%1%0%2return の 2 つの実行順序が可能です。

func.func @main() -> tensor<f64> {
  %0 = stablehlo.constant dense<1.0> : tensor<f64>
  %1 = stablehlo.constant dense<2.0> : tensor<f64>
  %2 = stablehlo.add %0, %1 : tensor<f64>
  return %2 : tensor<f64>
}

より正式には、StableHLO プロセスは、1)StableHLO プログラム、2)オペレーションのステータス(まだ実行されていない、すでに実行されている)、3)プロセスが処理中の中間値を組み合わせたものです。このプロセスは、main 関数への入力値から始まり、オペレーションのステータスと中間値を更新するオペレーションのグラフを経て、出力値で終了します。形式化は未定です(#484)。

並列実行

StableHLO プログラムは並行して実行でき、どちらも ui32 型の num_partitionsnum_replicas の 2D プロセス グリッドに編成できます。

StableHLO プロセス グリッドでは、StableHLO プロセスの num_replicas * num_partitions が同時に実行されています。各プロセスには一意の process_id = (replica_id, partition_id) があります。replica_ids = range(num_replicas)replica_idpartition_ids = range(num_partitions)partition_id はどちらも ui32 型です。

プロセスグリッドのサイズはプログラムごとに静的に把握され(将来的には StableHLO プログラム #650 で明示的な一部にする予定です)、プロセスグリッド内の位置はプロセスごとに静的に把握されます。各プロセスは、replica_id オペレーションと partition_id オペレーションを介して、プロセス グリッド内の位置にアクセスできます。

プロセス グリッド内で、プログラムはすべて同じ(「単一プログラム、複数のデータ」スタイル)にすることも、すべて異なる(「複数プログラム、複数のデータ」スタイル)にすることも、その中間にも設定できます。今後、GSPMD(#619)など、並列 StableHLO プログラムを定義する他のイディオムのサポートを導入する予定です。

プロセス グリッド内のプロセスは、ほとんどが互いに独立しています。オペレーションのステータス、入力値、中間値、出力値はそれぞれ別々で、ほとんどのオペレーションはプロセス間で別々に実行されます(後述する少数の集約オペレーションを除く)。

ほとんどのオペレーションの実行では同じプロセスの値のみが使用されるため、通常、これらの値を名前で参照してもあいまいさはありません。ただし、集約オペレーションのセマンティクスを記述する場合、これは不十分です。そのため、特定のプロセス内の値 name を参照する記号 name@process_id が使用されます。(この観点から、修飾されていない namename@(replica_id(), partition_id()) の省略形と見なすことができます)。

プロセス間の実行順序は、以下で説明するように、ポイントツーポイント通信と集約オペレーションによって導入される同期を除き、実装で定義されます。

ポイントツーポイント通信

StableHLO プロセスは、StableHLO チャネルを介して相互に通信できます。チャネルは、si64 型の正の ID で表されます。さまざまなオペレーションを使用して、値をチャネルに送信したり、チャネルから受信したりできます。

これらのチャンネル ID の取得元、プロセス プログラムが認識する方法、プロセスによって行われる同期の種類など、さらなる形式化は未定(#484)です。

ストリーミング通信

すべての StableHLO プロセスは、次の 2 つのストリーミング インターフェースにアクセスできます。

  • 読み取り可能なInfeed
  • 書き込み可能なアウトフィードの

プロセス間の通信に使用され、両端にプロセスがあるチャネルとは異なり、インフィードとアウトフィードでは、もう一方の端は実装で定義されます。

ストリーミング通信が実行順序に与える影響や、それによって行われる同期の種類など、さらなる形式化は未定です(#484)。

集団オペレーション

StableHLO には、all_gatherall_reduceall_to_allcollective_broadcastcollective_permutereduce_scatter の 6 つのグループ演算があります。これらのオペレーションはすべて、StableHLO プロセス グリッド内のプロセスを StableHLO プロセス グループに分割し、他のプロセス グループとは独立して、各プロセス グループ内で共同計算を実行します。

各プロセス グループ内で、集約オペレーションによって同期バリアが発生する可能性があります。この同期が正確にいつ行われるか、プロセスがこの障壁に到達する正確な方法、到達しなかった場合にどうなるかなど、さらなる形式化は未定です(#484)。

プロセス グループにパーティション間の通信が含まれる場合(パーティション ID が異なるプロセスがプロセス グループ内にある場合)、集約オペレーションの実行にはチャネルが必要であり、集約オペレーションは si64 型の正の channel_id を提供する必要があります。レプリカ間の通信にはチャネルは必要ありません。

集約オペレーションによって実行される計算は個々のオペレーションに固有であり、上記の個々のオペレーションのセクションで説明しています。ただし、プロセス グリッドをプロセス グループに分割する戦略はこれらのオペレーション間で共有され、このセクションで説明します。より正式には、StableHLO は次の 4 つの戦略をサポートしています。

cross_replica

各プロセス グループ内では、レプリカ間の通信のみが発生します。この戦略では、replica_groups(レプリカ ID のリストのリストのリスト)を受け取り、replica_groupspartition_ids の直積を計算します。replica_groups には一意の要素があり、すべての replica_ids をカバーする必要があります。より正式には、Python 構文を使用します。

def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
  for replica_group in replica_groups:
    for partition_id in partition_ids:
      process_group = []
      for replica_id in replica_group:
        process_group.append((replica_id, partition_id))
      yield process_group

たとえば、replica_groups = [[0, 1], [2, 3]]num_partitions = 2 の場合、cross_replica[[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]] を生成します。

cross_partition

各プロセス グループ内で行われる通信は、パーティション間の通信のみです。この戦略は、partition_groups(パーティション ID のリストのリスト)を受け取り、replica_ids による partition_groups のデカルト積を計算します。partition_groups には一意の要素があり、すべての partition_ids をカバーする必要があります。より正式には、Python 構文を使用して次のようにします。

def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
  for partition_group in partition_groups:
    for replica_id in replica_ids:
      process_group = []
      for partition_id in partition_group:
        process_group.append((replica_id, partition_id))
      yield process_group

たとえば、partition_groups = [[0, 1]]num_replicas = 4 の場合、cross_partition[[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]] を生成します。

cross_replica_and_partition

各プロセス グループ内で、レプリカ間通信とパーティション間通信の両方が発生する可能性があります。この戦略は、レプリカ ID のリストのリストである replica_groups を受け取り、partition_ids によって各 replica_group のデカルト積を計算します。replica_groups は一意の要素を持ち、すべての replica_ids をカバーする必要があります。より正式には、Python 構文を使用して次のようにします。

def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
  for replica_group in replica_groups:
    process_group = []
    for partition_id in partition_ids:
      for replica_id in replica_group:
        process_group.append((replica_id, partition_id))
    yield process_group

たとえば、replica_groups = [[0, 1], [2, 3]]num_partitions = 2 の場合、cross_replica_and_partition[[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]] を生成します。

flattened_ids

この戦略では、flattened_id_groupsreplica_id * num_partitions + partition_id 形式の「フラット化」されたプロセス ID のリスト)を受け取り、プロセス ID に変換します。flattened_id_groups は一意の要素を持ち、すべての process_ids をカバーする必要があります。より正式には、Python 構文を使用して次のようにします。

def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
  for flattened_id_group in flattened_id_groups:
    process_group = []
    for flattened_id in flattened_id_group:
      replica_id = flattened_id // num_partitions
      partition_id = flattened_id % num_partitions
      process_group.append((replica_id, partition_id))
    yield process_group

たとえば、flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]]num_replicas = 4num_partitions = 2 の場合、flattened_ids[[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]] を生成します。

精度

現時点では、StableHLO は数値の精度を保証していませんが、今後変更される可能性があります(#1156)。

量子化オペレーションの実行セマンティクス

量子化された StableHLO 演算の解釈は、ハードウェアの要件と機能によって異なる場合があります。たとえば、一部のハードウェアでは、「逆量子化し、浮動小数点演算を実行して、最終的に量子化」戦略を使用して量子化演算を解釈します。他の関数では、整数演算で計算全体を実行することもあります。したがって、量子化された StableHLO オペレーションの解釈は、特定の実装によってのみ決まります。ハイブリッド量子化(#1575)の解釈は、仕様(1792 を介して)で規定されているそのセマンティクスに基づく必要があります。

エラー

StableHLO プログラムは、個々のオペレーションに対する広範な制約セットによって検証され、実行前に多くのエラークラスが除外されます。ただし、整数オーバーフローや境界外アクセスなどのエラー条件は引き続き発生する可能性があります。明示的に示されていない限り、これらのエラーはすべて実装定義の動作になりますが、今後変更される可能性があります(#1157)。

浮動小数点例外

このルールの例外として、StableHLO プログラムの浮動小数点例外には明確に定義された動作があります。IEEE-754 標準で定義されている例外(無効なオペレーション、ゼロ除算、オーバーフロー、アンダーフロー、不正確な例外)が発生するオペレーションは、(標準で定義されているように)デフォルトの結果を生成し、対応するステータス フラグを発生させずに実行を続行します。これは、標準の raiseNoFlag 例外処理に似ています。標準以外の演算(複素演算や特定の超越関数など)の例外は、実装で定義されます。

形状の不一致

StableHLO は、動的に形状を変更できるテンソルをサポートしています。ただし、実行時にシェイプが一致する必要があります。一致しない場合、動作は未定義になります。StableHLO では、実行時にテンソルが特定の形状であることをアサートできる op が明示的に提供されていません。正しいコードを生成するのは、プロデューサーの責任です。

具体的な例として、以下のプログラムは有効です。ただし、実行時には %arg0%arg1 の正確な形状が同じである必要があります。一致していない場合、プログラムの動作は未定義になります。

func.func @foo(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
    %0 = stablehlo.add %arg0, %arg1 : tensor<?xi32>
    return %0 : tensor<?xi32>
}

Notation

このドキュメントでは、構文の記述に、EBNF 構文の ISO フレーバー(ISO/IEC 14977:1996Wikipedia)を変更して使用しています。変更点は 2 つあります。1)ルールは = ではなく ::= を使用して定義されます。

2)連結は , ではなく並置を使用して表されます。

セマンティクスの記述(「型」、「定数」、「オペレーション」セクション内)では、Python 構文に基づく式を使用しています。この式は、以下で説明するように、配列オペレーションを簡潔に表現できるように拡張されています。これは小さなコード スニペットには適していますが、大きなコード スニペットが必要なまれなケースでは、常に明示的に導入される標準の Python 構文を使用します。

数式

dot_general 仕様の例に基づいて、数式の仕組みを見てみましょう。このオペレーションの制約の 1 つは次のとおりです。dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)

この式で使用される名前は、2 つのソースから取得されます。1 つはグローバル関数(dim など)、もう 1 つは対応するプログラム要素のメンバー定義(dot_general の [入力] セクションで定義された lhslhs_batching_dimensionsrhsrhs_batching_dimensions の入力)です。

前述のとおり、この式の構文は Python ベースで、簡潔さを重視した拡張機能がいくつかあります。式を理解するために 通常の Python 構文に変換しましょう

A)これらの数式では、= を使用して等号を表しています。Python 構文を取得するための最初のステップは、=== に置き換えることです。dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...) のようにします。

B)また、これらの式では、スカラー式をテンソル式に変換する楕円記号(...)もサポートされています。簡単に言うと、f(xs...) は「テンソル xs 内のスカラー x ごとにスカラー f(x) を計算し、これらのスカラー結果をすべてテンソル結果として返す」ことを意味します。通常の Python 構文の場合、数式の例は [dim(lhs, dim1) for dim1 in lhs_batching_dimensions] == [dim(rhs, dim2) for dim2 in rhs_batching_dimensions] になります。

楕円記法を使用すると、個々のスカラーレベルで作業する必要がなくなります。ただし、場合によっては、gather 仕様の start_indices[bi0, ..., :, ..., biN] 数式のように、下位レベルの準非形式構文が使用されることがあります。簡潔にするため、このような構文を標準の Python に変換するための正確な形式主義は提供していません。ケースごとに直感的に理解できるようにすることを目的としています。特定の式がわかりにくい場合はお知らせください。改善を検討いたします。

また、式では楕円を使用して、テンソル、テンソルのリスト(可変長のテンソルから発生する可能性があるものなど)など、あらゆる種類のリストを展開します。これは、正確な形式主義を提供しない(リストは StableHLO 型システムの一部ではないなど)もう 1 つの領域であり、直感的な理解に依存しています。

C)最後に、暗黙的なブロードキャストという注目すべき記法について説明します。StableHLO opset は暗黙的なブロードキャストをサポートしていませんが、簡潔さを提供する目的で数式はサポートしています。簡単に言うと、テンソルが予想されるコンテキストでスカラーを使用すると、スカラーは期待される形状にブロードキャストされます。

dot_general の例を続けると、別の制約 0 <= lhs_batching_dimensions < rank(lhs) があります。dot_general 仕様で定義されているように、lhs_batching_dimensions はテンソルですが、0rank(lhs) はどちらもスカラーです。暗黙的なブロードキャストを適用すると、式は [0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)] になります。

特定の dot_general 演算に適用されると、この数式はブール値のテンソルに評価されます。数式を制約として使用する場合、数式が true と評価された場合、または true 要素のみを含むテンソルと評価された場合に制約が適用されます。

名前

数式では、1)グローバル関数、2)メンバー定義が、レキシカル スコープに含まれます。

3)ローカル定義グローバル関数のリストを以下に示します。要素定義のリストは、表記が適用されるプログラム要素によって異なります。

  • オペレーションの場合、メンバー定義には「入力」セクションと「出力」セクションで説明した名前が含まれます。
  • それ以外の場合、メンバー定義には、対応する EBNF 非終端記号にちなんで命名された、プログラム要素の構造部分が含まれます。ほとんどの場合、これらの構造部分の名前は、非終端記号の名前をスネークケースに変換することで取得されます(例: IntegerLiteral => integer_literal)。ただし、このプロセスで名前が省略されることもあります(例: QuantizationStorageType => storage_type)。その場合は、オペレーション仕様の [入力] / [出力] セクションと同様に、名前が明示的に導入されます。
  • また、メンバー定義には常に self が含まれ、対応するプログラム要素を参照します。

数式の評価では、次の種類の値が使用されます。1)Value(実際の値、例: dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>。型は常に判明しています)、2)Placeholder(将来の値、例: lhsrhsresult。実際の値はまだ判明しておらず、型のみが判明しています)、3)Type([型] セクションで定義されている型)、4)Function([関数] セクションで定義されているグローバル関数)。

コンテキストによっては、名前が異なる値を参照している場合があります。具体的には、オペレーションの「セマンティクス」セクション(および他のプログラム要素の同等のセクション)でランタイム ロジックが定義されているため、すべての入力を Value として使用できます。これに対して、op(および同等のもの)の「制約」セクションでは「コンパイル時」ロジック、つまり通常はランタイム前に実行されるロジックを定義しています。そのため、定数入力のみが Value として使用でき、その他の入力は Placeholder としてのみ利用できます。

名前 「Semantics」 [制約] で
グローバル機能 Function Function
定数入力 Value Value
定数以外の入力 Value Placeholder
出力 Value Placeholder
ローカル定義 定義によって異なる 定義によって異なる

transpose オペレーションの例を見てみましょう。

%result = "stablehlo.transpose"(%operand) {
  permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>

このオペレーションでは、permutation は定数であるため、セマンティクスと制約の両方で Value として使用できます。一方、operandresult は、セマンティクスでは Value として使用できますが、制約では Placeholder としてのみ使用できます。

関数

型の構築

型の作成に使用できる関数はありません。通常は、より簡潔なタイプ構文を直接使用します。たとえば、function_type( [tensor_type([], E), tensor_type([], E)], [tensor_type([], E)]) ではなく (tensor<E>, tensor<E>) -> (tensor<E>) です。

型の関数

  • element_type はテンソル型と量子化テンソル型で定義され、それぞれ対応する TensorType または QuantizedTensorTypeTensorElementType または QuantizedTensorElementType 部分を返します。
def element_type(x: Value | Placeholder | Type):
 if type(x) == TensorType:
    return tensor_element_type(x)
  if type(x) == QuantizedTensorType:
    return quantized_tensor_element_type(x)
  if type(x) is not Type:
    return element_type(type(x))
  • is_per_axis_quantized(x: Value | Placeholder | Type) -> Valueis_quantized(x) and quantization_dimension(x) is not None のショートカットです。

  • is_per_tensor_quantized(x: Value | Placeholder | Type) -> Valueis_quantized(x) and quantization_dimension(x) is None のショートカットです。

  • is_promotable(x: Type, y: Type) -> bool は、x 型を y 型に昇格できるかどうかを確認します。xyQuantizedTensorElementType の場合、プロモーションは storage_type にのみ適用されます。この特定のバージョンのプロモーションは、現在、削減計算のコンテキストで使用されています(詳細については、RFC をご覧ください)。

def is_promotable(x: Type, y: Type) -> Value:
  is_same_type = (is_bool(x) and is_bool(y)) or
    (is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
    (is_complex(x) and is_complex(y)) or
    (is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))

  if is_same_type == False:
    return False

  if is_integer(x) or is_float(x):
    return bitwidth(x) <= bitwidth(y)

  if is_complex(x):
    return bitwidth(element_type(x)) <= bitwidth(element_type(y))

  if is_quantized(x):
    return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))

  return false
  • is_quantized(x: Value | Placeholder | Type) -> Valueis_quantized_tensor_element_type(x) のショートカットです。

  • is_type_name(x: Value | Placeholder | Type) -> Value。すべてのタイプで使用できます。たとえば、xFloatType の場合、is_float(x)true を返します。x が値またはプレースホルダの場合、この関数は is_type_name(type(x)) のショートカットになります。

  • max_value(x: Type) -> ValueTensorElementType の最大値を返します。xTensorElementType でない場合、None を返します。

  • min_value(x: Type) -> Value は、TensorElementType の可能な最小値を返します。xTensorElementType でない場合、None を返します。

  • member_name(x: Value | Placeholder | Type) -> Any。すべてのタイプのすべてのメンバー定義 member_name で使用できます。たとえば、tensor_element_type(x) は対応する TensorTypeTensorElementType 部分を返します。x が値またはプレースホルダの場合、この関数は member_name(type(x)) のショートカットです。x が適切なメンバーを持つ型、またはそのような型の値またはプレースホルダでない場合、None を返します。

  • is_empty_algorithm(*args: Type) は、すべてのドット アルゴリズム フィールドが None に設定されているかどうかを確認します。これは、ドット アルゴリズムの実装でデフォルトの動作が定義されているため、デフォルト値を指定すると正しくありません。

値の構成

  • operation_name(*xs: Value | Type) -> Value。すべてのオペレーションで使用できます。たとえば、add(lhs, rhs) は 2 つのテンソル値 lhsrhs を受け取り、これらの入力で add 演算を評価した結果を返します。broadcast_in_dim などの一部のオペレーションでは、出力のタイプが「負荷を負う」タイプです。つまり、オペレーションの評価に必要です。この場合、関数はこれらの型を引数として受け取ります。

値の関数

  • Python のすべての演算子と関数を使用できます。たとえば、Python のサブスクリプションスライスの両方の表記を使用して、テンソル、量子化テンソル、タプルにインデックスを付けることができます。

  • to_destination_type(x: Value, destination_type: Type) -> Value はテンソルで定義され、次のように type(x)destination_type に基づいて x の変換値を返します。

def to_destination_type(x: Value, destination_type: Type) -> Value:
  if type(x) == destination_type:
    return x

  if is_quantized(destination_type):
    if is_quantized(type(x)):
      return quantize(x, destination_type)
    assert is_float(type(x))
    return quantize(x, destination_type)

  if is_quantized(type(x)):
    assert destination_type = expressed_type(type(x))
    return dequantize(type(x))

  return convert(x, destination_type)

convertuniform_quantizeuniform_dequantize オペレーションのマージについては、早期の議論があります(#1576)。 マージ後、上記の関数は不要になり、代わりに convert のオペレーション名を使用できます。

  • is_nan(x: Value) -> Value はテンソルで定義され、x のすべての要素が NaN の場合は true を返します。それ以外の場合は false を返します。x がテンソルでない場合、None を返します。

  • is_sorted(x: Value) -> Value はテンソルで定義され、x の要素がインデックスの辞書順の昇順で並べ替えられている場合は true を返し、そうでない場合は false を返します。x がテンソルでない場合は、None を返します。

  • is_unique(x: Value) -> Value はテンソルで定義され、x に重複する要素がない場合 true を返します。それ以外の場合は false を返します。x がテンソルでない場合は、None を返します。

  • member_name(x: Value) -> Any は、すべての値のすべてのメンバー定義 member_name に対して定義されます。たとえば、real_part(x) は対応する ComplexConstantRealPart 部分を返します。x が適切なメンバーを含む値でない場合、None を返します。

  • same(x: Value) -> Value はテンソルに対して定義され、x の要素がすべて同じであれば true を返し、それ以外の場合は false を返します。テンソルに要素がない場合、それは「すべて等しい」と見なされます。つまり、関数は true を返します。x がテンソルでない場合は、None を返します。

  • split(x: Value, num_results: Value, axis: Value) -> Value はテンソルで定義され、軸 axis に沿って xnum_results スライスを返します。x がテンソルまたは dim(x, axis) % num_results != 0 でない場合は、None を返します。

  • is_defined_in_parent_scope(x: Value) -> Value は文字列で定義され、x が関連するオペレーションの親関数と同じスコープで定義された関数の名前である場合、true を返します。

  • is_namespaced_op_name(x: Value) -> Value は文字列で定義され、x が有効なオペレーション名([a-zA-Z][a-zA-Z0-9_]*([.][a-zA-Z0-9_$]+)+ という正規表現に従う)である場合は true を返します。

シェイプの計算

  • axes(x: Value | Placeholder | Type) -> Valuerange(rank(x)) のショートカットです。

  • dim(x: Value | Placeholder | Type, axis: Value) -> Valueshape(x)[axis] のショートカットです。

  • dims(x: Value | Placeholder | Type, axes: List) -> Listlist(map(lambda axis: dim(x, axis), axes)) のショートカットです。

  • index_space(x: Value | Placeholder | Type) -> Value はテンソルで定義され、対応する TensorTypesize(x) インデックスを辞書順で昇順([0, ..., 0][0, ..., 1]、...、shape(x) - 1)で返します。x がテンソル型、量子化テンソル型、値、またはこれらの型のプレースホルダでない場合、None を返します。

  • rank(x: Value | Placeholder | Type) -> Valuesize(shape(x)) のショートカットです。

  • shape(x: Value | Placeholder | Type) -> Value は、「型の関数」セクションで member_name を介して定義されます。

  • size(x: Value | Placeholder | Type) -> Valuereduce(lambda x, y: x * y, shape(x)) のショートカットです。

量子化計算

  • def baseline_element_type(x: Value | Placeholder | Type) -> Typeelement_type(baseline_type(x)) のショートカットです。

  • baseline_type はテンソル型と量子化テンソル型で定義され、それらを「ベースライン」に変換します。ベースラインとは、同じ形状で、要素型の量子化パラメータがデフォルト値にリセットされた型です。これは、テンソルと量子化テンソルの両方のタイプを均一に比較するための便利なトリックとして使用されます。これは非常に頻繁に必要になります。量子化タイプの場合、量子化パラメータを無視してタイプを比較できます。つまり、shapestorage_typeexpressed_typestorage_minstorage_maxquantization_dimension(軸ごとの量子化タイプの場合)はすべて一致する必要がありますが、scaleszero points は異なる場合があります。

def baseline_type(x: Value | Placeholder | Type) -> Type:
  if type(x) == TensorType:
    return x
  if type(x) == QuantizedTensorType:
    element_type = quantized_tensor_element_type(x)
    baseline_element_type = QuantizedTensorElementType(
      storage_type = storage_type(element_type),
      storage_min = storage_min(element_type),
      storage_max = storage_max(element_type),
      expressed_type = expressed_type(element_type),
      quantization_dimension = quantization_dimension(element_type),
      scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
      zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
    return QuantizedTensorType(shape(x), baseline_element_type)
  if type(x) is not Type:
    return baseline_element_type(type(x))
  • dequantize は量子化テンソル型に対して定義され、浮動小数点テンソル型に変換します。これは、量子化要素型に関連付けられたゼロ点とスケールを使用して、ストレージ型の整数値を表す量子化要素を、表現型の対応する浮動小数点値に変換することで行われます。
def compute_zero_points(quantized_type, result_type):
  if is_per_tensor_quantized(quantized_type):
    return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
  if is_per_axis_quantized(quantized_type):
    for i in index_space(result_type):
      d = quantization_dimension(quantized_type)
      zero_points[i] = zero_points(quantized_type)[i[d]]
    return zero_points

def compute_scales(quantized_type, result_type):
  if is_per_tensor_quantized(quantized_type):
    return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
            type(result_type))
  if is_per_axis_quantized(quantized_type):
    for i in index_space(result_type):
      d = quantization_dimension(quantized_type)
      scales[i] = scales(quantized_type)[i[d]]
    return scales

def dequantize(x: Value) -> Value:
  assert is_quantized(x)
  x_storage = bitcast_convert(x, storage_type(x))
  x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
  x_expressed_sub = convert(x_storage_sub, expressed_type(x))
  return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
  • quantize は浮動小数点テンソル型で定義され、量子化されたテンソル型に変換します。これは、量子化された要素型に関連付けられたゼロ点とスケールを使用して、表現型の浮動小数点値をストレージ型の対応する整数値に変換することで行われます。
def quantize(x: Value, result_type: Type) -> Value:
  assert is_float(x) and is_quantized(result_type)
  zero_points = compute_zero_points(result_type, TensorType(shape(x), storage_type(result_type)))
  converted_zero_points = convert(zero_points, expressed_type(result_type))
  converted_min = convert(storage_min(result_type), expressed_type(result_type))
  converted_max = convert(storage_max(result_type), expressed_type(result_type))

  x_scaled = x / compute_scales(result_type, type(x))
  x_scaled_add_zp = x_scaled + converted_zero_points
  x_clamped = clamp(converted_min, x_scaled_add_zp, converted_max)
  x_rounded = round_nearest_even(x_clamped)
  return convert(x_rounded, result_type)
  • dequantize_op_quantize は、量子化テンソルの要素単位の計算を指定するために使用されます。デクォンタイズ(量子化された要素を表現型に変換)してから演算を実行し、量子化(結果をストレージ タイプに戻す)します。現時点では、この関数はテンソルごとの量子化でのみ機能します。軸ごとの量子化は現在開発中です(#1574)。
def dequantize_op_quantize(op, *inputs_and_output_type):
  inputs = inputs_and_output_type[:-1]
  output_type = inputs_and_output_type[-1]

  float_inputs = map(dequantize, inputs)
  float_result = op(*float_inputs)
  return quantize(float_result, output_type)

def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
  inputs = inputs_and_output_type[:-3]
  float_inputs = map(dequantize, inputs)
  float_results = op(*float_inputs)
  return map(quantize, float_results, inputs_and_output_type[-3:])

def dequantize_compare(lhs, rhs, comparison_direction):
  float_lhs = dequantize(lhs)
  float_rhs = dequantize(rhs)
  return compare(float_lhs, float_rhs, comparison_direction, FLOAT)

def dequantize_select_quantize(pred, on_true, on_false, output_type):
  float_on_true = dequantize(on_true)
  float_on_false = dequantize(on_false)
  float_result = select(pred, float_on_true, float_on_false)
  return quantize(float_result, output_type)
  • hybrid_dequantize_then_op は、浮動小数点型の lhs と量子化された型の rhs を受け入れるハイブリッド演算の重みのみの量子化を指定するために使用されます。量子化された入力を表現型にデクォンタイズし、浮動小数点数で計算を行います。浮動小数点の左辺テンソルの要素型と、量子化された右辺テンソルの表現型は同じである必要があります。
def hybrid_dequantize_then_op(op, lhs, rhs):
  assert(is_float(lhs) and is_quantized(rhs) and element_type(lhs) == expressed_type(rhs))
  return op(lhs, dequantize(rhs))

グリッド計算

  • cross_partition(replica_groups: Value) -> Value。上記の「cross_replica」セクションをご覧ください。

  • cross_replica(replica_groups: Value) -> Value。上記の「cross_replica」セクションをご覧ください。

  • cross_replica_and_partition(replica_groups: Value) -> Value。上記の「cross_replica_and_partition」セクションをご覧ください。

  • flattened_ids(replica_groups: Value) -> Value。上記の「flattened_ids」セクションをご覧ください。

ダイナミズム

StableHLO 値には、動的なディメンション サイズ(tensor<?xi64> など)を指定できます。ただし、StableHLO 値には動的ディメンション数(ランクなしの動的性、tensor<*xi64> など)を指定できません。オペランドと結果では、サイズに制約がある場合でも、動的ディメンション サイズを使用できます。制約は可能であれば静的に検証されます。検証できない場合は実行時に延期され、不一致により未定義の動作が発生します。例については以下をご覧ください。

単項要素ごとのオペレーションのシェイプの不一致

次のサンプル プログラムを考えてみましょう。

func.func @foo(%arg0: tensor<?xf64>) {
  %0 = stablehlo.abs %arg0 : (tensor<?xf64>) -> tensor<2xf64>
  return
}

このようなプログラムは珍しいものです。結果の形状はわかっても、入力の形状がわからないことは一般的ではありません。ただし、これは有効な StableHLO プログラムです。このプログラムの abs オペレーションを静的に検証することはできません。オペランドの正確な形状が不明であるためです。ただし、シェイプは確かに互換性があり、これは静的に確認できます。? が実行時に 2 になっても問題はありません。ただし、? が他の整数になる可能性もあります。その場合、動作は未定義になります。

結果でディメンション サイズが動的である場合、未定義の動作は発生しません。実際、サイズは「想定」されていないため、不一致は発生しません。

バイナリ要素ごとのオペレーションのシェイプの不一致

次のおもちゃプログラムについて考えてみましょう。

func.func @foo(%arg0: tensor<?xf64>, %arg1: tensor<?xf64>) {
  %0 = stablehlo.add %arg0, %arg0 : (tensor<?xf64>, tensor<?xf64>) -> tensor<?xf64>
  return
}

バイナリ要素ごとのオペレーションの場合、入力と結果の形状は実行時に一致している必要があります。コンパイル時に静的ディメンションは同じである必要があります。それ以外の場合は、互換性がある必要があります。入力でいずれかのディメンションが動的である場合、動的サイズが他のオペランドの対応するサイズ(静的または動的)と一致しないため、ランタイムで未定義の動作が発生する可能性があります。すべての入力が静的である場合、結果が動的かどうかは重要ではありません。静的に既知のディメンションは静的にチェックされ、動的ディメンションには制約が適用されません。

出力シェイプをオペランドとして取るオペレーションのシェイプの不一致

次のサンプル プログラムを考えてみましょう。

func.func @foo(%arg0: tensor<2xi32>) {
  %0 = stablehlo.dynamic_iota %arg0, dim = 0 : (tensor<2xi32>) -> tensor<3x4xi64>
  return
}

実行時にシェイプ オペランドの値は、結果のシェイプと一致する必要があります。一致しない場合、動作は未定義になります。つまり、実行時に %arg0 の値は dense<[3, 4]> : tensor<2xi32> にする必要があります。シェイプ オペランドが定数の場合は、静的に検証できます。結果の形状が完全に動的であれば、不一致が生じることはありません。