View source on GitHub |
Object that returns a tf.data.Dataset
upon invoking.
tf.keras.utils.experimental.DatasetCreator(
dataset_fn
)
tf.keras.utils.experimental.DatasetCreator
is designated as a supported type
for x
, or the input, in tf.keras.Model.fit
. Pass an instance of this class
to fit
when using a callable (with a input_context
argument) that returns
a tf.data.Dataset
.
model = tf.keras.Sequential([tf.keras.layers.Dense(10)])
model.compile(tf.keras.optimizers.SGD(), loss="mse")
def dataset_fn(input_context):
global_batch_size = 64
batch_size = input_context.get_per_replica_batch_size(global_batch_size)
dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat()
dataset = dataset.shard(
input_context.num_input_pipelines, input_context.input_pipeline_id)
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(2)
return dataset
model.fit(DatasetCreator(dataset_fn), epochs=10, steps_per_epoch=10)
Model.fit
usage with DatasetCreator
is intended to work across all
tf.distribute.Strategy
s, as long as Strategy.scope
is used at model
creation:
strategy = tf.distribute.experimental.ParameterServerStrategy(
cluster_resolver)
with strategy.scope():
model = tf.keras.Sequential([tf.keras.layers.Dense(10)])
model.compile(tf.keras.optimizers.SGD(), loss="mse")
...
Args | |
---|---|
dataset_fn
|
A callable that takes a single argument of type
tf.distribute.InputContext , which is used for batch size calculation and
cross-worker input pipeline sharding (if neither is needed, the
InputContext parameter can be ignored in the dataset_fn ), and returns
a tf.data.Dataset .
|
Methods
__call__
__call__(
*args, **kwargs
)
Call self as a function.