Spesifikasi StableHLO

StableHLO adalah operasi yang ditetapkan untuk operasi tingkat tinggi (HLO) dalam model machine learning (ML). StableHLO berfungsi sebagai lapisan portabilitas antara framework ML dan compiler ML yang berbeda: Framework ML yang menghasilkan program StableHLO kompatibel dengan compiler ML yang menggunakan program StableHLO.

Tujuan kami adalah menyederhanakan dan mempercepat pengembangan ML dengan membuat lebih banyak interoperabilitas antara berbagai framework ML (seperti TensorFlow, JAX, dan PyTorch) dan compiler ML (seperti XLA dan IREE). Menjelang tujuan tersebut, dokumen ini memberikan spesifikasi untuk bahasa pemrograman StableHLO.

Spesifikasi ini berisi tiga bagian utama. Pertama, bagian Programs menjelaskan struktur program StableHLO yang terdiri dari fungsi StableHLO yang terdiri dari operasi StableHLO. Dalam struktur tersebut, bagian Ops menentukan semantik operasi individual. Bagian Execution menyediakan semantik untuk semua operasi yang dijalankan bersama dalam suatu program. Terakhir, bagian Notation membahas notasi yang digunakan di seluruh spesifikasi.

Program

Program ::= {Func}

Program StableHLO terdiri dari sejumlah fungsi StableHLO. Berikut adalah contoh program dengan fungsi @main yang memiliki 3 input (%image, %weights, dan %bias) serta 1 output. Isi fungsi memiliki 6 operasi.

func.func @main(
  %image: tensor<28x28xf32>,
  %weights: tensor<784x10xf32>,
  %bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
  %0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
  %1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
  %2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  %3 = "stablehlo.constant"() { value = dense<0.0> : tensor<1x10xf32> } : () -> tensor<1x10xf32>
  %4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  "func.return"(%4): (tensor<1x10xf32>) -> ()
}

Fungsi

Func        ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs  ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput   ::= '%' ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput  ::= ValueType
FuncBody    ::= {Op}

Fungsi StableHLO (yang juga disebut fungsi bernama) memiliki ID, input/output, dan isi. Di masa mendatang, kami berencana untuk memperkenalkan metadata tambahan untuk fungsi guna mencapai kompatibilitas yang lebih baik dengan HLO (#425, #626, #740, #744).

ID

FuncId  ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
          | '%' letter {letter | digit}
letter  ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit   ::= '0' | ... | '9'

ID StableHLO mirip dengan ID dalam banyak bahasa pemrograman, dengan dua kekhasan: 1) semua ID memiliki karakteristik yang membedakan berbagai jenis ID, 2) ID nilai dapat berupa angka sepenuhnya untuk menyederhanakan pembuatan program StableHLO.

Jenis

Type         ::= ValueType | NonValueType
ValueType    ::= TensorType | QuantizedTensorType | TokenType | TupleType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType

Jenis StableHLO dikategorikan ke dalam jenis nilai (yang juga disebut sebagai jenis kelas satu) yang mewakili nilai StableHLO dan jenis non-nilai yang menjelaskan elemen program lainnya. Jenis StableHLO mirip dengan jenis dalam banyak bahasa pemrograman, dengan keunikan utamanya adalah StableHLO khusus domain yang menghasilkan beberapa hasil yang tidak biasa (misalnya, jenis skalar bukan jenis nilai).

TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit}

Jenis Tensor mewakili TensorFlow, yaitu array multidimensi. Elemen ini memiliki bentuk dan jenis elemen, dengan bentuk yang mewakili ukuran dimensi non-negatif dalam urutan menaik dari dimensi yang terkait (yang juga disebut sumbu) yang diberi nomor dari 0 hingga R-1. Jumlah dimensi R disebut peringkat. Misalnya, tensor<2x3xf32> adalah jenis tensor dengan bentuk 2x3 dan jenis elemen f32. Diagram ini memiliki dua dimensi (atau, dengan kata lain, dua sumbu) - dimensi ke-0 dan dimensi pertama - yang ukurannya adalah 2 dan 3. Peringkatnya adalah 2.

Ini menentukan dukungan untuk bentuk statis yang ukuran dimensi diketahui secara statis. Di masa mendatang, kami juga berencana memperkenalkan dukungan untuk bentuk dinamis dengan ukuran dimensi yang sebagian atau seluruh tidak diketahui (#8). Selain itu, kami akan mempelajari perluasan jenis TensorFlow di luar ukuran dimensi dan jenis elemen, misalnya, untuk menyertakan tata letak (#629) dan ketersebaran (#1078).

QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
                  QuantizationStorageType
                  ['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
                  ':' QuantizationExpressedType
                  [':' QuantizationDimension]
                  ',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerConstant
QuantizationStorageMax ::= IntegerConstant
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerConstant
QuantizationParameters ::= QuantizationParameter
                         | '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale ':' QuantizationZeroPoint
QuantizationScale ::= FloatConstant
QuantizationZeroPoint ::= IntegerConstant
Nama Jenis Batasan
storage_type jenis bilangan bulat (C1-C4), (C9)
storage_min konstanta bilangan bulat (C2), (C4), (C8)
storage_max konstanta bilangan bulat (C3), (C4), (C8)
expressed_type jenis floating point (C1), (C5)
quantization_dimension konstanta bilangan bulat opsional (C11-C13)
scales bilangan variadic dari konstanta floating point (C5-C7), (C10), (C11), (C13)
zero_points angka variadic konstanta integer (C8-C10)

Jenis elemen terkuantisasi mewakili nilai bilangan bulat dari jenis penyimpanan dalam rentang dari storage_min hingga storage_max (inklusif) yang sesuai dengan nilai floating point dari jenis yang dinyatakan. Untuk nilai bilangan bulat tertentu i, nilai floating point yang sesuai, f, dapat dihitung sebagai f = (i - zero_point) * scale, dengan scale dan zero_point disebut parameter kuantisasi. storage_min dan storage_max bersifat opsional dalam tata bahasa, tetapi memiliki nilai default min_value(storage_type) dan max_value(storage_type) masing-masing. Jenis elemen terkuantisasi memiliki batasan berikut:

  • (C1) num_bits(storage_type) < num_bits(expressed_type).
  • (C2) type(storage_min) = storage_type.
  • (C3) type(storage_max) = storage_type.
  • (C4) min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type).
  • (C5) type(scales...) = expressed_type.
  • (C6) 0 < scales.
  • (C7) is_finite(scales...).
  • (C8) storage_min <= zero_points <= storage_max.
  • (C9) type(zero_points...) = storage_type.
  • (C10) size(scales) = size(zero_points).
  • (C11) Jika is_empty(quantization_dimension), maka size(scales) = 1.
  • (C12) 0 <= quantization_dimension.

Saat ini, QuantizationScale adalah konstanta floating point, tetapi ada minat yang kuat pada skala berbasis bilangan bulat, yang ditampilkan dengan pengganda dan pergeseran. Kami berencana mempelajari hal ini dalam waktu dekat (#1404).

Terdapat diskusi yang sedang berlangsung tentang semantik QuantizationZeroPoint, termasuk jenis, nilai, dan apakah hanya boleh ada satu atau kemungkinan beberapa titik nol dalam jenis TensorFlow terkuantisasi. Berdasarkan hasil diskusi ini, spesifikasi di sekitar titik nol dapat berubah di masa mendatang (#1405).

Diskusi berkelanjutan lainnya melibatkan semantik QuantizationStorageMin dan QuantizationStorageMax untuk menentukan apakah ada batasan yang harus diberlakukan pada nilai ini dan pada nilai Tensor terkuantisasi (#1406).

Terakhir, kita berencana mempelajari cara merepresentasikan skala dan titik nol yang tidak diketahui, mirip dengan cara kita akan mempelajari cara mewakili ukuran dimensi yang tidak diketahui (#1407).

Jenis Quantized Tensor merepresentasikan TensorFlow dengan elemen terkuantisasi. Tensor ini sama persis dengan TensorFlow biasa, hanya saja elemennya memiliki jenis elemen terkuantisasi, bukan jenis elemen reguler.

Pada TensorFlow terkuantisasi, kuantisasi dapat berupa per-tensor, yang berarti memiliki satu scale dan zero_point untuk seluruh TensorFlow atau dapat berupa per-sumbu, yang berarti memiliki beberapa scales dan zero_points, satu pasangan per irisan dari dimensi tertentu quantization_dimension. Secara lebih formal, dalam Tensor t dengan kuantisasi per sumbu, terdapat irisan dim(t, quantization_dimension) dari quantization_dimension: t[:, ..., 0, ..., :], t[:, ..., 1, ..., :], dll. Semua elemen dalam irisan i menggunakan scales[i] dan zero_points[i] sebagai parameter kuantisasinya. Jenis Quantized Tensor memiliki batasan berikut:

  • Untuk kuantisasi per-tensor:
    • Tidak ada batasan tambahan.
  • Untuk kuantisasi per sumbu:
    • (C12) quantization_dimension < rank(self).
    • (C13) dim(self, quantization_dimension) = size(scales)
TokenType ::= 'token'

Jenis token mewakili token, yaitu nilai buram yang dihasilkan dan dipakai oleh beberapa operasi. Token digunakan untuk menerapkan urutan eksekusi pada operasi seperti yang dijelaskan di bagian Eksekusi.

TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]

Jenis tupel mewakili tupel, yaitu daftar heterogen. Tuple adalah fitur lama yang hanya ada untuk kompatibilitas dengan HLO. Di HLO, tuple digunakan untuk merepresentasikan input dan output variadic. Di StableHLO, input dan output variadik didukung secara native, dan satu-satunya penggunaan tuple di StableHLO adalah untuk secara komprehensif merepresentasikan HLO ABI, di mana misalnya, T, tuple<T>, dan tuple<tuple<T>> mungkin sangat berbeda, bergantung pada implementasi tertentu. Di masa mendatang, kami berencana membuat perubahan pada ABI HLO yang dapat memungkinkan kami menghapus jenis tuple dari StableHLO (#598).

TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f8E4M3FN' | 'f8E5M2' | 'f8E4M3FNUZ' | 'f8E5M2FNUZ'
            | 'f8E4M3B11FNUZ' | 'bf16' | 'f16' | 'f32' | 'f64'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'

Jenis elemen mewakili elemen jenis TensorFlow. Tidak seperti banyak bahasa pemrograman, jenis ini bukan class pertama di StableHLO. Artinya, program StableHLO tidak dapat secara langsung merepresentasikan nilai jenis ini (sehingga idiomatis untuk merepresentasikan nilai skalar jenis T dengan nilai tensor 0 dimensi jenis tensor<T>).

  • Jenis Boolean mewakili nilai boolean true dan false.
  • Jenis integer dapat ditandatangani (si) atau tidak ditandatangani (ui) dan memiliki salah satu lebar bit yang didukung (4, 8, 16, 32 atau 64). Jenis siN bertanda tangan mewakili nilai integer dari -2^(N-1) hingga 2^(N-1)-1 inklusif, dan uiN yang tidak ditandatangani mewakili nilai integer dari 0 sampai 2^N-1 inklusif.
  • Jenis floating point dapat berupa salah satu dari hal berikut:
  • Jenis kompleks mewakili nilai kompleks yang memiliki bagian nyata dan bagian imajiner dari jenis elemen yang sama. Jenis kompleks yang didukung adalah complex<f32> (kedua bagian berjenis f32) dan complex<f64> (keduanya berjenis f64).
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]

Jenis fungsi mewakili fungsi bernama dan anonim. Keduanya memiliki jenis input (daftar jenis di sisi kiri ->) dan jenis output (daftar jenis di sisi kanan ->). Dalam banyak bahasa pemrograman, jenis fungsi adalah class pertama, tetapi tidak di StableHLO.

StringType ::= 'string'

Jenis string mewakili urutan byte. Tidak seperti banyak bahasa pemrograman, jenis string bukanlah class pertama di StableHLO dan hanya digunakan untuk menentukan metadata statis elemen program.

Operasi

Operasi StableHLO (yang juga disebut operasi) mewakili sekumpulan operasi tingkat tinggi tertutup dalam model machine learning. Seperti dibahas di atas, sintaksis StableHLO sangat terinspirasi oleh MLIR, yang belum tentu merupakan alternatif yang paling ergonomis, tetapi dianggap paling cocok untuk tujuan StableHLO dalam menciptakan lebih banyak interoperabilitas antara framework ML dan compiler ML.

Op            ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName        ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic    ::= 'abs' | 'add' | ...

Operasi StableHLO (yang juga disebut operasi) memiliki nama, input/output, dan tanda tangan. Nama ini terdiri dari awalan stablehlo. dan mnemonik yang secara unik mengidentifikasi salah satu operasi yang didukung. Lihat di bawah untuk daftar lengkap semua operasi yang didukung.

Saat ini, program StableHLO di alam liar terkadang berisi operasi yang tidak dijelaskan dalam dokumen ini. Di masa mendatang, kami berencana untuk menyerap operasi ini ke dalam opset StableHLO atau melarangnya muncul dalam program StableHLO. Sementara itu, berikut adalah daftar operasi tersebut:

  • builtin.module, func.func, func.call, dan func.return (#425).
  • Operasi chlo (#602).
  • Kategori "Not in HLO" dari operasi StableHLO - awalnya merupakan bagian dari opset StableHLO, tetapi kemudian dianggap tidak cocok: broadcast, create_token, cross-replica-sum, dot, einsum, torch_index_select, unary_einsum (#3).
  • Kategori "Dynamism" dari operasi StableHLO - operasi tersebut di-bootstrap dari MHLO, tetapi kami belum menentukannya: compute_reshape_shape, cstr_reshapable, dynamic_broadcast_in_dim, dynamic_conv, dynamic_gather, dynamic_iota, dynamic_pad, dynamic_reshape, real_dynamic_slice, set_dimension_size (#8).
  • Komputasi bentuk, termasuk operasi arith, shape, dan tensor (#8).
OpInputs        ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues   ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue    ::= ValueId
OpInputFuncs    ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs    ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs       ::= [OpOutput {',' OpOutput} '=']
OpOutput        ::= ValueId

Operasi menggunakan input dan menghasilkan output. Input dikategorikan ke dalam nilai input (dihitung selama eksekusi), fungsi input (disediakan secara statis, karena dalam fungsi StableHLO bukan nilai kelas satu), dan atribut input (juga disediakan secara statis). Jenis input dan output yang dipakai dan dihasilkan oleh operasi bergantung pada mnemoniknya. Misalnya, operasi add menggunakan 2 nilai input dan menghasilkan 1 nilai output. Sebagai perbandingan, operasi select_and_scatter menggunakan 3 nilai input, 2 fungsi input, dan 3 atribut input.

OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused      ::= '^' digit {digit}
              | '^' letter {letter | digit}

Fungsi input (yang juga disebut fungsi anonim) sangat mirip dengan fungsi bernama, kecuali: 1) tidak memiliki ID (sehingga disebut "anonim"), 2) fungsi input tidak mendeklarasikan jenis output (jenis output disimpulkan dari op return dalam fungsi).

Sintaksis untuk fungsi input menyertakan bagian yang saat ini tidak digunakan (lihat produksi Unused di atas) yang tersedia untuk kompatibilitas dengan MLIR. Di MLIR, ada konsep "region" yang lebih umum yang dapat memiliki beberapa "blok" operasi yang terhubung bersama melalui operasi jump. Blok ini memiliki ID yang sesuai dengan produksi Unused, sehingga dapat dibedakan satu sama lain. StableHLO tidak memiliki operasi jump, sehingga bagian sintaksis MLIR yang sesuai tidak digunakan (tetapi masih ada).

OpInputAttr      ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName  ::= letter {letter | digit}
OpInputAttrValue ::= Constant

Atribut input memiliki nama dan nilai yang merupakan salah satu konstanta yang didukung. Library ini adalah cara utama untuk menentukan metadata statis elemen program. Misalnya, operasi concatenate menggunakan atribut dimension untuk menentukan dimensi yang digunakan untuk menggabungkan nilai inputnya. Demikian pula, operasi slice menggunakan beberapa atribut seperti start_indices dan limit_indices untuk menentukan batas yang digunakan untuk memotong nilai input.

Saat ini, program StableHLO di alam liar terkadang berisi atribut yang tidak dijelaskan dalam dokumen ini. Di masa mendatang, kami berencana untuk menyerap atribut ini ke dalam opset StableHLO atau melarangnya muncul dalam program StableHLO. Sementara itu, berikut adalah daftar atribut tersebut:

  • layout (#629).
  • mhlo.frontend_attributes (#628).
  • mhlo.sharding (#619).
  • output_operand_aliases (#740).
  • Metadata lokasi (#594).
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'

Signature Op terdiri dari jenis semua nilai input (daftar jenis di sisi kiri ->) dan jenis semua nilai output (daftar jenis di sisi kanan ->). Sebenarnya, jenis input bersifat redundan, dan jenis output juga hampir selalu redundan (karena untuk sebagian besar operasi StableHLO, jenis output dapat disimpulkan dari input). Meskipun demikian, tanda tangan operasi sengaja menjadi bagian dari sintaksis StableHLO untuk kompatibilitas dengan MLIR.

Berikut adalah contoh operasi yang mnemoniknya adalah select_and_scatter. Fungsi ini menggunakan 3 nilai input (%operand, %source, dan %init_value), 2 fungsi input, dan 3 atribut input (window_dimensions, window_strides, dan padding). Perhatikan bahwa tanda tangan operasi hanya menyertakan jenis nilai inputnya (tetapi bukan jenis fungsi dan atribut input yang disediakan secara inline).

%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
    %0 = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GE>
    } : (tensor<i32>, tensor<i32>) -> tensor<i1>
    "stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
    "stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
  window_dimensions = dense<[3, 1]> : tensor<2xi64>,
  window_strides = dense<[2, 1]> : tensor<2xi64>,
  padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>

Konstanta

Constant ::= BooleanConstant
           | IntegerConstant
           | FloatConstant
           | ComplexConstant
           | TensorConstant
           | QuantizedTensorConstant
           | StringConstant
           | EnumConstant

Konstanta StableHLO memiliki literal dan jenis yang bersama-sama mewakili nilai StableHLO. Umumnya, jenis tersebut adalah bagian dari sintaksis konstanta, kecuali jika bersifat tidak ambigu (mis., konstanta boolean secara jelas memiliki jenis i1, sedangkan konstanta integer dapat memiliki beberapa jenis yang memungkinkan).

BooleanConstant ::= BooleanLiteral
BooleanLiteral  ::= 'true' | 'false'

Konstanta Boolean mewakili nilai boolean true dan false. Konstanta Boolean memiliki jenis i1.

IntegerConstant   ::= IntegerLiteral ':' IntegerType
IntegerLiteral    ::= ['-' | '+'] DecimalDigits
                    | ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits     ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit      ::= '0' | ... | '9'
hexadecimalDigit  ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'

Konstanta bilangan bulat merepresentasikan nilai bilangan bulat melalui string yang menggunakan notasi desimal atau heksadesimal. Basis lain, misalnya biner atau oktal, tidak didukung. Konstanta bilangan bulat memiliki batasan berikut:

  • (C1) is_wellformed(integer_literal, integer_type).
FloatConstant  ::= FloatLiteral ':' FloatType
FloatLiteral   ::= SignPart IntegerPart FractionalPart ScientificPart
                 | '0x' [HexadecimalDigits]
SignPart       ::= ['-' | '+']
IntegerPart    ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]

Konstanta floating point mewakili nilai floating point melalui string yang menggunakan notasi ilmiah atau desimal. Selain itu, notasi heksadesimal dapat digunakan untuk menentukan secara langsung bit yang mendasarinya dalam format floating point dari jenis yang sesuai. Konstanta floating point memiliki batasan berikut:

  • (C1) Jika notasi non-heksadesimal digunakan, is_wellformed(float_literal, float_type).
  • (C2) Jika notasi heksadesimal digunakan, size(hexadecimal_digits) = num_bits(float_type) / 4.
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral  ::= '(' RealPart ',' ImaginaryPart ')'
RealPart        ::= FloatLiteral
ImaginaryPart   ::= FloatLiteral

Konstanta kompleks merepresentasikan nilai kompleks menggunakan daftar bagian nyata (yang pertama) dan bagian imajiner (mendatang kedua). Misalnya, (1.0, 0.0) : complex<f32> mewakili 1.0 + 0.0i, dan (0.0, 1.0) : complex<f32> mewakili 0.0 + 1.0i. Urutan bagian-bagian ini kemudian disimpan dalam memori ditentukan oleh implementasi. Konstanta kompleks memiliki batasan berikut:

  • (C1) is_wellformed(real_part, complex_element_type(complex_type)).
  • (C2) is_wellformed(imaginary_part, complex_element_type(complex_type)).
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral  ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral   ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements  ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral

Konstanta Tensor merepresentasikan nilai TensorFlow menggunakan daftar bertingkat yang ditentukan melalui notasi NumPy. Misalnya, dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> mewakili nilai TensorFlow dengan pemetaan berikut dari indeks ke elemen: {0, 0} => 1, {0, 1} => 2, {0, 2} => 3, {1, 0} => 4, {1, 1} => 5, {1, 2} => 6. Urutan elemen-elemen ini kemudian disimpan dalam memori ditentukan oleh implementasi. Konstanta Tensor memiliki batasan berikut:

  • (C1) has_syntax(tensor_literal, element_type(tensor_type)), dengan:
    • has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type).
    • has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type).
  • (C2) has_shape(tensor_literal, shape(tensor_type)), dengan:
    • has_shape(element_literal: Syntax, []) = true.
    • has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:]).
    • jika tidak, false.
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral  ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'

Konstanta Tensor terkuantisasi merepresentasikan nilai tensor terkuantisasi menggunakan notasi yang sama dengan konstanta Tensor, dengan elemen yang ditetapkan sebagai konstanta jenis penyimpanannya. Konstanta TensorFlow terkuantisasi memiliki batasan berikut:

  • (C1) has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type)).
  • (C2) has_shape(quantized_tensor_literal, shape(quantized_tensor_type)).
StringConstant  ::= StringLiteral
StringLiteral   ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence  ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))

Literal string terdiri dari byte yang ditentukan menggunakan karakter ASCII dan urutan escape. Keduanya tidak bergantung pada encoding, sehingga interpretasi byte ini ditentukan oleh implementasinya. Literal string memiliki jenis string.

Operasi

abs

Semantik

Menjalankan operasi abs yang berbasis elemen pada TensorFlow operand dan menghasilkan tensor result. Bergantung pada jenis elemen, lakukan hal berikut:

  • Untuk bilangan bulat bertanda tangan: modulus bilangan bulat.
  • Untuk float: abs dari IEEE-754.
  • Untuk bilangan kompleks: modulus kompleks.
  • Untuk jenis terkuantisasi: dequantize_op_quantize(abs, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand Tensor integer bertanda tangan, floating point, atau tipe kompleks atau quantized per-tensor (C1-C2)

Output

Nama Jenis Batasan
result Tensor integer bertanda tangan atau tipe floating point atau quantized per-tensor (C1-C2)

Batasan

  • (C1) shape(result) = shape(operand).
  • (C2) baseline_element_type(result) ditentukan sebagai:
    • complex_element_type(element_type(operand)) jika is_complex(operand).
    • baseline_element_type(operand) jika tidak.

Contoh

// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand) : (tensor<3xi32>) -> tensor<3xi32>
// %result: [2, 0, 2]

Contoh Lainnya

add

Semantik

Melakukan penambahan dua TensorFlow untuk elemen lhs dan rhs serta menghasilkan tensor result. Bergantung pada jenis elemen, lakukan hal berikut:

  • Untuk boolean: logika OR.
  • Untuk bilangan bulat: penjumlahan bilangan bulat.
  • Untuk float: addition dari IEEE-754.
  • Untuk bilangan kompleks: penjumlahan kompleks.
  • Untuk jenis terkuantisasi: dequantize_op_quantize(add, lhs, rhs, type(result)).

Input

Label Nama Jenis Batasan
(I1) lhs Quantized Tensor atau per-tensor (C1)
(I2) rhs Quantized Tensor atau per-tensor (C1)

Output

Nama Jenis Batasan
result Quantized Tensor atau per-tensor (C1)

Batasan

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Contoh

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[6, 8], [10, 12]]

Contoh Lainnya

after_all

Semantik

Memastikan bahwa operasi yang menghasilkan inputs dijalankan sebelum operasi apa pun yang bergantung pada result. Eksekusi operasi ini tidak melakukan apa pun, hanya ada untuk membangun dependensi data dari result ke inputs.

Input

Label Nama Jenis
(I1) inputs angka variadik token

Output

Nama Jenis
result token

Contoh

// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token

Contoh Lainnya

all_gather

Semantik

Dalam setiap grup proses di petak proses StableHLO, gabungkan nilai tensor operand dari setiap proses di sepanjang all_gather_dim dan menghasilkan tensor result.

Operasi ini membagi petak proses StableHLO menjadi process_groups yang ditentukan sebagai berikut:

  • cross_replica(replica_groups) jika channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) jika channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) jika channel_id > 0 and use_global_device_ids = true.

Setelah itu, dalam setiap process_group:

  • operands@receiver = [operand@sender for sender in process_group] untuk semua receiver di process_group.
  • result@process = concatenate(operands@process, all_gather_dim) untuk semua process di process_group.

Input

Label Nama Jenis Batasan
(I1) operand Quantized Tensor atau per-tensor (C1), (C6)
(I2) all_gather_dim konstanta jenis si64 (C1), (C6)
(I3) replica_groups Konstanta TensorFlow 2 dimensi dari jenis si64 (C2-C4)
(I4) channel_id konstanta jenis si64 (C5)
(I5) use_global_device_ids konstanta jenis i1 (C5)

Output

Nama Jenis Batasan
result Quantized Tensor atau per-tensor (C6)

Batasan

  • (C1) 0 <= all_gather_dim < rank(operand).
  • (C2) is_unique(replica_groups).
  • (C3) size(replica_groups) didefinisikan sebagai:
    • num_replicas jika cross_replica digunakan.
    • num_replicas jika cross_replica_and_partition digunakan.
    • num_processes jika flattened_ids digunakan.
  • (C4) 0 <= replica_groups < size(replica_groups).
  • (C5) Jika use_global_device_ids = true, maka channel_id > 0.
  • (C6) type(result) = type(operand) kecuali:
    • dim(result, all_gather_dim) = dim(operand, all_gather_dim) * dim(process_groups, 1).

Contoh

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
%result = "stablehlo.all_gather"(%operand) {
  all_gather_dim = 1 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  // channel_id = 0
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
  // use_global_device_ids = false
} : (tensor<2x2xi64>) -> tensor<2x4xi64>
// %result@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]

Contoh Lainnya

all_reduce

Semantik

Dalam setiap grup proses di petak proses StableHLO, terapkan fungsi reduksi computation ke nilai TensorFlow operand dari setiap proses dan menghasilkan TensorFlow result.

Operasi ini membagi petak proses StableHLO menjadi process_groups yang ditentukan sebagai berikut:

  • cross_replica(replica_groups) jika channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) jika channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) jika channel_id > 0 and use_global_device_ids = true.

Setelah itu, dalam setiap process_group:

  • result@process[result_index] = exec(schedule) untuk beberapa hierarki biner schedule dengan:
    • exec(node) = computation(exec(node.left), exec(node.right))
    • exec(leaf) = leaf.value
  • schedule adalah hierarki biner yang ditentukan implementasi yang memiliki traversal sesuai urutannya adalah to_destination_type(operands@process_group...[result_index], type(func_inputs(computation)[0])).

Input

Label Nama Jenis Batasan
(I1) operand Quantized Tensor atau per-tensor (C5), (C6)
(I2) replica_groups bilangan variadic konstanta TensorFlow 1 dimensi dari jenis si64 (C1-C3)
(I3) channel_id konstanta jenis si64 (C4)
(I4) use_global_device_ids konstanta jenis i1 (C4)
(I5) computation fungsi (C5)

Output

Nama Jenis Batasan
result Quantized Tensor atau per-tensor (C6-C7)

Batasan

  • (C1) is_unique(replica_groups).
  • (C2) size(replica_groups) ditentukan sebagai:
    • num_replicas jika cross_replica digunakan.
    • num_replicas jika cross_replica_and_partition digunakan.
    • num_processes jika flattened_ids digunakan.
  • (C3) 0 <= replica_groups < size(replica_groups).
  • (C4) Jika use_global_device_ids = true, maka channel_id > 0.
  • (C5) computation memiliki jenis (tensor<E>, tensor<E>) -> (tensor<E>) dengan is_promotable(element_type(operand), E).
  • (C6) shape(result) = shape(operand).
  • (C7) element_type(result) = E.

Contoh

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [1, 2, 3, 4]
// %operand@(1, 0): [5, 6, 7, 8]
%result = "stablehlo.all_reduce"(%operand) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<i64>) -> tensor<i64>
// %result@(0, 0): [6, 8, 10, 12]
// %result@(1, 0): [6, 8, 10, 12]

Contoh Lainnya

all_to_all

Semantik

Dalam setiap grup proses di petak proses StableHLO, membagi nilai tensor operand di sepanjang split_dimension menjadi beberapa bagian, menyebarkan bagian terpisah di antara proses, menggabungkan bagian yang tersebar di sepanjang concat_dimension, dan menghasilkan TensorFlow result.

Operasi ini membagi petak proses StableHLO menjadi process_groups yang ditentukan sebagai berikut:

  • cross_replica(replica_groups) jika channel_id <= 0.
  • cross_partition(replica_groups) jika channel_id > 0.

Setelah itu, dalam setiap process_group:

  • split_parts@sender = split(operand@sender, split_count, split_dimension) untuk semua sender di process_group.
  • scattered_parts@receiver = [split_parts@sender[receiver_index] for sender in process_group] dengan receiver_index = process_group.index(receiver).
  • result@process = concatenate(scattered_parts@process, concat_dimension).

Input

Label Nama Jenis Batasan
(I1) operand Quantized Tensor atau per-tensor (C1-C3), (C9)
(I2) split_dimension konstanta jenis si64 (C1), (C2), (C9)
(I3) concat_dimension konstanta jenis si64 (C3), (C9)
(I4) split_count konstanta jenis si64 (C2), (C4), (C8), (C9)
(I5) replica_groups Konstanta TensorFlow 2 dimensi dari jenis si64 (C5-C8)
(I6) channel_id konstanta jenis si64

Output

Nama Jenis Batasan
result Quantized Tensor atau per-tensor (C9)

Batasan

  • (C1) 0 <= split_dimension < rank(operand).
  • (C2) dim(operand, split_dimension) % split_count = 0.
  • (C3) 0 <= concat_dimension < rank(operand).
  • (C4) 0 < split_count.
  • (C5) is_unique(replica_groups).
  • (C6) size(replica_groups) didefinisikan sebagai:
    • num_replicas jika cross_replica digunakan.
    • num_partitions jika cross_partition digunakan.
  • (C7) 0 <= replica_groups < size(replica_groups).
  • (C8) dim(replica_groups, 1) = split_count.
  • (C9) type(result) = type(operand) kecuali:
    • dim(result, split_dimension) = dim(operand, split_dimension) / split_count.
    • dim(result, concat_dimension) = dim(operand, concat_dimension) * split_count.

Contoh

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
//                   [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
//                   [13, 14, 15, 16]]
%result = "stablehlo.all_to_all"(%operand) {
  split_dimension = 1 : i64,
  concat_dimension = 0 : i64,
  split_count = 2 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
} : (tensor<2x4xi64>) -> tensor<4x2xi64>
// %result@(0, 0): [[1, 2],
//                  [5, 6],
//                  [9, 10],
//                  [13, 14]]
// %result@(1, 0): [[3, 4],
//                  [7, 8],
//                  [11, 12],
//                  [15, 16]]

Contoh Lainnya

dan

Semantik

Menjalankan AND dari dua TensorFlow lhs dan rhs, serta menghasilkan tensor result. Bergantung pada jenis elemen, lakukan hal berikut:

  • Untuk boolean: logika AND.
  • Untuk bilangan bulat: bitwise AND.

Input

Label Nama Jenis Batasan
(I1) lhs Tensor tipe boolean atau integer (C1)
(I2) rhs Tensor tipe boolean atau integer (C1)

Output

Nama Jenis Batasan
result Tensor tipe boolean atau integer (C1)

Batasan

  • (C1) type(lhs) = type(rhs) = type(result).

Contoh

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 0]]

atan2

Semantik

Menjalankan operasi atan2 yang mengutamakan elemen pada TensorFlow lhs dan rhs serta menghasilkan tensor result. Bergantung pada jenis elemen, lakukan hal berikut:

  • Untuk float: atan2 dari IEEE-754.
  • Untuk bilangan kompleks: atan2 kompleks.
  • Untuk jenis terkuantisasi: dequantize_op_quantize(atan2, lhs, rhs, type(result)).

Input

Label Nama Jenis Batasan
(I1) lhs Tensor floating-point atau tipe kompleks atau kuantisasi per-tensor (C1)
(I2) rhs Tensor floating-point atau tipe kompleks atau kuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result Tensor floating-point atau tipe kompleks atau kuantisasi per-tensor (C1)

Batasan

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Contoh

// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs) : (tensor<3xf64>, tensor<3xf64>) -> tensor<3xf64>
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]

Contoh Lainnya

batch_norm_grad

Semantik

Menghitung gradien beberapa input propagasi mundur batch_norm_training dari grad_output, dan menghasilkan tensor grad_operand, grad_scale, dan grad_offset. Secara lebih formal, operasi ini dapat dinyatakan sebagai dekomposisi ke operasi StableHLO yang ada menggunakan sintaksis Python sebagai berikut:

def compute_sum(operand, feature_index):
  (sum,) = reduce(
      inputs=[operand],
      init_values=[constant(0, element_type(operand))],
      dimensions=[i for i in range(rank(operand)) if i != feature_index],
      body=lambda x, y: add(x, y))
  return sum

def compute_mean(operand, feature_index):
  sum = compute_sum(operand, feature_index)
  divisor = constant(size(operand) / dim(operand, feature_index),
                     element_type(operand))
  divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
  return divide(sum, divisor_bcast)

def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
  # Broadcast inputs to type(operand)
  scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
  epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
                                   type(operand))

  # Perform normalization using the provided `mean` and `variance`
  # Intermediate values will be useful for computing gradients
  centered_operand = subtract(operand, mean_bcast)
  stddev = sqrt(add(variance_bcast, epsilon_bcast))
  normalized_operand = divide(centered_operand, stddev)

  # Use the implementation from batchnorm_expander.cc in XLA
  # Temporary variables have exactly the same names as in the C++ code
  elements_per_feature = broadcast_in_dim(
      constant(divide(size(operand), dim(operand, feature_index)),
               element_type(grad_output)),
      [], type(operand))
  i1 = multiply(grad_output, elements_per_feature)
  i2 = broadcast_in_dim(
      compute_sum(grad_output, feature_index), [feature_index], type(operand))
  i3 = broadcast_in_dim(
      compute_sum(multiply(grad_output, centered_operand), feature_index),
      [feature_index], type(operand))
  i4 = multiply(i3, centered_operand)
  i5 = divide(i4, add(variance_bcast, epsilon_bcast))
  i6 = subtract(subtract(i1, i2), i5)

  grad_operand =
      multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
  grad_scale =
      compute_sum(multiply(grad_output, normalized_operand), feature_index)
  grad_offset = compute_sum(grad_output, feature_index)

  return grad_operand, grad_scale, grad_offset

Untuk jenis terkuantisasi, jalankan dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean, variance, grad_output: batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index), operand, scale, mean, variance, grad_output, type(grad_operand), type(grad_scale), type(feature_index)).

Input

Label Nama Jenis Batasan
(I1) operand Tensor floating-point atau per-tensor quantizedTensor (C1-C3), (C5)
(I2) scale Tensor 1-dimensi dari jenis terkuantisasi floating-point atau per-tensor (C2), (C4), (C5)
(I3) mean Tensor 1-dimensi dari jenis terkuantisasi floating-point atau per-tensor (C2), (C4)
(I4) variance Tensor 1-dimensi dari jenis terkuantisasi floating-point atau per-tensor (C2), (C4)
(I5) grad_output Tensor floating-point atau per-tensor quantizedTensor (C2), (C3)
(I6) epsilon konstanta jenis f32
(I7) feature_index konstanta jenis si64 (C1), (C5)

Output

Nama Jenis Batasan
grad_operand Tensor floating-point atau per-tensor quantizedTensor (C2), (C3)
grad_scale Tensor 1-dimensi dari jenis terkuantisasi floating-point atau per-tensor (C2), (C4)
grad_offset Tensor 1-dimensi dari jenis terkuantisasi floating-point atau per-tensor (C2), (C4)

Batasan

  • (C1) 0 <= feature_index < rank(operand).
  • (C2) operand, scale, mean, variance, grad_output, grad_operand, grad_scale, dan grad_offset memiliki baseline_element_type yang sama.
  • (C3) operand, grad_output, dan grad_operand memiliki bentuk yang sama.
  • (C4) scale, mean, variance, grad_scale, dan grad_offset memiliki bentuk yang sama.
  • (C5) size(scale) = dim(operand, feature_index).

Contoh

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
//                [[0.1, 0.1], [0.1, 0.1]],
//                [[0.1, 0.1], [0.1, 0.1]]
//               ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>,
     tensor<2x2x2xf64>) -> (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %grad_operand: [
//                 [[0.0, 0.0], [0.0, 0.0]],
//                 [[0.0, 0.0], [0.0, 0.0]]
//                ]
// %grad_scale:  [0.0, 0.0]
// %grad_offset: [0.4, 0.4]

batch_norm_inference

Semantik

Menormalisasi TensorFlow operand di semua dimensi kecuali untuk dimensi feature_index dan menghasilkan TensorFlow result. Secara lebih formal, operasi ini dapat dinyatakan sebagai dekomposisi ke operasi StableHLO yang ada menggunakan sintaksis Python sebagai berikut:

def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
  # Broadcast inputs to shape(operand)
  scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
  offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
  epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
                                   type(operand))

  # Perform normalization using the provided `mean` and `variance` instead of
  # computing them like `batch_norm_training` does.
  centered_operand = subtract(operand, mean_bcast)
  stddev = sqrt(add(variance_bcast, epsilon_bcast))
  normalized_operand = divide(centered_operand, stddev)
  return add(multiply(scale_bcast, normalized_operand), offset_bcast)

Untuk jenis terkuantisasi, jalankan dequantize_op_quantize(lambda operand, scale, offset, mean, variance: batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index), operand, scale, offset, mean, variance, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand Tensor floating-point atau per-tensor quantizedTensor (C1-C7)
(I2) scale Tensor 1-dimensi dari jenis terkuantisasi floating-point atau per-tensor (C2), (C3)
(I3) offset Tensor 1-dimensi dari jenis terkuantisasi floating-point atau per-tensor (C2), (C4)
(I4) mean Tensor 1-dimensi dari jenis terkuantisasi floating-point atau per-tensor (C5)
(I5) variance Tensor 1-dimensi dari jenis terkuantisasi floating-point atau per-tensor (C2), (C6)
(I6) epsilon konstanta jenis f32
(I7) feature_index konstanta jenis si64 (C1), (C3-C6)

Output

Nama Jenis Batasan
result Tensor floating-point atau per-tensor quantizedTensor (C2), (C7)

Batasan

  • (C1) 0 <= feature_index < rank(operand).
  • (C2) operand, scale, offset, mean, variance, dan result memiliki baseline_element_type yang sama.
  • (C3) size(scale) = dim(operand, feature_index).
  • (C4) size(offset) = dim(operand, feature_index).
  • (C5) size(mean) = dim(operand, feature_index).
  • (C6) size(variance) = dim(operand, feature_index).
  • (C7) baseline_type(operand) = baseline_type(result).

Contoh

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>) -> tensor<2x2x2xf64>
// %result: [
//           [[0.0, 0.0], [2.0, 2.0]],
//           [[2.0, 2.0], [0.0, 0.0]]
//          ]

batch_norm_training

Semantik

Menghitung rata-rata dan varians di semua dimensi kecuali untuk dimensi feature_index dan menormalkan TensorFlow operand yang menghasilkan TensorFlow output, batch_mean, dan batch_var. Secara lebih formal, operasi ini dapat dinyatakan sebagai dekomposisi ke operasi StableHLO yang ada menggunakan sintaksis Python sebagai berikut:

def compute_mean(operand, feature_index):
  (sum,) = reduce(
      inputs=[operand],
      init_values=[constant(0, element_type(operand))],
      dimensions=[i for i in range(rank(operand)) if i != feature_index],
      body=lambda x, y: add(x, y))
  divisor = constant(size(operand) / dim(operand, feature_index),
                     element_type(operand))
  divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
  return divide(sum, divisor_bcast)

def compute_variance(operand, feature_index):
  mean = compute_mean(operand, feature_index)
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  centered_operand = subtract(operand, mean_bcast)
  return compute_mean(mul(centered_operand, centered_operand), feature_index)

def batch_norm_training(operand, scale, offset, epsilon, feature_index):
  mean = compute_mean(operand, feature_index)
  variance = compute_variance(operand, feature_index)
  return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
                              feature_index),
         mean, variance

Untuk jenis terkuantisasi, jalankan dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset: batch_norm_training(operand, scale, offset, epsilon, feature_index), operand, scale, offset, type(output), type(batch_mean), type(batch_var)).

Input

Label Nama Jenis Batasan
(I1) operand Tensor floating-point atau per-tensor quantizedTensor (C1)
(I2) scale Tensor 1-dimensi untuk floating point atau per-tensor terkuantisasi (C2), (C3)
(I3) offset Tensor 1-dimensi untuk floating point atau per-tensor terkuantisasi (C2), (C4)
(I4) epsilon konstanta jenis f32 (C1), (C3-C6)
(I5) feature_index konstanta jenis si64 (C1), (C3-C6)

Output

Nama Jenis Batasan
output Tensor floating-point atau per-tensor quantizedTensor (C7)
batch_mean Tensor 1-dimensi untuk floating point atau per-tensor terkuantisasi (C2), (C5)
batch_var Tensor 1-dimensi untuk floating point atau per-tensor terkuantisasi (C2), (C6)

Batasan

  • (C1) 0 <= feature_index < rank(operand).
  • (C2) operand, scale, offset, batch_mean, batch_var, dan output memiliki baseline_element_type yang sama.
  • (C3) size(scale) = dim(operand, feature_index).
  • (C4) size(offset) = dim(operand, feature_index).
  • (C5) size(batch_mean) = dim(operand, feature_index).
  • (C6) size(batch_var) = dim(operand, feature_index).
  • (C7) baseline_type(output) = baseline_type(operand).

Contoh

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) ->
    (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %output: [
//           [[0.0, 0.0], [2.0, 2.0]],
//           [[2.0, 2.0], [0.0, 0.0]]
//          ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]

bitcast_convert

Semantik

Melakukan operasi bitcast pada5_tensor operand dan menghasilkan tensor result dengan bit dari seluruh tensor operand diinterpretasikan ulang menggunakan jenis_tensor result.

Secara lebih formal, dengan mempertimbangkan E = element_type(operand), E' = element_type(result), dan R = rank(operand):

  • Jika num_bits(E') < num_bits(E), bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1]).
  • Jika num_bits(E') > num_bits(E), bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :]).
  • Jika num_bits(E') = num_bits(E), bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1]).

bits menampilkan representasi dalam memori dari nilai tertentu, dan perilakunya ditentukan oleh implementasi karena representasi yang tepat dari tensor ditentukan oleh implementasi, dan representasi yang tepat dari jenis elemen juga ditentukan oleh implementasi.

Input

Label Nama Jenis Batasan
(I1) operand Tensor atau Quantized Tensor (C1-C2)

Output

Nama Jenis Batasan
result Tensor atau Quantized Tensor (C1-C2)

Batasan

  • (C1) E = is_quantized(operand) ? storage_type(operand) : element_type(operand), E' = is_quantized(result) ? storage_type(result) : element_type(result), dan R = rank(operand) yang ditentukan:
    • Jika num_bits(E') = num_bits(E), shape(result) = shape(operand).
    • Jika num_bits(E') < num_bits(E):
    • rank(result) = R + 1.
    • dim(result, i) = dim(operand, i) untuk semua 0 <= i < R.
    • dim(result, R) * num_bits(E') = num_bits(E).
    • Jika num_bits(E') > num_bits(E):
    • rank(result) = R - 1.
    • dim(result, i) = dim(operand, i) untuk semua 0 <= i < R.
    • dim(operand, R - 1) * num_bits(E) = num_bits(E').
  • (C2) Jika is_complex(operand) or is_complex(result), maka is_complex(operand) and is_complex(result).

Contoh

// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand) : (tensor<f64>) -> tensor<4xf16>
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation

Contoh Lainnya

broadcast_in_dim

Semantik

Memperluas dimensi dan/atau peringkat TensorFlow input dengan menduplikasi data pada Tensor operand dan menghasilkan result. Secara lebih formal, result[result_index] = operand[operand_index] dengan semua d di axes(operand):

  • operand_index[d] = 0 jika dim(operand, d) = 1.
  • operand_index[d] = result_index[broadcast_dimensions[d]] jika tidak.

Input

Label Nama Jenis Batasan
(I1) operand Tensor atau Quantized Tensor (C1-C2), (C5-C6)
(I2) broadcast_dimensions Konstanta TensorFlow 1 dimensi dari jenis si64 (C2-C6)

Output

Nama Jenis Batasan
result Tensor atau Quantized Tensor (C1), (C3), (C5-C6)

Batasan

  • (C1) element_type(result) diberikan oleh:
    • element_type(operand), jika !is_per_axis_quantized(operand).
    • element_type(operand) kecuali bahwa quantization_dimension(operand), scales(operand), dan zero_points(operand) mungkin berbeda dengan respons quantization_dimension(result), scales(result), dan zero_points(result), jika tidak.
  • (C2) size(broadcast_dimensions) = rank(operand).
  • (C3) 0 <= broadcast_dimensions < rank(result).
  • (C4) is_unique(broadcast_dimensions).
  • (C5) Untuk semua d di axes(operand):
    • dim(operand, d) = 1 atau
    • dim(operand, d) = dim(result, broadcast_dimensions[d]).
  • (C6) Jika is_per_axis_quantized(result):
    • quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)].
    • Jika dim(operand, quantization_dimension(operand)) = 1, maka scales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result))).

Contoh

// %operand: [
//            [1, 2, 3]
//           ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
  broadcast_dimensions = array<i64: 2, 1>
} : (tensor<1x3xi32>) -> tensor<2x3x2xi32>
// %result: [
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ],
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ]
//          ]

Contoh Lainnya

casing

Semantik

Menghasilkan output dari mengeksekusi tepat satu fungsi dari branches, bergantung pada nilai index. Secara lebih formal, result = selected_branch() jika:

  • selected_branch = branches[index] jika 0 <= index < size(branches).
  • selected_branch = branches[-1] jika tidak.

Input

Label Nama Jenis Batasan
(I1) index Tensor 0 dimensi dari jenis si32
(I2) branches jumlah fungsi variadic (C1-C4)

Output

Nama Jenis Batasan
results bilangan variadic, quantized TensorFlow, atau token (C4)

Batasan

  • (C1) 0 < size(branches).
  • (C2) input_types(branches...) = [].
  • (C3) same(output_types(branches...)).
  • (C4) type(results...) = output_types(branches[0]).

Contoh

// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
  "stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> ()
}, {
  "stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> ()
}) : (tensor<i32>) -> (tensor<2xi64>, tensor<2xi64>)
// %result0: [1, 1]
// %result1: [1, 1]

Contoh Lainnya

Cbrt

Semantik

Menjalankan operasi root kubik berbasis elemen pada TensorFlow operand dan menghasilkan tensor result. Bergantung pada jenis elemen, lakukan hal berikut:

  • Untuk float: rootn(x, 3) dari IEEE-754.
  • Untuk bilangan kompleks: akar kubik kompleks.
  • Untuk jenis terkuantifikasi: dequantize_op_quantize(cbrt, operand, type(result))

Input

Label Nama Jenis Batasan
(I1) operand Tensor floating-point atau tipe kompleks atau kuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result Tensor floating-point atau tipe kompleks atau kuantisasi per-tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand) : (tensor<4xf64>) -> tensor<4xf64>
// %result: [0.0, 1.0, 2.0, 3.0]

Contoh Lainnya

ceil

Semantik

Menjalankan ceil element-wise dari operandtensor dan menghasilkan result. Mengimplementasikan operasi roundToIntegralTowardPositive dari spesifikasi IEEE-754. Untuk jenis terkuantisasi, jalankan dequantize_op_quantize(ceil, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand Tensor floating-point atau per-tensor quantizedTensor (C1)

Output

Nama Jenis Batasan
result Tensor floating-point atau per-tensor quantizedTensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]

Contoh Lainnya

Cholesky

Semantik

Menghitung dekomposisi Cholesky dari batch matriks.

Secara lebih formal, untuk semua i di index_space(result), result[i0, ..., iR-3, :, :] adalah dekomposisi Cholesky dari a[i0, ..., iR-3, :, :], dalam bentuk matriks segitiga bawah (jika lower adalah true) atau matriks segitiga atas (jika lower adalah false). Nilai output dalam segitiga yang berlawanan, yaitu segitiga atas ketat atau segitiga bawah ketat yang bersesuaian, ditentukan oleh implementasi.

Jika ada i yang matriks inputnya bukan matriks pasti positif Hermitian, perilakunya tidak terdefinisi.

Untuk jenis terkuantisasi, jalankan dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result)).

Input

Label Nama Jenis Batasan
(I1) a Tensor floating-point atau tipe kompleks atau kuantisasi per-tensor (C1-C3)
(I2) lower Konstanta integer 0 dimensi dari jenis i1

Output

Nama Jenis Batasan
result Tensor floating-point atau tipe kompleks atau kuantisasi per-tensor (C1)

Batasan

  • (C1) baseline_type(a) = baseline_type(result).
  • (C2) 2 <= rank(a).
  • (C3) dim(a, -2) = dim(a, -1).

Contoh

// %a: [
//      [1.0, 2.0, 3.0],
//      [2.0, 20.0, 26.0],
//      [3.0, 26.0, 70.0]
//     ]
%result = "stablehlo.cholesky"(%a) {
  lower = true
} : (tensor<3x3xf32>) -> tensor<3x3xf64>
// %result: [
//           [1.0, 0.0, 0.0],
//           [2.0, 4.0, 0.0],
//           [3.0, 5.0, 6.0]
//          ]

klem

Semantik

Menjepit setiap elemen TensorFlow operand di antara nilai minimum dan maksimum serta menghasilkan5 result. Secara lebih formal, result[result_index] = minimum(maximum(operand[result_index], min_element), max_element), dengan min_element = rank(min) = 0 ? min[] : min[result_index], max_element = rank(max) = 0 ? max[] : max[result_index]. Untuk jenis terkuantisasi, menjalankan dequantize_op_quantize(clamp, min, operand, max, type(result)).

Memaksakan pengurutan pada bilangan kompleks melibatkan semantik yang mengejutkan, jadi di masa mendatang kami berencana menghapus dukungan untuk bilangan kompleks untuk operasi ini (#560).

Input

Label Nama Jenis Batasan
(I1) min Quantized Tensor atau per-tensor (C1), (C3)
(I2) operand Quantized Tensor atau per-tensor (C1-C4)
(I3) max Quantized Tensor atau per-tensor (C2), (C3)

Output

Nama Jenis Batasan
result Quantized Tensor atau per-tensor (C4)

Batasan

  • (C1) rank(min) = 0 or shape(min) = shape(operand).
  • (C2) rank(max) = 0 or shape(max) = shape(operand).
  • (C3) baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max).
  • (C4) baseline_type(operand) = baseline_type(result).

Contoh

// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max) : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// %result: [5, 13, 20]

Contoh Lainnya

collective_broadcast

Semantik

Dalam setiap grup proses di petak proses StableHLO, kirim nilai tensor operand dari proses sumber ke proses target dan hasilkan tensor result.

Operasi ini membagi petak proses StableHLO menjadi process_groups yang ditentukan sebagai berikut:

  • cross_replica(replica_groups) jika channel_id <= 0.
  • cross_partition(replica_groups) jika channel_id > 0.

Setelah itu, result@process diberikan oleh:

  • operand@process_groups[i, 0] jika ada i sehingga prosesnya berada di process_groups[i].
  • broadcast_in_dim(constant(0, element_type(result)), [], type(result)) jika tidak.

Input

Label Nama Jenis Batasan
(I1) operand Tensor (C3)
(I2) replica_groups bilangan variadic konstanta TensorFlow 1 dimensi dari jenis si64 (C1), (C2)
(I3) channel_id konstanta jenis si64

Output

Nama Jenis Batasan
result Tensor (C3)

Batasan

  • (C1) is_unique(replica_groups).
  • (C2) 0 <= replica_groups < N dengan N ditentukan sebagai:
    • num_replicas jika cross_replica digunakan.
    • num_partitions jika cross_partition digunakan.
  • (C3) type(result) = type(operand).

Contoh

// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
  replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]

collective_permute

Semantik

Dalam setiap grup proses di petak proses StableHLO, kirim nilai tensor operand dari proses sumber ke proses target dan menghasilkan tensor result.

Operasi ini membagi petak proses StableHLO menjadi process_groups yang ditentukan sebagai berikut:

  • cross_replica(source_target_pairs) jika channel_id <= 0.
  • cross_partition(source_target_pairs) jika channel_id > 0.

Setelah itu, result@process diberikan oleh:

  • operand@process_groups[i, 0], jika ada i yang sedemikian rupa sehingga process_groups[i, 1] = process.
  • broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result)) jika tidak.

Input

Label Nama Jenis Batasan
(I1) operand Quantized Tensor atau per-tensor (C5)
(I2) source_target_pairs Konstanta TensorFlow 2 dimensi dari jenis si64 (C1-C4)
(I3) channel_id konstanta jenis si64

Output

Nama Jenis Batasan
result Quantized Tensor atau per-tensor (C1)

Batasan

  • (C1) dim(source_target_pairs, 1) = 2.
  • (C2) is_unique(source_target_pairs[:, 0]).
  • (C3) is_unique(source_target_pairs[:, 1]).
  • (C4) 0 <= source_target_pairs < N, dengan N ditentukan sebagai:
    • num_replicas jika cross_replica digunakan.
    • num_partitions jika cross_partition digunakan.
  • (C5) type(result) = type(operand).

Contoh

// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
  source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]

Contoh Lainnya

compare

Semantik

Melakukan perbandingan element-wise dari lhs dan rhs5 sesuai dengan comparison_direction dan compare_type, serta menghasilkan result.

Nilai comparison_direction dan compare_type memiliki semantik berikut:

Untuk jenis elemen boolean dan bilangan bulat:

  • EQ: lhs = rhs.
  • NE: lhs != rhs.
  • GE: lhs >= rhs.
  • GT: lhs > rhs.
  • LE: lhs <= rhs.
  • LT: lhs < rhs.

Untuk jenis elemen floating point dengan compare_type = FLOAT, operasi akan menerapkan operasi IEEE-754 berikut:

  • EQ: compareQuietEqual.
  • NE: compareQuietNotEqual.
  • GE: compareQuietGreaterEqual.
  • GT: compareQuietGreater.
  • LE: compareQuietLessEqual.
  • LT: compareQuietLess.

Untuk jenis elemen floating point dengan compare_type = TOTALORDER, operasi menggunakan kombinasi operasi totalOrder dan compareQuietEqual dari IEEE-754. Fitur ini tampaknya tidak digunakan, jadi di masa mendatang, kami berencana menghapusnya (#584).

Untuk jenis elemen yang kompleks, perbandingan leksikografis pasangan (real, imag) dilakukan menggunakan comparison_direction dan compare_type yang disediakan. Memaksakan pengurutan pada bilangan kompleks melibatkan semantik yang mengejutkan, jadi di masa mendatang kami berencana menghapus dukungan untuk bilangan kompleks saat comparison_direction adalah GE, GT, LE, atau LT (#560).

Untuk jenis terkuantisasi. menjalankan dequantize_compare(lhs, rhs, comparison_direction).

Input

Label Nama Jenis Batasan
(I1) lhs Quantized Tensor atau per-tensor (C1-C3)
(I2) rhs Quantized Tensor atau per-tensor (C1-C2)
(I3) comparison_direction enum EQ, NE, GE, GT, LE, dan LT
(I4) compare_type enum FLOAT, TOTALORDER, SIGNED, dan UNSIGNED (C3)

Output

Nama Jenis Batasan
result Tensor tipe boolean (C2)

Batasan

  • (C1) baseline_element_type(lhs) = baseline_element_type(rhs).
  • (C2) shape(lhs) = shape(rhs) = shape(result).
  • (C3) compare_type didefinisikan sebagai:
    • SIGNED jika is_signed_integer(element_type(lhs)).
    • UNSIGNED jika is_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs)).
    • FLOAT atau TOTALORDER jika is_float(element_type(lhs)).
    • FLOAT jika is_complex(element_type(lhs)).

Contoh

// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
  comparison_direction = #stablehlo<comparison_direction LT>,
  compare_type = #stablehlo<comparison_type FLOAT>
} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
// %result: [true, false]

Contoh Lainnya

kompleks

Semantik

Melakukan konversi berbasis elemen ke nilai kompleks dari pasangan nilai nyata dan imajiner, lhs dan rhs, serta menghasilkan TensorFlow result.

Input

Label Nama Jenis Batasan
(I1) lhs Tensor jenis f32 atau f64 (C1-C3)
(I2) rhs Tensor jenis f32 atau f64 (C1)

Output

Nama Jenis Batasan
result Tensor tipe kompleks (C2), (C3)

Batasan

  • (C1) type(lhs) = type(rhs).
  • (C2) shape(result) = shape(lhs).
  • (C3) element_type(result) memiliki jenis complex<E> dengan E = element_type(lhs).

Contoh

// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs) : (tensor<2xf64>, tensor<2xf64>) -> tensor<2xcomplex<f64>>
// %result: [(1.0, 2.0), (3.0, 4.0)]

Contoh Lainnya

concatenate

Semantik

Menggabungkan inputs di sepanjang dimensi dimension dalam urutan yang sama seperti argumen yang diberikan dan menghasilkan argumen result. Secara lebih formal, result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1], dengan:

  1. id = d0 + ... + dk-1 + kd.
  2. d sama dengan dimension, dan d0, ... adalah ukuran dimensi ke-d dari inputs.

Input

Label Nama Jenis Batasan
(I1) inputs jumlah variadic Tensor atau quantized5 per-tensor (C1-C6)
(I2) dimension konstanta jenis si64 (C2), (C4), (C6)

Output

Nama Jenis Batasan
result Quantized Tensor atau per-tensor (C5-C6)

Batasan

  • (C1) same(element_type(inputs...)).
  • (C2) same(shape(inputs...)) kecuali untuk dim(inputs..., dimension).
  • (C3) 0 < size(inputs).
  • (C4) 0 <= dimension < rank(inputs[0]).
  • (C5) element_type(result) = element_type(inputs[0]).
  • (C6) shape(result) = shape(inputs[0]) kecuali untuk:
    • dim(result, dimension) = dim(inputs[0], dimension) + ....

Contoh

// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
  dimension = 0 : i64
} : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64>
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]

Contoh Lainnya

konstanta

Semantik

Menghasilkan5ten output dari konstanta value.

Input

Label Nama Jenis Batasan
(I1) value konstanta (C1)

Output

Nama Jenis Batasan
output Tensor atau Quantized Tensor (C1)

Batasan

  • (C1) type(value) = type(output).

Contoh

%output = "stablehlo.constant"() {
  value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
} : () -> tensor<2x2xf32>
// %output: [[0.0, 1.0], [2.0, 3.0]]

Contoh Lainnya

melakukan konversi

Semantik

Melakukan konversi element-wise dari satu jenis elemen ke jenis elemen lainnya pada tensor operand dan menghasilkan result.

Untuk konversi boolean-to-any-supported-type, nilai false dikonversi menjadi nol, dan nilai true dikonversi menjadi satu. Untuk konversi any-supported-type-to-boolean, nilai nol dikonversi menjadi false, dan nilai bukan nol dikonversi menjadi true. Lihat di bawah untuk mengetahui cara kerjanya untuk jenis yang kompleks.

Untuk konversi yang melibatkan integer-ke-bilangan bulat, bilangan bulat-ke-floating-point atau titik-mengambang-ke-mengambang, jika nilai sumber dapat direpresentasikan dengan tepat dalam jenis tujuan, nilai hasil adalah representasi yang tepat tersebut. Jika tidak, perilaku akan diputuskan (#180).

Untuk konversi yang melibatkan floating-point-to-integer, bagian pecahan akan dipotong. Jika nilai yang terpotong tidak dapat dinyatakan dalam jenis tujuan, perilakunya akan ditentukan nanti (#180).

Konversi yang melibatkan kompleks-ke-kompleks mengikuti perilaku konversi floating-point-to-floating-point yang sama untuk mengonversi bagian nyata dan imajiner.

Untuk konversi complex-to-any-other-type dan complex-to-any-other-type, nilai imajiner sumber diabaikan atau nilai imajiner tujuan dinolkan. Konversi bagian asli akan mengikuti konversi floating point.

Pada prinsipnya, operasi ini dapat mengekspresikan dekuantisasi (konversi dari tensor terkuantisasi ke TensorFlow reguler), kuantisasi (konversi dari tensor reguler ke TensorFlow terkuantisasi) dan rekuantisasi (konversi antara tensor terkuantisasi), tetapi saat ini kami memiliki operasi khusus untuk itu - uniform_dequantize untuk kasus penggunaan pertama dan uniform_quantize untuk kasus penggunaan kedua dan ketiga. Di masa mendatang, kedua operasi ini dapat digabungkan menjadi convert (#1576).

Input

Label Nama Jenis Batasan
(I1) operand Tensor (C1)

Output

Nama Jenis Batasan
result Tensor (C1)

Batasan

  • (C1) shape(operand) = shape(result).

Contoh

// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand) : (tensor<3xi64>) -> tensor<3xcomplex<f64>>
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]

Contoh Lainnya

konvolusi

Semantik

Menghitung perkalian titik di antara jendela lhs dan irisan rhs, lalu menghasilkan result. Diagram berikut menunjukkan cara elemen dalam result dihitung dari lhs dan rhs menggunakan contoh konkret.

Secara lebih formal, pertimbangkan framing ulang input berikut dalam hal lhs agar dapat mengekspresikan jendela lhs:

  • lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension)).
  • lhs_window_strides = lhs_shape(1, window_strides, 1).
  • lhs_padding = lhs_shape([0, 0], padding, [0, 0]).
  • lhs_base_dilations = lhs_shape(1, lhs_dilation, 1).
  • lhs_window_dilations = lhs_shape(1, rhs_dilation, 1).

Pembingkaian ulang data ({i>reframing<i}) ini menggunakan fungsi bantuan berikut:

  • lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]).
  • result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]).
  • permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1] dengan j[d] = i[permutation[d]].

Jika feature_group_count = 1 dan batch_group_count = 1, maka untuk semua output_spatial_index di index_space(dim(result, output_spatial_dimensions...)), result[result_shape(:, output_spatial_index, :)] = dot_product jika:

  • padding_value = constant(0, element_type(lhs)).
  • padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1).
  • lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides.
  • lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations).
  • reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true]). Fitur ini tampaknya tidak digunakan, jadi kami berencana menghapusnya di masa mendatang (#1181).
  • dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension]).

Jika feature_group_count > 1:

  • lhses = split(lhs, feature_group_count, input_feature_dimension).
  • rhses = split(rhs, feature_group_count, kernel_output_feature_dimension).
  • results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...).
  • result = concatenate(results, output_feature_dimension).

Jika batch_group_count > 1:

  • lhses = split(lhs, batch_group_count, input_batch_dimension).
  • rhses = split(rhs, batch_group_count, kernel_output_feature_dimension).
  • results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...).
  • result = concatenate(results, output_feature_dimension).

Untuk jenis terkuantisasi, jalankan dequantize_op_quantize( lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension, input_feature_dimension, input_spatial_dimensions, kernel_input_feature_dimension, kernel_output_feature_dimension, kernel_spatial_dimensions, output_batch_dimension, output_feature_dimension, output_spatial_dimensions, feature_group_count, batch_group_count, precision_config), lhs, rhs, type(result)).

Input

Label Nama Jenis Batasan
(I1) lhs Quantized Tensor atau per-tensor (C1), (C10-C11), (C14) (C25), (C27-C30)
(I2) rhs Tensor atau Quantized Tensor (C1), (C14-C16), (C25), (C27-C32)
(I3) window_strides Konstanta TensorFlow 1 dimensi dari jenis si64 (C2-C3), (C25)
(I4) padding Konstanta TensorFlow 2 dimensi dari jenis si64 (C4), (C25)
(I5) lhs_dilation Konstanta TensorFlow 1 dimensi dari jenis si64 (C5-C6), (C25)
(I6) rhs_dilation Konstanta TensorFlow 1 dimensi dari jenis si64 (C7-C8), (C25)
(I7) window_reversal Konstanta TensorFlow 1 dimensi dari jenis i1 (C9)
(I8) input_batch_dimension konstanta jenis si64 (C10), (C13), (C25)
(I9) input_feature_dimension konstanta jenis si64 (C11), (C13-C14)
(I10) input_spatial_dimensions Konstanta TensorFlow 1 dimensi dari jenis si64 (C12), (C13), (C25)
(I11) kernel_input_feature_dimension konstanta jenis si64 (C14), (C18)
(I12) kernel_output_feature_dimension konstanta jenis si64 (C15-C16), (C18), (C25), (C32)
(I13) kernel_spatial_dimensions Konstanta TensorFlow 1 dimensi dari jenis si64 (C17-C18), (C25)
(I14) output_batch_dimension konstanta jenis si64 (C20), (C25)
(I15) output_feature_dimension konstanta jenis si64 (C20), (C25), (C33)
(I16) output_spatial_dimensions Konstanta TensorFlow 1 dimensi dari jenis si64 (C19-C20), (C25)
(I17) feature_group_count konstanta jenis si64 (C11), (C14), (C16), (C21), (C23)
(I18) batch_group_count konstanta jenis si64 (C10), (C15), (C22), (C23), (C25)
(I19) precision_config jumlah variadic enum DEFAULT, HIGH, dan HIGHEST (C24)

Output

Nama Jenis Batasan
result Tensor atau Quantized Tensor (C25-C28), (C30-C31), (C33)

Batasan

  • (C1) N = rank(lhs) = rank(rhs).
  • (C2) size(window_strides) = N - 2.
  • (C3) 0 < window_strides.
  • (C4) shape(padding) = [N - 2, 2].
  • (C5) size(lhs_dilation) = N - 2.
  • (C6) 0 < lhs_dilation.
  • (C7) size(rhs_dilation) = N - 2.
  • (C8) 0 < rhs_dilation.
  • (C9) size(window_reversal) = N - 2.
  • (C10) dim(lhs, input_batch_dimension) % batch_group_count = 0.
  • (C11) dim(lhs, input_feature_dimension) % feature_group_count = 0.
  • (C12) size(input_spatial_dimensions) = N - 2.
  • (C13) input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension] yang ditentukan:
    • is_unique(input_dimensions).
    • 0 <= input_dimensions < N.
  • (C14) dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count.
  • (C15) dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0.
  • (C16) dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0.
  • (C17) size(kernel_spatial_dimensions) = N - 2.
  • (C18) kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension] yang ditentukan:
    • is_unique(kernel_dimensions).
    • 0 <= kernel_dimensions < N.
  • (C19) size(output_spatial_dimensions) = N - 2.
  • (C20) output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension] yang ditentukan:
    • is_unique(output_dimensions).
    • 0 <= output_dimensions < N.
  • (C21) 0 < feature_group_count.
  • (C22) 0 < batch_group_count.
  • (C23) feature_group_count = 1 or batch_group_count = 1.
  • (C24) size(precision_config) = 2.
  • (C25) dim(result, result_dim) didefinisikan sebagai:
    • dim(lhs, input_batch_dimension) / batch_group_count jika result_dim = output_batch_dimension.
    • dim(rhs, kernel_output_feature_dimension) jika result_dim = output_feature_dimension.
    • num_windows jika tidak, jika:
    • output_spatial_dimensions[spatial_dim] = result_dim.
    • lhs_dim = input_spatial_dimensions[spatial_dim].
    • rhs_dim = kernel_spatial_dimensions[spatial_dim].
    • dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1.
    • padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1].
    • dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1.
    • is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim].
    • num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1.
  • (C26) rank(result) = N.
  • Jika operasi menggunakan TensorFlow non-terkuantisasi:
    • (C27) element_type(lhs) = element_type(rhs) = element_type(result).
  • Jika operasi menggunakan kuantisasi Tensor:
    • (C28) is_quantized_tensor(lhs) and is_quantized_tensor(rhs) and is_quantized_tensor(result).
    • (C29) storage_type(lhs) = storage_type(rhs).
    • (C30) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • (C31) Jika is_per_tensor_quantized(rhs), maka is_per_tensor_quantized(result).
    • (C32) Jika is_per_axis_quantized(rhs), maka quantization_dimension(rhs) = kernel_output_feature_dimension.
    • (C33) Jika is_per_axis_quantized(result), maka quantization_dimension(result) = output_feature_dimension.

Contoh

// %lhs: [[
//        [
//          [1], [2], [5], [6]
//        ],
//        [
//          [3], [4], [7], [8]
//        ],
//        [
//          [10], [11], [14], [15]
//        ],
//        [
//          [12], [13], [16], [17]
//        ]
//      ]]
//
// %rhs : [
//         [[[1]], [[1]], [[1]]],
//         [[[1]], [[1]], [[1]]],
//         [[[1]], [[1]], [[1]]]
//        ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
  window_strides = dense<4> : tensor<2xi64>,
  padding = dense<0> : tensor<2x2xi64>,
  lhs_dilation = dense<2> : tensor<2xi64>,
  rhs_dilation = dense<1> : tensor<2xi64>,
  window_reversal = dense<false> : tensor<2xi1>,
  // In the StableHLO dialect, dimension numbers are encoded via:
  // `[<input dimensions>]x[<kernel dimensions>]->[output dimensions]`.
  // "b" is batch dimension, "f" is feature dimension,
  // "i" is input feature dimension, "o" is output feature dimension,
  // "0/1/etc" are spatial dimensions.
  dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
  feature_group_count = 1 : i64,
  batch_group_count = 1 : i64,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi32>, tensor<3x3x1x1xi32>) -> tensor<1x2x2x1xi32>
// %result: [[
//            [[10], [26]],
//            [[46], [62]]
//          ]]

cosinus

Semantik

Menjalankan operasi kosinus element-wise pada TensorFlow operand dan menghasilkan tensor result. Bergantung pada jenis elemen, lakukan hal berikut:

  • Untuk float: cos dari IEEE-754.
  • Untuk bilangan kompleks: kosinus kompleks.
  • Untuk jenis terkuantisasi: dequantize_op_quantize(cosine, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand Tensor floating-point atau tipe kompleks atau kuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result Tensor floating-point atau tipe kompleks atau kuantisasi per-tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.cosine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.0], [-1.0, 0.0]]

Contoh Lainnya

count_leading_zeros

Semantik

Melakukan penghitungan berdasarkan elemen dari jumlah bit nol di depan pada tensor operand dan menghasilkan result.

Input

Label Nama Jenis Batasan
(I1) operand Tensor tipe integer (C1)

Output

Nama Jenis Batasan
result Tensor tipe integer (C1)

Batasan

  • (C1) type(operand) = type(result).

Contoh

// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand) : (tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[64, 63], [56, 0]]

Contoh Lainnya

custom_call

Semantik

Mengenkapsulasi call_target_name operasi yang ditentukan implementasi yang mengambil inputs dan called_computations serta menghasilkan results. has_side_effect, backend_config, dan api_version dapat digunakan untuk menyediakan metadata tambahan yang ditentukan oleh implementasi.

Saat ini, operasi tersebut berisi kumpulan metadata yang cukup tidak teratur yang mencerminkan evolusi organik dari operasi pasangannya di compiler XLA. Ke depannya, kami berencana menyatukan metadata ini (#741).

Input

Label Nama Jenis
(I1) inputs jumlah nilai variadic
(I2) call_target_name konstanta jenis string
(I3) has_side_effect konstanta jenis i1
(I4) backend_config konstanta jenis string
(I5) api_version konstanta jenis si32
(I6) called_computations jumlah variadik konstanta jenis string

Output

Nama Jenis
results jumlah nilai variadic

Contoh

%results = "stablehlo.custom_call"(%input0) {
  call_target_name = "foo",
  has_side_effect = false,
  backend_config = "bar",
  api_version = 1 : i32,
  called_computations = [@foo]
} : (tensor<f64>) -> tensor<f64>

bagi

Semantik

Melakukan pembagian berbasis elemen dari TensorFlow lhs dan pembagi rhs pembagi serta menghasilkan tensor result. Bergantung pada jenis elemen, lakukan hal berikut:

  • Untuk bilangan bulat: pembagian bilangan bulat yang menghasilkan hasil bagi aljabar dengan bagian pecahan yang dihapus.
  • Untuk float: division dari IEEE-754.
  • Untuk bilangan kompleks: pembagian kompleks.
  • Untuk jenis terkuantisasi:
    • dequantize_op_quantize(divide, lhs, rhs, type(result)).

Input

Label Nama Jenis Batasan
(I1) lhs Tensor integer, floating point atau tipe kompleks atau quantized per-tensor (C1)
(I2) rhs Tensor integer, floating point atau tipe kompleks atau quantized per-tensor (C1)

Output

Nama Jenis Batasan
result Tensor integer, floating point, atau tipe kompleks atau kuantisasi per-tensor (C1)

Batasan

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Contoh

// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]

Contoh Lainnya

dot_general

Semantik

Menghitung produk titik di antara irisan lhs dan irisan rhs serta menghasilkan tensor result.

Secara lebih formal, result[result_index] = dot_product, jika:

  • lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions].
  • rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions].
  • result_batching_index + result_lhs_index + result_rhs_index = result_index dengan size(result_batching_index) = size(lhs_batching_dimensions), size(result_lhs_index) = size(lhs_result_dimensions), dan size(result_rhs_index) = size(rhs_result_dimensions).
  • transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions).
  • transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :]).
  • reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions)).
  • transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions).
  • transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :]).
  • reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions)).
  • dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y)).

Untuk jenis terkuantisasi, jalankan dequantize_op_quantize( lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions, rhs_batching_dimensions, lhs_contracting_dimensions, rhs_contracting_dimensions, precision_config), lhs, rhs, type(result)).

Nilai ini hanya menetapkan semantik untuk kuantisasi per tensor. Kuantisasi per sumbu sedang dalam proses (#1574). Selain itu, pada masa mendatang, kami dapat mempertimbangkan untuk menambahkan dukungan untuk kuantisasi campuran (#1575).

precision_config mengontrol kompromi antara kecepatan dan akurasi untuk komputasi pada backend akselerator. Ini bisa berupa salah satu dari hal berikut (pada saat ini, semantik nilai enum ini tidak ditentukan, tetapi kami merencanakan untuk mengatasinya di #755):

  • DEFAULT: Penghitungan tercepat, tetapi perkiraan yang paling tidak akurat terhadap angka asli.
  • HIGH: Penghitungan lebih lambat, tetapi perkiraan lebih akurat terhadap angka asli.
  • HIGHEST: Penghitungan paling lambat, tetapi merupakan perkiraan yang paling akurat terhadap angka asli.

Input

Label Nama Jenis Batasan
(I1) lhs Quantized Tensor atau per-tensor (C5-C6), (C9-C10), (C12-C16)
(I2) rhs Quantized Tensor atau per-tensor (C7-C10), (C12)
(I3) lhs_batching_dimensions Konstanta TensorFlow 1 dimensi dari jenis si64 (C1), (C3), (C5), (C9), (C12)
(I4) rhs_batching_dimensions Konstanta TensorFlow 1 dimensi dari jenis si64 (C1), (C4), (C7), (C9)
(I5) lhs_contracting_dimensions Konstanta TensorFlow 1 dimensi dari jenis si64 (C2), (C3), (C6), (C10)
(I6) rhs_contracting_dimensions Konstanta TensorFlow 1 dimensi dari jenis si64 (C2), (C4), (C8), (C10)
(I7) precision_config jumlah variadic enum DEFAULT, HIGH, dan HIGHEST (C11)

Output

Nama Jenis Batasan
result Quantized Tensor atau per-tensor (C12), (C14), (C16)

Batasan

  • (C1) size(lhs_batching_dimensions) = size(rhs_batching_dimensions).
  • (C2) size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions).
  • (C3) is_unique(lhs_batching_dimensions + lhs_contracting_dimensions).
  • (C4) is_unique(rhs_batching_dimensions + rhs_contracting_dimensions).
  • (C5) 0 <= lhs_batching_dimensions < rank(lhs).
  • (C6) 0 <= lhs_contracting_dimensions < rank(lhs).
  • (C7) 0 <= rhs_batching_dimensions < rank(rhs).
  • (C8) 0 <= rhs_contracting_dimensions < rank(rhs).
  • (C9) dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...).
  • (C10) dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...).
  • (C11) size(precision_config) = 2.
  • (C12) shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions).
  • Jika operasi menggunakan TensorFlow non-terkuantisasi:
    • (C13) element_type(lhs) = element_type(rhs)
  • Jika operasi menggunakan kuantisasi Tensor:
    • (C14) is_quantized(lhs) and is_quantized(rhs) and is_quantized(result).
    • (C15) storage_type(lhs) = storage_type(rhs).
    • (C16) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • (C17) zero_points(rhs) = 0.

Contoh

// %lhs: [
//        [[1, 2],
//         [3, 4]],
//        [[5, 6],
//         [7, 8]]
//       ]
// %rhs: [
//        [[1, 0],
//         [0, 1]],
//        [[1, 0],
//         [0, 1]]
//       ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
  dot_dimension_numbers = #stablehlo.dot<
    lhs_batching_dimensions = [0],
    rhs_batching_dimensions = [0],
    lhs_contracting_dimensions = [2],
    rhs_contracting_dimensions = [1]
  >,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<2x2x2xi64>, tensor<2x2x2xi64>) -> tensor<2x2x2xi64>
// %result: [
//           [[1, 2],
//            [3, 4]],
//           [[5, 6],
//            [7, 8]]
//          ]

Contoh Lainnya

dynamic_slice

Semantik

Mengekstrak irisan dari operand menggunakan indeks awal yang dikomputasi secara dinamis dan menghasilkan5 result. start_indices berisi indeks awal irisan untuk setiap dimensi yang bergantung pada kemungkinan penyesuaian, dan slice_sizes berisi ukuran irisan untuk setiap dimensi. Secara lebih formal, result[result_index] = operand[operand_index] jika:

  • adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes).
  • operand_index = adjusted_start_indices + result_index.

Input

Label Nama Jenis Batasan
(I1) operand Quantized Tensor atau per-tensor (C1), (C2), (C4)
(I2) start_indices jumlah variadic TensorFlow 0 dimensi dari jenis integer (C2), (C3)
(I3) slice_sizes Konstanta TensorFlow 1 dimensi dari jenis si64 (C2), (C4), (C5)

Output

Nama Jenis Batasan
result Quantized Tensor atau per-tensor (C1), (C5)

Batasan

  • (C1) element_type(operand) = element_type(result).
  • (C2) size(start_indices) = size(slice_sizes) = rank(operand).
  • (C3) same(type(start_indices...)).
  • (C4) 0 <= slice_sizes <= shape(operand).
  • (C5) shape(result) = slice_sizes.

Contoh

// %operand: [
//            [0, 0, 1, 1],
//            [0, 0, 1, 1],
//            [0, 0, 0, 0],
//            [0, 0, 0, 0]
//           ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
  slice_sizes = dense<[2, 2]> : tensor<2xi64>
} : (tensor<4x4xi32>, tensor<i64>, tensor<i64>) -> tensor<2x2xi32>
// %result: [
//           [1, 1],
//           [1, 1]
//          ]

Contoh Lainnya

dynamic_update_slice

Semantik

Menghasilkan5_tensor result yang sama dengan5_tensor operand kecuali bahwa irisan yang dimulai pada start_indices diperbarui dengan nilai di update. Secara lebih formal, result[result_index] ditentukan sebagai:

  • update[update_index] jika 0 <= update_index < shape(update) dengan:
    • adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update)).
    • update_index = result_index - adjusted_start_indices.
  • operand[result_index] jika tidak.

Input

Label Nama Jenis Batasan
(I1) operand Quantized Tensor atau per-tensor (C1-C4), (C6)
(I2) update Quantized Tensor atau per-tensor (C2), (C3), (C6)
(I3) start_indices jumlah variadic TensorFlow 0 dimensi dari jenis integer (C4), (C5)

Output

Nama Jenis Batasan
result Quantized Tensor atau per-tensor (C1)

Batasan

  • (C1) type(operand) = type(result).
  • (C2) element_type(update) = element_type(operand).
  • (C3) rank(update) = rank(operand).
  • (C4) size(start_indices) = rank(operand).
  • (C5) same(type(start_indices...)).
  • (C6) 0 <= shape(update) <= shape(operand).

Contoh

// %operand: [
//            [1, 1, 0, 0],
//            [1, 1, 0, 0],
//            [1, 1, 1, 1],
//            [1, 1, 1, 1]
//           ]
// %update: [
//           [1, 1],
//           [1, 1]
//          ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
  : (tensor<4x4xi32>, tensor<2x2xi32>, tensor<i64>, tensor<i64>) -> tensor<4x4xi32>
// %result: [
//           [1, 1, 1, 1],
//           [1, 1, 1, 1],
//           [1, 1, 1, 1],
//           [1, 1, 1, 1]
//          ]

Contoh Lainnya

berpangkat

Semantik

Menjalankan operasi eksponensial element-wise pada TensorFlow operand dan menghasilkan tensor result. Bergantung pada jenis elemen, lakukan hal berikut:

  • Untuk float: exp dari IEEE-754.
  • Untuk bilangan kompleks: eksponensial kompleks.
  • Untuk jenis terkuantisasi: dequantize_op_quantize(exponential, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand Tensor floating-point atau tipe kompleks atau kuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result Tensor floating-point atau tipe kompleks atau kuantisasi per-tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]

Contoh Lainnya

exponential_minus_one

Semantik

Melakukan operasi eksponensial elemen dikurangi satu pada TensorFlow operand dan menghasilkan5 result. Bergantung pada jenis elemen, lakukan hal berikut:

  • Untuk float: expm1 dari IEEE-754.
  • Untuk bilangan kompleks: eksponensial kompleks dikurangi satu.
  • Untuk jenis terkuantisasi: dequantize_op_quantize(exponential_minus_one, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand Tensor floating-point atau tipe kompleks atau kuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result Tensor floating-point atau tipe kompleks atau kuantisasi per-tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand) : (tensor<2xf64>) -> tensor<2xf64>
// %result: [0.0, 1.71828187]

Contoh Lainnya

fft [ift]

Semantik

Melakukan transformasi Fourier maju dan terbalik untuk input/output yang nyata dan kompleks.

fft_type adalah salah satu dari berikut ini:

  • FFT: Meneruskan FFT kompleks ke kompleks.
  • IFFT: FFT kompleks ke kompleks terbalik.
  • RFFT: Meneruskan FFT real-ke-kompleks.
  • IRFFT: FFT real-ke-kompleks terbalik (yaitu mengambil kompleks, menampilkan real).

Secara lebih formal, mengingat fungsi fft yang menggunakan TensorFlow 1 dimensi dari jenis kompleks sebagai input, menghasilkan TensorFlow 1 dimensi dari jenis yang sama dengan output dan menghitung transformasi Fourier diskret:

Untuk fft_type = FFT, result ditentukan sebagai hasil akhir dari rangkaian komputasi L dengan L = size(fft_length). Misalnya, untuk L = 3:

  • result1[i0, ..., :] = fft(operand[i0, ..., :]).
  • result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).

Selain itu, dengan fungsi ifft yang memiliki tanda tangan jenis yang sama dan menghitung kebalikan dari fft:

Untuk fft_type = IFFT, result ditentukan sebagai kebalikan dari komputasi untuk fft_type = FFT. Misalnya, untuk L = 3:

  • result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).
  • result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :] = ifft(result2[i0, ..., :]).

Selain itu, mengingat fungsi rfft yang menggunakan TensorFlow 1 dimensi dari jenis floating point, menghasilkan TensorFlow 1 dimensi dari jenis kompleks dari semantik floating-point yang sama dan berfungsi sebagai berikut:

  • rfft(real_operand) = truncated_result di mana
  • complex_operand... = (real_operand..., 0.0).
  • complex_result = fft(complex_operand).
  • truncated_result = complex_result[:(rank(complex_result) / 2 + 1)].

(Saat transformasi Fourier yang terpisah dihitung untuk operand nyata, elemen N/2 + 1 pertama hasil tersebut secara jelas menentukan hasil lainnya, sehingga hasil rfft terpotong untuk menghindari komputasi elemen yang berlebihan).

Untuk fft_type = RFFT, result ditentukan sebagai hasil akhir dari rangkaian komputasi L dengan L = size(fft_length). Misalnya, untuk L = 3:

  • result1[i0, ..., :] = rfft(operand[i0, ..., :]).
  • result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).

Terakhir, dengan mempertimbangkan fungsi irfft yang memiliki tanda tangan jenis yang sama dan menghitung kebalikan dari rfft:

Untuk fft_type = IRFFT, result ditentukan sebagai kebalikan dari komputasi untuk fft_type = RFFT. Misalnya, untuk L = 3:

  • result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).
  • result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :] = irfft(result2[i0, ..., :]).

Input

Label Nama Jenis Batasan
(I1) operand Tensor floating point atau tipe kompleks (C1), (C2), (C4), (C5)
(I2) fft_type enum FFT, IFFT, RFFT, dan IRFFT (C2), (C5)
(I3) fft_length Konstanta TensorFlow 1 dimensi dari jenis si64 (C1), (C3), (C4)

Output

Nama Jenis Batasan
result Tensor floating point atau tipe kompleks (C2), (C4), (C5)

Batasan

  • (C1) size(fft_length) <= rank(operand).
  • (C2) Hubungan antara jenis elemen operand dan result bervariasi:
    • Jika fft_type = FFT, element_type(operand), dan element_type(result) memiliki jenis kompleks yang sama.
    • Jika fft_type = IFFT, element_type(operand), dan element_type(result) memiliki jenis kompleks yang sama.
    • Jika fft_type = RFFT, element_type(operand) adalah jenis floating point dan element_type(result) adalah jenis kompleks dari semantik floating point yang sama.
    • Jika fft_type = IRFFT, element_type(operand) adalah jenis kompleks dan element_type(result) adalah jenis floating point dari semantik floating point yang sama.
  • (C3) 1 <= size(fft_length) <= 3.
  • (C4) Jika di antara operand dan result, terdapat real tensor dari jenis floating point, lalu shape(real)[-size(fft_length):] = fft_length.
  • (C5) shape(result) = shape(operand) kecuali untuk:
    • Jika fft_type = RFFT, dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1.
    • Jika fft_type = IRFFT, dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1.

Contoh

// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
  fft_type = #stablehlo<fft_type FFT>,
  fft_length = dense<4> : tensor<1xi64>
} : (tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>>
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]

floor

Semantik

Melakukan floor-wise element-wise dari operandtensor dan menghasilkan result. Mengimplementasikan operasi roundToIntegralTowardNegative dari spesifikasi IEEE-754. Untuk jenis terkuantisasi, jalankan dequantize_op_quantize(floor, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand Tensor floating-point atau per-tensor quantizedTensor (C1)

Output

Nama Jenis Batasan
result Tensor floating-point atau per-tensor quantizedTensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]

Contoh Lainnya

mengumpulkan

Semantik

Mengumpulkan irisan dari TensorFlow operand dari offset yang ditentukan dalam start_indices dan menghasilkan TensorFlow result.

Diagram berikut menunjukkan cara elemen di result dipetakan pada elemen di operand menggunakan contoh konkret. Diagram memilih beberapa contoh indeks result dan menjelaskan secara mendetail indeks operand mana yang sesuai.

Secara lebih formal, result[result_index] = operand[operand_index] jika:

  • batch_dims = [d for d in axes(result) and d not in offset_dims].
  • batch_index = result_index[batch_dims...].
  • start_index didefinisikan sebagai:
    • start_indices[bi0, ..., :, ..., biN] dengan bi adalah elemen individual dalam batch_index dan : disisipkan pada indeks index_vector_dim, jika index_vector_dim < rank(start_indices).
    • [start_indices[batch_index]] jika tidak.
  • Untuk d_operand di axes(operand),
    • full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand]) jika d_operand = start_index_map[d_start].
    • full_start_index[d_operand] = 0 jika tidak.
  • offset_index = result_index[offset_dims...].
  • full_offset_index = [oi0, ..., 0, ..., oiN] dengan oi adalah elemen individual di offset_index, dan 0 disisipkan pada indeks dari collapsed_slice_dims.
  • operand_index = full_start_index + full_offset_index.

Jika indices_are_sorted adalah true, implementasi dapat mengasumsikan bahwa start_indices diurutkan sehubungan dengan start_index_map, jika tidak, perilaku tidak ditentukan. Secara lebih formal, untuk semua i1 < i2 dari indices(result), full_start_index(i1) <= full_start_index(i2).

Input

Label Nama Jenis Batasan
(I1) operand Quantized Tensor atau per-tensor (C1), (C7), (C10-C12), (C14)
(I2) start_indices Tensor tipe integer (C2), (C3), (C13)
(I3) offset_dims Konstanta TensorFlow 1 dimensi dari jenis si64 (C1), (C4-C5), (C13)
(I4) collapsed_slice_dims Konstanta TensorFlow 1 dimensi dari jenis si64 (C1), (C6-C8), (C13)
(I5) start_index_map Konstanta TensorFlow 1 dimensi dari jenis si64 (C3), (C9), (C10)
(I6) index_vector_dim konstanta jenis si64 (C2), (C3), (C13)
(I7) slice_sizes Konstanta TensorFlow 1 dimensi dari jenis si64 (C8), (C11-C13)
(I8) indices_are_sorted konstanta jenis i1

Output

Nama Jenis Batasan
result Quantized Tensor atau per-tensor (C5), (C13-C14)

Batasan

  • (C1) rank(operand) = size(offset_dims) + size(collapsed_slice_dims).
  • (C2) 0 <= index_vector_dim <= rank(start_indices).
  • (C3) size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1.
  • (C4) is_unique(offset_dims) and is_sorted(offset_dims).
  • (C5) 0 <= offset_dims < rank(result).
  • (C6) is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims).
  • (C7) 0 <= collapsed_slice_dims < rank(operand).
  • (C8) slice_sizes[collapsed_slice_dims...] <= 1.
  • (C9) is_unique(start_index_map).
  • (C10) 0 <= start_index_map < rank(operand).
  • (C11) size(slice_sizes) = rank(operand).
  • (C12) 0 <= slice_sizes <= shape(operand).
  • (C13) shape(result) = combine(batch_dim_sizes, offset_dim_sizes) dalam hal:
    • batch_dim_sizes = shape(start_indices) kecuali bahwa ukuran dimensi start_indices yang sesuai dengan index_vector_dim tidak disertakan.
    • offset_dim_sizes = shape(slice_sizes) kecuali bahwa ukuran dimensi di slice_sizes yang sesuai dengan collapsed_slice_dims tidak disertakan.
    • combine menempatkan batch_dim_sizes pada sumbu yang sesuai dengan batch_dims dan offset_dim_sizes pada sumbu yang sesuai dengan offset_dims.
  • (C14) element_type(operand) = element_type(result).

Contoh

// %operand: [
//            [[1, 2], [3, 4], [5, 6], [7, 8]],
//            [[9, 10],[11, 12], [13, 14], [15, 16]],
//            [[17, 18], [19, 20], [21, 22], [23, 24]]
//           ]
// %start_indices: [
//                  [[0, 0], [1, 0], [2, 1]],
//                  [[0, 1], [1, 1], [0, 2]]
//                 ]
%result = "stablehlo.gather"(%operand, %start_indices) {
  dimension_numbers = #stablehlo.gather<
    offset_dims = [2, 3],
    collapsed_slice_dims = [0],
    start_index_map = [1, 0],
    index_vector_dim = 2>,
  slice_sizes = dense<[1, 2, 2]> : tensor<3xi64>,
  indices_are_sorted = false
} : (tensor<3x4x2xi32>, tensor<2x3x2xi64>) -> tensor<2x3x2x2xi32>
// %result: [
//            [
//              [[1, 2], [3, 4]],
//              [[3, 4], [5, 6]],
//              [[13, 14], [15, 16]]
//            ],
//            [
//              [[9, 10], [11, 12]],
//              [[11, 12], [13, 14]],
//              [[17, 18], [19, 20]]
//            ]
//          ]

Contoh Lainnya

get_dimension_size

Semantik

Menghasilkan ukuran dimension yang ditentukan dari operand. Secara lebih formal, result = dim(operand, dimension).

Input

Label Nama Jenis Batasan
(I1) operand Tensor (C1)
(I2) dimension konstanta jenis si64 (C1)

Output

Nama Jenis
result Tensor 0 dimensi dari jenis si32

Batasan

  • (C1) 0 <= dimension < rank(operand).

Contoh

// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
  dimension = 1 : i64
} : (tensor<2x3xi64>) -> tensor<i32>
// %result: 3

Contoh Lainnya

get_tuple_element

Semantik

Mengekstrak elemen pada posisi index dari tuple operand dan menghasilkan result. Secara lebih formal, result = operand[index].

Input

Label Nama Jenis Batasan
(I1) operand tuple (C1), (C2)
(I2) index konstanta jenis si32 (C1), (C2)

Output

Nama Jenis Batasan
result jenis apa pun yang didukung (C2)

Batasan

  • (C1) 0 <= index < size(operand).
  • (C2) type(result) = tuple_element_types(operand)[index].

Contoh

// %operand: ([1.0, 2.0], (3))
%result = "stablehlo.get_tuple_element"(%operand) {
  index = 0 : i32
} : (tuple<tensor<2xf32>, tuple<tensor<i32>>>) -> tensor<2xf32>
// %result: [1.0, 2.0]

Contoh Lainnya

if

Semantik

Menghasilkan output dari mengeksekusi tepat satu fungsi dari true_branch atau false_branch, bergantung pada nilai pred. Secara lebih formal, result = pred ? true_branch() : false_branch().

Input

Label Nama Jenis Batasan
(I1) pred Tensor 0 dimensi dari jenis i1
(I2) true_branch fungsi (C1-C3)
(I3) false_branch fungsi (C1), (C2)

Output

Nama Jenis Batasan
results bilangan variadic, quantized TensorFlow, atau token (C3)

Batasan

  • (C1) input_types(true_branch) = input_types(false_branch) = [].
  • (C2) output_types(true_branch) = output_types(false_branch).
  • (C3) type(results...) = output_types(true_branch).

Contoh

// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
  "stablehlo.return"(%result_true_branch) : (tensor<i32>) -> ()
}, {
  "stablehlo.return"(%result_false_branch) : (tensor<i32>) -> ()
}) : (tensor<i1>) -> tensor<i32>
// %result: 10

Contoh Lainnya

gambar

Semantik

Mengekstrak bagian imajiner, berdasarkan elemen, dari operand dan menghasilkan tensor result. Secara lebih formal, untuk setiap elemen x: imag(x) = is_complex(x) ? imaginary_part(x) : constant(0, element_type(result)).

Input

Label Nama Jenis Batasan
(I1) operand Tensor floating point atau tipe kompleks (C1), (C2)

Output

Nama Jenis Batasan
result Tensor tipe floating-point (C1), (C2)

Batasan

  • (C1) shape(result) = shape(operand).
  • (C2) element_type(result) ditentukan sebagai:
    • complex_element_type(element_type(operand)) jika is_complex(operand).
    • element_type(operand) jika tidak.

Contoh

// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [2.0, 4.0]

Contoh Lainnya

dalam feed

Semantik

Membaca data dari infeed dan menghasilkan results.

Semantik infeed_config ditentukan oleh implementasi.

results terdiri dari nilai payload yang akan muncul terlebih dahulu dan token yang akan muncul terakhir. Di masa mendatang, kami berencana membagi payload dan token menjadi dua output terpisah untuk meningkatkan kejelasan (#670).

Input

Label Nama Jenis
(I1) token token
(I2) infeed_config konstanta jenis string

Output

Nama Jenis Batasan
results bilangan variadic, quantized TensorFlow, atau token (C1-C3)

Batasan

  • (C1) 0 < size(results).
  • (C2) is_empty(result[:-1]) atau is_tensor(type(results[:-1])).
  • (C3) is_token(type(results[-1])).

Contoh

// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
  infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
  infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]

Contoh Lainnya

Iota

Semantik

Mengisi reCAPTCHA output dengan nilai dalam urutan yang meningkat mulai dari nol di sepanjang dimensi iota_dimension. Secara lebih formal,

output[result_index] = constant(is_quantized(output) ? quantize(result_index[iota_dimension], element_type(output)) : result_index[iota_dimension], element_type(output)).

Input

Label Nama Jenis Batasan
(I1) iota_dimension si64 (C1)

Output

Nama Jenis Batasan
output Tensor integer, floating point atau tipe kompleks atau quantized per-tensor (C1)

Batasan

  • (C1) 0 <= iota_dimension < rank(output).

Contoh

%output = "stablehlo.iota"() {
  iota_dimension = 0 : i64
} : () -> tensor<4x5xi32>
// %output: [
//           [0, 0, 0, 0, 0],
//           [1, 1, 1, 1, 1],
//           [2, 2, 2, 2, 2],
//           [3, 3, 3, 3, 3]
//          ]

%output = "stablehlo.iota"() {
  iota_dimension = 1 : i64
} : () -> tensor<4x5xi32>
// %output: [
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4]
//          ]

Contoh Lainnya

is_finite

Semantik

Melakukan pemeriksaan element-wise apakah nilai dalam x terbatas (yaitu bukan +Inf, -Inf, atau NaN) dan menghasilkan y_tensor. Menerapkan operasi isFinite dari spesifikasi IEEE-754. Untuk jenis terkuantisasi, hasilnya selalu true.

Input

Label Nama Jenis Batasan
(I1) x Tensor floating-point atau per-tensor quantizedTensor (C1)

Output

Nama Jenis Batasan
y Tensor tipe boolean (C1)

Batasan

  • (C1) shape(x) = shape(y).

Contoh

// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x) : (tensor<7xf64) -> tensor<7xi1>
// %y: [false, false, false, true, true, true, true]

Contoh Lainnya

log

Semantik

Menjalankan operasi logaritma element-wise pada TensorFlow operand dan menghasilkan tensor result. Bergantung pada jenis elemen, lakukan hal berikut:

  • Untuk float: log dari IEEE-754.
  • Untuk bilangan kompleks: logaritma kompleks.
  • Untuk jenis terkuantisasi: dequantize_op_quantize(log, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand Tensor floating-point atau tipe kompleks atau kuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result Tensor floating-point atau tipe kompleks atau kuantisasi per-tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]

Contoh Lainnya

log_plus_one

Semantik

Melakukan logaritma element-wise plus satu operasi pada TensorFlow operand dan menghasilkan5 result. Bergantung pada jenis elemen, lakukan hal berikut:

  • Untuk float: logp1 dari IEEE-754.
  • Untuk bilangan kompleks: logaritma kompleks plus satu.
  • Untuk jenis terkuantisasi: dequantize_op_quantize(log_plus_one, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand Tensor floating-point atau tipe kompleks atau kuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result Tensor floating-point atau tipe kompleks atau kuantisasi per-tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]

Contoh Lainnya

logistik

Semantik

Menjalankan operasi logistik element-wise pada TensorFlow operand dan menghasilkan tensor result. Bergantung pada jenis elemen, lakukan hal berikut:

  • Untuk float: division(1, addition(1, exp(-x))) dari IEEE-754.
  • Untuk bilangan kompleks: logistik kompleks.
  • Untuk jenis terkuantisasi: dequantize_op_quantize(logistic, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand Tensor floating-point atau tipe kompleks atau kuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result Tensor floating-point atau tipe kompleks atau kuantisasi per-tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]

Contoh Lainnya

map

Semantik

Menerapkan fungsi peta computation ke inputs di sepanjang dimensions dan menghasilkan TensorFlow result.

Secara lebih formal, result[result_index] = computation(inputs...[result_index]). Perlu diketahui bahwa dimensions saat ini tidak digunakan dan mungkin akan dihapus pada masa mendatang (#487).

Input

Label Nama Jenis Batasan
(I1) inputs jumlah variadic Tensor atau quantized5 per-tensor (C1-C4)
(I2) dimensions Konstanta TensorFlow 1 dimensi dari jenis si64 (C3)
(I3) computation fungsi (C4)

Output

Nama Jenis Batasan
result Quantized Tensor atau per-tensor (C1), (C4)

Batasan

  • (C1) shape(inputs...) = shape(result).
  • (C2) 0 < size(inputs) = N.
  • (C3) dimensions = range(rank(inputs[0])).
  • (C4) computation memiliki jenis (tensor<E0>, ..., tensor<EN-1>) -> tensor<E'> dengan Ei = element_type(inputs[i]) dan E' = element_type(result).

Contoh

// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = stablehlo.multiply %arg0, %arg1 : tensor<i64>
    stablehlo.return %0 : tensor<i64>
}) {
  dimensions = dense<[0, 1]> : tensor<2xi64>
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[0, 5], [12, 21]]

Contoh Lainnya

maksimum

Semantik

Menjalankan operasi maks element-wise pada TensorFlow lhs dan rhs serta menghasilkan tensor result. Bergantung pada jenis elemen, lakukan hal berikut:

  • Untuk boolean: logika OR.
  • Untuk bilangan bulat: maksimum bilangan bulat.
  • Untuk float: maximum dari IEEE-754.
  • Untuk bilangan kompleks: nilai maksimum leksikografis untuk pasangan (real, imaginary). Memaksakan pengurutan pada bilangan kompleks melibatkan semantik yang mengejutkan, jadi di masa mendatang kami berencana menghapus dukungan untuk bilangan kompleks untuk operasi ini (#560).
  • Untuk jenis terkuantisasi:
    • dequantize_op_quantize(maximum, lhs, rhs, type(result)).

Input

Label Nama Jenis Batasan
(I1) lhs Quantized Tensor atau per-tensor (C1)
(I2) rhs Quantized Tensor atau per-tensor (C1)

Output

Nama Jenis Batasan
result Quantized Tensor atau per-tensor (C1)

Batasan

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Contoh

// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 8]]

Contoh Lainnya

minimum

Semantik

Menjalankan operasi min element-wise pada TensorFlow lhs dan rhs serta menghasilkan tensor result. Bergantung pada jenis elemen, lakukan hal berikut:

  • Untuk boolean: logika AND.
  • Untuk bilangan bulat: minimum bilangan bulat.
  • Untuk float: minimum dari IEEE-754.
  • Untuk bilangan kompleks: minimum leksikografis untuk pasangan (real, imaginary). Memaksakan pengurutan pada bilangan kompleks melibatkan semantik yang mengejutkan, jadi di masa mendatang kami berencana menghapus dukungan untuk bilangan kompleks untuk operasi ini (#560).
  • Untuk jenis terkuantisasi:
    • dequantize_op_quantize(minimum, lhs, rhs, type(result)).

Input

Label Nama Jenis Batasan
(I1) lhs Quantized Tensor atau per-tensor (C1)
(I2) rhs Quantized Tensor atau per-tensor (C1)

Output

Nama Jenis Batasan
result Quantized Tensor atau per-tensor (C1)

Batasan

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Contoh

// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 4]]

Contoh Lainnya

kalikan

Semantik

Melakukan perkalian element-wise dari dua TensorFlow lhs dan rhs serta menghasilkan result Tensor. Bergantung pada jenis elemen, lakukan hal berikut:

  • Untuk boolean: logika AND.
  • Untuk bilangan bulat: perkalian bilangan bulat.
  • Untuk float: multiplication dari IEEE-754.
  • Untuk bilangan kompleks: perkalian kompleks.
  • Untuk jenis terkuantisasi:
    • dequantize_op_quantize(multiply, lhs, rhs, type(result)).

Input

Label Nama Jenis Batasan
(I1) lhs Quantized Tensor atau per-tensor (C1)
(I2) rhs Quantized Tensor atau per-tensor (C1)

Output

Nama Jenis Batasan
result Quantized Tensor atau per-tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 12], [21, 32]]

Contoh Lainnya

negasi

Semantik

Melakukan negasi berbasis elemen dari reCAPTCHA operand dan menghasilkan tensor result. Bergantung pada jenis elemen, lakukan hal berikut:

  • Untuk bilangan bulat bertanda: negasi bilangan bulat.
  • Untuk bilangan bulat yang tidak ditandatangani: bitcast ke bilangan bulat bertanda tangan, negasi bilangan bulat, bitcast kembali ke bilangan bulat yang tidak ditandatangani.
  • Untuk float: negate dari IEEE-754.
  • Untuk bilangan kompleks: negasi kompleks.
  • Untuk jenis terkuantisasi: dequantize_op_quantize(negate, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand Tensor integer, floating point, atau tipe kompleks atau kuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result Tensor integer, floating point, atau tipe kompleks atau kuantisasi per-tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand) : (tensor<2xi32>) -> tensor<2xi32>
// %result: [0, 2]

// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"(%operand) : (tensor<1xcomplex<f32>>) -> tensor<1xcomplex<f32>>
// %result: [-2.5, -0.0]

Contoh Lainnya

bukan

Semantik

Melakukan NOT-wise NOT dari operand_tensor dan menghasilkan result_tensor. Bergantung pada jenis elemen, lakukan hal berikut:

  • Untuk boolean: logika NOT.
  • Untuk bilangan bulat: bitwise NOT.

Argumen

Nama Jenis Batasan
operand Tensor tipe boolean atau integer (C1)

Output

Nama Jenis Batasan
result Tensor tipe boolean atau integer (C1)

Batasan

  • (C1) type(operand) = type(result).

Contoh

// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand) : (tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[-2, -3], [-4, -5]]

// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"(%operand) : (tensor<2xi1>) -> tensor<2xi1>
// %result: [false, true]

optimization_barrier

Semantik

Memastikan bahwa operasi yang menghasilkan operand dijalankan sebelum operasi apa pun yang bergantung pada result dan mencegah transformasi compiler memindahkan operasi melalui batasan. Selain itu, operasi tersebut adalah identitas, yaitu result = operand.

Argumen

Nama Jenis Batasan
operand jumlah variadic Tensor, token terkuantisasi per-tensor, atau token (C1)

Output

Nama Jenis Batasan
result jumlah variadic Tensor, token terkuantisasi per-tensor, atau token (C1)

Batasan

  • (C1) type(operand...) = type(result...).

Contoh

// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
// %result0: 0.0
// %result1: 1.0

Contoh Lainnya

atau

Semantik

Melakukan OR berdasarkan elemen dari dua TensorFlow lhs dan rhs serta menghasilkan tensor result. Bergantung pada jenis elemen, lakukan hal berikut:

  • Untuk boolean: logika OR.
  • Untuk bilangan bulat: bitwise OR.

Input

Label Nama Jenis Batasan
(I1) lhs Tensor integer atau jenis boolean (C1)
(I2) rhs Tensor integer atau jenis boolean (C1)

Output

Nama Jenis Batasan
result Tensor integer atau jenis boolean (C1)

Batasan

  • (C1) type(lhs) = type(rhs) = type(result).

Contoh

// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 12]]

// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, true]]

outfeed

Semantik

Menulis inputs ke outfeed dan menghasilkan token result.

Semantik outfeed_config ditentukan oleh implementasi.

Input

Label Nama Jenis
(I1) inputs bilangan variadic dari TensorFlow atau Quantized Tensor
(I2) token token
(I3) outfeed_config konstanta jenis string

Output

Nama Jenis
result token

Contoh

%result = "stablehlo.outfeed"(%inputs0, %token) {
  outfeed_config = ""
} : (tensor<2x2x2xi64>, !stablehlo.token) -> !stablehlo.token

Contoh Lainnya

bantalan

Semantik

Memperluas operand dengan padding di sekitar TensorFlow serta di antara elemen tensor dengan padding_value yang diberikan.

edge_padding_low dan edge_padding_high menentukan jumlah padding yang ditambahkan di bagian bawah (di samping indeks 0) dan kelas atas (di samping indeks tertinggi) dari setiap dimensi. Jumlah padding bisa negatif, dengan nilai absolut padding negatif menunjukkan jumlah elemen yang akan dihapus dari dimensi yang ditentukan.

interior_padding menentukan jumlah padding yang ditambahkan di antara dua elemen di setiap dimensi, yang tidak boleh negatif. Padding interior terjadi sebelum padding tepi sehingga padding tepi negatif akan menghapus elemen dari operand dengan padding interior.

Secara lebih formal, result[result_index] ditentukan sebagai:

  • operand[operand_index] jika result_index = edge_padding_low + operand_index * (interior_padding + 1).
  • padding_value jika tidak.

Input

Label Nama Jenis Batasan
(I1) operand Quantized Tensor atau per-tensor (C1), (C2), (C4)
(I2) padding_value Tensor 0-dimensi atau Tensortized per-tensor (C1)
(I3) edge_padding_low Konstanta TensorFlow 1 dimensi dari jenis si64 (C1), (C4)
(I4) edge_padding_high Konstanta TensorFlow 1 dimensi dari jenis si64 (C1), (C4)
(I5) interior_padding Konstanta TensorFlow 1 dimensi dari jenis si64 (C2-C4)

Output

Nama Jenis Batasan
result Quantized Tensor atau per-tensor (C3-C6)

Batasan

  • (C1) element_type(operand) = element_type(padding_value) = element_type(result).
  • (C2) size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand).
  • (C3) 0 <= interior_padding.
  • (C4) shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high.

Contoh

// %operand: [
//            [1, 2, 3],
//            [4, 5, 6]
//           ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
  edge_padding_low = dense<[0, 1]> : tensor<2xi64>,
  edge_padding_high = dense<[2, 1]> : tensor<2xi64>,
  interior_padding = dense<[1, 2]> : tensor<2xi64>
} : (tensor<2x3xi32>, tensor<i32>) -> tensor<5x9xi32>
// %result: [
//           [0, 1, 0, 0, 2, 0, 0, 3, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 4, 0, 0, 5, 0, 0, 6, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0]
//          ]

Contoh Lainnya

partition_id

Semantik

Menghasilkan partition_id dari proses saat ini.

Output

Nama Jenis
result Tensor 0 dimensi dari jenis ui32

Contoh

%result = "stablehlo.partition_id"() : () -> tensor<ui32>

Contoh Lainnya

popcnt

Semantik

Melakukan penghitungan element-wise dari jumlah bit yang ditetapkan dalam TensorFlow operand dan menghasilkan5 result.

Input

Label Nama Jenis Batasan
(I1) operand Tensor tipe integer (C1)

Output

Nama Jenis Batasan
result Tensor tipe integer (C1)

Batasan

  • (C1) type(operand) = type(result).

Contoh

// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand) : (tensor<4xi64>) -> tensor<4xi64>
// %result: [0, 1, 1, 7]

Contoh Lainnya

daya

Semantik

Melakukan eksponensiasi element-wise untuk TensorFlow lhs dengan TensorFlow rhs dan menghasilkan TensorFlow result. Bergantung pada jenis elemen, lakukan hal berikut:

  • Untuk bilangan bulat: eksponensiasi bilangan bulat.
  • Untuk float: pow dari IEEE-754.
  • Untuk bilangan kompleks: pangkat kompleks.
  • Untuk jenis terkuantisasi: dequantize_op_quantize(power, lhs, rhs, type(result)).

Input

Label Nama Jenis Batasan
(I1) lhs Tensor integer, floating point, atau tipe kompleks atau kuantisasi per-tensor (C1)
(I2) rhs Tensor integer, floating point, atau tipe kompleks atau kuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result Tensor integer, floating point, atau tipe kompleks atau kuantisasi per-tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs) : (tensor<6xf64>, tensor<6xf64>) -> tensor<6xf64>
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]

Contoh Lainnya

real

Semantik

Mengekstrak bagian nyata, sesuai elemen, dari operand dan menghasilkan tensor result. Secara lebih formal, untuk setiap elemen x: real(x) = is_complex(x) ? real_part(x) : x.

Input

Label Nama Jenis Batasan
(I1) operand Tensor floating point atau tipe kompleks (C1), (C2)

Output

Nama Jenis Batasan
result Tensor tipe floating-point (C1), (C2)

Batasan

  • (C1) shape(result) = shape(operand).
  • (C2) element_type(result) ditentukan sebagai:
    • complex_element_type(element_type(operand)) jika is_complex(operand).
    • element_type(operand) jika tidak.

Contoh

// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [1.0, 3.0]

Contoh Lainnya

diambil

Semantik

Menerima data dari saluran dengan channel_id dan menghasilkan results.

Jika is_host_transfer adalah true, operasi akan mentransfer data dari host. Jika tidak, server akan mentransfer data dari perangkat lain. Hal ini berarti ditentukan oleh implementasi. Flag ini menduplikasi informasi yang diberikan di channel_type, sehingga ke depannya kami berencana untuk hanya menyimpan salah satunya (#666).

results terdiri dari nilai payload yang akan muncul terlebih dahulu dan token yang akan muncul terakhir. Di masa mendatang, kami berencana membagi payload dan token menjadi dua output terpisah untuk meningkatkan kejelasan (#670).

Input

Label Nama Jenis Batasan
(I1) token token (C4)
(I2) channel_id konstanta jenis si64
(I3) channel_type enum DEVICE_TO_DEVICE dan HOST_TO_DEVICE (C1)
(I4) is_host_transfer konstanta jenis i1 (C1)

Output

Nama Jenis Batasan
results bilangan variadic, quantized TensorFlow, atau token (C2-C4)

Batasan

  • (C1) channel_type didefinisikan sebagai:
    • HOST_TO_DEVICE jika is_host_transfer = true,
    • DEVICE_TO_DEVICE jika tidak.
  • (C2) 0 < size(results).
  • (C3) is_empty(result[:-1]) atau is_tensor(type(results[:-1])).
  • (C4) is_token(type(results[-1])).

Contoh

%results0, %results1 = "stablehlo.recv"(%token) {
  channel_handle = #stablehlo.channel_handle<handle = 1, type = 3>,
  is_host_transfer = true
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)

Contoh Lainnya

reduce

Semantik

Menerapkan fungsi reduksi body ke inputs dan init_values di sepanjang dimensions dan menghasilkan5_tensor results.

Urutan pengurangan ditentukan oleh penerapan, yang berarti bahwa body dan init_values harus membentuk monoid untuk menjamin bahwa operasi memberikan hasil yang sama untuk semua input pada semua implementasi. Namun, kondisi ini tidak berlaku untuk banyak pengurangan yang populer. Misalnya, penambahan floating point untuk body dan nol untuk init_values sebenarnya tidak membentuk monoid karena penambahan floating point tidak asosiatif.

Secara lebih formal, results...[j0, ..., jR-1] = reduce(input_slices_converted) jika:

  • input_slices = inputs...[j0, ..., :, ..., jR-1], dengan : disisipkan di dimensions.
  • input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...).
  • init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...).
  • reduce(input_slices_converted) = exec(schedule) untuk beberapa hierarki biner schedule dengan:
    • exec(node) = body(exec(node.left), exec(node.right)).
    • exec(leaf) = leaf.value.
  • schedule adalah hierarki biner lengkap yang ditentukan implementasi yang memiliki traversal sesuai urutan yang terdiri dari:
    • input_slices_converted...[index], untuk semua index di index_space(input_slices_converted) dalam urutan leksikografis menaik dari index.
    • Diselingi dengan jumlah init_values_converted yang ditentukan implementasi pada posisi yang ditentukan oleh implementasi.

Input

Label Nama Jenis Batasan
(I1) inputs jumlah variadic Tensor atau quantized5 per-tensor (C1-C4), (C6), (C7)
(I2) init_values bilangan variadic dari 0-dimensions atau quantized true-tensor per-tensor (C2), (C3)
(I3) dimensions Konstanta TensorFlow 1 dimensi dari jenis si64 (C4), (C5), (C7)
(I4) body fungsi (C6)

Output

Nama Jenis Batasan
results jumlah variadic Tensor atau quantized5 per-tensor (C3), (C7), (C8)

Batasan

  • (C1) same(shape(inputs...)).
  • (C2) element_type(inputs...) = element_type(init_values...).
  • (C3) 0 < size(inputs) = size(init_values) = size(results) = N.
  • (C4) 0 <= dimensions < rank(inputs[0]).
  • (C5) is_unique(dimensions).
  • (C6) body memiliki jenis (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>) dengan is_promotable(element_type(inputs[i]), Ei).
  • (C7) shape(results...) = shape(inputs...) kecuali bahwa ukuran dimensi inputs... yang sesuai dengan dimensions tidak disertakan.
  • (C8) element_type(results[i]) = Ei untuk semua i di [0,N).

Contoh

// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  dimensions = dense<1> : tensor<1xi64>
} : (tensor<1x6xi64>, tensor<i64>) -> tensor<1xi64>
// %result = [15]

Contoh Lainnya

reduce_precision

Semantik

Melakukan konversi operand berbasis elemen ke jenis floating point lainnya yang menggunakan exponent_bits dan mantissa_bits serta kembali ke jenis floating point asli dan menghasilkan TensorFlow output.

Secara lebih formal:

  • Bit mantissa dari nilai asli diperbarui untuk membulatkan nilai asli ke nilai terdekat yang dapat diwakili dengan mantissa_bits menggunakan semantik roundToIntegralTiesToEven.
  • Kemudian, jika mantissa_bits lebih kecil dari jumlah bit mantissa nilai asli, bit mantissa dipotong menjadi mantissa_bits.
  • Kemudian, jika bit eksponen dari hasil perantara tidak sesuai dengan rentang yang disediakan oleh exponent_bits, hasil perantara akan ditambahkan ke tak terbatas menggunakan tanda asli atau underflow ke nol menggunakan tanda asli.
  • Untuk jenis terkuantisasi, jalankan dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand Tensor floating-point atau per-tensor quantizedTensor (C1)
(I2) exponent_bits konstanta jenis si32 (C2)
(I3) mantissa_bits konstanta jenis si32 (C3)

Output

Nama Jenis Batasan
output Tensor floating-point atau per-tensor quantizedTensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(output).
  • (C2) 1 <= exponent_bits.
  • (C3) 0 <= mantissa_bits.

Contoh

// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
  exponent_bits = 5 : i32,
  mantissa_bits = 10 : i32
} : (tensor<6xf64>) -> tensor<6xf64>
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]

Contoh Lainnya

reduce_scatter

Semantik

Dalam setiap grup proses di petak proses StableHLO, lakukan pengurangan, menggunakan computations, pada nilai Tensor operand dari setiap proses, membagi hasil pengurangan di sepanjang scatter_dimension menjadi beberapa bagian, dan menyebarkan bagian pemisahan di antara proses untuk menghasilkan result.

Operasi ini membagi petak proses StableHLO menjadi process_groups yang ditentukan sebagai berikut:

  • cross_replica(replica_groups) jika channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) jika channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) jika channel_id > 0 and use_global_device_ids = true.

Setelah itu, dalam setiap process_group:

  • reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation).
  • parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension).
  • result@receiver = parts@sender[receiver_index] untuk semua sender di process_group, dengan receiver_index = process_group.index(receiver).

Input

Label Nama Jenis Batasan
(I1) operand Quantized Tensor atau per-tensor (C1), (C2), (C7), (C8)
(I2) scatter_dimension konstanta jenis si64 (C1), (C2), (C8)
(I3) replica_groups Konstanta TensorFlow 2 dimensi dari jenis si64 (C3-C5)
(I4) channel_id konstanta jenis si64 (C6)
(I5) use_global_device_ids konstanta jenis i1 (C6)
(I6) computation fungsi (C7)

Output

Nama Jenis Batasan
result Quantized Tensor atau per-tensor (C8-C9)

Batasan

  • (C1) dim(operand, scatter_dimension) % dim(process_groups, 1) = 0.
  • (C2) 0 <= scatter_dimension < rank(operand).
  • (C3) is_unique(replica_groups).
  • (C4) size(replica_groups) didefinisikan sebagai:
    • num_replicas jika cross_replica digunakan.
    • num_replicas jika cross_replica_and_partition digunakan.
    • num_processes jika flattened_ids digunakan.
  • (C5) 0 <= replica_groups < size(replica_groups).
  • (C6) Jika use_global_device_ids = true, maka channel_id > 0.
  • (C7) computation memiliki jenis (tensor<E>, tensor<E>) -> (tensor<E>) dengan is_promotable(element_type(operand), E).
  • (C8) shape(result) = shape(operand) kecuali:
    • dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1).
  • (C9) element_type(result) = E.

Contoh

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
//                   [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
//                   [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
  %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
  "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  scatter_dimension = 1 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x4xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[10, 12],
//                  [18, 20]]
// %result@(1, 0): [[14, 16],
//                  [22, 24]]

Contoh Lainnya

reduce_window

Semantik

Menerapkan fungsi pengurangan body ke jendela inputs dan init_values serta menghasilkan results.

Diagram berikut menunjukkan cara elemen dalam results... dihitung dari inputs... menggunakan contoh konkret.

Secara lebih formal, results...[result_index] = reduce(windows, init_values, axes(inputs...), body) (lihat mengurangi) jika:

  • padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1).
  • window_start = result_index * window_strides.
  • window_end = window_start + (window_dimensions - 1) * window_dilations + 1.
  • windows = slice(padded_inputs..., window_start, window_end, window_dilations).

Input

Label Nama Jenis Batasan
(I1) inputs jumlah variadic Tensor atau quantized5 per-tensor (C1-C4), (C6), (C8), (C10), (C12), (C13), (C15)
(I2) init_values bilangan variadic dari 0-dimensions atau quantized true-tensor per-tensor (C1), (C13)
(I3) window_dimensions Konstanta TensorFlow 1 dimensi dari jenis si64 (C4), (C5), (C15)
(I4) window_strides Konstanta TensorFlow 1 dimensi dari jenis si64 (C6), (C7), (C15)
(I5) base_dilations Konstanta TensorFlow 1 dimensi dari jenis si64 (C8), (C9), (C15)
(I6) window_dilations Konstanta TensorFlow 1 dimensi dari jenis si64 (C10), (C11), (C15)
(I7) padding Konstanta TensorFlow 2 dimensi dari jenis si64 (C12), (C15)
(I8) body fungsi (C13)

Output

Nama Jenis Batasan
results jumlah variadic Tensor atau quantized5 per-tensor (C1), (C14-C16)

Batasan

  • (C1) 0 < size(inputs) = size(init_values) = size(results) = N.
  • (C2) same(shape(inputs...)).
  • (C3) element_type(inputs...) = element_type(init_values...).
  • (C4) size(window_dimensions) = rank(inputs[0]).
  • (C5) 0 < window_dimensions.
  • (C6) size(window_strides) = rank(inputs[0]).
  • (C7) 0 < window_strides.
  • (C8) size(base_dilations) = rank(inputs[0]).
  • (C9) 0 < base_dilations.
  • (C10) size(window_dilations) = rank(inputs[0]).
  • (C11) 0 < window_dilations.
  • (C12) shape(padding) = [rank(inputs[0]), 2].
  • (C13) body memiliki jenis (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>) dengan is_promotable(element_type(inputs[i]), Ei).
  • (C14) same(shape(results...)).
  • (C15) shape(results[0]) = num_windows dalam hal:
    • dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1.
    • padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1].
    • dilated_window_shape = (window_dimensions - 1) * window_dilations + 1.
    • is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape.
    • num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1.
  • (C16) element_type(results[i]) = Ei untuk semua i di [0,N).

Contoh

// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  window_dimensions = dense<[2, 1]> : tensor<2xi64>,
  window_strides = dense<[4, 1]> : tensor<2xi64>,
  base_dilations = dense<[2, 1]> : tensor<2xi64>,
  window_dilations = dense<[3, 1]> : tensor<2xi64>,
  padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<3x2xi64>, tensor<i64>) -> tensor<2x2xi64>
// %result = [[0, 0], [3, 4]]

Contoh Lainnya

sisa

Semantik

Melakukan sisa composable dividen lhs dan pembagi rhs yang bijaksana serta menghasilkan tensor result.

Secara lebih formal, tanda hasil diambil dari dividen, dan nilai absolut hasil selalu lebih kecil dari nilai absolut pembagi. Sisanya dihitung sebagai lhs - d * rhs, dengan d diberikan oleh:

  • Untuk bilangan bulat: stablehlo.divide(lhs, rhs).
  • Untuk float: division(lhs, rhs) dari IEEE-754 dengan atribut pembulatan roundTowardZero.
  • Untuk angka kompleks: TBD (#997).
  • Untuk jenis terkuantisasi:
    • dequantize_op_quantize(remainder, lhs, rhs, type(result)).

Untuk jenis elemen floating point, operasi ini berbeda dengan operasi remainder dari spesifikasi IEEE-754, dengan d merupakan nilai integral yang terdekat dengan nilai pasti lhs/rhs dengan nilai yang sama dengan genap.

Input

Label Nama Jenis Batasan
(I1) lhs Tensor integer, floating point atau tipe kompleks atau quantized per-tensor (C1)
(I2) rhs Tensor integer, floating point atau tipe kompleks atau quantized per-tensor (C1)

Output

Nama Jenis Batasan
result Tensor integer, floating point atau tipe kompleks atau quantized per-tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64>
// %result: [2, -2, 2, -2]

Contoh Lainnya

replica_id

Semantik

Menghasilkan replica_id dari proses saat ini.

Output

Nama Jenis
result Tensor 0 dimensi dari jenis ui32

Contoh

%result = "stablehlo.replica_id"() : () -> tensor<ui32>

Contoh Lainnya

bentuk ulang

Semantik

Melakukan pembentukan ulang TensorFlow operand menjadi TensorFlow result. Secara konseptual, hal ini berarti mempertahankan representasi kanonis yang sama, tetapi berpotensi mengubah bentuk, misalnya, dari tensor<2x3xf32> menjadi tensor<3x2xf32> atau tensor<6xf32>.

Secara lebih formal, result[result_index] = operand[operand_index] dengan result_index dan operand_index memiliki posisi yang sama dalam pengurutan leksikografis index_space(result) dan index_space(operand).

Input

Label Nama Jenis Batasan
(I1) operand Tensor atau Quantized Tensor (C1-C3)

Output

Nama Jenis Batasan
result Tensor atau Quantized Tensor (C1-C3)

Batasan

  • (C1) element_type(result) diberikan oleh:
    • element_type(operand), jika !is_per_axis_quantized(operand).
    • element_type(operand) kecuali bahwa quantization_dimension(operand) dan quantization_dimension(result) mungkin berbeda.
  • (C2) size(operand) = size(result).
  • (C3) Jika is_per_axis_quantized(operand):
    • reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
    • dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result)).
    • reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).

Contoh

// %operand: [[1, 2, 3], [4, 5, 6]]]
%result = "stablehlo.reshape"(%operand) : (tensor<2x3xi32>) -> tensor<3x2xi32>
// %result: [[1, 2], [3, 4], [5, 6]]

Contoh Lainnya

reverse

Semantik

Membalik urutan elemen dalam operand di sepanjang dimensions yang ditentukan dan menghasilkan TensorFlow result. Secara lebih formal, result[result_index] = operand[operand_index] jika:

  • operand_index[d] = dim(result, d) - result_index[d] - 1 jika d dalam dimensions.
  • operand_index[d] = result_index[d] jika tidak.

Input

Label Nama Jenis Batasan
(I1) operand Quantized Tensor atau per-tensor (C1), (C3)
(I2) dimensions Konstanta TensorFlow 1 dimensi dari jenis si64 (C2), (C3)

Output

Nama Jenis Batasan
result Quantized Tensor atau per-tensor (C1), (C3)

Batasan

  • (C1) type(operand) = type(result).
  • (C2) is_unique(dimensions).
  • (C3) 0 <= dimensions < rank(result).

Contoh

// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
  dimensions = dense<1> : tensor<1xi64>
} : (tensor<3x2xi32>) -> tensor<3x2xi32>
// %result: [[2, 1], [4, 3], [6, 5]]

Contoh Lainnya

RNG

Semantik

Menghasilkan angka acak menggunakan algoritme rng_distribution dan menghasilkan tensor result dari bentuk shape tertentu.

Jika rng_distribution = UNIFORM, angka acak akan dihasilkan dengan mengikuti distribusi seragam selama interval [a, b). Jika a >= b, perilaku tidak ditentukan.

Jika rng_distribution = NORMAL, angka acak akan dihasilkan mengikuti distribusi normal dengan rataan = a dan simpangan baku = b. Jika b < 0, perilaku tidak terdefinisi.

Cara yang tepat untuk menghasilkan angka acak ditentukan oleh implementasinya. Misalnya, mungkin atau tidak bersifat deterministik, dan mungkin menggunakan status tersembunyi, mungkin juga tidak.

Dalam percakapan dengan banyak pemangku kepentingan, operasi ini menjadi tidak digunakan lagi secara efektif, sehingga ke depannya kami berencana mempelajari penghapusannya (#597).

Input

Label Nama Jenis Batasan
(I1) a Tensor 0-dimensi untuk jenis integer, boolean, atau floating point (C1), (C2)
(I2) b Tensor 0-dimensi untuk jenis integer, boolean, atau floating point (C1), (C2)
(I3) shape Konstanta TensorFlow 1 dimensi dari jenis si64 (C3)
(I4) rng_distribution enum UNIFORM dan NORMAL (C2)

Output

Nama Jenis Batasan
result Tensor integer, boolean, atau tipe floating point (C1-C3)

Batasan

  • (C1) element_type(a) = element_type(b) = element_type(result).
  • (C2) Jika rng_distribution = NORMAL, maka is_float(a).
  • (C3) shape(result) = shape.

Contoh

// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
  rng_distribution = #stablehlo<rng_distribution UNIFORM>
} : (tensor<i32>, tensor<i32>, tensor<2xi64>) -> tensor<3x3xi32>
// %result: [
//           [1, 0, 1],
//           [1, 1, 1],
//           [0, 0, 0]
//          ]

rng_bit_generator

Semantik

Menampilkan output yang diisi dengan bit acak yang seragam dan status output yang diperbarui output_state menggunakan algoritme generator angka pseudorandom rng_algorithm berdasarkan status awal initial_state. Outputnya dijamin akan menjadi fungsi deterministik initial_state, tetapi tidak dijamin akan deterministik di antara implementasi.

rng_algorithm adalah salah satu dari berikut ini:

  • DEFAULT: Algoritme yang ditentukan implementasi.
  • THREE_FRY: Varian yang ditentukan implementasi dari algoritme Threefry.*
  • PHILOX: Varian yang ditentukan implementasi dari algoritme Philox.*

* Lihat: Salmon et al. SC 2011. Angka acak paralel: semudah 1, 2, 3.

Input

Label Nama Jenis Batasan
(I1) rng_algorithm enum DEFAULT, THREE_FRY, dan PHILOX (C2)
(I2) initial_state Tensor 1 dimensi jenis ui64 (C1), (C2)

Output

Nama Jenis Batasan
output_state Tensor 1 dimensi jenis ui64 (C1)
output Tensor tipe integer atau floating point

Batasan

  • (C1) type(initial_state) = type(output_state).
  • (C2) size(initial_state) ditentukan sebagai:
    • ditentukan implementasinya jika rng_algorithm = DEFAULT.
    • 2 jika rng_algorithm = THREE_FRY.
    • 2 atau 3 jika rng_algorithm = PHILOX.

Contoh

// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
  rng_algorithm = #stablehlo<rng_algorithm THREE_FRY>
} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x2xui64>)
// %output_state: [1, 6]
// %output: [
//           [9236835810183407956, 16087790271692313299],
//           [18212823393184779219, 2658481902456610144]
//          ]

round_nearest_afz

Semantik

Melakukan pembulatan berbasis elemen ke arah bilangan bulat terdekat, yang memisahkan ikatan dari nol, pada TensorFlow operand dan menghasilkan TensorFlow result. Mengimplementasikan operasi roundToIntegralTiesToAway dari spesifikasi IEEE-754. Untuk jenis terkuantisasi, jalankan dequantize_op_quantize(round_nearest_afz, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand Tensor floating-point atau per-tensor quantizedTensor (C1)

Output

Nama Jenis Batasan
result Tensor floating-point atau per-tensor quantizedTensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]

Contoh Lainnya

round_nearest_even

Semantik

Melakukan pembulatan berbasis elemen ke arah bilangan bulat terdekat, memutus ikatan ke bilangan bulat genap, pada TensorFlow operand dan menghasilkan tensor result. Mengimplementasikan operasi roundToIntegralTiesToEven dari spesifikasi IEEE-754. Untuk jenis terkuantisasi, jalankan dequantize_op_quantize(round_nearest_even, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand Tensor floating-point atau per-tensor quantizedTensor (C1)

Output

Nama Jenis Batasan
result Tensor floating-point atau per-tensor quantizedTensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]

Contoh Lainnya

{i>rsqrt<i}

Semantik

Melakukan operasi root kuadrat timbal balik berbasis elemen pada TensorFlow operand dan menghasilkan5 result. Bergantung pada jenis elemen, lakukan hal berikut:

  • Untuk float: rSqrt dari IEEE-754.
  • Untuk bilangan kompleks: akar kuadrat timbal balik kompleks.
  • Untuk jenis terkuantisasi: dequantize_op_quantize(rsqrt, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand Tensor floating-point atau tipe kompleks atau kuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result Tensor floating-point atau tipe kompleks atau kuantisasi per-tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.5], [0.33333343, 0.2]]

Contoh Lainnya

scatter

Semantik

Menghasilkan5_tensor results yang setara dengan_tensor inputs, kecuali jika beberapa irisan yang ditentukan oleh scatter_indices diperbarui dengan nilai updates menggunakan update_computation.

Diagram berikut menunjukkan cara elemen di updates... dipetakan pada elemen di results... menggunakan contoh konkret. Diagram memilih beberapa contoh indeks updates... dan menjelaskan secara mendetail indeks results... yang terkait dengannya.

Secara lebih formal, untuk semua update_index di index_space(updates[0]):

  • update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims].
  • update_scatter_index = update_index[update_scatter_dims...].
  • start_index didefinisikan sebagai:
    • scatter_indices[si0, ..., :, ..., siN] dengan si adalah elemen individual dalam update_scatter_index dan : disisipkan pada indeks index_vector_dim, jika index_vector_dim < rank(scatter_indices).
    • [scatter_indices[update_scatter_index]] jika tidak.
  • Untuk d_input di axes(inputs[0]),
    • full_start_index[d_input] = start_index[d_start] jika d_input = scatter_dims_to_operand_dims[d_start].
    • full_start_index[d_input] = 0 jika tidak.
  • update_window_index = update_index[update_window_dims...].
  • full_window_index = [wi0, ..., 0, ..., wiN] dengan wi adalah elemen individual di update_window_index, dan 0 disisipkan pada indeks dari inserted_window_dims.
  • result_index = full_start_index + full_window_index.

Dengan demikian, results = exec(schedule, inputs), dengan:

  • schedule adalah permutasi index_space(updates[0]) yang ditentukan oleh implementasi.
  • exec([update_index, ...], results) = exec([...], updated_results) dalam hal ini:
    • Jika result_index dalam batas untuk shape(results...)
    • updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )
    • updated_values = update_computation(results...[result_index], updates_converted)
    • updated_results adalah salinan results dengan results...[result_index] ditetapkan ke updated_values....
    • Atau
    • updated_results = results.
  • exec([], results) = results.

Jika indices_are_sorted adalah true, implementasi dapat mengasumsikan bahwa scatter_indices diurutkan dalam kaitannya dengan scatter_dims_to_operand_dims, jika tidak, perilaku tidak ditentukan. Secara lebih formal, untuk semua i1 < i2 dari indices(result), full_start_index(i1) <= full_start_index(i2).

Jika unique_indices adalah true, penerapan dapat mengasumsikan bahwa semua indeks result_index yang tersebar bersifat unik. Jika unique_indices adalah true, tetapi indeks yang disebar tidak unik, maka perilaku tidak ditentukan.

Input

Label Nama Jenis Batasan
(I1) inputs jumlah variadic Tensor atau quantized5 per-tensor (C1), (C2), (C4-C6), (C10), (C13), (C15-C16)
(I2) scatter_indices Tensor tipe integer (C4), (C11), (C14)
(I3) updates jumlah variadic Tensor atau quantized5 per-tensor (C3-C6), (C8)
(I4) update_window_dims Konstanta TensorFlow 1 dimensi dari jenis si64 (C2), (C4), (C7), (C8)
(I5) inserted_window_dims Konstanta TensorFlow 1 dimensi dari jenis si64 (C2), (C4), (C9), (C10)
(I6) scatter_dims_to_operand_dims Konstanta TensorFlow 1 dimensi dari jenis si64 (C11-C13)
(I7) index_vector_dim konstanta jenis si64 (C4), (C11), (C14)
(I8) indices_are_sorted konstanta jenis i1
(I9) unique_indices konstanta jenis i1
(I10) update_computation fungsi (C15)

Output

Nama Jenis Batasan
results jumlah variadic Tensor atau quantized5 per-tensor (C15-C17)

Batasan

  • (C1) same(shape(inputs...)).
  • (C2) rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims).
  • (C3) same(shape(updates...)).
  • (C4) shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes) jika:
    • update_scatter_dim_sizes = shape(scatter_indices) kecuali bahwa ukuran dimensi scatter_indices yang sesuai dengan index_vector_dim tidak disertakan.
    • update_window_dim_sizes <= shape(inputs[0]) kecuali bahwa ukuran dimensi di inputs[0] yang sesuai dengan inserted_window_dims tidak disertakan.
    • combine menempatkan update_scatter_dim_sizes pada sumbu yang sesuai dengan update_scatter_dims dan update_window_dim_sizes pada sumbu yang sesuai dengan update_window_dims.
  • (C5) 0 < size(inputs) = size(updates) = N.
  • (C6) element_type(updates...) = element_type(inputs...).
  • (C7) is_unique(update_window_dims) and is_sorted(update_window_dims).
  • (C8) 0 <= update_window_dims < rank(updates[0]).
  • (C9) is_unique(inserted_window_dims) and is_sorted(update_window_dims).
  • (C10) 0 <= inserted_window_dims < rank(inputs[0]).
  • (C11) size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1.
  • (C12) is_unique(scatter_dims_to_operand_dims).
  • (C13) 0 <= scatter_dims_to_operand_dims < rank(inputs[0])
  • (C14) 0 <= index_vector_dim <= rank(scatter_indices).
  • (C15) update_computation memiliki jenis (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>), dengan is_promotable(element_type(inputs[i]), Ei).
  • (C16) shape(inputs...) = shape(results...).
  • (C17) element_type(results[i]) = Ei untuk semua i di [0,N).

Contoh

// %input: [
//          [[1, 2], [3, 4], [5, 6], [7, 8]],
//          [[9, 10], [11, 12], [13, 14], [15, 16]],
//          [[17, 18], [19, 20], [21, 22], [23, 24]]
//         ]
// %scatter_indices: [[[0, 2], [1, 0], [2, 1]], [[0, 1], [1, 0], [0, 9]]]
// %update: [
//           [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]],
//           [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]]
//          ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  scatter_dimension_numbers = #stablehlo.scatter<
    update_window_dims = [2, 3],
    inserted_window_dims = [0],
    scatter_dims_to_operand_dims = [1, 0],
    index_vector_dim = 2>,
  indices_are_sorted = false,
  unique_indices = false
} : (tensor<3x4x2xi64>, tensor<2x3x2xi64>, tensor<2x3x2x2xi64>) -> tensor<3x4x2xi64>
// %result: [
//           [[1, 2], [5, 6], [7, 8], [7, 8]],
//           [[10, 11], [12, 13], [14, 15], [16, 17]],
//           [[18, 19], [20, 21], [21, 22], [23, 24]]
//          ]

Contoh Lainnya

pilih

Semantik

Menghasilkan5_tensor result dengan setiap elemen dipilih dari_tensor on_true atau on_false berdasarkan nilai elemen pred yang sesuai. Secara lebih formal, result[result_index] = pred_element ? on_true[result_index] : on_false[result_index], dengan pred_element = rank(pred) = 0 ? pred[] : pred[result_index]. Untuk jenis terkuantisasi, jalankan dequantize_select_quantize(pred, on_true, on_false, type(result)).

Input

Label Nama Jenis Batasan
(I1) pred Tensor jenis i1 (C1)
(I2) on_true Quantized Tensor atau per-tensor (C1-C2)
(I3) on_false Quantized Tensor atau per-tensor (C2)

Output

Nama Jenis Batasan
result Quantized Tensor atau per-tensor (C2)

Batasan

  • (C1) rank(pred) = 0 or shape(pred) = shape(on_true).
  • (C2) baseline_type(on_true) = baseline_type(on_false) = baseline_type(result).

Contoh

// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false) : (tensor<2x2xi1>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 2], [3, 8]]

Contoh Lainnya

select_and_scatter

Semantik

Sebarkan nilai dari TensorFlow source menggunakan scatter berdasarkan hasil reduce_window dari TensorFlow input menggunakan select dan menghasilkan tensor result.

Diagram berikut menunjukkan cara elemen dalam result dihitung dari operand dan source menggunakan contoh konkret.

