A callable that takes a single argument of type
tf.distribute.InputContext, which is used for batch size calculation and
cross-worker input pipeline sharding (if neither is needed, the
InputContext parameter can be ignored in the dataset_fn), and returns
a tf.data.Dataset.
[null,null,["Last updated 2021-05-14 UTC."],[],[],null,["# tf.keras.utils.experimental.DatasetCreator\n\n\u003cbr /\u003e\n\n|----------------------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/python/keras/utils/dataset_creator.py#L23-L84) |\n\nObject that returns a [`tf.data.Dataset`](../../../../tf/data/Dataset) upon invoking. \n\n tf.keras.utils.experimental.DatasetCreator(\n dataset_fn\n )\n\n[`tf.keras.utils.experimental.DatasetCreator`](../../../../tf/keras/utils/experimental/DatasetCreator) is designated as a supported type\nfor `x`, or the input, in [`tf.keras.Model.fit`](../../../../tf/keras/Model#fit). Pass an instance of this class\nto `fit` when using a callable (with a `input_context` argument) that returns\na [`tf.data.Dataset`](../../../../tf/data/Dataset). \n\n model = tf.keras.Sequential([tf.keras.layers.Dense(10)])\n model.compile(tf.keras.optimizers.SGD(), loss=\"mse\")\n\n def dataset_fn(input_context):\n global_batch_size = 64\n batch_size = input_context.get_per_replica_batch_size(global_batch_size)\n dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat()\n dataset = dataset.shard(\n input_context.num_input_pipelines, input_context.input_pipeline_id)\n dataset = dataset.batch(batch_size)\n dataset = dataset.prefetch(2)\n return dataset\n\n model.fit(DatasetCreator(dataset_fn), epochs=10, steps_per_epoch=10)\n\n[`Model.fit`](../../../../tf/keras/Model#fit) usage with `DatasetCreator` is intended to work across all\n[`tf.distribute.Strategy`](../../../../tf/distribute/Strategy)s, as long as [`Strategy.scope`](../../../../tf/distribute/MirroredStrategy#scope) is used at model\ncreation: \n\n strategy = tf.distribute.experimental.ParameterServerStrategy(\n cluster_resolver)\n with strategy.scope():\n model = tf.keras.Sequential([tf.keras.layers.Dense(10)])\n model.compile(tf.keras.optimizers.SGD(), loss=\"mse\")\n ...\n\n| **Note:** When using `DatasetCreator`, `steps_per_epoch` argument in [`Model.fit`](../../../../tf/keras/Model#fit) must be provided as the cardinality of such input cannot be inferred.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|--------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `dataset_fn` | A callable that takes a single argument of type [`tf.distribute.InputContext`](../../../../tf/distribute/InputContext), which is used for batch size calculation and cross-worker input pipeline sharding (if neither is needed, the `InputContext` parameter can be ignored in the `dataset_fn`), and returns a [`tf.data.Dataset`](../../../../tf/data/Dataset). |\n\n\u003cbr /\u003e\n\nMethods\n-------\n\n### `__call__`\n\n[View source](https://github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/python/keras/utils/dataset_creator.py#L77-L84) \n\n __call__(\n *args, **kwargs\n )\n\nCall self as a function."]]