Generalized Linear Models

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

In this notebook we introduce Generalized Linear Models via a worked example. We solve this example in two different ways using two algorithms for efficiently fitting GLMs in TensorFlow Probability: Fisher scoring for dense data, and coordinatewise proximal gradient descent for sparse data. We compare the fitted coefficients to the true coefficients and, in the case of coordinatewise proximal gradient descent, to the output of R's similar glmnet algorithm. Finally, we provide further mathematical details and derivations of several key properties of GLMs.

Background

A generalized linear model (GLM) is a linear model (\(\eta = x^\top \beta\)) wrapped in a transformation (link function) and equipped with a response distribution from an exponential family. The choice of link function and response distribution is very flexible, which lends great expressivity to GLMs. The full details, including a sequential presentation of all the definitions and results building up to GLMs in unambiguous notation, are found in "Derivation of GLM Facts" below. We summarize:

In a GLM, a predictive distribution for the response variable \(Y\) is associated with a vector of observed predictors \(x\). The distribution has the form:

\[ \begin{align*} p(y \, |\, x) &= m(y, \phi) \exp\left(\frac{\theta\, T(y) - A(\theta)}{\phi}\right) \\ \theta &:= h(\eta) \\ \eta &:= x^\top \beta \end{align*} \]

Here \(\beta\) are the parameters ("weights"), \(\phi\) a hyperparameter representing dispersion ("variance"), and \(m\), \(h\), \(T\), \(A\) are characterized by the user-specified model family.

The mean of \(Y\) depends on \(x\) by composition of linear response \(\eta\) and (inverse) link function, i.e.:

\[ \mu := g^{-1}(\eta) \]

where \(g\) is the so-called link function. In TFP the choice of link function and model family are jointly specifed by a tfp.glm.ExponentialFamily subclass. Examples include:

TFP prefers to name model families according to the distribution over Y rather than the link function since tfp.Distributions are already first-class citizens. If the tfp.glm.ExponentialFamily subclass name contains a second word, this indicates a non-canonical link function.

GLMs have several remarkable properties which permit efficient implementation of the maximum likelihood estimator. Chief among these properties are simple formulas for the gradient of the log-likelihood \(\ell\), and for the Fisher information matrix, which is the expected value of the Hessian of the negative log-likelihood under a re-sampling of the response under the same predictors. I.e.:

\[ \begin{align*} \nabla_\beta\, \ell(\beta\, ;\, \mathbf{x}, \mathbf{y}) &= \mathbf{x}^\top \,\text{diag}\left(\frac{ {\textbf{Mean}_T}'(\mathbf{x} \beta) }{ {\textbf{Var}_T}(\mathbf{x} \beta) }\right) \left(\mathbf{T}(\mathbf{y}) - {\textbf{Mean}_T}(\mathbf{x} \beta)\right) \\ \mathbb{E}_{Y_i \sim \text{GLM} | x_i} \left[ \nabla_\beta^2\, \ell(\beta\, ;\, \mathbf{x}, \mathbf{Y}) \right] &= -\mathbf{x}^\top \,\text{diag}\left( \frac{ \phi\, {\textbf{Mean}_T}'(\mathbf{x} \beta)^2 }{ {\textbf{Var}_T}(\mathbf{x} \beta) }\right)\, \mathbf{x} \end{align*} \]

where \(\mathbf{x}\) is the matrix whose \(i\)th row is the predictor vector for the \(i\)th data sample, and \(\mathbf{y}\) is vector whose \(i\)th coordinate is the observed response for the \(i\)th data sample. Here (loosely speaking), \({\text{Mean}_T}(\eta) := \mathbb{E}[T(Y)\,|\,\eta]\) and \({\text{Var}_T}(\eta) := \text{Var}[T(Y)\,|\,\eta]\), and boldface denotes vectorization of these functions. Full details of what distributions these expectations and variances are over can be found in "Derivation of GLM Facts" below.

An Example

In this section we briefly describe and showcase two built-in GLM fitting algorithms in TensorFlow Probability: Fisher scoring (tfp.glm.fit) and coordinatewise proximal gradient descent (tfp.glm.fit_sparse).

Synthetic Data Set

Let's pretend to load some training data set.

import numpy as np
import pandas as pd
import scipy
import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()
import tensorflow_probability as tfp
tfd = tfp.distributions
def make_dataset(n, d, link, scale=1., dtype=np.float32):
  model_coefficients = tfd.Uniform(
      low=-1., high=np.array(1, dtype)).sample(d, seed=42)
  radius = np.sqrt(2.)
  model_coefficients *= radius / tf.linalg.norm(model_coefficients)
  mask = tf.random.shuffle(tf.range(d)) < int(0.5 * d)
  model_coefficients = tf.where(
      mask, model_coefficients, np.array(0., dtype))
  model_matrix = tfd.Normal(
      loc=0., scale=np.array(1, dtype)).sample([n, d], seed=43)
  scale = tf.convert_to_tensor(scale, dtype)
  linear_response = tf.linalg.matvec(model_matrix, model_coefficients)

  if link == 'linear':
    response = tfd.Normal(loc=linear_response, scale=scale).sample(seed=44)
  elif link == 'probit':
    response = tf.cast(
        tfd.Normal(loc=linear_response, scale=scale).sample(seed=44) > 0,
                   dtype)
  elif link == 'logit':
    response = tfd.Bernoulli(logits=linear_response).sample(seed=44)
  else:
    raise ValueError('unrecognized true link: {}'.format(link))
  return model_matrix, response, model_coefficients, mask

Note: Connect to a local runtime.

In this notebook, we share data between Python and R kernels using local files. To enable this sharing, please use runtimes on the same machine where you have permission to read and write local files.

x, y, model_coefficients_true, _ = [t.numpy() for t in make_dataset(
    n=int(1e5), d=100, link='probit')]

DATA_DIR = '/tmp/glm_example'
tf.io.gfile.makedirs(DATA_DIR)
with tf.io.gfile.GFile('{}/x.csv'.format(DATA_DIR), 'w') as f:
  np.savetxt(f, x, delimiter=',')
with tf.io.gfile.GFile('{}/y.csv'.format(DATA_DIR), 'w') as f:
  np.savetxt(f, y.astype(np.int32) + 1, delimiter=',', fmt='%d')
with tf.io.gfile.GFile(
    '{}/model_coefficients_true.csv'.format(DATA_DIR), 'w') as f:
  np.savetxt(f, model_coefficients_true, delimiter=',')

Without L1 Regularization

The function tfp.glm.fit implements Fisher scoring, which takes as some of its arguments:

  • model_matrix = \(\mathbf{x}\)
  • response = \(\mathbf{y}\)
  • model = callable which, given argument \(\boldsymbol{\eta}\), returns the triple $\left( {\textbf{Mean}_T}(\boldsymbol{\eta}), {\textbf{Var}_T}(\boldsymbol{\eta}), {\textbf{Mean}_T}'(\boldsymbol{\eta}) \right)$.

We recommend that model be an instance of the tfp.glm.ExponentialFamily class. There are several pre-made implementations available, so for most common GLMs no custom code is necessary.

@tf.function(autograph=False)
def fit_model():
  model_coefficients, linear_response, is_converged, num_iter = tfp.glm.fit(
      model_matrix=x, response=y, model=tfp.glm.BernoulliNormalCDF())
  log_likelihood = tfp.glm.BernoulliNormalCDF().log_prob(y, linear_response)
  return (model_coefficients, linear_response, is_converged, num_iter,
          log_likelihood)

[model_coefficients, linear_response, is_converged, num_iter,
 log_likelihood] = [t.numpy() for t in fit_model()]

print(('is_converged: {}\n'
       '    num_iter: {}\n'
       '    accuracy: {}\n'
       '    deviance: {}\n'
       '||w0-w1||_2 / (1+||w0||_2): {}'
      ).format(
    is_converged,
    num_iter,
    np.mean((linear_response > 0.) == y),
    2. * np.mean(log_likelihood),
    np.linalg.norm(model_coefficients_true - model_coefficients, ord=2) /
        (1. + np.linalg.norm(model_coefficients_true, ord=2))
    ))
is_converged: True
    num_iter: 6
    accuracy: 0.75241
    deviance: -0.992436110973
||w0-w1||_2 / (1+||w0||_2): 0.0231555201462

Mathematical Details

Fisher scoring is a modification of Newton's method to find the maximum-likelihood estimate

\[ \hat\beta := \underset{\beta}{\text{arg max} }\ \ \ell(\beta\ ;\ \mathbf{x}, \mathbf{y}). \]

Vanilla Newton's method, searching for zeros of the gradient of the log-likelihood, would follow the update rule

\[ \beta^{(t+1)}_{\text{Newton} } := \beta^{(t)} - \alpha \left( \nabla^2_\beta\, \ell(\beta\ ;\ \mathbf{x}, \mathbf{y}) \right)_{\beta = \beta^{(t)} }^{-1} \left( \nabla_\beta\, \ell(\beta\ ;\ \mathbf{x}, \mathbf{y}) \right)_{\beta = \beta^{(t)} } \]

where \(\alpha \in (0, 1]\) is a learning rate used to control the step size.

In Fisher scoring, we replace the Hessian with the negative Fisher information matrix:

\[ \begin{align*} \beta^{(t+1)} &:= \beta^{(t)} - \alpha\, \mathbb{E}_{ Y_i \sim p_{\text{OEF}(m, T)}(\cdot | \theta = h(x_i^\top \beta^{(t)}), \phi) } \left[ \left( \nabla^2_\beta\, \ell(\beta\ ;\ \mathbf{x}, \mathbf{Y}) \right)_{\beta = \beta^{(t)} } \right]^{-1} \left( \nabla_\beta\, \ell(\beta\ ;\ \mathbf{x}, \mathbf{y}) \right)_{\beta = \beta^{(t)} } \\[3mm] \end{align*} \]

[Note that here \(\mathbf{Y} = (Y_i)_{i=1}^{n}\) is random, whereas \(\mathbf{y}\) is still the vector of observed responses.]

By the formulas in "Fitting GLM Parameters To Data" below, this simplifies to

\[ \begin{align*} \beta^{(t+1)} &= \beta^{(t)} + \alpha \left( \mathbf{x}^\top \text{diag}\left( \frac{ \phi\, {\textbf{Mean}_T}'(\mathbf{x} \beta^{(t)})^2 }{ {\textbf{Var}_T}(\mathbf{x} \beta^{(t)}) }\right)\, \mathbf{x} \right)^{-1} \left( \mathbf{x}^\top \text{diag}\left(\frac{ {\textbf{Mean}_T}'(\mathbf{x} \beta^{(t)}) }{ {\textbf{Var}_T}(\mathbf{x} \beta^{(t)}) }\right) \left(\mathbf{T}(\mathbf{y}) - {\textbf{Mean}_T}(\mathbf{x} \beta^{(t)})\right) \right). \end{align*} \]

With L1 Regularization

tfp.glm.fit_sparse implements a GLM fitter more suited to sparse data sets, based on the algorithm in Yuan, Ho and Lin 2012. Its features include:

  • L1 regularization
  • No matrix inversions
  • Few evaluations of the gradient and Hessian.

We first present an example usage of the code. Details of the algorithm are further elaborated in "Algorithm Details for tfp.glm.fit_sparse" below.

model = tfp.glm.Bernoulli()
model_coefficients_start = tf.zeros(x.shape[-1], np.float32)
@tf.function(autograph=False)
def fit_model():
  return tfp.glm.fit_sparse(
    model_matrix=tf.convert_to_tensor(x),
    response=tf.convert_to_tensor(y),
    model=model,
    model_coefficients_start=model_coefficients_start,
    l1_regularizer=800.,
    l2_regularizer=None,
    maximum_iterations=10,
    maximum_full_sweeps_per_iteration=10,
    tolerance=1e-6,
    learning_rate=None)

model_coefficients, is_converged, num_iter = [t.numpy() for t in fit_model()]
coefs_comparison = pd.DataFrame({
  'Learned': model_coefficients,
  'True': model_coefficients_true,
})

print(('is_converged: {}\n'
       '    num_iter: {}\n\n'
       'Coefficients:').format(
    is_converged,
    num_iter))
coefs_comparison
is_converged: True
    num_iter: 1

Coefficients:

Note that the learned coefficients have the same sparsity pattern as the true coefficients.

# Save the learned coefficients to a file.
with tf.io.gfile.GFile('{}/model_coefficients_prox.csv'.format(DATA_DIR), 'w') as f:
  np.savetxt(f, model_coefficients, delimiter=',')

Compare to R's glmnet

We compare the output of coordinatewise proximal gradient descent to that of R's glmnet, which uses a similar algorithm.

NOTE: To execute this section, you must switch to an R colab runtime.

suppressMessages({
  library('glmnet')
})
data_dir <- '/tmp/glm_example'
x <- as.matrix(read.csv(paste(data_dir, '/x.csv', sep=''),
                        header=FALSE))
y <- as.matrix(read.csv(paste(data_dir, '/y.csv', sep=''),
                        header=FALSE, colClasses='integer'))
fit <- glmnet(
x = x,
y = y,
family = "binomial",  # Logistic regression
alpha = 1,  # corresponds to l1_weight = 1, l2_weight = 0
standardize = FALSE,
intercept = FALSE,
thresh = 1e-30,
type.logistic = "Newton"
)
write.csv(as.matrix(coef(fit, 0.008)),
          paste(data_dir, '/model_coefficients_glmnet.csv', sep=''),
          row.names=FALSE)

Compare R, TFP and true coefficients (Note: back to Python kernel)

DATA_DIR = '/tmp/glm_example'
with tf.io.gfile.GFile('{}/model_coefficients_glmnet.csv'.format(DATA_DIR),
                       'r') as f:
  model_coefficients_glmnet = np.loadtxt(f,
                                   skiprows=2  # Skip column name and intercept
                               )

with tf.io.gfile.GFile('{}/model_coefficients_prox.csv'.format(DATA_DIR),
                       'r') as f:
  model_coefficients_prox = np.loadtxt(f)

with tf.io.gfile.GFile(
    '{}/model_coefficients_true.csv'.format(DATA_DIR), 'r') as f:
  model_coefficients_true = np.loadtxt(f)
coefs_comparison = pd.DataFrame({
    'TFP': model_coefficients_prox,
    'R': model_coefficients_glmnet,
    'True': model_coefficients_true,
})
coefs_comparison

Algorithm Details for tfp.glm.fit_sparse

We present the algorithm as a sequence of three modifications to Newton's method. In each one, the update rule for \(\beta\) is based on a vector \(s\) and a matrix \(H\) which approximate the gradient and Hessian of the log-likelihood. In step \(t\), we choose a coordinate \(j^{(t)}\) to change, and we update \(\beta\) according to the update rule:

\[ \begin{align*} u^{(t)} &:= \frac{ \left( s^{(t)} \right)_{j^{(t)} } }{ \left( H^{(t)} \right)_{j^{(t)},\, j^{(t)} } } \\[3mm] \beta^{(t+1)} &:= \beta^{(t)} - \alpha\, u^{(t)} \,\text{onehot}(j^{(t)}) \end{align*} \]

This update is a Newton-like step with learning rate \(\alpha\). Except for the final piece (L1 regularization), the modifications below differ only in how they update \(s\) and \(H\).

Starting point: Coordinatewise Newton's method

In coordinatewise Newton's method, we set \(s\) and \(H\) to the true gradient and Hessian of the log-likelihood:

\[ \begin{align*} s^{(t)}_{\text{vanilla} } &:= \left( \nabla_\beta\, \ell(\beta \,;\, \mathbf{x}, \mathbf{y}) \right)_{\beta = \beta^{(t)} } \\ H^{(t)}_{\text{vanilla} } &:= \left( \nabla^2_\beta\, \ell(\beta \,;\, \mathbf{x}, \mathbf{y}) \right)_{\beta = \beta^{(t)} } \end{align*} \]

Fewer evaluations of the gradient and Hessian

The gradient and Hessian of the log-likelihood are often expensive to compute, so it is often worthwhile to approximate them. We can do so as follows:

  • Usually, approximate the Hessian as locally constant and approximate the gradient to first order using the (approximate) Hessian:

\[ \begin{align*} H_{\text{approx} }^{(t+1)} &:= H^{(t)} \\ s_{\text{approx} }^{(t+1)} &:= s^{(t)} + H^{(t)} \left( \beta^{(t+1)} - \beta^{(t)} \right) \end{align*} \]

  • Occasionally, perform a "vanilla" update step as above, setting \(s^{(t+1)}\) to the exact gradient and \(H^{(t+1)}\) to the exact Hessian of the log-likelihood, evaluated at \(\beta^{(t+1)}\).

Substitute negative Fisher information for Hessian

To further reduce the cost of the vanilla update steps, we can set \(H\) to the negative Fisher information matrix (efficiently computable using the formulas in "Fitting GLM Parameters to Data" below) rather than the exact Hessian:

\[ \begin{align*} H_{\text{Fisher} }^{(t+1)} &:= \mathbb{E}_{Y_i \sim p_{\text{OEF}(m, T)}(\cdot | \theta = h(x_i^\top \beta^{(t+1)}), \phi)} \left[ \left( \nabla_\beta^2\, \ell(\beta\, ;\, \mathbf{x}, \mathbf{Y}) \right)_{\beta = \beta^{(t+1)} } \right] \\ &= -\mathbf{x}^\top \,\text{diag}\left( \frac{ \phi\, {\textbf{Mean}_T}'(\mathbf{x} \beta^{(t+1)})^2 }{ {\textbf{Var}_T}(\mathbf{x} \beta^{(t+1)}) }\right)\, \mathbf{x} \\ s_{\text{Fisher} }^{(t+1)} &:= s_{\text{vanilla} }^{(t+1)} \\ &= \left( \mathbf{x}^\top \,\text{diag}\left(\frac{ {\textbf{Mean}_T}'(\mathbf{x} \beta^{(t+1)}) }{ {\textbf{Var}_T}(\mathbf{x} \beta^{(t+1)}) }\right) \left(\mathbf{T}(\mathbf{y}) - {\textbf{Mean}_T}(\mathbf{x} \beta^{(t+1)})\right) \right) \end{align*} \]

