Ngôn ngữ "sdy"

Ngôn ngữ Shardy (SDY) xác định một cách trình bày phân đoạn tensor dựa trên trục và các thành phần API bổ sung để đính kèm các phân đoạn vào tensor.

Hoạt động tính toán

sdy.all_gather (sdy::AllGatherOp)

Thực hiện giao tiếp tất cả cùng nhau dọc theo các trục

Cú pháp:

operation ::= `sdy.all_gather` $gathering_axes $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)

Thu thập các phần của một tensor dọc theo các trục được chỉ định trong gathering_axes.

gathering_axes là danh sách các danh sách trục. Danh sách bên ngoài vượt quá kích thước của tensor. Mỗi danh sách bên trong chỉ định các trục theo đó sẽ thực hiện một hoạt động thu thập riêng biệt trên phương diện tương ứng. Phương thức này sẽ được áp dụng cho việc phân đoạn toán hạng (tensor) để lấy kết quả phân đoạn (out_sharding).

Xin lưu ý rằng out_sharding không được dùng để xác định việc phân đoạn kết quả. Thay vào đó, việc phân đoạn kết quả được xác định bằng cách phân đoạn toán hạng và gathering_axes, đồng thời out_sharding phải khớp với phân đoạn suy luận này.

Ví dụ:

%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", "b", "c"}, {}, {"d"}\]>]>} : tensor<8x8x8xf32>
%2 = sdy.all_gather [{"b", "c"}, {}, {"d"}\] %1 out_sharding=<@mesh, [{"a"}, {}, {}\]> : tensor<8x8x8xf32>

Các quy tắc ràng buộc:

  • Phải đáp ứng các điều kiện ràng buộc được liệt kê trong Sdy_CollectiveOpInterface.
  • Các phần tử trong gathering_axes phải đáp ứng các quy tắc ràng buộc được liệt kê trong AxisRefListAttr.
  • Việc áp dụng gathering_axes cho hoạt động phân đoạn toán hạng sẽ nhận được out_sharding.

Đặc điểm: SameOperandsAndResultType

Giao diện: InferTypeOpInterface, Sdy_CollectiveOpInterface

Thuộc tính:

Thuộc tínhLoại MLIRMô tả
gathering_axes::mlir::sdy::ListOfAxisRefListsAttrDanh sách danh sách tham chiếu trục
out_sharding::mlir::sdy::TensorShardingAttrPhân đoạn tensor

Toán hạng:

Toán hạng Mô tả
tensor tensor của bất kỳ giá trị loại nào

Kết quả:

Kết quả Mô tả
result tensor của bất kỳ giá trị loại nào

sdy.all_reduce (sdy::AllReduceOp)

Thực hiện giao tiếp giảm tất cả theo các trục

Cú pháp:

operation ::= `sdy.all_reduce` $reduction_axes $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)

Giảm các phần của một tensor dọc theo các trục được chỉ định trong reduction_axes. Thứ tự của reduction_axes không quan trọng đối với kết quả, nhưng có thể ảnh hưởng đến thứ tự của các nhóm bản sao tương ứng.

Các quy tắc ràng buộc:

  • Phải đáp ứng các điều kiện ràng buộc được liệt kê trong Sdy_CollectiveOpInterface.
  • reduction_axes phải đáp ứng các quy tắc ràng buộc được liệt kê trong AxisRefListAttr;
  • reduction_axes không được chồng chéo với các trục phân đoạn toán hạng;

Đặc điểm: SameOperandsAndResultType

Giao diện: CollectiveOpInterface, InferTypeOpInterface

Thuộc tính:

Thuộc tínhLoại MLIRMô tả
reduction_axes::mlir::sdy::AxisRefListAttrDanh sách tệp tham chiếu trục
out_sharding::mlir::sdy::TensorShardingAttrPhân đoạn tensor

Toán hạng:

Toán hạng Mô tả
tensor tensor của bất kỳ giá trị loại nào

Kết quả:

Kết quả Mô tả
result tensor của bất kỳ giá trị loại nào

sdy.all_slice (sdy::AllSliceOp)

Thực hiện thao tác cắt động dọc theo các trục

Cú pháp:

operation ::= `sdy.all_slice` $slicing_axes $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)

Cắt các phần của một tensor dọc theo các trục được chỉ định trong slicing_axes. Có một tính chất song đôi đại số giữa sdy.all_slicesdy.all_gather.

slicing_axes là danh sách các danh sách trục. Danh sách bên ngoài vượt quá kích thước của tensor. Mỗi danh sách bên trong chỉ định các trục dọc theo đó một lát cắt sẽ được thực hiện trên phương diện tương ứng. Phương thức này sẽ được áp dụng cho việc phân đoạn toán hạng (tensor) để lấy kết quả phân đoạn (out_sharding).

Xin lưu ý rằng out_sharding không được dùng để xác định việc phân đoạn kết quả. Thay vào đó, việc phân đoạn kết quả được xác định bằng cách phân đoạn toán hạng và slicing_axes, đồng thời out_sharding phải khớp với phân đoạn suy luận này.

Ví dụ:

%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}, {}, {}\]>]>} : tensor<8x8x8xf32>
%2 = sdy.all_slice [{"b", "c"}, {}, {"d"}\] %1 out_sharding=<@mesh, [{"a", "b", "c"}, {}, {"d"}\]> : tensor<8x8x8xf32>

Các quy tắc ràng buộc:

  • Các phần tử trong slicing_axes phải đáp ứng các quy tắc ràng buộc được liệt kê trong AxisRefListAttr.
  • Phải đáp ứng các điều kiện ràng buộc được liệt kê trong Sdy_CollectiveOpInterface.
  • Việc áp dụng slicing_axes cho hoạt động phân đoạn toán hạng sẽ nhận được out_sharding.

Đặc điểm: SameOperandsAndResultType

Giao diện: CollectiveOpInterface, InferTypeOpInterface

Thuộc tính:

