tflite_model_maker.text_classifier.TextClassifier

TextClassifier class for inference and exporting to tflite.

model_spec Specification for the model.
index_to_label A list that map from index to label class name.
shuffle Whether the data should be shuffled.

Methods

create

View source

Loads data and train the model for test classification.

Args
train_data Training data.
model_spec Specification for the model.
validation_data Validation data. If None, skips validation process.
batch_size Batch size for training.
epochs Number of epochs for training.
steps_per_epoch Integer or None. Total number of steps (batches of samples) before declaring one epoch finished and starting the next epoch. If steps_per_epoch is None, the epoch will run until the input dataset is exhausted.
shuffle Whether the data should be shuffled.
do_train Whether to run training.

Returns
An instance based on TextClassifier.

create_model

View source

create_serving_model

View source

Returns the underlining Keras model for serving.

evaluate

View source

Evaluates the model.

Args
data Data to be evaluated.
batch_size Number of samples per evaluation step.

Returns
The loss value and accuracy.

evaluate_tflite

View source

Evaluates the tflite model.

Args
tflite_filepath File path to the TFLite model.
data Data to be evaluated.
postprocess_fn Postprocessing function that will be applied to the output of lite_runner.run before calculating the probabilities.

Returns
The evaluation result of TFLite model - accuracy.

export

View source

Converts the retrained model based on export_format.

Args
export_dir The directory to save exported files.
tflite_filename File name to save tflite model. The full export path is {export_dir}/{tflite_filename}.
label_filename File name to save labels. The full export path is {export_dir}/{label_filename}.
vocab_filename File name to save vocabulary. The full export path is {export_dir}/{vocab_filename}.
saved_model_filename Path to SavedModel or H5 file to save the model. The full export path is {export_dir}/{saved_model_filename}/{saved_model.pb|assets|variables}.
tfjs_folder_name Folder name to save tfjs model. The full export path is {export_dir}/{tfjs_folder_name}.
export_format List of export format that could be saved_model, tflite, label, vocab.
**kwargs Other parameters like quantized_config for TFLITE model.

predict_top_k

View source

Predicts the top-k predictions.

Args
data Data to be evaluated. Either an instance of DataLoader or just raw data entries such TF tensor or numpy array.
k Number of top results to be predicted.
batch_size Number of samples per evaluation step.

Returns
top k results. Each one is (label, probability).

summary

View source

train

View source

Feeds the training data for training.

ALLOWED_EXPORT_FORMAT (<ExportFormat.TFLITE: 'TFLITE'>, <ExportFormat.LABEL: 'LABEL'>, <ExportFormat.VOCAB: 'VOCAB'>, <ExportFormat.SAVED_MODEL: 'SAVED_MODEL'>, <ExportFormat.TFJS: 'TFJS'>)
DEFAULT_EXPORT_FORMAT (<ExportFormat.TFLITE: 'TFLITE'>, <ExportFormat.LABEL: 'LABEL'>, <ExportFormat.VOCAB: 'VOCAB'>)