O que é tfcompile?
O tfcompile
é uma ferramenta autônoma que compila gráficos do TensorFlow
em código executável antecipadamente (AOT, na sigla em inglês). Isso pode reduzir o tamanho total do binário e também evitar
algumas sobrecargas no ambiente de execução. Um caso de uso típico de tfcompile
é compilar um
gráfico de inferência em um código executável para dispositivos móveis.
O gráfico do TensorFlow normalmente é executado pelo ambiente de execução do TensorFlow. Isso gera
uma sobrecarga no ambiente de execução para a execução de cada nó no gráfico. Isso também leva
a um tamanho total do binário maior, já que o código do ambiente de execução do TensorFlow precisa
estar disponível, além do gráfico em si. O código executável produzido
por tfcompile
não usa o ambiente de execução do TensorFlow e só tem dependências em
kernels que são realmente usados na computação.
O compilador é baseado no framework XLA. O código que vincula o TensorFlow ao framework XLA fica em tensorflow/compiler.
O que o tfcompile faz?
O tfcompile
usa um subgráfico, identificado pelos conceitos de feeds e buscas do TensorFlow, e gera uma função que implementa esse subgráfico.
feeds
são os argumentos de entrada da função e fetches
são os argumentos de saída da função. Todas as entradas precisam ser totalmente especificadas pelos feeds. O subgráfico reduzido resultante não pode conter nós de marcador ou de variável. É comum especificar todos os marcadores de posição e variáveis como feeds, o que
garante que o subgráfico resultante não contenha mais esses nós. A função
gerada é empacotada como um cc_library
, com um arquivo principal que exporta a
assinatura da função e um arquivo de objeto que contém a implementação. O usuário
escreve o código para invocar a função gerada conforme apropriado.
Como usar tfcompile
Esta seção detalha as etapas avançadas para gerar um binário executável com
tfcompile
de um subgráfico do TensorFlow. Essas etapas são:
- Etapa 1: configurar o subgráfico para compilação
- Etapa 2: usar a macro de build
tf_library
para compilar o subgráfico - Etapa 3: escrever o código para invocar o subgráfico
- Etapa 4: criar o binário final
Etapa 1: configurar o subgráfico para compilação
Identifique os feeds e as buscas que correspondem aos argumentos de entrada e saída para a função gerada. Em seguida, configure feeds
e fetches
em um .proto
tensorflow.tf2xla.Config
.
# Each feed is a positional input argument for the generated function. The order
# of each entry matches the order of each input argument. Here “x_hold” and “y_hold”
# refer to the names of placeholder nodes defined in the graph.
feed {
id { node_name: "x_hold" }
shape {
dim { size: 2 }
dim { size: 3 }
}
}
feed {
id { node_name: "y_hold" }
shape {
dim { size: 3 }
dim { size: 2 }
}
}
# Each fetch is a positional output argument for the generated function. The order
# of each entry matches the order of each output argument. Here “x_y_prod”
# refers to the name of a matmul node defined in the graph.
fetch {
id { node_name: "x_y_prod" }
}
Etapa 2: usar a macro de build tf_library para compilar o subgráfico
Esta etapa converte o gráfico em um cc_library
usando a macro de build
tf_library
. O cc_library
consiste em um arquivo de objeto que contém o código gerado
do gráfico, com um arquivo principal que fornece acesso ao código
gerado. tf_library
usa tfcompile
para compilar o gráfico do TensorFlow em
código executável.
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
# Use the tf_library macro to compile your graph into executable code.
tf_library(
# name is used to generate the following underlying build rules:
# <name> : cc_library packaging the generated header and object files
# <name>_test : cc_test containing a simple test and benchmark
# <name>_benchmark : cc_binary containing a stand-alone benchmark with minimal deps;
# can be run on a mobile device
name = "test_graph_tfmatmul",
# cpp_class specifies the name of the generated C++ class, with namespaces allowed.
# The class will be generated in the given namespace(s), or if no namespaces are
# given, within the global namespace.
cpp_class = "foo::bar::MatMulComp",
# graph is the input GraphDef proto, by default expected in binary format. To
# use the text format instead, just use the ‘.pbtxt’ suffix. A subgraph will be
# created from this input graph, with feeds as inputs and fetches as outputs.
# No Placeholder or Variable ops may exist in this subgraph.
graph = "test_graph_tfmatmul.pb",
# config is the input Config proto, by default expected in binary format. To
# use the text format instead, use the ‘.pbtxt’ suffix. This is where the
# feeds and fetches were specified above, in the previous step.
config = "test_graph_tfmatmul.config.pbtxt",
)
Para gerar o proto do GraphDef (test_graph_tfmatmul.pb) deste exemplo, execute make_test_graphs.py e especifique o local da saída com a flag --out_dir.
Os gráficos comuns contêm Variables
que representa os pesos aprendidos por treinamento, mas tfcompile
não pode
compilar um subgráfico que contenha Variables
. A ferramenta
freeze_graph.py
converte as variáveis em constantes usando valores armazenados em um arquivo de checkpoint. Por conveniência, a macro tf_library
é compatível com o argumento freeze_checkpoint
, que executa a ferramenta. Para mais exemplos, consulte
tensorflow/compiler/aot/tests/BUILD.
As constantes que aparecem no subgráfico compilado são compiladas diretamente no código gerado. Para transmitir as constantes para a função gerada, em vez de elas serem compiladas, basta transmiti-las como feeds.
Para detalhes sobre a macro de build tf_library
, consulte
tfcompile.bzl (em inglês).
Para detalhes sobre a ferramenta tfcompile
subjacente, consulte
tfcompile_main.cc.
Etapa 3: escrever o código para invocar o subgráfico
Esta etapa usa o arquivo principal (test_graph_tfmatmul.h
) gerado pela
macro de build tf_library
na etapa anterior para invocar o código gerado. O
arquivo de cabeçalho está localizado no diretório bazel-bin
correspondente ao
pacote de build e é nomeado com base no atributo de nome definido na macro de build
tf_library
. Por exemplo, o cabeçalho gerado para test_graph_tfmatmul
seria
test_graph_tfmatmul.h
. Confira abaixo uma versão resumida do que é gerado. O arquivo gerado, em bazel-bin
, contém outros comentários
úteis.
namespace foo {
namespace bar {
// MatMulComp represents a computation previously specified in a
// TensorFlow graph, now compiled into executable code.
class MatMulComp {
public:
// AllocMode controls the buffer allocation mode.
enum class AllocMode {
ARGS_RESULTS_AND_TEMPS, // Allocate arg, result and temp buffers
RESULTS_AND_TEMPS_ONLY, // Only allocate result and temp buffers
};
MatMulComp(AllocMode mode = AllocMode::ARGS_RESULTS_AND_TEMPS);
~MatMulComp();
// Runs the computation, with inputs read from arg buffers, and outputs
// written to result buffers. Returns true on success and false on failure.
bool Run();
// Arg methods for managing input buffers. Buffers are in row-major order.
// There is a set of methods for each positional argument.
void** args();
void set_arg0_data(float* data);
float* arg0_data();
float& arg0(size_t dim0, size_t dim1);
void set_arg1_data(float* data);
float* arg1_data();
float& arg1(size_t dim0, size_t dim1);
// Result methods for managing output buffers. Buffers are in row-major order.
// Must only be called after a successful Run call. There is a set of methods
// for each positional result.
void** results();
float* result0_data();
float& result0(size_t dim0, size_t dim1);
};
} // end namespace bar
} // end namespace foo
A classe C++ gerada é chamada de MatMulComp
no namespace foo::bar
,
porque esse foi o cpp_class
especificado na macro tf_library
. Todas
as classes geradas têm uma API semelhante, com a única diferença sendo os métodos
para processar buffers de argumento e resultado. Esses métodos diferem com base no número e
nos tipos de buffers, que foram especificados pelos argumentos feed
e fetch
para a macro tf_library
.
Há três tipos de buffers gerenciados na classe gerada: args
representando as entradas, results
representando as saídas e temps
representando os buffers temporários usados internamente para executar o cálculo. Por
padrão, cada instância da classe gerada aloca e gerencia todos esses
buffers para você. O argumento do construtor AllocMode
pode ser usado para mudar esse
comportamento. Todos os buffers estão alinhados a limites de 64 bytes.
A classe C++ gerada é apenas um wrapper em torno do código de baixo nível gerado pelo XLA.
Exemplo de invocação da função gerada com base em
tfcompile_test.cc
:
#define EIGEN_USE_THREADS
#define EIGEN_USE_CUSTOM_THREAD_POOL
#include <iostream>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "third_party/tensorflow/compiler/aot/tests/test_graph_tfmatmul.h" // generated
int main(int argc, char** argv) {
Eigen::ThreadPool tp(2); // Size the thread pool as appropriate.
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
foo::bar::MatMulComp matmul;
matmul.set_thread_pool(&device);
// Set up args and run the computation.
const float args[12] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
std::copy(args + 0, args + 6, matmul.arg0_data());
std::copy(args + 6, args + 12, matmul.arg1_data());
matmul.Run();
// Check result
if (matmul.result0(0, 0) == 58) {
std::cout << "Success" << std::endl;
} else {
std::cout << "Failed. Expected value 58 at 0,0. Got:"
<< matmul.result0(0, 0) << std::endl;
}
return 0;
}
Etapa 4: criar o binário final
Esta etapa combina a biblioteca gerada por tf_library
na etapa 2 e o código escrito na etapa 3 para criar um binário final. Veja abaixo um exemplo de arquivo BUILD bazel
.
# Example of linking your binary
# Also see //tensorflow/compiler/aot/tests/BUILD
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
# The same tf_library call from step 2 above.
tf_library(
name = "test_graph_tfmatmul",
...
)
# The executable code generated by tf_library can then be linked into your code.
cc_binary(
name = "my_binary",
srcs = [
"my_code.cc", # include test_graph_tfmatmul.h to access the generated header
],
deps = [
":test_graph_tfmatmul", # link in the generated object file
"//third_party/eigen3",
],
linkopts = [
"-lpthread",
]
)