Thuộc tínhLoại MLIRMô tả
slicing_axes::mlir::sdy::ListOfAxisRefListsAttrDanh sách danh sách tham chiếu trục
out_sharding::mlir::sdy::TensorShardingAttrPhân đoạn tensor

Toán hạng:

Toán hạng Mô tả
tensor tensor của bất kỳ giá trị loại nào

Kết quả:

Kết quả Mô tả
result tensor của bất kỳ giá trị loại nào

sdy.all_to_all (sdy::AllToAllOp)

Thực hiện giao tiếp tất cả với tất cả dọc theo các trục

Cú pháp:

operation ::= `sdy.all_to_all` $params $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)

Đối với mỗi bộ dữ liệu (axes, src_dim, tgt_dim) trong danh sách tham số, thao tác này sẽ cắt các phần của một tensor theo phương diện tgt_dim và các trục được chỉ định trong axes, phân tán các phần đó theo các trục và nối các phần đó theo phương diện src_dim.

Về cơ bản, toán tử này là sự kết hợp của một toán tử thu thập tất cả theo src_dimaxes, theo sau là một toán tử cắt tất cả theo tgt_dimaxes, tức là hậu tố của phương diện phân đoạn trục src_dim trên tensor đầu vào được thêm vào phương diện phân đoạn trục tgt_dim trên tensor đầu ra.

Phương thức tất cả với tất cả sẽ được áp dụng cho việc phân đoạn toán hạng (tensor) để lấy kết quả phân đoạn (out_sharding).

Xin lưu ý rằng out_sharding không được dùng để xác định việc phân đoạn kết quả. Thay vào đó, việc phân đoạn kết quả được xác định bằng cách phân đoạn toán hạng, src_dim, tgt_dimaxes, đồng thời out_sharding phải khớp với phân đoạn được suy luận này.

Ví dụ:

%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", "b"}, {"c"}, {}, {}\]>]>} : tensor<8x8x4x4x32>
%2 = sdy.all_to_all [{"b"}: 0->2, {"c"}: 1->3] %1 out_sharding=<@mesh, [{"a"}, {}, {"b"}, {"c"}\]> : tensor<8x8x4x4x32>

Các quy tắc ràng buộc:

  • Phải đáp ứng các điều kiện ràng buộc được liệt kê trong Sdy_CollectiveOpInterface.
  • Danh sách tham số không được để trống.
  • Đối với mỗi tham số trong params:
    • Các phần tử trong axes phải đáp ứng các quy tắc ràng buộc của AxisRefAttr.
    • src_dimtgt_dim phải là các phương diện hợp lệ (không âm và nhỏ hơn thứ hạng của tensor).
    • Mọi src_dim hoặc tgt_dim phải là duy nhất trên tất cả các tham số.
    • src_dim phải được sắp xếp theo thứ tự tăng dần trên tất cả các tham số.
  • Việc di chuyển axes từ src_dim sang tgt_dim trong quá trình phân đoạn toán hạng sẽ nhận được out_sharding.

Đặc điểm: SameOperandsAndResultType

Giao diện: InferTypeOpInterface, Sdy_CollectiveOpInterface

Thuộc tính:

Thuộc tínhLoại MLIRMô tả
params::mlir::sdy::AlltoAllParamListAttrDanh sách tham số tất cả với tất cả
out_sharding::mlir::sdy::TensorShardingAttrPhân đoạn tensor

Toán hạng:

Toán hạng Mô tả
tensor tensor của bất kỳ giá trị loại nào

Kết quả:

Kết quả Mô tả
result tensor của bất kỳ giá trị loại nào

sdy.collective_permute (sdy::CollectivePermuteOp)

Thực hiện giao tiếp hoán vị tập thể để thay thế các trục

Cú pháp:

operation ::= `sdy.collective_permute` $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)

Gửi một phần của tensor đầu vào từ mỗi thiết bị sang thiết bị khác để sắp xếp lại/thay thế các trục phân đoạn tensor.

Một phép hoán vị tập hợp có thể biến đổi hoạt động phân đoạn đầu vào sao cho mỗi kích thước phải được phân đoạn như trước, tức là phải được phân đoạn dọc theo các trục có tích kích thước khớp với tích kích thước của các trục đã phân đoạn tensor trước đó.

Điều này hữu ích khi sắp xếp lại các trục trong một phương diện hoặc trên nhiều phương diện, cũng như hoán đổi các trục được phân đoạn với các trục được sao chép.

Trong ví dụ bên dưới, kích thước tensor phân đoạn là tensor<1x4x2xf32> và kích thước này được giữ nguyên bằng cách hoán vị tập thể.

Ví dụ:

sdy.mesh @mesh = <["a"=2, "b"=2, "c"=4, "d"=2, "e"=2, "f"=2]>
%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", "c"}, {"f"}, {"d", "e"}\]>]>} : tensor<8x8x8xf32>
%2 = sdy.collective_permute %1 out_sharding=<@mesh, [{"c":(1)2, "b", "f"}, {"a"}, {"e", "d"}\]> : tensor<8x8x8xf32>

Các quy tắc ràng buộc:

  • Phải đáp ứng các điều kiện ràng buộc được liệt kê trong Sdy_CollectiveOpInterface.
  • Nếu phân đoạn đầu vào và đầu ra có các lưới khác nhau, thì các lưới đó phải có cùng một trục và thứ tự mã thiết bị khác nhau.
  • Đối với mỗi phương diện, tích của kích thước trục phân đoạn trong out_sharding phải khớp với tích của kích thước phân đoạn phương diện toán hạng tương ứng.

Đặc điểm: SameOperandsAndResultType

Giao diện: CollectiveOpInterface, InferTypeOpInterface

Thuộc tính:

Thuộc tínhLoại MLIRMô tả
out_sharding::mlir::sdy::TensorShardingAttrPhân đoạn tensor

Toán hạng:

Toán hạng Mô tả
tensor tensor của bất kỳ giá trị loại nào

Kết quả:

Kết quả Mô tả
result tensor của bất kỳ giá trị loại nào

