{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "FVKYfpQVYPaJ" }, "source": [ "##### Copyright 2018 The TensorFlow Probability Authors.\n", "\n", "Licensed under the Apache License, Version 2.0 (the \"License\");" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "htHLjlnLYSoB" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\"); { display-mode: \"form\" }\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "DcriL2xPrG3_" }, "source": [ "# TensorFlow Distributions の形状を理解する\n", "\n", "\n", " \n", " \n", " \n", " \n", "
TensorFlow.org で表示Google Colab で実行GitHub でソースを表示ノートブックをダウンロード
" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "J6t0EUihrG4B" }, "outputs": [], "source": [ "import collections\n", "\n", "import tensorflow as tf\n", "tf.compat.v2.enable_v2_behavior()\n", "\n", "import tensorflow_probability as tfp\n", "tfd = tfp.distributions\n", "tfb = tfp.bijectors" ] }, { "cell_type": "markdown", "metadata": { "id": "QD5lzFZerG4H" }, "source": [ "## 基礎\n", "\n", "TensorFlow Distributions の形状には関連する 3 つの重要な概念があります。\n", "\n", "- *イベントの形状*は、分布からの 1 つの抽出の形状を表します。抽出は次元間で依存する場合があります。スカラー分布の場合、イベントの形状は [] です。5 次元の MultivariateNormal の場合、イベントの形状は [5] です。\n", "- *バッチの形状*は、独立した、同一に分布されていない抽出である「バッチ」の分布を表します。\n", "- *サンプルの形状*は、 分布ファミリからの独立した、同一に分布されたバッチの抽出を表します。\n", "\n", "イベントの形状とバッチの形状は `Distribution` オブジェクトのプロパティですが、サンプルの形状は `sample` または `log_prob` への特定の呼び出しに関連付けられています。\n", "\n", "このノートブックでは、例を使ってこれらの概念を説明していくので、すぐに分からなくても、心配する必要はありません。\n", "\n", "また、これらの概念の概要については、[このブログ記事](https://ericmjl.github.io/blog/2019/5/29/reasoning-about-shapes-and-probability-distributions/)を参照してください。" ] }, { "cell_type": "markdown", "metadata": { "id": "yU34kIHDrG4I" }, "source": [ "### TensorFlow Eager に関する注意\n", "\n", "このノートブックは、すべて [TensorFlow Eager](https://research.googleblog.com/2017/10/eager-execution-imperative-define-by.html) を使用して記述されています。提示された概念は Eager に*依存*していませんが、Eager では、`Distribution` オブジェクトが Python で作成されるときに、分布バッチとイベントの形状が評価されます(したがって既知です)。一方、グラフ(非 Eager モード)では、グラフが実行されるまでイベントとバッチの形状が決定されていない分布を定義することができます。" ] }, { "cell_type": "markdown", "metadata": { "id": "MeirD-0JrG4K" }, "source": [ "## スカラー分布\n", "\n", "上記のように、`Distribution` オブジェクトではイベントとバッチの形状が定義されています。まず、分布を説明するユーティリティから始めます。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "bq8guNPtrG4M" }, "outputs": [], "source": [ "def describe_distributions(distributions):\n", " print('\\n'.join([str(d) for d in distributions]))" ] }, { "cell_type": "markdown", "metadata": { "id": "06CafVXWrG4Q" }, "source": [ "このセクションでは、*スカラー*分布(イベントの形状が `[]` の分布)について説明します。典型的な例は、`rate` で指定されたポアソン分布です。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Sdz1OMg7rG4S" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tfp.distributions.Poisson(\"One_Poisson_Scalar_Batch\", batch_shape=[], event_shape=[], dtype=float32)\n", "tfp.distributions.Poisson(\"Three_Poissons\", batch_shape=[3], event_shape=[], dtype=float32)\n", "tfp.distributions.Poisson(\"Two_by_Three_Poissons\", batch_shape=[2, 3], event_shape=[], dtype=float32)\n", "tfp.distributions.Poisson(\"One_Poisson_Vector_Batch\", batch_shape=[1], event_shape=[], dtype=float32)\n", "tfp.distributions.Poisson(\"One_Poisson_Expanded_Batch\", batch_shape=[1, 1], event_shape=[], dtype=float32)\n" ] } ], "source": [ "poisson_distributions = [\n", " tfd.Poisson(rate=1., name='One Poisson Scalar Batch'),\n", " tfd.Poisson(rate=[1., 10., 100.], name='Three Poissons'),\n", " tfd.Poisson(rate=[[1., 10., 100.,], [2., 20., 200.]],\n", " name='Two-by-Three Poissons'),\n", " tfd.Poisson(rate=[1.], name='One Poisson Vector Batch'),\n", " tfd.Poisson(rate=[[1.]], name='One Poisson Expanded Batch')\n", "]\n", "\n", "describe_distributions(poisson_distributions)" ] }, { "cell_type": "markdown", "metadata": { "id": "lVPVIsC9rG4a" }, "source": [ "ポアソン分布はスカラー分布であるため、そのイベントの形状は常に `[]` です。より多くのレートを指定すると、これらはバッチ形式で表示されます。例の最後のペアは興味深いものです。レートは 1 つだけですが、そのレートは空でない形状の numpy 配列に埋め込まれているため、その形状がバッチ形状になります。" ] }, { "cell_type": "markdown", "metadata": { "id": "cFlXG9O5rG4b" }, "source": [ "標準の正規分布もスカラーです。イベントの形状は、ポアソンの場合と同じように `[]` ですが、*ブロードキャスト*の最初の例で見ていきます。正規分布は、`loc` および `scale` パラメーターを使用して指定されます。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "e5PXRPM1rG4c" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tfp.distributions.Normal(\"Standard\", batch_shape=[], event_shape=[], dtype=float32)\n", "tfp.distributions.Normal(\"Standard_Vector_Batch\", batch_shape=[1], event_shape=[], dtype=float32)\n", "tfp.distributions.Normal(\"Different_Locs\", batch_shape=[4], event_shape=[], dtype=float32)\n", "tfp.distributions.Normal(\"Broadcasting_Scale\", batch_shape=[2, 4], event_shape=[], dtype=float32)\n" ] } ], "source": [ "normal_distributions = [\n", " tfd.Normal(loc=0., scale=1., name='Standard'),\n", " tfd.Normal(loc=[0.], scale=1., name='Standard Vector Batch'),\n", " tfd.Normal(loc=[0., 1., 2., 3.], scale=1., name='Different Locs'),\n", " tfd.Normal(loc=[0., 1., 2., 3.], scale=[[1.], [5.]],\n", " name='Broadcasting Scale')\n", "]\n", "\n", "describe_distributions(normal_distributions)" ] }, { "cell_type": "markdown", "metadata": { "id": "Dh70eNXHrG4i" }, "source": [ "上記の `Broadcasting Scale` 分布は興味深い例です。`loc` パラメーターは `[4]` の形状、`scale` パラメーターは `[2, 1]` の形状をもちます。[Numpy ブロードキャストルール](https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)を使用すると、バッチ形状は `[2, 4]` になります。 `\"Broadcasting Scale\"` 分布を定義するための同等の(ただし、あまりエレガントではなく、推奨されない)方法は次のとおりです。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9G5JNBzQrG4j" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tfp.distributions.Normal(\"Normal\", batch_shape=[2, 4], event_shape=[], dtype=float32)\n" ] } ], "source": [ "describe_distributions(\n", " [tfd.Normal(loc=[[0., 1., 2., 3], [0., 1., 2., 3.]],\n", " scale=[[1., 1., 1., 1.], [5., 5., 5., 5.]])])" ] }, { "cell_type": "markdown", "metadata": { "id": "_hSBWsokrG4p" }, "source": [ "以上のようにブロードキャストの表記は頭痛やバグの原因にもなりますが便利です。" ] }, { "cell_type": "markdown", "metadata": { "id": "trGxojHwrG4r" }, "source": [ "### スカラー分布のサンプリング" ] }, { "cell_type": "markdown", "metadata": { "id": "TDJqRz-qrG4t" }, "source": [ "分布で実行できる主なことは `sample` と `log_prob` の 2 つです。まず、サンプリングについて見ていきましょう。基本的なルールは、分布からサンプリングする場合、結果のテンソルは形状 `[sample_shape, batch_shape, event_shape]` になります。`batch_shape` と `event_shape` は `Distribution ` オブジェクトにより提供され、`sample_shape` は、`sample` の呼び出しにより提供されます。スカラー分布の場合、`event_shape = []` であるため、サンプルから返されるテンソルの形状は `[sample_shape, batch_shape]` になります。では、試してみましょう。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2TbeP0btrG4u" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tfp.distributions.Poisson(\"One_Poisson_Scalar_Batch\", batch_shape=[], event_shape=[], dtype=float32)\n", "Sample shape: 1\n", "Returned sample tensor shape: (1,)\n", "Sample shape: 2\n", "Returned sample tensor shape: (2,)\n", "Sample shape: [1, 5]\n", "Returned sample tensor shape: (1, 5)\n", "Sample shape: [3, 4, 5]\n", "Returned sample tensor shape: (3, 4, 5)\n", "\n", "tfp.distributions.Poisson(\"Three_Poissons\", batch_shape=[3], event_shape=[], dtype=float32)\n", "Sample shape: 1\n", "Returned sample tensor shape: (1, 3)\n", "Sample shape: 2\n", "Returned sample tensor shape: (2, 3)\n", "Sample shape: [1, 5]\n", "Returned sample tensor shape: (1, 5, 3)\n", "Sample shape: [3, 4, 5]\n", "Returned sample tensor shape: (3, 4, 5, 3)\n", "\n", "tfp.distributions.Poisson(\"Two_by_Three_Poissons\", batch_shape=[2, 3], event_shape=[], dtype=float32)\n", "Sample shape: 1\n", "Returned sample tensor shape: (1, 2, 3)\n", "Sample shape: 2\n", "Returned sample tensor shape: (2, 2, 3)\n", "Sample shape: [1, 5]\n", "Returned sample tensor shape: (1, 5, 2, 3)\n", "Sample shape: [3, 4, 5]\n", "Returned sample tensor shape: (3, 4, 5, 2, 3)\n", "\n", "tfp.distributions.Poisson(\"One_Poisson_Vector_Batch\", batch_shape=[1], event_shape=[], dtype=float32)\n", "Sample shape: 1\n", "Returned sample tensor shape: (1, 1)\n", "Sample shape: 2\n", "Returned sample tensor shape: (2, 1)\n", "Sample shape: [1, 5]\n", "Returned sample tensor shape: (1, 5, 1)\n", "Sample shape: [3, 4, 5]\n", "Returned sample tensor shape: (3, 4, 5, 1)\n", "\n", "tfp.distributions.Poisson(\"One_Poisson_Expanded_Batch\", batch_shape=[1, 1], event_shape=[], dtype=float32)\n", "Sample shape: 1\n", "Returned sample tensor shape: (1, 1, 1)\n", "Sample shape: 2\n", "Returned sample tensor shape: (2, 1, 1)\n", "Sample shape: [1, 5]\n", "Returned sample tensor shape: (1, 5, 1, 1)\n", "Sample shape: [3, 4, 5]\n", "Returned sample tensor shape: (3, 4, 5, 1, 1)\n", "\n" ] } ], "source": [ "def describe_sample_tensor_shape(sample_shape, distribution):\n", " print('Sample shape:', sample_shape)\n", " print('Returned sample tensor shape:',\n", " distribution.sample(sample_shape).shape)\n", "\n", "def describe_sample_tensor_shapes(distributions, sample_shapes):\n", " started = False\n", " for distribution in distributions:\n", " print(distribution)\n", " for sample_shape in sample_shapes:\n", " describe_sample_tensor_shape(sample_shape, distribution)\n", " print()\n", "\n", "sample_shapes = [1, 2, [1, 5], [3, 4, 5]]\n", "describe_sample_tensor_shapes(poisson_distributions, sample_shapes)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "qiJK8UBorG40" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tfp.distributions.Normal(\"Standard\", batch_shape=[], event_shape=[], dtype=float32)\n", "Sample shape: 1\n", "Returned sample tensor shape: (1,)\n", "Sample shape: 2\n", "Returned sample tensor shape: (2,)\n", "Sample shape: [1, 5]\n", "Returned sample tensor shape: (1, 5)\n", "Sample shape: [3, 4, 5]\n", "Returned sample tensor shape: (3, 4, 5)\n", "\n", "tfp.distributions.Normal(\"Standard_Vector_Batch\", batch_shape=[1], event_shape=[], dtype=float32)\n", "Sample shape: 1\n", "Returned sample tensor shape: (1, 1)\n", "Sample shape: 2\n", "Returned sample tensor shape: (2, 1)\n", "Sample shape: [1, 5]\n", "Returned sample tensor shape: (1, 5, 1)\n", "Sample shape: [3, 4, 5]\n", "Returned sample tensor shape: (3, 4, 5, 1)\n", "\n", "tfp.distributions.Normal(\"Different_Locs\", batch_shape=[4], event_shape=[], dtype=float32)\n", "Sample shape: 1\n", "Returned sample tensor shape: (1, 4)\n", "Sample shape: 2\n", "Returned sample tensor shape: (2, 4)\n", "Sample shape: [1, 5]\n", "Returned sample tensor shape: (1, 5, 4)\n", "Sample shape: [3, 4, 5]\n", "Returned sample tensor shape: (3, 4, 5, 4)\n", "\n", "tfp.distributions.Normal(\"Broadcasting_Scale\", batch_shape=[2, 4], event_shape=[], dtype=float32)\n", "Sample shape: 1\n", "Returned sample tensor shape: (1, 2, 4)\n", "Sample shape: 2\n", "Returned sample tensor shape: (2, 2, 4)\n", "Sample shape: [1, 5]\n", "Returned sample tensor shape: (1, 5, 2, 4)\n", "Sample shape: [3, 4, 5]\n", "Returned sample tensor shape: (3, 4, 5, 2, 4)\n", "\n" ] } ], "source": [ "describe_sample_tensor_shapes(normal_distributions, sample_shapes)" ] }, { "cell_type": "markdown", "metadata": { "id": "wDRB80oLrG48" }, "source": [ "`sample` についての説明は以上です。返されたサンプルテンソルの形状は `[sample_shape, batch_shape, event_shape]` です。\n", "\n", "### スカラー分布の `log_prob` の計算\n", "\n", "次に、`log_prob` を見てみましょう。これは少し注意する必要があります。`log_prob` は、分布の `log_prob` を計算する場所を表す(空でない)テンソルを入力として受け取ります。最も単純なケースでは、このテンソルは `[sample_shape, batch_shape, event_shape]` の形式になります。`batch_shape` と `event_shape` は 分布のバッチおよびイベントの形状に一致します。スカラー分布の場合は、`event_shape = []` なので、入力テンソルの形状は `[sample_shape, batch_shape]` です。この場合、`[sample_shape, batch_shape]` 形状のテンソルが返されます。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UgNIiFf9rG49" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 0, "metadata": {}, "output_type": "execute_result" } ], "source": [ "three_poissons = tfd.Poisson(rate=[1., 10., 100.], name='Three Poissons')\n", "three_poissons" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "OpN5WGog0WwC" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 0, "metadata": {}, "output_type": "execute_result" } ], "source": [ "three_poissons.log_prob([[1., 10., 100.], [100., 10., 1]]) # sample_shape is [2]." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4szFj9lkrG5F" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 0, "metadata": {}, "output_type": "execute_result" } ], "source": [ "three_poissons.log_prob([[[[1., 10., 100.], [100., 10., 1.]]]]) # sample_shape is [1, 1, 2]." ] }, { "cell_type": "markdown", "metadata": { "id": "VG_n9BHsrG5M" }, "source": [ "最初の例では、入力と出力の形状が `[2, 3]` であり、2 番目の例では形状が `[1, 1, 2, 3]` であることに注意してください。\n", "\n", "ブロードキャストがない場合はそれだけです。ブロードキャストを考慮する場合のルールは次のとおりです。これは一般的な説明であり、スカラー分布は簡略化されていることに注意してください。\n", "\n", "1. `n = len(batch_shape) + len(event_shape)` を定義します。(スカラー分布の場合は、`len(event_shape)=0`。)\n", "2. 入力テンソル `t` の次元が `n` 未満の場合、正確に `n` 次元になるまで、左側にサイズ `1` の次元を追加して形状をパッディングします。\n", "3. `t'` の右端の次元 `n` を `log_prob` 計算している分布の `[batch_shape, event_shape]` に対してブロードキャストします。詳しく説明すると、`t'` がすでに分布と一致している次元の場合は何もせず、`t'` の次元がシングルトンの場合は、そのシングルトンを適切な数で複製します。その他の場合はエラーです。(スカラー分布の場合、event_shape = `[]` であるため、 `batch_shape` に対してのみブロードキャストします。)\n", "4. これで、`log_prob` を計算できるようになりました。結果のテンソルの形状は、`[sample_shape, batch_shape]` です。`sample_shape` は、右端の次元 `n` の左側にある `t` または `t'` の任意の次元として定義されます(`sample_shape = shape(t)[:-n]`)。\n", "\n", "これが何を意味するのかわからないと混乱するかもしれないので、いくつかの例を見てみましょう。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YwDVaeRHrG5O" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 0, "metadata": {}, "output_type": "execute_result" } ], "source": [ "three_poissons.log_prob([10.])" ] }, { "cell_type": "markdown", "metadata": { "id": "xAImEhtdrG5U" }, "source": [ "テンソル `[10.]` (形状 `[1]`)は 3 つの`batch_shape` でブロードキャストされるため、値 10 での 3 つのポワソンの対数確率をすべて評価します。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "daDAG6p2rG5V" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 0, "metadata": {}, "output_type": "execute_result" } ], "source": [ "three_poissons.log_prob([[[1.], [10.]], [[100.], [1000.]]])" ] }, { "cell_type": "markdown", "metadata": { "id": "REEX-DgBrG5b" }, "source": [ "上記の例では、入力テンソルの形状は `[2, 2, 1]` ですが、分布オブジェクトの形状は 3 です。したがって、`[2, 2]` サンプル次元のそれぞれについて、提供された単一の値は、3 つのポワソンのそれぞれにブロードキャストします。\n", "\n", "これは役に立つ考え方です。`three_poissons` には `batch_shape = [2, 3]` があるため、`log_prob` の呼び出しには最後の次元が 1 または 3 のテンソルが必要です。それ以外はエラーです。(numpy ブロードキャストルールは、スカラーの特殊なケースを、形状 `[1]` のテンソルと完全に同等であるものとして扱います。)\n", "\n", "では、`batch_shape = [2, 3]` を使用して、より複雑なポアソン分布を使用して試してみましょう。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "MkSWkwYarG5d" }, "outputs": [], "source": [ "poisson_2_by_3 = tfd.Poisson(\n", " rate=[[1., 10., 100.,], [2., 20., 200.]],\n", " name='Two-by-Three Poissons')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9YFRkkssrG5f" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 0, "metadata": {}, "output_type": "execute_result" } ], "source": [ "poisson_2_by_3.log_prob(1.)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CqQXvOexrG5i" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 0, "metadata": {}, "output_type": "execute_result" } ], "source": [ "poisson_2_by_3.log_prob([1.]) # Exactly equivalent to above, demonstrating the scalar special case." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1nCuYQC5rG5m" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 0, "metadata": {}, "output_type": "execute_result" } ], "source": [ "poisson_2_by_3.log_prob([[1., 1., 1.], [1., 1., 1.]]) # Another way to write the same thing. No broadcasting." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2PgG6udBrG5p" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 0, "metadata": {}, "output_type": "execute_result" } ], "source": [ "poisson_2_by_3.log_prob([[1., 10., 100.]]) # Input is [1, 3] broadcast to [2, 3]." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Gm7ejyoArG5s" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 0, "metadata": {}, "output_type": "execute_result" } ], "source": [ "poisson_2_by_3.log_prob([[1., 10., 100.], [1., 10., 100.]]) # Equivalent to above. No broadcasting." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mVMSGVvGrG5w" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 0, "metadata": {}, "output_type": "execute_result" } ], "source": [ "poisson_2_by_3.log_prob([[1., 1., 1.], [2., 2., 2.]]) # No broadcasting." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "OVEpi5QErG5z" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 0, "metadata": {}, "output_type": "execute_result" } ], "source": [ "poisson_2_by_3.log_prob([[1.], [2.]]) # Equivalent to above. Input shape [2, 1] broadcast to [2, 3]." ] }, { "cell_type": "markdown", "metadata": { "id": "ZW2tApDGrG53" }, "source": [ "上記の例では、バッチを介したブロードキャストを見ていきましたが、サンプルの形状は空でした。値のコレクションがあり、バッチの各ポイントで各値の対数確率を取得する場合は、以下のように手動で実行できます。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "03DvnmK2rG53" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 0, "metadata": {}, "output_type": "execute_result" } ], "source": [ "poisson_2_by_3.log_prob([[[1., 1., 1.], [1., 1., 1.]], [[2., 2., 2.], [2., 2., 2.]]]) # Input shape [2, 2, 3]." ] }, { "cell_type": "markdown", "metadata": { "id": "XkpJQ0dJrG56" }, "source": [ "または、ブロードキャストに最後のバッチ次元を処理させることもできます。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "KJ6OsodCrG57" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 0, "metadata": {}, "output_type": "execute_result" } ], "source": [ "poisson_2_by_3.log_prob([[[1.], [1.]], [[2.], [2.]]]) # Input shape [2, 2, 1]." ] }, { "cell_type": "markdown", "metadata": { "id": "eZFx8pThrG5-" }, "source": [ "また、やや不自然ですがブロードキャストに最初のバッチ次元のみを処理させることもできます。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UoGs7GBSrG5_" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 0, "metadata": {}, "output_type": "execute_result" } ], "source": [ "poisson_2_by_3.log_prob([[[1., 1., 1.]], [[2., 2., 2.]]]) # Input shape [2, 1, 3]." ] }, { "cell_type": "markdown", "metadata": { "id": "cOP4OhGDrG6C" }, "source": [ "または、ブロードキャストに*両方*のバッチ次元を処理させることもできます。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tnG2f4tZrG6E" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 0, "metadata": {}, "output_type": "execute_result" } ], "source": [ "poisson_2_by_3.log_prob([[[1.]], [[2.]]]) # Input shape [2, 1, 1]." ] }, { "cell_type": "markdown", "metadata": { "id": "I1s1drAwrG6K" }, "source": [ "上記は、必要な値が 2 つしかない場合は問題ありませんでした。しかし、すべてのバッチポイントで評価する値のリストが長い場合は、次の表記を使用します。形状の右側にサイズ 1 の余分な次元を追加すると、非常に便利です。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "oUxbYZN_rG6K" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 0, "metadata": {}, "output_type": "execute_result" } ], "source": [ "poisson_2_by_3.log_prob(tf.constant([1., 2.])[..., tf.newaxis, tf.newaxis])" ] }, { "cell_type": "markdown", "metadata": { "id": "Se893aIurG6M" }, "source": [ "これは[ストライドスライス表記](https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/strided-slice)のインスタンスであり、知っておく価値があります。" ] }, { "cell_type": "markdown", "metadata": { "id": "XNDhHqJmrG6N" }, "source": [ "完全を期すために `three_poissons` に戻ると、同じ例は次のようになります。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zKP7OmQsrG6N" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 0, "metadata": {}, "output_type": "execute_result" } ], "source": [ "three_poissons.log_prob([[1.], [10.], [50.], [100.]])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PK_9DwSdrG6R" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 0, "metadata": {}, "output_type": "execute_result" } ], "source": [ "three_poissons.log_prob(tf.constant([1., 10., 50., 100.])[..., tf.newaxis]) # Equivalent to above." ] }, { "cell_type": "markdown", "metadata": { "id": "lhL17DW5rG6T" }, "source": [ "## 多変量分布\n", "\n", "ここでは、空でないイベント形状を持つ多変量分布を見ていきます。まず、多項分布を見てみましょう。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "lOdGa5n9rG6T" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tfp.distributions.Multinomial(\"One_Multinomial\", batch_shape=[], event_shape=[3], dtype=float32)\n", "tfp.distributions.Multinomial(\"Two_Multinomials_Same_Probs\", batch_shape=[2], event_shape=[3], dtype=float32)\n", "tfp.distributions.Multinomial(\"Two_Multinomials_Same_Counts\", batch_shape=[2], event_shape=[3], dtype=float32)\n", "tfp.distributions.Multinomial(\"Two_Multinomials_Different_Everything\", batch_shape=[2], event_shape=[3], dtype=float32)\n" ] } ], "source": [ "multinomial_distributions = [\n", " # Multinomial is a vector-valued distribution: if we have k classes,\n", " # an individual sample from the distribution has k values in it, so the\n", " # event_shape is `[k]`.\n", " tfd.Multinomial(total_count=100., probs=[.5, .4, .1],\n", " name='One Multinomial'),\n", " tfd.Multinomial(total_count=[100., 1000.], probs=[.5, .4, .1],\n", " name='Two Multinomials Same Probs'),\n", " tfd.Multinomial(total_count=100., probs=[[.5, .4, .1], [.1, .2, .7]],\n", " name='Two Multinomials Same Counts'),\n", " tfd.Multinomial(total_count=[100., 1000.],\n", " probs=[[.5, .4, .1], [.1, .2, .7]],\n", " name='Two Multinomials Different Everything')\n", "\n", "]\n", "\n", "describe_distributions(multinomial_distributions)" ] }, { "cell_type": "markdown", "metadata": { "id": "-NQ8gK7irG6W" }, "source": [ "最後の 3 つの例では、batch_shape は常に `[2]` でしたが、ブロードキャストを使用して、共有する `total_count` または共有する `probs` 使用できます(または、使用しないこともできます)。内部では同じ形状になるようにブロードキャストされるためです。\n", "\n", "既知の事柄を考慮すると、サンプリングは簡単です。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hSr362qjrG6W" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tfp.distributions.Multinomial(\"One_Multinomial\", batch_shape=[], event_shape=[3], dtype=float32)\n", "Sample shape: 1\n", "Returned sample tensor shape: (1, 3)\n", "Sample shape: 2\n", "Returned sample tensor shape: (2, 3)\n", "Sample shape: [1, 5]\n", "Returned sample tensor shape: (1, 5, 3)\n", "Sample shape: [3, 4, 5]\n", "Returned sample tensor shape: (3, 4, 5, 3)\n", "\n", "tfp.distributions.Multinomial(\"Two_Multinomials_Same_Probs\", batch_shape=[2], event_shape=[3], dtype=float32)\n", "Sample shape: 1\n", "Returned sample tensor shape: (1, 2, 3)\n", "Sample shape: 2\n", "Returned sample tensor shape: (2, 2, 3)\n", "Sample shape: [1, 5]\n", "Returned sample tensor shape: (1, 5, 2, 3)\n", "Sample shape: [3, 4, 5]\n", "Returned sample tensor shape: (3, 4, 5, 2, 3)\n", "\n", "tfp.distributions.Multinomial(\"Two_Multinomials_Same_Counts\", batch_shape=[2], event_shape=[3], dtype=float32)\n", "Sample shape: 1\n", "Returned sample tensor shape: (1, 2, 3)\n", "Sample shape: 2\n", "Returned sample tensor shape: (2, 2, 3)\n", "Sample shape: [1, 5]\n", "Returned sample tensor shape: (1, 5, 2, 3)\n", "Sample shape: [3, 4, 5]\n", "Returned sample tensor shape: (3, 4, 5, 2, 3)\n", "\n", "tfp.distributions.Multinomial(\"Two_Multinomials_Different_Everything\", batch_shape=[2], event_shape=[3], dtype=float32)\n", "Sample shape: 1\n", "Returned sample tensor shape: (1, 2, 3)\n", "Sample shape: 2\n", "Returned sample tensor shape: (2, 2, 3)\n", "Sample shape: [1, 5]\n", "Returned sample tensor shape: (1, 5, 2, 3)\n", "Sample shape: [3, 4, 5]\n", "Returned sample tensor shape: (3, 4, 5, 2, 3)\n", "\n" ] } ], "source": [ "describe_sample_tensor_shapes(multinomial_distributions, sample_shapes)" ] }, { "cell_type": "markdown", "metadata": { "id": "jjXgxPXCrG6Z" }, "source": [ "対数確率の計算も同様に簡単です。対角多変量正規分布の例を見てみましょう。(カウントと確率の制約により、ブロードキャストは許容できない値を生成することが多いため、多項分布はブロードキャストにあまり適していません。)平均は同じですがスケール(標準偏差)が異なる 2 つの 3 次元分布のバッチを使用します。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cnywBQdZrG6Z" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 0, "metadata": {}, "output_type": "execute_result" } ], "source": [ "two_multivariate_normals = tfd.MultivariateNormalDiag(loc=[1., 2., 3.], scale_diag=tf.ones([2, 3]) * [[1.], [2.]])\n", "two_multivariate_normals" ] }, { "cell_type": "markdown", "metadata": { "id": "S9xE21IirG6b" }, "source": [ "次に、各バッチポイントの平均とシフトされた平均での対数確率を評価します。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YBOLH33PrG6b" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 0, "metadata": {}, "output_type": "execute_result" } ], "source": [ "two_multivariate_normals.log_prob([[[1., 2., 3.]], [[3., 4., 5.]]]) # Input has shape [2,1,3]." ] }, { "cell_type": "markdown", "metadata": { "id": "oPDC6y3qrG6e" }, "source": [ "まったく同じように、[https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/strided-slice](ストライドスライス表記)を使用して、定数の中央に追加の形状 = 1 次元を挿入できます。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "M9m9GMezrG6f" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 0, "metadata": {}, "output_type": "execute_result" } ], "source": [ "two_multivariate_normals.log_prob(\n", " tf.constant([[1., 2., 3.], [3., 4., 5.]])[:, tf.newaxis, :]) # Equivalent to above." ] }, { "cell_type": "markdown", "metadata": { "id": "jJA07wN7rG6i" }, "source": [ "一方、余分な次元を追加しない場合は、`[1., 2., 3.]` を最初のバッチポイントに渡し、`[3., 4., 5.]` を 2 番目のバッチポイントに渡します。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xaX1unvPrG6i" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 0, "metadata": {}, "output_type": "execute_result" } ], "source": [ "two_multivariate_normals.log_prob(tf.constant([[1., 2., 3.], [3., 4., 5.]]))" ] }, { "cell_type": "markdown", "metadata": { "id": "JnhW86vcUfT8" }, "source": [ "## 形状変換テクニック" ] }, { "cell_type": "markdown", "metadata": { "id": "6EYcFW7OrG6m" }, "source": [ "### Reshape Bijector\n", "\n", "`Reshape` Bijector を使用すると、分布の *event_shape* の形状を変換できます。以下に例を示します。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1YT_lQCarG6m" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 0, "metadata": {}, "output_type": "execute_result" } ], "source": [ "six_way_multinomial = tfd.Multinomial(total_count=1000., probs=[.3, .25, .2, .15, .08, .02])\n", "six_way_multinomial" ] }, { "cell_type": "markdown", "metadata": { "id": "c5a5uXQsUpMs" }, "source": [ "`[6]` のイベント形状を持つ多項分布を作成しました。Reshape Bijector を使用すると、これを `[2, 3]` のイベント形状を持つ分布として扱うことができます。\n", "\n", "`Bijector` は、${\\mathbb R}^n$ の開集合上の微分可能な 1 対 1 の関数を表します。`Bijectors` は、`TransformedDistribution` と組み合わせて使用されます。これは、基本分布 $p(x)$ および$Y = g(X)$ を表す `Bijector` に関して分布 $p(y)$ をモデル化します。では、実際に見てみましょう。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Wttfn9Q-rG6o" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 0, "metadata": {}, "output_type": "execute_result" } ], "source": [ "transformed_multinomial = tfd.TransformedDistribution(\n", " distribution=six_way_multinomial,\n", " bijector=tfb.Reshape(event_shape_out=[2, 3]))\n", "transformed_multinomial" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sh6l4XZdrG6p" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 0, "metadata": {}, "output_type": "execute_result" } ], "source": [ "six_way_multinomial.log_prob([500., 100., 100., 150., 100., 50.])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6nMHlVrArG6r" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 0, "metadata": {}, "output_type": "execute_result" } ], "source": [ "transformed_multinomial.log_prob([[500., 100., 100.], [150., 100., 50.]])" ] }, { "cell_type": "markdown", "metadata": { "id": "dxZNZ02OrG6t" }, "source": [ "これは、`Reshape` Bijector が実行できる*唯一*のことです。イベント次元をバッチ次元に、またはバッチ次元をイベント次元に変換することはできません。" ] }, { "cell_type": "markdown", "metadata": { "id": "de7ek-FerG6t" }, "source": [ "### Independent 分布\n", "\n", "`Independent` 分布は、独立した、必ずしも同一ではない分布(バッチ)のコレクションを単一の分布として扱うために使用されます。より簡潔に言えば、`Independent` を使用すると、`batch_shape` の次元を `event_shape` の次元に変換できます。次に例を示します。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tLwioZPRrG6t" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 0, "metadata": {}, "output_type": "execute_result" } ], "source": [ "two_by_five_bernoulli = tfd.Bernoulli(\n", " probs=[[.05, .1, .15, .2, .25], [.3, .35, .4, .45, .5]],\n", " name=\"Two By Five Bernoulli\")\n", "two_by_five_bernoulli" ] }, { "cell_type": "markdown", "metadata": { "id": "-okVviR3rG6v" }, "source": [ "これは、表の確率が関連付けられた 2x5 のコインの配列として考えることができます。特定の任意の 1 と 0 のセットの確率を評価します。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9yq9jTGIrG6x" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 0, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pattern = [[1., 0., 0., 1., 0.], [0., 0., 1., 1., 1.]]\n", "two_by_five_bernoulli.log_prob(pattern)" ] }, { "cell_type": "markdown", "metadata": { "id": "C9CA19oPrG6y" }, "source": [ "`Independent` を使用すると、これを 2 つの異なる「5 つのベルヌーイのセット」に変換できます。これは、特定のパターンで出現するコイントスの「行」を単一の結果と見なす場合に役立ちます。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1iR23BMBrG6z" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 0, "metadata": {}, "output_type": "execute_result" } ], "source": [ "two_sets_of_five = tfd.Independent(\n", " distribution=two_by_five_bernoulli,\n", " reinterpreted_batch_ndims=1,\n", " name=\"Two Sets Of Five\")\n", "two_sets_of_five" ] }, { "cell_type": "markdown", "metadata": { "id": "mRrkesaPrG67" }, "source": [ "数学的には、5 つの「セット」ごとの対数確率を計算しています。セット内の 5 つの「独立した」コイントスの対数確率を合計するため、分布は「independent」と呼ばれます。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LcM6OgKNrG66" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 0, "metadata": {}, "output_type": "execute_result" } ], "source": [ "two_sets_of_five.log_prob(pattern)" ] }, { "cell_type": "markdown", "metadata": { "id": "krpnUVL9rG7A" }, "source": [ "さらに、`Independent` を使用して、個々のイベントが 2x5 のベルヌーイのセットである分布を作成できます。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PXSsoidirG7A" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 0, "metadata": {}, "output_type": "execute_result" } ], "source": [ "one_set_of_two_by_five = tfd.Independent(\n", " distribution=two_by_five_bernoulli, reinterpreted_batch_ndims=2,\n", " name=\"One Set Of Two By Five\")\n", "one_set_of_two_by_five.log_prob(pattern)" ] }, { "cell_type": "markdown", "metadata": { "id": "QfbfHA4hrG7F" }, "source": [ "`sample` の観点では、`Independent` を使用しても何も変更されないことに注意してください。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "uZ3NhQEZrG7F" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tfp.distributions.Bernoulli(\"Two_By_Five_Bernoulli\", batch_shape=[2, 5], event_shape=[], dtype=int32)\n", "Sample shape: [3, 5]\n", "Returned sample tensor shape: (3, 5, 2, 5)\n", "\n", "tfp.distributions.Independent(\"Two_Sets_Of_Five\", batch_shape=[2], event_shape=[5], dtype=int32)\n", "Sample shape: [3, 5]\n", "Returned sample tensor shape: (3, 5, 2, 5)\n", "\n", "tfp.distributions.Independent(\"One_Set_Of_Two_By_Five\", batch_shape=[], event_shape=[2, 5], dtype=int32)\n", "Sample shape: [3, 5]\n", "Returned sample tensor shape: (3, 5, 2, 5)\n", "\n" ] } ], "source": [ "describe_sample_tensor_shapes(\n", " [two_by_five_bernoulli,\n", " two_sets_of_five,\n", " one_set_of_two_by_five],\n", " [[3, 5]])" ] }, { "cell_type": "markdown", "metadata": { "id": "usTT7v0trG7H" }, "source": [ "最後の演習として、サンプリングと対数確率の観点から、`Normal` 分布のベクトルバッチと `MultivariateNormalDiag` 分布の相違点と類似点を検討することをお勧めします。`Independent` を使用して、`Normal` のバッチから `MultivariateNormalDiag` を構築するにはどうすればよいでしょうか?(`MultivariateNormalDiag` は、実際にはこの方法で実装されていません。)" ] } ], "metadata": { "colab": { "name": "Understanding_TensorFlow_Distributions_Shapes.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }