![]() |
![]() |
![]() |
![]() |
This tutorial fine-tunes a RetinaNet with ResNet-50 as backbone model from the TensorFlow Model Garden package (tensorflow-models) to detect three different Blood Cells in BCCD dataset. The RetinaNet is pretrained on COCO train2017 and evaluated on COCO val2017
Model Garden contains a collection of state-of-the-art models, implemented with TensorFlow's high-level APIs. The implementations demonstrate the best practices for modeling, letting users to take full advantage of TensorFlow for their research and product development.
This tutorial demonstrates how to:
- Use models from the Tensorflow Model Garden(TFM) package.
- Fine-tune a pre-trained RetinanNet with ResNet-50 as backbone for object detection.
- Export the tuned RetinaNet model
Install necessary dependencies
pip install -U -q "tensorflow" "tf-models-official"
Import required libraries
import os
import io
import pprint
import tempfile
import matplotlib
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from PIL import Image
from six import BytesIO
from IPython import display
from urllib.request import urlopen
2023-01-22 12:10:58.598384: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory 2023-01-22 12:10:58.598507: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory 2023-01-22 12:10:58.598518: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
Import required libraries from tensorflow models
import orbit
import tensorflow_models as tfm
from official.core import exp_factory
from official.core import config_definitions as cfg
from official.vision.serving import export_saved_model_lib
from official.vision.ops.preprocess_ops import normalize_image
from official.vision.ops.preprocess_ops import resize_and_crop_image
from official.vision.utils.object_detection import visualization_utils
from official.vision.dataloaders.tf_example_decoder import TfExampleDecoder
pp = pprint.PrettyPrinter(indent=4) # Set Pretty Print Indentation
print(tf.__version__) # Check the version of tensorflow used
%matplotlib inline
2.11.0
Custom dataset preparation for object detection
Models in official repository(of model-garden) requires data in a TFRecords format.
Please check this resource to learn more about TFRecords data format.
clone the model-garden repo as the required data conversion codes are within this model-garden repository
git clone --quiet https://github.com/tensorflow/models.git
Upload your custom data in drive or local disk of the notebook and unzip the data
curl -L 'https://public.roboflow.com/ds/ZpYLqHeT0W?key=ZXfZLRnhsc' > './BCCD.v1-bccd.coco.zip'
unzip -q -o './BCCD.v1-bccd.coco.zip' -d './BCC.v1-bccd.coco/'
rm './BCCD.v1-bccd.coco.zip'
% Total % Received % Xferd Average Speed Time Time Time Current Dload Upload Total Spent Left Speed 100 892 100 892 0 0 495 0 0:00:01 0:00:01 --:--:-- 494 100 15.2M 100 15.2M 0 0 7045k 0 0:00:02 0:00:02 --:--:-- 7045k
Change directory to vision or data where data conversion tools are available
%cd ./models/
/tmpfs/src/temp/docs/vision/models
CLI command to convert data(train data).
TRAIN_DATA_DIR='../BCC.v1-bccd.coco/train'
TRAIN_ANNOTATION_FILE_DIR='../BCC.v1-bccd.coco/train/_annotations.coco.json'
OUTPUT_TFRECORD_TRAIN='../bccd_coco_tfrecords/train'
# Need to provide
# 1. image_dir: where images are present
# 2. object_annotations_file: where annotations are listed in json format
# 3. output_file_prefix: where to write output convered TFRecords files
python -m official.vision.data.create_coco_tf_record --logtostderr \
--image_dir=${TRAIN_DATA_DIR} \
--object_annotations_file=${TRAIN_ANNOTATION_FILE_DIR} \
--output_file_prefix=$OUTPUT_TFRECORD_TRAIN \
--num_shards=1
2023-01-22 12:11:26.864361: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/cv2/../../lib64: 2023-01-22 12:11:26.864469: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/cv2/../../lib64: 2023-01-22 12:11:26.864479: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly. I0122 12:11:28.821597 140395509507904 create_coco_tf_record.py:499] writing to output path: ../bccd_coco_tfrecords/train I0122 12:11:28.891679 140395509507904 create_coco_tf_record.py:371] Building bounding box index. I0122 12:11:28.893189 140395509507904 create_coco_tf_record.py:382] 0 images are missing bboxes. I0122 12:11:29.134161 140395509507904 tfrecord_lib.py:168] On image 0 I0122 12:11:29.139740 140395509507904 tfrecord_lib.py:168] On image 100 I0122 12:11:29.144275 140395509507904 tfrecord_lib.py:168] On image 200 I0122 12:11:29.148532 140395509507904 tfrecord_lib.py:168] On image 300 I0122 12:11:29.153190 140395509507904 tfrecord_lib.py:168] On image 400 I0122 12:11:29.157748 140395509507904 tfrecord_lib.py:168] On image 500 I0122 12:11:29.162106 140395509507904 tfrecord_lib.py:168] On image 600 I0122 12:11:29.166518 140395509507904 tfrecord_lib.py:168] On image 700 I0122 12:11:29.185209 140395509507904 tfrecord_lib.py:180] Finished writing, skipped 6 annotations. I0122 12:11:29.192366 140395509507904 create_coco_tf_record.py:534] Finished writing, skipped 6 annotations.
CLI command to convert data(validation data).
VALID_DATA_DIR='../BCC.v1-bccd.coco/valid'
VALID_ANNOTATION_FILE_DIR='../BCC.v1-bccd.coco/valid/_annotations.coco.json'
OUTPUT_TFRECORD_VALID='../bccd_coco_tfrecords/valid'
python -m official.vision.data.create_coco_tf_record --logtostderr \
--image_dir=$VALID_DATA_DIR \
--object_annotations_file=$VALID_ANNOTATION_FILE_DIR \
--output_file_prefix=$OUTPUT_TFRECORD_VALID \
--num_shards=1
2023-01-22 12:11:30.949344: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/cv2/../../lib64: 2023-01-22 12:11:30.949450: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/cv2/../../lib64: 2023-01-22 12:11:30.949460: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly. I0122 12:11:32.639877 139722239469376 create_coco_tf_record.py:499] writing to output path: ../bccd_coco_tfrecords/valid I0122 12:11:32.647148 139722239469376 create_coco_tf_record.py:371] Building bounding box index. I0122 12:11:32.647382 139722239469376 create_coco_tf_record.py:382] 0 images are missing bboxes. I0122 12:11:32.823836 139722239469376 tfrecord_lib.py:168] On image 0 I0122 12:11:32.843802 139722239469376 tfrecord_lib.py:180] Finished writing, skipped 0 annotations. I0122 12:11:32.845139 139722239469376 create_coco_tf_record.py:534] Finished writing, skipped 0 annotations.
CLI command to convert data(test data).
TEST_DATA_DIR='../BCC.v1-bccd.coco/test'
TEST_ANNOTATION_FILE_DIR='../BCC.v1-bccd.coco/test/_annotations.coco.json'
OUTPUT_TFRECORD_TEST='../bccd_coco_tfrecords/test'
python -m official.vision.data.create_coco_tf_record --logtostderr \
--image_dir=$TEST_DATA_DIR \
--object_annotations_file=$TEST_ANNOTATION_FILE_DIR \
--output_file_prefix=$OUTPUT_TFRECORD_TEST \
--num_shards=1
2023-01-22 12:11:34.586338: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/cv2/../../lib64: 2023-01-22 12:11:34.586459: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/cv2/../../lib64: 2023-01-22 12:11:34.586471: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly. I0122 12:11:36.304936 140400327403328 create_coco_tf_record.py:499] writing to output path: ../bccd_coco_tfrecords/test I0122 12:11:36.308660 140400327403328 create_coco_tf_record.py:371] Building bounding box index. I0122 12:11:36.308827 140400327403328 create_coco_tf_record.py:382] 0 images are missing bboxes. I0122 12:11:36.471271 140400327403328 tfrecord_lib.py:168] On image 0 I0122 12:11:36.490734 140400327403328 tfrecord_lib.py:180] Finished writing, skipped 0 annotations. I0122 12:11:36.491844 140400327403328 create_coco_tf_record.py:534] Finished writing, skipped 0 annotations.
Configure the Retinanet Resnet FPN COCO model for custom dataset.
Dataset used for fine tuning the checkpoint is Blood Cells Detection (BCCD).
train_data_input_path = '../bccd_coco_tfrecords/train-00000-of-00001.tfrecord'
valid_data_input_path = '../bccd_coco_tfrecords/valid-00000-of-00001.tfrecord'
test_data_input_path = '../bccd_coco_tfrecords/test-00000-of-00001.tfrecord'
model_dir = '../trained_model/'
export_dir ='../exported_model/'
In Model Garden, the collections of parameters that define a model are called configs. Model Garden can create a config based on a known set of parameters via a factory.
Use the retinanet_resnetfpn_coco
experiment configuration, as defined by tfm.vision.configs.retinanet.retinanet_resnetfpn_coco
.
The configuration defines an experiment to train a RetinanNet with Resnet-50 as backbone, FPN as decoder. Default Configuration is trained on COCO train2017 and evaluated on COCO val2017.
There are also other alternative experiments available such as
retinanet_resnetfpn_coco
, retinanet_spinenet_coco
, fasterrcnn_resnetfpn_coco
and more. One can switch to them by changing the experiment name argument to the get_exp_config
function.
We are going to fine tune the Resnet-50 backbone checkpoint which is already present in the default configuration.
exp_config = exp_factory.get_exp_config('retinanet_resnetfpn_coco')
Adjust the model and dataset configurations so that it works with custom dataset(in this case BCCD
).
batch_size = 8
num_classes = 3
HEIGHT, WIDTH = 256, 256
IMG_SIZE = [HEIGHT, WIDTH, 3]
# Backbone config.
exp_config.task.freeze_backbone = False
exp_config.task.annotation_file = ''
# Model config.
exp_config.task.model.input_size = IMG_SIZE
exp_config.task.model.num_classes = num_classes + 1
exp_config.task.model.detection_generator.tflite_post_processing.max_classes_per_detection = exp_config.task.model.num_classes
# Training data config.
exp_config.task.train_data.input_path = train_data_input_path
exp_config.task.train_data.dtype = 'float32'
exp_config.task.train_data.global_batch_size = batch_size
exp_config.task.train_data.parser.aug_scale_max = 1.0
exp_config.task.train_data.parser.aug_scale_min = 1.0
# Validation data config.
exp_config.task.validation_data.input_path = valid_data_input_path
exp_config.task.validation_data.dtype = 'float32'
exp_config.task.validation_data.global_batch_size = batch_size
Adjust the trainer configuration.
logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]
if 'GPU' in ''.join(logical_device_names):
print('This may be broken in Colab.')
device = 'GPU'
elif 'TPU' in ''.join(logical_device_names):
print('This may be broken in Colab.')
device = 'TPU'
else:
print('Running on CPU is slow, so only train for a few steps.')
device = 'CPU'
train_steps = 1000
exp_config.trainer.steps_per_loop = 100 # steps_per_loop = num_of_training_examples // train_batch_size
exp_config.trainer.summary_interval = 100
exp_config.trainer.checkpoint_interval = 100
exp_config.trainer.validation_interval = 100
exp_config.trainer.validation_steps = 100 # validation_steps = num_of_validation_examples // eval_batch_size
exp_config.trainer.train_steps = train_steps
exp_config.trainer.optimizer_config.warmup.linear.warmup_steps = 100
exp_config.trainer.optimizer_config.learning_rate.type = 'cosine'
exp_config.trainer.optimizer_config.learning_rate.cosine.decay_steps = train_steps
exp_config.trainer.optimizer_config.learning_rate.cosine.initial_learning_rate = 0.1
exp_config.trainer.optimizer_config.warmup.linear.warmup_learning_rate = 0.05
This may be broken in Colab.
Print the modified configuration.
pp.pprint(exp_config.as_dict())
display.Javascript('google.colab.output.setIframeHeight("500px");')
{ 'runtime': { 'all_reduce_alg': None, 'batchnorm_spatial_persistent': False, 'dataset_num_private_threads': None, 'default_shard_dim': -1, 'distribution_strategy': 'mirrored', 'enable_xla': False, 'gpu_thread_mode': None, 'loss_scale': None, 'mixed_precision_dtype': 'bfloat16', 'num_cores_per_replica': 1, 'num_gpus': 0, 'num_packs': 1, 'per_gpu_thread_count': 0, 'run_eagerly': False, 'task_index': -1, 'tpu': None, 'tpu_enable_xla_dynamic_padder': None, 'worker_hosts': None}, 'task': { 'annotation_file': '', 'differential_privacy_config': None, 'export_config': { 'cast_detection_classes_to_float': False, 'cast_num_detections_to_float': False, 'output_normalized_coordinates': False}, 'freeze_backbone': False, 'init_checkpoint': 'gs://cloud-tpu-checkpoints/vision-2.0/resnet50_imagenet/ckpt-28080', 'init_checkpoint_modules': 'backbone', 'losses': { 'box_loss_weight': 50, 'focal_loss_alpha': 0.25, 'focal_loss_gamma': 1.5, 'huber_loss_delta': 0.1, 'l2_weight_decay': 0.0001, 'loss_weight': 1.0}, 'model': { 'anchor': { 'anchor_size': 4.0, 'aspect_ratios': [0.5, 1.0, 2.0], 'num_scales': 3}, 'backbone': { 'resnet': { 'bn_trainable': True, 'depth_multiplier': 1.0, 'model_id': 50, 'replace_stem_max_pool': False, 'resnetd_shortcut': False, 'scale_stem': True, 'se_ratio': 0.0, 'stem_type': 'v0', 'stochastic_depth_drop_rate': 0.0}, 'type': 'resnet'}, 'decoder': { 'fpn': { 'fusion_type': 'sum', 'num_filters': 256, 'use_keras_layer': False, 'use_separable_conv': False}, 'type': 'fpn'}, 'detection_generator': { 'apply_nms': True, 'max_num_detections': 100, 'nms_iou_threshold': 0.5, 'nms_version': 'v2', 'pre_nms_score_threshold': 0.05, 'pre_nms_top_k': 5000, 'soft_nms_sigma': None, 'tflite_post_processing': { 'max_classes_per_detection': 4, 'max_detections': 200, 'nms_iou_threshold': 0.5, 'nms_score_threshold': 0.1, 'use_regular_nms': False}, 'use_cpu_nms': False}, 'head': { 'attribute_heads': [], 'num_convs': 4, 'num_filters': 256, 'share_classification_heads': False, 'use_separable_conv': False}, 'input_size': [256, 256, 3], 'max_level': 7, 'min_level': 3, 'norm_activation': { 'activation': 'relu', 'norm_epsilon': 0.001, 'norm_momentum': 0.99, 'use_sync_bn': False}, 'num_classes': 4}, 'name': None, 'per_category_metrics': False, 'train_data': { 'apply_tf_data_service_before_batching': False, 'block_length': 1, 'cache': False, 'cycle_length': None, 'decoder': { 'simple_decoder': { 'mask_binarize_threshold': None, 'regenerate_source_id': False}, 'type': 'simple_decoder'}, 'deterministic': None, 'drop_remainder': True, 'dtype': 'float32', 'enable_shared_tf_data_service_between_parallel_trainers': False, 'enable_tf_data_service': False, 'file_type': 'tfrecord', 'global_batch_size': 8, 'input_path': '../bccd_coco_tfrecords/train-00000-of-00001.tfrecord', 'is_training': True, 'parser': { 'aug_policy': None, 'aug_rand_hflip': True, 'aug_scale_max': 1.0, 'aug_scale_min': 1.0, 'aug_type': None, 'match_threshold': 0.5, 'max_num_instances': 100, 'num_channels': 3, 'skip_crowd_during_training': True, 'unmatched_threshold': 0.5}, 'prefetch_buffer_size': None, 'seed': None, 'sharding': True, 'shuffle_buffer_size': 10000, 'tf_data_service_address': None, 'tf_data_service_job_name': None, 'tfds_as_supervised': False, 'tfds_data_dir': '', 'tfds_name': '', 'tfds_skip_decoding_feature': '', 'tfds_split': '', 'trainer_id': None}, 'use_coco_metrics': True, 'use_wod_metrics': False, 'validation_data': { 'apply_tf_data_service_before_batching': False, 'block_length': 1, 'cache': False, 'cycle_length': None, 'decoder': { 'simple_decoder': { 'mask_binarize_threshold': None, 'regenerate_source_id': False}, 'type': 'simple_decoder'}, 'deterministic': None, 'drop_remainder': True, 'dtype': 'float32', 'enable_shared_tf_data_service_between_parallel_trainers': False, 'enable_tf_data_service': False, 'file_type': 'tfrecord', 'global_batch_size': 8, 'input_path': '../bccd_coco_tfrecords/valid-00000-of-00001.tfrecord', 'is_training': False, 'parser': { 'aug_policy': None, 'aug_rand_hflip': False, 'aug_scale_max': 1.0, 'aug_scale_min': 1.0, 'aug_type': None, 'match_threshold': 0.5, 'max_num_instances': 100, 'num_channels': 3, 'skip_crowd_during_training': True, 'unmatched_threshold': 0.5}, 'prefetch_buffer_size': None, 'seed': None, 'sharding': True, 'shuffle_buffer_size': 10000, 'tf_data_service_address': None, 'tf_data_service_job_name': None, 'tfds_as_supervised': False, 'tfds_data_dir': '', 'tfds_name': '', 'tfds_skip_decoding_feature': '', 'tfds_split': '', 'trainer_id': None} }, 'trainer': { 'allow_tpu_summary': False, 'best_checkpoint_eval_metric': '', 'best_checkpoint_export_subdir': '', 'best_checkpoint_metric_comp': 'higher', 'checkpoint_interval': 100, 'continuous_eval_timeout': 3600, 'eval_tf_function': True, 'eval_tf_while_loop': False, 'loss_upper_bound': 1000000.0, 'max_to_keep': 5, 'optimizer_config': { 'ema': None, 'learning_rate': { 'cosine': { 'alpha': 0.0, 'decay_steps': 1000, 'initial_learning_rate': 0.1, 'name': 'CosineDecay', 'offset': 0}, 'type': 'cosine'}, 'optimizer': { 'sgd': { 'clipnorm': None, 'clipvalue': None, 'decay': 0.0, 'global_clipnorm': None, 'momentum': 0.9, 'name': 'SGD', 'nesterov': False}, 'type': 'sgd'}, 'warmup': { 'linear': { 'name': 'linear', 'warmup_learning_rate': 0.05, 'warmup_steps': 100}, 'type': 'linear'} }, 'recovery_begin_steps': 0, 'recovery_max_trials': 0, 'steps_per_loop': 100, 'summary_interval': 100, 'train_steps': 1000, 'train_tf_function': True, 'train_tf_while_loop': True, 'validation_interval': 100, 'validation_steps': 100, 'validation_summary_subdir': 'validation'} } <IPython.core.display.Javascript object>
Set up the distribution strategy.
if exp_config.runtime.mixed_precision_dtype == tf.float16:
tf.keras.mixed_precision.set_global_policy('mixed_float16')
if 'GPU' in ''.join(logical_device_names):
distribution_strategy = tf.distribute.MirroredStrategy()
elif 'TPU' in ''.join(logical_device_names):
tf.tpu.experimental.initialize_tpu_system()
tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='/device:TPU_SYSTEM:0')
distribution_strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
print('Warning: this will be really slow.')
distribution_strategy = tf.distribute.OneDeviceStrategy(logical_device_names[0])
print('Done')
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3') INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3') Done
Create the Task
object (tfm.core.base_task.Task
) from the config_definitions.TaskConfig
.
The Task
object has all the methods necessary for building the dataset, building the model, and running training & evaluation. These methods are driven by tfm.core.train_lib.run_experiment
.
with distribution_strategy.scope():
task = tfm.core.task_factory.get_task(exp_config.task, logging_dir=model_dir)
Visualize a batch of the data.
for images, labels in task.build_inputs(exp_config.task.train_data).take(1):
print()
print(f'images.shape: {str(images.shape):16} images.dtype: {images.dtype!r}')
print(f'labels.keys: {labels.keys()}')
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23. Instructions for updating: Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089 WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23. Instructions for updating: Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089 images.shape: (8, 256, 256, 3) images.dtype: tf.float32 labels.keys: dict_keys(['cls_targets', 'box_targets', 'anchor_boxes', 'cls_weights', 'box_weights', 'image_info'])
Create category index dictionary to map the labels to coressponding label names.
category_index={
1: {
'id': 1,
'name': 'Platelets'
},
2: {
'id': 2,
'name': 'RBC'
},
3: {
'id': 3,
'name': 'WBC'
}
}
tf_ex_decoder = TfExampleDecoder()
Helper function for visualizing the results from TFRecords.
Use visualize_boxes_and_labels_on_image_array
from visualization_utils
to draw boudning boxes on the image.
def show_batch(raw_records, num_of_examples):
plt.figure(figsize=(20, 20))
use_normalized_coordinates=True
min_score_thresh = 0.30
for i, serialized_example in enumerate(raw_records):
plt.subplot(1, 3, i + 1)
decoded_tensors = tf_ex_decoder.decode(serialized_example)
image = decoded_tensors['image'].numpy().astype('uint8')
scores = np.ones(shape=(len(decoded_tensors['groundtruth_boxes'])))
visualization_utils.visualize_boxes_and_labels_on_image_array(
image,
decoded_tensors['groundtruth_boxes'].numpy(),
decoded_tensors['groundtruth_classes'].numpy().astype('int'),
scores,
category_index=category_index,
use_normalized_coordinates=use_normalized_coordinates,
max_boxes_to_draw=200,
min_score_thresh=min_score_thresh,
agnostic_mode=False,
instance_masks=None,
line_thickness=4)
plt.imshow(image)
plt.axis('off')
plt.title(f'Image-{i+1}')
plt.show()
Visualization of train data
The bounding box detection has two components
- Class label of the object detected (e.g.RBC)
- Percentage of match between predicted and ground truth bounding boxes.
buffer_size = 20
num_of_examples = 3
raw_records = tf.data.TFRecordDataset(
exp_config.task.train_data.input_path).shuffle(
buffer_size=buffer_size).take(num_of_examples)
show_batch(raw_records, num_of_examples)
Train and evaluate.
We follow the COCO challenge tradition to evaluate the accuracy of object detection based on mAP(mean Average Precision). Please check here for detail explanation of how evaluation metrics for detection task is done.
IoU: is defined as the area of the intersection divided by the area of the union of a predicted bounding box and ground truth bounding box.
model, eval_logs = tfm.core.train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode='train_and_eval',
params=exp_config,
model_dir=model_dir,
run_post_eval=True)
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). restoring or initializing model... initialized model. train | step: 0 | training until step 100... INFO:tensorflow:batch_all_reduce: 285 all-reduces with algorithm = nccl, num_packs = 1 INFO:tensorflow:batch_all_reduce: 285 all-reduces with algorithm = nccl, num_packs = 1 INFO:tensorflow:batch_all_reduce: 285 all-reduces with algorithm = nccl, num_packs = 1 INFO:tensorflow:batch_all_reduce: 285 all-reduces with algorithm = nccl, num_packs = 1 train | step: 100 | steps/sec: 1.3 | output: {'box_loss': 0.012490222, 'cls_loss': 0.6384361, 'learning_rate': 0.09755283, 'model_loss': 1.2629474, 'total_loss': 1.8098843, 'training_loss': 1.8098843} saved checkpoint to ../trained_model/ckpt-100. eval | step: 100 | running 100 steps of evaluation... creating index... index created! creating index... index created! Running per image evaluation... Evaluate annotation type *bbox* DONE (t=0.36s). Accumulating evaluation results... DONE (t=0.05s). Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.000 Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.000 Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.000 Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.000 Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.000 Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.002 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.001 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.006 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.016 Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.000 Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.070 Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.008 eval | step: 100 | eval time: 22.0 sec | output: {'AP': 4.9814236e-05, 'AP50': 0.00017214617, 'AP75': 0.0, 'APl': 0.0023122071, 'APm': 9.130664e-05, 'APs': 0.0, 'ARl': 0.0076502734, 'ARm': 0.07, 'ARmax1': 0.00093896716, 'ARmax10': 0.0056338026, 'ARmax100': 0.016431924, 'ARs': 0.0, 'box_loss': 0.018262193, 'cls_loss': 460.0177, 'model_loss': 460.93082, 'total_loss': 461.50656, 'validation_loss': 461.50656} train | step: 100 | training until step 200... train | step: 200 | steps/sec: 2.8 | output: {'box_loss': 0.005326658, 'cls_loss': 0.44851834, 'learning_rate': 0.090450846, 'model_loss': 0.7148512, 'total_loss': 1.2860019, 'training_loss': 1.2860019} saved checkpoint to ../trained_model/ckpt-200. eval | step: 200 | running 100 steps of evaluation... creating index... index created! creating index... index created! Running per image evaluation... Evaluate annotation type *bbox* DONE (t=1.05s). Accumulating evaluation results... DONE (t=0.06s). Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.309 Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.576 Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.281 Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.000 Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.288 Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.327 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.203 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.350 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.397 Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.000 Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.361 Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.415 eval | step: 200 | eval time: 3.5 sec | output: {'AP': 0.30938724, 'AP50': 0.57599366, 'AP75': 0.28053367, 'APl': 0.32672268, 'APm': 0.2883867, 'APs': 0.0, 'ARl': 0.4147541, 'ARm': 0.36086452, 'ARmax1': 0.20319515, 'ARmax10': 0.34980303, 'ARmax100': 0.39654395, 'ARs': 0.0, 'box_loss': 0.0056167482, 'cls_loss': 0.37327495, 'model_loss': 0.65411234, 'total_loss': 1.2206033, 'validation_loss': 1.2206033} train | step: 200 | training until step 300... train | step: 300 | steps/sec: 5.9 | output: {'box_loss': 0.004289604, 'cls_loss': 0.32028148, 'learning_rate': 0.07938927, 'model_loss': 0.5347617, 'total_loss': 1.096797, 'training_loss': 1.096797} saved checkpoint to ../trained_model/ckpt-300. eval | step: 300 | running 100 steps of evaluation... creating index... index created! creating index... index created! Running per image evaluation... Evaluate annotation type *bbox* DONE (t=1.05s). Accumulating evaluation results... DONE (t=0.05s). Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.343 Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.601 Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.363 Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.000 Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.285 Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.410 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.230 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.359 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.409 Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.000 Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.333 Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.445 eval | step: 300 | eval time: 3.6 sec | output: {'AP': 0.34271824, 'AP50': 0.6007266, 'AP75': 0.3625663, 'APl': 0.41033825, 'APm': 0.28515357, 'APs': 0.0, 'ARl': 0.44508198, 'ARm': 0.333059, 'ARmax1': 0.2298889, 'ARmax10': 0.35857797, 'ARmax100': 0.40919676, 'ARs': 0.0, 'box_loss': 0.004234419, 'cls_loss': 0.3026318, 'model_loss': 0.5143528, 'total_loss': 1.0719291, 'validation_loss': 1.0719291} train | step: 300 | training until step 400... train | step: 400 | steps/sec: 5.9 | output: {'box_loss': 0.0036904868, 'cls_loss': 0.3034157, 'learning_rate': 0.06545085, 'model_loss': 0.48793995, 'total_loss': 1.041542, 'training_loss': 1.041542} saved checkpoint to ../trained_model/ckpt-400. eval | step: 400 | running 100 steps of evaluation... creating index... index created! creating index... index created! Running per image evaluation... Evaluate annotation type *bbox* DONE (t=1.02s). Accumulating evaluation results... DONE (t=0.05s). Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.371 Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.603 Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.440 Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.000 Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.352 Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.399 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.249 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.384 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.445 Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.000 Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.416 Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.421 eval | step: 400 | eval time: 3.5 sec | output: {'AP': 0.37149358, 'AP50': 0.6033562, 'AP75': 0.44032815, 'APl': 0.39924088, 'APm': 0.3522066, 'APs': 0.0, 'ARl': 0.42103824, 'ARm': 0.41559434, 'ARmax1': 0.24924, 'ARmax10': 0.38401896, 'ARmax100': 0.44491005, 'ARs': 0.0, 'box_loss': 0.0033389109, 'cls_loss': 0.2925316, 'model_loss': 0.4594771, 'total_loss': 1.0092748, 'validation_loss': 1.0092748} train | step: 400 | training until step 500... train | step: 500 | steps/sec: 5.7 | output: {'box_loss': 0.0032879927, 'cls_loss': 0.27859658, 'learning_rate': 0.049999997, 'model_loss': 0.4429962, 'total_loss': 0.9895906, 'training_loss': 0.9895906} saved checkpoint to ../trained_model/ckpt-500. eval | step: 500 | running 100 steps of evaluation... creating index... index created! creating index... index created! Running per image evaluation... Evaluate annotation type *bbox* DONE (t=1.02s). Accumulating evaluation results... DONE (t=0.05s). Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.382 Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.607 Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.441 Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.000 Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.361 Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.415 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.259 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.399 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.459 Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.000 Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.413 Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.453 eval | step: 500 | eval time: 3.6 sec | output: {'AP': 0.38194072, 'AP50': 0.60690856, 'AP75': 0.44128215, 'APl': 0.41530672, 'APm': 0.36100027, 'APs': 0.0, 'ARl': 0.45255008, 'ARm': 0.41261846, 'ARmax1': 0.25854716, 'ARmax10': 0.39946136, 'ARmax100': 0.45861977, 'ARs': 0.0, 'box_loss': 0.0030391, 'cls_loss': 0.27277103, 'model_loss': 0.42472604, 'total_loss': 0.9683125, 'validation_loss': 0.9683125} train | step: 500 | training until step 600... train | step: 600 | steps/sec: 5.9 | output: {'box_loss': 0.003014897, 'cls_loss': 0.26925796, 'learning_rate': 0.034549143, 'model_loss': 0.4200028, 'total_loss': 0.9611966, 'training_loss': 0.9611966} saved checkpoint to ../trained_model/ckpt-600. eval | step: 600 | running 100 steps of evaluation... creating index... index created! creating index... index created! Running per image evaluation... Evaluate annotation type *bbox* DONE (t=0.98s). Accumulating evaluation results... DONE (t=0.06s). Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.402 Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.610 Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.492 Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.000 Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.425 Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.439 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.268 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.419 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.478 Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.000 Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.508 Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.459 eval | step: 600 | eval time: 3.5 sec | output: {'AP': 0.4017909, 'AP50': 0.60976845, 'AP75': 0.4924452, 'APl': 0.4394729, 'APm': 0.4249094, 'APs': 0.0, 'ARl': 0.4592896, 'ARm': 0.5084442, 'ARmax1': 0.26798356, 'ARmax10': 0.4187947, 'ARmax100': 0.47820064, 'ARs': 0.0, 'box_loss': 0.0026975253, 'cls_loss': 0.25852692, 'model_loss': 0.39340317, 'total_loss': 0.93243176, 'validation_loss': 0.93243176} train | step: 600 | training until step 700... train | step: 700 | steps/sec: 5.9 | output: {'box_loss': 0.002797648, 'cls_loss': 0.29278567, 'learning_rate': 0.020610739, 'model_loss': 0.43266803, 'total_loss': 0.97009385, 'training_loss': 0.97009385} saved checkpoint to ../trained_model/ckpt-700. eval | step: 700 | running 100 steps of evaluation... creating index... index created! creating index... index created! Running per image evaluation... Evaluate annotation type *bbox* DONE (t=0.94s). Accumulating evaluation results... DONE (t=0.07s). Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.299 Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.463 Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.349 Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.000 Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.390 Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.369 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.260 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.413 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.482 Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.006 Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.536 Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.475 eval | step: 700 | eval time: 3.5 sec | output: {'AP': 0.29852465, 'AP50': 0.46254238, 'AP75': 0.34899247, 'APl': 0.36936423, 'APm': 0.39036143, 'APs': 4.7877307e-05, 'ARl': 0.47540984, 'ARm': 0.5358583, 'ARmax1': 0.2603382, 'ARmax10': 0.4129301, 'ARmax100': 0.48184597, 'ARs': 0.005882353, 'box_loss': 0.002813503, 'cls_loss': 0.2719897, 'model_loss': 0.41266486, 'total_loss': 0.9487193, 'validation_loss': 0.9487193} train | step: 700 | training until step 800... train | step: 800 | steps/sec: 5.9 | output: {'box_loss': 0.0027542955, 'cls_loss': 0.26488435, 'learning_rate': 0.009549147, 'model_loss': 0.40259916, 'total_loss': 0.9377283, 'training_loss': 0.9377283} saved checkpoint to ../trained_model/ckpt-800. eval | step: 800 | running 100 steps of evaluation... creating index... index created! creating index... index created! Running per image evaluation... Evaluate annotation type *bbox* DONE (t=0.91s). Accumulating evaluation results... DONE (t=0.06s). Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.384 Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.592 Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.474 Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.002 Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.398 Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.415 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.263 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.430 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.491 Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.027 Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.514 Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.454 eval | step: 800 | eval time: 3.4 sec | output: {'AP': 0.38367727, 'AP50': 0.5922334, 'AP75': 0.47372296, 'APl': 0.4150953, 'APm': 0.3983694, 'APs': 0.0024488068, 'ARl': 0.45373407, 'ARm': 0.5139734, 'ARmax1': 0.2633304, 'ARmax10': 0.429968, 'ARmax100': 0.49090034, 'ARs': 0.02745098, 'box_loss': 0.0025948472, 'cls_loss': 0.25274864, 'model_loss': 0.38249102, 'total_loss': 0.91687936, 'validation_loss': 0.91687936} train | step: 800 | training until step 900... train | step: 900 | steps/sec: 6.0 | output: {'box_loss': 0.0025644554, 'cls_loss': 0.25460792, 'learning_rate': 0.002447176, 'model_loss': 0.38283065, 'total_loss': 0.9168205, 'training_loss': 0.9168205} saved checkpoint to ../trained_model/ckpt-900. eval | step: 900 | running 100 steps of evaluation... creating index... index created! creating index... index created! Running per image evaluation... Evaluate annotation type *bbox* DONE (t=0.89s). Accumulating evaluation results... DONE (t=0.06s). Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.357 Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.537 Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.421 Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.001 Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.403 Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.403 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.270 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.440 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.503 Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.022 Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.540 Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.469 eval | step: 900 | eval time: 3.5 sec | output: {'AP': 0.35744554, 'AP50': 0.537141, 'AP75': 0.42080757, 'APl': 0.40314218, 'APm': 0.40293044, 'APs': 0.0007871322, 'ARl': 0.46857923, 'ARm': 0.53984624, 'ARmax1': 0.27011546, 'ARmax10': 0.43987975, 'ARmax100': 0.50345236, 'ARs': 0.021568628, 'box_loss': 0.0025267936, 'cls_loss': 0.25612128, 'model_loss': 0.38246095, 'total_loss': 0.91617805, 'validation_loss': 0.91617805} train | step: 900 | training until step 1000... train | step: 1000 | steps/sec: 5.7 | output: {'box_loss': 0.002532323, 'cls_loss': 0.25445548, 'learning_rate': 0.0, 'model_loss': 0.3810716, 'total_loss': 0.9147053, 'training_loss': 0.9147053} saved checkpoint to ../trained_model/ckpt-1000. eval | step: 1000 | running 100 steps of evaluation... creating index... index created! creating index... index created! Running per image evaluation... Evaluate annotation type *bbox* DONE (t=0.90s). Accumulating evaluation results... DONE (t=0.06s). Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.335 Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.504 Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.384 Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.000 Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.396 Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.364 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.270 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.429 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.502 Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.018 Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.541 Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.468 eval | step: 1000 | eval time: 3.5 sec | output: {'AP': 0.33458188, 'AP50': 0.5040318, 'AP75': 0.38418216, 'APl': 0.36382246, 'APm': 0.39634687, 'APs': 0.00046737958, 'ARl': 0.46803278, 'ARm': 0.5410252, 'ARmax1': 0.2700074, 'ARmax10': 0.42944863, 'ARmax100': 0.5023168, 'ARs': 0.01764706, 'box_loss': 0.0025171877, 'cls_loss': 0.26180327, 'model_loss': 0.38766268, 'total_loss': 0.9212637, 'validation_loss': 0.9212637} eval | step: 1000 | running 100 steps of evaluation... creating index... index created! creating index... index created! Running per image evaluation... Evaluate annotation type *bbox* DONE (t=0.90s). Accumulating evaluation results... DONE (t=0.06s). Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.335 Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.504 Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.384 Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.000 Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.396 Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.364 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.270 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.429 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.502 Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.018 Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.541 Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.468 eval | step: 1000 | eval time: 3.4 sec | output: {'AP': 0.33458188, 'AP50': 0.5040318, 'AP75': 0.38418216, 'APl': 0.36382246, 'APm': 0.39634687, 'APs': 0.00046737958, 'ARl': 0.46803278, 'ARm': 0.5410252, 'ARmax1': 0.2700074, 'ARmax10': 0.42944863, 'ARmax100': 0.5023168, 'ARs': 0.01764706, 'box_loss': 0.0025171877, 'cls_loss': 0.26180327, 'model_loss': 0.38766268, 'total_loss': 0.9212637, 'validation_loss': 0.9212637}
Load logs in tensorboard.
%load_ext tensorboard
%tensorboard --logdir '../trained_model/'
Saving and exporting the trained model.
The keras.Model
object returned by train_lib.run_experiment
expects the data to be normalized by the dataset loader using the same mean and variance statiscics in preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB)
. This export function handles those details, so you can pass tf.uint8
images and get the correct results.
export_saved_model_lib.export_inference_graph(
input_type='image_tensor',
batch_size=1,
input_image_size=[HEIGHT, WIDTH],
params=exp_config,
checkpoint_path=tf.train.latest_checkpoint(model_dir),
export_dir=export_dir)
WARNING:tensorflow:Skipping full serialization of Keras layer <official.vision.modeling.retinanet_model.RetinaNetModel object at 0x7f38b023a4f0>, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer <official.vision.modeling.retinanet_model.RetinaNetModel object at 0x7f38b023a4f0>, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer <official.vision.modeling.layers.detection_generator.MultilevelDetectionGenerator object at 0x7f4c2d1d26d0>, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer <official.vision.modeling.layers.detection_generator.MultilevelDetectionGenerator object at 0x7f4c2d1d26d0>, because it is not built. WARNING:absl:Found untraced functions such as inference_for_tflite, inference_from_image_bytes, inference_from_tf_example, retina_net_head_1_layer_call_fn, retina_net_head_1_layer_call_and_return_conditional_losses while saving (showing 5 of 328). These functions will not be directly callable after loading. INFO:tensorflow:Assets written to: ../exported_model/assets INFO:tensorflow:Assets written to: ../exported_model/assets
Inference from trained model
def load_image_into_numpy_array(path):
"""Load an image from file into a numpy array.
Puts image into numpy array to feed into tensorflow graph.
Note that by convention we put it into a numpy array with shape
(height, width, channels), where channels=3 for RGB.
Args:
path: the file path to the image
Returns:
uint8 numpy array with shape (img_height, img_width, 3)
"""
image = None
if(path.startswith('http')):
response = urlopen(path)
image_data = response.read()
image_data = BytesIO(image_data)
image = Image.open(image_data)
else:
image_data = tf.io.gfile.GFile(path, 'rb').read()
image = Image.open(BytesIO(image_data))
(im_width, im_height) = image.size
return np.array(image.getdata()).reshape(
(1, im_height, im_width, 3)).astype(np.uint8)
def build_inputs_for_object_detection(image, input_image_size):
"""Builds Object Detection model inputs for serving."""
image, _ = resize_and_crop_image(
image,
input_image_size,
padded_size=input_image_size,
aug_scale_min=1.0,
aug_scale_max=1.0)
return image
Visualize test data.
num_of_examples = 3
test_ds = tf.data.TFRecordDataset(
'../bccd_coco_tfrecords/test-00000-of-00001.tfrecord').take(
num_of_examples)
show_batch(test_ds, num_of_examples)
Importing SavedModel.
imported = tf.saved_model.load(export_dir)
model_fn = imported.signatures['serving_default']
Visualize predictions.
input_image_size = (HEIGHT, WIDTH)
plt.figure(figsize=(20, 20))
min_score_thresh = 0.30 # Change minimum score for threshold to see all bounding boxes confidences.
for i, serialized_example in enumerate(test_ds):
plt.subplot(1, 3, i+1)
decoded_tensors = tf_ex_decoder.decode(serialized_example)
image = build_inputs_for_object_detection(decoded_tensors['image'], input_image_size)
image = tf.expand_dims(image, axis=0)
image = tf.cast(image, dtype = tf.uint8)
image_np = image[0].numpy()
result = model_fn(image)
visualization_utils.visualize_boxes_and_labels_on_image_array(
image_np,
result['detection_boxes'][0].numpy(),
result['detection_classes'][0].numpy().astype(int),
result['detection_scores'][0].numpy(),
category_index=category_index,
use_normalized_coordinates=False,
max_boxes_to_draw=200,
min_score_thresh=min_score_thresh,
agnostic_mode=False,
instance_masks=None,
line_thickness=4)
plt.imshow(image_np)
plt.axis('off')
plt.show()