{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "H9CL5JZzWuuj" }, "source": [ "##### Copyright 2019 The TensorFlow Probability Authors.\n", "\n", "Licensed under the Apache License, Version 2.0 (the \"License\");" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "DcBheSMRXSq2" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\"); { display-mode: \"form\" }\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "VYXS33IGWx4Q" }, "source": [ "# 可学习分布园地\n", "\n", "\n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 上查看在 Google Colab 中运行 在 Github 上查看源代码 下载笔记本
" ] }, { "cell_type": "markdown", "metadata": { "id": "6m-yUlt3WP_n" }, "source": [ "在此 Colab 中,我们将展示构建可学习(“可训练”)分布的各种示例。(我们不会对这些分布进行说明,只会展示如何构建它们。)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LDScDeCWUuFf" }, "outputs": [], "source": [ "import numpy as np\n", "import tensorflow.compat.v2 as tf\n", "import tensorflow_probability as tfp\n", "from tensorflow_probability.python.internal import prefer_static\n", "tfb = tfp.bijectors\n", "tfd = tfp.distributions\n", "tf.enable_v2_behavior()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "bNpridm1U44X" }, "outputs": [], "source": [ "event_size = 4\n", "num_components = 3" ] }, { "cell_type": "markdown", "metadata": { "id": "vn_fp9oBV7NQ" }, "source": [ "## 用于 `chol(Cov)` 的带缩放单位的可学习多元正态分布" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZED8qAvZU8yW" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tfp.distributions.Independent(\"learnable_mvn_scaled_identity\", batch_shape=[], event_shape=[4], dtype=float32)\n", "(, )\n" ] } ], "source": [ "learnable_mvn_scaled_identity = tfd.Independent(\n", " tfd.Normal(\n", " loc=tf.Variable(tf.zeros(event_size), name='loc'),\n", " scale=tfp.util.TransformedVariable(\n", " tf.ones([1]),\n", " bijector=tfb.Exp(),\n", " name='scale')),\n", " reinterpreted_batch_ndims=1,\n", " name='learnable_mvn_scaled_identity')\n", "\n", "print(learnable_mvn_scaled_identity)\n", "print(learnable_mvn_scaled_identity.trainable_variables)" ] }, { "cell_type": "markdown", "metadata": { "id": "8LX4NbZ5V_UT" }, "source": [ "## 用于 `chol(Cov)` 的带对角线的可学习多元正态分布" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tH7NUglGU9Xs" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tfp.distributions.Independent(\"learnable_mvn_diag\", batch_shape=[], event_shape=[4], dtype=float32)\n", "(, )\n" ] } ], "source": [ "learnable_mvndiag = tfd.Independent(\n", " tfd.Normal(\n", " loc=tf.Variable(tf.zeros(event_size), name='loc'),\n", " scale=tfp.util.TransformedVariable(\n", " tf.ones(event_size),\n", " bijector=tfb.Softplus(), # Use Softplus...cuz why not?\n", " name='scale')),\n", " reinterpreted_batch_ndims=1,\n", " name='learnable_mvn_diag')\n", "\n", "print(learnable_mvndiag)\n", "print(learnable_mvndiag.trainable_variables)" ] }, { "cell_type": "markdown", "metadata": { "id": "q7QlKm8fWCCH" }, "source": [ "## 多元正态混合(球状)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4kaN1OP8U_S-" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tfp.distributions.MixtureSameFamily(\"learnable_mix_mvn_scaled_identity\", batch_shape=[], event_shape=[4], dtype=float32)\n", "(, , )\n" ] } ], "source": [ "learnable_mix_mvn_scaled_identity = tfd.MixtureSameFamily(\n", " mixture_distribution=tfd.Categorical(\n", " logits=tf.Variable(\n", " # Changing the `1.` intializes with a geometric decay.\n", " -tf.math.log(1.) * tf.range(num_components, dtype=tf.float32),\n", " name='logits')),\n", " components_distribution=tfd.Independent(\n", " tfd.Normal(\n", " loc=tf.Variable(\n", " tf.random.normal([num_components, event_size]),\n", " name='loc'),\n", " scale=tfp.util.TransformedVariable(\n", " 10. * tf.ones([num_components, 1]),\n", " bijector=tfb.Softplus(), # Use Softplus...cuz why not?\n", " name='scale')),\n", " reinterpreted_batch_ndims=1),\n", " name='learnable_mix_mvn_scaled_identity')\n", "\n", "print(learnable_mix_mvn_scaled_identity)\n", "print(learnable_mix_mvn_scaled_identity.trainable_variables)" ] }, { "cell_type": "markdown", "metadata": { "id": "0h_8Sp24WD7R" }, "source": [ "## 具有不可学习的第一混合权重的多元正态混合(球状)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1oGkxQoLVOiu" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tfp.distributions.MixtureSameFamily(\"learnable_mix_mvndiag_first_fixed\", batch_shape=[], event_shape=[4], dtype=float32)\n", "(, , )\n" ] } ], "source": [ "learnable_mix_mvndiag_first_fixed = tfd.MixtureSameFamily(\n", " mixture_distribution=tfd.Categorical(\n", " logits=tfp.util.TransformedVariable(\n", " # Initialize logits as geometric decay.\n", " -tf.math.log(1.5) * tf.range(num_components, dtype=tf.float32),\n", " tfb.Pad(paddings=[[1, 0]], constant_values=0)),\n", " name='logits'),\n", " components_distribution=tfd.Independent(\n", " tfd.Normal(\n", " loc=tf.Variable(\n", " # Use Rademacher...cuz why not?\n", " tfp.random.rademacher([num_components, event_size]),\n", " name='loc'),\n", " scale=tfp.util.TransformedVariable(\n", " 10. * tf.ones([num_components, 1]),\n", " bijector=tfb.Softplus(), # Use Softplus...cuz why not?\n", " name='scale')),\n", " reinterpreted_batch_ndims=1),\n", " name='learnable_mix_mvndiag_first_fixed')\n", "\n", "print(learnable_mix_mvndiag_first_fixed)\n", "print(learnable_mix_mvndiag_first_fixed.trainable_variables)" ] }, { "cell_type": "markdown", "metadata": { "id": "sOfgzpekWGcT" }, "source": [ "## 多元正态混合(全 `Cov`)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Gz7kcj0zVQPv" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tfp.distributions.MixtureSameFamily(\"learnable_mix_mvntril\", batch_shape=[], event_shape=[4], dtype=float32)\n", "(, , )\n" ] } ], "source": [ "learnable_mix_mvntril = tfd.MixtureSameFamily(\n", " mixture_distribution=tfd.Categorical(\n", " logits=tf.Variable(\n", " # Changing the `1.` intializes with a geometric decay.\n", " -tf.math.log(1.) * tf.range(num_components, dtype=tf.float32),\n", " name='logits')),\n", " components_distribution=tfd.MultivariateNormalTriL(\n", " loc=tf.Variable(tf.zeros([num_components, event_size]), name='loc'),\n", " scale_tril=tfp.util.TransformedVariable(\n", " 10. * tf.eye(event_size, batch_shape=[num_components]),\n", " bijector=tfb.FillScaleTriL(),\n", " name='scale_tril')),\n", " name='learnable_mix_mvntril')\n", "\n", "print(learnable_mix_mvntril)\n", "print(learnable_mix_mvntril.trainable_variables)" ] }, { "cell_type": "markdown", "metadata": { "id": "zU12fGeuWI29" }, "source": [ "## 具有不可学习的第一混合和第一分量的多元正态混合(全 `Cov`)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NMb5hcsDVVyL" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tfp.distributions.MixtureSameFamily(\"learnable_mix_mvntril_fixed_first\", batch_shape=[], event_shape=[4], dtype=float32)\n", "(, , )\n" ] } ], "source": [ "# Make a bijector which pads an eye to what otherwise fills a tril.\n", "num_tril_nonzero = lambda num_rows: num_rows * (num_rows + 1) // 2\n", "\n", "num_tril_rows = lambda nnz: prefer_static.cast(\n", " prefer_static.sqrt(0.25 + 2. * prefer_static.cast(nnz, tf.float32)) - 0.5,\n", " tf.int32)\n", "\n", "# TFP doesn't have a concat bijector, so we roll out our own.\n", "class PadEye(tfb.Bijector):\n", "\n", " def __init__(self, tril_fn=None):\n", " if tril_fn is None:\n", " tril_fn = tfb.FillScaleTriL()\n", " self._tril_fn = getattr(tril_fn, 'inverse', tril_fn)\n", " super(PadEye, self).__init__(\n", " forward_min_event_ndims=2,\n", " inverse_min_event_ndims=2,\n", " is_constant_jacobian=True,\n", " name='PadEye')\n", "\n", " def _forward(self, x):\n", " num_rows = int(num_tril_rows(tf.compat.dimension_value(x.shape[-1])))\n", " eye = tf.eye(num_rows, batch_shape=prefer_static.shape(x)[:-2])\n", " return tf.concat([self._tril_fn(eye)[..., tf.newaxis, :], x],\n", " axis=prefer_static.rank(x) - 2)\n", "\n", " def _inverse(self, y):\n", " return y[..., 1:, :]\n", "\n", " def _forward_log_det_jacobian(self, x):\n", " return tf.zeros([], dtype=x.dtype)\n", "\n", " def _inverse_log_det_jacobian(self, y):\n", " return tf.zeros([], dtype=y.dtype)\n", "\n", " def _forward_event_shape(self, in_shape):\n", " n = prefer_static.size(in_shape)\n", " return in_shape + prefer_static.one_hot(n - 2, depth=n, dtype=tf.int32)\n", "\n", " def _inverse_event_shape(self, out_shape):\n", " n = prefer_static.size(out_shape)\n", " return out_shape - prefer_static.one_hot(n - 2, depth=n, dtype=tf.int32)\n", "\n", "\n", "tril_bijector = tfb.FillScaleTriL(diag_bijector=tfb.Softplus())\n", "learnable_mix_mvntril_fixed_first = tfd.MixtureSameFamily(\n", " mixture_distribution=tfd.Categorical(\n", " logits=tfp.util.TransformedVariable(\n", " # Changing the `1.` intializes with a geometric decay.\n", " -tf.math.log(1.) * tf.range(num_components, dtype=tf.float32),\n", " bijector=tfb.Pad(paddings=[(1, 0)]),\n", " name='logits')),\n", " components_distribution=tfd.MultivariateNormalTriL(\n", " loc=tfp.util.TransformedVariable(\n", " tf.zeros([num_components, event_size]),\n", " bijector=tfb.Pad(paddings=[(1, 0)], axis=-2),\n", " name='loc'),\n", " scale_tril=tfp.util.TransformedVariable(\n", " 10. * tf.eye(event_size, batch_shape=[num_components]),\n", " bijector=tfb.Chain([tril_bijector, PadEye(tril_bijector)]),\n", " name='scale_tril')),\n", " name='learnable_mix_mvntril_fixed_first')\n", "\n", "\n", "print(learnable_mix_mvntril_fixed_first)\n", "print(learnable_mix_mvntril_fixed_first.trainable_variables)" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "Learnable_Distributions_Zoo.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }