Định nghĩa OpInterface

ShardableDataFlowOpInterface (ShardableDataFlowOpInterface)

Giao diện toán tử cho phép phân đoạn truyền tải các phân đoạn thông qua các cạnh luồng dữ liệu của các toán tử mở rộng giao diện này.

Cạnh luồng dữ liệu của một số toán tử X xác định một cầu nối giữa một tập hợp các nguồn (mỗi nguồn là một toán hạng của X hoặc một toán hạng của trình kết thúc khối của X) và một tập hợp các mục tiêu (mỗi mục tiêu là kết quả của X hoặc đối số khối của X), sao cho tất cả các nguồn và mục tiêu phải được phân đoạn theo cùng một cách. Một toán tử có thể có nhiều cạnh luồng dữ liệu vuông góc với nhau.

Chủ sở hữu là mục tiêu do người dùng chỉ định của cạnh luồng dữ liệu được sử dụng trong quá trình truyền của shardy. Người dùng có thể chọn giá trị này tuỳ ý nhưng giá trị này cần phải tĩnh.

Ví dụ:

  y_1, ..., y_n = custom_op (x_1, ..., x_n)
                  ((body_arg_1,..., body_arg_n) {
                    ...
                    return return_value_1, ..., return_value_n
                  })

custom_op này có hai loại cho các cạnh luồng dữ liệu, n cạnh giữa return_value_i (nguồn) và y_i (đích) và n cạnh giữa x_i(nguồn) và body_arg_i(đích). Trong trường hợp này, chủ sở hữu cạnh giống với mục tiêu.

Sau đây là ví dụ về một hoạt động có nhiều mục tiêu:

  y_0, ..., y_n = while (x_0, ..., x_n)
                  ((pred_arg_0,... , pred_arg_n) { ... })
                  ((body_arg_0,..., body_arg_n) {
                    ...
                    return return_value_0, ..., return_value_n
                  })

Toán tử while này có n cạnh luồng dữ liệu, cạnh luồng dữ liệu thứ i nằm giữa nguồn x_i, return_value_i và đích y_i, pred_arg_i, body_arg_i.

Phương thức:

getBlockArgumentEdgeOwnerShardings

mlir::ArrayRef<mlir::sdy::TensorShardingAttr> getBlockArgumentEdgeOwnerShardings();

Trả về các phân đoạn của tất cả chủ sở hữu cạnh luồng dữ liệu đối số khối.

setBlockArgumentEdgeOwnerSharding

void setBlockArgumentEdgeOwnerSharding(unsigned index, mlir::sdy::TensorShardingAttr sharding);

Đặt sharding của chủ sở hữu cạnh đối số của khối bằng index đã cho.

setBlockArgumentEdgeOwnerShardings

void setBlockArgumentEdgeOwnerShardings(mlir::ArrayRef<mlir::sdy::TensorShardingAttr> shardings);

Đặt shardings của tất cả chủ sở hữu cạnh đối số khối.

getOpResultEdgeOwnerShardings

mlir::ArrayRef<mlir::sdy::TensorShardingAttr> getOpResultEdgeOwnerShardings();

Trả về các phân đoạn của tất cả chủ sở hữu cạnh luồng dữ liệu kết quả hoạt động.

setOpResultEdgeOwnerSharding

void setOpResultEdgeOwnerSharding(unsigned index, mlir::sdy::TensorShardingAttr sharding);

Đặt sharding của chủ sở hữu cạnh kết quả toán tử bằng index đã cho.

setOpResultEdgeOwnerShardings

void setOpResultEdgeOwnerShardings(mlir::ArrayRef<mlir::sdy::TensorShardingAttr> shardings);

Đặt shardings cho tất cả chủ sở hữu cạnh kết quả trong hoạt động.

getBlockArgumentEdgeOwners

mlir::ArrayRef<mlir::BlockArgument> getBlockArgumentEdgeOwners();

Lấy tất cả chủ sở hữu cạnh đối số khối.

getOpResultEdgeOwners

mlir::ResultRange getOpResultEdgeOwners();

Lấy tất cả chủ sở hữu cạnh tranh trong kết quả hoạt động.

getEdgeSources

mlir::SmallVector<mlir::Value> getEdgeSources(mlir::Value target);

Lấy nguồn cạnh luồng dữ liệu theo giá trị target.

getEdgeOwnerFromTarget

mlir::Value getEdgeOwnerFromTarget(mlir::Value target);

Lấy target chủ sở hữu của một cạnh luồng dữ liệu, với target có thể là hoặc không phải là chủ sở hữu.

getEdgeOwnerFromSource

mlir::Value getEdgeOwnerFromSource(mlir::OpOperand&source);

Lấy mục tiêu chủ sở hữu của cạnh luồng dữ liệu dựa vào source.

getNonEdgeOwnerTargets

mlir::SmallVector<mlir::Value> getNonEdgeOwnerTargets(mlir::Value owner);

Lấy các mục tiêu không phải chủ sở hữu của một cạnh luồng dữ liệu dựa trên cạnh owner.

ShardingRuleOpInterface (ShardingRuleOpInterface)

Giao diện toán tử cho phép toán tử xác định quy tắc phân đoạn riêng. Quy tắc phân đoạn chỉ định cách phân vùng một toán tử theo các thuộc tính khác nhau trên op – bất kỳ thuộc tính nào, hình dạng của toán hạng, hình dạng của kết quả, v.v. Hãy xem OpShardingRuleAttr để biết thêm chi tiết.

Phương thức:

getShardingRule

mlir::sdy::OpShardingRuleAttr getShardingRule();

Trả về quy tắc phân đoạn của toán tử.