StableHLO adalah kumpulan operasi untuk operasi tingkat tinggi (HLO) dalam model machine learning (ML). StableHLO berfungsi sebagai lapisan portabilitas antara berbagai framework ML dan compiler ML: framework ML yang menghasilkan program StableHLO kompatibel dengan compiler ML yang menggunakan program StableHLO.
Sasaran kami adalah menyederhanakan dan mempercepat pengembangan ML dengan menciptakan lebih banyak interoperabilitas antara berbagai framework ML (seperti TensorFlow, JAX, dan PyTorch) serta compiler ML (seperti XLA dan IREE). Untuk mencapai tujuan tersebut, dokumen ini memberikan spesifikasi untuk bahasa pemrograman StableHLO.
Spesifikasi ini memuat tiga bagian utama. Pertama, bagian Program menjelaskan struktur program StableHLO yang terdiri dari fungsi StableHLO yang terdiri dari operasi StableHLO. Dalam struktur tersebut, bagian Ops menentukan semantik setiap operasi. Bagian Execution menyediakan semantik untuk semua operasi ini yang dijalankan bersama dalam program. Terakhir, bagian Notasi membahas notasi yang digunakan di seluruh spesifikasi.
Untuk melihat spesifikasi dari rilis StableHLO sebelumnya, buka repo di rilis yang diberi tag minat. Misalnya, Spesifikasi StableHLO v0.19.0. Untuk melihat perubahan yang terjadi pada setiap peningkatan versi minor StableHLO, lihat log versi di VhloDialect.td.
Program
Program ::= {Func}
Program StableHLO terdiri dari sejumlah fungsi StableHLO yang tidak ditentukan.
Berikut adalah contoh program dengan fungsi @main
yang memiliki 3 input
(%image
, %weights
, dan %bias
) dan 1 output. Isi fungsi memiliki 6 operasi.
func.func @main(
%image: tensor<28x28xf32>,
%weights: tensor<784x10xf32>,
%bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
%0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
%1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
%2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
%3 = "stablehlo.constant"() {value = dense<0.0> : tensor<1x10xf32>} : () -> tensor<1x10xf32>
%4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
"func.return"(%4): (tensor<1x10xf32>) -> ()
}
Fungsi
Func ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput ::= ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput ::= ValueType
FuncBody ::= {Op}
Fungsi StableHLO (yang juga disebut fungsi bernama) memiliki ID, input/output, dan isi. Di masa mendatang, kami berencana untuk memperkenalkan metadata tambahan untuk fungsi guna mencapai kompatibilitas yang lebih baik dengan HLO (#425, #626, #740, #744).
Pengenal
FuncId ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
| '%' letter {letter | digit}
letter ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit ::= '0' | ... | '9'
ID StableHLO mirip dengan ID dalam banyak bahasa pemrograman, dengan dua keunikan: 1) semua ID memiliki sigil yang membedakan berbagai jenis ID, 2) ID nilai dapat sepenuhnya numerik untuk menyederhanakan pembuatan program StableHLO.
Jenis
Type ::= ValueType | NonValueType
ValueType ::= TensorType | QuantizedTensorType | TokenType | TupleType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType
Jenis StableHLO dikategorikan ke dalam jenis nilai (yang juga disebut jenis kelas satu) yang mewakili nilai StableHLO dan jenis non-nilai yang menjelaskan elemen program lainnya. Jenis StableHLO mirip dengan jenis dalam banyak bahasa pemrograman, dengan keunikan utamanya adalah sifat khusus domain StableHLO yang menghasilkan beberapa hasil yang tidak biasa (misalnya, jenis skalar bukan merupakan jenis nilai).
TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit} | '?'
Jenis tensor merepresentasikan tensor, yaitu array multidimensi. Elemen ini memiliki
bentuk dan jenis elemen, dengan bentuk mewakili ukuran dimensi non-negatif atau
tidak diketahui dalam urutan menaik dari
dimensi yang sesuai (yang juga disebut sumbu) yang diberi nomor dari 0
hingga R-1
. Jumlah
dimensi R
disebut peringkat. Misalnya, tensor<2x3xf32>
adalah
jenis tensor dengan bentuk 2x3
dan jenis elemen f32
. Matriks ini memiliki dua dimensi
(atau, dengan kata lain, dua sumbu) - dimensi ke-0 dan dimensi ke-1 - yang ukurannya
adalah 2 dan 3. Peringkatnya adalah 2.
Bentuk dapat sebagian atau sepenuhnya tidak diketahui (dinamis), misalnya tensor<?x2xf64>
sebagian tidak diketahui dan tensor<?x?xf64>
sepenuhnya tidak diketahui. Ukuran dimensi
dinamis direpresentasikan menggunakan ?
. Bentuk tidak boleh diberi peringkat.
Di masa mendatang, kami berencana untuk mempelajari perluasan jenis tensor di luar ukuran dimensi dan jenis elemen, misalnya, untuk menyertakan tata letak (#629) dan sparsitas (#1078).
QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
QuantizationStorageType
['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
':' QuantizationExpressedType
[':' QuantizationDimension]
',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerLiteral
QuantizationStorageMax ::= IntegerLiteral
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerLiteral
QuantizationParameters ::= QuantizationParameter
| '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale [':' QuantizationZeroPoint]
QuantizationScale ::= FloatLiteral
QuantizationZeroPoint ::= IntegerLiteral
Nama | Jenis | Batasan |
---|---|---|
storage_type |
jenis bilangan bulat | (C1-C3), (C8) |
storage_min |
konstanta bilangan bulat | (C1), (C3), (C7) |
storage_max |
konstanta bilangan bulat | (C2), (C3), (C7) |
expressed_type |
jenis floating point | (C4) |
quantization_dimension |
konstanta bilangan bulat opsional | (C10-C12) |
scales |
jumlah variabel konstanta floating point | (C4-C6), (C9), (C10), (C13) |
zero_points |
jumlah variabel konstanta integer | (C7-C9) |
Jenis elemen kuantisasi mewakili nilai bilangan bulat dari jenis penyimpanan dalam
rentang dari storage_min
hingga storage_max
(inklusif) yang sesuai dengan
nilai floating point dari jenis yang dinyatakan. Untuk nilai bilangan bulat i
tertentu,
nilai floating point yang sesuai dengan f
dapat dikomputasi sebagai
f = (i - zero_point) * scale
, dengan scale
dan zero_point
disebut
parameter kuantisasi. storage_min
dan storage_max
bersifat opsional
dalam tata bahasa, tetapi memiliki nilai default min_value(storage_type)
dan
max_value(storage_type)
. Jenis elemen terkuantisasi memiliki
batasan berikut:
- (C1)
type(storage_min) = storage_type
. - (C2)
type(storage_max) = storage_type
. - (C3)
min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type)
. - (C4)
type(scales...) = expressed_type
. - (C5)
0 < scales
. - (C6)
is_finite(scales...)
. - (C7)
storage_min <= zero_points <= storage_max
. - (C8)
type(zero_points...) = storage_type
. - (C9)
size(scales) = size(zero_points)
. - (C10) Jika
is_empty(quantization_dimension)
, makasize(scales) = 1
. - (C11)
0 <= quantization_dimension
.
Saat ini, QuantizationScale
adalah konstanta floating point, tetapi ada
minat yang kuat pada skala berbasis bilangan bulat, yang direpresentasikan dengan pengali dan
pergeseran. Kami berencana untuk mempelajarinya dalam waktu dekat
(#1404).
Ada diskusi yang sedang berlangsung tentang semantik QuantizationZeroPoint
,
termasuk jenis, nilai, dan apakah hanya ada satu atau
berpotensi beberapa titik nol dalam jenis tensor kuantisasi. Berdasarkan
hasil diskusi ini, spesifikasi seputar titik nol dapat berubah
di masa mendatang (#1405).
Diskusi lain yang sedang berlangsung melibatkan semantik QuantizationStorageMin
dan QuantizationStorageMax
untuk menentukan apakah batasan apa pun harus
diterapkan pada nilai ini dan pada nilai tensor kuantisasi
(#1406).
Terakhir, kami berencana untuk mengeksplorasi representasi skala dan titik nol yang tidak diketahui, mirip dengan cara kami berencana untuk mengeksplorasi representasi ukuran dimensi yang tidak diketahui (#1407).
Jenis tensor terkuantisasi merepresentasikan tensor dengan elemen terkuantisasi. Tensor ini sama persis dengan tensor reguler, kecuali elemennya memiliki jenis elemen yang dikuantisasi, bukan jenis elemen reguler.
Dalam tensor kuantisasi, kuantisasi dapat berupa per-tensor, yang berarti memiliki
satu scale
dan zero_point
untuk seluruh tensor atau dapat berupa per-sumbu,
yang berarti memiliki beberapa scales
dan zero_points
, satu pasangan per slice
dimensi tertentu quantization_dimension
. Secara lebih formal, dalam tensor t
dengan kuantisasi per sumbu, ada dim(t, quantization_dimension)
slice
dari quantization_dimension
: t[:, ..., 0, ..., :], t[:, ..., 1, ..., :]
,
dll. Semua elemen dalam slice ke-i
menggunakan scales[i]
dan zero_points[i]
sebagai
parameter kuantisasi. Jenis tensor terkuantisasi memiliki batasan
berikut:
- Untuk kuantisasi per tensor:
- Tidak ada batasan tambahan.
- Untuk kuantisasi per sumbu:
- (C12)
quantization_dimension < rank(self)
. - (C13)
dim(self, quantization_dimension) = size(scales)
.
- (C12)
TokenType ::= 'token'
Jenis token mewakili token, yaitu nilai buram yang dihasilkan dan digunakan oleh beberapa operasi. Token digunakan untuk menerapkan urutan eksekusi pada operasi seperti yang dijelaskan di bagian Eksekusi.
TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]
Jenis tuple mewakili tuple, yaitu daftar heterogen. Tuple adalah fitur
lama yang hanya ada untuk kompatibilitas dengan HLO. Di HLO, tuple
digunakan untuk merepresentasikan input dan output variadik. Di StableHLO, input dan
output variadik didukung secara native, dan satu-satunya penggunaan tuple di StableHLO adalah untuk
mewakili HLO ABI secara komprehensif, misalnya T
, tuple<T>
, dan
tuple<tuple<T>>
mungkin secara material berbeda bergantung pada
implementasi tertentu. Di masa mendatang, kami berencana untuk membuat perubahan pada HLO ABI
yang dapat memungkinkan kita menghapus jenis tuple dari StableHLO
(#598).
TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si2' | 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui2' | 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f4E2M1FN' | 'f6E2M3FN' | 'f6E3M2FN' | 'f8E3M4' | 'f8E4M3'
| 'f8E4M3FN' | 'f8E4M3FNUZ' | 'f8E4M3B11FNUZ' | 'f8E5M2'
| 'f8E5M2FNUZ' | 'f8E8M0FNU' | 'bf16' | 'f16' | 'f32' | 'f64'
TensorFloat32 ::= 'tf32'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'
Jenis elemen mewakili elemen jenis tensor. Tidak seperti di banyak bahasa
pemrograman, jenis ini bukan kelas satu di StableHLO. Artinya,
program StableHLO tidak dapat langsung merepresentasikan nilai dari jenis ini (akibatnya,
idiomatik untuk merepresentasikan nilai skalar dari jenis T
dengan nilai tensor
dimensi 0 dari jenis tensor<T>
).
- Jenis boolean mewakili nilai boolean
true
danfalse
. - Jenis bilangan bulat dapat bertanda (
si
) atau tidak bertanda (ui
) dan memiliki salah satu lebar bit yang didukung (2
,4
,8
,16
,32
, atau64
). JenissiN
bertanda mewakili nilai bilangan bulat dari-2^(N-1)
hingga2^(N-1)-1
inklusif, dan jenisuiN
tanpa tanda mewakili nilai bilangan bulat dari0
hingga2^N-1
inklusif. - Jenis floating point dapat berupa salah satu dari berikut:
- Bilangan floating point 8-bit
f8E3M4
,f8E4M3
, danf8E5M2
mengikuti konvensi IEEE-754. - Jenis
f8E4M3FN
danf8E5M2
masing-masing sesuai dengan encodingE4M3
danE5M2
dari format FP8 yang dijelaskan dalam Format FP8 untuk Deep Learning. - Jenis
f8E4M3FNUZ
danf8E5M2FNUZ
yang sesuai dengan encodingE4M3
danE5M2
dari format FP8 yang dijelaskan dalam Format Numerik 8-bit untuk Jaringan Neural Dalam. - Jenis
f8E4M3B11FNUZ
yang sesuai dengan encodingE4M3
dari format FP8 yang dijelaskan dalam Pelatihan dan Inferensi Floating Point 8-bit Hybrid (HFP8) untuk Deep Neural Networks. - Jenis
bf16
yang sesuai dengan formatbfloat16
yang dijelaskan dalam BFloat16: Rahasia performa tinggi di Cloud TPU. - Jenis
f16
,f32
, danf64
masing-masing sesuai dengan formatbinary16
("presisi setengah"),binary32
("presisi tunggal"), danbinary64
("presisi ganda") yang dijelaskan dalam standar IEEE 754. - Jenis
tf32
sesuai dengan format TensorFloat32 dan memiliki dukungan terbatas di StableHLO. - Jenis MX (penskalaan mikro)
f4E2M1FN
,f6E2M3FN
,f6E3M2FN
, danf8E8M0FNU
dijelaskan dalam Spesifikasi Format Penskalaan Mikro OCP.
- Bilangan floating point 8-bit
- Jenis kompleks mewakili nilai kompleks yang memiliki bagian riil
dan bagian imajiner dari jenis elemen yang sama. Jenis kompleks
yang didukung adalah
complex<f32>
(kedua bagiannya berjenisf32
) dancomplex<f64>
(kedua bagiannya berjenisf64
).
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]
Jenis fungsi mewakili fungsi bernama dan anonim. Fungsi ini memiliki
jenis input (daftar jenis di sisi kiri ->
) dan jenis output
(daftar jenis di sisi kanan ->
). Dalam banyak bahasa
pemrograman, jenis fungsi adalah kelas satu, tetapi tidak di StableHLO.
StringType ::= 'string'
Jenis string mewakili urutan byte. Tidak seperti dalam banyak bahasa pemrograman, jenis string bukan kelas pertama di StableHLO dan hanya digunakan untuk menentukan metadata statis untuk elemen program.
Operasi
Operasi StableHLO (yang juga disebut ops) mewakili kumpulan tertutup operasi tingkat tinggi dalam model machine learning. Seperti yang telah dibahas di atas, sintaksis StableHLO sangat terinspirasi oleh MLIR, yang belum tentu merupakan alternatif paling ergonomis, tetapi dapat dibilang paling sesuai dengan sasaran StableHLO untuk membuat lebih banyak interoperabilitas antara framework ML dan compiler ML.
Op ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic ::= 'abs' | 'add' | ...
Operasi StableHLO (yang juga disebut ops) memiliki nama,
input/output, dan tanda tangan. Nama terdiri dari awalan stablehlo.
dan
mnemonik yang secara unik mengidentifikasi salah satu operasi yang didukung. Lihat di bawah untuk mengetahui daftar lengkap semua operasi yang didukung.
OpInputs ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue ::= ValueId
OpInputFuncs ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs ::= [OpOutput {',' OpOutput} '=']
OpOutput ::= ValueId
Ops menggunakan input dan menghasilkan output. Input dikategorikan ke dalam nilai input (dihitung selama eksekusi), fungsi input (disediakan secara statis, karena dalam StableHLO bukan merupakan nilai kelas satu), dan atribut input (juga disediakan secara statis). Jenis input dan output
yang digunakan dan dihasilkan oleh op bergantung pada mnemoninya. Misalnya, operasi add
menggunakan 2 nilai input dan menghasilkan 1 nilai output. Sebagai perbandingan,
op select_and_scatter
menggunakan 3 nilai input, 2 fungsi input, dan
3 atribut input.
OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused ::= '^' digit {digit}
| '^' letter {letter | digit}
Fungsi input (yang juga disebut fungsi anonim) sangat mirip dengan fungsi bernama, kecuali: 1) tidak memiliki ID (sehingga namanya "anonim"), 2) tidak mendeklarasikan jenis output (jenis output disimpulkan dari operasi return
dalam fungsi).
Sintaksis untuk fungsi input mencakup bagian yang saat ini tidak digunakan (lihat
produksi Unused
di atas) yang ada untuk kompatibilitas dengan MLIR. Di MLIR,
ada konsep "region" yang lebih umum yang dapat memiliki beberapa "blok"
operasi yang terhubung bersama melalui operasi lompat. Blok ini memiliki ID yang sesuai
dengan produksi Unused
, sehingga dapat dibedakan satu sama lain.
StableHLO tidak memiliki operasi lompat, sehingga bagian yang sesuai dari sintaksis MLIR tidak digunakan (tetapi masih ada).
OpInputAttr ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName ::= letter {letter | digit}
OpInputAttrValue ::= Constant
Atribut input memiliki nama dan nilai yang merupakan salah satu konstanta yang didukung. Class ini adalah cara utama dalam menentukan metadata statis untuk elemen
program. Misalnya, operasi concatenate
menggunakan atribut dimension
untuk
menentukan dimensi tempat nilai inputnya digabungkan. Demikian pula,
op slice
menggunakan beberapa atribut seperti start_indices
dan limit_indices
untuk menentukan batas yang digunakan untuk memotong nilai input.
Saat ini, program StableHLO di dunia nyata terkadang berisi atribut yang tidak dijelaskan dalam dokumen ini. Di masa mendatang, kami berencana untuk menyerap atribut ini ke dalam opset StableHLO atau melarangnya muncul dalam program StableHLO. Sementara itu, berikut adalah daftar atribut ini:
layout
(#629).mhlo.frontend_attributes
(#628).mhlo.sharding
(#619).output_operand_aliases
(#740).- Metadata lokasi (#594).
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'
Tanda tangan op terdiri dari jenis semua nilai input (daftar jenis di sisi kiri ->
) dan jenis semua nilai output (daftar jenis di sisi kanan ->
). Sederhananya, jenis input bersifat redundan, dan jenis output hampir selalu redundan (karena untuk sebagian besar operasi StableHLO, jenis output dapat disimpulkan dari input). Meskipun demikian, tanda tangan op sengaja menjadi bagian dari sintaksis StableHLO untuk kompatibilitas dengan MLIR.
Berikut adalah contoh operasi yang mnemoninya adalah select_and_scatter
. Fungsi ini menggunakan 3
nilai input (%operand
, %source
, dan %init_value
), 2 fungsi input,
dan 3 atribut input (window_dimensions
, window_strides
, dan padding
).
Perhatikan bagaimana tanda tangan operasi hanya menyertakan jenis nilai inputnya
(tetapi bukan jenis fungsi dan atribut input yang disediakan secara inline).
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i32>, tensor<i32>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
window_dimensions = dense<[3, 1]> : tensor<2xi64>,
window_strides = dense<[2, 1]> : tensor<2xi64>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>
Konstanta
Constant ::= BooleanConstant
| IntegerConstant
| FloatConstant
| ComplexConstant
| TensorConstant
| QuantizedTensorConstant
| StringConstant
| EnumConstant
Konstanta StableHLO memiliki literal dan jenis yang bersama-sama mewakili
nilai StableHLO. Umumnya, jenis ini adalah bagian dari sintaksis konstanta, kecuali
jika tidak ambigu (misalnya, konstanta boolean memiliki jenis i1
secara tidak ambigu,
sedangkan konstanta bilangan bulat dapat memiliki beberapa kemungkinan jenis).
BooleanConstant ::= BooleanLiteral
BooleanLiteral ::= 'true' | 'false'
Konstanta Boolean mewakili nilai boolean true
dan false
. Konstanta Boolean memiliki jenis i1
.
IntegerConstant ::= IntegerLiteral ':' IntegerType
IntegerLiteral ::= ['-' | '+'] DecimalDigits
| ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit ::= '0' | ... | '9'
hexadecimalDigit ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'
Konstanta bilangan bulat merepresentasikan nilai bilangan bulat melalui string yang menggunakan notasi desimal atau heksadesimal. Basis lain, misalnya biner atau oktal, tidak didukung. Konstanta bilangan bulat memiliki batasan berikut:
- (C1)
is_wellformed(integer_literal, integer_type)
.
FloatConstant ::= FloatLiteral ':' FloatType
FloatLiteral ::= SignPart IntegerPart FractionalPart ScientificPart
| '0x' [HexadecimalDigits]
SignPart ::= ['-' | '+']
IntegerPart ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]
Konstanta floating point mewakili nilai floating point melalui string yang menggunakan notasi desimal atau notasi ilmiah. Selain itu, notasi heksadesimal dapat digunakan untuk menentukan secara langsung bit yang mendasarinya dalam format floating point dari jenis yang sesuai. Konstanta floating point memiliki batasan berikut:
- (C1) Jika notasi non-heksadesimal digunakan,
is_wellformed(float_literal, float_type)
. - (C2) Jika notasi heksadesimal digunakan,
size(hexadecimal_digits) = num_bits(float_type) / 4
.
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral ::= '(' RealPart ',' ImaginaryPart ')'
RealPart ::= FloatLiteral
ImaginaryPart ::= FloatLiteral
Konstanta kompleks mewakili nilai kompleks menggunakan daftar bagian real
(didahulukan) dan bagian imajiner (didahulukan). Misalnya,
(1.0, 0.0) : complex<f32>
mewakili 1.0 + 0.0i
, dan
(0.0, 1.0) : complex<f32>
mewakili 0.0 + 1.0i
. Urutan penyimpanan
bagian-bagian ini dalam memori ditentukan oleh implementasi. Konstanta kompleks
memiliki batasan berikut:
- (C1)
is_wellformed(real_part, complex_element_type(complex_type))
. - (C2)
is_wellformed(imaginary_part, complex_element_type(complex_type))
.
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral
Konstanta tensor merepresentasikan nilai tensor menggunakan daftar bertingkat yang ditentukan melalui
notasi NumPy. Misalnya, dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
merepresentasikan nilai tensor dengan pemetaan berikut dari indeks ke elemen:
{0, 0} => 1
, {0, 1} => 2
, {0, 2} => 3
, {1, 0} => 4
, {1, 1} => 5
,
{1, 2} => 6
. Urutan penyimpanan elemen ini dalam memori ditentukan oleh
implementasi. Konstanta tensor memiliki batasan berikut:
- (C1)
has_syntax(tensor_literal, element_type(tensor_type))
, dengan:has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type)
.has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type)
.
- (C2)
has_shape(tensor_literal, shape(tensor_type))
, dengan:has_shape(element_literal: Syntax, []) = true
.has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:])
.- jika tidak,
false
.
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
Konstanta tensor kuantisasi mewakili nilai tensor kuantisasi menggunakan notasi yang sama dengan konstanta tensor, dengan elemen yang ditentukan sebagai konstanta dari jenis penyimpanannya. Konstanta tensor terkuantisasi memiliki batasan berikut:
- (C1)
has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type))
. - (C2)
has_shape(quantized_tensor_literal, shape(quantized_tensor_type))
.
StringConstant ::= StringLiteral
StringLiteral ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))
Literal string terdiri dari byte yang ditentukan menggunakan karakter ASCII dan urutan escape. Byte ini tidak bergantung pada encoding, sehingga interpretasi
byte ini ditentukan oleh implementasi. String literal memiliki jenis string
.
Operasi
abs
Semantik
Melakukan operasi abs element-wise pada tensor operand
dan menghasilkan
tensor result
. Bergantung pada jenis elemen, lakukan tindakan berikut:
- Untuk bilangan bulat bertanda tangan: modulus bilangan bulat.
- Untuk float:
abs
dari IEEE-754. - Untuk bilangan kompleks: modulus kompleks.
- Untuk jenis terkuantisasi:
dequantize_op_quantize(abs, operand, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor bilangan bulat bertanda tangan, floating point, atau jenis kompleks atau tensor terkuantisasi per tensor | (C1-C2) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor dari jenis bilangan bulat bertanda atau floating point atau tensor terkuantisasi per tensor | (C1-C2) |
Batasan
- (C1)
shape(result) = shape(operand)
. - (C2)
baseline_element_type(result)
ditentukan sebagai:complex_element_type(element_type(operand))
jikais_complex(operand)
.baseline_element_type(operand)
jika tidak.
Contoh
// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand) : (tensor<3xi32>) -> tensor<3xi32>
// %result: [2, 0, 2]
tambahkan
Semantik
Melakukan penambahan element-wise dari dua tensor lhs
dan rhs
dan menghasilkan
tensor result
. Bergantung pada jenis elemen, lakukan tindakan berikut:
- Untuk boolean: OR logis.
- Untuk bilangan bulat: penambahan bilangan bulat.
- Untuk float:
addition
dari IEEE-754. - Untuk bilangan kompleks: penambahan kompleks.
- Untuk jenis terkuantisasi:
dequantize_op_quantize(add, lhs, rhs, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | lhs |
tensor atau tensor terkuantisasi | (C1-C6) |
(I2) | rhs |
tensor atau quantized tensor | (C1-C5), (C7) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau tensor terkuantisasi | (C1-C7) |
Batasan
- Jika operasi menggunakan tensor yang tidak dikuantisasi:
- (C1)
type(lhs) = type(rhs) = type(result)
.
- (C1)
- Jika operasi menggunakan tensor kuantisasi:
- (C2)
is_quantized(lhs) and is_quantized(rhs) and is_quantized(result)
. - (C3)
storage_type(lhs) = storage_type(rhs) = storage_type(result)
. - (C4)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C5)
(is_per_axis_quantized(lhs) or is_per_axis_quantized(rhs)) = is_per_axis_quantized(result)
. - (C6) Jika
is_per_axis_quantized(lhs)
, makaquantization_dimension(lhs) = quantization_dimension(result)
. - (C7) Jika
is_per_axis_quantized(rhs)
, makaquantization_dimension(rhs) = quantization_dimension(result)
.
- (C2)
Contoh
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[6, 8], [10, 12]]
after_all
Semantik
Memastikan bahwa operasi yang menghasilkan inputs
dieksekusi sebelum
operasi apa pun yang bergantung pada result
. Eksekusi operasi ini tidak melakukan apa pun,
hanya ada untuk menetapkan dependensi data dari result
ke inputs
.
Input
Label | Nama | Jenis |
---|---|---|
(I1) | inputs |
jumlah variabel token |
Output
Nama | Jenis |
---|---|
result |
token |
Contoh
// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token
all_gather
Semantik
Dalam setiap grup proses di petak proses StableHLO, gabungkan nilai
tensor operands
dari setiap proses di sepanjang all_gather_dim
dan hasilkan
tensor results
.
Operasi ini membagi petak proses StableHLO menjadi process_groups
yang
ditentukan sebagai berikut:
cross_replica(replica_groups)
ifchannel_id <= 0 and use_global_device_ids = false
.cross_replica_and_partition(replica_groups)
jikachannel_id > 0 and use_global_device_ids = false
.flattened_ids(replica_groups)
ifchannel_id > 0 and use_global_device_ids = true
.
Setelah itu, dalam setiap process_group
:
operands...@receiver = [operand@sender for sender in process_group]
untuk semuareceiver
diprocess_group
.results...@process = concatenate(operands...@process, all_gather_dim)
untuk semuaprocess
diprocess_group
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operands |
jumlah variabel tensor atau tensor terkuantisasi per tensor | (C1), (C6) |
(I2) | all_gather_dim |
konstanta dari jenis si64 |
(C1), (C6) |
(I3) | replica_groups |
Konstanta tensor 2 dimensi jenis si64 |
(C2-C4) |
(I4) | channel_id |
konstanta dari jenis si64 |
(C5) |
(I5) | use_global_device_ids |
konstanta jenis i1 |
(C5) |
Output
Nama | Jenis | Batasan |
---|---|---|
results |
jumlah variabel tensor atau tensor terkuantisasi per tensor | (C6) |
Batasan
- (C1)
0 <= all_gather_dim < rank(operands...)
. - (C2)
is_unique(replica_groups)
. - (C3)
size(replica_groups)
ditentukan sebagai:num_replicas
jikacross_replica
digunakan.num_replicas
jikacross_replica_and_partition
digunakan.num_processes
jikaflattened_ids
digunakan.
- (C4)
0 <= replica_groups < size(replica_groups)
. - (C5) Jika
use_global_device_ids = true
, makachannel_id > 0
. - (C6)
type(results...) = type(operands...)
kecuali:dim(results..., all_gather_dim) = dim(operands..., all_gather_dim) * dim(process_groups, 1)
.
Contoh
// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [[1, 2], [3, 4]]
// %operand0@(1, 0): [[5, 6], [7, 8]]
// %operand1@(0, 0): [[11, 12], [13, 14]]
// %operand1@(1, 0): [[15, 16], [17, 18]]
%result:2 = "stablehlo.all_gather"(%operand0, %operand1) {
all_gather_dim = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
// channel_id = 0
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
// use_global_device_ids = false
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> (tensor<2x4xi64>, tensor<2x4xi64>)
// %result0@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result0@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result1@(0, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
// %result1@(1, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
all_reduce
Semantik
Dalam setiap grup proses di petak proses StableHLO, menerapkan fungsi pengurangan
computation
ke nilai tensor operands
dari setiap proses
dan menghasilkan tensor results
.
Operasi ini membagi petak proses StableHLO menjadi process_groups
yang
ditentukan sebagai berikut:
cross_replica(replica_groups)
ifchannel_id <= 0 and use_global_device_ids = false
.cross_replica_and_partition(replica_groups)
jikachannel_id > 0 and use_global_device_ids = false
.flattened_ids(replica_groups)
ifchannel_id > 0 and use_global_device_ids = true
.
Setelah itu, dalam setiap process_group
:
results...@process[result_index] = exec(schedule)
untuk beberapa hierarki binerschedule
dengan:exec(node)
=computation(exec(node.left), exec(node.right))
.exec(leaf)
=leaf.value
.
schedule
adalah hierarki biner yang ditentukan implementasi yang traversal urutannya adalahto_destination_type(operands...@process_group...[result_index], type(func_inputs(computation)[0]))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operands |
jumlah variabel tensor atau tensor terkuantisasi per tensor | (C5), (C6) |
(I2) | replica_groups |
jumlah variadik konstanta tensor 1 dimensi jenis si64 |
(C1-C3) |
(I3) | channel_id |
konstanta jenis si64 |
(C4) |
(I4) | use_global_device_ids |
konstanta dari jenis i1 |
(C4) |
(I5) | computation |
fungsi | (C5) |
Output
Nama | Jenis | Batasan |
---|---|---|
results |
jumlah variabel tensor atau tensor terkuantisasi per tensor | (C6-C7) |
Batasan
- (C1)
is_unique(replica_groups)
. - (C2)
size(replica_groups)
ditentukan sebagai:num_replicas
jikacross_replica
digunakan.num_replicas
jikacross_replica_and_partition
digunakan.num_processes
jikaflattened_ids
digunakan.
- (C3)
0 <= replica_groups < size(replica_groups)
. - (C4) Jika
use_global_device_ids = true
, makachannel_id > 0
. - (C5)
computation
memiliki jenis(tensor<E>, tensor<E>) -> (tensor<E>)
, denganis_promotable(element_type(operand), E)
. - (C6)
shape(results...) = shape(operands...)
. - (C7)
element_type(results...) = E
.
Contoh
// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [1, 2, 3, 4]
// %operand0@(1, 0): [5, 6, 7, 8]
// %operand1@(0, 0): [9, 10, 11, 12]
// %operand1@(1, 0): [13, 14, 15, 16]
%result:2 = "stablehlo.all_reduce"(%operand0, %operand0) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
// channel_id = 0
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
// use_global_device_ids = false
} : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>, tensor<4xi64>)
// %result0@(0, 0): [6, 8, 10, 12]
// %result0@(1, 0): [6, 8, 10, 12]
// %result1@(0, 0): [22, 24, 26, 28]
// %result1@(1, 0): [22, 24, 26, 28]
all_to_all
Semantik
Dalam setiap grup proses di petak proses StableHLO, bagi nilai
tensor operands
di sepanjang split_dimension
menjadi beberapa bagian, sebar bagian
yang dibagi di antara proses, gabungkan bagian yang tersebar di sepanjang
concat_dimension
, dan hasilkan tensor results
.
Operasi ini membagi petak proses StableHLO menjadi process_groups
yang
ditentukan sebagai berikut:
cross_replica(replica_groups)
ifchannel_id <= 0
.cross_partition(replica_groups)
ifchannel_id > 0
.
Setelah itu, dalam setiap process_group
:
split_parts...@sender = split(operands...@sender, split_count, split_dimension)
untuk semuasender
diprocess_group
.scattered_parts...@receiver = [split_parts...@sender[receiver_index] for sender in process_group]
denganreceiver_index = process_group.index(receiver)
.results...@process = concatenate(scattered_parts...@process, concat_dimension)
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operands |
jumlah variabel tensor atau tensor terkuantisasi per tensor | (C1-C3), (C9) |
(I2) | split_dimension |
konstanta dari jenis si64 |
(C1), (C2), (C9) |
(I3) | concat_dimension |
konstanta dari jenis si64 |
(C3), (C9) |
(I4) | split_count |
konstanta dari jenis si64 |
(C2), (C4), (C8), (C9) |
(I5) | replica_groups |
Konstanta tensor 2 dimensi dari jenis si64 |
(C5-C8) |
(I6) | channel_id |
konstanta jenis si64 |
Output
Nama | Jenis | Batasan |
---|---|---|
results |
jumlah variabel tensor atau tensor terkuantisasi per tensor | (C9) |
Batasan
- (C1)
0 <= split_dimension < rank(operands...)
. - (C2)
dim(operands..., split_dimension) % split_count = 0
. - (C3)
0 <= concat_dimension < rank(operands...)
. - (C4)
0 < split_count
. - (C5)
is_unique(replica_groups)
. - (C6)
size(replica_groups)
ditentukan sebagai:num_replicas
jikacross_replica
digunakan.num_partitions
jikacross_partition
digunakan.
- (C7)
0 <= replica_groups < size(replica_groups)
. - (C8)
dim(replica_groups, 1) = split_count
. - (C9)
type(results...) = type(operands...)
kecuali, jikasplit_dimension != concat_dimension
:dim(results..., split_dimension) = dim(operands..., split_dimension) / split_count
.dim(results..., concat_dimension) = dim(operands..., concat_dimension) * split_count
.
Contoh
// num_replicas: 2
// num_partitions: 1
// %operand1@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand1@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
// %operand2@(0, 0): [[17, 18, 19, 20],
// [21, 22, 23, 24]]
// %operand2@(1, 0): [[25, 26, 27, 28],
// [29, 30, 31, 32]]
%result:2 = "stablehlo.all_to_all"(%operand1, %operand2) {
split_dimension = 1 : i64,
concat_dimension = 0 : i64,
split_count = 2 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
// channel_id = 0
} : (tensor<2x4xi64>, tensor<2x4xi64>) -> (tensor<4x2xi64>, tensor<4x2xi64>)
// %result#0@(0, 0): [[1, 2], [5, 6], [9, 10], [13, 14]]
// %result#0@(1, 0): [[3, 4], [7, 8], [11, 12], [15, 16]]
// %result#1@(0, 0): [[17, 18], [21, 22], [25, 26], [29, 30]]
// %result#1@(1, 0): [[19, 20], [23, 24], [27, 28], [31, 32]]
dan
Semantik
Melakukan AND element-wise dari dua tensor lhs
dan rhs
serta menghasilkan
tensor result
. Bergantung pada jenis elemen, lakukan tindakan berikut:
- Untuk boolean: AND logika.
- Untuk bilangan bulat: bitwise AND.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | lhs |
tensor dari jenis boolean atau bilangan bulat | (C1) |
(I2) | rhs |
tensor dari jenis boolean atau bilangan bulat | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor dari jenis boolean atau bilangan bulat | (C1) |
Batasan
- (C1)
type(lhs) = type(rhs) = type(result)
.
Contoh
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 0]]
atan2
Semantik
Melakukan operasi atan2 element-wise pada tensor lhs
dan rhs
serta menghasilkan
tensor result
. Bergantung pada jenis elemen, lakukan tindakan berikut:
- Untuk float:
atan2
dari IEEE-754. - Untuk bilangan kompleks: atan2 kompleks.
- Untuk jenis kuantisasi:
dequantize_op_quantize(atan2, lhs, rhs, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | lhs |
tensor floating point atau jenis kompleks atau tensor terkuantisasi per-tensor | (C1) |
(I2) | rhs |
tensor dari jenis floating point atau kompleks atau tensor terkuantisasi per tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor dari jenis floating point atau kompleks atau tensor terkuantisasi per tensor | (C1) |
Batasan
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Contoh
// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs) : (tensor<3xf64>, tensor<3xf64>) -> tensor<3xf64>
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]
batch_norm_grad
Semantik
Menghitung gradien beberapa input batch_norm_training
yang melakukan backpropagation
dari grad_output
, dan menghasilkan tensor grad_operand
, grad_scale
, dan grad_offset
. Secara lebih formal, operasi ini dapat dinyatakan sebagai dekomposisi ke
operasi StableHLO yang ada menggunakan sintaksis Python sebagai berikut:
def compute_sum(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
return sum
def compute_mean(operand, feature_index):
sum = compute_sum(operand, feature_index)
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
# Broadcast inputs to type(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance`
# Intermediate values will be useful for computing gradients
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
# Use the implementation from batchnorm_expander.cc in XLA
# Temporary variables have exactly the same names as in the C++ code
elements_per_feature = broadcast_in_dim(
constant(divide(size(operand), dim(operand, feature_index)),
element_type(grad_output)),
[], type(operand))
i1 = multiply(grad_output, elements_per_feature)
i2 = broadcast_in_dim(
compute_sum(grad_output, feature_index), [feature_index], type(operand))
i3 = broadcast_in_dim(
compute_sum(multiply(grad_output, centered_operand), feature_index),
[feature_index], type(operand))
i4 = multiply(i3, centered_operand)
i5 = divide(i4, add(variance_bcast, epsilon_bcast))
i6 = subtract(subtract(i1, i2), i5)
grad_operand =
multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
grad_scale =
compute_sum(multiply(grad_output, normalized_operand), feature_index)
grad_offset = compute_sum(grad_output, feature_index)
return grad_operand, grad_scale, grad_offset
Untuk jenis kuantisasi, lakukan
dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean,
variance, grad_output: batch_norm_grad(operand, scale, mean, variance,
grad_output, epsilon, feature_index), operand, scale, mean, variance,
grad_output, type(grad_operand), type(grad_scale), type(feature_index))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor jenis floating point atau tensor terkuantisasi per tensor | (C1-C3), (C5) |
(I2) | scale |
Tensor 1 dimensi dari jenis floating point atau terkuantisasi per tensor | (C2), (C4), (C5) |
(I3) | mean |
Tensor 1-dimensi dari floating-point atau jenis terkuantisasi per-tensor | (C2), (C4) |
(I4) | variance |
Tensor 1-dimensi dari floating-point atau jenis terkuantisasi per-tensor | (C2), (C4) |
(I5) | grad_output |
tensor jenis floating point atau tensor terkuantisasi per tensor | (C2), (C3) |
(I6) | epsilon |
konstanta dari jenis f32 |
|
(I7) | feature_index |
konstanta dari jenis si64 |
(C1), (C5) |
Output
Nama | Jenis | Batasan |
---|---|---|
grad_operand |
tensor jenis floating point atau tensor terkuantisasi per tensor | (C2), (C3) |
grad_scale |
Tensor 1 dimensi dari jenis floating point atau terkuantisasi per tensor | (C2), (C4) |
grad_offset |
Tensor 1 dimensi dari jenis floating point atau terkuantisasi per tensor | (C2), (C4) |
Batasan
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,mean
,variance
,grad_output
,grad_operand
,grad_scale
, dangrad_offset
memilikibaseline_element_type
yang sama. - (C3)
operand
,grad_output
, dangrad_operand
memiliki bentuk yang sama. - (C4)
scale
,mean
,variance
,grad_scale
, dangrad_offset
memiliki bentuk yang sama. - (C5)
size(scale) = dim(operand, feature_index)
.
Contoh
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
// [[0.1, 0.1], [0.1, 0.1]],
// [[0.1, 0.1], [0.1, 0.1]]
// ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>,
tensor<2x2x2xf64>) -> (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %grad_operand: [
// [[0.0, 0.0], [0.0, 0.0]],
// [[0.0, 0.0], [0.0, 0.0]]
// ]
// %grad_scale: [0.0, 0.0]
// %grad_offset: [0.4, 0.4]
batch_norm_inference
Semantik
Menormalisasi tensor operand
di semua dimensi kecuali
dimensi feature_index
dan menghasilkan tensor result
. Secara lebih formal, operasi
ini dapat dinyatakan sebagai dekomposisi ke operasi StableHLO yang ada
menggunakan sintaksis Python sebagai berikut:
def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
# Broadcast inputs to shape(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance` instead of
# computing them like `batch_norm_training` does.
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
return add(multiply(scale_bcast, normalized_operand), offset_bcast)
Untuk jenis kuantisasi, lakukan
dequantize_op_quantize(lambda operand, scale, offset, mean, variance:
batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index), operand, scale, offset, mean, variance, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor jenis floating point atau tensor terkuantisasi per tensor | (C1-C7) |
(I2) | scale |
Tensor 1 dimensi dari jenis floating point atau terkuantisasi per tensor | (C2), (C3) |
(I3) | offset |
Tensor 1-dimensi dari floating-point atau jenis terkuantisasi per-tensor | (C2), (C4) |
(I4) | mean |
Tensor 1 dimensi dari jenis floating point atau terkuantisasi per tensor | (C5) |
(I5) | variance |
Tensor 1 dimensi dari jenis floating point atau terkuantisasi per tensor | (C2), (C6) |
(I6) | epsilon |
konstanta dari jenis f32 |
|
(I7) | feature_index |
konstanta jenis si64 |
(C1), (C3-C6) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor jenis floating point atau tensor terkuantisasi per tensor | (C2), (C7) |
Batasan
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,offset
,mean
,variance
, danresult
memilikibaseline_element_type
yang sama. - (C3)
size(scale) = dim(operand, feature_index)
. - (C4)
size(offset) = dim(operand, feature_index)
. - (C5)
size(mean) = dim(operand, feature_index)
. - (C6)
size(variance) = dim(operand, feature_index)
. - (C7)
baseline_type(operand) = baseline_type(result)
.
Contoh
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>) -> tensor<2x2x2xf64>
// %result: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
batch_norm_training
Semantik
Menghitung rata-rata dan varian di semua dimensi kecuali untuk dimensi feature_index
dan menormalisasi tensor operand
yang menghasilkan tensor output
, batch_mean
, dan batch_var
. Secara lebih formal, operasi ini dapat dinyatakan sebagai
dekomposisi ke operasi StableHLO yang ada menggunakan sintaksis Python sebagai
berikut:
def compute_mean(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def compute_variance(operand, feature_index):
mean = compute_mean(operand, feature_index)
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
centered_operand = subtract(operand, mean_bcast)
return compute_mean(mul(centered_operand, centered_operand), feature_index)
def batch_norm_training(operand, scale, offset, epsilon, feature_index):
mean = compute_mean(operand, feature_index)
variance = compute_variance(operand, feature_index)
return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index),
mean, variance
Untuk jenis kuantisasi, lakukan
dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset:
batch_norm_training(operand, scale, offset, epsilon, feature_index), operand,
scale, offset, type(output), type(batch_mean), type(batch_var))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor jenis floating point atau tensor terkuantisasi per tensor | (C1) |
(I2) | scale |
Tensor 1 dimensi dari floating point atau terkuantisasi per tensor | (C2), (C3) |
(I3) | offset |
Tensor 1 dimensi dari floating point atau terkuantisasi per tensor | (C2), (C4) |
(I4) | epsilon |
konstanta jenis f32 |
(C1), (C3-C6) |
(I5) | feature_index |
konstanta jenis si64 |
(C1), (C3-C6) |
Output
Nama | Jenis | Batasan |
---|---|---|
output |
tensor jenis floating point atau tensor terkuantisasi per tensor | (C7) |
batch_mean |
Tensor 1 dimensi dari floating point atau terkuantisasi per tensor | (C2), (C5) |
batch_var |
Tensor 1-dimensi dari floating-point atau per-tensor terkuantisasi | (C2), (C6) |
Batasan
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,offset
,batch_mean
,batch_var
, danoutput
memilikibaseline_element_type
yang sama. - (C3)
size(scale) = dim(operand, feature_index)
. - (C4)
size(offset) = dim(operand, feature_index)
. - (C5)
size(batch_mean) = dim(operand, feature_index)
. - (C6)
size(batch_var) = dim(operand, feature_index)
. - (C7)
baseline_type(output) = baseline_type(operand)
.
Contoh
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) ->
(tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %output: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]
bitcast_convert
Semantik
Melakukan operasi bitcast pada tensor operand
dan menghasilkan tensor result
dengan bit dari seluruh tensor operand
ditafsirkan ulang menggunakan
jenis tensor result
.
Secara lebih formal, dengan E = element_type(operand)
, E' = element_type(result)
,
dan R = rank(operand)
:
- Jika
num_bits(E') < num_bits(E)
,bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1])
. - Jika
num_bits(E') > num_bits(E)
,bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :])
. - Jika
num_bits(E') = num_bits(E)
,bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1])
.
bits
menampilkan representasi dalam memori dari nilai tertentu, dan perilakunya
ditentukan oleh implementasi karena representasi persis tensor
ditentukan oleh implementasi, dan representasi persis jenis elemen
juga ditentukan oleh implementasi.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor atau tensor terkuantisasi | (C1-C2) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau tensor terkuantisasi | (C1-C2) |
Batasan
- (C1) Dengan
E = is_quantized(operand) ? storage_type(operand) : element_type(operand)
,E' = is_quantized(result) ? storage_type(result) : element_type(result)
, danR = rank(operand)
:- Jika
num_bits(E') = num_bits(E)
,shape(result) = shape(operand)
. - Jika
num_bits(E') < num_bits(E)
: rank(result) = R + 1
.dim(result, i) = dim(operand, i)
untuk semua0 <= i < R
.dim(result, R) * num_bits(E') = num_bits(E)
.- Jika
num_bits(E') > num_bits(E)
: rank(result) = R - 1
.dim(result, i) = dim(operand, i)
untuk semua0 <= i < R
.dim(operand, R - 1) * num_bits(E) = num_bits(E')
.
- Jika
- (C2) Jika
is_complex(operand) or is_complex(result)
, makais_complex(operand) and is_complex(result)
.
Contoh
// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand) : (tensor<f64>) -> tensor<4xf16>
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation
broadcast_in_dim
Semantik
Memperluas dimensi dan/atau peringkat tensor input dengan menduplikasi data
dalam tensor operand
dan menghasilkan tensor result
. Secara lebih formal,
result[result_index] = operand[operand_index]
dengan untuk semua d
di
axes(operand)
:
operand_index[d] = 0
ifdim(operand, d) = 1
.operand_index[d] = result_index[broadcast_dimensions[d]]
jika tidak.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor atau quantized tensor | (C1-C2), (C5-C6) |
(I2) | broadcast_dimensions |
Konstanta tensor 1 dimensi dari jenis si64 |
(C2-C6) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau tensor terkuantisasi | (C1), (C3), (C5-C6) |
Batasan
- (C1)
element_type(result)
diberikan oleh:element_type(operand)
, jika!is_per_axis_quantized(operand)
.element_type(operand)
kecuali bahwaquantization_dimension(operand)
,scales(operand)
, danzero_points(operand)
mungkin berbeda dariquantization_dimension(result)
,scales(result)
, danzero_points(result)
masing-masing.
- (C2)
size(broadcast_dimensions) = rank(operand)
. - (C3)
0 <= broadcast_dimensions < rank(result)
. - (C4)
is_unique(broadcast_dimensions)
. - (C5) Untuk semua
d
diaxes(operand)
:dim(operand, d) = 1
ataudim(operand, d) = dim(result, broadcast_dimensions[d])
.
- (C6) Jika
is_per_axis_quantized(result)
:quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)]
.- Jika
dim(operand, quantization_dimension(operand)) = 1
, makascales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result)))
.
Contoh
// %operand: [
// [1, 2, 3]
// ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
broadcast_dimensions = array<i64: 2, 1>
} : (tensor<1x3xi32>) -> tensor<2x3x2xi32>
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
casing
Semantik
Menghasilkan output dari mengeksekusi tepat satu fungsi dari branches
,
bergantung pada nilai index
. Secara lebih formal, result = selected_branch()
dengan:
selected_branch = branches[index]
if0 <= index < size(branches)
.selected_branch = branches[-1]
jika tidak.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | index |
Tensor 0-dimensi jenis si32 |
|
(I2) | branches |
jumlah fungsi variadik | (C1-C4) |
Output
Nama | Jenis | Batasan |
---|---|---|
results |
jumlah variabel tensor, tensor terkuantisasi, atau token | (C4) |
Batasan
- (C1)
0 < size(branches)
. - (C2)
input_types(branches...) = []
. - (C3)
same(output_types(branches...))
. - (C4)
type(results...) = output_types(branches[0])
.
Contoh
// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
"stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> ()
}, {
"stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> ()
}) : (tensor<i32>) -> (tensor<2xi64>, tensor<2xi64>)
// %result0: [1, 1]
// %result1: [1, 1]
cbrt
Semantik
Melakukan operasi akar kubik element-wise pada tensor operand
dan menghasilkan
tensor result
. Bergantung pada jenis elemen, lakukan tindakan berikut:
- Untuk float:
rootn(x, 3)
dari IEEE-754. - Untuk bilangan kompleks: akar pangkat tiga kompleks.
- Untuk jenis terkuantisasi:
dequantize_op_quantize(cbrt, operand, type(result))
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor floating point atau jenis kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor dari jenis floating point atau kompleks atau tensor terkuantisasi per tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result)
.
Contoh
// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand) : (tensor<4xf64>) -> tensor<4xf64>
// %result: [0.0, 1.0, 2.0, 3.0]
ceil
Semantik
Melakukan ceil element-wise dari tensor operand
dan menghasilkan tensor result
.
Mengimplementasikan operasi roundToIntegralTowardPositive
dari spesifikasi
IEEE-754. Untuk jenis kuantisasi, lakukan
dequantize_op_quantize(ceil, operand, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor jenis floating point atau tensor terkuantisasi per tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor jenis floating point atau tensor terkuantisasi per tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result)
.
Contoh
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]
cholesky
Semantik
Menghitung dekomposisi Cholesky dari batch matriks.
Secara lebih formal, untuk semua i
di index_space(result)
,
result[i0, ..., iR-3, :, :]
adalah dekomposisi Cholesky dari
a[i0, ..., iR-3, :, :]
, dalam bentuk matriks segitiga bawah
(jika lower
adalah true
) atau matriks segitiga atas (jika lower
adalah false
).
Nilai output di segitiga yang berlawanan, yaitu segitiga atas yang ketat atau segitiga bawah yang ketat, ditentukan oleh implementasi.
Jika ada i
di mana matriks input bukan matriks positif-definit Hermitian, perilakunya tidak ditentukan.
Untuk jenis kuantisasi, lakukan
dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | a |
tensor dari jenis floating point atau kompleks atau tensor terkuantisasi per tensor | (C1-C3) |
(I2) | lower |
Konstanta tensor 0 dimensi jenis i1 |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor dari jenis floating point atau kompleks atau tensor terkuantisasi per tensor | (C1) |
Batasan
- (C1)
baseline_type(a) = baseline_type(result)
. - (C2)
2 <= rank(a)
. - (C3)
dim(a, -2) = dim(a, -1)
.
Contoh
// %a: [
// [1.0, 2.0, 3.0],
// [2.0, 20.0, 26.0],
// [3.0, 26.0, 70.0]
// ]
%result = "stablehlo.cholesky"(%a) {
lower = true
} : (tensor<3x3xf32>) -> tensor<3x3xf64>
// %result: [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
klem
Semantik
Membatasi setiap elemen tensor operand
antara nilai minimum dan maksimum
dan menghasilkan tensor result
. Secara lebih formal, result[result_index] =
minimum(maximum(operand[result_index], min_element), max_element)
,
dengan min_element = rank(min) = 0 ? min[] : min[result_index]
,
max_element = rank(max) = 0 ? max[] : max[result_index]
. Untuk jenis terkuantisasi,
menjalankan dequantize_op_quantize(clamp, min, operand, max, type(result))
.
Memaksakan pengurutan pada angka kompleks melibatkan semantik yang mengejutkan, jadi di masa mendatang kami berencana menghapus dukungan untuk bilangan kompleks untuk operasi ini (#560).
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | min |
tensor atau tensor terkuantisasi per tensor | (C1), (C3) |
(I2) | operand |
tensor atau tensor terkuantisasi per tensor | (C1-C4) |
(I3) | max |
tensor atau tensor terkuantisasi per tensor | (C2), (C3) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau tensor terkuantisasi per tensor | (C4) |
Batasan
- (C1)
rank(min) = 0 or shape(min) = shape(operand)
. - (C2)
rank(max) = 0 or shape(max) = shape(operand)
. - (C3)
baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max)
. - (C4)
baseline_type(operand) = baseline_type(result)
.
Contoh
// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max) : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// %result: [5, 13, 20]
collective_broadcast
Semantik
Dalam setiap grup proses pada petak proses StableHLO, kirim nilai
tensor operand
dari proses sumber ke proses target dan hasilkan
tensor result
.
Operasi tersebut membagi grid proses StableHLO menjadi process_groups
yang
ditentukan sebagai berikut:
cross_replica(replica_groups)
ifchannel_id <= 0
.cross_partition(replica_groups)
jikachannel_id > 0
.
Setelah itu, result@process
diberikan oleh:
operand@process_groups[i, 0]
jika adai
sehingga prosesnya berada diprocess_groups[i]
.broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))
jika tidak.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor atau tensor terkuantisasi per tensor | (C3) |
(I2) | replica_groups |
jumlah variadik konstanta tensor 1 dimensi jenis si64 |
(C1), (C2) |
(I3) | channel_id |
konstanta dari jenis si64 |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau tensor terkuantisasi per tensor | (C3) |
Batasan
- (C1)
is_unique(replica_groups)
. - (C2)
0 <= replica_groups < N
denganN
ditentukan sebagai:num_replicas
jikacross_replica
digunakan.num_partitions
jikacross_partition
digunakan.
- (C3)
type(result) = type(operand)
.
Contoh
// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]
collective_permute
Semantik
Dalam setiap grup proses pada petak proses StableHLO, mengirimkan nilai
tensor operand
dari proses sumber ke proses target dan menghasilkan
tensor result
.
Operasi ini membagi petak proses StableHLO menjadi process_groups
yang
ditentukan sebagai berikut:
cross_replica(source_target_pairs)
jikachannel_id <= 0
.cross_partition(source_target_pairs)
ifchannel_id > 0
.
Setelah itu, result@process
diberikan oleh:
operand@process_groups[i, 0]
, jika adai
sehinggaprocess_groups[i, 1] = process
.broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))
jika tidak.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor atau tensor terkuantisasi per tensor | (C5) |
(I2) | source_target_pairs |
Konstanta tensor 2 dimensi dari jenis si64 |
(C1-C4) |
(I3) | channel_id |
konstanta dari jenis si64 |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau tensor terkuantisasi per tensor | (C1) |
Batasan
- (C1)
dim(source_target_pairs, 1) = 2
. - (C2)
is_unique(source_target_pairs[:, 0])
. - (C3)
is_unique(source_target_pairs[:, 1])
. - (C4)
0 <= source_target_pairs < N
, denganN
ditentukan sebagai:num_replicas
jikacross_replica
digunakan.num_partitions
jikacross_partition
digunakan.
- (C5)
type(result) = type(operand)
.
Contoh
// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]
bandingkan
Semantik
Melakukan perbandingan element-wise dari tensor lhs
dan rhs
sesuai dengan
comparison_direction
dan compare_type
, serta menghasilkan tensor result
.
Nilai comparison_direction
dan compare_type
memiliki semantik
berikut:
Untuk jenis elemen boolean dan bilangan bulat:
EQ
:lhs = rhs
.NE
:lhs != rhs
.GE
:lhs >= rhs
.GT
:lhs > rhs
.LE
:lhs <= rhs
.LT
:lhs < rhs
.
Untuk jenis elemen floating point dengan compare_type = FLOAT
, op menerapkan
operasi IEEE-754 berikut:
EQ
:compareQuietEqual
.NE
:compareQuietNotEqual
.GE
:compareQuietGreaterEqual
.GT
:compareQuietGreater
.LE
:compareQuietLessEqual
.LT
:compareQuietLess
.
Untuk jenis elemen floating point dengan compare_type = TOTALORDER
, op
menggunakan kombinasi operasi totalOrder
dan compareQuietEqual
dari
IEEE-754.
Untuk jenis elemen yang kompleks, perbandingan leksikografis pasangan (real, imag)
dilakukan menggunakan comparison_direction
dan compare_type
yang disediakan.
Menerapkan pengurutan pada bilangan kompleks melibatkan semantik yang mengejutkan,
jadi pada masa mendatang, kami berencana untuk menghapus dukungan untuk bilangan kompleks
saat comparison_direction
adalah GE
, GT
, LE
, atau LT
(#560).
Untuk jenis kuantisasi. melakukan dequantize_compare(lhs, rhs,
comparison_direction)
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | lhs |
tensor atau tensor terkuantisasi per-tensor | (C1-C3) |
(I2) | rhs |
tensor atau tensor terkuantisasi per-tensor | (C1-C2) |
(I3) | comparison_direction |
enum EQ , NE , GE , GT , LE , dan LT |
|
(I4) | compare_type |
enum FLOAT , TOTALORDER , SIGNED , dan UNSIGNED |
(C3) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor dari jenis boolean | (C2) |
Batasan
- (C1)
baseline_element_type(lhs) = baseline_element_type(rhs)
. - (C2)
shape(lhs) = shape(rhs) = shape(result)
. - (C3)
compare_type
ditentukan sebagai:SIGNED
ifis_signed_integer(element_type(lhs))
.UNSIGNED
jikais_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs))
.FLOAT
atauTOTALORDER
jikais_float(element_type(lhs))
.FLOAT
jikais_complex(element_type(lhs))
.
Contoh
// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
comparison_direction = #stablehlo<comparison_direction LT>,
compare_type = #stablehlo<comparison_type FLOAT>
} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
// %result: [true, false]
kompleks
Semantik
Melakukan konversi elemen ke nilai kompleks dari sepasang nilai riil dan
imaginer, lhs
dan rhs
, serta menghasilkan tensor result
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | lhs |
tensor jenis f32 atau f64 |
(C1-C3) |
(I2) | rhs |
tensor jenis f32 atau f64 |
(C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor tipe kompleks | (C2), (C3) |
Batasan
- (C1)
type(lhs) = type(rhs)
. - (C2)
shape(result) = shape(lhs)
. - (C3)
element_type(result)
memiliki jeniscomplex<E>
denganE = element_type(lhs)
.
Contoh
// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs) : (tensor<2xf64>, tensor<2xf64>) -> tensor<2xcomplex<f64>>
// %result: [(1.0, 2.0), (3.0, 4.0)]
gabungan
Semantik
Mengenkapsulasi operasi yang terdiri (tersusun) dari operasi StableHLO lainnya,
menggunakan inputs
dan composite_attributes
, serta menghasilkan results
. Semantik op diterapkan oleh atribut decomposition
. Operasi
composite
dapat diganti dengan dekomposisinya tanpa mengubah semantik
program. Jika menyisipkan dekomposisi tidak memberikan semantik
op yang sama, sebaiknya gunakan custom_call
.
Kolom version
(default-nya adalah 0
) digunakan untuk menunjukkan kapan semantik komposit berubah.
Input
Label | Nama | Jenis |
---|---|---|
(I1) | inputs |
jumlah nilai variadik |
(I2) | name |
konstanta dari jenis string |
(I3) | composite_attributes |
kamus atribut |
(I4) | decomposition |
konstanta jenis string |
(I5) | version |
konstanta jenis si32 |
Output
Nama | Jenis |
---|---|
results |
jumlah nilai variadik |
Batasan
- (C1)
is_namespaced_op_name(name)
- (C2)
is_defined_in_parent_scope(decomposition)
- (C3)
types(inputs...) == input_types(decomposition)
- (C4)
types(results...) == output_types(decomposition)
Contoh
%results = "stablehlo.composite"(%input0, %input1) {
name = "my_namespace.my_op",
composite_attributes = {
my_attribute = "my_value"
},
decomposition = @my_op,
version = 1 : i32
} : (tensor<f32>, tensor<f32>) -> tensor<f32>
concatenate
Semantik
Menggabungkan inputs
di sepanjang dimensi dimension
dalam urutan yang sama dengan argumen
yang diberikan dan menghasilkan tensor result
. Secara lebih formal,
result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1]
, dengan:
id = d0 + ... + dk-1 + kd
.d
sama dengandimension
, dand0
, ... adalah ukuran dimensid
dariinputs
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | inputs |
jumlah variabel tensor atau tensor terkuantisasi per tensor | (C1-C6) |
(I2) | dimension |
konstanta dari jenis si64 |
(C2), (C4), (C6) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau tensor terkuantisasi per-tensor | (C5-C6) |
Batasan
- (C1)
same(element_type(inputs...))
. - (C2)
same(shape(inputs...))
kecuali untukdim(inputs..., dimension)
. - (C3)
0 < size(inputs)
. - (C4)
0 <= dimension < rank(inputs[0])
. - (C5)
element_type(result) = element_type(inputs[0])
. - (C6)
shape(result) = shape(inputs[0])
kecuali untuk:dim(result, dimension) = dim(inputs[0], dimension) + ...
.
Contoh
// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
dimension = 0 : i64
} : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64>
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]
konstanta
Semantik
Menghasilkan tensor output
dari value
yang konstan.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | value |
konstanta | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
output |
tensor atau tensor terkuantisasi | (C1) |
Batasan
- (C1)
type(value) = type(output)
.
Contoh
%output = "stablehlo.constant"() {
value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
} : () -> tensor<2x2xf32>
// %output: [[0.0, 1.0], [2.0, 3.0]]
melakukan konversi
Semantik
Melakukan konversi elemen dari satu jenis elemen ke jenis elemen lainnya pada
tensor operand
dan menghasilkan tensor result
.
Untuk konversi boolean-to-any-supported-type, nilai false
dikonversi menjadi nol, dan nilai true
dikonversi menjadi satu. Untuk
konversi any-supported-type-to-boolean, nilai nol dikonversi menjadi
false
, dan nilai non-nol dikonversi menjadi true
. Lihat di bawah untuk mengetahui cara kerja
ini untuk jenis yang kompleks.
Untuk konversi yang melibatkan bilangan bulat ke bilangan bulat, bilangan bulat ke floating point, atau floating point ke floating point, jika nilai sumber dapat direpresentasikan secara tepat dalam jenis tujuan, nilai hasilnya adalah representasi yang tepat tersebut. Jika tidak, perilakunya adalah TBD (#180).
Untuk konversi yang melibatkan floating-point-to-integer, bagian pecahan akan terpotong. Jika nilai yang dipangkas tidak dapat direpresentasikan dalam jenis tujuan, perilakunya adalah TBD (#180).
Konversi yang melibatkan kompleks ke kompleks mengikuti perilaku yang sama dengan konversi floating point ke floating point untuk mengonversi bagian riil dan imajiner.
Untuk konversi complex-to-any-other-type dan any-other-type-to-complex, nilai imajiner sumber diabaikan atau nilai imajiner tujuan disetel ke nol. Konversi bagian riil mengikuti konversi floating point.
Pada prinsipnya, operasi ini dapat mengekspresikan dekuantisasi (konversi dari
tensor kuantisasi menjadi tensor reguler), kuantisasi (konversi dari
tensor reguler menjadi tensor kuantisasi), dan rekuantisasi (konversi antara tensor
kuantisasi), tetapi saat ini kami memiliki operasi khusus untuk itu -
uniform_dequantize
untuk kasus penggunaan pertama dan uniform_quantize
untuk
kasus penggunaan kedua dan ketiga. Nantinya, kedua operasi ini dapat digabungkan
ke dalam convert
(#1576).
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor | (C1) |
Batasan
- (C1)
shape(operand) = shape(result)
.
Contoh
// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand) : (tensor<3xi64>) -> tensor<3xcomplex<f64>>
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]
konvolusi
Semantik
Menghitung perkalian titik antara jendela lhs
dan slice rhs
serta menghasilkan
result
. Diagram berikut menunjukkan cara elemen di result
dihitung dari
lhs
dan rhs
menggunakan contoh konkret.
Secara lebih formal, pertimbangkan penyusunan ulang input berikut dalam hal lhs
agar dapat mengekspresikan jendela lhs
:
lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension))
.lhs_window_strides = lhs_shape(1, window_strides, 1)
.lhs_padding = lhs_shape([0, 0], padding, [0, 0])
.lhs_base_dilations = lhs_shape(1, lhs_dilation, 1)
.lhs_window_dilations = lhs_shape(1, rhs_dilation, 1)
.
Pembingkaian ulang ini menggunakan fungsi bantuan berikut:
lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension])
.result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension])
.permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1]
denganj[d] = i[permutation[d]]
.
Jika feature_group_count = 1
dan batch_group_count = 1
, maka untuk semua
output_spatial_index
di index_space(dim(result, output_spatial_dimensions...))
,
result[result_shape(:, output_spatial_index, :)] = dot_product
dengan:
padding_value = constant(0, element_type(lhs))
.padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1)
.lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides
.lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations)
.reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true])
. Fitur ini tampaknya tidak digunakan, jadi pada masa mendatang kami berencana untuk menghapusnya (#1181).dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension])
.
Jika feature_group_count > 1
:
lhses = split(lhs, feature_group_count, input_feature_dimension)
.rhses = split(rhs, feature_group_count, kernel_output_feature_dimension)
.results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...)
.result = concatenate(results, output_feature_dimension)
.
Jika batch_group_count > 1
:
lhses = split(lhs, batch_group_count, input_batch_dimension)
.rhses = split(rhs, batch_group_count, kernel_output_feature_dimension)
.results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...)
.result = concatenate(results, output_feature_dimension)
.
Untuk jenis terkuantisasi, menjalankan dequantize_op_quantize(
lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs,
type(result))
.
Untuk jenis kuantisasi campuran, lakukan hybrid_dequantize_then_op(
lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs)
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | lhs |
tensor atau tensor terkuantisasi per-tensor | (C1), (C10-C11), (C14) (C25), (C27-C28), (C31-C32), (C34) |
(I2) | rhs |
tensor atau tensor terkuantisasi | (C1), (C14-C16), (C25), (C27-C29), (C31-C34) |
(I3) | window_strides |
Konstanta tensor 1 dimensi dari jenis si64 |
(C2-C3), (C25) |
(I4) | padding |
Konstanta tensor 2 dimensi dari jenis si64 |
(C4), (C25) |
(I5) | lhs_dilation |
Konstanta tensor 1 dimensi dari jenis si64 |
(C5-C6), (C25) |
(I6) | rhs_dilation |
Konstanta tensor 1 dimensi dari jenis si64 |
(C7-C8), (C25) |
(I7) | window_reversal |
Konstanta tensor 1 dimensi dari jenis i1 |
(C9) |
(I8) | input_batch_dimension |
konstanta dari jenis si64 |
(C10), (C13), (C25) |
(I9) | input_feature_dimension |
konstanta jenis si64 |
(C11), (C13-C14) |
(I10) | input_spatial_dimensions |
Konstanta tensor 1 dimensi jenis si64 |
(C12), (C13), (C25) |
(I11) | kernel_input_feature_dimension |
konstanta jenis si64 |
(C14), (C18) |
(I12) | kernel_output_feature_dimension |
konstanta dari jenis si64 |
(C15-C16), (C18), (C25), (C29) |
(I13) | kernel_spatial_dimensions |
Konstanta tensor 1 dimensi jenis si64 |
(C17-C18), (C25) |
(I14) | output_batch_dimension |
konstanta dari jenis si64 |
(C20), (C25) |
(I15) | output_feature_dimension |
konstanta dari jenis si64 |
(C20), (C25), (C30) |
(I16) | output_spatial_dimensions |
Konstanta tensor 1 dimensi dari jenis si64 |
(C19-C20), (C25) |
(I17) | feature_group_count |
konstanta dari jenis si64 |
(C11), (C14), (C16), (C21), (C23) |
(I18) | batch_group_count |
konstanta dari jenis si64 |
(C10), (C15), (C22), (C23), (C25) |
(I19) | precision_config |
jumlah variabel enum DEFAULT , HIGH , dan HIGHEST |
(C24) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau tensor terkuantisasi | (C25-C28), (C30), (C32-34) |
Batasan
- (C1)
N = rank(lhs) = rank(rhs)
. - (C2)
size(window_strides) = N - 2
. - (C3)
0 < window_strides
. - (C4)
shape(padding) = [N - 2, 2]
. - (C5)
size(lhs_dilation) = N - 2
. - (C6)
0 < lhs_dilation
. - (C7)
size(rhs_dilation) = N - 2
. - (C8)
0 < rhs_dilation
. - (C9)
size(window_reversal) = N - 2
. - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0
. - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0
. - (C12)
size(input_spatial_dimensions) = N - 2
. - (C13) Dengan
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]
:is_unique(input_dimensions)
.0 <= input_dimensions < N
.
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count
. - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0
. - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0
. - (C17)
size(kernel_spatial_dimensions) = N - 2
. - (C18) Dengan
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]
:is_unique(kernel_dimensions)
.0 <= kernel_dimensions < N
.
- (C19)
size(output_spatial_dimensions) = N - 2
. - (C20) Diberikan
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]
:is_unique(output_dimensions)
.0 <= output_dimensions < N
.
- (C21)
0 < feature_group_count
. - (C22)
0 < batch_group_count
. - (C23)
feature_group_count = 1 or batch_group_count = 1
. - (C24)
size(precision_config) = 2
. - (C25)
dim(result, result_dim)
ditentukan sebagai:dim(lhs, input_batch_dimension) / batch_group_count
ifresult_dim = output_batch_dimension
.dim(rhs, kernel_output_feature_dimension)
jikaresult_dim = output_feature_dimension
.num_windows
jika tidak, dengan:output_spatial_dimensions[spatial_dim] = result_dim
.lhs_dim = input_spatial_dimensions[spatial_dim]
.rhs_dim = kernel_spatial_dimensions[spatial_dim]
.dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
.padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
.dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
.is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
.num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1
.
- (C26)
rank(result) = N
. - Jika operasi menggunakan tensor yang tidak dikuantisasi:
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result)
.
- (C27)
- Jika operasi menggunakan tensor kuantisasi:
- (C28)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
. - (C29) Jika
is_per_axis_quantized(rhs)
, makaquantization_dimension(rhs) = kernel_output_feature_dimension
. - (C30) Jika
is_per_axis_quantized(result)
, makaquantization_dimension(result) = output_feature_dimension
. - Jika
is_quantized(lhs)
: - (C31)
storage_type(lhs) = storage_type(rhs)
. - (C32)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C33) Jika
is_per_tensor_quantized(rhs)
, makais_per_tensor_quantized(result)
. - Jika
!is_quantized(lhs)
: - (C34)
element_type(lhs) = expressed_type(rhs) = element_type(result)
.
- (C28)
Contoh
// %lhs: [[
// [
// [1], [2], [5], [6]
// ],
// [
// [3], [4], [7], [8]
// ],
// [
// [10], [11], [14], [15]
// ],
// [
// [12], [13], [16], [17]
// ]
// ]]
//
// %rhs: [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
window_strides = array<i64: 4, 4>,
padding = dense<0> : tensor<2x2xi64>,
lhs_dilation = array<i64: 2, 2>,
rhs_dilation = array<i64: 1, 1>,
window_reversal = array<i1: false, false>,
// In the StableHLO dialect, dimension numbers are encoded via:
// `[<input dimensions>]x[<kernel dimensions>]->[output dimensions]`.
// "b" is batch dimension, "f" is feature dimension,
// "i" is input feature dimension, "o" is output feature dimension,
// "0/1/etc" are spatial dimensions.
dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
batch_group_count = 1 : i64,
feature_group_count = 1 : i64,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
// [[10], [26]],
// [[46], [62]]
// ]]
kosinus
Semantik
Melakukan operasi kosinus berbasis elemen pada tensor operand
dan menghasilkan
tensor result
. Bergantung pada jenis elemen, lakukan tindakan berikut:
- Untuk float:
cos
dari IEEE-754. - Untuk bilangan kompleks: kosinus kompleks.
- Untuk jenis kuantisasi:
dequantize_op_quantize(cosine, operand, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor floating point atau jenis kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor dari jenis floating point atau kompleks atau tensor terkuantisasi per tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result)
.
Contoh
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.cosine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.0], [-1.0, 0.0]]
count_leading_zeros
Semantik
Melakukan penghitungan element-wise dari jumlah bit nol di awal dalam tensor operand
dan menghasilkan tensor result
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor dari jenis bilangan bulat | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor dari jenis bilangan bulat | (C1) |
Batasan
- (C1)
type(operand) = type(result)
.
Contoh
// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand) : (tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[64, 63], [56, 0]]
custom_call
Semantik
Mengenkapsulasi operasi call_target_name
yang ditentukan implementasi yang menggunakan
inputs
dan called_computations
serta menghasilkan results
. has_side_effect
,
backend_config
, dan api_version
dapat digunakan untuk memberikan metadata tambahan
yang ditentukan implementasi.
Saat ini, operasi ini berisi kumpulan metadata yang cukup tidak teratur yang mencerminkan evolusi organik dari operasi yang setara di compiler XLA. Pada masa mendatang, kami berencana untuk menyatukan metadata ini (#741).
Input
Label | Nama | Jenis |
---|---|---|
(I1) | inputs |
jumlah nilai variadik |
(I2) | call_target_name |
konstanta dari jenis string |
(I3) | has_side_effect |
konstanta dari jenis i1 |
(I4) | backend_config |
konstanta jenis string atau kamus atribut |
(I5) | api_version |
konstanta dari jenis si32 |
(I6) | called_computations |
jumlah konstanta variadik jenis string |
Output
Nama | Jenis |
---|---|
results |
jumlah nilai variadis |
Contoh
%results = "stablehlo.custom_call"(%input0) {
call_target_name = "foo",
has_side_effect = false,
backend_config = {bar = 42 : i32},
api_version = 4 : i32,
called_computations = [@foo]
} : (tensor<f64>) -> tensor<f64>
bagi
Semantik
Melakukan pembagian element-wise dari tensor dividen lhs
dan pembagi rhs
dan
menghasilkan tensor result
. Bergantung pada jenis elemen, lakukan tindakan berikut:
- Untuk bilangan bulat: pembagian bilangan bulat yang menghasilkan hasil bagi aljabar dengan bagian pecahan yang dihapus.
- Untuk float:
division
dari IEEE-754. - Untuk bilangan kompleks: pembagian kompleks.
- Untuk jenis kuantisasi:
dequantize_op_quantize(divide, lhs, rhs, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | lhs |
tensor bilangan bulat, floating point, atau kompleks, atau tensor terkuantisasi per tensor | (C1) |
(I2) | rhs |
tensor bilangan bulat, floating point, atau kompleks, atau tensor terkuantisasi per tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor bilangan bulat, floating point, atau jenis kompleks atau tensor terkuantisasi per tensor | (C1) |
Batasan
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Contoh
// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]
dot_general
Semantik
Menghitung perkalian titik antara slice lhs
dan slice rhs
serta menghasilkan
tensor result
.
Secara lebih formal, result[result_index] = dot_product
, dengan:
lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions]
.rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions]
.result_batching_index + result_lhs_index + result_rhs_index = result_index
dengansize(result_batching_index) = size(lhs_batching_dimensions)
,size(result_lhs_index) = size(lhs_result_dimensions)
, dansize(result_rhs_index) = size(rhs_result_dimensions)
.transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions)
.transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :])
.reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions))
.transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions)
.transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :])
.reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions))
.dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y))
.
Untuk jenis kuantisasi, lakukan dequantize_op_quantize(
lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
rhs_batching_dimensions, lhs_contracting_dimensions,
rhs_contracting_dimensions, precision_config), lhs, rhs, type(result))
.
Untuk jenis kuantisasi campuran, lakukan hybrid_dequantize_then_op(
lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
rhs_batching_dimensions, lhs_contracting_dimensions,
rhs_contracting_dimensions, precision_config), lhs, rhs)
.
precision_config
mengontrol kompromi antara kecepatan dan akurasi untuk
komputasi di backend akselerator. Ini dapat berupa salah satu dari berikut ini (saat ini, semantik nilai enum ini tidak ditentukan, tetapi kami berencana untuk mengatasinya di #755):
DEFAULT
: Penghitungan tercepat, tetapi perkiraan yang paling tidak akurat terhadap angka asli.HIGH
: Penghitungan lebih lambat, tetapi perkiraan ke angka asli lebih akurat.HIGHEST
: Penghitungan paling lambat, tetapi perkiraan paling akurat terhadap angka asli.
DotAlgorithm
menentukan properti utama algoritma yang digunakan untuk menerapkan
operasi titik, yang juga menentukan presisi. Jika kolom atribut
algoritma ditetapkan, precision_config
harus berupa DEFAULT
. DotAlgorithms
tidak memiliki nilai default, karena parameter default ditentukan oleh
implementasi. Dengan demikian, semua kolom algoritma titik dapat ditetapkan ke None
untuk menentukan
algoritma titik kosong, yang akan menggunakan nilai precision_config
.
Kolom DotAlgorithm
mencakup:
lhs_precision_type
danrhs_precision_type
, presisi yang digunakan untuk membulatkan LHS dan RHS operasi. Jenis presisi tidak bergantung pada jenis penyimpanan input dan output.accumulation_type
presisi yang digunakan untuk akumulasi.lhs_component_count
,rhs_component_count
, dannum_primitive_operations
berlaku saat kita melakukan algoritma yang menguraikan LHS dan/atau RHS menjadi beberapa komponen dan melakukan beberapa operasi titik "primitif" pada nilai tersebut, biasanya untuk mengemulasi presisi yang lebih tinggi (misalnya Memanfaatkan Jenis Data Kecerdasan Buatan bfloat16 Untuk Komputasi Presisi Tinggi: bf16_6x, tf32). Untuk algoritma tanpa dekomposisi, nilai ini harus ditetapkan ke1
.allow_imprecise_accumulation
untuk menentukan apakah akumulasi dalam presisi yang lebih rendah diizinkan untuk beberapa langkah (misalnya,CUBLASLT_MATMUL_DESC_FAST_ACCUM
).
Contoh atribut DotAlgorithm
:
// Inputs are casted to tf32, and then accumulated in f32:
{lhs_precision_type = tf32,
rhs_precision_type = tf32,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = false}
// bf16_6x: each input is decomposed to 3 bf16 components, then 6 dot operations are done on those components, and the result is accumulated in f32.
{lhs_precision_type = bf16,
rhs_precision_type = bf16,
accumulation_type = f32,
lhs_component_count = 3,
rhs_component_count = 3,
num_primitive_operations = 6,
allow_imprecise_accumulation = false}
// Inputs are (casted to) f8e5m2, and we accumulate in f32, but for some steps we may accumulate in lower precision.
{lhs_precision_type = f8e5m2,
rhs_precision_type = f8e5m2,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = true}
Implementasi bebas untuk memutuskan kombinasi mana yang didukung. Secara umum, tidak dijamin bahwa setiap algoritma didukung di setiap jenis akselerator oleh konsumen StableHLO. Jika algoritma tertentu tidak didukung, error harus ditampilkan, bukan kembali ke alternatif. Verifikasi StableHLO akan memberikan verifikasi upaya terbaik, yang mencegah algoritma yang tidak diketahui didukung di hardware mana pun.
Lihat xla_data.proto > Algorithm
untuk mengetahui beberapa nilai algoritma yang didukung. Tiket #2483 berisi rencana untuk membuat dokumen terpusat tentang algoritma yang didukung oleh backend.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | lhs |
tensor atau tensor terkuantisasi per tensor | (C5-C6), (C9-C10), (C12-C14), (C17-C18), (C20) |
(I2) | rhs |
tensor atau tensor terkuantisasi | (C7-C10), (C12-C20) |
(I3) | lhs_batching_dimensions |
Konstanta tensor 1 dimensi dari jenis si64 |
(C1), (C3), (C5), (C9), (C12) |
(I4) | rhs_batching_dimensions |
Konstanta tensor 1 dimensi dari jenis si64 |
(C1), (C4), (C7), (C9) |
(I5) | lhs_contracting_dimensions |
Konstanta tensor 1 dimensi dari jenis si64 |
(C2), (C3), (C6), (C10) |
(I6) | rhs_contracting_dimensions |
Konstanta tensor 1 dimensi dari jenis si64 |
(C2), (C4), (C8), (C10), (C16) |
(I7) | precision_config |
jumlah variabel enum DEFAULT , HIGH , dan HIGHEST |
(C11), C21) |
(I8) | lhs_precision_type |
FloatType atau TensorFloat32 | (C21) |
(I9) | rhs_precision_type |
FloatType atau TensorFloat32 | (C21) |
(I10) | accumulation_type |
FloatType atau TensorFloat32 | (C21) |
(I11) | lhs_component_count |
konstanta dari jenis si32 |
(C21), (C22) |
(I12) | rhs_component_count |
konstanta dari jenis si32 |
(C21), C23 |
(I13) | num_primitive_operations |
konstanta dari jenis si32 |
(C21), C24 |
(I14) | allow_imprecise_accumulation |
konstanta dari jenis bool |
(C21) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau quantized tensor | (C12), (C14), (C18-C20) |
Batasan
- (C1)
size(lhs_batching_dimensions) = size(rhs_batching_dimensions)
. - (C2)
size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions)
. - (C3)
is_unique(lhs_batching_dimensions + lhs_contracting_dimensions)
. - (C4)
is_unique(rhs_batching_dimensions + rhs_contracting_dimensions)
. - (C5)
0 <= lhs_batching_dimensions < rank(lhs)
. - (C6)
0 <= lhs_contracting_dimensions < rank(lhs)
. - (C7)
0 <= rhs_batching_dimensions < rank(rhs)
. - (C8)
0 <= rhs_contracting_dimensions < rank(rhs)
. - (C9)
dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)
. - (C10)
dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...)
. - (C11)
size(precision_config) = 2
. - (C12)
shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions)
. - Jika operasi menggunakan tensor yang tidak dikuantisasi:
- (C13)
element_type(lhs) = element_type(rhs)
.
- (C13)
- Jika operasi menggunakan tensor kuantisasi:
- (C14)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
. - (C15)
zero_points(rhs) = 0
. - (C16) Jika
is_per_axis_quantized(rhs)
, makaquantization_dimension(rhs)
tidak ada dalamrhs_contracting_dimensions
. - Jika
is_quantized(lhs)
: - (C17)
storage_type(lhs) = storage_type(rhs)
. - (C18)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C19) Jika
is_per_tensor_quantized(rhs)
, makais_per_tensor_quantized(result)
. - Jika
!is_quantized(lhs)
: - (C20)
element_type(lhs) = expressed_type(rhs) = element_type(result)
.
- (C14)
- Jika
!is_empty_algorithm(lhs_precision_type, rhs_precision_type, accumulation_type, lhs_component_count, rhs_component_count, num_primitive_operations allow_imprecise_accumulation)
:- (C21)
precision_config... = DEFAULT
. - (C22)
0 < lhs_component_count
. - (C23)
0 < rhs_component_count
. - (C24)
0 < num_primitive_operations
.
- (C21)
Contoh
// %lhs: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
// %rhs: [
// [[1, 0],
// [0, 1]],
// [[1, 0],
// [0, 1]]
// ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
dot_dimension_numbers = #stablehlo.dot<
lhs_batching_dimensions = [0],
rhs_batching_dimensions = [0],
lhs_contracting_dimensions = [2],
rhs_contracting_dimensions = [1]
>,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>],
algorithm = #stablehlo.dot_algorithm<
lhs_precision_type = tf32,
rhs_precision_type = tf32,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = false
>
} : (tensor<2x2x2xi64>, tensor<2x2x2xi64>) -> tensor<2x2x2xi64>
// %result: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
dynamic_broadcast_in_dim
Semantik
Operasi ini secara fungsional identik dengan operasi broadcast_in_dim, tetapi bentuk hasilnya ditetapkan secara dinamis melalui output_dimensions
.
Operasi tersebut juga menerima atribut opsional known_expanding_dimensions
, known_nonexpanding_dimensions
untuk menyatakan pengetahuan statis tentang perilaku dimensi yang diperluas.
Jika tidak ditentukan, semua dimensi diasumsikan dapat diperluas.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor atau tensor terkuantisasi | (C1-C2), (C5-C6), (C9) |
(I2) | output_dimensions |
Tensor 1-dimensi dari jenis integer | (C7) |
(I3) | broadcast_dimensions |
Tensor konstan 1 dimensi dari jenis bilangan bulat | (C2-C6) |
(I4) | known_expanding_dimensions |
Tensor konstanta 1 dimensi dari jenis integer | (C8-C9) |
(I5) | known_nonexpanding_dimensions |
Tensor konstan 1 dimensi dari jenis bilangan bulat | (C8-C9) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau tensor terkuantisasi | (C1), (C3), (C5-C7) |
Batasan
- (C1)
element_type(result)
diberikan oleh:element_type(operand)
, jika!is_per_axis_quantized(operand)
.element_type(operand)
kecuali bahwaquantization_dimension(operand)
,scales(operand)
, danzero_points(operand)
mungkin berbeda dariquantization_dimension(result)
,scales(result)
, danzero_points(result)
masing-masing.
- (C2)
size(broadcast_dimensions) = rank(operand)
. - (C3)
0 <= broadcast_dimensions < rank(result)
. - (C4)
is_unique(broadcast_dimensions)
. - (C5) Untuk semua
d
diaxes(operand)
:dim(operand, d) = 1
ataudim(operand, d) = dim(result, broadcast_dimensions[d])
.
- (C6) Jika
is_per_axis_quantized(result)
:quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)]
.- Jika
dim(operand, quantization_dimension(operand)) = 1
, makascales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result)))
.
- (C7)
size(output_dimensions) = rank(result)
. - (C8)
is_unique(known_expanding_dimensions + known_nonexpanding_dimensions)
. - (C9)
0 <= known_expanding_dimensions < rank(operand)
. - (C10)
0 <= known_nonexpanding_dimensions < rank(operand)
.
Contoh
// %operand: [
// [1, 2, 3]
// ]
%operand = stablehlo.constant dense<[[1, 2, 3]]> : tensor<1x3xi64>
%output_dimensions = stablehlo.constant dense<[2, 3, 2]> : tensor<3xi64>
%result = "stablehlo.dynamic_broadcast_in_dim"(%operand, %output_dimensions) {
broadcast_dimensions = array<i64: 2, 1>,
known_expanding_dimensions = array<i64: 0>,
known_nonexpanding_dimensions = array<i64: 1>
} : (tensor<1x3xi64>, tensor<3xi64>) -> tensor<2x3x2xi64>
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
dynamic_conv
Semantik
Operasi ini secara fungsional identik dengan
op
konvolusi, tetapi padding ditentukan secara dinamis melalui padding
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | lhs |
tensor atau tensor terkuantisasi per-tensor | (C1), (C10-C11), (C14) (C25), (C26-C27), (C30-C31), (C33) |
(I2) | rhs |
tensor atau tensor terkuantisasi | (C1), (C14-C16), (C26-C28), (C30-C33) |
(I3) | padding |
Tensor 2 dimensi dari jenis bilangan bulat | (C4) |
(I4) | window_strides |
Konstanta tensor 1 dimensi dari jenis si64 |
(C2-C3) |
(I5) | lhs_dilation |
Konstanta tensor 1 dimensi jenis si64 |
(C5-C6) |
(I6) | rhs_dilation |
Konstanta tensor 1 dimensi dari jenis si64 |
(C7-C8) |
(I7) | window_reversal |
Konstanta tensor 1 dimensi dari jenis i1 |
(C9) |
(I8) | input_batch_dimension |
konstanta dari jenis si64 |
(C10), (C13) |
(I9) | input_feature_dimension |
konstanta jenis si64 |
(C11), (C13-C14) |
(I10) | input_spatial_dimensions |
Konstanta tensor 1 dimensi dari jenis si64 |
(C12), C13) |
(I11) | kernel_input_feature_dimension |
konstanta jenis si64 |
(C14), (C18) |
(I12) | kernel_output_feature_dimension |
konstanta dari jenis si64 |
(C15-C16), (C18), (C28) |
(I13) | kernel_spatial_dimensions |
Konstanta tensor 1 dimensi dari jenis si64 |
(C17-C18) |
(I14) | output_batch_dimension |
konstanta dari jenis si64 |
(C20) |
(I15) | output_feature_dimension |
konstanta jenis si64 |
(C20), (C29) |
(I16) | output_spatial_dimensions |
Konstanta tensor 1 dimensi jenis si64 |
(C19-C20) |
(I17) | feature_group_count |
konstanta dari jenis si64 |
(C11), (C14), (C16), (C21), (C23) |
(I18) | batch_group_count |
konstanta dari jenis si64 |
(C10), (C15), (C22), (C23) |
(I19) | precision_config |
jumlah variabel enum DEFAULT , HIGH , dan HIGHEST |
(C24) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau tensor terkuantisasi | (C25-C27), (C29), (C31-C33) |
Batasan
- (C1)
N = rank(lhs) = rank(rhs)
. - (C2)
size(window_strides) = N - 2
. - (C3)
0 < window_strides
. - (C4)
shape(padding) = [N - 2, 2]
. - (C5)
size(lhs_dilation) = N - 2
. - (C6)
0 < lhs_dilation
. - (C7)
size(rhs_dilation) = N - 2
. - (C8)
0 < rhs_dilation
. - (C9)
size(window_reversal) = N - 2
. - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0
. - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0
. - (C12)
size(input_spatial_dimensions) = N - 2
. - (C13) Dengan
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]
:is_unique(input_dimensions)
.0 <= input_dimensions < N
.
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count
. - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0
. - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0
. - (C17)
size(kernel_spatial_dimensions) = N - 2
. - (C18) Dengan
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]
:is_unique(kernel_dimensions)
.0 <= kernel_dimensions < N
.
- (C19)
size(output_spatial_dimensions) = N - 2
. - (C20) Diberikan
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]
:is_unique(output_dimensions)
.0 <= output_dimensions < N
.
- (C21)
0 < feature_group_count
. - (C22)
0 < batch_group_count
. - (C23)
feature_group_count = 1 or batch_group_count = 1
. - (C24)
size(precision_config) = 2
. - (C25)
dim(result, result_dim)
ditentukan sebagai:dim(lhs, input_batch_dimension) / batch_group_count
ifresult_dim = output_batch_dimension
.dim(rhs, kernel_output_feature_dimension)
jikaresult_dim = output_feature_dimension
.num_windows
jika tidak, dengan:output_spatial_dimensions[spatial_dim] = result_dim
.lhs_dim = input_spatial_dimensions[spatial_dim]
.rhs_dim = kernel_spatial_dimensions[spatial_dim]
.dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
.padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
.dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
.is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
.num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1
.
- (C26)
rank(result) = N
. - Jika operasi menggunakan tensor yang tidak dikuantisasi:
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result)
.
- (C27)
- Jika operasi menggunakan tensor kuantisasi:
- (C28)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
. - (C29) Jika
is_per_axis_quantized(rhs)
, makaquantization_dimension(rhs) = kernel_output_feature_dimension
. - (C30) Jika
is_per_axis_quantized(result)
, makaquantization_dimension(result) = output_feature_dimension
. - Jika
is_quantized(lhs)
: - (C31)
storage_type(lhs) = storage_type(rhs)
. - (C32)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C33) Jika
is_per_tensor_quantized(rhs)
, makais_per_tensor_quantized(result)
. - Jika
!is_quantized(lhs)
: - (C34)
element_type(lhs) = expressed_type(rhs) = element_type(result)
.
- (C28)
Contoh
// %lhs: [[
// [[1], [2], [5], [6]],
// [[3], [4], [7], [8]],
// [[10], [11], [14], [15]],
// [[12], [13], [16], [17]]
// ]]
//
// %rhs: [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
// %padding: [[1, 1],
// [1, 1]]
%result = "stablehlo.dynamic_conv"(%lhs, %rhs, %padding) {
window_strides = array<i64: 4, 4>,
lhs_dilation = array<i64: 2, 2>,
rhs_dilation = array<i64: 1, 1>,
window_reversal = array<i1: false, false>,
dimension_numbers = #stablehlo.conv<raw
input_batch_dimension = 0,
input_feature_dimension = 3,
input_spatial_dimensions = [0, 1],
kernel_input_feature_dimension = 2,
kernel_output_feature_dimension = 3,
kernel_spatial_dimensions = [0, 1],
output_batch_dimension = 0,
output_feature_dimension = 3,
output_spatial_dimensions = [1, 2]
>,
feature_group_count = 1 : i64,
batch_group_count = 1 : i64,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>, tensor<2x2xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
// [[1], [5]],
// [[10], [14]]
// ]]
dynamic_gather
Semantik
Operasi ini secara fungsional identik dengan operasi mengumpulkan, dengan slice_sizes
yang ditentukan secara dinamis sebagai nilai.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor atau tensor terkuantisasi per tensor | (C1), (C7), (C10-C12), (C14) |
(I2) | start_indices |
tensor tipe integer | (C2), (C3), (C13) |
(I3) | slice_sizes |
Tensor 1 dimensi dari jenis bilangan bulat | (C8), (C11-C13) |
(I4) | offset_dims |
Konstanta tensor 1 dimensi dari jenis si64 |
(C1), (C4-C5), (C13) |
(I5) | collapsed_slice_dims |
Konstanta tensor 1 dimensi dari jenis si64 |
(C1), (C6-C8), (C13) |
(I6) | start_index_map |
Konstanta tensor 1 dimensi jenis si64 |
(C3), (C9), (C10) |
(I7) | index_vector_dim |
konstanta dari jenis si64 |
(C2), (C3), (C13) |
(I8) | indices_are_sorted |
konstanta jenis i1 |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau tensor terkuantisasi per-tensor | (C5), (C13-C14) |
Batasan
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims)
. - (C2)
0 <= index_vector_dim <= rank(start_indices)
. - (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1
. - (C4)
is_unique(offset_dims) and is_sorted(offset_dims)
. - (C5)
0 <= offset_dims < rank(result)
. - (C6)
is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims)
. - (C7)
0 <= collapsed_slice_dims < rank(operand)
. - (C8)
slice_sizes[collapsed_slice_dims...] <= 1
. - (C9)
is_unique(start_index_map)
. - (C10)
0 <= start_index_map < rank(operand)
. - (C11)
size(slice_sizes) = rank(operand)
. - (C12)
0 <= slice_sizes <= shape(operand)
. - (C13)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)
dengan:batch_dim_sizes = shape(start_indices)
, kecuali ukuran dimensistart_indices
yang sesuai denganindex_vector_dim
tidak disertakan.offset_dim_sizes = shape(slice_sizes)
, kecuali ukuran dimensi dislice_sizes
yang sesuai dengancollapsed_slice_dims
tidak disertakan.combine
menempatkanbatch_dim_sizes
pada sumbu yang sesuai denganbatch_dims
danoffset_dim_sizes
pada sumbu yang sesuai denganoffset_dims
.
- (C14)
element_type(operand) = element_type(result)
.
Contoh
// %operand: [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ]
// %start_indices: [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 2]]
// ]
// %slize_sizes: [1, 2, 2]
%result = "stablehlo.dynamic_gather"(%operand, %start_indices, %slize_sizes) {
dimension_numbers = #stablehlo.gather<
offset_dims = [2, 3],
collapsed_slice_dims = [0],
start_index_map = [1, 0],
index_vector_dim = 2>,
indices_are_sorted = false
} : (tensor<3x4x2xi64>, tensor<2x3x2xi64>, tensor<3xi64>) -> tensor<2x3x2x2xi64>
// %result: [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[9, 10], [11, 12]],
// [[11, 12], [13, 14]],
// [[17, 18], [19, 20]]
// ]
// ]
dynamic_iota
Semantik
Operasi ini secara fungsional identik dengan operasi iota, tetapi bentuk hasilnya ditetapkan secara dinamis melalui output_shape
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | output_shape |
Tensor 1-dimensi dari jenis integer | (C1), (C2) |
(I2) | iota_dimension |
si64 |
(C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor bilangan bulat, floating point, atau jenis kompleks atau tensor terkuantisasi per tensor | (C2) |
Batasan
- (C1)
0 <= iota_dimension < size(output_shape)
. - (C2)
rank(result) = size(output_shape)
.
Contoh
%output_shape = stablehlo.constant dense<[4, 5]> : tensor<2xi64>
%result = "stablehlo.dynamic_iota"(%output_shape) {
iota_dimension = 0 : i64
} : (tensor<2xi64>) -> tensor<4x5xi64>
// %result: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
dynamic_pad
Semantik
Operasi ini secara fungsional identik dengan
op
pad, tetapi dengan edge_padding_low
, edge_padding_high
, dan interior_padding
ditentukan secara dinamis sebagai nilai.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor atau tensor terkuantisasi per tensor | (C1), (C2), (C4) |
(I2) | padding_value |
Tensor 0 dimensi atau tensor terkuantisasi per tensor | (C1) |
(I3) | edge_padding_low |
Tensor 1 dimensi dari jenis bilangan bulat | (C1), C4 |
(I4) | edge_padding_high |
Tensor 1 dimensi dari jenis bilangan bulat | (C1), (C4) |
(I5) | interior_padding |
Tensor 1 dimensi dari jenis bilangan bulat | (C2-C4) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau tensor terkuantisasi per tensor | (C3-C6) |
Batasan
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result)
. - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)
. - (C3)
0 <= interior_padding
. - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high
.
Contoh
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
// %edge_padding_low: [0, 1]
// %edge_padding_high: [2, 1]
// %interior_padding: [1, 2]
%result = "stablehlo.dynamic_pad"(%operand, %padding_value,
%edge_padding_low, %edge_padding_high, %interior_padding
) : (tensor<2x3xi64>, tensor<i64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<5x9xi64>
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
dynamic_reshape
Semantik
Operasi ini secara fungsional identik dengan
operasi
reshape, tetapi bentuk hasilnya ditentukan secara dinamis melalui output_shape
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor atau tensor terkuantisasi | (C1-C3) |
(I2) | output_shape |
Tensor 1 dimensi dari jenis bilangan bulat | (C4) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau tensor terkuantisasi | (C1-C4) |
Batasan
- (C1)
element_type(result)
diberikan oleh:element_type(operand)
, jika!is_per_axis_quantized(operand)
.element_type(operand)
kecuali bahwaquantization_dimension(operand)
danquantization_dimension(result)
mungkin berbeda, jika tidak.
- (C2)
size(operand) = size(result)
. - (C3) Jika
is_per_axis_quantized(operand)
:reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
.reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.
- (C4)
size(output_shape) = rank(result)
.
Contoh
// %operand: [[1, 2, 3], [4, 5, 6]]
// %output_shape: [3, 2]
%result = "stablehlo.dynamic_reshape"(%operand, %output_shape) : (tensor<2x3xi64>, tensor<2xi64>) -> tensor<3x2xi64>
// %result: [[1, 2], [3, 4], [5, 6]]
dynamic_slice
Semantik
Mengekstrak slice dari operand
menggunakan indeks awal yang dihitung secara dinamis
dan menghasilkan tensor result
. start_indices
berisi indeks awal
slice untuk setiap dimensi yang dapat disesuaikan, dan slice_sizes
berisi ukuran slice untuk setiap dimensi. Secara lebih formal,
result[result_index] = operand[operand_index]
dengan:
adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes)
.operand_index = adjusted_start_indices + result_index
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor atau tensor terkuantisasi per tensor | (C1), (C2), (C4) |
(I2) | start_indices |
jumlah variabel dari tensor 0 dimensi dengan jenis bilangan bulat | (C2), (C3) |
(I3) | slice_sizes |
Konstanta tensor 1 dimensi dari jenis si64 |
(C2), (C4), (C5) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau tensor terkuantisasi per tensor | (C1), (C5) |
Batasan
- (C1)
element_type(operand) = element_type(result)
. - (C2)
size(start_indices) = size(slice_sizes) = rank(operand)
. - (C3)
same(type(start_indices...))
. - (C4)
0 <= slice_sizes <= shape(operand)
. - (C5)
shape(result) = slice_sizes
.
Contoh
// %operand: [
// [0, 0, 1, 1],
// [0, 0, 1, 1],
// [0, 0, 0, 0],
// [0, 0, 0, 0]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
slice_sizes = array<i64: 2, 2>
} : (tensor<4x4xi32>, tensor<i64>, tensor<i64>) -> tensor<2x2xi32>
// %result: [
// [1, 1],
// [1, 1]
// ]
dynamic_update_slice
Semantik
Menghasilkan tensor result
yang sama dengan tensor operand
, kecuali bahwa
slice yang dimulai dari start_indices
diperbarui dengan nilai di update
.
Secara lebih formal, result[result_index]
didefinisikan sebagai:
update[update_index]
jika0 <= update_index < shape(update)
di mana:adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update))
.update_index = result_index - adjusted_start_indices
.
operand[result_index]
sebaliknya.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor atau tensor terkuantisasi per tensor | (C1-C4), (C6) |
(I2) | update |
tensor atau tensor terkuantisasi per-tensor | (C2), (C3), (C6) |
(I3) | start_indices |
jumlah variabel dari tensor 0 dimensi dengan jenis bilangan bulat | (C4), (C5) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau tensor terkuantisasi per tensor | (C1) |
Batasan
- (C1)
type(operand) = type(result)
. - (C2)
element_type(update) = element_type(operand)
. - (C3)
rank(update) = rank(operand)
. - (C4)
size(start_indices) = rank(operand)
. - (C5)
same(type(start_indices...))
. - (C6)
0 <= shape(update) <= shape(operand)
.
Contoh
// %operand: [
// [1, 1, 0, 0],
// [1, 1, 0, 0],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
// %update: [
// [1, 1],
// [1, 1]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
: (tensor<4x4xi32>, tensor<2x2xi32>, tensor<i64>, tensor<i64>) -> tensor<4x4xi32>
// %result: [
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
berpangkat
Semantik
Melakukan operasi eksponensial element-wise pada tensor operand
dan menghasilkan
tensor result
. Bergantung pada jenis elemen, lakukan tindakan berikut:
- Untuk float:
exp
dari IEEE-754. - Untuk bilangan kompleks: eksponensial kompleks.
- Untuk jenis kuantisasi:
dequantize_op_quantize(exponential, operand, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor floating point atau jenis kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor dari jenis floating point atau kompleks atau tensor terkuantisasi per tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result)
.
Contoh
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]
eksponensial_minus_satu
Semantik
Menjalankan eksponensial element-wise dikurangi satu operasi pada tensor operand
dan
menghasilkan tensor result
. Bergantung pada jenis elemen, lakukan tindakan berikut:
- Untuk float:
expm1
dari IEEE-754. - Untuk bilangan kompleks: eksponensial kompleks dikurangi satu.
- Untuk jenis kuantisasi:
dequantize_op_quantize(exponential_minus_one, operand, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor floating point atau jenis kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor dari jenis floating point atau kompleks atau tensor terkuantisasi per tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result)
.
Contoh
// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand) : (tensor<2xf64>) -> tensor<2xf64>
// %result: [0.0, 1.71828187]
fft
Semantik
Melakukan transformasi Fourier maju dan balik untuk input/output real dan kompleks.
fft_type
adalah salah satu dari berikut ini:
FFT
: Meneruskan FFT kompleks ke kompleks.IFFT
: FFT kompleks-ke-kompleks terbalik.RFFT
: Meneruskan FFT real-to-complex.IRFFT
: FFT real-ke-kompleks terbalik (yaitu yang kompleks, menampilkan nilai nyata).
Secara lebih formal, dengan fungsi fft
yang menggunakan tensor 1 dimensi dari
jenis kompleks sebagai input, menghasilkan tensor 1 dimensi dari jenis yang sama sebagai
output dan menghitung transformasi Fourier diskret:
Untuk fft_type = FFT
, result
ditentukan sebagai hasil akhir dari serangkaian komputasi
L dengan L = size(fft_length)
. Misalnya, untuk L = 3
:
result1[i0, ..., :] = fft(operand[i0, ..., :])
.result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
.result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])
.
Selanjutnya, mengingat fungsi ifft
yang memiliki tanda tangan jenis yang sama dan menghitung invers dari fft
:
Untuk fft_type = IFFT
, result
ditentukan sebagai invers komputasi
untuk fft_type = FFT
. Misalnya, untuk L = 3
:
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])
.result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
.result[i0, ..., :] = ifft(result2[i0, ..., :])
.
Selain itu, dengan fungsi rfft
yang menggunakan tensor 1 dimensi dari
jenis floating point, menghasilkan tensor 1 dimensi dari jenis kompleks dari
semantik floating point yang sama dan berfungsi sebagai berikut:
rfft(real_operand) = truncated_result
di manacomplex_operand... = (real_operand..., 0.0)
.complex_result = fft(complex_operand)
.truncated_result = complex_result[:(rank(complex_result) / 2 + 1)]
.
(Saat transformasi Fourier diskret dihitung untuk operand nyata, elemen
N/2 + 1
pertama dari hasil akan menentukan sisa hasil secara tidak ambigu,
sehingga hasil rfft
akan terpotong untuk menghindari komputasi elemen yang redundan).
Untuk fft_type = RFFT
, result
ditentukan sebagai hasil akhir dari serangkaian komputasi
L dengan L = size(fft_length)
. Misalnya, untuk L = 3
:
result1[i0, ..., :] = rfft(operand[i0, ..., :])
.result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
.result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])
.
Terakhir, dengan fungsi irfft
yang memiliki tanda tangan jenis yang sama dan menghitung invers dari rfft
:
Untuk fft_type = IRFFT
, result
ditentukan sebagai invers komputasi
untuk fft_type = RFFT
. Misalnya, untuk L = 3
:
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])
.result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
.result[i0, ..., :] = irfft(result2[i0, ..., :])
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor dari jenis floating point atau kompleks | (C1), (C2), (C4), (C5) |
(I2) | fft_type |
enum FFT , IFFT , RFFT , dan IRFFT |
(C2), (C5) |
(I3) | fft_length |
Konstanta tensor 1 dimensi dari jenis si64 |
(C1), (C3), (C4) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor floating point atau jenis kompleks | (C2), (C4), (C5) |
Batasan
- (C1)
size(fft_length) <= rank(operand)
. - (C2) Hubungan antara jenis elemen
operand
danresult
bervariasi:- Jika
fft_type = FFT
,element_type(operand)
, danelement_type(result)
memiliki jenis kompleks yang sama. - Jika
fft_type = IFFT
,element_type(operand)
, danelement_type(result)
memiliki jenis kompleks yang sama. - Jika
fft_type = RFFT
,element_type(operand)
adalah jenis floating point danelement_type(result)
adalah jenis kompleks dari semantik floating point yang sama. - Jika
fft_type = IRFFT
,element_type(operand)
adalah jenis kompleks danelement_type(result)
adalah jenis floating point dari semantik floating point yang sama.
- Jika
- (C3)
1 <= size(fft_length) <= 3
. - (C4) Jika di antara
operand
danresult
, ada tensorreal
dari jenis floating point, makashape(real)[-size(fft_length):] = fft_length
. - (C5)
shape(result) = shape(operand)
kecuali untuk:- Jika
fft_type = RFFT
,dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1
. - Jika
fft_type = IRFFT
,dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1
.
- Jika
Contoh
// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
fft_type = #stablehlo<fft_type FFT>,
fft_length = array<i64: 4>
} : (tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>>
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]
lantai
Semantik
Melakukan floor element-wise dari tensor operand
dan menghasilkan tensor result
.
Mengimplementasikan operasi roundToIntegralTowardNegative
dari spesifikasi
IEEE-754. Untuk jenis kuantisasi, lakukan
dequantize_op_quantize(floor, operand, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor jenis floating point atau tensor terkuantisasi per tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor jenis floating point atau tensor terkuantisasi per tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result)
.
Contoh
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]
mengumpulkan
Semantik
Mengumpulkan slice dari tensor operand
dari offset yang ditentukan dalam start_indices
dan menghasilkan tensor result
.
Diagram berikut menunjukkan cara elemen di result
memetakan elemen di
operand
menggunakan contoh konkret. Diagram ini mengambil beberapa contoh indeks result
dan menjelaskan secara mendetail indeks operand
mana yang sesuai dengannya.
Secara lebih formal, result[result_index] = operand[operand_index]
dengan:
batch_dims = [d for d in axes(result) and d not in offset_dims]
.batch_index = result_index[batch_dims...]
.start_index
ditentukan sebagai:start_indices[bi0, ..., :, ..., biN]
denganbi
adalah elemen individual dalambatch_index
dan:
disisipkan pada indeksindex_vector_dim
, jikaindex_vector_dim
<rank(start_indices)
.[start_indices[batch_index]]
jika tidak.
- Untuk
d_operand
diaxes(operand)
,full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand])
ifd_operand = start_index_map[d_start]
.full_start_index[d_operand] = 0
jika tidak.
- Untuk
d_operand
diaxes(operand)
,full_batching_index[d_operand] = batch_index[d_start - (d_start < index_vector_dim ? 0 : 1)]
jikad_operand = operand_batching_dims[i_batching]
dand_start = start_indices_batching_dims[i_batching]
.full_batching_index[d_operand] = 0
jika tidak.
offset_index = result_index[offset_dims...]
.full_offset_index = [oi0, ..., 0, ..., oiN]
denganoi
adalah elemen individual dioffset_index
, dan0
disisipkan pada indeks daricollapsed_slice_dims
danoperand_batching_dims
.operand_index = full_start_index + full_batching_index + full_offset_index
.
Jika indices_are_sorted
adalah true
, implementasi dapat mengasumsikan bahwa
start_indices
diurutkan sehubungan dengan start_index_map
, jika tidak,
perilaku tidak ditentukan. Secara lebih formal, untuk semua i1 < i2
dari indices(result)
,
full_start_index(i1) <= full_start_index(i2)
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor atau tensor terkuantisasi per-tensor | (C1), (C8), (C11), (C17), (C19-C21), (C23) |
(I2) | start_indices |
tensor tipe integer | (C2-C3), (C14), (C17), (C22) |
(I3) | offset_dims |
Konstanta tensor 1 dimensi jenis si64 |
(C1), (C4-C5), (C22) |
(I4) | collapsed_slice_dims |
Konstanta tensor 1 dimensi dari jenis si64 |
(C1), (C6-C9), (C22) |
(I5) | operand_batching_dims |
Konstanta tensor 1 dimensi dari jenis si64 |
(C1), (C6), (C10-C12), (C16-C18), (C22) |
(I6) | start_indices_batching_dims |
Konstanta tensor 1 dimensi jenis si64 |
(C13-C17) |
(I7) | start_index_map |
Konstanta tensor 1 dimensi dari jenis si64 |
(C3), (C18-C19) |
(I8) | index_vector_dim |
konstanta dari jenis si64 |
(C2-C3), (C15), (C22) |
(I9) | slice_sizes |
Konstanta tensor 1 dimensi dari jenis si64 |
(C9), (C12), (C20-C22) |
(I10) | indices_are_sorted |
konstanta dari jenis i1 |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau tensor terkuantisasi per tensor | (C5), (C22-C23) |
Batasan
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims) + size(operand_batching_dims)
. - (C2)
0 <= index_vector_dim <= rank(start_indices)
. - (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1
. - (C4)
is_unique(offset_dims) and is_sorted(offset_dims)
. - (C5)
0 <= offset_dims < rank(result)
. - (C6)
is_unique(concatenate(collapsed_slice_dims, operand_batching_dims))
- (C7)
is_sorted(collapsed_slice_dims)
. - (C8)
0 <= collapsed_slice_dims < rank(operand)
. - (C9)
slice_sizes[collapsed_slice_dims...] <= 1
. - (C10)
is_sorted(operand_batching_dims)
. - (C11)
0 <= operand_batching_dims < rank(operand)
. - (C12)
slice_sizes[operand_batching_dims...] <= 1
. - (C13)
is_unique(start_indices_batching_dims)
. - (C14)
0 <= start_indices_batching_dims < rank(start_indices)
. - (C15)
index_vector_dim not in start_indices_batching_dims
. - (C16)
size(operand_batching_dims) == size(start_indices_batching_dims)
. - (C17)
dim(operand, operand_batching_dims...) = dim(start_indices, start_indices_batching_dims...)
. - (C18)
is_unique(concatenate(start_index_map, operand_batching_dims))
. - (C19)
0 <= start_index_map < rank(operand)
. - (C20)
size(slice_sizes) = rank(operand)
. - (C21)
0 <= slice_sizes <= shape(operand)
. - (C22)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)
dengan:batch_dim_sizes = shape(start_indices)
, tetapi ukuran dimensistart_indices
yang sesuai denganindex_vector_dim
tidak disertakan.offset_dim_sizes = slice_sizes
, kecuali ukuran dimensi dislice_sizes
yang sesuai dengancollapsed_slice_dims
danoperand_batching_dims
tidak disertakan.combine
menempatkanbatch_dim_sizes
pada sumbu yang sesuai denganbatch_dims
danoffset_dim_sizes
pada sumbu yang sesuai denganoffset_dims
.
- (C23)
element_type(operand) = element_type(result)
.
Contoh
// %operand: [
// [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ],
// [
// [[25, 26], [27, 28], [29, 30], [31, 32]],
// [[33, 34], [35, 36], [37, 38], [39, 40]],
// [[41, 42], [43, 44], [45, 46], [47, 48]]
// ]
// ]
// %start_indices: [
// [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 9]]
// ],
// [
// [[0, 0], [2, 1], [2, 2]],
// [[1, 2], [0, 1], [1, 0]]
// ]
// ]
%result = "stablehlo.gather"(%operand, %start_indices) {
dimension_numbers = #stablehlo.gather<
offset_dims = [3, 4],
collapsed_slice_dims = [1],
operand_batching_dims = [0],
start_indices_batching_dims = [1],
start_index_map = [2, 1],
index_vector_dim = 3>,
slice_sizes = array<i64: 1, 1, 2, 2>,
indices_are_sorted = false
} : (tensor<2x3x4x2xi32>, tensor<2x2x3x2xi64>) -> tensor<2x2x3x2x2xi32>
// %result: [
// [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[33, 34], [35, 36]],
// [[35, 36], [37, 38]],
// [[41, 42], [43, 44]]
// ]
// ],
// [
// [
// [[1, 2], [3, 4]],
// [[13, 14], [15, 16]],
// [[21, 22], [23, 24]]
// ],
// [
// [[43, 44], [45, 46]],
// [[33, 34], [35, 36]],
// [[27, 28], [29, 30]]
// ]
// ]
// ]
get_dimension_size
Semantik
Menghasilkan ukuran dimension
yang diberikan dari operand
. Secara lebih formal,
result = dim(operand, dimension)
. Semantik hanya berkaitan dengan komponen
bentuk jenis. Jenis elemen dapat berupa apa saja.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor atau tensor terkuantisasi | (C1) |
(I2) | dimension |
konstanta dari jenis si64 |
(C1) |
Output
Nama | Jenis |
---|---|
result |
Tensor 0 dimensi dari jenis si32 |
Batasan
- (C1)
0 <= dimension < rank(operand)
.
Contoh
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
dimension = 1 : i64
} : (tensor<2x3xi64>) -> tensor<i32>
// %result: 3
get_tuple_element
Semantik
Mengekstrak elemen pada posisi index
tuple operand
dan menghasilkan
result
. Secara lebih formal, result = operand[index]
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tuple | (C1), (C2) |
(I2) | index |
konstanta dari jenis si32 |
(C1), C2 |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
jenis apa pun yang didukung | (C2) |
Batasan
- (C1)
0 <= index < size(operand)
. - (C2)
type(result) = tuple_element_types(operand)[index]
.
Contoh
// %operand: ([1.0, 2.0], (3))
index = 0 : i32
} : (tuple<tensor<2xf32>, tuple<tensor<i32>>>) -> tensor<2xf32>
// %result: [1.0, 2.0]
jika
Semantik
Menghasilkan output dari menjalankan tepat satu fungsi dari true_branch
atau
false_branch
, bergantung pada nilai pred
. Secara lebih formal, result =
pred ? true_branch() : false_branch()
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | pred |
Tensor 0 dimensi dari jenis i1 |
|
(I2) | true_branch |
fungsi | (C1-C3) |
(I3) | false_branch |
fungsi | (C1), (C2) |
Output
Nama | Jenis | Batasan |
---|---|---|
results |
jumlah variabel tensor, tensor terkuantisasi, atau token | (C3) |
Batasan
- (C1)
input_types(true_branch) = input_types(false_branch) = []
. - (C2)
output_types(true_branch) = output_types(false_branch)
. - (C3)
type(results...) = output_types(true_branch)
.
Contoh
// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
"stablehlo.return"(%result_true_branch) : (tensor<i32>) -> ()
}, {
"stablehlo.return"(%result_false_branch) : (tensor<i32>) -> ()
}) : (tensor<i1>) -> tensor<i32>
// %result: 10
imag
Semantik
Mengekstrak bagian imajiner, per elemen, dari operand
dan menghasilkan
tensor result
. Secara lebih formal, untuk setiap elemen x
:
imag(x) = is_complex(x) ? imaginary_part(x) :
constant(0, element_type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor dari jenis floating point atau kompleks | (C1), C2 |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor jenis floating point | (C1), (C2) |
Batasan
- (C1)
shape(result) = shape(operand)
. - (C2)
element_type(result)
ditentukan sebagai:complex_element_type(element_type(operand))
ifis_complex(operand)
.element_type(operand)
jika tidak.
Contoh
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [2.0, 4.0]
dalam feed
Semantik
Membaca data dari infeed dan menghasilkan results
.
Semantik infeed_config
ditentukan oleh implementasi.
results
terdiri dari nilai payload yang muncul terlebih dahulu dan token yang muncul
terakhir. Pada masa mendatang, kami berencana untuk membagi payload dan token menjadi dua
output terpisah untuk meningkatkan kejelasan
(#670).
Input
Label | Nama | Jenis |
---|---|---|
(I1) | token |
token |
(I2) | infeed_config |
konstanta dari jenis string |
Output
Nama | Jenis | Batasan |
---|---|---|
results |
jumlah tensor, tensor atau token terkuantisasi | (C1-C3) |
Batasan
- (C1)
0 < size(results)
. - (C2)
is_empty(result[:-1])
atauis_tensor(type(results[:-1]))
. - (C3)
is_token(type(results[-1]))
.
Contoh
// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]
iota
Semantik
Mengisi tensor output
dengan nilai dalam urutan menaik mulai dari nol
di sepanjang dimensi iota_dimension
. Secara lebih formal,
output[output_index] = constant(is_quantized(output) ?
quantize(output_index[iota_dimension], element_type(output)) :
output_index[iota_dimension], element_type(output))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | iota_dimension |
si64 |
(C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
output |
tensor bilangan bulat, floating point, atau jenis kompleks atau tensor terkuantisasi per tensor | (C1) |
Batasan
- (C1)
0 <= iota_dimension < rank(output)
.
Contoh
%output = "stablehlo.iota"() {
iota_dimension = 0 : i64
} : () -> tensor<4x5xi32>
// %output: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
%output = "stablehlo.iota"() {
iota_dimension = 1 : i64
} : () -> tensor<4x5xi32>
// %output: [
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4]
// ]
is_finite
Semantik
Melakukan pemeriksaan elemen apakah nilai dalam x
terbatas (yaitu bukan
+Inf, -Inf, atau NaN) dan menghasilkan tensor y
. Mengimplementasikan operasi isFinite
dari spesifikasi IEEE-754. Untuk jenis kuantisasi, hasilnya
selalu true
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | x |
tensor jenis floating point atau tensor terkuantisasi per tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
y |
tensor dari jenis boolean | (C1) |
Batasan
- (C1)
shape(x) = shape(y)
.
Contoh
// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x) : (tensor<7xf64) -> tensor<7xi1>
// %y: [false, false, false, true, true, true, true]
log
Semantik
Melakukan operasi logaritma element-wise pada tensor operand
dan menghasilkan
tensor result
. Bergantung pada jenis elemen, melakukan hal berikut:
- Untuk float:
log
dari IEEE-754. - Untuk bilangan kompleks: logaritma kompleks.
- Untuk jenis kuantisasi:
dequantize_op_quantize(log, operand, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor floating point atau jenis kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor dari jenis floating point atau kompleks atau tensor terkuantisasi per tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result)
.
Contoh
// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]
log_plus_one
Semantik
Melakukan logaritma element-wise plus satu operasi pada tensor operand
dan
menghasilkan tensor result
. Bergantung pada jenis elemen, melakukan hal berikut:
- Untuk float:
logp1
dari IEEE-754. - Untuk bilangan kompleks: logaritma kompleks ditambah satu.
- Untuk jenis kuantisasi:
dequantize_op_quantize(log_plus_one, operand, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor floating point atau jenis kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor dari jenis floating point atau kompleks atau tensor terkuantisasi per tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result)
.
Contoh
// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]
logistik
Semantik
Melakukan operasi logistik element-wise pada tensor operand
dan menghasilkan
tensor result
. Bergantung pada jenis elemen, lakukan tindakan berikut:
- Untuk float:
division(1, addition(1, exp(-x)))
dari IEEE-754. - Untuk bilangan kompleks: logistik kompleks.
- Untuk jenis terkuantisasi:
dequantize_op_quantize(logistic, operand, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor floating point atau jenis kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor dari jenis floating point atau kompleks atau tensor terkuantisasi per tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result)
.
Contoh
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]
peta
Semantik
Menerapkan fungsi peta computation
ke inputs
di sepanjang dimensions
dan
menghasilkan tensor result
.
Secara lebih formal, result[result_index] = computation(inputs...[result_index])
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | inputs |
jumlah variabel tensor atau tensor terkuantisasi per tensor | (C1-C4) |
(I2) | dimensions |
Konstanta tensor 1 dimensi dari jenis si64 |
(C3) |
(I3) | computation |
fungsi | (C4) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau tensor terkuantisasi per-tensor | (C1), (C4) |
Batasan
- (C1)
shape(inputs...) = shape(result)
. - (C2)
0 < size(inputs) = N
. - (C3)
dimensions = range(rank(inputs[0]))
. - (C4)
computation
memiliki jenis(tensor<E0>, ..., tensor<EN-1>) -> tensor<E'>
denganEi = element_type(inputs[i])
danE' = element_type(result)
.
Contoh
// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = stablehlo.multiply %arg0, %arg1 : tensor<i64>
stablehlo.return %0 : tensor<i64>
}) {
dimensions = array<i64: 0, 1>
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[0, 5], [12, 21]]
maksimum
Semantik
Melakukan operasi maksimum element-wise pada tensor lhs
dan rhs
serta menghasilkan
tensor result
. Bergantung pada jenis elemen, melakukan hal berikut:
- Untuk boolean: OR logika.
- Untuk bilangan bulat: maksimum bilangan bulat.
- Untuk float:
maximum
dari IEEE-754. - Untuk bilangan kompleks: maksimum leksikografis untuk pasangan
(real, imaginary)
. Memaksakan pengurutan pada angka kompleks melibatkan semantik yang mengejutkan, jadi di masa mendatang kami berencana menghapus dukungan untuk bilangan kompleks untuk operasi ini (#560). - Untuk jenis kuantisasi:
dequantize_op_quantize(maximum, lhs, rhs, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | lhs |
tensor atau tensor terkuantisasi per-tensor | (C1) |
(I2) | rhs |
tensor atau tensor terkuantisasi per tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau tensor terkuantisasi per tensor | (C1) |
Batasan
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Contoh
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 8]]
minimum
Semantik
Melakukan operasi min element-wise pada tensor lhs
dan rhs
serta menghasilkan
tensor result
. Bergantung pada jenis elemen, lakukan tindakan berikut:
- Untuk boolean: AND logika.
- Untuk bilangan bulat: minimum bilangan bulat.
- Untuk float:
minimum
dari IEEE-754. - Untuk bilangan kompleks: minimum leksikografis untuk pasangan
(real, imaginary)
. Menerapkan pengurutan pada bilangan kompleks melibatkan semantik yang mengejutkan, jadi pada masa mendatang, kami berencana untuk menghapus dukungan untuk bilangan kompleks untuk operasi ini (#560). - Untuk jenis terkuantisasi:
dequantize_op_quantize(minimum, lhs, rhs, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | lhs |
tensor atau tensor terkuantisasi per-tensor | (C1) |
(I2) | rhs |
tensor atau tensor terkuantisasi per tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau tensor terkuantisasi per tensor | (C1) |
Batasan
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Contoh
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 4]]
memperbanyak
Semantik
Melakukan produk element-wise dari dua tensor lhs
dan rhs
dan menghasilkan
tensor result
. Bergantung pada jenis elemen, lakukan tindakan berikut:
- Untuk boolean: AND logika.
- Untuk bilangan bulat: perkalian bilangan bulat.
- Untuk float:
multiplication
dari IEEE-754. - Untuk bilangan kompleks: perkalian kompleks.
- Untuk jenis terkuantisasi:
dequantize_op_quantize(multiply, lhs, rhs, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | lhs |
tensor atau tensor terkuantisasi per-tensor | (C1) |
(I2) | rhs |
tensor atau tensor terkuantisasi per tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau tensor terkuantisasi per tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result)
.
Contoh
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 12], [21, 32]]
negasi
Semantik
Melakukan negasi element-wise dari tensor operand
dan menghasilkan tensor
result
. Bergantung pada jenis elemen, lakukan tindakan berikut:
- Untuk bilangan bulat bertanda: negasi bilangan bulat.
- Untuk bilangan bulat tanpa tanda: bitcast ke bilangan bulat bertanda, negasi bilangan bulat, bitcast kembali ke bilangan bulat tanpa tanda.
- Untuk float:
negate
dari IEEE-754. - Untuk bilangan kompleks: negasi kompleks.
- Untuk jenis kuantisasi:
dequantize_op_quantize(negate, operand, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor bilangan bulat, floating point, atau jenis kompleks atau tensor terkuantisasi per tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor bilangan bulat, floating point, atau jenis kompleks atau tensor terkuantisasi per tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result)
.
Contoh
// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand) : (tensor<2xi32>) -> tensor<2xi32>
// %result: [0, 2]
// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"(%operand) : (tensor<1xcomplex<f32>>) -> tensor<1xcomplex<f32>>
// %result: [-2.5, -0.0]
tidak
Semantik
Melakukan NOT element-wise dari tensor operand
dan menghasilkan tensor result
.
Bergantung pada jenis elemen, lakukan tindakan berikut:
- Untuk boolean: logical NOT.
- Untuk bilangan bulat: bitwise NOT.
Argumen
Nama | Jenis | Batasan |
---|---|---|
operand |
tensor dari jenis boolean atau bilangan bulat | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor dari jenis boolean atau bilangan bulat | (C1) |
Batasan
- (C1)
type(operand) = type(result)
.
Contoh
// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand) : (tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[-2, -3], [-4, -5]]
// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"(%operand) : (tensor<2xi1>) -> tensor<2xi1>
// %result: [false, true]
optimization_barrier
Semantik
Memastikan bahwa operasi yang menghasilkan operand
dieksekusi sebelum
operasi apa pun yang bergantung pada result
dan mencegah transformasi compiler
memindahkan operasi melintasi penghalang. Selain itu, operasinya adalah
identitas, yaitu result = operand
.
Argumen
Nama | Jenis | Batasan |
---|---|---|
operand |
jumlah variabel tensor, tensor atau token terkuantisasi per tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
jumlah variabel tensor, tensor atau token terkuantisasi per tensor | (C1) |
Batasan
- (C1)
type(operand...) = type(result...)
.
Contoh
// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
// %result0: 0.0
// %result1: 1.0
atau
Semantik
Melakukan OR element-wise dari dua tensor lhs
dan rhs
dan menghasilkan tensor
result
. Bergantung pada jenis elemen, melakukan hal berikut:
- Untuk boolean: OR logika.
- Untuk bilangan bulat: bitwise OR.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | lhs |
tensor dari jenis bilangan bulat atau boolean | (C1) |
(I2) | rhs |
tensor berjenis integer atau boolean | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor dari jenis bilangan bulat atau boolean | (C1) |
Batasan
- (C1)
type(lhs) = type(rhs) = type(result)
.
Contoh
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, true]]
outfeed
Semantik
Menulis inputs
ke feed keluar dan menghasilkan token result
.
Semantik outfeed_config
ditentukan oleh implementasi.
Input
Label | Nama | Jenis |
---|---|---|
(I1) | inputs |
jumlah variabel tensor atau tensor terkuantisasi |
(I2) | token |
token |
(I3) | outfeed_config |
konstanta dari jenis string |
Output
Nama | Jenis |
---|---|
result |
token |
Contoh
%result = "stablehlo.outfeed"(%input0, %token) {
outfeed_config = ""
} : (tensor<2x2x2xi64>, !stablehlo.token) -> !stablehlo.token
bantalan
Semantik
Memperluas operand
dengan padding di sekitar tensor serta di antara elemen
tensor dengan padding_value
yang diberikan.
edge_padding_low
dan edge_padding_high
menentukan jumlah padding yang ditambahkan
di bagian bawah (di samping indeks 0) dan bagian atas (di samping indeks tertinggi)
dari setiap dimensi. Jumlah padding dapat negatif, dengan
nilai absolut padding negatif menunjukkan jumlah elemen yang akan dihapus
dari dimensi yang ditentukan.
interior_padding
menentukan jumlah padding yang ditambahkan di antara dua elemen
di setiap dimensi yang mungkin tidak negatif. Padding interior terjadi sebelum padding tepi sehingga padding tepi negatif akan menghapus elemen dari
operand dengan padding bagian dalam.
Secara lebih formal, result[result_index]
didefinisikan sebagai:
operand[operand_index]
jikaresult_index = edge_padding_low + operand_index * (interior_padding + 1)
.padding_value
sebaliknya.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor atau tensor terkuantisasi per tensor | (C1), (C2), (C4) |
(I2) | padding_value |
Tensor 0 dimensi atau tensor terkuantisasi per tensor | (C1) |
(I3) | edge_padding_low |
Konstanta tensor 1 dimensi jenis si64 |
(C1), (C4) |
(I4) | edge_padding_high |
Konstanta tensor 1 dimensi dari jenis si64 |
(C1), (C4) |
(I5) | interior_padding |
Konstanta tensor 1 dimensi dari jenis si64 |
(C2-C4) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau tensor terkuantisasi per tensor | (C3-C6) |
Batasan
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result)
. - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)
. - (C3)
0 <= interior_padding
. - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high
.
Contoh
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
edge_padding_low = array<i64: 0, 1>,
edge_padding_high = array<i64: 2, 1>,
interior_padding = array<i64: 1, 2>
} : (tensor<2x3xi32>, tensor<i32>) -> tensor<5x9xi32>
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
partition_id
Semantik
Menghasilkan partition_id
dari proses saat ini.
Output
Nama | Jenis |
---|---|
result |
Tensor 0-dimensi jenis ui32 |
Contoh
%result = "stablehlo.partition_id"() : () -> tensor<ui32>
popcnt
Semantik
Melakukan penghitungan elemen per elemen dari jumlah bit yang ditetapkan dalam tensor operand
dan menghasilkan tensor result
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor dari jenis bilangan bulat | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor dari jenis bilangan bulat | (C1) |
Batasan
- (C1)
type(operand) = type(result)
.
Contoh
// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand) : (tensor<4xi64>) -> tensor<4xi64>
// %result: [0, 1, 1, 7]
daya
Semantik
Melakukan eksponensial element-wise tensor lhs
dengan tensor rhs
dan
menghasilkan tensor result
. Bergantung pada jenis elemen, melakukan hal berikut:
- Untuk bilangan bulat: eksponensial bilangan bulat.
- Untuk float:
pow
dari IEEE-754. - Untuk bilangan kompleks: eksponensial kompleks.
- Untuk jenis terkuantisasi:
dequantize_op_quantize(power, lhs, rhs, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | lhs |
tensor bilangan bulat, floating point, atau jenis kompleks atau tensor terkuantisasi per tensor | (C1) |
(I2) | rhs |
tensor bilangan bulat, floating point, atau jenis kompleks atau tensor terkuantisasi per tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor bilangan bulat, floating point, atau jenis kompleks atau tensor terkuantisasi per tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result)
.
Contoh
// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs) : (tensor<6xf64>, tensor<6xf64>) -> tensor<6xf64>
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]
real
Semantik
Mengekstrak bagian riil, per elemen, dari operand
dan menghasilkan tensor
result
. Secara lebih formal, untuk setiap elemen x
:
real(x) = is_complex(x) ? real_part(x) : x
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor dari jenis floating point atau kompleks | (C1), C2 |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor jenis floating point | (C1), (C2) |
Batasan
- (C1)
shape(result) = shape(operand)
. - (C2)
element_type(result)
ditentukan sebagai:complex_element_type(element_type(operand))
ifis_complex(operand)
.element_type(operand)
jika tidak.
Contoh
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [1.0, 3.0]
recv
Semantik
Menerima data dari saluran dengan channel_id
dan menghasilkan results
.
Jika is_host_transfer
adalah true
, operasi akan mentransfer data dari
host. Jika tidak, data akan ditransfer dari perangkat lain. Artinya,
ditentukan oleh implementasi. Flag ini menduplikasi informasi yang diberikan di
channel_type
, sehingga pada masa mendatang kami berencana untuk hanya menyimpan salah satunya
(#666).
results
terdiri dari nilai payload yang muncul terlebih dahulu dan token yang muncul terakhir. Pada masa mendatang, kami berencana untuk membagi payload dan token menjadi dua
output terpisah untuk meningkatkan kejelasan
(#670).
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | token |
token |
(C4) |
(I2) | channel_id |
konstanta dari jenis si64 |
|
(I3) | channel_type |
enum DEVICE_TO_DEVICE dan HOST_TO_DEVICE |
(C1) |
(I4) | is_host_transfer |
konstanta jenis i1 |
(C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
results |
jumlah tensor, tensor atau token terkuantisasi | (C2-C4) |
Batasan
- (C1)
channel_type
didefinisikan sebagai:HOST_TO_DEVICE
jikais_host_transfer = true
,DEVICE_TO_DEVICE
jika tidak.
- (C2)
0 < size(results)
. - (C3)
is_empty(result[:-1])
atauis_tensor(type(results[:-1]))
. - (C4)
is_token(type(results[-1]))
.
Contoh
%results0, %results1 = "stablehlo.recv"(%token) {
channel_handle = #stablehlo.channel_handle<handle = 1, type = 3>,
is_host_transfer = true
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
reduce
Semantik
Menerapkan fungsi pengurangan body
ke inputs
dan init_values
di sepanjang
dimensions
dan menghasilkan tensor results
.
Urutan pengurangan ditentukan oleh implementasi, yang berarti bahwa body
dan
init_values
harus membentuk monoid untuk menjamin bahwa operasi menghasilkan
hasil yang sama untuk semua input di semua implementasi. Namun, kondisi ini
tidak berlaku untuk banyak pengurangan populer. Misalnya, penambahan floating point untuk
body
dan nol untuk init_values
sebenarnya tidak membentuk monoid karena
penambahan floating point tidak asosiatif.
Secara lebih formal, results...[j0, ..., jR-1] = reduce(input_slices_converted)
dengan:
input_slices = inputs...[j0, ..., :, ..., jR-1]
, dengan:
disisipkan didimensions
.input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...)
.init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...)
.reduce(input_slices_converted) = exec(schedule)
untuk beberapa hierarki binerschedule
dengan:exec(node) = body(exec(node.left), exec(node.right))
.exec(leaf) = leaf.value
.
schedule
adalah hierarki biner penuh yang ditentukan oleh implementasi yang traversal berurutannya terdiri dari:- Nilai
input_slices_converted...[index]
, untuk semuaindex
diindex_space(input_slices_converted)
dalam urutan leksikografis menaikindex
. - Di sela-sela jumlah
init_values_converted
yang ditentukan implementasi pada posisi yang ditentukan implementasi.
- Nilai
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | inputs |
jumlah variabel tensor atau tensor terkuantisasi per tensor | (C1-C4), (C6), (C7) |
(I2) | init_values |
jumlah variadik tensor 0 dimensi atau tensor terkuantisasi per tensor | (C2), (C3) |
(I3) | dimensions |
Konstanta tensor 1 dimensi dari jenis si64 |
(C4), (C5), (C7) |
(I4) | body |
fungsi | (C6) |
Output
Nama | Jenis | Batasan |
---|---|---|
results |
jumlah variabel tensor atau tensor terkuantisasi per tensor | (C3), (C7), (C8) |
Batasan
- (C1)
same(shape(inputs...))
. - (C2)
element_type(inputs...) = element_type(init_values...)
. - (C3)
0 < size(inputs) = size(init_values) = size(results) = N
. - (C4)
0 <= dimensions < rank(inputs[0])
. - (C5)
is_unique(dimensions)
. - (C6)
body
memiliki jenis(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,
tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
denganis_promotable(element_type(inputs[i]), Ei)
. - (C7)
shape(results...) = shape(inputs...)
, kecuali ukuran dimensiinputs...
yang sesuai dengandimensions
tidak disertakan. - (C8)
element_type(results[i]) = Ei
untuk semuai
di[0,N)
.
Contoh
// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
dimensions = array<i64: 1>
} : (tensor<1x6xi64>, tensor<i64>) -> tensor<1xi64>
// %result = [15]
reduce_precision
Semantik
Melakukan konversi operand
berdasarkan elemen ke jenis floating point lain
yang menggunakan exponent_bits
dan mantissa_bits
, serta kembali ke jenis
floating point asli dan menghasilkan tensor output
.
Secara lebih formal:
- Bit mantisa dari nilai asli diperbarui untuk membulatkan nilai
asli ke nilai terdekat yang dapat direpresentasikan dengan
mantissa_bits
menggunakan semantikroundToIntegralTiesToEven
. - Kemudian, jika
mantissa_bits
lebih kecil dari jumlah bit mantissa dari nilai asli, bit mantissa akan terpotong menjadimantissa_bits
. - Kemudian, jika bit eksponen dari hasil perantara tidak sesuai dengan
rentang yang disediakan oleh
exponent_bits
, hasil perantara akan meluap ke tak terbatas menggunakan tanda asli atau underflow ke nol menggunakan tanda asli. - Untuk jenis kuantisasi, lakukan
dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor jenis floating point atau tensor terkuantisasi per tensor | (C1) |
(I2) | exponent_bits |
konstanta dari jenis si32 |
(C2) |
(I3) | mantissa_bits |
konstanta jenis si32 |
(C3) |
Output
Nama | Jenis | Batasan |
---|---|---|
output |
tensor jenis floating point atau tensor terkuantisasi per tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(output)
. - (C2)
1 <= exponent_bits
. - (C3)
0 <= mantissa_bits
.
Contoh
// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
exponent_bits = 5 : i32,
mantissa_bits = 10 : i32
} : (tensor<6xf64>) -> tensor<6xf64>
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]
reduce_scatter
Semantik
Dalam setiap grup proses di petak proses StableHLO, melakukan pengurangan,
menggunakan computations
, pada nilai tensor operand
dari setiap proses,
membagi hasil pengurangan di sepanjang scatter_dimension
menjadi beberapa bagian, dan menyebarkan
bagian yang terpisah di antara proses untuk menghasilkan result
.
Operasi ini membagi petak proses StableHLO menjadi process_groups
yang
ditentukan sebagai berikut:
cross_replica(replica_groups)
ifchannel_id <= 0 and use_global_device_ids = false
.cross_replica_and_partition(replica_groups)
jikachannel_id > 0 and use_global_device_ids = false
.flattened_ids(replica_groups)
ifchannel_id > 0 and use_global_device_ids = true
.
Setelah itu, dalam setiap process_group
:
reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation)
.parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension)
.result@receiver = parts@sender[receiver_index]
untuk semuasender
diprocess_group
, denganreceiver_index = process_group.index(receiver)
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor atau tensor terkuantisasi per tensor | (C1), (C2), (C7), (C8) |
(I2) | scatter_dimension |
konstanta dari jenis si64 |
(C1), (C2), (C8) |
(I3) | replica_groups |
Konstanta tensor 2 dimensi dari jenis si64 |
(C3-C5) |
(I4) | channel_id |
konstanta jenis si64 |
(C6) |
(I5) | use_global_device_ids |
konstanta dari jenis i1 |
(C6) |
(I6) | computation |
fungsi | (C7) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau tensor terkuantisasi per tensor | (C8-C9) |
Batasan
- (C1)
dim(operand, scatter_dimension) % dim(process_groups, 1) = 0
. - (C2)
0 <= scatter_dimension < rank(operand)
. - (C3)
is_unique(replica_groups)
. - (C4)
size(replica_groups)
ditentukan sebagai:num_replicas
jikacross_replica
digunakan.num_replicas
jikacross_replica_and_partition
digunakan.num_processes
jikaflattened_ids
digunakan.
- (C5)
0 <= replica_groups < size(replica_groups)
. - (C6) Jika
use_global_device_ids = true
, makachannel_id > 0
. - (C7)
computation
memiliki jenis(tensor<E>, tensor<E>) -> (tensor<E>)
denganis_promotable(element_type(operand), E)
. - (C8)
shape(result) = shape(operand)
kecuali:dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1)
.
- (C9)
element_type(result) = E
.
Contoh
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
scatter_dimension = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x4xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[10, 12],
// [18, 20]]
// %result@(1, 0): [[14, 16],
// [22, 24]]
reduce_window
Semantik
Menerapkan fungsi pengurangan body
ke jendela inputs
dan init_values
dan menghasilkan results
.
Diagram berikut menunjukkan cara elemen pada results...
dihitung dari
inputs...
menggunakan contoh konkret.
Secara lebih formal,
results...[result_index] = reduce(windows, init_values, axes(inputs...), body)
(lihat reduce) dengan:
padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1)
.window_start = result_index * window_strides
.window_end = window_start + (window_dimensions - 1) * window_dilations + 1
.windows = slice(padded_inputs..., window_start, window_end, window_dilations)
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | inputs |
jumlah variabel tensor atau tensor terkuantisasi per tensor | (C1-C4), (C6), (C8), (C10), (C12), (C13), (C15) |
(I2) | init_values |
jumlah variadik tensor 0 dimensi atau tensor terkuantisasi per tensor | (C1), (C13) |
(I3) | window_dimensions |
Konstanta tensor 1 dimensi dari jenis si64 |
(C4), (C5), (C15) |
(I4) | window_strides |
Konstanta tensor 1 dimensi dari jenis si64 |
(C6), (C7), (C15) |
(I5) | base_dilations |
Konstanta tensor 1 dimensi dari jenis si64 |
(C8), (C9), (C15) |
(I6) | window_dilations |
Konstanta tensor 1 dimensi dari jenis si64 |
(C10), (C11), (C15) |
(I7) | padding |
Konstanta tensor 2 dimensi dari jenis si64 |
(C12), (C15) |
(I8) | body |
fungsi | (C13) |
Output
Nama | Jenis | Batasan |
---|---|---|
results |
jumlah variabel tensor atau tensor terkuantisasi per tensor | (C1), (C14-C16) |
Batasan
- (C1)
0 < size(inputs) = size(init_values) = size(results) = N
. - (C2)
same(shape(inputs...))
. - (C3)
element_type(inputs...) = element_type(init_values...)
. - (C4)
size(window_dimensions) = rank(inputs[0])
. - (C5)
0 < window_dimensions
. - (C6)
size(window_strides) = rank(inputs[0])
. - (C7)
0 < window_strides
. - (C8)
size(base_dilations) = rank(inputs[0])
. - (C9)
0 < base_dilations
. - (C10)
size(window_dilations) = rank(inputs[0])
. - (C11)
0 < window_dilations
. - (C12)
shape(padding) = [rank(inputs[0]), 2]
. - (C13)
body
memiliki jenis(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,
tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
denganis_promotable(element_type(inputs[i]), Ei)
. - (C14)
same(shape(results...))
. - (C15)
shape(results[0]) = num_windows
dengan:dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1
.padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1]
.dilated_window_shape = (window_dimensions - 1) * window_dilations + 1
.is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape
.num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1
.
- (C16)
element_type(results[i]) = Ei
untuk semuai
di[0,N)
.
Contoh
// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
window_dimensions = array<i64: 2, 1>,
window_strides = array<i64: 4, 1>,
base_dilations = array<i64: 2, 1>,
window_dilations = array<i64: 3, 1>,
padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<3x2xi64>, tensor<i64>) -> tensor<2x2xi64>
// %result = [[0, 0], [3, 4]]
sisanya
Semantik
Melakukan sisa elemen dari tensor dividen lhs
dan pembagi rhs
dan
menghasilkan tensor result
.
Secara lebih formal, tanda hasil diambil dari dividen, dan
nilai absolut hasilnya selalu kurang dari nilai absolut pembagi.
Sisa dihitung sebagai lhs - d * rhs
, dengan d
diberikan oleh:
- Untuk bilangan bulat:
stablehlo.divide(lhs, rhs)
. - Untuk float:
division(lhs, rhs)
dari IEEE-754 dengan atribut pembulatanroundTowardZero
. - Untuk bilangan kompleks: TBD (#997).
- Untuk jenis kuantisasi:
dequantize_op_quantize(remainder, lhs, rhs, type(result))
.
Untuk jenis elemen floating point, operasi ini berbeda dengan
operasi remainder
dari spesifikasi IEEE-754 dengan d
adalah nilai integral
yang paling dekat dengan nilai persis lhs/rhs
dengan ikatan ke genap.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | lhs |
tensor bilangan bulat, floating point, atau kompleks, atau tensor terkuantisasi per tensor | (C1) |
(I2) | rhs |
tensor bilangan bulat, floating point, atau kompleks, atau tensor terkuantisasi per tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor integer, floating-point, atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result)
.
Contoh
// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64>
// %result: [2, -2, 2, -2]
replica_id
Semantik
Menghasilkan replica_id
dari proses saat ini.
Output
Nama | Jenis |
---|---|
result |
Tensor 0-dimensi jenis ui32 |
Contoh
%result = "stablehlo.replica_id"() : () -> tensor<ui32>
membentuk ulang
Semantik
Melakukan pembentukan ulang tensor operand
menjadi tensor result
. Secara konseptual, hal ini
sama dengan mempertahankan representasi kanonis yang sama, tetapi berpotensi mengubah
bentuknya, misalnya dari tensor<2x3xf32>
menjadi tensor<3x2xf32>
atau tensor<6xf32>
.
Secara lebih formal, result[result_index] = operand[operand_index]
dengan
result_index
dan operand_index
memiliki posisi yang sama dalam pengurutan
leksikal index_space(result)
dan index_space(operand)
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor atau tensor terkuantisasi | (C1-C3) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau quantized tensor | (C1-C3) |
Batasan
- (C1)
element_type(result)
diberikan oleh:element_type(operand)
, jika!is_per_axis_quantized(operand)
.element_type(operand)
kecuali bahwaquantization_dimension(operand)
danquantization_dimension(result)
mungkin berbeda, jika tidak.
- (C2)
size(operand) = size(result)
. - (C3) Jika
is_per_axis_quantized(operand)
:reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
.reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.
Contoh
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.reshape"(%operand) : (tensor<2x3xi32>) -> tensor<3x2xi32>
// %result: [[1, 2], [3, 4], [5, 6]]
balik
Semantik
Membalik urutan elemen dalam operand
di sepanjang dimensions
yang ditentukan
dan menghasilkan tensor result
. Secara lebih formal,
result[result_index] = operand[operand_index]
dengan:
operand_index[d] = dim(result, d) - result_index[d] - 1
jikad
didimensions
.operand_index[d] = result_index[d]
sebaliknya.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor atau tensor terkuantisasi per tensor | (C1), (C3) |
(I2) | dimensions |
Konstanta tensor 1 dimensi dari jenis si64 |
(C2), (C3) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau tensor terkuantisasi per tensor | (C1), (C3) |
Batasan
- (C1)
type(operand) = type(result)
. - (C2)
is_unique(dimensions)
. - (C3)
0 <= dimensions < rank(result)
.
Contoh
// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
dimensions = array<i64: 1>
} : (tensor<3x2xi32>) -> tensor<3x2xi32>
// %result: [[2, 1], [4, 3], [6, 5]]
rng
Semantik
Menghasilkan angka acak menggunakan algoritma rng_distribution
dan menghasilkan
tensor result
dari bentuk shape
tertentu.
Jika rng_distribution = UNIFORM
, angka acak akan dihasilkan
mengikuti distribusi seragam selama interval [a, b)
. Jika a >= b
,
perilaku tidak ditentukan.
Jika rng_distribution = NORMAL
, angka acak akan dihasilkan
mengikuti distribusi normal dengan mean = a
dan simpangan baku = b
.
Jika b < 0
, perilaku tidak ditentukan.
Cara persis pembuatan angka acak ditentukan oleh implementasi. Misalnya, status tersebut mungkin bersifat deterministik atau tidak, dan mungkin menggunakan status tersembunyi atau tidak.
Dalam diskusi dengan banyak pemangku kepentingan, operasi ini tampaknya tidak digunakan lagi secara efektif, jadi di masa mendatang kami berencana untuk menghapusnya (#597).
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | a |
Tensor 0-dimensi dari jenis integer, boolean, atau floating point | (C1), (C2) |
(I2) | b |
Tensor 0 dimensi dari jenis bilangan bulat, boolean, atau floating point | (C1), (C2) |
(I3) | shape |
Konstanta tensor 1 dimensi dari jenis si64 |
(C3) |
(I4) | rng_distribution |
enum UNIFORM dan NORMAL |
(C2) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor bilangan bulat, boolean, atau jenis floating point | (C1-C3) |
Batasan
- (C1)
element_type(a) = element_type(b) = element_type(result)
. - (C2) Jika
rng_distribution = NORMAL
, makais_float(a)
. - (C3)
shape(result) = shape
.
Contoh
// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
rng_distribution = #stablehlo<rng_distribution UNIFORM>
} : (tensor<i32>, tensor<i32>, tensor<2xi64>) -> tensor<3x3xi32>
// %result: [
// [1, 0, 1],
// [1, 1, 1],
// [0, 0, 0]
// ]
rng_bit_generator
Semantik
Menampilkan output
yang diisi dengan bit acak seragam dan status output yang diperbarui
output_state
menggunakan algoritma generator angka pseudorandom rng_algorithm
dengan status awal initial_state
. Output dijamin merupakan
fungsi deterministik dari initial_state
, tetapi tidak dijamin bersifat
deterministik di antara implementasi.
rng_algorithm
adalah salah satu dari berikut ini:
DEFAULT
: Algoritma yang ditentukan implementasi.THREE_FRY
: Varian algoritma Threefry yang ditentukan implementasi.*PHILOX
: Varian algoritma Philox yang ditentukan implementasi.*
* Lihat: Salmon et al. SC 2011. Angka acak paralel: semudah 1, 2, 3.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | rng_algorithm |
enum DEFAULT , THREE_FRY , dan PHILOX |
(C2) |
(I2) | initial_state |
Tensor 1 dimensi dari jenis ui64 |
(C1), C2 |
Output
Nama | Jenis | Batasan |
---|---|---|
output_state |
Tensor 1 dimensi dari jenis ui64 |
(C1) |
output |
tensor dari jenis bilangan bulat atau floating point |
Batasan
- (C1)
type(initial_state) = type(output_state)
. - (C2)
size(initial_state)
ditentukan sebagai:- ditentukan oleh implementasi jika
rng_algorithm = DEFAULT
. 2
ifrng_algorithm = THREE_FRY
.2
atau3
jikarng_algorithm = PHILOX
.
- ditentukan oleh implementasi jika
Contoh
// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
rng_algorithm = #stablehlo<rng_algorithm THREE_FRY>
} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x2xui64>)
// %output_state: [1, 6]
// %output: [
// [9236835810183407956, 16087790271692313299],
// [18212823393184779219, 2658481902456610144]
// ]
round_nearest_afz
Semantik
Melakukan pembulatan element-wise ke bilangan bulat terdekat, memisahkan dari nol, pada tensor operand
dan menghasilkan tensor result
. Menerapkan
operasi roundToIntegralTiesToAway
dari spesifikasi IEEE-754. Untuk
jenis terkuantisasi, menjalankan
dequantize_op_quantize(round_nearest_afz, operand, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor jenis floating point atau tensor terkuantisasi per tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor jenis floating point atau tensor terkuantisasi per tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result)
.
Contoh
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]
round_nearest_even
Semantik
Melakukan pembulatan elemen ke bilangan bulat terdekat, memecahkan ikatan
ke bilangan bulat genap, pada tensor operand
dan menghasilkan tensor
result
. Mengimplementasikan operasi roundToIntegralTiesToEven
dari spesifikasi
IEEE-754. Untuk jenis kuantisasi, lakukan
dequantize_op_quantize(round_nearest_even, operand, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor jenis floating point atau tensor terkuantisasi per tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor jenis floating point atau tensor terkuantisasi per tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result)
.
Contoh
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]
rsqrt
Semantik
Melakukan operasi akar kuadrat reciprocal element-wise pada tensor operand
dan
menghasilkan tensor result
. Bergantung pada jenis elemen, lakukan tindakan berikut:
- Untuk float:
rSqrt
dari IEEE-754. - Untuk bilangan kompleks: akar kuadrat kebalikan kompleks.
- Untuk jenis kuantisasi:
dequantize_op_quantize(rsqrt, operand, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor floating point atau jenis kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor dari jenis floating point atau kompleks atau tensor terkuantisasi per tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result)
.
Contoh
// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.5], [0.33333343, 0.2]]
menyebar
Semantik
Menghasilkan tensor results
yang sama dengan tensor inputs
, kecuali beberapa slice yang ditentukan oleh scatter_indices
diperbarui dengan nilai updates
menggunakan update_computation
.
Diagram berikut menunjukkan cara elemen di updates...
memetakan elemen di
results...
menggunakan contoh konkret. Diagram ini memilih beberapa contoh
indeks updates...
dan menjelaskan secara mendetail indeks results...
yang
sesuai.
Secara lebih formal, untuk semua update_index
di index_space(updates[0])
:
update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims]
.update_scatter_index = update_index[update_scatter_dims...]
.start_index
ditentukan sebagai:scatter_indices[si0, ..., :, ..., siN]
dengansi
adalah elemen individual diupdate_scatter_index
dan:
disisipkan pada indeksindex_vector_dim
, jikaindex_vector_dim
<rank(scatter_indices)
.[scatter_indices[update_scatter_index]]
jika tidak.
- Untuk
d_input
diaxes(inputs[0])
,full_start_index[d_input] = start_index[d_start]
jikad_input = scatter_dims_to_operand_dims[d_start]
.full_start_index[d_input] = 0
jika tidak.
- Untuk
d_input
diaxes(inputs[0])
,full_batching_index[d_input] = update_scatter_index[d_start - (d_start < index_vector_dim ? 0 : 1)]
jikad_input = input_batching_dims[i_batching]
dand_start = scatter_indices_batching_dims[i_batching]
.full_batching_index[d_input] = 0
jika tidak.
update_window_index = update_index[update_window_dims...]
.full_window_index = [wi0, ..., 0, ..., wiN]
denganwi
adalah elemen individual diupdate_window_index
, dan0
disisipkan pada indeks dariinserted_window_dims
daninput_batching_dims
.result_index = full_start_index + full_batching_index + full_window_index
.
Dengan demikian, results = exec(schedule, inputs)
, dengan:
schedule
adalah permutasiindex_space(updates[0])
yang ditentukan oleh implementasi.exec([update_index, ...], results) = exec([...], updated_results)
dengan:- Jika
result_index
berada dalam batas untukshape(results...)
updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )
updated_values = update_computation(results...[result_index], updates_converted)
updated_results
adalah salinanresults
denganresults...[result_index]
ditetapkan keupdated_values...
.- Atau
updated_results = results
.
- Jika
exec([], results) = results
.
Jika indices_are_sorted
adalah true
, implementasi dapat mengasumsikan bahwa
scatter_indices
diurutkan sehubungan dengan scatter_dims_to_operand_dims
,
jika tidak, perilaku tidak ditentukan. Secara lebih formal, untuk semua i1 < i2
dari
indices(result)
, full_start_index(i1)
<= full_start_index(i2)
.
Jika unique_indices
adalah true
, implementasi dapat mengasumsikan bahwa semua
indeks result_index
yang tersebar bersifat unik. Jika unique_indices
adalah
true
, tetapi indeks yang disebar tidak unik, perilakunya
tidak ditentukan.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | inputs |
jumlah tensor atau tensor terkuantisasi per-tensor | (C1), (C2), (C4-C6), (C11), (C13), (C18), (C21), (C23-C24) |
(I2) | scatter_indices |
tensor dari jenis bilangan bulat | (C4), (C15), (C19), (C22) |
(I3) | updates |
jumlah variabel tensor atau tensor terkuantisasi per tensor | (C3-C6), (C8) |
(I4) | update_window_dims |
Konstanta tensor 1 dimensi dari jenis si64 |
(C2), (C4), (C7-C8) |
(I5) | inserted_window_dims |
Konstanta tensor 1 dimensi jenis si64 |
(C2), (C4), (C9-C11) |
(I6) | input_batching_dims |
Konstanta tensor 1 dimensi dari jenis si64 |
(C2), (C4), (C9), (C12-13), (C17-18), (C20) |
(I7) | scatter_indices_batching_dims |
Konstanta tensor 1 dimensi dari jenis si64 |
(C14-C18) |
(I8) | scatter_dims_to_operand_dims |
Konstanta tensor 1 dimensi dari jenis si64 |
(C19-C21) |
(I9) | index_vector_dim |
konstanta dari jenis si64 |
(C4), (C16), (C19), (C22) |
(I10) | indices_are_sorted |
konstanta dari jenis i1 |
|
(I11) | unique_indices |
konstanta dari jenis i1 |
|
(I12) | update_computation |
fungsi | (C23) |
Output
Nama | Jenis | Batasan |
---|---|---|
results |
jumlah variabel tensor atau tensor terkuantisasi per tensor | (C24-C25) |
Batasan
- (C1)
same(shape(inputs...))
. - (C2) `rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims)
- size(input_batching_dims)`.
- (C3)
same(shape(updates...))
. - (C4)
shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes)
dengan:update_scatter_dim_sizes = shape(scatter_indices)
, tetapi ukuran dimensiscatter_indices
yang sesuai denganindex_vector_dim
tidak disertakan.update_window_dim_sizes <= shape(inputs[0])
, kecuali bahwa ukuran dimensi diinputs[0]
yang sesuai denganinserted_window_dims
daninput_batching_dims
tidak disertakan.combine
menempatkanupdate_scatter_dim_sizes
pada sumbu yang sesuai denganupdate_scatter_dims
danupdate_window_dim_sizes
pada sumbu yang sesuai denganupdate_window_dims
.
- (C5)
0 < size(inputs) = size(updates) = N
. - (C6)
element_type(updates...) = element_type(inputs...)
. - (C7)
is_unique(update_window_dims) and is_sorted(update_window_dims)
. - (C8)
0 <= update_window_dims < rank(updates[0])
. - (C9)
is_unique(concatenate(inserted_window_dims, input_batching_dims))
- (C10)
is_sorted(inserted_window_dims)
. - (C11)
0 <= inserted_window_dims < rank(inputs[0])
. - (C12)
is_sorted(input_batching_dims)
. - (C13)
0 <= input_batching_dims < rank(inputs[0]))
. - (C14)
is_unique(scatter_indices_batching_dims)
. - (C15)
0 <= scatter_indices_batching_dims < rank(scatter_indices)
. - (C16)
index_vector_dim not in scatter_indices_batching_dims
. - (C17)
size(input_batching_dims) == size(scatter_indices_batching_dims)
. - (C18)
dim(inputs[0], input_batching_dims...) = dim(scatter_indices, scatter_indices_batching_dims...)
. - (C19)
size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1
. - (C20)
is_unique(concatenate(scatter_dims_to_operand_dims, input_batching_dims))
. - (C21)
0 <= scatter_dims_to_operand_dims < rank(inputs[0])
. - (C22)
0 <= index_vector_dim <= rank(scatter_indices)
. - (C23)
update_computation
memiliki jenis(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
, denganis_promotable(element_type(inputs[i]), Ei)
. - (C24)
shape(inputs...) = shape(results...)
. - (C25)
element_type(results[i]) = Ei
untuk semuai
di[0,N)
.
Contoh
// %input: [
// [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ],
// [
// [[25, 26], [27, 28], [29, 30], [31, 32]],
// [[33, 34], [35, 36], [37, 38], [39, 40]],
// [[41, 42], [43, 44], [45, 46], [47, 48]]
// ]
// ]
// %scatter_indices: [
// [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 9]]
// ],
// [
// [[0, 0], [2, 1], [2, 2]],
// [[1, 2], [0, 1], [1, 0]]
// ]
// ]
// %update: [
// [
// [[1, 1], [1, 1], [1, 1]],
// [[1, 1], [1, 1], [1, 1]]
// ],
// [
// [[1, 1], [1, 1], [1, 1]],
// [[1, 1], [1, 1], [1, 1]]
// ]
// ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [3, 4],
inserted_window_dims = [1],
input_batching_dims = [0],
scatter_indices_batching_dims = [1],
scatter_dims_to_operand_dims = [2, 1],
index_vector_dim = 3>,
indices_are_sorted = false,
unique_indices = false
} : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>, tensor<2x2x3x2x2xi64>) -> tensor<2x3x4x2xi64>
// %result: [
// [
// [[3, 4], [6, 7], [6, 7], [7, 8]],
// [[9, 10],[11, 12], [15, 16], [17, 18]],
// [[17, 18], [19, 20], [22, 23], [24, 25]]
// ],
// [
// [[25, 26], [28, 29], [30, 31], [31, 32]],
// [[35, 36], [38, 39], [38, 39], [39, 40]],
// [[41, 42], [44, 45], [46, 47], [47, 48]]
// ]
// ]
pilih
Semantik
Menghasilkan tensor result
dengan setiap elemen dipilih dari tensor on_true
atau
on_false
berdasarkan nilai elemen pred
yang sesuai.
Secara lebih formal, result[result_index] = pred_element ? on_true[result_index] :
on_false[result_index]
, dengan pred_element = rank(pred) = 0 ? pred[] :
pred[result_index]
. Untuk jenis terkuantisasi, menjalankan
dequantize_select_quantize(pred, on_true, on_false, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | pred |
tensor jenis i1 |
(C1) |
(I2) | on_true |
tensor atau tensor terkuantisasi per-tensor | (C1-C2) |
(I3) | on_false |
tensor atau tensor terkuantisasi per tensor | (C2) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau tensor terkuantisasi per tensor | (C2) |
Batasan
- (C1)
rank(pred) = 0 or shape(pred) = shape(on_true)
. - (C2)
baseline_type(on_true) = baseline_type(on_false) = baseline_type(result)
.
Contoh
// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false) : (tensor<2x2xi1>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 2], [3, 8]]
select_and_scatter
Semantik
Menyebarkan nilai dari tensor source
menggunakan scatter
berdasarkan
hasil reduce_window
dari tensor input
menggunakan select
dan menghasilkan
tensor result
.
Diagram berikut menunjukkan cara elemen di result
dihitung dari
operand
dan source
menggunakan contoh konkret.
Secara lebih formal:
selected_values = reduce_window_without_init(...)
dengan input berikut:inputs = [operand].
window_dimensions
,window_strides
, danpadding
yang digunakan apa adanya.base_dilations = windows_dilations = 1
.body
ditentukan sebagai:
def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>: return select(arg0, arg1) ? arg0 : arg1;
dengan
E = element_type(operand)
, danreduce_window_without_init
berfungsi persis sepertireduce_window
, kecuali bahwaschedule
darireduce
yang mendasarinya (lihat reduce) tidak menyertakan nilai init. Saat ini, tidak ditentukan apa yang akan terjadi jika jendela yang sesuai tidak memiliki nilai (#731).result[result_index] = reduce([source_values], [init_value], [0], scatter)
dengan:source_values = [source[source_index] for source_index in source_indices]
.selected_index(source_index) = operand_index
jikaselected_values[source_index]
memiliki elemenoperand
darioperand_index
.source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index]
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor atau tensor terkuantisasi per-tensor | (C1-C4), (C6), (C8-C11) |
(I2) | source |
tensor atau tensor terkuantisasi per tensor | (C1), (C2) |
(I3) | init_value |
Tensor 0 dimensi atau tensor terkuantisasi per tensor | (C3) |
(I4) | window_dimensions |
Konstanta tensor 1 dimensi dari jenis si64 |
(C2), (C4), (C5) |
(I5) | window_strides |
Konstanta tensor 1 dimensi dari jenis si64 |
(C2), (C6), (C7) |
(I6) | padding |
Konstanta tensor 2 dimensi dari jenis si64 |
(C2), (C8) |
(I7) | select |
fungsi | (C9) |
(I8) | scatter |
fungsi | (C10) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau tensor terkuantisasi per tensor | (C11-C12) |
Batasan
- (C1)
element_type(operand) = element_type(source)
. - (C2)
shape(source) = num_windows
dengan:padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1]
.is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape
.num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1
.
- (C3)
element_type(init_value) = element_type(operand)
. - (C4)
size(window_dimensions) = rank(operand)
. - (C5)
0 < window_dimensions
. - (C6)
size(window_strides) = rank(operand)
. - (C7)
0 < window_strides
. - (C8)
shape(padding) = [rank(operand), 2]
. - (C9)
select
memiliki jenis(tensor<E>, tensor<E>) -> tensor<i1>
denganE = element_type(operand)
. - (C10)
scatter
memiliki jenis(tensor<E>, tensor<E>) -> tensor<E>
denganis_promotable(element_type(operand), E)
. - (C11)
shape(operand) = shape(result)
. - (C12)
element_type(result) = E
.
Contoh
// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
window_dimensions = array<i64: 3, 1>,
window_strides = array<i64: 2, 1>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi64>, tensor<2x2xi64>, tensor<i64>) -> tensor<4x2xi64>
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]
kirim
Semantik
Mengirim inputs
ke saluran channel_id
dan menghasilkan token result
.
Jika is_host_transfer
adalah true
, operasi akan mentransfer data ke
host. Jika tidak, data akan ditransfer ke perangkat lain. Artinya,
ditentukan oleh implementasi. Flag ini menduplikasi informasi yang diberikan di
channel_type
, sehingga di masa mendatang kami berencana untuk menyimpan salah satunya saja
(#666).
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | inputs |
jumlah variabel tensor atau tensor terkuantisasi | |
(I2) | token |
token |
|
(I3) | channel_id |
konstanta dari jenis si64 |
|
(I4) | channel_type |
enum DEVICE_TO_DEVICE dan DEVICE_TO_HOST |
(C1) |
(I5) | is_host_transfer |
konstanta dari jenis i1 |
(C1) |
Output
Nama | Jenis |
---|---|
result |
token |
Batasan
- (C1)
channel_type
didefinisikan sebagai:DEVICE_TO_HOST
ifis_host_transfer = true
,DEVICE_TO_DEVICE
sebaliknya.
Contoh
%result = "stablehlo.send"(%operand, %token) {
channel_handle = #stablehlo.channel_handle<handle = 1, type = 2>,
is_host_transfer = true
} : (tensor<2x2xi64>, !stablehlo.token) -> !stablehlo.token
shift_left
Semantik
Melakukan operasi pergeseran kiri element-wise pada tensor lhs
dengan jumlah
bit rhs
dan menghasilkan tensor result
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | lhs |
tensor dari jenis bilangan bulat | (C1) |
(I2) | rhs |
tensor dari jenis bilangan bulat | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor dari jenis bilangan bulat | (C1) |
Batasan
- (C1)
type(lhs) = type(rhs) = type(result)
.
Contoh
// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-2, 0, 8]
shift_right_arithmetic
Semantik
Melakukan operasi pergeseran kanan aritmetika element-wise pada tensor lhs
dengan
jumlah bit rhs
dan menghasilkan tensor result
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | lhs |
tensor dari jenis bilangan bulat | (C1) |
(I2) | rhs |
tensor dari jenis bilangan bulat | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor dari jenis bilangan bulat | (C1) |
Batasan
- (C1)
type(lhs) = type(rhs) = type(result)
.
Contoh
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-1, 0, 1]
shift_right_logical
Semantik
Melakukan operasi shift kanan logis berbasis element pada tensor lhs
berdasarkan jumlah bit rhs
dan menghasilkan tensor result
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | lhs |
tensor dari jenis bilangan bulat | (C1) |
(I2) | rhs |
tensor dari jenis bilangan bulat | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor dari jenis bilangan bulat | (C1) |
Batasan
- (C1)
type(lhs) = type(rhs) = type(result)
.
Contoh
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [9223372036854775807, 0, 1]
tanda
Semantik
Menampilkan tanda element-wise operand
dan menghasilkan tensor result
.
Secara lebih formal, untuk setiap elemen x
, semantik dapat dinyatakan menggunakan
sintaksis Python sebagai berikut:
def sign(x):
if is_integer(x):
if compare(x, 0, LT, SIGNED): return -1
if compare(x, 0, EQ, SIGNED): return 0
return 1
elif is_float(x):
if is_nan(x): return NaN
if compare(x, -0.0, EQ, FLOAT): return -0.0
if compare(x, +0.0, EQ, FLOAT): return +0.0
if compare(x, 0.0, LT, FLOAT): return -1.0
return 1.0
elif is_complex(x):
if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
return divide(x, convert(abs(x), type(x)))
Untuk jenis kuantisasi, lakukan
dequantize_op_quantize(sign, operand, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor bilangan bulat bertanda tangan, floating point, atau jenis kompleks atau tensor terkuantisasi per tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor bilangan bulat bertanda tangan, floating point, atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result)
.
Contoh
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
sinus
Semantik
Melakukan operasi sinus berdasarkan elemen pada tensor operand
dan menghasilkan tensor
result
. Bergantung pada jenis elemen, melakukan hal berikut:
- Untuk float:
sin
dari IEEE-754. - Untuk bilangan kompleks: sinus kompleks.
- Untuk jenis terkuantisasi:
dequantize_op_quantize(sine, operand, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor floating point atau jenis kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor dari jenis floating point atau kompleks atau tensor terkuantisasi per tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result)
.
Contoh
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.sine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [0.0, -1.0]]
slice
Semantik
Mengekstrak slice dari operand
menggunakan indeks awal yang dihitung secara statis
dan menghasilkan tensor result
. start_indices
berisi indeks awal
slice untuk setiap dimensi, limit_indices
berisi indeks akhir
(eksklusif) untuk slice untuk setiap dimensi, dan strides
berisi stride
untuk setiap dimensi.
Secara lebih formal, result[result_index] = operand[operand_index]
dengan
operand_index = start_indices + result_index * strides
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor atau tensor terkuantisasi per tensor | (C1-C3), (C5) |
(I2) | start_indices |
Konstanta tensor 1 dimensi dari jenis si64 |
(C2), (C3), (C5) |
(I3) | limit_indices |
Konstanta tensor 1 dimensi dari jenis si64 |
(C2), (C3), (C5) |
(I4) | strides |
Konstanta tensor 1 dimensi dari jenis si64 |
(C2), (C4) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau tensor terkuantisasi per tensor | (C1), (C5) |
Batasan
- (C1)
element_type(operand) = element_type(result)
. - (C2)
size(start_indices) = size(limit_indices) = size(strides) = rank(operand)
. - (C3)
0 <= start_indices <= limit_indices <= shape(operand)
. - (C4)
0 < strides
. - (C5)
shape(result) = ceil((limit_indices - start_indices) / strides)
.
Contoh
// %operand: [
// [0, 0, 0, 0],
// [0, 0, 1, 1],
// [0, 0, 1, 1]
// ]
%result = "stablehlo.slice"(%operand) {
start_indices = array<i64: 1, 2>,
limit_indices = array<i64: 3, 4>,
strides = array<i64: 1, 1>
} : (tensor<3x4xi64>) -> tensor<2x2xi64>
// % result: [
// [1, 1],
// [1, 1]
// ]
mengurutkan
Semantik
Mengurutkan slice inputs
1 dimensi di sepanjang dimensi dimension
secara bersamaan,
sesuai dengan comparator
dan menghasilkan results
.
Tidak seperti input serupa dalam operasi lain, dimension
mengizinkan nilai negatif,
dengan semantik yang dijelaskan di bawah. Di masa mendatang, hal ini mungkin tidak diizinkan
karena alasan konsistensi
(#1377).
Jika is_stable
bernilai benar, pengurutan akan stabil, yaitu urutan relatif
elemen yang dianggap sama oleh pembanding akan dipertahankan. Untuk kasus
saat ada satu input, dua elemen e1
dan e2
dianggap
sama oleh pembanding jika dan hanya jika
comparator(e1, e2) = comparator(e2, e1) = false
. Lihat formalisasi di bawah
untuk mengetahui cara generalisasi ini ke beberapa input.
Secara lebih formal, untuk semua result_index
di index_space(results[0])
:
adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension
.result_slice = [ri0, ..., :, ..., riR-1]
denganriN
adalah elemen individual diresult_index
, dan:
disisipkan padaadjusted_dimension
.inputs_together = (inputs[0]..., ..., inputs[N-1]...)
.results_together[result_slice] = sort(inputs_together[result_slice], comparator_together)
.- dengan
sort
mengurutkan slice 1 dimensi dalam urutan non-menurun yang mengharapkancomparator_together
menampilkantrue
jika argumen sisi kiri lebih kecil dari argumen kedua sisi kanan. def comparator_together(lhs_together, rhs_together): args = [] for (lhs_el, rhs_el) in zip(lhs_together, rhs_together): args.append(lhs_el) args.append(rhs_el) return comparator(*args)
(results[0]..., ..., results[N-1]...) = results_together
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | inputs |
jumlah variabel tensor atau tensor terkuantisasi per tensor | (C1-C5) |
(I2) | dimension |
konstanta dari jenis si64 |
(C4) |
(I3) | is_stable |
konstanta dari jenis i1 |
|
(I4) | comparator |
fungsi | (C5) |
Output
Nama | Jenis | Batasan |
---|---|---|
results |
jumlah tensor atau tensor terkuantisasi per-tensor | (C2), (C3) |
Batasan
- (C1)
0 < size(inputs)
. - (C2)
type(inputs...) = type(results...)
. - (C3)
same(shape(inputs...) + shape(results...))
. - (C4)
-R <= dimension < R
, denganR = rank(inputs[0])
. - (C5)
comparator
memiliki jenis(tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1>
, denganEi = element_type(inputs[i])
.
Contoh
// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>, %arg3: tensor<i64>):
%predicate = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GT>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"stablehlo.return"(%predicate) : (tensor<i1>) -> ()
}) {
dimension = 0 : i64,
is_stable = true
} : (tensor<2x3xi64>, tensor<2x3xi64>) -> (tensor<2x3xi64>, tensor<2x3xi64>)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]
sqrt
Semantik
Melakukan operasi akar kuadrat element-wise pada tensor operand
dan menghasilkan
tensor result
. Bergantung pada jenis elemen, lakukan tindakan berikut:
- Untuk float:
squareRoot
dari IEEE-754. - Untuk bilangan kompleks: akar kuadrat kompleks.
- Untuk jenis kuantisasi:
dequantize_op_quantize(sqrt, operand, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor floating point atau jenis kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor dari jenis floating point atau kompleks atau tensor terkuantisasi per tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result)
.
Contoh
// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [2.0, 3.0]]
kurangi
Semantik
Melakukan pengurangan berbasis elemen dari dua tensor lhs
dan rhs
, serta menghasilkan
tensor result
. Bergantung pada jenis elemen, melakukan hal berikut:
- Untuk bilangan bulat: pengurangan bilangan bulat.
- Untuk float:
subtraction
dari IEEE-754. - Untuk bilangan kompleks: pengurangan kompleks.
- Untuk jenis kuantisasi:
dequantize_op_quantize(subtract, lhs, rhs, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | lhs |
tensor bilangan bulat, floating point, atau jenis kompleks atau tensor terkuantisasi per tensor | (C1) |
(I2) | rhs |
tensor bilangan bulat, floating point, atau jenis kompleks atau tensor terkuantisasi per tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor bilangan bulat, floating point, atau jenis kompleks atau tensor terkuantisasi per tensor | (C1) |
Batasan
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Contoh
// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
// %result: [[1, 2], [3, 4]]
tan
Semantik
Melakukan operasi tangen element-wise pada tensor operand
dan menghasilkan
tensor result
. Bergantung pada jenis elemen, lakukan tindakan berikut:
- Untuk float:
tan
dari IEEE-754. - Untuk bilangan kompleks: tangen kompleks.
- Untuk jenis terkuantisasi:
dequantize_op_quantize(tan, operand, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor floating point atau jenis kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor dari jenis floating point atau kompleks atau tensor terkuantisasi per tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result)
.
Contoh
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.tan"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [
// [0.0, 1.63312e+16],
// [0.0, 5.44375e+15]
// ]
tanh
Semantik
Melakukan operasi tangen hiperbolik element-wise pada tensor operand
dan
menghasilkan tensor result
. Bergantung pada jenis elemen, lakukan tindakan berikut:
- Untuk float:
tanh
dari IEEE-754. - Untuk bilangan kompleks: tangen hiperbolik kompleks.
- Untuk jenis terkuantisasi:
dequantize_op_quantize(tanh, operand, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor floating point atau jenis kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor dari jenis floating point atau kompleks atau tensor terkuantisasi per tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result)
.
Contoh
// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand) : (tensor<3xf32>) -> tensor<3xf32>
// %result: [-0.76159416, 0.0, 0.76159416]
transpose
Semantik
Mengurutkan ulang dimensi tensor operand
menggunakan permutation
dan menghasilkan
tensor result
. Secara lebih formal, result[result_index] = operand[operand_index]
dengan result_index[d] = operand_index[permutation[d]]
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor atau tensor terkuantisasi | (C1-C4) |
(I2) | permutation |
Konstanta tensor 1 dimensi dari jenis si64 |
(C2-C4) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau tensor terkuantisasi | (C1), (C3-C4) |
Batasan
- (C1)
element_type(result)
diberikan oleh:element_type(operand)
, jika!is_per_axis_quantized(operand)
.element_type(operand)
kecuali jikaquantization_dimension(operand)
danquantization_dimension(result)
mungkin berbeda.
- (C2)
permutation
adalah permutasi darirange(rank(operand))
. - (C3)
shape(result) = dim(operand, permutation...)
. - (C4) Jika
is_per_axis_quantized(result)
, makaquantization_dimension(operand) = permutation(quantization_dimension(result))
.
Contoh
// %operand: [
// [[1,2], [3,4], [5,6]],
// [[7,8], [9,10], [11,12]]
// ]
%result = "stablehlo.transpose"(%operand) {
permutation = array<i64: 2, 1, 0>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
// %result: [
// [[1,7], [3,9], [5,11]],
// [[2,8], [4,10], [6,12]]
// ]
triangular_solve
Semantik
Menyelesaikan batch sistem persamaan linear dengan matriks koefisien segitiga bawah atau atas.
Secara lebih formal, dengan a
dan b
, result[i0, ..., iR-3, :, :]
adalah solusi
untuk op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :]
jika left_side
adalah
true
atau x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :]
jika
left_side
adalah false
, yang menyelesaikan variabel x
dengan op(a)
ditentukan
oleh transpose_a
, yang dapat berupa salah satu dari hal berikut:
NO_TRANSPOSE
: Menjalankan operasi menggunakana
sebagaimana adanya.TRANSPOSE
: Melakukan operasi pada transposea
.ADJOINT
: Melakukan operasi pada transpos konjugata
.
Data input hanya dibaca dari segitiga bawah a
, jika lower
adalah true
atau
segitiga atas a
, jika tidak. Data output ditampilkan dalam segitiga yang sama;
nilai dalam segitiga lainnya ditentukan oleh implementasi.
Jika unit_diagonal
benar, implementasi dapat mengasumsikan bahwa elemen diagonal
a
sama dengan 1, jika tidak, perilakunya tidak ditentukan.
Untuk jenis kuantisasi, lakukan
dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower,
unit_diagonal, transpose_a), a, b, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | a |
tensor dari jenis floating point atau kompleks atau tensor terkuantisasi per tensor | (C1-C3) |
(I2) | b |
tensor floating point atau jenis kompleks atau tensor terkuantisasi per-tensor | (C1-C4) |
(I3) | left_side |
konstanta dari jenis i1 |
(C3) |
(I4) | lower |
konstanta dari jenis i1 |
|
(I5) | unit_diagonal |
konstanta dari jenis i1 |
|
(I6) | transpose_a |
enum NO_TRANSPOSE , TRANSPOSE , dan ADJOINT |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor dari jenis floating point atau kompleks atau tensor terkuantisasi per tensor | (C1) |
Batasan
- (C1)
baseline_element_type(a) = baseline_element_type(b)
. - (C2)
2 <= rank(a) = rank(b) = R
. - (C3) Hubungan antara
shape(a)
danshape(b)
ditentukan sebagai berikut:shape(a)[:-3] = shape(b)[:-3]
.dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1)
.
- (C4)
baseline_type(b) = baseline_type(result)
.
Contoh
// %a = [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
// %b = [
// [2.0, 0.0, 0.0],
// [4.0, 8.0, 0.0],
// [6.0, 10.0, 12.0]
// ]
%result = "stablehlo.triangular_solve"(%a, %b) {
left_side = true,
lower = true,
unit_diagonal = false,
transpose_a = #stablehlo<transpose NO_TRANSPOSE>
} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// %result: [
// [2.0, 0.0, 0.0],
// [0.0, 2.0, 0.0],
// [0.0, 0.0, 2.0]
// ]
tuple
Semantik
Menghasilkan tuple result
dari nilai val
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | val |
jumlah nilai variadik | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tuple | (C1) |
Batasan
- (C1)
result
memiliki jenistuple<E0, ..., EN-1>
denganEi = type(val[i])
.
Contoh
// %val0: [1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1) : (tensor<2xf32>, tuple<tensor<i32>>) -> tuple<tensor<2xf32>, tuple<tensor<i32>>>
// %result: ([1.0, 2.0], (3))
uniform_dequantize
Semantik
Melakukan konversi element-wise dari tensor terkuantisasi operand
ke
tensor floating point result
sesuai dengan parameter kuantisasi yang ditentukan
oleh jenis operand
.
Secara lebih formal, result = dequantize(operand)
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor terkuantisasi | (C1), (C2) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor jenis floating point | (C1), (C2) |
Batasan
- (C1)
shape(operand) = shape(result)
. - (C2)
element_type(result) = expressed_type(operand)
.
Contoh
// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
// %result: [4.0, 15.0]
uniform_quantize
Semantik
Melakukan konversi element-wise dari tensor floating point atau tensor terkuantisasi
operand
ke tensor terkuantisasi result
sesuai dengan parameter
kuantisasi yang ditentukan oleh jenis result
.
Secara lebih formal,
- Jika
is_float(operand)
:result = quantize(operand, type(result))
.
- Jika
is_quantized(operand)
:float_result = dequantize(operand)
.result = quantize(float_result, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor dari jenis floating point atau terkuantisasi | (C1), (C2) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor terkuantisasi | (C1), C2 |
Batasan
- (C1)
shape(operand) = shape(result)
. - (C2)
expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand)
.
Contoh
// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
// %result: [10, 10]
// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-20,0.2:-30}>>
// %result: [20, 45]
while
Semantik
Menghasilkan output dari mengeksekusi fungsi body
0 kali atau lebih saat
fungsi cond
menghasilkan output true
. Secara lebih formal, semantik dapat dinyatakan
menggunakan sintaksis Python sebagai berikut:
internal_state = operand
while cond(*internal_state):
internal_state = body(*internal_state)
results = internal_state
Perilaku loop tanpa batas adalah TBD (#383).
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
jumlah variabel tensor, tensor terkuantisasi, atau token | (C1-C3) |
(I2) | cond |
fungsi | (C1) |
(I3) | body |
fungsi | (C2) |
Output
Nama | Jenis | Batasan |
---|---|---|
results |
jumlah variabel tensor, tensor terkuantisasi, atau token | (C3) |
Batasan
- (C1)
cond
memiliki jenis(T0, ..., TN-1) -> tensor<i1>
, denganTi = type(operand[i])
. - (C2)
body
memiliki jenis(T0, ..., TN-1) -> (T0, ..., TN-1)
, denganTi = type(operand[i])
. - (C3)
type(results...) = type(operand...)
.
Contoh
// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%cond = "stablehlo.compare"(%arg0, %ten) {
comparison_direction = #stablehlo<comparison_direction LT>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
stablehlo.return %cond : tensor<i1>
}, {
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%new_sum = stablehlo.add %arg1, %one : tensor<i64>
%new_i = stablehlo.add %arg0, %one : tensor<i64>
stablehlo.return %new_i, %new_sum : tensor<i64>, tensor<i64>
}) : (tensor<i64>, tensor<i64>) -> (tensor<i64>, tensor<i64>)
// %results0: 10
// %results1: 10
xor
Semantik
Melakukan XOR element-wise dari dua tensor lhs
dan rhs
serta menghasilkan tensor
result
. Bergantung pada jenis elemen, lakukan tindakan berikut:
- Untuk boolean: XOR logis.
- Untuk bilangan bulat: bitwise XOR.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | lhs |
tensor dari jenis boolean atau bilangan bulat | (C1) |
(I2) | rhs |
tensor dari jenis boolean atau bilangan bulat | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor dari jenis boolean atau bilangan bulat | (C1) |
Batasan
- (C1)
type(lhs) = type(rhs) = type(result)
.
Contoh
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, false]]
Interop Dialek
Saat ini, program StableHLO di dunia nyata terkadang berisi operasi yang tidak ditentukan oleh StableHLO.
Modul, Fungsi, Panggilan, dan Pengembalian
StableHLO menggunakan operasi MLIR upstream untuk ModuleOp, FuncOp, CallOp, dan ReturnOp. Hal ini dilakukan untuk interop yang lebih baik dengan mesin MLIR yang ada, karena banyak pass berguna yang ditulis dengan menargetkan FuncOp dan ModuleOp, dan banyak pipeline kompilasi memperkirakan operasi ini akan ada. Jaminan kompatibilitas penuh diterapkan ke operasi ini. Jika ada perubahan pada operasi ini dengan cara yang tidak kompatibel (yaitu penghapusan), padanan StableHLO akan ditambahkan untuk mempertahankan kompatibilitas.
CHLO
Opset CHLO berisi operasi tingkat lebih tinggi yang terurai menjadi StableHLO. Saat ini tidak ada jaminan kompatibilitas untuk CHLO. Untuk jaminan kompatibilitas, chlo-legalize-to-stablehlo pass harus digunakan sebelum serialisasi.
Operasi Bentuk
Ini adalah kasus penggunaan umum di komunitas untuk menggunakan operasi tertentu dari dialek
MLIR inti dalam program StableHLO dinamis untuk melakukan komputasi bentuk.
Biasanya, ini mencakup operasi dialek shape
seperti shape_of
atau num_elements
, operasi dialek tensor
seperti dim
atau from_elements
, dan jenis index
bawaan.
Dynamism RFC > O2
menunjukkan bahwa hal ini berada di luar cakupan, tetapi beberapa dukungan untuk jenis index
disertakan untuk tujuan interop. Tidak ada jaminan kompatibilitas untuk operasi atau jenis
ini. Kartu shape-legalize-to-stablehlo
dapat digunakan untuk mengonversi operasi ini menjadi operasi StableHLO yang didukung sepenuhnya.
Operasi yang Tidak Digunakan Lagi
Ada beberapa operasi StableHLO yang diwarisi dari MHLO yang tidak digunakan lagi dan akan dihapus dari StableHLO. Detail lengkap tentang penghapusan ini dapat ditemukan di Pembersihan StableHLO v1.0 #2283. Masalah pelacak untuk penghentian ini adalah #2340.
Operasi ini termasuk dalam beberapa kategori:
- Kategori "Not in HLO" dari operasi StableHLO - awalnya merupakan bagian dari opset StableHLO, tetapi kemudian dianggap tidak cocok:
broadcast
,create_token
,cross-replica-sum
,dot
,einsum
,torch_index_select
,unary_einsum
(#3). - Operasi yang tidak digunakan - Operasi ini mungkin pernah berguna pada suatu saat, tetapi operasi tersebut
belum dikembangkan sepenuhnya, atau pipeline yang menggunakan operasi ini telah
difaktorkan ulang sehingga tidak memerlukannya lagi. Ini termasuk
map
,tuple
(#598), perbandinganget_tuple_element
,rng
,complex
#560, dan konvolusiwindow_reversal
(#1181).
Beberapa operasi ini dapat dihapus dengan mudah karena dapat diekspresikan menggunakan
operasi yang sudah ada (broadcast
, create_token
, cross-replica-sum
, dot
,
unary_einsum
) dan akan dihapus setelah periode kompatibilitas yang ada
berlalu (6 bulan). Ops lainnya masih dipelajari untuk dihapus (perbandingan
einsum
,
get_tuple_element
, map
, rng
torch_index_select
, tuple
, complex
, window_reversal
). Menunggu masukan komunitas,
ops ini akan dihapus, atau ditambahkan ke spesifikasi dengan dukungan penuh. Hingga
ops futures ini diketahui, ops futures tersebut hanya dijamin kompatibilitasnya selama 6 bulan.
Eksekusi
Eksekusi berurutan
Program StableHLO dijalankan dengan memberikan nilai input ke fungsi main
dan menghitung nilai output. Nilai output fungsi dihitung dengan
menjalankan grafik operasi yang berakar pada operasi return
yang sesuai.
Urutan eksekusi ditentukan oleh implementasi selama selaras dengan
alur data, yaitu jika operasi dieksekusi sebelum penggunaannya. Di StableHLO, semua
operasi yang menghasilkan efek samping menggunakan satu token dan menghasilkan satu token (beberapa token dapat
dimultipleks menjadi satu token melalui after_all
), sehingga urutan eksekusi efek samping
juga selaras dengan alur data. Misalnya, dalam program di bawah
ada dua kemungkinan urutan eksekusi: %0
→ %1
→ %2
→ return
dan
%1
→ %0
→ %2
→ return
.
func.func @main() -> tensor<f64> {
%0 = stablehlo.constant dense<1.0> : tensor<f64>
%1 = stablehlo.constant dense<2.0> : tensor<f64>
%2 = stablehlo.add %0, %1 : tensor<f64>
return %2 : tensor<f64>
}
Secara lebih formal, proses StableHLO adalah kombinasi dari:
1) program StableHLO, 2) status operasi (belum dieksekusi,
sudah dieksekusi), dan 3) nilai perantara yang sedang diproses.
Proses dimulai dengan nilai input untuk fungsi main
, dilanjutkan
melalui grafik operasi yang memperbarui status operasi dan nilai perantara,
dan diakhiri dengan nilai output. Formasi lebih lanjut akan ditentukan
(#484).
Eksekusi paralel
Program StableHLO dapat dijalankan secara paralel, yang diatur ke dalam petak proses 2D
num_replicas
oleh num_partitions
yang keduanya memiliki jenis ui32
.
Dalam petak proses StableHLO, num_replicas * num_partitions
proses
StableHLO dieksekusi secara bersamaan. Setiap proses memiliki
process_id = (replica_id, partition_id)
unik, dengan
replica_id
di replica_ids = range(num_replicas)
dan
partition_id
di partition_ids = range(num_partitions)
yang keduanya memiliki
jenis ui32
.
Ukuran petak proses diketahui secara statis untuk setiap program (di
masa mendatang, kami berencana menjadikannya bagian eksplisit dari program StableHLO
#650), dan posisi
dalam petak proses diketahui secara statis untuk setiap proses. Setiap proses memiliki
akses ke posisinya dalam petak proses melalui operasi replica_id
dan
partition_id
.
Dalam petak proses, semua program dapat sama (dalam gaya "Program tunggal, Beberapa Data"), semua dapat berbeda (dalam gaya "Beberapa Program, Beberapa Data"), atau sesuatu di antaranya. Di masa mendatang, kami berencana memperkenalkan dukungan untuk idiom lain guna menentukan program StableHLO paralel, termasuk GSPMD (#619).
Dalam petak proses, proses sebagian besar independen satu sama lain - proses tersebut memiliki status operasi terpisah, nilai input/antara/output terpisah dan sebagian besar operasi dijalankan secara terpisah di antara proses, dengan pengecualian sejumlah kecil operasi kolektif yang dijelaskan di bawah.
Mengingat bahwa eksekusi sebagian besar operasi hanya menggunakan nilai dari proses
yang sama, biasanya menyebutkan nilai ini berdasarkan namanya menjadi tidak ambigu.
Namun, saat mendeskripsikan semantik operasi kolektif, hal itu tidak memadai, dan
yang menghasilkan notasi name@process_id
untuk merujuk ke nilai name
dalam proses tertentu. (Dari perspektif tersebut, name
yang tidak memenuhi syarat dapat
dilihat sebagai singkatan untuk name@(replica_id(), partition_id())
).
Urutan eksekusi di seluruh proses ditentukan oleh implementasi, kecuali untuk sinkronisasi yang diperkenalkan oleh komunikasi titik ke titik dan operasi kolektif seperti yang dijelaskan di bawah ini.
Komunikasi titik ke titik
Proses StableHLO dapat berkomunikasi satu sama lain melalui
saluran StableHLO. Channel diwakili oleh ID positif jenis
si64
. Melalui berbagai operasi, Anda dapat mengirim nilai ke saluran dan
menerimanya dari saluran.
Formalisasi lebih lanjut, misalnya asal ID saluran ini, cara program memprosesnya, dan jenis sinkronisasi yang diperkenalkan olehnya, masih akan ditentukan (#484).
Komunikasi streaming
Setiap proses StableHLO memiliki akses ke dua antarmuka streaming:
- Infeed yang dapat dibaca.
- Outfeed yang dapat ditulis.
Tidak seperti saluran, yang digunakan untuk berkomunikasi antar-proses sehingga memiliki proses di kedua ujungnya, feed infeed dan outfeed memiliki penerapan akhir yang ditentukan.
Formalisasi lebih lanjut, misalnya, bagaimana komunikasi streaming memengaruhi urutan eksekusi dan jenis sinkronisasi yang diperkenalkan olehnya, adalah TBD (#484).
Operasi kolektif
Ada enam operasi kolektif di StableHLO: all_gather
, all_reduce
,
all_to_all
, collective_broadcast
, collective_permute
, dan
reduce_scatter
. Semua operasi ini membagi proses dalam petak proses
StableHLO menjadi grup proses StableHLO dan mengeksekusi komputasi bersama dalam
setiap grup proses, secara independen dari grup proses lainnya.
Dalam setiap grup proses, operasi kolektif dapat menyebabkan penghalang sinkronisasi. Formalisasi lebih lanjut, misalnya menguraikan kapan tepatnya sinkronisasi ini terjadi, bagaimana prosesnya sampai pada hambatan ini, dan apa yang terjadi jika tidak, akan ditentukan (#484).
Jika grup proses melibatkan komunikasi lintas partisi, yaitu ada
proses dalam grup proses yang ID partisinya berbeda, maka eksekusi
operasi kolektif memerlukan saluran, dan operasi kolektif harus menyediakan
channel_id
positif dari jenis si64
. Komunikasi lintas replika tidak
membutuhkan saluran.
Komputasi yang dilakukan oleh operasi kolektif bersifat khusus untuk setiap operasi dan dijelaskan di setiap bagian operasi di atas. Namun, strategi yang digunakan untuk membagi petak proses menjadi grup proses dibagikan di antara operasi ini dan dijelaskan di bagian ini. Secara lebih formal, StableHLO mendukung empat strategi berikut.
cross_replica
Hanya komunikasi lintas replika yang terjadi dalam setiap grup proses. Strategi
ini menggunakan replica_groups
- daftar daftar ID replika - dan menghitung
produk Kartesius replica_groups
dengan partition_ids
. replica_groups
harus memiliki elemen unik dan mencakup semua replica_ids
. Secara lebih formal, menggunakan sintaksis Python:
def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
for partition_id in partition_ids:
process_group = []
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
Misalnya, untuk replica_groups = [[0, 1], [2, 3]]
dan num_partitions = 2
,
cross_replica
akan menghasilkan
[[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]]
.
cross_partition
Hanya komunikasi lintas partisi yang terjadi dalam setiap grup proses. Strategi
ini menggunakan partition_groups
- daftar daftar ID partisi - dan
menghitung produk Kartesius partition_groups
dengan replica_ids
.
partition_groups
harus memiliki elemen unik dan mencakup semua partition_ids
.
Secara lebih formal, menggunakan sintaksis Python:
def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
for partition_group in partition_groups:
for replica_id in replica_ids:
process_group = []
for partition_id in partition_group:
process_group.append((replica_id, partition_id))
yield process_group
Misalnya, untuk partition_groups = [[0, 1]]
dan num_replicas = 4
,
cross_partition
akan menghasilkan
[[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]]
.
cross_replica_and_partition
Komunikasi lintas-replikasi dan lintas-partisi dapat terjadi dalam setiap
grup proses. Strategi ini menggunakan replica_groups
- daftar daftar
ID replika - dan menghitung produk Kartesius dari setiap replica_group
dengan
partition_ids
. replica_groups
harus memiliki elemen unik dan mencakup semua
replica_ids
. Secara lebih formal, menggunakan sintaksis Python:
def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
process_group = []
for partition_id in partition_ids:
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
Misalnya, untuk replica_groups = [[0, 1], [2, 3]]
dan num_partitions = 2
,
cross_replica_and_partition
akan menghasilkan
[[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]]
.
flattened_ids
Strategi ini menggunakan flattened_id_groups
- daftar ID proses yang "disatukan" dalam bentuk replica_id * num_partitions + partition_id
- dan mengubahnya menjadi ID proses. flattened_id_groups
harus memiliki elemen unik
dan mencakup semua process_ids
. Secara lebih formal, menggunakan sintaksis Python:
def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
for flattened_id_group in flattened_id_groups:
process_group = []
for flattened_id in flattened_id_group:
replica_id = flattened_id // num_partitions
partition_id = flattened_id % num_partitions
process_group.append((replica_id, partition_id))
yield process_group
Misalnya, untuk flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]]
,
num_replicas = 4
, dan num_partitions = 2
, flattened_ids
akan menghasilkan
[[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]]
.
Akurasi
Saat ini, StableHLO tidak memberikan jaminan tentang akurasi numerik, tetapi hal ini dapat berubah pada masa mendatang (#1156).
Semantik eksekusi operasi kuantisasi
Penafsiran operasi StableHLO yang dikuantisasi dapat bervariasi, bergantung pada persyaratan dan kemampuan hardware. Misalnya, beberapa hardware dapat memilih untuk menafsirkan operasi kuantisasi menggunakan strategi "dequantize, perform floating-point operation, and finally quantize". Yang lain dapat melakukan seluruh komputasi dengan aritmetika bilangan bulat. Oleh karena itu, interpretasi operasi StableHLO yang dikuantisasi ditentukan secara eksklusif oleh penerapan tertentu. Penafsiran kuantisasi campuran (#1575) harus didasarkan pada semantiknya seperti yang ditentukan dalam spesifikasi (melalui 1792).
Error
Program StableHLO divalidasi melalui serangkaian batasan yang ekstensif untuk operasi individual, yang mengesampingkan banyak class error sebelum runtime. Namun, kondisi error masih mungkin terjadi, misalnya melalui overflow bilangan bulat, akses keluar batas, dll. Kecuali jika dinyatakan secara eksplisit, semua error ini akan menghasilkan perilaku yang ditentukan implementasi, tetapi hal ini dapat berubah di masa mendatang (#1157).
Pengecualian floating point
Sebagai pengecualian untuk aturan ini, pengecualian floating point dalam program StableHLO
memiliki perilaku yang ditentukan dengan baik. Operasi yang menghasilkan pengecualian yang ditentukan oleh
standar IEEE-754 (operasi tidak valid, pembagian dengan nol, overflow, underflow, atau
pengecualian yang tidak tepat) menghasilkan hasil default (seperti yang ditentukan dalam standar) dan
melanjutkan eksekusi tanpa menaikkan flag status yang sesuai; mirip dengan
penanganan pengecualian raiseNoFlag
dari standar. Pengecualian untuk operasi nonstandar (misalnya, aritmetika kompleks dan fungsi transendental tertentu) ditentukan oleh implementasi.
Ketidakcocokan bentuk
StabilHLO mendukung tensor berbentuk dinamis. Namun, bentuk harus sesuai pada runtime, jika tidak, perilakunya tidak ditentukan. StabilHLO tidak secara eksplisit menyediakan operasi yang dapat menyatakan bahwa tensor memiliki bentuk tertentu saat runtime. Produsen bertanggung jawab untuk membuat kode yang benar.
Sebagai contoh spesifik, program di bawah valid. Namun, saat runtime, bentuk
%arg0
dan %arg1
harus sama. Jika tidak, perilaku program tidak ditentukan:
func.func @foo(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<?xi32>
return %0 : tensor<?xi32>
}
Notasi
Untuk menjelaskan sintaksis, dokumen ini menggunakan ragam ISO yang dimodifikasi dari sintaksis EBNF (ISO/IEC 14977:1996, Wikipedia), dengan dua modifikasi: 1) aturan ditentukan menggunakan ::=
, bukan =
,
2) penyambungan dinyatakan menggunakan juxtaposition, bukan ,
.
Untuk mendeskripsikan semantik (yaitu dalam bagian "Jenis", "Konstanta", dan "Ops"), kami menggunakan formula yang didasarkan pada sintaksis Python yang diperluas dengan dukungan untuk mengekspresikan operasi array secara ringkas seperti yang dijelaskan di bawah. Hal ini berfungsi dengan baik untuk cuplikan kode kecil, tetapi dalam kasus yang jarang terjadi saat cuplikan kode yang lebih besar diperlukan, kita menggunakan sintaksis Python vanilla yang selalu diperkenalkan secara eksplisit.
Formula
Mari kita pelajari cara kerja formula berdasarkan contoh dari spesifikasi
dot_general
. Salah satu batasan untuk operasi ini terlihat seperti berikut:
dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)
.
Nama yang digunakan dalam formula ini berasal dari dua sumber: 1) fungsi global,
yaitu dim
, 2) definisi anggota elemen program yang sesuai, yaitu
input lhs
, lhs_batching_dimensions
, rhs
, dan rhs_batching_dimensions
yang ditentukan di bagian "Input" di dot_general
.
Seperti yang disebutkan di atas, sintaksis formula ini berbasis Python dengan beberapa ekstensi yang berorientasi pada ringkasan. Untuk memahami formula tersebut, mari kita ubah menjadi sintaks vanilla Python.
A) Dalam formula ini, kita menggunakan =
untuk merepresentasikan kesetaraan, sehingga langkah pertama
untuk mendapatkan sintaksis Python adalah mengganti =
dengan ==
, sebagai berikut:
dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...)
.
B) Selain itu, formula ini mendukung elipsis (...
) yang mengubah ekspresi skalar
menjadi ekspresi tensor. Singkatnya, f(xs...)
secara kasar berarti "untuk setiap
x
skalar dalam tensor xs
, komputasikan f(x)
skalar, lalu tampilkan semua
hasil skalar ini bersama-sama sebagai hasil tensor". Dalam sintaksis Python vanilla,
formula contoh kita berubah menjadi:
[dim(lhs, dim1) for dim1 in lhs_batching_dimensions] ==
[dim(rhs, dim2) for dim2 in rhs_batching_dimensions]
.
Berkat elipsis, Anda sering kali dapat menghindari bekerja pada tingkat skalar individual. Namun, dalam beberapa kasus rumit, sintaksis semi-informal tingkat rendah
dapat digunakan seperti dalam formula start_indices[bi0, ..., :, ..., biN]
dari spesifikasi gather
. Untuk mempersingkat, kami tidak
memberikan formalisme yang tepat untuk menerjemahkan sintaksis tersebut ke Python vanilla, dengan
harapan bahwa sintaksis tersebut masih dapat dipahami secara intuitif berdasarkan kasus per kasus.
Beri tahu kami jika beberapa formula tertentu terlihat buram, dan kami akan mencoba
meningkatkannya.
Selain itu, Anda akan melihat bahwa formula menggunakan elipsis untuk memperluas semua jenis daftar, termasuk tensor, daftar tensor (yang misalnya dapat muncul dari jumlah variabel tensor), dll. Ini adalah area lain tempat kita tidak memberikan formalisme yang tepat (misalnya, daftar bahkan bukan bagian dari sistem jenis StableHLO) dan sebagai gantinya mengandalkan pemahaman intuitif.
C) Kendaraan notasi penting terakhir yang kami gunakan adalah penyiaran implisit. Meskipun opset StableHLO tidak mendukung siaran implisit, formula mendukungnya, juga untuk layanan ringkas. Singkatnya, jika skalar digunakan dalam konteks yang mengharapkan tensor, skalar akan disiarkan ke bentuk yang diharapkan.
Untuk melanjutkan contoh dot_general
, berikut batasan lainnya:
0 <= lhs_batching_dimensions < rank(lhs)
. Seperti yang ditentukan dalam spesifikasi
dot_general
, lhs_batching_dimensions
adalah tensor, tetapi 0
dan
rank(lhs)
adalah skalar. Setelah kita menerapkan siaran implisit, formulanya akan
menjadi [0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)]
.
Saat diterapkan ke operasi dot_general
tertentu, formula ini akan
dievaluasi menjadi tensor boolean. Jika formula digunakan sebagai batasan,
batasan akan berlaku jika formula dievaluasi menjadi true
atau tensor yang
hanya memiliki elemen true
.
Nama
Dalam formula, cakupan leksikonis mencakup: 1) fungsi global, 2) definisi anggota,
3) definisi lokal. Daftar fungsi global diberikan di bawah ini. Daftar definisi elemen bergantung pada elemen program tempat notasi diterapkan:
- Untuk operasi, definisi anggota menyertakan nama yang diperkenalkan di bagian "Input" dan "Output".
- Untuk hal lainnya, definisi anggota mencakup bagian struktural dari
elemen program, yang diberi nama sesuai dengan non-terminal EBNF yang sesuai. Sebagian besar
waktu, nama bagian struktural ini diperoleh dengan mengonversi
nama non-terminal ke snake case (misalnya,
IntegerLiteral
=>integer_literal
), tetapi terkadang nama disingkat dalam proses (misalnya,QuantizationStorageType
=>storage_type
) dalam hal ini nama diperkenalkan secara eksplisit mirip dengan bagian "Input"/"Output" dalam spesifikasi operasi. - Selain itu, definisi anggota selalu menyertakan
self
untuk merujuk ke elemen program yang sesuai.
Nilai
Saat dievaluasi, formula akan berfungsi dengan jenis nilai berikut:
1) Value
(nilai sebenarnya, misalnya dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
;
nilainya selalu diketahui),
2) Placeholder
(nilai mendatang, misalnya lhs
, rhs
, atau result
; nilai sebenarnya
belum diketahui, hanya jenisnya yang diketahui),
3) Type
(jenis seperti yang ditentukan di bagian "Jenis"),
4) Function
(fungsi global seperti yang ditentukan di bagian "Fungsi").
Bergantung pada konteksnya, nama mungkin merujuk ke nilai yang berbeda. Lebih
khususnya, bagian "Semantik" untuk operasi (dan yang setara untuk elemen
program lainnya) menentukan logika runtime, sehingga semua input tersedia sebagai Value
.
Sebaliknya, bagian "Batasan" untuk operasi (dan yang setara) menentukan logika "waktu kompilasi", yaitu sesuatu yang biasanya dieksekusi sebelum runtime, sehingga hanya input konstan yang tersedia sebagai Value
dan input lainnya hanya tersedia sebagai Placeholder
.
Nama | Di "Semantik" | Di "Batasan" |
---|---|---|
Fungsi global | Function |
Function |
Input konstan | Value |
Value |
Input non-konstanta | Value |
Placeholder |
Output | Value |
Placeholder |
Definisi lokal | Bergantung pada definisi | Bergantung pada definisi |
Mari kita lihat contoh operasi transpose
:
%result = "stablehlo.transpose"(%operand) {
permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
Untuk operasi ini, permutation
adalah konstanta, sehingga tersedia sebagai Value
dalam semantik dan batasan. Sebaliknya, operand
dan result
tersedia sebagai Value
dalam semantik, tetapi hanya sebagai Placeholder
dalam batasan.
Fungsi
Konstruksi jenis
Tidak ada fungsi yang dapat digunakan untuk membuat jenis. Sebagai gantinya, kita langsung
menggunakan sintaksis jenis karena biasanya lebih ringkas. Misalnya,
(tensor<E>, tensor<E>) -> (tensor<E>)
, bukan function_type(
[tensor_type([], E), tensor_type([], E)], [tensor_type([], E)])
.
Fungsi pada jenis
element_type
ditentukan pada jenis tensor dan jenis tensor terkuantisasi, dan masing-masing menampilkan bagianTensorElementType
atauQuantizedTensorElementType
dariTensorType
atauQuantizedTensorType
yang sesuai.
def element_type(x: Value | Placeholder | Type):
if type(x) == TensorType:
return tensor_element_type(x)
if type(x) == QuantizedTensorType:
return quantized_tensor_element_type(x)
if type(x) is not Type:
return element_type(type(x))
is_per_axis_quantized(x: Value | Placeholder | Type) -> Value
adalah pintasan untukis_quantized(x) and quantization_dimension(x) is not None
.is_per_tensor_quantized(x: Value | Placeholder | Type) -> Value
adalah pintasanis_quantized(x) and quantization_dimension(x) is None
.is_promotable(x: Type, y: Type) -> bool
memeriksa apakah jenisx
dapat dipromosikan ke jenisy
. Jikax
dany
adalahQuantizedTensorElementType
, promosi hanya diterapkan kestorage_type
. Versi promosi khusus ini saat ini digunakan dalam konteks komputasi pengurangan (lihat RFC untuk mengetahui detail selengkapnya).
def is_promotable(x: Type, y: Type) -> Value:
is_same_type = (is_bool(x) and is_bool(y)) or
(is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
(is_complex(x) and is_complex(y)) or
(is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))
if is_same_type == False:
return False
if is_integer(x) or is_float(x):
return bitwidth(x) <= bitwidth(y)
if is_complex(x):
return bitwidth(element_type(x)) <= bitwidth(element_type(y))
if is_quantized(x):
return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))
return false
is_quantized(x: Value | Placeholder | Type) -> Value
adalah pintasan untukis_quantized_tensor_element_type(x)
.is_type_name(x: Value | Placeholder | Type) -> Value
. Tersedia untuk semua jenis. Misalnya,is_float(x)
menampilkantrue
jikax
adalahFloatType
. Jikax
adalah nilai atau placeholder, fungsi ini adalah pintasan untukis_type_name(type(x))
.max_value(x: Type) -> Value
menampilkan nilai maksimumTensorElementType
. Jikax
bukanTensorElementType
,None
akan ditampilkan.min_value(x: Type) -> Value
menampilkan nilai minimum yang memungkinkan dariTensorElementType
. Jikax
bukanTensorElementType
,None
akan ditampilkan.member_name(x: Value | Placeholder | Type) -> Any
. Tersedia untuk semua definisi anggotamember_name
dari semua jenis. Misalnya,tensor_element_type(x)
menampilkan bagianTensorElementType
dariTensorType
yang sesuai. Jikax
adalah nilai atau placeholder, fungsi ini adalah pintasan untukmember_name(type(x))
. Jikax
bukan jenis yang memiliki anggota yang sesuai, atau nilai atau placeholder dari jenis tersebut, tampilkanNone
.is_empty_algorithm(*args: Type)
memeriksa apakah semua kolom algoritma titik ditetapkan keNone
. Hal ini diperlukan karena algoritma titik memiliki perilaku default yang ditentukan implementasi, sehingga penentuan nilai default akan salah.
Konstruksi nilai
operation_name(*xs: Value | Type) -> Value
. Tersedia untuk semua operasi. Misalnya,add(lhs, rhs)
mengambil dua nilai tensorlhs
danrhs
dan menampilkan output evaluasi operasiadd
dengan input ini. Untuk beberapa operasi, misalnyabroadcast_in_dim
, jenis outputnya adalah "bearing beban", yaitu yang diperlukan untuk mengevaluasi operasi. Dalam hal ini, fungsi mengambil jenis ini sebagai argumen.
Fungsi pada nilai
Semua operator dan fungsi Python tersedia. Misalnya, notasi subscription dan slicing dari Python tersedia untuk mengindeks ke dalam tensor, tensor kuantisasi, dan tuple.
to_destination_type(x: Value, destination_type: Type) -> Value
ditentukan pada tensor dan menampilkan nilaix
yang dikonversi berdasarkantype(x)
dandestination_type
sebagai berikut:
def to_destination_type(x: Value, destination_type: Type) -> Value:
if type(x) == destination_type:
return x
if is_quantized(destination_type):
if is_quantized(type(x)):
return quantize(x, destination_type)
assert is_float(type(x))
return quantize(x, destination_type)
if is_quantized(type(x)):
assert destination_type = expressed_type(type(x))
return dequantize(type(x))
return convert(x, destination_type)
Ada diskusi awal tentang penggabungan operasi convert
, uniform_quantize
, dan
uniform_dequantize
(#1576).
Setelah penggabungan, kita tidak memerlukan fungsi di atas dan dapat menggunakan nama operasi
untuk convert
.
is_nan(x: Value) -> Value
ditentukan pada tensor dan menampilkantrue
jika semua elemenx
adalahNaN
ataufalse
jika tidak. Jikax
bukan tensor, akan menampilkanNone
.is_sorted(x: Value) -> Value
ditentukan pada tensor dan menampilkantrue
jika elemenx
diurutkan dalam urutan menaik sehubungan dengan urutan leksikografis menaik dari indeksnya, ataufalse
jika tidak. Jikax
bukan tensor,None
akan ditampilkan.is_unique(x: Value) -> Value
ditentukan pada tensor dan menampilkantrue
jikax
tidak memiliki elemen duplikat ataufalse
jika tidak. Jikax
bukan tensor,None
akan ditampilkan.member_name(x: Value) -> Any
ditentukan untuk semua definisi anggotamember_name
dari semua nilai. Misalnya,real_part(x)
menampilkan bagianRealPart
dariComplexConstant
yang sesuai. Jikax
bukan nilai yang memiliki anggota yang sesuai,None
akan ditampilkan.same(x: Value) -> Value
ditentukan pada tensor dan menampilkantrue
jika elemenx
semuanya sama satu sama lain ataufalse
jika tidak. Jika tensor tidak memiliki elemen, hal itu dianggap sebagai "semua sama satu sama lain", yaitu fungsi menampilkantrue
. Jikax
bukan tensor,None
akan ditampilkan.split(x: Value, num_results: Value, axis: Value) -> Value
ditentukan pada tensor dan menampilkan slicenum_results
darix
di sepanjang sumbuaxis
. Jikax
bukan tensor ataudim(x, axis) % num_results != 0
,None
akan ditampilkan.is_defined_in_parent_scope(x: Value) -> Value
ditentukan pada string dan menampilkantrue
jikax
adalah nama fungsi yang ditentukan dalam cakupan yang sama dengan fungsi induk dari operasi yang relevan.is_namespaced_op_name(x: Value) -> Value
ditentukan pada string dan menampilkantrue
jikax
adalah nama operasi yang valid, yaitu mengikuti ekspresi reguler berikut:[a-zA-Z][a-zA-Z0-9_]*([.][a-zA-Z0-9_$]+)+
Komputasi bentuk
axes(x: Value | Placeholder | Type) -> Value
adalah pintasan untukrange(rank(x))
.dim(x: Value | Placeholder | Type, axis: Value) -> Value
adalah pintasan untukshape(x)[axis]
.dims(x: Value | Placeholder | Type, axes: List) -> List
adalah pintasan untuklist(map(lambda axis: dim(x, axis), axes))
.index_space(x: Value | Placeholder | Type) -> Value
ditentukan pada tensor dan menampilkan indekssize(x)
untukTensorType
yang sesuai yang diurutkan dalam urutan leksikografis menaik, yaitu[0, ..., 0]
,[0, ..., 1]
, ...,shape(x) - 1
. Jikax
bukan jenis tensor, jenis tensor kuantisasi, atau nilai atau placeholder dari salah satu jenis ini,None
akan ditampilkan.rank(x: Value | Placeholder | Type) -> Value
adalah pintasan untuksize(shape(x))
.shape(x: Value | Placeholder | Type) -> Value
ditentukan di bagian "Fungsi pada jenis" melaluimember_name
.size(x: Value | Placeholder | Type) -> Value
adalah pintasan untukreduce(lambda x, y: x * y, shape(x))
.
Komputasi kuantisasi
def baseline_element_type(x: Value | Placeholder | Type) -> Type
adalah pintasanelement_type(baseline_type(x))
.baseline_type
ditentukan pada jenis tensor dan jenis tensor terkuantisasi, serta mengubahnya menjadi "dasar pengukuran", yaitu jenis dengan bentuk yang sama tetapi dengan parameter kuantisasi jenis elemen yang direset ke nilai default. Hal ini digunakan sebagai trik praktis untuk membandingkan jenis tensor dan tensor terkuantisasi secara seragam, yang cukup sering diperlukan. Untuk jenis kuantisasi, hal ini memungkinkan perbandingan jenis yang mengabaikan parameter kuantisasi, yaitu,shape
,storage_type
,expressed_type
,storage_min
,storage_max
, danquantization_dimension
(untuk jenis kuantisasi per sumbu) harus cocok, tetapiscales
danzero points
dapat berbeda.
def baseline_type(x: Value | Placeholder | Type) -> Type:
if type(x) == TensorType:
return x
if type(x) == QuantizedTensorType:
element_type = quantized_tensor_element_type(x)
baseline_element_type = QuantizedTensorElementType(
storage_type = storage_type(element_type),
storage_min = storage_min(element_type),
storage_max = storage_max(element_type),
expressed_type = expressed_type(element_type),
quantization_dimension = quantization_dimension(element_type),
scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
return QuantizedTensorType(shape(x), baseline_element_type)
if type(x) is not Type:
return baseline_element_type(type(x))
dequantize
ditentukan pada jenis tensor terkuantisasi dan mengubahnya menjadi jenis tensor floating point. Hal ini terjadi melalui konversi elemen kuantisasi yang mewakili nilai bilangan bulat dari jenis penyimpanan menjadi nilai floating point yang sesuai dari jenis yang dinyatakan menggunakan titik nol dan skala yang terkait dengan jenis elemen kuantisasi.
def compute_zero_points(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
zero_points[i] = zero_points(quantized_type)[i[d]]
return zero_points
def compute_scales(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
type(result_type))
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
scales[i] = scales(quantized_type)[i[d]]
return scales
def dequantize(x: Value) -> Value:
assert is_quantized(x)
x_storage = bitcast_convert(x, storage_type(x))
x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
x_expressed_sub = convert(x_storage_sub, expressed_type(x))
return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
quantize
ditentukan pada jenis tensor floating point dan mengubahnya menjadi jenis tensor terkuantisasi. Hal ini terjadi melalui konversi nilai floating point dari jenis yang dinyatakan menjadi nilai bilangan bulat yang sesuai dari jenis penyimpanan menggunakan titik nol dan skala yang terkait dengan jenis elemen yang dikuantisasi.
def quantize(x: Value, result_type: Type) -> Value:
assert is_float(x) and is_quantized(result_type)
zero_points = compute_zero_points(result_type, TensorType(shape(x), storage_type(result_type)))
converted_zero_points = convert(zero_points, expressed_type(result_type))
converted_min = convert(storage_min(result_type), expressed_type(result_type))
converted_max = convert(storage_max(result_type), expressed_type(result_type))
x_scaled = x / compute_scales(result_type, type(x))
x_scaled_add_zp = x_scaled + converted_zero_points
x_clamped = clamp(converted_min, x_scaled_add_zp, converted_max)
x_rounded = round_nearest_even(x_clamped)
return convert(x_rounded, result_type)
dequantize_op_quantize
digunakan untuk menentukan komputasi berbasis elemen pada tensor terkuantisasi. Fungsi ini melakukan dekuantisasi, yaitu mengubah elemen kuantisasi menjadi jenis yang dinyatakan, lalu melakukan operasi, lalu melakukan kuantisasi, yaitu mengubah hasil kembali menjadi jenis penyimpanannya. Saat ini, fungsi ini hanya berfungsi untuk kuantisasi per tensor. Kuantifikasi per sumbu sedang dalam proses (#1574).
def dequantize_op_quantize(op, *inputs_and_output_type):
inputs = inputs_and_output_type[:-1]
output_type = inputs_and_output_type[-1]
float_inputs = map(dequantize, inputs)
float_result = op(*float_inputs)
return quantize(float_result, output_type)
def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
inputs = inputs_and_output_type[:-3]
float_inputs = map(dequantize, inputs)
float_results = op(*float_inputs)
return map(quantize, float_results, inputs_and_output_type[-3:])
def dequantize_compare(lhs, rhs, comparison_direction):
float_lhs = dequantize(lhs)
float_rhs = dequantize(rhs)
return compare(float_lhs, float_rhs, comparison_direction, FLOAT)
def dequantize_select_quantize(pred, on_true, on_false, output_type):
float_on_true = dequantize(on_true)
float_on_false = dequantize(on_false)
float_result = select(pred, float_on_true, float_on_false)
return quantize(float_result, output_type)
hybrid_dequantize_then_op
digunakan untuk menentukan kuantisasi khusus bobot untuk operasi campuran yang menerima lhs dalam floating point dan rhs dalam jenis kuantisasi. Fungsi ini mendekuantisasikan input yang dikuantisasikan ke dalam jenis yang dinyatakan dan melakukan komputasi dalam float. Jenis elemen tensor lhs float dan jenis yang dinyatakan dari tensor rhs kuantisasi harus identik.
def hybrid_dequantize_then_op(op, lhs, rhs):
assert(is_float(lhs) and is_quantized(rhs) and element_type(lhs) == expressed_type(rhs))
return op(lhs, dequantize(rhs))
Komputasi petak
cross_partition(replica_groups: Value) -> Value
. Lihat bagian "cross_replica" di atas.cross_replica(replica_groups: Value) -> Value
. Lihat bagian "cross_replica" di atas.cross_replica_and_partition(replica_groups: Value) -> Value
. Lihat bagian "cross_replica_and_partition" di atas.flattened_ids(replica_groups: Value) -> Value
. Lihat bagian "flattened_ids" di atas.
Dinamisme
Nilai StableHLO dapat memiliki ukuran dimensi dinamis, misalnya tensor<?xi64>
.
Namun, nilai StableHLO tidak boleh memiliki jumlah dimensi dinamis (dinamisme
tanpa peringkat, misalnya tensor<*xi64>
). Operand dan hasil diizinkan untuk menggunakan ukuran
dimensi dinamis, meskipun ada batasan pada ukuran. Batasan akan
diverifikasi secara statis jika memungkinkan, jika tidak, batasan akan ditangguhkan ke runtime dan
ketidakcocokan akan menyebabkan perilaku yang tidak ditentukan. Lihat contoh berikut.
Ketidakcocokan bentuk untuk operasi elementwise unary
Pertimbangkan program mainan berikut:
func.func @foo(%arg0: tensor<?xf64>) {
%0 = stablehlo.abs %arg0 : (tensor<?xf64>) -> tensor<2xf64>
return
}
Program semacam itu tidak biasa, karena biasanya kita mengetahui bentuk
hasil, tetapi tidak mengetahui bentuk input. Meskipun demikian, ini adalah program StableHLO
yang valid. Operasi abs
dalam program ini tidak dapat divalidasi secara statis karena bentuk operand yang tepat tidak diketahui. Namun, bentuk
pasti kompatibel, dan ini dapat diperiksa secara statis: ?
dapat berubah
menjadi 2
saat runtime, dan tidak akan ada masalah. Namun, ?
juga
dapat berupa beberapa bilangan bulat lainnya, dalam hal ini perilakunya tidak ditentukan.
Perhatikan bahwa jika ukuran dimensi bersifat dinamis dalam hasil, tidak mungkin ada perilaku yang tidak ditentukan. Memang, tidak ada ukuran "yang diharapkan", sehingga tidak akan ada ketidakcocokan.
Ketidakcocokan bentuk untuk operasi elementwise biner
Pertimbangkan program mainan berikut:
func.func @foo(%arg0: tensor<?xf64>, %arg1: tensor<?xf64>) {
%0 = stablehlo.add %arg0, %arg0 : (tensor<?xf64>, tensor<?xf64>) -> tensor<?xf64>
return
}
Dalam hal operasi elementwise biner, bentuk input dan hasil harus sesuai saat runtime. Pada waktu kompilasi, dimensi statis harus sama, jika tidak, dimensi tersebut hanya perlu kompatibel. Jika ada dimensi apa pun yang bersifat dinamis dalam input, mungkin ada perilaku yang tidak ditentukan saat runtime, karena ukuran dinamis mungkin tidak cocok dengan ukuran yang sesuai dalam operand lain (baik statis maupun dinamis). Jika semua input bersifat statis, hasil yang dinamis atau tidak tidak akan menjadi masalah: dimensi yang diketahui secara statis akan diperiksa secara statis, dan dimensi dinamis tidak akan memberlakukan batasan apa pun.
Ketidakcocokan bentuk untuk operasi yang mengambil bentuk output-nya sebagai operand
Pertimbangkan program mainan berikut:
func.func @foo(%arg0: tensor<2xi32>) {
%0 = stablehlo.dynamic_iota %arg0, dim = 0 : (tensor<2xi32>) -> tensor<3x4xi64>
return
}
Nilai dalam operand bentuk saat runtime harus cocok dengan bentuk hasilnya.
Jika tidak, perilakunya tidak ditentukan. Artinya, saat runtime, %arg0
harus memiliki
nilai dense<[3, 4]> : tensor<2xi32>
. Jika operand bentuk konstan, hal ini
dapat diverifikasi secara statis. Jika bentuk hasilnya sepenuhnya dinamis, tidak
akan ada ketidakcocokan.