XLA: como otimizar o compilador para machine learning

O OpenXLA é um compilador específico de domínio para álgebra linear que pode acelerar modelos do TensorFlow sem mudanças no código-fonte.

Introdução

Quando um programa do TensorFlow é executado, todas as operações são executadas individualmente pelo executor do TensorFlow. Cada operação do TensorFlow tem uma implementação de kernel da GPU pré-compilada para a qual o executor envia.

O XLA fornece um modo alternativo de execução de modelos: ele compila o gráfico do TensorFlow em uma sequência de kernels de computação gerados especificamente para o modelo especificado. Como esses kernels são exclusivos do modelo, eles podem explorar informações específicas do modelo para otimização. Por exemplo, vamos ver como um XLA de otimização faz no contexto de um cálculo simples do TensorFlow:

def model_fn(x, y, z):
  return tf.reduce_sum(x + y * z)

Se for executado sem XLA, o grafo iniciará três kernels: um para a multiplicação, um para a adição e outro para a redução. No entanto, o XLA pode otimizar o gráfico para que ele calcule o resultado em uma única inicialização do kernel. Ele faz isso "combinando" adição, multiplicação e redução em um único kernel da GPU. Além disso, essa operação combinada não grava os valores intermediários produzidos por y*z e x+y*z na memória. Em vez disso, ela "transmite" os resultados desses cálculos intermediários diretamente para os usuários, mantendo-os totalmente em registros da GPU. A fusão é a otimização mais importante do XLA. A largura de banda de memória normalmente é o recurso mais escasso em aceleradores de hardware. Portanto, a remoção de operações de memória é uma das melhores maneiras de melhorar o desempenho.

Ativar XLA para modelos do TensorFlow

Compilação explícita com tf.function(jit_compile=True)

A API de compilação explícita oferece um controle refinado para escolher quais funções precisam ser compiladas. Por exemplo, a seguinte função do TensorFlow que executa o treinamento MNIST é compilada com o XLA:

@tf.function(jit_compile=True)
def train_mnist(images, labels):
    images, labels = cast(images, labels)

    with tf.GradientTape() as tape:
      predicted_labels = layer(images)
      loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
          logits=predicted_labels, labels=labels
      ))
    layer_variables = layer.trainable_variables
    grads = tape.gradient(loss, layer_variables)
    optimizer.apply_gradients(zip(grads, layer_variables))

A API jit_compile tem semântica precisa de compilação: a função inteira é compilada com XLA ou uma exceção errors.InvalidArgumentError é gerada. No momento, o XLA não pode compilar funções em que as dimensões não são inferíveis, ou seja, se não for possível inferir as dimensões de todos os tensores sem executar todo o cálculo. Por exemplo, a função a seguir não será compilada:

@tf.function
def not_compilable(x):
  return tf.unique(x)

No entanto, as formas podem variar entre as execuções:

@tf.function(jit_compile=True)
def recompiled_on_launch(a, b):
  return a + b

recompiled_on_launch(tf.ones([1, 10]), tf.ones([1, 10]))
recompiled_on_launch(tf.ones([1, 100]), tf.ones([1, 100]))

Consulte o tutorial no Colab para conferir um exemplo de uso mais detalhado e um tutorial em vídeo sobre o uso do jit_compile=True.

Uso com a Keras

Para modelos da Keras, jit_compile=True pode ser definido como um argumento para model.compile:

model.compile(optimizer="adam", jit_compile=True)

Uso com estratégia distribuída

XLA:GPU pode ser usada com estratégia distribuída do TF (MirroredStrategy ou MultiWorkerMirroredStrategy) anotando a função de etapa com jit_compile=True:

@tf.function(jit_compile=True)
def step_fn():
  t = tf.ones(shape=[100], dtype=tf.float32)
  ctx = tf.distribute.get_replica_context()
  return ctx.all_reduce(tf.distribute.ReduceOp.SUM, t)

@tf.function
def run_fn():
  return strategy.run(step_fn)

Clustering automático

Uma maneira simples de começar a usar XLA em modelos do TensorFlow sem nenhuma alteração é ativar o clustering automático, que encontra automaticamente clusters (subgráficos conectados) nas funções do TensorFlow que podem ser compiladas e executadas usando XLA. O clustering automático na GPU pode ser ativado definindo a variável de ambiente TF_XLA_FLAGS:

$ TF_XLA_FLAGS=--tf_xla_auto_jit=2 path/to/your/tf/program

No momento, o clustering automático está otimizado para cargas de trabalho da GPU, mas também pode ser ativado na CPU usando a sinalização --tf_xla_cpu_global_jit:

$ TF_XLA_FLAGS="--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit" path/to/your/program

Para ver um exemplo de uso detalhado, consulte o Tutorial de clustering automático no Colab.

Compilação antecipada (AOT, na sigla em inglês) para CPU com tfcompile

Também é possível usar uma ferramenta autônoma tfcompile, que converte o gráfico do TensorFlow em código executável (somente para CPU x86-64).

Inspecionar programas compilados

O XLA fornece instalações de introspecção que permitem inspecionar os programas gerados. Para despejar os programas gerados, use a variável de ambiente XLA_FLAGS:

$ XLA_FLAGS="--xla_dump_to=/tmp/generated" TF_XLA_FLAGS="--tf_xla_auto_jit=2" my/tensorflow/program

Depois que o despejo for feito, você poderá encontrar os seguintes arquivos em /tmp/generated:

  • O module_XXXX.*_optimizations.txt gerou programas XLA, um para cada cluster compilado. Anexá-los ao enviar relatórios de bugs do XLA é extremamente útil.

  • O module_XXXX.ir-*.ll gerou arquivos na representação intermediária LLVM (link em inglês), com intrínsecos NVPTX.

  • O module_XXXX.ptx gerou arquivos PTX.

Também é possível despejar o grafo que visualiza a incorporação de clusters do XLA dentro do gráfico do TensorFlow com:

$ TF_DUMP_GRAPH_PREFIX=/tmp/generated TF_XLA_FLAGS="--tf_xla_clustering_debug"

Relatórios de bugs reproduzíveis

Um relatório de bug é muito mais fácil de reproduzir se incluir despejos para os programas XLA gerados e a incorporação de clustering automático usada. Para gerá-las para um programa do TensorFlow em execução com clustering automático, inicie:

$ TF_DUMP_GRAPH_PREFIX=/tmp/generated \
  TF_XLA_FLAGS="--tf_xla_clustering_debug --tf_xla_auto_jit=2" \
  XLA_FLAGS="--xla_dump_hlo_as_text --xla_dump_to=/tmp/generated" \
    my/tensorflow/program"

Ao registrar bugs, anexe o conteúdo do diretório /tmp/generated (mencionado acima).

Se possível, tente isolar um bug em um único programa do XLA usando run_hlo_module e executando-o iterativamente nos programas gerados.

Sugestões de leitura

Front-ends de XLA

Além do TensorFlow, os programas do XLA podem ser gerados por:

  • JAX: transformações combináveis de programas Python e NumPy
  • Julia: a linguagem Julia para computação científica
  • PyTorch: framework PyTorch
  • Nx: biblioteca de computação numérica para a linguagem de programação Elixir

Palestras

Como usar XLA do TF com jit_compile=True

Visão geral de XLA