StableHLO の仕様

StableHLO は、機械学習(ML)モデルの高レベル演算(HLO)用の演算セットです。StableHLO は、さまざまな ML フレームワークと ML コンパイラ間のポータビリティ レイヤとして機能します。StableHLO プログラムを生成する ML フレームワークは、StableHLO プログラムを使用する ML コンパイラと互換性があります。

Google の目標は、さまざまな ML フレームワーク(TensorFlow、JAX、PyTorch など)と ML コンパイラ(XLA、IREE など)との間の相互運用性を強化することで、ML 開発を簡素化し加速することです。そのために、このドキュメントでは StableHLO プログラミング言語の仕様について説明します。

この仕様には 3 つの主要セクションがあります。まず、プログラム セクションでは、StableHLO オペレーションで構成される StableHLO 関数で構成される StableHLO プログラムの構造について説明します。Ops セクションでは、その構造内で個々のオペレーションのセマンティクスを指定します。Execution セクションでは、プログラム内で一緒に実行されるすべてのオペレーションのセマンティクスが提供されます。最後に、表記のセクションで、この仕様全体で使用される表記について説明します。

プログラム

Program ::= {Func}

StableHLO プログラムは、任意の数の StableHLO 関数で構成されています。以下は、3 つの入力(%image%weights%bias)と 1 つの出力を持つ関数 @main のプログラムの例です。関数の本体には 6 つの演算があります。

func.func @main(
  %image: tensor<28x28xf32>,
  %weights: tensor<784x10xf32>,
  %bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
  %0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
  %1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
  %2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  %3 = "stablehlo.constant"() { value = dense<0.0> : tensor<1x10xf32> } : () -> tensor<1x10xf32>
  %4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  "func.return"(%4): (tensor<1x10xf32>) -> ()
}

関数

Func        ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs  ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput   ::= '%' ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput  ::= ValueType
FuncBody    ::= {Op}

StableHLO 関数(名前付き関数とも呼ばれます)には、識別子、入出力、本体があります。将来的には、HLO との互換性を高めるために、関数に追加のメタデータを導入する予定です(#425#626#740#744)。

識別子

FuncId  ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
          | '%' letter {letter | digit}
letter  ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit   ::= '0' | ... | '9'

StableHLO 識別子は多くのプログラミング言語の識別子に似ていますが、1)すべての識別子に異なる種類の識別子を区別するシジルがある点、2)値識別子に完全に数値を使用して StableHLO プログラムの生成を簡素化する点が 2 つあります。

Type         ::= ValueType | NonValueType
ValueType    ::= TensorType | QuantizedTensorType | TokenType | TupleType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType

StableHLO 型は、StableHLO 値を表す値の型(ファースト クラス型)と、他のプログラム要素を記述する非値型に分類されます。StableHLO 型は多くのプログラミング言語の型に似ていますが、主な特徴は StableHLO のドメイン固有の性質であり、通常とは異なる結果(スカラー型が値の型ではないなど)があります。

TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit}

テンソル型は、テンソル、つまり多次元配列を表します。2 つの要素にはシェイプ要素タイプがあり、シェイプは負でないディメンションのサイズを、対応するディメンション(軸とも呼ばれます)0 から R-1 まで)の昇順で表します。ディメンションの数 R はランクと呼ばれます。たとえば、tensor<2x3xf32> は形状 2x3 と要素型 f32 のテンソル型です。このレイヤには、0 番目のディメンションと 1 番目のディメンションの 2 つのディメンション(つまり、2 つの軸)があり、サイズは 2 と 3 です。その順位は 2 です。

これは、寸法サイズが静的に認識される静的シェイプのサポートを定義します。将来的には、寸法サイズが部分的または完全に不明な動的シェイプのサポートも導入する予定です(#8)。さらに、レイアウト(#629)やスパース(#1078)など、ディメンションサイズや要素タイプを超えたテンソル型を拡張することも計画しています。

QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
                  QuantizationStorageType
                  ['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
                  ':' QuantizationExpressedType
                  [':' QuantizationDimension]
                  ',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerConstant
QuantizationStorageMax ::= IntegerConstant
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerConstant
QuantizationParameters ::= QuantizationParameter
                         | '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale ':' QuantizationZeroPoint
QuantizationScale ::= FloatConstant
QuantizationZeroPoint ::= IntegerConstant
名前 タイプ 制約
storage_type 整数型 (C1 ~ C4)、(C9)
storage_min 整数定数 (C2)、(C4)、(C8)
storage_max 整数定数 (C3)、(C4)、(C8)
expressed_type 浮動小数点型 (C1)、(C5)
quantization_dimension オプションの整数定数 (C11 ~ C13)
scales 浮動小数点定数の可変数 (C5 ~ C7)、(C10)、(C11)、(C13)
zero_points 整数定数の可変数 (C8 ~ C10)

量子化された要素タイプは、表現された型の浮動小数点値に対応する、storage_minstorage_max(両端を含む)の範囲のストレージ タイプの整数値を表します。ある整数値 i に対して、対応する浮動小数点値 ff = (i - zero_point) * scale として計算できます。ここで、scalezero_point量子化パラメータと呼ばれます。storage_minstorage_max は文法では省略可能ですが、デフォルト値はそれぞれ min_value(storage_type)max_value(storage_type) です。量子化された要素タイプには次の制約があります。

  • (C1)num_bits(storage_type) < num_bits(expressed_type)
  • (C2)type(storage_min) = storage_type
  • (C3)type(storage_max) = storage_type
  • (C4)min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type)
  • (C5)type(scales...) = expressed_type
  • (C6)0 < scales
  • (C7)is_finite(scales...)
  • (C8)storage_min <= zero_points <= storage_max
  • (C9)type(zero_points...) = storage_type
  • (C10)size(scales) = size(zero_points)
  • (C11)is_empty(quantization_dimension) の場合、size(scales) = 1
  • (C12)0 <= quantization_dimension

現時点では、QuantizationScale は浮動小数点定数ですが、乗数とシフトで表される整数ベースのスケールに強い関心があります。近い将来、この機能を検証する予定です(#1404)。

QuantizationZeroPoint のセマンティクスについては、現在も議論が続いています。たとえば、型や値のほか、量子化テンソル型にゼロ点を 1 つだけ存在できるかどうか、あるいは複数のゼロ点が存在する可能性があるかどうかなどです。以上の検討結果を踏まえ、今後、ゼロポイントの仕様を変更する可能性があります(#1405)。

現在進行中の他の議論では、QuantizationStorageMinQuantizationStorageMax のセマンティクスを取り上げ、これらの値と量子化テンソルの値になんらかの制約を課す必要があるかどうかを判断しています(#1406)。

最後に、不明なディメンション サイズを表現する方法(#1407)と同様に、未知のスケールとゼロポイントの表現について検討します。

量子化テンソル型は、量子化された要素を持つテンソルを表します。これらのテンソルは通常のテンソルとまったく同じですが、その要素に通常の要素型ではなく、量子化された要素型があります。

量子化テンソルでは、量子化はテンソルごとに 1 つ(テンソル全体で 1 つの scalezero_point を持つ)の場合もあれば、複数の scaleszero_points(特定のディメンション quantization_dimension のスライスごとに 1 つのペア)を持つことを意味します。より正式には、軸ごとの量子化を行うテンソル t には、quantization_dimensiondim(t, quantization_dimension) スライス(t[:, ..., 0, ..., :], t[:, ..., 1, ..., :] など)があります。i スライス内のすべての要素は、量子化パラメータとして scales[i]zero_points[i] を使用します。量子化テンソル型には次の制約があります。

  • テンソルごとの量子化の場合:
    • 追加の制約はありません。
  • 軸ごとの量子化の場合:
    • (C12)quantization_dimension < rank(self)
    • (C13)dim(self, quantization_dimension) = size(scales)
TokenType ::= 'token'

トークンタイプはトークンを表します。つまり、なんらかのオペレーションによって生成、消費される不透明な値を表します。トークンは、実行セクションで説明されているように、オペレーションに実行順序を適用するために使用されます。

TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]

タプル型は、タプル(異種リスト)を表します。タプルは、HLO との互換性を維持するためにのみ存在するレガシー機能です。HLO では、タプルを使用して可変長の入力と出力を表します。StableHLO では、さまざまな入出力がネイティブにサポートされており、StableHLO でタプルを使用するのは、HLO ABI を包括的に表現することだけです。たとえば、Ttuple<T>tuple<tuple<T>> は特定の実装によって実質的に異なる場合があります。将来的には、HLO ABI に変更を加えて、StableHLO からタプル型を削除できるようにする予定です(#598)。

TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f8E4M3FN' | 'f8E5M2' | 'f8E4M3FNUZ' | 'f8E5M2FNUZ'
            | 'f8E4M3B11FNUZ' | 'bf16' | 'f16' | 'f32' | 'f64'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'

要素型はテンソル型の要素を表します。多くのプログラミング言語とは異なり、これらの型は StableHLO のファースト クラスではありません。つまり、StableHLO プログラムでは、これらの型の値を直接表すことはできません(そのため、T 型のスカラー値を tensor<T> 型の 0 次元のテンソル値で表すのが慣用的です)。

FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]

関数型は、名前付き関数と匿名関数の両方を表します。入力型(-> の左側の型のリスト)と出力型(-> の右側の型のリスト)があります。多くのプログラミング言語では関数型がファースト クラスですが、StableHLO ではそうではありません。

StringType ::= 'string'

文字列型は、バイトのシーケンスを表します。多くのプログラミング言語とは異なり、文字列型は StableHLO ではトップクラスではなく、プログラム要素の静的メタデータを指定するためにのみ使用されます。

運用

StableHLO 演算演算とも呼ばれる)は、ML モデルにおける高度な演算のクローズド セットを表します。前述のように、StableHLO 構文は MLIR に大きくヒントを得たものですが、MLIR は必ずしも最も人間工学に基づく代替手段ではありませんが、ML フレームワークと ML コンパイラ間の相互運用性を促進するという StableHLO の目標にはおそらく最適です。

Op            ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName        ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic    ::= 'abs' | 'add' | ...

StableHLO オペレーション(オペレーションとも呼ばれる)には、名前、入出力、シグネチャがあります。この名前は、stablehlo. 接頭辞と、サポートされている演算のいずれかを一意に識別するニーモニックで構成されます。サポートされているすべてのオペレーションの包括的なリストについては、以下をご覧ください。

現時点では、実際の StableHLO プログラムにはこのドキュメントで説明されていないオペレーションが含まれている場合があります。将来的には、これらの演算を StableHLO オペレーション セットに取り込むか、StableHLO プログラムに表示されないようにする予定です。それまでの間、オペレーションのリストを以下に示します。

  • builtin.modulefunc.funcfunc.callfunc.return#425)。
  • chlo オペレーション(#602)。
  • StableHLO 演算の「Not HLO」カテゴリ - 当初は StableHLO オペセットに含まれていましたが、後で適切に一致しないとみなされました(broadcastcreate_tokencross-replica-sumdoteinsumtorch_index_selectunary_einsum)。(#3)。
  • StableHLO オペレーションの「Dynamism」カテゴリ - MHLO からブートストラップされましたが、まだ特定されていません。compute_reshape_shapecstr_reshapabledynamic_broadcast_in_dimdynamic_convdynamic_gatherdynamic_iotadynamic_paddynamic_reshapereal_dynamic_sliceset_dimension_size#8)。
  • arithshapetensor オペレーションを含むシェイプ計算(#8)。
OpInputs        ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues   ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue    ::= ValueId
OpInputFuncs    ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs    ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs       ::= [OpOutput {',' OpOutput} '=']
OpOutput        ::= ValueId

オペレーションは入力を使用して出力を生成します。入力は、入力値(実行中に計算)、入力関数(StableHLO ではファーストクラスの値ではないため、静的に提供)、入力属性(これも静的に提供)に分類されます。演算によって使用、生成される入出力の種類は、そのニーモニックによって異なります。たとえば、add 演算は 2 つの入力値を使用し、1 つの出力値を生成します。これに対して、select_and_scatter 演算は 3 つの入力値、2 つの入力関数、3 つの入力属性を使用します。

OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused      ::= '^' digit {digit}
              | '^' letter {letter | digit}

入力関数(匿名関数とも呼ばれます)は名前付き関数とよく似ていますが、1)識別子がない(したがって「匿名」という名前になります)、2)出力型を宣言しません(出力型は関数内の return 演算から推測されます)。

入力関数の構文には、現在使用されていない部分(上記の Unused の生成を参照)が含まれています。これは、MLIR との互換性を維持するために用意されています。MLIR には、ジャンプ オペレーションを介して接続されたオペレーションの複数の「ブロック」を持つことができる「リージョン」という、より一般的なコンセプトがあります。これらのブロックには、Unused のプロダクションに対応する ID があるため、相互に区別できます。StableHLO にはジャンプ オペレーションがないため、MLIR 構文の対応する部分は使用されません(ただし、引き続き存在します)。

OpInputAttr      ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName  ::= letter {letter | digit}
OpInputAttrValue ::= Constant

入力属性には、名前と値(サポートされている定数のいずれか)があります。プログラム要素の静的メタデータを指定する主な方法です。たとえば、concatenate 演算では属性 dimension を使用して、入力値を連結するディメンションを指定します。同様に、slice 演算は start_indiceslimit_indices などの複数の属性を使用して、入力値のスライスに使用する境界を指定します。

現時点では、実際の StableHLO プログラムには、このドキュメントでは説明されていない属性が含まれている場合があります。将来的には、これらの属性を StableHLO オペレーション セットに取り込むか、StableHLO プログラムに表示されないようにする予定です。それまでの間、属性のリストは次のとおりです。

  • layout#629)。
  • mhlo.frontend_attributes#628)。
  • mhlo.sharding#619)。
  • output_operand_aliases#740)。
  • 位置情報メタデータ(#594)。
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'

演算シグネチャは、すべての入力値の型(-> の左側の型のリスト)と、すべての出力値の型(-> の右側の型のリスト)で構成されます。厳密に言えば、入力型は冗長であり、出力型もほぼ常に冗長です(ほとんどの StableHLO 演算では、出力型は入力から推測できるため)。それにもかかわらず、演算シグネチャは、MLIR との互換性を確保するために StableHLO 構文の一部になっています。

以下は、ニーモニックが select_and_scatter の op の例です。3 つの入力値(%operand%source%init_value)、2 つの入力関数、3 つの入力属性(window_dimensionswindow_stridespadding)を使用します。演算のシグネチャには入力値の型のみが含まれます(インラインで提供される入力関数と属性の型は含まれません)。

%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
    %0 = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GE>
    } : (tensor<i32>, tensor<i32>) -> tensor<i1>
    "stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
    "stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
  window_dimensions = dense<[3, 1]> : tensor<2xi64>,
  window_strides = dense<[2, 1]> : tensor<2xi64>,
  padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>

定数

Constant ::= BooleanConstant
           | IntegerConstant
           | FloatConstant
           | ComplexConstant
           | TensorConstant
           | QuantizedTensorConstant
           | StringConstant
           | EnumConstant

StableHLO 定数には、StableHLO 値を表すリテラルと型があります。通常、型は、あいまいでないこと(ブール値定数の型が i1 で、整数定数の型が複数の場合があるなど)を除き、定数構文の一部になります。

BooleanConstant ::= BooleanLiteral
BooleanLiteral  ::= 'true' | 'false'

ブール定数は、ブール値 truefalse を表します。ブール値定数の型は i1 です。

IntegerConstant   ::= IntegerLiteral ':' IntegerType
IntegerLiteral    ::= ['-' | '+'] DecimalDigits
                    | ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits     ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit      ::= '0' | ... | '9'
hexadecimalDigit  ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'

整数定数は、10 進数または 16 進数の表記を使用した文字列で整数値を表します。他の底(2 進数や 8 進数など)はサポートされていません。整数定数には次の制約があります。

  • (C1)is_wellformed(integer_literal, integer_type)
FloatConstant  ::= FloatLiteral ':' FloatType
FloatLiteral   ::= SignPart IntegerPart FractionalPart ScientificPart
                 | '0x' [HexadecimalDigits]
SignPart       ::= ['-' | '+']
IntegerPart    ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]

浮動小数点定数は、10 進数または科学表記を使用する文字列で浮動小数点値を表します。さらに、16 進数を使用して、対応する型の浮動小数点形式で、基になるビットを直接指定することもできます。浮動小数点定数には次の制約があります。

  • (C1)16 進数以外の表記を使用する場合、is_wellformed(float_literal, float_type)
  • (C2)16 進数表記を使用する場合、size(hexadecimal_digits) = num_bits(float_type) / 4
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral  ::= '(' RealPart ',' ImaginaryPart ')'
RealPart        ::= FloatLiteral
ImaginaryPart   ::= FloatLiteral

複素定数は、実数部(最初に来る)と虚数部(次に来る)のリストを使用して複素数を表します。たとえば、(1.0, 0.0) : complex<f32>1.0 + 0.0i を表し、(0.0, 1.0) : complex<f32>0.0 + 1.0i を表します。これらのパートをメモリに保存する順序は実装で定義します。複雑な定数には、次の制約があります。

  • (C1)is_wellformed(real_part, complex_element_type(complex_type))
  • (C2)is_wellformed(imaginary_part, complex_element_type(complex_type))
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral  ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral   ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements  ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral

テンソル定数は、NumPy 表記で指定されたネストされたリストを使用してテンソル値を表します。たとえば、dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> は、インデックスから要素へのマッピングが {0, 0} => 1{0, 1} => 2{0, 2} => 3{1, 0} => 4{1, 1} => 5{1, 2} => 6 であるテンソル値を表します。これらの要素をメモリに保存する順序は実装で定義します。テンソル定数には次の制約があります。

  • (C1)has_syntax(tensor_literal, element_type(tensor_type))。ここで:
    • has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type).
    • has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type).
  • (C2)has_shape(tensor_literal, shape(tensor_type))。ここで:
    • has_shape(element_literal: Syntax, []) = true.
    • has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:]).
    • それ以外の場合は false
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral  ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'

量子化テンソル定数は、テンソル定数と同じ表記を使用して量子化テンソル値を表現し、要素はストレージ型の定数として指定されます。量子化テンソル定数には次の制約があります。

  • (C1)has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type))
  • (C2)has_shape(quantized_tensor_literal, shape(quantized_tensor_type))
StringConstant  ::= StringLiteral
StringLiteral   ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence  ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))

文字列リテラルは、ASCII 文字とエスケープ シーケンスを使用して指定されたバイトで構成されます。エンコードに依存しないため、これらのバイトの解釈は実装で定義します。文字列リテラルの型は string です。

Ops

abs

セマンティクス

operand テンソルに対して要素ごとの abs 演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の操作を行います。

  • 符号付き整数の場合: 整数のモジュラス。
  • 浮動小数点数の場合: IEEE-754 の abs
  • 複素数の場合: 複素モジュラス。
  • 量子化型の場合: dequantize_op_quantize(abs, operand, type(result))

入力

ラベル 名前 タイプ 制約
(I1) operand 符号付き整数、浮動小数点、複合型、テンソルごとの量子化テンソル (C1 ~ C2)

出力

名前 タイプ 制約
result 符号付き整数、浮動小数点型、テンソルごとの量子化テンソル (C1 ~ C2)

制約

  • (C1)shape(result) = shape(operand)
  • (C2)baseline_element_type(result) が次のように定義されます。
    • is_complex(operand) の場合は complex_element_type(element_type(operand))
    • それ以外の場合は baseline_element_type(operand)

// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand) : (tensor<3xi32>) -> tensor<3xi32>
// %result: [2, 0, 2]

その他の例

add

セマンティクス

2 つのテンソル lhsrhs を要素ごとに加算し、result テンソルを生成します。要素のタイプに応じて、次の操作を行います。

  • ブール値の場合: 論理 OR。
  • 整数の場合: 整数の加算。
  • 浮動小数点数の場合: IEEE-754 の addition
  • 複素数の場合: 複雑な加算。
  • 量子化型の場合: dequantize_op_quantize(add, lhs, rhs, type(result))

入力

ラベル 名前 タイプ 制約
(I1) lhs テンソルまたはテンソルごとの量子化テンソル (C1)
(I2) rhs テンソルまたはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result テンソルまたはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[6, 8], [10, 12]]

その他の例

after_all

セマンティクス

inputs を生成するオペレーションが、result に依存するオペレーションの前に実行されるようにします。このオペレーションを実行しても何も起こりません。result から inputs へのデータ依存関係を確立するためにのみ存在します。

入力

ラベル 名前 タイプ
(I1) inputs token の変数数

出力

名前 タイプ
result token

// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token

その他の例

all_gather

セマンティクス

StableHLO プロセス グリッドの各プロセス グループ内で、各プロセスの operand テンソルの値を all_gather_dim に沿って連結し、result テンソルを生成します。

このオペレーションでは、StableHLO プロセス グリッドが次のように定義される process_groups に分割されます。

  • cross_replica(replica_groups)channel_id <= 0 and use_global_device_ids = false の場合)。
  • cross_replica_and_partition(replica_groups)channel_id > 0 and use_global_device_ids = false の場合)。
  • flattened_ids(replica_groups)channel_id > 0 and use_global_device_ids = true の場合)。

その後、各 process_group 内で次のことを行います。

  • process_group のすべての receiver に対して operands@receiver = [operand@sender for sender in process_group]
  • process_group のすべての process に対して result@process = concatenate(operands@process, all_gather_dim)

入力

ラベル 名前 タイプ 制約
(I1) operand テンソルまたはテンソルごとの量子化テンソル (C1)、(C6)
(I2) all_gather_dim si64 型の定数 (C1)、(C6)
(I3) replica_groups si64 型の 2 次元テンソル定数 (C2 ~ C4)
(I4) channel_id si64 型の定数 (C5)
(I5) use_global_device_ids i1 型の定数 (C5)

出力

名前 タイプ 制約
result テンソルまたはテンソルごとの量子化テンソル (C6)

制約

  • (C1)0 <= all_gather_dim < rank(operand)
  • (C2)is_unique(replica_groups)
  • (C3)size(replica_groups) が次のように定義されます。
    • cross_replica を使用する場合は num_replicas
    • cross_replica_and_partition を使用する場合は num_replicas
    • flattened_ids を使用する場合は num_processes
  • (C4)0 <= replica_groups < size(replica_groups)
  • (C5)use_global_device_ids = true の場合、channel_id > 0
  • (C6)type(result) = type(operand) 例外:
    • dim(result, all_gather_dim) = dim(operand, all_gather_dim) * dim(process_groups, 1).

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
%result = "stablehlo.all_gather"(%operand) {
  all_gather_dim = 1 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  // channel_id = 0
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
  // use_global_device_ids = false
} : (tensor<2x2xi64>) -> tensor<2x4xi64>
// %result@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]

