View source on GitHub
|
Object that returns a tf.data.Dataset upon invoking.
tf.keras.utils.experimental.DatasetCreator(
dataset_fn, input_options=None
)
tf.keras.utils.experimental.DatasetCreator is designated as a supported type
for x, or the input, in tf.keras.Model.fit. Pass an instance of this class
to fit when using a callable (with a input_context argument) that returns
a tf.data.Dataset.
model = tf.keras.Sequential([tf.keras.layers.Dense(10)])
model.compile(tf.keras.optimizers.SGD(), loss="mse")
def dataset_fn(input_context):
global_batch_size = 64
batch_size = input_context.get_per_replica_batch_size(global_batch_size)
dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat()
dataset = dataset.shard(
input_context.num_input_pipelines, input_context.input_pipeline_id)
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(2)
return dataset
input_options = tf.distribute.InputOptions(
experimental_fetch_to_device=True,
experimental_per_replica_buffer_size=2)
model.fit(tf.keras.utils.experimental.DatasetCreator(
dataset_fn, input_options=input_options), epochs=10, steps_per_epoch=10)
Model.fit usage with DatasetCreator is intended to work across all
tf.distribute.Strategys, as long as Strategy.scope is used at model
creation:
strategy = tf.distribute.experimental.ParameterServerStrategy(
cluster_resolver)
with strategy.scope():
model = tf.keras.Sequential([tf.keras.layers.Dense(10)])
model.compile(tf.keras.optimizers.SGD(), loss="mse")
def dataset_fn(input_context):
...
input_options = ...
model.fit(tf.keras.utils.experimental.DatasetCreator(
dataset_fn, input_options=input_options), epochs=10, steps_per_epoch=10)
Args | |
|---|---|
dataset_fn
|
A callable that takes a single argument of type
tf.distribute.InputContext, which is used for batch size calculation and
cross-worker input pipeline sharding (if neither is needed, the
InputContext parameter can be ignored in the dataset_fn), and returns
a tf.data.Dataset.
|
input_options
|
Optional tf.distribute.InputOptions, used for specific
options when used with distribution, for example, whether to prefetch
dataset elements to accelerator device memory or host device memory, and
prefetch buffer size in the replica device memory. No effect if not used
with distributed training. See tf.distribute.InputOptions for more
information.
|
Methods
__call__
__call__(
*args, **kwargs
)
Call self as a function.
View source on GitHub