View source on GitHub |
Returns a tff.learning.optimizers.Optimizer
for AdamW.
tff.learning.optimizers.build_adamw(
learning_rate: optimizer.Float,
beta_1: optimizer.Float = 0.9,
beta_2: optimizer.Float = 0.999,
epsilon: optimizer.Float = 1e-07,
weight_decay: optimizer.Float = 0.004
) -> tff.learning.optimizers.Optimizer
The AdamW optimizer is based on Decoupled Weight Decay Regularization
The update rule given learning rate lr
, epsilon eps
, accumulator acc
,
preconditioner s
, weigh decay lambda
, iteration t
, weights w
and
gradients g
is:
acc = beta_1 * acc + (1 - beta_1) * g
s = beta_2 * s + (1 - beta_2) * g**2
normalization = sqrt(1 - beta_2**t) / (1 - beta_1**t)
w = w - lr * (normalization * acc / (sqrt(s) + eps) + lambda * w)