Use XLA with tf.function

View on TensorFlow.org Run in Google Colab Download notebook View source on GitHub

This tutorial trains a TensorFlow model to classify the MNIST dataset, where the training function is compiled using XLA.

First, load TensorFlow and enable eager execution.

import tensorflow as tf
2023-01-18 12:09:53.148023: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-01-18 12:09:53.148129: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2023-01-18 12:09:53.148139: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

Then define some necessary constants and prepare the MNIST dataset.

# Size of each input image, 28 x 28 pixels
IMAGE_SIZE = 28 * 28
# Number of distinct number labels, [0..9]
NUM_CLASSES = 10
# Number of examples in each training batch (step)
TRAIN_BATCH_SIZE = 100
# Number of training steps to run
TRAIN_STEPS = 1000

# Loads MNIST dataset.
train, test = tf.keras.datasets.mnist.load_data()
train_ds = tf.data.Dataset.from_tensor_slices(train).batch(TRAIN_BATCH_SIZE).repeat()

# Casting from raw data to the required datatypes.
def cast(images, labels):
  images = tf.cast(
      tf.reshape(images, [-1, IMAGE_SIZE]), tf.float32)
  labels = tf.cast(labels, tf.int64)
  return (images, labels)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11490434/11490434 [==============================] - 0s 0us/step

Finally, define the model and the optimizer. The model uses a single dense layer.

layer = tf.keras.layers.Dense(NUM_CLASSES)
optimizer = tf.keras.optimizers.Adam()

Define the training function

In the training function, you get the predicted labels using the layer defined above, and then minimize the gradient of the loss using the optimizer. In order to compile the computation using XLA, place it inside tf.function with jit_compile=True.

@tf.function(jit_compile=True)
def train_mnist(images, labels):
    images, labels = cast(images, labels)

    with tf.GradientTape() as tape:
      predicted_labels = layer(images)
      loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
          logits=predicted_labels, labels=labels
      ))
    layer_variables = layer.trainable_variables
    grads = tape.gradient(loss, layer_variables)
    optimizer.apply_gradients(zip(grads, layer_variables))

Train and test the model

Once you have defined the training function, define the model.

for images, labels in train_ds:
  if optimizer.iterations > TRAIN_STEPS:
    break
  train_mnist(images, labels)

And, finally, check the accuracy:

images, labels = cast(test[0], test[1])
predicted_labels = layer(images)
correct_prediction = tf.equal(tf.argmax(predicted_labels, 1), labels)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print("Prediction accuracy after training: %s" % accuracy)
Prediction accuracy after training: tf.Tensor(0.8773, shape=(), dtype=float32)

Behind the scenes, the XLA compiler has compiled the entire TF function to HLO, which has enabled fusion optimizations. Using the introspection facilities, we can see the HLO code (other interesting possible values for "stage" are optimized_hlo for HLO after optimizations and optimized_hlo_dot for a Graphviz graph):

print(train_mnist.experimental_get_compiler_ir(images, labels)(stage='hlo'))
HloModule a_inference_train_mnist_5324__.206, input_output_alias={ {0}: (2, {}, may-alias), {1}: (3, {}, may-alias), {2}: (4, {}, may-alias), {3}: (6, {}, may-alias), {4}: (7, {}, may-alias), {5}: (8, {}, may-alias), {6}: (9, {}, may-alias) }, entry_computation_layout={(f32[10000,784]{1,0},s64[10000]{0},f32[784,10]{1,0},f32[10]{0},s64[],f32[],f32[784,10]{1,0},f32[784,10]{1,0},f32[10]{0},f32[10]{0})->(f32[784,10]{1,0}, f32[10]{0}, s64[], f32[784,10]{1,0}, f32[784,10]{1,0}, /*index=5*/f32[10]{0}, f32[10]{0})}

%max_float_.42 (x.43: f32[], y.44: f32[]) -> f32[] {
  %x.43 = f32[] parameter(0)
  %y.44 = f32[] parameter(1)
  ROOT %maximum.45 = f32[] maximum(f32[] %x.43, f32[] %y.44)
}

%add_float_.52 (x.53: f32[], y.54: f32[]) -> f32[] {
  %x.53 = f32[] parameter(0)
  %y.54 = f32[] parameter(1)
  ROOT %add.55 = f32[] add(f32[] %x.53, f32[] %y.54)
}

%add_float_.71 (x.72: f32[], y.73: f32[]) -> f32[] {
  %x.72 = f32[] parameter(0)
  %y.73 = f32[] parameter(1)
  ROOT %add.74 = f32[] add(f32[] %x.72, f32[] %y.73)
}

%Mean-reduction.83 (x.84: f32[], y.85: f32[]) -> f32[] {
  %x.84 = f32[] parameter(0)
  %y.85 = f32[] parameter(1)
  ROOT %add.86 = f32[] add(f32[] %x.84, f32[] %y.85)
}

%region_0.100 (Arg_0.101: f32[], Arg_1.102: f32[]) -> f32[] {
  %Arg_0.101 = f32[] parameter(0)
  %Arg_1.102 = f32[] parameter(1)
  ROOT %add.103 = f32[] add(f32[] %Arg_0.101, f32[] %Arg_1.102), metadata={op_type="BiasAddGrad" op_name="gradient_tape/dense/BiasAdd/BiasAddGrad"}
}

ENTRY %a_inference_train_mnist_5324__.206 (arg0.1: f32[10000,784], arg1.2: s64[10000], arg2.3: f32[784,10], arg3.4: f32[10], arg4.5: s64[], arg5.6: f32[], arg6.7: f32[784,10], arg7.8: f32[784,10], arg8.9: f32[10], arg9.10: f32[10]) -> (f32[784,10], f32[10], s64[], f32[784,10], f32[784,10], /*index=5*/f32[10], f32[10]) {
  %constant.13 = s32[2]{0} constant({-1, 784}), metadata={op_type="Reshape" op_name="Reshape" source_file="/tmpfs/tmp/ipykernel_11149/494562224.py" source_line=16}
  %arg1.2 = s64[10000]{0} parameter(1), parameter_replication={false}, metadata={op_name="XLA_Args"}
  %reshape.12 = s64[10000]{0} reshape(s64[10000]{0} %arg1.2)
  %broadcast.24 = s64[10000,10]{1,0} broadcast(s64[10000]{0} %reshape.12), dimensions={0}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %iota.23 = s64[10000,10]{1,0} iota(), iota_dimension=1, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %compare.25 = pred[10000,10]{1,0} compare(s64[10000,10]{1,0} %broadcast.24, s64[10000,10]{1,0} %iota.23), direction=EQ, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %constant.20 = f32[] constant(1), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %broadcast.22 = f32[10000,10]{1,0} broadcast(f32[] %constant.20), dimensions={}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %constant.19 = f32[] constant(0), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %broadcast.21 = f32[10000,10]{1,0} broadcast(f32[] %constant.19), dimensions={}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %select.26 = f32[10000,10]{1,0} select(pred[10000,10]{1,0} %compare.25, f32[10000,10]{1,0} %broadcast.22, f32[10000,10]{1,0} %broadcast.21), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %constant.34 = s64[] constant(0), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %broadcast.35 = s64[10000]{0} broadcast(s64[] %constant.34), dimensions={}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %compare.36 = pred[10000]{0} compare(s64[10000]{0} %broadcast.35, s64[10000]{0} %reshape.12), direction=LE, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %constant.31 = s64[] constant(10), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %broadcast.32 = s64[10000]{0} broadcast(s64[] %constant.31), dimensions={}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %compare.33 = pred[10000]{0} compare(s64[10000]{0} %reshape.12, s64[10000]{0} %broadcast.32), direction=LT, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %and.37 = pred[10000]{0} and(pred[10000]{0} %compare.36, pred[10000]{0} %compare.33), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %constant.29 = f32[] constant(0), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %broadcast.30 = f32[10000]{0} broadcast(f32[] %constant.29), dimensions={}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %constant.27 = f32[] constant(nan), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %broadcast.28 = f32[10000]{0} broadcast(f32[] %constant.27), dimensions={}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %select.38 = f32[10000]{0} select(pred[10000]{0} %and.37, f32[10000]{0} %broadcast.30, f32[10000]{0} %broadcast.28), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %broadcast.39 = f32[10000,10]{1,0} broadcast(f32[10000]{0} %select.38), dimensions={0}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %add.40 = f32[10000,10]{1,0} add(f32[10000,10]{1,0} %select.26, f32[10000,10]{1,0} %broadcast.39), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %negate.67 = f32[10000,10]{1,0} negate(f32[10000,10]{1,0} %add.40), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %constant.63 = f32[] constant(0), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %broadcast.64 = f32[10000,10]{1,0} broadcast(f32[] %constant.63), dimensions={}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %compare.65 = pred[10000,10]{1,0} compare(f32[10000,10]{1,0} %add.40, f32[10000,10]{1,0} %broadcast.64), direction=EQ, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %constant.61 = f32[] constant(0), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %broadcast.62 = f32[10000,10]{1,0} broadcast(f32[] %constant.61), dimensions={}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %arg0.1 = f32[10000,784]{1,0} parameter(0), parameter_replication={false}, metadata={op_name="XLA_Args"}
  %reshape.11 = f32[10000,784]{1,0} reshape(f32[10000,784]{1,0} %arg0.1)
  %reshape.14 = f32[10000,784]{1,0} reshape(f32[10000,784]{1,0} %reshape.11), metadata={op_type="Reshape" op_name="Reshape" source_file="/tmpfs/tmp/ipykernel_11149/494562224.py" source_line=16}
  %arg2.3 = f32[784,10]{1,0} parameter(2), parameter_replication={false}, metadata={op_name="XLA_Args"}
  %dot.15 = f32[10000,10]{1,0} dot(f32[10000,784]{1,0} %reshape.14, f32[784,10]{1,0} %arg2.3), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_type="MatMul" op_name="dense/MatMul" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/layers/core/dense.py" source_line=241}
  %transpose.16 = f32[10000,10]{1,0} transpose(f32[10000,10]{1,0} %dot.15), dimensions={0,1}, metadata={op_type="MatMul" op_name="dense/MatMul" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/layers/core/dense.py" source_line=241}
  %arg3.4 = f32[10]{0} parameter(3), parameter_replication={false}, metadata={op_name="XLA_Args"}
  %broadcast.17 = f32[10000,10]{1,0} broadcast(f32[10]{0} %arg3.4), dimensions={1}, metadata={op_type="BiasAdd" op_name="dense/BiasAdd" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/layers/core/dense.py" source_line=252}
  %add.18 = f32[10000,10]{1,0} add(f32[10000,10]{1,0} %transpose.16, f32[10000,10]{1,0} %broadcast.17), metadata={op_type="BiasAdd" op_name="dense/BiasAdd" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/layers/core/dense.py" source_line=252}
  %constant.41 = f32[] constant(-inf), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %reduce.46 = f32[10000]{0} reduce(f32[10000,10]{1,0} %add.18, f32[] %constant.41), dimensions={1}, to_apply=%max_float_.42, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %broadcast.47 = f32[10000,10]{1,0} broadcast(f32[10000]{0} %reduce.46), dimensions={0}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %subtract.48 = f32[10000,10]{1,0} subtract(f32[10000,10]{1,0} %add.18, f32[10000,10]{1,0} %broadcast.47), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %exponential.49 = f32[10000,10]{1,0} exponential(f32[10000,10]{1,0} %subtract.48), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %convert.50 = f32[10000,10]{1,0} convert(f32[10000,10]{1,0} %exponential.49), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %constant.51 = f32[] constant(0), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %reduce.56 = f32[10000]{0} reduce(f32[10000,10]{1,0} %convert.50, f32[] %constant.51), dimensions={1}, to_apply=%add_float_.52, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %convert.57 = f32[10000]{0} convert(f32[10000]{0} %reduce.56), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %log.58 = f32[10000]{0} log(f32[10000]{0} %convert.57), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %broadcast.59 = f32[10000,10]{1,0} broadcast(f32[10000]{0} %log.58), dimensions={0}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %subtract.60 = f32[10000,10]{1,0} subtract(f32[10000,10]{1,0} %subtract.48, f32[10000,10]{1,0} %broadcast.59), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %select.66 = f32[10000,10]{1,0} select(pred[10000,10]{1,0} %compare.65, f32[10000,10]{1,0} %broadcast.62, f32[10000,10]{1,0} %subtract.60), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %multiply.68 = f32[10000,10]{1,0} multiply(f32[10000,10]{1,0} %negate.67, f32[10000,10]{1,0} %select.66), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %convert.70 = f32[10000,10]{1,0} convert(f32[10000,10]{1,0} %multiply.68), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %constant.69 = f32[] constant(0), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %reduce.75 = f32[10000]{0} reduce(f32[10000,10]{1,0} %convert.70, f32[] %constant.69), dimensions={1}, to_apply=%add_float_.71, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %convert.76 = f32[10000]{0} convert(f32[10000]{0} %reduce.75), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %convert.80 = f32[10000]{0} convert(f32[10000]{0} %convert.76), metadata={op_type="Mean" op_name="Mean" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %constant.81 = f32[] constant(0), metadata={op_type="Mean" op_name="Mean" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %convert.82 = f32[] convert(f32[] %constant.81), metadata={op_type="Mean" op_name="Mean" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %reduce.87 = f32[] reduce(f32[10000]{0} %convert.80, f32[] %convert.82), dimensions={0}, to_apply=%Mean-reduction.83, metadata={op_type="Mean" op_name="Mean" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %constant.88 = s32[] constant(10000), metadata={op_type="Mean" op_name="Mean" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %convert.89 = f32[] convert(s32[] %constant.88), metadata={op_type="Mean" op_name="Mean" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %divide.90 = f32[] divide(f32[] %reduce.87, f32[] %convert.89), metadata={op_type="Mean" op_name="Mean" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %convert.91 = f32[] convert(f32[] %divide.90), metadata={op_type="Mean" op_name="Mean" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %constant.92 = f32[] constant(0.0001), metadata={op_type="Mul" op_name="gradient_tape/SparseSoftmaxCrossEntropyWithLogits/mul" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=11}
  %broadcast.93 = f32[10000,1]{1,0} broadcast(f32[] %constant.92), dimensions={}, metadata={op_type="Mul" op_name="gradient_tape/SparseSoftmaxCrossEntropyWithLogits/mul" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=11}
  %constant.109 = s64[] constant(1), metadata={op_type="AddV2" op_name="StatefulPartitionedCall/add" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=162}
  %constant.113 = f32[] constant(0.9), metadata={op_type="Pow" op_name="StatefulPartitionedCall/Pow" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=163}
  %constant.116 = f32[] constant(1), metadata={op_type="Sub" op_name="StatefulPartitionedCall/sub_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=170}
  %constant.119 = f32[] constant(0.999), metadata={op_type="Pow" op_name="StatefulPartitionedCall/Pow_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=164}
  %constant.122 = f32[] constant(1), metadata={op_type="Sub" op_name="StatefulPartitionedCall/sub" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=170}
  %constant.126 = s64[] constant(1), metadata={op_type="AddV2" op_name="StatefulPartitionedCall_1/add" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=162}
  %constant.130 = f32[] constant(0.9), metadata={op_type="Pow" op_name="StatefulPartitionedCall_1/Pow" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=163}
  %constant.133 = f32[] constant(1), metadata={op_type="Sub" op_name="StatefulPartitionedCall_1/sub_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=170}
  %constant.136 = f32[] constant(0.999), metadata={op_type="Pow" op_name="StatefulPartitionedCall_1/Pow_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=164}
  %constant.139 = f32[] constant(1), metadata={op_type="Sub" op_name="StatefulPartitionedCall_1/sub" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=170}
  %constant.148 = f32[] constant(0.1), metadata={op_type="Mul" op_name="StatefulPartitionedCall/mul_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=194}
  %constant.156 = f32[] constant(0.001), metadata={op_type="Mul" op_name="StatefulPartitionedCall/mul_2" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=195}
  %constant.162 = f32[] constant(1e-07), metadata={op_type="AddV2" op_name="StatefulPartitionedCall/add_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=200}
  %constant.169 = f32[] constant(0.1), metadata={op_type="Mul" op_name="StatefulPartitionedCall_1/mul_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=194}
  %constant.177 = f32[] constant(0.001), metadata={op_type="Mul" op_name="StatefulPartitionedCall_1/mul_2" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=195}
  %constant.183 = f32[] constant(1e-07), metadata={op_type="AddV2" op_name="StatefulPartitionedCall_1/add_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=200}
  %arg6.7 = f32[784,10]{1,0} parameter(6), parameter_replication={false}, metadata={op_name="XLA_Args"}
  %constant.94 = f32[] constant(0.0001), metadata={op_type="Mul" op_name="gradient_tape/SparseSoftmaxCrossEntropyWithLogits/mul" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=11}
  %broadcast.95 = f32[10000,1]{1,0} broadcast(f32[] %constant.94), dimensions={}, metadata={op_type="Mul" op_name="gradient_tape/SparseSoftmaxCrossEntropyWithLogits/mul" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=11}
  %reshape.96 = f32[10000]{0} reshape(f32[10000,1]{1,0} %broadcast.95), metadata={op_type="Mul" op_name="gradient_tape/SparseSoftmaxCrossEntropyWithLogits/mul" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=11}
  %broadcast.97 = f32[10000,10]{1,0} broadcast(f32[10000]{0} %reshape.96), dimensions={0}, metadata={op_type="Mul" op_name="gradient_tape/SparseSoftmaxCrossEntropyWithLogits/mul" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=11}
  %broadcast.77 = f32[10000,10]{1,0} broadcast(f32[10000]{0} %convert.57), dimensions={0}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %divide.78 = f32[10000,10]{1,0} divide(f32[10000,10]{1,0} %exponential.49, f32[10000,10]{1,0} %broadcast.77), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %subtract.79 = f32[10000,10]{1,0} subtract(f32[10000,10]{1,0} %divide.78, f32[10000,10]{1,0} %add.40), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=7}
  %multiply.98 = f32[10000,10]{1,0} multiply(f32[10000,10]{1,0} %broadcast.97, f32[10000,10]{1,0} %subtract.79), metadata={op_type="Mul" op_name="gradient_tape/SparseSoftmaxCrossEntropyWithLogits/mul" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=11}
  %dot.105 = f32[784,10]{1,0} dot(f32[10000,784]{1,0} %reshape.14, f32[10000,10]{1,0} %multiply.98), lhs_contracting_dims={0}, rhs_contracting_dims={0}, metadata={op_type="MatMul" op_name="gradient_tape/dense/MatMul/MatMul" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=11}
  %transpose.106 = f32[784,10]{1,0} transpose(f32[784,10]{1,0} %dot.105), dimensions={0,1}, metadata={op_type="MatMul" op_name="gradient_tape/dense/MatMul/MatMul" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=11}
  %subtract.147 = f32[784,10]{1,0} subtract(f32[784,10]{1,0} %transpose.106, f32[784,10]{1,0} %arg6.7), metadata={op_type="Sub" op_name="StatefulPartitionedCall/sub_2" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=194}
  %constant.149 = f32[] constant(0.1), metadata={op_type="Mul" op_name="StatefulPartitionedCall/mul_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=194}
  %broadcast.150 = f32[784,10]{1,0} broadcast(f32[] %constant.149), dimensions={}, metadata={op_type="Mul" op_name="StatefulPartitionedCall/mul_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=194}
  %multiply.151 = f32[784,10]{1,0} multiply(f32[784,10]{1,0} %subtract.147, f32[784,10]{1,0} %broadcast.150), metadata={op_type="Mul" op_name="StatefulPartitionedCall/mul_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=194}
  %add.152 = f32[784,10]{1,0} add(f32[784,10]{1,0} %arg6.7, f32[784,10]{1,0} %multiply.151), metadata={op_type="AssignAddVariableOp" op_name="StatefulPartitionedCall/AssignAddVariableOp" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=194}
  %arg5.6 = f32[] parameter(5), parameter_replication={false}, metadata={op_name="XLA_Args"}
  %constant.123 = f32[] constant(1), metadata={op_type="Sub" op_name="StatefulPartitionedCall/sub" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=170}
  %constant.120 = f32[] constant(0.999), metadata={op_type="Pow" op_name="StatefulPartitionedCall/Pow_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=164}
  %arg4.5 = s64[] parameter(4), parameter_replication={false}, metadata={op_name="XLA_Args"}
  %constant.110 = s64[] constant(1), metadata={op_type="AddV2" op_name="StatefulPartitionedCall/add" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=162}
  %add.111 = s64[] add(s64[] %arg4.5, s64[] %constant.110), metadata={op_type="AddV2" op_name="StatefulPartitionedCall/add" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=162}
  %convert.112 = f32[] convert(s64[] %add.111), metadata={op_type="Cast" op_name="StatefulPartitionedCall/Cast" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=162}
  %power.121 = f32[] power(f32[] %constant.120, f32[] %convert.112), metadata={op_type="Pow" op_name="StatefulPartitionedCall/Pow_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=164}
  %subtract.124 = f32[] subtract(f32[] %constant.123, f32[] %power.121), metadata={op_type="Sub" op_name="StatefulPartitionedCall/sub" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=170}
  %sqrt.125 = f32[] sqrt(f32[] %subtract.124), metadata={op_type="Sqrt" op_name="Sqrt.StatefulPartitionedCall/Sqrt"}
  %multiply.143 = f32[] multiply(f32[] %arg5.6, f32[] %sqrt.125), metadata={op_type="Mul" op_name="StatefulPartitionedCall/mul" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=170}
  %constant.117 = f32[] constant(1), metadata={op_type="Sub" op_name="StatefulPartitionedCall/sub_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=170}
  %constant.114 = f32[] constant(0.9), metadata={op_type="Pow" op_name="StatefulPartitionedCall/Pow" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=163}
  %power.115 = f32[] power(f32[] %constant.114, f32[] %convert.112), metadata={op_type="Pow" op_name="StatefulPartitionedCall/Pow" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=163}
  %subtract.118 = f32[] subtract(f32[] %constant.117, f32[] %power.115), metadata={op_type="Sub" op_name="StatefulPartitionedCall/sub_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=170}
  %divide.144 = f32[] divide(f32[] %multiply.143, f32[] %subtract.118), metadata={op_type="RealDiv" op_name="StatefulPartitionedCall/truediv" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=170}
  %broadcast.153 = f32[784,10]{1,0} broadcast(f32[] %divide.144), dimensions={}, metadata={op_type="Mul" op_name="StatefulPartitionedCall/mul_3" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=200}
  %multiply.154 = f32[784,10]{1,0} multiply(f32[784,10]{1,0} %add.152, f32[784,10]{1,0} %broadcast.153), metadata={op_type="Mul" op_name="StatefulPartitionedCall/mul_3" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=200}
  %arg7.8 = f32[784,10]{1,0} parameter(7), parameter_replication={false}, metadata={op_name="XLA_Args"}
  %multiply.107 = f32[784,10]{1,0} multiply(f32[784,10]{1,0} %transpose.106, f32[784,10]{1,0} %transpose.106), metadata={op_type="Square" op_name="StatefulPartitionedCall/Square" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=195}
  %subtract.155 = f32[784,10]{1,0} subtract(f32[784,10]{1,0} %multiply.107, f32[784,10]{1,0} %arg7.8), metadata={op_type="Sub" op_name="StatefulPartitionedCall/sub_3" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=195}
  %constant.157 = f32[] constant(0.001), metadata={op_type="Mul" op_name="StatefulPartitionedCall/mul_2" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=195}
  %broadcast.158 = f32[784,10]{1,0} broadcast(f32[] %constant.157), dimensions={}, metadata={op_type="Mul" op_name="StatefulPartitionedCall/mul_2" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=195}
  %multiply.159 = f32[784,10]{1,0} multiply(f32[784,10]{1,0} %subtract.155, f32[784,10]{1,0} %broadcast.158), metadata={op_type="Mul" op_name="StatefulPartitionedCall/mul_2" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=195}
  %add.160 = f32[784,10]{1,0} add(f32[784,10]{1,0} %arg7.8, f32[784,10]{1,0} %multiply.159), metadata={op_type="AssignAddVariableOp" op_name="StatefulPartitionedCall/AssignAddVariableOp_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=195}
  %sqrt.161 = f32[784,10]{1,0} sqrt(f32[784,10]{1,0} %add.160), metadata={op_type="Sqrt" op_name="Sqrt_1.StatefulPartitionedCall/Sqrt_1"}
  %constant.163 = f32[] constant(1e-07), metadata={op_type="AddV2" op_name="StatefulPartitionedCall/add_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=200}
  %broadcast.164 = f32[784,10]{1,0} broadcast(f32[] %constant.163), dimensions={}, metadata={op_type="AddV2" op_name="StatefulPartitionedCall/add_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=200}
  %add.165 = f32[784,10]{1,0} add(f32[784,10]{1,0} %sqrt.161, f32[784,10]{1,0} %broadcast.164), metadata={op_type="AddV2" op_name="StatefulPartitionedCall/add_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=200}
  %divide.166 = f32[784,10]{1,0} divide(f32[784,10]{1,0} %multiply.154, f32[784,10]{1,0} %add.165), metadata={op_type="RealDiv" op_name="StatefulPartitionedCall/truediv_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=200}
  %subtract.167 = f32[784,10]{1,0} subtract(f32[784,10]{1,0} %arg2.3, f32[784,10]{1,0} %divide.166), metadata={op_type="AssignSubVariableOp" op_name="StatefulPartitionedCall/AssignSubVariableOp" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=200}
  %reshape.191 = f32[784,10]{1,0} reshape(f32[784,10]{1,0} %subtract.167), metadata={op_name="XLA_Retvals"}
  %copy.192 = f32[784,10]{1,0} copy(f32[784,10]{1,0} %reshape.191), metadata={op_name="XLA_Retvals"}
  %arg8.9 = f32[10]{0} parameter(8), parameter_replication={false}, metadata={op_name="XLA_Args"}
  %constant.99 = f32[] constant(-0), metadata={op_type="BiasAddGrad" op_name="gradient_tape/dense/BiasAdd/BiasAddGrad" source_file="/tmpfs/tmp/ipykernel_11149/3922067182.py" source_line=11}
  %reduce.104 = f32[10]{0} reduce(f32[10000,10]{1,0} %multiply.98, f32[] %constant.99), dimensions={0}, to_apply=%region_0.100, metadata={op_type="BiasAddGrad" op_name="gradient_tape/dense/BiasAdd/BiasAddGrad"}
  %subtract.168 = f32[10]{0} subtract(f32[10]{0} %reduce.104, f32[10]{0} %arg8.9), metadata={op_type="Sub" op_name="StatefulPartitionedCall_1/sub_2" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=194}
  %constant.170 = f32[] constant(0.1), metadata={op_type="Mul" op_name="StatefulPartitionedCall_1/mul_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=194}
  %broadcast.171 = f32[10]{0} broadcast(f32[] %constant.170), dimensions={}, metadata={op_type="Mul" op_name="StatefulPartitionedCall_1/mul_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=194}
  %multiply.172 = f32[10]{0} multiply(f32[10]{0} %subtract.168, f32[10]{0} %broadcast.171), metadata={op_type="Mul" op_name="StatefulPartitionedCall_1/mul_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=194}
  %add.173 = f32[10]{0} add(f32[10]{0} %arg8.9, f32[10]{0} %multiply.172), metadata={op_type="AssignAddVariableOp" op_name="StatefulPartitionedCall_1/AssignAddVariableOp" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=194}
  %constant.140 = f32[] constant(1), metadata={op_type="Sub" op_name="StatefulPartitionedCall_1/sub" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=170}
  %constant.137 = f32[] constant(0.999), metadata={op_type="Pow" op_name="StatefulPartitionedCall_1/Pow_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=164}
  %constant.127 = s64[] constant(1), metadata={op_type="AddV2" op_name="StatefulPartitionedCall_1/add" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=162}
  %add.128 = s64[] add(s64[] %arg4.5, s64[] %constant.127), metadata={op_type="AddV2" op_name="StatefulPartitionedCall_1/add" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=162}
  %convert.129 = f32[] convert(s64[] %add.128), metadata={op_type="Cast" op_name="StatefulPartitionedCall_1/Cast" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=162}
  %power.138 = f32[] power(f32[] %constant.137, f32[] %convert.129), metadata={op_type="Pow" op_name="StatefulPartitionedCall_1/Pow_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=164}
  %subtract.141 = f32[] subtract(f32[] %constant.140, f32[] %power.138), metadata={op_type="Sub" op_name="StatefulPartitionedCall_1/sub" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=170}
  %sqrt.142 = f32[] sqrt(f32[] %subtract.141), metadata={op_type="Sqrt" op_name="Sqrt.StatefulPartitionedCall_1/Sqrt"}
  %multiply.145 = f32[] multiply(f32[] %arg5.6, f32[] %sqrt.142), metadata={op_type="Mul" op_name="StatefulPartitionedCall_1/mul" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=170}
  %constant.134 = f32[] constant(1), metadata={op_type="Sub" op_name="StatefulPartitionedCall_1/sub_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=170}
  %constant.131 = f32[] constant(0.9), metadata={op_type="Pow" op_name="StatefulPartitionedCall_1/Pow" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=163}
  %power.132 = f32[] power(f32[] %constant.131, f32[] %convert.129), metadata={op_type="Pow" op_name="StatefulPartitionedCall_1/Pow" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=163}
  %subtract.135 = f32[] subtract(f32[] %constant.134, f32[] %power.132), metadata={op_type="Sub" op_name="StatefulPartitionedCall_1/sub_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=170}
  %divide.146 = f32[] divide(f32[] %multiply.145, f32[] %subtract.135), metadata={op_type="RealDiv" op_name="StatefulPartitionedCall_1/truediv" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=170}
  %broadcast.174 = f32[10]{0} broadcast(f32[] %divide.146), dimensions={}, metadata={op_type="Mul" op_name="StatefulPartitionedCall_1/mul_3" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=200}
  %multiply.175 = f32[10]{0} multiply(f32[10]{0} %add.173, f32[10]{0} %broadcast.174), metadata={op_type="Mul" op_name="StatefulPartitionedCall_1/mul_3" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=200}
  %arg9.10 = f32[10]{0} parameter(9), parameter_replication={false}, metadata={op_name="XLA_Args"}
  %multiply.108 = f32[10]{0} multiply(f32[10]{0} %reduce.104, f32[10]{0} %reduce.104), metadata={op_type="Square" op_name="StatefulPartitionedCall_1/Square" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=195}
  %subtract.176 = f32[10]{0} subtract(f32[10]{0} %multiply.108, f32[10]{0} %arg9.10), metadata={op_type="Sub" op_name="StatefulPartitionedCall_1/sub_3" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=195}
  %constant.178 = f32[] constant(0.001), metadata={op_type="Mul" op_name="StatefulPartitionedCall_1/mul_2" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=195}
  %broadcast.179 = f32[10]{0} broadcast(f32[] %constant.178), dimensions={}, metadata={op_type="Mul" op_name="StatefulPartitionedCall_1/mul_2" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=195}
  %multiply.180 = f32[10]{0} multiply(f32[10]{0} %subtract.176, f32[10]{0} %broadcast.179), metadata={op_type="Mul" op_name="StatefulPartitionedCall_1/mul_2" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=195}
  %add.181 = f32[10]{0} add(f32[10]{0} %arg9.10, f32[10]{0} %multiply.180), metadata={op_type="AssignAddVariableOp" op_name="StatefulPartitionedCall_1/AssignAddVariableOp_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=195}
  %sqrt.182 = f32[10]{0} sqrt(f32[10]{0} %add.181), metadata={op_type="Sqrt" op_name="Sqrt_1.StatefulPartitionedCall_1/Sqrt_1"}
  %constant.184 = f32[] constant(1e-07), metadata={op_type="AddV2" op_name="StatefulPartitionedCall_1/add_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=200}
  %broadcast.185 = f32[10]{0} broadcast(f32[] %constant.184), dimensions={}, metadata={op_type="AddV2" op_name="StatefulPartitionedCall_1/add_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=200}
  %add.186 = f32[10]{0} add(f32[10]{0} %sqrt.182, f32[10]{0} %broadcast.185), metadata={op_type="AddV2" op_name="StatefulPartitionedCall_1/add_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=200}
  %divide.187 = f32[10]{0} divide(f32[10]{0} %multiply.175, f32[10]{0} %add.186), metadata={op_type="RealDiv" op_name="StatefulPartitionedCall_1/truediv_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=200}
  %subtract.188 = f32[10]{0} subtract(f32[10]{0} %arg3.4, f32[10]{0} %divide.187), metadata={op_type="AssignSubVariableOp" op_name="StatefulPartitionedCall_1/AssignSubVariableOp" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py" source_line=200}
  %reshape.193 = f32[10]{0} reshape(f32[10]{0} %subtract.188), metadata={op_name="XLA_Retvals"}
  %copy.194 = f32[10]{0} copy(f32[10]{0} %reshape.193), metadata={op_name="XLA_Retvals"}
  %constant.189 = s64[] constant(1), metadata={op_type="AssignAddVariableOp" op_name="AssignAddVariableOp" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/optimizer.py" source_line=1236}
  %add.190 = s64[] add(s64[] %arg4.5, s64[] %constant.189), metadata={op_type="AssignAddVariableOp" op_name="AssignAddVariableOp" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/optimizer.py" source_line=1236}
  %reshape.195 = s64[] reshape(s64[] %add.190), metadata={op_name="XLA_Retvals"}
  %copy.196 = s64[] copy(s64[] %reshape.195), metadata={op_name="XLA_Retvals"}
  %reshape.197 = f32[784,10]{1,0} reshape(f32[784,10]{1,0} %add.152), metadata={op_name="XLA_Retvals"}
  %copy.198 = f32[784,10]{1,0} copy(f32[784,10]{1,0} %reshape.197), metadata={op_name="XLA_Retvals"}
  %reshape.199 = f32[784,10]{1,0} reshape(f32[784,10]{1,0} %add.160), metadata={op_name="XLA_Retvals"}
  %copy.200 = f32[784,10]{1,0} copy(f32[784,10]{1,0} %reshape.199), metadata={op_name="XLA_Retvals"}
  %reshape.201 = f32[10]{0} reshape(f32[10]{0} %add.173), metadata={op_name="XLA_Retvals"}
  %copy.202 = f32[10]{0} copy(f32[10]{0} %reshape.201), metadata={op_name="XLA_Retvals"}
  %reshape.203 = f32[10]{0} reshape(f32[10]{0} %add.181), metadata={op_name="XLA_Retvals"}
  %copy.204 = f32[10]{0} copy(f32[10]{0} %reshape.203), metadata={op_name="XLA_Retvals"}
  ROOT %tuple.205 = (f32[784,10]{1,0}, f32[10]{0}, s64[], f32[784,10]{1,0}, f32[784,10]{1,0}, /*index=5*/f32[10]{0}, f32[10]{0}) tuple(f32[784,10]{1,0} %copy.192, f32[10]{0} %copy.194, s64[] %copy.196, f32[784,10]{1,0} %copy.198, f32[784,10]{1,0} %copy.200, /*index=5*/f32[10]{0} %copy.202, f32[10]{0} %copy.204), metadata={op_name="XLA_Retvals"}
}