View source on GitHub
|
Object that returns a tf.data.Dataset upon invoking.
tf.keras.utils.experimental.DatasetCreator(
dataset_fn
)
tf.keras.utils.experimental.DatasetCreator is designated as a supported type
for x, or the input, in tf.keras.Model.fit. Pass an instance of this class
to fit when using a callable (with a input_context argument) that returns
a tf.data.Dataset.
model = tf.keras.Sequential([tf.keras.layers.Dense(10)])
model.compile(tf.keras.optimizers.SGD(), loss="mse")
def dataset_fn(input_context):
global_batch_size = 64
batch_size = input_context.get_per_replica_batch_size(global_batch_size)
dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat()
dataset = dataset.shard(
input_context.num_input_pipelines, input_context.input_pipeline_id)
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(2)
return dataset
model.fit(DatasetCreator(dataset_fn), epochs=10, steps_per_epoch=10)
Model.fit usage with DatasetCreator is intended to work across all
tf.distribute.Strategys, as long as Strategy.scope is used at model
creation:
strategy = tf.distribute.experimental.ParameterServerStrategy(
cluster_resolver)
with strategy.scope():
model = tf.keras.Sequential([tf.keras.layers.Dense(10)])
model.compile(tf.keras.optimizers.SGD(), loss="mse")
...
Args | |
|---|---|
dataset_fn
|
A callable that takes a single argument of type
tf.distribute.InputContext, which is used for batch size calculation and
cross-worker input pipeline sharding (if neither is needed, the
InputContext parameter can be ignored in the dataset_fn), and returns
a tf.data.Dataset.
|
Methods
__call__
__call__(
*args, **kwargs
)
Call self as a function.
View source on GitHub