Spesifikasi StableHLO

StabilHLO adalah operasi yang ditetapkan untuk operasi tingkat tinggi (HLO) di mesin machine learning (ML). StabilHLO berfungsi sebagai lapisan portabilitas antara berbagai Framework ML dan compiler ML: framework ML yang menghasilkan program StableHLO kompatibel dengan compiler ML yang menggunakan program StableHLO.

Tujuan kami adalah menyederhanakan dan mempercepat pengembangan ML dengan membuat lebih banyak interoperabilitas antara berbagai framework ML (seperti TensorFlow, JAX, PyTorch) dan compiler ML (seperti XLA dan IREE). Untuk mencapai tujuan tersebut, dokumen menyediakan spesifikasi untuk bahasa pemrograman StableHLO.

Spesifikasi ini memuat tiga bagian utama. Pertama, Bagian Programs menjelaskan struktur program StabilHLO yang terdiri dari fungsi StableHLO yang terdiri dari operasi StableHLO. Dalam struktur tersebut, bagian Ops menentukan semantik operasi individual. Bagian Execution menyediakan semantik untuk semua operasi ini yang dieksekusi bersama dalam sebuah program. Terakhir, Bagian Notasi membahas notasi yang digunakan dalam spesifikasi pendukung.

Untuk melihat spesifikasi dari rilis StableHLO sebelumnya, buka repo di rilis yang diberi tag minat. Misalnya, StableHLO v0.19.0 Spec. Untuk melihat perubahan yang terjadi di setiap bump versi minor StableHLO, lihat log versi di VhloDialect.td.

Program

Program ::= {Func}

Program SttableHLO terdiri dari sejumlah arbitrer fungsi StableHLO. Berikut adalah contoh program dengan fungsi @main yang memiliki 3 input (%image, %weights, dan %bias) dan 1 output. Isi fungsi memiliki 6 operasi.

func.func @main(
  %image: tensor<28x28xf32>,
  %weights: tensor<784x10xf32>,
  %bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
  %0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
  %1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
  %2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  %3 = "stablehlo.constant"() {value = dense<0.0> : tensor<1x10xf32>} : () -> tensor<1x10xf32>
  %4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  "func.return"(%4): (tensor<1x10xf32>) -> ()
}

Fungsi

Func        ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs  ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput   ::= ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput  ::= ValueType
FuncBody    ::= {Op}

Fungsi SttableHLO (yang juga disebut fungsi bernama) memiliki ID, input/{i>output<i} dan {i>body<i}. Di masa mendatang, kami berencana memperkenalkan metadata tambahan untuk fungsi guna mencapai kompatibilitas yang lebih baik dengan HLO (#425, #626, #740, #744).

Pengenal

FuncId  ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
          | '%' letter {letter | digit}
letter  ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit   ::= '0' | ... | '9'

ID SttableHLO mirip dengan ID dalam banyak pemrograman bahasa, dengan dua kekhasan: 1) semua pengenal memiliki {i>sigil<i} yang membedakan berbagai jenis pengenal, 2) pengenal nilai dapat lengkap bersifat numerik untuk menyederhanakan pembuatan program StableHLO.

Jenis

Type         ::= ValueType | NonValueType
ValueType    ::= TensorType | QuantizedTensorType | TokenType | TupleType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType

Jenis SttableHLO dikategorikan ke dalam jenis nilai (yang juga disebut jenis kelas satu) yang mewakili nilai StabilHLO dan jenis bukan nilai yang menggambarkan elemen program lainnya. Jenis StableHLO mirip dengan jenis pada banyak bahasa pemrograman, dengan keunikan utamanya adalah yang bersifat khusus domain yang memberikan hasil yang tidak biasa (mis. jenis skalar bukan jenis nilai).

TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit} | '?'

Jenis tensor mewakili tensor, yaitu array multidimensi. Aset tersebut memiliki bentuk dan jenis elemen, dengan sebuah bentuk merepresentasikan bilangan non-negatif atau ukuran dimensi yang tidak diketahui dalam urutan menaik dimensi (yang juga disebut sumbu) yang diberi nomor dari 0 hingga R-1. Tujuan jumlah dimensi R disebut peringkat. Misalnya, tensor<2x3xf32> adalah jenis tensor dengan bentuk 2x3 dan jenis elemen f32. Model ini memiliki dua dimensi (atau, dengan kata lain, dua sumbu) - dimensi ke-0 dan dimensi ke-1 - yang ukurannya adalah 2 dan 3. Peringkatnya 2.

Bentuk bisa tidak diketahui sebagian atau seluruhnya (dinamis), mis. tensor<?x2xf64> tidak diketahui sebagian dan tensor<?x?xf64> sama sekali tidak diketahui. Dinamis ukuran dimensi ditampilkan menggunakan ?. Bentuk tidak boleh diberi peringkat.

Di masa mendatang, kami berencana mempelajari jenis tensor yang diperluas di luar ukuran dimensi dan jenis elemen, misalnya, untuk menyertakan tata letak (#629) dan ketersebaran (#1078).

QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
                  QuantizationStorageType
                  ['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
                  ':' QuantizationExpressedType
                  [':' QuantizationDimension]
                  ',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerConstant
QuantizationStorageMax ::= IntegerConstant
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerConstant
QuantizationParameters ::= QuantizationParameter
                         | '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale ':' QuantizationZeroPoint
QuantizationScale ::= FloatConstant
QuantizationZeroPoint ::= IntegerConstant
Nama Jenis Batasan
storage_type jenis bilangan bulat (C1-C3), (C8)
storage_min konstanta bilangan bulat (C1), (C3), (C7)
storage_max konstanta bilangan bulat (C2), (C3), (C7)
expressed_type jenis floating point (C4)
quantization_dimension konstanta bilangan bulat opsional (C10-C12)
scales jumlah variabel konstanta floating point (C4-C6), (C9), (C10), (C13)
zero_points jumlah variabel konstanta bilangan bulat (C7-C9)

Jenis elemen terkuantisasi mewakili nilai bilangan bulat dari jenis penyimpanan di rentang dari storage_min hingga storage_max (inklusif) yang sesuai dengan nilai floating point dari jenis yang dinyatakan. Untuk nilai bilangan bulat i yang diberikan, nilai floating point f yang sesuai dapat dikomputasi sebagai f = (i - zero_point) * scale, yang mana scale dan zero_point dipanggil parameter kuantisasi. storage_min dan storage_max bersifat opsional dalam tata bahasa, tetapi memiliki nilai default min_value(storage_type) dan max_value(storage_type). Jenis elemen terkuantisasi memiliki batasan berikut:

  • (C1) type(storage_min) = storage_type.
  • (C2) type(storage_max) = storage_type.
  • (C3) min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type).
  • (C4) type(scales...) = expressed_type.
  • (C5) 0 < scales.
  • (C6) is_finite(scales...).
  • (C7) storage_min <= zero_points <= storage_max.
  • (C8) type(zero_points...) = storage_type.
  • (C9) size(scales) = size(zero_points).
  • (C10) Jika is_empty(quantization_dimension), maka size(scales) = 1.
  • (C11) 0 <= quantization_dimension.

Saat ini, QuantizationScale adalah konstanta floating point, tetapi ada kepentingan yang kuat dalam skala berbasis bilangan bulat, diwakili dengan pengganda dan {i>shift<i}. Kami berencana untuk menjelajahi fitur ini dalam waktu dekat (#1404).

Ada diskusi berkelanjutan tentang semantik QuantizationZeroPoint, termasuk jenisnya, nilai-nilainya, dan apakah mungkin hanya ada satu atau dan mungkin beberapa titik nol dalam tipe tensor terkuantisasi. Berdasarkan hasil diskusi ini, spesifikasi seputar titik nol dapat berubah pada masa mendatang (#1405).

Diskusi berkelanjutan lainnya melibatkan semantik QuantizationStorageMin dan QuantizationStorageMax untuk menentukan apakah ada batasan dikenakan pada nilai-nilai ini dan pada nilai-nilai tensor terkuantisasi (#1406).

Terakhir, kita berencana untuk merepresentasikan skala yang tidak diketahui dan nol poin, sama seperti perencanaan kita dalam mengeksplorasi cara untuk merepresentasikan ukuran dimensi (#1407).

Jenis tensor terkuantisasi mewakili tensor dengan elemen terkuantisasi. Ini tensor sama persis dengan tensor reguler, hanya saja memiliki jenis elemen terkuantisasi, bukan jenis elemen reguler.

Pada tensor terkuantisasi, kuantisasi bisa berupa per-tensor, artinya satu scale dan zero_point untuk seluruh tensor atau dapat berupa per-axis, artinya, memiliki beberapa scales dan zero_points, satu pasang per irisan dimensi tertentu quantization_dimension. Secara lebih formal, dalam tensor t dengan kuantisasi per sumbu, terdapat dim(t, quantization_dimension) irisan dari quantization_dimension: t[:, ..., 0, ..., :], t[:, ..., 1, ..., :], dll. Semua elemen dalam slice i menggunakan scales[i] dan zero_points[i] sebagai parameter kuantisasinya. Jenis tensor terkuantisasi memiliki batasan:

  • Untuk kuantisasi per-tensor:
    • Tidak ada batasan tambahan.
  • Untuk kuantisasi per sumbu:
    • (C12) quantization_dimension < rank(self).
    • (C13) dim(self, quantization_dimension) = size(scales).
TokenType ::= 'token'

Jenis token mewakili token, yaitu nilai buram yang dihasilkan dan dipakai oleh beberapa operasi. Token digunakan untuk menerapkan urutan eksekusi pada operasi seperti yang dijelaskan di bagian Eksekusi.

TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]

Jenis tuple mewakili tuple, yaitu daftar heterogen. Tuple adalah warisan yang hanya ada untuk kompatibilitas dengan HLO. Di HLO, tuple adalah digunakan untuk merepresentasikan input dan {i>output <i}variadic. Dalam StableHLO, input variadic dan didukung secara native, dan satu-satunya penggunaan tuple di StableHLO adalah untuk secara komprehensif merepresentasikan ABI HLO, mis. T, tuple<T>, dan tuple<tuple<T>> mungkin secara signifikan berbeda bergantung pada terlepas dari implementasi layanan. Di masa mendatang, kami berencana melakukan perubahan pada HLO ABI sehingga kami dapat menghapus jenis tuple dari StableHLO (#598).

TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si2' | 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui2' | 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f8E4M3FN' | 'f8E5M2' | 'f8E4M3FNUZ' | 'f8E5M2FNUZ'
            | 'f8E4M3B11FNUZ' | 'bf16' | 'f16' | 'f32' | 'f64'
TensorFloat32 ::= 'tf32'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'

Jenis elemen mewakili elemen jenis tensor. Tidak seperti di banyak pemrograman bahasa, jenis ini bukan kelas satu di StableHLO. Hal ini berarti bahwa Program StabilHLO tidak dapat langsung merepresentasikan nilai jenis ini (akibatnya, merepresentasikan nilai skalar jenis T dengan tensor 0 dimensi nilai dari jenis tensor<T>).

  • Jenis boolean mewakili nilai boolean true dan false.
  • Jenis bilangan bulat dapat ditandatangani (si) atau tidak ditandatangani (ui) dan memiliki salah satu lebar bit yang didukung (2, 4, 8, 16, 32, atau 64). Jenis siN bertanda tangan mewakili nilai bilangan bulat dari -2^(N-1) hingga 2^(N-1)-1 jenis uiN inklusif, dan tidak ditandatangani mewakili nilai bilangan bulat dari 0 hingga 2^N-1 inklusif.
  • Jenis floating point dapat berupa salah satu dari hal berikut:
  • Jenis kompleks merepresentasikan nilai kompleks yang memiliki bagian nyata dan bagian imajiner dari jenis elemen yang sama. Kompleks yang didukung adalah complex<f32> (kedua bagian dari jenis f32) dan complex<f64> (kedua bagian adalah jenis f64).
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]

Jenis fungsi mewakili fungsi bernama dan anonim. Mereka memiliki jenis input (daftar jenis di sebelah kiri ->) dan jenis output (daftar jenis di sisi kanan ->). Dalam banyak pemrograman bahasa, jenis fungsi adalah class satu, tetapi tidak di StableHLO.

StringType ::= 'string'

Jenis string mewakili urutan byte. Tidak seperti di banyak pemrograman bahasa, jenis string bukan kelas satu di StableHLO dan hanya digunakan untuk menentukan {i>metadata<i} statis untuk elemen program.

Operasi

Operasi SttableHLO (yang juga disebut ops) mewakili kumpulan tertutup operasi tingkat tinggi dalam model machine learning. Seperti yang dibahas di atas, Sintaksis StableHLO sangat terinspirasi oleh MLIR, yang belum tentu merupakan ergonomis, tetapi mungkin lebih cocok untuk tujuan StableHLO sehingga meningkatkan interoperabilitas antara framework ML dan compiler ML.

Op            ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName        ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic    ::= 'abs' | 'add' | ...

Operasi SttableHLO (yang juga disebut ops) memiliki nama, input/{i>output<i} dan tanda tangan. Nama terdiri dari awalan stablehlo. dan mnemonic yang secara unik mengidentifikasi salah satu operasi yang didukung. Lihat di bawah untuk daftar lengkap dari semua operasi yang didukung.

OpInputs        ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues   ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue    ::= ValueId
OpInputFuncs    ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs    ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs       ::= [OpOutput {',' OpOutput} '=']
OpOutput        ::= ValueId

Operasi menggunakan input dan menghasilkan output. Input dikategorikan ke dalam nilai input (dihitung selama eksekusi), fungsi input (disediakan secara statis, karena dalam StableHLO bukan merupakan nilai kelas satu) dan atribut input (juga disediakan secara statis). Jenis input dan output dikonsumsi dan diproduksi oleh operasi bergantung pada mnemoniknya. Misalnya, add op menggunakan 2 nilai input dan menghasilkan 1 nilai output. Sebagai perbandingan, Operasi select_and_scatter menggunakan 3 nilai input, 2 fungsi input, dan 3 atribut input.

OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused      ::= '^' digit {digit}
              | '^' letter {letter | digit}

Fungsi input (yang juga disebut fungsi anonim) sangat mirip dengan fungsi bernama kecuali bahwa: 1) mereka tidak memiliki ID (oleh karena itu nama "anonim"), 2) mereka tidak mendeklarasikan jenis output (jenis output disimpulkan dari operasi return dalam fungsi).

Sintaks untuk fungsi input menyertakan bagian yang saat ini tidak digunakan (lihat Unused di atas) yang tersedia untuk kompatibilitas dengan MLIR. Di MLIR, ada konsep yang lebih umum tentang "kawasan" yang dapat memiliki beberapa "blok" operasi yang terhubung bersama melalui operasi lompatan. Blok ini memiliki ID yang sesuai ke produksi Unused, agar dapat dibedakan satu sama lain. StabilHLO tidak memiliki operasi lompatan, jadi bagian yang sesuai dari sintaks MLIR adalah tidak digunakan (tetapi masih ada).

OpInputAttr      ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName  ::= letter {letter | digit}
OpInputAttrValue ::= Constant

Atribut input memiliki nama dan nilai yang merupakan salah satu atribut yang didukung konstanta. Mereka adalah cara utama untuk menentukan metadata statis untuk yang kurang penting. Misalnya, operasi concatenate menggunakan atribut dimension untuk menentukan dimensi yang menggabungkan nilai inputnya. Demikian pula, Operasi slice menggunakan beberapa atribut seperti start_indices dan limit_indices untuk menentukan batas yang digunakan untuk memotong nilai input.

Saat ini, program StableHLO di luar organisasi terkadang berisi atribut yang tidak dijelaskan dalam dokumen ini. Di masa mendatang, kami berencana menyerap atribut ini ke dalam opset StableHLO atau melarangnya yang muncul di program StableHLO. Sementara itu, berikut daftar atribut:

  • layout (#629).
  • mhlo.frontend_attributes (#628).
  • mhlo.sharding (#619).
  • output_operand_aliases (#740).
  • Metadata lokasi (#594).
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'

Tanda tangan op terdiri dari jenis semua nilai input (daftar jenis di di sisi kiri ->) dan jenis semua nilai output (daftar jenis di sisi kanan ->). Sebenarnya, tipe input tipe output hampir selalu redundan (karena untuk sebagian besar operasi StableHLO, jenis output dapat disimpulkan dari input). Meskipun demikian, op signature sengaja dijadikan bagian dari sintaksis StableHLO untuk kompatibilitas dengan MLIR.

Di bawah ini adalah contoh operasi yang mnemoniknya adalah select_and_scatter. Menggunakan 3 nilai input (%operand, %source, dan %init_value), 2 fungsi input dan 3 atribut input (window_dimensions, window_strides, dan padding). Perhatikan bagaimana tanda tangan operasi hanya menyertakan jenis nilai inputnya (tetapi bukan jenis fungsi dan atribut input yang disediakan secara inline).

%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
    %0 = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GE>
    } : (tensor<i32>, tensor<i32>) -> tensor<i1>
    "stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
    "stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
  window_dimensions = dense<[3, 1]> : tensor<2xi64>,
  window_strides = dense<[2, 1]> : tensor<2xi64>,
  padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>

Konstanta

Constant ::= BooleanConstant
           | IntegerConstant
           | FloatConstant
           | ComplexConstant
           | TensorConstant
           | QuantizedTensorConstant
           | StringConstant
           | EnumConstant

Konstanta SttableHLO memiliki literal dan jenis yang bersama-sama mewakili nilai StableHLO. Umumnya, jenis ini adalah bagian dari sintaks konstan, kecuali jika tidak ambigu (misalnya, konstanta boolean yang jelas memiliki jenis i1, sedangkan konstanta bilangan bulat dapat memiliki beberapa kemungkinan jenis).

BooleanConstant ::= BooleanLiteral
BooleanLiteral  ::= 'true' | 'false'

Konstanta Boolean mewakili nilai boolean true dan false. Boolean konstanta memiliki jenis i1.

IntegerConstant   ::= IntegerLiteral ':' IntegerType
IntegerLiteral    ::= ['-' | '+'] DecimalDigits
                    | ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits     ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit      ::= '0' | ... | '9'
hexadecimalDigit  ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'

Konstanta bilangan bulat mewakili nilai bilangan bulat melalui string yang menggunakan bilangan desimal atau notasi heksadesimal. Basis lainnya, misalnya biner atau oktal, tidak didukung. Konstanta bilangan bulat memiliki batasan berikut:

  • (C1) is_wellformed(integer_literal, integer_type).
FloatConstant  ::= FloatLiteral ':' FloatType
FloatLiteral   ::= SignPart IntegerPart FractionalPart ScientificPart
                 | '0x' [HexadecimalDigits]
SignPart       ::= ['-' | '+']
IntegerPart    ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]

Konstanta floating point merepresentasikan nilai floating point melalui string yang menggunakan desimal atau notasi ilmiah. Selain itu, notasi heksadesimal dapat digunakan untuk menentukan bit yang mendasarinya secara langsung dalam format floating point jenis yang sesuai. Konstanta floating point memiliki batasan berikut:

  • (C1) Jika notasi non-heksadesimal digunakan, is_wellformed(float_literal, float_type).
  • (C2) Jika notasi heksadesimal digunakan, size(hexadecimal_digits) = num_bits(float_type) / 4.
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral  ::= '(' RealPart ',' ImaginaryPart ')'
RealPart        ::= FloatLiteral
ImaginaryPart   ::= FloatLiteral

Konstanta kompleks merepresentasikan nilai kompleks menggunakan daftar bagian nyata (ditempatkan lebih dahulu) dan bagian imajiner (ditempatkan kedua). Misalnya, (1.0, 0.0) : complex<f32> mewakili 1.0 + 0.0i, dan (0.0, 1.0) : complex<f32> mewakili 0.0 + 1.0i. Urutan dari yang kemudian disimpan di memori adalah yang ditentukan implementasinya. Konstanta kompleks memiliki batasan berikut:

  • (C1) is_wellformed(real_part, complex_element_type(complex_type)).
  • (C2) is_wellformed(imaginary_part, complex_element_type(complex_type)).
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral  ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral   ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements  ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral

Konstanta Tensor mewakili nilai tensor menggunakan daftar bertingkat yang ditentukan melalui Notasi NumPy. Misalnya, dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> mewakili nilai tensor dengan pemetaan berikut dari indeks ke elemen: {0, 0} => 1, {0, 1} => 2, {0, 2} => 3, {1, 0} => 4, {1, 1} => 5, {1, 2} => 6. Urutan elemen-elemen ini kemudian disimpan dalam memori adalah ditentukan oleh implementasi. Konstanta tensor memiliki batasan berikut:

  • (C1) has_syntax(tensor_literal, element_type(tensor_type)), dengan kondisi:
    • has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type).
    • has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type).
  • (C2) has_shape(tensor_literal, shape(tensor_type)), dengan kondisi:
    • has_shape(element_literal: Syntax, []) = true.
    • has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:]).
    • jika tidak, false.
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral  ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'

Konstanta tensor terkuantisasi mewakili nilai tensor terkuantisasi menggunakan notasi sebagai konstanta tensor, dengan elemen yang ditetapkan sebagai konstanta jenis penyimpanan. Konstanta tensor terkuantisasi memiliki batasan berikut:

  • (C1) has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type)).
  • (C2) has_shape(quantized_tensor_literal, shape(quantized_tensor_type)).
StringConstant  ::= StringLiteral
StringLiteral   ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence  ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))

Literal string terdiri dari byte yang ditentukan menggunakan karakter ASCII dan urutan escape. Model ini bersifat agnostik encoding, jadi interpretasinya byte ditentukan oleh implementasi. Literal string memiliki jenis string.

Operasi

abs

Semantik

Melakukan operasi abs berbasis element-wise pada tensor operand dan menghasilkan result tensor. Bergantung pada jenis elemen, melakukan hal berikut:

  • Untuk bilangan bulat bertanda tangan: modulus bilangan bulat.
  • Untuk float: abs dari IEEE-754.
  • Untuk bilangan kompleks: modulus kompleks.
  • Untuk jenis terkuantisasi: dequantize_op_quantize(abs, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand tensor bilangan bulat bertanda tangan, floating point, atau tipe kompleks atau tensor terkuantisasi per-tensor (C1-C2)

Output

Nama Jenis Batasan
result tensor bilangan bulat bertanda tangan atau jenis floating point atau tensor terkuantisasi per-tensor (C1-C2)

Batasan

  • (C1) shape(result) = shape(operand).
  • (C2) baseline_element_type(result) didefinisikan sebagai:
    • complex_element_type(element_type(operand)) jika is_complex(operand).
    • baseline_element_type(operand) sebaliknya.

Contoh

// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand) : (tensor<3xi32>) -> tensor<3xi32>
// %result: [2, 0, 2]

Contoh Lainnya

tambahkan

Semantik

Melakukan penambahan berdasarkan elemen pada dua tensor lhs dan rhs serta menghasilkan Tensor result. Bergantung pada jenis elemen, melakukan hal berikut:

  • Untuk boolean: OR logis.
  • Untuk bilangan bulat: penambahan bilangan bulat.
  • Untuk float: addition dari IEEE-754.
  • Untuk bilangan kompleks: penambahan kompleks.
  • Untuk jenis terkuantisasi: dequantize_op_quantize(add, lhs, rhs, type(result)).

Input

Label Nama Jenis Batasan
(I1) lhs tensor atau quantized tensor (C1-C6)
(I2) rhs tensor atau quantized tensor (C1-C5), (C7)

Output

Nama Jenis Batasan
result tensor atau quantized tensor (C1-C7)

Batasan

  • Jika operasi tersebut menggunakan tensor non-kuantisasi:
    • (C1) type(lhs) = type(rhs) = type(result).
  • Jika operasi tersebut menggunakan tensor terkuantisasi:
    • (C2) is_quantized(lhs) and is_quantized(rhs) and is_quantized(result).
    • (C3) storage_type(lhs) = storage_type(rhs) = storage_type(result).
    • (C4) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • (C5) (is_per_axis_quantized(lhs) or is_per_axis_quantized(rhs)) = is_per_axis_quantized(result).
    • (C6) Jika is_per_axis_quantized(lhs), maka quantization_dimension(lhs) = quantization_dimension(result).
    • (C7) Jika is_per_axis_quantized(rhs), maka quantization_dimension(rhs) = quantization_dimension(result).

Contoh

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[6, 8], [10, 12]]

Contoh Lainnya

after_all

Semantik

Memastikan operasi yang menghasilkan inputs dijalankan sebelum operasi yang bergantung pada result. Eksekusi operasi ini tidak melakukan apa pun, fungsi tersebut hanya ada untuk membuat dependensi data dari result hingga inputs.

Input

Label Nama Jenis
(I1) inputs bilangan variadis token

Output

Nama Jenis
result token

Contoh

// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token

Contoh Lainnya

all_gather

Semantik

Dalam setiap grup proses di grid proses StableHLO, merangkai nilai-nilai tensor operands dari setiap proses di sepanjang all_gather_dim dan menghasilkan Tensor results.

Operasi tersebut membagi grid proses StableHLO menjadi process_groups yang didefinisikan sebagai berikut:

  • cross_replica(replica_groups) jika channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) jika channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) jika channel_id > 0 and use_global_device_ids = true.

Setelah itu, dalam setiap process_group:

  • operands...@receiver = [operand@sender for sender in process_group] untuk semua receiver dalam process_group.
  • results...@process = concatenate(operands...@process, all_gather_dim) untuk semua process dalam process_group.

Input

Label Nama Jenis Batasan
(I1) operands jumlah tensor atau tensor terkuantisasi per-tensor (C1), C6
(I2) all_gather_dim konstanta jenis si64 (C1), C6
(I3) replica_groups Konstanta tensor 2 dimensi jenis si64 (C2-C4)
(I4) channel_id konstanta jenis si64 (C5)
(I5) use_global_device_ids konstanta jenis i1 (C5)

Output

Nama Jenis Batasan
results jumlah tensor atau tensor terkuantisasi per-tensor (C6)

Batasan

  • (C1) 0 <= all_gather_dim < rank(operands...).
  • (C2) is_unique(replica_groups).
  • (C3) size(replica_groups) didefinisikan sebagai:
    • num_replicas jika cross_replica digunakan.
    • num_replicas jika cross_replica_and_partition digunakan.
    • num_processes jika flattened_ids digunakan.
  • (C4) 0 <= replica_groups < size(replica_groups).
  • (C5) Jika use_global_device_ids = true, maka channel_id > 0.
  • (C6) type(results...) = type(operands...) kecuali:
    • dim(results..., all_gather_dim) = dim(operands..., all_gather_dim) * dim(process_groups, 1).

Contoh

// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [[1, 2], [3, 4]]
// %operand0@(1, 0): [[5, 6], [7, 8]]
// %operand1@(0, 0): [[11, 12], [13, 14]]
// %operand1@(1, 0): [[15, 16], [17, 18]]
%result:2 = "stablehlo.all_gather"(%operand0, %operand1) {
  all_gather_dim = 1 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  // channel_id = 0
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
  // use_global_device_ids = false
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> (tensor<2x4xi64>, tensor<2x4xi64>)
// %result0@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result0@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result1@(0, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
// %result1@(1, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]

Contoh Lainnya

all_reduce

Semantik

Dalam setiap grup proses pada grid proses StableHLO, menerapkan pengurangan fungsi computation dengan nilai tensor operands dari setiap proses dan menghasilkan tensor results.

Operasi tersebut membagi grid proses StableHLO menjadi process_groups yang didefinisikan sebagai berikut:

  • cross_replica(replica_groups) jika channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) jika channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) jika channel_id > 0 and use_global_device_ids = true.

Setelah itu, dalam setiap process_group:

  • results...@process[result_index] = exec(schedule) untuk beberapa hierarki biner schedule dengan:
    • exec(node) = computation(exec(node.left), exec(node.right)).
    • exec(leaf) = leaf.value.
  • schedule adalah hierarki biner yang ditentukan penerapan yang sesuai urutan traversal adalah to_destination_type(operands...@process_group...[result_index], type(func_inputs(computation)[0])).

Input

Label Nama Jenis Batasan
(I1) operands jumlah tensor atau tensor terkuantisasi per-tensor (C5), (C6)
(I2) replica_groups jumlah variadik konstanta tensor 1 dimensi jenis si64 (C1-C3)
(I3) channel_id konstanta jenis si64 (C4)
(I4) use_global_device_ids konstanta jenis i1 (C4)
(I5) computation fungsi (C5)

Output

Nama Jenis Batasan
results jumlah tensor atau tensor terkuantisasi per-tensor (C6-C7)

Batasan

  • (C1) is_unique(replica_groups).
  • (C2) size(replica_groups) didefinisikan sebagai:
    • num_replicas jika cross_replica digunakan.
    • num_replicas jika cross_replica_and_partition digunakan.
    • num_processes jika flattened_ids digunakan.
  • (C3) 0 <= replica_groups < size(replica_groups).
  • (C4) Jika use_global_device_ids = true, maka channel_id > 0.
  • (C5) computation memiliki jenis (tensor<E>, tensor<E>) -> (tensor<E>), dengan is_promotable(element_type(operand), E).
  • (C6) shape(results...) = shape(operands...).
  • (C7) element_type(results...) = E.

Contoh

// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [1, 2, 3, 4]
// %operand0@(1, 0): [5, 6, 7, 8]
// %operand1@(0, 0): [9, 10, 11, 12]
// %operand1@(1, 0): [13, 14, 15, 16]
%result:2 = "stablehlo.all_reduce"(%operand0, %operand0) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  // channel_id = 0
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
  // use_global_device_ids = false
} : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>, tensor<4xi64>)
// %result0@(0, 0): [6, 8, 10, 12]
// %result0@(1, 0): [6, 8, 10, 12]
// %result1@(0, 0): [22, 24, 26, 28]
// %result1@(1, 0): [22, 24, 26, 28]

Contoh Lainnya

all_to_all

Semantik

all_to_all

Dalam setiap grup proses dalam {i>process grid<i} StableHLO, membagi nilai tensor operands di sepanjang split_dimension menjadi beberapa bagian, sehingga akan menyebarkan bagian di antara proses, menyambungkan bagian yang tersebar di concat_dimension dan menghasilkan tensor results. Operasi tersebut membagi grid proses StableHLO menjadi process_groups yang didefinisikan sebagai berikut:

  • cross_replica(replica_groups) jika channel_id <= 0.
  • cross_partition(replica_groups) jika channel_id > 0.

Setelah itu, dalam setiap process_group:

  • split_parts...@sender = split(operands...@sender, split_count, split_dimension) untuk semua sender di process_group.
  • scattered_parts...@receiver = [split_parts...@sender[receiver_index] for sender in process_group] di mana receiver_index = process_group.index(receiver).
  • results...@process = concatenate(scattered_parts...@process, concat_dimension).

Input

Label Nama Jenis Batasan
(I1) operands jumlah tensor atau tensor terkuantisasi per-tensor (C1-C3), (C9)
(I2) split_dimension konstanta jenis si64 (C1), (C2), (C9)
(I3) concat_dimension konstanta jenis si64 (C3), (C9)
(I4) split_count konstanta jenis si64 (C2), (C4), (C8), (C9)
(I5) replica_groups Konstanta tensor 2 dimensi jenis si64 (C5-C8)
(I6) channel_id konstanta jenis si64

Output

Nama Jenis Batasan
results jumlah tensor atau tensor terkuantisasi per-tensor (C9)

Batasan

  • (C1) 0 <= split_dimension < rank(operands...).
  • (C2) dim(operands..., split_dimension) % split_count = 0.
  • (C3) 0 <= concat_dimension < rank(operands...).
  • (C4) 0 < split_count.
  • (C5) is_unique(replica_groups).
  • (C6) size(replica_groups) ditentukan sebagai:
    • num_replicas jika cross_replica digunakan.
    • num_partitions jika cross_partition digunakan.
  • (C7) 0 <= replica_groups < size(replica_groups).
  • (C8) dim(replica_groups, 1) = split_count.
  • (C9) type(results...) = type(operands...) kecuali, jika split_dimension != concat_dimension:
    • dim(results..., split_dimension) = dim(operands..., split_dimension) / split_count.
    • dim(results..., concat_dimension) = dim(operands..., concat_dimension) * split_count.

Contoh

// num_replicas: 2
// num_partitions: 1
// %operand1@(0, 0): [[1, 2, 3, 4],
//                    [5, 6, 7, 8]]
// %operand1@(1, 0): [[9, 10, 11, 12],
//                    [13, 14, 15, 16]]
// %operand2@(0, 0): [[17, 18, 19, 20],
//                    [21, 22, 23, 24]]
// %operand2@(1, 0): [[25, 26, 27, 28],
//                    [29, 30, 31, 32]]
%result:2 = "stablehlo.all_to_all"(%operand1, %operand2) {
  split_dimension = 1 : i64,
  concat_dimension = 0 : i64,
  split_count = 2 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
  // channel_id = 0
} : (tensor<2x4xi64>, tensor<2x4xi64>) -> (tensor<4x2xi64>, tensor<4x2xi64>)
// %result#0@(0, 0): [[1, 2], [5, 6], [9, 10], [13, 14]]
// %result#0@(1, 0): [[3, 4], [7, 8], [11, 12], [15, 16]]
// %result#1@(0, 0): [[17, 18], [21, 22], [25, 26], [29, 30]]
// %result#1@(1, 0): [[19, 20], [23, 24], [27, 28], [31, 32]]

Contoh Lainnya

dan

Semantik

Melakukan AND berdasarkan elemen dari dua tensor lhs dan rhs serta menghasilkan result tensor. Bergantung pada jenis elemen, melakukan hal berikut:

  • Untuk boolean: logika AND.
  • Untuk bilangan bulat: bitwise AND.

Input

Label Nama Jenis Batasan
(I1) lhs tensor jenis boolean atau bilangan bulat (C1)
(I2) rhs tensor jenis boolean atau bilangan bulat (C1)

Output

Nama Jenis Batasan
result tensor jenis boolean atau bilangan bulat (C1)

Batasan

  • (C1) type(lhs) = type(rhs) = type(result).

Contoh

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 0]]

Contoh Lainnya

atan2

Semantik

Menjalankan operasi atan2 berdasarkan elemen pada tensor lhs dan rhs serta menghasilkan Tensor result. Bergantung pada jenis elemen, melakukan hal berikut:

  • Untuk float: atan2 dari IEEE-754.
  • Untuk bilangan kompleks: atan2 kompleks.
  • Untuk jenis terkuantisasi: dequantize_op_quantize(atan2, lhs, rhs, type(result)).

Input

Label Nama Jenis Batasan
(I1) lhs tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor (C1)
(I2) rhs tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor (C1)

Batasan

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Contoh

// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs) : (tensor<3xf64>, tensor<3xf64>) -> tensor<3xf64>
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]

Contoh Lainnya

batch_norm_grad

Semantik

Menghitung gradien beberapa input propagasi mundur batch_norm_training dari grad_output, dan menghasilkan grad_operand, grad_scale, dan grad_offset tensor. Secara lebih formal, operasi ini bisa dinyatakan sebagai dekomposisi menjadi operasi StabilHLO yang ada menggunakan sintaksis Python sebagai berikut:

def compute_sum(operand, feature_index):
  (sum,) = reduce(
      inputs=[operand],
      init_values=[constant(0, element_type(operand))],
      dimensions=[i for i in range(rank(operand)) if i != feature_index],
      body=lambda x, y: add(x, y))
  return sum

def compute_mean(operand, feature_index):
  sum = compute_sum(operand, feature_index)
  divisor = constant(size(operand) / dim(operand, feature_index),
                     element_type(operand))
  divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
  return divide(sum, divisor_bcast)

def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
  # Broadcast inputs to type(operand)
  scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
  epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
                                   type(operand))

  # Perform normalization using the provided `mean` and `variance`
  # Intermediate values will be useful for computing gradients
  centered_operand = subtract(operand, mean_bcast)
  stddev = sqrt(add(variance_bcast, epsilon_bcast))
  normalized_operand = divide(centered_operand, stddev)

  # Use the implementation from batchnorm_expander.cc in XLA
  # Temporary variables have exactly the same names as in the C++ code
  elements_per_feature = broadcast_in_dim(
      constant(divide(size(operand), dim(operand, feature_index)),
               element_type(grad_output)),
      [], type(operand))
  i1 = multiply(grad_output, elements_per_feature)
  i2 = broadcast_in_dim(
      compute_sum(grad_output, feature_index), [feature_index], type(operand))
  i3 = broadcast_in_dim(
      compute_sum(multiply(grad_output, centered_operand), feature_index),
      [feature_index], type(operand))
  i4 = multiply(i3, centered_operand)
  i5 = divide(i4, add(variance_bcast, epsilon_bcast))
  i6 = subtract(subtract(i1, i2), i5)

  grad_operand =
      multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
  grad_scale =
      compute_sum(multiply(grad_output, normalized_operand), feature_index)
  grad_offset = compute_sum(grad_output, feature_index)

  return grad_operand, grad_scale, grad_offset

Untuk jenis terkuantisasi, menjalankan dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean, variance, grad_output: batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index), operand, scale, mean, variance, grad_output, type(grad_operand), type(grad_scale), type(feature_index)).

Input

Label Nama Jenis Batasan
(I1) operand tensor tipe floating point atau tensor terkuantisasi per-tensor (C1-C3), (C5)
(I2) scale Tensor 1-dimensi dari floating-point atau jenis terkuantisasi per-tensor (C2), (C4), (C5)
(I3) mean Tensor 1-dimensi dari floating-point atau jenis terkuantisasi per-tensor (C2), (C4)
(I4) variance Tensor 1-dimensi dari floating-point atau jenis terkuantisasi per-tensor (C2), (C4)
(I5) grad_output tensor tipe floating point atau tensor terkuantisasi per-tensor (C2), (C3)
(I6) epsilon konstanta jenis f32
(I7) feature_index konstanta jenis si64 (C1), (C5)

Output

Nama Jenis Batasan
grad_operand tensor tipe floating point atau tensor terkuantisasi per-tensor (C2), (C3)
grad_scale Tensor 1-dimensi dari floating-point atau jenis terkuantisasi per-tensor (C2), (C4)
grad_offset Tensor 1-dimensi dari floating-point atau jenis terkuantisasi per-tensor (C2), (C4)

Batasan

  • (C1) 0 <= feature_index < rank(operand).
  • (C2) operand, scale, mean, variance, grad_output, grad_operand, grad_scale dan grad_offset memiliki baseline_element_type yang sama.
  • (C3) operand, grad_output, dan grad_operand memiliki bentuk yang sama.
  • (C4) scale, mean, variance, grad_scale, dan grad_offset memiliki bentuk yang sama.
  • (C5) size(scale) = dim(operand, feature_index).

Contoh

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
//                [[0.1, 0.1], [0.1, 0.1]],
//                [[0.1, 0.1], [0.1, 0.1]]
//               ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>,
     tensor<2x2x2xf64>) -> (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %grad_operand: [
//                 [[0.0, 0.0], [0.0, 0.0]],
//                 [[0.0, 0.0], [0.0, 0.0]]
//                ]
// %grad_scale:  [0.0, 0.0]
// %grad_offset: [0.4, 0.4]

batch_norm_inference

Semantik

Menormalkan tensor operand di semua dimensi kecuali untuk feature_index dan menghasilkan tensor result. Secara lebih formal, operasi dapat dinyatakan sebagai dekomposisi untuk operasi StableHLO yang ada menggunakan sintaks Python sebagai berikut:

def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
  # Broadcast inputs to shape(operand)
  scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
  offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
  epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
                                   type(operand))

  # Perform normalization using the provided `mean` and `variance` instead of
  # computing them like `batch_norm_training` does.
  centered_operand = subtract(operand, mean_bcast)
  stddev = sqrt(add(variance_bcast, epsilon_bcast))
  normalized_operand = divide(centered_operand, stddev)
  return add(multiply(scale_bcast, normalized_operand), offset_bcast)

Untuk jenis terkuantisasi, menjalankan dequantize_op_quantize(lambda operand, scale, offset, mean, variance: batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index), operand, scale, offset, mean, variance, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand tensor tipe floating point atau tensor terkuantisasi per-tensor (C1-C7)
(I2) scale Tensor 1-dimensi dari floating-point atau jenis terkuantisasi per-tensor (C2), (C3)
(I3) offset Tensor 1-dimensi dari floating-point atau jenis terkuantisasi per-tensor (C2), (C4)
(I4) mean Tensor 1-dimensi dari floating-point atau jenis terkuantisasi per-tensor (C5)
(I5) variance Tensor 1-dimensi dari floating-point atau jenis terkuantisasi per-tensor (C2), (C6)
(I6) epsilon konstanta jenis f32
(I7) feature_index konstanta jenis si64 (C1), C3-C6

Output

Nama Jenis Batasan
result tensor tipe floating point atau tensor terkuantisasi per-tensor (C2), (C7)

Batasan

  • (C1) 0 <= feature_index < rank(operand).
  • (C2) operand, scale, offset, mean, variance, dan result memiliki baseline_element_type yang sama.
  • (C3) size(scale) = dim(operand, feature_index).
  • (C4) size(offset) = dim(operand, feature_index).
  • (C5) size(mean) = dim(operand, feature_index).
  • (C6) size(variance) = dim(operand, feature_index).
  • (C7) baseline_type(operand) = baseline_type(result).

Contoh

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>) -> tensor<2x2x2xf64>
// %result: [
//           [[0.0, 0.0], [2.0, 2.0]],
//           [[2.0, 2.0], [0.0, 0.0]]
//          ]

batch_norm_training

Semantik

Menghitung rata-rata dan varians di semua dimensi kecuali untuk feature_index dan menormalisasi tensor operand yang menghasilkan output, batch_mean dan tensor batch_var. Secara lebih formal, operasi ini dapat dinyatakan sebagai dekomposisi ke operasi StableHLO yang ada menggunakan sintaksis Python sebagai berikut ini:

def compute_mean(operand, feature_index):
  (sum,) = reduce(
      inputs=[operand],
      init_values=[constant(0, element_type(operand))],
      dimensions=[i for i in range(rank(operand)) if i != feature_index],
      body=lambda x, y: add(x, y))
  divisor = constant(size(operand) / dim(operand, feature_index),
                     element_type(operand))
  divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
  return divide(sum, divisor_bcast)

def compute_variance(operand, feature_index):
  mean = compute_mean(operand, feature_index)
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  centered_operand = subtract(operand, mean_bcast)
  return compute_mean(mul(centered_operand, centered_operand), feature_index)

def batch_norm_training(operand, scale, offset, epsilon, feature_index):
  mean = compute_mean(operand, feature_index)
  variance = compute_variance(operand, feature_index)
  return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
                              feature_index),
         mean, variance

Untuk jenis terkuantisasi, menjalankan dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset: batch_norm_training(operand, scale, offset, epsilon, feature_index), operand, scale, offset, type(output), type(batch_mean), type(batch_var)).

Input

Label Nama Jenis Batasan
(I1) operand tensor tipe floating point atau tensor terkuantisasi per-tensor (C1)
(I2) scale Tensor 1-dimensi dari floating-point atau per-tensor terkuantisasi (C2), (C3)
(I3) offset Tensor 1-dimensi dari floating-point atau per-tensor terkuantisasi (C2), (C4)
(I4) epsilon konstanta jenis f32 (C1), C3-C6
(I5) feature_index konstanta jenis si64 (C1), C3-C6

Output

Nama Jenis Batasan
output tensor tipe floating point atau tensor terkuantisasi per-tensor (C7)
batch_mean Tensor 1-dimensi dari floating-point atau per-tensor terkuantisasi (C2), (C5)
batch_var Tensor 1-dimensi dari floating-point atau per-tensor terkuantisasi (C2), (C6)

Batasan

  • (C1) 0 <= feature_index < rank(operand).
  • (C2) operand, scale, offset, batch_mean, batch_var, dan output memiliki baseline_element_type yang sama.
  • (C3) size(scale) = dim(operand, feature_index).
  • (C4) size(offset) = dim(operand, feature_index).
  • (C5) size(batch_mean) = dim(operand, feature_index).
  • (C6) size(batch_var) = dim(operand, feature_index).
  • (C7) baseline_type(output) = baseline_type(operand).

Contoh

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) ->
    (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %output: [
//           [[0.0, 0.0], [2.0, 2.0]],
//           [[2.0, 2.0], [0.0, 0.0]]
//          ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]

bitcast_convert

Semantik

Melakukan operasi bitcast pada tensor operand dan menghasilkan tensor result dengan bit dari seluruh tensor operand diinterpretasikan ulang menggunakan dari tensor result.

Secara lebih formal, mengingat E = element_type(operand), E' = element_type(result), dan R = rank(operand):

  • Jika num_bits(E') < num_bits(E), bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1]).
  • Jika num_bits(E') > num_bits(E), bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :]).
  • Jika num_bits(E') = num_bits(E), bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1]).

bits menampilkan representasi dalam memori dari nilai tertentu, beserta perilakunya adalah implementasi-ditentukan karena representasi yang tepat dari tensor yang ditentukan implementasinya, dan representasi yang tepat dari tipe elemen adalah serta definisi implementasi.

Input

Label Nama Jenis Batasan
(I1) operand tensor atau quantized tensor (C1-C2)

Output

Nama Jenis Batasan
result tensor atau quantized tensor (C1-C2)

Batasan

  • (C1) Dengan mempertimbangkan E = is_quantized(operand) ? storage_type(operand) : element_type(operand), E' = is_quantized(result) ? storage_type(result) : element_type(result), dan R = rank(operand):
    • Jika num_bits(E') = num_bits(E), shape(result) = shape(operand).
    • Jika num_bits(E') < num_bits(E):
    • rank(result) = R + 1.
    • dim(result, i) = dim(operand, i) untuk semua 0 <= i < R.
    • dim(result, R) * num_bits(E') = num_bits(E).
    • Jika num_bits(E') > num_bits(E):
    • rank(result) = R - 1.
    • dim(result, i) = dim(operand, i) untuk semua 0 <= i < R.
    • dim(operand, R - 1) * num_bits(E) = num_bits(E').
  • (C2) Jika is_complex(operand) or is_complex(result), maka is_complex(operand) and is_complex(result).

Contoh

// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand) : (tensor<f64>) -> tensor<4xf16>
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation

Contoh Lainnya

broadcast_in_dim

Semantik

Memperluas dimensi dan/atau peringkat tensor input dengan menduplikasi data pada tensor operand dan menghasilkan tensor result. Secara lebih formal, result[result_index] = operand[operand_index] untuk semua d di axes(operand):

  • operand_index[d] = 0 jika dim(operand, d) = 1.
  • operand_index[d] = result_index[broadcast_dimensions[d]] sebaliknya.

Input

Label Nama Jenis Batasan
(I1) operand tensor atau quantized tensor (C1-C2), (C5-C6)
(I2) broadcast_dimensions Konstanta tensor 1 dimensi jenis si64 (C2-C6)

Output

Nama Jenis Batasan
result tensor atau quantized tensor (C1), (C3), (C5-C6)

Batasan

  • (C1) element_type(result) diberikan oleh:
    • element_type(operand), jika !is_per_axis_quantized(operand).
    • element_type(operand) kecuali quantization_dimension(operand), scales(operand), dan zero_points(operand) mungkin berbeda dengan quantization_dimension(result), scales(result), dan zero_points(result) resp., jika tidak.
  • (C2) size(broadcast_dimensions) = rank(operand).
  • (C3) 0 <= broadcast_dimensions < rank(result).
  • (C4) is_unique(broadcast_dimensions).
  • (C5) Untuk semua d di axes(operand):
    • dim(operand, d) = 1 atau
    • dim(operand, d) = dim(result, broadcast_dimensions[d]).
  • (C6) Jika is_per_axis_quantized(result):
    • quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)].
    • Jika dim(operand, quantization_dimension(operand)) = 1, maka scales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result))).

Contoh

// %operand: [
//            [1, 2, 3]
//           ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
  broadcast_dimensions = array<i64: 2, 1>
} : (tensor<1x3xi32>) -> tensor<2x3x2xi32>
// %result: [
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ],
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ]
//          ]

Contoh Lainnya

casing

Semantik

Menghasilkan output dari mengeksekusi tepat satu fungsi dari branches bergantung pada nilai index. Secara lebih formal, result = selected_branch() dalam hal ini:

  • selected_branch = branches[index] jika 0 <= index < size(branches).
  • selected_branch = branches[-1] sebaliknya.

Input

Label Nama Jenis Batasan
(I1) index Tensor 0-dimensi jenis si32
(I2) branches jumlah fungsi variadik (C1-C4)

Output

Nama Jenis Batasan
results jumlah tensor, tensor atau token terkuantisasi (C4)

Batasan

  • (C1) 0 < size(branches).
  • (C2) input_types(branches...) = [].
  • (C3) same(output_types(branches...)).
  • (C4) type(results...) = output_types(branches[0]).

Contoh

// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
  "stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> ()
}, {
  "stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> ()
}) : (tensor<i32>) -> (tensor<2xi64>, tensor<2xi64>)
// %result0: [1, 1]
// %result1: [1, 1]

Contoh Lainnya

Cbrt

Semantik

Melakukan operasi root kubik berdasarkan elemen pada tensor operand dan menghasilkan Tensor result. Bergantung pada jenis elemen, melakukan hal berikut:

  • Untuk float: rootn(x, 3) dari IEEE-754.
  • Untuk bilangan kompleks: akar kubik kompleks.
  • Untuk jenis terkuantisasi: dequantize_op_quantize(cbrt, operand, type(result))

Input

Label Nama Jenis Batasan
(I1) operand tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand) : (tensor<4xf64>) -> tensor<4xf64>
// %result: [0.0, 1.0, 2.0, 3.0]

Contoh Lainnya

ceil

Semantik

Melakukan ceil berbasis elemen dari tensor operand dan menghasilkan tensor result. Menerapkan operasi roundToIntegralTowardPositive dari IEEE-754 spesifikasi pendukung. Untuk jenis terkuantisasi, menjalankan dequantize_op_quantize(ceil, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand tensor tipe floating point atau tensor terkuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result tensor tipe floating point atau tensor terkuantisasi per-tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]

Contoh Lainnya

Cholesky

Semantik

Menghitung dekomposisi Cholesky dari sekumpulan matriks.

Secara lebih formal, untuk semua i di index_space(result), result[i0, ..., iR-3, :, :] adalah dekomposisi Cholesky dari a[i0, ..., iR-3, :, :], dalam bentuk segitiga rendah (jika lower adalah true) atau matriks segitiga atas (jika lower adalah false). Nilai {i>output<i} dalam segitiga yang berlawanan, yaitu segitiga atas yang tegas atau segitiga bawah yang ketat, yang juga ditentukan oleh implementasi.

Jika ada i yang matriks inputnya bukan positif-definit Hermitian matriks, maka perilakunya tidak terdefinisi.

Untuk jenis terkuantisasi, menjalankan dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result)).

Input

Label Nama Jenis Batasan
(I1) a tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor (C1-C3)
(I2) lower Konstanta tensor 0 dimensi jenis i1

Output

Nama Jenis Batasan
result tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor (C1)

Batasan

  • (C1) baseline_type(a) = baseline_type(result).
  • (C2) 2 <= rank(a).
  • (C3) dim(a, -2) = dim(a, -1).

Contoh

// %a: [
//      [1.0, 2.0, 3.0],
//      [2.0, 20.0, 26.0],
//      [3.0, 26.0, 70.0]
//     ]
%result = "stablehlo.cholesky"(%a) {
  lower = true
} : (tensor<3x3xf32>) -> tensor<3x3xf64>
// %result: [
//           [1.0, 0.0, 0.0],
//           [2.0, 4.0, 0.0],
//           [3.0, 5.0, 6.0]
//          ]

klem

Semantik

Menjepit setiap elemen tensor operand di antara nilai minimum dan maksimum dan menghasilkan tensor result. Secara lebih formal, result[result_index] = minimum(maximum(operand[result_index], min_element), max_element), dengan min_element = rank(min) = 0 ? min[] : min[result_index], max_element = rank(max) = 0 ? max[] : max[result_index]. Untuk jenis terkuantisasi, menjalankan dequantize_op_quantize(clamp, min, operand, max, type(result)).

Memaksakan urutan pada bilangan kompleks melibatkan semantik yang mengejutkan, jadi di masa mendatang kami berencana menghapus dukungan untuk bilangan kompleks untuk operasi ini (#560).

Input

Label Nama Jenis Batasan
(I1) min tensor atau tensor terkuantisasi per-tensor (C1), C3
(I2) operand tensor atau tensor terkuantisasi per-tensor (C1-C4)
(I3) max tensor atau tensor terkuantisasi per-tensor (C2), (C3)

Output

Nama Jenis Batasan
result tensor atau tensor terkuantisasi per-tensor (C4)

Batasan

  • (C1) rank(min) = 0 or shape(min) = shape(operand).
  • (C2) rank(max) = 0 or shape(max) = shape(operand).
  • (C3) baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max).
  • (C4) baseline_type(operand) = baseline_type(result).

Contoh

// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max) : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// %result: [5, 13, 20]

Contoh Lainnya

collective_broadcast

Semantik

Dalam setiap grup proses di grid proses StableHLO, kirim nilai proses Tensor operand dari proses sumber ke proses target dan menghasilkan Tensor result.

Operasi tersebut membagi grid proses StableHLO menjadi process_groups yang didefinisikan sebagai berikut:

  • cross_replica(replica_groups) jika channel_id <= 0.
  • cross_partition(replica_groups) jika channel_id > 0.

Setelah itu, result@process diberikan oleh:

  • operand@process_groups[i, 0] jika ada i sehingga prosesnya menjadi di process_groups[i].
  • broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result)) sebaliknya.

Input

Label Nama Jenis Batasan
(I1) operand tensor atau tensor terkuantisasi per-tensor (C3)
(I2) replica_groups jumlah variadik konstanta tensor 1 dimensi jenis si64 (C1), C2
(I3) channel_id konstanta jenis si64

Output

Nama Jenis Batasan
result tensor atau tensor terkuantisasi per-tensor (C3)

Batasan

  • (C1) is_unique(replica_groups).
  • (C2) 0 <= replica_groups < N dengan N ditentukan sebagai:
    • num_replicas jika cross_replica digunakan.
    • num_partitions jika cross_partition digunakan.
  • (C3) type(result) = type(operand).

Contoh

// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
  replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]

collective_permute

Semantik

Dalam setiap grup proses pada grid proses StableHLO, mengirimkan nilai Tensor operand dari proses sumber ke proses target dan menghasilkan Tensor result.

Operasi tersebut membagi grid proses StableHLO menjadi process_groups yang didefinisikan sebagai berikut:

  • cross_replica(source_target_pairs) jika channel_id <= 0.
  • cross_partition(source_target_pairs) jika channel_id > 0.

Setelah itu, result@process diberikan oleh:

  • operand@process_groups[i, 0], jika ada i sehingga process_groups[i, 1] = process.
  • broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result)) sebaliknya.

Input

Label Nama Jenis Batasan
(I1) operand tensor atau tensor terkuantisasi per-tensor (C5)
(I2) source_target_pairs Konstanta tensor 2 dimensi jenis si64 (C1-C4)
(I3) channel_id konstanta jenis si64

Output

Nama Jenis Batasan
result tensor atau tensor terkuantisasi per-tensor (C1)

Batasan

  • (C1) dim(source_target_pairs, 1) = 2.
  • (C2) is_unique(source_target_pairs[:, 0]).
  • (C3) is_unique(source_target_pairs[:, 1]).
  • (C4) 0 <= source_target_pairs < N, dengan N ditentukan sebagai:
    • num_replicas jika cross_replica digunakan.
    • num_partitions jika cross_partition digunakan.
  • (C5) type(result) = type(operand).

Contoh

// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
  source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]

Contoh Lainnya

bandingkan

Semantik

Melakukan perbandingan tensor lhs dan rhs berdasarkan elemen sesuai dengan comparison_direction dan compare_type, serta menghasilkan tensor result.

Nilai comparison_direction dan compare_type memiliki hal berikut semantik:

Untuk jenis elemen boolean dan bilangan bulat:

  • EQ: lhs = rhs.
  • NE: lhs != rhs.
  • GE: lhs >= rhs.
  • GT: lhs > rhs.
  • LE: lhs <= rhs.
  • LT: lhs < rhs.

Untuk jenis elemen floating point dengan compare_type = FLOAT, op akan mengimplementasikan operasi IEEE-754 berikut:

  • EQ: compareQuietEqual.
  • NE: compareQuietNotEqual.
  • GE: compareQuietGreaterEqual.
  • GT: compareQuietGreater.
  • LE: compareQuietLessEqual.
  • LT: compareQuietLess.

Untuk jenis elemen floating point dengan compare_type = TOTALORDER, op menggunakan kombinasi operasi totalOrder dan compareQuietEqual dari IEEE-754.

Untuk jenis elemen kompleks, perbandingan leksikografis pasangan (real, imag) dijalankan menggunakan comparison_direction dan compare_type yang disediakan. Memaksakan urutan pada bilangan kompleks melibatkan semantik yang mengejutkan, jadi di masa mendatang kami berencana menghapus dukungan untuk bilangan kompleks jika comparison_direction adalah GE, GT, LE, atau LT (#560).

Untuk jenis terkuantisasi. menjalankan dequantize_compare(lhs, rhs, comparison_direction).

Input

Label Nama Jenis Batasan
(I1) lhs tensor atau tensor terkuantisasi per-tensor (C1-C3)
(I2) rhs tensor atau tensor terkuantisasi per-tensor (C1-C2)
(I3) comparison_direction enum EQ, NE, GE, GT, LE, dan LT
(I4) compare_type enum FLOAT, TOTALORDER, SIGNED, dan UNSIGNED (C3)

Output

Nama Jenis Batasan
result tensor jenis boolean (C2)

Batasan

  • (C1) baseline_element_type(lhs) = baseline_element_type(rhs).
  • (C2) shape(lhs) = shape(rhs) = shape(result).
  • (C3) compare_type didefinisikan sebagai:
    • SIGNED jika is_signed_integer(element_type(lhs)).
    • UNSIGNED jika is_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs)).
    • FLOAT atau TOTALORDER jika is_float(element_type(lhs)).
    • FLOAT jika is_complex(element_type(lhs)).

Contoh

// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
  comparison_direction = #stablehlo<comparison_direction LT>,
  compare_type = #stablehlo<comparison_type FLOAT>
} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
// %result: [true, false]

Contoh Lainnya

kompleks

Semantik

Melakukan konversi {i>element<i}-{i>wise<i} ke nilai kompleks dari pasangan real dan nilai imajiner, lhs dan rhs, serta menghasilkan tensor result.

Input

Label Nama Jenis Batasan
(I1) lhs tensor jenis f32 atau f64 (C1-C3)
(I2) rhs tensor jenis f32 atau f64 (C1)

Output

Nama Jenis Batasan
result tensor tipe kompleks (C2), (C3)

Batasan

  • (C1) type(lhs) = type(rhs).
  • (C2) shape(result) = shape(lhs).
  • (C3) element_type(result) memiliki jenis complex<E>, dengan E = element_type(lhs).

Contoh

// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs) : (tensor<2xf64>, tensor<2xf64>) -> tensor<2xcomplex<f64>>
// %result: [(1.0, 2.0), (3.0, 4.0)]

Contoh Lainnya

gabungan

Semantik

Mengenkapsulasi operasi yang terdiri (tersusun) dari operasi StableHLO lainnya, mengambil inputs dan composite_attributes, serta memproduksi results. Tujuan semantik op diimplementasikan oleh atribut decomposition. Tujuan op composite dapat diganti dengan dekomposisinya tanpa mengubah program semantik. Dalam kasus saat penyisipan dekomposisi tidak memberikan hasil semantik operasi, sebaiknya gunakan custom_call.

Kolom version (default-nya adalah 0) digunakan untuk menunjukkan kapan kolom semantik berubah.

Input

Label Nama Jenis
(I1) inputs jumlah nilai variadis
(I2) name konstanta jenis string
(I3) composite_attributes kamus atribut
(I4) decomposition konstanta jenis string
(I5) version konstanta jenis si32

Output

Nama Jenis
results jumlah nilai variadis

Batasan

  • (C1) is_namespaced_op_name(name)
  • (C2) is_defined_in_parent_scope(decomposition)
  • (C3) types(inputs...) == input_types(decomposition)
  • (C4) types(results...) == output_types(decomposition)

Contoh

%results = "stablehlo.composite"(%input0, %input1) {
  name = "my_namespace.my_op",
  composite_attributes = {
    my_attribute = "my_value"
  },
  decomposition = @my_op,
  version = 1 : i32
} : (tensor<f32>, tensor<f32>) -> tensor<f32>

Contoh Lainnya

concatenate

Semantik

Menggabungkan inputs di sepanjang dimensi dimension dalam urutan yang sama seperti yang ditentukan argumen dan menghasilkan tensor result. Secara lebih formal, result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1], dengan:

  1. id = d0 + ... + dk-1 + kd.
  2. d sama dengan dimension, dan d0, ... adalah ukuran dimensi ke-d dari inputs.

Input

Label Nama Jenis Batasan
(I1) inputs jumlah tensor atau tensor terkuantisasi per-tensor (C1-C6)
(I2) dimension konstanta jenis si64 (C2), (C4), (C6)

Output

Nama Jenis Batasan
result tensor atau tensor terkuantisasi per-tensor (C5-C6)

Batasan

  • (C1) same(element_type(inputs...)).
  • (C2) same(shape(inputs...)) kecuali untuk dim(inputs..., dimension).
  • (C3) 0 < size(inputs).
  • (C4) 0 <= dimension < rank(inputs[0]).
  • (C5) element_type(result) = element_type(inputs[0]).
  • (C6) shape(result) = shape(inputs[0]) kecuali untuk:
    • dim(result, dimension) = dim(inputs[0], dimension) + ....

Contoh

// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
  dimension = 0 : i64
} : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64>
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]

Contoh Lainnya

konstanta

Semantik

Menghasilkan tensor output dari value yang konstan.

Input

Label Nama Jenis Batasan
(I1) value konstanta (C1)

Output

Nama Jenis Batasan
output tensor atau quantized tensor (C1)

Batasan

  • (C1) type(value) = type(output).

Contoh

%output = "stablehlo.constant"() {
  value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
} : () -> tensor<2x2xf32>
// %output: [[0.0, 1.0], [2.0, 3.0]]

Contoh Lainnya

melakukan konversi

Semantik

Melakukan konversi berbasis elemen dari satu jenis elemen ke jenis elemen lainnya pada Tensor operand dan menghasilkan tensor result.

Untuk konversi boolean-to-any-supported-type, nilai false adalah dikonversi menjadi nol, dan nilai true dikonversi menjadi satu. Sebagai any-supported-type-to-boolean, nilai nol dikonversi menjadi false, dan nilai selain nol dikonversi menjadi true. Lihat di bawah untuk mengetahui cara bekerja untuk tipe-tipe yang kompleks.

Untuk konversi yang melibatkan integer-ke-bilangan bulat, bilangan bulat-ke-floating-point atau floating-point-to-floating-point, jika nilai sumber bisa tepat direpresentasikan dalam jenis tujuan, nilai hasil adalah nilai persis merepresentasinya. Jika tidak, perilaku akan ditentukan nanti (#180).

Untuk konversi yang melibatkan floating-point-to-integer, bagian pecahannya adalah terpotong. Jika nilai yang terpotong tidak dapat ditampilkan dalam jenis tujuan, perilakunya akan ditentukan nanti (#180).

Konversi yang melibatkan kompleks-ke-kompleks mengikuti perilaku yang sama dari konversi floating-point-to-floating-point untuk mengonversi real time dan bagian imajiner.

Untuk konversi complex-to-any-other-type dan complex-to-any-other-type, nilai imajiner sumber diabaikan atau nilai imajiner tujuannya adalah masing-masing menjadi nol. Konversi dari bagian riil mengikuti konversi floating point.

Pada prinsipnya, operasi ini dapat mengekspresikan dekuantisasi (konversi dari tensor terkuantisasi ke tensor reguler), kuantisasi (konversi dari tensor ke tensor terkuantisasi) dan rekuantisasi (konversi antara tensor), tetapi saat ini kami memiliki operasi khusus untuk itu - uniform_dequantize untuk kasus penggunaan pertama dan uniform_quantize untuk kasus penggunaan pertama kedua dan ketiga. Di masa mendatang, kedua operasi ini dapat digabungkan ke dalam convert (#1576).

Input

Label Nama Jenis Batasan
(I1) operand Tensor (C1)

Output

Nama Jenis Batasan
result Tensor (C1)

Batasan

  • (C1) shape(operand) = shape(result).

Contoh

// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand) : (tensor<3xi64>) -> tensor<3xcomplex<f64>>
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]

Contoh Lainnya

konvolusi

Semantik

Menghitung produk titik antara jendela lhs dan irisan rhs serta menghasilkan result. Diagram berikut menunjukkan cara penghitungan elemen pada result lhs dan rhs menggunakan contoh konkret.

konvolusi

Secara lebih formal, pertimbangkan framing ulang input berikut dalam hal lhs agar dapat mengekspresikan jendela lhs:

  • lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension)).
  • lhs_window_strides = lhs_shape(1, window_strides, 1).
  • lhs_padding = lhs_shape([0, 0], padding, [0, 0]).
  • lhs_base_dilations = lhs_shape(1, lhs_dilation, 1).
  • lhs_window_dilations = lhs_shape(1, rhs_dilation, 1).

Penyesuaian frame ini menggunakan fungsi bantuan berikut:

  • lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]).
  • result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]).
  • permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1] dengan j[d] = i[permutation[d]].

Jika feature_group_count = 1 dan batch_group_count = 1, maka untuk semua output_spatial_index di index_space(dim(result, output_spatial_dimensions...)), result[result_shape(:, output_spatial_index, :)] = dot_product dengan:

  • padding_value = constant(0, element_type(lhs)).
  • padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1).
  • lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides.
  • lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations).
  • reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true]). Fitur ini tampaknya tidak digunakan, jadi di masa mendatang kami berencana untuk menghapusnya (#1181).
  • dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension]).

Jika feature_group_count > 1:

  • lhses = split(lhs, feature_group_count, input_feature_dimension).
  • rhses = split(rhs, feature_group_count, kernel_output_feature_dimension).
  • results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...).
  • result = concatenate(results, output_feature_dimension).

Jika batch_group_count > 1:

  • lhses = split(lhs, batch_group_count, input_batch_dimension).
  • rhses = split(rhs, batch_group_count, kernel_output_feature_dimension).
  • results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...).
  • result = concatenate(results, output_feature_dimension).

Untuk jenis terkuantisasi, menjalankan dequantize_op_quantize( lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension, input_feature_dimension, input_spatial_dimensions, kernel_input_feature_dimension, kernel_output_feature_dimension, kernel_spatial_dimensions, output_batch_dimension, output_feature_dimension, output_spatial_dimensions, feature_group_count, batch_group_count, precision_config), lhs, rhs, type(result)).

Untuk jenis terkuantisasi hybrid, menjalankan hybrid_dequantize_then_op( lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension, input_feature_dimension, input_spatial_dimensions, kernel_input_feature_dimension, kernel_output_feature_dimension, kernel_spatial_dimensions, output_batch_dimension, output_feature_dimension, output_spatial_dimensions, feature_group_count, batch_group_count, precision_config), lhs, rhs).

Input

Label Nama Jenis Batasan
(I1) lhs tensor atau tensor terkuantisasi per-tensor (C1), (C10-C11), (C14) (C25), (C27-C28), (C31-C32), (C34)
(I2) rhs tensor atau quantized tensor (C1), (C14-C16), (C25), (C27-C29), (C31-C34)
(I3) window_strides Konstanta tensor 1 dimensi jenis si64 (C2-C3), (C25)
(I4) padding Konstanta tensor 2 dimensi jenis si64 (C4), (C25)
(I5) lhs_dilation Konstanta tensor 1 dimensi jenis si64 (C5-C6), (C25)
(I6) rhs_dilation Konstanta tensor 1 dimensi jenis si64 (C7-C8), (C25)
(I7) window_reversal Konstanta tensor 1 dimensi jenis i1 (C9)
(I8) input_batch_dimension konstanta jenis si64 (C10), (C13), (C25)
(I9) input_feature_dimension konstanta jenis si64 (C11), (C13-C14)
(I10) input_spatial_dimensions Konstanta tensor 1 dimensi jenis si64 (C12), (C13), (C25)
(I11) kernel_input_feature_dimension konstanta jenis si64 (C14), (C18)
(I12) kernel_output_feature_dimension konstanta jenis si64 (C15-C16), (C18), (C25), (C29)
(I13) kernel_spatial_dimensions Konstanta tensor 1 dimensi jenis si64 (C17-C18), (C25)
(I14) output_batch_dimension konstanta jenis si64 (C20), (C25)
(I15) output_feature_dimension konstanta jenis si64 (C20), (C25), (C30)
(I16) output_spatial_dimensions Konstanta tensor 1 dimensi jenis si64 (C19-C20), (C25)
(I17) feature_group_count konstanta jenis si64 (C11), (C14), (C16), (C21), (C23)
(I18) batch_group_count konstanta jenis si64 (C10), (C15), (C22), (C23), (C25)
(I19) precision_config jumlah variadik enum DEFAULT, HIGH, dan HIGHEST (C24)

Output

Nama Jenis Batasan
result tensor atau quantized tensor (C25-C28), (C30), (C32-34)

Batasan

  • (C1) N = rank(lhs) = rank(rhs).
  • (C2) size(window_strides) = N - 2.
  • (C3) 0 < window_strides.
  • (C4) shape(padding) = [N - 2, 2].
  • (C5) size(lhs_dilation) = N - 2.
  • (C6) 0 < lhs_dilation.
  • (C7) size(rhs_dilation) = N - 2.
  • (C8) 0 < rhs_dilation.
  • (C9) size(window_reversal) = N - 2.
  • (C10) dim(lhs, input_batch_dimension) % batch_group_count = 0.
  • (C11) dim(lhs, input_feature_dimension) % feature_group_count = 0.
  • (C12) size(input_spatial_dimensions) = N - 2.
  • (C13) Dengan mempertimbangkan input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]:
    • is_unique(input_dimensions).
    • 0 <= input_dimensions < N.
  • (C14) dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count.
  • (C15) dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0.
  • (C16) dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0.
  • (C17) size(kernel_spatial_dimensions) = N - 2.
  • (C18) Dengan mempertimbangkan kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]:
    • is_unique(kernel_dimensions).
    • 0 <= kernel_dimensions < N.
  • (C19) size(output_spatial_dimensions) = N - 2.
  • (C20) Dengan mempertimbangkan output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]:
    • is_unique(output_dimensions).
    • 0 <= output_dimensions < N.
  • (C21) 0 < feature_group_count.
  • (C22) 0 < batch_group_count.
  • (C23) feature_group_count = 1 or batch_group_count = 1.
  • (C24) size(precision_config) = 2.
  • (C25) dim(result, result_dim) didefinisikan sebagai:
    • dim(lhs, input_batch_dimension) / batch_group_count jika result_dim = output_batch_dimension.
    • dim(rhs, kernel_output_feature_dimension) jika result_dim = output_feature_dimension.
    • num_windows jika tidak, jika:
    • output_spatial_dimensions[spatial_dim] = result_dim.
    • lhs_dim = input_spatial_dimensions[spatial_dim].
    • rhs_dim = kernel_spatial_dimensions[spatial_dim].
    • dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1.
    • padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1].
    • dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1.
    • is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim].
    • num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1.
  • (C26) rank(result) = N.
  • Jika operasi tersebut menggunakan tensor non-kuantisasi:
    • (C27) element_type(lhs) = element_type(rhs) = element_type(result).
  • Jika operasi tersebut menggunakan tensor terkuantisasi:
    • (C28) is_quantized(lhs) = is_quantized(result) and is_quantized(rhs).
    • (C29) Jika is_per_axis_quantized(rhs), lalu quantization_dimension(rhs) = kernel_output_feature_dimension.
    • (C30) Jika is_per_axis_quantized(result), maka quantization_dimension(result) = output_feature_dimension.
    • Jika is_quantized(lhs):
    • (C31) storage_type(lhs) = storage_type(rhs).
    • (C32) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • (C33) Jika is_per_tensor_quantized(rhs), maka is_per_tensor_quantized(result).
    • Jika !is_quantized(lhs):
    • (C34) element_type(lhs) = expressed_type(rhs) = element_type(result).

Contoh

// %lhs: [[
//        [
//          [1], [2], [5], [6]
//        ],
//        [
//          [3], [4], [7], [8]
//        ],
//        [
//          [10], [11], [14], [15]
//        ],
//        [
//          [12], [13], [16], [17]
//        ]
//      ]]
//
// %rhs: [
//        [[[1]], [[1]], [[1]]],
//        [[[1]], [[1]], [[1]]],
//        [[[1]], [[1]], [[1]]]
//       ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
  window_strides = array<i64: 4, 4>,
  padding = dense<0> : tensor<2x2xi64>,
  lhs_dilation = array<i64: 2, 2>,
  rhs_dilation = array<i64: 1, 1>,
  window_reversal = array<i1: false, false>,
  // In the StableHLO dialect, dimension numbers are encoded via:
  // `[<input dimensions>]x[<kernel dimensions>]->[output dimensions]`.
  // "b" is batch dimension, "f" is feature dimension,
  // "i" is input feature dimension, "o" is output feature dimension,
  // "0/1/etc" are spatial dimensions.
  dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
  batch_group_count = 1 : i64,
  feature_group_count = 1 : i64,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
//            [[10], [26]],
//            [[46], [62]]
//          ]]

Contoh Lainnya

kosinus

Semantik

Melakukan operasi kosinus element-wise pada tensor operand dan menghasilkan Tensor result. Bergantung pada jenis elemen, melakukan hal berikut:

  • Untuk float: cos dari IEEE-754.
  • Untuk bilangan kompleks: kosinus kompleks.
  • Untuk jenis terkuantisasi: dequantize_op_quantize(cosine, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.cosine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.0], [-1.0, 0.0]]

Contoh Lainnya

count_leading_zeros

Semantik

Melakukan penghitungan berbasis elemen dari jumlah bit nol di awal dalam operand dan menghasilkan tensor result.

Input

Label Nama Jenis Batasan
(I1) operand tensor tipe integer (C1)

Output

Nama Jenis Batasan
result tensor tipe integer (C1)

Batasan

  • (C1) type(operand) = type(result).

Contoh

// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand) : (tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[64, 63], [56, 0]]

Contoh Lainnya

custom_call

Semantik

Mengenkapsulasi call_target_name operasi yang ditentukan implementasi yang menggunakan inputs dan called_computations serta menghasilkan results. has_side_effect, backend_config dan api_version dapat digunakan untuk memberikan metadata yang ditentukan implementasi.

Saat ini, operasi ini berisi kumpulan data yang tidak terorganisir {i>metadata <i}yang mencerminkan evolusi organik dari operasi kompilator XLA. Di masa mendatang, kami berencana untuk menyatukan metadata ini (#741).

Input

Label Nama Jenis
(I1) inputs jumlah nilai variadis
(I2) call_target_name konstanta jenis string
(I3) has_side_effect konstanta jenis i1
(I4) backend_config konstanta jenis string atau kamus atribut
(I5) api_version konstanta jenis si32
(I6) called_computations jumlah konstanta variadik jenis string

Output

Nama Jenis
results jumlah nilai variadis

Contoh

%results = "stablehlo.custom_call"(%input0) {
  call_target_name = "foo",
  has_side_effect = false,
  backend_config = {bar = 42 : i32},
  api_version = 4 : i32,
  called_computations = [@foo]
} : (tensor<f64>) -> tensor<f64>

bagi

Semantik

Melakukan pembagian menurut elemen dari tensor lhs dan pembagi rhs serta menghasilkan tensor result. Bergantung pada jenis elemen, melakukan hal berikut:

  • Untuk bilangan bulat: pembagian bilangan bulat yang menghasilkan hasil bagi aljabar dengan bagian pecahan dibuang.
  • Untuk float: division dari IEEE-754.
  • Untuk bilangan kompleks: pembagian kompleks.
  • Untuk jenis terkuantisasi:
    • dequantize_op_quantize(divide, lhs, rhs, type(result)).

Input

Label Nama Jenis Batasan
(I1) lhs tensor integer, floating-point, atau tipe kompleks atau tensor terkuantisasi per-tensor (C1)
(I2) rhs tensor integer, floating-point, atau tipe kompleks atau tensor terkuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result tensor bilangan bulat, floating point, atau tipe kompleks atau tensor terkuantisasi per-tensor (C1)

Batasan

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Contoh

// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]

Contoh Lainnya

dot_general

Semantik

Menghitung produk titik di antara irisan lhs dan irisan rhs serta menghasilkan Tensor result.

Secara lebih formal, result[result_index] = dot_product, dengan:

  • lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions].
  • rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions].
  • result_batching_index + result_lhs_index + result_rhs_index = result_index dengan size(result_batching_index) = size(lhs_batching_dimensions), size(result_lhs_index) = size(lhs_result_dimensions) dan size(result_rhs_index) = size(rhs_result_dimensions).
  • transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions).
  • transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :]).
  • reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions)).
  • transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions).
  • transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :]).
  • reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions)).
  • dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y)).

Untuk jenis terkuantisasi, menjalankan dequantize_op_quantize( lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions, rhs_batching_dimensions, lhs_contracting_dimensions, rhs_contracting_dimensions, precision_config), lhs, rhs, type(result)).

Untuk jenis terkuantisasi hybrid, menjalankan hybrid_dequantize_then_op( lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions, rhs_batching_dimensions, lhs_contracting_dimensions, rhs_contracting_dimensions, precision_config), lhs, rhs).

precision_config mengontrol keseimbangan antara kecepatan dan akurasi untuk komputasi pada backend akselerator. Kolom ini dapat berupa salah satu dari hal berikut (di semantik, semantik dari nilai enum ini kurang ditentukan, namun kita berencana untuk mengatasinya dalam #755):

  • DEFAULT: Penghitungan tercepat, tetapi perkiraan yang paling tidak akurat ke nomor aslinya.
  • HIGH: Penghitungan lebih lambat, tetapi perkiraan yang lebih akurat terhadap nomor aslinya.
  • HIGHEST: Penghitungan paling lambat, tetapi perkiraan paling akurat ke nomor aslinya.

DotAlgorithm menentukan properti utama algoritma yang digunakan untuk menerapkan operasi titik, yang juga menentukan presisi. Jika atribut algoritma kolom ditetapkan, maka precision_config harus DEFAULT. DotAlgorithms tidak memiliki nilai default, karena parameter default adalah implementasi didefinisikan. Dengan demikian, semua kolom algoritme titik dapat ditetapkan ke None untuk menentukan algoritma titik kosong, yang akan menggunakan nilai precision_config.

Kolom DotAlgorithm mencakup:

  • lhs_precision_type dan rhs_precision_type, presisi yang ditentukan oleh LHS dan RHS operasi dibulatkan ke. Jenis presisi tidak bergantung pada jenis penyimpanan input dan output-nya.
  • accumulation_type presisi yang digunakan untuk akumulasi.
  • lhs_component_count, rhs_component_count, dan num_primitive_operations terapkan saat kita melakukan algoritma yang menguraikan LHS dan/atau RHS menjadi banyak komponen dan melakukan beberapa “primitif” operasi dot pada nilai - biasanya untuk mengemulasi presisi yang lebih tinggi (mis. Memanfaatkan Jenis Data Kecerdasan Buatan bfloat16 untuk Komputasi Presisi Lebih Tinggi: bf16_6x tf32_3x, dll.). Untuk algoritma tanpa dekomposisi, nilai ini harus ditetapkan ke 1.
  • allow_imprecise_accumulation untuk menentukan apakah akumulasi dalam presisi lebih rendah diizinkan untuk beberapa langkah (mis. CUBLASLT_MATMUL_DESC_FAST_ACCUM).

Contoh atribut DotAlgorithm:

// Inputs are casted to tf32, and then accumulated in f32:
{lhs_precision_type = tf32,
 rhs_precision_type = tf32,
 accumulation_type = f32,
 lhs_component_count = 1,
 rhs_component_count = 1,
 num_primitive_operations = 1,
 allow_imprecise_accumulation = false}


// bf16_6x: each input is decomposed to 3 bf16 components, then 6 dot operations are done on those components, and the result is accumulated in f32.
{lhs_precision_type = bf16,
 rhs_precision_type = bf16,
 accumulation_type = f32,
 lhs_component_count = 3,
 rhs_component_count = 3,
 num_primitive_operations = 6,
 allow_imprecise_accumulation = false}


// Inputs are (casted to) f8e5m2, and we accumulate in f32, but for some steps we may accumulate in lower precision.
{lhs_precision_type = f8e5m2,
 rhs_precision_type = f8e5m2,
 accumulation_type = f32,
 lhs_component_count = 1,
 rhs_component_count = 1,
 num_primitive_operations = 1,
 allow_imprecise_accumulation = true}

Implementasi bebas untuk memutuskan kombinasi mana yang didukung. Di beberapa umum, tidak ada jaminan bahwa setiap algoritma didukung pada setiap jenis akselerator oleh konsumen StableHLO. Jika algoritma tertentu tidak didukung, maka kesalahan akan dimunculkan, bukan untuk kembali ke alternatif. Verifikasi StabilHLO akan memberikan upaya verifikasi terbaik, mencegah algoritme yang tidak diketahui didukung pada perangkat keras apa pun.

Lihat xla_data.proto > Algorithm untuk beberapa nilai algoritma yang didukung. Tiket #2483 berisi rencana untuk membuat dokumen terpusat pada algoritma yang didukung oleh backend.

Input

Label Nama Jenis Batasan
(I1) lhs tensor atau tensor terkuantisasi per-tensor (C5-C6), (C9-C10), (C12-C14), (C17-C18), (C20)
(I2) rhs tensor atau quantized tensor (C7-C10), (C12-C20)
(I3) lhs_batching_dimensions Konstanta tensor 1 dimensi jenis si64 (C1), (C3), (C5), (C9), (C12)
(I4) rhs_batching_dimensions Konstanta tensor 1 dimensi jenis si64 (C1), (C4), (C7), (C9)
(I5) lhs_contracting_dimensions Konstanta tensor 1 dimensi jenis si64 (C2), (C3), (C6), (C10)
(I6) rhs_contracting_dimensions Konstanta tensor 1 dimensi jenis si64 (C2), (C4), (C8), (C10), (C16)
(I7) precision_config jumlah variadik enum DEFAULT, HIGH, dan HIGHEST (C11), C21
(I8) lhs_precision_type FloatType atau TensorFloat32 (C21)
(I9) rhs_precision_type FloatType atau TensorFloat32 (C21)
(I10) accumulation_type FloatType atau TensorFloat32 (C21)
(I11) lhs_component_count konstanta jenis si32 (C21), C22
(I12) rhs_component_count konstanta jenis si32 (C21), C23
(I13) num_primitive_operations konstanta jenis si32 (C21), C24
(I14) allow_imprecise_accumulation konstanta jenis bool (C21)

Output

Nama Jenis Batasan
result tensor atau quantized tensor (C12), (C14), (C18—C20)

Batasan

  • (C1) size(lhs_batching_dimensions) = size(rhs_batching_dimensions).
  • (C2) size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions).
  • (C3) is_unique(lhs_batching_dimensions + lhs_contracting_dimensions).
  • (C4) is_unique(rhs_batching_dimensions + rhs_contracting_dimensions).
  • (C5) 0 <= lhs_batching_dimensions < rank(lhs).
  • (C6) 0 <= lhs_contracting_dimensions < rank(lhs).
  • (C7) 0 <= rhs_batching_dimensions < rank(rhs).
  • (C8) 0 <= rhs_contracting_dimensions < rank(rhs).
  • (C9) dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...).
  • (C10) dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...).
  • (C11) size(precision_config) = 2.
  • (C12) shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions).
  • Jika operasi tersebut menggunakan tensor non-kuantisasi:
    • (C13) element_type(lhs) = element_type(rhs).
  • Jika operasi tersebut menggunakan tensor terkuantisasi:
    • (C14) is_quantized(lhs) = is_quantized(result) and is_quantized(rhs).
    • (C15) zero_points(rhs) = 0.
    • (C16) Jika is_per_axis_quantized(rhs), maka quantization_dimension(rhs) tidak ada di rhs_contracting_dimensions.
    • Jika is_quantized(lhs):
    • (C17) storage_type(lhs) = storage_type(rhs).
    • (C18) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • (C19) Jika is_per_tensor_quantized(rhs), maka is_per_tensor_quantized(result).
    • Jika !is_quantized(lhs):
    • (C20) element_type(lhs) = expressed_type(rhs) = element_type(result).
  • Jika !is_empty_algorithm(lhs_precision_type, rhs_precision_type, accumulation_type, lhs_component_count, rhs_component_count, num_primitive_operations allow_imprecise_accumulation):
    • (C21) precision_config... = DEFAULT.
    • (C22) 0 < lhs_component_count.
    • (C23) 0 < rhs_component_count.
    • (C24) 0 < num_primitive_operations.

Contoh

// %lhs: [
//        [[1, 2],
//         [3, 4]],
//        [[5, 6],
//         [7, 8]]
//       ]
// %rhs: [
//        [[1, 0],
//         [0, 1]],
//        [[1, 0],
//         [0, 1]]
//       ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
  dot_dimension_numbers = #stablehlo.dot<
    lhs_batching_dimensions = [0],
    rhs_batching_dimensions = [0],
    lhs_contracting_dimensions = [2],
    rhs_contracting_dimensions = [1]
  >,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>],
  algorithm = #stablehlo.dot_algorithm<
    lhs_precision_type = tf32,
    rhs_precision_type = tf32,
    accumulation_type = f32,
    lhs_component_count = 1,
    rhs_component_count = 1,
    num_primitive_operations = 1,
    allow_imprecise_accumulation = false
  >
} : (tensor<2x2x2xi64>, tensor<2x2x2xi64>) -> tensor<2x2x2xi64>
// %result: [
//           [[1, 2],
//            [3, 4]],
//           [[5, 6],
//            [7, 8]]
//          ]

Contoh Lainnya

dynamic_broadcast_in_dim

Semantik

Operasi ini secara fungsional identik dengan broadcast_in_dim berjalan, tetapi bentuk hasilnya ditentukan secara dinamis melalui output_dimensions.

Operasi tersebut juga menerima atribut opsional known_expanding_dimensions, known_non_expanding_dimensions untuk mengekspresikan pengetahuan statis tentang perilaku dimensi yang meluas. Jika tidak ditentukan, semua dimensi diasumsikan mungkin dapat diperluas.

Input

Label Nama Jenis Batasan
(I1) operand tensor atau quantized tensor (C1-C2), (C5-C6), (C9)
(I2) output_dimensions Tensor 1-dimensi dari jenis integer (C7)
(I3) broadcast_dimensions Tensor konstanta 1 dimensi dari jenis integer (C2-C6)
(I4) known_expanding_dimensions Tensor konstanta 1 dimensi dari jenis integer (C8-C9)
(I5) known_non_expanding_dimensions Tensor konstanta 1 dimensi dari jenis integer (C8-C9)

Output

Nama Jenis Batasan
result tensor atau quantized tensor (C1), (C3), (C5-C7)

Batasan

  • (C1) element_type(result) diberikan oleh:
    • element_type(operand), jika !is_per_axis_quantized(operand).
    • element_type(operand) kecuali quantization_dimension(operand), scales(operand), dan zero_points(operand) mungkin berbeda dengan quantization_dimension(result), scales(result), dan zero_points(result) resp., jika tidak.
  • (C2) size(broadcast_dimensions) = rank(operand).
  • (C3) 0 <= broadcast_dimensions < rank(result).
  • (C4) is_unique(broadcast_dimensions).
  • (C5) Untuk semua d di axes(operand):
    • dim(operand, d) = 1 atau
    • dim(operand, d) = dim(result, broadcast_dimensions[d]).
  • (C6) Jika is_per_axis_quantized(result):
    • quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)].
    • Jika dim(operand, quantization_dimension(operand)) = 1, maka scales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result))).
  • (C7) size(output_dimensions) = rank(result).
  • (C8) is_unique(known_expanding_dimensions + known_non_expanding_dimensions).
  • (C9) 0 <= known_expanding_dimensions < rank(operand).
  • (C10) 0 <= known_non_expanding_dimensions < rank(operand).

Contoh

// %operand: [
//            [1, 2, 3]
//           ]
%operand = stablehlo.constant dense<[[1, 2, 3]]> : tensor<1x3xi64>
%output_dimensions = stablehlo.constant dense<[2, 3, 2]> : tensor<3xi64>
%result = "stablehlo.dynamic_broadcast_in_dim"(%operand, %output_dimensions) {
  broadcast_dimensions = array<i64: 2, 1>,
  known_expanding_dimensions = array<i64: 0>,
  known_non_expanding_dimensions = array<i64: 1>
} : (tensor<1x3xi64>, tensor<3xi64>) -> tensor<2x3x2xi64>
// %result: [
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ],
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ]
//          ]

Contoh Lainnya

dynamic_conv

Semantik

Operasi ini secara fungsional identik dengan konvolusi op, tetapi padding ditentukan secara dinamis melalui padding.

Input

Label Nama Jenis Batasan
(I1) lhs tensor atau tensor terkuantisasi per-tensor (C1), (C10-C11), (C14) (C25), (C26-C27), (C30-C31), (C33)
(I2) rhs tensor atau quantized tensor (C1), (C14-C16), (C26-C28), (C30-C33)
(I3) padding Tensor 2 dimensi dari tipe integer (C4)
(I4) window_strides Konstanta tensor 1 dimensi jenis si64 (C2-C3)
(I5) lhs_dilation Konstanta tensor 1 dimensi jenis si64 (C5-C6)
(I6) rhs_dilation Konstanta tensor 1 dimensi jenis si64 (C7-C8)
(I7) window_reversal Konstanta tensor 1 dimensi jenis i1 (C9)
(I8) input_batch_dimension konstanta jenis si64 (C10), C13)
(I9) input_feature_dimension konstanta jenis si64 (C11), (C13-C14)
(I10) input_spatial_dimensions Konstanta tensor 1 dimensi jenis si64 (C12), C13)
(I11) kernel_input_feature_dimension konstanta jenis si64 (C14), (C18)
(I12) kernel_output_feature_dimension konstanta jenis si64 (C15-C16), (C18), (C28)
(I13) kernel_spatial_dimensions Konstanta tensor 1 dimensi jenis si64 (C17-C18)
(I14) output_batch_dimension konstanta jenis si64 (C20)
(I15) output_feature_dimension konstanta jenis si64 (C20), (C29)
(I16) output_spatial_dimensions Konstanta tensor 1 dimensi jenis si64 (C19-C20)
(I17) feature_group_count konstanta jenis si64 (C11), (C14), (C16), (C21), (C23)
(I18) batch_group_count konstanta jenis si64 (C10), (C15), (C22), (C23)
(I19) precision_config jumlah variadik enum DEFAULT, HIGH, dan HIGHEST (C24)

Output

Nama Jenis Batasan
result tensor atau quantized tensor (C25-C27), (C29), (C31-C33)

Batasan

  • (C1) N = rank(lhs) = rank(rhs).
  • (C2) size(window_strides) = N - 2.
  • (C3) 0 < window_strides.
  • (C4) shape(padding) = [N - 2, 2].
  • (C5) size(lhs_dilation) = N - 2.
  • (C6) 0 < lhs_dilation.
  • (C7) size(rhs_dilation) = N - 2.
  • (C8) 0 < rhs_dilation.
  • (C9) size(window_reversal) = N - 2.
  • (C10) dim(lhs, input_batch_dimension) % batch_group_count = 0.
  • (C11) dim(lhs, input_feature_dimension) % feature_group_count = 0.
  • (C12) size(input_spatial_dimensions) = N - 2.
  • (C13) Dengan mempertimbangkan input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]:
    • is_unique(input_dimensions).
    • 0 <= input_dimensions < N.
  • (C14) dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count.
  • (C15) dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0.
  • (C16) dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0.
  • (C17) size(kernel_spatial_dimensions) = N - 2.
  • (C18) Dengan mempertimbangkan kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]:
    • is_unique(kernel_dimensions).
    • 0 <= kernel_dimensions < N.
  • (C19) size(output_spatial_dimensions) = N - 2.
  • (C20) Dengan mempertimbangkan output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]:
    • is_unique(output_dimensions).
    • 0 <= output_dimensions < N.
  • (C21) 0 < feature_group_count.
  • (C22) 0 < batch_group_count.
  • (C23) feature_group_count = 1 or batch_group_count = 1.
  • (C24) size(precision_config) = 2.
  • (C25) dim(result, result_dim) didefinisikan sebagai:
    • dim(lhs, input_batch_dimension) / batch_group_count jika result_dim = output_batch_dimension.
    • dim(rhs, kernel_output_feature_dimension) jika result_dim = output_feature_dimension.
    • num_windows jika tidak, jika:
    • output_spatial_dimensions[spatial_dim] = result_dim.
    • lhs_dim = input_spatial_dimensions[spatial_dim].
    • rhs_dim = kernel_spatial_dimensions[spatial_dim].
    • dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1.
    • padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1].
    • dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1.
    • is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim].
    • num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1.
  • (C26) rank(result) = N.
  • Jika operasi tersebut menggunakan tensor non-kuantisasi:
    • (C27) element_type(lhs) = element_type(rhs) = element_type(result).
  • Jika operasi tersebut menggunakan tensor terkuantisasi:
    • (C28) is_quantized(lhs) = is_quantized(result) and is_quantized(rhs).
    • (C29) Jika is_per_axis_quantized(rhs), lalu quantization_dimension(rhs) = kernel_output_feature_dimension.
    • (C30) Jika is_per_axis_quantized(result), maka quantization_dimension(result) = output_feature_dimension.
    • Jika is_quantized(lhs):
    • (C31) storage_type(lhs) = storage_type(rhs).
    • (C32) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • (C33) Jika is_per_tensor_quantized(rhs), maka is_per_tensor_quantized(result).
    • Jika !is_quantized(lhs):
    • (C34) element_type(lhs) = expressed_type(rhs) = element_type(result).

Contoh

// %lhs: [[
//        [[1], [2], [5], [6]],
//        [[3], [4], [7], [8]],
//        [[10], [11], [14], [15]],
//        [[12], [13], [16], [17]]
//      ]]
//
// %rhs: [
//         [[[1]], [[1]], [[1]]],
//         [[[1]], [[1]], [[1]]],
//         [[[1]], [[1]], [[1]]]
//        ]
// %padding: [[1, 1],
//            [1, 1]]
%result = "stablehlo.dynamic_conv"(%lhs, %rhs, %padding) {
  window_strides = array<i64: 4, 4>,
  lhs_dilation = array<i64: 2, 2>,
  rhs_dilation = array<i64: 1, 1>,
  window_reversal = array<i1: false, false>,
  dimension_numbers = #stablehlo.conv<raw
    input_batch_dimension = 0,
    input_feature_dimension = 3,
    input_spatial_dimensions = [0, 1],
    kernel_input_feature_dimension = 2,
    kernel_output_feature_dimension = 3,
    kernel_spatial_dimensions = [0, 1],
    output_batch_dimension = 0,
    output_feature_dimension = 3,
    output_spatial_dimensions = [1, 2]
  >,
  feature_group_count = 1 : i64,
  batch_group_count = 1 : i64,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>, tensor<2x2xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
//            [[1], [5]],
//            [[10], [14]]
//          ]]

Contoh Lainnya

dynamic_gather

Semantik

Operasi ini secara fungsional identik dengan kumpulkan op, dengan slice_sizes yang ditentukan secara dinamis sebagai nilai.

Input

Label Nama Jenis Batasan
(I1) operand tensor atau tensor terkuantisasi per-tensor (C1), (C7), (C10-C12), (C14)
(I2) start_indices tensor tipe integer (C2), (C3), (C13)
(I3) slice_sizes Tensor 1-dimensi dari jenis integer (C8), (C11-C13)
(I4) offset_dims Konstanta tensor 1 dimensi jenis si64 (C1), (C4-C5), (C13)
(I5) collapsed_slice_dims Konstanta tensor 1 dimensi jenis si64 (C1), (C6-C8), (C13)
(I6) start_index_map Konstanta tensor 1 dimensi jenis si64 (C3), (C9), (C10)
(I7) index_vector_dim konstanta jenis si64 (C2), (C3), (C13)
(I8) indices_are_sorted konstanta jenis i1

Output

Nama Jenis Batasan
result tensor atau tensor terkuantisasi per-tensor (C5), (C13-C14)

Batasan

  • (C1) rank(operand) = size(offset_dims) + size(collapsed_slice_dims).
  • (C2) 0 <= index_vector_dim <= rank(start_indices).
  • (C3) size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1.
  • (C4) is_unique(offset_dims) and is_sorted(offset_dims).
  • (C5) 0 <= offset_dims < rank(result).
  • (C6) is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims).
  • (C7) 0 <= collapsed_slice_dims < rank(operand).
  • (C8) slice_sizes[collapsed_slice_dims...] <= 1.
  • (C9) is_unique(start_index_map).
  • (C10) 0 <= start_index_map < rank(operand).
  • (C11) size(slice_sizes) = rank(operand).
  • (C12) 0 <= slice_sizes <= shape(operand).
  • (C13) shape(result) = combine(batch_dim_sizes, offset_dim_sizes) jika:
    • batch_dim_sizes = shape(start_indices), kecuali bahwa ukuran dimensi dari start_indices yang sesuai dengan index_vector_dim tidak disertakan.
    • offset_dim_sizes = shape(slice_sizes) kecuali bahwa ukuran dimensi di slice_sizes yang sesuai dengan collapsed_slice_dims tidak disertakan.
    • combine menempatkan batch_dim_sizes pada sumbu yang sesuai dengan batch_dims dan offset_dim_sizes pada sumbu yang sesuai dengan offset_dims.
  • (C14) element_type(operand) = element_type(result).

Contoh

// %operand: [
//            [[1, 2], [3, 4], [5, 6], [7, 8]],
//            [[9, 10],[11, 12], [13, 14], [15, 16]],
//            [[17, 18], [19, 20], [21, 22], [23, 24]]
//           ]
// %start_indices: [
//                  [[0, 0], [1, 0], [2, 1]],
//                  [[0, 1], [1, 1], [0, 2]]
//                 ]
// %slize_sizes: [1, 2, 2]
%result = "stablehlo.dynamic_gather"(%operand, %start_indices, %slize_sizes) {
  dimension_numbers = #stablehlo.gather<
    offset_dims = [2, 3],
    collapsed_slice_dims = [0],
    start_index_map = [1, 0],
    index_vector_dim = 2>,
  indices_are_sorted = false
} : (tensor<3x4x2xi64>, tensor<2x3x2xi64>, tensor<3xi64>) -> tensor<2x3x2x2xi64>
// %result: [
//            [
//              [[1, 2], [3, 4]],
//              [[3, 4], [5, 6]],
//              [[13, 14], [15, 16]]
//            ],
//            [
//              [[9, 10], [11, 12]],
//              [[11, 12], [13, 14]],
//              [[17, 18], [19, 20]]
//            ]
//          ]

Contoh Lainnya

dynamic_iota

Semantik

Operasi ini secara fungsional identik dengan Iota berjalan, tetapi bentuk hasilnya ditentukan secara dinamis melalui output_shape.

Input

Label Nama Jenis Batasan
(I1) output_shape Tensor 1-dimensi dari jenis integer (C1), C2
(I2) iota_dimension si64 (C1)

Output

Nama Jenis Batasan
result tensor bilangan bulat, floating point, atau tipe kompleks atau tensor terkuantisasi per-tensor (C2)

Batasan

  • (C1) 0 <= iota_dimension < size(output_shape).
  • (C2) rank(result) = size(output_shape).

Contoh

%output_shape = stablehlo.constant dense<[4, 5]> : tensor<2xi64>
%result = "stablehlo.dynamic_iota"(%output_shape) {
  iota_dimension = 0 : i64
} : (tensor<2xi64>) -> tensor<4x5xi64>
// %result: [
//           [0, 0, 0, 0, 0],
//           [1, 1, 1, 1, 1],
//           [2, 2, 2, 2, 2],
//           [3, 3, 3, 3, 3]
//          ]

Contoh Lainnya

dynamic_pad

Semantik

Operasi ini secara fungsional identik dengan pad operasi, tetapi dengan edge_padding_low, edge_padding_high, dan interior_padding ditetapkan secara dinamis sebagai nilai.

Input

Label Nama Jenis Batasan
(I1) operand tensor atau tensor terkuantisasi per-tensor (C1), (C2), (C4)
(I2) padding_value Tensor 0-dimensi atau per-tensor quantized tensor (C1)
(I3) edge_padding_low Tensor 1-dimensi dari jenis integer (C1), C4
(I4) edge_padding_high Tensor 1-dimensi dari jenis integer (C1), C4
(I5) interior_padding Tensor 1-dimensi dari jenis integer (C2-C4)

Output

Nama Jenis Batasan
result tensor atau tensor terkuantisasi per-tensor (C3-C6)

Batasan

  • (C1) element_type(operand) = element_type(padding_value) = element_type(result).
  • (C2) size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand).
  • (C3) 0 <= interior_padding.
  • (C4) shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high.

Contoh

// %operand: [
//            [1, 2, 3],
//            [4, 5, 6]
//           ]
// %padding_value: 0
// %edge_padding_low: [0, 1]
// %edge_padding_high: [2, 1]
// %interior_padding: [1, 2]
%result = "stablehlo.dynamic_pad"(%operand, %padding_value,
  %edge_padding_low, %edge_padding_high, %interior_padding
) : (tensor<2x3xi64>, tensor<i64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<5x9xi64>
// %result: [
//           [0, 1, 0, 0, 2, 0, 0, 3, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 4, 0, 0, 5, 0, 0, 6, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0]
//          ]

Contoh Lainnya

dynamic_reshape

Semantik

Operasi ini secara fungsional identik dengan membentuk ulang berjalan, tetapi bentuk hasilnya ditentukan secara dinamis melalui output_shape.

Input

Label Nama Jenis Batasan
(I1) operand tensor atau quantized tensor (C1-C3)
(I2) output_shape Tensor 1-dimensi dari jenis integer (C4)

Output

Nama Jenis Batasan
result tensor atau quantized tensor (C1-C4)

Batasan

  • (C1) element_type(result) diberikan oleh:
    • element_type(operand), jika !is_per_axis_quantized(operand).
    • element_type(operand) kecuali quantization_dimension(operand) dan quantization_dimension(result) dapat berbeda, jika tidak.
  • (C2) size(operand) = size(result).
  • (C3) Jika is_per_axis_quantized(operand):
    • reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
    • dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result)).
    • reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
  • (C4) size(output_shape) = rank(result).

Contoh

// %operand: [[1, 2, 3], [4, 5, 6]]
// %output_shape: [3, 2]
%result = "stablehlo.dynamic_reshape"(%operand, %output_shape) : (tensor<2x3xi64>, tensor<2xi64>) -> tensor<3x2xi64>
// %result: [[1, 2], [3, 4], [5, 6]]

Contoh Lainnya

dynamic_slice

Semantik

Mengekstrak slice dari operand menggunakan indeks awal yang dikomputasi secara dinamis dan menghasilkan tensor result. start_indices berisi indeks awal dari irisan untuk setiap dimensi yang dapat disesuaikan dengan potensi penyesuaian, dan slice_sizes berisi ukuran irisan untuk setiap dimensi. Secara lebih formal, result[result_index] = operand[operand_index] dalam hal ini:

  • adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes).
  • operand_index = adjusted_start_indices + result_index.

Input

Label Nama Jenis Batasan
(I1) operand tensor atau tensor terkuantisasi per-tensor (C1), (C2), (C4)
(I2) start_indices jumlah variadic tensor 0-dimensi dari jenis integer (C2), (C3)
(I3) slice_sizes Konstanta tensor 1 dimensi jenis si64 (C2), (C4), (C5)

Output

Nama Jenis Batasan
result tensor atau tensor terkuantisasi per-tensor (C1), (C5)

Batasan

  • (C1) element_type(operand) = element_type(result).
  • (C2) size(start_indices) = size(slice_sizes) = rank(operand).
  • (C3) same(type(start_indices...)).
  • (C4) 0 <= slice_sizes <= shape(operand).
  • (C5) shape(result) = slice_sizes.

Contoh

// %operand: [
//            [0, 0, 1, 1],
//            [0, 0, 1, 1],
//            [0, 0, 0, 0],
//            [0, 0, 0, 0]
//           ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
  slice_sizes = array<i64: 2, 2>
} : (tensor<4x4xi32>, tensor<i64>, tensor<i64>) -> tensor<2x2xi32>
// %result: [
//           [1, 1],
//           [1, 1]
//          ]

Contoh Lainnya

dynamic_update_slice

Semantik

Menghasilkan tensor result yang sama dengan tensor operand, kecuali bahwa slice yang dimulai dari start_indices akan diupdate dengan nilai di update. Secara lebih formal, result[result_index] didefinisikan sebagai:

  • update[update_index] jika 0 <= update_index < shape(update) di mana:
    • adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update)).
    • update_index = result_index - adjusted_start_indices.
  • operand[result_index] sebaliknya.

Input

Label Nama Jenis Batasan
(I1) operand tensor atau tensor terkuantisasi per-tensor (C1-C4), (C6)
(I2) update tensor atau tensor terkuantisasi per-tensor (C2), (C3), (C6)
(I3) start_indices jumlah variadic tensor 0-dimensi dari jenis integer (C4), (C5)

Output

Nama Jenis Batasan
result tensor atau tensor terkuantisasi per-tensor (C1)

Batasan

  • (C1) type(operand) = type(result).
  • (C2) element_type(update) = element_type(operand).
  • (C3) rank(update) = rank(operand).
  • (C4) size(start_indices) = rank(operand).
  • (C5) same(type(start_indices...)).
  • (C6) 0 <= shape(update) <= shape(operand).

Contoh

// %operand: [
//            [1, 1, 0, 0],
//            [1, 1, 0, 0],
//            [1, 1, 1, 1],
//            [1, 1, 1, 1]
//           ]
// %update: [
//           [1, 1],
//           [1, 1]
//          ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
  : (tensor<4x4xi32>, tensor<2x2xi32>, tensor<i64>, tensor<i64>) -> tensor<4x4xi32>
// %result: [
//           [1, 1, 1, 1],
//           [1, 1, 1, 1],
//           [1, 1, 1, 1],
//           [1, 1, 1, 1]
//          ]

Contoh Lainnya

berpangkat

Semantik

Melakukan operasi eksponensial setiap elemen pada tensor operand dan menghasilkan Tensor result. Bergantung pada jenis elemen, melakukan hal berikut:

  • Untuk float: exp dari IEEE-754.
  • Untuk bilangan kompleks: eksponensial kompleks.
  • Untuk jenis terkuantisasi: dequantize_op_quantize(exponential, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]

Contoh Lainnya

exponential_minus_one

Semantik

Menjalankan eksponensial element-wise dikurangi satu operasi pada tensor operand dan menghasilkan tensor result. Bergantung pada jenis elemen, melakukan hal berikut:

  • Untuk float: expm1 dari IEEE-754.
  • Untuk bilangan kompleks: eksponensial kompleks dikurangi satu.
  • Untuk jenis terkuantisasi: dequantize_op_quantize(exponential_minus_one, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand) : (tensor<2xf64>) -> tensor<2xf64>
// %result: [0.0, 1.71828187]

Contoh Lainnya

FFT

Semantik

Melakukan transformasi Fourier maju dan terbalik untuk real dan kompleks input/output lainnya.

fft_type adalah salah satu dari berikut ini:

  • FFT: Meneruskan FFT kompleks ke kompleks.
  • IFFT: FFT kompleks ke kompleks terbalik.
  • RFFT: Meneruskan FFT real-ke-kompleks.
  • IRFFT: FFT real-ke-kompleks terbalik (yaitu yang kompleks, menampilkan nilai nyata).

Secara lebih formal, dengan fungsi fft yang menggunakan tensor 1 dimensi tipe kompleks sebagai input, menghasilkan tensor 1-dimensi dari tipe yang sama dengan dan menghitung transformasi Fourier diskrit:

Untuk fft_type = FFT, result ditentukan sebagai hasil akhir deret L komputasi di mana L = size(fft_length). Misalnya, untuk L = 3:

  • result1[i0, ..., :] = fft(operand[i0, ..., :]).
  • result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).

Selain itu, dengan fungsi ifft yang memiliki jenis tanda tangan dan menghitung invers fft:

Untuk fft_type = IFFT, result didefinisikan sebagai kebalikan dari komputasi untuk fft_type = FFT. Misalnya, untuk L = 3:

  • result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).
  • result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :] = ifft(result2[i0, ..., :]).

Selanjutnya, dengan fungsi rfft yang menggunakan tensor 1 dimensi tipe floating point, menghasilkan tensor 1 dimensi dari jenis semantik floating point yang sama dan bekerja sebagai berikut:

  • rfft(real_operand) = truncated_result di mana
  • complex_operand... = (real_operand..., 0.0).
  • complex_result = fft(complex_operand).
  • truncated_result = complex_result[:(rank(complex_result) / 2 + 1)].

(Ketika transformasi Fourier diskrit dihitung untuk operand riil, yang pertama Elemen N/2 + 1 hasil secara jelas menentukan sisa hasil, sehingga hasil rfft terpotong untuk menghindari komputasi elemen yang redundan).

Untuk fft_type = RFFT, result ditentukan sebagai hasil akhir deret L komputasi di mana L = size(fft_length). Misalnya, untuk L = 3:

  • result1[i0, ..., :] = rfft(operand[i0, ..., :]).
  • result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).

Terakhir, dengan fungsi irfft yang memiliki tanda tangan dan tanda tangan jenis yang sama menghitung invers rfft:

Untuk fft_type = IRFFT, result didefinisikan sebagai kebalikan dari komputasi untuk fft_type = RFFT. Misalnya, untuk L = 3:

  • result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).
  • result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :] = irfft(result2[i0, ..., :]).

Input

Label Nama Jenis Batasan
(I1) operand tensor floating point atau jenis kompleks (C1), (C2), (C4), (C5)
(I2) fft_type enum FFT, IFFT, RFFT, dan IRFFT (C2), (C5)
(I3) fft_length Konstanta tensor 1 dimensi jenis si64 (C1), (C3), (C4)

Output

Nama Jenis Batasan
result tensor floating point atau jenis kompleks (C2), (C4), (C5)

Batasan

  • (C1) size(fft_length) <= rank(operand).
  • (C2) Hubungan antara jenis elemen operand dan result bervariasi:
    • Jika fft_type = FFT, element_type(operand), dan element_type(result) memiliki jenis kompleks yang sama.
    • Jika fft_type = IFFT, element_type(operand), dan element_type(result) memiliki jenis kompleks yang sama.
    • Jika fft_type = RFFT, element_type(operand) adalah jenis floating point dan element_type(result) adalah jenis kompleks dari floating point yang sama semantik.
    • Jika fft_type = IRFFT, element_type(operand) adalah jenis kompleks dan element_type(result) adalah jenis floating point dari floating point yang sama semantik.
  • (C3) 1 <= size(fft_length) <= 3.
  • (C4) Jika di antara operand dan result, ada tensor real dari jenis floating point, kemudian shape(real)[-size(fft_length):] = fft_length.
  • (C5) shape(result) = shape(operand) kecuali untuk:
    • Jika fft_type = RFFT, dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1.
    • Jika fft_type = IRFFT, dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1.

Contoh

// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
  fft_type = #stablehlo<fft_type FFT>,
  fft_length = array<i64: 4>
} : (tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>>
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]

floor

Semantik

Melakukan nilai minimum berdasarkan elemen pada tensor operand dan menghasilkan tensor result. Menerapkan operasi roundToIntegralTowardNegative dari IEEE-754 spesifikasi pendukung. Untuk jenis terkuantisasi, menjalankan dequantize_op_quantize(floor, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand tensor tipe floating point atau tensor terkuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result tensor tipe floating point atau tensor terkuantisasi per-tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]

Contoh Lainnya

kumpulkan

Semantik

Mengumpulkan slice dari tensor operand dari offset yang ditentukan dalam start_indices dan menghasilkan tensor result.

Diagram berikut menunjukkan cara elemen di result dipetakan pada elemen di operand menggunakan contoh konkret. Diagram ini memilih beberapa contoh result indeks dan menjelaskan secara mendetail indeks operand mana yang sesuai dengannya.

kumpulkan

Secara lebih formal, result[result_index] = operand[operand_index] dengan:

  • batch_dims = [d for d in axes(result) and d not in offset_dims].
  • batch_index = result_index[batch_dims...].
  • start_index ditentukan sebagai:
    • start_indices[bi0, ..., :, ..., biN] dengan bi merupakan elemen individual di batch_index dan : dimasukkan pada indeks index_vector_dim, jika index_vector_dim rank(start_indices).
    • [start_indices[batch_index]] sebaliknya.
  • Untuk d_operand di axes(operand),
    • full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand]) jika d_operand = start_index_map[d_start].
    • full_start_index[d_operand] = 0 sebaliknya.
  • Untuk d_operand di axes(operand),
    • full_batching_index[d_operand] = batch_index[d_start - (d_start < index_vector_dim ? 0 : 1)] jika d_operand = operand_batching_dims[i_batching] dan d_start = start_indices_batching_dims[i_batching].
    • full_batching_index[d_operand] = 0 sebaliknya.
  • offset_index = result_index[offset_dims...].
  • full_offset_index = [oi0, ..., 0, ..., oiN] dengan oi merupakan individu di offset_index, dan 0 disisipkan pada indeks dari collapsed_slice_dims dan operand_batching_dims.
  • operand_index = full_start_index + full_batching_index + full_offset_index.

Jika indices_are_sorted adalah true, implementasi dapat mengasumsikan bahwa start_indices diurutkan sehubungan dengan start_index_map, jika tidak, perilakunya tidak terdefinisi. Secara lebih formal, untuk semua i1 < i2 dari indices(result), full_start_index(i1) <= full_start_index(i2).

Input

Label Nama Jenis Batasan
(I1) operand tensor atau tensor terkuantisasi per-tensor (C1), (C8), (C11), (C17), (C19-C21), (C23)
(I2) start_indices tensor tipe integer (C2-C3), (C14), (C17), (C22)
(I3) offset_dims Konstanta tensor 1 dimensi jenis si64 (C1), (C4-C5), (C22)
(I4) collapsed_slice_dims Konstanta tensor 1 dimensi jenis si64 (C1), (C6-C9), (C22)
(I5) operand_batching_dims Konstanta tensor 1 dimensi jenis si64 (C1), (C6), (C10-C12), (C16-C18), (C22)
(I6) start_indices_batching_dims Konstanta tensor 1 dimensi jenis si64 (C13-C17)
(I7) start_index_map Konstanta tensor 1 dimensi jenis si64 (C3), (C18-C19)
(I8) index_vector_dim konstanta jenis si64 (C2-C3), (C15), (C22)
(I9) slice_sizes Konstanta tensor 1 dimensi jenis si64 (C9), (C12), (C20-C22)
(I10) indices_are_sorted konstanta jenis i1

Output

Nama Jenis Batasan
result tensor atau tensor terkuantisasi per-tensor (C5), (C22-C23)

Batasan

  • (C1) rank(operand) = size(offset_dims) + size(collapsed_slice_dims) + size(operand_batching_dims).
  • (C2) 0 <= index_vector_dim <= rank(start_indices).
  • (C3) size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1.
  • (C4) is_unique(offset_dims) and is_sorted(offset_dims).
  • (C5) 0 <= offset_dims < rank(result).
  • (C6) is_unique(concatenate(collapsed_slice_dims, operand_batching_dims))
  • (C7) is_sorted(collapsed_slice_dims).
  • (C8) 0 <= collapsed_slice_dims < rank(operand).
  • (C9) slice_sizes[collapsed_slice_dims...] <= 1.
  • (C10) is_sorted(operand_batching_dims).
  • (C11) 0 <= operand_batching_dims < rank(operand).
  • (C12) slice_sizes[operand_batching_dims...] <= 1.
  • (C13) is_unique(start_indices_batching_dims).
  • (C14) 0 <= start_indices_batching_dims < rank(start_indices).
  • (C15) index_vector_dim not in start_indices_batching_dims.
  • (C16) size(operand_batching_dims) == size(start_indices_batching_dims).
  • (C17) dim(operand, operand_batching_dims...) = dim(start_indices, start_indices_batching_dims...).
  • (C18) is_unique(concatenate(start_index_map, operand_batching_dims)).
  • (C19) 0 <= start_index_map < rank(operand).
  • (C20) size(slice_sizes) = rank(operand).
  • (C21) 0 <= slice_sizes <= shape(operand).
  • (C22) shape(result) = combine(batch_dim_sizes, offset_dim_sizes) jika:
    • batch_dim_sizes = shape(start_indices), kecuali bahwa ukuran dimensi dari start_indices yang sesuai dengan index_vector_dim tidak disertakan.
    • offset_dim_sizes = slice_sizes kecuali bahwa ukuran dimensi di slice_sizes yang sesuai dengan collapsed_slice_dims dan operand_batching_dims tidak termasuk.
    • combine menempatkan batch_dim_sizes pada sumbu yang sesuai dengan batch_dims dan offset_dim_sizes pada sumbu yang sesuai dengan offset_dims.
  • (C23) element_type(operand) = element_type(result).

Contoh

// %operand: [
//            [
//             [[1, 2], [3, 4], [5, 6], [7, 8]],
//             [[9, 10],[11, 12], [13, 14], [15, 16]],
//             [[17, 18], [19, 20], [21, 22], [23, 24]]
//            ],
//            [
//             [[25, 26], [27, 28], [29, 30], [31, 32]],
//             [[33, 34], [35, 36], [37, 38], [39, 40]],
//             [[41, 42], [43, 44], [45, 46], [47, 48]]
//            ]
//           ]
// %start_indices: [
//                  [
//                   [[0, 0], [1, 0], [2, 1]],
//                   [[0, 1], [1, 1], [0, 9]]
//                  ],
//                  [
//                   [[0, 0], [2, 1], [2, 2]],
//                   [[1, 2], [0, 1], [1, 0]]
//                  ]
//                 ]
%result = "stablehlo.gather"(%operand, %start_indices) {
  dimension_numbers = #stablehlo.gather<
    offset_dims = [3, 4],
    collapsed_slice_dims = [1],
    operand_batching_dims = [0],
    start_indices_batching_dims = [1],
    start_index_map = [2, 1],
    index_vector_dim = 3>,
  slice_sizes = array<i64: 1, 1, 2, 2>,
  indices_are_sorted = false
} : (tensor<2x3x4x2xi32>, tensor<2x2x3x2xi64>) -> tensor<2x2x3x2x2xi32>
// %result: [
//           [
//            [
//             [[1, 2], [3, 4]],
//             [[3, 4], [5, 6]],
//             [[13, 14], [15, 16]]
//            ],
//            [
//             [[33, 34], [35, 36]],
//             [[35, 36], [37, 38]],
//             [[41, 42], [43, 44]]
//            ]
//           ],
//           [
//            [
//             [[1, 2], [3, 4]],
//             [[13, 14], [15, 16]],
//             [[21, 22], [23, 24]]
//            ],
//            [
//             [[43, 44], [45, 46]],
//             [[33, 34], [35, 36]],
//             [[27, 28], [29, 30]]
//            ]
//           ]
//          ]

Contoh Lainnya

get_dimension_size

Semantik

Menghasilkan ukuran dimension yang ditentukan dari operand. Secara lebih formal, result = dim(operand, dimension). Semantik hanya memperhitungkan bentuk komponen jenis ini. Tipe elemen bisa berupa apa saja.

Input

Label Nama Jenis Batasan
(I1) operand tensor atau quantized tensor (C1)
(I2) dimension konstanta jenis si64 (C1)

Output

Nama Jenis
result Tensor 0-dimensi jenis si32

Batasan

  • (C1) 0 <= dimension < rank(operand).

Contoh

// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
  dimension = 1 : i64
} : (tensor<2x3xi64>) -> tensor<i32>
// %result: 3

Contoh Lainnya

get_tuple_element

Semantik

Mengekstrak elemen pada posisi index tuple operand dan menghasilkan result. Secara lebih formal, result = operand[index].

Input

Label Nama Jenis Batasan
(I1) operand tuple (C1), C2
(I2) index konstanta jenis si32 (C1), C2

Output

Nama Jenis Batasan
result semua jenis yang didukung (C2)

Batasan

  • (C1) 0 <= index < size(operand).
  • (C2) type(result) = tuple_element_types(operand)[index].

Contoh

// %operand: ([1.0, 2.0], (3))
  index = 0 : i32
} : (tuple<tensor<2xf32>, tuple<tensor<i32>>>) -> tensor<2xf32>
// %result: [1.0, 2.0]

Contoh Lainnya

jika

Semantik

Menghasilkan output dari mengeksekusi tepat satu fungsi dari true_branch atau false_branch bergantung pada nilai pred. Secara lebih formal, result = pred ? true_branch() : false_branch().

Input

Label Nama Jenis Batasan
(I1) pred Tensor 0-dimensi jenis i1
(I2) true_branch fungsi (C1-C3)
(I3) false_branch fungsi (C1), C2

Output

Nama Jenis Batasan
results jumlah tensor, tensor atau token terkuantisasi (C3)

Batasan

  • (C1) input_types(true_branch) = input_types(false_branch) = [].
  • (C2) output_types(true_branch) = output_types(false_branch).
  • (C3) type(results...) = output_types(true_branch).

Contoh

// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
  "stablehlo.return"(%result_true_branch) : (tensor<i32>) -> ()
}, {
  "stablehlo.return"(%result_false_branch) : (tensor<i32>) -> ()
}) : (tensor<i1>) -> tensor<i32>
// %result: 10

Contoh Lainnya

gambar

Semantik

Mengekstrak bagian imajiner, berdasarkan elemen, dari operand dan menghasilkan Tensor result. Secara lebih formal, untuk setiap elemen x: imag(x) = is_complex(x) ? imaginary_part(x) : constant(0, element_type(result)).

Input

Label Nama Jenis Batasan
(I1) operand tensor floating point atau jenis kompleks (C1), C2

Output

Nama Jenis Batasan
result tensor jenis floating point (C1), C2

Batasan

  • (C1) shape(result) = shape(operand).
  • (C2) element_type(result) didefinisikan sebagai:
    • complex_element_type(element_type(operand)) jika is_complex(operand).
    • element_type(operand) sebaliknya.

Contoh

// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [2.0, 4.0]

Contoh Lainnya

dalam feed

Semantik

Membaca data dari dalam feed dan menghasilkan results.

Semantik infeed_config adalah ditentukan oleh implementasi.

results terdiri dari nilai payload yang muncul terlebih dahulu dan token yang muncul terakhir. Di masa mendatang, kami berencana untuk membagi payload dan token menjadi dua {i>output<i} terpisah untuk meningkatkan kejelasan (#670).

Input

Label Nama Jenis
(I1) token token
(I2) infeed_config konstanta jenis string

Output

Nama Jenis Batasan
results jumlah tensor, tensor atau token terkuantisasi (C1-C3)

Batasan

  • (C1) 0 < size(results).
  • (C2) is_empty(result[:-1]) atau is_tensor(type(results[:-1])).
  • (C3) is_token(type(results[-1])).

Contoh

// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
  infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
  infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]

Contoh Lainnya

Iota

Semantik

Mengisi tensor output dengan nilai dalam urutan yang meningkat mulai dari nol di sepanjang dimensi iota_dimension. Secara lebih formal,

output[output_index] = constant(is_quantized(output) ? quantize(output_index[iota_dimension], element_type(output)) : output_index[iota_dimension], element_type(output)).

Input

Label Nama Jenis Batasan
(I1) iota_dimension si64 (C1)

Output

Nama Jenis Batasan
output tensor bilangan bulat, floating point, atau tipe kompleks atau tensor terkuantisasi per-tensor (C1)

Batasan

  • (C1) 0 <= iota_dimension < rank(output).

Contoh

%output = "stablehlo.iota"() {
  iota_dimension = 0 : i64
} : () -> tensor<4x5xi32>
// %output: [
//           [0, 0, 0, 0, 0],
//           [1, 1, 1, 1, 1],
//           [2, 2, 2, 2, 2],
//           [3, 3, 3, 3, 3]
//          ]

%output = "stablehlo.iota"() {
  iota_dimension = 1 : i64
} : () -> tensor<4x5xi32>
// %output: [
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4]
//          ]

Contoh Lainnya

is_finite

Semantik

Melakukan pemeriksaan berdasarkan elemen apakah nilai dalam x terbatas (yaitu bukan +Inf, -Inf, atau NaN) dan menghasilkan tensor y. Mengimplementasikan isFinite operasi dari spesifikasi IEEE-754. Untuk jenis terkuantisasi, hasilnya adalah selalu true.

Input

Label Nama Jenis Batasan
(I1) x tensor tipe floating point atau tensor terkuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
y tensor jenis boolean (C1)

Batasan

  • (C1) shape(x) = shape(y).

Contoh

// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x) : (tensor<7xf64) -> tensor<7xi1>
// %y: [false, false, false, true, true, true, true]

Contoh Lainnya

log

Semantik

Melakukan operasi logaritma element-wise pada tensor operand dan menghasilkan Tensor result. Bergantung pada jenis elemen, melakukan hal berikut:

  • Untuk float: log dari IEEE-754.
  • Untuk bilangan kompleks: logaritma kompleks.
  • Untuk jenis terkuantisasi: dequantize_op_quantize(log, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]

Contoh Lainnya

log_plus_one

Semantik

Melakukan logaritma element-wise ditambah satu operasi pada tensor operand dan menghasilkan tensor result. Bergantung pada jenis elemen, melakukan hal berikut:

  • Untuk float: logp1 dari IEEE-754.
  • Untuk bilangan kompleks: logaritma kompleks ditambah satu.
  • Untuk jenis terkuantisasi: dequantize_op_quantize(log_plus_one, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]

Contoh Lainnya

logistik

Semantik

Melakukan operasi logistik berdasarkan elemen pada tensor operand dan menghasilkan Tensor result. Bergantung pada jenis elemen, melakukan hal berikut:

  • Untuk float: division(1, addition(1, exp(-x))) dari IEEE-754.
  • Untuk bilangan kompleks: logistik kompleks.
  • Untuk jenis terkuantisasi: dequantize_op_quantize(logistic, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]

Contoh Lainnya

peta

Semantik

Menerapkan fungsi peta computation ke inputs di sepanjang dimensions dan menghasilkan tensor result.

Secara lebih formal, result[result_index] = computation(inputs...[result_index]).

Input

Label Nama Jenis Batasan
(I1) inputs jumlah tensor atau tensor terkuantisasi per-tensor (C1-C4)
(I2) dimensions Konstanta tensor 1 dimensi jenis si64 (C3)
(I3) computation fungsi (C4)

Output

Nama Jenis Batasan
result tensor atau tensor terkuantisasi per-tensor (C1), C4

Batasan

  • (C1) shape(inputs...) = shape(result).
  • (C2) 0 < size(inputs) = N.
  • (C3) dimensions = range(rank(inputs[0])).
  • (C4) computation memiliki jenis (tensor<E0>, ..., tensor<EN-1>) -> tensor<E'> dengan Ei = element_type(inputs[i]) dan E' = element_type(result).

Contoh

// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = stablehlo.multiply %arg0, %arg1 : tensor<i64>
    stablehlo.return %0 : tensor<i64>
}) {
  dimensions = array<i64: 0, 1>
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[0, 5], [12, 21]]

Contoh Lainnya

maksimum

Semantik

Melakukan operasi maksimum berbasis elemen pada tensor lhs dan rhs serta menghasilkan Tensor result. Bergantung pada jenis elemen, melakukan hal berikut:

  • Untuk boolean: OR logis.
  • Untuk bilangan bulat: maksimum bilangan bulat.
  • Untuk float: maximum dari IEEE-754.
  • Untuk bilangan kompleks: maksimum leksikografis untuk pasangan (real, imaginary). Memaksakan urutan pada bilangan kompleks melibatkan semantik yang mengejutkan, jadi di masa mendatang kami berencana menghapus dukungan untuk bilangan kompleks untuk operasi ini (#560).
  • Untuk jenis terkuantisasi:
    • dequantize_op_quantize(maximum, lhs, rhs, type(result)).

Input

Label Nama Jenis Batasan
(I1) lhs tensor atau tensor terkuantisasi per-tensor (C1)
(I2) rhs tensor atau tensor terkuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result tensor atau tensor terkuantisasi per-tensor (C1)

Batasan

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Contoh

// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 8]]

Contoh Lainnya

minimum

Semantik

Melakukan operasi min menurut elemen pada tensor lhs dan rhs, serta menghasilkan Tensor result. Bergantung pada jenis elemen, melakukan hal berikut:

  • Untuk boolean: logika AND.
  • Untuk bilangan bulat: minimum bilangan bulat.
  • Untuk float: minimum dari IEEE-754.
  • Untuk bilangan kompleks: minimum leksikografis untuk pasangan (real, imaginary). Memaksakan urutan pada bilangan kompleks melibatkan semantik yang mengejutkan, jadi di masa mendatang kami berencana menghapus dukungan untuk bilangan kompleks untuk operasi ini (#560).
  • Untuk jenis terkuantisasi:
    • dequantize_op_quantize(minimum, lhs, rhs, type(result)).

Input

Label Nama Jenis Batasan
(I1) lhs tensor atau tensor terkuantisasi per-tensor (C1)
(I2) rhs tensor atau tensor terkuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result tensor atau tensor terkuantisasi per-tensor (C1)

Batasan

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Contoh

// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 4]]

Contoh Lainnya

memperbanyak

Semantik

Menjalankan perkalian berbasis elemen dari dua tensor lhs dan rhs serta menghasilkan Tensor result. Bergantung pada jenis elemen, melakukan hal berikut:

  • Untuk boolean: logika AND.
  • Untuk bilangan bulat: perkalian bilangan bulat.
  • Untuk float: multiplication dari IEEE-754.
  • Untuk bilangan kompleks: perkalian kompleks.
  • Untuk jenis terkuantisasi:
    • dequantize_op_quantize(multiply, lhs, rhs, type(result)).

Input

Label Nama Jenis Batasan
(I1) lhs tensor atau tensor terkuantisasi per-tensor (C1)
(I2) rhs tensor atau tensor terkuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result tensor atau tensor terkuantisasi per-tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 12], [21, 32]]

Contoh Lainnya

negasi

Semantik

Melakukan negasi berdasarkan elemen dari tensor operand dan menghasilkan result tensor. Bergantung pada jenis elemen, melakukan hal berikut:

  • Untuk bilangan bulat bertanda tangan: negasi bilangan bulat.
  • Untuk bilangan bulat yang tidak ditandatangani: bitcast ke integer bertanda tangan, negasi bilangan bulat, bitcast kembali ke bilangan bulat yang tidak ditandatangani.
  • Untuk float: negate dari IEEE-754.
  • Untuk bilangan kompleks: negasi kompleks.
  • Untuk jenis terkuantisasi: dequantize_op_quantize(negate, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand tensor bilangan bulat, floating point, atau tipe kompleks atau tensor terkuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result tensor bilangan bulat, floating point, atau tipe kompleks atau tensor terkuantisasi per-tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand) : (tensor<2xi32>) -> tensor<2xi32>
// %result: [0, 2]

// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"(%operand) : (tensor<1xcomplex<f32>>) -> tensor<1xcomplex<f32>>
// %result: [-2.5, -0.0]

Contoh Lainnya

tidak

Semantik

Melakukan NOT tanpa elemen pada tensor operand dan menghasilkan tensor result. Bergantung pada jenis elemen, melakukan hal berikut:

  • Untuk boolean: logical NOT.
  • Untuk bilangan bulat: bitwise NOT.

Argumen

Nama Jenis Batasan
operand tensor jenis boolean atau bilangan bulat (C1)

Output

Nama Jenis Batasan
result tensor jenis boolean atau bilangan bulat (C1)

Batasan

  • (C1) type(operand) = type(result).

Contoh

// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand) : (tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[-2, -3], [-4, -5]]

// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"(%operand) : (tensor<2xi1>) -> tensor<2xi1>
// %result: [false, true]

Contoh Lainnya

optimization_barrier

Semantik

Memastikan operasi yang menghasilkan operand dijalankan sebelum operasi yang bergantung pada result dan mencegah transformasi compiler agar tidak memindahkan operasi melintasi penghalang. Selain itu, operasinya identitas, yaitu result = operand.

Argumen

Nama Jenis Batasan
operand jumlah tensor, tensor atau token terkuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result jumlah tensor, tensor atau token terkuantisasi per-tensor (C1)

Batasan

  • (C1) type(operand...) = type(result...).

Contoh

// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
// %result0: 0.0
// %result1: 1.0

Contoh Lainnya

atau

Semantik

Melakukan OR berdasarkan elemen dari dua tensor lhs dan rhs serta menghasilkan result tensor. Bergantung pada jenis elemen, melakukan hal berikut:

  • Untuk boolean: OR logis.
  • Untuk bilangan bulat: bitwise OR.

Input

Label Nama Jenis Batasan
(I1) lhs tensor integer atau tipe boolean (C1)
(I2) rhs tensor integer atau tipe boolean (C1)

Output

Nama Jenis Batasan
result tensor integer atau tipe boolean (C1)

Batasan

  • (C1) type(lhs) = type(rhs) = type(result).

Contoh

// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 12]]

// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, true]]

Contoh Lainnya

outfeed

Semantik

Menulis inputs ke outfeed dan menghasilkan token result.

Semantik outfeed_config adalah ditentukan oleh implementasi.

Input

Label Nama Jenis
(I1) inputs jumlah tensor atau tensor terkuantisasi
(I2) token token
(I3) outfeed_config konstanta jenis string

Output

Nama Jenis
result token

Contoh

%result = "stablehlo.outfeed"(%input0, %token) {
  outfeed_config = ""
} : (tensor<2x2x2xi64>, !stablehlo.token) -> !stablehlo.token

Contoh Lainnya

alas

Semantik

Memperluas operand dengan padding di sekitar tensor serta di antara elemen tensor dengan padding_value yang diberikan.

edge_padding_low dan edge_padding_high menentukan jumlah padding yang ditambahkan di kelas bawah (di samping indeks 0) dan kelas atas (di samping indeks tertinggi) dalam masing-masing dimensi. Jumlah padding bisa negatif, di mana nilai absolut padding negatif menunjukkan jumlah elemen yang akan dihapus dari dimensi yang ditentukan.

interior_padding menentukan jumlah padding yang ditambahkan di antara dua padding di setiap dimensi yang mungkin tidak negatif. Padding interior terjadi sebelum pelapis tepi sehingga pelapis tepi negatif akan membuang elemen dari operand dengan padding interior.

Secara lebih formal, result[result_index] didefinisikan sebagai:

  • operand[operand_index] jika result_index = edge_padding_low + operand_index * (interior_padding + 1).
  • padding_value sebaliknya.

Input

Label Nama Jenis Batasan
(I1) operand tensor atau tensor terkuantisasi per-tensor (C1), (C2), (C4)
(I2) padding_value Tensor 0-dimensi atau per-tensor quantized tensor (C1)
(I3) edge_padding_low Konstanta tensor 1 dimensi jenis si64 (C1), C4
(I4) edge_padding_high Konstanta tensor 1 dimensi jenis si64 (C1), C4
(I5) interior_padding Konstanta tensor 1 dimensi jenis si64 (C2-C4)

Output

Nama Jenis Batasan
result tensor atau tensor terkuantisasi per-tensor (C3-C6)

Batasan

  • (C1) element_type(operand) = element_type(padding_value) = element_type(result).
  • (C2) size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand).
  • (C3) 0 <= interior_padding.
  • (C4) shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high.

Contoh

// %operand: [
//            [1, 2, 3],
//            [4, 5, 6]
//           ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
  edge_padding_low = array<i64: 0, 1>,
  edge_padding_high = array<i64: 2, 1>,
  interior_padding = array<i64: 1, 2>
} : (tensor<2x3xi32>, tensor<i32>) -> tensor<5x9xi32>
// %result: [
//           [0, 1, 0, 0, 2, 0, 0, 3, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 4, 0, 0, 5, 0, 0, 6, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0]
//          ]

Contoh Lainnya

partition_id

Semantik

Menghasilkan partition_id dari proses saat ini.

Output

Nama Jenis
result Tensor 0-dimensi jenis ui32

Contoh

%result = "stablehlo.partition_id"() : () -> tensor<ui32>

Contoh Lainnya

Popcnt

Semantik

Melakukan penghitungan jumlah bit berdasarkan elemen pada tensor operand dan menghasilkan tensor result.

Input

Label Nama Jenis Batasan
(I1) operand tensor tipe integer (C1)

Output

Nama Jenis Batasan
result tensor tipe integer (C1)

Batasan

  • (C1) type(operand) = type(result).

Contoh

// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand) : (tensor<4xi64>) -> tensor<4xi64>
// %result: [0, 1, 1, 7]

Contoh Lainnya

daya

Semantik

Melakukan eksponensial element-wise dari tensor lhs dengan tensor rhs dan menghasilkan tensor result. Bergantung pada jenis elemen, melakukan hal berikut:

  • Untuk bilangan bulat: eksponensial bilangan bulat.
  • Untuk float: pow dari IEEE-754.
  • Untuk bilangan kompleks: eksponensial kompleks.
  • Untuk jenis terkuantisasi: dequantize_op_quantize(power, lhs, rhs, type(result)).

Input

Label Nama Jenis Batasan
(I1) lhs tensor bilangan bulat, floating point, atau tipe kompleks atau tensor terkuantisasi per-tensor (C1)
(I2) rhs tensor bilangan bulat, floating point, atau tipe kompleks atau tensor terkuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result tensor bilangan bulat, floating point, atau tipe kompleks atau tensor terkuantisasi per-tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs) : (tensor<6xf64>, tensor<6xf64>) -> tensor<6xf64>
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]

Contoh Lainnya

real

Semantik

Mengekstrak bagian asli, berdasarkan elemen, dari operand dan menghasilkan result tensor. Secara lebih formal, untuk setiap elemen x: real(x) = is_complex(x) ? real_part(x) : x.

Input

Label Nama Jenis Batasan
(I1) operand tensor floating point atau jenis kompleks (C1), C2

Output

Nama Jenis Batasan
result tensor jenis floating point (C1), C2

Batasan

  • (C1) shape(result) = shape(operand).
  • (C2) element_type(result) didefinisikan sebagai:
    • complex_element_type(element_type(operand)) jika is_complex(operand).
    • element_type(operand) sebaliknya.

Contoh

// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [1.0, 3.0]

Contoh Lainnya

rekomendasi

Semantik

Menerima data dari saluran dengan channel_id dan menghasilkan results.

Jika is_host_transfer adalah true, operasi tersebut akan mentransfer data dari {i>host<i}. Jika tidak, sistem akan mentransfer data dari perangkat lain. Artinya adalah ditentukan oleh implementasi. Penanda ini menduplikasi informasi yang diberikan di channel_type, jadi di masa mendatang kami berencana untuk menyimpan salah satunya saja (#666).

results terdiri dari nilai payload yang muncul terlebih dahulu dan token yang muncul terakhir. Di masa mendatang, kami berencana untuk membagi payload dan token menjadi dua {i>output<i} terpisah untuk meningkatkan kejelasan (#670).

Input

Label Nama Jenis Batasan
(I1) token token (C4)
(I2) channel_id konstanta jenis si64
(I3) channel_type enum DEVICE_TO_DEVICE dan HOST_TO_DEVICE (C1)
(I4) is_host_transfer konstanta jenis i1 (C1)

Output

Nama Jenis Batasan
results jumlah tensor, tensor atau token terkuantisasi (C2-C4)

Batasan

  • (C1) channel_type didefinisikan sebagai:
    • HOST_TO_DEVICE jika is_host_transfer = true,
    • DEVICE_TO_DEVICE sebaliknya.
  • (C2) 0 < size(results).
  • (C3) is_empty(result[:-1]) atau is_tensor(type(results[:-1])).
  • (C4) is_token(type(results[-1])).

Contoh

%results0, %results1 = "stablehlo.recv"(%token) {
  channel_handle = #stablehlo.channel_handle<handle = 1, type = 3>,
  is_host_transfer = true
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)

Contoh Lainnya

reduce

Semantik

Menerapkan fungsi pengurangan body ke inputs dan init_values di sepanjang dimensions dan menghasilkan tensor results.

Urutan pengurangan ditentukan oleh implementasi, yang berarti bahwa body dan init_values harus membentuk monoid untuk menjamin bahwa operasi tersebut menghasilkan hasil yang sama untuk semua input pada semua implementasi. Namun, kondisi ini tidak dapat digunakan untuk banyak pengurangan populer. Mis. penambahan floating point untuk body dan nol untuk init_values tidak benar-benar membentuk monoid karena penambahan floating point tidak asosiatif.

Secara lebih formal, results...[j0, ..., jR-1] = reduce(input_slices_converted) dengan:

  • input_slices = inputs...[j0, ..., :, ..., jR-1], tempat : disisipkan pukul dimensions.
  • input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...).
  • init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...).
  • reduce(input_slices_converted) = exec(schedule) untuk beberapa hierarki biner schedule dengan:
    • exec(node) = body(exec(node.left), exec(node.right)).
    • exec(leaf) = leaf.value.
  • schedule adalah hierarki biner lengkap yang ditentukan implementasi yang sesuai urutan traversal terdiri dari:
    • Nilai input_slices_converted...[index], untuk semua index di index_space(input_slices_converted) dalam urutan leksikografis menaik dari index.
    • Diselingi dengan sejumlah init_values_converted di posisi yang ditentukan penerapan.

Input

Label Nama Jenis Batasan
(I1) inputs jumlah tensor atau tensor terkuantisasi per-tensor (C1-C4), (C6), (C7)
(I2) init_values jumlah variadik tensor 0-dimensi atau tensor terkuantisasi per-tensor (C2), (C3)
(I3) dimensions Konstanta tensor 1 dimensi jenis si64 (C4), (C5), (C7)
(I4) body fungsi (C6)

Output

Nama Jenis Batasan
results jumlah tensor atau tensor terkuantisasi per-tensor (C3), (C7), (C8)

Batasan

  • (C1) same(shape(inputs...)).
  • (C2) element_type(inputs...) = element_type(init_values...).
  • (C3) 0 < size(inputs) = size(init_values) = size(results) = N.
  • (C4) 0 <= dimensions < rank(inputs[0]).
  • (C5) is_unique(dimensions).
  • (C6) body memiliki jenis (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>) di mana is_promotable(element_type(inputs[i]), Ei).
  • (C7) shape(results...) = shape(inputs...) kecuali bahwa dimensi ukuran inputs... yang sesuai dengan dimensions tidak disertakan.
  • (C8) element_type(results[i]) = Ei untuk semua i di [0,N).

Contoh

// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  dimensions = array<i64: 1>
} : (tensor<1x6xi64>, tensor<i64>) -> tensor<1xi64>
// %result = [15]

Contoh Lainnya

reduce_precision

Semantik

Melakukan konversi operand berbasis elemen ke jenis floating point lainnya yang menggunakan exponent_bits dan mantissa_bits serta kembali ke versi aslinya jenis floating point dan menghasilkan tensor output.

Secara lebih formal:

  • Bit mantissa nilai asli diperbarui untuk membulatkan nilai asli ke nilai terdekat yang dapat diwakili dengan mantissa_bits menggunakan Semantik roundToIntegralTiesToEven.
  • Kemudian, jika mantissa_bits lebih kecil dari jumlah bit mantissa nilai asli, bit mantissa akan dipotong menjadi mantissa_bits.
  • Lalu, jika bit eksponen dari hasil perantara tidak sesuai dengan yang disediakan oleh exponent_bits, hasil perantara akan melebihi tak terhingga menggunakan tanda awal atau {i>underflows<i} ke nol menggunakan tanda aslinya.
  • Untuk jenis terkuantisasi, menjalankan dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand tensor tipe floating point atau tensor terkuantisasi per-tensor (C1)
(I2) exponent_bits konstanta jenis si32 (C2)
(I3) mantissa_bits konstanta jenis si32 (C3)

Output

Nama Jenis Batasan
output tensor tipe floating point atau tensor terkuantisasi per-tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(output).
  • (C2) 1 <= exponent_bits.
  • (C3) 0 <= mantissa_bits.

Contoh

// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
  exponent_bits = 5 : i32,
  mantissa_bits = 10 : i32
} : (tensor<6xf64>) -> tensor<6xf64>
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]

Contoh Lainnya

reduce_scatter

Semantik

reduce_scatter

Dalam setiap grup proses di {i>process grid<i} StabilHLO, melakukan pengurangan, menggunakan computations, terhadap nilai tensor operand dari setiap proses, membagi hasil pengurangan di sepanjang scatter_dimension menjadi beberapa bagian, dan sebar bagian yang terpisah di antara proses untuk menghasilkan result.

Operasi tersebut membagi grid proses StableHLO menjadi process_groups yang didefinisikan sebagai berikut:

  • cross_replica(replica_groups) jika channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) jika channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) jika channel_id > 0 and use_global_device_ids = true.

Setelah itu, dalam setiap process_group:

  • reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation).
  • parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension).
  • result@receiver = parts@sender[receiver_index] untuk semua sender di process_group, dengan receiver_index = process_group.index(receiver).

