배경
독자는 텐서 샤딩을 Shardy로 표현하는 방법을 설명하는 샤딩 표현의 기본사항을 숙지하고 있다고 가정합니다. 이 문서에서는 샤딩 표현을 프로그램에서 사용하는 방법을 보여줍니다(예: 프로그램의 특정 텐서에 샤딩을 연결하는 경우).
샤딩 전파는 텐서 하위 집합의 샤딩 제약 조건을 고려하여 프로그램의 모든 텐서에 대한 샤딩을 결정하는 프로세스입니다. Shardy의 컴파일러 API는 샤딩 전파에 영향을 주거나 제어하는 여러 가지 방법을 노출합니다. 또한 사용자가 수동으로 샤딩된 계산을 프로그램에 삽입할 수 있습니다.
목표
이 문서에서는 Shardy의 이러한 API 구성요소 설계를 설명하고 동작 및 불변 항목을 설명합니다. 이 API는 샤딩 전파를 제어하는 데 사용되지만 이 문서에서는 전파 동작이나 설계 방법에 관해 다루지 않습니다.
개요
입력/출력 샤딩 - 샤딩을 기본 함수의 입력 또는 출력에 연결하여 함수에 제공되거나 함수에서 반환될 때 입력/출력 텐서를 샤딩해야 하는 방식을 나타냅니다.
샤딩 제약조건 - 중간 텐서 (예: matmul의 결과)에 샤딩을 연결하여 이 텐서 또는 사용의 하위 집합을 샤딩해야 함을 나타냅니다.
샤딩 그룹: 여러 텐서를 ID별로 그룹화하여 동일한 방식으로 샤딩해야 함을 나타냅니다.
수동 계산: 메시 축의 하위 집합을 사용하여 수동으로 분할된 하위 계산을 묶습니다. 여기서 이러한 수동 축을 따라 샤딩이 모든 입력과 출력에 지정되고 하위 계산 내에서 텐서 유형은 이러한 샤딩과 관련하여 로컬입니다.
세부 설계
입력/출력 샤딩
사용자가 기본 함수의 입력과 출력에 샤딩을 지정할 수 있습니다.
MLIR에서는 속성을 함수 인수 및 결과에 연결할 수 있으므로 사용자는 이러한 방식으로 함수에 샤딩 속성을 연결할 수 있습니다.
예를 들면 다음과 같습니다.
@mesh_xy = <["x"=2, "y"=2]>
// The 1st input has a sharding specified, but the 2nd input doesn't.
// The output has a sharding specified.
func @main(%arg0: tensor<8x8xf32>
{sdy.sharding = #sdy.sharding<@mesh_xy, [{"x"}, {}]>},
%arg1: tensor<8x16xf32>)
-> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_xy, [{}, {"y"}]>}) {
...
}
샤딩 제약 조건
사용자가 프로그램의 중간 텐서에 샤딩을 연결할 수 있습니다. 그러면 샤딩이 텐서 또는 사용의 하위 집합에 적용되어야 한다는 것을 파티셔너에 알립니다.
이는 텐서를 입력으로 사용하고 샤딩 속성이 연결된 MLIR 연산입니다. 작업은 다음 중 하나일 수 있습니다.
- 사용이 없음 (댕글링) - 즉, 연결된 샤딩이 텐서 자체를 샤딩하는 방법입니다.
- 사용이 있음 - 즉, 연결된 샤딩은 샤딩 제약 조건 연산자의 사용이 샤딩되는 방식이지만 입력 텐서의 다른 사용에는 다른 샤딩이 있을 수 있습니다. 입력 텐서에 다른 사용이 없는 경우 동작은 사용이 없는 케이스와 동일합니다. 전파는 텐서 자체의 샤딩을 결정하고 필요한 경우 다시 샤딩합니다.
개방형 측정기준 샤딩이 있을 수 있습니다. 즉, 피연산자를 사용 가능한 축을 따라 추가로 샤딩할 수 있습니다.
@mesh_xy = <["x"=2, "y"=2]>
%0 = ... : tensor<8x8xf32>
%1 = sdy.sharding_constraint %0 <@mesh_xy, [{"x"}, {?}]> : tensor<8x8xf32>
샤딩 그룹
두 개 이상의 텐서 간에 데이터 종속 항목이 없거나 강력한 데이터 종속 항목이 없는 경우, 사용자가 이러한 텐서를 동일한 방식 또는 유사한 방식으로 파티셔닝해야 한다는 것을 알고 있는 경우 Shardy API는 이 관계를 지정하는 방법을 제공합니다. 이를 통해 사용자는 텐서가 서로 파티션되어야 한다고 명시적으로 지정할 수 있습니다.
이를 위해 샤드 그룹이라는 개념을 도입합니다. 여기서 각 그룹은 동일한 샤드 그룹 ID와 연결된 임의 수의 안내를 포함합니다. 샤딩 그룹은 동일한 그룹 내에서 샤딩이 동일하도록 적용합니다.
예를 들어 아래와 같은 가상의 사용자 프로그램에서 프로그램의 입력과 정확히 동일하게 프로그램의 출력을 샤딩하려고 합니다. 두 프로그램 간에 데이터 종속 항목은 없습니다.
이 프로그램을 실행하면 샤딩 전파가 텐서 %1
및 %2
의 샤딩을 추론할 수 없으므로 결국 복제됩니다.
그러나 입력 %0
와 출력 %2
가 동일한 shard_group
내에 있다고 나타내는 shard_group
속성을 연결하면 샤딩 @mesh_xy,
[{"x"},{"y"}]>
가 입력 %0
에서 출력 %2
로, 그리고 나머지 그래프로 전파되도록 허용합니다. 여기서 나머지 그래프는 상수 %1
로 브로드캐스트됩니다. sdy.sharding_group
연산을 사용하여 그룹에 값을 할당할 수 있습니다.
@mesh_xy = <["x"=2, "y"=2]>
module @"jit_zeros_like" {
func.func @main(%arg0: tensor<8x2xi64> {sdy.sharding = #sdy.sharding<@mesh_xy, [{"x"},{"y"}]>} }) -> (tensor<8x2xi64>) {
%0 = sdy.sharding_group %arg0, id=0 : tensor<8x2xi64>
%1 = stablehlo.constant dense<0> : tensor<8x2xi64>
%2 = sdy.sharding_group %1, id=0 : tensor<8x2xi64>
return %2 : tensor<8x2xi64>
}
}
위의 간단한 예시에서 출력에 입력과 동일한 샤딩을 명시적으로 지정할 수도 있습니다. 이렇게 하면 입력에 할당할 샤딩을 미리 알고 있으므로 동일한 효과를 얻을 수 있습니다. 그러나 더 현실적인 경우에는 샤딩을 사용하여 여러 텐서의 샤딩을 동기화하면서 샤딩을 반드시 알 필요는 없습니다. Shardy가 나머지를 처리하고 할당할 최적의 샤딩을 찾습니다.
수동 계산
사용자는 계산의 일부가 파티션되는 방식과 사용되는 집합을 명시적으로 제어하고 싶을 수 있습니다. 예를 들어 일부 사용자는 컴파일러에 위임하는 대신 프런트엔드 API에서 수동으로 집계 matmul을 적용하려고 합니다. 이를 지원하는 수동 계산 API를 제공합니다.
수동 하위 계산을 위한 단일 리전이 있는 MLIR 작업입니다. 사용자는 메시 축의 하위 집합 (전체 포함)을 사용하여 이 하위 계산에 입력/출력 샤딩을 지정합니다. 하위 계산은 지정된 메시 축 (수동 축이라고도 함)과 관련하여 로컬/수동이고 지정되지 않은 축 (자유 축이라고도 함)과 관련하여 글로벌/파티션되지 않은 것입니다. 하위 계산은 이 작업 외부의 계산과 동일한 방식으로 전파 중에 자유 축을 따라 추가로 샤딩될 수 있습니다.
예를 들면 다음과 같습니다.
@mesh_name = <["data"=2, "model"=2]>
%0 = ... : tensor<16x32xf32>
%1 = sdy.manual_computation(%0)
in_shardings=[<@mesh_name, [{"data"}, {"model",?}]>]
out_shardings=[<@mesh_name, [{"data"}, {?}]>]
manual_axes={"data"}
(%arg1: tensor<8x32xf32>) {
// body
return %42 : tensor<8x32xf32>
} : (tensor<16x32xf32>) -> tensor<16x32xf32>
불변항
모든
in_shardings
,out_shardings
,manual_axes
는 동일한 메시를 참조해야 합니다.manual_axes
는 메시를 기준으로 정렬됩니다.manual_axes
는 모든 인/아웃 샤딩에서 명시적으로 사용해야 합니다. 즉, 각 샤딩의 경우 모든 수동 축이 측정기준을 샤딩하거나 명시적으로 복제되어야 합니다.자유 축 (
manual_axes
에 없는 메시 축)이 입력/출력 샤딩 중 하나에 있는 경우 동일한 측정기준 샤딩의 수동 축보다 작아야 합니다 (위 예에서는 측정기준 샤딩{"model", "data"}
이 유효하지 않음).계산의 리전/본문은 로컬 계산입니다 (예: 사용자 지정 집합 포함). 수동 축을 따라 인/아웃 샤딩과 관련하여 로컬이어야 합니다 (위의 참고사항 참고).
수동 계산 중첩
각 수동 계산이 고유한 수동 축 세트를 사용하여 작동하는 한 여러 수동 계산을 서로 중첩할 수 있습니다.