A callable that takes a single argument of type
tf.distribute.InputContext, which is used for batch size calculation and
cross-worker input pipeline sharding (if neither is needed, the
InputContext parameter can be ignored in the dataset_fn), and returns
a tf.data.Dataset.
input_options
Optional tf.distribute.InputOptions, used for specific
options when used with distribution, for example, whether to prefetch
dataset elements to accelerator device memory or host device memory, and
prefetch buffer size in the replica device memory. No effect if not used
with distributed training. See tf.distribute.InputOptions for more
information.
[null,null,["Last updated 2022-09-07 UTC."],[],[],null,["# tf.keras.utils.experimental.DatasetCreator\n\n\u003cbr /\u003e\n\n|------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/keras-team/keras/tree/v2.8.0/keras/utils/dataset_creator.py#L22-L110) |\n\nObject that returns a [`tf.data.Dataset`](../../../../tf/data/Dataset) upon invoking. \n\n tf.keras.utils.experimental.DatasetCreator(\n dataset_fn, input_options=None\n )\n\n[`tf.keras.utils.experimental.DatasetCreator`](../../../../tf/keras/utils/experimental/DatasetCreator) is designated as a supported type\nfor `x`, or the input, in [`tf.keras.Model.fit`](../../../../tf/keras/Model#fit). Pass an instance of this class\nto `fit` when using a callable (with a `input_context` argument) that returns\na [`tf.data.Dataset`](../../../../tf/data/Dataset). \n\n model = tf.keras.Sequential([tf.keras.layers.Dense(10)])\n model.compile(tf.keras.optimizers.SGD(), loss=\"mse\")\n\n def dataset_fn(input_context):\n global_batch_size = 64\n batch_size = input_context.get_per_replica_batch_size(global_batch_size)\n dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat()\n dataset = dataset.shard(\n input_context.num_input_pipelines, input_context.input_pipeline_id)\n dataset = dataset.batch(batch_size)\n dataset = dataset.prefetch(2)\n return dataset\n\n input_options = tf.distribute.InputOptions(\n experimental_fetch_to_device=True,\n experimental_per_replica_buffer_size=2)\n model.fit(tf.keras.utils.experimental.DatasetCreator(\n dataset_fn, input_options=input_options), epochs=10, steps_per_epoch=10)\n\n[`Model.fit`](../../../../tf/keras/Model#fit) usage with `DatasetCreator` is intended to work across all\n[`tf.distribute.Strategy`](../../../../tf/distribute/Strategy)s, as long as [`Strategy.scope`](../../../../tf/distribute/MirroredStrategy#scope) is used at model\ncreation: \n\n strategy = tf.distribute.experimental.ParameterServerStrategy(\n cluster_resolver)\n with strategy.scope():\n model = tf.keras.Sequential([tf.keras.layers.Dense(10)])\n model.compile(tf.keras.optimizers.SGD(), loss=\"mse\")\n\n def dataset_fn(input_context):\n ...\n\n input_options = ...\n model.fit(tf.keras.utils.experimental.DatasetCreator(\n dataset_fn, input_options=input_options), epochs=10, steps_per_epoch=10)\n\n| **Note:** When using `DatasetCreator`, `steps_per_epoch` argument in [`Model.fit`](../../../../tf/keras/Model#fit) must be provided as the cardinality of such input cannot be inferred.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|-----------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `dataset_fn` | A callable that takes a single argument of type [`tf.distribute.InputContext`](../../../../tf/distribute/InputContext), which is used for batch size calculation and cross-worker input pipeline sharding (if neither is needed, the `InputContext` parameter can be ignored in the `dataset_fn`), and returns a [`tf.data.Dataset`](../../../../tf/data/Dataset). |\n| `input_options` | Optional [`tf.distribute.InputOptions`](../../../../tf/distribute/InputOptions), used for specific options when used with distribution, for example, whether to prefetch dataset elements to accelerator device memory or host device memory, and prefetch buffer size in the replica device memory. No effect if not used with distributed training. See [`tf.distribute.InputOptions`](../../../../tf/distribute/InputOptions) for more information. |\n\n\u003cbr /\u003e\n\nMethods\n-------\n\n### `__call__`\n\n[View source](https://github.com/keras-team/keras/tree/v2.8.0/keras/utils/dataset_creator.py#L102-L110) \n\n __call__(\n *args, **kwargs\n )\n\nCall self as a function."]]