StabilHLO adalah operasi yang ditetapkan untuk operasi tingkat tinggi (HLO) di mesin machine learning (ML). StabilHLO berfungsi sebagai lapisan portabilitas antara berbagai Framework ML dan compiler ML: framework ML yang menghasilkan program StableHLO kompatibel dengan compiler ML yang menggunakan program StableHLO.
Tujuan kami adalah menyederhanakan dan mempercepat pengembangan ML dengan membuat lebih banyak interoperabilitas antara berbagai framework ML (seperti TensorFlow, JAX, PyTorch) dan compiler ML (seperti XLA dan IREE). Untuk mencapai tujuan tersebut, dokumen menyediakan spesifikasi untuk bahasa pemrograman StableHLO.
Spesifikasi ini memuat tiga bagian utama. Pertama, Bagian Programs menjelaskan struktur program StabilHLO yang terdiri dari fungsi StableHLO yang terdiri dari operasi StableHLO. Dalam struktur tersebut, bagian Ops menentukan semantik operasi individual. Bagian Execution menyediakan semantik untuk semua operasi ini yang dieksekusi bersama dalam sebuah program. Terakhir, Bagian Notasi membahas notasi yang digunakan dalam spesifikasi pendukung.
Untuk melihat spesifikasi dari rilis StableHLO sebelumnya, buka repo di rilis yang diberi tag minat. Misalnya, StableHLO v0.19.0 Spec. Untuk melihat perubahan yang terjadi di setiap bump versi minor StableHLO, lihat log versi di VhloDialect.td.
Program
Program ::= {Func}
Program SttableHLO terdiri dari sejumlah arbitrer fungsi StableHLO.
Berikut adalah contoh program dengan fungsi @main
yang memiliki 3 input
(%image
, %weights
, dan %bias
) dan 1 output. Isi fungsi
memiliki 6 operasi.
func.func @main(
%image: tensor<28x28xf32>,
%weights: tensor<784x10xf32>,
%bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
%0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
%1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
%2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
%3 = "stablehlo.constant"() {value = dense<0.0> : tensor<1x10xf32>} : () -> tensor<1x10xf32>
%4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
"func.return"(%4): (tensor<1x10xf32>) -> ()
}
Fungsi
Func ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput ::= ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput ::= ValueType
FuncBody ::= {Op}
Fungsi SttableHLO (yang juga disebut fungsi bernama) memiliki ID, input/{i>output<i} dan {i>body<i}. Di masa mendatang, kami berencana memperkenalkan metadata tambahan untuk fungsi guna mencapai kompatibilitas yang lebih baik dengan HLO (#425, #626, #740, #744).
Pengenal
FuncId ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
| '%' letter {letter | digit}
letter ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit ::= '0' | ... | '9'
ID SttableHLO mirip dengan ID dalam banyak pemrograman bahasa, dengan dua kekhasan: 1) semua pengenal memiliki {i>sigil<i} yang membedakan berbagai jenis pengenal, 2) pengenal nilai dapat lengkap bersifat numerik untuk menyederhanakan pembuatan program StableHLO.
Jenis
Type ::= ValueType | NonValueType
ValueType ::= TensorType | QuantizedTensorType | TokenType | TupleType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType
Jenis SttableHLO dikategorikan ke dalam jenis nilai (yang juga disebut jenis kelas satu) yang mewakili nilai StabilHLO dan jenis bukan nilai yang menggambarkan elemen program lainnya. Jenis StableHLO mirip dengan jenis pada banyak bahasa pemrograman, dengan keunikan utamanya adalah yang bersifat khusus domain yang memberikan hasil yang tidak biasa (mis. jenis skalar bukan jenis nilai).
TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit} | '?'
Jenis tensor mewakili tensor, yaitu array multidimensi. Aset tersebut memiliki
bentuk dan jenis elemen, dengan sebuah bentuk merepresentasikan bilangan non-negatif atau
ukuran dimensi yang tidak diketahui dalam urutan menaik
dimensi (yang juga disebut sumbu) yang diberi nomor dari 0
hingga R-1
. Tujuan
jumlah dimensi R
disebut peringkat. Misalnya, tensor<2x3xf32>
adalah
jenis tensor dengan bentuk 2x3
dan jenis elemen f32
. Model ini memiliki dua dimensi
(atau, dengan kata lain, dua sumbu) - dimensi ke-0 dan dimensi ke-1 - yang ukurannya
adalah 2 dan 3. Peringkatnya 2.
Bentuk bisa tidak diketahui sebagian atau seluruhnya (dinamis), mis. tensor<?x2xf64>
tidak diketahui sebagian dan tensor<?x?xf64>
sama sekali tidak diketahui. Dinamis
ukuran dimensi ditampilkan menggunakan ?
. Bentuk tidak boleh diberi peringkat.
Di masa mendatang, kami berencana mempelajari jenis tensor yang diperluas di luar ukuran dimensi dan jenis elemen, misalnya, untuk menyertakan tata letak (#629) dan ketersebaran (#1078).
QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
QuantizationStorageType
['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
':' QuantizationExpressedType
[':' QuantizationDimension]
',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerConstant
QuantizationStorageMax ::= IntegerConstant
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerConstant
QuantizationParameters ::= QuantizationParameter
| '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale ':' QuantizationZeroPoint
QuantizationScale ::= FloatConstant
QuantizationZeroPoint ::= IntegerConstant
Nama | Jenis | Batasan |
---|---|---|
storage_type |
jenis bilangan bulat | (C1-C3), (C8) |
storage_min |
konstanta bilangan bulat | (C1), (C3), (C7) |
storage_max |
konstanta bilangan bulat | (C2), (C3), (C7) |
expressed_type |
jenis floating point | (C4) |
quantization_dimension |
konstanta bilangan bulat opsional | (C10-C12) |
scales |
jumlah variabel konstanta floating point | (C4-C6), (C9), (C10), (C13) |
zero_points |
jumlah variabel konstanta bilangan bulat | (C7-C9) |
Jenis elemen terkuantisasi mewakili nilai bilangan bulat dari jenis penyimpanan di
rentang dari storage_min
hingga storage_max
(inklusif) yang sesuai dengan
nilai floating point dari jenis yang dinyatakan. Untuk nilai bilangan bulat i
yang diberikan,
nilai floating point f
yang sesuai dapat dikomputasi sebagai
f = (i - zero_point) * scale
, yang mana scale
dan zero_point
dipanggil
parameter kuantisasi. storage_min
dan storage_max
bersifat opsional
dalam tata bahasa, tetapi memiliki nilai default min_value(storage_type)
dan
max_value(storage_type)
. Jenis elemen terkuantisasi memiliki
batasan berikut:
- (C1)
type(storage_min) = storage_type
. - (C2)
type(storage_max) = storage_type
. - (C3)
min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type)
. - (C4)
type(scales...) = expressed_type
. - (C5)
0 < scales
. - (C6)
is_finite(scales...)
. - (C7)
storage_min <= zero_points <= storage_max
. - (C8)
type(zero_points...) = storage_type
. - (C9)
size(scales) = size(zero_points)
. - (C10) Jika
is_empty(quantization_dimension)
, makasize(scales) = 1
. - (C11)
0 <= quantization_dimension
.
Saat ini, QuantizationScale
adalah konstanta floating point, tetapi ada
kepentingan yang kuat dalam skala berbasis bilangan bulat, diwakili dengan pengganda dan
{i>shift<i}. Kami berencana untuk menjelajahi fitur ini dalam waktu dekat
(#1404).
Ada diskusi berkelanjutan tentang semantik QuantizationZeroPoint
,
termasuk jenisnya, nilai-nilainya, dan apakah mungkin hanya ada satu atau
dan mungkin beberapa titik nol
dalam tipe tensor terkuantisasi. Berdasarkan
hasil diskusi ini, spesifikasi seputar titik nol dapat berubah
pada masa mendatang (#1405).
Diskusi berkelanjutan lainnya melibatkan semantik QuantizationStorageMin
dan QuantizationStorageMax
untuk menentukan apakah ada batasan
dikenakan pada nilai-nilai ini dan pada nilai-nilai tensor terkuantisasi
(#1406).
Terakhir, kita berencana untuk merepresentasikan skala yang tidak diketahui dan nol poin, sama seperti perencanaan kita dalam mengeksplorasi cara untuk merepresentasikan ukuran dimensi (#1407).
Jenis tensor terkuantisasi mewakili tensor dengan elemen terkuantisasi. Ini tensor sama persis dengan tensor reguler, hanya saja memiliki jenis elemen terkuantisasi, bukan jenis elemen reguler.
Pada tensor terkuantisasi, kuantisasi bisa berupa per-tensor, artinya
satu scale
dan zero_point
untuk seluruh tensor atau dapat berupa per-axis,
artinya, memiliki beberapa scales
dan zero_points
, satu pasang per irisan
dimensi tertentu quantization_dimension
. Secara lebih formal, dalam tensor t
dengan kuantisasi per sumbu, terdapat dim(t, quantization_dimension)
irisan
dari quantization_dimension
: t[:, ..., 0, ..., :], t[:, ..., 1, ..., :]
,
dll. Semua elemen dalam slice i
menggunakan scales[i]
dan zero_points[i]
sebagai
parameter kuantisasinya. Jenis tensor terkuantisasi memiliki
batasan:
- Untuk kuantisasi per-tensor:
- Tidak ada batasan tambahan.
- Untuk kuantisasi per sumbu:
- (C12)
quantization_dimension < rank(self)
. - (C13)
dim(self, quantization_dimension) = size(scales)
.
- (C12)
TokenType ::= 'token'
Jenis token mewakili token, yaitu nilai buram yang dihasilkan dan dipakai oleh beberapa operasi. Token digunakan untuk menerapkan urutan eksekusi pada operasi seperti yang dijelaskan di bagian Eksekusi.
TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]
Jenis tuple mewakili tuple, yaitu daftar heterogen. Tuple adalah warisan
yang hanya ada untuk kompatibilitas dengan HLO. Di HLO, tuple adalah
digunakan untuk merepresentasikan
input dan {i>output <i}variadic. Dalam StableHLO, input variadic dan
didukung secara native, dan satu-satunya penggunaan tuple di StableHLO adalah untuk
secara komprehensif merepresentasikan ABI HLO, mis. T
, tuple<T>
, dan
tuple<tuple<T>>
mungkin secara signifikan berbeda bergantung pada
terlepas dari implementasi layanan. Di masa mendatang, kami berencana melakukan perubahan pada HLO ABI
sehingga kami dapat menghapus jenis tuple dari StableHLO
(#598).
TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si2' | 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui2' | 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f8E4M3FN' | 'f8E5M2' | 'f8E4M3FNUZ' | 'f8E5M2FNUZ'
| 'f8E4M3B11FNUZ' | 'bf16' | 'f16' | 'f32' | 'f64'
TensorFloat32 ::= 'tf32'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'
Jenis elemen mewakili elemen jenis tensor. Tidak seperti di banyak pemrograman
bahasa, jenis ini bukan kelas satu di StableHLO. Hal ini berarti bahwa
Program StabilHLO tidak dapat langsung merepresentasikan nilai jenis ini (akibatnya,
merepresentasikan nilai skalar jenis T
dengan tensor 0 dimensi
nilai dari jenis tensor<T>
).
- Jenis boolean mewakili nilai boolean
true
danfalse
. - Jenis bilangan bulat dapat ditandatangani (
si
) atau tidak ditandatangani (ui
) dan memiliki salah satu lebar bit yang didukung (2
,4
,8
,16
,32
, atau64
). JenissiN
bertanda tangan mewakili nilai bilangan bulat dari-2^(N-1)
hingga2^(N-1)-1
jenisuiN
inklusif, dan tidak ditandatangani mewakili nilai bilangan bulat dari0
hingga2^N-1
inklusif. - Jenis floating point dapat berupa salah satu dari hal berikut:
- Jenis
f8E4M3FN
danf8E5M2
yang sesuai dengan masing-masing EncodingE4M3
danE5M2
untuk format FP8 yang dijelaskan di Format FP8 untuk Deep Learning. - Jenis
f8E4M3FNUZ
danf8E5M2FNUZ
yang sesuai denganE4M3
danE5M2
pengkodean format FP8 yang dijelaskan dalam Format Numerik 8-bit untuk Jaringan Neural Dalam. - Jenis
f8E4M3B11FNUZ
yang sesuai dengan encodingE4M3
untuk format FP8 dijelaskan dalam Pelatihan dan Inferensi Hybrid 8-bit Floating Point (HFP8) untuk Jaringan Neural Dalam. - Jenis
bf16
yang sesuai dengan formatbfloat16
yang dijelaskan di BFloat16: Rahasia performa tinggi pada Cloud TPU. - Jenis
f16
,f32
, danf64
yang sesuai masing-masingbinary16
("presisi setengah"),binary32
("presisi tunggal") dan Formatbinary64
("presisi ganda") yang dijelaskan di standar IEEE 754. - Jenis
tf32
sesuai dengan format TensorFloat32 dan memiliki dukungan terbatas di StableHLO.
- Jenis
- Jenis kompleks merepresentasikan nilai kompleks yang memiliki bagian nyata
dan bagian imajiner dari jenis elemen yang sama. Kompleks yang didukung
adalah
complex<f32>
(kedua bagian dari jenisf32
) dancomplex<f64>
(kedua bagian adalah jenisf64
).
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]
Jenis fungsi mewakili fungsi bernama dan anonim. Mereka memiliki
jenis input (daftar jenis di sebelah kiri ->
) dan jenis output
(daftar jenis di sisi kanan ->
). Dalam banyak pemrograman
bahasa, jenis fungsi adalah class satu, tetapi tidak di StableHLO.
StringType ::= 'string'
Jenis string mewakili urutan byte. Tidak seperti di banyak pemrograman bahasa, jenis string bukan kelas satu di StableHLO dan hanya digunakan untuk menentukan {i>metadata<i} statis untuk elemen program.
Operasi
Operasi SttableHLO (yang juga disebut ops) mewakili kumpulan tertutup operasi tingkat tinggi dalam model machine learning. Seperti yang dibahas di atas, Sintaksis StableHLO sangat terinspirasi oleh MLIR, yang belum tentu merupakan ergonomis, tetapi mungkin lebih cocok untuk tujuan StableHLO sehingga meningkatkan interoperabilitas antara framework ML dan compiler ML.
Op ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic ::= 'abs' | 'add' | ...
Operasi SttableHLO (yang juga disebut ops) memiliki nama,
input/{i>output<i} dan tanda tangan. Nama terdiri dari awalan stablehlo.
dan
mnemonic yang secara unik mengidentifikasi salah satu operasi yang didukung. Lihat di bawah untuk
daftar lengkap dari semua operasi yang didukung.
OpInputs ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue ::= ValueId
OpInputFuncs ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs ::= [OpOutput {',' OpOutput} '=']
OpOutput ::= ValueId
Operasi menggunakan input dan menghasilkan output. Input dikategorikan ke dalam
nilai input (dihitung selama eksekusi), fungsi input (disediakan
secara statis, karena dalam StableHLO bukan merupakan nilai kelas satu) dan
atribut input (juga disediakan secara statis). Jenis input dan output
dikonsumsi dan diproduksi oleh operasi
bergantung pada mnemoniknya. Misalnya, add
op menggunakan 2 nilai input dan menghasilkan 1 nilai output. Sebagai perbandingan,
Operasi select_and_scatter
menggunakan 3 nilai input, 2 fungsi input, dan
3 atribut input.
OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused ::= '^' digit {digit}
| '^' letter {letter | digit}
Fungsi input (yang juga disebut fungsi anonim) sangat
mirip dengan fungsi bernama kecuali bahwa: 1) mereka tidak memiliki ID (oleh karena itu
nama "anonim"), 2) mereka tidak mendeklarasikan jenis output (jenis output
disimpulkan dari operasi return
dalam fungsi).
Sintaks untuk fungsi input menyertakan bagian yang saat ini tidak digunakan (lihat
Unused
di atas) yang tersedia untuk kompatibilitas dengan MLIR. Di MLIR,
ada konsep yang lebih umum tentang "kawasan" yang dapat memiliki beberapa "blok"
operasi yang terhubung
bersama melalui operasi lompatan. Blok ini memiliki ID yang sesuai
ke produksi Unused
, agar dapat dibedakan satu sama lain.
StabilHLO tidak memiliki operasi lompatan, jadi bagian yang sesuai dari sintaks MLIR adalah
tidak digunakan (tetapi masih ada).
OpInputAttr ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName ::= letter {letter | digit}
OpInputAttrValue ::= Constant
Atribut input memiliki nama dan nilai yang merupakan salah satu atribut yang didukung
konstanta. Mereka adalah cara utama untuk menentukan metadata statis untuk
yang kurang penting. Misalnya, operasi concatenate
menggunakan atribut dimension
untuk
menentukan dimensi yang menggabungkan nilai inputnya. Demikian pula,
Operasi slice
menggunakan beberapa atribut seperti start_indices
dan limit_indices
untuk menentukan batas yang digunakan untuk memotong nilai input.
Saat ini, program StableHLO di luar organisasi terkadang berisi atribut yang tidak dijelaskan dalam dokumen ini. Di masa mendatang, kami berencana menyerap atribut ini ke dalam opset StableHLO atau melarangnya yang muncul di program StableHLO. Sementara itu, berikut daftar atribut:
layout
(#629).mhlo.frontend_attributes
(#628).mhlo.sharding
(#619).output_operand_aliases
(#740).- Metadata lokasi (#594).
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'
Tanda tangan op terdiri dari jenis semua nilai input (daftar jenis di
di sisi kiri ->
) dan jenis semua nilai output (daftar
jenis di sisi kanan ->
). Sebenarnya, tipe input
tipe output hampir selalu redundan (karena untuk
sebagian besar operasi StableHLO, jenis output dapat disimpulkan dari input). Meskipun demikian, op
signature sengaja dijadikan bagian dari sintaksis StableHLO untuk kompatibilitas dengan MLIR.
Di bawah ini adalah contoh operasi yang mnemoniknya adalah select_and_scatter
. Menggunakan 3
nilai input (%operand
, %source
, dan %init_value
), 2 fungsi input
dan 3 atribut input (window_dimensions
, window_strides
, dan padding
).
Perhatikan bagaimana tanda tangan operasi hanya menyertakan jenis nilai inputnya
(tetapi bukan jenis fungsi dan atribut input yang disediakan secara inline).
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i32>, tensor<i32>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
window_dimensions = dense<[3, 1]> : tensor<2xi64>,
window_strides = dense<[2, 1]> : tensor<2xi64>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>
Konstanta
Constant ::= BooleanConstant
| IntegerConstant
| FloatConstant
| ComplexConstant
| TensorConstant
| QuantizedTensorConstant
| StringConstant
| EnumConstant
Konstanta SttableHLO memiliki literal dan jenis yang bersama-sama mewakili
nilai StableHLO. Umumnya, jenis ini adalah bagian dari sintaks konstan, kecuali
jika tidak ambigu (misalnya, konstanta boolean yang jelas memiliki jenis i1
,
sedangkan konstanta bilangan bulat dapat memiliki beberapa kemungkinan jenis).
BooleanConstant ::= BooleanLiteral
BooleanLiteral ::= 'true' | 'false'
Konstanta Boolean mewakili nilai boolean true
dan false
. Boolean
konstanta memiliki jenis i1
.
IntegerConstant ::= IntegerLiteral ':' IntegerType
IntegerLiteral ::= ['-' | '+'] DecimalDigits
| ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit ::= '0' | ... | '9'
hexadecimalDigit ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'
Konstanta bilangan bulat mewakili nilai bilangan bulat melalui string yang menggunakan bilangan desimal atau notasi heksadesimal. Basis lainnya, misalnya biner atau oktal, tidak didukung. Konstanta bilangan bulat memiliki batasan berikut:
- (C1)
is_wellformed(integer_literal, integer_type)
.
FloatConstant ::= FloatLiteral ':' FloatType
FloatLiteral ::= SignPart IntegerPart FractionalPart ScientificPart
| '0x' [HexadecimalDigits]
SignPart ::= ['-' | '+']
IntegerPart ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]
Konstanta floating point merepresentasikan nilai floating point melalui string yang menggunakan desimal atau notasi ilmiah. Selain itu, notasi heksadesimal dapat digunakan untuk menentukan bit yang mendasarinya secara langsung dalam format floating point jenis yang sesuai. Konstanta floating point memiliki batasan berikut:
- (C1) Jika notasi non-heksadesimal digunakan,
is_wellformed(float_literal, float_type)
. - (C2) Jika notasi heksadesimal digunakan,
size(hexadecimal_digits) = num_bits(float_type) / 4
.
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral ::= '(' RealPart ',' ImaginaryPart ')'
RealPart ::= FloatLiteral
ImaginaryPart ::= FloatLiteral
Konstanta kompleks merepresentasikan nilai kompleks menggunakan daftar bagian nyata
(ditempatkan lebih dahulu) dan bagian imajiner (ditempatkan kedua). Misalnya,
(1.0, 0.0) : complex<f32>
mewakili 1.0 + 0.0i
, dan
(0.0, 1.0) : complex<f32>
mewakili 0.0 + 1.0i
. Urutan dari
yang kemudian disimpan di memori
adalah yang ditentukan implementasinya. Konstanta kompleks
memiliki batasan berikut:
- (C1)
is_wellformed(real_part, complex_element_type(complex_type))
. - (C2)
is_wellformed(imaginary_part, complex_element_type(complex_type))
.
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral
Konstanta Tensor mewakili nilai tensor menggunakan daftar bertingkat yang ditentukan melalui
Notasi NumPy. Misalnya, dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
mewakili nilai tensor dengan pemetaan berikut dari indeks ke elemen:
{0, 0} => 1
, {0, 1} => 2
, {0, 2} => 3
, {1, 0} => 4
, {1, 1} => 5
,
{1, 2} => 6
. Urutan elemen-elemen ini kemudian disimpan dalam memori adalah
ditentukan oleh implementasi. Konstanta tensor memiliki batasan berikut:
- (C1)
has_syntax(tensor_literal, element_type(tensor_type))
, dengan kondisi:has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type)
.has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type)
.
- (C2)
has_shape(tensor_literal, shape(tensor_type))
, dengan kondisi:has_shape(element_literal: Syntax, []) = true
.has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:])
.- jika tidak,
false
.
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
Konstanta tensor terkuantisasi mewakili nilai tensor terkuantisasi menggunakan notasi sebagai konstanta tensor, dengan elemen yang ditetapkan sebagai konstanta jenis penyimpanan. Konstanta tensor terkuantisasi memiliki batasan berikut:
- (C1)
has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type))
. - (C2)
has_shape(quantized_tensor_literal, shape(quantized_tensor_type))
.
StringConstant ::= StringLiteral
StringLiteral ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))
Literal string terdiri dari byte yang ditentukan menggunakan karakter ASCII dan
urutan escape. Model ini bersifat agnostik encoding, jadi interpretasinya
byte ditentukan oleh implementasi. Literal string memiliki jenis string
.
Operasi
abs
Semantik
Melakukan operasi abs berbasis element-wise pada tensor operand
dan menghasilkan result
tensor. Bergantung pada jenis elemen, melakukan hal berikut:
- Untuk bilangan bulat bertanda tangan: modulus bilangan bulat.
- Untuk float:
abs
dari IEEE-754. - Untuk bilangan kompleks: modulus kompleks.
- Untuk jenis terkuantisasi:
dequantize_op_quantize(abs, operand, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor bilangan bulat bertanda tangan, floating point, atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1-C2) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor bilangan bulat bertanda tangan atau jenis floating point atau tensor terkuantisasi per-tensor | (C1-C2) |
Batasan
- (C1)
shape(result) = shape(operand)
. - (C2)
baseline_element_type(result)
didefinisikan sebagai:complex_element_type(element_type(operand))
jikais_complex(operand)
.baseline_element_type(operand)
sebaliknya.
Contoh
// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand) : (tensor<3xi32>) -> tensor<3xi32>
// %result: [2, 0, 2]
tambahkan
Semantik
Melakukan penambahan berdasarkan elemen pada dua tensor lhs
dan rhs
serta menghasilkan
Tensor result
. Bergantung pada jenis elemen, melakukan hal berikut:
- Untuk boolean: OR logis.
- Untuk bilangan bulat: penambahan bilangan bulat.
- Untuk float:
addition
dari IEEE-754. - Untuk bilangan kompleks: penambahan kompleks.
- Untuk jenis terkuantisasi:
dequantize_op_quantize(add, lhs, rhs, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | lhs |
tensor atau quantized tensor | (C1-C6) |
(I2) | rhs |
tensor atau quantized tensor | (C1-C5), (C7) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau quantized tensor | (C1-C7) |
Batasan
- Jika operasi tersebut menggunakan tensor non-kuantisasi:
- (C1)
type(lhs) = type(rhs) = type(result)
.
- (C1)
- Jika operasi tersebut menggunakan tensor terkuantisasi:
- (C2)
is_quantized(lhs) and is_quantized(rhs) and is_quantized(result)
. - (C3)
storage_type(lhs) = storage_type(rhs) = storage_type(result)
. - (C4)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C5)
(is_per_axis_quantized(lhs) or is_per_axis_quantized(rhs)) = is_per_axis_quantized(result)
. - (C6) Jika
is_per_axis_quantized(lhs)
, makaquantization_dimension(lhs) = quantization_dimension(result)
. - (C7) Jika
is_per_axis_quantized(rhs)
, makaquantization_dimension(rhs) = quantization_dimension(result)
.
- (C2)
Contoh
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[6, 8], [10, 12]]
after_all
Semantik
Memastikan operasi yang menghasilkan inputs
dijalankan sebelum
operasi yang bergantung pada result
. Eksekusi operasi ini tidak melakukan apa pun,
fungsi tersebut hanya ada untuk membuat dependensi data dari result
hingga inputs
.
Input
Label | Nama | Jenis |
---|---|---|
(I1) | inputs |
bilangan variadis token |
Output
Nama | Jenis |
---|---|
result |
token |
Contoh
// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token
all_gather
Semantik
Dalam setiap grup proses di grid proses StableHLO, merangkai nilai-nilai
tensor operands
dari setiap proses di sepanjang all_gather_dim
dan menghasilkan
Tensor results
.
Operasi tersebut membagi grid proses StableHLO menjadi process_groups
yang
didefinisikan sebagai berikut:
cross_replica(replica_groups)
jikachannel_id <= 0 and use_global_device_ids = false
.cross_replica_and_partition(replica_groups)
jikachannel_id > 0 and use_global_device_ids = false
.flattened_ids(replica_groups)
jikachannel_id > 0 and use_global_device_ids = true
.
Setelah itu, dalam setiap process_group
:
operands...@receiver = [operand@sender for sender in process_group]
untuk semuareceiver
dalamprocess_group
.results...@process = concatenate(operands...@process, all_gather_dim)
untuk semuaprocess
dalamprocess_group
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operands |
jumlah tensor atau tensor terkuantisasi per-tensor | (C1), C6 |
(I2) | all_gather_dim |
konstanta jenis si64 |
(C1), C6 |
(I3) | replica_groups |
Konstanta tensor 2 dimensi jenis si64 |
(C2-C4) |
(I4) | channel_id |
konstanta jenis si64 |
(C5) |
(I5) | use_global_device_ids |
konstanta jenis i1 |
(C5) |
Output
Nama | Jenis | Batasan |
---|---|---|
results |
jumlah tensor atau tensor terkuantisasi per-tensor | (C6) |
Batasan
- (C1)
0 <= all_gather_dim < rank(operands...)
. - (C2)
is_unique(replica_groups)
. - (C3)
size(replica_groups)
didefinisikan sebagai:num_replicas
jikacross_replica
digunakan.num_replicas
jikacross_replica_and_partition
digunakan.num_processes
jikaflattened_ids
digunakan.
- (C4)
0 <= replica_groups < size(replica_groups)
. - (C5) Jika
use_global_device_ids = true
, makachannel_id > 0
. - (C6)
type(results...) = type(operands...)
kecuali:dim(results..., all_gather_dim) = dim(operands..., all_gather_dim) * dim(process_groups, 1)
.
Contoh
// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [[1, 2], [3, 4]]
// %operand0@(1, 0): [[5, 6], [7, 8]]
// %operand1@(0, 0): [[11, 12], [13, 14]]
// %operand1@(1, 0): [[15, 16], [17, 18]]
%result:2 = "stablehlo.all_gather"(%operand0, %operand1) {
all_gather_dim = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
// channel_id = 0
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
// use_global_device_ids = false
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> (tensor<2x4xi64>, tensor<2x4xi64>)
// %result0@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result0@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result1@(0, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
// %result1@(1, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
all_reduce
Semantik
Dalam setiap grup proses pada grid proses StableHLO, menerapkan pengurangan
fungsi computation
dengan nilai tensor operands
dari setiap proses
dan menghasilkan tensor results
.
Operasi tersebut membagi grid proses StableHLO menjadi process_groups
yang
didefinisikan sebagai berikut:
cross_replica(replica_groups)
jikachannel_id <= 0 and use_global_device_ids = false
.cross_replica_and_partition(replica_groups)
jikachannel_id > 0 and use_global_device_ids = false
.flattened_ids(replica_groups)
jikachannel_id > 0 and use_global_device_ids = true
.
Setelah itu, dalam setiap process_group
:
results...@process[result_index] = exec(schedule)
untuk beberapa hierarki binerschedule
dengan:exec(node)
=computation(exec(node.left), exec(node.right))
.exec(leaf)
=leaf.value
.
schedule
adalah hierarki biner yang ditentukan penerapan yang sesuai urutan traversal adalahto_destination_type(operands...@process_group...[result_index], type(func_inputs(computation)[0]))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operands |
jumlah tensor atau tensor terkuantisasi per-tensor | (C5), (C6) |
(I2) | replica_groups |
jumlah variadik konstanta tensor 1 dimensi jenis si64 |
(C1-C3) |
(I3) | channel_id |
konstanta jenis si64 |
(C4) |
(I4) | use_global_device_ids |
konstanta jenis i1 |
(C4) |
(I5) | computation |
fungsi | (C5) |
Output
Nama | Jenis | Batasan |
---|---|---|
results |
jumlah tensor atau tensor terkuantisasi per-tensor | (C6-C7) |
Batasan
- (C1)
is_unique(replica_groups)
. - (C2)
size(replica_groups)
didefinisikan sebagai:num_replicas
jikacross_replica
digunakan.num_replicas
jikacross_replica_and_partition
digunakan.num_processes
jikaflattened_ids
digunakan.
- (C3)
0 <= replica_groups < size(replica_groups)
. - (C4) Jika
use_global_device_ids = true
, makachannel_id > 0
. - (C5)
computation
memiliki jenis(tensor<E>, tensor<E>) -> (tensor<E>)
, denganis_promotable(element_type(operand), E)
. - (C6)
shape(results...) = shape(operands...)
. - (C7)
element_type(results...) = E
.
Contoh
// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [1, 2, 3, 4]
// %operand0@(1, 0): [5, 6, 7, 8]
// %operand1@(0, 0): [9, 10, 11, 12]
// %operand1@(1, 0): [13, 14, 15, 16]
%result:2 = "stablehlo.all_reduce"(%operand0, %operand0) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
// channel_id = 0
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
// use_global_device_ids = false
} : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>, tensor<4xi64>)
// %result0@(0, 0): [6, 8, 10, 12]
// %result0@(1, 0): [6, 8, 10, 12]
// %result1@(0, 0): [22, 24, 26, 28]
// %result1@(1, 0): [22, 24, 26, 28]
all_to_all
Semantik
Dalam setiap grup proses dalam {i>process grid<i} StableHLO, membagi nilai
tensor operands
di sepanjang split_dimension
menjadi beberapa bagian, sehingga akan menyebarkan
bagian di antara proses, menyambungkan bagian yang tersebar di
concat_dimension
dan menghasilkan tensor results
.
Operasi tersebut membagi grid proses StableHLO menjadi process_groups
yang
didefinisikan sebagai berikut:
cross_replica(replica_groups)
jikachannel_id <= 0
.cross_partition(replica_groups)
jikachannel_id > 0
.
Setelah itu, dalam setiap process_group
:
split_parts...@sender = split(operands...@sender, split_count, split_dimension)
untuk semuasender
diprocess_group
.scattered_parts...@receiver = [split_parts...@sender[receiver_index] for sender in process_group]
di manareceiver_index = process_group.index(receiver)
.results...@process = concatenate(scattered_parts...@process, concat_dimension)
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operands |
jumlah tensor atau tensor terkuantisasi per-tensor | (C1-C3), (C9) |
(I2) | split_dimension |
konstanta jenis si64 |
(C1), (C2), (C9) |
(I3) | concat_dimension |
konstanta jenis si64 |
(C3), (C9) |
(I4) | split_count |
konstanta jenis si64 |
(C2), (C4), (C8), (C9) |
(I5) | replica_groups |
Konstanta tensor 2 dimensi jenis si64 |
(C5-C8) |
(I6) | channel_id |
konstanta jenis si64 |
Output
Nama | Jenis | Batasan |
---|---|---|
results |
jumlah tensor atau tensor terkuantisasi per-tensor | (C9) |
Batasan
- (C1)
0 <= split_dimension < rank(operands...)
. - (C2)
dim(operands..., split_dimension) % split_count = 0
. - (C3)
0 <= concat_dimension < rank(operands...)
. - (C4)
0 < split_count
. - (C5)
is_unique(replica_groups)
. - (C6)
size(replica_groups)
ditentukan sebagai:num_replicas
jikacross_replica
digunakan.num_partitions
jikacross_partition
digunakan.
- (C7)
0 <= replica_groups < size(replica_groups)
. - (C8)
dim(replica_groups, 1) = split_count
. - (C9)
type(results...) = type(operands...)
kecuali, jikasplit_dimension != concat_dimension
:dim(results..., split_dimension) = dim(operands..., split_dimension) / split_count
.dim(results..., concat_dimension) = dim(operands..., concat_dimension) * split_count
.
Contoh
// num_replicas: 2
// num_partitions: 1
// %operand1@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand1@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
// %operand2@(0, 0): [[17, 18, 19, 20],
// [21, 22, 23, 24]]
// %operand2@(1, 0): [[25, 26, 27, 28],
// [29, 30, 31, 32]]
%result:2 = "stablehlo.all_to_all"(%operand1, %operand2) {
split_dimension = 1 : i64,
concat_dimension = 0 : i64,
split_count = 2 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
// channel_id = 0
} : (tensor<2x4xi64>, tensor<2x4xi64>) -> (tensor<4x2xi64>, tensor<4x2xi64>)
// %result#0@(0, 0): [[1, 2], [5, 6], [9, 10], [13, 14]]
// %result#0@(1, 0): [[3, 4], [7, 8], [11, 12], [15, 16]]
// %result#1@(0, 0): [[17, 18], [21, 22], [25, 26], [29, 30]]
// %result#1@(1, 0): [[19, 20], [23, 24], [27, 28], [31, 32]]
dan
Semantik
Melakukan AND berdasarkan elemen dari dua tensor lhs
dan rhs
serta menghasilkan result
tensor. Bergantung pada jenis elemen, melakukan hal berikut:
- Untuk boolean: logika AND.
- Untuk bilangan bulat: bitwise AND.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | lhs |
tensor jenis boolean atau bilangan bulat | (C1) |
(I2) | rhs |
tensor jenis boolean atau bilangan bulat | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor jenis boolean atau bilangan bulat | (C1) |
Batasan
- (C1)
type(lhs) = type(rhs) = type(result)
.
Contoh
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 0]]
atan2
Semantik
Menjalankan operasi atan2 berdasarkan elemen pada tensor lhs
dan rhs
serta menghasilkan
Tensor result
. Bergantung pada jenis elemen, melakukan hal berikut:
- Untuk float:
atan2
dari IEEE-754. - Untuk bilangan kompleks: atan2 kompleks.
- Untuk jenis terkuantisasi:
dequantize_op_quantize(atan2, lhs, rhs, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | lhs |
tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1) |
(I2) | rhs |
tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Contoh
// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs) : (tensor<3xf64>, tensor<3xf64>) -> tensor<3xf64>
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]
batch_norm_grad
Semantik
Menghitung gradien beberapa input propagasi mundur batch_norm_training
dari grad_output
, dan menghasilkan grad_operand
, grad_scale
, dan grad_offset
tensor. Secara lebih formal, operasi ini bisa
dinyatakan sebagai dekomposisi menjadi
operasi StabilHLO yang ada menggunakan sintaksis Python sebagai berikut:
def compute_sum(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
return sum
def compute_mean(operand, feature_index):
sum = compute_sum(operand, feature_index)
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
# Broadcast inputs to type(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance`
# Intermediate values will be useful for computing gradients
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
# Use the implementation from batchnorm_expander.cc in XLA
# Temporary variables have exactly the same names as in the C++ code
elements_per_feature = broadcast_in_dim(
constant(divide(size(operand), dim(operand, feature_index)),
element_type(grad_output)),
[], type(operand))
i1 = multiply(grad_output, elements_per_feature)
i2 = broadcast_in_dim(
compute_sum(grad_output, feature_index), [feature_index], type(operand))
i3 = broadcast_in_dim(
compute_sum(multiply(grad_output, centered_operand), feature_index),
[feature_index], type(operand))
i4 = multiply(i3, centered_operand)
i5 = divide(i4, add(variance_bcast, epsilon_bcast))
i6 = subtract(subtract(i1, i2), i5)
grad_operand =
multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
grad_scale =
compute_sum(multiply(grad_output, normalized_operand), feature_index)
grad_offset = compute_sum(grad_output, feature_index)
return grad_operand, grad_scale, grad_offset
Untuk jenis terkuantisasi, menjalankan
dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean,
variance, grad_output: batch_norm_grad(operand, scale, mean, variance,
grad_output, epsilon, feature_index), operand, scale, mean, variance,
grad_output, type(grad_operand), type(grad_scale), type(feature_index))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor tipe floating point atau tensor terkuantisasi per-tensor | (C1-C3), (C5) |
(I2) | scale |
Tensor 1-dimensi dari floating-point atau jenis terkuantisasi per-tensor | (C2), (C4), (C5) |
(I3) | mean |
Tensor 1-dimensi dari floating-point atau jenis terkuantisasi per-tensor | (C2), (C4) |
(I4) | variance |
Tensor 1-dimensi dari floating-point atau jenis terkuantisasi per-tensor | (C2), (C4) |
(I5) | grad_output |
tensor tipe floating point atau tensor terkuantisasi per-tensor | (C2), (C3) |
(I6) | epsilon |
konstanta jenis f32 |
|
(I7) | feature_index |
konstanta jenis si64 |
(C1), (C5) |
Output
Nama | Jenis | Batasan |
---|---|---|
grad_operand |
tensor tipe floating point atau tensor terkuantisasi per-tensor | (C2), (C3) |
grad_scale |
Tensor 1-dimensi dari floating-point atau jenis terkuantisasi per-tensor | (C2), (C4) |
grad_offset |
Tensor 1-dimensi dari floating-point atau jenis terkuantisasi per-tensor | (C2), (C4) |
Batasan
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,mean
,variance
,grad_output
,grad_operand
,grad_scale
dangrad_offset
memilikibaseline_element_type
yang sama. - (C3)
operand
,grad_output
, dangrad_operand
memiliki bentuk yang sama. - (C4)
scale
,mean
,variance
,grad_scale
, dangrad_offset
memiliki bentuk yang sama. - (C5)
size(scale) = dim(operand, feature_index)
.
Contoh
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
// [[0.1, 0.1], [0.1, 0.1]],
// [[0.1, 0.1], [0.1, 0.1]]
// ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>,
tensor<2x2x2xf64>) -> (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %grad_operand: [
// [[0.0, 0.0], [0.0, 0.0]],
// [[0.0, 0.0], [0.0, 0.0]]
// ]
// %grad_scale: [0.0, 0.0]
// %grad_offset: [0.4, 0.4]
batch_norm_inference
Semantik
Menormalkan tensor operand
di semua dimensi kecuali untuk
feature_index
dan menghasilkan tensor result
. Secara lebih formal,
operasi dapat dinyatakan sebagai dekomposisi untuk operasi StableHLO yang ada
menggunakan sintaks Python sebagai berikut:
def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
# Broadcast inputs to shape(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance` instead of
# computing them like `batch_norm_training` does.
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
return add(multiply(scale_bcast, normalized_operand), offset_bcast)
Untuk jenis terkuantisasi, menjalankan
dequantize_op_quantize(lambda operand, scale, offset, mean, variance:
batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index), operand, scale, offset, mean, variance, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor tipe floating point atau tensor terkuantisasi per-tensor | (C1-C7) |
(I2) | scale |
Tensor 1-dimensi dari floating-point atau jenis terkuantisasi per-tensor | (C2), (C3) |
(I3) | offset |
Tensor 1-dimensi dari floating-point atau jenis terkuantisasi per-tensor | (C2), (C4) |
(I4) | mean |
Tensor 1-dimensi dari floating-point atau jenis terkuantisasi per-tensor | (C5) |
(I5) | variance |
Tensor 1-dimensi dari floating-point atau jenis terkuantisasi per-tensor | (C2), (C6) |
(I6) | epsilon |
konstanta jenis f32 |
|
(I7) | feature_index |
konstanta jenis si64 |
(C1), C3-C6 |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor tipe floating point atau tensor terkuantisasi per-tensor | (C2), (C7) |
Batasan
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,offset
,mean
,variance
, danresult
memilikibaseline_element_type
yang sama. - (C3)
size(scale) = dim(operand, feature_index)
. - (C4)
size(offset) = dim(operand, feature_index)
. - (C5)
size(mean) = dim(operand, feature_index)
. - (C6)
size(variance) = dim(operand, feature_index)
. - (C7)
baseline_type(operand) = baseline_type(result)
.
Contoh
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>) -> tensor<2x2x2xf64>
// %result: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
batch_norm_training
Semantik
Menghitung rata-rata dan varians di semua dimensi kecuali untuk feature_index
dan menormalisasi tensor operand
yang menghasilkan output
, batch_mean
dan tensor batch_var
. Secara lebih formal, operasi ini
dapat dinyatakan sebagai
dekomposisi ke operasi StableHLO yang ada menggunakan sintaksis Python sebagai
berikut ini:
def compute_mean(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def compute_variance(operand, feature_index):
mean = compute_mean(operand, feature_index)
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
centered_operand = subtract(operand, mean_bcast)
return compute_mean(mul(centered_operand, centered_operand), feature_index)
def batch_norm_training(operand, scale, offset, epsilon, feature_index):
mean = compute_mean(operand, feature_index)
variance = compute_variance(operand, feature_index)
return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index),
mean, variance
Untuk jenis terkuantisasi, menjalankan
dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset:
batch_norm_training(operand, scale, offset, epsilon, feature_index), operand,
scale, offset, type(output), type(batch_mean), type(batch_var))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor tipe floating point atau tensor terkuantisasi per-tensor | (C1) |
(I2) | scale |
Tensor 1-dimensi dari floating-point atau per-tensor terkuantisasi | (C2), (C3) |
(I3) | offset |
Tensor 1-dimensi dari floating-point atau per-tensor terkuantisasi | (C2), (C4) |
(I4) | epsilon |
konstanta jenis f32 |
(C1), C3-C6 |
(I5) | feature_index |
konstanta jenis si64 |
(C1), C3-C6 |
Output
Nama | Jenis | Batasan |
---|---|---|
output |
tensor tipe floating point atau tensor terkuantisasi per-tensor | (C7) |
batch_mean |
Tensor 1-dimensi dari floating-point atau per-tensor terkuantisasi | (C2), (C5) |
batch_var |
Tensor 1-dimensi dari floating-point atau per-tensor terkuantisasi | (C2), (C6) |
Batasan
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,offset
,batch_mean
,batch_var
, danoutput
memilikibaseline_element_type
yang sama. - (C3)
size(scale) = dim(operand, feature_index)
. - (C4)
size(offset) = dim(operand, feature_index)
. - (C5)
size(batch_mean) = dim(operand, feature_index)
. - (C6)
size(batch_var) = dim(operand, feature_index)
. - (C7)
baseline_type(output) = baseline_type(operand)
.
Contoh
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) ->
(tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %output: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]
bitcast_convert
Semantik
Melakukan operasi bitcast pada tensor operand
dan menghasilkan tensor result
dengan bit dari seluruh tensor operand
diinterpretasikan ulang menggunakan
dari tensor result
.
Secara lebih formal, mengingat E = element_type(operand)
, E' = element_type(result)
,
dan R = rank(operand)
:
- Jika
num_bits(E') < num_bits(E)
,bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1])
. - Jika
num_bits(E') > num_bits(E)
,bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :])
. - Jika
num_bits(E') = num_bits(E)
,bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1])
.
bits
menampilkan representasi dalam memori dari nilai tertentu, beserta perilakunya
adalah implementasi-ditentukan karena representasi yang tepat dari tensor
yang ditentukan implementasinya, dan representasi
yang tepat dari tipe elemen adalah
serta definisi implementasi.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor atau quantized tensor | (C1-C2) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau quantized tensor | (C1-C2) |
Batasan
- (C1) Dengan mempertimbangkan
E = is_quantized(operand) ? storage_type(operand) : element_type(operand)
,E' = is_quantized(result) ? storage_type(result) : element_type(result)
, danR = rank(operand)
:- Jika
num_bits(E') = num_bits(E)
,shape(result) = shape(operand)
. - Jika
num_bits(E') < num_bits(E)
: rank(result) = R + 1
.dim(result, i) = dim(operand, i)
untuk semua0 <= i < R
.dim(result, R) * num_bits(E') = num_bits(E)
.- Jika
num_bits(E') > num_bits(E)
: rank(result) = R - 1
.dim(result, i) = dim(operand, i)
untuk semua0 <= i < R
.dim(operand, R - 1) * num_bits(E) = num_bits(E')
.
- Jika
- (C2) Jika
is_complex(operand) or is_complex(result)
, makais_complex(operand) and is_complex(result)
.
Contoh
// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand) : (tensor<f64>) -> tensor<4xf16>
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation
broadcast_in_dim
Semantik
Memperluas dimensi dan/atau peringkat tensor input dengan menduplikasi data
pada tensor operand
dan menghasilkan tensor result
. Secara lebih formal,
result[result_index] = operand[operand_index]
untuk semua d
di
axes(operand)
:
operand_index[d] = 0
jikadim(operand, d) = 1
.operand_index[d] = result_index[broadcast_dimensions[d]]
sebaliknya.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor atau quantized tensor | (C1-C2), (C5-C6) |
(I2) | broadcast_dimensions |
Konstanta tensor 1 dimensi jenis si64 |
(C2-C6) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau quantized tensor | (C1), (C3), (C5-C6) |
Batasan
- (C1)
element_type(result)
diberikan oleh:element_type(operand)
, jika!is_per_axis_quantized(operand)
.element_type(operand)
kecualiquantization_dimension(operand)
,scales(operand)
, danzero_points(operand)
mungkin berbeda denganquantization_dimension(result)
,scales(result)
, danzero_points(result)
resp., jika tidak.
- (C2)
size(broadcast_dimensions) = rank(operand)
. - (C3)
0 <= broadcast_dimensions < rank(result)
. - (C4)
is_unique(broadcast_dimensions)
. - (C5) Untuk semua
d
diaxes(operand)
:dim(operand, d) = 1
ataudim(operand, d) = dim(result, broadcast_dimensions[d])
.
- (C6) Jika
is_per_axis_quantized(result)
:quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)]
.- Jika
dim(operand, quantization_dimension(operand)) = 1
, makascales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result)))
.
Contoh
// %operand: [
// [1, 2, 3]
// ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
broadcast_dimensions = array<i64: 2, 1>
} : (tensor<1x3xi32>) -> tensor<2x3x2xi32>
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
casing
Semantik
Menghasilkan output dari mengeksekusi tepat satu fungsi dari branches
bergantung pada nilai index
. Secara lebih formal, result = selected_branch()
dalam hal ini:
selected_branch = branches[index]
jika0 <= index < size(branches)
.selected_branch = branches[-1]
sebaliknya.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | index |
Tensor 0-dimensi jenis si32 |
|
(I2) | branches |
jumlah fungsi variadik | (C1-C4) |
Output
Nama | Jenis | Batasan |
---|---|---|
results |
jumlah tensor, tensor atau token terkuantisasi | (C4) |
Batasan
- (C1)
0 < size(branches)
. - (C2)
input_types(branches...) = []
. - (C3)
same(output_types(branches...))
. - (C4)
type(results...) = output_types(branches[0])
.
Contoh
// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
"stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> ()
}, {
"stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> ()
}) : (tensor<i32>) -> (tensor<2xi64>, tensor<2xi64>)
// %result0: [1, 1]
// %result1: [1, 1]
Cbrt
Semantik
Melakukan operasi root kubik berdasarkan elemen pada tensor operand
dan menghasilkan
Tensor result
. Bergantung pada jenis elemen, melakukan hal berikut:
- Untuk float:
rootn(x, 3)
dari IEEE-754. - Untuk bilangan kompleks: akar kubik kompleks.
- Untuk jenis terkuantisasi:
dequantize_op_quantize(cbrt, operand, type(result))
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result)
.
Contoh
// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand) : (tensor<4xf64>) -> tensor<4xf64>
// %result: [0.0, 1.0, 2.0, 3.0]
ceil
Semantik
Melakukan ceil berbasis elemen dari tensor operand
dan menghasilkan tensor result
.
Menerapkan operasi roundToIntegralTowardPositive
dari IEEE-754
spesifikasi
pendukung. Untuk jenis terkuantisasi, menjalankan
dequantize_op_quantize(ceil, operand, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor tipe floating point atau tensor terkuantisasi per-tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor tipe floating point atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result)
.
Contoh
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]
Cholesky
Semantik
Menghitung dekomposisi Cholesky dari sekumpulan matriks.
Secara lebih formal, untuk semua i
di index_space(result)
,
result[i0, ..., iR-3, :, :]
adalah dekomposisi Cholesky dari
a[i0, ..., iR-3, :, :]
, dalam bentuk segitiga rendah
(jika lower
adalah true
) atau matriks segitiga atas (jika lower
adalah false
).
Nilai {i>output<i} dalam segitiga yang berlawanan, yaitu segitiga atas yang tegas atau
segitiga bawah yang ketat, yang juga ditentukan oleh implementasi.
Jika ada i
yang matriks inputnya bukan positif-definit Hermitian
matriks, maka perilakunya tidak terdefinisi.
Untuk jenis terkuantisasi, menjalankan
dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | a |
tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1-C3) |
(I2) | lower |
Konstanta tensor 0 dimensi jenis i1 |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(a) = baseline_type(result)
. - (C2)
2 <= rank(a)
. - (C3)
dim(a, -2) = dim(a, -1)
.
Contoh
// %a: [
// [1.0, 2.0, 3.0],
// [2.0, 20.0, 26.0],
// [3.0, 26.0, 70.0]
// ]
%result = "stablehlo.cholesky"(%a) {
lower = true
} : (tensor<3x3xf32>) -> tensor<3x3xf64>
// %result: [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
klem
Semantik
Menjepit setiap elemen tensor operand
di antara nilai minimum dan maksimum
dan menghasilkan tensor result
. Secara lebih formal, result[result_index] =
minimum(maximum(operand[result_index], min_element), max_element)
,
dengan min_element = rank(min) = 0 ? min[] : min[result_index]
,
max_element = rank(max) = 0 ? max[] : max[result_index]
. Untuk jenis terkuantisasi,
menjalankan dequantize_op_quantize(clamp, min, operand, max, type(result))
.
Memaksakan urutan pada bilangan kompleks melibatkan semantik yang mengejutkan, jadi di masa mendatang kami berencana menghapus dukungan untuk bilangan kompleks untuk operasi ini (#560).
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | min |
tensor atau tensor terkuantisasi per-tensor | (C1), C3 |
(I2) | operand |
tensor atau tensor terkuantisasi per-tensor | (C1-C4) |
(I3) | max |
tensor atau tensor terkuantisasi per-tensor | (C2), (C3) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau tensor terkuantisasi per-tensor | (C4) |
Batasan
- (C1)
rank(min) = 0 or shape(min) = shape(operand)
. - (C2)
rank(max) = 0 or shape(max) = shape(operand)
. - (C3)
baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max)
. - (C4)
baseline_type(operand) = baseline_type(result)
.
Contoh
// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max) : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// %result: [5, 13, 20]
collective_broadcast
Semantik
Dalam setiap grup proses di grid proses StableHLO, kirim nilai proses
Tensor operand
dari proses sumber ke proses target dan menghasilkan
Tensor result
.
Operasi tersebut membagi grid proses StableHLO menjadi process_groups
yang
didefinisikan sebagai berikut:
cross_replica(replica_groups)
jikachannel_id <= 0
.cross_partition(replica_groups)
jikachannel_id > 0
.
Setelah itu, result@process
diberikan oleh:
operand@process_groups[i, 0]
jika adai
sehingga prosesnya menjadi diprocess_groups[i]
.broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))
sebaliknya.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor atau tensor terkuantisasi per-tensor | (C3) |
(I2) | replica_groups |
jumlah variadik konstanta tensor 1 dimensi jenis si64 |
(C1), C2 |
(I3) | channel_id |
konstanta jenis si64 |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau tensor terkuantisasi per-tensor | (C3) |
Batasan
- (C1)
is_unique(replica_groups)
. - (C2)
0 <= replica_groups < N
denganN
ditentukan sebagai:num_replicas
jikacross_replica
digunakan.num_partitions
jikacross_partition
digunakan.
- (C3)
type(result) = type(operand)
.
Contoh
// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]
collective_permute
Semantik
Dalam setiap grup proses pada grid proses StableHLO, mengirimkan nilai
Tensor operand
dari proses sumber ke proses target dan menghasilkan
Tensor result
.
Operasi tersebut membagi grid proses StableHLO menjadi process_groups
yang
didefinisikan sebagai berikut:
cross_replica(source_target_pairs)
jikachannel_id <= 0
.cross_partition(source_target_pairs)
jikachannel_id > 0
.
Setelah itu, result@process
diberikan oleh:
operand@process_groups[i, 0]
, jika adai
sehinggaprocess_groups[i, 1] = process
.broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))
sebaliknya.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor atau tensor terkuantisasi per-tensor | (C5) |
(I2) | source_target_pairs |
Konstanta tensor 2 dimensi jenis si64 |
(C1-C4) |
(I3) | channel_id |
konstanta jenis si64 |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
dim(source_target_pairs, 1) = 2
. - (C2)
is_unique(source_target_pairs[:, 0])
. - (C3)
is_unique(source_target_pairs[:, 1])
. - (C4)
0 <= source_target_pairs < N
, denganN
ditentukan sebagai:num_replicas
jikacross_replica
digunakan.num_partitions
jikacross_partition
digunakan.
- (C5)
type(result) = type(operand)
.
Contoh
// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]
bandingkan
Semantik
Melakukan perbandingan tensor lhs
dan rhs
berdasarkan elemen sesuai dengan
comparison_direction
dan compare_type
, serta menghasilkan tensor result
.
Nilai comparison_direction
dan compare_type
memiliki hal berikut
semantik:
Untuk jenis elemen boolean dan bilangan bulat:
EQ
:lhs = rhs
.NE
:lhs != rhs
.GE
:lhs >= rhs
.GT
:lhs > rhs
.LE
:lhs <= rhs
.LT
:lhs < rhs
.
Untuk jenis elemen floating point dengan compare_type = FLOAT
, op akan mengimplementasikan
operasi IEEE-754 berikut:
EQ
:compareQuietEqual
.NE
:compareQuietNotEqual
.GE
:compareQuietGreaterEqual
.GT
:compareQuietGreater
.LE
:compareQuietLessEqual
.LT
:compareQuietLess
.
Untuk jenis elemen floating point dengan compare_type = TOTALORDER
, op
menggunakan kombinasi operasi totalOrder
dan compareQuietEqual
dari
IEEE-754.
Untuk jenis elemen kompleks, perbandingan leksikografis pasangan (real, imag)
dijalankan menggunakan comparison_direction
dan compare_type
yang disediakan.
Memaksakan urutan pada bilangan kompleks
melibatkan semantik yang mengejutkan,
jadi di masa mendatang kami berencana
menghapus dukungan untuk bilangan kompleks
jika comparison_direction
adalah GE
, GT
, LE
, atau LT
(#560).
Untuk jenis terkuantisasi. menjalankan dequantize_compare(lhs, rhs,
comparison_direction)
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | lhs |
tensor atau tensor terkuantisasi per-tensor | (C1-C3) |
(I2) | rhs |
tensor atau tensor terkuantisasi per-tensor | (C1-C2) |
(I3) | comparison_direction |
enum EQ , NE , GE , GT , LE , dan LT |
|
(I4) | compare_type |
enum FLOAT , TOTALORDER , SIGNED , dan UNSIGNED |
(C3) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor jenis boolean | (C2) |
Batasan
- (C1)
baseline_element_type(lhs) = baseline_element_type(rhs)
. - (C2)
shape(lhs) = shape(rhs) = shape(result)
. - (C3)
compare_type
didefinisikan sebagai:SIGNED
jikais_signed_integer(element_type(lhs))
.UNSIGNED
jikais_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs))
.FLOAT
atauTOTALORDER
jikais_float(element_type(lhs))
.FLOAT
jikais_complex(element_type(lhs))
.
Contoh
// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
comparison_direction = #stablehlo<comparison_direction LT>,
compare_type = #stablehlo<comparison_type FLOAT>
} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
// %result: [true, false]
kompleks
Semantik
Melakukan konversi {i>element<i}-{i>wise<i} ke nilai kompleks dari pasangan real dan
nilai imajiner, lhs
dan rhs
, serta menghasilkan tensor result
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | lhs |
tensor jenis f32 atau f64 |
(C1-C3) |
(I2) | rhs |
tensor jenis f32 atau f64 |
(C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor tipe kompleks | (C2), (C3) |
Batasan
- (C1)
type(lhs) = type(rhs)
. - (C2)
shape(result) = shape(lhs)
. - (C3)
element_type(result)
memiliki jeniscomplex<E>
, denganE = element_type(lhs)
.
Contoh
// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs) : (tensor<2xf64>, tensor<2xf64>) -> tensor<2xcomplex<f64>>
// %result: [(1.0, 2.0), (3.0, 4.0)]
gabungan
Semantik
Mengenkapsulasi operasi yang terdiri (tersusun) dari operasi StableHLO lainnya,
mengambil inputs
dan composite_attributes
, serta memproduksi results
. Tujuan
semantik op diimplementasikan oleh atribut decomposition
. Tujuan
op composite
dapat diganti dengan dekomposisinya tanpa mengubah program
semantik. Dalam kasus saat penyisipan dekomposisi tidak memberikan hasil
semantik operasi, sebaiknya gunakan custom_call
.
Kolom version
(default-nya adalah 0
) digunakan untuk menunjukkan kapan kolom
semantik berubah.
Input
Label | Nama | Jenis |
---|---|---|
(I1) | inputs |
jumlah nilai variadis |
(I2) | name |
konstanta jenis string |
(I3) | composite_attributes |
kamus atribut |
(I4) | decomposition |
konstanta jenis string |
(I5) | version |
konstanta jenis si32 |
Output
Nama | Jenis |
---|---|
results |
jumlah nilai variadis |
Batasan
- (C1)
is_namespaced_op_name(name)
- (C2)
is_defined_in_parent_scope(decomposition)
- (C3)
types(inputs...) == input_types(decomposition)
- (C4)
types(results...) == output_types(decomposition)
Contoh
%results = "stablehlo.composite"(%input0, %input1) {
name = "my_namespace.my_op",
composite_attributes = {
my_attribute = "my_value"
},
decomposition = @my_op,
version = 1 : i32
} : (tensor<f32>, tensor<f32>) -> tensor<f32>
concatenate
Semantik
Menggabungkan inputs
di sepanjang dimensi dimension
dalam urutan yang sama seperti yang ditentukan
argumen dan menghasilkan tensor result
. Secara lebih formal,
result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1]
, dengan:
id = d0 + ... + dk-1 + kd
.d
sama dengandimension
, dand0
, ... adalah ukuran dimensi ke-d
dariinputs
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | inputs |
jumlah tensor atau tensor terkuantisasi per-tensor | (C1-C6) |
(I2) | dimension |
konstanta jenis si64 |
(C2), (C4), (C6) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau tensor terkuantisasi per-tensor | (C5-C6) |
Batasan
- (C1)
same(element_type(inputs...))
. - (C2)
same(shape(inputs...))
kecuali untukdim(inputs..., dimension)
. - (C3)
0 < size(inputs)
. - (C4)
0 <= dimension < rank(inputs[0])
. - (C5)
element_type(result) = element_type(inputs[0])
. - (C6)
shape(result) = shape(inputs[0])
kecuali untuk:dim(result, dimension) = dim(inputs[0], dimension) + ...
.
Contoh
// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
dimension = 0 : i64
} : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64>
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]
konstanta
Semantik
Menghasilkan tensor output
dari value
yang konstan.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | value |
konstanta | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
output |
tensor atau quantized tensor | (C1) |
Batasan
- (C1)
type(value) = type(output)
.
Contoh
%output = "stablehlo.constant"() {
value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
} : () -> tensor<2x2xf32>
// %output: [[0.0, 1.0], [2.0, 3.0]]
melakukan konversi
Semantik
Melakukan konversi berbasis elemen dari satu jenis elemen ke jenis elemen lainnya pada
Tensor operand
dan menghasilkan tensor result
.
Untuk konversi boolean-to-any-supported-type, nilai false
adalah
dikonversi menjadi nol, dan nilai true
dikonversi menjadi satu. Sebagai
any-supported-type-to-boolean, nilai nol dikonversi menjadi
false
, dan nilai selain nol dikonversi menjadi true
. Lihat di bawah untuk mengetahui cara
bekerja untuk tipe-tipe
yang kompleks.
Untuk konversi yang melibatkan integer-ke-bilangan bulat, bilangan bulat-ke-floating-point atau floating-point-to-floating-point, jika nilai sumber bisa tepat direpresentasikan dalam jenis tujuan, nilai hasil adalah nilai persis merepresentasinya. Jika tidak, perilaku akan ditentukan nanti (#180).
Untuk konversi yang melibatkan floating-point-to-integer, bagian pecahannya adalah terpotong. Jika nilai yang terpotong tidak dapat ditampilkan dalam jenis tujuan, perilakunya akan ditentukan nanti (#180).
Konversi yang melibatkan kompleks-ke-kompleks mengikuti perilaku yang sama dari konversi floating-point-to-floating-point untuk mengonversi real time dan bagian imajiner.
Untuk konversi complex-to-any-other-type dan complex-to-any-other-type, nilai imajiner sumber diabaikan atau nilai imajiner tujuannya adalah masing-masing menjadi nol. Konversi dari bagian riil mengikuti konversi floating point.
Pada prinsipnya, operasi ini dapat mengekspresikan dekuantisasi (konversi dari
tensor terkuantisasi ke tensor reguler), kuantisasi (konversi dari
tensor ke tensor terkuantisasi) dan rekuantisasi (konversi antara
tensor), tetapi saat ini kami memiliki operasi khusus untuk itu -
uniform_dequantize
untuk kasus penggunaan pertama dan uniform_quantize
untuk kasus penggunaan pertama
kedua dan ketiga. Di masa mendatang, kedua operasi ini dapat digabungkan
ke dalam convert
(#1576).
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
Tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
Tensor | (C1) |
Batasan
- (C1)
shape(operand) = shape(result)
.
Contoh
// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand) : (tensor<3xi64>) -> tensor<3xcomplex<f64>>
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]
konvolusi
Semantik
Menghitung produk titik antara jendela lhs
dan irisan rhs
serta menghasilkan
result
. Diagram berikut menunjukkan cara penghitungan elemen pada result
lhs
dan rhs
menggunakan contoh konkret.
Secara lebih formal, pertimbangkan framing ulang input berikut dalam hal lhs
agar dapat mengekspresikan jendela lhs
:
lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension))
.lhs_window_strides = lhs_shape(1, window_strides, 1)
.lhs_padding = lhs_shape([0, 0], padding, [0, 0])
.lhs_base_dilations = lhs_shape(1, lhs_dilation, 1)
.lhs_window_dilations = lhs_shape(1, rhs_dilation, 1)
.
Penyesuaian frame ini menggunakan fungsi bantuan berikut:
lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension])
.result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension])
.permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1]
denganj[d] = i[permutation[d]]
.
Jika feature_group_count = 1
dan batch_group_count = 1
, maka untuk semua
output_spatial_index
di index_space(dim(result, output_spatial_dimensions...))
,
result[result_shape(:, output_spatial_index, :)] = dot_product
dengan:
padding_value = constant(0, element_type(lhs))
.padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1)
.lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides
.lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations)
.reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true])
. Fitur ini tampaknya tidak digunakan, jadi di masa mendatang kami berencana untuk menghapusnya (#1181).dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension])
.
Jika feature_group_count > 1
:
lhses = split(lhs, feature_group_count, input_feature_dimension)
.rhses = split(rhs, feature_group_count, kernel_output_feature_dimension)
.results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...)
.result = concatenate(results, output_feature_dimension)
.
Jika batch_group_count > 1
:
lhses = split(lhs, batch_group_count, input_batch_dimension)
.rhses = split(rhs, batch_group_count, kernel_output_feature_dimension)
.results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...)
.result = concatenate(results, output_feature_dimension)
.
Untuk jenis terkuantisasi, menjalankan dequantize_op_quantize(
lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs,
type(result))
.
Untuk jenis terkuantisasi hybrid, menjalankan hybrid_dequantize_then_op(
lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs)
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | lhs |
tensor atau tensor terkuantisasi per-tensor | (C1), (C10-C11), (C14) (C25), (C27-C28), (C31-C32), (C34) |
(I2) | rhs |
tensor atau quantized tensor | (C1), (C14-C16), (C25), (C27-C29), (C31-C34) |
(I3) | window_strides |
Konstanta tensor 1 dimensi jenis si64 |
(C2-C3), (C25) |
(I4) | padding |
Konstanta tensor 2 dimensi jenis si64 |
(C4), (C25) |
(I5) | lhs_dilation |
Konstanta tensor 1 dimensi jenis si64 |
(C5-C6), (C25) |
(I6) | rhs_dilation |
Konstanta tensor 1 dimensi jenis si64 |
(C7-C8), (C25) |
(I7) | window_reversal |
Konstanta tensor 1 dimensi jenis i1 |
(C9) |
(I8) | input_batch_dimension |
konstanta jenis si64 |
(C10), (C13), (C25) |
(I9) | input_feature_dimension |
konstanta jenis si64 |
(C11), (C13-C14) |
(I10) | input_spatial_dimensions |
Konstanta tensor 1 dimensi jenis si64 |
(C12), (C13), (C25) |
(I11) | kernel_input_feature_dimension |
konstanta jenis si64 |
(C14), (C18) |
(I12) | kernel_output_feature_dimension |
konstanta jenis si64 |
(C15-C16), (C18), (C25), (C29) |
(I13) | kernel_spatial_dimensions |
Konstanta tensor 1 dimensi jenis si64 |
(C17-C18), (C25) |
(I14) | output_batch_dimension |
konstanta jenis si64 |
(C20), (C25) |
(I15) | output_feature_dimension |
konstanta jenis si64 |
(C20), (C25), (C30) |
(I16) | output_spatial_dimensions |
Konstanta tensor 1 dimensi jenis si64 |
(C19-C20), (C25) |
(I17) | feature_group_count |
konstanta jenis si64 |
(C11), (C14), (C16), (C21), (C23) |
(I18) | batch_group_count |
konstanta jenis si64 |
(C10), (C15), (C22), (C23), (C25) |
(I19) | precision_config |
jumlah variadik enum DEFAULT , HIGH , dan HIGHEST |
(C24) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau quantized tensor | (C25-C28), (C30), (C32-34) |
Batasan
- (C1)
N = rank(lhs) = rank(rhs)
. - (C2)
size(window_strides) = N - 2
. - (C3)
0 < window_strides
. - (C4)
shape(padding) = [N - 2, 2]
. - (C5)
size(lhs_dilation) = N - 2
. - (C6)
0 < lhs_dilation
. - (C7)
size(rhs_dilation) = N - 2
. - (C8)
0 < rhs_dilation
. - (C9)
size(window_reversal) = N - 2
. - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0
. - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0
. - (C12)
size(input_spatial_dimensions) = N - 2
. - (C13) Dengan mempertimbangkan
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]
:is_unique(input_dimensions)
.0 <= input_dimensions < N
.
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count
. - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0
. - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0
. - (C17)
size(kernel_spatial_dimensions) = N - 2
. - (C18) Dengan mempertimbangkan
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]
:is_unique(kernel_dimensions)
.0 <= kernel_dimensions < N
.
- (C19)
size(output_spatial_dimensions) = N - 2
. - (C20) Dengan mempertimbangkan
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]
:is_unique(output_dimensions)
.0 <= output_dimensions < N
.
- (C21)
0 < feature_group_count
. - (C22)
0 < batch_group_count
. - (C23)
feature_group_count = 1 or batch_group_count = 1
. - (C24)
size(precision_config) = 2
. - (C25)
dim(result, result_dim)
didefinisikan sebagai:dim(lhs, input_batch_dimension) / batch_group_count
jikaresult_dim = output_batch_dimension
.dim(rhs, kernel_output_feature_dimension)
jikaresult_dim = output_feature_dimension
.num_windows
jika tidak, jika:output_spatial_dimensions[spatial_dim] = result_dim
.lhs_dim = input_spatial_dimensions[spatial_dim]
.rhs_dim = kernel_spatial_dimensions[spatial_dim]
.dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
.padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
.dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
.is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
.num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1
.
- (C26)
rank(result) = N
. - Jika operasi tersebut menggunakan tensor non-kuantisasi:
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result)
.
- (C27)
- Jika operasi tersebut menggunakan tensor terkuantisasi:
- (C28)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
. - (C29) Jika
is_per_axis_quantized(rhs)
, laluquantization_dimension(rhs) = kernel_output_feature_dimension
. - (C30) Jika
is_per_axis_quantized(result)
, makaquantization_dimension(result) = output_feature_dimension
. - Jika
is_quantized(lhs)
: - (C31)
storage_type(lhs) = storage_type(rhs)
. - (C32)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C33) Jika
is_per_tensor_quantized(rhs)
, makais_per_tensor_quantized(result)
. - Jika
!is_quantized(lhs)
: - (C34)
element_type(lhs) = expressed_type(rhs) = element_type(result)
.
- (C28)
Contoh
// %lhs: [[
// [
// [1], [2], [5], [6]
// ],
// [
// [3], [4], [7], [8]
// ],
// [
// [10], [11], [14], [15]
// ],
// [
// [12], [13], [16], [17]
// ]
// ]]
//
// %rhs: [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
window_strides = array<i64: 4, 4>,
padding = dense<0> : tensor<2x2xi64>,
lhs_dilation = array<i64: 2, 2>,
rhs_dilation = array<i64: 1, 1>,
window_reversal = array<i1: false, false>,
// In the StableHLO dialect, dimension numbers are encoded via:
// `[<input dimensions>]x[<kernel dimensions>]->[output dimensions]`.
// "b" is batch dimension, "f" is feature dimension,
// "i" is input feature dimension, "o" is output feature dimension,
// "0/1/etc" are spatial dimensions.
dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
batch_group_count = 1 : i64,
feature_group_count = 1 : i64,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
// [[10], [26]],
// [[46], [62]]
// ]]
kosinus
Semantik
Melakukan operasi kosinus element-wise pada tensor operand
dan menghasilkan
Tensor result
. Bergantung pada jenis elemen, melakukan hal berikut:
- Untuk float:
cos
dari IEEE-754. - Untuk bilangan kompleks: kosinus kompleks.
- Untuk jenis terkuantisasi:
dequantize_op_quantize(cosine, operand, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result)
.
Contoh
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.cosine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.0], [-1.0, 0.0]]
count_leading_zeros
Semantik
Melakukan penghitungan berbasis elemen dari jumlah bit nol di awal dalam operand
dan menghasilkan tensor result
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor tipe integer | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor tipe integer | (C1) |
Batasan
- (C1)
type(operand) = type(result)
.
Contoh
// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand) : (tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[64, 63], [56, 0]]
custom_call
Semantik
Mengenkapsulasi call_target_name
operasi yang ditentukan implementasi yang menggunakan
inputs
dan called_computations
serta menghasilkan results
. has_side_effect
,
backend_config
dan api_version
dapat digunakan untuk memberikan
metadata yang ditentukan implementasi.
Saat ini, operasi ini berisi kumpulan data yang tidak terorganisir {i>metadata <i}yang mencerminkan evolusi organik dari operasi kompilator XLA. Di masa mendatang, kami berencana untuk menyatukan metadata ini (#741).
Input
Label | Nama | Jenis |
---|---|---|
(I1) | inputs |
jumlah nilai variadis |
(I2) | call_target_name |
konstanta jenis string |
(I3) | has_side_effect |
konstanta jenis i1 |
(I4) | backend_config |
konstanta jenis string atau kamus atribut |
(I5) | api_version |
konstanta jenis si32 |
(I6) | called_computations |
jumlah konstanta variadik jenis string |
Output
Nama | Jenis |
---|---|
results |
jumlah nilai variadis |
Contoh
%results = "stablehlo.custom_call"(%input0) {
call_target_name = "foo",
has_side_effect = false,
backend_config = {bar = 42 : i32},
api_version = 4 : i32,
called_computations = [@foo]
} : (tensor<f64>) -> tensor<f64>
bagi
Semantik
Melakukan pembagian menurut elemen dari tensor lhs
dan pembagi rhs
serta
menghasilkan tensor result
. Bergantung pada jenis elemen, melakukan hal berikut:
- Untuk bilangan bulat: pembagian bilangan bulat yang menghasilkan hasil bagi aljabar dengan bagian pecahan dibuang.
- Untuk float:
division
dari IEEE-754. - Untuk bilangan kompleks: pembagian kompleks.
- Untuk jenis terkuantisasi:
dequantize_op_quantize(divide, lhs, rhs, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | lhs |
tensor integer, floating-point, atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1) |
(I2) | rhs |
tensor integer, floating-point, atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor bilangan bulat, floating point, atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Contoh
// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]
dot_general
Semantik
Menghitung produk titik di antara irisan lhs
dan irisan rhs
serta menghasilkan
Tensor result
.
Secara lebih formal, result[result_index] = dot_product
, dengan:
lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions]
.rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions]
.result_batching_index + result_lhs_index + result_rhs_index = result_index
dengansize(result_batching_index) = size(lhs_batching_dimensions)
,size(result_lhs_index) = size(lhs_result_dimensions)
dansize(result_rhs_index) = size(rhs_result_dimensions)
.transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions)
.transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :])
.reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions))
.transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions)
.transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :])
.reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions))
.dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y))
.
Untuk jenis terkuantisasi, menjalankan dequantize_op_quantize(
lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
rhs_batching_dimensions, lhs_contracting_dimensions,
rhs_contracting_dimensions, precision_config), lhs, rhs, type(result))
.
Untuk jenis terkuantisasi hybrid, menjalankan hybrid_dequantize_then_op(
lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
rhs_batching_dimensions, lhs_contracting_dimensions,
rhs_contracting_dimensions, precision_config), lhs, rhs)
.
precision_config
mengontrol keseimbangan antara kecepatan dan akurasi untuk
komputasi pada backend akselerator. Kolom ini dapat berupa salah satu dari hal berikut (di
semantik, semantik dari nilai enum ini kurang ditentukan, namun kita
berencana untuk mengatasinya dalam
#755):
DEFAULT
: Penghitungan tercepat, tetapi perkiraan yang paling tidak akurat ke nomor aslinya.HIGH
: Penghitungan lebih lambat, tetapi perkiraan yang lebih akurat terhadap nomor aslinya.HIGHEST
: Penghitungan paling lambat, tetapi perkiraan paling akurat ke nomor aslinya.
DotAlgorithm
menentukan properti utama algoritma yang digunakan untuk menerapkan
operasi titik, yang juga menentukan presisi. Jika atribut algoritma
kolom ditetapkan, maka precision_config
harus DEFAULT
. DotAlgorithms
tidak memiliki nilai default, karena parameter default adalah implementasi
didefinisikan. Dengan demikian, semua kolom algoritme titik dapat ditetapkan ke None
untuk menentukan
algoritma titik kosong, yang akan menggunakan nilai precision_config
.
Kolom DotAlgorithm
mencakup:
lhs_precision_type
danrhs_precision_type
, presisi yang ditentukan oleh LHS dan RHS operasi dibulatkan ke. Jenis presisi tidak bergantung pada jenis penyimpanan input dan output-nya.accumulation_type
presisi yang digunakan untuk akumulasi.lhs_component_count
,rhs_component_count
, dannum_primitive_operations
terapkan saat kita melakukan algoritma yang menguraikan LHS dan/atau RHS menjadi banyak komponen dan melakukan beberapa “primitif” operasi dot pada nilai - biasanya untuk mengemulasi presisi yang lebih tinggi (mis. Memanfaatkan Jenis Data Kecerdasan Buatan bfloat16 untuk Komputasi Presisi Lebih Tinggi: bf16_6x tf32_3x, dll.). Untuk algoritma tanpa dekomposisi, nilai ini harus ditetapkan ke1
.allow_imprecise_accumulation
untuk menentukan apakah akumulasi dalam presisi lebih rendah diizinkan untuk beberapa langkah (mis.CUBLASLT_MATMUL_DESC_FAST_ACCUM
).
Contoh atribut DotAlgorithm
:
// Inputs are casted to tf32, and then accumulated in f32:
{lhs_precision_type = tf32,
rhs_precision_type = tf32,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = false}
// bf16_6x: each input is decomposed to 3 bf16 components, then 6 dot operations are done on those components, and the result is accumulated in f32.
{lhs_precision_type = bf16,
rhs_precision_type = bf16,
accumulation_type = f32,
lhs_component_count = 3,
rhs_component_count = 3,
num_primitive_operations = 6,
allow_imprecise_accumulation = false}
// Inputs are (casted to) f8e5m2, and we accumulate in f32, but for some steps we may accumulate in lower precision.
{lhs_precision_type = f8e5m2,
rhs_precision_type = f8e5m2,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = true}
Implementasi bebas untuk memutuskan kombinasi mana yang didukung. Di beberapa umum, tidak ada jaminan bahwa setiap algoritma didukung pada setiap jenis akselerator oleh konsumen StableHLO. Jika algoritma tertentu tidak didukung, maka kesalahan akan dimunculkan, bukan untuk kembali ke alternatif. Verifikasi StabilHLO akan memberikan upaya verifikasi terbaik, mencegah algoritme yang tidak diketahui didukung pada perangkat keras apa pun.
Lihat xla_data.proto > Algorithm
untuk beberapa nilai algoritma yang didukung. Tiket #2483 berisi rencana untuk membuat
dokumen terpusat pada algoritma yang
didukung oleh backend.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | lhs |
tensor atau tensor terkuantisasi per-tensor | (C5-C6), (C9-C10), (C12-C14), (C17-C18), (C20) |
(I2) | rhs |
tensor atau quantized tensor | (C7-C10), (C12-C20) |
(I3) | lhs_batching_dimensions |
Konstanta tensor 1 dimensi jenis si64 |
(C1), (C3), (C5), (C9), (C12) |
(I4) | rhs_batching_dimensions |
Konstanta tensor 1 dimensi jenis si64 |
(C1), (C4), (C7), (C9) |
(I5) | lhs_contracting_dimensions |
Konstanta tensor 1 dimensi jenis si64 |
(C2), (C3), (C6), (C10) |
(I6) | rhs_contracting_dimensions |
Konstanta tensor 1 dimensi jenis si64 |
(C2), (C4), (C8), (C10), (C16) |
(I7) | precision_config |
jumlah variadik enum DEFAULT , HIGH , dan HIGHEST |
(C11), C21 |
(I8) | lhs_precision_type |
FloatType atau TensorFloat32 | (C21) |
(I9) | rhs_precision_type |
FloatType atau TensorFloat32 | (C21) |
(I10) | accumulation_type |
FloatType atau TensorFloat32 | (C21) |
(I11) | lhs_component_count |
konstanta jenis si32 |
(C21), C22 |
(I12) | rhs_component_count |
konstanta jenis si32 |
(C21), C23 |
(I13) | num_primitive_operations |
konstanta jenis si32 |
(C21), C24 |
(I14) | allow_imprecise_accumulation |
konstanta jenis bool |
(C21) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau quantized tensor | (C12), (C14), (C18—C20) |
Batasan
- (C1)
size(lhs_batching_dimensions) = size(rhs_batching_dimensions)
. - (C2)
size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions)
. - (C3)
is_unique(lhs_batching_dimensions + lhs_contracting_dimensions)
. - (C4)
is_unique(rhs_batching_dimensions + rhs_contracting_dimensions)
. - (C5)
0 <= lhs_batching_dimensions < rank(lhs)
. - (C6)
0 <= lhs_contracting_dimensions < rank(lhs)
. - (C7)
0 <= rhs_batching_dimensions < rank(rhs)
. - (C8)
0 <= rhs_contracting_dimensions < rank(rhs)
. - (C9)
dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)
. - (C10)
dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...)
. - (C11)
size(precision_config) = 2
. - (C12)
shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions)
. - Jika operasi tersebut menggunakan tensor non-kuantisasi:
- (C13)
element_type(lhs) = element_type(rhs)
.
- (C13)
- Jika operasi tersebut menggunakan tensor terkuantisasi:
- (C14)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
. - (C15)
zero_points(rhs) = 0
. - (C16) Jika
is_per_axis_quantized(rhs)
, makaquantization_dimension(rhs)
tidak ada dirhs_contracting_dimensions
. - Jika
is_quantized(lhs)
: - (C17)
storage_type(lhs) = storage_type(rhs)
. - (C18)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C19) Jika
is_per_tensor_quantized(rhs)
, makais_per_tensor_quantized(result)
. - Jika
!is_quantized(lhs)
: - (C20)
element_type(lhs) = expressed_type(rhs) = element_type(result)
.
- (C14)
- Jika
!is_empty_algorithm(lhs_precision_type, rhs_precision_type, accumulation_type, lhs_component_count, rhs_component_count, num_primitive_operations allow_imprecise_accumulation)
:- (C21)
precision_config... = DEFAULT
. - (C22)
0 < lhs_component_count
. - (C23)
0 < rhs_component_count
. - (C24)
0 < num_primitive_operations
.
- (C21)
Contoh
// %lhs: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
// %rhs: [
// [[1, 0],
// [0, 1]],
// [[1, 0],
// [0, 1]]
// ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
dot_dimension_numbers = #stablehlo.dot<
lhs_batching_dimensions = [0],
rhs_batching_dimensions = [0],
lhs_contracting_dimensions = [2],
rhs_contracting_dimensions = [1]
>,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>],
algorithm = #stablehlo.dot_algorithm<
lhs_precision_type = tf32,
rhs_precision_type = tf32,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = false
>
} : (tensor<2x2x2xi64>, tensor<2x2x2xi64>) -> tensor<2x2x2xi64>
// %result: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
dynamic_broadcast_in_dim
Semantik
Operasi ini secara fungsional identik dengan
broadcast_in_dim
berjalan, tetapi bentuk hasilnya ditentukan secara dinamis melalui output_dimensions
.
Operasi tersebut juga menerima atribut opsional known_expanding_dimensions
, known_non_expanding_dimensions
untuk mengekspresikan pengetahuan statis tentang perilaku dimensi yang meluas.
Jika tidak ditentukan, semua dimensi diasumsikan mungkin dapat diperluas.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor atau quantized tensor | (C1-C2), (C5-C6), (C9) |
(I2) | output_dimensions |
Tensor 1-dimensi dari jenis integer | (C7) |
(I3) | broadcast_dimensions |
Tensor konstanta 1 dimensi dari jenis integer | (C2-C6) |
(I4) | known_expanding_dimensions |
Tensor konstanta 1 dimensi dari jenis integer | (C8-C9) |
(I5) | known_non_expanding_dimensions |
Tensor konstanta 1 dimensi dari jenis integer | (C8-C9) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau quantized tensor | (C1), (C3), (C5-C7) |
Batasan
- (C1)
element_type(result)
diberikan oleh:element_type(operand)
, jika!is_per_axis_quantized(operand)
.element_type(operand)
kecualiquantization_dimension(operand)
,scales(operand)
, danzero_points(operand)
mungkin berbeda denganquantization_dimension(result)
,scales(result)
, danzero_points(result)
resp., jika tidak.
- (C2)
size(broadcast_dimensions) = rank(operand)
. - (C3)
0 <= broadcast_dimensions < rank(result)
. - (C4)
is_unique(broadcast_dimensions)
. - (C5) Untuk semua
d
diaxes(operand)
:dim(operand, d) = 1
ataudim(operand, d) = dim(result, broadcast_dimensions[d])
.
- (C6) Jika
is_per_axis_quantized(result)
:quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)]
.- Jika
dim(operand, quantization_dimension(operand)) = 1
, makascales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result)))
.
- (C7)
size(output_dimensions) = rank(result)
. - (C8)
is_unique(known_expanding_dimensions + known_non_expanding_dimensions)
. - (C9)
0 <= known_expanding_dimensions < rank(operand)
. - (C10)
0 <= known_non_expanding_dimensions < rank(operand)
.
Contoh
// %operand: [
// [1, 2, 3]
// ]
%operand = stablehlo.constant dense<[[1, 2, 3]]> : tensor<1x3xi64>
%output_dimensions = stablehlo.constant dense<[2, 3, 2]> : tensor<3xi64>
%result = "stablehlo.dynamic_broadcast_in_dim"(%operand, %output_dimensions) {
broadcast_dimensions = array<i64: 2, 1>,
known_expanding_dimensions = array<i64: 0>,
known_non_expanding_dimensions = array<i64: 1>
} : (tensor<1x3xi64>, tensor<3xi64>) -> tensor<2x3x2xi64>
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
dynamic_conv
Semantik
Operasi ini secara fungsional identik dengan
konvolusi
op, tetapi padding ditentukan secara dinamis melalui padding
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | lhs |
tensor atau tensor terkuantisasi per-tensor | (C1), (C10-C11), (C14) (C25), (C26-C27), (C30-C31), (C33) |
(I2) | rhs |
tensor atau quantized tensor | (C1), (C14-C16), (C26-C28), (C30-C33) |
(I3) | padding |
Tensor 2 dimensi dari tipe integer | (C4) |
(I4) | window_strides |
Konstanta tensor 1 dimensi jenis si64 |
(C2-C3) |
(I5) | lhs_dilation |
Konstanta tensor 1 dimensi jenis si64 |
(C5-C6) |
(I6) | rhs_dilation |
Konstanta tensor 1 dimensi jenis si64 |
(C7-C8) |
(I7) | window_reversal |
Konstanta tensor 1 dimensi jenis i1 |
(C9) |
(I8) | input_batch_dimension |
konstanta jenis si64 |
(C10), C13) |
(I9) | input_feature_dimension |
konstanta jenis si64 |
(C11), (C13-C14) |
(I10) | input_spatial_dimensions |
Konstanta tensor 1 dimensi jenis si64 |
(C12), C13) |
(I11) | kernel_input_feature_dimension |
konstanta jenis si64 |
(C14), (C18) |
(I12) | kernel_output_feature_dimension |
konstanta jenis si64 |
(C15-C16), (C18), (C28) |
(I13) | kernel_spatial_dimensions |
Konstanta tensor 1 dimensi jenis si64 |
(C17-C18) |
(I14) | output_batch_dimension |
konstanta jenis si64 |
(C20) |
(I15) | output_feature_dimension |
konstanta jenis si64 |
(C20), (C29) |
(I16) | output_spatial_dimensions |
Konstanta tensor 1 dimensi jenis si64 |
(C19-C20) |
(I17) | feature_group_count |
konstanta jenis si64 |
(C11), (C14), (C16), (C21), (C23) |
(I18) | batch_group_count |
konstanta jenis si64 |
(C10), (C15), (C22), (C23) |
(I19) | precision_config |
jumlah variadik enum DEFAULT , HIGH , dan HIGHEST |
(C24) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau quantized tensor | (C25-C27), (C29), (C31-C33) |
Batasan
- (C1)
N = rank(lhs) = rank(rhs)
. - (C2)
size(window_strides) = N - 2
. - (C3)
0 < window_strides
. - (C4)
shape(padding) = [N - 2, 2]
. - (C5)
size(lhs_dilation) = N - 2
. - (C6)
0 < lhs_dilation
. - (C7)
size(rhs_dilation) = N - 2
. - (C8)
0 < rhs_dilation
. - (C9)
size(window_reversal) = N - 2
. - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0
. - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0
. - (C12)
size(input_spatial_dimensions) = N - 2
. - (C13) Dengan mempertimbangkan
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]
:is_unique(input_dimensions)
.0 <= input_dimensions < N
.
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count
. - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0
. - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0
. - (C17)
size(kernel_spatial_dimensions) = N - 2
. - (C18) Dengan mempertimbangkan
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]
:is_unique(kernel_dimensions)
.0 <= kernel_dimensions < N
.
- (C19)
size(output_spatial_dimensions) = N - 2
. - (C20) Dengan mempertimbangkan
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]
:is_unique(output_dimensions)
.0 <= output_dimensions < N
.
- (C21)
0 < feature_group_count
. - (C22)
0 < batch_group_count
. - (C23)
feature_group_count = 1 or batch_group_count = 1
. - (C24)
size(precision_config) = 2
. - (C25)
dim(result, result_dim)
didefinisikan sebagai:dim(lhs, input_batch_dimension) / batch_group_count
jikaresult_dim = output_batch_dimension
.dim(rhs, kernel_output_feature_dimension)
jikaresult_dim = output_feature_dimension
.num_windows
jika tidak, jika:output_spatial_dimensions[spatial_dim] = result_dim
.lhs_dim = input_spatial_dimensions[spatial_dim]
.rhs_dim = kernel_spatial_dimensions[spatial_dim]
.dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
.padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
.dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
.is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
.num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1
.
- (C26)
rank(result) = N
. - Jika operasi tersebut menggunakan tensor non-kuantisasi:
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result)
.
- (C27)
- Jika operasi tersebut menggunakan tensor terkuantisasi:
- (C28)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
. - (C29) Jika
is_per_axis_quantized(rhs)
, laluquantization_dimension(rhs) = kernel_output_feature_dimension
. - (C30) Jika
is_per_axis_quantized(result)
, makaquantization_dimension(result) = output_feature_dimension
. - Jika
is_quantized(lhs)
: - (C31)
storage_type(lhs) = storage_type(rhs)
. - (C32)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C33) Jika
is_per_tensor_quantized(rhs)
, makais_per_tensor_quantized(result)
. - Jika
!is_quantized(lhs)
: - (C34)
element_type(lhs) = expressed_type(rhs) = element_type(result)
.
- (C28)
Contoh
// %lhs: [[
// [[1], [2], [5], [6]],
// [[3], [4], [7], [8]],
// [[10], [11], [14], [15]],
// [[12], [13], [16], [17]]
// ]]
//
// %rhs: [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
// %padding: [[1, 1],
// [1, 1]]
%result = "stablehlo.dynamic_conv"(%lhs, %rhs, %padding) {
window_strides = array<i64: 4, 4>,
lhs_dilation = array<i64: 2, 2>,
rhs_dilation = array<i64: 1, 1>,
window_reversal = array<i1: false, false>,
dimension_numbers = #stablehlo.conv<raw
input_batch_dimension = 0,
input_feature_dimension = 3,
input_spatial_dimensions = [0, 1],
kernel_input_feature_dimension = 2,
kernel_output_feature_dimension = 3,
kernel_spatial_dimensions = [0, 1],
output_batch_dimension = 0,
output_feature_dimension = 3,
output_spatial_dimensions = [1, 2]
>,
feature_group_count = 1 : i64,
batch_group_count = 1 : i64,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>, tensor<2x2xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
// [[1], [5]],
// [[10], [14]]
// ]]
dynamic_gather
Semantik
Operasi ini secara fungsional identik dengan
kumpulkan
op, dengan slice_sizes
yang ditentukan secara dinamis sebagai nilai.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor atau tensor terkuantisasi per-tensor | (C1), (C7), (C10-C12), (C14) |
(I2) | start_indices |
tensor tipe integer | (C2), (C3), (C13) |
(I3) | slice_sizes |
Tensor 1-dimensi dari jenis integer | (C8), (C11-C13) |
(I4) | offset_dims |
Konstanta tensor 1 dimensi jenis si64 |
(C1), (C4-C5), (C13) |
(I5) | collapsed_slice_dims |
Konstanta tensor 1 dimensi jenis si64 |
(C1), (C6-C8), (C13) |
(I6) | start_index_map |
Konstanta tensor 1 dimensi jenis si64 |
(C3), (C9), (C10) |
(I7) | index_vector_dim |
konstanta jenis si64 |
(C2), (C3), (C13) |
(I8) | indices_are_sorted |
konstanta jenis i1 |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau tensor terkuantisasi per-tensor | (C5), (C13-C14) |
Batasan
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims)
. - (C2)
0 <= index_vector_dim <= rank(start_indices)
. - (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1
. - (C4)
is_unique(offset_dims) and is_sorted(offset_dims)
. - (C5)
0 <= offset_dims < rank(result)
. - (C6)
is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims)
. - (C7)
0 <= collapsed_slice_dims < rank(operand)
. - (C8)
slice_sizes[collapsed_slice_dims...] <= 1
. - (C9)
is_unique(start_index_map)
. - (C10)
0 <= start_index_map < rank(operand)
. - (C11)
size(slice_sizes) = rank(operand)
. - (C12)
0 <= slice_sizes <= shape(operand)
. - (C13)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)
jika:batch_dim_sizes = shape(start_indices)
, kecuali bahwa ukuran dimensi daristart_indices
yang sesuai denganindex_vector_dim
tidak disertakan.offset_dim_sizes = shape(slice_sizes)
kecuali bahwa ukuran dimensi dislice_sizes
yang sesuai dengancollapsed_slice_dims
tidak disertakan.combine
menempatkanbatch_dim_sizes
pada sumbu yang sesuai denganbatch_dims
danoffset_dim_sizes
pada sumbu yang sesuai denganoffset_dims
.
- (C14)
element_type(operand) = element_type(result)
.
Contoh
// %operand: [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ]
// %start_indices: [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 2]]
// ]
// %slize_sizes: [1, 2, 2]
%result = "stablehlo.dynamic_gather"(%operand, %start_indices, %slize_sizes) {
dimension_numbers = #stablehlo.gather<
offset_dims = [2, 3],
collapsed_slice_dims = [0],
start_index_map = [1, 0],
index_vector_dim = 2>,
indices_are_sorted = false
} : (tensor<3x4x2xi64>, tensor<2x3x2xi64>, tensor<3xi64>) -> tensor<2x3x2x2xi64>
// %result: [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[9, 10], [11, 12]],
// [[11, 12], [13, 14]],
// [[17, 18], [19, 20]]
// ]
// ]
dynamic_iota
Semantik
Operasi ini secara fungsional identik dengan
Iota
berjalan, tetapi bentuk hasilnya ditentukan secara dinamis melalui output_shape
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | output_shape |
Tensor 1-dimensi dari jenis integer | (C1), C2 |
(I2) | iota_dimension |
si64 |
(C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor bilangan bulat, floating point, atau tipe kompleks atau tensor terkuantisasi per-tensor | (C2) |
Batasan
- (C1)
0 <= iota_dimension < size(output_shape)
. - (C2)
rank(result) = size(output_shape)
.
Contoh
%output_shape = stablehlo.constant dense<[4, 5]> : tensor<2xi64>
%result = "stablehlo.dynamic_iota"(%output_shape) {
iota_dimension = 0 : i64
} : (tensor<2xi64>) -> tensor<4x5xi64>
// %result: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
dynamic_pad
Semantik
Operasi ini secara fungsional identik dengan
pad
operasi, tetapi dengan edge_padding_low
, edge_padding_high
, dan interior_padding
ditetapkan secara dinamis sebagai nilai.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor atau tensor terkuantisasi per-tensor | (C1), (C2), (C4) |
(I2) | padding_value |
Tensor 0-dimensi atau per-tensor quantized tensor | (C1) |
(I3) | edge_padding_low |
Tensor 1-dimensi dari jenis integer | (C1), C4 |
(I4) | edge_padding_high |
Tensor 1-dimensi dari jenis integer | (C1), C4 |
(I5) | interior_padding |
Tensor 1-dimensi dari jenis integer | (C2-C4) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau tensor terkuantisasi per-tensor | (C3-C6) |
Batasan
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result)
. - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)
. - (C3)
0 <= interior_padding
. - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high
.
Contoh
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
// %edge_padding_low: [0, 1]
// %edge_padding_high: [2, 1]
// %interior_padding: [1, 2]
%result = "stablehlo.dynamic_pad"(%operand, %padding_value,
%edge_padding_low, %edge_padding_high, %interior_padding
) : (tensor<2x3xi64>, tensor<i64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<5x9xi64>
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
dynamic_reshape
Semantik
Operasi ini secara fungsional identik dengan
membentuk ulang
berjalan, tetapi bentuk hasilnya ditentukan secara dinamis melalui output_shape
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor atau quantized tensor | (C1-C3) |
(I2) | output_shape |
Tensor 1-dimensi dari jenis integer | (C4) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau quantized tensor | (C1-C4) |
Batasan
- (C1)
element_type(result)
diberikan oleh:element_type(operand)
, jika!is_per_axis_quantized(operand)
.element_type(operand)
kecualiquantization_dimension(operand)
danquantization_dimension(result)
dapat berbeda, jika tidak.
- (C2)
size(operand) = size(result)
. - (C3) Jika
is_per_axis_quantized(operand)
:reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
.reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.
- (C4)
size(output_shape) = rank(result)
.
Contoh
// %operand: [[1, 2, 3], [4, 5, 6]]
// %output_shape: [3, 2]
%result = "stablehlo.dynamic_reshape"(%operand, %output_shape) : (tensor<2x3xi64>, tensor<2xi64>) -> tensor<3x2xi64>
// %result: [[1, 2], [3, 4], [5, 6]]
dynamic_slice
Semantik
Mengekstrak slice dari operand
menggunakan indeks awal yang dikomputasi secara dinamis
dan menghasilkan tensor result
. start_indices
berisi indeks awal dari
irisan untuk setiap dimensi yang dapat disesuaikan dengan potensi penyesuaian, dan slice_sizes
berisi ukuran irisan untuk setiap dimensi. Secara lebih formal,
result[result_index] = operand[operand_index]
dalam hal ini:
adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes)
.operand_index = adjusted_start_indices + result_index
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor atau tensor terkuantisasi per-tensor | (C1), (C2), (C4) |
(I2) | start_indices |
jumlah variadic tensor 0-dimensi dari jenis integer | (C2), (C3) |
(I3) | slice_sizes |
Konstanta tensor 1 dimensi jenis si64 |
(C2), (C4), (C5) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau tensor terkuantisasi per-tensor | (C1), (C5) |
Batasan
- (C1)
element_type(operand) = element_type(result)
. - (C2)
size(start_indices) = size(slice_sizes) = rank(operand)
. - (C3)
same(type(start_indices...))
. - (C4)
0 <= slice_sizes <= shape(operand)
. - (C5)
shape(result) = slice_sizes
.
Contoh
// %operand: [
// [0, 0, 1, 1],
// [0, 0, 1, 1],
// [0, 0, 0, 0],
// [0, 0, 0, 0]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
slice_sizes = array<i64: 2, 2>
} : (tensor<4x4xi32>, tensor<i64>, tensor<i64>) -> tensor<2x2xi32>
// %result: [
// [1, 1],
// [1, 1]
// ]
dynamic_update_slice
Semantik
Menghasilkan tensor result
yang sama dengan tensor operand
, kecuali bahwa
slice yang dimulai dari start_indices
akan diupdate dengan nilai di update
.
Secara lebih formal, result[result_index]
didefinisikan sebagai:
update[update_index]
jika0 <= update_index < shape(update)
di mana:adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update))
.update_index = result_index - adjusted_start_indices
.
operand[result_index]
sebaliknya.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor atau tensor terkuantisasi per-tensor | (C1-C4), (C6) |
(I2) | update |
tensor atau tensor terkuantisasi per-tensor | (C2), (C3), (C6) |
(I3) | start_indices |
jumlah variadic tensor 0-dimensi dari jenis integer | (C4), (C5) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
type(operand) = type(result)
. - (C2)
element_type(update) = element_type(operand)
. - (C3)
rank(update) = rank(operand)
. - (C4)
size(start_indices) = rank(operand)
. - (C5)
same(type(start_indices...))
. - (C6)
0 <= shape(update) <= shape(operand)
.
Contoh
// %operand: [
// [1, 1, 0, 0],
// [1, 1, 0, 0],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
// %update: [
// [1, 1],
// [1, 1]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
: (tensor<4x4xi32>, tensor<2x2xi32>, tensor<i64>, tensor<i64>) -> tensor<4x4xi32>
// %result: [
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
berpangkat
Semantik
Melakukan operasi eksponensial setiap elemen pada tensor operand
dan menghasilkan
Tensor result
. Bergantung pada jenis elemen, melakukan hal berikut:
- Untuk float:
exp
dari IEEE-754. - Untuk bilangan kompleks: eksponensial kompleks.
- Untuk jenis terkuantisasi:
dequantize_op_quantize(exponential, operand, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result)
.
Contoh
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]
exponential_minus_one
Semantik
Menjalankan eksponensial element-wise dikurangi satu operasi pada tensor operand
dan
menghasilkan tensor result
. Bergantung pada jenis elemen, melakukan hal berikut:
- Untuk float:
expm1
dari IEEE-754. - Untuk bilangan kompleks: eksponensial kompleks dikurangi satu.
- Untuk jenis terkuantisasi:
dequantize_op_quantize(exponential_minus_one, operand, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result)
.
Contoh
// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand) : (tensor<2xf64>) -> tensor<2xf64>
// %result: [0.0, 1.71828187]
FFT
Semantik
Melakukan transformasi Fourier maju dan terbalik untuk real dan kompleks input/output lainnya.
fft_type
adalah salah satu dari berikut ini:
FFT
: Meneruskan FFT kompleks ke kompleks.IFFT
: FFT kompleks ke kompleks terbalik.RFFT
: Meneruskan FFT real-ke-kompleks.IRFFT
: FFT real-ke-kompleks terbalik (yaitu yang kompleks, menampilkan nilai nyata).
Secara lebih formal, dengan fungsi fft
yang menggunakan tensor 1 dimensi
tipe kompleks sebagai input, menghasilkan tensor 1-dimensi dari tipe yang sama dengan
dan menghitung transformasi Fourier diskrit:
Untuk fft_type = FFT
, result
ditentukan sebagai hasil akhir deret L
komputasi di mana L = size(fft_length)
. Misalnya, untuk L = 3
:
result1[i0, ..., :] = fft(operand[i0, ..., :])
.result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
.result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])
.
Selain itu, dengan fungsi ifft
yang memiliki jenis tanda tangan dan
menghitung invers fft
:
Untuk fft_type = IFFT
, result
didefinisikan sebagai kebalikan dari komputasi
untuk fft_type = FFT
. Misalnya, untuk L = 3
:
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])
.result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
.result[i0, ..., :] = ifft(result2[i0, ..., :])
.
Selanjutnya, dengan fungsi rfft
yang menggunakan tensor 1 dimensi
tipe floating point, menghasilkan tensor 1 dimensi dari jenis
semantik floating point yang sama dan bekerja sebagai berikut:
rfft(real_operand) = truncated_result
di manacomplex_operand... = (real_operand..., 0.0)
.complex_result = fft(complex_operand)
.truncated_result = complex_result[:(rank(complex_result) / 2 + 1)]
.
(Ketika transformasi Fourier diskrit dihitung untuk operand riil, yang pertama
Elemen N/2 + 1
hasil secara jelas menentukan sisa hasil,
sehingga hasil rfft
terpotong untuk menghindari komputasi elemen yang redundan).
Untuk fft_type = RFFT
, result
ditentukan sebagai hasil akhir deret L
komputasi di mana L = size(fft_length)
. Misalnya, untuk L = 3
:
result1[i0, ..., :] = rfft(operand[i0, ..., :])
.result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
.result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])
.
Terakhir, dengan fungsi irfft
yang memiliki
tanda tangan dan tanda tangan jenis yang sama
menghitung invers rfft
:
Untuk fft_type = IRFFT
, result
didefinisikan sebagai kebalikan dari komputasi
untuk fft_type = RFFT
. Misalnya, untuk L = 3
:
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])
.result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
.result[i0, ..., :] = irfft(result2[i0, ..., :])
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor floating point atau jenis kompleks | (C1), (C2), (C4), (C5) |
(I2) | fft_type |
enum FFT , IFFT , RFFT , dan IRFFT |
(C2), (C5) |
(I3) | fft_length |
Konstanta tensor 1 dimensi jenis si64 |
(C1), (C3), (C4) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor floating point atau jenis kompleks | (C2), (C4), (C5) |
Batasan
- (C1)
size(fft_length) <= rank(operand)
. - (C2) Hubungan antara jenis elemen
operand
danresult
bervariasi:- Jika
fft_type = FFT
,element_type(operand)
, danelement_type(result)
memiliki jenis kompleks yang sama. - Jika
fft_type = IFFT
,element_type(operand)
, danelement_type(result)
memiliki jenis kompleks yang sama. - Jika
fft_type = RFFT
,element_type(operand)
adalah jenis floating point danelement_type(result)
adalah jenis kompleks dari floating point yang sama semantik. - Jika
fft_type = IRFFT
,element_type(operand)
adalah jenis kompleks danelement_type(result)
adalah jenis floating point dari floating point yang sama semantik.
- Jika
- (C3)
1 <= size(fft_length) <= 3
. - (C4) Jika di antara
operand
danresult
, ada tensorreal
dari jenis floating point, kemudianshape(real)[-size(fft_length):] = fft_length
. - (C5)
shape(result) = shape(operand)
kecuali untuk:- Jika
fft_type = RFFT
,dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1
. - Jika
fft_type = IRFFT
,dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1
.
- Jika
Contoh
// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
fft_type = #stablehlo<fft_type FFT>,
fft_length = array<i64: 4>
} : (tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>>
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]
floor
Semantik
Melakukan nilai minimum berdasarkan elemen pada tensor operand
dan menghasilkan tensor result
.
Menerapkan operasi roundToIntegralTowardNegative
dari IEEE-754
spesifikasi
pendukung. Untuk jenis terkuantisasi, menjalankan
dequantize_op_quantize(floor, operand, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor tipe floating point atau tensor terkuantisasi per-tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor tipe floating point atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result)
.
Contoh
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]
kumpulkan
Semantik
Mengumpulkan slice dari tensor operand
dari offset yang ditentukan dalam start_indices
dan menghasilkan tensor result
.
Diagram berikut menunjukkan cara elemen di result
dipetakan pada elemen di
operand
menggunakan contoh konkret. Diagram ini memilih beberapa contoh result
indeks dan menjelaskan secara mendetail indeks operand
mana yang sesuai dengannya.
Secara lebih formal, result[result_index] = operand[operand_index]
dengan:
batch_dims = [d for d in axes(result) and d not in offset_dims]
.batch_index = result_index[batch_dims...]
.start_index
ditentukan sebagai:start_indices[bi0, ..., :, ..., biN]
denganbi
merupakan elemen individual dibatch_index
dan:
dimasukkan pada indeksindex_vector_dim
, jikaindex_vector_dim
rank(start_indices)
.[start_indices[batch_index]]
sebaliknya.
- Untuk
d_operand
diaxes(operand)
,full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand])
jikad_operand = start_index_map[d_start]
.full_start_index[d_operand] = 0
sebaliknya.
- Untuk
d_operand
diaxes(operand)
,full_batching_index[d_operand] = batch_index[d_start - (d_start < index_vector_dim ? 0 : 1)]
jikad_operand = operand_batching_dims[i_batching]
dand_start = start_indices_batching_dims[i_batching]
.full_batching_index[d_operand] = 0
sebaliknya.
offset_index = result_index[offset_dims...]
.full_offset_index = [oi0, ..., 0, ..., oiN]
denganoi
merupakan individu dioffset_index
, dan0
disisipkan pada indeks daricollapsed_slice_dims
danoperand_batching_dims
.operand_index = full_start_index + full_batching_index + full_offset_index
.
Jika indices_are_sorted
adalah true
, implementasi dapat mengasumsikan bahwa
start_indices
diurutkan sehubungan dengan start_index_map
, jika tidak,
perilakunya tidak terdefinisi. Secara lebih formal, untuk semua i1 < i2
dari indices(result)
,
full_start_index(i1) <= full_start_index(i2)
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor atau tensor terkuantisasi per-tensor | (C1), (C8), (C11), (C17), (C19-C21), (C23) |
(I2) | start_indices |
tensor tipe integer | (C2-C3), (C14), (C17), (C22) |
(I3) | offset_dims |
Konstanta tensor 1 dimensi jenis si64 |
(C1), (C4-C5), (C22) |
(I4) | collapsed_slice_dims |
Konstanta tensor 1 dimensi jenis si64 |
(C1), (C6-C9), (C22) |
(I5) | operand_batching_dims |
Konstanta tensor 1 dimensi jenis si64 |
(C1), (C6), (C10-C12), (C16-C18), (C22) |
(I6) | start_indices_batching_dims |
Konstanta tensor 1 dimensi jenis si64 |
(C13-C17) |
(I7) | start_index_map |
Konstanta tensor 1 dimensi jenis si64 |
(C3), (C18-C19) |
(I8) | index_vector_dim |
konstanta jenis si64 |
(C2-C3), (C15), (C22) |
(I9) | slice_sizes |
Konstanta tensor 1 dimensi jenis si64 |
(C9), (C12), (C20-C22) |
(I10) | indices_are_sorted |
konstanta jenis i1 |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau tensor terkuantisasi per-tensor | (C5), (C22-C23) |
Batasan
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims) + size(operand_batching_dims)
. - (C2)
0 <= index_vector_dim <= rank(start_indices)
. - (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1
. - (C4)
is_unique(offset_dims) and is_sorted(offset_dims)
. - (C5)
0 <= offset_dims < rank(result)
. - (C6)
is_unique(concatenate(collapsed_slice_dims, operand_batching_dims))
- (C7)
is_sorted(collapsed_slice_dims)
. - (C8)
0 <= collapsed_slice_dims < rank(operand)
. - (C9)
slice_sizes[collapsed_slice_dims...] <= 1
. - (C10)
is_sorted(operand_batching_dims)
. - (C11)
0 <= operand_batching_dims < rank(operand)
. - (C12)
slice_sizes[operand_batching_dims...] <= 1
. - (C13)
is_unique(start_indices_batching_dims)
. - (C14)
0 <= start_indices_batching_dims < rank(start_indices)
. - (C15)
index_vector_dim not in start_indices_batching_dims
. - (C16)
size(operand_batching_dims) == size(start_indices_batching_dims)
. - (C17)
dim(operand, operand_batching_dims...) = dim(start_indices, start_indices_batching_dims...)
. - (C18)
is_unique(concatenate(start_index_map, operand_batching_dims))
. - (C19)
0 <= start_index_map < rank(operand)
. - (C20)
size(slice_sizes) = rank(operand)
. - (C21)
0 <= slice_sizes <= shape(operand)
. - (C22)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)
jika:batch_dim_sizes = shape(start_indices)
, kecuali bahwa ukuran dimensi daristart_indices
yang sesuai denganindex_vector_dim
tidak disertakan.offset_dim_sizes = slice_sizes
kecuali bahwa ukuran dimensi dislice_sizes
yang sesuai dengancollapsed_slice_dims
danoperand_batching_dims
tidak termasuk.combine
menempatkanbatch_dim_sizes
pada sumbu yang sesuai denganbatch_dims
danoffset_dim_sizes
pada sumbu yang sesuai denganoffset_dims
.
- (C23)
element_type(operand) = element_type(result)
.
Contoh
// %operand: [
// [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ],
// [
// [[25, 26], [27, 28], [29, 30], [31, 32]],
// [[33, 34], [35, 36], [37, 38], [39, 40]],
// [[41, 42], [43, 44], [45, 46], [47, 48]]
// ]
// ]
// %start_indices: [
// [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 9]]
// ],
// [
// [[0, 0], [2, 1], [2, 2]],
// [[1, 2], [0, 1], [1, 0]]
// ]
// ]
%result = "stablehlo.gather"(%operand, %start_indices) {
dimension_numbers = #stablehlo.gather<
offset_dims = [3, 4],
collapsed_slice_dims = [1],
operand_batching_dims = [0],
start_indices_batching_dims = [1],
start_index_map = [2, 1],
index_vector_dim = 3>,
slice_sizes = array<i64: 1, 1, 2, 2>,
indices_are_sorted = false
} : (tensor<2x3x4x2xi32>, tensor<2x2x3x2xi64>) -> tensor<2x2x3x2x2xi32>
// %result: [
// [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[33, 34], [35, 36]],
// [[35, 36], [37, 38]],
// [[41, 42], [43, 44]]
// ]
// ],
// [
// [
// [[1, 2], [3, 4]],
// [[13, 14], [15, 16]],
// [[21, 22], [23, 24]]
// ],
// [
// [[43, 44], [45, 46]],
// [[33, 34], [35, 36]],
// [[27, 28], [29, 30]]
// ]
// ]
// ]
get_dimension_size
Semantik
Menghasilkan ukuran dimension
yang ditentukan dari operand
. Secara lebih formal,
result = dim(operand, dimension)
. Semantik hanya memperhitungkan bentuk
komponen jenis ini. Tipe elemen bisa berupa apa saja.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor atau quantized tensor | (C1) |
(I2) | dimension |
konstanta jenis si64 |
(C1) |
Output
Nama | Jenis |
---|---|
result |
Tensor 0-dimensi jenis si32 |
Batasan
- (C1)
0 <= dimension < rank(operand)
.
Contoh
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
dimension = 1 : i64
} : (tensor<2x3xi64>) -> tensor<i32>
// %result: 3
get_tuple_element
Semantik
Mengekstrak elemen pada posisi index
tuple operand
dan menghasilkan
result
. Secara lebih formal, result = operand[index]
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tuple | (C1), C2 |
(I2) | index |
konstanta jenis si32 |
(C1), C2 |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
semua jenis yang didukung | (C2) |
Batasan
- (C1)
0 <= index < size(operand)
. - (C2)
type(result) = tuple_element_types(operand)[index]
.
Contoh
// %operand: ([1.0, 2.0], (3))
index = 0 : i32
} : (tuple<tensor<2xf32>, tuple<tensor<i32>>>) -> tensor<2xf32>
// %result: [1.0, 2.0]
jika
Semantik
Menghasilkan output dari mengeksekusi tepat satu fungsi dari true_branch
atau
false_branch
bergantung pada nilai pred
. Secara lebih formal, result =
pred ? true_branch() : false_branch()
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | pred |
Tensor 0-dimensi jenis i1 |
|
(I2) | true_branch |
fungsi | (C1-C3) |
(I3) | false_branch |
fungsi | (C1), C2 |
Output
Nama | Jenis | Batasan |
---|---|---|
results |
jumlah tensor, tensor atau token terkuantisasi | (C3) |
Batasan
- (C1)
input_types(true_branch) = input_types(false_branch) = []
. - (C2)
output_types(true_branch) = output_types(false_branch)
. - (C3)
type(results...) = output_types(true_branch)
.
Contoh
// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
"stablehlo.return"(%result_true_branch) : (tensor<i32>) -> ()
}, {
"stablehlo.return"(%result_false_branch) : (tensor<i32>) -> ()
}) : (tensor<i1>) -> tensor<i32>
// %result: 10
gambar
Semantik
Mengekstrak bagian imajiner, berdasarkan elemen, dari operand
dan menghasilkan
Tensor result
. Secara lebih formal, untuk setiap elemen x
:
imag(x) = is_complex(x) ? imaginary_part(x) :
constant(0, element_type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor floating point atau jenis kompleks | (C1), C2 |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor jenis floating point | (C1), C2 |
Batasan
- (C1)
shape(result) = shape(operand)
. - (C2)
element_type(result)
didefinisikan sebagai:complex_element_type(element_type(operand))
jikais_complex(operand)
.element_type(operand)
sebaliknya.
Contoh
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [2.0, 4.0]
dalam feed
Semantik
Membaca data dari dalam feed dan menghasilkan results
.
Semantik infeed_config
adalah ditentukan oleh implementasi.
results
terdiri dari nilai payload yang muncul terlebih dahulu dan token yang muncul
terakhir. Di masa mendatang, kami berencana untuk membagi payload dan token menjadi dua
{i>output<i} terpisah untuk
meningkatkan kejelasan
(#670).
Input
Label | Nama | Jenis |
---|---|---|
(I1) | token |
token |
(I2) | infeed_config |
konstanta jenis string |
Output
Nama | Jenis | Batasan |
---|---|---|
results |
jumlah tensor, tensor atau token terkuantisasi | (C1-C3) |
Batasan
- (C1)
0 < size(results)
. - (C2)
is_empty(result[:-1])
atauis_tensor(type(results[:-1]))
. - (C3)
is_token(type(results[-1]))
.
Contoh
// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]
Iota
Semantik
Mengisi tensor output
dengan nilai dalam urutan yang meningkat mulai dari nol
di sepanjang dimensi iota_dimension
. Secara lebih formal,
output[output_index] = constant(is_quantized(output) ?
quantize(output_index[iota_dimension], element_type(output)) :
output_index[iota_dimension], element_type(output))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | iota_dimension |
si64 |
(C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
output |
tensor bilangan bulat, floating point, atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
0 <= iota_dimension < rank(output)
.
Contoh
%output = "stablehlo.iota"() {
iota_dimension = 0 : i64
} : () -> tensor<4x5xi32>
// %output: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
%output = "stablehlo.iota"() {
iota_dimension = 1 : i64
} : () -> tensor<4x5xi32>
// %output: [
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4]
// ]
is_finite
Semantik
Melakukan pemeriksaan berdasarkan elemen apakah nilai dalam x
terbatas (yaitu bukan
+Inf, -Inf, atau NaN) dan menghasilkan tensor y
. Mengimplementasikan isFinite
operasi dari spesifikasi IEEE-754. Untuk jenis terkuantisasi, hasilnya adalah
selalu true
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | x |
tensor tipe floating point atau tensor terkuantisasi per-tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
y |
tensor jenis boolean | (C1) |
Batasan
- (C1)
shape(x) = shape(y)
.
Contoh
// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x) : (tensor<7xf64) -> tensor<7xi1>
// %y: [false, false, false, true, true, true, true]
log
Semantik
Melakukan operasi logaritma element-wise pada tensor operand
dan menghasilkan
Tensor result
. Bergantung pada jenis elemen, melakukan hal berikut:
- Untuk float:
log
dari IEEE-754. - Untuk bilangan kompleks: logaritma kompleks.
- Untuk jenis terkuantisasi:
dequantize_op_quantize(log, operand, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result)
.
Contoh
// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]
log_plus_one
Semantik
Melakukan logaritma element-wise ditambah satu operasi pada tensor operand
dan
menghasilkan tensor result
. Bergantung pada jenis elemen, melakukan hal berikut:
- Untuk float:
logp1
dari IEEE-754. - Untuk bilangan kompleks: logaritma kompleks ditambah satu.
- Untuk jenis terkuantisasi:
dequantize_op_quantize(log_plus_one, operand, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result)
.
Contoh
// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]
logistik
Semantik
Melakukan operasi logistik berdasarkan elemen pada tensor operand
dan menghasilkan
Tensor result
. Bergantung pada jenis elemen, melakukan hal berikut:
- Untuk float:
division(1, addition(1, exp(-x)))
dari IEEE-754. - Untuk bilangan kompleks: logistik kompleks.
- Untuk jenis terkuantisasi:
dequantize_op_quantize(logistic, operand, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result)
.
Contoh
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]
peta
Semantik
Menerapkan fungsi peta computation
ke inputs
di sepanjang dimensions
dan
menghasilkan tensor result
.
Secara lebih formal, result[result_index] = computation(inputs...[result_index])
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | inputs |
jumlah tensor atau tensor terkuantisasi per-tensor | (C1-C4) |
(I2) | dimensions |
Konstanta tensor 1 dimensi jenis si64 |
(C3) |
(I3) | computation |
fungsi | (C4) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau tensor terkuantisasi per-tensor | (C1), C4 |
Batasan
- (C1)
shape(inputs...) = shape(result)
. - (C2)
0 < size(inputs) = N
. - (C3)
dimensions = range(rank(inputs[0]))
. - (C4)
computation
memiliki jenis(tensor<E0>, ..., tensor<EN-1>) -> tensor<E'>
denganEi = element_type(inputs[i])
danE' = element_type(result)
.
Contoh
// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = stablehlo.multiply %arg0, %arg1 : tensor<i64>
stablehlo.return %0 : tensor<i64>
}) {
dimensions = array<i64: 0, 1>
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[0, 5], [12, 21]]
maksimum
Semantik
Melakukan operasi maksimum berbasis elemen pada tensor lhs
dan rhs
serta menghasilkan
Tensor result
. Bergantung pada jenis elemen, melakukan hal berikut:
- Untuk boolean: OR logis.
- Untuk bilangan bulat: maksimum bilangan bulat.
- Untuk float:
maximum
dari IEEE-754. - Untuk bilangan kompleks: maksimum leksikografis untuk pasangan
(real, imaginary)
. Memaksakan urutan pada bilangan kompleks melibatkan semantik yang mengejutkan, jadi di masa mendatang kami berencana menghapus dukungan untuk bilangan kompleks untuk operasi ini (#560). - Untuk jenis terkuantisasi:
dequantize_op_quantize(maximum, lhs, rhs, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | lhs |
tensor atau tensor terkuantisasi per-tensor | (C1) |
(I2) | rhs |
tensor atau tensor terkuantisasi per-tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Contoh
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 8]]
minimum
Semantik
Melakukan operasi min menurut elemen pada tensor lhs
dan rhs
, serta menghasilkan
Tensor result
. Bergantung pada jenis elemen, melakukan hal berikut:
- Untuk boolean: logika AND.
- Untuk bilangan bulat: minimum bilangan bulat.
- Untuk float:
minimum
dari IEEE-754. - Untuk bilangan kompleks: minimum leksikografis untuk pasangan
(real, imaginary)
. Memaksakan urutan pada bilangan kompleks melibatkan semantik yang mengejutkan, jadi di masa mendatang kami berencana menghapus dukungan untuk bilangan kompleks untuk operasi ini (#560). - Untuk jenis terkuantisasi:
dequantize_op_quantize(minimum, lhs, rhs, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | lhs |
tensor atau tensor terkuantisasi per-tensor | (C1) |
(I2) | rhs |
tensor atau tensor terkuantisasi per-tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Contoh
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 4]]
memperbanyak
Semantik
Menjalankan perkalian berbasis elemen dari dua tensor lhs
dan rhs
serta menghasilkan
Tensor result
. Bergantung pada jenis elemen, melakukan hal berikut:
- Untuk boolean: logika AND.
- Untuk bilangan bulat: perkalian bilangan bulat.
- Untuk float:
multiplication
dari IEEE-754. - Untuk bilangan kompleks: perkalian kompleks.
- Untuk jenis terkuantisasi:
dequantize_op_quantize(multiply, lhs, rhs, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | lhs |
tensor atau tensor terkuantisasi per-tensor | (C1) |
(I2) | rhs |
tensor atau tensor terkuantisasi per-tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result)
.
Contoh
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 12], [21, 32]]
negasi
Semantik
Melakukan negasi berdasarkan elemen dari tensor operand
dan menghasilkan result
tensor. Bergantung pada jenis elemen, melakukan hal berikut:
- Untuk bilangan bulat bertanda tangan: negasi bilangan bulat.
- Untuk bilangan bulat yang tidak ditandatangani: bitcast ke integer bertanda tangan, negasi bilangan bulat, bitcast kembali ke bilangan bulat yang tidak ditandatangani.
- Untuk float:
negate
dari IEEE-754. - Untuk bilangan kompleks: negasi kompleks.
- Untuk jenis terkuantisasi:
dequantize_op_quantize(negate, operand, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor bilangan bulat, floating point, atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor bilangan bulat, floating point, atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result)
.
Contoh
// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand) : (tensor<2xi32>) -> tensor<2xi32>
// %result: [0, 2]
// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"(%operand) : (tensor<1xcomplex<f32>>) -> tensor<1xcomplex<f32>>
// %result: [-2.5, -0.0]
tidak
Semantik
Melakukan NOT tanpa elemen pada tensor operand
dan menghasilkan tensor result
.
Bergantung pada jenis elemen, melakukan hal berikut:
- Untuk boolean: logical NOT.
- Untuk bilangan bulat: bitwise NOT.
Argumen
Nama | Jenis | Batasan |
---|---|---|
operand |
tensor jenis boolean atau bilangan bulat | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor jenis boolean atau bilangan bulat | (C1) |
Batasan
- (C1)
type(operand) = type(result)
.
Contoh
// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand) : (tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[-2, -3], [-4, -5]]
// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"(%operand) : (tensor<2xi1>) -> tensor<2xi1>
// %result: [false, true]
optimization_barrier
Semantik
Memastikan operasi yang menghasilkan operand
dijalankan sebelum
operasi yang bergantung pada result
dan mencegah transformasi compiler
agar tidak memindahkan operasi melintasi penghalang. Selain itu, operasinya
identitas, yaitu result = operand
.
Argumen
Nama | Jenis | Batasan |
---|---|---|
operand |
jumlah tensor, tensor atau token terkuantisasi per-tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
jumlah tensor, tensor atau token terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
type(operand...) = type(result...)
.
Contoh
// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
// %result0: 0.0
// %result1: 1.0
atau
Semantik
Melakukan OR berdasarkan elemen dari dua tensor lhs
dan rhs
serta menghasilkan result
tensor. Bergantung pada jenis elemen, melakukan hal berikut:
- Untuk boolean: OR logis.
- Untuk bilangan bulat: bitwise OR.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | lhs |
tensor integer atau tipe boolean | (C1) |
(I2) | rhs |
tensor integer atau tipe boolean | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor integer atau tipe boolean | (C1) |
Batasan
- (C1)
type(lhs) = type(rhs) = type(result)
.
Contoh
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, true]]
outfeed
Semantik
Menulis inputs
ke outfeed dan menghasilkan token result
.
Semantik outfeed_config
adalah ditentukan oleh implementasi.
Input
Label | Nama | Jenis |
---|---|---|
(I1) | inputs |
jumlah tensor atau tensor terkuantisasi |
(I2) | token |
token |
(I3) | outfeed_config |
konstanta jenis string |
Output
Nama | Jenis |
---|---|
result |
token |
Contoh
%result = "stablehlo.outfeed"(%input0, %token) {
outfeed_config = ""
} : (tensor<2x2x2xi64>, !stablehlo.token) -> !stablehlo.token
alas
Semantik
Memperluas operand
dengan padding di sekitar tensor serta di antara elemen
tensor dengan padding_value
yang diberikan.
edge_padding_low
dan edge_padding_high
menentukan jumlah padding yang ditambahkan
di kelas bawah (di samping indeks 0) dan kelas atas (di samping indeks tertinggi) dalam
masing-masing dimensi. Jumlah padding bisa negatif, di mana
nilai absolut padding negatif menunjukkan jumlah elemen yang akan dihapus
dari dimensi yang ditentukan.
interior_padding
menentukan jumlah padding yang ditambahkan di antara dua padding
di setiap dimensi yang mungkin tidak negatif. Padding interior terjadi
sebelum pelapis tepi sehingga pelapis tepi negatif akan membuang elemen dari
operand dengan padding interior.
Secara lebih formal, result[result_index]
didefinisikan sebagai:
operand[operand_index]
jikaresult_index = edge_padding_low + operand_index * (interior_padding + 1)
.padding_value
sebaliknya.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor atau tensor terkuantisasi per-tensor | (C1), (C2), (C4) |
(I2) | padding_value |
Tensor 0-dimensi atau per-tensor quantized tensor | (C1) |
(I3) | edge_padding_low |
Konstanta tensor 1 dimensi jenis si64 |
(C1), C4 |
(I4) | edge_padding_high |
Konstanta tensor 1 dimensi jenis si64 |
(C1), C4 |
(I5) | interior_padding |
Konstanta tensor 1 dimensi jenis si64 |
(C2-C4) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau tensor terkuantisasi per-tensor | (C3-C6) |
Batasan
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result)
. - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)
. - (C3)
0 <= interior_padding
. - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high
.
Contoh
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
edge_padding_low = array<i64: 0, 1>,
edge_padding_high = array<i64: 2, 1>,
interior_padding = array<i64: 1, 2>
} : (tensor<2x3xi32>, tensor<i32>) -> tensor<5x9xi32>
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
partition_id
Semantik
Menghasilkan partition_id
dari proses saat ini.
Output
Nama | Jenis |
---|---|
result |
Tensor 0-dimensi jenis ui32 |
Contoh
%result = "stablehlo.partition_id"() : () -> tensor<ui32>
Popcnt
Semantik
Melakukan penghitungan jumlah bit berdasarkan elemen pada tensor operand
dan menghasilkan tensor result
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor tipe integer | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor tipe integer | (C1) |
Batasan
- (C1)
type(operand) = type(result)
.
Contoh
// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand) : (tensor<4xi64>) -> tensor<4xi64>
// %result: [0, 1, 1, 7]
daya
Semantik
Melakukan eksponensial element-wise dari tensor lhs
dengan tensor rhs
dan
menghasilkan tensor result
. Bergantung pada jenis elemen, melakukan hal berikut:
- Untuk bilangan bulat: eksponensial bilangan bulat.
- Untuk float:
pow
dari IEEE-754. - Untuk bilangan kompleks: eksponensial kompleks.
- Untuk jenis terkuantisasi:
dequantize_op_quantize(power, lhs, rhs, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | lhs |
tensor bilangan bulat, floating point, atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1) |
(I2) | rhs |
tensor bilangan bulat, floating point, atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor bilangan bulat, floating point, atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result)
.
Contoh
// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs) : (tensor<6xf64>, tensor<6xf64>) -> tensor<6xf64>
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]
real
Semantik
Mengekstrak bagian asli, berdasarkan elemen, dari operand
dan menghasilkan result
tensor. Secara lebih formal, untuk setiap elemen x
:
real(x) = is_complex(x) ? real_part(x) : x
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor floating point atau jenis kompleks | (C1), C2 |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor jenis floating point | (C1), C2 |
Batasan
- (C1)
shape(result) = shape(operand)
. - (C2)
element_type(result)
didefinisikan sebagai:complex_element_type(element_type(operand))
jikais_complex(operand)
.element_type(operand)
sebaliknya.
Contoh
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [1.0, 3.0]
rekomendasi
Semantik
Menerima data dari saluran dengan channel_id
dan menghasilkan results
.
Jika is_host_transfer
adalah true
, operasi tersebut akan mentransfer data dari
{i>host<i}. Jika tidak, sistem akan mentransfer data dari perangkat lain. Artinya adalah
ditentukan oleh implementasi. Penanda ini menduplikasi
informasi yang diberikan di
channel_type
, jadi di masa mendatang kami berencana untuk menyimpan salah satunya saja
(#666).
results
terdiri dari nilai payload yang muncul terlebih dahulu dan token yang muncul
terakhir. Di masa mendatang, kami berencana untuk membagi payload dan token menjadi dua
{i>output<i} terpisah untuk
meningkatkan kejelasan
(#670).
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | token |
token |
(C4) |
(I2) | channel_id |
konstanta jenis si64 |
|
(I3) | channel_type |
enum DEVICE_TO_DEVICE dan HOST_TO_DEVICE |
(C1) |
(I4) | is_host_transfer |
konstanta jenis i1 |
(C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
results |
jumlah tensor, tensor atau token terkuantisasi | (C2-C4) |
Batasan
- (C1)
channel_type
didefinisikan sebagai:HOST_TO_DEVICE
jikais_host_transfer = true
,DEVICE_TO_DEVICE
sebaliknya.
- (C2)
0 < size(results)
. - (C3)
is_empty(result[:-1])
atauis_tensor(type(results[:-1]))
. - (C4)
is_token(type(results[-1]))
.
Contoh
%results0, %results1 = "stablehlo.recv"(%token) {
channel_handle = #stablehlo.channel_handle<handle = 1, type = 3>,
is_host_transfer = true
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
reduce
Semantik
Menerapkan fungsi pengurangan body
ke inputs
dan init_values
di sepanjang
dimensions
dan menghasilkan tensor results
.
Urutan pengurangan ditentukan oleh implementasi, yang berarti bahwa body
dan
init_values
harus membentuk monoid untuk menjamin bahwa operasi tersebut menghasilkan
hasil yang sama untuk semua
input pada semua implementasi. Namun, kondisi ini
tidak dapat digunakan untuk banyak pengurangan populer. Mis. penambahan floating point untuk
body
dan nol untuk init_values
tidak benar-benar membentuk monoid karena
penambahan floating point tidak asosiatif.
Secara lebih formal, results...[j0, ..., jR-1] = reduce(input_slices_converted)
dengan:
input_slices = inputs...[j0, ..., :, ..., jR-1]
, tempat:
disisipkan pukuldimensions
.input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...)
.init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...)
.reduce(input_slices_converted) = exec(schedule)
untuk beberapa hierarki binerschedule
dengan:exec(node) = body(exec(node.left), exec(node.right))
.exec(leaf) = leaf.value
.
schedule
adalah hierarki biner lengkap yang ditentukan implementasi yang sesuai urutan traversal terdiri dari:- Nilai
input_slices_converted...[index]
, untuk semuaindex
diindex_space(input_slices_converted)
dalam urutan leksikografis menaik dariindex
. - Diselingi dengan sejumlah
init_values_converted
di posisi yang ditentukan penerapan.
- Nilai
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | inputs |
jumlah tensor atau tensor terkuantisasi per-tensor | (C1-C4), (C6), (C7) |
(I2) | init_values |
jumlah variadik tensor 0-dimensi atau tensor terkuantisasi per-tensor | (C2), (C3) |
(I3) | dimensions |
Konstanta tensor 1 dimensi jenis si64 |
(C4), (C5), (C7) |
(I4) | body |
fungsi | (C6) |
Output
Nama | Jenis | Batasan |
---|---|---|
results |
jumlah tensor atau tensor terkuantisasi per-tensor | (C3), (C7), (C8) |
Batasan
- (C1)
same(shape(inputs...))
. - (C2)
element_type(inputs...) = element_type(init_values...)
. - (C3)
0 < size(inputs) = size(init_values) = size(results) = N
. - (C4)
0 <= dimensions < rank(inputs[0])
. - (C5)
is_unique(dimensions)
. - (C6)
body
memiliki jenis(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,
tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
di manais_promotable(element_type(inputs[i]), Ei)
. - (C7)
shape(results...) = shape(inputs...)
kecuali bahwa dimensi ukuraninputs...
yang sesuai dengandimensions
tidak disertakan. - (C8)
element_type(results[i]) = Ei
untuk semuai
di[0,N)
.
Contoh
// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
dimensions = array<i64: 1>
} : (tensor<1x6xi64>, tensor<i64>) -> tensor<1xi64>
// %result = [15]
reduce_precision
Semantik
Melakukan konversi operand
berbasis elemen ke jenis floating point lainnya
yang menggunakan exponent_bits
dan mantissa_bits
serta kembali ke versi aslinya
jenis floating point dan menghasilkan tensor output
.
Secara lebih formal:
- Bit mantissa nilai asli diperbarui untuk membulatkan nilai asli
ke nilai terdekat yang dapat diwakili dengan
mantissa_bits
menggunakan SemantikroundToIntegralTiesToEven
. - Kemudian, jika
mantissa_bits
lebih kecil dari jumlah bit mantissa nilai asli, bit mantissa akan dipotong menjadimantissa_bits
. - Lalu, jika bit eksponen dari hasil perantara tidak sesuai dengan
yang disediakan oleh
exponent_bits
, hasil perantara akan melebihi tak terhingga menggunakan tanda awal atau {i>underflows<i} ke nol menggunakan tanda aslinya. - Untuk jenis terkuantisasi, menjalankan
dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor tipe floating point atau tensor terkuantisasi per-tensor | (C1) |
(I2) | exponent_bits |
konstanta jenis si32 |
(C2) |
(I3) | mantissa_bits |
konstanta jenis si32 |
(C3) |
Output
Nama | Jenis | Batasan |
---|---|---|
output |
tensor tipe floating point atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(output)
. - (C2)
1 <= exponent_bits
. - (C3)
0 <= mantissa_bits
.
Contoh
// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
exponent_bits = 5 : i32,
mantissa_bits = 10 : i32
} : (tensor<6xf64>) -> tensor<6xf64>
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]
reduce_scatter
Semantik
Dalam setiap grup proses di {i>process grid<i}
StabilHLO, melakukan pengurangan,
menggunakan computations
, terhadap nilai tensor operand
dari setiap proses,
membagi hasil pengurangan di sepanjang scatter_dimension
menjadi beberapa bagian, dan sebar
bagian yang terpisah di antara proses
untuk menghasilkan result
.
Operasi tersebut membagi grid proses StableHLO menjadi process_groups
yang
didefinisikan sebagai berikut:
cross_replica(replica_groups)
jikachannel_id <= 0 and use_global_device_ids = false
.cross_replica_and_partition(replica_groups)
jikachannel_id > 0 and use_global_device_ids = false
.flattened_ids(replica_groups)
jikachannel_id > 0 and use_global_device_ids = true
.
Setelah itu, dalam setiap process_group
:
reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation)
.parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension)
.result@receiver = parts@sender[receiver_index]
untuk semuasender
diprocess_group
, denganreceiver_index = process_group.index(receiver)
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor atau tensor terkuantisasi per-tensor | (C1), (C2), (C7), (C8) |
(I2) | scatter_dimension |
konstanta jenis si64 |
(C1), C2 (C8), dan (C8) |
(I3) | replica_groups |
Konstanta tensor 2 dimensi jenis si64 |
(C3-C5) |
(I4) | channel_id |
konstanta jenis si64 |
(C6) |
(I5) | use_global_device_ids |
konstanta jenis i1 |
(C6) |
(I6) | computation |
fungsi | (C7) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau tensor terkuantisasi per-tensor | (C8-C9) |
Batasan
- (C1)
dim(operand, scatter_dimension) % dim(process_groups, 1) = 0
. - (C2)
0 <= scatter_dimension < rank(operand)
. - (C3)
is_unique(replica_groups)
. - (C4)
size(replica_groups)
ditentukan sebagai:num_replicas
jikacross_replica
digunakan.num_replicas
jikacross_replica_and_partition
digunakan.num_processes
jikaflattened_ids
digunakan.
- (C5)
0 <= replica_groups < size(replica_groups)
. - (C6) Jika
use_global_device_ids = true
, makachannel_id > 0
. - (C7)
computation
memiliki jenis(tensor<E>, tensor<E>) -> (tensor<E>)
, denganis_promotable(element_type(operand), E)
. - (C8)
shape(result) = shape(operand)
kecuali:dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1)
.
- (C9)
element_type(result) = E
.
Contoh
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
scatter_dimension = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x4xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[10, 12],
// [18, 20]]
// %result@(1, 0): [[14, 16],
// [22, 24]]
reduce_window
Semantik
Menerapkan fungsi pengurangan body
ke jendela inputs
dan init_values
dan menghasilkan results
.
Diagram berikut menunjukkan cara penghitungan elemen pada results...
inputs...
menggunakan contoh konkret.
Secara lebih formal,
results...[result_index] = reduce(windows, init_values, axes(inputs...), body)
(lihat mengurangi) jika:
padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1)
.window_start = result_index * window_strides
.window_end = window_start + (window_dimensions - 1) * window_dilations + 1
.windows = slice(padded_inputs..., window_start, window_end, window_dilations)
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | inputs |
jumlah tensor atau tensor terkuantisasi per-tensor | (C1-C4), (C6), (C8), (C10), (C12), (C13), (C15) |
(I2) | init_values |
jumlah variadik tensor 0-dimensi atau tensor terkuantisasi per-tensor | (C1), C13 |
(I3) | window_dimensions |
Konstanta tensor 1 dimensi jenis si64 |
(C4), (C5), (C15) |
(I4) | window_strides |
Konstanta tensor 1 dimensi jenis si64 |
(C6), (C7), (C15) |
(I5) | base_dilations |
Konstanta tensor 1 dimensi jenis si64 |
(C8), (C9), (C15) |
(I6) | window_dilations |
Konstanta tensor 1 dimensi jenis si64 |
(C10), (C11), (C15) |
(I7) | padding |
Konstanta tensor 2 dimensi jenis si64 |
(C12), (C15) |
(I8) | body |
fungsi | (C13) |
Output
Nama | Jenis | Batasan |
---|---|---|
results |
jumlah tensor atau tensor terkuantisasi per-tensor | (C1), (C14-C16) |
Batasan
- (C1)
0 < size(inputs) = size(init_values) = size(results) = N
. - (C2)
same(shape(inputs...))
. - (C3)
element_type(inputs...) = element_type(init_values...)
. - (C4)
size(window_dimensions) = rank(inputs[0])
. - (C5)
0 < window_dimensions
. - (C6)
size(window_strides) = rank(inputs[0])
. - (C7)
0 < window_strides
. - (C8)
size(base_dilations) = rank(inputs[0])
. - (C9)
0 < base_dilations
. - (C10)
size(window_dilations) = rank(inputs[0])
. - (C11)
0 < window_dilations
. - (C12)
shape(padding) = [rank(inputs[0]), 2]
. - (C13)
body
memiliki jenis(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,
tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
di manais_promotable(element_type(inputs[i]), Ei)
. - (C14)
same(shape(results...))
. - (C15)
shape(results[0]) = num_windows
jika:dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1
.padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1]
.dilated_window_shape = (window_dimensions - 1) * window_dilations + 1
.is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape
.num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1
.
- (C16)
element_type(results[i]) = Ei
untuk semuai
di[0,N)
.
Contoh
// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
window_dimensions = array<i64: 2, 1>,
window_strides = array<i64: 4, 1>,
base_dilations = array<i64: 2, 1>,
window_dilations = array<i64: 3, 1>,
padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<3x2xi64>, tensor<i64>) -> tensor<2x2xi64>
// %result = [[0, 0], [3, 4]]
sisa
Semantik
Melakukan sisa tensor dividen lhs
dan pembagi rhs
serta
menghasilkan tensor result
.
Secara lebih formal, tanda hasil diambil
dari pembagian, dan
nilai absolut hasil selalu lebih kecil dari nilai absolut pembagi.
Sisanya dihitung sebagai lhs - d * rhs
, dengan d
diberikan oleh:
- Untuk bilangan bulat:
stablehlo.divide(lhs, rhs)
. - Untuk float:
division(lhs, rhs)
dari IEEE-754 dengan atribut pembulatanroundTowardZero
. - Untuk bilangan kompleks: TBD (#997).
- Untuk jenis terkuantisasi:
dequantize_op_quantize(remainder, lhs, rhs, type(result))
.
Untuk jenis elemen floating point, operasi ini berbeda dengan
Operasi remainder
dari spesifikasi IEEE-754 dengan d
sebagai nilai integral
terdekat dengan nilai pasti lhs/rhs
yang memiliki nilai yang sama.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | lhs |
tensor integer, floating-point, atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1) |
(I2) | rhs |
tensor integer, floating-point, atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor integer, floating-point, atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result)
.
Contoh
// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64>
// %result: [2, -2, 2, -2]
replica_id
Semantik
Menghasilkan replica_id
dari proses saat ini.
Output
Nama | Jenis |
---|---|
result |
Tensor 0-dimensi jenis ui32 |
Contoh
%result = "stablehlo.replica_id"() : () -> tensor<ui32>
bentuk ulang
Semantik
Melakukan pembentukan ulang tensor operand
ke tensor result
. Secara konseptual,
sama dengan mempertahankan representasi kanonis yang sama, tetapi berpotensi mengubah
bentuknya, misalnya dari tensor<2x3xf32>
hingga tensor<3x2xf32>
atau tensor<6xf32>
.
Secara lebih formal, result[result_index] = operand[operand_index]
di mana
result_index
dan operand_index
memiliki posisi yang sama dalam leksikografis
urutan index_space(result)
dan index_space(operand)
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor atau quantized tensor | (C1-C3) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau quantized tensor | (C1-C3) |
Batasan
- (C1)
element_type(result)
diberikan oleh:element_type(operand)
, jika!is_per_axis_quantized(operand)
.element_type(operand)
kecualiquantization_dimension(operand)
danquantization_dimension(result)
dapat berbeda, jika tidak.
- (C2)
size(operand) = size(result)
. - (C3) Jika
is_per_axis_quantized(operand)
:reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
.reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.
Contoh
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.reshape"(%operand) : (tensor<2x3xi32>) -> tensor<3x2xi32>
// %result: [[1, 2], [3, 4], [5, 6]]
balik
Semantik
Membalik urutan elemen dalam operand
di sepanjang dimensions
yang ditentukan
dan menghasilkan tensor result
. Secara lebih formal,
result[result_index] = operand[operand_index]
dalam hal ini:
operand_index[d] = dim(result, d) - result_index[d] - 1
jikad
dalamdimensions
.operand_index[d] = result_index[d]
sebaliknya.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor atau tensor terkuantisasi per-tensor | (C1), C3 |
(I2) | dimensions |
Konstanta tensor 1 dimensi jenis si64 |
(C2), (C3) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau tensor terkuantisasi per-tensor | (C1), C3 |
Batasan
- (C1)
type(operand) = type(result)
. - (C2)
is_unique(dimensions)
. - (C3)
0 <= dimensions < rank(result)
.
Contoh
// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
dimensions = array<i64: 1>
} : (tensor<3x2xi32>) -> tensor<3x2xi32>
// %result: [[2, 1], [4, 3], [6, 5]]
rng
Semantik
Menghasilkan angka acak menggunakan algoritma rng_distribution
dan menghasilkan
Tensor result
dari bentuk tertentu shape
.
Jika rng_distribution = UNIFORM
, angka acak akan dibuat
setelah distribusi seragam selama interval [a, b)
. Jika a >= b
,
perilakunya tidak terdefinisi.
Jika rng_distribution = NORMAL
, angka acak akan dibuat
mengikuti distribusi normal dengan rata-rata = a
dan deviasi standar = b
.
Jika b < 0
, perilaku tidak ditentukan.
Cara persis bagaimana angka acak dibuat ditentukan oleh implementasi. Sebagai misalnya, mereka mungkin determenistik, dan mereka mungkin atau tidak status tersembunyi.
Dalam percakapan dengan banyak pemangku kepentingan, operasi ini telah muncul seefektif tidak digunakan lagi, jadi di masa mendatang kami berencana untuk mencoba menghapusnya (#597).
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | a |
Tensor 0-dimensi dari jenis integer, boolean, atau floating point | (C1), C2 |
(I2) | b |
Tensor 0-dimensi dari jenis integer, boolean, atau floating point | (C1), C2 |
(I3) | shape |
Konstanta tensor 1 dimensi jenis si64 |
(C3) |
(I4) | rng_distribution |
enum UNIFORM dan NORMAL |
(C2) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor bilangan bulat, boolean, atau jenis floating point | (C1-C3) |
Batasan
- (C1)
element_type(a) = element_type(b) = element_type(result)
. - (C2) Jika
rng_distribution = NORMAL
, makais_float(a)
. - (C3)
shape(result) = shape
.
Contoh
// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
rng_distribution = #stablehlo<rng_distribution UNIFORM>
} : (tensor<i32>, tensor<i32>, tensor<2xi64>) -> tensor<3x3xi32>
// %result: [
// [1, 0, 1],
// [1, 1, 1],
// [0, 0, 0]
// ]
rng_bit_generator
Semantik
Menampilkan output
yang diisi dengan bit acak yang seragam dan status output yang diperbarui
output_state
menggunakan algoritma generator angka pseudorandom rng_algorithm
diberi status awal initial_state
. Outputnya dijamin akan
fungsi determenistik initial_state
, tetapi tidak dijamin akan
determenistik di antara implementasi.
rng_algorithm
adalah salah satu dari berikut ini:
DEFAULT
: Algoritma yang ditentukan implementasi.THREE_FRY
: Varian yang ditentukan berdasarkan implementasi dari algoritma Threefry.*PHILOX
: Varian algoritma Philox yang ditentukan berdasarkan implementasi.*
* Lihat: Salmon et al. SC 2011. Bilangan acak paralel: semudah 1, 2, 3.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | rng_algorithm |
enum DEFAULT , THREE_FRY , dan PHILOX |
(C2) |
(I2) | initial_state |
Tensor 1 dimensi dari jenis ui64 |
(C1), C2 |
Output
Nama | Jenis | Batasan |
---|---|---|
output_state |
Tensor 1 dimensi dari jenis ui64 |
(C1) |
output |
tensor jenis bilangan bulat atau floating point |
Batasan
- (C1)
type(initial_state) = type(output_state)
. - (C2)
size(initial_state)
didefinisikan sebagai:- yang ditentukan implementasi jika
rng_algorithm = DEFAULT
. 2
jikarng_algorithm = THREE_FRY
.2
atau3
jikarng_algorithm = PHILOX
.
- yang ditentukan implementasi jika
Contoh
// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
rng_algorithm = #stablehlo<rng_algorithm THREE_FRY>
} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x2xui64>)
// %output_state: [1, 6]
// %output: [
// [9236835810183407956, 16087790271692313299],
// [18212823393184779219, 2658481902456610144]
// ]
round_nearest_afz
Semantik
Melakukan pembulatan berbasis elemen ke bilangan bulat terdekat, memisahkan ikatan
dari nol, pada tensor operand
dan menghasilkan tensor result
. Implementasi
operasi roundToIntegralTiesToAway
dari spesifikasi IEEE-754. Sebagai
jenis terkuantisasi, melakukan
dequantize_op_quantize(round_nearest_afz, operand, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor tipe floating point atau tensor terkuantisasi per-tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor tipe floating point atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result)
.
Contoh
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]
round_nearest_even
Semantik
Melakukan pembulatan berbasis elemen ke bilangan bulat terdekat, yang memutus ikatan
terhadap bilangan bulat genap, pada tensor operand
dan menghasilkan result
tensor. Menerapkan operasi roundToIntegralTiesToEven
dari IEEE-754
spesifikasi
pendukung. Untuk jenis terkuantisasi, menjalankan
dequantize_op_quantize(round_nearest_even, operand, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor tipe floating point atau tensor terkuantisasi per-tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor tipe floating point atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result)
.
Contoh
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]
rsqrt
Semantik
Melakukan operasi root kuadrat timbal balik berdasarkan elemen pada tensor operand
dan
menghasilkan tensor result
. Bergantung pada jenis elemen, melakukan hal berikut:
- Untuk float:
rSqrt
dari IEEE-754. - Untuk bilangan kompleks: akar kuadrat timbal balik kompleks.
- Untuk jenis terkuantisasi:
dequantize_op_quantize(rsqrt, operand, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result)
.
Contoh
// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.5], [0.33333343, 0.2]]
{i>scatter <i}(memencar)
Semantik
Menghasilkan tensor results
yang sama dengan tensor inputs
, kecuali bahwa
beberapa irisan yang ditentukan oleh scatter_indices
diperbarui dengan nilai
updates
menggunakan update_computation
.
Diagram berikut menunjukkan cara elemen di updates...
dipetakan pada elemen di
results...
menggunakan contoh konkret. Diagram ini memilih beberapa contoh
updates...
membuat indeks dan menjelaskan secara mendetail results...
mana yang menghitungnya
yang sesuai.
Secara lebih formal, untuk semua update_index
di index_space(updates[0])
:
update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims]
.update_scatter_index = update_index[update_scatter_dims...]
.start_index
ditentukan sebagai:scatter_indices[si0, ..., :, ..., siN]
dengansi
merupakan individu elemen dalamupdate_scatter_index
dan:
dimasukkan di indeksindex_vector_dim
, jikaindex_vector_dim
<rank(scatter_indices)
.[scatter_indices[update_scatter_index]]
sebaliknya.
- Untuk
d_input
diaxes(inputs[0])
,full_start_index[d_input] = start_index[d_start]
jikad_input = scatter_dims_to_operand_dims[d_start]
.full_start_index[d_input] = 0
sebaliknya.
- Untuk
d_input
diaxes(inputs[0])
,full_batching_index[d_input] = update_scatter_index[d_start - (d_start < index_vector_dim ? 0 : 1)]
jikad_input = input_batching_dims[i_batching]
dand_start = scatter_indices_batching_dims[i_batching]
.full_batching_index[d_input] = 0
sebaliknya.
update_window_index = update_index[update_window_dims...]
.full_window_index = [wi0, ..., 0, ..., wiN]
denganwi
merupakan individu diupdate_window_index
, dan0
disisipkan pada indeks dariinserted_window_dims
daninput_batching_dims
.result_index = full_start_index + full_batching_index + full_window_index
.
Oleh karena itu, results = exec(schedule, inputs)
, dalam hal:
schedule
adalah permutasi yang ditentukan implementasi dariindex_space(updates[0])
.exec([update_index, ...], results) = exec([...], updated_results)
dalam hal ini:- Jika
result_index
termasuk dalam batasshape(results...)
updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )
updated_values = update_computation(results...[result_index], updates_converted)
updated_results
adalah salinanresults
denganresults...[result_index]
disetel keupdated_values...
.- Atau
updated_results = results
.
- Jika
exec([], results) = results
.
Jika indices_are_sorted
adalah true
, implementasi dapat mengasumsikan bahwa
scatter_indices
diurutkan sehubungan dengan scatter_dims_to_operand_dims
,
jika tidak, perilaku
tidak terdefinisi. Secara lebih formal, untuk semua i1 < i2
dari
indices(result)
, full_start_index(i1)
<= full_start_index(i2)
.
Jika unique_indices
adalah true
, implementasi dapat mengasumsikan bahwa semua
Indeks result_index
yang tersebar bersifat unik. Jika unique_indices
adalah
true
, tetapi indeks yang tersebar tidak unik maka perilakunya adalah
tidak terdefinisi.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | inputs |
jumlah tensor atau tensor terkuantisasi per-tensor | (C1), (C2), (C4-C6), (C11), (C13), (C18), (C21), (C23-C24) |
(I2) | scatter_indices |
tensor tipe integer | (C4), (C15), (C19), (C22) |
(I3) | updates |
jumlah tensor atau tensor terkuantisasi per-tensor | (C3-C6), (C8) |
(I4) | update_window_dims |
Konstanta tensor 1 dimensi jenis si64 |
(C2), (C4), (C7-C8) |
(I5) | inserted_window_dims |
Konstanta tensor 1 dimensi jenis si64 |
(C2), (C4), (C9-C11) |
(I6) | input_batching_dims |
Konstanta tensor 1 dimensi jenis si64 |
(C2), (C4), (C9), (C12-13), (C17-18), (C20) |
(I7) | scatter_indices_batching_dims |
Konstanta tensor 1 dimensi jenis si64 |
(C14-C18) |
(I8) | scatter_dims_to_operand_dims |
Konstanta tensor 1 dimensi jenis si64 |
(C19-C21) |
(I9) | index_vector_dim |
konstanta jenis si64 |
(C4), (C16), (C19), (C22) |
(I10) | indices_are_sorted |
konstanta jenis i1 |
|
(I11) | unique_indices |
konstanta jenis i1 |
|
(I12) | update_computation |
fungsi | (C23) |
Output
Nama | Jenis | Batasan |
---|---|---|
results |
jumlah tensor atau tensor terkuantisasi per-tensor | (C24-C25) |
Batasan
- (C1)
same(shape(inputs...))
. - (C2) `rank(inputs[0]) = size(update_window_dims) + ukuran(inserted_window_dims)
- size(input_batching_dims)`.
- (C3)
same(shape(updates...))
. - (C4)
shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes)
jika:update_scatter_dim_sizes = shape(scatter_indices)
kecuali bahwa ukuran dimensiscatter_indices
yang sesuai denganindex_vector_dim
tidak termasuk.update_window_dim_sizes <= shape(inputs[0])
kecuali bahwa ukuran dimensi diinputs[0]
yang sesuai denganinserted_window_dims
daninput_batching_dims
tidak termasuk.combine
menempatkanupdate_scatter_dim_sizes
pada sumbu yang sesuai denganupdate_scatter_dims
danupdate_window_dim_sizes
pada sumbu yang sesuai keupdate_window_dims
.
- (C5)
0 < size(inputs) = size(updates) = N
. - (C6)
element_type(updates...) = element_type(inputs...)
. - (C7)
is_unique(update_window_dims) and is_sorted(update_window_dims)
. - (C8)
0 <= update_window_dims < rank(updates[0])
. - (C9)
is_unique(concatenate(inserted_window_dims, input_batching_dims))
- (C10)
is_sorted(inserted_window_dims)
. - (C11)
0 <= inserted_window_dims < rank(inputs[0])
. - (C12)
is_sorted(input_batching_dims)
. - (C13)
0 <= input_batching_dims < rank(inputs[0]))
. - (C14)
is_unique(scatter_indices_batching_dims)
. - (C15)
0 <= scatter_indices_batching_dims < rank(scatter_indices)
. - (C16)
index_vector_dim not in scatter_indices_batching_dims
. - (C17)
size(input_batching_dims) == size(scatter_indices_batching_dims)
. - (C18)
dim(inputs[0], input_batching_dims...) = dim(scatter_indices, scatter_indices_batching_dims...)
. - (C19)
size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1
. - (C20)
is_unique(concatenate(scatter_dims_to_operand_dims, input_batching_dims))
. - (C21)
0 <= scatter_dims_to_operand_dims < rank(inputs[0])
. - (C22)
0 <= index_vector_dim <= rank(scatter_indices)
. - (C23)
update_computation
memiliki jenis(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
, denganis_promotable(element_type(inputs[i]), Ei)
. - (C24)
shape(inputs...) = shape(results...)
. - (C25)
element_type(results[i]) = Ei
untuk semuai
di[0,N)
.
Contoh
// %input: [
// [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ],
// [
// [[25, 26], [27, 28], [29, 30], [31, 32]],
// [[33, 34], [35, 36], [37, 38], [39, 40]],
// [[41, 42], [43, 44], [45, 46], [47, 48]]
// ]
// ]
// %scatter_indices: [
// [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 9]]
// ],
// [
// [[0, 0], [2, 1], [2, 2]],
// [[1, 2], [0, 1], [1, 0]]
// ]
// ]
// %update: [
// [
// [[1, 1], [1, 1], [1, 1]],
// [[1, 1], [1, 1], [1, 1]]
// ],
// [
// [[1, 1], [1, 1], [1, 1]],
// [[1, 1], [1, 1], [1, 1]]
// ]
// ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [3, 4],
inserted_window_dims = [1],
input_batching_dims = [0],
scatter_indices_batching_dims = [1],
scatter_dims_to_operand_dims = [2, 1],
index_vector_dim = 3>,
indices_are_sorted = false,
unique_indices = false
} : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>, tensor<2x2x3x2x2xi64>) -> tensor<2x3x4x2xi64>
// %result: [
// [
// [[3, 4], [6, 7], [6, 7], [7, 8]],
// [[9, 10],[11, 12], [15, 16], [17, 18]],
// [[17, 18], [19, 20], [22, 23], [24, 25]]
// ],
// [
// [[25, 26], [28, 29], [30, 31], [31, 32]],
// [[35, 36], [38, 39], [38, 39], [39, 40]],
// [[41, 42], [44, 45], [46, 47], [47, 48]]
// ]
// ]
pilih
Semantik
Menghasilkan tensor result
dengan setiap elemen dipilih dari on_true
atau
Tensor on_false
berdasarkan nilai elemen pred
yang sesuai.
Secara lebih formal, result[result_index] = pred_element ? on_true[result_index] :
on_false[result_index]
, dengan pred_element = rank(pred) = 0 ? pred[] :
pred[result_index]
. Untuk jenis terkuantisasi, menjalankan
dequantize_select_quantize(pred, on_true, on_false, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | pred |
tensor jenis i1 |
(C1) |
(I2) | on_true |
tensor atau tensor terkuantisasi per-tensor | (C1-C2) |
(I3) | on_false |
tensor atau tensor terkuantisasi per-tensor | (C2) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau tensor terkuantisasi per-tensor | (C2) |
Batasan
- (C1)
rank(pred) = 0 or shape(pred) = shape(on_true)
. - (C2)
baseline_type(on_true) = baseline_type(on_false) = baseline_type(result)
.
Contoh
// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false) : (tensor<2x2xi1>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 2], [3, 8]]
select_and_scatter
Semantik
Sebarkan nilai dari tensor source
menggunakan scatter
berdasarkan
hasil reduce_window
dari tensor input
menggunakan select
dan menghasilkan
sebuah tensor result
.
Diagram berikut menunjukkan cara penghitungan elemen pada result
operand
dan source
menggunakan contoh konkret.
Secara lebih formal:
selected_values = reduce_window_without_init(...)
dengan input berikut:inputs = [operand].
window_dimensions
,window_strides
, danpadding
yang digunakan sebagaimana adanya.base_dilations = windows_dilations = 1
.body
ditentukan sebagai:
def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>: return select(arg0, arg1) ? arg0 : arg1;
tempat
E = element_type(operand)
, danreduce_window_without_init
bekerja persis sepertireduce_window
, kecuali bahwaschedule
reduce
(lihat mengurangi) tidak menyertakan nilai init. Saat ini tidak menentukan apa yang akan terjadi jika jendela yang sesuai tidak memiliki nilai (#731).result[result_index] = reduce([source_values], [init_value], [0], scatter)
dalam hal ini:source_values = [source[source_index] for source_index in source_indices]
.selected_index(source_index) = operand_index
jikaselected_values[source_index]
memiliki elemenoperand
darioperand_index
.source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index]
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor atau tensor terkuantisasi per-tensor | (C1-C4), (C6), (C8-C11) |
(I2) | source |
tensor atau tensor terkuantisasi per-tensor | (C1), C2 |
(I3) | init_value |
Tensor 0-dimensi atau per-tensor quantized tensor | (C3) |
(I4) | window_dimensions |
Konstanta tensor 1 dimensi jenis si64 |
(C2), (C4), (C5) |
(I5) | window_strides |
Konstanta tensor 1 dimensi jenis si64 |
(C2), (C6), (C7) |
(I6) | padding |
Konstanta tensor 2 dimensi jenis si64 |
(C2), (C8) |
(I7) | select |
fungsi | (C9) |
(I8) | scatter |
fungsi | (C10) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau tensor terkuantisasi per-tensor | (C11-C12) |
Batasan
- (C1)
element_type(operand) = element_type(source)
. - (C2)
shape(source) = num_windows
jika:padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1]
.is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape
.num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1
.
- (C3)
element_type(init_value) = element_type(operand)
. - (C4)
size(window_dimensions) = rank(operand)
. - (C5)
0 < window_dimensions
. - (C6)
size(window_strides) = rank(operand)
. - (C7)
0 < window_strides
. - (C8)
shape(padding) = [rank(operand), 2]
. - (C9)
select
memiliki jenis(tensor<E>, tensor<E>) -> tensor<i1>
, denganE = element_type(operand)
. - (C10)
scatter
memiliki jenis(tensor<E>, tensor<E>) -> tensor<E>
, denganis_promotable(element_type(operand), E)
. - (C11)
shape(operand) = shape(result)
. - (C12)
element_type(result) = E
.
Contoh
// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
window_dimensions = array<i64: 3, 1>,
window_strides = array<i64: 2, 1>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi64>, tensor<2x2xi64>, tensor<i64>) -> tensor<4x2xi64>
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]
kirim
Semantik
Mengirim inputs
ke saluran channel_id
dan menghasilkan token result
.
Jika is_host_transfer
adalah true
, operasi tersebut akan mentransfer data ke
{i>host<i}. Jika tidak, data akan ditransfer ke perangkat lain. Artinya adalah
ditentukan oleh implementasi. Penanda ini menduplikasi
informasi yang disediakan di
channel_type
, jadi di masa mendatang kami berencana untuk menyimpan salah satunya saja
(#666).
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | inputs |
jumlah tensor atau tensor terkuantisasi | |
(I2) | token |
token |
|
(I3) | channel_id |
konstanta jenis si64 |
|
(I4) | channel_type |
enum DEVICE_TO_DEVICE dan DEVICE_TO_HOST |
(C1) |
(I5) | is_host_transfer |
konstanta jenis i1 |
(C1) |
Output
Nama | Jenis |
---|---|
result |
token |
Batasan
- (C1)
channel_type
didefinisikan sebagai:DEVICE_TO_HOST
jikais_host_transfer = true
,DEVICE_TO_DEVICE
sebaliknya.
Contoh
%result = "stablehlo.send"(%operand, %token) {
channel_handle = #stablehlo.channel_handle<handle = 1, type = 2>,
is_host_transfer = true
} : (tensor<2x2xi64>, !stablehlo.token) -> !stablehlo.token
shift_left
Semantik
Melakukan operasi shift kiri berbasis element pada tensor lhs
berdasarkan angka rhs
bit dan menghasilkan tensor result
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | lhs |
tensor tipe integer | (C1) |
(I2) | rhs |
tensor tipe integer | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor tipe integer | (C1) |
Batasan
- (C1)
type(lhs) = type(rhs) = type(result)
.
Contoh
// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-2, 0, 8]
shift_right_arithmetic
Semantik
Melakukan operasi shift kanan aritmetika element-wise pada tensor lhs
dengan
rhs
jumlah bit dan menghasilkan tensor result
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | lhs |
tensor tipe integer | (C1) |
(I2) | rhs |
tensor tipe integer | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor tipe integer | (C1) |
Batasan
- (C1)
type(lhs) = type(rhs) = type(result)
.
Contoh
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-1, 0, 1]
shift_right_logical
Semantik
Melakukan operasi shift kanan logis element-wise pada tensor lhs
dengan rhs
jumlah bit dan menghasilkan tensor result
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | lhs |
tensor tipe integer | (C1) |
(I2) | rhs |
tensor tipe integer | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor tipe integer | (C1) |
Batasan
- (C1)
type(lhs) = type(rhs) = type(result)
.
Contoh
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [9223372036854775807, 0, 1]
tanda
Semantik
Menampilkan tanda element-wise operand
dan menghasilkan tensor result
.
Secara lebih formal, untuk setiap elemen x
, semantik dapat dinyatakan menggunakan
Sintaksis Python sebagai berikut:
def sign(x):
if is_integer(x):
if compare(x, 0, LT, SIGNED): return -1
if compare(x, 0, EQ, SIGNED): return 0
return 1
elif is_float(x):
if is_nan(x): return NaN
if compare(x, -0.0, EQ, FLOAT): return -0.0
if compare(x, +0.0, EQ, FLOAT): return +0.0
if compare(x, 0.0, LT, FLOAT): return -1.0
return 1.0
elif is_complex(x):
if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
return divide(x, convert(abs(x), type(x)))
Untuk jenis terkuantisasi, menjalankan
dequantize_op_quantize(sign, operand, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor bilangan bulat bertanda tangan, floating point, atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor bilangan bulat bertanda tangan, floating point, atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result)
.
Contoh
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
sinus
Semantik
Melakukan operasi sinus element-wise pada tensor operand
dan menghasilkan result
tensor. Bergantung pada jenis elemen, melakukan hal berikut:
- Untuk float:
sin
dari IEEE-754. - Untuk bilangan kompleks: sinus kompleks.
- Untuk jenis terkuantisasi:
dequantize_op_quantize(sine, operand, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result)
.
Contoh
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.sine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [0.0, -1.0]]
slice
Semantik
Mengekstrak slice dari operand
menggunakan indeks awal yang dihitung secara statis
dan menghasilkan tensor result
. start_indices
berisi indeks awal dari
irisan untuk setiap dimensi, limit_indices
berisi indeks akhir
(eksklusif) untuk irisan untuk setiap dimensi, dan strides
berisi langkah
untuk setiap dimensi.
Secara lebih formal, result[result_index] = operand[operand_index]
di mana
operand_index = start_indices + result_index * strides
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor atau tensor terkuantisasi per-tensor | (C1-C3), (C5) |
(I2) | start_indices |
Konstanta tensor 1 dimensi jenis si64 |
(C2), (C3), (C5) |
(I3) | limit_indices |
Konstanta tensor 1 dimensi jenis si64 |
(C2), (C3), (C5) |
(I4) | strides |
Konstanta tensor 1 dimensi jenis si64 |
(C2), (C4) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau tensor terkuantisasi per-tensor | (C1), (C5) |
Batasan
- (C1)
element_type(operand) = element_type(result)
. - (C2)
size(start_indices) = size(limit_indices) = size(strides) = rank(operand)
. - (C3)
0 <= start_indices <= limit_indices <= shape(operand)
. - (C4)
0 < strides
. - (C5)
shape(result) = ceil((limit_indices - start_indices) / strides)
.
Contoh
// %operand: [
// [0, 0, 0, 0],
// [0, 0, 1, 1],
// [0, 0, 1, 1]
// ]
%result = "stablehlo.slice"(%operand) {
start_indices = array<i64: 1, 2>,
limit_indices = array<i64: 3, 4>,
strides = array<i64: 1, 1>
} : (tensor<3x4xi64>) -> tensor<2x2xi64>
// % result: [
// [1, 1],
// [1, 1]
// ]
mengurutkan
Semantik
Mengurutkan irisan 1 dimensi inputs
di sepanjang dimensi dimension
bersama-sama,
menurut comparator
dan menghasilkan results
.
Tidak seperti input serupa dalam operasi lain, dimension
memungkinkan nilai negatif,
dengan semantik yang dijelaskan di bawah ini. Pada masa mendatang, tindakan ini mungkin tidak diizinkan
untuk alasan konsistensi
(#1377).
Jika is_stable
bernilai benar, berarti penyortirannya stabil, yaitu, urutan relatif
elemen yang dianggap setara dengan pembanding dipertahankan. Untuk kasus ini
jika ada input tunggal, dua elemen e1
dan e2
dianggap
sama dengan pembanding jika dan hanya jika
comparator(e1, e2) = comparator(e2, e1) = false
. Lihat formalisasi di bawah ini
untuk mengetahui bagaimana
generalisasi ke beberapa input.
Secara lebih formal, untuk semua result_index
di index_space(results[0])
:
adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension
.result_slice = [ri0, ..., :, ..., riR-1]
denganriN
merupakan individu elemen diresult_index
, dan:
disisipkan padaadjusted_dimension
.inputs_together = (inputs[0]..., ..., inputs[N-1]...)
.results_together[result_slice] = sort(inputs_together[result_slice], comparator_together)
.- dengan
sort
mengurutkan irisan 1 dimensi dalam urutan yang tidak menurun seperti bahwacomparator_together
akan menampilkantrue
jika argumen di sisi kiri adalah lebih kecil daripada argumen kedua sebelah kanan. def comparator_together(lhs_together, rhs_together): args = [] for (lhs_el, rhs_el) in zip(lhs_together, rhs_together): args.append(lhs_el) args.append(rhs_el) return comparator(*args)
(results[0]..., ..., results[N-1]...) = results_together
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | inputs |
jumlah tensor atau tensor terkuantisasi per-tensor | (C1-C5) |
(I2) | dimension |
konstanta jenis si64 |
(C4) |
(I3) | is_stable |
konstanta jenis i1 |
|
(I4) | comparator |
fungsi | (C5) |
Output
Nama | Jenis | Batasan |
---|---|---|
results |
jumlah tensor atau tensor terkuantisasi per-tensor | (C2), (C3) |
Batasan
- (C1)
0 < size(inputs)
. - (C2)
type(inputs...) = type(results...)
. - (C3)
same(shape(inputs...) + shape(results...))
. - (C4)
-R <= dimension < R
, denganR = rank(inputs[0])
. - (C5)
comparator
memiliki jenis(tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1>
, denganEi = element_type(inputs[i])
.
Contoh
// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>, %arg3: tensor<i64>):
%predicate = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GT>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"stablehlo.return"(%predicate) : (tensor<i1>) -> ()
}) {
dimension = 0 : i64,
is_stable = true
} : (tensor<2x3xi64>, tensor<2x3xi64>) -> (tensor<2x3xi64>, tensor<2x3xi64>)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]
sqrt
Semantik
Melakukan operasi root kuadrat berdasarkan elemen pada tensor operand
dan menghasilkan
Tensor result
. Bergantung pada jenis elemen, melakukan hal berikut:
- Untuk float:
squareRoot
dari IEEE-754. - Untuk bilangan kompleks: akar kuadrat kompleks.
- Untuk jenis terkuantisasi:
dequantize_op_quantize(sqrt, operand, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result)
.
Contoh
// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [2.0, 3.0]]
kurangi
Semantik
Melakukan pengurangan berbasis elemen dari dua tensor lhs
dan rhs
serta menghasilkan
Tensor result
. Bergantung pada jenis elemen, melakukan hal berikut:
- Untuk bilangan bulat: pengurangan bilangan bulat.
- Untuk float:
subtraction
dari IEEE-754. - Untuk bilangan kompleks: pengurangan kompleks.
- Untuk jenis terkuantisasi:
dequantize_op_quantize(subtract, lhs, rhs, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | lhs |
tensor bilangan bulat, floating point, atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1) |
(I2) | rhs |
tensor bilangan bulat, floating point, atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor bilangan bulat, floating point, atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Contoh
// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
// %result: [[1, 2], [3, 4]]
tan
Semantik
Melakukan operasi tangen berbasis elemen pada tensor operand
dan menghasilkan
Tensor result
. Bergantung pada jenis elemen, melakukan hal berikut:
- Untuk float:
tan
dari IEEE-754. - Untuk bilangan kompleks: tangen kompleks.
- Untuk jenis terkuantisasi:
dequantize_op_quantize(tan, operand, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result)
.
Contoh
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.tan"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [
// [0.0, 1.63312e+16],
// [0.0, 5.44375e+15]
// ]
Tanh
Semantik
Melakukan operasi tangen hiperbolik berbasis elemen pada tensor operand
dan
menghasilkan tensor result
. Bergantung pada jenis elemen, melakukan hal berikut:
- Untuk float:
tanh
dari IEEE-754. - Untuk bilangan kompleks: tangen hiperbolik kompleks.
- Untuk jenis terkuantisasi:
dequantize_op_quantize(tanh, operand, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_type(operand) = baseline_type(result)
.
Contoh
// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand) : (tensor<3xf32>) -> tensor<3xf32>
// %result: [-0.76159416, 0.0, 0.76159416]
{i>transpose<i}
Semantik
Mengubah dimensi tensor operand
menggunakan permutation
dan menghasilkan
Tensor result
. Secara lebih formal, result[result_index] = operand[operand_index]
dengan result_index[d] = operand_index[permutation[d]]
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor atau quantized tensor | (C1-C4) |
(I2) | permutation |
Konstanta tensor 1 dimensi jenis si64 |
(C2-C4) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor atau quantized tensor | (C1), C3-C4 |
Batasan
- (C1)
element_type(result)
diberikan oleh:element_type(operand)
, jika!is_per_axis_quantized(operand)
.element_type(operand)
kecualiquantization_dimension(operand)
danquantization_dimension(result)
dapat berbeda, jika tidak.
- (C2)
permutation
adalah permutasi darirange(rank(operand))
. - (C3)
shape(result) = dim(operand, permutation...)
. - (C4) Jika
is_per_axis_quantized(result)
, makaquantization_dimension(operand) = permutation(quantization_dimension(result))
.
Contoh
// %operand: [
// [[1,2], [3,4], [5,6]],
// [[7,8], [9,10], [11,12]]
// ]
%result = "stablehlo.transpose"(%operand) {
permutation = array<i64: 2, 1, 0>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
// %result: [
// [[1,7], [3,9], [5,11]],
// [[2,8], [4,10], [6,12]]
// ]
triangular_solve
Semantik
Menyelesaikan kumpulan sistem persamaan linear dengan segitiga bawah atau atas matriks koefisien.
Secara lebih formal, mengingat a
dan b
, result[i0, ..., iR-3, :, :]
adalah solusinya
ke op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :]
saat left_side
adalah
true
atau x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :]
saat
left_side
adalah false
, menyelesaikan variabel x
tempat op(a)
ditentukan
paling lambat transpose_a
, yang dapat berupa salah satu dari hal berikut:
NO_TRANSPOSE
: Menjalankan operasi menggunakana
sebagaimana adanya.TRANSPOSE
: Melakukan operasi pada transposisia
.ADJOINT
: Melakukan operasi pada transposisi konjugasia
.
Data input hanya dibaca dari segitiga bawah a
, jika lower
adalah true
atau
segitiga atas a
, jika sebaliknya. Data output ditampilkan dalam segitiga yang sama;
nilai dalam segitiga lainnya merupakan manifestasi implementasi.
Jika unit_diagonal
bernilai benar, implementasi dapat mengasumsikan bahwa ukuran diagonal
elemen a
sama dengan 1, jika tidak, perilaku tidak terdefinisi.
Untuk jenis terkuantisasi, menjalankan
dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower,
unit_diagonal, transpose_a), a, b, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | a |
tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1-C3) |
(I2) | b |
tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1-C4) |
(I3) | left_side |
konstanta jenis i1 |
(C3) |
(I4) | lower |
konstanta jenis i1 |
|
(I5) | unit_diagonal |
konstanta jenis i1 |
|
(I6) | transpose_a |
enum NO_TRANSPOSE , TRANSPOSE , dan ADJOINT |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor floating point atau tipe kompleks atau tensor terkuantisasi per-tensor | (C1) |
Batasan
- (C1)
baseline_element_type(a) = baseline_element_type(b)
. - (C2)
2 <= rank(a) = rank(b) = R
. - (C3) Hubungan antara
shape(a)
danshape(b)
didefinisikan sebagai berikut:shape(a)[:-3] = shape(b)[:-3]
.dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1)
.
- (C4)
baseline_type(b) = baseline_type(result)
.
Contoh
// %a = [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
// %b = [
// [2.0, 0.0, 0.0],
// [4.0, 8.0, 0.0],
// [6.0, 10.0, 12.0]
// ]
%result = "stablehlo.triangular_solve"(%a, %b) {
left_side = true,
lower = true,
unit_diagonal = false,
transpose_a = #stablehlo<transpose NO_TRANSPOSE>
} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// %result: [
// [2.0, 0.0, 0.0],
// [0.0, 2.0, 0.0],
// [0.0, 0.0, 2.0]
// ]
tuple
Semantik
Menghasilkan tuple result
dari nilai val
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | val |
jumlah nilai variadis | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tuple | (C1) |
Batasan
- (C1)
result
memiliki jenistuple<E0, ..., EN-1>
denganEi = type(val[i])
.
Contoh
// %val0: [1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1) : (tensor<2xf32>, tuple<tensor<i32>>) -> tuple<tensor<2xf32>, tuple<tensor<i32>>>
// %result: ([1.0, 2.0], (3))
uniform_dequantize
Semantik
Melakukan konversi berbasis elemen dari tensor terkuantisasi operand
ke
tensor floating point result
sesuai dengan parameter kuantisasi yang ditentukan
oleh jenis operand
.
Secara lebih formal, result = dequantize(operand)
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor terkuantisasi | (C1), C2 |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor jenis floating point | (C1), C2 |
Batasan
- (C1)
shape(operand) = shape(result)
. - (C2)
element_type(result) = expressed_type(operand)
.
Contoh
// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
// %result: [4.0, 15.0]
uniform_quantize
Semantik
Melakukan konversi berbasis elemen dari tensor floating point atau tensor terkuantisasi
operand
ke tensor terkuantisasi result
sesuai dengan kuantisasi
parameter yang ditentukan oleh jenis result
.
Secara lebih formal,
- Jika
is_float(operand)
:result = quantize(operand, type(result))
.
- Jika
is_quantized(operand)
:float_result = dequantize(operand)
.result = quantize(float_result, type(result))
.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
tensor floating point atau tipe terkuantisasi | (C1), C2 |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor terkuantisasi | (C1), C2 |
Batasan
- (C1)
shape(operand) = shape(result)
. - (C2)
expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand)
.
Contoh
// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
// %result: [10, 10]
// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-20,0.2:-30}>>
// %result: [20, 45]
sementara
Semantik
Menghasilkan output dari eksekusi fungsi body
0 kali atau lebih saat
Fungsi cond
menghasilkan true
. Secara lebih formal, semantik dapat dinyatakan
menggunakan sintaks Python sebagai berikut:
internal_state = operand
while cond(*internal_state):
internal_state = body(*internal_state)
results = internal_state
Perilaku loop tanpa batas akan ditentukan nanti (#383).
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | operand |
jumlah tensor, tensor atau token terkuantisasi | (C1-C3) |
(I2) | cond |
fungsi | (C1) |
(I3) | body |
fungsi | (C2) |
Output
Nama | Jenis | Batasan |
---|---|---|
results |
jumlah tensor, tensor atau token terkuantisasi | (C3) |
Batasan
- (C1)
cond
memiliki jenis(T0, ..., TN-1) -> tensor<i1>
, denganTi = type(operand[i])
. - (C2)
body
memiliki jenis(T0, ..., TN-1) -> (T0, ..., TN-1)
, denganTi = type(operand[i])
. - (C3)
type(results...) = type(operand...)
.
Contoh
// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%cond = "stablehlo.compare"(%arg0, %ten) {
comparison_direction = #stablehlo<comparison_direction LT>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
stablehlo.return %cond : tensor<i1>
}, {
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%new_sum = stablehlo.add %arg1, %one : tensor<i64>
%new_i = stablehlo.add %arg0, %one : tensor<i64>
stablehlo.return %new_i, %new_sum : tensor<i64>, tensor<i64>
}) : (tensor<i64>, tensor<i64>) -> (tensor<i64>, tensor<i64>)
// %results0: 10
// %results1: 10
Xor
Semantik
Melakukan XOR berbasis elemen dari dua tensor lhs
dan rhs
serta menghasilkan result
tensor. Bergantung pada jenis elemen, melakukan hal berikut:
- Untuk boolean: XOR logis.
- Untuk bilangan bulat: bitwise XOR.
Input
Label | Nama | Jenis | Batasan |
---|---|---|---|
(I1) | lhs |
tensor jenis boolean atau bilangan bulat | (C1) |
(I2) | rhs |
tensor jenis boolean atau bilangan bulat | (C1) |
Output
Nama | Jenis | Batasan |
---|---|---|
result |
tensor jenis boolean atau bilangan bulat | (C1) |
Batasan
- (C1)
type(lhs) = type(rhs) = type(result)
.
Contoh
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, false]]
Interop Dialek
Saat ini, program StableHLO di luar jaringan terkadang berisi operasi yang tidak ditentukan oleh StableHLO.
Modul, Fungsi, Panggilan, dan Kembali
StableHLO menggunakan operasi MLIR upstream untuk ModuleOp, FuncOp, CallOp, dan ReturnOp. Hal ini dilakukan untuk interop yang lebih baik dengan mesin MLIR yang ada, karena kartu berguna yang menargetkan FuncOp dan ModuleOp, serta banyak kompilasi pipeline mengharapkan adanya operasi ini. Jaminan kompatibilitas penuh yang diterapkan pada operasi ini. Jika ada yang berubah pada operasi ini dalam dengan cara yang tidak kompatibel (yaitu penghapusan), padanan StableHLO akan ditambahkan untuk mempertahankan kompatibilitas mundur.
CHLO
Opset CHLO berisi operasi tingkat lebih tinggi yang terurai menjadi StableHLO. Saat ini tidak ada jaminan kompatibilitas untuk CHLO. Untuk kompatibilitas jaminan, pass chlo-legalize-to-stablehlo harus digunakan sebelum serialisasi.
Operasi Bentuk
Praktik umum dalam komunitas adalah menggunakan operasi tertentu dari
Dialek MLIR dalam program StableHLO dinamis untuk melakukan komputasi bentuk.
Umumnya, ini mencakup dialek shape
operasi seperti shape_of
atau num_elements
, tensor
dialek
operasi seperti dim
atau from_elements
, dan jenis index
bawaan.
Dynamism RFC > O2
menunjukkan ini sebagai di luar cakupan, namun beberapa dukungan untuk jenis index
disertakan untuk tujuan interop. Tidak ada jaminan kompatibilitas untuk template ini
operasi atau jenis. Opsi shape-legalize-to-stablehlo
dapat digunakan untuk mengonversi operasi ini menjadi operasi StableHLO yang didukung penuh.
Operasi yang Tidak Digunakan Lagi
Ada beberapa operasi StableHLO yang diturunkan dari MHLO yang tidak digunakan lagi dan akan keluar dari StableHLO. Detail lengkap tentang penghapusan dapat ditemukan di StableHLO v1.0 Cleanup #2283. Masalah pelacak untuk penghentian ini adalah #2340.
Operasi ini termasuk dalam beberapa kategori:
- "Tidak di HLO" kategori operasi StabilHLO - mereka awalnya adalah bagian dari
opset StableHLO tetapi kemudian dianggap tidak cocok dengan baik:
broadcast
,create_token
,cross-replica-sum
,dot
,einsum
,torch_index_select
,unary_einsum
(#3). - Operasi yang tidak digunakan - Operasi ini mungkin berguna pada titik tertentu, tetapi operasinya
belum dikembangkan, atau {i>pipelines<i}
yang menggunakan operasi ini telah
difaktorkan ulang sehingga tidak diperlukan lagi. Ini mencakup
map
,tuple
(#598),get_tuple_element
,rng
,complex
membandingkan #560, dan konvolusiwindow_reversal
(#1181).
Beberapa dari operasi ini dapat dihapus dengan mudah
karena mereka dapat dinyatakan menggunakan
operasi yang ada (broadcast
, create_token
, cross-replica-sum
, dot
,
unary_einsum
) dan akan dihapus setelah jendela kompatibilitas yang ada
izin lewat (6 bulan). URL lainnya masih dalam proses peninjauan untuk dihapus (einsum
,
get_tuple_element
, map
, rng
torch_index_select
, tuple
, complex
perbandingan, window_reversal
). Menunggu masukan dari komunitas,
operasi ini akan dihapus, atau
ditambahkan ke spesifikasi dengan dukungan penuh. Sampai
operasi berjangka ini diketahui, mereka hanya dijamin kompatibilitasnya selama 6 bulan.
Eksekusi
Eksekusi berurutan
Program StableHLO dijalankan dengan memberikan nilai input ke fungsi main
dan menghitung nilai output. Nilai {i>output<i} dari sebuah fungsi dihitung oleh
mengeksekusi grafik operasi yang di-root pada operasi return
yang sesuai.
Urutan eksekusi ditentukan oleh implementasi selama sesuai dengan
dataflow, yaitu jika operasi dijalankan sebelum digunakan. Di StabilHLO, semua
dan operasi {i>side-effecting<i} memakai satu token dan
menghasilkan satu token (beberapa token dapat
di-multiplex menjadi satu token melalui after_all
), sehingga urutan eksekusi dari sisi
juga selaras dengan dataflow. Misalnya, dalam program di bawah ini
ada dua kemungkinan perintah eksekusi: %0
→ %1
→ %2
→ return
dan
%1
→ %0
→ %2
→ return
.
func.func @main() -> tensor<f64> {
%0 = stablehlo.constant dense<1.0> : tensor<f64>
%1 = stablehlo.constant dense<2.0> : tensor<f64>
%2 = stablehlo.add %0, %1 : tensor<f64>
return %2 : tensor<f64>
}
Secara lebih formal, proses SttableHLO adalah kombinasi dari:
1) program StableHLO, 2) status operasi (belum dieksekusi,
sudah dieksekusi), dan 3) nilai antara yang sedang dikerjakan oleh proses.
Proses ini dimulai dengan nilai input ke fungsi main
, dilanjutkan hingga
grafik operasi yang memperbarui status operasi dan nilai perantara, serta
diakhiri dengan nilai output. Formalisasi lebih lanjut akan ditentukan nanti
(#484).
Eksekusi paralel
Program StableHLO dapat dijalankan secara paralel, yang disusun ke dalam grid proses 2D
dari num_replicas
oleh num_partitions
yang keduanya memiliki jenis ui32
.
Di petak proses SttableHLO, num_replicas * num_partitions
dari StableHLO
proses dijalankan secara bersamaan. Setiap proses memiliki
process_id = (replica_id, partition_id)
, dengan
replica_id
dalam replica_ids = range(num_replicas)
dan
partition_id
di partition_ids = range(num_partitions)
yang keduanya memiliki
ketik ui32
.
Ukuran grid proses diketahui secara statis untuk setiap program (dalam
di masa mendatang, kami berencana untuk menjadikannya bagian eksplisit dari program StableHLO
#650), dan posisi
di dalam {i>process grid <i}dikenal
secara statis untuk setiap proses. Setiap proses memiliki
akses ke posisinya dalam petak proses melalui replica_id
dan
partition_id
op.
Di dalam {i>process grid<i}, semua program bisa jadi sama (dalam kolom “{i>Single<i} {i>Program<i}, {i>Multiple Data<i}” berbeda), semuanya bisa berbeda (dalam "Multiple Program, Multi Data" {i>style<i}) atau sesuatu di antaranya. Di masa mendatang, kami berencana memperkenalkan dukungan untuk idiom lain dalam mendefinisikan program StableHLO paralel, termasuk GSPMD (#619).
Di dalam kisi proses, proses sebagian besar independen satu sama lain - keduanya memiliki status operasi terpisah, nilai input/menengah/output terpisah dan sebagian besar operasi dijalankan secara terpisah antar proses, dengan pengecualian untuk sejumlah kecil operasi kolektif yang dijelaskan di bawah ini.
Mengingat bahwa eksekusi sebagian besar operasi hanya menggunakan nilai dari
data, biasanya tidak ambigu untuk merujuk pada nilai-nilai ini dengan namanya.
Namun, ketika menjelaskan semantik operasi kolektif, itu tidak cukup, dan
yang memunculkan notasi name@process_id
untuk merujuk ke nilai name
dalam proses tertentu. (Dari perspektif tersebut, name
yang tidak memenuhi syarat dapat
dilihat sebagai singkatan untuk name@(replica_id(), partition_id())
).
Urutan eksekusi di seluruh proses adalah implementasi yang ditentukan, kecuali untuk sinkronisasi yang diperkenalkan oleh komunikasi point-to-point dan operasi kolektif sebagaimana dijelaskan di bawah ini.
Komunikasi titik ke titik
Proses StabilHLO dapat
berkomunikasi satu sama lain melalui
Saluran SttableHLO. Saluran diwakili oleh ID positif jenis
si64
. Melalui berbagai operasi, adalah mungkin
untuk mengirim nilai ke saluran dan
menerimanya dari saluran.
Formalisasi lebih lanjut, mis. dari mana ID channel ini berasal, bagaimana proses yang menyadarinya dan jenis sinkronisasi apa yang diperkenalkan oleh mereka, akan ditentukan nanti (#484).
Komunikasi streaming
Setiap proses StableHLO memiliki akses ke dua antarmuka streaming:
- Infeed yang dapat dibaca.
- Outfeed yang dapat ditulis.
Tidak seperti saluran, yang digunakan untuk berkomunikasi antar proses dan karenanya memiliki proses di kedua ujungnya, feed infeed dan outfeed memiliki yang ditentukan oleh implementasi akhir.
Formalisasi lebih lanjut, mis. bagaimana komunikasi streaming memengaruhi eksekusi dan jenis sinkronisasi yang dilakukannya, akan ditentukan nanti (#484).
Operasi kolektif
Ada enam operasi kolektif di StableHLO: all_gather
, all_reduce
,
all_to_all
, collective_broadcast
, collective_permute
, dan
reduce_scatter
. Semua operasi ini membagi proses dalam proses StableHLO
grid ke dalam grup proses SttableHLO dan jalankan komputasi gabungan dalam
setiap grup proses, terpisah
dari grup proses lainnya.
Dalam setiap grup proses, operasi kolektif dapat memperkenalkan sinkronisasi hambatan. Formalisasi lebih lanjut, mis. dan menjelaskan kapan sinkronisasi terjadi, bagaimana sebenarnya proses sampai pada penghalang ini, dan apa yang terjadi jika tidak dilakukan, akan ditentukan nanti (#484).
Jika grup proses melibatkan komunikasi lintas partisi, yaitu ada
proses dalam grup proses yang ID partisinya berbeda, lalu dieksekusi
operasi kolektif membutuhkan saluran, dan operasi kolektif harus menyediakan
channel_id
positif dari jenis si64
. Komunikasi lintas replika
tidak perlu
saluran TV Anda.
Komputasi yang dilakukan oleh operasi kolektif bersifat khusus untuk operasi individual dan dijelaskan di bagian pengoperasian individual di atas. Namun, strategi dengan di mana grid proses dibagi menjadi grup proses yang dibagikan di antara operasi ini di bagian ini dan dijelaskan dalam bagian ini. Secara lebih formal, StableHLO mendukung mengikuti empat strategi berikut.
cross_replica
Hanya komunikasi lintas replika yang terjadi dalam setiap grup proses. Ini
menggunakan replica_groups
- daftar ID replika - dan komputasi
hasil Kartesius dari replica_groups
oleh partition_ids
. replica_groups
harus memiliki elemen yang unik dan mencakup semua replica_ids
. Secara lebih formal, menggunakan
{i>Syntax<i} Python:
def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
for partition_id in partition_ids:
process_group = []
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
Misalnya, untuk replica_groups = [[0, 1], [2, 3]]
dan num_partitions = 2
,
cross_replica
akan menghasilkan
[[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]]
.
cross_partition
Hanya komunikasi lintas partisi yang terjadi dalam setiap grup proses. Ini
menggunakan partition_groups
- daftar ID partisi - dan
menghitung produk Kartesius partition_groups
dari replica_ids
.
partition_groups
harus memiliki elemen yang unik dan mencakup semua partition_ids
.
Secara lebih formal, menggunakan sintaksis Python:
def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
for partition_group in partition_groups:
for replica_id in replica_ids:
process_group = []
for partition_id in partition_group:
process_group.append((replica_id, partition_id))
yield process_group
Misalnya, untuk partition_groups = [[0, 1]]
dan num_replicas = 4
,
cross_partition
akan menghasilkan
[[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]]
.
cross_replica_and_partition
Baik komunikasi lintas replika maupun lintas partisi dapat terjadi di dalam setiap
{i>process group<i}. Strategi ini menggunakan replica_groups
- daftar yang berisi
ID replika - dan menghitung produk Kartesius dari setiap replica_group
dengan
partition_ids
. replica_groups
harus memiliki elemen yang unik dan mencakup semua
replica_ids
. Secara lebih formal, menggunakan sintaksis Python:
def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
process_group = []
for partition_id in partition_ids:
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
Misalnya, untuk replica_groups = [[0, 1], [2, 3]]
dan num_partitions = 2
,
cross_replica_and_partition
akan menghasilkan
[[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]]
.
flattened_ids
Strategi ini menggunakan flattened_id_groups
- daftar daftar "diratakan"
ID proses dalam bentuk replica_id * num_partitions + partition_id
- dan
mengubahnya menjadi ID proses. flattened_id_groups
harus memiliki elemen yang unik
dan mencakup semua process_ids
. Secara lebih formal, menggunakan sintaksis Python:
def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
for flattened_id_group in flattened_id_groups:
process_group = []
for flattened_id in flattened_id_group:
replica_id = flattened_id // num_partitions
partition_id = flattened_id % num_partitions
process_group.append((replica_id, partition_id))
yield process_group
Misalnya, untuk flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]]
,
num_replicas = 4
dan num_partitions = 2
, flattened_ids
akan menghasilkan
[[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]]
.
Akurasi
Saat ini, StableHLO tidak memberikan jaminan tentang akurasi numerik, tapi ini bisa berubah pada masa mendatang (#1156).
Semantik eksekusi operasi terkuantisasi
Penafsiran operasi StableHLO terkuantisasi dapat bervariasi bergantung pada persyaratan dan kemampuan perangkat kerasnya. Misalnya, beberapa perangkat keras mungkin memilih untuk menafsirkan operasi terkuantisasi menggunakan metode "{i>dequantize, do floating-point<i} operasi, dan terakhir "mengkuantisasikan" strategi. Orang lain dapat melakukan seluruh komputasi dengan aritmatika bilangan bulat. Karenanya, interpretasi dari operasi StableHLO terkuantisasi ditentukan secara eksklusif oleh terlepas dari implementasi layanan. Penafsiran kuantisasi hibrida (#1575) harus didasarkan pada semantik seperti yang ditentukan dalam spesifikasi (melalui 1792).
Error
Program StabilHLO divalidasi melalui berbagai kendala untuk operasi individu, yang mengatur banyak kelas {i>error<i} sebelum waktu proses. Namun, kondisi error masih mungkin terjadi, misalnya melalui overflow bilangan bulat, akses di luar batas, dll. Kecuali dinyatakan lain secara eksplisit, semua error ini menghasilkan perilaku yang ditentukan implementasi, tetapi ini dapat berubah dalam mendatang (#1157).
Pengecualian floating point
Sebagai pengecualian untuk aturan ini, pengecualian floating point dalam program StableHLO
memiliki perilaku yang jelas. Operasi yang menghasilkan pengecualian yang ditentukan oleh
Standar IEEE-754 (operasi tidak valid, pembagian dengan nol, overflow, underflow, atau
pengecualian tidak tepat) menghasilkan hasil default (seperti yang didefinisikan dalam standar) dan
melanjutkan eksekusi tanpa menaikkan tanda status yang sesuai; mirip dengan
raiseNoFlag
penanganan pengecualian dari standar. Pengecualian untuk non-standar
operasi (mis. aritmatika kompleks dan fungsi transendental tertentu) merupakan
ditentukan oleh implementasi.
Ketidakcocokan bentuk
StableHLO mendukung tensor berbentuk dinamis. Namun, bentuk harus sesuai dengan runtime, jika tidak, perilaku tidak ditentukan. StabilHLO tidak secara eksplisit menyediakan operasi yang dapat menyatakan bahwa tensor memiliki bentuk tertentu saat runtime. Membuat kode yang benar adalah tanggung jawab produsen.
Sebagai contoh spesifik, program di bawah valid. Namun, saat runtime,
bentuk %arg0
dan %arg1
yang sama harus sama. Jika tidak,
perilaku program tidak terdefinisi:
func.func @foo(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<?xi32>
return %0 : tensor<?xi32>
}
Notasi
Untuk menjelaskan sintaksis, dokumen ini menggunakan ragam ISO EBNF yang dimodifikasi
standar (ISO/IEC 14977:1996,
Wikipedia),
dengan dua modifikasi: 1) aturan ditentukan menggunakan ::=
, bukan =
,
2) penyambungan dinyatakan menggunakan penjajaran, bukan ,
.
Untuk menjelaskan semantik (yaitu dalam bagian "Types", "Constants", dan "Ops"), kami menggunakan formula yang didasarkan pada sintaksis Python, yang diperluas dengan untuk mengekspresikan operasi array secara ringkas seperti yang dijelaskan di bawah ini. Ini berfungsi dengan baik untuk cuplikan kode berukuran kecil, tetapi dalam kasus yang jarang terjadi, ketika cuplikan kode yang lebih besar diperlukan, kita menggunakan {i>syntax<i} vanilla Python yang selalu diperkenalkan secara eksplisit.
Formula
Mari kita pelajari cara kerja formula berdasarkan contoh dari dot_general
spesifikasi
pendukung. Salah satu batasan untuk operasi ini terlihat seperti berikut:
dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)
.
Nama-nama yang digunakan dalam formula ini berasal
dari dua sumber: 1) fungsi global,
yaitu dim
, 2) definisi anggota dari elemen program yang sesuai, yaitu
Input lhs
, lhs_batching_dimensions
, rhs
, dan rhs_batching_dimensions
yang ditentukan dalam "Input" dari dot_general
.
Seperti yang disebutkan di atas, {i>syntax<i} formula ini berbasis Python dengan beberapa ekstensi yang berorientasi keringkasan. Untuk memahami formulanya, mari kita ubah menjadi sintaks vanilla Python.
A) Dalam formula ini, kita menggunakan =
untuk mewakili kesetaraan. Jadi, langkah pertama
untuk mendapatkan sintaksis Python, mengganti =
dengan ==
, seperti berikut:
dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...)
.
B) Selain itu, rumus ini mendukung elipsis (...
) yang mengubah ekspresi skalar
menjadi ekspresi tensor. Singkatnya, f(xs...)
secara kasar berarti "untuk setiap
x
skalar pada tensor xs
, hitung f(x)
skalar, lalu tampilkan semuanya
hasil skalar ini bersama-sama
sebagai hasil tensor". Dalam sintaks{i> <i}vanilla Python,
contoh formula kita berubah menjadi:
[dim(lhs, dim1) for dim1 in lhs_batching_dimensions] ==
[dim(rhs, dim2) for dim2 in rhs_batching_dimensions]
.
Berkat elips, sering kali kita bisa
menghindari bekerja pada tingkat
skalar individu. Namun, dalam beberapa kasus yang rumit, tingkat semi-informal di tingkat
dapat digunakan seperti di formula start_indices[bi0, ..., :, ..., biN]
dari spesifikasi gather
. Untuk memberikan keringkasan, kami tidak
menyediakan formalisme yang tepat untuk menerjemahkan
sintaks tersebut ke vanilla Python,
berharap bahwa hal itu masih dapat dipahami
secara intuitif berdasarkan kasus per kasus.
Harap beri tahu kami jika beberapa formula tertentu terlihat tidak tembus pandang, dan kami akan mencoba
memperbaikinya.
Anda juga akan memperhatikan bahwa formula menggunakan elips untuk memperluas segala macam daftar, termasuk tensor, daftar tensor (yang misalnya dapat muncul dari jumlah tensor), dll. Ini adalah area lain di mana kami tidak memberikan formalisme (mis., daftar bahkan bukan bagian dari sistem jenis StableHLO) dan mengandalkan pemahaman intuitif.
C) Kendaraan notasi penting terakhir yang kami gunakan adalah penyiaran. Meskipun {i>opset<i} StableHLO tidak mendukung penyiaran implisit, formula dilakukan, juga untuk membuat keringkasan. Secara singkat, jika skalar digunakan dalam konteks di mana tensor diharapkan, skalarnya disiarkan ke sesuai dengan bentuk yang diinginkan.
Untuk melanjutkan contoh dot_general
, berikut batasan lainnya:
0 <= lhs_batching_dimensions < rank(lhs)
. Seperti yang ditentukan dalam dot_general
spesifikasi, lhs_batching_dimensions
adalah tensor, tetapi 0
dan
rank(lhs)
adalah skalar. Setelah kita menerapkan
siaran implisit, formulanya akan
menjadi [0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)]
.
Ketika diterapkan ke operasi dot_general
tertentu, formula ini akan
mengevaluasi ke tensor boolean. Ketika formula digunakan sebagai {i>constraints<i},
berlaku jika formula bernilai true
atau ke tensor yang
hanya memiliki true
elemen.
Nama
Dalam formula, cakupan leksikal meliputi: 1) fungsi global, 2) definisi anggota,
3) definisi lokal. Daftar fungsi global disediakan di bawah ini. Daftar definisi elemen tergantung pada elemen program yang notasinya diterapkan ke:
- Untuk operasi, definisi anggota menyertakan nama yang diperkenalkan dalam "Input" dan "Output" bagian.
- Untuk yang lainnya, definisi anggota
mencakup bagian struktural dari
, dinamai menurut non-terminal EBNF yang sesuai. Sebagian besar
waktu, nama-nama bagian struktural ini
diperoleh dengan mengubah
nama non-terminal untuk snake case (misalnya
IntegerLiteral
=>integer_literal
), tetapi terkadang nama disingkat dalam prosesnya (mis.QuantizationStorageType
=>storage_type
), dalam hal ini nama tersebut diperkenalkan secara eksplisit mirip dengan "Input" / "Output" bagian yang beroperasi spesifikasi produk. - Selain itu, definisi anggota selalu menyertakan
self
untuk merujuk ke elemen program yang sesuai.
Nilai
Saat dievaluasi, formula menggunakan jenis nilai berikut:
1) Value
(nilai sebenarnya, misalnya, dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
;
mereka selalu tahu
tipenya),
2) Placeholder
(nilai mendatang, misalnya lhs
, rhs
, atau result
; nilai sebenarnya
nilainya belum diketahui, hanya jenisnya yang diketahui),
3) Type
(jenis seperti yang ditentukan di bagian "Types"),
4) Function
(fungsi global seperti yang didefinisikan di bagian "Fungsi").
Bergantung pada konteksnya, nama mungkin merujuk pada nilai yang berbeda. Selengkapnya
khususnya, "Semantik" untuk operasi (dan yang setara untuk program lain
) menentukan logika runtime, sehingga semua input tersedia sebagai Value
.
Sebaliknya, daftar “{i>Constraints<i}” untuk operasi (dan yang setara) mendefinisikan
"waktu kompilasi" logika, yaitu sesuatu yang biasanya
dieksekusi sebelum {i>runtime<i},
jadi hanya input konstan yang tersedia sebagai Value
dan input lainnya yang
hanya tersedia sebagai Placeholder
.
Nama | Di "Semantik" | Di bagian "Constraints" |
---|---|---|
Fungsi global | Function |
Function |
Input konstan | Value |
Value |
Input tidak konstan | Value |
Placeholder |
Output | Value |
Placeholder |
Definisi lokal | Tergantung definisi | Tergantung definisi |
Mari kita pertimbangkan contoh operasi transpose
:
%result = "stablehlo.transpose"(%operand) {
permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
Untuk operasi ini, permutation
adalah konstanta, sehingga tersedia sebagai Value
dalam semantik dan batasan. Sebaliknya, operand
dan result
adalah
tersedia sebagai Value
dalam semantik, tetapi hanya sebagai Placeholder
dalam batasan.
Fungsi
Konstruksi jenis
Tidak ada fungsi yang dapat digunakan untuk membuat jenis. Sebaliknya, kita secara langsung
menggunakan sintaks jenis karena biasanya lebih ringkas. Mis.
(tensor<E>, tensor<E>) -> (tensor<E>)
, bukan function_type(
[tensor_type([], E), tensor_type([], E)], [tensor_type([], E)])
.
Fungsi pada jenis
element_type
ditentukan berdasarkan jenis tensor dan jenis tensor terkuantisasi, serta menampilkanTensorElementType
atauQuantizedTensorElementType
bagian dariTensorType
atauQuantizedTensorType
yang sesuai.
def element_type(x: Value | Placeholder | Type):
if type(x) == TensorType:
return tensor_element_type(x)
if type(x) == QuantizedTensorType:
return quantized_tensor_element_type(x)
if type(x) is not Type:
return element_type(type(x))
is_per_axis_quantized(x: Value | Placeholder | Type) -> Value
adalah pintasan untukis_quantized(x) and quantization_dimension(x) is not None
.is_per_tensor_quantized(x: Value | Placeholder | Type) -> Value
adalah untukis_quantized(x) and quantization_dimension(x) is None
.is_promotable(x: Type, y: Type) -> bool
memeriksa apakah jenisx
dapat dipromosikan untuk mengetiky
. Jikax
dany
adalahQuantizedTensorElementType
, promosi hanya diterapkan untukstorage_type
. Versi promosi khusus ini saat ini digunakan dalam konteks komputasi pengurangan (lihat RFC untuk detail selengkapnya).
def is_promotable(x: Type, y: Type) -> Value:
is_same_type = (is_bool(x) and is_bool(y)) or
(is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
(is_complex(x) and is_complex(y)) or
(is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))
if is_same_type == False:
return False
if is_integer(x) or is_float(x):
return bitwidth(x) <= bitwidth(y)
if is_complex(x):
return bitwidth(element_type(x)) <= bitwidth(element_type(y))
if is_quantized(x):
return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))
return false
is_quantized(x: Value | Placeholder | Type) -> Value
adalah pintasan untukis_quantized_tensor_element_type(x)
.is_type_name(x: Value | Placeholder | Type) -> Value
. Tersedia untuk semua jenis datanya. Misalnya,is_float(x)
menampilkantrue
jikax
adalahFloatType
. Jikax
adalah nilai atau placeholder, fungsi ini adalah pintasan untukis_type_name(type(x))
.max_value(x: Type) -> Value
menampilkan nilai maksimumTensorElementType
. Jikax
bukanTensorElementType
,None
akan ditampilkan.min_value(x: Type) -> Value
menampilkan nilai minimum yang memungkinkanTensorElementType
. Jikax
bukanTensorElementType
,None
akan ditampilkan.member_name(x: Value | Placeholder | Type) -> Any
. Tersedia untuk semua anggota definisimember_name
dari semua jenis. Misalnya,tensor_element_type(x)
menampilkan bagianTensorElementType
dariTensorType
yang sesuai. Jikax
adalah nilai atau placeholder, fungsi ini adalah pintasan untukmember_name(type(x))
. Jikax
bukan jenis yang memiliki anggota yang sesuai, atau nilai atau placeholder dari jenis tersebut, akan menampilkanNone
.is_empty_algorithm(*args: Type)
memeriksa apakah semua kolom algoritma titik telah ditetapkan keNone
. Ini diperlukan karena algoritma titik telah menentukan implementasi perilaku default, jadi menentukan nilai {i>default<i} akan menjadi tidak benar.
Konstruksi nilai
operation_name(*xs: Value | Type) -> Value
. Tersedia untuk semua operasi. Misalnya,add(lhs, rhs)
menggunakan dua nilai tensorlhs
danrhs
, lalu menampilkan output dari mengevaluasi operasiadd
dengan input ini. Untuk beberapa operasi, misalnyabroadcast_in_dim
, jenis output-nya "beban", yaitu yang diperlukan untuk mengevaluasi operasi. Dalam hal ini, fungsi menganggap jenis ini sebagai argumen.
Fungsi pada nilai
Semua operator dan fungsi Python tersedia. Mis. keduanya langganan dan mengiris notasi dari Python tersedia untuk diindeks menjadi tensor, tensor terkuantisasi dan tuple.
to_destination_type(x: Value, destination_type: Type) -> Value
ditentukan di dan menampilkan nilai yang dikonversi darix
berdasarkantype(x)
dandestination_type
sebagai berikut:
def to_destination_type(x: Value, destination_type: Type) -> Value:
if type(x) == destination_type:
return x
if is_quantized(destination_type):
if is_quantized(type(x)):
return quantize(x, destination_type)
assert is_float(type(x))
return quantize(x, destination_type)
if is_quantized(type(x)):
assert destination_type = expressed_type(type(x))
return dequantize(type(x))
return convert(x, destination_type)
Ada diskusi awal tentang penggabungan convert
, uniform_quantize
, dan
Operasi uniform_dequantize
(#1576).
Setelah penggabungan, kita tidak memerlukan fungsi di atas dan dapat menggunakan nama operasi
untuk convert
.
is_nan(x: Value) -> Value
ditentukan pada tensor dan menampilkantrue
jika semua elemenx
adalahNaN
ataufalse
jika tidak. Jikax
bukan tensor, akan menampilkanNone
.is_sorted(x: Value) -> Value
ditentukan pada tensor dan menampilkantrue
jika elemenx
diurutkan dalam urutan menaik sehubungan dengan urutan naik urutan leksikografis indeksnya ataufalse
jika tidak. Jikax
bukan , menghasilkanNone
.is_unique(x: Value) -> Value
ditentukan pada tensor dan menampilkantrue
jikax
tidak memiliki elemen duplikat ataufalse
. Jikax
bukan tensor, akan menampilkanNone
.member_name(x: Value) -> Any
ditentukan untuk semua definisi anggotamember_name
dari semua nilai. Misalnya,real_part(x)
menampilkanRealPart
dariComplexConstant
yang sesuai. Jikax
bukan nilai yang memiliki anggota yang sesuai, akan menampilkanNone
.same(x: Value) -> Value
ditentukan pada tensor dan menampilkantrue
jika elemenx
semuanya sama satu sama lain ataufalse
jika tidak. Jika tensor tidak memiliki elemen, yang dihitung sebagai “semua sama satu sama lain”, yaitu menampilkantrue
. Jikax
bukan tensor, tampilkanNone
.split(x: Value, num_results: Value, axis: Value) -> Value
ditentukan di tensor dan menampilkan irisannum_results
darix
di sepanjang sumbuaxis
. Jikax
bukan tensor ataudim(x, axis) % num_results != 0
, makaNone
akan ditampilkan.is_defined_in_parent_scope(x: Value) -> Value
ditentukan pada string dan menampilkantrue
jikax
adalah nama fungsi yang ditentukan dalam cakupan yang sama sebagai fungsi induk dari op yang relevan.is_namespaced_op_name(x: Value) -> Value
ditentukan pada string dan ditampilkantrue
jikax
adalah nama op yang valid, itu mengikuti metode reguler berikut ekspresi:[a-zA-Z][a-zA-Z0-9_]*([.][a-zA-Z0-9_$]+)+
Komputasi bentuk
axes(x: Value | Placeholder | Type) -> Value
adalah pintasan untukrange(rank(x))
.dim(x: Value | Placeholder | Type, axis: Value) -> Value
adalah pintasan untukshape(x)[axis]
.dims(x: Value | Placeholder | Type, axes: List) -> List
adalah pintasan untuklist(map(lambda axis: dim(x, axis), axes))
.index_space(x: Value | Placeholder | Type) -> Value
ditentukan pada tensor dan menampilkan indekssize(x)
untukTensorType
yang sesuai yang diurutkan dalam urutan leksikografis menaik, yaitu[0, ..., 0]
,[0, ..., 1]
, ...,shape(x) - 1
. Jikax
bukan jenis tensor, jenis tensor terkuantisasi, atau nilai atau placeholder dari salah satu jenis ini, akan menampilkanNone
.rank(x: Value | Placeholder | Type) -> Value
adalah pintasan untuksize(shape(x))
.shape(x: Value | Placeholder | Type) -> Value
ditentukan di bagian "Functions tentang jenis" melaluimember_name
.size(x: Value | Placeholder | Type) -> Value
adalah pintasan untukreduce(lambda x, y: x * y, shape(x))
.
Komputasi kuantisasi
def baseline_element_type(x: Value | Placeholder | Type) -> Type
adalah untukelement_type(baseline_type(x))
.baseline_type
ditentukan berdasarkan jenis tensor dan jenis tensor terkuantisasi, serta mengubahnya ke "baseline", yaitu tipe dengan bentuk yang sama tetapi dengan parameter kuantisasi jenis elemen direset ke nilai default. Ini adalah digunakan sebagai trik praktis untuk membandingkan tipe tensor dan tensor terkuantisasi seragam, yang sering dibutuhkan. Untuk jenis terkuantisasi, hal ini memungkinkan membandingkan jenis yang mengabaikan parameter kuantisasi, yaitushape
,storage_type
,expressed_type
,storage_min
,storage_max
, danquantization_dimension
(untuk jenis terkuantisasi per sumbu) harus semuanya cocok, tetapiscales
danzero points
mungkin berbeda.
def baseline_type(x: Value | Placeholder | Type) -> Type:
if type(x) == TensorType:
return x
if type(x) == QuantizedTensorType:
element_type = quantized_tensor_element_type(x)
baseline_element_type = QuantizedTensorElementType(
storage_type = storage_type(element_type),
storage_min = storage_min(element_type),
storage_max = storage_max(element_type),
expressed_type = expressed_type(element_type),
quantization_dimension = quantization_dimension(element_type),
scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
return QuantizedTensorType(shape(x), baseline_element_type)
if type(x) is not Type:
return baseline_element_type(type(x))
dequantize
ditentukan pada jenis tensor terkuantisasi dan mengubahnya menjadi tipe tensor floating point. Hal ini terjadi melalui konversi elemen terkuantisasi yang mewakili nilai integer dari jenis penyimpanan ke dalam kolom nilai floating point dari jenis yang dinyatakan menggunakan titik nol dan skala yang terkait dengan jenis elemen terkuantisasi.
def compute_zero_points(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
zero_points[i] = zero_points(quantized_type)[i[d]]
return zero_points
def compute_scales(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
type(result_type))
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
scales[i] = scales(quantized_type)[i[d]]
return scales
def dequantize(x: Value) -> Value:
assert is_quantized(x)
x_storage = bitcast_convert(x, storage_type(x))
x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
x_expressed_sub = convert(x_storage_sub, expressed_type(x))
return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
quantize
ditentukan pada jenis tensor floating point dan mengubahnya menjadi jenis tensor terkuantisasi. Hal ini terjadi melalui konversi nilai floating point dari jenis yang dinyatakan menjadi nilai bilangan bulat yang sesuai dari jenis penyimpanan menggunakan titik dan skala nol yang terkait dengan jenis elemen terkuantisasi.
def quantize(x: Value, result_type: Type) -> Value:
assert is_float(x) and is_quantized(result_type)
zero_points = compute_zero_points(result_type, TensorType(shape(x), storage_type(result_type)))
converted_zero_points = convert(zero_points, expressed_type(result_type))
converted_min = convert(storage_min(result_type), expressed_type(result_type))
converted_max = convert(storage_max(result_type), expressed_type(result_type))
x_scaled = x / compute_scales(result_type, type(x))
x_scaled_add_zp = x_scaled + converted_zero_points
x_clamped = clamp(converted_min, x_scaled_add_zp, converted_max)
x_rounded = round_nearest_even(x_clamped)
return convert(x_rounded, result_type)
dequantize_op_quantize
digunakan untuk menentukan komputasi {i>element<i}-{i>wise<i} di tensor terkuantisasi. Model ini mendekuantisasi, yaitu mengubah elemen terkuantisasi menjadi tipe yang dinyatakan, lalu melakukan operasi, dan kemudian mengkuantisasi, yaitu mengubah kembali ke jenis penyimpanan mereka. Saat ini, fungsi tersebut hanya bekerja untuk kuantisasi per-tensor. Kuantisasi per sumbu sedang dalam proses (#1574).
def dequantize_op_quantize(op, *inputs_and_output_type):
inputs = inputs_and_output_type[:-1]
output_type = inputs_and_output_type[-1]
float_inputs = map(dequantize, inputs)
float_result = op(*float_inputs)
return quantize(float_result, output_type)
def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
inputs = inputs_and_output_type[:-3]
float_inputs = map(dequantize, inputs)
float_results = op(*float_inputs)
return map(quantize, float_results, inputs_and_output_type[-3:])
def dequantize_compare(lhs, rhs, comparison_direction):
float_lhs = dequantize(lhs)
float_rhs = dequantize(rhs)
return compare(float_lhs, float_rhs, comparison_direction, FLOAT)
def dequantize_select_quantize(pred, on_true, on_false, output_type):
float_on_true = dequantize(on_true)
float_on_false = dequantize(on_false)
float_result = select(pred, float_on_true, float_on_false)
return quantize(float_result, output_type)
hybrid_dequantize_then_op
digunakan untuk menentukan kuantisasi hanya bobot untuk operasi hybrid yang menerima lhs dalam floating point dan rhs dalam jenis terkuantisasi. Ini mendekuantisasi input terkuantisasi ke dalam jenis yang dinyatakan dan melakukan komputasi dalam {i>float<i}. Jenis elemen tensor float lhs dan jenis rhs terkuantisasi yang dinyatakan Tensor harus identik.
def hybrid_dequantize_then_op(op, lhs, rhs):
assert(is_float(lhs) and is_quantized(rhs) and element_type(lhs) == expressed_type(rhs))
return op(lhs, dequantize(rhs))
Komputasi grid
cross_partition(replica_groups: Value) -> Value
. Lihat "cross_replica" di atas.cross_replica(replica_groups: Value) -> Value
. Lihat "cross_replica" di atas.cross_replica_and_partition(replica_groups: Value) -> Value
. Lihat "cross_replica_and_partition" di atas.flattened_ids(replica_groups: Value) -> Value
. Lihat "flattened_id" di atas.
Dinamisme
Nilai StabilHLO dapat memiliki ukuran dimensi dinamis, mis. tensor<?xi64>
.
Namun, nilai StableHLO tidak boleh memiliki jumlah dimensi dinamis (tidak memiliki peringkat
dinamis, misalnya tensor<*xi64>
). Operand dan hasil diizinkan untuk menggunakan
ukuran dimensi, meskipun ada batasan ukurannya. Batasan akan
diverifikasi secara statis jika memungkinkan, jika tidak,
mereka ditangguhkan untuk runtime dan
ketidakcocokan data akan menyebabkan perilaku yang tidak terdefinisi. Lihat contoh berikut.
Ketidakcocokan bentuk untuk operasi elementwise unary
Pertimbangkan program mainan berikut:
func.func @foo(%arg0: tensor<?xf64>) {
%0 = stablehlo.abs %arg0 : (tensor<?xf64>) -> tensor<2xf64>
return
}
Program semacam itu tidak biasa, karena tidak umum mengetahui bentuk
hasil, tetapi bukan bentuk inputnya. Meskipun demikian, ini adalah StableHLO yang valid
program ini. Operasi abs
tidak dapat divalidasi secara statis dalam
, karena bentuk operand yang tepat tidak diketahui. Namun, bentuk
memang kompatibel, dan ini dapat diperiksa secara statis: ?
dapat berubah
menjadi 2
saat runtime, dan tidak akan ada masalah. Namun, ?
dapat
juga berubah menjadi beberapa integer lain, dalam hal ini perilakunya tidak terdefinisi.
Perhatikan bahwa jika ukuran dimensi dinamis dalam hasil, tidak boleh ada perilaku yang tidak terdefinisi. Sebenarnya, tidak ada kata "yang diharapkan" sehingga tidak bisa ada tidak cocok.
Ketidakcocokan bentuk untuk operasi elementwise biner
Pertimbangkan program mainan berikut:
func.func @foo(%arg0: tensor<?xf64>, %arg1: tensor<?xf64>) {
%0 = stablehlo.add %arg0, %arg0 : (tensor<?xf64>, tensor<?xf64>) -> tensor<?xf64>
return
}
Dalam hal operasi elementwise biner, bentuk input dan hasil harus disetujui pada runtime. Pada waktu kompilasi, dimensi statis harus sama, jika tidak, mereka hanya harus kompatibel. Jika salah satu dimensi dinamis dalam input, mungkin ada dimensi yang tidak ditentukan perilaku saat runtime, karena ukuran dinamis mungkin tidak cocok dengan di operand lain (baik itu statis maupun dinamis). Jika semua input statis, maka apakah hasilnya dinamis atau tidak, akan menjadi masalah: secara statis dimensi yang diketahui akan diperiksa secara statis, dan dimensi dinamis tidak menerapkan batasan apa pun.
Ketidakcocokan bentuk untuk operasi yang mengambil bentuk output-nya sebagai operand
Pertimbangkan program mainan berikut:
func.func @foo(%arg0: tensor<2xi32>) {
%0 = stablehlo.dynamic_iota %arg0, dim = 0 : (tensor<2xi32>) -> tensor<3x4xi64>
return
}
Nilai dalam operand bentuk saat runtime harus cocok dengan bentuk hasil,
jika tidak, perilaku
tidak terdefinisi. Artinya, pada runtime %arg0
harus memiliki
senilai dense<[3, 4]> : tensor<2xi32>
. Jika operand bentuk konstan, ini
dapat diverifikasi secara statis. Jika bentuk hasil sepenuhnya dinamis, maka ada
tidak boleh berupa ketidakcocokan.