XLA 플래그 안내

이 가이드에서는 사용자가 XLA의 기능을 효과적으로 탐색하고 활용할 수 있도록 엄선된 주요 XLA 플래그를 제공합니다. 다음 섹션에서는 런타임 성능과 메모리 사용량에 큰 영향을 미칠 수 있는 플래그를 자세히 설명합니다. 플래그를 사용 설정한 후 비정상 종료와 같은 문제가 발생하면 기본 설정으로 되돌리고 GitHub 문제를 만드는 것이 좋습니다.

성능 플래그

다음 플래그는 런타임 성능을 향상하는 데 도움이 됩니다. 이러한 설정을 실험하면 상당한 성능 향상을 얻을 수 있습니다.

플래그 설명 기본값 추천 값 후보 값
파이프라인
1. xla_should_allow_loop_variant_parameter_in_chain
2. xla_should_add_loop_invariant_op_in_chain
3. xla_tpu_enable_ici_ag_pipelining
이 3가지 플래그는 ICI(Interchip-Interconnect) all-gather 작업의 집단 파이프라인을 사용 설정하는 데 함께 사용해야 하며, 이를 통해 실행이 중복될 가능성이 높아집니다. 1. xla_should_allow_loop_variant_parameter_in_chain=kDisabled
2. xla_should_add_loop_invariant_op_in_chain=kDisabled
3. xla_tpu_enable_ici_ag_pipelining=false
1. xla_should_allow_loop_variant_parameter_in_chain=kEnabled
2. xla_should_add_loop_invariant_op_in_chain=kEnabled
3. xla_tpu_enable_ici_ag_pipelining=true
1. xla_should_allow_loop_variant_parameter_in_chain=kDisabled/kEnabled/kAuto
2. xla_should_add_loop_invariant_op_in_chain=kDisabled/kEnabled/kAuto
3. xla_tpu_enable_ici_ag_pipelining=true/false
v5e/Async
xla_enable_async_all_gather
xla_tpu_enable_async_collective_fusion
xla_tpu_enable_async_collective_fusion_fuse_all_gather
v5e에서 비동기 all-gather 작업을 활성화하려면 이 3가지 플래그를 함께 사용해야 합니다. xla_enable_async_all_gather=kAuto
xla_tpu_enable_async_collective_fusion=true
xla_tpu_enable_async_collective_fusion_fuse_all_gather=true
xla_enable_async_all_gather=kAuto
xla_tpu_enable_async_collective_fusion=true
xla_tpu_enable_async_collective_fusion_fuse_all_gather=true
xla_enable_async_all_gather=kDisabled/kEnabled/kAuto
xla_tpu_enable_async_collective_fusion=true/false
xla_tpu_enable_async_collective_fusion_fuse_all_gather=true/false
v5e/Async
xla_tpu_enable_async_collective_fusion
xla_tpu_enable_async_collective_fusion_fuse_all_reduce
이 두 플래그는 v5e에서 비동기 all-reduce 작업을 활성화하기 위해 함께 사용해야 합니다. xla_tpu_enable_async_collective_fusion=true
xla_tpu_enable_async_collective_fusion_fuse_all_reduce=false
xla_tpu_enable_async_collective_fusion=true
xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true
xla_tpu_enable_async_collective_fusion=true/false
xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true/false
Async
xla_tpu_enable_async_all_to_all
이 플래그는 비동기 all-to-all 통신을 사용 설정합니다. xla_tpu_enable_async_all_to_all=false xla_tpu_enable_async_all_to_all=true xla_tpu_enable_async_all_to_all=true/false
지연 시간 제한
xla_all_gather_latency_bound_threshold_in_bytes
이 플래그는 지연 시간 제한 (즉, 소규모) all-gather 작업을 위한 것입니다. 이를 사용 설정하면 지연 시간 제한 all-gather의 실행 시간을 줄일 수 있는 특정 최적화가 트리거됩니다. 일반적으로 추론 워크로드에 사용됩니다. xla_all_gather_latency_bound_threshold_in_bytes=-1
(사용 설정되지 않음)
4~16Mb(i.e. 4~16 * 1024 * 1024) [0, 9223372036854775807]
지연 시간 제한
xla_all_reduce_latency_bound_threshold_in_bytes
이 플래그는 지연 시간 제한 (즉, 소규모) all-gather 작업을 위한 것입니다. 이를 사용 설정하면 지연 시간 제한 all-reduce의 실행 시간을 줄일 수 있는 특정 최적화가 트리거됩니다. 일반적으로 추론 워크로드에 사용됩니다. xla_all_reduce_latency_bound_threshold_in_bytes=-1
(사용 설정되지 않음)
4~16Mb(i.e. 4~16 * 1024 * 1024) [0, 9223372036854775807]
지연 시간 제한
xla_collective_permute_latency_bound_threshold_in_bytes
이 플래그는 지연 시간 제한 (즉, 소규모) all-gather 작업을 위한 것입니다. 이를 사용 설정하면 지연 시간 제한 집단 순열의 실행 시간을 줄일 수 있는 특정 최적화가 트리거됩니다. 일반적으로 추론 워크로드에 사용됩니다. xla_collective_permute_latency_bound_threshold_in_bytes=-1
(사용 설정되지 않음)
4~16Mb(i.e. 4~16 * 1024 * 1024) [0, 9223372036854775807]
지연 시간 제한
xla_all_to_all_latency_bound_threshold_in_bytes
이 플래그는 지연 시간 제한 (즉, 소규모) all-gather 작업을 위한 것입니다. 이를 사용 설정하면 지연 시간 제한 all-to-all의 실행 시간을 줄일 수 있는 특정 최적화가 트리거됩니다. 일반적으로 추론 워크로드에 사용됩니다. xla_all_to_all_latency_bound_threshold_in_bytes=-1
(사용 설정되지 않음)
4~16Mb(i.e. 4~16 * 1024 * 1024) [0, 9223372036854775807]
xla_enable_async_collective_permute 모든 집단 순열 작업을 비동기 변형으로 다시 작성합니다. auto로 설정하면 XLA가 다른 구성이나 조건에 따라 비동기 집합을 자동으로 사용 설정할 수 있습니다. xla_enable_async_collective_permute=kAuto xla_enable_async_collective_permute=kAuto xla_enable_async_collective_permute=kAuto/kEnabled/kDisabled