sdy.constant (sdy::ConstantOp)

Toán tử hằng

Tạo một tensor output từ một hằng số value.

Xem: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#constant

Ví dụ:

%output = sdy.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>

Đặc điểm: AlwaysSpeculatableImplTrait

Giao diện: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface)

Hiệu ứng: MemoryEffects::Effect{}

Thuộc tính:

Thuộc tínhLoại MLIRMô tả
value::mlir::ElementsAttrthuộc tính vectơ/tensor không đổi

Kết quả:

Kết quả Mô tả
output tensor có hình dạng tĩnh của bất kỳ giá trị loại nào

sdy.data_flow_edge (sdy::DataFlowEdgeOp)

Toán tử cạnh luồng dữ liệu.

Cú pháp:

operation ::= `sdy.data_flow_edge` $input (`sharding````=``` $sharding^)? attr-dict `:` type($result)

Cạnh luồng dữ liệu của một số toán tử X xác định một cầu nối giữa một tập hợp các nguồn (mỗi nguồn là một toán hạng của X hoặc một toán hạng của trình kết thúc khối của X) và một tập hợp các mục tiêu (mỗi mục tiêu là kết quả của X hoặc đối số khối của X), sao cho tất cả các nguồn và mục tiêu phải được phân đoạn theo cùng một cách.

Một toán tử có thể có nhiều cạnh luồng dữ liệu vuông góc với nhau.

Ví dụ:

  y_0, ..., y_n = while (x_0, ..., x_n)
                  ((pred_arg_0,... , pred_arg_n) { ... })
                  ((body_arg_0,..., body_arg_n) {
                    ...
                    return return_value_0, ..., return_value_n
                  })

Toán tử while này có n cạnh luồng dữ liệu, cạnh luồng dữ liệu thứ i nằm giữa nguồn x_i, return_value_i và đích y_i, pred_arg_i, body_arg_i.

sdy.data_flow_edge lấy làm đầu vào chủ sở hữu của một cạnh (có thể là bất kỳ mục tiêu nào, nhưng tốt nhất là kết quả của toán tử thay vì đối số khối), không được có bất kỳ mục đích sử dụng nào khác. Toán tử này không thuần tuý vì có thể nhận dữ liệu đầu vào ban đầu không có bất kỳ giá trị sử dụng nào.

sdy.data_flow_edge cũng chứa một phân đoạn không bắt buộc cho tất cả các mục tiêu của cạnh và phân đoạn đó phải được cập nhật thay vì phân đoạn của mục tiêu (nếu có thể đính kèm) trong quá trình truyền tải. Điều này rất hữu ích khi một toán tử có nhiều cạnh, vì việc này sẽ hiệu quả hơn nhiều khi:

  • truyền tải riêng qua từng cạnh.
  • cập nhật từng phân đoạn cạnh riêng biệt thay vì tất cả các mục tiêu cùng một lúc (ví dụ: một toán tử có một TensorShardingPerValueAttr không thể thay đổi duy nhất cho các phân đoạn kết quả).
  • thêm từng cạnh vào danh sách công việc riêng biệt khi phân đoạn của một nguồn đã thay đổi.

Quá trình truyền tải sẽ truyền tải các phân đoạn giữa tất cả các nguồn và đích của sdy.data_flow_edge như thể đó là một toán tử thông thường với các nguồn là toán hạng và đích là kết quả, cũng như một giá trị nhận dạng sdy.op_sharding_rule. Điều đó có nghĩa là quá trình truyền về phía trước là từ nguồn đến đích và quá trình truyền về phía sau là từ đích đến nguồn.

Chúng tôi không cho phép toán tử SdyDialect xác định đầu vào của sdy.data_flow_edge, vì vậy, chúng ta có thể giả định rằng toán tử này được xác định bằng một toán tử có thuộc tính sdy.sharding chưa đăng ký.

Đặc điểm: SameOperandsAndResultType

Giao diện: InferTypeOpInterface

Thuộc tính:

Thuộc tínhLoại MLIRMô tả
sharding::mlir::sdy::TensorShardingAttrPhân đoạn tensor

Toán hạng:

Toán hạng Mô tả
input có hình dạng của bất kỳ giá trị loại nào

Kết quả:

Kết quả Mô tả
result có hình dạng của bất kỳ giá trị loại nào

sdy.manual_computation (sdy::ManualComputationOp)

Thao tác song song trên nhiều thiết bị bằng các tập hợp thủ công

Cú pháp:

operation ::= `sdy.manual_computation` `(`operands`)`
              `in_shardings````=```custom<StrippedTensorShardingPerValueAttr>($in_shardings)
              `out_shardings````=```custom<StrippedTensorShardingPerValueAttr>($out_shardings)
              `manual_axes````=```$manual_axes
              custom<SingleBlockRegionNoBlockId>($body)
              attr-dict
              `:`
              functional-type(operands, results)

Chuyển đến một vùng được viết bằng mã cục bộ trên mỗi thiết bị với các tập hợp rõ ràng, trong đó các hình dạng logic khớp với các hình dạng vùng đệm vật lý cục bộ trên mỗi thiết bị và các tập hợp tương ứng chính xác với giao tiếp vật lý trên nhiều thiết bị.

Thân xe là cục bộ so với manual_axes. Quá trình truyền sẽ diễn ra thông qua thân trên mọi trục tự do – những trục không có trong danh sách manual_axes.

Các quy tắc ràng buộc:

  • Các phần tử trong in_shardingsout_shardings phải đáp ứng các quy tắc ràng buộc được liệt kê trong TensorShardingAttr.
  • Số lượng đầu vào/đầu ra tensor toàn cục và cục bộ của vùng toán tử phải khớp nhau.
  • Trục thủ công phải đứng trước mọi trục tự do trong mỗi phân đoạn chiều.
  • Trục thủ công không thể thêm khoảng đệm. Cụ thể, kích thước phương diện phải chia hết cho kích thước trục thủ công tương ứng.
  • Hình dạng toàn cục và cục bộ của đối số/kết quả của vùng hoạt động phải khớp với nhau.
  • Không có trục thủ công nào được tách.

