View source on GitHub
|
Task object for tagging (e.g., NER or POS).
Inherits From: Task
tfm.nlp.tasks.TaggingTask(
params, logging_dir: Optional[str] = None, name: Optional[str] = None
)
Attributes | |
|---|---|
logging_dir
|
|
task_config
|
|
Methods
aggregate_logs
aggregate_logs(
state=None, step_outputs=None
)
Aggregates over logs returned from a validation step.
build_inputs
build_inputs(
params: tfm.core.config_definitions.DataConfig,
input_context=None
)
Returns tf.data.Dataset for sentence_prediction task.
build_losses
build_losses(
labels, model_outputs, aux_losses=None
) -> tf.Tensor
Standard interface to compute losses.
| Args | |
|---|---|
labels
|
optional label tensors. |
model_outputs
|
a nested structure of output tensors. |
aux_losses
|
auxiliary loss tensors, i.e. losses in keras.Model.
|
| Returns | |
|---|---|
| The total loss tensor. |
build_metrics
build_metrics(
training: bool = True
)
Gets streaming metrics for training/validation.
build_model
build_model()
[Optional] Creates model architecture.
| Returns | |
|---|---|
| A model instance. |
create_optimizer
@classmethodcreate_optimizer( optimizer_config:tfm.optimization.OptimizationConfig, runtime_config: Optional[tfm.core.base_task.RuntimeConfig] = None, dp_config: Optional[tfm.core.base_task.DifferentialPrivacyConfig] = None )
Creates an TF optimizer from configurations.
| Args | |
|---|---|
optimizer_config
|
the parameters of the Optimization settings. |
runtime_config
|
the parameters of the runtime. |
dp_config
|
the parameter of differential privacy. |
| Returns | |
|---|---|
| A tf.optimizers.Optimizer object. |
inference_step
inference_step(
inputs, model: tf.keras.Model
)
Performs the forward step.
initialize
initialize(
model: tf.keras.Model
)
[Optional] A callback function used as CheckpointManager's init_fn.
This function will be called when no checkpoint is found for the model. If there is a checkpoint, the checkpoint will be loaded and this function will not be called. You can use this callback function to load a pretrained checkpoint, saved under a directory other than the model_dir.
| Args | |
|---|---|
model
|
The keras.Model built or used by this task. |
process_compiled_metrics
process_compiled_metrics(
compiled_metrics, labels, model_outputs
)
Process and update compiled_metrics.
call when using compile/fit API.
| Args | |
|---|---|
compiled_metrics
|
the compiled metrics (model.compiled_metrics). |
labels
|
a tensor or a nested structure of tensors. |
model_outputs
|
a tensor or a nested structure of tensors. For example, output of the keras model built by self.build_model. |
process_metrics
process_metrics(
metrics, labels, model_outputs, **kwargs
)
Process and update metrics.
Called when using custom training loop API.
| Args | |
|---|---|
metrics
|
a nested structure of metrics objects. The return of function self.build_metrics. |
labels
|
a tensor or a nested structure of tensors. |
model_outputs
|
a tensor or a nested structure of tensors. For example, output of the keras model built by self.build_model. |
**kwargs
|
other args. |
reduce_aggregated_logs
reduce_aggregated_logs(
aggregated_logs, global_step=None
)
Reduces aggregated logs over validation steps.
train_step
train_step(
inputs,
model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer,
metrics=None
)
Does forward and backward.
With distribution strategies, this method runs on devices.
| Args | |
|---|---|
inputs
|
a dictionary of input tensors. |
model
|
the model, forward pass definition. |
optimizer
|
the optimizer for this training step. |
metrics
|
a nested structure of metrics objects. |
| Returns | |
|---|---|
| A dictionary of logs. |
validation_step
validation_step(
inputs, model: tf.keras.Model, metrics=None
)
Validatation step.
| Args | |
|---|---|
inputs
|
a dictionary of input tensors. |
model
|
the keras.Model. |
metrics
|
a nested structure of metrics objects. |
| Returns | |
|---|---|
| A dictionary of logs. |
Class Variables | |
|---|---|
| loss |
'loss'
|
View source on GitHub