![]() |
![]() |
![]() |
![]() |
This tutorial fine-tunes a Mask R-CNN with Mobilenet V2 as backbone model from the TensorFlow Model Garden package (tensorflow-models).
Model Garden contains a collection of state-of-the-art models, implemented with TensorFlow's high-level APIs. The implementations demonstrate the best practices for modeling, letting users to take full advantage of TensorFlow for their research and product development.
This tutorial demonstrates how to:
- Use models from the TensorFlow Models package.
- Train/Fine-tune a pre-built Mask R-CNN with mobilenet as backbone for Object Detection and Instance Segmentation
- Export the trained/tuned Mask R-CNN model
Install Necessary Dependencies
pip install -U -q "tf-models-official"
pip install -U -q remotezip tqdm opencv-python einops
Import required libraries
import os
import io
import json
import tqdm
import shutil
import pprint
import pathlib
import tempfile
import requests
import collections
import matplotlib
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from PIL import Image
from six import BytesIO
from etils import epath
from IPython import display
from urllib.request import urlopen
2023-02-09 12:05:40.296233: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory 2023-02-09 12:05:40.296323: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory 2023-02-09 12:05:40.296332: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
import orbit
import tensorflow as tf
import tensorflow_models as tfm
import tensorflow_datasets as tfds
from official.core import exp_factory
from official.core import config_definitions as cfg
from official.vision.data import tfrecord_lib
from official.vision.serving import export_saved_model_lib
from official.vision.dataloaders.tf_example_decoder import TfExampleDecoder
from official.vision.utils.object_detection import visualization_utils
from official.vision.ops.preprocess_ops import normalize_image, resize_and_crop_image
from official.vision.data.create_coco_tf_record import coco_annotations_to_lists
pp = pprint.PrettyPrinter(indent=4) # Set Pretty Print Indentation
print(tf.__version__) # Check the version of tensorflow used
%matplotlib inline
2.11.0
Download subset of lvis dataset
LVIS: A dataset for large vocabulary instance segmentation.
# @title Download annotation files
wget https://s3-us-west-2.amazonaws.com/dl.fbaipublicfiles.com/LVIS/lvis_v1_train.json.zip
unzip -q lvis_v1_train.json.zip
rm lvis_v1_train.json.zip
wget https://s3-us-west-2.amazonaws.com/dl.fbaipublicfiles.com/LVIS/lvis_v1_val.json.zip
unzip -q lvis_v1_val.json.zip
rm lvis_v1_val.json.zip
wget https://s3-us-west-2.amazonaws.com/dl.fbaipublicfiles.com/LVIS/lvis_v1_image_info_test_dev.json.zip
unzip -q lvis_v1_image_info_test_dev.json.zip
rm lvis_v1_image_info_test_dev.json.zip
--2023-02-09 12:05:43-- https://s3-us-west-2.amazonaws.com/dl.fbaipublicfiles.com/LVIS/lvis_v1_train.json.zip Resolving s3-us-west-2.amazonaws.com (s3-us-west-2.amazonaws.com)... 52.92.197.112, 52.92.178.40, 52.218.235.72, ... Connecting to s3-us-west-2.amazonaws.com (s3-us-west-2.amazonaws.com)|52.92.197.112|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 350264821 (334M) [application/zip] Saving to: ‘lvis_v1_train.json.zip’ lvis_v1_train.json. 100%[===================>] 334.04M 42.1MB/s in 8.1s 2023-02-09 12:05:51 (41.0 MB/s) - ‘lvis_v1_train.json.zip’ saved [350264821/350264821] --2023-02-09 12:06:00-- https://s3-us-west-2.amazonaws.com/dl.fbaipublicfiles.com/LVIS/lvis_v1_val.json.zip Resolving s3-us-west-2.amazonaws.com (s3-us-west-2.amazonaws.com)... 52.218.218.192, 3.5.80.17, 52.92.209.184, ... Connecting to s3-us-west-2.amazonaws.com (s3-us-west-2.amazonaws.com)|52.218.218.192|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 64026968 (61M) [application/zip] Saving to: ‘lvis_v1_val.json.zip’ lvis_v1_val.json.zi 100%[===================>] 61.06M 34.0MB/s in 1.8s 2023-02-09 12:06:03 (34.0 MB/s) - ‘lvis_v1_val.json.zip’ saved [64026968/64026968] --2023-02-09 12:06:05-- https://s3-us-west-2.amazonaws.com/dl.fbaipublicfiles.com/LVIS/lvis_v1_image_info_test_dev.json.zip Resolving s3-us-west-2.amazonaws.com (s3-us-west-2.amazonaws.com)... 52.92.211.152, 52.218.216.136, 52.92.250.120, ... Connecting to s3-us-west-2.amazonaws.com (s3-us-west-2.amazonaws.com)|52.92.211.152|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 384629 (376K) [application/zip] Saving to: ‘lvis_v1_image_info_test_dev.json.zip’ lvis_v1_image_info_ 100%[===================>] 375.61K 1.53MB/s in 0.2s 2023-02-09 12:06:05 (1.53 MB/s) - ‘lvis_v1_image_info_test_dev.json.zip’ saved [384629/384629]
# @title Lvis annotation parsing
# Annotations with invalid bounding boxes. Will not be used.
_INVALID_ANNOTATIONS = [
# Train split.
662101,
81217,
462924,
227817,
29381,
601484,
412185,
504667,
572573,
91937,
239022,
181534,
101685,
# Validation split.
36668,
57541,
33126,
10932,
]
def get_category_map(annotation_path, num_classes):
with epath.Path(annotation_path).open() as f:
data = json.load(f)
category_map = {id+1: {'id': cat_dict['id'],
'name': cat_dict['name']}
for id, cat_dict in enumerate(data['categories'][:num_classes])}
return category_map
class LvisAnnotation:
"""LVIS annotation helper class.
The format of the annations is explained on
https://www.lvisdataset.org/dataset.
"""
def __init__(self, annotation_path):
with epath.Path(annotation_path).open() as f:
data = json.load(f)
self._data = data
img_id2annotations = collections.defaultdict(list)
for a in self._data.get('annotations', []):
if a['category_id'] in category_ids:
img_id2annotations[a['image_id']].append(a)
self._img_id2annotations = {
k: list(sorted(v, key=lambda a: a['id']))
for k, v in img_id2annotations.items()
}
@property
def categories(self):
"""Return the category dicts, as sorted in the file."""
return self._data['categories']
@property
def images(self):
"""Return the image dicts, as sorted in the file."""
sub_images = []
for image_info in self._data['images']:
if image_info['id'] in self._img_id2annotations:
sub_images.append(image_info)
return sub_images
def get_annotations(self, img_id):
"""Return all annotations associated with the image id string."""
# Some images don't have any annotations. Return empty list instead.
return self._img_id2annotations.get(img_id, [])
def _generate_tf_records(prefix, images_zip, annotation_file, num_shards=5):
"""Generate TFRecords."""
lvis_annotation = LvisAnnotation(annotation_file)
def _process_example(prefix, image_info, id_to_name_map):
# Search image dirs.
filename = pathlib.Path(image_info['coco_url']).name
image = tf.io.read_file(os.path.join(IMGS_DIR, filename))
instances = lvis_annotation.get_annotations(img_id=image_info['id'])
instances = [x for x in instances if x['id'] not in _INVALID_ANNOTATIONS]
# print([x['category_id'] for x in instances])
is_crowd = {'iscrowd': 0}
instances = [dict(x, **is_crowd) for x in instances]
neg_category_ids = image_info.get('neg_category_ids', [])
not_exhaustive_category_ids = image_info.get(
'not_exhaustive_category_ids', []
)
data, _ = coco_annotations_to_lists(instances,
id_to_name_map,
image_info['height'],
image_info['width'],
include_masks=True)
# data['category_id'] = [id-1 for id in data['category_id']]
keys_to_features = {
'image/encoded':
tfrecord_lib.convert_to_feature(image.numpy()),
'image/filename':
tfrecord_lib.convert_to_feature(filename.encode('utf8')),
'image/format':
tfrecord_lib.convert_to_feature('jpg'.encode('utf8')),
'image/height':
tfrecord_lib.convert_to_feature(image_info['height']),
'image/width':
tfrecord_lib.convert_to_feature(image_info['width']),
'image/source_id':
tfrecord_lib.convert_to_feature(str(image_info['id']).encode('utf8')),
'image/object/bbox/xmin':
tfrecord_lib.convert_to_feature(data['xmin']),
'image/object/bbox/xmax':
tfrecord_lib.convert_to_feature(data['xmax']),
'image/object/bbox/ymin':
tfrecord_lib.convert_to_feature(data['ymin']),
'image/object/bbox/ymax':
tfrecord_lib.convert_to_feature(data['ymax']),
'image/object/class/text':
tfrecord_lib.convert_to_feature(data['category_names']),
'image/object/class/label':
tfrecord_lib.convert_to_feature(data['category_id']),
'image/object/is_crowd':
tfrecord_lib.convert_to_feature(data['is_crowd']),
'image/object/area':
tfrecord_lib.convert_to_feature(data['area'], 'float_list'),
'image/object/mask':
tfrecord_lib.convert_to_feature(data['encoded_mask_png'])
}
# print(keys_to_features['image/object/class/label'])
example = tf.train.Example(
features=tf.train.Features(feature=keys_to_features))
return example
# file_names = [f"{prefix}/{pathlib.Path(image_info['coco_url']).name}"
# for image_info in lvis_annotation.images]
# _extract_images(images_zip, file_names)
writers = [
tf.io.TFRecordWriter(
tf_records_dir + prefix +'-%05d-of-%05d.tfrecord' % (i, num_shards))
for i in range(num_shards)
]
id_to_name_map = {cat_dict['id']: cat_dict['name']
for cat_dict in lvis_annotation.categories[:NUM_CLASSES]}
# print(id_to_name_map)
for idx, image_info in enumerate(tqdm.tqdm(lvis_annotation.images)):
img_data = requests.get(image_info['coco_url'], stream=True).content
img_name = image_info['coco_url'].split('/')[-1]
with open(os.path.join(IMGS_DIR, img_name), 'wb') as handler:
handler.write(img_data)
tf_example = _process_example(prefix, image_info, id_to_name_map)
writers[idx % num_shards].write(tf_example.SerializeToString())
del lvis_annotation
_URLS = {
'train_images': 'http://images.cocodataset.org/zips/train2017.zip',
'validation_images': 'http://images.cocodataset.org/zips/val2017.zip',
'test_images': 'http://images.cocodataset.org/zips/test2017.zip',
}
train_prefix = 'train'
valid_prefix = 'val'
train_annotation_path = './lvis_v1_train.json'
valid_annotation_path = './lvis_v1_val.json'
IMGS_DIR = './lvis_sub_dataset/'
tf_records_dir = './lvis_tfrecords/'
if not os.path.exists(IMGS_DIR):
os.mkdir(IMGS_DIR)
if not os.path.exists(tf_records_dir):
os.mkdir(tf_records_dir)
NUM_CLASSES = 3
category_index = get_category_map(valid_annotation_path, NUM_CLASSES)
category_ids = list(category_index.keys())
# Below helper function are taken from github tensorflow dataset lvis
# https://github.com/tensorflow/datasets/blob/master/tensorflow_datasets/datasets/lvis/lvis_dataset_builder.py
_generate_tf_records(train_prefix,
_URLS['train_images'],
train_annotation_path)
100%|██████████| 2338/2338 [09:46<00:00, 3.99it/s]
_generate_tf_records(valid_prefix,
_URLS['validation_images'],
valid_annotation_path)
100%|██████████| 422/422 [01:44<00:00, 4.04it/s]
Configure the MaskRCNN Resnet FPN COCO model for custom dataset
train_data_input_path = './lvis_tfrecords/train*'
valid_data_input_path = './lvis_tfrecords/val*'
test_data_input_path = './lvis_tfrecords/test*'
model_dir = './trained_model/'
export_dir ='./exported_model/'
if not os.path.exists(model_dir):
os.mkdir(model_dir)
In Model Garden, the collections of parameters that define a model are called configs. Model Garden can create a config based on a known set of parameters via a factory.
Use the retinanet_mobilenet_coco
experiment configuration, as defined by tfm.vision.configs.maskrcnn.maskrcnn_mobilenet_coco
.
Please find all the registered experiements here
The configuration defines an experiment to train a Mask R-CNN model with mobilenet as backbone and FPN as decoder. Default Congiguration is trained on COCO train2017 and evaluated on COCO val2017.
There are also other alternative experiments available such as
maskrcnn_resnetfpn_coco
,
maskrcnn_spinenet_coco
and more. One can switch to them by changing the experiment name argument to the get_exp_config
function.
exp_config = exp_factory.get_exp_config('maskrcnn_mobilenet_coco')
model_ckpt_path = './model_ckpt/'
if not os.path.exists(model_ckpt_path):
os.mkdir(model_ckpt_path)
!gsutil cp gs://tf_model_garden/vision/mobilenet/v2_1.0_float/ckpt-180648.data-00000-of-00001 './model_ckpt/'
!gsutil cp gs://tf_model_garden/vision/mobilenet/v2_1.0_float/ckpt-180648.index './model_ckpt/'
Copying gs://tf_model_garden/vision/mobilenet/v2_1.0_float/ckpt-180648.data-00000-of-00001... Operation completed over 1 objects/26.9 MiB. Copying gs://tf_model_garden/vision/mobilenet/v2_1.0_float/ckpt-180648.index... Operation completed over 1 objects/7.5 KiB.
Adjust the model and dataset configurations so that it works with custom dataset.
BATCH_SIZE = 8
HEIGHT, WIDTH = 256, 256
IMG_SHAPE = [HEIGHT, WIDTH, 3]
# Backbone Config
exp_config.task.annotation_file = None
exp_config.task.freeze_backbone = True
exp_config.task.init_checkpoint = "./model_ckpt/ckpt-180648"
exp_config.task.init_checkpoint_modules = "backbone"
# Model Config
exp_config.task.model.num_classes = NUM_CLASSES + 1
exp_config.task.model.input_size = IMG_SHAPE
# Training Data Config
exp_config.task.train_data.input_path = train_data_input_path
exp_config.task.train_data.dtype = 'float32'
exp_config.task.train_data.global_batch_size = BATCH_SIZE
exp_config.task.train_data.shuffle_buffer_size = 64
exp_config.task.train_data.parser.aug_scale_max = 1.0
exp_config.task.train_data.parser.aug_scale_min = 1.0
# Validation Data Config
exp_config.task.validation_data.input_path = valid_data_input_path
exp_config.task.validation_data.dtype = 'float32'
exp_config.task.validation_data.global_batch_size = BATCH_SIZE
Adjust the trainer configuration.
logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]
if 'GPU' in ''.join(logical_device_names):
print('This may be broken in Colab.')
device = 'GPU'
elif 'TPU' in ''.join(logical_device_names):
print('This may be broken in Colab.')
device = 'TPU'
else:
print('Running on CPU is slow, so only train for a few steps.')
device = 'CPU'
train_steps = 2000
exp_config.trainer.steps_per_loop = 200 # steps_per_loop = num_of_training_examples // train_batch_size
exp_config.trainer.summary_interval = 200
exp_config.trainer.checkpoint_interval = 200
exp_config.trainer.validation_interval = 200
exp_config.trainer.validation_steps = 200 # validation_steps = num_of_validation_examples // eval_batch_size
exp_config.trainer.train_steps = train_steps
exp_config.trainer.optimizer_config.warmup.linear.warmup_steps = 200
exp_config.trainer.optimizer_config.learning_rate.type = 'cosine'
exp_config.trainer.optimizer_config.learning_rate.cosine.decay_steps = train_steps
exp_config.trainer.optimizer_config.learning_rate.cosine.initial_learning_rate = 0.07
exp_config.trainer.optimizer_config.warmup.linear.warmup_learning_rate = 0.05
This may be broken in Colab.
Print the modified configuration.
pp.pprint(exp_config.as_dict())
display.Javascript("google.colab.output.setIframeHeight('500px');")
{ 'runtime': { 'all_reduce_alg': None, 'batchnorm_spatial_persistent': False, 'dataset_num_private_threads': None, 'default_shard_dim': -1, 'distribution_strategy': 'mirrored', 'enable_xla': False, 'gpu_thread_mode': None, 'loss_scale': None, 'mixed_precision_dtype': 'bfloat16', 'num_cores_per_replica': 1, 'num_gpus': 0, 'num_packs': 1, 'per_gpu_thread_count': 0, 'run_eagerly': False, 'task_index': -1, 'tpu': None, 'tpu_enable_xla_dynamic_padder': None, 'worker_hosts': None}, 'task': { 'allowed_mask_class_ids': None, 'annotation_file': None, 'differential_privacy_config': None, 'freeze_backbone': True, 'init_checkpoint': './model_ckpt/ckpt-180648', 'init_checkpoint_modules': 'backbone', 'losses': { 'frcnn_box_weight': 1.0, 'frcnn_class_weight': 1.0, 'frcnn_huber_loss_delta': 1.0, 'l2_weight_decay': 4e-05, 'loss_weight': 1.0, 'mask_weight': 1.0, 'rpn_box_weight': 1.0, 'rpn_huber_loss_delta': 0.1111111111111111, 'rpn_score_weight': 1.0}, 'model': { 'anchor': { 'anchor_size': 3, 'aspect_ratios': [0.5, 1.0, 2.0], 'num_scales': 1}, 'backbone': { 'mobilenet': { 'filter_size_scale': 1.0, 'model_id': 'MobileNetV2', 'output_intermediate_endpoints': False, 'output_stride': None, 'stochastic_depth_drop_rate': 0.0}, 'type': 'mobilenet'}, 'decoder': { 'fpn': { 'fusion_type': 'sum', 'num_filters': 128, 'use_keras_layer': False, 'use_separable_conv': True}, 'type': 'fpn'}, 'detection_generator': { 'apply_nms': True, 'max_num_detections': 100, 'nms_iou_threshold': 0.5, 'nms_version': 'v2', 'pre_nms_score_threshold': 0.05, 'pre_nms_top_k': 5000, 'soft_nms_sigma': None, 'use_cpu_nms': False}, 'detection_head': { 'cascade_class_ensemble': False, 'class_agnostic_bbox_pred': False, 'fc_dims': 512, 'num_convs': 4, 'num_fcs': 1, 'num_filters': 128, 'use_separable_conv': True}, 'include_mask': True, 'input_size': [256, 256, 3], 'mask_head': { 'class_agnostic': False, 'num_convs': 4, 'num_filters': 128, 'upsample_factor': 2, 'use_separable_conv': True}, 'mask_roi_aligner': { 'crop_size': 14, 'sample_offset': 0.5}, 'mask_sampler': {'num_sampled_masks': 128}, 'max_level': 6, 'min_level': 3, 'norm_activation': { 'activation': 'relu6', 'norm_epsilon': 0.001, 'norm_momentum': 0.99, 'use_sync_bn': True}, 'num_classes': 4, 'roi_aligner': { 'crop_size': 7, 'sample_offset': 0.5}, 'roi_generator': { 'nms_iou_threshold': 0.7, 'num_proposals': 1000, 'pre_nms_min_size_threshold': 0.0, 'pre_nms_score_threshold': 0.0, 'pre_nms_top_k': 2000, 'test_nms_iou_threshold': 0.7, 'test_num_proposals': 1000, 'test_pre_nms_min_size_threshold': 0.0, 'test_pre_nms_score_threshold': 0.0, 'test_pre_nms_top_k': 1000, 'use_batched_nms': False}, 'roi_sampler': { 'background_iou_high_threshold': 0.5, 'background_iou_low_threshold': 0.0, 'cascade_iou_thresholds': None, 'foreground_fraction': 0.25, 'foreground_iou_threshold': 0.5, 'mix_gt_boxes': True, 'num_sampled_rois': 512}, 'rpn_head': { 'num_convs': 1, 'num_filters': 128, 'use_separable_conv': True} }, 'name': None, 'per_category_metrics': False, 'train_data': { 'apply_tf_data_service_before_batching': False, 'block_length': 1, 'cache': False, 'cycle_length': None, 'decoder': { 'simple_decoder': { 'mask_binarize_threshold': None, 'regenerate_source_id': False}, 'type': 'simple_decoder'}, 'deterministic': None, 'drop_remainder': True, 'dtype': 'float32', 'enable_shared_tf_data_service_between_parallel_trainers': False, 'enable_tf_data_service': False, 'file_type': 'tfrecord', 'global_batch_size': 8, 'input_path': './lvis_tfrecords/train*', 'is_training': True, 'num_examples': -1, 'parser': { 'aug_rand_hflip': True, 'aug_scale_max': 1.0, 'aug_scale_min': 1.0, 'aug_type': None, 'mask_crop_size': 112, 'match_threshold': 0.5, 'max_num_instances': 100, 'num_channels': 3, 'rpn_batch_size_per_im': 256, 'rpn_fg_fraction': 0.5, 'rpn_match_threshold': 0.7, 'rpn_unmatched_threshold': 0.3, 'skip_crowd_during_training': True, 'unmatched_threshold': 0.5}, 'prefetch_buffer_size': None, 'seed': None, 'sharding': True, 'shuffle_buffer_size': 64, 'tf_data_service_address': None, 'tf_data_service_job_name': None, 'tfds_as_supervised': False, 'tfds_data_dir': '', 'tfds_name': '', 'tfds_skip_decoding_feature': '', 'tfds_split': '', 'trainer_id': None}, 'use_coco_metrics': True, 'use_wod_metrics': False, 'validation_data': { 'apply_tf_data_service_before_batching': False, 'block_length': 1, 'cache': False, 'cycle_length': None, 'decoder': { 'simple_decoder': { 'mask_binarize_threshold': None, 'regenerate_source_id': False}, 'type': 'simple_decoder'}, 'deterministic': None, 'drop_remainder': False, 'dtype': 'float32', 'enable_shared_tf_data_service_between_parallel_trainers': False, 'enable_tf_data_service': False, 'file_type': 'tfrecord', 'global_batch_size': 8, 'input_path': './lvis_tfrecords/val*', 'is_training': False, 'num_examples': -1, 'parser': { 'aug_rand_hflip': False, 'aug_scale_max': 1.0, 'aug_scale_min': 1.0, 'aug_type': None, 'mask_crop_size': 112, 'match_threshold': 0.5, 'max_num_instances': 100, 'num_channels': 3, 'rpn_batch_size_per_im': 256, 'rpn_fg_fraction': 0.5, 'rpn_match_threshold': 0.7, 'rpn_unmatched_threshold': 0.3, 'skip_crowd_during_training': True, 'unmatched_threshold': 0.5}, 'prefetch_buffer_size': None, 'seed': None, 'sharding': True, 'shuffle_buffer_size': 10000, 'tf_data_service_address': None, 'tf_data_service_job_name': None, 'tfds_as_supervised': False, 'tfds_data_dir': '', 'tfds_name': '', 'tfds_skip_decoding_feature': '', 'tfds_split': '', 'trainer_id': None} }, 'trainer': { 'allow_tpu_summary': False, 'best_checkpoint_eval_metric': '', 'best_checkpoint_export_subdir': '', 'best_checkpoint_metric_comp': 'higher', 'checkpoint_interval': 200, 'continuous_eval_timeout': 3600, 'eval_tf_function': True, 'eval_tf_while_loop': False, 'loss_upper_bound': 1000000.0, 'max_to_keep': 5, 'optimizer_config': { 'ema': None, 'learning_rate': { 'cosine': { 'alpha': 0.0, 'decay_steps': 2000, 'initial_learning_rate': 0.07, 'name': 'CosineDecay', 'offset': 0}, 'type': 'cosine'}, 'optimizer': { 'sgd': { 'clipnorm': None, 'clipvalue': None, 'decay': 0.0, 'global_clipnorm': None, 'momentum': 0.9, 'name': 'SGD', 'nesterov': False}, 'type': 'sgd'}, 'warmup': { 'linear': { 'name': 'linear', 'warmup_learning_rate': 0.05, 'warmup_steps': 200}, 'type': 'linear'} }, 'recovery_begin_steps': 0, 'recovery_max_trials': 0, 'steps_per_loop': 200, 'summary_interval': 200, 'train_steps': 2000, 'train_tf_function': True, 'train_tf_while_loop': True, 'validation_interval': 200, 'validation_steps': 200, 'validation_summary_subdir': 'validation'} } <IPython.core.display.Javascript object>
Set up the distribution strategy.
# Setting up the Strategy
if exp_config.runtime.mixed_precision_dtype == tf.float16:
tf.keras.mixed_precision.set_global_policy('mixed_float16')
if 'GPU' in ''.join(logical_device_names):
distribution_strategy = tf.distribute.MirroredStrategy()
elif 'TPU' in ''.join(logical_device_names):
tf.tpu.experimental.initialize_tpu_system()
tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='/device:TPU_SYSTEM:0')
distribution_strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
print('Warning: this will be really slow.')
distribution_strategy = tf.distribute.OneDeviceStrategy(logical_device_names[0])
print("Done")
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3') INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3') Done
Create the Task
object (tfm.core.base_task.Task
) from the config_definitions.TaskConfig
.
The Task
object has all the methods necessary for building the dataset, building the model, and running training & evaluation. These methods are driven by tfm.core.train_lib.run_experiment
.
with distribution_strategy.scope():
task = tfm.core.task_factory.get_task(exp_config.task, logging_dir=model_dir)
Visualize a batch of the data.
for images, labels in task.build_inputs(exp_config.task.train_data).take(1):
print()
print(f'images.shape: {str(images.shape):16} images.dtype: {images.dtype!r}')
print(f'labels.keys: {labels.keys()}')
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23. Instructions for updating: Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089 WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23. Instructions for updating: Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089 WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/deprecation.py:629: calling map_fn_v2 (from tensorflow.python.ops.map_fn) with dtype is deprecated and will be removed in a future version. Instructions for updating: Use fn_output_signature instead WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/deprecation.py:629: calling map_fn_v2 (from tensorflow.python.ops.map_fn) with dtype is deprecated and will be removed in a future version. Instructions for updating: Use fn_output_signature instead 2023-02-09 12:18:34.444674: W tensorflow/core/grappler/costs/op_level_cost_estimator.cc:690] Error in PredictCost() for the op: op: "CropAndResize" attr { key: "T" value { type: DT_FLOAT } } attr { key: "extrapolation_value" value { f: 0 } } attr { key: "method" value { s: "bilinear" } } inputs { dtype: DT_FLOAT shape { dim { size: -47 } dim { size: -31 } dim { size: -32 } dim { size: 1 } } } inputs { dtype: DT_FLOAT shape { dim { size: -3 } dim { size: 4 } } } inputs { dtype: DT_INT32 shape { dim { size: -3 } } } inputs { dtype: DT_INT32 shape { dim { size: 2 } } value { dtype: DT_INT32 tensor_shape { dim { size: 2 } } int_val: 112 } } device { type: "CPU" vendor: "GenuineIntel" model: "111" frequency: 2299 num_cores: 32 environment { key: "cpu_instruction_set" value: "AVX SSE, SSE2, SSE3, SSSE3, SSE4.1, SSE4.2" } environment { key: "eigen" value: "3.4.90" } l1_cache_size: 32768 l2_cache_size: 262144 l3_cache_size: 47185920 memory_size: 268435456 } outputs { dtype: DT_FLOAT shape { dim { size: -3 } dim { size: 112 } dim { size: 112 } dim { size: 1 } } } images.shape: (8, 256, 256, 3) images.dtype: tf.float32 labels.keys: dict_keys(['anchor_boxes', 'image_info', 'rpn_score_targets', 'rpn_box_targets', 'gt_boxes', 'gt_classes', 'gt_masks'])
Create Category Index Dictionary to map the labels to coressponding label names
tf_ex_decoder = TfExampleDecoder(include_mask=True)
Helper Function for Visualizing the results from TFRecords
Use visualize_boxes_and_labels_on_image_array
from visualization_utils
to draw boudning boxes on the image.
def show_batch(raw_records, num_of_examples):
plt.figure(figsize=(20, 20))
use_normalized_coordinates=True
min_score_thresh = 0.30
for i, serialized_example in enumerate(raw_records):
plt.subplot(1, 3, i + 1)
decoded_tensors = tf_ex_decoder.decode(serialized_example)
image = decoded_tensors['image'].numpy().astype('uint8')
scores = np.ones(shape=(len(decoded_tensors['groundtruth_boxes'])))
# print(decoded_tensors['groundtruth_instance_masks'].numpy().shape)
# print(decoded_tensors.keys())
visualization_utils.visualize_boxes_and_labels_on_image_array(
image,
decoded_tensors['groundtruth_boxes'].numpy(),
decoded_tensors['groundtruth_classes'].numpy().astype('int'),
scores,
category_index=category_index,
use_normalized_coordinates=use_normalized_coordinates,
min_score_thresh=min_score_thresh,
instance_masks=decoded_tensors['groundtruth_instance_masks'].numpy().astype('uint8'),
line_thickness=4)
plt.imshow(image)
plt.axis("off")
plt.title(f"Image-{i+1}")
plt.show()
Visualization of Train Data
The bounding box detection has three components
- Class label of the object detected.
- Percentage of match between predicted and ground truth bounding boxes.
- Instance Segmentation Mask
buffer_size = 100
num_of_examples = 3
train_tfrecords = tf.io.gfile.glob(exp_config.task.train_data.input_path)
raw_records = tf.data.TFRecordDataset(train_tfrecords).shuffle(buffer_size=buffer_size).take(num_of_examples)
show_batch(raw_records, num_of_examples)
Train and evaluate
We follow the COCO challenge tradition to evaluate the accuracy of object detection based on mAP(mean Average Precision). Please check here for detail explanation of how evaluation metrics for detection task is done.
IoU: is defined as the area of the intersection divided by the area of the union of a predicted bounding box and ground truth bounding box.
model, eval_logs = tfm.core.train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode='train_and_eval',
params=exp_config,
model_dir=model_dir,
run_post_eval=True)
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/engine/functional.py:638: UserWarning: Input dict contained keys ['6'] which did not match any model input. They will be ignored by the model. inputs = self._flatten_to_reference_inputs(inputs) loading annotations into memory... Done (t=0.01s) creating index... index created! restoring or initializing model... initialized model. train | step: 0 | training until step 200... 2023-02-09 12:19:02.077168: W tensorflow/core/grappler/costs/op_level_cost_estimator.cc:690] Error in PredictCost() for the op: op: "CropAndResize" attr { key: "T" value { type: DT_FLOAT } } attr { key: "extrapolation_value" value { f: 0 } } attr { key: "method" value { s: "bilinear" } } inputs { dtype: DT_FLOAT shape { dim { size: -47 } dim { size: -31 } dim { size: -32 } dim { size: 1 } } } inputs { dtype: DT_FLOAT shape { dim { size: -3 } dim { size: 4 } } } inputs { dtype: DT_INT32 shape { dim { size: -3 } } } inputs { dtype: DT_INT32 shape { dim { size: 2 } } value { dtype: DT_INT32 tensor_shape { dim { size: 2 } } int_val: 112 } } device { type: "CPU" vendor: "GenuineIntel" model: "111" frequency: 2299 num_cores: 32 environment { key: "cpu_instruction_set" value: "AVX SSE, SSE2, SSE3, SSSE3, SSE4.1, SSE4.2" } environment { key: "eigen" value: "3.4.90" } l1_cache_size: 32768 l2_cache_size: 262144 l3_cache_size: 47185920 memory_size: 268435456 } outputs { dtype: DT_FLOAT shape { dim { size: -3 } dim { size: 112 } dim { size: 112 } dim { size: 1 } } } INFO:tensorflow:batch_all_reduce: 1 all-reduces with algorithm = nccl, num_packs = 1 INFO:tensorflow:batch_all_reduce: 1 all-reduces with algorithm = nccl, num_packs = 1 INFO:tensorflow:batch_all_reduce: 1 all-reduces with algorithm = nccl, num_packs = 1 INFO:tensorflow:batch_all_reduce: 1 all-reduces with algorithm = nccl, num_packs = 1 INFO:tensorflow:batch_all_reduce: 1 all-reduces with algorithm = nccl, num_packs = 1 INFO:tensorflow:batch_all_reduce: 1 all-reduces with algorithm = nccl, num_packs = 1 INFO:tensorflow:batch_all_reduce: 1 all-reduces with algorithm = nccl, num_packs = 1 INFO:tensorflow:batch_all_reduce: 1 all-reduces with algorithm = nccl, num_packs = 1 INFO:tensorflow:batch_all_reduce: 1 all-reduces with algorithm = nccl, num_packs = 1 INFO:tensorflow:batch_all_reduce: 1 all-reduces with algorithm = nccl, num_packs = 1 INFO:tensorflow:batch_all_reduce: 1 all-reduces with algorithm = nccl, num_packs = 1 INFO:tensorflow:batch_all_reduce: 1 all-reduces with algorithm = nccl, num_packs = 1 INFO:tensorflow:batch_all_reduce: 1 all-reduces with algorithm = nccl, num_packs = 1 INFO:tensorflow:batch_all_reduce: 1 all-reduces with algorithm = nccl, num_packs = 1 INFO:tensorflow:batch_all_reduce: 1 all-reduces with algorithm = nccl, num_packs = 1 INFO:tensorflow:batch_all_reduce: 1 all-reduces with algorithm = nccl, num_packs = 1 INFO:tensorflow:batch_all_reduce: 1 all-reduces with algorithm = nccl, num_packs = 1 INFO:tensorflow:batch_all_reduce: 1 all-reduces with algorithm = nccl, num_packs = 1 INFO:tensorflow:batch_all_reduce: 1 all-reduces with algorithm = nccl, num_packs = 1 INFO:tensorflow:batch_all_reduce: 1 all-reduces with algorithm = nccl, num_packs = 1 /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/initializers/initializers_v2.py:120: UserWarning: The initializer VarianceScaling is unseeded and being called multiple times, which will return identical values each time (even if the initializer is unseeded). Please update your code to provide a seed to the initializer, or avoid using the same initalizer instance more than once. warnings.warn( train | step: 200 | steps/sec: 1.4 | output: {'frcnn_box_loss': 0.26097748, 'frcnn_cls_loss': 0.058542486, 'learning_rate': 0.06828698, 'mask_loss': 0.5206564, 'model_loss': 0.9554084, 'rpn_box_loss': 0.03707892, 'rpn_score_loss': 0.078153364, 'total_loss': 1.2578363, 'training_loss': 1.2578363} saved checkpoint to ./trained_model/ckpt-200. eval | step: 200 | running 200 steps of evaluation... creating index... index created! Running per image evaluation... Evaluate annotation type *bbox* DONE (t=0.94s). Accumulating evaluation results... DONE (t=0.28s). Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.004 Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.018 Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.000 Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.000 Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.002 Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.015 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.009 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.029 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.036 Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.001 Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.025 Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.094 Running per image evaluation... Evaluate annotation type *segm* DONE (t=1.02s). Accumulating evaluation results... DONE (t=0.27s). Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.002 Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.013 Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.000 Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.000 Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.002 Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.010 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.009 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.024 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.027 Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.001 Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.012 Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.080 eval | step: 200 | eval time: 76.9 sec | output: {'AP': 0.0038142162, 'AP50': 0.018394616, 'AP75': 0.00031846602, 'APl': 0.01463741, 'APm': 0.0015962591, 'APs': 0.00018393727, 'ARl': 0.09382022, 'ARm': 0.025490196, 'ARmax1': 0.009154551, 'ARmax10': 0.02924071, 'ARmax100': 0.035702746, 'ARs': 0.0013513514, 'mask_AP': 0.0024596243, 'mask_AP50': 0.012702563, 'mask_AP75': 5.410377e-05, 'mask_APl': 0.01024499, 'mask_APm': 0.0015298778, 'mask_APs': 0.0004950495, 'mask_ARl': 0.08033708, 'mask_ARm': 0.012418301, 'mask_ARmax1': 0.009262251, 'mask_ARmax10': 0.024448035, 'mask_ARmax100': 0.02735595, 'mask_ARs': 0.0006756757, 'validation_loss': 0.0} train | step: 200 | training until step 400... train | step: 400 | steps/sec: 1.9 | output: {'frcnn_box_loss': 0.29886824, 'frcnn_cls_loss': 0.047152787, 'learning_rate': 0.06331559, 'mask_loss': 0.41211814, 'model_loss': 0.8108435, 'rpn_box_loss': 0.032853108, 'rpn_score_loss': 0.019851197, 'total_loss': 1.1125084, 'training_loss': 1.1125084} saved checkpoint to ./trained_model/ckpt-400. eval | step: 400 | running 200 steps of evaluation... creating index... index created! Running per image evaluation... Evaluate annotation type *bbox* DONE (t=0.96s). Accumulating evaluation results... DONE (t=1.21s). Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.036 Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.104 Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.010 Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.002 Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.021 Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.098 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.048 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.083 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.090 Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.012 Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.073 Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.218 Running per image evaluation... Evaluate annotation type *segm* DONE (t=1.06s). Accumulating evaluation results... DONE (t=0.28s). Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.024 Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.076 Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.005 Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.000 Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.009 Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.076 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.037 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.053 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.054 Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.001 Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.022 Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.163 eval | step: 400 | eval time: 44.4 sec | output: {'AP': 0.03559471, 'AP50': 0.10418426, 'AP75': 0.009932064, 'APl': 0.09798559, 'APm': 0.021183753, 'APs': 0.002425099, 'ARl': 0.21835206, 'ARm': 0.07287582, 'ARmax1': 0.048303716, 'ARmax10': 0.08255251, 'ARmax100': 0.089714594, 'ARs': 0.012162162, 'mask_AP': 0.024233514, 'mask_AP50': 0.07612163, 'mask_AP75': 0.005068625, 'mask_APl': 0.07589979, 'mask_APm': 0.0085301595, 'mask_APs': 0.0002750275, 'mask_ARl': 0.16348314, 'mask_ARm': 0.021568628, 'mask_ARmax1': 0.036618203, 'mask_ARmax10': 0.05261174, 'mask_ARmax100': 0.054281097, 'mask_ARs': 0.0006756757, 'validation_loss': 0.0} train | step: 400 | training until step 600... train | step: 600 | steps/sec: 2.8 | output: {'frcnn_box_loss': 0.26900575, 'frcnn_cls_loss': 0.04801208, 'learning_rate': 0.055572484, 'mask_loss': 0.37825814, 'model_loss': 0.74536073, 'rpn_box_loss': 0.031232227, 'rpn_score_loss': 0.018852688, 'total_loss': 1.045991, 'training_loss': 1.045991} saved checkpoint to ./trained_model/ckpt-600. eval | step: 600 | running 200 steps of evaluation... creating index... index created! Running per image evaluation... Evaluate annotation type *bbox* DONE (t=1.15s). Accumulating evaluation results... DONE (t=0.31s). Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.063 Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.147 Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.043 Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.007 Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.030 Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.177 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.071 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.098 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.109 Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.029 Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.089 Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.254 Running per image evaluation... Evaluate annotation type *segm* DONE (t=1.25s). Accumulating evaluation results... DONE (t=0.28s). Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.040 Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.106 Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.011 Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.000 Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.013 Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.124 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.047 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.055 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.057 Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.000 Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.026 Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.168 eval | step: 600 | eval time: 44.1 sec | output: {'AP': 0.06321701, 'AP50': 0.147283, 'AP75': 0.043131366, 'APl': 0.17651631, 'APm': 0.029862402, 'APs': 0.0072133164, 'ARl': 0.25411984, 'ARm': 0.08916122, 'ARmax1': 0.07134943, 'ARmax10': 0.09805918, 'ARmax100': 0.10899079, 'ARs': 0.028603604, 'mask_AP': 0.039753467, 'mask_AP50': 0.10610386, 'mask_AP75': 0.011098142, 'mask_APl': 0.12370983, 'mask_APm': 0.0126444055, 'mask_APs': 2.0148746e-05, 'mask_ARl': 0.167603, 'mask_ARm': 0.025653595, 'mask_ARmax1': 0.04695746, 'mask_ARmax10': 0.055411953, 'mask_ARmax100': 0.056758214, 'mask_ARs': 0.00045045046, 'validation_loss': 0.0} train | step: 600 | training until step 800... train | step: 800 | steps/sec: 2.8 | output: {'frcnn_box_loss': 0.27452332, 'frcnn_cls_loss': 0.044581868, 'learning_rate': 0.045815594, 'mask_loss': 0.38426793, 'model_loss': 0.7470253, 'rpn_box_loss': 0.02769132, 'rpn_score_loss': 0.015961029, 'total_loss': 1.0466467, 'training_loss': 1.0466467} saved checkpoint to ./trained_model/ckpt-800. eval | step: 800 | running 200 steps of evaluation... creating index... index created! Running per image evaluation... Evaluate annotation type *bbox* DONE (t=0.85s). Accumulating evaluation results... DONE (t=0.26s). Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.061 Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.145 Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.044 Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.006 Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.030 Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.177 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.070 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.098 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.101 Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.015 Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.071 Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.258 Running per image evaluation... Evaluate annotation type *segm* DONE (t=0.94s). Accumulating evaluation results... DONE (t=0.26s). Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.041 Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.121 Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.010 Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.000 Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.017 Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.123 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.049 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.062 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.063 Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.002 Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.032 Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.180 eval | step: 800 | eval time: 44.0 sec | output: {'AP': 0.061455246, 'AP50': 0.14542615, 'AP75': 0.043833878, 'APl': 0.17677288, 'APm': 0.030157736, 'APs': 0.006219705, 'ARl': 0.25805244, 'ARm': 0.07075164, 'ARmax1': 0.069682285, 'ARmax10': 0.09773829, 'ARmax100': 0.10113086, 'ARs': 0.01509009, 'mask_AP': 0.04056062, 'mask_AP50': 0.12092455, 'mask_AP75': 0.010073127, 'mask_APl': 0.12258576, 'mask_APm': 0.017008074, 'mask_APs': 0.00017664173, 'mask_ARl': 0.18014981, 'mask_ARm': 0.031862747, 'mask_ARmax1': 0.04932687, 'mask_ARmax10': 0.06182014, 'mask_ARmax100': 0.062735595, 'mask_ARs': 0.0018018018, 'validation_loss': 0.0} train | step: 800 | training until step 1000... train | step: 1000 | steps/sec: 2.8 | output: {'frcnn_box_loss': 0.2563517, 'frcnn_cls_loss': 0.044390183, 'learning_rate': 0.034999996, 'mask_loss': 0.36897781, 'model_loss': 0.7110931, 'rpn_box_loss': 0.026073322, 'rpn_score_loss': 0.015300027, 'total_loss': 1.0098797, 'training_loss': 1.0098797} saved checkpoint to ./trained_model/ckpt-1000. eval | step: 1000 | running 200 steps of evaluation... creating index... index created! Running per image evaluation... Evaluate annotation type *bbox* DONE (t=0.96s). Accumulating evaluation results... DONE (t=0.27s). Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.074 Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.161 Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.061 Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.008 Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.042 Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.200 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.074 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.104 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.111 Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.023 Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.092 Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.260 Running per image evaluation... Evaluate annotation type *segm* DONE (t=1.06s). Accumulating evaluation results... DONE (t=0.26s). Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.046 Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.122 Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.026 Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.000 Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.018 Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.143 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.054 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.064 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.066 Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.002 Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.030 Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.192 eval | step: 1000 | eval time: 43.3 sec | output: {'AP': 0.07433459, 'AP50': 0.16138124, 'AP75': 0.061414238, 'APl': 0.20042573, 'APm': 0.042105544, 'APs': 0.008244963, 'ARl': 0.2599251, 'ARm': 0.09232026, 'ARmax1': 0.074098006, 'ARmax10': 0.10350027, 'ARmax100': 0.11060851, 'ARs': 0.022972973, 'mask_AP': 0.045994613, 'mask_AP50': 0.122272044, 'mask_AP75': 0.026484298, 'mask_APl': 0.14291438, 'mask_APm': 0.017729025, 'mask_APs': 0.00014489137, 'mask_ARl': 0.19213483, 'mask_ARm': 0.030392157, 'mask_ARmax1': 0.053688746, 'mask_ARmax10': 0.06397415, 'mask_ARmax100': 0.065858915, 'mask_ARs': 0.0024774775, 'validation_loss': 0.0} train | step: 1000 | training until step 1200... train | step: 1200 | steps/sec: 2.8 | output: {'frcnn_box_loss': 0.2445597, 'frcnn_cls_loss': 0.044347066, 'learning_rate': 0.024184398, 'mask_loss': 0.363975, 'model_loss': 0.6969675, 'rpn_box_loss': 0.028565748, 'rpn_score_loss': 0.015519912, 'total_loss': 0.99509925, 'training_loss': 0.99509925} saved checkpoint to ./trained_model/ckpt-1200. eval | step: 1200 | running 200 steps of evaluation... creating index... index created! Running per image evaluation... Evaluate annotation type *bbox* DONE (t=1.13s). Accumulating evaluation results... DONE (t=0.29s). Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.083 Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.173 Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.069 Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.011 Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.049 Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.222 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.085 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.114 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.126 Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.037 Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.113 Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.279 Running per image evaluation... Evaluate annotation type *segm* DONE (t=1.23s). Accumulating evaluation results... DONE (t=0.27s). Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.051 Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.133 Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.029 Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.000 Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.023 Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.154 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.059 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.068 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.071 Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.005 Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.041 Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.196 eval | step: 1200 | eval time: 44.6 sec | output: {'AP': 0.08281581, 'AP50': 0.17325933, 'AP75': 0.069426864, 'APl': 0.22204006, 'APm': 0.0490283, 'APs': 0.01114552, 'ARl': 0.27940074, 'ARm': 0.11279956, 'ARmax1': 0.0854453, 'ARmax10': 0.11449421, 'ARmax100': 0.12585662, 'ARs': 0.036795128, 'mask_AP': 0.050573617, 'mask_AP50': 0.13301045, 'mask_AP75': 0.028750887, 'mask_APl': 0.15416817, 'mask_APm': 0.022767765, 'mask_APs': 0.000342094, 'mask_ARl': 0.19606742, 'mask_ARm': 0.041067537, 'mask_ARmax1': 0.059166353, 'mask_ARmax10': 0.06837254, 'mask_ARmax100': 0.07106505, 'mask_ARs': 0.0051217885, 'validation_loss': 0.0} train | step: 1200 | training until step 1400... train | step: 1400 | steps/sec: 2.8 | output: {'frcnn_box_loss': 0.2317424, 'frcnn_cls_loss': 0.042604752, 'learning_rate': 0.014427517, 'mask_loss': 0.35548222, 'model_loss': 0.67040056, 'rpn_box_loss': 0.025666475, 'rpn_score_loss': 0.01490463, 'total_loss': 0.9680551, 'training_loss': 0.9680551} saved checkpoint to ./trained_model/ckpt-1400. eval | step: 1400 | running 200 steps of evaluation... creating index... index created! Running per image evaluation... Evaluate annotation type *bbox* DONE (t=1.02s). Accumulating evaluation results... DONE (t=0.26s). Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.086 Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.175 Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.081 Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.013 Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.050 Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.227 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.088 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.117 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.126 Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.032 Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.119 Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.278 Running per image evaluation... Evaluate annotation type *segm* DONE (t=1.09s). Accumulating evaluation results... DONE (t=0.26s). Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.054 Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.135 Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.035 Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.000 Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.021 Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.164 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.062 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.072 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.075 Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.002 Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.045 Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.210 eval | step: 1400 | eval time: 43.3 sec | output: {'AP': 0.08645412, 'AP50': 0.1745393, 'AP75': 0.08108285, 'APl': 0.2271225, 'APm': 0.05006612, 'APs': 0.013359827, 'ARl': 0.27752808, 'ARm': 0.11857299, 'ARmax1': 0.08788147, 'ARmax10': 0.11722988, 'ARmax100': 0.12622288, 'ARs': 0.031981982, 'mask_AP': 0.054192472, 'mask_AP50': 0.13544387, 'mask_AP75': 0.035158496, 'mask_APl': 0.16427927, 'mask_APm': 0.021398347, 'mask_APs': 0.00022512162, 'mask_ARl': 0.2099251, 'mask_ARm': 0.045315903, 'mask_ARmax1': 0.061884686, 'mask_ARmax10': 0.0723855, 'mask_ARmax100': 0.07529341, 'mask_ARs': 0.0015765766, 'validation_loss': 0.0} train | step: 1400 | training until step 1600... train | step: 1600 | steps/sec: 2.8 | output: {'frcnn_box_loss': 0.23538384, 'frcnn_cls_loss': 0.04081902, 'learning_rate': 0.006684403, 'mask_loss': 0.3545543, 'model_loss': 0.66812444, 'rpn_box_loss': 0.023438226, 'rpn_score_loss': 0.013929193, 'total_loss': 0.9654837, 'training_loss': 0.9654837} saved checkpoint to ./trained_model/ckpt-1600. eval | step: 1600 | running 200 steps of evaluation... creating index... index created! Running per image evaluation... Evaluate annotation type *bbox* DONE (t=0.90s). Accumulating evaluation results... DONE (t=0.27s). Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.088 Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.175 Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.079 Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.014 Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.052 Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.234 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.089 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.117 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.123 Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.031 Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.102 Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.286 Running per image evaluation... Evaluate annotation type *segm* DONE (t=1.01s). Accumulating evaluation results... DONE (t=0.28s). Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.053 Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.134 Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.033 Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.000 Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.023 Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.161 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.061 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.071 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.073 Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.002 Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.041 Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.204 eval | step: 1600 | eval time: 43.9 sec | output: {'AP': 0.08810501, 'AP50': 0.17473447, 'AP75': 0.07890058, 'APl': 0.23405229, 'APm': 0.051752485, 'APs': 0.014340493, 'ARl': 0.2863296, 'ARm': 0.10157952, 'ARmax1': 0.08928158, 'ARmax10': 0.11696063, 'ARmax100': 0.122938015, 'ARs': 0.03108108, 'mask_AP': 0.05264981, 'mask_AP50': 0.13365833, 'mask_AP75': 0.032613363, 'mask_APl': 0.16077128, 'mask_APm': 0.022689762, 'mask_APs': 0.00028459763, 'mask_ARl': 0.20411985, 'mask_ARm': 0.041285403, 'mask_ARmax1': 0.060848624, 'mask_ARmax10': 0.07108018, 'mask_ARmax100': 0.07258799, 'mask_ARs': 0.0024774775, 'validation_loss': 0.0} train | step: 1600 | training until step 1800... train | step: 1800 | steps/sec: 2.8 | output: {'frcnn_box_loss': 0.22772422, 'frcnn_cls_loss': 0.039871745, 'learning_rate': 0.0017130232, 'mask_loss': 0.35741723, 'model_loss': 0.66372687, 'rpn_box_loss': 0.024775868, 'rpn_score_loss': 0.013937668, 'total_loss': 0.96093696, 'training_loss': 0.96093696} saved checkpoint to ./trained_model/ckpt-1800. eval | step: 1800 | running 200 steps of evaluation... creating index... index created! Running per image evaluation... Evaluate annotation type *bbox* DONE (t=0.91s). Accumulating evaluation results... DONE (t=0.26s). Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.091 Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.178 Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.084 Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.015 Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.056 Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.237 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.091 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.120 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.126 Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.030 Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.109 Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.290 Running per image evaluation... Evaluate annotation type *segm* DONE (t=1.01s). Accumulating evaluation results... DONE (t=0.26s). Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.055 Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.138 Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.033 Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.000 Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.027 Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.163 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.062 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.073 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.075 Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.002 Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.048 Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.206 eval | step: 1800 | eval time: 42.9 sec | output: {'AP': 0.09051445, 'AP50': 0.17811275, 'AP75': 0.08386774, 'APl': 0.237296, 'APm': 0.055846825, 'APs': 0.014534691, 'ARl': 0.2895131, 'ARm': 0.109095864, 'ARmax1': 0.09068169, 'ARmax10': 0.12029935, 'ARmax100': 0.12611519, 'ARs': 0.030405406, 'mask_AP': 0.054592002, 'mask_AP50': 0.13818993, 'mask_AP75': 0.03327635, 'mask_APl': 0.16250683, 'mask_APm': 0.026944423, 'mask_APs': 0.00029440268, 'mask_ARl': 0.2059925, 'mask_ARm': 0.04782135, 'mask_ARmax1': 0.062141027, 'mask_ARmax10': 0.07334189, 'mask_ARmax100': 0.0752805, 'mask_ARs': 0.0024774775, 'validation_loss': 0.0} train | step: 1800 | training until step 2000... train | step: 2000 | steps/sec: 2.8 | output: {'frcnn_box_loss': 0.21894428, 'frcnn_cls_loss': 0.042917825, 'learning_rate': 0.0, 'mask_loss': 0.34602836, 'model_loss': 0.6480473, 'rpn_box_loss': 0.02465954, 'rpn_score_loss': 0.015497429, 'total_loss': 0.9452122, 'training_loss': 0.9452122} saved checkpoint to ./trained_model/ckpt-2000. eval | step: 2000 | running 200 steps of evaluation... creating index... index created! Running per image evaluation... Evaluate annotation type *bbox* DONE (t=0.93s). Accumulating evaluation results... DONE (t=0.27s). Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.091 Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.181 Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.081 Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.015 Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.057 Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.237 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.091 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.120 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.126 Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.034 Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.109 Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.285 Running per image evaluation... Evaluate annotation type *segm* DONE (t=1.18s). Accumulating evaluation results... DONE (t=0.26s). Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.055 Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.139 Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.033 Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.000 Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.027 Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.164 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.063 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.074 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.076 Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.003 Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.048 Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.207 eval | step: 2000 | eval time: 43.1 sec | output: {'AP': 0.091350704, 'AP50': 0.18149593, 'AP75': 0.08133419, 'APl': 0.23709846, 'APm': 0.057230346, 'APs': 0.014834963, 'ARl': 0.2846442, 'ARm': 0.10925926, 'ARmax1': 0.090627834, 'ARmax10': 0.12028423, 'ARmax100': 0.1261539, 'ARs': 0.03425926, 'mask_AP': 0.055043407, 'mask_AP50': 0.13945857, 'mask_AP75': 0.0326878, 'mask_APl': 0.16421366, 'mask_APm': 0.026785547, 'mask_APs': 0.0004039821, 'mask_ARl': 0.20692883, 'mask_ARm': 0.04765795, 'mask_ARmax1': 0.06267953, 'mask_ARmax10': 0.07429607, 'mask_ARmax100': 0.07618083, 'mask_ARs': 0.0034034033, 'validation_loss': 0.0} eval | step: 2000 | running 200 steps of evaluation... creating index... index created! Running per image evaluation... Evaluate annotation type *bbox* DONE (t=0.94s). Accumulating evaluation results... DONE (t=0.27s). Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.091 Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.181 Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.081 Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.015 Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.057 Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.237 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.091 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.120 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.126 Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.034 Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.109 Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.285 Running per image evaluation... Evaluate annotation type *segm* DONE (t=1.05s). Accumulating evaluation results... DONE (t=0.26s). Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.055 Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.139 Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.033 Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.000 Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.027 Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.164 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.063 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.074 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.076 Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.003 Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.048 Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.207 eval | step: 2000 | eval time: 43.9 sec | output: {'AP': 0.091350704, 'AP50': 0.18149593, 'AP75': 0.08133419, 'APl': 0.23709846, 'APm': 0.057230346, 'APs': 0.014834963, 'ARl': 0.2846442, 'ARm': 0.10925926, 'ARmax1': 0.090627834, 'ARmax10': 0.12028423, 'ARmax100': 0.1261539, 'ARs': 0.03425926, 'mask_AP': 0.055043407, 'mask_AP50': 0.13945857, 'mask_AP75': 0.0326878, 'mask_APl': 0.16421366, 'mask_APm': 0.026785547, 'mask_APs': 0.0004039821, 'mask_ARl': 0.20692883, 'mask_ARm': 0.04765795, 'mask_ARmax1': 0.06267953, 'mask_ARmax10': 0.07429607, 'mask_ARmax100': 0.07618083, 'mask_ARs': 0.0034034033, 'validation_loss': 0.0}
Load logs in tensorboard
%load_ext tensorboard
%tensorboard --logdir "./trained_model"
Saving and exporting the trained model
The keras.Model
object returned by train_lib.run_experiment
expects the data to be normalized by the dataset loader using the same mean and variance statiscics in preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB)
. This export function handles those details, so you can pass tf.uint8
images and get the correct results.
export_saved_model_lib.export_inference_graph(
input_type='image_tensor',
batch_size=1,
input_image_size=[HEIGHT, WIDTH],
params=exp_config,
checkpoint_path=tf.train.latest_checkpoint(model_dir),
export_dir=export_dir)
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/engine/functional.py:638: UserWarning: Input dict contained keys ['6'] which did not match any model input. They will be ignored by the model. inputs = self._flatten_to_reference_inputs(inputs) WARNING:tensorflow:Skipping full serialization of Keras layer <official.vision.modeling.maskrcnn_model.MaskRCNNModel object at 0x7fb774139550>, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer <official.vision.modeling.maskrcnn_model.MaskRCNNModel object at 0x7fb774139550>, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.merging.add.Add object at 0x7fca768361f0>, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.merging.add.Add object at 0x7fca768361f0>, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.merging.add.Add object at 0x7fb7b0077cd0>, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.merging.add.Add object at 0x7fb7b0077cd0>, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.merging.add.Add object at 0x7fb8a06d4430>, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.merging.add.Add object at 0x7fb8a06d4430>, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.merging.add.Add object at 0x7fb7b0313e80>, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.merging.add.Add object at 0x7fb7b0313e80>, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.merging.add.Add object at 0x7fb74c0dab80>, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.merging.add.Add object at 0x7fb74c0dab80>, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.merging.add.Add object at 0x7fca76701cd0>, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.merging.add.Add object at 0x7fca76701cd0>, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.merging.add.Add object at 0x7fb74c10efd0>, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.merging.add.Add object at 0x7fb74c10efd0>, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer <official.vision.modeling.layers.detection_generator.DetectionGenerator object at 0x7fb774166790>, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer <official.vision.modeling.layers.detection_generator.DetectionGenerator object at 0x7fb774166790>, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer <official.vision.modeling.layers.mask_sampler.MaskSampler object at 0x7fb774166cd0>, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer <official.vision.modeling.layers.mask_sampler.MaskSampler object at 0x7fb774166cd0>, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer <official.vision.modeling.layers.roi_sampler.ROISampler object at 0x7fb890568eb0>, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer <official.vision.modeling.layers.roi_sampler.ROISampler object at 0x7fb890568eb0>, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer <official.vision.modeling.layers.box_sampler.BoxSampler object at 0x7fb8902d1220>, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer <official.vision.modeling.layers.box_sampler.BoxSampler object at 0x7fb8902d1220>, because it is not built. WARNING:absl:Found untraced functions such as inference_for_tflite, inference_from_image_bytes, inference_from_tf_example, rpn_head_1_layer_call_fn, rpn_head_1_layer_call_and_return_conditional_losses while saving (showing 5 of 287). These functions will not be directly callable after loading. INFO:tensorflow:Assets written to: ./exported_model/assets INFO:tensorflow:Assets written to: ./exported_model/assets
Inference from Trained Model
def load_image_into_numpy_array(path):
"""Load an image from file into a numpy array.
Puts image into numpy array to feed into tensorflow graph.
Note that by convention we put it into a numpy array with shape
(height, width, channels), where channels=3 for RGB.
Args:
path: the file path to the image
Returns:
uint8 numpy array with shape (img_height, img_width, 3)
"""
image = None
if(path.startswith('http')):
response = urlopen(path)
image_data = response.read()
image_data = BytesIO(image_data)
image = Image.open(image_data)
else:
image_data = tf.io.gfile.GFile(path, 'rb').read()
image = Image.open(BytesIO(image_data))
(im_width, im_height) = image.size
return np.array(image.getdata()).reshape(
(1, im_height, im_width, 3)).astype(np.uint8)
def build_inputs_for_object_detection(image, input_image_size):
"""Builds Object Detection model inputs for serving."""
image, _ = resize_and_crop_image(
image,
input_image_size,
padded_size=input_image_size,
aug_scale_min=1.0,
aug_scale_max=1.0)
return image
Visualize test data
num_of_examples = 3
test_tfrecords = tf.io.gfile.glob('./lvis_tfrecords/val*')
test_ds = tf.data.TFRecordDataset(test_tfrecords).take(num_of_examples)
show_batch(test_ds, num_of_examples)
Importing SavedModel
imported = tf.saved_model.load(export_dir)
model_fn = imported.signatures['serving_default']
WARNING:absl:Importing a function (__inference_internal_grad_fn_424093) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421283) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419303) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423193) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421093) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419223) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423803) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423773) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419173) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423353) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423963) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421263) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419793) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426013) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424203) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418613) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_427043) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420743) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420473) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422303) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426353) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425683) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418483) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425123) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425943) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421613) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419983) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420053) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426193) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420353) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424073) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421953) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420953) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418573) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419153) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422233) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426663) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425203) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423903) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421183) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424193) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426753) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423703) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423483) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425763) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425313) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424443) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423223) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425053) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423433) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424523) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426813) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419943) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420033) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426413) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426733) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422143) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418563) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424963) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418773) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419703) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419973) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421223) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420443) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426903) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425073) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419633) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418643) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424513) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418703) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421343) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426463) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418603) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420863) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426453) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425903) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426933) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422133) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419013) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422283) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425373) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418393) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425143) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423953) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425403) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426393) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424543) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421313) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424343) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423243) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422603) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419413) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422983) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421293) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419873) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419493) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419283) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423683) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421943) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426823) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426423) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426163) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421833) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424503) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426543) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425933) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420513) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421673) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424133) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422333) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421393) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426333) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420133) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420203) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421723) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424373) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423813) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424233) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420153) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425853) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422973) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420833) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418433) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420543) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425443) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425003) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425163) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422613) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424053) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422213) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423863) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419813) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423593) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421713) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419373) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419673) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424013) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425213) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425783) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420023) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418883) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421653) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422913) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418463) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426403) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421113) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421483) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423883) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423873) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426833) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419953) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422843) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420263) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422773) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426743) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421763) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424623) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426623) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424803) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420313) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421813) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424713) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423143) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421493) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423833) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420713) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419663) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426593) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418673) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421503) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419263) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424863) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426023) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426673) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425843) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423413) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426033) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424423) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419483) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423003) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422863) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420573) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425913) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424493) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418923) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419513) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424333) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421013) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420593) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419453) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421663) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422703) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425043) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421323) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419103) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422943) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420583) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422293) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420603) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422633) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424583) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423713) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422853) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426913) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421693) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419473) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418423) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425773) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423533) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425583) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419833) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422793) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419923) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422063) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422003) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426283) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418803) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421103) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421563) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420733) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423023) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418683) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422593) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421843) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419743) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419543) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425613) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422363) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426713) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426483) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418543) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419213) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423653) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418903) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419203) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424873) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425793) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424063) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419183) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418723) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425063) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421333) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419863) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421593) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426053) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421683) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424653) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424813) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420763) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424983) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425523) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423183) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419163) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422693) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418553) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421443) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425633) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418453) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425973) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426343) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422393) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424773) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421543) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426083) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423403) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420173) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423213) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422243) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426383) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425273) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422873) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419333) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423543) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421403) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418523) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421143) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420283) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424413) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426583) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426203) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422223) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423393) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425113) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419733) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420273) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419613) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419883) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426553) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423423) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420853) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423563) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424473) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423923) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424673) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423613) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421873) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426773) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421553) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422473) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420873) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426063) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423173) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424693) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423553) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421903) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424143) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425863) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423033) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421473) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425923) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421173) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423233) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426143) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420463) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426983) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423573) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422093) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419963) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419363) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421863) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422273) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419583) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424823) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421923) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419623) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426763) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422883) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421243) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421963) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420453) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425513) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419443) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418633) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419193) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424563) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418993) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422833) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426633) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419403) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426963) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426323) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424833) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425833) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426503) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426273) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423853) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423363) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422423) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420943) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422503) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422343) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426123) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418833) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421783) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420893) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418473) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424763) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424953) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419533) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419913) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419273) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418973) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421973) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423443) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419503) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418953) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419383) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423753) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418593) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422723) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419903) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423623) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423993) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425283) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422643) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420563) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421053) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424683) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423723) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421913) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421203) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419593) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420123) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420523) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418783) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419343) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425033) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422313) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422513) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419323) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425433) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425133) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419033) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423333) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418353) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426943) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421413) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418963) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422153) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420243) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425803) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422803) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423103) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421933) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422753) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420993) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420143) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422663) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422533) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419133) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423643) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424453) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424003) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421603) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421893) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421993) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420363) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420683) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426073) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418693) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425673) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426103) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420783) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423783) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420963) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425663) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425463) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424153) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422933) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426873) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422113) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423943) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421133) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423793) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422183) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424923) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424733) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424353) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422253) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426883) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426153) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426923) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420183) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421023) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419723) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419433) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418623) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424103) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425693) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424243) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422623) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422583) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420803) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421353) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418933) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423083) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422573) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418713) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422463) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419603) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424113) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423323) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421643) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425093) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421523) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419353) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418383) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424553) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421073) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421623) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422033) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425103) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420503) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425453) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420723) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418983) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420113) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425713) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420013) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420253) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424123) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426563) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423373) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421803) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420913) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424213) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422413) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426373) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_427013) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421153) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419803) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418513) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424603) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420673) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420923) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424893) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421303) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420813) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423453) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423513) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418793) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422023) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419783) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419773) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424633) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425023) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421233) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421083) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424043) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419243) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423283) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426653) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425593) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423843) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420553) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426693) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423823) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422553) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422163) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421003) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426363) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_427003) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422823) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424793) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422763) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422743) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424383) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419113) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424993) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426853) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420493) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422013) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422353) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423293) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421043) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422963) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_427023) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421573) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421733) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423733) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420903) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418913) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420613) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419653) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423693) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419253) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422103) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420823) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418373) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421743) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422483) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424083) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426843) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422123) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426213) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424293) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418943) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419313) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422813) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419143) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423383) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420533) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422083) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422433) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419053) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420163) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418583) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423203) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_427033) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426723) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425343) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421373) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421753) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420693) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420423) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419563) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421853) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425983) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425223) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426223) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424223) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423093) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424593) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424463) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425823) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424163) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424303) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420323) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420333) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426683) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425333) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426313) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424853) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423473) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421273) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424703) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426953) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426793) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425873) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419853) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423893) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426183) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422893) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425883) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420483) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425483) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418733) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422993) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422683) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422783) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420643) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422673) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425503) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424643) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421383) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420293) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424253) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421163) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425393) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426613) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426043) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422453) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426643) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421793) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423273) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422173) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425353) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424183) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422903) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420983) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418753) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419003) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421983) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421453) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424723) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425573) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424663) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420343) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425553) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421823) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426443) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418813) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425993) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420393) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419713) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425363) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422523) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418663) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424933) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419233) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419823) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419293) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420043) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424403) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425423) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422203) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421033) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419123) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422073) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422403) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418853) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418653) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421423) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420213) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419063) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420753) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419753) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425413) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423113) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421773) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420383) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425243) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422383) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423313) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424033) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422543) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418443) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425253) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424843) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426893) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418363) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422563) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420303) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420433) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423583) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420003) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420773) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420083) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419553) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418493) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423663) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423493) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424973) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422193) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422713) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420073) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426523) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423013) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423163) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420843) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419043) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425533) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423603) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425193) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423263) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420413) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419573) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426113) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420933) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426573) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424393) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423973) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421513) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421633) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420063) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425383) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421213) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420623) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426973) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425743) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423673) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424903) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423743) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420653) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419933) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425703) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419073) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421433) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419523) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424023) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423463) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420883) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422953) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418823) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425233) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426433) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426243) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424263) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418533) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425013) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424483) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418503) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426133) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423303) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422053) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425563) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422043) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425733) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418413) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425473) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426473) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425083) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418403) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425293) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420973) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421703) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423983) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420193) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423913) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419693) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423043) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421253) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426493) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424573) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420093) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426293) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425183) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423063) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425953) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424753) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426993) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423503) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426803) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425603) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420103) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420793) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422373) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424533) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424913) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421463) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424273) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421193) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419993) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426703) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425623) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424883) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424363) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426533) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426093) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424433) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423153) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421583) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422493) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422323) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418863) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422443) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423133) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423633) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426173) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420403) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424783) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425493) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426783) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424743) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420633) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421533) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419763) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419843) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419023) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423073) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418743) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419423) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423523) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419683) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425813) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425543) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425653) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423933) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421883) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423343) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418893) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425723) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426263) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419093) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422923) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425303) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424173) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421063) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421363) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424613) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424943) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424283) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420703) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419893) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418843) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426603) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420373) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419083) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424323) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419393) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425963) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425173) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423053) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426863) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425893) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419463) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422733) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426303) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426233) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423253) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425323) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423123) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_423763) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420663) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422653) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426513) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420233) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_419643) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425753) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426253) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425263) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_422263) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_424313) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425643) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_426003) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418873) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_421123) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_420223) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_418763) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. WARNING:absl:Importing a function (__inference_internal_grad_fn_425153) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
Visualize predictions
def reframe_image_corners_relative_to_boxes(boxes):
"""Reframe the image corners ([0, 0, 1, 1]) to be relative to boxes.
The local coordinate frame of each box is assumed to be relative to
its own for corners.
Args:
boxes: A float tensor of [num_boxes, 4] of (ymin, xmin, ymax, xmax)
coordinates in relative coordinate space of each bounding box.
Returns:
reframed_boxes: Reframes boxes with same shape as input.
"""
ymin, xmin, ymax, xmax = (boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3])
height = tf.maximum(ymax - ymin, 1e-4)
width = tf.maximum(xmax - xmin, 1e-4)
ymin_out = (0 - ymin) / height
xmin_out = (0 - xmin) / width
ymax_out = (1 - ymin) / height
xmax_out = (1 - xmin) / width
return tf.stack([ymin_out, xmin_out, ymax_out, xmax_out], axis=1)
def reframe_box_masks_to_image_masks(box_masks, boxes, image_height,
image_width, resize_method='bilinear'):
"""Transforms the box masks back to full image masks.
Embeds masks in bounding boxes of larger masks whose shapes correspond to
image shape.
Args:
box_masks: A tensor of size [num_masks, mask_height, mask_width].
boxes: A tf.float32 tensor of size [num_masks, 4] containing the box
corners. Row i contains [ymin, xmin, ymax, xmax] of the box
corresponding to mask i. Note that the box corners are in
normalized coordinates.
image_height: Image height. The output mask will have the same height as
the image height.
image_width: Image width. The output mask will have the same width as the
image width.
resize_method: The resize method, either 'bilinear' or 'nearest'. Note that
'bilinear' is only respected if box_masks is a float.
Returns:
A tensor of size [num_masks, image_height, image_width] with the same dtype
as `box_masks`.
"""
resize_method = 'nearest' if box_masks.dtype == tf.uint8 else resize_method
# TODO(rathodv): Make this a public function.
def reframe_box_masks_to_image_masks_default():
"""The default function when there are more than 0 box masks."""
num_boxes = tf.shape(box_masks)[0]
box_masks_expanded = tf.expand_dims(box_masks, axis=3)
resized_crops = tf.image.crop_and_resize(
image=box_masks_expanded,
boxes=reframe_image_corners_relative_to_boxes(boxes),
box_indices=tf.range(num_boxes),
crop_size=[image_height, image_width],
method=resize_method,
extrapolation_value=0)
return tf.cast(resized_crops, box_masks.dtype)
image_masks = tf.cond(
tf.shape(box_masks)[0] > 0,
reframe_box_masks_to_image_masks_default,
lambda: tf.zeros([0, image_height, image_width, 1], box_masks.dtype))
return tf.squeeze(image_masks, axis=3)
input_image_size = (HEIGHT, WIDTH)
plt.figure(figsize=(20, 20))
min_score_thresh = 0.40 # Change minimum score for threshold to see all bounding boxes confidences
for i, serialized_example in enumerate(test_ds):
plt.subplot(1, 3, i+1)
decoded_tensors = tf_ex_decoder.decode(serialized_example)
image = build_inputs_for_object_detection(decoded_tensors['image'], input_image_size)
image = tf.expand_dims(image, axis=0)
image = tf.cast(image, dtype = tf.uint8)
image_np = image[0].numpy()
result = model_fn(image)
# Visualize detection and masks
if 'detection_masks' in result:
# we need to convert np.arrays to tensors
detection_masks = tf.convert_to_tensor(result['detection_masks'][0])
detection_boxes = tf.convert_to_tensor(result['detection_boxes'][0])
detection_masks_reframed = reframe_box_masks_to_image_masks(
detection_masks, detection_boxes/255.0,
image_np.shape[0], image_np.shape[1])
detection_masks_reframed = tf.cast(
detection_masks_reframed > min_score_thresh,
np.uint8)
result['detection_masks_reframed'] = detection_masks_reframed.numpy()
visualization_utils.visualize_boxes_and_labels_on_image_array(
image_np,
result['detection_boxes'][0].numpy(),
(result['detection_classes'][0] + 0).numpy().astype(int),
result['detection_scores'][0].numpy(),
category_index=category_index,
use_normalized_coordinates=False,
max_boxes_to_draw=200,
min_score_thresh=min_score_thresh,
instance_masks=result.get('detection_masks_reframed', None),
line_thickness=4)
plt.imshow(image_np)
plt.axis("off")
plt.show()