![]() |
ELECTRA network training model.
tfm.nlp.models.ElectraPretrainer(
generator_network,
discriminator_network,
vocab_size,
num_classes,
num_token_predictions,
mlm_activation=None,
mlm_initializer='glorot_uniform',
output_type='logits',
disallow_correct=False,
**kwargs
)
This is an implementation of the network structure described in "ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators" ( https://arxiv.org/abs/2003.10555).
The ElectraPretrainer allows a user to pass in two transformer models, one for generator, the other for discriminator, and instantiates the masked language model (at generator side) and classification networks (at discriminator side) that are used to create the training objectives.
Attributes | |
---|---|
checkpoint_items
|
Returns a dictionary of items to be additionally checkpointed. |
Methods
call
call(
inputs
)
ELECTRA forward pass.
Args | |
---|---|
inputs
|
A dict of all inputs, same as the standard BERT model. |
Returns | |
---|---|
outputs
|
A dict of pretrainer model outputs, including
(1) lm_outputs: A [batch_size, num_token_predictions, vocab_size]
tensor indicating logits on masked positions.
(2) sentence_outputs: A [batch_size, num_classes] tensor indicating
logits for nsp task.
(3) disc_logits: A [batch_size, sequence_length] tensor indicating
logits for discriminator replaced token detection task.
(4) disc_label: A [batch_size, sequence_length] tensor indicating
target labels for discriminator replaced token detection task.
|