L1 Regularization via Proximal Gradient Descent

To incorporate L1 regularization, we replace the update rule

\[ \beta^{(t+1)} := \beta^{(t)} - \alpha\, u^{(t)} \,\text{onehot}(j^{(t)}) \]

with the more general update rule

\[ \begin{align*} \gamma^{(t)} &:= -\frac{\alpha\, r_{\text{L1} } }{\left(H^{(t)}\right)_{j^{(t)},\, j^{(t)} } } \\[2mm] \left(\beta_{\text{reg} }^{(t+1)}\right)_j &:= \begin{cases} \beta^{(t+1)}_j &\text{if } j \neq j^{(t)} \\ \text{SoftThreshold} \left( \beta^{(t)}_j - \alpha\, u^{(t)} ,\ \gamma^{(t)} \right) &\text{if } j = j^{(t)} \end{cases} \end{align*} \]

where \(r_{\text{L1} } > 0\) is a supplied constant (the L1 regularization coefficient) and \(\text{SoftThreshold}\) is the soft thresholding operator, defined by

\[ \text{SoftThreshold}(\beta, \gamma) := \begin{cases} \beta + \gamma &\text{if } \beta < -\gamma \\ 0 &\text{if } -\gamma \leq \beta \leq \gamma \\ \beta - \gamma &\text{if } \beta > \gamma. \end{cases} \]

This update rule has the following two inspirational properties, which we explain below:

  1. In the limiting case \(r_{\text{L1} } \to 0\) (i.e., no L1 regularization), this update rule is identical to the original update rule.

  2. This update rule can be interpreted as applying a proximity operator whose fixed point is the solution to the L1-regularized minimization problem

$$ \underset{\beta - \beta^{(t)} \in \text{span}{ \text{onehot}(j^{(t)}) } }{\text{arg min} } \left( -\ell(\beta \,;\, \mathbf{x}, \mathbf{y})

  • r_{\text{L1} } \left\lVert \beta \right\rVert_1 \right). $$

Degenerate case \(r_{\text{L1} } = 0\) recovers the original update rule

To see (1), note that if \(r_{\text{L1} } = 0\) then \(\gamma^{(t)} = 0\), hence

\[ \begin{align*} \left(\beta_{\text{reg} }^{(t+1)}\right)_{j^{(t)} } &= \text{SoftThreshold} \left( \beta^{(t)}_{j^{(t)} } - \alpha\, u^{(t)} ,\ 0 \right) \\ &= \beta^{(t)}_{j^{(t)} } - \alpha\, u^{(t)}. \end{align*} \]

Hence

\[ \begin{align*} \beta_{\text{reg} }^{(t+1)} &= \beta^{(t)} - \alpha\, u^{(t)} \,\text{onehot}(j^{(t)}) \\ &= \beta^{(t+1)}. \end{align*} \]

Proximity operator whose fixed point is the regularized MLE

To see (2), first note (see Wikipedia) that for any \(\gamma > 0\), the update rule

\[ \left(\beta_{\text{exact-prox}, \gamma}^{(t+1)}\right)_{j^{(t)} } := \text{prox}_{\gamma \lVert \cdot \rVert_1} \left( \beta^{(t)}_{j^{(t)} } + \frac{\gamma}{r_{\text{L1} } } \left( \left( \nabla_\beta\, \ell(\beta \,;\, \mathbf{x}, \mathbf{y}) \right)_{\beta = \beta^{(t)} } \right)_{j^{(t)} } \right) \]

satisfies (2), where \(\text{prox}\) is the proximity operator (see Yu, where this operator is denoted \(\mathsf{P}\)). The right-hand side of the above equation is computed here:

$$

\left(\beta{\text{exact-prox}, \gamma}^{(t+1)}\right){j^{(t)} }

\text{SoftThreshold} \left( \beta^{(t)}{j^{(t)} } + \frac{\gamma}{r{\text{L1} } } \left( \left( \nabla\beta\, \ell(\beta \,;\, \mathbf{x}, \mathbf{y}) \right){\beta = \beta^{(t)} } \right)_{j^{(t)} } ,\ \gamma \right). $$

In particular, setting \(\gamma = \gamma^{(t)} = -\frac{\alpha\, r_{\text{L1} } }{\left(H^{(t)}\right)_{j^{(t)}, j^{(t)} } }\) (note that \(\gamma^{(t)} > 0\) as long as the negative log-likelihood is convex), we obtain the update rule

$$

\left(\beta{\text{exact-prox}, \gamma^{(t)} }^{(t+1)}\right){j^{(t)} }

\text{SoftThreshold} \left( \beta^{(t)}{j^{(t)} } - \alpha \frac{ \left( \left( \nabla\beta\, \ell(\beta \,;\, \mathbf{x}, \mathbf{y}) \right){\beta = \beta^{(t)} } \right){j^{(t)} } }{ \left(H^{(t)}\right)_{j^{(t)}, j^{(t)} } } ,\ \gamma^{(t)} \right). $$

We then replace the exact gradient $\left( \nabla\beta\, \ell(\beta \,;\, \mathbf{x}, \mathbf{y}) \right){\beta = \beta^{(t)} }$ with its approximation \(s^{(t)}\), obtaining

\begin{align} \left(\beta{\text{exact-prox}, \gamma^{(t)} }^{(t+1)}\right){j^{(t)} } &\approx \text{SoftThreshold} \left( \beta^{(t)}{j^{(t)} } - \alpha \frac{ \left(s^{(t)}\right){j^{(t)} } }{ \left(H^{(t)}\right){j^{(t)}, j^{(t)} } } ,\ \gamma^{(t)} \right) \ &= \text{SoftThreshold} \left( \beta^{(t)}{j^{(t)} } - \alpha\, u^{(t)} ,\ \gamma^{(t)} \right). \end{align}

Hence

\[ \beta_{\text{exact-prox}, \gamma^{(t)} }^{(t+1)} \approx \beta_{\text{reg} }^{(t+1)}. \]

Derivation of GLM Facts

In this section we state in full detail and derive the results about GLMs that are used in the preceding sections. Then, we use TensorFlow's gradients to numerically verify the derived formulas for gradient of the log-likelihood and Fisher information.

Score and Fisher information

Consider a family of probability distributions parameterized by parameter vector \(\theta\), having probability densities \(\left\{p(\cdot | \theta)\right\}_{\theta \in \mathcal{T} }\). The score of an outcome \(y\) at parameter vector \(\theta_0\) is defined to be the gradient of the log likelihood of \(y\) (evaluated at \(\theta_0\)), that is,

\[ \text{score}(y, \theta_0) := \left[\nabla_\theta\, \log p(y | \theta)\right]_{\theta=\theta_0}. \]

Claim: Expectation of the score is zero

Under mild regularity conditions (permitting us to pass differentiation under the integral),

\[ \mathbb{E}_{Y \sim p(\cdot | \theta=\theta_0)}\left[\text{score}(Y, \theta_0)\right] = 0. \]

Proof

We have

\[ \begin{align*} \mathbb{E}_{Y \sim p(\cdot | \theta=\theta_0)}\left[\text{score}(Y, \theta_0)\right] &:=\mathbb{E}_{Y \sim p(\cdot | \theta=\theta_0)}\left[\left(\nabla_\theta \log p(Y|\theta)\right)_{\theta=\theta_0}\right] \\ &\stackrel{\text{(1)} }{=} \mathbb{E}_{Y \sim p(\cdot | \theta=\theta_0)}\left[\frac{\left(\nabla_\theta p(Y|\theta)\right)_{\theta=\theta_0} }{p(Y|\theta=\theta_0)}\right] \\ &\stackrel{\text{(2)} }{=} \int_{\mathcal{Y} } \left[\frac{\left(\nabla_\theta p(y|\theta)\right)_{\theta=\theta_0} }{p(y|\theta=\theta_0)}\right] p(y | \theta=\theta_0)\, dy \\ &= \int_{\mathcal{Y} } \left(\nabla_\theta p(y|\theta)\right)_{\theta=\theta_0}\, dy \\ &\stackrel{\text{(3)} }{=} \left[\nabla_\theta \left(\int_{\mathcal{Y} } p(y|\theta)\, dy\right) \right]_{\theta=\theta_0} \\ &\stackrel{\text{(4)} }{=} \left[\nabla_\theta\, 1 \right]_{\theta=\theta_0} \\ &= 0, \end{align*} \]

where we have used: (1) chain rule for differentiation, (2) definition of expectation, (3) passing differentiation under the integral sign (using the regularity conditions), (4) the integral of a probability density is 1.

Claim (Fisher information): Variance of the score equals negative expected Hessian of the log likelihood

Under mild regularity conditions (permitting us to pass differentiation under the integral),

$$ \mathbb{E}_{Y \sim p(\cdot | \theta=\theta_0)}\left[ \text{score}(Y, \theta_0) \text{score}(Y, \theta_0)^\top

\right]

-\mathbb{E}_{Y \sim p(\cdot | \theta=\theta0)}\left[ \left(\nabla\theta^2 \log p(Y | \theta)\right)_{\theta=\theta_0} \right] $$

where \(\nabla_\theta^2 F\) denotes the Hessian matrix, whose \((i, j)\) entry is \(\frac{\partial^2 F}{\partial \theta_i \partial \theta_j}\).

The left-hand side of this equation is called the Fisher information of the family \(\left\{p(\cdot | \theta)\right\}_{\theta \in \mathcal{T} }\) at parameter vector \(\theta_0\).

Proof of claim

We have

\[ \begin{align*} \mathbb{E}_{Y \sim p(\cdot | \theta=\theta_0)}\left[ \left(\nabla_\theta^2 \log p(Y | \theta)\right)_{\theta=\theta_0} \right] &\stackrel{\text{(1)} }{=} \mathbb{E}_{Y \sim p(\cdot | \theta=\theta_0)}\left[ \left(\nabla_\theta^\top \frac{ \nabla_\theta p(Y | \theta) }{ p(Y|\theta) }\right)_{\theta=\theta_0} \right] \\ &\stackrel{\text{(2)} }{=} \mathbb{E}_{Y \sim p(\cdot | \theta=\theta_0)}\left[ \frac{ \left(\nabla^2_\theta p(Y | \theta)\right)_{\theta=\theta_0} }{ p(Y|\theta=\theta_0) } - \left(\frac{ \left(\nabla_\theta\, p(Y|\theta)\right)_{\theta=\theta_0} }{ p(Y|\theta=\theta_0) }\right) \left(\frac{ \left(\nabla_\theta\, p(Y|\theta)\right)_{\theta=\theta_0} }{ p(Y|\theta=\theta_0) }\right)^\top \right] \\ &\stackrel{\text{(3)} }{=} \mathbb{E}_{Y \sim p(\cdot | \theta=\theta_0)}\left[ \frac{ \left(\nabla^2_\theta p(Y | \theta)\right)_{\theta=\theta_0} }{ p(Y|\theta=\theta_0) } - \text{score}(Y, \theta_0) \,\text{score}(Y, \theta_0)^\top \right], \end{align*} \]

where we have used (1) chain rule for differentiation, (2) quotient rule for differentiation, (3) chain rule again, in reverse.

To complete the proof, it suffices to show that

\[ \mathbb{E}_{Y \sim p(\cdot | \theta=\theta_0)}\left[ \frac{ \left(\nabla^2_\theta p(Y | \theta)\right)_{\theta=\theta_0} }{ p(Y|\theta=\theta_0) } \right] \stackrel{\text{?} }{=} 0. \]

To do that, we pass differentiation under the integral sign twice:

\[ \begin{align*} \mathbb{E}_{Y \sim p(\cdot | \theta=\theta_0)}\left[ \frac{ \left(\nabla^2_\theta p(Y | \theta)\right)_{\theta=\theta_0} }{ p(Y|\theta=\theta_0) } \right] &= \int_{\mathcal{Y} } \left[ \frac{ \left(\nabla^2_\theta p(y | \theta)\right)_{\theta=\theta_0} }{ p(y|\theta=\theta_0) } \right] \, p(y | \theta=\theta_0)\, dy \\ &= \int_{\mathcal{Y} } \left(\nabla^2_\theta p(y | \theta)\right)_{\theta=\theta_0} \, dy \\ &= \left[ \nabla_\theta^2 \left( \int_{\mathcal{Y} } p(y | \theta) \, dy \right) \right]_{\theta=\theta_0} \\ &= \left[ \nabla_\theta^2 \, 1 \right]_{\theta=\theta_0} \\ &= 0. \end{align*} \]

Lemma about the derivative of the log partition function

If \(a\), \(b\) and \(c\) are scalar-valued functions, \(c\) twice differentiable, such that the family of distributions \(\left\{p(\cdot | \theta)\right\}_{\theta \in \mathcal{T} }\) defined by

\[ p(y|\theta) = a(y) \exp\left(b(y)\, \theta - c(\theta)\right) \]

satisfies the mild regularity conditions that permit passing differentiation with respect to \(\theta\) under an integral with respect to \(y\), then

\[ \mathbb{E}_{Y \sim p(\cdot | \theta=\theta_0)} \left[ b(Y) \right] = c'(\theta_0) \]

and

\[ \text{Var}_{Y \sim p(\cdot | \theta=\theta_0)} \left[ b(Y) \right] = c''(\theta_0). \]

(Here \('\) denotes differentiation, so \(c'\) and \(c''\) are the first and second derivatives of \(c\). )

Proof

For this family of distributions, we have \(\text{score}(y, \theta_0) = b(y) - c'(\theta_0)\). The first equation then follows from the fact that \(\mathbb{E}_{Y \sim p(\cdot | \theta=\theta_0)} \left[ \text{score}(y, \theta_0) \right] = 0\). Next, we have

\[ \begin{align*} \text{Var}_{Y \sim p(\cdot | \theta=\theta_0)} \left[ b(Y) \right] &= \mathbb{E}_{Y \sim p(\cdot | \theta=\theta_0)} \left[ \left(b(Y) - c'(\theta_0)\right)^2 \right] \\ &= \text{the one entry of } \mathbb{E}_{Y \sim p(\cdot | \theta=\theta_0)} \left[ \text{score}(y, \theta_0) \text{score}(y, \theta_0)^\top \right] \\ &= \text{the one entry of } -\mathbb{E}_{Y \sim p(\cdot | \theta=\theta_0)} \left[ \left(\nabla_\theta^2 \log p(\cdot | \theta)\right)_{\theta=\theta_0} \right] \\ &= -\mathbb{E}_{Y \sim p(\cdot | \theta=\theta_0)} \left[ -c''(\theta_0) \right] \\ &= c''(\theta_0). \end{align*} \]

Overdispersed Exponential Family

A (scalar) overdispersed exponential family is a family of distributions whose densities take the form

\[ p_{\text{OEF}(m, T)}(y\, |\, \theta, \phi) = m(y, \phi) \exp\left(\frac{\theta\, T(y) - A(\theta)}{\phi}\right), \]

where \(m\) and \(T\) are known scalar-valued functions, and \(\theta\) and \(\phi\) are scalar parameters.

[Note that \(A\) is overdetermined: for any \(\phi_0\), the function \(A\) is completely determined by the constraint that \(\int p_{\text{OEF}(m, T)}(y\ |\ \theta, \phi=\phi_0)\, dy = 1\) for all \(\theta\). The \(A\)'s produced by different values of \(\phi_0\) must all be the same, which places a constraint on the functions \(m\) and \(T\).]

Mean and variance of the sufficient statistic

Under the same conditions as "Lemma about the derivative of the log partition function," we have

$$ \mathbb{E}{Y \sim p{\text{OEF}(m, T)}(\cdot | \theta, \phi)} \left[ T(Y)

\right]

A'(\theta) $$

and

$$ \text{Var}{Y \sim p{\text{OEF}(m, T)}(\cdot | \theta, \phi)} \left[ T(Y)

\right]

\phi A''(\theta). $$

Proof

By "Lemma about the derivative of the log partition function," we have

$$ \mathbb{E}{Y \sim p{\text{OEF}(m, T)}(\cdot | \theta, \phi)} \left[ \frac{T(Y)}{\phi}

\right]

\frac{A'(\theta)}{\phi} $$

and

$$ \text{Var}{Y \sim p{\text{OEF}(m, T)}(\cdot | \theta, \phi)} \left[ \frac{T(Y)}{\phi}

\right]

\frac{A''(\theta)}{\phi}. $$

The result then follows from the fact that expectation is linear (\(\mathbb{E}[aX] = a\mathbb{E}[X]\)) and variance is degree-2 homogeneous (\(\text{Var}[aX] = a^2 \,\text{Var}[X]\)).

Generalized Linear Model

In a generalized linear model, a predictive distribution for the response variable \(Y\) is associated with a vector of observed predictors \(x\). The distribution is a member of an overdispersed exponential family, and the parameter \(\theta\) is replaced by \(h(\eta)\) where \(h\) is a known function, \(\eta := x^\top \beta\) is the so-called linear response, and \(\beta\) is a vector of parameters (regression coefficients) to be learned. In general the dispersion parameter \(\phi\) could be learned too, but in our setup we will treat \(\phi\) as known. So our setup is

\[ Y \sim p_{\text{OEF}(m, T)}(\cdot\, |\, \theta = h(\eta), \phi) \]

where the model structure is characterized by the distribution \(p_{\text{OEF}(m, T)}\) and the function \(h\) which converts linear response to parameters.

Traditionally, the mapping from linear response \(\eta\) to mean \(\mu := \mathbb{E}_{Y \sim p_{\text{OEF}(m, T)}(\cdot\, |\, \theta = h(\eta), \phi)}\left[ Y\right]\) is denoted

\[ \mu = g^{-1}(\eta). \]

This mapping is required to be one-to-one, and its inverse, \(g\), is called the link function for this GLM. Typically, one describes a GLM by naming its link function and its family of distributions -- e.g., a "GLM with Bernoulli distribution and logit link function" (also known as a logistic regression model). In order to fully characterize the GLM, the function \(h\) must also be specified. If \(h\) is the identity, then \(g\) is said to be the canonical link function.

Claim: Expressing \(h'\) in terms of the sufficient statistic

Define

\[ {\text{Mean}_T}(\eta) := \mathbb{E}_{Y \sim p_{\text{OEF}(m, T)}(\cdot | \theta = h(\eta), \phi)} \left[ T(Y) \right] \]

and

\[ {\text{Var}_T}(\eta) := \text{Var}_{Y \sim p_{\text{OEF}(m, T)}(\cdot | \theta = h(\eta), \phi)} \left[ T(Y) \right]. \]

Then we have

\[ h'(\eta) = \frac{\phi\, {\text{Mean}_T}'(\eta)}{ {\text{Var}_T}(\eta)}. \]

Proof

By "Mean and variance of the sufficient statistic," we have

\[ {\text{Mean}_T}(\eta) = A'(h(\eta)). \]

Differentiating with the chain rule, we obtain

\[ {\text{Mean}_T}'(\eta) = A''(h(\eta))\, h'(\eta), \]

and by "Mean and variance of the sufficient statistic,"

\[ \cdots = \frac{1}{\phi} {\text{Var}_T}(\eta)\ h'(\eta). \]

The conclusion follows.

Fitting GLM Parameters to Data

The properties derived above lend themselves very well to fitting GLM parameters \(\beta\) to a data set. Quasi-Newton methods such as Fisher scoring rely on the gradient of the log likelihood and the Fisher information, which we now show can be computed especially efficiently for a GLM.

Suppose we have observed predictor vectors \(x_i\) and associated scalar responses \(y_i\). In matrix form, we'll say we have observed predictors \(\mathbf{x}\) and response \(\mathbf{y}\), where \(\mathbf{x}\) is the matrix whose \(i\)th row is \(x_i^\top\) and \(\mathbf{y}\) is the vector whose \(i\)th element is \(y_i\). The log likelihood of parameters \(\beta\) is then

\[ \ell(\beta\, ;\, \mathbf{x}, \mathbf{y}) = \sum_{i=1}^{N} \log p_{\text{OEF}(m, T)}(y_i\, |\, \theta = h(x_i^\top \beta), \phi). \]

For a single data sample

To simplify the notation, let's first consider the case of a single data point, \(N=1\); then we will extend to the general case by additivity.

Gradient

We have

\[ \begin{align*} \ell(\beta\, ;\, x, y) &= \log p_{\text{OEF}(m, T)}(y\, |\, \theta = h(x^\top \beta), \phi) \\ &= \log m(y, \phi) + \frac{\theta\, T(y) - A(\theta)}{\phi}, \quad\text{where}\ \theta = h(x^\top \beta). \end{align*} \]

Hence by the chain rule,

\[ \nabla_\beta \ell(\beta\, ; \, x, y) = \frac{T(y) - A'(\theta)}{\phi}\, h'(x^\top \beta)\, x. \]

Separately, by "Mean and variance of the sufficient statistic," we have \(A'(\theta) = {\text{Mean}_T}(x^\top \beta)\). Hence, by "Claim: Expressing \(h'\) in terms of the sufficient statistic," we have

\[ \cdots = \left(T(y) - {\text{Mean}_T}(x^\top \beta)\right) \frac{ {\text{Mean}_T}'(x^\top \beta)}{ {\text{Var}_T}(x^\top \beta)} \,x. \]

Hessian

Differentiating a second time, by the product rule we obtain

\[ \begin{align*} \nabla_\beta^2 \ell(\beta\, ;\, x, y) &= \left[ -A''(h(x^\top \beta))\, h'(x^\top \beta) \right] h'(x^\top \beta)\, x x^\top + \left[ T(y) - A'(h(x^\top \beta)) \right] h''(x^\top \beta)\, xx^\top ] \\ &= \left( -{\text{Mean}_T}'(x^\top \beta)\, h'(x^\top \beta) + \left[T(y) - A'(h(x^\top \beta))\right] \right)\, x x^\top. \end{align*} \]

Fisher information

By "Mean and variance of the sufficient statistic," we have

\[ \mathbb{E}_{Y \sim p_{\text{OEF}(m, T)}(\cdot | \theta = h(x^\top \beta), \phi)} \left[ T(y) - A'(h(x^\top \beta)) \right] = 0. \]

Hence

\[ \begin{align*} \mathbb{E}_{Y \sim p_{\text{OEF}(m, T)}(\cdot | \theta = h(x^\top \beta), \phi)} \left[ \nabla_\beta^2 \ell(\beta\, ;\, x, y) \right] &= -{\text{Mean}_T}'(x^\top \beta)\, h'(x^\top \beta) x x^\top \\ &= -\frac{\phi\, {\text{Mean}_T}'(x^\top \beta)^2}{ {\text{Var}_T}(x^\top \beta)}\, x x^\top. \end{align*} \]

For multiple data samples

We now extend the \(N=1\) case to the general case. Let \(\boldsymbol{\eta} := \mathbf{x} \beta\) denote the vector whose \(i\)th coordinate is the linear response from the \(i\)th data sample. Let \(\mathbf{T}\) (resp. \({\textbf{Mean}_T}\), resp. \({\textbf{Var}_T}\)) denote the broadcasted (vectorized) function which applies the scalar-valued function \(T\) (resp. \({\text{Mean}_T}\), resp. \({\text{Var}_T}\)) to each coordinate. Then we have

\[ \begin{align*} \nabla_\beta \ell(\beta\, ;\, \mathbf{x}, \mathbf{y}) &= \sum_{i=1}^{N} \nabla_\beta \ell(\beta\, ;\, x_i, y_i) \\ &= \sum_{i=1}^{N} \left(T(y) - {\text{Mean}_T}(x_i^\top \beta)\right) \frac{ {\text{Mean}_T}'(x_i^\top \beta)}{ {\text{Var}_T}(x_i^\top \beta)} \, x_i \\ &= \mathbf{x}^\top \,\text{diag}\left(\frac{ {\textbf{Mean}_T}'(\mathbf{x} \beta) }{ {\textbf{Var}_T}(\mathbf{x} \beta) }\right) \left(\mathbf{T}(\mathbf{y}) - {\textbf{Mean}_T}(\mathbf{x} \beta)\right) \\ \end{align*} \]

and

\[ \begin{align*} \mathbb{E}_{Y_i \sim p_{\text{OEF}(m, T)}(\cdot | \theta = h(x_i^\top \beta), \phi)} \left[ \nabla_\beta^2 \ell(\beta\, ;\, \mathbf{x}, \mathbf{Y}) \right] &= \sum_{i=1}^{N} \mathbb{E}_{Y_i \sim p_{\text{OEF}(m, T)}(\cdot | \theta = h(x_i^\top \beta), \phi)} \left[ \nabla_\beta^2 \ell(\beta\, ;\, x_i, Y_i) \right] \\ &= \sum_{i=1}^{N} -\frac{\phi\, {\text{Mean}_T}'(x_i^\top \beta)^2}{ {\text{Var}_T}(x_i^\top \beta)}\, x_i x_i^\top \\ &= -\mathbf{x}^\top \,\text{diag}\left( \frac{ \phi\, {\textbf{Mean}_T}'(\mathbf{x} \beta)^2 }{ {\textbf{Var}_T}(\mathbf{x} \beta) }\right)\, \mathbf{x}, \end{align*} \]

where the fractions denote element-wise division.

Verifying the Formulas Numerically

We now verify the above formula for gradient of the log likelihood numerically using tf.gradients, and verify the formula for Fisher information with a Monte Carlo estimate using tf.hessians:

def VerifyGradientAndFIM():
  model = tfp.glm.BernoulliNormalCDF()
  model_matrix = np.array([[1., 5, -2],
                           [8, -1, 8]])

  def _naive_grad_and_hessian_loss_fn(x, response):
    # Computes gradient and Hessian of negative log likelihood using autodiff.
    predicted_linear_response = tf.linalg.matvec(model_matrix, x)
    log_probs = model.log_prob(response, predicted_linear_response)
    grad_loss = tf.gradients(-log_probs, [x])[0]
    hessian_loss = tf.hessians(-log_probs, [x])[0]
    return [grad_loss, hessian_loss]

  def _grad_neg_log_likelihood_and_fim_fn(x, response):
    # Computes gradient of negative log likelihood and Fisher information matrix
    # using the formulas above.
    predicted_linear_response = tf.linalg.matvec(model_matrix, x)
    mean, variance, grad_mean = model(predicted_linear_response)

    v = (response - mean) * grad_mean / variance
    grad_log_likelihood = tf.linalg.matvec(model_matrix, v, adjoint_a=True)
    w = grad_mean**2 / variance

    fisher_info = tf.linalg.matmul(
        model_matrix,
        w[..., tf.newaxis] * model_matrix,
        adjoint_a=True)
    return [-grad_log_likelihood, fisher_info]

  @tf.function(autograph=False)
  def compute_grad_hessian_estimates():
    # Monte Carlo estimate of E[Hessian(-LogLikelihood)], where the expectation is
    # as written in "Claim (Fisher information)" above.
    num_trials = 20
    trial_outputs = []
    np.random.seed(10)
    model_coefficients_ = np.random.random(size=(model_matrix.shape[1],))
    model_coefficients = tf.convert_to_tensor(model_coefficients_)
    for _ in range(num_trials):
      # Sample from the distribution of `model`
      response = np.random.binomial(
          1,
          scipy.stats.norm().cdf(np.matmul(model_matrix, model_coefficients_))
      ).astype(np.float64)
      trial_outputs.append(
          list(_naive_grad_and_hessian_loss_fn(model_coefficients, response)) +
          list(
              _grad_neg_log_likelihood_and_fim_fn(model_coefficients, response))
      )

    naive_grads = tf.stack(
        list(naive_grad for [naive_grad, _, _, _] in trial_outputs), axis=0)
    fancy_grads = tf.stack(
        list(fancy_grad for [_, _, fancy_grad, _] in trial_outputs), axis=0)

    average_hess = tf.reduce_mean(tf.stack(
        list(hess for [_, hess, _, _] in trial_outputs), axis=0), axis=0)
    [_, _, _, fisher_info] = trial_outputs[0]
    return naive_grads, fancy_grads, average_hess, fisher_info

  naive_grads, fancy_grads, average_hess, fisher_info = [
      t.numpy() for t in compute_grad_hessian_estimates()]

  print("Coordinatewise relative error between naively computed gradients and"
        " formula-based gradients (should be zero):\n{}\n".format(
            (naive_grads - fancy_grads) / naive_grads))

  print("Coordinatewise relative error between average of naively computed"
        " Hessian and formula-based FIM (should approach zero as num_trials"
        " -> infinity):\n{}\n".format(
                (average_hess - fisher_info) / average_hess))

VerifyGradientAndFIM()
Coordinatewise relative error between naively computed gradients and formula-based gradients (should be zero):
[[2.08845965e-16 1.67076772e-16 2.08845965e-16]
 [1.96118673e-16 3.13789877e-16 1.96118673e-16]
 [2.08845965e-16 1.67076772e-16 2.08845965e-16]
 [1.96118673e-16 3.13789877e-16 1.96118673e-16]
 [2.08845965e-16 1.67076772e-16 2.08845965e-16]
 [1.96118673e-16 3.13789877e-16 1.96118673e-16]
 [1.96118673e-16 3.13789877e-16 1.96118673e-16]
 [1.96118673e-16 3.13789877e-16 1.96118673e-16]
 [2.08845965e-16 1.67076772e-16 2.08845965e-16]
 [1.96118673e-16 3.13789877e-16 1.96118673e-16]
 [2.08845965e-16 1.67076772e-16 2.08845965e-16]
 [1.96118673e-16 3.13789877e-16 1.96118673e-16]
 [1.96118673e-16 3.13789877e-16 1.96118673e-16]
 [1.96118673e-16 3.13789877e-16 1.96118673e-16]
 [1.96118673e-16 3.13789877e-16 1.96118673e-16]
 [1.96118673e-16 3.13789877e-16 1.96118673e-16]
 [1.96118673e-16 3.13789877e-16 1.96118673e-16]
 [2.08845965e-16 1.67076772e-16 2.08845965e-16]
 [1.96118673e-16 3.13789877e-16 1.96118673e-16]
 [2.08845965e-16 1.67076772e-16 2.08845965e-16]]

Coordinatewise relative error between average of naively computed Hessian and formula-based FIM (should approach zero as num_trials -> infinity):
[[0.00072369 0.00072369 0.00072369]
 [0.00072369 0.00072369 0.00072369]
 [0.00072369 0.00072369 0.00072369]]

References

[1]: Guo-Xun Yuan, Chia-Hua Ho and Chih-Jen Lin. An Improved GLMNET for L1-regularized Logistic Regression. Journal of Machine Learning Research, 13, 2012. http://www.jmlr.org/papers/volume13/yuan12a/yuan12a.pdf

[2]: skd. Derivation of Soft Thresholding Operator. 2018. https://math.stackexchange.com/q/511106

[3]: Wikipedia Contributors. Proximal gradient methods for learning. Wikipedia, The Free Encyclopedia, 2018. https://en.wikipedia.org/wiki/Proximal_gradient_methods_for_learning

[4]: Yao-Liang Yu. The Proximity Operator. https://www.cs.cmu.edu/~suvrit/teach/yaoliang_proximity.pdf