Ngôn ngữ Shardy (SDY) xác định một cách trình bày phân đoạn tensor dựa trên trục và các thành phần API bổ sung để đính kèm các phân đoạn vào tensor.
Hoạt động tính toán
sdy.all_gather
(sdy::AllGatherOp)
Thực hiện giao tiếp tất cả cùng nhau dọc theo các trục
Cú pháp:
operation ::= `sdy.all_gather` $gathering_axes $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)
Thu thập các phần của một tensor dọc theo các trục được chỉ định trong gathering_axes
.
gathering_axes
là danh sách các danh sách trục. Danh sách bên ngoài vượt quá kích thước của tensor. Mỗi danh sách bên trong chỉ định các trục theo đó sẽ thực hiện một hoạt động thu thập riêng biệt trên phương diện tương ứng. Phương thức này sẽ được áp dụng cho việc phân đoạn toán hạng (tensor
) để lấy kết quả phân đoạn (out_sharding
).
Xin lưu ý rằng out_sharding
không được dùng để xác định việc phân đoạn kết quả. Thay vào đó, việc phân đoạn kết quả được xác định bằng cách phân đoạn toán hạng và gathering_axes
, đồng thời out_sharding
phải khớp với phân đoạn suy luận này.
Ví dụ:
%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", "b", "c"}, {}, {"d"}\]>]>} : tensor<8x8x8xf32>
%2 = sdy.all_gather [{"b", "c"}, {}, {"d"}\] %1 out_sharding=<@mesh, [{"a"}, {}, {}\]> : tensor<8x8x8xf32>
Các quy tắc ràng buộc:
- Phải đáp ứng các điều kiện ràng buộc được liệt kê trong
Sdy_CollectiveOpInterface
. - Các phần tử trong
gathering_axes
phải đáp ứng các quy tắc ràng buộc được liệt kê trongAxisRefListAttr
. - Việc áp dụng
gathering_axes
cho hoạt động phân đoạn toán hạng sẽ nhận đượcout_sharding
.
Đặc điểm: SameOperandsAndResultType
Giao diện: InferTypeOpInterface
, Sdy_CollectiveOpInterface
Thuộc tính:
Thuộc tính | Loại MLIR | Mô tả |
---|---|---|
gathering_axes | ::mlir::sdy::ListOfAxisRefListsAttr | Danh sách danh sách tham chiếu trục |
out_sharding | ::mlir::sdy::TensorShardingAttr | Phân đoạn tensor |
Toán hạng:
Toán hạng | Mô tả |
---|---|
tensor |
tensor của bất kỳ giá trị loại nào |
Kết quả:
Kết quả | Mô tả |
---|---|
result |
tensor của bất kỳ giá trị loại nào |
sdy.all_reduce
(sdy::AllReduceOp)
Thực hiện giao tiếp giảm tất cả theo các trục
Cú pháp:
operation ::= `sdy.all_reduce` $reduction_axes $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)
Giảm các phần của một tensor dọc theo các trục được chỉ định trong reduction_axes
.
Thứ tự của reduction_axes
không quan trọng đối với kết quả, nhưng có thể ảnh hưởng đến thứ tự của các nhóm bản sao tương ứng.
Các quy tắc ràng buộc:
- Phải đáp ứng các điều kiện ràng buộc được liệt kê trong
Sdy_CollectiveOpInterface
. reduction_axes
phải đáp ứng các quy tắc ràng buộc được liệt kê trongAxisRefListAttr
;reduction_axes
không được chồng chéo với các trục phân đoạn toán hạng;
Đặc điểm: SameOperandsAndResultType
Giao diện: CollectiveOpInterface
, InferTypeOpInterface
Thuộc tính:
Thuộc tính | Loại MLIR | Mô tả |
---|---|---|
reduction_axes | ::mlir::sdy::AxisRefListAttr | Danh sách tệp tham chiếu trục |
out_sharding | ::mlir::sdy::TensorShardingAttr | Phân đoạn tensor |
Toán hạng:
Toán hạng | Mô tả |
---|---|
tensor |
tensor của bất kỳ giá trị loại nào |
Kết quả:
Kết quả | Mô tả |
---|---|
result |
tensor của bất kỳ giá trị loại nào |
sdy.all_slice
(sdy::AllSliceOp)
Thực hiện thao tác cắt động dọc theo các trục
Cú pháp:
operation ::= `sdy.all_slice` $slicing_axes $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)
Cắt các phần của một tensor dọc theo các trục được chỉ định trong slicing_axes
. Có một tính chất song đôi đại số giữa sdy.all_slice
và sdy.all_gather
.
slicing_axes
là danh sách các danh sách trục. Danh sách bên ngoài vượt quá kích thước của tensor. Mỗi danh sách bên trong chỉ định các trục dọc theo đó một lát cắt sẽ được thực hiện trên phương diện tương ứng. Phương thức này sẽ được áp dụng cho việc phân đoạn toán hạng (tensor
) để lấy kết quả phân đoạn (out_sharding
).
Xin lưu ý rằng out_sharding
không được dùng để xác định việc phân đoạn kết quả. Thay vào đó, việc phân đoạn kết quả được xác định bằng cách phân đoạn toán hạng và slicing_axes
, đồng thời out_sharding
phải khớp với phân đoạn suy luận này.
Ví dụ:
%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}, {}, {}\]>]>} : tensor<8x8x8xf32>
%2 = sdy.all_slice [{"b", "c"}, {}, {"d"}\] %1 out_sharding=<@mesh, [{"a", "b", "c"}, {}, {"d"}\]> : tensor<8x8x8xf32>
Các quy tắc ràng buộc:
- Các phần tử trong
slicing_axes
phải đáp ứng các quy tắc ràng buộc được liệt kê trongAxisRefListAttr
. - Phải đáp ứng các điều kiện ràng buộc được liệt kê trong
Sdy_CollectiveOpInterface
. - Việc áp dụng
slicing_axes
cho hoạt động phân đoạn toán hạng sẽ nhận đượcout_sharding
.
Đặc điểm: SameOperandsAndResultType
Giao diện: CollectiveOpInterface
, InferTypeOpInterface
Thuộc tính:
Thuộc tính | Loại MLIR | Mô tả |
---|---|---|
slicing_axes | ::mlir::sdy::ListOfAxisRefListsAttr | Danh sách danh sách tham chiếu trục |
out_sharding | ::mlir::sdy::TensorShardingAttr | Phân đoạn tensor |
Toán hạng:
Toán hạng | Mô tả |
---|---|
tensor |
tensor của bất kỳ giá trị loại nào |
Kết quả:
Kết quả | Mô tả |
---|---|
result |
tensor của bất kỳ giá trị loại nào |
sdy.all_to_all
(sdy::AllToAllOp)
Thực hiện giao tiếp tất cả với tất cả dọc theo các trục
Cú pháp:
operation ::= `sdy.all_to_all` $params $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)
Đối với mỗi bộ dữ liệu (axes, src_dim, tgt_dim) trong danh sách tham số, thao tác này sẽ cắt các phần của một tensor theo phương diện tgt_dim
và các trục được chỉ định trong axes
, phân tán các phần đó theo các trục và nối các phần đó theo phương diện src_dim
.
Về cơ bản, toán tử này là sự kết hợp của một toán tử thu thập tất cả theo src_dim
và axes
, theo sau là một toán tử cắt tất cả theo tgt_dim
và axes
, tức là hậu tố của phương diện phân đoạn trục src_dim
trên tensor đầu vào được thêm vào phương diện phân đoạn trục tgt_dim
trên tensor đầu ra.
Phương thức tất cả với tất cả sẽ được áp dụng cho việc phân đoạn toán hạng (tensor
) để lấy kết quả phân đoạn (out_sharding
).
Xin lưu ý rằng out_sharding
không được dùng để xác định việc phân đoạn kết quả. Thay vào đó, việc phân đoạn kết quả được xác định bằng cách phân đoạn toán hạng, src_dim
, tgt_dim
và axes
, đồng thời out_sharding
phải khớp với phân đoạn được suy luận này.
Ví dụ:
%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", "b"}, {"c"}, {}, {}\]>]>} : tensor<8x8x4x4x32>
%2 = sdy.all_to_all [{"b"}: 0->2, {"c"}: 1->3] %1 out_sharding=<@mesh, [{"a"}, {}, {"b"}, {"c"}\]> : tensor<8x8x4x4x32>
Các quy tắc ràng buộc:
- Phải đáp ứng các điều kiện ràng buộc được liệt kê trong
Sdy_CollectiveOpInterface
. - Danh sách tham số không được để trống.
- Đối với mỗi tham số trong
params
:- Các phần tử trong
axes
phải đáp ứng các quy tắc ràng buộc củaAxisRefAttr
. src_dim
vàtgt_dim
phải là các phương diện hợp lệ (không âm và nhỏ hơn thứ hạng của tensor).- Mọi
src_dim
hoặctgt_dim
phải là duy nhất trên tất cả các tham số. src_dim
phải được sắp xếp theo thứ tự tăng dần trên tất cả các tham số.
- Các phần tử trong
- Việc di chuyển
axes
từsrc_dim
sangtgt_dim
trong quá trình phân đoạn toán hạng sẽ nhận đượcout_sharding
.
Đặc điểm: SameOperandsAndResultType
Giao diện: InferTypeOpInterface
, Sdy_CollectiveOpInterface
Thuộc tính:
Thuộc tính | Loại MLIR | Mô tả |
---|---|---|
params | ::mlir::sdy::AlltoAllParamListAttr | Danh sách tham số tất cả với tất cả |
out_sharding | ::mlir::sdy::TensorShardingAttr | Phân đoạn tensor |
Toán hạng:
Toán hạng | Mô tả |
---|---|
tensor |
tensor của bất kỳ giá trị loại nào |
Kết quả:
Kết quả | Mô tả |
---|---|
result |
tensor của bất kỳ giá trị loại nào |
sdy.collective_permute
(sdy::CollectivePermuteOp)
Thực hiện giao tiếp hoán vị tập thể để thay thế các trục
Cú pháp:
operation ::= `sdy.collective_permute` $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)
Gửi một phần của tensor đầu vào từ mỗi thiết bị sang thiết bị khác để sắp xếp lại/thay thế các trục phân đoạn tensor.
Một phép hoán vị tập hợp có thể biến đổi hoạt động phân đoạn đầu vào sao cho mỗi kích thước phải được phân đoạn như trước, tức là phải được phân đoạn dọc theo các trục có tích kích thước khớp với tích kích thước của các trục đã phân đoạn tensor trước đó.
Điều này hữu ích khi sắp xếp lại các trục trong một phương diện hoặc trên nhiều phương diện, cũng như hoán đổi các trục được phân đoạn với các trục được sao chép.
Trong ví dụ bên dưới, kích thước tensor phân đoạn là tensor<1x4x2xf32>
và kích thước này được giữ nguyên bằng cách hoán vị tập thể.
Ví dụ:
sdy.mesh @mesh = <["a"=2, "b"=2, "c"=4, "d"=2, "e"=2, "f"=2]>
%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", "c"}, {"f"}, {"d", "e"}\]>]>} : tensor<8x8x8xf32>
%2 = sdy.collective_permute %1 out_sharding=<@mesh, [{"c":(1)2, "b", "f"}, {"a"}, {"e", "d"}\]> : tensor<8x8x8xf32>
Các quy tắc ràng buộc:
- Phải đáp ứng các điều kiện ràng buộc được liệt kê trong
Sdy_CollectiveOpInterface
. - Nếu phân đoạn đầu vào và đầu ra có các lưới khác nhau, thì các lưới đó phải có cùng một trục và thứ tự mã thiết bị khác nhau.
- Đối với mỗi phương diện, tích của kích thước trục phân đoạn trong
out_sharding
phải khớp với tích của kích thước phân đoạn phương diện toán hạng tương ứng.
Đặc điểm: SameOperandsAndResultType
Giao diện: CollectiveOpInterface
, InferTypeOpInterface
Thuộc tính:
Thuộc tính | Loại MLIR | Mô tả |
---|---|---|
out_sharding | ::mlir::sdy::TensorShardingAttr | Phân đoạn tensor |
Toán hạng:
Toán hạng | Mô tả |
---|---|
tensor |
tensor của bất kỳ giá trị loại nào |
Kết quả:
Kết quả | Mô tả |
---|---|
result |
tensor của bất kỳ giá trị loại nào |
sdy.constant
(sdy::ConstantOp)
Toán tử hằng
Tạo một tensor output
từ một hằng số value
.
Xem: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#constant
Ví dụ:
%output = sdy.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
Đặc điểm: AlwaysSpeculatableImplTrait
Giao diện: ConditionallySpeculatable
, InferTypeOpInterface
, NoMemoryEffect (MemoryEffectOpInterface)
Hiệu ứng: MemoryEffects::Effect{}
Thuộc tính:
Thuộc tính | Loại MLIR | Mô tả |
---|---|---|
value | ::mlir::ElementsAttr | thuộc tính vectơ/tensor không đổi |
Kết quả:
Kết quả | Mô tả |
---|---|
output |
tensor có hình dạng tĩnh của bất kỳ giá trị loại nào |
sdy.data_flow_edge
(sdy::DataFlowEdgeOp)
Toán tử cạnh luồng dữ liệu.
Cú pháp:
operation ::= `sdy.data_flow_edge` $input (`sharding````=``` $sharding^)? attr-dict `:` type($result)
Cạnh luồng dữ liệu của một số toán tử X xác định một cầu nối giữa một tập hợp các nguồn (mỗi nguồn là một toán hạng của X hoặc một toán hạng của trình kết thúc khối của X) và một tập hợp các mục tiêu (mỗi mục tiêu là kết quả của X hoặc đối số khối của X), sao cho tất cả các nguồn và mục tiêu phải được phân đoạn theo cùng một cách.
Một toán tử có thể có nhiều cạnh luồng dữ liệu vuông góc với nhau.
Ví dụ:
y_0, ..., y_n = while (x_0, ..., x_n)
((pred_arg_0,... , pred_arg_n) { ... })
((body_arg_0,..., body_arg_n) {
...
return return_value_0, ..., return_value_n
})
Toán tử while này có n cạnh luồng dữ liệu, cạnh luồng dữ liệu thứ i nằm giữa nguồn x_i
, return_value_i
và đích y_i
, pred_arg_i
, body_arg_i
.
sdy.data_flow_edge
lấy làm đầu vào chủ sở hữu của một cạnh (có thể là bất kỳ mục tiêu nào, nhưng tốt nhất là kết quả của toán tử thay vì đối số khối), không được có bất kỳ mục đích sử dụng nào khác. Toán tử này không thuần tuý vì có thể nhận dữ liệu đầu vào ban đầu không có bất kỳ giá trị sử dụng nào.
sdy.data_flow_edge
cũng chứa một phân đoạn không bắt buộc cho tất cả các mục tiêu của cạnh và phân đoạn đó phải được cập nhật thay vì phân đoạn của mục tiêu (nếu có thể đính kèm) trong quá trình truyền tải. Điều này rất hữu ích khi một toán tử có nhiều cạnh, vì việc này sẽ hiệu quả hơn nhiều khi:
- truyền tải riêng qua từng cạnh.
- cập nhật từng phân đoạn cạnh riêng biệt thay vì tất cả các mục tiêu cùng một lúc (ví dụ: một toán tử có một
TensorShardingPerValueAttr
không thể thay đổi duy nhất cho các phân đoạn kết quả). - thêm từng cạnh vào danh sách công việc riêng biệt khi phân đoạn của một nguồn đã thay đổi.
Quá trình truyền tải sẽ truyền tải các phân đoạn giữa tất cả các nguồn và đích của sdy.data_flow_edge
như thể đó là một toán tử thông thường với các nguồn là toán hạng và đích là kết quả, cũng như một giá trị nhận dạng sdy.op_sharding_rule
. Điều đó có nghĩa là quá trình truyền về phía trước là từ nguồn đến đích và quá trình truyền về phía sau là từ đích đến nguồn.
Chúng tôi không cho phép toán tử SdyDialect
xác định đầu vào của sdy.data_flow_edge
, vì vậy, chúng ta có thể giả định rằng toán tử này được xác định bằng một toán tử có thuộc tính sdy.sharding
chưa đăng ký.
Đặc điểm: SameOperandsAndResultType
Giao diện: InferTypeOpInterface
Thuộc tính:
Thuộc tính | Loại MLIR | Mô tả |
---|---|---|
sharding | ::mlir::sdy::TensorShardingAttr | Phân đoạn tensor |
Toán hạng:
Toán hạng | Mô tả |
---|---|
input |
có hình dạng của bất kỳ giá trị loại nào |
Kết quả:
Kết quả | Mô tả |
---|---|
result |
có hình dạng của bất kỳ giá trị loại nào |
sdy.manual_computation
(sdy::ManualComputationOp)
Thao tác song song trên nhiều thiết bị bằng các tập hợp thủ công
Cú pháp:
operation ::= `sdy.manual_computation` `(`operands`)`
`in_shardings````=```custom<StrippedTensorShardingPerValueAttr>($in_shardings)
`out_shardings````=```custom<StrippedTensorShardingPerValueAttr>($out_shardings)
`manual_axes````=```$manual_axes
custom<SingleBlockRegionNoBlockId>($body)
attr-dict
`:`
functional-type(operands, results)
Chuyển đến một vùng được viết bằng mã cục bộ trên mỗi thiết bị với các tập hợp rõ ràng, trong đó các hình dạng logic khớp với các hình dạng vùng đệm vật lý cục bộ trên mỗi thiết bị và các tập hợp tương ứng chính xác với giao tiếp vật lý trên nhiều thiết bị.
Thân xe là cục bộ so với manual_axes. Quá trình truyền sẽ diễn ra thông qua thân trên mọi trục tự do – những trục không có trong danh sách manual_axes.
Các quy tắc ràng buộc:
- Các phần tử trong
in_shardings
vàout_shardings
phải đáp ứng các quy tắc ràng buộc được liệt kê trongTensorShardingAttr
. - Số lượng đầu vào/đầu ra tensor toàn cục và cục bộ của vùng toán tử phải khớp nhau.
- Trục thủ công phải đứng trước mọi trục tự do trong mỗi phân đoạn chiều.
- Trục thủ công không thể thêm khoảng đệm. Cụ thể, kích thước phương diện phải chia hết cho kích thước trục thủ công tương ứng.
- Hình dạng toàn cục và cục bộ của đối số/kết quả của vùng hoạt động phải khớp với nhau.
- Không có trục thủ công nào được tách.
Thuộc tính: IsolatedFromAbove
, RecursiveMemoryEffects
, SingleBlockImplicitTerminator<ReturnOp>
, SingleBlock
Giao diện: ShardableDataFlowOpInterface
Thuộc tính:
Thuộc tính | Loại MLIR | Mô tả |
---|---|---|
in_shardings | ::mlir::sdy::TensorShardingPerValueAttr | Phân đoạn tensor theo toán hạng/kết quả của một toán tử |
out_shardings | ::mlir::sdy::TensorShardingPerValueAttr | Phân đoạn tensor theo toán hạng/kết quả của một toán tử |
manual_axes | ::mlir::sdy::ManualAxesAttr | Danh sách các trục mà ManualComputationOp là thủ công |
Toán hạng:
Toán hạng | Mô tả |
---|---|
tensors |
biến của tensor được xếp hạng của bất kỳ giá trị loại nào |
Kết quả:
Kết quả | Mô tả |
---|---|
results |
biến của tensor được xếp hạng của bất kỳ giá trị loại nào |
sdy.mesh
(sdy::MeshOp)
Lưới được đặt tên
Cú pháp:
operation ::= `sdy.mesh` $sym_name `=` $mesh attr-dict
Xác định một lưới có tên mới. Tất cả lưới trong một mô-đun phải có cùng số lượng thiết bị (ngoại trừ lưới có một device_id).
Lưới là một toán tử Symbol
xuất hiện trong SymbolTable
của mô-đun và có thể được tham chiếu bằng name
của mô-đun đó.
Đặc điểm: HasParent<ModuleOp>
Giao diện: Symbol
Thuộc tính:
Thuộc tính | Loại MLIR | Mô tả |
---|---|---|
sym_name | ::mlir::StringAttr | thuộc tính chuỗi |
mesh | ::mlir::sdy::MeshAttr | Lưới trục và danh sách thiết bị |
sdy.named_computation
(sdy::NamedComputationOp)
Toán tử tính toán được đặt tên
Cú pháp:
operation ::= `sdy.named_computation` `<`$name`>` `` `(` $operands `)`
(`in_shardings````=```custom<StrippedTensorShardingPerValueAttr>($in_shardings)^)?
(`out_shardings````=```custom<StrippedTensorShardingPerValueAttr>($out_shardings)^)?
custom<SingleBlockRegionNoBlockId>($body)
attr-dict
`:` functional-type($operands, results)
Nhóm một phép tính, tức là một khối các phép toán và đặt tên cho phép tính đó. Quá trình truyền sẽ diễn ra vào/ra khỏi vùng này như thể mọi thứ đều được cùng dòng.
Bạn có thể dùng phương thức này để xử lý việc truyền thông qua các lệnh gọi đến các hàm khác. Mọi người dùng Shardy đều phải viết một thẻ nhập/xuất chuyển đổi các thao tác gọi thành thao tác sdy.named_computation
, sao chép/sao chép nội dung của hàm được gọi vào nội dung của named_computation
.
Loại của mỗi đối số khối và giá trị trả về trong vùng phải giống với loại toán hạng và loại kết quả của toán tử.
Ví dụ:
%1 = sdy.named_computation<"foo">(%0) (%arg1: tensor<16x32xf32>) {
sdy.return %arg1 : tensor<16x32xf32>
} : (tensor<16x32xf32>) -> tensor<16x32xf32>
Thuộc tính: IsolatedFromAbove
, RecursiveMemoryEffects
, RecursivelySpeculatableImplTrait
, SingleBlockImplicitTerminator<ReturnOp>
, SingleBlock
Giao diện: ConditionallySpeculatable
, InferTypeOpInterface
, ShardableDataFlowOpInterface
Thuộc tính:
Thuộc tính | Loại MLIR | Mô tả |
---|---|---|
name | ::mlir::StringAttr | thuộc tính chuỗi |
in_shardings | ::mlir::sdy::TensorShardingPerValueAttr | Phân đoạn tensor theo toán hạng/kết quả của một toán tử |
out_shardings | ::mlir::sdy::TensorShardingPerValueAttr | Phân đoạn tensor theo toán hạng/kết quả của một toán tử |
Toán hạng:
Toán hạng | Mô tả |
---|---|
operands |
biến kiểu bất kỳ |
Kết quả:
Kết quả | Mô tả |
---|---|
«unnamed» | biến kiểu bất kỳ |
sdy.propagation_barrier
(sdy::PropagationBarrierOp)
Thao tác rào cản truyền tải
Cú pháp:
operation ::= `sdy.propagation_barrier` $input `allowed_direction````=```$allowed_direction attr-dict `:` type($input)
Toán tử này hoạt động giống như toán tử nhận dạng, xuất ra cùng một giá trị đã lấy làm đầu vào. Nhưng về mặt truyền tải, điều này sẽ chỉ cho phép truyền tải chảy qua theo một hướng nhất định.
Điều này ngăn việc phân đoạn được truyền giữa các lần sử dụng kết quả của toán tử rào cản và toán hạng của toán tử đó.
FORWARD
có nghĩa là các phần phân đoạn chỉ có thể chuyển từ toán hạng đến kết quả.BACKWARD
có nghĩa là các phần phân đoạn chỉ có thể chuyển từ kết quả đến toán hạng.NONE
có nghĩa là không có hoạt động phân đoạn nào có thể truyền qua toán tử này.- Không thể chỉ định
BOTH
vì toán tử này sẽ thừa.
Thuộc tính: AlwaysSpeculatableImplTrait
, SameOperandsAndResultType
Giao diện: ConditionallySpeculatable
, InferTypeOpInterface
, NoMemoryEffect (MemoryEffectOpInterface)
Hiệu ứng: MemoryEffects::Effect{}
Thuộc tính:
Thuộc tính | Loại MLIR | Mô tả |
---|---|---|
allowed_direction | ::mlir::sdy::PropagationDirectionAttr | enum hướng truyền |
Toán hạng:
Toán hạng | Mô tả |
---|---|
input |
tensor được xếp hạng của bất kỳ giá trị loại nào |
Kết quả:
Kết quả | Mô tả |
---|---|
result |
tensor được xếp hạng của bất kỳ giá trị loại nào |
sdy.reshard
(sdy::ReshardOp)
Phân đoạn lại một tensor thành một phân đoạn khác
Cú pháp:
operation ::= `sdy.reshard` $input $sharding attr-dict `:` type($result)
Phân đoạn lại tensor đầu vào bằng phương thức phân đoạn đã chỉ định, khác với phương thức phân đoạn hiện có của tensor đầu vào.
Cả ShardingConstraintOp và ReshardOp đều đính kèm một phân đoạn vào một tensor. Thời gian hoạt động của chúng là:
- Trước khi truyền phân đoạn, người dùng sẽ thêm ShardingConstraintOp.
- Quá trình truyền tải phân đoạn sẽ sử dụng ShardingConstraintOp. Không có ShardingConstraintOp nào trong kết quả của quá trình truyền dữ liệu phân đoạn. Thay vào đó, bạn có thể thêm ReshardOp nếu cần.
- Trình phân vùng chuyển đổi ReshardOp thành một toán tử tập hợp (hoặc toán tử nhận dạng). Không được có ReshardOp trong kết quả của trình phân vùng.
// TODO(b/331680067). Thêm mẫu chuẩn hoá để xoá các thao tác phân đoạn lại // thừa.
Thuộc tính: AlwaysSpeculatableImplTrait
, SameOperandsAndResultType
Giao diện: ConditionallySpeculatable
, InferTypeOpInterface
, NoMemoryEffect (MemoryEffectOpInterface)
Hiệu ứng: MemoryEffects::Effect{}
Thuộc tính:
Thuộc tính | Loại MLIR | Mô tả |
---|---|---|
sharding | ::mlir::sdy::TensorShardingAttr | Phân đoạn tensor |
Toán hạng:
Toán hạng | Mô tả |
---|---|
input |
tensor của bất kỳ giá trị loại nào |
Kết quả:
Kết quả | Mô tả |
---|---|
result |
tensor của bất kỳ giá trị loại nào |
sdy.return
(sdy::ReturnOp)
Thao tác sdy.return
chấm dứt các vùng được đính kèm vào các thao tác dựa trên vùng sdy
và mọi thao tác dựa trên vùng Shardy khác. Đây là một hàm biến: hàm này lấy danh sách các giá trị có kiểu bất kỳ làm đối số (nhưng cùng loại, ví dụ: AnyTensor
) và do đó có thể được sử dụng lại ở nhiều cấp của ngăn xếp IR Shardy.
Cú pháp:
operation ::= `sdy.return` attr-dict ($results^ `:` type($results))?
Đặc điểm: AlwaysSpeculatableImplTrait
, Terminator
Giao diện: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
Hiệu ứng: MemoryEffects::Effect{}
Toán hạng:
Toán hạng | Mô tả |
---|---|
results |
biến kiểu bất kỳ |
sdy.sharding_constraint
(sdy::ShardingConstraintOp)
Ràng buộc một tensor với phân đoạn đã chỉ định
Cú pháp:
operation ::= `sdy.sharding_constraint` $input $sharding attr-dict `:` type($result)
Đính kèm một phân đoạn vào một tensor trung gian (ví dụ: kết quả của matmul) để cho biết đây là cách phân đoạn tensor đó hoặc một tập hợp con của các mục sử dụng tensor đó.
Nếu phân đoạn có các phương diện mở và trục không ràng buộc, thì điều đó có nghĩa là tensor có thể được phân đoạn thêm dọc theo các phương diện mở.
Thao tác này có thể:
- Không có mục đích sử dụng (dangling) – tức là việc phân đoạn đính kèm là cách phân đoạn chính tensor đầu vào.
- Có trường hợp sử dụng – nghĩa là phân đoạn đính kèm là cách phân đoạn các trường hợp sử dụng của toán tử ràng buộc phân đoạn, trong khi các trường hợp sử dụng khác của tensor đầu vào có thể có phân đoạn khác (nếu tensor đầu vào không có trường hợp sử dụng nào khác thì hành vi sẽ giống như trường hợp không sử dụng).
Đặc điểm: SameOperandsAndResultType
Giao diện: InferTypeOpInterface
Thuộc tính:
Thuộc tính | Loại MLIR | Mô tả |
---|---|---|
sharding | ::mlir::sdy::TensorShardingAttr | Phân đoạn tensor |
Toán hạng:
Toán hạng | Mô tả |
---|---|
input |
tensor của bất kỳ giá trị loại nào |
Kết quả:
Kết quả | Mô tả |
---|---|
result |
tensor của bất kỳ giá trị loại nào |
sdy.sharding_group
(sdy::ShardingGroupOp)
Giới hạn các tensor trong nhóm để có cùng một phân đoạn.
Cú pháp:
operation ::= `sdy.sharding_group` $input `group_id````=```$group_id attr-dict `:` type($input)
Toán tử này cung cấp một giao diện để gán tensor cho các nhóm phân đoạn ( các nhóm tensor sẽ được thực thi để có các phân đoạn giống hệt nhau). Trong quá trình truyền tải, ngay khi một phần tử nhóm được phân đoạn, tất cả các thành phần khác sẽ được phân đoạn theo đúng cách. Thao tác này lấy mã nhận dạng nhóm đối số và không trả về kết quả nào, nhưng thay vào đó, sửa đổi nội dung trình bày nhóm phân đoạn nội bộ để thêm tensor đầu vào vào nhóm có mã nhận dạng đã cho.
Giao diện: InferTypeOpInterface
Thuộc tính:
Thuộc tính | Loại MLIR | Mô tả |
---|---|---|
group_id | ::mlir::IntegerAttr | Thuộc tính số nguyên 64 bit không dấu |
Toán hạng:
Toán hạng | Mô tả |
---|---|
input |
tensor được xếp hạng của bất kỳ giá trị loại nào |
Thuộc tính
AllToAllParamAttr
Thông số tất cả với tất cả
Cú pháp:
#sdy.all_to_all_param<
::llvm::ArrayRef<AxisRefAttr>, # axes
int64_t, # src_dim
int64_t # tgt_dim
>
Một bộ chứa các trục và phương diện nguồn/đích để thực hiện tính năng tất cả với tất cả.
Các thông số:
Thông số | Loại C++ | Mô tả |
---|---|---|
trục | ::llvm::ArrayRef<AxisRefAttr> |
các trục để thực hiện tất cả với tất cả |
src_dim | int64_t |
chỉ mục phương diện nguồn |
tgt_dim | int64_t |
chỉ mục phương diện mục tiêu |
AlltoAllParamListAttr
Danh sách tham số tất cả với tất cả
Cú pháp:
#sdy.all_to_all_param_list<
::llvm::ArrayRef<AllToAllParamAttr> # value
>
Các thông số:
Thông số | Loại C++ | Mô tả |
---|---|---|
value | ::llvm::ArrayRef<AllToAllParamAttr> |
AxisRefAttr
Tham chiếu đến một trục đầy đủ hoặc một trục phụ được chia
Cú pháp:
#sdy.axis_ref<
::llvm::StringRef, # name
SubAxisInfoAttr # sub_axis_info
>
Các quy tắc ràng buộc:
name
phải có trongMeshAttr
đã liên kết.- Nếu có
sub_axis_info
, thìsub_axis_info
phải đáp ứng các quy tắc ràng buộc củaSubAxisInfoAttr
.
Các thông số:
Thông số | Loại C++ | Mô tả |
---|---|---|
tên | ::llvm::StringRef |
tên của trục này |
sub_axis_info | SubAxisInfoAttr |
thông tin bổ sung nếu đây là trục phụ |
AxisRefListAttr
Danh sách tệp tham chiếu trục
Cú pháp:
#sdy.axis_ref_list<
::llvm::ArrayRef<AxisRefAttr> # value
>
Các quy tắc ràng buộc:
- Các phần tử trong
value
phải đáp ứng các quy tắc ràng buộc củaAxisRefAttr
. - Không có tham chiếu trục hoặc trục phụ trùng lặp chồng chéo lên nhau.
- Không có hai tệp tham chiếu trục liền kề nào là trục phụ liên tiếp của cùng một trục đầy đủ, nghĩa là chúng có thể được hợp nhất thành một trục phụ hoặc trục đầy đủ.
Các thông số:
Thông số | Loại C++ | Mô tả |
---|---|---|
value | ::llvm::ArrayRef<AxisRefAttr> |
DimMappingAttr
Danh sách chỉ số nhân tố cho một phương diện
Danh sách trống cho biết đây là mối liên kết rỗng (được phân tích cú pháp/in bằng *
), tức là phương diện không được liên kết với bất kỳ yếu tố nào.
Các quy tắc ràng buộc:
- Có ít nhất một chỉ mục nhân tố.
- Chỉ số hệ số phải nằm trong khoảng [0,
$factor_sizes
). - Nếu có nhiều yếu tố, thì không có yếu tố nào có thể có kích thước 1.
- Không có chỉ mục nhân tố trùng lặp.
Các thông số:
Thông số | Loại C++ | Mô tả |
---|---|---|
factor_indices | ::llvm::ArrayRef<int64_t> |
các yếu tố mà phương diện này được liên kết |
DimensionShardingAttr
Phân đoạn phương diện
Danh sách tên trục để phân đoạn một phương diện tensor từ chính đến phụ, một giá trị boolean cho biết liệu phương diện có thể được phân đoạn thêm hay không và một số nguyên tuỳ chọn biểu thị mức độ ưu tiên của việc phân đoạn phương diện này. Mức độ ưu tiên này sẽ được tuân thủ trong quá trình truyền tải phân đoạn. Mức độ ưu tiên bắt nguồn từ các chú thích phân đoạn người dùng và giá trị thấp hơn biểu thị mức độ ưu tiên cao hơn. Mức độ ưu tiên cao nhất được giả định khi mức độ ưu tiên bị thiếu trong chú thích.
Các quy tắc ràng buộc:
- Các phần tử trong
axes
phải đáp ứng các quy tắc ràng buộc được liệt kê trongAxisRefListAttr
. - Nếu một phương diện phân đoạn có mức độ ưu tiên:
- Mức độ ưu tiên lớn hơn hoặc bằng 0.
- Phương diện có ít nhất một trục nếu phương diện đó đóng.
Các thông số:
Thông số | Loại C++ | Mô tả |
---|---|---|
trục | ::llvm::ArrayRef<AxisRefAttr> |
tệp tham chiếu trục |
is_closed | bool |
liệu bạn có thể phân đoạn thêm phương diện này hay không |
của chiến dịch | std::optional<int64_t> |
mức độ ưu tiên được sử dụng trong quá trình truyền tải dựa trên mức độ ưu tiên của người dùng |
ListOfAxisRefListsAttr
Danh sách danh sách tham chiếu trục
Cú pháp:
#sdy.list_of_axis_ref_lists<
::llvm::ArrayRef<AxisRefListAttr> # value
>
Các thông số:
Thông số | Loại C++ | Mô tả |
---|---|---|
value | ::llvm::ArrayRef<AxisRefListAttr> |
ManualAxesAttr
Danh sách các trục mà ManualComputationOp là thủ công
Cú pháp:
#sdy.manual_axes<
::llvm::ArrayRef<StringAttr> # value
>
Các thông số:
Thông số | Loại C++ | Mô tả |
---|---|---|
value | ::llvm::ArrayRef<StringAttr> |
MeshAttr
Lưới trục và danh sách thiết bị
Cú pháp:
#sdy.mesh<
::llvm::ArrayRef<MeshAxisAttr>, # axes
::llvm::ArrayRef<int64_t> # device_ids
>
Lưới là danh sách các trục và danh sách mã thiết bị (không bắt buộc) chỉ định thứ tự thiết bị.
Nếu danh sách trục trống, thì lưới sẽ có một trục ẩn không tên có kích thước 1. Trong trường hợp này, nếu bạn không cung cấp danh sách mã thiết bị, thì danh sách mã thiết bị ngầm ẩn sẽ là [0]; nếu bạn cung cấp danh sách mã thiết bị, thì danh sách đó phải chứa một số nguyên bất kỳ có giá trị không âm. Chúng tôi gọi trường hợp này là phân đoạn tối đa.
Đối với tất cả các trường hợp không phân đoạn tối đa, nếu bạn chỉ định danh sách mã thiết bị, thì tích của kích thước trục phải khớp với số lượng thiết bị. Nếu bạn không chỉ định danh sách mã thiết bị, thì danh sách mã thiết bị ngầm ẩn sẽ là iota(product(axes)). Để đơn giản, chúng tôi cũng không cho phép chỉ định danh sách mã thiết bị giống với iota(product(axes)); trong trường hợp này, bạn không nên chỉ định danh sách mã thiết bị.
Dưới đây là một số ví dụ về lưới:
- Lưới trống đại diện cho lưới giữ chỗ có thể được thay thế trong quá trình truyền: <[]>
- Một lưới có trục không tên và mã thiết bị rõ ràng, thường được dùng để biểu thị phân đoạn tối đa: <[], device_ids=[3]>
- Một lưới có hai trục và mã thiết bị ngầm ẩn iota(6): <["a"=2, "b"=3]>
- Một lưới có hai trục và mã thiết bị rõ ràng chỉ định thứ tự thiết bị: <["a"=3, "b"=2], device_ids=[0, 2, 4, 1, 3, 5]>
Các quy tắc ràng buộc:
- Các phần tử trong
axes
không được có tên trùng lặp. - Nếu bạn chỉ định
device_ids
:- Sản phẩm của kích thước trục phải khớp với số lượng thiết bị.
- Tất cả các phần tử của mảng này đều phải là số không âm.
device_ids
không được bằngiota(product(axis_sizes))
.device_ids
đã sắp xếp phải làiota(product(axis_sizes))
.
Các thông số:
Thông số | Loại C++ | Mô tả |
---|---|---|
trục | ::llvm::ArrayRef<MeshAxisAttr> |
trục lưới |
device_ids | ::llvm::ArrayRef<int64_t> |
thứ tự thiết bị rõ ràng hoặc mã thiết bị tối đa |
MeshAxisAttr
Trục được đặt tên trong lưới
Cú pháp:
#sdy.mesh_axis<
::llvm::StringRef, # name
int64_t # size
>
Các thông số:
Thông số | Loại C++ | Mô tả |
---|---|---|
tên | ::llvm::StringRef |
tên |
size | int64_t |
kích thước của trục này |
OpShardingRuleAttr
Chỉ định cách phân vùng một toán tử.
Cú pháp:
#sdy.op_sharding_rule<
::llvm::ArrayRef<int64_t>, # factor_sizes
::llvm::ArrayRef<TensorMappingAttr>, # operand_mappings
::llvm::ArrayRef<TensorMappingAttr>, # result_mappings
::llvm::ArrayRef<int64_t>, # reduction_factors
::llvm::ArrayRef<int64_t>, # need_replication_factors
::llvm::ArrayRef<int64_t>, # permutation_factors
::llvm::ArrayRef<int64_t>, # blocked_propagation_factors
bool # is_custom_rule
>
Quy tắc phân đoạn chỉ định cách phân đoạn một toán tử theo nhiều thuộc tính trên toán tử – bất kỳ thuộc tính nào, hình dạng của toán hạng, hình dạng của kết quả, v.v. Ví dụ:
%0 = stablehlo.add %arg0, %arg1 {
sdy.sharding_rule = #sdy.op_sharding_rule<
([i, j],[i, j])->([i, j])
{i=8, j=8}>
} : tensor<8x8xf32>
%1 = stablehlo.dot_general %arg2, %arg3, contracting_dims = [1] x [0] {
sdy.sharding_rule = #sdy.op_sharding_rule<
([i, k],[k, j])->([i, j])
{i=8, j=16, k=8}>
}: (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32>
Xin lưu ý rằng chúng tôi cho phép các hệ số có kích thước 1 mặc dù không thể phân đoạn được, điều này chủ yếu là để đảm bảo tính đầy đủ vì nhiều toán tử như toán tử theo điểm có kích thước một phương diện tương ứng với các toán hạng và kết quả.
Loại yếu tố:
reduction_factors
chứa các chỉ mục của các hệ số cần giảm, chẳng hạn như các phương diện thu hẹp trong phép toán dấu chấm.need_replication_factors
chứa các chỉ mục của các yếu tố yêu cầu sao chép đầy đủ, chẳng hạn như phương diện đã sắp xếp trong một thao tác sắp xếp.permutation_factors
chứa các chỉ mục của các yếu tố yêu cầu hoán vị tập thể nếu các yếu tố đó được phân đoạn, chẳng hạn như kích thước khoảng đệm trong thao tác đệm.- Tất cả các yếu tố khác được coi là yếu tố truyền qua, tức là các yếu tố không yêu cầu bất kỳ hoạt động giao tiếp nào nếu được phân đoạn theo cùng một cách trên tất cả các tensor được liên kết với chúng.
blocked_propagation_factors
chứa các yếu tố theo đó không được phép phân đoạn. Phương diện này vuông góc với các loại yếu tố. Cụ thể,
yếu tố chặn truyền có thể là bất kỳ loại yếu tố nào.
is_custom_rule
mô tả liệu đây có phải là quy tắc do người dùng xác định hay không. Người dùng có thể xác định quy tắc phân đoạn cho các lệnh gọi tuỳ chỉnh hoặc ghi đè quy tắc phân đoạn được xác định trước cho các thao tác chuẩn. Quy tắc tuỳ chỉnh luôn được giữ nguyên/không bao giờ bị xoá.
Các quy tắc ràng buộc:
- Số lượt ánh xạ toán hạng/kết quả phải khớp với số lượng toán hạng/kết quả của toán tử.
- Có ít nhất một mối liên kết (không thể có quy tắc cho một toán tử không có toán hạng/kết quả).
- Hạng của mỗi
TensorMappingAttr
khớp với hạng của loại tensor tương ứng. - Đối với mỗi nhóm yếu tố (
reduction_factors
,need_replication_factors
,permutation_factors
):- Các phần tử phải nằm trong khoảng [0,
$factor_sizes
]. - Không có chỉ mục nhân tố trùng lặp trong mỗi nhóm và giữa các nhóm.
- Các phần tử phải nằm trong khoảng [0,
Các thông số:
Thông số | Loại C++ | Mô tả |
---|---|---|
factor_sizes | ::llvm::ArrayRef<int64_t> |
kích thước của tất cả các hệ số trong quy tắc này |
operand_mappings | ::llvm::ArrayRef<TensorMappingAttr> |
ánh xạ toán hạng |
result_mappings | ::llvm::ArrayRef<TensorMappingAttr> |
ánh xạ kết quả |
reduction_factors | ::llvm::ArrayRef<int64_t> |
các yếu tố cần giảm |
need_replication_factors | ::llvm::ArrayRef<int64_t> |
các yếu tố yêu cầu sao chép toàn bộ |
permutation_factors | ::llvm::ArrayRef<int64_t> |
các yếu tố yêu cầu hoán vị tập thể |
blocked_propagation_factors | ::llvm::ArrayRef<int64_t> |
các yếu tố theo đó không phân đoạn |
is_custom_rule | bool |
liệu quy tắc có phải là cho stablehlo.custom_call hay không |
SubAxisInfoAttr
Thông tin về cách trục phụ này được lấy từ trục đầy đủ
Cú pháp:
#sdy.sub_axis_info<
int64_t, # pre_size
int64_t # size
>
Khi chia một trục đầy đủ thành n trục phụ, trục này được định hình lại thành [k_1,...,k_n] và trục phụ thứ i có thể được biểu thị bằng tích của tất cả kích thước trục ở bên trái m=prod(k_1,...,k_(i-1))
(còn gọi là kích thước trước) và kích thước k_i. Do đó, thuộc tính sub-axis-info chứa hai số đó và được ký hiệu như sau: (m)k
cho kích thước trước m và kích thước k.
Các quy tắc ràng buộc:
pre-size
tối thiểu là 1.size
lớn hơn 1.pre-size
phải chia kích thước của trục đầy đủ, tức là cảpre-size
vàsize
chia kích thước của trục đầy đủ và trục phụ không vượt quá trục đầy đủ.- Kích thước của trục phụ không bằng kích thước của trục đầy đủ tương ứng, trong trường hợp này, bạn nên sử dụng trục đầy đủ.
Các thông số:
Thông số | Loại C++ | Mô tả |
---|---|---|
pre_size | int64_t |
tích của các kích thước trục phụ ở bên trái trục phụ này |
size | int64_t |
kích thước của trục phụ này |
TensorMappingAttr
Các ánh xạ nhân tố cho mỗi phương diện của một tensor.
Cú pháp:
#sdy.tensor_mapping<
::llvm::ArrayRef<DimMappingAttr> # dim_mappings
>
Các quy tắc ràng buộc:
- Các phần tử trong
dim_mappings
phải đáp ứng các quy tắc ràng buộc trongDimMappingAttr
. - Không có chỉ mục nhân tố trùng lặp trên các phương diện.
Các thông số:
Thông số | Loại C++ | Mô tả |
---|---|---|
dim_mappings | ::llvm::ArrayRef<DimMappingAttr> |
mối liên kết phương diện |
TensorShardingAttr
Phân đoạn tensor
Cú pháp:
#sdy.sharding<
::mlir::Attribute, # mesh_or_ref
::llvm::ArrayRef<DimensionShardingAttr>, # dim_shardings
::llvm::ArrayRef<AxisRefAttr> # replicated_axes
>
Phân đoạn tensor được liên kết với một lưới cụ thể và chỉ có thể tham chiếu tên trục từ lưới đó. Các phân đoạn theo phương diện cho chúng ta biết về mỗi phương diện của tensor, theo đó các trục (hoặc trục phụ) được phân đoạn từ chính đến phụ. Tất cả các trục khác không phân đoạn một phương diện sẽ được sao chép ngầm hoặc rõ ràng (nếu xuất hiện trong danh sách trục được sao chép).
Lưới liên kết với phân đoạn này có thể được chỉ định bằng tên biểu tượng, tham chiếu đến biểu tượng MeshOp
tương ứng hoặc MeshAttr
nội tuyến.
Các quy tắc ràng buộc:
- Các phần tử trong
dim_shardings
phải đáp ứng các quy tắc ràng buộc được liệt kê trongDimensionShardingAttr
. - Các phần tử trong
replicated_axes
phải đáp ứng các quy tắc ràng buộc được liệt kê trongAxisRefListAttr
. - Nếu loại tensor tương ứng không phải là
ShapedType
, thì phân đoạn phải có thứ hạng 0 và không có trục được sao chép. - Tensor phải có thứ hạng.
- Số lượng phân đoạn theo phương diện bằng với thứ hạng của tensor.
- Phương diện có kích thước 0 không được phân đoạn.
- Các mục trong
replicated_axes
được sắp xếp theomesh_or_ref
(xemAxisRefAttr::getMeshComparator
).
Các thông số:
Thông số | Loại C++ | Mô tả |
---|---|---|
mesh_or_ref | ::mlir::Attribute |
thuộc tính lưới hoặc thuộc tính tham chiếu biểu tượng lưới phẳng |
dim_shardings | ::llvm::ArrayRef<DimensionShardingAttr> |
phân đoạn phương diện |
replicated_axes | ::llvm::ArrayRef<AxisRefAttr> |
tệp tham chiếu trục |
TensorShardingPerValueAttr
Phân đoạn tensor theo toán hạng/kết quả của một toán tử
Cú pháp:
#sdy.sharding_per_value<
::llvm::ArrayRef<TensorShardingAttr> # shardings
>
Danh sách TensorShardingAttr
, mỗi TensorShardingAttr
cho một toán hạng/kết quả của một toán tử.
Các quy tắc ràng buộc:
- Các phần tử trong
shardings
phải đáp ứng các quy tắc ràng buộc củaTensorShardingAttr
.
Các thông số:
Thông số | Loại C++ | Mô tả |
---|---|---|
phân đoạn | ::llvm::ArrayRef<TensorShardingAttr> |
phân đoạn theo giá trị |
Enum
PropagationDirection
Enum hướng truyền
Trường hợp:
Biểu tượng | Giá trị | Chuỗi |
---|---|---|
KHÔNG CÓ | 0 |
KHÔNG CÓ |
FORWARD | 1 |
FORWARD |
QUAY LẠI | 2 |
QUAY LẠI |
CẢ HAI BÊN | 3 |
CẢ HAI BÊN |