Tính linh động trong StableHLO

Trạng thái hiện tại của tính linh động được nêu rõ hơn trong RFC về tính linh động, trang này sẽ cung cấp thông tin tổng quan cấp cao về RFC và thảo luận về các API và công cụ quan trọng để tương tác với các chương trình động.

Tổng quan về thuật ngữ và hỗ trợ về tính năng Động

Trước tiên, để giới thiệu một số thuật ngữ sẽ xuất hiện trong tài liệu này, cũng như một phần giới thiệu ngắn về tính năng hỗ trợ của các thuật ngữ đó trong StableHLO:

Phương diện động

Phương diện động đề cập đến mọi phương diện có kích thước chưa xác định. Trong StableHLO, chúng ta biểu thị các phương diện động bằng ?, tức là tensor<16x?xf32>.

Tính linh động có giới hạn

Tính linh động có giới hạn đề cập đến một phương diện động có giá trị có giới hạn trên đã biết. Nói chung, điều này rất hữu ích để tạo khoảng đệm cho tensor trong quá trình thực thi. Trong StableHLO, chúng ta biểu thị tính linh động có giới hạn bằng cách sử dụng #stablehlo.bounds làm mã hoá tensor, tức là một tensor cấp 2 có một chiều động bị giới hạn ở 16 và chiều còn lại không có giới hạn có thể được biểu thị là tensor<?x?xf32, #stablehlo.bounds<16, ?>>.

StableHLO có thể biểu thị tính linh động có giới hạn, nhưng có hỗ trợ khung hạn chế, bắt nguồn từ TensorFlow và một số hỗ trợ trong PyTorch/XLA.

Tính linh động không giới hạn

Như tên gọi, tính năng linh động không giới hạn đề cập đến một phương diện động không có giới hạn kích thước nào được biết. Loại tính năng linh động này rất phổ biến trong StableHLO, với khả năng hỗ trợ JAX, PyTorch/XLA và TF, thường được dùng để xuất các mô hình có kích thước lô hoặc độ dài trình tự động.

Trong StableHLO, chúng ta chỉ cần xoá mã hoá giới hạn cho hình thức linh động này, tức là tensor<?x?xf32>.

Tính đa hình của hình dạng

Tính đa hình hình dạng là một thuật ngữ mà chúng ta kế thừa từ JAX.

Có hai tác động chính đến việc định hình tính đa hình:

  1. Tất cả tính năng linh động trong chương trình đều bắt nguồn từ các đối số đầu vào.
  2. Tất cả các động lực học chỉ liên quan đến hình dạng tensor, tức là không phụ thuộc vào dữ liệu.

Với 2 quy tắc này, sau khi biết hình dạng tĩnh của một chương trình, chúng tôi có thể đưa một chương trình động và tinh chỉnh hoàn toàn chương trình đó thành một chương trình tĩnh để biên dịch (xem phần "Thẻ về trình biên dịch để tinh chỉnh các chương trình động").

Nhìn chung, tính đa hình hình dạng sử dụng tính linh động không giới hạn, nếu các hình dạng đối số đã biết có thể dẫn đến một chương trình hoàn toàn tĩnh, thì bạn không cần phải đoán cách giới hạn các giá trị.

Tính linh động phụ thuộc vào dữ liệu

Tính linh động phụ thuộc vào dữ liệu đề cập đến kích thước phương diện động liên quan đến dữ liệu bên trong một tensor. Ví dụ chuẩn là hàm nonzeros trả về chỉ mục của tất cả các phần tử là 0 trong một giá trị tensor. Tuy không thể biết hình dạng nếu không đánh giá dữ liệu, nhưng hình dạng này thường có thể được biên dịch bằng cách sử dụng động lực giới hạn, dành thêm bộ nhớ cho kích thước tensor đầu ra tiềm năng.

Bạn có thể lập mô hình nhiều toán tử động phụ thuộc vào dữ liệu bằng cách sử dụng tính linh động có giới hạn, trong đó giới hạn trên của kích thước tensor được chỉ định và phần cứng thường sẽ triển khai điều này thông qua tính năng đệm tensor. Hiện tại, có một số tính năng hỗ trợ tính linh động phụ thuộc vào dữ liệu trong PyTorch/XLA và TensorFlow, nhưng JAX hiện không theo dõi các thao tác dẫn đến tính linh động phụ thuộc vào dữ liệu.

Xuất chương trình có phương diện động

Hãy xem hướng dẫn về StableHLO của chúng tôi để biết thông tin về cách xuất các chương trình có kích thước lô hoặc độ dài trình tự động:

Lượt truyền trình biên dịch để tinh chỉnh các chương trình động

Xoá quy trình truyền động

Có một số lượt truyền hữu ích để tinh chỉnh hình dạng, tất cả đều được đóng gói trong quy trình truyền createStablehloRemoveDynamismPipeline:

void createStablehloRemoveDynamismPipeline(OpPassManager &pm,
                                           TypeRange refinedTypes);

Các lượt đi riêng lẻ để tinh chỉnh tính linh động

Riêng lẻ, các lượt truyền có xu hướng hữu ích cho việc tinh chỉnh hình dạng là:

Hãy xem tài liệu được liên kết để biết thông tin mới nhất và các ví dụ.

Ví dụ: Tính năng động hữu ích như thế nào và làm cách nào để sử dụng?

Tính linh động có nhiều cách sử dụng, ở đây chúng ta sẽ chủ yếu tập trung vào trường hợp sử dụng phổ biến của tính đa hình hình dạng – tạo một bản trình bày mô hình xuất linh hoạt, thường dùng để biểu thị kích thước lô động hoặc độ dài trình tự.

Mô hình add_one tĩnh

Chúng ta sẽ sử dụng mô hình add_one đơn giản sau đây để minh hoạ điều này:

def add_one(x):
  return x + 1

Khi theo dõi bằng tensor<4xf32>, chúng ta sẽ nhận được chương trình StableHLO sau:

// File: add_one.mlir
func.func @add_one(%arg0: tensor<4xf32>) -> tensor<4xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<4xf32>
  %0 = stablehlo.add %arg0, %cst : tensor<4xf32>
  return %0 : tensor<4xf32>
}

Mô hình này chỉ hoạt động cho các đối số đầu vào có hình dạng tensor<4xf32>. Nếu đã từng thay đổi kích thước lô hoặc độ dài trình tự, chúng tôi sẽ cần truy xuất lại mã nguồn và hạ xuống thành StableHLO, đồng thời không có gì đảm bảo rằng chúng tôi vẫn có quyền truy cập vào mã nguồn!

Mô hình add_one động

Đây là lúc tính năng linh động đa hình của hình dạng phát huy tác dụng. Thay vào đó, JAX và PyTorch/XLA có thể phát ra mô hình add_one bằng IR hợp lệ một cách linh động. Mô hình này sẽ truyền hằng số để khớp với hình dạng đầu vào động như sau:

// File: add_one_dynamic.mlir
func.func public @main(%arg0: tensor<?xf32>) -> tensor<?xf32> {
  %cst = stablehlo.constant dense<1.0> : tensor<f32>
  %0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor<?xf32>) -> tensor<i32>
  %1 = stablehlo.reshape %0 : (tensor<i32>) -> tensor<1xi32>
  %2 = stablehlo.dynamic_broadcast_in_dim %cst, %1, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
  %3 = stablehlo.add %arg0, %2 : tensor<?xf32>
  return %3 : tensor<?xf32>
}

Cách trình bày mô hình này linh hoạt hơn nhiều và cho phép bạn trì hoãn việc chỉ định các giá trị như kích thước lô hoặc độ dài trình tự. Bạn có thể triển khai mô hình này trên các nền tảng có hỗ trợ hình dạng động (như AI Edge) hoặc tinh chỉnh mô hình này bằng cách sử dụng các lượt truyền tính năng linh động được đề cập trong tài liệu này.

Tinh chỉnh mô hình động

Ví dụ: thứ tự truyền sau đây có thể tinh chỉnh hoàn toàn chương trình này:

stablehlo-opt add_one_dynamic.mlir \
  --stablehlo-refine-arguments='types=tensor<16xf32>' \
  --stablehlo-refine-shapes \
  --stablehlo-canonicalize-dynamism

Sau đây là cách chương trình được chuyển đổi dần dần:

// After stablehlo-refine-arguments: Inputs updated, shapes not propagated
func.func public @main(%arg0: tensor<16xf32>) -> tensor<?xf32> {
  %c = stablehlo.constant dense<16> : tensor<1xi64>
  %0 = stablehlo.custom_call @stablehlo.shape_refinement_operand_wrapper(%arg0, %c) {indices_of_shape_operands = dense<1> : tensor<1xi64>} : (tensor<16xf32>, tensor<1xi64>) -> tensor<?xf32>
  ...
  %3 = stablehlo.dynamic_broadcast_in_dim %cst, %2, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
  %4 = stablehlo.add %0, %3 : tensor<?xf32>
  return %4 : tensor<?xf32>
}

// After stablehlo-refine-shapes: Shapes propagated, dynamic ops still exist
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
  %c = stablehlo.constant dense<16> : tensor<1xi32>
  %0 = stablehlo.dynamic_broadcast_in_dim %cst, %c, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<16xf32>
  %1 = stablehlo.add %arg0, %0 : tensor<16xf32>
  return %1 : tensor<16xf32>
}

// After stablehlo-canonicalize-dynamism: Dynamic ops replaced with static ops
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
  %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<16xf32>
  %1 = stablehlo.add %arg0, %0 : tensor<16xf32>
  return %1 : tensor<16xf32>
}

// (Bonus) Use ` --stablehlo-aggressive-simplification` pass to canonicalize the
// constant broadcast, leaving us with the original static program in this case.
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
  %cst = stablehlo.constant dense<1.000000e+00> : tensor<16xf32>
  %0 = stablehlo.add %arg0, %cst : tensor<16xf32>
  return %0 : tensor<16xf32>
}