This class allows to extend optimizers with decoupled weight decay.
tfa.optimizers.DecoupledWeightDecayExtension(
weight_decay: Union[FloatTensorLike, Callable],
exclude_from_weight_decay: Optional[List[str]] = None,
**kwargs
)
It implements the decoupled weight decay described by Loshchilov & Hutter, in which the weight decay is decoupled from the optimization steps w.r.t. to the loss function. For SGD variants, this simplifies hyperparameter search since it decouples the settings of weight decay and learning rate. For adaptive gradient algorithms, it regularizes variables with large gradients more than L2 regularization would, which was shown to yield better training loss and generalization error in the paper above.
This class alone is not an optimizer but rather extends existing
optimizers with decoupled weight decay. We explicitly define the two
examples used in the above paper (SGDW and AdamW), but in general this can
extend any OptimizerX class by using
ExtendedCls = extend_with_decoupled_weight_decay(OptimizerX)
.
Weight decay can then be set when instantiating the optimizer:
optimizerX = ExtendedCls(weight_decay=0.001, learning_rate=0.001)
.
In order for it to work, it must be the first class the Optimizer with
weight decay inherits from, e.g.
class AdamW(DecoupledWeightDecayExtension, tf.keras.optimizers.Adam):
def __init__(self, weight_decay, *args, **kwargs):
super(AdamW, self).__init__(weight_decay, *args, **kwargs).
step = tf.Variable(0, trainable=False)
schedule = tf.optimizers.schedules.PiecewiseConstantDecay(
[10000, 15000], [1e-0, 1e-1, 1e-2])
# lr and wd can be a function or a tensor
lr = 1e-1 * schedule(step)
wd = lambda: 1e-4 * schedule(step)
# ...
optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd)
Args | |
---|---|
weight_decay
|
A Tensor , a floating point value, or a schedule
that is a tf.keras.optimizers.schedules.LearningRateSchedule
to decay the variable by, in the update step.
|
exclude_from_weight_decay
|
List of regex patterns of
variables excluded from weight decay. Variables whose name
contain a substring matching the pattern will be excluded.
Note decay_var_list in minimize or apply_gradients takes
priority over exclude_from_weight_decay if specified.
|
**kwargs
|
Optional list or tuple or set of Variable objects to
decay.
|
Methods
apply_gradients
apply_gradients(
grads_and_vars, name=None, decay_var_list=None, **kwargs
)
Apply gradients to variables.
This is the second part of minimize()
. It returns an Operation
that
applies gradients.
Args | |
---|---|
grads_and_vars
|
List of (gradient, variable) pairs. |
name
|
Optional name for the returned operation. Default to the
name passed to the Optimizer constructor.
|
decay_var_list
|
Optional list of variables to be decayed. Defaults
to all variables in var_list. Note decay_var_list takes
priority over exclude_from_weight_decay if specified.
|
**kwargs
|
Additional arguments to pass to the base optimizer's
apply_gradient method, e.g., TF2.2 added an argument
experimental_aggregate_gradients .
|
Returns | |
---|---|
An Operation that applies the specified gradients.
|
Raises | |
---|---|
TypeError
|
If grads_and_vars is malformed.
|
ValueError
|
If none of the variables have gradients. |
from_config
@classmethod
from_config( config, custom_objects=None )
get_config
get_config()
minimize
minimize(
loss, var_list, grad_loss=None, name=None, decay_var_list=None, tape=None
)
Minimize loss
by updating var_list
.
This method simply computes gradient using tf.GradientTape
and calls
apply_gradients()
. If you want to process the gradient before
applying then call tf.GradientTape
and apply_gradients()
explicitly
instead of using this function.
Args | |
---|---|
loss
|
Tensor or callable. If a callable, loss should take no
arguments and return the value to minimize. If a Tensor , the
tape argument must be passed.
|
var_list
|
list or tuple of Variable objects to update to
minimize loss , or a callable returning the list or tuple of
Variable objects. Use callable when the variable list would
otherwise be incomplete before minimize since the variables
are created at the first time loss is called.
|
grad_loss
|
Optional. A Tensor holding the gradient computed for
loss .
|
decay_var_list
|
Optional list of variables to be decayed. Defaults
to all variables in var_list. Note decay_var_list takes
priority over exclude_from_weight_decay if specified.
|
name
|
Optional name for the returned operation. |
tape
|
(Optional) tf.GradientTape . If loss is provided as a
Tensor , the tape that computed the loss must be provided.
|
Returns | |
---|---|
An Operation that updates the variables in var_list .
|
Raises | |
---|---|
ValueError
|
If some of the variables are not Variable objects.
|