tldr;
This page describes the internals of the cost model used by Latency Hiding Scheduler. If you are interested in tuning the model go straight to the Tuning section.
The Latency Hiding Scheduler (LHS) is a compiler pass that schedules a HLO DAG in a way that minimizes wall time.
Its decisions are guided by the unified cost model, which uses a mixture of performance tables and analytical models. In particular XLA embeds performance tables for a GEMMs and fast-interconnect collectives, and uses analytical networking and fusion cost model for other cases. The rest of the document describes the inner workings of these on a high level.
Performance tables – ICI collectives
Performance table consist of two main components: a collector and an interpolator.
Collector
The collector is a C++ tool responsible for generating the performance
tables for collective operations. It measures the performance of individual HLO
ops (e.g., all-gather
, all-reduce
) across a statically defined parameter
space.
How It Works
The tool performs a sweep over a range of collective ops, transfer sizes, and
transfer schemes for a given cluster. It uses the existing multi-host HLO runner
infrastructure and ExecutionProfile
data to run the generated HLO and gather
performance metrics.
Data Collection Parameters
Latency tables are collected for a cross-product of the following parameters:
- Collective Type:
all-reduce
all-gather
reduce-scatter
- Transfer Size:
- Logarithmic scale from 1024B up to 2GiB (e.g., 1024B, 2048B, 4096B, ...)
- Transfer Scheme:
rail-aligned
non-rail-aligned
This sweep is run for intra-node clusters with 2, 4, and 8 devices.
Output
The result of a collection run is a latency table in .pbtxt
format
(approximately 116 KB per platform).
Interpolator
The interpolator is the compiler component that consumes the generated performance tables to provide runtime estimates during compilation.
Internal Data Structure
On initialization, the Interpolator processes the performance table into a map.
This map uses a tuple of (collective_type, transfer_scheme)
as its key.
The value associated with each key is a 2D Euclidean plane. This plane indexes the network throughput (measured by the Collector) based on two axes:
- Transfer size.
- Number of devices involved.
Lookup and Interpolation
When the compiler encounters a collective operation, the Interpolator performs the following steps:
- It identifies the correct 2D throughput plane using the operation's
(collective_type, transfer_scheme)
as the map key. - It then uses a weighted average retrieval (based on Euclidean distance) within that 2D plane, using the operation's
(transfer_size, num_devices)
as the query point. - The result of this lookup is a single, unique network throughput value.
Rationale: Throughput and Extrapolation
The system is designed to store network throughput rather than raw latency. This design choice significantly simplifies extrapolating performance for transfer sizes not explicitly present in the table.
If the latency tables capture network bandwidth saturation at a collective size
S
, the throughput T
at that point is considered the maximum. For any new
collective of size S'
> S
, the runtime can be estimated as:
\[\text{EstimatedTime}(S') = \frac{S'}{T_{\text{saturated} } }\]
This allows the model to estimate performance for collectives of any size, even those larger than the 2GiB maximum measured by the Collector.
- Underestimate the maximum throughput.
- Consequently, overestimate the runtime for large transfers.
In general XLA:GPU teams maintains performance tables, but in cases user decide to provide their own, it is the responsibility of the user generating the tables to ensure they are representative and include measurements in the bandwidth-saturated region for the target hardware.
Performance tables – GEMMs
Similar to the system for collectives, GEMM latency tables are supported by two components: a collector and an interpolator.
Collector
The collector is a C++ tool that computes performance tables for General
Matrix Multiplications (GEMMs). It measures the performance of matrix
multiplications at the HLO dot
op level.
How It Works
The tool performs a sweep over a static space of GEMM dimensions (batch, two non-contracting, and one contracting dimension) and data types.
- Default Data Types:
LHS = bf16,f32
,RHS = bf16,f32
,OUT = bf16,f32
. - Infrastructure: Re-uses the HLO op profiler.
Collection Parameters
Latency tables are collected for a cross-product of the following dimensions:
- batch:
{1, 2, 4}
- m (non-contracting):
{256, 512, ..., 4096}
- n (non-contracting):
{256, 512, ..., 4096}
- k (contracting):
{256, 512, ..., 4096}
Output and Storage
A full sweep generates a .pbtxt
latency table, ready to be consumed by
interpolator.
Interpolator
The interpolator is the compiler component that uses the generated tables to estimate GEMM performance.
Rationale: FLOPS Saturation
The collected latency tables allow the interpolator to reconstruct FLOPS for each entry:
\[\text{FLOPS} = \frac{2 \times b \times m \times n \times k}{\text{runtime} }\]
A key insight is that FLOPS saturate at a certain point; that is, the hardware reaches peak FLOPS beyond a certain matrix shape. This saturation allows the use of the same extrapolation method employed for collectives.
Lookup and Interpolation
The interpolator builds a 4D Euclidean space from the table data. To provide a performance estimate, it performs a weighted-average interpolation within this 4D space. If there's no table for a certain data type, as a heuristic each dimension is normalized to the number of bytes.
Analytical Cost Model - DCN
S-curve Collective Cost Model
The S-curve model is a fully analytical networking roofline model.
Overview
The model is designed to estimate the performance of collective operations based on a set of fixed network properties.
Model Inputs
The model requires two categories of inputs:
Fixed Network Properties (User-Defined):
- Collective launch overhead
- NIC speed
- RTT (round trip time)
By default, XLA auto-detects a platform and uses values for the most common architectures. These properties are configurable by the user. See Tuning section for details.
Per-Collective Inputs:
- Collective type (e.g.,
AllGather
,ReduceScatter
) - Transfer size
- Number of nodes involved in the communication
- Collective type (e.g.,
Integration
The S-curve model is integrated into XLA:GPU
and is being used on Hopper, and
Blackwell.
Analytical Cost Model - Fusions
For other kernels we rely on the GPU performance cost model to estimate the right runtimes. You can read more about it here.
Tuning
S-curve model can be tuned by issuing right XLA flags. Default configuration should be good enough in majority of cases, but the model control is exposed in other cases.
export NIC_SPEED_GBPS=... # NIC speed per GPU in Gigabytes
export GPUS_PER_NODE=... # Num of GPUs per cluster interconnected with fast network (e.g. NVLINK)
export XLA_FLAGS=--xla_gpu_analytical_latency_estimator_options="nic_speed_gbps=$NIC_SPEED_GBPS,gpus_per_node=$GPUS_PER_NODE"