广义线性模型

在此笔记本中,我们将通过一个工作示例来介绍广义线性模型。我们使用两种算法以两种不同的方式解决此示例,以在 TensorFlow Probability 中有效地拟合 GLM:针对密集数据使用 Fisher 得分算法,针对稀疏数据使用坐标近端梯度下降算法。我们将拟合系数与真实系数进行对比,在坐标近端梯度下降算法下则与 R 语言的类似 glmnet 算法的输出进行对比。最后,我们提供了 GLM 一些关键属性的进一步数学细节和推导。

背景

广义线性模型 (GLM) 是一种封装在转换(联系函数)中并配备了指数族的响应分布的线性模型 (η=xβ) 。联系函数和响应分布的选择非常灵活,这为 GLM 赋予了出色的表达性。在下面的“GLM 事实的推导”中可以找到完整的详细信息,包括以明确的表示法对 GLM 构建的所有定义和结果的有序介绍。我们总结如下:

在 GLM 中,响应变量 Y 的预测分布与观察到的预测变量 x 的向量相关联。分布形式如下:

p(y,|,x)amp;=m(y,ϕ)exp(θ,T(y)A(θ)ϕ) θamp;:=h(η) ηamp;:=xβ

其中,β 是参数(“权重”),ϕ 是表示离散度(“方差”)的超参数,mhTA 由用户指定模型族表征。

Y 的均值取决于 x线性响应 η 和(逆)联系函数,即:

μ:=g1(η)

其中 g 是所谓的联系函数。在 TFP 中,联系函数和模型族的选择由 tfp.glm.ExponentialFamily 子类共同指定。示例包括:

TFP 更喜欢根据 Y 的分布而非联系函数来命名模型族,因为 tfp.Distribution 已经是一等公民。如果 tfp.glm.ExponentialFamily 子类名称包含第二个单词,则表示非正则联系函数

GLM 具有几项可有效地实现最大似然 estimator 的显著特性。这些特性中最主要的是为对数似然函数 梯度以及 Fisher 信息矩阵提供了简单的公式,它是在相同预测变量下对响应重新采样时负对数似然函数的 Hessian 的期望值。即:

β,(β,;,x,y)amp;=x,diag(MeanT(xβ)Var<emdatamdtype="emphasis">T(xβ))(T(y)Mean<emdatamdtype="emphasis">T(xβ)) E</em>YiGLM|xi[</em>β2,(β,;,x,Y)]amp;=x,diag(ϕ,MeanT(xβ)2VarT(xβ)),x

其中 x 是矩阵,其第 i 行是第 i 个数据样本的预测变量向量;y 是向量,其第 i 个坐标是第 i 个数据样本的观察到的响应。这里(粗略地讲),MeanT(η):=E[T(Y),|,η]VarT(η):=Var[T(Y),|,η],粗体表示这些函数的矢量化。有关这些期望和方差的分布的完整详细信息,请参阅下方的“GLM 事实的推导”。

示例

在本部分中,我们将简要介绍和展示 TensorFlow Probability 中的两种内置 GLM 拟合算法:Fisher 得分 (tfp.glm.fit) 和坐标近端梯度下降 (tfp.glm.fit_sparse)。

合成数据集

让我们假装加载一些训练数据集。

import numpy as np
import pandas as pd
import scipy
import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()
import tensorflow_probability as tfp
tfd = tfp.distributions
def make_dataset(n, d, link, scale=1., dtype=np.float32):
  model_coefficients = tfd.Uniform(
      low=-1., high=np.array(1, dtype)).sample(d, seed=42)
  radius = np.sqrt(2.)
  model_coefficients *= radius / tf.linalg.norm(model_coefficients)
  mask = tf.random.shuffle(tf.range(d)) < int(0.5 * d)
  model_coefficients = tf.where(
      mask, model_coefficients, np.array(0., dtype))
  model_matrix = tfd.Normal(
      loc=0., scale=np.array(1, dtype)).sample([n, d], seed=43)
  scale = tf.convert_to_tensor(scale, dtype)
  linear_response = tf.linalg.matvec(model_matrix, model_coefficients)

  if link == 'linear':
    response = tfd.Normal(loc=linear_response, scale=scale).sample(seed=44)
  elif link == 'probit':
    response = tf.cast(
        tfd.Normal(loc=linear_response, scale=scale).sample(seed=44) > 0,
                   dtype)
  elif link == 'logit':
    response = tfd.Bernoulli(logits=linear_response).sample(seed=44)
  else:
    raise ValueError('unrecognized true link: {}'.format(link))
  return model_matrix, response, model_coefficients, mask