Thuộc tính: IsolatedFromAbove, RecursiveMemoryEffects, SingleBlockImplicitTerminator<ReturnOp>, SingleBlock

Giao diện: ShardableDataFlowOpInterface

Thuộc tính:

Thuộc tínhLoại MLIRMô tả
in_shardings::mlir::sdy::TensorShardingPerValueAttrPhân đoạn tensor theo toán hạng/kết quả của một toán tử
out_shardings::mlir::sdy::TensorShardingPerValueAttrPhân đoạn tensor theo toán hạng/kết quả của một toán tử
manual_axes::mlir::sdy::ManualAxesAttrDanh sách các trục mà ManualComputationOp là thủ công

Toán hạng:

Toán hạng Mô tả
tensors biến của tensor được xếp hạng của bất kỳ giá trị loại nào

Kết quả:

Kết quả Mô tả
results biến của tensor được xếp hạng của bất kỳ giá trị loại nào

sdy.mesh (sdy::MeshOp)

Lưới được đặt tên

Cú pháp:

operation ::= `sdy.mesh` $sym_name `=` $mesh attr-dict

Xác định một lưới có tên mới. Tất cả lưới trong một mô-đun phải có cùng số lượng thiết bị (ngoại trừ lưới có một device_id). Lưới là một toán tử Symbol xuất hiện trong SymbolTable của mô-đun và có thể được tham chiếu bằng name của mô-đun đó.

Đặc điểm: HasParent<ModuleOp>

Giao diện: Symbol

Thuộc tính:

Thuộc tínhLoại MLIRMô tả
sym_name::mlir::StringAttrthuộc tính chuỗi
mesh::mlir::sdy::MeshAttrLưới trục và danh sách thiết bị

sdy.named_computation (sdy::NamedComputationOp)

Toán tử tính toán được đặt tên

Cú pháp:

operation ::= `sdy.named_computation` `<`$name`>` `` `(` $operands `)`
              (`in_shardings````=```custom<StrippedTensorShardingPerValueAttr>($in_shardings)^)?
              (`out_shardings````=```custom<StrippedTensorShardingPerValueAttr>($out_shardings)^)?
              custom<SingleBlockRegionNoBlockId>($body)
              attr-dict
              `:` functional-type($operands, results)

Nhóm một phép tính, tức là một khối các phép toán và đặt tên cho phép tính đó. Quá trình truyền sẽ diễn ra vào/ra khỏi vùng này như thể mọi thứ đều được cùng dòng.

Bạn có thể dùng phương thức này để xử lý việc truyền thông qua các lệnh gọi đến các hàm khác. Mọi người dùng Shardy đều phải viết một thẻ nhập/xuất chuyển đổi các thao tác gọi thành thao tác sdy.named_computation, sao chép/sao chép nội dung của hàm được gọi vào nội dung của named_computation.

Loại của mỗi đối số khối và giá trị trả về trong vùng phải giống với loại toán hạng và loại kết quả của toán tử.

Ví dụ:

%1 = sdy.named_computation<"foo">(%0) (%arg1: tensor<16x32xf32>) {
  sdy.return %arg1 : tensor<16x32xf32>
} : (tensor<16x32xf32>) -> tensor<16x32xf32>

Thuộc tính: IsolatedFromAbove, RecursiveMemoryEffects, RecursivelySpeculatableImplTrait, SingleBlockImplicitTerminator<ReturnOp>, SingleBlock

Giao diện: ConditionallySpeculatable, InferTypeOpInterface, ShardableDataFlowOpInterface

Thuộc tính:

Thuộc tínhLoại MLIRMô tả
name::mlir::StringAttrthuộc tính chuỗi
in_shardings::mlir::sdy::TensorShardingPerValueAttrPhân đoạn tensor theo toán hạng/kết quả của một toán tử
out_shardings::mlir::sdy::TensorShardingPerValueAttrPhân đoạn tensor theo toán hạng/kết quả của một toán tử

Toán hạng:

Toán hạng Mô tả
operands biến kiểu bất kỳ

Kết quả:

Kết quả Mô tả
«unnamed» biến kiểu bất kỳ

sdy.propagation_barrier (sdy::PropagationBarrierOp)

Thao tác rào cản truyền tải

Cú pháp:

operation ::= `sdy.propagation_barrier` $input `allowed_direction````=```$allowed_direction attr-dict `:` type($input)

Toán tử này hoạt động giống như toán tử nhận dạng, xuất ra cùng một giá trị đã lấy làm đầu vào. Nhưng về mặt truyền tải, điều này sẽ chỉ cho phép truyền tải chảy qua theo một hướng nhất định.

Điều này ngăn việc phân đoạn được truyền giữa các lần sử dụng kết quả của toán tử rào cản và toán hạng của toán tử đó.

  • FORWARD có nghĩa là các phần phân đoạn chỉ có thể chuyển từ toán hạng đến kết quả.
  • BACKWARD có nghĩa là các phần phân đoạn chỉ có thể chuyển từ kết quả đến toán hạng.
  • NONE có nghĩa là không có hoạt động phân đoạn nào có thể truyền qua toán tử này.
  • Không thể chỉ định BOTH vì toán tử này sẽ thừa.

Thuộc tính: AlwaysSpeculatableImplTrait, SameOperandsAndResultType

Giao diện: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface)

Hiệu ứng: MemoryEffects::Effect{}

Thuộc tính:

Thuộc tínhLoại MLIRMô tả
allowed_direction::mlir::sdy::PropagationDirectionAttrenum hướng truyền

Toán hạng:

Toán hạng Mô tả
input tensor được xếp hạng của bất kỳ giá trị loại nào

Kết quả:

Kết quả Mô tả
result tensor được xếp hạng của bất kỳ giá trị loại nào

sdy.reshard (sdy::ReshardOp)

Phân đoạn lại một tensor thành một phân đoạn khác

Cú pháp:

operation ::= `sdy.reshard` $input $sharding attr-dict `:` type($result)

Phân đoạn lại tensor đầu vào bằng phương thức phân đoạn đã chỉ định, khác với phương thức phân đoạn hiện có của tensor đầu vào.

Cả ShardingConstraintOp và ReshardOp đều đính kèm một phân đoạn vào một tensor. Thời gian hoạt động của chúng là:

  1. Trước khi truyền phân đoạn, người dùng sẽ thêm ShardingConstraintOp.
  2. Quá trình truyền tải phân đoạn sẽ sử dụng ShardingConstraintOp. Không có ShardingConstraintOp nào trong kết quả của quá trình truyền dữ liệu phân đoạn. Thay vào đó, bạn có thể thêm ReshardOp nếu cần.
  3. Trình phân vùng chuyển đổi ReshardOp thành một toán tử tập hợp (hoặc toán tử nhận dạng). Không được có ReshardOp trong kết quả của trình phân vùng.

// TODO(b/331680067). Thêm mẫu chuẩn hoá để xoá các thao tác phân đoạn lại // thừa.

Thuộc tính: AlwaysSpeculatableImplTrait, SameOperandsAndResultType

Giao diện: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface)

Hiệu ứng: MemoryEffects::Effect{}

Thuộc tính:

Thuộc tínhLoại MLIRMô tả
sharding::mlir::sdy::TensorShardingAttrPhân đoạn tensor

Toán hạng:

Toán hạng Mô tả
input tensor của bất kỳ giá trị loại nào

Kết quả:

Kết quả Mô tả
result tensor của bất kỳ giá trị loại nào

sdy.return (sdy::ReturnOp)

Thao tác sdy.return chấm dứt các vùng được đính kèm vào các thao tác dựa trên vùng sdy và mọi thao tác dựa trên vùng Shardy khác. Đây là một hàm biến: hàm này lấy danh sách các giá trị có kiểu bất kỳ làm đối số (nhưng cùng loại, ví dụ: AnyTensor) và do đó có thể được sử dụng lại ở nhiều cấp của ngăn xếp IR Shardy.

Cú pháp:

operation ::= `sdy.return` attr-dict ($results^ `:` type($results))?

Đặc điểm: AlwaysSpeculatableImplTrait, Terminator

Giao diện: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Hiệu ứng: MemoryEffects::Effect{}

Toán hạng:

Toán hạng Mô tả
results biến kiểu bất kỳ

sdy.sharding_constraint (sdy::ShardingConstraintOp)

Ràng buộc một tensor với phân đoạn đã chỉ định

Cú pháp:

operation ::= `sdy.sharding_constraint` $input $sharding attr-dict `:` type($result)

Đính kèm một phân đoạn vào một tensor trung gian (ví dụ: kết quả của matmul) để cho biết đây là cách phân đoạn tensor đó hoặc một tập hợp con của các mục sử dụng tensor đó.

Nếu phân đoạn có các phương diện mở và trục không ràng buộc, thì điều đó có nghĩa là tensor có thể được phân đoạn thêm dọc theo các phương diện mở.

Thao tác này có thể:

  • Không có mục đích sử dụng (dangling) – tức là việc phân đoạn đính kèm là cách phân đoạn chính tensor đầu vào.
  • Có trường hợp sử dụng – nghĩa là phân đoạn đính kèm là cách phân đoạn các trường hợp sử dụng của toán tử ràng buộc phân đoạn, trong khi các trường hợp sử dụng khác của tensor đầu vào có thể có phân đoạn khác (nếu tensor đầu vào không có trường hợp sử dụng nào khác thì hành vi sẽ giống như trường hợp không sử dụng).

Đặc điểm: SameOperandsAndResultType

Giao diện: InferTypeOpInterface

Thuộc tính:

Thuộc tínhLoại MLIRMô tả
sharding::mlir::sdy::TensorShardingAttrPhân đoạn tensor

Toán hạng:

Toán hạng Mô tả
input tensor của bất kỳ giá trị loại nào

Kết quả:

Kết quả Mô tả
result tensor của bất kỳ giá trị loại nào

sdy.sharding_group (sdy::ShardingGroupOp)

Giới hạn các tensor trong nhóm để có cùng một phân đoạn.

Cú pháp:

operation ::= `sdy.sharding_group` $input `group_id````=```$group_id attr-dict `:` type($input)

Toán tử này cung cấp một giao diện để gán tensor cho các nhóm phân đoạn ( các nhóm tensor sẽ được thực thi để có các phân đoạn giống hệt nhau). Trong quá trình truyền tải, ngay khi một phần tử nhóm được phân đoạn, tất cả các thành phần khác sẽ được phân đoạn theo đúng cách. Thao tác này lấy mã nhận dạng nhóm đối số và không trả về kết quả nào, nhưng thay vào đó, sửa đổi nội dung trình bày nhóm phân đoạn nội bộ để thêm tensor đầu vào vào nhóm có mã nhận dạng đã cho.

Giao diện: InferTypeOpInterface

Thuộc tính:

Thuộc tínhLoại MLIRMô tả
group_id::mlir::IntegerAttrThuộc tính số nguyên 64 bit không dấu

Toán hạng:

Toán hạng Mô tả
input tensor được xếp hạng của bất kỳ giá trị loại nào

Thuộc tính

AllToAllParamAttr

Thông số tất cả với tất cả

Cú pháp:

#sdy.all_to_all_param<
  ::llvm::ArrayRef<AxisRefAttr>,   # axes
  int64_t,   # src_dim
  int64_t   # tgt_dim
>

Một bộ chứa các trục và phương diện nguồn/đích để thực hiện tính năng tất cả với tất cả.

Các thông số:

Thông số Loại C++ Mô tả
trục ::llvm::ArrayRef<AxisRefAttr> các trục để thực hiện tất cả với tất cả
src_dim int64_t chỉ mục phương diện nguồn
tgt_dim int64_t chỉ mục phương diện mục tiêu

AlltoAllParamListAttr

Danh sách tham số tất cả với tất cả

Cú pháp:

#sdy.all_to_all_param_list<
  ::llvm::ArrayRef<AllToAllParamAttr>   # value
>

Các thông số:

Thông số Loại C++ Mô tả
value ::llvm::ArrayRef<AllToAllParamAttr>

AxisRefAttr

Tham chiếu đến một trục đầy đủ hoặc một trục phụ được chia

Cú pháp:

#sdy.axis_ref<
  ::llvm::StringRef,   # name
  SubAxisInfoAttr   # sub_axis_info
>

Các quy tắc ràng buộc:

  • name phải có trong MeshAttr đã liên kết.
  • Nếu có sub_axis_info, thì sub_axis_info phải đáp ứng các quy tắc ràng buộc của SubAxisInfoAttr.

Các thông số:

Thông số Loại C++ Mô tả
tên ::llvm::StringRef tên của trục này
sub_axis_info SubAxisInfoAttr thông tin bổ sung nếu đây là trục phụ

AxisRefListAttr

Danh sách tệp tham chiếu trục

Cú pháp:

#sdy.axis_ref_list<
  ::llvm::ArrayRef<AxisRefAttr>   # value
>

Các quy tắc ràng buộc:

  • Các phần tử trong value phải đáp ứng các quy tắc ràng buộc của AxisRefAttr.
  • Không có tham chiếu trục hoặc trục phụ trùng lặp chồng chéo lên nhau.
  • Không có hai tệp tham chiếu trục liền kề nào là trục phụ liên tiếp của cùng một trục đầy đủ, nghĩa là chúng có thể được hợp nhất thành một trục phụ hoặc trục đầy đủ.

Các thông số:

Thông số Loại C++ Mô tả
value ::llvm::ArrayRef<AxisRefAttr>

DimMappingAttr

Danh sách chỉ số nhân tố cho một phương diện

Danh sách trống cho biết đây là mối liên kết rỗng (được phân tích cú pháp/in bằng *), tức là phương diện không được liên kết với bất kỳ yếu tố nào.

Các quy tắc ràng buộc:

  • Có ít nhất một chỉ mục nhân tố.
  • Chỉ số hệ số phải nằm trong khoảng [0, $factor_sizes).
  • Nếu có nhiều yếu tố, thì không có yếu tố nào có thể có kích thước 1.
  • Không có chỉ mục nhân tố trùng lặp.

Các thông số:

Thông số Loại C++ Mô tả
factor_indices ::llvm::ArrayRef<int64_t> các yếu tố mà phương diện này được liên kết

DimensionShardingAttr

Phân đoạn phương diện

Danh sách tên trục để phân đoạn một phương diện tensor từ chính đến phụ, một giá trị boolean cho biết liệu phương diện có thể được phân đoạn thêm hay không và một số nguyên tuỳ chọn biểu thị mức độ ưu tiên của việc phân đoạn phương diện này. Mức độ ưu tiên này sẽ được tuân thủ trong quá trình truyền tải phân đoạn. Mức độ ưu tiên bắt nguồn từ các chú thích phân đoạn người dùng và giá trị thấp hơn biểu thị mức độ ưu tiên cao hơn. Mức độ ưu tiên cao nhất được giả định khi mức độ ưu tiên bị thiếu trong chú thích.

Các quy tắc ràng buộc:

  • Các phần tử trong axes phải đáp ứng các quy tắc ràng buộc được liệt kê trong AxisRefListAttr.
  • Nếu một phương diện phân đoạn có mức độ ưu tiên:
    • Mức độ ưu tiên lớn hơn hoặc bằng 0.
    • Phương diện có ít nhất một trục nếu phương diện đó đóng.

Các thông số:

Thông số Loại C++ Mô tả
trục ::llvm::ArrayRef<AxisRefAttr> tệp tham chiếu trục
is_closed bool liệu bạn có thể phân đoạn thêm phương diện này hay không
của chiến dịch std::optional<int64_t> mức độ ưu tiên được sử dụng trong quá trình truyền tải dựa trên mức độ ưu tiên của người dùng

ListOfAxisRefListsAttr

Danh sách danh sách tham chiếu trục

Cú pháp:

#sdy.list_of_axis_ref_lists<
  ::llvm::ArrayRef<AxisRefListAttr>   # value
>

Các thông số:

Thông số Loại C++ Mô tả
value ::llvm::ArrayRef<AxisRefListAttr>

ManualAxesAttr

Danh sách các trục mà ManualComputationOp là thủ công

Cú pháp:

#sdy.manual_axes<
  ::llvm::ArrayRef<StringAttr>   # value
>

Các thông số:

Thông số Loại C++ Mô tả
value ::llvm::ArrayRef<StringAttr>

MeshAttr

Lưới trục và danh sách thiết bị

Cú pháp:

#sdy.mesh<
  ::llvm::ArrayRef<MeshAxisAttr>,   # axes
  ::llvm::ArrayRef<int64_t>   # device_ids
>

Lưới là danh sách các trục và danh sách mã thiết bị (không bắt buộc) chỉ định thứ tự thiết bị.

Nếu danh sách trục trống, thì lưới sẽ có một trục ẩn không tên có kích thước 1. Trong trường hợp này, nếu bạn không cung cấp danh sách mã thiết bị, thì danh sách mã thiết bị ngầm ẩn sẽ là [0]; nếu bạn cung cấp danh sách mã thiết bị, thì danh sách đó phải chứa một số nguyên bất kỳ có giá trị không âm. Chúng tôi gọi trường hợp này là phân đoạn tối đa.

Đối với tất cả các trường hợp không phân đoạn tối đa, nếu bạn chỉ định danh sách mã thiết bị, thì tích của kích thước trục phải khớp với số lượng thiết bị. Nếu bạn không chỉ định danh sách mã thiết bị, thì danh sách mã thiết bị ngầm ẩn sẽ là iota(product(axes)). Để đơn giản, chúng tôi cũng không cho phép chỉ định danh sách mã thiết bị giống với iota(product(axes)); trong trường hợp này, bạn không nên chỉ định danh sách mã thiết bị.

