![]() |
A specification of BERT model for question answering.
tflite_model_maker.question_answer.BertQaSpec(
uri='https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_L-12_H-768_A-12/1',
model_dir=None,
seq_len=384,
query_len=64,
doc_stride=128,
dropout_rate=0.1,
initializer_range=0.02,
learning_rate=8e-05,
distribution_strategy='mirrored',
num_gpus=-1,
tpu='',
trainable=True,
predict_batch_size=8,
do_lower_case=True,
is_tf2=True,
tflite_input_name=None,
tflite_output_name=None,
init_from_squad_model=False,
default_batch_size=16,
name='Bert'
)
Methods
build
build()
Builds the class. Used for lazy initialization.
convert_examples_to_features
convert_examples_to_features(
examples, is_training, output_fn, batch_size
)
Converts examples to features and write them into TFRecord file.
create_model
create_model()
Creates the model for qa task.
evaluate
evaluate(
model,
tflite_filepath,
dataset,
num_steps,
eval_examples,
eval_features,
predict_file,
version_2_with_negative,
max_answer_length,
null_score_diff_threshold,
verbose_logging,
output_dir
)
Evaluate QA model.
Args | |
---|---|
model
|
The keras model to be evaluated. |
tflite_filepath
|
File path to the TFLite model. |
dataset
|
tf.data.Dataset used for evaluation. |
num_steps
|
Number of steps to evaluate the model. |
eval_examples
|
List of squad_lib.SquadExample for evaluation data.
|
eval_features
|
List of squad_lib.InputFeatures for evaluation data.
|
predict_file
|
The input predict file. |
version_2_with_negative
|
Whether the input predict file is SQuAD 2.0 format. |
max_answer_length
|
The maximum length of an answer that can be generated. This is needed because the start and end predictions are not conditioned on one another. |
null_score_diff_threshold
|
If null_score - best_non_null is greater than the threshold, predict null. This is only used for SQuAD v2. |
verbose_logging
|
If true, all of the warnings related to data processing will be printed. A number of warnings are expected for a normal SQuAD evaluation. |
output_dir
|
The output directory to save output to json files: predictions.json, nbest_predictions.json, null_odds.json. If None, skip saving to json files. |
Returns | |
---|---|
A dict contains two metrics: Exact match rate and F1 score. |
get_config
get_config()
Gets the configuration.
get_default_quantization_config
get_default_quantization_config()
Gets the default quantization configuration.
get_name_to_features
get_name_to_features(
is_training
)
Gets the dictionary describing the features.
predict
predict(
model, dataset, num_steps
)
Predicts the dataset for model
.
predict_tflite
predict_tflite(
tflite_filepath, dataset
)
Predicts the dataset for TFLite model in tflite_filepath
.
reorder_input_details
reorder_input_details(
tflite_input_details
)
Reorders the tflite input details to map the order of keras model.
reorder_output_details
reorder_output_details(
tflite_output_details
)
Reorders the tflite output details to map the order of keras model.
save_vocab
save_vocab(
vocab_filename
)
Prints the file path to the vocabulary.
select_data_from_record
select_data_from_record(
record
)
Dispatches records to features and labels.
train
train(
train_ds, epochs, steps_per_epoch, **kwargs
)
Run bert QA training.
Args | |
---|---|
train_ds
|
tf.data.Dataset, training data to be fed in tf.keras.Model.fit(). |
epochs
|
Integer, training epochs. |
steps_per_epoch
|
Integer or None. Total number of steps (batches of
samples) before declaring one epoch finished and starting the next
epoch. If steps_per_epoch is None, the epoch will run until the input
dataset is exhausted.
|
**kwargs
|
Other parameters used in the tf.keras.Model.fit(). |
Returns | |
---|---|
tf.keras.Model, the keras model that's already trained. |
Class Variables | |
---|---|
compat_tf_versions |
[2]
|
convert_from_saved_model_tf2 |
True
|
need_gen_vocab |
False
|