その他の例

all_reduce

セマンティクス

StableHLO プロセス グリッドの各プロセス グループ内で、各プロセスの operand テンソルの値にリダクション関数 computation を適用し、result テンソルを生成します。

このオペレーションでは、StableHLO プロセス グリッドが次のように定義される process_groups に分割されます。

  • cross_replica(replica_groups)channel_id <= 0 and use_global_device_ids = false の場合)。
  • cross_replica_and_partition(replica_groups)channel_id > 0 and use_global_device_ids = false の場合)。
  • flattened_ids(replica_groups)channel_id > 0 and use_global_device_ids = true の場合)。

その後、各 process_group 内で次のことを行います。

  • バイナリツリー schedule の場合は result@process[result_index] = exec(schedule)。ここで、:
    • exec(node) = computation(exec(node.left), exec(node.right))
    • exec(leaf) = leaf.value
  • schedule は、順序走査が to_destination_type(operands@process_group...[result_index], type(func_inputs(computation)[0])) である実装定義のバイナリツリーです。

入力

ラベル 名前 タイプ 制約
(I1) operand テンソルまたはテンソルごとの量子化テンソル (C5)、(C6)
(I2) replica_groups si64 型の 1 次元テンソル定数の可変数 (C1 ~ C3)
(I3) channel_id si64 型の定数 (C4)
(I4) use_global_device_ids i1 型の定数 (C4)
(I5) computation 機能 (C5)

出力

名前 タイプ 制約
result テンソルまたはテンソルごとの量子化テンソル (C6 ~ C7)

制約

  • (C1)is_unique(replica_groups)
  • (C2)size(replica_groups) が次のように定義されます。
    • cross_replica を使用する場合は num_replicas
    • cross_replica_and_partition を使用する場合は num_replicas
    • flattened_ids を使用する場合は num_processes
  • (C3)0 <= replica_groups < size(replica_groups)
  • (C4)use_global_device_ids = true の場合、channel_id > 0
  • (C5)computation の型は (tensor<E>, tensor<E>) -> (tensor<E>) で、is_promotable(element_type(operand), E) です。
  • (C6)shape(result) = shape(operand)
  • (C7)element_type(result) = E

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [1, 2, 3, 4]
// %operand@(1, 0): [5, 6, 7, 8]
%result = "stablehlo.all_reduce"(%operand) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<i64>) -> tensor<i64>
// %result@(0, 0): [6, 8, 10, 12]
// %result@(1, 0): [6, 8, 10, 12]

その他の例

all_to_all

セマンティクス

StableHLO プロセス グリッドの各プロセス グループ内で、operand テンソルの値を split_dimension に沿って一部に分割し、分割された部分をプロセス間で分散させて concat_dimension に沿って連結し、result テンソルを生成します。

このオペレーションでは、StableHLO プロセス グリッドが次のように定義される process_groups に分割されます。

  • channel_id <= 0 の場合は cross_replica(replica_groups)
  • channel_id > 0 の場合は cross_partition(replica_groups)

その後、各 process_group 内で次のことを行います。

  • process_group のすべての sender に対して split_parts@sender = split(operand@sender, split_count, split_dimension)
  • scattered_parts@receiver = [split_parts@sender[receiver_index] for sender in process_group](ここで、receiver_index = process_group.index(receiver))。
  • result@process = concatenate(scattered_parts@process, concat_dimension).

入力

ラベル 名前 タイプ 制約
(I1) operand テンソルまたはテンソルごとの量子化テンソル (C1 ~ C3)、(C9)
(I2) split_dimension si64 型の定数 (C1)、(C2)、(C9)
(I3) concat_dimension si64 型の定数 (C3)、(C9)
(I4) split_count si64 型の定数 (C2)、(C4)、(C8)、(C9)
(I5) replica_groups si64 型の 2 次元テンソル定数 (C5 ~ C8)
(I6) channel_id si64 型の定数

出力

名前 タイプ 制約
result テンソルまたはテンソルごとの量子化テンソル (C9)

制約

  • (C1)0 <= split_dimension < rank(operand)
  • (C2)dim(operand, split_dimension) % split_count = 0
  • (C3)0 <= concat_dimension < rank(operand)
  • (C4)0 < split_count
  • (C5)is_unique(replica_groups)
  • (C6)size(replica_groups) は次のように定義されます。
    • cross_replica を使用する場合は num_replicas
    • cross_partition を使用する場合は num_partitions
  • (C7)0 <= replica_groups < size(replica_groups)
  • (C8)dim(replica_groups, 1) = split_count
  • (C9)type(result) = type(operand)(ただし、以下を除く):
    • dim(result, split_dimension) = dim(operand, split_dimension) / split_count.
    • dim(result, concat_dimension) = dim(operand, concat_dimension) * split_count.

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
//                   [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
//                   [13, 14, 15, 16]]
%result = "stablehlo.all_to_all"(%operand) {
  split_dimension = 1 : i64,
  concat_dimension = 0 : i64,
  split_count = 2 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
} : (tensor<2x4xi64>) -> tensor<4x2xi64>
// %result@(0, 0): [[1, 2],
//                  [5, 6],
//                  [9, 10],
//                  [13, 14]]
// %result@(1, 0): [[3, 4],
//                  [7, 8],
//                  [11, 12],
//                  [15, 16]]

その他の例

and

セマンティクス

2 つのテンソル lhsrhs の要素ごとの AND を実行し、result テンソルを生成します。要素のタイプに応じて、次の操作を行います。

  • ブール値の場合: 論理 AND。
  • 整数の場合: ビット演算 AND。

入力

ラベル 名前 タイプ 制約
(I1) lhs ブール値または整数型のテンソル (C1)
(I2) rhs ブール値または整数型のテンソル (C1)

出力

名前 タイプ 制約
result ブール値または整数型のテンソル (C1)

制約

  • (C1)type(lhs) = type(rhs) = type(result)

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 0]]

atan2

セマンティクス

lhs テンソルと rhs テンソルに対して要素ごとの atan2 演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の操作を行います。

  • 浮動小数点数の場合: IEEE-754 の atan2
  • 複素数の場合: 複素 atan2。
  • 量子化型の場合: dequantize_op_quantize(atan2, lhs, rhs, type(result))

入力

ラベル 名前 タイプ 制約
(I1) lhs 浮動小数点型または複素数のテンソル、またはテンソルごとの量子化テンソル (C1)
(I2) rhs 浮動小数点型または複素数のテンソル、またはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result 浮動小数点型または複素数のテンソル、またはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)

// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs) : (tensor<3xf64>, tensor<3xf64>) -> tensor<3xf64>
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]

その他の例

batch_norm_grad

セマンティクス

grad_output から逆伝播する複数の batch_norm_training 入力の勾配を計算し、grad_operandgrad_scalegrad_offset テンソルを生成します。より正式には、このオペレーションは、次のように Python 構文を使用して既存の StableHLO オペレーションへの分解として表すことができます。

def compute_sum(operand, feature_index):
  (sum,) = reduce(
      inputs=[operand],
      init_values=[constant(0, element_type(operand))],
      dimensions=[i for i in range(rank(operand)) if i != feature_index],
      body=lambda x, y: add(x, y))
  return sum

def compute_mean(operand, feature_index):
  sum = compute_sum(operand, feature_index)
  divisor = constant(size(operand) / dim(operand, feature_index),
                     element_type(operand))
  divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
  return divide(sum, divisor_bcast)

def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
  # Broadcast inputs to type(operand)
  scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
  epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
                                   type(operand))

  # Perform normalization using the provided `mean` and `variance`
  # Intermediate values will be useful for computing gradients
  centered_operand = subtract(operand, mean_bcast)
  stddev = sqrt(add(variance_bcast, epsilon_bcast))
  normalized_operand = divide(centered_operand, stddev)

  # Use the implementation from batchnorm_expander.cc in XLA
  # Temporary variables have exactly the same names as in the C++ code
  elements_per_feature = broadcast_in_dim(
      constant(divide(size(operand), dim(operand, feature_index)),
               element_type(grad_output)),
      [], type(operand))
  i1 = multiply(grad_output, elements_per_feature)
  i2 = broadcast_in_dim(
      compute_sum(grad_output, feature_index), [feature_index], type(operand))
  i3 = broadcast_in_dim(
      compute_sum(multiply(grad_output, centered_operand), feature_index),
      [feature_index], type(operand))
  i4 = multiply(i3, centered_operand)
  i5 = divide(i4, add(variance_bcast, epsilon_bcast))
  i6 = subtract(subtract(i1, i2), i5)

  grad_operand =
      multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
  grad_scale =
      compute_sum(multiply(grad_output, normalized_operand), feature_index)
  grad_offset = compute_sum(grad_output, feature_index)

  return grad_operand, grad_scale, grad_offset

量子化型の場合、dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean, variance, grad_output: batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index), operand, scale, mean, variance, grad_output, type(grad_operand), type(grad_scale), type(feature_index)) を実行します。

入力

ラベル 名前 タイプ 制約
(I1) operand 浮動小数点型またはテンソルごとの量子化テンソル (C1 ~ C3)、(C5)
(I2) scale 浮動小数点数またはテンソルごとの量子化型の 1 次元テンソル (C2)、(C4)、(C5)
(I3) mean 浮動小数点数またはテンソルごとの量子化型の 1 次元テンソル (C2)、(C4)
(I4) variance 浮動小数点数またはテンソルごとの量子化型の 1 次元テンソル (C2)、(C4)
(I5) grad_output 浮動小数点型またはテンソルごとの量子化テンソル (C2)、(C3)
(I6) epsilon f32 型の定数
(I7) feature_index si64 型の定数 (C1)、(C5)

出力

名前 タイプ 制約
grad_operand 浮動小数点型またはテンソルごとの量子化テンソル (C2)、(C3)
grad_scale 浮動小数点数またはテンソルごとの量子化型の 1 次元テンソル (C2)、(C4)
grad_offset 浮動小数点数またはテンソルごとの量子化型の 1 次元テンソル (C2)、(C4)

制約

  • (C1)0 <= feature_index < rank(operand)
  • (C2)operandscalemeanvariancegrad_outputgrad_operandgrad_scalegrad_offsetbaseline_element_type が同じ。
  • (C3)operandgrad_outputgrad_operand が同じ形状である。
  • (C4)scalemeanvariancegrad_scalegrad_offset が同じ形状である。
  • (C5)size(scale) = dim(operand, feature_index)

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
//                [[0.1, 0.1], [0.1, 0.1]],
//                [[0.1, 0.1], [0.1, 0.1]]
//               ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>,
     tensor<2x2x2xf64>) -> (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %grad_operand: [
//                 [[0.0, 0.0], [0.0, 0.0]],
//                 [[0.0, 0.0], [0.0, 0.0]]
//                ]
// %grad_scale:  [0.0, 0.0]
// %grad_offset: [0.4, 0.4]

batch_norm_inference

セマンティクス

feature_index ディメンションを除くすべてのディメンションで operand テンソルを正規化し、result テンソルを生成します。より正式には、このオペレーションは、次のように Python 構文を使用して既存の StableHLO オペレーションへの分解として表すことができます。

def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
  # Broadcast inputs to shape(operand)
  scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
  offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
  epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
                                   type(operand))

  # Perform normalization using the provided `mean` and `variance` instead of
  # computing them like `batch_norm_training` does.
  centered_operand = subtract(operand, mean_bcast)
  stddev = sqrt(add(variance_bcast, epsilon_bcast))
  normalized_operand = divide(centered_operand, stddev)
  return add(multiply(scale_bcast, normalized_operand), offset_bcast)

量子化型の場合、dequantize_op_quantize(lambda operand, scale, offset, mean, variance: batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index), operand, scale, offset, mean, variance, type(result)) を実行します。

入力

ラベル 名前 タイプ 制約
(I1) operand 浮動小数点型またはテンソルごとの量子化テンソル (C1 ~ C7)
(I2) scale 浮動小数点数またはテンソルごとの量子化型の 1 次元テンソル (C2)、(C3)
(I3) offset 浮動小数点数またはテンソルごとの量子化型の 1 次元テンソル (C2)、(C4)
(I4) mean 浮動小数点数またはテンソルごとの量子化型の 1 次元テンソル (C5)
(I5) variance 浮動小数点数またはテンソルごとの量子化型の 1 次元テンソル (C2)、(C6)
(I6) epsilon f32 型の定数
(I7) feature_index si64 型の定数 (C1)、(C3 ~ C6)

出力

名前 タイプ 制約
result 浮動小数点型またはテンソルごとの量子化テンソル (C2)、(C7)

制約

  • (C1)0 <= feature_index < rank(operand)
  • (C2)operandscaleoffsetmeanvarianceresultbaseline_element_type が同じ。
  • (C3)size(scale) = dim(operand, feature_index)
  • (C4)size(offset) = dim(operand, feature_index)
  • (C5)size(mean) = dim(operand, feature_index)
  • (C6)size(variance) = dim(operand, feature_index)
  • (C7)baseline_type(operand) = baseline_type(result)

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>) -> tensor<2x2x2xf64>
// %result: [
//           [[0.0, 0.0], [2.0, 2.0]],
//           [[2.0, 2.0], [0.0, 0.0]]
//          ]

batch_norm_training

セマンティクス

feature_index ディメンションを除くすべての次元の平均と分散を計算し、outputbatch_meanbatch_var テンソルを生成する operand テンソルを正規化します。より正式には、このオペレーションは、次のように Python 構文を使用して既存の StableHLO オペレーションへの分解として表すことができます。

def compute_mean(operand, feature_index):
  (sum,) = reduce(
      inputs=[operand],
      init_values=[constant(0, element_type(operand))],
      dimensions=[i for i in range(rank(operand)) if i != feature_index],
      body=lambda x, y: add(x, y))
  divisor = constant(size(operand) / dim(operand, feature_index),
                     element_type(operand))
  divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
  return divide(sum, divisor_bcast)

def compute_variance(operand, feature_index):
  mean = compute_mean(operand, feature_index)
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  centered_operand = subtract(operand, mean_bcast)
  return compute_mean(mul(centered_operand, centered_operand), feature_index)

def batch_norm_training(operand, scale, offset, epsilon, feature_index):
  mean = compute_mean(operand, feature_index)
  variance = compute_variance(operand, feature_index)
  return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
                              feature_index),
         mean, variance

量子化型の場合、dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset: batch_norm_training(operand, scale, offset, epsilon, feature_index), operand, scale, offset, type(output), type(batch_mean), type(batch_var)) を実行します。

入力

ラベル 名前 タイプ 制約
(I1) operand 浮動小数点型またはテンソルごとの量子化テンソル (C1)
(I2) scale 浮動小数点数またはテンソルごとの量子化された 1 次元テンソル (C2)、(C3)
(I3) offset 浮動小数点数またはテンソルごとの量子化された 1 次元テンソル (C2)、(C4)
(I4) epsilon f32 型の定数 (C1)、(C3 ~ C6)
(I5) feature_index si64 型の定数 (C1)、(C3 ~ C6)

出力

名前 タイプ 制約
output 浮動小数点型またはテンソルごとの量子化テンソル (C7)
batch_mean 浮動小数点数またはテンソルごとの量子化された 1 次元テンソル (C2)、(C5)
batch_var 浮動小数点数またはテンソルごとの量子化された 1 次元テンソル (C2)、(C6)

制約

  • (C1)0 <= feature_index < rank(operand)
  • (C2)operandscaleoffsetbatch_meanbatch_varoutputbaseline_element_type が同じ。
  • (C3)size(scale) = dim(operand, feature_index)
  • (C4)size(offset) = dim(operand, feature_index)
  • (C5)size(batch_mean) = dim(operand, feature_index)
  • (C6)size(batch_var) = dim(operand, feature_index)
  • (C7)baseline_type(output) = baseline_type(operand)

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) ->
    (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %output: [
//           [[0.0, 0.0], [2.0, 2.0]],
//           [[2.0, 2.0], [0.0, 0.0]]
//          ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]

bitcast_convert

セマンティクス

operand テンソルに対してビットキャスト演算を実行し、operand テンソル全体のビットが result テンソルの型を使用して再解釈された result テンソルを生成します。

E = element_type(operand)E' = element_type(result)R = rank(operand) を指定して、より正式に次のように記述します。

  • num_bits(E') < num_bits(E) の場合、bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1])
  • num_bits(E') > num_bits(E) の場合、bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :])
  • num_bits(E') = num_bits(E) の場合、bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1])

bits は指定された値のメモリ内表現を返します。テンソルの正確な表現は実装によって定義され、要素の型の正確な表現は実装によって定義されるため、その動作は実装が定義されています。

入力

ラベル 名前 タイプ 制約
(I1) operand テンソルまたは量子化テンソル (C1 ~ C2)

出力

名前 タイプ 制約
result テンソルまたは量子化テンソル (C1 ~ C2)

制約

  • (C1)E = is_quantized(operand) ? storage_type(operand) : element_type(operand)E' = is_quantized(result) ? storage_type(result) : element_type(result)R = rank(operand) がある場合:
    • num_bits(E') = num_bits(E) の場合は、shape(result) = shape(operand)
    • num_bits(E') < num_bits(E) の場合:
    • rank(result) = R + 1.
    • すべての 0 <= i < R に対して dim(result, i) = dim(operand, i)
    • dim(result, R) * num_bits(E') = num_bits(E).
    • num_bits(E') > num_bits(E) の場合:
    • rank(result) = R - 1.
    • すべての 0 <= i < R に対して dim(result, i) = dim(operand, i)
    • dim(operand, R - 1) * num_bits(E) = num_bits(E').
  • (C2)is_complex(operand) or is_complex(result) の場合、is_complex(operand) and is_complex(result)

// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand) : (tensor<f64>) -> tensor<4xf16>
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation

その他の例

broadcast_in_dim

セマンティクス

operand テンソルにデータを複製することにより入力テンソルの次元や階数を拡張し、result テンソルを生成します。より正式には、result[result_index] = operand[operand_index]axes(operand) 内のすべての d について次のように記述します。

  • dim(operand, d) = 1 の場合は operand_index[d] = 0
  • それ以外の場合は operand_index[d] = result_index[broadcast_dimensions[d]]

入力

ラベル 名前 タイプ 制約
(I1) operand テンソルまたは量子化テンソル (C1 ~ C2)、(C5 ~ C6)
(I2) broadcast_dimensions si64 型の 1 次元テンソル定数 (C2 ~ C6)

出力

名前 タイプ 制約
result テンソルまたは量子化テンソル (C1)、(C3)、(C5 ~ C6)

制約

  • (C1)element_type(result) は次のように表されます。
    • element_type(operand)!is_per_axis_quantized(operand) の場合)。
    • element_type(operand)。ただし、quantization_dimension(operand)scales(operand)zero_points(operand) は、quantization_dimension(result)scales(result)zero_points(result) の応答と異なる場合があります。
  • (C2)size(broadcast_dimensions) = rank(operand)
  • (C3)0 <= broadcast_dimensions < rank(result)
  • (C4)is_unique(broadcast_dimensions)
  • (C5)axes(operand) 内のすべての d に対して:
    • dim(operand, d) = 1 または
    • dim(operand, d) = dim(result, broadcast_dimensions[d]).
  • (C6)is_per_axis_quantized(result) の場合:
    • quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)].
    • dim(operand, quantization_dimension(operand)) = 1 の場合は、scales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result))) です。

// %operand: [
//            [1, 2, 3]
//           ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
  broadcast_dimensions = array<i64: 2, 1>
} : (tensor<1x3xi32>) -> tensor<2x3x2xi32>
// %result: [
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ],
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ]
//          ]

その他の例

ケース

セマンティクス

index の値に応じて、branches から関数を 1 つだけ実行して出力を生成します。より正式には result = selected_branch() です。ここで、

  • 0 <= index < size(branches) の場合は selected_branch = branches[index]
  • それ以外の場合は selected_branch = branches[-1]

入力

ラベル 名前 タイプ 制約
(I1) index si32 型の 0 次元テンソル
(I2) branches 可変関数数 (C1 ~ C4)

出力

名前 タイプ 制約
results テンソル、量子化テンソル、トークンの可変数 (C4)

制約

  • (C1)0 < size(branches)
  • (C2)input_types(branches...) = []
  • (C3)same(output_types(branches...))
  • (C4)type(results...) = output_types(branches[0])

// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
  "stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> ()
}, {
  "stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> ()
}) : (tensor<i32>) -> (tensor<2xi64>, tensor<2xi64>)
// %result0: [1, 1]
// %result1: [1, 1]

その他の例

cbrt

セマンティクス

operand テンソルに対して要素ごとの立方根演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の操作を行います。

  • 浮動小数点数の場合: IEEE-754 の rootn(x, 3)
  • 複素数の場合: 複素立方根。
  • 量子化型の場合: dequantize_op_quantize(cbrt, operand, type(result))

入力

ラベル 名前 タイプ 制約
(I1) operand 浮動小数点型または複素数のテンソル、またはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result 浮動小数点型または複素数のテンソル、またはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(operand) = baseline_type(result)

// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand) : (tensor<4xf64>) -> tensor<4xf64>
// %result: [0.0, 1.0, 2.0, 3.0]

その他の例

ceil

セマンティクス

operand テンソルの要素ごとに ceil を実行し、result テンソルを生成します。IEEE-754 仕様の roundToIntegralTowardPositive オペレーションを実装します。量子化型の場合、dequantize_op_quantize(ceil, operand, type(result)) を実行します。

入力

ラベル 名前 タイプ 制約
(I1) operand 浮動小数点型またはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result 浮動小数点型またはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(operand) = baseline_type(result)

// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]

その他の例

コレスキー

セマンティクス

一連の行列のコレスキー分解を計算します。

より正式には、index_space(result) のすべての i について、result[i0, ..., iR-3, :, :] は下三角行列(lowertrue の場合)または上三角行列(lowerfalse の場合)のいずれかの形式で、a[i0, ..., iR-3, :, :] のコレスキー分解になります。反対側の三角形の出力値(対応する厳密な上三角形または厳密な下方三角形)は実装で定義されます。

