In this notebook we introduce Generalized Linear Models via a worked example. We solve this example in two different ways using two algorithms for efficiently fitting GLMs in TensorFlow Probability: Fisher scoring for dense data, and coordinatewise proximal gradient descent for sparse data. We compare the fitted coefficients to the true coefficients and, in the case of coordinatewise proximal gradient descent, to the output of R's similar glmnet
algorithm. Finally, we provide further mathematical details and derivations of several key properties of GLMs.
Background
A generalized linear model (GLM) is a linear model () wrapped in a transformation (link function) and equipped with a response distribution from an exponential family. The choice of link function and response distribution is very flexible, which lends great expressivity to GLMs. The full details, including a sequential presentation of all the definitions and results building up to GLMs in unambiguous notation, are found in "Derivation of GLM Facts" below. We summarize:
In a GLM, a predictive distribution for the response variable is associated with a vector of observed predictors . The distribution has the form:
Here are the parameters ("weights"), a hyperparameter representing dispersion ("variance"), and , , , are characterized by the user-specified model family.
The mean of depends on by composition of linear response and (inverse) link function, i.e.:
where is the so-called link function. In TFP the choice of link function and model family are jointly specifed by a tfp.glm.ExponentialFamily
subclass. Examples include:
tfp.glm.Normal
, aka "linear regression"tfp.glm.Bernoulli
, aka "logistic regression"tfp.glm.Poisson
, aka "Poisson regression"tfp.glm.BernoulliNormalCDF
, aka "probit regression".
TFP prefers to name model families according to the distribution over Y
rather than the link function since tfp.Distribution
s are already first-class citizens. If the tfp.glm.ExponentialFamily
subclass name contains a second word, this indicates a non-canonical link function.
GLMs have several remarkable properties which permit efficient implementation of the maximum likelihood estimator. Chief among these properties are simple formulas for the gradient of the log-likelihood , and for the Fisher information matrix, which is the expected value of the Hessian of the negative log-likelihood under a re-sampling of the response under the same predictors. I.e.:
where is the matrix whose th row is the predictor vector for the th data sample, and is vector whose th coordinate is the observed response for the th data sample. Here (loosely speaking), and , and boldface denotes vectorization of these functions. Full details of what distributions these expectations and variances are over can be found in "Derivation of GLM Facts" below.
An Example
In this section we briefly describe and showcase two built-in GLM fitting algorithms in TensorFlow Probability: Fisher scoring (tfp.glm.fit
) and coordinatewise proximal gradient descent (tfp.glm.fit_sparse
).
Synthetic Data Set
Let's pretend to load some training data set.
import numpy as np
import pandas as pd
import scipy
import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()
import tensorflow_probability as tfp
tfd = tfp.distributions
def make_dataset(n, d, link, scale=1., dtype=np.float32):
model_coefficients = tfd.Uniform(
low=-1., high=np.array(1, dtype)).sample(d, seed=42)
radius = np.sqrt(2.)
model_coefficients *= radius / tf.linalg.norm(model_coefficients)
mask = tf.random.shuffle(tf.range(d)) < int(0.5 * d)
model_coefficients = tf.where(
mask, model_coefficients, np.array(0., dtype))
model_matrix = tfd.Normal(
loc=0., scale=np.array(1, dtype)).sample([n, d], seed=43)
scale = tf.convert_to_tensor(scale, dtype)
linear_response = tf.linalg.matvec(model_matrix, model_coefficients)
if link == 'linear':
response = tfd.Normal(loc=linear_response, scale=scale).sample(seed=44)
elif link == 'probit':
response = tf.cast(
tfd.Normal(loc=linear_response, scale=scale).sample(seed=44) > 0,
dtype)
elif link == 'logit':
response = tfd.Bernoulli(logits=linear_response).sample(seed=44)
else:
raise ValueError('unrecognized true link: {}'.format(link))
return model_matrix, response, model_coefficients, mask
Note: Connect to a local runtime.
In this notebook, we share data between Python and R kernels using local files. To enable this sharing, please use runtimes on the same machine where you have permission to read and write local files.
x, y, model_coefficients_true, _ = [t.numpy() for t in make_dataset(
n=int(1e5), d=100, link='probit')]
DATA_DIR = '/tmp/glm_example'
tf.io.gfile.makedirs(DATA_DIR)
with tf.io.gfile.GFile('{}/x.csv'.format(DATA_DIR), 'w') as f:
np.savetxt(f, x, delimiter=',')
with tf.io.gfile.GFile('{}/y.csv'.format(DATA_DIR), 'w') as f:
np.savetxt(f, y.astype(np.int32) + 1, delimiter=',', fmt='%d')
with tf.io.gfile.GFile(
'{}/model_coefficients_true.csv'.format(DATA_DIR), 'w') as f:
np.savetxt(f, model_coefficients_true, delimiter=',')
Without L1 Regularization
The function tfp.glm.fit
implements Fisher scoring, which takes as some of its arguments:
model_matrix
=response
=model
= callable which, given argument , returns the triple .
We recommend that model
be an instance of the tfp.glm.ExponentialFamily
class. There are several pre-made implementations available, so for most common GLMs no custom code is necessary.
@tf.function(autograph=False)
def fit_model():
model_coefficients, linear_response, is_converged, num_iter = tfp.glm.fit(
model_matrix=x, response=y, model=tfp.glm.BernoulliNormalCDF())
log_likelihood = tfp.glm.BernoulliNormalCDF().log_prob(y, linear_response)
return (model_coefficients, linear_response, is_converged, num_iter,
log_likelihood)
[model_coefficients, linear_response, is_converged, num_iter,
log_likelihood] = [t.numpy() for t in fit_model()]
print(('is_converged: {}\n'
' num_iter: {}\n'
' accuracy: {}\n'
' deviance: {}\n'
'||w0-w1||_2 / (1+||w0||_2): {}'
).format(
is_converged,
num_iter,
np.mean((linear_response > 0.) == y),
2. * np.mean(log_likelihood),
np.linalg.norm(model_coefficients_true - model_coefficients, ord=2) /
(1. + np.linalg.norm(model_coefficients_true, ord=2))
))
is_converged: True num_iter: 6 accuracy: 0.75241 deviance: -0.992436110973 ||w0-w1||_2 / (1+||w0||_2): 0.0231555201462
Mathematical Details
Fisher scoring is a modification of Newton's method to find the maximum-likelihood estimate
Vanilla Newton's method, searching for zeros of the gradient of the log-likelihood, would follow the update rule
where is a learning rate used to control the step size.
In Fisher scoring, we replace the Hessian with the negative Fisher information matrix:
[Note that here is random, whereas is still the vector of observed responses.]
By the formulas in "Fitting GLM Parameters To Data" below, this simplifies to
With L1 Regularization
tfp.glm.fit_sparse
implements a GLM fitter more suited to sparse data sets, based on the algorithm in Yuan, Ho and Lin 2012. Its features include:
- L1 regularization
- No matrix inversions
- Few evaluations of the gradient and Hessian.
We first present an example usage of the code. Details of the algorithm are further elaborated in "Algorithm Details for tfp.glm.fit_sparse
" below.
model = tfp.glm.Bernoulli()
model_coefficients_start = tf.zeros(x.shape[-1], np.float32)
@tf.function(autograph=False)
def fit_model():
return tfp.glm.fit_sparse(
model_matrix=tf.convert_to_tensor(x),
response=tf.convert_to_tensor(y),
model=model,
model_coefficients_start=model_coefficients_start,
l1_regularizer=800.,
l2_regularizer=None,
maximum_iterations=10,
maximum_full_sweeps_per_iteration=10,
tolerance=1e-6,
learning_rate=None)
model_coefficients, is_converged, num_iter = [t.numpy() for t in fit_model()]
coefs_comparison = pd.DataFrame({
'Learned': model_coefficients,
'True': model_coefficients_true,
})
print(('is_converged: {}\n'
' num_iter: {}\n\n'
'Coefficients:').format(
is_converged,
num_iter))
coefs_comparison
is_converged: True num_iter: 1 Coefficients:
Note that the learned coefficients have the same sparsity pattern as the true coefficients.
# Save the learned coefficients to a file.
with tf.io.gfile.GFile('{}/model_coefficients_prox.csv'.format(DATA_DIR), 'w') as f:
np.savetxt(f, model_coefficients, delimiter=',')
Compare to R's glmnet
We compare the output of coordinatewise proximal gradient descent to that of R's glmnet
, which uses a similar algorithm.
NOTE: To execute this section, you must switch to an R colab runtime.
suppressMessages({
library('glmnet')
})
data_dir <- '/tmp/glm_example'
x <- as.matrix(read.csv(paste(data_dir, '/x.csv', sep=''),
header=FALSE))
y <- as.matrix(read.csv(paste(data_dir, '/y.csv', sep=''),
header=FALSE, colClasses='integer'))
fit <- glmnet(
x = x,
y = y,
family = "binomial", # Logistic regression
alpha = 1, # corresponds to l1_weight = 1, l2_weight = 0
standardize = FALSE,
intercept = FALSE,
thresh = 1e-30,
type.logistic = "Newton"
)
write.csv(as.matrix(coef(fit, 0.008)),
paste(data_dir, '/model_coefficients_glmnet.csv', sep=''),
row.names=FALSE)
Compare R, TFP and true coefficients (Note: back to Python kernel)
DATA_DIR = '/tmp/glm_example'
with tf.io.gfile.GFile('{}/model_coefficients_glmnet.csv'.format(DATA_DIR),
'r') as f:
model_coefficients_glmnet = np.loadtxt(f,
skiprows=2 # Skip column name and intercept
)
with tf.io.gfile.GFile('{}/model_coefficients_prox.csv'.format(DATA_DIR),
'r') as f:
model_coefficients_prox = np.loadtxt(f)
with tf.io.gfile.GFile(
'{}/model_coefficients_true.csv'.format(DATA_DIR), 'r') as f:
model_coefficients_true = np.loadtxt(f)
coefs_comparison = pd.DataFrame({
'TFP': model_coefficients_prox,
'R': model_coefficients_glmnet,
'True': model_coefficients_true,
})
coefs_comparison
Algorithm Details for tfp.glm.fit_sparse
We present the algorithm as a sequence of three modifications to Newton's method. In each one, the update rule for is based on a vector and a matrix which approximate the gradient and Hessian of the log-likelihood. In step , we choose a coordinate to change, and we update according to the update rule:
This update is a Newton-like step with learning rate . Except for the final piece (L1 regularization), the modifications below differ only in how they update and .
Starting point: Coordinatewise Newton's method
In coordinatewise Newton's method, we set and to the true gradient and Hessian of the log-likelihood:
Fewer evaluations of the gradient and Hessian
The gradient and Hessian of the log-likelihood are often expensive to compute, so it is often worthwhile to approximate them. We can do so as follows:
- Usually, approximate the Hessian as locally constant and approximate the gradient to first order using the (approximate) Hessian:
- Occasionally, perform a "vanilla" update step as above, setting to the exact gradient and to the exact Hessian of the log-likelihood, evaluated at .
Substitute negative Fisher information for Hessian
To further reduce the cost of the vanilla update steps, we can set to the negative Fisher information matrix (efficiently computable using the formulas in "Fitting GLM Parameters to Data" below) rather than the exact Hessian:
L1 Regularization via Proximal Gradient Descent
To incorporate L1 regularization, we replace the update rule
with the more general update rule
where is a supplied constant (the L1 regularization coefficient) and is the soft thresholding operator, defined by
This update rule has the following two inspirational properties, which we explain below:
In the limiting case (i.e., no L1 regularization), this update rule is identical to the original update rule.
This update rule can be interpreted as applying a proximity operator whose fixed point is the solution to the L1-regularized minimization problem
$$ \underset{\beta - \beta^{(t)} \in \text{span}{ \text{onehot}(j^{(t)}) } }{\text{arg min} } \left( -\ell(\beta \,;\, \mathbf{x}, \mathbf{y})
- r_{\text{L1} } \left\lVert \beta \right\rVert_1 \right). $$
Degenerate case recovers the original update rule
To see (1), note that if then , hence
Hence
Proximity operator whose fixed point is the regularized MLE
To see (2), first note (see Wikipedia) that for any , the update rule
satisfies (2), where is the proximity operator (see Yu, where this operator is denoted ). The right-hand side of the above equation is computed here:
$$
\left(\beta{\text{exact-prox}, \gamma}^{(t+1)}\right){j^{(t)} }
\text{SoftThreshold} \left( \beta^{(t)}{j^{(t)} } + \frac{\gamma}{r{\text{L1} } } \left( \left( \nabla\beta\, \ell(\beta \,;\, \mathbf{x}, \mathbf{y}) \right){\beta = \beta^{(t)} } \right)_{j^{(t)} } ,\ \gamma \right). $$
In particular, setting (note that as long as the negative log-likelihood is convex), we obtain the update rule
$$
\left(\beta{\text{exact-prox}, \gamma^{(t)} }^{(t+1)}\right){j^{(t)} }
\text{SoftThreshold} \left( \beta^{(t)}{j^{(t)} } - \alpha \frac{ \left( \left( \nabla\beta\, \ell(\beta \,;\, \mathbf{x}, \mathbf{y}) \right){\beta = \beta^{(t)} } \right){j^{(t)} } }{ \left(H^{(t)}\right)_{j^{(t)}, j^{(t)} } } ,\ \gamma^{(t)} \right). $$
We then replace the exact gradient $\left( \nabla\beta\, \ell(\beta \,;\, \mathbf{x}, \mathbf{y}) \right){\beta = \beta^{(t)} }$ with its approximation , obtaining
\begin{align} \left(\beta{\text{exact-prox}, \gamma^{(t)} }^{(t+1)}\right){j^{(t)} } &\approx \text{SoftThreshold} \left( \beta^{(t)}{j^{(t)} } - \alpha \frac{ \left(s^{(t)}\right){j^{(t)} } }{ \left(H^{(t)}\right){j^{(t)}, j^{(t)} } } ,\ \gamma^{(t)} \right) \ &= \text{SoftThreshold} \left( \beta^{(t)}{j^{(t)} } - \alpha\, u^{(t)} ,\ \gamma^{(t)} \right). \end{align}
Hence
Derivation of GLM Facts
In this section we state in full detail and derive the results about GLMs that are used in the preceding sections. Then, we use TensorFlow's gradients
to numerically verify the derived formulas for gradient of the log-likelihood and Fisher information.
Score and Fisher information
Consider a family of probability distributions parameterized by parameter vector , having probability densities . The score of an outcome at parameter vector is defined to be the gradient of the log likelihood of (evaluated at ), that is,
Claim: Expectation of the score is zero
Under mild regularity conditions (permitting us to pass differentiation under the integral),
Proof
We have
where we have used: (1) chain rule for differentiation, (2) definition of expectation, (3) passing differentiation under the integral sign (using the regularity conditions), (4) the integral of a probability density is 1.
Claim (Fisher information): Variance of the score equals negative expected Hessian of the log likelihood
Under mild regularity conditions (permitting us to pass differentiation under the integral),
$$ \mathbb{E}_{Y \sim p(\cdot | \theta=\theta_0)}\left[ \text{score}(Y, \theta_0) \text{score}(Y, \theta_0)^\top
\right]
-\mathbb{E}_{Y \sim p(\cdot | \theta=\theta0)}\left[ \left(\nabla\theta^2 \log p(Y | \theta)\right)_{\theta=\theta_0} \right] $$
where denotes the Hessian matrix, whose entry is .
The left-hand side of this equation is called the Fisher information of the family at parameter vector .
Proof of claim
We have
where we have used (1) chain rule for differentiation, (2) quotient rule for differentiation, (3) chain rule again, in reverse.
To complete the proof, it suffices to show that
To do that, we pass differentiation under the integral sign twice:
Lemma about the derivative of the log partition function
If , and are scalar-valued functions, twice differentiable, such that the family of distributions defined by
satisfies the mild regularity conditions that permit passing differentiation with respect to under an integral with respect to , then
and
(Here denotes differentiation, so and are the first and second derivatives of . )
Proof
For this family of distributions, we have . The first equation then follows from the fact that . Next, we have
Overdispersed Exponential Family
A (scalar) overdispersed exponential family is a family of distributions whose densities take the form
where and are known scalar-valued functions, and and are scalar parameters.
[Note that is overdetermined: for any , the function is completely determined by the constraint that for all . The 's produced by different values of must all be the same, which places a constraint on the functions and .]
Mean and variance of the sufficient statistic
Under the same conditions as "Lemma about the derivative of the log partition function," we have
$$ \mathbb{E}{Y \sim p{\text{OEF}(m, T)}(\cdot | \theta, \phi)} \left[ T(Y)
\right]
A'(\theta) $$
and
$$ \text{Var}{Y \sim p{\text{OEF}(m, T)}(\cdot | \theta, \phi)} \left[ T(Y)
\right]
\phi A''(\theta). $$
Proof
By "Lemma about the derivative of the log partition function," we have
$$ \mathbb{E}{Y \sim p{\text{OEF}(m, T)}(\cdot | \theta, \phi)} \left[ \frac{T(Y)}{\phi}
\right]
\frac{A'(\theta)}{\phi} $$
and
$$ \text{Var}{Y \sim p{\text{OEF}(m, T)}(\cdot | \theta, \phi)} \left[ \frac{T(Y)}{\phi}
\right]
\frac{A''(\theta)}{\phi}. $$
The result then follows from the fact that expectation is linear () and variance is degree-2 homogeneous ().
Generalized Linear Model
In a generalized linear model, a predictive distribution for the response variable is associated with a vector of observed predictors . The distribution is a member of an overdispersed exponential family, and the parameter is replaced by where is a known function, is the so-called linear response, and is a vector of parameters (regression coefficients) to be learned. In general the dispersion parameter could be learned too, but in our setup we will treat as known. So our setup is
where the model structure is characterized by the distribution and the function which converts linear response to parameters.
Traditionally, the mapping from linear response to mean is denoted
This mapping is required to be one-to-one, and its inverse, , is called the link function for this GLM. Typically, one describes a GLM by naming its link function and its family of distributions -- e.g., a "GLM with Bernoulli distribution and logit link function" (also known as a logistic regression model). In order to fully characterize the GLM, the function must also be specified. If is the identity, then is said to be the canonical link function.
Claim: Expressing in terms of the sufficient statistic
Define
and
Then we have
Proof
By "Mean and variance of the sufficient statistic," we have
Differentiating with the chain rule, we obtain
and by "Mean and variance of the sufficient statistic,"
The conclusion follows.
Fitting GLM Parameters to Data
The properties derived above lend themselves very well to fitting GLM parameters to a data set. Quasi-Newton methods such as Fisher scoring rely on the gradient of the log likelihood and the Fisher information, which we now show can be computed especially efficiently for a GLM.
Suppose we have observed predictor vectors and associated scalar responses . In matrix form, we'll say we have observed predictors and response , where is the matrix whose th row is and is the vector whose th element is . The log likelihood of parameters is then
For a single data sample
To simplify the notation, let's first consider the case of a single data point, ; then we will extend to the general case by additivity.
Gradient
We have
Hence by the chain rule,
Separately, by "Mean and variance of the sufficient statistic," we have . Hence, by "Claim: Expressing in terms of the sufficient statistic," we have
Hessian
Differentiating a second time, by the product rule we obtain
Fisher information
By "Mean and variance of the sufficient statistic," we have
Hence
For multiple data samples
We now extend the case to the general case. Let denote the vector whose th coordinate is the linear response from the th data sample. Let (resp. , resp. ) denote the broadcasted (vectorized) function which applies the scalar-valued function (resp. , resp. ) to each coordinate. Then we have
and
where the fractions denote element-wise division.
Verifying the Formulas Numerically
We now verify the above formula for gradient of the log likelihood numerically using tf.gradients
, and verify the formula for Fisher information with a Monte Carlo estimate using tf.hessians
:
def VerifyGradientAndFIM():
model = tfp.glm.BernoulliNormalCDF()
model_matrix = np.array([[1., 5, -2],
[8, -1, 8]])
def _naive_grad_and_hessian_loss_fn(x, response):
# Computes gradient and Hessian of negative log likelihood using autodiff.
predicted_linear_response = tf.linalg.matvec(model_matrix, x)
log_probs = model.log_prob(response, predicted_linear_response)
grad_loss = tf.gradients(-log_probs, [x])[0]
hessian_loss = tf.hessians(-log_probs, [x])[0]
return [grad_loss, hessian_loss]
def _grad_neg_log_likelihood_and_fim_fn(x, response):
# Computes gradient of negative log likelihood and Fisher information matrix
# using the formulas above.
predicted_linear_response = tf.linalg.matvec(model_matrix, x)
mean, variance, grad_mean = model(predicted_linear_response)
v = (response - mean) * grad_mean / variance
grad_log_likelihood = tf.linalg.matvec(model_matrix, v, adjoint_a=True)
w = grad_mean**2 / variance
fisher_info = tf.linalg.matmul(
model_matrix,
w[..., tf.newaxis] * model_matrix,
adjoint_a=True)
return [-grad_log_likelihood, fisher_info]
@tf.function(autograph=False)
def compute_grad_hessian_estimates():
# Monte Carlo estimate of E[Hessian(-LogLikelihood)], where the expectation is
# as written in "Claim (Fisher information)" above.
num_trials = 20
trial_outputs = []
np.random.seed(10)
model_coefficients_ = np.random.random(size=(model_matrix.shape[1],))
model_coefficients = tf.convert_to_tensor(model_coefficients_)
for _ in range(num_trials):
# Sample from the distribution of `model`
response = np.random.binomial(
1,
scipy.stats.norm().cdf(np.matmul(model_matrix, model_coefficients_))
).astype(np.float64)
trial_outputs.append(
list(_naive_grad_and_hessian_loss_fn(model_coefficients, response)) +
list(
_grad_neg_log_likelihood_and_fim_fn(model_coefficients, response))
)
naive_grads = tf.stack(
list(naive_grad for [naive_grad, _, _, _] in trial_outputs), axis=0)
fancy_grads = tf.stack(
list(fancy_grad for [_, _, fancy_grad, _] in trial_outputs), axis=0)
average_hess = tf.reduce_mean(tf.stack(
list(hess for [_, hess, _, _] in trial_outputs), axis=0), axis=0)
[_, _, _, fisher_info] = trial_outputs[0]
return naive_grads, fancy_grads, average_hess, fisher_info
naive_grads, fancy_grads, average_hess, fisher_info = [
t.numpy() for t in compute_grad_hessian_estimates()]
print("Coordinatewise relative error between naively computed gradients and"
" formula-based gradients (should be zero):\n{}\n".format(
(naive_grads - fancy_grads) / naive_grads))
print("Coordinatewise relative error between average of naively computed"
" Hessian and formula-based FIM (should approach zero as num_trials"
" -> infinity):\n{}\n".format(
(average_hess - fisher_info) / average_hess))
VerifyGradientAndFIM()
Coordinatewise relative error between naively computed gradients and formula-based gradients (should be zero): [[2.08845965e-16 1.67076772e-16 2.08845965e-16] [1.96118673e-16 3.13789877e-16 1.96118673e-16] [2.08845965e-16 1.67076772e-16 2.08845965e-16] [1.96118673e-16 3.13789877e-16 1.96118673e-16] [2.08845965e-16 1.67076772e-16 2.08845965e-16] [1.96118673e-16 3.13789877e-16 1.96118673e-16] [1.96118673e-16 3.13789877e-16 1.96118673e-16] [1.96118673e-16 3.13789877e-16 1.96118673e-16] [2.08845965e-16 1.67076772e-16 2.08845965e-16] [1.96118673e-16 3.13789877e-16 1.96118673e-16] [2.08845965e-16 1.67076772e-16 2.08845965e-16] [1.96118673e-16 3.13789877e-16 1.96118673e-16] [1.96118673e-16 3.13789877e-16 1.96118673e-16] [1.96118673e-16 3.13789877e-16 1.96118673e-16] [1.96118673e-16 3.13789877e-16 1.96118673e-16] [1.96118673e-16 3.13789877e-16 1.96118673e-16] [1.96118673e-16 3.13789877e-16 1.96118673e-16] [2.08845965e-16 1.67076772e-16 2.08845965e-16] [1.96118673e-16 3.13789877e-16 1.96118673e-16] [2.08845965e-16 1.67076772e-16 2.08845965e-16]] Coordinatewise relative error between average of naively computed Hessian and formula-based FIM (should approach zero as num_trials -> infinity): [[0.00072369 0.00072369 0.00072369] [0.00072369 0.00072369 0.00072369] [0.00072369 0.00072369 0.00072369]]
References
[1]: Guo-Xun Yuan, Chia-Hua Ho and Chih-Jen Lin. An Improved GLMNET for L1-regularized Logistic Regression. Journal of Machine Learning Research, 13, 2012. http://www.jmlr.org/papers/volume13/yuan12a/yuan12a.pdf
[2]: skd. Derivation of Soft Thresholding Operator. 2018. https://math.stackexchange.com/q/511106
[3]: Wikipedia Contributors. Proximal gradient methods for learning. Wikipedia, The Free Encyclopedia, 2018. https://en.wikipedia.org/wiki/Proximal_gradient_methods_for_learning
[4]: Yao-Liang Yu. The Proximity Operator. https://www.cs.cmu.edu/~suvrit/teach/yaoliang_proximity.pdf