{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "Fqp93JixVuiN" }, "source": [ "##### Copyright 2018 The TensorFlow Probability Authors.\n", "\n", "Licensed under the Apache License, Version 2.0 (the \"License\");" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "MeKZo1dnV1cE" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\"); { display-mode: \"form\" }\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "DcriL2xPrG3_" }, "source": [ "# TensorFlow Distributions 浅析\n", "\n", "\n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 上查看 在 Google Colab 中运行 在 Github 上查看源代码 下载笔记本
" ] }, { "cell_type": "markdown", "metadata": { "id": "10i0wSQfJClb" }, "source": [ "在此笔记本中,我们将探索 TensorFlow Distributions(简称为 TFD)。此笔记本的目标是用通俗易懂的方式让您了解学习曲线,包括了解 TFD 对张量形状的处理。此笔记本尝试先列举示例,而不是介绍抽象的概念。我们首先介绍执行操作时公认的简单方式,而将最基本的抽象概念留到最后。如果您更偏爱较抽象的参考教程,请参阅[了解 TensorFlow Distributions 形状](https://github.com/tensorflow/probability/blob/main/tensorflow_probability/examples/jupyter_notebooks/Understanding_TensorFlow_Distributions_Shapes.ipynb)。如果对本文介绍的内容有任何疑问,请随时联系(或加入)[TensorFlow Probability 邮寄名单](https://groups.google.com/a/tensorflow.org/forum/#!forum/tfprobability)。我们非常乐意为您提供帮助。" ] }, { "cell_type": "markdown", "metadata": { "id": "kII2QSIEJn0X" }, "source": [ "首先,我们需要导入相应的库。我们的整个库为 `tensorflow_probability`。按照惯例,我们通常将该分布库称为 `tfd`。\n", "\n", "[Tensorflow Eager](https://tensorflow.google.cn/guide/eager) 是 TensorFlow 的命令式执行环境。在 TensorFlow Eager 中,每个 TF 运算都会立即得到计算并生成结果。这与 TensorFlow 的标准“计算图”模式形成对比,在“计算图”模式下,TF 运算会将节点添加到稍后执行的计算图中。整个笔记本使用 TF Eager 编写,但是本文介绍的任何概念都与其无关,并且 TFP 可以在计算图模式下使用。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "J6t0EUihrG4B" }, "outputs": [], "source": [ "import collections\n", "\n", "import tensorflow as tf\n", "import tensorflow_probability as tfp\n", "tfd = tfp.distributions\n", "\n", "try:\n", " tf.compat.v1.enable_eager_execution()\n", "except ValueError:\n", " pass\n", "\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "metadata": { "id": "QD5lzFZerG4H" }, "source": [ "## 基本的一元分布\n" ] }, { "cell_type": "markdown", "metadata": { "id": "Gos16Z82LSQQ" }, "source": [ "我们立即创建一个正态分布:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "r3zofkWpLEvY" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 3, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "n = tfd.Normal(loc=0., scale=1.)\n", "n" ] }, { "cell_type": "markdown", "metadata": { "id": "7rlXP5HaLVsc" }, "source": [ "我们可以通过它绘制一个样本:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "14RRJONELX3O" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 4, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "n.sample()" ] }, { "cell_type": "markdown", "metadata": { "id": "NsDAltf5Le33" }, "source": [ "我们可以绘制多个样本:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "v5jdTzbqLhrl" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 5, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "n.sample(3)" ] }, { "cell_type": "markdown", "metadata": { "id": "4wX5cjUXLmrD" }, "source": [ "我们可以计算一个对数概率:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "hrCTzv2cLoLw" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 6, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "n.log_prob(0.)" ] }, { "cell_type": "markdown", "metadata": { "id": "lHYIb0psLrzE" }, "source": [ "我们可以计算多个对数概率:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "4dgwzazNLw6H" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 7, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "n.log_prob([0., 2., 4.])" ] }, { "cell_type": "markdown", "metadata": { "id": "mY5hHMClL-i1" }, "source": [ "存在各种各样的分布。我们试一试伯努利分布:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "OIJErPQWMDfP" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 8, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "b = tfd.Bernoulli(probs=0.7)\n", "b" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "oqDcgYE8Mck2" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 9, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "b.sample()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "id": "6HbbzPNTMgXh" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 10, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "b.sample(8)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "id": "LNy0tIKmMuL3" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 11, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "b.log_prob(1)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "id": "sghhA8onM0IN" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 12, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "b.log_prob([1, 0, 1, 0])" ] }, { "cell_type": "markdown", "metadata": { "id": "ztM2d-N9nNX2" }, "source": [ "## 多元分布" ] }, { "cell_type": "markdown", "metadata": { "id": "MT2ZyGCoHMla" }, "source": [ "我们使用对角协方差创建一个多元正态分布:" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "id": "MuFrhR4enQI5" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 13, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "nd = tfd.MultivariateNormalDiag(loc=[0., 10.], scale_diag=[1., 4.])\n", "nd" ] }, { "cell_type": "markdown", "metadata": { "id": "QUcnlm3vHFRj" }, "source": [ "将此分布与我们之前创建的一元正态分布进行比较,看看有什么不同? " ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "id": "ggInhJ-LHVhR" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 14, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "tfd.Normal(loc=0., scale=1.)" ] }, { "cell_type": "markdown", "metadata": { "id": "0ze8A19LnqO5" }, "source": [ "我们发现,一元正态分布的 `event_shape` 为 `()`,表明这是标量分布。多元正态分布的 `event_shape` 为 `2`,表明此分布的基本[事件空间](https://en.wikipedia.org/wiki/Event_(probability_theory))为二维空间。" ] }, { "cell_type": "markdown", "metadata": { "id": "lJTTGCuuHpf5" }, "source": [ "抽样的工作方式与以前相同:" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "id": "xQzfj0vHnw-5" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 15, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "nd.sample()" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "id": "dSyxxVmNnzT1" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 16, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "nd.sample(5)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "id": "SC0fmV3hn0Zp" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 17, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "nd.log_prob([0., 10])" ] }, { "cell_type": "markdown", "metadata": { "id": "lmI_pjVJIT4I" }, "source": [ "多元正态分布通常没有对角协方差。通过 TFD,可以采用多种方式创建多元正态分布,包括完全协方差规范(由协方差矩阵的 Cholesky 因子参数化),也就是我们在本文中使用的规范。" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "id": "qPEWjisBolk2" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAEICAYAAABGaK+TAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90\nbGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsT\nAAALEwEAmpwYAAAuXElEQVR4nO2de3BV53nun1dCQluyhLgKS4AxIMA2VwffQn0rGCeOc5t62thO\n6znOHM9xT0/Tnjq9nDqJm0zbcyaeXDzp0LpxTjIJTiYXp3XS5thAbcfB2C62ARsjkIgxSIAQSEIC\nSej2nT8efbOWtvaWtC/aWws9vxnNZq/b9+01zLPe9X7vxZxzEEIIET0K8j0BIYQQ6SEBF0KIiCIB\nF0KIiCIBF0KIiCIBF0KIiCIBF0KIiCIBF0KIiCIBF5HAzI6aWbeZdZpZu5m9Ymb/zczG9X/YzBab\nmTOzaRM917hxbzOzxlyOKaYOEnARJT7qnCsHcAWA/w3gLwA8ld8pCZE/JOAicjjnzjnnngXwewAe\nMLNVAGBmHzGzt8ysw8yOm9ljodN+NfTZbmbnzewmM1tqZv9hZmfN7IyZbTOzykRjGvmamZ02s3Nm\ntj807nQze9zMjplZs5n9o5nFzKwMwC8BVA+Ned7MqifqvoiphwRcRBbn3OsAGgHcPLTpAoA/AFAJ\n4CMAHjazTwztu2Xos9I5d5lzbjcAA/D3AKoBXAVgIYDHkgy3Zegay4eu/3sAzg7t+z9D29cBWAag\nBsAXnHMXAHwYwImhMS9zzp3I5DcLEUYCLqLOCQCzAMA596Jz7m3n3KBzbj+AHwC4NdmJzrkG59x2\n59xF51wLgK+OcnwfgHIAKwGYc+6gc+6kmRmA/wrgT51zrc65TgB/B+BTWfuFQiQhpws6QkwANQBa\nAcDMbgB946sAFAOYDuDHyU40s3kAngAt+HLQoGlLdKxz7j/M7JsA/gHAIjP7GYBHAJQAKAXwBrWc\nlwZQmOkPE2IsZIGLyGJm14EC/uuhTU8DeBbAQufcDAD/CIopACQqu/n3Q9vXOOcqAHw6dPwInHNP\nOOc+AOAa0GXyOQBnAHQDuMY5Vzn0N8M5d9ko4wqRFSTgInKYWYWZ3Q3ghwC+75x7e2hXOYBW51yP\nmV0P4L7QaS0ABgEsCW0rB3AeXNisAQU52ZjXmdkNZlYE+tp7AAw45wYB/DOArw1Z9DCzGjO7c+jU\nZgCzzWxGhj9biBFIwEWU+LmZdQI4DuCvQZ/1fwnt/0MAXxo65gsAfuR3OOe6APwtgF1DceQ3Avgb\nANcCOAfg3wA8M8rYFaBQtwF4H1zAfHxo318AaADwqpl1ANgBYMXQuHWgL/43Q+MqCkVkDVNDByGE\niCaywIUQIqKMKeBm9u2h5IV3Qttmmdl2M6sf+pw5sdMUQggRz3gs8O8A+FDctr8EsNM5Vwtg59B3\nIYQQOWRcPnAzWwzgF845nzp8CMBtQ4kMlwN40Tm3YkJnKoQQYhjpJvJUOedOAsCQiM9LdqCZPQTg\nIQAoKyv7wMqVK9McUgghpiZvvPHGGefc3PjtE56J6Zx7EsCTALBhwwa3Z8+eiR5SCCEuKczs/UTb\n041CaR5ynWDo83S6ExNCCJEe6Qr4swAeGPr3AwD+NTvTEUIIMV7GE0b4AwC7Aawws0Yz+wxYMOgO\nM6sHcMfQdyGEEDlkTB+4c+7eJLs2ZXkuQgghUkCZmEIIEVEk4EIIEVEk4EIIEVEk4EIIEVEk4EII\nEVEk4EIIEVEk4EIIEVEk4EIIEVEk4EIIEVEk4EIIEVEk4EIIEVEk4EIIEVEk4EIIEVEk4EIIEVEk\n4EIIEVEk4EIIEVEk4EIIEVEk4EIIEVEk4EIIEVEk4EIIEVEk4EIIEVEk4EIIEVEk4EIIEVEk4EII\nEVEk4EIIEVEk4EIIEVEk4EIIEVEk4EIIEVEk4EIIEVEk4EIIEVEk4EIIEVEk4EIIEVEk4EIIEVEk\n4EIIEVEk4EIIEVEyEnAz+1MzO2Bm75jZD8ysJFsTE0IIMTppC7iZ1QD4YwAbnHOrABQC+FS2JiaE\nEGJ0pmXh/JiZ9QEoBXAi8ykJIfJBfT2wYwfQ1ATU1ACbNwO1tfmelRiNtC1w51wTgMcBHANwEsA5\n59zz8ceZ2UNmtsfM9rS0tKQ/UyHEhFFfDzz1FNDZCSxYwM+nnuJ2MXnJxIUyE8DHAVwJoBpAmZl9\nOv4459yTzrkNzrkNc+fOTX+mQogJY8cOYNYsoLISKCjg56xZ3C4mL5ksYm4G8J5zrsU51wfgGQAf\nzM60hBC5pKkJqKgYvq2iAjghp+ikJhMBPwbgRjMrNTMDsAnAwexMSwiRS2pqgI6O4ds6OoDq6vzM\nR4yPTHzgrwH4CYA3Abw9dK0nszQvIUQO2bwZaG0F2tuBwUF+trZyu5i8mHMuZ4Nt2LDB7dmzJ2fj\nCSHGj49COXGClreiUCYPZvaGc25D/PZMwwiFEJcItbUS7KihVHohhIgoEnAhhIgoEnAhhIgoEnAh\nhIgoEnAhhIgoEnAhhIgoEnAhhIgoigMX4hJEpWGnBrLAhbjEUGnYqYMscCEmOala0+HSsEDwuWOH\nrPBLDVngQkxi0rGmVRp26iABF2ISk06jBZWGnTrIhSLEJKapiZZ3mIoKbk/G5s200v2xHR0sDfvJ\nT2ZnTlognTzIAhdiEpOONV1bC3zmM0B5OUW2vJzfsyGyWiCdXMgCF2ISk641PVGlYbVAOrmQgAsx\nCUjmlvDWtN9XXU3xrq3NjysjHZeOmDgk4ELkGe+WmDWL4tjRwe/e7ZHImh7rnInCu3S85Q1ogTSf\nSMCFyDPpuCXSOcdb7Pv3s+flzJnA6tWpWe4TvUAqUkMCLkSeiXdLNDcDBw8Cx47xeyKBTdWV4S32\ngQHgvfcYktjaCpSVpWa5j+bSEblHAi5Engm7JZqbgV27KLALFwZRHvECm6orw1vs+/YBpaX86+5m\ncs/ataktQqp35uRBYYRC5JnNm2kN19cDzz4LNDQAR48C8+YlT9zx57S3A4OD/Gxt5fZE+OzM9nYg\nFuO2khJ+V5ZmdJEFLkQOGC1ipLYW2LQJePxxoKUFmDMHmD0bOHSI/547d6RrJFVXRthi7+6mBd7T\nw+9ahIwuEnAhJphwxEhREfD888D3vgfceSdw330U3YYGivicORTWWAzo6gLq6oDp0xMLbCquDL/4\nWF0NvP02xxgYAJYs0SJklJGAC5EGqcRge//zxYvA7t2AcxTnp58GXnkFeOSRYFFy5Ur6wDs7gbY2\n4MwZ/j3ySGbzDVvsFy7w2rNmUcCVCh9dJOBCpEiqMdhenF9+meLd2AgUF9MaLyyk6+SDH+R1qqqA\nFSuA7duB3l66UtasAXbuBBYvDq6fThKPFh8vPSTgQqRIqjHY3v/c3k7Ld/p0CnlZGWOxT53i99ZW\nHn/6NMV6cBDYuJGi3t4eXD/bSTwqThVdJOAi8ky0ANXXA9u2Aa+9xu/nz3OMcAjfaDHY3v9cXAyc\nO8cFxN5eYNEiLijOm0d/tHdxHD/OfVddRfGOv34265Ek889v2QLcf7+EfLIjAReRZqKt0ZISXu/Y\nMY4xdy5jtf/934G77w4EdrRIDu9/3raNcdh9fcDSpXSfdHXx39XVwXx/9Sta5Wb8XlU1/PrZrEcS\n758vK+N4+/ZxoXOiU/NFZigOXESadBoeJCO+VOpvfgN86UsU09mzef0TJ2gxt7UBb745vhhsgCL4\n2GPAP/0THwwXLtCVsno1hXzz5mD8BQtorbe1Ab/+NbeHr5/Nhg0+PryujuIdiwVvCOneR5E7JOAi\n0mSzfVj8w+DECYrrhQt0LxQVUXR7e4HLL6fVmmq97TvuAL75TeCee4BlyxgF4s/149fW0vc9cybQ\n389Fz/D1U03iGQ3/MGhr49sGQLdOZaUSfKKAXCgi0mSzOl68a6Ktjdft7KTbw0eOnDvHcbdsAR5+\nOPVxkkWDNDXx+vv2UZQrK4FbbqGIx0efnDsHvP8+BX/VqvTrkXj//PTpFG6Abp3165XgEwVkgYtI\nMxHWqGfmTLoTLruMVnhvLxcwBwfpRvFj1NcDW7cCjz7Kz3S700ybBrzwQmABd3fze2FhMI538axd\nS5EtLw/mkc4cvH9+3Tr69p0DbrqJgp7ufRS5QwIuIk0224fFPwyqqyna69fTMj52jCJ3ww1MrPEh\nfY8/Dvz0p8DPfgZ8/evAgw8yjjsV6uuB//xPVgp87z0+SPwipv9M5u/fti2zNme1tcAXvwh85zt8\nq+jvz24bNjFxZORCMbNKAN8CsAqAA/Cgc253FuYlxLjJVoJKfH2RJUuAq68GfvhDYP58Wr01NbSI\njx7lcT/4AV0ZZrTYi4t57t/8zfDEm9HwlnVPDxc1m5qAd9/l2LffTkEFkkefvPACj8s0rFCJPtEj\nUx/4NwD8P+fcPWZWDKA0C3MSIm/Ei9jWraxREvaxe6t70ybg7Fn6xwcGKKaxGDBjBotSjVdAvWU9\nfz7dJtdcw8+SEl6vvJzHJfP3A4kXctXm7NInbReKmVUAuAXAUwDgnOt1zrVnaV5C5BXv1962Ddi7\nl66T5mbgpZcYA370KKNQzCjeRUXD/efTpo0/gsNH0qxcyQXEri76oE+dGu6HTubvv/HG7IUVimiR\niQW+BEALgP9rZmsBvAHgs865C+GDzOwhAA8BwKJFizIYTojsMVr25vbttLD7+ijSjY38c46Ll/39\ntIx37WJiz+nTFPHBQQrn2bOMqW5o4DiJrPDw+EeOcBwfPvj668CBA3wI+NrdQFB29tvfZrbmwoX0\nty9enFqbs/jM0htvDKoiimhhzrn0TjTbAOBVABudc6+Z2TcAdDjnPp/snA0bNrg9e/akN1MhxmC8\nKfXh7M2KCgro/v30ec+fTwGvqKBPu6WFQlxQwGiUJUuY9LJwIb/39gInT9Iid47+8eJi4MorKYyF\nhSMXAxON//LLwM0385ovvMDjbr+dAt7aymsAw8/zQu337dhBq7+6evTf/pWvAIcP8/eZ8RorVgQL\ns2LyYWZvOOc2xG/PxAJvBNDonBt6juMnAP4yg+sJkRJhwS4q4ueyZWOn1IejOZqbWR+7sJAuiVOn\naN2uXUtxmzeP57z9Nj9jMSbjHDrEMU6coJjOmEGLPRYDli8Hrr9+ZBGqROMDwb7GxqDZ8LXXBmn6\n/hwgeQ2Uhx8ev7+9pYXXKR1asTLjW0Q6tVREfklbwJ1zp8zsuJmtcM4dArAJwLvZm5oQyYmvgfLc\nc0y8WbgwCLEDEotSOJrDp5D79mLO8dzjx7mYCNBNUlHBawP0iZ88SbGPxejCuPpqWs5VVcMXFBMt\nJiaKJlm6lHNwjvsKCkZew+8Lk+piZVMT3TUzZwbbSkp475R1GT0yjQP/HwC2mdl+AOsA/F3GMxJi\nHMTHRHtRqqsLjkmWCh5O2PEp5D55ZuZMWt0dHVxMdI7HlJXRcj1+nJ/nztFqr63lsYcPU8zPnWM4\n4K5dtO4TLSaOVssk3X3jpaZmeNYlwPkm6/ojJjcZCbhzbq9zboNzbo1z7hPOubZsTUyI0YivgeL9\nue3twbZk4haO5qispEB3dTEKZOVKukKuvJIPhT17+FCoqACuu47CNzhIYZ8+nQWv3nuPXXOWLGGm\npnN0T7z5ZuJsxtGyR9PdN142b+YbRWsrs0u7urjoGs4sFdFBmZgiksRboytXUpSKi8cWNx/N8dZb\n9GW/9x6vN3cuRXnFCrpE2tv5+cADtLaPH6clXlvLCBGfIdnbS6u8p4fHx2L8d7JyrKNlj6a7b7zU\n1gKf+xxw223BHG+/XQuYUSXtKJR0UBSKyBbxkRwdHYzmuPxyhvSNFYnhz+3uZsnWI0dYMnbePOCK\nK2iVL1jA85ubgWef5ba+PlrXZrRci4tpcff18fx77w0WL32dEnW7EZkyEVEoQuSN+LT36urxW5Hx\nTQzKy7mI+P77QUODujoKdm8vrfTKyqCg1fHjFPG+Pgq5GcP/ZsygFe+t//Xrs9tsQoh4JOAisqRb\nu8NHgfzbv9G67u+nuBYW0oXyi18Ei3xNTRzDW90nT/LYixcp6gMD3F5VRT+8f5h88pPZbX02XtTf\ncmohARdTjpoaukzefZfWd2lpUEr1yBH6t2traYWfPk2XihmjXRYvZrJPXR3FvqyMlnhLC+ty/8Ef\nBIL53e9mr/XZeBhvezmJ/KWDFjHFlMHXN3n7beCZZyhwJ08y1NA5inRXV5CYs2QJXSP19Qw13LiR\nC6Tl5bSmy8povZeWUgiXLRvegmy8YX/Zqic+nvZy8W3jUi09KyYXssBFZBiP5ZjsmLB1WlVFF8nA\nAAW4sDCI4e7v5/6uLgr67/wO8MYbFN133+XfhQu00hcsYPd4f368de273bS0MMvy9Gme96lPUahT\nzSAdi/376X/3Mew+Mmb6dF7/jjtGunUuXqSP/4//GPjYx2SNRw0JuIgE43EPJDvGF4A6fZruj44O\ninRnJ4UOoJBNmwbMmUMrOxbjIuT06dy/fz/PKyjgtosXg5jsqio2QjhyhEL96KNBmOHx44wH9xZ9\nLMbGxjffzIXTVDJIR7s327bxWqWlfDOor+eDZfZszuPzQxWKwlmgzc1MOPL+fW+Na5E1OsiFIiLB\neNwDiY4ZGGBlwebmwPI+cIDhf2fPUpR9ESqANVBuvZUC69uKzZ7Nh8CSJWw9tn49t7W304KeMYPX\nefllimNREfDii0ytP3eOSUGVlcBVV9EinjWLPvSXX6Y7p72dFQg9qTQT9g+tffv4QOjt5VvCtGn8\njefOcfusWXyIhd06vowAwAXYRPdUTG4k4CISjKf7fKJjGhvpFikrAw4epOB1dQHvvMPPwUGKvHP0\nd8+fPzJRpq+P1/Vp9xUVLDa1ahWwZg1FsrGRol9by7T6WbMo8kePBr016+oo1gUFFNmeHu7r6+P3\n5mbOOZX0eP/Q6u1lDPtVV/E3XbzIB0lFReCzb2wcns3Z2spjfRZqonsqJjdyoYhIMJ7u8/HHNDfT\nfeFT330STm8vLfHp0ylYXrwXL2bp2CeeSDz2zJkU3ViM59fU0GK/5Zbhrgmfou/x5/jt+/YF3Xuq\nq/lgKSnhp7f6k9XyjseP65sgl5fTDdTdzfkWFQVz8olJPn7eZ5Ju3BhUPlQjiGghC1xEgvHUAQkf\nc/IksHNn4B6JxQKXQmEhXQzTpgX1u2MxiniixGR/3epq1jppbaXVWl0dzCHsmvBi2tND98mFC7Te\nKyp4Tlsbxzt0iH8DAxTd48dTT4/344a7+dTU8K3j3Dla5a2t/HvwQZ5TW8vys9/4BssGTJ+efm0V\nkV+USi/ySioxyf5Y37Rg2TJazOFzAR7z7LO0aq+4Avj5zynSnZ2BkAIUuIoKClhhISNKbrsNeOyx\nkWNv304f8uHDtFyXL6fLJD7KZWCA7pB33uF1P/xhjr1/P33Rq1dz4fD11+lK8Y0jLl5MPPZY9ye+\nLMDevVysXbKEYtzRQcv7wQcZhTLWPVUUyuREqfRi0jHexBOPz7z0URff+x5f/detGx5B8fDDgWuh\noIDi+c47FF6/uOcXIIuKgkiUefOA++9PPM+dO7l4eeutQQef/fuHz23TJi6Y9vdTqH01wy1bgG9+\nM/hNp0/THz17Nh8yPT1Bf81U70/YJdLZybFSEWF1oo82EnCRN8ZKNU9kfQIUsUOHKN4FBaxnsnFj\nEEFRWzvcH15Wxs/iYopkUxNFtriYfvH+/qDWyVjzjO/gE35wvPIKfewXLzIyZeVKWuHl5cNFsq+P\nFQAPHw784rffznmkcn88EuGpiwRc5BwvzNu20W1x1VXBIlpFBQXysceA558faWGXlARRF5WVgdVa\nV0eXhk+k8Uk0AIXxiiu4r6wsWLTs7aXvt7KS3XfOngX+6I/oflizJrBkfahfezuTY2bODMIIvahu\n28b5zp8fLHbu2gXcdBPnHqamhttuvXX4PfEx5P5hlahzz0Sm4ovooUVMkVPCqdwLF1IEffcagO6J\nhgZGaoQt7IsXKdyvvUYR8wuFHR3AsWPc/txzQTx3uHa2rxZ47730XS9ezGsUFPCvv58RIL6gVdiy\n3r6dc/K9Kjs7KaBnzgTiXVHB8f1DyIyLlGVl9EkXFg5PlV+2bPiCbH19EEMeTm8vKsq8A4+4tJGA\ni5wSdgtcfTUFrKCAAtreTr/ymjW0jktLAyH0XXHMgqiLlhZu7+6mQLe1MfrE1/VIFG3hI0BaWmiJ\nl5RQwJuaaDUfO0aL2z8wvv1tzmdwkONUVPD448eD2GkvsuvWBZEgzvGco0c5p3DtkZ07OZ+33gK+\n/30m/FxzDecbTlJyLvMOPOLSRgIucko42aaqir7rGTOCELolSxit4S1sIGg43NEB3HADRczHcBcX\nBzHP5eW03r/wheHFmcLWuHP0da9axfN7eniMWZBOX1TEt4Lubs5r6VLOMxbjQ8VHj4Rrf994I/f7\n43xrt+rq4PeEs0N/+EMuin7604FbxL+FANw2MJB5Bx5xaSMfuMgp8ck2VVUU41tuobW8dWtgYe/a\nFZxXXEyh/Mxn+H3HDoYBrllDIT10iMJbVUUhfPxxdufp7w98yg8/zM8HHuA5PqW9vZ2C2dvLBcZl\ny2hFP/MMz3/uOWZeep+191eHa38DQcTIzTfzN/gwvvjsUL+I6u/B5ZfzzaCubmRCjRYoxWhIwEVO\nCS8u+lZo4cxDv3/WLC4A+rjmLVuA++4bHj4H0CWxbx8t49JSWs2lpRT006eBO+/kGJ//PI89d45/\nAK32FSsonP39tMZragK/dEcHxfzAAcZJ33UXrevCQuBLXxoprPEdgnxTh/gM0lOnAqEG+LD69a+5\nfXBw5D0RIhkScDEq2S7+n6gVmhcqX2I1FqMFPDDAVHUzWsa+yJIf34u9F8TubrpHioqClmkFBRTf\nV16hC2LtWm5raGD0y3XX0W3S1kYfdksLk2zOn+c1a2u57ehR+q7vvXfkfMP3JdG9iX9gFRUNjy6p\nquKbRLxVL8tbjIUyMUVSEjUO9m6MbIpLfT1dHqdP0wouKWFSzT33UDRHG7++nj7vo0eD5gpnz9It\nMW8e3R7//M8U6K4uhv/FYrR0AWZKFhZyoXHpUo7zta9x/9VXU/Q7OjhOSwsFf9kyulv88WPdl0QZ\npGP9LiHCKBNTpEw2ezqOZsk//TRdHrNnc7yeHn7/+tcpwKONX1vLNPHPf55+7cpKjrFnT3BMYyOt\neV/zpK+P1nksBnz5yyPnN20ahdaL98GDfAAUFPDh8txzHMdneo51XxJZ5osXj3wLkXiLVJGAi6Rk\nK5FkrJTwnTu57cwZhgzOn89j9+8HPvrR4DrNzRTTY8f43T8EGhq4cHjiBOfmHK/jqw6eP0+hvfxy\numOKi4OOO56wyBYWsp53Vxctc5+hOX8+r1tQwGuHFx1TvS9anBTZQAIukjKeEq5j4V0czc0U0JUr\nA9HzPu3GxiCh5swZCvSSJYxO8eP77jEFBQw7fP551kK58076wNeupYgePkwfc2kphdentfvuO+3t\nQR3w5csTz/n++/kwaGmh5d3fzwXPK6/k/ooKLoT6UMF07osQ2UACLpIyVsTIWHjL27cy8+nlGzfS\n3dHURBGvqqI1G4sF7crefRf4yEc4HkDLu6AgSEufO5fXfPllXmf3blrFPT1BPeyiIsZnV1QwmqW9\nncdcdhn94729zLSMr2hYWwt87nOcW08Pxbyqii4VgGLe1kZLXlEjIp8okUckJZwAk04iifehe/EO\nZ1V6i7WpKVjMKyykf7qoKKh5EosxU/G11+jyqKykeJeW8tj33uP35mb6ubu7Ka6Dg0zVr6sL3EDr\n1wMf+hCwYQP97QsXcvE0UYd2n8X5xBN8CPT20nL3Vv369YxaUYKNyCeywMWoJPLVjje00PvQw0k5\nJSV0eSxcGMRJ795NF0hzM0Vy2jS6PfbsAT7xCeDjHw+a/zoXCHJjI4V/4ULg/fd5Tl8ffd6rV/NB\nsGcPxfqyy2jB+wJU69bRqg8n1CRbJH3kES60vvYax7/tNrpZJNgi30jARUqkUsPb+9B9ynxdHcV7\n3rzhx//Lv1BIly2jpX7hAq1p32gXYCbkzp0U4FmzKMYdHawh0tPDa/qWYU1NdMXs38/z77yTLpS2\nNtbs9j74F14YnlADBIuR8Q+p++4DvvjFCbutQqSFXCgiJbxb5OJF+p9ffJEhf08/PfLYcIuzuXNp\nZV933fAsRm/hDgxQ3KdPp/V84QKtZE9VFWtmz5hBS31wkJZ9fz+P/eAHab23tdE1M20aFx3vuYeL\np9dey+v4Hpm+oUM4yqa5mZb+Sy+xrOyePaxE+JOf8Pv27RNzT4VIFyXyTGHSybJ89FEK3+7d9GeX\nlNDv3NwMfOc7yd0tY7Xsij/u9Gn6v30ESl0dQ/qqqhj33dDAqoFHjjCLcenSoFPOsmXAb35D3/Xl\nlwdjnDwJvPoqfeDxCTXd3bTIAbpbOjroolm6lNZ9WxsfMuHOOkLkimSJPBLwSUa2U9dHGyedLMut\nWxnCV1BAgQWC8qlbtnDhLxts384Fxo6OwIIvLaV1XlgYzDPZA2LrVi5KhkMg29u54Bieo2/P9uMf\n05pfvpxjdnUFdVWWL+fvO3WKFn22fqMQ4yWZgMuFMokINzuIj4rINuEsy3ANah+bnYzNmwMXhnNB\n/et16yii2cD3oFyzhouSfX2MD1+xguIcnqePFvnyl/kZrpMS7lD/y1/S197cPPJ+9vQwUuYDH2Bo\n4JkzTMf3xbH8MVVV2fuNQmQDLWJOIrKZuj4W6WZZ1tbS0t63L4joWL8+cDn84R/S3VBZObwtWTIS\nvXGE78OBAxTu7m4Ka3MzY8SPH+f5ya7vQyCffppvDPPmcTEzFhu+6BoOdfQWd00N3TO+a70vkrVk\niZJ1xOQiYwE3s0IAewA0OefuznxKU5dc9kDMJMvy/vtpkXr3y5EjXNBctYr/LiykwNbXB9mS4VKw\nHu8mOXeO7ouiIlrJc+YwNR4I+kvGYrwPZ8/yjWHRouENhZOJ+Lx5DEMM/04geCgmCnWcM4eC3t7O\n3zJ9OsW7sFDdcMTkIhsulM8COJiF60x5vKiGSSd1Pdx/MZn7JexiSLVdV3yCT2MjBde3Nisq4iLk\nmTMUwr17R7qCfAXC7m4KcU8Pf2t3NyNFjhzhcStX0vptawtqmgwOshTseNw+4Q5AnoqKwBUSH+oY\ni9HKX7kSePJJNn9YupQCrmQdMdnIyAI3swUAPgLgbwH8z6zMaAqTrdT18cRoJ6vLPR6Bind7nD1L\nkTtwgBZzQwNdEQMDQXsxL7T++jt2BE0Uiov519fH7zNmMJpk7lz+rV7N776DfHwX+9HeUOLfNJqb\n+YDo6QkaDP/4x6x7cvEire2lS5lKX1sL3HHH+O59qvdsohanxdQiUxfK1wH8OYDyZAeY2UMAHgKA\nRYsWZTjcpU0mogqk7kNPpyJeoofEkSMMKfTuDp9NGYvxe2XlSKFtaqJ7o6mJgg0EzYXnzKGgNzby\nYeAcu/MAfDCM1+1TX883geeeo+AvXMgHAcCY8s5Oinf8W0+4SmE2SOXBKkQqpC3gZnY3gNPOuTfM\n7LZkxznnngTwJMAwwnTHmypkUmY0Fz70RA+JNWuCbvL79weNgxcvDhJy4oW2poYW7759PMZ3ziks\nDGK3X3+drpmlS4MHhXO0msd6QwmLps/E/NnPgnK1hw8Hne0BNnbwtLdnd+E4l4vTYmqRiQW+EcDH\nzOwuACUAKszs+865T2dnaiJVUl2YTOe1PtFDYunSIErDW98nTlAs162jWyIstN4y3rUrSNTp6KDl\nW1sbdIYvLuZ1amt53NKl9JF73/tobyhh0fShkm++ycSjmpqgMmJXF/32YbL90BvtwSrXisiEtAXc\nOfdXAP4KAIYs8Eck3vklFR96uq/1yR4Sq1ePTJDxCTbl5YHQJrKMu7sp3jU1QUTIq69yjHDN7YoK\nFq2aN4+W+GjEi2ZdXeDiMQuSkE6eZLGrMIkeemGhLSri+OGO9+ncs2nT5FoRmaE48EuIVHzo6b7W\nJ3pINDRQpB59dOwGv7592sWLFNRrr6V/+623GE8enk9bG48BaKX/6ld0o5w4Qct+tDDCeNFsa6Nv\n/b33gHfeofgWFrJM7Lx5fFAke+iFHzpFRUHKvfejJ5uDF/34lH8/hi+ZK9eKSJesCLhz7kUAL2bj\nWiIzxutDzySRJ/yQKCwMLNqqqtGtyPp6LijOnx9Yw7t2UcBnzQqaN1RU8OFz5AirDZ48SdE8cYIC\naMbxKyvpAnn//eEFsgD6yR9/nJEt8+fzs6OD4w4M8JiBAVrf99wTNHVI9NALP+z27eM5ztGPfuut\nwTHhc8Kiv2YN3Un799PVtHo1x/judxOHOE5E3L+4NJEFPkXJJJEn/JDYunV4ZMhoVqTvvgMMd2Ps\n3cvsTp+F2dREf/qWLRTWZ5+l8DpHa7mzkxb5+fMMKWxuHv7QCKfiNzZyf28vBbuqKnh4eDFtaBi9\nvkn4YRd+K/DunUSiG/+GU1vLkMhwLZZstKwTUxsJ+BSkvp6i9vzzFLR16yimY8WcJ1pwS8WSb2ri\nWLt383ssxtju5ubEbpf6eoprZydFs6KCwnvqFC3a/n66YnwTZP/QCIunv157O/CLXwR+dd/Uwbd2\nG42w0Hrxdy4Q3kSiO577kmncvxAqZjXF8K/2paVcRAQo5N3doy+eJSu0VVQ0/uzRmhqKts949Bbs\nli2J3S1+vIULeey5c4xe6eiggE6bRit65crh2ZXJsi99VMzHP07Xh3f5jGXxhrNWly9n4lJrK/+d\nLIN1PFm1mbasE0IW+BQjPrzu8suDTu07dtAvmyiyYts2Lj729vK8lSt5na6u4b7r0azIzZuBr3xl\neNbj3LmsrTLaPK++mr5y3xato4OujDVrgOuvpxC3twfimMw1ccMN459rmLDfv7OTi5c+CiUcYRP/\nW8djXWcS9y+EBHyKkejVvrubVvgnPjEynA1g5Mi3vsXFu0WLeLxffBwYSC17ND7LMVnWY3ievk7J\nu+/Sp33rrYzpvnCBfS07OriY6sUxmXj635NOpmuqQptpVq0Q40ECPsVIZJ3u3UuRjF+IfPppirXv\nwXHyJF0YvpCUX3wcr7jt2MEokg98INiWLOsxfp5VVbTYV6ygD/qmm4IFyrNn2ZYt3KZtNPHMlYjK\nuhYTjQQ8z+Q6Ey+RddrcHPjDPRUV7Hd5zTVsT1ZSQvdBfz8FfdUqWs+plFdNZcEzmRUdiyVeoGxo\nGF54SuIppgJaxMwjuezA40m0cOYbHYTxC4WNjbR8Ozu58FlSwpjqhgZa0qmIZCrlcpMt8PX18a3g\npZdYO/yll/h9rE454y2zK0SUkAWeR/JV5ChRuF4ia7e2lok3ra10W1x2GUU8FqPPedas1MZNNWwu\nkRXtMyFnzw5C+l54gQuLyVA1QHGpIgs8j4zVbCBXJLJ2N22iOPrel2VlTJw5f577V60KMhozGSdV\nEfU1UJJ9JiLd/p9CTHZkgeeQeH/3tGljZ+LlykcebrTQ1MS6IwsWAHfdBfzoR/R9z5hB0b3iCob2\npZMxmKlvur+f1vbhw0FCzu23c3syctmqTohcIgHPEYle40+eHL2+dS5f/ePH2r2bc/mt3wJ+93fZ\nv3JwkBbs6tXj6w85EQ+fmhr6430NEoBCXp60pYhS1sWli1woOSLRa/zSpRSXZC6FTF79U120ix9r\n/nyKdF0dI1HuvZfhg4sWja8/5EQt0KbTyzP+HF8rZf9+LWiKaCMBzxHJ/N39/Sxu9OUv8zMsiun6\nyNMRz/ixVq6k2J06xU8fg/3EEyPnmYiJ8jun40cPn/P220H3oLVrcxP5I8REIRdKjkjnNT7dV/90\nolsSJc6sXs0wwnQyCSfS75yOH92fs3UrffiqwS0uBWSB54hsvPqP5xwgPcs90ViFhayznejtYCxS\nifnOJZMl8keIbCABzxGZvvqnEnaXjnhmuzJeug+fiWayPliESAdzYzUXzCIbNmxwe3xhDTFhhCNK\n4os5TYSbIFm0SbgvZnX15GjYm+t7I0Q2MLM3nHMbRmyXgF+a5Eo8oyiIk/HBIsRoJBNwLWJeouSq\nmFO+ygFkggpdiUsF+cBFRmhRUIj8IQEXGaFFQSHyhwRcZMRkjTYRYiogARcZoca8QuQPLWKKjNGi\noBD5QRa4EEJEFAm4EEJEFAm4EEJEFAm4EEJEFAm4EEJEFAm4EEJEFAm4EEJEFAm4EEJEFAm4EEJE\nlLQF3MwWmtkLZnbQzA6Y2WezOTEhhBCjk0kqfT+AP3POvWlm5QDeMLPtzrl3szQ3IYQQo5C2Be6c\nO+mce3Po350ADgKoydbEhBBCjE5WilmZ2WIA6wG8lmDfQwAeAoBFixZlYziRIcl6WAohokXGi5hm\ndhmAnwL4E+dcR/x+59yTzrkNzrkNc+fOzXQ4kSG+h2VnJ7BgAT+feorbhRDRIiMBN7MiULy3Oeee\nyc6UxEQS7mFZUMDPWbO4XQgRLTKJQjEATwE46Jz7avamJCYS9bAU4tIhEwt8I4DfB/DbZrZ36O+u\nLM1LTBDqYSnEpUPai5jOuV8DsCzOReSAzZvp8wZoeXd0sIflJz+Z33kJIVJHmZhTDPWwFOLSQT0x\npyDqYSnEpYEscCGEiCgScCGEiCgScCGEiCgScCGEiCgScCGEiCgScCGEiCgScCGEiCgScCGEiCgS\ncCGEiCgScCGEiCgScCGEiCgScCGEiCgScCGEiCgScCGEiCgScCGEiCgScCGEiCgScCGEiCgScCGE\niCgScCGEiCgScCGEiCgScCGEiCgScCGEiCgScCGEiCgScCGEiCgScCGEiCgScCGEiCgScCGEiCgS\ncCGEiCgScCGEiCgScCGEiCgScCGEiCgScCGEiCgScCGEiCgZCbiZfcjMDplZg5n9ZbYmJYQQYmzS\nFnAzKwTwDwA+DOBqAPea2dXZmpgQQojRycQCvx5Ag3PuN865XgA/BPDx7ExLCCHEWEzL4NwaAMdD\n3xsB3BB/kJk9BOChoa/nzexQBmNmgzkAzuR5DpMF3YsA3YsA3YuAyXIvrki0MRMBtwTb3IgNzj0J\n4MkMxskqZrbHObch3/OYDOheBOheBOheBEz2e5GJC6URwMLQ9wUATmQ2HSGEEOMlEwH/TwC1Znal\nmRUD+BSAZ7MzLSGEEGORtgvFOddvZn8E4DkAhQC+7Zw7kLWZTRyTxp0zCdC9CNC9CNC9CJjU98Kc\nG+G2FkIIEQGUiSmEEBFFAi6EEBFlSgu4mT1iZs7M5uR7LvnCzL5iZnVmtt/MfmZmlfmeU65RSQhi\nZgvN7AUzO2hmB8zss/meU74xs0Ize8vMfpHvuSRiygq4mS0EcAeAY/meS57ZDmCVc24NgMMA/irP\n88kpKgkxjH4Af+acuwrAjQD++xS+F57PAjiY70kkY8oKOICvAfhzJEg+mko45553zvUPfX0VjOef\nSqgkxBDOuZPOuTeH/t0JCldNfmeVP8xsAYCPAPhWvueSjCkp4Gb2MQBNzrl9+Z7LJONBAL/M9yRy\nTKKSEFNWtDxmthjAegCv5Xkq+eTroJE3mOd5JCWTVPpJjZntADA/wa6/BvC/AGzJ7Yzyx2j3wjn3\nr0PH/DX4Cr0tl3ObBIyrJMRUwswuA/BTAH/inOvI93zygZndDeC0c+4NM7stz9NJyiUr4M65zYm2\nm9lqAFcC2GdmAF0Gb5rZ9c65UzmcYs5Idi88ZvYAgLsBbHJTLzFAJSFCmFkRKN7bnHPP5Hs+eWQj\ngI+Z2V0ASgBUmNn3nXOfzvO8hjHlE3nM7CiADc65yVBxLOeY2YcAfBXArc65lnzPJ9eY2TRw8XYT\ngCawRMR9EckqzipGi+a7AFqdc3+S5+lMGoYs8Eecc3fneSojmJI+cDGMbwIoB7DdzPaa2T/me0K5\nZGgB15eEOAjgR1NRvIfYCOD3Afz20P+FvUMWqJikTHkLXAghoooscCGEiCgScCGEiCgScCGEiCgS\ncCGEiCgScCGEiCgScCGEiCgScCGEiCj/H2Li1YPt41w8AAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "tags": [] }, "output_type": "display_data" } ], "source": [ "covariance_matrix = [[1., .7], [.7, 1.]]\n", "nd = tfd.MultivariateNormalTriL(\n", " loc = [0., 5], scale_tril = tf.linalg.cholesky(covariance_matrix))\n", "data = nd.sample(200)\n", "plt.scatter(data[:, 0], data[:, 1], color='blue', alpha=0.4)\n", "plt.axis([-5, 5, 0, 10])\n", "plt.title(\"Data set\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "57lLzC7MQV-9" }, "source": [ "## 多个分布" ] }, { "cell_type": "markdown", "metadata": { "id": "aRYY7-KvQupB" }, "source": [ "我们介绍的第一个伯努利分布表示公平地抛掷一枚硬币。我们也可以在一个 `Distribution` 对象中创建一批独立的伯努利分布,每个分布具有自己的参数:" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "id": "as9fo-XtRAFo" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 19, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "b3 = tfd.Bernoulli(probs=[.3, .5, .7])\n", "b3" ] }, { "cell_type": "markdown", "metadata": { "id": "x_7t57XzRGVD" }, "source": [ "重要的是弄清楚其具体含义。上述调用定义了三个独立的伯努利分布,它们恰好都在同一个 Python `Distribution` 对象中。这三个分布无法分别操作。请注意,`batch_shape` 为 `(3,)`,表明该批次包含三个分布,`event_shape` 为 `()`,表明各个分布具有一元事件空间。\n", "\n", "如果我们调用 `sample`,将得到所有这三个分布的样本:" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "id": "bQQJ_N7XRkuh" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 20, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "b3.sample()" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "id": "aM6JOl3HSQb3" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 21, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "b3.sample(6)" ] }, { "cell_type": "markdown", "metadata": { "id": "7NRbaUyLR2yf" }, "source": [ "如果调用 `prob`(其形状语义与 `log_prob` 相同;为了清楚起见,我们将 `prob` 与这些小的伯努利示例结合使用,但是 `log_prob` 通常更适合在应用中使用),我们可以向其传递一个向量,并计算抛掷每个硬币时得到该值的概率:" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "id": "UKRV_z47NUV9" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 22, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "b3.prob([1, 1, 0])" ] }, { "cell_type": "markdown", "metadata": { "id": "Y3MexqrtREPP" }, "source": [ "API 为何会包含批次形状?从语义上来讲,API 可以创建一组分布并使用 `for` 循环(至少在 Eager 模式下如此;在 TF 计算图模式下,需要使用 `tf.while` 循环)对这些分布进行迭代,从而执行相同的计算。但是,极为常见的情况是一组(可能较大的)分布采用相同的参数设置;为了能够使用硬件加速器快速进行计算,关键要素是尽可能使用向量化计算。" ] }, { "cell_type": "markdown", "metadata": { "id": "t52ptQXvUO07" }, "source": [ "## 使用 Independent 将批次汇总到事件" ] }, { "cell_type": "markdown", "metadata": { "id": "oN3mut1NTOXX" }, "source": [ "在上一部分中,我们创建了单个 `Distribution` 对象 `b3`,它表示抛掷硬币三次。如果我们在向量 $v$ 上调用 `b3.prob`,第 $i$ 个条目就是抛掷第 $i$ 枚硬币时得到的值为 $v[i]$ 的概率。\n", "\n", "假设我们改为对同一底层系列中的独立随机变量指定“联合”分布。这是不同的数学对象,因为在这个新分布中,向量 $v$ 的 `prob` 将返回单个值,表示抛掷整组硬币时得到的值均为向量 $v$ 的概率。\n", "\n", "我们如何实现这个目标呢?我们使用名为 `Independent` 的“高阶”分布,它获取一个分布,并得到一个批次形状移动至事件形状的新分布:" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "id": "V_DcGAG2Tqxj" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 23, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "b3_joint = tfd.Independent(b3, reinterpreted_batch_ndims=1)\n", "b3_joint" ] }, { "cell_type": "markdown", "metadata": { "id": "Zkv5TRVFVLUo" }, "source": [ "将该形状与原始 `b3` 的形状进行比较:" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "id": "5bBsLX-6VT36" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 24, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "b3" ] }, { "cell_type": "markdown", "metadata": { "id": "0uveNoPNVVYy" }, "source": [ "如前所述,我们发现 `Independent` 已将批次形状移动至事件形状:`b3_joint` 是三维事件空间 (`event_shape = (3,)`) 上的单一分布 (`batch_shape = ()`)。\n", "\n", "我们检查语义:" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "id": "eDsO2gLcVlY9" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 25, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "b3_joint.prob([1, 1, 0])" ] }, { "cell_type": "markdown", "metadata": { "id": "IktKInQ5WQJz" }, "source": [ "通过另一种方式也可以获得相同的结果,也就是使用 `b3` 计算概率,并通过乘法(或者采用更常见的做法,即使用对数概率,求和)手动简化:" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "id": "rRIVEchSV-RZ" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 26, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "tf.reduce_prod(b3.prob([1, 1, 0]))" ] }, { "cell_type": "markdown", "metadata": { "id": "ikayH3d2Wcf-" }, "source": [ "借助 `Indpendent`,用户可以更明确地表示所需的概念。我们将其视为极为有用的标记,但是它并不完全必要。" ] }, { "cell_type": "markdown", "metadata": { "id": "wVivnv1qWi9f" }, "source": [ "以下事实较为有趣:\n", "\n", "- `b3.sample` 和 `b3_joint.sample` 采用不同的概念实现,但输出无区别:在计算概率时,使用 `Independent` 基于批次生成的一批独立分布与单个分布之间存在差异,但在抽样时,这两种分布之间没有差别。\n", "- 可以使用标量 `Normal` 和 `Independent` 分布轻松实现 `MultivariateNormalDiag`(实际上不会以这种方式来实现它,但可以这么做)。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "INu1viAVXz93" }, "source": [ "## 多元分布的批次" ] }, { "cell_type": "markdown", "metadata": { "id": "G_cEhLU-Tjhm" }, "source": [ "我们创建一个批次,其中包含三个完全协方差的二维多元正态分布: " ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "id": "mtxwqizfTwKi" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 27, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "covariance_matrix = [[[1., .1], [.1, 1.]], \n", " [[1., .3], [.3, 1.]],\n", " [[1., .5], [.5, 1.]]]\n", "nd_batch = tfd.MultivariateNormalTriL(\n", " loc = [[0., 0.], [1., 1.], [2., 2.]],\n", " scale_tril = tf.linalg.cholesky(covariance_matrix))\n", "nd_batch" ] }, { "cell_type": "markdown", "metadata": { "id": "osDjz1vXUVkr" }, "source": [ "我们发现 `batch_shape = (3,)`,因此,有三个独立的多元正态分布;`event_shape = (2,)`,因此,每个多元正态分布均为二维分布。在本例中,单个分布没有独立的元素。\n", "\n", "抽样如下:" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "id": "82u32RUpYKeK" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 28, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "nd_batch.sample(4)" ] }, { "cell_type": "markdown", "metadata": { "id": "2I-cYckNYTmf" }, "source": [ "由于 `batch_shape = (3,)` 并且 `event_shape = (2,)`,我们将形状张量 `(3, 2)` 传递给 `log_prob`:" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "id": "-7p02_66YRpX" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 29, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "nd_batch.log_prob([[0., 0.], [1., 1.], [2., 2.]])" ] }, { "cell_type": "markdown", "metadata": { "id": "72uiME85SmEH" }, "source": [ "## 广播(也可以说这为何让人如此困惑?)" ] }, { "cell_type": "markdown", "metadata": { "id": "3aWnXjyYZYtp" }, "source": [ "到目前为止,我们所执行的操作可以抽象概括为,每个分布都有一个批次形状 `B` 和一个事件形状 `E`。令 `BE` 作为事件形状的串联:\n", "\n", "- 对于一元标量分布 `n` 和 `b`,`BE = ()`。\n", "- 对于二维多元正态分布 `nd`,`BE = (2)`。\n", "- 对于 `b3` 和 `b3_joint`,`BE = (3)`。\n", "- 对于多元正态分布批次 `ndb`,`BE = (3, 2)`。\n", "\n", "到目前为止,我们使用的“计算规则”如下:\n", "\n", "- 没有参数的样本将返回形状为 `BE` 的张量;标量为 n 的抽样将返回张量“n * `BE`”。\n", "- `prob` 和 `log_prob` 使用形状为 `BE` 的张量,并返回形状为 `B` 的结果。\n", "\n", "`prob` 和 `log_prob` 的实际“计算规则”更复杂,虽然性能和速度可能不错,但同时也增加了复杂性和挑战。实际规则(本质上)是 {nbsp}`log_prob` 的参数必须可以根据 BE 进行广播;输出中会预留“额外”维度。 " ] }, { "cell_type": "markdown", "metadata": { "id": "iwv81UjpmlkX" }, "source": [ "我们来探索具体含义。对于一元正态分布 `n`,`BE = ()`,因此,`log_prob` 应该为标量。如果我们向 `log_prob` 传递具有非空形状的张量,它们在输出中将显示为批次维度:" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "id": "xRMkZd2cnqnG" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 30, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "n = tfd.Normal(loc=0., scale=1.)\n", "n" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "id": "mci1cs1NnLDb" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 31, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "n.log_prob(0.)" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "id": "MQW1XSB1nRlH" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 32, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "n.log_prob([0.])" ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "id": "z-6d3PtTnT1W" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 33, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "n.log_prob([[0., 1.], [-1., 2.]])" ] }, { "cell_type": "markdown", "metadata": { "id": "6BkE19lEh9XY" }, "source": [ "我们转到二维多元正态分布 `nd`(为了便于阐述,对参数进行了更改):" ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "id": "Y1D3zg9kn8HJ" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 34, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "nd = tfd.MultivariateNormalDiag(loc=[0., 1.], scale_diag=[1., 1.])\n", "nd" ] }, { "cell_type": "markdown", "metadata": { "id": "SyZS-on4oCR4" }, "source": [ "`log_prob`“应该”使用形状为 `(2,)` 的参数,但它将接受根据此形状广播的任何参数: " ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "id": "RHyn5rV7oMzq" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 35, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "nd.log_prob([0., 0.])" ] }, { "cell_type": "markdown", "metadata": { "id": "DTnAETFGo17O" }, "source": [ "不过,我们可以传入“更多”示例,并同时计算所有 `log_prob`:" ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "id": "-eSm6Hnlo1sn" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 36, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "nd.log_prob([[0., 0.],\n", " [1., 1.],\n", " [2., 2.]])" ] }, { "cell_type": "markdown", "metadata": { "id": "dgxneFROpG7L" }, "source": [ "说服力可能还不够强,我们可以在事件维度上广播:" ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "id": "-YRxLZLcoW29" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 37, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "nd.log_prob([0.])" ] }, { "cell_type": "code", "execution_count": 38, "metadata": { "id": "Md6RkXrcpNiK" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 38, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "nd.log_prob([[0.], [1.], [2.]])" ] }, { "cell_type": "markdown", "metadata": { "id": "266h1o2KoZZL" }, "source": [ "这样广播后,我们最后得到了“尽可能广播”的设计;这种用法存在一些争议,在以后的 TFP 版本中可能会被移除。\n", "\n", "现在,我们再来看看三枚硬币的示例:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mKHtmSP6SnvY" }, "outputs": [], "source": [ "b3 = tfd.Bernoulli(probs=[.3, .5, .7])" ] }, { "cell_type": "markdown", "metadata": { "id": "bGOJAgv_p059" }, "source": [ "在此,可以非常直观地使用广播来表示*每*一枚硬币正面朝上的概率:" ] }, { "cell_type": "code", "execution_count": 40, "metadata": { "id": "ZYC6J8-dp50r" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 40, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "b3.prob([1])" ] }, { "cell_type": "markdown", "metadata": { "id": "5gxdAEjBiLiw" }, "source": [ "(将其与 `b3.prob([1., 1., 1.])` 进行比较,我们在前面介绍 `b3` 时使用过后者。)\n", "\n", "现在,假设我们想知道,对于每一枚硬币,硬币正面朝上的概率*和*背面朝上的概率。我们可以尝试:\n", "\n", "`b3.log_prob([0, 1])`\n", "\n", "遗憾的是,这样会得到错误,其中包含可辨识度较低的长堆栈轨迹。对于 `b3`,`BE = (3)`,因此,我们必须向 `b3.prob` 传递可根据 `(3,)` 进行广播的内容。`[0, 1]` 的形状为 `(2)`,因此,它不会广播,也不会生成错误。相反,我们必须使用:" ] }, { "cell_type": "code", "execution_count": 41, "metadata": { "id": "_ry9LMiIieUx" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 41, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "b3.prob([[0], [1]])" ] }, { "cell_type": "markdown", "metadata": { "id": "mxZ1WeK1qRcc" }, "source": [ "原因是 `[[0], [1]]` 的形状为 `(2, 1)`,因此,它会根据形状 `(3)` 广播,从而生成 `(2, 3)` 的广播形状。" ] }, { "cell_type": "markdown", "metadata": { "id": "WJBxD-zOrLDQ" }, "source": [ "广播非常有用:在有些情况下,广播可以大幅减少所使用的内存量,并且通常可以缩短用户代码。但是,广播对程序来说是一项挑战。如果调用 `log_prob` 并得到错误,问题几乎总是广播失败。" ] }, { "cell_type": "markdown", "metadata": { "id": "JpjjIGThrj8Q" }, "source": [ "## 延伸内容\n", "\n", "在本教程中,我们进行了简单介绍(希望如此)。作为延伸内容,有以下几点建议:\n", "\n", "- `event_shape`、`batch_shape` 和 `sample_shape` 可以是任意秩(在本教程中,它们始终为标量或秩 1)。这可以提高性能,但也会为编程带来挑战,特别是在涉及到广播时。要想深入了解形状操作,请参阅 [了解 TensorFlow Distributions 形状](https://github.com/tensorflow/probability/blob/main/tensorflow_probability/examples/jupyter_notebooks/Understanding_TensorFlow_Distributions_Shapes.ipynb)。\n", "- TFP 包含一个强大的抽象概念 `Bijectors`,将该概念与 `TransformedDistribution` 结合使用,可以轻松灵活地创建新的分布,这些分布是现有分布的可逆转换。我们很快会尝试编写相关的教程,但目前可以参阅[本文档](https://tensorflow.google.cn/probability/api_docs/python/tfp/distributions/TransformedDistribution)。\n" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "TensorFlow_Distributions_Tutorial.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }