{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "D2J3nB-ZrRv1" }, "source": [ "##### Copyright 2018 The TensorFlow Probability Authors.\n", "\n", "Licensed under the Apache License, Version 2.0 (the \"License\");" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "9qDhTJmprPnm" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\"); { display-mode: \"form\" }\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "pfPtIQ3DdZ8r" }, "source": [ "# 广义线性模型\n", "\n", "\n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本
" ] }, { "cell_type": "markdown", "metadata": { "id": "EOfH1_F9YsOG" }, "source": [ "在此笔记本中,我们将通过一个工作示例来介绍广义线性模型。我们使用两种算法以两种不同的方式解决此示例,以在 TensorFlow Probability 中有效地拟合 GLM:针对密集数据使用 Fisher 得分算法,针对稀疏数据使用坐标近端梯度下降算法。我们将拟合系数与真实系数进行对比,在坐标近端梯度下降算法下则与 R 语言的类似 `glmnet` 算法的输出进行对比。最后,我们提供了 GLM 一些关键属性的进一步数学细节和推导。" ] }, { "cell_type": "markdown", "metadata": { "id": "rjsfQ6vLb5I0" }, "source": [ "# 背景" ] }, { "cell_type": "markdown", "metadata": { "id": "TdMX-QKagFnY" }, "source": [ "广义线性模型 (GLM) 是一种封装在转换(联系函数)中并配备了指数族的响应分布的线性模型 ($\\eta = x^\\top \\beta$) 。联系函数和响应分布的选择非常灵活,这为 GLM 赋予了出色的表达性。在下面的“GLM 事实的推导”中可以找到完整的详细信息,包括以明确的表示法对 GLM 构建的所有定义和结果的有序介绍。我们总结如下:\n", "\n", "在 GLM 中,响应变量 $Y$ 的预测分布与观察到的预测变量 $x$ 的向量相关联。分布形式如下:\n", "\n", "$$ \\begin{align*} p(y , |, x) &= m(y, \\phi) \\exp\\left(\\frac{\\theta, T(y) - A(\\theta)}{\\phi}\\right) \\ \\theta &:= h(\\eta) \\ \\eta &:= x^\\top \\beta \\end{align*} $$\n", "\n", "其中,$\\beta$ 是参数(“权重”),$\\phi$ 是表示离散度(“方差”)的超参数,$m$、$h$、$T$、$A$ 由用户指定模型族表征。\n", "\n", "$Y$ 的均值取决于 $x$ 的**线性响应** $\\eta$ 和(逆)联系函数,即:\n", "\n", "$$ \\mu := g^{-1}(\\eta) $$\n", "\n", "其中 $g$ 是所谓的**联系函数**。在 TFP 中,联系函数和模型族的选择由 `tfp.glm.ExponentialFamily` 子类共同指定。示例包括:\n", "\n", "- `tfp.glm.Normal`,又名“线性回归”\n", "- `tfp.glm.Bernoulli`,又名“逻辑回归”\n", "- `tfp.glm.Poisson`,又名“泊松回归”\n", "- `tfp.glm.BernoulliNormalCDF`,又名“概率回归”。\n", "\n", "TFP 更喜欢根据 `Y` 的分布而非联系函数来命名模型族,因为 `tfp.Distribution` 已经是一等公民。如果 `tfp.glm.ExponentialFamily` 子类名称包含第二个单词,则表示[非正则联系函数](https://en.wikipedia.org/wiki/Generalized_linear_model#Link_function)。" ] }, { "cell_type": "markdown", "metadata": { "id": "1oGScpRnqH_b" }, "source": [ "GLM 具有几项可有效地实现最大似然 estimator 的显著特性。这些特性中最主要的是为对数似然函数 $\\ell$ 梯度以及 Fisher 信息矩阵提供了简单的公式,它是在相同预测变量下对响应重新采样时负对数似然函数的 Hessian 的期望值。即:\n", "\n", "$$ \\begin{align*} \\nabla_\\beta, \\ell(\\beta, ;, \\mathbf{x}, \\mathbf{y}) &= \\mathbf{x}^\\top ,\\text{diag}\\left(\\frac{ {\\textbf{Mean}_T}'(\\mathbf{x} \\beta) }{ {\\textbf{Var}T}(\\mathbf{x} \\beta) }\\right) \\left(\\mathbf{T}(\\mathbf{y}) - {\\textbf{Mean}T}(\\mathbf{x} \\beta)\\right) \\ \\mathbb{E}{Y_i \\sim \\text{GLM} | x_i} \\left[ \\nabla\\beta^2, \\ell(\\beta, ;, \\mathbf{x}, \\mathbf{Y}) \\right] &= -\\mathbf{x}^\\top ,\\text{diag}\\left( \\frac{ \\phi, {\\textbf{Mean}_T}'(\\mathbf{x} \\beta)^2 }{ {\\textbf{Var}_T}(\\mathbf{x} \\beta) }\\right), \\mathbf{x} \\end{align*} $$\n", "\n", "其中 $\\mathbf{x}$ 是矩阵,其第 $i$ 行是第 $i$ 个数据样本的预测变量向量;$\\mathbf{y}$ 是向量,其第 $i$ 个坐标是第 $i$ 个数据样本的观察到的响应。这里(粗略地讲),${\\text{Mean}_T}(\\eta) := \\mathbb{E}[T(Y),|,\\eta]$ 和 ${\\text{Var}_T}(\\eta) := \\text{Var}[T(Y),|,\\eta]$,粗体表示这些函数的矢量化。有关这些期望和方差的分布的完整详细信息,请参阅下方的“GLM 事实的推导”。" ] }, { "cell_type": "markdown", "metadata": { "id": "XuNDwfwBObKl" }, "source": [ "# 示例\n", "\n", "在本部分中,我们将简要介绍和展示 TensorFlow Probability 中的两种内置 GLM 拟合算法:Fisher 得分 (`tfp.glm.fit`) 和坐标近端梯度下降 (`tfp.glm.fit_sparse`)。" ] }, { "cell_type": "markdown", "metadata": { "id": "4phryMfsP4Sn" }, "source": [ "## 合成数据集\n", "\n", "让我们假装加载一些训练数据集。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DA2Rf9PPgMAD" }, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import scipy\n", "import tensorflow.compat.v2 as tf\n", "tf.enable_v2_behavior()\n", "import tensorflow_probability as tfp\n", "tfd = tfp.distributions" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "KEVnTz2hh9RN" }, "outputs": [], "source": [ "def make_dataset(n, d, link, scale=1., dtype=np.float32):\n", " model_coefficients = tfd.Uniform(\n", " low=-1., high=np.array(1, dtype)).sample(d, seed=42)\n", " radius = np.sqrt(2.)\n", " model_coefficients *= radius / tf.linalg.norm(model_coefficients)\n", " mask = tf.random.shuffle(tf.range(d)) < int(0.5 * d)\n", " model_coefficients = tf.where(\n", " mask, model_coefficients, np.array(0., dtype))\n", " model_matrix = tfd.Normal(\n", " loc=0., scale=np.array(1, dtype)).sample([n, d], seed=43)\n", " scale = tf.convert_to_tensor(scale, dtype)\n", " linear_response = tf.linalg.matvec(model_matrix, model_coefficients)\n", " \n", " if link == 'linear':\n", " response = tfd.Normal(loc=linear_response, scale=scale).sample(seed=44)\n", " elif link == 'probit':\n", " response = tf.cast(\n", " tfd.Normal(loc=linear_response, scale=scale).sample(seed=44) > 0,\n", " dtype)\n", " elif link == 'logit':\n", " response = tfd.Bernoulli(logits=linear_response).sample(seed=44)\n", " else:\n", " raise ValueError('unrecognized true link: {}'.format(link))\n", " return model_matrix, response, model_coefficients, mask" ] }, { "cell_type": "markdown", "metadata": { "id": "99Fk5XZKbvi4" }, "source": [ "### 注:连接到本地运行时。\n", "\n", "在此笔记本中,我们使用本地文件在 Python 和 R 内核之间共享数据。要启用此共享,请在您具备本地文件读写权限的同一台计算机上使用运行时。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2EAQjTrZJqKx" }, "outputs": [], "source": [ "x, y, model_coefficients_true, _ = [t.numpy() for t in make_dataset(\n", " n=int(1e5), d=100, link='probit')]\n", "\n", "DATA_DIR = '/tmp/glm_example'\n", "tf.io.gfile.makedirs(DATA_DIR)\n", "with tf.io.gfile.GFile('{}/x.csv'.format(DATA_DIR), 'w') as f:\n", " np.savetxt(f, x, delimiter=',')\n", "with tf.io.gfile.GFile('{}/y.csv'.format(DATA_DIR), 'w') as f:\n", " np.savetxt(f, y.astype(np.int32) + 1, delimiter=',', fmt='%d')\n", "with tf.io.gfile.GFile(\n", " '{}/model_coefficients_true.csv'.format(DATA_DIR), 'w') as f:\n", " np.savetxt(f, model_coefficients_true, delimiter=',')" ] }, { "cell_type": "markdown", "metadata": { "id": "0P5I-aJdN6GZ" }, "source": [ "## 不使用 L1 正则化" ] }, { "cell_type": "markdown", "metadata": { "id": "VN6HfiH3bAb0" }, "source": [ "函数 `tfp.glm.fit` 实现 Fisher 得分,它采用一些参数:\n", "\n", "- `model_matrix` = $\\mathbf{x}$\n", "- `response` = $\\mathbf{y}$\n", "- `model` = 可调用对象,给定参数 $\\boldsymbol{\\eta}$,返回三元组 $\\left( {\\textbf{Mean}_T}(\\boldsymbol{\\eta}), {\\textbf{Var}_T}(\\boldsymbol{\\eta}), {\\textbf{Mean}_T}'(\\boldsymbol{\\eta}) \\right)$。\n", "\n", "我们建议该 `model` 为 `tfp.glm.ExponentialFamily` 类的实例。有几种预制的实现可用,对于大多数常见的 GLM,不需要自定义代码。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "iXkxVBSmesjn" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "is_converged: True\n", " num_iter: 6\n", " accuracy: 0.75241\n", " deviance: -0.992436110973\n", "||w0-w1||_2 / (1+||w0||_2): 0.0231555201462\n" ] } ], "source": [ "@tf.function(autograph=False)\n", "def fit_model():\n", " model_coefficients, linear_response, is_converged, num_iter = tfp.glm.fit(\n", " model_matrix=x, response=y, model=tfp.glm.BernoulliNormalCDF())\n", " log_likelihood = tfp.glm.BernoulliNormalCDF().log_prob(y, linear_response)\n", " return (model_coefficients, linear_response, is_converged, num_iter,\n", " log_likelihood)\n", " \n", "[model_coefficients, linear_response, is_converged, num_iter,\n", " log_likelihood] = [t.numpy() for t in fit_model()]\n", "\n", "print(('is_converged: {}\\n'\n", " ' num_iter: {}\\n'\n", " ' accuracy: {}\\n'\n", " ' deviance: {}\\n'\n", " '||w0-w1||_2 / (1+||w0||_2): {}'\n", " ).format(\n", " is_converged,\n", " num_iter,\n", " np.mean((linear_response > 0.) == y),\n", " 2. * np.mean(log_likelihood),\n", " np.linalg.norm(model_coefficients_true - model_coefficients, ord=2) /\n", " (1. + np.linalg.norm(model_coefficients_true, ord=2))\n", " ))" ] }, { "cell_type": "markdown", "metadata": { "id": "h6qexoHAJzEF" }, "source": [ "### 数学细节\n", "\n", "Fisher 得分法是对牛顿法的修改,用于寻找最大似然估计\n", "\n", "$$ \\hat\\beta := \\underset{\\beta}{\\text{arg max}}\\ \\ \\ell(\\beta\\ ;\\ \\mathbf{x}, \\mathbf{y}). $$\n", "\n", "普通牛顿法,搜索对数似然函数梯度的零点,将遵循更新规则\n", "\n", "## $$ \\beta^{(t+1)}_{\\text{Newton}} := \\beta^{(t)}\n", "\n", "其中 $\\alpha \\in (0, 1]$ 是用于控制步长的学习率。\n", "\n", "在 Fisher 得分法中,我们将 Hessian 替换为负的 Fisher 信息矩阵:\n", "\n", "## $$ \\begin{align*} \\beta^{(t+1)} &:= \\beta^{(t)}\n", "\n", "[注:此处 $\\mathbf{Y} = (Y_i)_{i=1}^{n}$ 是随机的,而 $\\mathbf{y}$ 仍是观察到的响应的向量。]\n", "\n", "通过下文“将 GLM 参数拟合到数据”中的公式,可将其简化为\n", "\n", "$$ \\begin{align*} \\beta^{(t+1)} &= \\beta^{(t)} + \\alpha \\left( \\mathbf{x}^\\top \\text{diag}\\left( \\frac{ \\phi, {\\textbf{Mean}_T}'(\\mathbf{x} \\beta^{(t)})^2 }{ {\\textbf{Var}_T}(\\mathbf{x} \\beta^{(t)}) }\\right), \\mathbf{x} \\right)^{-1} \\left( \\mathbf{x}^\\top \\text{diag}\\left(\\frac{ {\\textbf{Mean}_T}'(\\mathbf{x} \\beta^{(t)}) }{ {\\textbf{Var}_T}(\\mathbf{x} \\beta^{(t)}) }\\right) \\left(\\mathbf{T}(\\mathbf{y}) - {\\textbf{Mean}_T}(\\mathbf{x} \\beta^{(t)})\\right) \\right). \\end{align*} $$" ] }, { "cell_type": "markdown", "metadata": { "id": "076quM7tN8_1" }, "source": [ "## 使用 L1 正则化" ] }, { "cell_type": "markdown", "metadata": { "id": "fnP3jeZOk7Y5" }, "source": [ "`tfp.glm.fit_sparse` 基于 [Yuan, Ho and Lin 2012](#1) 中的算法实现了更适合稀疏数据集的 GLM 拟合器。特性包括:\n", "\n", "- L1 正则化\n", "- 不使用矩阵求逆\n", "- 只需少量梯度和 Hessian 评估。\n", "\n", "我们首先展示代码的示例用法。算法的细节会在下文“`tfp.glm.fit_sparse` 的算法细节”中进一步阐述。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "v_Oky1X4ijfv" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "is_converged: True\n", " num_iter: 1\n", "\n", "Coefficients:\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
LearnedTrue
00.2162400.220758
10.0000000.000000
20.0000000.000000
30.0000000.000000
40.0000000.000000
50.0437020.063950
6-0.145379-0.153256
70.0000000.000000
80.0000000.000000
90.0000000.000000
100.0000000.000000
110.0000000.000000
120.0000000.000000
130.0243820.046572
14-0.242985-0.242609
15-0.106168-0.123367
160.0000000.000000
17-0.039745-0.067560
18-0.217717-0.222169
190.0000000.000000
200.0000000.000000
21-0.016553-0.041692
220.0189590.049624
23-0.057686-0.078299
240.0036420.035682
250.0000000.000000
260.0000000.000000
27-0.234406-0.240482
280.0000000.000000
290.2322090.225448
.........
700.0000000.000000
710.1301660.144485
720.0000000.000000
730.0000000.000000
740.0000000.000000
75-0.178534-0.186722
760.0000000.000000
770.2184930.229656
780.0000000.000000
790.0000000.000000
800.1955790.200442
810.0000000.000000
820.0000000.000000
830.0311530.050457
840.2290650.231451
85-0.006512-0.039516
86-0.107947-0.119896
870.0000000.000000
880.1494190.171693
890.0000000.000000
900.0479550.063434
910.0000000.003592
92-0.083171-0.107145
930.0846150.101221
94-0.168431-0.175473
950.1384110.152623
960.0000000.000000
970.0611610.081945
98-0.083348-0.104929
99-0.141154-0.153871
\n", "

100 rows × 2 columns

\n", "
" ], "text/plain": [ " Learned True\n", "0 0.216240 0.220758\n", "1 0.000000 0.000000\n", "2 0.000000 0.000000\n", "3 0.000000 0.000000\n", "4 0.000000 0.000000\n", "5 0.043702 0.063950\n", "6 -0.145379 -0.153256\n", "7 0.000000 0.000000\n", "8 0.000000 0.000000\n", "9 0.000000 0.000000\n", "10 0.000000 0.000000\n", "11 0.000000 0.000000\n", "12 0.000000 0.000000\n", "13 0.024382 0.046572\n", "14 -0.242985 -0.242609\n", "15 -0.106168 -0.123367\n", "16 0.000000 0.000000\n", "17 -0.039745 -0.067560\n", "18 -0.217717 -0.222169\n", "19 0.000000 0.000000\n", "20 0.000000 0.000000\n", "21 -0.016553 -0.041692\n", "22 0.018959 0.049624\n", "23 -0.057686 -0.078299\n", "24 0.003642 0.035682\n", "25 0.000000 0.000000\n", "26 0.000000 0.000000\n", "27 -0.234406 -0.240482\n", "28 0.000000 0.000000\n", "29 0.232209 0.225448\n", ".. ... ...\n", "70 0.000000 0.000000\n", "71 0.130166 0.144485\n", "72 0.000000 0.000000\n", "73 0.000000 0.000000\n", "74 0.000000 0.000000\n", "75 -0.178534 -0.186722\n", "76 0.000000 0.000000\n", "77 0.218493 0.229656\n", "78 0.000000 0.000000\n", "79 0.000000 0.000000\n", "80 0.195579 0.200442\n", "81 0.000000 0.000000\n", "82 0.000000 0.000000\n", "83 0.031153 0.050457\n", "84 0.229065 0.231451\n", "85 -0.006512 -0.039516\n", "86 -0.107947 -0.119896\n", "87 0.000000 0.000000\n", "88 0.149419 0.171693\n", "89 0.000000 0.000000\n", "90 0.047955 0.063434\n", "91 0.000000 0.003592\n", "92 -0.083171 -0.107145\n", "93 0.084615 0.101221\n", "94 -0.168431 -0.175473\n", "95 0.138411 0.152623\n", "96 0.000000 0.000000\n", "97 0.061161 0.081945\n", "98 -0.083348 -0.104929\n", "99 -0.141154 -0.153871\n", "\n", "[100 rows x 2 columns]" ] }, "execution_count": 0, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "model = tfp.glm.Bernoulli()\n", "model_coefficients_start = tf.zeros(x.shape[-1], np.float32)\n", "@tf.function(autograph=False)\n", "def fit_model():\n", " return tfp.glm.fit_sparse(\n", " model_matrix=tf.convert_to_tensor(x),\n", " response=tf.convert_to_tensor(y),\n", " model=model,\n", " model_coefficients_start=model_coefficients_start,\n", " l1_regularizer=800.,\n", " l2_regularizer=None,\n", " maximum_iterations=10,\n", " maximum_full_sweeps_per_iteration=10,\n", " tolerance=1e-6,\n", " learning_rate=None)\n", "\n", "model_coefficients, is_converged, num_iter = [t.numpy() for t in fit_model()]\n", "coefs_comparison = pd.DataFrame({\n", " 'Learned': model_coefficients,\n", " 'True': model_coefficients_true,\n", "})\n", " \n", "print(('is_converged: {}\\n'\n", " ' num_iter: {}\\n\\n'\n", " 'Coefficients:').format(\n", " is_converged,\n", " num_iter))\n", "coefs_comparison" ] }, { "cell_type": "markdown", "metadata": { "id": "DrJC2J1YbR5L" }, "source": [ "请注意,学习的系数与真实系数具有相同的稀疏模式。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hQ7SzrPZMpke" }, "outputs": [], "source": [ "# Save the learned coefficients to a file.\n", "with tf.io.gfile.GFile('{}/model_coefficients_prox.csv'.format(DATA_DIR), 'w') as f:\n", " np.savetxt(f, model_coefficients, delimiter=',')" ] }, { "cell_type": "markdown", "metadata": { "id": "VW9NgB1Zisqh" }, "source": [ "### 对比 R 语言的 `glmnet`\n", "\n", "我们将坐标近端梯度下降算法的输出与使用类似算法的 R 语言的 `glmnet` 的输出进行对比。" ] }, { "cell_type": "markdown", "metadata": { "id": "Aptz7SWwkd5v" }, "source": [ "#### 注:要执行此部分,您必须切换到 R colab 运行时。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RS1H3n53h9qc" }, "outputs": [], "source": [ "suppressMessages({\n", " library('glmnet')\n", "})" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2X6zKSaxie7I" }, "outputs": [], "source": [ "data_dir <- '/tmp/glm_example'\n", "x <- as.matrix(read.csv(paste(data_dir, '/x.csv', sep=''),\n", " header=FALSE))\n", "y <- as.matrix(read.csv(paste(data_dir, '/y.csv', sep=''),\n", " header=FALSE, colClasses='integer'))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Eb31LbhRjsSz" }, "outputs": [], "source": [ "fit <- glmnet(\n", "x = x,\n", "y = y,\n", "family = \"binomial\", # Logistic regression\n", "alpha = 1, # corresponds to l1_weight = 1, l2_weight = 0\n", "standardize = FALSE,\n", "intercept = FALSE,\n", "thresh = 1e-30,\n", "type.logistic = \"Newton\"\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HTN4RKQbhlCm" }, "outputs": [], "source": [ "write.csv(as.matrix(coef(fit, 0.008)),\n", " paste(data_dir, '/model_coefficients_glmnet.csv', sep=''),\n", " row.names=FALSE)" ] }, { "cell_type": "markdown", "metadata": { "id": "vsrEKgUGjGjf" }, "source": [ "#### 比较 R、TFP 和真实系数(注:回到 Python 内核)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "lCOlGo_4i2sb" }, "outputs": [], "source": [ "DATA_DIR = '/tmp/glm_example'\n", "with tf.io.gfile.GFile('{}/model_coefficients_glmnet.csv'.format(DATA_DIR),\n", " 'r') as f:\n", " model_coefficients_glmnet = np.loadtxt(f,\n", " skiprows=2 # Skip column name and intercept\n", " )\n", "\n", "with tf.io.gfile.GFile('{}/model_coefficients_prox.csv'.format(DATA_DIR),\n", " 'r') as f:\n", " model_coefficients_prox = np.loadtxt(f)\n", "\n", "with tf.io.gfile.GFile(\n", " '{}/model_coefficients_true.csv'.format(DATA_DIR), 'r') as f:\n", " model_coefficients_true = np.loadtxt(f)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4l-SZ85lnKg5" }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
RTFPTrue
00.2810800.2162400.220758
10.0000000.0000000.000000
20.0000000.0000000.000000
30.0000000.0000000.000000
40.0000000.0000000.000000
50.0566250.0437020.063950
6-0.188771-0.145379-0.153256
70.0000000.0000000.000000
80.0000000.0000000.000000
90.0000000.0000000.000000
100.0000000.0000000.000000
110.0000000.0000000.000000
120.0000000.0000000.000000
130.0301120.0243820.046572
14-0.316488-0.242985-0.242609
15-0.139214-0.106168-0.123367
160.0000000.0000000.000000
17-0.050239-0.039745-0.067560
18-0.283372-0.217717-0.222169
190.0000000.0000000.000000
200.0000000.0000000.000000
21-0.021815-0.016553-0.041692
220.0240700.0189590.049624
23-0.074039-0.057686-0.078299
240.0053210.0036420.035682
250.0000000.0000000.000000
260.0000000.0000000.000000
27-0.304958-0.234406-0.240482
280.0000000.0000000.000000
290.3015620.2322090.225448
............
700.0000000.0000000.000000
710.1692910.1301660.144485
720.0000000.0000000.000000
730.0000000.0000000.000000
740.0000000.0000000.000000
75-0.231294-0.178534-0.186722
760.0000000.0000000.000000
770.2842150.2184930.229656
780.0000000.0000000.000000
790.0000000.0000000.000000
800.2545240.1955790.200442
810.0000000.0000000.000000
820.0000000.0000000.000000
830.0407160.0311530.050457
840.2974750.2290650.231451
85-0.008569-0.006512-0.039516
86-0.141028-0.107947-0.119896
870.0000000.0000000.000000
880.1941300.1494190.171693
890.0000000.0000000.000000
900.0626010.0479550.063434
910.0000000.0000000.003592
92-0.107693-0.083171-0.107145
930.1093810.0846150.101221
94-0.218831-0.168431-0.175473
950.1806620.1384110.152623
960.0000000.0000000.000000
970.0788150.0611610.081945
98-0.108332-0.083348-0.104929
99-0.183284-0.141154-0.153871
\n", "

100 rows × 3 columns

\n", "
" ], "text/plain": [ " R TFP True\n", "0 0.281080 0.216240 0.220758\n", "1 0.000000 0.000000 0.000000\n", "2 0.000000 0.000000 0.000000\n", "3 0.000000 0.000000 0.000000\n", "4 0.000000 0.000000 0.000000\n", "5 0.056625 0.043702 0.063950\n", "6 -0.188771 -0.145379 -0.153256\n", "7 0.000000 0.000000 0.000000\n", "8 0.000000 0.000000 0.000000\n", "9 0.000000 0.000000 0.000000\n", "10 0.000000 0.000000 0.000000\n", "11 0.000000 0.000000 0.000000\n", "12 0.000000 0.000000 0.000000\n", "13 0.030112 0.024382 0.046572\n", "14 -0.316488 -0.242985 -0.242609\n", "15 -0.139214 -0.106168 -0.123367\n", "16 0.000000 0.000000 0.000000\n", "17 -0.050239 -0.039745 -0.067560\n", "18 -0.283372 -0.217717 -0.222169\n", "19 0.000000 0.000000 0.000000\n", "20 0.000000 0.000000 0.000000\n", "21 -0.021815 -0.016553 -0.041692\n", "22 0.024070 0.018959 0.049624\n", "23 -0.074039 -0.057686 -0.078299\n", "24 0.005321 0.003642 0.035682\n", "25 0.000000 0.000000 0.000000\n", "26 0.000000 0.000000 0.000000\n", "27 -0.304958 -0.234406 -0.240482\n", "28 0.000000 0.000000 0.000000\n", "29 0.301562 0.232209 0.225448\n", ".. ... ... ...\n", "70 0.000000 0.000000 0.000000\n", "71 0.169291 0.130166 0.144485\n", "72 0.000000 0.000000 0.000000\n", "73 0.000000 0.000000 0.000000\n", "74 0.000000 0.000000 0.000000\n", "75 -0.231294 -0.178534 -0.186722\n", "76 0.000000 0.000000 0.000000\n", "77 0.284215 0.218493 0.229656\n", "78 0.000000 0.000000 0.000000\n", "79 0.000000 0.000000 0.000000\n", "80 0.254524 0.195579 0.200442\n", "81 0.000000 0.000000 0.000000\n", "82 0.000000 0.000000 0.000000\n", "83 0.040716 0.031153 0.050457\n", "84 0.297475 0.229065 0.231451\n", "85 -0.008569 -0.006512 -0.039516\n", "86 -0.141028 -0.107947 -0.119896\n", "87 0.000000 0.000000 0.000000\n", "88 0.194130 0.149419 0.171693\n", "89 0.000000 0.000000 0.000000\n", "90 0.062601 0.047955 0.063434\n", "91 0.000000 0.000000 0.003592\n", "92 -0.107693 -0.083171 -0.107145\n", "93 0.109381 0.084615 0.101221\n", "94 -0.218831 -0.168431 -0.175473\n", "95 0.180662 0.138411 0.152623\n", "96 0.000000 0.000000 0.000000\n", "97 0.078815 0.061161 0.081945\n", "98 -0.108332 -0.083348 -0.104929\n", "99 -0.183284 -0.141154 -0.153871\n", "\n", "[100 rows x 3 columns]" ] }, "execution_count": 0, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "coefs_comparison = pd.DataFrame({\n", " 'TFP': model_coefficients_prox,\n", " 'R': model_coefficients_glmnet,\n", " 'True': model_coefficients_true,\n", "})\n", "coefs_comparison" ] }, { "cell_type": "markdown", "metadata": { "id": "Rfv0GVXqY74Y" }, "source": [ "# `tfp.glm.fit_sparse` 的算法细节\n", "\n", "我们将算法依次呈现为对牛顿法的三种修改形式。在每种形式中,$\\beta$ 的更新规则都基于向量 $s$ 和矩阵 $H$,它们会逼近对数似然函数的梯度和 Hessian。在步骤 $t$ 中,我们选择坐标 $j^{(t)}$ 进行更改,并根据更新规则更新 $\\beta$:\n", "\n", "## $$ \\begin{align*} u^{(t)} &:= \\frac{ \\left( s^{(t)} \\right)*{j^{(t)}} }{ \\left( H^{(t)} \\right)*{j^{(t)},, j^{(t)}} } \\[3mm] \\beta^{(t+1)} &:= \\beta^{(t)}\n", "\n", "此更新是一种类似牛顿法的步骤,学习率为 $\\alpha$。除了最后一部分(L1 正则化),下面的修改仅在 $s$ 和 $H$ 的更新方式上有所不同。" ] }, { "cell_type": "markdown", "metadata": { "id": "fH7C1xBWUV7_" }, "source": [ "## 起点:坐标牛顿法\n", "\n", "在坐标牛顿法中,我们将 $s$ 和 $H$ 设置为对数似然函数的真实梯度和 Hessian:\n", "\n", "$$ \\begin{align*} s^{(t)}{\\text{vanilla}} &:= \\left( \\nabla\\beta, \\ell(\\beta ,;, \\mathbf{x}, \\mathbf{y}) \\right){\\beta = \\beta^{(t)}} \\ H^{(t)}{\\text{vanilla}} &:= \\left( \\nabla^2_\\beta, \\ell(\\beta ,;, \\mathbf{x}, \\mathbf{y}) \\right)_{\\beta = \\beta^{(t)}} \\end{align*} $$" ] }, { "cell_type": "markdown", "metadata": { "id": "6rJZD6iyUl0v" }, "source": [ "## 只需少量梯度和 Hessian 评估\n", "\n", "对数似然函数的梯度和 Hessian 的计算通常十分消耗算力,因此通常值得对其采用逼近算法。我们可以如下处理:\n", "\n", "- 通常,将 Hessian 逼近为局部常值,并使用(逼近)Hessian 将梯度逼近为一阶:\n", "\n", "$$ \\begin{align*} H_{\\text{approx}}^{(t+1)} &:= H^{(t)} \\ s_{\\text{approx}}^{(t+1)} &:= s^{(t)} + H^{(t)} \\left( \\beta^{(t+1)} - \\beta^{(t)} \\right) \\end{align*} $$\n", "\n", "- 有时,可执行上述“普通”更新步骤,将 $s^{(t+1)}$ 设置为对数似然函数的精确梯度并将 $H^{(t+1)}$ 设置为其精确 Hessian,在 $\\beta^{(t+1)}$ 处评估。" ] }, { "cell_type": "markdown", "metadata": { "id": "rfvvyaVnUqIQ" }, "source": [ "## 使用负 Fisher 信息矩阵代替 Hessian\n", "\n", "为了进一步降低普通更新步骤的算力成本,我们可以将 $H$ 设置为负 Fisher 信息矩阵(使用下文“将 GLM 参数拟合到数据”中的公式可以有效计算),而非确切的 Hessian:\n", "\n", "$$ \\begin{align*} H_{\\text{Fisher}}^{(t+1)} &:= \\mathbb{E}{Y_i \\sim p{\\text{OEF}(m, T)}(\\cdot | \\theta = h(x_i^\\top \\beta^{(t+1)}), \\phi)} \\left[ \\left( \\nabla_\\beta^2, \\ell(\\beta, ;, \\mathbf{x}, \\mathbf{Y}) \\right)_{\\beta = \\beta^{(t+1)}} \\right] \\ &= -\\mathbf{x}^\\top ,\\text{diag}\\left( \\frac{ \\phi, {\\textbf{Mean}T}'(\\mathbf{x} \\beta^{(t+1)})^2 }{ {\\textbf{Var}T}(\\mathbf{x} \\beta^{(t+1)}) }\\right), \\mathbf{x} \\ s{\\text{Fisher}}^{(t+1)} &:= s{\\text{vanilla}}^{(t+1)} \\ &= \\left( \\mathbf{x}^\\top ,\\text{diag}\\left(\\frac{ {\\textbf{Mean}_T}'(\\mathbf{x} \\beta^{(t+1)}) }{ {\\textbf{Var}_T}(\\mathbf{x} \\beta^{(t+1)}) }\\right) \\left(\\mathbf{T}(\\mathbf{y}) - {\\textbf{Mean}_T}(\\mathbf{x} \\beta^{(t+1)})\\right) \\right) \\end{align*} $$" ] }, { "cell_type": "markdown", "metadata": { "id": "DTH07xYpWGcR" }, "source": [ "## 通过近端梯度下降求解 L1 正则化\n", "\n", "为包含 L1 正则化,我们将更新规则\n", "\n", "## $$ \\beta^{(t+1)} := \\beta^{(t)}\n", "\n", "替换为更通用的更新规则\n", "\n", "$$ \\begin{align*} \\gamma^{(t)} &:= -\\frac{\\alpha, r_{\\text}}{\\left(H^{(t)}\\right){j^{(t)},, j^{(t)}}} \\[2mm] \\left(\\beta{\\text{reg}}^{(t+1)}\\right)_j &:= \\begin{cases} \\beta^{(t+1)}_j &\\text{if } j \\neq j^{(t)} \\ \\text{SoftThreshold} \\left( \\beta^{(t)}_j - \\alpha, u^{(t)} ,\\ \\gamma^{(t)} \\right) &\\text{if } j = j^{(t)} \\end{cases} \\end{align*} $$\n", "\n", "其中 $r_{\\text} > 0$ 是提供的常值(L1 正则化系数),$\\text{SoftThreshold}$ 是软阈值算子,定义为\n", "\n", "$$ \\text{SoftThreshold}(\\beta, \\gamma) := \\begin{cases} \\beta + \\gamma &\\text{if } \\beta < -\\gamma \\ 0 &\\text{if } -\\gamma \\leq \\beta \\leq \\gamma \\ \\beta - \\gamma &\\text{if } \\beta > \\gamma. \\end{cases} $$\n", "\n", "此更新规则具有以下两项令人欣喜的性质,解释如下:\n", "\n", "1. 在极限情况 $r_{\\text} \\to 0$(即不使用 L1 正则化)下,此更新规则与原始更新规则相同。\n", "\n", "2. 此更新规则可以解释为应用邻近算子,其不动点是 L1 正则化最小化问题的解\n", "\n", "$$ \\underset{\\beta - \\beta^{(t)} \\in \\text{span}{ \\text{onehot}(j^{(t)}) }}{\\text{arg min}} \\left( -\\ell(\\beta ,;, \\mathbf{x}, \\mathbf{y})" ] }, { "cell_type": "markdown", "metadata": { "id": "CSs7_osNPLVt" }, "source": [ "### 退化情况 $r_{\\text} = 0$ 可恢复原始更新规则\n", "\n", "要查看 (1),请注意如果 $r_{\\text} = 0$ 则 $\\gamma^{(t)} = 0$,因此\n", "\n", "$$ \\begin{align*} \\left(\\beta_{\\text{reg}}^{(t+1)}\\right){j^{(t)}} &= \\text{SoftThreshold} \\left( \\beta^{(t)}{j^{(t)}} - \\alpha, u^{(t)} ,\\ 0 \\right) \\ &= \\beta^{(t)}_{j^{(t)}} - \\alpha, u^{(t)}. \\end{align*} $$\n", "\n", "因此\n", "\n", "$$ \\begin{align*} \\beta_{\\text{reg}}^{(t+1)} &= \\beta^{(t)} - \\alpha, u^{(t)} ,\\text{onehot}(j^{(t)}) \\ &= \\beta^{(t+1)}. \\end{align*} $$" ] }, { "cell_type": "markdown", "metadata": { "id": "EiHy_0NIPT5f" }, "source": [ "### 不动点为正则化最大似然估计的邻近算子\n", "\n", "要查看 (2),首先要注意(参见 [Wikipedia](#3))对于任何 $\\gamma > 0$,更新规则\n", "\n", "$$ \\left(\\beta_{\\text{exact-prox}, \\gamma}^{(t+1)}\\right){j^{(t)}} := \\text{prox}{\\gamma \\lVert \\cdot \\rVert_1} \\left( \\beta^{(t)}{j^{(t)}} + \\frac{\\gamma}{r{\\text}} \\left( \\left( \\nabla_\\beta, \\ell(\\beta ,;, \\mathbf{x}, \\mathbf{y}) \\right){\\beta = \\beta^{(t)}} \\right){j^{(t)}} \\right) $$\n", "\n", "均满足 (2),其中 $\\text{prox}$ 是邻近算子(参见 [Yu](#4),其中该算子表示为 $\\mathsf{P}$)。上述方程的右半部分在[此处](#2)计算:\n", "\n", "# $$ \\left(\\beta_{\\text{exact-prox}, \\gamma}^{(t+1)}\\right)_{j^{(t)}}\n", "\n", "特别地,设置 $\\gamma = \\gamma^{(t)} = -\\frac{\\alpha, r_{\\text}}{\\left(H^{(t)}\\right)_{j^{(t)}, j^{(t)}}}$(注:只要负对数似然函数是凸函数,$\\gamma^{(t)} 就大于 0$),我们得到更新规则\n", "\n", "# $$ \\left(\\beta_{\\text{exact-prox}, \\gamma^{(t)}}^{(t+1)}\\right)_{j^{(t)}}\n", "\n", "然后,我们将精确梯度 $\\left( \\nabla_\\beta, \\ell(\\beta ,;, \\mathbf{x}, \\mathbf{y}) \\right)_{\\beta = \\beta^{(t)}}$ 替换为其近似值 $s^{(t)}$,得到\n", "\n", "\\begin{align*} \\left(\\beta_{\\text{exact-prox}, \\gamma^{(t)}}^{(t+1)}\\right)*{j^{(t)}} &\\approx \\text{SoftThreshold} \\left( \\beta^{(t)}*{j^{(t)}} - \\alpha \\frac{ \\left(s^{(t)}\\right)*{j^{(t)}} }{ \\left(H^{(t)}\\right)*{j^{(t)}, j^{(t)}} } ,\\ \\gamma^{(t)} \\right) \\ &= \\text{SoftThreshold} \\left( \\beta^{(t)}_{j^{(t)}} - \\alpha, u^{(t)} ,\\ \\gamma^{(t)} \\right). \\end{align*}\n", "\n", "因此\n", "\n", "$$ \\beta_{\\text{exact-prox}, \\gamma^{(t)}}^{(t+1)} \\approx \\beta_{\\text{reg}}^{(t+1)}. $$" ] }, { "cell_type": "markdown", "metadata": { "id": "P7YOOrmI8j0L" }, "source": [ "# GLM 事实的推导\n", "\n", "在本部分中,我们将详细说明并推导出在之前几部分中使用的 GLM 相关结果。然后,我们将使用 TensorFlow 的 `gradients` 对导出的对数似然函数和 Fisher 信息的梯度公式进行数值验证。" ] }, { "cell_type": "markdown", "metadata": { "id": "lkHZyhuAIW-p" }, "source": [ "## 得分和 Fisher 信息" ] }, { "cell_type": "markdown", "metadata": { "id": "bbyYy0bE8pOK" }, "source": [ "考虑由参数向量 $\\theta$ 参数化的概率分布族,其概率密度为 $\\left{p(\\cdot | \\theta)\\right}_{\\theta \\in \\mathcal{T}}$。参数向量 $\\theta_0$ 处的结果 $y$ 的**得分**定义为 $y$ 的对数似然函数的梯度(在 $\\theta_0$ 处评估),即:\n", "\n", "$$ \\text{score}(y, \\theta_0) := \\left[\\nabla_\\theta, \\log p(y | \\theta)\\right]_{\\theta=\\theta_0}. $$" ] }, { "cell_type": "markdown", "metadata": { "id": "IYGaMPIx8uOc" }, "source": [ "### 声明:得分的期望值为零\n", "\n", "在非极端正则条件(允许我们传递积分符号内取微分)下,\n", "\n", "$$ \\mathbb{E}_{Y \\sim p(\\cdot | \\theta=\\theta_0)}\\left[\\text{score}(Y, \\theta_0)\\right] = 0. $$" ] }, { "cell_type": "markdown", "metadata": { "id": "b3H-wNmJ800R" }, "source": [ "#### 证明\n", "\n", "已知\n", "\n", "$$ \\begin{align*} \\mathbb{E}{Y \\sim p(\\cdot | \\theta=\\theta_0)}\\left[\\text{score}(Y, \\theta_0)\\right] &:=\\mathbb{E}{Y \\sim p(\\cdot | \\theta=\\theta_0)}\\left[\\left(\\nabla_\\theta \\log p(Y|\\theta)\\right){\\theta=\\theta_0}\\right] \\ &\\stackrel{\\text{(1)}}{=} \\mathbb{E}{Y \\sim p(\\cdot | \\theta=\\theta_0)}\\left[\\frac{\\left(\\nabla_\\theta p(Y|\\theta)\\right){\\theta=\\theta_0}}{p(Y|\\theta=\\theta_0)}\\right] \\ &\\stackrel{\\text{(2)}}{=} \\int{\\mathcal{Y}} \\left[\\frac{\\left(\\nabla_\\theta p(y|\\theta)\\right){\\theta=\\theta_0}}{p(y|\\theta=\\theta_0)}\\right] p(y | \\theta=\\theta_0), dy \\ &= \\int{\\mathcal{Y}} \\left(\\nabla_\\theta p(y|\\theta)\\right){\\theta=\\theta_0}, dy \\ &\\stackrel{\\text{(3)}}{=} \\left[\\nabla\\theta \\left(\\int_{\\mathcal{Y}} p(y|\\theta), dy\\right) \\right]{\\theta=\\theta_0} \\ &\\stackrel{\\text{(4)}}{=} \\left[\\nabla\\theta, 1 \\right]_{\\theta=\\theta_0} \\ &= 0, \\end{align*} $$\n", "\n", "其中我们使用了:(1) 微分连锁律、(2) 期望的定义、(3) 传递积分符号内取微分(使用正则条件)、(4) 概率密度的积分为 1。" ] }, { "cell_type": "markdown", "metadata": { "id": "1Y1DPVOI9OT2" }, "source": [ "### 声明(Fisher 信息):得分方差等于对数似然函数的 Hessian 负期望值\n", "\n", "在非极端正则条件(允许我们传递积分符号内取微分)下,\n", "\n", "# $$ \\mathbb{E}_{Y \\sim p(\\cdot | \\theta=\\theta_0)}\\left[ \\text{score}(Y, \\theta_0) \\text{score}(Y, \\theta_0)^\\top \\right]\n", "\n", "其中 $\\nabla_\\theta^2 F$ 表示 Hessian 矩阵,其 $(i, j)$ 项为 $\\frac{\\partial^2 F}{\\partial \\theta_i \\partial \\theta_j}$。\n", "\n", "此方程的左半部分称为参数向量 $\\theta_0$ 处的族 $\\left{p(\\cdot | \\theta)\\right}_{\\theta \\in \\mathcal{T}}$ 的 **Fisher 信息**。" ] }, { "cell_type": "markdown", "metadata": { "id": "KF-ac0Bk-HmR" }, "source": [ "#### 声明证明\n", "\n", "已知\n", "\n", "## $$ \\begin{align*} \\mathbb{E}*{Y \\sim p(\\cdot | \\theta=\\theta_0)}\\left[ \\left(\\nabla*\\theta^2 \\log p(Y | \\theta)\\right)*{\\theta=\\theta_0} \\right] &\\stackrel{\\text{(1)}}{=} \\mathbb{E}*{Y \\sim p(\\cdot | \\theta=\\theta_0)}\\left[ \\left(\\nabla_\\theta^\\top \\frac{ \\nabla_\\theta p(Y | \\theta) }{ p(Y|\\theta) }\\right)*{\\theta=\\theta_0} \\right] \\ &\\stackrel{\\text{(2)}}{=} \\mathbb{E}*{Y \\sim p(\\cdot | \\theta=\\theta_0)}\\left[ \\frac{ \\left(\\nabla^2_\\theta p(Y | \\theta)\\right)_{\\theta=\\theta_0} }{ p(Y|\\theta=\\theta_0) }\n", "\n", "其中我们使用了 (1) 微分链式法则、(2) 微分商法则、(3)再次反向使用链式法则。\n", "\n", "要完成证明,只需证明\n", "\n", "$$ \\mathbb{E}{Y \\sim p(\\cdot | \\theta=\\theta_0)}\\left[ \\frac{ \\left(\\nabla^2\\theta p(Y | \\theta)\\right)_{\\theta=\\theta_0} }{ p(Y|\\theta=\\theta_0) } \\right] \\stackrel{\\text{?}}{=} 0. $$\n", "\n", "为此,我们传递积分符号内取微分两次:\n", "\n", "$$ \\begin{align*} \\mathbb{E}{Y \\sim p(\\cdot | \\theta=\\theta_0)}\\left[ \\frac{ \\left(\\nabla^2\\theta p(Y | \\theta)\\right){\\theta=\\theta_0} }{ p(Y|\\theta=\\theta_0) } \\right] &= \\int{\\mathcal{Y}} \\left[ \\frac{ \\left(\\nabla^2_\\theta p(y | \\theta)\\right){\\theta=\\theta_0} }{ p(y|\\theta=\\theta_0) } \\right] , p(y | \\theta=\\theta_0), dy \\ &= \\int{\\mathcal{Y}} \\left(\\nabla^2_\\theta p(y | \\theta)\\right){\\theta=\\theta_0} , dy \\ &= \\left[ \\nabla\\theta^2 \\left( \\int_{\\mathcal{Y}} p(y | \\theta) , dy \\right) \\right]{\\theta=\\theta_0} \\ &= \\left[ \\nabla\\theta^2 , 1 \\right]_{\\theta=\\theta_0} \\ &= 0. \\end{align*} $$" ] }, { "cell_type": "markdown", "metadata": { "id": "kAIJfX7IX_lP" }, "source": [ "### 对数配分函数的导数相关引理\n", "\n", "如果 $a$、$b$ 和 $c$ 是标量值函数,则 $c$ 二次可微,使分布族 $\\left{p(\\cdot | \\theta)\\right}_{\\theta \\in \\mathcal{T}}$ 定义为\n", "\n", "$$ p(y|\\theta) = a(y) \\exp\\left(b(y), \\theta - c(\\theta)\\right) $$\n", "\n", "满足非极端正则条件,允许传递在对 $y$ 的积分符号内取对 $\\theta$ 的微分,然后\n", "\n", "$$ \\mathbb{E}_{Y \\sim p(\\cdot | \\theta=\\theta_0)} \\left[ b(Y) \\right] = c'(\\theta_0) $$\n", "\n", "和\n", "\n", "$$ \\text{Var}_{Y \\sim p(\\cdot | \\theta=\\theta_0)} \\left[ b(Y) \\right] = c''(\\theta_0). $$\n", "\n", "(这里 $'$ 表示微分,所以 $c'$ 和 $c''$ 是 $c$ 的一阶导数和二阶导数。)" ] }, { "cell_type": "markdown", "metadata": { "id": "CYBH-KwpfWhr" }, "source": [ "#### 证明\n", "\n", "对于此分布族,已知 $\\text{score}(y, \\theta_0) = b(y) - c'(\\theta_0)$。然后第一个方程遵循以下事实 $\\mathbb{E}_{Y \\sim p(\\cdot | \\theta=\\theta_0)} \\left[ \\text{score}(y, \\theta_0) \\right] = 0$。接下来,已知\n", "\n", "$$ \\begin{align*} \\text{Var}{Y \\sim p(\\cdot | \\theta=\\theta_0)} \\left[ b(Y) \\right] &= \\mathbb{E}{Y \\sim p(\\cdot | \\theta=\\theta_0)} \\left[ \\left(b(Y) - c'(\\theta_0)\\right)^2 \\right] \\ &= \\text{the one entry of } \\mathbb{E}{Y \\sim p(\\cdot | \\theta=\\theta_0)} \\left[ \\text{score}(y, \\theta_0) \\text{score}(y, \\theta_0)^\\top \\right] \\ &= \\text{the one entry of } -\\mathbb{E}{Y \\sim p(\\cdot | \\theta=\\theta_0)} \\left[ \\left(\\nabla_\\theta^2 \\log p(\\cdot | \\theta)\\right){\\theta=\\theta_0} \\right] \\ &= -\\mathbb{E}{Y \\sim p(\\cdot | \\theta=\\theta_0)} \\left[ -c''(\\theta_0) \\right] \\ &= c''(\\theta_0). \\end{align*} $$" ] }, { "cell_type": "markdown", "metadata": { "id": "AYpWUvvKcX-e" }, "source": [ "## 过度离散指数族\n", "\n", "**过度离散指数族**(标量)是一种分布族,其密度为\n", "\n", "$$ p_{\\text{OEF}(m, T)}(y, |, \\theta, \\phi) = m(y, \\phi) \\exp\\left(\\frac{\\theta, T(y) - A(\\theta)}{\\phi}\\right), $$\n", "\n", "其中 $m$ 和 $T$ 是已知的标量值函数,$\\theta$ 和 $\\phi$ 是标量参数。\n", "\n", "*[注:$A$ 是超定的:对于任何 $\\phi_0$,函数 $A$ 完全由此约束定义:对所有 $\\theta$,均满足 \\int p_{\\text{OEF}(m, T)}(y\\ |\\ \\theta, \\phi=\\phi_0), dy = 1$。由不同的 $\\phi_0$ 值求得的 $A$ 必须全部相同,这对 $m$ 和 $T$ 函数施加了约束。]*" ] }, { "cell_type": "markdown", "metadata": { "id": "IgpoijwPf7TV" }, "source": [ "### 充分统计量的均值和方差\n", "\n", "在与“对数配分函数的导数相关引理”部分的相同条件下,已知\n", "\n", "# $$ \\mathbb{E}*{Y \\sim p*{\\text{OEF}(m, T)}(\\cdot | \\theta, \\phi)} \\left[ T(Y) \\right]\n", "\n", "和\n", "\n", "# $$ \\text{Var}*{Y \\sim p*{\\text{OEF}(m, T)}(\\cdot | \\theta, \\phi)} \\left[ T(Y) \\right]" ] }, { "cell_type": "markdown", "metadata": { "id": "gyf51flphGOK" }, "source": [ "#### 证明\n", "\n", "根据“对数配分函数的导数相关引理”,已知\n", "\n", "# $$ \\mathbb{E}*{Y \\sim p*{\\text{OEF}(m, T)}(\\cdot | \\theta, \\phi)} \\left[ \\frac{T(Y)}{\\phi} \\right]\n", "\n", "和\n", "\n", "# $$ \\text{Var}*{Y \\sim p*{\\text{OEF}(m, T)}(\\cdot | \\theta, \\phi)} \\left[ \\frac{T(Y)}{\\phi} \\right]\n", "\n", "结果满足期望为线性 ($\\mathbb{E}[aX] = a\\mathbb{E}[X]$) 并且方差为二次齐次式 ($\\text{Var}[aX] = a^2 ,\\text{Var}[X]$)。" ] }, { "cell_type": "markdown", "metadata": { "id": "mYOnAZv9d4XH" }, "source": [ "## 广义线性模型\n", "\n", "在广义线性模型中,响应变量 $Y$ 的预测分布与观察到的预测变量 $x$ 的向量相关联。该分布是过度离散指数族的成员,参数 $\\theta$ 被替换为 $h(\\eta)$,其中 $h$ 是已知函数,$\\eta := x^\\top \\beta$ 是所谓的**线性响应**,$\\beta$ 是要学习的参数(回归系数)的向量。通常,也可以学习离散参数 $\\phi$,但在我们的设置中,我们将 $\\phi$ 视为已知。因此我们设置如下\n", "\n", "$$ Y \\sim p_{\\text{OEF}(m, T)}(\\cdot, |, \\theta = h(\\eta), \\phi) $$\n", "\n", "其中模型结构的特征在于分布 $p_{\\text{OEF}(m, T)}$ 和将线性响应转换为参数的函数 $h$。\n", "\n", "传统上,从线性响应 $\\eta$ 到均值 $\\mu := \\mathbb{E}*{Y \\sim p*{\\text{OEF}(m, T)}(\\cdot, |, \\theta = h(\\eta), \\phi)}\\left[ Y\\right]$ 的映射表示为\n", "\n", "$$ \\mu = g^{-1}(\\eta). $$\n", "\n", "此映射需为一对一映射,它的反函数 $g$ 被称为此 GLM 的**联系函数**。通常,人们通过命名其联系函数及其分布族来描述 GLM,例如,“具有伯努利分布和 logit 联系函数的 GLM”(也称为逻辑回归模型)。为了完全表征 GLM,还必须指定函数 $h$。如果 $h$ 为恒等函数,则称 $g$ 是**正则联系函数**。" ] }, { "cell_type": "markdown", "metadata": { "id": "t-mrWHH2-wtv" }, "source": [ "### 声明:用充分统计量表达 $h'$\n", "\n", "定义\n", "\n", "$$ {\\text{Mean}T}(\\eta) := \\mathbb{E}{Y \\sim p_{\\text{OEF}(m, T)}(\\cdot | \\theta = h(\\eta), \\phi)} \\left[ T(Y) \\right] $$\n", "\n", "和\n", "\n", "$$ {\\text{Var}T}(\\eta) := \\text{Var}{Y \\sim p_{\\text{OEF}(m, T)}(\\cdot | \\theta = h(\\eta), \\phi)} \\left[ T(Y) \\right]. $$\n", "\n", "然后,已知\n", "\n", "$$ h'(\\eta) = \\frac{\\phi, {\\text{Mean}_T}'(\\eta)}{{\\text{Var}_T}(\\eta)}. $$" ] }, { "cell_type": "markdown", "metadata": { "id": "z36iGKlf_-3F" }, "source": [ "#### 证明\n", "\n", "根据“充分统计量的均值和方差”,已知\n", "\n", "$$ {\\text{Mean}_T}(\\eta) = A'(h(\\eta)). $$\n", "\n", "用链式法则求导,我们得到 $$ {\\text{Mean}_T}'(\\eta) = A''(h(\\eta)), h'(\\eta), $$\n", "\n", "根据“充分统计量的均值和方差”\n", "\n", "$$ \\cdots = \\frac{1}{\\phi} {\\text{Var}_T}(\\eta)\\ h'(\\eta). $$\n", "\n", "结论如下。" ] }, { "cell_type": "markdown", "metadata": { "id": "D8LV_QHPx-wV" }, "source": [ "## 将 GLM 参数拟合到数据\n", "\n", "上面推导出的属性非常适合将 GLM 参数 $\\beta$ 拟合到数据集。诸如 Fisher 得分法之类的拟牛顿法依赖于对数似然函数的梯度和 Fisher 信息,我们现在将展示对于 GLM 可以特别有效地计算这些信息。\n", "\n", "假设我们已经观察到预测变量向量 $x_i$ 和相关的标量响应 $y_i$。在矩阵形式中,我们会说我们观察到了预测变量 $\\mathbf{x}$ 和响应 $\\mathbf{y}$,其中 $\\mathbf{x}$ 是第 $i$ 行为 $x_i^\\top$ 的矩阵,$\\mathbf{y}$ 是第 $i$ 个元素为 $y_i$ 的向量。参数 $\\beta$ 的对数似然函数为\n", "\n", "$$ \\ell(\\beta, ;, \\mathbf{x}, \\mathbf{y}) = \\sum_{i=1}^{N} \\log p_{\\text{OEF}(m, T)}(y_i, |, \\theta = h(x_i^\\top \\beta), \\phi). $$" ] }, { "cell_type": "markdown", "metadata": { "id": "aghNxiO_HFW1" }, "source": [ "### 对于单个数据样本\n", "\n", "为了简化表示法,让我们首先考虑单个数据点 $N=1$ 时的情况;然后我们将通过可加性扩展到一般情况。\n", "\n", "#### 梯度\n", "\n", "已知\n", "\n", "$$ \\begin{align*} \\ell(\\beta, ;, x, y) &= \\log p_{\\text{OEF}(m, T)}(y, |, \\theta = h(x^\\top \\beta), \\phi) \\ &= \\log m(y, \\phi) + \\frac{\\theta, T(y) - A(\\theta)}{\\phi}, \\quad\\text{where}\\ \\theta = h(x^\\top \\beta). \\end{align*} $$\n", "\n", "因此,根据链式法则,\n", "\n", "$$ \\nabla_\\beta \\ell(\\beta, ; , x, y) = \\frac{T(y) - A'(\\theta)}{\\phi}, h'(x^\\top \\beta), x. $$\n", "\n", "另外,根据充分统计量的均值和方差”,已知 $A'(\\theta) = {\\text{Mean}_T}(x^\\top \\beta)$。因此,根据“声明:用充分统计量表达 $h'$”,可得\n", "\n", "$$ \\cdots = \\left(T(y) - {\\text{Mean}_T}(x^\\top \\beta)\\right) \\frac{{\\text{Mean}_T}'(x^\\top \\beta)}{{\\text{Var}_T}(x^\\top \\beta)} ,x. $$\n", "\n", "#### Hessian\n", "\n", "由乘积法则二次求导,得到\n", "\n", "$$ \\begin{align*} \\nabla_\\beta^2 \\ell(\\beta, ;, x, y) &= \\left[ -A''(h(x^\\top \\beta)), h'(x^\\top \\beta) \\right] h'(x^\\top \\beta), x x^\\top + \\left[ T(y) - A'(h(x^\\top \\beta)) \\right] h''(x^\\top \\beta), xx^\\top ] \\ &= \\left( -{\\text{Mean}_T}'(x^\\top \\beta), h'(x^\\top \\beta) + \\left[T(y) - A'(h(x^\\top \\beta))\\right] \\right), x x^\\top. \\end{align*} $$\n", "\n", "#### Fisher 信息\n", "\n", "根据“充分统计量的均值和方差”,已知\n", "\n", "$$ \\mathbb{E}{Y \\sim p{\\text{OEF}(m, T)}(\\cdot | \\theta = h(x^\\top \\beta), \\phi)} \\left[ T(y) - A'(h(x^\\top \\beta)) \\right] = 0. $$\n", "\n", "因此\n", "\n", "$$ \\begin{align*} \\mathbb{E}{Y \\sim p{\\text{OEF}(m, T)}(\\cdot | \\theta = h(x^\\top \\beta), \\phi)} \\left[ \\nabla_\\beta^2 \\ell(\\beta, ;, x, y) \\right] &= -{\\text{Mean}_T}'(x^\\top \\beta), h'(x^\\top \\beta) x x^\\top \\ &= -\\frac{\\phi, {\\text{Mean}_T}'(x^\\top \\beta)^2}{{\\text{Var}_T}(x^\\top \\beta)}, x x^\\top. \\end{align*} $$" ] }, { "cell_type": "markdown", "metadata": { "id": "BrA1A583HOng" }, "source": [ "### 对于多个数据样本\n", "\n", "我们现在将 $N=1$ 情况扩展到一般情况。让$\\boldsymbol{\\eta} := \\mathbf{x} \\beta$ 表示第 i$ 个坐标是第 i$ 个数据样本的线性响应的向量。让 $\\mathbf{T}$ (resp. ${\\textbf{Mean}_T}$, resp. ${\\textbf{Var}_T}$) 表示对每个坐标应用标量值函数 $T$ (resp. ${\\text{Mean}_T}$, resp. ${\\text{Var}_T}$) 的广播(矢量化)函数。然后可得\n", "\n", "$$ \\begin{align*} \\nabla_\\beta \\ell(\\beta, ;, \\mathbf{x}, \\mathbf{y}) &= \\sum_{i=1}^{N} \\nabla_\\beta \\ell(\\beta, ;, x_i, y_i) \\ &= \\sum_{i=1}^{N} \\left(T(y) - {\\text{Mean}_T}(x_i^\\top \\beta)\\right) \\frac{{\\text{Mean}_T}'(x_i^\\top \\beta)}{{\\text{Var}_T}(x_i^\\top \\beta)} , x_i \\ &= \\mathbf{x}^\\top ,\\text{diag}\\left(\\frac{ {\\textbf{Mean}_T}'(\\mathbf{x} \\beta) }{ {\\textbf{Var}_T}(\\mathbf{x} \\beta) }\\right) \\left(\\mathbf{T}(\\mathbf{y}) - {\\textbf{Mean}_T}(\\mathbf{x} \\beta)\\right) \\ \\end{align*} $$\n", "\n", "和\n", "\n", "$$ \\begin{align*} \\mathbb{E}{Y_i \\sim p{\\text{OEF}(m, T)}(\\cdot | \\theta = h(x_i^\\top \\beta), \\phi)} \\left[ \\nabla_\\beta^2 \\ell(\\beta, ;, \\mathbf{x}, \\mathbf{Y}) \\right] &= \\sum_{i=1}^{N} \\mathbb{E}{Y_i \\sim p{\\text{OEF}(m, T)}(\\cdot | \\theta = h(x_i^\\top \\beta), \\phi)} \\left[ \\nabla_\\beta^2 \\ell(\\beta, ;, x_i, Y_i) \\right] \\ &= \\sum_{i=1}^{N} -\\frac{\\phi, {\\text{Mean}_T}'(x_i^\\top \\beta)^2}{{\\text{Var}_T}(x_i^\\top \\beta)}, x_i x_i^\\top \\ &= -\\mathbf{x}^\\top ,\\text{diag}\\left( \\frac{ \\phi, {\\textbf{Mean}_T}'(\\mathbf{x} \\beta)^2 }{ {\\textbf{Var}_T}(\\mathbf{x} \\beta) }\\right), \\mathbf{x}, \\end{align*} $$\n", "\n", "其中分数表示逐元素相除。" ] }, { "cell_type": "markdown", "metadata": { "id": "jUrOmdt395hZ" }, "source": [ "## 以数值方式验证公式" ] }, { "cell_type": "markdown", "metadata": { "id": "WVp59IBW-TK6" }, "source": [ "我们现在使用 `tf.gradients` 以数值方式验证上述对数似然函数的梯度的公式,并使用 `tf.hessians` 通过蒙特卡洛估计验证 Fisher 信息的公式:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "oM-HDPdPepE-" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Coordinatewise relative error between naively computed gradients and formula-based gradients (should be zero):\n", "[[2.08845965e-16 1.67076772e-16 2.08845965e-16]\n", " [1.96118673e-16 3.13789877e-16 1.96118673e-16]\n", " [2.08845965e-16 1.67076772e-16 2.08845965e-16]\n", " [1.96118673e-16 3.13789877e-16 1.96118673e-16]\n", " [2.08845965e-16 1.67076772e-16 2.08845965e-16]\n", " [1.96118673e-16 3.13789877e-16 1.96118673e-16]\n", " [1.96118673e-16 3.13789877e-16 1.96118673e-16]\n", " [1.96118673e-16 3.13789877e-16 1.96118673e-16]\n", " [2.08845965e-16 1.67076772e-16 2.08845965e-16]\n", " [1.96118673e-16 3.13789877e-16 1.96118673e-16]\n", " [2.08845965e-16 1.67076772e-16 2.08845965e-16]\n", " [1.96118673e-16 3.13789877e-16 1.96118673e-16]\n", " [1.96118673e-16 3.13789877e-16 1.96118673e-16]\n", " [1.96118673e-16 3.13789877e-16 1.96118673e-16]\n", " [1.96118673e-16 3.13789877e-16 1.96118673e-16]\n", " [1.96118673e-16 3.13789877e-16 1.96118673e-16]\n", " [1.96118673e-16 3.13789877e-16 1.96118673e-16]\n", " [2.08845965e-16 1.67076772e-16 2.08845965e-16]\n", " [1.96118673e-16 3.13789877e-16 1.96118673e-16]\n", " [2.08845965e-16 1.67076772e-16 2.08845965e-16]]\n", "\n", "Coordinatewise relative error between average of naively computed Hessian and formula-based FIM (should approach zero as num_trials -> infinity):\n", "[[0.00072369 0.00072369 0.00072369]\n", " [0.00072369 0.00072369 0.00072369]\n", " [0.00072369 0.00072369 0.00072369]]\n", "\n" ] } ], "source": [ "def VerifyGradientAndFIM():\n", " model = tfp.glm.BernoulliNormalCDF()\n", " model_matrix = np.array([[1., 5, -2],\n", " [8, -1, 8]])\n", "\n", " def _naive_grad_and_hessian_loss_fn(x, response):\n", " # Computes gradient and Hessian of negative log likelihood using autodiff.\n", " predicted_linear_response = tf.linalg.matvec(model_matrix, x)\n", " log_probs = model.log_prob(response, predicted_linear_response)\n", " grad_loss = tf.gradients(-log_probs, [x])[0]\n", " hessian_loss = tf.hessians(-log_probs, [x])[0]\n", " return [grad_loss, hessian_loss]\n", "\n", " def _grad_neg_log_likelihood_and_fim_fn(x, response):\n", " # Computes gradient of negative log likelihood and Fisher information matrix\n", " # using the formulas above.\n", " predicted_linear_response = tf.linalg.matvec(model_matrix, x)\n", " mean, variance, grad_mean = model(predicted_linear_response)\n", "\n", " v = (response - mean) * grad_mean / variance\n", " grad_log_likelihood = tf.linalg.matvec(model_matrix, v, adjoint_a=True)\n", " w = grad_mean**2 / variance\n", "\n", " fisher_info = tf.linalg.matmul(\n", " model_matrix,\n", " w[..., tf.newaxis] * model_matrix,\n", " adjoint_a=True)\n", " return [-grad_log_likelihood, fisher_info]\n", "\n", " @tf.function(autograph=False)\n", " def compute_grad_hessian_estimates():\n", " # Monte Carlo estimate of E[Hessian(-LogLikelihood)], where the expectation is\n", " # as written in \"Claim (Fisher information)\" above.\n", " num_trials = 20\n", " trial_outputs = []\n", " np.random.seed(10)\n", " model_coefficients_ = np.random.random(size=(model_matrix.shape[1],))\n", " model_coefficients = tf.convert_to_tensor(model_coefficients_)\n", " for _ in range(num_trials):\n", " # Sample from the distribution of `model`\n", " response = np.random.binomial(\n", " 1,\n", " scipy.stats.norm().cdf(np.matmul(model_matrix, model_coefficients_))\n", " ).astype(np.float64)\n", " trial_outputs.append(\n", " list(_naive_grad_and_hessian_loss_fn(model_coefficients, response)) +\n", " list(\n", " _grad_neg_log_likelihood_and_fim_fn(model_coefficients, response))\n", " )\n", "\n", " naive_grads = tf.stack(\n", " list(naive_grad for [naive_grad, _, _, _] in trial_outputs), axis=0)\n", " fancy_grads = tf.stack(\n", " list(fancy_grad for [_, _, fancy_grad, _] in trial_outputs), axis=0)\n", "\n", " average_hess = tf.reduce_mean(tf.stack(\n", " list(hess for [_, hess, _, _] in trial_outputs), axis=0), axis=0)\n", " [_, _, _, fisher_info] = trial_outputs[0]\n", " return naive_grads, fancy_grads, average_hess, fisher_info\n", " \n", " naive_grads, fancy_grads, average_hess, fisher_info = [\n", " t.numpy() for t in compute_grad_hessian_estimates()]\n", "\n", " print(\"Coordinatewise relative error between naively computed gradients and\"\n", " \" formula-based gradients (should be zero):\\n{}\\n\".format(\n", " (naive_grads - fancy_grads) / naive_grads))\n", "\n", " print(\"Coordinatewise relative error between average of naively computed\"\n", " \" Hessian and formula-based FIM (should approach zero as num_trials\"\n", " \" -> infinity):\\n{}\\n\".format(\n", " (average_hess - fisher_info) / average_hess))\n", " \n", "VerifyGradientAndFIM()\n" ] }, { "cell_type": "markdown", "metadata": { "id": "bAiNubQ-WDHN" }, "source": [ "# 参考文献\n", "\n", "[1]: Guo-Xun Yuan, Chia-Hua Ho and Chih-Jen Lin. An Improved GLMNET for L1-regularized Logistic Regression. *Journal of Machine Learning Research*, 13, 2012. http://www.jmlr.org/papers/volume13/yuan12a/yuan12a.pdf\n", "\n", "[2]: skd. Derivation of Soft Thresholding Operator. 2018. https://math.stackexchange.com/q/511106\n", "\n", "[3]: Wikipedia Contributors. Proximal gradient methods for learning. *Wikipedia, The Free Encyclopedia*, 2018. https://en.wikipedia.org/wiki/Proximal_gradient_methods_for_learning\n", "\n", "[4]: Yao-Liang Yu. The Proximity Operator. https://www.cs.cmu.edu/~suvrit/teach/yaoliang_proximity.pdf" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "Generalized_Linear_Models.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }