Spesifikasi StableHLO

StableHLO adalah kumpulan operasi untuk operasi tingkat tinggi (HLO) dalam model machine learning (ML). StableHLO berfungsi sebagai lapisan portabilitas antara berbagai framework ML dan compiler ML: framework ML yang menghasilkan program StableHLO kompatibel dengan compiler ML yang menggunakan program StableHLO.

Sasaran kami adalah menyederhanakan dan mempercepat pengembangan ML dengan menciptakan lebih banyak interoperabilitas antara berbagai framework ML (seperti TensorFlow, JAX, dan PyTorch) serta compiler ML (seperti XLA dan IREE). Untuk mencapai tujuan tersebut, dokumen ini memberikan spesifikasi untuk bahasa pemrograman StableHLO.

Spesifikasi ini memuat tiga bagian utama. Pertama, bagian Program menjelaskan struktur program StableHLO yang terdiri dari fungsi StableHLO yang terdiri dari operasi StableHLO. Dalam struktur tersebut, bagian Ops menentukan semantik setiap operasi. Bagian Execution menyediakan semantik untuk semua operasi ini yang dijalankan bersama dalam program. Terakhir, bagian Notasi membahas notasi yang digunakan di seluruh spesifikasi.

Untuk melihat spesifikasi dari rilis StableHLO sebelumnya, buka repo di rilis yang diberi tag minat. Misalnya, Spesifikasi StableHLO v0.19.0. Untuk melihat perubahan yang terjadi pada setiap peningkatan versi minor StableHLO, lihat log versi di VhloDialect.td.

Program

Program ::= {Func}

Program StableHLO terdiri dari sejumlah fungsi StableHLO yang tidak ditentukan. Berikut adalah contoh program dengan fungsi @main yang memiliki 3 input (%image, %weights, dan %bias) dan 1 output. Isi fungsi memiliki 6 operasi.

func.func @main(
  %image: tensor<28x28xf32>,
  %weights: tensor<784x10xf32>,
  %bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
  %0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
  %1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
  %2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  %3 = "stablehlo.constant"() {value = dense<0.0> : tensor<1x10xf32>} : () -> tensor<1x10xf32>
  %4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  "func.return"(%4): (tensor<1x10xf32>) -> ()
}

Fungsi

Func        ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs  ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput   ::= ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput  ::= ValueType
FuncBody    ::= {Op}

Fungsi StableHLO (yang juga disebut fungsi bernama) memiliki ID, input/output, dan isi. Di masa mendatang, kami berencana untuk memperkenalkan metadata tambahan untuk fungsi guna mencapai kompatibilitas yang lebih baik dengan HLO (#425, #626, #740, #744).

Pengenal

FuncId  ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
          | '%' letter {letter | digit}
letter  ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit   ::= '0' | ... | '9'

ID StableHLO mirip dengan ID dalam banyak bahasa pemrograman, dengan dua keunikan: 1) semua ID memiliki sigil yang membedakan berbagai jenis ID, 2) ID nilai dapat sepenuhnya numerik untuk menyederhanakan pembuatan program StableHLO.

Jenis

Type         ::= ValueType | NonValueType
ValueType    ::= TensorType | QuantizedTensorType | TokenType | TupleType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType

Jenis StableHLO dikategorikan ke dalam jenis nilai (yang juga disebut jenis kelas satu) yang mewakili nilai StableHLO dan jenis non-nilai yang menjelaskan elemen program lainnya. Jenis StableHLO mirip dengan jenis dalam banyak bahasa pemrograman, dengan keunikan utamanya adalah sifat khusus domain StableHLO yang menghasilkan beberapa hasil yang tidak biasa (misalnya, jenis skalar bukan merupakan jenis nilai).

TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit} | '?'

Jenis tensor merepresentasikan tensor, yaitu array multidimensi. Elemen ini memiliki bentuk dan jenis elemen, dengan bentuk mewakili ukuran dimensi non-negatif atau tidak diketahui dalam urutan menaik dari dimensi yang sesuai (yang juga disebut sumbu) yang diberi nomor dari 0 hingga R-1. Jumlah dimensi R disebut peringkat. Misalnya, tensor<2x3xf32> adalah jenis tensor dengan bentuk 2x3 dan jenis elemen f32. Matriks ini memiliki dua dimensi (atau, dengan kata lain, dua sumbu) - dimensi ke-0 dan dimensi ke-1 - yang ukurannya adalah 2 dan 3. Peringkatnya adalah 2.

Bentuk dapat sebagian atau sepenuhnya tidak diketahui (dinamis), misalnya tensor<?x2xf64> sebagian tidak diketahui dan tensor<?x?xf64> sepenuhnya tidak diketahui. Ukuran dimensi dinamis direpresentasikan menggunakan ?. Bentuk tidak boleh diberi peringkat.

Di masa mendatang, kami berencana untuk mempelajari perluasan jenis tensor di luar ukuran dimensi dan jenis elemen, misalnya, untuk menyertakan tata letak (#629) dan sparsitas (#1078).

QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
                  QuantizationStorageType
                  ['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
                  ':' QuantizationExpressedType
                  [':' QuantizationDimension]
                  ',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerLiteral
QuantizationStorageMax ::= IntegerLiteral
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerLiteral
QuantizationParameters ::= QuantizationParameter
                         | '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale [':' QuantizationZeroPoint]
QuantizationScale ::= FloatLiteral
QuantizationZeroPoint ::= IntegerLiteral
Nama Jenis Batasan
storage_type jenis bilangan bulat (C1-C3), (C8)
storage_min konstanta bilangan bulat (C1), (C3), (C7)
storage_max konstanta bilangan bulat (C2), (C3), (C7)
expressed_type jenis floating point (C4)
quantization_dimension konstanta bilangan bulat opsional (C10-C12)
scales jumlah variabel konstanta floating point (C4-C6), (C9), (C10), (C13)
zero_points jumlah variabel konstanta integer (C7-C9)

Jenis elemen kuantisasi mewakili nilai bilangan bulat dari jenis penyimpanan dalam rentang dari storage_min hingga storage_max (inklusif) yang sesuai dengan nilai floating point dari jenis yang dinyatakan. Untuk nilai bilangan bulat i tertentu, nilai floating point yang sesuai dengan f dapat dikomputasi sebagai f = (i - zero_point) * scale, dengan scale dan zero_point disebut parameter kuantisasi. storage_min dan storage_max bersifat opsional dalam tata bahasa, tetapi memiliki nilai default min_value(storage_type) dan max_value(storage_type). Jenis elemen terkuantisasi memiliki batasan berikut:

  • (C1) type(storage_min) = storage_type.
  • (C2) type(storage_max) = storage_type.
  • (C3) min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type).
  • (C4) type(scales...) = expressed_type.
  • (C5) 0 < scales.
  • (C6) is_finite(scales...).
  • (C7) storage_min <= zero_points <= storage_max.
  • (C8) type(zero_points...) = storage_type.
  • (C9) size(scales) = size(zero_points).
  • (C10) Jika is_empty(quantization_dimension), maka size(scales) = 1.
  • (C11) 0 <= quantization_dimension.

Saat ini, QuantizationScale adalah konstanta floating point, tetapi ada minat yang kuat pada skala berbasis bilangan bulat, yang direpresentasikan dengan pengali dan pergeseran. Kami berencana untuk mempelajarinya dalam waktu dekat (#1404).

Ada diskusi yang sedang berlangsung tentang semantik QuantizationZeroPoint, termasuk jenis, nilai, dan apakah hanya ada satu atau berpotensi beberapa titik nol dalam jenis tensor kuantisasi. Berdasarkan hasil diskusi ini, spesifikasi seputar titik nol dapat berubah di masa mendatang (#1405).

Diskusi lain yang sedang berlangsung melibatkan semantik QuantizationStorageMin dan QuantizationStorageMax untuk menentukan apakah batasan apa pun harus diterapkan pada nilai ini dan pada nilai tensor kuantisasi (#1406).

Terakhir, kami berencana untuk mengeksplorasi representasi skala dan titik nol yang tidak diketahui, mirip dengan cara kami berencana untuk mengeksplorasi representasi ukuran dimensi yang tidak diketahui (#1407).

Jenis tensor terkuantisasi merepresentasikan tensor dengan elemen terkuantisasi. Tensor ini sama persis dengan tensor reguler, kecuali elemennya memiliki jenis elemen yang dikuantisasi, bukan jenis elemen reguler.

Dalam tensor kuantisasi, kuantisasi dapat berupa per-tensor, yang berarti memiliki satu scale dan zero_point untuk seluruh tensor atau dapat berupa per-sumbu, yang berarti memiliki beberapa scales dan zero_points, satu pasangan per slice dimensi tertentu quantization_dimension. Secara lebih formal, dalam tensor t dengan kuantisasi per sumbu, ada dim(t, quantization_dimension) slice dari quantization_dimension: t[:, ..., 0, ..., :], t[:, ..., 1, ..., :], dll. Semua elemen dalam slice ke-i menggunakan scales[i] dan zero_points[i] sebagai parameter kuantisasi. Jenis tensor terkuantisasi memiliki batasan berikut:

  • Untuk kuantisasi per tensor:
    • Tidak ada batasan tambahan.
  • Untuk kuantisasi per sumbu:
    • (C12) quantization_dimension < rank(self).
    • (C13) dim(self, quantization_dimension) = size(scales).
TokenType ::= 'token'

Jenis token mewakili token, yaitu nilai buram yang dihasilkan dan digunakan oleh beberapa operasi. Token digunakan untuk menerapkan urutan eksekusi pada operasi seperti yang dijelaskan di bagian Eksekusi.

TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]

Jenis tuple mewakili tuple, yaitu daftar heterogen. Tuple adalah fitur lama yang hanya ada untuk kompatibilitas dengan HLO. Di HLO, tuple digunakan untuk merepresentasikan input dan output variadik. Di StableHLO, input dan output variadik didukung secara native, dan satu-satunya penggunaan tuple di StableHLO adalah untuk mewakili HLO ABI secara komprehensif, misalnya T, tuple<T>, dan tuple<tuple<T>> mungkin secara material berbeda bergantung pada implementasi tertentu. Di masa mendatang, kami berencana untuk membuat perubahan pada HLO ABI yang dapat memungkinkan kita menghapus jenis tuple dari StableHLO (#598).

TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si2' | 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui2' | 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f4E2M1FN' | 'f6E2M3FN' | 'f6E3M2FN' | 'f8E3M4' | 'f8E4M3'
            | 'f8E4M3FN' | 'f8E4M3FNUZ' | 'f8E4M3B11FNUZ' | 'f8E5M2'
            | 'f8E5M2FNUZ' | 'f8E8M0FNU' | 'bf16' | 'f16' | 'f32' | 'f64'
TensorFloat32 ::= 'tf32'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'

Jenis elemen mewakili elemen jenis tensor. Tidak seperti di banyak bahasa pemrograman, jenis ini bukan kelas satu di StableHLO. Artinya, program StableHLO tidak dapat langsung merepresentasikan nilai dari jenis ini (akibatnya, idiomatik untuk merepresentasikan nilai skalar dari jenis T dengan nilai tensor dimensi 0 dari jenis tensor<T>).

  • Jenis boolean mewakili nilai boolean true dan false.
  • Jenis bilangan bulat dapat bertanda (si) atau tidak bertanda (ui) dan memiliki salah satu lebar bit yang didukung (2, 4, 8, 16, 32, atau 64). Jenis siN bertanda mewakili nilai bilangan bulat dari -2^(N-1) hingga 2^(N-1)-1 inklusif, dan jenis uiN tanpa tanda mewakili nilai bilangan bulat dari 0 hingga 2^N-1 inklusif.
  • Jenis floating point dapat berupa salah satu dari berikut:
  • Jenis kompleks mewakili nilai kompleks yang memiliki bagian riil dan bagian imajiner dari jenis elemen yang sama. Jenis kompleks yang didukung adalah complex<f32> (kedua bagiannya berjenis f32) dan complex<f64> (kedua bagiannya berjenis f64).
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]

Jenis fungsi mewakili fungsi bernama dan anonim. Fungsi ini memiliki jenis input (daftar jenis di sisi kiri ->) dan jenis output (daftar jenis di sisi kanan ->). Dalam banyak bahasa pemrograman, jenis fungsi adalah kelas satu, tetapi tidak di StableHLO.

StringType ::= 'string'

Jenis string mewakili urutan byte. Tidak seperti dalam banyak bahasa pemrograman, jenis string bukan kelas pertama di StableHLO dan hanya digunakan untuk menentukan metadata statis untuk elemen program.

Operasi

Operasi StableHLO (yang juga disebut ops) mewakili kumpulan tertutup operasi tingkat tinggi dalam model machine learning. Seperti yang telah dibahas di atas, sintaksis StableHLO sangat terinspirasi oleh MLIR, yang belum tentu merupakan alternatif paling ergonomis, tetapi dapat dibilang paling sesuai dengan sasaran StableHLO untuk membuat lebih banyak interoperabilitas antara framework ML dan compiler ML.

Op            ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName        ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic    ::= 'abs' | 'add' | ...

Operasi StableHLO (yang juga disebut ops) memiliki nama, input/output, dan tanda tangan. Nama terdiri dari awalan stablehlo. dan mnemonik yang secara unik mengidentifikasi salah satu operasi yang didukung. Lihat di bawah untuk mengetahui daftar lengkap semua operasi yang didukung.

OpInputs        ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues   ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue    ::= ValueId
OpInputFuncs    ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs    ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs       ::= [OpOutput {',' OpOutput} '=']
OpOutput        ::= ValueId

Ops menggunakan input dan menghasilkan output. Input dikategorikan ke dalam nilai input (dihitung selama eksekusi), fungsi input (disediakan secara statis, karena dalam StableHLO bukan merupakan nilai kelas satu), dan atribut input (juga disediakan secara statis). Jenis input dan output yang digunakan dan dihasilkan oleh op bergantung pada mnemoninya. Misalnya, operasi add menggunakan 2 nilai input dan menghasilkan 1 nilai output. Sebagai perbandingan, op select_and_scatter menggunakan 3 nilai input, 2 fungsi input, dan 3 atribut input.

OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused      ::= '^' digit {digit}
              | '^' letter {letter | digit}

Fungsi input (yang juga disebut fungsi anonim) sangat mirip dengan fungsi bernama, kecuali: 1) tidak memiliki ID (sehingga namanya "anonim"), 2) tidak mendeklarasikan jenis output (jenis output disimpulkan dari operasi return dalam fungsi).

Sintaksis untuk fungsi input mencakup bagian yang saat ini tidak digunakan (lihat produksi Unused di atas) yang ada untuk kompatibilitas dengan MLIR. Di MLIR, ada konsep "region" yang lebih umum yang dapat memiliki beberapa "blok" operasi yang terhubung bersama melalui operasi lompat. Blok ini memiliki ID yang sesuai dengan produksi Unused, sehingga dapat dibedakan satu sama lain. StableHLO tidak memiliki operasi lompat, sehingga bagian yang sesuai dari sintaksis MLIR tidak digunakan (tetapi masih ada).

OpInputAttr      ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName  ::= letter {letter | digit}
OpInputAttrValue ::= Constant

Atribut input memiliki nama dan nilai yang merupakan salah satu konstanta yang didukung. Class ini adalah cara utama dalam menentukan metadata statis untuk elemen program. Misalnya, operasi concatenate menggunakan atribut dimension untuk menentukan dimensi tempat nilai inputnya digabungkan. Demikian pula, op slice menggunakan beberapa atribut seperti start_indices dan limit_indices untuk menentukan batas yang digunakan untuk memotong nilai input.

Saat ini, program StableHLO di dunia nyata terkadang berisi atribut yang tidak dijelaskan dalam dokumen ini. Di masa mendatang, kami berencana untuk menyerap atribut ini ke dalam opset StableHLO atau melarangnya muncul dalam program StableHLO. Sementara itu, berikut adalah daftar atribut ini:

  • layout (#629).
  • mhlo.frontend_attributes (#628).
  • mhlo.sharding (#619).
  • output_operand_aliases (#740).
  • Metadata lokasi (#594).
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'

Tanda tangan op terdiri dari jenis semua nilai input (daftar jenis di sisi kiri ->) dan jenis semua nilai output (daftar jenis di sisi kanan ->). Sederhananya, jenis input bersifat redundan, dan jenis output hampir selalu redundan (karena untuk sebagian besar operasi StableHLO, jenis output dapat disimpulkan dari input). Meskipun demikian, tanda tangan op sengaja menjadi bagian dari sintaksis StableHLO untuk kompatibilitas dengan MLIR.

Berikut adalah contoh operasi yang mnemoninya adalah select_and_scatter. Fungsi ini menggunakan 3 nilai input (%operand, %source, dan %init_value), 2 fungsi input, dan 3 atribut input (window_dimensions, window_strides, dan padding). Perhatikan bagaimana tanda tangan operasi hanya menyertakan jenis nilai inputnya (tetapi bukan jenis fungsi dan atribut input yang disediakan secara inline).

%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
    %0 = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GE>
    } : (tensor<i32>, tensor<i32>) -> tensor<i1>
    "stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
    "stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
  window_dimensions = dense<[3, 1]> : tensor<2xi64>,
  window_strides = dense<[2, 1]> : tensor<2xi64>,
  padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>

Konstanta

Constant ::= BooleanConstant
           | IntegerConstant
           | FloatConstant
           | ComplexConstant
           | TensorConstant
           | QuantizedTensorConstant
           | StringConstant
           | EnumConstant

Konstanta StableHLO memiliki literal dan jenis yang bersama-sama mewakili nilai StableHLO. Umumnya, jenis ini adalah bagian dari sintaksis konstanta, kecuali jika tidak ambigu (misalnya, konstanta boolean memiliki jenis i1 secara tidak ambigu, sedangkan konstanta bilangan bulat dapat memiliki beberapa kemungkinan jenis).

BooleanConstant ::= BooleanLiteral
BooleanLiteral  ::= 'true' | 'false'

Konstanta Boolean mewakili nilai boolean true dan false. Konstanta Boolean memiliki jenis i1.

IntegerConstant   ::= IntegerLiteral ':' IntegerType
IntegerLiteral    ::= ['-' | '+'] DecimalDigits
                    | ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits     ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit      ::= '0' | ... | '9'
hexadecimalDigit  ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'

Konstanta bilangan bulat merepresentasikan nilai bilangan bulat melalui string yang menggunakan notasi desimal atau heksadesimal. Basis lain, misalnya biner atau oktal, tidak didukung. Konstanta bilangan bulat memiliki batasan berikut:

  • (C1) is_wellformed(integer_literal, integer_type).
FloatConstant  ::= FloatLiteral ':' FloatType
FloatLiteral   ::= SignPart IntegerPart FractionalPart ScientificPart
                 | '0x' [HexadecimalDigits]
SignPart       ::= ['-' | '+']
IntegerPart    ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]

Konstanta floating point mewakili nilai floating point melalui string yang menggunakan notasi desimal atau notasi ilmiah. Selain itu, notasi heksadesimal dapat digunakan untuk menentukan secara langsung bit yang mendasarinya dalam format floating point dari jenis yang sesuai. Konstanta floating point memiliki batasan berikut:

  • (C1) Jika notasi non-heksadesimal digunakan, is_wellformed(float_literal, float_type).
  • (C2) Jika notasi heksadesimal digunakan, size(hexadecimal_digits) = num_bits(float_type) / 4.
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral  ::= '(' RealPart ',' ImaginaryPart ')'
RealPart        ::= FloatLiteral
ImaginaryPart   ::= FloatLiteral

Konstanta kompleks mewakili nilai kompleks menggunakan daftar bagian real (didahulukan) dan bagian imajiner (didahulukan). Misalnya, (1.0, 0.0) : complex<f32> mewakili 1.0 + 0.0i, dan (0.0, 1.0) : complex<f32> mewakili 0.0 + 1.0i. Urutan penyimpanan bagian-bagian ini dalam memori ditentukan oleh implementasi. Konstanta kompleks memiliki batasan berikut:

  • (C1) is_wellformed(real_part, complex_element_type(complex_type)).
  • (C2) is_wellformed(imaginary_part, complex_element_type(complex_type)).
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral  ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral   ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements  ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral

Konstanta tensor merepresentasikan nilai tensor menggunakan daftar bertingkat yang ditentukan melalui notasi NumPy. Misalnya, dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> merepresentasikan nilai tensor dengan pemetaan berikut dari indeks ke elemen: {0, 0} => 1, {0, 1} => 2, {0, 2} => 3, {1, 0} => 4, {1, 1} => 5, {1, 2} => 6. Urutan penyimpanan elemen ini dalam memori ditentukan oleh implementasi. Konstanta tensor memiliki batasan berikut:

  • (C1) has_syntax(tensor_literal, element_type(tensor_type)), dengan:
    • has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type).
    • has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type).
  • (C2) has_shape(tensor_literal, shape(tensor_type)), dengan:
    • has_shape(element_literal: Syntax, []) = true.
    • has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:]).
    • jika tidak, false.
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral  ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'

Konstanta tensor kuantisasi mewakili nilai tensor kuantisasi menggunakan notasi yang sama dengan konstanta tensor, dengan elemen yang ditentukan sebagai konstanta dari jenis penyimpanannya. Konstanta tensor terkuantisasi memiliki batasan berikut:

  • (C1) has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type)).
  • (C2) has_shape(quantized_tensor_literal, shape(quantized_tensor_type)).
StringConstant  ::= StringLiteral
StringLiteral   ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence  ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))

Literal string terdiri dari byte yang ditentukan menggunakan karakter ASCII dan urutan escape. Byte ini tidak bergantung pada encoding, sehingga interpretasi byte ini ditentukan oleh implementasi. String literal memiliki jenis string.

Operasi

abs

Semantik

Melakukan operasi abs element-wise pada tensor operand dan menghasilkan tensor result. Bergantung pada jenis elemen, lakukan tindakan berikut:

  • Untuk bilangan bulat bertanda tangan: modulus bilangan bulat.
  • Untuk float: abs dari IEEE-754.
  • Untuk bilangan kompleks: modulus kompleks.
  • Untuk jenis terkuantisasi: dequantize_op_quantize(abs, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand tensor bilangan bulat bertanda tangan, floating point, atau jenis kompleks atau tensor terkuantisasi per tensor (C1-C2)

Output

Nama Jenis Batasan
result tensor dari jenis bilangan bulat bertanda atau floating point atau tensor terkuantisasi per tensor (C1-C2)

Batasan

  • (C1) shape(result) = shape(operand).
  • (C2) baseline_element_type(result) ditentukan sebagai:
    • complex_element_type(element_type(operand)) jika is_complex(operand).
    • baseline_element_type(operand) jika tidak.

Contoh

// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand) : (tensor<3xi32>) -> tensor<3xi32>
// %result: [2, 0, 2]

 Contoh Lainnya

tambahkan

Semantik

Melakukan penambahan element-wise dari dua tensor lhs dan rhs dan menghasilkan tensor result. Bergantung pada jenis elemen, lakukan tindakan berikut:

  • Untuk boolean: OR logis.
  • Untuk bilangan bulat: penambahan bilangan bulat.
  • Untuk float: addition dari IEEE-754.
  • Untuk bilangan kompleks: penambahan kompleks.
  • Untuk jenis terkuantisasi: dequantize_op_quantize(add, lhs, rhs, type(result)).

Input

Label Nama Jenis Batasan
(I1) lhs tensor atau tensor terkuantisasi (C1-C6)
(I2) rhs tensor atau quantized tensor (C1-C5), (C7)

Output

Nama Jenis Batasan
result tensor atau tensor terkuantisasi (C1-C7)

Batasan

  • Jika operasi menggunakan tensor yang tidak dikuantisasi:
    • (C1) type(lhs) = type(rhs) = type(result).
  • Jika operasi menggunakan tensor kuantisasi:
    • (C2) is_quantized(lhs) and is_quantized(rhs) and is_quantized(result).
    • (C3) storage_type(lhs) = storage_type(rhs) = storage_type(result).
    • (C4) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • (C5) (is_per_axis_quantized(lhs) or is_per_axis_quantized(rhs)) = is_per_axis_quantized(result).
    • (C6) Jika is_per_axis_quantized(lhs), maka quantization_dimension(lhs) = quantization_dimension(result).
    • (C7) Jika is_per_axis_quantized(rhs), maka quantization_dimension(rhs) = quantization_dimension(result).

Contoh

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[6, 8], [10, 12]]

Contoh Lainnya

after_all

Semantik

Memastikan bahwa operasi yang menghasilkan inputs dieksekusi sebelum operasi apa pun yang bergantung pada result. Eksekusi operasi ini tidak melakukan apa pun, hanya ada untuk menetapkan dependensi data dari result ke inputs.

Input

Label Nama Jenis
(I1) inputs jumlah variabel token

Output

Nama Jenis
result token

Contoh

// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token

 Contoh Lainnya

all_gather

Semantik

Dalam setiap grup proses di petak proses StableHLO, gabungkan nilai tensor operands dari setiap proses di sepanjang all_gather_dim dan hasilkan tensor results.

Operasi ini membagi petak proses StableHLO menjadi process_groups yang ditentukan sebagai berikut:

  • cross_replica(replica_groups) if channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) jika channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) if channel_id > 0 and use_global_device_ids = true.

Setelah itu, dalam setiap process_group:

  • operands...@receiver = [operand@sender for sender in process_group] untuk semua receiver di process_group.
  • results...@process = concatenate(operands...@process, all_gather_dim) untuk semua process di process_group.

Input

Label Nama Jenis Batasan
(I1) operands jumlah variabel tensor atau tensor terkuantisasi per tensor (C1), (C6)
(I2) all_gather_dim konstanta dari jenis si64 (C1), (C6)
(I3) replica_groups Konstanta tensor 2 dimensi jenis si64 (C2-C4)
(I4) channel_id konstanta dari jenis si64 (C5)
(I5) use_global_device_ids konstanta jenis i1 (C5)

Output

Nama Jenis Batasan
results jumlah variabel tensor atau tensor terkuantisasi per tensor (C6)

Batasan

  • (C1) 0 <= all_gather_dim < rank(operands...).
  • (C2) is_unique(replica_groups).
  • (C3) size(replica_groups) ditentukan sebagai:
    • num_replicas jika cross_replica digunakan.
    • num_replicas jika cross_replica_and_partition digunakan.
    • num_processes jika flattened_ids digunakan.
  • (C4) 0 <= replica_groups < size(replica_groups).
  • (C5) Jika use_global_device_ids = true, maka channel_id > 0.
  • (C6) type(results...) = type(operands...) kecuali:
    • dim(results..., all_gather_dim) = dim(operands..., all_gather_dim) * dim(process_groups, 1).

Contoh

// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [[1, 2], [3, 4]]
// %operand0@(1, 0): [[5, 6], [7, 8]]
// %operand1@(0, 0): [[11, 12], [13, 14]]
// %operand1@(1, 0): [[15, 16], [17, 18]]
%result:2 = "stablehlo.all_gather"(%operand0, %operand1) {
  all_gather_dim = 1 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  // channel_id = 0
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
  // use_global_device_ids = false
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> (tensor<2x4xi64>, tensor<2x4xi64>)
// %result0@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result0@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result1@(0, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
// %result1@(1, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]

 Contoh Lainnya

all_reduce

Semantik

Dalam setiap grup proses di petak proses StableHLO, menerapkan fungsi pengurangan computation ke nilai tensor operands dari setiap proses dan menghasilkan tensor results.

Operasi ini membagi petak proses StableHLO menjadi process_groups yang ditentukan sebagai berikut:

  • cross_replica(replica_groups) if channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) jika channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) if channel_id > 0 and use_global_device_ids = true.

Setelah itu, dalam setiap process_group:

  • results...@process[result_index] = exec(schedule) untuk beberapa hierarki biner schedule dengan:
    • exec(node) = computation(exec(node.left), exec(node.right)).
    • exec(leaf) = leaf.value.
  • schedule adalah hierarki biner yang ditentukan implementasi yang traversal urutannya adalah to_destination_type(operands...@process_group...[result_index], type(func_inputs(computation)[0])).

Input

Label Nama Jenis Batasan
(I1) operands jumlah variabel tensor atau tensor terkuantisasi per tensor (C5), (C6)
(I2) replica_groups jumlah variadik konstanta tensor 1 dimensi jenis si64 (C1-C3)
(I3) channel_id konstanta jenis si64 (C4)
(I4) use_global_device_ids konstanta dari jenis i1 (C4)
(I5) computation fungsi (C5)

Output

Nama Jenis Batasan
results jumlah variabel tensor atau tensor terkuantisasi per tensor (C6-C7)

Batasan

  • (C1) is_unique(replica_groups).
  • (C2) size(replica_groups) ditentukan sebagai:
    • num_replicas jika cross_replica digunakan.
    • num_replicas jika cross_replica_and_partition digunakan.
    • num_processes jika flattened_ids digunakan.
  • (C3) 0 <= replica_groups < size(replica_groups).
  • (C4) Jika use_global_device_ids = true, maka channel_id > 0.
  • (C5) computation memiliki jenis (tensor<E>, tensor<E>) -> (tensor<E>), dengan is_promotable(element_type(operand), E).
  • (C6) shape(results...) = shape(operands...).
  • (C7) element_type(results...) = E.

Contoh

// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [1, 2, 3, 4]
// %operand0@(1, 0): [5, 6, 7, 8]
// %operand1@(0, 0): [9, 10, 11, 12]
// %operand1@(1, 0): [13, 14, 15, 16]
%result:2 = "stablehlo.all_reduce"(%operand0, %operand0) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  // channel_id = 0
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
  // use_global_device_ids = false
} : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>, tensor<4xi64>)
// %result0@(0, 0): [6, 8, 10, 12]
// %result0@(1, 0): [6, 8, 10, 12]
// %result1@(0, 0): [22, 24, 26, 28]
// %result1@(1, 0): [22, 24, 26, 28]

 Contoh Lainnya

all_to_all

Semantik

all_to_all

Dalam setiap grup proses di petak proses StableHLO, bagi nilai tensor operands di sepanjang split_dimension menjadi beberapa bagian, sebar bagian yang dibagi di antara proses, gabungkan bagian yang tersebar di sepanjang concat_dimension, dan hasilkan tensor results. Operasi ini membagi petak proses StableHLO menjadi process_groups yang ditentukan sebagai berikut:

  • cross_replica(replica_groups) if channel_id <= 0.
  • cross_partition(replica_groups) if channel_id > 0.

Setelah itu, dalam setiap process_group:

  • split_parts...@sender = split(operands...@sender, split_count, split_dimension) untuk semua sender di process_group.
  • scattered_parts...@receiver = [split_parts...@sender[receiver_index] for sender in process_group] dengan receiver_index = process_group.index(receiver).
  • results...@process = concatenate(scattered_parts...@process, concat_dimension).

Input

Label Nama Jenis Batasan
(I1) operands jumlah variabel tensor atau tensor terkuantisasi per tensor (C1-C3), (C9)
(I2) split_dimension konstanta dari jenis si64 (C1), (C2), (C9)
(I3) concat_dimension konstanta dari jenis si64 (C3), (C9)
(I4) split_count konstanta dari jenis si64 (C2), (C4), (C8), (C9)
(I5) replica_groups Konstanta tensor 2 dimensi dari jenis si64 (C5-C8)
(I6) channel_id konstanta jenis si64

Output

Nama Jenis Batasan
results jumlah variabel tensor atau tensor terkuantisasi per tensor (C9)

Batasan

  • (C1) 0 <= split_dimension < rank(operands...).
  • (C2) dim(operands..., split_dimension) % split_count = 0.
  • (C3) 0 <= concat_dimension < rank(operands...).
  • (C4) 0 < split_count.
  • (C5) is_unique(replica_groups).
  • (C6) size(replica_groups) ditentukan sebagai:
    • num_replicas jika cross_replica digunakan.
    • num_partitions jika cross_partition digunakan.
  • (C7) 0 <= replica_groups < size(replica_groups).
  • (C8) dim(replica_groups, 1) = split_count.
  • (C9) type(results...) = type(operands...) kecuali, jika split_dimension != concat_dimension:
    • dim(results..., split_dimension) = dim(operands..., split_dimension) / split_count.
    • dim(results..., concat_dimension) = dim(operands..., concat_dimension) * split_count.

Contoh

// num_replicas: 2
// num_partitions: 1
// %operand1@(0, 0): [[1, 2, 3, 4],
//                    [5, 6, 7, 8]]
// %operand1@(1, 0): [[9, 10, 11, 12],
//                    [13, 14, 15, 16]]
// %operand2@(0, 0): [[17, 18, 19, 20],
//                    [21, 22, 23, 24]]
// %operand2@(1, 0): [[25, 26, 27, 28],
//                    [29, 30, 31, 32]]
%result:2 = "stablehlo.all_to_all"(%operand1, %operand2) {
  split_dimension = 1 : i64,
  concat_dimension = 0 : i64,
  split_count = 2 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
  // channel_id = 0
} : (tensor<2x4xi64>, tensor<2x4xi64>) -> (tensor<4x2xi64>, tensor<4x2xi64>)
// %result#0@(0, 0): [[1, 2], [5, 6], [9, 10], [13, 14]]
// %result#0@(1, 0): [[3, 4], [7, 8], [11, 12], [15, 16]]
// %result#1@(0, 0): [[17, 18], [21, 22], [25, 26], [29, 30]]
// %result#1@(1, 0): [[19, 20], [23, 24], [27, 28], [31, 32]]

 Contoh Lainnya

dan

Semantik

Melakukan AND element-wise dari dua tensor lhs dan rhs serta menghasilkan tensor result. Bergantung pada jenis elemen, lakukan tindakan berikut:

  • Untuk boolean: AND logika.
  • Untuk bilangan bulat: bitwise AND.

Input

Label Nama Jenis Batasan
(I1) lhs tensor dari jenis boolean atau bilangan bulat (C1)
(I2) rhs tensor dari jenis boolean atau bilangan bulat (C1)

Output

Nama Jenis Batasan
result tensor dari jenis boolean atau bilangan bulat (C1)

Batasan

  • (C1) type(lhs) = type(rhs) = type(result).

Contoh

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 0]]

 Contoh Lainnya

atan2

Semantik

Melakukan operasi atan2 element-wise pada tensor lhs dan rhs serta menghasilkan tensor result. Bergantung pada jenis elemen, lakukan tindakan berikut:

  • Untuk float: atan2 dari IEEE-754.
  • Untuk bilangan kompleks: atan2 kompleks.
  • Untuk jenis kuantisasi: dequantize_op_quantize(atan2, lhs, rhs, type(result)).

Input

Label Nama Jenis Batasan
(I1) lhs tensor floating point atau jenis kompleks atau tensor terkuantisasi per-tensor (C1)
(I2) rhs tensor dari jenis floating point atau kompleks atau tensor terkuantisasi per tensor (C1)

Output

Nama Jenis Batasan
result tensor dari jenis floating point atau kompleks atau tensor terkuantisasi per tensor (C1)

Batasan

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Contoh

// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs) : (tensor<3xf64>, tensor<3xf64>) -> tensor<3xf64>
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]

 Contoh Lainnya

batch_norm_grad

Semantik

Menghitung gradien beberapa input batch_norm_training yang melakukan backpropagation dari grad_output, dan menghasilkan tensor grad_operand, grad_scale, dan grad_offset. Secara lebih formal, operasi ini dapat dinyatakan sebagai dekomposisi ke operasi StableHLO yang ada menggunakan sintaksis Python sebagai berikut:

def compute_sum(operand, feature_index):
  (sum,) = reduce(
      inputs=[operand],
      init_values=[constant(0, element_type(operand))],
      dimensions=[i for i in range(rank(operand)) if i != feature_index],
      body=lambda x, y: add(x, y))
  return sum

def compute_mean(operand, feature_index):
  sum = compute_sum(operand, feature_index)
  divisor = constant(size(operand) / dim(operand, feature_index),
                     element_type(operand))
  divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
  return divide(sum, divisor_bcast)

def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
  # Broadcast inputs to type(operand)
  scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
  epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
                                   type(operand))

  # Perform normalization using the provided `mean` and `variance`
  # Intermediate values will be useful for computing gradients
  centered_operand = subtract(operand, mean_bcast)
  stddev = sqrt(add(variance_bcast, epsilon_bcast))
  normalized_operand = divide(centered_operand, stddev)

  # Use the implementation from batchnorm_expander.cc in XLA
  # Temporary variables have exactly the same names as in the C++ code
  elements_per_feature = broadcast_in_dim(
      constant(divide(size(operand), dim(operand, feature_index)),
               element_type(grad_output)),
      [], type(operand))
  i1 = multiply(grad_output, elements_per_feature)
  i2 = broadcast_in_dim(
      compute_sum(grad_output, feature_index), [feature_index], type(operand))
  i3 = broadcast_in_dim(
      compute_sum(multiply(grad_output, centered_operand), feature_index),
      [feature_index], type(operand))
  i4 = multiply(i3, centered_operand)
  i5 = divide(i4, add(variance_bcast, epsilon_bcast))
  i6 = subtract(subtract(i1, i2), i5)

  grad_operand =
      multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
  grad_scale =
      compute_sum(multiply(grad_output, normalized_operand), feature_index)
  grad_offset = compute_sum(grad_output, feature_index)

  return grad_operand, grad_scale, grad_offset

Untuk jenis kuantisasi, lakukan dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean, variance, grad_output: batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index), operand, scale, mean, variance, grad_output, type(grad_operand), type(grad_scale), type(feature_index)).

Input

Label Nama Jenis Batasan
(I1) operand tensor jenis floating point atau tensor terkuantisasi per tensor (C1-C3), (C5)
(I2) scale Tensor 1 dimensi dari jenis floating point atau terkuantisasi per tensor (C2), (C4), (C5)
(I3) mean Tensor 1-dimensi dari floating-point atau jenis terkuantisasi per-tensor (C2), (C4)
(I4) variance Tensor 1-dimensi dari floating-point atau jenis terkuantisasi per-tensor (C2), (C4)
(I5) grad_output tensor jenis floating point atau tensor terkuantisasi per tensor (C2), (C3)
(I6) epsilon konstanta dari jenis f32
(I7) feature_index konstanta dari jenis si64 (C1), (C5)

Output

Nama Jenis Batasan
grad_operand tensor jenis floating point atau tensor terkuantisasi per tensor (C2), (C3)
grad_scale Tensor 1 dimensi dari jenis floating point atau terkuantisasi per tensor (C2), (C4)
grad_offset Tensor 1 dimensi dari jenis floating point atau terkuantisasi per tensor (C2), (C4)

Batasan

  • (C1) 0 <= feature_index < rank(operand).
  • (C2) operand, scale, mean, variance, grad_output, grad_operand, grad_scale, dan grad_offset memiliki baseline_element_type yang sama.
  • (C3) operand, grad_output, dan grad_operand memiliki bentuk yang sama.
  • (C4) scale, mean, variance, grad_scale, dan grad_offset memiliki bentuk yang sama.
  • (C5) size(scale) = dim(operand, feature_index).

Contoh

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
//                [[0.1, 0.1], [0.1, 0.1]],
//                [[0.1, 0.1], [0.1, 0.1]]
//               ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>,
     tensor<2x2x2xf64>) -> (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %grad_operand: [
//                 [[0.0, 0.0], [0.0, 0.0]],
//                 [[0.0, 0.0], [0.0, 0.0]]
//                ]
// %grad_scale:  [0.0, 0.0]
// %grad_offset: [0.4, 0.4]

batch_norm_inference

Semantik

Menormalisasi tensor operand di semua dimensi kecuali dimensi feature_index dan menghasilkan tensor result. Secara lebih formal, operasi ini dapat dinyatakan sebagai dekomposisi ke operasi StableHLO yang ada menggunakan sintaksis Python sebagai berikut:

def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
  # Broadcast inputs to shape(operand)
  scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
  offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
  epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
                                   type(operand))

  # Perform normalization using the provided `mean` and `variance` instead of
  # computing them like `batch_norm_training` does.
  centered_operand = subtract(operand, mean_bcast)
  stddev = sqrt(add(variance_bcast, epsilon_bcast))
  normalized_operand = divide(centered_operand, stddev)
  return add(multiply(scale_bcast, normalized_operand), offset_bcast)

Untuk jenis kuantisasi, lakukan dequantize_op_quantize(lambda operand, scale, offset, mean, variance: batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index), operand, scale, offset, mean, variance, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand tensor jenis floating point atau tensor terkuantisasi per tensor (C1-C7)
(I2) scale Tensor 1 dimensi dari jenis floating point atau terkuantisasi per tensor (C2), (C3)
(I3) offset Tensor 1-dimensi dari floating-point atau jenis terkuantisasi per-tensor (C2), (C4)
(I4) mean Tensor 1 dimensi dari jenis floating point atau terkuantisasi per tensor (C5)
(I5) variance Tensor 1 dimensi dari jenis floating point atau terkuantisasi per tensor (C2), (C6)
(I6) epsilon konstanta dari jenis f32
(I7) feature_index konstanta jenis si64 (C1), (C3-C6)

Output

Nama Jenis Batasan
result tensor jenis floating point atau tensor terkuantisasi per tensor (C2), (C7)

Batasan

  • (C1) 0 <= feature_index < rank(operand).
  • (C2) operand, scale, offset, mean, variance, dan result memiliki baseline_element_type yang sama.
  • (C3) size(scale) = dim(operand, feature_index).
  • (C4) size(offset) = dim(operand, feature_index).
  • (C5) size(mean) = dim(operand, feature_index).
  • (C6) size(variance) = dim(operand, feature_index).
  • (C7) baseline_type(operand) = baseline_type(result).

Contoh

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>) -> tensor<2x2x2xf64>
// %result: [
//           [[0.0, 0.0], [2.0, 2.0]],
//           [[2.0, 2.0], [0.0, 0.0]]
//          ]

batch_norm_training

Semantik

Menghitung rata-rata dan varian di semua dimensi kecuali untuk dimensi feature_index dan menormalisasi tensor operand yang menghasilkan tensor output, batch_mean, dan batch_var. Secara lebih formal, operasi ini dapat dinyatakan sebagai dekomposisi ke operasi StableHLO yang ada menggunakan sintaksis Python sebagai berikut:

def compute_mean(operand, feature_index):
  (sum,) = reduce(
      inputs=[operand],
      init_values=[constant(0, element_type(operand))],
      dimensions=[i for i in range(rank(operand)) if i != feature_index],
      body=lambda x, y: add(x, y))
  divisor = constant(size(operand) / dim(operand, feature_index),
                     element_type(operand))
  divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
  return divide(sum, divisor_bcast)

def compute_variance(operand, feature_index):
  mean = compute_mean(operand, feature_index)
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  centered_operand = subtract(operand, mean_bcast)
  return compute_mean(mul(centered_operand, centered_operand), feature_index)

def batch_norm_training(operand, scale, offset, epsilon, feature_index):
  mean = compute_mean(operand, feature_index)
  variance = compute_variance(operand, feature_index)
  return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
                              feature_index),
         mean, variance

Untuk jenis kuantisasi, lakukan dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset: batch_norm_training(operand, scale, offset, epsilon, feature_index), operand, scale, offset, type(output), type(batch_mean), type(batch_var)).

Input

Label Nama Jenis Batasan
(I1) operand tensor jenis floating point atau tensor terkuantisasi per tensor (C1)
(I2) scale Tensor 1 dimensi dari floating point atau terkuantisasi per tensor (C2), (C3)
(I3) offset Tensor 1 dimensi dari floating point atau terkuantisasi per tensor (C2), (C4)
(I4) epsilon konstanta jenis f32 (C1), (C3-C6)
(I5) feature_index konstanta jenis si64 (C1), (C3-C6)

Output

Nama Jenis Batasan
output tensor jenis floating point atau tensor terkuantisasi per tensor (C7)
batch_mean Tensor 1 dimensi dari floating point atau terkuantisasi per tensor (C2), (C5)
batch_var Tensor 1-dimensi dari floating-point atau per-tensor terkuantisasi (C2), (C6)

Batasan

  • (C1) 0 <= feature_index < rank(operand).
  • (C2) operand, scale, offset, batch_mean, batch_var, dan output memiliki baseline_element_type yang sama.
  • (C3) size(scale) = dim(operand, feature_index).
  • (C4) size(offset) = dim(operand, feature_index).
  • (C5) size(batch_mean) = dim(operand, feature_index).
  • (C6) size(batch_var) = dim(operand, feature_index).
  • (C7) baseline_type(output) = baseline_type(operand).

Contoh

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) ->
    (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %output: [
//           [[0.0, 0.0], [2.0, 2.0]],
//           [[2.0, 2.0], [0.0, 0.0]]
//          ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]

bitcast_convert

Semantik

Melakukan operasi bitcast pada tensor operand dan menghasilkan tensor result dengan bit dari seluruh tensor operand ditafsirkan ulang menggunakan jenis tensor result.

Secara lebih formal, dengan E = element_type(operand), E' = element_type(result), dan R = rank(operand):

  • Jika num_bits(E') < num_bits(E), bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1]).
  • Jika num_bits(E') > num_bits(E), bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :]).
  • Jika num_bits(E') = num_bits(E), bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1]).

bits menampilkan representasi dalam memori dari nilai tertentu, dan perilakunya ditentukan oleh implementasi karena representasi persis tensor ditentukan oleh implementasi, dan representasi persis jenis elemen juga ditentukan oleh implementasi.

Input

Label Nama Jenis Batasan
(I1) operand tensor atau tensor terkuantisasi (C1-C2)

Output

Nama Jenis Batasan
result tensor atau tensor terkuantisasi (C1-C2)

Batasan

  • (C1) Dengan E = is_quantized(operand) ? storage_type(operand) : element_type(operand), E' = is_quantized(result) ? storage_type(result) : element_type(result), dan R = rank(operand):
    • Jika num_bits(E') = num_bits(E), shape(result) = shape(operand).
    • Jika num_bits(E') < num_bits(E):
    • rank(result) = R + 1.
    • dim(result, i) = dim(operand, i) untuk semua 0 <= i < R.
    • dim(result, R) * num_bits(E') = num_bits(E).
    • Jika num_bits(E') > num_bits(E):
    • rank(result) = R - 1.
    • dim(result, i) = dim(operand, i) untuk semua 0 <= i < R.
    • dim(operand, R - 1) * num_bits(E) = num_bits(E').
  • (C2) Jika is_complex(operand) or is_complex(result), maka is_complex(operand) and is_complex(result).

Contoh

// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand) : (tensor<f64>) -> tensor<4xf16>
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation

 Contoh Lainnya

broadcast_in_dim

Semantik

Memperluas dimensi dan/atau peringkat tensor input dengan menduplikasi data dalam tensor operand dan menghasilkan tensor result. Secara lebih formal, result[result_index] = operand[operand_index] dengan untuk semua d di axes(operand):

  • operand_index[d] = 0 if dim(operand, d) = 1.
  • operand_index[d] = result_index[broadcast_dimensions[d]] jika tidak.

Input

Label Nama Jenis Batasan
(I1) operand tensor atau quantized tensor (C1-C2), (C5-C6)
(I2) broadcast_dimensions Konstanta tensor 1 dimensi dari jenis si64 (C2-C6)

Output

Nama Jenis Batasan
result tensor atau tensor terkuantisasi (C1), (C3), (C5-C6)

Batasan

  • (C1) element_type(result) diberikan oleh:
    • element_type(operand), jika !is_per_axis_quantized(operand).
    • element_type(operand) kecuali bahwa quantization_dimension(operand), scales(operand), dan zero_points(operand) mungkin berbeda dari quantization_dimension(result), scales(result), dan zero_points(result) masing-masing.
  • (C2) size(broadcast_dimensions) = rank(operand).
  • (C3) 0 <= broadcast_dimensions < rank(result).
  • (C4) is_unique(broadcast_dimensions).
  • (C5) Untuk semua d di axes(operand):
    • dim(operand, d) = 1 atau
    • dim(operand, d) = dim(result, broadcast_dimensions[d]).
  • (C6) Jika is_per_axis_quantized(result):
    • quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)].
    • Jika dim(operand, quantization_dimension(operand)) = 1, maka scales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result))).

Contoh

// %operand: [
//            [1, 2, 3]
//           ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
  broadcast_dimensions = array<i64: 2, 1>
} : (tensor<1x3xi32>) -> tensor<2x3x2xi32>
// %result: [
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ],
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ]
//          ]

Contoh Lainnya

casing

Semantik

Menghasilkan output dari mengeksekusi tepat satu fungsi dari branches, bergantung pada nilai index. Secara lebih formal, result = selected_branch() dengan:

  • selected_branch = branches[index] if 0 <= index < size(branches).
  • selected_branch = branches[-1] jika tidak.

Input

Label Nama Jenis Batasan
(I1) index Tensor 0-dimensi jenis si32
(I2) branches jumlah fungsi variadik (C1-C4)

Output

Nama Jenis Batasan
results jumlah variabel tensor, tensor terkuantisasi, atau token (C4)

Batasan

  • (C1) 0 < size(branches).
  • (C2) input_types(branches...) = [].
  • (C3) same(output_types(branches...)).
  • (C4) type(results...) = output_types(branches[0]).

Contoh

// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
  "stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> ()
}, {
  "stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> ()
}) : (tensor<i32>) -> (tensor<2xi64>, tensor<2xi64>)
// %result0: [1, 1]
// %result1: [1, 1]

 Contoh Lainnya

cbrt

Semantik

Melakukan operasi akar kubik element-wise pada tensor operand dan menghasilkan tensor result. Bergantung pada jenis elemen, lakukan tindakan berikut:

  • Untuk float: rootn(x, 3) dari IEEE-754.
  • Untuk bilangan kompleks: akar pangkat tiga kompleks.
  • Untuk jenis terkuantisasi: dequantize_op_quantize(cbrt, operand, type(result))

Input

Label Nama Jenis Batasan
(I1) operand tensor floating point atau jenis kompleks atau tensor terkuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result tensor dari jenis floating point atau kompleks atau tensor terkuantisasi per tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand) : (tensor<4xf64>) -> tensor<4xf64>
// %result: [0.0, 1.0, 2.0, 3.0]

 Contoh Lainnya

ceil

Semantik

Melakukan ceil element-wise dari tensor operand dan menghasilkan tensor result. Mengimplementasikan operasi roundToIntegralTowardPositive dari spesifikasi IEEE-754. Untuk jenis kuantisasi, lakukan dequantize_op_quantize(ceil, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand tensor jenis floating point atau tensor terkuantisasi per tensor (C1)

Output

Nama Jenis Batasan
result tensor jenis floating point atau tensor terkuantisasi per tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]

 Contoh Lainnya

cholesky

Semantik

Menghitung dekomposisi Cholesky dari batch matriks.

Secara lebih formal, untuk semua i di index_space(result), result[i0, ..., iR-3, :, :] adalah dekomposisi Cholesky dari a[i0, ..., iR-3, :, :], dalam bentuk matriks segitiga bawah (jika lower adalah true) atau matriks segitiga atas (jika lower adalah false). Nilai output di segitiga yang berlawanan, yaitu segitiga atas yang ketat atau segitiga bawah yang ketat, ditentukan oleh implementasi.

Jika ada i di mana matriks input bukan matriks positif-definit Hermitian, perilakunya tidak ditentukan.

Untuk jenis kuantisasi, lakukan dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result)).

Input

Label Nama Jenis Batasan
(I1) a tensor dari jenis floating point atau kompleks atau tensor terkuantisasi per tensor (C1-C3)
(I2) lower Konstanta tensor 0 dimensi jenis i1

Output

Nama Jenis Batasan
result tensor dari jenis floating point atau kompleks atau tensor terkuantisasi per tensor (C1)

Batasan

  • (C1) baseline_type(a) = baseline_type(result).
  • (C2) 2 <= rank(a).
  • (C3) dim(a, -2) = dim(a, -1).

Contoh

// %a: [
//      [1.0, 2.0, 3.0],
//      [2.0, 20.0, 26.0],
//      [3.0, 26.0, 70.0]
//     ]
%result = "stablehlo.cholesky"(%a) {
  lower = true
} : (tensor<3x3xf32>) -> tensor<3x3xf64>
// %result: [
//           [1.0, 0.0, 0.0],
//           [2.0, 4.0, 0.0],
//           [3.0, 5.0, 6.0]
//          ]

klem

Semantik

Membatasi setiap elemen tensor operand antara nilai minimum dan maksimum dan menghasilkan tensor result. Secara lebih formal, result[result_index] = minimum(maximum(operand[result_index], min_element), max_element), dengan min_element = rank(min) = 0 ? min[] : min[result_index], max_element = rank(max) = 0 ? max[] : max[result_index]. Untuk jenis terkuantisasi, menjalankan dequantize_op_quantize(clamp, min, operand, max, type(result)).

Memaksakan pengurutan pada angka kompleks melibatkan semantik yang mengejutkan, jadi di masa mendatang kami berencana menghapus dukungan untuk bilangan kompleks untuk operasi ini (#560).

Input

Label Nama Jenis Batasan
(I1) min tensor atau tensor terkuantisasi per tensor (C1), (C3)
(I2) operand tensor atau tensor terkuantisasi per tensor (C1-C4)
(I3) max tensor atau tensor terkuantisasi per tensor (C2), (C3)

Output

Nama Jenis Batasan
result tensor atau tensor terkuantisasi per tensor (C4)

Batasan

  • (C1) rank(min) = 0 or shape(min) = shape(operand).
  • (C2) rank(max) = 0 or shape(max) = shape(operand).
  • (C3) baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max).
  • (C4) baseline_type(operand) = baseline_type(result).

Contoh

// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max) : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// %result: [5, 13, 20]

 Contoh Lainnya

collective_broadcast

Semantik

Dalam setiap grup proses pada petak proses StableHLO, kirim nilai tensor operand dari proses sumber ke proses target dan hasilkan tensor result.

Operasi tersebut membagi grid proses StableHLO menjadi process_groups yang ditentukan sebagai berikut:

  • cross_replica(replica_groups) if channel_id <= 0.
  • cross_partition(replica_groups) jika channel_id > 0.

Setelah itu, result@process diberikan oleh:

  • operand@process_groups[i, 0] jika ada i sehingga prosesnya berada di process_groups[i].
  • broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result)) jika tidak.

Input

Label Nama Jenis Batasan
(I1) operand tensor atau tensor terkuantisasi per tensor (C3)
(I2) replica_groups jumlah variadik konstanta tensor 1 dimensi jenis si64 (C1), (C2)
(I3) channel_id konstanta dari jenis si64

Output

Nama Jenis Batasan
result tensor atau tensor terkuantisasi per tensor (C3)

Batasan

  • (C1) is_unique(replica_groups).
  • (C2) 0 <= replica_groups < N dengan N ditentukan sebagai:
    • num_replicas jika cross_replica digunakan.
    • num_partitions jika cross_partition digunakan.
  • (C3) type(result) = type(operand).

Contoh

// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
  replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]

collective_permute

Semantik

Dalam setiap grup proses pada petak proses StableHLO, mengirimkan nilai tensor operand dari proses sumber ke proses target dan menghasilkan tensor result.

Operasi ini membagi petak proses StableHLO menjadi process_groups yang ditentukan sebagai berikut:

  • cross_replica(source_target_pairs) jika channel_id <= 0.
  • cross_partition(source_target_pairs) if channel_id > 0.

Setelah itu, result@process diberikan oleh:

  • operand@process_groups[i, 0], jika ada i sehingga process_groups[i, 1] = process.
  • broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result)) jika tidak.

Input

Label Nama Jenis Batasan
(I1) operand tensor atau tensor terkuantisasi per tensor (C5)
(I2) source_target_pairs Konstanta tensor 2 dimensi dari jenis si64 (C1-C4)
(I3) channel_id konstanta dari jenis si64

Output

Nama Jenis Batasan
result tensor atau tensor terkuantisasi per tensor (C1)

Batasan

  • (C1) dim(source_target_pairs, 1) = 2.
  • (C2) is_unique(source_target_pairs[:, 0]).
  • (C3) is_unique(source_target_pairs[:, 1]).
  • (C4) 0 <= source_target_pairs < N, dengan N ditentukan sebagai:
    • num_replicas jika cross_replica digunakan.
    • num_partitions jika cross_partition digunakan.
  • (C5) type(result) = type(operand).

Contoh

// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
  source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]

 Contoh Lainnya

bandingkan

Semantik

Melakukan perbandingan element-wise dari tensor lhs dan rhs sesuai dengan comparison_direction dan compare_type, serta menghasilkan tensor result.

Nilai comparison_direction dan compare_type memiliki semantik berikut:

Untuk jenis elemen boolean dan bilangan bulat:

  • EQ: lhs = rhs.
  • NE: lhs != rhs.
  • GE: lhs >= rhs.
  • GT: lhs > rhs.
  • LE: lhs <= rhs.
  • LT: lhs < rhs.

Untuk jenis elemen floating point dengan compare_type = FLOAT, op menerapkan operasi IEEE-754 berikut:

  • EQ: compareQuietEqual.
  • NE: compareQuietNotEqual.
  • GE: compareQuietGreaterEqual.
  • GT: compareQuietGreater.
  • LE: compareQuietLessEqual.
  • LT: compareQuietLess.

Untuk jenis elemen floating point dengan compare_type = TOTALORDER, op menggunakan kombinasi operasi totalOrder dan compareQuietEqual dari IEEE-754.

Untuk jenis elemen yang kompleks, perbandingan leksikografis pasangan (real, imag) dilakukan menggunakan comparison_direction dan compare_type yang disediakan. Menerapkan pengurutan pada bilangan kompleks melibatkan semantik yang mengejutkan, jadi pada masa mendatang, kami berencana untuk menghapus dukungan untuk bilangan kompleks saat comparison_direction adalah GE, GT, LE, atau LT (#560).

Untuk jenis kuantisasi. melakukan dequantize_compare(lhs, rhs, comparison_direction).

Input

Label Nama Jenis Batasan
(I1) lhs tensor atau tensor terkuantisasi per-tensor (C1-C3)
(I2) rhs tensor atau tensor terkuantisasi per-tensor (C1-C2)
(I3) comparison_direction enum EQ, NE, GE, GT, LE, dan LT
(I4) compare_type enum FLOAT, TOTALORDER, SIGNED, dan UNSIGNED (C3)

Output

Nama Jenis Batasan
result tensor dari jenis boolean (C2)

Batasan

  • (C1) baseline_element_type(lhs) = baseline_element_type(rhs).
  • (C2) shape(lhs) = shape(rhs) = shape(result).
  • (C3) compare_type ditentukan sebagai:
    • SIGNED if is_signed_integer(element_type(lhs)).
    • UNSIGNED jika is_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs)).
    • FLOAT atau TOTALORDER jika is_float(element_type(lhs)).
    • FLOAT jika is_complex(element_type(lhs)).

Contoh

// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
  comparison_direction = #stablehlo<comparison_direction LT>,
  compare_type = #stablehlo<comparison_type FLOAT>
} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
// %result: [true, false]

 Contoh Lainnya

kompleks

Semantik

Melakukan konversi elemen ke nilai kompleks dari sepasang nilai riil dan imaginer, lhs dan rhs, serta menghasilkan tensor result.

Input

Label Nama Jenis Batasan
(I1) lhs tensor jenis f32 atau f64 (C1-C3)
(I2) rhs tensor jenis f32 atau f64 (C1)

Output

Nama Jenis Batasan
result tensor tipe kompleks (C2), (C3)

Batasan

  • (C1) type(lhs) = type(rhs).
  • (C2) shape(result) = shape(lhs).
  • (C3) element_type(result) memiliki jenis complex<E> dengan E = element_type(lhs).

Contoh

// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs) : (tensor<2xf64>, tensor<2xf64>) -> tensor<2xcomplex<f64>>
// %result: [(1.0, 2.0), (3.0, 4.0)]

 Contoh Lainnya

gabungan

Semantik

Mengenkapsulasi operasi yang terdiri (tersusun) dari operasi StableHLO lainnya, menggunakan inputs dan composite_attributes, serta menghasilkan results. Semantik op diterapkan oleh atribut decomposition. Operasi composite dapat diganti dengan dekomposisinya tanpa mengubah semantik program. Jika menyisipkan dekomposisi tidak memberikan semantik op yang sama, sebaiknya gunakan custom_call.

Kolom version (default-nya adalah 0) digunakan untuk menunjukkan kapan semantik komposit berubah.

Input

Label Nama Jenis
(I1) inputs jumlah nilai variadik
(I2) name konstanta dari jenis string
(I3) composite_attributes kamus atribut
(I4) decomposition konstanta jenis string
(I5) version konstanta jenis si32

Output

Nama Jenis
results jumlah nilai variadik

Batasan

  • (C1) is_namespaced_op_name(name)
  • (C2) is_defined_in_parent_scope(decomposition)
  • (C3) types(inputs...) == input_types(decomposition)
  • (C4) types(results...) == output_types(decomposition)

Contoh

%results = "stablehlo.composite"(%input0, %input1) {
  name = "my_namespace.my_op",
  composite_attributes = {
    my_attribute = "my_value"
  },
  decomposition = @my_op,
  version = 1 : i32
} : (tensor<f32>, tensor<f32>) -> tensor<f32>

 Contoh Lainnya

concatenate

Semantik

Menggabungkan inputs di sepanjang dimensi dimension dalam urutan yang sama dengan argumen yang diberikan dan menghasilkan tensor result. Secara lebih formal, result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1], dengan:

  1. id = d0 + ... + dk-1 + kd.
  2. d sama dengan dimension, dan d0, ... adalah ukuran dimensi d dari inputs.

Input

Label Nama Jenis Batasan
(I1) inputs jumlah variabel tensor atau tensor terkuantisasi per tensor (C1-C6)
(I2) dimension konstanta dari jenis si64 (C2), (C4), (C6)

Output

Nama Jenis Batasan
result tensor atau tensor terkuantisasi per-tensor (C5-C6)

Batasan

  • (C1) same(element_type(inputs...)).
  • (C2) same(shape(inputs...)) kecuali untuk dim(inputs..., dimension).
  • (C3) 0 < size(inputs).
  • (C4) 0 <= dimension < rank(inputs[0]).
  • (C5) element_type(result) = element_type(inputs[0]).
  • (C6) shape(result) = shape(inputs[0]) kecuali untuk:
    • dim(result, dimension) = dim(inputs[0], dimension) + ....

Contoh

// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
  dimension = 0 : i64
} : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64>
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]

 Contoh Lainnya

konstanta

Semantik

Menghasilkan tensor output dari value yang konstan.

Input

Label Nama Jenis Batasan
(I1) value konstanta (C1)

Output

Nama Jenis Batasan
output tensor atau tensor terkuantisasi (C1)

Batasan

  • (C1) type(value) = type(output).

Contoh

%output = "stablehlo.constant"() {
  value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
} : () -> tensor<2x2xf32>
// %output: [[0.0, 1.0], [2.0, 3.0]]

 Contoh Lainnya

melakukan konversi

Semantik

Melakukan konversi elemen dari satu jenis elemen ke jenis elemen lainnya pada tensor operand dan menghasilkan tensor result.

Untuk konversi boolean-to-any-supported-type, nilai false dikonversi menjadi nol, dan nilai true dikonversi menjadi satu. Untuk konversi any-supported-type-to-boolean, nilai nol dikonversi menjadi false, dan nilai non-nol dikonversi menjadi true. Lihat di bawah untuk mengetahui cara kerja ini untuk jenis yang kompleks.

Untuk konversi yang melibatkan bilangan bulat ke bilangan bulat, bilangan bulat ke floating point, atau floating point ke floating point, jika nilai sumber dapat direpresentasikan secara tepat dalam jenis tujuan, nilai hasilnya adalah representasi yang tepat tersebut. Jika tidak, perilakunya adalah TBD (#180).

Untuk konversi yang melibatkan floating-point-to-integer, bagian pecahan akan terpotong. Jika nilai yang dipangkas tidak dapat direpresentasikan dalam jenis tujuan, perilakunya adalah TBD (#180).

Konversi yang melibatkan kompleks ke kompleks mengikuti perilaku yang sama dengan konversi floating point ke floating point untuk mengonversi bagian riil dan imajiner.

Untuk konversi complex-to-any-other-type dan any-other-type-to-complex, nilai imajiner sumber diabaikan atau nilai imajiner tujuan disetel ke nol. Konversi bagian riil mengikuti konversi floating point.

Pada prinsipnya, operasi ini dapat mengekspresikan dekuantisasi (konversi dari tensor kuantisasi menjadi tensor reguler), kuantisasi (konversi dari tensor reguler menjadi tensor kuantisasi), dan rekuantisasi (konversi antara tensor kuantisasi), tetapi saat ini kami memiliki operasi khusus untuk itu - uniform_dequantize untuk kasus penggunaan pertama dan uniform_quantize untuk kasus penggunaan kedua dan ketiga. Nantinya, kedua operasi ini dapat digabungkan ke dalam convert (#1576).

Input

Label Nama Jenis Batasan
(I1) operand tensor (C1)

Output

Nama Jenis Batasan
result tensor (C1)

Batasan

  • (C1) shape(operand) = shape(result).

Contoh

// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand) : (tensor<3xi64>) -> tensor<3xcomplex<f64>>
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]

 Contoh Lainnya

konvolusi

Semantik

Menghitung perkalian titik antara jendela lhs dan slice rhs serta menghasilkan result. Diagram berikut menunjukkan cara elemen di result dihitung dari lhs dan rhs menggunakan contoh konkret.

konvolusi

Secara lebih formal, pertimbangkan penyusunan ulang input berikut dalam hal lhs agar dapat mengekspresikan jendela lhs:

  • lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension)).
  • lhs_window_strides = lhs_shape(1, window_strides, 1).
  • lhs_padding = lhs_shape([0, 0], padding, [0, 0]).
  • lhs_base_dilations = lhs_shape(1, lhs_dilation, 1).
  • lhs_window_dilations = lhs_shape(1, rhs_dilation, 1).

Pembingkaian ulang ini menggunakan fungsi bantuan berikut:

  • lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]).
  • result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]).
  • permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1] dengan j[d] = i[permutation[d]].

Jika feature_group_count = 1 dan batch_group_count = 1, maka untuk semua output_spatial_index di index_space(dim(result, output_spatial_dimensions...)), result[result_shape(:, output_spatial_index, :)] = dot_product dengan:

  • padding_value = constant(0, element_type(lhs)).
  • padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1).
  • lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides.
  • lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations).
  • reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true]). Fitur ini tampaknya tidak digunakan, jadi pada masa mendatang kami berencana untuk menghapusnya (#1181).
  • dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension]).

Jika feature_group_count > 1:

  • lhses = split(lhs, feature_group_count, input_feature_dimension).
  • rhses = split(rhs, feature_group_count, kernel_output_feature_dimension).
  • results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...).
  • result = concatenate(results, output_feature_dimension).

Jika batch_group_count > 1:

  • lhses = split(lhs, batch_group_count, input_batch_dimension).
  • rhses = split(rhs, batch_group_count, kernel_output_feature_dimension).
  • results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...).
  • result = concatenate(results, output_feature_dimension).

Untuk jenis terkuantisasi, menjalankan dequantize_op_quantize( lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension, input_feature_dimension, input_spatial_dimensions, kernel_input_feature_dimension, kernel_output_feature_dimension, kernel_spatial_dimensions, output_batch_dimension, output_feature_dimension, output_spatial_dimensions, feature_group_count, batch_group_count, precision_config), lhs, rhs, type(result)).

Untuk jenis kuantisasi campuran, lakukan hybrid_dequantize_then_op( lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension, input_feature_dimension, input_spatial_dimensions, kernel_input_feature_dimension, kernel_output_feature_dimension, kernel_spatial_dimensions, output_batch_dimension, output_feature_dimension, output_spatial_dimensions, feature_group_count, batch_group_count, precision_config), lhs, rhs).

Input

Label Nama Jenis Batasan
(I1) lhs tensor atau tensor terkuantisasi per-tensor (C1), (C10-C11), (C14) (C25), (C27-C28), (C31-C32), (C34)
(I2) rhs tensor atau tensor terkuantisasi (C1), (C14-C16), (C25), (C27-C29), (C31-C34)
(I3) window_strides Konstanta tensor 1 dimensi dari jenis si64 (C2-C3), (C25)
(I4) padding Konstanta tensor 2 dimensi dari jenis si64 (C4), (C25)
(I5) lhs_dilation Konstanta tensor 1 dimensi dari jenis si64 (C5-C6), (C25)
(I6) rhs_dilation Konstanta tensor 1 dimensi dari jenis si64 (C7-C8), (C25)
(I7) window_reversal Konstanta tensor 1 dimensi dari jenis i1 (C9)
(I8) input_batch_dimension konstanta dari jenis si64 (C10), (C13), (C25)
(I9) input_feature_dimension konstanta jenis si64 (C11), (C13-C14)
(I10) input_spatial_dimensions Konstanta tensor 1 dimensi jenis si64 (C12), (C13), (C25)
(I11) kernel_input_feature_dimension konstanta jenis si64 (C14), (C18)
(I12) kernel_output_feature_dimension konstanta dari jenis si64 (C15-C16), (C18), (C25), (C29)
(I13) kernel_spatial_dimensions Konstanta tensor 1 dimensi jenis si64 (C17-C18), (C25)
(I14) output_batch_dimension konstanta dari jenis si64 (C20), (C25)
(I15) output_feature_dimension konstanta dari jenis si64 (C20), (C25), (C30)
(I16) output_spatial_dimensions Konstanta tensor 1 dimensi dari jenis si64 (C19-C20), (C25)
(I17) feature_group_count konstanta dari jenis si64 (C11), (C14), (C16), (C21), (C23)
(I18) batch_group_count konstanta dari jenis si64 (C10), (C15), (C22), (C23), (C25)
(I19) precision_config jumlah variabel enum DEFAULT, HIGH, dan HIGHEST (C24)

Output

Nama Jenis Batasan
result tensor atau tensor terkuantisasi (C25-C28), (C30), (C32-34)

Batasan

  • (C1) N = rank(lhs) = rank(rhs).
  • (C2) size(window_strides) = N - 2.
  • (C3) 0 < window_strides.
  • (C4) shape(padding) = [N - 2, 2].
  • (C5) size(lhs_dilation) = N - 2.
  • (C6) 0 < lhs_dilation.
  • (C7) size(rhs_dilation) = N - 2.
  • (C8) 0 < rhs_dilation.
  • (C9) size(window_reversal) = N - 2.
  • (C10) dim(lhs, input_batch_dimension) % batch_group_count = 0.
  • (C11) dim(lhs, input_feature_dimension) % feature_group_count = 0.
  • (C12) size(input_spatial_dimensions) = N - 2.
  • (C13) Dengan input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]:
    • is_unique(input_dimensions).
    • 0 <= input_dimensions < N.
  • (C14) dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count.
  • (C15) dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0.
  • (C16) dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0.
  • (C17) size(kernel_spatial_dimensions) = N - 2.
  • (C18) Dengan kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]:
    • is_unique(kernel_dimensions).
    • 0 <= kernel_dimensions < N.
  • (C19) size(output_spatial_dimensions) = N - 2.
  • (C20) Diberikan output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]:
    • is_unique(output_dimensions).
    • 0 <= output_dimensions < N.
  • (C21) 0 < feature_group_count.
  • (C22) 0 < batch_group_count.
  • (C23) feature_group_count = 1 or batch_group_count = 1.
  • (C24) size(precision_config) = 2.
  • (C25) dim(result, result_dim) ditentukan sebagai:
    • dim(lhs, input_batch_dimension) / batch_group_count if result_dim = output_batch_dimension.
    • dim(rhs, kernel_output_feature_dimension) jika result_dim = output_feature_dimension.
    • num_windows jika tidak, dengan:
    • output_spatial_dimensions[spatial_dim] = result_dim.
    • lhs_dim = input_spatial_dimensions[spatial_dim].
    • rhs_dim = kernel_spatial_dimensions[spatial_dim].
    • dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1.
    • padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1].
    • dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1.
    • is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim].
    • num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1.
  • (C26) rank(result) = N.
  • Jika operasi menggunakan tensor yang tidak dikuantisasi:
    • (C27) element_type(lhs) = element_type(rhs) = element_type(result).
  • Jika operasi menggunakan tensor kuantisasi:
    • (C28) is_quantized(lhs) = is_quantized(result) and is_quantized(rhs).
    • (C29) Jika is_per_axis_quantized(rhs), maka quantization_dimension(rhs) = kernel_output_feature_dimension.
    • (C30) Jika is_per_axis_quantized(result), maka quantization_dimension(result) = output_feature_dimension.
    • Jika is_quantized(lhs):
    • (C31) storage_type(lhs) = storage_type(rhs).
    • (C32) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • (C33) Jika is_per_tensor_quantized(rhs), maka is_per_tensor_quantized(result).
    • Jika !is_quantized(lhs):
    • (C34) element_type(lhs) = expressed_type(rhs) = element_type(result).

Contoh

// %lhs: [[
//        [
//          [1], [2], [5], [6]
//        ],
//        [
//          [3], [4], [7], [8]
//        ],
//        [
//          [10], [11], [14], [15]
//        ],
//        [
//          [12], [13], [16], [17]
//        ]
//      ]]
//
// %rhs: [
//        [[[1]], [[1]], [[1]]],
//        [[[1]], [[1]], [[1]]],
//        [[[1]], [[1]], [[1]]]
//       ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
  window_strides = array<i64: 4, 4>,
  padding = dense<0> : tensor<2x2xi64>,
  lhs_dilation = array<i64: 2, 2>,
  rhs_dilation = array<i64: 1, 1>,
  window_reversal = array<i1: false, false>,
  // In the StableHLO dialect, dimension numbers are encoded via:
  // `[<input dimensions>]x[<kernel dimensions>]->[output dimensions]`.
  // "b" is batch dimension, "f" is feature dimension,
  // "i" is input feature dimension, "o" is output feature dimension,
  // "0/1/etc" are spatial dimensions.
  dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
  batch_group_count = 1 : i64,
  feature_group_count = 1 : i64,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
//            [[10], [26]],
//            [[46], [62]]
//          ]]

Contoh Lainnya

kosinus

Semantik

Melakukan operasi kosinus berbasis elemen pada tensor operand dan menghasilkan tensor result. Bergantung pada jenis elemen, lakukan tindakan berikut:

  • Untuk float: cos dari IEEE-754.
  • Untuk bilangan kompleks: kosinus kompleks.
  • Untuk jenis kuantisasi: dequantize_op_quantize(cosine, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand tensor floating point atau jenis kompleks atau tensor terkuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result tensor dari jenis floating point atau kompleks atau tensor terkuantisasi per tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.cosine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.0], [-1.0, 0.0]]

 Contoh Lainnya

count_leading_zeros

Semantik

Melakukan penghitungan element-wise dari jumlah bit nol di awal dalam tensor operand dan menghasilkan tensor result.

Input

Label Nama Jenis Batasan
(I1) operand tensor dari jenis bilangan bulat (C1)

Output

Nama Jenis Batasan
result tensor dari jenis bilangan bulat (C1)

Batasan

  • (C1) type(operand) = type(result).

Contoh

// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand) : (tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[64, 63], [56, 0]]

 Contoh Lainnya

custom_call

Semantik

Mengenkapsulasi operasi call_target_name yang ditentukan implementasi yang menggunakan inputs dan called_computations serta menghasilkan results. has_side_effect, backend_config, dan api_version dapat digunakan untuk memberikan metadata tambahan yang ditentukan implementasi.

Saat ini, operasi ini berisi kumpulan metadata yang cukup tidak teratur yang mencerminkan evolusi organik dari operasi yang setara di compiler XLA. Pada masa mendatang, kami berencana untuk menyatukan metadata ini (#741).

Input

Label Nama Jenis
(I1) inputs jumlah nilai variadik
(I2) call_target_name konstanta dari jenis string
(I3) has_side_effect konstanta dari jenis i1
(I4) backend_config konstanta jenis string atau kamus atribut
(I5) api_version konstanta dari jenis si32
(I6) called_computations jumlah konstanta variadik jenis string

Output

Nama Jenis
results jumlah nilai variadis

Contoh

%results = "stablehlo.custom_call"(%input0) {
  call_target_name = "foo",
  has_side_effect = false,
  backend_config = {bar = 42 : i32},
  api_version = 4 : i32,
  called_computations = [@foo]
} : (tensor<f64>) -> tensor<f64>

bagi

Semantik

Melakukan pembagian element-wise dari tensor dividen lhs dan pembagi rhs dan menghasilkan tensor result. Bergantung pada jenis elemen, lakukan tindakan berikut:

  • Untuk bilangan bulat: pembagian bilangan bulat yang menghasilkan hasil bagi aljabar dengan bagian pecahan yang dihapus.
  • Untuk float: division dari IEEE-754.
  • Untuk bilangan kompleks: pembagian kompleks.
  • Untuk jenis kuantisasi:
    • dequantize_op_quantize(divide, lhs, rhs, type(result)).

Input

Label Nama Jenis Batasan
(I1) lhs tensor bilangan bulat, floating point, atau kompleks, atau tensor terkuantisasi per tensor (C1)
(I2) rhs tensor bilangan bulat, floating point, atau kompleks, atau tensor terkuantisasi per tensor (C1)

Output

Nama Jenis Batasan
result tensor bilangan bulat, floating point, atau jenis kompleks atau tensor terkuantisasi per tensor (C1)

Batasan

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Contoh

// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]

 Contoh Lainnya

dot_general

Semantik

Menghitung perkalian titik antara slice lhs dan slice rhs serta menghasilkan tensor result.

Secara lebih formal, result[result_index] = dot_product, dengan:

  • lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions].
  • rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions].
  • result_batching_index + result_lhs_index + result_rhs_index = result_index dengan size(result_batching_index) = size(lhs_batching_dimensions), size(result_lhs_index) = size(lhs_result_dimensions), dan size(result_rhs_index) = size(rhs_result_dimensions).
  • transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions).
  • transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :]).
  • reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions)).
  • transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions).
  • transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :]).
  • reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions)).
  • dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y)).

Untuk jenis kuantisasi, lakukan dequantize_op_quantize( lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions, rhs_batching_dimensions, lhs_contracting_dimensions, rhs_contracting_dimensions, precision_config), lhs, rhs, type(result)).

Untuk jenis kuantisasi campuran, lakukan hybrid_dequantize_then_op( lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions, rhs_batching_dimensions, lhs_contracting_dimensions, rhs_contracting_dimensions, precision_config), lhs, rhs).

precision_config mengontrol kompromi antara kecepatan dan akurasi untuk komputasi di backend akselerator. Ini dapat berupa salah satu dari berikut ini (saat ini, semantik nilai enum ini tidak ditentukan, tetapi kami berencana untuk mengatasinya di #755):

  • DEFAULT: Penghitungan tercepat, tetapi perkiraan yang paling tidak akurat terhadap angka asli.
  • HIGH: Penghitungan lebih lambat, tetapi perkiraan ke angka asli lebih akurat.
  • HIGHEST: Penghitungan paling lambat, tetapi perkiraan paling akurat terhadap angka asli.

DotAlgorithm menentukan properti utama algoritma yang digunakan untuk menerapkan operasi titik, yang juga menentukan presisi. Jika kolom atribut algoritma ditetapkan, precision_config harus berupa DEFAULT. DotAlgorithms tidak memiliki nilai default, karena parameter default ditentukan oleh implementasi. Dengan demikian, semua kolom algoritma titik dapat ditetapkan ke None untuk menentukan algoritma titik kosong, yang akan menggunakan nilai precision_config.

Kolom DotAlgorithm mencakup:

  • lhs_precision_type dan rhs_precision_type, presisi yang digunakan untuk membulatkan LHS dan RHS operasi. Jenis presisi tidak bergantung pada jenis penyimpanan input dan output.
  • accumulation_type presisi yang digunakan untuk akumulasi.
  • lhs_component_count, rhs_component_count, dan num_primitive_operations berlaku saat kita melakukan algoritma yang menguraikan LHS dan/atau RHS menjadi beberapa komponen dan melakukan beberapa operasi titik "primitif" pada nilai tersebut, biasanya untuk mengemulasi presisi yang lebih tinggi (misalnya Memanfaatkan Jenis Data Kecerdasan Buatan bfloat16 Untuk Komputasi Presisi Tinggi: bf16_6x, tf32). Untuk algoritma tanpa dekomposisi, nilai ini harus ditetapkan ke 1.
  • allow_imprecise_accumulation untuk menentukan apakah akumulasi dalam presisi yang lebih rendah diizinkan untuk beberapa langkah (misalnya, CUBLASLT_MATMUL_DESC_FAST_ACCUM).

Contoh atribut DotAlgorithm:

// Inputs are casted to tf32, and then accumulated in f32:
{lhs_precision_type = tf32,
 rhs_precision_type = tf32,
 accumulation_type = f32,
 lhs_component_count = 1,
 rhs_component_count = 1,
 num_primitive_operations = 1,
 allow_imprecise_accumulation = false}


// bf16_6x: each input is decomposed to 3 bf16 components, then 6 dot operations are done on those components, and the result is accumulated in f32.
{lhs_precision_type = bf16,
 rhs_precision_type = bf16,
 accumulation_type = f32,
 lhs_component_count = 3,
 rhs_component_count = 3,
 num_primitive_operations = 6,
 allow_imprecise_accumulation = false}


// Inputs are (casted to) f8e5m2, and we accumulate in f32, but for some steps we may accumulate in lower precision.
{lhs_precision_type = f8e5m2,
 rhs_precision_type = f8e5m2,
 accumulation_type = f32,
 lhs_component_count = 1,
 rhs_component_count = 1,
 num_primitive_operations = 1,
 allow_imprecise_accumulation = true}

Implementasi bebas untuk memutuskan kombinasi mana yang didukung. Secara umum, tidak dijamin bahwa setiap algoritma didukung di setiap jenis akselerator oleh konsumen StableHLO. Jika algoritma tertentu tidak didukung, error harus ditampilkan, bukan kembali ke alternatif. Verifikasi StableHLO akan memberikan verifikasi upaya terbaik, yang mencegah algoritma yang tidak diketahui didukung di hardware mana pun.

Lihat xla_data.proto > Algorithm untuk mengetahui beberapa nilai algoritma yang didukung. Tiket #2483 berisi rencana untuk membuat dokumen terpusat tentang algoritma yang didukung oleh backend.

Input

Label Nama Jenis Batasan
(I1) lhs tensor atau tensor terkuantisasi per tensor (C5-C6), (C9-C10), (C12-C14), (C17-C18), (C20)
(I2) rhs tensor atau tensor terkuantisasi (C7-C10), (C12-C20)
(I3) lhs_batching_dimensions Konstanta tensor 1 dimensi dari jenis si64 (C1), (C3), (C5), (C9), (C12)
(I4) rhs_batching_dimensions Konstanta tensor 1 dimensi dari jenis si64 (C1), (C4), (C7), (C9)
(I5) lhs_contracting_dimensions Konstanta tensor 1 dimensi dari jenis si64 (C2), (C3), (C6), (C10)
(I6) rhs_contracting_dimensions Konstanta tensor 1 dimensi dari jenis si64 (C2), (C4), (C8), (C10), (C16)
(I7) precision_config jumlah variabel enum DEFAULT, HIGH, dan HIGHEST (C11), C21)
(I8) lhs_precision_type FloatType atau TensorFloat32 (C21)
(I9) rhs_precision_type FloatType atau TensorFloat32 (C21)
(I10) accumulation_type FloatType atau TensorFloat32 (C21)
(I11) lhs_component_count konstanta dari jenis si32 (C21), (C22)
(I12) rhs_component_count konstanta dari jenis si32 (C21), C23
(I13) num_primitive_operations konstanta dari jenis si32 (C21), C24
(I14) allow_imprecise_accumulation konstanta dari jenis bool (C21)

Output

Nama Jenis Batasan
result tensor atau quantized tensor (C12), (C14), (C18-C20)

Batasan

  • (C1) size(lhs_batching_dimensions) = size(rhs_batching_dimensions).
  • (C2) size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions).
  • (C3) is_unique(lhs_batching_dimensions + lhs_contracting_dimensions).
  • (C4) is_unique(rhs_batching_dimensions + rhs_contracting_dimensions).
  • (C5) 0 <= lhs_batching_dimensions < rank(lhs).
  • (C6) 0 <= lhs_contracting_dimensions < rank(lhs).
  • (C7) 0 <= rhs_batching_dimensions < rank(rhs).
  • (C8) 0 <= rhs_contracting_dimensions < rank(rhs).
  • (C9) dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...).
  • (C10) dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...).
  • (C11) size(precision_config) = 2.
  • (C12) shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions).
  • Jika operasi menggunakan tensor yang tidak dikuantisasi:
    • (C13) element_type(lhs) = element_type(rhs).
  • Jika operasi menggunakan tensor kuantisasi:
    • (C14) is_quantized(lhs) = is_quantized(result) and is_quantized(rhs).
    • (C15) zero_points(rhs) = 0.
    • (C16) Jika is_per_axis_quantized(rhs), maka quantization_dimension(rhs) tidak ada dalam rhs_contracting_dimensions.
    • Jika is_quantized(lhs):
    • (C17) storage_type(lhs) = storage_type(rhs).
    • (C18) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • (C19) Jika is_per_tensor_quantized(rhs), maka is_per_tensor_quantized(result).
    • Jika !is_quantized(lhs):
    • (C20) element_type(lhs) = expressed_type(rhs) = element_type(result).
  • Jika !is_empty_algorithm(lhs_precision_type, rhs_precision_type, accumulation_type, lhs_component_count, rhs_component_count, num_primitive_operations allow_imprecise_accumulation):
    • (C21) precision_config... = DEFAULT.
    • (C22) 0 < lhs_component_count.
    • (C23) 0 < rhs_component_count.
    • (C24) 0 < num_primitive_operations.

Contoh

// %lhs: [
//        [[1, 2],
//         [3, 4]],
//        [[5, 6],
//         [7, 8]]
//       ]
// %rhs: [
//        [[1, 0],
//         [0, 1]],
//        [[1, 0],
//         [0, 1]]
//       ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
  dot_dimension_numbers = #stablehlo.dot<
    lhs_batching_dimensions = [0],
    rhs_batching_dimensions = [0],
    lhs_contracting_dimensions = [2],
    rhs_contracting_dimensions = [1]
  >,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>],
  algorithm = #stablehlo.dot_algorithm<
    lhs_precision_type = tf32,
    rhs_precision_type = tf32,
    accumulation_type = f32,
    lhs_component_count = 1,
    rhs_component_count = 1,
    num_primitive_operations = 1,
    allow_imprecise_accumulation = false
  >
} : (tensor<2x2x2xi64>, tensor<2x2x2xi64>) -> tensor<2x2x2xi64>
// %result: [
//           [[1, 2],
//            [3, 4]],
//           [[5, 6],
//            [7, 8]]
//          ]

Contoh Lainnya

dynamic_broadcast_in_dim

Semantik

Operasi ini secara fungsional identik dengan operasi broadcast_in_dim, tetapi bentuk hasilnya ditetapkan secara dinamis melalui output_dimensions.

Operasi tersebut juga menerima atribut opsional known_expanding_dimensions, known_nonexpanding_dimensions untuk menyatakan pengetahuan statis tentang perilaku dimensi yang diperluas. Jika tidak ditentukan, semua dimensi diasumsikan dapat diperluas.

Input

Label Nama Jenis Batasan
(I1) operand tensor atau tensor terkuantisasi (C1-C2), (C5-C6), (C9)
(I2) output_dimensions Tensor 1-dimensi dari jenis integer (C7)
(I3) broadcast_dimensions Tensor konstan 1 dimensi dari jenis bilangan bulat (C2-C6)
(I4) known_expanding_dimensions Tensor konstanta 1 dimensi dari jenis integer (C8-C9)
(I5) known_nonexpanding_dimensions Tensor konstan 1 dimensi dari jenis bilangan bulat (C8-C9)

Output

Nama Jenis Batasan
result tensor atau tensor terkuantisasi (C1), (C3), (C5-C7)

Batasan

  • (C1) element_type(result) diberikan oleh:
    • element_type(operand), jika !is_per_axis_quantized(operand).
    • element_type(operand) kecuali bahwa quantization_dimension(operand), scales(operand), dan zero_points(operand) mungkin berbeda dari quantization_dimension(result), scales(result), dan zero_points(result) masing-masing.
  • (C2) size(broadcast_dimensions) = rank(operand).
  • (C3) 0 <= broadcast_dimensions < rank(result).
  • (C4) is_unique(broadcast_dimensions).
  • (C5) Untuk semua d di axes(operand):
    • dim(operand, d) = 1 atau
    • dim(operand, d) = dim(result, broadcast_dimensions[d]).
  • (C6) Jika is_per_axis_quantized(result):
    • quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)].
    • Jika dim(operand, quantization_dimension(operand)) = 1, maka scales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result))).
  • (C7) size(output_dimensions) = rank(result).
  • (C8) is_unique(known_expanding_dimensions + known_nonexpanding_dimensions).
  • (C9) 0 <= known_expanding_dimensions < rank(operand).
  • (C10) 0 <= known_nonexpanding_dimensions < rank(operand).

Contoh

// %operand: [
//            [1, 2, 3]
//           ]
%operand = stablehlo.constant dense<[[1, 2, 3]]> : tensor<1x3xi64>
%output_dimensions = stablehlo.constant dense<[2, 3, 2]> : tensor<3xi64>
%result = "stablehlo.dynamic_broadcast_in_dim"(%operand, %output_dimensions) {
  broadcast_dimensions = array<i64: 2, 1>,
  known_expanding_dimensions = array<i64: 0>,
  known_nonexpanding_dimensions = array<i64: 1>
} : (tensor<1x3xi64>, tensor<3xi64>) -> tensor<2x3x2xi64>
// %result: [
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ],
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ]
//          ]

Contoh Lainnya

dynamic_conv

Semantik

Operasi ini secara fungsional identik dengan op konvolusi, tetapi padding ditentukan secara dinamis melalui padding.

Input

Label Nama Jenis Batasan
(I1) lhs tensor atau tensor terkuantisasi per-tensor (C1), (C10-C11), (C14) (C25), (C26-C27), (C30-C31), (C33)
(I2) rhs tensor atau tensor terkuantisasi (C1), (C14-C16), (C26-C28), (C30-C33)
(I3) padding Tensor 2 dimensi dari jenis bilangan bulat (C4)
(I4) window_strides Konstanta tensor 1 dimensi dari jenis si64 (C2-C3)
(I5) lhs_dilation Konstanta tensor 1 dimensi jenis si64 (C5-C6)
(I6) rhs_dilation Konstanta tensor 1 dimensi dari jenis si64 (C7-C8)
(I7) window_reversal Konstanta tensor 1 dimensi dari jenis i1 (C9)
(I8) input_batch_dimension konstanta dari jenis si64 (C10), (C13)
(I9) input_feature_dimension konstanta jenis si64 (C11), (C13-C14)
(I10) input_spatial_dimensions Konstanta tensor 1 dimensi dari jenis si64 (C12), C13)
(I11) kernel_input_feature_dimension konstanta jenis si64 (C14), (C18)
(I12) kernel_output_feature_dimension konstanta dari jenis si64 (C15-C16), (C18), (C28)
(I13) kernel_spatial_dimensions Konstanta tensor 1 dimensi dari jenis si64 (C17-C18)
(I14) output_batch_dimension konstanta dari jenis si64 (C20)
(I15) output_feature_dimension konstanta jenis si64 (C20), (C29)
(I16) output_spatial_dimensions Konstanta tensor 1 dimensi jenis si64 (C19-C20)
(I17) feature_group_count konstanta dari jenis si64 (C11), (C14), (C16), (C21), (C23)
(I18) batch_group_count konstanta dari jenis si64 (C10), (C15), (C22), (C23)
(I19) precision_config jumlah variabel enum DEFAULT, HIGH, dan HIGHEST (C24)

Output

Nama Jenis Batasan
result tensor atau tensor terkuantisasi (C25-C27), (C29), (C31-C33)

Batasan

  • (C1) N = rank(lhs) = rank(rhs).
  • (C2) size(window_strides) = N - 2.
  • (C3) 0 < window_strides.
  • (C4) shape(padding) = [N - 2, 2].
  • (C5) size(lhs_dilation) = N - 2.
  • (C6) 0 < lhs_dilation.
  • (C7) size(rhs_dilation) = N - 2.
  • (C8) 0 < rhs_dilation.
  • (C9) size(window_reversal) = N - 2.
  • (C10) dim(lhs, input_batch_dimension) % batch_group_count = 0.
  • (C11) dim(lhs, input_feature_dimension) % feature_group_count = 0.
  • (C12) size(input_spatial_dimensions) = N - 2.
  • (C13) Dengan input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]:
    • is_unique(input_dimensions).
    • 0 <= input_dimensions < N.
  • (C14) dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count.
  • (C15) dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0.
  • (C16) dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0.
  • (C17) size(kernel_spatial_dimensions) = N - 2.
  • (C18) Dengan kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]:
    • is_unique(kernel_dimensions).
    • 0 <= kernel_dimensions < N.
  • (C19) size(output_spatial_dimensions) = N - 2.
  • (C20) Diberikan output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]:
    • is_unique(output_dimensions).
    • 0 <= output_dimensions < N.
  • (C21) 0 < feature_group_count.
  • (C22) 0 < batch_group_count.
  • (C23) feature_group_count = 1 or batch_group_count = 1.
  • (C24) size(precision_config) = 2.
  • (C25) dim(result, result_dim) ditentukan sebagai:
    • dim(lhs, input_batch_dimension) / batch_group_count if result_dim = output_batch_dimension.
    • dim(rhs, kernel_output_feature_dimension) jika result_dim = output_feature_dimension.
    • num_windows jika tidak, dengan:
    • output_spatial_dimensions[spatial_dim] = result_dim.
    • lhs_dim = input_spatial_dimensions[spatial_dim].
    • rhs_dim = kernel_spatial_dimensions[spatial_dim].
    • dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1.
    • padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1].
    • dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1.
    • is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim].
    • num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1.
  • (C26) rank(result) = N.
  • Jika operasi menggunakan tensor yang tidak dikuantisasi:
    • (C27) element_type(lhs) = element_type(rhs) = element_type(result).
  • Jika operasi menggunakan tensor kuantisasi:
    • (C28) is_quantized(lhs) = is_quantized(result) and is_quantized(rhs).
    • (C29) Jika is_per_axis_quantized(rhs), maka quantization_dimension(rhs) = kernel_output_feature_dimension.
    • (C30) Jika is_per_axis_quantized(result), maka quantization_dimension(result) = output_feature_dimension.
    • Jika is_quantized(lhs):
    • (C31) storage_type(lhs) = storage_type(rhs).
    • (C32) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • (C33) Jika is_per_tensor_quantized(rhs), maka is_per_tensor_quantized(result).
    • Jika !is_quantized(lhs):
    • (C34) element_type(lhs) = expressed_type(rhs) = element_type(result).

Contoh

// %lhs: [[
//        [[1], [2], [5], [6]],
//        [[3], [4], [7], [8]],
//        [[10], [11], [14], [15]],
//        [[12], [13], [16], [17]]
//      ]]
//
// %rhs: [
//         [[[1]], [[1]], [[1]]],
//         [[[1]], [[1]], [[1]]],
//         [[[1]], [[1]], [[1]]]
//        ]
// %padding: [[1, 1],
//            [1, 1]]
%result = "stablehlo.dynamic_conv"(%lhs, %rhs, %padding) {
  window_strides = array<i64: 4, 4>,
  lhs_dilation = array<i64: 2, 2>,
  rhs_dilation = array<i64: 1, 1>,
  window_reversal = array<i1: false, false>,
  dimension_numbers = #stablehlo.conv<raw
    input_batch_dimension = 0,
    input_feature_dimension = 3,
    input_spatial_dimensions = [0, 1],
    kernel_input_feature_dimension = 2,
    kernel_output_feature_dimension = 3,
    kernel_spatial_dimensions = [0, 1],
    output_batch_dimension = 0,
    output_feature_dimension = 3,
    output_spatial_dimensions = [1, 2]
  >,
  feature_group_count = 1 : i64,
  batch_group_count = 1 : i64,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>, tensor<2x2xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
//            [[1], [5]],
//            [[10], [14]]
//          ]]

 Contoh Lainnya

dynamic_gather

Semantik

Operasi ini secara fungsional identik dengan operasi mengumpulkan, dengan slice_sizes yang ditentukan secara dinamis sebagai nilai.

Input

Label Nama Jenis Batasan
(I1) operand tensor atau tensor terkuantisasi per tensor (C1), (C7), (C10-C12), (C14)
(I2) start_indices tensor tipe integer (C2), (C3), (C13)
(I3) slice_sizes Tensor 1 dimensi dari jenis bilangan bulat (C8), (C11-C13)
(I4) offset_dims Konstanta tensor 1 dimensi dari jenis si64 (C1), (C4-C5), (C13)
(I5) collapsed_slice_dims Konstanta tensor 1 dimensi dari jenis si64 (C1), (C6-C8), (C13)
(I6) start_index_map Konstanta tensor 1 dimensi jenis si64 (C3), (C9), (C10)
(I7) index_vector_dim konstanta dari jenis si64 (C2), (C3), (C13)
(I8) indices_are_sorted konstanta jenis i1

Output

Nama Jenis Batasan
result tensor atau tensor terkuantisasi per-tensor (C5), (C13-C14)

Batasan

  • (C1) rank(operand) = size(offset_dims) + size(collapsed_slice_dims).
  • (C2) 0 <= index_vector_dim <= rank(start_indices).
  • (C3) size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1.
  • (C4) is_unique(offset_dims) and is_sorted(offset_dims).
  • (C5) 0 <= offset_dims < rank(result).
  • (C6) is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims).
  • (C7) 0 <= collapsed_slice_dims < rank(operand).
  • (C8) slice_sizes[collapsed_slice_dims...] <= 1.
  • (C9) is_unique(start_index_map).
  • (C10) 0 <= start_index_map < rank(operand).
  • (C11) size(slice_sizes) = rank(operand).
  • (C12) 0 <= slice_sizes <= shape(operand).
  • (C13) shape(result) = combine(batch_dim_sizes, offset_dim_sizes) dengan:
    • batch_dim_sizes = shape(start_indices), kecuali ukuran dimensi start_indices yang sesuai dengan index_vector_dim tidak disertakan.
    • offset_dim_sizes = shape(slice_sizes), kecuali ukuran dimensi di slice_sizes yang sesuai dengan collapsed_slice_dims tidak disertakan.
    • combine menempatkan batch_dim_sizes pada sumbu yang sesuai dengan batch_dims dan offset_dim_sizes pada sumbu yang sesuai dengan offset_dims.
  • (C14) element_type(operand) = element_type(result).

Contoh

// %operand: [
//            [[1, 2], [3, 4], [5, 6], [7, 8]],
//            [[9, 10],[11, 12], [13, 14], [15, 16]],
//            [[17, 18], [19, 20], [21, 22], [23, 24]]
//           ]
// %start_indices: [
//                  [[0, 0], [1, 0], [2, 1]],
//                  [[0, 1], [1, 1], [0, 2]]
//                 ]
// %slize_sizes: [1, 2, 2]
%result = "stablehlo.dynamic_gather"(%operand, %start_indices, %slize_sizes) {
  dimension_numbers = #stablehlo.gather<
    offset_dims = [2, 3],
    collapsed_slice_dims = [0],
    start_index_map = [1, 0],
    index_vector_dim = 2>,
  indices_are_sorted = false
} : (tensor<3x4x2xi64>, tensor<2x3x2xi64>, tensor<3xi64>) -> tensor<2x3x2x2xi64>
// %result: [
//            [
//              [[1, 2], [3, 4]],
//              [[3, 4], [5, 6]],
//              [[13, 14], [15, 16]]
//            ],
//            [
//              [[9, 10], [11, 12]],
//              [[11, 12], [13, 14]],
//              [[17, 18], [19, 20]]
//            ]
//          ]

 Contoh Lainnya

dynamic_iota

Semantik

Operasi ini secara fungsional identik dengan operasi iota, tetapi bentuk hasilnya ditetapkan secara dinamis melalui output_shape.

Input

Label Nama Jenis Batasan
(I1) output_shape Tensor 1-dimensi dari jenis integer (C1), (C2)
(I2) iota_dimension si64 (C1)

Output

Nama Jenis Batasan
result tensor bilangan bulat, floating point, atau jenis kompleks atau tensor terkuantisasi per tensor (C2)

Batasan

  • (C1) 0 <= iota_dimension < size(output_shape).
  • (C2) rank(result) = size(output_shape).

Contoh

%output_shape = stablehlo.constant dense<[4, 5]> : tensor<2xi64>
%result = "stablehlo.dynamic_iota"(%output_shape) {
  iota_dimension = 0 : i64
} : (tensor<2xi64>) -> tensor<4x5xi64>
// %result: [
//           [0, 0, 0, 0, 0],
//           [1, 1, 1, 1, 1],
//           [2, 2, 2, 2, 2],
//           [3, 3, 3, 3, 3]
//          ]

 Contoh Lainnya

dynamic_pad

Semantik

Operasi ini secara fungsional identik dengan op pad, tetapi dengan edge_padding_low, edge_padding_high, dan interior_padding ditentukan secara dinamis sebagai nilai.

Input

Label Nama Jenis Batasan
(I1) operand tensor atau tensor terkuantisasi per tensor (C1), (C2), (C4)
(I2) padding_value Tensor 0 dimensi atau tensor terkuantisasi per tensor (C1)
(I3) edge_padding_low Tensor 1 dimensi dari jenis bilangan bulat (C1), C4
(I4) edge_padding_high Tensor 1 dimensi dari jenis bilangan bulat (C1), (C4)
(I5) interior_padding Tensor 1 dimensi dari jenis bilangan bulat (C2-C4)

Output

Nama Jenis Batasan
result tensor atau tensor terkuantisasi per tensor (C3-C6)

Batasan

  • (C1) element_type(operand) = element_type(padding_value) = element_type(result).
  • (C2) size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand).
  • (C3) 0 <= interior_padding.
  • (C4) shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high.

Contoh

// %operand: [
//            [1, 2, 3],
//            [4, 5, 6]
//           ]
// %padding_value: 0
// %edge_padding_low: [0, 1]
// %edge_padding_high: [2, 1]
// %interior_padding: [1, 2]
%result = "stablehlo.dynamic_pad"(%operand, %padding_value,
  %edge_padding_low, %edge_padding_high, %interior_padding
) : (tensor<2x3xi64>, tensor<i64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<5x9xi64>
// %result: [
//           [0, 1, 0, 0, 2, 0, 0, 3, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 4, 0, 0, 5, 0, 0, 6, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0]
//          ]

 Contoh Lainnya

dynamic_reshape

Semantik

Operasi ini secara fungsional identik dengan operasi reshape, tetapi bentuk hasilnya ditentukan secara dinamis melalui output_shape.

Input

Label Nama Jenis Batasan
(I1) operand tensor atau tensor terkuantisasi (C1-C3)
(I2) output_shape Tensor 1 dimensi dari jenis bilangan bulat (C4)

Output

Nama Jenis Batasan
result tensor atau tensor terkuantisasi (C1-C4)

Batasan

  • (C1) element_type(result) diberikan oleh:
    • element_type(operand), jika !is_per_axis_quantized(operand).
    • element_type(operand) kecuali bahwa quantization_dimension(operand) dan quantization_dimension(result) mungkin berbeda, jika tidak.
  • (C2) size(operand) = size(result).
  • (C3) Jika is_per_axis_quantized(operand):
    • reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
    • dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result)).
    • reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
  • (C4) size(output_shape) = rank(result).

Contoh

// %operand: [[1, 2, 3], [4, 5, 6]]
// %output_shape: [3, 2]
%result = "stablehlo.dynamic_reshape"(%operand, %output_shape) : (tensor<2x3xi64>, tensor<2xi64>) -> tensor<3x2xi64>
// %result: [[1, 2], [3, 4], [5, 6]]

 Contoh Lainnya

dynamic_slice

Semantik

Mengekstrak slice dari operand menggunakan indeks awal yang dihitung secara dinamis dan menghasilkan tensor result. start_indices berisi indeks awal slice untuk setiap dimensi yang dapat disesuaikan, dan slice_sizes berisi ukuran slice untuk setiap dimensi. Secara lebih formal, result[result_index] = operand[operand_index] dengan:

  • adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes).
  • operand_index = adjusted_start_indices + result_index.

Input

Label Nama Jenis Batasan
(I1) operand tensor atau tensor terkuantisasi per tensor (C1), (C2), (C4)
(I2) start_indices jumlah variabel dari tensor 0 dimensi dengan jenis bilangan bulat (C2), (C3)
(I3) slice_sizes Konstanta tensor 1 dimensi dari jenis si64 (C2), (C4), (C5)

Output

Nama Jenis Batasan
result tensor atau tensor terkuantisasi per tensor (C1), (C5)

Batasan

  • (C1) element_type(operand) = element_type(result).
  • (C2) size(start_indices) = size(slice_sizes) = rank(operand).
  • (C3) same(type(start_indices...)).
  • (C4) 0 <= slice_sizes <= shape(operand).
  • (C5) shape(result) = slice_sizes.

Contoh

// %operand: [
//            [0, 0, 1, 1],
//            [0, 0, 1, 1],
//            [0, 0, 0, 0],
//            [0, 0, 0, 0]
//           ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
  slice_sizes = array<i64: 2, 2>
} : (tensor<4x4xi32>, tensor<i64>, tensor<i64>) -> tensor<2x2xi32>
// %result: [
//           [1, 1],
//           [1, 1]
//          ]

Contoh Lainnya

dynamic_update_slice

Semantik

Menghasilkan tensor result yang sama dengan tensor operand, kecuali bahwa slice yang dimulai dari start_indices diperbarui dengan nilai di update. Secara lebih formal, result[result_index] didefinisikan sebagai:

  • update[update_index] jika 0 <= update_index < shape(update) di mana:
    • adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update)).
    • update_index = result_index - adjusted_start_indices.
  • operand[result_index] sebaliknya.

Input

Label Nama Jenis Batasan
(I1) operand tensor atau tensor terkuantisasi per tensor (C1-C4), (C6)
(I2) update tensor atau tensor terkuantisasi per-tensor (C2), (C3), (C6)
(I3) start_indices jumlah variabel dari tensor 0 dimensi dengan jenis bilangan bulat (C4), (C5)

Output

Nama Jenis Batasan
result tensor atau tensor terkuantisasi per tensor (C1)

Batasan

  • (C1) type(operand) = type(result).
  • (C2) element_type(update) = element_type(operand).
  • (C3) rank(update) = rank(operand).
  • (C4) size(start_indices) = rank(operand).
  • (C5) same(type(start_indices...)).
  • (C6) 0 <= shape(update) <= shape(operand).

Contoh

// %operand: [
//            [1, 1, 0, 0],
//            [1, 1, 0, 0],
//            [1, 1, 1, 1],
//            [1, 1, 1, 1]
//           ]
// %update: [
//           [1, 1],
//           [1, 1]
//          ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
  : (tensor<4x4xi32>, tensor<2x2xi32>, tensor<i64>, tensor<i64>) -> tensor<4x4xi32>
// %result: [
//           [1, 1, 1, 1],
//           [1, 1, 1, 1],
//           [1, 1, 1, 1],
//           [1, 1, 1, 1]
//          ]

Contoh Lainnya

berpangkat

Semantik

Melakukan operasi eksponensial element-wise pada tensor operand dan menghasilkan tensor result. Bergantung pada jenis elemen, lakukan tindakan berikut:

  • Untuk float: exp dari IEEE-754.
  • Untuk bilangan kompleks: eksponensial kompleks.
  • Untuk jenis kuantisasi: dequantize_op_quantize(exponential, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand tensor floating point atau jenis kompleks atau tensor terkuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result tensor dari jenis floating point atau kompleks atau tensor terkuantisasi per tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]

 Contoh Lainnya

eksponensial_minus_satu

Semantik

Menjalankan eksponensial element-wise dikurangi satu operasi pada tensor operand dan menghasilkan tensor result. Bergantung pada jenis elemen, lakukan tindakan berikut:

  • Untuk float: expm1 dari IEEE-754.
  • Untuk bilangan kompleks: eksponensial kompleks dikurangi satu.
  • Untuk jenis kuantisasi: dequantize_op_quantize(exponential_minus_one, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand tensor floating point atau jenis kompleks atau tensor terkuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result tensor dari jenis floating point atau kompleks atau tensor terkuantisasi per tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand) : (tensor<2xf64>) -> tensor<2xf64>
// %result: [0.0, 1.71828187]

 Contoh Lainnya

fft

Semantik

Melakukan transformasi Fourier maju dan balik untuk input/output real dan kompleks.

fft_type adalah salah satu dari berikut ini:

  • FFT: Meneruskan FFT kompleks ke kompleks.
  • IFFT: FFT kompleks-ke-kompleks terbalik.
  • RFFT: Meneruskan FFT real-to-complex.
  • IRFFT: FFT real-ke-kompleks terbalik (yaitu yang kompleks, menampilkan nilai nyata).

Secara lebih formal, dengan fungsi fft yang menggunakan tensor 1 dimensi dari jenis kompleks sebagai input, menghasilkan tensor 1 dimensi dari jenis yang sama sebagai output dan menghitung transformasi Fourier diskret:

Untuk fft_type = FFT, result ditentukan sebagai hasil akhir dari serangkaian komputasi L dengan L = size(fft_length). Misalnya, untuk L = 3:

  • result1[i0, ..., :] = fft(operand[i0, ..., :]).
  • result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).

Selanjutnya, mengingat fungsi ifft yang memiliki tanda tangan jenis yang sama dan menghitung invers dari fft:

Untuk fft_type = IFFT, result ditentukan sebagai invers komputasi untuk fft_type = FFT. Misalnya, untuk L = 3:

  • result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).
  • result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :] = ifft(result2[i0, ..., :]).

Selain itu, dengan fungsi rfft yang menggunakan tensor 1 dimensi dari jenis floating point, menghasilkan tensor 1 dimensi dari jenis kompleks dari semantik floating point yang sama dan berfungsi sebagai berikut:

  • rfft(real_operand) = truncated_result di mana
  • complex_operand... = (real_operand..., 0.0).
  • complex_result = fft(complex_operand).
  • truncated_result = complex_result[:(rank(complex_result) / 2 + 1)].

(Saat transformasi Fourier diskret dihitung untuk operand nyata, elemen N/2 + 1 pertama dari hasil akan menentukan sisa hasil secara tidak ambigu, sehingga hasil rfft akan terpotong untuk menghindari komputasi elemen yang redundan).

Untuk fft_type = RFFT, result ditentukan sebagai hasil akhir dari serangkaian komputasi L dengan L = size(fft_length). Misalnya, untuk L = 3:

  • result1[i0, ..., :] = rfft(operand[i0, ..., :]).
  • result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).

Terakhir, dengan fungsi irfft yang memiliki tanda tangan jenis yang sama dan menghitung invers dari rfft:

Untuk fft_type = IRFFT, result ditentukan sebagai invers komputasi untuk fft_type = RFFT. Misalnya, untuk L = 3:

  • result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).
  • result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :] = irfft(result2[i0, ..., :]).

Input

Label Nama Jenis Batasan
(I1) operand tensor dari jenis floating point atau kompleks (C1), (C2), (C4), (C5)
(I2) fft_type enum FFT, IFFT, RFFT, dan IRFFT (C2), (C5)
(I3) fft_length Konstanta tensor 1 dimensi dari jenis si64 (C1), (C3), (C4)

Output

Nama Jenis Batasan
result tensor floating point atau jenis kompleks (C2), (C4), (C5)

Batasan

  • (C1) size(fft_length) <= rank(operand).
  • (C2) Hubungan antara jenis elemen operand dan result bervariasi:
    • Jika fft_type = FFT, element_type(operand), dan element_type(result) memiliki jenis kompleks yang sama.
    • Jika fft_type = IFFT, element_type(operand), dan element_type(result) memiliki jenis kompleks yang sama.
    • Jika fft_type = RFFT, element_type(operand) adalah jenis floating point dan element_type(result) adalah jenis kompleks dari semantik floating point yang sama.
    • Jika fft_type = IRFFT, element_type(operand) adalah jenis kompleks dan element_type(result) adalah jenis floating point dari semantik floating point yang sama.
  • (C3) 1 <= size(fft_length) <= 3.
  • (C4) Jika di antara operand dan result, ada tensor real dari jenis floating point, maka shape(real)[-size(fft_length):] = fft_length.
  • (C5) shape(result) = shape(operand) kecuali untuk:
    • Jika fft_type = RFFT, dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1.
    • Jika fft_type = IRFFT, dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1.

Contoh

// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
  fft_type = #stablehlo<fft_type FFT>,
  fft_length = array<i64: 4>
} : (tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>>
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]

lantai

Semantik

Melakukan floor element-wise dari tensor operand dan menghasilkan tensor result. Mengimplementasikan operasi roundToIntegralTowardNegative dari spesifikasi IEEE-754. Untuk jenis kuantisasi, lakukan dequantize_op_quantize(floor, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand tensor jenis floating point atau tensor terkuantisasi per tensor (C1)

Output

Nama Jenis Batasan
result tensor jenis floating point atau tensor terkuantisasi per tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]

 Contoh Lainnya

mengumpulkan

Semantik

Mengumpulkan slice dari tensor operand dari offset yang ditentukan dalam start_indices dan menghasilkan tensor result.

Diagram berikut menunjukkan cara elemen di result memetakan elemen di operand menggunakan contoh konkret. Diagram ini mengambil beberapa contoh indeks result dan menjelaskan secara mendetail indeks operand mana yang sesuai dengannya.

mengumpulkan

Secara lebih formal, result[result_index] = operand[operand_index] dengan:

  • batch_dims = [d for d in axes(result) and d not in offset_dims].
  • batch_index = result_index[batch_dims...].
  • start_index ditentukan sebagai:
    • start_indices[bi0, ..., :, ..., biN] dengan bi adalah elemen individual dalam batch_index dan : disisipkan pada indeks index_vector_dim, jika index_vector_dim < rank(start_indices).
    • [start_indices[batch_index]] jika tidak.
  • Untuk d_operand di axes(operand),
    • full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand]) if d_operand = start_index_map[d_start].
    • full_start_index[d_operand] = 0 jika tidak.
  • Untuk d_operand di axes(operand),
    • full_batching_index[d_operand] = batch_index[d_start - (d_start < index_vector_dim ? 0 : 1)] jika d_operand = operand_batching_dims[i_batching] dan d_start = start_indices_batching_dims[i_batching].
    • full_batching_index[d_operand] = 0 jika tidak.
  • offset_index = result_index[offset_dims...].
  • full_offset_index = [oi0, ..., 0, ..., oiN] dengan oi adalah elemen individual di offset_index, dan 0 disisipkan pada indeks dari collapsed_slice_dims dan operand_batching_dims.
  • operand_index = full_start_index + full_batching_index + full_offset_index.

Jika indices_are_sorted adalah true, implementasi dapat mengasumsikan bahwa start_indices diurutkan sehubungan dengan start_index_map, jika tidak, perilaku tidak ditentukan. Secara lebih formal, untuk semua i1 < i2 dari indices(result), full_start_index(i1) <= full_start_index(i2).

Input

Label Nama Jenis Batasan
(I1) operand tensor atau tensor terkuantisasi per-tensor (C1), (C8), (C11), (C17), (C19-C21), (C23)
(I2) start_indices tensor tipe integer (C2-C3), (C14), (C17), (C22)
(I3) offset_dims Konstanta tensor 1 dimensi jenis si64 (C1), (C4-C5), (C22)
(I4) collapsed_slice_dims Konstanta tensor 1 dimensi dari jenis si64 (C1), (C6-C9), (C22)
(I5) operand_batching_dims Konstanta tensor 1 dimensi dari jenis si64 (C1), (C6), (C10-C12), (C16-C18), (C22)
(I6) start_indices_batching_dims Konstanta tensor 1 dimensi jenis si64 (C13-C17)
(I7) start_index_map Konstanta tensor 1 dimensi dari jenis si64 (C3), (C18-C19)
(I8) index_vector_dim konstanta dari jenis si64 (C2-C3), (C15), (C22)
(I9) slice_sizes Konstanta tensor 1 dimensi dari jenis si64 (C9), (C12), (C20-C22)
(I10) indices_are_sorted konstanta dari jenis i1

Output

Nama Jenis Batasan
result tensor atau tensor terkuantisasi per tensor (C5), (C22-C23)

Batasan

  • (C1) rank(operand) = size(offset_dims) + size(collapsed_slice_dims) + size(operand_batching_dims).
  • (C2) 0 <= index_vector_dim <= rank(start_indices).
  • (C3) size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1.
  • (C4) is_unique(offset_dims) and is_sorted(offset_dims).
  • (C5) 0 <= offset_dims < rank(result).
  • (C6) is_unique(concatenate(collapsed_slice_dims, operand_batching_dims))
  • (C7) is_sorted(collapsed_slice_dims).
  • (C8) 0 <= collapsed_slice_dims < rank(operand).
  • (C9) slice_sizes[collapsed_slice_dims...] <= 1.
  • (C10) is_sorted(operand_batching_dims).
  • (C11) 0 <= operand_batching_dims < rank(operand).
  • (C12) slice_sizes[operand_batching_dims...] <= 1.
  • (C13) is_unique(start_indices_batching_dims).
  • (C14) 0 <= start_indices_batching_dims < rank(start_indices).
  • (C15) index_vector_dim not in start_indices_batching_dims.
  • (C16) size(operand_batching_dims) == size(start_indices_batching_dims).
  • (C17) dim(operand, operand_batching_dims...) = dim(start_indices, start_indices_batching_dims...).
  • (C18) is_unique(concatenate(start_index_map, operand_batching_dims)).
  • (C19) 0 <= start_index_map < rank(operand).
  • (C20) size(slice_sizes) = rank(operand).
  • (C21) 0 <= slice_sizes <= shape(operand).
  • (C22) shape(result) = combine(batch_dim_sizes, offset_dim_sizes) dengan:
    • batch_dim_sizes = shape(start_indices), tetapi ukuran dimensi start_indices yang sesuai dengan index_vector_dim tidak disertakan.
    • offset_dim_sizes = slice_sizes, kecuali ukuran dimensi di slice_sizes yang sesuai dengan collapsed_slice_dims dan operand_batching_dims tidak disertakan.
    • combine menempatkan batch_dim_sizes pada sumbu yang sesuai dengan batch_dims dan offset_dim_sizes pada sumbu yang sesuai dengan offset_dims.
  • (C23) element_type(operand) = element_type(result).

Contoh

// %operand: [
//            [
//             [[1, 2], [3, 4], [5, 6], [7, 8]],
//             [[9, 10],[11, 12], [13, 14], [15, 16]],
//             [[17, 18], [19, 20], [21, 22], [23, 24]]
//            ],
//            [
//             [[25, 26], [27, 28], [29, 30], [31, 32]],
//             [[33, 34], [35, 36], [37, 38], [39, 40]],
//             [[41, 42], [43, 44], [45, 46], [47, 48]]
//            ]
//           ]
// %start_indices: [
//                  [
//                   [[0, 0], [1, 0], [2, 1]],
//                   [[0, 1], [1, 1], [0, 9]]
//                  ],
//                  [
//                   [[0, 0], [2, 1], [2, 2]],
//                   [[1, 2], [0, 1], [1, 0]]
//                  ]
//                 ]
%result = "stablehlo.gather"(%operand, %start_indices) {
  dimension_numbers = #stablehlo.gather<
    offset_dims = [3, 4],
    collapsed_slice_dims = [1],
    operand_batching_dims = [0],
    start_indices_batching_dims = [1],
    start_index_map = [2, 1],
    index_vector_dim = 3>,
  slice_sizes = array<i64: 1, 1, 2, 2>,
  indices_are_sorted = false
} : (tensor<2x3x4x2xi32>, tensor<2x2x3x2xi64>) -> tensor<2x2x3x2x2xi32>
// %result: [
//           [
//            [
//             [[1, 2], [3, 4]],
//             [[3, 4], [5, 6]],
//             [[13, 14], [15, 16]]
//            ],
//            [
//             [[33, 34], [35, 36]],
//             [[35, 36], [37, 38]],
//             [[41, 42], [43, 44]]
//            ]
//           ],
//           [
//            [
//             [[1, 2], [3, 4]],
//             [[13, 14], [15, 16]],
//             [[21, 22], [23, 24]]
//            ],
//            [
//             [[43, 44], [45, 46]],
//             [[33, 34], [35, 36]],
//             [[27, 28], [29, 30]]
//            ]
//           ]
//          ]

 Contoh Lainnya

get_dimension_size

Semantik

Menghasilkan ukuran dimension yang diberikan dari operand. Secara lebih formal, result = dim(operand, dimension). Semantik hanya berkaitan dengan komponen bentuk jenis. Jenis elemen dapat berupa apa saja.

Input

Label Nama Jenis Batasan
(I1) operand tensor atau tensor terkuantisasi (C1)
(I2) dimension konstanta dari jenis si64 (C1)

Output

Nama Jenis
result Tensor 0 dimensi dari jenis si32

Batasan

  • (C1) 0 <= dimension < rank(operand).

Contoh

// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
  dimension = 1 : i64
} : (tensor<2x3xi64>) -> tensor<i32>
// %result: 3

 Contoh Lainnya

get_tuple_element

Semantik

Mengekstrak elemen pada posisi index tuple operand dan menghasilkan result. Secara lebih formal, result = operand[index].

Input

Label Nama Jenis Batasan
(I1) operand tuple (C1), (C2)
(I2) index konstanta dari jenis si32 (C1), C2

Output

Nama Jenis Batasan
result jenis apa pun yang didukung (C2)

Batasan

  • (C1) 0 <= index < size(operand).
  • (C2) type(result) = tuple_element_types(operand)[index].

Contoh

// %operand: ([1.0, 2.0], (3))
  index = 0 : i32
} : (tuple<tensor<2xf32>, tuple<tensor<i32>>>) -> tensor<2xf32>
// %result: [1.0, 2.0]

 Contoh Lainnya

jika

Semantik

Menghasilkan output dari menjalankan tepat satu fungsi dari true_branch atau false_branch, bergantung pada nilai pred. Secara lebih formal, result = pred ? true_branch() : false_branch().

Input

Label Nama Jenis Batasan
(I1) pred Tensor 0 dimensi dari jenis i1
(I2) true_branch fungsi (C1-C3)
(I3) false_branch fungsi (C1), (C2)

Output

Nama Jenis Batasan
results jumlah variabel tensor, tensor terkuantisasi, atau token (C3)

Batasan

  • (C1) input_types(true_branch) = input_types(false_branch) = [].
  • (C2) output_types(true_branch) = output_types(false_branch).
  • (C3) type(results...) = output_types(true_branch).

Contoh

// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
  "stablehlo.return"(%result_true_branch) : (tensor<i32>) -> ()
}, {
  "stablehlo.return"(%result_false_branch) : (tensor<i32>) -> ()
}) : (tensor<i1>) -> tensor<i32>
// %result: 10

 Contoh Lainnya

imag

Semantik

Mengekstrak bagian imajiner, per elemen, dari operand dan menghasilkan tensor result. Secara lebih formal, untuk setiap elemen x: imag(x) = is_complex(x) ? imaginary_part(x) : constant(0, element_type(result)).

Input

Label Nama Jenis Batasan
(I1) operand tensor dari jenis floating point atau kompleks (C1), C2

Output

Nama Jenis Batasan
result tensor jenis floating point (C1), (C2)

Batasan

  • (C1) shape(result) = shape(operand).
  • (C2) element_type(result) ditentukan sebagai:
    • complex_element_type(element_type(operand)) if is_complex(operand).
    • element_type(operand) jika tidak.

Contoh

// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [2.0, 4.0]

Contoh Lainnya

dalam feed

Semantik

Membaca data dari infeed dan menghasilkan results.

Semantik infeed_config ditentukan oleh implementasi.

results terdiri dari nilai payload yang muncul terlebih dahulu dan token yang muncul terakhir. Pada masa mendatang, kami berencana untuk membagi payload dan token menjadi dua output terpisah untuk meningkatkan kejelasan (#670).

Input

Label Nama Jenis
(I1) token token
(I2) infeed_config konstanta dari jenis string

Output

Nama Jenis Batasan
results jumlah tensor, tensor atau token terkuantisasi (C1-C3)

Batasan

  • (C1) 0 < size(results).
  • (C2) is_empty(result[:-1]) atau is_tensor(type(results[:-1])).
  • (C3) is_token(type(results[-1])).

Contoh

// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
  infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
  infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]

 Contoh Lainnya

iota

Semantik

Mengisi tensor output dengan nilai dalam urutan menaik mulai dari nol di sepanjang dimensi iota_dimension. Secara lebih formal,

output[output_index] = constant(is_quantized(output) ? quantize(output_index[iota_dimension], element_type(output)) : output_index[iota_dimension], element_type(output)).

Input

Label Nama Jenis Batasan
(I1) iota_dimension si64 (C1)

Output

Nama Jenis Batasan
output tensor bilangan bulat, floating point, atau jenis kompleks atau tensor terkuantisasi per tensor (C1)

Batasan

  • (C1) 0 <= iota_dimension < rank(output).

Contoh

%output = "stablehlo.iota"() {
  iota_dimension = 0 : i64
} : () -> tensor<4x5xi32>
// %output: [
//           [0, 0, 0, 0, 0],
//           [1, 1, 1, 1, 1],
//           [2, 2, 2, 2, 2],
//           [3, 3, 3, 3, 3]
//          ]

%output = "stablehlo.iota"() {
  iota_dimension = 1 : i64
} : () -> tensor<4x5xi32>
// %output: [
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4]
//          ]

 Contoh Lainnya

is_finite

Semantik

Melakukan pemeriksaan elemen apakah nilai dalam x terbatas (yaitu bukan +Inf, -Inf, atau NaN) dan menghasilkan tensor y. Mengimplementasikan operasi isFinite dari spesifikasi IEEE-754. Untuk jenis kuantisasi, hasilnya selalu true.

Input

Label Nama Jenis Batasan
(I1) x tensor jenis floating point atau tensor terkuantisasi per tensor (C1)

Output

Nama Jenis Batasan
y tensor dari jenis boolean (C1)

Batasan

  • (C1) shape(x) = shape(y).

Contoh

// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x) : (tensor<7xf64) -> tensor<7xi1>
// %y: [false, false, false, true, true, true, true]

 Contoh Lainnya

log

Semantik

Melakukan operasi logaritma element-wise pada tensor operand dan menghasilkan tensor result. Bergantung pada jenis elemen, melakukan hal berikut:

  • Untuk float: log dari IEEE-754.
  • Untuk bilangan kompleks: logaritma kompleks.
  • Untuk jenis kuantisasi: dequantize_op_quantize(log, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand tensor floating point atau jenis kompleks atau tensor terkuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result tensor dari jenis floating point atau kompleks atau tensor terkuantisasi per tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]

 Contoh Lainnya

log_plus_one

Semantik

Melakukan logaritma element-wise plus satu operasi pada tensor operand dan menghasilkan tensor result. Bergantung pada jenis elemen, melakukan hal berikut:

  • Untuk float: logp1 dari IEEE-754.
  • Untuk bilangan kompleks: logaritma kompleks ditambah satu.
  • Untuk jenis kuantisasi: dequantize_op_quantize(log_plus_one, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand tensor floating point atau jenis kompleks atau tensor terkuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result tensor dari jenis floating point atau kompleks atau tensor terkuantisasi per tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]

 Contoh Lainnya

logistik

Semantik

Melakukan operasi logistik element-wise pada tensor operand dan menghasilkan tensor result. Bergantung pada jenis elemen, lakukan tindakan berikut:

  • Untuk float: division(1, addition(1, exp(-x))) dari IEEE-754.
  • Untuk bilangan kompleks: logistik kompleks.
  • Untuk jenis terkuantisasi: dequantize_op_quantize(logistic, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand tensor floating point atau jenis kompleks atau tensor terkuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result tensor dari jenis floating point atau kompleks atau tensor terkuantisasi per tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]

 Contoh Lainnya

peta

Semantik

Menerapkan fungsi peta computation ke inputs di sepanjang dimensions dan menghasilkan tensor result.

Secara lebih formal, result[result_index] = computation(inputs...[result_index]).

Input

Label Nama Jenis Batasan
(I1) inputs jumlah variabel tensor atau tensor terkuantisasi per tensor (C1-C4)
(I2) dimensions Konstanta tensor 1 dimensi dari jenis si64 (C3)
(I3) computation fungsi (C4)

Output

Nama Jenis Batasan
result tensor atau tensor terkuantisasi per-tensor (C1), (C4)

Batasan

  • (C1) shape(inputs...) = shape(result).
  • (C2) 0 < size(inputs) = N.
  • (C3) dimensions = range(rank(inputs[0])).
  • (C4) computation memiliki jenis (tensor<E0>, ..., tensor<EN-1>) -> tensor<E'> dengan Ei = element_type(inputs[i]) dan E' = element_type(result).

Contoh

// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = stablehlo.multiply %arg0, %arg1 : tensor<i64>
    stablehlo.return %0 : tensor<i64>
}) {
  dimensions = array<i64: 0, 1>
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[0, 5], [12, 21]]

Contoh Lainnya

maksimum

Semantik

Melakukan operasi maksimum element-wise pada tensor lhs dan rhs serta menghasilkan tensor result. Bergantung pada jenis elemen, melakukan hal berikut:

  • Untuk boolean: OR logika.
  • Untuk bilangan bulat: maksimum bilangan bulat.
  • Untuk float: maximum dari IEEE-754.
  • Untuk bilangan kompleks: maksimum leksikografis untuk pasangan (real, imaginary). Memaksakan pengurutan pada angka kompleks melibatkan semantik yang mengejutkan, jadi di masa mendatang kami berencana menghapus dukungan untuk bilangan kompleks untuk operasi ini (#560).
  • Untuk jenis kuantisasi:
    • dequantize_op_quantize(maximum, lhs, rhs, type(result)).

Input

Label Nama Jenis Batasan
(I1) lhs tensor atau tensor terkuantisasi per-tensor (C1)
(I2) rhs tensor atau tensor terkuantisasi per tensor (C1)

Output

Nama Jenis Batasan
result tensor atau tensor terkuantisasi per tensor (C1)

Batasan

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Contoh

// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 8]]

 Contoh Lainnya

minimum

Semantik

Melakukan operasi min element-wise pada tensor lhs dan rhs serta menghasilkan tensor result. Bergantung pada jenis elemen, lakukan tindakan berikut:

  • Untuk boolean: AND logika.
  • Untuk bilangan bulat: minimum bilangan bulat.
  • Untuk float: minimum dari IEEE-754.
  • Untuk bilangan kompleks: minimum leksikografis untuk pasangan (real, imaginary). Menerapkan pengurutan pada bilangan kompleks melibatkan semantik yang mengejutkan, jadi pada masa mendatang, kami berencana untuk menghapus dukungan untuk bilangan kompleks untuk operasi ini (#560).
  • Untuk jenis terkuantisasi:
    • dequantize_op_quantize(minimum, lhs, rhs, type(result)).

Input

Label Nama Jenis Batasan
(I1) lhs tensor atau tensor terkuantisasi per-tensor (C1)
(I2) rhs tensor atau tensor terkuantisasi per tensor (C1)

Output

Nama Jenis Batasan
result tensor atau tensor terkuantisasi per tensor (C1)

Batasan

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Contoh

// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 4]]

 Contoh Lainnya

memperbanyak

Semantik

Melakukan produk element-wise dari dua tensor lhs dan rhs dan menghasilkan tensor result. Bergantung pada jenis elemen, lakukan tindakan berikut:

  • Untuk boolean: AND logika.
  • Untuk bilangan bulat: perkalian bilangan bulat.
  • Untuk float: multiplication dari IEEE-754.
  • Untuk bilangan kompleks: perkalian kompleks.
  • Untuk jenis terkuantisasi:
    • dequantize_op_quantize(multiply, lhs, rhs, type(result)).

Input

Label Nama Jenis Batasan
(I1) lhs tensor atau tensor terkuantisasi per-tensor (C1)
(I2) rhs tensor atau tensor terkuantisasi per tensor (C1)

Output

Nama Jenis Batasan
result tensor atau tensor terkuantisasi per tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 12], [21, 32]]

 Contoh Lainnya

negasi

Semantik

Melakukan negasi element-wise dari tensor operand dan menghasilkan tensor result. Bergantung pada jenis elemen, lakukan tindakan berikut:

  • Untuk bilangan bulat bertanda: negasi bilangan bulat.
  • Untuk bilangan bulat tanpa tanda: bitcast ke bilangan bulat bertanda, negasi bilangan bulat, bitcast kembali ke bilangan bulat tanpa tanda.
  • Untuk float: negate dari IEEE-754.
  • Untuk bilangan kompleks: negasi kompleks.
  • Untuk jenis kuantisasi: dequantize_op_quantize(negate, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand tensor bilangan bulat, floating point, atau jenis kompleks atau tensor terkuantisasi per tensor (C1)

Output

Nama Jenis Batasan
result tensor bilangan bulat, floating point, atau jenis kompleks atau tensor terkuantisasi per tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand) : (tensor<2xi32>) -> tensor<2xi32>
// %result: [0, 2]

// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"(%operand) : (tensor<1xcomplex<f32>>) -> tensor<1xcomplex<f32>>
// %result: [-2.5, -0.0]

Contoh Lainnya

tidak

Semantik

Melakukan NOT element-wise dari tensor operand dan menghasilkan tensor result. Bergantung pada jenis elemen, lakukan tindakan berikut:

  • Untuk boolean: logical NOT.
  • Untuk bilangan bulat: bitwise NOT.

Argumen

Nama Jenis Batasan
operand tensor dari jenis boolean atau bilangan bulat (C1)

Output

Nama Jenis Batasan
result tensor dari jenis boolean atau bilangan bulat (C1)

Batasan

  • (C1) type(operand) = type(result).

Contoh

// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand) : (tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[-2, -3], [-4, -5]]

// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"(%operand) : (tensor<2xi1>) -> tensor<2xi1>
// %result: [false, true]

 Contoh Lainnya

optimization_barrier

Semantik

Memastikan bahwa operasi yang menghasilkan operand dieksekusi sebelum operasi apa pun yang bergantung pada result dan mencegah transformasi compiler memindahkan operasi melintasi penghalang. Selain itu, operasinya adalah identitas, yaitu result = operand.

Argumen

Nama Jenis Batasan
operand jumlah variabel tensor, tensor atau token terkuantisasi per tensor (C1)

Output

Nama Jenis Batasan
result jumlah variabel tensor, tensor atau token terkuantisasi per tensor (C1)

Batasan

  • (C1) type(operand...) = type(result...).

Contoh

// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
// %result0: 0.0
// %result1: 1.0

 Contoh Lainnya

atau

Semantik

Melakukan OR element-wise dari dua tensor lhs dan rhs dan menghasilkan tensor result. Bergantung pada jenis elemen, melakukan hal berikut:

  • Untuk boolean: OR logika.
  • Untuk bilangan bulat: bitwise OR.

Input

Label Nama Jenis Batasan
(I1) lhs tensor dari jenis bilangan bulat atau boolean (C1)
(I2) rhs tensor berjenis integer atau boolean (C1)

Output

Nama Jenis Batasan
result tensor dari jenis bilangan bulat atau boolean (C1)

Batasan

  • (C1) type(lhs) = type(rhs) = type(result).

Contoh

// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 12]]

// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, true]]

Contoh Lainnya

outfeed

Semantik

Menulis inputs ke feed keluar dan menghasilkan token result.

Semantik outfeed_config ditentukan oleh implementasi.

Input

Label Nama Jenis
(I1) inputs jumlah variabel tensor atau tensor terkuantisasi
(I2) token token
(I3) outfeed_config konstanta dari jenis string

Output

Nama Jenis
result token

Contoh

%result = "stablehlo.outfeed"(%input0, %token) {
  outfeed_config = ""
} : (tensor<2x2x2xi64>, !stablehlo.token) -> !stablehlo.token

Contoh Lainnya

bantalan

Semantik

Memperluas operand dengan padding di sekitar tensor serta di antara elemen tensor dengan padding_value yang diberikan.

edge_padding_low dan edge_padding_high menentukan jumlah padding yang ditambahkan di bagian bawah (di samping indeks 0) dan bagian atas (di samping indeks tertinggi) dari setiap dimensi. Jumlah padding dapat negatif, dengan nilai absolut padding negatif menunjukkan jumlah elemen yang akan dihapus dari dimensi yang ditentukan.

interior_padding menentukan jumlah padding yang ditambahkan di antara dua elemen di setiap dimensi yang mungkin tidak negatif. Padding interior terjadi sebelum padding tepi sehingga padding tepi negatif akan menghapus elemen dari operand dengan padding bagian dalam.

Secara lebih formal, result[result_index] didefinisikan sebagai:

  • operand[operand_index] jika result_index = edge_padding_low + operand_index * (interior_padding + 1).
  • padding_value sebaliknya.

Input

Label Nama Jenis Batasan
(I1) operand tensor atau tensor terkuantisasi per tensor (C1), (C2), (C4)
(I2) padding_value Tensor 0 dimensi atau tensor terkuantisasi per tensor (C1)
(I3) edge_padding_low Konstanta tensor 1 dimensi jenis si64 (C1), (C4)
(I4) edge_padding_high Konstanta tensor 1 dimensi dari jenis si64 (C1), (C4)
(I5) interior_padding Konstanta tensor 1 dimensi dari jenis si64 (C2-C4)

Output

Nama Jenis Batasan
result tensor atau tensor terkuantisasi per tensor (C3-C6)

Batasan

  • (C1) element_type(operand) = element_type(padding_value) = element_type(result).
  • (C2) size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand).
  • (C3) 0 <= interior_padding.
  • (C4) shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high.

Contoh

// %operand: [
//            [1, 2, 3],
//            [4, 5, 6]
//           ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
  edge_padding_low = array<i64: 0, 1>,
  edge_padding_high = array<i64: 2, 1>,
  interior_padding = array<i64: 1, 2>
} : (tensor<2x3xi32>, tensor<i32>) -> tensor<5x9xi32>
// %result: [
//           [0, 1, 0, 0, 2, 0, 0, 3, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 4, 0, 0, 5, 0, 0, 6, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0]
//          ]

 Contoh Lainnya

partition_id

Semantik

Menghasilkan partition_id dari proses saat ini.

Output

Nama Jenis
result Tensor 0-dimensi jenis ui32

Contoh

%result = "stablehlo.partition_id"() : () -> tensor<ui32>

 Contoh Lainnya

popcnt

Semantik

Melakukan penghitungan elemen per elemen dari jumlah bit yang ditetapkan dalam tensor operand dan menghasilkan tensor result.

Input

Label Nama Jenis Batasan
(I1) operand tensor dari jenis bilangan bulat (C1)

Output

Nama Jenis Batasan
result tensor dari jenis bilangan bulat (C1)

Batasan

  • (C1) type(operand) = type(result).

Contoh

// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand) : (tensor<4xi64>) -> tensor<4xi64>
// %result: [0, 1, 1, 7]

 Contoh Lainnya

daya

Semantik

Melakukan eksponensial element-wise tensor lhs dengan tensor rhs dan menghasilkan tensor result. Bergantung pada jenis elemen, melakukan hal berikut:

  • Untuk bilangan bulat: eksponensial bilangan bulat.
  • Untuk float: pow dari IEEE-754.
  • Untuk bilangan kompleks: eksponensial kompleks.
  • Untuk jenis terkuantisasi: dequantize_op_quantize(power, lhs, rhs, type(result)).

Input

Label Nama Jenis Batasan
(I1) lhs tensor bilangan bulat, floating point, atau jenis kompleks atau tensor terkuantisasi per tensor (C1)
(I2) rhs tensor bilangan bulat, floating point, atau jenis kompleks atau tensor terkuantisasi per tensor (C1)

Output

Nama Jenis Batasan
result tensor bilangan bulat, floating point, atau jenis kompleks atau tensor terkuantisasi per tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs) : (tensor<6xf64>, tensor<6xf64>) -> tensor<6xf64>
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]

Contoh Lainnya

real

Semantik

Mengekstrak bagian riil, per elemen, dari operand dan menghasilkan tensor result. Secara lebih formal, untuk setiap elemen x: real(x) = is_complex(x) ? real_part(x) : x.

Input

Label Nama Jenis Batasan
(I1) operand tensor dari jenis floating point atau kompleks (C1), C2

Output

Nama Jenis Batasan
result tensor jenis floating point (C1), (C2)

Batasan

  • (C1) shape(result) = shape(operand).
  • (C2) element_type(result) ditentukan sebagai:
    • complex_element_type(element_type(operand)) if is_complex(operand).
    • element_type(operand) jika tidak.

Contoh

// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [1.0, 3.0]

 Contoh Lainnya

recv

Semantik

Menerima data dari saluran dengan channel_id dan menghasilkan results.

Jika is_host_transfer adalah true, operasi akan mentransfer data dari host. Jika tidak, data akan ditransfer dari perangkat lain. Artinya, ditentukan oleh implementasi. Flag ini menduplikasi informasi yang diberikan di channel_type, sehingga pada masa mendatang kami berencana untuk hanya menyimpan salah satunya (#666).

results terdiri dari nilai payload yang muncul terlebih dahulu dan token yang muncul terakhir. Pada masa mendatang, kami berencana untuk membagi payload dan token menjadi dua output terpisah untuk meningkatkan kejelasan (#670).

Input

Label Nama Jenis Batasan
(I1) token token (C4)
(I2) channel_id konstanta dari jenis si64
(I3) channel_type enum DEVICE_TO_DEVICE dan HOST_TO_DEVICE (C1)
(I4) is_host_transfer konstanta jenis i1 (C1)

Output

Nama Jenis Batasan
results jumlah tensor, tensor atau token terkuantisasi (C2-C4)

Batasan

  • (C1) channel_type didefinisikan sebagai:
    • HOST_TO_DEVICE jika is_host_transfer = true,
    • DEVICE_TO_DEVICE jika tidak.
  • (C2) 0 < size(results).
  • (C3) is_empty(result[:-1]) atau is_tensor(type(results[:-1])).
  • (C4) is_token(type(results[-1])).

Contoh

%results0, %results1 = "stablehlo.recv"(%token) {
  channel_handle = #stablehlo.channel_handle<handle = 1, type = 3>,
  is_host_transfer = true
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)

 Contoh Lainnya

reduce

Semantik

Menerapkan fungsi pengurangan body ke inputs dan init_values di sepanjang dimensions dan menghasilkan tensor results.

Urutan pengurangan ditentukan oleh implementasi, yang berarti bahwa body dan init_values harus membentuk monoid untuk menjamin bahwa operasi menghasilkan hasil yang sama untuk semua input di semua implementasi. Namun, kondisi ini tidak berlaku untuk banyak pengurangan populer. Misalnya, penambahan floating point untuk body dan nol untuk init_values sebenarnya tidak membentuk monoid karena penambahan floating point tidak asosiatif.

Secara lebih formal, results...[j0, ..., jR-1] = reduce(input_slices_converted) dengan:

  • input_slices = inputs...[j0, ..., :, ..., jR-1], dengan : disisipkan di dimensions.
  • input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...).
  • init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...).
  • reduce(input_slices_converted) = exec(schedule) untuk beberapa hierarki biner schedule dengan:
    • exec(node) = body(exec(node.left), exec(node.right)).
    • exec(leaf) = leaf.value.
  • schedule adalah hierarki biner penuh yang ditentukan oleh implementasi yang traversal berurutannya terdiri dari:
    • Nilai input_slices_converted...[index], untuk semua index di index_space(input_slices_converted) dalam urutan leksikografis menaik index.
    • Di sela-sela jumlah init_values_converted yang ditentukan implementasi pada posisi yang ditentukan implementasi.

Input

Label Nama Jenis Batasan
(I1) inputs jumlah variabel tensor atau tensor terkuantisasi per tensor (C1-C4), (C6), (C7)
(I2) init_values jumlah variadik tensor 0 dimensi atau tensor terkuantisasi per tensor (C2), (C3)
(I3) dimensions Konstanta tensor 1 dimensi dari jenis si64 (C4), (C5), (C7)
(I4) body fungsi (C6)

Output

Nama Jenis Batasan
results jumlah variabel tensor atau tensor terkuantisasi per tensor (C3), (C7), (C8)

Batasan

  • (C1) same(shape(inputs...)).
  • (C2) element_type(inputs...) = element_type(init_values...).
  • (C3) 0 < size(inputs) = size(init_values) = size(results) = N.
  • (C4) 0 <= dimensions < rank(inputs[0]).
  • (C5) is_unique(dimensions).
  • (C6) body memiliki jenis (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>) dengan is_promotable(element_type(inputs[i]), Ei).
  • (C7) shape(results...) = shape(inputs...), kecuali ukuran dimensi inputs... yang sesuai dengan dimensions tidak disertakan.
  • (C8) element_type(results[i]) = Ei untuk semua i di [0,N).

Contoh

// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  dimensions = array<i64: 1>
} : (tensor<1x6xi64>, tensor<i64>) -> tensor<1xi64>
// %result = [15]

 Contoh Lainnya

reduce_precision

Semantik

Melakukan konversi operand berdasarkan elemen ke jenis floating point lain yang menggunakan exponent_bits dan mantissa_bits, serta kembali ke jenis floating point asli dan menghasilkan tensor output.

Secara lebih formal:

  • Bit mantisa dari nilai asli diperbarui untuk membulatkan nilai asli ke nilai terdekat yang dapat direpresentasikan dengan mantissa_bits menggunakan semantik roundToIntegralTiesToEven.
  • Kemudian, jika mantissa_bits lebih kecil dari jumlah bit mantissa dari nilai asli, bit mantissa akan terpotong menjadi mantissa_bits.
  • Kemudian, jika bit eksponen dari hasil perantara tidak sesuai dengan rentang yang disediakan oleh exponent_bits, hasil perantara akan meluap ke tak terbatas menggunakan tanda asli atau underflow ke nol menggunakan tanda asli.
  • Untuk jenis kuantisasi, lakukan dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand tensor jenis floating point atau tensor terkuantisasi per tensor (C1)
(I2) exponent_bits konstanta dari jenis si32 (C2)
(I3) mantissa_bits konstanta jenis si32 (C3)

Output

Nama Jenis Batasan
output tensor jenis floating point atau tensor terkuantisasi per tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(output).
  • (C2) 1 <= exponent_bits.
  • (C3) 0 <= mantissa_bits.

Contoh

// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
  exponent_bits = 5 : i32,
  mantissa_bits = 10 : i32
} : (tensor<6xf64>) -> tensor<6xf64>
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]

 Contoh Lainnya

reduce_scatter

Semantik

reduce_scatter

Dalam setiap grup proses di petak proses StableHLO, melakukan pengurangan, menggunakan computations, pada nilai tensor operand dari setiap proses, membagi hasil pengurangan di sepanjang scatter_dimension menjadi beberapa bagian, dan menyebarkan bagian yang terpisah di antara proses untuk menghasilkan result.

Operasi ini membagi petak proses StableHLO menjadi process_groups yang ditentukan sebagai berikut:

  • cross_replica(replica_groups) if channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) jika channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) if channel_id > 0 and use_global_device_ids = true.

Setelah itu, dalam setiap process_group:

  • reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation).
  • parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension).
  • result@receiver = parts@sender[receiver_index] untuk semua sender di process_group, dengan receiver_index = process_group.index(receiver).

Input

Label Nama Jenis Batasan
(I1) operand tensor atau tensor terkuantisasi per tensor (C1), (C2), (C7), (C8)
(I2) scatter_dimension konstanta dari jenis si64 (C1), (C2), (C8)
(I3) replica_groups Konstanta tensor 2 dimensi dari jenis si64 (C3-C5)
(I4) channel_id konstanta jenis si64 (C6)
(I5) use_global_device_ids konstanta dari jenis i1 (C6)
(I6) computation fungsi (C7)

Output

Nama Jenis Batasan
result tensor atau tensor terkuantisasi per tensor (C8-C9)

Batasan

  • (C1) dim(operand, scatter_dimension) % dim(process_groups, 1) = 0.
  • (C2) 0 <= scatter_dimension < rank(operand).
  • (C3) is_unique(replica_groups).
  • (C4) size(replica_groups) ditentukan sebagai:
    • num_replicas jika cross_replica digunakan.
    • num_replicas jika cross_replica_and_partition digunakan.
    • num_processes jika flattened_ids digunakan.
  • (C5) 0 <= replica_groups < size(replica_groups).
  • (C6) Jika use_global_device_ids = true, maka channel_id > 0.
  • (C7) computation memiliki jenis (tensor<E>, tensor<E>) -> (tensor<E>) dengan is_promotable(element_type(operand), E).
  • (C8) shape(result) = shape(operand) kecuali:
    • dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1).
  • (C9) element_type(result) = E.

Contoh

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
//                   [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
//                   [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
  %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
  "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  scatter_dimension = 1 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x4xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[10, 12],
//                  [18, 20]]
// %result@(1, 0): [[14, 16],
//                  [22, 24]]

 Contoh Lainnya

reduce_window

Semantik

Menerapkan fungsi pengurangan body ke jendela inputs dan init_values dan menghasilkan results.

Diagram berikut menunjukkan cara elemen pada results... dihitung dari inputs... menggunakan contoh konkret.

reduce_window

Secara lebih formal, results...[result_index] = reduce(windows, init_values, axes(inputs...), body) (lihat reduce) dengan:

  • padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1).
  • window_start = result_index * window_strides.
  • window_end = window_start + (window_dimensions - 1) * window_dilations + 1.
  • windows = slice(padded_inputs..., window_start, window_end, window_dilations).

Input

Label Nama Jenis Batasan
(I1) inputs jumlah variabel tensor atau tensor terkuantisasi per tensor (C1-C4), (C6), (C8), (C10), (C12), (C13), (C15)
(I2) init_values jumlah variadik tensor 0 dimensi atau tensor terkuantisasi per tensor (C1), (C13)
(I3) window_dimensions Konstanta tensor 1 dimensi dari jenis si64 (C4), (C5), (C15)
(I4) window_strides Konstanta tensor 1 dimensi dari jenis si64 (C6), (C7), (C15)
(I5) base_dilations Konstanta tensor 1 dimensi dari jenis si64 (C8), (C9), (C15)
(I6) window_dilations Konstanta tensor 1 dimensi dari jenis si64 (C10), (C11), (C15)
(I7) padding Konstanta tensor 2 dimensi dari jenis si64 (C12), (C15)
(I8) body fungsi (C13)

Output

Nama Jenis Batasan
results jumlah variabel tensor atau tensor terkuantisasi per tensor (C1), (C14-C16)

Batasan

  • (C1) 0 < size(inputs) = size(init_values) = size(results) = N.
  • (C2) same(shape(inputs...)).
  • (C3) element_type(inputs...) = element_type(init_values...).
  • (C4) size(window_dimensions) = rank(inputs[0]).
  • (C5) 0 < window_dimensions.
  • (C6) size(window_strides) = rank(inputs[0]).
  • (C7) 0 < window_strides.
  • (C8) size(base_dilations) = rank(inputs[0]).
  • (C9) 0 < base_dilations.
  • (C10) size(window_dilations) = rank(inputs[0]).
  • (C11) 0 < window_dilations.
  • (C12) shape(padding) = [rank(inputs[0]), 2].
  • (C13) body memiliki jenis (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>) dengan is_promotable(element_type(inputs[i]), Ei).
  • (C14) same(shape(results...)).
  • (C15) shape(results[0]) = num_windows dengan:
    • dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1.
    • padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1].
    • dilated_window_shape = (window_dimensions - 1) * window_dilations + 1.
    • is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape.
    • num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1.
  • (C16) element_type(results[i]) = Ei untuk semua i di [0,N).

Contoh

// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  window_dimensions = array<i64: 2, 1>,
  window_strides = array<i64: 4, 1>,
  base_dilations = array<i64: 2, 1>,
  window_dilations = array<i64: 3, 1>,
  padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<3x2xi64>, tensor<i64>) -> tensor<2x2xi64>
// %result = [[0, 0], [3, 4]]

 Contoh Lainnya

sisanya

Semantik

Melakukan sisa elemen dari tensor dividen lhs dan pembagi rhs dan menghasilkan tensor result.

Secara lebih formal, tanda hasil diambil dari dividen, dan nilai absolut hasilnya selalu kurang dari nilai absolut pembagi. Sisa dihitung sebagai lhs - d * rhs, dengan d diberikan oleh:

  • Untuk bilangan bulat: stablehlo.divide(lhs, rhs).
  • Untuk float: division(lhs, rhs) dari IEEE-754 dengan atribut pembulatan roundTowardZero.
  • Untuk bilangan kompleks: TBD (#997).
  • Untuk jenis kuantisasi:
    • dequantize_op_quantize(remainder, lhs, rhs, type(result)).

Untuk jenis elemen floating point, operasi ini berbeda dengan operasi remainder dari spesifikasi IEEE-754 dengan d adalah nilai integral yang paling dekat dengan nilai persis lhs/rhs dengan ikatan ke genap.

Input

Label Nama Jenis Batasan
(I1) lhs tensor bilangan bulat, floating point, atau kompleks, atau tensor terkuantisasi per tensor (C1)
(I2) rhs tensor bilangan bulat, floating point, atau kompleks, atau tensor terkuantisasi per tensor (C1)

Output

Nama Jenis Batasan
result tensor integer, floating-point, atau tipe kompleks atau tensor terkuantisasi per-tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64>
// %result: [2, -2, 2, -2]

Contoh Lainnya

replica_id

Semantik

Menghasilkan replica_id dari proses saat ini.

Output

Nama Jenis
result Tensor 0-dimensi jenis ui32

Contoh

%result = "stablehlo.replica_id"() : () -> tensor<ui32>

 Contoh Lainnya

membentuk ulang

Semantik

Melakukan pembentukan ulang tensor operand menjadi tensor result. Secara konseptual, hal ini sama dengan mempertahankan representasi kanonis yang sama, tetapi berpotensi mengubah bentuknya, misalnya dari tensor<2x3xf32> menjadi tensor<3x2xf32> atau tensor<6xf32>.

Secara lebih formal, result[result_index] = operand[operand_index] dengan result_index dan operand_index memiliki posisi yang sama dalam pengurutan leksikal index_space(result) dan index_space(operand).

Input

Label Nama Jenis Batasan
(I1) operand tensor atau tensor terkuantisasi (C1-C3)

Output

Nama Jenis Batasan
result tensor atau quantized tensor (C1-C3)

Batasan

  • (C1) element_type(result) diberikan oleh:
    • element_type(operand), jika !is_per_axis_quantized(operand).
    • element_type(operand) kecuali bahwa quantization_dimension(operand) dan quantization_dimension(result) mungkin berbeda, jika tidak.
  • (C2) size(operand) = size(result).
  • (C3) Jika is_per_axis_quantized(operand):
    • reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
    • dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result)).
    • reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).

Contoh

// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.reshape"(%operand) : (tensor<2x3xi32>) -> tensor<3x2xi32>
// %result: [[1, 2], [3, 4], [5, 6]]

 Contoh Lainnya

balik

Semantik

Membalik urutan elemen dalam operand di sepanjang dimensions yang ditentukan dan menghasilkan tensor result. Secara lebih formal, result[result_index] = operand[operand_index] dengan:

  • operand_index[d] = dim(result, d) - result_index[d] - 1 jika d di dimensions.
  • operand_index[d] = result_index[d] sebaliknya.

Input

Label Nama Jenis Batasan
(I1) operand tensor atau tensor terkuantisasi per tensor (C1), (C3)
(I2) dimensions Konstanta tensor 1 dimensi dari jenis si64 (C2), (C3)

Output

Nama Jenis Batasan
result tensor atau tensor terkuantisasi per tensor (C1), (C3)

Batasan

  • (C1) type(operand) = type(result).
  • (C2) is_unique(dimensions).
  • (C3) 0 <= dimensions < rank(result).

Contoh

// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
  dimensions = array<i64: 1>
} : (tensor<3x2xi32>) -> tensor<3x2xi32>
// %result: [[2, 1], [4, 3], [6, 5]]

 Contoh Lainnya

rng

Semantik

Menghasilkan angka acak menggunakan algoritma rng_distribution dan menghasilkan tensor result dari bentuk shape tertentu.

Jika rng_distribution = UNIFORM, angka acak akan dihasilkan mengikuti distribusi seragam selama interval [a, b). Jika a >= b, perilaku tidak ditentukan.

Jika rng_distribution = NORMAL, angka acak akan dihasilkan mengikuti distribusi normal dengan mean = a dan simpangan baku = b. Jika b < 0, perilaku tidak ditentukan.

Cara persis pembuatan angka acak ditentukan oleh implementasi. Misalnya, status tersebut mungkin bersifat deterministik atau tidak, dan mungkin menggunakan status tersembunyi atau tidak.

Dalam diskusi dengan banyak pemangku kepentingan, operasi ini tampaknya tidak digunakan lagi secara efektif, jadi di masa mendatang kami berencana untuk menghapusnya (#597).

Input

Label Nama Jenis Batasan
(I1) a Tensor 0-dimensi dari jenis integer, boolean, atau floating point (C1), (C2)
(I2) b Tensor 0 dimensi dari jenis bilangan bulat, boolean, atau floating point (C1), (C2)
(I3) shape Konstanta tensor 1 dimensi dari jenis si64 (C3)
(I4) rng_distribution enum UNIFORM dan NORMAL (C2)

Output

Nama Jenis Batasan
result tensor bilangan bulat, boolean, atau jenis floating point (C1-C3)

Batasan

  • (C1) element_type(a) = element_type(b) = element_type(result).
  • (C2) Jika rng_distribution = NORMAL, maka is_float(a).
  • (C3) shape(result) = shape.

Contoh

// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
  rng_distribution = #stablehlo<rng_distribution UNIFORM>
} : (tensor<i32>, tensor<i32>, tensor<2xi64>) -> tensor<3x3xi32>
// %result: [
//           [1, 0, 1],
//           [1, 1, 1],
//           [0, 0, 0]
//          ]

rng_bit_generator

Semantik

Menampilkan output yang diisi dengan bit acak seragam dan status output yang diperbarui output_state menggunakan algoritma generator angka pseudorandom rng_algorithm dengan status awal initial_state. Output dijamin merupakan fungsi deterministik dari initial_state, tetapi tidak dijamin bersifat deterministik di antara implementasi.

rng_algorithm adalah salah satu dari berikut ini:

  • DEFAULT: Algoritma yang ditentukan implementasi.
  • THREE_FRY: Varian algoritma Threefry yang ditentukan implementasi.*
  • PHILOX: Varian algoritma Philox yang ditentukan implementasi.*

* Lihat: Salmon et al. SC 2011. Angka acak paralel: semudah 1, 2, 3.

Input

Label Nama Jenis Batasan
(I1) rng_algorithm enum DEFAULT, THREE_FRY, dan PHILOX (C2)
(I2) initial_state Tensor 1 dimensi dari jenis ui64 (C1), C2

Output

Nama Jenis Batasan
output_state Tensor 1 dimensi dari jenis ui64 (C1)
output tensor dari jenis bilangan bulat atau floating point

Batasan

  • (C1) type(initial_state) = type(output_state).
  • (C2) size(initial_state) ditentukan sebagai:
    • ditentukan oleh implementasi jika rng_algorithm = DEFAULT.
    • 2 if rng_algorithm = THREE_FRY.
    • 2 atau 3 jika rng_algorithm = PHILOX.

Contoh

// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
  rng_algorithm = #stablehlo<rng_algorithm THREE_FRY>
} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x2xui64>)
// %output_state: [1, 6]
// %output: [
//           [9236835810183407956, 16087790271692313299],
//           [18212823393184779219, 2658481902456610144]
//          ]

round_nearest_afz

Semantik

Melakukan pembulatan element-wise ke bilangan bulat terdekat, memisahkan dari nol, pada tensor operand dan menghasilkan tensor result. Menerapkan operasi roundToIntegralTiesToAway dari spesifikasi IEEE-754. Untuk jenis terkuantisasi, menjalankan dequantize_op_quantize(round_nearest_afz, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand tensor jenis floating point atau tensor terkuantisasi per tensor (C1)

Output

Nama Jenis Batasan
result tensor jenis floating point atau tensor terkuantisasi per tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]

Contoh Lainnya

round_nearest_even

Semantik

Melakukan pembulatan elemen ke bilangan bulat terdekat, memecahkan ikatan ke bilangan bulat genap, pada tensor operand dan menghasilkan tensor result. Mengimplementasikan operasi roundToIntegralTiesToEven dari spesifikasi IEEE-754. Untuk jenis kuantisasi, lakukan dequantize_op_quantize(round_nearest_even, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand tensor jenis floating point atau tensor terkuantisasi per tensor (C1)

Output

Nama Jenis Batasan
result tensor jenis floating point atau tensor terkuantisasi per tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]

 Contoh Lainnya

rsqrt

Semantik

Melakukan operasi akar kuadrat reciprocal element-wise pada tensor operand dan menghasilkan tensor result. Bergantung pada jenis elemen, lakukan tindakan berikut:

  • Untuk float: rSqrt dari IEEE-754.
  • Untuk bilangan kompleks: akar kuadrat kebalikan kompleks.
  • Untuk jenis kuantisasi: dequantize_op_quantize(rsqrt, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand tensor floating point atau jenis kompleks atau tensor terkuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result tensor dari jenis floating point atau kompleks atau tensor terkuantisasi per tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.5], [0.33333343, 0.2]]

 Contoh Lainnya

menyebar

Semantik

Menghasilkan tensor results yang sama dengan tensor inputs, kecuali beberapa slice yang ditentukan oleh scatter_indices diperbarui dengan nilai updates menggunakan update_computation.

Diagram berikut menunjukkan cara elemen di updates... memetakan elemen di results... menggunakan contoh konkret. Diagram ini memilih beberapa contoh indeks updates... dan menjelaskan secara mendetail indeks results... yang sesuai.

{i>scatter <i}(memencar)

Secara lebih formal, untuk semua update_index di index_space(updates[0]):

  • update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims].
  • update_scatter_index = update_index[update_scatter_dims...].
  • start_index ditentukan sebagai:
    • scatter_indices[si0, ..., :, ..., siN] dengan si adalah elemen individual di update_scatter_index dan : disisipkan pada indeks index_vector_dim, jika index_vector_dim < rank(scatter_indices).
    • [scatter_indices[update_scatter_index]] jika tidak.
  • Untuk d_input di axes(inputs[0]),
    • full_start_index[d_input] = start_index[d_start] jika d_input = scatter_dims_to_operand_dims[d_start].
    • full_start_index[d_input] = 0 jika tidak.
  • Untuk d_input di axes(inputs[0]),
    • full_batching_index[d_input] = update_scatter_index[d_start - (d_start < index_vector_dim ? 0 : 1)] jika d_input = input_batching_dims[i_batching] dan d_start = scatter_indices_batching_dims[i_batching].
    • full_batching_index[d_input] = 0 jika tidak.
  • update_window_index = update_index[update_window_dims...].
  • full_window_index = [wi0, ..., 0, ..., wiN] dengan wi adalah elemen individual di update_window_index, dan 0 disisipkan pada indeks dari inserted_window_dims dan input_batching_dims.
  • result_index = full_start_index + full_batching_index + full_window_index.

Dengan demikian, results = exec(schedule, inputs), dengan:

  • schedule adalah permutasi index_space(updates[0]) yang ditentukan oleh implementasi.
  • exec([update_index, ...], results) = exec([...], updated_results) dengan:
    • Jika result_index berada dalam batas untuk shape(results...)
    • updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )
    • updated_values = update_computation(results...[result_index], updates_converted)
    • updated_results adalah salinan results dengan results...[result_index] ditetapkan ke updated_values....
    • Atau
    • updated_results = results.
  • exec([], results) = results.

Jika indices_are_sorted adalah true, implementasi dapat mengasumsikan bahwa scatter_indices diurutkan sehubungan dengan scatter_dims_to_operand_dims, jika tidak, perilaku tidak ditentukan. Secara lebih formal, untuk semua i1 < i2 dari indices(result), full_start_index(i1) <= full_start_index(i2).

Jika unique_indices adalah true, implementasi dapat mengasumsikan bahwa semua indeks result_index yang tersebar bersifat unik. Jika unique_indices adalah true, tetapi indeks yang disebar tidak unik, perilakunya tidak ditentukan.

Input

Label Nama Jenis Batasan
(I1) inputs jumlah tensor atau tensor terkuantisasi per-tensor (C1), (C2), (C4-C6), (C11), (C13), (C18), (C21), (C23-C24)
(I2) scatter_indices tensor dari jenis bilangan bulat (C4), (C15), (C19), (C22)
(I3) updates jumlah variabel tensor atau tensor terkuantisasi per tensor (C3-C6), (C8)
(I4) update_window_dims Konstanta tensor 1 dimensi dari jenis si64 (C2), (C4), (C7-C8)
(I5) inserted_window_dims Konstanta tensor 1 dimensi jenis si64 (C2), (C4), (C9-C11)
(I6) input_batching_dims Konstanta tensor 1 dimensi dari jenis si64 (C2), (C4), (C9), (C12-13), (C17-18), (C20)
(I7) scatter_indices_batching_dims Konstanta tensor 1 dimensi dari jenis si64 (C14-C18)
(I8) scatter_dims_to_operand_dims Konstanta tensor 1 dimensi dari jenis si64 (C19-C21)
(I9) index_vector_dim konstanta dari jenis si64 (C4), (C16), (C19), (C22)
(I10) indices_are_sorted konstanta dari jenis i1
(I11) unique_indices konstanta dari jenis i1
(I12) update_computation fungsi (C23)

Output

Nama Jenis Batasan
results jumlah variabel tensor atau tensor terkuantisasi per tensor (C24-C25)

Batasan

  • (C1) same(shape(inputs...)).
  • (C2) `rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims)
    • size(input_batching_dims)`.
  • (C3) same(shape(updates...)).
  • (C4) shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes) dengan:
    • update_scatter_dim_sizes = shape(scatter_indices), tetapi ukuran dimensi scatter_indices yang sesuai dengan index_vector_dim tidak disertakan.
    • update_window_dim_sizes <= shape(inputs[0]), kecuali bahwa ukuran dimensi di inputs[0] yang sesuai dengan inserted_window_dims dan input_batching_dims tidak disertakan.
    • combine menempatkan update_scatter_dim_sizes pada sumbu yang sesuai dengan update_scatter_dims dan update_window_dim_sizes pada sumbu yang sesuai dengan update_window_dims.
  • (C5) 0 < size(inputs) = size(updates) = N.
  • (C6) element_type(updates...) = element_type(inputs...).
  • (C7) is_unique(update_window_dims) and is_sorted(update_window_dims).
  • (C8) 0 <= update_window_dims < rank(updates[0]).
  • (C9) is_unique(concatenate(inserted_window_dims, input_batching_dims))
  • (C10) is_sorted(inserted_window_dims).
  • (C11) 0 <= inserted_window_dims < rank(inputs[0]).
  • (C12) is_sorted(input_batching_dims).
  • (C13) 0 <= input_batching_dims < rank(inputs[0])).
  • (C14) is_unique(scatter_indices_batching_dims).
  • (C15) 0 <= scatter_indices_batching_dims < rank(scatter_indices).
  • (C16) index_vector_dim not in scatter_indices_batching_dims.
  • (C17) size(input_batching_dims) == size(scatter_indices_batching_dims).
  • (C18) dim(inputs[0], input_batching_dims...) = dim(scatter_indices, scatter_indices_batching_dims...).
  • (C19) size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1.
  • (C20) is_unique(concatenate(scatter_dims_to_operand_dims, input_batching_dims)).
  • (C21) 0 <= scatter_dims_to_operand_dims < rank(inputs[0]).
  • (C22) 0 <= index_vector_dim <= rank(scatter_indices).
  • (C23) update_computation memiliki jenis (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>), dengan is_promotable(element_type(inputs[i]), Ei).
  • (C24) shape(inputs...) = shape(results...).
  • (C25) element_type(results[i]) = Ei untuk semua i di [0,N).

Contoh

// %input: [
//          [
//           [[1, 2], [3, 4], [5, 6], [7, 8]],
//           [[9, 10],[11, 12], [13, 14], [15, 16]],
//           [[17, 18], [19, 20], [21, 22], [23, 24]]
//          ],
//          [
//           [[25, 26], [27, 28], [29, 30], [31, 32]],
//           [[33, 34], [35, 36], [37, 38], [39, 40]],
//           [[41, 42], [43, 44], [45, 46], [47, 48]]
//          ]
//         ]
// %scatter_indices: [
//                    [
//                     [[0, 0], [1, 0], [2, 1]],
//                     [[0, 1], [1, 1], [0, 9]]
//                    ],
//                    [
//                     [[0, 0], [2, 1], [2, 2]],
//                     [[1, 2], [0, 1], [1, 0]]
//                    ]
//                   ]
// %update: [
//           [
//            [[1, 1], [1, 1], [1, 1]],
//            [[1, 1], [1, 1], [1, 1]]
//           ],
//           [
//            [[1, 1], [1, 1], [1, 1]],
//            [[1, 1], [1, 1], [1, 1]]
//           ]
//          ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  scatter_dimension_numbers = #stablehlo.scatter<
    update_window_dims = [3, 4],
    inserted_window_dims = [1],
    input_batching_dims = [0],
    scatter_indices_batching_dims = [1],
    scatter_dims_to_operand_dims = [2, 1],
    index_vector_dim = 3>,
  indices_are_sorted = false,
  unique_indices = false
} : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>, tensor<2x2x3x2x2xi64>) -> tensor<2x3x4x2xi64>
// %result: [
//           [
//            [[3, 4], [6, 7], [6, 7], [7, 8]],
//            [[9, 10],[11, 12], [15, 16], [17, 18]],
//            [[17, 18], [19, 20], [22, 23], [24, 25]]
//           ],
//           [
//            [[25, 26], [28, 29], [30, 31], [31, 32]],
//            [[35, 36], [38, 39], [38, 39], [39, 40]],
//            [[41, 42], [44, 45], [46, 47], [47, 48]]
//           ]
//          ]

 Contoh Lainnya

pilih

Semantik

Menghasilkan tensor result dengan setiap elemen dipilih dari tensor on_true atau on_false berdasarkan nilai elemen pred yang sesuai. Secara lebih formal, result[result_index] = pred_element ? on_true[result_index] : on_false[result_index], dengan pred_element = rank(pred) = 0 ? pred[] : pred[result_index]. Untuk jenis terkuantisasi, menjalankan dequantize_select_quantize(pred, on_true, on_false, type(result)).

Input

Label Nama Jenis Batasan
(I1) pred tensor jenis i1 (C1)
(I2) on_true tensor atau tensor terkuantisasi per-tensor (C1-C2)
(I3) on_false tensor atau tensor terkuantisasi per tensor (C2)

Output

Nama Jenis Batasan
result tensor atau tensor terkuantisasi per tensor (C2)

Batasan

  • (C1) rank(pred) = 0 or shape(pred) = shape(on_true).
  • (C2) baseline_type(on_true) = baseline_type(on_false) = baseline_type(result).

Contoh

// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false) : (tensor<2x2xi1>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 2], [3, 8]]

 Contoh Lainnya

select_and_scatter

Semantik

Menyebarkan nilai dari tensor source menggunakan scatter berdasarkan hasil reduce_window dari tensor input menggunakan select dan menghasilkan tensor result.

Diagram berikut menunjukkan cara elemen di result dihitung dari operand dan source menggunakan contoh konkret.

select_and_scatter

Secara lebih formal:

  • selected_values = reduce_window_without_init(...) dengan input berikut:

    • inputs = [operand].
    • window_dimensions, window_strides, dan padding yang digunakan apa adanya.
    • base_dilations = windows_dilations = 1.
    • body ditentukan sebagai:
    def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>:
      return select(arg0, arg1) ? arg0 : arg1;
    

    dengan E = element_type(operand), dan reduce_window_without_init berfungsi persis seperti reduce_window, kecuali bahwa schedule dari reduce yang mendasarinya (lihat reduce) tidak menyertakan nilai init. Saat ini, tidak ditentukan apa yang akan terjadi jika jendela yang sesuai tidak memiliki nilai (#731).

  • result[result_index] = reduce([source_values], [init_value], [0], scatter) dengan:

    • source_values = [source[source_index] for source_index in source_indices].
    • selected_index(source_index) = operand_index jika selected_values[source_index] memiliki elemen operand dari operand_index.
    • source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index].

Input

Label Nama Jenis Batasan
(I1) operand tensor atau tensor terkuantisasi per-tensor (C1-C4), (C6), (C8-C11)
(I2) source tensor atau tensor terkuantisasi per tensor (C1), (C2)
(I3) init_value Tensor 0 dimensi atau tensor terkuantisasi per tensor (C3)
(I4) window_dimensions Konstanta tensor 1 dimensi dari jenis si64 (C2), (C4), (C5)
(I5) window_strides Konstanta tensor 1 dimensi dari jenis si64 (C2), (C6), (C7)
(I6) padding Konstanta tensor 2 dimensi dari jenis si64 (C2), (C8)
(I7) select fungsi (C9)
(I8) scatter fungsi (C10)

Output

Nama Jenis Batasan
result tensor atau tensor terkuantisasi per tensor (C11-C12)

Batasan

  • (C1) element_type(operand) = element_type(source).
  • (C2) shape(source) = num_windows dengan:
    • padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1].
    • is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape.
    • num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1.
  • (C3) element_type(init_value) = element_type(operand).
  • (C4) size(window_dimensions) = rank(operand).
  • (C5) 0 < window_dimensions.
  • (C6) size(window_strides) = rank(operand).
  • (C7) 0 < window_strides.
  • (C8) shape(padding) = [rank(operand), 2].
  • (C9) select memiliki jenis (tensor<E>, tensor<E>) -> tensor<i1> dengan E = element_type(operand).
  • (C10) scatter memiliki jenis (tensor<E>, tensor<E>) -> tensor<E> dengan is_promotable(element_type(operand), E).
  • (C11) shape(operand) = shape(result).
  • (C12) element_type(result) = E.

Contoh

// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GE>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    "stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  window_dimensions = array<i64: 3, 1>,
  window_strides = array<i64: 2, 1>,
  padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi64>, tensor<2x2xi64>, tensor<i64>) -> tensor<4x2xi64>
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]

 Contoh Lainnya

kirim

Semantik

Mengirim inputs ke saluran channel_id dan menghasilkan token result.

Jika is_host_transfer adalah true, operasi akan mentransfer data ke host. Jika tidak, data akan ditransfer ke perangkat lain. Artinya, ditentukan oleh implementasi. Flag ini menduplikasi informasi yang diberikan di channel_type, sehingga di masa mendatang kami berencana untuk menyimpan salah satunya saja (#666).

Input

Label Nama Jenis Batasan
(I1) inputs jumlah variabel tensor atau tensor terkuantisasi
(I2) token token
(I3) channel_id konstanta dari jenis si64
(I4) channel_type enum DEVICE_TO_DEVICE dan DEVICE_TO_HOST (C1)
(I5) is_host_transfer konstanta dari jenis i1 (C1)

Output

Nama Jenis
result token

Batasan

  • (C1) channel_type didefinisikan sebagai:
    • DEVICE_TO_HOST if is_host_transfer = true,
    • DEVICE_TO_DEVICE sebaliknya.

Contoh

%result = "stablehlo.send"(%operand, %token) {
  channel_handle = #stablehlo.channel_handle<handle = 1, type = 2>,
  is_host_transfer = true
} : (tensor<2x2xi64>, !stablehlo.token) -> !stablehlo.token

 Contoh Lainnya

shift_left

Semantik

Melakukan operasi pergeseran kiri element-wise pada tensor lhs dengan jumlah bit rhs dan menghasilkan tensor result.

Input

Label Nama Jenis Batasan
(I1) lhs tensor dari jenis bilangan bulat (C1)
(I2) rhs tensor dari jenis bilangan bulat (C1)

Output

Nama Jenis Batasan
result tensor dari jenis bilangan bulat (C1)

Batasan

  • (C1) type(lhs) = type(rhs) = type(result).

Contoh

// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-2, 0, 8]

 Contoh Lainnya

shift_right_arithmetic

Semantik

Melakukan operasi pergeseran kanan aritmetika element-wise pada tensor lhs dengan jumlah bit rhs dan menghasilkan tensor result.

Input

Label Nama Jenis Batasan
(I1) lhs tensor dari jenis bilangan bulat (C1)
(I2) rhs tensor dari jenis bilangan bulat (C1)

Output

Nama Jenis Batasan
result tensor dari jenis bilangan bulat (C1)

Batasan

  • (C1) type(lhs) = type(rhs) = type(result).

Contoh

// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-1, 0, 1]

 Contoh Lainnya

shift_right_logical

Semantik

Melakukan operasi shift kanan logis berbasis element pada tensor lhs berdasarkan jumlah bit rhs dan menghasilkan tensor result.

Input

Label Nama Jenis Batasan
(I1) lhs tensor dari jenis bilangan bulat (C1)
(I2) rhs tensor dari jenis bilangan bulat (C1)

Output

Nama Jenis Batasan
result tensor dari jenis bilangan bulat (C1)

Batasan

  • (C1) type(lhs) = type(rhs) = type(result).

Contoh

// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [9223372036854775807, 0, 1]

Contoh Lainnya

tanda

Semantik

Menampilkan tanda element-wise operand dan menghasilkan tensor result. Secara lebih formal, untuk setiap elemen x, semantik dapat dinyatakan menggunakan sintaksis Python sebagai berikut:

def sign(x):
  if is_integer(x):
    if compare(x, 0, LT, SIGNED): return -1
    if compare(x, 0, EQ, SIGNED): return 0
    return 1
  elif is_float(x):
    if is_nan(x): return NaN
    if compare(x, -0.0, EQ, FLOAT): return -0.0
    if compare(x, +0.0, EQ, FLOAT): return +0.0
    if compare(x, 0.0, LT, FLOAT): return -1.0
    return 1.0
  elif is_complex(x):
    if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
    if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
    return divide(x, convert(abs(x), type(x)))

Untuk jenis kuantisasi, lakukan dequantize_op_quantize(sign, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand tensor bilangan bulat bertanda tangan, floating point, atau jenis kompleks atau tensor terkuantisasi per tensor (C1)

Output

Nama Jenis Batasan
result tensor bilangan bulat bertanda tangan, floating point, atau tipe kompleks atau tensor terkuantisasi per-tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]

 Contoh Lainnya

sinus

Semantik

Melakukan operasi sinus berdasarkan elemen pada tensor operand dan menghasilkan tensor result. Bergantung pada jenis elemen, melakukan hal berikut:

  • Untuk float: sin dari IEEE-754.
  • Untuk bilangan kompleks: sinus kompleks.
  • Untuk jenis terkuantisasi: dequantize_op_quantize(sine, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand tensor floating point atau jenis kompleks atau tensor terkuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result tensor dari jenis floating point atau kompleks atau tensor terkuantisasi per tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.sine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [0.0, -1.0]]

 Contoh Lainnya

slice

Semantik

Mengekstrak slice dari operand menggunakan indeks awal yang dihitung secara statis dan menghasilkan tensor result. start_indices berisi indeks awal slice untuk setiap dimensi, limit_indices berisi indeks akhir (eksklusif) untuk slice untuk setiap dimensi, dan strides berisi stride untuk setiap dimensi.

Secara lebih formal, result[result_index] = operand[operand_index] dengan operand_index = start_indices + result_index * strides.

Input

Label Nama Jenis Batasan
(I1) operand tensor atau tensor terkuantisasi per tensor (C1-C3), (C5)
(I2) start_indices Konstanta tensor 1 dimensi dari jenis si64 (C2), (C3), (C5)
(I3) limit_indices Konstanta tensor 1 dimensi dari jenis si64 (C2), (C3), (C5)
(I4) strides Konstanta tensor 1 dimensi dari jenis si64 (C2), (C4)

Output

Nama Jenis Batasan
result tensor atau tensor terkuantisasi per tensor (C1), (C5)

Batasan

  • (C1) element_type(operand) = element_type(result).
  • (C2) size(start_indices) = size(limit_indices) = size(strides) = rank(operand).
  • (C3) 0 <= start_indices <= limit_indices <= shape(operand).
  • (C4) 0 < strides.
  • (C5) shape(result) = ceil((limit_indices - start_indices) / strides).

Contoh

// %operand: [
//            [0, 0, 0, 0],
//            [0, 0, 1, 1],
//            [0, 0, 1, 1]
//           ]
%result = "stablehlo.slice"(%operand) {
  start_indices = array<i64: 1, 2>,
  limit_indices = array<i64: 3, 4>,
  strides = array<i64: 1, 1>
} : (tensor<3x4xi64>) -> tensor<2x2xi64>
// % result: [
//            [1, 1],
//            [1, 1]
//           ]

 Contoh Lainnya

mengurutkan

Semantik

Mengurutkan slice inputs 1 dimensi di sepanjang dimensi dimension secara bersamaan, sesuai dengan comparator dan menghasilkan results.

Tidak seperti input serupa dalam operasi lain, dimension mengizinkan nilai negatif, dengan semantik yang dijelaskan di bawah. Di masa mendatang, hal ini mungkin tidak diizinkan karena alasan konsistensi (#1377).

Jika is_stable bernilai benar, pengurutan akan stabil, yaitu urutan relatif elemen yang dianggap sama oleh pembanding akan dipertahankan. Untuk kasus saat ada satu input, dua elemen e1 dan e2 dianggap sama oleh pembanding jika dan hanya jika comparator(e1, e2) = comparator(e2, e1) = false. Lihat formalisasi di bawah untuk mengetahui cara generalisasi ini ke beberapa input.

Secara lebih formal, untuk semua result_index di index_space(results[0]):

  • adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension.
  • result_slice = [ri0, ..., :, ..., riR-1] dengan riN adalah elemen individual di result_index, dan : disisipkan pada adjusted_dimension.
  • inputs_together = (inputs[0]..., ..., inputs[N-1]...).
  • results_together[result_slice] = sort(inputs_together[result_slice], comparator_together).
  • dengan sort mengurutkan slice 1 dimensi dalam urutan non-menurun yang mengharapkan comparator_together menampilkan true jika argumen sisi kiri lebih kecil dari argumen kedua sisi kanan.
  • def comparator_together(lhs_together, rhs_together):
      args = []
      for (lhs_el, rhs_el) in zip(lhs_together, rhs_together):
        args.append(lhs_el)
        args.append(rhs_el)
      return comparator(*args)
    
  • (results[0]..., ..., results[N-1]...) = results_together.

Input

Label Nama Jenis Batasan
(I1) inputs jumlah variabel tensor atau tensor terkuantisasi per tensor (C1-C5)
(I2) dimension konstanta dari jenis si64 (C4)
(I3) is_stable konstanta dari jenis i1
(I4) comparator fungsi (C5)

Output

Nama Jenis Batasan
results jumlah tensor atau tensor terkuantisasi per-tensor (C2), (C3)

Batasan

  • (C1) 0 < size(inputs).
  • (C2) type(inputs...) = type(results...).
  • (C3) same(shape(inputs...) + shape(results...)).
  • (C4) -R <= dimension < R, dengan R = rank(inputs[0]).
  • (C5) comparator memiliki jenis (tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1>, dengan Ei = element_type(inputs[i]).

Contoh

// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>, %arg3: tensor<i64>):
    %predicate = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GT>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    "stablehlo.return"(%predicate) : (tensor<i1>) -> ()
}) {
  dimension = 0 : i64,
  is_stable = true
} : (tensor<2x3xi64>, tensor<2x3xi64>) -> (tensor<2x3xi64>, tensor<2x3xi64>)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]

 Contoh Lainnya

sqrt

Semantik

Melakukan operasi akar kuadrat element-wise pada tensor operand dan menghasilkan tensor result. Bergantung pada jenis elemen, lakukan tindakan berikut:

  • Untuk float: squareRoot dari IEEE-754.
  • Untuk bilangan kompleks: akar kuadrat kompleks.
  • Untuk jenis kuantisasi: dequantize_op_quantize(sqrt, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand tensor floating point atau jenis kompleks atau tensor terkuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result tensor dari jenis floating point atau kompleks atau tensor terkuantisasi per tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [2.0, 3.0]]

 Contoh Lainnya

kurangi

Semantik

Melakukan pengurangan berbasis elemen dari dua tensor lhs dan rhs, serta menghasilkan tensor result. Bergantung pada jenis elemen, melakukan hal berikut:

  • Untuk bilangan bulat: pengurangan bilangan bulat.
  • Untuk float: subtraction dari IEEE-754.
  • Untuk bilangan kompleks: pengurangan kompleks.
  • Untuk jenis kuantisasi:
    • dequantize_op_quantize(subtract, lhs, rhs, type(result)).

Input

Label Nama Jenis Batasan
(I1) lhs tensor bilangan bulat, floating point, atau jenis kompleks atau tensor terkuantisasi per tensor (C1)
(I2) rhs tensor bilangan bulat, floating point, atau jenis kompleks atau tensor terkuantisasi per tensor (C1)

Output

Nama Jenis Batasan
result tensor bilangan bulat, floating point, atau jenis kompleks atau tensor terkuantisasi per tensor (C1)

Batasan

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Contoh

// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
// %result: [[1, 2], [3, 4]]

 Contoh Lainnya

tan

Semantik

Melakukan operasi tangen element-wise pada tensor operand dan menghasilkan tensor result. Bergantung pada jenis elemen, lakukan tindakan berikut:

  • Untuk float: tan dari IEEE-754.
  • Untuk bilangan kompleks: tangen kompleks.
  • Untuk jenis terkuantisasi: dequantize_op_quantize(tan, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand tensor floating point atau jenis kompleks atau tensor terkuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result tensor dari jenis floating point atau kompleks atau tensor terkuantisasi per tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.tan"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [
//           [0.0, 1.63312e+16],
//           [0.0, 5.44375e+15]
//          ]

 Contoh Lainnya

tanh

Semantik

Melakukan operasi tangen hiperbolik element-wise pada tensor operand dan menghasilkan tensor result. Bergantung pada jenis elemen, lakukan tindakan berikut:

  • Untuk float: tanh dari IEEE-754.
  • Untuk bilangan kompleks: tangen hiperbolik kompleks.
  • Untuk jenis terkuantisasi:
    • dequantize_op_quantize(tanh, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand tensor floating point atau jenis kompleks atau tensor terkuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result tensor dari jenis floating point atau kompleks atau tensor terkuantisasi per tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand) : (tensor<3xf32>) -> tensor<3xf32>
// %result: [-0.76159416, 0.0, 0.76159416]

 Contoh Lainnya

transpose

Semantik

Mengurutkan ulang dimensi tensor operand menggunakan permutation dan menghasilkan tensor result. Secara lebih formal, result[result_index] = operand[operand_index] dengan result_index[d] = operand_index[permutation[d]].

Input

Label Nama Jenis Batasan
(I1) operand tensor atau tensor terkuantisasi (C1-C4)
(I2) permutation Konstanta tensor 1 dimensi dari jenis si64 (C2-C4)

Output

Nama Jenis Batasan
result tensor atau tensor terkuantisasi (C1), (C3-C4)

Batasan

  • (C1) element_type(result) diberikan oleh:
    • element_type(operand), jika !is_per_axis_quantized(operand).
    • element_type(operand) kecuali jika quantization_dimension(operand) dan quantization_dimension(result) mungkin berbeda.
  • (C2) permutation adalah permutasi dari range(rank(operand)).
  • (C3) shape(result) = dim(operand, permutation...).
  • (C4) Jika is_per_axis_quantized(result), maka quantization_dimension(operand) = permutation(quantization_dimension(result)).

Contoh

// %operand: [
//            [[1,2], [3,4], [5,6]],
//            [[7,8], [9,10], [11,12]]
//           ]
%result = "stablehlo.transpose"(%operand) {
  permutation = array<i64: 2, 1, 0>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
// %result: [
//           [[1,7], [3,9], [5,11]],
//           [[2,8], [4,10], [6,12]]
//          ]

 Contoh Lainnya

triangular_solve

Semantik

Menyelesaikan batch sistem persamaan linear dengan matriks koefisien segitiga bawah atau atas.

Secara lebih formal, dengan a dan b, result[i0, ..., iR-3, :, :] adalah solusi untuk op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :] jika left_side adalah true atau x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :] jika left_side adalah false, yang menyelesaikan variabel x dengan op(a) ditentukan oleh transpose_a, yang dapat berupa salah satu dari hal berikut:

  • NO_TRANSPOSE: Menjalankan operasi menggunakan a sebagaimana adanya.
  • TRANSPOSE: Melakukan operasi pada transpose a.
  • ADJOINT: Melakukan operasi pada transpos konjugat a.

Data input hanya dibaca dari segitiga bawah a, jika lower adalah true atau segitiga atas a, jika tidak. Data output ditampilkan dalam segitiga yang sama; nilai dalam segitiga lainnya ditentukan oleh implementasi.

Jika unit_diagonal benar, implementasi dapat mengasumsikan bahwa elemen diagonal a sama dengan 1, jika tidak, perilakunya tidak ditentukan.

Untuk jenis kuantisasi, lakukan dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower, unit_diagonal, transpose_a), a, b, type(result)).

Input

Label Nama Jenis Batasan
(I1) a tensor dari jenis floating point atau kompleks atau tensor terkuantisasi per tensor (C1-C3)
(I2) b tensor floating point atau jenis kompleks atau tensor terkuantisasi per-tensor (C1-C4)
(I3) left_side konstanta dari jenis i1 (C3)
(I4) lower konstanta dari jenis i1
(I5) unit_diagonal konstanta dari jenis i1
(I6) transpose_a enum NO_TRANSPOSE, TRANSPOSE, dan ADJOINT

Output

Nama Jenis Batasan
result tensor dari jenis floating point atau kompleks atau tensor terkuantisasi per tensor (C1)

Batasan

  • (C1) baseline_element_type(a) = baseline_element_type(b).
  • (C2) 2 <= rank(a) = rank(b) = R.
  • (C3) Hubungan antara shape(a) dan shape(b) ditentukan sebagai berikut:
    • shape(a)[:-3] = shape(b)[:-3].
    • dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1).
  • (C4) baseline_type(b) = baseline_type(result).

Contoh

// %a = [
//       [1.0, 0.0, 0.0],
//       [2.0, 4.0, 0.0],
//       [3.0, 5.0, 6.0]
//      ]
// %b = [
//       [2.0, 0.0, 0.0],
//       [4.0, 8.0, 0.0],
//       [6.0, 10.0, 12.0]
//      ]
%result = "stablehlo.triangular_solve"(%a, %b) {
  left_side = true,
  lower = true,
  unit_diagonal = false,
  transpose_a = #stablehlo<transpose NO_TRANSPOSE>
} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// %result: [
//           [2.0, 0.0, 0.0],
//           [0.0, 2.0, 0.0],
//           [0.0, 0.0, 2.0]
//          ]

tuple

Semantik

Menghasilkan tuple result dari nilai val.

Input

Label Nama Jenis Batasan
(I1) val jumlah nilai variadik (C1)

Output

Nama Jenis Batasan
result tuple (C1)

Batasan

  • (C1) result memiliki jenis tuple<E0, ..., EN-1> dengan Ei = type(val[i]).

Contoh

// %val0: [1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1) : (tensor<2xf32>, tuple<tensor<i32>>) -> tuple<tensor<2xf32>, tuple<tensor<i32>>>
// %result: ([1.0, 2.0], (3))

 Contoh Lainnya

uniform_dequantize

Semantik

Melakukan konversi element-wise dari tensor terkuantisasi operand ke tensor floating point result sesuai dengan parameter kuantisasi yang ditentukan oleh jenis operand.

Secara lebih formal, result = dequantize(operand).

Input

Label Nama Jenis Batasan
(I1) operand tensor terkuantisasi (C1), (C2)

Output

Nama Jenis Batasan
result tensor jenis floating point (C1), (C2)

Batasan

  • (C1) shape(operand) = shape(result).
  • (C2) element_type(result) = expressed_type(operand).

Contoh

// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
// %result: [4.0, 15.0]

uniform_quantize

Semantik

Melakukan konversi element-wise dari tensor floating point atau tensor terkuantisasi operand ke tensor terkuantisasi result sesuai dengan parameter kuantisasi yang ditentukan oleh jenis result.

Secara lebih formal,

  • Jika is_float(operand):
    • result = quantize(operand, type(result)).
  • Jika is_quantized(operand):
    • float_result = dequantize(operand).
    • result = quantize(float_result, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand tensor dari jenis floating point atau terkuantisasi (C1), (C2)

Output

Nama Jenis Batasan
result tensor terkuantisasi (C1), C2

Batasan

  • (C1) shape(operand) = shape(result).
  • (C2) expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand).

Contoh

// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
// %result: [10, 10]

// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-20,0.2:-30}>>
// %result: [20, 45]

while

Semantik

Menghasilkan output dari mengeksekusi fungsi body 0 kali atau lebih saat fungsi cond menghasilkan output true. Secara lebih formal, semantik dapat dinyatakan menggunakan sintaksis Python sebagai berikut:

internal_state = operand
while cond(*internal_state):
  internal_state = body(*internal_state)
results = internal_state

Perilaku loop tanpa batas adalah TBD (#383).

Input

Label Nama Jenis Batasan
(I1) operand jumlah variabel tensor, tensor terkuantisasi, atau token (C1-C3)
(I2) cond fungsi (C1)
(I3) body fungsi (C2)

Output

Nama Jenis Batasan
results jumlah variabel tensor, tensor terkuantisasi, atau token (C3)

Batasan

  • (C1) cond memiliki jenis (T0, ..., TN-1) -> tensor<i1>, dengan Ti = type(operand[i]).
  • (C2) body memiliki jenis (T0, ..., TN-1) -> (T0, ..., TN-1), dengan Ti = type(operand[i]).
  • (C3) type(results...) = type(operand...).

Contoh

// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %cond = "stablehlo.compare"(%arg0, %ten) {
      comparison_direction = #stablehlo<comparison_direction LT>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    stablehlo.return %cond : tensor<i1>
  }, {
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %new_sum = stablehlo.add %arg1, %one : tensor<i64>
    %new_i = stablehlo.add %arg0, %one : tensor<i64>
    stablehlo.return %new_i, %new_sum : tensor<i64>, tensor<i64>
}) : (tensor<i64>, tensor<i64>) -> (tensor<i64>, tensor<i64>)
// %results0: 10
// %results1: 10

Contoh Lainnya

xor

Semantik

Melakukan XOR element-wise dari dua tensor lhs dan rhs serta menghasilkan tensor result. Bergantung pada jenis elemen, lakukan tindakan berikut:

  • Untuk boolean: XOR logis.
  • Untuk bilangan bulat: bitwise XOR.

Input

Label Nama Jenis Batasan
(I1) lhs tensor dari jenis boolean atau bilangan bulat (C1)
(I2) rhs tensor dari jenis boolean atau bilangan bulat (C1)

Output

Nama Jenis Batasan
result tensor dari jenis boolean atau bilangan bulat (C1)

Batasan

  • (C1) type(lhs) = type(rhs) = type(result).

Contoh

// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]

// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, false]]

 Contoh Lainnya

Interop Dialek

Saat ini, program StableHLO di dunia nyata terkadang berisi operasi yang tidak ditentukan oleh StableHLO.

Modul, Fungsi, Panggilan, dan Pengembalian

StableHLO menggunakan operasi MLIR upstream untuk ModuleOp, FuncOp, CallOp, dan ReturnOp. Hal ini dilakukan untuk interop yang lebih baik dengan mesin MLIR yang ada, karena banyak pass berguna yang ditulis dengan menargetkan FuncOp dan ModuleOp, dan banyak pipeline kompilasi memperkirakan operasi ini akan ada. Jaminan kompatibilitas penuh diterapkan ke operasi ini. Jika ada perubahan pada operasi ini dengan cara yang tidak kompatibel (yaitu penghapusan), padanan StableHLO akan ditambahkan untuk mempertahankan kompatibilitas.

CHLO

Opset CHLO berisi operasi tingkat lebih tinggi yang terurai menjadi StableHLO. Saat ini tidak ada jaminan kompatibilitas untuk CHLO. Untuk jaminan kompatibilitas, chlo-legalize-to-stablehlo pass harus digunakan sebelum serialisasi.

Operasi Bentuk

Ini adalah kasus penggunaan umum di komunitas untuk menggunakan operasi tertentu dari dialek MLIR inti dalam program StableHLO dinamis untuk melakukan komputasi bentuk. Biasanya, ini mencakup operasi dialek shape seperti shape_of atau num_elements, operasi dialek tensor seperti dim atau from_elements, dan jenis index bawaan.

Dynamism RFC > O2 menunjukkan bahwa hal ini berada di luar cakupan, tetapi beberapa dukungan untuk jenis index disertakan untuk tujuan interop. Tidak ada jaminan kompatibilitas untuk operasi atau jenis ini. Kartu shape-legalize-to-stablehlo dapat digunakan untuk mengonversi operasi ini menjadi operasi StableHLO yang didukung sepenuhnya.

Operasi yang Tidak Digunakan Lagi

Ada beberapa operasi StableHLO yang diwarisi dari MHLO yang tidak digunakan lagi dan akan dihapus dari StableHLO. Detail lengkap tentang penghapusan ini dapat ditemukan di Pembersihan StableHLO v1.0 #2283. Masalah pelacak untuk penghentian ini adalah #2340.

Operasi ini termasuk dalam beberapa kategori:

  • Kategori "Not in HLO" dari operasi StableHLO - awalnya merupakan bagian dari opset StableHLO, tetapi kemudian dianggap tidak cocok: broadcast, create_token, cross-replica-sum, dot, einsum, torch_index_select, unary_einsum (#3).
  • Operasi yang tidak digunakan - Operasi ini mungkin pernah berguna pada suatu saat, tetapi operasi tersebut belum dikembangkan sepenuhnya, atau pipeline yang menggunakan operasi ini telah difaktorkan ulang sehingga tidak memerlukannya lagi. Ini termasuk map, tuple (#598), perbandingan get_tuple_element, rng, complex #560, dan konvolusi window_reversal (#1181).

Beberapa operasi ini dapat dihapus dengan mudah karena dapat diekspresikan menggunakan operasi yang sudah ada (broadcast, create_token, cross-replica-sum, dot, unary_einsum) dan akan dihapus setelah periode kompatibilitas yang ada berlalu (6 bulan). Ops lainnya masih dipelajari untuk dihapus (perbandingan einsum, get_tuple_element, map, rng torch_index_select, tuple, complex, window_reversal). Menunggu masukan komunitas, ops ini akan dihapus, atau ditambahkan ke spesifikasi dengan dukungan penuh. Hingga ops futures ini diketahui, ops futures tersebut hanya dijamin kompatibilitasnya selama 6 bulan.

Eksekusi

Eksekusi berurutan

Program StableHLO dijalankan dengan memberikan nilai input ke fungsi main dan menghitung nilai output. Nilai output fungsi dihitung dengan menjalankan grafik operasi yang berakar pada operasi return yang sesuai.

Urutan eksekusi ditentukan oleh implementasi selama selaras dengan alur data, yaitu jika operasi dieksekusi sebelum penggunaannya. Di StableHLO, semua operasi yang menghasilkan efek samping menggunakan satu token dan menghasilkan satu token (beberapa token dapat dimultipleks menjadi satu token melalui after_all), sehingga urutan eksekusi efek samping juga selaras dengan alur data. Misalnya, dalam program di bawah ada dua kemungkinan urutan eksekusi: %0%1%2return dan %1%0%2return.

func.func @main() -> tensor<f64> {
  %0 = stablehlo.constant dense<1.0> : tensor<f64>
  %1 = stablehlo.constant dense<2.0> : tensor<f64>
  %2 = stablehlo.add %0, %1 : tensor<f64>
  return %2 : tensor<f64>
}

Secara lebih formal, proses StableHLO adalah kombinasi dari: 1) program StableHLO, 2) status operasi (belum dieksekusi, sudah dieksekusi), dan 3) nilai perantara yang sedang diproses. Proses dimulai dengan nilai input untuk fungsi main, dilanjutkan melalui grafik operasi yang memperbarui status operasi dan nilai perantara, dan diakhiri dengan nilai output. Formasi lebih lanjut akan ditentukan (#484).

Eksekusi paralel

Program StableHLO dapat dijalankan secara paralel, yang diatur ke dalam petak proses 2D num_replicas oleh num_partitions yang keduanya memiliki jenis ui32.

Dalam petak proses StableHLO, num_replicas * num_partitions proses StableHLO dieksekusi secara bersamaan. Setiap proses memiliki process_id = (replica_id, partition_id) unik, dengan replica_id di replica_ids = range(num_replicas) dan partition_id di partition_ids = range(num_partitions) yang keduanya memiliki jenis ui32.

Ukuran petak proses diketahui secara statis untuk setiap program (di masa mendatang, kami berencana menjadikannya bagian eksplisit dari program StableHLO #650), dan posisi dalam petak proses diketahui secara statis untuk setiap proses. Setiap proses memiliki akses ke posisinya dalam petak proses melalui operasi replica_id dan partition_id.

Dalam petak proses, semua program dapat sama (dalam gaya "Program tunggal, Beberapa Data"), semua dapat berbeda (dalam gaya "Beberapa Program, Beberapa Data"), atau sesuatu di antaranya. Di masa mendatang, kami berencana memperkenalkan dukungan untuk idiom lain guna menentukan program StableHLO paralel, termasuk GSPMD (#619).

Dalam petak proses, proses sebagian besar independen satu sama lain - proses tersebut memiliki status operasi terpisah, nilai input/antara/output terpisah dan sebagian besar operasi dijalankan secara terpisah di antara proses, dengan pengecualian sejumlah kecil operasi kolektif yang dijelaskan di bawah.

Mengingat bahwa eksekusi sebagian besar operasi hanya menggunakan nilai dari proses yang sama, biasanya menyebutkan nilai ini berdasarkan namanya menjadi tidak ambigu. Namun, saat mendeskripsikan semantik operasi kolektif, hal itu tidak memadai, dan yang menghasilkan notasi name@process_id untuk merujuk ke nilai name dalam proses tertentu. (Dari perspektif tersebut, name yang tidak memenuhi syarat dapat dilihat sebagai singkatan untuk name@(replica_id(), partition_id())).

Urutan eksekusi di seluruh proses ditentukan oleh implementasi, kecuali untuk sinkronisasi yang diperkenalkan oleh komunikasi titik ke titik dan operasi kolektif seperti yang dijelaskan di bawah ini.

Komunikasi titik ke titik

Proses StableHLO dapat berkomunikasi satu sama lain melalui saluran StableHLO. Channel diwakili oleh ID positif jenis si64. Melalui berbagai operasi, Anda dapat mengirim nilai ke saluran dan menerimanya dari saluran.

Formalisasi lebih lanjut, misalnya asal ID saluran ini, cara program memprosesnya, dan jenis sinkronisasi yang diperkenalkan olehnya, masih akan ditentukan (#484).

Komunikasi streaming

Setiap proses StableHLO memiliki akses ke dua antarmuka streaming:

  • Infeed yang dapat dibaca.
  • Outfeed yang dapat ditulis.

Tidak seperti saluran, yang digunakan untuk berkomunikasi antar-proses sehingga memiliki proses di kedua ujungnya, feed infeed dan outfeed memiliki penerapan akhir yang ditentukan.

Formalisasi lebih lanjut, misalnya, bagaimana komunikasi streaming memengaruhi urutan eksekusi dan jenis sinkronisasi yang diperkenalkan olehnya, adalah TBD (#484).

Operasi kolektif

Ada enam operasi kolektif di StableHLO: all_gather, all_reduce, all_to_all, collective_broadcast, collective_permute, dan reduce_scatter. Semua operasi ini membagi proses dalam petak proses StableHLO menjadi grup proses StableHLO dan mengeksekusi komputasi bersama dalam setiap grup proses, secara independen dari grup proses lainnya.

Dalam setiap grup proses, operasi kolektif dapat menyebabkan penghalang sinkronisasi. Formalisasi lebih lanjut, misalnya menguraikan kapan tepatnya sinkronisasi ini terjadi, bagaimana prosesnya sampai pada hambatan ini, dan apa yang terjadi jika tidak, akan ditentukan (#484).

Jika grup proses melibatkan komunikasi lintas partisi, yaitu ada proses dalam grup proses yang ID partisinya berbeda, maka eksekusi operasi kolektif memerlukan saluran, dan operasi kolektif harus menyediakan channel_id positif dari jenis si64. Komunikasi lintas replika tidak membutuhkan saluran.

Komputasi yang dilakukan oleh operasi kolektif bersifat khusus untuk setiap operasi dan dijelaskan di setiap bagian operasi di atas. Namun, strategi yang digunakan untuk membagi petak proses menjadi grup proses dibagikan di antara operasi ini dan dijelaskan di bagian ini. Secara lebih formal, StableHLO mendukung empat strategi berikut.

cross_replica

Hanya komunikasi lintas replika yang terjadi dalam setiap grup proses. Strategi ini menggunakan replica_groups - daftar daftar ID replika - dan menghitung produk Kartesius replica_groups dengan partition_ids. replica_groups harus memiliki elemen unik dan mencakup semua replica_ids. Secara lebih formal, menggunakan sintaksis Python:

def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
  for replica_group in replica_groups:
    for partition_id in partition_ids:
      process_group = []
      for replica_id in replica_group:
        process_group.append((replica_id, partition_id))
      yield process_group

Misalnya, untuk replica_groups = [[0, 1], [2, 3]] dan num_partitions = 2, cross_replica akan menghasilkan [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]].

cross_partition

Hanya komunikasi lintas partisi yang terjadi dalam setiap grup proses. Strategi ini menggunakan partition_groups - daftar daftar ID partisi - dan menghitung produk Kartesius partition_groups dengan replica_ids. partition_groups harus memiliki elemen unik dan mencakup semua partition_ids. Secara lebih formal, menggunakan sintaksis Python:

def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
  for partition_group in partition_groups:
    for replica_id in replica_ids:
      process_group = []
      for partition_id in partition_group:
        process_group.append((replica_id, partition_id))
      yield process_group

Misalnya, untuk partition_groups = [[0, 1]] dan num_replicas = 4, cross_partition akan menghasilkan [[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]].

cross_replica_and_partition

Komunikasi lintas-replikasi dan lintas-partisi dapat terjadi dalam setiap grup proses. Strategi ini menggunakan replica_groups - daftar daftar ID replika - dan menghitung produk Kartesius dari setiap replica_group dengan partition_ids. replica_groups harus memiliki elemen unik dan mencakup semua replica_ids. Secara lebih formal, menggunakan sintaksis Python:

def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
  for replica_group in replica_groups:
    process_group = []
    for partition_id in partition_ids:
      for replica_id in replica_group:
        process_group.append((replica_id, partition_id))
    yield process_group

Misalnya, untuk replica_groups = [[0, 1], [2, 3]] dan num_partitions = 2, cross_replica_and_partition akan menghasilkan [[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]].

flattened_ids

Strategi ini menggunakan flattened_id_groups - daftar ID proses yang "disatukan" dalam bentuk replica_id * num_partitions + partition_id - dan mengubahnya menjadi ID proses. flattened_id_groups harus memiliki elemen unik dan mencakup semua process_ids. Secara lebih formal, menggunakan sintaksis Python:

def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
  for flattened_id_group in flattened_id_groups:
    process_group = []
    for flattened_id in flattened_id_group:
      replica_id = flattened_id // num_partitions
      partition_id = flattened_id % num_partitions
      process_group.append((replica_id, partition_id))
    yield process_group

Misalnya, untuk flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]], num_replicas = 4, dan num_partitions = 2, flattened_ids akan menghasilkan [[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]].

Akurasi

Saat ini, StableHLO tidak memberikan jaminan tentang akurasi numerik, tetapi hal ini dapat berubah pada masa mendatang (#1156).

Semantik eksekusi operasi kuantisasi

Penafsiran operasi StableHLO yang dikuantisasi dapat bervariasi, bergantung pada persyaratan dan kemampuan hardware. Misalnya, beberapa hardware dapat memilih untuk menafsirkan operasi kuantisasi menggunakan strategi "dequantize, perform floating-point operation, and finally quantize". Yang lain dapat melakukan seluruh komputasi dengan aritmetika bilangan bulat. Oleh karena itu, interpretasi operasi StableHLO yang dikuantisasi ditentukan secara eksklusif oleh penerapan tertentu. Penafsiran kuantisasi campuran (#1575) harus didasarkan pada semantiknya seperti yang ditentukan dalam spesifikasi (melalui 1792).

Error

Program StableHLO divalidasi melalui serangkaian batasan yang ekstensif untuk operasi individual, yang mengesampingkan banyak class error sebelum runtime. Namun, kondisi error masih mungkin terjadi, misalnya melalui overflow bilangan bulat, akses keluar batas, dll. Kecuali jika dinyatakan secara eksplisit, semua error ini akan menghasilkan perilaku yang ditentukan implementasi, tetapi hal ini dapat berubah di masa mendatang (#1157).

Pengecualian floating point

Sebagai pengecualian untuk aturan ini, pengecualian floating point dalam program StableHLO memiliki perilaku yang ditentukan dengan baik. Operasi yang menghasilkan pengecualian yang ditentukan oleh standar IEEE-754 (operasi tidak valid, pembagian dengan nol, overflow, underflow, atau pengecualian yang tidak tepat) menghasilkan hasil default (seperti yang ditentukan dalam standar) dan melanjutkan eksekusi tanpa menaikkan flag status yang sesuai; mirip dengan penanganan pengecualian raiseNoFlag dari standar. Pengecualian untuk operasi nonstandar (misalnya, aritmetika kompleks dan fungsi transendental tertentu) ditentukan oleh implementasi.

Ketidakcocokan bentuk

StabilHLO mendukung tensor berbentuk dinamis. Namun, bentuk harus sesuai pada runtime, jika tidak, perilakunya tidak ditentukan. StabilHLO tidak secara eksplisit menyediakan operasi yang dapat menyatakan bahwa tensor memiliki bentuk tertentu saat runtime. Produsen bertanggung jawab untuk membuat kode yang benar.

Sebagai contoh spesifik, program di bawah valid. Namun, saat runtime, bentuk %arg0 dan %arg1 harus sama. Jika tidak, perilaku program tidak ditentukan:

func.func @foo(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
    %0 = stablehlo.add %arg0, %arg1 : tensor<?xi32>
    return %0 : tensor<?xi32>
}

Notasi

Untuk menjelaskan sintaksis, dokumen ini menggunakan ragam ISO yang dimodifikasi dari sintaksis EBNF (ISO/IEC 14977:1996, Wikipedia), dengan dua modifikasi: 1) aturan ditentukan menggunakan ::=, bukan =,

2) penyambungan dinyatakan menggunakan juxtaposition, bukan ,.

Untuk mendeskripsikan semantik (yaitu dalam bagian "Jenis", "Konstanta", dan "Ops"), kami menggunakan formula yang didasarkan pada sintaksis Python yang diperluas dengan dukungan untuk mengekspresikan operasi array secara ringkas seperti yang dijelaskan di bawah. Hal ini berfungsi dengan baik untuk cuplikan kode kecil, tetapi dalam kasus yang jarang terjadi saat cuplikan kode yang lebih besar diperlukan, kita menggunakan sintaksis Python vanilla yang selalu diperkenalkan secara eksplisit.

Formula

Mari kita pelajari cara kerja formula berdasarkan contoh dari spesifikasi dot_general. Salah satu batasan untuk operasi ini terlihat seperti berikut: dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...).

Nama yang digunakan dalam formula ini berasal dari dua sumber: 1) fungsi global, yaitu dim, 2) definisi anggota elemen program yang sesuai, yaitu input lhs, lhs_batching_dimensions, rhs, dan rhs_batching_dimensions yang ditentukan di bagian "Input" di dot_general.

Seperti yang disebutkan di atas, sintaksis formula ini berbasis Python dengan beberapa ekstensi yang berorientasi pada ringkasan. Untuk memahami formula tersebut, mari kita ubah menjadi sintaks vanilla Python.

A) Dalam formula ini, kita menggunakan = untuk merepresentasikan kesetaraan, sehingga langkah pertama untuk mendapatkan sintaksis Python adalah mengganti = dengan ==, sebagai berikut: dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...).

B) Selain itu, formula ini mendukung elipsis (...) yang mengubah ekspresi skalar menjadi ekspresi tensor. Singkatnya, f(xs...) secara kasar berarti "untuk setiap x skalar dalam tensor xs, komputasikan f(x) skalar, lalu tampilkan semua hasil skalar ini bersama-sama sebagai hasil tensor". Dalam sintaksis Python vanilla, formula contoh kita berubah menjadi: [dim(lhs, dim1) for dim1 in lhs_batching_dimensions] == [dim(rhs, dim2) for dim2 in rhs_batching_dimensions].

Berkat elipsis, Anda sering kali dapat menghindari bekerja pada tingkat skalar individual. Namun, dalam beberapa kasus rumit, sintaksis semi-informal tingkat rendah dapat digunakan seperti dalam formula start_indices[bi0, ..., :, ..., biN] dari spesifikasi gather. Untuk mempersingkat, kami tidak memberikan formalisme yang tepat untuk menerjemahkan sintaksis tersebut ke Python vanilla, dengan harapan bahwa sintaksis tersebut masih dapat dipahami secara intuitif berdasarkan kasus per kasus. Beri tahu kami jika beberapa formula tertentu terlihat buram, dan kami akan mencoba meningkatkannya.

Selain itu, Anda akan melihat bahwa formula menggunakan elipsis untuk memperluas semua jenis daftar, termasuk tensor, daftar tensor (yang misalnya dapat muncul dari jumlah variabel tensor), dll. Ini adalah area lain tempat kita tidak memberikan formalisme yang tepat (misalnya, daftar bahkan bukan bagian dari sistem jenis StableHLO) dan sebagai gantinya mengandalkan pemahaman intuitif.

C) Kendaraan notasi penting terakhir yang kami gunakan adalah penyiaran implisit. Meskipun opset StableHLO tidak mendukung siaran implisit, formula mendukungnya, juga untuk layanan ringkas. Singkatnya, jika skalar digunakan dalam konteks yang mengharapkan tensor, skalar akan disiarkan ke bentuk yang diharapkan.

Untuk melanjutkan contoh dot_general, berikut batasan lainnya: 0 <= lhs_batching_dimensions < rank(lhs). Seperti yang ditentukan dalam spesifikasi dot_general, lhs_batching_dimensions adalah tensor, tetapi 0 dan rank(lhs) adalah skalar. Setelah kita menerapkan siaran implisit, formulanya akan menjadi [0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)].

Saat diterapkan ke operasi dot_general tertentu, formula ini akan dievaluasi menjadi tensor boolean. Jika formula digunakan sebagai batasan, batasan akan berlaku jika formula dievaluasi menjadi true atau tensor yang hanya memiliki elemen true.

Nama

Dalam formula, cakupan leksikonis mencakup: 1) fungsi global, 2) definisi anggota,

3) definisi lokal. Daftar fungsi global diberikan di bawah ini. Daftar definisi elemen bergantung pada elemen program tempat notasi diterapkan:

  • Untuk operasi, definisi anggota menyertakan nama yang diperkenalkan di bagian "Input" dan "Output".
  • Untuk hal lainnya, definisi anggota mencakup bagian struktural dari elemen program, yang diberi nama sesuai dengan non-terminal EBNF yang sesuai. Sebagian besar waktu, nama bagian struktural ini diperoleh dengan mengonversi nama non-terminal ke snake case (misalnya, IntegerLiteral => integer_literal), tetapi terkadang nama disingkat dalam proses (misalnya, QuantizationStorageType => storage_type) dalam hal ini nama diperkenalkan secara eksplisit mirip dengan bagian "Input"/"Output" dalam spesifikasi operasi.
  • Selain itu, definisi anggota selalu menyertakan self untuk merujuk ke elemen program yang sesuai.

Nilai

Saat dievaluasi, formula akan berfungsi dengan jenis nilai berikut: 1) Value (nilai sebenarnya, misalnya dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>; nilainya selalu diketahui), 2) Placeholder (nilai mendatang, misalnya lhs, rhs, atau result; nilai sebenarnya belum diketahui, hanya jenisnya yang diketahui), 3) Type (jenis seperti yang ditentukan di bagian "Jenis"), 4) Function (fungsi global seperti yang ditentukan di bagian "Fungsi").

Bergantung pada konteksnya, nama mungkin merujuk ke nilai yang berbeda. Lebih khususnya, bagian "Semantik" untuk operasi (dan yang setara untuk elemen program lainnya) menentukan logika runtime, sehingga semua input tersedia sebagai Value. Sebaliknya, bagian "Batasan" untuk operasi (dan yang setara) menentukan logika "waktu kompilasi", yaitu sesuatu yang biasanya dieksekusi sebelum runtime, sehingga hanya input konstan yang tersedia sebagai Value dan input lainnya hanya tersedia sebagai Placeholder.

Nama Di "Semantik" Di "Batasan"
Fungsi global Function Function
Input konstan Value Value
Input non-konstanta Value Placeholder
Output Value Placeholder
Definisi lokal Bergantung pada definisi Bergantung pada definisi

Mari kita lihat contoh operasi transpose:

%result = "stablehlo.transpose"(%operand) {
  permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>

Untuk operasi ini, permutation adalah konstanta, sehingga tersedia sebagai Value dalam semantik dan batasan. Sebaliknya, operand dan result tersedia sebagai Value dalam semantik, tetapi hanya sebagai Placeholder dalam batasan.

Fungsi

Konstruksi jenis

Tidak ada fungsi yang dapat digunakan untuk membuat jenis. Sebagai gantinya, kita langsung menggunakan sintaksis jenis karena biasanya lebih ringkas. Misalnya, (tensor<E>, tensor<E>) -> (tensor<E>), bukan function_type( [tensor_type([], E), tensor_type([], E)], [tensor_type([], E)]).

Fungsi pada jenis

  • element_type ditentukan pada jenis tensor dan jenis tensor terkuantisasi, dan masing-masing menampilkan bagian TensorElementType atau QuantizedTensorElementType dari TensorType atau QuantizedTensorType yang sesuai.
def element_type(x: Value | Placeholder | Type):
 if type(x) == TensorType:
    return tensor_element_type(x)
  if type(x) == QuantizedTensorType:
    return quantized_tensor_element_type(x)
  if type(x) is not Type:
    return element_type(type(x))
  • is_per_axis_quantized(x: Value | Placeholder | Type) -> Value adalah pintasan untuk is_quantized(x) and quantization_dimension(x) is not None.

  • is_per_tensor_quantized(x: Value | Placeholder | Type) -> Value adalah pintasan is_quantized(x) and quantization_dimension(x) is None.

  • is_promotable(x: Type, y: Type) -> bool memeriksa apakah jenis x dapat dipromosikan ke jenis y. Jika x dan y adalah QuantizedTensorElementType, promosi hanya diterapkan ke storage_type. Versi promosi khusus ini saat ini digunakan dalam konteks komputasi pengurangan (lihat RFC untuk mengetahui detail selengkapnya).

def is_promotable(x: Type, y: Type) -> Value:
  is_same_type = (is_bool(x) and is_bool(y)) or
    (is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
    (is_complex(x) and is_complex(y)) or
    (is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))

  if is_same_type == False:
    return False

  if is_integer(x) or is_float(x):
    return bitwidth(x) <= bitwidth(y)

  if is_complex(x):
    return bitwidth(element_type(x)) <= bitwidth(element_type(y))

  if is_quantized(x):
    return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))

  return false
  • is_quantized(x: Value | Placeholder | Type) -> Value adalah pintasan untuk is_quantized_tensor_element_type(x).

  • is_type_name(x: Value | Placeholder | Type) -> Value. Tersedia untuk semua jenis. Misalnya, is_float(x) menampilkan true jika x adalah FloatType. Jika x adalah nilai atau placeholder, fungsi ini adalah pintasan untuk is_type_name(type(x)).

  • max_value(x: Type) -> Value menampilkan nilai maksimum TensorElementType. Jika x bukan TensorElementType, None akan ditampilkan.

  • min_value(x: Type) -> Value menampilkan nilai minimum yang memungkinkan dari TensorElementType. Jika x bukan TensorElementType, None akan ditampilkan.

  • member_name(x: Value | Placeholder | Type) -> Any. Tersedia untuk semua definisi anggota member_name dari semua jenis. Misalnya, tensor_element_type(x) menampilkan bagian TensorElementType dari TensorType yang sesuai. Jika x adalah nilai atau placeholder, fungsi ini adalah pintasan untuk member_name(type(x)). Jika x bukan jenis yang memiliki anggota yang sesuai, atau nilai atau placeholder dari jenis tersebut, tampilkan None.

  • is_empty_algorithm(*args: Type) memeriksa apakah semua kolom algoritma titik ditetapkan ke None. Hal ini diperlukan karena algoritma titik memiliki perilaku default yang ditentukan implementasi, sehingga penentuan nilai default akan salah.

Konstruksi nilai

  • operation_name(*xs: Value | Type) -> Value. Tersedia untuk semua operasi. Misalnya, add(lhs, rhs) mengambil dua nilai tensor lhs dan rhs dan menampilkan output evaluasi operasi add dengan input ini. Untuk beberapa operasi, misalnya broadcast_in_dim, jenis outputnya adalah "bearing beban", yaitu yang diperlukan untuk mengevaluasi operasi. Dalam hal ini, fungsi mengambil jenis ini sebagai argumen.

Fungsi pada nilai

  • Semua operator dan fungsi Python tersedia. Misalnya, notasi subscription dan slicing dari Python tersedia untuk mengindeks ke dalam tensor, tensor kuantisasi, dan tuple.

  • to_destination_type(x: Value, destination_type: Type) -> Value ditentukan pada tensor dan menampilkan nilai x yang dikonversi berdasarkan type(x) dan destination_type sebagai berikut:

def to_destination_type(x: Value, destination_type: Type) -> Value:
  if type(x) == destination_type:
    return x

  if is_quantized(destination_type):
    if is_quantized(type(x)):
      return quantize(x, destination_type)
    assert is_float(type(x))
    return quantize(x, destination_type)

  if is_quantized(type(x)):
    assert destination_type = expressed_type(type(x))
    return dequantize(type(x))

  return convert(x, destination_type)

Ada diskusi awal tentang penggabungan operasi convert, uniform_quantize, dan uniform_dequantize (#1576). Setelah penggabungan, kita tidak memerlukan fungsi di atas dan dapat menggunakan nama operasi untuk convert.

  • is_nan(x: Value) -> Value ditentukan pada tensor dan menampilkan true jika semua elemen x adalah NaN atau false jika tidak. Jika x bukan tensor, akan menampilkan None.

  • is_sorted(x: Value) -> Value ditentukan pada tensor dan menampilkan true jika elemen x diurutkan dalam urutan menaik sehubungan dengan urutan leksikografis menaik dari indeksnya, atau false jika tidak. Jika x bukan tensor, None akan ditampilkan.

  • is_unique(x: Value) -> Value ditentukan pada tensor dan menampilkan true jika x tidak memiliki elemen duplikat atau false jika tidak. Jika x bukan tensor, None akan ditampilkan.

  • member_name(x: Value) -> Any ditentukan untuk semua definisi anggota member_name dari semua nilai. Misalnya, real_part(x) menampilkan bagian RealPart dari ComplexConstant yang sesuai. Jika x bukan nilai yang memiliki anggota yang sesuai, None akan ditampilkan.

  • same(x: Value) -> Value ditentukan pada tensor dan menampilkan true jika elemen x semuanya sama satu sama lain atau false jika tidak. Jika tensor tidak memiliki elemen, hal itu dianggap sebagai "semua sama satu sama lain", yaitu fungsi menampilkan true. Jika x bukan tensor, None akan ditampilkan.

  • split(x: Value, num_results: Value, axis: Value) -> Value ditentukan pada tensor dan menampilkan slice num_results dari x di sepanjang sumbu axis. Jika x bukan tensor atau dim(x, axis) % num_results != 0, None akan ditampilkan.

  • is_defined_in_parent_scope(x: Value) -> Value ditentukan pada string dan menampilkan true jika x adalah nama fungsi yang ditentukan dalam cakupan yang sama dengan fungsi induk dari operasi yang relevan.

  • is_namespaced_op_name(x: Value) -> Value ditentukan pada string dan menampilkan true jika x adalah nama operasi yang valid, yaitu mengikuti ekspresi reguler berikut: [a-zA-Z][a-zA-Z0-9_]*([.][a-zA-Z0-9_$]+)+

Komputasi bentuk

  • axes(x: Value | Placeholder | Type) -> Value adalah pintasan untuk range(rank(x)).

  • dim(x: Value | Placeholder | Type, axis: Value) -> Value adalah pintasan untuk shape(x)[axis].

  • dims(x: Value | Placeholder | Type, axes: List) -> List adalah pintasan untuk list(map(lambda axis: dim(x, axis), axes)).

  • index_space(x: Value | Placeholder | Type) -> Value ditentukan pada tensor dan menampilkan indeks size(x) untuk TensorType yang sesuai yang diurutkan dalam urutan leksikografis menaik, yaitu [0, ..., 0], [0, ..., 1], ..., shape(x) - 1. Jika x bukan jenis tensor, jenis tensor kuantisasi, atau nilai atau placeholder dari salah satu jenis ini, None akan ditampilkan.

  • rank(x: Value | Placeholder | Type) -> Value adalah pintasan untuk size(shape(x)).

  • shape(x: Value | Placeholder | Type) -> Value ditentukan di bagian "Fungsi pada jenis" melalui member_name.

  • size(x: Value | Placeholder | Type) -> Value adalah pintasan untuk reduce(lambda x, y: x * y, shape(x)).

Komputasi kuantisasi

  • def baseline_element_type(x: Value | Placeholder | Type) -> Type adalah pintasan element_type(baseline_type(x)).

  • baseline_type ditentukan pada jenis tensor dan jenis tensor terkuantisasi, serta mengubahnya menjadi "dasar pengukuran", yaitu jenis dengan bentuk yang sama tetapi dengan parameter kuantisasi jenis elemen yang direset ke nilai default. Hal ini digunakan sebagai trik praktis untuk membandingkan jenis tensor dan tensor terkuantisasi secara seragam, yang cukup sering diperlukan. Untuk jenis kuantisasi, hal ini memungkinkan perbandingan jenis yang mengabaikan parameter kuantisasi, yaitu, shape, storage_type, expressed_type, storage_min, storage_max, dan quantization_dimension (untuk jenis kuantisasi per sumbu) harus cocok, tetapi scales dan zero points dapat berbeda.

def baseline_type(x: Value | Placeholder | Type) -> Type:
  if type(x) == TensorType:
    return x
  if type(x) == QuantizedTensorType:
    element_type = quantized_tensor_element_type(x)
    baseline_element_type = QuantizedTensorElementType(
      storage_type = storage_type(element_type),
      storage_min = storage_min(element_type),
      storage_max = storage_max(element_type),
      expressed_type = expressed_type(element_type),
      quantization_dimension = quantization_dimension(element_type),
      scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
      zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
    return QuantizedTensorType(shape(x), baseline_element_type)
  if type(x) is not Type:
    return baseline_element_type(type(x))
  • dequantize ditentukan pada jenis tensor terkuantisasi dan mengubahnya menjadi jenis tensor floating point. Hal ini terjadi melalui konversi elemen kuantisasi yang mewakili nilai bilangan bulat dari jenis penyimpanan menjadi nilai floating point yang sesuai dari jenis yang dinyatakan menggunakan titik nol dan skala yang terkait dengan jenis elemen kuantisasi.
def compute_zero_points(quantized_type, result_type):
  if is_per_tensor_quantized(quantized_type):
    return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
  if is_per_axis_quantized(quantized_type):
    for i in index_space(result_type):
      d = quantization_dimension(quantized_type)
      zero_points[i] = zero_points(quantized_type)[i[d]]
    return zero_points

def compute_scales(quantized_type, result_type):
  if is_per_tensor_quantized(quantized_type):
    return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
            type(result_type))
  if is_per_axis_quantized(quantized_type):
    for i in index_space(result_type):
      d = quantization_dimension(quantized_type)
      scales[i] = scales(quantized_type)[i[d]]
    return scales

def dequantize(x: Value) -> Value:
  assert is_quantized(x)
  x_storage = bitcast_convert(x, storage_type(x))
  x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
  x_expressed_sub = convert(x_storage_sub, expressed_type(x))
  return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
  • quantize ditentukan pada jenis tensor floating point dan mengubahnya menjadi jenis tensor terkuantisasi. Hal ini terjadi melalui konversi nilai floating point dari jenis yang dinyatakan menjadi nilai bilangan bulat yang sesuai dari jenis penyimpanan menggunakan titik nol dan skala yang terkait dengan jenis elemen yang dikuantisasi.
def quantize(x: Value, result_type: Type) -> Value:
  assert is_float(x) and is_quantized(result_type)
  zero_points = compute_zero_points(result_type, TensorType(shape(x), storage_type(result_type)))
  converted_zero_points = convert(zero_points, expressed_type(result_type))
  converted_min = convert(storage_min(result_type), expressed_type(result_type))
  converted_max = convert(storage_max(result_type), expressed_type(result_type))

  x_scaled = x / compute_scales(result_type, type(x))
  x_scaled_add_zp = x_scaled + converted_zero_points
  x_clamped = clamp(converted_min, x_scaled_add_zp, converted_max)
  x_rounded = round_nearest_even(x_clamped)
  return convert(x_rounded, result_type)
  • dequantize_op_quantize digunakan untuk menentukan komputasi berbasis elemen pada tensor terkuantisasi. Fungsi ini melakukan dekuantisasi, yaitu mengubah elemen kuantisasi menjadi jenis yang dinyatakan, lalu melakukan operasi, lalu melakukan kuantisasi, yaitu mengubah hasil kembali menjadi jenis penyimpanannya. Saat ini, fungsi ini hanya berfungsi untuk kuantisasi per tensor. Kuantifikasi per sumbu sedang dalam proses (#1574).
def dequantize_op_quantize(op, *inputs_and_output_type):
  inputs = inputs_and_output_type[:-1]
  output_type = inputs_and_output_type[-1]

  float_inputs = map(dequantize, inputs)
  float_result = op(*float_inputs)
  return quantize(float_result, output_type)

def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
  inputs = inputs_and_output_type[:-3]
  float_inputs = map(dequantize, inputs)
  float_results = op(*float_inputs)
  return map(quantize, float_results, inputs_and_output_type[-3:])

def dequantize_compare(lhs, rhs, comparison_direction):
  float_lhs = dequantize(lhs)
  float_rhs = dequantize(rhs)
  return compare(float_lhs, float_rhs, comparison_direction, FLOAT)

def dequantize_select_quantize(pred, on_true, on_false, output_type):
  float_on_true = dequantize(on_true)
  float_on_false = dequantize(on_false)
  float_result = select(pred, float_on_true, float_on_false)
  return quantize(float_result, output_type)
  • hybrid_dequantize_then_op digunakan untuk menentukan kuantisasi khusus bobot untuk operasi campuran yang menerima lhs dalam floating point dan rhs dalam jenis kuantisasi. Fungsi ini mendekuantisasikan input yang dikuantisasikan ke dalam jenis yang dinyatakan dan melakukan komputasi dalam float. Jenis elemen tensor lhs float dan jenis yang dinyatakan dari tensor rhs kuantisasi harus identik.
def hybrid_dequantize_then_op(op, lhs, rhs):
  assert(is_float(lhs) and is_quantized(rhs) and element_type(lhs) == expressed_type(rhs))
  return op(lhs, dequantize(rhs))

Komputasi petak

  • cross_partition(replica_groups: Value) -> Value. Lihat bagian "cross_replica" di atas.

  • cross_replica(replica_groups: Value) -> Value. Lihat bagian "cross_replica" di atas.

  • cross_replica_and_partition(replica_groups: Value) -> Value. Lihat bagian "cross_replica_and_partition" di atas.

  • flattened_ids(replica_groups: Value) -> Value. Lihat bagian "flattened_ids" di atas.

Dinamisme

Nilai StableHLO dapat memiliki ukuran dimensi dinamis, misalnya tensor<?xi64>. Namun, nilai StableHLO tidak boleh memiliki jumlah dimensi dinamis (dinamisme tanpa peringkat, misalnya tensor<*xi64>). Operand dan hasil diizinkan untuk menggunakan ukuran dimensi dinamis, meskipun ada batasan pada ukuran. Batasan akan diverifikasi secara statis jika memungkinkan, jika tidak, batasan akan ditangguhkan ke runtime dan ketidakcocokan akan menyebabkan perilaku yang tidak ditentukan. Lihat contoh berikut.

Ketidakcocokan bentuk untuk operasi elementwise unary

Pertimbangkan program mainan berikut:

func.func @foo(%arg0: tensor<?xf64>) {
  %0 = stablehlo.abs %arg0 : (tensor<?xf64>) -> tensor<2xf64>
  return
}

Program semacam itu tidak biasa, karena biasanya kita mengetahui bentuk hasil, tetapi tidak mengetahui bentuk input. Meskipun demikian, ini adalah program StableHLO yang valid. Operasi abs dalam program ini tidak dapat divalidasi secara statis karena bentuk operand yang tepat tidak diketahui. Namun, bentuk pasti kompatibel, dan ini dapat diperiksa secara statis: ? dapat berubah menjadi 2 saat runtime, dan tidak akan ada masalah. Namun, ? juga dapat berupa beberapa bilangan bulat lainnya, dalam hal ini perilakunya tidak ditentukan.

Perhatikan bahwa jika ukuran dimensi bersifat dinamis dalam hasil, tidak mungkin ada perilaku yang tidak ditentukan. Memang, tidak ada ukuran "yang diharapkan", sehingga tidak akan ada ketidakcocokan.

Ketidakcocokan bentuk untuk operasi elementwise biner

Pertimbangkan program mainan berikut:

func.func @foo(%arg0: tensor<?xf64>, %arg1: tensor<?xf64>) {
  %0 = stablehlo.add %arg0, %arg0 : (tensor<?xf64>, tensor<?xf64>) -> tensor<?xf64>
  return
}

Dalam hal operasi elementwise biner, bentuk input dan hasil harus sesuai saat runtime. Pada waktu kompilasi, dimensi statis harus sama, jika tidak, dimensi tersebut hanya perlu kompatibel. Jika ada dimensi apa pun yang bersifat dinamis dalam input, mungkin ada perilaku yang tidak ditentukan saat runtime, karena ukuran dinamis mungkin tidak cocok dengan ukuran yang sesuai dalam operand lain (baik statis maupun dinamis). Jika semua input bersifat statis, hasil yang dinamis atau tidak tidak akan menjadi masalah: dimensi yang diketahui secara statis akan diperiksa secara statis, dan dimensi dinamis tidak akan memberlakukan batasan apa pun.

Ketidakcocokan bentuk untuk operasi yang mengambil bentuk output-nya sebagai operand

Pertimbangkan program mainan berikut:

func.func @foo(%arg0: tensor<2xi32>) {
  %0 = stablehlo.dynamic_iota %arg0, dim = 0 : (tensor<2xi32>) -> tensor<3x4xi64>
  return
}

Nilai dalam operand bentuk saat runtime harus cocok dengan bentuk hasilnya. Jika tidak, perilakunya tidak ditentukan. Artinya, saat runtime, %arg0 harus memiliki nilai dense<[3, 4]> : tensor<2xi32>. Jika operand bentuk konstan, hal ini dapat diverifikasi secara statis. Jika bentuk hasilnya sepenuhnya dinamis, tidak akan ada ketidakcocokan.