tf.keras.utils.experimental.DatasetCreator

Object that returns a tf.data.Dataset upon invoking.

tf.keras.utils.experimental.DatasetCreator is designated as a supported type for x, or the input, in tf.keras.Model.fit. Pass an instance of this class to fit when using a callable (with a input_context argument) that returns a tf.data.Dataset.

model = tf.keras.Sequential([tf.keras.layers.Dense(10)])
model.compile(tf.keras.optimizers.SGD(), loss="mse")

def dataset_fn(input_context):
  global_batch_size = 64
  batch_size = input_context.get_per_replica_batch_size(global_batch_size)
  dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat()
  dataset = dataset.shard(
      input_context.num_input_pipelines, input_context.input_pipeline_id)
  dataset = dataset.batch(batch_size)
  dataset = dataset.prefetch(2)
  return dataset

input_options = tf.distribute.InputOptions(
    experimental_fetch_to_device=True,
    experimental_per_replica_buffer_size=2)
model.fit(tf.keras.utils.experimental.DatasetCreator(
    dataset_fn, input_options=input_options), epochs=10, steps_per_epoch=10)

Model.fit usage with DatasetCreator is intended to work across all tf.distribute.Strategys, as long as Strategy.scope is used at model creation:

strategy = tf.distribute.experimental.ParameterServerStrategy(
    cluster_resolver)
with strategy.scope():
  model = tf.keras.Sequential([tf.keras.layers.Dense(10)])
model.compile(tf.keras.optimizers.SGD(), loss="mse")
...

dataset_fn A callable that takes a single argument of type tf.distribute.InputContext, which is used for batch size calculation and cross-worker input pipeline sharding (if neither is needed, the InputContext parameter can be ignored in the dataset_fn), and returns a tf.data.Dataset.
input_options Optional tf.distribute.InputOptions, used for specific options when used with distribution, for example, whether to prefetch dataset elements to accelerator device memory or host device memory, and prefetch buffer size in the replica device memory. No effect if not used with distributed training. See tf.distribute.InputOptions for more information.

Methods

__call__

View source

Call self as a function.