![]() |
![]() |
![]() |
![]() |
This tutorial fine-tunes a Residual Network (ResNet) from the TensorFlow Model Garden package (tensorflow-models
) to classify images in the CIFAR dataset.
Model Garden contains a collection of state-of-the-art vision models, implemented with TensorFlow's high-level APIs. The implementations demonstrate the best practices for modeling, letting users to take full advantage of TensorFlow for their research and product development.
This tutorial uses a ResNet model, a state-of-the-art image classifier. This tutorial uses the ResNet-18 model, a convolutional neural network with 18 layers.
This tutorial demonstrates how to:
- Use models from the TensorFlow Models package.
- Fine-tune a pre-built ResNet for image classification.
- Export the tuned ResNet model.
Setup
Install and import the necessary modules.
pip install -U -q "tf-models-official"
Import TensorFlow, TensorFlow Datasets, and a few helper libraries.
import pprint
import tempfile
from IPython import display
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
2022-12-17 12:25:32.700765: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory 2022-12-17 12:25:32.700882: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory 2022-12-17 12:25:32.700892: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
The tensorflow_models
package contains the ResNet vision model, and the official.vision.serving
model contains the function to save and export the tuned model.
import tensorflow_models as tfm
# These are not in the tfm public API for v2.9. They will be available in v2.10
from official.vision.serving import export_saved_model_lib
import official.core.train_lib
Configure the ResNet-18 model for the Cifar-10 dataset
The CIFAR10 dataset contains 60,000 color images in mutually exclusive 10 classes, with 6,000 images in each class.
In Model Garden, the collections of parameters that define a model are called configs. Model Garden can create a config based on a known set of parameters via a factory.
Use the resnet_imagenet
factory configuration, as defined by tfm.vision.configs.image_classification.image_classification_imagenet
. The configuration is set up to train ResNet to converge on ImageNet.
exp_config = tfm.core.exp_factory.get_exp_config('resnet_imagenet')
tfds_name = 'cifar10'
ds_info = tfds.builder(tfds_name ).info
ds_info
tfds.core.DatasetInfo( name='cifar10', full_name='cifar10/3.0.2', description=""" The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images. """, homepage='https://www.cs.toronto.edu/~kriz/cifar.html', data_path='gs://tensorflow-datasets/datasets/cifar10/3.0.2', file_format=tfrecord, download_size=162.17 MiB, dataset_size=132.40 MiB, features=FeaturesDict({ 'id': Text(shape=(), dtype=tf.string), 'image': Image(shape=(32, 32, 3), dtype=tf.uint8), 'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10), }), supervised_keys=('image', 'label'), disable_shuffling=False, splits={ 'test': <SplitInfo num_examples=10000, num_shards=1>, 'train': <SplitInfo num_examples=50000, num_shards=1>, }, citation="""@TECHREPORT{Krizhevsky09learningmultiple, author = {Alex Krizhevsky}, title = {Learning multiple layers of features from tiny images}, institution = {}, year = {2009} }""", )
Adjust the model and dataset configurations so that it works with Cifar-10 (cifar10
).
# Configure model
exp_config.task.model.num_classes = 10
exp_config.task.model.input_size = list(ds_info.features["image"].shape)
exp_config.task.model.backbone.resnet.model_id = 18
# Configure training and testing data
batch_size = 128
exp_config.task.train_data.input_path = ''
exp_config.task.train_data.tfds_name = tfds_name
exp_config.task.train_data.tfds_split = 'train'
exp_config.task.train_data.global_batch_size = batch_size
exp_config.task.validation_data.input_path = ''
exp_config.task.validation_data.tfds_name = tfds_name
exp_config.task.validation_data.tfds_split = 'test'
exp_config.task.validation_data.global_batch_size = batch_size
Adjust the trainer configuration.
logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]
if 'GPU' in ''.join(logical_device_names):
print('This may be broken in Colab.')
device = 'GPU'
elif 'TPU' in ''.join(logical_device_names):
print('This may be broken in Colab.')
device = 'TPU'
else:
print('Running on CPU is slow, so only train for a few steps.')
device = 'CPU'
if device=='CPU':
train_steps = 20
exp_config.trainer.steps_per_loop = 5
else:
train_steps=5000
exp_config.trainer.steps_per_loop = 100
exp_config.trainer.summary_interval = 100
exp_config.trainer.checkpoint_interval = train_steps
exp_config.trainer.validation_interval = 1000
exp_config.trainer.validation_steps = ds_info.splits['test'].num_examples // batch_size
exp_config.trainer.train_steps = train_steps
exp_config.trainer.optimizer_config.learning_rate.type = 'cosine'
exp_config.trainer.optimizer_config.learning_rate.cosine.decay_steps = train_steps
exp_config.trainer.optimizer_config.learning_rate.cosine.initial_learning_rate = 0.1
exp_config.trainer.optimizer_config.warmup.linear.warmup_steps = 100
This may be broken in Colab.
Print the modified configuration.
pprint.pprint(exp_config.as_dict())
display.Javascript("google.colab.output.setIframeHeight('300px');")
{'runtime': {'all_reduce_alg': None, 'batchnorm_spatial_persistent': False, 'dataset_num_private_threads': None, 'default_shard_dim': -1, 'distribution_strategy': 'mirrored', 'enable_xla': True, 'gpu_thread_mode': None, 'loss_scale': None, 'mixed_precision_dtype': None, 'num_cores_per_replica': 1, 'num_gpus': 0, 'num_packs': 1, 'per_gpu_thread_count': 0, 'run_eagerly': False, 'task_index': -1, 'tpu': None, 'tpu_enable_xla_dynamic_padder': None, 'worker_hosts': None}, 'task': {'differential_privacy_config': None, 'evaluation': {'precision_and_recall_thresholds': None, 'report_per_class_precision_and_recall': False, 'top_k': 5}, 'freeze_backbone': False, 'init_checkpoint': None, 'init_checkpoint_modules': 'all', 'losses': {'l2_weight_decay': 0.0001, 'label_smoothing': 0.0, 'loss_weight': 1.0, 'one_hot': True, 'soft_labels': False}, 'model': {'add_head_batch_norm': False, 'backbone': {'resnet': {'bn_trainable': True, 'depth_multiplier': 1.0, 'model_id': 18, 'replace_stem_max_pool': False, 'resnetd_shortcut': False, 'scale_stem': True, 'se_ratio': 0.0, 'stem_type': 'v0', 'stochastic_depth_drop_rate': 0.0}, 'type': 'resnet'}, 'dropout_rate': 0.0, 'input_size': [32, 32, 3], 'kernel_initializer': 'random_uniform', 'norm_activation': {'activation': 'relu', 'norm_epsilon': 1e-05, 'norm_momentum': 0.9, 'use_sync_bn': False}, 'num_classes': 10, 'output_softmax': False}, 'model_output_keys': [], 'name': None, 'train_data': {'apply_tf_data_service_before_batching': False, 'aug_crop': True, 'aug_policy': None, 'aug_rand_hflip': True, 'aug_type': None, 'block_length': 1, 'cache': False, 'color_jitter': 0.0, 'crop_area_range': (0.08, 1.0), 'cycle_length': 10, 'decode_jpeg_only': True, 'decoder': {'simple_decoder': {'mask_binarize_threshold': None, 'regenerate_source_id': False}, 'type': 'simple_decoder'}, 'deterministic': None, 'drop_remainder': True, 'dtype': 'float32', 'enable_shared_tf_data_service_between_parallel_trainers': False, 'enable_tf_data_service': False, 'file_type': 'tfrecord', 'global_batch_size': 128, 'image_field_key': 'image/encoded', 'input_path': '', 'is_multilabel': False, 'is_training': True, 'label_field_key': 'image/class/label', 'mixup_and_cutmix': None, 'prefetch_buffer_size': None, 'randaug_magnitude': 10, 'random_erasing': None, 'seed': None, 'sharding': True, 'shuffle_buffer_size': 10000, 'tf_data_service_address': None, 'tf_data_service_job_name': None, 'tfds_as_supervised': False, 'tfds_data_dir': '', 'tfds_name': 'cifar10', 'tfds_skip_decoding_feature': '', 'tfds_split': 'train', 'trainer_id': None}, 'validation_data': {'apply_tf_data_service_before_batching': False, 'aug_crop': True, 'aug_policy': None, 'aug_rand_hflip': True, 'aug_type': None, 'block_length': 1, 'cache': False, 'color_jitter': 0.0, 'crop_area_range': (0.08, 1.0), 'cycle_length': 10, 'decode_jpeg_only': True, 'decoder': {'simple_decoder': {'mask_binarize_threshold': None, 'regenerate_source_id': False}, 'type': 'simple_decoder'}, 'deterministic': None, 'drop_remainder': True, 'dtype': 'float32', 'enable_shared_tf_data_service_between_parallel_trainers': False, 'enable_tf_data_service': False, 'file_type': 'tfrecord', 'global_batch_size': 128, 'image_field_key': 'image/encoded', 'input_path': '', 'is_multilabel': False, 'is_training': False, 'label_field_key': 'image/class/label', 'mixup_and_cutmix': None, 'prefetch_buffer_size': None, 'randaug_magnitude': 10, 'random_erasing': None, 'seed': None, 'sharding': True, 'shuffle_buffer_size': 10000, 'tf_data_service_address': None, 'tf_data_service_job_name': None, 'tfds_as_supervised': False, 'tfds_data_dir': '', 'tfds_name': 'cifar10', 'tfds_skip_decoding_feature': '', 'tfds_split': 'test', 'trainer_id': None} }, 'trainer': {'allow_tpu_summary': False, 'best_checkpoint_eval_metric': '', 'best_checkpoint_export_subdir': '', 'best_checkpoint_metric_comp': 'higher', 'checkpoint_interval': 5000, 'continuous_eval_timeout': 3600, 'eval_tf_function': True, 'eval_tf_while_loop': False, 'loss_upper_bound': 1000000.0, 'max_to_keep': 5, 'optimizer_config': {'ema': None, 'learning_rate': {'cosine': {'alpha': 0.0, 'decay_steps': 5000, 'initial_learning_rate': 0.1, 'name': 'CosineDecay', 'offset': 0}, 'type': 'cosine'}, 'optimizer': {'sgd': {'clipnorm': None, 'clipvalue': None, 'decay': 0.0, 'global_clipnorm': None, 'momentum': 0.9, 'name': 'SGD', 'nesterov': False}, 'type': 'sgd'}, 'warmup': {'linear': {'name': 'linear', 'warmup_learning_rate': 0, 'warmup_steps': 100}, 'type': 'linear'} }, 'recovery_begin_steps': 0, 'recovery_max_trials': 0, 'steps_per_loop': 100, 'summary_interval': 100, 'train_steps': 5000, 'train_tf_function': True, 'train_tf_while_loop': True, 'validation_interval': 1000, 'validation_steps': 78, 'validation_summary_subdir': 'validation'} } <IPython.core.display.Javascript object>
Set up the distribution strategy.
logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]
if exp_config.runtime.mixed_precision_dtype == tf.float16:
tf.keras.mixed_precision.set_global_policy('mixed_float16')
if 'GPU' in ''.join(logical_device_names):
distribution_strategy = tf.distribute.MirroredStrategy()
elif 'TPU' in ''.join(logical_device_names):
tf.tpu.experimental.initialize_tpu_system()
tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='/device:TPU_SYSTEM:0')
distribution_strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
print('Warning: this will be really slow.')
distribution_strategy = tf.distribute.OneDeviceStrategy(logical_device_names[0])
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3') INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
Create the Task
object (tfm.core.base_task.Task
) from the config_definitions.TaskConfig
.
The Task
object has all the methods necessary for building the dataset, building the model, and running training & evaluation. These methods are driven by tfm.core.train_lib.run_experiment
.
with distribution_strategy.scope():
model_dir = tempfile.mkdtemp()
task = tfm.core.task_factory.get_task(exp_config.task, logging_dir=model_dir)
# tf.keras.utils.plot_model(task.build_model(), show_shapes=True)
for images, labels in task.build_inputs(exp_config.task.train_data).take(1):
print()
print(f'images.shape: {str(images.shape):16} images.dtype: {images.dtype!r}')
print(f'labels.shape: {str(labels.shape):16} labels.dtype: {labels.dtype!r}')
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23. Instructions for updating: Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089 WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23. Instructions for updating: Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089 images.shape: (128, 32, 32, 3) images.dtype: tf.float32 labels.shape: (128,) labels.dtype: tf.int32 2022-12-17 12:25:42.541340: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Visualize the training data
The dataloader applies a z-score normalization using
preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB)
, so the images returned by the dataset can't be directly displayed by standard tools. The visualization code needs to rescale the data into the [0,1] range.
plt.hist(images.numpy().flatten());
Use ds_info
(which is an instance of tfds.core.DatasetInfo
) to lookup the text descriptions of each class ID.
label_info = ds_info.features['label']
label_info.int2str(1)
'automobile'
Visualize a batch of the data.
def show_batch(images, labels, predictions=None):
plt.figure(figsize=(10, 10))
min = images.numpy().min()
max = images.numpy().max()
delta = max - min
for i in range(12):
plt.subplot(6, 6, i + 1)
plt.imshow((images[i]-min) / delta)
if predictions is None:
plt.title(label_info.int2str(labels[i]))
else:
if labels[i] == predictions[i]:
color = 'g'
else:
color = 'r'
plt.title(label_info.int2str(predictions[i]), color=color)
plt.axis("off")
plt.figure(figsize=(10, 10))
for images, labels in task.build_inputs(exp_config.task.train_data).take(1):
show_batch(images, labels)
2022-12-17 12:25:44.100513: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. <Figure size 1000x1000 with 0 Axes>
Visualize the testing data
Visualize a batch of images from the validation dataset.
plt.figure(figsize=(10, 10));
for images, labels in task.build_inputs(exp_config.task.validation_data).take(1):
show_batch(images, labels)
2022-12-17 12:25:45.661145: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. <Figure size 1000x1000 with 0 Axes>
Train and evaluate
model, eval_logs = tfm.core.train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode='train_and_eval',
params=exp_config,
model_dir=model_dir,
run_post_eval=True)
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). restoring or initializing model... initialized model. train | step: 0 | training until step 1000... INFO:tensorflow:batch_all_reduce: 65 all-reduces with algorithm = nccl, num_packs = 1 INFO:tensorflow:batch_all_reduce: 65 all-reduces with algorithm = nccl, num_packs = 1 INFO:tensorflow:batch_all_reduce: 65 all-reduces with algorithm = nccl, num_packs = 1 INFO:tensorflow:batch_all_reduce: 65 all-reduces with algorithm = nccl, num_packs = 1 train | step: 100 | steps/sec: 3.9 | output: {'accuracy': 0.20242187, 'learning_rate': 0.09990134, 'top_5_accuracy': 0.71296877, 'training_loss': 2.6425662} saved checkpoint to /tmpfs/tmp/tmpctqej0c8/ckpt-100. train | step: 200 | steps/sec: 24.6 | output: {'accuracy': 0.24070312, 'learning_rate': 0.09960574, 'top_5_accuracy': 0.7657812, 'training_loss': 2.5396342} train | step: 300 | steps/sec: 30.8 | output: {'accuracy': 0.29671875, 'learning_rate': 0.09911436, 'top_5_accuracy': 0.82375, 'training_loss': 2.2092695} train | step: 400 | steps/sec: 31.5 | output: {'accuracy': 0.31757814, 'learning_rate': 0.09842916, 'top_5_accuracy': 0.8296875, 'training_loss': 2.1373692} train | step: 500 | steps/sec: 31.4 | output: {'accuracy': 0.33101562, 'learning_rate': 0.09755283, 'top_5_accuracy': 0.84054685, 'training_loss': 2.0857556} train | step: 600 | steps/sec: 31.6 | output: {'accuracy': 0.34929687, 'learning_rate': 0.096488826, 'top_5_accuracy': 0.85148436, 'training_loss': 2.0427072} train | step: 700 | steps/sec: 31.5 | output: {'accuracy': 0.35867187, 'learning_rate': 0.09524136, 'top_5_accuracy': 0.8584375, 'training_loss': 1.9956405} train | step: 800 | steps/sec: 31.4 | output: {'accuracy': 0.3709375, 'learning_rate': 0.09381534, 'top_5_accuracy': 0.86507815, 'training_loss': 1.9741195} train | step: 900 | steps/sec: 31.4 | output: {'accuracy': 0.38078126, 'learning_rate': 0.092216395, 'top_5_accuracy': 0.8708594, 'training_loss': 1.94308} train | step: 1000 | steps/sec: 31.5 | output: {'accuracy': 0.3907031, 'learning_rate': 0.090450846, 'top_5_accuracy': 0.8723438, 'training_loss': 1.9159793} eval | step: 1000 | running 78 steps of evaluation... eval | step: 1000 | eval time: 5.7 sec | output: {'accuracy': 0.5116186, 'top_5_accuracy': 0.9391026, 'validation_loss': 1.6100333} train | step: 1000 | training until step 2000... train | step: 1100 | steps/sec: 11.3 | output: {'accuracy': 0.40046874, 'learning_rate': 0.08852567, 'top_5_accuracy': 0.87492186, 'training_loss': 1.896727} train | step: 1200 | steps/sec: 31.5 | output: {'accuracy': 0.40625, 'learning_rate': 0.08644843, 'top_5_accuracy': 0.8813281, 'training_loss': 1.8706796} train | step: 1300 | steps/sec: 31.3 | output: {'accuracy': 0.41507813, 'learning_rate': 0.08422736, 'top_5_accuracy': 0.8851563, 'training_loss': 1.8534433} train | step: 1400 | steps/sec: 31.5 | output: {'accuracy': 0.4297656, 'learning_rate': 0.081871204, 'top_5_accuracy': 0.89023435, 'training_loss': 1.8151345} train | step: 1500 | steps/sec: 31.4 | output: {'accuracy': 0.4285156, 'learning_rate': 0.07938927, 'top_5_accuracy': 0.89039063, 'training_loss': 1.8241135} train | step: 1600 | steps/sec: 31.4 | output: {'accuracy': 0.43828124, 'learning_rate': 0.07679134, 'top_5_accuracy': 0.894375, 'training_loss': 1.7994397} train | step: 1700 | steps/sec: 31.3 | output: {'accuracy': 0.44578126, 'learning_rate': 0.07408768, 'top_5_accuracy': 0.8994531, 'training_loss': 1.767456} train | step: 1800 | steps/sec: 31.5 | output: {'accuracy': 0.450625, 'learning_rate': 0.071288966, 'top_5_accuracy': 0.9035938, 'training_loss': 1.733595} train | step: 1900 | steps/sec: 31.5 | output: {'accuracy': 0.45164064, 'learning_rate': 0.068406224, 'top_5_accuracy': 0.901875, 'training_loss': 1.734338} train | step: 2000 | steps/sec: 31.4 | output: {'accuracy': 0.46453124, 'learning_rate': 0.06545085, 'top_5_accuracy': 0.9046094, 'training_loss': 1.7166274} eval | step: 2000 | running 78 steps of evaluation... eval | step: 2000 | eval time: 0.9 sec | output: {'accuracy': 0.58984375, 'top_5_accuracy': 0.95502806, 'validation_loss': 1.3773639} train | step: 2000 | training until step 3000... train | step: 2100 | steps/sec: 24.3 | output: {'accuracy': 0.46710938, 'learning_rate': 0.062434502, 'top_5_accuracy': 0.9053906, 'training_loss': 1.6979953} train | step: 2200 | steps/sec: 31.5 | output: {'accuracy': 0.47875, 'learning_rate': 0.059369065, 'top_5_accuracy': 0.91085935, 'training_loss': 1.6677152} train | step: 2300 | steps/sec: 31.4 | output: {'accuracy': 0.484375, 'learning_rate': 0.056266654, 'top_5_accuracy': 0.91578126, 'training_loss': 1.6433257} train | step: 2400 | steps/sec: 31.5 | output: {'accuracy': 0.49023438, 'learning_rate': 0.053139526, 'top_5_accuracy': 0.91234374, 'training_loss': 1.6406174} train | step: 2500 | steps/sec: 31.5 | output: {'accuracy': 0.48601562, 'learning_rate': 0.049999997, 'top_5_accuracy': 0.9121094, 'training_loss': 1.6500366} train | step: 2600 | steps/sec: 31.5 | output: {'accuracy': 0.50890625, 'learning_rate': 0.04686048, 'top_5_accuracy': 0.91796875, 'training_loss': 1.6033983} train | step: 2700 | steps/sec: 31.4 | output: {'accuracy': 0.5030469, 'learning_rate': 0.043733336, 'top_5_accuracy': 0.92078125, 'training_loss': 1.5989766} train | step: 2800 | steps/sec: 31.5 | output: {'accuracy': 0.51734376, 'learning_rate': 0.040630933, 'top_5_accuracy': 0.9248437, 'training_loss': 1.5626732} train | step: 2900 | steps/sec: 31.5 | output: {'accuracy': 0.5224219, 'learning_rate': 0.037565507, 'top_5_accuracy': 0.9183594, 'training_loss': 1.5648793} train | step: 3000 | steps/sec: 31.5 | output: {'accuracy': 0.5271094, 'learning_rate': 0.034549143, 'top_5_accuracy': 0.9263281, 'training_loss': 1.537529} eval | step: 3000 | running 78 steps of evaluation... eval | step: 3000 | eval time: 0.9 sec | output: {'accuracy': 0.65965545, 'top_5_accuracy': 0.96784854, 'validation_loss': 1.1733003} train | step: 3000 | training until step 4000... train | step: 3100 | steps/sec: 24.2 | output: {'accuracy': 0.52953124, 'learning_rate': 0.03159377, 'top_5_accuracy': 0.92929685, 'training_loss': 1.5256608} train | step: 3200 | steps/sec: 31.5 | output: {'accuracy': 0.5339062, 'learning_rate': 0.028711034, 'top_5_accuracy': 0.9330469, 'training_loss': 1.5184388} train | step: 3300 | steps/sec: 31.5 | output: {'accuracy': 0.53945315, 'learning_rate': 0.025912309, 'top_5_accuracy': 0.92742187, 'training_loss': 1.5155553} train | step: 3400 | steps/sec: 31.5 | output: {'accuracy': 0.5420312, 'learning_rate': 0.023208654, 'top_5_accuracy': 0.93242186, 'training_loss': 1.494617} train | step: 3500 | steps/sec: 31.5 | output: {'accuracy': 0.5495312, 'learning_rate': 0.020610739, 'top_5_accuracy': 0.930625, 'training_loss': 1.4852078} train | step: 3600 | steps/sec: 31.4 | output: {'accuracy': 0.55140626, 'learning_rate': 0.018128792, 'top_5_accuracy': 0.933125, 'training_loss': 1.4555632} train | step: 3700 | steps/sec: 31.3 | output: {'accuracy': 0.56242186, 'learning_rate': 0.015772644, 'top_5_accuracy': 0.936875, 'training_loss': 1.4414537} train | step: 3800 | steps/sec: 31.6 | output: {'accuracy': 0.5635156, 'learning_rate': 0.01355157, 'top_5_accuracy': 0.93953127, 'training_loss': 1.4296999} train | step: 3900 | steps/sec: 31.5 | output: {'accuracy': 0.5769531, 'learning_rate': 0.011474336, 'top_5_accuracy': 0.93671876, 'training_loss': 1.4046975} train | step: 4000 | steps/sec: 31.6 | output: {'accuracy': 0.5765625, 'learning_rate': 0.009549147, 'top_5_accuracy': 0.93875, 'training_loss': 1.4086288} eval | step: 4000 | running 78 steps of evaluation... eval | step: 4000 | eval time: 0.9 sec | output: {'accuracy': 0.7139423, 'top_5_accuracy': 0.97566104, 'validation_loss': 1.0238551} train | step: 4000 | training until step 5000... train | step: 4100 | steps/sec: 24.3 | output: {'accuracy': 0.58640623, 'learning_rate': 0.0077836006, 'top_5_accuracy': 0.9411719, 'training_loss': 1.3843393} train | step: 4200 | steps/sec: 31.5 | output: {'accuracy': 0.58023435, 'learning_rate': 0.0061846706, 'top_5_accuracy': 0.94109374, 'training_loss': 1.3896589} train | step: 4300 | steps/sec: 31.4 | output: {'accuracy': 0.5786719, 'learning_rate': 0.0047586444, 'top_5_accuracy': 0.93921876, 'training_loss': 1.3887373} train | step: 4400 | steps/sec: 31.4 | output: {'accuracy': 0.58320314, 'learning_rate': 0.0035111725, 'top_5_accuracy': 0.9433594, 'training_loss': 1.3801813} train | step: 4500 | steps/sec: 31.4 | output: {'accuracy': 0.5899219, 'learning_rate': 0.002447176, 'top_5_accuracy': 0.946875, 'training_loss': 1.3682168} train | step: 4600 | steps/sec: 31.5 | output: {'accuracy': 0.5877344, 'learning_rate': 0.0015708387, 'top_5_accuracy': 0.94242185, 'training_loss': 1.37529} train | step: 4700 | steps/sec: 31.6 | output: {'accuracy': 0.5903125, 'learning_rate': 0.0008856386, 'top_5_accuracy': 0.9416406, 'training_loss': 1.3593487} train | step: 4800 | steps/sec: 31.6 | output: {'accuracy': 0.6027344, 'learning_rate': 0.00039426386, 'top_5_accuracy': 0.9451563, 'training_loss': 1.3419139} train | step: 4900 | steps/sec: 31.4 | output: {'accuracy': 0.59585935, 'learning_rate': 9.866357e-05, 'top_5_accuracy': 0.94953126, 'training_loss': 1.3465914} train | step: 5000 | steps/sec: 31.3 | output: {'accuracy': 0.59953123, 'learning_rate': 0.0, 'top_5_accuracy': 0.9472656, 'training_loss': 1.3405178} eval | step: 5000 | running 78 steps of evaluation... eval | step: 5000 | eval time: 0.9 sec | output: {'accuracy': 0.72415864, 'top_5_accuracy': 0.9777644, 'validation_loss': 0.9919631} saved checkpoint to /tmpfs/tmp/tmpctqej0c8/ckpt-5000. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/nn_ops.py:5250: tensor_shape_from_node_def_name (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version. Instructions for updating: This API was designed for TensorFlow v1. See https://www.tensorflow.org/guide/migrate for instructions on how to migrate your code to TensorFlow v2. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/nn_ops.py:5250: tensor_shape_from_node_def_name (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version. Instructions for updating: This API was designed for TensorFlow v1. See https://www.tensorflow.org/guide/migrate for instructions on how to migrate your code to TensorFlow v2. eval | step: 5000 | running 78 steps of evaluation... eval | step: 5000 | eval time: 0.9 sec | output: {'accuracy': 0.72415864, 'top_5_accuracy': 0.9777644, 'validation_loss': 0.9919631}
# tf.keras.utils.plot_model(model, show_shapes=True)
Print the accuracy
, top_5_accuracy
, and validation_loss
evaluation metrics.
for key, value in eval_logs.items():
if isinstance(value, tf.Tensor):
value = value.numpy()
print(f'{key:20}: {value:.3f}')
accuracy : 0.724 top_5_accuracy : 0.978 validation_loss : 0.992
Run a batch of the processed training data through the model, and view the results
for images, labels in task.build_inputs(exp_config.task.train_data).take(1):
predictions = model.predict(images)
predictions = tf.argmax(predictions, axis=-1)
show_batch(images, labels, tf.cast(predictions, tf.int32))
if device=='CPU':
plt.suptitle('The model was only trained for a few steps, it is not expected to do well.')
2022-12-17 12:29:06.219775: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 4/4 [==============================] - 3s 7ms/step
Export a SavedModel
The keras.Model
object returned by train_lib.run_experiment
expects the data to be normalized by the dataset loader using the same mean and variance statiscics in preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB)
. This export function handles those details, so you can pass tf.uint8
images and get the correct results.
# Saving and exporting the trained model
export_saved_model_lib.export_inference_graph(
input_type='image_tensor',
batch_size=1,
input_image_size=[32, 32],
params=exp_config,
checkpoint_path=tf.train.latest_checkpoint(model_dir),
export_dir='./export/')
WARNING:absl:Found untraced functions such as inference_for_tflite, inference_from_image_bytes, inference_from_tf_example, _jit_compiled_convolution_op, conv2d_layer_call_fn while saving (showing 5 of 64). These functions will not be directly callable after loading. INFO:tensorflow:Assets written to: ./export/assets INFO:tensorflow:Assets written to: ./export/assets
Test the exported model.
# Importing SavedModel
imported = tf.saved_model.load('./export/')
model_fn = imported.signatures['serving_default']
Visualize the predictions.
plt.figure(figsize=(10, 10))
for data in tfds.load('cifar10', split='test').batch(12).take(1):
predictions = []
for image in data['image']:
index = tf.argmax(model_fn(image[tf.newaxis, ...])['logits'], axis=1)[0]
predictions.append(index)
show_batch(data['image'], data['label'], predictions)
if device=='CPU':
plt.suptitle('The model was only trained for a few steps, it is not expected to do better than random.')
2022-12-17 12:29:21.455321: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. <Figure size 1000x1000 with 0 Axes>