入力行列がエルミート正定行列でない i が存在する場合、動作は未定義になります。

量子化型の場合、dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result)) を実行します。

入力

ラベル 名前 タイプ 制約
(I1) a 浮動小数点型または複素数のテンソル、またはテンソルごとの量子化テンソル (C1 ~ C3)
(I2) lower i1 型の 0 次元テンソル定数

出力

名前 タイプ 制約
result 浮動小数点型または複素数のテンソル、またはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(a) = baseline_type(result)
  • (C2)2 <= rank(a)
  • (C3)dim(a, -2) = dim(a, -1)

// %a: [
//      [1.0, 2.0, 3.0],
//      [2.0, 20.0, 26.0],
//      [3.0, 26.0, 70.0]
//     ]
%result = "stablehlo.cholesky"(%a) {
  lower = true
} : (tensor<3x3xf32>) -> tensor<3x3xf64>
// %result: [
//           [1.0, 0.0, 0.0],
//           [2.0, 4.0, 0.0],
//           [3.0, 5.0, 6.0]
//          ]

クランプ

セマンティクス

operand テンソルのすべての要素を最小値と最大値の間にクロックして、result テンソルを生成します。より正式には result[result_index] = minimum(maximum(operand[result_index], min_element), max_element) です(min_element = rank(min) = 0 ? min[] : min[result_index]max_element = rank(max) = 0 ? max[] : max[result_index])。量子化型の場合、dequantize_op_quantize(clamp, min, operand, max, type(result)) を実行します。

複素数に順序付けを適用すると、意外なセマンティクスも必要になります。そのため、今後、この演算では複素数のサポートを削除する予定です(#560)。

入力

ラベル 名前 タイプ 制約
(I1) min テンソルまたはテンソルごとの量子化テンソル (C1)、(C3)
(I2) operand テンソルまたはテンソルごとの量子化テンソル (C1 ~ C4)
(I3) max テンソルまたはテンソルごとの量子化テンソル (C2)、(C3)

出力

名前 タイプ 制約
result テンソルまたはテンソルごとの量子化テンソル (C4)

制約

  • (C1)rank(min) = 0 or shape(min) = shape(operand)
  • (C2)rank(max) = 0 or shape(max) = shape(operand)
  • (C3)baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max)
  • (C4)baseline_type(operand) = baseline_type(result)

// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max) : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// %result: [5, 13, 20]

その他の例

collective_broadcast

セマンティクス

StableHLO プロセス グリッドの各プロセス グループ内で、operand テンソルの値をソースプロセスからターゲット プロセスに送信し、result テンソルを生成します。

このオペレーションでは、StableHLO プロセス グリッドが次のように定義される process_groups に分割されます。

  • channel_id <= 0 の場合は cross_replica(replica_groups)
  • channel_id > 0 の場合は cross_partition(replica_groups)

その後、result@process は次のように与えられます。

  • operand@process_groups[i, 0](プロセスが process_groups[i] にあるような i が存在する場合)。
  • broadcast_in_dim(constant(0, element_type(result)), [], type(result)) に制限されます。

入力

ラベル 名前 タイプ 制約
(I1) operand テンソル (C3)
(I2) replica_groups si64 型の 1 次元テンソル定数の可変数 (C1)、(C2)
(I3) channel_id si64 型の定数

出力

名前 タイプ 制約
result テンソル (C3)

制約

  • (C1)is_unique(replica_groups)
  • (C2)0 <= replica_groups < NN は次のように定義されます。
    • cross_replica を使用する場合は num_replicas
    • cross_partition を使用する場合は num_partitions
  • (C3)type(result) = type(operand)

// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
  replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]

collective_permute

セマンティクス

StableHLO プロセス グリッドの各プロセス グループ内で、operand テンソルの値をソースプロセスからターゲット プロセスに送信し、result テンソルを生成します。

このオペレーションでは、StableHLO プロセス グリッドが次のように定義される process_groups に分割されます。

  • channel_id <= 0 の場合は cross_replica(source_target_pairs)
  • channel_id > 0 の場合は cross_partition(source_target_pairs)

その後、result@process は次のように与えられます。

  • operand@process_groups[i, 0]process_groups[i, 1] = process のような i が存在する場合)。
  • broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result)) に制限されます。

入力

ラベル 名前 タイプ 制約
(I1) operand テンソルまたはテンソルごとの量子化テンソル (C5)
(I2) source_target_pairs si64 型の 2 次元テンソル定数 (C1 ~ C4)
(I3) channel_id si64 型の定数

出力

名前 タイプ 制約
result テンソルまたはテンソルごとの量子化テンソル (C1)

制約

  • (C1)dim(source_target_pairs, 1) = 2
  • (C2)is_unique(source_target_pairs[:, 0])
  • (C3)is_unique(source_target_pairs[:, 1])
  • (C4)0 <= source_target_pairs < NN は次のように定義されます。
    • cross_replica を使用する場合は num_replicas
    • cross_partition を使用する場合は num_partitions
  • (C5)type(result) = type(operand)

// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
  source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]

その他の例

compare

セマンティクス

comparison_directioncompare_type に従って、lhs テンソルと rhs テンソルの要素単位の比較を行い、result テンソルを生成します。

comparison_directioncompare_type の値のセマンティクスは次のとおりです。

ブール値と整数の要素タイプの場合:

  • EQ: lhs = rhs
  • NE: lhs != rhs
  • GE: lhs >= rhs
  • GT: lhs > rhs
  • LE: lhs <= rhs
  • LT: lhs < rhs

compare_type = FLOAT の浮動小数点要素型の場合、演算は次の IEEE-754 演算を実装します。

  • EQ: compareQuietEqual
  • NE: compareQuietNotEqual
  • GE: compareQuietGreaterEqual
  • GT: compareQuietGreater
  • LE: compareQuietLessEqual
  • LT: compareQuietLess

compare_type = TOTALORDER の浮動小数点要素型の場合、演算は IEEE-754 の totalOrder 演算と compareQuietEqual 演算を組み合わせて使用します。この機能は使用されていないようですので、今後削除する予定です(#584)。

複合要素タイプの場合、(real, imag) ペアの辞書順の比較は、指定された comparison_directioncompare_type を使用して行われます。複素数に順序付けを課すと、意外なセマンティクスが登場するため、今後、comparison_directionGEGTLELT の場合に複素数のサポートを削除する予定です(#560)。

量子化型の場合、dequantize_compare(lhs, rhs, comparison_direction) を実行します。

入力

ラベル 名前 タイプ 制約
(I1) lhs テンソルまたはテンソルごとの量子化テンソル (C1 ~ C3)
(I2) rhs テンソルまたはテンソルごとの量子化テンソル (C1 ~ C2)
(I3) comparison_direction EQNEGEGTLELT の列挙型
(I4) compare_type FLOATTOTALORDERSIGNEDUNSIGNED の列挙型 (C3)

出力

名前 タイプ 制約
result ブール型のテンソル (C2)

制約

  • (C1)baseline_element_type(lhs) = baseline_element_type(rhs)
  • (C2)shape(lhs) = shape(rhs) = shape(result)
  • (C3)compare_type が次のように定義されます。
    • is_signed_integer(element_type(lhs)) の場合は SIGNED
    • is_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs)) の場合は UNSIGNED
    • FLOAT または TOTALORDERis_float(element_type(lhs)) の場合)。
    • is_complex(element_type(lhs)) の場合は FLOAT

// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
  comparison_direction = #stablehlo<comparison_direction LT>,
  compare_type = #stablehlo<comparison_type FLOAT>
} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
// %result: [true, false]

その他の例

複雑

セマンティクス

実数値と虚数値のペア lhsrhs から複素値に要素ごとに変換し、result テンソルを生成します。

入力

ラベル 名前 タイプ 制約
(I1) lhs f32 型または f64 型のテンソル (C1 ~ C3)
(I2) rhs f32 型または f64 型のテンソル (C1)

出力

名前 タイプ 制約
result 複素型のテンソル (C2)、(C3)

制約

  • (C1)type(lhs) = type(rhs)
  • (C2)shape(result) = shape(lhs)
  • (C3)element_type(result) の型は complex<E> で、E = element_type(lhs) です。

// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs) : (tensor<2xf64>, tensor<2xf64>) -> tensor<2xcomplex<f64>>
// %result: [(1.0, 2.0), (3.0, 4.0)]

その他の例

concatenate

セマンティクス

指定された引数と同じ順序で inputsdimension ディメンションに沿って連結し、result テンソルを生成します。より正式には result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1] です。

  1. id = d0 + ... + dk-1 + kd.
  2. ddimension と等しく、d0 は、inputsd 番目の次元サイズです。

入力

ラベル 名前 タイプ 制約
(I1) inputs テンソル数またはテンソルごとの量子化テンソルの可変数 (C1 ~ C6)
(I2) dimension si64 型の定数 (C2)、(C4)、(C6)

出力

名前 タイプ 制約
result テンソルまたはテンソルごとの量子化テンソル (C5 ~ C6)

制約

  • (C1)same(element_type(inputs...))
  • (C2)same(shape(inputs...))dim(inputs..., dimension) を除く)。
  • (C3)0 < size(inputs)
  • (C4)0 <= dimension < rank(inputs[0])
  • (C5)element_type(result) = element_type(inputs[0])
  • (C6)shape(result) = shape(inputs[0]) 例外:
    • dim(result, dimension) = dim(inputs[0], dimension) + ....

// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
  dimension = 0 : i64
} : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64>
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]

その他の例

定数

セマンティクス

定数 value から output テンソルを生成します。

入力

ラベル 名前 タイプ 制約
(I1) value 定数 (C1)

出力

名前 タイプ 制約
output テンソルまたは量子化テンソル (C1)

制約

  • (C1)type(value) = type(output)

%output = "stablehlo.constant"() {
  value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
} : () -> tensor<2x2xf32>
// %output: [[0.0, 1.0], [2.0, 3.0]]

その他の例

コンバージョン

セマンティクス

operand テンソルで要素タイプから別の要素タイプへの要素単位の変換を実行し、result テンソルを生成します。

ブール値から任意のサポートされている型への変換では、値 false は 0 に変換され、値 true は 1 に変換されます。boolean-to-any-supported-typeany-supported-type-to-booleanへの変換では、ゼロの値は false に変換され、ゼロ以外の値は true に変換されます。複雑な型での仕組みについては、下記をご覧ください。

integer-to-integerinteger-to-floating-pointfloating-point-to-floating-point が関係する変換において、ソース値を変換先の型で正確に表すことができる場合、結果の値は正確に表現されます。それ以外の場合、動作は未定です(#180)。

floating-point-to-integerへの変換を含む変換では、小数部が切り捨てられます。切り捨てられた値をデスティネーション タイプで表現できない場合、動作は未定です(#180)。

複素数から複素数への変換を含む変換は、実数部と虚数部を変換する場合と同じ動作になります。

「complex-to-any-other-type」complex-to-any-other-typeコンバージョンと「any-other-type-to-complex」complex-to-any-other-typeコンバージョンの場合、ソースの虚偽値は無視され、デスティネーションの虚数はゼロになります。実数部の変換は、浮動小数点変換に従います。

原理上、この演算は逆量子化(量子化テンソルから正規テンソルへの変換)、量子化(正規テンソルから量子テンソルへの変換)、再量子化(量子化テンソル間の変換)を表現できますが、現時点では、最初のユースケースでは uniform_dequantize、2 番目と 3 番目のユースケースでは uniform_quantize を使用しています。今後、これら 2 つの演算は convert に統合される可能性があります(#1576)。

入力

ラベル 名前 タイプ 制約
(I1) operand テンソル (C1)

出力

名前 タイプ 制約
result テンソル (C1)

制約

  • (C1)shape(operand) = shape(result)

// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand) : (tensor<3xi64>) -> tensor<3xcomplex<f64>>
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]

その他の例

畳み込み

セマンティクス

lhs のウィンドウと rhs のスライスの間のドット積を計算し、result を生成します。次の図は、具体的な例を使用して、result の要素が lhsrhs から計算される方法を示しています。

より正式には、lhs のウィンドウを表現できるように、lhs の観点から入力を次のように再フレーミングすることを検討してください。

  • lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension)).
  • lhs_window_strides = lhs_shape(1, window_strides, 1).
  • lhs_padding = lhs_shape([0, 0], padding, [0, 0]).
  • lhs_base_dilations = lhs_shape(1, lhs_dilation, 1).
  • lhs_window_dilations = lhs_shape(1, rhs_dilation, 1).

この再構成では、次のヘルパー関数を使用します。

  • lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]).
  • result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]).
  • permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1]j[d] = i[permutation[d]])。

feature_group_count = 1 かつ batch_group_count = 1 の場合は、index_space(dim(result, output_spatial_dimensions...)) 内のすべての output_spatial_index に対して、result[result_shape(:, output_spatial_index, :)] = dot_product は次のようになります。

  • padding_value = constant(0, element_type(lhs)).
  • padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1).
  • lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides.
  • lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations).
  • reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true])。この機能は使用されていないようです。そのため、今後削除する予定です(#1181)。
  • dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension]).

feature_group_count > 1 の場合:

  • lhses = split(lhs, feature_group_count, input_feature_dimension).
  • rhses = split(rhs, feature_group_count, kernel_output_feature_dimension).
  • results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...).
  • result = concatenate(results, output_feature_dimension).

batch_group_count > 1 の場合:

  • lhses = split(lhs, batch_group_count, input_batch_dimension).
  • rhses = split(rhs, batch_group_count, kernel_output_feature_dimension).
  • results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...).
  • result = concatenate(results, output_feature_dimension)

量子化型の場合、dequantize_op_quantize( lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension, input_feature_dimension, input_spatial_dimensions, kernel_input_feature_dimension, kernel_output_feature_dimension, kernel_spatial_dimensions, output_batch_dimension, output_feature_dimension, output_spatial_dimensions, feature_group_count, batch_group_count, precision_config), lhs, rhs, type(result)) を実行します。

入力

ラベル 名前 タイプ 制約
(I1) lhs テンソルまたはテンソルごとの量子化テンソル (C1)、(C10 ~ C11)、(C14)(C25)、(C27 ~ C30)
(I2) rhs テンソルまたは量子化テンソル (C1)、(C14 ~ C16)、(C25)、(C27 ~ C32)
(I3) window_strides si64 型の 1 次元テンソル定数 (C2 ~ C3)、(C25)
(I4) padding si64 型の 2 次元テンソル定数 (C4)、(C25)
(I5) lhs_dilation si64 型の 1 次元テンソル定数 (C5 ~ C6)、(C25)
(I6) rhs_dilation si64 型の 1 次元テンソル定数 (C7 ~ C8)、(C25)
(I7) window_reversal i1 型の 1 次元テンソル定数 (C9)
(I8) input_batch_dimension si64 型の定数 (C10)、(C13)、(C25)
(I9) input_feature_dimension si64 型の定数 (C11)、(C13 ~ C14)
(I10) input_spatial_dimensions si64 型の 1 次元テンソル定数 (C12)、(C13)、(C25)
(I11) kernel_input_feature_dimension si64 型の定数 (C14)、(C18)
(I12) kernel_output_feature_dimension si64 型の定数 (C15 ~ C16)、(C18)、(C25)、(C32)
(I13) kernel_spatial_dimensions si64 型の 1 次元テンソル定数 (C17 ~ C18)、(C25)
(I14) output_batch_dimension si64 型の定数 (C20)、(C25)
(I15) output_feature_dimension si64 型の定数 (C20)、(C25)、(C33)
(I16) output_spatial_dimensions si64 型の 1 次元テンソル定数 (C19 ~ C20)、(C25)
(I17) feature_group_count si64 型の定数 (C11)、(C14)、(C16)、(C21)、(C23)
(I18) batch_group_count si64 型の定数 (C10)、(C15)、(C22)、(C23)、(C25)
(I19) precision_config DEFAULTHIGHHIGHEST の可変列挙型 (C24)

出力

名前 タイプ 制約
result テンソルまたは量子化テンソル (C25 ~ C28)、(C30 ~ C31)、(C33)

制約

  • (C1)N = rank(lhs) = rank(rhs)
  • (C2)size(window_strides) = N - 2
  • (C3)0 < window_strides
  • (C4)shape(padding) = [N - 2, 2]
  • (C5)size(lhs_dilation) = N - 2
  • (C6)0 < lhs_dilation
  • (C7)size(rhs_dilation) = N - 2
  • (C8)0 < rhs_dilation
  • (C9)size(window_reversal) = N - 2
  • (C10)dim(lhs, input_batch_dimension) % batch_group_count = 0
  • (C11)dim(lhs, input_feature_dimension) % feature_group_count = 0
  • (C12)size(input_spatial_dimensions) = N - 2
  • (C13)input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension] の場合:
    • is_unique(input_dimensions).
    • 0 <= input_dimensions < N.
  • (C14)dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count
  • (C15)dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0
  • (C16)dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0
  • (C17)size(kernel_spatial_dimensions) = N - 2
  • (C18)kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension] の場合:
    • is_unique(kernel_dimensions).
    • 0 <= kernel_dimensions < N.
  • (C19)size(output_spatial_dimensions) = N - 2
  • (C20)output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension] の場合:
    • is_unique(output_dimensions).
    • 0 <= output_dimensions < N.
  • (C21)0 < feature_group_count
  • (C22)0 < batch_group_count
  • (C23)feature_group_count = 1 or batch_group_count = 1
  • (C24)size(precision_config) = 2
  • (C25)dim(result, result_dim) は次のように定義されます。
    • result_dim = output_batch_dimension の場合は dim(lhs, input_batch_dimension) / batch_group_count
    • result_dim = output_feature_dimension の場合は dim(rhs, kernel_output_feature_dimension)
    • それ以外の場合は num_windows。ここで、
    • output_spatial_dimensions[spatial_dim] = result_dim.
    • lhs_dim = input_spatial_dimensions[spatial_dim].
    • rhs_dim = kernel_spatial_dimensions[spatial_dim].
    • dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1.
    • padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1].
    • dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1.
    • is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim].
    • num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1.
  • (C26)rank(result) = N
  • 演算で量子化されていないテンソルを使用する場合:
    • (C27)element_type(lhs) = element_type(rhs) = element_type(result)
  • 演算に量子化テンソルを使用する場合:
    • (C28)is_quantized_tensor(lhs) and is_quantized_tensor(rhs) and is_quantized_tensor(result)
    • (C29)storage_type(lhs) = storage_type(rhs)
    • (C30)expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
    • (C31)is_per_tensor_quantized(rhs) の場合、is_per_tensor_quantized(result)
    • (C32)is_per_axis_quantized(rhs) の場合、quantization_dimension(rhs) = kernel_output_feature_dimension
    • (C33)is_per_axis_quantized(result) の場合、quantization_dimension(result) = output_feature_dimension

// %lhs: [[
//        [
//          [1], [2], [5], [6]
//        ],
//        [
//          [3], [4], [7], [8]
//        ],
//        [
//          [10], [11], [14], [15]
//        ],
//        [
//          [12], [13], [16], [17]
//        ]
//      ]]
//
// %rhs : [
//         [[[1]], [[1]], [[1]]],
//         [[[1]], [[1]], [[1]]],
//         [[[1]], [[1]], [[1]]]
//        ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
  window_strides = dense<4> : tensor<2xi64>,
  padding = dense<0> : tensor<2x2xi64>,
  lhs_dilation = dense<2> : tensor<2xi64>,
  rhs_dilation = dense<1> : tensor<2xi64>,
  window_reversal = dense<false> : tensor<2xi1>,
  // In the StableHLO dialect, dimension numbers are encoded via:
  // `[<input dimensions>]x[<kernel dimensions>]->[output dimensions]`.
  // "b" is batch dimension, "f" is feature dimension,
  // "i" is input feature dimension, "o" is output feature dimension,
  // "0/1/etc" are spatial dimensions.
  dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
  feature_group_count = 1 : i64,
  batch_group_count = 1 : i64,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi32>, tensor<3x3x1x1xi32>) -> tensor<1x2x2x1xi32>
// %result: [[
//            [[10], [26]],
//            [[46], [62]]
//          ]]

コサイン

セマンティクス

operand テンソルに対して要素ごとのコサイン演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の操作を行います。

  • 浮動小数点数の場合: IEEE-754 の cos
  • 複素数の場合: 複素コサイン。
  • 量子化型の場合: dequantize_op_quantize(cosine, operand, type(result))

入力

ラベル 名前 タイプ 制約
(I1) operand 浮動小数点型または複素数のテンソル、またはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result 浮動小数点型または複素数のテンソル、またはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(operand) = baseline_type(result)

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.cosine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.0], [-1.0, 0.0]]

その他の例

count_leading_zeros

セマンティクス

operand テンソル内の先頭のゼロビット数を要素ごとにカウントし、result テンソルを生成します。

入力

ラベル 名前 タイプ 制約
(I1) operand 整数型のテンソル (C1)

出力

名前 タイプ 制約
result 整数型のテンソル (C1)

制約

  • (C1)type(operand) = type(result)

// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand) : (tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[64, 63], [56, 0]]

その他の例

custom_call

セマンティクス

inputscalled_computations を受け取って results を生成する実装定義のオペレーション call_target_name をカプセル化します。has_side_effectbackend_configapi_version を使用して、実装定義の追加のメタデータを指定できます。

現時点で、このオペレーションには、XLA コンパイラにおける対応オペレーションの有機的進化を反映した、かなり整理されたメタデータのコレクションが含まれています。今後、このメタデータを統合する予定です(#741)。

入力

ラベル 名前 タイプ
(I1) inputs 可変長の値
(I2) call_target_name string 型の定数
(I3) has_side_effect i1 型の定数
(I4) backend_config string 型の定数
(I5) api_version si32 型の定数
(I6) called_computations string 型の定数の可変数

出力

名前 タイプ
results 可変長の値

%results = "stablehlo.custom_call"(%input0) {
  call_target_name = "foo",
  has_side_effect = false,
  backend_config = "bar",
  api_version = 1 : i32,
  called_computations = [@foo]
} : (tensor<f64>) -> tensor<f64>

割る

セマンティクス

被除数 lhs と除数 rhs テンソルを要素ごとに除算し、result テンソルを生成します。要素のタイプに応じて、次の操作を行います。

  • 整数の場合: 小数部を破棄した代数の商を生成する整数除算。
  • 浮動小数点数の場合: IEEE-754 の division
  • 複素数の場合: 複素除算。
  • 量子化型の場合:
    • dequantize_op_quantize(divide, lhs, rhs, type(result)).

入力

ラベル 名前 タイプ 制約
(I1) lhs 整数、浮動小数点、複合型、テンソルごとの量子化テンソル (C1)
(I2) rhs 整数、浮動小数点、複合型、テンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result 整数、浮動小数点、複合型のテンソル、またはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)

// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]

その他の例

dot_general

セマンティクス

lhs のスライスと rhs のスライス間のドット積を計算し、result テンソルを生成します。

より正式には result[result_index] = dot_product です。ここで:

  • lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions].
  • rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions].
  • result_batching_index + result_lhs_index + result_rhs_index = result_index。ここで、size(result_batching_index) = size(lhs_batching_dimensions)size(result_lhs_index) = size(lhs_result_dimensions)size(result_rhs_index) = size(rhs_result_dimensions)
  • transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions).
  • transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :]).
  • reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions)).
  • transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions).
  • transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :]).
  • reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions)).
  • dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y))

量子化型の場合、dequantize_op_quantize( lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions, rhs_batching_dimensions, lhs_contracting_dimensions, rhs_contracting_dimensions, precision_config), lhs, rhs, type(result)) を実行します。

これは、テンソルごとの量子化のセマンティクスを指定するだけです。軸ごとの量子化が進行中です(#1574)。将来的には、ハイブリッド量子化のサポートを追加する可能性があります(#1575)。

precision_config は、アクセラレータ バックエンドでのコンピューティングの速度と精度のトレードオフを制御します。これは次のいずれかになります(現時点では、これらの列挙値のセマンティクスは規定されていませんが、#755 で対処する予定です)。

  • DEFAULT: 計算は最も高速ですが、元の数値への近似精度は最も低くなります。
  • HIGH: 計算速度は遅くなりますが、元の数値への近似値がより正確になります。
  • HIGHEST: 計算は最も遅くなりますが、元の数値への近似値が最も高くなります。

入力

ラベル 名前 タイプ 制約
(I1) lhs テンソルまたはテンソルごとの量子化テンソル (C5 ~ C6)、(C9 ~ C10)、(C12 ~ C16)
(I2) rhs テンソルまたはテンソルごとの量子化テンソル (C7 ~ C10)、(C12)
(I3) lhs_batching_dimensions si64 型の 1 次元テンソル定数 (C1)、(C3)、(C5)、(C9)、(C12)
(I4) rhs_batching_dimensions si64 型の 1 次元テンソル定数 (C1)、(C4)、(C7)、(C9)
(I5) lhs_contracting_dimensions si64 型の 1 次元テンソル定数 (C2)、(C3)、(C6)、(C10)
(I6) rhs_contracting_dimensions si64 型の 1 次元テンソル定数 (C2)、(C4)、(C8)、(C10)
(I7) precision_config DEFAULTHIGHHIGHEST の可変列挙型 (C11)

出力

名前 タイプ 制約
result テンソルまたはテンソルごとの量子化テンソル (C12)、(C14)、(C16)

制約

  • (C1)size(lhs_batching_dimensions) = size(rhs_batching_dimensions)
  • (C2)size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions)
  • (C3)is_unique(lhs_batching_dimensions + lhs_contracting_dimensions)
  • (C4)is_unique(rhs_batching_dimensions + rhs_contracting_dimensions)
  • (C5)0 <= lhs_batching_dimensions < rank(lhs)
  • (C6)0 <= lhs_contracting_dimensions < rank(lhs)
  • (C7)0 <= rhs_batching_dimensions < rank(rhs)
  • (C8)0 <= rhs_contracting_dimensions < rank(rhs)
  • (C9)dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)
  • (C10)dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...)
  • (C11)size(precision_config) = 2
  • (C12)shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions)
  • 演算で量子化されていないテンソルを使用する場合:
    • (C13)element_type(lhs) = element_type(rhs)
  • 演算に量子化テンソルを使用する場合:
    • (C14)is_quantized(lhs) and is_quantized(rhs) and is_quantized(result)
    • (C15)storage_type(lhs) = storage_type(rhs)
    • (C16)expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
    • (C17)zero_points(rhs) = 0

// %lhs: [
//        [[1, 2],
//         [3, 4]],
//        [[5, 6],
//         [7, 8]]
//       ]
// %rhs: [
//        [[1, 0],
//         [0, 1]],
//        [[1, 0],
//         [0, 1]]
//       ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
  dot_dimension_numbers = #stablehlo.dot<
    lhs_batching_dimensions = [0],
    rhs_batching_dimensions = [0],
    lhs_contracting_dimensions = [2],
    rhs_contracting_dimensions = [1]
  >,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<2x2x2xi64>, tensor<2x2x2xi64>) -> tensor<2x2x2xi64>
// %result: [
//           [[1, 2],
//            [3, 4]],
//           [[5, 6],
//            [7, 8]]
//          ]

その他の例

dynamic_slice

セマンティクス

動的に計算された開始インデックスを使用して operand からスライスを抽出し、result テンソルを生成します。start_indices には調整対象となる各ディメンションのスライスの開始インデックスが含まれ、slice_sizes には各ディメンションのスライスのサイズが含まれます。より正式には result[result_index] = operand[operand_index] です。ここで、

  • adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes).
  • operand_index = adjusted_start_indices + result_index.

入力

ラベル 名前 タイプ 制約
(I1) operand テンソルまたはテンソルごとの量子化テンソル (C1)、(C2)、(C4)
(I2) start_indices 整数型の 0 次元テンソルの可変数 (C2)、(C3)
(I3) slice_sizes si64 型の 1 次元テンソル定数 (C2)、(C4)、(C5)

出力

名前 タイプ 制約
result テンソルまたはテンソルごとの量子化テンソル (C1)、(C5)

制約

  • (C1)element_type(operand) = element_type(result)
  • (C2)size(start_indices) = size(slice_sizes) = rank(operand)
  • (C3)same(type(start_indices...))
  • (C4)0 <= slice_sizes <= shape(operand)
  • (C5)shape(result) = slice_sizes

// %operand: [
//            [0, 0, 1, 1],
//            [0, 0, 1, 1],
//            [0, 0, 0, 0],
//            [0, 0, 0, 0]
//           ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
  slice_sizes = dense<[2, 2]> : tensor<2xi64>
} : (tensor<4x4xi32>, tensor<i64>, tensor<i64>) -> tensor<2x2xi32>
// %result: [
//           [1, 1],
//           [1, 1]
//          ]

その他の例

dynamic_update_slice

セマンティクス

start_indices で始まるスライスが update の値で更新されることを除けば、operand テンソルと等しい result テンソルを生成します。より正式には、result[result_index] は次のように定義されます。

  • 0 <= update_index < shape(update) の場合は update[update_index]。ここで、
    • adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update)).
    • update_index = result_index - adjusted_start_indices.
  • それ以外の場合は operand[result_index]

入力

ラベル 名前 タイプ 制約
(I1) operand テンソルまたはテンソルごとの量子化テンソル (C1 ~ C4)、(C6)
(I2) update テンソルまたはテンソルごとの量子化テンソル (C2)、(C3)、(C6)
(I3) start_indices 整数型の 0 次元テンソルの可変数 (C4)、(C5)

出力

名前 タイプ 制約
result テンソルまたはテンソルごとの量子化テンソル (C1)

制約

  • (C1)type(operand) = type(result)
  • (C2)element_type(update) = element_type(operand)
  • (C3)rank(update) = rank(operand)
  • (C4)size(start_indices) = rank(operand)
  • (C5)same(type(start_indices...))
  • (C6)0 <= shape(update) <= shape(operand)

// %operand: [
//            [1, 1, 0, 0],
//            [1, 1, 0, 0],
//            [1, 1, 1, 1],
//            [1, 1, 1, 1]
//           ]
// %update: [
//           [1, 1],
//           [1, 1]
//          ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
  : (tensor<4x4xi32>, tensor<2x2xi32>, tensor<i64>, tensor<i64>) -> tensor<4x4xi32>
// %result: [
//           [1, 1, 1, 1],
//           [1, 1, 1, 1],
//           [1, 1, 1, 1],
//           [1, 1, 1, 1]
//          ]

その他の例

指数

セマンティクス

operand テンソルに対して要素単位の指数演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の操作を行います。

  • 浮動小数点数の場合: IEEE-754 の exp
  • 複素数の場合: 複素指数。
  • 量子化型の場合: dequantize_op_quantize(exponential, operand, type(result))

入力

ラベル 名前 タイプ 制約
(I1) operand 浮動小数点型または複素数のテンソル、またはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result 浮動小数点型または複素数のテンソル、またはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(operand) = baseline_type(result)

// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]

その他の例

exponential_minus_one

セマンティクス

operand テンソルに対して要素単位の指数から 1 を引いた演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の操作を行います。

  • 浮動小数点数の場合: IEEE-754 の expm1
  • 複素数の場合: 複素指数 - 1
  • 量子化型の場合: dequantize_op_quantize(exponential_minus_one, operand, type(result))

入力

ラベル 名前 タイプ 制約
(I1) operand 浮動小数点型または複素数のテンソル、またはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result 浮動小数点型または複素数のテンソル、またはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(operand) = baseline_type(result)

// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand) : (tensor<2xf64>) -> tensor<2xf64>
// %result: [0.0, 1.71828187]

その他の例

FFT

セマンティクス

実数と複雑な入出力に対して順フーリエ変換と逆フーリエ変換を実行します。

fft_type は、次のいずれかです。

  • FFT: 複素数から複素数への FFT を転送します。
  • IFFT: 複素数から複素数への逆 FFT です。
  • RFFT: 実数から複素数への FFT 転送を行います。
  • IRFFT: 実数から複素数への逆 FFT です(複素数を取り、実数を返します)。

より正式には、複雑な型の 1 次元テンソルを入力として受け取る関数 fft で、出力と同じ型の 1 次元テンソルを生成し、離散フーリエ変換を計算します。

fft_type = FFT の場合、result は一連の L 演算の最終結果として定義されます。ただし、L = size(fft_length) となります。たとえば、L = 3 の場合は次のようになります。

  • result1[i0, ..., :] = fft(operand[i0, ..., :]).
  • result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).

さらに、同じ型署名を持ち、fft の逆数を計算する関数 ifft があるとします。

fft_type = IFFT の場合、resultfft_type = FFT の計算の逆数として定義されます。たとえば、L = 3 の場合は次のようになります。

  • result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).
  • result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :] = ifft(result2[i0, ..., :]).

さらに、浮動小数点型の 1 次元テンソルを受け取る関数 rfft を指定すると、同じ浮動小数点セマンティクスの複雑な型の 1 次元テンソルが生成され、次のように動作します。

  • rfft(real_operand) = truncated_result ここで
  • complex_operand... = (real_operand..., 0.0).
  • complex_result = fft(complex_operand).
  • truncated_result = complex_result[:(rank(complex_result) / 2 + 1)].

(実際のオペランドに対して離散フーリエ変換が計算される場合、結果の最初の N/2 + 1 要素によって結果の残りの部分が明確に定義されるため、冗長な要素の計算を回避するために rfft の結果は切り捨てられます)。

fft_type = RFFT の場合、result は一連の L 演算の最終結果として定義されます。ただし、L = size(fft_length) となります。たとえば、L = 3 の場合は次のようになります。

  • result1[i0, ..., :] = rfft(operand[i0, ..., :]).
  • result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).

最後に、同じ型署名を持ち、rfft の逆数を計算する関数 irfft があるとします。

fft_type = IRFFT の場合、resultfft_type = RFFT の計算の逆数として定義されます。たとえば、L = 3 の場合は次のようになります。

  • result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).
  • result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :] = irfft(result2[i0, ..., :]).

入力

ラベル 名前 タイプ 制約
(I1) operand 浮動小数点型または複素数のテンソル (C1)、(C2)、(C4)、(C5)
(I2) fft_type FFTIFFTRFFTIRFFT の列挙型 (C2)、(C5)
(I3) fft_length si64 型の 1 次元テンソル定数 (C1)、(C3)、(C4)

出力

名前 タイプ 制約
result 浮動小数点型または複素数のテンソル (C2)、(C4)、(C5)

制約

  • (C1)size(fft_length) <= rank(operand)
  • (C2)operand 要素と result 要素タイプの関係はさまざまです。
    • fft_type = FFTelement_type(operand)element_type(result) の複合型は同じです。
    • fft_type = IFFTelement_type(operand)element_type(result) の複合型は同じです。
    • fft_type = RFFT の場合、element_type(operand) は浮動小数点型であり、element_type(result) は同じ浮動小数点セマンティクスの複合型です。
    • fft_type = IRFFT の場合、element_type(operand) は複合型であり、element_type(result) は同じ浮動小数点セマンティクスの浮動小数点型です。
  • (C3)1 <= size(fft_length) <= 3
  • (C4)operandresult の間に、浮動小数点型のテンソル real があり、次に shape(real)[-size(fft_length):] = fft_length があります。
  • (C5)shape(result) = shape(operand)(以下を除く):
    • fft_type = RFFT の場合、dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1
    • fft_type = IRFFT の場合、dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1

// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
  fft_type = #stablehlo<fft_type FFT>,
  fft_length = dense<4> : tensor<1xi64>
} : (tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>>
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]

floor

セマンティクス

operand テンソルの要素ごとの下限を実行し、result テンソルを生成します。IEEE-754 仕様の roundToIntegralTowardNegative オペレーションを実装します。量子化型の場合、dequantize_op_quantize(floor, operand, type(result)) を実行します。

入力

ラベル 名前 タイプ 制約
(I1) operand 浮動小数点型またはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result 浮動小数点型またはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(operand) = baseline_type(result)

// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]

その他の例

収集

セマンティクス

start_indices で指定されたオフセットから operand テンソルからスライスを収集し、result テンソルを生成します。

次の図は、具体的な例を使用して、result の要素が operand の要素にどのようにマッピングされるかを示しています。この図では、いくつかの result インデックスを選択し、それらがどの operand インデックスに対応しているかを詳しく説明しています。

result[result_index] = operand[operand_index] の形式は次のとおりです。

  • batch_dims = [d for d in axes(result) and d not in offset_dims].
  • batch_index = result_index[batch_dims...].
  • start_index は次のように定義されます。
    • start_indices[bi0, ..., :, ..., biN]。ここで、bibatch_index の個々の要素であり、:index_vector_dim < rank(start_indices) の場合に index_vector_dim インデックスに挿入されます。
    • それ以外の場合は [start_indices[batch_index]]
  • axes(operand)d_operand の場合:
    • full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand])d_operand = start_index_map[d_start] の場合)。
    • それ以外の場合は full_start_index[d_operand] = 0
  • offset_index = result_index[offset_dims...].
  • full_offset_index = [oi0, ..., 0, ..., oiN]。ここで、oioffset_index の個々の要素であり、0collapsed_slice_dims のインデックスに挿入されます。
  • operand_index = full_start_index + full_offset_index

indices_are_sortedtrue の場合、実装では start_indicesstart_index_map を基準に並べ替えられていると想定できます。それ以外の場合は、動作が未定義になります。より正式には、indices(result) からのすべての i1 < i2 について、full_start_index(i1) <= full_start_index(i2) となります。

入力

ラベル 名前 タイプ 制約
(I1) operand テンソルまたはテンソルごとの量子化テンソル (C1)、(C7)、(C10 ~ C12)、(C14)
(I2) start_indices 整数型のテンソル (C2)、(C3)、(C13)
(I3) offset_dims si64 型の 1 次元テンソル定数 (C1)、(C4 ~ C5)、(C13)
(I4) collapsed_slice_dims si64 型の 1 次元テンソル定数 (C1)、(C6 ~ C8)、(C13)
(I5) start_index_map si64 型の 1 次元テンソル定数 (C3)、(C9)、(C10)
(I6) index_vector_dim si64 型の定数 (C2)、(C3)、(C13)
(I7) slice_sizes si64 型の 1 次元テンソル定数 (C8)、(C11 ~ C13)
(I8) indices_are_sorted i1 型の定数

出力

名前 タイプ 制約
result テンソルまたはテンソルごとの量子化テンソル (C5)、(C13 ~ C14)

制約

  • (C1)rank(operand) = size(offset_dims) + size(collapsed_slice_dims)
  • (C2)0 <= index_vector_dim <= rank(start_indices)
  • (C3)size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1
  • (C4)is_unique(offset_dims) and is_sorted(offset_dims)
  • (C5)0 <= offset_dims < rank(result)
  • (C6)is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims)
  • (C7)0 <= collapsed_slice_dims < rank(operand)
  • (C8)slice_sizes[collapsed_slice_dims...] <= 1
  • (C9)is_unique(start_index_map)
  • (C10)0 <= start_index_map < rank(operand)
  • (C11)size(slice_sizes) = rank(operand)
  • (C12)0 <= slice_sizes <= shape(operand)
  • (C13)shape(result) = combine(batch_dim_sizes, offset_dim_sizes) ここで:
    • batch_dim_sizes = shape(start_indices)。ただし、index_vector_dim に対応する start_indices のディメンション サイズは含まれません。
    • offset_dim_sizes = shape(slice_sizes)。ただし、collapsed_slice_dims に対応する slice_sizes のディメンション サイズは含まれません。
    • combine は、batch_dims に対応する軸と offset_dim_sizes に対応する軸で offset_dims に対応する軸に batch_dim_sizes を配置します。
  • (C14)element_type(operand) = element_type(result)

// %operand: [
//            [[1, 2], [3, 4], [5, 6], [7, 8]],
//            [[9, 10],[11, 12], [13, 14], [15, 16]],
//            [[17, 18], [19, 20], [21, 22], [23, 24]]
//           ]
// %start_indices: [
//                  [[0, 0], [1, 0], [2, 1]],
//                  [[0, 1], [1, 1], [0, 2]]
//                 ]
%result = "stablehlo.gather"(%operand, %start_indices) {
  dimension_numbers = #stablehlo.gather<
    offset_dims = [2, 3],
    collapsed_slice_dims = [0],
    start_index_map = [1, 0],
    index_vector_dim = 2>,
  slice_sizes = dense<[1, 2, 2]> : tensor<3xi64>,
  indices_are_sorted = false
} : (tensor<3x4x2xi32>, tensor<2x3x2xi64>) -> tensor<2x3x2x2xi32>
// %result: [
//            [
//              [[1, 2], [3, 4]],
//              [[3, 4], [5, 6]],
//              [[13, 14], [15, 16]]
//            ],
//            [
//              [[9, 10], [11, 12]],
//              [[11, 12], [13, 14]],
//              [[17, 18], [19, 20]]
//            ]
//          ]

その他の例

get_dimension_size

セマンティクス

operand の指定された dimension のサイズを生成します。より正式には、result = dim(operand, dimension) です。

入力

ラベル 名前 タイプ 制約
(I1) operand テンソル (C1)
(I2) dimension si64 型の定数 (C1)

出力

名前 タイプ
result si32 型の 0 次元テンソル

制約

  • (C1)0 <= dimension < rank(operand)

// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
  dimension = 1 : i64
} : (tensor<2x3xi64>) -> tensor<i32>
// %result: 3

その他の例

get_tuple_element

セマンティクス

operand タプルの index 位置の要素を抽出し、result を生成します。より正式には、result = operand[index] です。

入力

ラベル 名前 タイプ 制約
(I1) operand tuple (C1)、(C2)
(I2) index si32 型の定数 (C1)、(C2)

出力

名前 タイプ 制約
result サポートされている任意のタイプ (C2)

制約

  • (C1)0 <= index < size(operand)
  • (C2)type(result) = tuple_element_types(operand)[index]

// %operand: ([1.0, 2.0], (3))
%result = "stablehlo.get_tuple_element"(%operand) {
  index = 0 : i32
} : (tuple<tensor<2xf32>, tuple<tensor<i32>>>) -> tensor<2xf32>
// %result: [1.0, 2.0]

その他の例

if

セマンティクス

pred の値に応じて、true_branch または false_branch から関数を 1 つだけ実行して出力を生成します。より正式には、result = pred ? true_branch() : false_branch() です。

入力

ラベル 名前 タイプ 制約
(I1) pred i1 型の 0 次元テンソル
(I2) true_branch 機能 (C1 ~ C3)
(I3) false_branch 機能 (C1)、(C2)

出力

名前 タイプ 制約
results テンソル、量子化テンソル、トークンの可変数 (C3)

制約

  • (C1)input_types(true_branch) = input_types(false_branch) = []
  • (C2)output_types(true_branch) = output_types(false_branch)
  • (C3)type(results...) = output_types(true_branch)

// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
  "stablehlo.return"(%result_true_branch) : (tensor<i32>) -> ()
}, {
  "stablehlo.return"(%result_false_branch) : (tensor<i32>) -> ()
}) : (tensor<i1>) -> tensor<i32>
// %result: 10

その他の例

画像

セマンティクス

operand から虚数部を要素ごとに抽出し、result テンソルを生成します。より正式には、各要素 x に対して imag(x) = is_complex(x) ? imaginary_part(x) : constant(0, element_type(result)) と記述します。

入力

ラベル 名前 タイプ 制約
(I1) operand 浮動小数点型または複素数のテンソル (C1)、(C2)

出力

名前 タイプ 制約
result 浮動小数点型のテンソル (C1)、(C2)

制約

  • (C1)shape(result) = shape(operand)
  • (C2)element_type(result) が次のように定義されます。
    • is_complex(operand) の場合は complex_element_type(element_type(operand))
    • それ以外の場合は element_type(operand)

// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [2.0, 4.0]

その他の例

インフィード

セマンティクス

インフィードからデータを読み取り、results を生成します。

infeed_config のセマンティクスは実装で定義します。

results は、最初に来たペイロード値と最後に来るトークンで構成されます。今後、明確にするために、ペイロードとトークンを 2 つの別々の出力に分割する予定です(#670)。

入力

ラベル 名前 タイプ
(I1) token token
(I2) infeed_config string 型の定数

出力

名前 タイプ 制約
results テンソル、量子化テンソル、トークンの可変数 (C1 ~ C3)

制約

  • (C1)0 < size(results)
  • (C2)is_empty(result[:-1]) または is_tensor(type(results[:-1]))
  • (C3)is_token(type(results[-1]))

// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
  infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
  infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]

