인터프리터 디자인

데이터 모델

StableHLO 프로그램은 텐서(n차원 배열)를 통한 연산이며, 현재 모델에서는 Tensor 클래스를 사용하여 구현됩니다. Tensor 객체의 기본 스토리지 클래스인 detail::Buffer주에서 마이너 순서로 연속 바이트 배열로 배치된 텐서 데이터의 변경 가능한 blob을 나타내는 mlir::HeapAsmResourceBlob 객체와 함께 텐서의 mlir::ShapedType를 저장합니다. detail::Buffer 객체는 메모리 관리를 간소화하기 위해 참조 카운트됩니다.

텐서의 개별 요소는 Element 클래스를 사용하여 표현되며, 이 클래스는 APInt, APFloat 또는 pair<APFloat,APFloat> 중 하나를 보유하는 구별된 합집합을 사용하여 저장합니다. 마지막 유형은 복합 유형의 요소를 저장하는 데 사용됩니다.

Tensor에는 개별 요소와 상호작용하는 다음 API가 있습니다.

  • Element Tensor::get(llvm::ArrayRef<int64_t> index): 다차원 색인 index의 개별 텐서 요소를 Element 객체로 추출합니다.
  • void Tensor::set(llvm::ArrayRef<int64_t> index, Element element);: Element 객체 element를 다차원 색인 index의 텐서로 업데이트합니다.

인터프리터 작동 방식

인터프리터에 대한 입력 함수는

SmallVector<Tensor> eval(func::FuncOp func, ArrayRef<Tensor> args);

이 코드는 다음을 수행합니다.

  1. 기호 테이블 맵 M을 사용하여 args에 제공된 func의 SSA 인수 및 연결된 런타임 Tensor 값을 추적합니다.
  2. func 내 각 작업에 대해 SSACFG 순서로 다음을 수행합니다.
    • 작업에 eval를 호출합니다. 작업의 각 SSA 피연산자의 경우 M에서 런타임 값을 추출하여 eval 호출의 인수로 제공합니다.
    • 작업의 SSA 결과와 M에서 평가된 값을 추적합니다.

(2)에서 언급된 작업 수준 eval는 작업의 실행 시맨틱스를 구현합니다. 다음은 stablehlo::AddOp의 예입니다. 이 예시에서 lhsrhs 텐서의 개별 요소는 Element 객체로 쌍 추출된 후 추가됩니다. 더하기의 결과인 Element 객체는 최종 result 텐서에 저장됩니다.

Tensor eval(AddOp op, const Tensor &lhs, const Tensor &rhs) {
  Tensor result(op.getType());

  for (auto it = result.index_begin(); it != result.index_end(); ++it)
    result.set(*it, lhs.get(*it) + rhs.get(*it));

  return result;
}

전반적으로 인터프리터 설계는 StableHLO의 참조 구현으로 기능하도록 되어 있으므로 개별 작업에 관한 eval 함수 구현의 가독성에 최적화되어 있습니다. 예를 들어 eval를 템플릿 함수로 정의하고 요소 유형으로 매개변수화하는 대신 Element::operator+ 등에서 다양한 요소 유형이 처리되는 방식에 관한 세부정보를 캡슐화하여 eval 구현을 간소화합니다.

일정한 접기에 인터프리터 사용

인터프리터 메커니즘을 사용하여 상수 피연산자 값이 있는 연산을 폴딩할 수 있습니다. 다음 코드 스니펫은 부동 소수점 유형 피연산자로 stablehlo::AddOp를 접는 구현 개념을 보여줍니다.

OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
  auto attrs = adaptor.getOperands();
  DenseElementsAttr lhsData = attrs[0].dyn_cast<DenseElementsAttr>();
  DenseElementsAttr rhsData = attrs[1].dyn_cast<DenseElementsAttr>();
  if (!lhsData || !rhsData) return {};

  auto lhs = Tensor(lhsData);
  auto rhs = Tensor(rhsData);
  auto result = eval(*this, lhs, rhs);

  SmallVector<APFloat> values;
  for (auto i = 0; i < result.getNumElements(); ++i) {
    Element element = result.get(i);
    values.push_back(element.getValue().cast<FloatAttr>().getValue());
  }

  return DenseElementsAttr::get(result.getType(), values);
}