Input

Label Nama Jenis Batasan
(I1) operand tensor atau tensor terkuantisasi per-tensor (C1), (C2), (C7), (C8)
(I2) scatter_dimension konstanta jenis si64 (C1), C2 (C8), dan (C8)
(I3) replica_groups Konstanta tensor 2 dimensi jenis si64 (C3-C5)
(I4) channel_id konstanta jenis si64 (C6)
(I5) use_global_device_ids konstanta jenis i1 (C6)
(I6) computation fungsi (C7)

Output

Nama Jenis Batasan
result tensor atau tensor terkuantisasi per-tensor (C8-C9)

Batasan

  • (C1) dim(operand, scatter_dimension) % dim(process_groups, 1) = 0.
  • (C2) 0 <= scatter_dimension < rank(operand).
  • (C3) is_unique(replica_groups).
  • (C4) size(replica_groups) ditentukan sebagai:
    • num_replicas jika cross_replica digunakan.
    • num_replicas jika cross_replica_and_partition digunakan.
    • num_processes jika flattened_ids digunakan.
  • (C5) 0 <= replica_groups < size(replica_groups).
  • (C6) Jika use_global_device_ids = true, maka channel_id > 0.
  • (C7) computation memiliki jenis (tensor<E>, tensor<E>) -> (tensor<E>), dengan is_promotable(element_type(operand), E).
  • (C8) shape(result) = shape(operand) kecuali:
    • dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1).
  • (C9) element_type(result) = E.

Contoh

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
//                   [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
//                   [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
  %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
  "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  scatter_dimension = 1 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x4xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[10, 12],
//                  [18, 20]]
// %result@(1, 0): [[14, 16],
//                  [22, 24]]

Contoh Lainnya

reduce_window

Semantik

Menerapkan fungsi pengurangan body ke jendela inputs dan init_values dan menghasilkan results.

Diagram berikut menunjukkan cara penghitungan elemen pada results... inputs... menggunakan contoh konkret.

reduce_window

Secara lebih formal, results...[result_index] = reduce(windows, init_values, axes(inputs...), body) (lihat mengurangi) jika:

  • padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1).
  • window_start = result_index * window_strides.
  • window_end = window_start + (window_dimensions - 1) * window_dilations + 1.
  • windows = slice(padded_inputs..., window_start, window_end, window_dilations).

Input

Label Nama Jenis Batasan
(I1) inputs jumlah tensor atau tensor terkuantisasi per-tensor (C1-C4), (C6), (C8), (C10), (C12), (C13), (C15)
(I2) init_values jumlah variadik tensor 0-dimensi atau tensor terkuantisasi per-tensor (C1), C13
(I3) window_dimensions Konstanta tensor 1 dimensi jenis si64 (C4), (C5), (C15)
(I4) window_strides Konstanta tensor 1 dimensi jenis si64 (C6), (C7), (C15)
(I5) base_dilations Konstanta tensor 1 dimensi jenis si64 (C8), (C9), (C15)
(I6) window_dilations Konstanta tensor 1 dimensi jenis si64 (C10), (C11), (C15)
(I7) padding Konstanta tensor 2 dimensi jenis si64 (C12), (C15)
(I8) body fungsi (C13)

Output

Nama Jenis Batasan
results jumlah tensor atau tensor terkuantisasi per-tensor (C1), (C14-C16)

Batasan

  • (C1) 0 < size(inputs) = size(init_values) = size(results) = N.
  • (C2) same(shape(inputs...)).
  • (C3) element_type(inputs...) = element_type(init_values...).
  • (C4) size(window_dimensions) = rank(inputs[0]).
  • (C5) 0 < window_dimensions.
  • (C6) size(window_strides) = rank(inputs[0]).
  • (C7) 0 < window_strides.
  • (C8) size(base_dilations) = rank(inputs[0]).
  • (C9) 0 < base_dilations.
  • (C10) size(window_dilations) = rank(inputs[0]).
  • (C11) 0 < window_dilations.
  • (C12) shape(padding) = [rank(inputs[0]), 2].
  • (C13) body memiliki jenis (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>) di mana is_promotable(element_type(inputs[i]), Ei).
  • (C14) same(shape(results...)).
  • (C15) shape(results[0]) = num_windows jika:
    • dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1.
    • padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1].
    • dilated_window_shape = (window_dimensions - 1) * window_dilations + 1.
    • is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape.
    • num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1.
  • (C16) element_type(results[i]) = Ei untuk semua i di [0,N).

Contoh

// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  window_dimensions = array<i64: 2, 1>,
  window_strides = array<i64: 4, 1>,
  base_dilations = array<i64: 2, 1>,
  window_dilations = array<i64: 3, 1>,
  padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<3x2xi64>, tensor<i64>) -> tensor<2x2xi64>
// %result = [[0, 0], [3, 4]]

Contoh Lainnya

sisa

Semantik

Melakukan sisa tensor dividen lhs dan pembagi rhs serta menghasilkan tensor result.

Secara lebih formal, tanda hasil diambil dari pembagian, dan nilai absolut hasil selalu lebih kecil dari nilai absolut pembagi. Sisanya dihitung sebagai lhs - d * rhs, dengan d diberikan oleh:

  • Untuk bilangan bulat: stablehlo.divide(lhs, rhs).
  • Untuk float: division(lhs, rhs) dari IEEE-754 dengan atribut pembulatan roundTowardZero.
  • Untuk bilangan kompleks: TBD (#997).
  • Untuk jenis terkuantisasi:
    • dequantize_op_quantize(remainder, lhs, rhs, type(result)).

Untuk jenis elemen floating point, operasi ini berbeda dengan Operasi remainder dari spesifikasi IEEE-754 dengan d sebagai nilai integral terdekat dengan nilai pasti lhs/rhs yang memiliki nilai yang sama.

Input

Label Nama Jenis Batasan
(I1) lhs tensor integer, floating-point, atau tipe kompleks atau tensor terkuantisasi per-tensor (C1)
(I2) rhs tensor integer, floating-point, atau tipe kompleks atau tensor terkuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result tensor integer, floating-point, atau tipe kompleks atau tensor terkuantisasi per-tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64>
// %result: [2, -2, 2, -2]

Contoh Lainnya

replica_id

Semantik

Menghasilkan replica_id dari proses saat ini.

Output

Nama Jenis
result Tensor 0-dimensi jenis ui32

Contoh

%result = "stablehlo.replica_id"() : () -> tensor<ui32>

Contoh Lainnya

bentuk ulang

Semantik

Melakukan pembentukan ulang tensor operand ke tensor result. Secara konseptual, sama dengan mempertahankan representasi kanonis yang sama, tetapi berpotensi mengubah bentuknya, misalnya dari tensor<2x3xf32> hingga tensor<3x2xf32> atau tensor<6xf32>.

Secara lebih formal, result[result_index] = operand[operand_index] di mana result_index dan operand_index memiliki posisi yang sama dalam leksikografis urutan index_space(result) dan index_space(operand).

Input

Label Nama Jenis Batasan
(I1) operand tensor atau quantized tensor (C1-C3)

Output

Nama Jenis Batasan
result tensor atau quantized tensor (C1-C3)

Batasan

  • (C1) element_type(result) diberikan oleh:
    • element_type(operand), jika !is_per_axis_quantized(operand).
    • element_type(operand) kecuali quantization_dimension(operand) dan quantization_dimension(result) dapat berbeda, jika tidak.
  • (C2) size(operand) = size(result).
  • (C3) Jika is_per_axis_quantized(operand):
    • reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
    • dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result)).
    • reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).

Contoh

// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.reshape"(%operand) : (tensor<2x3xi32>) -> tensor<3x2xi32>
// %result: [[1, 2], [3, 4], [5, 6]]

Contoh Lainnya

balik

Semantik

Membalik urutan elemen dalam operand di sepanjang dimensions yang ditentukan dan menghasilkan tensor result. Secara lebih formal, result[result_index] = operand[operand_index] dalam hal ini:

  • operand_index[d] = dim(result, d) - result_index[d] - 1 jika d dalam dimensions.
  • operand_index[d] = result_index[d] sebaliknya.

Input

Label Nama Jenis Batasan
(I1) operand tensor atau tensor terkuantisasi per-tensor (C1), C3
(I2) dimensions Konstanta tensor 1 dimensi jenis si64 (C2), (C3)

Output

Nama Jenis Batasan
result tensor atau tensor terkuantisasi per-tensor (C1), C3

Batasan

  • (C1) type(operand) = type(result).
  • (C2) is_unique(dimensions).
  • (C3) 0 <= dimensions < rank(result).

Contoh

// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
  dimensions = array<i64: 1>
} : (tensor<3x2xi32>) -> tensor<3x2xi32>
// %result: [[2, 1], [4, 3], [6, 5]]

Contoh Lainnya

rng

Semantik

Menghasilkan angka acak menggunakan algoritma rng_distribution dan menghasilkan Tensor result dari bentuk tertentu shape.

Jika rng_distribution = UNIFORM, angka acak akan dibuat setelah distribusi seragam selama interval [a, b). Jika a >= b, perilakunya tidak terdefinisi.

Jika rng_distribution = NORMAL, angka acak akan dibuat mengikuti distribusi normal dengan rata-rata = a dan deviasi standar = b. Jika b < 0, perilaku tidak ditentukan.

Cara persis bagaimana angka acak dibuat ditentukan oleh implementasi. Sebagai misalnya, mereka mungkin determenistik, dan mereka mungkin atau tidak status tersembunyi.

Dalam percakapan dengan banyak pemangku kepentingan, operasi ini telah muncul seefektif tidak digunakan lagi, jadi di masa mendatang kami berencana untuk mencoba menghapusnya (#597).

Input

Label Nama Jenis Batasan
(I1) a Tensor 0-dimensi dari jenis integer, boolean, atau floating point (C1), C2
(I2) b Tensor 0-dimensi dari jenis integer, boolean, atau floating point (C1), C2
(I3) shape Konstanta tensor 1 dimensi jenis si64 (C3)
(I4) rng_distribution enum UNIFORM dan NORMAL (C2)

Output

Nama Jenis Batasan
result tensor bilangan bulat, boolean, atau jenis floating point (C1-C3)

Batasan

  • (C1) element_type(a) = element_type(b) = element_type(result).
  • (C2) Jika rng_distribution = NORMAL, maka is_float(a).
  • (C3) shape(result) = shape.

Contoh

// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
  rng_distribution = #stablehlo<rng_distribution UNIFORM>
} : (tensor<i32>, tensor<i32>, tensor<2xi64>) -> tensor<3x3xi32>
// %result: [
//           [1, 0, 1],
//           [1, 1, 1],
//           [0, 0, 0]
//          ]

rng_bit_generator

Semantik

Menampilkan output yang diisi dengan bit acak yang seragam dan status output yang diperbarui output_state menggunakan algoritma generator angka pseudorandom rng_algorithm diberi status awal initial_state. Outputnya dijamin akan fungsi determenistik initial_state, tetapi tidak dijamin akan determenistik di antara implementasi.

rng_algorithm adalah salah satu dari berikut ini:

  • DEFAULT: Algoritma yang ditentukan implementasi.
  • THREE_FRY: Varian yang ditentukan berdasarkan implementasi dari algoritma Threefry.*
  • PHILOX: Varian algoritma Philox yang ditentukan berdasarkan implementasi.*

* Lihat: Salmon et al. SC 2011. Bilangan acak paralel: semudah 1, 2, 3.

Input

Label Nama Jenis Batasan
(I1) rng_algorithm enum DEFAULT, THREE_FRY, dan PHILOX (C2)
(I2) initial_state Tensor 1 dimensi dari jenis ui64 (C1), C2

Output

Nama Jenis Batasan
output_state Tensor 1 dimensi dari jenis ui64 (C1)
output tensor jenis bilangan bulat atau floating point

Batasan

  • (C1) type(initial_state) = type(output_state).
  • (C2) size(initial_state) didefinisikan sebagai:
    • yang ditentukan implementasi jika rng_algorithm = DEFAULT.
    • 2 jika rng_algorithm = THREE_FRY.
    • 2 atau 3 jika rng_algorithm = PHILOX.

Contoh

// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
  rng_algorithm = #stablehlo<rng_algorithm THREE_FRY>
} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x2xui64>)
// %output_state: [1, 6]
// %output: [
//           [9236835810183407956, 16087790271692313299],
//           [18212823393184779219, 2658481902456610144]
//          ]

round_nearest_afz

Semantik

Melakukan pembulatan berbasis elemen ke bilangan bulat terdekat, memisahkan ikatan dari nol, pada tensor operand dan menghasilkan tensor result. Implementasi operasi roundToIntegralTiesToAway dari spesifikasi IEEE-754. Sebagai jenis terkuantisasi, melakukan dequantize_op_quantize(round_nearest_afz, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand tensor tipe floating point atau tensor terkuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result tensor tipe floating point atau tensor terkuantisasi per-tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]

Contoh Lainnya

round_nearest_even

Semantik

Melakukan pembulatan berbasis elemen ke bilangan bulat terdekat, yang memutus ikatan terhadap bilangan bulat genap, pada tensor operand dan menghasilkan result tensor. Menerapkan operasi roundToIntegralTiesToEven dari IEEE-754 spesifikasi pendukung. Untuk jenis terkuantisasi, menjalankan dequantize_op_quantize(round_nearest_even, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand tensor tipe floating point atau tensor terkuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result tensor tipe floating point atau tensor terkuantisasi per-tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]

Contoh Lainnya

rsqrt

Semantik

Melakukan operasi root kuadrat timbal balik berdasarkan elemen pada tensor operand dan menghasilkan tensor result. Bergantung pada jenis elemen, melakukan hal berikut:

  • Untuk float: rSqrt dari IEEE-754.
  • Untuk bilangan kompleks: akar kuadrat timbal balik kompleks.
  • Untuk jenis terkuantisasi: dequantize_op_quantize(rsqrt, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.5], [0.33333343, 0.2]]

Contoh Lainnya

{i>scatter <i}(memencar)

Semantik

Menghasilkan tensor results yang sama dengan tensor inputs, kecuali bahwa beberapa irisan yang ditentukan oleh scatter_indices diperbarui dengan nilai updates menggunakan update_computation.

Diagram berikut menunjukkan cara elemen di updates... dipetakan pada elemen di results... menggunakan contoh konkret. Diagram ini memilih beberapa contoh updates... membuat indeks dan menjelaskan secara mendetail results... mana yang menghitungnya yang sesuai.

{i>scatter <i}(memencar)

Secara lebih formal, untuk semua update_index di index_space(updates[0]):

  • update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims].
  • update_scatter_index = update_index[update_scatter_dims...].
  • start_index ditentukan sebagai:
    • scatter_indices[si0, ..., :, ..., siN] dengan si merupakan individu elemen dalam update_scatter_index dan : dimasukkan di indeks index_vector_dim, jika index_vector_dim < rank(scatter_indices).
    • [scatter_indices[update_scatter_index]] sebaliknya.
  • Untuk d_input di axes(inputs[0]),
    • full_start_index[d_input] = start_index[d_start] jika d_input = scatter_dims_to_operand_dims[d_start].
    • full_start_index[d_input] = 0 sebaliknya.
  • Untuk d_input di axes(inputs[0]),
    • full_batching_index[d_input] = update_scatter_index[d_start - (d_start < index_vector_dim ? 0 : 1)] jika d_input = input_batching_dims[i_batching] dan d_start = scatter_indices_batching_dims[i_batching].
    • full_batching_index[d_input] = 0 sebaliknya.
  • update_window_index = update_index[update_window_dims...].
  • full_window_index = [wi0, ..., 0, ..., wiN] dengan wi merupakan individu di update_window_index, dan 0 disisipkan pada indeks dari inserted_window_dims dan input_batching_dims.
  • result_index = full_start_index + full_batching_index + full_window_index.

Oleh karena itu, results = exec(schedule, inputs), dalam hal:

  • schedule adalah permutasi yang ditentukan implementasi dari index_space(updates[0]).
  • exec([update_index, ...], results) = exec([...], updated_results) dalam hal ini:
    • Jika result_index termasuk dalam batas shape(results...)
    • updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )
    • updated_values = update_computation(results...[result_index], updates_converted)
    • updated_results adalah salinan results dengan results...[result_index] disetel ke updated_values....
    • Atau
    • updated_results = results.
  • exec([], results) = results.

Jika indices_are_sorted adalah true, implementasi dapat mengasumsikan bahwa scatter_indices diurutkan sehubungan dengan scatter_dims_to_operand_dims, jika tidak, perilaku tidak terdefinisi. Secara lebih formal, untuk semua i1 < i2 dari indices(result), full_start_index(i1) <= full_start_index(i2).

Jika unique_indices adalah true, implementasi dapat mengasumsikan bahwa semua Indeks result_index yang tersebar bersifat unik. Jika unique_indices adalah true, tetapi indeks yang tersebar tidak unik maka perilakunya adalah tidak terdefinisi.

Input

Label Nama Jenis Batasan
(I1) inputs jumlah tensor atau tensor terkuantisasi per-tensor (C1), (C2), (C4-C6), (C11), (C13), (C18), (C21), (C23-C24)
(I2) scatter_indices tensor tipe integer (C4), (C15), (C19), (C22)
(I3) updates jumlah tensor atau tensor terkuantisasi per-tensor (C3-C6), (C8)
(I4) update_window_dims Konstanta tensor 1 dimensi jenis si64 (C2), (C4), (C7-C8)
(I5) inserted_window_dims Konstanta tensor 1 dimensi jenis si64 (C2), (C4), (C9-C11)
(I6) input_batching_dims Konstanta tensor 1 dimensi jenis si64 (C2), (C4), (C9), (C12-13), (C17-18), (C20)
(I7) scatter_indices_batching_dims Konstanta tensor 1 dimensi jenis si64 (C14-C18)
(I8) scatter_dims_to_operand_dims Konstanta tensor 1 dimensi jenis si64 (C19-C21)
(I9) index_vector_dim konstanta jenis si64 (C4), (C16), (C19), (C22)
(I10) indices_are_sorted konstanta jenis i1
(I11) unique_indices konstanta jenis i1
(I12) update_computation fungsi (C23)

Output

Nama Jenis Batasan
results jumlah tensor atau tensor terkuantisasi per-tensor (C24-C25)

Batasan

  • (C1) same(shape(inputs...)).
  • (C2) `rank(inputs[0]) = size(update_window_dims) + ukuran(inserted_window_dims)
    • size(input_batching_dims)`.
  • (C3) same(shape(updates...)).
  • (C4) shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes) jika:
    • update_scatter_dim_sizes = shape(scatter_indices) kecuali bahwa ukuran dimensi scatter_indices yang sesuai dengan index_vector_dim tidak termasuk.
    • update_window_dim_sizes <= shape(inputs[0]) kecuali bahwa ukuran dimensi di inputs[0] yang sesuai dengan inserted_window_dims dan input_batching_dims tidak termasuk.
    • combine menempatkan update_scatter_dim_sizes pada sumbu yang sesuai dengan update_scatter_dims dan update_window_dim_sizes pada sumbu yang sesuai ke update_window_dims.
  • (C5) 0 < size(inputs) = size(updates) = N.
  • (C6) element_type(updates...) = element_type(inputs...).
  • (C7) is_unique(update_window_dims) and is_sorted(update_window_dims).
  • (C8) 0 <= update_window_dims < rank(updates[0]).
  • (C9) is_unique(concatenate(inserted_window_dims, input_batching_dims))
  • (C10) is_sorted(inserted_window_dims).
  • (C11) 0 <= inserted_window_dims < rank(inputs[0]).
  • (C12) is_sorted(input_batching_dims).
  • (C13) 0 <= input_batching_dims < rank(inputs[0])).
  • (C14) is_unique(scatter_indices_batching_dims).
  • (C15) 0 <= scatter_indices_batching_dims < rank(scatter_indices).
  • (C16) index_vector_dim not in scatter_indices_batching_dims.
  • (C17) size(input_batching_dims) == size(scatter_indices_batching_dims).
  • (C18) dim(inputs[0], input_batching_dims...) = dim(scatter_indices, scatter_indices_batching_dims...).
  • (C19) size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1.
  • (C20) is_unique(concatenate(scatter_dims_to_operand_dims, input_batching_dims)).
  • (C21) 0 <= scatter_dims_to_operand_dims < rank(inputs[0]).
  • (C22) 0 <= index_vector_dim <= rank(scatter_indices).
  • (C23) update_computation memiliki jenis (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>), dengan is_promotable(element_type(inputs[i]), Ei).
  • (C24) shape(inputs...) = shape(results...).
  • (C25) element_type(results[i]) = Ei untuk semua i di [0,N).

Contoh

// %input: [
//          [
//           [[1, 2], [3, 4], [5, 6], [7, 8]],
//           [[9, 10],[11, 12], [13, 14], [15, 16]],
//           [[17, 18], [19, 20], [21, 22], [23, 24]]
//          ],
//          [
//           [[25, 26], [27, 28], [29, 30], [31, 32]],
//           [[33, 34], [35, 36], [37, 38], [39, 40]],
//           [[41, 42], [43, 44], [45, 46], [47, 48]]
//          ]
//         ]
// %scatter_indices: [
//                    [
//                     [[0, 0], [1, 0], [2, 1]],
//                     [[0, 1], [1, 1], [0, 9]]
//                    ],
//                    [
//                     [[0, 0], [2, 1], [2, 2]],
//                     [[1, 2], [0, 1], [1, 0]]
//                    ]
//                   ]
// %update: [
//           [
//            [[1, 1], [1, 1], [1, 1]],
//            [[1, 1], [1, 1], [1, 1]]
//           ],
//           [
//            [[1, 1], [1, 1], [1, 1]],
//            [[1, 1], [1, 1], [1, 1]]
//           ]
//          ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  scatter_dimension_numbers = #stablehlo.scatter<
    update_window_dims = [3, 4],
    inserted_window_dims = [1],
    input_batching_dims = [0],
    scatter_indices_batching_dims = [1],
    scatter_dims_to_operand_dims = [2, 1],
    index_vector_dim = 3>,
  indices_are_sorted = false,
  unique_indices = false
} : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>, tensor<2x2x3x2x2xi64>) -> tensor<2x3x4x2xi64>
// %result: [
//           [
//            [[3, 4], [6, 7], [6, 7], [7, 8]],
//            [[9, 10],[11, 12], [15, 16], [17, 18]],
//            [[17, 18], [19, 20], [22, 23], [24, 25]]
//           ],
//           [
//            [[25, 26], [28, 29], [30, 31], [31, 32]],
//            [[35, 36], [38, 39], [38, 39], [39, 40]],
//            [[41, 42], [44, 45], [46, 47], [47, 48]]
//           ]
//          ]

Contoh Lainnya

pilih

Semantik

Menghasilkan tensor result dengan setiap elemen dipilih dari on_true atau Tensor on_false berdasarkan nilai elemen pred yang sesuai. Secara lebih formal, result[result_index] = pred_element ? on_true[result_index] : on_false[result_index], dengan pred_element = rank(pred) = 0 ? pred[] : pred[result_index]. Untuk jenis terkuantisasi, menjalankan dequantize_select_quantize(pred, on_true, on_false, type(result)).

Input

Label Nama Jenis Batasan
(I1) pred tensor jenis i1 (C1)
(I2) on_true tensor atau tensor terkuantisasi per-tensor (C1-C2)
(I3) on_false tensor atau tensor terkuantisasi per-tensor (C2)

Output

Nama Jenis Batasan
result tensor atau tensor terkuantisasi per-tensor (C2)

Batasan

  • (C1) rank(pred) = 0 or shape(pred) = shape(on_true).
  • (C2) baseline_type(on_true) = baseline_type(on_false) = baseline_type(result).

Contoh

// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false) : (tensor<2x2xi1>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 2], [3, 8]]

Contoh Lainnya

select_and_scatter

Semantik

Sebarkan nilai dari tensor source menggunakan scatter berdasarkan hasil reduce_window dari tensor input menggunakan select dan menghasilkan sebuah tensor result.

Diagram berikut menunjukkan cara penghitungan elemen pada result operand dan source menggunakan contoh konkret.

select_and_scatter

Secara lebih formal:

  • selected_values = reduce_window_without_init(...) dengan input berikut:

    • inputs = [operand].
    • window_dimensions, window_strides, dan padding yang digunakan sebagaimana adanya.
    • base_dilations = windows_dilations = 1.
    • body ditentukan sebagai:
    def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>:
      return select(arg0, arg1) ? arg0 : arg1;
    

    tempat E = element_type(operand), dan reduce_window_without_init bekerja persis seperti reduce_window, kecuali bahwa schedule reduce (lihat mengurangi) tidak menyertakan nilai init. Saat ini tidak menentukan apa yang akan terjadi jika jendela yang sesuai tidak memiliki nilai (#731).

  • result[result_index] = reduce([source_values], [init_value], [0], scatter) dalam hal ini:

    • source_values = [source[source_index] for source_index in source_indices].
    • selected_index(source_index) = operand_index jika selected_values[source_index] memiliki elemen operand dari operand_index.
    • source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index].

Input

Label Nama Jenis Batasan
(I1) operand tensor atau tensor terkuantisasi per-tensor (C1-C4), (C6), (C8-C11)
(I2) source tensor atau tensor terkuantisasi per-tensor (C1), C2
(I3) init_value Tensor 0-dimensi atau per-tensor quantized tensor (C3)
(I4) window_dimensions Konstanta tensor 1 dimensi jenis si64 (C2), (C4), (C5)
(I5) window_strides Konstanta tensor 1 dimensi jenis si64 (C2), (C6), (C7)
(I6) padding Konstanta tensor 2 dimensi jenis si64 (C2), (C8)
(I7) select fungsi (C9)
(I8) scatter fungsi (C10)

Output

Nama Jenis Batasan
result tensor atau tensor terkuantisasi per-tensor (C11-C12)

Batasan

  • (C1) element_type(operand) = element_type(source).
  • (C2) shape(source) = num_windows jika:
    • padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1].
    • is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape.
    • num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1.
  • (C3) element_type(init_value) = element_type(operand).
  • (C4) size(window_dimensions) = rank(operand).
  • (C5) 0 < window_dimensions.
  • (C6) size(window_strides) = rank(operand).
  • (C7) 0 < window_strides.
  • (C8) shape(padding) = [rank(operand), 2].
  • (C9) select memiliki jenis (tensor<E>, tensor<E>) -> tensor<i1>, dengan E = element_type(operand).
  • (C10) scatter memiliki jenis (tensor<E>, tensor<E>) -> tensor<E>, dengan is_promotable(element_type(operand), E).
  • (C11) shape(operand) = shape(result).
  • (C12) element_type(result) = E.

Contoh

// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GE>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    "stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  window_dimensions = array<i64: 3, 1>,
  window_strides = array<i64: 2, 1>,
  padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi64>, tensor<2x2xi64>, tensor<i64>) -> tensor<4x2xi64>
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]

Contoh Lainnya

kirim

Semantik

Mengirim inputs ke saluran channel_id dan menghasilkan token result.

