View source on GitHub |
Input reader that returns a tf.data.Dataset instance.
tfm.core.input_reader.InputReader(
params: tfm.core.config_definitions.DataConfig
,
dataset_fn=tf.data.TFRecordDataset,
decoder_fn: Optional[Callable[..., Any]] = None,
combine_fn: Optional[Callable[..., Any]] = None,
sample_fn: Optional[Callable[..., Any]] = None,
parser_fn: Optional[Callable[..., Any]] = None,
filter_fn: Optional[Callable[..., tf.Tensor]] = None,
transform_and_batch_fn: Optional[Callable[[tf.data.Dataset, Optional[tf.distribute.InputContext]],
tf.data.Dataset]] = None,
postprocess_fn: Optional[Callable[..., Any]] = None
)
Args | |
---|---|
params
|
A config_definitions.DataConfig object. |
dataset_fn
|
A tf.data.Dataset that consumes the input files. For
example, it can be tf.data.TFRecordDataset .
|
decoder_fn
|
An optional callable that takes the serialized data string
and decodes them into the raw tensor dictionary.
|
combine_fn
|
An optional callable that takes a dictionarty of
tf.data.Dataset objects as input and outputs a combined dataset. It
will be executed after the decoder_fn and before the sample_fn.
|
sample_fn
|
An optional callable that takes a tf.data.Dataset object as
input and outputs the transformed dataset. It performs sampling on the
decoded raw tensors dict before the parser_fn.
|
parser_fn
|
An optional callable that takes the decoded raw tensors dict
and parse them into a dictionary of tensors that can be consumed by the
model. It will be executed after decoder_fn.
|
filter_fn
|
An optional callable mapping a dataset element to a boolean.
It will be executed after parser_fn.
|
transform_and_batch_fn
|
An optional callable that takes a
tf.data.Dataset object and an optional tf.distribute.InputContext as
input, and returns a tf.data.Dataset object. It will be executed after
parser_fn to transform and batch the dataset; if None, after
parser_fn is executed, the dataset will be batched into per-replica
batch size.
|
postprocess_fn
|
A optional callable that processes batched tensors. It
will be executed after batching.
|
Methods
get_files
get_files(
input_path
)
Gets matched files. Can be overridden by subclasses.
read
read(
input_context: Optional[tf.distribute.InputContext] = None,
dataset: Optional[tf.data.Dataset] = None
) -> tf.data.Dataset
Generates a tf.data.Dataset object.
Class Variables | |
---|---|
static_randnum |
1021534207
|