XLA: como otimizar o compilador para machine learning

O OpenXLA é um compilador específico de domínio para álgebra linear que pode acelerar modelos do TensorFlow sem mudanças no código-fonte.

Introdução

Quando um programa do TensorFlow é executado, todas as operações são executadas individualmente pelo executor do TensorFlow. Cada operação do TensorFlow tem uma implementação de kernel da GPU pré-compilada para onde o executor envia.

O XLA fornece um modo alternativo de execução de modelos: ele compila o grafo do TensorFlow em uma sequência de kernels de computação gerados especificamente para o modelo fornecido. Como esses kernels são exclusivos do modelo, eles podem explorar informações específicas do modelo para otimização. Por exemplo, vamos analisar o XLA de otimização no contexto de um cálculo simples do TensorFlow:

def model_fn(x, y, z):
  return tf.reduce_sum(x + y * z)

Executado sem o XLA, o gráfico inicia três kernels: um para a multiplicação, outro para a adição e outro para a redução. No entanto, o XLA pode otimizar o grafo para que ele calcule o resultado em uma única inicialização do kernel. Ele faz isso "combinando" adição, multiplicação e redução em um único kernel da GPU. Além disso, essa operação combinada não grava os valores intermediários produzidos por y*z e x+y*z na memória. Em vez disso, ela "transmite" os resultados desses cálculos intermediários diretamente para os usuários, mantendo-os inteiramente em registros da GPU. A fusão é a otimização mais importante do XLA. A largura de banda de memória normalmente é o recurso mais escasso em aceleradores de hardware. Portanto, remover operações de memória é uma das melhores maneiras de melhorar o desempenho.

Ativar o XLA para modelos do TensorFlow

Compilação explícita com tf.function(jit_compile=True)

A API de compilação explícita oferece um controle refinado para escolher quais funções devem ser compiladas. Por exemplo, a seguinte função do TensorFlow que realiza o treinamento MNIST é compilada com o XLA:

@tf.function(jit_compile=True)
def train_mnist(images, labels):
    images, labels = cast(images, labels)

    with tf.GradientTape() as tape:
      predicted_labels = layer(images)
      loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
          logits=predicted_labels, labels=labels
      ))
    layer_variables = layer.trainable_variables
    grads = tape.gradient(loss, layer_variables)
    optimizer.apply_gradients(zip(grads, layer_variables))

A API jit_compile tem semântica must-compile: a função inteira é compilada com XLA ou uma exceção errors.InvalidArgumentError é gerada. No momento, o XLA não pode compilar funções em que as dimensões não são inferíveis: isto é, se não for possível inferir as dimensões de todos os tensores sem executar o cálculo completo. Por exemplo, a função a seguir não pode ser compilada:

@tf.function
def not_compilable(x):
  return tf.unique(x)

No entanto, as formas podem variar entre as execuções:

@tf.function(jit_compile=True)
def recompiled_on_launch(a, b):
  return a + b

recompiled_on_launch(tf.ones([1, 10]), tf.ones([1, 10]))
recompiled_on_launch(tf.ones([1, 100]), tf.ones([1, 100]))

Consulte o tutorial no Colab para conferir um exemplo de uso mais detalhado e um tutorial em vídeo sobre o uso de jit_compile=True.

Uso com o Keras

Para modelos do Keras, jit_compile=True pode ser definido como um argumento para model.compile:

model.compile(optimizer="adam", jit_compile=True)

Uso com a estratégia distribuída

XLA:GPU pode ser usado com a estratégia distribuída do TF (MirroredStrategy ou MultiWorkerMirroredStrategy) anotando a função de etapa com jit_compile=True:

@tf.function(jit_compile=True)
def step_fn():
  t = tf.ones(shape=[100], dtype=tf.float32)
  ctx = tf.distribute.get_replica_context()
  return ctx.all_reduce(tf.distribute.ReduceOp.SUM, t)

@tf.function
def run_fn():
  return strategy.run(step_fn)

Clustering automático

Uma forma simples de começar a usar XLA em modelos do TensorFlow sem nenhuma mudança é ativar o clustering automático, que encontra automaticamente clusters (subgrafos conectados) nas funções do TensorFlow que podem ser compiladas e executadas usando XLA. O clustering automático na GPU pode ser ativado ao definir a variável de ambiente TF_XLA_FLAGS:

$ TF_XLA_FLAGS=--tf_xla_auto_jit=2 path/to/your/tf/program

Atualmente, o clustering automático é otimizado para cargas de trabalho da GPU, mas também pode ser ativado na CPU usando a flag --tf_xla_cpu_global_jit adicionalmente:

$ TF_XLA_FLAGS="--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit" path/to/your/program

Para conferir um exemplo de uso detalhado, consulte o tutorial sobre clustering automático no Colab.

Compilação AOT (antecipada) para CPU com tfcompile

Também é possível usar uma ferramenta autônoma tfcompile, que converte o grafo do TensorFlow em código executável (somente para CPU x86-64).

Inspecionar programas compilados

O XLA fornece instalações de introspecção que permitem inspecionar os programas gerados. Para despejar os programas gerados, use a variável de ambiente XLA_FLAGS:

$ XLA_FLAGS="--xla_dump_to=/tmp/generated" TF_XLA_FLAGS="--tf_xla_auto_jit=2" my/tensorflow/program

Depois que o despejo for realizado, você poderá encontrar os seguintes arquivos em /tmp/generated:

  • O module_XXXX.*_optimizations.txt gerou programas XLA, um para cada cluster compilado. Anexá-los ao enviar relatórios de bugs do XLA é extremamente útil.

  • O module_XXXX.ir-*.ll gerou arquivos na LLVM, com intrínsecos NVPTX.

  • O module_XXXX.ptx gerou PTX.

Também é possível descarregar o gráfico que visualiza a incorporação de clusters XLA dentro do gráfico do TensorFlow com:

$ TF_DUMP_GRAPH_PREFIX=/tmp/generated TF_XLA_FLAGS="--tf_xla_clustering_debug"

Relatórios de bugs reproduzíveis

É muito mais fácil reproduzir um relatório de bug se ele inclui dumps para os programas XLA gerados e a incorporação de clustering automático usada. Para gerá-los para um programa do TensorFlow executado com clustering automático, inicie:

$ TF_DUMP_GRAPH_PREFIX=/tmp/generated \
  TF_XLA_FLAGS="--tf_xla_clustering_debug --tf_xla_auto_jit=2" \
  XLA_FLAGS="--xla_dump_hlo_as_text --xla_dump_to=/tmp/generated" \
    my/tensorflow/program"

Ao enviar relatórios de bugs, anexe o conteúdo do diretório /tmp/generated (mencionado acima).

Se possível, tente isolar um bug em um único programa XLA usando run_hlo_module e executando-o iterativamente nos programas gerados.

Leitura adicional

Front-ends de XLA

Além do TensorFlow, os programas XLA podem ser gerados por:

  • JAX: transformações combináveis de programas Python e NumPy
  • Julia: a linguagem Julia para computação científica;
  • PyTorch: framework PyTorch
  • Nx: biblioteca de computação numérica para a linguagem de programação Elixir

Palestras

Como usar XLA do TF com jit_compile=True

Visão geral da XLA