Jika is_host_transfer adalah true, operasi tersebut akan mentransfer data ke {i>host<i}. Jika tidak, data akan ditransfer ke perangkat lain. Artinya adalah ditentukan oleh implementasi. Penanda ini menduplikasi informasi yang disediakan di channel_type, jadi di masa mendatang kami berencana untuk menyimpan salah satunya saja (#666).

Input

Label Nama Jenis Batasan
(I1) inputs jumlah tensor atau tensor terkuantisasi
(I2) token token
(I3) channel_id konstanta jenis si64
(I4) channel_type enum DEVICE_TO_DEVICE dan DEVICE_TO_HOST (C1)
(I5) is_host_transfer konstanta jenis i1 (C1)

Output

Nama Jenis
result token

Batasan

  • (C1) channel_type didefinisikan sebagai:
    • DEVICE_TO_HOST jika is_host_transfer = true,
    • DEVICE_TO_DEVICE sebaliknya.

Contoh

%result = "stablehlo.send"(%operand, %token) {
  channel_handle = #stablehlo.channel_handle<handle = 1, type = 2>,
  is_host_transfer = true
} : (tensor<2x2xi64>, !stablehlo.token) -> !stablehlo.token

Contoh Lainnya

shift_left

Semantik

Melakukan operasi shift kiri berbasis element pada tensor lhs berdasarkan angka rhs bit dan menghasilkan tensor result.

Input

Label Nama Jenis Batasan
(I1) lhs tensor tipe integer (C1)
(I2) rhs tensor tipe integer (C1)

Output

Nama Jenis Batasan
result tensor tipe integer (C1)

Batasan

  • (C1) type(lhs) = type(rhs) = type(result).

Contoh

// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-2, 0, 8]

Contoh Lainnya

shift_right_arithmetic

Semantik

Melakukan operasi shift kanan aritmetika element-wise pada tensor lhs dengan rhs jumlah bit dan menghasilkan tensor result.

Input

Label Nama Jenis Batasan
(I1) lhs tensor tipe integer (C1)
(I2) rhs tensor tipe integer (C1)

Output

Nama Jenis Batasan
result tensor tipe integer (C1)

Batasan

  • (C1) type(lhs) = type(rhs) = type(result).

Contoh

// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-1, 0, 1]

Contoh Lainnya

shift_right_logical

Semantik

Melakukan operasi shift kanan logis element-wise pada tensor lhs dengan rhs jumlah bit dan menghasilkan tensor result.

Input

Label Nama Jenis Batasan
(I1) lhs tensor tipe integer (C1)
(I2) rhs tensor tipe integer (C1)

Output

Nama Jenis Batasan
result tensor tipe integer (C1)

Batasan

  • (C1) type(lhs) = type(rhs) = type(result).

Contoh

// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [9223372036854775807, 0, 1]

Contoh Lainnya

tanda

Semantik

Menampilkan tanda element-wise operand dan menghasilkan tensor result. Secara lebih formal, untuk setiap elemen x, semantik dapat dinyatakan menggunakan Sintaksis Python sebagai berikut:

def sign(x):
  if is_integer(x):
    if compare(x, 0, LT, SIGNED): return -1
    if compare(x, 0, EQ, SIGNED): return 0
    return 1
  elif is_float(x):
    if is_nan(x): return NaN
    if compare(x, -0.0, EQ, FLOAT): return -0.0
    if compare(x, +0.0, EQ, FLOAT): return +0.0
    if compare(x, 0.0, LT, FLOAT): return -1.0
    return 1.0
  elif is_complex(x):
    if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
    if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
    return divide(x, convert(abs(x), type(x)))

Untuk jenis terkuantisasi, menjalankan dequantize_op_quantize(sign, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand tensor bilangan bulat bertanda tangan, floating point, atau tipe kompleks atau tensor terkuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result tensor bilangan bulat bertanda tangan, floating point, atau tipe kompleks atau tensor terkuantisasi per-tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]

Contoh Lainnya

sinus

Semantik

Melakukan operasi sinus element-wise pada tensor operand dan menghasilkan result tensor. Bergantung pada jenis elemen, melakukan hal berikut:

  • Untuk float: sin dari IEEE-754.
  • Untuk bilangan kompleks: sinus kompleks.
  • Untuk jenis terkuantisasi: dequantize_op_quantize(sine, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.sine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [0.0, -1.0]]

Contoh Lainnya

slice

Semantik

Mengekstrak slice dari operand menggunakan indeks awal yang dihitung secara statis dan menghasilkan tensor result. start_indices berisi indeks awal dari irisan untuk setiap dimensi, limit_indices berisi indeks akhir (eksklusif) untuk irisan untuk setiap dimensi, dan strides berisi langkah untuk setiap dimensi.

Secara lebih formal, result[result_index] = operand[operand_index] di mana operand_index = start_indices + result_index * strides.

Input

Label Nama Jenis Batasan
(I1) operand tensor atau tensor terkuantisasi per-tensor (C1-C3), (C5)
(I2) start_indices Konstanta tensor 1 dimensi jenis si64 (C2), (C3), (C5)
(I3) limit_indices Konstanta tensor 1 dimensi jenis si64 (C2), (C3), (C5)
(I4) strides Konstanta tensor 1 dimensi jenis si64 (C2), (C4)

Output

Nama Jenis Batasan
result tensor atau tensor terkuantisasi per-tensor (C1), (C5)

Batasan

  • (C1) element_type(operand) = element_type(result).
  • (C2) size(start_indices) = size(limit_indices) = size(strides) = rank(operand).
  • (C3) 0 <= start_indices <= limit_indices <= shape(operand).
  • (C4) 0 < strides.
  • (C5) shape(result) = ceil((limit_indices - start_indices) / strides).

Contoh

// %operand: [
//            [0, 0, 0, 0],
//            [0, 0, 1, 1],
//            [0, 0, 1, 1]
//           ]
%result = "stablehlo.slice"(%operand) {
  start_indices = array<i64: 1, 2>,
  limit_indices = array<i64: 3, 4>,
  strides = array<i64: 1, 1>
} : (tensor<3x4xi64>) -> tensor<2x2xi64>
// % result: [
//            [1, 1],
//            [1, 1]
//           ]

Contoh Lainnya

mengurutkan

Semantik

Mengurutkan irisan 1 dimensi inputs di sepanjang dimensi dimension bersama-sama, menurut comparator dan menghasilkan results.

Tidak seperti input serupa dalam operasi lain, dimension memungkinkan nilai negatif, dengan semantik yang dijelaskan di bawah ini. Pada masa mendatang, tindakan ini mungkin tidak diizinkan untuk alasan konsistensi (#1377).

Jika is_stable bernilai benar, berarti penyortirannya stabil, yaitu, urutan relatif elemen yang dianggap setara dengan pembanding dipertahankan. Untuk kasus ini jika ada input tunggal, dua elemen e1 dan e2 dianggap sama dengan pembanding jika dan hanya jika comparator(e1, e2) = comparator(e2, e1) = false. Lihat formalisasi di bawah ini untuk mengetahui bagaimana generalisasi ke beberapa input.

Secara lebih formal, untuk semua result_index di index_space(results[0]):

  • adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension.
  • result_slice = [ri0, ..., :, ..., riR-1] dengan riN merupakan individu elemen di result_index, dan : disisipkan pada adjusted_dimension.
  • inputs_together = (inputs[0]..., ..., inputs[N-1]...).
  • results_together[result_slice] = sort(inputs_together[result_slice], comparator_together).
  • dengan sort mengurutkan irisan 1 dimensi dalam urutan yang tidak menurun seperti bahwa comparator_together akan menampilkan true jika argumen di sisi kiri adalah lebih kecil daripada argumen kedua sebelah kanan.
  • def comparator_together(lhs_together, rhs_together):
      args = []
      for (lhs_el, rhs_el) in zip(lhs_together, rhs_together):
        args.append(lhs_el)
        args.append(rhs_el)
      return comparator(*args)
    
  • (results[0]..., ..., results[N-1]...) = results_together.

Input

Label Nama Jenis Batasan
(I1) inputs jumlah tensor atau tensor terkuantisasi per-tensor (C1-C5)
(I2) dimension konstanta jenis si64 (C4)
(I3) is_stable konstanta jenis i1
(I4) comparator fungsi (C5)

Output

Nama Jenis Batasan
results jumlah tensor atau tensor terkuantisasi per-tensor (C2), (C3)

Batasan

  • (C1) 0 < size(inputs).
  • (C2) type(inputs...) = type(results...).
  • (C3) same(shape(inputs...) + shape(results...)).
  • (C4) -R <= dimension < R, dengan R = rank(inputs[0]).
  • (C5) comparator memiliki jenis (tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1>, dengan Ei = element_type(inputs[i]).

Contoh

// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>, %arg3: tensor<i64>):
    %predicate = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GT>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    "stablehlo.return"(%predicate) : (tensor<i1>) -> ()
}) {
  dimension = 0 : i64,
  is_stable = true
} : (tensor<2x3xi64>, tensor<2x3xi64>) -> (tensor<2x3xi64>, tensor<2x3xi64>)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]

Contoh Lainnya

sqrt

Semantik

Melakukan operasi root kuadrat berdasarkan elemen pada tensor operand dan menghasilkan Tensor result. Bergantung pada jenis elemen, melakukan hal berikut:

  • Untuk float: squareRoot dari IEEE-754.
  • Untuk bilangan kompleks: akar kuadrat kompleks.
  • Untuk jenis terkuantisasi: dequantize_op_quantize(sqrt, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [2.0, 3.0]]

Contoh Lainnya

kurangi

Semantik

Melakukan pengurangan berbasis elemen dari dua tensor lhs dan rhs serta menghasilkan Tensor result. Bergantung pada jenis elemen, melakukan hal berikut:

  • Untuk bilangan bulat: pengurangan bilangan bulat.
  • Untuk float: subtraction dari IEEE-754.
  • Untuk bilangan kompleks: pengurangan kompleks.
  • Untuk jenis terkuantisasi:
    • dequantize_op_quantize(subtract, lhs, rhs, type(result)).

Input

Label Nama Jenis Batasan
(I1) lhs tensor bilangan bulat, floating point, atau tipe kompleks atau tensor terkuantisasi per-tensor (C1)
(I2) rhs tensor bilangan bulat, floating point, atau tipe kompleks atau tensor terkuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result tensor bilangan bulat, floating point, atau tipe kompleks atau tensor terkuantisasi per-tensor (C1)

Batasan

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Contoh

// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
// %result: [[1, 2], [3, 4]]

Contoh Lainnya

tan

Semantik

Melakukan operasi tangen berbasis elemen pada tensor operand dan menghasilkan Tensor result. Bergantung pada jenis elemen, melakukan hal berikut:

  • Untuk float: tan dari IEEE-754.
  • Untuk bilangan kompleks: tangen kompleks.
  • Untuk jenis terkuantisasi: dequantize_op_quantize(tan, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.tan"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [
//           [0.0, 1.63312e+16],
//           [0.0, 5.44375e+15]
//          ]

Contoh Lainnya

Tanh

Semantik

Melakukan operasi tangen hiperbolik berbasis elemen pada tensor operand dan menghasilkan tensor result. Bergantung pada jenis elemen, melakukan hal berikut:

  • Untuk float: tanh dari IEEE-754.
  • Untuk bilangan kompleks: tangen hiperbolik kompleks.
  • Untuk jenis terkuantisasi:
    • dequantize_op_quantize(tanh, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand) : (tensor<3xf32>) -> tensor<3xf32>
// %result: [-0.76159416, 0.0, 0.76159416]

Contoh Lainnya

{i>transpose<i}

Semantik

Mengubah dimensi tensor operand menggunakan permutation dan menghasilkan Tensor result. Secara lebih formal, result[result_index] = operand[operand_index] dengan result_index[d] = operand_index[permutation[d]].

Input

Label Nama Jenis Batasan
(I1) operand tensor atau quantized tensor (C1-C4)
(I2) permutation Konstanta tensor 1 dimensi jenis si64 (C2-C4)

Output

Nama Jenis Batasan
result tensor atau quantized tensor (C1), C3-C4

Batasan

  • (C1) element_type(result) diberikan oleh:
    • element_type(operand), jika !is_per_axis_quantized(operand).
    • element_type(operand) kecuali quantization_dimension(operand) dan quantization_dimension(result) dapat berbeda, jika tidak.
  • (C2) permutation adalah permutasi dari range(rank(operand)).
  • (C3) shape(result) = dim(operand, permutation...).
  • (C4) Jika is_per_axis_quantized(result), maka quantization_dimension(operand) = permutation(quantization_dimension(result)).

Contoh

// %operand: [
//            [[1,2], [3,4], [5,6]],
//            [[7,8], [9,10], [11,12]]
//           ]
%result = "stablehlo.transpose"(%operand) {
  permutation = array<i64: 2, 1, 0>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
// %result: [
//           [[1,7], [3,9], [5,11]],
//           [[2,8], [4,10], [6,12]]
//          ]

Contoh Lainnya

triangular_solve

Semantik

Menyelesaikan kumpulan sistem persamaan linear dengan segitiga bawah atau atas matriks koefisien.

Secara lebih formal, mengingat a dan b, result[i0, ..., iR-3, :, :] adalah solusinya ke op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :] saat left_side adalah true atau x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :] saat left_side adalah false, menyelesaikan variabel x tempat op(a) ditentukan paling lambat transpose_a, yang dapat berupa salah satu dari hal berikut:

  • NO_TRANSPOSE: Menjalankan operasi menggunakan a sebagaimana adanya.
  • TRANSPOSE: Melakukan operasi pada transposisi a.
  • ADJOINT: Melakukan operasi pada transposisi konjugasi a.

Data input hanya dibaca dari segitiga bawah a, jika lower adalah true atau segitiga atas a, jika sebaliknya. Data output ditampilkan dalam segitiga yang sama; nilai dalam segitiga lainnya merupakan manifestasi implementasi.

Jika unit_diagonal bernilai benar, implementasi dapat mengasumsikan bahwa ukuran diagonal elemen a sama dengan 1, jika tidak, perilaku tidak terdefinisi.

Untuk jenis terkuantisasi, menjalankan dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower, unit_diagonal, transpose_a), a, b, type(result)).

Input

Label Nama Jenis Batasan
(I1) a tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor (C1-C3)
(I2) b tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor (C1-C4)
(I3) left_side konstanta jenis i1 (C3)
(I4) lower konstanta jenis i1
(I5) unit_diagonal konstanta jenis i1
(I6) transpose_a enum NO_TRANSPOSE, TRANSPOSE, dan ADJOINT

Output

Nama Jenis Batasan
result tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor (C1)

Batasan

  • (C1) baseline_element_type(a) = baseline_element_type(b).
  • (C2) 2 <= rank(a) = rank(b) = R.
  • (C3) Hubungan antara shape(a) dan shape(b) didefinisikan sebagai berikut:
    • shape(a)[:-3] = shape(b)[:-3].
    • dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1).
  • (C4) baseline_type(b) = baseline_type(result).

Contoh

// %a = [
//       [1.0, 0.0, 0.0],
//       [2.0, 4.0, 0.0],
//       [3.0, 5.0, 6.0]
//      ]
// %b = [
//       [2.0, 0.0, 0.0],
//       [4.0, 8.0, 0.0],
//       [6.0, 10.0, 12.0]
//      ]
%result = "stablehlo.triangular_solve"(%a, %b) {
  left_side = true,
  lower = true,
  unit_diagonal = false,
  transpose_a = #stablehlo<transpose NO_TRANSPOSE>
} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// %result: [
//           [2.0, 0.0, 0.0],
//           [0.0, 2.0, 0.0],
//           [0.0, 0.0, 2.0]
//          ]

tuple

Semantik

Menghasilkan tuple result dari nilai val.

Input

Label Nama Jenis Batasan
(I1) val jumlah nilai variadis (C1)

Output

Nama Jenis Batasan
result tuple (C1)

Batasan

  • (C1) result memiliki jenis tuple<E0, ..., EN-1> dengan Ei = type(val[i]).

Contoh

// %val0: [1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1) : (tensor<2xf32>, tuple<tensor<i32>>) -> tuple<tensor<2xf32>, tuple<tensor<i32>>>
// %result: ([1.0, 2.0], (3))

Contoh Lainnya

uniform_dequantize

Semantik

Melakukan konversi berbasis elemen dari tensor terkuantisasi operand ke tensor floating point result sesuai dengan parameter kuantisasi yang ditentukan oleh jenis operand.

Secara lebih formal, result = dequantize(operand).

Input

Label Nama Jenis Batasan
(I1) operand tensor terkuantisasi (C1), C2

Output

Nama Jenis Batasan
result tensor jenis floating point (C1), C2

Batasan

  • (C1) shape(operand) = shape(result).
  • (C2) element_type(result) = expressed_type(operand).

Contoh

// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
// %result: [4.0, 15.0]

uniform_quantize

Semantik

Melakukan konversi berbasis elemen dari tensor floating point atau tensor terkuantisasi operand ke tensor terkuantisasi result sesuai dengan kuantisasi parameter yang ditentukan oleh jenis result.

Secara lebih formal,

  • Jika is_float(operand):
    • result = quantize(operand, type(result)).
  • Jika is_quantized(operand):
    • float_result = dequantize(operand).
    • result = quantize(float_result, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand tensor floating point atau tipe terkuantisasi (C1), C2

Output

Nama Jenis Batasan
result tensor terkuantisasi (C1), C2

Batasan

  • (C1) shape(operand) = shape(result).
  • (C2) expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand).

Contoh

// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
// %result: [10, 10]

// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-20,0.2:-30}>>
// %result: [20, 45]

sementara

Semantik

Menghasilkan output dari eksekusi fungsi body 0 kali atau lebih saat Fungsi cond menghasilkan true. Secara lebih formal, semantik dapat dinyatakan menggunakan sintaks Python sebagai berikut:

internal_state = operand
while cond(*internal_state):
  internal_state = body(*internal_state)
results = internal_state

Perilaku loop tanpa batas akan ditentukan nanti (#383).

Input

Label Nama Jenis Batasan
(I1) operand jumlah tensor, tensor atau token terkuantisasi (C1-C3)
(I2) cond fungsi (C1)
(I3) body fungsi (C2)

Output

Nama Jenis Batasan
results jumlah tensor, tensor atau token terkuantisasi (C3)

Batasan

  • (C1) cond memiliki jenis (T0, ..., TN-1) -> tensor<i1>, dengan Ti = type(operand[i]).
  • (C2) body memiliki jenis (T0, ..., TN-1) -> (T0, ..., TN-1), dengan Ti = type(operand[i]).
  • (C3) type(results...) = type(operand...).

Contoh

// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %cond = "stablehlo.compare"(%arg0, %ten) {
      comparison_direction = #stablehlo<comparison_direction LT>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    stablehlo.return %cond : tensor<i1>
  }, {
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %new_sum = stablehlo.add %arg1, %one : tensor<i64>
    %new_i = stablehlo.add %arg0, %one : tensor<i64>
    stablehlo.return %new_i, %new_sum : tensor<i64>, tensor<i64>
}) : (tensor<i64>, tensor<i64>) -> (tensor<i64>, tensor<i64>)
// %results0: 10
// %results1: 10

Contoh Lainnya

Xor

Semantik

Melakukan XOR berbasis elemen dari dua tensor lhs dan rhs serta menghasilkan result tensor. Bergantung pada jenis elemen, melakukan hal berikut:

  • Untuk boolean: XOR logis.
  • Untuk bilangan bulat: bitwise XOR.

Input

Label Nama Jenis Batasan
(I1) lhs tensor jenis boolean atau bilangan bulat (C1)
(I2) rhs tensor jenis boolean atau bilangan bulat (C1)

Output

Nama Jenis Batasan
result tensor jenis boolean atau bilangan bulat (C1)

Batasan

  • (C1) type(lhs) = type(rhs) = type(result).

Contoh

// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]

// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, false]]

Contoh Lainnya

Interop Dialek

Saat ini, program StableHLO di luar jaringan terkadang berisi operasi yang tidak ditentukan oleh StableHLO.

Modul, Fungsi, Panggilan, dan Kembali

StableHLO menggunakan operasi MLIR upstream untuk ModuleOp, FuncOp, CallOp, dan ReturnOp. Hal ini dilakukan untuk interop yang lebih baik dengan mesin MLIR yang ada, karena kartu berguna yang menargetkan FuncOp dan ModuleOp, serta banyak kompilasi pipeline mengharapkan adanya operasi ini. Jaminan kompatibilitas penuh yang diterapkan pada operasi ini. Jika ada yang berubah pada operasi ini dalam dengan cara yang tidak kompatibel (yaitu penghapusan), padanan StableHLO akan ditambahkan untuk mempertahankan kompatibilitas mundur.

CHLO

Opset CHLO berisi operasi tingkat lebih tinggi yang terurai menjadi StableHLO. Saat ini tidak ada jaminan kompatibilitas untuk CHLO. Untuk kompatibilitas jaminan, pass chlo-legalize-to-stablehlo harus digunakan sebelum serialisasi.

Operasi Bentuk

Praktik umum dalam komunitas adalah menggunakan operasi tertentu dari Dialek MLIR dalam program StableHLO dinamis untuk melakukan komputasi bentuk. Umumnya, ini mencakup dialek shape operasi seperti shape_of atau num_elements, tensor dialek operasi seperti dim atau from_elements, dan jenis index bawaan.

Dynamism RFC > O2 menunjukkan ini sebagai di luar cakupan, namun beberapa dukungan untuk jenis index disertakan untuk tujuan interop. Tidak ada jaminan kompatibilitas untuk template ini operasi atau jenis. Opsi shape-legalize-to-stablehlo dapat digunakan untuk mengonversi operasi ini menjadi operasi StableHLO yang didukung penuh.

Operasi yang Tidak Digunakan Lagi

Ada beberapa operasi StableHLO yang diturunkan dari MHLO yang tidak digunakan lagi dan akan keluar dari StableHLO. Detail lengkap tentang penghapusan dapat ditemukan di StableHLO v1.0 Cleanup #2283. Masalah pelacak untuk penghentian ini adalah #2340.

Operasi ini termasuk dalam beberapa kategori:

  • "Tidak di HLO" kategori operasi StabilHLO - mereka awalnya adalah bagian dari opset StableHLO tetapi kemudian dianggap tidak cocok dengan baik: broadcast, create_token, cross-replica-sum, dot, einsum, torch_index_select, unary_einsum (#3).
  • Operasi yang tidak digunakan - Operasi ini mungkin berguna pada titik tertentu, tetapi operasinya belum dikembangkan, atau {i>pipelines<i} yang menggunakan operasi ini telah difaktorkan ulang sehingga tidak diperlukan lagi. Ini mencakup map, tuple (#598), get_tuple_element, rng, complex membandingkan #560, dan konvolusi window_reversal (#1181).

Beberapa dari operasi ini dapat dihapus dengan mudah karena mereka dapat dinyatakan menggunakan operasi yang ada (broadcast, create_token, cross-replica-sum, dot, unary_einsum) dan akan dihapus setelah jendela kompatibilitas yang ada izin lewat (6 bulan). URL lainnya masih dalam proses peninjauan untuk dihapus (einsum, get_tuple_element, map, rng torch_index_select, tuple, complex perbandingan, window_reversal). Menunggu masukan dari komunitas, operasi ini akan dihapus, atau ditambahkan ke spesifikasi dengan dukungan penuh. Sampai operasi berjangka ini diketahui, mereka hanya dijamin kompatibilitasnya selama 6 bulan.

Eksekusi

Eksekusi berurutan

Program StableHLO dijalankan dengan memberikan nilai input ke fungsi main dan menghitung nilai output. Nilai {i>output<i} dari sebuah fungsi dihitung oleh mengeksekusi grafik operasi yang di-root pada operasi return yang sesuai.

Urutan eksekusi ditentukan oleh implementasi selama sesuai dengan dataflow, yaitu jika operasi dijalankan sebelum digunakan. Di StabilHLO, semua dan operasi {i>side-effecting<i} memakai satu token dan menghasilkan satu token (beberapa token dapat di-multiplex menjadi satu token melalui after_all), sehingga urutan eksekusi dari sisi juga selaras dengan dataflow. Misalnya, dalam program di bawah ini ada dua kemungkinan perintah eksekusi: %0%1%2return dan %1%0%2return.

func.func @main() -> tensor<f64> {
  %0 = stablehlo.constant dense<1.0> : tensor<f64>
  %1 = stablehlo.constant dense<2.0> : tensor<f64>
  %2 = stablehlo.add %0, %1 : tensor<f64>
  return %2 : tensor<f64>
}

Secara lebih formal, proses SttableHLO adalah kombinasi dari: 1) program StableHLO, 2) status operasi (belum dieksekusi, sudah dieksekusi), dan 3) nilai antara yang sedang dikerjakan oleh proses. Proses ini dimulai dengan nilai input ke fungsi main, dilanjutkan hingga grafik operasi yang memperbarui status operasi dan nilai perantara, serta diakhiri dengan nilai output. Formalisasi lebih lanjut akan ditentukan nanti (#484).

Eksekusi paralel

Program StableHLO dapat dijalankan secara paralel, yang disusun ke dalam grid proses 2D dari num_replicas oleh num_partitions yang keduanya memiliki jenis ui32.

Di petak proses SttableHLO, num_replicas * num_partitions dari StableHLO proses dijalankan secara bersamaan. Setiap proses memiliki process_id = (replica_id, partition_id), dengan replica_id dalam replica_ids = range(num_replicas) dan partition_id di partition_ids = range(num_partitions) yang keduanya memiliki ketik ui32.

Ukuran grid proses diketahui secara statis untuk setiap program (dalam di masa mendatang, kami berencana untuk menjadikannya bagian eksplisit dari program StableHLO #650), dan posisi di dalam {i>process grid <i}dikenal secara statis untuk setiap proses. Setiap proses memiliki akses ke posisinya dalam petak proses melalui replica_id dan partition_id op.

Di dalam {i>process grid<i}, semua program bisa jadi sama (dalam kolom “{i>Single<i} {i>Program<i}, {i>Multiple Data<i}” berbeda), semuanya bisa berbeda (dalam "Multiple Program, Multi Data" {i>style<i}) atau sesuatu di antaranya. Di masa mendatang, kami berencana memperkenalkan dukungan untuk idiom lain dalam mendefinisikan program StableHLO paralel, termasuk GSPMD (#619).

Di dalam kisi proses, proses sebagian besar independen satu sama lain - keduanya memiliki status operasi terpisah, nilai input/menengah/output terpisah dan sebagian besar operasi dijalankan secara terpisah antar proses, dengan pengecualian untuk sejumlah kecil operasi kolektif yang dijelaskan di bawah ini.

Mengingat bahwa eksekusi sebagian besar operasi hanya menggunakan nilai dari data, biasanya tidak ambigu untuk merujuk pada nilai-nilai ini dengan namanya. Namun, ketika menjelaskan semantik operasi kolektif, itu tidak cukup, dan yang memunculkan notasi name@process_id untuk merujuk ke nilai name dalam proses tertentu. (Dari perspektif tersebut, name yang tidak memenuhi syarat dapat dilihat sebagai singkatan untuk name@(replica_id(), partition_id())).

Urutan eksekusi di seluruh proses adalah implementasi yang ditentukan, kecuali untuk sinkronisasi yang diperkenalkan oleh komunikasi point-to-point dan operasi kolektif sebagaimana dijelaskan di bawah ini.

Komunikasi titik ke titik

Proses StabilHLO dapat berkomunikasi satu sama lain melalui Saluran SttableHLO. Saluran diwakili oleh ID positif jenis si64. Melalui berbagai operasi, adalah mungkin untuk mengirim nilai ke saluran dan menerimanya dari saluran.

Formalisasi lebih lanjut, mis. dari mana ID channel ini berasal, bagaimana proses yang menyadarinya dan jenis sinkronisasi apa yang diperkenalkan oleh mereka, akan ditentukan nanti (#484).

Komunikasi streaming

Setiap proses StableHLO memiliki akses ke dua antarmuka streaming:

  • Infeed yang dapat dibaca.
  • Outfeed yang dapat ditulis.

Tidak seperti saluran, yang digunakan untuk berkomunikasi antar proses dan karenanya memiliki proses di kedua ujungnya, feed infeed dan outfeed memiliki yang ditentukan oleh implementasi akhir.

Formalisasi lebih lanjut, mis. bagaimana komunikasi streaming memengaruhi eksekusi dan jenis sinkronisasi yang dilakukannya, akan ditentukan nanti (#484).

Operasi kolektif

Ada enam operasi kolektif di StableHLO: all_gather, all_reduce, all_to_all, collective_broadcast, collective_permute, dan reduce_scatter. Semua operasi ini membagi proses dalam proses StableHLO grid ke dalam grup proses SttableHLO dan jalankan komputasi gabungan dalam setiap grup proses, terpisah dari grup proses lainnya.

Dalam setiap grup proses, operasi kolektif dapat memperkenalkan sinkronisasi hambatan. Formalisasi lebih lanjut, mis. dan menjelaskan kapan sinkronisasi terjadi, bagaimana sebenarnya proses sampai pada penghalang ini, dan apa yang terjadi jika tidak dilakukan, akan ditentukan nanti (#484).

Jika grup proses melibatkan komunikasi lintas partisi, yaitu ada proses dalam grup proses yang ID partisinya berbeda, lalu dieksekusi operasi kolektif membutuhkan saluran, dan operasi kolektif harus menyediakan channel_id positif dari jenis si64. Komunikasi lintas replika tidak perlu saluran TV Anda.

Komputasi yang dilakukan oleh operasi kolektif bersifat khusus untuk operasi individual dan dijelaskan di bagian pengoperasian individual di atas. Namun, strategi dengan di mana grid proses dibagi menjadi grup proses yang dibagikan di antara operasi ini di bagian ini dan dijelaskan dalam bagian ini. Secara lebih formal, StableHLO mendukung mengikuti empat strategi berikut.

cross_replica

Hanya komunikasi lintas replika yang terjadi dalam setiap grup proses. Ini menggunakan replica_groups - daftar ID replika - dan komputasi hasil Kartesius dari replica_groups oleh partition_ids. replica_groups harus memiliki elemen yang unik dan mencakup semua replica_ids. Secara lebih formal, menggunakan {i>Syntax<i} Python:

def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
  for replica_group in replica_groups:
    for partition_id in partition_ids:
      process_group = []
      for replica_id in replica_group:
        process_group.append((replica_id, partition_id))
      yield process_group

Misalnya, untuk replica_groups = [[0, 1], [2, 3]] dan num_partitions = 2, cross_replica akan menghasilkan [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]].

cross_partition

Hanya komunikasi lintas partisi yang terjadi dalam setiap grup proses. Ini menggunakan partition_groups - daftar ID partisi - dan menghitung produk Kartesius partition_groups dari replica_ids. partition_groups harus memiliki elemen yang unik dan mencakup semua partition_ids. Secara lebih formal, menggunakan sintaksis Python:

def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
  for partition_group in partition_groups:
    for replica_id in replica_ids:
      process_group = []
      for partition_id in partition_group:
        process_group.append((replica_id, partition_id))
      yield process_group

Misalnya, untuk partition_groups = [[0, 1]] dan num_replicas = 4, cross_partition akan menghasilkan [[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]].

cross_replica_and_partition

Baik komunikasi lintas replika maupun lintas partisi dapat terjadi di dalam setiap {i>process group<i}. Strategi ini menggunakan replica_groups - daftar yang berisi ID replika - dan menghitung produk Kartesius dari setiap replica_group dengan partition_ids. replica_groups harus memiliki elemen yang unik dan mencakup semua replica_ids. Secara lebih formal, menggunakan sintaksis Python:

def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
  for replica_group in replica_groups:
    process_group = []
    for partition_id in partition_ids:
      for replica_id in replica_group:
        process_group.append((replica_id, partition_id))
    yield process_group

Misalnya, untuk replica_groups = [[0, 1], [2, 3]] dan num_partitions = 2, cross_replica_and_partition akan menghasilkan [[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]].

flattened_ids

Strategi ini menggunakan flattened_id_groups - daftar daftar "diratakan" ID proses dalam bentuk replica_id * num_partitions + partition_id - dan mengubahnya menjadi ID proses. flattened_id_groups harus memiliki elemen yang unik dan mencakup semua process_ids. Secara lebih formal, menggunakan sintaksis Python:

def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
  for flattened_id_group in flattened_id_groups:
    process_group = []
    for flattened_id in flattened_id_group:
      replica_id = flattened_id // num_partitions
      partition_id = flattened_id % num_partitions
      process_group.append((replica_id, partition_id))
    yield process_group

Misalnya, untuk flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]], num_replicas = 4 dan num_partitions = 2, flattened_ids akan menghasilkan [[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]].

Akurasi

Saat ini, StableHLO tidak memberikan jaminan tentang akurasi numerik, tapi ini bisa berubah pada masa mendatang (#1156).

Semantik eksekusi operasi terkuantisasi

Penafsiran operasi StableHLO terkuantisasi dapat bervariasi bergantung pada persyaratan dan kemampuan perangkat kerasnya. Misalnya, beberapa perangkat keras mungkin memilih untuk menafsirkan operasi terkuantisasi menggunakan metode "{i>dequantize, do floating-point<i} operasi, dan terakhir "mengkuantisasikan" strategi. Orang lain dapat melakukan seluruh komputasi dengan aritmatika bilangan bulat. Karenanya, interpretasi dari operasi StableHLO terkuantisasi ditentukan secara eksklusif oleh terlepas dari implementasi layanan. Penafsiran kuantisasi hibrida (#1575) harus didasarkan pada semantik seperti yang ditentukan dalam spesifikasi (melalui 1792).

Error

Program StabilHLO divalidasi melalui berbagai kendala untuk operasi individu, yang mengatur banyak kelas {i>error<i} sebelum waktu proses. Namun, kondisi error masih mungkin terjadi, misalnya melalui overflow bilangan bulat, akses di luar batas, dll. Kecuali dinyatakan lain secara eksplisit, semua error ini menghasilkan perilaku yang ditentukan implementasi, tetapi ini dapat berubah dalam mendatang (#1157).

Pengecualian floating point

Sebagai pengecualian untuk aturan ini, pengecualian floating point dalam program StableHLO memiliki perilaku yang jelas. Operasi yang menghasilkan pengecualian yang ditentukan oleh Standar IEEE-754 (operasi tidak valid, pembagian dengan nol, overflow, underflow, atau pengecualian tidak tepat) menghasilkan hasil default (seperti yang didefinisikan dalam standar) dan melanjutkan eksekusi tanpa menaikkan tanda status yang sesuai; mirip dengan raiseNoFlag penanganan pengecualian dari standar. Pengecualian untuk non-standar operasi (mis. aritmatika kompleks dan fungsi transendental tertentu) merupakan ditentukan oleh implementasi.

Ketidakcocokan bentuk

StableHLO mendukung tensor berbentuk dinamis. Namun, bentuk harus sesuai dengan runtime, jika tidak, perilaku tidak ditentukan. StabilHLO tidak secara eksplisit menyediakan operasi yang dapat menyatakan bahwa tensor memiliki bentuk tertentu saat runtime. Membuat kode yang benar adalah tanggung jawab produsen.

Sebagai contoh spesifik, program di bawah valid. Namun, saat runtime, bentuk %arg0 dan %arg1 yang sama harus sama. Jika tidak, perilaku program tidak terdefinisi:

func.func @foo(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
    %0 = stablehlo.add %arg0, %arg1 : tensor<?xi32>
    return %0 : tensor<?xi32>
}

Notasi

Untuk menjelaskan sintaksis, dokumen ini menggunakan ragam ISO EBNF yang dimodifikasi standar (ISO/IEC 14977:1996, Wikipedia), dengan dua modifikasi: 1) aturan ditentukan menggunakan ::=, bukan =,

2) penyambungan dinyatakan menggunakan penjajaran, bukan ,.

Untuk menjelaskan semantik (yaitu dalam bagian "Types", "Constants", dan "Ops"), kami menggunakan formula yang didasarkan pada sintaksis Python, yang diperluas dengan untuk mengekspresikan operasi array secara ringkas seperti yang dijelaskan di bawah ini. Ini berfungsi dengan baik untuk cuplikan kode berukuran kecil, tetapi dalam kasus yang jarang terjadi, ketika cuplikan kode yang lebih besar diperlukan, kita menggunakan {i>syntax<i} vanilla Python yang selalu diperkenalkan secara eksplisit.

Formula

Mari kita pelajari cara kerja formula berdasarkan contoh dari dot_general spesifikasi pendukung. Salah satu batasan untuk operasi ini terlihat seperti berikut: dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...).

Nama-nama yang digunakan dalam formula ini berasal dari dua sumber: 1) fungsi global, yaitu dim, 2) definisi anggota dari elemen program yang sesuai, yaitu Input lhs, lhs_batching_dimensions, rhs, dan rhs_batching_dimensions yang ditentukan dalam "Input" dari dot_general.

Seperti yang disebutkan di atas, {i>syntax<i} formula ini berbasis Python dengan beberapa ekstensi yang berorientasi keringkasan. Untuk memahami formulanya, mari kita ubah menjadi sintaks vanilla Python.

A) Dalam formula ini, kita menggunakan = untuk mewakili kesetaraan. Jadi, langkah pertama untuk mendapatkan sintaksis Python, mengganti = dengan ==, seperti berikut: dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...).

B) Selain itu, rumus ini mendukung elipsis (...) yang mengubah ekspresi skalar menjadi ekspresi tensor. Singkatnya, f(xs...) secara kasar berarti "untuk setiap x skalar pada tensor xs, hitung f(x) skalar, lalu tampilkan semuanya hasil skalar ini bersama-sama sebagai hasil tensor". Dalam sintaks{i> <i}vanilla Python, contoh formula kita berubah menjadi: [dim(lhs, dim1) for dim1 in lhs_batching_dimensions] == [dim(rhs, dim2) for dim2 in rhs_batching_dimensions].

Berkat elips, sering kali kita bisa menghindari bekerja pada tingkat skalar individu. Namun, dalam beberapa kasus yang rumit, tingkat semi-informal di tingkat dapat digunakan seperti di formula start_indices[bi0, ..., :, ..., biN] dari spesifikasi gather. Untuk memberikan keringkasan, kami tidak menyediakan formalisme yang tepat untuk menerjemahkan sintaks tersebut ke vanilla Python, berharap bahwa hal itu masih dapat dipahami secara intuitif berdasarkan kasus per kasus. Harap beri tahu kami jika beberapa formula tertentu terlihat tidak tembus pandang, dan kami akan mencoba memperbaikinya.

Anda juga akan memperhatikan bahwa formula menggunakan elips untuk memperluas segala macam daftar, termasuk tensor, daftar tensor (yang misalnya dapat muncul dari jumlah tensor), dll. Ini adalah area lain di mana kami tidak memberikan formalisme (mis., daftar bahkan bukan bagian dari sistem jenis StableHLO) dan mengandalkan pemahaman intuitif.

C) Kendaraan notasi penting terakhir yang kami gunakan adalah penyiaran. Meskipun {i>opset<i} StableHLO tidak mendukung penyiaran implisit, formula dilakukan, juga untuk membuat keringkasan. Secara singkat, jika skalar digunakan dalam konteks di mana tensor diharapkan, skalarnya disiarkan ke sesuai dengan bentuk yang diinginkan.

Untuk melanjutkan contoh dot_general, berikut batasan lainnya: 0 <= lhs_batching_dimensions < rank(lhs). Seperti yang ditentukan dalam dot_general spesifikasi, lhs_batching_dimensions adalah tensor, tetapi 0 dan rank(lhs) adalah skalar. Setelah kita menerapkan siaran implisit, formulanya akan menjadi [0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)].

Ketika diterapkan ke operasi dot_general tertentu, formula ini akan mengevaluasi ke tensor boolean. Ketika formula digunakan sebagai {i>constraints<i}, berlaku jika formula bernilai true atau ke tensor yang hanya memiliki true elemen.

Nama

Dalam formula, cakupan leksikal meliputi: 1) fungsi global, 2) definisi anggota,

3) definisi lokal. Daftar fungsi global disediakan di bawah ini. Daftar definisi elemen tergantung pada elemen program yang notasinya diterapkan ke:

  • Untuk operasi, definisi anggota menyertakan nama yang diperkenalkan dalam "Input" dan "Output" bagian.
  • Untuk yang lainnya, definisi anggota mencakup bagian struktural dari , dinamai menurut non-terminal EBNF yang sesuai. Sebagian besar waktu, nama-nama bagian struktural ini diperoleh dengan mengubah nama non-terminal untuk snake case (misalnya IntegerLiteral => integer_literal), tetapi terkadang nama disingkat dalam prosesnya (mis. QuantizationStorageType => storage_type), dalam hal ini nama tersebut diperkenalkan secara eksplisit mirip dengan "Input" / "Output" bagian yang beroperasi spesifikasi produk.
  • Selain itu, definisi anggota selalu menyertakan self untuk merujuk ke elemen program yang sesuai.

Nilai

Saat dievaluasi, formula menggunakan jenis nilai berikut: 1) Value (nilai sebenarnya, misalnya, dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>; mereka selalu tahu tipenya), 2) Placeholder (nilai mendatang, misalnya lhs, rhs, atau result; nilai sebenarnya nilainya belum diketahui, hanya jenisnya yang diketahui), 3) Type (jenis seperti yang ditentukan di bagian "Types"), 4) Function (fungsi global seperti yang didefinisikan di bagian "Fungsi").

Bergantung pada konteksnya, nama mungkin merujuk pada nilai yang berbeda. Selengkapnya khususnya, "Semantik" untuk operasi (dan yang setara untuk program lain ) menentukan logika runtime, sehingga semua input tersedia sebagai Value. Sebaliknya, daftar “{i>Constraints<i}” untuk operasi (dan yang setara) mendefinisikan "waktu kompilasi" logika, yaitu sesuatu yang biasanya dieksekusi sebelum {i>runtime<i}, jadi hanya input konstan yang tersedia sebagai Value dan input lainnya yang hanya tersedia sebagai Placeholder.

Nama Di "Semantik" Di bagian "Constraints"
Fungsi global Function Function
Input konstan Value Value
Input tidak konstan Value Placeholder
Output Value Placeholder
Definisi lokal Tergantung definisi Tergantung definisi

Mari kita pertimbangkan contoh operasi transpose:

%result = "stablehlo.transpose"(%operand) {
  permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>

Untuk operasi ini, permutation adalah konstanta, sehingga tersedia sebagai Value dalam semantik dan batasan. Sebaliknya, operand dan result adalah tersedia sebagai Value dalam semantik, tetapi hanya sebagai Placeholder dalam batasan.

Fungsi

Konstruksi jenis

Tidak ada fungsi yang dapat digunakan untuk membuat jenis. Sebaliknya, kita secara langsung menggunakan sintaks jenis karena biasanya lebih ringkas. Mis. (tensor<E>, tensor<E>) -> (tensor<E>), bukan function_type( [tensor_type([], E), tensor_type([], E)], [tensor_type([], E)]).

Fungsi pada jenis

  • element_type ditentukan berdasarkan jenis tensor dan jenis tensor terkuantisasi, serta menampilkan TensorElementType atau QuantizedTensorElementType bagian dari TensorType atau QuantizedTensorType yang sesuai.
def element_type(x: Value | Placeholder | Type):
 if type(x) == TensorType:
    return tensor_element_type(x)
  if type(x) == QuantizedTensorType:
    return quantized_tensor_element_type(x)
  if type(x) is not Type:
    return element_type(type(x))
  • is_per_axis_quantized(x: Value | Placeholder | Type) -> Value adalah pintasan untuk is_quantized(x) and quantization_dimension(x) is not None.

  • is_per_tensor_quantized(x: Value | Placeholder | Type) -> Value adalah untuk is_quantized(x) and quantization_dimension(x) is None.

  • is_promotable(x: Type, y: Type) -> bool memeriksa apakah jenis x dapat dipromosikan untuk mengetik y. Jika x dan y adalah QuantizedTensorElementType, promosi hanya diterapkan untuk storage_type. Versi promosi khusus ini saat ini digunakan dalam konteks komputasi pengurangan (lihat RFC untuk detail selengkapnya).

def is_promotable(x: Type, y: Type) -> Value:
  is_same_type = (is_bool(x) and is_bool(y)) or
    (is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
    (is_complex(x) and is_complex(y)) or
    (is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))

  if is_same_type == False:
    return False

  if is_integer(x) or is_float(x):
    return bitwidth(x) <= bitwidth(y)

  if is_complex(x):
    return bitwidth(element_type(x)) <= bitwidth(element_type(y))

  if is_quantized(x):
    return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))

  return false
  • is_quantized(x: Value | Placeholder | Type) -> Value adalah pintasan untuk is_quantized_tensor_element_type(x).

  • is_type_name(x: Value | Placeholder | Type) -> Value. Tersedia untuk semua jenis datanya. Misalnya, is_float(x) menampilkan true jika x adalah FloatType. Jika x adalah nilai atau placeholder, fungsi ini adalah pintasan untuk is_type_name(type(x)).

  • max_value(x: Type) -> Value menampilkan nilai maksimum TensorElementType. Jika x bukan TensorElementType, None akan ditampilkan.

  • min_value(x: Type) -> Value menampilkan nilai minimum yang memungkinkan TensorElementType. Jika x bukan TensorElementType, None akan ditampilkan.

  • member_name(x: Value | Placeholder | Type) -> Any. Tersedia untuk semua anggota definisi member_name dari semua jenis. Misalnya, tensor_element_type(x) menampilkan bagian TensorElementType dari TensorType yang sesuai. Jika x adalah nilai atau placeholder, fungsi ini adalah pintasan untuk member_name(type(x)). Jika x bukan jenis yang memiliki anggota yang sesuai, atau nilai atau placeholder dari jenis tersebut, akan menampilkan None.

  • is_empty_algorithm(*args: Type) memeriksa apakah semua kolom algoritma titik telah ditetapkan ke None. Ini diperlukan karena algoritma titik telah menentukan implementasi perilaku default, jadi menentukan nilai {i>default<i} akan menjadi tidak benar.

Konstruksi nilai

  • operation_name(*xs: Value | Type) -> Value. Tersedia untuk semua operasi. Misalnya, add(lhs, rhs) menggunakan dua nilai tensor lhs dan rhs, lalu menampilkan output dari mengevaluasi operasi add dengan input ini. Untuk beberapa operasi, misalnya broadcast_in_dim, jenis output-nya "beban", yaitu yang diperlukan untuk mengevaluasi operasi. Dalam hal ini, fungsi menganggap jenis ini sebagai argumen.

Fungsi pada nilai

  • Semua operator dan fungsi Python tersedia. Mis. keduanya langganan dan mengiris notasi dari Python tersedia untuk diindeks menjadi tensor, tensor terkuantisasi dan tuple.

  • to_destination_type(x: Value, destination_type: Type) -> Value ditentukan di dan menampilkan nilai yang dikonversi dari x berdasarkan type(x) dan destination_type sebagai berikut:

def to_destination_type(x: Value, destination_type: Type) -> Value:
  if type(x) == destination_type:
    return x

  if is_quantized(destination_type):
    if is_quantized(type(x)):
      return quantize(x, destination_type)
    assert is_float(type(x))
    return quantize(x, destination_type)

  if is_quantized(type(x)):
    assert destination_type = expressed_type(type(x))
    return dequantize(type(x))

  return convert(x, destination_type)

Ada diskusi awal tentang penggabungan convert, uniform_quantize, dan Operasi uniform_dequantize (#1576). Setelah penggabungan, kita tidak memerlukan fungsi di atas dan dapat menggunakan nama operasi untuk convert.

  • is_nan(x: Value) -> Value ditentukan pada tensor dan menampilkan true jika semua elemen x adalah NaN atau false jika tidak. Jika x bukan tensor, akan menampilkan None.

  • is_sorted(x: Value) -> Value ditentukan pada tensor dan menampilkan true jika elemen x diurutkan dalam urutan menaik sehubungan dengan urutan naik urutan leksikografis indeksnya atau false jika tidak. Jika x bukan , menghasilkan None.

  • is_unique(x: Value) -> Value ditentukan pada tensor dan menampilkan true jika x tidak memiliki elemen duplikat atau false. Jika x bukan tensor, akan menampilkan None.

  • member_name(x: Value) -> Any ditentukan untuk semua definisi anggota member_name dari semua nilai. Misalnya, real_part(x) menampilkan RealPart dari ComplexConstant yang sesuai. Jika x bukan nilai yang memiliki anggota yang sesuai, akan menampilkan None.

  • same(x: Value) -> Value ditentukan pada tensor dan menampilkan true jika elemen x semuanya sama satu sama lain atau false jika tidak. Jika tensor tidak memiliki elemen, yang dihitung sebagai “semua sama satu sama lain”, yaitu menampilkan true. Jika x bukan tensor, tampilkan None.

  • split(x: Value, num_results: Value, axis: Value) -> Value ditentukan di tensor dan menampilkan irisan num_results dari x di sepanjang sumbu axis. Jika x bukan tensor atau dim(x, axis) % num_results != 0, maka None akan ditampilkan.

  • is_defined_in_parent_scope(x: Value) -> Value ditentukan pada string dan menampilkan true jika x adalah nama fungsi yang ditentukan dalam cakupan yang sama sebagai fungsi induk dari op yang relevan.

  • is_namespaced_op_name(x: Value) -> Value ditentukan pada string dan ditampilkan true jika x adalah nama op yang valid, itu mengikuti metode reguler berikut ekspresi: [a-zA-Z][a-zA-Z0-9_]*([.][a-zA-Z0-9_$]+)+

Komputasi bentuk

  • axes(x: Value | Placeholder | Type) -> Value adalah pintasan untuk range(rank(x)).

  • dim(x: Value | Placeholder | Type, axis: Value) -> Value adalah pintasan untuk shape(x)[axis].

  • dims(x: Value | Placeholder | Type, axes: List) -> List adalah pintasan untuk list(map(lambda axis: dim(x, axis), axes)).

  • index_space(x: Value | Placeholder | Type) -> Value ditentukan pada tensor dan menampilkan indeks size(x) untuk TensorType yang sesuai yang diurutkan dalam urutan leksikografis menaik, yaitu [0, ..., 0], [0, ..., 1], ..., shape(x) - 1. Jika x bukan jenis tensor, jenis tensor terkuantisasi, atau nilai atau placeholder dari salah satu jenis ini, akan menampilkan None.

  • rank(x: Value | Placeholder | Type) -> Value adalah pintasan untuk size(shape(x)).

  • shape(x: Value | Placeholder | Type) -> Value ditentukan di bagian "Functions tentang jenis" melalui member_name.

  • size(x: Value | Placeholder | Type) -> Value adalah pintasan untuk reduce(lambda x, y: x * y, shape(x)).

Komputasi kuantisasi

  • def baseline_element_type(x: Value | Placeholder | Type) -> Type adalah untuk element_type(baseline_type(x)).

  • baseline_type ditentukan berdasarkan jenis tensor dan jenis tensor terkuantisasi, serta mengubahnya ke "baseline", yaitu tipe dengan bentuk yang sama tetapi dengan parameter kuantisasi jenis elemen direset ke nilai default. Ini adalah digunakan sebagai trik praktis untuk membandingkan tipe tensor dan tensor terkuantisasi seragam, yang sering dibutuhkan. Untuk jenis terkuantisasi, hal ini memungkinkan membandingkan jenis yang mengabaikan parameter kuantisasi, yaitu shape, storage_type, expressed_type, storage_min, storage_max, dan quantization_dimension (untuk jenis terkuantisasi per sumbu) harus semuanya cocok, tetapi scales dan zero points mungkin berbeda.

def baseline_type(x: Value | Placeholder | Type) -> Type:
  if type(x) == TensorType:
    return x
  if type(x) == QuantizedTensorType:
    element_type = quantized_tensor_element_type(x)
    baseline_element_type = QuantizedTensorElementType(
      storage_type = storage_type(element_type),
      storage_min = storage_min(element_type),
      storage_max = storage_max(element_type),
      expressed_type = expressed_type(element_type),
      quantization_dimension = quantization_dimension(element_type),
      scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
      zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
    return QuantizedTensorType(shape(x), baseline_element_type)
  if type(x) is not Type:
    return baseline_element_type(type(x))
  • dequantize ditentukan pada jenis tensor terkuantisasi dan mengubahnya menjadi tipe tensor floating point. Hal ini terjadi melalui konversi elemen terkuantisasi yang mewakili nilai integer dari jenis penyimpanan ke dalam kolom nilai floating point dari jenis yang dinyatakan menggunakan titik nol dan skala yang terkait dengan jenis elemen terkuantisasi.
def compute_zero_points(quantized_type, result_type):
  if is_per_tensor_quantized(quantized_type):
    return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
  if is_per_axis_quantized(quantized_type):
    for i in index_space(result_type):
      d = quantization_dimension(quantized_type)
      zero_points[i] = zero_points(quantized_type)[i[d]]
    return zero_points

def compute_scales(quantized_type, result_type):
  if is_per_tensor_quantized(quantized_type):
    return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
            type(result_type))
  if is_per_axis_quantized(quantized_type):
    for i in index_space(result_type):
      d = quantization_dimension(quantized_type)
      scales[i] = scales(quantized_type)[i[d]]
    return scales

def dequantize(x: Value) -> Value:
  assert is_quantized(x)
  x_storage = bitcast_convert(x, storage_type(x))
  x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
  x_expressed_sub = convert(x_storage_sub, expressed_type(x))
  return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
  • quantize ditentukan pada jenis tensor floating point dan mengubahnya menjadi jenis tensor terkuantisasi. Hal ini terjadi melalui konversi nilai floating point dari jenis yang dinyatakan menjadi nilai bilangan bulat yang sesuai dari jenis penyimpanan menggunakan titik dan skala nol yang terkait dengan jenis elemen terkuantisasi.
def quantize(x: Value, result_type: Type) -> Value:
  assert is_float(x) and is_quantized(result_type)
  zero_points = compute_zero_points(result_type, TensorType(shape(x), storage_type(result_type)))
  converted_zero_points = convert(zero_points, expressed_type(result_type))
  converted_min = convert(storage_min(result_type), expressed_type(result_type))
  converted_max = convert(storage_max(result_type), expressed_type(result_type))

  x_scaled = x / compute_scales(result_type, type(x))
  x_scaled_add_zp = x_scaled + converted_zero_points
  x_clamped = clamp(converted_min, x_scaled_add_zp, converted_max)
  x_rounded = round_nearest_even(x_clamped)
  return convert(x_rounded, result_type)
  • dequantize_op_quantize digunakan untuk menentukan komputasi {i>element<i}-{i>wise<i} di tensor terkuantisasi. Model ini mendekuantisasi, yaitu mengubah elemen terkuantisasi menjadi tipe yang dinyatakan, lalu melakukan operasi, dan kemudian mengkuantisasi, yaitu mengubah kembali ke jenis penyimpanan mereka. Saat ini, fungsi tersebut hanya bekerja untuk kuantisasi per-tensor. Kuantisasi per sumbu sedang dalam proses (#1574).
def dequantize_op_quantize(op, *inputs_and_output_type):
  inputs = inputs_and_output_type[:-1]
  output_type = inputs_and_output_type[-1]

  float_inputs = map(dequantize, inputs)
  float_result = op(*float_inputs)
  return quantize(float_result, output_type)

def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
  inputs = inputs_and_output_type[:-3]
  float_inputs = map(dequantize, inputs)
  float_results = op(*float_inputs)
  return map(quantize, float_results, inputs_and_output_type[-3:])

def dequantize_compare(lhs, rhs, comparison_direction):
  float_lhs = dequantize(lhs)
  float_rhs = dequantize(rhs)
  return compare(float_lhs, float_rhs, comparison_direction, FLOAT)

def dequantize_select_quantize(pred, on_true, on_false, output_type):
  float_on_true = dequantize(on_true)
  float_on_false = dequantize(on_false)
  float_result = select(pred, float_on_true, float_on_false)
  return quantize(float_result, output_type)
  • hybrid_dequantize_then_op digunakan untuk menentukan kuantisasi hanya bobot untuk operasi hybrid yang menerima lhs dalam floating point dan rhs dalam jenis terkuantisasi. Ini mendekuantisasi input terkuantisasi ke dalam jenis yang dinyatakan dan melakukan komputasi dalam {i>float<i}. Jenis elemen tensor float lhs dan jenis rhs terkuantisasi yang dinyatakan Tensor harus identik.
def hybrid_dequantize_then_op(op, lhs, rhs):
  assert(is_float(lhs) and is_quantized(rhs) and element_type(lhs) == expressed_type(rhs))
  return op(lhs, dequantize(rhs))

Komputasi grid

  • cross_partition(replica_groups: Value) -> Value. Lihat "cross_replica" di atas.

  • cross_replica(replica_groups: Value) -> Value. Lihat "cross_replica" di atas.

  • cross_replica_and_partition(replica_groups: Value) -> Value. Lihat &quot;cross_replica_and_partition&quot; di atas.

  • flattened_ids(replica_groups: Value) -> Value. Lihat "flattened_id" di atas.

Dinamisme

Nilai StabilHLO dapat memiliki ukuran dimensi dinamis, mis. tensor<?xi64>. Namun, nilai StableHLO tidak boleh memiliki jumlah dimensi dinamis (tidak memiliki peringkat dinamis, misalnya tensor<*xi64>). Operand dan hasil diizinkan untuk menggunakan ukuran dimensi, meskipun ada batasan ukurannya. Batasan akan diverifikasi secara statis jika memungkinkan, jika tidak, mereka ditangguhkan untuk runtime dan ketidakcocokan data akan menyebabkan perilaku yang tidak terdefinisi. Lihat contoh berikut.

Ketidakcocokan bentuk untuk operasi elementwise unary

Pertimbangkan program mainan berikut:

func.func @foo(%arg0: tensor<?xf64>) {
  %0 = stablehlo.abs %arg0 : (tensor<?xf64>) -> tensor<2xf64>
  return
}

Program semacam itu tidak biasa, karena tidak umum mengetahui bentuk hasil, tetapi bukan bentuk inputnya. Meskipun demikian, ini adalah StableHLO yang valid program ini. Operasi abs tidak dapat divalidasi secara statis dalam , karena bentuk operand yang tepat tidak diketahui. Namun, bentuk memang kompatibel, dan ini dapat diperiksa secara statis: ? dapat berubah menjadi 2 saat runtime, dan tidak akan ada masalah. Namun, ? dapat juga berubah menjadi beberapa integer lain, dalam hal ini perilakunya tidak terdefinisi.

Perhatikan bahwa jika ukuran dimensi dinamis dalam hasil, tidak boleh ada perilaku yang tidak terdefinisi. Sebenarnya, tidak ada kata "yang diharapkan" sehingga tidak bisa ada tidak cocok.

Ketidakcocokan bentuk untuk operasi elementwise biner

Pertimbangkan program mainan berikut:

func.func @foo(%arg0: tensor<?xf64>, %arg1: tensor<?xf64>) {
  %0 = stablehlo.add %arg0, %arg0 : (tensor<?xf64>, tensor<?xf64>) -> tensor<?xf64>
  return
}

Dalam hal operasi elementwise biner, bentuk input dan hasil harus disetujui pada runtime. Pada waktu kompilasi, dimensi statis harus sama, jika tidak, mereka hanya harus kompatibel. Jika salah satu dimensi dinamis dalam input, mungkin ada dimensi yang tidak ditentukan perilaku saat runtime, karena ukuran dinamis mungkin tidak cocok dengan di operand lain (baik itu statis maupun dinamis). Jika semua input statis, maka apakah hasilnya dinamis atau tidak, akan menjadi masalah: secara statis dimensi yang diketahui akan diperiksa secara statis, dan dimensi dinamis tidak menerapkan batasan apa pun.

Ketidakcocokan bentuk untuk operasi yang mengambil bentuk output-nya sebagai operand

Pertimbangkan program mainan berikut:

func.func @foo(%arg0: tensor<2xi32>) {
  %0 = stablehlo.dynamic_iota %arg0, dim = 0 : (tensor<2xi32>) -> tensor<3x4xi64>
  return
}

Nilai dalam operand bentuk saat runtime harus cocok dengan bentuk hasil, jika tidak, perilaku tidak terdefinisi. Artinya, pada runtime %arg0 harus memiliki senilai dense<[3, 4]> : tensor<2xi32>. Jika operand bentuk konstan, ini dapat diverifikasi secara statis. Jika bentuk hasil sepenuhnya dinamis, maka ada tidak boleh berupa ketidakcocokan.