tf.keras.utils.StepsPerExecutionTuner
Stay organized with collections
Save and categorize content based on your preferences.
Steps per execution tuner class.
tf.keras.utils.StepsPerExecutionTuner(
optimizer,
spe_variable,
interval=5,
change_spe_interval=10,
change_threshold=0.1
)
Args |
optimizer
|
The optimizer used for training/evaluation/prediction. Used
to measure iterations and global throughput
(optimizer.iterations /second).
|
spe_variable
|
A tf.Variable representing the steps_per_execution
variable used during training/evaluation/prediction. Must be
updatable with spe_variable.assign .
|
interval
|
Optional int, the amount of seconds to wait between calls to
measure throughput and tune spe_variable . Defaults to 5.
|
change_spe_interval
|
Optional int, the number of throughput measurements
before tuning. Defaults to 10.
|
change_threshold
|
Optional float, the percent different in throughput to
trigger a steps_per_execution change. For example, 0.1 triggers
changes if throughput changes more than 10%.
|
Examples:
If you're using model.compile
and model.fit
, this functionality is
available at compile time with steps_per_execution='auto'
model.compile(..., steps_per_execution='auto')
Custom training loop usage:
# Get model
inputs = keras.Input(shape=(784,), name="digits")
x = layers.Dense(64, activation="relu", name="dense_1")(inputs)
x = layers.Dense(64, activation="relu", name="dense_2")(x)
outputs = layers.Dense(10, name="predictions")(x)
model = keras.Model(inputs=inputs, outputs=outputs)
# Instantiate an optimizer to train the model.
optimizer = keras.optimizers.SGD(learning_rate=1e-3)
# Instantiate a loss function.
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
# Prepare the training dataset.
batch_size = 64
(x_train, y_train), (_, _) = keras.datasets.mnist.load_data()
x_train = np.reshape(x_train, (-1, 784))
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
# Create our steps per execution variable
steps_per_execution = tf.Variable(
1,
dtype="int64",
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA
)
# Create the tuner
tuner = StepsPerExecutionTuner(
optimizer, steps_per_execution
)
# Create a step function that runs a single training step
@tf.function
def step_fn(iterator):
batch_data, labels = next(iterator)
with tf.GradientTape() as tape:
logits = model(batch_data, training=True)
loss_value = loss_fn(labels, logits)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
# We can now pack multiple execution steps into one call
@tf.function
def multi_step_train_fn(iterator, steps_per_execution):
for _ in tf.range(steps_per_execution):
outputs = step_fn(iterator)
return
initial_steps_per_execution = 1
steps_per_epoch = 100
epochs = 2
# Start the tuner before training
tuner.start()
# We can now call our multi step training with our data
for epoch in range(epochs):
for _ in range(steps_per_epoch):
multi_step_train_fn(iterator, steps_per_execution)
# End the tuner after training
tuner.stop()
Attributes |
steps_per_execution
|
Settable attribute representingsteps_per_execution variable.
|
Methods
start
View source
start()
Starts steps per execution tuning thread.
Returns a threading.Thread
which will run every self.interval
seconds to measure throughput and tune steps per execution.
stop
View source
stop()
Stops steps per execution tuning thread.
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates. Some content is licensed under the numpy license.
Last updated 2024-01-23 UTC.
[null,null,["Last updated 2024-01-23 UTC."],[],[],null,["# tf.keras.utils.StepsPerExecutionTuner\n\n\u003cbr /\u003e\n\n|------------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/keras-team/keras/tree/v2.15.0/keras/utils/steps_per_execution_tuning.py#L25-L264) |\n\nSteps per execution tuner class. \n\n tf.keras.utils.StepsPerExecutionTuner(\n optimizer,\n spe_variable,\n interval=5,\n change_spe_interval=10,\n change_threshold=0.1\n )\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|-----------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `optimizer` | The optimizer used for training/evaluation/prediction. Used to measure iterations and global throughput (`optimizer.iterations`/second). |\n| `spe_variable` | A [`tf.Variable`](../../../tf/Variable) representing the `steps_per_execution` variable used during training/evaluation/prediction. Must be updatable with `spe_variable.assign`. |\n| `interval` | Optional int, the amount of seconds to wait between calls to measure throughput and tune `spe_variable`. Defaults to 5. |\n| `change_spe_interval` | Optional int, the number of throughput measurements before tuning. Defaults to 10. |\n| `change_threshold` | Optional float, the percent different in throughput to trigger a `steps_per_execution` change. For example, `0.1` triggers changes if throughput changes more than 10%. |\n\n\u003cbr /\u003e\n\n#### Examples:\n\nIf you're using `model.compile` and `model.fit`, this functionality is\navailable at compile time with `steps_per_execution='auto'` \n\n model.compile(..., steps_per_execution='auto')\n\nCustom training loop usage: \n\n # Get model\n inputs = keras.Input(shape=(784,), name=\"digits\")\n x = layers.Dense(64, activation=\"relu\", name=\"dense_1\")(inputs)\n x = layers.Dense(64, activation=\"relu\", name=\"dense_2\")(x)\n outputs = layers.Dense(10, name=\"predictions\")(x)\n model = keras.Model(inputs=inputs, outputs=outputs)\n\n # Instantiate an optimizer to train the model.\n optimizer = keras.optimizers.SGD(learning_rate=1e-3)\n # Instantiate a loss function.\n loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n\n # Prepare the training dataset.\n batch_size = 64\n (x_train, y_train), (_, _) = keras.datasets.mnist.load_data()\n x_train = np.reshape(x_train, (-1, 784))\n train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))\n\n # Create our steps per execution variable\n steps_per_execution = tf.Variable(\n 1,\n dtype=\"int64\",\n aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA\n )\n\n # Create the tuner\n tuner = StepsPerExecutionTuner(\n optimizer, steps_per_execution\n )\n\n # Create a step function that runs a single training step\n @tf.function\n def step_fn(iterator):\n batch_data, labels = next(iterator)\n with tf.GradientTape() as tape:\n logits = model(batch_data, training=True)\n loss_value = loss_fn(labels, logits)\n grads = tape.gradient(loss_value, model.trainable_weights)\n optimizer.apply_gradients(zip(grads, model.trainable_weights))\n\n # We can now pack multiple execution steps into one call\n @tf.function\n def multi_step_train_fn(iterator, steps_per_execution):\n for _ in tf.range(steps_per_execution):\n outputs = step_fn(iterator)\n return\n\n initial_steps_per_execution = 1\n steps_per_epoch = 100\n epochs = 2\n\n # Start the tuner before training\n tuner.start()\n\n # We can now call our multi step training with our data\n for epoch in range(epochs):\n for _ in range(steps_per_epoch):\n multi_step_train_fn(iterator, steps_per_execution)\n\n # End the tuner after training\n tuner.stop()\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Attributes ---------- ||\n|-----------------------|----------------------------------------------------------------|\n| `steps_per_execution` | Settable attribute representing`steps_per_execution` variable. |\n\n\u003cbr /\u003e\n\nMethods\n-------\n\n### `start`\n\n[View source](https://github.com/keras-team/keras/tree/v2.15.0/keras/utils/steps_per_execution_tuning.py#L136-L149) \n\n start()\n\nStarts steps per execution tuning thread.\n\nReturns a `threading.Thread` which will run every `self.interval`\nseconds to measure throughput and tune steps per execution.\n\n### `stop`\n\n[View source](https://github.com/keras-team/keras/tree/v2.15.0/keras/utils/steps_per_execution_tuning.py#L180-L183) \n\n stop()\n\nStops steps per execution tuning thread."]]