stablehlo MLIR Dialect Builder API

Builder Methods

stablehlo::AbsOp

Creates a new stablehlo.abs operation.

MlirOp Abs(MlirOp &operand);

stablehlo::AddOp

Creates a new stablehlo.add operation.

MlirOp Add(MlirOp &lhs, MlirOp &rhs);

stablehlo::AfterAllOp

Creates a new stablehlo.after_all operation.

MlirOp AfterAll(MlirBuilder &builder, ArrayRef<MlirOp> inputs);

stablehlo::AllGatherOp

Creates a new stablehlo.all_gather operation.

SmallVector<MlirOp> AllGather(MlirBuilder &builder, Type resultType, ArrayRef<MlirOp> operands, uint64_t all_gather_dim, ::mlir::DenseIntElementsAttr replica_groups, /*optional*/::mlir::stablehlo::ChannelHandleAttr channel_handle = {}, /*optional*/bool use_global_device_ids = false);

stablehlo::AllReduceOp

Creates a new stablehlo.all_reduce operation.

This operation has a body region built via a callback function.

SmallVector<MlirOp> AllReduce(MlirBuilder &builder, ArrayRef<MlirOp> operands, const RegionBuilderCallback &computation, ::mlir::DenseIntElementsAttr replica_groups, /*optional*/::mlir::stablehlo::ChannelHandleAttr channel_handle = {}, /*optional*/bool use_global_device_ids = false);

stablehlo::AllToAllOp

Creates a new stablehlo.all_to_all operation.

SmallVector<MlirOp> AllToAll(MlirBuilder &builder, ArrayRef<MlirOp> operands, uint64_t split_dimension, uint64_t concat_dimension, uint64_t split_count, ::mlir::DenseIntElementsAttr replica_groups, /*optional*/::mlir::stablehlo::ChannelHandleAttr channel_handle = {});

stablehlo::AndOp

Creates a new stablehlo.and operation.

MlirOp And(MlirOp &lhs, MlirOp &rhs);

stablehlo::Atan2Op

Creates a new stablehlo.atan2 operation.

MlirOp Atan2(MlirOp &lhs, MlirOp &rhs);

stablehlo::BatchNormGradOp

Creates a new stablehlo.batch_norm_grad operation.

SmallVector<MlirOp, 3> BatchNormGrad(MlirOp &operand, MlirOp &scale, MlirOp &mean, MlirOp &variance, MlirOp &grad_output, ::llvm::APFloat epsilon, uint64_t feature_index);

stablehlo::BatchNormInferenceOp

Creates a new stablehlo.batch_norm_inference operation.

MlirOp BatchNormInference(MlirOp &operand, MlirOp &scale, MlirOp &offset, MlirOp &mean, MlirOp &variance, ::llvm::APFloat epsilon, uint64_t feature_index);

stablehlo::BatchNormTrainingOp

Creates a new stablehlo.batch_norm_training operation.

SmallVector<MlirOp, 3> BatchNormTraining(MlirOp &operand, MlirOp &scale, MlirOp &offset, ::llvm::APFloat epsilon, uint64_t feature_index);

stablehlo::BitcastConvertOp

Creates a new stablehlo.bitcast_convert operation.

MlirOp BitcastConvert(Type resultType, MlirOp &operand);

stablehlo::BroadcastInDimOp

Creates a new stablehlo.broadcast_in_dim operation.

MlirOp BroadcastInDim(Type resultType, MlirOp &operand, ::llvm::ArrayRef<int64_t> broadcast_dimensions);

stablehlo::BroadcastOp

Creates a new stablehlo.broadcast operation.

MlirOp Broadcast(MlirOp &operand, ::llvm::ArrayRef<int64_t> broadcast_sizes);

stablehlo::CbrtOp

Creates a new stablehlo.cbrt operation.

MlirOp Cbrt(MlirOp &operand, /*optional*/::mlir::stablehlo::ResultAccuracyAttr result_accuracy = {});

stablehlo::CeilOp

Creates a new stablehlo.ceil operation.

MlirOp Ceil(MlirOp &operand);

stablehlo::CholeskyOp

Creates a new stablehlo.cholesky operation.

MlirOp Cholesky(MlirOp &a, /*optional*/bool lower = false);

stablehlo::ClampOp

Creates a new stablehlo.clamp operation.

MlirOp Clamp(MlirOp &min, MlirOp &operand, MlirOp &max);

stablehlo::ClzOp

Creates a new stablehlo.count_leading_zeros operation.

MlirOp Clz(MlirOp &operand);

stablehlo::CollectiveBroadcastOp

Creates a new stablehlo.collective_broadcast operation.

MlirOp CollectiveBroadcast(MlirOp &operand, ::mlir::DenseIntElementsAttr replica_groups, /*optional*/::mlir::stablehlo::ChannelHandleAttr channel_handle = {});

stablehlo::CollectivePermuteOp

Creates a new stablehlo.collective_permute operation.

MlirOp CollectivePermute(MlirOp &operand, ::mlir::DenseIntElementsAttr source_target_pairs, /*optional*/::mlir::stablehlo::ChannelHandleAttr channel_handle = {});

stablehlo::CompareOp

Creates a new stablehlo.compare operation.

MlirOp Compare(MlirOp &lhs, MlirOp &rhs, ::mlir::stablehlo::ComparisonDirection comparison_direction, /*optional*/::mlir::stablehlo::ComparisonTypeAttr compare_type = {});

stablehlo::ComplexOp

Creates a new stablehlo.complex operation.

MlirOp Complex(MlirOp &lhs, MlirOp &rhs);

stablehlo::CompositeOp

Creates a new stablehlo.composite operation.

SmallVector<MlirOp> Composite(MlirBuilder &builder, Type resultType, ArrayRef<MlirOp> inputs, ::llvm::StringRef name, ::llvm::StringRef decomposition, /*optional*/::mlir::DictionaryAttr composite_attributes = {}, /*optional*/uint32_t version = 0);

stablehlo::ConcatenateOp

Creates a new stablehlo.concatenate operation.

MlirOp Concatenate(MlirBuilder &builder, ArrayRef<MlirOp> inputs, uint64_t dimension);

stablehlo::ConstantOp

Creates a new stablehlo.constant operation.

MlirOp Constant(MlirBuilder &builder, ::mlir::ElementsAttr value);

stablehlo::ConvertOp

Creates a new stablehlo.convert operation.

MlirOp Convert(Type resultType, MlirOp &operand);

stablehlo::ConvolutionOp

Creates a new stablehlo.convolution operation.

MlirOp Convolution(Type resultType, MlirOp &lhs, MlirOp &rhs, ::mlir::stablehlo::ConvDimensionNumbersAttr dimension_numbers, uint64_t feature_group_count, uint64_t batch_group_count, /*optional*/::mlir::DenseI64ArrayAttr window_strides = {}, /*optional*/::mlir::DenseIntElementsAttr padding = {}, /*optional*/::mlir::DenseI64ArrayAttr lhs_dilation = {}, /*optional*/::mlir::DenseI64ArrayAttr rhs_dilation = {}, /*optional*/::mlir::DenseBoolArrayAttr window_reversal = {}, /*optional*/::mlir::ArrayAttr precision_config = {});

stablehlo::CosineOp

Creates a new stablehlo.cosine operation.

MlirOp Cosine(MlirOp &operand, /*optional*/::mlir::stablehlo::ResultAccuracyAttr result_accuracy = {});

stablehlo::CreateTokenOp

Creates a new stablehlo.create_token operation.

MlirOp CreateToken(MlirBuilder &builder);

stablehlo::CrossReplicaSumOp

Creates a new stablehlo.cross-replica-sum operation.

MlirOp CrossReplicaSum(MlirOp &operand, ::mlir::DenseIntElementsAttr replica_groups);

stablehlo::CustomCallOp

Creates a new stablehlo.custom_call operation.

SmallVector<MlirOp> CustomCall(MlirBuilder &builder, Type resultType, ArrayRef<MlirOp> inputs, ::llvm::StringRef call_target_name, /*optional*/bool has_side_effect = false, /*optional*/::mlir::Attribute backend_config = {}, /*optional*/::mlir::stablehlo::CustomCallApiVersion api_version = ::mlir::stablehlo::CustomCallApiVersion::API_VERSION_ORIGINAL, /*optional*/::mlir::ArrayAttr called_computations = {}, /*optional*/::mlir::ArrayAttr operand_layouts = {}, /*optional*/::mlir::ArrayAttr result_layouts = {}, /*optional*/::mlir::ArrayAttr output_operand_aliases = {});

stablehlo::DivOp

Creates a new stablehlo.divide operation.

MlirOp Div(MlirOp &lhs, MlirOp &rhs);

stablehlo::DotGeneralOp

Creates a new stablehlo.dot_general operation.

MlirOp DotGeneral(Type resultType, MlirOp &lhs, MlirOp &rhs, ::mlir::stablehlo::DotDimensionNumbersAttr dot_dimension_numbers, /*optional*/::mlir::ArrayAttr precision_config = {}, /*optional*/::mlir::stablehlo::DotAlgorithmAttr algorithm = {});

stablehlo::DotOp

Creates a new stablehlo.dot operation.

MlirOp Dot(Type resultType, MlirOp &lhs, MlirOp &rhs, /*optional*/::mlir::ArrayAttr precision_config = {});

stablehlo::DynamicBroadcastInDimOp

Creates a new stablehlo.dynamic_broadcast_in_dim operation.

MlirOp DynamicBroadcastInDim(Type resultType, MlirOp &operand, MlirOp &output_dimensions, ::llvm::ArrayRef<int64_t> broadcast_dimensions, /*optional*/::mlir::DenseI64ArrayAttr known_expanding_dimensions = {}, /*optional*/::mlir::DenseI64ArrayAttr known_nonexpanding_dimensions = {});

stablehlo::DynamicConvOp

Creates a new stablehlo.dynamic_conv operation.

MlirOp DynamicConv(Type resultType, MlirOp &lhs, MlirOp &rhs, MlirOp &padding, ::mlir::stablehlo::ConvDimensionNumbersAttr dimension_numbers, uint64_t feature_group_count, uint64_t batch_group_count, /*optional*/::mlir::DenseI64ArrayAttr window_strides = {}, /*optional*/::mlir::DenseI64ArrayAttr lhs_dilation = {}, /*optional*/::mlir::DenseI64ArrayAttr rhs_dilation = {}, /*optional*/::mlir::DenseBoolArrayAttr window_reversal = {}, /*optional*/::mlir::ArrayAttr precision_config = {});

stablehlo::DynamicGatherOp

Creates a new stablehlo.dynamic_gather operation.

MlirOp DynamicGather(MlirOp &operand, MlirOp &start_indices, MlirOp &slice_sizes, ::mlir::stablehlo::GatherDimensionNumbersAttr dimension_numbers, /*optional*/bool indices_are_sorted = false);

stablehlo::DynamicIotaOp

Creates a new stablehlo.dynamic_iota operation.

MlirOp DynamicIota(Type resultType, MlirOp &output_shape, uint64_t iota_dimension);

stablehlo::DynamicPadOp

Creates a new stablehlo.dynamic_pad operation.

MlirOp DynamicPad(Type resultType, MlirOp &operand, MlirOp &padding_value, MlirOp &edge_padding_low, MlirOp &edge_padding_high, MlirOp &interior_padding);

stablehlo::DynamicReshapeOp

Creates a new stablehlo.dynamic_reshape operation.

MlirOp DynamicReshape(Type resultType, MlirOp &operand, MlirOp &output_shape);

stablehlo::DynamicSliceOp

Creates a new stablehlo.dynamic_slice operation.

