Use XLA with tf.function

View on TensorFlow.org Run in Google Colab Download notebook View source on GitHub

This tutorial trains a TensorFlow model to classify the MNIST dataset, where the training function is compiled using XLA.

First, load TensorFlow and enable eager execution.

import tensorflow as tf
2024-02-01 00:59:25.556117: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-02-01 00:59:25.556159: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-02-01 00:59:25.557586: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

Then define some necessary constants and prepare the MNIST dataset.

# Size of each input image, 28 x 28 pixels
IMAGE_SIZE = 28 * 28
# Number of distinct number labels, [0..9]
NUM_CLASSES = 10
# Number of examples in each training batch (step)
TRAIN_BATCH_SIZE = 100
# Number of training steps to run
TRAIN_STEPS = 1000

# Loads MNIST dataset.
train, test = tf.keras.datasets.mnist.load_data()
train_ds = tf.data.Dataset.from_tensor_slices(train).batch(TRAIN_BATCH_SIZE).repeat()

# Casting from raw data to the required datatypes.
def cast(images, labels):
  images = tf.cast(
      tf.reshape(images, [-1, IMAGE_SIZE]), tf.float32)
  labels = tf.cast(labels, tf.int64)
  return (images, labels)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11490434/11490434 [==============================] - 0s 0us/step

Finally, define the model and the optimizer. The model uses a single dense layer.

layer = tf.keras.layers.Dense(NUM_CLASSES)
optimizer = tf.keras.optimizers.Adam()

Define the training function

In the training function, you get the predicted labels using the layer defined above, and then minimize the gradient of the loss using the optimizer. In order to compile the computation using XLA, place it inside tf.function with jit_compile=True.

@tf.function(jit_compile=True)
def train_mnist(images, labels):
    images, labels = cast(images, labels)

    with tf.GradientTape() as tape:
      predicted_labels = layer(images)
      loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
          logits=predicted_labels, labels=labels
      ))
    layer_variables = layer.trainable_variables
    grads = tape.gradient(loss, layer_variables)
    optimizer.apply_gradients(zip(grads, layer_variables))

Train and test the model

Once you have defined the training function, define the model.

for images, labels in train_ds:
  if optimizer.iterations > TRAIN_STEPS:
    break
  train_mnist(images, labels)
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1706749171.368102    8926 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
W0000 00:00:1706749171.386805    8926 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update

And, finally, check the accuracy:

images, labels = cast(test[0], test[1])
predicted_labels = layer(images)
correct_prediction = tf.equal(tf.argmax(predicted_labels, 1), labels)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print("Prediction accuracy after training: %s" % accuracy)
Prediction accuracy after training: tf.Tensor(0.8761, shape=(), dtype=float32)

Behind the scenes, the XLA compiler has compiled the entire TF function to HLO, which has enabled fusion optimizations. Using the introspection facilities, we can see the HLO code (other interesting possible values for "stage" are optimized_hlo for HLO after optimizations and optimized_hlo_dot for a Graphviz graph):

print(train_mnist.experimental_get_compiler_ir(images, labels)(stage='hlo'))
HloModule a_inference_train_mnist_5322__.187, input_output_alias={ {0}: (2, {}, may-alias), {1}: (3, {}, may-alias), {2}: (4, {}, may-alias), {3}: (6, {}, may-alias), {4}: (7, {}, may-alias), {5}: (8, {}, may-alias), {6}: (9, {}, may-alias) }, entry_computation_layout={(f32[10000,784]{1,0}, s64[10000]{0}, f32[784,10]{1,0}, f32[10]{0}, s64[], /*index=5*/f32[], f32[784,10]{1,0}, f32[784,10]{1,0}, f32[10]{0}, f32[10]{0})->(f32[784,10]{1,0}, f32[10]{0}, s64[], f32[784,10]{1,0}, f32[784,10]{1,0}, /*index=5*/f32[10]{0}, f32[10]{0})}

%max_float_.69 (x.70: f32[], y.71: f32[]) -> f32[] {
  %x.70 = f32[] parameter(0)
  %y.71 = f32[] parameter(1)
  ROOT %maximum.72 = f32[] maximum(f32[] %x.70, f32[] %y.71)
}

%add_float_.79 (x.80: f32[], y.81: f32[]) -> f32[] {
  %x.80 = f32[] parameter(0)
  %y.81 = f32[] parameter(1)
  ROOT %add.82 = f32[] add(f32[] %x.80, f32[] %y.81)
}

%add_float_.98 (x.99: f32[], y.100: f32[]) -> f32[] {
  %x.99 = f32[] parameter(0)
  %y.100 = f32[] parameter(1)
  ROOT %add.101 = f32[] add(f32[] %x.99, f32[] %y.100)
}

%Mean-reduction.110 (x.111: f32[], y.112: f32[]) -> f32[] {
  %x.111 = f32[] parameter(0)
  %y.112 = f32[] parameter(1)
  ROOT %add.113 = f32[] add(f32[] %x.111, f32[] %y.112)
}

%region_0.125 (Arg_0.126: f32[], Arg_1.127: f32[]) -> f32[] {
  %Arg_0.126 = f32[] parameter(0)
  %Arg_1.127 = f32[] parameter(1)
  ROOT %add.128 = f32[] add(f32[] %Arg_0.126, f32[] %Arg_1.127), metadata={op_type="BiasAddGrad" op_name="gradient_tape/dense/BiasAdd/BiasAddGrad"}
}

ENTRY %a_inference_train_mnist_5322__.187 (arg0.1: f32[10000,784], arg1.2: s64[10000], arg2.3: f32[784,10], arg3.4: f32[10], arg4.5: s64[], arg5.6: f32[], arg6.7: f32[784,10], arg7.8: f32[784,10], arg8.9: f32[10], arg9.10: f32[10]) -> (f32[784,10], f32[10], s64[], f32[784,10], f32[784,10], /*index=5*/f32[10], f32[10]) {
  %arg1.2 = s64[10000]{0} parameter(1), parameter_replication={false}, metadata={op_name="XLA_Args"}
  %reshape.12 = s64[10000]{0} reshape(s64[10000]{0} %arg1.2)
  %broadcast.49 = s64[10000,10]{1,0} broadcast(s64[10000]{0} %reshape.12), dimensions={0}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %iota.48 = s64[10000,10]{1,0} iota(), iota_dimension=1, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %compare.50 = pred[10000,10]{1,0} compare(s64[10000,10]{1,0} %broadcast.49, s64[10000,10]{1,0} %iota.48), direction=EQ, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %constant.46 = f32[] constant(1), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %broadcast.51 = f32[10000,10]{1,0} broadcast(f32[] %constant.46), dimensions={}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %constant.47 = f32[] constant(0), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %broadcast.52 = f32[10000,10]{1,0} broadcast(f32[] %constant.47), dimensions={}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %select.53 = f32[10000,10]{1,0} select(pred[10000,10]{1,0} %compare.50, f32[10000,10]{1,0} %broadcast.51, f32[10000,10]{1,0} %broadcast.52), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %constant.54 = s64[] constant(0), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %broadcast.55 = s64[10000]{0} broadcast(s64[] %constant.54), dimensions={}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %compare.56 = pred[10000]{0} compare(s64[10000]{0} %broadcast.55, s64[10000]{0} %reshape.12), direction=LE, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %constant.57 = s64[] constant(10), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %broadcast.58 = s64[10000]{0} broadcast(s64[] %constant.57), dimensions={}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %compare.59 = pred[10000]{0} compare(s64[10000]{0} %reshape.12, s64[10000]{0} %broadcast.58), direction=LT, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %and.60 = pred[10000]{0} and(pred[10000]{0} %compare.56, pred[10000]{0} %compare.59), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %constant.61 = f32[] constant(0), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %broadcast.62 = f32[10000]{0} broadcast(f32[] %constant.61), dimensions={}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %constant.63 = f32[] constant(nan), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %broadcast.64 = f32[10000]{0} broadcast(f32[] %constant.63), dimensions={}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %select.65 = f32[10000]{0} select(pred[10000]{0} %and.60, f32[10000]{0} %broadcast.62, f32[10000]{0} %broadcast.64), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %broadcast.66 = f32[10000,10]{1,0} broadcast(f32[10000]{0} %select.65), dimensions={0}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %add.67 = f32[10000,10]{1,0} add(f32[10000,10]{1,0} %select.53, f32[10000,10]{1,0} %broadcast.66), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %negate.94 = f32[10000,10]{1,0} negate(f32[10000,10]{1,0} %add.67), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %constant.88 = f32[] constant(0), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %broadcast.89 = f32[10000,10]{1,0} broadcast(f32[] %constant.88), dimensions={}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %compare.90 = pred[10000,10]{1,0} compare(f32[10000,10]{1,0} %add.67, f32[10000,10]{1,0} %broadcast.89), direction=EQ, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %constant.91 = f32[] constant(0), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %broadcast.92 = f32[10000,10]{1,0} broadcast(f32[] %constant.91), dimensions={}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %arg0.1 = f32[10000,784]{1,0} parameter(0), parameter_replication={false}, metadata={op_name="XLA_Args"}
  %reshape.11 = f32[10000,784]{1,0} reshape(f32[10000,784]{1,0} %arg0.1)
  %reshape.41 = f32[10000,784]{1,0} reshape(f32[10000,784]{1,0} %reshape.11), metadata={op_type="Reshape" op_name="Reshape" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %arg2.3 = f32[784,10]{1,0} parameter(2), parameter_replication={false}, metadata={op_name="XLA_Args"}
  %dot.42 = f32[10000,10]{1,0} dot(f32[10000,784]{1,0} %reshape.41, f32[784,10]{1,0} %arg2.3), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_type="MatMul" op_name="dense/MatMul" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %transpose.43 = f32[10000,10]{1,0} transpose(f32[10000,10]{1,0} %dot.42), dimensions={0,1}, metadata={op_type="MatMul" op_name="dense/MatMul" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %arg3.4 = f32[10]{0} parameter(3), parameter_replication={false}, metadata={op_name="XLA_Args"}
  %broadcast.44 = f32[10000,10]{1,0} broadcast(f32[10]{0} %arg3.4), dimensions={1}, metadata={op_type="BiasAdd" op_name="dense/BiasAdd" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %add.45 = f32[10000,10]{1,0} add(f32[10000,10]{1,0} %transpose.43, f32[10000,10]{1,0} %broadcast.44), metadata={op_type="BiasAdd" op_name="dense/BiasAdd" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %constant.68 = f32[] constant(-inf), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %reduce.73 = f32[10000]{0} reduce(f32[10000,10]{1,0} %add.45, f32[] %constant.68), dimensions={1}, to_apply=%max_float_.69, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %broadcast.74 = f32[10000,10]{1,0} broadcast(f32[10000]{0} %reduce.73), dimensions={0}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %subtract.75 = f32[10000,10]{1,0} subtract(f32[10000,10]{1,0} %add.45, f32[10000,10]{1,0} %broadcast.74), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %exponential.76 = f32[10000,10]{1,0} exponential(f32[10000,10]{1,0} %subtract.75), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %convert.77 = f32[10000,10]{1,0} convert(f32[10000,10]{1,0} %exponential.76), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %constant.78 = f32[] constant(0), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %reduce.83 = f32[10000]{0} reduce(f32[10000,10]{1,0} %convert.77, f32[] %constant.78), dimensions={1}, to_apply=%add_float_.79, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %convert.84 = f32[10000]{0} convert(f32[10000]{0} %reduce.83), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %log.85 = f32[10000]{0} log(f32[10000]{0} %convert.84), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %broadcast.86 = f32[10000,10]{1,0} broadcast(f32[10000]{0} %log.85), dimensions={0}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %subtract.87 = f32[10000,10]{1,0} subtract(f32[10000,10]{1,0} %subtract.75, f32[10000,10]{1,0} %broadcast.86), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %select.93 = f32[10000,10]{1,0} select(pred[10000,10]{1,0} %compare.90, f32[10000,10]{1,0} %broadcast.92, f32[10000,10]{1,0} %subtract.87), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %multiply.95 = f32[10000,10]{1,0} multiply(f32[10000,10]{1,0} %negate.94, f32[10000,10]{1,0} %select.93), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %convert.96 = f32[10000,10]{1,0} convert(f32[10000,10]{1,0} %multiply.95), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %constant.97 = f32[] constant(0), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %reduce.102 = f32[10000]{0} reduce(f32[10000,10]{1,0} %convert.96, f32[] %constant.97), dimensions={1}, to_apply=%add_float_.98, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %convert.103 = f32[10000]{0} convert(f32[10000]{0} %reduce.102), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %convert.107 = f32[10000]{0} convert(f32[10000]{0} %convert.103), metadata={op_type="Mean" op_name="Mean" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %constant.108 = f32[] constant(0), metadata={op_type="Mean" op_name="Mean" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %convert.109 = f32[] convert(f32[] %constant.108), metadata={op_type="Mean" op_name="Mean" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %reduce.114 = f32[] reduce(f32[10000]{0} %convert.107, f32[] %convert.109), dimensions={0}, to_apply=%Mean-reduction.110, metadata={op_type="Mean" op_name="Mean" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %constant.115 = s32[] constant(10000), metadata={op_type="Mean" op_name="Mean" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %convert.116 = f32[] convert(s32[] %constant.115), metadata={op_type="Mean" op_name="Mean" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %divide.117 = f32[] divide(f32[] %reduce.114, f32[] %convert.116), metadata={op_type="Mean" op_name="Mean" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %convert.118 = f32[] convert(f32[] %divide.117), metadata={op_type="Mean" op_name="Mean" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %arg6.7 = f32[784,10]{1,0} parameter(6), parameter_replication={false}, metadata={op_name="XLA_Args"}
  %constant.119 = f32[] constant(0.0001), metadata={op_type="Mul" op_name="gradient_tape/SparseSoftmaxCrossEntropyWithLogits/mul" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %broadcast.120 = f32[10000,1]{1,0} broadcast(f32[] %constant.119), dimensions={}, metadata={op_type="Mul" op_name="gradient_tape/SparseSoftmaxCrossEntropyWithLogits/mul" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %reshape.121 = f32[10000]{0} reshape(f32[10000,1]{1,0} %broadcast.120), metadata={op_type="Mul" op_name="gradient_tape/SparseSoftmaxCrossEntropyWithLogits/mul" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %broadcast.122 = f32[10000,10]{1,0} broadcast(f32[10000]{0} %reshape.121), dimensions={0}, metadata={op_type="Mul" op_name="gradient_tape/SparseSoftmaxCrossEntropyWithLogits/mul" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %broadcast.104 = f32[10000,10]{1,0} broadcast(f32[10000]{0} %convert.84), dimensions={0}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %divide.105 = f32[10000,10]{1,0} divide(f32[10000,10]{1,0} %exponential.76, f32[10000,10]{1,0} %broadcast.104), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %subtract.106 = f32[10000,10]{1,0} subtract(f32[10000,10]{1,0} %divide.105, f32[10000,10]{1,0} %add.67), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %multiply.123 = f32[10000,10]{1,0} multiply(f32[10000,10]{1,0} %broadcast.122, f32[10000,10]{1,0} %subtract.106), metadata={op_type="Mul" op_name="gradient_tape/SparseSoftmaxCrossEntropyWithLogits/mul" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %dot.130 = f32[784,10]{1,0} dot(f32[10000,784]{1,0} %reshape.41, f32[10000,10]{1,0} %multiply.123), lhs_contracting_dims={0}, rhs_contracting_dims={0}, metadata={op_type="MatMul" op_name="gradient_tape/dense/MatMul/MatMul" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %transpose.131 = f32[784,10]{1,0} transpose(f32[784,10]{1,0} %dot.130), dimensions={0,1}, metadata={op_type="MatMul" op_name="gradient_tape/dense/MatMul/MatMul" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %subtract.142 = f32[784,10]{1,0} subtract(f32[784,10]{1,0} %transpose.131, f32[784,10]{1,0} %arg6.7), metadata={op_type="Sub" op_name="Adam/StatefulPartitionedCall/sub_2" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %constant.143 = f32[] constant(0.1), metadata={op_type="Mul" op_name="Adam/StatefulPartitionedCall/mul_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %broadcast.144 = f32[784,10]{1,0} broadcast(f32[] %constant.143), dimensions={}, metadata={op_type="Mul" op_name="Adam/StatefulPartitionedCall/mul_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %multiply.145 = f32[784,10]{1,0} multiply(f32[784,10]{1,0} %subtract.142, f32[784,10]{1,0} %broadcast.144), metadata={op_type="Mul" op_name="Adam/StatefulPartitionedCall/mul_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %add.146 = f32[784,10]{1,0} add(f32[784,10]{1,0} %arg6.7, f32[784,10]{1,0} %multiply.145), metadata={op_type="AssignAddVariableOp" op_name="Adam/StatefulPartitionedCall/AssignAddVariableOp" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %arg5.6 = f32[] parameter(5), parameter_replication={false}, metadata={op_name="XLA_Args"}
  %constant.36 = f32[] constant(1), metadata={op_type="Sub" op_name="Adam/StatefulPartitionedCall/sub" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %constant.34 = f32[] constant(0.999), metadata={op_type="Pow" op_name="Adam/StatefulPartitionedCall/Pow_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %arg4.5 = s64[] parameter(4), parameter_replication={false}, metadata={op_name="XLA_Args"}
  %constant.27 = s64[] constant(1), metadata={op_type="AddV2" op_name="Adam/StatefulPartitionedCall/add" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %add.28 = s64[] add(s64[] %arg4.5, s64[] %constant.27), metadata={op_type="AddV2" op_name="Adam/StatefulPartitionedCall/add" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %convert.29 = f32[] convert(s64[] %add.28), metadata={op_type="Cast" op_name="Adam/StatefulPartitionedCall/Cast" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %power.35 = f32[] power(f32[] %constant.34, f32[] %convert.29), metadata={op_type="Pow" op_name="Adam/StatefulPartitionedCall/Pow_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %subtract.37 = f32[] subtract(f32[] %constant.36, f32[] %power.35), metadata={op_type="Sub" op_name="Adam/StatefulPartitionedCall/sub" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %sqrt.38 = f32[] sqrt(f32[] %subtract.37), metadata={op_type="Sqrt" op_name="Sqrt;Adam/StatefulPartitionedCall/Sqrt"}
  %multiply.39 = f32[] multiply(f32[] %arg5.6, f32[] %sqrt.38), metadata={op_type="Mul" op_name="Adam/StatefulPartitionedCall/mul" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %constant.32 = f32[] constant(1), metadata={op_type="Sub" op_name="Adam/StatefulPartitionedCall/sub_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %constant.30 = f32[] constant(0.9), metadata={op_type="Pow" op_name="Adam/StatefulPartitionedCall/Pow" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %power.31 = f32[] power(f32[] %constant.30, f32[] %convert.29), metadata={op_type="Pow" op_name="Adam/StatefulPartitionedCall/Pow" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %subtract.33 = f32[] subtract(f32[] %constant.32, f32[] %power.31), metadata={op_type="Sub" op_name="Adam/StatefulPartitionedCall/sub_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %divide.40 = f32[] divide(f32[] %multiply.39, f32[] %subtract.33), metadata={op_type="RealDiv" op_name="Adam/StatefulPartitionedCall/truediv" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %broadcast.147 = f32[784,10]{1,0} broadcast(f32[] %divide.40), dimensions={}, metadata={op_type="Mul" op_name="Adam/StatefulPartitionedCall/mul_3" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %multiply.148 = f32[784,10]{1,0} multiply(f32[784,10]{1,0} %add.146, f32[784,10]{1,0} %broadcast.147), metadata={op_type="Mul" op_name="Adam/StatefulPartitionedCall/mul_3" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %arg7.8 = f32[784,10]{1,0} parameter(7), parameter_replication={false}, metadata={op_name="XLA_Args"}
  %multiply.132 = f32[784,10]{1,0} multiply(f32[784,10]{1,0} %transpose.131, f32[784,10]{1,0} %transpose.131), metadata={op_type="Square" op_name="Adam/StatefulPartitionedCall/Square" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %subtract.133 = f32[784,10]{1,0} subtract(f32[784,10]{1,0} %multiply.132, f32[784,10]{1,0} %arg7.8), metadata={op_type="Sub" op_name="Adam/StatefulPartitionedCall/sub_3" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %constant.134 = f32[] constant(0.001), metadata={op_type="Mul" op_name="Adam/StatefulPartitionedCall/mul_2" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %broadcast.135 = f32[784,10]{1,0} broadcast(f32[] %constant.134), dimensions={}, metadata={op_type="Mul" op_name="Adam/StatefulPartitionedCall/mul_2" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %multiply.136 = f32[784,10]{1,0} multiply(f32[784,10]{1,0} %subtract.133, f32[784,10]{1,0} %broadcast.135), metadata={op_type="Mul" op_name="Adam/StatefulPartitionedCall/mul_2" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %add.137 = f32[784,10]{1,0} add(f32[784,10]{1,0} %arg7.8, f32[784,10]{1,0} %multiply.136), metadata={op_type="AssignAddVariableOp" op_name="Adam/StatefulPartitionedCall/AssignAddVariableOp_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %sqrt.138 = f32[784,10]{1,0} sqrt(f32[784,10]{1,0} %add.137), metadata={op_type="Sqrt" op_name="Sqrt_1;Adam/StatefulPartitionedCall/Sqrt_1"}
  %constant.139 = f32[] constant(1e-07), metadata={op_type="AddV2" op_name="Adam/StatefulPartitionedCall/add_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %broadcast.140 = f32[784,10]{1,0} broadcast(f32[] %constant.139), dimensions={}, metadata={op_type="AddV2" op_name="Adam/StatefulPartitionedCall/add_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %add.141 = f32[784,10]{1,0} add(f32[784,10]{1,0} %sqrt.138, f32[784,10]{1,0} %broadcast.140), metadata={op_type="AddV2" op_name="Adam/StatefulPartitionedCall/add_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %divide.149 = f32[784,10]{1,0} divide(f32[784,10]{1,0} %multiply.148, f32[784,10]{1,0} %add.141), metadata={op_type="RealDiv" op_name="Adam/StatefulPartitionedCall/truediv_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %subtract.150 = f32[784,10]{1,0} subtract(f32[784,10]{1,0} %arg2.3, f32[784,10]{1,0} %divide.149), metadata={op_type="AssignSubVariableOp" op_name="Adam/StatefulPartitionedCall/AssignSubVariableOp" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %reshape.172 = f32[784,10]{1,0} reshape(f32[784,10]{1,0} %subtract.150), metadata={op_name="XLA_Retvals"}
  %copy.173 = f32[784,10]{1,0} copy(f32[784,10]{1,0} %reshape.172), metadata={op_name="XLA_Retvals"}
  %arg8.9 = f32[10]{0} parameter(8), parameter_replication={false}, metadata={op_name="XLA_Args"}
  %constant.124 = f32[] constant(-0), metadata={op_type="BiasAddGrad" op_name="gradient_tape/dense/BiasAdd/BiasAddGrad" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %reduce.129 = f32[10]{0} reduce(f32[10000,10]{1,0} %multiply.123, f32[] %constant.124), dimensions={0}, to_apply=%region_0.125, metadata={op_type="BiasAddGrad" op_name="gradient_tape/dense/BiasAdd/BiasAddGrad"}
  %subtract.161 = f32[10]{0} subtract(f32[10]{0} %reduce.129, f32[10]{0} %arg8.9), metadata={op_type="Sub" op_name="Adam/StatefulPartitionedCall_1/sub_2" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %constant.162 = f32[] constant(0.1), metadata={op_type="Mul" op_name="Adam/StatefulPartitionedCall_1/mul_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %broadcast.163 = f32[10]{0} broadcast(f32[] %constant.162), dimensions={}, metadata={op_type="Mul" op_name="Adam/StatefulPartitionedCall_1/mul_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %multiply.164 = f32[10]{0} multiply(f32[10]{0} %subtract.161, f32[10]{0} %broadcast.163), metadata={op_type="Mul" op_name="Adam/StatefulPartitionedCall_1/mul_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %add.165 = f32[10]{0} add(f32[10]{0} %arg8.9, f32[10]{0} %multiply.164), metadata={op_type="AssignAddVariableOp" op_name="Adam/StatefulPartitionedCall_1/AssignAddVariableOp" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %constant.22 = f32[] constant(1), metadata={op_type="Sub" op_name="Adam/StatefulPartitionedCall_1/sub" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %constant.20 = f32[] constant(0.999), metadata={op_type="Pow" op_name="Adam/StatefulPartitionedCall_1/Pow_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %constant.13 = s64[] constant(1), metadata={op_type="AddV2" op_name="Adam/StatefulPartitionedCall_1/add" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %add.14 = s64[] add(s64[] %arg4.5, s64[] %constant.13), metadata={op_type="AddV2" op_name="Adam/StatefulPartitionedCall_1/add" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %convert.15 = f32[] convert(s64[] %add.14), metadata={op_type="Cast" op_name="Adam/StatefulPartitionedCall_1/Cast" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %power.21 = f32[] power(f32[] %constant.20, f32[] %convert.15), metadata={op_type="Pow" op_name="Adam/StatefulPartitionedCall_1/Pow_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %subtract.23 = f32[] subtract(f32[] %constant.22, f32[] %power.21), metadata={op_type="Sub" op_name="Adam/StatefulPartitionedCall_1/sub" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %sqrt.24 = f32[] sqrt(f32[] %subtract.23), metadata={op_type="Sqrt" op_name="Sqrt;Adam/StatefulPartitionedCall_1/Sqrt"}
  %multiply.25 = f32[] multiply(f32[] %arg5.6, f32[] %sqrt.24), metadata={op_type="Mul" op_name="Adam/StatefulPartitionedCall_1/mul" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %constant.18 = f32[] constant(1), metadata={op_type="Sub" op_name="Adam/StatefulPartitionedCall_1/sub_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %constant.16 = f32[] constant(0.9), metadata={op_type="Pow" op_name="Adam/StatefulPartitionedCall_1/Pow" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %power.17 = f32[] power(f32[] %constant.16, f32[] %convert.15), metadata={op_type="Pow" op_name="Adam/StatefulPartitionedCall_1/Pow" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %subtract.19 = f32[] subtract(f32[] %constant.18, f32[] %power.17), metadata={op_type="Sub" op_name="Adam/StatefulPartitionedCall_1/sub_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %divide.26 = f32[] divide(f32[] %multiply.25, f32[] %subtract.19), metadata={op_type="RealDiv" op_name="Adam/StatefulPartitionedCall_1/truediv" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %broadcast.166 = f32[10]{0} broadcast(f32[] %divide.26), dimensions={}, metadata={op_type="Mul" op_name="Adam/StatefulPartitionedCall_1/mul_3" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %multiply.167 = f32[10]{0} multiply(f32[10]{0} %add.165, f32[10]{0} %broadcast.166), metadata={op_type="Mul" op_name="Adam/StatefulPartitionedCall_1/mul_3" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %arg9.10 = f32[10]{0} parameter(9), parameter_replication={false}, metadata={op_name="XLA_Args"}
  %multiply.151 = f32[10]{0} multiply(f32[10]{0} %reduce.129, f32[10]{0} %reduce.129), metadata={op_type="Square" op_name="Adam/StatefulPartitionedCall_1/Square" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %subtract.152 = f32[10]{0} subtract(f32[10]{0} %multiply.151, f32[10]{0} %arg9.10), metadata={op_type="Sub" op_name="Adam/StatefulPartitionedCall_1/sub_3" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %constant.153 = f32[] constant(0.001), metadata={op_type="Mul" op_name="Adam/StatefulPartitionedCall_1/mul_2" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %broadcast.154 = f32[10]{0} broadcast(f32[] %constant.153), dimensions={}, metadata={op_type="Mul" op_name="Adam/StatefulPartitionedCall_1/mul_2" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %multiply.155 = f32[10]{0} multiply(f32[10]{0} %subtract.152, f32[10]{0} %broadcast.154), metadata={op_type="Mul" op_name="Adam/StatefulPartitionedCall_1/mul_2" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %add.156 = f32[10]{0} add(f32[10]{0} %arg9.10, f32[10]{0} %multiply.155), metadata={op_type="AssignAddVariableOp" op_name="Adam/StatefulPartitionedCall_1/AssignAddVariableOp_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %sqrt.157 = f32[10]{0} sqrt(f32[10]{0} %add.156), metadata={op_type="Sqrt" op_name="Sqrt_1;Adam/StatefulPartitionedCall_1/Sqrt_1"}
  %constant.158 = f32[] constant(1e-07), metadata={op_type="AddV2" op_name="Adam/StatefulPartitionedCall_1/add_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %broadcast.159 = f32[10]{0} broadcast(f32[] %constant.158), dimensions={}, metadata={op_type="AddV2" op_name="Adam/StatefulPartitionedCall_1/add_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %add.160 = f32[10]{0} add(f32[10]{0} %sqrt.157, f32[10]{0} %broadcast.159), metadata={op_type="AddV2" op_name="Adam/StatefulPartitionedCall_1/add_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %divide.168 = f32[10]{0} divide(f32[10]{0} %multiply.167, f32[10]{0} %add.160), metadata={op_type="RealDiv" op_name="Adam/StatefulPartitionedCall_1/truediv_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %subtract.169 = f32[10]{0} subtract(f32[10]{0} %arg3.4, f32[10]{0} %divide.168), metadata={op_type="AssignSubVariableOp" op_name="Adam/StatefulPartitionedCall_1/AssignSubVariableOp" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %reshape.174 = f32[10]{0} reshape(f32[10]{0} %subtract.169), metadata={op_name="XLA_Retvals"}
  %copy.175 = f32[10]{0} copy(f32[10]{0} %reshape.174), metadata={op_name="XLA_Retvals"}
  %constant.170 = s64[] constant(1), metadata={op_type="AssignAddVariableOp" op_name="Adam/AssignAddVariableOp" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %add.171 = s64[] add(s64[] %arg4.5, s64[] %constant.170), metadata={op_type="AssignAddVariableOp" op_name="Adam/AssignAddVariableOp" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1160}
  %reshape.176 = s64[] reshape(s64[] %add.171), metadata={op_name="XLA_Retvals"}
  %copy.177 = s64[] copy(s64[] %reshape.176), metadata={op_name="XLA_Retvals"}
  %reshape.178 = f32[784,10]{1,0} reshape(f32[784,10]{1,0} %add.146), metadata={op_name="XLA_Retvals"}
  %copy.179 = f32[784,10]{1,0} copy(f32[784,10]{1,0} %reshape.178), metadata={op_name="XLA_Retvals"}
  %reshape.180 = f32[784,10]{1,0} reshape(f32[784,10]{1,0} %add.137), metadata={op_name="XLA_Retvals"}
  %copy.181 = f32[784,10]{1,0} copy(f32[784,10]{1,0} %reshape.180), metadata={op_name="XLA_Retvals"}
  %reshape.182 = f32[10]{0} reshape(f32[10]{0} %add.165), metadata={op_name="XLA_Retvals"}
  %copy.183 = f32[10]{0} copy(f32[10]{0} %reshape.182), metadata={op_name="XLA_Retvals"}
  %reshape.184 = f32[10]{0} reshape(f32[10]{0} %add.156), metadata={op_name="XLA_Retvals"}
  %copy.185 = f32[10]{0} copy(f32[10]{0} %reshape.184), metadata={op_name="XLA_Retvals"}
  ROOT %tuple.186 = (f32[784,10]{1,0}, f32[10]{0}, s64[], f32[784,10]{1,0}, f32[784,10]{1,0}, /*index=5*/f32[10]{0}, f32[10]{0}) tuple(f32[784,10]{1,0} %copy.173, f32[10]{0} %copy.175, s64[] %copy.177, f32[784,10]{1,0} %copy.179, f32[784,10]{1,0} %copy.181, /*index=5*/f32[10]{0} %copy.183, f32[10]{0} %copy.185), metadata={op_name="XLA_Retvals"}
}