注:连接到本地运行时。

在此笔记本中,我们使用本地文件在 Python 和 R 内核之间共享数据。要启用此共享,请在您具备本地文件读写权限的同一台计算机上使用运行时。

x, y, model_coefficients_true, _ = [t.numpy() for t in make_dataset(
    n=int(1e5), d=100, link='probit')]

DATA_DIR = '/tmp/glm_example'
tf.io.gfile.makedirs(DATA_DIR)
with tf.io.gfile.GFile('{}/x.csv'.format(DATA_DIR), 'w') as f:
  np.savetxt(f, x, delimiter=',')
with tf.io.gfile.GFile('{}/y.csv'.format(DATA_DIR), 'w') as f:
  np.savetxt(f, y.astype(np.int32) + 1, delimiter=',', fmt='%d')
with tf.io.gfile.GFile(
    '{}/model_coefficients_true.csv'.format(DATA_DIR), 'w') as f:
  np.savetxt(f, model_coefficients_true, delimiter=',')

不使用 L1 正则化

函数 tfp.glm.fit 实现 Fisher 得分,它采用一些参数:

  • model_matrix = x
  • response = y
  • model = 可调用对象,给定参数 η,返回三元组 (MeanT(η),VarT(η),MeanT(η))

我们建议该 modeltfp.glm.ExponentialFamily 类的实例。有几种预制的实现可用,对于大多数常见的 GLM,不需要自定义代码。

@tf.function(autograph=False)
def fit_model():
  model_coefficients, linear_response, is_converged, num_iter = tfp.glm.fit(
      model_matrix=x, response=y, model=tfp.glm.BernoulliNormalCDF())
  log_likelihood = tfp.glm.BernoulliNormalCDF().log_prob(y, linear_response)
  return (model_coefficients, linear_response, is_converged, num_iter,
          log_likelihood)

[model_coefficients, linear_response, is_converged, num_iter,
 log_likelihood] = [t.numpy() for t in fit_model()]

print(('is_converged: {}\n'
       '    num_iter: {}\n'
       '    accuracy: {}\n'
       '    deviance: {}\n'
       '||w0-w1||_2 / (1+||w0||_2): {}'
      ).format(
    is_converged,
    num_iter,
    np.mean((linear_response > 0.) == y),
    2. * np.mean(log_likelihood),
    np.linalg.norm(model_coefficients_true - model_coefficients, ord=2) /
        (1. + np.linalg.norm(model_coefficients_true, ord=2))
    ))
is_converged: True
    num_iter: 6
    accuracy: 0.75241
    deviance: -0.992436110973
||w0-w1||_2 / (1+||w0||_2): 0.0231555201462

数学细节

Fisher 得分法是对牛顿法的修改,用于寻找最大似然估计

ˆβ:=arg maxβ  (β ; x,y).

普通牛顿法,搜索对数似然函数梯度的零点,将遵循更新规则

\beta^{(t+1)}_{\text{Newton} } := \beta^{(t)}

其中 α(0,1] 是用于控制步长的学习率。

在 Fisher 得分法中,我们将 Hessian 替换为负的 Fisher 信息矩阵:

\begin{align*} \beta^{(t+1)} &:= \beta^{(t)}

[注:此处 Y=(Yi)ni=1 是随机的,而 y 仍是观察到的响应的向量。]

通过下文“将 GLM 参数拟合到数据”中的公式,可将其简化为

β(t+1)amp;=β(t)+α(xdiag(ϕ,MeanT(xβ(t))2VarT(xβ(t))),x)1(xdiag(MeanT(xβ(t))VarT(xβ(t)))(T(y)MeanT(xβ(t)))).

使用 L1 正则化

tfp.glm.fit_sparse 基于 Yuan, Ho and Lin 2012 中的算法实现了更适合稀疏数据集的 GLM 拟合器。特性包括:

  • L1 正则化
  • 不使用矩阵求逆
  • 只需少量梯度和 Hessian 评估。

我们首先展示代码的示例用法。算法的细节会在下文“tfp.glm.fit_sparse 的算法细节”中进一步阐述。

model = tfp.glm.Bernoulli()
model_coefficients_start = tf.zeros(x.shape[-1], np.float32)
@tf.function(autograph=False)
def fit_model():
  return tfp.glm.fit_sparse(
    model_matrix=tf.convert_to_tensor(x),
    response=tf.convert_to_tensor(y),
    model=model,
    model_coefficients_start=model_coefficients_start,
    l1_regularizer=800.,
    l2_regularizer=None,
    maximum_iterations=10,
    maximum_full_sweeps_per_iteration=10,
    tolerance=1e-6,
    learning_rate=None)

model_coefficients, is_converged, num_iter = [t.numpy() for t in fit_model()]
coefs_comparison = pd.DataFrame({
  'Learned': model_coefficients,
  'True': model_coefficients_true,
})

print(('is_converged: {}\n'
       '    num_iter: {}\n\n'
       'Coefficients:').format(
    is_converged,
    num_iter))
coefs_comparison
is_converged: True
    num_iter: 1

Coefficients:

请注意,学习的系数与真实系数具有相同的稀疏模式。

# Save the learned coefficients to a file.
with tf.io.gfile.GFile('{}/model_coefficients_prox.csv'.format(DATA_DIR), 'w') as f:
  np.savetxt(f, model_coefficients, delimiter=',')

对比 R 语言的 glmnet

我们将坐标近端梯度下降算法的输出与使用类似算法的 R 语言的 glmnet 的输出进行对比。

注:要执行此部分,您必须切换到 R colab 运行时。

suppressMessages({
  library('glmnet')
})
data_dir <- '/tmp/glm_example'
x <- as.matrix(read.csv(paste(data_dir, '/x.csv', sep=''),
                        header=FALSE))
y <- as.matrix(read.csv(paste(data_dir, '/y.csv', sep=''),
                        header=FALSE, colClasses='integer'))
fit <- glmnet(
x = x,
y = y,
family = "binomial",  # Logistic regression
alpha = 1,  # corresponds to l1_weight = 1, l2_weight = 0
standardize = FALSE,
intercept = FALSE,
thresh = 1e-30,
type.logistic = "Newton"
)
write.csv(as.matrix(coef(fit, 0.008)),
          paste(data_dir, '/model_coefficients_glmnet.csv', sep=''),
          row.names=FALSE)

比较 R、TFP 和真实系数(注:回到 Python 内核)

DATA_DIR = '/tmp/glm_example'
with tf.io.gfile.GFile('{}/model_coefficients_glmnet.csv'.format(DATA_DIR),
                       'r') as f:
  model_coefficients_glmnet = np.loadtxt(f,
                                   skiprows=2  # Skip column name and intercept
                               )

with tf.io.gfile.GFile('{}/model_coefficients_prox.csv'.format(DATA_DIR),
                       'r') as f:
  model_coefficients_prox = np.loadtxt(f)

with tf.io.gfile.GFile(
    '{}/model_coefficients_true.csv'.format(DATA_DIR), 'r') as f:
  model_coefficients_true = np.loadtxt(f)
coefs_comparison = pd.DataFrame({
    'TFP': model_coefficients_prox,
    'R': model_coefficients_glmnet,
    'True': model_coefficients_true,
})
coefs_comparison

tfp.glm.fit_sparse 的算法细节

我们将算法依次呈现为对牛顿法的三种修改形式。在每种形式中,β 的更新规则都基于向量 s 和矩阵 H,它们会逼近对数似然函数的梯度和 Hessian。在步骤 t 中,我们选择坐标 j(t) 进行更改,并根据更新规则更新 β

\begin{align} u^{(t)} &:= \frac{ \left( s^{(t)} \right){j^{(t)} } }{ \left( H^{(t)} \right)*{j^{(t)},, j^{(t)} } } [3mm] \beta^{(t+1)} &:= \beta^{(t)}

此更新是一种类似牛顿法的步骤,学习率为 α。除了最后一部分(L1 正则化),下面的修改仅在 sH 的更新方式上有所不同。

起点:坐标牛顿法

在坐标牛顿法中,我们将 sH 设置为对数似然函数的真实梯度和 Hessian:

s(t)<emdatamdtype="emphasis">vanillaamp;:=(</em>β,(β,;,x,y))<emdatamdtype="emphasis">β=β(t) H(t)</em>vanillaamp;:=(2β,(β,;,x,y))β=β(t)

只需少量梯度和 Hessian 评估

对数似然函数的梯度和 Hessian 的计算通常十分消耗算力,因此通常值得对其采用逼近算法。我们可以如下处理:

  • 通常,将 Hessian 逼近为局部常值,并使用(逼近)Hessian 将梯度逼近为一阶:

H(t+1)approxamp;:=H(t) s(t+1)approxamp;:=s(t)+H(t)(β(t+1)β(t))

  • 有时,可执行上述“普通”更新步骤,将 s(t+1) 设置为对数似然函数的精确梯度并将 H(t+1) 设置为其精确 Hessian,在 β(t+1) 处评估。

使用负 Fisher 信息矩阵代替 Hessian

