![]() |
A specification of BERT model for text classification.
tflite_model_maker.text_classifier.BertClassifierSpec(
uri='https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_L-12_H-768_A-12/1',
model_dir=None,
seq_len=128,
dropout_rate=0.1,
initializer_range=0.02,
learning_rate=3e-05,
distribution_strategy='mirrored',
num_gpus=-1,
tpu='',
trainable=True,
do_lower_case=True,
is_tf2=True,
name='Bert',
tflite_input_name=None,
default_batch_size=32,
index_to_label=None
)
Methods
build
build()
Builds the class. Used for lazy initialization.
convert_examples_to_features
convert_examples_to_features(
examples, tfrecord_file, label_names
)
Converts examples to features and write them into TFRecord file.
create_model
create_model(
num_classes, optimizer='adam', with_loss_and_metrics=True
)
Creates the keras model.
get_config
get_config()
Gets the configuration.
get_default_quantization_config
get_default_quantization_config()
Gets the default quantization configuration.
get_name_to_features
get_name_to_features()
Gets the dictionary describing the features.
reorder_input_details
reorder_input_details(
tflite_input_details
)
Reorders the tflite input details to map the order of keras model.
run_classifier
run_classifier(
train_ds, validation_ds, epochs, steps_per_epoch, num_classes, **kwargs
)
Creates classifier and runs the classifier training.
Args | |
---|---|
train_ds
|
tf.data.Dataset, training data to be fed in tf.keras.Model.fit(). |
validation_ds
|
tf.data.Dataset, validation data to be fed in tf.keras.Model.fit(). |
epochs
|
Integer, training epochs. |
steps_per_epoch
|
Integer or None. Total number of steps (batches of
samples) before declaring one epoch finished and starting the next
epoch. If steps_per_epoch is None, the epoch will run until the input
dataset is exhausted.
|
num_classes
|
Interger, number of classes. |
**kwargs
|
Other parameters used in the tf.keras.Model.fit(). |
Returns | |
---|---|
tf.keras.Model, the keras model that's already trained. |
save_vocab
save_vocab(
vocab_filename
)
Prints the file path to the vocabulary.
select_data_from_record
select_data_from_record(
record
)
Dispatches records to features and labels.
Class Variables | |
---|---|
compat_tf_versions |
[2]
|
convert_from_saved_model_tf2 |
True
|
need_gen_vocab |
False
|