tfm.core.config_definitions.DataConfig

The base configuration for building datasets.

Inherits From: Config, ParamsDict

input_path The path to the input. It can be either (1) a str indicating a file path/pattern, or (2) a str indicating multiple file paths/patterns separated by comma (e.g "a, b, c" or no spaces "a,b,c"), or (3) a list of str, each of which is a file path/pattern or multiple file paths/patterns separated by comma, or (4) a dictionary of the previous three approaches for more advanced data mixing using named access. It should not be specified when the following tfds_name is specified.
tfds_name The name of the tensorflow dataset (TFDS). It should not be specified when the above input_path is specified.
tfds_split A str indicating which split of the data to load from TFDS. It is required when above tfds_name is specified.
global_batch_size The global batch size across all replicas.
is_training Whether this data is used for training or not. This flag is useful for consumers of this object to determine whether the data should be repeated or shuffled.
drop_remainder Whether the last batch should be dropped in the case it has fewer than global_batch_size elements.
shuffle_buffer_size The buffer size used for shuffling training data.
cache Whether to cache dataset examples. If True, we will cache the dataset after applying the decode_fn and parse_fn. It can be used to avoid re-reading from disk, re-decoding and re-parsing the example on the second epoch, but it requires significant memory overhead.
cycle_length The number of files that will be processed concurrently when interleaving files.
block_length The number of consecutive elements to produce from each input element before cycling to another input element when interleaving files.
deterministic A boolean controlling whether determinism should be enforced.
sharding Whether sharding is used in the input pipeline.
enable_tf_data_service A boolean indicating whether to enable tf.data service for the input pipeline.
tf_data_service_address The URI of a tf.data service to offload preprocessing onto during training. The URI should be in the format "protocol://address", e.g. "grpc://tf-data-service:5050". It can be overridden by FLAGS.tf_data_service flag in the binary.
tf_data_service_job_name The name of the tf.data service job. This argument makes it possible for multiple datasets to share the same job. The default behavior is that the dataset creates anonymous, exclusively owned jobs.
tfds_data_dir A str specifying the directory to read/write TFDS data.
tfds_as_supervised A bool. When loading dataset from TFDS, if True, the returned tf.data.Dataset will have a 2-tuple structure (input, label) according to builder.info.supervised_keys; if False, the default, the returned tf.data.Dataset will have a dictionary with all the features.
tfds_skip_decoding_feature A str to indicate which features are skipped for decoding when loading dataset from TFDS. Use comma to separate multiple features. The main use case is to skip the image/video decoding for better performance.
enable_shared_tf_data_service_between_parallel_trainers A bool. When set to true, only a single tf.data service will be started, and it will be shared between all the trainer run simultaneously, e.g. using vizier to tune hyperparameters. This will save CPU and RAM resources compared to running separate tf.data service for each trainer. Notice that if batch size is different for different trainers, the field apply_tf_data_service_before_batching also needs to be true so that only a single tf.data service instance will be created. In this case, tf.data service will be applied before batching operation. So make sure to not apply any processing steps after batching (e.g. in postprocess_fn) since they wouldn't be paralleled by tf.data service and may slow down your tf.data pipeline. When using shared tf.data service, the tf.data dataset must be infinite, and slow trainer may skip certain training examples. More details about shared tf.data service can be found at: https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#sharing_tfdata_service_with_concurrent_trainers.
apply_tf_data_service_before_batching A bool. If set to True, tf.data service will be applied before batching operation. This is useful to make sure only a single tf.data service instance is created when enable_shared_tf_data_service_between_parallel_trainers is true and batch size is changing between parallel trainers.
trainer_id A string. The id of the trainer if there are multiple parallel trainer running at the same time, e.g. in vizier tuning case. It will be automatically set if this field is needed. Users does not need to set it when creating experiment configs.
seed An optional seed to use for deterministic shuffling/preprocessing.
prefetch_buffer_size An int specifying the buffer size of prefetch datasets. If None, the buffer size is autotuned. Specifying this is useful in case autotuning uses up too much memory by making the buffer size too high.
autotune_algorithm If specified, use this algorithm for AUTOTUNE. See: https://www.tensorflow.org/api_docs/python/tf/data/experimental/AutotuneAlgorithm
BUILDER

default_params Dataclass field
restrictions Dataclass field

Methods

as_dict

View source

Returns a dict representation of params_dict.ParamsDict.

For the nested params_dict.ParamsDict, a nested dict will be returned.

from_args

View source

Builds a config from the given list of arguments.

from_json

View source

Wrapper for from_yaml.

from_yaml

View source

get

View source

Accesses through built-in dictionary get method.

lock

View source

Makes the ParamsDict immutable.

override

View source

Override the ParamsDict with a set of given params.

Args
override_params a dict or a ParamsDict specifying the parameters to be overridden.
is_strict a boolean specifying whether override is strict or not. If True, keys in override_params must be present in the ParamsDict. If False, keys in override_params can be different from what is currently defined in the ParamsDict. In this case, the ParamsDict will be extended to include the new keys.

replace

View source

Overrides/returns a unlocked copy with the current config unchanged.

validate

View source

Validate the parameters consistency based on the restrictions.

This method validates the internal consistency using the pre-defined list of restrictions. A restriction is defined as a string which specifies a binary operation. The supported binary operations are {'==', '!=', '<', '<=', '>', '>='}. Note that the meaning of these operators are consistent with the underlying Python immplementation. Users should make sure the define restrictions on their type make sense.

For example, for a ParamsDict like the following

a:
  a1: 1
  a2: 2
b:
  bb:
    bb1: 10
    bb2: 20
  ccc:
    a1: 1
    a3: 3

one can define two restrictions like this ['a.a1 == b.ccc.a1', 'a.a2 <= b.bb.bb2']

What it enforces are

  • a.a1 = 1 == b.ccc.a1 = 1
  • a.a2 = 2 <= b.bb.bb2 = 20

Raises
KeyError if any of the following happens (1) any of parameters in any of restrictions is not defined in ParamsDict, (2) any inconsistency violating the restriction is found.
ValueError if the restriction defined in the string is not supported.

__contains__

View source

Implements the membership test operator.

__eq__

IMMUTABLE_TYPES (<class 'str'>, <class 'int'>, <class 'float'>, <class 'bool'>, <class 'NoneType'>)
RESERVED_ATTR ['_locked', '_restrictions']
SEQUENCE_TYPES (<class 'list'>, <class 'tuple'>)
apply_tf_data_service_before_batching False
autotune_algorithm None
block_length 1
cache False
cycle_length None
default_params None
deterministic None
drop_remainder True
enable_shared_tf_data_service_between_parallel_trainers False
enable_tf_data_service False
global_batch_size 0
input_path ''
is_training None
prefetch_buffer_size None
restrictions None
seed None
sharding True
shuffle_buffer_size 100
tf_data_service_address None
tf_data_service_job_name None
tfds_as_supervised False
tfds_data_dir ''
tfds_name ''
tfds_skip_decoding_feature ''
tfds_split ''
trainer_id None