tfp.experimental.substrates.jax.math.soft_threshold

Soft Thresholding operator.

This operator is defined by the equations

                              { x[i] - gamma,  x[i] >   gamma
SoftThreshold(x, gamma)[i] =  { 0,             x[i] ==  gamma
                              { x[i] + gamma,  x[i] <  -gamma

In the context of proximal gradient methods, we have

SoftThreshold(x, gamma) = prox_{gamma L1}(x)

where prox is the proximity operator. Thus the soft thresholding operator is used in proximal gradient descent for optimizing a smooth function with (non-smooth) L1 regularization, as outlined below.

The proximity operator is defined as:

prox_r(x) = argmin{ r(z) + 0.5 ||x - z||_2**2 : z },

where r is a (weakly) convex function, not necessarily differentiable. Because the L2 norm is strictly convex, the above argmin is unique.

One important application of the proximity operator is as follows. Let L be a convex and differentiable function with Lipschitz-continuous gradient. Let R be a convex lower semicontinuous function which is possibly nondifferentiable. Let gamma be an arbitrary positive real. Then

x_star = argmin{ L(x) + R(x) : x }

if and only if the fixed-point equation is satisfied:

x_star = prox_{gamma R}(x_star - gamma grad L(x_star))

Proximal gradient descent thus typically consists of choosing an initial value x^{(0)} and repeatedly applying the update

x^{(k+1)} = prox_{gamma^{(k)} R}(x^{(k)} - gamma^{(k)} grad L(x^{(k)}))

where gamma is allowed to vary from iteration to iteration. Specializing to the case where R(x) = ||x||_1, we minimize L(x) + ||x||_1 by repeatedly applying the update

x^{(k+1)} = SoftThreshold(x - gamma grad L(x^{(k)}), gamma)

(This idea can also be extended to second-order approximations, although the multivariate case does not have a known closed form like above.)

x float Tensor representing the input to the SoftThreshold function.
threshold nonnegative scalar, float Tensor representing the radius of the interval on which each coordinate of SoftThreshold takes the value zero. Denoted gamma above.
name Python string indicating the name of the TensorFlow operation. Default value: 'soft_threshold'.

softthreshold float Tensor with the same shape and dtype as x, representing the value of the SoftThreshold function.

References

[1]: Yu, Yao-Liang. The Proximity Operator. https://www.cs.cmu.edu/~suvrit/teach/yaoliang_proximity.pdf

[2]: Wikipedia Contributors. Proximal gradient methods for learning. Wikipedia, The Free Encyclopedia, 2018. https://en.wikipedia.org/wiki/Proximal_gradient_methods_for_learning