TextClassifier class for inference and exporting to tflite.
tflite_model_maker.text_classifier.TextClassifier(
model_spec, index_to_label, shuffle=True
)
Args |
model_spec
|
Specification for the model.
|
index_to_label
|
A list that map from index to label class name.
|
shuffle
|
Whether the data should be shuffled.
|
Methods
create
View source
@classmethod
create(
train_data,
model_spec='average_word_vec',
validation_data=None,
batch_size=None,
epochs=3,
steps_per_epoch=None,
shuffle=False,
do_train=True
)
Loads data and train the model for test classification.
Args |
train_data
|
Training data.
|
model_spec
|
Specification for the model.
|
validation_data
|
Validation data. If None, skips validation process.
|
batch_size
|
Batch size for training.
|
epochs
|
Number of epochs for training.
|
steps_per_epoch
|
Integer or None. Total number of steps (batches of
samples) before declaring one epoch finished and starting the next
epoch. If steps_per_epoch is None, the epoch will run until the input
dataset is exhausted.
|
shuffle
|
Whether the data should be shuffled.
|
do_train
|
Whether to run training.
|
Returns |
An instance based on TextClassifier.
|
create_model
View source
create_model(
with_loss_and_metrics=True
)
create_serving_model
View source
create_serving_model()
Returns the underlining Keras model for serving.
evaluate
View source
evaluate(
data, batch_size=32
)
Evaluates the model.
Args |
data
|
Data to be evaluated.
|
batch_size
|
Number of samples per evaluation step.
|
Returns |
The loss value and accuracy.
|
evaluate_tflite
View source
evaluate_tflite(
tflite_filepath, data, postprocess_fn=None
)
Evaluates the tflite model.
Args |
tflite_filepath
|
File path to the TFLite model.
|
data
|
Data to be evaluated.
|
postprocess_fn
|
Postprocessing function that will be applied to the output
of lite_runner.run before calculating the probabilities.
|
Returns |
The evaluation result of TFLite model - accuracy.
|
export
View source
export(
export_dir,
tflite_filename='model.tflite',
label_filename='labels.txt',
vocab_filename='vocab.txt',
saved_model_filename='saved_model',
tfjs_folder_name='tfjs',
export_format=None,
**kwargs
)
Converts the retrained model based on export_format
.
Args |
export_dir
|
The directory to save exported files.
|
tflite_filename
|
File name to save tflite model. The full export path is
{export_dir}/{tflite_filename}.
|
label_filename
|
File name to save labels. The full export path is
{export_dir}/{label_filename}.
|
vocab_filename
|
File name to save vocabulary. The full export path is
{export_dir}/{vocab_filename}.
|
saved_model_filename
|
Path to SavedModel or H5 file to save the model. The
full export path is
{export_dir}/{saved_model_filename}/{saved_model.pb|assets|variables}.
|
tfjs_folder_name
|
Folder name to save tfjs model. The full export path is
{export_dir}/{tfjs_folder_name}.
|
export_format
|
List of export format that could be saved_model, tflite,
label, vocab.
|
**kwargs
|
Other parameters like quantized_config for TFLITE model.
|
predict_top_k
View source
predict_top_k(
data, k=1, batch_size=32
)
Predicts the top-k predictions.
Args |
data
|
Data to be evaluated. Either an instance of DataLoader or just raw
data entries such TF tensor or numpy array.
|
k
|
Number of top results to be predicted.
|
batch_size
|
Number of samples per evaluation step.
|
Returns |
top k results. Each one is (label, probability).
|
summary
View source
summary()
train
View source
train(
train_data,
validation_data=None,
epochs=None,
batch_size=None,
steps_per_epoch=None
)
Feeds the training data for training.
Class Variables |
ALLOWED_EXPORT_FORMAT
|
(<ExportFormat.TFLITE: 'TFLITE'>,
<ExportFormat.LABEL: 'LABEL'>,
<ExportFormat.VOCAB: 'VOCAB'>,
<ExportFormat.SAVED_MODEL: 'SAVED_MODEL'>,
<ExportFormat.TFJS: 'TFJS'>)
|
DEFAULT_EXPORT_FORMAT
|
(<ExportFormat.TFLITE: 'TFLITE'>,
<ExportFormat.LABEL: 'LABEL'>,
<ExportFormat.VOCAB: 'VOCAB'>)
|