MlirOp DynamicSlice(MlirOp &operand, ArrayRef<MlirOp> start_indices, ::llvm::ArrayRef<int64_t> slice_sizes);

stablehlo::DynamicUpdateSliceOp

Creates a new stablehlo.dynamic_update_slice operation.

MlirOp DynamicUpdateSlice(MlirOp &operand, MlirOp &update, ArrayRef<MlirOp> start_indices);

stablehlo::EinsumOp

Creates a new stablehlo.einsum operation.

MlirOp Einsum(Type resultType, MlirOp &lhs, MlirOp &rhs, ::llvm::StringRef einsum_config);

stablehlo::ExpOp

Creates a new stablehlo.exponential operation.

MlirOp Exp(MlirOp &operand, /*optional*/::mlir::stablehlo::ResultAccuracyAttr result_accuracy = {});

stablehlo::Expm1Op

Creates a new stablehlo.exponential_minus_one operation.

MlirOp Expm1(MlirOp &operand, /*optional*/::mlir::stablehlo::ResultAccuracyAttr result_accuracy = {});

stablehlo::FftOp

Creates a new stablehlo.fft operation.

MlirOp Fft(MlirOp &operand, ::mlir::stablehlo::FftType fft_type, ::llvm::ArrayRef<int64_t> fft_length);

stablehlo::FloorOp

Creates a new stablehlo.floor operation.

MlirOp Floor(MlirOp &operand);

stablehlo::GatherOp

Creates a new stablehlo.gather operation.

MlirOp Gather(MlirOp &operand, MlirOp &start_indices, ::mlir::stablehlo::GatherDimensionNumbersAttr dimension_numbers, ::llvm::ArrayRef<int64_t> slice_sizes, /*optional*/bool indices_are_sorted = false);

stablehlo::GetDimensionSizeOp

Creates a new stablehlo.get_dimension_size operation.

MlirOp GetDimensionSize(MlirOp &operand, uint64_t dimension);

stablehlo::GetTupleElementOp

Creates a new stablehlo.get_tuple_element operation.

MlirOp GetTupleElement(MlirOp &operand, uint32_t index);

stablehlo::IfOp

Creates a new stablehlo.if operation.

This operation has a body region built via a callback function.

SmallVector<MlirOp> If(MlirOp &pred, const RegionBuilderCallback &true_branch, const RegionBuilderCallback &false_branch);

stablehlo::ImagOp

Creates a new stablehlo.imag operation.

MlirOp Imag(MlirOp &operand);

stablehlo::InfeedOp

Creates a new stablehlo.infeed operation.

SmallVector<MlirOp> Infeed(Type resultType, MlirOp &token, ::llvm::StringRef infeed_config = "", /*optional*/::mlir::ArrayAttr layout = {});

stablehlo::IotaOp

Creates a new stablehlo.iota operation.

MlirOp Iota(MlirBuilder &builder, Type resultType, uint64_t iota_dimension);

stablehlo::IsFiniteOp

Creates a new stablehlo.is_finite operation.

MlirOp IsFinite(MlirOp &x);

stablehlo::Log1pOp

Creates a new stablehlo.log_plus_one operation.

MlirOp Log1p(MlirOp &operand, /*optional*/::mlir::stablehlo::ResultAccuracyAttr result_accuracy = {});

stablehlo::LogOp

Creates a new stablehlo.log operation.

MlirOp Log(MlirOp &operand, /*optional*/::mlir::stablehlo::ResultAccuracyAttr result_accuracy = {});

stablehlo::LogisticOp

Creates a new stablehlo.logistic operation.

MlirOp Logistic(MlirOp &operand, /*optional*/::mlir::stablehlo::ResultAccuracyAttr result_accuracy = {});

stablehlo::MapOp

Creates a new stablehlo.map operation.

This operation has a body region built via a callback function.

MlirOp Map(MlirBuilder &builder, ArrayRef<MlirOp> inputs, const RegionBuilderCallback &computation, ::llvm::ArrayRef<int64_t> dimensions);

stablehlo::MaxOp

Creates a new stablehlo.maximum operation.

MlirOp Max(MlirOp &lhs, MlirOp &rhs);

stablehlo::MinOp

Creates a new stablehlo.minimum operation.

MlirOp Min(MlirOp &lhs, MlirOp &rhs);

stablehlo::MulOp

Creates a new stablehlo.multiply operation.

MlirOp Mul(MlirOp &lhs, MlirOp &rhs);

stablehlo::NegOp

Creates a new stablehlo.negate operation.

MlirOp Neg(MlirOp &operand);

stablehlo::NotOp

Creates a new stablehlo.not operation.

MlirOp Not(MlirOp &operand);

stablehlo::OptimizationBarrierOp

Creates a new stablehlo.optimization_barrier operation.

SmallVector<MlirOp> OptimizationBarrier(MlirBuilder &builder, ArrayRef<MlirOp> operand);

stablehlo::OrOp

Creates a new stablehlo.or operation.

MlirOp Or(MlirOp &lhs, MlirOp &rhs);

stablehlo::OutfeedOp

Creates a new stablehlo.outfeed operation.

MlirOp Outfeed(ArrayRef<MlirOp> inputs, MlirOp &token, ::llvm::StringRef outfeed_config = "");

stablehlo::PadOp

Creates a new stablehlo.pad operation.

MlirOp Pad(MlirOp &operand, MlirOp &padding_value, ::llvm::ArrayRef<int64_t> edge_padding_low, ::llvm::ArrayRef<int64_t> edge_padding_high, ::llvm::ArrayRef<int64_t> interior_padding);

stablehlo::PartitionIdOp

Creates a new stablehlo.partition_id operation.

MlirOp PartitionId(MlirBuilder &builder);

stablehlo::PopulationCountOp

Creates a new stablehlo.popcnt operation.

MlirOp PopulationCount(MlirOp &operand);

stablehlo::PowOp

Creates a new stablehlo.power operation.

MlirOp Pow(MlirOp &lhs, MlirOp &rhs);

stablehlo::RealDynamicSliceOp

Creates a new stablehlo.real_dynamic_slice operation.

MlirOp RealDynamicSlice(Type resultType, MlirOp &operand, MlirOp &start_indices, MlirOp &limit_indices, MlirOp &strides);

stablehlo::RealOp

Creates a new stablehlo.real operation.

MlirOp Real(MlirOp &operand);

stablehlo::RecvOp

Creates a new stablehlo.recv operation.

SmallVector<MlirOp> Recv(Type resultType, MlirOp &token, ::mlir::stablehlo::ChannelHandleAttr channel_handle, /*optional*/bool is_host_transfer = false, /*optional*/::mlir::DenseIntElementsAttr source_target_pairs = {});

stablehlo::ReduceOp

Creates a new stablehlo.reduce operation.

This operation has a body region built via a callback function.

SmallVector<MlirOp> Reduce(MlirBuilder &builder, ArrayRef<MlirOp> inputs, ArrayRef<MlirOp> init_values, const RegionBuilderCallback &body, ::llvm::ArrayRef<int64_t> dimensions);

stablehlo::ReducePrecisionOp

Creates a new stablehlo.reduce_precision operation.

MlirOp ReducePrecision(MlirOp &operand, uint32_t exponent_bits, uint32_t mantissa_bits);

stablehlo::ReduceScatterOp

Creates a new stablehlo.reduce_scatter operation.

This operation has a body region built via a callback function.

MlirOp ReduceScatter(Type resultType, MlirOp &operand, const RegionBuilderCallback &computation, uint64_t scatter_dimension, ::mlir::DenseIntElementsAttr replica_groups, /*optional*/::mlir::stablehlo::ChannelHandleAttr channel_handle = {}, /*optional*/bool use_global_device_ids = false);

stablehlo::ReduceWindowOp

Creates a new stablehlo.reduce_window operation.

This operation has a body region built via a callback function.

