tf.keras.utils.experimental.DatasetCreator

Object that returns a tf.data.Dataset upon invoking.

tf.keras.utils.experimental.DatasetCreator is designated as a supported type for x, or the input, in tf.keras.Model.fit. Pass an instance of this class to fit when using a callable (with a input_context argument) that returns a tf.data.Dataset.

model = tf.keras.Sequential([tf.keras.layers.Dense(10)])
model.compile(tf.keras.optimizers.SGD(), loss="mse")

def dataset_fn(input_context):
  global_batch_size = 64
  batch_size = input_context.get_per_replica_batch_size(global_batch_size)
  dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat()
  dataset = dataset.shard(
      input_context.num_input_pipelines, input_context.input_pipeline_id)
  dataset = dataset.batch(batch_size)
  dataset = dataset.prefetch(2)
  return dataset

model.fit(DatasetCreator(dataset_fn), epochs=10, steps_per_epoch=10)

Model.fit usage with DatasetCreator is intended to work across all tf.distribute.Strategys, as long as Strategy.scope is used at model creation:

strategy = tf.distribute.experimental.ParameterServerStrategy(
    cluster_resolver)
with strategy.scope():
  model = tf.keras.Sequential([tf.keras.layers.Dense(10)])
model.compile(tf.keras.optimizers.SGD(), loss="mse")
...

dataset_fn A callable that takes a single argument of type tf.distribute.InputContext, which is used for batch size calculation and cross-worker input pipeline sharding (if neither is needed, the InputContext parameter can be ignored in the dataset_fn), and returns a tf.data.Dataset.

Methods

__call__

View source

Call self as a function.