|  View source on GitHub | 
Object that returns a tf.data.Dataset upon invoking.
tf.keras.utils.experimental.DatasetCreator(
    dataset_fn, input_options=None
)
tf.keras.utils.experimental.DatasetCreator is designated as a supported type
for x, or the input, in tf.keras.Model.fit. Pass an instance of this class
to fit when using a callable (with a input_context argument) that returns
a tf.data.Dataset.
model = tf.keras.Sequential([tf.keras.layers.Dense(10)])
model.compile(tf.keras.optimizers.SGD(), loss="mse")
def dataset_fn(input_context):
  global_batch_size = 64
  batch_size = input_context.get_per_replica_batch_size(global_batch_size)
  dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat()
  dataset = dataset.shard(
      input_context.num_input_pipelines, input_context.input_pipeline_id)
  dataset = dataset.batch(batch_size)
  dataset = dataset.prefetch(2)
  return dataset
input_options = tf.distribute.InputOptions(
    experimental_fetch_to_device=True,
    experimental_per_replica_buffer_size=2)
model.fit(tf.keras.utils.experimental.DatasetCreator(
    dataset_fn, input_options=input_options), epochs=10, steps_per_epoch=10)
Model.fit usage with DatasetCreator is intended to work across all
tf.distribute.Strategys, as long as Strategy.scope is used at model
creation:
strategy = tf.distribute.experimental.ParameterServerStrategy(
    cluster_resolver)
with strategy.scope():
  model = tf.keras.Sequential([tf.keras.layers.Dense(10)])
model.compile(tf.keras.optimizers.SGD(), loss="mse")
...
| Args | |
|---|---|
| dataset_fn | A callable that takes a single argument of type tf.distribute.InputContext, which is used for batch size calculation and
cross-worker input pipeline sharding (if neither is needed, theInputContextparameter can be ignored in thedataset_fn), and returns
atf.data.Dataset. | 
| input_options | Optional tf.distribute.InputOptions, used for specific
options when used with distribution, for example, whether to prefetch
dataset elements to accelerator device memory or host device memory, and
prefetch buffer size in the replica device memory. No effect if not used
with distributed training. Seetf.distribute.InputOptionsfor more
information. | 
Methods
__call__
__call__(
    *args, **kwargs
)
Call self as a function.