O XProf é uma ótima maneira de adquirir e visualizar traces e perfis de desempenho do seu programa, incluindo a atividade na GPU e na TPU. O resultado final vai ficar parecido com este:

Captura programática
Você pode instrumentar seu código para capturar um rastreamento do criador de perfis para código JAX usando os métodos
jax.profiler.start_trace
e jax.profiler.stop_trace. Chame
jax.profiler.start_trace
com o diretório em que os arquivos de rastreamento serão gravados. Ele precisa ser o mesmo diretório --logdir
usado para iniciar o XProf. Em seguida, use o XProf para ver os traces.
Por exemplo, para fazer um rastreamento do criador de perfil:
import jax
jax.profiler.start_trace("/tmp/profile-data")
# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()
jax.profiler.stop_trace()
Observe a chamada jax.block_until_ready. Usamos isso para garantir que a execução no dispositivo seja capturada pelo rastreamento. Consulte Envio assíncrono para detalhes sobre por que isso é necessário.
Você também pode usar o gerenciador de contexto jax.profiler.trace como alternativa a start_trace e stop_trace:
import jax
with jax.profiler.trace("/tmp/profile-data"):
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()
Como visualizar o trace
Depois de capturar um rastreamento, é possível visualizá-lo usando a interface do XProf.
É possível iniciar a interface do criador de perfil diretamente usando o comando autônomo XProf, apontando para o diretório de registros:
$ xprof --port=8791 /tmp/profile-data
Attempting to start XProf server:
Log Directory: /tmp/profile-data
Port: 8791
Worker Service Address: 0.0.0.0:50051
Hide Capture Button: False
XProf at http://localhost:8791/ (Press CTRL+C to quit)
Navegue até o URL fornecido (por exemplo, http://localhost:8791/) no navegador para ver o perfil.
Os rastreamentos disponíveis aparecem no menu suspenso "Sessões" à esquerda. Selecione a sessão de seu interesse e, no menu suspenso "Ferramentas", escolha "Visualizador de rastreamento". Agora você vai ver uma linha do tempo da execução. Use as teclas WASD para navegar pelo rastreamento e clique ou arraste para selecionar eventos e ver mais detalhes. Consulte a documentação da ferramenta Trace Viewer para mais detalhes sobre como usar o visualizador de rastreamento.
Captura manual via XProf
Confira abaixo as instruções para capturar um rastreamento de N segundos acionado manualmente de um programa em execução.
Inicie um servidor XProf:
xprof --logdir /tmp/profile-data/Você poderá carregar o XProf em
<http://localhost:8791/>. Você pode especificar uma porta diferente com a flag--port.No programa ou processo Python que você quer criar um perfil, adicione o seguinte em algum lugar perto do início:
import jax.profiler jax.profiler.start_server(9999)Isso inicia o servidor do criador de perfil a que o XProf se conecta. O servidor do criador de perfis precisa estar em execução antes de você passar para a próxima etapa. Quando terminar de usar o servidor, chame
jax.profiler.stop_server()para desligá-lo.Se você quiser criar um perfil de um trecho de um programa de longa duração (por exemplo, um loop de treinamento longo), coloque isso no início do programa e inicie o programa normalmente. Se você quiser criar um perfil de um programa curto (por exemplo, um microbenchmark), uma opção é iniciar o servidor do profiler em um shell do IPython e executar o programa curto com
%rundepois de iniciar a captura na próxima etapa. Outra opção é iniciar o servidor do criador de perfil no início do programa e usartime.sleep()para ter tempo suficiente para iniciar a captura.Abra
<http://localhost:8791/>e clique no botão "CAPTURE PROFILE" (CAPTURA DE PERFIL) no canto superior esquerdo. Insira "localhost:9999" como o URL do serviço de perfil. Esse é o endereço do servidor do criador de perfil que você iniciou na etapa anterior. Digite o número de milissegundos que você quer criar um perfil e clique em "CAPTURE".Se o código que você quer criar um perfil ainda não estiver em execução (por exemplo, se você iniciou o servidor do criador de perfil em um shell do Python), execute-o enquanto a captura estiver em andamento.
Depois que a captura terminar, o XProf será atualizado automaticamente. Nem todos os recursos de criação de perfil do XProf estão conectados ao JAX. Por isso, pode parecer que nada foi capturado. À esquerda, em "Ferramentas", selecione "Visualizador de rastreamento".
Agora você vai ver uma linha do tempo da execução. Use as teclas WASD para navegar pelo rastreamento e clique ou arraste para selecionar eventos e ver mais detalhes na parte de baixo. Consulte a documentação da ferramenta Trace Viewer para mais detalhes sobre como usar o visualizador de rastreamento.
XProf e TensorBoard
O XProf é a ferramenta que alimenta a funcionalidade de criação de perfil e captura de rastreamento no TensorBoard. Desde que o xprof esteja instalado, uma guia "Perfil" vai aparecer no TensorBoard. Usar isso é idêntico a iniciar o XProf
de forma independente, desde que ele seja iniciado apontando para o mesmo diretório de registros.
Isso inclui a funcionalidade de captura, análise e visualização de perfis. O XProf substitui a funcionalidade tensorboard_plugin_profile que era recomendada anteriormente.
$ tensorboard --logdir=/tmp/profile-data
[...]
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.19.0 at http://localhost:6006/ (Press CTRL+C to quit)
Adicionar eventos de rastreamento personalizados
Por padrão, os eventos no visualizador de rastreamento são principalmente funções JAX internas de baixo nível. É possível adicionar seus próprios eventos e funções usando
jax.profiler.TraceAnnotation
e jax.profiler.annotate_function
no seu código.
Configurar opções do criador de perfil
O método start_trace aceita um parâmetro profiler_options opcional, que permite um controle detalhado sobre o comportamento do criador de perfis. Esse parâmetro precisa ser uma instância de jax.profiler.ProfileOptions.
Por exemplo, para desativar todos os rastreamentos de python e host:
import jax
options = jax.profiler.ProfileOptions()
options.python_tracer_level = 0
options.host_tracer_level = 0
jax.profiler.start_trace("/tmp/profile-data", profiler_options=options)
# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()
jax.profiler.stop_trace()
Opções gerais
host_tracer_level: define o nível de rastreamento para atividades do lado do host.Valores aceitos:
0: desativa completamente o rastreamento do host (CPU).1: ativa o rastreamento apenas de eventos TraceMe instrumentados pelo usuário.2: inclui rastreamentos de nível 1 e detalhes de execução de programas de alto nível, como operações XLA caras (padrão).3: inclui rastreamentos de nível 2 e detalhes mais verbosos e de baixo nível da execução do programa, como operações XLA baratas.
device_tracer_level: controla se o rastreamento de dispositivos está ativado.Valores aceitos:
0: desativa o rastreamento de dispositivos.1: ativa o rastreamento de dispositivos (padrão).
python_tracer_level: controla se o rastreamento do Python está ativado.Valores aceitos:
0: desativa o rastreamento de chamadas de função Python (padrão).1: ativa o rastreamento do Python.
Opções de configuração avançada
Opções de TPU
tpu_trace_mode: especifica o modo para rastreamento de TPU.Valores aceitos:
TRACE_ONLY_HOST: isso significa que apenas as atividades do lado do host (CPU) são rastreadas, e nenhum rastreamento de dispositivo (TPU/GPU) é coletado.TRACE_ONLY_XLA: isso significa que apenas as operações no nível do XLA no dispositivo são rastreadas.TRACE_COMPUTE: rastreia operações de computação no dispositivo.TRACE_COMPUTE_AND_SYNC: rastreia operações de computação e eventos de sincronização no dispositivo.
Se "tpu_trace_mode" não for fornecido, o padrão de trace_mode será
TRACE_ONLY_XLA.tpu_num_sparse_cores_to_trace: especifica o número de núcleos esparsos a serem rastreados na TPU.tpu_num_sparse_core_tiles_to_trace: especifica o número de blocos em cada núcleo esparso a ser rastreado na TPU.tpu_num_chips_to_profile_per_task: especifica o número de chips de TPU a serem criados por perfil por tarefa.
Opções de GPU
As seguintes opções estão disponíveis para o perfil da GPU:
gpu_max_callback_api_events: define o número máximo de eventos coletados pela API de callback do CUPTI. O padrão é2*1024*1024.gpu_max_activity_api_events: define o número máximo de eventos coletados pela API de atividade CUPTI. O padrão é2*1024*1024.gpu_max_annotation_strings: define o número máximo de strings de anotação que podem ser coletadas. O padrão é1024*1024.gpu_enable_nvtx_tracking: ativa o rastreamento de NVTX no CUPTI. O padrão éFalse.gpu_enable_cupti_activity_graph_trace: ativa o rastreamento do gráfico de atividade do CUPTI para gráficos CUDA. O padrão éFalse.gpu_pm_sample_counters: uma string separada por vírgulas de métricas de monitoramento de desempenho da GPU a serem coletadas usando o recurso de amostragem PM do CUPTI (por exemplo,"sm__cycles_active.avg.pct_of_peak_sustained_elapsed"). A amostragem PM é desativada por padrão. Para conferir as métricas disponíveis, consulte a documentação da NVIDIA CUPTI (em inglês).gpu_pm_sample_interval_us: define o intervalo de amostragem em microssegundos para a amostragem de PM do CUPTI. O padrão é500.gpu_pm_sample_buffer_size_per_gpu_mb: define o tamanho do buffer de memória do sistema por dispositivo em MB para amostragem de PM do CUPTI. O padrão é 64 MB. O valor máximo aceito é de 4 GB.gpu_num_chips_to_profile_per_task: especifica o número de dispositivos de GPU a serem criados por perfil por tarefa. Se não for especificado, definido como 0 ou definido como um valor inválido, todas as GPUs disponíveis serão analisadas. Isso pode ser usado para diminuir o tamanho da coleta de rastreamentos.gpu_dump_graph_node_mapping: se ativado, despeja informações de mapeamento de nós do gráfico CUDA no rastreamento. O padrão éFalse.
Exemplo:
options = ProfileOptions()
options.advanced_configuration = {"tpu_trace_mode" : "TRACE_ONLY_HOST", "tpu_num_sparse_cores_to_trace" : 2}
Retorna InvalidArgumentError se forem encontradas chaves ou valores de opção não reconhecidos.