StableHLO adalah set operasi untuk operasi tingkat tinggi (HLO) dalam model machine learning (ML). StableHLO berfungsi sebagai lapisan portabilitas antara berbagai framework ML dan compiler ML: framework ML yang menghasilkan program StableHLO kompatibel dengan compiler ML yang menggunakan program StableHLO.
Tujuan kami adalah menyederhanakan dan mempercepat pengembangan ML dengan menciptakan lebih banyak interoperabilitas antara berbagai framework ML (seperti TensorFlow, JAX, dan PyTorch) dan compiler ML (seperti XLA dan IREE). Untuk itu, dokumen ini memberikan spesifikasi untuk bahasa pemrograman StableHLO.
Spesifikasi ini berisi tiga bagian utama. Pertama, bagian Program menjelaskan struktur program StableHLO yang terdiri dari fungsi StableHLO yang terdiri dari operasi StableHLO. Dalam struktur tersebut, bagian Ops menentukan semantik setiap operasi. Bagian Execution menyediakan semantik untuk semua operasi ini yang dijalankan bersama dalam sebuah program. Terakhir, bagian Notasi membahas notasi yang digunakan di seluruh spesifikasi.
Untuk melihat spesifikasi dari rilis StableHLO sebelumnya, buka repo di rilis yang diberi tag yang diinginkan. Misalnya, Spesifikasi StableHLO v0.19.0. Untuk melihat perubahan yang terjadi pada setiap peningkatan versi minor StableHLO, lihat log versi di VhloDialect.td.
Program
Program ::= {Func}
Program StableHLO terdiri dari sejumlah fungsi StableHLO.
Di bawah ini adalah contoh program dengan fungsi @main yang memiliki 3 input
(%image, %weights, dan %bias) serta 1 output. Isi fungsi
memiliki 6 operasi.
func.func @main(
%image: tensor<28x28xf32>,
%weights: tensor<784x10xf32>,
%bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
%0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
%1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
%2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
%3 = "stablehlo.constant"() {value = dense<0.0> : tensor<1x10xf32>} : () -> tensor<1x10xf32>
%4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
"func.return"(%4): (tensor<1x10xf32>) -> ()
}
Fungsi
Func ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput ::= ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput ::= ValueType
FuncBody ::= {Op}
Fungsi StableHLO (yang juga disebut fungsi bernama) memiliki ID, input/output, dan isi. Pada masa mendatang, kami berencana untuk memperkenalkan metadata tambahan untuk fungsi guna mencapai kompatibilitas yang lebih baik dengan HLO (#425, #626, #740, #744).
ID
FuncId ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
| '%' letter {letter | digit}
letter ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit ::= '0' | ... | '9'
ID StableHLO mirip dengan ID di banyak bahasa pemrograman, dengan dua keunikan: 1) semua ID memiliki sigil yang membedakan berbagai jenis ID, 2) ID nilai dapat berupa angka sepenuhnya untuk menyederhanakan pembuatan program StableHLO.
Jenis
Type ::= ValueType | NonValueType
ValueType ::= TensorType | QuantizedTensorType | TokenType | TupleType | BufferType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType
Jenis StableHLO dikategorikan ke dalam jenis nilai (yang juga disebut jenis kelas satu) yang merepresentasikan nilai StableHLO dan jenis non-nilai yang menjelaskan elemen program lainnya. Jenis StableHLO mirip dengan jenis di banyak bahasa pemrograman, dengan keunikan utama adalah sifat khusus domain StableHLO yang menghasilkan beberapa hasil yang tidak biasa (misalnya, jenis skalar bukan jenis nilai).
TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit} | '?'
Jenis tensor merepresentasikan tensor, yaitu array multidimensi. Objek ini memiliki
bentuk dan jenis elemen, dengan bentuk yang merepresentasikan ukuran dimensi non-negatif atau tidak diketahui dalam urutan menaik dari dimensi yang sesuai (yang juga disebut sumbu) yang diberi nomor dari 0 hingga R-1. Jumlah dimensi R disebut rank. Misalnya, tensor<2x3xf32> adalah
jenis tensor dengan bentuk 2x3 dan jenis elemen f32. Array ini memiliki dua dimensi
(atau, dengan kata lain, dua sumbu) - dimensi ke-0 dan dimensi ke-1 - yang ukurannya
adalah 2 dan 3. Peringkatnya adalah 2.
Bentuk dapat sebagian atau sepenuhnya tidak diketahui (dinamis), misalnya, tensor<?x2xf64>
sebagian tidak diketahui dan tensor<?x?xf64> sepenuhnya tidak diketahui. Ukuran dimensi
dinamis ditampilkan menggunakan ?. Bentuk tidak dapat dibatalkan peringkatnya.
Pada masa mendatang, kami berencana untuk memperluas jenis tensor di luar ukuran dimensi dan jenis elemen, misalnya, untuk menyertakan tata letak (#629) dan kejarangan (#1078).
QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
QuantizationStorageType
['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
':' QuantizationExpressedType
[':' QuantizationDimension]
',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerLiteral
QuantizationStorageMax ::= IntegerLiteral
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerLiteral
QuantizationParameters ::= QuantizationParameter
| '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale [':' QuantizationZeroPoint]
QuantizationScale ::= FloatLiteral
QuantizationZeroPoint ::= IntegerLiteral
| Nama | Jenis | Batasan |
|---|---|---|
storage_type |
integer type | (C1-C3), (C8) |
storage_min |
konstanta bilangan bulat | (C1), (C3), (C7) |
storage_max |
konstanta bilangan bulat | (C2), (C3), (C7) |
expressed_type |
jenis floating point | (C4) |
quantization_dimension |
konstanta bilangan bulat opsional | (C10-C12) |
scales |
sejumlah konstanta floating point variadik | (C4-C6), (C9), (C10), (C13) |
zero_points |
jumlah konstanta bilangan bulat variadik | (C7-C9) |
Jenis elemen terkuantisasi merepresentasikan nilai bilangan bulat dari jenis penyimpanan dalam
rentang dari storage_min hingga storage_max (inklusif) yang sesuai dengan
nilai floating point dari jenis yang dinyatakan. Untuk nilai bilangan bulat i tertentu,
nilai floating point f yang sesuai dapat dihitung sebagai
f = (i - zero_point) * scale, dengan scale dan zero_point disebut
parameter kuantisasi. storage_min dan storage_max bersifat opsional
dalam tata bahasa, tetapi memiliki nilai default min_value(storage_type) dan
max_value(storage_type). Jenis elemen terkuantisasi memiliki batasan berikut:
- (C1)
type(storage_min) = storage_type. - (C2)
type(storage_max) = storage_type. - (C3)
min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type). - (C4)
type(scales...) = expressed_type. - (C5)
0 < scales. - (C6)
is_finite(scales...). - (C7)
storage_min <= zero_points <= storage_max. - (C8)
type(zero_points...) = storage_type. - (C9)
size(scales) = size(zero_points). - (C10) Jika
is_empty(quantization_dimension), makasize(scales) = 1. - (C11)
0 <= quantization_dimension.
Saat ini, QuantizationScale adalah konstanta floating point, tetapi ada minat yang kuat terhadap skala berbasis bilangan bulat, yang diwakili dengan pengali dan pergeseran. Kami berencana untuk menjelajahi fitur ini dalam waktu dekat
(#1404).
Diskusi tentang semantik QuantizationZeroPoint sedang berlangsung, termasuk jenis, nilai, dan apakah hanya ada satu atau beberapa titik nol yang berpotensi ada dalam jenis tensor terkuantisasi. Berdasarkan hasil diskusi ini, spesifikasi seputar nol poin dapat berubah di masa mendatang (#1405).
Diskusi berkelanjutan lainnya melibatkan semantik QuantizationStorageMin dan QuantizationStorageMax untuk menentukan apakah ada batasan yang harus diterapkan pada nilai ini dan pada nilai tensor terkuantisasi (#1406).
Terakhir, kami berencana untuk mengeksplorasi cara menampilkan skala yang tidak diketahui dan titik nol, serupa dengan cara kami berencana untuk mengeksplorasi cara menampilkan ukuran dimensi yang tidak diketahui (#1407).
Jenis tensor terkuantisasi merepresentasikan tensor dengan elemen terkuantisasi. Tensor ini sama persis dengan tensor biasa, kecuali elemennya memiliki jenis elemen terkuantisasi, bukan jenis elemen biasa.
Dalam tensor terkuantisasi, kuantisasi dapat dilakukan per-tensor, yang berarti memiliki
satu scale dan zero_point untuk seluruh tensor atau dapat dilakukan per-axis,
yang berarti memiliki beberapa scales dan zero_points, satu pasangan per slice
dari dimensi quantization_dimension tertentu. Secara lebih formal, dalam tensor t
dengan kuantisasi per sumbu, ada dim(t, quantization_dimension) slice
dari quantization_dimension: t[:, ..., 0, ..., :], t[:, ..., 1, ..., :],
dll. Semua elemen dalam slice ke-i menggunakan scales[i] dan zero_points[i] sebagai
parameter kuantisasinya. Jenis tensor terkuantisasi memiliki batasan berikut:
- Untuk kuantisasi per-tensor:
- Tidak ada batasan tambahan.
- Untuk kuantisasi per sumbu:
- (C12)
quantization_dimension < rank(self). - (C13)
dim(self, quantization_dimension) = size(scales).
- (C12)
TokenType ::= 'token'
Jenis token merepresentasikan token, yaitu nilai buram yang dihasilkan dan digunakan oleh beberapa operasi. Token digunakan untuk memaksakan urutan eksekusi pada operasi seperti yang dijelaskan di bagian Execution.
TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]
Jenis buffer merepresentasikan buffer. Misalnya, di XLA, buffer adalah
array multidimensi dengan penyimpanan yang konsisten. Mirip dengan jenis tensor,
jenis buffer memiliki bentuk dan jenis elemen, dengan bentuk yang merepresentasikan
ukuran dimensi non-negatif atau tidak diketahui dalam urutan menaik dari
dimensi yang sesuai (yang juga disebut sumbu) yang diberi nomor dari 0
hingga R-1. Jumlah dimensi R disebut rank. Misalnya,
memref<2x3xf32> adalah jenis buffer dengan bentuk 2x3 dan jenis elemen f32. Array ini memiliki dua dimensi (atau, dengan kata lain, dua sumbu) - dimensi ke-0 dan dimensi ke-1 - yang ukurannya adalah 2 dan 3. Peringkatnya adalah 2.
Buffer dapat dialokasikan menggunakan custom_call hingga CreateBuffer atau Pin dan
dibatalkan alokasinya melalui custom_call hingga Unpin. Hanya operasi custom_call yang dapat membaca dan
menulis konten di dalam buffer. Lihat custom_call untuk mengetahui detail
selengkapnya.
Jenis tuple merepresentasikan tuple, yaitu daftar heterogen. Tuple adalah fitur lama yang hanya ada untuk kompatibilitas dengan HLO. Di HLO, tuple digunakan untuk merepresentasikan input dan output variadik. Di StableHLO, input dan
output variadik didukung secara native, dan satu-satunya penggunaan tuple di StableHLO adalah untuk
merepresentasikan ABI HLO secara komprehensif, misalnya T, tuple<T>, dan
tuple<tuple<T>> mungkin sangat berbeda bergantung pada implementasi tertentu. Pada masa mendatang, kami berencana melakukan perubahan pada ABI HLO
yang dapat memungkinkan kami menghapus jenis tuple dari StableHLO
(#598).
TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si2' | 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui2' | 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f4E2M1FN' | 'f6E2M3FN' | 'f6E3M2FN' | 'f8E3M4' | 'f8E4M3'
| 'f8E4M3FN' | 'f8E4M3FNUZ' | 'f8E4M3B11FNUZ' | 'f8E5M2'
| 'f8E5M2FNUZ' | 'f8E8M0FNU' | 'bf16' | 'f16' | 'f32' | 'f64'
TensorFloat32 ::= 'tf32'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'
Jenis elemen merepresentasikan elemen jenis tensor. Tidak seperti di banyak bahasa pemrograman, jenis ini bukan kelas utama di StableHLO. Artinya, program StableHLO tidak dapat secara langsung merepresentasikan nilai jenis ini (sebagai hasilnya, nilai skalar jenis T direpresentasikan secara idiomatis dengan nilai tensor 0 dimensi jenis tensor<T>).
- Jenis boolean merepresentasikan nilai boolean
truedanfalse. - Jenis bilangan bulat dapat ditandai (
si) atau tidak ditandai (ui) dan memiliki salah satu lebar bit yang didukung (2,4,8,16,32, atau64). JenissiNyang ditandai merepresentasikan nilai bilangan bulat dari-2^(N-1)hingga2^(N-1)-1inklusif, dan jenisuiNyang tidak ditandai merepresentasikan nilai bilangan bulat dari0hingga2^N-1inklusif. - Jenis floating point dapat berupa salah satu dari berikut:
- Bilangan floating point 8-bit
f8E3M4,f8E4M3, danf8E5M2yang mengikuti konvensi IEEE-754. - Jenis
f8E4M3FNdanf8E5M2masing-masing sesuai dengan encodingE4M3danE5M2dari format FP8 yang dijelaskan dalam FP8 Formats for Deep Learning. - Jenis
f8E4M3FNUZdanf8E5M2FNUZyang sesuai dengan encodingE4M3danE5M2format FP8 yang dijelaskan dalam Format Numerik 8-bit untuk Jaringan Neural Dalam (Deep Neural Network). - Jenis
f8E4M3B11FNUZyang sesuai dengan encodingE4M3format FP8 yang dijelaskan dalam Pelatihan dan Inferensi Floating Point 8-bit Hibrida (HFP8) untuk Jaringan Neural Dalam. - Jenis
bf16yang sesuai dengan formatbfloat16yang dijelaskan dalam BFloat16: Rahasia performa tinggi di Cloud TPU. - Jenis
f16,f32, danf64yang masing-masing sesuai denganbinary16("presisi setengah"),binary32("presisi tunggal"), danbinary64("presisi ganda") yang dijelaskan dalam standar IEEE 754. - Jenis
tf32sesuai dengan format TensorFloat32 dan memiliki dukungan terbatas di StableHLO. f4E2M1FN,f6E2M3FN,f6E3M2FN, danf8E8M0FNUjenis MX (penskalaan mikro) yang dijelaskan dalam Spesifikasi Format Penskalaan Mikro OCP.
- Bilangan floating point 8-bit
- Jenis kompleks merepresentasikan nilai kompleks yang memiliki bagian riil
dan bagian imajiner dari jenis elemen yang sama. Jenis kompleks yang didukung adalah
complex<f32>(kedua bagian berjenisf32) dancomplex<f64>(kedua bagian berjenisf64).
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]
Jenis fungsi merepresentasikan fungsi bernama dan anonim. Mereka memiliki
jenis input (daftar jenis di sisi kiri ->) dan jenis output
(daftar jenis di sisi kanan ->). Dalam banyak bahasa pemrograman, jenis fungsi adalah kelas utama, tetapi tidak di StableHLO.
StringType ::= 'string'
Jenis string merepresentasikan urutan byte. Tidak seperti di banyak bahasa pemrograman, jenis string bukan kelas utama di StableHLO dan hanya digunakan untuk menentukan metadata statis untuk elemen program.
Operasi
Operasi StableHLO (yang juga disebut ops) merepresentasikan sekumpulan operasi tingkat tinggi yang tertutup dalam model machine learning. Seperti yang dibahas di atas, sintaksis StableHLO sangat terinspirasi oleh MLIR, yang belum tentu merupakan alternatif paling ergonomis, tetapi bisa dibilang paling cocok untuk tujuan StableHLO dalam menciptakan lebih banyak interoperabilitas antara framework ML dan compiler ML.
Op ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic ::= 'abs' | 'add' | ...
Operasi StableHLO (yang juga disebut ops) memiliki nama,
input/output, dan tanda tangan. Nama terdiri dari awalan stablehlo. dan
mnemonik yang secara unik mengidentifikasi salah satu operasi yang didukung. Lihat di bawah untuk mengetahui daftar lengkap semua operasi yang didukung.
OpInputs ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue ::= ValueId
OpInputFuncs ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs ::= [OpOutput {',' OpOutput} '=']
OpOutput ::= ValueId
Operasi menggunakan input dan menghasilkan output. Input dikategorikan ke dalam
nilai input (dihitung selama eksekusi), fungsi input (disediakan
secara statis, karena dalam fungsi StableHLO bukan nilai kelas satu) dan
atribut input (juga disediakan secara statis). Jenis input dan output yang digunakan dan dihasilkan oleh operasi bergantung pada mnemoniknya. Misalnya, operasi add
menggunakan 2 nilai input dan menghasilkan 1 nilai output. Sebagai perbandingan, operasi
select_and_scatter menggunakan 3 nilai input, 2 fungsi input, dan
3 atribut input.
OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused ::= '^' digit {digit}
| '^' letter {letter | digit}
Fungsi input (yang juga disebut fungsi anonim) sangat
mirip dengan fungsi bernama, kecuali: 1) fungsi ini tidak memiliki ID (oleh karena itu
disebut "anonim"), 2) fungsi ini tidak mendeklarasikan jenis output (jenis output
disimpulkan dari operasi return dalam fungsi).
Sintaksis untuk fungsi input mencakup bagian yang saat ini tidak digunakan (lihat produksi
Unused di atas) yang ada untuk kompatibilitas dengan MLIR. Di MLIR,
ada konsep "region" yang lebih umum yang dapat memiliki beberapa "blok"
operasi yang terhubung bersama melalui operasi lompatan. Blok ini memiliki ID yang sesuai dengan produksi Unused, sehingga dapat dibedakan satu sama lain.
StableHLO tidak memiliki operasi lompatan, sehingga bagian sintaksis MLIR yang sesuai tidak digunakan (tetapi masih ada).
OpInputAttr ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName ::= letter {letter | digit}
OpInputAttrValue ::= Constant
Atribut input memiliki nama dan nilai yang merupakan salah satu konstanta yang didukung. Ini adalah cara utama untuk menentukan metadata statis untuk elemen program. Misalnya, operasi concatenate menggunakan atribut dimension untuk
menentukan dimensi yang digunakan untuk menggabungkan nilai inputnya. Demikian pula,
operasi slice menggunakan beberapa atribut seperti start_indices dan limit_indices
untuk menentukan batas yang digunakan untuk mengiris nilai input.
Saat ini, program StableHLO di luar sana terkadang berisi atribut yang tidak dijelaskan dalam dokumen ini. Ke depannya, kami berencana untuk menggabungkan atribut ini ke dalam opset StableHLO atau melarangnya muncul dalam program StableHLO. Sementara itu, berikut daftar atribut tersebut:
layout(#629).mhlo.frontend_attributes(#628).mhlo.sharding(#619).output_operand_aliases(#740).- Metadata lokasi (#594).
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'
Tanda tangan op terdiri dari jenis semua nilai input (daftar jenis di sisi kiri ->) dan jenis semua nilai output (daftar jenis di sisi kanan ->). Sebenarnya, jenis input bersifat redundan, dan jenis output juga hampir selalu redundan (karena untuk sebagian besar operasi StableHLO, jenis output dapat disimpulkan dari input). Namun demikian, tanda tangan op sengaja menjadi bagian dari sintaksis StableHLO untuk kompatibilitas dengan MLIR.
Berikut adalah contoh operasi yang mnemoniknya adalah select_and_scatter. Fungsi ini menggunakan 3
nilai input (%operand, %source, dan %init_value), 2 fungsi input
, dan 3 atribut input (window_dimensions, window_strides, dan padding).
Perhatikan cara tanda tangan op hanya menyertakan jenis nilai inputnya
(tetapi bukan jenis fungsi dan atribut input yang disediakan secara inline).
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i32>, tensor<i32>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
window_dimensions = dense<[3, 1]> : tensor<2xi64>,
window_strides = dense<[2, 1]> : tensor<2xi64>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>
Konstanta
Constant ::= BooleanConstant
| IntegerConstant
| FloatConstant
| ComplexConstant
| TensorConstant
| QuantizedTensorConstant
| StringConstant
| EnumConstant
Konstanta StableHLO memiliki literal dan jenis yang bersama-sama merepresentasikan
nilai StableHLO. Umumnya, jenis adalah bagian dari sintaksis konstanta, kecuali
jika tidak ambigu (misalnya, konstanta boolean tidak ambigu memiliki jenis i1,
sedangkan konstanta bilangan bulat dapat memiliki beberapa kemungkinan jenis).
BooleanConstant ::= BooleanLiteral
BooleanLiteral ::= 'true' | 'false'
Konstanta boolean merepresentasikan nilai boolean true dan false. Konstanta
Boolean memiliki jenis i1.
IntegerConstant ::= IntegerLiteral ':' IntegerType
IntegerLiteral ::= ['-' | '+'] DecimalDigits
| ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit ::= '0' | ... | '9'
hexadecimalDigit ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'
Konstanta bilangan bulat merepresentasikan nilai bilangan bulat melalui string yang menggunakan notasi desimal atau heksadesimal. Basis lainnya, misalnya biner atau oktal, tidak didukung. Konstanta bilangan bulat memiliki batasan berikut:
- (C1)
is_wellformed(integer_literal, integer_type).
FloatConstant ::= FloatLiteral ':' FloatType
FloatLiteral ::= SignPart IntegerPart FractionalPart ScientificPart
| '0x' [HexadecimalDigits]
SignPart ::= ['-' | '+']
IntegerPart ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]
Konstanta floating point merepresentasikan nilai floating point melalui string yang menggunakan notasi desimal atau ilmiah. Selain itu, notasi heksadesimal dapat digunakan untuk menentukan secara langsung bit yang mendasarinya dalam format floating point dari jenis yang sesuai. Konstanta floating point memiliki batasan berikut:
- (C1) Jika notasi non-heksadesimal digunakan,
is_wellformed(float_literal, float_type). - (C2) Jika notasi heksadesimal digunakan,
size(hexadecimal_digits) = num_bits(float_type) / 4.
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral ::= '(' RealPart ',' ImaginaryPart ')'
RealPart ::= FloatLiteral
ImaginaryPart ::= FloatLiteral
Konstanta kompleks merepresentasikan nilai kompleks menggunakan daftar bagian real
(muncul pertama) dan bagian imajiner (muncul kedua). Misalnya,
(1.0, 0.0) : complex<f32> mewakili 1.0 + 0.0i, dan
(0.0, 1.0) : complex<f32> mewakili 0.0 + 1.0i. Urutan bagian-bagian ini disimpan dalam memori ditentukan oleh implementasi. Konstanta kompleks
memiliki batasan berikut:
- (C1)
is_wellformed(real_part, complex_element_type(complex_type)). - (C2)
is_wellformed(imaginary_part, complex_element_type(complex_type)).
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral
Konstanta tensor merepresentasikan nilai tensor menggunakan daftar bertingkat yang ditentukan melalui notasi NumPy. Misalnya, dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
merepresentasikan nilai tensor dengan pemetaan berikut dari indeks ke elemen:
{0, 0} => 1, {0, 1} => 2, {0, 2} => 3, {1, 0} => 4, {1, 1} => 5,
{1, 2} => 6. Urutan elemen ini disimpan dalam memori ditentukan oleh implementasi. Konstanta tensor memiliki batasan berikut:
- (C1)
has_syntax(tensor_literal, element_type(tensor_type)), dengan:has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type).has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type).
- (C2)
has_shape(tensor_literal, shape(tensor_type)), dengan:has_shape(element_literal: Syntax, []) = true.has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:]).- jika tidak,
false.
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
Konstanta tensor terkuantisasi merepresentasikan nilai tensor terkuantisasi menggunakan notasi yang sama dengan konstanta tensor, dengan elemen yang ditentukan sebagai konstanta jenis penyimpanan. Konstanta tensor terkuantisasi memiliki batasan berikut:
- (C1)
has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type)). - (C2)
has_shape(quantized_tensor_literal, shape(quantized_tensor_type)).
StringConstant ::= StringLiteral
StringLiteral ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))
Literal string terdiri dari byte yang ditentukan menggunakan karakter ASCII dan
urutan escape. Byte ini tidak bergantung pada encoding, sehingga interpretasi byte ini ditentukan oleh implementasi. Literal string memiliki jenis string.
Operasi
abs
Semantik
Melakukan operasi abs per elemen pada tensor operand dan menghasilkan tensor result. Bergantung pada jenis elemen, lakukan tindakan berikut:
- Untuk bilangan bulat bertanda: modulus bilangan bulat.
- Untuk float:
absdari IEEE-754. - Untuk bilangan kompleks: modulus kompleks.
- Untuk jenis yang dikuantisasi:
dequantize_op_quantize(abs, operand, type(result)).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor bilangan bulat bertanda tangan, floating point, atau jenis kompleks atau tensor terkuantisasi per-tensor | (C1-C2) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor jenis bilangan bulat bertanda atau floating point atau tensor terkuantisasi per-tensor | (C1-C2) |
Batasan
- (C1)
shape(result) = shape(operand). - (C2)
baseline_element_type(result)ditentukan sebagai:complex_element_type(element_type(operand))ifis_complex(operand).baseline_element_type(operand)jika tidak.
Contoh
// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand) : (tensor<3xi32>) -> tensor<3xi32>
// %result: [2, 0, 2]
tambahkan
Semantik
Melakukan penambahan per elemen dari dua tensor lhs dan rhs serta menghasilkan tensor result. Bergantung pada jenis elemen, lakukan tindakan berikut:
- Untuk boolean: OR logis.
- Untuk bilangan bulat: penambahan bilangan bulat.
- Untuk float:
additiondari IEEE-754. - Untuk bilangan kompleks: penjumlahan kompleks.
- Untuk jenis yang dikuantisasi:
dequantize_op_quantize(add, lhs, rhs, type(result)).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | lhs |
tensor atau tensor terkuantisasi | (C1-C6) |
| (I2) | rhs |
tensor atau tensor terkuantisasi | (C1-C5), (C7) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor atau tensor terkuantisasi | (C1-C7) |
Batasan
- Jika operasi menggunakan tensor yang tidak dikuantisasi:
- (C1)
type(lhs) = type(rhs) = type(result).
- (C1)
- Jika operasi menggunakan tensor terkuantisasi:
- (C2)
is_quantized(lhs) and is_quantized(rhs) and is_quantized(result). - (C3)
storage_type(lhs) = storage_type(rhs) = storage_type(result). - (C4)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result). - (C5)
(is_per_axis_quantized(lhs) or is_per_axis_quantized(rhs)) = is_per_axis_quantized(result). - (C6) Jika
is_per_axis_quantized(lhs), makaquantization_dimension(lhs) = quantization_dimension(result). - (C7) Jika
is_per_axis_quantized(rhs), makaquantization_dimension(rhs) = quantization_dimension(result).
- (C2)
Contoh
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[6, 8], [10, 12]]
after_all
Semantik
Memastikan bahwa operasi yang menghasilkan inputs dijalankan sebelum operasi apa pun yang bergantung pada result. Eksekusi operasi ini tidak melakukan apa pun,
operasi ini hanya ada untuk membuat dependensi data dari result ke inputs.
Input
| Label | Nama | Jenis |
|---|---|---|
| (I1) | inputs |
jumlah variadik token |
Output
| Nama | Jenis |
|---|---|
result |
token |
Contoh
// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token
all_gather
Semantik
Dalam setiap grup proses di petak proses StableHLO, menggabungkan nilai
tensor operands dari setiap proses di sepanjang all_gather_dim dan menghasilkan
tensor results.
Operasi ini membagi petak proses StableHLO menjadi process_groups yang ditentukan sebagai berikut:
cross_replica(replica_groups)ifchannel_id <= 0 and use_global_device_ids = false.cross_replica_and_partition(replica_groups)ifchannel_id > 0 and use_global_device_ids = false.flattened_ids(replica_groups)ifchannel_id > 0 and use_global_device_ids = true.
Setelah itu, di dalam setiap process_group:
operands...@receiver = [operand@sender for sender in process_group]untuk semuareceiverdiprocess_group.results...@process = concatenate(operands...@process, all_gather_dim)untuk semuaprocessdiprocess_group.
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operands |
jumlah tensor variadik atau tensor terkuantisasi per tensor | (C1), (C6) |
| (I2) | all_gather_dim |
konstanta jenis si64 |
(C1), (C6) |
| (I3) | replica_groups |
Konstanta tensor 2 dimensi dari jenis si64 |
(C2-C4) |
| (I4) | channel_id |
konstanta jenis si64 |
(C5) |
| (I5) | use_global_device_ids |
konstanta jenis i1 |
(C5) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
results |
jumlah tensor variadik atau tensor terkuantisasi per tensor | (C6) |
Batasan
- (C1)
0 <= all_gather_dim < rank(operands...). - (C2)
is_unique(replica_groups). - (C3)
size(replica_groups)ditentukan sebagai:num_replicasjikacross_replicadigunakan.num_replicasjikacross_replica_and_partitiondigunakan.num_processesjikaflattened_idsdigunakan.
- (C4)
0 <= replica_groups < size(replica_groups). - (C5) Jika
use_global_device_ids = true, makachannel_id > 0. - (C6)
type(results...) = type(operands...)kecuali:dim(results..., all_gather_dim) = dim(operands..., all_gather_dim) * dim(process_groups, 1).
Contoh
// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [[1, 2], [3, 4]]
// %operand0@(1, 0): [[5, 6], [7, 8]]
// %operand1@(0, 0): [[11, 12], [13, 14]]
// %operand1@(1, 0): [[15, 16], [17, 18]]
%result:2 = "stablehlo.all_gather"(%operand0, %operand1) {
all_gather_dim = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
// channel_id = 0
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
// use_global_device_ids = false
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> (tensor<2x4xi64>, tensor<2x4xi64>)
// %result0@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result0@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result1@(0, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
// %result1@(1, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
all_reduce
Semantik
Dalam setiap grup proses di petak proses StableHLO, menerapkan fungsi
reduksi computation ke nilai tensor operands dari setiap proses
dan menghasilkan tensor results.
Operasi ini membagi petak proses StableHLO menjadi process_groups yang ditentukan sebagai berikut:
cross_replica(replica_groups)ifchannel_id <= 0 and use_global_device_ids = false.cross_replica_and_partition(replica_groups)ifchannel_id > 0 and use_global_device_ids = false.flattened_ids(replica_groups)ifchannel_id > 0 and use_global_device_ids = true.
Setelah itu, di dalam setiap process_group:
results...@process[result_index] = exec(schedule)untuk beberapa pohon binerscheduledengan:exec(node)=computation(exec(node.left), exec(node.right)).exec(leaf)=leaf.value.
scheduleadalah pohon biner yang ditentukan implementasinya yang traversal in-order-nya adalahto_destination_type(operands...@process_group...[result_index], type(func_inputs(computation)[0])).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operands |
jumlah tensor variadik atau tensor terkuantisasi per tensor | (C5), (C6) |
| (I2) | replica_groups |
jumlah variadik konstanta tensor 1 dimensi berjenis si64 |
(C1-C3) |
| (I3) | channel_id |
konstanta jenis si64 |
(C4) |
| (I4) | use_global_device_ids |
konstanta jenis i1 |
(C4) |
| (I5) | computation |
fungsi | (C5) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
results |
jumlah tensor variadik atau tensor terkuantisasi per tensor | (C6-C7) |
Batasan
- (C1)
is_unique(replica_groups). - (C2)
size(replica_groups)ditentukan sebagai:num_replicasjikacross_replicadigunakan.num_replicasjikacross_replica_and_partitiondigunakan.num_processesjikaflattened_idsdigunakan.
- (C3)
0 <= replica_groups < size(replica_groups). - (C4) Jika
use_global_device_ids = true, makachannel_id > 0. - (C5)
computationmemiliki jenis(tensor<E>, tensor<E>) -> (tensor<E>)denganis_promotable(element_type(operand), E). - (C6)
shape(results...) = shape(operands...). - (C7)
element_type(results...) = E.
Contoh
// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [1, 2, 3, 4]
// %operand0@(1, 0): [5, 6, 7, 8]
// %operand1@(0, 0): [9, 10, 11, 12]
// %operand1@(1, 0): [13, 14, 15, 16]
%result:2 = "stablehlo.all_reduce"(%operand0, %operand0) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
// channel_id = 0
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
// use_global_device_ids = false
} : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>, tensor<4xi64>)
// %result0@(0, 0): [6, 8, 10, 12]
// %result0@(1, 0): [6, 8, 10, 12]
// %result1@(0, 0): [22, 24, 26, 28]
// %result1@(1, 0): [22, 24, 26, 28]
all_to_all
Semantik
Dalam setiap grup proses di petak proses StableHLO, membagi nilai
tensor operands di sepanjang split_dimension menjadi beberapa bagian, menyebarkan bagian
yang dibagi di antara proses, menggabungkan bagian yang disebarkan di sepanjang
concat_dimension, dan menghasilkan tensor results.
Operasi ini membagi petak proses StableHLO menjadi process_groups yang ditentukan sebagai berikut:
cross_replica(replica_groups)ifchannel_id <= 0.cross_partition(replica_groups)ifchannel_id > 0.
Setelah itu, di dalam setiap process_group:
split_parts...@sender = split(operands...@sender, split_count, split_dimension)untuk semuasenderdiprocess_group.scattered_parts...@receiver = [split_parts...@sender[receiver_index] for sender in process_group]di manareceiver_index = process_group.index(receiver).results...@process = concatenate(scattered_parts...@process, concat_dimension).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operands |
jumlah tensor variadik atau tensor terkuantisasi per tensor | (C1-C3), (C9) |
| (I2) | split_dimension |
konstanta jenis si64 |
(C1), (C2), (C9) |
| (I3) | concat_dimension |
konstanta jenis si64 |
(C3), (C9) |
| (I4) | split_count |
konstanta jenis si64 |
(C2), (C4), (C8), (C9) |
| (I5) | replica_groups |
Konstanta tensor 2 dimensi dari jenis si64 |
(C5-C8) |
| (I6) | channel_id |
konstanta jenis si64 |
Output
| Nama | Jenis | Batasan |
|---|---|---|
results |
jumlah tensor variadik atau tensor terkuantisasi per tensor | (C9) |
Batasan
- (C1)
0 <= split_dimension < rank(operands...). - (C2)
dim(operands..., split_dimension) % split_count = 0. - (C3)
0 <= concat_dimension < rank(operands...). - (C4)
0 < split_count. - (C5)
is_unique(replica_groups). - (C6)
size(replica_groups)ditentukan sebagai:num_replicasjikacross_replicadigunakan.num_partitionsjikacross_partitiondigunakan.
- (C7)
0 <= replica_groups < size(replica_groups). - (C8)
dim(replica_groups, 1) = split_count. - (C9)
type(results...) = type(operands...)kecuali, jikasplit_dimension != concat_dimension:dim(results..., split_dimension) = dim(operands..., split_dimension) / split_count.dim(results..., concat_dimension) = dim(operands..., concat_dimension) * split_count.
Contoh
// num_replicas: 2
// num_partitions: 1
// %operand1@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand1@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
// %operand2@(0, 0): [[17, 18, 19, 20],
// [21, 22, 23, 24]]
// %operand2@(1, 0): [[25, 26, 27, 28],
// [29, 30, 31, 32]]
%result:2 = "stablehlo.all_to_all"(%operand1, %operand2) {
split_dimension = 1 : i64,
concat_dimension = 0 : i64,
split_count = 2 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
// channel_id = 0
} : (tensor<2x4xi64>, tensor<2x4xi64>) -> (tensor<4x2xi64>, tensor<4x2xi64>)
// %result#0@(0, 0): [[1, 2], [5, 6], [9, 10], [13, 14]]
// %result#0@(1, 0): [[3, 4], [7, 8], [11, 12], [15, 16]]
// %result#1@(0, 0): [[17, 18], [21, 22], [25, 26], [29, 30]]
// %result#1@(1, 0): [[19, 20], [23, 24], [27, 28], [31, 32]]
dan
Semantik
Melakukan AND per elemen dari dua tensor lhs dan rhs serta menghasilkan tensor result. Bergantung pada jenis elemen, lakukan tindakan berikut:
- Untuk boolean: AND logis.
- Untuk bilangan bulat: bitwise AND.
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | lhs |
tensor jenis boolean atau bilangan bulat | (C1) |
| (I2) | rhs |
tensor jenis boolean atau bilangan bulat | (C1) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor jenis boolean atau bilangan bulat | (C1) |
Batasan
- (C1)
type(lhs) = type(rhs) = type(result).
Contoh
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 0]]
atan2
Semantik
Melakukan operasi atan2 per elemen pada tensor lhs dan rhs serta menghasilkan tensor result. Bergantung pada jenis elemen, lakukan tindakan berikut:
- Untuk float:
atan2dari IEEE-754. - Untuk bilangan kompleks: atan2 kompleks.
- Untuk jenis yang dikuantisasi:
dequantize_op_quantize(atan2, lhs, rhs, type(result)).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | lhs |
tensor jenis floating-point atau kompleks atau tensor terkuantisasi per-tensor | (C1) |
| (I2) | rhs |
tensor jenis floating-point atau kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor jenis floating-point atau kompleks atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).
Contoh
// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs) : (tensor<3xf64>, tensor<3xf64>) -> tensor<3xf64>
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]
batch_norm_grad
Semantik
Menghitung gradien beberapa input batch_norm_training dengan backpropagation
dari grad_output, dan menghasilkan tensor grad_operand, grad_scale, dan grad_offset. Secara lebih formal, operasi ini dapat dinyatakan sebagai dekomposisi ke operasi StableHLO yang ada menggunakan sintaksis Python sebagai berikut:
def compute_sum(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
return sum
def compute_mean(operand, feature_index):
sum = compute_sum(operand, feature_index)
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
# Broadcast inputs to type(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance`
# Intermediate values will be useful for computing gradients
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
# Use the implementation from batchnorm_expander.cc in XLA
# Temporary variables have exactly the same names as in the C++ code
elements_per_feature = broadcast_in_dim(
constant(divide(size(operand), dim(operand, feature_index)),
element_type(grad_output)),
[], type(operand))
i1 = multiply(grad_output, elements_per_feature)
i2 = broadcast_in_dim(
compute_sum(grad_output, feature_index), [feature_index], type(operand))
i3 = broadcast_in_dim(
compute_sum(multiply(grad_output, centered_operand), feature_index),
[feature_index], type(operand))
i4 = multiply(i3, centered_operand)
i5 = divide(i4, add(variance_bcast, epsilon_bcast))
i6 = subtract(subtract(i1, i2), i5)
grad_operand =
multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
grad_scale =
compute_sum(multiply(grad_output, normalized_operand), feature_index)
grad_offset = compute_sum(grad_output, feature_index)
return grad_operand, grad_scale, grad_offset
Untuk jenis yang dikuantisasi, melakukan
dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean,
variance, grad_output: batch_norm_grad(operand, scale, mean, variance,
grad_output, epsilon, feature_index), operand, scale, mean, variance,
grad_output, type(grad_operand), type(grad_scale), type(feature_index)).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor jenis floating point atau tensor terkuantisasi per-tensor | (C1-C3), (C5) |
| (I2) | scale |
Tensor 1 dimensi dari jenis floating-point atau per-tensor terkuantisasi | (C2), (C4), (C5) |
| (I3) | mean |
Tensor 1 dimensi dari jenis floating-point atau per-tensor terkuantisasi | (C2), (C4) |
| (I4) | variance |
Tensor 1 dimensi dari jenis floating-point atau per-tensor terkuantisasi | (C2), (C4) |
| (I5) | grad_output |
tensor jenis floating point atau tensor terkuantisasi per-tensor | (C2), (C3) |
| (I6) | epsilon |
konstanta jenis f32 |
|
| (I7) | feature_index |
konstanta jenis si64 |
(C1), (C5) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
grad_operand |
tensor jenis floating point atau tensor terkuantisasi per-tensor | (C2), (C3) |
grad_scale |
Tensor 1 dimensi dari jenis floating-point atau per-tensor terkuantisasi | (C2), (C4) |
grad_offset |
Tensor 1 dimensi dari jenis floating-point atau per-tensor terkuantisasi | (C2), (C4) |
Batasan
- (C1)
0 <= feature_index < rank(operand). - (C2)
operand,scale,mean,variance,grad_output,grad_operand,grad_scale, dangrad_offsetmemilikibaseline_element_typeyang sama. - (C3)
operand,grad_output, dangrad_operandmemiliki bentuk yang sama. - (C4)
scale,mean,variance,grad_scale, dangrad_offsetmemiliki bentuk yang sama. - (C5)
size(scale) = dim(operand, feature_index).
Contoh
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
// [[0.1, 0.1], [0.1, 0.1]],
// [[0.1, 0.1], [0.1, 0.1]]
// ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>,
tensor<2x2x2xf64>) -> (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %grad_operand: [
// [[0.0, 0.0], [0.0, 0.0]],
// [[0.0, 0.0], [0.0, 0.0]]
// ]
// %grad_scale: [0.0, 0.0]
// %grad_offset: [0.4, 0.4]
batch_norm_inference
Semantik
Menormalisasi tensor operand di semua dimensi kecuali dimensi
feature_index dan menghasilkan tensor result. Secara lebih formal, operasi ini dapat dinyatakan sebagai dekomposisi ke operasi StableHLO yang ada menggunakan sintaksis Python sebagai berikut:
def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
# Broadcast inputs to shape(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance` instead of
# computing them like `batch_norm_training` does.
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
return add(multiply(scale_bcast, normalized_operand), offset_bcast)
Untuk jenis yang dikuantisasi, melakukan
dequantize_op_quantize(lambda operand, scale, offset, mean, variance:
batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index), operand, scale, offset, mean, variance, type(result)).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor jenis floating point atau tensor terkuantisasi per-tensor | (C1-C7) |
| (I2) | scale |
Tensor 1 dimensi dari jenis floating-point atau per-tensor terkuantisasi | (C2), (C3) |
| (I3) | offset |
Tensor 1 dimensi dari jenis floating-point atau per-tensor terkuantisasi | (C2), (C4) |
| (I4) | mean |
Tensor 1 dimensi dari jenis floating-point atau per-tensor terkuantisasi | (C5) |
| (I5) | variance |
Tensor 1 dimensi dari jenis floating-point atau per-tensor terkuantisasi | (C2), (C6) |
| (I6) | epsilon |
konstanta jenis f32 |
|
| (I7) | feature_index |
konstanta jenis si64 |
(C1), (C3-C6) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor jenis floating point atau tensor terkuantisasi per-tensor | (C2), (C7) |
Batasan
- (C1)
0 <= feature_index < rank(operand). - (C2)
operand,scale,offset,mean,variance, danresultmemilikibaseline_element_typeyang sama. - (C3)
size(scale) = dim(operand, feature_index). - (C4)
size(offset) = dim(operand, feature_index). - (C5)
size(mean) = dim(operand, feature_index). - (C6)
size(variance) = dim(operand, feature_index). - (C7)
baseline_type(operand) = baseline_type(result).
Contoh
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>) -> tensor<2x2x2xf64>
// %result: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
batch_norm_training
Semantik
Menghitung rata-rata dan varians di semua dimensi kecuali dimensi feature_index
dan menormalisasi tensor operand yang menghasilkan tensor output, batch_mean
dan batch_var. Secara lebih formal, operasi ini dapat dinyatakan sebagai dekomposisi ke operasi StableHLO yang ada menggunakan sintaksis Python sebagai berikut:
def compute_mean(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def compute_variance(operand, feature_index):
mean = compute_mean(operand, feature_index)
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
centered_operand = subtract(operand, mean_bcast)
return compute_mean(mul(centered_operand, centered_operand), feature_index)
def batch_norm_training(operand, scale, offset, epsilon, feature_index):
mean = compute_mean(operand, feature_index)
variance = compute_variance(operand, feature_index)
return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index),
mean, variance
Untuk jenis yang dikuantisasi, melakukan
dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset:
batch_norm_training(operand, scale, offset, epsilon, feature_index), operand,
scale, offset, type(output), type(batch_mean), type(batch_var)).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor jenis floating point atau tensor terkuantisasi per-tensor | (C1) |
| (I2) | scale |
Tensor 1 dimensi dari floating point atau per-tensor terkuantisasi | (C2), (C3) |
| (I3) | offset |
Tensor 1 dimensi dari floating point atau per-tensor terkuantisasi | (C2), (C4) |
| (I4) | epsilon |
konstanta jenis f32 |
(C1), (C3-C6) |
| (I5) | feature_index |
konstanta jenis si64 |
(C1), (C3-C6) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
output |
tensor jenis floating point atau tensor terkuantisasi per-tensor | (C7) |
batch_mean |
Tensor 1 dimensi dari floating point atau per-tensor terkuantisasi | (C2), (C5) |
batch_var |
Tensor 1 dimensi dari floating point atau per-tensor terkuantisasi | (C2), (C6) |
Batasan
- (C1)
0 <= feature_index < rank(operand). - (C2)
operand,scale,offset,batch_mean,batch_var, danoutputmemilikibaseline_element_typeyang sama. - (C3)
size(scale) = dim(operand, feature_index). - (C4)
size(offset) = dim(operand, feature_index). - (C5)
size(batch_mean) = dim(operand, feature_index). - (C6)
size(batch_var) = dim(operand, feature_index). - (C7)
baseline_type(output) = baseline_type(operand).
Contoh
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) ->
(tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %output: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]
bitcast_convert
Semantik
Melakukan operasi bitcast pada tensor operand dan menghasilkan tensor result
di mana bit seluruh tensor operand ditafsirkan ulang menggunakan
jenis tensor result.
Secara lebih formal, dengan kondisi E = element_type(operand), E' = element_type(result),
dan R = rank(operand):
- Jika
num_bits(E') < num_bits(E),bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1]). - Jika
num_bits(E') > num_bits(E),bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :]). - Jika
num_bits(E') = num_bits(E),bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1]).
bits menampilkan representasi dalam memori dari nilai tertentu, dan perilakunya
ditetapkan oleh implementasi karena representasi tensor yang tepat
ditetapkan oleh implementasi, dan representasi jenis elemen yang tepat juga
ditetapkan oleh implementasi.
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor atau tensor terkuantisasi | (C1-C2) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor atau tensor terkuantisasi | (C1-C2) |
Batasan
- (C1) Dengan
E = is_quantized(operand) ? storage_type(operand) : element_type(operand),E' = is_quantized(result) ? storage_type(result) : element_type(result), danR = rank(operand):- Jika
num_bits(E') = num_bits(E),shape(result) = shape(operand). - Jika
num_bits(E') < num_bits(E): rank(result) = R + 1.dim(result, i) = dim(operand, i)untuk semua0 <= i < R.dim(result, R) * num_bits(E') = num_bits(E).- Jika
num_bits(E') > num_bits(E): rank(result) = R - 1.dim(result, i) = dim(operand, i)untuk semua0 <= i < R.dim(operand, R - 1) * num_bits(E) = num_bits(E').
- Jika
- (C2) Jika
is_complex(operand) or is_complex(result), makais_complex(operand) and is_complex(result).
Contoh
// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand) : (tensor<f64>) -> tensor<4xf16>
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation
broadcast_in_dim
Semantik
Memperluas dimensi dan/atau peringkat tensor input dengan menduplikasi data
dalam tensor operand dan menghasilkan tensor result. Secara lebih formal,
result[result_index] = operand[operand_index] dengan semua d dalam
axes(operand):
operand_index[d] = 0ifdim(operand, d) = 1.operand_index[d] = result_index[broadcast_dimensions[d]]jika tidak.
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor atau tensor terkuantisasi | (C1-C2), (C5-C6) |
| (I2) | broadcast_dimensions |
Konstanta tensor 1 dimensi berjenis si64 |
(C2-C6) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor atau tensor terkuantisasi | (C1), (C3), (C5-C6) |
Batasan
- (C1)
element_type(result)diberikan oleh:element_type(operand), jika!is_per_axis_quantized(operand).element_type(operand), kecualiquantization_dimension(operand),scales(operand), danzero_points(operand)dapat berbeda dariquantization_dimension(result),scales(result), danzero_points(result)masing-masing, jika tidak.
- (C2)
size(broadcast_dimensions) = rank(operand). - (C3)
0 <= broadcast_dimensions < rank(result). - (C4)
is_unique(broadcast_dimensions). - (C5) Untuk semua
ddiaxes(operand):dim(operand, d) = 1ataudim(operand, d) = dim(result, broadcast_dimensions[d]).
- (C6) Jika
is_per_axis_quantized(result):quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)].- Jika
dim(operand, quantization_dimension(operand)) = 1, makascales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result))).
Contoh
// %operand: [
// [1, 2, 3]
// ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
broadcast_dimensions = array<i64: 2, 1>
} : (tensor<1x3xi32>) -> tensor<2x3x2xi32>
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
casing
Semantik
Menghasilkan output dari eksekusi tepat satu fungsi dari branches
bergantung pada nilai index. Secara lebih formal, result = selected_branch()
dengan:
selected_branch = branches[index]if0 <= index < size(branches).selected_branch = branches[-1]jika tidak.
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | index |
Tensor 0 dimensi berjenis si32 |
|
| (I2) | branches |
jumlah fungsi variadik | (C1-C4) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
results |
sejumlah tensor, tensor terkuantisasi, atau token variadik | (C4) |
Batasan
- (C1)
0 < size(branches). - (C2)
input_types(branches...) = []. - (C3)
same(output_types(branches...)). - (C4)
type(results...) = output_types(branches[0]).
Contoh
// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
"stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> ()
}, {
"stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> ()
}) : (tensor<i32>) -> (tensor<2xi64>, tensor<2xi64>)
// %result0: [1, 1]
// %result1: [1, 1]
cbrt
Semantik
Melakukan operasi akar kubik per elemen pada tensor operand dan menghasilkan tensor result. Bergantung pada jenis elemen, lakukan tindakan berikut:
- Untuk float:
rootn(x, 3)dari IEEE-754. - Untuk bilangan kompleks: akar kubik kompleks.
- Untuk jenis yang dikuantisasi:
dequantize_op_quantize(cbrt, operand, type(result))
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor jenis floating-point atau kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor jenis floating-point atau kompleks atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result).
Contoh
// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand) : (tensor<4xf64>) -> tensor<4xf64>
// %result: [0.0, 1.0, 2.0, 3.0]
ceil
Semantik
Melakukan ceil per elemen pada tensor operand dan menghasilkan tensor result.
Menerapkan operasi roundToIntegralTowardPositive dari spesifikasi IEEE-754. Untuk jenis yang dikuantisasi, melakukan
dequantize_op_quantize(ceil, operand, type(result)).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor jenis floating point atau tensor terkuantisasi per-tensor | (C1) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor jenis floating point atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result).
Contoh
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]
cholesky
Semantik
Menghitung dekomposisi Cholesky dari sekumpulan matriks.
Secara lebih formal, untuk semua i dalam index_space(result),
result[i0, ..., iR-3, :, :] adalah dekomposisi Cholesky dari
a[i0, ..., iR-3, :, :], dalam bentuk matriks segitiga bawah
(jika lower adalah true) atau segitiga atas (jika lower adalah false).
Nilai output dalam segitiga yang berlawanan, yaitu segitiga atas ketat atau
segitiga bawah ketat, ditentukan oleh implementasi.
Jika ada i di mana matriks input bukan matriks definit positif Hermitian, maka perilakunya tidak ditentukan.
Untuk jenis yang dikuantisasi, melakukan
dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result)).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | a |
tensor jenis floating-point atau kompleks atau tensor terkuantisasi per-tensor | (C1-C3) |
| (I2) | lower |
konstanta jenis i1 |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor jenis floating-point atau kompleks atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(a) = baseline_type(result). - (C2)
2 <= rank(a). - (C3)
dim(a, -2) = dim(a, -1).
Contoh
// %a: [
// [1.0, 2.0, 3.0],
// [2.0, 20.0, 26.0],
// [3.0, 26.0, 70.0]
// ]
%result = "stablehlo.cholesky"(%a) {
lower = true
} : (tensor<3x3xf32>) -> tensor<3x3xf64>
// %result: [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
penjepit
Semantik
Mengepaskan setiap elemen tensor operand antara nilai minimum dan maksimum
serta menghasilkan tensor result. Secara lebih formal, result[result_index] =
minimum(maximum(operand[result_index], min_element), max_element),
dengan min_element = rank(min) = 0 ? min[] : min[result_index],
max_element = rank(max) = 0 ? max[] : max[result_index]. Untuk jenis yang dikuantisasi,
melakukan dequantize_op_quantize(clamp, min, operand, max, type(result)).
Menerapkan pengurutan pada bilangan kompleks melibatkan semantik yang mengejutkan, jadi pada masa mendatang kami berencana menghapus dukungan untuk bilangan kompleks untuk operasi ini (#560).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | min |
tensor atau tensor terkuantisasi per tensor | (C1), (C3) |
| (I2) | operand |
tensor atau tensor terkuantisasi per tensor | (C1-C4) |
| (I3) | max |
tensor atau tensor terkuantisasi per tensor | (C2), (C3) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor atau tensor terkuantisasi per tensor | (C4) |
Batasan
- (C1)
rank(min) = 0 or shape(min) = shape(operand). - (C2)
rank(max) = 0 or shape(max) = shape(operand). - (C3)
baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max). - (C4)
baseline_type(operand) = baseline_type(result).
Contoh
// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max) : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// %result: [5, 13, 20]
collective_broadcast
Semantik
Dalam setiap grup proses di petak proses StableHLO, kirim nilai tensor
operand dari proses sumber ke proses target dan hasilkan tensor
result.
Operasi ini membagi petak proses StableHLO menjadi process_groups yang ditentukan sebagai berikut:
cross_replica(replica_groups)ifchannel_id <= 0.cross_partition(replica_groups)ifchannel_id > 0.
Setelah itu, result@process diberikan oleh:
operand@process_groups[i, 0]jika adaisehingga prosesnya berada diprocess_groups[i].broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))sebaliknya.
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor atau tensor terkuantisasi per tensor | (C3) |
| (I2) | replica_groups |
jumlah variadik konstanta tensor 1 dimensi berjenis si64 |
(C1), (C2) |
| (I3) | channel_id |
konstanta jenis si64 |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor atau tensor terkuantisasi per tensor | (C3) |
Batasan
- (C1)
is_unique(replica_groups). - (C2)
0 <= replica_groups < NdenganNditentukan sebagai:num_replicasjikacross_replicadigunakan.num_partitionsjikacross_partitiondigunakan.
- (C3)
type(result) = type(operand).
Contoh
// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]
collective_permute
Semantik
Dalam setiap grup proses di petak proses StableHLO, mengirimkan nilai tensor operand dari proses sumber ke proses target dan menghasilkan tensor result.
Operasi ini membagi petak proses StableHLO menjadi process_groups yang ditentukan sebagai berikut:
cross_replica(source_target_pairs)ifchannel_id <= 0.cross_partition(source_target_pairs)ifchannel_id > 0.
Setelah itu, result@process diberikan oleh:
operand@process_groups[i, 0], jika adaisehinggaprocess_groups[i, 1] = process.broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))sebaliknya.
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor atau tensor terkuantisasi per tensor | (C5) |
| (I2) | source_target_pairs |
Konstanta tensor 2 dimensi dari jenis si64 |
(C1-C4) |
| (I3) | channel_id |
konstanta jenis si64 |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor atau tensor terkuantisasi per tensor | (C1) |
Batasan
- (C1)
dim(source_target_pairs, 1) = 2. - (C2)
is_unique(source_target_pairs[:, 0]). - (C3)
is_unique(source_target_pairs[:, 1]). - (C4)
0 <= source_target_pairs < N, denganNditentukan sebagai:num_replicasjikacross_replicadigunakan.num_partitionsjikacross_partitiondigunakan.
- (C5)
type(result) = type(operand).
Contoh
// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]
bandingkan
Semantik
Melakukan perbandingan elemen demi elemen pada tensor lhs dan rhs sesuai dengan
comparison_direction dan compare_type, serta menghasilkan tensor result.
Nilai comparison_direction dan compare_type memiliki semantik
berikut:
Untuk jenis elemen boolean dan bilangan bulat:
EQ:lhs = rhs.NE:lhs != rhs.GE:lhs >= rhs.GT:lhs > rhs.LE:lhs <= rhs.LT:lhs < rhs.
Untuk jenis elemen floating point dengan compare_type = FLOAT, op mengimplementasikan
operasi IEEE-754 berikut:
EQ:compareQuietEqual.NE:compareQuietNotEqual.GE:compareQuietGreaterEqual.GT:compareQuietGreater.LE:compareQuietLessEqual.LT:compareQuietLess.
Untuk jenis elemen floating point dengan compare_type = TOTALORDER, op
menggunakan kombinasi operasi totalOrder dan compareQuietEqual dari
IEEE-754.
Untuk jenis elemen yang kompleks, perbandingan leksikografis pasangan (real, imag)
dilakukan menggunakan comparison_direction dan compare_type yang diberikan.
Pemberlakuan pengurutan pada bilangan kompleks melibatkan semantik yang mengejutkan,
jadi pada masa mendatang kami berencana menghapus dukungan untuk bilangan kompleks
jika comparison_direction adalah GE, GT, LE, atau LT
(#560).
Untuk jenis yang dikuantisasi. Melakukan dequantize_compare(lhs, rhs,
comparison_direction).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | lhs |
tensor atau tensor terkuantisasi per tensor | (C1-C3) |
| (I2) | rhs |
tensor atau tensor terkuantisasi per tensor | (C1-C2) |
| (I3) | comparison_direction |
enum EQ, NE, GE, GT, LE, dan LT |
|
| (I4) | compare_type |
enum FLOAT, TOTALORDER, SIGNED, dan UNSIGNED |
(C3) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor jenis boolean | (C2) |
Batasan
- (C1)
baseline_element_type(lhs) = baseline_element_type(rhs). - (C2)
shape(lhs) = shape(rhs) = shape(result). - (C3)
compare_typeditentukan sebagai:SIGNEDifis_signed_integer(element_type(lhs)).UNSIGNEDifis_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs)).FLOATatauTOTALORDERjikais_float(element_type(lhs)).FLOATifis_complex(element_type(lhs)).
Contoh
// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
comparison_direction = #stablehlo<comparison_direction LT>,
compare_type = #stablehlo<comparison_type FLOAT>
} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
// %result: [true, false]
kompleks
Semantik
Melakukan konversi per elemen ke nilai kompleks dari pasangan nilai riil dan imajiner, lhs dan rhs, serta menghasilkan tensor result.
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | lhs |
tensor berjenis f32 atau f64 |
(C1-C3) |
| (I2) | rhs |
tensor berjenis f32 atau f64 |
(C1) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor jenis kompleks | (C2), (C3) |
Batasan
- (C1)
type(lhs) = type(rhs). - (C2)
shape(result) = shape(lhs). - (C3)
element_type(result)memiliki jeniscomplex<E>denganE = element_type(lhs).
Contoh
// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs) : (tensor<2xf64>, tensor<2xf64>) -> tensor<2xcomplex<f64>>
// %result: [(1.0, 2.0), (3.0, 4.0)]
gabungan
Semantik
Mengkapsulasi operasi yang terdiri dari (dikomposisikan) operasi StableHLO lainnya,
mengambil inputs dan composite_attributes serta menghasilkan results. Semantik op diimplementasikan oleh atribut decomposition. Operasi
composite dapat diganti dengan dekomposisinya tanpa mengubah semantik program. Jika penyusunan sebaris dekomposisi tidak memberikan semantik op yang sama, sebaiknya gunakan custom_call.
Kolom version (defaultnya 0) digunakan untuk menunjukkan kapan semantik komposit berubah.
Input
| Label | Nama | Jenis |
|---|---|---|
| (I1) | inputs |
jumlah nilai variadik |
| (I2) | name |
konstanta jenis string |
| (I3) | composite_attributes |
kamus atribut |
| (I4) | decomposition |
konstanta jenis string |
| (I5) | version |
konstanta jenis si32 |
Output
| Nama | Jenis |
|---|---|
results |
jumlah nilai variadik |
Batasan
- (C1)
is_namespaced_op_name(name) - (C2)
is_defined_in_parent_scope(decomposition) - (C3)
types(inputs...) == input_types(decomposition) - (C4)
types(results...) == output_types(decomposition)
Contoh
%results = "stablehlo.composite"(%input0, %input1) {
name = "my_namespace.my_op",
composite_attributes = {
my_attribute = "my_value"
},
decomposition = @my_op,
version = 1 : i32
} : (tensor<f32>, tensor<f32>) -> tensor<f32>
concatenate
Semantik
Menggabungkan inputs di sepanjang dimensi dimension dalam urutan yang sama seperti argumen yang diberikan dan menghasilkan tensor result. Secara lebih formal,
result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1], dengan:
id = d0 + ... + dk-1 + kd.dsama dengandimension, dand0, ... adalah ukuran dimensi ke-ddariinputs.
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | inputs |
jumlah tensor variadik atau tensor terkuantisasi per tensor | (C1-C6) |
| (I2) | dimension |
konstanta jenis si64 |
(C2), (C4), (C6) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor atau tensor terkuantisasi per tensor | (C5-C6) |
Batasan
- (C1)
same(element_type(inputs...)). - (C2)
same(shape(inputs...))kecualidim(inputs..., dimension). - (C3)
0 < size(inputs). - (C4)
0 <= dimension < rank(inputs[0]). - (C5)
element_type(result) = element_type(inputs[0]). - (C6)
shape(result) = shape(inputs[0])kecuali:dim(result, dimension) = dim(inputs[0], dimension) + ....
Contoh
// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
dimension = 0 : i64
} : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64>
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]
konstanta
Semantik
Menghasilkan tensor output dari konstanta value.
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | value |
konstanta | (C1) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
output |
tensor atau tensor terkuantisasi | (C1) |
Batasan
- (C1)
type(value) = type(output).
Contoh
%output = "stablehlo.constant"() {
value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
} : () -> tensor<2x2xf32>
// %output: [[0.0, 1.0], [2.0, 3.0]]
melakukan konversi
Semantik
Melakukan konversi per elemen dari satu jenis elemen ke jenis elemen lainnya pada
tensor operand dan menghasilkan tensor result.
Untuk konversi boolean-to-any-supported-type, nilai false dikonversi menjadi nol, dan nilai true dikonversi menjadi satu. Untuk konversi
any-supported-type-to-boolean, nilai nol dikonversi menjadi
false, dan nilai bukan nol dikonversi menjadi true. Lihat di bawah untuk mengetahui cara kerja ini untuk jenis yang kompleks.
Untuk konversi yang melibatkan bilangan bulat ke bilangan bulat, bilangan bulat ke floating point atau floating point ke floating point, jika nilai sumber dapat direpresentasikan secara persis dalam jenis tujuan, nilai hasilnya adalah representasi persis tersebut. Jika tidak, perilakunya akan ditentukan kemudian (#180).
Untuk konversi yang melibatkan floating-point-to-integer, bagian pecahan akan dipangkas. Jika nilai yang dipangkas tidak dapat ditampilkan dalam jenis tujuan, perilakunya akan ditentukan (#180).
Konversi yang melibatkan kompleks ke kompleks mengikuti perilaku yang sama dengan konversi floating point ke floating point untuk mengonversi bagian riil dan imajiner.
Untuk konversi complex-to-any-other-type dan any-other-type-to-complex, nilai imajiner sumber diabaikan atau nilai imajiner tujuan dibuat nol. Konversi bagian nyata mengikuti konversi floating point.
Pada prinsipnya, operasi ini dapat mengekspresikan dekuantisasi (konversi dari
tensor terkuantisasi ke tensor reguler), kuantisasi (konversi dari
tensor reguler ke tensor terkuantisasi), dan rekuantisasi (konversi antara tensor
terkuantisasi), tetapi saat ini kita memiliki operasi khusus untuk itu -
uniform_dequantize untuk kasus penggunaan pertama dan uniform_quantize untuk kasus penggunaan
kedua dan ketiga. Pada masa mendatang, kedua operasi ini dapat digabungkan menjadi convert (#1576).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor | (C1) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor | (C1) |
Batasan
- (C1)
shape(operand) = shape(result).
Contoh
// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand) : (tensor<3xi64>) -> tensor<3xcomplex<f64>>
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]
konvolusi
Semantik
Menghitung perkalian titik antara jendela lhs dan irisan rhs serta menghasilkan
result. Diagram berikut menunjukkan cara elemen di result dihitung dari
lhs dan rhs menggunakan contoh konkret.
Secara lebih formal, pertimbangkan pembingkaian ulang input berikut dalam hal lhs
agar dapat mengekspresikan jendela lhs:
lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension)).lhs_window_strides = lhs_shape(1, window_strides, 1).lhs_padding = lhs_shape([0, 0], padding, [0, 0]).lhs_base_dilations = lhs_shape(1, lhs_dilation, 1).lhs_window_dilations = lhs_shape(1, rhs_dilation, 1).
Pembingkaian ulang ini menggunakan fungsi bantuan berikut:
lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]).result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]).permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1]di manaj[d] = i[permutation[d]].
Jika feature_group_count = 1 dan batch_group_count = 1, maka untuk semua
output_spatial_index di index_space(dim(result, output_spatial_dimensions...)),
result[result_shape(:, output_spatial_index, :)] = dot_product dengan:
padding_value = constant(0, element_type(lhs)).padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1).lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides.lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations).reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true]). Fitur ini tampaknya tidak digunakan, jadi pada masa mendatang kami berencana menghapusnya (#1181).dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension]).
Jika feature_group_count > 1:
lhses = split(lhs, feature_group_count, input_feature_dimension).rhses = split(rhs, feature_group_count, kernel_output_feature_dimension).results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...).result = concatenate(results, output_feature_dimension).
Jika batch_group_count > 1:
lhses = split(lhs, batch_group_count, input_batch_dimension).rhses = split(rhs, batch_group_count, kernel_output_feature_dimension).results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...).result = concatenate(results, output_feature_dimension).
Untuk jenis yang dikuantisasi, melakukan dequantize_op_quantize(
lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs,
type(result)).
Untuk jenis terkuantisasi hibrida, melakukan hybrid_dequantize_then_op(
lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | lhs |
tensor atau tensor terkuantisasi per tensor | (C1), (C10-C11), (C14) (C25), (C27-C28), (C31-C32), (C34) |
| (I2) | rhs |
tensor atau tensor terkuantisasi | (C1), (C14-C16), (C25), (C27-C29), (C31-C34) |
| (I3) | window_strides |
Konstanta tensor 1 dimensi berjenis si64 |
(C2-C3), (C25) |
| (I4) | padding |
Konstanta tensor 2 dimensi dari jenis si64 |
(C4), (C25) |
| (I5) | lhs_dilation |
Konstanta tensor 1 dimensi berjenis si64 |
(C5-C6), (C25) |
| (I6) | rhs_dilation |
Konstanta tensor 1 dimensi berjenis si64 |
(C7-C8), (C25) |
| (I7) | window_reversal |
Konstanta tensor 1 dimensi berjenis i1 |
(C9) |
| (I8) | input_batch_dimension |
konstanta jenis si64 |
(C10), (C13), (C25) |
| (I9) | input_feature_dimension |
konstanta jenis si64 |
(C11), (C13-C14) |
| (I10) | input_spatial_dimensions |
Konstanta tensor 1 dimensi berjenis si64 |
(C12), (C13), (C25) |
| (I11) | kernel_input_feature_dimension |
konstanta jenis si64 |
(C14), (C18) |
| (I12) | kernel_output_feature_dimension |
konstanta jenis si64 |
(C15-C16), (C18), (C25), (C29) |
| (I13) | kernel_spatial_dimensions |
Konstanta tensor 1 dimensi berjenis si64 |
(C17-C18), (C25) |
| (I14) | output_batch_dimension |
konstanta jenis si64 |
(C20), (C25) |
| (I15) | output_feature_dimension |
konstanta jenis si64 |
(C20), (C25), (C30) |
| (I16) | output_spatial_dimensions |
Konstanta tensor 1 dimensi berjenis si64 |
(C19-C20), (C25) |
| (I17) | feature_group_count |
konstanta jenis si64 |
(C11), (C14), (C16), (C21), (C23) |
| (I18) | batch_group_count |
konstanta jenis si64 |
(C10), (C15), (C22), (C23), (C25) |
| (I19) | precision_config |
jumlah enum variadik dari DEFAULT, HIGH, dan HIGHEST |
(C24) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor atau tensor terkuantisasi | (C25-C28), (C30), (C32-34) |
Batasan
- (C1)
N = rank(lhs) = rank(rhs). - (C2)
size(window_strides) = N - 2. - (C3)
0 < window_strides. - (C4)
shape(padding) = [N - 2, 2]. - (C5)
size(lhs_dilation) = N - 2. - (C6)
0 < lhs_dilation. - (C7)
size(rhs_dilation) = N - 2. - (C8)
0 < rhs_dilation. - (C9)
size(window_reversal) = N - 2. - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0. - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0. - (C12)
size(input_spatial_dimensions) = N - 2. - (C13) Diberikan
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]:is_unique(input_dimensions).0 <= input_dimensions < N.
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count. - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0. - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0. - (C17)
size(kernel_spatial_dimensions) = N - 2. - (C18) Mengingat
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]:is_unique(kernel_dimensions).0 <= kernel_dimensions < N.
- (C19)
size(output_spatial_dimensions) = N - 2. - (C20) Dengan
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]:is_unique(output_dimensions).0 <= output_dimensions < N.
- (C21)
0 < feature_group_count. - (C22)
0 < batch_group_count. - (C23)
feature_group_count = 1 or batch_group_count = 1. - (C24)
size(precision_config) = 2. - (C25)
dim(result, result_dim)didefinisikan sebagai:dim(lhs, input_batch_dimension) / batch_group_countifresult_dim = output_batch_dimension.dim(rhs, kernel_output_feature_dimension)ifresult_dim = output_feature_dimension.num_windowsjika tidak, di mana:output_spatial_dimensions[spatial_dim] = result_dim.lhs_dim = input_spatial_dimensions[spatial_dim].rhs_dim = kernel_spatial_dimensions[spatial_dim].dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1.padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1].dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1.is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim].num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1.
- (C26)
rank(result) = N. - Jika operasi menggunakan tensor yang tidak dikuantisasi:
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result).
- (C27)
- Jika operasi menggunakan tensor terkuantisasi:
- (C28)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs). - (C29) Jika
is_per_axis_quantized(rhs), makaquantization_dimension(rhs) = kernel_output_feature_dimension. - (C30) Jika
is_per_axis_quantized(result), makaquantization_dimension(result) = output_feature_dimension. - Jika
is_quantized(lhs): - (C31)
storage_type(lhs) = storage_type(rhs). - (C32)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result). - (C33) Jika
is_per_tensor_quantized(rhs), makais_per_tensor_quantized(result). - Jika
!is_quantized(lhs): - (C34)
element_type(lhs) = expressed_type(rhs) = element_type(result).
- (C28)
Contoh
// %lhs: [[
// [
// [1], [2], [5], [6]
// ],
// [
// [3], [4], [7], [8]
// ],
// [
// [10], [11], [14], [15]
// ],
// [
// [12], [13], [16], [17]
// ]
// ]]
//
// %rhs: [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
window_strides = array<i64: 4, 4>,
padding = dense<0> : tensor<2x2xi64>,
lhs_dilation = array<i64: 2, 2>,
rhs_dilation = array<i64: 1, 1>,
window_reversal = array<i1: false, false>,
// In the StableHLO dialect, dimension numbers are encoded via:
// `[<input dimensions>]x[<kernel dimensions>]->[output dimensions]`.
// "b" is batch dimension, "f" is feature dimension,
// "i" is input feature dimension, "o" is output feature dimension,
// "0/1/etc" are spatial dimensions.
dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
batch_group_count = 1 : i64,
feature_group_count = 1 : i64,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
// [[10], [26]],
// [[46], [62]]
// ]]
kosinus
Semantik
Melakukan operasi kosinus per elemen pada tensor operand dan menghasilkan tensor result. Bergantung pada jenis elemen, lakukan tindakan berikut:
- Untuk float:
cosdari IEEE-754. - Untuk bilangan kompleks: kosinus kompleks.
- Untuk jenis yang dikuantisasi:
dequantize_op_quantize(cosine, operand, type(result)).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor jenis floating-point atau kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor jenis floating-point atau kompleks atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result).
Contoh
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.cosine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.0], [-1.0, 0.0]]
count_leading_zeros
Semantik
Melakukan penghitungan per elemen dari jumlah bit nol di awal dalam tensor operand dan menghasilkan tensor result.
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor jenis bilangan bulat | (C1) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor jenis bilangan bulat | (C1) |
Batasan
- (C1)
type(operand) = type(result).
Contoh
// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand) : (tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[64, 63], [56, 0]]
custom_call
Semantik
Mengkapsulasi operasi call_target_name yang ditentukan implementasinya yang mengambil
inputs dan called_computations serta menghasilkan results. has_side_effect,
backend_config, dan api_version dapat digunakan untuk memberikan metadata tambahan yang ditentukan implementasi.
Saat ini, operasi ini berisi kumpulan metadata yang cukup tidak teratur yang mencerminkan evolusi organik operasi yang setara di compiler XLA. Pada masa mendatang, kami berencana menyatukan metadata ini (#741).
Input
| Label | Nama | Jenis |
|---|---|---|
| (I1) | inputs |
jumlah nilai variadik |
| (I2) | call_target_name |
konstanta jenis string |
| (I3) | has_side_effect |
konstanta jenis i1 |
| (I4) | backend_config |
konstanta jenis string atau kamus atribut |
| (I5) | api_version |
konstanta jenis si32 |
| (I6) | called_computations |
jumlah konstanta variadik berjenis string |
| (I7) | output_operand_aliases |
menentukan bagian alias dalam output dan operand |
Output
| Nama | Jenis |
|---|---|
results |
jumlah nilai variadik |
(Dukungan GPU XLA) Target custom_call khusus
Ada tiga call_target_name khusus yang terkait dengan jenis buffer:
CreateBuffer membuat buffer yang belum diinisialisasi, Pin membuat buffer yang sudah diinisialisasi, dan Unpin membatalkan alokasi buffer serta menampilkan konten buffer.
%uninitialized_buffer = "stablehlo.custom_call"() {
call_target_name = "CreateBuffer",
api_version = 4 : i32,
} : () -> memref<4xf64>
%initialized_buffer = "stablehlo.custom_call"(%init_value) {
call_target_name = "Pin",
api_version = 4 : i32,
} : (tensor<4xf64>) -> memref<4xf64>
%dealloc_buffer = "stablehlo.custom_call"(%initialized_buffer) {
call_target_name = "Unpin",
api_version = 4 : i32,
} : (memref<4xf64>) -> tensor<4xf64>
Alias
Beberapa operasi custom_call mungkin memerlukan bagian dalam output dan bagian dalam
operan untuk berbagi memori yang sama. Hal ini dapat dinyatakan melalui
output_operand_aliases. Representasi pasangan alias terdiri dari daftar indeks tuple output yang merepresentasikan bagian output, dan operand_index beserta daftar indeks tuple operand yang merepresentasikan bagian operand. Daftar output
atau indeks tuple operand kosong jika jenis yang sesuai bukan jenis tuple, dan dapat memiliki panjang yang berubah-ubah untuk jenis tuple yang bersarang secara berubah-ubah. Hal ini mirip dengan representasi alias XLA.
Bagian output dan bagian input dalam pasangan alias harus memiliki jenis yang sama. Untuk
operasi custom_call yang bukan panggilan ke CreateBuffer, Pin, dan Unpin, operand
buffer dapat muncul di paling banyak satu pasangan alias, dan output buffer
harus muncul di satu pasangan alias.
Contoh
%results = "stablehlo.custom_call"(%input0) {
call_target_name = "foo",
has_side_effect = false,
backend_config = {bar = 42 : i32},
api_version = 4 : i32,
called_computations = [@foo]
} : (tensor<f64>) -> tensor<f64>
%updated_buffer = "stablehlo.custom_call"(%buffer) {
call_target_name = "Update",
api_version = 4 : i32,
output_operand_aliases = [
#stablehlo.output_operand_alias<output_tuple_indices = [],
operand_index = 0,
operand_tuple_indices = []>]
} : (memref<4xf64>) -> memref<4xf64>
bagi
Semantik
Melakukan pembagian per elemen tensor dividen lhs dan tensor pembagi rhs serta
menghasilkan tensor result. Bergantung pada jenis elemen, lakukan tindakan berikut:
- Untuk bilangan bulat: pembagian bilangan bulat yang menghasilkan hasil bagi aljabar dengan bagian pecahan yang diabaikan.
- Untuk float:
divisiondari IEEE-754. - Untuk bilangan kompleks: pembagian kompleks.
- Untuk jenis yang dikuantisasi:
dequantize_op_quantize(divide, lhs, rhs, type(result)).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | lhs |
tensor bilangan bulat, floating point, atau jenis kompleks atau tensor terkuantisasi per-tensor | (C1) |
| (I2) | rhs |
tensor bilangan bulat, floating point, atau jenis kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor bilangan bulat, floating point, atau jenis kompleks atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).
Contoh
// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]
dot_general
Semantik
Menghitung perkalian titik antara slice lhs dan slice rhs serta menghasilkan tensor result.
Secara lebih formal, result[result_index] = dot_product, dengan:
lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions].rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions].result_batching_index + result_lhs_index + result_rhs_index = result_indexdengansize(result_batching_index) = size(lhs_batching_dimensions),size(result_lhs_index) = size(lhs_result_dimensions)dansize(result_rhs_index) = size(rhs_result_dimensions).transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions).transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :]).reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions)).transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions).transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :]).reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions)).dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y)).
Untuk jenis yang dikuantisasi, melakukan dequantize_op_quantize(
lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
rhs_batching_dimensions, lhs_contracting_dimensions,
rhs_contracting_dimensions, precision_config), lhs, rhs, type(result)).
Untuk jenis terkuantisasi hibrida, melakukan hybrid_dequantize_then_op(
lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
rhs_batching_dimensions, lhs_contracting_dimensions,
rhs_contracting_dimensions, precision_config), lhs, rhs).
precision_config mengontrol kompromi antara kecepatan dan akurasi untuk
komputasi di backend akselerator. Ini dapat berupa salah satu dari nilai berikut (saat ini, semantik nilai enum ini kurang ditentukan, tetapi kami berencana untuk menanganinya di #755):
DEFAULT: Perhitungan tercepat, tetapi perkiraan paling tidak akurat untuk angka asli.HIGH: Perhitungan lebih lambat, tetapi perkiraan yang lebih akurat untuk angka asli.HIGHEST: Perhitungan paling lambat, tetapi perkiraan paling akurat untuk angka asli.
DotAlgorithm menentukan properti utama algoritma yang digunakan untuk menerapkan
operasi titik, yang juga menentukan presisi. Jika kolom atribut algoritma
disetel, maka precision_config harus DEFAULT. DotAlgorithms
tidak memiliki nilai default, karena parameter default ditentukan
oleh penerapan. Dengan demikian, semua kolom algoritma titik dapat disetel ke None untuk menentukan algoritma titik kosong, yang akan menggunakan nilai precision_config.
Kolom DotAlgorithm mencakup:
lhs_precision_typedanrhs_precision_type, presisi yang dibulatkan untuk LHS dan RHS operasi. Jenis presisi tidak bergantung pada jenis penyimpanan input dan output.accumulation_typepresisi yang digunakan untuk akumulasi.lhs_component_count,rhs_component_count, dannum_primitive_operationsberlaku saat kita melakukan algoritma yang menguraikan LHS dan/atau RHS menjadi beberapa komponen dan melakukan beberapa operasi titik "primitif" pada nilai-nilai tersebut - biasanya untuk mengemulasi presisi yang lebih tinggi (misalnya, Memanfaatkan Jenis Data Kecerdasan Buatan bfloat16 untuk Komputasi Presisi Tinggi: bf16_6x tf32_3x, dll.). Untuk algoritma tanpa dekomposisi, nilai ini harus ditetapkan ke1.allow_imprecise_accumulationuntuk menentukan apakah akumulasi dalam presisi yang lebih rendah diizinkan untuk beberapa langkah (misalnya,CUBLASLT_MATMUL_DESC_FAST_ACCUM).
Contoh atribut DotAlgorithm:
// Inputs are casted to tf32, and then accumulated in f32:
{lhs_precision_type = tf32,
rhs_precision_type = tf32,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = false}
// bf16_6x: each input is decomposed to 3 bf16 components, then 6 dot operations are done on those components, and the result is accumulated in f32.
{lhs_precision_type = bf16,
rhs_precision_type = bf16,
accumulation_type = f32,
lhs_component_count = 3,
rhs_component_count = 3,
num_primitive_operations = 6,
allow_imprecise_accumulation = false}
// Inputs are (casted to) f8e5m2, and we accumulate in f32, but for some steps we may accumulate in lower precision.
{lhs_precision_type = f8e5m2,
rhs_precision_type = f8e5m2,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = true}
Implementasi yang akan memutuskan kombinasi mana yang didukung. Secara umum, tidak ada jaminan bahwa setiap algoritma didukung di setiap jenis akselerator oleh konsumen StableHLO. Jika algoritma tertentu tidak didukung, error akan muncul, bukan beralih ke alternatif. Verifikasi StableHLO akan memberikan verifikasi upaya terbaik, mencegah algoritma yang tidak diketahui didukung di hardware mana pun.
Lihat xla_data.proto > Algorithm
untuk mengetahui beberapa nilai algoritma yang didukung. Tiket #2483 mencatat rencana untuk membuat dokumen terpusat tentang algoritma yang didukung oleh backend.
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | lhs |
tensor atau tensor terkuantisasi per tensor | (C5-C6), (C9-C10), (C12-C14), (C17-C18), (C20) |
| (I2) | rhs |
tensor atau tensor terkuantisasi | (C7-C10), (C12-C20) |
| (I3) | lhs_batching_dimensions |
Konstanta tensor 1 dimensi berjenis si64 |
(C1), (C3), (C5), (C9), (C12) |
| (I4) | rhs_batching_dimensions |
Konstanta tensor 1 dimensi berjenis si64 |
(C1), (C4), (C7), (C9) |
| (I5) | lhs_contracting_dimensions |
Konstanta tensor 1 dimensi berjenis si64 |
(C2), (C3), (C6), (C10) |
| (I6) | rhs_contracting_dimensions |
Konstanta tensor 1 dimensi berjenis si64 |
(C2), (C4), (C8), (C10), (C16) |
| (I7) | precision_config |
jumlah enum variadik dari DEFAULT, HIGH, dan HIGHEST |
(C11), (C21) |
| (I8) | lhs_precision_type |
FloatType atau TensorFloat32 | (C21) |
| (I9) | rhs_precision_type |
FloatType atau TensorFloat32 | (C21) |
| (I10) | accumulation_type |
FloatType atau TensorFloat32 | (C21) |
| (I11) | lhs_component_count |
konstanta jenis si32 |
(C21), (C22) |
| (I12) | rhs_component_count |
konstanta jenis si32 |
(C21), (C23) |
| (I13) | num_primitive_operations |
konstanta jenis si32 |
(C21), (C24) |
| (I14) | allow_imprecise_accumulation |
konstanta jenis bool |
(C21) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor atau tensor terkuantisasi | (C12), (C14), (C18-C20) |
Batasan
- (C1)
size(lhs_batching_dimensions) = size(rhs_batching_dimensions). - (C2)
size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions). - (C3)
is_unique(lhs_batching_dimensions + lhs_contracting_dimensions). - (C4)
is_unique(rhs_batching_dimensions + rhs_contracting_dimensions). - (C5)
0 <= lhs_batching_dimensions < rank(lhs). - (C6)
0 <= lhs_contracting_dimensions < rank(lhs). - (C7)
0 <= rhs_batching_dimensions < rank(rhs). - (C8)
0 <= rhs_contracting_dimensions < rank(rhs). - (C9)
dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...). - (C10)
dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...). - (C11)
size(precision_config) = 2. - (C12)
shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions). - Jika operasi menggunakan tensor yang tidak dikuantisasi:
- (C13)
element_type(lhs) = element_type(rhs).
- (C13)
- Jika operasi menggunakan tensor terkuantisasi:
- (C14)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs). - (C15)
zero_points(rhs) = 0. - (C16) Jika
is_per_axis_quantized(rhs), makaquantization_dimension(rhs)tidak ada dirhs_contracting_dimensions. - Jika
is_quantized(lhs): - (C17)
storage_type(lhs) = storage_type(rhs). - (C18)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result). - (C19) Jika
is_per_tensor_quantized(rhs), makais_per_tensor_quantized(result). - Jika
!is_quantized(lhs): - (C20)
element_type(lhs) = expressed_type(rhs) = element_type(result).
- (C14)
- Jika
!is_empty_algorithm(lhs_precision_type, rhs_precision_type, accumulation_type, lhs_component_count, rhs_component_count, num_primitive_operations allow_imprecise_accumulation):- (C21)
precision_config... = DEFAULT. - (C22)
0 < lhs_component_count. - (C23)
0 < rhs_component_count. - (C24)
0 < num_primitive_operations.
- (C21)
Contoh
// %lhs: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
// %rhs: [
// [[1, 0],
// [0, 1]],
// [[1, 0],
// [0, 1]]
// ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
dot_dimension_numbers = #stablehlo.dot<
lhs_batching_dimensions = [0],
rhs_batching_dimensions = [0],
lhs_contracting_dimensions = [2],
rhs_contracting_dimensions = [1]
>,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>],
algorithm = #stablehlo.dot_algorithm<
lhs_precision_type = tf32,
rhs_precision_type = tf32,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = false
>
} : (tensor<2x2x2xi64>, tensor<2x2x2xi64>) -> tensor<2x2x2xi64>
// %result: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
dynamic_broadcast_in_dim
Semantik
Operasi ini secara fungsional identik dengan
broadcast_in_dim
op, tetapi bentuk hasilnya ditentukan secara dinamis melalui output_dimensions.
Operasi ini juga menerima atribut opsional known_expanding_dimensions, known_nonexpanding_dimensions
untuk menyatakan pengetahuan statis tentang perilaku perluasan dimensi.
Jika tidak ditentukan, semua dimensi diasumsikan dapat diperluas.
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor atau tensor terkuantisasi | (C1-C2), (C5-C6), (C9) |
| (I2) | output_dimensions |
Tensor 1 dimensi dari jenis bilangan bulat | (C7) |
| (I3) | broadcast_dimensions |
Tensor konstanta 1 dimensi dari jenis bilangan bulat | (C2-C6) |
| (I4) | known_expanding_dimensions |
Tensor konstanta 1 dimensi dari jenis bilangan bulat | (C8-C9) |
| (I5) | known_nonexpanding_dimensions |
Tensor konstanta 1 dimensi dari jenis bilangan bulat | (C8-C9) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor atau tensor terkuantisasi | (C1), (C3), (C5-C7) |
Batasan
- (C1)
element_type(result)diberikan oleh:element_type(operand), jika!is_per_axis_quantized(operand).element_type(operand), kecualiquantization_dimension(operand),scales(operand), danzero_points(operand)dapat berbeda dariquantization_dimension(result),scales(result), danzero_points(result)masing-masing, jika tidak.
- (C2)
size(broadcast_dimensions) = rank(operand). - (C3)
0 <= broadcast_dimensions < rank(result). - (C4)
is_unique(broadcast_dimensions). - (C5) Untuk semua
ddiaxes(operand):dim(operand, d) = 1ataudim(operand, d) = dim(result, broadcast_dimensions[d]).
- (C6) Jika
is_per_axis_quantized(result):quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)].- Jika
dim(operand, quantization_dimension(operand)) = 1, makascales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result))).
- (C7)
size(output_dimensions) = rank(result). - (C8)
is_unique(known_expanding_dimensions + known_nonexpanding_dimensions). - (C9)
0 <= known_expanding_dimensions < rank(operand). - (C10)
0 <= known_nonexpanding_dimensions < rank(operand).
Contoh
// %operand: [
// [1, 2, 3]
// ]
%operand = stablehlo.constant dense<[[1, 2, 3]]> : tensor<1x3xi64>
%output_dimensions = stablehlo.constant dense<[2, 3, 2]> : tensor<3xi64>
%result = "stablehlo.dynamic_broadcast_in_dim"(%operand, %output_dimensions) {
broadcast_dimensions = array<i64: 2, 1>,
known_expanding_dimensions = array<i64: 0>,
known_nonexpanding_dimensions = array<i64: 1>
} : (tensor<1x3xi64>, tensor<3xi64>) -> tensor<2x3x2xi64>
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
dynamic_conv
Semantik
Operasi ini secara fungsional identik dengan operasi
konvolusi, tetapi padding ditentukan secara dinamis melalui padding.
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | lhs |
tensor atau tensor terkuantisasi per tensor | (C1), (C10-C11), (C14) (C25), (C26-C27), (C30-C31), (C33) |
| (I2) | rhs |
tensor atau tensor terkuantisasi | (C1), (C14-C16), (C26-C28), (C30-C33) |
| (I3) | padding |
Tensor 2 dimensi dari jenis bilangan bulat | (C4) |
| (I4) | window_strides |
Konstanta tensor 1 dimensi berjenis si64 |
(C2-C3) |
| (I5) | lhs_dilation |
Konstanta tensor 1 dimensi berjenis si64 |
(C5-C6) |
| (I6) | rhs_dilation |
Konstanta tensor 1 dimensi berjenis si64 |
(C7-C8) |
| (I7) | window_reversal |
Konstanta tensor 1 dimensi berjenis i1 |
(C9) |
| (I8) | input_batch_dimension |
konstanta jenis si64 |
(C10), (C13) |
| (I9) | input_feature_dimension |
konstanta jenis si64 |
(C11), (C13-C14) |
| (I10) | input_spatial_dimensions |
Konstanta tensor 1 dimensi berjenis si64 |
(C12), (C13) |
| (I11) | kernel_input_feature_dimension |
konstanta jenis si64 |
(C14), (C18) |
| (I12) | kernel_output_feature_dimension |
konstanta jenis si64 |
(C15-C16), (C18), (C28) |
| (I13) | kernel_spatial_dimensions |
Konstanta tensor 1 dimensi berjenis si64 |
(C17-C18) |
| (I14) | output_batch_dimension |
konstanta jenis si64 |
(C20) |
| (I15) | output_feature_dimension |
konstanta jenis si64 |
(C20), (C29) |
| (I16) | output_spatial_dimensions |
Konstanta tensor 1 dimensi berjenis si64 |
(C19-C20) |
| (I17) | feature_group_count |
konstanta jenis si64 |
(C11), (C14), (C16), (C21), (C23) |
| (I18) | batch_group_count |
konstanta jenis si64 |
(C10), (C15), (C22), (C23) |
| (I19) | precision_config |
jumlah enum variadik dari DEFAULT, HIGH, dan HIGHEST |
(C24) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor atau tensor terkuantisasi | (C25-C27), (C29), (C31-C33) |
Batasan
- (C1)
N = rank(lhs) = rank(rhs). - (C2)
size(window_strides) = N - 2. - (C3)
0 < window_strides. - (C4)
shape(padding) = [N - 2, 2]. - (C5)
size(lhs_dilation) = N - 2. - (C6)
0 < lhs_dilation. - (C7)
size(rhs_dilation) = N - 2. - (C8)
0 < rhs_dilation. - (C9)
size(window_reversal) = N - 2. - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0. - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0. - (C12)
size(input_spatial_dimensions) = N - 2. - (C13) Diberikan
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]:is_unique(input_dimensions).0 <= input_dimensions < N.
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count. - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0. - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0. - (C17)
size(kernel_spatial_dimensions) = N - 2. - (C18) Mengingat
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]:is_unique(kernel_dimensions).0 <= kernel_dimensions < N.
- (C19)
size(output_spatial_dimensions) = N - 2. - (C20) Dengan
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]:is_unique(output_dimensions).0 <= output_dimensions < N.
- (C21)
0 < feature_group_count. - (C22)
0 < batch_group_count. - (C23)
feature_group_count = 1 or batch_group_count = 1. - (C24)
size(precision_config) = 2. - (C25)
dim(result, result_dim)didefinisikan sebagai:dim(lhs, input_batch_dimension) / batch_group_countifresult_dim = output_batch_dimension.dim(rhs, kernel_output_feature_dimension)ifresult_dim = output_feature_dimension.num_windowsjika tidak, di mana:output_spatial_dimensions[spatial_dim] = result_dim.lhs_dim = input_spatial_dimensions[spatial_dim].rhs_dim = kernel_spatial_dimensions[spatial_dim].dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1.padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1].dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1.is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim].num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1.
- (C26)
rank(result) = N. - Jika operasi menggunakan tensor yang tidak dikuantisasi:
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result).
- (C27)
- Jika operasi menggunakan tensor terkuantisasi:
- (C28)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs). - (C29) Jika
is_per_axis_quantized(rhs), makaquantization_dimension(rhs) = kernel_output_feature_dimension. - (C30) Jika
is_per_axis_quantized(result), makaquantization_dimension(result) = output_feature_dimension. - Jika
is_quantized(lhs): - (C31)
storage_type(lhs) = storage_type(rhs). - (C32)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result). - (C33) Jika
is_per_tensor_quantized(rhs), makais_per_tensor_quantized(result). - Jika
!is_quantized(lhs): - (C34)
element_type(lhs) = expressed_type(rhs) = element_type(result).
- (C28)
Contoh
// %lhs: [[
// [[1], [2], [5], [6]],
// [[3], [4], [7], [8]],
// [[10], [11], [14], [15]],
// [[12], [13], [16], [17]]
// ]]
//
// %rhs: [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
// %padding: [[1, 1],
// [1, 1]]
%result = "stablehlo.dynamic_conv"(%lhs, %rhs, %padding) {
window_strides = array<i64: 4, 4>,
lhs_dilation = array<i64: 2, 2>,
rhs_dilation = array<i64: 1, 1>,
window_reversal = array<i1: false, false>,
dimension_numbers = #stablehlo.conv<raw
input_batch_dimension = 0,
input_feature_dimension = 3,
input_spatial_dimensions = [0, 1],
kernel_input_feature_dimension = 2,
kernel_output_feature_dimension = 3,
kernel_spatial_dimensions = [0, 1],
output_batch_dimension = 0,
output_feature_dimension = 3,
output_spatial_dimensions = [1, 2]
>,
feature_group_count = 1 : i64,
batch_group_count = 1 : i64,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>, tensor<2x2xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
// [[1], [5]],
// [[10], [14]]
// ]]
dynamic_gather
Semantik
Operasi ini secara fungsional identik dengan
gather
op, dengan slice_sizes yang ditentukan secara dinamis sebagai nilai.
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor atau tensor terkuantisasi per tensor | (C1), (C7), (C10-C12), (C14) |
| (I2) | start_indices |
tensor jenis bilangan bulat | (C2), (C3), (C13) |
| (I3) | slice_sizes |
Tensor 1 dimensi dari jenis bilangan bulat | (C8), (C11-C13) |
| (I4) | offset_dims |
Konstanta tensor 1 dimensi berjenis si64 |
(C1), (C4-C5), (C13) |
| (I5) | collapsed_slice_dims |
Konstanta tensor 1 dimensi berjenis si64 |
(C1), (C6-C8), (C13) |
| (I6) | start_index_map |
Konstanta tensor 1 dimensi berjenis si64 |
(C3), (C9), (C10) |
| (I7) | index_vector_dim |
konstanta jenis si64 |
(C2), (C3), (C13) |
| (I8) | indices_are_sorted |
konstanta jenis i1 |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor atau tensor terkuantisasi per tensor | (C5), (C13-C14) |
Batasan
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims). - (C2)
0 <= index_vector_dim <= rank(start_indices). - (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1. - (C4)
is_unique(offset_dims) and is_sorted(offset_dims). - (C5)
0 <= offset_dims < rank(result). - (C6)
is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims). - (C7)
0 <= collapsed_slice_dims < rank(operand). - (C8)
slice_sizes[collapsed_slice_dims...] <= 1. - (C9)
is_unique(start_index_map). - (C10)
0 <= start_index_map < rank(operand). - (C11)
size(slice_sizes) = rank(operand). - (C12)
0 <= slice_sizes <= shape(operand). - (C13)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)yang mana:batch_dim_sizes = shape(start_indices)kecuali ukuran dimensi daristart_indicesyang sesuai denganindex_vector_dimtidak disertakan.offset_dim_sizes = shape(slice_sizes)kecuali ukuran dimensi dalamslice_sizesyang sesuai dengancollapsed_slice_dimstidak disertakan.combinemenempatkanbatch_dim_sizespada sumbu yang sesuai denganbatch_dimsdanoffset_dim_sizespada sumbu yang sesuai denganoffset_dims.
- (C14)
element_type(operand) = element_type(result).
Contoh
// %operand: [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ]
// %start_indices: [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 2]]
// ]
// %slize_sizes: [1, 2, 2]
%result = "stablehlo.dynamic_gather"(%operand, %start_indices, %slize_sizes) {
dimension_numbers = #stablehlo.gather<
offset_dims = [2, 3],
collapsed_slice_dims = [0],
start_index_map = [1, 0],
index_vector_dim = 2>,
indices_are_sorted = false
} : (tensor<3x4x2xi64>, tensor<2x3x2xi64>, tensor<3xi64>) -> tensor<2x3x2x2xi64>
// %result: [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[9, 10], [11, 12]],
// [[11, 12], [13, 14]],
// [[17, 18], [19, 20]]
// ]
// ]
dynamic_iota
Semantik
Operasi ini secara fungsional identik dengan
iota
op, tetapi bentuk hasilnya ditentukan secara dinamis melalui output_shape.
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | output_shape |
Tensor 1 dimensi dari jenis bilangan bulat | (C1), (C2) |
| (I2) | iota_dimension |
si64 |
(C1) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor bilangan bulat, floating point, atau jenis kompleks atau tensor terkuantisasi per-tensor | (C2) |
Batasan
- (C1)
0 <= iota_dimension < size(output_shape). - (C2)
rank(result) = size(output_shape).
Contoh
%output_shape = stablehlo.constant dense<[4, 5]> : tensor<2xi64>
%result = "stablehlo.dynamic_iota"(%output_shape) {
iota_dimension = 0 : i64
} : (tensor<2xi64>) -> tensor<4x5xi64>
// %result: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
dynamic_pad
Semantik
Operasi ini secara fungsional identik dengan
pad
op, tetapi dengan edge_padding_low, edge_padding_high, dan interior_padding
ditentukan secara dinamis sebagai nilai.
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor atau tensor terkuantisasi per tensor | (C1), (C2), (C4) |
| (I2) | padding_value |
Tensor 0 dimensi atau tensor terkuantisasi per tensor | (C1) |
| (I3) | edge_padding_low |
Tensor 1 dimensi dari jenis bilangan bulat | (C1), (C4) |
| (I4) | edge_padding_high |
Tensor 1 dimensi dari jenis bilangan bulat | (C1), (C4) |
| (I5) | interior_padding |
Tensor 1 dimensi dari jenis bilangan bulat | (C2-C4) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor atau tensor terkuantisasi per tensor | (C3-C6) |
Batasan
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result). - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand). - (C3)
0 <= interior_padding. - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high.
Contoh
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
// %edge_padding_low: [0, 1]
// %edge_padding_high: [2, 1]
// %interior_padding: [1, 2]
%result = "stablehlo.dynamic_pad"(%operand, %padding_value,
%edge_padding_low, %edge_padding_high, %interior_padding
) : (tensor<2x3xi64>, tensor<i64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<5x9xi64>
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
dynamic_reshape
Semantik
Operasi ini secara fungsional identik dengan
reshape
op, tetapi bentuk hasilnya ditentukan secara dinamis melalui output_shape.
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor atau tensor terkuantisasi | (C1-C3) |
| (I2) | output_shape |
Tensor 1 dimensi dari jenis bilangan bulat | (C4) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor atau tensor terkuantisasi | (C1-C4) |
Batasan
- (C1)
element_type(result)diberikan oleh:element_type(operand), jika!is_per_axis_quantized(operand).element_type(operand)kecualiquantization_dimension(operand)danquantization_dimension(result)mungkin berbeda.
- (C2)
size(operand) = size(result). - (C3) Jika
is_per_axis_quantized(operand):reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result)).reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
- (C4)
size(output_shape) = rank(result).
Contoh
// %operand: [[1, 2, 3], [4, 5, 6]]
// %output_shape: [3, 2]
%result = "stablehlo.dynamic_reshape"(%operand, %output_shape) : (tensor<2x3xi64>, tensor<2xi64>) -> tensor<3x2xi64>
// %result: [[1, 2], [3, 4], [5, 6]]
dynamic_slice
Semantik
Mengekstrak slice dari operand menggunakan indeks awal yang dihitung secara dinamis
dan menghasilkan tensor result. start_indices berisi indeks awal
slice untuk setiap dimensi yang dapat disesuaikan, dan slice_sizes
berisi ukuran slice untuk setiap dimensi. Secara lebih formal,
result[result_index] = operand[operand_index] dengan:
adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes).operand_index = adjusted_start_indices + result_index.
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor atau tensor terkuantisasi per tensor | (C1), (C2), (C4) |
| (I2) | start_indices |
jumlah variadik tensor 0 dimensi berjenis bilangan bulat | (C2), (C3) |
| (I3) | slice_sizes |
Konstanta tensor 1 dimensi berjenis si64 |
(C2), (C4), (C5) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor atau tensor terkuantisasi per tensor | (C1), (C5) |
Batasan
- (C1)
element_type(operand) = element_type(result). - (C2)
size(start_indices) = size(slice_sizes) = rank(operand). - (C3)
same(type(start_indices...)). - (C4)
0 <= slice_sizes <= shape(operand). - (C5)
shape(result) = slice_sizes.
Contoh
// %operand: [
// [0, 0, 1, 1],
// [0, 0, 1, 1],
// [0, 0, 0, 0],
// [0, 0, 0, 0]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
slice_sizes = array<i64: 2, 2>
} : (tensor<4x4xi32>, tensor<i64>, tensor<i64>) -> tensor<2x2xi32>
// %result: [
// [1, 1],
// [1, 1]
// ]
dynamic_update_slice
Semantik
Menghasilkan tensor result yang sama dengan tensor operand, kecuali slice yang dimulai dari start_indices diperbarui dengan nilai dalam update.
Secara lebih formal, result[result_index] didefinisikan sebagai:
update[update_index]jika0 <= update_index < shape(update)di mana:adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update)).update_index = result_index - adjusted_start_indices.
operand[result_index]jika tidak.
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor atau tensor terkuantisasi per tensor | (C1-C4), (C6) |
| (I2) | update |
tensor atau tensor terkuantisasi per tensor | (C2), (C3), (C6) |
| (I3) | start_indices |
jumlah variadik tensor 0 dimensi berjenis bilangan bulat | (C4), (C5) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor atau tensor terkuantisasi per tensor | (C1) |
Batasan
- (C1)
type(operand) = type(result). - (C2)
element_type(update) = element_type(operand). - (C3)
rank(update) = rank(operand). - (C4)
size(start_indices) = rank(operand). - (C5)
same(type(start_indices...)). - (C6)
0 <= shape(update) <= shape(operand).
Contoh
// %operand: [
// [1, 1, 0, 0],
// [1, 1, 0, 0],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
// %update: [
// [1, 1],
// [1, 1]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
: (tensor<4x4xi32>, tensor<2x2xi32>, tensor<i64>, tensor<i64>) -> tensor<4x4xi32>
// %result: [
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
berpangkat
Semantik
Melakukan operasi eksponensial per elemen pada tensor operand dan menghasilkan tensor result. Bergantung pada jenis elemen, lakukan tindakan berikut:
- Untuk float:
expdari IEEE-754. - Untuk bilangan kompleks: eksponensial kompleks.
- Untuk jenis yang dikuantisasi:
dequantize_op_quantize(exponential, operand, type(result)).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor jenis floating-point atau kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor jenis floating-point atau kompleks atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result).
Contoh
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]
exponential_minus_one
Semantik
Melakukan operasi eksponensial minus satu per elemen pada tensor operand dan
menghasilkan tensor result. Bergantung pada jenis elemen, lakukan tindakan berikut:
- Untuk float:
expm1dari IEEE-754. - Untuk bilangan kompleks: eksponensial kompleks dikurangi satu.
- Untuk jenis yang dikuantisasi:
dequantize_op_quantize(exponential_minus_one, operand, type(result)).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor jenis floating-point atau kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor jenis floating-point atau kompleks atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result).
Contoh
// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand) : (tensor<2xf64>) -> tensor<2xf64>
// %result: [0.0, 1.71828187]
fft
Semantik
Melakukan transformasi Fourier maju dan invers untuk input/output bilangan real dan kompleks.
fft_type adalah salah satu dari berikut ini:
FFT: Meneruskan FFT kompleks ke kompleks.IFFT: FFT kompleks-ke-kompleks invers.RFFT: FFT real ke kompleks.IRFFT: FFT real-to-complex terbalik (yaitu mengambil bilangan kompleks, menampilkan bilangan real).
Secara lebih formal, mengingat fungsi fft yang menggunakan tensor 1 dimensi dari
jenis kompleks sebagai input, menghasilkan tensor 1 dimensi dari jenis yang sama sebagai
output dan menghitung transformasi Fourier diskret:
Untuk fft_type = FFT, result ditentukan sebagai hasil akhir dari serangkaian komputasi L
dengan L = size(fft_length). Misalnya, untuk L = 3:
result1[i0, ..., :] = fft(operand[i0, ..., :]).result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1]).result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).
Selain itu, mengingat fungsi ifft yang memiliki tanda tangan jenis yang sama dan
menghitung inversi fft:
Untuk fft_type = IFFT, result ditentukan sebagai kebalikan dari komputasi
untuk fft_type = FFT. Misalnya, untuk L = 3:
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1]).result[i0, ..., :] = ifft(result2[i0, ..., :]).
Selain itu, mengingat fungsi rfft yang menggunakan tensor 1 dimensi dari
jenis floating point, menghasilkan tensor 1 dimensi dari jenis kompleks dengan
semantik floating point yang sama dan berfungsi sebagai berikut:
rfft(real_operand) = truncated_resultdi manacomplex_operand... = (real_operand..., 0.0).complex_result = fft(complex_operand).truncated_result = complex_result[:(rank(complex_result) / 2 + 1)].
(Saat transformasi Fourier diskret dihitung untuk operan nyata, N/2 + 1 elemen pertama dari hasil secara jelas menentukan hasil lainnya, sehingga hasil rfft dipangkas untuk menghindari penghitungan elemen yang berlebihan).
Untuk fft_type = RFFT, result ditentukan sebagai hasil akhir dari serangkaian komputasi L
dengan L = size(fft_length). Misalnya, untuk L = 3:
result1[i0, ..., :] = rfft(operand[i0, ..., :]).result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1]).result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).
Terakhir, mengingat fungsi irfft yang memiliki tanda tangan jenis yang sama dan
menghitung inversi rfft:
Untuk fft_type = IRFFT, result ditentukan sebagai kebalikan dari komputasi
untuk fft_type = RFFT. Misalnya, untuk L = 3:
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1]).result[i0, ..., :] = irfft(result2[i0, ..., :]).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor jenis floating point atau kompleks | (C1), (C2), (C4), (C5) |
| (I2) | fft_type |
enum FFT, IFFT, RFFT, dan IRFFT |
(C2), (C5) |
| (I3) | fft_length |
Konstanta tensor 1 dimensi berjenis si64 |
(C1), (C3), (C4) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor jenis floating point atau kompleks | (C2), (C4), (C5) |
Batasan
- (C1)
size(fft_length) <= rank(operand). - (C2) Hubungan antara jenis elemen
operanddanresultbervariasi:- Jika
fft_type = FFT,element_type(operand), danelement_type(result)memiliki jenis kompleks yang sama. - Jika
fft_type = IFFT,element_type(operand), danelement_type(result)memiliki jenis kompleks yang sama. - Jika
fft_type = RFFT,element_type(operand)adalah jenis floating point danelement_type(result)adalah jenis kompleks dengan semantik floating point yang sama. - Jika
fft_type = IRFFT,element_type(operand)adalah jenis kompleks danelement_type(result)adalah jenis floating point dengan semantik floating point yang sama.
- Jika
- (C3)
1 <= size(fft_length) <= 3. - (C4) Jika di antara
operanddanresult, ada tensorrealdari jenis floating-point, makashape(real)[-size(fft_length):] = fft_length. - (C5)
shape(result) = shape(operand)kecuali:- Jika
fft_type = RFFT,dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1. - Jika
fft_type = IRFFT,dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1.
- Jika
Contoh
// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
fft_type = #stablehlo<fft_type FFT>,
fft_length = array<i64: 4>
} : (tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>>
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]
lantai
Semantik
Melakukan floor per elemen tensor operand dan menghasilkan tensor result.
Menerapkan operasi roundToIntegralTowardNegative dari spesifikasi IEEE-754. Untuk jenis yang dikuantisasi, melakukan
dequantize_op_quantize(floor, operand, type(result)).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor jenis floating point atau tensor terkuantisasi per-tensor | (C1) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor jenis floating point atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result).
Contoh
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]
mengumpulkan
Semantik
Mengumpulkan slice dari tensor operand dari offset yang ditentukan dalam start_indices
dan menghasilkan tensor result.
Diagram berikut menunjukkan cara elemen di result dipetakan ke elemen di
operand menggunakan contoh konkret. Diagram memilih beberapa contoh indeks result
dan menjelaskan secara mendetail indeks operand yang sesuai dengan indeks tersebut.
Secara lebih formal, result[result_index] = operand[operand_index] dengan:
batch_dims = [d for d in axes(result) and d not in offset_dims].batch_index = result_index[batch_dims...].start_indexditentukan sebagai:start_indices[bi0, ..., :, ..., biN]denganbiadalah elemen individual dalambatch_indexdan:disisipkan pada indeksindex_vector_dim, jikaindex_vector_dim<rank(start_indices).[start_indices[batch_index]]jika tidak.
- Untuk
d_operanddiaxes(operand),full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand])ifd_operand = start_index_map[d_start].full_start_index[d_operand] = 0jika tidak.
- Untuk
d_operanddiaxes(operand),full_batching_index[d_operand] = batch_index[d_start - (d_start < index_vector_dim ? 0 : 1)]jikad_operand = operand_batching_dims[i_batching]dand_start = start_indices_batching_dims[i_batching].full_batching_index[d_operand] = 0jika tidak.
offset_index = result_index[offset_dims...].full_offset_index = [oi0, ..., 0, ..., oiN]denganoiadalah masing-masing elemen dalamoffset_index, dan0disisipkan pada indeks daricollapsed_slice_dimsdanoperand_batching_dims.operand_index = full_start_index + full_batching_index + full_offset_index.
Jika indices_are_sorted adalah true, implementasi dapat mengasumsikan bahwa
start_indices diurutkan berdasarkan start_index_map, jika tidak, perilaku
tidak ditentukan. Secara lebih formal, untuk semua i1 < i2 dari indices(result),
full_start_index(i1) <= full_start_index(i2).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor atau tensor terkuantisasi per tensor | (C1), (C8), (C11), (C17), (C19-C21), (C23) |
| (I2) | start_indices |
tensor jenis bilangan bulat | (C2-C3), (C14), (C17), (C22) |
| (I3) | offset_dims |
Konstanta tensor 1 dimensi berjenis si64 |
(C1), (C4-C5), (C22) |
| (I4) | collapsed_slice_dims |
Konstanta tensor 1 dimensi berjenis si64 |
(C1), (C6-C9), (C22) |
| (I5) | operand_batching_dims |
Konstanta tensor 1 dimensi berjenis si64 |
(C1), (C6), (C10-C12), (C16-C18), (C22) |
| (I6) | start_indices_batching_dims |
Konstanta tensor 1 dimensi berjenis si64 |
(C13-C17) |
| (I7) | start_index_map |
Konstanta tensor 1 dimensi berjenis si64 |
(C3), (C18-C19) |
| (I8) | index_vector_dim |
konstanta jenis si64 |
(C2-C3), (C15), (C22) |
| (I9) | slice_sizes |
Konstanta tensor 1 dimensi berjenis si64 |
(C9), (C12), (C20-C22) |
| (I10) | indices_are_sorted |
konstanta jenis i1 |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor atau tensor terkuantisasi per tensor | (C5), (C22-C23) |
Batasan
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims) + size(operand_batching_dims). - (C2)
0 <= index_vector_dim <= rank(start_indices). - (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1. - (C4)
is_unique(offset_dims) and is_sorted(offset_dims). - (C5)
0 <= offset_dims < rank(result). - (C6)
is_unique(concatenate(collapsed_slice_dims, operand_batching_dims)) - (C7)
is_sorted(collapsed_slice_dims). - (C8)
0 <= collapsed_slice_dims < rank(operand). - (C9)
slice_sizes[collapsed_slice_dims...] <= 1. - (C10)
is_sorted(operand_batching_dims). - (C11)
0 <= operand_batching_dims < rank(operand). - (C12)
slice_sizes[operand_batching_dims...] <= 1. - (C13)
is_unique(start_indices_batching_dims). - (C14)
0 <= start_indices_batching_dims < rank(start_indices). - (C15)
index_vector_dim not in start_indices_batching_dims. - (C16)
size(operand_batching_dims) == size(start_indices_batching_dims). - (C17)
dim(operand, operand_batching_dims...) = dim(start_indices, start_indices_batching_dims...). - (C18)
is_unique(concatenate(start_index_map, operand_batching_dims)). - (C19)
0 <= start_index_map < rank(operand). - (C20)
size(slice_sizes) = rank(operand). - (C21)
0 <= slice_sizes <= shape(operand). - (C22)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)dengan:batch_dim_sizes = shape(start_indices)kecuali ukuran dimensi daristart_indicesyang sesuai denganindex_vector_dimtidak disertakan.offset_dim_sizes = slice_sizeskecuali ukuran dimensi dalamslice_sizesyang sesuai dengancollapsed_slice_dimsdanoperand_batching_dimstidak disertakan.combinemenempatkanbatch_dim_sizespada sumbu yang sesuai denganbatch_dimsdanoffset_dim_sizespada sumbu yang sesuai denganoffset_dims.
- (C23)
element_type(operand) = element_type(result).
Contoh
// %operand: [
// [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ],
// [
// [[25, 26], [27, 28], [29, 30], [31, 32]],
// [[33, 34], [35, 36], [37, 38], [39, 40]],
// [[41, 42], [43, 44], [45, 46], [47, 48]]
// ]
// ]
// %start_indices: [
// [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 9]]
// ],
// [
// [[0, 0], [2, 1], [2, 2]],
// [[1, 2], [0, 1], [1, 0]]
// ]
// ]
%result = "stablehlo.gather"(%operand, %start_indices) {
dimension_numbers = #stablehlo.gather<
offset_dims = [3, 4],
collapsed_slice_dims = [1],
operand_batching_dims = [0],
start_indices_batching_dims = [1],
start_index_map = [2, 1],
index_vector_dim = 3>,
slice_sizes = array<i64: 1, 1, 2, 2>,
indices_are_sorted = false
} : (tensor<2x3x4x2xi32>, tensor<2x2x3x2xi64>) -> tensor<2x2x3x2x2xi32>
// %result: [
// [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[33, 34], [35, 36]],
// [[35, 36], [37, 38]],
// [[41, 42], [43, 44]]
// ]
// ],
// [
// [
// [[1, 2], [3, 4]],
// [[13, 14], [15, 16]],
// [[21, 22], [23, 24]]
// ],
// [
// [[43, 44], [45, 46]],
// [[33, 34], [35, 36]],
// [[27, 28], [29, 30]]
// ]
// ]
// ]
get_dimension_size
Semantik
Menghasilkan ukuran dimension tertentu dari operand. Secara lebih formal,
result = dim(operand, dimension). Semantik hanya berkaitan dengan komponen
bentuk dari jenis. Jenis elemen bisa berupa apa saja.
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor atau tensor terkuantisasi | (C1) |
| (I2) | dimension |
konstanta jenis si64 |
(C1) |
Output
| Nama | Jenis |
|---|---|
result |
Tensor 0 dimensi berjenis si32 |
Batasan
- (C1)
0 <= dimension < rank(operand).
Contoh
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
dimension = 1 : i64
} : (tensor<2x3xi64>) -> tensor<i32>
// %result: 3
get_tuple_element
Semantik
Mengekstrak elemen pada posisi index dari tuple operand dan menghasilkan
result. Secara lebih formal, result = operand[index].
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tuple | (C1), (C2) |
| (I2) | index |
konstanta jenis si32 |
(C1), (C2) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
nilai apa pun | (C2) |
Batasan
- (C1)
0 <= index < size(operand). - (C2)
type(result) = tuple_element_types(operand)[index].
Contoh
// %operand: ([1.0, 2.0], (3))
%result = "stablehlo.get_tuple_element"(%operand) <{index = 0 : i32}> : (tuple<tensor<2xf64>, tuple<tensor<i64>>>) -> tensor<2xf64>
// %result: [1.0, 2.0]
jika
Semantik
Menghasilkan output dari eksekusi tepat satu fungsi dari true_branch atau
false_branch, bergantung pada nilai pred. Secara lebih formal, result =
pred ? true_branch() : false_branch().
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | pred |
Tensor 0 dimensi berjenis i1 |
|
| (I2) | true_branch |
fungsi | (C1-C3) |
| (I3) | false_branch |
fungsi | (C1), (C2) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
results |
sejumlah tensor, tensor terkuantisasi, atau token variadik | (C3) |
Batasan
- (C1)
input_types(true_branch) = input_types(false_branch) = []. - (C2)
output_types(true_branch) = output_types(false_branch). - (C3)
type(results...) = output_types(true_branch).
Contoh
// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
"stablehlo.return"(%result_true_branch) : (tensor<i32>) -> ()
}, {
"stablehlo.return"(%result_false_branch) : (tensor<i32>) -> ()
}) : (tensor<i1>) -> tensor<i32>
// %result: 10
imag
Semantik
Mengekstrak bagian imajiner, per elemen, dari operand dan menghasilkan tensor
result. Secara lebih formal, untuk setiap elemen x:
imag(x) = is_complex(x) ? imaginary_part(x) :
constant(0, element_type(result)).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor jenis floating point atau kompleks | (C1), (C2) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor jenis floating-point | (C1), (C2) |
Batasan
- (C1)
shape(result) = shape(operand). - (C2)
element_type(result)ditentukan sebagai:complex_element_type(element_type(operand))ifis_complex(operand).element_type(operand)jika tidak.
Contoh
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [2.0, 4.0]
dalam feed
Semantik
Membaca data dari feed dan menghasilkan results.
Semantik infeed_config ditentukan oleh implementasi.
results terdiri dari nilai payload yang muncul terlebih dahulu dan token yang muncul
terakhir. Pada masa mendatang, kami berencana membagi payload dan token menjadi dua
output terpisah untuk meningkatkan kejelasan
(#670).
Input
| Label | Nama | Jenis |
|---|---|---|
| (I1) | token |
token |
| (I2) | infeed_config |
konstanta jenis string |
Output
| Nama | Jenis | Batasan |
|---|---|---|
results |
sejumlah tensor, tensor terkuantisasi, atau token variadik | (C1-C3) |
Batasan
- (C1)
0 < size(results). - (C2)
is_empty(result[:-1])atauis_tensor(type(results[:-1])). - (C3)
is_token(type(results[-1])).
Contoh
// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]
iota
Semantik
Mengisi tensor output dengan nilai dalam urutan menaik yang dimulai dari nol
di sepanjang dimensi iota_dimension. Secara lebih formal,
output[output_index] = constant(is_quantized(output) ?
quantize(output_index[iota_dimension], element_type(output)) :
output_index[iota_dimension], element_type(output)).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | iota_dimension |
si64 |
(C1) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
output |
tensor bilangan bulat, floating point, atau jenis kompleks atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
0 <= iota_dimension < rank(output).
Contoh
%output = "stablehlo.iota"() {
iota_dimension = 0 : i64
} : () -> tensor<4x5xi32>
// %output: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
%output = "stablehlo.iota"() {
iota_dimension = 1 : i64
} : () -> tensor<4x5xi32>
// %output: [
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4]
// ]
is_finite
Semantik
Melakukan pemeriksaan per elemen apakah nilai dalam x terbatas (yaitu bukan
+Inf, -Inf, atau NaN) dan menghasilkan tensor y. Menerapkan operasi isFinite
dari spesifikasi IEEE-754. Untuk jenis yang dikuantisasi, hasilnya
selalu true.
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | x |
tensor jenis floating point atau tensor terkuantisasi per-tensor | (C1) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
y |
tensor jenis boolean | (C1) |
Batasan
- (C1)
shape(x) = shape(y).
Contoh
// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x) : (tensor<7xf64) -> tensor<7xi1>
// %y: [false, false, false, true, true, true, true]
log
Semantik
Melakukan operasi logaritma per elemen pada tensor operand dan menghasilkan tensor result. Bergantung pada jenis elemen, lakukan tindakan berikut:
- Untuk float:
logdari IEEE-754. - Untuk bilangan kompleks: logaritma kompleks.
- Untuk jenis yang dikuantisasi:
dequantize_op_quantize(log, operand, type(result)).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor jenis floating-point atau kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor jenis floating-point atau kompleks atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result).
Contoh
// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]
log_plus_one
Semantik
Melakukan operasi logaritma plus satu per elemen pada tensor operand dan
menghasilkan tensor result. Bergantung pada jenis elemen, lakukan tindakan berikut:
- Untuk float:
logp1dari IEEE-754. - Untuk bilangan kompleks:
complex(log(hypot(real(x) + 1, imag(x))), atan2(imag(x), real(x) + 1)) - Untuk jenis yang dikuantisasi:
dequantize_op_quantize(log_plus_one, operand, type(result)).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor jenis floating-point atau kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor jenis floating-point atau kompleks atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result).
Contoh
// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]
logistik
Semantik
Melakukan operasi logistik per elemen pada tensor operand dan menghasilkan tensor
result. Bergantung pada jenis elemen, lakukan tindakan berikut:
- Untuk float:
division(1, addition(1, exp(-x)))dari IEEE-754. - Untuk bilangan kompleks: logistik kompleks.
- Untuk jenis yang dikuantisasi:
dequantize_op_quantize(logistic, operand, type(result)).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor jenis floating-point atau kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor jenis floating-point atau kompleks atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result).
Contoh
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]
peta
Semantik
Menerapkan fungsi peta computation ke inputs di sepanjang dimensions dan
menghasilkan tensor result.
Secara lebih formal, result[result_index] = computation(inputs...[result_index]).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | inputs |
jumlah tensor variadik atau tensor terkuantisasi per tensor | (C1-C4) |
| (I2) | dimensions |
Konstanta tensor 1 dimensi berjenis si64 |
(C3) |
| (I3) | computation |
fungsi | (C4) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor atau tensor terkuantisasi per tensor | (C1), (C4) |
Batasan
- (C1)
shape(inputs...) = shape(result). - (C2)
0 < size(inputs) = N. - (C3)
dimensions = range(rank(inputs[0])). - (C4)
computationmemiliki jenis(tensor<E0>, ..., tensor<EN-1>) -> tensor<E'>denganEi = element_type(inputs[i])danE' = element_type(result).
Contoh
// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = stablehlo.multiply %arg0, %arg1 : tensor<i64>
stablehlo.return %0 : tensor<i64>
}) {
dimensions = array<i64: 0, 1>
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[0, 5], [12, 21]]
maksimum
Semantik
Melakukan operasi maks per elemen pada tensor lhs dan rhs serta menghasilkan tensor result. Bergantung pada jenis elemen, lakukan tindakan berikut:
- Untuk boolean: OR logis.
- Untuk bilangan bulat: maksimum bilangan bulat.
- Untuk float:
maximumdari IEEE-754. - Untuk bilangan kompleks: maksimum leksikografis untuk pasangan
(real, imaginary). Menerapkan pengurutan pada bilangan kompleks melibatkan semantik yang mengejutkan, jadi pada masa mendatang kami berencana menghapus dukungan untuk bilangan kompleks untuk operasi ini (#560). - Untuk jenis yang dikuantisasi:
dequantize_op_quantize(maximum, lhs, rhs, type(result)).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | lhs |
tensor atau tensor terkuantisasi per tensor | (C1) |
| (I2) | rhs |
tensor atau tensor terkuantisasi per tensor | (C1) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor atau tensor terkuantisasi per tensor | (C1) |
Batasan
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).
Contoh
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 8]]
minimum
Semantik
Melakukan operasi min per elemen pada tensor lhs dan rhs serta menghasilkan tensor result. Bergantung pada jenis elemen, lakukan tindakan berikut:
- Untuk boolean: AND logis.
- Untuk bilangan bulat: bilangan bulat minimum.
- Untuk float:
minimumdari IEEE-754. - Untuk bilangan kompleks: minimum leksikografis untuk pasangan
(real, imaginary). Menerapkan pengurutan pada bilangan kompleks melibatkan semantik yang mengejutkan, jadi pada masa mendatang kami berencana menghapus dukungan untuk bilangan kompleks untuk operasi ini (#560). - Untuk jenis yang dikuantisasi:
dequantize_op_quantize(minimum, lhs, rhs, type(result)).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | lhs |
tensor atau tensor terkuantisasi per tensor | (C1) |
| (I2) | rhs |
tensor atau tensor terkuantisasi per tensor | (C1) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor atau tensor terkuantisasi per tensor | (C1) |
Batasan
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).
Contoh
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 4]]
kali
Semantik
Melakukan perkalian per elemen dari dua tensor lhs dan rhs serta menghasilkan tensor result. Bergantung pada jenis elemen, lakukan tindakan berikut:
- Untuk boolean: AND logis.
- Untuk bilangan bulat: perkalian bilangan bulat.
- Untuk float:
multiplicationdari IEEE-754. - Untuk bilangan kompleks: perkalian kompleks.
- Untuk jenis yang dikuantisasi:
dequantize_op_quantize(multiply, lhs, rhs, type(result)).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | lhs |
tensor atau tensor terkuantisasi per tensor | (C1) |
| (I2) | rhs |
tensor atau tensor terkuantisasi per tensor | (C1) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor atau tensor terkuantisasi per tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result).
Contoh
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 12], [21, 32]]
negasi
Semantik
Melakukan negasi per elemen tensor operand dan menghasilkan tensor result. Bergantung pada jenis elemen, lakukan tindakan berikut:
- Untuk bilangan bulat bertanda: negasi bilangan bulat.
- Untuk bilangan bulat tanpa tanda: bitcast ke bilangan bulat bertanda, negasi bilangan bulat, bitcast kembali ke bilangan bulat tanpa tanda.
- Untuk float:
negatedari IEEE-754. - Untuk bilangan kompleks: negasi kompleks.
- Untuk jenis yang dikuantisasi:
dequantize_op_quantize(negate, operand, type(result)).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor bilangan bulat, floating point, atau jenis kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor bilangan bulat, floating point, atau jenis kompleks atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result).
Contoh
// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand) : (tensor<2xi32>) -> tensor<2xi32>
// %result: [0, 2]
// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"(%operand) : (tensor<1xcomplex<f32>>) -> tensor<1xcomplex<f32>>
// %result: [-2.5, -0.0]
tidak
Semantik
Melakukan NOT per elemen tensor operand dan menghasilkan tensor result.
Bergantung pada jenis elemen, lakukan tindakan berikut:
- Untuk boolean: NOT logis.
- Untuk bilangan bulat: bitwise NOT.
Argumen
| Nama | Jenis | Batasan |
|---|---|---|
operand |
tensor jenis boolean atau bilangan bulat | (C1) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor jenis boolean atau bilangan bulat | (C1) |
Batasan
- (C1)
type(operand) = type(result).
Contoh
// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand) : (tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[-2, -3], [-4, -5]]
// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"(%operand) : (tensor<2xi1>) -> tensor<2xi1>
// %result: [false, true]
optimization_barrier
Semantik
Memastikan bahwa operasi yang menghasilkan operand dieksekusi sebelum operasi apa pun yang bergantung pada result dan mencegah transformasi compiler memindahkan operasi melintasi penghalang. Selain itu, operasinya adalah identitas, yaitu result = operand.
Argumen
| Nama | Jenis | Batasan |
|---|---|---|
operand |
sejumlah tensor variadik, tensor atau token terkuantisasi per tensor | (C1) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
sejumlah tensor variadik, tensor atau token terkuantisasi per tensor | (C1) |
Batasan
- (C1)
type(operand...) = type(result...).
Contoh
// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
// %result0: 0.0
// %result1: 1.0
atau
Semantik
Melakukan OR per elemen dari dua tensor lhs dan rhs serta menghasilkan tensor result. Bergantung pada jenis elemen, lakukan tindakan berikut:
- Untuk boolean: OR logis.
- Untuk bilangan bulat: bitwise OR.
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | lhs |
tensor jenis bilangan bulat atau boolean | (C1) |
| (I2) | rhs |
tensor jenis bilangan bulat atau boolean | (C1) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor jenis bilangan bulat atau boolean | (C1) |
Batasan
- (C1)
type(lhs) = type(rhs) = type(result).
Contoh
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, true]]
keluar
Semantik
Menulis inputs ke outfeed dan menghasilkan token result.
Semantik outfeed_config ditentukan oleh implementasi.
Input
| Label | Nama | Jenis |
|---|---|---|
| (I1) | inputs |
jumlah tensor atau tensor terkuantisasi yang bervariasi |
| (I2) | token |
token |
| (I3) | outfeed_config |
konstanta jenis string |
Output
| Nama | Jenis |
|---|---|
result |
token |
Contoh
%result = "stablehlo.outfeed"(%input0, %token) {
outfeed_config = ""
} : (tensor<2x2x2xi64>, !stablehlo.token) -> !stablehlo.token
pad
Semantik
Memperluas operand dengan padding di sekitar tensor serta di antara elemen
tensor dengan padding_value yang diberikan.
edge_padding_low dan edge_padding_high menentukan jumlah padding yang ditambahkan di ujung bawah (di samping indeks 0) dan ujung atas (di samping indeks tertinggi) dari setiap dimensi. Jumlah padding dapat negatif, dengan
nilai absolut padding negatif menunjukkan jumlah elemen yang akan dihapus
dari dimensi yang ditentukan.
interior_padding menentukan jumlah padding yang ditambahkan di antara dua
elemen dalam setiap dimensi yang tidak boleh negatif. Padding interior terjadi
sebelum padding tepi sehingga padding tepi negatif akan menghapus elemen dari
operand yang diberi padding interior.
Secara lebih formal, result[result_index] didefinisikan sebagai:
operand[operand_index]ifresult_index = edge_padding_low + operand_index * (interior_padding + 1).padding_valuejika tidak.
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor atau tensor terkuantisasi per tensor | (C1), (C2), (C4) |
| (I2) | padding_value |
Tensor 0 dimensi atau tensor terkuantisasi per tensor | (C1) |
| (I3) | edge_padding_low |
Konstanta tensor 1 dimensi berjenis si64 |
(C1), (C4) |
| (I4) | edge_padding_high |
Konstanta tensor 1 dimensi berjenis si64 |
(C1), (C4) |
| (I5) | interior_padding |
Konstanta tensor 1 dimensi berjenis si64 |
(C2-C4) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor atau tensor terkuantisasi per tensor | (C3-C6) |
Batasan
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result). - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand). - (C3)
0 <= interior_padding. - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high.
Contoh
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
edge_padding_low = array<i64: 0, 1>,
edge_padding_high = array<i64: 2, 1>,
interior_padding = array<i64: 1, 2>
} : (tensor<2x3xi32>, tensor<i32>) -> tensor<5x9xi32>
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
partition_id
Semantik
Menghasilkan partition_id dari proses saat ini.
Output
| Nama | Jenis |
|---|---|
result |
Tensor 0 dimensi berjenis ui32 |
Contoh
%result = "stablehlo.partition_id"() : () -> tensor<ui32>
popcnt
Semantik
Melakukan penghitungan per elemen jumlah bit yang ditetapkan dalam tensor operand
dan menghasilkan tensor result.
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor jenis bilangan bulat | (C1) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor jenis bilangan bulat | (C1) |
Batasan
- (C1)
type(operand) = type(result).
Contoh
// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand) : (tensor<4xi64>) -> tensor<4xi64>
// %result: [0, 1, 1, 7]
daya
Semantik
Melakukan eksponensiasi per elemen tensor lhs dengan tensor rhs dan
menghasilkan tensor result. Bergantung pada jenis elemen, lakukan tindakan berikut:
- Untuk bilangan bulat: eksponensiasi bilangan bulat.
- Untuk float:
powdari IEEE-754. - Untuk bilangan kompleks: eksponensiasi kompleks.
- Untuk jenis yang dikuantisasi:
dequantize_op_quantize(power, lhs, rhs, type(result)).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | lhs |
tensor bilangan bulat, floating point, atau jenis kompleks atau tensor terkuantisasi per-tensor | (C1) |
| (I2) | rhs |
tensor bilangan bulat, floating point, atau jenis kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor bilangan bulat, floating point, atau jenis kompleks atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result).
Contoh
// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs) : (tensor<6xf64>, tensor<6xf64>) -> tensor<6xf64>
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]
real
Semantik
Mengekstrak bagian real, per elemen, dari operand dan menghasilkan tensor result. Secara lebih formal, untuk setiap elemen x:
real(x) = is_complex(x) ? real_part(x) : x.
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor jenis floating point atau kompleks | (C1), (C2) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor jenis floating-point | (C1), (C2) |
Batasan
- (C1)
shape(result) = shape(operand). - (C2)
element_type(result)ditentukan sebagai:complex_element_type(element_type(operand))ifis_complex(operand).element_type(operand)jika tidak.
Contoh
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [1.0, 3.0]
recv
Semantik
Menerima data dari saluran dengan channel_id dan menghasilkan results.
Jika is_host_transfer adalah true, operasi akan mentransfer data dari
host. Jika tidak, data akan ditransfer dari perangkat lain berdasarkan nilai
source_target_pairs. Flag ini menduplikasi informasi yang diberikan di
channel_type, jadi pada masa mendatang kami berencana untuk hanya menyimpan salah satunya
(#666). Jika is_host_transfer
= false dan source_target_pairs adalah None atau kosong, maka dianggap sebagai
perilaku yang tidak ditentukan.
results terdiri dari nilai payload yang muncul terlebih dahulu dan token yang muncul
terakhir. Pada masa mendatang, kami berencana membagi payload dan token menjadi dua
output terpisah untuk meningkatkan kejelasan
(#670).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | token |
token |
|
| (I2) | channel_id |
konstanta jenis si64 |
|
| (I3) | channel_type |
enum DEVICE_TO_DEVICE dan DEVICE_TO_HOST |
(C5) |
| (I4) | is_host_transfer |
konstanta jenis i1 |
(C5-C6) |
| (I5) | source_target_pairs |
Konstanta tensor 2 dimensi dari jenis si64 |
(C1-C4), (C6) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
results |
sejumlah tensor, tensor terkuantisasi, atau token variadik | (C2-C4) |
Batasan
- (C1)
dim(source_target_pairs, 1) = 2. - (C2)
is_unique(source_target_pairs[:, 0]). - (C3)
is_unique(source_target_pairs[:, 1]). - (C4)
0 <= source_target_pairs < N, denganNditentukan sebagai:num_replicasjikacross_replicadigunakan.num_partitionsjikacross_partitiondigunakan.
- (C5)
channel_typeditentukan sebagai:DEVICE_TO_HOSTifis_host_transfer = true,DEVICE_TO_DEVICEjika tidak.
Contoh
%results0, %results1 = "stablehlo.recv"(%token) {
channel_handle = #stablehlo.channel_handle<handle = 0, type = 1>,
is_host_transfer = false,
source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
reduce
Semantik
Menerapkan fungsi pengurangan body ke inputs dan init_values di sepanjang
dimensions dan menghasilkan tensor results.
Urutan pengurangan ditentukan oleh implementasi, yang berarti bahwa body dan
init_values harus membentuk monoid untuk menjamin bahwa operasi menghasilkan
hasil yang sama untuk semua input pada semua implementasi. Namun, kondisi ini tidak berlaku untuk banyak pengurangan populer. Misalnya, penambahan floating point untuk
body dan nol untuk init_values sebenarnya tidak membentuk monoid karena
penambahan floating point tidak bersifat asosiatif.
Secara lebih formal, results...[j0, ..., jR-1] = reduce(input_slices_converted) dengan:
input_slices = inputs...[j0, ..., :, ..., jR-1], dengan:disisipkan didimensions.input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...).init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...).reduce(input_slices_converted) = exec(schedule)untuk beberapa pohon binerscheduledengan:exec(node) = body(exec(node.left), exec(node.right)).exec(leaf) = leaf.value.
scheduleadalah pohon biner penuh yang ditentukan implementasinya yang traversal in-order-nya terdiri dari:- Nilai
input_slices_converted...[index], untuk semuaindexdalamindex_space(input_slices_converted)dalam urutan leksikografis menaik dariindex. - Diselingi dengan jumlah
init_values_convertedyang ditentukan implementasi pada posisi yang ditentukan implementasi.
- Nilai
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | inputs |
jumlah tensor variadik atau tensor terkuantisasi per tensor | (C1-C4), (C6), (C7) |
| (I2) | init_values |
jumlah tensor 0 dimensi variadik atau tensor terkuantisasi per tensor | (C2), (C3) |
| (I3) | dimensions |
Konstanta tensor 1 dimensi berjenis si64 |
(C4), (C5), (C7) |
| (I4) | body |
fungsi | (C6) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
results |
jumlah tensor variadik atau tensor terkuantisasi per tensor | (C3), (C7), (C8) |
Batasan
- (C1)
same(shape(inputs...)). - (C2)
element_type(inputs...) = element_type(init_values...). - (C3)
0 < size(inputs) = size(init_values) = size(results) = N. - (C4)
0 <= dimensions < rank(inputs[0]). - (C5)
is_unique(dimensions). - (C6)
bodymemiliki jenis(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)denganis_promotable(element_type(inputs[i]), Ei). - (C7)
shape(results...) = shape(inputs...)kecuali ukuran dimensiinputs...yang sesuai dengandimensionstidak disertakan. - (C8)
element_type(results[i]) = Eiuntuk semuaidi[0,N).
Contoh
// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
dimensions = array<i64: 1>
} : (tensor<1x6xi64>, tensor<i64>) -> tensor<1xi64>
// %result = [15]
reduce_precision
Semantik
Melakukan konversi per elemen dari operand ke jenis floating point lain
yang menggunakan exponent_bits dan mantissa_bits, lalu kembali ke jenis
floating point asli dan menghasilkan tensor output.
Lebih formal:
- Bit mantisa dari nilai asli diperbarui untuk membulatkan nilai asli ke nilai terdekat yang dapat direpresentasikan dengan
mantissa_bitsmenggunakan semantikroundToIntegralTiesToEven. - Kemudian, jika
mantissa_bitslebih kecil dari jumlah bit mantisa dari nilai asli, bit mantisa akan dipangkas menjadimantissa_bits. - Kemudian, jika bit eksponen hasil perantara tidak sesuai dengan
rentang yang disediakan oleh
exponent_bits, hasil perantara akan meluap ke tak terhingga menggunakan tanda asli atau meluap ke nol menggunakan tanda asli. - Untuk jenis yang dikuantisasi, melakukan
dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result)).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor jenis floating point atau tensor terkuantisasi per-tensor | (C1) |
| (I2) | exponent_bits |
konstanta jenis si32 |
(C2) |
| (I3) | mantissa_bits |
konstanta jenis si32 |
(C3) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
output |
tensor jenis floating point atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(output). - (C2)
1 <= exponent_bits. - (C3)
0 <= mantissa_bits.
Contoh
// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
exponent_bits = 5 : i32,
mantissa_bits = 10 : i32
} : (tensor<6xf64>) -> tensor<6xf64>
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]
reduce_scatter
Semantik
Dalam setiap grup proses di petak proses StableHLO, lakukan reduksi,
menggunakan computations, atas nilai tensor operand dari setiap proses,
pisahkan hasil reduksi di sepanjang scatter_dimension menjadi beberapa bagian, dan sebar
bagian yang dipisahkan di antara proses untuk menghasilkan result.
Operasi ini membagi petak proses StableHLO menjadi process_groups yang ditentukan sebagai berikut:
cross_replica(replica_groups)ifchannel_id <= 0 and use_global_device_ids = false.cross_replica_and_partition(replica_groups)ifchannel_id > 0 and use_global_device_ids = false.flattened_ids(replica_groups)ifchannel_id > 0 and use_global_device_ids = true.
Setelah itu, di dalam setiap process_group:
reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation).parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension).result@receiver = parts@sender[receiver_index]untuk semuasenderdalamprocess_group, denganreceiver_index = process_group.index(receiver).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor atau tensor terkuantisasi per tensor | (C1), (C2), (C7), (C8) |
| (I2) | scatter_dimension |
konstanta jenis si64 |
(C1), (C2), (C8) |
| (I3) | replica_groups |
Konstanta tensor 2 dimensi dari jenis si64 |
(C3-C5) |
| (I4) | channel_id |
konstanta jenis si64 |
(C6) |
| (I5) | use_global_device_ids |
konstanta jenis i1 |
(C6) |
| (I6) | computation |
fungsi | (C7) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor atau tensor terkuantisasi per tensor | (C8-C9) |
Batasan
- (C1)
dim(operand, scatter_dimension) % dim(process_groups, 1) = 0. - (C2)
0 <= scatter_dimension < rank(operand). - (C3)
is_unique(replica_groups). - (C4)
size(replica_groups)ditentukan sebagai:num_replicasjikacross_replicadigunakan.num_replicasjikacross_replica_and_partitiondigunakan.num_processesjikaflattened_idsdigunakan.
- (C5)
0 <= replica_groups < size(replica_groups). - (C6) Jika
use_global_device_ids = true, makachannel_id > 0. - (C7)
computationmemiliki jenis(tensor<E>, tensor<E>) -> (tensor<E>)denganis_promotable(element_type(operand), E). - (C8)
shape(result) = shape(operand)kecuali:dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1).
- (C9)
element_type(result) = E.
Contoh
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
scatter_dimension = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x4xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[10, 12],
// [18, 20]]
// %result@(1, 0): [[14, 16],
// [22, 24]]
reduce_window
Semantik
Menerapkan fungsi pengurangan body ke jendela inputs dan init_values
serta menghasilkan results.
Diagram berikut menunjukkan cara elemen di results... dihitung dari
inputs... menggunakan contoh konkret.
Secara lebih formal,
results...[result_index] = reduce(windows, init_values, axes(inputs...), body)
(lihat reduce) dengan:
padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1).window_start = result_index * window_strides.window_end = window_start + (window_dimensions - 1) * window_dilations + 1.windows = slice(padded_inputs..., window_start, window_end, window_dilations).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | inputs |
jumlah tensor variadik atau tensor terkuantisasi per tensor | (C1-C4), (C6), (C8), (C10), (C12), (C13), (C15) |
| (I2) | init_values |
jumlah tensor 0 dimensi variadik atau tensor terkuantisasi per tensor | (C1), (C13) |
| (I3) | window_dimensions |
Konstanta tensor 1 dimensi berjenis si64 |
(C4), (C5), (C15) |
| (I4) | window_strides |
Konstanta tensor 1 dimensi berjenis si64 |
(C6), (C7), (C15) |
| (I5) | base_dilations |
Konstanta tensor 1 dimensi berjenis si64 |
(C8), (C9), (C15) |
| (I6) | window_dilations |
Konstanta tensor 1 dimensi berjenis si64 |
(C10), (C11), (C15) |
| (I7) | padding |
Konstanta tensor 2 dimensi dari jenis si64 |
(C12), (C15) |
| (I8) | body |
fungsi | (C13) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
results |
jumlah tensor variadik atau tensor terkuantisasi per tensor | (C1), (C14-C16) |
Batasan
- (C1)
0 < size(inputs) = size(init_values) = size(results) = N. - (C2)
same(shape(inputs...)). - (C3)
element_type(inputs...) = element_type(init_values...). - (C4)
size(window_dimensions) = rank(inputs[0]). - (C5)
0 < window_dimensions. - (C6)
size(window_strides) = rank(inputs[0]). - (C7)
0 < window_strides. - (C8)
size(base_dilations) = rank(inputs[0]). - (C9)
0 < base_dilations. - (C10)
size(window_dilations) = rank(inputs[0]). - (C11)
0 < window_dilations. - (C12)
shape(padding) = [rank(inputs[0]), 2]. - (C13)
bodymemiliki jenis(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)denganis_promotable(element_type(inputs[i]), Ei). - (C14)
same(shape(results...)). - (C15)
shape(results[0]) = num_windowsdengan:dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1.padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1].dilated_window_shape = (window_dimensions - 1) * window_dilations + 1.is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape.num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1.
- (C16)
element_type(results[i]) = Eiuntuk semuaidi[0,N).
Contoh
// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
window_dimensions = array<i64: 2, 1>,
window_strides = array<i64: 4, 1>,
base_dilations = array<i64: 2, 1>,
window_dilations = array<i64: 3, 1>,
padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<3x2xi64>, tensor<i64>) -> tensor<2x2xi64>
// %result = [[0, 0], [3, 4]]
sisa bagi
Semantik
Melakukan sisa pembagian per elemen dari tensor pembilang lhs dan tensor pembagi rhs serta
menghasilkan tensor result.
Secara lebih formal, tanda hasil diambil dari dividen, dan nilai absolut hasil selalu kurang dari nilai absolut pembagi.
Sisanya dihitung sebagai lhs - d * rhs, dengan d diberikan oleh:
- Untuk bilangan bulat:
stablehlo.divide(lhs, rhs). - Untuk float:
division(lhs, rhs)dari IEEE-754 dengan atribut pembulatanroundTowardZero. - Untuk bilangan kompleks: TBD (#997).
- Untuk jenis yang dikuantisasi:
dequantize_op_quantize(remainder, lhs, rhs, type(result)).
Untuk jenis elemen floating point, operasi ini berbeda dengan operasi
remainder dari spesifikasi IEEE-754 di mana d adalah nilai integral
terdekat dengan nilai lhs/rhs yang tepat dengan ikatan ke genap.
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | lhs |
tensor bilangan bulat, floating point, atau jenis kompleks atau tensor terkuantisasi per-tensor | (C1) |
| (I2) | rhs |
tensor bilangan bulat, floating point, atau jenis kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor bilangan bulat, floating point, atau jenis kompleks atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result).
Contoh
// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64>
// %result: [2, -2, 2, -2]
replica_id
Semantik
Menghasilkan replica_id dari proses saat ini.
Output
| Nama | Jenis |
|---|---|
result |
Tensor 0 dimensi berjenis ui32 |
Contoh
%result = "stablehlo.replica_id"() : () -> tensor<ui32>
membentuk ulang
Semantik
Melakukan pembentukan ulang tensor operand menjadi tensor result. Secara konseptual, hal ini sama dengan mempertahankan representasi kanonis yang sama, tetapi berpotensi mengubah bentuknya, misalnya dari tensor<2x3xf32> menjadi tensor<3x2xf32> atau tensor<6xf32>.
Secara lebih formal, result[result_index] = operand[operand_index] dengan
result_index dan operand_index memiliki posisi yang sama dalam urutan
leksikografis index_space(result) dan index_space(operand).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor atau tensor terkuantisasi | (C1-C3) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor atau tensor terkuantisasi | (C1-C3) |
Batasan
- (C1)
element_type(result)diberikan oleh:element_type(operand), jika!is_per_axis_quantized(operand).element_type(operand)kecualiquantization_dimension(operand)danquantization_dimension(result)mungkin berbeda.
- (C2)
size(operand) = size(result). - (C3) Jika
is_per_axis_quantized(operand):reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result)).reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
Contoh
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.reshape"(%operand) : (tensor<2x3xi32>) -> tensor<3x2xi32>
// %result: [[1, 2], [3, 4], [5, 6]]
balik
Semantik
Membalikkan urutan elemen dalam operand di sepanjang dimensions yang ditentukan
dan menghasilkan tensor result. Secara lebih formal,
result[result_index] = operand[operand_index] dengan:
operand_index[d] = dim(result, d) - result_index[d] - 1ifdindimensions.operand_index[d] = result_index[d]jika tidak.
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor atau tensor terkuantisasi per tensor | (C1), (C3) |
| (I2) | dimensions |
Konstanta tensor 1 dimensi berjenis si64 |
(C2), (C3) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor atau tensor terkuantisasi per tensor | (C1), (C3) |
Batasan
- (C1)
type(operand) = type(result). - (C2)
is_unique(dimensions). - (C3)
0 <= dimensions < rank(result).
Contoh
// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
dimensions = array<i64: 1>
} : (tensor<3x2xi32>) -> tensor<3x2xi32>
// %result: [[2, 1], [4, 3], [6, 5]]
rng
Semantik
Membuat angka acak menggunakan algoritma rng_distribution dan menghasilkan tensor
result dengan bentuk shape tertentu.
Jika rng_distribution = UNIFORM, maka angka acak akan dibuat
mengikuti distribusi seragam selama interval [a, b). Jika a >= b,
perilakunya tidak ditentukan.
Jika rng_distribution = NORMAL, maka bilangan acak dihasilkan
mengikuti distribusi normal dengan rerata = a dan simpangan baku = b.
Jika b < 0, perilakunya tidak ditentukan.
Cara persis pembuatan angka acak ditentukan oleh implementasi. Misalnya, mungkin deterministik atau tidak, dan mungkin menggunakan atau tidak menggunakan status tersembunyi.
Dalam percakapan dengan banyak pemangku kepentingan, op ini telah muncul sebagai tidak digunakan lagi secara efektif, jadi di masa mendatang kami berencana untuk menghapusnya (#597).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | a |
Tensor 0 dimensi dari jenis bilangan bulat, boolean, atau floating point | (C1), (C2) |
| (I2) | b |
Tensor 0 dimensi dari jenis bilangan bulat, boolean, atau floating point | (C1), (C2) |
| (I3) | shape |
Konstanta tensor 1 dimensi berjenis si64 |
(C3) |
| (I4) | rng_distribution |
enum UNIFORM dan NORMAL |
(C2) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor jenis bilangan bulat, boolean, atau floating point | (C1-C3) |
Batasan
- (C1)
element_type(a) = element_type(b) = element_type(result). - (C2) Jika
rng_distribution = NORMAL, makais_float(a). - (C3)
shape(result) = shape.
Contoh
// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
rng_distribution = #stablehlo<rng_distribution UNIFORM>
} : (tensor<i32>, tensor<i32>, tensor<2xi64>) -> tensor<3x3xi32>
// %result: [
// [1, 0, 1],
// [1, 1, 1],
// [0, 0, 0]
// ]
rng_bit_generator
Semantik
Menampilkan output yang diisi dengan bit acak seragam dan status output yang diperbarui
output_state menggunakan algoritma generator angka pseudorandom rng_algorithm
dengan status awal initial_state. Output dijamin berupa fungsi deterministik initial_state, tetapi tidak dijamin deterministik antar-penerapan.
rng_algorithm adalah salah satu dari berikut ini:
DEFAULT: Algoritma yang ditentukan implementasinya.THREE_FRY: Varian algoritma Threefry yang ditentukan implementasinya.*PHILOX: Varian algoritma Philox yang ditentukan implementasinya.*
* Lihat: Salmon et al. SC 2011. Angka acak paralel: semudah 1, 2, 3.
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | rng_algorithm |
enum DEFAULT, THREE_FRY, dan PHILOX |
(C2) |
| (I2) | initial_state |
Tensor 1 dimensi berjenis ui64 |
(C1), (C2) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
output_state |
Tensor 1 dimensi berjenis ui64 |
(C1) |
output |
tensor jenis bilangan bulat atau floating point |
Batasan
- (C1)
type(initial_state) = type(output_state). - (C2)
size(initial_state)ditentukan sebagai:- ditetapkan oleh implementasi jika
rng_algorithm = DEFAULT. 2ifrng_algorithm = THREE_FRY.2atau3jikarng_algorithm = PHILOX.
- ditetapkan oleh implementasi jika
Contoh
// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
rng_algorithm = #stablehlo<rng_algorithm THREE_FRY>
} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x2xui64>)
// %output_state: [1, 6]
// %output: [
// [9236835810183407956, 16087790271692313299],
// [18212823393184779219, 2658481902456610144]
// ]
round_nearest_afz
Semantik
Melakukan pembulatan per elemen ke bilangan bulat terdekat, memisahkan ikatan dari nol, pada tensor operand dan menghasilkan tensor result. Menerapkan
operasi roundToIntegralTiesToAway dari spesifikasi IEEE-754. Untuk
jenis terkuantisasi, melakukan
dequantize_op_quantize(round_nearest_afz, operand, type(result)).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor jenis floating point atau tensor terkuantisasi per-tensor | (C1) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor jenis floating point atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result).
Contoh
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]
round_nearest_even
Semantik
Melakukan pembulatan per elemen ke bilangan bulat terdekat, memecahkan kesamaan
ke bilangan bulat genap, pada tensor operand dan menghasilkan tensor result. Menerapkan operasi roundToIntegralTiesToEven dari spesifikasi IEEE-754. Untuk jenis yang dikuantisasi, melakukan
dequantize_op_quantize(round_nearest_even, operand, type(result)).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor jenis floating point atau tensor terkuantisasi per-tensor | (C1) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor jenis floating point atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result).
Contoh
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]
rsqrt
Semantik
Melakukan operasi akar kuadrat timbal balik per elemen pada tensor operand dan menghasilkan tensor result. Bergantung pada jenis elemen, lakukan tindakan berikut:
- Untuk float:
rSqrtdari IEEE-754. - Untuk bilangan kompleks: akar kuadrat timbal balik kompleks.
- Untuk jenis yang dikuantisasi:
dequantize_op_quantize(rsqrt, operand, type(result)).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor jenis floating-point atau kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor jenis floating-point atau kompleks atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result).
Contoh
// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.5], [0.33333343, 0.2]]
scatter (memencar)
Semantik
Menghasilkan tensor results yang sama dengan tensor inputs, kecuali beberapa slice yang ditentukan oleh scatter_indices diperbarui dengan nilai updates menggunakan update_computation.
Diagram berikut menunjukkan cara elemen di updates... dipetakan ke elemen di
results... menggunakan contoh konkret. Diagram ini memilih beberapa contoh indeks
updates... dan menjelaskan secara mendetail indeks results... yang sesuai
dengan indeks tersebut.
Lebih formal, untuk semua update_index di index_space(updates[0]):
update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims].update_scatter_index = update_index[update_scatter_dims...].start_indexditentukan sebagai:scatter_indices[si0, ..., :, ..., siN]dengansiadalah elemen individu dalamupdate_scatter_indexdan:disisipkan pada indeksindex_vector_dim, jikaindex_vector_dim<rank(scatter_indices).[scatter_indices[update_scatter_index]]jika tidak.
- Untuk
d_inputdiaxes(inputs[0]),full_start_index[d_input] = start_index[d_start]ifd_input = scatter_dims_to_operand_dims[d_start].full_start_index[d_input] = 0jika tidak.
- Untuk
d_inputdiaxes(inputs[0]),full_batching_index[d_input] = update_scatter_index[d_start - (d_start < index_vector_dim ? 0 : 1)]jikad_input = input_batching_dims[i_batching]dand_start = scatter_indices_batching_dims[i_batching].full_batching_index[d_input] = 0jika tidak.
update_window_index = update_index[update_window_dims...].full_window_index = [wi0, ..., 0, ..., wiN]denganwiadalah masing-masing elemen dalamupdate_window_index, dan0disisipkan pada indeks dariinserted_window_dimsdaninput_batching_dims.result_index = full_start_index + full_batching_index + full_window_index.
Dengan demikian, results = exec(schedule, inputs), dengan:
scheduleadalah permutasi yang ditentukan implementasi dariindex_space(updates[0]).exec([update_index, ...], results) = exec([...], updated_results)dengan:- Jika
result_indexberada dalam batas untukshape(results...) updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )updated_values = update_computation(results...[result_index], updates_converted)updated_resultsadalah salinanresultsdenganresults...[result_index]disetel keupdated_values....- Atau
updated_results = results.
- Jika
exec([], results) = results.
Jika indices_are_sorted adalah true, implementasi dapat mengasumsikan bahwa
scatter_indices diurutkan berdasarkan scatter_dims_to_operand_dims,
jika tidak, perilakunya tidak terdefinisi. Lebih formalnya, untuk semua i1 < i2 dari
indices(result), full_start_index(i1) <= full_start_index(i2).
Jika unique_indices adalah true, maka implementasi dapat mengasumsikan bahwa semua
indeks result_index yang tersebar bersifat unik. Jika unique_indices adalah
true, tetapi indeks yang disebarkan tidak unik, maka perilakunya tidak
terdefinisi.
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | inputs |
jumlah tensor variadik atau tensor terkuantisasi per tensor | (C1), (C2), (C4-C6), (C11), (C13), (C18), (C21), (C23-C24) |
| (I2) | scatter_indices |
tensor jenis bilangan bulat | (C4), (C15), (C19), (C22) |
| (I3) | updates |
jumlah tensor variadik atau tensor terkuantisasi per tensor | (C3-C6), (C8) |
| (I4) | update_window_dims |
Konstanta tensor 1 dimensi berjenis si64 |
(C2), (C4), (C7-C8) |
| (I5) | inserted_window_dims |
Konstanta tensor 1 dimensi berjenis si64 |
(C2), (C4), (C9-C11) |
| (I6) | input_batching_dims |
Konstanta tensor 1 dimensi berjenis si64 |
(C2), (C4), (C9), (C12-13), (C17-18), (C20) |
| (I7) | scatter_indices_batching_dims |
Konstanta tensor 1 dimensi berjenis si64 |
(C14-C18) |
| (I8) | scatter_dims_to_operand_dims |
Konstanta tensor 1 dimensi berjenis si64 |
(C19-C21) |
| (I9) | index_vector_dim |
konstanta jenis si64 |
(C4), (C16), (C19), (C22) |
| (I10) | indices_are_sorted |
konstanta jenis i1 |
|
| (I11) | unique_indices |
konstanta jenis i1 |
|
| (I12) | update_computation |
fungsi | (C23) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
results |
jumlah tensor variadik atau tensor terkuantisasi per tensor | (C24-C25) |
Batasan
- (C1)
same(shape(inputs...)). - (C2)
rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims) + size(input_batching_dims). - (C3)
same(shape(updates...)). - (C4)
shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes)where:update_scatter_dim_sizes = shape(scatter_indices)kecuali ukuran dimensiscatter_indicesyang sesuai denganindex_vector_dimtidak disertakan.update_window_dim_sizes <= shape(inputs[0])kecuali ukuran dimensi dalaminputs[0]yang sesuai denganinserted_window_dimsdaninput_batching_dimstidak disertakan.combinemenempatkanupdate_scatter_dim_sizespada sumbu yang sesuai denganupdate_scatter_dimsdanupdate_window_dim_sizespada sumbu yang sesuai denganupdate_window_dims.
- (C5)
0 < size(inputs) = size(updates) = N. - (C6)
element_type(updates...) = element_type(inputs...). - (C7)
is_unique(update_window_dims) and is_sorted(update_window_dims). - (C8)
0 <= update_window_dims < rank(updates[0]). - (C9)
is_unique(concatenate(inserted_window_dims, input_batching_dims)) - (C10)
is_sorted(inserted_window_dims). - (C11)
0 <= inserted_window_dims < rank(inputs[0]). - (C12)
is_sorted(input_batching_dims). - (C13)
0 <= input_batching_dims < rank(inputs[0])). - (C14)
is_unique(scatter_indices_batching_dims). - (C15)
0 <= scatter_indices_batching_dims < rank(scatter_indices). - (C16)
index_vector_dim not in scatter_indices_batching_dims. - (C17)
size(input_batching_dims) == size(scatter_indices_batching_dims). - (C18)
dim(inputs[0], input_batching_dims...) = dim(scatter_indices, scatter_indices_batching_dims...). - (C19)
size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1. - (C20)
is_unique(concatenate(scatter_dims_to_operand_dims, input_batching_dims)). - (C21)
0 <= scatter_dims_to_operand_dims < rank(inputs[0]). - (C22)
0 <= index_vector_dim <= rank(scatter_indices). - (C23)
update_computationmemiliki jenis(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>), denganis_promotable(element_type(inputs[i]), Ei). - (C24)
shape(inputs...) = shape(results...). - (C25)
element_type(results[i]) = Eiuntuk semuaidi[0,N).
Contoh
// %input: [
// [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ],
// [
// [[25, 26], [27, 28], [29, 30], [31, 32]],
// [[33, 34], [35, 36], [37, 38], [39, 40]],
// [[41, 42], [43, 44], [45, 46], [47, 48]]
// ]
// ]
// %scatter_indices: [
// [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 9]]
// ],
// [
// [[0, 0], [2, 1], [2, 2]],
// [[1, 2], [0, 1], [1, 0]]
// ]
// ]
// %update: [
// [
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]],
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]]
// ],
// [
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]],
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]]
// ]
// ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [3, 4],
inserted_window_dims = [1],
input_batching_dims = [0],
scatter_indices_batching_dims = [1],
scatter_dims_to_operand_dims = [2, 1],
index_vector_dim = 3>,
indices_are_sorted = false,
unique_indices = false
} : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>, tensor<2x2x3x2x2xi64>) -> tensor<2x3x4x2xi64>
// %result: [
// [
// [[3, 4], [6, 7], [6, 7], [7, 8]],
// [[9, 10],[11, 12], [15, 16], [17, 18]],
// [[17, 18], [19, 20], [22, 23], [24, 25]]
// ],
// [
// [[25, 26], [28, 29], [30, 31], [31, 32]],
// [[35, 36], [38, 39], [38, 39], [39, 40]],
// [[41, 42], [44, 45], [46, 47], [47, 48]]
// ]
// ]
pilih
Semantik
Menghasilkan tensor result yang setiap elemennya dipilih dari tensor on_true atau
on_false berdasarkan nilai elemen pred yang sesuai.
Secara lebih formal, result[result_index] = pred_element ? on_true[result_index] :
on_false[result_index], dengan pred_element = rank(pred) = 0 ? pred[] :
pred[result_index]. Untuk jenis yang dikuantisasi, melakukan
dequantize_select_quantize(pred, on_true, on_false, type(result)).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | pred |
tensor jenis i1 |
(C1) |
| (I2) | on_true |
tensor atau tensor terkuantisasi per tensor | (C1-C2) |
| (I3) | on_false |
tensor atau tensor terkuantisasi per tensor | (C2) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor atau tensor terkuantisasi per tensor | (C2) |
Batasan
- (C1)
rank(pred) = 0 or shape(pred) = shape(on_true). - (C2)
baseline_type(on_true) = baseline_type(on_false) = baseline_type(result).
Contoh
// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false) : (tensor<2x2xi1>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 2], [3, 8]]
select_and_scatter
Semantik
Menyebarkan nilai dari tensor source menggunakan scatter berdasarkan
hasil reduce_window dari tensor input menggunakan select dan menghasilkan
tensor result.
Diagram berikut menunjukkan cara elemen di result dihitung dari
operand dan source menggunakan contoh konkret.
Lebih formal:
selected_values = reduce_window_without_init(...)dengan input berikut:inputs = [operand].window_dimensions,window_strides, danpaddingyang digunakan apa adanya.base_dilations = windows_dilations = 1.bodydidefinisikan sebagai:
def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>: return select(arg0, arg1) ? arg0 : arg1;dengan
E = element_type(operand), danreduce_window_without_initberfungsi persis sepertireduce_window, kecuali bahwascheduledarireduceyang mendasarinya (lihat reduce) tidak menyertakan nilai init. Saat ini, tidak ditentukan apa yang terjadi jika jendela yang sesuai tidak memiliki nilai (#731).result[result_index] = reduce([source_values], [init_value], [0], scatter)di mana:source_values = [source[source_index] for source_index in source_indices].selected_index(source_index) = operand_indexjikaselected_values[source_index]memiliki elemenoperanddarioperand_index.source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index].
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor atau tensor terkuantisasi per tensor | (C1-C4), (C6), (C8-C11) |
| (I2) | source |
tensor atau tensor terkuantisasi per tensor | (C1), (C2) |
| (I3) | init_value |
Tensor 0 dimensi atau tensor terkuantisasi per tensor | (C3) |
| (I4) | window_dimensions |
Konstanta tensor 1 dimensi berjenis si64 |
(C2), (C4), (C5) |
| (I5) | window_strides |
Konstanta tensor 1 dimensi berjenis si64 |
(C2), (C6), (C7) |
| (I6) | padding |
Konstanta tensor 2 dimensi dari jenis si64 |
(C2), (C8) |
| (I7) | select |
fungsi | (C9) |
| (I8) | scatter |
fungsi | (C10) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor atau tensor terkuantisasi per tensor | (C11-C12) |
Batasan
- (C1)
element_type(operand) = element_type(source). - (C2)
shape(source) = num_windowsdengan:padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1].is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape.num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1.
- (C3)
element_type(init_value) = element_type(operand). - (C4)
size(window_dimensions) = rank(operand). - (C5)
0 < window_dimensions. - (C6)
size(window_strides) = rank(operand). - (C7)
0 < window_strides. - (C8)
shape(padding) = [rank(operand), 2]. - (C9)
selectmemiliki jenis(tensor<E>, tensor<E>) -> tensor<i1>denganE = element_type(operand). - (C10)
scattermemiliki jenis(tensor<E>, tensor<E>) -> tensor<E>denganis_promotable(element_type(operand), E). - (C11)
shape(operand) = shape(result). - (C12)
element_type(result) = E.
Contoh
// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
window_dimensions = array<i64: 3, 1>,
window_strides = array<i64: 2, 1>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi64>, tensor<2x2xi64>, tensor<i64>) -> tensor<4x2xi64>
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]
kirim
Semantik
Mengirim inputs ke channel channel_id. Input kemudian dikirim ke perangkat lain
dalam urutan yang ditentukan oleh source_target_pairs. Operasi ini menghasilkan token result.
Jika is_host_transfer adalah true, operasi akan mentransfer data ke
host. Jika tidak, data akan ditransfer ke perangkat lain berdasarkan nilai
source_target_pairs. Flag ini menduplikasi informasi yang diberikan di
channel_type, jadi pada masa mendatang kami berencana untuk hanya menyimpan salah satunya
(#666). Jika is_host_transfer
= false dan source_target_pairs adalah None atau kosong, maka dianggap sebagai
perilaku yang tidak ditentukan.
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | inputs |
jumlah tensor atau tensor terkuantisasi yang bervariasi | |
| (I2) | token |
token |
|
| (I3) | channel_id |
konstanta jenis si64 |
|
| (I4) | channel_type |
enum DEVICE_TO_DEVICE dan DEVICE_TO_HOST |
(C5) |
| (I5) | is_host_transfer |
konstanta jenis i1 |
(C5-C6) |
| (I6) | source_target_pairs |
Konstanta tensor 2 dimensi dari jenis si64 |
(C1-C4), (C6) |
Output
| Nama | Jenis |
|---|---|
result |
token |
Batasan
- (C1)
dim(source_target_pairs, 1) = 2. - (C2)
is_unique(source_target_pairs[:, 0]). - (C3)
is_unique(source_target_pairs[:, 1]). - (C4)
0 <= source_target_pairs < N, denganNditentukan sebagai:num_replicasjikacross_replicadigunakan.num_partitionsjikacross_partitiondigunakan.
- (C5)
channel_typeditentukan sebagai:DEVICE_TO_HOSTifis_host_transfer = true,DEVICE_TO_DEVICEjika tidak.
Contoh
%result = "stablehlo.send"(%operand, %token) {
channel_handle = #stablehlo.channel_handle<handle = 0, type = 1>,
is_host_transfer = false,
source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>
} : (tensor<2x2xi64>, !stablehlo.token) -> !stablehlo.token
shift_left
Semantik
Melakukan operasi pergeseran kiri per elemen pada tensor lhs dengan jumlah bit rhs dan menghasilkan tensor result.
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | lhs |
tensor jenis bilangan bulat | (C1) |
| (I2) | rhs |
tensor jenis bilangan bulat | (C1) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor jenis bilangan bulat | (C1) |
Batasan
- (C1)
type(lhs) = type(rhs) = type(result).
Contoh
// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-2, 0, 8]
shift_right_arithmetic
Semantik
Melakukan operasi pergeseran aritmatika ke kanan berdasarkan elemen pada tensor lhs dengan
rhs jumlah bit dan menghasilkan tensor result.
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | lhs |
tensor jenis bilangan bulat | (C1) |
| (I2) | rhs |
tensor jenis bilangan bulat | (C1) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor jenis bilangan bulat | (C1) |
Batasan
- (C1)
type(lhs) = type(rhs) = type(result).
Contoh
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-1, 0, 1]
shift_right_logical
Semantik
Melakukan operasi pergeseran bit logis menurut elemen pada tensor lhs sebanyak rhs
bit dan menghasilkan tensor result.
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | lhs |
tensor jenis bilangan bulat | (C1) |
| (I2) | rhs |
tensor jenis bilangan bulat | (C1) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor jenis bilangan bulat | (C1) |
Batasan
- (C1)
type(lhs) = type(rhs) = type(result).
Contoh
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [9223372036854775807, 0, 1]
tanda
Semantik
Menampilkan tanda elemen operand secara elemen demi elemen dan menghasilkan tensor result.
Secara lebih formal, untuk setiap elemen x, semantiknya dapat dinyatakan menggunakan sintaksis Python sebagai berikut:
def sign(x):
if is_integer(x):
if compare(x, 0, LT, SIGNED): return -1
if compare(x, 0, EQ, SIGNED): return 0
return 1
elif is_float(x):
if is_nan(x): return NaN
if compare(x, -0.0, EQ, FLOAT): return -0.0
if compare(x, +0.0, EQ, FLOAT): return +0.0
if compare(x, 0.0, LT, FLOAT): return -1.0
return 1.0
elif is_complex(x):
if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
return divide(x, convert(abs(x), type(x)))
Untuk jenis yang dikuantisasi, melakukan
dequantize_op_quantize(sign, operand, type(result)).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor bilangan bulat bertanda tangan, floating point, atau jenis kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor bilangan bulat bertanda tangan, floating point, atau jenis kompleks atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result).
Contoh
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
sinus
Semantik
Melakukan operasi sinus per elemen pada tensor operand dan menghasilkan tensor result. Bergantung pada jenis elemen, lakukan tindakan berikut:
- Untuk float:
sindari IEEE-754. - Untuk bilangan kompleks: sinus kompleks.
- Untuk jenis yang dikuantisasi:
dequantize_op_quantize(sine, operand, type(result)).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor jenis floating-point atau kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor jenis floating-point atau kompleks atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result).
Contoh
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.sine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [0.0, -1.0]]
slice
Semantik
Mengekstrak slice dari operand menggunakan indeks awal yang dihitung secara statis
dan menghasilkan tensor result. start_indices berisi indeks awal slice untuk setiap dimensi, limit_indices berisi indeks akhir (eksklusif) slice untuk setiap dimensi, dan strides berisi langkah untuk setiap dimensi.
Secara lebih formal, result[result_index] = operand[operand_index] dengan
operand_index = start_indices + result_index * strides.
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor atau tensor terkuantisasi per tensor | (C1-C3), (C5) |
| (I2) | start_indices |
Konstanta tensor 1 dimensi berjenis si64 |
(C2), (C3), (C5) |
| (I3) | limit_indices |
Konstanta tensor 1 dimensi berjenis si64 |
(C2), (C3), (C5) |
| (I4) | strides |
Konstanta tensor 1 dimensi berjenis si64 |
(C2), (C4) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor atau tensor terkuantisasi per tensor | (C1), (C5) |
Batasan
- (C1)
element_type(operand) = element_type(result). - (C2)
size(start_indices) = size(limit_indices) = size(strides) = rank(operand). - (C3)
0 <= start_indices <= limit_indices <= shape(operand). - (C4)
0 < strides. - (C5)
shape(result) = ceil((limit_indices - start_indices) / strides).
Contoh
// %operand: [
// [0, 0, 0, 0],
// [0, 0, 1, 1],
// [0, 0, 1, 1]
// ]
%result = "stablehlo.slice"(%operand) {
start_indices = array<i64: 1, 2>,
limit_indices = array<i64: 3, 4>,
strides = array<i64: 1, 1>
} : (tensor<3x4xi64>) -> tensor<2x2xi64>
// % result: [
// [1, 1],
// [1, 1]
// ]
mengurutkan
Semantik
Mengurutkan irisan 1 dimensi dari inputs di sepanjang dimensi dimension bersama-sama,
sesuai dengan comparator dan menghasilkan results.
Tidak seperti input serupa dalam operasi lain, dimension memungkinkan nilai negatif,
dengan semantik yang dijelaskan di bawah. Pada masa mendatang, hal ini mungkin tidak diizinkan
untuk alasan konsistensi
(#1377).
Jika is_stable benar (true), pengurutan akan stabil, yaitu urutan relatif elemen yang dianggap sama oleh pembanding akan dipertahankan. Untuk kasus
dengan satu input, dua elemen e1 dan e2 dianggap
sama oleh pembanding jika dan hanya jika
comparator(e1, e2) = comparator(e2, e1) = false. Lihat formalisasi di bawah
untuk mengetahui cara menggeneralisasi hal ini ke beberapa input.
Lebih formal, untuk semua result_index di index_space(results[0]):
adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension.result_slice = [ri0, ..., :, ..., riR-1]denganriNadalah elemen individu dalamresult_index, dan:disisipkan diadjusted_dimension.inputs_together = (inputs[0]..., ..., inputs[N-1]...).results_together[result_slice] = sort(inputs_together[result_slice], comparator_together).- dengan
sortmengurutkan slice 1 dimensi dalam urutan tidak menurun, dancomparator_togethermenampilkantruejika argumen sisi kiri kurang dari argumen kedua sisi kanan. def comparator_together(lhs_together, rhs_together): args = [] for (lhs_el, rhs_el) in zip(lhs_together, rhs_together): args.append(lhs_el) args.append(rhs_el) return comparator(*args)(results[0]..., ..., results[N-1]...) = results_together.
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | inputs |
jumlah tensor variadik atau tensor terkuantisasi per tensor | (C1-C5) |
| (I2) | dimension |
konstanta jenis si64 |
(C4) |
| (I3) | is_stable |
konstanta jenis i1 |
|
| (I4) | comparator |
fungsi | (C5) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
results |
jumlah tensor variadik atau tensor terkuantisasi per tensor | (C2), (C3) |
Batasan
- (C1)
0 < size(inputs). - (C2)
type(inputs...) = type(results...). - (C3)
same(shape(inputs...) + shape(results...)). - (C4)
-R <= dimension < R, denganR = rank(inputs[0]). - (C5)
comparatormemiliki jenis(tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1>, denganEi = element_type(inputs[i]).
Contoh
// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>, %arg3: tensor<i64>):
%predicate = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GT>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"stablehlo.return"(%predicate) : (tensor<i1>) -> ()
}) {
dimension = 0 : i64,
is_stable = true
} : (tensor<2x3xi64>, tensor<2x3xi64>) -> (tensor<2x3xi64>, tensor<2x3xi64>)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]
sqrt
Semantik
Melakukan operasi akar kuadrat per elemen pada tensor operand dan menghasilkan tensor
result. Bergantung pada jenis elemen, lakukan tindakan berikut:
- Untuk float:
squareRootdari IEEE-754. - Untuk bilangan kompleks: akar kuadrat kompleks.
- Untuk jenis yang dikuantisasi:
dequantize_op_quantize(sqrt, operand, type(result)).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor jenis floating-point atau kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor jenis floating-point atau kompleks atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result).
Contoh
// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [2.0, 3.0]]
mengurangi
Semantik
Melakukan pengurangan per elemen dari dua tensor lhs dan rhs serta menghasilkan tensor result. Bergantung pada jenis elemen, lakukan tindakan berikut:
- Untuk bilangan bulat: pengurangan bilangan bulat.
- Untuk float:
subtractiondari IEEE-754. - Untuk bilangan kompleks: pengurangan bilangan kompleks.
- Untuk jenis yang dikuantisasi:
dequantize_op_quantize(subtract, lhs, rhs, type(result)).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | lhs |
tensor bilangan bulat, floating point, atau jenis kompleks atau tensor terkuantisasi per-tensor | (C1) |
| (I2) | rhs |
tensor bilangan bulat, floating point, atau jenis kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor bilangan bulat, floating point, atau jenis kompleks atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).
Contoh
// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
// %result: [[1, 2], [3, 4]]
tan
Semantik
Melakukan operasi tangen per elemen pada tensor operand dan menghasilkan tensor result. Bergantung pada jenis elemen, lakukan tindakan berikut:
- Untuk float:
tandari IEEE-754. - Untuk bilangan kompleks: tangen kompleks.
- Untuk jenis yang dikuantisasi:
dequantize_op_quantize(tan, operand, type(result)).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor jenis floating-point atau kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor jenis floating-point atau kompleks atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result).
Contoh
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.tan"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [
// [0.0, 1.63312e+16],
// [0.0, 5.44375e+15]
// ]
tanh
Semantik
Melakukan operasi tangen hiperbolik per elemen pada tensor operand dan menghasilkan tensor result. Bergantung pada jenis elemen, lakukan tindakan berikut:
- Untuk float:
tanhdari IEEE-754. - Untuk bilangan kompleks: tangen hiperbolik kompleks.
- Untuk jenis yang dikuantisasi:
dequantize_op_quantize(tanh, operand, type(result)).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor jenis floating-point atau kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor jenis floating-point atau kompleks atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result).
Contoh
// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand) : (tensor<3xf32>) -> tensor<3xf32>
// %result: [-0.76159416, 0.0, 0.76159416]
memindahkan
Semantik
Memindahkan dimensi tensor operand menggunakan permutation dan menghasilkan tensor result. Secara lebih formal, result[result_index] = operand[operand_index]
dengan result_index[d] = operand_index[permutation[d]].
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor atau tensor terkuantisasi | (C1-C4) |
| (I2) | permutation |
Konstanta tensor 1 dimensi berjenis si64 |
(C2-C4) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor atau tensor terkuantisasi | (C1), (C3-C4) |
Batasan
- (C1)
element_type(result)diberikan oleh:element_type(operand), jika!is_per_axis_quantized(operand).element_type(operand)kecualiquantization_dimension(operand)danquantization_dimension(result)mungkin berbeda.
- (C2)
permutationadalah permutasi darirange(rank(operand)). - (C3)
shape(result) = dim(operand, permutation...). - (C4) If
is_per_axis_quantized(result), thenquantization_dimension(operand) = permutation(quantization_dimension(result)).
Contoh
// %operand: [
// [[1,2], [3,4], [5,6]],
// [[7,8], [9,10], [11,12]]
// ]
%result = "stablehlo.transpose"(%operand) {
permutation = array<i64: 2, 1, 0>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
// %result: [
// [[1,7], [3,9], [5,11]],
// [[2,8], [4,10], [6,12]]
// ]
triangular_solve
Semantik
Menyelesaikan batch sistem persamaan linear dengan matriks koefisien segitiga bawah atau atas.
Secara lebih formal, mengingat a dan b, result[i0, ..., iR-3, :, :] adalah solusi
untuk op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :] saat left_side adalah
true atau x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :] saat
left_side adalah false, yang memecahkan variabel x saat op(a) ditentukan
oleh transpose_a, yang dapat berupa salah satu dari berikut ini:
NO_TRANSPOSE: Lakukan operasi menggunakanaapa adanya.TRANSPOSE: Lakukan operasi pada transposisia.ADJOINT: Melakukan operasi pada transpose konjugata.
Data input hanya dibaca dari segitiga bawah a, jika lower adalah true atau
segitiga atas a, jika tidak. Data output ditampilkan dalam segitiga yang sama;
nilai dalam segitiga lainnya ditentukan oleh implementasi.
Jika unit_diagonal benar, maka penerapan dapat mengasumsikan bahwa elemen
diagonal a sama dengan 1, jika tidak, perilakunya tidak ditentukan.
Untuk jenis yang dikuantisasi, melakukan
dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower,
unit_diagonal, transpose_a), a, b, type(result)).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | a |
tensor jenis floating-point atau kompleks atau tensor terkuantisasi per-tensor | (C1-C3) |
| (I2) | b |
tensor jenis floating-point atau kompleks atau tensor terkuantisasi per-tensor | (C1-C4) |
| (I3) | left_side |
konstanta jenis i1 |
(C3) |
| (I4) | lower |
konstanta jenis i1 |
|
| (I5) | unit_diagonal |
konstanta jenis i1 |
|
| (I6) | transpose_a |
enum NO_TRANSPOSE, TRANSPOSE, dan ADJOINT |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor jenis floating-point atau kompleks atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_element_type(a) = baseline_element_type(b). - (C2)
2 <= rank(a) = rank(b) = R. - (C3) Hubungan antara
shape(a)danshape(b)ditentukan sebagai berikut:shape(a)[:-3] = shape(b)[:-3].dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1).
- (C4)
baseline_type(b) = baseline_type(result).
Contoh
// %a = [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
// %b = [
// [2.0, 0.0, 0.0],
// [4.0, 8.0, 0.0],
// [6.0, 10.0, 12.0]
// ]
%result = "stablehlo.triangular_solve"(%a, %b) {
left_side = true,
lower = true,
unit_diagonal = false,
transpose_a = #stablehlo<transpose NO_TRANSPOSE>
} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// %result: [
// [2.0, 0.0, 0.0],
// [0.0, 2.0, 0.0],
// [0.0, 0.0, 2.0]
// ]
tuple
Semantik
Menghasilkan tuple result dari nilai val.
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | val |
jumlah nilai variadik | (C1) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tuple | (C1) |
Batasan
- (C1)
resultmemiliki jenistuple<E0, ..., EN-1>denganEi = type(val[i]).
Contoh
// %val0: memref[1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1) : (memref<2xf32>, tuple<tensor<i32>>) -> tuple<memref<2xf32>, tuple<tensor<i32>>>
// %result: (memref[1.0, 2.0], (3))
uniform_dequantize
Semantik
Melakukan konversi per elemen tensor terkuantisasi operand ke tensor floating point result sesuai dengan parameter kuantisasi yang ditentukan oleh jenis operand.
Secara lebih formal, result = dequantize(operand).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor terkuantisasi | (C1), (C2) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor jenis floating-point | (C1), (C2) |
Batasan
- (C1)
shape(operand) = shape(result). - (C2)
element_type(result) = expressed_type(operand).
Contoh
// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
// %result: [4.0, 15.0]
uniform_quantize
Semantik
Melakukan konversi per elemen tensor floating point atau tensor terkuantisasi
operand menjadi tensor terkuantisasi result sesuai dengan parameter
kuantisasi yang ditentukan oleh jenis result.
Secara lebih formal,
- Jika
is_float(operand):result = quantize(operand, type(result)).
- Jika
is_quantized(operand):float_result = dequantize(operand).result = quantize(float_result, type(result)).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
tensor jenis floating-point atau terkuantisasi | (C1), (C2) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor terkuantisasi | (C1), (C2) |
Batasan
- (C1)
shape(operand) = shape(result). - (C2)
expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand).
Contoh
// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
// %result: [10, 10]
// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-20,0.2:-30}>>
// %result: [20, 45]
saat
Semantik
Menghasilkan output dari eksekusi fungsi body 0 kali atau lebih, sementara fungsi
cond menghasilkan true. Secara lebih formal, semantik dapat dinyatakan
menggunakan sintaksis Python sebagai berikut:
internal_state = operand
while cond(*internal_state):
internal_state = body(*internal_state)
results = internal_state
Perilaku loop tak terbatas akan ditentukan nanti (#383).
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | operand |
jumlah nilai variadik | (C1-C3) |
| (I2) | cond |
fungsi | (C1) |
| (I3) | body |
fungsi | (C2) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
results |
jumlah nilai variadik | (C3) |
Batasan
- (C1)
condmemiliki jenis(T0, ..., TN-1) -> tensor<i1>, denganTi = type(operand[i]). - (C2)
bodymemiliki jenis(T0, ..., TN-1) -> (T0, ..., TN-1), denganTi = type(operand[i]). - (C3)
type(results...) = type(operand...).
Contoh
// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%cond = "stablehlo.compare"(%arg0, %ten) {
comparison_direction = #stablehlo<comparison_direction LT>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
stablehlo.return %cond : tensor<i1>
}, {
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%new_sum = stablehlo.add %arg1, %one : tensor<i64>
%new_i = stablehlo.add %arg0, %one : tensor<i64>
stablehlo.return %new_i, %new_sum : tensor<i64>, tensor<i64>
}) : (tensor<i64>, tensor<i64>) -> (tensor<i64>, tensor<i64>)
// %results0: 10
// %results1: 10
xor
Semantik
Melakukan XOR per elemen dari dua tensor lhs dan rhs serta menghasilkan tensor result. Bergantung pada jenis elemen, lakukan tindakan berikut:
- Untuk boolean: XOR logis.
- Untuk bilangan bulat: bitwise XOR.
Input
| Label | Nama | Jenis | Batasan |
|---|---|---|---|
| (I1) | lhs |
tensor jenis boolean atau bilangan bulat | (C1) |
| (I2) | rhs |
tensor jenis boolean atau bilangan bulat | (C1) |
Output
| Nama | Jenis | Batasan |
|---|---|---|
result |
tensor jenis boolean atau bilangan bulat | (C1) |
Batasan
- (C1)
type(lhs) = type(rhs) = type(result).
Contoh
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, false]]
Interop Dialek
Saat ini, program StableHLO di luar sana terkadang berisi operasi yang tidak ditentukan oleh StableHLO.
Modul, Fungsi, Panggilan, dan Pengembalian
StableHLO menggunakan operasi MLIR upstream untuk ModuleOp, FuncOp, CallOp, dan ReturnOp. Hal ini dilakukan untuk interoperabilitas yang lebih baik dengan mekanisme MLIR yang ada, karena banyak pass yang berguna ditulis dengan menargetkan FuncOp dan ModuleOp, dan banyak pipeline kompilasi mengharapkan kehadiran operasi ini. Jaminan kompatibilitas penuh diterapkan pada operasi ini. Jika ada perubahan yang tidak kompatibel (yaitu penghapusan) pada operasi ini, padanan StableHLO akan ditambahkan untuk mempertahankan kompatibilitas.
CHLO
Opset CHLO berisi operasi tingkat yang lebih tinggi yang didekomposisi ke StableHLO. Saat ini tidak ada jaminan kompatibilitas untuk CHLO. Untuk jaminan kompatibilitas, chlo-legalize-to-stablehlo pass harus digunakan sebelum serialisasi.
Operasi Bentuk
Penggunaan operasi tertentu dari dialek MLIR inti dalam program StableHLO dinamis untuk melakukan komputasi bentuk adalah kasus penggunaan umum dalam komunitas.
Biasanya, ini mencakup operasi shape dialek
seperti shape_of atau num_elements, operasi tensor dialek
seperti dim atau from_elements, dan jenis index bawaan.
RFC Dinamisme > O2
menandai hal ini sebagai di luar cakupan, tetapi beberapa dukungan untuk jenis index disertakan untuk tujuan interoperabilitas. Tidak ada jaminan kompatibilitas untuk operasi atau jenis ini. Penerusan shape-legalize-to-stablehlo
dapat digunakan untuk mengonversi operasi ini menjadi operasi StableHLO yang didukung sepenuhnya.
Operasi yang Tidak Digunakan Lagi
Ada beberapa operasi StableHLO yang diwarisi dari MHLO yang tidak digunakan lagi dan akan dihapus dari StableHLO. Detail lengkap tentang penghapusan ini dapat ditemukan di Pembersihan StableHLO v1.0 #2283. Masalah pelacak untuk penghentian penggunaan ini adalah #2340.
Operasi ini dikelompokkan ke dalam beberapa kategori:
- Kategori "Tidak ada di HLO" untuk operasi StableHLO - awalnya merupakan bagian dari
opset StableHLO, tetapi kemudian dianggap tidak cocok:
broadcast,create_token,cross-replica-sum,dot,einsum,torch_index_select,unary_einsum(#3). - Operasi yang tidak digunakan - Operasi ini mungkin berguna pada suatu waktu, tetapi operasi tersebut kurang dikembangkan, atau pipeline yang menggunakan operasi ini telah difaktorkan ulang sehingga tidak memerlukannya lagi. Ini mencakup perbandingan
map,tuple(#598),get_tuple_element,rng,complex#560, dan konvolusiwindow_reversal(#1181).
Beberapa operasi ini dapat dihapus dengan mudah karena dapat dinyatakan menggunakan
operasi yang ada (broadcast, create_token, cross-replica-sum, dot,
unary_einsum) dan akan dihapus setelah periode kompatibilitas yang ada
berakhir (6 bulan). Yang lainnya masih dalam proses penghapusan (einsum, get_tuple_element, map, rng torch_index_select, tuple, complex perbandingan, window_reversal). Menunggu masukan dari komunitas, operasi ini akan dihapus atau ditambahkan ke spesifikasi dengan dukungan penuh. Hingga masa depan operasi ini diketahui, masa depan operasi ini hanya dijamin kompatibel selama 6 bulan.
Eksekusi
Eksekusi berurutan
Program StableHLO dieksekusi dengan memberikan nilai input ke fungsi main
dan menghitung nilai output. Nilai output fungsi dihitung dengan
mengeksekusi grafik operasi yang berakar pada operasi return yang sesuai.
Urutan eksekusi ditentukan oleh implementasi selama selaras dengan
alur data, yaitu jika operasi dieksekusi sebelum digunakan. Di StableHLO, semua operasi yang menimbulkan efek samping menggunakan satu token dan menghasilkan satu token (beberapa token dapat di-multiplex menjadi satu token melalui after_all), sehingga urutan eksekusi efek samping juga selaras dengan alur data. Misalnya, dalam program di bawah ini
ada dua kemungkinan urutan eksekusi: %0 → %1 → %2 → return dan
%1 → %0 → %2 → return.
func.func @main() -> tensor<f64> {
%0 = stablehlo.constant dense<1.0> : tensor<f64>
%1 = stablehlo.constant dense<2.0> : tensor<f64>
%2 = stablehlo.add %0, %1 : tensor<f64>
return %2 : tensor<f64>
}
Secara lebih formal, proses StableHLO adalah kombinasi dari:
1) program StableHLO, 2) status operasi (belum dieksekusi,
sudah dieksekusi), dan 3) nilai perantara yang sedang diproses.
Proses dimulai dengan nilai input ke fungsi main, berlanjut melalui
grafik operasi yang memperbarui status operasi dan nilai perantara, serta
berakhir dengan nilai output. Formalisasi lebih lanjut akan ditentukan nanti
(#484).
Eksekusi paralel
Program StableHLO dapat dieksekusi secara paralel, yang disusun ke dalam petak proses 2D
num_replicas kali num_partitions yang keduanya memiliki jenis ui32.
Dalam petak proses StableHLO, num_replicas * num_partitions proses StableHLO dijalankan secara bersamaan. Setiap proses memiliki
process_id = (replica_id, partition_id) unik, dengan
replica_id dalam replica_ids = range(num_replicas) dan
partition_id dalam partition_ids = range(num_partitions) yang keduanya memiliki
jenis ui32.
Ukuran petak proses diketahui secara statis untuk setiap program (pada masa mendatang, kami berencana menjadikannya bagian eksplisit dari program StableHLO #650), dan posisi dalam petak proses diketahui secara statis untuk setiap proses. Setiap proses memiliki
akses ke posisinya dalam petak proses melalui operasi replica_id dan
partition_id.
Dalam petak proses, semua program dapat sama (dalam gaya "Satu Program, Beberapa Data"), dapat berbeda (dalam gaya "Beberapa Program, Beberapa Data"), atau di antaranya. Pada masa mendatang, kami berencana untuk memperkenalkan dukungan untuk idiom lain dalam menentukan program StableHLO paralel, termasuk GSPMD (#619).
Dalam petak proses, proses sebagian besar independen satu sama lain - proses memiliki status operasi terpisah, nilai input/perantara/output terpisah dan sebagian besar operasi dijalankan secara terpisah di antara proses, dengan pengecualian sejumlah kecil operasi kolektif yang dijelaskan di bawah.
Mengingat bahwa eksekusi sebagian besar operasi hanya menggunakan nilai dari proses yang sama, biasanya tidak ambigu untuk merujuk ke nilai ini berdasarkan namanya.
Namun, saat mendeskripsikan semantik operasi kolektif, hal itu tidak cukup, dan
hal itu memunculkan notasi name@process_id untuk merujuk ke nilai name
dalam proses tertentu. (Dari perspektif tersebut, name yang tidak memenuhi syarat dapat
dilihat sebagai singkatan dari name@(replica_id(), partition_id())).
Urutan eksekusi di seluruh proses ditentukan oleh implementasi, kecuali untuk sinkronisasi yang diperkenalkan oleh komunikasi point-to-point dan operasi kolektif seperti yang dijelaskan di bawah.
Komunikasi titik ke titik
Proses StableHLO dapat berkomunikasi satu sama lain melalui
channel StableHLO. Channel diwakili oleh ID positif berjenis
si64. Melalui berbagai operasi, nilai dapat dikirim ke channel dan diterima dari channel.
Formalisasi lebih lanjut, misalnya dari mana ID channel ini berasal, bagaimana proses program mengetahuinya, dan jenis sinkronisasi apa yang diperkenalkan olehnya, akan ditentukan kemudian (#484).
Komunikasi streaming
Setiap proses StableHLO memiliki akses ke dua antarmuka streaming:
- Infeed yang dapat dibaca.
- Outfeed yang dapat ditulis.
Tidak seperti channel, yang digunakan untuk berkomunikasi antar-proses dan oleh karena itu memiliki proses di kedua ujungnya, infeeds dan outfeeds memiliki ujung lainnya yang ditentukan implementasinya.
Formalisasi lebih lanjut, misalnya, bagaimana komunikasi streaming memengaruhi urutan eksekusi dan jenis sinkronisasi apa yang diperkenalkan olehnya, akan ditentukan kemudian (#484).
Operasi kolektif
Ada enam operasi kolektif di StableHLO: all_gather, all_reduce,
all_to_all, collective_broadcast, collective_permute, dan
reduce_scatter. Semua operasi ini membagi proses dalam petak proses StableHLO menjadi grup proses StableHLO dan mengeksekusi komputasi gabungan dalam setiap grup proses, secara terpisah dari grup proses lainnya.
Dalam setiap grup proses, operasi kolektif dapat memperkenalkan penghalang sinkronisasi. Formalisasi lebih lanjut, misalnya, menjelaskan kapan sinkronisasi ini terjadi, bagaimana proses mencapai penghalang ini, dan apa yang terjadi jika tidak, akan ditentukan kemudian (#484).
Jika grup proses melibatkan komunikasi lintas partisi, yaitu ada proses dalam grup proses yang ID partisinya berbeda, maka eksekusi operasi kolektif memerlukan saluran, dan operasi kolektif harus memberikan channel_id positif berjenis si64. Komunikasi lintas replika tidak memerlukan
saluran.
Komputasi yang dilakukan oleh operasi kolektif bersifat khusus untuk setiap operasi dan dijelaskan di bagian operasi individual di atas. Namun, strategi yang digunakan untuk membagi petak proses menjadi grup proses digunakan bersama oleh operasi ini dan dijelaskan di bagian ini. Secara lebih formal, StableHLO mendukung empat strategi berikut.
cross_replica
Hanya komunikasi lintas replika yang terjadi dalam setiap grup proses. Strategi
ini mengambil replica_groups - daftar daftar ID replika - dan menghitung
produk Cartesian dari replica_groups dengan partition_ids. replica_groups
harus memiliki elemen unik dan mencakup semua replica_ids. Secara lebih formal, menggunakan
sintaksis Python:
def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
for partition_id in partition_ids:
process_group = []
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
Misalnya, untuk replica_groups = [[0, 1], [2, 3]] dan num_partitions = 2,
cross_replica akan menghasilkan
[[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]].
cross_partition
Hanya komunikasi lintas partisi yang terjadi dalam setiap grup proses. Strategi ini menggunakan partition_groups - daftar daftar ID partisi - dan menghitung produk Cartesian dari partition_groups dengan replica_ids.
partition_groups harus memiliki elemen unik dan mencakup semua partition_ids.
Secara lebih formal, menggunakan sintaksis Python:
def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
for partition_group in partition_groups:
for replica_id in replica_ids:
process_group = []
for partition_id in partition_group:
process_group.append((replica_id, partition_id))
yield process_group
Misalnya, untuk partition_groups = [[0, 1]] dan num_replicas = 4,
cross_partition akan menghasilkan
[[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]].
cross_replica_and_partition
Komunikasi lintas replika dan lintas partisi dapat terjadi dalam setiap grup proses. Strategi ini mengambil replica_groups - daftar daftar
ID replika - dan menghitung produk Cartesian dari setiap replica_group menurut
partition_ids. replica_groups harus memiliki elemen unik dan mencakup semua
replica_ids. Secara lebih formal, menggunakan sintaksis Python:
def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
process_group = []
for partition_id in partition_ids:
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
Misalnya, untuk replica_groups = [[0, 1], [2, 3]] dan num_partitions = 2,
cross_replica_and_partition akan menghasilkan
[[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]].
flattened_ids
Strategi ini menggunakan flattened_id_groups - daftar daftar ID proses "diratakan"
dalam bentuk replica_id * num_partitions + partition_id - dan
mengubahnya menjadi ID proses. flattened_id_groups harus memiliki elemen unik
dan mencakup semua process_ids. Secara lebih formal, menggunakan sintaksis Python:
def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
for flattened_id_group in flattened_id_groups:
process_group = []
for flattened_id in flattened_id_group:
replica_id = flattened_id // num_partitions
partition_id = flattened_id % num_partitions
process_group.append((replica_id, partition_id))
yield process_group
Misalnya, untuk flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]],
num_replicas = 4, dan num_partitions = 2, flattened_ids akan menghasilkan
[[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]].
Akurasi
Saat ini, StableHLO tidak memberikan jaminan tentang akurasi numerik, tetapi hal ini dapat berubah pada masa mendatang (#1156).
Semantik eksekusi operasi terkuantisasi
Interpretasi operasi StableHLO yang dikuantisasi dapat bervariasi bergantung pada persyaratan dan kemampuan hardware. Misalnya, beberapa hardware dapat memilih untuk menafsirkan operasi terkuantisasi menggunakan strategi "dekuantisasi, lakukan operasi titik mengambang, dan terakhir kuantisasi". Yang lain mungkin melakukan seluruh penghitungan dengan aritmetika bilangan bulat. Oleh karena itu, interpretasi operasi StableHLO yang dikuantisasi ditentukan secara eksklusif oleh implementasi tertentu. Interpretasi kuantisasi hybrid (#1575) harus didasarkan pada semantiknya seperti yang ditetapkan dalam spesifikasi (melalui 1792).
Error
Program StableHLO divalidasi melalui serangkaian batasan yang ekstensif untuk setiap operasi, yang menghilangkan banyak kelas error sebelum waktu proses. Namun, kondisi error masih mungkin terjadi, misalnya melalui overflow bilangan bulat, akses di luar batas, dll. Kecuali dinyatakan secara eksplisit, semua error ini menghasilkan perilaku yang ditentukan implementasi, tetapi hal ini dapat berubah pada masa mendatang (#1157).
Pengecualian floating point
Sebagai pengecualian untuk aturan ini, pengecualian floating point dalam program StableHLO memiliki perilaku yang ditentukan dengan baik. Operasi yang menghasilkan pengecualian yang ditentukan oleh standar IEEE-754 (operasi tidak valid, pembagian dengan nol, overflow, underflow, atau pengecualian tidak tepat) menghasilkan hasil default (sebagaimana ditentukan dalam standar) dan melanjutkan eksekusi tanpa memunculkan tanda status yang sesuai; mirip dengan penanganan pengecualian raiseNoFlag dari standar. Pengecualian untuk operasi nonstandar (misalnya, aritmetika kompleks dan fungsi transendental tertentu) ditentukan oleh implementasi.
Bentuk tidak cocok
StableHLO mendukung tensor berbentuk dinamis. Namun, bentuk harus disetujui saat runtime, jika tidak, perilakunya tidak ditentukan. StableHLO tidak secara eksplisit menyediakan operasi yang dapat menegaskan bahwa tensor memiliki bentuk tertentu saat runtime. Membuat kode yang benar adalah tanggung jawab produser.
Sebagai contoh spesifik, program di bawah ini valid. Namun, saat runtime, bentuk %arg0 dan %arg1 harus sama persis, jika tidak, perilaku program tidak akan ditentukan:
func.func @foo(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<?xi32>
return %0 : tensor<?xi32>
}
Notasi
Untuk menjelaskan sintaksis, dokumen ini menggunakan sintaksis EBNF versi ISO yang telah dimodifikasi (ISO/IEC 14977:1996,
Wikipedia),
dengan dua modifikasi: 1) aturan ditentukan menggunakan ::=, bukan =,
2) penggabungan dinyatakan menggunakan berdampingan, bukan ,.
Untuk mendeskripsikan semantik (yaitu dalam bagian "Types", "Constants", dan "Ops"), kami menggunakan formula yang didasarkan pada sintaksis Python yang diperluas dengan dukungan untuk secara ringkas mengekspresikan operasi array seperti yang dijelaskan di bawah. Cara ini berfungsi dengan baik untuk cuplikan kode kecil, tetapi dalam kasus yang jarang terjadi ketika cuplikan kode yang lebih besar diperlukan, kami menggunakan sintaksis Python standar yang selalu diperkenalkan secara eksplisit.
Formula
Mari kita pelajari cara kerja formula berdasarkan contoh dari spesifikasi dot_general. Salah satu batasan untuk operasi ini terlihat sebagai berikut:
dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...).
Nama yang digunakan dalam formula ini berasal dari dua sumber: 1) fungsi global,
yaitu dim, 2) definisi anggota elemen program yang sesuai, yaitu
input lhs, lhs_batching_dimensions, rhs, dan rhs_batching_dimensions
yang ditentukan di bagian "Input" dot_general.
Seperti yang disebutkan di atas, sintaksis formula ini berbasis Python dengan beberapa ekstensi yang berorientasi pada ringkasan. Untuk memahami formula, mari kita ubah formula tersebut menjadi sintaksis Python standar.
A) Dalam formula ini, kita menggunakan = untuk merepresentasikan kesamaan, jadi langkah pertama
untuk mendapatkan sintaksis Python adalah mengganti = dengan ==, sebagai berikut:
dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...).
B) Selain itu, formula ini mendukung elipsis (...) yang mengubah ekspresi skalar
menjadi ekspresi tensor. Singkatnya, f(xs...) secara kasar berarti "untuk setiap
skalar x dalam tensor xs, hitung skalar f(x), lalu tampilkan semua
hasil skalar ini bersama-sama sebagai hasil tensor". Dalam sintaksis Python standar,
contoh formula kita akan menjadi:
[dim(lhs, dim1) for dim1 in lhs_batching_dimensions] ==
[dim(rhs, dim2) for dim2 in rhs_batching_dimensions].
Berkat elipsis, Anda sering kali dapat menghindari bekerja di tingkat skalar individual. Namun, dalam beberapa kasus yang rumit, sintaksis semi-informal tingkat bawah dapat digunakan seperti dalam formula start_indices[bi0, ..., :, ..., biN] dari spesifikasi gather. Untuk menjaga keringkasan, kami tidak
memberikan formalisme yang tepat untuk menerjemahkan sintaksis tersebut ke Python standar, dengan
harapan bahwa sintaksis tersebut masih dapat dipahami secara intuitif berdasarkan kasus per kasus.
Beri tahu kami jika beberapa formula tertentu terlihat tidak jelas, dan kami akan mencoba memperbaikinya.
Selain itu, Anda akan melihat bahwa formula menggunakan elipsis untuk memperluas semua jenis daftar, termasuk tensor, daftar tensor (yang misalnya dapat muncul dari sejumlah tensor variadik), dll. Ini adalah area lain tempat kita tidak memberikan formalisme yang tepat (misalnya, daftar bahkan bukan bagian dari sistem jenis StableHLO) dan mengandalkan pemahaman intuitif.
C) Sarana notasi penting terakhir yang kita gunakan adalah penyiaran implisit. Meskipun opset StableHLO tidak mendukung penyiaran implisit, rumus tersebut mendukungnya, juga untuk tujuan ringkas. Singkatnya, jika skalar digunakan dalam konteks yang memerlukan tensor, skalar akan di-broadcast ke bentuk yang diharapkan.
Untuk melanjutkan contoh dot_general, berikut batasan lainnya:
0 <= lhs_batching_dimensions < rank(lhs). Seperti yang ditentukan dalam spesifikasi dot_general, lhs_batching_dimensions adalah tensor, tetapi 0 dan rank(lhs) adalah skalar. Setelah kita menerapkan penyiaran implisit, formula akan menjadi [0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)].
Jika diterapkan pada operasi dot_general tertentu, formula ini akan
dievaluasi menjadi tensor boolean. Jika formula digunakan sebagai batasan, batasan akan berlaku jika formula dievaluasi menjadi true atau tensor yang hanya memiliki elemen true.
Nama
Dalam formula, cakupan leksikal mencakup: 1) fungsi global, 2) definisi anggota,
3) definisi lokal. Daftar fungsi global disediakan di bawah. Daftar definisi elemen bergantung pada elemen program yang diterapkan notasi:
- Untuk operasi, definisi anggota mencakup nama yang diperkenalkan di bagian "Input" dan "Output".
- Untuk hal lainnya, definisi anggota mencakup bagian struktural elemen program, yang dinamai sesuai dengan non-terminal EBNF yang sesuai. Biasanya, nama bagian struktural ini diperoleh dengan mengonversi nama non-terminal ke snake case (misalnya,
IntegerLiteral=>integer_literal), tetapi terkadang nama disingkat dalam prosesnya (misalnya,QuantizationStorageType=>storage_type). Dalam hal ini, nama diperkenalkan secara eksplisit seperti bagian "Input" / "Output" dalam spesifikasi operasi. - Selain itu, definisi anggota selalu menyertakan
selfuntuk merujuk ke elemen program yang sesuai.
Nilai
Saat dievaluasi, formula berfungsi dengan jenis nilai berikut:
1) Value (nilai sebenarnya, misalnya dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>;
nilai ini selalu mengetahui jenisnya),
2) Placeholder (nilai mendatang, misalnya lhs, rhs, atau result; nilai sebenarnya belum diketahui, hanya jenisnya yang diketahui),
3) Type (jenis seperti yang ditentukan di bagian "Jenis"),
4) Function (fungsi global seperti yang ditentukan di bagian "Fungsi").
Bergantung pada konteksnya, nama dapat merujuk pada nilai yang berbeda. Lebih
khususnya, bagian "Semantik" untuk operasi (dan yang setara untuk elemen program lainnya) menentukan logika runtime, sehingga semua input tersedia sebagai Value.
Sebaliknya, bagian "Batasan" untuk operasi (dan yang setara) menentukan logika "waktu kompilasi", yaitu sesuatu yang biasanya dieksekusi sebelum runtime, sehingga hanya input konstan yang tersedia sebagai Value dan input lainnya hanya tersedia sebagai Placeholder.
| Nama | Di "Semantik" | Di "Batasan" |
|---|---|---|
| Fungsi global | Function |
Function |
| Input konstan | Value |
Value |
| Input non-konstan | Value |
Placeholder |
| Output | Value |
Placeholder |
| Definisi lokal | Bergantung pada definisi | Bergantung pada definisi |
Mari kita lihat contoh operasi transpose:
%result = "stablehlo.transpose"(%operand) {
permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
Untuk operasi ini, permutation adalah konstanta, sehingga tersedia sebagai Value
dalam semantik dan batasan. Sebaliknya, operand dan result tersedia sebagai Value dalam semantik, tetapi hanya sebagai Placeholder dalam batasan.
Fungsi
Konstruksi jenis
Tidak ada fungsi yang dapat digunakan untuk membuat jenis. Sebagai gantinya, kita langsung menggunakan sintaksis jenis karena biasanya lebih ringkas. Misalnya,
(tensor<E>, tensor<E>) -> (tensor<E>), bukan function_type(
[tensor_type([], E), tensor_type([], E)], [tensor_type([], E)]).
Fungsi pada jenis
element_typeditentukan pada jenis tensor dan jenis tensor terkuantisasi, serta masing-masing menampilkan bagianTensorElementTypeatauQuantizedTensorElementTypedariTensorTypeatauQuantizedTensorTypeyang sesuai.
def element_type(x: Value | Placeholder | Type):
if type(x) == TensorType:
return tensor_element_type(x)
if type(x) == QuantizedTensorType:
return quantized_tensor_element_type(x)
if type(x) is not Type:
return element_type(type(x))
is_per_axis_quantized(x: Value | Placeholder | Type) -> Valueadalah pintasan untukis_quantized(x) and quantization_dimension(x) is not None.is_per_tensor_quantized(x: Value | Placeholder | Type) -> Valueadalah pintasan untukis_quantized(x) and quantization_dimension(x) is None.is_promotable(x: Type, y: Type) -> boolmemeriksa apakah jenisxdapat dipromosikan ke jenisy. JikaxdanyadalahQuantizedTensorElementType, promosi hanya diterapkan padastorage_type. Versi promosi khusus ini saat ini digunakan dalam konteks komputasi pengurangan (lihat RFC untuk mengetahui detail selengkapnya).
def is_promotable(x: Type, y: Type) -> Value:
is_same_type = (is_bool(x) and is_bool(y)) or
(is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
(is_complex(x) and is_complex(y)) or
(is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))
if is_same_type == False:
return False
if is_integer(x) or is_float(x):
return bitwidth(x) <= bitwidth(y)
if is_complex(x):
return bitwidth(element_type(x)) <= bitwidth(element_type(y))
if is_quantized(x):
return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))
return false
is_quantized(x: Value | Placeholder | Type) -> Valueadalah pintasan untukis_quantized_tensor_element_type(x).is_type_name(x: Value | Placeholder | Type) -> Value. Tersedia untuk semua jenis. Misalnya,is_float(x)akan menampilkantruejikaxadalahFloatType. Jikaxadalah nilai atau placeholder, fungsi ini adalah pintasan untukis_type_name(type(x)).max_value(x: Type) -> Valuemenampilkan nilai maksimumTensorElementType. JikaxbukanTensorElementType,Noneakan ditampilkan.min_value(x: Type) -> Valuemenampilkan nilai minimum yang mungkin dariTensorElementType. JikaxbukanTensorElementType,Noneakan ditampilkan.member_name(x: Value | Placeholder | Type) -> Any. Tersedia untuk semua definisi anggotamember_namedari semua jenis. Misalnya,tensor_element_type(x)menampilkan bagianTensorElementTypedariTensorTypeyang sesuai. Jikaxadalah nilai atau placeholder, fungsi ini adalah pintasan untukmember_name(type(x)). Jikaxbukan jenis yang memiliki anggota yang sesuai, atau nilai atau placeholder dari jenis tersebut,Noneakan ditampilkan.is_empty_algorithm(*args: Type)memeriksa apakah semua kolom algoritma titik ditetapkan keNone. Hal ini diperlukan karena algoritma titik memiliki perilaku default yang ditentukan implementasinya, sehingga menentukan nilai default akan salah.
Konstruksi nilai
operation_name(*xs: Value | Type) -> Value. Tersedia untuk semua operasi. Misalnya,add(lhs, rhs)mengambil dua nilai tensorlhsdanrhs, lalu menampilkan output evaluasi operasiadddengan input ini. Untuk beberapa operasi, misalnyabroadcast_in_dim, jenis outputnya adalah "load-bearing", yaitu diperlukan untuk mengevaluasi operasi. Dalam hal ini, fungsi mengambil jenis ini sebagai argumen.
Fungsi pada nilai
Semua operator dan fungsi Python tersedia. Misalnya, notasi langganan dan pengirisan dari Python tersedia untuk mengindeks tensor, tensor terkuantisasi dan tuple.
to_destination_type(x: Value, destination_type: Type) -> Valueditentukan pada tensor dan menampilkan nilaixyang dikonversi berdasarkantype(x)dandestination_typesebagai berikut:
def to_destination_type(x: Value, destination_type: Type) -> Value:
if type(x) == destination_type:
return x
if is_quantized(destination_type):
if is_quantized(type(x)):
return quantize(x, destination_type)
assert is_float(type(x))
return quantize(x, destination_type)
if is_quantized(type(x)):
assert destination_type = expressed_type(type(x))
return dequantize(type(x))
return convert(x, destination_type)
Ada diskusi awal tentang menggabungkan operasi convert, uniform_quantize, dan
uniform_dequantize (#1576).
Setelah penggabungan, kita tidak memerlukan fungsi di atas dan dapat menggunakan nama operasi
untuk convert.
is_nan(x: Value) -> Valueditentukan pada tensor dan menampilkantruejika semua elemenxadalahNaNataufalse. Jikaxbukan tensor,Noneakan ditampilkan.is_sorted(x: Value) -> Valueditentukan pada tensor dan menampilkantruejika elemenxdiurutkan dalam urutan menaik sehubungan dengan urutan leksikografis menaik dari indeksnya ataufalsejika tidak. Jikaxbukan tensor,Noneakan ditampilkan.is_unique(x: Value) -> Valueditentukan pada tensor dan menampilkantruejikaxtidak memiliki elemen duplikat ataufalsejika tidak. Jikaxbukan tensor,Noneakan ditampilkan.member_name(x: Value) -> Anyditentukan untuk semua definisi anggotamember_namedari semua nilai. Misalnya,real_part(x)menampilkan bagianRealPartdariComplexConstantyang sesuai. Jikaxbukan nilai yang memiliki anggota yang sesuai,Noneakan ditampilkan.same(x: Value) -> Valueditentukan pada tensor dan menampilkantruejika elemenxsemuanya sama satu sama lain ataufalsejika sebaliknya. Jika tensor tidak memiliki elemen, hal itu dianggap sebagai "semua sama satu sama lain", yaitu fungsi menampilkantrue. Jikaxbukan tensor,Noneakan ditampilkan.split(x: Value, num_results: Value, axis: Value) -> Valueditentukan pada tensor dan menampilkan slicenum_resultsdarixdi sepanjang sumbuaxis. Jikaxbukan tensor ataudim(x, axis) % num_results != 0,Noneakan ditampilkan.is_defined_in_parent_scope(x: Value) -> Valueditentukan pada string dan menampilkantruejikaxadalah nama fungsi yang ditentukan dalam cakupan yang sama dengan fungsi induk dari operasi yang relevan.is_namespaced_op_name(x: Value) -> Valueditentukan pada string dan menampilkantruejikaxadalah nama operasi yang valid, yaitu mematuhi ekspresi reguler berikut:[a-zA-Z][a-zA-Z0-9_]*([.][a-zA-Z0-9_$]+)+
Komputasi bentuk
axes(x: Value | Placeholder | Type) -> Valueadalah pintasan untukrange(rank(x)).dim(x: Value | Placeholder | Type, axis: Value) -> Valueadalah pintasan untukshape(x)[axis].dims(x: Value | Placeholder | Type, axes: List) -> Listadalah pintasan untuklist(map(lambda axis: dim(x, axis), axes)).index_space(x: Value | Placeholder | Type) -> Valueditentukan pada tensor dan menampilkan indekssize(x)untukTensorTypeyang sesuai yang diurutkan dalam urutan leksikografis menaik, yaitu[0, ..., 0],[0, ..., 1], ...,shape(x) - 1. Jikaxbukan jenis tensor, jenis tensor terkuantisasi, atau nilai atau placeholder dari salah satu jenis ini,Noneakan ditampilkan.rank(x: Value | Placeholder | Type) -> Valueadalah pintasan untuksize(shape(x)).shape(x: Value | Placeholder | Type) -> Valueditentukan di bagian "Fungsi pada jenis" melaluimember_name.size(x: Value | Placeholder | Type) -> Valueadalah pintasan untukreduce(lambda x, y: x * y, shape(x)).
Komputasi kuantisasi
def baseline_element_type(x: Value | Placeholder | Type) -> Typeadalah pintasan untukelement_type(baseline_type(x)).baseline_typeditentukan pada jenis tensor dan jenis tensor terkuantisasi serta mengubahnya menjadi "dasar", yaitu jenis dengan bentuk yang sama, tetapi dengan parameter kuantisasi jenis elemen direset ke nilai default. Hal ini digunakan sebagai trik praktis untuk membandingkan jenis tensor dan tensor terkuantisasi secara seragam, yang cukup sering diperlukan. Untuk jenis yang dikuantisasi, hal ini memungkinkan membandingkan jenis dengan mengabaikan parameter kuantisasi, yaitu,shape,storage_type,expressed_type,storage_min,storage_max, danquantization_dimension(untuk jenis yang dikuantisasi per sumbu) harus cocok, tetapiscalesdanzero pointsmungkin berbeda.
def baseline_type(x: Value | Placeholder | Type) -> Type:
if type(x) == TensorType:
return x
if type(x) == QuantizedTensorType:
element_type = quantized_tensor_element_type(x)
baseline_element_type = QuantizedTensorElementType(
storage_type = storage_type(element_type),
storage_min = storage_min(element_type),
storage_max = storage_max(element_type),
expressed_type = expressed_type(element_type),
quantization_dimension = quantization_dimension(element_type),
scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
return QuantizedTensorType(shape(x), baseline_element_type)
if type(x) is not Type:
return baseline_element_type(type(x))
dequantizeditentukan pada jenis tensor terkuantisasi dan mengubahnya menjadi jenis tensor floating point. Hal ini terjadi melalui konversi elemen terkuantisasi yang merepresentasikan nilai bilangan bulat dari jenis penyimpanan ke nilai floating point yang sesuai dari jenis yang dinyatakan menggunakan titik nol dan skala yang terkait dengan jenis elemen terkuantisasi.
def compute_zero_points(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
zero_points[i] = zero_points(quantized_type)[i[d]]
return zero_points
def compute_scales(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
type(result_type))
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
scales[i] = scales(quantized_type)[i[d]]
return scales
def dequantize(x: Value) -> Value:
assert is_quantized(x)
x_storage = bitcast_convert(x, storage_type(x))
x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
x_expressed_sub = convert(x_storage_sub, expressed_type(x))
return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
quantizeditentukan pada jenis tensor floating point dan mengubahnya menjadi jenis tensor terkuantisasi. Hal ini terjadi melalui konversi nilai floating point dari jenis yang dinyatakan menjadi nilai bilangan bulat yang sesuai dari jenis penyimpanan menggunakan titik nol dan skala yang terkait dengan jenis elemen terkuantisasi.
def quantize(x: Value, result_type: Type) -> Value:
assert is_float(x) and is_quantized(result_type)
zero_points = compute_zero_points(result_type, TensorType(shape(x), storage_type(result_type)))
converted_zero_points = convert(zero_points, expressed_type(result_type))
converted_min = convert(storage_min(result_type), expressed_type(result_type))
converted_max = convert(storage_max(result_type), expressed_type(result_type))
x_scaled = x / compute_scales(result_type, type(x))
x_scaled_add_zp = x_scaled + converted_zero_points
x_clamped = clamp(converted_min, x_scaled_add_zp, converted_max)
x_rounded = round_nearest_even(x_clamped)
return convert(x_rounded, result_type)
dequantize_op_quantizedigunakan untuk menentukan komputasi per elemen pada tensor terkuantisasi. Dequantisasi, yaitu mengubah elemen terkuantisasi menjadi jenis yang dinyatakan, lalu melakukan operasi, dan kemudian menguantisasi, yaitu mengubah hasil kembali menjadi jenis penyimpanannya. Saat ini, fungsi ini hanya berfungsi untuk kuantisasi per-tensor. Kuantisasi per sumbu sedang dalam proses (#1574).
def dequantize_op_quantize(op, *inputs_and_output_type):
inputs = inputs_and_output_type[:-1]
output_type = inputs_and_output_type[-1]
float_inputs = map(dequantize, inputs)
float_result = op(*float_inputs)
return quantize(float_result, output_type)
def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
inputs = inputs_and_output_type[:-3]
float_inputs = map(dequantize, inputs)
float_results = op(*float_inputs)
return map(quantize, float_results, inputs_and_output_type[-3:])
def dequantize_compare(lhs, rhs, comparison_direction):
float_lhs = dequantize(lhs)
float_rhs = dequantize(rhs)
return compare(float_lhs, float_rhs, comparison_direction, FLOAT)
def dequantize_select_quantize(pred, on_true, on_false, output_type):
float_on_true = dequantize(on_true)
float_on_false = dequantize(on_false)
float_result = select(pred, float_on_true, float_on_false)
return quantize(float_result, output_type)
hybrid_dequantize_then_opdigunakan untuk menentukan kuantisasi khusus bobot untuk operasi hibrida yang menerima lhs dalam floating point dan rhs dalam jenis yang dikuantisasi. Operasi ini mendekuantisasi input terkuantisasi ke dalam jenis yang dinyatakan dan melakukan komputasi dalam float. Jenis elemen tensor lhs float dan jenis yang dinyatakan dari tensor rhs yang dikuantisasi harus identik.
def hybrid_dequantize_then_op(op, lhs, rhs):
assert(is_float(lhs) and is_quantized(rhs) and element_type(lhs) == expressed_type(rhs))
return op(lhs, dequantize(rhs))
Komputasi petak
cross_partition(replica_groups: Value) -> Value. Lihat bagian "cross_replica" di atas.cross_replica(replica_groups: Value) -> Value. Lihat bagian "cross_replica" di atas.cross_replica_and_partition(replica_groups: Value) -> Value. Lihat bagian "cross_replica_and_partition" di atas.flattened_ids(replica_groups: Value) -> Value. Lihat bagian "flattened_ids" di atas.
Dinamisme
Nilai StableHLO dapat memiliki ukuran dimensi dinamis, misalnya tensor<?xi64>.
Namun, nilai StableHLO tidak dapat memiliki jumlah dimensi yang dinamis (dinamisme tidak berperingkat, misalnya tensor<*xi64>). Operan dan hasil diizinkan menggunakan ukuran dimensi dinamis, meskipun ada batasan pada ukuran. Batasan akan diverifikasi secara statis jika memungkinkan, jika tidak, batasan akan ditangguhkan ke runtime dan ketidakcocokan akan menyebabkan perilaku yang tidak terdefinisi. Lihat contoh berikut.
Ketidakcocokan bentuk untuk operasi elementwise unari
Pertimbangkan program mainan berikut:
func.func @foo(%arg0: tensor<?xf64>) {
%0 = stablehlo.abs %arg0 : (tensor<?xf64>) -> tensor<2xf64>
return
}
Program semacam ini tidak biasa, karena biasanya kita mengetahui bentuk hasilnya, tetapi tidak mengetahui bentuk inputnya. Namun demikian, ini adalah program StableHLO
yang valid. Operasi abs tidak dapat divalidasi secara statis dalam program ini karena bentuk pasti operand tidak diketahui. Namun, bentuknya pasti kompatibel, dan hal ini dapat diperiksa secara statis: ? dapat berubah menjadi 2 saat runtime, dan tidak akan ada masalah. Namun, ? juga bisa
berupa bilangan bulat lain, yang dalam hal ini perilakunya tidak ditentukan.
Perhatikan bahwa jika ukuran dimensi bersifat dinamis dalam hasil, tidak boleh ada perilaku yang tidak ditentukan. Memang, tidak ada ukuran "yang diharapkan", jadi tidak mungkin ada ketidakcocokan.
Ketidakcocokan bentuk untuk operasi biner per elemen
Pertimbangkan program mainan berikut:
func.func @foo(%arg0: tensor<?xf64>, %arg1: tensor<?xf64>) {
%0 = stablehlo.add %arg0, %arg0 : (tensor<?xf64>, tensor<?xf64>) -> tensor<?xf64>
return
}
Untuk operasi biner per elemen, bentuk input dan hasil harus sama saat runtime. Pada waktu kompilasi, dimensi statis harus sama, jika tidak, dimensi tersebut hanya perlu kompatibel. Jika ada dimensi yang dinamis dalam input, maka perilaku yang tidak ditentukan dapat terjadi saat runtime, karena ukuran dinamis mungkin tidak cocok dengan ukuran yang sesuai dalam operand lain (baik statis maupun dinamis). Jika semua input bersifat statis, maka apakah hasilnya dinamis atau tidak tidak menjadi masalah: dimensi yang diketahui secara statis akan diperiksa secara statis, dan dimensi dinamis tidak memaksakan batasan apa pun.
Ketidakcocokan bentuk untuk operasi yang menggunakan bentuk outputnya sebagai operand
Pertimbangkan program mainan berikut:
func.func @foo(%arg0: tensor<2xi32>) {
%0 = stablehlo.dynamic_iota %arg0, dim = 0 : (tensor<2xi32>) -> tensor<3x4xi64>
return
}
Nilai dalam operand bentuk saat runtime harus cocok dengan bentuk hasil,
jika tidak, perilakunya tidak ditentukan. Artinya, saat runtime, %arg0 harus memiliki
nilai dense<[3, 4]> : tensor<2xi32>. Jika operand bentuk adalah konstanta, hal ini dapat diverifikasi secara statis. Jika bentuk hasilnya sepenuhnya dinamis, maka tidak boleh ada ketidakcocokan.