{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "htW5SiGzeXYm" }, "source": [ "##### Copyright 2018 The TensorFlow Probability Authors.\n", "\n", "Licensed under the Apache License, Version 2.0 (the \"License\");" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "9HGeUNoteaSm" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\"); { display-mode: \"form\" }\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "JJ3UDciDVcB5" }, "source": [ "# 贝叶斯高斯混合模型与汉密尔顿 MCMC\n", "\n", "\n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 上查看 在 Google Colab 中运行 在 Github 上查看源代码 下载笔记本
" ] }, { "cell_type": "markdown", "metadata": { "id": "lin40yCC6eBo" }, "source": [ "在此 Colab 中,我们将仅使用 TensorFlow Probability 基元探索从贝叶斯高斯混合模型 (BGMM) 的后验中进行采样。" ] }, { "cell_type": "markdown", "metadata": { "id": "eZs1ShikNBK2" }, "source": [ "## 模型" ] }, { "cell_type": "markdown", "metadata": { "id": "7JjokKMbk2hJ" }, "source": [ "对于每个维度 $D$ 的 $k\\in{1,\\ldots, K}$ 混合分量,我们希望使用以下贝叶斯高斯混合模型对 $i\\in{1,\\ldots,N}$ iid 样本进行建模:\n", "\n", "$$\\begin{align*}\n", "\\theta &\\sim \\text{Dirichlet}(\\text{concentration}=\\alpha_0)\\\\\n", "\\mu_k &\\sim \\text{Normal}(\\text{loc}=\\mu_{0k}, \\text{scale}=I_D)\\\\\n", "T_k &\\sim \\text{Wishart}(\\text{df}=5, \\text{scale}=I_D)\\\\\n", "Z_i &\\sim \\text{Categorical}(\\text{probs}=\\theta)\\\\\n", "Y_i &\\sim \\text{Normal}(\\text{loc}=\\mu_{z_i}, \\text{scale}=T_{z_i}^{-1/2})\\\\\n", "\\end{align*}$$" ] }, { "cell_type": "markdown", "metadata": { "id": "iySRABi0qZnQ" }, "source": [ "请注意,`scale` 参数都具有 `cholesky` 语义。我们使用此约定是因为它是 TF 分布(它本身使用此约定,部分原因是它在计算上具有优势)的约定。" ] }, { "cell_type": "markdown", "metadata": { "id": "Y6X_Beihwzyi" }, "source": [ "我们的目标是从后验生成样本:\n", "\n", "$$p\\left(\\theta, \\{\\mu_k, T_k\\}_{k=1}^K \\Big| \\{y_i\\}_{i=1}^N, \\alpha_0, \\{\\mu_{ok}\\}_{k=1}^K\\right)$$\n", "\n", "请注意不存在 ${Z_i}_{i=1}^N$,我们只对不会随 $N$ 缩放的随机变量感兴趣。(幸运的是,存在一个 TF 分布可以处理 $Z_i$ 的边缘化。)\n", "\n", "It is not possible to directly sample from this distribution owing to a computationally intractable normalization term.\n", "\n", "[梅特罗波利斯-黑斯廷斯算法](https://en.wikipedia.org/wiki/Metropolis%E2%80%93Hastings_algorithm)是一种用于从难以归一化的分布进行抽样的技术。\n", "\n", "TensorFlow Probability 提供了许多 MCMC 选项,其中包括一些基于梅特罗波利斯-黑斯廷斯算法的选项。在此笔记本中,我们将使用[汉密尔顿蒙特卡洛算法](https://en.wikipedia.org/wiki/Hamiltonian_Monte_Carlo) (`tfp.mcmc.HamiltonianMonteCarlo`)。HMC 通常是一个不错的选择,因为它可以快速收敛、可以对状态空间进行联合采样(相对于协调方式),还可以利用 TF 的其中一个优点:自动微分。即便如此,实际上通过其他方式(例如[吉布斯采样](https://en.wikipedia.org/wiki/Gibbs_sampling))可能可以更好地完成从 BGMM 后验分布的采样。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "uswTWdgNu46j" }, "outputs": [], "source": [ "%matplotlib inline\n", "\n", "\n", "import functools\n", "\n", "import matplotlib.pyplot as plt; plt.style.use('ggplot')\n", "import numpy as np\n", "import seaborn as sns; sns.set_context('notebook')\n", "\n", "import tensorflow.compat.v2 as tf\n", "tf.enable_v2_behavior()\n", "import tensorflow_probability as tfp\n", "\n", "tfd = tfp.distributions\n", "tfb = tfp.bijectors\n", "\n", "physical_devices = tf.config.experimental.list_physical_devices('GPU')\n", "if len(physical_devices) > 0:\n", " tf.config.experimental.set_memory_growth(physical_devices[0], True)" ] }, { "cell_type": "markdown", "metadata": { "id": "Uj9uHZN2yUqz" }, "source": [ "在实际构建模型之前,我们需要定义一种新的分布类型。从上面的模型规范中可以明显看出,我们在使用逆协方差矩阵对 MVN 进行参数化,即[精度矩阵] (https://en.wikipedia.org/wiki/Precision_(statistics%29))。为了在 TF 中实现此操作,我们需要推出 `Bijector`。这个 `Bijector` 将使用前向转换:\n", "\n", "- Y = tf.linalg.triangular_solve((tf.linalg.matrix_transpose(chol_precision_tril), X, adjoint=True) + loc\n", "\n", "`log_prob` 正好相反,即:\n", "\n", "- X = tf.linalg.matmul(chol_precision_tril, X - loc, adjoint_a=True)\n", "\n", "由于 HMC 只需 `log_prob`,这意味着我们会避免调用 `tf.linalg.triangular_solve`(与 `tfd.MultivariateNormalTriL` 一样)。这样做有优势,因为由于缓存位置更好,`tf.linalg.matmul` 通常更快。\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "nc4yy6vW-lC_" }, "outputs": [], "source": [ "class MVNCholPrecisionTriL(tfd.TransformedDistribution):\n", " \"\"\"MVN from loc and (Cholesky) precision matrix.\"\"\"\n", "\n", " def __init__(self, loc, chol_precision_tril, name=None):\n", " super(MVNCholPrecisionTriL, self).__init__(\n", " distribution=tfd.Independent(tfd.Normal(tf.zeros_like(loc),\n", " scale=tf.ones_like(loc)),\n", " reinterpreted_batch_ndims=1),\n", " bijector=tfb.Chain([\n", " tfb.Shift(shift=loc),\n", " tfb.Invert(tfb.ScaleMatvecTriL(scale_tril=chol_precision_tril,\n", " adjoint=True)),\n", " ]),\n", " name=name)" ] }, { "cell_type": "markdown", "metadata": { "id": "JDOkWhDQg4ZG" }, "source": [ "`tfd.Independent` 分布会将一个分布的独立图表转变为具有统计独立坐标的多元分布。就计算 `log_prob` 而言,此“元分布”表现为事件维度上的简单总和。\n", "\n", "另请注意,我们接受缩放矩阵的 `adjoint`(“转置”)。这是因为如果精确率是逆协方差,即 $P=C^{-1}$,并且如果 $C=AA^\\top$,则 $P=BB^{\\top}$,其中 $B=A^{-\\top}$。" ] }, { "cell_type": "markdown", "metadata": { "id": "Pfkc8cmhh2Qz" }, "source": [ "由于这种分布不好处理,让我们来快速验证一下 `MVNCholPrecisionTriL` 是否能够按照我们的预期工作。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GhqbjwlIh1Vn" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "true mean: [ 1. -1.]\n", "sample mean: [ 1.0002806 -1.000105 ]\n", "true cov:\n", " [[ 1.0625 -0.03125 ]\n", " [-0.03125 0.015625]]\n", "sample cov:\n", " [[ 1.0641273 -0.03126175]\n", " [-0.03126175 0.01559312]]\n" ] } ], "source": [ "def compute_sample_stats(d, seed=42, n=int(1e6)):\n", " x = d.sample(n, seed=seed)\n", " sample_mean = tf.reduce_mean(x, axis=0, keepdims=True)\n", " s = x - sample_mean\n", " sample_cov = tf.linalg.matmul(s, s, adjoint_a=True) / tf.cast(n, s.dtype)\n", " sample_scale = tf.linalg.cholesky(sample_cov)\n", " sample_mean = sample_mean[0]\n", " return [\n", " sample_mean,\n", " sample_cov,\n", " sample_scale,\n", " ]\n", "\n", "dtype = np.float32\n", "true_loc = np.array([1., -1.], dtype=dtype)\n", "true_chol_precision = np.array([[1., 0.],\n", " [2., 8.]],\n", " dtype=dtype)\n", "true_precision = np.matmul(true_chol_precision, true_chol_precision.T)\n", "true_cov = np.linalg.inv(true_precision)\n", "\n", "d = MVNCholPrecisionTriL(\n", " loc=true_loc,\n", " chol_precision_tril=true_chol_precision)\n", "\n", "[sample_mean, sample_cov, sample_scale] = [\n", " t.numpy() for t in compute_sample_stats(d)]\n", "\n", "print('true mean:', true_loc)\n", "print('sample mean:', sample_mean)\n", "print('true cov:\\n', true_cov)\n", "print('sample cov:\\n', sample_cov)" ] }, { "cell_type": "markdown", "metadata": { "id": "N60z8scN1v6E" }, "source": [ "由于样本均值和协方差接近真实均值和协方差,因此似乎分布已正确实现。现在,我们将使用 `MVNCholPrecisionTriL` `tfp.distributions.JointDistributionNamed` 来指定 BGMM 模型。对于观察模型,我们将使用 `tfd.MixtureSameFamily` 自动集成 ${Z_i}_{i=1}^N$ 绘图。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xhzxySDjL2-S" }, "outputs": [], "source": [ "dtype = np.float64\n", "dims = 2\n", "components = 3\n", "num_samples = 1000" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xAOmHhZ7LzDQ" }, "outputs": [], "source": [ "bgmm = tfd.JointDistributionNamed(dict(\n", " mix_probs=tfd.Dirichlet(\n", " concentration=np.ones(components, dtype) / 10.),\n", " loc=tfd.Independent(\n", " tfd.Normal(\n", " loc=np.stack([\n", " -np.ones(dims, dtype),\n", " np.zeros(dims, dtype),\n", " np.ones(dims, dtype),\n", " ]),\n", " scale=tf.ones([components, dims], dtype)),\n", " reinterpreted_batch_ndims=2),\n", " precision=tfd.Independent(\n", " tfd.WishartTriL(\n", " df=5,\n", " scale_tril=np.stack([np.eye(dims, dtype=dtype)]*components),\n", " input_output_cholesky=True),\n", " reinterpreted_batch_ndims=1),\n", " s=lambda mix_probs, loc, precision: tfd.Sample(tfd.MixtureSameFamily(\n", " mixture_distribution=tfd.Categorical(probs=mix_probs),\n", " components_distribution=MVNCholPrecisionTriL(\n", " loc=loc,\n", " chol_precision_tril=precision)),\n", " sample_shape=num_samples)\n", "))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CpLnRJr2TXYD" }, "outputs": [], "source": [ "def joint_log_prob(observations, mix_probs, loc, chol_precision):\n", " \"\"\"BGMM with priors: loc=Normal, precision=Inverse-Wishart, mix=Dirichlet.\n", "\n", " Args:\n", " observations: `[n, d]`-shaped `Tensor` representing Bayesian Gaussian\n", " Mixture model draws. Each sample is a length-`d` vector.\n", " mix_probs: `[K]`-shaped `Tensor` representing random draw from\n", " `Dirichlet` prior.\n", " loc: `[K, d]`-shaped `Tensor` representing the location parameter of the\n", " `K` components.\n", " chol_precision: `[K, d, d]`-shaped `Tensor` representing `K` lower\n", " triangular `cholesky(Precision)` matrices, each being sampled from\n", " a Wishart distribution.\n", "\n", " Returns:\n", " log_prob: `Tensor` representing joint log-density over all inputs.\n", " \"\"\"\n", " return bgmm.log_prob(\n", " mix_probs=mix_probs, loc=loc, precision=chol_precision, s=observations)" ] }, { "cell_type": "markdown", "metadata": { "id": "7jTMXdymV1QJ" }, "source": [ "## 生成“训练”数据" ] }, { "cell_type": "markdown", "metadata": { "id": "rl4brz3G3pS7" }, "source": [ "对于此演示,我们将对一些随机数据进行采样。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1AJZAtwXV8RQ" }, "outputs": [], "source": [ "true_loc = np.array([[-2., -2],\n", " [0, 0],\n", " [2, 2]], dtype)\n", "random = np.random.RandomState(seed=43)\n", "\n", "true_hidden_component = random.randint(0, components, num_samples)\n", "observations = (true_loc[true_hidden_component] +\n", " random.randn(num_samples, dims).astype(dtype))" ] }, { "cell_type": "markdown", "metadata": { "id": "zVOvMh7MV37A" }, "source": [ "## 使用 HMC 的贝叶斯推断" ] }, { "cell_type": "markdown", "metadata": { "id": "cdN3iKFT32Jp" }, "source": [ "现在,我们已经使用了 TFD 来指定我们的模型并获得了一些观察到的数据,我们已经具备了运行 HMC 所需的所有部件。\n", "\n", "为此,我们将使用[部分应用](https://en.wikipedia.org/wiki/Partial_application)来“固定”我们不想采样的内容。在本例中,这意味着我们只需要固定 `observations`。(超参数已烘焙到先验分布中,而不是 `joint_log_prob` 函数签名的一部分。)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tVoaDFSf7L_j" }, "outputs": [], "source": [ "unnormalized_posterior_log_prob = functools.partial(joint_log_prob, observations)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "a0OMIWIYeMmQ" }, "outputs": [], "source": [ "initial_state = [\n", " tf.fill([components],\n", " value=np.array(1. / components, dtype),\n", " name='mix_probs'),\n", " tf.constant(np.array([[-2., -2],\n", " [0, 0],\n", " [2, 2]], dtype),\n", " name='loc'),\n", " tf.linalg.eye(dims, batch_shape=[components], dtype=dtype, name='chol_precision'),\n", "]" ] }, { "cell_type": "markdown", "metadata": { "id": "TVpiT3LLyfcO" }, "source": [ "### 无约束表示" ] }, { "cell_type": "markdown", "metadata": { "id": "JS8XOsxiyiBV" }, "source": [ "汉密尔顿蒙特卡洛 (HMC) 要求目标对数概率函数的参数可微。此外,如果状态空间无约束,则 HMC 可以显著提高统计效率。\n", "\n", "这意味着从 BGMM 后验进行采样时,我们必须解决两个主要问题:\n", "\n", "1. $\\theta$ 表示离散概率向量,即必须使 $\\sum_{k=1}^K \\theta_k = 1$ 且 $\\theta_k>0$。\n", "2. $T_k$ 表示逆协方差矩阵,即必须使 $T_k \\succ 0$(即为[正定](https://en.wikipedia.org/wiki/Positive-definite_matrix)。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "Vt9SXJzO0Cks" }, "source": [ "为了满足此要求,我们需要:\n", "\n", "1. 将约束变量转换为无约束空间\n", "2. 在无约束空间中运行 MCMC\n", "3. 将无约束变量转换回约束空间\n", "\n", "至于 `MVNCholPrecisionTriL`,我们将使用 [`Bijector`](https://tensorflow.google.cn/probability/api_docs/python/tfp/bijectors/Bijector) 将随机变量转换为无约束空间。\n", "\n", "- [`Dirichlet`](https://en.wikipedia.org/wiki/Dirichlet_distribution) 通过 [Softmax 函数](https://en.wikipedia.org/wiki/Softmax_function)转换为无约束空间。\n", "\n", "- 我们的精确率随机变量是半正定矩阵上的分布。为了取消约束,我们将使用 `FillTriangular` 和 `TransformDiagonal` 双射器。这些会将向量转换为下三角矩阵,并确保对角线为正。前者很有用,因为它可以仅对 $d(d+1)/2$ 浮点进行采样,而非 $d^2$。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_atEQrDR7JvG" }, "outputs": [], "source": [ "unconstraining_bijectors = [\n", " tfb.SoftmaxCentered(),\n", " tfb.Identity(),\n", " tfb.Chain([\n", " tfb.TransformDiagonal(tfb.Softplus()),\n", " tfb.FillTriangular(),\n", " ])]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0zq6QJJ-NSPJ" }, "outputs": [], "source": [ "@tf.function(autograph=False)\n", "def sample():\n", " return tfp.mcmc.sample_chain(\n", " num_results=2000,\n", " num_burnin_steps=500,\n", " current_state=initial_state,\n", " kernel=tfp.mcmc.SimpleStepSizeAdaptation(\n", " tfp.mcmc.TransformedTransitionKernel(\n", " inner_kernel=tfp.mcmc.HamiltonianMonteCarlo(\n", " target_log_prob_fn=unnormalized_posterior_log_prob,\n", " step_size=0.065,\n", " num_leapfrog_steps=5),\n", " bijector=unconstraining_bijectors),\n", " num_adaptation_steps=400),\n", " trace_fn=lambda _, pkr: pkr.inner_results.inner_results.is_accepted)\n", "\n", "[mix_probs, loc, chol_precision], is_accepted = sample()" ] }, { "cell_type": "markdown", "metadata": { "id": "QLEz96mg6fpZ" }, "source": [ "现在,我们将执行链并打印后验均值。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_ceX1A3-ZFiN" }, "outputs": [], "source": [ "acceptance_rate = tf.reduce_mean(tf.cast(is_accepted, dtype=tf.float32)).numpy()\n", "mean_mix_probs = tf.reduce_mean(mix_probs, axis=0).numpy()\n", "mean_loc = tf.reduce_mean(loc, axis=0).numpy()\n", "mean_chol_precision = tf.reduce_mean(chol_precision, axis=0).numpy()\n", "precision = tf.linalg.matmul(chol_precision, chol_precision, transpose_b=True)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "bqJ6RSJxegC6" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "acceptance_rate: 0.5305\n", "avg mix probs: [0.25248723 0.60729516 0.1402176 ]\n", "avg loc:\n", " [[-1.96466753 -2.12047249]\n", " [ 0.27628865 0.22944732]\n", " [ 2.06461244 2.54216122]]\n", "avg chol(precision):\n", " [[[ 1.05105032 0. ]\n", " [ 0.12699955 1.06553113]]\n", "\n", " [[ 0.76058015 0. ]\n", " [-0.50332767 0.77947431]]\n", "\n", " [[ 1.22770457 0. ]\n", " [ 0.70670027 1.50914164]]]\n" ] } ], "source": [ "print('acceptance_rate:', acceptance_rate)\n", "print('avg mix probs:', mean_mix_probs)\n", "print('avg loc:\\n', mean_loc)\n", "print('avg chol(precision):\\n', mean_chol_precision)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zFOU0j9kPdUy" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAEQCAYAAABLMTQcAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAG4dJREFUeJzt3Xts1fX9x/HXuRTanrbQcsqpxy4C4iXoJFRATWEzJArK\ngktWzP6YDpbNyIrzHpyTS+d0Y/7QKBbYENHBZGCNaOzMXDL/oCBWyhJlmbuJGigczmmp0gNSaM/v\nj3pOzzk9p6el355z+jnPR2Lo+Z5vP99PP5HXefP+XmoLhUIhAQCMYs/0BAAA1iPcAcBAhDsAGIhw\nBwADEe4AYCDCHQAMRLgjZ7zyyiuqrq5WVVWVvvjii5j3jh49qiuvvFI9PT0jOod58+bpvffeG9Fj\nABLhjhEUH2SNjY2aPXu2Dhw4EAnTqqoqVVVVac6cObr77ru1b9++fmNMnz5dVVVVmjFjhqqqqvSr\nX/1qyHM5f/681q5dq61bt+rgwYMaN25cv31sNtvQf0ggSzkzPQHkhtdff11r167V5s2bNX36dB09\nelQ2m00tLS2y2Wxqa2tTY2OjamtrtXr1an33u9+NfO/vfvc7XX/99cM6fiAQUFdXly699NLh/igj\npru7Ww6HI9PTgCGo3DHidu7cqd/+9rd68cUXNX369Jj3wjdIT5gwQXfeeafuuecePfXUUwn3SaWr\nq0tPPPGE5s6dq29961t68sknde7cOX366ae65ZZbJEmzZs3SkiVLUo514sQJLVu2TNddd53mz5+v\nV199NfJeT0+PNm3apJtuuknXXnutvve978nn8yUcZ/fu3Zo3b56uv/56bdq0Kea9559/Xj/72c/0\n8MMPa+bMmXr99df14Ycf6vvf/75mzZqluXPn6vHHH9f58+clSevXr4/8q+X8+fOaMWOG/u///k+S\ndPbsWV1zzTU6deqUurq69PDDD+u6667TrFmztHjxYrW3tw9qDWEOwh0jaseOHVq/fr1efvllTZs2\nLeX+N910k9ra2vTJJ58M+VgbN27URx99pDfffFNvvPGGPvzwQ23cuFGTJk3SW2+9JUlqaWnRSy+9\nlHKsBx54QF6vV01NTXr22Wf19NNPa//+/ZKkF198UX/+85/1wgsvqKWlRU8++aTy8/P7jfHf//5X\ndXV1euqpp7Rnzx51dHToxIkTMfv87W9/0y233KIDBw5o0aJFcjqdevTRR9Xc3KydO3dq//79euWV\nVyT1fjA1NzdLkj766CO53W598MEHkqSDBw9qypQpKi4u1uuvv67Ozk7t2bNHzc3Nqqur09ixY4e8\nnhjdCHeMqH379mn69Om6/PLLB7W/x+ORpJgTnrW1tZo9e7ZmzZql2bNnx1TR0d566y3V1taqtLRU\npaWlWr58uXbv3i2pr/ofzL8Cjh07pr///e966KGHlJeXpyuvvFKLFy/WG2+8IUlqaGjQ/fffr0su\nuUSSdMUVVyTs4f/lL3/RvHnzdO211yovL0/33ntvv31mzJihefPmSZLGjBmjadOm6ZprrpHNZpPX\n69Xtt98eCfAZM2bos88+0xdffKEPPvhANTU18vl8OnPmjA4cOKBZs2ZJkpxOpzo6OnT48GHZbDZN\nmzZNLpcr5c8Ns9Bzx4iqq6vThg0b9Oijj+rJJ59MuX+4vTF+/PjItg0bNgyq537ixAl5vd7Ia6/X\nK7/fL2loJ0v9fr/GjRungoKCmLH+8Y9/SJKOHz+ub3zjG4OaT0VFReR1QUFBzM8lKeZ9Sfr000/1\nm9/8RocOHdJXX32l7u5uXXXVVZKksWPH6uqrr1Zzc7MOHDigZcuW6eOPP1ZLS4uam5t15513SpJu\nu+02HT9+XA888IBOnTqlRYsW6f7776efn2Oo3DGiysrK9NJLL6mlpUVr1qxJuf8777wjt9utyZMn\nR7YNtufu8Xh09OjRyOvW1lZNnDhxyHOeOHGivvjiC50+fTqy7dixY5GxKioq9Pnnn6ccp7y8XMeP\nH4+8PnPmjDo6OmL2if/QWbNmjaZMmaK//vWvOnDggO67776Yn3/mzJnav3+//vnPf+qb3/ymZs6c\nqaamJh06dEgzZ86U1Fu519bWqrGxUX/605/07rvvRv4Fg9xBuGPElZeX6+WXX1ZTU5N+/etfR7aH\nQqFIcLW1tWn79u3asGGDHnzwwQs6zq233qqNGzeqvb1d7e3t2rBhg2677baY4w0k/H5FRYVmzJih\np59+Wl1dXfr444/V0NCgRYsWSZIWL16sZ599Vp999pkk6V//+le/6+YlacGCBXr33Xd18OBBnTt3\nTs8991zKnyEYDKqoqEgFBQX63//+px07dsS8P3v2bO3evVtTp06V0+nUddddp1dffVWVlZUqLS2V\nJL3//vv697//rZ6eHhUWFsrpdFK156Cca8sEg0E1NjZq4cKFOd2HTMc6RFelFRUVeumll3THHXco\nPz9ft99+u2w2m2bNmqVQKKTCwkJdffXVeu6551RdXR0zzrJly2S399Uh1dXVWr9+fb/j/fSnP1Uw\nGNSiRYtks9l0yy236O677044n2jhtYh+f926dVq9erXmzp2rcePG6d5779UNN9wgSVq6dKnOnTun\nH/3oR+ro6NCUKVP0/PPP9+u7T506VatWrdKDDz6oM2fOaOnSpZFzCsmsWLFCK1eu1AsvvKBp06Zp\n4cKFkRO5Um/f/ezZs5H++tSpU5Wfnx95LfVe9rl69Wr5fD65XC7deuutkQ+mVPj70cuIdQjlGJ/P\nF1q8eHHI5/NleioZxTr0YS36sBa9TFgH2jIAYCDL2jJPPfWU/H6/bDab8vPztXTpUk2aNMmq4QEA\nQ2BZuC9fvjxy6diBAwe0ceNGrV271qrhAQBDYFlbJvqa4GAwGHMCLJvY7XaVl5dn7fzShXXow1r0\nYS16mbAOtlBokBcRD8KmTZv04YcfSpIeffRRVVZWxrwfDAYVDAZjtjmdTpWVlVk1BQDIKe3t7ZHn\nD4W5XC5rwz1sz549ampq0s9//vOY7bt27VJDQ0PMNo/Hk/CyNgBAavfcc0+/B9fV1NSMTLhL0g9+\n8ANt2rRJRUVFkW2JKne73S632y2fz6fu7u6RmEo/Xq9Xra2taTlWNmMd+rAWfViLXtm+Dg6HQx6P\nR4FAoN8vmXG5XNacUP3qq68UDAY1YcIESb0nVIuLi2OCPXzAUXtDAABkIbfbnXC7JeF+9uzZyK3a\nNptNxcXFWrFihRVDAwAugCXhPm7cOD3xxBNWDAUAsMDovc4HAJAU4Q4ABiLcAcBAOffIXwC57ezp\notQ7STr83y8lFWlsYefITmiEEO4AcsZgg32g7xktYU+4A8hZwc7zqXeK4ipyjpqwJ9wB5IT4UI4O\n9lQh7ypyDrBf/38NZEPgE+4Aclo4sIOd5xK+7yrKSxr+riJn5L3wB4DU90GSyZAn3AHkhLGFnRfU\nc08W+tHCIR8d8JnGpZAAckZ0JR0O4r4/8wb83tOd5/v9Fw7+ofbu0yF7PmYAIA2GWsGfjmvbBINf\nt2FcTgU7z+l00XkVxlXs4Q+Ms6czdykl4Q4AceIDXeoL9USvwyEfLHLG/Asg/uqadAY94Q4AUQYT\n7NESvRfsPBc5ERvbh09f0BPuABClsMip053nvw7n1CdTowU7z0WC/XRnb7sm/DpauKIfyYAn3AHk\nnHDfPdHVLcHOc/0C3uXq2y9ZFT9QdS8pSSU/crhaBkBOGlvYGamco6+cCVfZhUmuonG5nDH/Zavs\nnRkApNFAFXW4ko8WbtkkC/j4D4lUx7Aa4Q4gp8VfGtlbvcfelBT8ukXT+3XyPnz8JZF9Yya+hn4k\n++6EO4CcFx2w4aBPVmX3BX/sCddUN0ElG2+kAp5wB4AoiW5yig/mC33cQDoDnnAHgDiDuYs18ZU2\n5wd8P50IdwBIYPLUErW2tg56/2SXVg4GbRkAyFKJAjpZ9Z+OxxAQ7gAwQjL5PHduYgIAAxHuAGAg\nwh0ADES4A4CBCHcAMJAlV8t0dnZq/fr1OnHihJxOpyoqKnTXXXepuLjYiuEBAENk2aWQt912m6ZN\nmyZJ2r59u/74xz/q7rvvtmp4AMAQWNKWKSoqigS7JF122WUKBAJWDA0AuACW99xDoZDeeecdzZw5\n0+qhAQCDZPkdqlu2bFFBQYEWLFjQ771gMKhgMBizzW63y+12Wz0NAMgJgUBAPT09MdtcLpdsoVAo\nZNVBtm3bps8//1yPPPKIHA5Hv/d37dqlhoaGmG3l5eWqr6+3agoAkFNqa2vl9/tjttXU1FgX7jt2\n7NB//vMfPfLIIxozZkzCfQaq3H0+n7q7u62YSkper3dIT3szFevQh7Xow1r0yvZ1cDgc8ng8SSt3\nS9oyR44c0e7du+X1evXYY49JkiZOnKiHHnqo3wFdLpcVhwQASEnb2paEe2VlpXbu3GnFUAAAC3CH\nKgAYiHAHAAMR7gBgIMIdAAxEuAOAgQh3ADAQ4Q4ABiLcAcBAhDsAGIhwBwADEe4AYCDCHQAMRLgD\ngIEIdwAwEOEOAAYi3AHAQIQ7ABiIcAcAAxHuAGAgwh0ADES4A4CBCHcAMBDhDgAGItwBwECEOwAY\niHAHAAMR7gBgIMIdAAxEuAOAgQh3ADCQ06qBtm3bpvfff19+v1/r1q1TZWWlVUMDAIbIssp99uzZ\n+uUvf6ny8nKrhgQAXCDLKvcrrrhCkhQKhawaEgBwgSwL98EIBoMKBoMx2+x2u9xudzqnAQDGCAQC\n6unpidnmcrnSG+6NjY1qaGiI2VZeXq76+np5PJ50TkVerzetx8tWrEMf1qIPa9FrNKzD6tWr5ff7\nY7bV1NSkN9wXLlyoG2+8MWab3d7b9vf5fOru7k7LPLxer1pbW9NyrGzGOvRhLfqwFr2yfR0cDoc8\nHo/q6uoyX7m7XC65XK50HhIAjJasrW1ZuG/dulXNzc3q6OjQ448/rqKiIq1bt86q4QEAQ2BZuC9d\nulRLly61ajgAwDBwhyoAGIhwBwADEe4AYCDCHQAMRLgDgIEIdwAwEOEOAAYi3AHAQIQ7ABgorc+W\nATB47bailPuUhTrTMBOMRlTuQJZptxUNKtjD+wKJULkDWSI6qNtPn0u6X1lhXjqmg1GOcAcyKL7y\nHijUo/ch4JEK4Q5kQLIqveN0V9LvGV84ZkTnBLMQ7kCaDCbQT57pX7mXFiSv0jmhimQId2CEJQr1\ncKAnCvNUaMlgMAh3YISEQz1ZoJ9M0oIpHWT7haodAyHcAQslq9LjA70j7sTp+CTVeLglE+63h6t2\ngh2pEO6ABQYT6tGBfvJMl0oL+lfo4aqdUMdwEe7AMMSHenTrJbpKP3mmfwsmacAX5PULdYlgx9AQ\n7sAFGKifHl2lh0M9umqPb8GEX5cWjkkY7IQ6LgThDgxBolBP1Ho5eaarX189XmnBGI0vzIuEutTb\nhiHUYQXCHRiERO2XRK2XVIEeqdLjgp1qHVYj3IEUoqv16PbL4UAwJtDbE1zaWBZ1WeP4wjyqdaQN\n4Q4kMVC1Hl2pt5/uUkeCm5HGR91ZmijYCXWMJMIdSCBZb32wwR5tfGGeJk9wJWzBEOoYKYQ7ECe+\nDRN/wjTRZY3RxhfkqaxwDNU6MopwB742ULXe+7p/hV5WOCamry4lb8H07k+wIz0Id+S8dluR2o99\nOaTHBUix16uHb0bimnVkC8vC/dixY6qvr1dnZ6eKi4tVW1uriooKq4YHRsSFVOvxd5XGB3rvNkId\nmWVZuG/evFkLFizQnDlztGfPHv3+97/XqlWrrBoesFSqO0wHElOx8ywYZClLwv3LL7/U4cOHVV1d\nLUmqrq7Wiy++qFOnTqm4uNiKQwCWiA91aXDPVh9MoEuEOrKHJeEeCARUVlYmm80mSbLb7SotLVVb\nWxvhjqxwob8wYyhVukSoI3uk9YRqMBhUMBiM2Wa32+V2u9M5DeSYoVbryX5ZBqGObBQIBNTT0xOz\nzeVyWRPubrdb7e3tCoVCstls6unp0cmTJzVhwoSY/RobG9XQ0BCzrby8XPX19fJ4PFZMZdC8Xm9a\nj5etTF+HQ8e+7LdtoF9CHW+g1oskXX1Ryddflcgkpv9/MVijYR1Wr14tv98fs62mpsaacC8pKdGk\nSZPU1NSkuXPnqqmpSZMnT+7Xklm4cKFuvPHGmG12u12S5PP51N3dbcV0UvJ6vWptbU3LsbKZ6esQ\n3YoZitKYxwYMXKW3tppXsZv+/8VgZfs6OBwOeTwe1dXVjVzlLkk/+clPVF9fr9dee00ul0vLly/v\nt4/L5ZLL5bLqkMAFGV84JlK9Rwd5/D4SrRdkv2RtbcvC3ev16oknnrBqOGDYykKdSav38Un66pHv\nJdQxynGHKnJGOLCjT6xGh3i//Ql1jGKEO4yWqHpPFOhXX1SS1f1VYKjsmZ4AMNIGqsDLQp1U6DAS\nlTtyAgGOXEPlDgAGItwBwECEOwAYiHAHAAMR7gBgIMIdAAxEuAOAgQh3ADAQ4Q4ABiLcAcBAhDsA\nGIhwBwADEe4AYCDCHQAMRLgDgIEIdwAwEOEOAAYi3AHAQIQ7ABiIcAcAAxHuAGAgwh0ADES4A4CB\nCHcAMBDhDgAGItwBwEDO4Q6wZ88evfnmmzpy5IiWLFmi+fPnWzEvAMAwDLtynzx5su677z7NmTPH\nivkAACww7Mq9srJSkmSz2YY9GQCANYYd7kMRDAYVDAZjttntdrnd7nROAwCMEQgE1NPTE7PN5XKl\nDvcVK1aora0tZlsoFJLNZtPmzZuHVLE3NjaqoaEhZlt5ebnq6+vl8XgGPY4VvF5vWo+XrViHPqxF\nH9ai12hYh9WrV8vv98dsq6mpkS0UCoWsOMCGDRt06aWXDnhCdaDK3efzqbu724qppOT1etXa2pqW\nY2Uz1qEPa9GHteiV7evgcDjk8XguvHIfilSfEy6XSy6Xy8pDAkBOS9bWHvbVMnv37tWyZcu0f/9+\n7dq1S8uWLdPRo0eHOywAYBiGXblXV1erurrairkAACzCHaoAYCDCHQAMRLgDgIEIdwAwEOEOAAYi\n3AHAQIQ7ABiIcAcAAxHuAGAgwh0ADES4A4CBCHcAMBDhDgAGItwBwECEOwAYiHAHAAMR7gBgIMId\nAAxEuAOAgQh3ADAQ4Q4ABiLcAcBAhDsAGIhwBwADEe4AYCDCHQAMRLgDgIEIdwAwkDPTE8iErw6+\np7IE29srLkn7XABgJAw73Lds2aJDhw4pLy9P+fn5WrJkiaZMmWLF3CxVdvyzmNfdgRMxrx3uif32\nCSP0AYw2ww73GTNmaOnSpbLb7Tp48KCeeeYZrV+/3oq5WSJVqMdvd7gnJh2DkAcwWgw73KuqqiJf\nX3755Wpvbx/ukJYJh3J0oJ/3+4Y0RqKwB4BsZ2nP/e23344J+0wZKNS7AwOHe3fAJ4fbI2e5JzJG\nOODLjn9G9Q5gVEgZ7itWrFBbW1vMtlAoJJvNps2bN8tms0mS9u7dq3379qmuri7pWMFgUMFgMGab\n3W6X2+2+kLknFB/s5/2+foE+UMA73J7I9xHwALJdIBBQT09PzDaXyyVbKBQKDXfw5uZmbd++XatW\nrRowqHft2qWGhoaYbeXl5aqvrx/W8b86+F7M6/hgjw7zZMEeDnWH2xP5Ohzuvdtj2zP5VTcMa84A\nYIXa2lr5/f6YbTU1NcMP95aWFm3dulUrV66Ux+MZcN+BKnefz6fu7u4hHz/6hGmyNkx0wHf7+/Zx\nlPcFdnS4R/8ZDvj4cDehevd6vWptbc30NLICa9GHteiV7evgcDjk8XiSVu7D7rlv3LhReXl5evrp\npyPtmpUrV6qoqKjfvi6XSy6Xa7iHjIgP9viTpf2qdn/slTLd/hORgA/32sN/AsBokKxbMuxwf+GF\nF4Y7xJANdHljqhOm8RIFfPzXADDajLrHDyS60SjV5YpDDfzeMT1Gt2QAmG1UPX4g2R2kFyq65550\nH4IdwCg06ir3gTjLPQlbKQm3DRDs4aqdYAcwWo2qyn2wwidGw18n2p7s+5IFOwCMJkaEe7LnxUSL\nvsQxUcBHB3siVO0ARpNRFe7tFZek7Ls7yz0pnx8T36ZJdU07AIw2oyrcpdgKOhz0DvfEmOo9WfU9\nkIGCnaodwGgz6sI9WnQlP5hqO1lFzyWPAEwzqsNd6h/AQwn7wY4JAKPNqA/3eImCOVngJ/oFHQQ7\nABMYF+6JhAM7/mQsbRgApsqJcA8Lh3e2P+0NAIbLqDtUAQC9CHcAMBDhDgAGItwBwECEOwAYiHAH\nAAMR7gBgIMIdAAxEuAOAgQh3ADAQ4Q4ABiLcAcBAWfPgMLs9vZ8zDocjrcfLVqxDH9aiD2vRK5vX\nIVVm2kKhUChNcwEApEnOtWUCgYBqa2sVCAQyPZWMYh36sBZ9WIteJqxDzoV7T0+P/H6/enp6Mj2V\njGId+rAWfViLXiasQ86FOwDkAsIdAAxEuAOAgRxr1qxZk+lJpFteXp6uuuoqjRkzJtNTySjWoQ9r\n0Ye16DXa14FLIQHAQLRlAMBAhDsAGChrHj+Qblu2bNGhQ4eUl5en/Px8LVmyRFOmTMn0tNJuz549\nevPNN3XkyBEtWbJE8+fPz/SU0urYsWOqr69XZ2eniouLVVtbq4qKikxPK+22bdum999/X36/X+vW\nrVNlZWWmp5QxnZ2dWr9+vU6cOCGn06mKigrdddddKi4uzvTUhiQnT6iG/fCHP9TNN9+scePGaePG\njbr11lszPaW0czgcuuGGG3T69GmVlZVp6tSpmZ5SWj3zzDO66aabdNdddykvL0+vvfaavv3tb2d6\nWmmXn5+v73znO2pubtacOXNUUlKS6SllTFdXly666CLdcccduvnmm/XJJ5+opaVFM2fOzPTUhiRn\n2zJVVVWRB+9cfvnlam9vz/CMMqOyslIXX3yxbDZbpqeSdl9++aUOHz6s6upqSVJ1dbUOHz6sU6dO\nZXhm6XfFFVeorKxMXF8hFRUVadq0aZHXl1122ah8DEHOhnu0t99+W1VVVZmeBtIsEAiorKws8sFm\nt9tVWlqqtra2DM8M2SIUCumdd94ZdVW7ZHDPfcWKFf3+koZCIdlsNm3evDnyF3rv3r3at2+f6urq\nMjHNETfYdQDQ35YtW1RQUKAFCxZkeipDZmy4r127NuU+zc3N2rlzp1atWmVsj3Ew65Cr3G632tvb\nIx92PT09OnnypCZMmJDpqSELbNu2TT6fT4888kimp3JBcrYt09LSoj/84Q/6xS9+IbfbnenpZIVc\n67eWlJRo0qRJampqkiQ1NTVp8uTJo+6qCFhvx44dOnz4sB5++OGs/oUdA8nZO1R//OMfKy8vTyUl\nJZHKbeXKlSoqKsr01NJq79692r59u4LBoJxOp8aOHavHHntMF198caanlhatra2qr69XMBiUy+XS\n8uXLddFFF2V6Wmm3detWNTc3q6OjQyUlJSoqKtK6desyPa2MOHLkiB588EF5vV7l5eVJkiZOnKiH\nHnoowzMbmpwNdwAwWc62ZQDAZIQ7ABiIcAcAAxHuAGAgwh0ADES4A4CBCHcAMBDhDgAG+n+UFdHi\na8/mXgAAAABJRU5ErkJggg==\n", "text/plain": [ "" ] }, "metadata": { "tags": [] }, "output_type": "display_data" } ], "source": [ "loc_ = loc.numpy()\n", "ax = sns.kdeplot(loc_[:,0,0], loc_[:,0,1], shade=True, shade_lowest=False)\n", "ax = sns.kdeplot(loc_[:,1,0], loc_[:,1,1], shade=True, shade_lowest=False)\n", "ax = sns.kdeplot(loc_[:,2,0], loc_[:,2,1], shade=True, shade_lowest=False)\n", "plt.title('KDE of loc draws');" ] }, { "cell_type": "markdown", "metadata": { "id": "NmfNIM1c6mwc" }, "source": [ "## 结论" ] }, { "cell_type": "markdown", "metadata": { "id": "t8LeIeMn6ot4" }, "source": [ "这个简单的 Colab 演示了如何将 TensorFlow Probability 基元用于构建分层贝叶斯混合模型。" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "Bayesian_Gaussian_Mixture_Model.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }