{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "D2J3nB-ZrRv1" }, "source": [ "##### Copyright 2018 The TensorFlow Probability Authors.\n", "\n", "Licensed under the Apache License, Version 2.0 (the \"License\");" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "9qDhTJmprPnm" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\"); { display-mode: \"form\" }\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "pfPtIQ3DdZ8r" }, "source": [ "# 广义线性模型\n", "\n", "
![]() | \n",
" ![]() | \n",
" ![]() | \n",
" ![]() | \n",
"
\n", " | Learned | \n", "True | \n", "
---|---|---|
0 | \n", "0.216240 | \n", "0.220758 | \n", "
1 | \n", "0.000000 | \n", "0.000000 | \n", "
2 | \n", "0.000000 | \n", "0.000000 | \n", "
3 | \n", "0.000000 | \n", "0.000000 | \n", "
4 | \n", "0.000000 | \n", "0.000000 | \n", "
5 | \n", "0.043702 | \n", "0.063950 | \n", "
6 | \n", "-0.145379 | \n", "-0.153256 | \n", "
7 | \n", "0.000000 | \n", "0.000000 | \n", "
8 | \n", "0.000000 | \n", "0.000000 | \n", "
9 | \n", "0.000000 | \n", "0.000000 | \n", "
10 | \n", "0.000000 | \n", "0.000000 | \n", "
11 | \n", "0.000000 | \n", "0.000000 | \n", "
12 | \n", "0.000000 | \n", "0.000000 | \n", "
13 | \n", "0.024382 | \n", "0.046572 | \n", "
14 | \n", "-0.242985 | \n", "-0.242609 | \n", "
15 | \n", "-0.106168 | \n", "-0.123367 | \n", "
16 | \n", "0.000000 | \n", "0.000000 | \n", "
17 | \n", "-0.039745 | \n", "-0.067560 | \n", "
18 | \n", "-0.217717 | \n", "-0.222169 | \n", "
19 | \n", "0.000000 | \n", "0.000000 | \n", "
20 | \n", "0.000000 | \n", "0.000000 | \n", "
21 | \n", "-0.016553 | \n", "-0.041692 | \n", "
22 | \n", "0.018959 | \n", "0.049624 | \n", "
23 | \n", "-0.057686 | \n", "-0.078299 | \n", "
24 | \n", "0.003642 | \n", "0.035682 | \n", "
25 | \n", "0.000000 | \n", "0.000000 | \n", "
26 | \n", "0.000000 | \n", "0.000000 | \n", "
27 | \n", "-0.234406 | \n", "-0.240482 | \n", "
28 | \n", "0.000000 | \n", "0.000000 | \n", "
29 | \n", "0.232209 | \n", "0.225448 | \n", "
... | \n", "... | \n", "... | \n", "
70 | \n", "0.000000 | \n", "0.000000 | \n", "
71 | \n", "0.130166 | \n", "0.144485 | \n", "
72 | \n", "0.000000 | \n", "0.000000 | \n", "
73 | \n", "0.000000 | \n", "0.000000 | \n", "
74 | \n", "0.000000 | \n", "0.000000 | \n", "
75 | \n", "-0.178534 | \n", "-0.186722 | \n", "
76 | \n", "0.000000 | \n", "0.000000 | \n", "
77 | \n", "0.218493 | \n", "0.229656 | \n", "
78 | \n", "0.000000 | \n", "0.000000 | \n", "
79 | \n", "0.000000 | \n", "0.000000 | \n", "
80 | \n", "0.195579 | \n", "0.200442 | \n", "
81 | \n", "0.000000 | \n", "0.000000 | \n", "
82 | \n", "0.000000 | \n", "0.000000 | \n", "
83 | \n", "0.031153 | \n", "0.050457 | \n", "
84 | \n", "0.229065 | \n", "0.231451 | \n", "
85 | \n", "-0.006512 | \n", "-0.039516 | \n", "
86 | \n", "-0.107947 | \n", "-0.119896 | \n", "
87 | \n", "0.000000 | \n", "0.000000 | \n", "
88 | \n", "0.149419 | \n", "0.171693 | \n", "
89 | \n", "0.000000 | \n", "0.000000 | \n", "
90 | \n", "0.047955 | \n", "0.063434 | \n", "
91 | \n", "0.000000 | \n", "0.003592 | \n", "
92 | \n", "-0.083171 | \n", "-0.107145 | \n", "
93 | \n", "0.084615 | \n", "0.101221 | \n", "
94 | \n", "-0.168431 | \n", "-0.175473 | \n", "
95 | \n", "0.138411 | \n", "0.152623 | \n", "
96 | \n", "0.000000 | \n", "0.000000 | \n", "
97 | \n", "0.061161 | \n", "0.081945 | \n", "
98 | \n", "-0.083348 | \n", "-0.104929 | \n", "
99 | \n", "-0.141154 | \n", "-0.153871 | \n", "
100 rows × 2 columns
\n", "\n", " | R | \n", "TFP | \n", "True | \n", "
---|---|---|---|
0 | \n", "0.281080 | \n", "0.216240 | \n", "0.220758 | \n", "
1 | \n", "0.000000 | \n", "0.000000 | \n", "0.000000 | \n", "
2 | \n", "0.000000 | \n", "0.000000 | \n", "0.000000 | \n", "
3 | \n", "0.000000 | \n", "0.000000 | \n", "0.000000 | \n", "
4 | \n", "0.000000 | \n", "0.000000 | \n", "0.000000 | \n", "
5 | \n", "0.056625 | \n", "0.043702 | \n", "0.063950 | \n", "
6 | \n", "-0.188771 | \n", "-0.145379 | \n", "-0.153256 | \n", "
7 | \n", "0.000000 | \n", "0.000000 | \n", "0.000000 | \n", "
8 | \n", "0.000000 | \n", "0.000000 | \n", "0.000000 | \n", "
9 | \n", "0.000000 | \n", "0.000000 | \n", "0.000000 | \n", "
10 | \n", "0.000000 | \n", "0.000000 | \n", "0.000000 | \n", "
11 | \n", "0.000000 | \n", "0.000000 | \n", "0.000000 | \n", "
12 | \n", "0.000000 | \n", "0.000000 | \n", "0.000000 | \n", "
13 | \n", "0.030112 | \n", "0.024382 | \n", "0.046572 | \n", "
14 | \n", "-0.316488 | \n", "-0.242985 | \n", "-0.242609 | \n", "
15 | \n", "-0.139214 | \n", "-0.106168 | \n", "-0.123367 | \n", "
16 | \n", "0.000000 | \n", "0.000000 | \n", "0.000000 | \n", "
17 | \n", "-0.050239 | \n", "-0.039745 | \n", "-0.067560 | \n", "
18 | \n", "-0.283372 | \n", "-0.217717 | \n", "-0.222169 | \n", "
19 | \n", "0.000000 | \n", "0.000000 | \n", "0.000000 | \n", "
20 | \n", "0.000000 | \n", "0.000000 | \n", "0.000000 | \n", "
21 | \n", "-0.021815 | \n", "-0.016553 | \n", "-0.041692 | \n", "
22 | \n", "0.024070 | \n", "0.018959 | \n", "0.049624 | \n", "
23 | \n", "-0.074039 | \n", "-0.057686 | \n", "-0.078299 | \n", "
24 | \n", "0.005321 | \n", "0.003642 | \n", "0.035682 | \n", "
25 | \n", "0.000000 | \n", "0.000000 | \n", "0.000000 | \n", "
26 | \n", "0.000000 | \n", "0.000000 | \n", "0.000000 | \n", "
27 | \n", "-0.304958 | \n", "-0.234406 | \n", "-0.240482 | \n", "
28 | \n", "0.000000 | \n", "0.000000 | \n", "0.000000 | \n", "
29 | \n", "0.301562 | \n", "0.232209 | \n", "0.225448 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "
70 | \n", "0.000000 | \n", "0.000000 | \n", "0.000000 | \n", "
71 | \n", "0.169291 | \n", "0.130166 | \n", "0.144485 | \n", "
72 | \n", "0.000000 | \n", "0.000000 | \n", "0.000000 | \n", "
73 | \n", "0.000000 | \n", "0.000000 | \n", "0.000000 | \n", "
74 | \n", "0.000000 | \n", "0.000000 | \n", "0.000000 | \n", "
75 | \n", "-0.231294 | \n", "-0.178534 | \n", "-0.186722 | \n", "
76 | \n", "0.000000 | \n", "0.000000 | \n", "0.000000 | \n", "
77 | \n", "0.284215 | \n", "0.218493 | \n", "0.229656 | \n", "
78 | \n", "0.000000 | \n", "0.000000 | \n", "0.000000 | \n", "
79 | \n", "0.000000 | \n", "0.000000 | \n", "0.000000 | \n", "
80 | \n", "0.254524 | \n", "0.195579 | \n", "0.200442 | \n", "
81 | \n", "0.000000 | \n", "0.000000 | \n", "0.000000 | \n", "
82 | \n", "0.000000 | \n", "0.000000 | \n", "0.000000 | \n", "
83 | \n", "0.040716 | \n", "0.031153 | \n", "0.050457 | \n", "
84 | \n", "0.297475 | \n", "0.229065 | \n", "0.231451 | \n", "
85 | \n", "-0.008569 | \n", "-0.006512 | \n", "-0.039516 | \n", "
86 | \n", "-0.141028 | \n", "-0.107947 | \n", "-0.119896 | \n", "
87 | \n", "0.000000 | \n", "0.000000 | \n", "0.000000 | \n", "
88 | \n", "0.194130 | \n", "0.149419 | \n", "0.171693 | \n", "
89 | \n", "0.000000 | \n", "0.000000 | \n", "0.000000 | \n", "
90 | \n", "0.062601 | \n", "0.047955 | \n", "0.063434 | \n", "
91 | \n", "0.000000 | \n", "0.000000 | \n", "0.003592 | \n", "
92 | \n", "-0.107693 | \n", "-0.083171 | \n", "-0.107145 | \n", "
93 | \n", "0.109381 | \n", "0.084615 | \n", "0.101221 | \n", "
94 | \n", "-0.218831 | \n", "-0.168431 | \n", "-0.175473 | \n", "
95 | \n", "0.180662 | \n", "0.138411 | \n", "0.152623 | \n", "
96 | \n", "0.000000 | \n", "0.000000 | \n", "0.000000 | \n", "
97 | \n", "0.078815 | \n", "0.061161 | \n", "0.081945 | \n", "
98 | \n", "-0.108332 | \n", "-0.083348 | \n", "-0.104929 | \n", "
99 | \n", "-0.183284 | \n", "-0.141154 | \n", "-0.153871 | \n", "
100 rows × 3 columns
\n", "$$ {\\text{Mean}_T}'(\\eta) = A''(h(\\eta)), h'(\\eta), $$
\n",
"\n",
"根据“充分统计量的均值和方差”\n",
"\n",
"$$ \\cdots = \\frac{1}{\\phi} {\\text{Var}_T}(\\eta)\\ h'(\\eta). $$\n",
"\n",
"结论如下。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "D8LV_QHPx-wV"
},
"source": [
"## 将 GLM 参数拟合到数据\n",
"\n",
"上面推导出的属性非常适合将 GLM 参数 $\\beta$ 拟合到数据集。诸如 Fisher 得分法之类的拟牛顿法依赖于对数似然函数的梯度和 Fisher 信息,我们现在将展示对于 GLM 可以特别有效地计算这些信息。\n",
"\n",
"假设我们已经观察到预测变量向量 $x_i$ 和相关的标量响应 $y_i$。在矩阵形式中,我们会说我们观察到了预测变量 $\\mathbf{x}$ 和响应 $\\mathbf{y}$,其中 $\\mathbf{x}$ 是第 $i$ 行为 $x_i^\\top$ 的矩阵,$\\mathbf{y}$ 是第 $i$ 个元素为 $y_i$ 的向量。参数 $\\beta$ 的对数似然函数为\n",
"\n",
"$$ \\ell(\\beta, ;, \\mathbf{x}, \\mathbf{y}) = \\sum_{i=1}^{N} \\log p_{\\text{OEF}(m, T)}(y_i, |, \\theta = h(x_i^\\top \\beta), \\phi). $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aghNxiO_HFW1"
},
"source": [
"### 对于单个数据样本\n",
"\n",
"为了简化表示法,让我们首先考虑单个数据点 $N=1$ 时的情况;然后我们将通过可加性扩展到一般情况。\n",
"\n",
"#### 梯度\n",
"\n",
"已知\n",
"\n",
"$$ \\begin{align*} \\ell(\\beta, ;, x, y) &= \\log p_{\\text{OEF}(m, T)}(y, |, \\theta = h(x^\\top \\beta), \\phi) \\ &= \\log m(y, \\phi) + \\frac{\\theta, T(y) - A(\\theta)}{\\phi}, \\quad\\text{where}\\ \\theta = h(x^\\top \\beta). \\end{align*} $$\n",
"\n",
"因此,根据链式法则,\n",
"\n",
"$$ \\nabla_\\beta \\ell(\\beta, ; , x, y) = \\frac{T(y) - A'(\\theta)}{\\phi}, h'(x^\\top \\beta), x. $$\n",
"\n",
"另外,根据充分统计量的均值和方差”,已知 $A'(\\theta) = {\\text{Mean}_T}(x^\\top \\beta)$。因此,根据“声明:用充分统计量表达 $h'$”,可得\n",
"\n",
"$$ \\cdots = \\left(T(y) - {\\text{Mean}_T}(x^\\top \\beta)\\right) \\frac{{\\text{Mean}_T}'(x^\\top \\beta)}{{\\text{Var}_T}(x^\\top \\beta)} ,x. $$\n",
"\n",
"#### Hessian\n",
"\n",
"由乘积法则二次求导,得到\n",
"\n",
"$$ \\begin{align*} \\nabla_\\beta^2 \\ell(\\beta, ;, x, y) &= \\left[ -A''(h(x^\\top \\beta)), h'(x^\\top \\beta) \\right] h'(x^\\top \\beta), x x^\\top + \\left[ T(y) - A'(h(x^\\top \\beta)) \\right] h''(x^\\top \\beta), xx^\\top ] \\ &= \\left( -{\\text{Mean}_T}'(x^\\top \\beta), h'(x^\\top \\beta) + \\left[T(y) - A'(h(x^\\top \\beta))\\right] \\right), x x^\\top. \\end{align*} $$\n",
"\n",
"#### Fisher 信息\n",
"\n",
"根据“充分统计量的均值和方差”,已知\n",
"\n",
"$$ \\mathbb{E}{Y \\sim p{\\text{OEF}(m, T)}(\\cdot | \\theta = h(x^\\top \\beta), \\phi)} \\left[ T(y) - A'(h(x^\\top \\beta)) \\right] = 0. $$\n",
"\n",
"因此\n",
"\n",
"$$ \\begin{align*} \\mathbb{E}{Y \\sim p{\\text{OEF}(m, T)}(\\cdot | \\theta = h(x^\\top \\beta), \\phi)} \\left[ \\nabla_\\beta^2 \\ell(\\beta, ;, x, y) \\right] &= -{\\text{Mean}_T}'(x^\\top \\beta), h'(x^\\top \\beta) x x^\\top \\ &= -\\frac{\\phi, {\\text{Mean}_T}'(x^\\top \\beta)^2}{{\\text{Var}_T}(x^\\top \\beta)}, x x^\\top. \\end{align*} $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BrA1A583HOng"
},
"source": [
"### 对于多个数据样本\n",
"\n",
"我们现在将 $N=1$ 情况扩展到一般情况。让$\\boldsymbol{\\eta} := \\mathbf{x} \\beta$ 表示第 i$ 个坐标是第 i$ 个数据样本的线性响应的向量。让 $\\mathbf{T}$ (resp. ${\\textbf{Mean}_T}$, resp. ${\\textbf{Var}_T}$) 表示对每个坐标应用标量值函数 $T$ (resp. ${\\text{Mean}_T}$, resp. ${\\text{Var}_T}$) 的广播(矢量化)函数。然后可得\n",
"\n",
"$$ \\begin{align*} \\nabla_\\beta \\ell(\\beta, ;, \\mathbf{x}, \\mathbf{y}) &= \\sum_{i=1}^{N} \\nabla_\\beta \\ell(\\beta, ;, x_i, y_i) \\ &= \\sum_{i=1}^{N} \\left(T(y) - {\\text{Mean}_T}(x_i^\\top \\beta)\\right) \\frac{{\\text{Mean}_T}'(x_i^\\top \\beta)}{{\\text{Var}_T}(x_i^\\top \\beta)} , x_i \\ &= \\mathbf{x}^\\top ,\\text{diag}\\left(\\frac{ {\\textbf{Mean}_T}'(\\mathbf{x} \\beta) }{ {\\textbf{Var}_T}(\\mathbf{x} \\beta) }\\right) \\left(\\mathbf{T}(\\mathbf{y}) - {\\textbf{Mean}_T}(\\mathbf{x} \\beta)\\right) \\ \\end{align*} $$\n",
"\n",
"和\n",
"\n",
"$$ \\begin{align*} \\mathbb{E}{Y_i \\sim p{\\text{OEF}(m, T)}(\\cdot | \\theta = h(x_i^\\top \\beta), \\phi)} \\left[ \\nabla_\\beta^2 \\ell(\\beta, ;, \\mathbf{x}, \\mathbf{Y}) \\right] &= \\sum_{i=1}^{N} \\mathbb{E}{Y_i \\sim p{\\text{OEF}(m, T)}(\\cdot | \\theta = h(x_i^\\top \\beta), \\phi)} \\left[ \\nabla_\\beta^2 \\ell(\\beta, ;, x_i, Y_i) \\right] \\ &= \\sum_{i=1}^{N} -\\frac{\\phi, {\\text{Mean}_T}'(x_i^\\top \\beta)^2}{{\\text{Var}_T}(x_i^\\top \\beta)}, x_i x_i^\\top \\ &= -\\mathbf{x}^\\top ,\\text{diag}\\left( \\frac{ \\phi, {\\textbf{Mean}_T}'(\\mathbf{x} \\beta)^2 }{ {\\textbf{Var}_T}(\\mathbf{x} \\beta) }\\right), \\mathbf{x}, \\end{align*} $$\n",
"\n",
"其中分数表示逐元素相除。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jUrOmdt395hZ"
},
"source": [
"## 以数值方式验证公式"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WVp59IBW-TK6"
},
"source": [
"我们现在使用 `tf.gradients` 以数值方式验证上述对数似然函数的梯度的公式,并使用 `tf.hessians` 通过蒙特卡洛估计验证 Fisher 信息的公式:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "oM-HDPdPepE-"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Coordinatewise relative error between naively computed gradients and formula-based gradients (should be zero):\n",
"[[2.08845965e-16 1.67076772e-16 2.08845965e-16]\n",
" [1.96118673e-16 3.13789877e-16 1.96118673e-16]\n",
" [2.08845965e-16 1.67076772e-16 2.08845965e-16]\n",
" [1.96118673e-16 3.13789877e-16 1.96118673e-16]\n",
" [2.08845965e-16 1.67076772e-16 2.08845965e-16]\n",
" [1.96118673e-16 3.13789877e-16 1.96118673e-16]\n",
" [1.96118673e-16 3.13789877e-16 1.96118673e-16]\n",
" [1.96118673e-16 3.13789877e-16 1.96118673e-16]\n",
" [2.08845965e-16 1.67076772e-16 2.08845965e-16]\n",
" [1.96118673e-16 3.13789877e-16 1.96118673e-16]\n",
" [2.08845965e-16 1.67076772e-16 2.08845965e-16]\n",
" [1.96118673e-16 3.13789877e-16 1.96118673e-16]\n",
" [1.96118673e-16 3.13789877e-16 1.96118673e-16]\n",
" [1.96118673e-16 3.13789877e-16 1.96118673e-16]\n",
" [1.96118673e-16 3.13789877e-16 1.96118673e-16]\n",
" [1.96118673e-16 3.13789877e-16 1.96118673e-16]\n",
" [1.96118673e-16 3.13789877e-16 1.96118673e-16]\n",
" [2.08845965e-16 1.67076772e-16 2.08845965e-16]\n",
" [1.96118673e-16 3.13789877e-16 1.96118673e-16]\n",
" [2.08845965e-16 1.67076772e-16 2.08845965e-16]]\n",
"\n",
"Coordinatewise relative error between average of naively computed Hessian and formula-based FIM (should approach zero as num_trials -> infinity):\n",
"[[0.00072369 0.00072369 0.00072369]\n",
" [0.00072369 0.00072369 0.00072369]\n",
" [0.00072369 0.00072369 0.00072369]]\n",
"\n"
]
}
],
"source": [
"def VerifyGradientAndFIM():\n",
" model = tfp.glm.BernoulliNormalCDF()\n",
" model_matrix = np.array([[1., 5, -2],\n",
" [8, -1, 8]])\n",
"\n",
" def _naive_grad_and_hessian_loss_fn(x, response):\n",
" # Computes gradient and Hessian of negative log likelihood using autodiff.\n",
" predicted_linear_response = tf.linalg.matvec(model_matrix, x)\n",
" log_probs = model.log_prob(response, predicted_linear_response)\n",
" grad_loss = tf.gradients(-log_probs, [x])[0]\n",
" hessian_loss = tf.hessians(-log_probs, [x])[0]\n",
" return [grad_loss, hessian_loss]\n",
"\n",
" def _grad_neg_log_likelihood_and_fim_fn(x, response):\n",
" # Computes gradient of negative log likelihood and Fisher information matrix\n",
" # using the formulas above.\n",
" predicted_linear_response = tf.linalg.matvec(model_matrix, x)\n",
" mean, variance, grad_mean = model(predicted_linear_response)\n",
"\n",
" v = (response - mean) * grad_mean / variance\n",
" grad_log_likelihood = tf.linalg.matvec(model_matrix, v, adjoint_a=True)\n",
" w = grad_mean**2 / variance\n",
"\n",
" fisher_info = tf.linalg.matmul(\n",
" model_matrix,\n",
" w[..., tf.newaxis] * model_matrix,\n",
" adjoint_a=True)\n",
" return [-grad_log_likelihood, fisher_info]\n",
"\n",
" @tf.function(autograph=False)\n",
" def compute_grad_hessian_estimates():\n",
" # Monte Carlo estimate of E[Hessian(-LogLikelihood)], where the expectation is\n",
" # as written in \"Claim (Fisher information)\" above.\n",
" num_trials = 20\n",
" trial_outputs = []\n",
" np.random.seed(10)\n",
" model_coefficients_ = np.random.random(size=(model_matrix.shape[1],))\n",
" model_coefficients = tf.convert_to_tensor(model_coefficients_)\n",
" for _ in range(num_trials):\n",
" # Sample from the distribution of `model`\n",
" response = np.random.binomial(\n",
" 1,\n",
" scipy.stats.norm().cdf(np.matmul(model_matrix, model_coefficients_))\n",
" ).astype(np.float64)\n",
" trial_outputs.append(\n",
" list(_naive_grad_and_hessian_loss_fn(model_coefficients, response)) +\n",
" list(\n",
" _grad_neg_log_likelihood_and_fim_fn(model_coefficients, response))\n",
" )\n",
"\n",
" naive_grads = tf.stack(\n",
" list(naive_grad for [naive_grad, _, _, _] in trial_outputs), axis=0)\n",
" fancy_grads = tf.stack(\n",
" list(fancy_grad for [_, _, fancy_grad, _] in trial_outputs), axis=0)\n",
"\n",
" average_hess = tf.reduce_mean(tf.stack(\n",
" list(hess for [_, hess, _, _] in trial_outputs), axis=0), axis=0)\n",
" [_, _, _, fisher_info] = trial_outputs[0]\n",
" return naive_grads, fancy_grads, average_hess, fisher_info\n",
" \n",
" naive_grads, fancy_grads, average_hess, fisher_info = [\n",
" t.numpy() for t in compute_grad_hessian_estimates()]\n",
"\n",
" print(\"Coordinatewise relative error between naively computed gradients and\"\n",
" \" formula-based gradients (should be zero):\\n{}\\n\".format(\n",
" (naive_grads - fancy_grads) / naive_grads))\n",
"\n",
" print(\"Coordinatewise relative error between average of naively computed\"\n",
" \" Hessian and formula-based FIM (should approach zero as num_trials\"\n",
" \" -> infinity):\\n{}\\n\".format(\n",
" (average_hess - fisher_info) / average_hess))\n",
" \n",
"VerifyGradientAndFIM()\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bAiNubQ-WDHN"
},
"source": [
"# 参考文献\n",
"\n",
"[1]: Guo-Xun Yuan, Chia-Hua Ho and Chih-Jen Lin. An Improved GLMNET for L1-regularized Logistic Regression. *Journal of Machine Learning Research*, 13, 2012. http://www.jmlr.org/papers/volume13/yuan12a/yuan12a.pdf\n",
"\n",
"[2]: skd. Derivation of Soft Thresholding Operator. 2018. https://math.stackexchange.com/q/511106\n",
"\n",
"[3]: Wikipedia Contributors. Proximal gradient methods for learning. *Wikipedia, The Free Encyclopedia*, 2018. https://en.wikipedia.org/wiki/Proximal_gradient_methods_for_learning\n",
"\n",
"[4]: Yao-Liang Yu. The Proximity Operator. https://www.cs.cmu.edu/~suvrit/teach/yaoliang_proximity.pdf"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "Generalized_Linear_Models.ipynb",
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}