Creates MobileBert model spec for the text classification task. See also: tflite_model_maker.text_classifier.BertClassifierSpec
.
tflite_model_maker.text_classifier.MobileBertClassifierSpec(
*,
uri='https://hub.tensorflow.google.cn/google/mobilebert/uncased_L-24_H-128_B-512_A-4_F-4_OPT/1',
model_dir=None,
seq_len=128,
dropout_rate=0.1,
initializer_range=0.02,
learning_rate=3e-05,
distribution_strategy='off',
num_gpus=-1,
tpu='',
trainable=True,
do_lower_case=True,
is_tf2=False,
name='MobileBert',
tflite_input_name=None,
default_batch_size=48,
index_to_label=None
)
Args |
uri
|
TF-Hub path/url to Bert module.
|
model_dir
|
The location of the model checkpoint files.
|
seq_len
|
Length of the sequence to feed into the model.
|
dropout_rate
|
The rate for dropout.
|
initializer_range
|
The stdev of the truncated_normal_initializer for
initializing all weight matrices.
|
learning_rate
|
The initial learning rate for Adam.
|
distribution_strategy
|
A string specifying which distribution strategy to
use. Accepted values are 'off', 'one_device', 'mirrored',
'parameter_server', 'multi_worker_mirrored', and 'tpu' -- case
insensitive. 'off' means not to use Distribution Strategy; 'tpu' means
to use TPUStrategy using tpu_address .
|
num_gpus
|
How many GPUs to use at each worker with the
DistributionStrategies API. The default is -1, which means utilize all
available GPUs.
|
tpu
|
TPU address to connect to.
|
trainable
|
boolean, whether pretrain layer is trainable.
|
do_lower_case
|
boolean, whether to lower case the input text. Should be
True for uncased models and False for cased models.
|
is_tf2
|
boolean, whether the hub module is in TensorFlow 2.x format.
|
name
|
The name of the object.
|
tflite_input_name
|
Dict, input names for the TFLite model.
|
default_batch_size
|
Default batch size for training.
|
index_to_label
|
List of labels in the training data. e.g. ['neg', 'pos'].
|