The base configuration for building datasets.
Inherits From: Config
, ParamsDict
tfm.core.config_definitions.DataConfig(
default_params: dataclasses.InitVar[Optional[Mapping[str, Any]]] = None,
restrictions: dataclasses.InitVar[Optional[List[str]]] = None,
input_path: Union[Sequence[str], str, tfm.hyperparams.Config
] = '',
tfds_name: str = '',
tfds_split: str = '',
global_batch_size: int = 0,
is_training: bool = None,
drop_remainder: bool = True,
shuffle_buffer_size: int = 100,
cache: bool = False,
cycle_length: Optional[int] = None,
block_length: int = 1,
deterministic: Optional[bool] = None,
sharding: bool = True,
enable_tf_data_service: bool = False,
tf_data_service_address: Optional[str] = None,
tf_data_service_job_name: Optional[str] = None,
tfds_data_dir: str = '',
tfds_as_supervised: bool = False,
tfds_skip_decoding_feature: str = '',
seed: Optional[int] = None
)
Attributes |
input_path
|
The path to the input. It can be either (1) a str indicating a
file path/pattern, or (2) a str indicating multiple file paths/patterns
separated by comma (e.g "a, b, c" or no spaces "a,b,c"), or (3) a list of
str, each of which is a file path/pattern or multiple file paths/patterns
separated by comma, or (4) a dictionary of the previous three approaches
for more advanced data mixing using named access. It should not be
specified when the following tfds_name is specified.
|
tfds_name
|
The name of the tensorflow dataset (TFDS). It should not be
specified when the above input_path is specified.
|
tfds_split
|
A str indicating which split of the data to load from TFDS. It
is required when above tfds_name is specified.
|
global_batch_size
|
The global batch size across all replicas.
|
is_training
|
Whether this data is used for training or not. This flag is
useful for consumers of this object to determine whether the data should
be repeated or shuffled.
|
drop_remainder
|
Whether the last batch should be dropped in the case it has
fewer than global_batch_size elements.
|
shuffle_buffer_size
|
The buffer size used for shuffling training data.
|
cache
|
Whether to cache dataset examples. If True , we will cache the
dataset after applying the decode_fn and parse_fn. It can be used to avoid
re-reading from disk, re-decoding and re-parsing the example on the second
epoch, but it requires significant memory overhead.
|
cycle_length
|
The number of files that will be processed concurrently when
interleaving files.
|
block_length
|
The number of consecutive elements to produce from each input
element before cycling to another input element when interleaving files.
|
deterministic
|
A boolean controlling whether determinism should be enforced.
|
sharding
|
Whether sharding is used in the input pipeline.
|
enable_tf_data_service
|
A boolean indicating whether to enable tf.data
service for the input pipeline.
|
tf_data_service_address
|
The URI of a tf.data service to offload
preprocessing onto during training. The URI should be in the format
"protocol://address", e.g. "grpc://tf-data-service:5050". It can be
overridden by FLAGS.tf_data_service flag in the binary.
|
tf_data_service_job_name
|
The name of the tf.data service job. This argument
makes it possible for multiple datasets to share the same job. The default
behavior is that the dataset creates anonymous, exclusively owned jobs.
|
tfds_data_dir
|
A str specifying the directory to read/write TFDS data.
|
tfds_as_supervised
|
A bool. When loading dataset from TFDS, if True, the
returned tf.data.Dataset will have a 2-tuple structure (input, label)
according to builder.info.supervised_keys; if False, the default, the
returned tf.data.Dataset will have a dictionary with all the features.
|
tfds_skip_decoding_feature
|
A str to indicate which features are skipped for
decoding when loading dataset from TFDS. Use comma to separate multiple
features. The main use case is to skip the image/video decoding for better
performance.
|
seed
|
An optional seed to use for deterministic shuffling/preprocessing.
|
BUILDER
|
|
default_params
|
Dataclass field
|
restrictions
|
Dataclass field
|
Methods
as_dict
View source
as_dict()
Returns a dict representation of params_dict.ParamsDict.
For the nested params_dict.ParamsDict, a nested dict will be returned.
from_args
View source
@classmethod
from_args(
*args, **kwargs
)
Builds a config from the given list of arguments.
from_json
View source
@classmethod
from_json(
file_path: str
)
Wrapper for from_yaml
.
from_yaml
View source
@classmethod
from_yaml(
file_path: str
)
get
View source
get(
key, value=None
)
Accesses through built-in dictionary get method.
lock
View source
lock()
Makes the ParamsDict immutable.
override
View source
override(
override_params, is_strict=True
)
Override the ParamsDict with a set of given params.
Args |
override_params
|
a dict or a ParamsDict specifying the parameters to be
overridden.
|
is_strict
|
a boolean specifying whether override is strict or not. If
True, keys in override_params must be present in the ParamsDict. If
False, keys in override_params can be different from what is currently
defined in the ParamsDict. In this case, the ParamsDict will be extended
to include the new keys.
|
replace
View source
replace(
**kwargs
)
Overrides/returns a unlocked copy with the current config unchanged.
validate
View source
validate()
Validate the parameters consistency based on the restrictions.
This method validates the internal consistency using the pre-defined list of
restrictions. A restriction is defined as a string which specfiies a binary
operation. The supported binary operations are {'==', '!=', '<', '<=', '>',
'>='}. Note that the meaning of these operators are consistent with the
underlying Python immplementation. Users should make sure the define
restrictions on their type make sense.
For example, for a ParamsDict like the following
a:
a1: 1
a2: 2
b:
bb:
bb1: 10
bb2: 20
ccc:
a1: 1
a3: 3
one can define two restrictions like this
['a.a1 == b.ccc.a1', 'a.a2 <= b.bb.bb2']
What it enforces are |
- a.a1 = 1 == b.ccc.a1 = 1
- a.a2 = 2 <= b.bb.bb2 = 20
|
Raises |
KeyError
|
if any of the following happens
(1) any of parameters in any of restrictions is not defined in
ParamsDict,
(2) any inconsistency violating the restriction is found.
|
ValueError
|
if the restriction defined in the string is not supported.
|
__contains__
View source
__contains__(
key
)
Implements the membership test operator.
__eq__
__eq__(
other
)
Class Variables |
IMMUTABLE_TYPES
|
(<class 'str'>,
<class 'int'>,
<class 'float'>,
<class 'bool'>,
<class 'NoneType'>)
|
RESERVED_ATTR
|
['_locked', '_restrictions']
|
SEQUENCE_TYPES
|
(<class 'list'>, <class 'tuple'>)
|
block_length
|
1
|
cache
|
False
|
cycle_length
|
None
|
default_params
|
None
|
deterministic
|
None
|
drop_remainder
|
True
|
enable_tf_data_service
|
False
|
global_batch_size
|
0
|
input_path
|
''
|
is_training
|
None
|
restrictions
|
None
|
seed
|
None
|
sharding
|
True
|
shuffle_buffer_size
|
100
|
tf_data_service_address
|
None
|
tf_data_service_job_name
|
None
|
tfds_as_supervised
|
False
|
tfds_data_dir
|
''
|
tfds_name
|
''
|
tfds_skip_decoding_feature
|
''
|
tfds_split
|
''
|