SmallVector<MlirOp> ReduceWindow(MlirBuilder &builder, ArrayRef<MlirOp> inputs, ArrayRef<MlirOp> init_values, const RegionBuilderCallback &body, ::llvm::ArrayRef<int64_t> window_dimensions, /*optional*/::mlir::DenseI64ArrayAttr window_strides = {}, /*optional*/::mlir::DenseI64ArrayAttr base_dilations = {}, /*optional*/::mlir::DenseI64ArrayAttr window_dilations = {}, /*optional*/::mlir::DenseIntElementsAttr padding = {});

stablehlo::RemOp

Creates a new stablehlo.remainder operation.

MlirOp Rem(MlirOp &lhs, MlirOp &rhs);

stablehlo::ReplicaIdOp

Creates a new stablehlo.replica_id operation.

MlirOp ReplicaId(MlirBuilder &builder);

stablehlo::ReshapeOp

Creates a new stablehlo.reshape operation.

MlirOp Reshape(Type resultType, MlirOp &operand);

stablehlo::ReturnOp

Creates a new stablehlo.return operation.

This operation is a Region's Terminator. It can only be called in a RegionBuilder function callback when constructing the body of an op.

void Return(RegionBuilder &builder, ArrayRef<MlirOp> results);

stablehlo::ReverseOp

Creates a new stablehlo.reverse operation.

MlirOp Reverse(MlirOp &operand, ::llvm::ArrayRef<int64_t> dimensions);

stablehlo::RngOp

Creates a new stablehlo.rng operation.

MlirOp Rng(MlirOp &a, MlirOp &b, MlirOp &shape, ::mlir::stablehlo::RngDistribution rng_distribution);

stablehlo::RoundNearestEvenOp

Creates a new stablehlo.round_nearest_even operation.

MlirOp RoundNearestEven(MlirOp &operand);

stablehlo::RoundOp

Creates a new stablehlo.round_nearest_afz operation.

MlirOp Round(MlirOp &operand);

stablehlo::RsqrtOp

Creates a new stablehlo.rsqrt operation.

MlirOp Rsqrt(MlirOp &operand, /*optional*/::mlir::stablehlo::ResultAccuracyAttr result_accuracy = {});

stablehlo::ScatterOp

Creates a new stablehlo.scatter operation.

This operation has a body region built via a callback function.

SmallVector<MlirOp> Scatter(ArrayRef<MlirOp> inputs, MlirOp &scatter_indices, ArrayRef<MlirOp> updates, const RegionBuilderCallback &update_computation, ::mlir::stablehlo::ScatterDimensionNumbersAttr scatter_dimension_numbers, /*optional*/bool indices_are_sorted = false, /*optional*/bool unique_indices = false);

stablehlo::SelectAndScatterOp

Creates a new stablehlo.select_and_scatter operation.

This operation has a body region built via a callback function.

MlirOp SelectAndScatter(MlirOp &operand, MlirOp &source, MlirOp &init_value, const RegionBuilderCallback &select, const RegionBuilderCallback &scatter, /*optional*/::mlir::DenseI64ArrayAttr window_dimensions = {}, /*optional*/::mlir::DenseI64ArrayAttr window_strides = {}, /*optional*/::mlir::DenseIntElementsAttr padding = {});

stablehlo::SelectOp

Creates a new stablehlo.select operation.

MlirOp Select(MlirOp &pred, MlirOp &on_true, MlirOp &on_false);

stablehlo::SendOp

Creates a new stablehlo.send operation.

MlirOp Send(ArrayRef<MlirOp> inputs, MlirOp &token, ::mlir::stablehlo::ChannelHandleAttr channel_handle, /*optional*/bool is_host_transfer = false, /*optional*/::mlir::DenseIntElementsAttr source_target_pairs = {});

stablehlo::SetDimensionSizeOp

Creates a new stablehlo.set_dimension_size operation.

MlirOp SetDimensionSize(MlirOp &operand, MlirOp &size, uint64_t dimension);

