- Adding async operations to HLO is cumbersome (i.e.
all-reduce-start
andall-reduce-done
). - The start and done split may be inadequate for some of the asynchronous use cases.
To target the first shortcoming, we propose to introduce one last set of new
asynchronous opcodes: kAsyncStart
, kAsyncUpdate
, and kAsyncDone
. The idea
is to create a generic asynchronous opcode that can wrap any HLO instruction.
The actual operation that will be performed asynchronously will be encoded using
a called computation that only has the instruction as its root and any
parameters for inputs. The in-flight input/output buffer handling and aliasing
can then be shared for any asynchronous operation. The async-start instruction’s
output shape will then be a tuple of the input operands, output values, and any
intermediate state that is needed for the async-update
or async-done
instructions.
%async_op {
%param0 = f32[64] parameter(0)
ROOT %op = f32[32] op(f32[64] %param0), op_specific_attr=”foo”
}
%async-start = (f32[64], f32[32], s32[]) async-start(f32[64] %operand),
calls=%async_op
%async-done = f32[32] async-done((f32[64], f32[32], s32[]) %async-start)
In the representation above, only async-start
has a called computation since
it is trivial to find what the async-done
does by following its operand to
find the corresponding async-start
to find the called computation.
Also note
that the first element in the output tuple of async-start
aliases with the
operand, so the buffer stays alive until at least the async-done instruction.
Similarly, the second element aliases with the output of async-done
, and the
third element is the context state that is used to keep track of the
asynchronous operation. This representation also supports multiple tensors in
the asynchronous operation input and/or output and the aliasing works the same
way:
%async_op {
%param0 = f32[64] parameter(0)
%param1 = f32[64] parameter(1)
ROOT %op = (f32[32], f32[32]) op(f32[64] %param0, f32[64] %param1),
op_specific_attr=”foo”
}
%async-start = ((f32[64], f32[64]), (f32[32], f32[32]), s32[])
async-start(f32[64] %operand0, f32[64] %operand1),
calls=%async_op
%async-done = (f32[32], f32[32]) async-done(%async-start)
In addition, the op can further be decomposed into zero or more async-update
steps that perform intermediate computations. The input/output aliasing works
the same way with the async-update
instruction and each async-start
and
async-update
instructions must have one user that is either another
async-update
or an async-done
:
%async_op {
%param0 = f32[64] parameter(0)
ROOT %op = f32[32] op(f32[64] %param0), op_specific_attr=”foo”
}
%async-start = (f32[64], f32[32], s32[]) async-start(f32[64] %operand),
calls=%async_op
%async-update0 = (f32[64], f32[32], s32[]) async-update(
(f32[64], f32[32], s32[]) %async-start)
%async-update1 = (f32[64], f32[32], s32[]) async-update(
(f32[64], f32[32], s32[]) %async-update0)
%async-done = f32[32] async-done((f32[64], f32[32], s32[]) %async-update1)
Syntax sugar
Since having a separate computation to define the operation that will be performed asynchronously is a bit cumbersome, we also propose a syntax sugar to automatically print and parse asynchronous operations as if they are first-class opcodes. The idea is to treat the “-start”, “-update”, and “-done” suffixes specially by automatically creating the computation and instruction (without the suffix) when parsing. For example, the code snippet above can be pretty-printed to the following and the two can be parsed to the same representation:
%op-start = (f32[64], f32[32], s32[]) op-start(f32[64] %operand),
op_specific_attr=”foo”
%op-update0 = (f32[64], f32[32], s32[]) op-update(
(f32[64], f32[32], s32[]) %op-start),
op_specific_attr=”foo”
%op-update1 = (f32[64], f32[32], s32[]) op-update(
(f32[64], f32[32], s32[]) %op-update0)
%op-done = f32[32] op-done((f32[64], f32[32], s32[]) %op-update1)
In order not to create ambiguities, the verifier will not allow an operation to
be wrapped with async-start if we explicitly defined an opcode for that
operation with the “-start” and/or “-done” suffixes. This is also an escape
hatch in case we have any instructions that require HLO-level treatment that
doesn’t fit in the model described above (e.g. the aliasing input/output
buffers). So, initially, copy-start
/copy-done
,
collective-permute-start
/collective-permute-done
etc. will continue to use
their respective first-class opcodes instead of the new
async-start
/async-done
opcodes until we clean up the code to remove these
“-start”/”-done” opcodes.