Image classification with Model Garden

View on TensorFlow.org Run in Google Colab View on GitHub Download notebook

This tutorial fine-tunes a Residual Network (ResNet) from the TensorFlow Model Garden package (tensorflow-models) to classify images in the CIFAR dataset.

Model Garden contains a collection of state-of-the-art vision models, implemented with TensorFlow's high-level APIs. The implementations demonstrate the best practices for modeling, letting users to take full advantage of TensorFlow for their research and product development.

This tutorial uses a ResNet model, a state-of-the-art image classifier. This tutorial uses the ResNet-18 model, a convolutional neural network with 18 layers.

This tutorial demonstrates how to:

  1. Use models from the TensorFlow Models package.
  2. Fine-tune a pre-built ResNet for image classification.
  3. Export the tuned ResNet model.

Setup

Install and import the necessary modules. This tutorial uses the tf-models-nightly version of Model Garden.

pip uninstall -y opencv-python
pip install -U -q "tensorflow>=2.9.0" "tf-models-official"

Import TensorFlow, TensorFlow Datasets, and a few helper libraries.

import pprint
import tempfile

from IPython import display
import matplotlib.pyplot as plt

import tensorflow as tf
import tensorflow_datasets as tfds

The tensorflow_models package contains the ResNet vision model, and the official.vision.serving model contains the function to save and export the tuned model.

import tensorflow_models as tfm

# These are not in the tfm public API for v2.9. They will be available in v2.10
from official.vision.serving import export_saved_model_lib
import official.core.train_lib

Configure the ResNet-18 model for the Cifar-10 dataset

The CIFAR10 dataset contains 60,000 color images in mutually exclusive 10 classes, with 6,000 images in each class.

In Model Garden, the collections of parameters that define a model are called configs. Model Garden can create a config based on a known set of parameters via a factory.

Use the resnet_imagenet factory configuration, as defined by tfm.vision.configs.image_classification.image_classification_imagenet. The configuration is set up to train ResNet to converge on ImageNet.

exp_config = tfm.core.exp_factory.get_exp_config('resnet_imagenet')
tfds_name = 'cifar10'
ds_info = tfds.builder(tfds_name ).info
ds_info
tfds.core.DatasetInfo(
    name='cifar10',
    full_name='cifar10/3.0.2',
    description="""
    The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.
    """,
    homepage='https://www.cs.toronto.edu/~kriz/cifar.html',
    data_path='gs://tensorflow-datasets/datasets/cifar10/3.0.2',
    file_format=tfrecord,
    download_size=162.17 MiB,
    dataset_size=132.40 MiB,
    features=FeaturesDict({
        'id': Text(shape=(), dtype=tf.string),
        'image': Image(shape=(32, 32, 3), dtype=tf.uint8),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
    }),
    supervised_keys=('image', 'label'),
    disable_shuffling=False,
    splits={
        'test': <SplitInfo num_examples=10000, num_shards=1>,
        'train': <SplitInfo num_examples=50000, num_shards=1>,
    },
    citation="""@TECHREPORT{Krizhevsky09learningmultiple,
        author = {Alex Krizhevsky},
        title = {Learning multiple layers of features from tiny images},
        institution = {},
        year = {2009}
    }""",
)

Adjust the model and dataset configurations so that it works with Cifar-10 (cifar10).

# Configure model
exp_config.task.model.num_classes = 10
exp_config.task.model.input_size = list(ds_info.features["image"].shape)
exp_config.task.model.backbone.resnet.model_id = 18

# Configure training and testing data
batch_size = 128

exp_config.task.train_data.input_path = ''
exp_config.task.train_data.tfds_name = tfds_name
exp_config.task.train_data.tfds_split = 'train'
exp_config.task.train_data.global_batch_size = batch_size

exp_config.task.validation_data.input_path = ''
exp_config.task.validation_data.tfds_name = tfds_name
exp_config.task.validation_data.tfds_split = 'test'
exp_config.task.validation_data.global_batch_size = batch_size

Adjust the trainer configuration.

logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]

if 'GPU' in ''.join(logical_device_names):
  print('This may be broken in Colab.')
  device = 'GPU'
elif 'TPU' in ''.join(logical_device_names):
  print('This may be broken in Colab.')
  device = 'TPU'
else:
  print('Running on CPU is slow, so only train for a few steps.')
  device = 'CPU'

if device=='CPU':
  train_steps = 20
  exp_config.trainer.steps_per_loop = 5
else:
  train_steps=5000
  exp_config.trainer.steps_per_loop = 100

exp_config.trainer.summary_interval = 100
exp_config.trainer.checkpoint_interval = train_steps
exp_config.trainer.validation_interval = 1000
exp_config.trainer.validation_steps =  ds_info.splits['test'].num_examples // batch_size
exp_config.trainer.train_steps = train_steps
exp_config.trainer.optimizer_config.learning_rate.type = 'cosine'
exp_config.trainer.optimizer_config.learning_rate.cosine.decay_steps = train_steps
exp_config.trainer.optimizer_config.learning_rate.cosine.initial_learning_rate = 0.1
exp_config.trainer.optimizer_config.warmup.linear.warmup_steps = 100
This may be broken in Colab.

Print the modified configuration.

pprint.pprint(exp_config.as_dict())

display.Javascript("google.colab.output.setIframeHeight('300px');")
{'runtime': {'all_reduce_alg': None,
             'batchnorm_spatial_persistent': False,
             'dataset_num_private_threads': None,
             'default_shard_dim': -1,
             'distribution_strategy': 'mirrored',
             'enable_xla': True,
             'gpu_thread_mode': None,
             'loss_scale': None,
             'mixed_precision_dtype': None,
             'num_cores_per_replica': 1,
             'num_gpus': 0,
             'num_packs': 1,
             'per_gpu_thread_count': 0,
             'run_eagerly': False,
             'task_index': -1,
             'tpu': None,
             'tpu_enable_xla_dynamic_padder': None,
             'worker_hosts': None},
 'task': {'differential_privacy_config': None,
          'evaluation': {'top_k': 5},
          'init_checkpoint': None,
          'init_checkpoint_modules': 'all',
          'losses': {'l2_weight_decay': 0.0001,
                     'label_smoothing': 0.0,
                     'loss_weight': 1.0,
                     'one_hot': True,
                     'soft_labels': False},
          'model': {'add_head_batch_norm': False,
                    'backbone': {'resnet': {'bn_trainable': True,
                                            'depth_multiplier': 1.0,
                                            'model_id': 18,
                                            'replace_stem_max_pool': False,
                                            'resnetd_shortcut': False,
                                            'scale_stem': True,
                                            'se_ratio': 0.0,
                                            'stem_type': 'v0',
                                            'stochastic_depth_drop_rate': 0.0},
                                 'type': 'resnet'},
                    'dropout_rate': 0.0,
                    'input_size': [32, 32, 3],
                    'kernel_initializer': 'random_uniform',
                    'norm_activation': {'activation': 'relu',
                                        'norm_epsilon': 1e-05,
                                        'norm_momentum': 0.9,
                                        'use_sync_bn': False},
                    'num_classes': 10},
          'model_output_keys': [],
          'name': None,
          'train_data': {'aug_policy': None,
                         'aug_rand_hflip': True,
                         'aug_type': None,
                         'block_length': 1,
                         'cache': False,
                         'color_jitter': 0.0,
                         'cycle_length': 10,
                         'decode_jpeg_only': True,
                         'decoder': {'simple_decoder': {'mask_binarize_threshold': None,
                                                        'regenerate_source_id': False},
                                     'type': 'simple_decoder'},
                         'deterministic': None,
                         'drop_remainder': True,
                         'dtype': 'float32',
                         'enable_tf_data_service': False,
                         'file_type': 'tfrecord',
                         'global_batch_size': 128,
                         'image_field_key': 'image/encoded',
                         'input_path': '',
                         'is_multilabel': False,
                         'is_training': True,
                         'label_field_key': 'image/class/label',
                         'mixup_and_cutmix': None,
                         'randaug_magnitude': 10,
                         'random_erasing': None,
                         'seed': None,
                         'sharding': True,
                         'shuffle_buffer_size': 10000,
                         'tf_data_service_address': None,
                         'tf_data_service_job_name': None,
                         'tfds_as_supervised': False,
                         'tfds_data_dir': '',
                         'tfds_name': 'cifar10',
                         'tfds_skip_decoding_feature': '',
                         'tfds_split': 'train'},
          'validation_data': {'aug_policy': None,
                              'aug_rand_hflip': True,
                              'aug_type': None,
                              'block_length': 1,
                              'cache': False,
                              'color_jitter': 0.0,
                              'cycle_length': 10,
                              'decode_jpeg_only': True,
                              'decoder': {'simple_decoder': {'mask_binarize_threshold': None,
                                                             'regenerate_source_id': False},
                                          'type': 'simple_decoder'},
                              'deterministic': None,
                              'drop_remainder': True,
                              'dtype': 'float32',
                              'enable_tf_data_service': False,
                              'file_type': 'tfrecord',
                              'global_batch_size': 128,
                              'image_field_key': 'image/encoded',
                              'input_path': '',
                              'is_multilabel': False,
                              'is_training': False,
                              'label_field_key': 'image/class/label',
                              'mixup_and_cutmix': None,
                              'randaug_magnitude': 10,
                              'random_erasing': None,
                              'seed': None,
                              'sharding': True,
                              'shuffle_buffer_size': 10000,
                              'tf_data_service_address': None,
                              'tf_data_service_job_name': None,
                              'tfds_as_supervised': False,
                              'tfds_data_dir': '',
                              'tfds_name': 'cifar10',
                              'tfds_skip_decoding_feature': '',
                              'tfds_split': 'test'}},
 'trainer': {'allow_tpu_summary': False,
             'best_checkpoint_eval_metric': '',
             'best_checkpoint_export_subdir': '',
             'best_checkpoint_metric_comp': 'higher',
             'checkpoint_interval': 5000,
             'continuous_eval_timeout': 3600,
             'eval_tf_function': True,
             'eval_tf_while_loop': False,
             'loss_upper_bound': 1000000.0,
             'max_to_keep': 5,
             'optimizer_config': {'ema': None,
                                  'learning_rate': {'cosine': {'alpha': 0.0,
                                                               'decay_steps': 5000,
                                                               'initial_learning_rate': 0.1,
                                                               'name': 'CosineDecay',
                                                               'offset': 0},
                                                    'type': 'cosine'},
                                  'optimizer': {'sgd': {'clipnorm': None,
                                                        'clipvalue': None,
                                                        'decay': 0.0,
                                                        'global_clipnorm': None,
                                                        'momentum': 0.9,
                                                        'name': 'SGD',
                                                        'nesterov': False},
                                                'type': 'sgd'},
                                  'warmup': {'linear': {'name': 'linear',
                                                        'warmup_learning_rate': 0,
                                                        'warmup_steps': 100},
                                             'type': 'linear'}},
             'recovery_begin_steps': 0,
             'recovery_max_trials': 0,
             'steps_per_loop': 100,
             'summary_interval': 100,
             'train_steps': 5000,
             'train_tf_function': True,
             'train_tf_while_loop': True,
             'validation_interval': 1000,
             'validation_steps': 78,
             'validation_summary_subdir': 'validation'}}
<IPython.core.display.Javascript object>

Set up the distribution strategy.

logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]

if exp_config.runtime.mixed_precision_dtype == tf.float16:
    tf.keras.mixed_precision.set_global_policy('mixed_float16')

if 'GPU' in ''.join(logical_device_names):
  distribution_strategy = tf.distribute.MirroredStrategy()
elif 'TPU' in ''.join(logical_device_names):
  tf.tpu.experimental.initialize_tpu_system()
  tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='/device:TPU_SYSTEM:0')
  distribution_strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
  print('Warning: this will be really slow.')
  distribution_strategy = tf.distribute.OneDeviceStrategy(logical_device_names[0])
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')

Create the Task object (tfm.core.base_task.Task) from the config_definitions.TaskConfig.

The Task object has all the methods necessary for building the dataset, building the model, and running training & evaluation. These methods are driven by tfm.core.train_lib.run_experiment.

with distribution_strategy.scope():
  model_dir = tempfile.mkdtemp()
  task = tfm.core.task_factory.get_task(exp_config.task, logging_dir=model_dir)

tf.keras.utils.plot_model(task.build_model(), show_shapes=True)
You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) for plot_model/model_to_dot to work.
for images, labels in task.build_inputs(exp_config.task.train_data).take(1):
  print()
  print(f'images.shape: {str(images.shape):16}  images.dtype: {images.dtype!r}')
  print(f'labels.shape: {str(labels.shape):16}  labels.dtype: {labels.dtype!r}')
images.shape: (128, 32, 32, 3)  images.dtype: tf.float32
labels.shape: (128,)            labels.dtype: tf.int32
2022-06-23 21:52:37.082150: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Visualize the training data

The dataloader applies a z-score normalization using preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB), so the images returned by the dataset can't be directly displayed by standard tools. The visualization code needs to rescale the data into the [0,1] range.

plt.hist(images.numpy().flatten());

png

Use ds_info (which is an instance of tfds.core.DatasetInfo) to lookup the text descriptions of each class ID.

label_info = ds_info.features['label']
label_info.int2str(1)
'automobile'

Visualize a batch of the data.

def show_batch(images, labels, predictions=None):
  plt.figure(figsize=(10, 10))
  min = images.numpy().min()
  max = images.numpy().max()
  delta = max - min

  for i in range(12):
    plt.subplot(6, 6, i + 1)
    plt.imshow((images[i]-min) / delta)
    if predictions is None:
      plt.title(label_info.int2str(labels[i]))
    else:
      if labels[i] == predictions[i]:
        color = 'g'
      else:
        color = 'r'
      plt.title(label_info.int2str(predictions[i]), color=color)
    plt.axis("off")
plt.figure(figsize=(10, 10))
for images, labels in task.build_inputs(exp_config.task.train_data).take(1):
  show_batch(images, labels)
2022-06-23 21:52:39.250249: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
<Figure size 720x720 with 0 Axes>

png

Visualize the testing data

Visualize a batch of images from the validation dataset.

plt.figure(figsize=(10, 10));
for images, labels in task.build_inputs(exp_config.task.validation_data).take(1):
  show_batch(images, labels)
2022-06-23 21:52:41.343590: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
<Figure size 720x720 with 0 Axes>

png

Train and evaluate

model, eval_logs = tfm.core.train_lib.run_experiment(
    distribution_strategy=distribution_strategy,
    task=task,
    mode='train_and_eval',
    params=exp_config,
    model_dir=model_dir,
    run_post_eval=True)
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
restoring or initializing model...
initialized model.
train | step:      0 | training until step 1000...
INFO:tensorflow:batch_all_reduce: 65 all-reduces with algorithm = nccl, num_packs = 1
INFO:tensorflow:batch_all_reduce: 65 all-reduces with algorithm = nccl, num_packs = 1
INFO:tensorflow:batch_all_reduce: 65 all-reduces with algorithm = nccl, num_packs = 1
INFO:tensorflow:batch_all_reduce: 65 all-reduces with algorithm = nccl, num_packs = 1
train | step:    100 | steps/sec:    3.9 | output: 
    {'accuracy': 0.2059375,
     'learning_rate': 0.09990134,
     'top_5_accuracy': 0.7071094,
     'training_loss': 2.7588568}
saved checkpoint to /tmpfs/tmp/tmp2wg8yq3r/ckpt-100.
train | step:    200 | steps/sec:   26.2 | output: 
    {'accuracy': 0.22742188,
     'learning_rate': 0.09960574,
     'top_5_accuracy': 0.75578123,
     'training_loss': 2.6750154}
train | step:    300 | steps/sec:   30.2 | output: 
    {'accuracy': 0.28734374,
     'learning_rate': 0.09911436,
     'top_5_accuracy': 0.8138281,
     'training_loss': 2.2482069}
