Este guia oferece uma seleção de flags principais do XLA para ajudar os usuários a navegar e usar os recursos do XLA de maneira eficaz. As seções a seguir detalham flags que podem afetar significativamente o desempenho do tempo de execução e a utilização da memória. Se ocorrerem problemas, como falhas, depois de ativar uma flag, recomendamos reverter para a configuração padrão e criar um problema no GitHub.
Flags de performance
As flags a seguir são importantes para melhorar o desempenho do ambiente de execução. Testar essas configurações pode gerar ganhos consideráveis de performance.
| Sinalização | Descrição | Valores padrão | Valores sugeridos | Valores de candidatos | 
|---|---|---|---|---|
| Pipelining 1. xla_should_allow_loop_variant_parameter_in_chain2. xla_should_add_loop_invariant_op_in_chain3. xla_tpu_enable_ici_ag_pipelining | Essas três flags precisam ser usadas juntas para ativar o pipelining coletivo de operações de agregação total da ICI(Interchip-Interconnect), o que cria mais oportunidades de execução sobreposta. | 1. xla_should_allow_loop_variant_parameter_in_chain=kDisabled2. xla_should_add_loop_invariant_op_in_chain=kDisabled3. xla_tpu_enable_ici_ag_pipelining=false | 1. xla_should_allow_loop_variant_parameter_in_chain=kEnabled2. xla_should_add_loop_invariant_op_in_chain=kEnabled3. xla_tpu_enable_ici_ag_pipelining=true | 1. xla_should_allow_loop_variant_parameter_in_chain=kDisabled/kEnabled/kAuto2. xla_should_add_loop_invariant_op_in_chain=kDisabled/kEnabled/kAuto3. xla_tpu_enable_ici_ag_pipelining=true/false | 
| v5e/Async xla_enable_async_all_gatherxla_tpu_enable_async_collective_fusionxla_tpu_enable_async_collective_fusion_fuse_all_gather | Essas três flags precisam ser usadas juntas para ativar operações assíncronas de all-gather no v5e. | xla_enable_async_all_gather=kAutoxla_tpu_enable_async_collective_fusion=truexla_tpu_enable_async_collective_fusion_fuse_all_gather=true | xla_enable_async_all_gather=kAutoxla_tpu_enable_async_collective_fusion=truexla_tpu_enable_async_collective_fusion_fuse_all_gather=true | xla_enable_async_all_gather=kDisabled/kEnabled/kAutoxla_tpu_enable_async_collective_fusion=true/falsexla_tpu_enable_async_collective_fusion_fuse_all_gather=true/false | 
| v5e/Async xla_tpu_enable_async_collective_fusionxla_tpu_enable_async_collective_fusion_fuse_all_reduce | Essas duas flags precisam ser usadas juntas para ativar operações assíncronas de redução total no v5e. | xla_tpu_enable_async_collective_fusion=truexla_tpu_enable_async_collective_fusion_fuse_all_reduce=false | xla_tpu_enable_async_collective_fusion=truexla_tpu_enable_async_collective_fusion_fuse_all_reduce=true | xla_tpu_enable_async_collective_fusion=true/falsexla_tpu_enable_async_collective_fusion_fuse_all_reduce=true/false | 
| Async xla_tpu_enable_async_all_to_all | Essa flag ativa a comunicação assíncrona de todos para todos. | xla_tpu_enable_async_all_to_all=false | xla_tpu_enable_async_all_to_all=true | xla_tpu_enable_async_all_to_all=true/false | 
| Limitado por latência xla_all_gather_latency_bound_threshold_in_bytes | Essa flag é destinada a operações de coleta total com limite de latência (ou seja, de tamanho pequeno). Ao ativar essa opção, você aciona otimizações específicas que podem reduzir o tempo de execução de all-gathers vinculados à latência. Normalmente, ele é usado em cargas de trabalho de inferência. | xla_all_gather_latency_bound_threshold_in_bytes=-1(que não está ativado) | 4~16Mb(i.e. 4~16 * 1024 * 1024) | [0, 9223372036854775807] | 
| Limitado por latência xla_all_reduce_latency_bound_threshold_in_bytes | Essa flag é destinada a operações de coleta total com limite de latência (ou seja, de tamanho pequeno). Ao ativar essa opção, você aciona otimizações específicas que podem reduzir o tempo de execução para all-reduces vinculados à latência. Normalmente, ele é usado em cargas de trabalho de inferência. | xla_all_reduce_latency_bound_threshold_in_bytes=-1(que não está ativado) | 4~16Mb(i.e. 4~16 * 1024 * 1024) | [0, 9223372036854775807] | 
| Limitado por latência xla_collective_permute_latency_bound_threshold_in_bytes | Essa flag é destinada a operações de coleta total com limite de latência (ou seja, de tamanho pequeno). Ativar essa opção aciona otimizações específicas que podem reduzir o tempo de execução para permutações coletivas vinculadas à latência. Normalmente, ele é usado em cargas de trabalho de inferência. | xla_collective_permute_latency_bound_threshold_in_bytes=-1(que não está ativado) | 4~16Mb(i.e. 4~16 * 1024 * 1024) | [0, 9223372036854775807] | 
| Limitado por latência xla_all_to_all_latency_bound_threshold_in_bytes | Essa flag é destinada a operações de coleta total com limite de latência (ou seja, de tamanho pequeno). Ativar essa opção aciona otimizações específicas que podem reduzir o tempo de execução para all-to-all vinculados à latência. Normalmente, ele é usado em cargas de trabalho de inferência. | xla_all_to_all_latency_bound_threshold_in_bytes=-1(que não está ativado) | 4~16Mb(i.e. 4~16 * 1024 * 1024) | [0, 9223372036854775807] | 
| xla_enable_async_collective_permute | Reescreve todas as operações de permutação coletiva para as variantes assíncronas.  Quando definido como auto, o XLA pode ativar a coleta assíncrona com base em outras configurações ou condições automaticamente. | xla_enable_async_collective_permute=kAuto | xla_enable_async_collective_permute=kAuto | xla_enable_async_collective_permute=kAuto/kEnabled/kDisabled | 
Flags de memória
As flags listadas abaixo são fornecidas para resolver problemas relacionados à HBM. Esses valores só devem ser ajustados se você encontrar erros de falta de memória de HBM durante a compilação do modelo. Em todos os outros cenários, os valores padrão são recomendados, já que alterá-los pode afetar negativamente o desempenho.
| Sinalização | Descrição | Valores padrão | Valores sugeridos | Valores de candidatos | 
|---|---|---|---|---|
| Programador xla_latency_hiding_scheduler_rerun | Essa configuração ajusta o comportamento do programador de ocultação de latência. Ele funciona reduzindo gradualmente o limite de memória alocado para o agendamento a cada "nova execução" do processo. | xla_latency_hiding_scheduler_rerun=1 | xla_latency_hiding_scheduler_rerun=5 | 0~10(it doesn’t make much sense beyond 10 reruns) | 
| Fusão xla_tpu_rwb_fusion | Essa flag ativa fusões do tipo redução+transmissão e pode diminuir o uso de memória. | xla_tpu_rwb_fusion=true | xla_tpu_rwb_fusion=false | xla_tpu_rwb_fusion=true/false | 
| Programador xla_memory_scheduler | Essa flag especifica o algoritmo que o programador de memória vai usar para minimizar o consumo de memória. Usar um algoritmo mais avançado pode gerar um cronograma que consome menos memória, mas com um tempo de compilação maior. | xla_memory_scheduler=kDefault | xla_memory_scheduler=kBrkga | xla_memory_scheduler=kDefault/kList/kDfs/kPostOrder/kBrkga | 
| Programador xla_tpu_enable_latency_hiding_scheduler | Essa flag ativa o planejador de ocultação de latência, que permite realizar operações coletivas assíncronas em vez de síncronas. Desativá-la reduz o uso da memória, mas perde os ganhos de desempenho dessas operações assíncronas. | xla_tpu_enable_latency_hiding_scheduler=true | xla_tpu_enable_latency_hiding_scheduler=false | xla_tpu_enable_latency_hiding_scheduler=true/false | 
| SPMD xla_jf_spmd_threshold_for_windowed_einsum_mib | Essa flag define o limite inferior do tamanho mínimo do ponto para acionar a multiplicação de matrizes coletiva. Definir um valor mais alto economizaria memória, mas perderia oportunidades de realizar a multiplicação de matrizes coletiva. | xla_jf_spmd_threshold_for_windowed_einsum_mib=-1 | 10Mb~1Gb (i.e. 10*1024*1024 ~ 1024*1024*1024) | [0, 9223372036854775807] | 
Outras flags usadas com frequência
| Sinalização | Tipo | Observações | 
|---|---|---|
| xla_dump_to | String (filepath) | A pasta em que os arquivos HLO pré-otimização e outros artefatos serão colocados. Consulte Ferramentas XLA. | 
Flags do TPU XLA
| Sinalização | Tipo | Observações | 
|---|---|---|
| xla_tpu_enable_data_parallel_all_reduce_opt | Booleano (verdadeiro/falso) | Otimização para aumentar as oportunidades de sobreposição de todas as reduções de DCN (rede de data center) usadas para fragmentação paralela de dados. | 
| xla_tpu_data_parallel_opt_different_sized_ops | Booleano (verdadeiro/falso) | Permite o encadeamento de operações paralelas de dados em várias iterações, mesmo que os tamanhos de saída não correspondam ao que pode ser salvo no lugar das variáveis empilhadas. Pode aumentar a pressão da memória. | 
| xla_tpu_spmd_rng_bit_generator_unsafe | Booleano (verdadeiro/falso) | Executar ou não o HLO RngBitGenerator de maneira particionada, o que é inseguro se resultados determinísticos forem esperados com diferentes fragmentações em diferentes partes da computação. | 
| xla_tpu_megacore_fusion_allow_ags | Booleano (verdadeiro/falso) | Permite a fusão de all-gathers com convoluções/all-reduces. | 
| xla_tpu_enable_ag_backward_pipelining | Booleano (verdadeiro/falso) | Os pipelines fazem a coleta de todos os dados (atualmente, coleta de todos os dados em grande escala) de trás para frente nos loops de verificação. | 
Flags XLA da GPU
| Sinalização | Tipo | Observações | 
|---|---|---|
| xla_gpu_enable_latency_hiding_scheduler | Booleano (verdadeiro/falso) | Essa flag permite que os programadores de ocultação de latência sobreponham a comunicação assíncrona com a computação de maneira eficiente. O valor padrão é False. | 
| xla_gpu_enable_triton_gemm | Booleano (verdadeiro/falso) | Use a multiplicação de matrizes baseada em Triton. | 
| xla_gpu_graph_level | Sinalização (0 a 3) | A flag legada para definir o nível do gráfico da GPU. Use xla_gpu_enable_command_buffer em novos casos de uso. 0 = desativado; 1 = captura fusões e memcpys; 2 = captura gemms; 3 = captura convoluções. | 
| xla_gpu_all_reduce_combine_threshold_bytes | Número inteiro (bytes) | Essas flags ajustam quando combinar vários AllGather / ReduceScatter / AllReduce pequenos em um grande AllGather / ReduceScatter / AllReduce para reduzir o tempo gasto na comunicação entre dispositivos. Por exemplo, para os limites de AllGather / ReduceScatter em uma carga de trabalho baseada em Transformer, ajuste-os para que combinem pelo menos um AllGather / ReduceScatter de peso de camada do Transformer. Por padrão, o combine_threshold_bytes é definido como 256. | 
| xla_gpu_all_gather_combine_threshold_bytes | Número inteiro (bytes) | Consulte xla_gpu_all_reduce_combine_threshold_bytes acima. | 
| xla_gpu_reduce_scatter_combine_threshold_bytes | Número inteiro (bytes) | Consulte xla_gpu_all_reduce_combine_threshold_bytes acima. | 
| xla_gpu_enable_pipelined_all_gather | Booleano (verdadeiro/falso) | Ativa o encadeamento de instruções all-gather. | 
| xla_gpu_enable_pipelined_reduce_scatter | Booleano (verdadeiro/falso) | Ativa o encadeamento de instruções de redução e dispersão. | 
| xla_gpu_enable_pipelined_all_reduce | Booleano (verdadeiro/falso) | Ativa o encadeamento de instruções all-reduce. | 
| xla_gpu_enable_while_loop_double_buffering | Booleano (verdadeiro/falso) | Ative o buffer duplo para o loop "while". | 
| xla_gpu_enable_all_gather_combine_by_dim | Booleano (verdadeiro/falso) | Combine operações all-gather com a mesma dimensão de coleta ou independente da dimensão. | 
| xla_gpu_enable_reduce_scatter_combine_by_dim | Booleano (verdadeiro/falso) | Combine operações de redução e dispersão com a mesma dimensão ou independente da dimensão. |