为了进一步降低普通更新步骤的算力成本,我们可以将 H 设置为负 Fisher 信息矩阵(使用下文“将 GLM 参数拟合到数据”中的公式可以有效计算),而非确切的 Hessian:

H(t+1)Fisheramp;:=E<emdatamdtype="emphasis">Yip</em>OEF(m,T)(|θ=h(xiβ(t+1)),ϕ)[(2β,(β,;,x,Y))β=β(t+1)] amp;=x,diag(ϕ,Mean<emdatamdtype="emphasis">T(xβ(t+1))2Var<emdatamdtype="emphasis">T(xβ(t+1))),x s</em>Fisher(t+1)amp;:=s</em>vanilla(t+1) amp;=(x,diag(MeanT(xβ(t+1))VarT(xβ(t+1)))(T(y)MeanT(xβ(t+1))))

通过近端梯度下降求解 L1 正则化

为包含 L1 正则化,我们将更新规则

\beta^{(t+1)} := \beta^{(t)}

替换为更通用的更新规则

γ(t)amp;:=α,r<l>(H(t))<em>j(t),,j(t)\[2mm](β</em>reg(t+1))jamp;:={β(t+1)jamp;if jj(t) SoftThreshold(β(t)jα,u(t), γ(t))amp;if j=j(t)

其中 r_{\text<l>} &gt; 0r_{\text<l>} &gt; 0 是提供的常值(L1 正则化系数),SoftThreshold 是软阈值算子,定义为

SoftThreshold(β,γ):={β+γamp;if βlt;γ 0amp;if γβγ βγamp;if βgt;γ.

此更新规则具有以下两项令人欣喜的性质,解释如下:

  1. 在极限情况 r<l>0(即不使用 L1 正则化)下,此更新规则与原始更新规则相同。

  2. 此更新规则可以解释为应用邻近算子,其不动点是 L1 正则化最小化问题的解

\underset{\beta - \beta^{(t)} \in \text{span}{ \text{onehot}(j^{(t)}) } }{\text{arg min} } \left( -\ell(\beta ,;, \mathbf{x}, \mathbf{y})

退化情况 r<l>=0 可恢复原始更新规则

要查看 (1),请注意如果 r<l>=0γ(t)=0,因此

(β(t+1)reg)<emdatamdtype="emphasis">j(t)amp;=SoftThreshold(β(t)</em>j(t)α,u(t), 0) amp;=β(t)j(t)α,u(t).

因此

β(t+1)regamp;=β(t)α,u(t),onehot(j(t)) amp;=β(t+1).

不动点为正则化最大似然估计的邻近算子

要查看 (2),首先要注意(参见 Wikipedia)对于任何 \gamma &gt; 0\gamma &gt; 0,更新规则

(β(t+1)exact-prox,γ)<emdatamdtype="emphasis">j(t):=prox</em>γ1(β(t)<emdatamdtype="emphasis">j(t)+γr</em><l>((β,(β,;,x,y))<em>β=β(t))</em>j(t))

均满足 (2),其中 prox 是邻近算子(参见 Yu,其中该算子表示为 P)。上述方程的右半部分在此处计算:

\left(\beta{\text{exact-prox}, \gamma}^{(t+1)}\right){j^{(t)} }

特别地,设置 γ=γ(t)=α,r<l>(H(t))j(t),j(t)(注:只要负对数似然函数是凸函数,γ(t)0),我们得到更新规则

\left(\beta{\text{exact-prox}, \gamma^{(t)} }^{(t+1)}\right){j^{(t)} }

然后,我们将精确梯度 (β,(β,;,x,y))β=β(t) 替换为其近似值 s(t),得到

\begin{align} \left(\beta_{\text{exact-prox}, \gamma^{(t)} }^{(t+1)}\right){j^{(t)} } &\approx \text{SoftThreshold} \left( \beta^{(t)}{j^{(t)} } - \alpha \frac{ \left(s^{(t)}\right){j^{(t)} } }{ \left(H^{(t)}\right){j^{(t)}, j^{(t)} } } ,\ \gamma^{(t)} \right) \ &= \text{SoftThreshold} \left( \beta^{(t)}_{j^{(t)} } - \alpha, u^{(t)} ,\ \gamma^{(t)} \right). \end{align}

因此

β(t+1)exact-prox,γ(t)β(t+1)reg.

GLM 事实的推导

在本部分中,我们将详细说明并推导出在之前几部分中使用的 GLM 相关结果。然后,我们将使用 TensorFlow 的 gradients 对导出的对数似然函数和 Fisher 信息的梯度公式进行数值验证。

得分和 Fisher 信息

考虑由参数向量 θ 参数化的概率分布族,其概率密度为 \left{p(\cdot | \theta)\right}_{\theta \in \mathcal{T} }\left{p(\cdot | \theta)\right}_{\theta \in \mathcal{T} }。参数向量 θ0 处的结果 y得分定义为 y 的对数似然函数的梯度(在 θ0 处评估),即:

score(y,θ0):=[θ,logp(y|θ)]θ=θ0.

声明:得分的期望值为零

在非极端正则条件(允许我们传递积分符号内取微分)下,

EYp(|θ=θ0)[score(Y,θ0)]=0.

证明

已知

E<emdatamdtype="emphasis">Yp(|θ=θ0)[score(Y,θ0)]amp;:=E</em>Yp(|θ=θ0)[(θlogp(Y|θ))<emdatamdtype="emphasis">θ=θ0] amp;(1)=E</em>Yp(|θ=θ0)[(θp(Y|θ))<emdatamdtype="emphasis">θ=θ0p(Y|θ=θ0)] amp;(2)=</em>Y[(θp(y|θ))<emdatamdtype="emphasis">θ=θ0p(y|θ=θ0)]p(y|θ=θ0),dy amp;=</em>Y(θp(y|θ))<emdatamdtype="emphasis">θ=θ0,dy amp;(3)=[</em>θ(Yp(y|θ),dy)]<emdatamdtype="emphasis">θ=θ0 amp;(4)=[</em>θ,1]θ=θ0 amp;=0,

其中我们使用了:(1) 微分连锁律、(2) 期望的定义、(3) 传递积分符号内取微分(使用正则条件)、(4) 概率密度的积分为 1。

声明(Fisher 信息):得分方差等于对数似然函数的 Hessian 负期望值

在非极端正则条件(允许我们传递积分符号内取微分)下,

\mathbb{E}_{Y \sim p(\cdot | \theta=\theta_0)}\left[ \text{score}(Y, \theta_0) \text{score}(Y, \theta_0)^\top \right]

其中 2θF 表示 Hessian 矩阵,其 (i,j) 项为 2Fθiθj

此方程的左半部分称为参数向量 θ0 处的族 \left{p(\cdot | \theta)\right}_{\theta \in \mathcal{T} }\left{p(\cdot | \theta)\right}_{\theta \in \mathcal{T} }Fisher 信息

声明证明

已知

\begin{align} \mathbb{E}{Y \sim p(\cdot | \theta=\theta_0)}\left[ \left(\nabla\theta^2 \log p(Y | \theta)\right){\theta=\theta_0} \right] &\stackrel{\text{(1)} }{=} \mathbb{E}{Y \sim p(\cdot | \theta=\theta0)}\left[ \left(\nabla\theta^\top \frac{ \nabla_\theta p(Y | \theta) }{ p(Y|\theta) }\right){\theta=\theta_0} \right] \ &\stackrel{\text{(2)} }{=} \mathbb{E}*{Y \sim p(\cdot | \theta=\theta0)}\left[ \frac{ \left(\nabla^2\theta p(Y | \theta)\right)_{\theta=\theta_0} }{ p(Y|\theta=\theta_0) }

其中我们使用了 (1) 微分链式法则、(2) 微分商法则、(3)再次反向使用链式法则。

要完成证明,只需证明

E<emdatamdtype="emphasis">Yp(|θ=θ0)[(2</em>θp(Y|θ))θ=θ0p(Y|θ=θ0)]?=0.

为此,我们传递积分符号内取微分两次:

E<emdatamdtype="emphasis">Yp(|θ=θ0)[(2</em>θp(Y|θ))<emdatamdtype="emphasis">θ=θ0p(Y|θ=θ0)]amp;=</em>Y[(2θp(y|θ))<emdatamdtype="emphasis">θ=θ0p(y|θ=θ0)],p(y|θ=θ0),dy amp;=</em>Y(2θp(y|θ))<emdatamdtype="emphasis">θ=θ0,dy amp;=[</em>θ2(Yp(y|θ),dy)]<emdatamdtype="emphasis">θ=θ0 amp;=[</em>θ2,1]θ=θ0 amp;=0.

对数配分函数的导数相关引理

如果 abc 是标量值函数,则 c 二次可微,使分布族 \left{p(\cdot | \theta)\right}_{\theta \in \mathcal{T} }\left{p(\cdot | \theta)\right}_{\theta \in \mathcal{T} } 定义为

p(y|θ)=a(y)exp(b(y),θc(θ))

满足非极端正则条件,允许传递在对 y 的积分符号内取对 θ 的微分,然后

EYp(|θ=θ0)[b(Y)]=c(θ0)

VarYp(|θ=θ0)[b(Y)]=c(θ0).

(这里 表示微分,所以 ccc 的一阶导数和二阶导数。)

证明

对于此分布族,已知 score(y,θ0)=b(y)c(θ0)。然后第一个方程遵循以下事实 EYp(|θ=θ0)[score(y,θ0)]=0。接下来,已知

Var<emdatamdtype="emphasis">Yp(|θ=θ0)[b(Y)]amp;=E</em>Yp(|θ=θ0)[(b(Y)c(θ0))2] amp;=the one entry of E<emdatamdtype="emphasis">Yp(|θ=θ0)[score(y,θ0)score(y,θ0)] amp;=the one entry of E</em>Yp(|θ=θ0)[(2θlogp(|θ))<emdatamdtype="emphasis">θ=θ0] amp;=E</em>Yp(|θ=θ0)[c(θ0)] amp;=c(θ0).

过度离散指数族

过度离散指数族(标量)是一种分布族,其密度为

pOEF(m,T)(y,|,θ,ϕ)=m(y,ϕ)exp(θ,T(y)A(θ)ϕ),

其中 mT 是已知的标量值函数,θϕ 是标量参数。

[注:A 是超定的:对于任何 ϕ0,函数 A 完全由此约束定义:对所有 θ,均满足 \int p_{\text{OEF}(m, T)}(y\ |\ \theta, \phi=\phi_0), dy = 1\phi_0AmT$ 函数施加了约束。]

充分统计量的均值和方差

在与“对数配分函数的导数相关引理”部分的相同条件下,已知

\mathbb{E}{Y \sim p{\text{OEF}(m, T)}(\cdot | \theta, \phi)} \left[ T(Y) \right]

\text{Var}{Y \sim p{\text{OEF}(m, T)}(\cdot | \theta, \phi)} \left[ T(Y) \right]

证明

根据“对数配分函数的导数相关引理”,已知

\mathbb{E}{Y \sim p{\text{OEF}(m, T)}(\cdot | \theta, \phi)} \left[ \frac{T(Y)}{\phi} \right]

\text{Var}{Y \sim p{\text{OEF}(m, T)}(\cdot | \theta, \phi)} \left[ \frac{T(Y)}{\phi} \right]

结果满足期望为线性 (E[aX]=aE[X]) 并且方差为二次齐次式 (Var[aX]=a2,Var[X])。

广义线性模型

在广义线性模型中,响应变量 Y 的预测分布与观察到的预测变量 x 的向量相关联。该分布是过度离散指数族的成员,参数 θ 被替换为 h(η),其中 h 是已知函数,η:=xβ 是所谓的线性响应β 是要学习的参数(回归系数)的向量。通常,也可以学习离散参数 ϕ,但在我们的设置中,我们将 ϕ 视为已知。因此我们设置如下

YpOEF(m,T)(,|,θ=h(η),ϕ)

其中模型结构的特征在于分布 pOEF(m,T) 和将线性响应转换为参数的函数 h

传统上,从线性响应 η 到均值 μ:=EYpOEF(m,T)(,|,θ=h(η),ϕ)[Y] 的映射表示为

μ=g1(η).

此映射需为一对一映射,它的反函数 g 被称为此 GLM 的联系函数。通常,人们通过命名其联系函数及其分布族来描述 GLM,例如,“具有伯努利分布和 logit 联系函数的 GLM”(也称为逻辑回归模型)。为了完全表征 GLM,还必须指定函数 h。如果 h 为恒等函数,则称 g正则联系函数

声明:用充分统计量表达 h

定义

Mean<emdatamdtype="emphasis">T(η):=E</em>YpOEF(m,T)(|θ=h(η),ϕ)[T(Y)]

Var<emdatamdtype="emphasis">T(η):=Var</em>YpOEF(m,T)(|θ=h(η),ϕ)[T(Y)].

然后,已知

h(η)=ϕ,MeanT(η)VarT(η).

证明

根据“充分统计量的均值和方差”,已知

MeanT(η)=A(h(η)).

用链式法则求导,我们得到 \( {\text{Mean}_T}'(\eta) = A''(h(\eta)), h'(\eta), \)

根据“充分统计量的均值和方差”

=1ϕVarT(η) h(η).

结论如下。

将 GLM 参数拟合到数据

上面推导出的属性非常适合将 GLM 参数 β 拟合到数据集。诸如 Fisher 得分法之类的拟牛顿法依赖于对数似然函数的梯度和 Fisher 信息,我们现在将展示对于 GLM 可以特别有效地计算这些信息。

假设我们已经观察到预测变量向量 xi 和相关的标量响应 yi。在矩阵形式中,我们会说我们观察到了预测变量 x 和响应 y,其中 x 是第 i 行为 xi 的矩阵,y 是第 i 个元素为 yi 的向量。参数 β 的对数似然函数为

(β,;,x,y)=Ni=1logpOEF(m,T)(yi,|,θ=h(xiβ),ϕ).

对于单个数据样本

为了简化表示法,让我们首先考虑单个数据点 N=1 时的情况;然后我们将通过可加性扩展到一般情况。

梯度

已知

(β,;,x,y)amp;=logpOEF(m,T)(y,|,θ=h(xβ),ϕ) amp;=logm(y,ϕ)+θ,T(y)A(θ)ϕ,where θ=h(xβ).

因此,根据链式法则,

β(β,;,x,y)=T(y)A(θ)ϕ,h(xβ),x.

另外,根据充分统计量的均值和方差”,已知 A(θ)=MeanT(xβ)。因此,根据“声明:用充分统计量表达 h”,可得

=(T(y)MeanT(xβ))MeanT(xβ)VarT(xβ),x.

Hessian

由乘积法则二次求导,得到

2β(β,;,x,y)amp;=[A(h(xβ)),h(xβ)]h(xβ),xx+[T(y)A(h(xβ))]h(xβ),xx] amp;=(MeanT(xβ),h(xβ)+[T(y)A(h(xβ))]),xx.

Fisher 信息

根据“充分统计量的均值和方差”,已知

E<emdatamdtype="emphasis">Yp</em>OEF(m,T)(|θ=h(xβ),ϕ)[T(y)A(h(xβ))]=0.

因此

E<emdatamdtype="emphasis">Yp</em>OEF(m,T)(|θ=h(xβ),ϕ)[2β(β,;,x,y)]amp;=MeanT(xβ),h(xβ)xx amp;=ϕ,MeanT(xβ)2VarT(xβ),xx.

对于多个数据样本

我们现在将 N=1 情况扩展到一般情况。让η:=xβ 表示第 ii 个数据样本的线性响应的向量。让 T (resp. MeanT, resp. VarT) 表示对每个坐标应用标量值函数 T (resp. MeanT, resp. VarT) 的广播(矢量化)函数。然后可得

β(β,;,x,y)amp;=Ni=1β(β,;,xi,yi) amp;=Ni=1(T(y)MeanT(xiβ))MeanT(xiβ)VarT(xiβ),xi amp;=x,diag(MeanT(xβ)VarT(xβ))(T(y)MeanT(xβ)) 

E<emdatamdtype="emphasis">Yip</em>OEF(m,T)(|θ=h(xiβ),ϕ)[2β(β,;,x,Y)]amp;=Ni=1E<emdatamdtype="emphasis">Yip</em>OEF(m,T)(|θ=h(xiβ),ϕ)[2β(β,;,xi,Yi)] amp;=Ni=1ϕ,MeanT(xiβ)2VarT(xiβ),xixi amp;=x,diag(ϕ,MeanT(xβ)2VarT(xβ)),x,

其中分数表示逐元素相除。

以数值方式验证公式

我们现在使用 tf.gradients 以数值方式验证上述对数似然函数的梯度的公式,并使用 tf.hessians 通过蒙特卡洛估计验证 Fisher 信息的公式:

def VerifyGradientAndFIM():
  model = tfp.glm.BernoulliNormalCDF()
  model_matrix = np.array([[1., 5, -2],
                           [8, -1, 8]])

  def _naive_grad_and_hessian_loss_fn(x, response):
    # Computes gradient and Hessian of negative log likelihood using autodiff.
    predicted_linear_response = tf.linalg.matvec(model_matrix, x)
    log_probs = model.log_prob(response, predicted_linear_response)
    grad_loss = tf.gradients(-log_probs, [x])[0]
    hessian_loss = tf.hessians(-log_probs, [x])[0]
    return [grad_loss, hessian_loss]

  def _grad_neg_log_likelihood_and_fim_fn(x, response):
    # Computes gradient of negative log likelihood and Fisher information matrix
    # using the formulas above.
    predicted_linear_response = tf.linalg.matvec(model_matrix, x)
    mean, variance, grad_mean = model(predicted_linear_response)

    v = (response - mean) * grad_mean / variance
    grad_log_likelihood = tf.linalg.matvec(model_matrix, v, adjoint_a=True)
    w = grad_mean**2 / variance

    fisher_info = tf.linalg.matmul(
        model_matrix,
        w[..., tf.newaxis] * model_matrix,
        adjoint_a=True)
    return [-grad_log_likelihood, fisher_info]

  @tf.function(autograph=False)
  def compute_grad_hessian_estimates():
    # Monte Carlo estimate of E[Hessian(-LogLikelihood)], where the expectation is
    # as written in "Claim (Fisher information)" above.
    num_trials = 20
    trial_outputs = []
    np.random.seed(10)
    model_coefficients_ = np.random.random(size=(model_matrix.shape[1],))
    model_coefficients = tf.convert_to_tensor(model_coefficients_)
    for _ in range(num_trials):
      # Sample from the distribution of `model`
      response = np.random.binomial(
          1,
          scipy.stats.norm().cdf(np.matmul(model_matrix, model_coefficients_))
      ).astype(np.float64)
      trial_outputs.append(
          list(_naive_grad_and_hessian_loss_fn(model_coefficients, response)) +
          list(
              _grad_neg_log_likelihood_and_fim_fn(model_coefficients, response))
      )

    naive_grads = tf.stack(
        list(naive_grad for [naive_grad, _, _, _] in trial_outputs), axis=0)
    fancy_grads = tf.stack(
        list(fancy_grad for [_, _, fancy_grad, _] in trial_outputs), axis=0)

    average_hess = tf.reduce_mean(tf.stack(
        list(hess for [_, hess, _, _] in trial_outputs), axis=0), axis=0)
    [_, _, _, fisher_info] = trial_outputs[0]
    return naive_grads, fancy_grads, average_hess, fisher_info

  naive_grads, fancy_grads, average_hess, fisher_info = [
      t.numpy() for t in compute_grad_hessian_estimates()]

  print("Coordinatewise relative error between naively computed gradients and"
        " formula-based gradients (should be zero):\n{}\n".format(
            (naive_grads - fancy_grads) / naive_grads))

  print("Coordinatewise relative error between average of naively computed"
        " Hessian and formula-based FIM (should approach zero as num_trials"
        " -> infinity):\n{}\n".format(
                (average_hess - fisher_info) / average_hess))

VerifyGradientAndFIM()
Coordinatewise relative error between naively computed gradients and formula-based gradients (should be zero):
[[2.08845965e-16 1.67076772e-16 2.08845965e-16]
 [1.96118673e-16 3.13789877e-16 1.96118673e-16]
 [2.08845965e-16 1.67076772e-16 2.08845965e-16]
 [1.96118673e-16 3.13789877e-16 1.96118673e-16]
 [2.08845965e-16 1.67076772e-16 2.08845965e-16]
 [1.96118673e-16 3.13789877e-16 1.96118673e-16]
 [1.96118673e-16 3.13789877e-16 1.96118673e-16]
 [1.96118673e-16 3.13789877e-16 1.96118673e-16]
 [2.08845965e-16 1.67076772e-16 2.08845965e-16]
 [1.96118673e-16 3.13789877e-16 1.96118673e-16]
 [2.08845965e-16 1.67076772e-16 2.08845965e-16]
 [1.96118673e-16 3.13789877e-16 1.96118673e-16]
 [1.96118673e-16 3.13789877e-16 1.96118673e-16]
 [1.96118673e-16 3.13789877e-16 1.96118673e-16]
 [1.96118673e-16 3.13789877e-16 1.96118673e-16]
 [1.96118673e-16 3.13789877e-16 1.96118673e-16]
 [1.96118673e-16 3.13789877e-16 1.96118673e-16]
 [2.08845965e-16 1.67076772e-16 2.08845965e-16]
 [1.96118673e-16 3.13789877e-16 1.96118673e-16]
 [2.08845965e-16 1.67076772e-16 2.08845965e-16]]

Coordinatewise relative error between average of naively computed Hessian and formula-based FIM (should approach zero as num_trials -> infinity):
[[0.00072369 0.00072369 0.00072369]
 [0.00072369 0.00072369 0.00072369]
 [0.00072369 0.00072369 0.00072369]]

参考文献

[1]: Guo-Xun Yuan, Chia-Hua Ho and Chih-Jen Lin. An Improved GLMNET for L1-regularized Logistic Regression. Journal of Machine Learning Research, 13, 2012. http://www.jmlr.org/papers/volume13/yuan12a/yuan12a.pdf

[2]: skd. Derivation of Soft Thresholding Operator. 2018. https://math.stackexchange.com/q/511106

[3]: Wikipedia Contributors. Proximal gradient methods for learning. Wikipedia, The Free Encyclopedia, 2018. https://en.wikipedia.org/wiki/Proximal_gradient_methods_for_learning

[4]: Yao-Liang Yu. The Proximity Operator. https://www.cs.cmu.edu/~suvrit/teach/yaoliang_proximity.pdf