Shardy (SDY) 방언은 축 기반 텐서 샤딩 표현과 샤딩을 텐서에 연결하는 추가 API 구성요소를 정의합니다.
작업
sdy.all_gather
(sdy::AllGatherOp)
축을 따라 all-gather 통신을 실행합니다.
구문:
operation ::= `sdy.all_gather` $gathering_axes $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)
gathering_axes
에 지정된 축을 따라 텐서 청크를 수집합니다.
gathering_axes
는 축 목록의 목록입니다. 외부 목록이 텐서 크기를 초과합니다. 각 내부 목록은 각 측정기준에 대해 별도의 수집을 실행해야 하는 축을 지정합니다. 이는 피연산자 (tensor
)의 샤딩에 적용되어 결과 (out_sharding
)의 샤딩을 가져옵니다.
out_sharding
는 결과의 샤딩을 결정하는 데 사용되지 않습니다. 대신 결과의 샤딩은 피연산자와 gathering_axes
의 샤딩에 따라 결정되며 out_sharding
는 이 추론된 샤딩과 일치해야 합니다.
예:
%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", "b", "c"}, {}, {"d"}\]>]>} : tensor<8x8x8xf32>
%2 = sdy.all_gather [{"b", "c"}, {}, {"d"}\] %1 out_sharding=<@mesh, [{"a"}, {}, {}\]> : tensor<8x8x8xf32>
제약조건:
Sdy_CollectiveOpInterface
에 나열된 제약조건을 충족해야 합니다.gathering_axes
의 요소는AxisRefListAttr
에 나열된 제약 조건을 충족해야 합니다.- 피연산자 샤딩에
gathering_axes
를 적용하면out_sharding
가 됩니다.
트레잇: SameOperandsAndResultType
인터페이스: InferTypeOpInterface
, Sdy_CollectiveOpInterface
속성:
속성 | MLIR 유형 | 설명 |
---|---|---|
gathering_axes | ::mlir::sdy::ListOfAxisRefListsAttr | 축 참조 목록 목록 |
out_sharding | ::mlir::sdy::TensorShardingAttr | 텐서 샤딩 |
피연산자:
피연산자 | 설명 |
---|---|
tensor |
모든 유형 값의 텐서 |
결과:
결과 | 설명 |
---|---|
result |
모든 유형 값의 텐서 |
sdy.all_reduce
(sdy::AllReduceOp)
축을 따라 올리듀스 통신 수행
구문:
operation ::= `sdy.all_reduce` $reduction_axes $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)
reduction_axes
에 지정된 축을 따라 텐서 청크를 줄입니다.
reduction_axes
의 순서는 결과에 중요하지 않지만 상응하는 복제 그룹의 순서에 영향을 줄 수 있습니다.
제약조건:
Sdy_CollectiveOpInterface
에 나열된 제약조건을 충족해야 합니다.reduction_axes
는AxisRefListAttr
에 나열된 제약조건을 충족해야 합니다.reduction_axes
는 피연산자 샤딩 축과 겹쳐서는 안 됩니다.
트레잇: SameOperandsAndResultType
인터페이스: CollectiveOpInterface
, InferTypeOpInterface
속성:
속성 | MLIR 유형 | 설명 |
---|---|---|
reduction_axes | ::mlir::sdy::AxisRefListAttr | 축 참조 목록 |
out_sharding | ::mlir::sdy::TensorShardingAttr | 텐서 샤딩 |
피연산자:
피연산자 | 설명 |
---|---|
tensor |
모든 유형 값의 텐서 |
결과:
결과 | 설명 |
---|---|
result |
모든 유형 값의 텐서 |
sdy.all_slice
(sdy::AllSliceOp)
축을 따라 동적 슬라이스 작업을 실행합니다.
구문:
operation ::= `sdy.all_slice` $slicing_axes $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)
slicing_axes
에 지정된 축을 따라 텐서의 청크를 슬라이스합니다. sdy.all_slice
와 sdy.all_gather
사이에는 대수적 이중성이 있습니다.
slicing_axes
는 축 목록의 목록입니다. 외부 목록이 텐서 크기를 초과합니다. 각 내부 목록은 각 측정기준에서 슬라이스가 실행되어야 하는 축을 지정합니다. 이는 피연산자 (tensor
)의 샤딩에 적용되어 결과(out_sharding
)의 샤딩을 가져옵니다.
out_sharding
는 결과의 샤딩을 결정하는 데 사용되지 않습니다. 대신 결과의 샤딩은 피연산자와 slicing_axes
의 샤딩에 따라 결정되며 out_sharding
는 이 추론된 샤딩과 일치해야 합니다.
예:
%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}, {}, {}\]>]>} : tensor<8x8x8xf32>
%2 = sdy.all_slice [{"b", "c"}, {}, {"d"}\] %1 out_sharding=<@mesh, [{"a", "b", "c"}, {}, {"d"}\]> : tensor<8x8x8xf32>
제약조건:
slicing_axes
의 요소는AxisRefListAttr
에 나열된 제약 조건을 충족해야 합니다.Sdy_CollectiveOpInterface
에 나열된 제약조건을 충족해야 합니다.- 피연산자 샤딩에
slicing_axes
를 적용하면out_sharding
가 됩니다.
트레잇: SameOperandsAndResultType
인터페이스: CollectiveOpInterface
, InferTypeOpInterface
속성:
속성 | MLIR 유형 | 설명 |
---|---|---|
slicing_axes | ::mlir::sdy::ListOfAxisRefListsAttr | 축 참조 목록 목록 |
out_sharding | ::mlir::sdy::TensorShardingAttr | 텐서 샤딩 |
피연산자:
피연산자 | 설명 |
---|---|
tensor |
모든 유형 값의 텐서 |
결과:
결과 | 설명 |
---|---|
result |
모든 유형 값의 텐서 |
sdy.all_to_all
(sdy::AllToAllOp)
축을 따라 전체 대 전체 통신을 실행합니다.
구문:
operation ::= `sdy.all_to_all` $params $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)
이 연산은 매개변수 목록의 각 (axes, src_dim, tgt_dim) 튜플에 대해 tgt_dim
측정기준 및 axes
에 지정된 축을 따라 텐서 청크를 자르고, 축을 따라 청크를 흩어뜨린 후 src_dim
측정기준을 따라 연결합니다.
이 연산은 기본적으로 src_dim
및 axes
을 따라 all-gather를 수행한 후 tgt_dim
및 axes
을 따라 all-slice를 수행하는 것입니다. 즉, 입력 텐서의 축 샤딩 크기 src_dim
의 접미사가 출력 텐서의 축 샤딩 크기 tgt_dim
에 추가됩니다.
all-to-all은 피연산자 (tensor
)의 샤딩에 적용되어 결과 (out_sharding
)의 샤딩을 얻습니다.
out_sharding
는 결과의 샤딩을 결정하는 데 사용되지 않습니다. 대신 결과의 샤딩은 피연산자 src_dim
, tgt_dim
, axes
의 샤딩에 따라 결정되며 out_sharding
는 이 추론된 샤딩과 일치해야 합니다.
예:
%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", "b"}, {"c"}, {}, {}\]>]>} : tensor<8x8x4x4x32>
%2 = sdy.all_to_all [{"b"}: 0->2, {"c"}: 1->3] %1 out_sharding=<@mesh, [{"a"}, {}, {"b"}, {"c"}\]> : tensor<8x8x4x4x32>
제약조건:
Sdy_CollectiveOpInterface
에 나열된 제약조건을 충족해야 합니다.- 매개변수 목록은 비워 둘 수 없습니다.
params
의 각 매개변수에 대해 다음을 실행합니다.axes
의 요소는AxisRefAttr
의 제약 조건을 충족해야 합니다.src_dim
및tgt_dim
는 유효한 크기 (음이 아니고 텐서 계급보다 작음)여야 합니다.- 모든
src_dim
또는tgt_dim
는 모든 매개변수에서 고유해야 합니다. src_dim
는 모든 매개변수에서 오름차순으로 정렬해야 합니다.
- 피연산자 샤딩에서
axes
를src_dim
에서tgt_dim
로 이동하면out_sharding
이 됩니다.
트레잇: SameOperandsAndResultType
인터페이스: InferTypeOpInterface
, Sdy_CollectiveOpInterface
속성:
속성 | MLIR 유형 | 설명 |
---|---|---|
params | ::mlir::sdy::AlltoAllParamListAttr | all-to-all 매개변수 목록 |
out_sharding | ::mlir::sdy::TensorShardingAttr | 텐서 샤딩 |
피연산자:
피연산자 | 설명 |
---|---|
tensor |
모든 유형 값의 텐서 |
결과:
결과 | 설명 |
---|---|
result |
모든 유형 값의 텐서 |
sdy.collective_permute
(sdy::CollectivePermuteOp)
축을 대체하기 위해 집합-순열 통신을 실행합니다.
구문:
operation ::= `sdy.collective_permute` $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)
각 기기에서 다른 기기로 입력 텐서의 청크를 전송하여 텐서를 샤딩하는 축을 재정렬/대체합니다.
집합 순열은 각 측정기준이 이전과 동일하게 샤딩되도록 입력 샤딩을 변환할 수 있습니다. 즉, 크기의 곱이 이전에 텐서를 샤딩한 축의 곱과 일치하는 축을 따라 샤딩되어야 합니다.
이 방법은 단일 측정기준 또는 여러 측정기준에서 축의 순서를 변경하고 샤딩된 축을 복제된 축으로 전환하는 데 유용합니다.
아래 예에서 샤딩된 텐서 크기는 tensor<1x4x2xf32>
이며 이는 집합 순열에 의해 보존됩니다.
예:
sdy.mesh @mesh = <["a"=2, "b"=2, "c"=4, "d"=2, "e"=2, "f"=2]>
%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", "c"}, {"f"}, {"d", "e"}\]>]>} : tensor<8x8x8xf32>
%2 = sdy.collective_permute %1 out_sharding=<@mesh, [{"c":(1)2, "b", "f"}, {"a"}, {"e", "d"}\]> : tensor<8x8x8xf32>
제약조건:
Sdy_CollectiveOpInterface
에 나열된 제약조건을 충족해야 합니다.- 입력 샤딩과 출력 샤딩의 메시가 다른 경우 이러한 메시의 축은 정확히 동일해야 하며 기기 ID의 순서는 달라야 합니다.
- 각 측정기준의 경우
out_sharding
의 샤딩 축 크기 곱셈 결과는 상응하는 피연산자 측정기준 샤딩의 곱셈 결과와 일치해야 합니다.
트레잇: SameOperandsAndResultType
인터페이스: CollectiveOpInterface
, InferTypeOpInterface
속성:
속성 | MLIR 유형 | 설명 |
---|---|---|
out_sharding | ::mlir::sdy::TensorShardingAttr | 텐서 샤딩 |
피연산자:
피연산자 | 설명 |
---|---|
tensor |
모든 유형 값의 텐서 |
결과:
결과 | 설명 |
---|---|
result |
모든 유형 값의 텐서 |
sdy.constant
(sdy::ConstantOp)
상수 연산
상수 value
에서 output
텐서를 생성합니다.
참고: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#constant
예:
%output = sdy.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
트레잇: AlwaysSpeculatableImplTrait
인터페이스: ConditionallySpeculatable
, InferTypeOpInterface
, NoMemoryEffect (MemoryEffectOpInterface)
효과: MemoryEffects::Effect{}
속성:
속성 | MLIR 유형 | 설명 |
---|---|---|
value | ::mlir::ElementsAttr | 상수 벡터/텐서 속성 |
결과:
결과 | 설명 |
---|---|
output |
모든 유형 값의 정적 형식 텐서 |
sdy.data_flow_edge
(sdy::DataFlowEdgeOp)
데이터 흐름 가장자리 연산.
구문:
operation ::= `sdy.data_flow_edge` $input (`sharding````=``` $sharding^)? attr-dict `:` type($result)
일부 작업 X의 데이터 흐름 가장자리가 소스 집합(각각 X의 피연산자 또는 X의 블록 종료자의 피연산자)과 타겟 집합 (각각 X의 결과 또는 X의 블록 인수) 간의 브리지를 정의합니다. 따라서 모든 소스와 타겟은 동일한 방식으로 샤딩되어야 합니다.
연산에는 서로 직교하는 여러 데이터 흐름 가장자리가 있을 수 있습니다.
예를 들면 다음과 같습니다.
y_0, ..., y_n = while (x_0, ..., x_n)
((pred_arg_0,... , pred_arg_n) { ... })
((body_arg_0,..., body_arg_n) {
...
return return_value_0, ..., return_value_n
})
이 while 연산에는 n개의 데이터 흐름 가장자리가 있으며, i번째 데이터 흐름 가장자리가 소스 x_i
, return_value_i
와 타겟 y_i
, pred_arg_i
, body_arg_i
사이입니다.
sdy.data_flow_edge
는 다른 용도가 없어야 하는 에지의 소유자를 입력으로 사용합니다 (어떤 타겟이든 될 수 있지만 블록 인수가 아닌 op 결과가 더 좋음). 이 연산자는 원래 사용되지 않았던 입력을 사용할 수 있으므로 순수하지 않습니다.
sdy.data_flow_edge
는 또한 에지의 모든 타겟에 대한 선택적 샤딩을 보유하며, 전파 중에 타겟의 샤딩 대신 이 샤딩을 업데이트해야 합니다 (연결 가능한 경우). 이는 다음과 같은 작업을 훨씬 더 효율적으로 수행할 수 있으므로 연산에 많은 가장자리가 있는 경우에 유용합니다.
- 각 에지를 통해 별도로 전파됩니다.
- 모든 타겟을 한 번에 업데이트하는 대신 각 에지의 샤딩을 개별적으로 업데이트합니다(예: 연산에 결과 샤딩을 위한 변경 불가능한 단일
TensorShardingPerValueAttr
가 있음). - 소스의 샤딩이 변경될 때 각 에지를 작업 목록에 별도로 추가합니다.
전파는 소스가 피연산자이고 대상이 결과인 일반 연산자와 ID sdy.op_sharding_rule
인 것처럼 sdy.data_flow_edge
의 모든 소스와 대상 간에 샤딩을 전파합니다. 즉, 전방 전파는 소스에서 대상으로, 후방 전파는 대상에서 소스로 진행됩니다.
sdy.data_flow_edge
의 입력이 SdyDialect
연산자에 의해 정의되는 것은 허용되지 않으므로 등록되지 않은 sdy.sharding
속성이 있는 연산자에 의해 정의되었다고 가정할 수 있습니다.
트레잇: SameOperandsAndResultType
인터페이스: InferTypeOpInterface
속성:
속성 | MLIR 유형 | 설명 |
---|---|---|
sharding | ::mlir::sdy::TensorShardingAttr | 텐서 샤딩 |
피연산자:
피연산자 | 설명 |
---|---|
input |
모든 유형 값의 모양 |
결과:
결과 | 설명 |
---|---|
result |
모든 유형 값의 모양 |
sdy.manual_computation
(sdy::ManualComputationOp)
수동 집합을 사용한 멀티 디바이스 병렬 작업
구문:
operation ::= `sdy.manual_computation` `(`operands`)`
`in_shardings````=```custom<StrippedTensorShardingPerValueAttr>($in_shardings)
`out_shardings````=```custom<StrippedTensorShardingPerValueAttr>($out_shardings)
`manual_axes````=```$manual_axes
custom<SingleBlockRegionNoBlockId>($body)
attr-dict
`:`
functional-type(operands, results)
명시적 집합을 사용하여 기기별 로컬 코드로 작성된 영역으로 이동합니다. 이 영역에서는 논리적 도형이 기기별 로컬 물리적 버퍼 도형과 일치하고 집합이 물리적 교차 기기 통신에 정확하게 대응합니다.
본체는 manual_axes를 기준으로 로컬입니다. 전파는 manual_axes 목록에 없는 모든 자유축의 본문을 통해 발생합니다.
제약조건:
in_shardings
및out_shardings
의 요소는TensorShardingAttr
에 나열된 제약조건을 충족해야 합니다.- 연산 영역의 전역 및 로컬 텐서 입력/출력 수가 일치해야 합니다.
- 수동 축은 각 측정기준 샤딩에서 무료 축 앞에 있어야 합니다.
- 수동 축은 패딩을 도입할 수 없습니다. 즉, 측정기준 크기는 해당하는 수동 축 크기로 나눌 수 있어야 합니다.
- op regions 인수/결과의 전역 및 로컬 도형이 일치해야 합니다.
- 수동 축은 분할되지 않습니다.
트레잇: IsolatedFromAbove
, RecursiveMemoryEffects
, SingleBlockImplicitTerminator<ReturnOp>
, SingleBlock
인터페이스: ShardableDataFlowOpInterface
속성:
속성 | MLIR 유형 | 설명 |
---|---|---|
in_shardings | ::mlir::sdy::TensorShardingPerValueAttr | 연산의 피연산자/결과당 텐서 샤딩 |
out_shardings | ::mlir::sdy::TensorShardingPerValueAttr | 연산의 피연산자/결과당 텐서 샤딩 |
manual_axes | ::mlir::sdy::ManualAxesAttr | ManualComputationOp가 수동인 축 목록 |
피연산자:
피연산자 | 설명 |
---|---|
tensors |
모든 유형 값의 순위 지정된 텐서의 변형 |
결과:
결과 | 설명 |
---|---|
results |
모든 유형 값의 순위 지정된 텐서의 변형 |
sdy.mesh
(sdy::MeshOp)
이름이 지정된 메시
구문:
operation ::= `sdy.mesh` $sym_name `=` $mesh attr-dict
이름이 지정된 새 메시를 정의합니다. 모듈의 모든 메시에는 동일한 수의 기기가 있어야 합니다 (단일 device_id가 있는 메시 제외).
메시는 모듈의 SymbolTable
에 표시되고 name
에서 참조할 수 있는 Symbol
작업입니다.
트레잇: HasParent<ModuleOp>
인터페이스: Symbol
속성:
속성 | MLIR 유형 | 설명 |
---|---|---|
sym_name | ::mlir::StringAttr | 문자열 속성 |
mesh | ::mlir::sdy::MeshAttr | 축 메시 및 기기 목록 |
sdy.named_computation
(sdy::NamedComputationOp)
이름이 지정된 계산 작업
구문:
operation ::= `sdy.named_computation` `<`$name`>` `` `(` $operands `)`
(`in_shardings````=```custom<StrippedTensorShardingPerValueAttr>($in_shardings)^)?
(`out_shardings````=```custom<StrippedTensorShardingPerValueAttr>($out_shardings)^)?
custom<SingleBlockRegionNoBlockId>($body)
attr-dict
`:` functional-type($operands, results)
계산(작업 블록)을 그룹화하고 이름을 지정합니다. 모든 것이 인라인 처리된 것처럼 전파가 영역 안팎으로 흐릅니다.
이를 사용하여 호출 안내를 통해 다른 함수로 전파를 처리할 수 있습니다. Shardy 사용자는 호출 작업을 sdy.named_computation
작업으로 변환하여 호출된 함수의 본문을 named_computation
의 본문으로 복제/복사하는 가져오기/내보내기 패스를 작성해야 합니다.
리전의 각 블록 인수 및 반환 값의 유형은 피연산자의 유형 및 연산자의 결과 유형과 동일해야 합니다.
예:
%1 = sdy.named_computation<"foo">(%0) (%arg1: tensor<16x32xf32>) {
sdy.return %arg1 : tensor<16x32xf32>
} : (tensor<16x32xf32>) -> tensor<16x32xf32>
트레잇: IsolatedFromAbove
, RecursiveMemoryEffects
, RecursivelySpeculatableImplTrait
, SingleBlockImplicitTerminator<ReturnOp>
, SingleBlock
인터페이스: ConditionallySpeculatable
, InferTypeOpInterface
, ShardableDataFlowOpInterface
속성:
속성 | MLIR 유형 | 설명 |
---|---|---|
name | ::mlir::StringAttr | 문자열 속성 |
in_shardings | ::mlir::sdy::TensorShardingPerValueAttr | 연산의 피연산자/결과당 텐서 샤딩 |
out_shardings | ::mlir::sdy::TensorShardingPerValueAttr | 연산의 피연산자/결과당 텐서 샤딩 |
피연산자:
피연산자 | 설명 |
---|---|
operands |
모든 유형의 변형 인수 |
결과:
결과 | 설명 |
---|---|
«unnamed» | 모든 유형의 변형 인수 |
sdy.propagation_barrier
(sdy::PropagationBarrierOp)
전파 장벽 작업
구문:
operation ::= `sdy.propagation_barrier` $input `allowed_direction````=```$allowed_direction attr-dict `:` type($input)
이 연산자는 ID 연산자처럼 작동하여 입력으로 사용한 것과 동일한 값을 출력합니다. 하지만 전파 측면에서는 전파가 특정 방향으로만 전파되도록 허용합니다.
이렇게 하면 배리어 연산 결과와 연산자의 사용 간에 샤딩이 전파되지 않습니다.
FORWARD
는 샤딩이 피연산자에서 결과로만 흐를 수 있음을 의미합니다.BACKWARD
는 샤딩이 결과에서 피연산자로만 흐를 수 있음을 의미합니다.NONE
은 이 연산을 통해 샤딩이 전파될 수 없음을 의미합니다.- 이 연산은 중복되므로
BOTH
를 지정할 수 없습니다.
트레잇: AlwaysSpeculatableImplTrait
, SameOperandsAndResultType
인터페이스: ConditionallySpeculatable
, InferTypeOpInterface
, NoMemoryEffect (MemoryEffectOpInterface)
효과: MemoryEffects::Effect{}
속성:
속성 | MLIR 유형 | 설명 |
---|---|---|
allowed_direction | ::mlir::sdy::PropagationDirectionAttr | 전파 방향 enum |
피연산자:
피연산자 | 설명 |
---|---|
input |
모든 유형 값의 순위 지정된 텐서 |
결과:
결과 | 설명 |
---|---|
result |
모든 유형 값의 순위 지정된 텐서 |
sdy.reshard
(sdy::ReshardOp)
텐서를 다른 샤딩으로 리샤딩합니다.
구문:
operation ::= `sdy.reshard` $input $sharding attr-dict `:` type($result)
입력 텐서의 기존 샤딩과 다른 지정된 샤딩으로 입력 텐서를 리샤딩합니다.
ShardingConstraintOp와 ReshardOp는 모두 샤딩을 텐서에 연결합니다. 수명은 다음과 같습니다.
- 샤딩 전파 전에 ShardingConstraintOp가 사용자에 의해 추가됩니다.
- 샤딩 전파는 ShardingConstraintOp를 사용합니다. 샤딩 전파 결과에 ShardingConstraintOp가 없습니다. 대신 필요한 경우 ReshardOp를 추가할 수 있습니다.
- 파티셔너는 ReshardOp를 집계 연산 (또는 ID 연산)으로 변환합니다. 파티셔너의 결과에 ReshardOp가 없어야 합니다.
// TODO(b/331680067). 중복된 // 리샤드 작업을 삭제하는 정규화 패턴을 추가합니다.
트레잇: AlwaysSpeculatableImplTrait
, SameOperandsAndResultType
인터페이스: ConditionallySpeculatable
, InferTypeOpInterface
, NoMemoryEffect (MemoryEffectOpInterface)
효과: MemoryEffects::Effect{}
속성:
속성 | MLIR 유형 | 설명 |
---|---|---|
sharding | ::mlir::sdy::TensorShardingAttr | 텐서 샤딩 |
피연산자:
피연산자 | 설명 |
---|---|
input |
모든 유형 값의 텐서 |
결과:
결과 | 설명 |
---|---|
result |
모든 유형 값의 텐서 |
sdy.return
(sdy::ReturnOp)
sdy.return
작업은 sdy
리전 기반 작업 및 기타 Shardy 리전 기반 작업에 연결된 리전을 종료합니다. 이는 가변 인수입니다. 유형이 어떤 유형이든 될 수 있지만 동일한 종류 (예: AnyTensor
)여야 하며, 따라서 Shardy IR 스택의 다양한 수준에서 재사용할 수 있습니다.
구문:
operation ::= `sdy.return` attr-dict ($results^ `:` type($results))?
트레잇: AlwaysSpeculatableImplTrait
, Terminator
인터페이스: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
효과: MemoryEffects::Effect{}
피연산자:
피연산자 | 설명 |
---|---|
results |
모든 유형의 변형 인수 |
sdy.sharding_constraint
(sdy::ShardingConstraintOp)
텐서를 지정된 샤딩으로 제한합니다.
구문:
operation ::= `sdy.sharding_constraint` $input $sharding attr-dict `:` type($result)
중간 텐서 (예: matmul의 결과)에 샤딩을 연결하여 이 텐서 또는 사용의 하위 집합을 샤딩해야 하는 방식을 나타냅니다.
샤딩에 개방형 측정기준과 제약 조건이 없는 축이 있는 경우 텐서를 개방형 측정기준을 따라 추가로 샤딩할 수 있음을 의미합니다.
이 작업은 다음 중 하나를 실행할 수 있습니다.
- 사용이 없음 (댕글링) - 즉, 연결된 샤딩이 입력 텐서 자체를 샤딩하는 방법입니다.
- 사용이 있음 - 즉, 연결된 샤딩은 샤딩 제약 조건 연산자의 사용이 샤딩되는 방식이지만 입력 텐서의 다른 사용에는 다른 샤딩이 있을 수 있습니다. 입력 텐서에 다른 사용이 없는 경우 동작은 사용이 없는 사례와 동일합니다.
트레잇: SameOperandsAndResultType
인터페이스: InferTypeOpInterface
속성:
속성 | MLIR 유형 | 설명 |
---|---|---|
sharding | ::mlir::sdy::TensorShardingAttr | 텐서 샤딩 |
피연산자:
피연산자 | 설명 |
---|---|
input |
모든 유형 값의 텐서 |
결과:
결과 | 설명 |
---|---|
result |
모든 유형 값의 텐서 |
sdy.sharding_group
(sdy::ShardingGroupOp)
그룹의 텐서가 동일한 샤딩을 갖도록 제약합니다.
구문:
operation ::= `sdy.sharding_group` $input `group_id````=```$group_id attr-dict `:` type($input)
이 연산자는 샤딩 그룹(동일한 샤딩이 적용되도록 시행되는 텐서 그룹)에 텐서를 할당하는 인터페이스를 제공합니다. 전파 중에 하나의 그룹 요소가 샤딩되면 다른 모든 구성원이 정확히 동일한 방식으로 샤딩됩니다. 이 연산은 인수 그룹 ID를 사용하고 결과를 반환하지 않지만 대신 내부 샤딩 그룹 표현을 수정하여 지정된 ID의 그룹에 입력 텐서를 추가합니다.
인터페이스: InferTypeOpInterface
속성:
속성 | MLIR 유형 | 설명 |
---|---|---|
group_id | ::mlir::IntegerAttr | 부호 없는 64비트 정수 속성 |
피연산자:
피연산자 | 설명 |
---|---|
input |
모든 유형 값의 순위 지정된 텐서 |
속성
AllToAllParamAttr
All-to-all 매개변수
구문:
#sdy.all_to_all_param<
::llvm::ArrayRef<AxisRefAttr>, # axes
int64_t, # src_dim
int64_t # tgt_dim
>
all-to-all을 실행할 축과 소스/타겟 측정기준이 포함된 튜플입니다.
매개변수:
매개변수 | C++ 유형 | 설명 |
---|---|---|
축 | ::llvm::ArrayRef<AxisRefAttr> |
all-to-all을 실행할 축 |
src_dim | int64_t |
소스 측정기준 색인 |
tgt_dim | int64_t |
타겟 측정기준 색인 |
AlltoAllParamListAttr
all-to-all 매개변수 목록
구문:
#sdy.all_to_all_param_list<
::llvm::ArrayRef<AllToAllParamAttr> # value
>
매개변수:
매개변수 | C++ 유형 | 설명 |
---|---|---|
값 | ::llvm::ArrayRef<AllToAllParamAttr> |
AxisRefAttr
전체 축 또는 분할된 하위 축 참조
구문:
#sdy.axis_ref<
::llvm::StringRef, # name
SubAxisInfoAttr # sub_axis_info
>
제약조건:
name
는 경계MeshAttr
에 있어야 합니다.sub_axis_info
가 있는 경우SubAxisInfoAttr
의 제약 조건을 충족해야 합니다.
매개변수:
매개변수 | C++ 유형 | 설명 |
---|---|---|
이름 | ::llvm::StringRef |
이 축의 이름 |
sub_axis_info | SubAxisInfoAttr |
하위 축인 경우 추가 정보 |
AxisRefListAttr
축 참조 목록
구문:
#sdy.axis_ref_list<
::llvm::ArrayRef<AxisRefAttr> # value
>
제약조건:
value
의 요소는AxisRefAttr
의 제약 조건을 충족해야 합니다.- 중복된 축 참조나 서로 겹치는 하위 축이 없습니다.
- 인접한 두 축 참조가 동일한 전체 축의 연속된 하위 축이 아닙니다. 즉, 하나의 하위 축 또는 전체 축으로 병합할 수 있습니다.
매개변수:
매개변수 | C++ 유형 | 설명 |
---|---|---|
값 | ::llvm::ArrayRef<AxisRefAttr> |
DimMappingAttr
측정기준의 요인 지수 목록
빈 목록은 null 매핑 (*
로 파싱/출력됨)을 나타냅니다. 즉, 측정기준이 어떤 요인에도 매핑되지 않습니다.
제약조건:
- 요인 색인이 하나 이상 있습니다.
- 요소 색인은 [0,
$factor_sizes
) 범위에 있어야 합니다. - 요인이 여러 개인 경우 요인 중 어느 것도 크기가 1일 수 없습니다.
- 중복된 요인 색인이 없습니다.
매개변수:
매개변수 | C++ 유형 | 설명 |
---|---|---|
factor_indices | ::llvm::ArrayRef<int64_t> |
이 측정기준이 매핑된 요인 |
DimensionShardingAttr
측정기준 샤딩
텐서 측정기준을 주요 측정기준에서 하위 측정기준으로 샤딩할 축 이름 목록, 측정기준을 더 샤딩할 수 있는지 나타내는 불리언, 샤딩 전파 중에 준수되는 이 측정기준 샤딩의 우선순위를 나타내는 선택적 정수입니다. 우선순위는 사용자 샤딩 주석에서 비롯되며 값이 낮을수록 우선순위가 높습니다. 주석에 우선순위가 없으면 가장 높은 우선순위가 가정됩니다.
제약조건:
axes
의 요소는AxisRefListAttr
에 나열된 제약 조건을 충족해야 합니다.- 측정기준 샤딩에 우선순위가 있는 경우:
- 우선순위는 0 이상입니다.
- 측정기준이 닫힌 경우 측정기준에 하나 이상의 축이 있습니다.
매개변수:
매개변수 | C++ 유형 | 설명 |
---|---|---|
축 | ::llvm::ArrayRef<AxisRefAttr> |
축 참조 |
is_closed | bool |
이 측정기준을 더 이상 샤딩할 수 없는지 여부 |
우선순위 | std::optional<int64_t> |
사용자 우선순위 기반 전파 중에 사용되는 우선순위 |
ListOfAxisRefListsAttr
축 참조 목록 목록
구문:
#sdy.list_of_axis_ref_lists<
::llvm::ArrayRef<AxisRefListAttr> # value
>
매개변수:
매개변수 | C++ 유형 | 설명 |
---|---|---|
값 | ::llvm::ArrayRef<AxisRefListAttr> |
ManualAxesAttr
ManualComputationOp가 수동인 축 목록
구문:
#sdy.manual_axes<
::llvm::ArrayRef<StringAttr> # value
>
매개변수:
매개변수 | C++ 유형 | 설명 |
---|---|---|
값 | ::llvm::ArrayRef<StringAttr> |
MeshAttr
축 메시 및 기기 목록
구문:
#sdy.mesh<
::llvm::ArrayRef<MeshAxisAttr>, # axes
::llvm::ArrayRef<int64_t> # device_ids
>
메시는 축 목록과 기기 순서를 지정하는 기기 ID 목록(선택사항)입니다.
축 목록이 비어 있으면 메시에 크기가 1인 이름이 지정되지 않은 암시적 축이 있습니다. 이 경우 기기 ID 목록이 제공되지 않으면 암시적 기기 ID 목록은 [0]입니다. 기기 ID 목록이 제공되는 경우 목록에는 0이 아닌 값의 단일 정수가 포함되어야 합니다. 이를 최대 샤딩 사례라고 합니다.
최대 샤딩이 아닌 모든 케이스의 경우 기기 ID 목록이 지정된 경우 축 크기의 곱이 기기 수와 일치해야 합니다. 기기 ID 목록을 지정하지 않으면 암시적 기기 ID 목록은 iota(product(axes))입니다. 편의상 iota(product(axes))와 동일한 기기 ID 목록을 지정하는 것도 허용되지 않습니다. 이 경우 기기 ID 목록을 지정해서는 안 됩니다.
다음은 메시의 몇 가지 예입니다.
- 빈 메시는 전파 중에 대체될 수 있는 자리표시자 메시를 나타냅니다. <[]>
- 이름이 지정되지 않은 축과 명시적 기기 ID가 있는 메시로, 일반적으로 최대 샤딩을 나타내는 데 사용됩니다. <[], device_ids=[3]>
- 축이 2개이고 기기 ID가 암시된 메시 iota(6): <["a"=2, "b"=3]>
- 두 축과 기기 순서를 지정하는 명시적 기기 ID가 있는 메시: <["a"=3, "b"=2], device_ids=[0, 2, 4, 1, 3, 5]>
제약조건:
axes
의 요소에는 중복 이름이 없어야 합니다.device_ids
가 지정된 경우:- 축 크기의 곱은 기기 수와 일치해야 합니다.
- 모든 요소는 음수가 아니어야 합니다.
device_ids
는iota(product(axis_sizes))
과 같지 않아야 합니다.- 정렬된
device_ids
는iota(product(axis_sizes))
여야 합니다.
매개변수:
매개변수 | C++ 유형 | 설명 |
---|---|---|
축 | ::llvm::ArrayRef<MeshAxisAttr> |
메시 축 |
device_ids | ::llvm::ArrayRef<int64_t> |
명시적 기기 순서 또는 최대 기기 ID |
MeshAxisAttr
메시의 이름이 지정된 축
구문:
#sdy.mesh_axis<
::llvm::StringRef, # name
int64_t # size
>
매개변수:
매개변수 | C++ 유형 | 설명 |
---|---|---|
이름 | ::llvm::StringRef |
이름 |
크기 | int64_t |
이 축의 크기 |
OpShardingRuleAttr
작업을 파티션할 수 있는 방법을 지정합니다.
구문:
#sdy.op_sharding_rule<
::llvm::ArrayRef<int64_t>, # factor_sizes
::llvm::ArrayRef<TensorMappingAttr>, # operand_mappings
::llvm::ArrayRef<TensorMappingAttr>, # result_mappings
::llvm::ArrayRef<int64_t>, # reduction_factors
::llvm::ArrayRef<int64_t>, # need_replication_factors
::llvm::ArrayRef<int64_t>, # permutation_factors
::llvm::ArrayRef<int64_t>, # blocked_propagation_factors
bool # is_custom_rule
>
샤딩 규칙은 연산의 다양한 속성(속성, 피연산자의 모양, 결과의 모양 등)에 따라 연산을 파티션할 수 있는 방법을 지정합니다. 예를 들면 다음과 같습니다.
%0 = stablehlo.add %arg0, %arg1 {
sdy.sharding_rule = #sdy.op_sharding_rule<
([i, j],[i, j])->([i, j])
{i=8, j=8}>
} : tensor<8x8xf32>
%1 = stablehlo.dot_general %arg2, %arg3, contracting_dims = [1] x [0] {
sdy.sharding_rule = #sdy.op_sharding_rule<
([i, k],[k, j])->([i, j])
{i=8, j=16, k=8}>
}: (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32>
크기가 1인 요소는 샤딩할 수 없지만 허용됩니다. 이는 주로 완전성을 위해서입니다. 점별 연산과 같은 많은 연산에는 피연산자와 결과 간에 일치하는 크기 1의 측정기준이 있습니다.
요인 유형:
reduction_factors
에는 감소가 필요한 요소의 색인(예: 점 연산의 축소 측정기준)이 포함되어 있습니다.need_replication_factors
에는 정렬 작업의 정렬된 측정기준과 같이 전체 복제가 필요한 요소의 색인이 포함됩니다.permutation_factors
에는 패딩 작업의 패딩 크기와 같이 샤딩된 경우 집합-순열이 필요한 요소의 색인이 포함됩니다.- 다른 모든 요인은 패스 스루 요인으로 간주됩니다. 즉, 매핑된 모든 텐서에서 동일한 방식으로 샤딩된 경우 통신이 필요하지 않은 요인입니다.
blocked_propagation_factors
에는 샤딩이 전파되지 않는 요소가 포함되어 있습니다. 요소 유형과 직교합니다. 즉, 차단된 전파 계수는 모든 계수 유형이 될 수 있습니다.
is_custom_rule
는 사용자가 정의한 규칙인지 여부를 나타냅니다. 사용자는 맞춤 호출에 대한 샤딩 규칙을 정의하거나 표준 작업에 대해 사전 정의된 샤딩 규칙을 덮어쓸 수 있습니다. 맞춤 규칙은 항상 보존되며 삭제되지 않습니다.
제약조건:
- 피연산자/결과 매핑 수는 연산의 피연산자/결과 수와 일치해야 합니다.
- 매핑이 하나 이상 있습니다 (operand/result가 없는 연산에 관한 규칙은 있을 수 없음).
- 각
TensorMappingAttr
의 순위는 상응하는 텐서 유형의 순위와 일치합니다. - 각 요인 그룹 (
reduction_factors
,need_replication_factors
,permutation_factors
)의 경우:- 요소는 [0,
$factor_sizes
] 범위여야 합니다. - 각 그룹 내에서 그리고 그룹 간에 중복된 요인 색인이 없습니다.
- 요소는 [0,
매개변수:
매개변수 | C++ 유형 | 설명 |
---|---|---|
factor_sizes | ::llvm::ArrayRef<int64_t> |
이 규칙의 모든 요소의 크기 |
operand_mappings | ::llvm::ArrayRef<TensorMappingAttr> |
피연산자 매핑 |
result_mappings | ::llvm::ArrayRef<TensorMappingAttr> |
결과 매핑 |
reduction_factors | ::llvm::ArrayRef<int64_t> |
감소가 필요한 요소 |
need_replication_factors | ::llvm::ArrayRef<int64_t> |
전체 복제가 필요한 요인 |
permutation_factors | ::llvm::ArrayRef<int64_t> |
collective-permute가 필요한 요소 |
blocked_propagation_factors | ::llvm::ArrayRef<int64_t> |
샤딩이 전파되지 않는 요소 |
is_custom_rule | bool |
규칙이 stablehlo.custom_call에 관한 것인지 여부 |
SubAxisInfoAttr
이 하위 축이 전체 축에서 파생되는 방식에 관한 정보
구문:
#sdy.sub_axis_info<
int64_t, # pre_size
int64_t # size
>
전체 축을 n개의 하위 축으로 분할하면 축의 모양이 [k_1,...,k_n]으로 변경되고, i번째 하위 축은 왼쪽의 모든 축 크기 m=prod(k_1,...,k_(i-1))
(예: 사전 크기)와 크기 k_i의 곱으로 표현할 수 있습니다. 따라서 sub-axis-info 속성은 이러한 두 숫자를 보유하며 다음과 같이 표시됩니다. 사전 크기 m과 크기 k의 경우 (m)k
.
제약조건:
pre-size
은(는) 1 이상입니다.size
가 1보다 큼pre-size
는 전체 축의 크기를 나눠야 합니다. 즉,pre-size
와size
모두 전체 축의 크기를 나누고 하위 축이 전체 축을 초과하지 않습니다.- 하위 축의 크기가 상응하는 전체 축의 크기와 같지 않은 경우 전체 축을 대신 사용해야 합니다.
매개변수:
매개변수 | C++ 유형 | 설명 |
---|---|---|
pre_size | int64_t |
이 하위 축 왼쪽에 있는 하위 축 크기의 곱 |
크기 | int64_t |
이 하위 축의 크기 |
TensorMappingAttr
텐서의 각 측정기준에 대한 인수 매핑입니다.
구문:
#sdy.tensor_mapping<
::llvm::ArrayRef<DimMappingAttr> # dim_mappings
>
제약조건:
dim_mappings
의 요소는DimMappingAttr
의 제약 조건을 충족해야 합니다.- 측정기준 간에 중복된 요인 색인이 없습니다.
매개변수:
매개변수 | C++ 유형 | 설명 |
---|---|---|
dim_mappings | ::llvm::ArrayRef<DimMappingAttr> |
측정기준 매핑 |
TensorShardingAttr
텐서 샤딩
구문:
#sdy.sharding<
::mlir::Attribute, # mesh_or_ref
::llvm::ArrayRef<DimensionShardingAttr>, # dim_shardings
::llvm::ArrayRef<AxisRefAttr> # replicated_axes
>
텐서 샤딩은 특정 메시에 바인딩되며 해당 메시의 축 이름만 참조할 수 있습니다. 크기 샤딩은 텐서의 각 크기에 대해 주요 축에서 보조 축으로 샤딩되는 축 (또는 하위 축)을 알려줍니다. 측정기준을 샤딩하지 않는 다른 모든 축은 암시적으로 또는 명시적으로 (복제된 축 목록에 표시되는 경우) 복제됩니다.
이 샤딩이 바인딩된 메시는 기호 이름, 상응하는 MeshOp
기호 참조 또는 인라인 MeshAttr
로 지정할 수 있습니다.
제약조건:
dim_shardings
의 요소는DimensionShardingAttr
에 나열된 제약 조건을 충족해야 합니다.replicated_axes
의 요소는AxisRefListAttr
에 나열된 제약 조건을 충족해야 합니다.- 상응하는 텐서 유형이
ShapedType
가 아닌 경우 샤딩의 순위는 0이어야 하며 복제된 축이 없어야 합니다. - 텐서에 등급이 있어야 합니다.
- 차원 샤딩 수는 텐서의 순위와 같습니다.
- 크기가 0인 측정기준은 샤딩되지 않습니다.
replicated_axes
의 항목은mesh_or_ref
를 기준으로 정렬됩니다 (AxisRefAttr::getMeshComparator
참고).
매개변수:
매개변수 | C++ 유형 | 설명 |
---|---|---|
mesh_or_ref | ::mlir::Attribute |
메시 속성 또는 평면 메시 기호 참조 속성 |
dim_shardings | ::llvm::ArrayRef<DimensionShardingAttr> |
측정기준 샤딩 |
replicated_axes | ::llvm::ArrayRef<AxisRefAttr> |
축 참조 |
TensorShardingPerValueAttr
연산자의 피연산자/결과당 텐서 샤딩
구문:
#sdy.sharding_per_value<
::llvm::ArrayRef<TensorShardingAttr> # shardings
>
연산의 피연산자/결과별로 하나씩 있는 TensorShardingAttr
목록입니다.
제약조건:
shardings
의 요소는TensorShardingAttr
의 제약 조건을 충족해야 합니다.
매개변수:
매개변수 | C++ 유형 | 설명 |
---|---|---|
shardings | ::llvm::ArrayRef<TensorShardingAttr> |
값별 샤딩 |
열거형
PropagationDirection
전파 방향 enum
케이스:
기호 | 값 | 문자열 |
---|---|---|
없음 | 0 |
없음 |
전달 | 1 |
전달 |
뒤로 | 2 |
뒤로 |
양측 | 3 |
양측 |