tfp.experimental.nn.util.tune_dataset
Sets generally recommended parameters for a tf.data.Dataset
.
tfp.experimental.nn.util.tune_dataset(
dataset,
batch_size=None,
shuffle_size=None,
preprocess_fn=None,
repeat_count=-1
)
Args |
dataset
|
tf.data.Dataset -like instance to be tuned according to this
functions arguments.
|
batch_size
|
Python int representing the number of elements in each
minibatch.
|
shuffle_size
|
Python int representing the number of elements to shuffle
(at a time).
|
preprocess_fn
|
Python callable applied to each item in dataset .
|
repeat_count
|
Python int , representing the number of times the dataset
should be repeated. The default behavior (repeat_count = -1 ) is for the
dataset to be repeated indefinitely. If repeat_count is None repeat is
"off;" note that this is a deviation from tf.data.Dataset.repeat which
interprets None as "repeat indefinitely".
Default value: -1 (i.e., repeat indefinitely).
|
Returns |
tuned_dataset
|
tf.data.Dataset instance tuned according to this functions
arguments.
|
Example
[train_dataset, eval_dataset], datasets_info = tfds.load(
name='mnist',
split=['train', 'test'],
with_info=True,
as_supervised=True,
shuffle_files=True)
def _preprocess(image, label):
image = tf.cast(image, dtype=tf.int32)
u = tf.random.uniform(shape=tf.shape(image), maxval=256, dtype=image.dtype)
image = tf.cast(u < image, dtype=tf.float32) # Randomly binarize.
return image, label
@tf.function(autograph=False)
def one_step(iter):
x, y = next(iter)
return tf.reduce_mean(x)
ds = tune_dataset(
train_dataset,
batch_size=32,
shuffle_size=int(datasets_info.splits['train'].num_examples / 7),
preprocess_fn=_preprocess)
it = iter(ds)
[one_step(it)]*3 # Build graph / burn-in.
%time one_step(it)
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2023-11-21 UTC.
[null,null,["Last updated 2023-11-21 UTC."],[],[]]