Tính linh động trong StableHLO

Trạng thái hiện tại của tính linh động được trình bày rõ ràng hơn trong RFC về tính linh động. Trang này sẽ cung cấp thông tin tổng quan cấp cao về RFC và thảo luận về các API cũng như công cụ quan trọng để tương tác với các chương trình động.

Tổng quan về thuật ngữ và hoạt động hỗ trợ cho tính linh hoạt

Trước tiên, để đề cập đến một số thuật ngữ sẽ xuất hiện trong tài liệu này, cũng như phần giới thiệu ngắn về khả năng hỗ trợ của các thuật ngữ đó trong StableHLO:

Kích thước linh hoạt

Phương diện động là phương diện có kích thước phương diện không xác định. Trong StableHLO, chúng ta biểu thị các phương diện động bằng ?, tức là tensor<16x?xf32>.

Tính linh hoạt có giới hạn

Tính linh động có giới hạn đề cập đến một phương diện linh động có giá trị là một giới hạn trên đã biết. Nhìn chung, điều này hữu ích cho việc thêm khoảng đệm vào tensor trong quá trình thực thi. Trong StableHLO, chúng ta biểu thị tính linh hoạt có giới hạn bằng cách sử dụng #stablehlo.bounds làm một phương thức mã hoá tensor, tức là một tensor hạng 2 có một phương diện động được giới hạn ở 16 và phương diện còn lại không có giới hạn có thể được biểu thị dưới dạng tensor<?x?xf32, #stablehlo.bounds<16, ?>>.

StableHLO có thể biểu thị tính linh hoạt có giới hạn, nhưng có sự hỗ trợ hạn chế về khung, bắt nguồn từ TensorFlow và có một số hỗ trợ trong PyTorch/XLA.

Tính linh động không giới hạn

Tính linh hoạt không giới hạn, như tên gọi ngụ ý, đề cập đến một phương diện động không có giới hạn đã biết về kích thước. Loại tính linh hoạt này rất phổ biến trong StableHLO, với sự hỗ trợ của JAX, PyTorch/XLA và TF, thường được dùng để xuất các mô hình có kích thước lô hoặc độ dài chuỗi linh hoạt.

Trong StableHLO, chúng ta chỉ cần bỏ qua việc mã hoá ranh giới cho dạng thức linh hoạt này, tức là tensor<?x?xf32>.

Tính đa hình của hình dạng

Tính đa hình về hình dạng là một thuật ngữ mà chúng tôi kế thừa từ JAX.

Đa hình dạng có 2 ý nghĩa chính:

  1. Mọi tính linh hoạt trong chương trình đều bắt nguồn từ các đối số đầu vào.
  2. Mọi tính linh hoạt chỉ liên quan đến hình dạng của tensor, tức là không phụ thuộc vào dữ liệu.

Với hai quy tắc này, sau khi biết được các hình dạng tĩnh của một chương trình, chúng ta có thể lấy một chương trình động và tinh chỉnh hoàn toàn chương trình đó thành một chương trình tĩnh để biên dịch (xem "Compiler passes for refining dynamic programs").

Nhìn chung, đa hình dạng sử dụng tính linh động không giới hạn. Nếu các hình dạng đối số đã biết có thể dẫn đến một chương trình hoàn toàn tĩnh, thì không cần phải đoán cách liên kết các giá trị.

Tính linh hoạt phụ thuộc vào dữ liệu

Tính linh hoạt phụ thuộc vào dữ liệu đề cập đến kích thước phương diện động liên quan đến dữ liệu bên trong một tensor. Ví dụ kinh điển là hàm nonzeros trả về chỉ mục của tất cả các phần tử là 0 trong một giá trị tensor. Không thể biết hình dạng nếu không đánh giá dữ liệu, nhưng thường có thể biên dịch bằng tính linh hoạt có giới hạn, tiêu tốn thêm bộ nhớ cho kích thước tensor đầu ra tiềm năng.

Bạn có thể lập mô hình cho nhiều hoạt động động phụ thuộc vào dữ liệu bằng cách sử dụng tính năng linh hoạt có giới hạn, trong đó bạn chỉ định giới hạn trên cho kích thước tensor và phần cứng thường sẽ triển khai tính năng này thông qua việc đệm tensor. Hiện tại, có một số hỗ trợ cho tính linh hoạt phụ thuộc vào dữ liệu trong PyTorch/XLA và TensorFlow, nhưng JAX hiện không theo dõi các thao tác dẫn đến tính linh hoạt phụ thuộc vào dữ liệu.

Xuất chương trình có kích thước linh hoạt

Hãy xem hướng dẫn về StableHLO để biết thông tin về cách xuất các chương trình có kích thước lô động hoặc độ dài chuỗi:

Các lượt truyền trình biên dịch để tinh chỉnh các chương trình động

Xoá quy trình truyền thẻ động

Có một số đường chuyền hữu ích để tinh chỉnh các hình dạng, rất thuận tiện khi tất cả đều được gói trong một quy trình đường chuyền createStablehloRemoveDynamismPipeline:

void createStablehloRemoveDynamismPipeline(OpPassManager &pm,
                                           TypeRange refinedTypes);

Các đường chuyền riêng lẻ để tinh chỉnh tính năng động

Các đường chuyền thường hữu ích cho việc tinh chỉnh hình dạng là:

Hãy xem tài liệu được liên kết để biết thông tin và ví dụ mới nhất.

Ví dụ: Tính linh động hữu ích như thế nào và tôi có thể sử dụng tính năng này như thế nào?

Tính linh hoạt có nhiều ứng dụng, ở đây chúng ta sẽ chủ yếu tập trung vào trường hợp sử dụng phổ biến cho Đa hình dạng – tạo một biểu diễn mô hình được xuất linh hoạt, thường được dùng để biểu diễn kích thước lô động hoặc độ dài chuỗi.

Mô hình static add_one

Chúng ta sẽ sử dụng mô hình add_one đơn giản sau đây để minh hoạ điều này:

def add_one(x):
  return x + 1

Khi được theo dõi bằng tensor<4xf32>, chúng ta sẽ nhận được chương trình StableHLO sau đây:

// File: add_one.mlir
func.func @add_one(%arg0: tensor<4xf32>) -> tensor<4xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<4xf32>
  %0 = stablehlo.add %arg0, %cst : tensor<4xf32>
  return %0 : tensor<4xf32>
}

Mô hình này sẽ chỉ hoạt động đối với các đối số đầu vào có hình dạng tensor<4xf32>. Nếu thay đổi kích thước lô hoặc độ dài chuỗi, chúng ta sẽ cần theo dõi lại mã nguồn và hạ cấp xuống StableHLO. Ngoài ra, không có gì đảm bảo rằng chúng ta vẫn có quyền truy cập vào mã nguồn!

Mô hình dynamic add_one

Đây là lúc tính đa hình động của hình dạng phát huy tác dụng. Thay vào đó, JAX và PyTorch/XLA có thể phát ra mô hình add_one với IR hợp lệ một cách linh hoạt. IR này sẽ truyền hằng số để khớp với hình dạng đầu vào linh hoạt như sau:

// File: add_one_dynamic.mlir
func.func public @main(%arg0: tensor<?xf32>) -> tensor<?xf32> {
  %cst = stablehlo.constant dense<1.0> : tensor<f32>
  %0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor<?xf32>) -> tensor<i32>
  %1 = stablehlo.reshape %0 : (tensor<i32>) -> tensor<1xi32>
  %2 = stablehlo.dynamic_broadcast_in_dim %cst, %1, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
  %3 = stablehlo.add %arg0, %2 : tensor<?xf32>
  return %3 : tensor<?xf32>
}

Biểu diễn mô hình này linh hoạt hơn nhiều và cho phép chỉ định các giá trị bị trì hoãn như kích thước lô hoặc độ dài chuỗi. Bạn có thể triển khai mô hình này trên các nền tảng có hỗ trợ hình dạng động (chẳng hạn như AI Edge) hoặc có thể tinh chỉnh mô hình này bằng cách sử dụng các đường chuyền động được đề cập trong tài liệu này.

Tinh chỉnh mô hình động

Ví dụ: thứ tự truyền sau đây có thể tinh chỉnh hoàn toàn chương trình này:

stablehlo-opt add_one_dynamic.mlir \
  --stablehlo-refine-arguments='types=tensor<16xf32>' \
  --stablehlo-refine-shapes \
  --stablehlo-canonicalize-dynamism

Dần dần, chương trình sẽ được chuyển đổi như sau:

// After stablehlo-refine-arguments: Inputs updated, shapes not propagated
func.func public @main(%arg0: tensor<16xf32>) -> tensor<?xf32> {
  %c = stablehlo.constant dense<16> : tensor<1xi64>
  %0 = stablehlo.custom_call @stablehlo.shape_refinement_operand_wrapper(%arg0, %c) {indices_of_shape_operands = dense<1> : tensor<1xi64>} : (tensor<16xf32>, tensor<1xi64>) -> tensor<?xf32>
  ...
  %3 = stablehlo.dynamic_broadcast_in_dim %cst, %2, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
  %4 = stablehlo.add %0, %3 : tensor<?xf32>
  return %4 : tensor<?xf32>
}

// After stablehlo-refine-shapes: Shapes propagated, dynamic ops still exist
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
  %c = stablehlo.constant dense<16> : tensor<1xi32>
  %0 = stablehlo.dynamic_broadcast_in_dim %cst, %c, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<16xf32>
  %1 = stablehlo.add %arg0, %0 : tensor<16xf32>
  return %1 : tensor<16xf32>
}

// After stablehlo-canonicalize-dynamism: Dynamic ops replaced with static ops
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
  %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<16xf32>
  %1 = stablehlo.add %arg0, %0 : tensor<16xf32>
  return %1 : tensor<16xf32>
}

// (Bonus) Use ` --stablehlo-aggressive-simplification` pass to canonicalize the
// constant broadcast, leaving us with the original static program in this case.
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<16xf32>
  %0 = stablehlo.add %arg0, %cst : tensor<16xf32>
  return %0 : tensor<16xf32>
}