View source on GitHub
|
Base class for defining a parallel dataset using Python code.
tf.keras.utils.PyDataset(
workers=1, use_multiprocessing=False, max_queue_size=10
)
Every PyDataset must implement the __getitem__() and the __len__()
methods. If you want to modify your dataset between epochs,
you may additionally implement on_epoch_end().
The __getitem__() method should return a complete batch
(not a single sample), and the __len__ method should return
the number of batches in the dataset (rather than the number of samples).
Notes:
PyDatasetis a safer way to do multiprocessing. This structure guarantees that the model will only train once on each sample per epoch, which is not the case with Python generators.- The arguments
workers,use_multiprocessing, andmax_queue_sizeexist to configure howfit()uses parallelism to iterate over the dataset. They are not being used by thePyDatasetclass directly. When you are manually iterating over aPyDataset, no parallelism is applied.
Example:
from skimage.io import imread
from skimage.transform import resize
import numpy as np
import math
# Here, `x_set` is list of path to the images
# and `y_set` are the associated classes.
class CIFAR10PyDataset(keras.utils.PyDataset):
def __init__(self, x_set, y_set, batch_size, **kwargs):
super().__init__(**kwargs)
self.x, self.y = x_set, y_set
self.batch_size = batch_size
def __len__(self):
# Return number of batches.
return math.ceil(len(self.x) / self.batch_size)
def __getitem__(self, idx):
# Return x, y for batch idx.
low = idx * self.batch_size
# Cap upper bound at array length; the last batch may be smaller
# if the total number of items is not a multiple of batch size.
high = min(low + self.batch_size, len(self.x))
batch_x = self.x[low:high]
batch_y = self.y[low:high]
return np.array([
resize(imread(file_name), (200, 200))
for file_name in batch_x]), np.array(batch_y)
Attributes | |
|---|---|
max_queue_size
|
|
num_batches
|
Number of batches in the PyDataset. |
use_multiprocessing
|
|
workers
|
|
Methods
on_epoch_end
on_epoch_end()
Method called at the end of every epoch.
__getitem__
__getitem__(
index
)
Gets batch at position index.
| Args | |
|---|---|
index
|
position of the batch in the PyDataset. |
| Returns | |
|---|---|
| A batch |
View source on GitHub