A callable that takes a single argument of type
tf.distribute.InputContext, which is used for batch size calculation
and cross-worker input pipeline sharding (if neither is needed, the
InputContext parameter can be ignored in the dataset_fn), and
returns a tf.data.Dataset.
input_options
Optional tf.distribute.InputOptions, used for specific
options when used with distribution, for example, whether to prefetch
dataset elements to accelerator device memory or host device memory, and
prefetch buffer size in the replica device memory. No effect if not used
with distributed training. See tf.distribute.InputOptions for more
information.
[null,null,["Last updated 2023-10-06 UTC."],[],[],null,["# tf.keras.utils.experimental.DatasetCreator\n\n\u003cbr /\u003e\n\n|-------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/keras-team/keras/tree/v2.13.1/keras/utils/dataset_creator.py#L24-L116) |\n\nObject that returns a [`tf.data.Dataset`](../../../../tf/data/Dataset) upon invoking. \n\n tf.keras.utils.experimental.DatasetCreator(\n dataset_fn, input_options=None\n )\n\n[`tf.keras.utils.experimental.DatasetCreator`](../../../../tf/keras/utils/experimental/DatasetCreator) is designated as a supported\ntype for `x`, or the input, in [`tf.keras.Model.fit`](../../../../tf/keras/Model#fit). Pass an instance of\nthis class to `fit` when using a callable (with a `input_context` argument)\nthat returns a [`tf.data.Dataset`](../../../../tf/data/Dataset). \n\n model = tf.keras.Sequential([tf.keras.layers.Dense(10)])\n model.compile(tf.keras.optimizers.SGD(), loss=\"mse\")\n\n def dataset_fn(input_context):\n global_batch_size = 64\n batch_size = input_context.get_per_replica_batch_size(global_batch_size)\n dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat()\n dataset = dataset.shard(\n input_context.num_input_pipelines, input_context.input_pipeline_id)\n dataset = dataset.batch(batch_size)\n dataset = dataset.prefetch(2)\n return dataset\n\n input_options = tf.distribute.InputOptions(\n experimental_fetch_to_device=True,\n experimental_per_replica_buffer_size=2)\n model.fit(tf.keras.utils.experimental.DatasetCreator(\n dataset_fn, input_options=input_options), epochs=10, steps_per_epoch=10)\n\n[`Model.fit`](../../../../tf/keras/Model#fit) usage with `DatasetCreator` is intended to work across all\n[`tf.distribute.Strategy`](../../../../tf/distribute/Strategy)s, as long as [`Strategy.scope`](../../../../tf/distribute/MirroredStrategy#scope) is used at model\ncreation: \n\n strategy = tf.distribute.experimental.ParameterServerStrategy(\n cluster_resolver)\n with strategy.scope():\n model = tf.keras.Sequential([tf.keras.layers.Dense(10)])\n model.compile(tf.keras.optimizers.SGD(), loss=\"mse\")\n\n def dataset_fn(input_context):\n ...\n\n input_options = ...\n model.fit(tf.keras.utils.experimental.DatasetCreator(\n dataset_fn, input_options=input_options), epochs=10, steps_per_epoch=10)\n\n| **Note:** When using `DatasetCreator`, `steps_per_epoch` argument in [`Model.fit`](../../../../tf/keras/Model#fit) must be provided as the cardinality of such input cannot be inferred.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|-----------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `dataset_fn` | A callable that takes a single argument of type [`tf.distribute.InputContext`](../../../../tf/distribute/InputContext), which is used for batch size calculation and cross-worker input pipeline sharding (if neither is needed, the `InputContext` parameter can be ignored in the `dataset_fn`), and returns a [`tf.data.Dataset`](../../../../tf/data/Dataset). |\n| `input_options` | Optional [`tf.distribute.InputOptions`](../../../../tf/distribute/InputOptions), used for specific options when used with distribution, for example, whether to prefetch dataset elements to accelerator device memory or host device memory, and prefetch buffer size in the replica device memory. No effect if not used with distributed training. See [`tf.distribute.InputOptions`](../../../../tf/distribute/InputOptions) for more information. |\n\n\u003cbr /\u003e\n\nMethods\n-------\n\n### `__call__`\n\n[View source](https://github.com/keras-team/keras/tree/v2.13.1/keras/utils/dataset_creator.py#L107-L116) \n\n __call__(\n *args, **kwargs\n )\n\nCall self as a function."]]