StableHLO, makine öğrenimi (ML) modellerinde üst düzey işlemler (HLO) için kullanılan bir işlem kümesidir. StableHLO, farklı makine öğrenimi çerçeveleri ve ML derleyicileri arasında taşınabilirlik katmanı olarak çalışır: StableHLO programları üreten ML çerçeveleri, StableHLO programlarını kullanan ML derleyicileriyle uyumludur.
Amacımız, çeşitli makine öğrenimi çerçeveleri (ör. TensorFlow, JAX ve PyTorch) ile ML derleyicileri (ör. XLA ve IREE) arasında daha fazla birlikte çalışabilirlik sağlayarak makine öğrenimi geliştirme sürecini basitleştirmek ve hızlandırmaktır. Bu amaçla, bu belgede StableHLO programlama dili için spesifikasyon sağlanmaktadır.
Bu spesifikasyon üç ana bölümden oluşur. İlk olarak Programlar bölümünde, StableHLO işlevlerinden oluşan StableHLO programlarının yapısı açıklanmaktadır. Bunlar StableHLO işlemlerinden oluşur. Bu yapı içinde, İşlemler bölümü her bir işlemin anlamını belirtir. Yürütme bölümü, bir program içinde birlikte yürütülen tüm bu işlemlerin anlamlarını içerir. Son olarak, Not bölümünde spesifikasyon boyunca kullanılan not açıklanmaktadır.
Programlar
Program ::= {Func}
StableHLO programları, isteğe bağlı sayıda StableHLO işlevinden oluşur.
Aşağıda, 3 giriş (%image
, %weights
ve %bias
) ve 1 çıkışa sahip bir @main
işlevine sahip örnek program gösterilmektedir. İşlevin gövdesinde
6 işlem vardır.
func.func @main(
%image: tensor<28x28xf32>,
%weights: tensor<784x10xf32>,
%bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
%0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
%1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
%2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
%3 = "stablehlo.constant"() { value = dense<0.0> : tensor<1x10xf32> } : () -> tensor<1x10xf32>
%4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
"func.return"(%4): (tensor<1x10xf32>) -> ()
}
İşlevler
Func ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput ::= '%' ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput ::= ValueType
FuncBody ::= {Op}
Kararlı HLO işlevleri (adlandırılmış işlevler olarak da adlandırılır) bir tanımlayıcıya, giriş/çıkışlara ve bir gövdeye sahiptir. Gelecekte HLO ile daha iyi uyumluluk sağlamak amacıyla işlevler için ek meta veriler sunmayı planlıyoruz (#425, #626, #740, #744).
Tanımlayıcılar
FuncId ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
| '%' letter {letter | digit}
letter ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit ::= '0' | ... | '9'
StableHLO tanımlayıcıları, birçok programlama dilindeki tanımlayıcılara benzer. Bu tanımlayıcıların iki özelliği vardır: 1) tüm tanımlayıcılar, farklı tanımlayıcı türlerini ayırt eden kimlik işaretlerine sahiptir, 2) değer tanımlayıcıları, StableHLO programlarının oluşturulmasını basitleştirmek için tamamen sayısal olabilir.
Türler
Type ::= ValueType | NonValueType
ValueType ::= TensorType | QuantizedTensorType | TokenType | TupleType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType
KararlıHLO türleri, StableHLO değerlerini temsil eden değer türleri (birinci sınıf türler olarak da adlandırılır) ve diğer program öğelerini açıklayan değer olmayan türleri olarak sınıflandırılır. StableHLO türleri, birçok programlama dilindeki türlere benzerdir.Temel özelliği StableHLO'nun alana özgü yapısıdır ve bu da bazı olağandışı sonuçlara neden olur (ör. skaler türler değer türleri değildir).
TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit}
Tensör türleri, tensörleri (yani çok boyutlu dizileri) temsil eder. Bunlar bir
şekil ve bir öğe türüne sahiptir. Burada bir şekil, 0
ile R-1
arasında numaralandırılmış ve karşılık gelen boyutların (eksenler olarak da adlandırılır) artan sırasındaki negatif olmayan boyut boyutlarını temsil eder. R
boyutlarının sayısına sıralama denir. Örneğin tensor<2x3xf32>
, 2x3
şekline ve f32
öğe türüne sahip bir tensor türüdür. Boyutu 2 ve 3 olmak üzere, 0. boyut ve 1. boyut olmak üzere iki boyutu (veya diğer bir deyişle iki ekseni) vardır. Rütbesi 2.
Bu, boyut boyutlarının statik olarak bilindiği statik şekiller için desteği tanımlar. Gelecekte, boyut boyutlarının kısmen veya tamamen bilinmediği (#8) dinamik şekiller için de destek sunmayı planlıyoruz. Ayrıca, tensör türlerini örneğin, düzenler (#629) ve seyreklik (#1078) eklemek için boyut boyutlarının ve öğe türlerinin ötesine genişletmeyi de planlıyoruz.
QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
QuantizationStorageType
['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
':' QuantizationExpressedType
[':' QuantizationDimension]
',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerConstant
QuantizationStorageMax ::= IntegerConstant
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerConstant
QuantizationParameters ::= QuantizationParameter
| '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale ':' QuantizationZeroPoint
QuantizationScale ::= FloatConstant
QuantizationZeroPoint ::= IntegerConstant
Ad | Tür | Sınırlamalar |
---|---|---|
storage_type |
tam sayı türü | (C1-C4), (C9) |
storage_min |
Tam sayı sabiti | (C2), (C4), (C8) |
storage_max |
Tam sayı sabiti | (C3), (C4), (C8) |
expressed_type |
kayan nokta türü | (C1), (C5) |
quantization_dimension |
isteğe bağlı tam sayı sabiti | (C11-C13) |
scales |
kayan nokta sabitlerinin değişken sayısı | (C5-C7), (C10), (C11), (C13) |
zero_points |
tam sayı sabitlerinin değişken sayısı | (C8-C10) |
Nicelleştirilmiş öğe türleri, storage_min
ile storage_max
(dahil) aralığındaki ve ifade edilen bir türdeki kayan nokta değerlerine karşılık gelen depolama türünün tam sayı değerlerini temsil eder. Belirli bir tam sayı değeri i
için karşılık gelen f
kayan nokta değeri f = (i - zero_point) * scale
olarak hesaplanabilir. Burada scale
ve zero_point
, ölçüm parametreleri olarak adlandırılır. storage_min
ve storage_max
, dilbilgisinde isteğe bağlıdır ancak sırasıyla min_value(storage_type)
ve max_value(storage_type)
varsayılan değerlerine sahiptir. Nicel öğe türlerinde aşağıdaki kısıtlamalar bulunur:
- (C1)
num_bits(storage_type) < num_bits(expressed_type)
. - (C2)
type(storage_min) = storage_type
. - (C3)
type(storage_max) = storage_type
. - (C4)
min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type)
. - (C5)
type(scales...) = expressed_type
. - (C6)
0 < scales
. - (C7)
is_finite(scales...)
. - (C8)
storage_min <= zero_points <= storage_max
. - (C9)
type(zero_points...) = storage_type
. - (C10)
size(scales) = size(zero_points)
. - (C11)
is_empty(quantization_dimension)
isesize(scales) = 1
. - (C12)
0 <= quantization_dimension
.
QuantizationScale
, şu anda bir kayan nokta sabitidir ancak çarpanlar ve kaydırmalarla temsil edilen tam sayı tabanlı ölçeklere büyük ilgi göstermektedir. Bu konuyu yakın gelecekte keşfetmeyi planlıyoruz
(#1404).
Tür, değerler ve ölçülmüş bir tensör türünde yalnızca bir veya potansiyel olarak birden çok sıfır noktasının olup olmadığı dahil olmak üzere QuantizationZeroPoint
öğesinin anlamıyla ilgili tartışmalar devam etmektedir. Bu tartışmanın sonuçlarına göre, sıfır puan civarındaki spesifikasyon gelecekte değişebilir (#1405).
Devam eden bir diğer tartışma, bu değerler ve ölçülmüş tensörlerin değerleri (#1406) üzerinde herhangi bir kısıtlama uygulanması gerekip gerekmediğini belirlemek için QuantizationStorageMin
ve QuantizationStorageMax
anlamını içerir.
Son olarak, bilinmeyen boyutların temsil edilmesiyle ilgili keşifler yapmayı planladığımız gibi (#1407) bilinmeyen ölçekleri ve sıfır noktalarını temsil etmeyi planlıyoruz.
Nicelleştirilmiş tensör türleri, ölçülmüş öğelere sahip tensörleri temsil eder. Bu tensörler, normal tensörlerle tamamen aynıdır. Tek fark, öğelerinin normal öğe türleri yerine ölçülmüş öğe türlerine sahip olmasıdır.
Ölçülen tensörlerde niceleme, tensör başına, yani tüm tensör için bir scale
ve zero_point
içerebilir veya eksen başına birden fazla scales
ve zero_points
, belirli bir boyutun (quantization_dimension
boyutu) her dilimi için bir çift içerebilir. Daha resmi olarak, eksen başına nicelemenin kullanıldığı bir tensör t
öğesinde dim(t, quantization_dimension)
quantization_dimension
dilimi bulunur: t[:, ..., 0, ..., :], t[:, ..., 1, ..., :]
vb. i
. dilimdeki tüm öğeler, niceleme parametreleri olarak scales[i]
ve zero_points[i]
kullanır. Nicelleştirilmiş tensör türleri aşağıdaki kısıtlamalara sahiptir:
- Tensör başına ölçüm için:
- Ek sınırlama yok.
- Eksen başına ölçüm için:
- (C12)
quantization_dimension < rank(self)
. - (C13)
dim(self, quantization_dimension) = size(scales)
.
- (C12)
TokenType ::= 'token'
Jeton türleri jetonları, yani bazı işlemler tarafından üretilen ve tüketilen opak değerleri temsil eder. Jetonlar, Yürütme bölümünde açıklandığı gibi işlemlere yürütme sırası uygulamak için kullanılır.
TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]
Tuple türleri, yani heterojen listeler gibi unsurları temsil eder. Tuple'lar, yalnızca HLO ile uyumluluk için mevcut olan eski bir özelliktir. HLO'da değişken giriş ve çıkışları temsil etmek için tuple'ler kullanılır. StableHLO'da değişken giriş ve çıkışlar yerel olarak desteklenir.StableHLO'da tuple'ların kullanımı, HLO ABI'yi kapsamlı bir şekilde temsil etmektir. Örneğin T
, tuple<T>
ve tuple<tuple<T>>
belirli bir uygulamaya bağlı olarak önemli ölçüde farklı olabilir. Gelecekte HLO ABI'de, tuple türlerini StableHLO'dan kaldırmamıza olanak tanıyacak değişiklikler yapmayı planlıyoruz (#598).
TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f8E4M3FN' | 'f8E5M2' | 'f8E4M3FNUZ' | 'f8E5M2FNUZ'
| 'f8E4M3B11FNUZ' | 'bf16' | 'f16' | 'f32' | 'f64'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'
Öğe türleri, tensör türlerinin öğelerini temsil eder. Birçok programlama dilinden farklı olarak bu türler StableHLO'da birinci sınıf değildir. Bu, SttableHLO programlarının bu türlerin değerlerini doğrudan temsil edemeyeceği anlamına gelir (sonuç olarak, T
türündeki skaler değerleri tensor<T>
türünde 0 boyutlu tensör değerleriyle temsil etmek deyimseldir).
- Boole türü,
true
vefalse
boole değerlerini temsil eder. - Tam sayı türleri işaretli (
si
) veya işaretsiz (ui
) olabilir ve desteklenen bit genişliklerinden birine (4
,8
,16
,32
veya64
) sahip olabilir. İmzalısiN
türleri,-2^(N-1)
ile2^(N-1)-1
arasındaki tam sayı değerlerini, imzasızuiN
türleri ise0
ile2^N-1
arasındaki tam sayı değerlerini temsil eder. - Kayan nokta türleri aşağıdakilerden biri olabilir:
- Derin Öğrenme için FP8 Biçimleri bölümünde açıklanan FP8 biçiminin sırasıyla
E4M3
veE5M2
kodlamalarına karşılık gelenf8E4M3FN
vef8E5M2
türleri. - Derin Sinir Ağları için 8 bit Sayısal Biçimler bölümünde açıklanan FP8 biçimlerinin
E4M3
veE5M2
kodlamalarına karşılık gelenf8E4M3FNUZ
vef8E5M2FNUZ
türleri. - Derin Sinir Ağları için Hibrit 8 bit Kayan Nokta (HFP8) Eğitimi ve Çıkarımı bölümünde açıklanan FP8 biçimlerinin
E4M3
kodlamasına karşılık gelenf8E4M3B11FNUZ
türü. - BFloat16: Cloud TPU'larda yüksek performansın sırrı bölümünde açıklanan
bfloat16
biçimine karşılık gelenbf16
türü. - IEEE 754 standardında açıklanan sırasıyla
binary16
("yarı hassasiyet"),binary32
("tek kesinlik") vebinary64
("çift kesinlik") biçimlerine karşılık gelenf16
,f32
vef64
türleridir.
- Derin Öğrenme için FP8 Biçimleri bölümünde açıklanan FP8 biçiminin sırasıyla
- Karmaşık türler, gerçek bir kısmı ve aynı öğe türünün sanal bir kısmı olan karmaşık değerleri temsil eder. Desteklenen karmaşık türler
complex<f32>
(her iki parça daf32
türündedir) vecomplex<f64>
'dir (her iki parça daf64
türündedir).
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]
İşlev türleri, hem adlandırılmış hem de anonim işlevleri temsil eder. Giriş türleri (->
öğesinin sol tarafındaki türlerin listesi) ve çıkış türleri (->
uygulamasının sağ tarafındaki türler listesi) vardır. Birçok programlama dilinde fonksiyon türleri birinci sınıftır ancak StableHLO'da değildir.
StringType ::= 'string'
Dize türü, bayt dizilerini temsil eder. Birçok programlama dilinden farklı olarak dize türü StableHLO'da birinci sınıf değildir ve yalnızca program öğeleri için statik meta verileri belirtmek amacıyla kullanılır.
İşlemler
KararlıHLO işlemleri (işlemler olarak da adlandırılır), makine öğrenimi modellerinde kapalı bir üst düzey işlem grubunu temsil eder. Yukarıda açıklandığı gibi SttableHLO söz dizimi, büyük ölçüde MLIR'den esinlenilmiştir. MLIR, her zaman en ergonomik alternatif olmasa da StableHLO'nun ML çerçeveleri ile ML derleyicileri arasında daha fazla birlikte çalışabilirlik oluşturma hedefine en uygun yöntem olduğu tahmin edilmektedir.
Op ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic ::= 'abs' | 'add' | ...
KararlıHLO işlemlerinin (işlemler olarak da adlandırılır) bir adı, girişleri/çıkışları ve bir imzası vardır. Ad, stablehlo.
ön eki ve desteklenen işlemlerden birini benzersiz şekilde tanımlayan bir hatırlatıcıdan oluşur. Desteklenen tüm işlemlerin kapsamlı bir listesini aşağıda bulabilirsiniz.
Şu anda, doğadaki StableHLO programları bazen bu belgede açıklanmayan işlemler içermektedir. Gelecekte bu işlemleri StableHLO mantığına almayı veya StableHLO programlarında görünmelerini engellemeyi planlıyoruz. Bu süre zarfında söz konusu işlemlerin listesi şu şekildedir:
builtin.module
,func.func
,func.call
vefunc.return
(#425).chlo
işlemleri (#602).- StableHLO işlemlerinin "HLO'da değil" kategorisinde - başlangıçta StableHLO operasyonunun bir parçasıydılar, ancak daha sonra bu işleme uygun olmadığı kabul edildi:
broadcast
,create_token
,cross-replica-sum
,dot
,einsum
,torch_index_select
,unary_einsum
(#3). - StableHLO işlemlerinin "Dynamism" kategorisi. Bunlar MHLO'dan yüklenmiştir, ancak henüz belirtilmemiştir:
compute_reshape_shape
,cstr_reshapable
,dynamic_broadcast_in_dim
,dynamic_conv
,dynamic_gather
,dynamic_iota
,dynamic_pad
,dynamic_reshape
,real_dynamic_slice
,set_dimension_size
(#8). arith
,shape
vetensor
işlemleri dahil şekil hesaplamaları (#8).
OpInputs ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue ::= ValueId
OpInputFuncs ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs ::= [OpOutput {',' OpOutput} '=']
OpOutput ::= ValueId
İşlemler girişleri tüketir ve çıktı oluşturur. Girişler; giriş değerleri (yürütme sırasında hesaplanan), giriş işlevleri (StableHLO işlevlerinde birinci sınıf değerler olmadığı için statik olarak sağlanır) ve giriş özellikleri (statik olarak da sağlanır) olarak sınıflandırılır. Bir opsiyon tarafından tüketilen ve üretilen giriş ve çıkışların türü, anımsatıcıya bağlıdır. Örneğin, add
işlemi 2 giriş değeri tüketir ve 1 çıkış değeri üretir. Buna karşılık, select_and_scatter
işlemi 3 giriş değeri, 2 giriş işlevi ve 3 giriş özelliği tüketir.
OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused ::= '^' digit {digit}
| '^' letter {letter | digit}
Giriş işlevleri (anonim işlevler olarak da adlandırılır), adlandırılmış işlevlere çok benzerdir. Tek fark: 1) tanımlayıcısı yoktur (dolayısıyla "anonim" olarak adlandırılır), 2) çıkış türlerini belirtmez (çıktı türleri, işlev içindeki return
işleminden belirlenir).
Giriş işlevlerinin söz dizimi, MLIR ile uyumluluk için o anda kullanılmayan bir parça (yukarıdaki Unused
üretimine bakın) içerir. MLIR'de atlama işlemleri aracılığıyla birbirine bağlanan birden fazla işlem "blosu" içerebilen daha genel bir "bölge" kavramı vardır. Bu bloklar, Unused
üretimine karşılık gelen kimliklere sahiptir. Böylece birbirlerinden ayırt edilebilirler.
StableHLO'da atlama işlemleri olmadığından MLIR söz diziminin karşılık gelen bölümü kullanılmaz (ancak hâlâ mevcuttur).
OpInputAttr ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName ::= letter {letter | digit}
OpInputAttrValue ::= Constant
Giriş özellikleri, desteklenen sabit değerlerden biri olan bir ada ve değere sahiptir. Bunlar, program öğeleri için statik meta verileri belirtmenin birincil yoludur. Örneğin, concatenate
işlemi, giriş değerlerinin birleştirildiği boyutu belirtmek için dimension
özelliğini kullanır. Benzer şekilde, slice
işlemi, giriş değerini bölümlendirmek için kullanılan sınırları belirtmek amacıyla start_indices
ve limit_indices
gibi birden çok özellik kullanır.
Şu anda, normalde kullanıma sunulan StableHLO programları bazen bu belgede açıklanmayan özellikler içermektedir. Gelecekte bu özellikleri StableHLO türüne almayı veya StableHLO programlarında görünmelerini engellemeyi planlıyoruz. Bu arada, bu özelliklerin listesi şu şekildedir:
layout
(#629).mhlo.frontend_attributes
(#628).mhlo.sharding
(#619).output_operand_aliases
(#740).- Konum meta verileri (#594).
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'
İşlem imzası, tüm giriş değeri türlerinden (->
öğesinin sol tarafındaki tür listesi) ve tüm çıkış değeri türlerinden (->
öğesinin sağ tarafındaki tür listesi) oluşur. Özet olarak, giriş türleri yedektir ve çıkış türleri de neredeyse her zaman yedektir (çünkü çoğu StableHLO işlemleri için çıkış türleri girişlerden çıkarılabilir). Bununla birlikte işlem imzası, MLIR ile uyumluluk için kasıtlı olarak StableHLO söz diziminin bir parçasıdır.
Aşağıda, anımsatıcısı select_and_scatter
olan bir işlem örneği verilmiştir. 3 giriş değeri (%operand
, %source
ve %init_value
), 2 giriş işlevi ve 3 giriş özelliği (window_dimensions
, window_strides
ve padding
) kullanır.
İşlem imzasının, satır içinde sağlanan giriş işlevi ve özellik türlerini değil, yalnızca giriş değeri türlerini içerdiğini unutmayın.
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i32>, tensor<i32>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
window_dimensions = dense<[3, 1]> : tensor<2xi64>,
window_strides = dense<[2, 1]> : tensor<2xi64>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>
Sabitler
Constant ::= BooleanConstant
| IntegerConstant
| FloatConstant
| ComplexConstant
| TensorConstant
| QuantizedTensorConstant
| StringConstant
| EnumConstant
StableHLO sabitlerinin bir değişmez değeri ve türü, birlikte bir StableHLO değerini temsil eder. Genel olarak tür, belirsiz olmadığı durumlar haricinde sabit söz diziminin bir parçasıdır (ör. bir boole sabiti, i1
türünü açık bir şekilde gösterirken, bir tam sayı sabitinin birden fazla olası türü olabilir).
BooleanConstant ::= BooleanLiteral
BooleanLiteral ::= 'true' | 'false'
Boole sabitleri, true
ve false
boole değerlerini temsil eder. Boole sabitleri i1
türüne sahiptir.
IntegerConstant ::= IntegerLiteral ':' IntegerType
IntegerLiteral ::= ['-' | '+'] DecimalDigits
| ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit ::= '0' | ... | '9'
hexadecimalDigit ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'
Tam sayı sabitleri, ondalık veya onaltılık gösterim kullanan dizeler aracılığıyla tam sayı değerlerini temsil eder. Diğer bazlar (ör. ikili veya sekizlik) desteklenmez. Tam sayı sabitleri aşağıdaki kısıtlamalara sahiptir:
- (C1)
is_wellformed(integer_literal, integer_type)
.
FloatConstant ::= FloatLiteral ':' FloatType
FloatLiteral ::= SignPart IntegerPart FractionalPart ScientificPart
| '0x' [HexadecimalDigits]
SignPart ::= ['-' | '+']
IntegerPart ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]
Kayan nokta sabitleri, ondalık veya bilimsel gösterim kullanan dizeler aracılığıyla kayan nokta değerlerini temsil eder. Buna ek olarak, alttaki bitleri ilgili türün kayan nokta biçiminde doğrudan belirtmek için onaltılı gösterim kullanılabilir. Kayan nokta sabitleri aşağıdaki kısıtlamalara sahiptir:
- (C1) Onaltılık olmayan gösterim kullanılıyorsa
is_wellformed(float_literal, float_type)
. - (C2) Onaltılık gösterim kullanılıyorsa
size(hexadecimal_digits) = num_bits(float_type) / 4
.
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral ::= '(' RealPart ',' ImaginaryPart ')'
RealPart ::= FloatLiteral
ImaginaryPart ::= FloatLiteral
Karmaşık sabitler, gerçek bir parçanın (önce gelir) ve sanal bir kısmın (ikinci adım) listelerini kullanarak karmaşık değerleri temsil eder. Örneğin, (1.0, 0.0) : complex<f32>
1.0 + 0.0i
değerini, (0.0, 1.0) : complex<f32>
ise 0.0 + 1.0i
değerini temsil eder. Bu parçaların bellekte depolandığı sıra, uygulama tarafından belirlenir. Karmaşık sabitlerde aşağıdaki kısıtlamalar bulunur:
- (C1)
is_wellformed(real_part, complex_element_type(complex_type))
. - (C2)
is_wellformed(imaginary_part, complex_element_type(complex_type))
.
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral
Tensör sabitleri, NumPy gösterimiyle belirtilen iç içe yerleştirilmiş listeleri kullanan tensör değerlerini temsil eder. Örneğin, dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
, dizinlerden öğelere aşağıdaki eşlemeyle bir tensör değerini temsil eder:
{0, 0} => 1
, {0, 1} => 2
, {0, 2} => 3
, {1, 0} => 4
, {1, 1} => 5
,
{1, 2} => 6
. Bu öğelerin bellekte depolandığı sıra, uygulama tarafından belirlenir. Tensör sabitleri aşağıdaki kısıtlamalara sahiptir:
- (C1)
has_syntax(tensor_literal, element_type(tensor_type))
; burada:has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type)
.has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type)
.
- (C2)
has_shape(tensor_literal, shape(tensor_type))
; burada:has_shape(element_literal: Syntax, []) = true
.has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:])
.- aksi takdirde
false
.
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
Nicelleştirilmiş tensör sabitleri, tensör sabitleriyle aynı gösterimi kullanan nicelleştirilmiş tensör değerlerini temsil eder ve öğeler, depolama türlerinin sabitleri olarak belirtilir. Nicelleştirilmiş tensör sabitleri aşağıdaki kısıtlamalara sahiptir:
- (C1)
has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type))
. - (C2)
has_shape(quantized_tensor_literal, shape(quantized_tensor_type))
.
StringConstant ::= StringLiteral
StringLiteral ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))
Dize hazır değerleri, ASCII karakterleri ve çıkış dizileri kullanılarak belirtilen baytlardan oluşur. Bunlar kodlamadan bağımsızdır. Bu nedenle, bu baytların yorumlanması uygulama tanımlıdır. Dize değişmez değerlerinin türü string
.
İşlemler
abs
Anlambilim
operand
tensörü üzerinde öğe düzeyinde abs işlemi gerçekleştirir ve bir result
tensörü üretir. Öğe türüne bağlı olarak aşağıdakileri yapar:
- İşaretli tam sayılar için: tam sayı modülü.
- Kayan reklamlar için: IEEE-754'ten
abs
. - Karmaşık sayılar için: karmaşık modül.
- Nicel türler için:
dequantize_op_quantize(abs, operand, type(result))
.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
işaretli tam sayı, kayan nokta, karmaşık tür veya tensör başına ölçülmüş tensör tensörü | (C1-C2) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
işaretli tam sayı veya kayan nokta türü ya da tensör başına ölçülmüş tensör tensörü | (C1-C2) |
Sınırlamalar
- (C1)
shape(result) = shape(operand)
. - (C2)
baseline_element_type(result)
şu şekilde tanımlanır:is_complex(operand)
isecomplex_element_type(element_type(operand))
.- Aksi takdirde
baseline_element_type(operand)
.
Örnekler
// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand) : (tensor<3xi32>) -> tensor<3xi32>
// %result: [2, 0, 2]
add
Anlambilim
İki tensör (lhs
ve rhs
) öğe bazında ekleme yapar ve bir result
tensörü üretir. Öğe türüne bağlı olarak aşağıdakileri yapar:
- Boole için: mantıksal OR.
- Tam sayılar için: tamsayılarla toplama.
- Kayan reklamlar için: IEEE-754'ten
addition
. - Karmaşık sayılar için: karmaşık toplama.
- Nicel türler için:
dequantize_op_quantize(add, lhs, rhs, type(result))
.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | lhs |
tensör veya tensör başına ölçülebilir tensör | (C1) |
(I2) | rhs |
tensör veya tensör başına ölçülebilir tensör | (C1) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
tensör veya tensör başına ölçülebilir tensör | (C1) |
Sınırlamalar
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Örnekler
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[6, 8], [10, 12]]
after_all
Anlambilim
inputs
öğesini üreten işlemlerin, result
öğesine bağlı tüm işlemlerden önce yürütülmesini sağlar. Bu işlemin yürütülmesi hiçbir şey yapmaz, yalnızca result
ile inputs
arasında veri bağımlılığı oluşturmak için vardır.
Girişler
Etiket | Ad | Tür |
---|---|---|
(I1) | inputs |
token değişken sayısı |
Çıkışlar
Ad | Tür |
---|---|
result |
token |
Örnekler
// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token
all_gather
Anlambilim
StableHLO işlem tablosundaki her işlem grubu içinde, all_gather_dim
boyunca her süreçten operand
tensörünün değerlerini birleştirir ve bir result
tensörü oluşturur.
İşlem, StableHLO işlem tablosunu aşağıdaki şekilde tanımlanan process_groups
olarak böler:
channel_id <= 0 and use_global_device_ids = false
isecross_replica(replica_groups)
.channel_id > 0 and use_global_device_ids = false
isecross_replica_and_partition(replica_groups)
.channel_id > 0 and use_global_device_ids = true
iseflattened_ids(replica_groups)
.
Daha sonra, her process_group
içinde:
process_group
bölgesindekireceiver
tümü içinoperands@receiver = [operand@sender for sender in process_group]
.process_group
bölgesindekiprocess
tümü içinresult@process = concatenate(operands@process, all_gather_dim)
.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
tensör veya tensör başına ölçülebilir tensör | (C1), (C6) |
(I2) | all_gather_dim |
si64 türünün sabiti |
(C1), (C6) |
(I3) | replica_groups |
si64 türünde 2 boyutlu tensör sabiti |
(C2-C4) |
(I4) | channel_id |
si64 türünün sabiti |
(C5) |
(I5) | use_global_device_ids |
i1 türünün sabiti |
(C5) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
tensör veya tensör başına ölçülebilir tensör | (C6) |
Sınırlamalar
- (C1)
0 <= all_gather_dim < rank(operand)
. - (C2)
is_unique(replica_groups)
. - (C3)
size(replica_groups)
şu şekilde tanımlanır:cross_replica
kullanılıyorsanum_replicas
.cross_replica_and_partition
kullanılıyorsanum_replicas
.flattened_ids
kullanılıyorsanum_processes
.
- (C4)
0 <= replica_groups < size(replica_groups)
. - (C5)
use_global_device_ids = true
isechannel_id > 0
. - (C6)
type(result) = type(operand)
, şunlar hariç:dim(result, all_gather_dim) = dim(operand, all_gather_dim) * dim(process_groups, 1)
.
Örnekler
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
%result = "stablehlo.all_gather"(%operand) {
all_gather_dim = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
// channel_id = 0
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
// use_global_device_ids = false
} : (tensor<2x2xi64>) -> tensor<2x4xi64>
// %result@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
all_reduce
Anlambilim
StableHLO işlem tablosundaki her işlem grubu içinde, her işlemdeki operand
tensörünün değerlerine bir computation
azaltma işlevi uygular ve bir result
tensörü üretir.
İşlem, StableHLO işlem tablosunu aşağıdaki şekilde tanımlanan process_groups
olarak böler:
channel_id <= 0 and use_global_device_ids = false
isecross_replica(replica_groups)
.channel_id > 0 and use_global_device_ids = false
isecross_replica_and_partition(replica_groups)
.channel_id > 0 and use_global_device_ids = true
iseflattened_ids(replica_groups)
.
Daha sonra, her process_group
içinde:
- Aşağıdaki ikili ağaçlarda
schedule
içinresult@process[result_index] = exec(schedule)
değeri:exec(node)
=computation(exec(node.left), exec(node.right))
.exec(leaf)
=leaf.value
.
schedule
, sıralı geçişito_destination_type(operands@process_group...[result_index], type(func_inputs(computation)[0]))
olan uygulama tanımlı bir ikili ağaçtır.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
tensör veya tensör başına ölçülebilir tensör | (C5), (C6) |
(I2) | replica_groups |
si64 türündeki 1 boyutlu tensör sabitlerinin değişken sayısı |
(C1-C3) |
(I3) | channel_id |
si64 türünün sabiti |
(C4) |
(I4) | use_global_device_ids |
i1 türünün sabiti |
(C4) |
(I5) | computation |
işlev | (C5) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
tensör veya tensör başına ölçülebilir tensör | (C6-C7) |
Sınırlamalar
- (C1)
is_unique(replica_groups)
. - (C2)
size(replica_groups)
şu şekilde tanımlanır:cross_replica
kullanılıyorsanum_replicas
.cross_replica_and_partition
kullanılıyorsanum_replicas
.flattened_ids
kullanılıyorsanum_processes
.
- (C3)
0 <= replica_groups < size(replica_groups)
. - (C4)
use_global_device_ids = true
ise,channel_id > 0
. - (C5)
computation
,is_promotable(element_type(operand), E)
olacak şekilde(tensor<E>, tensor<E>) -> (tensor<E>)
türündedir. - (C6)
shape(result) = shape(operand)
. - (C7)
element_type(result) = E
.
Örnekler
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [1, 2, 3, 4]
// %operand@(1, 0): [5, 6, 7, 8]
%result = "stablehlo.all_reduce"(%operand) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<i64>) -> tensor<i64>
// %result@(0, 0): [6, 8, 10, 12]
// %result@(1, 0): [6, 8, 10, 12]
all_to_all
Anlambilim
StableHLO işlem ızgarasındaki her işlem grubunda, split_dimension
boyunca operand
tensörünün değerlerini parçalara ayırır, bölünmüş parçaları süreçler arasında dağıtır, concat_dimension
boyunca dağınık parçaları birleştirir ve bir result
tensörü oluşturur.
İşlem, StableHLO işlem tablosunu aşağıdaki şekilde tanımlanan process_groups
olarak böler:
channel_id <= 0
isecross_replica(replica_groups)
.channel_id > 0
isecross_partition(replica_groups)
.
Daha sonra, her process_group
içinde:
process_group
bölgesindeki tümsender
içinsplit_parts@sender = split(operand@sender, split_count, split_dimension)
.scattered_parts@receiver = [split_parts@sender[receiver_index] for sender in process_group]
buradareceiver_index = process_group.index(receiver)
.result@process = concatenate(scattered_parts@process, concat_dimension)
.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
tensör veya tensör başına ölçülebilir tensör | (C1-C3), (C9) |
(I2) | split_dimension |
si64 türünün sabiti |
(C1), (C2), (C9) |
(I3) | concat_dimension |
si64 türünün sabiti |
(C3), (C9) |
(I4) | split_count |
si64 türünün sabiti |
(C2), (C4), (C8), (C9) |
(I5) | replica_groups |
si64 türünde 2 boyutlu tensör sabiti |
(C5-C8) |
(I6) | channel_id |
si64 türünün sabiti |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
tensör veya tensör başına ölçülebilir tensör | (C9) |
Sınırlamalar
- (C1)
0 <= split_dimension < rank(operand)
. - (C2)
dim(operand, split_dimension) % split_count = 0
. - (C3)
0 <= concat_dimension < rank(operand)
. - (C4)
0 < split_count
. - (C5)
is_unique(replica_groups)
. - (C6)
size(replica_groups)
şu şekilde tanımlanır:cross_replica
kullanılıyorsanum_replicas
.cross_partition
kullanılıyorsanum_partitions
.
- (C7)
0 <= replica_groups < size(replica_groups)
. - (C8)
dim(replica_groups, 1) = split_count
. - (C9)
type(result) = type(operand)
, şunlar hariç:dim(result, split_dimension) = dim(operand, split_dimension) / split_count
.dim(result, concat_dimension) = dim(operand, concat_dimension) * split_count
.
Örnekler
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
%result = "stablehlo.all_to_all"(%operand) {
split_dimension = 1 : i64,
concat_dimension = 0 : i64,
split_count = 2 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
} : (tensor<2x4xi64>) -> tensor<4x2xi64>
// %result@(0, 0): [[1, 2],
// [5, 6],
// [9, 10],
// [13, 14]]
// %result@(1, 0): [[3, 4],
// [7, 8],
// [11, 12],
// [15, 16]]
ve
Anlambilim
İki tensörün (lhs
ve rhs
) öğe bazında AND işlemini gerçekleştirir ve bir result
tensörü üretir. Öğe türüne bağlı olarak aşağıdakileri yapar:
- Boole için: mantıksal VE.
- Tam sayılar için: bit tabanlı VE.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | lhs |
boole veya tam sayı türü tensörü | (C1) |
(I2) | rhs |
boole veya tam sayı türü tensörü | (C1) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
boole veya tam sayı türü tensörü | (C1) |
Sınırlamalar
- (C1)
type(lhs) = type(rhs) = type(result)
.
Örnekler
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 0]]
atan2
Anlambilim
lhs
ve rhs
tensörü üzerinde öğe bazında atan2 işlemi gerçekleştirir ve bir result
tensörü üretir. Öğe türüne bağlı olarak aşağıdakileri yapar:
- Kayan reklamlar için: IEEE-754'ten
atan2
. - Karmaşık sayılar için: karmaşık atan2.
- Nicel türler için:
dequantize_op_quantize(atan2, lhs, rhs, type(result))
.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | lhs |
kayan nokta veya karmaşık tür ya da tensör başına ölçülmüş tensör tensörü | (C1) |
(I2) | rhs |
kayan nokta veya karmaşık tür ya da tensör başına ölçülmüş tensör tensörü | (C1) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
kayan nokta veya karmaşık tür ya da tensör başına ölçülmüş tensör tensörü | (C1) |
Sınırlamalar
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Örnekler
// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs) : (tensor<3xf64>, tensor<3xf64>) -> tensor<3xf64>
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]
batch_norm_grad
Anlambilim
Çeşitli batch_norm_training
girdilerinin gradyanlarını grad_output
geri çoğaltıcı olarak hesaplar ve grad_operand
, grad_scale
ve grad_offset
tensör üretir. Daha resmi olarak bu işlem, Python söz dizimi kullanılarak mevcut StableHLO işlemlerine ayrıştırma olarak aşağıdaki gibi ifade edilebilir:
def compute_sum(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
return sum
def compute_mean(operand, feature_index):
sum = compute_sum(operand, feature_index)
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
# Broadcast inputs to type(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance`
# Intermediate values will be useful for computing gradients
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
# Use the implementation from batchnorm_expander.cc in XLA
# Temporary variables have exactly the same names as in the C++ code
elements_per_feature = broadcast_in_dim(
constant(divide(size(operand), dim(operand, feature_index)),
element_type(grad_output)),
[], type(operand))
i1 = multiply(grad_output, elements_per_feature)
i2 = broadcast_in_dim(
compute_sum(grad_output, feature_index), [feature_index], type(operand))
i3 = broadcast_in_dim(
compute_sum(multiply(grad_output, centered_operand), feature_index),
[feature_index], type(operand))
i4 = multiply(i3, centered_operand)
i5 = divide(i4, add(variance_bcast, epsilon_bcast))
i6 = subtract(subtract(i1, i2), i5)
grad_operand =
multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
grad_scale =
compute_sum(multiply(grad_output, normalized_operand), feature_index)
grad_offset = compute_sum(grad_output, feature_index)
return grad_operand, grad_scale, grad_offset
Miktarı ölçülmüş türler için dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean,
variance, grad_output: batch_norm_grad(operand, scale, mean, variance,
grad_output, epsilon, feature_index), operand, scale, mean, variance,
grad_output, type(grad_operand), type(grad_scale), type(feature_index))
performans gösterir.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
kayan nokta türü tensörü veya tensör başına ölçülebilir tensör | (C1-C3), (C5) |
(I2) | scale |
Kayan noktanın veya tensör başına ölçülmüş türde 1 boyutlu tensör | (C2), (C4), (C5) |
(I3) | mean |
Kayan noktanın veya tensör başına ölçülmüş türde 1 boyutlu tensör | (C2), (C4) |
(I4) | variance |
Kayan noktanın veya tensör başına ölçülmüş türde 1 boyutlu tensör | (C2), (C4) |
(I5) | grad_output |
kayan nokta türü tensörü veya tensör başına ölçülebilir tensör | (C2), (C3) |
(I6) | epsilon |
f32 türünün sabiti |
|
(I7) | feature_index |
si64 türünün sabiti |
(C1), (C5) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
grad_operand |
kayan nokta türü tensörü veya tensör başına ölçülebilir tensör | (C2), (C3) |
grad_scale |
Kayan noktanın veya tensör başına ölçülmüş türde 1 boyutlu tensör | (C2), (C4) |
grad_offset |
Kayan noktanın veya tensör başına ölçülmüş türde 1 boyutlu tensör | (C2), (C4) |
Sınırlamalar
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,mean
,variance
,grad_output
,grad_operand
,grad_scale
vegrad_offset
aynıbaseline_element_type
değerine sahip. - (C3)
operand
,grad_output
vegrad_operand
şekle sahip. - (C4)
scale
,mean
,variance
,grad_scale
vegrad_offset
aynı şekle sahip. - (C5)
size(scale) = dim(operand, feature_index)
.
Örnekler
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
// [[0.1, 0.1], [0.1, 0.1]],
// [[0.1, 0.1], [0.1, 0.1]]
// ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>,
tensor<2x2x2xf64>) -> (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %grad_operand: [
// [[0.0, 0.0], [0.0, 0.0]],
// [[0.0, 0.0], [0.0, 0.0]]
// ]
// %grad_scale: [0.0, 0.0]
// %grad_offset: [0.4, 0.4]
batch_norm_inference
Anlambilim
operand
tensörünü feature_index
boyutu hariç tüm boyutlarda normalleştirir ve bir result
tensörü oluşturur. Daha resmi olarak bu işlem, Python söz dizimi kullanılarak mevcut StableHLO işlemlerine ayrıştırma olarak aşağıdaki gibi ifade edilebilir:
def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
# Broadcast inputs to shape(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance` instead of
# computing them like `batch_norm_training` does.
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
return add(multiply(scale_bcast, normalized_operand), offset_bcast)
Miktarı ölçülmüş türler için dequantize_op_quantize(lambda operand, scale, offset, mean, variance:
batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index), operand, scale, offset, mean, variance, type(result))
performans gösterir.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
kayan nokta türü tensörü veya tensör başına ölçülebilir tensör | (C1-C7) |
(I2) | scale |
Kayan noktanın veya tensör başına ölçülmüş türde 1 boyutlu tensör | (C2), (C3) |
(I3) | offset |
Kayan noktanın veya tensör başına ölçülmüş türde 1 boyutlu tensör | (C2), (C4) |
(I4) | mean |
Kayan noktanın veya tensör başına ölçülmüş türde 1 boyutlu tensör | (C5) |
(I5) | variance |
Kayan noktanın veya tensör başına ölçülmüş türde 1 boyutlu tensör | (C2), (C6) |
(I6) | epsilon |
f32 türünün sabiti |
|
(I7) | feature_index |
si64 türünün sabiti |
(C1), (C3-C6) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
kayan nokta türü tensörü veya tensör başına ölçülebilir tensör | (C2), (C7) |
Sınırlamalar
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,offset
,mean
,variance
veresult
aynıbaseline_element_type
değerine sahip. - (C3)
size(scale) = dim(operand, feature_index)
. - (C4)
size(offset) = dim(operand, feature_index)
. - (C5)
size(mean) = dim(operand, feature_index)
. - (C6)
size(variance) = dim(operand, feature_index)
. - (C7)
baseline_type(operand) = baseline_type(result)
.
Örnekler
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>) -> tensor<2x2x2xf64>
// %result: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
batch_norm_training
Anlambilim
feature_index
boyutu hariç tüm boyutlarda ortalama ve varyansı hesaplar ve output
, batch_mean
ve batch_var
tensörlerini üreten operand
tensörünü normalleştirir. Daha resmi olarak bu işlem, Python söz dizimi kullanılarak mevcut StableHLO işlemlerine ayrıştırma olarak aşağıdaki gibi ifade edilebilir:
def compute_mean(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def compute_variance(operand, feature_index):
mean = compute_mean(operand, feature_index)
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
centered_operand = subtract(operand, mean_bcast)
return compute_mean(mul(centered_operand, centered_operand), feature_index)
def batch_norm_training(operand, scale, offset, epsilon, feature_index):
mean = compute_mean(operand, feature_index)
variance = compute_variance(operand, feature_index)
return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index),
mean, variance
Miktarı ölçülmüş türler için dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset:
batch_norm_training(operand, scale, offset, epsilon, feature_index), operand,
scale, offset, type(output), type(batch_mean), type(batch_var))
performans gösterir.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
kayan nokta türü tensörü veya tensör başına ölçülebilir tensör | (C1) |
(I2) | scale |
Kayan noktanın veya tensör başına ölçülmüş 1 boyutlu tensörü | (C2), (C3) |
(I3) | offset |
Kayan noktanın veya tensör başına ölçülmüş 1 boyutlu tensörü | (C2), (C4) |
(I4) | epsilon |
f32 türünün sabiti |
(C1), (C3-C6) |
(I5) | feature_index |
si64 türünün sabiti |
(C1), (C3-C6) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
output |
kayan nokta türü tensörü veya tensör başına ölçülebilir tensör | (C7) |
batch_mean |
Kayan noktanın veya tensör başına ölçülmüş 1 boyutlu tensörü | (C2), (C5) |
batch_var |
Kayan noktanın veya tensör başına ölçülmüş 1 boyutlu tensörü | (C2), (C6) |
Sınırlamalar
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,offset
,batch_mean
,batch_var
veoutput
aynıbaseline_element_type
değerine sahip. - (C3)
size(scale) = dim(operand, feature_index)
. - (C4)
size(offset) = dim(operand, feature_index)
. - (C5)
size(batch_mean) = dim(operand, feature_index)
. - (C6)
size(batch_var) = dim(operand, feature_index)
. - (C7)
baseline_type(output) = baseline_type(operand)
.
Örnekler
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) ->
(tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %output: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]
bitcast_convert
Anlambilim
operand
tensörü üzerinde bir bitcast işlemi gerçekleştirir ve tüm operand
tensörünün bitlerinin result
tensörü türü kullanılarak yeniden yorumlandığı bir result
tensörü oluşturur.
E = element_type(operand)
, E' = element_type(result)
ve R = rank(operand)
göz önünde bulundurulduğunda daha resmi bir şekilde:
num_bits(E') < num_bits(E)
ise,bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1])
.num_bits(E') > num_bits(E)
ise,bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :])
.num_bits(E') = num_bits(E)
ise,bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1])
.
bits
, belirli bir değerin bellek içi gösterimini döndürür ve tensörlerin tam temsili uygulama tanımlı olduğu ve öğe türlerinin tam temsili de uygulama tanımlı olduğu için davranışı uygulama tanımlıdır.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
tensör veya ölçülmüş tensör | (C1-C2) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
tensör veya ölçülmüş tensör | (C1-C2) |
Sınırlamalar
- (C1)
E = is_quantized(operand) ? storage_type(operand) : element_type(operand)
,E' = is_quantized(result) ? storage_type(result) : element_type(result)
veR = rank(operand)
göz önünde bulundurulduğunda:num_bits(E') = num_bits(E)
ise,shape(result) = shape(operand)
.num_bits(E') < num_bits(E)
ise:rank(result) = R + 1
.- Tüm
0 <= i < R
içindim(result, i) = dim(operand, i)
. dim(result, R) * num_bits(E') = num_bits(E)
.num_bits(E') > num_bits(E)
ise:rank(result) = R - 1
.- Tüm
0 <= i < R
içindim(result, i) = dim(operand, i)
. dim(operand, R - 1) * num_bits(E) = num_bits(E')
.
- (C2)
is_complex(operand) or is_complex(result)
ise,is_complex(operand) and is_complex(result)
.
Örnekler
// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand) : (tensor<f64>) -> tensor<4xf16>
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation
broadcast_in_dim
Anlambilim
operand
tensöründeki verileri kopyalayarak bir giriş tensörünün boyutlarını ve/veya sıralamasını genişletir ve result
tensörü oluşturur. Daha resmi bir şekilde ifade etmek gerekirse axes(operand)
içindeki tüm d
için result[result_index] = operand[operand_index]
:
dim(operand, d) = 1
iseoperand_index[d] = 0
.- Aksi takdirde
operand_index[d] = result_index[broadcast_dimensions[d]]
.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
tensör veya ölçülmüş tensör | (C1-C2), (C5-C6) |
(I2) | broadcast_dimensions |
si64 türünde 1 boyutlu tensör sabiti |
(C2-C6) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
tensör veya ölçülmüş tensör | (C1), (C3), (C5-C6) |
Sınırlamalar
- (C1)
element_type(result)
değerini veren:!is_per_axis_quantized(operand)
iseelement_type(operand)
.- Aksi takdirde
quantization_dimension(operand)
,scales(operand)
vezero_points(operand)
quantization_dimension(result)
,scales(result)
vezero_points(result)
yanıtlarından farklı olabilir. Ancakelement_type(operand)
.
- (C2)
size(broadcast_dimensions) = rank(operand)
. - (C3)
0 <= broadcast_dimensions < rank(result)
. - (C4)
is_unique(broadcast_dimensions)
. - (C5)
axes(operand)
içindeki tümd
için:dim(operand, d) = 1
veyadim(operand, d) = dim(result, broadcast_dimensions[d])
.
- (C6)
is_per_axis_quantized(result)
ise:quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)]
.dim(operand, quantization_dimension(operand)) = 1
isescales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result)))
.
Örnekler
// %operand: [
// [1, 2, 3]
// ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
broadcast_dimensions = array<i64: 2, 1>
} : (tensor<1x3xi32>) -> tensor<2x3x2xi32>
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
kılıf
Anlambilim
index
değerine bağlı olarak, branches
işlevinden tam olarak bir işlev yürüterek çıktıyı üretir. Daha resmî bir yapıya sahiptir. result = selected_branch()
Burada:
0 <= index < size(branches)
iseselected_branch = branches[index]
.- Aksi takdirde
selected_branch = branches[-1]
.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | index |
si32 türünde 0 boyutlu tensör |
|
(I2) | branches |
değişken sayıda fonksiyon | (C1-C4) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
results |
değişken sayıda tensör, ölçülmüş tensör veya jeton | (C4) |
Sınırlamalar
- (C1)
0 < size(branches)
. - (C2)
input_types(branches...) = []
. - (C3)
same(output_types(branches...))
. - (C4)
type(results...) = output_types(branches[0])
.
Örnekler
// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
"stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> ()
}, {
"stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> ()
}) : (tensor<i32>) -> (tensor<2xi64>, tensor<2xi64>)
// %result0: [1, 1]
// %result1: [1, 1]
cbrt
Anlambilim
operand
tensörü üzerinde öğe bazında kübik kök işlemi gerçekleştirir ve bir result
tensörü üretir. Öğe türüne bağlı olarak aşağıdakileri yapar:
- Kayan reklamlar için: IEEE-754'ten
rootn(x, 3)
. - Karmaşık sayılar için: karmaşık kübik kök.
- Miktarı belirlenmiş türler için:
dequantize_op_quantize(cbrt, operand, type(result))
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
kayan nokta veya karmaşık tür ya da tensör başına ölçülmüş tensör tensörü | (C1) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
kayan nokta veya karmaşık tür ya da tensör başına ölçülmüş tensör tensörü | (C1) |
Sınırlamalar
- (C1)
baseline_type(operand) = baseline_type(result)
.
Örnekler
// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand) : (tensor<4xf64>) -> tensor<4xf64>
// %result: [0.0, 1.0, 2.0, 3.0]
ceil
Anlambilim
operand
tensörünün öğe düzeyinde tavanı gerçekleştirir ve result
tensörü üretir.
IEEE-754 spesifikasyonundan roundToIntegralTowardPositive
işlemini uygular. Miktarı ölçülmüş türler için dequantize_op_quantize(ceil, operand, type(result))
performans gösterir.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
kayan nokta türü tensörü veya tensör başına ölçülebilir tensör | (C1) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
kayan nokta türü tensörü veya tensör başına ölçülebilir tensör | (C1) |
Sınırlamalar
- (C1)
baseline_type(operand) = baseline_type(result)
.
Örnekler
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]
Cholesky
Anlambilim
Bir grup matrisin Cholesky'nin ayrıştırılmasını hesaplar.
Daha resmi olarak index_space(result)
içindeki tüm i
için result[i0, ..., iR-3, :, :]
, alt üçgen (lower
true
ise) veya üst üçgen (lower
false
ise) matrisi biçiminde bir Cholesky ayrışmasıdır.a[i0, ..., iR-3, :, :]
Karşıt üçgendeki çıkış değerleri (yani katı üst üçgen veya karşılık olarak yüksek katı alt üçgen) uygulama tarafından tanımlanır.
Giriş matrisinin Hermityen pozitif tanımlı bir matris olmadığı i
varsa davranış tanımsızdır.
Miktarı ölçülmüş türler için dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result))
performans gösterir.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | a |
kayan nokta veya karmaşık tür ya da tensör başına ölçülmüş tensör tensörü | (C1-C3) |
(I2) | lower |
i1 türünde 0 boyutlu tensör sabiti |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
kayan nokta veya karmaşık tür ya da tensör başına ölçülmüş tensör tensörü | (C1) |
Sınırlamalar
- (C1)
baseline_type(a) = baseline_type(result)
. - (C2)
2 <= rank(a)
. - (C3)
dim(a, -2) = dim(a, -1)
.
Örnekler
// %a: [
// [1.0, 2.0, 3.0],
// [2.0, 20.0, 26.0],
// [3.0, 26.0, 70.0]
// ]
%result = "stablehlo.cholesky"(%a) {
lower = true
} : (tensor<3x3xf32>) -> tensor<3x3xf64>
// %result: [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
kenetli
Anlambilim
operand
tensörünün her öğesini bir minimum ve maksimum değer arasında sabitler ve bir result
tensörü oluşturur. Daha resmi ifade etmek gerekirse result[result_index] =
minimum(maximum(operand[result_index], min_element), max_element)
;
burada min_element = rank(min) = 0 ? min[] : min[result_index]
,
max_element = rank(max) = 0 ? max[] : max[result_index]
. Ölçülen türler için dequantize_op_quantize(clamp, min, operand, max, type(result))
performans gösterir.
Karmaşık sayılara sıralama uygulamak şaşırtıcı anlamlar içerir. Bu nedenle, gelecekte bu işleme yönelik karmaşık sayılara yönelik desteği kaldırmayı planlıyoruz (#560).
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | min |
tensör veya tensör başına ölçülebilir tensör | (C1), (C3) |
(I2) | operand |
tensör veya tensör başına ölçülebilir tensör | (C1-C4) |
(I3) | max |
tensör veya tensör başına ölçülebilir tensör | (C2), (C3) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
tensör veya tensör başına ölçülebilir tensör | (C4) |
Sınırlamalar
- (C1)
rank(min) = 0 or shape(min) = shape(operand)
. - (C2)
rank(max) = 0 or shape(max) = shape(operand)
. - (C3)
baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max)
. - (C4)
baseline_type(operand) = baseline_type(result)
.
Örnekler
// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max) : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// %result: [5, 13, 20]
collective_broadcast
Anlambilim
StableHLO işlem ızgarasındaki her işlem grubunda, kaynak işlemdeki operand
tensörünün değerini hedef işlemlere gönderin ve bir result
tensörü oluşturun.
İşlem, StableHLO işlem tablosunu aşağıdaki şekilde tanımlanan process_groups
olarak böler:
channel_id <= 0
isecross_replica(replica_groups)
.channel_id > 0
isecross_partition(replica_groups)
.
Sonrasında, result@process
ödülünü veren:
- İşlemin
process_groups[i]
içinde olduğu biri
varsaoperand@process_groups[i, 0]
. - Aksi takdirde
broadcast_in_dim(constant(0, element_type(result)), [], type(result))
.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
tensor | (C3) |
(I2) | replica_groups |
si64 türündeki 1 boyutlu tensör sabitlerinin değişken sayısı |
(C1), (C2) |
(I3) | channel_id |
si64 türünün sabiti |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
tensor | (C3) |
Sınırlamalar
- (C1)
is_unique(replica_groups)
. - (C2)
0 <= replica_groups < N
içindeN
şu şekilde tanımlanır:cross_replica
kullanılıyorsanum_replicas
.cross_partition
kullanılıyorsanum_partitions
.
- (C3)
type(result) = type(operand)
.
Örnekler
// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]
collective_permute
Anlambilim
StableHLO işlem ızgarasındaki her işlem grubunda, operand
tensörünün değerini kaynak işlemden hedef işleme gönderir ve bir result
tensörü üretir.
İşlem, StableHLO işlem tablosunu aşağıdaki şekilde tanımlanan process_groups
olarak böler:
channel_id <= 0
isecross_replica(source_target_pairs)
.channel_id > 0
isecross_partition(source_target_pairs)
.
Sonrasında, result@process
ödülünü veren:
process_groups[i, 1] = process
gibi biri
varsaoperand@process_groups[i, 0]
.- Aksi takdirde
broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))
.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
tensör veya tensör başına ölçülebilir tensör | (C5) |
(I2) | source_target_pairs |
si64 türünde 2 boyutlu tensör sabiti |
(C1-C4) |
(I3) | channel_id |
si64 türünün sabiti |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
tensör veya tensör başına ölçülebilir tensör | (C1) |
Sınırlamalar
- (C1)
dim(source_target_pairs, 1) = 2
. - (C2)
is_unique(source_target_pairs[:, 0])
. - (C3)
is_unique(source_target_pairs[:, 1])
. - (C4)
0 <= source_target_pairs < N
; buradaN
şu şekilde tanımlanır:cross_replica
kullanılıyorsanum_replicas
.cross_partition
kullanılıyorsanum_partitions
.
- (C5)
type(result) = type(operand)
.
Örnekler
// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]
compare
Anlambilim
comparison_direction
ve compare_type
tensörlerine göre lhs
ve rhs
tensörleri için öğe bazında karşılaştırma yapar ve bir result
tensörü üretir.
comparison_direction
ve compare_type
değerleri şu anlamlara sahiptir:
Boole ve tam sayı öğe türleri için:
EQ
:lhs = rhs
.NE
:lhs != rhs
.GE
:lhs >= rhs
.GT
:lhs > rhs
.LE
:lhs <= rhs
.LT
:lhs < rhs
.
İşlem, compare_type = FLOAT
içeren kayan nokta öğe türleri için aşağıdaki IEEE-754 işlemlerini uygular:
EQ
:compareQuietEqual
.NE
:compareQuietNotEqual
.GE
:compareQuietGreaterEqual
.GT
:compareQuietGreater
.LE
:compareQuietLessEqual
.LT
:compareQuietLess
.
İşlem, compare_type = TOTALORDER
içeren kayan nokta öğe türleri için IEEE-754'teki totalOrder
ve compareQuietEqual
işlemlerinin kombinasyonunu kullanır. Bu özelliğin kullanılmadığı görülüyor, bu nedenle gelecekte bu özelliği kaldırmayı planlıyoruz (#584).
Karmaşık öğe türleri için (real, imag)
çiftlerinin sözlüksel karşılaştırması, sağlanan comparison_direction
ve compare_type
kullanılarak yapılır.
Karmaşık sayılara sıralama uygulamak şaşırtıcı anlamlar içerir. Bu nedenle, gelecekte comparison_direction
GE
, GT
, LE
veya LT
(#560) olduğunda karmaşık sayılara yönelik desteği kaldırmayı planlıyoruz.
Ölçülen türler için dequantize_compare(lhs, rhs,
comparison_direction)
performans gösterir.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | lhs |
tensör veya tensör başına ölçülebilir tensör | (C1-C3) |
(I2) | rhs |
tensör veya tensör başına ölçülebilir tensör | (C1-C2) |
(I3) | comparison_direction |
EQ , NE , GE , GT , LE ve LT sıralaması |
|
(I4) | compare_type |
FLOAT , TOTALORDER , SIGNED ve UNSIGNED sıralaması |
(C3) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
boole türü tensör | (C2) |
Sınırlamalar
- (C1)
baseline_element_type(lhs) = baseline_element_type(rhs)
. - (C2)
shape(lhs) = shape(rhs) = shape(result)
. - (C3)
compare_type
şu şekilde tanımlanır:is_signed_integer(element_type(lhs))
iseSIGNED
.is_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs))
iseUNSIGNED
.is_float(element_type(lhs))
iseFLOAT
veyaTOTALORDER
.is_complex(element_type(lhs))
iseFLOAT
.
Örnekler
// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
comparison_direction = #stablehlo<comparison_direction LT>,
compare_type = #stablehlo<comparison_type FLOAT>
} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
// %result: [true, false]
karmaşık
Anlambilim
Bir çift gerçek ve sanal değerden (lhs
ve rhs
) karmaşık bir değere öğe bazında dönüşüm gerçekleştirir ve bir result
tensörü oluşturur.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | lhs |
f32 veya f64 türünde tensör |
(C1-C3) |
(I2) | rhs |
f32 veya f64 türünde tensör |
(C1) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
karmaşık türde tensör | (C2), (C3) |
Sınırlamalar
- (C1)
type(lhs) = type(rhs)
. - (C2)
shape(result) = shape(lhs)
. - (C3)
element_type(result)
,E = element_type(lhs)
buradacomplex<E>
türündedir.
Örnekler
// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs) : (tensor<2xf64>, tensor<2xf64>) -> tensor<2xcomplex<f64>>
// %result: [(1.0, 2.0), (3.0, 4.0)]
concatenate
Anlambilim
inputs
öğesini, dimension
boyutu boyunca belirtilen bağımsız değişkenlerle aynı sırayla birleştirir ve bir result
tensörü oluşturur. Daha resmi bir şekilde,
result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1]
. Burada:
id = d0 + ... + dk-1 + kd
.d
,dimension
değerine eşittir ved0
,inputs
öğesinind
. boyut boyutlarıdır.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | inputs |
değişken sayıda tensör veya tensör başına ölçülmüş tensör | (C1-C6) |
(I2) | dimension |
si64 türünün sabiti |
(C2), (C4), (C6) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
tensör veya tensör başına ölçülebilir tensör | (C5-C6) |
Sınırlamalar
- (C1)
same(element_type(inputs...))
. - (C2)
dim(inputs..., dimension)
hariçsame(shape(inputs...))
. - (C3)
0 < size(inputs)
. - (C4)
0 <= dimension < rank(inputs[0])
. - (C5)
element_type(result) = element_type(inputs[0])
. - (C6)
shape(result) = shape(inputs[0])
, şunlar hariç:dim(result, dimension) = dim(inputs[0], dimension) + ...
.
Örnekler
// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
dimension = 0 : i64
} : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64>
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]
sabit
Anlambilim
Sabit bir value
değerinden bir output
tensörü üretir.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | value |
sabit | (C1) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
output |
tensör veya ölçülmüş tensör | (C1) |
Sınırlamalar
- (C1)
type(value) = type(output)
.
Örnekler
%output = "stablehlo.constant"() {
value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
} : () -> tensor<2x2xf32>
// %output: [[0.0, 1.0], [2.0, 3.0]]
Dönüşüm gerçekleştirme
Anlambilim
operand
tensörü üzerinde bir öğe türünden diğerine öğe düzeyinde bir dönüşüm gerçekleştirir ve bir result
tensörü oluşturur.
boolean-to-any-supported-type dönüşümde false
değeri sıfıra, true
değeri ise bire dönüştürülür. any-supported-type-to-boolean dönüşümde, sıfır değeri false
, sıfır olmayan değerler de true
değerine dönüştürülür. Bunun karmaşık türler için nasıl
çalıştığını aşağıda görebilirsiniz.
Tamsayıdan-tamsayıya, tamsayıdan-kayan noktaya veya kayan-noktadan-kayan noktaya dayalı dönüşümler için kaynak değer, hedef türünde tam olarak temsil edilebiliyorsa sonuç değeri bu tam gösterim olur. Aksi takdirde davranış TBD olur (#180).
floating-point-to-integer dönüşüm içeren dönüşümlerde kesirli kısım kısaltılır. Kısaltılmış değer hedef türünde temsil edilemiyorsa davranış TBD (#180) olur.
Karmaşıktan karmaşıka dönüşüm içeren dönüşümler, gerçek ve sanal parçaların dönüştürülmesi için kayan noktadan kayan noktaya dönüşümlerle aynı davranış izler.
complex-to-any-other-type ve complex-to-any-other-type dönüşümlerinde kaynak sanal değeri yok sayılır ya da hedef sanal değeri sıfırlanır. Gerçek parçanın dönüşümü, kayan nokta dönüşümlerini izler.
Prensip olarak, bu işlem; miktar azalmasını (ölçümlü tensörlerden normal tensörlere dönüştürme), nicelleştirmeyi (normal tensörlerden nicelleştirilmiş tensörlere dönüştürme) ve yeniden ölçmeyi (sayısal tensörler arasındaki dönüşüm) ifade edebilir ancak şu anda bunun için özel işlemlerimiz vardır: ikinci kullanım durumu için üçüncü ve uniform_quantize
.uniform_dequantize
Gelecekte bu iki işlem convert
(#1576) ile birleştirilebilir.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
tensor | (C1) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
tensor | (C1) |
Sınırlamalar
- (C1)
shape(operand) = shape(result)
.
Örnekler
// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand) : (tensor<3xi64>) -> tensor<3xcomplex<f64>>
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]
konvolüsyon
Anlambilim
lhs
pencereleri ile rhs
dilimleri arasındaki nokta ürünlerini hesaplar ve result
sonucunu verir. Aşağıdaki şemada, result
içindeki öğelerin lhs
ve rhs
değerlerinden nasıl hesaplandığı gösterilmektedir.
Daha resmi bir şekilde, lhs
pencerelerini ifade etmek için girişleri lhs
çerçevesinde aşağıdaki yeniden çerçevelemeyi düşünün:
lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension))
.lhs_window_strides = lhs_shape(1, window_strides, 1)
.lhs_padding = lhs_shape([0, 0], padding, [0, 0])
.lhs_base_dilations = lhs_shape(1, lhs_dilation, 1)
.lhs_window_dilations = lhs_shape(1, rhs_dilation, 1)
.
Bu yeniden çerçeveleme aşağıdaki yardımcı işlevleri kullanır:
lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension])
.result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension])
.permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1]
; buradaj[d] = i[permutation[d]]
.
feature_group_count = 1
ve batch_group_count = 1
ise aşağıdaki durumlarda index_space(dim(result, output_spatial_dimensions...))
result[result_shape(:, output_spatial_index, :)] = dot_product
konumundaki tüm output_spatial_index
için:
padding_value = constant(0, element_type(lhs))
.padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1)
.lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides
.lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations)
.reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true])
. Bu özelliğin kullanılmadığı görülüyor. Bu nedenle gelecekte bu özelliği kaldırmayı planlıyoruz (#1181).dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension])
.
feature_group_count > 1
ise:
lhses = split(lhs, feature_group_count, input_feature_dimension)
.rhses = split(rhs, feature_group_count, kernel_output_feature_dimension)
.results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...)
.result = concatenate(results, output_feature_dimension)
.
batch_group_count > 1
ise:
lhses = split(lhs, batch_group_count, input_batch_dimension)
.rhses = split(rhs, batch_group_count, kernel_output_feature_dimension)
.results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...)
.result = concatenate(results, output_feature_dimension)
.
Miktarı ölçülmüş türler için dequantize_op_quantize(
lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs,
type(result))
performans gösterir.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | lhs |
tensör veya tensör başına ölçülebilir tensör | (C1), (C10-C11), (C14) (C25), (C27-C30) |
(I2) | rhs |
tensör veya ölçülmüş tensör | (C1), (C14-C16), (C25), (C27-C32) |
(I3) | window_strides |
si64 türünde 1 boyutlu tensör sabiti |
(C2-C3), (C25) |
(I4) | padding |
si64 türünde 2 boyutlu tensör sabiti |
(C4), (C25) |
(I5) | lhs_dilation |
si64 türünde 1 boyutlu tensör sabiti |
(C5-C6), (C25) |
(I6) | rhs_dilation |
si64 türünde 1 boyutlu tensör sabiti |
(C7-C8), (C25) |
(I7) | window_reversal |
i1 türünde 1 boyutlu tensör sabiti |
(C9) |
(I8) | input_batch_dimension |
si64 türünün sabiti |
(C10), (C13), (C25) |
(I9) | input_feature_dimension |
si64 türünün sabiti |
(C11), (C13-C14) |
(I10) | input_spatial_dimensions |
si64 türünde 1 boyutlu tensör sabiti |
(C12), (C13), (C25) |
(I11) | kernel_input_feature_dimension |
si64 türünün sabiti |
(C14), (C18) |
(I12) | kernel_output_feature_dimension |
si64 türünün sabiti |
(C15-C16), (C18), (C25), (C32) |
(I13) | kernel_spatial_dimensions |
si64 türünde 1 boyutlu tensör sabiti |
(C17-C18), (C25) |
(I14) | output_batch_dimension |
si64 türünün sabiti |
(C20), (C25) |
(I15) | output_feature_dimension |
si64 türünün sabiti |
(C20), (C25), (C33) |
(I16) | output_spatial_dimensions |
si64 türünde 1 boyutlu tensör sabiti |
(C19-C20), (C25) |
(I17) | feature_group_count |
si64 türünün sabiti |
(C11), (C14), (C16), (C21), (C23) |
(I18) | batch_group_count |
si64 türünün sabiti |
(C10), (C15), (C22), (C23), (C25) |
(I19) | precision_config |
DEFAULT , HIGH ve HIGHEST enumlarının değişken sayısı |
(C24) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
tensör veya ölçülmüş tensör | (C25-C28), (C30-C31), (C33) |
Sınırlamalar
- (C1)
N = rank(lhs) = rank(rhs)
. - (C2)
size(window_strides) = N - 2
. - (C3)
0 < window_strides
. - (C4)
shape(padding) = [N - 2, 2]
. - (C5)
size(lhs_dilation) = N - 2
. - (C6)
0 < lhs_dilation
. - (C7)
size(rhs_dilation) = N - 2
. - (C8)
0 < rhs_dilation
. - (C9)
size(window_reversal) = N - 2
. - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0
. - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0
. - (C12)
size(input_spatial_dimensions) = N - 2
. - (C13) Verilen
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]
:is_unique(input_dimensions)
.0 <= input_dimensions < N
.
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count
. - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0
. - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0
. - (C17)
size(kernel_spatial_dimensions) = N - 2
. - (C18) Verilen
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]
:is_unique(kernel_dimensions)
.0 <= kernel_dimensions < N
.
- (C19)
size(output_spatial_dimensions) = N - 2
. - (C20) Verilen
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]
:is_unique(output_dimensions)
.0 <= output_dimensions < N
.
- (C21)
0 < feature_group_count
. - (C22)
0 < batch_group_count
. - (C23)
feature_group_count = 1 or batch_group_count = 1
. - (C24)
size(precision_config) = 2
. - (C25)
dim(result, result_dim)
şu şekilde tanımlanır:result_dim = output_batch_dimension
isedim(lhs, input_batch_dimension) / batch_group_count
.result_dim = output_feature_dimension
isedim(rhs, kernel_output_feature_dimension)
.num_windows
ise aksi takdirde:output_spatial_dimensions[spatial_dim] = result_dim
.lhs_dim = input_spatial_dimensions[spatial_dim]
.rhs_dim = kernel_spatial_dimensions[spatial_dim]
.dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
.padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
.dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
.is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
.num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1
.
- (C26)
rank(result) = N
. - İşlemde ölçülmeyen tensörler kullanılıyorsa:
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result)
.
- (C27)
- İşlemde ölçülmüş tensörler kullanılıyorsa:
- (C28)
is_quantized_tensor(lhs) and is_quantized_tensor(rhs) and is_quantized_tensor(result)
. - (C29)
storage_type(lhs) = storage_type(rhs)
. - (C30)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C31)
is_per_tensor_quantized(rhs)
iseis_per_tensor_quantized(result)
. - (C32)
is_per_axis_quantized(rhs)
isequantization_dimension(rhs) = kernel_output_feature_dimension
. - (C33)
is_per_axis_quantized(result)
isequantization_dimension(result) = output_feature_dimension
.
- (C28)
Örnekler
// %lhs: [[
// [
// [1], [2], [5], [6]
// ],
// [
// [3], [4], [7], [8]
// ],
// [
// [10], [11], [14], [15]
// ],
// [
// [12], [13], [16], [17]
// ]
// ]]
//
// %rhs : [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
window_strides = dense<4> : tensor<2xi64>,
padding = dense<0> : tensor<2x2xi64>,
lhs_dilation = dense<2> : tensor<2xi64>,
rhs_dilation = dense<1> : tensor<2xi64>,
window_reversal = dense<false> : tensor<2xi1>,
// In the StableHLO dialect, dimension numbers are encoded via:
// `[<input dimensions>]x[<kernel dimensions>]->[output dimensions]`.
// "b" is batch dimension, "f" is feature dimension,
// "i" is input feature dimension, "o" is output feature dimension,
// "0/1/etc" are spatial dimensions.
dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
feature_group_count = 1 : i64,
batch_group_count = 1 : i64,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi32>, tensor<3x3x1x1xi32>) -> tensor<1x2x2x1xi32>
// %result: [[
// [[10], [26]],
// [[46], [62]]
// ]]
kosinüs
Anlambilim
operand
tensörü üzerinde öğe düzeyinde kosinüs işlemi gerçekleştirir ve result
tensörü üretir. Öğe türüne bağlı olarak aşağıdakileri yapar:
- Kayan reklamlar için: IEEE-754'ten
cos
. - Karmaşık sayılar için: karmaşık kosinüs.
- Nicel türler için:
dequantize_op_quantize(cosine, operand, type(result))
.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
kayan nokta veya karmaşık tür ya da tensör başına ölçülmüş tensör tensörü | (C1) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
kayan nokta veya karmaşık tür ya da tensör başına ölçülmüş tensör tensörü | (C1) |
Sınırlamalar
- (C1)
baseline_type(operand) = baseline_type(result)
.
Örnekler
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.cosine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.0], [-1.0, 0.0]]
count_leading_zeros
Anlambilim
operand
tensöründe baştaki sıfır bitlerinin öğe bazında sayımını yapar ve bir result
tensörü üretir.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
tam sayı türünde tensör | (C1) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
tam sayı türünde tensör | (C1) |
Sınırlamalar
- (C1)
type(operand) = type(result)
.
Örnekler
// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand) : (tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[64, 63], [56, 0]]
custom_call
Anlambilim
inputs
ile called_computations
alan ve results
sonucunu veren uygulama tanımlı bir call_target_name
işlemini içerir. has_side_effect
,
backend_config
ve api_version
, uygulama tanımlı ek meta veriler sağlamak için
kullanılabilir.
Şu anda bu işlem, XLA derleyicisindeki eşdeğer işleminin organik evrimini yansıtan, oldukça düzensiz bir meta veri koleksiyonu içermektedir. Gelecekte bu meta verileri birleştirmeyi planlıyoruz (#741).
Girişler
Etiket | Ad | Tür |
---|---|---|
(I1) | inputs |
değişken sayı |
(I2) | call_target_name |
string türünün sabiti |
(I3) | has_side_effect |
i1 türünün sabiti |
(I4) | backend_config |
string türünün sabiti |
(I5) | api_version |
si32 türünün sabiti |
(I6) | called_computations |
string türünün değişken sabit sayısı |
Çıkışlar
Ad | Tür |
---|---|
results |
değişken sayı |
Örnekler
%results = "stablehlo.custom_call"(%input0) {
call_target_name = "foo",
has_side_effect = false,
backend_config = "bar",
api_version = 1 : i32,
called_computations = [@foo]
} : (tensor<f64>) -> tensor<f64>
bölme
Anlambilim
Bölen lhs
ve bölen rhs
tensörlerinin öğe bazında bölme işlemini gerçekleştirir ve bir result
tensörü oluşturur. Öğe türüne bağlı olarak aşağıdakileri yapar:
- Tam sayılar için: Herhangi bir kesirli kısmı atılarak cebir bölümünü oluşturan tamsayı bölme.
- Kayan reklamlar için: IEEE-754'ten
division
. - Karmaşık sayılar için: karmaşık bölme.
- Nicel türler için:
dequantize_op_quantize(divide, lhs, rhs, type(result))
.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | lhs |
tam sayı, kayan nokta veya karmaşık tür tensörü ya da tensör başına ölçülmüş tensör | (C1) |
(I2) | rhs |
tam sayı, kayan nokta veya karmaşık tür tensörü ya da tensör başına ölçülmüş tensör | (C1) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
tam sayı, kayan nokta, karmaşık tür veya tensör başına ölçülmüş tensör tensörü | (C1) |
Sınırlamalar
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Örnekler
// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]
dot_general
Anlambilim
lhs
dilimleri ile rhs
dilimleri arasındaki nokta ürünlerini hesaplar ve bir result
tensör oluşturur.
Daha resmîdir, result[result_index] = dot_product
. Burada:
lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions]
.rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions]
.result_batching_index + result_lhs_index + result_rhs_index = result_index
size(result_batching_index) = size(lhs_batching_dimensions)
,size(result_lhs_index) = size(lhs_result_dimensions)
vesize(result_rhs_index) = size(rhs_result_dimensions)
olacak şekilde.transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions)
.transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :])
.reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions))
.transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions)
.transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :])
.reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions))
.dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y))
.
Miktarı ölçülmüş türler için dequantize_op_quantize(
lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
rhs_batching_dimensions, lhs_contracting_dimensions,
rhs_contracting_dimensions, precision_config), lhs, rhs, type(result))
performans gösterir.
Bu yalnızca tensor başına niceleme için anlamları belirtir. Eksen başına ölçümle ilgili çalışmalar devam etmektedir (#1574). Ayrıca, gelecekte karma niceleme için destek eklemeyi düşünebiliriz (#1575).
precision_config
, hızlandırıcı arka uçlarındaki hesaplamalarda hız ile doğruluk arasındaki dengeyi kontrol eder. Bu, aşağıdakilerden biri olabilir (şu anda bu enum değerlerinin anlamı eksiktir, ancak bu sorunu #755 içinde ele almayı planlıyoruz):
DEFAULT
: En hızlı hesaplama, ancak orijinal sayıya en az doğru tahmindir.HIGH
: Hesaplama daha yavaştır, ancak orijinal sayıya daha doğru bir tahminde bulunur.HIGHEST
: Orijinal sayıya en yavaş, ancak en doğru tahminde bulunur.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | lhs |
tensör veya tensör başına ölçülebilir tensör | (C5-C6), (C9-C10), (C12-C16) |
(I2) | rhs |
tensör veya tensör başına ölçülebilir tensör | (C7-C10), (C12) |
(I3) | lhs_batching_dimensions |
si64 türünde 1 boyutlu tensör sabiti |
(C1), (C3), (C5), (C9), (C12) |
(I4) | rhs_batching_dimensions |
si64 türünde 1 boyutlu tensör sabiti |
(C1), (C4), (C7), (C9) |
(I5) | lhs_contracting_dimensions |
si64 türünde 1 boyutlu tensör sabiti |
(C2), (C3), (C6), (C10) |
(I6) | rhs_contracting_dimensions |
si64 türünde 1 boyutlu tensör sabiti |
(C2), (C4), (C8), (C10) |
(I7) | precision_config |
DEFAULT , HIGH ve HIGHEST enumlarının değişken sayısı |
(C11) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
tensör veya tensör başına ölçülebilir tensör | (C12), (C14), (C16) |
Sınırlamalar
- (C1)
size(lhs_batching_dimensions) = size(rhs_batching_dimensions)
. - (C2)
size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions)
. - (C3)
is_unique(lhs_batching_dimensions + lhs_contracting_dimensions)
. - (C4)
is_unique(rhs_batching_dimensions + rhs_contracting_dimensions)
. - (C5)
0 <= lhs_batching_dimensions < rank(lhs)
. - (C6)
0 <= lhs_contracting_dimensions < rank(lhs)
. - (C7)
0 <= rhs_batching_dimensions < rank(rhs)
. - (C8)
0 <= rhs_contracting_dimensions < rank(rhs)
. - (C9)
dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)
. - (C10)
dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...)
. - (C11)
size(precision_config) = 2
. - (C12)
shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions)
. - İşlemde ölçülmeyen tensörler kullanılıyorsa:
- (C13)
element_type(lhs) = element_type(rhs)
.
- (C13)
- İşlemde ölçülmüş tensörler kullanılıyorsa:
- (C14)
is_quantized(lhs) and is_quantized(rhs) and is_quantized(result)
. - (C15)
storage_type(lhs) = storage_type(rhs)
. - (C16)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C17)
zero_points(rhs) = 0
.
- (C14)
Örnekler
// %lhs: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
// %rhs: [
// [[1, 0],
// [0, 1]],
// [[1, 0],
// [0, 1]]
// ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
dot_dimension_numbers = #stablehlo.dot<
lhs_batching_dimensions = [0],
rhs_batching_dimensions = [0],
lhs_contracting_dimensions = [2],
rhs_contracting_dimensions = [1]
>,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<2x2x2xi64>, tensor<2x2x2xi64>) -> tensor<2x2x2xi64>
// %result: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
dynamic_slice
Anlambilim
Dinamik olarak hesaplanan başlangıç dizinlerini kullanarak operand
öğesinden bir dilim çıkarır ve bir result
tensörü üretir. start_indices
, potansiyel düzenlemeye tabi olan her boyut için dilimin başlangıç dizinlerini, slice_sizes
ise her boyut için dilimin boyutlarını içerir. Daha resmi şekilde,
result[result_index] = operand[operand_index]
burada:
adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes)
.operand_index = adjusted_start_indices + result_index
.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
tensör veya tensör başına ölçülebilir tensör | (C1), (C2), (C4) |
(I2) | start_indices |
tam sayı türündeki 0 boyutlu tensörlerin değişken sayısı | (C2), (C3) |
(I3) | slice_sizes |
si64 türünde 1 boyutlu tensör sabiti |
(C2), (C4), (C5) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
tensör veya tensör başına ölçülebilir tensör | (C1), (C5) |
Sınırlamalar
- (C1)
element_type(operand) = element_type(result)
. - (C2)
size(start_indices) = size(slice_sizes) = rank(operand)
. - (C3)
same(type(start_indices...))
. - (C4)
0 <= slice_sizes <= shape(operand)
. - (C5)
shape(result) = slice_sizes
.
Örnekler
// %operand: [
// [0, 0, 1, 1],
// [0, 0, 1, 1],
// [0, 0, 0, 0],
// [0, 0, 0, 0]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
slice_sizes = dense<[2, 2]> : tensor<2xi64>
} : (tensor<4x4xi32>, tensor<i64>, tensor<i64>) -> tensor<2x2xi32>
// %result: [
// [1, 1],
// [1, 1]
// ]
dynamic_update_slice
Anlambilim
start_indices
ile başlayan dilimin update
içindeki değerlerle güncellenmesi dışında operand
tensörüne eşit bir result
tensörü oluşturur.
Daha resmi bir şekilde, result[result_index]
şu şekilde tanımlanır:
update[update_index]
0 <= update_index < shape(update)
ise:adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update))
.update_index = result_index - adjusted_start_indices
.
- Aksi takdirde
operand[result_index]
.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
tensör veya tensör başına ölçülebilir tensör | (C1-C4), (C6) |
(I2) | update |
tensör veya tensör başına ölçülebilir tensör | (C2), (C3), (C6) |
(I3) | start_indices |
tam sayı türündeki 0 boyutlu tensörlerin değişken sayısı | (C4), (C5) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
tensör veya tensör başına ölçülebilir tensör | (C1) |
Sınırlamalar
- (C1)
type(operand) = type(result)
. - (C2)
element_type(update) = element_type(operand)
. - (C3)
rank(update) = rank(operand)
. - (C4)
size(start_indices) = rank(operand)
. - (C5)
same(type(start_indices...))
. - (C6)
0 <= shape(update) <= shape(operand)
.
Örnekler
// %operand: [
// [1, 1, 0, 0],
// [1, 1, 0, 0],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
// %update: [
// [1, 1],
// [1, 1]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
: (tensor<4x4xi32>, tensor<2x2xi32>, tensor<i64>, tensor<i64>) -> tensor<4x4xi32>
// %result: [
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
üstel
Anlambilim
operand
tensöründe öğe bazında üstel işlem gerçekleştirir ve bir result
tensörü üretir. Öğe türüne bağlı olarak aşağıdakileri yapar:
- Kayan reklamlar için: IEEE-754'ten
exp
. - Karmaşık sayılar için: karmaşık üstel.
- Ölçülen türler için:
dequantize_op_quantize(exponential, operand, type(result))
.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
kayan nokta veya karmaşık tür ya da tensör başına ölçülmüş tensör tensörü | (C1) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
kayan nokta veya karmaşık tür ya da tensör başına ölçülmüş tensör tensörü | (C1) |
Sınırlamalar
- (C1)
baseline_type(operand) = baseline_type(result)
.
Örnekler
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]
exponential_minus_one
Anlambilim
Öğe bazında üstel eksi bir işlem operand
tensör üzerinde gerçekleştirir ve bir result
tensör üretir. Öğe türüne bağlı olarak aşağıdakileri yapar:
- Kayan reklamlar için: IEEE-754'ten
expm1
. - Karmaşık sayılar için: karmaşık üstel eksi bir.
- Ölçülen türler için:
dequantize_op_quantize(exponential_minus_one, operand, type(result))
.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
kayan nokta veya karmaşık tür ya da tensör başına ölçülmüş tensör tensörü | (C1) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
kayan nokta veya karmaşık tür ya da tensör başına ölçülmüş tensör tensörü | (C1) |
Sınırlamalar
- (C1)
baseline_type(operand) = baseline_type(result)
.
Örnekler
// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand) : (tensor<2xf64>) -> tensor<2xf64>
// %result: [0.0, 1.71828187]
f.
Anlambilim
Gerçek ve karmaşık girişler/çıkışlar için ileri ve ters Fourier dönüşümlerini uygular.
fft_type
şunlardan biridir:
FFT
: Karmaşıktan karmaşığa FFT'yi yönlendirin.IFFT
: Ters karmaşıktan karmaşığa FFT.RFFT
: Gerçekten karmaşıka FFT'ye yönlendirir.IRFFT
: Reel ile karmaşık FFT'nin tersine (yani karmaşıktır, gerçek sonucunu döndürür).
Daha resmi olarak, girdi olarak karmaşık türlerin 1 boyutlu tensörlerini alan fft
fonksiyonunun çıkışla aynı türde 1 boyutlu tensörler üretmesi ve ayrı Fourier dönüşümünü hesaplar:
fft_type = FFT
için result
, L = size(fft_length)
olduğu bir dizi L hesaplamasının nihai sonucu olarak tanımlanır. Örneğin, L = 3
için:
result1[i0, ..., :] = fft(operand[i0, ..., :])
.result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
.result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])
.
Ayrıca, aynı tür imzaya sahip olan ve fft
değerinin tersini hesaplayan ifft
işlevi verildiğinde:
fft_type = IFFT
için result
, fft_type = FFT
için hesaplamaların tersi olarak tanımlanır. Örneğin, L = 3
için:
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])
.result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
.result[i0, ..., :] = ifft(result2[i0, ..., :])
.
Ayrıca, kayan nokta türlerinin 1 boyutlu tensörlerini alan rfft
işlevi göz önünde bulundurulduğunda, aynı kayan nokta anlamında karmaşık türde 1 boyutlu tensörler oluşturur ve aşağıdaki şekilde çalışır:
rfft(real_operand) = truncated_result
buradacomplex_operand... = (real_operand..., 0.0)
.complex_result = fft(complex_operand)
.truncated_result = complex_result[:(rank(complex_result) / 2 + 1)]
.
(Gerçek işlenenler için ayrı Fourier dönüşümü hesaplandığında, sonucun ilk N/2 + 1
öğeleri sonucun geri kalanını açık bir şekilde tanımlar. Bu nedenle rfft
sonucu, gereksiz öğelerin hesaplanmasını önlemek için kısaltılır).
fft_type = RFFT
için result
, L = size(fft_length)
olduğu bir dizi L hesaplamasının nihai sonucu olarak tanımlanır. Örneğin, L = 3
için:
result1[i0, ..., :] = rfft(operand[i0, ..., :])
.result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
.result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])
.
Son olarak, aynı tür imzaya sahip olan ve rfft
değerinin tersini hesaplayan irfft
fonksiyonunu hatırlayın:
fft_type = IRFFT
için result
, fft_type = RFFT
için hesaplamaların tersi olarak tanımlanır. Örneğin, L = 3
için:
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])
.result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
.result[i0, ..., :] = irfft(result2[i0, ..., :])
.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
kayan nokta veya karmaşık türde tensör | (C1), (C2), (C4), (C5) |
(I2) | fft_type |
FFT , IFFT , RFFT ve IRFFT sıralaması |
(C2), (C5) |
(I3) | fft_length |
si64 türünde 1 boyutlu tensör sabiti |
(C1), (C3), (C4) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
kayan nokta veya karmaşık türde tensör | (C2), (C4), (C5) |
Sınırlamalar
- (C1)
size(fft_length) <= rank(operand)
. - (C2)
operand
veresult
öğe türleri arasındaki ilişki değişiklik gösterir:fft_type = FFT
,element_type(operand)
veelement_type(result)
aynı karmaşık türdeyse.fft_type = IFFT
,element_type(operand)
veelement_type(result)
aynı karmaşık türdeyse.fft_type = RFFT
iseelement_type(operand)
bir kayan nokta türü,element_type(result)
ise aynı kayan nokta anlamının karmaşık bir türüdür.fft_type = IRFFT
iseelement_type(operand)
karmaşık bir tür veelement_type(result)
, aynı kayan nokta semantiğine ait bir kayan nokta türüdür.
- (C3)
1 <= size(fft_length) <= 3
. - (C4)
operand
veresult
arasında bir kayan nokta türündereal
tensörü varsa bu durumdashape(real)[-size(fft_length):] = fft_length
olur. - (C5) Aşağıdakiler hariç
shape(result) = shape(operand)
:fft_type = RFFT
ise,dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1
.fft_type = IRFFT
ise,dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1
.
Örnekler
// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
fft_type = #stablehlo<fft_type FFT>,
fft_length = dense<4> : tensor<1xi64>
} : (tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>>
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]
floor
Anlambilim
Öğe bazında operand
tensör tabanını gerçekleştirir ve result
tensörü üretir.
IEEE-754 spesifikasyonundan roundToIntegralTowardNegative
işlemini uygular. Miktarı ölçülmüş türler için dequantize_op_quantize(floor, operand, type(result))
performans gösterir.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
kayan nokta türü tensörü veya tensör başına ölçülebilir tensör | (C1) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
kayan nokta türü tensörü veya tensör başına ölçülebilir tensör | (C1) |
Sınırlamalar
- (C1)
baseline_type(operand) = baseline_type(result)
.
Örnekler
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]
toplamak
Anlambilim
operand
tensöründen dilimleri start_indices
politikasındaki ofsetlerden toplar ve bir result
tensörü oluşturur.
Aşağıdaki şemada, result
içindeki öğelerin operand
içindeki öğelerle nasıl eşleştiği somut bir örnek üzerinden gösterilmektedir. Şema birkaç örnek result
dizini seçer ve bunların hangi operand
dizinlerine karşılık geldiğini ayrıntılı olarak açıklar.
Daha resmîdir. result[result_index] = operand[operand_index]
. Burada:
batch_dims = [d for d in axes(result) and d not in offset_dims]
.batch_index = result_index[batch_dims...]
.start_index
şu şekilde tanımlanır:index_vector_dim
<rank(start_indices)
isebi
değeribatch_index
öğesindeki bağımsız öğelerdir ve:
index_vector_dim
dizinine eklenir.start_indices[bi0, ..., :, ..., biN]
.- Aksi takdirde
[start_indices[batch_index]]
.
axes(operand)
bölgesinded_operand
içind_operand = start_index_map[d_start]
isefull_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand])
.- Aksi takdirde
full_start_index[d_operand] = 0
.
offset_index = result_index[offset_dims...]
.oi
öğesininoffset_index
içindeki bağımsız öğeler olduğu ve0
değerinincollapsed_slice_dims
içindeki dizinlere eklendiğifull_offset_index = [oi0, ..., 0, ..., oiN]
.operand_index = full_start_index + full_offset_index
.
indices_are_sorted
true
ise uygulama, start_indices
öğesinin start_index_map
göre sıralandığı varsayılabilir, aksi takdirde davranış tanımsız kalır. Daha resmi olarak, indices(result)
tarihinden itibaren tüm i1 < i2
için
full_start_index(i1) <= full_start_index(i2)
.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
tensör veya tensör başına ölçülebilir tensör | (C1), (C7), (C10-C12), (C14) |
(I2) | start_indices |
tam sayı türünde tensör | (C2), (C3), (C13) |
(I3) | offset_dims |
si64 türünde 1 boyutlu tensör sabiti |
(C1), (C4-C5), (C13) |
(I4) | collapsed_slice_dims |
si64 türünde 1 boyutlu tensör sabiti |
(C1), (C6-C8), (C13) |
(I5) | start_index_map |
si64 türünde 1 boyutlu tensör sabiti |
(C3), (C9), (C10) |
(I6) | index_vector_dim |
si64 türünün sabiti |
(C2), (C3), (C13) |
(I7) | slice_sizes |
si64 türünde 1 boyutlu tensör sabiti |
(C8), (C11-C13) |
(I8) | indices_are_sorted |
i1 türünün sabiti |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
tensör veya tensör başına ölçülebilir tensör | (C5), (C13-C14) |
Sınırlamalar
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims)
. - (C2)
0 <= index_vector_dim <= rank(start_indices)
. - (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1
. - (C4)
is_unique(offset_dims) and is_sorted(offset_dims)
. - (C5)
0 <= offset_dims < rank(result)
. - (C6)
is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims)
. - (C7)
0 <= collapsed_slice_dims < rank(operand)
. - (C8)
slice_sizes[collapsed_slice_dims...] <= 1
. - (C9)
is_unique(start_index_map)
. - (C10)
0 <= start_index_map < rank(operand)
. - (C11)
size(slice_sizes) = rank(operand)
. - (C12)
0 <= slice_sizes <= shape(operand)
. - (C13)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)
. Burada:batch_dim_sizes = shape(start_indices)
dışında,index_vector_dim
öğesine karşılık gelenstart_indices
boyut boyutu dahil edilmez.offset_dim_sizes = shape(slice_sizes)
dışında,slice_sizes
içindecollapsed_slice_dims
'ye karşılık gelen boyut boyutları dahil edilmez.combine
,batch_dim_sizes
öğesinibatch_dims
veoffset_dim_sizes
öğesine karşılık gelen eksenlereoffset_dims
değerine karşılık gelen eksenlere yerleştirir.
- (C14)
element_type(operand) = element_type(result)
.
Örnekler
// %operand: [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ]
// %start_indices: [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 2]]
// ]
%result = "stablehlo.gather"(%operand, %start_indices) {
dimension_numbers = #stablehlo.gather<
offset_dims = [2, 3],
collapsed_slice_dims = [0],
start_index_map = [1, 0],
index_vector_dim = 2>,
slice_sizes = dense<[1, 2, 2]> : tensor<3xi64>,
indices_are_sorted = false
} : (tensor<3x4x2xi32>, tensor<2x3x2xi64>) -> tensor<2x3x2x2xi32>
// %result: [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[9, 10], [11, 12]],
// [[11, 12], [13, 14]],
// [[17, 18], [19, 20]]
// ]
// ]
get_dimension_size
Anlambilim
operand
için belirtilen dimension
değerinin boyutunu oluşturur. Daha resmi olarak,
result = dim(operand, dimension)
.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
tensor | (C1) |
(I2) | dimension |
si64 türünün sabiti |
(C1) |
Çıkışlar
Ad | Tür |
---|---|
result |
si32 türünde 0 boyutlu tensör |
Sınırlamalar
- (C1)
0 <= dimension < rank(operand)
.
Örnekler
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
dimension = 1 : i64
} : (tensor<2x3xi64>) -> tensor<i32>
// %result: 3
get_tuple_element
Anlambilim
operand
defterinin index
konumundaki öğeyi çıkarır ve result
sonucunu verir. Daha resmi şekilde, result = operand[index]
.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
tuple | (C1), (C2) |
(I2) | index |
si32 türünün sabiti |
(C1), (C2) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
desteklenen tüm türler | (C2) |
Sınırlamalar
- (C1)
0 <= index < size(operand)
. - (C2)
type(result) = tuple_element_types(operand)[index]
.
Örnekler
// %operand: ([1.0, 2.0], (3))
%result = "stablehlo.get_tuple_element"(%operand) {
index = 0 : i32
} : (tuple<tensor<2xf32>, tuple<tensor<i32>>>) -> tensor<2xf32>
// %result: [1.0, 2.0]
if
Anlambilim
pred
değerine bağlı olarak, true_branch
veya false_branch
işlevinde tam olarak bir işlev yürüterek çıktıyı üretir. Daha resmi şekilde, result =
pred ? true_branch() : false_branch()
.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | pred |
i1 türünde 0 boyutlu tensör |
|
(I2) | true_branch |
işlev | (C1-C3) |
(I3) | false_branch |
işlev | (C1), (C2) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
results |
değişken sayıda tensör, ölçülmüş tensör veya jeton | (C3) |
Sınırlamalar
- (C1)
input_types(true_branch) = input_types(false_branch) = []
. - (C2)
output_types(true_branch) = output_types(false_branch)
. - (C3)
type(results...) = output_types(true_branch)
.
Örnekler
// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
"stablehlo.return"(%result_true_branch) : (tensor<i32>) -> ()
}, {
"stablehlo.return"(%result_false_branch) : (tensor<i32>) -> ()
}) : (tensor<i1>) -> tensor<i32>
// %result: 10
hayal etmek
Anlambilim
operand
öğesinden öğe düzeyinde sanal bölümü çıkarır ve bir result
tensörü üretir. Daha resmi bir şekilde, her x
öğesi için:
imag(x) = is_complex(x) ? imaginary_part(x) :
constant(0, element_type(result))
.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
kayan nokta veya karmaşık türde tensör | (C1), (C2) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
kayan nokta türünün tensörü | (C1), (C2) |
Sınırlamalar
- (C1)
shape(result) = shape(operand)
. - (C2)
element_type(result)
şu şekilde tanımlanır:is_complex(operand)
isecomplex_element_type(element_type(operand))
.- Aksi takdirde
element_type(operand)
.
Örnekler
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [2.0, 4.0]
feed içi
Anlambilim
Feed içi verileri okur ve results
değerini oluşturur.
infeed_config
semantiği, uygulama tarafından tanımlanmıştır.
results
, başta gelen yük değerleri ve en son gelen jetondan oluşur. Gelecekte, netliği artırmak için yükü ve jetonu iki ayrı çıkışa bölmeyi planlıyoruz (#670).
Girişler
Etiket | Ad | Tür |
---|---|---|
(I1) | token |
token |
(I2) | infeed_config |
string türünün sabiti |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
results |
değişken sayıda tensör, ölçülmüş tensör veya jeton | (C1-C3) |
Sınırlamalar
- (C1)
0 < size(results)
. - (C2)
is_empty(result[:-1])
veyais_tensor(type(results[:-1]))
. - (C3)
is_token(type(results[-1]))
.
Örnekler
// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]
Iota
Anlambilim
Bir output
tensörünü, iota_dimension
boyutu boyunca sıfırdan başlayıp artan sırada değerlerle doldurur. Daha resmî bir şekilde ifade etmek
output[result_index] = constant(is_quantized(output) ?
quantize(result_index[iota_dimension], element_type(output)) :
result_index[iota_dimension], element_type(output))
.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | iota_dimension |
si64 |
(C1) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
output |
tam sayı, kayan nokta veya karmaşık tür tensörü ya da tensör başına ölçülmüş tensör | (C1) |
Sınırlamalar
- (C1)
0 <= iota_dimension < rank(output)
.
Örnekler
%output = "stablehlo.iota"() {
iota_dimension = 0 : i64
} : () -> tensor<4x5xi32>
// %output: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
%output = "stablehlo.iota"() {
iota_dimension = 1 : i64
} : () -> tensor<4x5xi32>
// %output: [
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4]
// ]
is_finite
Anlambilim
x
nesnesindeki değerin sonlu olup olmadığını (yani ne +Inf, -Inf veya NaN olmadığını) öğe bazında kontrol eder ve bir y
tensörü üretir. IEEE-754 spesifikasyonundan isFinite
işlemini uygular. Nicelleştirilmiş türler için sonuç her zaman true
olur.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | x |
kayan nokta türü tensörü veya tensör başına ölçülebilir tensör | (C1) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
y |
boole türü tensör | (C1) |
Sınırlamalar
- (C1)
shape(x) = shape(y)
.
Örnekler
// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x) : (tensor<7xf64) -> tensor<7xi1>
// %y: [false, false, false, true, true, true, true]
log
Anlambilim
operand
tensörü üzerinde öğe bazında logaritma işlemi gerçekleştirir ve bir result
tensörü üretir. Öğe türüne bağlı olarak aşağıdakileri yapar:
- Kayan reklamlar için: IEEE-754'ten
log
. - Karmaşık sayılar için: karmaşık logaritma.
- Nicel türler için:
dequantize_op_quantize(log, operand, type(result))
.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
kayan nokta veya karmaşık tür ya da tensör başına ölçülmüş tensör tensörü | (C1) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
kayan nokta veya karmaşık tür ya da tensör başına ölçülmüş tensör tensörü | (C1) |
Sınırlamalar
- (C1)
baseline_type(operand) = baseline_type(result)
.
Örnekler
// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]
log_plus_one
Anlambilim
Öğe bazında logaritma ve operand
tensöründe bir işlem gerçekleştirir ve bir result
tensörü oluşturur. Öğe türüne bağlı olarak aşağıdakileri yapar:
- Kayan reklamlar için: IEEE-754'ten
logp1
. - Karmaşık sayılar için: karmaşık logaritma artı bir.
- Ölçülen türler için:
dequantize_op_quantize(log_plus_one, operand, type(result))
.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
kayan nokta veya karmaşık tür ya da tensör başına ölçülmüş tensör tensörü | (C1) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
kayan nokta veya karmaşık tür ya da tensör başına ölçülmüş tensör tensörü | (C1) |
Sınırlamalar
- (C1)
baseline_type(operand) = baseline_type(result)
.
Örnekler
// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]
lojistik
Anlambilim
operand
tensörü üzerinde öğe düzeyinde lojistik işlem gerçekleştirir ve bir result
tensörü üretir. Öğe türüne bağlı olarak aşağıdakileri yapar:
- Kayan reklamlar için: IEEE-754'ten
division(1, addition(1, exp(-x)))
. - Karmaşık sayılar için: karmaşık lojistik.
- Ölçülen türler için:
dequantize_op_quantize(logistic, operand, type(result))
.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
kayan nokta veya karmaşık tür ya da tensör başına ölçülmüş tensör tensörü | (C1) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
kayan nokta veya karmaşık tür ya da tensör başına ölçülmüş tensör tensörü | (C1) |
Sınırlamalar
- (C1)
baseline_type(operand) = baseline_type(result)
.
Örnekler
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]
harita
Anlambilim
dimensions
boyunca inputs
öğesine bir computation
harita işlevi uygular ve bir result
tensörü oluşturur.
Daha resmi şekilde, result[result_index] = computation(inputs...[result_index])
.
dimensions
etiketinin şu anda kullanılmadığını ve gelecekte kaldırılma ihtimalinin yüksek olduğunu unutmayın (#487).
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | inputs |
değişken sayıda tensör veya tensör başına ölçülmüş tensör | (C1-C4) |
(I2) | dimensions |
si64 türünde 1 boyutlu tensör sabiti |
(C3) |
(I3) | computation |
işlev | (C4) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
tensör veya tensör başına ölçülebilir tensör | (C1), (C4) |
Sınırlamalar
- (C1)
shape(inputs...) = shape(result)
. - (C2)
0 < size(inputs) = N
. - (C3)
dimensions = range(rank(inputs[0]))
. - (C4)
computation
,Ei = element_type(inputs[i])
veE' = element_type(result)
olacak şekilde(tensor<E0>, ..., tensor<EN-1>) -> tensor<E'>
türündedir.
Örnekler
// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = stablehlo.multiply %arg0, %arg1 : tensor<i64>
stablehlo.return %0 : tensor<i64>
}) {
dimensions = dense<[0, 1]> : tensor<2xi64>
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[0, 5], [12, 21]]
maksimum
Anlambilim
lhs
ve rhs
tensörlerinde öğe bazında maksimum işlem gerçekleştirir ve bir result
tensörü üretir. Öğe türüne bağlı olarak aşağıdakileri yapar:
- Boole için: mantıksal OR.
- Tam sayılar için: maksimum tam sayı.
- Kayan reklamlar için: IEEE-754'ten
maximum
. - Karmaşık sayılar için:
(real, imaginary)
çiftinin sözlükteki maksimum değeri. Karmaşık sayılara sıralama uygulamak şaşırtıcı anlamlar içerir. Bu nedenle, gelecekte bu işleme yönelik karmaşık sayılara yönelik desteği kaldırmayı planlıyoruz (#560). - Nicel türler için:
dequantize_op_quantize(maximum, lhs, rhs, type(result))
.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | lhs |
tensör veya tensör başına ölçülebilir tensör | (C1) |
(I2) | rhs |
tensör veya tensör başına ölçülebilir tensör | (C1) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
tensör veya tensör başına ölçülebilir tensör | (C1) |
Sınırlamalar
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Örnekler
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 8]]
minimum
Anlambilim
lhs
ve rhs
tensörlerinde öğe bazında minimum işlem ve result
tensör üretir. Öğe türüne bağlı olarak aşağıdakileri yapar:
- Boole için: mantıksal VE.
- Tam sayılar için: minimum tam sayı.
- Kayan reklamlar için: IEEE-754'ten
minimum
. - Karmaşık sayılar için:
(real, imaginary)
çifti için sözlükte belirtilen minimum değer. Karmaşık sayılara sıralama uygulamak şaşırtıcı anlamlar içerir. Bu nedenle, gelecekte bu işleme yönelik karmaşık sayılara yönelik desteği kaldırmayı planlıyoruz (#560). - Nicel türler için:
dequantize_op_quantize(minimum, lhs, rhs, type(result))
.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | lhs |
tensör veya tensör başına ölçülebilir tensör | (C1) |
(I2) | rhs |
tensör veya tensör başına ölçülebilir tensör | (C1) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
tensör veya tensör başına ölçülebilir tensör | (C1) |
Sınırlamalar
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Örnekler
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 4]]
çarpma
Anlambilim
İki tensörün (lhs
ve rhs
) öğe bazlı çarpımını gerçekleştirir ve bir result
tensörü üretir. Öğe türüne bağlı olarak aşağıdakileri yapar:
- Boole için: mantıksal VE.
- Tam sayılar için: tam sayı çarpım.
- Kayan reklamlar için: IEEE-754'ten
multiplication
. - Karmaşık sayılar için: karmaşık çarpma.
- Nicel türler için:
dequantize_op_quantize(multiply, lhs, rhs, type(result))
.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | lhs |
tensör veya tensör başına ölçülebilir tensör | (C1) |
(I2) | rhs |
tensör veya tensör başına ölçülebilir tensör | (C1) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
tensör veya tensör başına ölçülebilir tensör | (C1) |
Sınırlamalar
- (C1)
baseline_type(operand) = baseline_type(result)
.
Örnekler
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 12], [21, 32]]
işareti değiştir
Anlambilim
operand
tensörünün öğe düzeyinde olumsuzluğunu gerçekleştirir ve result
tensörü üretir. Öğe türüne bağlı olarak aşağıdakileri yapar:
- İşaretli tam sayılar için: tam sayı olumsuzlama.
- İmzasız tam sayılar için: bitcast'ten işaretli tam sayıya, tam sayı olumsuzlama, imzasız tam sayıya geri gönderme.
- Kayan reklamlar için: IEEE-754'ten
negate
. - Karmaşık sayılar için: karmaşık olumsuzlama.
- Ölçülen türler için:
dequantize_op_quantize(negate, operand, type(result))
.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
tam sayı, kayan nokta, karmaşık tür veya tensör başına ölçülmüş tensör tensörü | (C1) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
tam sayı, kayan nokta, karmaşık tür veya tensör başına ölçülmüş tensör tensörü | (C1) |
Sınırlamalar
- (C1)
baseline_type(operand) = baseline_type(result)
.
Örnekler
// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand) : (tensor<2xi32>) -> tensor<2xi32>
// %result: [0, 2]
// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"(%operand) : (tensor<1xcomplex<f32>>) -> tensor<1xcomplex<f32>>
// %result: [-2.5, -0.0]
değil
Anlambilim
Öğe bazında NOT tensörü operand
gerçekleştirir ve result
tensör üretir.
Öğe türüne bağlı olarak aşağıdakileri yapar:
- Boole için: mantıksal DEĞİL.
- Tam sayılar için: bit tabanlı DEĞİL.
Bağımsız değişkenler
Ad | Tür | Sınırlamalar |
---|---|---|
operand |
boole veya tam sayı türü tensörü | (C1) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
boole veya tam sayı türü tensörü | (C1) |
Sınırlamalar
- (C1)
type(operand) = type(result)
.
Örnekler
// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand) : (tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[-2, -3], [-4, -5]]
// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"(%operand) : (tensor<2xi1>) -> tensor<2xi1>
// %result: [false, true]
optimization_barrier
Anlambilim
operand
üreten işlemlerin, result
öğesine bağlı tüm işlemlerden önce yürütülmesini sağlar ve derleyici dönüşümlerinin işlemleri bariyer boyunca taşımasını önler. Bunun dışındaki işlem bir kimliktir (yani result = operand
).
Bağımsız değişkenler
Ad | Tür | Sınırlamalar |
---|---|---|
operand |
değişken sayıda tensör, tensör başına ölçülebilir tensör veya jeton | (C1) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
değişken sayıda tensör, tensör başına ölçülebilir tensör veya jeton | (C1) |
Sınırlamalar
- (C1)
type(operand...) = type(result...)
.
Örnekler
// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
// %result0: 0.0
// %result1: 1.0
veya
Anlambilim
İki tensör lhs
ve rhs
için öğe bazında VEYA gerçekleştirir ve bir result
tensörü üretir. Öğe türüne bağlı olarak aşağıdakileri yapar:
- Boole için: mantıksal OR.
- Tam sayılar için: bit tabanlı VEYA.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | lhs |
tam sayı veya boole türünde tensör | (C1) |
(I2) | rhs |
tam sayı veya boole türünde tensör | (C1) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
tam sayı veya boole türünde tensör | (C1) |
Sınırlamalar
- (C1)
type(lhs) = type(rhs) = type(result)
.
Örnekler
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, true]]
feed dışı
Anlambilim
Özet akışına inputs
yazar ve result
jetonu oluşturur.
outfeed_config
semantiği, uygulama tarafından tanımlanmıştır.
Girişler
Etiket | Ad | Tür |
---|---|---|
(I1) | inputs |
değişken sayıda tensör veya sayısal tensör |
(I2) | token |
token |
(I3) | outfeed_config |
string türünün sabiti |
Çıkışlar
Ad | Tür |
---|---|
result |
token |
Örnekler
%result = "stablehlo.outfeed"(%inputs0, %token) {
outfeed_config = ""
} : (tensor<2x2x2xi64>, !stablehlo.token) -> !stablehlo.token
tampon
Anlambilim
operand
öğesini, belirtilen padding_value
ile tensörün elemanlarının arasına ve tensörün etrafına dolgu yaparak genişletir.
edge_padding_low
ve edge_padding_high
, her bir boyutun sırasıyla alt uç (dizin 0'ın yanında) ve üst uç noktasına (en yüksek dizinin yanında) eklenen dolgu miktarını belirtir. Dolgu miktarı negatif olabilir. Negatif dolgunun mutlak değeri, belirtilen boyuttan kaldırılacak öğe sayısını gösterir.
interior_padding
, her bir boyuttaki herhangi iki öğe arasına negatif olmayabilecek dolgu miktarını belirtir. İç dolgu, kenar dolgusundan önce gerçekleşir. Böylece negatif kenar dolgusu, iç dolgulu işlenenden öğeleri kaldırır.
Daha resmi bir şekilde, result[result_index]
şu şekilde tanımlanır:
result_index = edge_padding_low + operand_index * (interior_padding + 1)
iseoperand[operand_index]
.- Aksi takdirde
padding_value
.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
tensör veya tensör başına ölçülebilir tensör | (C1), (C2), (C4) |
(I2) | padding_value |
0 boyutlu tensör veya tensör başına ölçülmüş tensör | (C1) |
(I3) | edge_padding_low |
si64 türünde 1 boyutlu tensör sabiti |
(C1), (C4) |
(I4) | edge_padding_high |
si64 türünde 1 boyutlu tensör sabiti |
(C1), (C4) |
(I5) | interior_padding |
si64 türünde 1 boyutlu tensör sabiti |
(C2-C4) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
tensör veya tensör başına ölçülebilir tensör | (C3-C6) |
Sınırlamalar
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result)
. - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)
. - (C3)
0 <= interior_padding
. - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high
.
Örnekler
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
edge_padding_low = dense<[0, 1]> : tensor<2xi64>,
edge_padding_high = dense<[2, 1]> : tensor<2xi64>,
interior_padding = dense<[1, 2]> : tensor<2xi64>
} : (tensor<2x3xi32>, tensor<i32>) -> tensor<5x9xi32>
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
partition_id
Anlambilim
Mevcut işlemin partition_id
kadarını oluşturur.
Çıkışlar
Ad | Tür |
---|---|
result |
ui32 türünde 0 boyutlu tensör |
Örnekler
%result = "stablehlo.partition_id"() : () -> tensor<ui32>
Popcnt
Anlambilim
operand
tensöründe ayarlanan bit sayısını öğe bazında hesaplar ve bir result
tensörü üretir.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
tam sayı türünde tensör | (C1) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
tam sayı türünde tensör | (C1) |
Sınırlamalar
- (C1)
type(operand) = type(result)
.
Örnekler
// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand) : (tensor<4xi64>) -> tensor<4xi64>
// %result: [0, 1, 1, 7]
güç
Anlambilim
lhs
tensörünün öğe düzeyinde üssünü rhs
tensörü ile gerçekleştirir ve bir result
tensörü oluşturur. Öğe türüne bağlı olarak aşağıdakileri yapar:
- Tam sayılar için: tam sayıların üssünü alma.
- Kayan reklamlar için: IEEE-754'ten
pow
. - Karmaşık sayılar için: karmaşık üsleme.
- Nicel türler için:
dequantize_op_quantize(power, lhs, rhs, type(result))
.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | lhs |
tam sayı, kayan nokta, karmaşık tür veya tensör başına ölçülmüş tensör tensörü | (C1) |
(I2) | rhs |
tam sayı, kayan nokta, karmaşık tür veya tensör başına ölçülmüş tensör tensörü | (C1) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
tam sayı, kayan nokta, karmaşık tür veya tensör başına ölçülmüş tensör tensörü | (C1) |
Sınırlamalar
- (C1)
baseline_type(operand) = baseline_type(result)
.
Örnekler
// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs) : (tensor<6xf64>, tensor<6xf64>) -> tensor<6xf64>
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]
gerçek
Anlambilim
operand
öğesinden öğe olarak gerçek kısmı çıkarır ve bir result
tensörü oluşturur. Daha resmi bir şekilde, her x
öğesi için:
real(x) = is_complex(x) ? real_part(x) : x
.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
kayan nokta veya karmaşık türde tensör | (C1), (C2) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
kayan nokta türünün tensörü | (C1), (C2) |
Sınırlamalar
- (C1)
shape(result) = shape(operand)
. - (C2)
element_type(result)
şu şekilde tanımlanır:is_complex(operand)
isecomplex_element_type(element_type(operand))
.- Aksi takdirde
element_type(operand)
.
Örnekler
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [1.0, 3.0]
Recv
Anlambilim
channel_id
içeren bir kanaldan veri alır ve results
üretir.
is_host_transfer
değeri true
ise işlem, verileri ana makineden aktarır. Aksi takdirde, veriler başka bir cihazdan aktarılır. Bu, uygulama tanımlı
anlamına gelir. Bu işaret, channel_type
içinde sağlanan bilgilerin aynısını oluşturur. Dolayısıyla gelecekte bu işaretlerden yalnızca birini tutmayı planlıyoruz
(#666).
results
, başta gelen yük değerleri ve en son gelen jetondan oluşur. Gelecekte, netliği artırmak için yükü ve jetonu iki ayrı çıkışa bölmeyi planlıyoruz (#670).
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | token |
token |
(C4) |
(I2) | channel_id |
si64 türünün sabiti |
|
(I3) | channel_type |
DEVICE_TO_DEVICE ve HOST_TO_DEVICE sıralaması |
(C1) |
(I4) | is_host_transfer |
i1 türünün sabiti |
(C1) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
results |
değişken sayıda tensör, ölçülmüş tensör veya jeton | (C2-C4) |
Sınırlamalar
- (C1)
channel_type
şu şekilde tanımlanır:is_host_transfer = true
iseHOST_TO_DEVICE
,- Aksi takdirde
DEVICE_TO_DEVICE
.
- (C2)
0 < size(results)
. - (C3)
is_empty(result[:-1])
veyais_tensor(type(results[:-1]))
. - (C4)
is_token(type(results[-1]))
.
Örnekler
%results0, %results1 = "stablehlo.recv"(%token) {
channel_handle = #stablehlo.channel_handle<handle = 1, type = 3>,
is_host_transfer = true
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
reduce
Anlambilim
dimensions
boyunca inputs
ve init_values
için body
azaltma işlevi uygular ve results
tensörleri üretir.
Azaltma sırası uygulama tarafından tanımlanır. Bu nedenle, işlemin tüm uygulamalardaki tüm girişler için aynı sonuçları vermesini sağlamak için body
ve init_values
monoidi oluşturmalıdır. Ancak bu koşul birçok popüler indirim için geçerli değildir. Örneğin, body
için kayan nokta ekleme ve init_values
için sıfırın eklenmesi, aslında bir monoid oluşturmaz çünkü kayan nokta eklemesi ilişkisel değildir.
Daha resmîdir. results...[j0, ..., jR-1] = reduce(input_slices_converted)
. Burada:
input_slices = inputs...[j0, ..., :, ..., jR-1]
; burada:
,dimensions
konumuna eklenir.input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...)
.init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...)
.- Aşağıdaki ikili ağaçlarda
schedule
içinreduce(input_slices_converted) = exec(schedule)
değeri:exec(node) = body(exec(node.left), exec(node.right))
.exec(leaf) = leaf.value
.
schedule
, sıralı geçişi aşağıdakilerden oluşan, uygulama tanımlı bir tam ikili ağaçtır:index
olan artan sözlük sıralamasına göreindex_space(input_slices_converted)
içindeki tümindex
içininput_slices_converted...[index]
değerleri.- Uygulama tanımlı konumlarda, uygulama tanımlı bir miktarda
init_values_converted
ile serpiştirilir.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | inputs |
değişken sayıda tensör veya tensör başına ölçülmüş tensör | (C1-C4), (C6), (C7) |
(I2) | init_values |
0 boyutlu tensörlerin değişken sayısı veya tensör başına ölçülmüş tensörlerin değişken sayısı | (C2), (C3) |
(I3) | dimensions |
si64 türünde 1 boyutlu tensör sabiti |
(C4), (C5), (C7) |
(I4) | body |
işlev | (C6) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
results |
değişken sayıda tensör veya tensör başına ölçülmüş tensör | (C3), (C7), (C8) |
Sınırlamalar
- (C1)
same(shape(inputs...))
. - (C2)
element_type(inputs...) = element_type(init_values...)
. - (C3)
0 < size(inputs) = size(init_values) = size(results) = N
. - (C4)
0 <= dimensions < rank(inputs[0])
. - (C5)
is_unique(dimensions)
. - (C6)
body
,is_promotable(element_type(inputs[i]), Ei)
türünde(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,
tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
türüne sahiptir. - (C7)
shape(results...) = shape(inputs...)
dışında,dimensions
öğesine karşılık geleninputs...
boyut boyutları dahil edilmez. - (C8)
[0,N)
kapsamındaki tümi
içinelement_type(results[i]) = Ei
.
Örnekler
// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
dimensions = dense<1> : tensor<1xi64>
} : (tensor<1x6xi64>, tensor<i64>) -> tensor<1xi64>
// %result = [15]
reduce_precision
Anlambilim
operand
değerini, exponent_bits
ve mantissa_bits
kullanan başka bir kayan nokta türüne ve tekrar orijinal kayan nokta türüne geri döndürerek output
tensörü oluşturur.
Daha resmî olarak:
- Orijinal değerin mantissa bitleri, orijinal değeri
roundToIntegralTiesToEven
semantiği kullanılarakmantissa_bits
ile temsil edilen en yakın değere yuvarlanacak şekilde güncellenir. - Daha sonra,
mantissa_bits
orijinal değerin mantissa bit sayısından küçükse mantissa bitlerimantissa_bits
değerine kısaltılır. - Daha sonra, ara sonucun üs bitleri
exponent_bits
tarafından sağlanan aralığa sığmazsa ara sonuç, orijinal işareti kullanarak sonsuza kadar taşar veya orijinal işareti kullanarak sıfıra geçer. - Miktarı ölçülmüş türler için
dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result))
performans gösterir.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
kayan nokta türü tensörü veya tensör başına ölçülebilir tensör | (C1) |
(I2) | exponent_bits |
si32 türünün sabiti |
(C2) |
(I3) | mantissa_bits |
si32 türünün sabiti |
(C3) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
output |
kayan nokta türü tensörü veya tensör başına ölçülebilir tensör | (C1) |
Sınırlamalar
- (C1)
baseline_type(operand) = baseline_type(output)
. - (C2)
1 <= exponent_bits
. - (C3)
0 <= mantissa_bits
.
Örnekler
// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
exponent_bits = 5 : i32,
mantissa_bits = 10 : i32
} : (tensor<6xf64>) -> tensor<6xf64>
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]
reduce_scatter
Anlambilim
StableHLO işlem tablosundaki her bir işlem grubunda, her bir işlemdeki operand
tensörünün değerleri üzerinde computations
kullanarak azaltma işlemi gerçekleştirir, azaltma sonucunu scatter_dimension
boyunca parçalara ayırır ve result
sonucunu üretmek için ayırma parçalarını işlemler arasında dağıtır.
İşlem, StableHLO işlem tablosunu aşağıdaki şekilde tanımlanan process_groups
olarak böler:
channel_id <= 0 and use_global_device_ids = false
isecross_replica(replica_groups)
.channel_id > 0 and use_global_device_ids = false
isecross_replica_and_partition(replica_groups)
.channel_id > 0 and use_global_device_ids = true
iseflattened_ids(replica_groups)
.
Daha sonra, her process_group
içinde:
reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation)
.parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension)
.receiver_index = process_group.index(receiver)
olacak şekildeprocess_group
bölgesindeki tümsender
içinresult@receiver = parts@sender[receiver_index]
.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
tensör veya tensör başına ölçülebilir tensör | (C1), (C2), (C7), (C8) |
(I2) | scatter_dimension |
si64 türünün sabiti |
(C1), (C2), (C8) |
(I3) | replica_groups |
si64 türünde 2 boyutlu tensör sabiti |
(C3-C5) |
(I4) | channel_id |
si64 türünün sabiti |
(C6) |
(I5) | use_global_device_ids |
i1 türünün sabiti |
(C6) |
(I6) | computation |
işlev | (C7) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
tensör veya tensör başına ölçülebilir tensör | (C8-C9) |
Sınırlamalar
- (C1)
dim(operand, scatter_dimension) % dim(process_groups, 1) = 0
. - (C2)
0 <= scatter_dimension < rank(operand)
. - (C3)
is_unique(replica_groups)
. - (C4)
size(replica_groups)
şu şekilde tanımlanır:cross_replica
kullanılıyorsanum_replicas
.cross_replica_and_partition
kullanılıyorsanum_replicas
.flattened_ids
kullanılıyorsanum_processes
.
- (C5)
0 <= replica_groups < size(replica_groups)
. - (C6)
use_global_device_ids = true
ise,channel_id > 0
. - (C7)
computation
,is_promotable(element_type(operand), E)
olacak şekilde(tensor<E>, tensor<E>) -> (tensor<E>)
türündedir. - (C8)
shape(result) = shape(operand)
, şunlar hariç:dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1)
.
- (C9)
element_type(result) = E
.
Örnekler
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
scatter_dimension = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x4xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[10, 12],
// [18, 20]]
// %result@(1, 0): [[14, 16],
// [22, 24]]
reduce_window
Anlambilim
inputs
ve init_values
pencerelerine body
azaltma işlevi uygular ve results
sonucunu verir.
Aşağıdaki şemada, results...
içindeki öğelerin inputs...
üzerinden nasıl hesaplandığı somut bir örnek kullanılarak gösterilmektedir.
Daha resmi bir şekilde ifade etmek gerekirse:
results...[result_index] = reduce(windows, init_values, axes(inputs...), body)
(reduce bölümünü inceleyin):
padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1)
.window_start = result_index * window_strides
.window_end = window_start + (window_dimensions - 1) * window_dilations + 1
.windows = slice(padded_inputs..., window_start, window_end, window_dilations)
.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | inputs |
değişken sayıda tensör veya tensör başına ölçülmüş tensör | (C1-C4), (C6), (C8), (C10), (C12), (C13), (C15) |
(I2) | init_values |
0 boyutlu tensörlerin değişken sayısı veya tensör başına ölçülmüş tensörlerin değişken sayısı | (C1), (C13) |
(I3) | window_dimensions |
si64 türünde 1 boyutlu tensör sabiti |
(C4), (C5), (C15) |
(I4) | window_strides |
si64 türünde 1 boyutlu tensör sabiti |
(C6), (C7), (C15) |
(I5) | base_dilations |
si64 türünde 1 boyutlu tensör sabiti |
(C8), (C9), (C15) |
(I6) | window_dilations |
si64 türünde 1 boyutlu tensör sabiti |
(C10), (C11), (C15) |
(I7) | padding |
si64 türünde 2 boyutlu tensör sabiti |
(C12), (C15) |
(I8) | body |
işlev | (C13) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
results |
değişken sayıda tensör veya tensör başına ölçülmüş tensör | (C1), (C14-C16) |
Sınırlamalar
- (C1)
0 < size(inputs) = size(init_values) = size(results) = N
. - (C2)
same(shape(inputs...))
. - (C3)
element_type(inputs...) = element_type(init_values...)
. - (C4)
size(window_dimensions) = rank(inputs[0])
. - (C5)
0 < window_dimensions
. - (C6)
size(window_strides) = rank(inputs[0])
. - (C7)
0 < window_strides
. - (C8)
size(base_dilations) = rank(inputs[0])
. - (C9)
0 < base_dilations
. - (C10)
size(window_dilations) = rank(inputs[0])
. - (C11)
0 < window_dilations
. - (C12)
shape(padding) = [rank(inputs[0]), 2]
. - (C13)
body
,is_promotable(element_type(inputs[i]), Ei)
türünde(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,
tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
türüne sahiptir. - (C14)
same(shape(results...))
. - (C15)
shape(results[0]) = num_windows
. Burada:dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1
.padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1]
.dilated_window_shape = (window_dimensions - 1) * window_dilations + 1
.is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape
.num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1
.
- (C16)
[0,N)
kapsamındaki tümi
içinelement_type(results[i]) = Ei
.
Örnekler
// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
window_dimensions = dense<[2, 1]> : tensor<2xi64>,
window_strides = dense<[4, 1]> : tensor<2xi64>,
base_dilations = dense<[2, 1]> : tensor<2xi64>,
window_dilations = dense<[3, 1]> : tensor<2xi64>,
padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<3x2xi64>, tensor<i64>) -> tensor<2x2xi64>
// %result = [[0, 0], [3, 4]]
kalan
Anlambilim
Bölen lhs
ve bölen rhs
tensörlerinin öğe bazında geri kalanını gerçekleştirir ve bir result
tensörü oluşturur.
Daha resmi olarak, sonucun işareti bölmeden alınır ve sonucun mutlak değeri her zaman bölenin mutlak değerinden küçük olur.
Kalanı lhs - d * rhs
olarak hesaplanır. Burada d
şu şekilde verilir:
- Tam sayılar için:
stablehlo.divide(lhs, rhs)
. - Kayan öğeler için: IEEE-754'ten gelen
division(lhs, rhs)
yuvarlama özelliğiyleroundTowardZero
. - Karmaşık sayılar için: TBD (#997).
- Nicel türler için:
dequantize_op_quantize(remainder, lhs, rhs, type(result))
.
Kayan nokta öğesi türleri için bu işlem, IEEE-754 spesifikasyonundaki remainder
işlemiyle zıtlık gösterir. Burada d
, çift eşittir ve lhs/rhs
değerinin tam değerine en yakın integral değerdir.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | lhs |
tam sayı, kayan nokta veya karmaşık tür tensörü ya da tensör başına ölçülmüş tensör | (C1) |
(I2) | rhs |
tam sayı, kayan nokta veya karmaşık tür tensörü ya da tensör başına ölçülmüş tensör | (C1) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
tam sayı, kayan nokta veya karmaşık tür tensörü ya da tensör başına ölçülmüş tensör | (C1) |
Sınırlamalar
- (C1)
baseline_type(operand) = baseline_type(result)
.
Örnekler
// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64>
// %result: [2, -2, 2, -2]
replica_id
Anlambilim
Mevcut işlemin replica_id
kadarını oluşturur.
Çıkışlar
Ad | Tür |
---|---|
result |
ui32 türünde 0 boyutlu tensör |
Örnekler
%result = "stablehlo.replica_id"() : () -> tensor<ui32>
yeniden şekillendirmek
Anlambilim
operand
tensörünü result
tensöre göre yeniden şekillendirir. Kavram olarak bu, aynı standart gösterimi korurken şeklin potansiyel olarak değiştirilmesi anlamına gelir (ör. tensor<2x3xf32>
yerine tensor<3x2xf32>
veya tensor<6xf32>
).
Daha resmi bir şekilde ifade etmek gerekirse result[result_index] = operand[operand_index]
; burada result_index
ve operand_index
, index_space(result)
ve index_space(operand)
sözlük sıralamasında aynı konuma sahiptir.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
tensör veya ölçülmüş tensör | (C1-C3) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
tensör veya ölçülmüş tensör | (C1-C3) |
Sınırlamalar
- (C1)
element_type(result)
değerini veren:!is_per_axis_quantized(operand)
iseelement_type(operand)
.- Aksi takdirde
quantization_dimension(operand)
vequantization_dimension(result)
farklı olabilir. Bunun dışındaelement_type(operand)
.
- (C2)
size(operand) = size(result)
. - (C3)
is_per_axis_quantized(operand)
ise:reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
.reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.
Örnekler
// %operand: [[1, 2, 3], [4, 5, 6]]]
%result = "stablehlo.reshape"(%operand) : (tensor<2x3xi32>) -> tensor<3x2xi32>
// %result: [[1, 2], [3, 4], [5, 6]]
geri al
Anlambilim
Belirtilen dimensions
boyunca operand
içindeki öğelerin sırasını tersine çevirir ve bir result
tensörü oluşturur. Daha resmi şekilde,
result[result_index] = operand[operand_index]
burada:
dimensions
içinded
iseoperand_index[d] = dim(result, d) - result_index[d] - 1
.- Aksi takdirde
operand_index[d] = result_index[d]
.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
tensör veya tensör başına ölçülebilir tensör | (C1), (C3) |
(I2) | dimensions |
si64 türünde 1 boyutlu tensör sabiti |
(C2), (C3) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
tensör veya tensör başına ölçülebilir tensör | (C1), (C3) |
Sınırlamalar
- (C1)
type(operand) = type(result)
. - (C2)
is_unique(dimensions)
. - (C3)
0 <= dimensions < rank(result)
.
Örnekler
// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
dimensions = dense<1> : tensor<1xi64>
} : (tensor<3x2xi32>) -> tensor<3x2xi32>
// %result: [[2, 1], [4, 3], [6, 5]]
rng
Anlambilim
rng_distribution
algoritmasını kullanarak rastgele sayılar üretir ve belirli bir shape
şekli için result
tensörü üretir.
rng_distribution = UNIFORM
ise [a, b)
aralığı boyunca tek tip dağılım izlenerek rastgele sayılar oluşturulur. a >= b
ise davranış tanımsızdır.
rng_distribution = NORMAL
ise rastgele sayılar, ortalama = a
ve standart sapma = b
şeklinde normal dağılım izlenerek oluşturulur.
b < 0
ise davranış tanımsızdır.
Rastgele sayıların tam olarak nasıl oluşturulduğu uygulama tarafından tanımlanır. Örneğin, deterministik olabilir veya olmayabilir ve gizli durumu kullanabilir ya da kullanmayabilirler.
Birçok paydaşla yapılan görüşmelerde bu seçenek etkili bir şekilde kullanımdan kaldırıldı, bu nedenle gelecekte bu işlevi kaldırmayı (#597) planlamayı planlıyoruz.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | a |
Tam sayı, boole veya kayan nokta türünün 0 boyutlu tensörü | (C1), (C2) |
(I2) | b |
Tam sayı, boole veya kayan nokta türünün 0 boyutlu tensörü | (C1), (C2) |
(I3) | shape |
si64 türünde 1 boyutlu tensör sabiti |
(C3) |
(I4) | rng_distribution |
UNIFORM ve NORMAL sıralaması |
(C2) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
tam sayı, boole veya kayan nokta türünde tensör | (C1-C3) |
Sınırlamalar
- (C1)
element_type(a) = element_type(b) = element_type(result)
. - (C2)
rng_distribution = NORMAL
iseis_float(a)
. - (C3)
shape(result) = shape
.
Örnekler
// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
rng_distribution = #stablehlo<rng_distribution UNIFORM>
} : (tensor<i32>, tensor<i32>, tensor<2xi64>) -> tensor<3x3xi32>
// %result: [
// [1, 0, 1],
// [1, 1, 1],
// [0, 0, 0]
// ]
rng_bit_generator
Anlambilim
Başlangıç durumu initial_state
olan sahte rastgele sayı oluşturma algoritmasını kullanarak rng_algorithm
tek tip rastgele bitlerle dolu bir output
ve güncellenmiş çıkış durumu output_state
döndürür. Sonucun initial_state
için belirleyici işlev olacağı garanti edilir, ancak uygulamalar arasında belirleyici olacağı garanti edilmez.
rng_algorithm
şunlardan biridir:
DEFAULT
: Uygulama tanımlı algoritma.THREE_FRY
: Threefry algoritmasının uygulama tanımlı varyantı.*PHILOX
: Philox algoritmasının uygulama tanımlı varyantı.*
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | rng_algorithm |
DEFAULT , THREE_FRY ve PHILOX sıralaması |
(C2) |
(I2) | initial_state |
ui64 türünde 1 boyutlu tensör |
(C1), (C2) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
output_state |
ui64 türünde 1 boyutlu tensör |
(C1) |
output |
tam sayı veya kayan nokta türü tensörü |
Sınırlamalar
- (C1)
type(initial_state) = type(output_state)
. - (C2)
size(initial_state)
şu şekilde tanımlanır:rng_algorithm = DEFAULT
ise uygulama tanımlıdır.rng_algorithm = THREE_FRY
ise2
.rng_algorithm = PHILOX
ise2
veya3
.
Örnekler
// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
rng_algorithm = #stablehlo<rng_algorithm THREE_FRY>
} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x2xui64>)
// %output_state: [1, 6]
// %output: [
// [9236835810183407956, 16087790271692313299],
// [18212823393184779219, 2658481902456610144]
// ]
round_nearest_afz
Anlambilim
operand
tensöründe, en yakın tam sayıya doğru öğe bazında yuvarlama gerçekleştirir ve bağları sıfırdan uzaklaştırarak result
tensörü oluşturur. IEEE-754 spesifikasyonundan roundToIntegralTiesToAway
işlemini uygular. Miktarı ölçülmüş türler için dequantize_op_quantize(round_nearest_afz, operand, type(result))
performans gösterir.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
kayan nokta türü tensörü veya tensör başına ölçülebilir tensör | (C1) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
kayan nokta türü tensörü veya tensör başına ölçülebilir tensör | (C1) |
Sınırlamalar
- (C1)
baseline_type(operand) = baseline_type(result)
.
Örnekler
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]
round_nearest_even
Anlambilim
operand
tensöründe, çift tam sayıya doğru olan bağları kopararak öğe bazında yuvarlama gerçekleştirir ve bir result
tensörü oluşturur. IEEE-754 spesifikasyonundan roundToIntegralTiesToEven
işlemini uygular. Miktarı ölçülmüş türler için dequantize_op_quantize(round_nearest_even, operand, type(result))
performans gösterir.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
kayan nokta türü tensörü veya tensör başına ölçülebilir tensör | (C1) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
kayan nokta türü tensörü veya tensör başına ölçülebilir tensör | (C1) |
Sınırlamalar
- (C1)
baseline_type(operand) = baseline_type(result)
.
Örnekler
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]
rsqrt
Anlambilim
operand
tensörü üzerinde öğe bazında ters karekök işlemi gerçekleştirir ve bir result
tensörü oluşturur. Öğe türüne bağlı olarak aşağıdakileri yapar:
- Kayan reklamlar için: IEEE-754'ten
rSqrt
. - Karmaşık sayılar için: karmaşık ters karekök.
- Nicel türler için:
dequantize_op_quantize(rsqrt, operand, type(result))
.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
kayan nokta veya karmaşık tür ya da tensör başına ölçülmüş tensör tensörü | (C1) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
kayan nokta veya karmaşık tür ya da tensör başına ölçülmüş tensör tensörü | (C1) |
Sınırlamalar
- (C1)
baseline_type(operand) = baseline_type(result)
.
Örnekler
// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.5], [0.33333343, 0.2]]
scatter
Anlambilim
scatter_indices
tarafından belirtilen birkaç dilim, update_computation
kullanılarak updates
değerleriyle güncellenir. Bunun dışında inputs
tensörlerine eşit olan results
tensörleri oluşturur.
Aşağıdaki şemada, updates...
içindeki öğelerin results...
içindeki öğelerle nasıl eşleştiği somut bir örnek üzerinden gösterilmektedir. Şema birkaç örnek updates...
dizini seçer ve bunların hangi results...
dizinleriyle ilişkili olduğunu ayrıntılı olarak açıklar.
Daha resmi şekilde (index_space(updates[0])
içindeki tüm update_index
için):
update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims]
.update_scatter_index = update_index[update_scatter_dims...]
.start_index
şu şekilde tanımlanır:index_vector_dim
<rank(scatter_indices)
isesi
öğesininupdate_scatter_index
nesnesindeki bağımsız öğeler olduğu ve:
öğesininindex_vector_dim
dizinine eklendiğiscatter_indices[si0, ..., :, ..., siN]
.- Aksi takdirde
[scatter_indices[update_scatter_index]]
.
axes(inputs[0])
bölgesinded_input
içind_input = scatter_dims_to_operand_dims[d_start]
isefull_start_index[d_input] = start_index[d_start]
.- Aksi takdirde
full_start_index[d_input] = 0
.
update_window_index = update_index[update_window_dims...]
.wi
öğesininupdate_window_index
içindeki bağımsız öğeler olduğu ve0
değerinininserted_window_dims
içindeki dizinlere eklendiğifull_window_index = [wi0, ..., 0, ..., wiN]
.result_index = full_start_index + full_window_index
.
Bu bilgilere dayanarak results = exec(schedule, inputs)
:
schedule
,index_space(updates[0])
için uygulama tanımlı bir permütasyondur.exec([update_index, ...], results) = exec([...], updated_results)
. Bu durumda:result_index
,shape(results...)
sınırları içindeyseupdates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )
updated_values = update_computation(results...[result_index], updates_converted)
updated_results
,results...[result_index]
updated_values...
olarak ayarlanmışresults
kopyasıdır.- Aksi halde
updated_results = results
.
exec([], results) = results
.
indices_are_sorted
true
ise uygulama, scatter_indices
öğesinin scatter_dims_to_operand_dims
göre sıralandığı varsayılabilir, aksi takdirde davranış tanımsız kalır. Daha resmi şekilde ifade etmek gerekirse indices(result)
tarihinden itibaren tüm i1 < i2
için full_start_index(i1)
<= full_start_index(i2)
.
unique_indices
değeri true
ise uygulama, dağıtılmış tüm result_index
dizinlerinin benzersiz olduğunu varsayabilir. unique_indices
değeri true
ise ancak dağılan dizinler benzersiz değilse davranış tanımsız olur.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | inputs |
değişken sayıda tensör veya tensör başına ölçülmüş tensör | (C1), (C2), (C4-C6), (C10), (C13), (C15-C16) |
(I2) | scatter_indices |
tam sayı türünde tensör | (C4), (C11), (C14) |
(I3) | updates |
değişken sayıda tensör veya tensör başına ölçülmüş tensör | (C3-C6), (C8) |
(I4) | update_window_dims |
si64 türünde 1 boyutlu tensör sabiti |
(C2), (C4), (C7), (C8) |
(I5) | inserted_window_dims |
si64 türünde 1 boyutlu tensör sabiti |
(C2), (C4), (C9), (C10) |
(I6) | scatter_dims_to_operand_dims |
si64 türünde 1 boyutlu tensör sabiti |
(C11-C13) |
(I7) | index_vector_dim |
si64 türünün sabiti |
(C4), (C11), (C14) |
(I8) | indices_are_sorted |
i1 türünün sabiti |
|
(I9) | unique_indices |
i1 türünün sabiti |
|
(I10) | update_computation |
işlev | (C15) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
results |
değişken sayıda tensör veya tensör başına ölçülmüş tensör | (C15-C17) |
Sınırlamalar
- (C1)
same(shape(inputs...))
. - (C2)
rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims)
. - (C3)
same(shape(updates...))
. - (C4)
shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes)
; burada:update_scatter_dim_sizes = shape(scatter_indices)
dışında,index_vector_dim
değerine karşılık gelenscatter_indices
boyut boyutu dahil edilmez.update_window_dim_sizes <= shape(inputs[0])
dışında,inputs[0]
içindeinserted_window_dims
öğesine karşılık gelen boyut boyutları dahil edilmez.combine
,update_scatter_dim_sizes
öğesiniupdate_window_dims
öğesine karşılık gelen eksenlerdeupdate_scatter_dims
veupdate_window_dim_sizes
'ye karşılık gelen eksenlere yerleştirir.
- (C5)
0 < size(inputs) = size(updates) = N
. - (C6)
element_type(updates...) = element_type(inputs...)
. - (C7)
is_unique(update_window_dims) and is_sorted(update_window_dims)
. - (C8)
0 <= update_window_dims < rank(updates[0])
. - (C9)
is_unique(inserted_window_dims) and is_sorted(update_window_dims)
. - (C10)
0 <= inserted_window_dims < rank(inputs[0])
. - (C11)
size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1
. - (C12)
is_unique(scatter_dims_to_operand_dims)
. - (C13)
0 <= scatter_dims_to_operand_dims < rank(inputs[0])
. - (C14)
0 <= index_vector_dim <= rank(scatter_indices)
. - (C15)
update_computation
,(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
türündedir. Buradais_promotable(element_type(inputs[i]), Ei)
bulunur. - (C16)
shape(inputs...) = shape(results...)
. - (C17)
[0,N)
kapsamındaki tümi
içinelement_type(results[i]) = Ei
.
Örnekler
// %input: [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10], [11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ]
// %scatter_indices: [[[0, 2], [1, 0], [2, 1]], [[0, 1], [1, 0], [0, 9]]]
// %update: [
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]],
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]]
// ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [2, 3],
inserted_window_dims = [0],
scatter_dims_to_operand_dims = [1, 0],
index_vector_dim = 2>,
indices_are_sorted = false,
unique_indices = false
} : (tensor<3x4x2xi64>, tensor<2x3x2xi64>, tensor<2x3x2x2xi64>) -> tensor<3x4x2xi64>
// %result: [
// [[1, 2], [5, 6], [7, 8], [7, 8]],
// [[10, 11], [12, 13], [14, 15], [16, 17]],
// [[18, 19], [20, 21], [21, 22], [23, 24]]
// ]
seç
Anlambilim
Her öğenin, karşılık gelen pred
öğesinin değerine göre on_true
veya on_false
tensöründen seçildiği bir result
tensörü oluşturur.
Daha resmi şekilde, result[result_index] = pred_element ? on_true[result_index] :
on_false[result_index]
. Burada: pred_element = rank(pred) = 0 ? pred[] :
pred[result_index]
. Miktarı ölçülmüş türler için dequantize_select_quantize(pred, on_true, on_false, type(result))
performans gösterir.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | pred |
i1 türünde tensör |
(C1) |
(I2) | on_true |
tensör veya tensör başına ölçülebilir tensör | (C1-C2) |
(I3) | on_false |
tensör veya tensör başına ölçülebilir tensör | (C2) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
tensör veya tensör başına ölçülebilir tensör | (C2) |
Sınırlamalar
- (C1)
rank(pred) = 0 or shape(pred) = shape(on_true)
. - (C2)
baseline_type(on_true) = baseline_type(on_false) = baseline_type(result)
.
Örnekler
// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false) : (tensor<2x2xi1>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 2], [3, 8]]
select_and_scatter
Anlambilim
select
kullanarak input
tensörünün reduce_window
sonucuna bağlı olarak scatter
kullanarak source
tensöründen gelen değerleri dağıtır ve bir result
tensörü oluşturur.
Aşağıdaki şemada, result
içindeki öğelerin operand
ve source
değerlerinden nasıl hesaplandığı gösterilmektedir.
Daha resmî olarak:
Şu girişlerle birlikte
selected_values = reduce_window_without_init(...)
:- `girişler = [işleme alınan].
window_dimensions
,window_strides
vepadding
.base_dilations = windows_dilations = 1
.body
şu şekilde tanımlanır:
def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>: return select(arg0, arg1) ? arg0 : arg1;
Bu durumda
E = element_type(operand)
vereduce_window_without_init
, tam olarakreduce_window
işlevine benzer şekilde çalışır. Bununla birlikte, temeldekireduce
schedule
öğesinin (reduce bölümünü inceleyin) başlangıç değerlerini içermez. İlgili pencerede değer yoksa ne olacağı şu an için belirsizdir (#731).result[result_index] = reduce([source_values], [init_value], [0], scatter)
Burada:source_values = [source[source_index] for source_index in source_indices]
.selected_values[source_index]
,operand_index
öğesindenoperand
öğesine sahipseselected_index(source_index) = operand_index
.source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index]
.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
tensör veya tensör başına ölçülebilir tensör | (C1-C4), (C6), (C8-C11) |
(I2) | source |
tensör veya tensör başına ölçülebilir tensör | (C1), (C2) |
(I3) | init_value |
0 boyutlu tensör veya tensör başına ölçülmüş tensör | (C3) |
(I4) | window_dimensions |
si64 türünde 1 boyutlu tensör sabiti |
(C2), (C4), (C5) |
(I5) | window_strides |
si64 türünde 1 boyutlu tensör sabiti |
(C2), (C6), (C7) |
(I6) | padding |
si64 türünde 2 boyutlu tensör sabiti |
(C2), (C8) |
(I7) | select |
işlev | (C9) |
(I8) | scatter |
işlev | (C10) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
tensör veya tensör başına ölçülebilir tensör | (C11-C12) |
Sınırlamalar
- (C1)
element_type(operand) = element_type(source)
. - (C2)
shape(source) = num_windows
; burada:padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1]
.is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape
.num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1
.
- (C3)
element_type(init_value) = element_type(operand)
. - (C4)
size(window_dimensions) = rank(operand)
. - (C5)
0 < window_dimensions
. - (C6)
size(window_strides) = rank(operand)
. - (C7)
0 < window_strides
. - (C8)
shape(padding) = [rank(operand), 2]
. - (C9)
select
,E = element_type(operand)
olacak şekilde(tensor<E>, tensor<E>) -> tensor<i1>
türündedir. - (C10)
scatter
,is_promotable(element_type(operand), E)
burada(tensor<E>, tensor<E>) -> tensor<E>
türündedir. - (C11)
shape(operand) = shape(result)
. - (C12)
element_type(result) = E
.
Örnekler
// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
window_dimensions = dense<[3, 1]> : tensor<2xi64>,
window_strides = dense<[2, 1]> : tensor<2xi64>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi64>, tensor<2x2xi64>, tensor<i64>) -> tensor<4x2xi64>
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]
gönder
Anlambilim
channel_id
kanalına inputs
gönderir ve result
jetonu oluşturur.
is_host_transfer
değeri true
ise işlem, verileri ana makineye aktarır. Aksi takdirde veriler başka bir cihaza aktarılır. Bu, uygulama tanımlı
anlamına gelir. Bu işaret, channel_type
içinde sağlanan bilgilerin aynısını oluşturur. Dolayısıyla gelecekte bu işaretlerden yalnızca birini tutmayı planlıyoruz
(#666).
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | inputs |
değişken sayıda tensör veya sayısal tensör | |
(I2) | token |
token |
|
(I3) | channel_id |
si64 türünün sabiti |
|
(I4) | channel_type |
DEVICE_TO_DEVICE ve DEVICE_TO_HOST sıralaması |
(C1) |
(I5) | is_host_transfer |
i1 türünün sabiti |
(C1) |
Çıkışlar
Ad | Tür |
---|---|
result |
token |
Sınırlamalar
- (C1)
channel_type
şu şekilde tanımlanır:is_host_transfer = true
iseDEVICE_TO_HOST
,- Aksi takdirde
DEVICE_TO_DEVICE
.
Örnekler
%result = "stablehlo.send"(%operand, %token) {
channel_handle = #stablehlo.channel_handle<handle = 1, type = 2>,
is_host_transfer = true
} : (tensor<2x2xi64>, !stablehlo.token) -> !stablehlo.token
shift_left
Anlambilim
lhs
tensöründe rhs
bit sayısı kadar öğe düzeyinde sola kaydırma işlemi gerçekleştirir ve bir result
tensörü üretir.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | lhs |
tam sayı türünde tensör | (C1) |
(I2) | rhs |
tam sayı türünde tensör | (C1) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
tam sayı türünde tensör | (C1) |
Sınırlamalar
- (C1)
type(lhs) = type(rhs) = type(result)
.
Örnekler
// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-2, 0, 8]
shift_right_arithmetic
Anlambilim
lhs
tensöründe rhs
bit sayısı kadar öğe bazında aritmetik sağa kaydırma işlemi gerçekleştirir ve bir result
tensörü üretir.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | lhs |
tam sayı türünde tensör | (C1) |
(I2) | rhs |
tam sayı türünde tensör | (C1) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
tam sayı türünde tensör | (C1) |
Sınırlamalar
- (C1)
type(lhs) = type(rhs) = type(result)
.
Örnekler
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-1, 0, 1]
shift_right_logical
Anlambilim
lhs
tensöründe rhs
bit sayısı kadar öğe bazında mantıksal sağa kaydırma işlemi gerçekleştirir ve bir result
tensörü üretir.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | lhs |
tam sayı türünde tensör | (C1) |
(I2) | rhs |
tam sayı türünde tensör | (C1) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
tam sayı türünde tensör | (C1) |
Sınırlamalar
- (C1)
type(lhs) = type(rhs) = type(result)
.
Örnekler
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [9223372036854775807, 0, 1]
işaret
Anlambilim
Öğe düzeyinde operand
işaretini döndürür ve bir result
tensörü üretir.
Daha resmi bir şekilde, her x
öğesi için anlamlar Python söz dizimi kullanılarak aşağıdaki gibi ifade edilebilir:
def sign(x):
if is_integer(x):
if compare(x, 0, LT, SIGNED): return -1
if compare(x, 0, EQ, SIGNED): return 0
return 1
elif is_float(x):
if is_nan(x): return NaN
if compare(x, -0.0, EQ, FLOAT): return -0.0
if compare(x, +0.0, EQ, FLOAT): return +0.0
if compare(x, 0.0, LT, FLOAT): return -1.0
return 1.0
elif is_complex(x):
if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
return divide(x, convert(abs(x), type(x)))
Miktarı ölçülmüş türler için dequantize_op_quantize(sign, operand, type(result))
performans gösterir.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
işaretli tam sayı, kayan nokta, karmaşık tür veya tensör başına ölçülmüş tensör tensörü | (C1) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
işaretli tam sayı, kayan nokta, karmaşık tür veya tensör başına ölçülmüş tensör tensörü | (C1) |
Sınırlamalar
- (C1)
baseline_type(operand) = baseline_type(result)
.
Örnekler
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
sinüs
Anlambilim
operand
tensörü üzerinde öğe bazında sinüs işlemi gerçekleştirir ve result
tensörü üretir. Öğe türüne bağlı olarak aşağıdakileri yapar:
- Kayan reklamlar için: IEEE-754'ten
sin
. - Karmaşık sayılar için: karmaşık sinüs.
- Nicel türler için:
dequantize_op_quantize(sine, operand, type(result))
.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
kayan nokta veya karmaşık tür ya da tensör başına ölçülmüş tensör tensörü | (C1) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
kayan nokta veya karmaşık tür ya da tensör başına ölçülmüş tensör tensörü | (C1) |
Sınırlamalar
- (C1)
baseline_type(operand) = baseline_type(result)
.
Örnekler
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.sine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [0.0, -1.0]]
dilim
Anlambilim
Statik olarak hesaplanan başlangıç dizinlerini kullanarak operand
öğesinden bir dilim çıkarır ve bir result
tensörü oluşturur. start_indices
her bir boyut için dilimin başlangıç dizinlerini, limit_indices
her boyut için dilimin bitiş dizinlerini (hariç), strides
ise her boyuta ait adımları içerir.
Daha resmi olarak, result[result_index] = operand[operand_index]
burada
operand_index = start_indices + result_index * strides
.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
tensör veya tensör başına ölçülebilir tensör | (C1-C3), (C5) |
(I2) | start_indices |
si64 türünde 1 boyutlu tensör sabiti |
(C2), (C3), (C5) |
(I3) | limit_indices |
si64 türünde 1 boyutlu tensör sabiti |
(C2), (C3), (C5) |
(I4) | strides |
si64 türünde 1 boyutlu tensör sabiti |
(C2), (C4) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
tensör veya tensör başına ölçülebilir tensör | (C1), (C5) |
Sınırlamalar
- (C1)
element_type(operand) = element_type(result)
. - (C2)
size(start_indices) = size(limit_indices) = size(strides) = rank(operand)
. - (C3)
0 <= start_indices <= limit_indices <= shape(operand)
. - (C4)
0 < strides
. - (C5)
shape(result) = ceil((limit_indices - start_indices) / strides)
.
Örnekler
// %operand: [
// [0, 0, 0, 0],
// [0, 0, 1, 1],
// [0, 0, 1, 1]
// ]
%result = "stablehlo.slice"(%operand) {
start_indices = dense<[1, 2]> : tensor<2xi64>,
limit_indices = dense<[3, 4]> : tensor<2xi64>,
strides = dense<1> : tensor<2xi64>
} : (tensor<3x4xi64>) -> tensor<2x2xi64>
// % result: [
// [1, 1],
// [1, 1]
// ]
sıralama
Anlambilim
inputs
öğesinin 1 boyutlu dilimlerini dimension
boyutu boyunca, comparator
ölçütüne göre birlikte sıralar ve results
sonucunu oluşturur.
Diğer işlemlerdeki benzer girişlerin aksine dimension
, aşağıda açıklanan anlamlarla birlikte negatif değerlere izin verir. Gelecekte, tutarlılık nedeniyle buna izin verilmeyebilir (#1377).
is_stable
doğruysa sıralama sabit olur. Diğer bir deyişle, karşılaştırıcıya eşit olduğu kabul edilen öğelerin göreli sırası korunur. Tek bir girişin olduğu durumda, e1
ve e2
iki öğenin yalnızca comparator(e1, e2) = comparator(e2, e1) = false
olması durumunda karşılaştırıcı tarafından eşit olarak kabul edilir. Bunun birden fazla girişe nasıl genelleştirildiğini öğrenmek için aşağıdaki biçimlendirmeye bakın.
Daha resmi şekilde (index_space(results[0])
içindeki tüm result_index
için):
adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension
.riN
öğesininresult_index
içindeki bağımsız öğeler olduğu ve:
öğesininadjusted_dimension
konumuna eklendiğiresult_slice = [ri0, ..., :, ..., riR-1]
.inputs_together = (inputs[0]..., ..., inputs[N-1]...)
.results_together[result_slice] = sort(inputs_together[result_slice], comparator_together)
.- Bu örnekte
sort
, 1 boyutlu bir dilimi azalan düzende sıralar. Bu durumda, sol taraftaki bağımsız değişken sağ taraftaki ikinci bağımsız değişkenden küçüksecomparator_together
,true
değerini döndürür. def comparator_together(lhs_together, rhs_together): args = [] for (lhs_el, rhs_el) in zip(lhs_together, rhs_together): args.append(lhs_el) args.append(rhs_el) return comparator(*args)
(results[0]..., ..., results[N-1]...) = results_together
.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | inputs |
değişken sayıda tensör veya tensör başına ölçülmüş tensör | (C1-C5) |
(I2) | dimension |
si64 türünün sabiti |
(C4) |
(I3) | is_stable |
i1 türünün sabiti |
|
(I4) | comparator |
işlev | (C5) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
results |
değişken sayıda tensör veya tensör başına ölçülmüş tensör | (C2), (C3) |
Sınırlamalar
- (C1)
0 < size(inputs)
. - (C2)
type(inputs...) = type(results...)
. - (C3)
same(shape(inputs...) + shape(results...))
. - (C4)
-R <= dimension < R
; buradaR = rank(inputs[0])
. - (C5)
comparator
,(tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1>
türündedir (Ei = element_type(inputs[i])
).
Örnekler
// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>, %arg3: tensor<i64>):
%predicate = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GT>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"stablehlo.return"(%predicate) : (tensor<i1>) -> ()
}) {
dimension = 0 : i64,
is_stable = true
} : (tensor<2x3xi64>, tensor<2x3xi64>) -> (tensor<2x3xi64>, tensor<2x3xi64>)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]
sqrt
Anlambilim
operand
tensörü üzerinde öğe bazında karekök işlemi gerçekleştirir ve bir result
tensörü üretir. Öğe türüne bağlı olarak aşağıdakileri yapar:
- Kayan reklamlar için: IEEE-754'ten
squareRoot
. - Karmaşık sayılar için: karmaşık karekök.
- Nicel türler için:
dequantize_op_quantize(sqrt, operand, type(result))
.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
kayan nokta veya karmaşık tür ya da tensör başına ölçülmüş tensör tensörü | (C1) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
kayan nokta veya karmaşık tür ya da tensör başına ölçülmüş tensör tensörü | (C1) |
Sınırlamalar
- (C1)
baseline_type(operand) = baseline_type(result)
.
Örnekler
// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [2.0, 3.0]]
çıkar
Anlambilim
İki tensörün (lhs
ve rhs
) öğe bazında çıkarma işlemini gerçekleştirir ve bir result
tensörü üretir. Öğe türüne bağlı olarak aşağıdakileri yapar:
- Tam sayılar için: Tam sayı çıkarma.
- Kayan reklamlar için: IEEE-754'ten
subtraction
. - Karmaşık sayılar için: karmaşık çıkarma işlemi.
- Nicel türler için:
dequantize_op_quantize(subtract, lhs, rhs, type(result))
.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | lhs |
tam sayı, kayan nokta, karmaşık tür veya tensör başına ölçülmüş tensör tensörü | (C1) |
(I2) | rhs |
tam sayı, kayan nokta, karmaşık tür veya tensör başına ölçülmüş tensör tensörü | (C1) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
tam sayı, kayan nokta, karmaşık tür veya tensör başına ölçülmüş tensör tensörü | (C1) |
Sınırlamalar
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Örnekler
// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
// %result: [[1, 2], [3, 4]]
Tanh
Anlambilim
operand
tensörü üzerinde öğe bazında hiperbolik tanjant işlemi gerçekleştirir ve bir result
tensörü oluşturur. Öğe türüne bağlı olarak aşağıdakileri yapar:
- Kayan reklamlar için: IEEE-754'ten
tanh
. - Karmaşık sayılar için: karmaşık hiperbolik tanjant.
- Nicel türler için:
dequantize_op_quantize(tanh, operand, type(result))
.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
kayan nokta veya karmaşık tür ya da tensör başına ölçülmüş tensör tensörü | (C1) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
kayan nokta veya karmaşık tür ya da tensör başına ölçülmüş tensör tensörü | (C1) |
Sınırlamalar
- (C1)
baseline_type(operand) = baseline_type(result)
.
Örnekler
// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand) : (tensor<3xf32>) -> tensor<3xf32>
// %result: [-0.76159416, 0.0, 0.76159416]
ters çevir
Anlambilim
permutation
kullanarak operand
tensörünün boyutlarını değiştirir ve bir result
tensör üretir. Daha resmi olarak, result[result_index] = operand[operand_index]
burada result_index[d] = operand_index[permutation[d]]
.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
tensör veya ölçülmüş tensör | (C1-C4) |
(I2) | permutation |
si64 türünde 1 boyutlu tensör sabiti |
(C2-C4) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
tensör veya ölçülmüş tensör | (C1), (C3-C4) |
Sınırlamalar
- (C1)
element_type(result)
değerini veren:!is_per_axis_quantized(operand)
iseelement_type(operand)
.- Aksi takdirde
quantization_dimension(operand)
vequantization_dimension(result)
farklı olabilir. Bunun dışındaelement_type(operand)
.
- (C2)
permutation
,range(rank(operand))
permütasyonudur. - (C3)
shape(result) = dim(operand, permutation...)
. - (C4)
is_per_axis_quantized(result)
ise,quantization_dimension(operand) = permutation(quantization_dimension(result))
.
Örnekler
// %operand: [
// [[1,2], [3,4], [5,6]],
// [[7,8], [9,10], [11,12]]
// ]
%result = "stablehlo.transpose"(%operand) {
permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
// %result: [
// [[1,7], [3,9], [5,11]],
// [[2,8], [4,10], [6,12]]
// ]
triangular_solve
Anlambilim
Alt veya üst üçgen katsayı matrislerine sahip doğrusal denklem sistemlerini çözer.
Daha resmi bir şekilde ifade etmek gerekirse a
ve b
verildiğinde result[i0, ..., iR-3, :, :]
, left_side
false
olduğunda op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :]
için çözümdür. left_side
false
olduğunda x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :]
için çözüm olur. x
değişkeninin çözümü, op(a)
'ın transpose_a
ile belirlenir. Bu, aşağıdakilerden biri olabilir:true
NO_TRANSPOSE
:a
öğesini olduğu gibi kullanarak işlemi gerçekleştirin.TRANSPOSE
:a
ters çevirmesinde işlem yap.ADJOINT
:a
eşlenik ters çevirmesinde işlem yap.
Giriş verileri, lower
değerinin true
ise alt üçgenden (a
) veya üst üçgenden (a
) farklı olması durumunda okunur. Çıktı verileri aynı üçgende döndürülür; diğer üçgendeki değerler ise uygulama tarafından tanımlanır.
unit_diagonal
doğruysa uygulama, a
köşegen öğelerinin 1'e eşit olduğunu varsayabilir, aksi takdirde davranış tanımsız olur.
Miktarı ölçülmüş türler için dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower,
unit_diagonal, transpose_a), a, b, type(result))
performans gösterir.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | a |
kayan nokta veya karmaşık tür ya da tensör başına ölçülmüş tensör tensörü | (C1-C3) |
(I2) | b |
kayan nokta veya karmaşık tür ya da tensör başına ölçülmüş tensör tensörü | (C1-C4) |
(I3) | left_side |
i1 türünün sabiti |
(C3) |
(I4) | lower |
i1 türünün sabiti |
|
(I5) | unit_diagonal |
i1 türünün sabiti |
|
(I6) | transpose_a |
NO_TRANSPOSE , TRANSPOSE ve ADJOINT sıralaması |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
kayan nokta veya karmaşık tür ya da tensör başına ölçülmüş tensör tensörü | (C1) |
Sınırlamalar
- (C1)
baseline_element_type(a) = baseline_element_type(b)
. - (C2)
2 <= rank(a) = rank(b) = R
. - (C3)
shape(a)
ileshape(b)
arasındaki ilişki aşağıdaki şekilde tanımlanır:shape(a)[:-3] = shape(b)[:-3]
.dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1)
.
- (C4)
baseline_type(b) = baseline_type(result)
.
Örnekler
// %a = [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
// %b = [
// [2.0, 0.0, 0.0],
// [4.0, 8.0, 0.0],
// [6.0, 10.0, 12.0]
// ]
%result = "stablehlo.triangular_solve"(%a, %b) {
left_side = true,
lower = true,
unit_diagonal = false,
transpose_a = #stablehlo<transpose NO_TRANSPOSE>
} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// %result: [
// [2.0, 0.0, 0.0],
// [0.0, 2.0, 0.0],
// [0.0, 0.0, 2.0]
// ]
tuple
Anlambilim
val
değerlerinden bir result
demeti oluşturur.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | val |
değişken sayı | (C1) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
tuple | (C1) |
Sınırlamalar
- (C1)
result
,Ei = type(val[i])
buradatuple<E0, ..., EN-1>
türündedir.
Örnekler
// %val0: [1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1) : (tensor<2xf32>, tuple<tensor<i32>>) -> tuple<tensor<2xf32>, tuple<tensor<i32>>>
// %result: ([1.0, 2.0], (3))
uniform_dequantize
Anlambilim
operand
türü tarafından tanımlanan ölçüm parametrelerine göre ölçülmüş operand
tensörünü, result
kayan nokta tensörüne (result
) öğe bazında dönüştürür.
Daha resmi şekilde, result = dequantize(operand)
.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
sayısallaştırılmış tensör | (C1), (C2) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
kayan nokta türünün tensörü | (C1), (C2) |
Sınırlamalar
- (C1)
shape(operand) = shape(result)
. - (C2)
element_type(result) = expressed_type(operand)
.
Örnekler
// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
// %result: [4.0, 15.0]
uniform_quantize
Anlambilim
Kayan nokta tensörünü veya ölçülmüş tensörü, result
türü tarafından tanımlanan niceleme parametrelerine göre operand
nicel bir tensöre result
dönüştürür.
Daha resmî bir şekilde ifade etmek
is_float(operand)
ise:result = quantize(operand, type(result))
.
is_quantized(operand)
ise:float_result = dequantize(operand)
.result = quantize(float_result, type(result))
.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
kayan nokta veya ölçülmüş tür tensörü | (C1), (C2) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
sayısallaştırılmış tensör | (C1), (C2) |
Sınırlamalar
- (C1)
shape(operand) = shape(result)
. - (C2)
expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand)
.
Örnekler
// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
// %result: [10, 10]
// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-20,0.2:-30}>>
// %result: [20, 45]
süre
Anlambilim
cond
işlevi true
çıktısını verirken body
işlevini 0 veya daha fazla kez yürütürken çıkışı üretir. Daha resmi olarak, anlamlar Python söz dizimi kullanılarak aşağıdaki gibi ifade edilebilir:
internal_state = operand
while cond(*internal_state):
internal_state = body(*internal_state)
results = internal_state
Sonsuz döngünün davranışı TBD'dir (#383).
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | operand |
değişken sayıda tensör, ölçülmüş tensör veya jeton | (C1-C3) |
(I2) | cond |
işlev | (C1) |
(I3) | body |
işlev | (C2) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
results |
değişken sayıda tensör, ölçülmüş tensör veya jeton | (C3) |
Sınırlamalar
- (C1)
cond
,(T0, ..., TN-1) -> tensor<i1>
türündedir; buradaTi = type(operand[i])
. - (C2)
body
,(T0, ..., TN-1) -> (T0, ..., TN-1)
türündedir; buradaTi = type(operand[i])
. - (C3)
type(results...) = type(operand...)
.
Örnekler
// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%cond = "stablehlo.compare"(%arg0, %ten) {
comparison_direction = #stablehlo<comparison_direction LT>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
stablehlo.return %cond : tensor<i1>
}, {
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%new_sum = stablehlo.add %arg1, %one : tensor<i64>
%new_i = stablehlo.add %arg0, %one : tensor<i64>
stablehlo.return %new_i, %new_sum : tensor<i64>, tensor<i64>
}) : (tensor<i64>, tensor<i64>) -> (tensor<i64>, tensor<i64>)
// %results0: 10
// %results1: 10
Xor
Anlambilim
İki tensör lhs
ve rhs
tensörünün öğe düzeyinde ÖZELVEYA gerçekleştirmesini sağlar ve bir result
tensörü üretir. Öğe türüne bağlı olarak aşağıdakileri yapar:
- Boole için: mantıksal ÖZELVEYA.
- Tam sayılar için: bit tabanlı ÖZELVEYA.
Girişler
Etiket | Ad | Tür | Sınırlamalar |
---|---|---|---|
(I1) | lhs |
boole veya tam sayı türü tensörü | (C1) |
(I2) | rhs |
boole veya tam sayı türü tensörü | (C1) |
Çıkışlar
Ad | Tür | Sınırlamalar |
---|---|---|
result |
boole veya tam sayı türü tensörü | (C1) |
Sınırlamalar
- (C1)
type(lhs) = type(rhs) = type(result)
.
Örnekler
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, false]]
Uygulama
Sıralı yürütme
StableHLO programı, main
işlevine giriş değerleri sağlanarak ve çıkış değerleri hesaplanarak yürütülür. Bir fonksiyonun çıkış değerleri, karşılık gelen return
işlemde rootlanmış işlem grafiği yürütülerek hesaplanır.
Yürütme sırası, veri akışıyla uyumlu olduğu sürece (işlemler kullanılmadan önce yürütülürse) uygulama tarafından tanımlanır. StableHLO'da, yan etkiye sahip tüm işlemler bir jeton kullanır ve bir jeton oluşturur (after_all
aracılığıyla birden fazla jeton tek bir jetona çoğaltılabilir). Bu nedenle, yan etkileri yürütme sırası da veri akışıyla uyumludur. Yukarıdaki örnek programın olası yürütme siparişleri şunlardır: %0
→ %1
→ %2
→ %3
→ %4
→ return
veya %3
→ %0
→
%1
→ %2
→ %4
→ return
.
Daha resmi olarak, StableHLO süreci şunların birleşiminden oluşur:
1) bir StableHLO programı, 2) işlem durumları (henüz yürütülmemiştir, zaten yürütülmüştür) ve 3) sürecin üzerinde çalıştığı ara değerler.
Süreç, main
işlevine giriş değerleriyle başlar, işlem durumları ile ara değerleri güncelleyen işlem grafiğinde ilerler ve çıkış değerleriyle biter. Daha fazla biçimselleştirme henüz belirlenmemiştir (#484).
Paralel yürütme
StableHLO programları paralel olarak yürütülebilir ve her ikisi de ui32
türünde olan num_replicas
ile num_partitions
arasındaki 2D işlem ızgarasında düzenlenir.
StableHLO işlem tablosunda num_replicas * num_partitions
StableHLO işlemi aynı anda yürütülmektedir. Her işlemin benzersiz bir process_id = (replica_id, partition_id)
özelliği vardır. Burada replica_ids = range(num_replicas)
ve partition_ids = range(num_partitions)
içinde partition_id
türündeki replica_id
, her ikisi de ui32
türüne sahiptir.
Süreç tablosunun boyutu her program için statik olarak bilinmektedir (gelecekte bunu StableHLO programlarının #650 açık bir parçası yapmayı planlıyoruz) ve süreç tablosundaki konum her süreç için statik olarak bilinmektedir. Her işlem, replica_id
ve partition_id
işlemleri aracılığıyla işlem tablosundaki kendi konumuna erişebilir.
Süreç tablosunda tüm programlar aynı olabilir ("Tek Program, Çoklu Veri" stilinde), farklı olabilir ("Çoklu Program, Çoklu Veri" stilinde) veya ikisi arasında bir durum olabilir. Gelecekte, GSPMD (#619) dahil olmak üzere paralel StableHLO programlarını tanımlamaya yönelik diğer deyimler için de destek sunmayı planlıyoruz.
Süreç tablosunda süreçler çoğunlukla birbirinden bağımsızdır. Ayrı işlem durumları, ayrı giriş/orta/çıkış değerleri vardır ve işlemlerin çoğu, aşağıda açıklanan az sayıda toplu işlem haricinde süreçler arasında ayrı ayrı yürütülür.
İşlemlerin çoğunun yürütülmesinde yalnızca aynı işlemden gelen değerler kullanıldığı göz önünde bulundurulduğunda, bu değerlere adlarıyla atıfta bulunmak genellikle belirsizdir.
Ancak, toplu işlemlerin anlamını açıklarken bu yeterli değildir. Bu da, belirli bir süreçte name
değerine referansta bulunması için name@process_id
gösteriminin oluşmasına yol açar. (Bu açıdan, uygun olmayan name
, name@(replica_id(), partition_id())
için bir kısaltma olarak görülebilir.)
Aşağıda açıklandığı gibi, noktadan noktaya iletişim ve toplu işlemlerle sağlanan senkronizasyon hariç, süreçler genelinde yürütme sırası uygulama tarafından tanımlanır.
Noktadan noktaya iletişim
StableHLO süreçleri, StableHLO kanalları üzerinden birbirleriyle iletişim kurabilir. Bir kanal, si64
türünün pozitif bir kimliğiyle temsil edilir. Çeşitli işlemler aracılığıyla, kanallara değer göndermek ve bu değerleri kanallardan almak mümkündür.
Daha fazla resmileştirme (ör. bu kanal kimliklerinin nereden geldiği, programların bunları nasıl fark ettiği ve bunlar tarafından ne tür bir senkronizasyon başlatılan) TBD'dir (#484).
Canlı iletişim
Her StableHLO işleminin iki akış arayüzüne erişimi vardır:
- Okunabilen feed içi.
- Yazılabilecek feed dışı.
Süreçler arasında iletişim kurmak için kullanılan ve dolayısıyla her iki ucunda da süreçler olan kanalların aksine, feed içi ve feed'lerin diğer son uygulamaları tanımlanmıştır.
Daha fazla resmilik (ör. akış iletişiminin yürütme sırasını nasıl etkilediği ve ne tür senkronizasyonların ortaya çıktığı) TBD'dir (#484).
Ortak operasyonlar
StableHLO'da altı toplu işlem vardır: all_gather
, all_reduce
,
all_to_all
, collective_broadcast
, collective_permute
ve
reduce_scatter
. Tüm bu işlemler, StableHLO işlem tablosundaki işlemleri StableHLO işlem gruplarına ayırır ve her süreç grubu içinde diğer süreç gruplarından bağımsız olarak ortak bir hesaplama yürütür.
Her süreç grubu içinde toplu işlemler bir senkronizasyon bariyeri oluşturabilir. Daha da resmileştirme (ör. bu senkronizasyonun tam olarak ne zaman gerçekleştiği, süreçlerin bu engele tam olarak nasıl ulaştığı ve ulaşmazsa ne olacağı) henüz belirlenmemiştir (#484).
Süreç grubu bölümler arası iletişim içeriyorsa (ör. işlem grubunda bölüm kimlikleri farklı olan süreçler varsa), toplu işlemin yürütülmesi bir kanala ihtiyaç duyar ve toplu işlem, si64
türünde pozitif bir channel_id
sağlamalıdır. Çapraz replika iletişiminde kanallara
ihtiyaç yoktur.
Toplu operasyonlar tarafından gerçekleştirilen hesaplamalar, her işleme özeldir ve yukarıdaki işlem bölümlerinde ayrı ayrı açıklanmıştır. Bununla birlikte, işlem tablosunun süreç gruplarına ayrıldığı stratejiler bu işlemler arasında paylaşılır ve bu bölümde açıklanmaktadır. Daha resmi olarak, StableHLO aşağıdaki dört stratejiyi destekler.
cross_replica
Her süreç grubu içinde yalnızca çapraz replika iletişimleri gerçekleşir. Bu strateji, kopya kimliklerden oluşan bir liste olan replica_groups
öğesini alır ve partition_ids
tarafından replica_groups
Kartezyen çarpımını hesaplar. replica_groups
benzersiz öğelere sahip olmalı ve tüm replica_ids
öğelerini kapsamalıdır. Daha resmi bir şekilde, Python söz dizimini kullanarak:
def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
for partition_id in partition_ids:
process_group = []
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
Örneğin, replica_groups = [[0, 1], [2, 3]]
ve num_partitions = 2
için cross_replica
işlevi [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]]
sonucunu oluşturur.
cross_partition
Her süreç grubu içinde yalnızca bölümler arası iletişimler gerçekleşir. Bu strateji partition_groups
(bölüm kimlikleri listesi listesi) öğesini alır ve replica_ids
ile partition_groups
Kartezyen çarpımını hesaplar.
partition_groups
benzersiz öğelere sahip olmalı ve tüm partition_ids
öğelerini kapsamalıdır.
Daha resmi bir şekilde, Python söz dizimi kullanılarak:
def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
for partition_group in partition_groups:
for replica_id in replica_ids:
process_group = []
for partition_id in partition_group:
process_group.append((replica_id, partition_id))
yield process_group
Örneğin, partition_groups = [[0, 1]]
ve num_replicas = 4
için cross_partition
işlevi [[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]]
sonucunu oluşturur.
cross_replica_and_partition
Her süreç grubu içinde hem kopyalar hem de bölümler arası iletişimler gerçekleşebilir. Bu strateji, replica_groups
(çoğaltma kimlikleri listesi) öğesini alır ve her bir replica_group
öğesinin Kartezyen ürünlerini partition_ids
değerine göre hesaplar. replica_groups
benzersiz öğelere sahip olmalı ve tüm replica_ids
öğelerini kapsamalıdır. Daha resmi bir şekilde, Python söz dizimi kullanılarak:
def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
process_group = []
for partition_id in partition_ids:
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
Örneğin, replica_groups = [[0, 1], [2, 3]]
ve num_partitions = 2
için cross_replica_and_partition
işlevi [[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]]
sonucunu oluşturur.
flattened_ids
Bu strateji, "birleştirilmiş" işlem kimlikleri listelerinden oluşan ve replica_id * num_partitions + partition_id
biçimindeki flattened_id_groups
listesini alır ve bunları işlem kimliklerine dönüştürür. flattened_id_groups
benzersiz öğelere sahip olmalı ve tüm process_ids
öğelerini kapsamalıdır. Daha resmi bir şekilde, Python söz dizimi kullanılarak:
def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
for flattened_id_group in flattened_id_groups:
process_group = []
for flattened_id in flattened_id_group:
replica_id = flattened_id // num_partitions
partition_id = flattened_id % num_partitions
process_group.append((replica_id, partition_id))
yield process_group
Örneğin, flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]]
, num_replicas = 4
ve num_partitions = 2
için flattened_ids
, [[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]]
sonucunu verir.
Doğruluk
Şu anda StableHLO, sayısal doğruluk konusunda garanti vermemektedir ancak bu durum ileride değişebilir (#1156).
Hatalar
StableHLO programları, bağımsız işlemlere yönelik çok sayıda kısıtlama grubuyla doğrulanır. Bu kısıtlamalar, çalışma zamanından önce birçok hata sınıfını ortadan kaldırır. Bununla birlikte, tam sayı taşmaları, sınır dışı erişimler vb. hata durumlarıyla karşılaşmaya devam edebilirsiniz. Açıkça belirtilmediği sürece tüm bu hatalar uygulama tanımlı davranışla sonuçlanır, ancak bu durum ileride değişebilir (#1157).
Bu kurala bir istisna olarak, StableHLO programlarındaki kayan nokta istisnaları iyi tanımlanmış davranışa sahiptir. IEEE-754 standardı tarafından tanımlanan istisnalara (geçersiz işlem, sıfıra bölme, taşma, alt akış veya tam olmayan istisnalar) sonuçlanan işlemler, varsayılan sonuçlar üretir (standartta tanımlandığı gibi) ve ilgili durum işaretini kaldırmadan yürütülmeye devam eder. Standartta raiseNoFlag
istisna işlemeye benzer. Standart olmayan işlemler (ör. karmaşık aritmetik ve belirli transandantal işlevler) ile ilgili istisnalar uygulama tarafından tanımlanır.
Notasyon
Bu dokümanda söz dizimini açıklamak için bu dokümanda EBNF söz diziminin değiştirilmiş ISO lezzetini (ISO/IEC 14977:1996, Wikipedia) kullanılmaktadır: 1) kurallar =
yerine ::=
kullanılarak tanımlanır,
2) Birleştirme, ,
yerine yan yana koyma kullanılarak ifade edilir.
Anlamları açıklamak için (ör. "Türler", "Sabitler" ve "İşlemler" bölümlerinde) Python söz dizimine dayalı formüller kullanıyoruz. Ayrıca dizi işlemlerinin aşağıda açıklandığı gibi kısa bir şekilde ifade edilmesini de destekliyoruz. Bu, küçük kod snippet'leri için iyi sonuç verir ancak daha büyük kod snippet'lerinin gerektiği nadir durumlarda, her zaman açık bir şekilde sunulan vanilla Python söz dizimini kullanırız.
Formüller
dot_general
spesifikasyonundan bir örnek üzerinden formüllerin nasıl çalıştığını inceleyelim. Bu işlemin kısıtlamalarından biri aşağıdaki gibi görünür:
dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)
.
Bu formülde kullanılan adlar iki kaynaktan gelir: 1) genel işlevler (ör. dim
, 2) karşılık gelen program öğesinin üye tanımları, ör. dot_general
"Girişler" bölümünde tanımlanan lhs
, lhs_batching_dimensions
, rhs
ve rhs_batching_dimensions
girişleri.
Yukarıda belirtildiği gibi, bu formülün söz dizimi, kısa ve öz bazı uzantılarla birlikte Python tabanlıdır. Formülü anlamak için vanilla Python söz dizimine dönüştürelim.
Y) Bu formüllerde eşitliği temsil etmek için =
kullanılmaktadır. Dolayısıyla Python söz dizimini elde etmenin ilk adımı, =
yerine ==
ile geçmektir. Örneğin: dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...)
.
B) Ayrıca bu formüller, skaler ifadeleri tensör ifadelere dönüştüren elipsleri (...
) destekler. Özetle f(xs...)
, yaklaşık olarak "xs
tensöründeki her skaler x
için skaler bir f(x)
hesaplayın ve ardından tüm bu skaler sonuçları tensör sonucu olarak birlikte döndürün" anlamına gelir. Vanilla Python söz diziminde, örnek formülümüz şu şekildedir: [dim(lhs, dim1) for dim1 in lhs_batching_dimensions] ==
[dim(rhs, dim2) for dim2 in rhs_batching_dimensions]
.
Elipsler sayesinde, tek tek skalerler düzeyinde çalışmaktan kaçınmak çoğu zaman mümkündür. Bununla birlikte, bazı yanıltıcı durumlarda, gather
spesifikasyonundaki start_indices[bi0, ..., :, ..., biN]
formülündeki gibi alt düzey yarı resmi söz dizimi kullanılabilir. Konuyu kısa ve öz tutmak adına, bu tür söz diziminin vanilla Python'a çevrilmesi konusunda kesin bir formellik vermiyoruz. Bunun nedeni, bu söz diziminin duruma göre hâlâ sezgisel olarak anlaşılabilmesini sağlamaktır.
Belirli formüller opak görünüyorsa lütfen bize bildirin, bunları iyileştirmeye çalışalım.
Ayrıca formüllerde, tensörler, tensör listeleri (ör. çok sayıda tensörden ortaya çıkabilir) dahil her tür listeyi genişletmek için elipslerin kullanıldığını fark edeceksiniz. Bu, kesin bir resmilik sunmadığımız başka bir alandır (ör. listeler StableHLO türü sisteminin bir parçası bile değildir).
C) Kullandığımız son önemli notasyon aracı, örtülü yayındır. StableHLO operasyonu örtülü yayınlamayı desteklemese de formüller aynı zamanda kısa ve öz yayını destekler. Özetle, tensörün beklendiği bir bağlamda skaler bir kullanım kullanılıyorsa skaler beklenen şekle yayınlanır.
dot_general
örneğine devam etmek için aşağıdaki başka bir kısıtlamayı görebilirsiniz:
0 <= lhs_batching_dimensions < rank(lhs)
. dot_general
spesifikasyonunda tanımlandığı gibi lhs_batching_dimensions
bir tensördür ancak hem 0
hem de rank(lhs)
skalerdir. Örtülü yayını uyguladıktan sonra formül [0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)]
haline gelir.
Belirli bir dot_general
işlemine uygulandığında bu formül, boole tensörü olarak değerlendirilir. Formüller kısıtlama olarak kullanıldığında formülün true
veya yalnızca true
öğesi içeren bir tensör olarak değerlendirilmesi durumunda kısıtlama korunur.
Adlar
Formüllerde sözlüksel kapsam şunları içerir: 1) genel işlevler, 2) üye tanımları,
3) yerel tanımlar. Genel işlevlerin listesi aşağıda verilmiştir. Öğe tanımları listesi, gösterimin uygulandığı program öğesine bağlıdır:
- İşlemler için üye tanımları, "Girişler" ve "Çıkışlar" bölümlerinde verilen adları içerir.
- Diğer her şeyde üye tanımları, program öğesinin yapısal parçalarını içerir. Adını ilgili EBNF olmayan terminallerden alır. Çoğu zaman bu yapısal parçaların adları, terminal olmayan bölümlerin adlarının yılan şekline (ör.
IntegerLiteral
=>integer_literal
) dönüştürülmesiyle elde edilir. Ancak bazen adlar işlem sırasında kısaltılır (ör.QuantizationStorageType
=>storage_type
). Bu durumda adlar, "Girdiler" / "işlem bölümleri"nde açık bir şekilde benzer şekilde tanıtılır. - Ayrıca, üye tanımları ilgili program öğesine işaret etmek için her zaman
self
içerir.
Değerler
Formüller değerlendirildiğinde şu değer türleriyle çalışır:
1) Value
(gerçek değerler, ör. dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
;
türlerini her zaman bilirler),
2) Placeholder
(gelecek değerler, ör. lhs
, rhs
veya result
; gerçek değerleri henüz bilinmiyor, yalnızca türleri bilinmektedir),
3) Type
("Türler" bölümünde tanımlanan türler),
4) Function
(Tanımlanan genel işlevler) bölümünde.
Bağlama bağlı olarak, adlar farklı değerleri ifade ediyor olabilir. Daha açık şekilde belirtmek gerekirse işlemler için "Semantik" bölümü (ve diğer program öğelerinin eşdeğerleri) çalışma zamanı mantığını tanımladığından, tüm girişler Value
olarak kullanılabilir.
Bunun aksine, işlemler (ve eşdeğerleri) için "Kısıtlamalar" bölümü, "derleme süresi" mantığını (yani genellikle çalışma zamanından önce yürütülen bir şey) tanımlar. Bu nedenle, yalnızca sabit girişler Value
olarak, diğer girişler ise yalnızca Placeholder
olarak kullanılabilir.
Adlar | "Anlambilim"de | "Kısıtlamalar"da |
---|---|---|
Genel işlevler | Function |
Function |
Sabit girişler | Value |
Value |
Sabit olmayan girişler | Value |
Placeholder |
Çıkışlar | Value |
Placeholder |
Yerel tanımlar | Tanıma göre değişir | Tanıma göre değişir |
Örnek bir transpose
işlemi ele alalım:
%result = "stablehlo.transpose"(%operand) {
permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
Bu işlem için permutation
sabit bir değerdir. Bu nedenle, hem anlam hem de kısıtlamalar açısından Value
olarak kullanılabilir. Öte yandan, operand
ve result
, anlamsal olarak Value
olarak ancak kısıtlamalarda yalnızca bir Placeholder
olarak kullanılabilir.
İşlevler
İnşaat türü
Tür oluşturmak için kullanılabilecek işlev yoktur. Bunun yerine, genellikle daha kısa ve öz olduğu için doğrudan tür söz dizimini kullanırız. Ör. function_type(
[tensor_type([], E), tensor_type([], E)], [tensor_type([], E)])
yerine (tensor<E>, tensor<E>) -> (tensor<E>)
.
Türlerdeki işlevler
element_type
, sırasıyla karşılık gelenTensorType
veyaQuantizedTensorType
öğesininTensorElementType
ya daQuantizedTensorElementType
bölümünde tensör türleri ve ölçülmüş tensör türleri ve dönüşlerde tanımlanır.
def element_type(x: Value | Placeholder | Type):
if type(x) == TensorType:
return tensor_element_type(x)
if type(x) == QuantizedTensorType:
return quantized_tensor_element_type(x)
if type(x) is not Type:
return element_type(type(x))
is_per_axis_quantized(x: Value | Placeholder | Type) -> Value
,is_quantized(x) and quantization_dimension(x) is not None
için bir kısayoldur.is_per_tensor_quantized(x: Value | Placeholder | Type) -> Value
,is_quantized(x) and quantization_dimension(x) is None
için bir kısayoldur.is_promotable(x: Type, y: Type) -> bool
,x
türününy
türüne yükseltilip yükseltilemeyeceğini kontrol eder.x
vey
QuantizedTensorElementType
olduğunda promosyon yalnızcastorage_type
için geçerlidir. Promosyonun bu sürümü, şu anda hesaplamayı azaltmak için kullanılmaktadır (daha fazla ayrıntı için RFC bölümüne bakın).
def is_promotable(x: Type, y: Type) -> Value:
is_same_type = (is_bool(x) and is_bool(y)) or
(is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
(is_complex(x) and is_complex(y)) or
(is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))
if is_same_type == False:
return False
if is_integer(x) or is_float(x):
return bitwidth(x) <= bitwidth(y)
if is_complex(x):
return bitwidth(element_type(x)) <= bitwidth(element_type(y))
if is_quantized(x):
return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))
return false
is_quantized(x: Value | Placeholder | Type) -> Value
,is_quantized_tensor_element_type(x)
için bir kısayoldur.is_type_name(x: Value | Placeholder | Type) -> Value
. Tüm türler için kullanılabilir. Örneğin,x
birFloatType
iseis_float(x)
,true
değerini döndürür.x
bir değer veya yer tutucuysa bu işlevis_type_name(type(x))
için bir kısayoldur.max_value(x: Type) -> Value
,TensorElementType
değerinin maksimum değerini döndürür.x
birTensorElementType
değilseNone
değerini döndürür.min_value(x: Type) -> Value
,TensorElementType
özelliğinin mümkün olan en düşük değerini döndürür.x
birTensorElementType
değilseNone
değerini döndürür.member_name(x: Value | Placeholder | Type) -> Any
. Her türdenmember_name
üye tanımlarının tamamı için kullanılabilir. Örneğin,tensor_element_type(x)
, karşılık gelen birTensorType
öğesininTensorElementType
kısmını döndürür.x
bir değer veya yer tutucuysa bu işlevmember_name(type(x))
için bir kısayoldur.x
uygun bir üyeye sahip bir tür veya böyle bir türde değer ya da yer tutucu değilseNone
değerini döndürür.
Değer oluşturma
operation_name(*xs: Value | Type) -> Value
. Tüm işlemler için kullanılabilir. Örneğinadd(lhs, rhs)
, iki tensör değeri (lhs
verhs
) alır veadd
işlemini bu girişlerle değerlendirmenin çıkışını döndürür.broadcast_in_dim
gibi bazı işlemler için çıkış türleri "yük taşıyan"dır. Yani bir işlemin değerlendirilmesi için gereklidir. Bu durumda, işlev bu türleri bağımsız değişken olarak alır.
Değerler üzerindeki işlev
Python'un tüm operatör ve işlevleri kullanılabilir. Örneğin, Python'dan hem abonelik hem de dilimleme gösterimleri, tensörler, ölçülmüş tensörler ve tup'lar olarak dizine eklenebilir.
to_destination_type(x: Value, destination_type: Type) -> Value
, tensörler üzerinde tanımlanır vetype(x)
vedestination_type
özelliklerine göre dönüştürülmüşx
değerini aşağıdaki gibi döndürür:
def to_destination_type(x: Value, destination_type: Type) -> Value:
if type(x) == destination_type:
return x
if is_quantized(destination_type):
if is_quantized(type(x)):
return quantize(x, destination_type)
assert is_float(type(x))
return quantize(x, destination_type)
if is_quantized(type(x)):
assert destination_type = expressed_type(type(x))
return dequantize(type(x))
return convert(x, destination_type)
convert
, uniform_quantize
ve uniform_dequantize
işlemlerinin birleştirilmesiyle ilgili başlangıç tartışmaları bulunmaktadır (#1576).
Birleştirme işleminden sonra yukarıdaki işleve ihtiyacımız olmaz ve bunun yerine convert
için işlem adını kullanabiliriz.
is_nan(x: Value) -> Value
, tensörlerde tanımlanır vex
'nin tüm öğeleriNaN
veyafalse
isetrue
değerini döndürür.x
bir tensör değilseNone
değerini döndürür.is_sorted(x: Value) -> Value
tensörler üzerinde tanımlanır vex
öğeleri, dizinlerinin artan sözlük sıralamasına göre artan düzende sıralanmışsatrue
aksi takdirdefalse
döndürülür.x
bir tensör değilseNone
değerini döndürür.is_unique(x: Value) -> Value
tensörlerde tanımlanır vex
yinelenen öğeler içermiyorsatrue
, aksi haldefalse
değerini döndürür.x
bir tensör değilseNone
değerini döndürür.member_name(x: Value) -> Any
, tüm değerlerinmember_name
üye tanımlarının tamamı için tanımlandı. Örneğin,real_part(x)
, karşılık gelen birComplexConstant
öğesininRealPart
bölümünü döndürür.x
uygun bir üye içeren bir değer değilseNone
değerini döndürür.same(x: Value) -> Value
tensörlerde tanımlanır vex
öğelerinin tamamı birbirine eşitsetrue
veya aksi haldefalse
değerini döndürür. Tensörde öğe yoksa bu, "tümü birbirine eşit" olarak sayılır. Yani, işlevtrue
değerini döndürür.x
bir tensör değilseNone
değerini döndürür.split(x: Value, num_results: Value, axis: Value) -> Value
tensörlerde tanımlanır veaxis
ekseni boyuncanum_results
x
dilim döndürür.x
bir tensör değilse veyadim(x, axis) % num_results != 0
,None
değerini döndürür.
Şekil hesaplamaları
axes(x: Value | Placeholder | Type) -> Value
,range(rank(x))
için bir kısayoldur.dim(x: Value | Placeholder | Type, axis: Value) -> Value
,shape(x)[axis]
için bir kısayoldur.dims(x: Value | Placeholder | Type, axes: List) -> List
,list(map(lambda axis: dim(x, axis), axes))
için bir kısayoldur.index_space(x: Value | Placeholder | Type) -> Value
tensörlerde tanımlanır ve karşılık gelenTensorType
için artan sözlük sıralamasına göre sıralanmışsize(x)
dizinleri döndürür (ör.[0, ..., 0]
,[0, ..., 1]
, ...,shape(x) - 1
).x
bir tensör türü, ölçülmüş bir tensör türü veya bu türlerden birinin değeri veya yer tutucusu değilseNone
değerini döndürür.rank(x: Value | Placeholder | Type) -> Value
,size(shape(x))
için bir kısayoldur.shape(x: Value | Placeholder | Type) -> Value
,member_name
aracılığıyla "Türlerdeki işlevler" bölümünde tanımlanmıştır.size(x: Value | Placeholder | Type) -> Value
,reduce(lambda x, y: x * y, shape(x))
için bir kısayoldur.
Nicelleştirme hesaplamaları
def baseline_element_type(x: Value | Placeholder | Type) -> Type
,element_type(baseline_type(x))
için bir kısayoldur.baseline_type
, tensör türlerinde ve ölçülmüş tensör türlerinde tanımlanır ve bunları "referans" değerine, yani aynı şekle sahip ancak öğe türünün niceleme parametreleri varsayılan değerlere sıfırlanan bir türe dönüştürür. Bu yöntem, hem tensör hem de ölçülmüş tensör türlerini eşit şekilde karşılaştırmak için pratik bir yöntem olarak kullanılır. Bu işlem sık sık kullanılır. Nicelleştirilmiş türlerde bu, nicelik parametreleri yok sayılmadan türlerin karşılaştırılmasını sağlar. Yani,shape
,storage_type
,expressed_type
,storage_min
,storage_max
vequantization_dimension
(eksen başına ölçülmüş tür için) tümünün eşleşmesi gerekir ancakscales
vezero points
farklı olabilir.
def baseline_type(x: Value | Placeholder | Type) -> Type:
if type(x) == TensorType:
return x
if type(x) == QuantizedTensorType:
element_type = quantized_tensor_element_type(x)
baseline_element_type = QuantizedTensorElementType(
storage_type = storage_type(element_type),
storage_min = storage_min(element_type),
storage_max = storage_max(element_type),
expressed_type = expressed_type(element_type),
quantization_dimension = quantization_dimension(element_type),
scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
return QuantizedTensorType(shape(x), baseline_element_type)
if type(x) is not Type:
return baseline_element_type(type(x))
dequantize
, ölçülmüş tensör türlerinde tanımlanır ve bunları kayan nokta tensör türlerine dönüştürür. Bu, depolama türünün tam sayı değerlerini temsil eden nicelenmiş öğelerin, ölçülmüş öğe türüyle ilişkili sıfır noktası ve ölçek kullanılarak ifade edilen türün karşılık gelen kayan nokta değerlerine dönüştürülmesiyle gerçekleşir.
def compute_zero_points(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
zero_points[i] = zero_points(quantized_type)[i[d]]
return zero_points
def compute_scales(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
type(result_type))
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
scales[i] = scales(quantized_type)[i[d]]
return scales
def dequantize(x: Value) -> Value:
assert is_quantized(x)
x_storage = bitcast_convert(x, storage_type(x))
x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
x_expressed_sub = convert(x_storage_sub, expressed_type(x))
return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
quantize
, kayan nokta tensör türlerinde tanımlanır ve bunları ölçülmüş tensör türlerine dönüştürür. Bu işlem, ifade edilen türdeki kayan nokta değerlerinin, ölçülmüş öğe türüyle ilişkili sıfır noktası ve ölçek kullanılarak depolama türünün karşılık gelen tam sayı değerlerine dönüştürülmesiyle gerçekleşir.
def quantize(x: Value, type: Type) -> Value:
assert is_float(x) and is_quantized(type)
x_expressed_rounded = round_nearest_even(x / compute_scales(type, type(x)))
x_storage_rounded = convert(x_expressed_rounded, storage_type(type))
x_storage_add = x_storage_rounded + compute_zero_points(type, type(x_storage_rounded))
x_storage = clamp(storage_min(type), x_storage_add, storage_max(type))
return bitcast_convert(x_storage, type)
dequantize_op_quantize
, ölçülmüş tensörler üzerinde öğe tabanlı hesaplamaları belirtmek için kullanılır. Bu işlem, nicelikten arındırılır, yani ölçülmüş öğeleri ifade edilen türlere dönüştürür, ardından bir işlem gerçekleştirir ve nicelikleri ölçer. Başka bir deyişle, sonuçları tekrar depolama türlerine dönüştürür. Şu anda bu işlev yalnızca tensör başına ölçüm için çalışmaktadır. Eksen başına ölçümle ilgili çalışmalar devam etmektedir (#1574).
def dequantize_op_quantize(op, *inputs_and_output_type):
inputs = inputs_and_output_type[:-1]
output_type = inputs_and_output_type[-1]
float_inputs = map(dequantize, inputs)
float_result = op(*float_inputs)
return quantize(float_result, output_type)
def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
inputs = inputs_and_output_type[:-3]
float_inputs = map(dequantize, inputs)
float_results = op(*float_inputs)
return map(quantize, float_results, inputs_and_output_type[-3:])
def dequantize_compare(lhs, rhs, comparison_direction):
float_lhs = dequantize(lhs)
float_rhs = dequantize(rhs)
return compare(float_lhs, float_rhs, comparison_direction, FLOAT)
def dequantize_select_quantize(pred, on_true, on_false, output_type):
float_on_true = dequantize(on_true)
float_on_false = dequantize(on_false)
float_result = select(pred, float_on_true, float_on_false)
return quantize(float_result, output_type)
Izgara hesaplamaları
cross_partition(replica_groups: Value) -> Value
. Yukarıdaki "cross_replica" bölümüne bakın.cross_replica(replica_groups: Value) -> Value
. Yukarıdaki "cross_replica" bölümüne bakın.cross_replica_and_partition(replica_groups: Value) -> Value
. Yukarıdaki "cross_replica_and_partition" bölümüne bakın.flattened_ids(replica_groups: Value) -> Value
. Yukarıdaki "flattened_id" bölümüne bakın.