Dưới đây là một số ví dụ về lưới:

  • Lưới trống đại diện cho lưới giữ chỗ có thể được thay thế trong quá trình truyền: <[]>
  • Một lưới có trục không tên và mã thiết bị rõ ràng, thường được dùng để biểu thị phân đoạn tối đa: <[], device_ids=[3]>
  • Một lưới có hai trục và mã thiết bị ngầm ẩn iota(6): <["a"=2, "b"=3]>
  • Một lưới có hai trục và mã thiết bị rõ ràng chỉ định thứ tự thiết bị: <["a"=3, "b"=2], device_ids=[0, 2, 4, 1, 3, 5]>

Các quy tắc ràng buộc:

  • Các phần tử trong axes không được có tên trùng lặp.
  • Nếu bạn chỉ định device_ids:
    • Sản phẩm của kích thước trục phải khớp với số lượng thiết bị.
    • Tất cả các phần tử của mảng này đều phải là số không âm.
    • device_ids không được bằng iota(product(axis_sizes)).
    • device_ids đã sắp xếp phải là iota(product(axis_sizes)).

Các thông số:

Thông số Loại C++ Mô tả
trục ::llvm::ArrayRef<MeshAxisAttr> trục lưới
device_ids ::llvm::ArrayRef<int64_t> thứ tự thiết bị rõ ràng hoặc mã thiết bị tối đa

MeshAxisAttr

Trục được đặt tên trong lưới

Cú pháp:

#sdy.mesh_axis<
  ::llvm::StringRef,   # name
  int64_t   # size
>

Các thông số:

Thông số Loại C++ Mô tả
tên ::llvm::StringRef tên
size int64_t kích thước của trục này

OpShardingRuleAttr

Chỉ định cách phân vùng một toán tử.

Cú pháp:

#sdy.op_sharding_rule<
  ::llvm::ArrayRef<int64_t>,   # factor_sizes
  ::llvm::ArrayRef<TensorMappingAttr>,   # operand_mappings
  ::llvm::ArrayRef<TensorMappingAttr>,   # result_mappings
  ::llvm::ArrayRef<int64_t>,   # reduction_factors
  ::llvm::ArrayRef<int64_t>,   # need_replication_factors
  ::llvm::ArrayRef<int64_t>,   # permutation_factors
  ::llvm::ArrayRef<int64_t>,   # blocked_propagation_factors
  bool   # is_custom_rule
>

Quy tắc phân đoạn chỉ định cách phân đoạn một toán tử theo nhiều thuộc tính trên toán tử – bất kỳ thuộc tính nào, hình dạng của toán hạng, hình dạng của kết quả, v.v. Ví dụ:

%0 = stablehlo.add %arg0, %arg1 {
    sdy.sharding_rule = #sdy.op_sharding_rule<
        ([i, j],[i, j])->([i, j])
        {i=8, j=8}>
} : tensor<8x8xf32>
%1 = stablehlo.dot_general %arg2, %arg3, contracting_dims = [1] x [0] {
  sdy.sharding_rule = #sdy.op_sharding_rule<
      ([i, k],[k, j])->([i, j])
      {i=8, j=16, k=8}>
}: (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32>

Xin lưu ý rằng chúng tôi cho phép các hệ số có kích thước 1 mặc dù không thể phân đoạn được, điều này chủ yếu là để đảm bảo tính đầy đủ vì nhiều toán tử như toán tử theo điểm có kích thước một phương diện tương ứng với các toán hạng và kết quả.

Loại yếu tố:

  • reduction_factors chứa các chỉ mục của các hệ số cần giảm, chẳng hạn như các phương diện thu hẹp trong phép toán dấu chấm.
  • need_replication_factors chứa các chỉ mục của các yếu tố yêu cầu sao chép đầy đủ, chẳng hạn như phương diện đã sắp xếp trong một thao tác sắp xếp.
  • permutation_factors chứa các chỉ mục của các yếu tố yêu cầu hoán vị tập thể nếu các yếu tố đó được phân đoạn, chẳng hạn như kích thước khoảng đệm trong thao tác đệm.
  • Tất cả các yếu tố khác được coi là yếu tố truyền qua, tức là các yếu tố không yêu cầu bất kỳ hoạt động giao tiếp nào nếu được phân đoạn theo cùng một cách trên tất cả các tensor được liên kết với chúng.

blocked_propagation_factors chứa các yếu tố theo đó không được phép phân đoạn. Phương diện này vuông góc với các loại yếu tố. Cụ thể, yếu tố chặn truyền có thể là bất kỳ loại yếu tố nào.

is_custom_rule mô tả liệu đây có phải là quy tắc do người dùng xác định hay không. Người dùng có thể xác định quy tắc phân đoạn cho các lệnh gọi tuỳ chỉnh hoặc ghi đè quy tắc phân đoạn được xác định trước cho các thao tác chuẩn. Quy tắc tuỳ chỉnh luôn được giữ nguyên/không bao giờ bị xoá.

Các quy tắc ràng buộc:

  • Số lượt ánh xạ toán hạng/kết quả phải khớp với số lượng toán hạng/kết quả của toán tử.
  • Có ít nhất một mối liên kết (không thể có quy tắc cho một toán tử không có toán hạng/kết quả).
  • Hạng của mỗi TensorMappingAttr khớp với hạng của loại tensor tương ứng.
  • Đối với mỗi nhóm yếu tố (reduction_factors, need_replication_factors, permutation_factors):
    • Các phần tử phải nằm trong khoảng [0, $factor_sizes].
    • Không có chỉ mục nhân tố trùng lặp trong mỗi nhóm và giữa các nhóm.

Các thông số:

Thông số Loại C++ Mô tả
factor_sizes ::llvm::ArrayRef<int64_t> kích thước của tất cả các hệ số trong quy tắc này
operand_mappings ::llvm::ArrayRef<TensorMappingAttr> ánh xạ toán hạng
result_mappings ::llvm::ArrayRef<TensorMappingAttr> ánh xạ kết quả
reduction_factors ::llvm::ArrayRef<int64_t> các yếu tố cần giảm
need_replication_factors ::llvm::ArrayRef<int64_t> các yếu tố yêu cầu sao chép toàn bộ
permutation_factors ::llvm::ArrayRef<int64_t> các yếu tố yêu cầu hoán vị tập thể
blocked_propagation_factors ::llvm::ArrayRef<int64_t> các yếu tố theo đó không phân đoạn
is_custom_rule bool liệu quy tắc có phải là cho stablehlo.custom_call hay không

SubAxisInfoAttr

Thông tin về cách trục phụ này được lấy từ trục đầy đủ

Cú pháp:

#sdy.sub_axis_info<
  int64_t,   # pre_size
  int64_t   # size
>

Khi chia một trục đầy đủ thành n trục phụ, trục này được định hình lại thành [k_1,...,k_n] và trục phụ thứ i có thể được biểu thị bằng tích của tất cả kích thước trục ở bên trái m=prod(k_1,...,k_(i-1)) (còn gọi là kích thước trước) và kích thước k_i. Do đó, thuộc tính sub-axis-info chứa hai số đó và được ký hiệu như sau: (m)k cho kích thước trước m và kích thước k.

Các quy tắc ràng buộc:

  • pre-size tối thiểu là 1.
  • size lớn hơn 1.
  • pre-size phải chia kích thước của trục đầy đủ, tức là cả pre-sizesize chia kích thước của trục đầy đủ và trục phụ không vượt quá trục đầy đủ.
  • Kích thước của trục phụ không bằng kích thước của trục đầy đủ tương ứng, trong trường hợp này, bạn nên sử dụng trục đầy đủ.

Các thông số:

Thông số Loại C++ Mô tả
pre_size int64_t tích của các kích thước trục phụ ở bên trái trục phụ này
size int64_t kích thước của trục phụ này

TensorMappingAttr

Các ánh xạ nhân tố cho mỗi phương diện của một tensor.

Cú pháp:

#sdy.tensor_mapping<
  ::llvm::ArrayRef<DimMappingAttr>   # dim_mappings
>

Các quy tắc ràng buộc:

  • Các phần tử trong dim_mappings phải đáp ứng các quy tắc ràng buộc trong DimMappingAttr.
  • Không có chỉ mục nhân tố trùng lặp trên các phương diện.

Các thông số:

Thông số Loại C++ Mô tả
dim_mappings ::llvm::ArrayRef<DimMappingAttr> mối liên kết phương diện

TensorShardingAttr

Phân đoạn tensor

Cú pháp:

#sdy.sharding<
  ::mlir::Attribute,   # mesh_or_ref
  ::llvm::ArrayRef<DimensionShardingAttr>,   # dim_shardings
  ::llvm::ArrayRef<AxisRefAttr>   # replicated_axes
>

Phân đoạn tensor được liên kết với một lưới cụ thể và chỉ có thể tham chiếu tên trục từ lưới đó. Các phân đoạn theo phương diện cho chúng ta biết về mỗi phương diện của tensor, theo đó các trục (hoặc trục phụ) được phân đoạn từ chính đến phụ. Tất cả các trục khác không phân đoạn một phương diện sẽ được sao chép ngầm hoặc rõ ràng (nếu xuất hiện trong danh sách trục được sao chép).

Lưới liên kết với phân đoạn này có thể được chỉ định bằng tên biểu tượng, tham chiếu đến biểu tượng MeshOp tương ứng hoặc MeshAttr nội tuyến.

Các quy tắc ràng buộc:

  • Các phần tử trong dim_shardings phải đáp ứng các quy tắc ràng buộc được liệt kê trong DimensionShardingAttr.
  • Các phần tử trong replicated_axes phải đáp ứng các quy tắc ràng buộc được liệt kê trong AxisRefListAttr.
  • Nếu loại tensor tương ứng không phải là ShapedType, thì phân đoạn phải có thứ hạng 0 và không có trục được sao chép.
  • Tensor phải có thứ hạng.
  • Số lượng phân đoạn theo phương diện bằng với thứ hạng của tensor.
  • Phương diện có kích thước 0 không được phân đoạn.
  • Các mục trong replicated_axes được sắp xếp theo mesh_or_ref (xem AxisRefAttr::getMeshComparator).

Các thông số:

Thông số Loại C++ Mô tả
mesh_or_ref ::mlir::Attribute thuộc tính lưới hoặc thuộc tính tham chiếu biểu tượng lưới phẳng
dim_shardings ::llvm::ArrayRef<DimensionShardingAttr> phân đoạn phương diện
replicated_axes ::llvm::ArrayRef<AxisRefAttr> tệp tham chiếu trục

TensorShardingPerValueAttr

Phân đoạn tensor theo toán hạng/kết quả của một toán tử

Cú pháp:

#sdy.sharding_per_value<
  ::llvm::ArrayRef<TensorShardingAttr>   # shardings
>

Danh sách TensorShardingAttr, mỗi TensorShardingAttr cho một toán hạng/kết quả của một toán tử.

Các quy tắc ràng buộc:

  • Các phần tử trong shardings phải đáp ứng các quy tắc ràng buộc của TensorShardingAttr.

Các thông số:

Thông số Loại C++ Mô tả
phân đoạn ::llvm::ArrayRef<TensorShardingAttr> phân đoạn theo giá trị

Enum

PropagationDirection

Enum hướng truyền

Trường hợp:

Biểu tượng Giá trị Chuỗi
KHÔNG CÓ 0 KHÔNG CÓ
FORWARD 1 FORWARD
QUAY LẠI 2 QUAY LẠI
CẢ HAI BÊN 3 CẢ HAI BÊN