train | step:    400 | steps/sec:   30.7 | output: 
    {'accuracy': 0.30882812,
     'learning_rate': 0.09842916,
     'top_5_accuracy': 0.82773435,
     'training_loss': 2.1465735}
train | step:    500 | steps/sec:   30.6 | output: 
    {'accuracy': 0.3253125,
     'learning_rate': 0.09755283,
     'top_5_accuracy': 0.8352344,
     'training_loss': 2.091837}
train | step:    600 | steps/sec:   30.6 | output: 
    {'accuracy': 0.34140626,
     'learning_rate': 0.096488826,
     'top_5_accuracy': 0.8472656,
     'training_loss': 2.0435574}
train | step:    700 | steps/sec:   30.7 | output: 
    {'accuracy': 0.3514844,
     'learning_rate': 0.09524136,
     'top_5_accuracy': 0.860625,
     'training_loss': 2.0075445}
train | step:    800 | steps/sec:   30.6 | output: 
    {'accuracy': 0.36304688,
     'learning_rate': 0.09381534,
     'top_5_accuracy': 0.86148435,
     'training_loss': 1.9868404}
train | step:    900 | steps/sec:   30.7 | output: 
    {'accuracy': 0.3725781,
     'learning_rate': 0.092216395,
     'top_5_accuracy': 0.8624219,
     'training_loss': 1.9544175}
train | step:   1000 | steps/sec:   30.7 | output: 
    {'accuracy': 0.38976562,
     'learning_rate': 0.090450846,
     'top_5_accuracy': 0.86476564,
     'training_loss': 1.9300879}
 eval | step:   1000 | running 78 steps of evaluation...
 eval | step:   1000 | eval time:    5.1 sec | output: 
    {'accuracy': 0.51272035,
     'top_5_accuracy': 0.93719953,
     'validation_loss': 1.608658}
train | step:   1000 | training until step 2000...
train | step:   1100 | steps/sec:   11.9 | output: 
    {'accuracy': 0.3907031,
     'learning_rate': 0.08852567,
     'top_5_accuracy': 0.8735156,
     'training_loss': 1.9157768}
train | step:   1200 | steps/sec:   30.5 | output: 
    {'accuracy': 0.39632812,
     'learning_rate': 0.08644843,
     'top_5_accuracy': 0.8744531,
     'training_loss': 1.8925322}
train | step:   1300 | steps/sec:   30.7 | output: 
    {'accuracy': 0.4128906,
     'learning_rate': 0.08422736,
     'top_5_accuracy': 0.87773436,
     'training_loss': 1.870737}
train | step:   1400 | steps/sec:   30.7 | output: 
    {'accuracy': 0.411875,
     'learning_rate': 0.081871204,
     'top_5_accuracy': 0.88125,
     'training_loss': 1.8580977}
train | step:   1500 | steps/sec:   30.8 | output: 
    {'accuracy': 0.42554688,
     'learning_rate': 0.07938927,
     'top_5_accuracy': 0.88304687,
     'training_loss': 1.8234015}
train | step:   1600 | steps/sec:   30.7 | output: 
    {'accuracy': 0.43609375,
     'learning_rate': 0.07679134,
     'top_5_accuracy': 0.8894531,
     'training_loss': 1.7906485}
train | step:   1700 | steps/sec:   30.8 | output: 
    {'accuracy': 0.4403125,
     'learning_rate': 0.07408768,
     'top_5_accuracy': 0.8871094,
     'training_loss': 1.7919132}
train | step:   1800 | steps/sec:   30.8 | output: 
    {'accuracy': 0.42757812,
     'learning_rate': 0.071288966,
     'top_5_accuracy': 0.8889844,
     'training_loss': 1.8168057}
train | step:   1900 | steps/sec:   30.8 | output: 
    {'accuracy': 0.43703124,
     'learning_rate': 0.068406224,
     'top_5_accuracy': 0.895625,
     'training_loss': 1.7795025}
