View on TensorFlow.org | Run in Google Colab | Download notebook | View source on GitHub |
This tutorial trains a TensorFlow model to classify the MNIST dataset, where the training function is compiled using XLA.
First, load TensorFlow and enable eager execution.
import tensorflow as tf
2024-07-19 11:23:45.216189: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2024-07-19 11:23:45.236918: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2024-07-19 11:23:45.243232: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Then define some necessary constants and prepare the MNIST dataset.
# Size of each input image, 28 x 28 pixels
IMAGE_SIZE = 28 * 28
# Number of distinct number labels, [0..9]
NUM_CLASSES = 10
# Number of examples in each training batch (step)
TRAIN_BATCH_SIZE = 100
# Number of training steps to run
TRAIN_STEPS = 1000
# Loads MNIST dataset.
train, test = tf.keras.datasets.mnist.load_data()
train_ds = tf.data.Dataset.from_tensor_slices(train).batch(TRAIN_BATCH_SIZE).repeat()
# Casting from raw data to the required datatypes.
def cast(images, labels):
images = tf.cast(
tf.reshape(images, [-1, IMAGE_SIZE]), tf.float32)
labels = tf.cast(labels, tf.int64)
return (images, labels)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz 11490434/11490434 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1721388228.402026 14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721388228.405848 14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721388228.409509 14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721388228.413274 14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721388228.424828 14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721388228.428253 14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721388228.431743 14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721388228.435175 14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721388228.438568 14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721388228.441992 14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721388228.445396 14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721388228.448821 14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721388229.713543 14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721388229.715722 14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721388229.717723 14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721388229.719825 14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721388229.721969 14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721388229.724001 14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721388229.725904 14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721388229.727964 14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721388229.730060 14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721388229.732070 14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721388229.733963 14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721388229.736063 14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721388229.774171 14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721388229.776268 14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721388229.778211 14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721388229.780282 14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721388229.782155 14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721388229.784174 14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721388229.786611 14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721388229.788610 14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721388229.790486 14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721388229.793023 14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721388229.795337 14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721388229.797720 14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
Finally, define the model and the optimizer. The model uses a single dense layer.
layer = tf.keras.layers.Dense(NUM_CLASSES)
optimizer = tf.keras.optimizers.Adam()
Define the training function
In the training function, you get the predicted labels using the layer defined above, and then minimize the gradient of the loss using the optimizer. In order to compile the computation using XLA, place it inside tf.function
with jit_compile=True
.
@tf.function(jit_compile=True)
def train_mnist(images, labels):
images, labels = cast(images, labels)
with tf.GradientTape() as tape:
predicted_labels = layer(images)
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=predicted_labels, labels=labels
))
layer_variables = layer.trainable_variables
grads = tape.gradient(loss, layer_variables)
optimizer.apply_gradients(zip(grads, layer_variables))
Train and test the model
Once you have defined the training function, define the model.
for images, labels in train_ds:
if optimizer.iterations > TRAIN_STEPS:
break
train_mnist(images, labels)
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1721388230.602891 14480 service.cc:146] XLA service 0xb51fe10 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices: I0000 00:00:1721388230.602933 14480 service.cc:154] StreamExecutor device (0): Tesla T4, Compute Capability 7.5 I0000 00:00:1721388230.602937 14480 service.cc:154] StreamExecutor device (1): Tesla T4, Compute Capability 7.5 I0000 00:00:1721388230.602940 14480 service.cc:154] StreamExecutor device (2): Tesla T4, Compute Capability 7.5 I0000 00:00:1721388230.602942 14480 service.cc:154] StreamExecutor device (3): Tesla T4, Compute Capability 7.5 I0000 00:00:1721388230.941982 14480 device_compiler.h:188] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.
And, finally, check the accuracy:
images, labels = cast(test[0], test[1])
predicted_labels = layer(images)
correct_prediction = tf.equal(tf.argmax(predicted_labels, 1), labels)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print("Prediction accuracy after training: %s" % accuracy)
Prediction accuracy after training: tf.Tensor(0.8818, shape=(), dtype=float32)
Behind the scenes, the XLA compiler has compiled the entire TF function to HLO, which has enabled fusion optimizations. Using the introspection facilities, we can see the HLO code (other interesting possible values for "stage" are optimized_hlo
for HLO after optimizations and optimized_hlo_dot
for a Graphviz graph):
print(train_mnist.experimental_get_compiler_ir(images, labels)(stage='hlo'))
HloModule a_inference_train_mnist_5553__.192, input_output_alias={ {0}: (2, {}, may-alias), {1}: (3, {}, may-alias), {2}: (5, {}, may-alias), {3}: (6, {}, may-alias), {4}: (7, {}, may-alias), {5}: (8, {}, may-alias), {6}: (9, {}, may-alias) }, entry_computation_layout={(f32[10000,784]{1,0}, s64[10000]{0}, f32[784,10]{1,0}, f32[10]{0}, f32[], /*index=5*/s64[], f32[784,10]{1,0}, f32[784,10]{1,0}, f32[10]{0}, f32[10]{0})->(f32[784,10]{1,0}, f32[10]{0}, s64[], f32[784,10]{1,0}, f32[784,10]{1,0}, /*index=5*/f32[10]{0}, f32[10]{0})} %max_float_.71 (x.72: f32[], y.73: f32[]) -> f32[] { %x.72 = f32[] parameter(0) %y.73 = f32[] parameter(1) ROOT %maximum.74 = f32[] maximum(f32[] %x.72, f32[] %y.73) } %add_float_.81 (x.82: f32[], y.83: f32[]) -> f32[] { %x.82 = f32[] parameter(0) %y.83 = f32[] parameter(1) ROOT %add.84 = f32[] add(f32[] %x.82, f32[] %y.83) } %add_float_.100 (x.101: f32[], y.102: f32[]) -> f32[] { %x.101 = f32[] parameter(0) %y.102 = f32[] parameter(1) ROOT %add.103 = f32[] add(f32[] %x.101, f32[] %y.102) } %Mean-reduction.112 (x.113: f32[], y.114: f32[]) -> f32[] { %x.113 = f32[] parameter(0) %y.114 = f32[] parameter(1) ROOT %add.115 = f32[] add(f32[] %x.113, f32[] %y.114) } %gradient_tape_dense_1_Add_Sum-reduction.129 (x.130: f32[], y.131: f32[]) -> f32[] { %x.130 = f32[] parameter(0) %y.131 = f32[] parameter(1) ROOT %add.132 = f32[] add(f32[] %x.130, f32[] %y.131) } ENTRY %a_inference_train_mnist_5553__.192 (arg0.1: f32[10000,784], arg1.2: s64[10000], arg2.3: f32[784,10], arg3.4: f32[10], arg4.5: f32[], arg5.6: s64[], arg6.7: f32[784,10], arg7.8: f32[784,10], arg8.9: f32[10], arg9.10: f32[10]) -> (f32[784,10], f32[10], s64[], f32[784,10], f32[784,10], /*index=5*/f32[10], f32[10]) { %arg1.2 = s64[10000]{0} parameter(1), parameter_replication={false}, metadata={op_name="XLA_Args"} %reshape.12 = s64[10000]{0} reshape(s64[10000]{0} %arg1.2) %broadcast.51 = s64[10000,10]{1,0} broadcast(s64[10000]{0} %reshape.12), dimensions={0}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %iota.50 = s64[10000,10]{1,0} iota(), iota_dimension=1, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %compare.52 = pred[10000,10]{1,0} compare(s64[10000,10]{1,0} %broadcast.51, s64[10000,10]{1,0} %iota.50), direction=EQ, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %constant.48 = f32[] constant(1), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %broadcast.53 = f32[10000,10]{1,0} broadcast(f32[] %constant.48), dimensions={}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %constant.49 = f32[] constant(0), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %broadcast.54 = f32[10000,10]{1,0} broadcast(f32[] %constant.49), dimensions={}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %select.55 = f32[10000,10]{1,0} select(pred[10000,10]{1,0} %compare.52, f32[10000,10]{1,0} %broadcast.53, f32[10000,10]{1,0} %broadcast.54), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %constant.56 = s64[] constant(0), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %broadcast.57 = s64[10000]{0} broadcast(s64[] %constant.56), dimensions={}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %compare.58 = pred[10000]{0} compare(s64[10000]{0} %broadcast.57, s64[10000]{0} %reshape.12), direction=LE, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %constant.59 = s64[] constant(10), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %broadcast.60 = s64[10000]{0} broadcast(s64[] %constant.59), dimensions={}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %compare.61 = pred[10000]{0} compare(s64[10000]{0} %reshape.12, s64[10000]{0} %broadcast.60), direction=LT, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %and.62 = pred[10000]{0} and(pred[10000]{0} %compare.58, pred[10000]{0} %compare.61), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %constant.63 = f32[] constant(0), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %broadcast.64 = f32[10000]{0} broadcast(f32[] %constant.63), dimensions={}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %constant.65 = f32[] constant(nan), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %broadcast.66 = f32[10000]{0} broadcast(f32[] %constant.65), dimensions={}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %select.67 = f32[10000]{0} select(pred[10000]{0} %and.62, f32[10000]{0} %broadcast.64, f32[10000]{0} %broadcast.66), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %broadcast.68 = f32[10000,10]{1,0} broadcast(f32[10000]{0} %select.67), dimensions={0}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %add.69 = f32[10000,10]{1,0} add(f32[10000,10]{1,0} %select.55, f32[10000,10]{1,0} %broadcast.68), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %negate.96 = f32[10000,10]{1,0} negate(f32[10000,10]{1,0} %add.69), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %constant.90 = f32[] constant(0), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %broadcast.91 = f32[10000,10]{1,0} broadcast(f32[] %constant.90), dimensions={}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %compare.92 = pred[10000,10]{1,0} compare(f32[10000,10]{1,0} %add.69, f32[10000,10]{1,0} %broadcast.91), direction=EQ, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %constant.93 = f32[] constant(0), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %broadcast.94 = f32[10000,10]{1,0} broadcast(f32[] %constant.93), dimensions={}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %arg0.1 = f32[10000,784]{1,0} parameter(0), parameter_replication={false}, metadata={op_name="XLA_Args"} %reshape.11 = f32[10000,784]{1,0} reshape(f32[10000,784]{1,0} %arg0.1) %reshape.43 = f32[10000,784]{1,0} reshape(f32[10000,784]{1,0} %reshape.11), metadata={op_type="Reshape" op_name="Reshape" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %arg2.3 = f32[784,10]{1,0} parameter(2), parameter_replication={false}, metadata={op_name="XLA_Args"} %dot.44 = f32[10000,10]{1,0} dot(f32[10000,784]{1,0} %reshape.43, f32[784,10]{1,0} %arg2.3), lhs_contracting_dims={1}, rhs_contracting_dims={0}, frontend_attributes={grad_x="false",grad_y="false"}, metadata={op_type="MatMul" op_name="dense_1/MatMul" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %transpose.45 = f32[10000,10]{1,0} transpose(f32[10000,10]{1,0} %dot.44), dimensions={0,1}, metadata={op_type="MatMul" op_name="dense_1/MatMul" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %arg3.4 = f32[10]{0} parameter(3), parameter_replication={false}, metadata={op_name="XLA_Args"} %broadcast.46 = f32[10000,10]{1,0} broadcast(f32[10]{0} %arg3.4), dimensions={1}, metadata={op_type="AddV2" op_name="dense_1/Add" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %add.47 = f32[10000,10]{1,0} add(f32[10000,10]{1,0} %transpose.45, f32[10000,10]{1,0} %broadcast.46), metadata={op_type="AddV2" op_name="dense_1/Add" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %constant.70 = f32[] constant(-inf), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %reduce.75 = f32[10000]{0} reduce(f32[10000,10]{1,0} %add.47, f32[] %constant.70), dimensions={1}, to_apply=%max_float_.71, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %broadcast.76 = f32[10000,10]{1,0} broadcast(f32[10000]{0} %reduce.75), dimensions={0}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %subtract.77 = f32[10000,10]{1,0} subtract(f32[10000,10]{1,0} %add.47, f32[10000,10]{1,0} %broadcast.76), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %exponential.78 = f32[10000,10]{1,0} exponential(f32[10000,10]{1,0} %subtract.77), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %convert.79 = f32[10000,10]{1,0} convert(f32[10000,10]{1,0} %exponential.78), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %constant.80 = f32[] constant(0), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %reduce.85 = f32[10000]{0} reduce(f32[10000,10]{1,0} %convert.79, f32[] %constant.80), dimensions={1}, to_apply=%add_float_.81, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %convert.86 = f32[10000]{0} convert(f32[10000]{0} %reduce.85), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %log.87 = f32[10000]{0} log(f32[10000]{0} %convert.86), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %broadcast.88 = f32[10000,10]{1,0} broadcast(f32[10000]{0} %log.87), dimensions={0}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %subtract.89 = f32[10000,10]{1,0} subtract(f32[10000,10]{1,0} %subtract.77, f32[10000,10]{1,0} %broadcast.88), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %select.95 = f32[10000,10]{1,0} select(pred[10000,10]{1,0} %compare.92, f32[10000,10]{1,0} %broadcast.94, f32[10000,10]{1,0} %subtract.89), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %multiply.97 = f32[10000,10]{1,0} multiply(f32[10000,10]{1,0} %negate.96, f32[10000,10]{1,0} %select.95), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %convert.98 = f32[10000,10]{1,0} convert(f32[10000,10]{1,0} %multiply.97), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %constant.99 = f32[] constant(0), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %reduce.104 = f32[10000]{0} reduce(f32[10000,10]{1,0} %convert.98, f32[] %constant.99), dimensions={1}, to_apply=%add_float_.100, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %convert.105 = f32[10000]{0} convert(f32[10000]{0} %reduce.104), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %convert.109 = f32[10000]{0} convert(f32[10000]{0} %convert.105), metadata={op_type="Mean" op_name="Mean" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %constant.110 = f32[] constant(0), metadata={op_type="Mean" op_name="Mean" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %convert.111 = f32[] convert(f32[] %constant.110), metadata={op_type="Mean" op_name="Mean" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %reduce.116 = f32[] reduce(f32[10000]{0} %convert.109, f32[] %convert.111), dimensions={0}, to_apply=%Mean-reduction.112, metadata={op_type="Mean" op_name="Mean" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %constant.117 = s32[] constant(10000), metadata={op_type="Mean" op_name="Mean" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %convert.118 = f32[] convert(s32[] %constant.117), metadata={op_type="Mean" op_name="Mean" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %divide.119 = f32[] divide(f32[] %reduce.116, f32[] %convert.118), metadata={op_type="Mean" op_name="Mean" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %convert.120 = f32[] convert(f32[] %divide.119), metadata={op_type="Mean" op_name="Mean" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %arg6.7 = f32[784,10]{1,0} parameter(6), parameter_replication={false}, metadata={op_name="XLA_Args"} %constant.121 = f32[] constant(0.0001), metadata={op_type="Mul" op_name="gradient_tape/SparseSoftmaxCrossEntropyWithLogits/mul" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %broadcast.122 = f32[10000,1]{1,0} broadcast(f32[] %constant.121), dimensions={}, metadata={op_type="Mul" op_name="gradient_tape/SparseSoftmaxCrossEntropyWithLogits/mul" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %reshape.123 = f32[10000]{0} reshape(f32[10000,1]{1,0} %broadcast.122), metadata={op_type="Mul" op_name="gradient_tape/SparseSoftmaxCrossEntropyWithLogits/mul" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %broadcast.124 = f32[10000,10]{1,0} broadcast(f32[10000]{0} %reshape.123), dimensions={0}, metadata={op_type="Mul" op_name="gradient_tape/SparseSoftmaxCrossEntropyWithLogits/mul" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %broadcast.106 = f32[10000,10]{1,0} broadcast(f32[10000]{0} %convert.86), dimensions={0}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %divide.107 = f32[10000,10]{1,0} divide(f32[10000,10]{1,0} %exponential.78, f32[10000,10]{1,0} %broadcast.106), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %subtract.108 = f32[10000,10]{1,0} subtract(f32[10000,10]{1,0} %divide.107, f32[10000,10]{1,0} %add.69), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %multiply.125 = f32[10000,10]{1,0} multiply(f32[10000,10]{1,0} %broadcast.124, f32[10000,10]{1,0} %subtract.108), metadata={op_type="Mul" op_name="gradient_tape/SparseSoftmaxCrossEntropyWithLogits/mul" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %dot.137 = f32[784,10]{1,0} dot(f32[10000,784]{1,0} %reshape.43, f32[10000,10]{1,0} %multiply.125), lhs_contracting_dims={0}, rhs_contracting_dims={0}, frontend_attributes={grad_x="false",grad_y="true"}, metadata={op_type="MatMul" op_name="gradient_tape/dense_1/MatMul/MatMul" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %transpose.138 = f32[784,10]{1,0} transpose(f32[784,10]{1,0} %dot.137), dimensions={0,1}, metadata={op_type="MatMul" op_name="gradient_tape/dense_1/MatMul/MatMul" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %subtract.159 = f32[784,10]{1,0} subtract(f32[784,10]{1,0} %transpose.138, f32[784,10]{1,0} %arg6.7), metadata={op_type="Sub" op_name="adam/Sub_2" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %constant.160 = f32[] constant(0.1), metadata={op_type="Mul" op_name="adam/Mul_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %broadcast.161 = f32[784,10]{1,0} broadcast(f32[] %constant.160), dimensions={}, metadata={op_type="Mul" op_name="adam/Mul_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %multiply.162 = f32[784,10]{1,0} multiply(f32[784,10]{1,0} %subtract.159, f32[784,10]{1,0} %broadcast.161), metadata={op_type="Mul" op_name="adam/Mul_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %add.163 = f32[784,10]{1,0} add(f32[784,10]{1,0} %arg6.7, f32[784,10]{1,0} %multiply.162), metadata={op_type="AssignAddVariableOp" op_name="adam/AssignAddVariableOp" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %arg4.5 = f32[] parameter(4), parameter_replication={false}, metadata={op_name="XLA_Args"} %constant.22 = f32[] constant(1), metadata={op_type="Sub" op_name="adam/sub" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %constant.20 = f32[] constant(0.999), metadata={op_type="Pow" op_name="adam/Pow_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %arg5.6 = s64[] parameter(5), parameter_replication={false}, metadata={op_name="XLA_Args"} %constant.13 = s64[] constant(1), metadata={op_type="AddV2" op_name="adam/add" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %add.14 = s64[] add(s64[] %arg5.6, s64[] %constant.13), metadata={op_type="AddV2" op_name="adam/add" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %convert.15 = f32[] convert(s64[] %add.14), metadata={op_type="Cast" op_name="adam/Cast_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %power.21 = f32[] power(f32[] %constant.20, f32[] %convert.15), metadata={op_type="Pow" op_name="adam/Pow_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %subtract.23 = f32[] subtract(f32[] %constant.22, f32[] %power.21), metadata={op_type="Sub" op_name="adam/sub" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %sqrt.24 = f32[] sqrt(f32[] %subtract.23), metadata={op_type="Sqrt" op_name="adam/Sqrt"} %multiply.25 = f32[] multiply(f32[] %arg4.5, f32[] %sqrt.24), metadata={op_type="Mul" op_name="adam/mul" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %constant.18 = f32[] constant(1), metadata={op_type="Sub" op_name="adam/sub_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %constant.16 = f32[] constant(0.9), metadata={op_type="Pow" op_name="adam/Pow" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %power.17 = f32[] power(f32[] %constant.16, f32[] %convert.15), metadata={op_type="Pow" op_name="adam/Pow" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %subtract.19 = f32[] subtract(f32[] %constant.18, f32[] %power.17), metadata={op_type="Sub" op_name="adam/sub_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %divide.26 = f32[] divide(f32[] %multiply.25, f32[] %subtract.19), metadata={op_type="RealDiv" op_name="adam/truediv" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %broadcast.164 = f32[784,10]{1,0} broadcast(f32[] %divide.26), dimensions={}, metadata={op_type="Mul" op_name="adam/Mul_3" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %multiply.165 = f32[784,10]{1,0} multiply(f32[784,10]{1,0} %add.163, f32[784,10]{1,0} %broadcast.164), metadata={op_type="Mul" op_name="adam/Mul_3" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %arg7.8 = f32[784,10]{1,0} parameter(7), parameter_replication={false}, metadata={op_name="XLA_Args"} %multiply.139 = f32[784,10]{1,0} multiply(f32[784,10]{1,0} %transpose.138, f32[784,10]{1,0} %transpose.138), metadata={op_type="Square" op_name="adam/Square" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %subtract.140 = f32[784,10]{1,0} subtract(f32[784,10]{1,0} %multiply.139, f32[784,10]{1,0} %arg7.8), metadata={op_type="Sub" op_name="adam/Sub_3" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %constant.141 = f32[] constant(0.001), metadata={op_type="Mul" op_name="adam/Mul_2" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %broadcast.142 = f32[784,10]{1,0} broadcast(f32[] %constant.141), dimensions={}, metadata={op_type="Mul" op_name="adam/Mul_2" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %multiply.143 = f32[784,10]{1,0} multiply(f32[784,10]{1,0} %subtract.140, f32[784,10]{1,0} %broadcast.142), metadata={op_type="Mul" op_name="adam/Mul_2" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %add.144 = f32[784,10]{1,0} add(f32[784,10]{1,0} %arg7.8, f32[784,10]{1,0} %multiply.143), metadata={op_type="AssignAddVariableOp" op_name="adam/AssignAddVariableOp_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %sqrt.145 = f32[784,10]{1,0} sqrt(f32[784,10]{1,0} %add.144), metadata={op_type="Sqrt" op_name="adam/Sqrt_1"} %constant.146 = f32[] constant(1e-07), metadata={op_type="AddV2" op_name="adam/Add_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %broadcast.147 = f32[784,10]{1,0} broadcast(f32[] %constant.146), dimensions={}, metadata={op_type="AddV2" op_name="adam/Add_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %add.148 = f32[784,10]{1,0} add(f32[784,10]{1,0} %sqrt.145, f32[784,10]{1,0} %broadcast.147), metadata={op_type="AddV2" op_name="adam/Add_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %divide.166 = f32[784,10]{1,0} divide(f32[784,10]{1,0} %multiply.165, f32[784,10]{1,0} %add.148), metadata={op_type="RealDiv" op_name="adam/truediv_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %subtract.167 = f32[784,10]{1,0} subtract(f32[784,10]{1,0} %arg2.3, f32[784,10]{1,0} %divide.166), metadata={op_type="AssignSubVariableOp" op_name="adam/AssignSubVariableOp" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %reshape.177 = f32[784,10]{1,0} reshape(f32[784,10]{1,0} %subtract.167), metadata={op_name="XLA_Retvals"} %copy.178 = f32[784,10]{1,0} copy(f32[784,10]{1,0} %reshape.177), metadata={op_name="XLA_Retvals"} %arg8.9 = f32[10]{0} parameter(8), parameter_replication={false}, metadata={op_name="XLA_Args"} %convert.126 = f32[10000,10]{1,0} convert(f32[10000,10]{1,0} %multiply.125), metadata={op_type="Sum" op_name="gradient_tape/dense_1/Add/Sum" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %constant.127 = f32[] constant(0), metadata={op_type="Sum" op_name="gradient_tape/dense_1/Add/Sum" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %convert.128 = f32[] convert(f32[] %constant.127), metadata={op_type="Sum" op_name="gradient_tape/dense_1/Add/Sum" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %reduce.133 = f32[10]{0} reduce(f32[10000,10]{1,0} %convert.126, f32[] %convert.128), dimensions={0}, to_apply=%gradient_tape_dense_1_Add_Sum-reduction.129, metadata={op_type="Sum" op_name="gradient_tape/dense_1/Add/Sum" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %convert.134 = f32[10]{0} convert(f32[10]{0} %reduce.133), metadata={op_type="Sum" op_name="gradient_tape/dense_1/Add/Sum" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %reshape.135 = f32[1,10]{1,0} reshape(f32[10]{0} %convert.134), metadata={op_type="Sum" op_name="gradient_tape/dense_1/Add/Sum" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %reshape.136 = f32[10]{0} reshape(f32[1,10]{1,0} %reshape.135), metadata={op_type="Reshape" op_name="gradient_tape/dense_1/Add/Reshape" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %subtract.168 = f32[10]{0} subtract(f32[10]{0} %reshape.136, f32[10]{0} %arg8.9), metadata={op_type="Sub" op_name="adam/Sub_6" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %constant.169 = f32[] constant(0.1), metadata={op_type="Mul" op_name="adam/Mul_5" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %broadcast.170 = f32[10]{0} broadcast(f32[] %constant.169), dimensions={}, metadata={op_type="Mul" op_name="adam/Mul_5" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %multiply.171 = f32[10]{0} multiply(f32[10]{0} %subtract.168, f32[10]{0} %broadcast.170), metadata={op_type="Mul" op_name="adam/Mul_5" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %add.172 = f32[10]{0} add(f32[10]{0} %arg8.9, f32[10]{0} %multiply.171), metadata={op_type="AssignAddVariableOp" op_name="adam/AssignAddVariableOp_2" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %constant.36 = f32[] constant(1), metadata={op_type="Sub" op_name="adam/sub_4" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %constant.34 = f32[] constant(0.999), metadata={op_type="Pow" op_name="adam/Pow_3" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %constant.27 = s64[] constant(1), metadata={op_type="AddV2" op_name="adam/add_2" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %add.28 = s64[] add(s64[] %arg5.6, s64[] %constant.27), metadata={op_type="AddV2" op_name="adam/add_2" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %convert.29 = f32[] convert(s64[] %add.28), metadata={op_type="Cast" op_name="adam/Cast_6" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %power.35 = f32[] power(f32[] %constant.34, f32[] %convert.29), metadata={op_type="Pow" op_name="adam/Pow_3" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %subtract.37 = f32[] subtract(f32[] %constant.36, f32[] %power.35), metadata={op_type="Sub" op_name="adam/sub_4" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %sqrt.38 = f32[] sqrt(f32[] %subtract.37), metadata={op_type="Sqrt" op_name="adam/Sqrt_2"} %multiply.39 = f32[] multiply(f32[] %arg4.5, f32[] %sqrt.38), metadata={op_type="Mul" op_name="adam/mul_4" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %constant.32 = f32[] constant(1), metadata={op_type="Sub" op_name="adam/sub_5" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %constant.30 = f32[] constant(0.9), metadata={op_type="Pow" op_name="adam/Pow_2" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %power.31 = f32[] power(f32[] %constant.30, f32[] %convert.29), metadata={op_type="Pow" op_name="adam/Pow_2" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %subtract.33 = f32[] subtract(f32[] %constant.32, f32[] %power.31), metadata={op_type="Sub" op_name="adam/sub_5" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %divide.40 = f32[] divide(f32[] %multiply.39, f32[] %subtract.33), metadata={op_type="RealDiv" op_name="adam/truediv_2" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %broadcast.173 = f32[10]{0} broadcast(f32[] %divide.40), dimensions={}, metadata={op_type="Mul" op_name="adam/Mul_7" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %multiply.174 = f32[10]{0} multiply(f32[10]{0} %add.172, f32[10]{0} %broadcast.173), metadata={op_type="Mul" op_name="adam/Mul_7" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %arg9.10 = f32[10]{0} parameter(9), parameter_replication={false}, metadata={op_name="XLA_Args"} %multiply.149 = f32[10]{0} multiply(f32[10]{0} %reshape.136, f32[10]{0} %reshape.136), metadata={op_type="Square" op_name="adam/Square_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %subtract.150 = f32[10]{0} subtract(f32[10]{0} %multiply.149, f32[10]{0} %arg9.10), metadata={op_type="Sub" op_name="adam/Sub_7" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %constant.151 = f32[] constant(0.001), metadata={op_type="Mul" op_name="adam/Mul_6" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %broadcast.152 = f32[10]{0} broadcast(f32[] %constant.151), dimensions={}, metadata={op_type="Mul" op_name="adam/Mul_6" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %multiply.153 = f32[10]{0} multiply(f32[10]{0} %subtract.150, f32[10]{0} %broadcast.152), metadata={op_type="Mul" op_name="adam/Mul_6" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %add.154 = f32[10]{0} add(f32[10]{0} %arg9.10, f32[10]{0} %multiply.153), metadata={op_type="AssignAddVariableOp" op_name="adam/AssignAddVariableOp_3" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %sqrt.155 = f32[10]{0} sqrt(f32[10]{0} %add.154), metadata={op_type="Sqrt" op_name="adam/Sqrt_3"} %constant.156 = f32[] constant(1e-07), metadata={op_type="AddV2" op_name="adam/Add_3" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %broadcast.157 = f32[10]{0} broadcast(f32[] %constant.156), dimensions={}, metadata={op_type="AddV2" op_name="adam/Add_3" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %add.158 = f32[10]{0} add(f32[10]{0} %sqrt.155, f32[10]{0} %broadcast.157), metadata={op_type="AddV2" op_name="adam/Add_3" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %divide.175 = f32[10]{0} divide(f32[10]{0} %multiply.174, f32[10]{0} %add.158), metadata={op_type="RealDiv" op_name="adam/truediv_3" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %subtract.176 = f32[10]{0} subtract(f32[10]{0} %arg3.4, f32[10]{0} %divide.175), metadata={op_type="AssignSubVariableOp" op_name="adam/AssignSubVariableOp_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %reshape.179 = f32[10]{0} reshape(f32[10]{0} %subtract.176), metadata={op_name="XLA_Retvals"} %copy.180 = f32[10]{0} copy(f32[10]{0} %reshape.179), metadata={op_name="XLA_Retvals"} %constant.41 = s64[] constant(1), metadata={op_type="AddV2" op_name="adam/add_4" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %add.42 = s64[] add(s64[] %arg5.6, s64[] %constant.41), metadata={op_type="AddV2" op_name="adam/add_4" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177} %reshape.181 = s64[] reshape(s64[] %add.42), metadata={op_name="XLA_Retvals"} %copy.182 = s64[] copy(s64[] %reshape.181), metadata={op_name="XLA_Retvals"} %reshape.183 = f32[784,10]{1,0} reshape(f32[784,10]{1,0} %add.163), metadata={op_name="XLA_Retvals"} %copy.184 = f32[784,10]{1,0} copy(f32[784,10]{1,0} %reshape.183), metadata={op_name="XLA_Retvals"} %reshape.185 = f32[784,10]{1,0} reshape(f32[784,10]{1,0} %add.144), metadata={op_name="XLA_Retvals"} %copy.186 = f32[784,10]{1,0} copy(f32[784,10]{1,0} %reshape.185), metadata={op_name="XLA_Retvals"} %reshape.187 = f32[10]{0} reshape(f32[10]{0} %add.172), metadata={op_name="XLA_Retvals"} %copy.188 = f32[10]{0} copy(f32[10]{0} %reshape.187), metadata={op_name="XLA_Retvals"} %reshape.189 = f32[10]{0} reshape(f32[10]{0} %add.154), metadata={op_name="XLA_Retvals"} %copy.190 = f32[10]{0} copy(f32[10]{0} %reshape.189), metadata={op_name="XLA_Retvals"} ROOT %tuple.191 = (f32[784,10]{1,0}, f32[10]{0}, s64[], f32[784,10]{1,0}, f32[784,10]{1,0}, /*index=5*/f32[10]{0}, f32[10]{0}) tuple(f32[784,10]{1,0} %copy.178, f32[10]{0} %copy.180, s64[] %copy.182, f32[784,10]{1,0} %copy.184, f32[784,10]{1,0} %copy.186, /*index=5*/f32[10]{0} %copy.188, f32[10]{0} %copy.190), metadata={op_name="XLA_Retvals"} }