View source on GitHub |
A specification of the EfficientDet model.
tflite_model_maker.object_detector.EfficientDetSpec(
model_name: str,
uri: str,
hparams: str = '',
model_dir: Optional[str] = None,
epochs: int = 50,
batch_size: int = 64,
steps_per_execution: int = 1,
moving_average_decay: int = 0,
var_freeze_expr: str = '(efficientnet|fpn_cells|resample_p6)',
tflite_max_detections: int = 25,
strategy: Optional[str] = None,
tpu: Optional[str] = None,
gcp_project: Optional[str] = None,
tpu_zone: Optional[str] = None,
use_xla: bool = False,
profile: bool = False,
debug: bool = False,
tf_random_seed: int = 111111,
verbose: int = 0
) -> None
Args | |
---|---|
model_name
|
Model name. |
uri
|
TF-Hub path/url to EfficientDet module. |
hparams
|
Hyperparameters used to overwrite default configuration. Can be
1) Dict, contains parameter names and values; 2) String, Comma separated k=v pairs of hyperparameters; 3) String, yaml filename which's a module containing attributes to use as hyperparameters. |
model_dir
|
The location to save the model checkpoint files. |
epochs
|
Default training epochs. |
batch_size
|
Training & Evaluation batch size. |
steps_per_execution
|
Number of steps per training execution. |
moving_average_decay
|
Float. The decay to use for maintaining moving averages of the trained parameters. |
var_freeze_expr
|
Expression to freeze variables. |
tflite_max_detections
|
The max number of output detections in the TFLite model. |
strategy
|
A string specifying which distribution strategy to use. Accepted values are 'tpu', 'gpus', None. tpu' means to use TPUStrategy. 'gpus' mean to use MirroredStrategy for multi-gpus. If None, use TF default with OneDeviceStrategy. |
tpu
|
The Cloud TPU to use for training. This should be either the name used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 url. |
gcp_project
|
Project name for the Cloud TPU-enabled project. If not specified, we will attempt to automatically detect the GCE project from metadata. |
tpu_zone
|
GCE zone where the Cloud TPU is located in. If not specified, we will attempt to automatically detect the GCE project from metadata. |
use_xla
|
Use XLA even if strategy is not tpu. If strategy is tpu, always use XLA, and this flag has no effect. |
profile
|
Enable profile mode. |
debug
|
Enable debug mode. |
tf_random_seed
|
Fixed random seed for deterministic execution across runs for debugging. |
verbose
|
verbosity mode for tf.keras.callbacks.ModelCheckpoint , 0 or 1.
|
Methods
create_model
create_model() -> tf.keras.Model
Creates the EfficientDet model.
evaluate
evaluate(
model: tf.keras.Model,
dataset: tf.data.Dataset,
steps: int,
json_file: Optional[str] = None
) -> Dict[str, float]
Evaluate the EfficientDet keras model.
Args | |
---|---|
model
|
The keras model to be evaluated. |
dataset
|
tf.data.Dataset used for evaluation. |
steps
|
Number of steps to evaluate the model. |
json_file
|
JSON with COCO data format containing golden bounding boxes. Used for validation. If None, use the ground truth from the dataloader. Refer to https://towardsdatascience.com/coco-data-format-for-object-detection-a4c5eaf518c5 for the description of COCO data format. |
Returns | |
---|---|
A dict contains AP metrics. |
evaluate_tflite
evaluate_tflite(
tflite_filepath: str,
dataset: tf.data.Dataset,
steps: int,
json_file: Optional[str] = None
) -> Dict[str, float]
Evaluate the EfficientDet TFLite model.
Args | |
---|---|
tflite_filepath
|
File path to the TFLite model. |
dataset
|
tf.data.Dataset used for evaluation. |
steps
|
Number of steps to evaluate the model. |
json_file
|
JSON with COCO data format containing golden bounding boxes. Used for validation. If None, use the ground truth from the dataloader. Refer to https://towardsdatascience.com/coco-data-format-for-object-detection-a4c5eaf518c5 for the description of COCO data format. |
Returns | |
---|---|
A dict contains AP metrics. |
export_saved_model
export_saved_model(
model: tf.keras.Model,
saved_model_dir: str,
batch_size: Optional[int] = None,
pre_mode: Optional[str] = 'infer',
post_mode: Optional[str] = 'global'
) -> None
Saves the model to Tensorflow SavedModel.
Args | |
---|---|
model
|
The EfficientDetNet model used for training which doesn't have pre and post processing. |
saved_model_dir
|
Folder path for saved model. |
batch_size
|
Batch size to be saved in saved_model. |
pre_mode
|
Pre-processing Mode in ExportModel, must be {None, 'infer'}. |
post_mode
|
Post-processing Mode in ExportModel, must be {None, 'global', 'per_class', 'tflite'}. |
export_tflite
export_tflite(
model: tf.keras.Model,
tflite_filepath: str,
quantization_config: Optional[tflite_model_maker.config.QuantizationConfig
] = None
) -> None
Converts the retrained model to tflite format and saves it.
The exported TFLite model has the following inputs & outputs:
One input:
image: a float32 tensor of shape[1, height, width, 3] containing the
normalized input image. self.config.image_size
is [height, width].
Four Outputs | |
---|---|
detection_boxes
|
a float32 tensor of shape [1, num_boxes, 4] with box locations. |
detection_classes
|
a float32 tensor of shape [1, num_boxes] with class indices. |
detection_scores
|
a float32 tensor of shape [1, num_boxes] with class scores. |
num_boxes
|
a float32 tensor of size 1 containing the number of detected boxes. |
Args | |
---|---|
model
|
The EfficientDetNet model used for training which doesn't have pre and post processing. |
tflite_filepath
|
File path to save tflite model. |
quantization_config
|
Configuration for post-training quantization. |
get_default_quantization_config
get_default_quantization_config(
representative_data: tflite_model_maker.object_detector.DataLoader
) -> tflite_model_maker.config.QuantizationConfig
Gets the default quantization configuration.
train
train(
model: tf.keras.Model,
train_dataset: tf.data.Dataset,
steps_per_epoch: int,
val_dataset: Optional[tf.data.Dataset],
validation_steps: int,
epochs: Optional[int] = None,
batch_size: Optional[int] = None,
val_json_file: Optional[str] = None
) -> tf.keras.Model
Run EfficientDet training.
Class Variables | |
---|---|
compat_tf_versions |
[2]
|