train | step:   2000 | steps/sec:   30.8 | output: 
    {'accuracy': 0.4628125,
     'learning_rate': 0.06545085,
     'top_5_accuracy': 0.90367186,
     'training_loss': 1.7216893}
 eval | step:   2000 | running 78 steps of evaluation...
 eval | step:   2000 | eval time:    0.9 sec | output: 
    {'accuracy': 0.5889423,
     'top_5_accuracy': 0.9557292,
     'validation_loss': 1.4047486}
train | step:   2000 | training until step 3000...
train | step:   2100 | steps/sec:   23.9 | output: 
    {'accuracy': 0.46179688,
     'learning_rate': 0.062434502,
     'top_5_accuracy': 0.9009375,
     'training_loss': 1.7257642}
train | step:   2200 | steps/sec:   30.7 | output: 
    {'accuracy': 0.47304687,
     'learning_rate': 0.059369065,
     'top_5_accuracy': 0.9096875,
     'training_loss': 1.6878849}
train | step:   2300 | steps/sec:   30.7 | output: 
    {'accuracy': 0.47828126,
     'learning_rate': 0.056266654,
     'top_5_accuracy': 0.9125,
     'training_loss': 1.6812003}
train | step:   2400 | steps/sec:   30.8 | output: 
    {'accuracy': 0.4865625,
     'learning_rate': 0.053139526,
     'top_5_accuracy': 0.9122656,
     'training_loss': 1.6623691}
train | step:   2500 | steps/sec:   30.6 | output: 
    {'accuracy': 0.48382813,
     'learning_rate': 0.049999997,
     'top_5_accuracy': 0.9142187,
     'training_loss': 1.6555617}
train | step:   2600 | steps/sec:   30.7 | output: 
    {'accuracy': 0.4999219,
     'learning_rate': 0.04686048,
     'top_5_accuracy': 0.91875,
     'training_loss': 1.6292816}
train | step:   2700 | steps/sec:   30.7 | output: 
    {'accuracy': 0.49382812,
     'learning_rate': 0.043733336,
     'top_5_accuracy': 0.911875,
     'training_loss': 1.6349418}
train | step:   2800 | steps/sec:   30.8 | output: 
    {'accuracy': 0.5015625,
     'learning_rate': 0.040630933,
     'top_5_accuracy': 0.9150781,
     'training_loss': 1.6125712}
train | step:   2900 | steps/sec:   30.8 | output: 
    {'accuracy': 0.47125,
     'learning_rate': 0.037565507,
     'top_5_accuracy': 0.8939844,
     'training_loss': 1.696453}
train | step:   3000 | steps/sec:   30.8 | output: 
    {'accuracy': 0.49671876,
     'learning_rate': 0.034549143,
     'top_5_accuracy': 0.9072656,
     'training_loss': 1.6314162}
 eval | step:   3000 | running 78 steps of evaluation...
 eval | step:   3000 | eval time:    0.9 sec | output: 
    {'accuracy': 0.6453325,
     'top_5_accuracy': 0.96854967,
     'validation_loss': 1.2068491}
train | step:   3000 | training until step 4000...
train | step:   3100 | steps/sec:   23.8 | output: 
    {'accuracy': 0.5042969,
     'learning_rate': 0.03159377,
     'top_5_accuracy': 0.91664064,
     'training_loss': 1.5914496}
train | step:   3200 | steps/sec:   30.6 | output: 
    {'accuracy': 0.5197656,
     'learning_rate': 0.028711034,
     'top_5_accuracy': 0.92234373,
     'training_loss': 1.5563523}
train | step:   3300 | steps/sec:   30.8 | output: 
    {'accuracy': 0.5250781,
     'learning_rate': 0.025912309,
     'top_5_accuracy': 0.92648435,
     'training_loss': 1.5450299}
train | step:   3400 | steps/sec:   30.7 | output: 
    {'accuracy': 0.530625,
     'learning_rate': 0.023208654,
     'top_5_accuracy': 0.9259375,
     'training_loss': 1.5256865}
train | step:   3500 | steps/sec:   30.7 | output: 
    {'accuracy': 0.5353125,
     'learning_rate': 0.020610739,
     'top_5_accuracy': 0.92757815,
     'training_loss': 1.5099089}
train | step:   3600 | steps/sec:   30.8 | output: 
    {'accuracy': 0.5408594,
     'learning_rate': 0.018128792,
     'top_5_accuracy': 0.9283594,
     'training_loss': 1.500019}
train | step:   3700 | steps/sec:   30.8 | output: 
    {'accuracy': 0.5442188,
     'learning_rate': 0.015772644,
     'top_5_accuracy': 0.9330469,
     'training_loss': 1.4787773}
train | step:   3800 | steps/sec:   30.8 | output: 
    {'accuracy': 0.5565625,
     'learning_rate': 0.01355157,
     'top_5_accuracy': 0.9333594,
     'training_loss': 1.4548306}
train | step:   3900 | steps/sec:   30.8 | output: 
    {'accuracy': 0.5578125,
     'learning_rate': 0.011474336,
     'top_5_accuracy': 0.93109375,
     'training_loss': 1.4592746}
train | step:   4000 | steps/sec:   30.8 | output: 
    {'accuracy': 0.55960935,
     'learning_rate': 0.009549147,
     'top_5_accuracy': 0.9355469,
     'training_loss': 1.4406807}
 eval | step:   4000 | running 78 steps of evaluation...
 eval | step:   4000 | eval time:    0.9 sec | output: 
    {'accuracy': 0.6991186,
     'top_5_accuracy': 0.9748598,
     'validation_loss': 1.067995}
train | step:   4000 | training until step 5000...
train | step:   4100 | steps/sec:   23.9 | output: 
    {'accuracy': 0.56289065,
     'learning_rate': 0.0077836006,
     'top_5_accuracy': 0.9366406,
     'training_loss': 1.4425566}
train | step:   4200 | steps/sec:   30.8 | output: 
    {'accuracy': 0.5626562,
     'learning_rate': 0.0061846706,
     'top_5_accuracy': 0.9396875,
     'training_loss': 1.4339583}
train | step:   4300 | steps/sec:   30.7 | output: 
    {'accuracy': 0.5664844,
     'learning_rate': 0.0047586444,
     'top_5_accuracy': 0.9380469,
     'training_loss': 1.4205724}
train | step:   4400 | steps/sec:   30.8 | output: 
    {'accuracy': 0.5621875,
     'learning_rate': 0.0035111725,
     'top_5_accuracy': 0.9351562,
     'training_loss': 1.4305773}
train | step:   4500 | steps/sec:   30.7 | output: 
    {'accuracy': 0.5800781,
     'learning_rate': 0.002447176,
     'top_5_accuracy': 0.9391406,
     'training_loss': 1.4054605}
train | step:   4600 | steps/sec:   30.7 | output: 
    {'accuracy': 0.5790625,
     'learning_rate': 0.0015708387,
     'top_5_accuracy': 0.9386719,
     'training_loss': 1.4075251}
train | step:   4700 | steps/sec:   30.8 | output: 
    {'accuracy': 0.57820314,
     'learning_rate': 0.0008856386,
     'top_5_accuracy': 0.93539065,
     'training_loss': 1.4003149}
train | step:   4800 | steps/sec:   30.8 | output: 
    {'accuracy': 0.58171874,
     'learning_rate': 0.00039426386,
     'top_5_accuracy': 0.93921876,
     'training_loss': 1.3960028}
train | step:   4900 | steps/sec:   30.8 | output: 
    {'accuracy': 0.5779688,
     'learning_rate': 9.866357e-05,
     'top_5_accuracy': 0.9441406,
     'training_loss': 1.3905802}
train | step:   5000 | steps/sec:   30.8 | output: 
    {'accuracy': 0.58304685,
     'learning_rate': 0.0,
     'top_5_accuracy': 0.940625,
     'training_loss': 1.3832752}
 eval | step:   5000 | running 78 steps of evaluation...
 eval | step:   5000 | eval time:    0.9 sec | output: 
    {'accuracy': 0.7166466,
     'top_5_accuracy': 0.97686297,
     'validation_loss': 1.0204651}
saved checkpoint to /tmpfs/tmp/tmp2wg8yq3r/ckpt-5000.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/nn_ops.py:5219: tensor_shape_from_node_def_name (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.graph_util.tensor_shape_from_node_def_name`
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/nn_ops.py:5219: tensor_shape_from_node_def_name (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.graph_util.tensor_shape_from_node_def_name`
tf.keras.utils.plot_model(model, show_shapes=True)
You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) for plot_model/model_to_dot to work.

Print the accuracy, top_5_accuracy, and validation_loss evaluation metrics.

for key, value in eval_logs.items():
    print(f'{key:20}: {value.numpy():.3f}')
accuracy            : 0.717
top_5_accuracy      : 0.977
validation_loss     : 1.020

Run a batch of the processed training data through the model, and view the results

for images, labels in task.build_inputs(exp_config.task.train_data).take(1):
  predictions = model.predict(images)
  predictions = tf.argmax(predictions, axis=-1)

show_batch(images, labels, tf.cast(predictions, tf.int32))

if device=='CPU':
  plt.suptitle('The model was only trained for a few steps, it is not expected to do well.')
2022-06-23 21:56:03.812781: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:776] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Did not find a shardable source, walked to a node which is not a dataset: name: "FlatMapDataset/_9"
op: "FlatMapDataset"
input: "PrefetchDataset/_8"
attr {
  key: "Targuments"
  value {
    list {
    }
  }
}
attr {
  key: "_cardinality"
  value {
    i: -2
  }
}
attr {
  key: "f"
  value {
    func {
      name: "__inference_Dataset_flat_map_slice_batch_indices_54820"
    }
  }
}
attr {
  key: "metadata"
  value {
    s: "\n\022FlatMapDataset:208"
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
        dim {
          size: 32
        }
      }
    }
  }
}
attr {
  key: "output_types"
  value {
    list {
      type: DT_INT64
    }
  }
}
experimental_type {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_PRODUCT
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_INT64
        }
      }
    }
  }
}
. Consider either turning off auto-sharding or switching the auto_shard_policy to DATA to shard this dataset. You can do this by creating a new `tf.data.Options()` object then setting `options.experimental_distribute.auto_shard_policy = AutoShardPolicy.DATA` before applying the options object to the dataset via `dataset.with_options(options)`.
4/4 [==============================] - 3s 8ms/step
2022-06-23 21:56:06.406831: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

png

Export a SavedModel

The keras.Model object returned by train_lib.run_experiment expects the data to be normalized by the dataset loader using the same mean and variance statiscics in preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB). This export function handles those details, so you can pass tf.uint8 images and get the correct results.

# Saving and exporting the trained model
export_saved_model_lib.export_inference_graph(
    input_type='image_tensor',
    batch_size=1,
    input_image_size=[32, 32],
    params=exp_config,
    checkpoint_path=tf.train.latest_checkpoint(model_dir),
    export_dir='./export/')
WARNING:absl:Found untraced functions such as inference_for_tflite, inference_from_image_bytes, inference_from_tf_example, _jit_compiled_convolution_op, conv2d_layer_call_fn while saving (showing 5 of 64). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: ./export/assets
INFO:tensorflow:Assets written to: ./export/assets

Test the exported model.

# Importing SavedModel
imported = tf.saved_model.load('./export/')
model_fn = imported.signatures['serving_default']

Visualize the predictions.

plt.figure(figsize=(10, 10))
for data in tfds.load('cifar10', split='test').batch(12).take(1):
  predictions = []
  for image in data['image']:
    index = tf.argmax(model_fn(image[tf.newaxis, ...])['logits'], axis=1)[0]
    predictions.append(index)
  show_batch(data['image'], data['label'], predictions)

  if device=='CPU':
    plt.suptitle('The model was only trained for a few steps, it is not expected to do better than random.')
2022-06-23 21:56:18.767194: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
<Figure size 720x720 with 0 Axes>

png