메모리 플래그

아래에 나열된 플래그는 HBM 관련 문제를 해결하기 위해 제공됩니다. 모델 컴파일 중에 HBM '메모리 부족' 오류가 발생하는 경우에만 조정해야 합니다. 다른 모든 시나리오에서는 기본값을 사용하는 것이 좋습니다. 기본값을 변경하면 실적에 부정적인 영향을 미칠 수 있기 때문입니다.

플래그 설명 기본값 추천 값 후보 값
스케줄러
xla_latency_hiding_scheduler_rerun
이 설정은 지연 시간 숨기기 스케줄러의 동작을 조정합니다. 이 기능은 프로세스를 '다시 실행'할 때마다 예약에 할당된 메모리 한도를 점진적으로 줄여 작동합니다. xla_latency_hiding_scheduler_rerun=1 xla_latency_hiding_scheduler_rerun=5 0~10(it doesn’t make much sense beyond 10 reruns)
Fusion
xla_tpu_rwb_fusion
이 플래그는 reduce+broadcast 유형의 융합을 사용 설정하며 메모리 사용량을 줄일 수 있습니다. xla_tpu_rwb_fusion=true xla_tpu_rwb_fusion=false xla_tpu_rwb_fusion=true/false
스케줄러
xla_memory_scheduler
이 플래그는 메모리 스케줄러가 메모리 소비를 최소화하는 데 사용할 알고리즘을 지정합니다. 더 고급 알고리즘을 사용하면 컴파일 시간이 길어지는 대신 메모리 소비가 적은 일정을 얻을 수 있습니다. xla_memory_scheduler=kDefault xla_memory_scheduler=kBrkga xla_memory_scheduler=kDefault/kList/kDfs/kPostOrder/kBrkga
스케줄러
xla_tpu_enable_latency_hiding_scheduler
이 플래그는 동기식 대신 비동기식을 실행할 수 있는 지연 시간 숨김 스케줄러를 사용 설정합니다. 사용 중지하면 이러한 비동기 작업의 성능 향상을 잃는 대신 메모리 사용량이 줄어듭니다. xla_tpu_enable_latency_hiding_scheduler=true xla_tpu_enable_latency_hiding_scheduler=false xla_tpu_enable_latency_hiding_scheduler=true/false
SPMD
xla_jf_spmd_threshold_for_windowed_einsum_mib
이 플래그는 집단 matmul을 트리거하는 점의 최소 크기의 하한을 설정합니다. 더 높은 값으로 설정하면 집단 matmul을 실행할 기회를 놓치는 대신 메모리를 절약할 수 있습니다. xla_jf_spmd_threshold_for_windowed_einsum_mib=-1 10Mb~1Gb (i.e. 10*1024*1024 ~ 1024*1024*1024) [0, 9223372036854775807]

