This tutorial trains a TensorFlow model to classify the MNIST dataset, where the training function is compiled using XLA.
First, load TensorFlow and enable eager execution.
import tensorflow as tf
2025-02-05 12:06:43.470310: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered WARNING: All log messages before absl::InitializeLog() is called are written to STDERR E0000 00:00:1738757203.491631 10594 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered E0000 00:00:1738757203.498157 10594 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Then define some necessary constants and prepare the MNIST dataset.
# Size of each input image, 28 x 28 pixels
IMAGE_SIZE = 28 * 28
# Number of distinct number labels, [0..9]
NUM_CLASSES = 10
# Number of examples in each training batch (step)
TRAIN_BATCH_SIZE = 100
# Number of training steps to run
TRAIN_STEPS = 1000
# Loads MNIST dataset.
train, test = tf.keras.datasets.mnist.load_data()
train_ds = tf.data.Dataset.from_tensor_slices(train).batch(TRAIN_BATCH_SIZE).repeat()
# Casting from raw data to the required datatypes.
def cast(images, labels):
images = tf.cast(
tf.reshape(images, [-1, IMAGE_SIZE]), tf.float32)
labels = tf.cast(labels, tf.int64)
return (images, labels)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz 11490434/11490434 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step I0000 00:00:1738757208.209124 10594 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13680 MB memory: -> device: 0, name: Tesla T4, pci bus id: 0000:00:05.0, compute capability: 7.5 I0000 00:00:1738757208.211541 10594 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 13756 MB memory: -> device: 1, name: Tesla T4, pci bus id: 0000:00:06.0, compute capability: 7.5 I0000 00:00:1738757208.213696 10594 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:2 with 13756 MB memory: -> device: 2, name: Tesla T4, pci bus id: 0000:00:07.0, compute capability: 7.5 I0000 00:00:1738757208.215979 10594 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:3 with 13756 MB memory: -> device: 3, name: Tesla T4, pci bus id: 0000:00:08.0, compute capability: 7.5
Finally, define the model and the optimizer. The model uses a single dense layer.
layer = tf.keras.layers.Dense(NUM_CLASSES)
optimizer = tf.keras.optimizers.Adam()
Define the training function
In the training function, you get the predicted labels using the layer defined above, and then minimize the gradient of the loss using the optimizer. In order to compile the computation using XLA, place it inside tf.function
with jit_compile=True
.
@tf.function(jit_compile=True)
def train_mnist(images, labels):
images, labels = cast(images, labels)
with tf.GradientTape() as tape:
predicted_labels = layer(images)
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=predicted_labels, labels=labels
))
layer_variables = layer.trainable_variables
grads = tape.gradient(loss, layer_variables)
optimizer.apply_gradients(zip(grads, layer_variables))
Train and test the model
Once you have defined the training function, define the model.
for images, labels in train_ds:
if optimizer.iterations > TRAIN_STEPS:
break
train_mnist(images, labels)
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1738757209.044346 10594 service.cc:148] XLA service 0x499b3c20 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices: I0000 00:00:1738757209.044391 10594 service.cc:156] StreamExecutor device (0): Tesla T4, Compute Capability 7.5 I0000 00:00:1738757209.044395 10594 service.cc:156] StreamExecutor device (1): Tesla T4, Compute Capability 7.5 I0000 00:00:1738757209.044398 10594 service.cc:156] StreamExecutor device (2): Tesla T4, Compute Capability 7.5 I0000 00:00:1738757209.044402 10594 service.cc:156] StreamExecutor device (3): Tesla T4, Compute Capability 7.5 I0000 00:00:1738757209.094386 10594 cuda_dnn.cc:529] Loaded cuDNN version 90300 I0000 00:00:1738757209.562486 10594 device_compiler.h:188] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.
And, finally, check the accuracy:
images, labels = cast(test[0], test[1])
predicted_labels = layer(images)
correct_prediction = tf.equal(tf.argmax(predicted_labels, 1), labels)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print("Prediction accuracy after training: %s" % accuracy)
Prediction accuracy after training: tf.Tensor(0.8754, shape=(), dtype=float32)
Behind the scenes, the XLA compiler has compiled the entire TF function to HLO, which has enabled fusion optimizations. Using the introspection facilities, we can see the HLO code (other interesting possible values for "stage" are optimized_hlo
for HLO after optimizations and optimized_hlo_dot
for a Graphviz graph):
print(train_mnist.experimental_get_compiler_ir(images, labels)(stage='hlo'))
HloModule a_inference_train_mnist_6559__.187, input_output_alias={ {0}: (2, {}, may-alias), {1}: (3, {}, may-alias), {2}: (5, {}, may-alias), {3}: (6, {}, may-alias), {4}: (7, {}, may-alias), {5}: (8, {}, may-alias), {6}: (9, {}, may-alias) }, entry_computation_layout={(f32[10000,784]{1,0}, s64[10000]{0}, f32[784,10]{1,0}, f32[10]{0}, f32[], /*index=5*/s64[], f32[784,10]{1,0}, f32[784,10]{1,0}, f32[10]{0}, f32[10]{0})->(f32[784,10]{1,0}, f32[10]{0}, s64[], f32[784,10]{1,0}, f32[784,10]{1,0}, /*index=5*/f32[10]{0}, f32[10]{0})} %max_float_.71 (x.72: f32[], y.73: f32[]) -> f32[] { %x.72 = f32[] parameter(0) %y.73 = f32[] parameter(1) ROOT %maximum.74 = f32[] maximum(f32[] %x.72, f32[] %y.73) } %add_float_.81 (x.82: f32[], y.83: f32[]) -> f32[] { %x.82 = f32[] parameter(0) %y.83 = f32[] parameter(1) ROOT %add.84 = f32[] add(f32[] %x.82, f32[] %y.83) } %add_float_.100 (x.101: f32[], y.102: f32[]) -> f32[] { %x.101 = f32[] parameter(0) %y.102 = f32[] parameter(1) ROOT %add.103 = f32[] add(f32[] %x.101, f32[] %y.102) } %Mean-reduction.112 (x.113: f32[], y.114: f32[]) -> f32[] { %x.113 = f32[] parameter(0) %y.114 = f32[] parameter(1) ROOT %add.115 = f32[] add(f32[] %x.113, f32[] %y.114) } %region_0.127 (Arg_0.128: f32[], Arg_1.129: f32[]) -> f32[] { %Arg_0.128 = f32[] parameter(0), metadata={op_name="gradient_tape/dense_1/BiasAdd/BiasAddGrad"} %Arg_1.129 = f32[] parameter(1), metadata={op_name="gradient_tape/dense_1/BiasAdd/BiasAddGrad"} ROOT %add.130 = f32[] add(f32[] %Arg_0.128, f32[] %Arg_1.129), metadata={op_type="BiasAddGrad" op_name="gradient_tape/dense_1/BiasAdd/BiasAddGrad"} } ENTRY %a_inference_train_mnist_6559__.187 (arg0.1: f32[10000,784], arg1.2: s64[10000], arg2.3: f32[784,10], arg3.4: f32[10], arg4.5: f32[], arg5.6: s64[], arg6.7: f32[784,10], arg7.8: f32[784,10], arg8.9: f32[10], arg9.10: f32[10]) -> (f32[784,10], f32[10], s64[], f32[784,10], f32[784,10], /*index=5*/f32[10], f32[10]) { %arg1.2 = s64[10000]{0} parameter(1), parameter_replication={false}, metadata={op_name="XLA_Args"} %reshape.12 = s64[10000]{0} reshape(s64[10000]{0} %arg1.2) %broadcast.51 = s64[10000,10]{1,0} broadcast(s64[10000]{0} %reshape.12), dimensions={0}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %iota.50 = s64[10000,10]{1,0} iota(), iota_dimension=1, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %compare.52 = pred[10000,10]{1,0} compare(s64[10000,10]{1,0} %broadcast.51, s64[10000,10]{1,0} %iota.50), direction=EQ, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %constant.48 = f32[] constant(1), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %broadcast.53 = f32[10000,10]{1,0} broadcast(f32[] %constant.48), dimensions={}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %constant.49 = f32[] constant(0), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %broadcast.54 = f32[10000,10]{1,0} broadcast(f32[] %constant.49), dimensions={}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %select.55 = f32[10000,10]{1,0} select(pred[10000,10]{1,0} %compare.52, f32[10000,10]{1,0} %broadcast.53, f32[10000,10]{1,0} %broadcast.54), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %constant.56 = s64[] constant(0), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %broadcast.57 = s64[10000]{0} broadcast(s64[] %constant.56), dimensions={}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %compare.58 = pred[10000]{0} compare(s64[10000]{0} %broadcast.57, s64[10000]{0} %reshape.12), direction=LE, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %constant.59 = s64[] constant(10), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %broadcast.60 = s64[10000]{0} broadcast(s64[] %constant.59), dimensions={}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %compare.61 = pred[10000]{0} compare(s64[10000]{0} %reshape.12, s64[10000]{0} %broadcast.60), direction=LT, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %and.62 = pred[10000]{0} and(pred[10000]{0} %compare.58, pred[10000]{0} %compare.61), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %constant.63 = f32[] constant(0), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %broadcast.64 = f32[10000]{0} broadcast(f32[] %constant.63), dimensions={}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %constant.65 = f32[] constant(nan), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %broadcast.66 = f32[10000]{0} broadcast(f32[] %constant.65), dimensions={}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %select.67 = f32[10000]{0} select(pred[10000]{0} %and.62, f32[10000]{0} %broadcast.64, f32[10000]{0} %broadcast.66), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %broadcast.68 = f32[10000,10]{1,0} broadcast(f32[10000]{0} %select.67), dimensions={0}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %add.69 = f32[10000,10]{1,0} add(f32[10000,10]{1,0} %select.55, f32[10000,10]{1,0} %broadcast.68), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %negate.96 = f32[10000,10]{1,0} negate(f32[10000,10]{1,0} %add.69), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %constant.90 = f32[] constant(0), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %broadcast.91 = f32[10000,10]{1,0} broadcast(f32[] %constant.90), dimensions={}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %compare.92 = pred[10000,10]{1,0} compare(f32[10000,10]{1,0} %add.69, f32[10000,10]{1,0} %broadcast.91), direction=EQ, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %constant.93 = f32[] constant(0), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %broadcast.94 = f32[10000,10]{1,0} broadcast(f32[] %constant.93), dimensions={}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %arg0.1 = f32[10000,784]{1,0} parameter(0), parameter_replication={false}, metadata={op_name="XLA_Args"} %reshape.11 = f32[10000,784]{1,0} reshape(f32[10000,784]{1,0} %arg0.1) %reshape.43 = f32[10000,784]{1,0} reshape(f32[10000,784]{1,0} %reshape.11), metadata={op_type="Reshape" op_name="Reshape" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %arg2.3 = f32[784,10]{1,0} parameter(2), parameter_replication={false}, metadata={op_name="XLA_Args"} %dot.44 = f32[10000,10]{1,0} dot(f32[10000,784]{1,0} %reshape.43, f32[784,10]{1,0} %arg2.3), lhs_contracting_dims={1}, rhs_contracting_dims={0}, frontend_attributes={grad_x="false",grad_y="false"}, metadata={op_type="MatMul" op_name="dense_1/MatMul" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %transpose.45 = f32[10000,10]{1,0} transpose(f32[10000,10]{1,0} %dot.44), dimensions={0,1}, metadata={op_type="MatMul" op_name="dense_1/MatMul" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %arg3.4 = f32[10]{0} parameter(3), parameter_replication={false}, metadata={op_name="XLA_Args"} %broadcast.46 = f32[10000,10]{1,0} broadcast(f32[10]{0} %arg3.4), dimensions={1}, metadata={op_type="BiasAdd" op_name="dense_1/BiasAdd" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %add.47 = f32[10000,10]{1,0} add(f32[10000,10]{1,0} %transpose.45, f32[10000,10]{1,0} %broadcast.46), metadata={op_type="BiasAdd" op_name="dense_1/BiasAdd" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %constant.70 = f32[] constant(-inf), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %reduce.75 = f32[10000]{0} reduce(f32[10000,10]{1,0} %add.47, f32[] %constant.70), dimensions={1}, to_apply=%max_float_.71, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %broadcast.76 = f32[10000,10]{1,0} broadcast(f32[10000]{0} %reduce.75), dimensions={0}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %subtract.77 = f32[10000,10]{1,0} subtract(f32[10000,10]{1,0} %add.47, f32[10000,10]{1,0} %broadcast.76), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %exponential.78 = f32[10000,10]{1,0} exponential(f32[10000,10]{1,0} %subtract.77), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %convert.79 = f32[10000,10]{1,0} convert(f32[10000,10]{1,0} %exponential.78), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %constant.80 = f32[] constant(0), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %reduce.85 = f32[10000]{0} reduce(f32[10000,10]{1,0} %convert.79, f32[] %constant.80), dimensions={1}, to_apply=%add_float_.81, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %convert.86 = f32[10000]{0} convert(f32[10000]{0} %reduce.85), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %log.87 = f32[10000]{0} log(f32[10000]{0} %convert.86), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %broadcast.88 = f32[10000,10]{1,0} broadcast(f32[10000]{0} %log.87), dimensions={0}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %subtract.89 = f32[10000,10]{1,0} subtract(f32[10000,10]{1,0} %subtract.77, f32[10000,10]{1,0} %broadcast.88), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %select.95 = f32[10000,10]{1,0} select(pred[10000,10]{1,0} %compare.92, f32[10000,10]{1,0} %broadcast.94, f32[10000,10]{1,0} %subtract.89), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %multiply.97 = f32[10000,10]{1,0} multiply(f32[10000,10]{1,0} %negate.96, f32[10000,10]{1,0} %select.95), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %convert.98 = f32[10000,10]{1,0} convert(f32[10000,10]{1,0} %multiply.97), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %constant.99 = f32[] constant(0), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %reduce.104 = f32[10000]{0} reduce(f32[10000,10]{1,0} %convert.98, f32[] %constant.99), dimensions={1}, to_apply=%add_float_.100, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %convert.105 = f32[10000]{0} convert(f32[10000]{0} %reduce.104), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %convert.109 = f32[10000]{0} convert(f32[10000]{0} %convert.105), metadata={op_type="Mean" op_name="Mean" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %constant.110 = f32[] constant(0), metadata={op_type="Mean" op_name="Mean" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %convert.111 = f32[] convert(f32[] %constant.110), metadata={op_type="Mean" op_name="Mean" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %reduce.116 = f32[] reduce(f32[10000]{0} %convert.109, f32[] %convert.111), dimensions={0}, to_apply=%Mean-reduction.112, metadata={op_type="Mean" op_name="Mean" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %constant.117 = s32[] constant(10000), metadata={op_type="Mean" op_name="Mean" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %convert.118 = f32[] convert(s32[] %constant.117), metadata={op_type="Mean" op_name="Mean" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %divide.119 = f32[] divide(f32[] %reduce.116, f32[] %convert.118), metadata={op_type="Mean" op_name="Mean" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %convert.120 = f32[] convert(f32[] %divide.119), metadata={op_type="Mean" op_name="Mean" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %arg6.7 = f32[784,10]{1,0} parameter(6), parameter_replication={false}, metadata={op_name="XLA_Args"} %constant.121 = f32[] constant(0.0001), metadata={op_type="Mul" op_name="gradient_tape/SparseSoftmaxCrossEntropyWithLogits/mul" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %broadcast.122 = f32[10000,1]{1,0} broadcast(f32[] %constant.121), dimensions={}, metadata={op_type="Mul" op_name="gradient_tape/SparseSoftmaxCrossEntropyWithLogits/mul" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %reshape.123 = f32[10000]{0} reshape(f32[10000,1]{1,0} %broadcast.122), metadata={op_type="Mul" op_name="gradient_tape/SparseSoftmaxCrossEntropyWithLogits/mul" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %broadcast.124 = f32[10000,10]{1,0} broadcast(f32[10000]{0} %reshape.123), dimensions={0}, metadata={op_type="Mul" op_name="gradient_tape/SparseSoftmaxCrossEntropyWithLogits/mul" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %broadcast.106 = f32[10000,10]{1,0} broadcast(f32[10000]{0} %convert.86), dimensions={0}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %divide.107 = f32[10000,10]{1,0} divide(f32[10000,10]{1,0} %exponential.78, f32[10000,10]{1,0} %broadcast.106), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %subtract.108 = f32[10000,10]{1,0} subtract(f32[10000,10]{1,0} %divide.107, f32[10000,10]{1,0} %add.69), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %multiply.125 = f32[10000,10]{1,0} multiply(f32[10000,10]{1,0} %broadcast.124, f32[10000,10]{1,0} %subtract.108), metadata={op_type="Mul" op_name="gradient_tape/SparseSoftmaxCrossEntropyWithLogits/mul" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %dot.132 = f32[784,10]{1,0} dot(f32[10000,784]{1,0} %reshape.43, f32[10000,10]{1,0} %multiply.125), lhs_contracting_dims={0}, rhs_contracting_dims={0}, frontend_attributes={grad_x="false",grad_y="true"}, metadata={op_type="MatMul" op_name="gradient_tape/dense_1/MatMul/MatMul" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %transpose.133 = f32[784,10]{1,0} transpose(f32[784,10]{1,0} %dot.132), dimensions={0,1}, metadata={op_type="MatMul" op_name="gradient_tape/dense_1/MatMul/MatMul" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %subtract.154 = f32[784,10]{1,0} subtract(f32[784,10]{1,0} %transpose.133, f32[784,10]{1,0} %arg6.7), metadata={op_type="Sub" op_name="adam/Sub_2" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %constant.155 = f32[] constant(0.1), metadata={op_type="Mul" op_name="adam/Mul_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %broadcast.156 = f32[784,10]{1,0} broadcast(f32[] %constant.155), dimensions={}, metadata={op_type="Mul" op_name="adam/Mul_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %multiply.157 = f32[784,10]{1,0} multiply(f32[784,10]{1,0} %subtract.154, f32[784,10]{1,0} %broadcast.156), metadata={op_type="Mul" op_name="adam/Mul_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %add.158 = f32[784,10]{1,0} add(f32[784,10]{1,0} %arg6.7, f32[784,10]{1,0} %multiply.157), metadata={op_type="AssignAddVariableOp" op_name="adam/AssignAddVariableOp" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %arg4.5 = f32[] parameter(4), parameter_replication={false}, metadata={op_name="XLA_Args"} %constant.22 = f32[] constant(1), metadata={op_type="Sub" op_name="adam/sub" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %constant.20 = f32[] constant(0.999), metadata={op_type="Pow" op_name="adam/Pow_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %arg5.6 = s64[] parameter(5), parameter_replication={false}, metadata={op_name="XLA_Args"} %constant.13 = s64[] constant(1), metadata={op_type="AddV2" op_name="adam/Add" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %add.14 = s64[] add(s64[] %arg5.6, s64[] %constant.13), metadata={op_type="AddV2" op_name="adam/Add" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %convert.15 = f32[] convert(s64[] %add.14), metadata={op_type="Cast" op_name="adam/Cast_2" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %power.21 = f32[] power(f32[] %constant.20, f32[] %convert.15), metadata={op_type="Pow" op_name="adam/Pow_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %subtract.23 = f32[] subtract(f32[] %constant.22, f32[] %power.21), metadata={op_type="Sub" op_name="adam/sub" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %sqrt.24 = f32[] sqrt(f32[] %subtract.23), metadata={op_type="Sqrt" op_name="adam/Sqrt"} %multiply.39 = f32[] multiply(f32[] %arg4.5, f32[] %sqrt.24), metadata={op_type="Mul" op_name="adam/mul" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %constant.18 = f32[] constant(1), metadata={op_type="Sub" op_name="adam/sub_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %constant.16 = f32[] constant(0.9), metadata={op_type="Pow" op_name="adam/Pow" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %power.17 = f32[] power(f32[] %constant.16, f32[] %convert.15), metadata={op_type="Pow" op_name="adam/Pow" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %subtract.19 = f32[] subtract(f32[] %constant.18, f32[] %power.17), metadata={op_type="Sub" op_name="adam/sub_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %divide.40 = f32[] divide(f32[] %multiply.39, f32[] %subtract.19), metadata={op_type="RealDiv" op_name="adam/truediv" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %broadcast.159 = f32[784,10]{1,0} broadcast(f32[] %divide.40), dimensions={}, metadata={op_type="Mul" op_name="adam/Mul_3" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %multiply.160 = f32[784,10]{1,0} multiply(f32[784,10]{1,0} %add.158, f32[784,10]{1,0} %broadcast.159), metadata={op_type="Mul" op_name="adam/Mul_3" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %arg7.8 = f32[784,10]{1,0} parameter(7), parameter_replication={false}, metadata={op_name="XLA_Args"} %multiply.134 = f32[784,10]{1,0} multiply(f32[784,10]{1,0} %transpose.133, f32[784,10]{1,0} %transpose.133), metadata={op_type="Square" op_name="adam/Square" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %subtract.135 = f32[784,10]{1,0} subtract(f32[784,10]{1,0} %multiply.134, f32[784,10]{1,0} %arg7.8), metadata={op_type="Sub" op_name="adam/Sub_3" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %constant.136 = f32[] constant(0.001), metadata={op_type="Mul" op_name="adam/Mul_2" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %broadcast.137 = f32[784,10]{1,0} broadcast(f32[] %constant.136), dimensions={}, metadata={op_type="Mul" op_name="adam/Mul_2" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %multiply.138 = f32[784,10]{1,0} multiply(f32[784,10]{1,0} %subtract.135, f32[784,10]{1,0} %broadcast.137), metadata={op_type="Mul" op_name="adam/Mul_2" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %add.139 = f32[784,10]{1,0} add(f32[784,10]{1,0} %arg7.8, f32[784,10]{1,0} %multiply.138), metadata={op_type="AssignAddVariableOp" op_name="adam/AssignAddVariableOp_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %sqrt.140 = f32[784,10]{1,0} sqrt(f32[784,10]{1,0} %add.139), metadata={op_type="Sqrt" op_name="adam/Sqrt_1"} %constant.141 = f32[] constant(1e-07), metadata={op_type="AddV2" op_name="adam/Add_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %broadcast.142 = f32[784,10]{1,0} broadcast(f32[] %constant.141), dimensions={}, metadata={op_type="AddV2" op_name="adam/Add_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %add.143 = f32[784,10]{1,0} add(f32[784,10]{1,0} %sqrt.140, f32[784,10]{1,0} %broadcast.142), metadata={op_type="AddV2" op_name="adam/Add_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %divide.161 = f32[784,10]{1,0} divide(f32[784,10]{1,0} %multiply.160, f32[784,10]{1,0} %add.143), metadata={op_type="RealDiv" op_name="adam/truediv_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %subtract.162 = f32[784,10]{1,0} subtract(f32[784,10]{1,0} %arg2.3, f32[784,10]{1,0} %divide.161), metadata={op_type="AssignSubVariableOp" op_name="adam/AssignSubVariableOp" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %reshape.172 = f32[784,10]{1,0} reshape(f32[784,10]{1,0} %subtract.162), metadata={op_name="XLA_Retvals"} %copy.173 = f32[784,10]{1,0} copy(f32[784,10]{1,0} %reshape.172), metadata={op_name="XLA_Retvals"} %arg8.9 = f32[10]{0} parameter(8), parameter_replication={false}, metadata={op_name="XLA_Args"} %constant.126 = f32[] constant(-0), metadata={op_type="BiasAddGrad" op_name="gradient_tape/dense_1/BiasAdd/BiasAddGrad" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %reduce.131 = f32[10]{0} reduce(f32[10000,10]{1,0} %multiply.125, f32[] %constant.126), dimensions={0}, to_apply=%region_0.127, metadata={op_type="BiasAddGrad" op_name="gradient_tape/dense_1/BiasAdd/BiasAddGrad"} %subtract.163 = f32[10]{0} subtract(f32[10]{0} %reduce.131, f32[10]{0} %arg8.9), metadata={op_type="Sub" op_name="adam/Sub_6" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %constant.164 = f32[] constant(0.1), metadata={op_type="Mul" op_name="adam/Mul_5" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %broadcast.165 = f32[10]{0} broadcast(f32[] %constant.164), dimensions={}, metadata={op_type="Mul" op_name="adam/Mul_5" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %multiply.166 = f32[10]{0} multiply(f32[10]{0} %subtract.163, f32[10]{0} %broadcast.165), metadata={op_type="Mul" op_name="adam/Mul_5" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %add.167 = f32[10]{0} add(f32[10]{0} %arg8.9, f32[10]{0} %multiply.166), metadata={op_type="AssignAddVariableOp" op_name="adam/AssignAddVariableOp_2" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %constant.34 = f32[] constant(1), metadata={op_type="Sub" op_name="adam/sub_4" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %constant.32 = f32[] constant(0.999), metadata={op_type="Pow" op_name="adam/Pow_3" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %constant.25 = s64[] constant(1), metadata={op_type="AddV2" op_name="adam/Add_2" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %add.26 = s64[] add(s64[] %arg5.6, s64[] %constant.25), metadata={op_type="AddV2" op_name="adam/Add_2" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %convert.27 = f32[] convert(s64[] %add.26), metadata={op_type="Cast" op_name="adam/Cast_8" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %power.33 = f32[] power(f32[] %constant.32, f32[] %convert.27), metadata={op_type="Pow" op_name="adam/Pow_3" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %subtract.35 = f32[] subtract(f32[] %constant.34, f32[] %power.33), metadata={op_type="Sub" op_name="adam/sub_4" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %sqrt.36 = f32[] sqrt(f32[] %subtract.35), metadata={op_type="Sqrt" op_name="adam/Sqrt_2"} %multiply.41 = f32[] multiply(f32[] %arg4.5, f32[] %sqrt.36), metadata={op_type="Mul" op_name="adam/mul_4" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %constant.30 = f32[] constant(1), metadata={op_type="Sub" op_name="adam/sub_5" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %constant.28 = f32[] constant(0.9), metadata={op_type="Pow" op_name="adam/Pow_2" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %power.29 = f32[] power(f32[] %constant.28, f32[] %convert.27), metadata={op_type="Pow" op_name="adam/Pow_2" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %subtract.31 = f32[] subtract(f32[] %constant.30, f32[] %power.29), metadata={op_type="Sub" op_name="adam/sub_5" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %divide.42 = f32[] divide(f32[] %multiply.41, f32[] %subtract.31), metadata={op_type="RealDiv" op_name="adam/truediv_2" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %broadcast.168 = f32[10]{0} broadcast(f32[] %divide.42), dimensions={}, metadata={op_type="Mul" op_name="adam/Mul_7" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %multiply.169 = f32[10]{0} multiply(f32[10]{0} %add.167, f32[10]{0} %broadcast.168), metadata={op_type="Mul" op_name="adam/Mul_7" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %arg9.10 = f32[10]{0} parameter(9), parameter_replication={false}, metadata={op_name="XLA_Args"} %multiply.144 = f32[10]{0} multiply(f32[10]{0} %reduce.131, f32[10]{0} %reduce.131), metadata={op_type="Square" op_name="adam/Square_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %subtract.145 = f32[10]{0} subtract(f32[10]{0} %multiply.144, f32[10]{0} %arg9.10), metadata={op_type="Sub" op_name="adam/Sub_7" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %constant.146 = f32[] constant(0.001), metadata={op_type="Mul" op_name="adam/Mul_6" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %broadcast.147 = f32[10]{0} broadcast(f32[] %constant.146), dimensions={}, metadata={op_type="Mul" op_name="adam/Mul_6" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %multiply.148 = f32[10]{0} multiply(f32[10]{0} %subtract.145, f32[10]{0} %broadcast.147), metadata={op_type="Mul" op_name="adam/Mul_6" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %add.149 = f32[10]{0} add(f32[10]{0} %arg9.10, f32[10]{0} %multiply.148), metadata={op_type="AssignAddVariableOp" op_name="adam/AssignAddVariableOp_3" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %sqrt.150 = f32[10]{0} sqrt(f32[10]{0} %add.149), metadata={op_type="Sqrt" op_name="adam/Sqrt_3"} %constant.151 = f32[] constant(1e-07), metadata={op_type="AddV2" op_name="adam/Add_3" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %broadcast.152 = f32[10]{0} broadcast(f32[] %constant.151), dimensions={}, metadata={op_type="AddV2" op_name="adam/Add_3" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %add.153 = f32[10]{0} add(f32[10]{0} %sqrt.150, f32[10]{0} %broadcast.152), metadata={op_type="AddV2" op_name="adam/Add_3" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %divide.170 = f32[10]{0} divide(f32[10]{0} %multiply.169, f32[10]{0} %add.153), metadata={op_type="RealDiv" op_name="adam/truediv_3" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %subtract.171 = f32[10]{0} subtract(f32[10]{0} %arg3.4, f32[10]{0} %divide.170), metadata={op_type="AssignSubVariableOp" op_name="adam/AssignSubVariableOp_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %reshape.174 = f32[10]{0} reshape(f32[10]{0} %subtract.171), metadata={op_name="XLA_Retvals"} %copy.175 = f32[10]{0} copy(f32[10]{0} %reshape.174), metadata={op_name="XLA_Retvals"} %constant.37 = s64[] constant(1), metadata={op_type="AddV2" op_name="adam/Add_4" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %add.38 = s64[] add(s64[] %arg5.6, s64[] %constant.37), metadata={op_type="AddV2" op_name="adam/Add_4" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1196} %reshape.176 = s64[] reshape(s64[] %add.38), metadata={op_name="XLA_Retvals"} %copy.177 = s64[] copy(s64[] %reshape.176), metadata={op_name="XLA_Retvals"} %reshape.178 = f32[784,10]{1,0} reshape(f32[784,10]{1,0} %add.158), metadata={op_name="XLA_Retvals"} %copy.179 = f32[784,10]{1,0} copy(f32[784,10]{1,0} %reshape.178), metadata={op_name="XLA_Retvals"} %reshape.180 = f32[784,10]{1,0} reshape(f32[784,10]{1,0} %add.139), metadata={op_name="XLA_Retvals"} %copy.181 = f32[784,10]{1,0} copy(f32[784,10]{1,0} %reshape.180), metadata={op_name="XLA_Retvals"} %reshape.182 = f32[10]{0} reshape(f32[10]{0} %add.167), metadata={op_name="XLA_Retvals"} %copy.183 = f32[10]{0} copy(f32[10]{0} %reshape.182), metadata={op_name="XLA_Retvals"} %reshape.184 = f32[10]{0} reshape(f32[10]{0} %add.149), metadata={op_name="XLA_Retvals"} %copy.185 = f32[10]{0} copy(f32[10]{0} %reshape.184), metadata={op_name="XLA_Retvals"} ROOT %tuple.186 = (f32[784,10]{1,0}, f32[10]{0}, s64[], f32[784,10]{1,0}, f32[784,10]{1,0}, /*index=5*/f32[10]{0}, f32[10]{0}) tuple(f32[784,10]{1,0} %copy.173, f32[10]{0} %copy.175, s64[] %copy.177, f32[784,10]{1,0} %copy.179, f32[784,10]{1,0} %copy.181, /*index=5*/f32[10]{0} %copy.183, f32[10]{0} %copy.185), metadata={op_name="XLA_Retvals"} }