Secara lebih formal:

  • selected_values = reduce_window_without_init(...) dengan input berikut:

    • `input = [operand].
    • window_dimensions, window_strides, dan padding yang digunakan apa adanya.
    • base_dilations = windows_dilations = 1.
    • body didefinisikan sebagai:
    def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>:
      return select(arg0, arg1) ? arg0 : arg1;
    

    dengan E = element_type(operand), dan reduce_window_without_init berfungsi persis seperti reduce_window, kecuali bahwa schedule dari reduce yang mendasarinya (lihat mengurangi) tidak menyertakan nilai init. Saat ini tidak ditentukan apa yang terjadi jika jendela yang sesuai tidak memiliki nilai (#731).

  • result[result_index] = reduce([source_values], [init_value], [0], scatter) dengan:

    • source_values = [source[source_index] for source_index in source_indices].
    • selected_index(source_index) = operand_index jika selected_values[source_index] memiliki elemen operand dari operand_index.
    • source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index].

Input

Label Nama Jenis Batasan
(I1) operand Quantized Tensor atau per-tensor (C1-C4), (C6), (C8-C11)
(I2) source Quantized Tensor atau per-tensor (C1), (C2)
(I3) init_value Tensor 0-dimensi atau Tensortized per-tensor (C3)
(I4) window_dimensions Konstanta TensorFlow 1 dimensi dari jenis si64 (C2), (C4), (C5)
(I5) window_strides Konstanta TensorFlow 1 dimensi dari jenis si64 (C2), (C6), (C7)
(I6) padding Konstanta TensorFlow 2 dimensi dari jenis si64 (C2), (C8)
(I7) select fungsi (C9)
(I8) scatter fungsi (C10)

Output

Nama Jenis Batasan
result Quantized Tensor atau per-tensor (C11-C12)

Batasan

  • (C1) element_type(operand) = element_type(source).
  • (C2) shape(source) = num_windows jika:
    • padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1].
    • is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape.
    • num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1.
  • (C3) element_type(init_value) = element_type(operand).
  • (C4) size(window_dimensions) = rank(operand).
  • (C5) 0 < window_dimensions.
  • (C6) size(window_strides) = rank(operand).
  • (C7) 0 < window_strides.
  • (C8) shape(padding) = [rank(operand), 2].
  • (C9) select memiliki jenis (tensor<E>, tensor<E>) -> tensor<i1> dengan E = element_type(operand).
  • (C10) scatter memiliki jenis (tensor<E>, tensor<E>) -> tensor<E> dengan is_promotable(element_type(operand), E).
  • (C11) shape(operand) = shape(result).
  • (C12) element_type(result) = E.

Contoh

// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GE>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    "stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  window_dimensions = dense<[3, 1]> : tensor<2xi64>,
  window_strides = dense<[2, 1]> : tensor<2xi64>,
  padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi64>, tensor<2x2xi64>, tensor<i64>) -> tensor<4x2xi64>
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]

Contoh Lainnya

kirim

Semantik

Mengirim inputs ke saluran channel_id dan menghasilkan token result.

Jika is_host_transfer adalah true, operasi akan mentransfer data ke host. Jika tidak, data akan ditransfer ke perangkat lain. Hal ini berarti ditentukan oleh implementasi. Flag ini menduplikasi informasi yang diberikan di channel_type, sehingga ke depannya kami berencana untuk hanya menyimpan salah satunya (#666).

Input

Label Nama Jenis Batasan
(I1) inputs bilangan variadic dari TensorFlow atau Quantized Tensor
(I2) token token
(I3) channel_id konstanta jenis si64
(I4) channel_type enum DEVICE_TO_DEVICE dan DEVICE_TO_HOST (C1)
(I5) is_host_transfer konstanta jenis i1 (C1)

Output

Nama Jenis
result token

Batasan

  • (C1) channel_type didefinisikan sebagai:
    • DEVICE_TO_HOST jika is_host_transfer = true,
    • DEVICE_TO_DEVICE jika tidak.

Contoh

%result = "stablehlo.send"(%operand, %token) {
  channel_handle = #stablehlo.channel_handle<handle = 1, type = 2>,
  is_host_transfer = true
} : (tensor<2x2xi64>, !stablehlo.token) -> !stablehlo.token

Contoh Lainnya

shift_left

Semantik

Melakukan operasi kiri-shift element-wise pada TensorFlow lhs dengan jumlah rhs bit dan menghasilkan TensorFlow result.

Input

Label Nama Jenis Batasan
(I1) lhs Tensor tipe integer (C1)
(I2) rhs Tensor tipe integer (C1)

Output

Nama Jenis Batasan
result Tensor tipe integer (C1)

Batasan

  • (C1) type(lhs) = type(rhs) = type(result).

Contoh

// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-2, 0, 8]

Contoh Lainnya

shift_right_arithmetic

Semantik

Melakukan operasi pergeseran kanan aritmatika element-wise pada TensorFlow lhs dengan jumlah rhs dari bit dan menghasilkan TensorFlow result.

Input

Label Nama Jenis Batasan
(I1) lhs Tensor tipe integer (C1)
(I2) rhs Tensor tipe integer (C1)

Output

Nama Jenis Batasan
result Tensor tipe integer (C1)

Batasan

  • (C1) type(lhs) = type(rhs) = type(result).

Contoh

// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-1, 0, 1]

Contoh Lainnya

shift_right_logical

Semantik

Melakukan operasi pergeseran kanan logis yang mengatur elemen (element-wise) pada TensorFlow lhs dengan jumlah rhs bit dan menghasilkan TensorFlow result.

Input

Label Nama Jenis Batasan
(I1) lhs Tensor tipe integer (C1)
(I2) rhs Tensor tipe integer (C1)

Output

Nama Jenis Batasan
result Tensor tipe integer (C1)

Batasan

  • (C1) type(lhs) = type(rhs) = type(result).

Contoh

// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [9223372036854775807, 0, 1]

Contoh Lainnya

tanda

Semantik

Menampilkan tanda elemen operand dan menghasilkan result. Secara lebih formal, untuk setiap elemen x, semantik dapat dinyatakan menggunakan sintaksis Python sebagai berikut:

def sign(x):
  if is_integer(x):
    if compare(x, 0, LT, SIGNED): return -1
    if compare(x, 0, EQ, SIGNED): return 0
    return 1
  elif is_float(x):
    if is_nan(x): return NaN
    if compare(x, -0.0, EQ, FLOAT): return -0.0
    if compare(x, +0.0, EQ, FLOAT): return +0.0
    if compare(x, 0.0, LT, FLOAT): return -1.0
    return 1.0
  elif is_complex(x):
    if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
    if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
    return divide(x, convert(abs(x), type(x)))

Untuk jenis terkuantisasi, jalankan dequantize_op_quantize(sign, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand Tensor integer bertanda tangan, floating point, atau tipe kompleks atau quantized per-tensor (C1)

Output

Nama Jenis Batasan
result Tensor integer bertanda tangan, floating point, atau tipe kompleks atau quantized per-tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]

Contoh Lainnya

sinus

Semantik

Menjalankan operasi sinus element-wise pada TensorFlow operand dan menghasilkan tensor result. Bergantung pada jenis elemen, lakukan hal berikut:

  • Untuk float: sin dari IEEE-754.
  • Untuk bilangan kompleks: sinus kompleks.
  • Untuk jenis terkuantisasi: dequantize_op_quantize(sine, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand Tensor floating-point atau tipe kompleks atau kuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result Tensor floating-point atau tipe kompleks atau kuantisasi per-tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.sine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [0.0, -1.0]]

Contoh Lainnya

slice

Semantik

Mengekstrak slice dari operand menggunakan indeks awal yang dikomputasi secara statis dan menghasilkan5_tensor result. start_indices berisi indeks awal irisan untuk setiap dimensi, limit_indices berisi indeks akhir (eksklusif) untuk irisan setiap dimensi, dan strides berisi langkah untuk setiap dimensi.

Secara lebih formal, result[result_index] = operand[operand_index] dengan operand_index = start_indices + result_index * strides.

Input

Label Nama Jenis Batasan
(I1) operand Quantized Tensor atau per-tensor (C1-C3), (C5)
(I2) start_indices Konstanta TensorFlow 1 dimensi dari jenis si64 (C2), (C3), (C5)
(I3) limit_indices Konstanta TensorFlow 1 dimensi dari jenis si64 (C2), (C3), (C5)
(I4) strides Konstanta TensorFlow 1 dimensi dari jenis si64 (C2), (C4)

Output

Nama Jenis Batasan
result Quantized Tensor atau per-tensor (C1), (C5)

Batasan

  • (C1) element_type(operand) = element_type(result).
  • (C2) size(start_indices) = size(limit_indices) = size(strides) = rank(operand).
  • (C3) 0 <= start_indices <= limit_indices <= shape(operand).
  • (C4) 0 < strides.
  • (C5) shape(result) = ceil((limit_indices - start_indices) / strides).

Contoh

// %operand: [
//            [0, 0, 0, 0],
//            [0, 0, 1, 1],
//            [0, 0, 1, 1]
//           ]
%result = "stablehlo.slice"(%operand) {
  start_indices = dense<[1, 2]> : tensor<2xi64>,
  limit_indices = dense<[3, 4]> : tensor<2xi64>,
  strides = dense<1> : tensor<2xi64>
} : (tensor<3x4xi64>) -> tensor<2x2xi64>
// % result: [
//            [1, 1],
//            [1, 1]
//           ]

Contoh Lainnya

sort

Semantik

Mengurutkan irisan inputs 1 dimensi di sepanjang dimensi dimension secara bersamaan, sesuai dengan comparator dan menghasilkan results.

Tidak seperti input serupa di operasi lain, dimension memungkinkan nilai negatif, dengan semantik yang dijelaskan di bawah ini. Pada masa mendatang, hal ini mungkin tidak diizinkan karena alasan konsistensi (#1377).

Jika is_stable bernilai benar, pengurutan akan stabil, yaitu, urutan relatif elemen yang dianggap sama oleh pembanding akan dipertahankan. Untuk kasus dengan satu input, dua elemen e1 dan e2 dianggap sama oleh pembanding jika dan hanya jika comparator(e1, e2) = comparator(e2, e1) = false. Lihat formalisasi di bawah untuk mengetahui cara generalisasi ke beberapa input.

Secara lebih formal, untuk semua result_index di index_space(results[0]):

  • adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension.
  • result_slice = [ri0, ..., :, ..., riR-1] dengan riN adalah elemen individual di result_index, dan : disisipkan di adjusted_dimension.
  • inputs_together = (inputs[0]..., ..., inputs[N-1]...).
  • results_together[result_slice] = sort(inputs_together[result_slice], comparator_together).
  • tempat sort mengurutkan potongan 1 dimensi dalam urutan tidak menurun dengan mengharapkan comparator_together menampilkan true jika argumen sisi kiri kurang dari argumen kedua sebelah kanan.
  • def comparator_together(lhs_together, rhs_together):
      args = []
      for (lhs_el, rhs_el) in zip(lhs_together, rhs_together):
        args.append(lhs_el)
        args.append(rhs_el)
      return comparator(*args)
    
  • (results[0]..., ..., results[N-1]...) = results_together.

Input

Label Nama Jenis Batasan
(I1) inputs jumlah variadic Tensor atau quantized5 per-tensor (C1-C5)
(I2) dimension konstanta jenis si64 (C4)
(I3) is_stable konstanta jenis i1
(I4) comparator fungsi (C5)

Output

Nama Jenis Batasan
results jumlah variadic Tensor atau quantized5 per-tensor (C2), (C3)

Batasan

  • (C1) 0 < size(inputs).
  • (C2) type(inputs...) = type(results...).
  • (C3) same(shape(inputs...) + shape(results...)).
  • (C4) -R <= dimension < R, dengan R = rank(inputs[0]).
  • (C5) comparator memiliki jenis (tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1>, dengan Ei = element_type(inputs[i]).

Contoh

// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>, %arg3: tensor<i64>):
    %predicate = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GT>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    "stablehlo.return"(%predicate) : (tensor<i1>) -> ()
}) {
  dimension = 0 : i64,
  is_stable = true
} : (tensor<2x3xi64>, tensor<2x3xi64>) -> (tensor<2x3xi64>, tensor<2x3xi64>)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]

Contoh Lainnya

sqrt

Semantik

Menjalankan operasi root kuadrat element-wise pada TensorFlow operand dan menghasilkan tensor result. Bergantung pada jenis elemen, lakukan hal berikut:

  • Untuk float: squareRoot dari IEEE-754.
  • Untuk bilangan kompleks: akar kuadrat kompleks.
  • Untuk jenis terkuantisasi: dequantize_op_quantize(sqrt, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand Tensor floating-point atau tipe kompleks atau kuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result Tensor floating-point atau tipe kompleks atau kuantisasi per-tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [2.0, 3.0]]

Contoh Lainnya

kurangi

Semantik

Melakukan pengurangan berbasis elemen dari dua TensorFlow lhs dan rhs serta menghasilkan tensor result. Bergantung pada jenis elemen, lakukan hal berikut:

  • Untuk bilangan bulat: pengurangan bilangan bulat.
  • Untuk float: subtraction dari IEEE-754.
  • Untuk bilangan kompleks: pengurangan kompleks.
  • Untuk jenis terkuantisasi:
    • dequantize_op_quantize(subtract, lhs, rhs, type(result)).

Input

Label Nama Jenis Batasan
(I1) lhs Tensor integer, floating point, atau tipe kompleks atau kuantisasi per-tensor (C1)
(I2) rhs Tensor integer, floating point, atau tipe kompleks atau kuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result Tensor integer, floating point, atau tipe kompleks atau kuantisasi per-tensor (C1)

Batasan

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Contoh

// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
// %result: [[1, 2], [3, 4]]

Contoh Lainnya

Tanh

Semantik

Menjalankan operasi tangen hiperbolik berbasis elemen pada TensorFlow operand dan menghasilkan5 result. Bergantung pada jenis elemen, lakukan hal berikut:

  • Untuk float: tanh dari IEEE-754.
  • Untuk bilangan kompleks: tangen hiperbolik kompleks.
  • Untuk jenis terkuantisasi:
    • dequantize_op_quantize(tanh, operand, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand Tensor floating-point atau tipe kompleks atau kuantisasi per-tensor (C1)

Output

Nama Jenis Batasan
result Tensor floating-point atau tipe kompleks atau kuantisasi per-tensor (C1)

Batasan

  • (C1) baseline_type(operand) = baseline_type(result).

Contoh

// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand) : (tensor<3xf32>) -> tensor<3xf32>
// %result: [-0.76159416, 0.0, 0.76159416]

Contoh Lainnya

{i>transpose<i}

Semantik

Membisukan dimensi TensorFlow operand menggunakan permutation dan menghasilkan tensor result. Secara lebih formal, result[result_index] = operand[operand_index] dengan result_index[d] = operand_index[permutation[d]].

Input

Label Nama Jenis Batasan
(I1) operand Tensor atau Quantized Tensor (C1-C4)
(I2) permutation Konstanta TensorFlow 1 dimensi dari jenis si64 (C2-C4)

Output

Nama Jenis Batasan
result Tensor atau Quantized Tensor (C1), (C3-C4)

Batasan

  • (C1) element_type(result) diberikan oleh:
    • element_type(operand), jika !is_per_axis_quantized(operand).
    • element_type(operand) kecuali bahwa quantization_dimension(operand) dan quantization_dimension(result) mungkin berbeda.
  • (C2) permutation adalah permutasi dari range(rank(operand)).
  • (C3) shape(result) = dim(operand, permutation...).
  • (C4) Jika is_per_axis_quantized(result), maka quantization_dimension(operand) = permutation(quantization_dimension(result)).

Contoh

// %operand: [
//            [[1,2], [3,4], [5,6]],
//            [[7,8], [9,10], [11,12]]
//           ]
%result = "stablehlo.transpose"(%operand) {
  permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
// %result: [
//           [[1,7], [3,9], [5,11]],
//           [[2,8], [4,10], [6,12]]
//          ]

Contoh Lainnya

triangular_solve

Semantik

Menyelesaikan batch sistem persamaan linear dengan matriks koefisien segitiga bawah atau atas.

Secara lebih formal, dengan a dan b, result[i0, ..., iR-3, :, :] adalah solusi untuk op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :] jika left_side adalah true atau x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :] saat left_side adalah false, memecahkan variabel x dengan op(a) ditentukan oleh transpose_a, yang dapat berupa salah satu dari berikut:

  • NO_TRANSPOSE: Melakukan operasi menggunakan a apa adanya.
  • TRANSPOSE: Melakukan operasi pada transposisi a.
  • ADJOINT: Melakukan operasi pada transpose konjugasi a.

Data input hanya dibaca dari segitiga bawah a, jika lower adalah true, atau segitiga atas a, jika tidak. Data output ditampilkan dalam segitiga yang sama; nilai dalam segitiga lainnya ditentukan oleh implementasi.

Jika unit_diagonal bernilai benar, implementasi dapat mengasumsikan bahwa elemen diagonal a sama dengan 1, jika tidak, perilaku tidak ditentukan.

Untuk jenis terkuantisasi, jalankan dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower, unit_diagonal, transpose_a), a, b, type(result)).

Input

Label Nama Jenis Batasan
(I1) a Tensor floating-point atau tipe kompleks atau kuantisasi per-tensor (C1-C3)
(I2) b Tensor floating-point atau tipe kompleks atau kuantisasi per-tensor (C1-C4)
(I3) left_side konstanta jenis i1 (C3)
(I4) lower konstanta jenis i1
(I5) unit_diagonal konstanta jenis i1
(I6) transpose_a enum NO_TRANSPOSE, TRANSPOSE, dan ADJOINT

Output

Nama Jenis Batasan
result Tensor floating-point atau tipe kompleks atau kuantisasi per-tensor (C1)

Batasan

  • (C1) baseline_element_type(a) = baseline_element_type(b).
  • (C2) 2 <= rank(a) = rank(b) = R.
  • (C3) Hubungan antara shape(a) dan shape(b) ditentukan sebagai berikut:
    • shape(a)[:-3] = shape(b)[:-3].
    • dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1).
  • (C4) baseline_type(b) = baseline_type(result).

Contoh

// %a = [
//       [1.0, 0.0, 0.0],
//       [2.0, 4.0, 0.0],
//       [3.0, 5.0, 6.0]
//      ]
// %b = [
//       [2.0, 0.0, 0.0],
//       [4.0, 8.0, 0.0],
//       [6.0, 10.0, 12.0]
//      ]
%result = "stablehlo.triangular_solve"(%a, %b) {
  left_side = true,
  lower = true,
  unit_diagonal = false,
  transpose_a = #stablehlo<transpose NO_TRANSPOSE>
} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// %result: [
//           [2.0, 0.0, 0.0],
//           [0.0, 2.0, 0.0],
//           [0.0, 0.0, 2.0]
//          ]

tuple

Semantik

Menghasilkan tuple result dari nilai val.

Input

Label Nama Jenis Batasan
(I1) val jumlah nilai variadic (C1)

Output

Nama Jenis Batasan
result tuple (C1)

Batasan

  • (C1) result memiliki jenis tuple<E0, ..., EN-1> dengan Ei = type(val[i]).

Contoh

// %val0: [1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1) : (tensor<2xf32>, tuple<tensor<i32>>) -> tuple<tensor<2xf32>, tuple<tensor<i32>>>
// %result: ([1.0, 2.0], (3))

Contoh Lainnya

uniform_dequantize

Semantik

Melakukan konversi element-wise operand menjadi tensor floating point result sesuai dengan parameter kuantisasi yang ditentukan oleh jenis operand.

Secara lebih formal, result = dequantize(operand).

Input

Label Nama Jenis Batasan
(I1) operand Tensor terkuantisasi (C1), (C2)

Output

Nama Jenis Batasan
result Tensor tipe floating-point (C1), (C2)

Batasan

  • (C1) shape(operand) = shape(result).
  • (C2) element_type(result) = expressed_type(operand).

Contoh

// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
// %result: [4.0, 15.0]

uniform_quantize

Semantik

Melakukan konversi berbasis elemen tensor floating point atau operand tensor terkuantisasi ke result kuantisasi energi sesuai dengan parameter kuantisasi yang ditentukan oleh jenis result.

Secara lebih formal,

  • Jika is_float(operand):
    • result = quantize(operand, type(result)).
  • Jika is_quantized(operand):
    • float_result = dequantize(operand).
    • result = quantize(float_result, type(result)).

Input

Label Nama Jenis Batasan
(I1) operand Tensor floating point atau tipe terkuantisasi (C1), (C2)

Output

Nama Jenis Batasan
result Tensor terkuantisasi (C1), (C2)

Batasan

  • (C1) shape(operand) = shape(result).
  • (C2) expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand).

Contoh

// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
// %result: [10, 10]

// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-20,0.2:-30}>>
// %result: [20, 45]

saat

Semantik

Menghasilkan output dari menjalankan fungsi body 0 kali atau lebih, sementara fungsi cond menghasilkan true. Secara lebih formal, semantik dapat dinyatakan menggunakan sintaksis Python sebagai berikut:

internal_state = operand
while cond(*internal_state):
  internal_state = body(*internal_state)
results = internal_state

Perilaku loop tanpa henti akan ditentukan (#383).

Input

Label Nama Jenis Batasan
(I1) operand bilangan variadic, quantized TensorFlow, atau token (C1-C3)
(I2) cond fungsi (C1)
(I3) body fungsi (C2)

Output

Nama Jenis Batasan
results bilangan variadic, quantized TensorFlow, atau token (C3)

Batasan

  • (C1) cond memiliki jenis (T0, ..., TN-1) -> tensor<i1>, dengan Ti = type(operand[i]).
  • (C2) body memiliki jenis (T0, ..., TN-1) -> (T0, ..., TN-1), dengan Ti = type(operand[i]).
  • (C3) type(results...) = type(operand...).

Contoh

// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %cond = "stablehlo.compare"(%arg0, %ten) {
      comparison_direction = #stablehlo<comparison_direction LT>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    stablehlo.return %cond : tensor<i1>
  }, {
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %new_sum = stablehlo.add %arg1, %one : tensor<i64>
    %new_i = stablehlo.add %arg0, %one : tensor<i64>
    stablehlo.return %new_i, %new_sum : tensor<i64>, tensor<i64>
}) : (tensor<i64>, tensor<i64>) -> (tensor<i64>, tensor<i64>)
// %results0: 10
// %results1: 10

Contoh Lainnya

Xor

Semantik

Menjalankan XOR berbasis elemen dari dua TensorFlow lhs dan rhs serta menghasilkan tensor result. Bergantung pada jenis elemen, lakukan hal berikut:

  • Untuk boolean: XOR logis.
  • Untuk bilangan bulat: bitwise XOR.

Input

Label Nama Jenis Batasan
(I1) lhs Tensor tipe boolean atau integer (C1)
(I2) rhs Tensor tipe boolean atau integer (C1)

Output

Nama Jenis Batasan
result Tensor tipe boolean atau integer (C1)

Batasan

  • (C1) type(lhs) = type(rhs) = type(result).

Contoh

// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]

// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, false]]

Eksekusi

Eksekusi berurutan

Program StableHLO dijalankan dengan memberikan nilai input ke fungsi main dan menghitung nilai output. Nilai output fungsi dihitung dengan mengeksekusi grafik operasi yang di-root pada op return yang sesuai.

Urutan eksekusi ditentukan implementasinya selama selaras dengan dataflow, yaitu jika op dijalankan sebelum digunakan. Di StableHLO, semua operasi yang memberikan dampak samping menggunakan satu token dan menghasilkan satu token (beberapa token dapat di-multiplex menjadi satu token melalui after_all), sehingga urutan eksekusi efek samping juga selaras dengan dataflow. Kemungkinan urutan eksekusi dari contoh program di atas adalah %0%1%2%3%4return atau %3%0%1%2%4return.

Secara lebih formal, proses StableHLO adalah kombinasi dari: 1) program StableHLO, 2) status operasi (belum dijalankan, sudah dieksekusi), dan 3) nilai perantara yang sedang diproses oleh proses tersebut. Proses ini dimulai dengan nilai input ke fungsi main, berlanjut melalui grafik operasi yang memperbarui status operasi dan nilai menengah, serta selesai dengan nilai output. Formasi lebih lanjut akan ditentukan (#484).

Eksekusi paralel

Program StableHLO dapat dijalankan secara paralel, yang disusun ke dalam petak proses 2D num_replicas oleh num_partitions yang keduanya memiliki jenis ui32.

Di petak proses StableHLO, num_replicas * num_partitions proses StableHLO dijalankan secara bersamaan. Setiap proses memiliki process_id = (replica_id, partition_id) unik, dengan replica_id di replica_ids = range(num_replicas) dan partition_id di partition_ids = range(num_partitions) yang keduanya memiliki jenis ui32.

Ukuran petak proses diketahui secara statis untuk setiap program (di masa mendatang, kami berencana untuk menjadikannya sebagai bagian eksplisit dari program StableHLO #650), dan posisi dalam petak proses diketahui secara statis untuk setiap proses. Setiap proses memiliki akses ke posisinya dalam petak proses melalui operasi replica_id dan partition_id.

Dalam petak proses, semua program bisa sama (dalam gaya "Satu Program, Beberapa Data"), semuanya dapat berbeda (dalam gaya "Beberapa Program, Multi-Data") atau apa pun di antaranya. Di masa mendatang, kami berencana untuk memperkenalkan dukungan bagi idiom lain dalam menentukan program StableHLO paralel, termasuk GSPMD (#619).

Dalam petak proses, sebagian besar proses bersifat independen satu sama lain. Proses tersebut memiliki status operasi terpisah, nilai input/menengah/output terpisah, dan sebagian besar operasi dijalankan secara terpisah antarproses, dengan pengecualian sejumlah kecil operasi kolektif yang dijelaskan di bawah ini.

Mengingat bahwa eksekusi sebagian besar operasi hanya menggunakan nilai dari proses yang sama, biasanya nilai ini akan dianggap tidak ambigu. Namun, saat mendeskripsikan semantik operasi kolektif, hal itu tidak mencukupi, dan menyebabkan notasi name@process_id untuk merujuk pada nilai name dalam proses tertentu. (Dari perspektif tersebut, name yang tidak memenuhi syarat dapat dilihat sebagai singkatan untuk name@(replica_id(), partition_id())).

Urutan eksekusi di seluruh proses ditentukan oleh implementasi, kecuali untuk sinkronisasi yang diperkenalkan oleh komunikasi titik ke titik dan operasi kolektif seperti yang dijelaskan di bawah ini.

Komunikasi point-to-point

Proses StableHLO dapat berkomunikasi satu sama lain melalui saluran StableHLO. Saluran diwakili oleh ID positif jenis si64. Melalui berbagai operasi, Anda dapat mengirim nilai ke saluran dan menerimanya dari saluran.

Formalisasi lebih lanjut, misalnya asal ID saluran ini, cara program mengenalinya, dan jenis sinkronisasi yang diperkenalkan oleh ID tersebut, akan ditentukan (#484).

Komunikasi streaming

Setiap proses StableHLO memiliki akses ke dua antarmuka streaming:

  • Infeed yang dapat dibaca.
  • Outfeed yang dapat ditulisi.

Tidak seperti saluran, yang digunakan untuk berkomunikasi antar-proses dan karenanya memiliki proses di kedua ujungnya, infeed dan outfeed memiliki implementasi akhir lainnya yang ditentukan.

Formalisasi lebih lanjut, misalnya pengaruh komunikasi streaming terhadap urutan eksekusi dan jenis sinkronisasi yang diperkenalkan olehnya, akan ditentukan (#484).

Operasi kolektif

Ada enam operasi kolektif di StableHLO: all_gather, all_reduce, all_to_all, collective_broadcast, collective_permute, dan reduce_scatter. Semua operasi ini membagi proses dalam petak proses StableHLO menjadi grup proses StableHLO dan menjalankan komputasi bersama dalam setiap grup proses, secara terpisah dari grup proses lainnya.

Dalam setiap grup proses, operasi kolektif dapat menimbulkan penghalang sinkronisasi. Formalisasi lebih lanjut, misalnya menjelaskan kapan tepatnya sinkronisasi ini terjadi, bagaimana proses yang sebenarnya sampai pada batasan ini, dan apa yang terjadi jika tidak, akan ditentukan (#484).

Jika grup proses melibatkan komunikasi lintas partisi, yaitu ada proses dalam grup proses yang ID partisinya berbeda, eksekusi operasi kolektif memerlukan saluran, dan operasi kolektif harus memberikan channel_id positif dari jenis si64. Komunikasi lintas replika tidak memerlukan saluran.

Komputasi yang dilakukan oleh operasi kolektif bersifat khusus untuk operasi individual dan dijelaskan di bagian operasi individual di atas. Namun, strategi yang membagi petak proses menjadi beberapa grup proses dibagikan di antara operasi tersebut dan dijelaskan di bagian ini. Secara lebih formal, StableHLO mendukung empat strategi berikut.

cross_replica

Hanya komunikasi lintas replika yang terjadi dalam setiap grup proses. Strategi ini mengambil replica_groups - daftar ID replika - dan menghitung produk Kartesius replica_groups dengan partition_ids. replica_groups harus memiliki elemen unik dan mencakup semua replica_ids. Secara lebih formal, dengan menggunakan sintaksis Python:

def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
  for replica_group in replica_groups:
    for partition_id in partition_ids:
      process_group = []
      for replica_id in replica_group:
        process_group.append((replica_id, partition_id))
      yield process_group

Misalnya, untuk replica_groups = [[0, 1], [2, 3]] dan num_partitions = 2, cross_replica akan menghasilkan [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]].

cross_partition

Hanya komunikasi lintas partisi yang terjadi dalam setiap grup proses. Strategi ini mengambil partition_groups - daftar ID partisi - dan menghitung hasil kali Kartesius partition_groups dengan replica_ids. partition_groups harus memiliki elemen unik dan mencakup semua partition_ids. Secara lebih formal, menggunakan sintaksis Python:

def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
  for partition_group in partition_groups:
    for replica_id in replica_ids:
      process_group = []
      for partition_id in partition_group:
        process_group.append((replica_id, partition_id))
      yield process_group

Misalnya, untuk partition_groups = [[0, 1]] dan num_replicas = 4, cross_partition akan menghasilkan [[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]].

cross_replica_and_partition

Komunikasi lintas replika dan lintas partisi dapat terjadi dalam setiap grup proses. Strategi ini menggunakan replica_groups - daftar ID replika - dan menghitung produk Kartesius dari setiap replica_group dengan partition_ids. replica_groups harus memiliki elemen unik dan mencakup semua replica_ids. Secara lebih formal, menggunakan sintaksis Python:

def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
  for replica_group in replica_groups:
    process_group = []
    for partition_id in partition_ids:
      for replica_id in replica_group:
        process_group.append((replica_id, partition_id))
    yield process_group

Misalnya, untuk replica_groups = [[0, 1], [2, 3]] dan num_partitions = 2, cross_replica_and_partition akan menghasilkan [[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]].

flattened_ids

Strategi ini mengambil flattened_id_groups - daftar ID proses yang "diratakan" dalam bentuk replica_id * num_partitions + partition_id - dan mengubahnya menjadi ID proses. flattened_id_groups harus memiliki elemen yang unik dan mencakup semua process_ids. Secara lebih formal, menggunakan sintaksis Python:

def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
  for flattened_id_group in flattened_id_groups:
    process_group = []
    for flattened_id in flattened_id_group:
      replica_id = flattened_id // num_partitions
      partition_id = flattened_id % num_partitions
      process_group.append((replica_id, partition_id))
    yield process_group

Misalnya, untuk flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]], num_replicas = 4, dan num_partitions = 2, flattened_ids akan menghasilkan [[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]].

Akurasi

Saat ini, StableHLO tidak memberikan jaminan tentang akurasi numerik, tetapi hal ini dapat berubah pada masa mendatang (#1156).

Error

Program StableHLO divalidasi melalui serangkaian batasan untuk setiap operasi, yang mengesampingkan banyak kelas error sebelum runtime. Namun, kondisi error masih mungkin terjadi, misalnya melalui overflow integer, akses di luar batas, dll. Kecuali jika dipanggil secara eksplisit, semua error ini menghasilkan perilaku yang ditentukan implementasi, tetapi ini dapat berubah di masa mendatang (#1157).

Sebagai pengecualian untuk aturan ini, pengecualian floating point dalam program StableHLO memiliki perilaku yang ditetapkan dengan baik. Operasi yang menghasilkan pengecualian yang ditentukan oleh standar IEEE-754 (operasi tidak valid, pembagian dengan nol, overflow, underflow, atau pengecualian tidak tepat) memberikan hasil default (seperti yang ditentukan dalam standar) dan melanjutkan eksekusi tanpa menaikkan flag status yang sesuai; mirip dengan penanganan pengecualian raiseNoFlag dari standar. Pengecualian untuk operasi nonstandar (misalnya, aritmetika kompleks dan fungsi transendental tertentu) ditentukan oleh implementasi.

Notasi

Untuk mendeskripsikan sintaksis, dokumen ini menggunakan ragam ISO dari sintaksis EBNF yang dimodifikasi (ISO/IEC 14977:1996, Wikipedia), dengan dua modifikasi: 1) aturan ditentukan menggunakan ::=, bukan =,

2) penyambungan dinyatakan menggunakan penjajaran, bukan ,.

Untuk mendeskripsikan semantik (yaitu dalam bagian "Jenis", "Konstanta", dan "Ops"), kami menggunakan formula yang didasarkan pada sintaksis Python yang diperluas dengan dukungan untuk mengekspresikan operasi array secara singkat seperti yang dijelaskan di bawah ini. Ini berfungsi dengan baik untuk cuplikan kode kecil, tetapi dalam kasus yang jarang terjadi ketika cuplikan kode yang lebih besar diperlukan, kami menggunakan sintaksis Python vanilla yang selalu diperkenalkan secara eksplisit.

Formula

Mari kita pelajari cara kerja formula berdasarkan contoh dari spesifikasi dot_general. Salah satu batasan untuk operasi ini terlihat sebagai berikut: dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...).

Nama-nama yang digunakan dalam formula ini berasal dari dua sumber: 1) fungsi global, yaitu dim, 2) definisi anggota dari elemen program yang sesuai, yaitu input lhs, lhs_batching_dimensions, rhs, dan rhs_batching_dimensions yang ditentukan di bagian "Input" pada dot_general.

Seperti yang disebutkan di atas, sintaksis formula ini berbasis Python dengan beberapa ekstensi berorientasi keringkasan. Untuk memahami formula itu, mari kita ubah menjadi sintaks{i> vanilla Python<i}.

A) Dalam formula ini, kita menggunakan = untuk mewakili kesetaraan sehingga langkah pertama untuk mendapatkan sintaksis Python adalah mengganti = dengan ==, sebagai berikut: dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...).

B) Selain itu, formula ini mendukung elipsis (...) yang mengubah ekspresi skalar menjadi ekspresi TensorFlow. Singkatnya, f(xs...) secara kasar berarti "untuk setiap x skalar pada TensorFlow xs, menghitung f(x) skalar, lalu menampilkan semua hasil skalar ini bersama-sama sebagai hasil TensorFlow". Dalam sintaksis Python vanila, contoh formula kami berubah menjadi: [dim(lhs, dim1) for dim1 in lhs_batching_dimensions] == [dim(rhs, dim2) for dim2 in rhs_batching_dimensions].

Berkat elipsis, menghindari bekerja pada tingkat skalar individual sering kali dapat dilakukan. Namun, dalam beberapa kasus yang rumit, sintaksis semi-informal tingkat rendah dapat digunakan seperti dalam formula start_indices[bi0, ..., :, ..., biN] dari spesifikasi gather. Dalam hal keringkasan, kami tidak memberikan formalisme yang tepat untuk menerjemahkan sintaksis semacam itu ke vanilla Python, dengan harapan bahwa sintaksis tersebut masih dapat dipahami secara intuitif berdasarkan kasus per kasus. Harap beri tahu kami jika beberapa formula tertentu terlihat buram, dan kami akan mencoba meningkatkannya.

Selain itu, Anda akan melihat bahwa formula menggunakan elipsis untuk memperluas semua jenis daftar, termasuk tensor, daftar Tensor (yang bisa muncul dari jumlah tensor variadik), dll. Ini adalah area lain tempat kami tidak memberikan formalisme yang tepat (misalnya, daftar bahkan bukan bagian dari sistem jenis StableHLO) dan sebagai gantinya mengandalkan pemahaman intuitif.

C) Alat notasi terakhir yang penting dan kami gunakan adalah penyiaran implisit. Meskipun opset StableHLO tidak mendukung penyiaran implisit, formula juga mendukung keringkasan. Intinya, jika skalar digunakan dalam konteks di mana TensorFlow diharapkan, skalar akan disiarkan ke bentuk yang diharapkan.

Untuk melanjutkan contoh dot_general, berikut batasan lainnya: 0 <= lhs_batching_dimensions < rank(lhs). Seperti yang ditentukan dalam spesifikasi dot_general, lhs_batching_dimensions adalah TensorFlow, tetapi 0 dan rank(lhs) adalah skalar. Setelah kita menerapkan siaran implisit, formula akan menjadi [0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)].

Jika diterapkan pada operasi dot_general tertentu, formula ini akan dievaluasi menjadi TensorFlow boolean. Saat formula digunakan sebagai batasan, batasan akan berlaku jika formula dievaluasi ke true atau ke TensorFlow yang hanya memiliki elemen true.

Nama

Dalam formula, ruang lingkup leksikal meliputi: 1) fungsi global, 2) definisi anggota,

3) definisi lokal. Daftar fungsi global disediakan di bawah ini. Daftar definisi elemen bergantung pada elemen program tempat notasi diterapkan:

  • Untuk operasi, definisi anggota mencakup nama yang diperkenalkan di bagian "Input" dan "Output".
  • Untuk lainnya, definisi anggota mencakup bagian struktural elemen program, dinamai menurut non-terminal EBNF terkait. Biasanya, nama bagian struktural ini diperoleh dengan mengonversi nama non-terminal menjadi snake case (mis. IntegerLiteral => integer_literal), tetapi terkadang nama disingkat dalam proses (misalnya QuantizationStorageType => storage_type) yang dalam hal ini nama diperkenalkan secara eksplisit mirip dengan bagian "Inputs" / "Outputs" dalam spesifikasi operasi.
  • Selain itu, definisi anggota selalu menyertakan self untuk merujuk ke elemen program yang sesuai.

Nilai

Saat dievaluasi, formula akan berfungsi dengan jenis nilai berikut: 1) Value (nilai sebenarnya, misalnya dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>; nilai tersebut selalu mengetahui jenisnya), 2) Placeholder (nilai mendatang, misalnya lhs, rhs, atau result; nilai sebenarnya belum diketahui, hanya jenisnya yang diketahui), 3) Type (jenis seperti yang ditentukan di bagian "Jenis"), 4) Function (fungsi global" seperti yang ditentukan di bagian "Jenis). 4) Function (fungsi global" seperti yang ditentukan di bagian "Jenis").

Bergantung pada konteksnya, nama mungkin merujuk ke nilai yang berbeda. Lebih khususnya, bagian "Semantik" untuk operasi (dan yang setara untuk elemen program lainnya) menentukan logika runtime, sehingga semua input tersedia sebagai Value. Sebaliknya, bagian "Batasan" untuk operasi (dan yang setara) menentukan logika "waktu kompilasi", yaitu sesuatu yang biasanya dieksekusi sebelum runtime, sehingga hanya input konstan yang tersedia sebagai Value dan input lainnya hanya tersedia sebagai Placeholder.

Nama Di "Semantik" Di "Constraints"
Fungsi global Function Function
Input konstan Value Value
Input non-konstanta Value Placeholder
Output Value Placeholder
Definisi lokal Bergantung pada definisi Bergantung pada definisi

Mari kita perhatikan contoh operasi transpose:

%result = "stablehlo.transpose"(%operand) {
  permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>

Untuk operasi ini, permutation adalah konstanta, sehingga tersedia sebagai Value dalam semantik dan batasan. Sebaliknya, operand dan result tersedia sebagai Value dalam semantik, tetapi hanya sebagai Placeholder dalam batasan.

Fungsi

Konstruksi jenis

Tidak ada fungsi yang dapat digunakan untuk membuat jenis. Sebaliknya, kita langsung menggunakan sintaksis jenis karena biasanya lebih ringkas. Misalnya, (tensor<E>, tensor<E>) -> (tensor<E>), bukan function_type( [tensor_type([], E), tensor_type([], E)], [tensor_type([], E)]).

Fungsi pada jenis

  • element_type ditentukan pada jenis TensorFlow, jenis Quantized, dan menampilkan, masing-masing, bagian TensorElementType atau QuantizedTensorElementType dari TensorType atau QuantizedTensorType yang terkait.
def element_type(x: Value | Placeholder | Type):
 if type(x) == TensorType:
    return tensor_element_type(x)
  if type(x) == QuantizedTensorType:
    return quantized_tensor_element_type(x)
  if type(x) is not Type:
    return element_type(type(x))
  • is_per_axis_quantized(x: Value | Placeholder | Type) -> Value adalah pintasan untuk is_quantized(x) and quantization_dimension(x) is not None.

  • is_per_tensor_quantized(x: Value | Placeholder | Type) -> Value adalah pintasan untuk is_quantized(x) and quantization_dimension(x) is None.

  • is_promotable(x: Type, y: Type) -> bool memeriksa apakah jenis x dapat dipromosikan ke jenis y. Jika x dan y adalah QuantizedTensorElementType, promosi hanya berlaku untuk storage_type. Versi promosi khusus ini saat ini digunakan dalam konteks komputasi pengurangan (lihat RFC untuk detail selengkapnya).

def is_promotable(x: Type, y: Type) -> Value:
  is_same_type = (is_bool(x) and is_bool(y)) or
    (is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
    (is_complex(x) and is_complex(y)) or
    (is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))

  if is_same_type == False:
    return False

  if is_integer(x) or is_float(x):
    return bitwidth(x) <= bitwidth(y)

  if is_complex(x):
    return bitwidth(element_type(x)) <= bitwidth(element_type(y))

  if is_quantized(x):
    return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))

  return false
  • is_quantized(x: Value | Placeholder | Type) -> Value adalah pintasan untuk is_quantized_tensor_element_type(x).

  • is_type_name(x: Value | Placeholder | Type) -> Value. Tersedia untuk semua jenis. Misalnya, is_float(x) akan menampilkan true jika x adalah FloatType. Jika x adalah nilai atau placeholder, fungsi ini adalah pintasan untuk is_type_name(type(x)).

  • max_value(x: Type) -> Value menampilkan nilai maksimum TensorElementType. Jika x bukan TensorElementType, tampilkan None.

  • min_value(x: Type) -> Value menampilkan nilai minimum TensorElementType yang memungkinkan. Jika x bukan TensorElementType, tampilkan None.

  • member_name(x: Value | Placeholder | Type) -> Any. Tersedia untuk semua definisi anggota member_name dari semua jenis. Misalnya, tensor_element_type(x) menampilkan bagian TensorElementType dari TensorType yang sesuai. Jika x adalah nilai atau placeholder, fungsi ini adalah pintasan untuk member_name(type(x)). Jika x bukan jenis yang memiliki anggota yang sesuai, atau nilai atau placeholder dari jenis tersebut, None akan ditampilkan.

Konstruksi nilai

  • operation_name(*xs: Value | Type) -> Value. Tersedia untuk semua operasi. Misalnya, add(lhs, rhs) mengambil dua nilai Tensor lhs dan rhs, lalu menampilkan output mengevaluasi operasi add dengan input ini. Untuk beberapa operasi, misalnya broadcast_in_dim, jenis output-nya adalah "load-bearing", yaitu diperlukan untuk mengevaluasi operasi. Dalam hal ini, fungsi menggunakan jenis ini sebagai argumen.

Fungsi pada nilai

  • Semua operator dan fungsi Python tersedia. Misalnya, notasi subscription dan slicing dari Python tersedia untuk diindeks ke dalam TensorFlow, quantized Tensor, dan tuple.

  • to_destination_type(x: Value, destination_type: Type) -> Value ditentukan pada tensor dan menampilkan nilai x yang dikonversi berdasarkan type(x) dan destination_type sebagai berikut:

def to_destination_type(x: Value, destination_type: Type) -> Value:
  if type(x) == destination_type:
    return x

  if is_quantized(destination_type):
    if is_quantized(type(x)):
      return quantize(x, destination_type)
    assert is_float(type(x))
    return quantize(x, destination_type)

  if is_quantized(type(x)):
    assert destination_type = expressed_type(type(x))
    return dequantize(type(x))

  return convert(x, destination_type)

Ada diskusi awal tentang menggabungkan operasi convert, uniform_quantize, dan uniform_dequantize (#1576). Setelah penggabungan, kita tidak memerlukan fungsi di atas dan dapat menggunakan nama operasi untuk convert sebagai gantinya.

  • is_nan(x: Value) -> Value ditentukan pada TensorFlow dan menampilkan true jika semua elemen x adalah NaN atau false. Jika x bukan TensorFlow, akan menampilkan None.

  • is_sorted(x: Value) -> Value ditentukan pada TensorFlow dan menampilkan true jika elemen x diurutkan dalam urutan menaik sehubungan dengan urutan leksikografis menaik dari indeksnya atau false jika tidak. Jika x bukan tensor, tampilkan None.

  • is_unique(x: Value) -> Value ditentukan pada TensorFlow dan menampilkan true jika x tidak memiliki elemen duplikat atau false jika tidak. Jika x bukan TensorFlow, akan menampilkan None.

  • member_name(x: Value) -> Any ditentukan untuk semua definisi anggota member_name dari semua nilai. Misalnya, real_part(x) menampilkan bagian RealPart dari ComplexConstant yang sesuai. Jika x bukan nilai yang memiliki anggota yang sesuai, tampilkan None.

  • same(x: Value) -> Value ditentukan pada TensorFlow dan menampilkan true jika elemen x sama satu sama lain atau false jika tidak. Jika Tensor tidak memiliki elemen, hal ini dianggap sebagai "semua sama satu sama lain", yaitu fungsi akan menampilkan true. Jika x bukan TensorFlow, tampilkan None.

  • split(x: Value, num_results: Value, axis: Value) -> Value ditentukan pada tensor dan menampilkan irisan num_results dari x di sepanjang sumbu axis. Jika x bukan TensorFlow atau dim(x, axis) % num_results != 0, tampilkan None.

Komputasi bentuk

  • axes(x: Value | Placeholder | Type) -> Value adalah pintasan untuk range(rank(x)).

  • dim(x: Value | Placeholder | Type, axis: Value) -> Value adalah pintasan untuk shape(x)[axis].

  • dims(x: Value | Placeholder | Type, axes: List) -> List adalah pintasan untuk list(map(lambda axis: dim(x, axis), axes)).

  • index_space(x: Value | Placeholder | Type) -> Value ditentukan pada TensorFlow dan menampilkan indeks size(x) untuk TensorType yang sesuai yang diurutkan dalam urutan leksikografis menaik, yaitu [0, ..., 0], [0, ..., 1], ..., shape(x) - 1. Jika x bukan jenis TensorFlow, jenis TensorFlow terkuantisasi, atau nilai atau placeholder dari salah satu jenis ini, akan menampilkan None.

  • rank(x: Value | Placeholder | Type) -> Value adalah pintasan untuk size(shape(x)).

  • shape(x: Value | Placeholder | Type) -> Value ditentukan di bagian "Fungsi pada jenis" melalui member_name.

  • size(x: Value | Placeholder | Type) -> Value adalah pintasan untuk reduce(lambda x, y: x * y, shape(x)).

Komputasi kuantisasi

  • def baseline_element_type(x: Value | Placeholder | Type) -> Type adalah pintasan untuk element_type(baseline_type(x)).

  • baseline_type ditentukan pada jenis TensorFlow dan jenis TensorFlow terkuantisasi, lalu mengubahnya menjadi "baseline", yaitu jenis dengan bentuk yang sama tetapi dengan parameter kuantisasi jenis elemen yang direset ke nilai default. Ini digunakan sebagai trik praktis untuk membandingkan jenis TensorFlow dan Tensor terkuantisasi secara seragam, yang cukup sering diperlukan. Untuk jenis terkuantisasi, hal ini memungkinkan pembandingan jenis yang mengabaikan parameter kuantisasi, yaitu, shape, storage_type, expressed_type, storage_min, storage_max, dan quantization_dimension (untuk jenis terkuantisasi per sumbu) semuanya harus cocok, tetapi scales dan zero points mungkin berbeda.

def baseline_type(x: Value | Placeholder | Type) -> Type:
  if type(x) == TensorType:
    return x
  if type(x) == QuantizedTensorType:
    element_type = quantized_tensor_element_type(x)
    baseline_element_type = QuantizedTensorElementType(
      storage_type = storage_type(element_type),
      storage_min = storage_min(element_type),
      storage_max = storage_max(element_type),
      expressed_type = expressed_type(element_type),
      quantization_dimension = quantization_dimension(element_type),
      scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
      zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
    return QuantizedTensorType(shape(x), baseline_element_type)
  if type(x) is not Type:
    return baseline_element_type(type(x))
  • dequantize ditentukan pada jenis TensorFlow terkuantisasi dan mengubahnya menjadi jenis TensorFlow floating point. Hal ini terjadi dengan mengonversi elemen terkuantisasi yang mewakili nilai bilangan bulat jenis penyimpanan menjadi nilai floating point yang sesuai dari jenis yang dinyatakan menggunakan titik nol dan skala yang terkait dengan jenis elemen terkuantisasi.
def compute_zero_points(quantized_type, result_type):
  if is_per_tensor_quantized(quantized_type):
    return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
  if is_per_axis_quantized(quantized_type):
    for i in index_space(result_type):
      d = quantization_dimension(quantized_type)
      zero_points[i] = zero_points(quantized_type)[i[d]]
    return zero_points

def compute_scales(quantized_type, result_type):
  if is_per_tensor_quantized(quantized_type):
    return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
            type(result_type))
  if is_per_axis_quantized(quantized_type):
    for i in index_space(result_type):
      d = quantization_dimension(quantized_type)
      scales[i] = scales(quantized_type)[i[d]]
    return scales

def dequantize(x: Value) -> Value:
  assert is_quantized(x)
  x_storage = bitcast_convert(x, storage_type(x))
  x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
  x_expressed_sub = convert(x_storage_sub, expressed_type(x))
  return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
  • quantize ditentukan pada jenis TensorFlow floating point dan mengubahnya menjadi jenis TensorFlow terkuantisasi. Hal ini terjadi dengan mengonversi nilai floating point dari jenis yang dinyatakan menjadi nilai bilangan bulat jenis penyimpanan yang sesuai menggunakan titik nol dan skala yang terkait dengan jenis elemen terkuantisasi.
def quantize(x: Value, type: Type) -> Value:
  assert is_float(x) and is_quantized(type)
  x_expressed_rounded = round_nearest_even(x / compute_scales(type, type(x)))
  x_storage_rounded = convert(x_expressed_rounded, storage_type(type))
  x_storage_add = x_storage_rounded + compute_zero_points(type, type(x_storage_rounded))
  x_storage = clamp(storage_min(type), x_storage_add, storage_max(type))
  return bitcast_convert(x_storage, type)
  • dequantize_op_quantize digunakan untuk menentukan komputasi element-wise pada tensor terkuantisasi. Proses men-dekuantisasi, yaitu mengubah elemen terkuantisasi menjadi jenis yang diekspresikan, lalu menjalankan operasi, dan kemudian melakukan kuantisasi, yaitu mengubah hasil kembali menjadi jenis penyimpanannya. Saat ini, fungsi ini hanya berfungsi untuk kuantisasi per-tensor. Kuantisasi per sumbu sedang dalam proses (#1574).
def dequantize_op_quantize(op, *inputs_and_output_type):
  inputs = inputs_and_output_type[:-1]
  output_type = inputs_and_output_type[-1]

  float_inputs = map(dequantize, inputs)
  float_result = op(*float_inputs)
  return quantize(float_result, output_type)

def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
  inputs = inputs_and_output_type[:-3]
  float_inputs = map(dequantize, inputs)
  float_results = op(*float_inputs)
  return map(quantize, float_results, inputs_and_output_type[-3:])

def dequantize_compare(lhs, rhs, comparison_direction):
  float_lhs = dequantize(lhs)
  float_rhs = dequantize(rhs)
  return compare(float_lhs, float_rhs, comparison_direction, FLOAT)

def dequantize_select_quantize(pred, on_true, on_false, output_type):
  float_on_true = dequantize(on_true)
  float_on_false = dequantize(on_false)
  float_result = select(pred, float_on_true, float_on_false)
  return quantize(float_result, output_type)

Komputasi petak

  • cross_partition(replica_groups: Value) -> Value. Lihat bagian "cross_replica" di atas.

  • cross_replica(replica_groups: Value) -> Value. Lihat bagian "cross_replica" di atas.

  • cross_replica_and_partition(replica_groups: Value) -> Value. Lihat bagian "cross_replica_and_partition" di atas.

  • flattened_ids(replica_groups: Value) -> Value. Lihat bagian "flattened_ids" di atas.