tf.keras.utils.Sequence
Stay organized with collections
Save and categorize content based on your preferences.
Base object for fitting to a sequence of data, such as a dataset.
Every Sequence
must implement the __getitem__
and the __len__
methods.
If you want to modify your dataset between epochs, you may implement
on_epoch_end
. The method __getitem__
should return a complete batch.
Notes:
Sequence
is a safer way to do multiprocessing. This structure guarantees
that the network will only train once on each sample per epoch, which is not
the case with generators.
Examples:
from skimage.io import imread
from skimage.transform import resize
import numpy as np
import math
# Here, `x_set` is list of path to the images
# and `y_set` are the associated classes.
class CIFAR10Sequence(tf.keras.utils.Sequence):
def __init__(self, x_set, y_set, batch_size):
self.x, self.y = x_set, y_set
self.batch_size = batch_size
def __len__(self):
return math.ceil(len(self.x) / self.batch_size)
def __getitem__(self, idx):
low = idx * self.batch_size
# Cap upper bound at array length; the last batch may be smaller
# if the total number of items is not a multiple of batch size.
high = min(low + self.batch_size, len(self.x))
batch_x = self.x[low:high]
batch_y = self.y[low:high]
return np.array([
resize(imread(file_name), (200, 200))
for file_name in batch_x]), np.array(batch_y)
Methods
on_epoch_end
View source
on_epoch_end()
Method called at the end of every epoch.
__getitem__
View source
__getitem__(
index
)
Gets batch at position index
.
Args |
index
|
position of the batch in the Sequence.
|
__iter__
View source
__iter__()
Create a generator that iterate over the Sequence.
__len__
View source
__len__()
Number of batch in the Sequence.
Returns |
The number of batches in the Sequence.
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates. Some content is licensed under the numpy license.
Last updated 2023-10-06 UTC.
[null,null,["Last updated 2023-10-06 UTC."],[],[],null,["# tf.keras.utils.Sequence\n\n\u003cbr /\u003e\n\n|---------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/keras-team/keras/tree/v2.14.0/keras/utils/data_utils.py#L492-L568) |\n\nBase object for fitting to a sequence of data, such as a dataset.\n\nEvery `Sequence` must implement the `__getitem__` and the `__len__` methods.\nIf you want to modify your dataset between epochs, you may implement\n`on_epoch_end`. The method `__getitem__` should return a complete batch.\n\n#### Notes:\n\n`Sequence` is a safer way to do multiprocessing. This structure guarantees\nthat the network will only train once on each sample per epoch, which is not\nthe case with generators.\n\n#### Examples:\n\n from skimage.io import imread\n from skimage.transform import resize\n import numpy as np\n import math\n\n # Here, `x_set` is list of path to the images\n # and `y_set` are the associated classes.\n\n class CIFAR10Sequence(tf.keras.utils.Sequence):\n\n def __init__(self, x_set, y_set, batch_size):\n self.x, self.y = x_set, y_set\n self.batch_size = batch_size\n\n def __len__(self):\n return math.ceil(len(self.x) / self.batch_size)\n\n def __getitem__(self, idx):\n low = idx * self.batch_size\n # Cap upper bound at array length; the last batch may be smaller\n # if the total number of items is not a multiple of batch size.\n high = min(low + self.batch_size, len(self.x))\n batch_x = self.x[low:high]\n batch_y = self.y[low:high]\n\n return np.array([\n resize(imread(file_name), (200, 200))\n for file_name in batch_x]), np.array(batch_y)\n\nMethods\n-------\n\n### `on_epoch_end`\n\n[View source](https://github.com/keras-team/keras/tree/v2.14.0/keras/utils/data_utils.py#L561-L563) \n\n on_epoch_end()\n\nMethod called at the end of every epoch.\n\n### `__getitem__`\n\n[View source](https://github.com/keras-team/keras/tree/v2.14.0/keras/utils/data_utils.py#L540-L550) \n\n __getitem__(\n index\n )\n\nGets batch at position `index`.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ||\n|---------|----------------------------------------|\n| `index` | position of the batch in the Sequence. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ||\n|---|---|\n| A batch ||\n\n\u003cbr /\u003e\n\n### `__iter__`\n\n[View source](https://github.com/keras-team/keras/tree/v2.14.0/keras/utils/data_utils.py#L565-L568) \n\n __iter__()\n\nCreate a generator that iterate over the Sequence.\n\n### `__len__`\n\n[View source](https://github.com/keras-team/keras/tree/v2.14.0/keras/utils/data_utils.py#L552-L559) \n\n __len__()\n\nNumber of batch in the Sequence.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ||\n|---|---|\n| The number of batches in the Sequence. ||\n\n\u003cbr /\u003e"]]