Modelo de costos del LHS


Resumen:

En esta página, se describen los detalles internos del modelo de costos que utiliza el programador de ocultamiento de latencia. Si te interesa ajustar el modelo, ve directamente a la sección de ajuste.

El Latency Hiding Scheduler (LHS) es un paso del compilador que programa un DAG de HLO de manera que se minimice el tiempo real.

Sus decisiones se basan en el modelo de costos unificado, que utiliza una combinación de tablas de rendimiento y modelos analíticos. En particular, XLA incorpora tablas de rendimiento para GEMM y operaciones colectivas de interconexión rápida, y usa un modelo analítico de costos de redes y fusión para otros casos. En el resto del documento, se describe el funcionamiento interno de estos elementos a un nivel general.


Tablas de rendimiento: ICI colectivas

La tabla de rendimiento consta de dos componentes principales: un recopilador y un interpolador.

Recopilador

El colector es una herramienta de C++ responsable de generar las tablas de rendimiento para las operaciones colectivas. Mide el rendimiento de las operaciones de HLO individuales (p.ej., all-gather, all-reduce) en un espacio de parámetros definido de forma estática.

Cómo funciona

La herramienta realiza un análisis en un rango de operaciones colectivas, tamaños de transferencia y esquemas de transferencia para un clúster determinado. Utiliza la infraestructura existente del ejecutor de HLO de varios hosts y los datos de ExecutionProfile para ejecutar el HLO generado y recopilar métricas de rendimiento.

Parámetros de recopilación de datos

Las tablas de latencia se recopilan para un producto cruzado de los siguientes parámetros:

  • Tipo de colectivo:
    • all-reduce
    • all-gather
    • reduce-scatter
  • Tamaño de transferencia:
    • Escala logarítmica de 1,024 B hasta 2 GiB (p.ej., 1024 B, 2048 B, 4096 B, …)
  • Esquema de transferencia:
    • rail-aligned
    • non-rail-aligned

Este análisis se ejecuta para clústeres dentro del nodo con 2, 4 y 8 dispositivos.

Salida

El resultado de una ejecución de recopilación es una tabla de latencia en formato .pbtxt (aproximadamente 116 KB por plataforma).

Interpolator

El interpolador es el componente del compilador que consume las tablas de rendimiento generadas para proporcionar estimaciones del tiempo de ejecución durante la compilación.

Estructura de datos interna

Durante la inicialización, el interpolador procesa la tabla de rendimiento en un mapa. Este mapa usa una tupla de (collective_type, transfer_scheme) como su clave.

El valor asociado con cada clave es un plano euclidiano 2D. Este plano indexa el rendimiento de la red (medido por el recopilador) en función de dos ejes:

  1. Tamaño de la transferencia.
  2. Cantidad de dispositivos involucrados.

Búsqueda e interpolación

Cuando el compilador encuentra una operación colectiva, el interpolador realiza los siguientes pasos:

  1. Identifica el plano de capacidad de procesamiento 2D correcto con el (collective_type, transfer_scheme) de la operación como clave del mapa.
  2. Luego, usa una recuperación de promedio ponderado (basada en la distancia euclidiana) dentro de ese plano 2D, con el (transfer_size, num_devices) de la operación como punto de consulta.
  3. El resultado de esta búsqueda es un solo valor único de capacidad de procesamiento de la red.

Justificación: Capacidad de procesamiento y extrapolación

El sistema está diseñado para almacenar el rendimiento de la red en lugar de la latencia sin procesar. Esta elección de diseño simplifica significativamente la extrapolación del rendimiento para los tamaños de transferencia que no están presentes de forma explícita en la tabla.

Si las tablas de latencia capturan la saturación del ancho de banda de la red en un tamaño colectivo S, la capacidad de procesamiento T en ese punto se considera la máxima. Para cualquier colectivo nuevo de tamaño S' > S, el tiempo de ejecución se puede estimar de la siguiente manera:

\[\text{EstimatedTime}(S') = \frac{S'}{T_{\text{saturated} } }\]

Esto permite que el modelo estime el rendimiento para colectivos de cualquier tamaño, incluso aquellos más grandes que el máximo de 2 GiB medido por el Collector.

  • Subestime la capacidad de procesamiento máxima.
  • Por lo tanto, sobreestiman el tiempo de ejecución para las transferencias grandes.

