|  View source on GitHub | 
A single-replica view of training procedure.
Inherits From: Task
tfm.vision.MaskRCNNTask(
    params, logging_dir: Optional[str] = None, name: Optional[str] = None
)
Mask R-CNN task provides artifacts for training/evalution procedures, including loading/iterating over Datasets, initializing the model, calculating the loss, post-processing, and customized metrics with reduction.
| Attributes | |
|---|---|
| logging_dir | |
| task_config | |
Methods
aggregate_logs
aggregate_logs(
    state: Optional[Any] = None, step_outputs: Optional[Dict[str, Any]] = None
) -> Optional[Any]
Optional aggregation over logs returned from a validation step.
build_inputs
build_inputs(
    params: tfm.vision.configs.maskrcnn.DataConfig,
    input_context: Optional[tf.distribute.InputContext] = None,
    dataset_fn: Optional[dataset_fn_lib.PossibleDatasetType] = None
) -> tf.data.Dataset
Builds input dataset.
build_losses
build_losses(
    outputs: Mapping[str, Any],
    labels: Mapping[str, Any],
    aux_losses: Optional[Any] = None
) -> Dict[str, tf.Tensor]
Builds Mask R-CNN losses.
build_metrics
build_metrics(
    training: bool = True
)
Builds detection metrics.
build_model
build_model()
Builds Mask R-CNN model.
create_optimizer
@classmethodcreate_optimizer( optimizer_config:tfm.optimization.OptimizationConfig, runtime_config: Optional[tfm.core.base_task.RuntimeConfig] = None, dp_config: Optional[tfm.core.base_task.DifferentialPrivacyConfig] = None )
Creates an TF optimizer from configurations.
| Args | |
|---|---|
| optimizer_config | the parameters of the Optimization settings. | 
| runtime_config | the parameters of the runtime. | 
| dp_config | the parameter of differential privacy. | 
| Returns | |
|---|---|
| A tf.optimizers.Optimizer object. | 
inference_step
inference_step(
    inputs, model: tf.keras.Model
)
Performs the forward step.
With distribution strategies, this method runs on devices.
| Args | |
|---|---|
| inputs | a dictionary of input tensors. | 
| model | the keras.Model. | 
| Returns | |
|---|---|
| Model outputs. | 
initialize
initialize(
    model: tf.keras.Model
)
Loads pretrained checkpoint.
process_compiled_metrics
process_compiled_metrics(
    compiled_metrics, labels, model_outputs
)
Process and update compiled_metrics.
call when using compile/fit API.
| Args | |
|---|---|
| compiled_metrics | the compiled metrics (model.compiled_metrics). | 
| labels | a tensor or a nested structure of tensors. | 
| model_outputs | a tensor or a nested structure of tensors. For example, output of the keras model built by self.build_model. | 
process_metrics
process_metrics(
    metrics, labels, model_outputs, **kwargs
)
Process and update metrics.
Called when using custom training loop API.
| Args | |
|---|---|
| metrics | a nested structure of metrics objects. The return of function self.build_metrics. | 
| labels | a tensor or a nested structure of tensors. | 
| model_outputs | a tensor or a nested structure of tensors. For example, output of the keras model built by self.build_model. | 
| **kwargs | other args. | 
reduce_aggregated_logs
reduce_aggregated_logs(
    aggregated_logs: Dict[str, Any], global_step: Optional[tf.Tensor] = None
) -> Dict[str, tf.Tensor]
Optional reduce of aggregated logs over validation steps.
train_step
train_step(
    inputs: Tuple[Any, Any],
    model: tf.keras.Model,
    optimizer: tf.keras.optimizers.Optimizer,
    metrics: Optional[List[Any]] = None
)
Does forward and backward.
| Args | |
|---|---|
| inputs | a dictionary of input tensors. | 
| model | the model, forward pass definition. | 
| optimizer | the optimizer for this training step. | 
| metrics | a nested structure of metrics objects. | 
| Returns | |
|---|---|
| A dictionary of logs. | 
validation_step
validation_step(
    inputs: Tuple[Any, Any],
    model: tf.keras.Model,
    metrics: Optional[List[Any]] = None
)
Validatation step.
| Args | |
|---|---|
| inputs | a dictionary of input tensors. | 
| model | the keras.Model. | 
| metrics | a nested structure of metrics objects. | 
| Returns | |
|---|---|
| A dictionary of logs. | 
| Class Variables | |
|---|---|
| loss | 'loss' |