현재는 StableHLO용 폴더를 구현할 계획이 없으므로 인터프리터를 상수 폴딩에 통합하는 작업을 적극적으로 진행하고 있지 않습니다. 그러나 앞으로 MHLO의 상수 폴딩에 인터프리터를 활용할 계획입니다. 이때 위 코드 스니펫의 인체공학을 개선할 예정입니다 (예: 상수 피연산자를 Tensor 객체에 압축하고 Tensor 결과를 OpFoldResult로 압축해제하는 도우미 함수가 있을 수 있음).

StableHLO 인터프리터 테스트

인터프리터는 (A) StableHLO 프로그램 및 (B) 프로그램에 제공할 데이터 값을 입력으로 취하고 출력 데이터 값을 생성합니다. 이 값은 사용자가 제공한 예상 데이터 값과 일치합니다. 데이터 값 (B)는 stablehlo.constant 연산을 사용하여 프로그램 자체에 하드 코딩됩니다. 인터프리터는 입력 프로그램을 평가합니다. 테스트 중인 작업의 출력은 아래와 같이 검사(예: check.expect_eq, check.expect_almost_eq)를 통해 확인됩니다. check.expect_eqcheck.expect_eq_const은 지원되는 모든 유형의 비트 동등성을 확인하고 check.expect_almost_eqcheck.expect_almost_eq_const는 부동 소수점 유형과 복합 유형의 경우 테스트 가이드라인 (G6)에 설명된 공차 내에서 거의 동등한지 확인합니다.

// CHECK-LABEL: Evaluated results of function: add_op_test_ui4
func.func @add_op_test_ui4() {
  %0 = stablehlo.constant dense<[0, 2]> : tensor<2xui4>
  %1 = stablehlo.constant dense<[15, 3]> : tensor<2xui4>
  %2 = stablehlo.add %0, %1 : tensor<2xui4>
  check.expect_eq_const %2, [15, 5] : tensor<2xui4>
  func.return
}

테스트 유틸리티 stablehlo-translate --interpret(코드)는 프로그램을 파싱하고 함수를 구성하는 작업을 포함하여 각 함수를 해석합니다. 각 StableHLO 작업에 다양한 런타임 동작을 실행하는 여러 테스트로 구성된 전용 테스트 모음이 있습니다. 테스트는 여기(예: interpret_*.mlir)에서 확인할 수 있습니다.

테스트 가이드라인

(G1) 모든 작업에 지원되는 모든 유형을 테스트해야 하나요?

다음 규칙을 조합하여 결정할 수 있습니다.

  1. 작업을 구현하는 동안 특정 유형을 처리할 eval 함수에 코드가 있는 경우 그러한 유형을 처리할 테스트가 있어야 합니다. 예를 들어 add 작업의 경우 정수, 불리언, 부동 소수점 및 복합 유형을 처리하는 전용 코드가 있으므로 각 유형 카테고리에 하나의 테스트가 필요합니다.

  2. 유형 집합이 해당하는 eval 함수에서 균일하게 처리되는 경우 이러한 모든 유형에 관한 단일 테스트로 충분합니다. 예를 들어 add 작업의 경우 정수 유형의 모든 변형 (si4, u4, si8, u8 등)이 llvm::APInt API를 사용하여 동일하게 처리되므로 이러한 각 변형의 테스트 추가를 건너뛰고 대신 단일 대표 테스트를 추가할 수 있습니다. 대표자를 선택할 때 모호함을 피하려면 다음 가이드라인을 따라야 합니다.

    • 균일하게 처리되는 모든 유형이 동일한 기본 유형인 경우(즉, 모두 정수, 부동 소수점 또는 복합 유형인 경우) 최대 비트 너비가 있는 유형을 선택합니다.
    • 균일하게 처리되는 모든 유형에 기본 유형이 혼합된 경우 기본 유형(정수, 부동 소수점, 불리언, 복합 유형)을 선호도에 따라 내림차순으로 정렬한 유형을 선택합니다.

(G2) 특정 작업의 동작을 처리하는 데 필요한 테스트 수는 어떻게 결정하나요?

목표는 최소한의 테스트로 연산의 인터프리터 로직(즉, 구현의 모든 특수 사례)을 포괄적으로 다루는 것입니다. 테스트 수를 최소화하는 것이 유지관리를 위해 중요합니다. 테스트 수가 적을수록 테스트를 검토하고 작업을 포괄적으로 다루는지 확인하기가 더 쉽습니다. 따라서 대부분의 간단한 작업에서 하나의 테스트만 거치게 됩니다. 어떤 이유로든 포괄적인 적용이 비실용적이라면 90% 이상으로 중단하는 것이 좋습니다. 이는 pull 요청을 검토하는 동안 사례별로 결정됩니다.

(G3) 인터프리터 인프라 테스트를 추가하는 방법은 무엇인가요?

인터프리터 인프라는 대체로 간단하며 Google 신뢰 기반에 추가할 수 있습니다. 중요한 점은 다양한 유형이 기본 인터프리터 저장소에 패키징되고 언패킹되는 방식입니다. (G1)에서 설명한 것처럼 다르게 처리되는 작업 유형만 테스트합니다. 따라서 다양한 정수/부동 소수점 유형에 해당하는 패킹/패킹 해제 코드가 테스트 중에 완전히 포함되지 않을 수 있습니다. 전체 적용 범위를 보장하기 위해 모든 StableHLO 요소 유형을 지원하는 constant와 같은 작업을 선택하고 포괄적 테스트를 작성할 수 있습니다.

(G4) 작업 구현이 다른 작업에 종속되는 경우 후자의 테스트를 작성해야 하나요?

아니요. 예를 들어 batch_norm_grad의 구현은 divide, subtract, multiply 등을 기반으로 할 수 있습니다. 전자를 테스트하는 동안 후자의 작업은 테스트하지 않아야 합니다.

(G5) 구현 정의 / 정의되지 않은 동작을 실행하는 테스트를 작성해야 하나요?

작업의 구현 정의 또는 정의되지 않은 동작을 실행하는 테스트를 작성하면 안 됩니다. 구현 정의 동작을 실행하는 테스트는 일반화하면 안 되는 인터프리터의 로컬 동작을 보여줍니다. 정의되지 않은 동작을 실행하는 테스트는 작업의 동작을 이해하는 데 도움이 되지 않습니다.

(G6) 부동 소수점 유형의 테스트를 작성하는 동안 검사에서 예상 결과를 어떤 정밀도로 지정해야 하나요?

기본 연산 (더하기, 빼기, 곱하기, 나누기, 제곱)의 경우 IEEE 사양을 따르는 구현은 수학적으로 정확한 결과의 0.5 ULP 내에서 반올림된 결과를 제공해야 합니다. 그렇지만 이러한 작업에서 발생할 것으로 예상되는 결과가 최대 1개의 ULP 차이로 상상할 수 있습니다. 하지만 정밀도 보장이 구현되는(근거) 초월 함수 (sine, cosine 등)에는 작동하지 않을 수 있습니다.

현재 구현에서는 '일률적인' 공차 값 0.0001을 사용합니다. 다음 예는 위의 허용 오차를 실제로 보여줍니다.

func.func @check_tolerance() {
  %0 = stablehlo.constant dense<0.2> : tensor<f32>

  // The following check succeeds as %0 is almost equal to the provided
  // constant modulo the tolerance, mentioned above.
  check.expect_almost_eq_const %0, dense<0.19999> : tensor<f32>

  // The following check fails as %0 is not bitwise equal to the provided
  // constant.
  check.expect_eq_const %0, dense<0.19999> : tensor<f32>

  func.return
}

이는 StableHLO 작업의 수치 정확성을 테스트하는 첫 번째 단계일 뿐입니다. 현재 이는 StableHLO 사양에서 잘 지정되지 않은 영역이며, StableHLO를 실제로 사용해 본 경험과 이해관계자의 의견을 바탕으로 #1156을 알아내기 위해 노력하고 있습니다. 이 작업이 진행됨에 따라 인프라를 업데이트할 예정입니다.

(G7) 테스트의 코딩 스타일에 관해 궁금한 점이 있나요?

  1. SSA 값 (예: %0, %1 등)을 기본값으로 설정하지 않고 입력/출력의 실제 이름을 사용해야 합니다.
  2. 테스트에서 예쁘게 인쇄된 형식을 사용해야 합니다(있는 경우).

(G8) 사양에 이미 제공된 예시를 포함해야 하나요? 예 (테스트의 완전성을 위해).