tfm.nlp.encoders.MobileBertEncoderConfig

MobileBERT encoder configuration.

Inherits From: Config, ParamsDict

word_vocab_size number of words in the vocabulary.
word_embed_size word embedding size.
type_vocab_size number of word types.
max_sequence_length maximum length of input sequence.
num_blocks number of transformer block in the encoder model.
hidden_size the hidden size for the transformer block.
num_attention_heads number of attention heads in the transformer block.
intermediate_size the size of the "intermediate" (a.k.a., feed forward) layer.
hidden_activation the non-linear activation function to apply to the output of the intermediate/feed-forward layer.
hidden_dropout_prob dropout probability for the hidden layers.
attention_probs_dropout_prob dropout probability of the attention probabilities.
intra_bottleneck_size the size of bottleneck.
initializer_range The stddev of the truncated_normal_initializer for initializing all weight matrices.
use_bottleneck_attention Use attention inputs from the bottleneck transformation. If true, the following key_query_shared_bottleneck will be ignored.
key_query_shared_bottleneck whether to share linear transformation for keys and queries.
num_feedforward_networks number of stacked feed-forward networks.
normalization_type the type of normalization_type, only 'no_norm' and 'layer_norm' are supported. 'no_norm' represents the element-wise linear transformation for the student model, as suggested by the original MobileBERT paper. 'layer_norm' is used for the teacher model.
classifier_activation if using the tanh activation for the final representation of the [CLS] token in fine-tuning.
BUILDER

default_params Dataclass field
restrictions Dataclass field
input_mask_dtype Dataclass field

Methods

as_dict

View source

Returns a dict representation of params_dict.ParamsDict.

For the nested params_dict.ParamsDict, a nested dict will be returned.

from_args

View source

Builds a config from the given list of arguments.

from_json

View source

Wrapper for from_yaml.

from_yaml

View source

get

View source

Accesses through built-in dictionary get method.

lock

View source

Makes the ParamsDict immutable.

override

View source

Override the ParamsDict with a set of given params.

Args
override_params a dict or a ParamsDict specifying the parameters to be overridden.
is_strict a boolean specifying whether override is strict or not. If True, keys in override_params must be present in the ParamsDict. If False, keys in override_params can be different from what is currently defined in the ParamsDict. In this case, the ParamsDict will be extended to include the new keys.

replace

View source

Overrides/returns a unlocked copy with the current config unchanged.

validate

View source

Validate the parameters consistency based on the restrictions.

This method validates the internal consistency using the pre-defined list of restrictions. A restriction is defined as a string which specifies a binary operation. The supported binary operations are {'==', '!=', '<', '<=', '>', '>='}. Note that the meaning of these operators are consistent with the underlying Python immplementation. Users should make sure the define restrictions on their type make sense.

For example, for a ParamsDict like the following

a:
  a1: 1
  a2: 2
b:
  bb:
    bb1: 10
    bb2: 20
  ccc:
    a1: 1
    a3: 3

one can define two restrictions like this ['a.a1 == b.ccc.a1', 'a.a2 <= b.bb.bb2']

What it enforces are

  • a.a1 = 1 == b.ccc.a1 = 1
  • a.a2 = 2 <= b.bb.bb2 = 20

Raises
KeyError if any of the following happens (1) any of parameters in any of restrictions is not defined in ParamsDict, (2) any inconsistency violating the restriction is found.
ValueError if the restriction defined in the string is not supported.

__contains__

View source

Implements the membership test operator.

__eq__

IMMUTABLE_TYPES (<class 'str'>, <class 'int'>, <class 'float'>, <class 'bool'>, <class 'NoneType'>)
RESERVED_ATTR ['_locked', '_restrictions']
SEQUENCE_TYPES (<class 'list'>, <class 'tuple'>)
attention_probs_dropout_prob 0.1
classifier_activation True
default_params None
hidden_activation 'gelu'
hidden_dropout_prob 0.1
hidden_size 512
initializer_range 0.02
input_mask_dtype 'int32'
intermediate_size 4096
intra_bottleneck_size 1024
key_query_shared_bottleneck False
max_sequence_length 512
normalization_type 'layer_norm'
num_attention_heads 4
num_blocks 24
num_feedforward_networks 1
restrictions None
type_vocab_size 2
use_bottleneck_attention False
word_embed_size 128
word_vocab_size 30522