その他の例

イオタ

セマンティクス

iota_dimension ディメンションに沿って、0 から順に昇順の値で output テンソルを埋めます。より正式に,

output[result_index] = constant(is_quantized(output) ? quantize(result_index[iota_dimension], element_type(output)) : result_index[iota_dimension], element_type(output)).

入力

ラベル 名前 タイプ 制約
(I1) iota_dimension si64 (C1)

出力

名前 タイプ 制約
output 整数、浮動小数点、複合型、テンソルごとの量子化テンソル (C1)

制約

  • (C1)0 <= iota_dimension < rank(output)

%output = "stablehlo.iota"() {
  iota_dimension = 0 : i64
} : () -> tensor<4x5xi32>
// %output: [
//           [0, 0, 0, 0, 0],
//           [1, 1, 1, 1, 1],
//           [2, 2, 2, 2, 2],
//           [3, 3, 3, 3, 3]
//          ]

%output = "stablehlo.iota"() {
  iota_dimension = 1 : i64
} : () -> tensor<4x5xi32>
// %output: [
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4]
//          ]

その他の例

is_finite

セマンティクス

x の値が有限(つまり、+Inf、-Inf、NaN ではない)かどうかを要素ごとにチェックし、y テンソルを生成します。IEEE-754 仕様の isFinite オペレーションを実装します。量子化型の場合、結果は常に true になります。

入力

ラベル 名前 タイプ 制約
(I1) x 浮動小数点型またはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
y ブール型のテンソル (C1)

制約

  • (C1)shape(x) = shape(y)

// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x) : (tensor<7xf64) -> tensor<7xi1>
// %y: [false, false, false, true, true, true, true]

その他の例

log

セマンティクス

operand テンソルに対して要素ごとの対数演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の操作を行います。

  • 浮動小数点数の場合: IEEE-754 の log
  • 複素数の場合: 複素対数。
  • 量子化型の場合: dequantize_op_quantize(log, operand, type(result))

入力

ラベル 名前 タイプ 制約
(I1) operand 浮動小数点型または複素数のテンソル、またはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result 浮動小数点型または複素数のテンソル、またはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(operand) = baseline_type(result)

// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]

その他の例

log_plus_one

セマンティクス

operand テンソルに対して要素単位の対数と 1 の演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の操作を行います。

  • 浮動小数点数の場合: IEEE-754 の logp1
  • 複素数の場合: 複素対数 + 1
  • 量子化型の場合: dequantize_op_quantize(log_plus_one, operand, type(result))

入力

ラベル 名前 タイプ 制約
(I1) operand 浮動小数点型または複素数のテンソル、またはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result 浮動小数点型または複素数のテンソル、またはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(operand) = baseline_type(result)

// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]

その他の例

ロジスティック

セマンティクス

operand テンソルに対して要素ごとのロジスティック演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の操作を行います。

  • 浮動小数点数の場合: IEEE-754 の division(1, addition(1, exp(-x)))
  • 複素数の場合: 複雑なロジスティック。
  • 量子化型の場合: dequantize_op_quantize(logistic, operand, type(result))

入力

ラベル 名前 タイプ 制約
(I1) operand 浮動小数点型または複素数のテンソル、またはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result 浮動小数点型または複素数のテンソル、またはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(operand) = baseline_type(result)

// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]

その他の例

地図

セマンティクス

マップ関数 computationdimensions に沿って inputs に適用し、result テンソルを生成します。

より正式には、result[result_index] = computation(inputs...[result_index]) です。 dimensions は現在使用されておらず、今後削除される可能性があります(#487)。

入力

ラベル 名前 タイプ 制約
(I1) inputs テンソル数またはテンソルごとの量子化テンソルの可変数 (C1 ~ C4)
(I2) dimensions si64 型の 1 次元テンソル定数 (C3)
(I3) computation 機能 (C4)

出力

名前 タイプ 制約
result テンソルまたはテンソルごとの量子化テンソル (C1)、(C4)

制約

  • (C1)shape(inputs...) = shape(result)
  • (C2)0 < size(inputs) = N
  • (C3)dimensions = range(rank(inputs[0]))
  • (C4)computation(tensor<E0>, ..., tensor<EN-1>) -> tensor<E'> の型で、Ei = element_type(inputs[i])E' = element_type(result) です。

// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = stablehlo.multiply %arg0, %arg1 : tensor<i64>
    stablehlo.return %0 : tensor<i64>
}) {
  dimensions = dense<[0, 1]> : tensor<2xi64>
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[0, 5], [12, 21]]

その他の例

最大

セマンティクス

テンソル lhsrhs に対して要素単位の最大演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の操作を行います。

  • ブール値の場合: 論理 OR。
  • 整数の場合: 整数の最大値。
  • 浮動小数点数の場合: IEEE-754 の maximum
  • 複素数の場合: (real, imaginary) ペアの辞書順の最大値。複素数に順序付けを適用すると、意外なセマンティクスも必要になります。そのため、今後、この演算では複素数のサポートを削除する予定です(#560)。
  • 量子化型の場合:
    • dequantize_op_quantize(maximum, lhs, rhs, type(result)).

入力

ラベル 名前 タイプ 制約
(I1) lhs テンソルまたはテンソルごとの量子化テンソル (C1)
(I2) rhs テンソルまたはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result テンソルまたはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)

// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 8]]

その他の例

最小

セマンティクス

テンソル lhsrhs に対して要素ごとの最小演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の操作を行います。

  • ブール値の場合: 論理 AND。
  • 整数の場合: 整数の最小値。
  • 浮動小数点数の場合: IEEE-754 の minimum
  • 複素数の場合: (real, imaginary) ペアの辞書順の最小値。複素数に順序付けを適用すると、意外なセマンティクスも必要になります。そのため、今後、この演算では複素数のサポートを削除する予定です(#560)。
  • 量子化型の場合:
    • dequantize_op_quantize(minimum, lhs, rhs, type(result)).

入力

ラベル 名前 タイプ 制約
(I1) lhs テンソルまたはテンソルごとの量子化テンソル (C1)
(I2) rhs テンソルまたはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result テンソルまたはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)

// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 4]]

その他の例

掛ける

セマンティクス

2 つのテンソル lhsrhs の要素単位で積を実行し、result テンソルを生成します。要素のタイプに応じて、次の操作を行います。

  • ブール値の場合: 論理 AND。
  • 整数の場合: 整数の乗算。
  • 浮動小数点数の場合: IEEE-754 の multiplication
  • 複素数の場合: 複素乗算。
  • 量子化型の場合:
    • dequantize_op_quantize(multiply, lhs, rhs, type(result)).

入力

ラベル 名前 タイプ 制約
(I1) lhs テンソルまたはテンソルごとの量子化テンソル (C1)
(I2) rhs テンソルまたはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result テンソルまたはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(operand) = baseline_type(result)

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 12], [21, 32]]

その他の例

negate

セマンティクス

operand テンソルの要素単位で否定を行い、result テンソルを生成します。要素のタイプに応じて、次の操作を行います。

  • 符号付き整数の場合: 整数の否定。
  • 符号なし整数の場合: 符号付き整数へのビットキャスト、整数の否定、符号なし整数へのビットキャスト
  • 浮動小数点数の場合: IEEE-754 の negate
  • 複素数の場合: 複素否定。
  • 量子化型の場合: dequantize_op_quantize(negate, operand, type(result))

入力

ラベル 名前 タイプ 制約
(I1) operand 整数、浮動小数点、複合型のテンソル、またはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result 整数、浮動小数点、複合型のテンソル、またはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(operand) = baseline_type(result)

// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand) : (tensor<2xi32>) -> tensor<2xi32>
// %result: [0, 2]

// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"(%operand) : (tensor<1xcomplex<f32>>) -> tensor<1xcomplex<f32>>
// %result: [-2.5, -0.0]

その他の例

いない

セマンティクス

テンソル operand の要素ごとの NOT を実行し、result テンソルを生成します。 要素のタイプに応じて、次の操作を行います。

  • ブール値の場合: 論理 NOT。
  • 整数の場合: ビット演算 NOT。

引数

名前 タイプ 制約
operand ブール値または整数型のテンソル (C1)

出力

名前 タイプ 制約
result ブール値または整数型のテンソル (C1)

制約

  • (C1)type(operand) = type(result)

// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand) : (tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[-2, -3], [-4, -5]]

// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"(%operand) : (tensor<2xi1>) -> tensor<2xi1>
// %result: [false, true]

optimization_barrier

セマンティクス

operand を生成するオペレーションが、result に依存するオペレーションよりも前に実行されるようにし、コンパイラ変換がバリアを越えてオペレーションを移動しないようにします。それ以外は、オペレーションは ID(result = operand)です。

引数

名前 タイプ 制約
operand テンソルの可変数、テンソルごとの量子化テンソル、トークン (C1)

出力

名前 タイプ 制約
result テンソルの可変数、テンソルごとの量子化テンソル、トークン (C1)

制約

  • (C1)type(operand...) = type(result...)

// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
// %result0: 0.0
// %result1: 1.0

その他の例

or

セマンティクス

2 つのテンソル lhsrhs の要素単位 OR を演算し、result テンソルを生成します。要素のタイプに応じて、次の操作を行います。

  • ブール値の場合: 論理 OR。
  • 整数の場合: ビット演算 OR。

入力

ラベル 名前 タイプ 制約
(I1) lhs 整数型またはブール型のテンソル (C1)
(I2) rhs 整数型またはブール型のテンソル (C1)

出力

名前 タイプ 制約
result 整数型またはブール型のテンソル (C1)

制約

  • (C1)type(lhs) = type(rhs) = type(result)

// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 12]]

// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, true]]

アウトフィード

セマンティクス

inputs をアウトフィードに書き込み、result トークンを生成します。

outfeed_config のセマンティクスは実装で定義します。

入力

ラベル 名前 タイプ
(I1) inputs テンソルまたは量子化テンソルの可変数
(I2) token token
(I3) outfeed_config string 型の定数

出力

名前 タイプ
result token

%result = "stablehlo.outfeed"(%inputs0, %token) {
  outfeed_config = ""
} : (tensor<2x2x2xi64>, !stablehlo.token) -> !stablehlo.token

その他の例

パッド

セマンティクス

テンソルの周囲および指定された padding_value でテンソルの要素間にパディングを追加することで、operand を拡張します。

edge_padding_lowedge_padding_high は、各ディメンションのローエンド(インデックス 0 の次)とハイエンド(最も高いインデックスの次)に追加するパディングの量をそれぞれ指定します。パディングの量は負の数にすることができます。負のパディングの絶対値は、指定したディメンションから削除する要素の数を示します。

interior_padding は、各ディメンションの任意の 2 つの要素間に追加されるパディングの量を指定します。負の値は使用できません。内部パディングはエッジ パディングの前に行われるため、負エッジ パディングにより内部パディング オペランドから要素が削除されます。

より正式には、result[result_index] は次のように定義されます。

  • result_index = edge_padding_low + operand_index * (interior_padding + 1) の場合は operand[operand_index]
  • それ以外の場合は padding_value

入力

ラベル 名前 タイプ 制約
(I1) operand テンソルまたはテンソルごとの量子化テンソル (C1)、(C2)、(C4)
(I2) padding_value 0 次元テンソルまたはテンソルごとの量子化テンソル (C1)
(I3) edge_padding_low si64 型の 1 次元テンソル定数 (C1)、(C4)
(I4) edge_padding_high si64 型の 1 次元テンソル定数 (C1)、(C4)
(I5) interior_padding si64 型の 1 次元テンソル定数 (C2 ~ C4)

出力

名前 タイプ 制約
result テンソルまたはテンソルごとの量子化テンソル (C3 ~ C6)

制約

  • (C1)element_type(operand) = element_type(padding_value) = element_type(result)
  • (C2)size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)
  • (C3)0 <= interior_padding
  • (C4)shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high

// %operand: [
//            [1, 2, 3],
//            [4, 5, 6]
//           ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
  edge_padding_low = dense<[0, 1]> : tensor<2xi64>,
  edge_padding_high = dense<[2, 1]> : tensor<2xi64>,
  interior_padding = dense<[1, 2]> : tensor<2xi64>
} : (tensor<2x3xi32>, tensor<i32>) -> tensor<5x9xi32>
// %result: [
//           [0, 1, 0, 0, 2, 0, 0, 3, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 4, 0, 0, 5, 0, 0, 6, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0]
//          ]

その他の例

partition_id

セマンティクス

現在のプロセスの partition_id を生成します。

出力

名前 タイプ
result ui32 型の 0 次元テンソル

%result = "stablehlo.partition_id"() : () -> tensor<ui32>

その他の例

ポップコンテンツ

セマンティクス

operand テンソルに設定されたビット数の要素を要素ごとにカウントし、result テンソルを生成します。

入力

ラベル 名前 タイプ 制約
(I1) operand 整数型のテンソル (C1)

出力

名前 タイプ 制約
result 整数型のテンソル (C1)

制約

  • (C1)type(operand) = type(result)

// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand) : (tensor<4xi64>) -> tensor<4xi64>
// %result: [0, 1, 1, 7]

その他の例

電力

セマンティクス

rhs テンソルで lhs テンソルを要素ごとにべき乗し、result テンソルを生成します。要素のタイプに応じて、次の操作を行います。

  • 整数の場合: 整数のべき乗。
  • 浮動小数点数の場合: IEEE-754 の pow
  • 複素数の場合: 複素べき乗。
  • 量子化型の場合: dequantize_op_quantize(power, lhs, rhs, type(result))

入力

ラベル 名前 タイプ 制約
(I1) lhs 整数、浮動小数点、複合型のテンソル、またはテンソルごとの量子化テンソル (C1)
(I2) rhs 整数、浮動小数点、複合型のテンソル、またはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result 整数、浮動小数点、複合型のテンソル、またはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(operand) = baseline_type(result)

// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs) : (tensor<6xf64>, tensor<6xf64>) -> tensor<6xf64>
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]

その他の例

real

セマンティクス

operand から実部を要素ごとに抽出し、result テンソルを生成します。より正式には、各要素 x に対して real(x) = is_complex(x) ? real_part(x) : x と記述します。

入力

ラベル 名前 タイプ 制約
(I1) operand 浮動小数点型または複素数のテンソル (C1)、(C2)

出力

名前 タイプ 制約
result 浮動小数点型のテンソル (C1)、(C2)

制約

  • (C1)shape(result) = shape(operand)
  • (C2)element_type(result) が次のように定義されます。
    • is_complex(operand) の場合は complex_element_type(element_type(operand))
    • それ以外の場合は element_type(operand)

// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [1.0, 3.0]

その他の例

受信

セマンティクス

channel_id でチャネルからデータを受信し、results を生成します。

is_host_transfertrue の場合、オペレーションはホストからデータを転送する。それ以外の場合は、別のデバイスからデータが転送されます。これは、実装が定義されることを意味します。このフラグは channel_type で指定された情報が重複しているため、将来的にはそのうち 1 つのみを保持する予定です(#666)。

results は、最初に来たペイロード値と最後に来るトークンで構成されます。今後、明確にするために、ペイロードとトークンを 2 つの別々の出力に分割する予定です(#670)。

入力

ラベル 名前 タイプ 制約
(I1) token token (C4)
(I2) channel_id si64 型の定数
(I3) channel_type DEVICE_TO_DEVICEHOST_TO_DEVICE の列挙型 (C1)
(I4) is_host_transfer i1 型の定数 (C1)

出力

名前 タイプ 制約
results テンソル、量子化テンソル、トークンの可変数 (C2 ~ C4)

制約

  • (C1)channel_type が次のように定義されます。
    • is_host_transfer = true の場合は HOST_TO_DEVICE
    • それ以外の場合は DEVICE_TO_DEVICE
  • (C2)0 < size(results)
  • (C3)is_empty(result[:-1]) または is_tensor(type(results[:-1]))
  • (C4)is_token(type(results[-1]))

%results0, %results1 = "stablehlo.recv"(%token) {
  channel_handle = #stablehlo.channel_handle<handle = 1, type = 3>,
  is_host_transfer = true
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)

その他の例

reduce

セマンティクス

dimensions に沿ってリダクション関数 bodyinputsinit_values に適用し、results テンソルを生成します。

削減の順序は実装で定義します。つまり、bodyinit_values は、演算がすべての実装のすべての入力に対して同じ結果を生成することを保証するためにモノイドを形成する必要があります。ただし、この条件は、一般的な削減の多くでは当てはまりません。たとえば、body の浮動小数点の加算と init_values のゼロは、実際にはモノイドを形成しません。これは、浮動小数点数の加算は結合的ではないためです。

results...[j0, ..., jR-1] = reduce(input_slices_converted) の形式は次のとおりです。

  • input_slices = inputs...[j0, ..., :, ..., jR-1]: ここで、:dimensions に挿入されます。
  • input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...).
  • init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...).
  • バイナリツリー schedule の場合は reduce(input_slices_converted) = exec(schedule)。ここで、:
    • exec(node) = body(exec(node.left), exec(node.right)).
    • exec(leaf) = leaf.value.
  • schedule は実装定義の完全なバイナリツリーで、インオーダー トラバーサルは次のものから構成されます。
    • index_space(input_slices_converted) 内のすべての index に対する input_slices_converted...[index] 値。辞書順の index の昇順で行われます。
    • 実装で定義された位置に、実装が定義した量の init_values_converted が混在します。

入力

ラベル 名前 タイプ 制約
(I1) inputs テンソル数またはテンソルごとの量子化テンソルの可変数 (C1 ~ C4)、(C6)、(C7)
(I2) init_values 0 次元テンソルまたはテンソルごとの量子化テンソルの可変数 (C2)、(C3)
(I3) dimensions si64 型の 1 次元テンソル定数 (C4)、(C5)、(C7)
(I4) body 機能 (C6)

出力

名前 タイプ 制約
results テンソル数またはテンソルごとの量子化テンソルの可変数 (C3)、(C7)、(C8)

制約

  • (C1)same(shape(inputs...))
  • (C2)element_type(inputs...) = element_type(init_values...)
  • (C3)0 < size(inputs) = size(init_values) = size(results) = N
  • (C4)0 <= dimensions < rank(inputs[0])
  • (C5)is_unique(dimensions)
  • (C6)body の型は (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)is_promotable(element_type(inputs[i]), Ei) です。
  • (C7)shape(results...) = shape(inputs...)。ただし、dimensions に対応する inputs... のディメンション サイズは含まれません。
  • (C8)[0,N) のすべての i に対して element_type(results[i]) = Ei

// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  dimensions = dense<1> : tensor<1xi64>
} : (tensor<1x6xi64>, tensor<i64>) -> tensor<1xi64>
// %result = [15]

その他の例

reduce_precision

セマンティクス

operand の要素単位で exponent_bitsmantissa_bits を使用する別の浮動小数点型に変換し、元の浮動小数点型に戻して output テンソルを生成します。

より正式に:

  • 元の値の仮数ビットが更新され、roundToIntegralTiesToEven セマンティクスを使用して mantissa_bits で表現できる最も近い値に元の値が丸められます。
  • mantissa_bits が元の値の仮数のビット数より小さい場合、仮数のビットは mantissa_bits に切り捨てられます。
  • 次に、中間結果の指数ビットが exponent_bits で指定された範囲に収まらない場合、中間結果は元の符号を使用して無限にオーバーフローするか、元の符号を使用して 0 にアンダーフローされます。
  • 量子化型の場合、dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result)) を実行します。

入力

ラベル 名前 タイプ 制約
(I1) operand 浮動小数点型またはテンソルごとの量子化テンソル (C1)
(I2) exponent_bits si32 型の定数 (C2)
(I3) mantissa_bits si32 型の定数 (C3)

出力

名前 タイプ 制約
output 浮動小数点型またはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(operand) = baseline_type(output)
  • (C2)1 <= exponent_bits
  • (C3)0 <= mantissa_bits

// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
  exponent_bits = 5 : i32,
  mantissa_bits = 10 : i32
} : (tensor<6xf64>) -> tensor<6xf64>
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]

その他の例

reduce_scatter

セマンティクス

StableHLO プロセス グリッド内の各プロセス グループ内で、computations を使用して各プロセスの operand テンソルの値に対してリダクションを実行し、リダクション結果を scatter_dimension に沿って一部に分割し、分割したパーツをプロセス間で分散させて result を生成します。

このオペレーションでは、StableHLO プロセス グリッドが次のように定義される process_groups に分割されます。

  • cross_replica(replica_groups)channel_id <= 0 and use_global_device_ids = false の場合)。
  • cross_replica_and_partition(replica_groups)channel_id > 0 and use_global_device_ids = false の場合)。
  • flattened_ids(replica_groups)channel_id > 0 and use_global_device_ids = true の場合)。

その後、各 process_group 内で次のことを行います。

  • reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation).
  • parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension).
  • process_group 内のすべての sender に対する result@receiver = parts@sender[receiver_index]receiver_index = process_group.index(receiver))。

入力

ラベル 名前 タイプ 制約
(I1) operand テンソルまたはテンソルごとの量子化テンソル (C1)、(C2)、(C7)、(C8)
(I2) scatter_dimension si64 型の定数 (C1)、(C2)、(C8)
(I3) replica_groups si64 型の 2 次元テンソル定数 (C3 ~ C5)
(I4) channel_id si64 型の定数 (C6)
(I5) use_global_device_ids i1 型の定数 (C6)
(I6) computation 機能 (C7)

出力

名前 タイプ 制約
result テンソルまたはテンソルごとの量子化テンソル (C8 ~ C9)

制約

  • (C1)dim(operand, scatter_dimension) % dim(process_groups, 1) = 0
  • (C2)0 <= scatter_dimension < rank(operand)
  • (C3)is_unique(replica_groups)
  • (C4)size(replica_groups) は次のように定義されます。
    • cross_replica を使用する場合は num_replicas
    • cross_replica_and_partition を使用する場合は num_replicas
    • flattened_ids を使用する場合は num_processes
  • (C5)0 <= replica_groups < size(replica_groups)
  • (C6)use_global_device_ids = true の場合、channel_id > 0
  • (C7)computation の型は (tensor<E>, tensor<E>) -> (tensor<E>) で、is_promotable(element_type(operand), E)
  • (C8)shape(result) = shape(operand) 例外:
    • dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1).
  • (C9)element_type(result) = E

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
//                   [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
//                   [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
  %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
  "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  scatter_dimension = 1 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x4xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[10, 12],
