API trình biên dịch

Thông tin khái quát

Chúng tôi giả định rằng độc giả đã nắm được ít nhất các kiến thức cơ bản về phương thức biểu diễn phân đoạn, mô tả cách biểu thị phân đoạn của một tensor trong Shardy. Tài liệu này cho biết cách sử dụng các đại diện phân đoạn trong một chương trình, ví dụ: để đính kèm một phân đoạn vào một tensor cụ thể của chương trình.

Truyền tải phân đoạn là quá trình quyết định phân đoạn cho mọi tensor trong một chương trình, với các điều kiện ràng buộc phân đoạn cho một tập hợp con của các tensor. API trình biên dịch của Shardy hiển thị một số cách để ảnh hưởng/kiểm soát việc truyền tải phân đoạn. Ngoài ra, tính năng này cho phép người dùng chèn các phép tính được phân đoạn theo cách thủ công vào chương trình của họ.

Mục tiêu

Tài liệu này mô tả thiết kế của các thành phần API như vậy trong Shardy và giải thích hành vi và hằng số của các thành phần đó. Xin lưu ý rằng mặc dù API này được dùng để kiểm soát việc truyền tải phân đoạn, nhưng tài liệu này KHÔNG thảo luận về bất kỳ điều gì liên quan đến hành vi truyền tải cũng như cách thiết kế hành vi truyền tải.

Tổng quan

  • Phân đoạn đầu vào/đầu ra – đính kèm một phân đoạn vào đầu vào hoặc đầu ra của hàm chính để cho biết đây là cách phân đoạn tensor đầu vào/đầu ra khi được cung cấp cho/trả về từ hàm.

  • Hạn chế phân đoạn – đính kèm một phân đoạn vào một tensor trung gian (ví dụ: kết quả của matmul) để cho biết đây là cách phân đoạn tensor đó hoặc một tập hợp con của các mục sử dụng tensor đó.

  • Nhóm phân đoạn – nhóm nhiều tensor theo một mã nhận dạng để cho biết rằng các tensor đó phải được phân đoạn theo cùng một cách.

  • Tính toán thủ công – bao gồm một phép tính phụ được phân vùng theo cách thủ công bằng cách sử dụng một tập hợp con các trục lưới, trong đó các phân đoạn dọc theo các trục thủ công đó được chỉ định cho tất cả dữ liệu đầu vào và đầu ra, và bên trong phép tính phụ, các loại tensor là cục bộ so với các phân đoạn đó.

Thiết kế chi tiết

Phân đoạn đầu vào/đầu ra

Cho phép người dùng chỉ định một phân đoạn cho dữ liệu đầu vào và đầu ra của hàm chính.

Trong MLIR, các thuộc tính có thể được đính kèm vào các đối số và kết quả của hàm, do đó, người dùng có thể đính kèm các thuộc tính phân đoạn vào hàm theo cách này.

Ví dụ:

@mesh_xy = <["x"=2, "y"=2]>

// The 1st input has a sharding specified, but the 2nd input doesn't.
// The output has a sharding specified.
func @main(%arg0: tensor<8x8xf32>
            {sdy.sharding = #sdy.sharding<@mesh_xy, [{"x"}, {}]>},
            %arg1: tensor<8x16xf32>)
    -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_xy, [{}, {"y"}]>}) {
  ...
}

Quy tắc ràng buộc phân đoạn

Cho phép người dùng đính kèm một phân đoạn vào một tensor trung gian trong chương trình của họ, điều này cho trình phân vùng biết rằng đây là cách phân đoạn tensor đó hoặc một tập hợp con của các mục đích sử dụng.

Đây là một toán tử MLIR lấy tensor làm dữ liệu đầu vào và có một thuộc tính phân đoạn được đính kèm. Thao tác này có thể:

  • Không có mục đích sử dụng (dangling) – tức là việc phân đoạn đính kèm là cách phân đoạn chính tensor.
  • Có trường hợp sử dụng – nghĩa là phân đoạn đính kèm là cách phân đoạn các trường hợp sử dụng của toán tử ràng buộc phân đoạn, trong khi các trường hợp sử dụng khác của tensor đầu vào có thể có phân đoạn khác (nếu tensor đầu vào không có trường hợp sử dụng nào khác thì hành vi sẽ giống như trường hợp không có trường hợp sử dụng). Quá trình truyền tải sẽ xác định việc phân đoạn của chính tensor và phân đoạn lại nếu cần.

Hàm này có thể có các phân đoạn phương diện mở, tức là toán hạng có thể được phân đoạn thêm theo các trục có sẵn.

@mesh_xy = <["x"=2, "y"=2]>

%0 = ... : tensor<8x8xf32>
%1 = sdy.sharding_constraint %0 <@mesh_xy, [{"x"}, {?}]> : tensor<8x8xf32>

Nhóm phân đoạn

Trong trường hợp không có phần phụ thuộc dữ liệu hoặc không có phần phụ thuộc dữ liệu mạnh giữa hai hoặc nhiều tensor, trong khi người dùng biết rằng các tensor đó phải được phân vùng theo cùng một cách hoặc theo cách tương tự, thì Shardy API cung cấp một cách để chỉ định mối quan hệ này. Điều này cho phép người dùng tự do chỉ định rõ ràng rằng các tensor phải được phân vùng như nhau.

Để đạt được điều này, chúng tôi giới thiệu khái niệm về nhóm phân mảnh, trong đó mỗi nhóm chứa số lượng lệnh bất kỳ được liên kết với cùng một mã nhận dạng nhóm phân mảnh. Các nhóm phân đoạn thực thi các phân đoạn trong cùng một nhóm phải giống nhau.

