Builder Methods
stablehlo::AbsOp
Creates a new stablehlo.abs
operation.
MlirOp Abs(MlirOp &operand);
stablehlo::AddOp
Creates a new stablehlo.add
operation.
MlirOp Add(MlirOp &lhs, MlirOp &rhs);
stablehlo::AfterAllOp
Creates a new stablehlo.after_all
operation.
MlirOp AfterAll(MlirBuilder &builder, ArrayRef<MlirOp> inputs);
stablehlo::AllGatherOp
Creates a new stablehlo.all_gather
operation.
SmallVector<MlirOp> AllGather(MlirBuilder &builder, Type resultType, ArrayRef<MlirOp> operands, uint64_t all_gather_dim, ::mlir::DenseIntElementsAttr replica_groups, /*optional*/::mlir::stablehlo::ChannelHandleAttr channel_handle = {}, /*optional*/bool use_global_device_ids = false);
stablehlo::AllReduceOp
Creates a new stablehlo.all_reduce
operation.
This operation has a body region built via a callback function.
SmallVector<MlirOp> AllReduce(MlirBuilder &builder, ArrayRef<MlirOp> operands, const RegionBuilderCallback &computation, ::mlir::DenseIntElementsAttr replica_groups, /*optional*/::mlir::stablehlo::ChannelHandleAttr channel_handle = {}, /*optional*/bool use_global_device_ids = false);
stablehlo::AllToAllOp
Creates a new stablehlo.all_to_all
operation.
SmallVector<MlirOp> AllToAll(MlirBuilder &builder, ArrayRef<MlirOp> operands, uint64_t split_dimension, uint64_t concat_dimension, uint64_t split_count, ::mlir::DenseIntElementsAttr replica_groups, /*optional*/::mlir::stablehlo::ChannelHandleAttr channel_handle = {});
stablehlo::AndOp
Creates a new stablehlo.and
operation.
MlirOp And(MlirOp &lhs, MlirOp &rhs);
stablehlo::Atan2Op
Creates a new stablehlo.atan2
operation.
MlirOp Atan2(MlirOp &lhs, MlirOp &rhs);
stablehlo::BatchNormGradOp
Creates a new stablehlo.batch_norm_grad
operation.
SmallVector<MlirOp, 3> BatchNormGrad(MlirOp &operand, MlirOp &scale, MlirOp &mean, MlirOp &variance, MlirOp &grad_output, ::llvm::APFloat epsilon, uint64_t feature_index);
stablehlo::BatchNormInferenceOp
Creates a new stablehlo.batch_norm_inference
operation.
MlirOp BatchNormInference(MlirOp &operand, MlirOp &scale, MlirOp &offset, MlirOp &mean, MlirOp &variance, ::llvm::APFloat epsilon, uint64_t feature_index);
stablehlo::BatchNormTrainingOp
Creates a new stablehlo.batch_norm_training
operation.
SmallVector<MlirOp, 3> BatchNormTraining(MlirOp &operand, MlirOp &scale, MlirOp &offset, ::llvm::APFloat epsilon, uint64_t feature_index);
stablehlo::BitcastConvertOp
Creates a new stablehlo.bitcast_convert
operation.
MlirOp BitcastConvert(Type resultType, MlirOp &operand);
stablehlo::BroadcastInDimOp
Creates a new stablehlo.broadcast_in_dim
operation.
MlirOp BroadcastInDim(Type resultType, MlirOp &operand, ::llvm::ArrayRef<int64_t> broadcast_dimensions);
stablehlo::BroadcastOp
Creates a new stablehlo.broadcast
operation.
MlirOp Broadcast(MlirOp &operand, ::llvm::ArrayRef<int64_t> broadcast_sizes);
stablehlo::CbrtOp
Creates a new stablehlo.cbrt
operation.
MlirOp Cbrt(MlirOp &operand, /*optional*/::mlir::stablehlo::ResultAccuracyAttr result_accuracy = {});
stablehlo::CeilOp
Creates a new stablehlo.ceil
operation.
MlirOp Ceil(MlirOp &operand);
stablehlo::CholeskyOp
Creates a new stablehlo.cholesky
operation.
MlirOp Cholesky(MlirOp &a, /*optional*/bool lower = false);
stablehlo::ClampOp
Creates a new stablehlo.clamp
operation.
MlirOp Clamp(MlirOp &min, MlirOp &operand, MlirOp &max);
stablehlo::ClzOp
Creates a new stablehlo.count_leading_zeros
operation.
MlirOp Clz(MlirOp &operand);
stablehlo::CollectiveBroadcastOp
Creates a new stablehlo.collective_broadcast
operation.
MlirOp CollectiveBroadcast(MlirOp &operand, ::mlir::DenseIntElementsAttr replica_groups, /*optional*/::mlir::stablehlo::ChannelHandleAttr channel_handle = {});
stablehlo::CollectivePermuteOp
Creates a new stablehlo.collective_permute
operation.
MlirOp CollectivePermute(MlirOp &operand, ::mlir::DenseIntElementsAttr source_target_pairs, /*optional*/::mlir::stablehlo::ChannelHandleAttr channel_handle = {});
stablehlo::CompareOp
Creates a new stablehlo.compare
operation.
MlirOp Compare(MlirOp &lhs, MlirOp &rhs, ::mlir::stablehlo::ComparisonDirection comparison_direction, /*optional*/::mlir::stablehlo::ComparisonTypeAttr compare_type = {});
stablehlo::ComplexOp
Creates a new stablehlo.complex
operation.
MlirOp Complex(MlirOp &lhs, MlirOp &rhs);
stablehlo::CompositeOp
Creates a new stablehlo.composite
operation.
SmallVector<MlirOp> Composite(MlirBuilder &builder, Type resultType, ArrayRef<MlirOp> inputs, ::llvm::StringRef name, ::llvm::StringRef decomposition, /*optional*/::mlir::DictionaryAttr composite_attributes = {}, /*optional*/uint32_t version = 0);
stablehlo::ConcatenateOp
Creates a new stablehlo.concatenate
operation.
MlirOp Concatenate(MlirBuilder &builder, ArrayRef<MlirOp> inputs, uint64_t dimension);
stablehlo::ConstantOp
Creates a new stablehlo.constant
operation.
MlirOp Constant(MlirBuilder &builder, ::mlir::ElementsAttr value);
stablehlo::ConvertOp
Creates a new stablehlo.convert
operation.
MlirOp Convert(Type resultType, MlirOp &operand);
stablehlo::ConvolutionOp
Creates a new stablehlo.convolution
operation.
MlirOp Convolution(Type resultType, MlirOp &lhs, MlirOp &rhs, ::mlir::stablehlo::ConvDimensionNumbersAttr dimension_numbers, uint64_t feature_group_count, uint64_t batch_group_count, /*optional*/::mlir::DenseI64ArrayAttr window_strides = {}, /*optional*/::mlir::DenseIntElementsAttr padding = {}, /*optional*/::mlir::DenseI64ArrayAttr lhs_dilation = {}, /*optional*/::mlir::DenseI64ArrayAttr rhs_dilation = {}, /*optional*/::mlir::DenseBoolArrayAttr window_reversal = {}, /*optional*/::mlir::ArrayAttr precision_config = {});
stablehlo::CosineOp
Creates a new stablehlo.cosine
operation.
MlirOp Cosine(MlirOp &operand, /*optional*/::mlir::stablehlo::ResultAccuracyAttr result_accuracy = {});
stablehlo::CreateTokenOp
Creates a new stablehlo.create_token
operation.
MlirOp CreateToken(MlirBuilder &builder);
stablehlo::CrossReplicaSumOp
Creates a new stablehlo.cross-replica-sum
operation.
MlirOp CrossReplicaSum(MlirOp &operand, ::mlir::DenseIntElementsAttr replica_groups);
stablehlo::CustomCallOp
Creates a new stablehlo.custom_call
operation.
SmallVector<MlirOp> CustomCall(MlirBuilder &builder, Type resultType, ArrayRef<MlirOp> inputs, ::llvm::StringRef call_target_name, /*optional*/bool has_side_effect = false, /*optional*/::mlir::Attribute backend_config = {}, /*optional*/::mlir::stablehlo::CustomCallApiVersion api_version = ::mlir::stablehlo::CustomCallApiVersion::API_VERSION_ORIGINAL, /*optional*/::mlir::ArrayAttr called_computations = {}, /*optional*/::mlir::ArrayAttr operand_layouts = {}, /*optional*/::mlir::ArrayAttr result_layouts = {}, /*optional*/::mlir::ArrayAttr output_operand_aliases = {});
stablehlo::DivOp
Creates a new stablehlo.divide
operation.
MlirOp Div(MlirOp &lhs, MlirOp &rhs);
stablehlo::DotGeneralOp
Creates a new stablehlo.dot_general
operation.
MlirOp DotGeneral(Type resultType, MlirOp &lhs, MlirOp &rhs, ::mlir::stablehlo::DotDimensionNumbersAttr dot_dimension_numbers, /*optional*/::mlir::ArrayAttr precision_config = {}, /*optional*/::mlir::stablehlo::DotAlgorithmAttr algorithm = {});
stablehlo::DotOp
Creates a new stablehlo.dot
operation.
MlirOp Dot(Type resultType, MlirOp &lhs, MlirOp &rhs, /*optional*/::mlir::ArrayAttr precision_config = {});
stablehlo::DynamicBroadcastInDimOp
Creates a new stablehlo.dynamic_broadcast_in_dim
operation.
MlirOp DynamicBroadcastInDim(Type resultType, MlirOp &operand, MlirOp &output_dimensions, ::llvm::ArrayRef<int64_t> broadcast_dimensions, /*optional*/::mlir::DenseI64ArrayAttr known_expanding_dimensions = {}, /*optional*/::mlir::DenseI64ArrayAttr known_nonexpanding_dimensions = {});
stablehlo::DynamicConvOp
Creates a new stablehlo.dynamic_conv
operation.
MlirOp DynamicConv(Type resultType, MlirOp &lhs, MlirOp &rhs, MlirOp &padding, ::mlir::stablehlo::ConvDimensionNumbersAttr dimension_numbers, uint64_t feature_group_count, uint64_t batch_group_count, /*optional*/::mlir::DenseI64ArrayAttr window_strides = {}, /*optional*/::mlir::DenseI64ArrayAttr lhs_dilation = {}, /*optional*/::mlir::DenseI64ArrayAttr rhs_dilation = {}, /*optional*/::mlir::DenseBoolArrayAttr window_reversal = {}, /*optional*/::mlir::ArrayAttr precision_config = {});
stablehlo::DynamicGatherOp
Creates a new stablehlo.dynamic_gather
operation.
MlirOp DynamicGather(MlirOp &operand, MlirOp &start_indices, MlirOp &slice_sizes, ::mlir::stablehlo::GatherDimensionNumbersAttr dimension_numbers, /*optional*/bool indices_are_sorted = false);
stablehlo::DynamicIotaOp
Creates a new stablehlo.dynamic_iota
operation.
MlirOp DynamicIota(Type resultType, MlirOp &output_shape, uint64_t iota_dimension);
stablehlo::DynamicPadOp
Creates a new stablehlo.dynamic_pad
operation.
MlirOp DynamicPad(Type resultType, MlirOp &operand, MlirOp &padding_value, MlirOp &edge_padding_low, MlirOp &edge_padding_high, MlirOp &interior_padding);
stablehlo::DynamicReshapeOp
Creates a new stablehlo.dynamic_reshape
operation.
MlirOp DynamicReshape(Type resultType, MlirOp &operand, MlirOp &output_shape);
stablehlo::DynamicSliceOp
Creates a new stablehlo.dynamic_slice
operation.
MlirOp DynamicSlice(MlirOp &operand, ArrayRef<MlirOp> start_indices, ::llvm::ArrayRef<int64_t> slice_sizes);
stablehlo::DynamicUpdateSliceOp
Creates a new stablehlo.dynamic_update_slice
operation.
MlirOp DynamicUpdateSlice(MlirOp &operand, MlirOp &update, ArrayRef<MlirOp> start_indices);
stablehlo::EinsumOp
Creates a new stablehlo.einsum
operation.
MlirOp Einsum(Type resultType, MlirOp &lhs, MlirOp &rhs, ::llvm::StringRef einsum_config);
stablehlo::ExpOp
Creates a new stablehlo.exponential
operation.
MlirOp Exp(MlirOp &operand, /*optional*/::mlir::stablehlo::ResultAccuracyAttr result_accuracy = {});
stablehlo::Expm1Op
Creates a new stablehlo.exponential_minus_one
operation.
MlirOp Expm1(MlirOp &operand, /*optional*/::mlir::stablehlo::ResultAccuracyAttr result_accuracy = {});
stablehlo::FftOp
Creates a new stablehlo.fft
operation.
MlirOp Fft(MlirOp &operand, ::mlir::stablehlo::FftType fft_type, ::llvm::ArrayRef<int64_t> fft_length);
stablehlo::FloorOp
Creates a new stablehlo.floor
operation.
MlirOp Floor(MlirOp &operand);
stablehlo::GatherOp
Creates a new stablehlo.gather
operation.
MlirOp Gather(MlirOp &operand, MlirOp &start_indices, ::mlir::stablehlo::GatherDimensionNumbersAttr dimension_numbers, ::llvm::ArrayRef<int64_t> slice_sizes, /*optional*/bool indices_are_sorted = false);
stablehlo::GetDimensionSizeOp
Creates a new stablehlo.get_dimension_size
operation.
MlirOp GetDimensionSize(MlirOp &operand, uint64_t dimension);
stablehlo::GetTupleElementOp
Creates a new stablehlo.get_tuple_element
operation.
MlirOp GetTupleElement(MlirOp &operand, uint32_t index);
stablehlo::IfOp
Creates a new stablehlo.if
operation.
This operation has a body region built via a callback function.
SmallVector<MlirOp> If(MlirOp &pred, const RegionBuilderCallback &true_branch, const RegionBuilderCallback &false_branch);
stablehlo::ImagOp
Creates a new stablehlo.imag
operation.
MlirOp Imag(MlirOp &operand);
stablehlo::InfeedOp
Creates a new stablehlo.infeed
operation.
SmallVector<MlirOp> Infeed(Type resultType, MlirOp &token, ::llvm::StringRef infeed_config = "", /*optional*/::mlir::ArrayAttr layout = {});
stablehlo::IotaOp
Creates a new stablehlo.iota
operation.
MlirOp Iota(MlirBuilder &builder, Type resultType, uint64_t iota_dimension);
stablehlo::IsFiniteOp
Creates a new stablehlo.is_finite
operation.
MlirOp IsFinite(MlirOp &x);
stablehlo::Log1pOp
Creates a new stablehlo.log_plus_one
operation.
MlirOp Log1p(MlirOp &operand, /*optional*/::mlir::stablehlo::ResultAccuracyAttr result_accuracy = {});
stablehlo::LogOp
Creates a new stablehlo.log
operation.
MlirOp Log(MlirOp &operand, /*optional*/::mlir::stablehlo::ResultAccuracyAttr result_accuracy = {});
stablehlo::LogisticOp
Creates a new stablehlo.logistic
operation.
MlirOp Logistic(MlirOp &operand, /*optional*/::mlir::stablehlo::ResultAccuracyAttr result_accuracy = {});
stablehlo::MapOp
Creates a new stablehlo.map
operation.
This operation has a body region built via a callback function.
MlirOp Map(MlirBuilder &builder, ArrayRef<MlirOp> inputs, const RegionBuilderCallback &computation, ::llvm::ArrayRef<int64_t> dimensions);
stablehlo::MaxOp
Creates a new stablehlo.maximum
operation.
MlirOp Max(MlirOp &lhs, MlirOp &rhs);
stablehlo::MinOp
Creates a new stablehlo.minimum
operation.
MlirOp Min(MlirOp &lhs, MlirOp &rhs);
stablehlo::MulOp
Creates a new stablehlo.multiply
operation.
MlirOp Mul(MlirOp &lhs, MlirOp &rhs);
stablehlo::NegOp
Creates a new stablehlo.negate
operation.
MlirOp Neg(MlirOp &operand);
stablehlo::NotOp
Creates a new stablehlo.not
operation.
MlirOp Not(MlirOp &operand);
stablehlo::OptimizationBarrierOp
Creates a new stablehlo.optimization_barrier
operation.
SmallVector<MlirOp> OptimizationBarrier(MlirBuilder &builder, ArrayRef<MlirOp> operand);
stablehlo::OrOp
Creates a new stablehlo.or
operation.
MlirOp Or(MlirOp &lhs, MlirOp &rhs);
stablehlo::OutfeedOp
Creates a new stablehlo.outfeed
operation.
MlirOp Outfeed(ArrayRef<MlirOp> inputs, MlirOp &token, ::llvm::StringRef outfeed_config = "");
stablehlo::PadOp
Creates a new stablehlo.pad
operation.
MlirOp Pad(MlirOp &operand, MlirOp &padding_value, ::llvm::ArrayRef<int64_t> edge_padding_low, ::llvm::ArrayRef<int64_t> edge_padding_high, ::llvm::ArrayRef<int64_t> interior_padding);
stablehlo::PartitionIdOp
Creates a new stablehlo.partition_id
operation.
MlirOp PartitionId(MlirBuilder &builder);
stablehlo::PopulationCountOp
Creates a new stablehlo.popcnt
operation.
MlirOp PopulationCount(MlirOp &operand);
stablehlo::PowOp
Creates a new stablehlo.power
operation.
MlirOp Pow(MlirOp &lhs, MlirOp &rhs);
stablehlo::RealDynamicSliceOp
Creates a new stablehlo.real_dynamic_slice
operation.
MlirOp RealDynamicSlice(Type resultType, MlirOp &operand, MlirOp &start_indices, MlirOp &limit_indices, MlirOp &strides);
stablehlo::RealOp
Creates a new stablehlo.real
operation.
MlirOp Real(MlirOp &operand);
stablehlo::RecvOp
Creates a new stablehlo.recv
operation.
SmallVector<MlirOp> Recv(Type resultType, MlirOp &token, ::mlir::stablehlo::ChannelHandleAttr channel_handle, /*optional*/bool is_host_transfer = false, /*optional*/::mlir::DenseIntElementsAttr source_target_pairs = {});
stablehlo::ReduceOp
Creates a new stablehlo.reduce
operation.
This operation has a body region built via a callback function.
SmallVector<MlirOp> Reduce(MlirBuilder &builder, ArrayRef<MlirOp> inputs, ArrayRef<MlirOp> init_values, const RegionBuilderCallback &body, ::llvm::ArrayRef<int64_t> dimensions);
stablehlo::ReducePrecisionOp
Creates a new stablehlo.reduce_precision
operation.
MlirOp ReducePrecision(MlirOp &operand, uint32_t exponent_bits, uint32_t mantissa_bits);
stablehlo::ReduceScatterOp
Creates a new stablehlo.reduce_scatter
operation.
This operation has a body region built via a callback function.
MlirOp ReduceScatter(Type resultType, MlirOp &operand, const RegionBuilderCallback &computation, uint64_t scatter_dimension, ::mlir::DenseIntElementsAttr replica_groups, /*optional*/::mlir::stablehlo::ChannelHandleAttr channel_handle = {}, /*optional*/bool use_global_device_ids = false);
stablehlo::ReduceWindowOp
Creates a new stablehlo.reduce_window
operation.
This operation has a body region built via a callback function.
SmallVector<MlirOp> ReduceWindow(MlirBuilder &builder, ArrayRef<MlirOp> inputs, ArrayRef<MlirOp> init_values, const RegionBuilderCallback &body, ::llvm::ArrayRef<int64_t> window_dimensions, /*optional*/::mlir::DenseI64ArrayAttr window_strides = {}, /*optional*/::mlir::DenseI64ArrayAttr base_dilations = {}, /*optional*/::mlir::DenseI64ArrayAttr window_dilations = {}, /*optional*/::mlir::DenseIntElementsAttr padding = {});
stablehlo::RemOp
Creates a new stablehlo.remainder
operation.
MlirOp Rem(MlirOp &lhs, MlirOp &rhs);
stablehlo::ReplicaIdOp
Creates a new stablehlo.replica_id
operation.
MlirOp ReplicaId(MlirBuilder &builder);
stablehlo::ReshapeOp
Creates a new stablehlo.reshape
operation.
MlirOp Reshape(Type resultType, MlirOp &operand);
stablehlo::ReturnOp
Creates a new stablehlo.return
operation.
This operation is a Region's Terminator. It can only be called in a RegionBuilder function callback when constructing the body of an op.
void Return(RegionBuilder &builder, ArrayRef<MlirOp> results);
stablehlo::ReverseOp
Creates a new stablehlo.reverse
operation.
MlirOp Reverse(MlirOp &operand, ::llvm::ArrayRef<int64_t> dimensions);
stablehlo::RngOp
Creates a new stablehlo.rng
operation.
MlirOp Rng(MlirOp &a, MlirOp &b, MlirOp &shape, ::mlir::stablehlo::RngDistribution rng_distribution);
stablehlo::RoundNearestEvenOp
Creates a new stablehlo.round_nearest_even
operation.
MlirOp RoundNearestEven(MlirOp &operand);
stablehlo::RoundOp
Creates a new stablehlo.round_nearest_afz
operation.
MlirOp Round(MlirOp &operand);
stablehlo::RsqrtOp
Creates a new stablehlo.rsqrt
operation.
MlirOp Rsqrt(MlirOp &operand, /*optional*/::mlir::stablehlo::ResultAccuracyAttr result_accuracy = {});
stablehlo::ScatterOp
Creates a new stablehlo.scatter
operation.
This operation has a body region built via a callback function.
SmallVector<MlirOp> Scatter(ArrayRef<MlirOp> inputs, MlirOp &scatter_indices, ArrayRef<MlirOp> updates, const RegionBuilderCallback &update_computation, ::mlir::stablehlo::ScatterDimensionNumbersAttr scatter_dimension_numbers, /*optional*/bool indices_are_sorted = false, /*optional*/bool unique_indices = false);
stablehlo::SelectAndScatterOp
Creates a new stablehlo.select_and_scatter
operation.
This operation has a body region built via a callback function.
MlirOp SelectAndScatter(MlirOp &operand, MlirOp &source, MlirOp &init_value, const RegionBuilderCallback &select, const RegionBuilderCallback &scatter, /*optional*/::mlir::DenseI64ArrayAttr window_dimensions = {}, /*optional*/::mlir::DenseI64ArrayAttr window_strides = {}, /*optional*/::mlir::DenseIntElementsAttr padding = {});
stablehlo::SelectOp
Creates a new stablehlo.select
operation.
MlirOp Select(MlirOp &pred, MlirOp &on_true, MlirOp &on_false);
stablehlo::SendOp
Creates a new stablehlo.send
operation.
MlirOp Send(ArrayRef<MlirOp> inputs, MlirOp &token, ::mlir::stablehlo::ChannelHandleAttr channel_handle, /*optional*/bool is_host_transfer = false, /*optional*/::mlir::DenseIntElementsAttr source_target_pairs = {});
stablehlo::SetDimensionSizeOp
Creates a new stablehlo.set_dimension_size
operation.
MlirOp SetDimensionSize(MlirOp &operand, MlirOp &size, uint64_t dimension);
stablehlo::ShiftLeftOp
Creates a new stablehlo.shift_left
operation.
MlirOp ShiftLeft(MlirOp &lhs, MlirOp &rhs);
stablehlo::ShiftRightArithmeticOp
Creates a new stablehlo.shift_right_arithmetic
operation.
MlirOp ShiftRightArithmetic(MlirOp &lhs, MlirOp &rhs);
stablehlo::ShiftRightLogicalOp
Creates a new stablehlo.shift_right_logical
operation.
MlirOp ShiftRightLogical(MlirOp &lhs, MlirOp &rhs);
stablehlo::SignOp
Creates a new stablehlo.sign
operation.
MlirOp Sign(MlirOp &operand);
stablehlo::SineOp
Creates a new stablehlo.sine
operation.
MlirOp Sine(MlirOp &operand, /*optional*/::mlir::stablehlo::ResultAccuracyAttr result_accuracy = {});
stablehlo::SliceOp
Creates a new stablehlo.slice
operation.
MlirOp Slice(MlirOp &operand, ::llvm::ArrayRef<int64_t> start_indices, ::llvm::ArrayRef<int64_t> limit_indices, ::llvm::ArrayRef<int64_t> strides);
stablehlo::SortOp
Creates a new stablehlo.sort
operation.
This operation has a body region built via a callback function.
SmallVector<MlirOp> Sort(MlirBuilder &builder, ArrayRef<MlirOp> inputs, const RegionBuilderCallback &comparator, /*optional*/uint64_t dimension = -1, /*optional*/bool is_stable = false);
stablehlo::SqrtOp
Creates a new stablehlo.sqrt
operation.
MlirOp Sqrt(MlirOp &operand, /*optional*/::mlir::stablehlo::ResultAccuracyAttr result_accuracy = {});
stablehlo::SubtractOp
Creates a new stablehlo.subtract
operation.
MlirOp Subtract(MlirOp &lhs, MlirOp &rhs);
stablehlo::TanOp
Creates a new stablehlo.tan
operation.
MlirOp Tan(MlirOp &operand, /*optional*/::mlir::stablehlo::ResultAccuracyAttr result_accuracy = {});
stablehlo::TanhOp
Creates a new stablehlo.tanh
operation.
MlirOp Tanh(MlirOp &operand, /*optional*/::mlir::stablehlo::ResultAccuracyAttr result_accuracy = {});
stablehlo::TorchIndexSelectOp
Creates a new stablehlo.torch_index_select
operation.
MlirOp TorchIndexSelect(Type resultType, MlirOp &operand, MlirOp &index, uint64_t dim, uint64_t batch_dims);
stablehlo::TransposeOp
Creates a new stablehlo.transpose
operation.
MlirOp Transpose(MlirOp &operand, ::llvm::ArrayRef<int64_t> permutation);
stablehlo::TupleOp
Creates a new stablehlo.tuple
operation.
MlirOp Tuple(MlirBuilder &builder, ArrayRef<MlirOp> val);
stablehlo::UnaryEinsumOp
Creates a new stablehlo.unary_einsum
operation.
MlirOp UnaryEinsum(Type resultType, MlirOp &operand, ::llvm::StringRef einsum_config);
stablehlo::UniformDequantizeOp
Creates a new stablehlo.uniform_dequantize
operation.
MlirOp UniformDequantize(MlirOp &operand);
stablehlo::UniformQuantizeOp
Creates a new stablehlo.uniform_quantize
operation.
MlirOp UniformQuantize(Type resultType, MlirOp &operand);
stablehlo::WhileOp
Creates a new stablehlo.while
operation.
This operation has a body region built via a callback function.
SmallVector<MlirOp> While(MlirBuilder &builder, ArrayRef<MlirOp> operand, const RegionBuilderCallback &cond, const RegionBuilderCallback &body);
stablehlo::XorOp
Creates a new stablehlo.xor
operation.
MlirOp Xor(MlirOp &lhs, MlirOp &rhs);
Skipped Operations
Unable to generate builder for the following operations: