Un volcado de HLO es una representación textual de los módulos de HLO en diferentes etapas del cálculo. Es útil para la depuración y, a menudo, debes incluirlo en los informes de errores. Por lo general, se trata de un archivo de texto legible que enumera las instrucciones del HLO y sus propiedades. A veces, los módulos de HLO se vuelcan de la siguiente manera:
- HloProto: Son archivos de búfer de protocolo, que tienen un formato más estructurado y legible por máquina.
- HloSnapshot: Módulo de HLO más sus entradas. Para reproducir HLO, a veces necesitas las entradas reales que se proporcionan a un cálculo determinado en lugar de datos aleatorios.
Puedes usar marcas de XLA para especificar y obtener volcados. En la mayoría de los casos, puedes establecerlo con una variable de entorno. JAX también ofrece una forma programática de imprimir el volcado de HLO.
Ejecución local
Usa variables de entorno
Puedes establecer la variable de entorno XLA_FLAGS con las marcas necesarias para obtener volcados. Esto funciona para JAX, TensorFlow y PyTorch/XLA.
Para volcar módulos de HLO y otra información de depuración en un directorio específico, ejecuta tu programa con la marca --xla_dump_to:
XLA_FLAGS="--xla_dump_to=DIRECTORY_PATH"
Por ejemplo, puedes usar /tmp o /tmp/xladump como rutas de acceso.
De forma predeterminada, esto vuelca los módulos de HLO como texto, al principio y al final de la canalización de optimización.
También puedes especificar el formato de forma explícita:
- Volcados de texto
XLA_FLAGS="--xla_dump_hlo_as_text --xla_dump_to=DIRECTORY_PATH"
- Protos de HLO
XLA_FLAGS="--xla_dump_hlo_as_proto --xla_dump_to=DIRECTORY_PATH"
- Instantáneas del HLO
XLA_FLAGS="--xla_dump_hlo_snapshots --xla_dump_to=DIRECTORY_PATH"
- Renderización de gráficos con el servidor de Graphviz (solo funciona bien para gráficos pequeños)
XLA_FLAGS="--xla_dump_hlo_as_url --xla_dump_to=DIRECTORY_PATH"
- Renderización de grafos en archivos HTML (solo funciona bien para grafos pequeños)
XLA_FLAGS="--xla_dump_hlo_as_html --xla_dump_to=DIRECTORY_PATH"
En el caso de los gráficos más grandes, puedes usar interactive_graphviz para visualizar partes del gráfico.
Cómo volcar pases intermedios específicos
Además de los HLOs estándar preoptimizados o completamente optimizados, también puedes volcar el estado de los HLOs después de un pase de compilador en particular.
XLA_FLAGS="--xla_dump_hlo_pass_re=regex --xla_dump_to=DIRECTORY_PATH"
Se volcarán los módulos de HLO para los pases cuyos nombres coincidan con la expresión regular (regex). Por ejemplo, puedes observar los HLO que resultan de los pases relacionados con la partición SPMD con el siguiente comando:
XLA_FLAGS="--xla_dump_to=DIRECTORY_PATH --xla_dump_hlo_pass_re=spmd|propagation"
Para volcar el resultado después de cada paso de XLA (esto generará muchos archivos), puedes configurar lo siguiente:
XLA_FLAGS="--xla_dump_to=DIRECTORY_PATH --xla_dump_hlo_pass_re=.*"
Opciones específicas de JAX
De forma programática en JAX
En lugar de pasar marcas o variables de entorno, también puedes volcar HLO de forma programática con las APIs lower y compile de JAX.
Para recuperar de forma local el HLO original sin optimizar y reducido, haz lo siguiente:
jax.jit(f).lower(*args).as_text('hlo')
Para volcar a archivos durante los pases de compilación de HLO, especifica lo siguiente:
compilation_args = {
'xla_dump_to': DIRECTORY_PATH,
'xla_dump_hlo_pass_re': 'spmd|propagation', # or some other pass filter
...
}
jax.jit(f).lower(*args).compile(compilation_args)
Cómo volcar Jaxprs
Los jaxprs son la representación intermedia de JAX para los registros de ejecución de programas. Para volcar esta información, configura las variables de entorno:
JAX_DUMP_IR_TO="DIRECTORY_PATH" JAX_DUMP_IR_MODES=jaxpr
Obtén más información en la documentación de JAX sobre Exporting and serializing staged-out computations: Debugging (Cómo exportar y serializar cálculos por etapas: Depuración).
Google Colab
Variables de entorno
En la primera celda ejecutada de tu notebook (porque las variables de entorno y las marcas de línea de comandos suelen procesarse solo una vez, p.ej., en el momento de la importación del módulo o de la inicialización del backend de XLA), agrega el XLA_FLAGS que se describió anteriormente con os.environ, por ejemplo:
import os
os.environ['XLA_FLAGS'] = "--xla_dump_to=DIRECTORY_PATH"
Esto volcará el cálculo en DIRECTORY_PATH, por ejemplo, /tmp. En Colab, navega al explorador de "Archivos" en la barra lateral izquierda para ver y acceder a este directorio.
Puedes usar todas las marcas mencionadas en la sección Ejecución local.
Opciones específicas de JAX
De manera similar a la ejecución local, para la introspección interactiva en vivo, puedes imprimir directamente el HLO preoptimizado de un cálculo:
def f(x):
return jax.numpy.sin(jax.numpy.cos(x))
c = jax.jit(f).lower(3.).compiler_ir('hlo')
print(c.as_hlo_text())
También puedes imprimir directamente el HLO optimizado de un cálculo:
def optimized_HLO(f, *args, platform=None):
print(jax.jit(f).lower(*args).compile().as_text())
def f(x):
return jax.numpy.sin(jax.numpy.cos(x))
optimized_HLO(f, 1.0)
Volcado de todos los cómputos o de una pequeña cantidad de ellos
Si deseas ver todo en un volcado, incluidas todas las compilaciones pequeñas, establece la variable de entorno de JAX de la siguiente manera:
JAX_COMPILER_DETAILED_LOGGING_MIN_OPS=0
Mosaico
Mosaic es un compilador para el backend de Pallas TPU y el backend experimental de Pallas GPU. Para volcar el cálculo del mosaico, establece la siguiente marca:
--xla_mosaic_dump_to=/tmp/mosaic_dumps
O bien, establece los argumentos de inicialización de la TPU como una variable de entorno:
export LIBTPU_INIT_ARGS="--xla_mosaic_dump_to=/tmp/mosaic_dumps"
Consulta la documentación de JAX sobre Pallas y Mosaic para obtener más información.
Más con HLO Dumps
Cómo encontrar el cálculo correcto
Por lo general, se descartan muchos cálculos. Los archivos volcados se nombran explícitamente con el "nombre de la computación" de JAX, TensorFlow o PyTorch/XLA que se menciona en los registros, lo que facilita la identificación de los archivos HLO pertinentes. Por ejemplo:
1624325116260738.module_0065.pmap__unnamed_wrapped_function_.186875.before_optimizations.txt
De lo contrario, puedes usar ripgrep para identificar rápidamente qué módulo contiene símbolos o cálculos particulares.
Nota: Incluye los 3 archivos volcados antes/después/asignación de búferes de interés en tus informes de errores.
Conversión de HLO
Una herramienta llamada hlo-opt que puede traducir entre formatos de HLOProto y texto.
Es útil en los casos en los que tienes un formato, pero necesitas el otro para la depuración.
Aprende a usarlo: Documentación de XLA Tooling: hlo-opt.
Volver a reproducir
Puedes ejecutar (reproducir) los cálculos volcados en un backend de XLA especificado con datos falsos o instantáneas de entrada. Esta es una forma conveniente de reproducir, iterar y depurar problemas en XLA.
Los siguientes comandos usan datos falsos. Si guardaste instantáneas de HLO, puedes pasarlas en su lugar y se usarán los datos de la instantánea. Para seguir usando datos falsos mientras ejecutas la instantánea, pasa la marca --force_fake_data.
Backend de CPU:
bazel run -c opt //xla/hlo/tools:run_hlo_module -- --platform=cpu
/tmp/xladump/module_4561.before_optimizations.txt
Backend de GPU:
bazel run -c opt //xla/hlo/tools:run_hlo_module -- --platform=CUDA
/tmp/xladump/module_4561.before_optimizations.txt