tflite_model_maker.question_answer.QuestionAnswer

QuestionAnswer class for inference and exporting to tflite.

model_spec Specification for the model.
shuffle Whether the training data should be shuffled.

Methods

create

View source

Loads data and train the model for question answer.

Args
train_data Training data.
model_spec Specification for the model.
batch_size Batch size for training.
epochs Number of epochs for training.
steps_per_epoch Integer or None. Total number of steps (batches of samples) before declaring one epoch finished and starting the next epoch. If steps_per_epoch is None, the epoch will run until the input dataset is exhausted.
shuffle Whether the data should be shuffled.
do_train Whether to run training.

Returns
An instance based on QuestionAnswer.

create_model

View source

create_serving_model

View source

Returns the underlining Keras model for serving.

evaluate

View source

Evaluate the model.

Args
data Data to be evaluated.
max_answer_length The maximum length of an answer that can be generated. This is needed because the start and end predictions are not conditioned on one another.
null_score_diff_threshold If null_score - best_non_null is greater than the threshold, predict null. This is only used for SQuAD v2.
verbose_logging If true, all of the warnings related to data processing will be printed. A number of warnings are expected for a normal SQuAD evaluation.
output_dir The output directory to save output to json files: predictions.json, nbest_predictions.json, null_odds.json. If None, skip saving to json files.

Returns
A dict contains two metrics: Exact match rate and F1 score.

evaluate_tflite

View source

Evaluate the model.

Args
tflite_filepath File path to the TFLite model.
data Data to be evaluated.
max_answer_length The maximum length of an answer that can be generated. This is needed because the start and end predictions are not conditioned on one another.
null_score_diff_threshold If null_score - best_non_null is greater than the threshold, predict null. This is only used for SQuAD v2.
verbose_logging If true, all of the warnings related to data processing will be printed. A number of warnings are expected for a normal SQuAD evaluation.
output_dir The output directory to save output to json files: predictions.json, nbest_predictions.json, null_odds.json. If None, skip saving to json files.

Returns
A dict contains two metrics: Exact match rate and F1 score.

export

View source

Converts the retrained model based on export_format.

Args
export_dir The directory to save exported files.
tflite_filename File name to save tflite model. The full export path is {export_dir}/{tflite_filename}.
label_filename File name to save labels. The full export path is {export_dir}/{label_filename}.
vocab_filename File name to save vocabulary. The full export path is {export_dir}/{vocab_filename}.
saved_model_filename Path to SavedModel or H5 file to save the model. The full export path is {export_dir}/{saved_model_filename}/{saved_model.pb|assets|variables}.
tfjs_folder_name Folder name to save tfjs model. The full export path is {export_dir}/{tfjs_folder_name}.
export_format List of export format that could be saved_model, tflite, label, vocab.
**kwargs Other parameters like quantized_config for TFLITE model.

summary

View source

train

View source

Feeds the training data for training.

ALLOWED_EXPORT_FORMAT (<ExportFormat.TFLITE: 'TFLITE'>, <ExportFormat.VOCAB: 'VOCAB'>, <ExportFormat.SAVED_MODEL: 'SAVED_MODEL'>)
DEFAULT_EXPORT_FORMAT (<ExportFormat.TFLITE: 'TFLITE'>, <ExportFormat.VOCAB: 'VOCAB'>)