En general, los equipos de XLA:GPU mantienen tablas de rendimiento, pero, en los casos en que el usuario decide proporcionar las suyas propias, es responsabilidad del usuario que genera las tablas asegurarse de que sean representativas y de que incluyan mediciones en la región saturada de ancho de banda para el hardware objetivo.


Tablas de rendimiento: GEMM

Al igual que el sistema de colectivos, las tablas de latencia de GEMM se admiten con dos componentes: un recopilador y un interpolador.

Recopilador

El colector es una herramienta de C++ que calcula tablas de rendimiento para las multiplicaciones de matrices generales (GEMM). Mide el rendimiento de las multiplicaciones de matrices a nivel de la operación dot del HLO.

Cómo funciona

La herramienta realiza un análisis en un espacio estático de dimensiones de GEMM (lote, dos dimensiones no contractivas y una contractiva) y tipos de datos.

  • Tipos de datos predeterminados: LHS = bf16,f32, RHS = bf16,f32, OUT = bf16,f32.
  • Infraestructura: Vuelve a usar el generador de perfiles de operaciones de HLO.

Parámetros de la colección

Las tablas de latencia se recopilan para un producto cruzado de las siguientes dimensiones:

  • batch: {1, 2, 4}
  • m (sin contrato): {256, 512, ..., 4096}
  • n (sin contrato): {256, 512, ..., 4096}
  • k (contracción): {256, 512, ..., 4096}

Salida y almacenamiento

Un análisis completo genera una tabla de latencia .pbtxt, lista para que la consuma el interpolador.

Interpolator

El interpolador es el componente del compilador que usa las tablas generadas para estimar el rendimiento de GEMM.

Justificación: Saturación de FLOPS

Las tablas de latencia recopiladas permiten que el interpolador reconstruya los FLOPS para cada entrada:

\[\text{FLOPS} = \frac{2 \times b \times m \times n \times k}{\text{runtime} }\]

Una observación clave es que los FLOPS se saturan en un punto determinado, es decir, el hardware alcanza el máximo de FLOPS más allá de una determinada forma de matriz. Esta saturación permite usar el mismo método de extrapolación que se emplea para los colectivos.

Búsqueda e interpolación

El interpolador crea un espacio euclidiano 4D a partir de los datos de la tabla. Para proporcionar una estimación del rendimiento, realiza una interpolación de promedio ponderado dentro de este espacio 4D. Si no hay una tabla para un determinado tipo de datos, como heurística, cada dimensión se normaliza según la cantidad de bytes.


Modelo de costos analíticos: DCN

Modelo de costos colectivos de curva en S

El modelo de curva S es un modelo de techo de red completamente analítico.

Descripción general

El modelo está diseñado para estimar el rendimiento de las operaciones colectivas en función de un conjunto de propiedades de red fijas.

Entradas del modelo

El modelo requiere dos categorías de entradas:

  1. Propiedades de red fijas (definidas por el usuario):

    • Sobrecarga de lanzamiento colectiva
    • Velocidad de la NIC
    • RTT (tiempo de ida y vuelta)

    De forma predeterminada, XLA detecta automáticamente una plataforma y usa valores para las arquitecturas más comunes. El usuario puede configurar estas propiedades. Consulta la sección de ajuste para obtener más detalles.

  2. Entradas por colectivo:

    • Tipo colectivo (p.ej., AllGather, ReduceScatter)
    • Tamaño de transferencia
    • Cantidad de nodos involucrados en la comunicación

Integración

El modelo de curva S está integrado en XLA:GPU y se usa en Hopper y Blackwell.


Modelo de costos analíticos: Fusions

Para otros kernels, nos basamos en el modelo de costo de rendimiento de la GPU para estimar los tiempos de ejecución correctos. Puedes obtener más información al respecto aquí.


Ajuste

El modelo de curva S se puede ajustar con los parámetros de XLA correctos. La configuración predeterminada debería ser suficiente en la mayoría de los casos, pero el control del modelo se expone en otros casos.

export NIC_SPEED_GBPS=... # NIC speed per GPU in Gigabytes
export GPUS_PER_NODE=... # Num of GPUs per cluster interconnected with fast network (e.g. NVLINK)
export XLA_FLAGS=--xla_gpu_analytical_latency_estimator_options="nic_speed_gbps=$NIC_SPEED_GBPS,gpus_per_node=$GPUS_PER_NODE"