View source on GitHub |
Object that returns a tf.data.Dataset
upon invoking.
tf.keras.utils.experimental.DatasetCreator(
dataset_fn, input_options=None
)
tf.keras.utils.experimental.DatasetCreator
is designated as a supported type
for x
, or the input, in tf.keras.Model.fit
. Pass an instance of this class
to fit
when using a callable (with a input_context
argument) that returns
a tf.data.Dataset
.
model = tf.keras.Sequential([tf.keras.layers.Dense(10)])
model.compile(tf.keras.optimizers.SGD(), loss="mse")
def dataset_fn(input_context):
global_batch_size = 64
batch_size = input_context.get_per_replica_batch_size(global_batch_size)
dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat()
dataset = dataset.shard(
input_context.num_input_pipelines, input_context.input_pipeline_id)
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(2)
return dataset
input_options = tf.distribute.InputOptions(
experimental_fetch_to_device=True,
experimental_per_replica_buffer_size=2)
model.fit(tf.keras.utils.experimental.DatasetCreator(
dataset_fn, input_options=input_options), epochs=10, steps_per_epoch=10)
Model.fit
usage with DatasetCreator
is intended to work across all
tf.distribute.Strategy
s, as long as Strategy.scope
is used at model
creation:
strategy = tf.distribute.experimental.ParameterServerStrategy(
cluster_resolver)
with strategy.scope():
model = tf.keras.Sequential([tf.keras.layers.Dense(10)])
model.compile(tf.keras.optimizers.SGD(), loss="mse")
...
Args | |
---|---|
dataset_fn
|
A callable that takes a single argument of type
tf.distribute.InputContext , which is used for batch size calculation and
cross-worker input pipeline sharding (if neither is needed, the
InputContext parameter can be ignored in the dataset_fn ), and returns
a tf.data.Dataset .
|
input_options
|
Optional tf.distribute.InputOptions , used for specific
options when used with distribution, for example, whether to prefetch
dataset elements to accelerator device memory or host device memory, and
prefetch buffer size in the replica device memory. No effect if not used
with distributed training. See tf.distribute.InputOptions for more
information.
|
Methods
__call__
__call__(
*args, **kwargs
)
Call self as a function.