//                  [18, 20]]
// %result@(1, 0): [[14, 16],
//                  [22, 24]]

その他の例

reduce_window

セマンティクス

inputsinit_values のウィンドウにリダクション関数 body を適用し、results を生成します。

次の図は、具体的な例を使用して results... の要素が inputs... から計算される方法を示しています。

より正式には、results...[result_index] = reduce(windows, init_values, axes(inputs...), body)reduce を参照)で以下を行います。

  • padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1).
  • window_start = result_index * window_strides.
  • window_end = window_start + (window_dimensions - 1) * window_dilations + 1.
  • windows = slice(padded_inputs..., window_start, window_end, window_dilations).

入力

ラベル 名前 タイプ 制約
(I1) inputs テンソル数またはテンソルごとの量子化テンソルの可変数 (C1 ~ C4)、(C6)、(C8)、(C10)、(C12)、(C13)、(C15)
(I2) init_values 0 次元テンソルまたはテンソルごとの量子化テンソルの可変数 (C1)、(C13)
(I3) window_dimensions si64 型の 1 次元テンソル定数 (C4)、(C5)、(C15)
(I4) window_strides si64 型の 1 次元テンソル定数 (C6)、(C7)、(C15)
(I5) base_dilations si64 型の 1 次元テンソル定数 (C8)、(C9)、(C15)
(I6) window_dilations si64 型の 1 次元テンソル定数 (C10)、(C11)、(C15)
(I7) padding si64 型の 2 次元テンソル定数 (C12)、(C15)
(I8) body 機能 (C13)

出力

名前 タイプ 制約
results テンソル数またはテンソルごとの量子化テンソルの可変数 (C1)、(C14 ~ C16)

制約

  • (C1)0 < size(inputs) = size(init_values) = size(results) = N
  • (C2)same(shape(inputs...))
  • (C3)element_type(inputs...) = element_type(init_values...)
  • (C4)size(window_dimensions) = rank(inputs[0])
  • (C5)0 < window_dimensions
  • (C6)size(window_strides) = rank(inputs[0])
  • (C7)0 < window_strides
  • (C8)size(base_dilations) = rank(inputs[0])
  • (C9)0 < base_dilations
  • (C10)size(window_dilations) = rank(inputs[0])
  • (C11)0 < window_dilations
  • (C12)shape(padding) = [rank(inputs[0]), 2]
  • (C13)body の型は (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)is_promotable(element_type(inputs[i]), Ei) です。
  • (C14)same(shape(results...))
  • (C15)shape(results[0]) = num_windows ここで:
    • dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1.
    • padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1].
    • dilated_window_shape = (window_dimensions - 1) * window_dilations + 1.
    • is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape.
    • num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1.
  • (C16)[0,N) のすべての i に対する element_type(results[i]) = Ei

// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  window_dimensions = dense<[2, 1]> : tensor<2xi64>,
  window_strides = dense<[4, 1]> : tensor<2xi64>,
  base_dilations = dense<[2, 1]> : tensor<2xi64>,
  window_dilations = dense<[3, 1]> : tensor<2xi64>,
  padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<3x2xi64>, tensor<i64>) -> tensor<2x2xi64>
// %result = [[0, 0], [3, 4]]

その他の例

残り

セマンティクス

被除数 lhs と除数 rhs テンソルの要素単位の余りを実行し、result テンソルを生成します。

より正式には、被除数から結果の符号が取得され、結果の絶対値は常に除数の絶対値よりも小さくなります。余りは lhs - d * rhs として計算されます。d は次のように計算されます。

  • 整数の場合: stablehlo.divide(lhs, rhs)
  • 浮動小数点数の場合: IEEE-754 の division(lhs, rhs)(丸め属性 roundTowardZero)。
  • 複素数: 未定(#997)。
  • 量子化型の場合:
    • dequantize_op_quantize(remainder, lhs, rhs, type(result)).

浮動小数点要素タイプの場合、この演算は IEEE-754 仕様の remainder 演算とは対照的です。ここで dlhs/rhs の正確な値に最も近い整数値で、偶数に関連しています。

入力

ラベル 名前 タイプ 制約
(I1) lhs 整数、浮動小数点、複合型、テンソルごとの量子化テンソル (C1)
(I2) rhs 整数、浮動小数点、複合型、テンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result 整数、浮動小数点、複合型、テンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(operand) = baseline_type(result)

// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64>
// %result: [2, -2, 2, -2]

その他の例

replica_id

セマンティクス

現在のプロセスの replica_id を生成します。

出力

名前 タイプ
result ui32 型の 0 次元テンソル

%result = "stablehlo.replica_id"() : () -> tensor<ui32>

その他の例

reshape

セマンティクス

operand テンソルを result テンソルに変更します。概念上は、同じ正規表現を維持しつつ、場合によっては形を変更する(tensor<2x3xf32> から tensor<3x2xf32>tensor<6xf32> など)ことを意味します。

より正式には result[result_index] = operand[operand_index] です。ここで、result_indexoperand_index は、index_space(result)index_space(operand) の辞書順での位置が同じになります。

入力

ラベル 名前 タイプ 制約
(I1) operand テンソルまたは量子化テンソル (C1 ~ C3)

出力

名前 タイプ 制約
result テンソルまたは量子化テンソル (C1 ~ C3)

制約

  • (C1)element_type(result) は次のように表されます。
    • element_type(operand)!is_per_axis_quantized(operand) の場合)。
    • element_type(operand)。ただし、quantization_dimension(operand)quantization_dimension(result) は異なる場合があります。
  • (C2)size(operand) = size(result)
  • (C3)is_per_axis_quantized(operand) の場合:
    • reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
    • dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result)).
    • reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).

// %operand: [[1, 2, 3], [4, 5, 6]]]
%result = "stablehlo.reshape"(%operand) : (tensor<2x3xi32>) -> tensor<3x2xi32>
// %result: [[1, 2], [3, 4], [5, 6]]

その他の例

reverse

セマンティクス

指定された dimensions に沿って operand 内の要素の順序を逆にして、result テンソルを生成します。より正式には result[result_index] = operand[operand_index] です。ここで、

  • dimensionsd の場合、operand_index[d] = dim(result, d) - result_index[d] - 1
  • それ以外の場合は operand_index[d] = result_index[d]

入力

ラベル 名前 タイプ 制約
(I1) operand テンソルまたはテンソルごとの量子化テンソル (C1)、(C3)
(I2) dimensions si64 型の 1 次元テンソル定数 (C2)、(C3)

出力

名前 タイプ 制約
result テンソルまたはテンソルごとの量子化テンソル (C1)、(C3)

制約

  • (C1)type(operand) = type(result)
  • (C2)is_unique(dimensions)
  • (C3)0 <= dimensions < rank(result)

// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
  dimensions = dense<1> : tensor<1xi64>
} : (tensor<3x2xi32>) -> tensor<3x2xi32>
// %result: [[2, 1], [4, 3], [6, 5]]

その他の例

RNG

セマンティクス

rng_distribution アルゴリズムを使用して乱数を生成し、指定された形状 shaperesult テンソルを生成します。

rng_distribution = UNIFORM の場合、区間 [a, b) の一様分布に従って乱数が生成されます。a >= b の場合、動作は未定義です。

rng_distribution = NORMAL の場合、乱数は平均 = a、標準偏差 = b の正規分布に従って生成されます。b < 0 の場合、動作は未定義です。

乱数が生成される正確な方法は、実装で定義されます。たとえば、確定的である場合もあれば、そうでない場合もあります。また、隠れ状態を使用する場合とされない場合もあります。

多くの関係者との話し合いの結果、この op は実質的に非推奨となったため、今後削除を検討する予定です(#597)。

入力

ラベル 名前 タイプ 制約
(I1) a 整数、ブール値、浮動小数点型の 0 次元テンソル (C1)、(C2)
(I2) b 整数、ブール値、浮動小数点型の 0 次元テンソル (C1)、(C2)
(I3) shape si64 型の 1 次元テンソル定数 (C3)
(I4) rng_distribution UNIFORMNORMAL の列挙型 (C2)

出力

名前 タイプ 制約
result 整数、ブール値、浮動小数点型のテンソル (C1 ~ C3)

制約

  • (C1)element_type(a) = element_type(b) = element_type(result)
  • (C2)rng_distribution = NORMAL の場合、is_float(a)
  • (C3)shape(result) = shape

// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
  rng_distribution = #stablehlo<rng_distribution UNIFORM>
} : (tensor<i32>, tensor<i32>, tensor<2xi64>) -> tensor<3x3xi32>
// %result: [
//           [1, 0, 1],
//           [1, 1, 1],
//           [0, 0, 0]
//          ]

rng_bit_generator

セマンティクス

初期状態 initial_state を前提として、擬似乱数ジェネレータ アルゴリズム rng_algorithm を使用して、均一なランダムビットが入力された output と更新された出力状態 output_state を返します。出力は initial_state の決定論的関数であることが保証されますが、実装間の確定的関数は保証されません。

rng_algorithm は、次のいずれかです。

  • DEFAULT: 実装で定義されているアルゴリズム。
  • THREE_FRY: Threefry アルゴリズムの実装定義のバリアント。*
  • PHILOX: 実装で定義されている Philox アルゴリズムのバリアント。*

* 参照: Salmon et al. SC 2011. 並行乱数: 1、2、3 と同じくらい簡単です。

入力

ラベル 名前 タイプ 制約
(I1) rng_algorithm DEFAULTTHREE_FRYPHILOX の列挙型 (C2)
(I2) initial_state ui64 型の 1 次元テンソル (C1)、(C2)

出力

名前 タイプ 制約
output_state ui64 型の 1 次元テンソル (C1)
output 整数型または浮動小数点型のテンソル

制約

  • (C1)type(initial_state) = type(output_state)
  • (C2)size(initial_state) が次のように定義されます。
    • rng_algorithm = DEFAULT の場合、実装で定義されます。
    • rng_algorithm = THREE_FRY の場合は 2
    • 2 または 3rng_algorithm = PHILOX の場合)。

// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
  rng_algorithm = #stablehlo<rng_algorithm THREE_FRY>
} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x2xui64>)
// %output_state: [1, 6]
// %output: [
//           [9236835810183407956, 16087790271692313299],
//           [18212823393184779219, 2658481902456610144]
//          ]

round_nearest_afz

セマンティクス

operand テンソルで最も近い整数に向けて要素ごとの丸めを実行し、ゼロからの同点を解消して、result テンソルを生成します。IEEE-754 仕様の roundToIntegralTiesToAway オペレーションを実装します。量子化型の場合は dequantize_op_quantize(round_nearest_afz, operand, type(result)) を実行します。

入力

ラベル 名前 タイプ 制約
(I1) operand 浮動小数点型またはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result 浮動小数点型またはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(operand) = baseline_type(result)

// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]

その他の例

round_nearest_even

セマンティクス

operand テンソルで最も近い整数に向けて要素ごとの丸めを実行し、偶数の整数とのタイ関係を解消して、result テンソルを生成します。IEEE-754 仕様の roundToIntegralTiesToEven オペレーションを実装します。量子化型の場合、dequantize_op_quantize(round_nearest_even, operand, type(result)) を実行します。

入力

ラベル 名前 タイプ 制約
(I1) operand 浮動小数点型またはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result 浮動小数点型またはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(operand) = baseline_type(result)

// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]

その他の例

rsqrt

セマンティクス

operand テンソルに対して要素単位の逆平方根演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の操作を行います。

  • 浮動小数点数の場合: IEEE-754 の rSqrt
  • 複素数の場合: 複素逆数の平方根。
  • 量子化型の場合: dequantize_op_quantize(rsqrt, operand, type(result))

入力

ラベル 名前 タイプ 制約
(I1) operand 浮動小数点型または複素数のテンソル、またはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result 浮動小数点型または複素数のテンソル、またはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(operand) = baseline_type(result)

// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.5], [0.33333343, 0.2]]

その他の例

scatter

セマンティクス

inputs テンソルと等しい results テンソルを生成します。ただし、scatter_indices で指定された複数のスライスは、update_computation を使用して updates 値で更新されます。

次の図は、具体的な例を使用して、updates... の要素が results... の要素にどのようにマッピングされるかを示しています。この図では、いくつかの updates... インデックスを選択し、それらがどの results... インデックスに対応しているかを詳しく説明しています。

より正式には、index_space(updates[0]) のすべての update_index については次のようになります。

  • update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims].
  • update_scatter_index = update_index[update_scatter_dims...].
  • start_index は次のように定義されます。
    • scatter_indices[si0, ..., :, ..., siN]。ここで、siupdate_scatter_index の個々の要素であり、:index_vector_dim < rank(scatter_indices) の場合に index_vector_dim インデックスに挿入されます。
    • それ以外の場合は [scatter_indices[update_scatter_index]]
  • axes(inputs[0])d_input の場合:
    • d_input = scatter_dims_to_operand_dims[d_start] の場合は full_start_index[d_input] = start_index[d_start]
    • それ以外の場合は full_start_index[d_input] = 0
  • update_window_index = update_index[update_window_dims...].
  • full_window_index = [wi0, ..., 0, ..., wiN]。ここで、wiupdate_window_index の個々の要素であり、0inserted_window_dims のインデックスに挿入されます。
  • result_index = full_start_index + full_window_index.

そのため、results = exec(schedule, inputs) は次のとおりです。

  • schedule は、実装で定義されている index_space(updates[0]) の組み合わせです。
  • exec([update_index, ...], results) = exec([...], updated_results) ここで:
    • result_indexshape(results...) の範囲内にある場合
    • updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )
    • updated_values = update_computation(results...[result_index], updates_converted)
    • updated_results は、results...[result_index]updated_values... に設定された results のコピーです。
    • それ以外の場合
    • updated_results = results.
  • exec([], results) = results.

indices_are_sortedtrue の場合、実装では scatter_indicesscatter_dims_to_operand_dims を基準に並べ替えられていると想定できます。そうでない場合、動作は未定義になります。より正式には、indices(result) からのすべての i1 < i2 について、full_start_index(i1) <= full_start_index(i2) となります。

unique_indicestrue の場合、実装では分散されているすべての result_index インデックスが一意であると想定できます。unique_indicestrue で、分散されているインデックスが一意でない場合、動作は未定義です。

入力

ラベル 名前 タイプ 制約
(I1) inputs テンソル数またはテンソルごとの量子化テンソルの可変数 (C1)、(C2)、(C4 ~ C6)、(C10)、(C13)、(C15 ~ C16)
(I2) scatter_indices 整数型のテンソル (C4)、(C11)、(C14)
(I3) updates テンソル数またはテンソルごとの量子化テンソルの可変数 (C3 ~ C6)、(C8)
(I4) update_window_dims si64 型の 1 次元テンソル定数 (C2)、(C4)、(C7)、(C8)
(I5) inserted_window_dims si64 型の 1 次元テンソル定数 (C2)、(C4)、(C9)、(C10)
(I6) scatter_dims_to_operand_dims si64 型の 1 次元テンソル定数 (C11 ~ C13)
(I7) index_vector_dim si64 型の定数 (C4)、(C11)、(C14)
(I8) indices_are_sorted i1 型の定数
(I9) unique_indices i1 型の定数
(I10) update_computation 機能 (C15)

出力

名前 タイプ 制約
results テンソル数またはテンソルごとの量子化テンソルの可変数 (C15 ~ C17)

制約

  • (C1)same(shape(inputs...))
  • (C2)rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims)
  • (C3)same(shape(updates...))
  • (C4)shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes)。ここで:
    • update_scatter_dim_sizes = shape(scatter_indices)。ただし、index_vector_dim に対応する scatter_indices のディメンション サイズは含まれません。
    • update_window_dim_sizes <= shape(inputs[0])。ただし、inserted_window_dims に対応する inputs[0] のディメンション サイズは含まれません。
    • combine は、update_scatter_dim_sizesupdate_scatter_dims に対応する軸に配置します。また、update_window_dim_sizesupdate_window_dims に対応する軸に配置します。
  • (C5)0 < size(inputs) = size(updates) = N
  • (C6)element_type(updates...) = element_type(inputs...)
  • (C7)is_unique(update_window_dims) and is_sorted(update_window_dims)
  • (C8)0 <= update_window_dims < rank(updates[0])
  • (C9)is_unique(inserted_window_dims) and is_sorted(update_window_dims)
  • (C10)0 <= inserted_window_dims < rank(inputs[0])
  • (C11)size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1
  • (C12)is_unique(scatter_dims_to_operand_dims)
  • (C13)0 <= scatter_dims_to_operand_dims < rank(inputs[0])
  • (C14)0 <= index_vector_dim <= rank(scatter_indices)
  • (C15)update_computation の型は (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>) で、is_promotable(element_type(inputs[i]), Ei) です。
  • (C16)shape(inputs...) = shape(results...)
  • (C17)[0,N) のすべての i に対して element_type(results[i]) = Ei

// %input: [
//          [[1, 2], [3, 4], [5, 6], [7, 8]],
//          [[9, 10], [11, 12], [13, 14], [15, 16]],
//          [[17, 18], [19, 20], [21, 22], [23, 24]]
//         ]
// %scatter_indices: [[[0, 2], [1, 0], [2, 1]], [[0, 1], [1, 0], [0, 9]]]
// %update: [
//           [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]],
//           [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]]
//          ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  scatter_dimension_numbers = #stablehlo.scatter<
    update_window_dims = [2, 3],
    inserted_window_dims = [0],
    scatter_dims_to_operand_dims = [1, 0],
    index_vector_dim = 2>,
  indices_are_sorted = false,
  unique_indices = false
} : (tensor<3x4x2xi64>, tensor<2x3x2xi64>, tensor<2x3x2x2xi64>) -> tensor<3x4x2xi64>
// %result: [
//           [[1, 2], [5, 6], [7, 8], [7, 8]],
//           [[10, 11], [12, 13], [14, 15], [16, 17]],
//           [[18, 19], [20, 21], [21, 22], [23, 24]]
//          ]

その他の例

select

セマンティクス

result テンソルを生成します。各要素は、pred の対応する要素の値に基づいて on_true テンソルまたは on_false テンソルから選択されます。より正式には result[result_index] = pred_element ? on_true[result_index] : on_false[result_index]pred_element = rank(pred) = 0 ? pred[] : pred[result_index])。量子化型の場合、dequantize_select_quantize(pred, on_true, on_false, type(result)) を実行します。

入力

ラベル 名前 タイプ 制約
(I1) pred i1 型のテンソル (C1)
(I2) on_true テンソルまたはテンソルごとの量子化テンソル (C1 ~ C2)
(I3) on_false テンソルまたはテンソルごとの量子化テンソル (C2)

出力

名前 タイプ 制約
result テンソルまたはテンソルごとの量子化テンソル (C2)

制約

  • (C1)rank(pred) = 0 or shape(pred) = shape(on_true)
  • (C2)baseline_type(on_true) = baseline_type(on_false) = baseline_type(result)

// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false) : (tensor<2x2xi1>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 2], [3, 8]]

その他の例

select_and_scatter

セマンティクス

select を使用して input テンソルの reduce_window の結果に基づいて scatter を使用して source テンソルから値を分散し、result テンソルを生成します。

次の図は、具体的な例を使用して、result の要素が operandsource から計算される方法を示しています。

より正式に:

  • selected_values = reduce_window_without_init(...) は、次の入力に置き換えます。

    • `inputs = [オペランド]。
    • window_dimensionswindow_stridespadding はそのまま使用されます。
    • base_dilations = windows_dilations = 1.
    • body は次のように定義されます。
    def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>:
      return select(arg0, arg1) ? arg0 : arg1;
    

    ここで、E = element_type(operand)reduce_window_without_initreduce_window とまったく同じように機能しますが、基盤となる reduceリデュースを参照)の schedule に init 値が含まれていない点が異なります。現時点では、対応するウィンドウに値がない場合の動作は特定されていません(#731)。

  • result[result_index] = reduce([source_values], [init_value], [0], scatter) ここで

    • source_values = [source[source_index] for source_index in source_indices].
    • selected_values[source_index]operand_indexoperand 要素がある場合は selected_index(source_index) = operand_index です。
    • source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index].

入力

ラベル 名前 タイプ 制約
(I1) operand テンソルまたはテンソルごとの量子化テンソル (C1 ~ C4)、(C6)、(C8 ~ C11)
(I2) source テンソルまたはテンソルごとの量子化テンソル (C1)、(C2)
(I3) init_value 0 次元テンソルまたはテンソルごとの量子化テンソル (C3)
(I4) window_dimensions si64 型の 1 次元テンソル定数 (C2)、(C4)、(C5)
(I5) window_strides si64 型の 1 次元テンソル定数 (C2)、(C6)、(C7)
(I6) padding si64 型の 2 次元テンソル定数 (C2)、(C8)
(I7) select 機能 (C9)
(I8) scatter 機能 (C10)

出力

名前 タイプ 制約
result テンソルまたはテンソルごとの量子化テンソル (C11 ~ C12)

制約

  • (C1)element_type(operand) = element_type(source)
  • (C2)shape(source) = num_windows。ここで:
    • padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1].
    • is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape.
    • num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1.
  • (C3)element_type(init_value) = element_type(operand)
  • (C4)size(window_dimensions) = rank(operand)
  • (C5)0 < window_dimensions
  • (C6)size(window_strides) = rank(operand)
  • (C7)0 < window_strides
  • (C8)shape(padding) = [rank(operand), 2]
  • (C9)select の型は (tensor<E>, tensor<E>) -> tensor<i1> で、E = element_type(operand) です。
  • (C10)scatter の型は (tensor<E>, tensor<E>) -> tensor<E> で、is_promotable(element_type(operand), E) です。
  • (C11)shape(operand) = shape(result)
  • (C12)element_type(result) = E

// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GE>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    "stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  window_dimensions = dense<[3, 1]> : tensor<2xi64>,
  window_strides = dense<[2, 1]> : tensor<2xi64>,
  padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi64>, tensor<2x2xi64>, tensor<i64>) -> tensor<4x2xi64>
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]

その他の例

送信

セマンティクス

inputs をチャネル channel_id に送信し、result トークンを生成します。

is_host_transfertrue の場合、オペレーションはデータをホストに転送します。それ以外の場合は、別のデバイスにデータを転送します。これは、実装が定義されることを意味します。このフラグは channel_type で指定された情報が重複しているため、将来的にはそのうち 1 つのみを保持する予定です(#666)。

入力

ラベル 名前 タイプ 制約
(I1) inputs テンソルまたは量子化テンソルの可変数
(I2) token token
(I3) channel_id si64 型の定数
(I4) channel_type DEVICE_TO_DEVICEDEVICE_TO_HOST の列挙型 (C1)
(I5) is_host_transfer i1 型の定数 (C1)

出力

名前 タイプ
result token

制約

  • (C1)channel_type が次のように定義されます。
    • is_host_transfer = true の場合は DEVICE_TO_HOST
    • それ以外の場合は DEVICE_TO_DEVICE

%result = "stablehlo.send"(%operand, %token) {
  channel_handle = #stablehlo.channel_handle<handle = 1, type = 2>,
  is_host_transfer = true
} : (tensor<2x2xi64>, !stablehlo.token) -> !stablehlo.token

その他の例

shift_left

セマンティクス

lhs テンソルで rhs のビット数だけ要素単位の左シフト演算を実行し、result テンソルを生成します。

入力

ラベル 名前 タイプ 制約
(I1) lhs 整数型のテンソル (C1)
(I2) rhs 整数型のテンソル (C1)

出力

名前 タイプ 制約
result 整数型のテンソル (C1)

制約

  • (C1)type(lhs) = type(rhs) = type(result)

// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-2, 0, 8]

その他の例

shift_right_arithmetic

セマンティクス

lhs テンソルに rhs のビット数だけ要素単位の右シフト演算を実行し、result テンソルを生成します。

入力

ラベル 名前 タイプ 制約
(I1) lhs 整数型のテンソル (C1)
(I2) rhs 整数型のテンソル (C1)

出力

名前 タイプ 制約
result 整数型のテンソル (C1)

制約

  • (C1)type(lhs) = type(rhs) = type(result)

// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-1, 0, 1]

その他の例

shift_right_logical

セマンティクス

lhs テンソルに rhs のビット数だけ要素単位の論理右シフト演算を実行し、result テンソルを生成します。

入力

ラベル 名前 タイプ 制約
(I1) lhs 整数型のテンソル (C1)
(I2) rhs 整数型のテンソル (C1)

出力

名前 タイプ 制約
result 整数型のテンソル (C1)

制約

  • (C1)type(lhs) = type(rhs) = type(result)

// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [9223372036854775807, 0, 1]

その他の例

揃えます。

セマンティクス

operand の符号を要素ごとに返し、result テンソルを生成します。 より正式には、要素 x ごとに、次のように Python 構文を使用してセマンティクスを表現できます。

def sign(x):
  if is_integer(x):
    if compare(x, 0, LT, SIGNED): return -1
    if compare(x, 0, EQ, SIGNED): return 0
    return 1
  elif is_float(x):
    if is_nan(x): return NaN
    if compare(x, -0.0, EQ, FLOAT): return -0.0
    if compare(x, +0.0, EQ, FLOAT): return +0.0
    if compare(x, 0.0, LT, FLOAT): return -1.0
    return 1.0
  elif is_complex(x):
    if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
    if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
    return divide(x, convert(abs(x), type(x)))

量子化型の場合、dequantize_op_quantize(sign, operand, type(result)) を実行します。

入力

ラベル 名前 タイプ 制約
(I1) operand 符号付き整数、浮動小数点、複合型、テンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result 符号付き整数、浮動小数点、複合型、テンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(operand) = baseline_type(result)

// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]

その他の例

サイン

セマンティクス

operand テンソルに対して要素ごとの正弦演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の操作を行います。

  • 浮動小数点数の場合: IEEE-754 の sin
  • 複素数の場合: 複素正弦。
  • 量子化型の場合: dequantize_op_quantize(sine, operand, type(result))

入力

ラベル 名前 タイプ 制約
(I1) operand 浮動小数点型または複素数のテンソル、またはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result 浮動小数点型または複素数のテンソル、またはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(operand) = baseline_type(result)

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.sine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [0.0, -1.0]]

その他の例

スライス

セマンティクス

静的に計算された開始インデックスを使用して operand からスライスを抽出し、result テンソルを生成します。start_indices には各ディメンションのスライスの開始インデックスが含まれ、limit_indices には各ディメンションのスライスの終了インデックス(排他的)が含まれ、strides には各ディメンションのストライドが含まれます。

より正式には result[result_index] = operand[operand_index] です。ここで、operand_index = start_indices + result_index * strides となります。

入力

ラベル 名前 タイプ 制約
(I1) operand テンソルまたはテンソルごとの量子化テンソル (C1 ~ C3)、(C5)
(I2) start_indices si64 型の 1 次元テンソル定数 (C2)、(C3)、(C5)
(I3) limit_indices si64 型の 1 次元テンソル定数 (C2)、(C3)、(C5)
(I4) strides si64 型の 1 次元テンソル定数 (C2)、(C4)

出力

名前 タイプ 制約
result テンソルまたはテンソルごとの量子化テンソル (C1)、(C5)

制約

  • (C1)element_type(operand) = element_type(result)
  • (C2)size(start_indices) = size(limit_indices) = size(strides) = rank(operand)
  • (C3)0 <= start_indices <= limit_indices <= shape(operand)
  • (C4)0 < strides
  • (C5)shape(result) = ceil((limit_indices - start_indices) / strides)

// %operand: [
//            [0, 0, 0, 0],
//            [0, 0, 1, 1],
//            [0, 0, 1, 1]
//           ]
%result = "stablehlo.slice"(%operand) {
  start_indices = dense<[1, 2]> : tensor<2xi64>,
  limit_indices = dense<[3, 4]> : tensor<2xi64>,
  strides = dense<1> : tensor<2xi64>
} : (tensor<3x4xi64>) -> tensor<2x2xi64>
// % result: [
//            [1, 1],
//            [1, 1]
//           ]

その他の例

sort

セマンティクス

inputs の 1 次元スライスを、comparator に従ってディメンション dimension に沿ってまとめて並べ替え、results を生成します。

他の演算の同様の入力とは異なり、dimension では負の値を使用できます。セマンティクスは以下のとおりです。将来的には、整合性の理由から許可されなくなります(#1377)。

is_stable が true の場合、並べ替えは安定します。つまり、コンパレータによって等しいとみなされる要素の相対的な順序が保持されます。入力が 1 つの場合、comparator(e1, e2) = comparator(e2, e1) = false の場合にのみ、コンパレータによって 2 つの要素 e1e2 が等しいとみなされます。これが複数の入力に対してどのように一般化されるかについては、以下の形式化をご覧ください。

より正式には、index_space(results[0]) のすべての result_index については次のようになります。

  • adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension.
  • result_slice = [ri0, ..., :, ..., riR-1]。ここで、riNresult_index の個々の要素であり、:adjusted_dimension に挿入されます。
  • inputs_together = (inputs[0]..., ..., inputs[N-1]...).
  • results_together[result_slice] = sort(inputs_together[result_slice], comparator_together).
  • ここで、sort は、1 次元スライスを非降順で並べ替えます。左側の引数が右側の 2 番目の引数よりも小さい場合は、comparator_togethertrue を返すと想定します。
  • def comparator_together(lhs_together, rhs_together):
      args = []
      for (lhs_el, rhs_el) in zip(lhs_together, rhs_together):
        args.append(lhs_el)
        args.append(rhs_el)
      return comparator(*args)
    
  • (results[0]..., ..., results[N-1]...) = results_together.

入力

ラベル 名前 タイプ 制約
(I1) inputs テンソル数またはテンソルごとの量子化テンソルの可変数 (C1 ~ C5)
(I2) dimension si64 型の定数 (C4)
(I3) is_stable i1 型の定数
(I4) comparator 機能 (C5)

出力

名前 タイプ 制約
results テンソル数またはテンソルごとの量子化テンソルの可変数 (C2)、(C3)

制約

  • (C1)0 < size(inputs)
  • (C2)type(inputs...) = type(results...)
  • (C3)same(shape(inputs...) + shape(results...))
  • (C4)-R <= dimension < R、ここで R = rank(inputs[0])
  • (C5)comparator の型は (tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1> で、Ei = element_type(inputs[i]) です。

// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>, %arg3: tensor<i64>):
    %predicate = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GT>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    "stablehlo.return"(%predicate) : (tensor<i1>) -> ()
}) {
  dimension = 0 : i64,
  is_stable = true
} : (tensor<2x3xi64>, tensor<2x3xi64>) -> (tensor<2x3xi64>, tensor<2x3xi64>)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]

その他の例

sqrt

セマンティクス

operand テンソルに対して要素単位の平方根演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の操作を行います。

  • 浮動小数点数の場合: IEEE-754 の squareRoot
  • 複素数の場合: 複素平方根
  • 量子化型の場合: dequantize_op_quantize(sqrt, operand, type(result))

入力

ラベル 名前 タイプ 制約
(I1) operand 浮動小数点型または複素数のテンソル、またはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result 浮動小数点型または複素数のテンソル、またはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(operand) = baseline_type(result)

// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [2.0, 3.0]]

その他の例

subtract

セマンティクス

2 つのテンソル lhsrhs を要素ごとに減算し、result テンソルを生成します。要素のタイプに応じて、次の操作を行います。

  • 整数の場合: 整数の減算。
  • 浮動小数点数の場合: IEEE-754 の subtraction
  • 複素数の場合: 複素減算。
  • 量子化型の場合:
    • dequantize_op_quantize(subtract, lhs, rhs, type(result)).

入力

ラベル 名前 タイプ 制約
(I1) lhs 整数、浮動小数点、複合型のテンソル、またはテンソルごとの量子化テンソル (C1)
(I2) rhs 整数、浮動小数点、複合型のテンソル、またはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result 整数、浮動小数点、複合型のテンソル、またはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)

// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
// %result: [[1, 2], [3, 4]]

その他の例

tanh

セマンティクス

operand テンソルに対して要素ごとの双曲線正接演算を実行し、result テンソルを生成します。要素のタイプに応じて、次の操作を行います。

  • 浮動小数点数の場合: IEEE-754 の tanh
  • 複素数の場合: 複素双曲線正接。
  • 量子化型の場合:
    • dequantize_op_quantize(tanh, operand, type(result)).

入力

ラベル 名前 タイプ 制約
(I1) operand 浮動小数点型または複素数のテンソル、またはテンソルごとの量子化テンソル (C1)

出力

名前 タイプ 制約
result 浮動小数点型または複素数のテンソル、またはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_type(operand) = baseline_type(result)

// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand) : (tensor<3xf32>) -> tensor<3xf32>
// %result: [-0.76159416, 0.0, 0.76159416]

その他の例

転置

セマンティクス

permutation を使用して operand テンソルの次元を並べ替えて、result テンソルを生成します。より正式には result[result_index] = operand[operand_index] です(result_index[d] = operand_index[permutation[d]])。

入力

ラベル 名前 タイプ 制約
(I1) operand テンソルまたは量子化テンソル (C1 ~ C4)
(I2) permutation si64 型の 1 次元テンソル定数 (C2 ~ C4)

出力

名前 タイプ 制約
result テンソルまたは量子化テンソル (C1)、(C3 ~ C4)

制約

  • (C1)element_type(result) は次のように表されます。
    • element_type(operand)!is_per_axis_quantized(operand) の場合)。
    • element_type(operand)。ただし、quantization_dimension(operand)quantization_dimension(result) は異なる場合があります。
  • (C2)permutationrange(rank(operand)) の組み合わせです。
  • (C3)shape(result) = dim(operand, permutation...)
  • (C4)is_per_axis_quantized(result) の場合、quantization_dimension(operand) = permutation(quantization_dimension(result))

// %operand: [
//            [[1,2], [3,4], [5,6]],
//            [[7,8], [9,10], [11,12]]
//           ]
%result = "stablehlo.transpose"(%operand) {
  permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
// %result: [
//           [[1,7], [3,9], [5,11]],
//           [[2,8], [4,10], [6,12]]
//          ]

その他の例

triangular_solve

セマンティクス

下位三角係数行列または上位三角係数行列を持つ連立一次方程式系のバッチを解きます。

より正式には、ab から、result[i0, ..., iR-3, :, :] は、left_sidetrue のときは op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :] を、left_sidefalse のときは x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :] の解答です。op(a)transpose_a によって決まる変数 x を求めます。これは次のいずれかになります。

  • NO_TRANSPOSE: a を使用してオペレーションをそのまま実行します。
  • TRANSPOSE: a の転置に対する演算を実行します。
  • ADJOINT: a の共役転置に対して演算を実行します。

lowertrue の場合、入力データは a の下側の三角形からのみ読み取られ、それ以外の場合は a の上側の三角形からのみ読み取られます。出力データは同じ三角形に返されます。他方の三角形の値は実装で定義します。

unit_diagonal が true の場合、実装では a の対角要素が 1 に等しいと想定できます。それ以外の場合は、動作は未定義となります。

量子化型の場合、dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower, unit_diagonal, transpose_a), a, b, type(result)) を実行します。

入力

ラベル 名前 タイプ 制約
(I1) a 浮動小数点型または複素数のテンソル、またはテンソルごとの量子化テンソル (C1 ~ C3)
(I2) b 浮動小数点型または複素数のテンソル、またはテンソルごとの量子化テンソル (C1 ~ C4)
(I3) left_side i1 型の定数 (C3)
(I4) lower i1 型の定数
(I5) unit_diagonal i1 型の定数
(I6) transpose_a NO_TRANSPOSETRANSPOSEADJOINT の列挙型

出力

名前 タイプ 制約
result 浮動小数点型または複素数のテンソル、またはテンソルごとの量子化テンソル (C1)

制約

  • (C1)baseline_element_type(a) = baseline_element_type(b)
  • (C2)2 <= rank(a) = rank(b) = R
  • (C3)shape(a)shape(b) の関係は、次のように定義されます。
    • shape(a)[:-3] = shape(b)[:-3].
    • dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1).
  • (C4)baseline_type(b) = baseline_type(result)

// %a = [
//       [1.0, 0.0, 0.0],
//       [2.0, 4.0, 0.0],
//       [3.0, 5.0, 6.0]
//      ]
// %b = [
//       [2.0, 0.0, 0.0],
//       [4.0, 8.0, 0.0],
//       [6.0, 10.0, 12.0]
//      ]
%result = "stablehlo.triangular_solve"(%a, %b) {
  left_side = true,
  lower = true,
  unit_diagonal = false,
  transpose_a = #stablehlo<transpose NO_TRANSPOSE>
} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// %result: [
//           [2.0, 0.0, 0.0],
//           [0.0, 2.0, 0.0],
//           [0.0, 0.0, 2.0]
//          ]

tuple

セマンティクス

val から result タプルを生成します。

入力

ラベル 名前 タイプ 制約
(I1) val 可変長の値 (C1)

出力

名前 タイプ 制約
result tuple (C1)

制約

  • (C1)result の型は tuple<E0, ..., EN-1> で、Ei = type(val[i]) です。

// %val0: [1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1) : (tensor<2xf32>, tuple<tensor<i32>>) -> tuple<tensor<2xf32>, tuple<tensor<i32>>>
// %result: ([1.0, 2.0], (3))

その他の例

uniform_dequantize

セマンティクス

operand 型で定義された量子化パラメータに従って、量子化テンソル operand を浮動小数点テンソル result に要素ごとに変換します。

より正式には、result = dequantize(operand) です。

入力

ラベル 名前 タイプ 制約
(I1) operand 量子化テンソル (C1)、(C2)

出力

名前 タイプ 制約
result 浮動小数点型のテンソル (C1)、(C2)

制約

  • (C1)shape(operand) = shape(result)
  • (C2)element_type(result) = expressed_type(operand)

// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
// %result: [4.0, 15.0]

uniform_quantize

セマンティクス

result 型で定義された量子化パラメータに従って、浮動小数点テンソルまたは量子化テンソル operand を量子化テンソル result に要素ごとに変換します。

より正式に,

  • is_float(operand) の場合:
    • result = quantize(operand, type(result)).
  • is_quantized(operand) の場合:
    • float_result = dequantize(operand).
    • result = quantize(float_result, type(result)).

入力

ラベル 名前 タイプ 制約
(I1) operand 浮動小数点型または量子化型のテンソル (C1)、(C2)

出力

名前 タイプ 制約
result 量子化テンソル (C1)、(C2)

制約

  • (C1)shape(operand) = shape(result)
  • (C2)expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand)

// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
// %result: [10, 10]

// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-20,0.2:-30}>>
// %result: [20, 45]

しばらく

セマンティクス

cond 関数は true を出力する間、body 関数を 0 回以上実行すると出力を生成します。より正式には、Python 構文を使用してセマンティクスを表現すると、次のようになります。

internal_state = operand
while cond(*internal_state):
  internal_state = body(*internal_state)
results = internal_state

無限ループの動作は未定です(#383)。

入力

ラベル 名前 タイプ 制約
(I1) operand テンソル、量子化テンソル、トークンの可変数 (C1 ~ C3)
(I2) cond 機能 (C1)
(I3) body 機能 (C2)

出力

名前 タイプ 制約
results テンソル、量子化テンソル、トークンの可変数 (C3)

制約

  • (C1)cond の型は (T0, ..., TN-1) -> tensor<i1> で、Ti = type(operand[i]) です。
  • (C2)body の型は (T0, ..., TN-1) -> (T0, ..., TN-1) で、Ti = type(operand[i]) です。
  • (C3)type(results...) = type(operand...)

// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %cond = "stablehlo.compare"(%arg0, %ten) {
      comparison_direction = #stablehlo<comparison_direction LT>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    stablehlo.return %cond : tensor<i1>
  }, {
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %new_sum = stablehlo.add %arg1, %one : tensor<i64>
    %new_i = stablehlo.add %arg0, %one : tensor<i64>
    stablehlo.return %new_i, %new_sum : tensor<i64>, tensor<i64>
}) : (tensor<i64>, tensor<i64>) -> (tensor<i64>, tensor<i64>)
// %results0: 10
// %results1: 10

その他の例

XOR

セマンティクス

2 つのテンソル lhsrhs の要素ごとの XOR を実行し、result テンソルを生成します。要素のタイプに応じて、次の操作を行います。

  • ブール値の場合: 論理 XOR 演算。
  • 整数の場合: ビット演算 XOR 演算。

入力

ラベル 名前 タイプ 制約
(I1) lhs ブール値または整数型のテンソル (C1)
(I2) rhs ブール値または整数型のテンソル (C1)

出力

名前 タイプ 制約
result ブール値または整数型のテンソル (C1)

制約

  • (C1)type(lhs) = type(rhs) = type(result)

// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]

// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, false]]

実行

順次実行

StableHLO プログラムを実行するには、main 関数に入力値を指定し、出力値を計算します。関数の出力値は、対応する return 演算に根差した演算のグラフを実行することで計算されます。

実行順序は、Dataflow と整合している限り(つまり、オペレーションが使用される前に実行される場合)、実装によって定義されます。StableHLO では、すべての副作用オペレーションが 1 つのトークンを消費して 1 つのトークンを生成します(after_all を介して複数のトークンを 1 つのトークンに多重化できます)。そのため、副作用の実行順序も Dataflow と一致します。上のサンプル プログラムの実行順序は、%0%1%2%3%4return または %3%0%1%2%4return です。

より正式には、StableHLO プロセスは、1)StableHLO プログラム、2)オペレーションのステータス(まだ実行されていない、すでに実行されている)、3)プロセスが取り組んでいる中間値の組み合わせで構成されます。プロセスは、main 関数への入力値から始まり、オペレーション ステータスと中間値を更新するオペレーションのグラフを進め、出力値で終了します。さらなる形式化は未定です(#484)。

並列実行

StableHLO プログラムは並列実行でき、num_replicasnum_partitions の 2D プロセス グリッドにまとめられます。どちらも ui32 型です。

StableHLO プロセス グリッドでは、num_replicas * num_partitions 個の StableHLO プロセスが同時に実行されています。各プロセスには一意の process_id = (replica_id, partition_id) があり、replica_ids = range(num_replicas)replica_idpartition_ids = range(num_partitions)partition_id の型は ui32 です。

プロセス グリッドのサイズは、すべてのプログラムで静的に認識され(今後、StableHLO プログラム #650 に明示的に組み込まれる予定です)、プロセス グリッド内の位置はすべてのプロセスで静的に認識されます。各プロセスは、replica_id オペレーションと partition_id オペレーションを介してプロセス グリッド内の位置にアクセスできます。

プロセス グリッド内では、プログラムはすべて同じもの(「Single Program, Multiple Data」スタイル)にすることも、「Multiple Program, Multiple Data」スタイルでもすべて異なっていてもかまいません。また、それらの間に別のものを含めることもできます。将来的には、GSPMD(#619)など、StableHLO 並列プログラムを定義するためのその他のイディオムのサポートを導入する予定です。

プロセス グリッド内では、プロセスのほとんどが相互に独立しています。プロセスには個別のオペレーションのステータスがあり、個別の入力/中間/出力の値があります。ほとんどのオペレーションは、以下で説明する少数の集合オペレーションを除き、プロセス間で別々に実行されます。

ほとんどのオペレーションの実行は同じプロセスの値のみが使用されるため、通常はこれらの値を名前で参照できます。ただし、集団演算のセマンティクスを記述する場合、それでは不十分であり、特定のプロセス内で値 name を参照する name@process_id という表記が発生します。(このような観点から、不修飾 namename@(replica_id(), partition_id()) の省略形として表示できます)。

プロセス間の実行順序は、以下で説明するように、ポイントツーポイント通信と集団演算によって導入される同期を除き、実装によって定義されます。

ポイントツーポイント通信

StableHLO プロセスは、StableHLO チャネルを介して相互に通信できます。チャネルは si64 型の正の ID で表されます。さまざまな演算により、チャネルに値を送信し、チャネルから値を受け取ることができます。

これらのチャネル ID のソース、プロセス プログラムがそれらを認識する方法、それらによってどのような同期が導入されるかなど、さらなる形式化は未定です(#484)。

ストリーミング通信

すべての StableHLO プロセスは、次の 2 つのストリーミング インターフェースにアクセスできます。

  • 読み取り可能なインフィード
  • 書き込み可能なアウトフィード

プロセス間で通信するために使用するチャネルで、両端にプロセスがあるのに対して、インフィードとアウトフィードはもう一方の端が実装で定義されます。

ストリーミング通信が実行順序に与える影響や、それによって導入される同期の種類など、さらなる形式化は未定です(#484)。

共同運用

StableHLO には、all_gatherall_reduceall_to_allcollective_broadcastcollective_permutereduce_scatter の 6 つの集合演算があります。これらのオペレーションはすべて、StableHLO プロセス グリッド内のプロセスを StableHLO プロセス グループに分割し、各プロセス グループ内で他のプロセス グループとは独立して共同計算を実行します。

各プロセス グループ内で、集合オペレーションにより同期の障壁が生じる場合があります。この同期がいつ行われるか、プロセスがこの境界にどのように到達し、行われなかった場合にどうなるかなど、さらなる形式化は未定です(#484)。

プロセス グループにパーティション間通信が含まれる場合、つまり、プロセス グループにパーティション ID が異なるプロセスがある場合、集団演算の実行にはチャネルが必要で、集団演算は型 si64 の正の channel_id を提供する必要があります。レプリカ間の通信にチャネルは必要ありません。

集団演算によって実行される計算は、個々の演算に固有であり、上記の個々の演算に関するセクションで説明されています。ただし、プロセス グリッドがプロセス グループに分割される方法は、これらのオペレーション間で共有されるため、このセクションではその方法について説明します。正式には、StableHLO は次の 4 つの戦略をサポートしています。

cross_replica

各プロセス グループ内ではレプリカ間の通信のみが行われます。この戦略は、レプリカ ID のリストである replica_groups を受け取り、partition_ids による replica_groups のデカルト積を計算します。replica_groups は一意の要素を持ち、すべての replica_ids を網羅する必要があります。より正式には、Python 構文を使用して次のように記述します。

def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
  for replica_group in replica_groups:
    for partition_id in partition_ids:
      process_group = []
      for replica_id in replica_group:
        process_group.append((replica_id, partition_id))
      yield process_group

たとえば、replica_groups = [[0, 1], [2, 3]]num_partitions = 2 の場合、cross_replica[[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]] を生成します。

cross_partition

各プロセス グループ内ではパーティション間の通信のみが行われます。この戦略は、パーティション ID のリストである partition_groups を受け取り、replica_ids による partition_groups のデカルト積を計算します。partition_groups は一意の要素を持ち、すべての partition_ids を網羅する必要があります。 より正式には、Python 構文を使用して次のように記述します。

def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
  for partition_group in partition_groups:
    for replica_id in replica_ids:
      process_group = []
      for partition_id in partition_group:
        process_group.append((replica_id, partition_id))
      yield process_group

たとえば、partition_groups = [[0, 1]]num_replicas = 4 の場合、cross_partition[[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]] を生成します。

cross_replica_and_partition

レプリカ間通信とパーティション間通信の両方を各プロセス グループ内で行うことができます。この戦略は、replica_groups(レプリカ ID のリストのリスト)を取り、partition_ids によって各 replica_group のデカルト積を計算します。replica_groups は一意の要素を持ち、すべての replica_ids を網羅する必要があります。より正式には、Python 構文を使用して次のように記述します。

def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
  for replica_group in replica_groups:
    process_group = []
    for partition_id in partition_ids:
      for replica_id in replica_group:
        process_group.append((replica_id, partition_id))
    yield process_group

たとえば、replica_groups = [[0, 1], [2, 3]]num_partitions = 2 の場合、cross_replica_and_partition[[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]] を生成します。

flattened_ids

この戦略は、flattened_id_groupsreplica_id * num_partitions + partition_id 形式の「フラット化された」プロセス ID のリスト)を取り、プロセス ID に変換します。flattened_id_groups は一意の要素を持ち、すべての process_ids を網羅する必要があります。より正式には、Python 構文を使用して次のように記述します。

def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
  for flattened_id_group in flattened_id_groups:
    process_group = []
    for flattened_id in flattened_id_group:
      replica_id = flattened_id // num_partitions
      partition_id = flattened_id % num_partitions
      process_group.append((replica_id, partition_id))
    yield process_group

たとえば、flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]]num_replicas = 4num_partitions = 2 の場合、flattened_ids[[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]] を生成します。

精度

現時点では、StableHLO では数値の正確性に関する保証はありませんが、今後変更される可能性があります(#1156)。

エラー

StableHLO プログラムは、実行前に個々のオペレーションの広範な制約を通じて検証されるため、多くのクラスのエラーが除外されます。ただし、整数オーバーフローや境界外のアクセスなどによるエラー条件は引き続き発生する可能性があります。明示的に指定しない限り、これらのエラーはすべて実装定義の動作になりますが、将来変更される可能性があります(#1157)。

このルールの例外として、StableHLO プログラムの浮動小数点例外には、明確に定義された動作があります。IEEE-754 標準で定義されている例外(無効な演算、ゼロ除算、オーバーフロー、アンダーフロー、不正確例外)が発生する演算では、デフォルトの結果(この標準で定義されている)が生成され、対応するステータス フラグを発生させることなく実行が続行されます。標準の raiseNoFlag 例外処理と同様です。非標準の演算(複雑な算術関数や特定の超越関数など)の例外は実装で定義されます。

Notation

構文の説明のため、このドキュメントでは EBNF 構文の修正版 ISO フレーバー(ISO/IEC 14977:1996Wikipedia)を使用していますが、次の 2 つの点が変更されています。1)ルールの定義には = ではなく ::= を使用し、

2)連結は、, ではなく並置を使用して表現されます。

セマンティクスの記述(「型」、「定数」、「演算」の各セクション内)には、以下に説明する配列演算を簡潔に表現するためのサポートで拡張された Python 構文に基づく数式を使用しています。これは小さなコード スニペットには有効ですが、大規模なコード スニペットが必要なまれなケースでは、常に明示的に導入される標準の Python 構文を使用します。

数式

dot_general 仕様の例に基づいて、数式の仕組みを見ていきましょう。このオペレーションの制約の一つは、dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...) のようになります。

この式で使用される名前には、1)グローバル関数、2)対応するプログラム要素のメンバー定義(dot_general の「入力」セクションで定義された lhslhs_batching_dimensionsrhsrhs_batching_dimensions 入力)の 2 つのソースが使用されます。dim

前述のように、この式の構文は Python ベースであり、簡潔さを重視した拡張機能があります。数式を理解するため、標準の Python 構文に式を変換してみましょう。

A)これらの式では、= を使用して等式を表すため、Python 構文を取得するための最初のステップは、dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...) のように === に置き換えることです。

B)また、これらの式では、スカラー式をテンソル式に変換する省略記号(...)がサポートされています。要約すると、f(xs...) とは大まかに「テンソル xs 内のスカラー x ごとにスカラー f(x) を計算し、これらのスカラー結果をすべて一緒にテンソル結果として返す」ことを意味します。標準の Python 構文では、式の例は [dim(lhs, dim1) for dim1 in lhs_batching_dimensions] == [dim(rhs, dim2) for dim2 in rhs_batching_dimensions] になります。

省略記号を使うことで、個々のスカラー レベルでの作業を回避できることがよくあります。ただし、複雑なケースでは、gather 仕様の start_indices[bi0, ..., :, ..., biN] 式のように、低レベルの準非公式構文が使用されることがあります。簡潔にするために、このような構文を標準 Python に変換するための正確な形式はありませんが、個々の状況に応じて直感的に理解できることを期待しています。特定の数式が不透明に見える場合はお知らせください。改善に努めてまいります。

また、数式では、テンソルやテンソルのリスト(さまざまなテンソルから生じる可能性のあるもの)など、あらゆる種類のリストを拡張するために省略記号を使用します。これも正確なフォーマル性を提供しないもう一つの領域です(リストは StableHLO 型システムの一部でさえ理解できないなど)。

C)最後に注目すべき表記手段は暗黙的ブロードキャストです。StableHLO 演算セットは暗黙的なブロードキャストをサポートしていませんが、簡潔さのために式ではサポートしています。要約すると、テンソルが想定されるコンテキストでスカラーを使用すると、スカラーは想定される形状にブロードキャストされます。

dot_general の例を続けるには、別の制約として 0 <= lhs_batching_dimensions < rank(lhs) を使用します。dot_general 仕様で定義されているように、lhs_batching_dimensions はテンソルですが、0rank(lhs) はどちらもスカラーです。非明示的ブロードキャストを適用すると、式は [0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)] になります。

特定の dot_general 演算に適用すると、この式はブール値のテンソルと評価されます。式を制約として使用する場合、式が true と評価された場合、または true 要素のみを持つテンソルに評価された場合に制約が保持されます。

名前

数式の字句スコープには、1)グローバル関数、2)メンバー定義、

3)ローカル定義。グローバル関数のリストを以下に示します。要素定義のリストは、表記が適用されるプログラム要素によって異なります。

  • オペレーションの場合、メンバー定義には「入力」セクションと「出力」セクションで導入された名前が含まれます。
  • それ以外については、メンバー定義にプログラム要素の構造的部分が含まれ、対応する EBNF の非終端名にちなんで名が付けられています。ほとんどの場合、これらの構造部分の名前は、非終端名をスネークケースに変換することで取得できます(例: IntegerLiteral => integer_literal)。ただし、プロセス内で名前が省略されることもあります(例: QuantizationStorageType => storage_type)。その場合は、オペレーションの「入力」セクションまたは「出力仕様」セクションと同様に、名前が明示的に紹介されます。
  • また、メンバー定義には、対応するプログラム要素を参照する self が常に含まれています。

Values

数式が評価されると、次の型の値が使用されます。1)Valuedense<[[1, 2], [3, 4]]> : tensor<2x2xi32> などの実際の値。型は常に知っている)、2)Placeholderlhsrhsresult などの将来の値。実際の値はまだ不明で、型のみがわかっている)、3)Type(型のセクションで定義された型)、4)Function(関数で定義されているグローバル関数)。

コンテキストによっては、名前が参照する値が異なる場合があります。具体的には、演算の「セマンティクス」セクション(および他のプログラム要素の同等のもの)でランタイム ロジックが定義されており、すべての入力を Value として使用できます。これに対して、演算(および同等のもの)の「制約」セクションでは、「コンパイル時」のロジック、つまり、通常は実行前に実行されるものが定義されています。そのため、定数入力のみが Value として使用でき、その他の入力は Placeholder としてのみ使用できます。

名前 「セマンティクス」での操作 [制約] で
グローバル機能 Function Function
定数入力 Value Value
非定数入力 Value Placeholder
出力 Value Placeholder
ローカルの定義 定義によって異なる 定義によって異なる

transpose オペレーションの例を考えてみましょう。

%result = "stablehlo.transpose"(%operand) {
  permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>

このオペレーションでは、permutation は定数であるため、セマンティクスと制約の両方で Value として使用できます。これに対して、operandresult はセマンティクスでは Value として利用できますが、制約では Placeholder としてのみ使用できます。

関数

タイプの構築

型の作成に使用できる関数はありません。通常はより簡潔であるため、代わりに型構文を直接使用します。たとえば、function_type( [tensor_type([], E), tensor_type([], E)], [tensor_type([], E)]) ではなく (tensor<E>, tensor<E>) -> (tensor<E>) です。

型の関数

  • element_type は、テンソル型と量子化テンソル型で定義され、それぞれ対応する TensorType または QuantizedTensorTypeTensorElementType 部分または QuantizedTensorElementType 部分で返されます。
def element_type(x: Value | Placeholder | Type):
 if type(x) == TensorType:
    return tensor_element_type(x)
  if type(x) == QuantizedTensorType:
    return quantized_tensor_element_type(x)
  if type(x) is not Type:
    return element_type(type(x))
  • is_per_axis_quantized(x: Value | Placeholder | Type) -> Valueis_quantized(x) and quantization_dimension(x) is not None のショートカットです。

  • is_per_tensor_quantized(x: Value | Placeholder | Type) -> Valueis_quantized(x) and quantization_dimension(x) is None のショートカットです。

  • is_promotable(x: Type, y: Type) -> bool は、タイプ x をタイプ y に昇格できるかどうかを確認します。xyQuantizedTensorElementType の場合、プロモーションは storage_type にのみ適用されます。現在、この特定のバージョンのプロモーションは、削減の計算のコンテキストで使用されます(詳しくは、RFC をご覧ください)。

def is_promotable(x: Type, y: Type) -> Value:
  is_same_type = (is_bool(x) and is_bool(y)) or
    (is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
    (is_complex(x) and is_complex(y)) or
    (is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))

  if is_same_type == False:
    return False

  if is_integer(x) or is_float(x):
    return bitwidth(x) <= bitwidth(y)

  if is_complex(x):
    return bitwidth(element_type(x)) <= bitwidth(element_type(y))

  if is_quantized(x):
    return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))

  return false
  • is_quantized(x: Value | Placeholder | Type) -> Valueis_quantized_tensor_element_type(x) のショートカットです。

  • is_type_name(x: Value | Placeholder | Type) -> Value。すべてのタイプで使用できます。たとえば、xFloatType の場合、is_float(x)true を返します。x が値またはプレースホルダの場合、この関数は is_type_name(type(x)) のショートカットです。

  • max_value(x: Type) -> Value は、TensorElementType の最大値を返します。xTensorElementType でない場合は、None を返します。

  • min_value(x: Type) -> Value は、TensorElementType の可能な最小値を返します。xTensorElementType でない場合は、None を返します。

  • member_name(x: Value | Placeholder | Type) -> Any。すべてのタイプのすべてのメンバー定義 member_name で使用できます。たとえば、tensor_element_type(x) は、対応する TensorTypeTensorElementType 部分を返します。x が値またはプレースホルダの場合、この関数は member_name(type(x)) のショートカットです。x が、適切なメンバーを持たない型、またはそのような型の値またはプレースホルダを持たない場合、None を返します。

値の構成

  • operation_name(*xs: Value | Type) -> Value。すべてのオペレーションで使用できます。たとえば、add(lhs, rhs)lhsrhs の 2 つのテンソル値を受け取り、これらの入力で add 演算を評価した結果を返します。broadcast_in_dim などの一部のオペレーションでは、その出力のタイプは「耐荷重性」であり、オペレーションの評価に必要です。この場合、関数はこれらの型を引数として受け取ります。

値の関数

  • Python のすべての演算子と関数が使用できます。たとえば、Python では、サブスクリプション表記とスライス表記の両方を使用して、テンソル、量子化テンソル、タプルへのインデックスを作成できます。

  • to_destination_type(x: Value, destination_type: Type) -> Value はテンソルで定義され、次のように type(x)destination_type に基づいて x の変換された値を返します。

def to_destination_type(x: Value, destination_type: Type) -> Value:
  if type(x) == destination_type:
    return x

  if is_quantized(destination_type):
    if is_quantized(type(x)):
      return quantize(x, destination_type)
    assert is_float(type(x))
    return quantize(x, destination_type)

  if is_quantized(type(x)):
    assert destination_type = expressed_type(type(x))
    return dequantize(type(x))

  return convert(x, destination_type)

convertuniform_quantizeuniform_dequantize オペレーションのマージについては、早期に議論されています(#1576)。マージ後は、上記の関数は不要になり、代わりに convert のオペレーション名を使用できます。

  • is_nan(x: Value) -> Value はテンソルで定義され、x のすべての要素が NaN であれば true を、それ以外の場合は false を返します。x がテンソルでない場合、None を返します。

  • is_sorted(x: Value) -> Value はテンソルに対して定義され、x の要素がインデックスの辞書順の昇順で並べ替えられている場合は true を返し、それ以外の場合は false を返します。x がテンソルでない場合、None を返します。

  • is_unique(x: Value) -> Value はテンソルで定義され、x に重複する要素がない場合は true を返し、それ以外の場合は false を返します。x がテンソルでない場合、None を返します。

  • member_name(x: Value) -> Any は、すべての値のすべてのメンバー定義 member_name に対して定義されます。たとえば、real_part(x) は、対応する ComplexConstantRealPart 部分を返します。x が適切なメンバーを持つ値でない場合は、None を返します。

  • same(x: Value) -> Value はテンソルで定義され、x の要素がすべて互いに等しい場合は true を返し、そうでない場合は false を返します。テンソルに要素がない場合、「すべて等しい」と見なされます。つまり、関数は true を返します。x がテンソルでない場合、None を返します。

  • split(x: Value, num_results: Value, axis: Value) -> Value はテンソルで定義され、axis 軸上の xnum_results スライスを返します。x がテンソルまたは dim(x, axis) % num_results != 0 でない場合、None を返します。

形状計算

  • axes(x: Value | Placeholder | Type) -> Valuerange(rank(x)) のショートカットです。

  • dim(x: Value | Placeholder | Type, axis: Value) -> Valueshape(x)[axis] のショートカットです。

  • dims(x: Value | Placeholder | Type, axes: List) -> Listlist(map(lambda axis: dim(x, axis), axes)) のショートカットです。

  • index_space(x: Value | Placeholder | Type) -> Value はテンソルで定義され、対応する TensorTypesize(x) インデックスを辞書順([0, ..., 0][0, ..., 1]、...、shape(x) - 1)の昇順に並べ替えて返します。x がテンソル型、量子化テンソル型、値またはこれらの型のプレースホルダではない場合、None を返します。

  • rank(x: Value | Placeholder | Type) -> Valuesize(shape(x)) のショートカットです。

  • shape(x: Value | Placeholder | Type) -> Value は、member_name を介して「型の関数」セクションで定義されます。

  • size(x: Value | Placeholder | Type) -> Valuereduce(lambda x, y: x * y, shape(x)) のショートカットです。

量子化計算

  • def baseline_element_type(x: Value | Placeholder | Type) -> Typeelement_type(baseline_type(x)) のショートカットです。

  • baseline_type は、テンソル型と量子化テンソル型に対して定義され、「ベースライン」に変換されます。つまり、同じ形状でも要素型の量子化パラメータがデフォルト値にリセットされた型です。これは、テンソルと量子化テンソルの型を均一に比較する便利な手法として使用されます。これは頻繁に必要になります。量子化型の場合、量子化パラメータを無視して型を比較できます。つまり、shapestorage_typeexpressed_typestorage_minstorage_maxquantization_dimension(軸ごとの量子化型の場合)はすべて一致している必要がありますが、scaleszero points は異なる場合があります。

def baseline_type(x: Value | Placeholder | Type) -> Type:
  if type(x) == TensorType:
    return x
  if type(x) == QuantizedTensorType:
    element_type = quantized_tensor_element_type(x)
    baseline_element_type = QuantizedTensorElementType(
      storage_type = storage_type(element_type),
      storage_min = storage_min(element_type),
      storage_max = storage_max(element_type),
      expressed_type = expressed_type(element_type),
      quantization_dimension = quantization_dimension(element_type),
      scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
      zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
    return QuantizedTensorType(shape(x), baseline_element_type)
  if type(x) is not Type:
    return baseline_element_type(type(x))
  • dequantize は量子化テンソル型に対して定義され、それらを浮動小数点テンソル型に変換します。これは、量子化された要素タイプに関連付けられたゼロ点とスケールを使用して、ストレージ タイプの整数値を表す量子化要素を、表現された型の対応する浮動小数点値に変換することによって行われます。
def compute_zero_points(quantized_type, result_type):
  if is_per_tensor_quantized(quantized_type):
    return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
  if is_per_axis_quantized(quantized_type):
    for i in index_space(result_type):
      d = quantization_dimension(quantized_type)
      zero_points[i] = zero_points(quantized_type)[i[d]]
    return zero_points

def compute_scales(quantized_type, result_type):
  if is_per_tensor_quantized(quantized_type):
    return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
            type(result_type))
  if is_per_axis_quantized(quantized_type):
    for i in index_space(result_type):
      d = quantization_dimension(quantized_type)
      scales[i] = scales(quantized_type)[i[d]]
    return scales

def dequantize(x: Value) -> Value:
  assert is_quantized(x)
  x_storage = bitcast_convert(x, storage_type(x))
  x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
  x_expressed_sub = convert(x_storage_sub, expressed_type(x))
  return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
  • quantize は浮動小数点テンソル型で定義され、それを量子化テンソル型に変換します。これは、量子化された要素型に関連付けられたゼロポイントとスケールを使用して、表現された型の浮動小数点値を対応するストレージ タイプの整数値に変換することによって行われます。
def quantize(x: Value, type: Type) -> Value:
  assert is_float(x) and is_quantized(type)
  x_expressed_rounded = round_nearest_even(x / compute_scales(type, type(x)))
  x_storage_rounded = convert(x_expressed_rounded, storage_type(type))
  x_storage_add = x_storage_rounded + compute_zero_points(type, type(x_storage_rounded))
  x_storage = clamp(storage_min(type), x_storage_add, storage_max(type))
  return bitcast_convert(x_storage, type)
  • dequantize_op_quantize は、量子化テンソルの要素単位の計算を指定するために使用します。逆量子化(量子化された要素を表現された型に変換)してから演算を実行してから、量子化します。つまり、その結果をストレージの型に戻します。現時点では、この関数はテンソルごとの量子化でのみ機能します。軸ごとの量子化が進行中です(#1574)。
def dequantize_op_quantize(op, *inputs_and_output_type):
  inputs = inputs_and_output_type[:-1]
  output_type = inputs_and_output_type[-1]

  float_inputs = map(dequantize, inputs)
  float_result = op(*float_inputs)
  return quantize(float_result, output_type)

def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
  inputs = inputs_and_output_type[:-3]
  float_inputs = map(dequantize, inputs)
  float_results = op(*float_inputs)
  return map(quantize, float_results, inputs_and_output_type[-3:])

def dequantize_compare(lhs, rhs, comparison_direction):
  float_lhs = dequantize(lhs)
  float_rhs = dequantize(rhs)
  return compare(float_lhs, float_rhs, comparison_direction, FLOAT)

def dequantize_select_quantize(pred, on_true, on_false, output_type):
  float_on_true = dequantize(on_true)
  float_on_false = dequantize(on_false)
  float_result = select(pred, float_on_true, float_on_false)
  return quantize(float_result, output_type)

グリッド計算

  • cross_partition(replica_groups: Value) -> Value。上記の「cross_replica」セクションをご覧ください。

  • cross_replica(replica_groups: Value) -> Value。上記の「cross_replica」セクションをご覧ください。

  • cross_replica_and_partition(replica_groups: Value) -> Value。上記の「cross_replica_and_partition」セクションをご覧ください。

  • flattened_ids(replica_groups: Value) -> Value。上記の「flattened_ids」セクションをご覧ください。