Ví dụ: trong một chương trình người dùng giả định như minh hoạ bên dưới, chúng ta muốn phân đoạn đầu ra của chương trình giống hệt như đầu vào của chương trình trong khi không có phần phụ thuộc dữ liệu nào giữa hai đầu ra này.

Nếu chúng ta chạy chương trình này, quá trình truyền tải phân đoạn sẽ không thể suy luận về việc phân đoạn các tensor %1%2, và cuối cùng các tensor này sẽ được sao chép. Tuy nhiên, bằng cách đính kèm thuộc tính shard_group cho biết %0 đầu vào và %2 đầu ra nằm trong cùng một shard_group, chúng ta cho phép truyền tải @mesh_xy, [{"x"},{"y"}]> phân đoạn từ %0 đầu vào đến %2 đầu ra, và lần lượt đến phần còn lại của biểu đồ, được truyền liên tục %1 tại đây. Chúng ta có thể gán giá trị cho một nhóm bằng toán tử sdy.sharding_group.

@mesh_xy = <["x"=2, "y"=2]>

module @"jit_zeros_like" {
  func.func @main(%arg0: tensor<8x2xi64> {sdy.sharding = #sdy.sharding<@mesh_xy, [{"x"},{"y"}]>} }) -> (tensor<8x2xi64>) {
    %0 = sdy.sharding_group %arg0, id=0 : tensor<8x2xi64>
    %1 = stablehlo.constant dense<0> : tensor<8x2xi64>
    %2 = sdy.sharding_group %1, id=0 : tensor<8x2xi64>
    return %2 : tensor<8x2xi64>
  }
}

Trong ví dụ đơn giản ở trên, chúng ta cũng có thể chỉ định rõ ràng cùng một phân đoạn trên đầu ra như đầu vào, điều này sẽ đạt được hiệu quả tương tự, vì chúng ta đã biết trước phân đoạn nào chúng ta muốn chỉ định cho đầu vào, nhưng trong các trường hợp thực tế hơn, chúng ta sử dụng phân đoạn để đồng bộ hoá việc phân đoạn nhiều tensor mà không nhất thiết phải biết phân đoạn cho bất kỳ phân đoạn nào trong số đó, trong khi Shardy sẽ xử lý phần còn lại và tìm phân đoạn tốt nhất để chỉ định cho các phân đoạn đó.

Tính toán thủ công

Người dùng có thể muốn kiểm soát rõ ràng cách phân vùng các phần của phép tính và những tập hợp nào được sử dụng. Ví dụ: một số người dùng muốn áp dụng matmul tập hợp theo cách thủ công (từ API giao diện người dùng) thay vì trì hoãn cho trình biên dịch. Chúng tôi cung cấp một API Tính toán thủ công cho phép họ làm việc đó.

Đây là toán tử MLIR với một vùng duy nhất cho phép tính toán phụ theo cách thủ công. Người dùng sẽ chỉ định các phân đoạn đầu vào/đầu ra cho phép tính phụ này bằng cách sử dụng một tập hợp con (có thể bao gồm tất cả) các trục lưới. Tính toán phụ sẽ là cục bộ/thủ công đối với các trục lưới được chỉ định (còn gọi là trục thủ công) và toàn cục/không phân vùng đối với các trục không được chỉ định (còn gọi là trục tự do). Bạn có thể phân đoạn thêm phép tính phụ dọc theo các trục tự do trong quá trình truyền tải theo cách tương tự như phép tính bên ngoài phép toán này.

Ví dụ:

@mesh_name = <["data"=2, "model"=2]>

%0 = ... : tensor<16x32xf32>
%1 = sdy.manual_computation(%0)
    in_shardings=[<@mesh_name, [{"data"}, {"model",?}]>]
    out_shardings=[<@mesh_name, [{"data"}, {?}]>]
    manual_axes={"data"}
    (%arg1: tensor<8x32xf32>) {
  // body
  return %42 : tensor<8x32xf32>
} : (tensor<16x32xf32>) -> tensor<16x32xf32>

Biến không đổi

  1. Tất cả in_shardings, out_shardingsmanual_axes phải tham chiếu đến cùng một lưới. manual_axes được sắp xếp theo lưới.

  2. Bạn phải sử dụng manual_axes một cách rõ ràng trong tất cả các hoạt động phân đoạn vào/ra, tức là đối với mỗi hoạt động phân đoạn, tất cả các trục thủ công phải phân đoạn một phương diện hoặc được sao chép một cách rõ ràng.

  3. Nếu một trục tự do (bất kỳ trục lưới nào không nằm trong manual_axes) tồn tại trong một trong các phân đoạn vào/ra, thì trục đó phải là trục phụ đối với bất kỳ trục thủ công nào trong cùng một phân đoạn phương diện (trong ví dụ trên, phân đoạn phương diện {"model", "data"} sẽ không hợp lệ).

  4. Khu vực/phần nội dung của phép tính là phép tính cục bộ (ví dụ: bao gồm cả các tập hợp do người dùng chỉ định). Phải là cục bộ đối với việc phân đoạn vào/ra dọc theo các trục thủ công (xem ghi chú ở trên).

Lồng các phép tính thủ công

Bạn có thể lồng nhiều phép tính thủ công vào nhau, miễn là mỗi phép tính hoạt động trên một nhóm trục thủ công riêng biệt.