stablehlo::ShiftLeftOp

Creates a new stablehlo.shift_left operation.

MlirOp ShiftLeft(MlirOp &lhs, MlirOp &rhs);

stablehlo::ShiftRightArithmeticOp

Creates a new stablehlo.shift_right_arithmetic operation.

MlirOp ShiftRightArithmetic(MlirOp &lhs, MlirOp &rhs);

stablehlo::ShiftRightLogicalOp

Creates a new stablehlo.shift_right_logical operation.

MlirOp ShiftRightLogical(MlirOp &lhs, MlirOp &rhs);

stablehlo::SignOp

Creates a new stablehlo.sign operation.

MlirOp Sign(MlirOp &operand);

stablehlo::SineOp

Creates a new stablehlo.sine operation.

MlirOp Sine(MlirOp &operand, /*optional*/::mlir::stablehlo::ResultAccuracyAttr result_accuracy = {});

stablehlo::SliceOp

Creates a new stablehlo.slice operation.

MlirOp Slice(MlirOp &operand, ::llvm::ArrayRef<int64_t> start_indices, ::llvm::ArrayRef<int64_t> limit_indices, ::llvm::ArrayRef<int64_t> strides);

stablehlo::SortOp

Creates a new stablehlo.sort operation.

This operation has a body region built via a callback function.

SmallVector<MlirOp> Sort(MlirBuilder &builder, ArrayRef<MlirOp> inputs, const RegionBuilderCallback &comparator, /*optional*/uint64_t dimension = -1, /*optional*/bool is_stable = false);

stablehlo::SqrtOp

Creates a new stablehlo.sqrt operation.

MlirOp Sqrt(MlirOp &operand, /*optional*/::mlir::stablehlo::ResultAccuracyAttr result_accuracy = {});

stablehlo::SubtractOp

Creates a new stablehlo.subtract operation.

MlirOp Subtract(MlirOp &lhs, MlirOp &rhs);

stablehlo::TanOp

Creates a new stablehlo.tan operation.

MlirOp Tan(MlirOp &operand, /*optional*/::mlir::stablehlo::ResultAccuracyAttr result_accuracy = {});

stablehlo::TanhOp

Creates a new stablehlo.tanh operation.

MlirOp Tanh(MlirOp &operand, /*optional*/::mlir::stablehlo::ResultAccuracyAttr result_accuracy = {});

stablehlo::TorchIndexSelectOp

Creates a new stablehlo.torch_index_select operation.

MlirOp TorchIndexSelect(Type resultType, MlirOp &operand, MlirOp &index, uint64_t dim, uint64_t batch_dims);

stablehlo::TransposeOp

Creates a new stablehlo.transpose operation.

MlirOp Transpose(MlirOp &operand, ::llvm::ArrayRef<int64_t> permutation);

stablehlo::TupleOp

Creates a new stablehlo.tuple operation.

MlirOp Tuple(MlirBuilder &builder, ArrayRef<MlirOp> val);

stablehlo::UnaryEinsumOp

Creates a new stablehlo.unary_einsum operation.

MlirOp UnaryEinsum(Type resultType, MlirOp &operand, ::llvm::StringRef einsum_config);

stablehlo::UniformDequantizeOp

Creates a new stablehlo.uniform_dequantize operation.

MlirOp UniformDequantize(MlirOp &operand);

stablehlo::UniformQuantizeOp

Creates a new stablehlo.uniform_quantize operation.

MlirOp UniformQuantize(Type resultType, MlirOp &operand);

stablehlo::WhileOp

Creates a new stablehlo.while operation.

This operation has a body region built via a callback function.

SmallVector<MlirOp> While(MlirBuilder &builder, ArrayRef<MlirOp> operand, const RegionBuilderCallback &cond, const RegionBuilderCallback &body);

stablehlo::XorOp

Creates a new stablehlo.xor operation.

MlirOp Xor(MlirOp &lhs, MlirOp &rhs);

Skipped Operations

Unable to generate builder for the following operations: