tflite_model_maker.image_classifier.ImageClassifier

ImageClassifier class for inference and exporting to tflite.

model_spec Specification for the model.
index_to_label A list that map from index to label class name.
shuffle Whether the data should be shuffled.
hparams A namedtuple of hyperparameters. This function expects .dropout_rate: The fraction of the input units to drop, used in dropout layer. .do_fine_tuning: If true, the Hub module is trained together with the classification layer on top.
use_augmentation Use data augmentation for preprocessing.
representative_data Representative dataset for full integer quantization. Used when converting the keras model to the TFLite model with full interger quantization.

Methods

create

Loads data and retrains the model based on data for image classification.

Args
train_data Training data.
model_spec Specification for the model.
validation_data Validation data. If None, skips validation process.
batch_size Number of samples per training step. If use_hub_library is False, it represents the base learning rate when train batch size is 256 and it's linear to the batch size.
epochs Number of epochs for training.
train_whole_model If true, the Hub module is trained together with the classification layer on top. Otherwise, only train the top classification layer.
dropout_rate The rate for dropout.
learning_rate Base learning rate when train batch size is 256. Linear to the batch size.
momentum a Python float forwarded to the optimizer. Only used when use_hub_library is True.
shuffle Whether the data should be shuffled.
use_augmentation Use data augmentation for preprocessing.
use_hub_library Use make_image_classifier_lib from tensorflow hub to retrain the model.
warmup_steps Number of warmup steps for warmup schedule on learning rate. If None, the default warmup_steps is used which is the total training steps in two epochs. Only used when use_hub_library is False.
model_dir The location of the model checkpoint files. Only used when use_hub_library is False.
do_train Whether to run training.

Returns
An instance based on ImageClassifier.

create_model

Creates the classifier model for retraining.

create_serving_model

Returns the underlining Keras model for serving.

evaluate

Evaluates the model.

Args
data Data to be evaluated.
batch_size Number of samples per evaluation step.

Returns
The loss value and accuracy.

evaluate_tflite

Evaluates the tflite model.

Args
tflite_filepath File path to the TFLite model.
data Data to be evaluated.
postprocess_fn Postprocessing function that will be applied to the output of lite_runner.run before calculating the probabilities.

Returns
The evaluation result of TFLite model - accuracy.

export

Converts the retrained model based on export_format.

Args
export_dir The directory to save exported files.
tflite_filename File name to save tflite model. The full export path is {export_dir}/{tflite_filename}.
label_filename File name to save labels. The full export path is {export_dir}/{label_filename}.
vocab_filename File name to save vocabulary. The full export path is {export_dir}/{vocab_filename}.
saved_model_filename Path to SavedModel or H5 file to save the model. The full export path is {export_dir}/{saved_model_filename}/{saved_model.pb|assets|variables}.
tfjs_folder_name Folder name to save tfjs model. The full export path is {export_dir}/{tfjs_folder_name}.
export_format List of export format that could be saved_model, tflite, label, vocab.
**kwargs Other parameters like quantized for TFLITE model.

predict_top_k

Predicts the top-k predictions.

Args
data Data to be evaluated. Either an instance of DataLoader or just raw data entries such TF tensor or numpy array.
k Number of top results to be predicted.
batch_size Number of samples per evaluation step.

Returns
top k results. Each one is (label, probability).

summary

train

Feeds the training data for training.

Args
train_data Training data.
validation_data Validation data. If None, skips validation process.
hparams An instance of hub_lib.HParams or train_image_classifier_lib.HParams. Anamedtuple of hyperparameters.

Returns
The tf.keras.callbacks.History object returned by tf.keras.Model.fit*().

ALLOWED_EXPORT_FORMAT (<ExportFormat.TFLITE: 'TFLITE'>, <ExportFormat.LABEL: 'LABEL'>, <ExportFormat.SAVED_MODEL: 'SAVED_MODEL'>, <ExportFormat.TFJS: 'TFJS'>)
DEFAULT_EXPORT_FORMAT (<ExportFormat.TFLITE: 'TFLITE'>, <ExportFormat.LABEL: 'LABEL'>)