- Adding async operations to HLO is cumbersome (i.e.
all-reduce-startandall-reduce-done). - The start and done split may be inadequate for some of the asynchronous use cases.
To target the first shortcoming, we propose to introduce one last set of new
asynchronous opcodes: kAsyncStart, kAsyncUpdate, and kAsyncDone. The idea
is to create a generic asynchronous opcode that can wrap any HLO instruction.
The actual operation that will be performed asynchronously will be encoded using
a called computation that only has the instruction as its root and any
parameters for inputs. The in-flight input/output buffer handling and aliasing
can then be shared for any asynchronous operation. The async-start instruction’s
output shape will then be a tuple of the input operands, output values, and any
intermediate state that is needed for the async-update or async-done
instructions.
%async_op {
%param0 = f32[64] parameter(0)
ROOT %op = f32[32] op(f32[64] %param0), op_specific_attr=”foo”
}
%async-start = ((f32[64]), f32[32], s32[]) async-start(f32[64] %operand),
calls=%async_op
%async-done = f32[32] async-done(((f32[64]), f32[32], s32[]) %async-start)
In the representation above, only async-start has a called computation since
it is trivial to find what the async-done does by following its operand to
find the corresponding async-start to find the called computation.
Also note that the first element in the output tuple of async-start is a
tuple containing the operands. The elements of this operand tuple alias with
the respective operands, so their buffers stay alive until at least the
async-done instruction. Similarly, the second element aliases with the output
of async-done, and the third element is the context state that is used to
keep track of the asynchronous operation. This representation naturally
supports multiple tensors in the asynchronous operation input and/or output:
%async_op {
%param0 = f32[64] parameter(0)
%param1 = f32[64] parameter(1)
ROOT %op = (f32[32], f32[32]) op(f32[64] %param0, f32[64] %param1),
op_specific_attr=”foo”
}
%async-start = ((f32[64], f32[64]), (f32[32], f32[32]), s32[])
async-start(f32[64] %operand0, f32[64] %operand1),
calls=%async_op
%async-done = (f32[32], f32[32]) async-done(%async-start)
In addition, the op can further be decomposed into zero or more async-update
steps that perform intermediate computations. The input/output aliasing works
the same way with the async-update instruction and each async-start and
async-update instructions must have one user that is either another
async-update or an async-done:
%async_op {
%param0 = f32[64] parameter(0)
ROOT %op = f32[32] op(f32[64] %param0), op_specific_attr=”foo”
}
%async-start = ((f32[64]), f32[32], s32[]) async-start(f32[64] %operand),
calls=%async_op
%async-update0 = ((f32[64]), f32[32], s32[]) async-update(
((f32[64]), f32[32], s32[]) %async-start)
%async-update1 = ((f32[64]), f32[32], s32[]) async-update(
((f32[64]), f32[32], s32[]) %async-update0)
%async-done = f32[32] async-done(((f32[64]), f32[32], s32[]) %async-update1)
Syntax sugar
The HLO parser supports syntax sugar to automatically parse and print
asynchronous operations as if they are first-class opcodes. The parser treats
the -start, -update, and -done suffixes specially by automatically
creating the async computation and the wrapped instruction (without the suffix).
For example, an asynchronous custom-call can be written as:
%cc-start = ((f32[64]), f32[32], s32[]) custom-call-start(%operand),
custom_call_target="foo"
%cc-update = ((f32[64]), f32[32], s32[]) custom-call-update(%cc-start)
%result = f32[32] custom-call-done(%cc-update)
The parser desugars this into the following equivalent HLO:
%async_computation {
%p0 = f32[64] parameter(0)
ROOT %custom-call = f32[32] custom-call(%p0), custom_call_target="foo"
}
%async-start = ((f32[64]), f32[32], s32[]) async-start(%operand),
calls=%async_computation
%async-update = ((f32[64]), f32[32], s32[]) async-update(%async-start)
%result = f32[32] async-done(%async-update)
This desugaring is supported for most HLO opcodes (e.g., custom-call, dot,
all-reduce, etc.).
Exceptions
In order not to create ambiguities, the parser will not desugar operations that
have explicit first-class opcodes defined with the -start and/or -done
suffixes (e.g., copy-start/copy-done,
collective-permute-start/collective-permute-done). These will continue to
use their respective first-class opcodes.
Late Binding
In some cases, the operands (inputs) or outputs of an asynchronous
operation are not all available or allocated when the operation starts.
XLA supports late binding, which allows operands to be incrementally
bound during async-update steps, and outputs to be bound during either
async-update or async-done steps.
Representation in HLO
For a called computation that expects \(N\) parameters, we can start the
asynchronous execution with fewer than \(N\) operands. The remaining
operands are passed in subsequent async-update instructions.
async-startbinds the first \(K\) operands (\(K < N\)).async-updateinstructions bind the remaining \(N - K\) operands.
Operand bindings must happen in left-to-right order. That is, if a computation expects parameters \(P_0, P_1, \dots, P_{N-1}\), they must be bound in that order across the async chain.
The async-start and async-update shapes reflect the incrementally
bound parameters. Specifically, the first element of the tuple shape
(the operand shapes) grows as more operands are bound.
Output binding is independent of operand binding and can happen at any
step in the async chain (either in an async-update or at the final
async-done).
Example with kCall
Consider a called computation %foo that takes two parameters:
%foo {
%p0 = f32[] parameter(0)
%p1 = f32[] parameter(1)
ROOT %add = f32[] add(%p0, %p1)
}
We can call this computation asynchronously, binding %p0 at start and
%p1 at update:
%call-start = ((f32[]), (), s32[]) call-start(%operand0), to_apply=%foo
%call-update = ((f32[], f32[]), f32[], s32[]) call-update(%call-start, %operand1)
%result = f32[] call-done(%call-update)
The parser desugars this into the following HLO:
%async-start = ((f32[]), (), s32[]) async-start(%operand0), calls=%foo
%async-update = ((f32[], f32[]), f32[], s32[]) async-update(%async-start, %operand1)
%result = f32[] async-done(%async-update)
Late-Bound Outputs
In addition to operands (inputs), the outputs of an asynchronous operation can also be bound late. This is useful when the output buffers are not known or allocated at the start of the operation.
To represent late-bound outputs:
- The
async-start(orcall-start) instruction is defined with an empty tuple()at index 1 of its output shape (the result slot). - A subsequent
async-update(orcall-update) instruction specifies the actual output shape at index 1, replacing the empty tuple. - Alternatively, the output can be bound at the end of the chain by
the
async-done(orcall-done) instruction, which returns the final output shape. This can be done regardless of whether there are intermediateasync-updatesteps in the chain.
Example with async-update
// Output is not bound at start (index 1 is ())
%call-start = ((f32[1024]), (), s32[]) call-start(%input_buffer), to_apply=%foo
// Output is bound at update (index 1 becomes (f32[1024]))
%call-update = ((f32[1024]), (f32[1024]), s32[]) call-update(%call-start, %output_buffer)
%result = (f32[1024]) call-done(%call-update)
The parser desugars this into:
%async-start = ((f32[1024]), (), s32[]) async-start(%input_buffer), calls=%foo
%async-update = ((f32[1024]), (f32[1024]), s32[]) async-update(%async-start, %output_buffer)
%result = (f32[1024]) async-done(%async-update)
Example with async-done (without async-update)
If there are no intermediate update steps, the output can be bound
directly at async-done:
// Output is not bound at start (index 1 is ())
%call-start = ((f32[1024]), (), s32[]) call-start(%input_buffer), to_apply=%foo
// Output is bound at done
%result = (f32[1024]) call-done(%call-start)
The parser desugars this into:
%async-start = ((f32[1024]), (), s32[]) async-start(%input_buffer), calls=%foo
%result = (f32[1024]) async-done(%async-start)
Example with intermediate async-update and output bound at async-done
If there are intermediate update steps to bind operands, but the output is still bound at the very end:
// Output is not bound at start, no operands bound
%call-start = ((), (), s32[]) call-start(), to_apply=%foo
// Operands are bound at update, but output remains unbound (index 1 is ())
%call-update = ((f32[], f32[]), (), s32[]) call-update(%call-start, %operand0, %operand1)
// Output is bound at done
%result = f32[] call-done(%call-update)
The parser desugars this into:
%async-start = ((), (), s32[]) async-start(), calls=%foo
%async-update = ((f32[], f32[]), (), s32[]) async-update(%async-start, %operand0, %operand1)
%result = f32[] async-done(%async-update)