View source on GitHub
|
Defines a function as a recompute-checkpoint for the tape auto-diff.
tf.recompute_grad(
f
)
Tape checkpointing is a technique to reduce the memory consumption of the auto-diff tape:
Without tape checkpointing operations and intermediate values are recorded to the tape for use in the backward pass.
With tape checkpointing, only the function call and its inputs are recorded. During back-propagation the
recompute_gradcustom gradient (tf.custom_gradient) recomputes the function under a localized Tape object. This recomputation of the function during backpropagation performs redundant calculation, but reduces the overall memory usage of the Tape.
y = tf.Variable(1.0)def my_function(x):tf.print('running')z = x*yreturn z
my_function_recompute = tf.recompute_grad(my_function)with tf.GradientTape() as tape:r = tf.constant(1.0)for i in range(4):r = my_function_recompute(r)runningrunningrunningrunning
grad = tape.gradient(r, [y])runningrunningrunningrunning
Without recompute_grad, the tape contains all intermitate steps, and no
recomputation is performed.
with tf.GradientTape() as tape:r = tf.constant(1.0)for i in range(4):r = my_function(r)runningrunningrunningrunning
grad = tape.gradient(r, [y])If f was a tf.keras Model or Layer object, methods and attributes
such as f.variables are not available on the returned function g.
Either keep a reference of f , or use g.__wrapped__ for accessing
these variables and methods.
def print_running_and_return(x):tf.print("running")return x
model = tf.keras.Sequential([tf.keras.layers.Lambda(print_running_and_return),tf.keras.layers.Dense(2)])
model_recompute = tf.recompute_grad(model)with tf.GradientTape(persistent=True) as tape:r = tf.constant([[1,2]])for i in range(4):r = model_recompute(r)runningrunningrunningrunning
grad = tape.gradient(r, model.variables)runningrunningrunningrunning
Alternatively, use the __wrapped__ attribute to access the original
model object.
grad = tape.gradient(r, model_recompute.__wrapped__.variables)runningrunningrunningrunning
Args | |
|---|---|
f
|
function f(*x) that returns a Tensor or sequence of Tensor outputs.
|
Returns | |
|---|---|
A function g wrapping f that defines a custom gradient, which recomputes
f on the backwards pass of a gradient call.
|
View source on GitHub