기타 일반적으로 사용되는 플래그

플래그 유형 참고
xla_dump_to 문자열 (파일 경로) 사전 최적화 HLO 파일 및 기타 아티팩트가 배치될 폴더입니다 (XLA 도구 참고).

TPU XLA 플래그

플래그 유형 참고
xla_tpu_enable_data_parallel_all_reduce_opt 불리언 (true/false) 데이터 병렬 샤딩에 사용되는 DCN (데이터 센터 네트워킹) all-reduce의 중복 기회를 늘리기 위한 최적화
xla_tpu_data_parallel_opt_different_sized_ops 불리언 (true/false) 출력 크기가 스택 변수에 저장할 수 있는 크기와 일치하지 않더라도 여러 반복에 걸쳐 데이터 병렬 작업의 파이프라인을 사용 설정합니다. 메모리 부족을 늘릴 수 있습니다.
xla_tpu_spmd_rng_bit_generator_unsafe 불리언 (true/false) 계산의 여러 부분에서 서로 다른 샤딩으로 결정적 결과를 예상하는 경우 안전하지 않은 방식으로 RngBitGenerator HLO를 파티셔닝된 방식으로 실행할지 여부입니다.
xla_tpu_megacore_fusion_allow_ags 불리언 (true/false) 모든 수집을 컨볼루션/모든 축소와 융합할 수 있습니다.
xla_tpu_enable_ag_backward_pipelining 불리언 (true/false) 파이프라인은 스캔 루프를 통해 뒤로 모두 수집합니다 (현재 메가스케일 모두 수집).

GPU XLA 플래그

플래그 유형 참고
xla_gpu_enable_latency_hiding_scheduler 불리언 (true/false) 이 플래그를 사용하면 지연 시간 숨기기 스케줄러가 비동기 통신과 계산을 효율적으로 중복할 수 있습니다. 기본값은 False입니다.
xla_gpu_enable_triton_gemm 불리언 (true/false) Triton 기반 행렬 곱셈을 사용합니다.
xla_gpu_graph_level 플래그 (0~3) GPU 그래프 수준을 설정하는 기존 플래그입니다. 새 사용 사례에서 xla_gpu_enable_command_buffer 사용 0 = 사용 안함, 1 = 융합 및 memcpys 캡처, 2 = gemm 캡처, 3 = 컨볼루션 캡처
xla_gpu_all_reduce_combine_threshold_bytes 정수 (바이트) 이러한 플래그는 기기 간 통신에 소요되는 시간을 줄이기 위해 여러 개의 작은 AllGather / ReduceScatter / AllReduce를 하나의 큰 AllGather / ReduceScatter / AllReduce로 결합하는 시점을 조정합니다. 예를 들어 트랜스포머 기반 워크로드의 AllGather / ReduceScatter 임계값의 경우 최소한 트랜스포머 레이어의 가중치 AllGather / ReduceScatter를 결합할 수 있을 만큼 높게 조정하는 것이 좋습니다. 기본적으로 combine_threshold_bytes는 256으로 설정됩니다.
xla_gpu_all_gather_combine_threshold_bytes 정수 (바이트) 위의 xla_gpu_all_reduce_combine_threshold_bytes를 참고하세요.
xla_gpu_reduce_scatter_combine_threshold_bytes 정수 (바이트) 위의 xla_gpu_all_reduce_combine_threshold_bytes를 참고하세요.
xla_gpu_enable_pipelined_all_gather 불리언 (true/false) all-gather 명령어의 파이프라인을 사용 설정합니다.
xla_gpu_enable_pipelined_reduce_scatter 불리언 (true/false) reduce-scatter 명령어의 파이프라인 지원
xla_gpu_enable_pipelined_all_reduce 불리언 (true/false) all-reduce 명령어의 파이프라인 처리를 사용 설정합니다.
xla_gpu_enable_while_loop_double_buffering 불리언 (true/false) while 루프의 더블 버퍼링 사용 설정
xla_gpu_enable_all_gather_combine_by_dim 불리언 (true/false) 동일한 수집 측정기준을 사용하거나 측정기준과 관계없이 모든 수집 작업을 결합합니다.
xla_gpu_enable_reduce_scatter_combine_by_dim 불리언 (true/false) 동일한 측정기준을 사용하거나 측정기준과 관계없이 reduce-scatter 작업을 결합합니다.