{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "Vi2Tl3VMnX68" }, "source": [ "##### Copyright 2021 The TensorFlow Probability Authors.\n", "\n", "Licensed under the Apache License, Version 2.0 (the \"License\");" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "FW9em4rqnw0S" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\"); { display-mode: \"form\" }\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "Zn9IkdJNQblp" }, "source": [ "# TFP リリースノートノートブック (0.13.0)\n", "\n", "このノートブックでは、いくつかの小さなスニペット(デモ)を介して TFP 0.11.0 で実現できることを紹介します。\n", "\n", "\n", " \n", " \n", " \n", " \n", "
TensorFlow.org で表示 Google Colab で実行 GitHubでソースを表示ノートブックをダウンロード
" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "Ceywx-aaQblq" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[K |████████████████████████████████| 5.4MB 8.8MB/s \n", "\u001b[K |████████████████████████████████| 3.9MB 37.1MB/s \n", "\u001b[K |████████████████████████████████| 296kB 31.6MB/s \n", "\u001b[?25h" ] } ], "source": [ "#@title Installs & imports { vertical-output: true }\n", "!pip3 install -qU tensorflow==2.5.0 tensorflow_probability==0.13.0 tensorflow-datasets inference_gym\n", "\n", "import tensorflow as tf\n", "import tensorflow_probability as tfp\n", "assert '0.13' in tfp.__version__, tfp.__version__\n", "assert '2.5' in tf.__version__, tf.__version__\n", "\n", "physical_devices = tf.config.list_physical_devices('CPU')\n", "tf.config.set_logical_device_configuration(\n", " physical_devices[0],\n", " [tf.config.LogicalDeviceConfiguration(),\n", " tf.config.LogicalDeviceConfiguration()])\n", "\n", "tfd = tfp.distributions\n", "tfb = tfp.bijectors\n", "tfpk = tfp.math.psd_kernels\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import scipy.interpolate\n", "import IPython\n", "import seaborn as sns\n", "import logging" ] }, { "cell_type": "markdown", "metadata": { "id": "iHBsq_t5IIQy" }, "source": [ "## ディストリビューション [コア数学]" ] }, { "cell_type": "markdown", "metadata": { "id": "6N1TInBM8V1r" }, "source": [ "### `BetaQuotient`\n", "\n", "2 つの独立したベータ分布確率変数の比率" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "Yq4tIvL8lhLW" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAQpklEQVR4nO3df6xfdX3H8edrIOKPuIq9NrXtdrtZNWj8Qe6QhW1B2GYVYvmDEIi6Tlmabeh0umFxyciWkNRtEVm2mXTQURICNMikEfeDIY4tGWUXUPkls0OQNoVeA6ibCa763h/fU/x6e8u99/u9P3o/9/lImnvO55zzPe984L7uJ59zvuekqpAkteWnFrsASdLcM9wlqUGGuyQ1yHCXpAYZ7pLUoOMXuwCAlStX1ujo6GKXIUlLyj333PPtqhqZatsxEe6jo6OMj48vdhmStKQkefxo25yWkaQGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBh0T31Cdb6Nbb31++bFtZy9iJZK0MBy5S1KDDHdJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUoGnDPcmOJAeTPDCp/cNJvp7kwSR/1td+aZK9SR5J8s75KFqS9MJm8iWma4C/Aq493JDkHcAm4C1V9VySV3ftJwMXAG8EXgP8S5LXVdUP57pwSdLRTTtyr6o7gacnNf8OsK2qnuv2Odi1bwJuqKrnquqbwF7g1DmsV5I0A4POub8O+OUke5L8a5Jf6NrXAE/07bevaztCki1JxpOMT0xMDFiGJGkqg4b78cBJwGnAHwK7kmQ2H1BV26tqrKrGRkZGBixDkjSVQcN9H3Bz9dwN/AhYCewH1vXtt7ZrkyQtoEHD/fPAOwCSvA44Afg2sBu4IMmLk6wHNgB3z0WhkqSZm/ZumSTXA2cAK5PsAy4DdgA7utsjfwBsrqoCHkyyC3gIOARc7J0ykrTwpg33qrrwKJved5T9LwcuH6YoSdJw/IaqJDXIcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDZo23JPsSHKwezHH5G0fT1JJVnbrSfKXSfYm+VqSU+ajaEnSC5vJyP0aYOPkxiTrgF8HvtXX/C56r9bbAGwBPjt8iZKk2Zo23KvqTuDpKTZdAVwCVF/bJuDa7sXZdwErkqyek0olSTM20Jx7kk3A/qr66qRNa4An+tb3dW1TfcaWJONJxicmJgYpQ5J0FLMO9yQvBT4J/PEwJ66q7VU1VlVjIyMjw3yUJGmSaV+QPYWfB9YDX00CsBa4N8mpwH5gXd++a7s2SdICmvXIvarur6pXV9VoVY3Sm3o5paqeBHYDv9HdNXMa8J2qOjC3JUuSpjOTWyGvB/4DeH2SfUkueoHdvwg8CuwF/hb43TmpUpI0K9NOy1TVhdNsH+1bLuDi4cuSJA1jkDn3JWF0662LXYIkLRofPyBJDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDZvImph1JDiZ5oK/tz5N8PcnXkvx9khV92y5NsjfJI0neOV+FS5KObiYj92uAjZPabgPeVFVvBv4LuBQgycnABcAbu2P+Jslxc1atJGlGpg33qroTeHpS2z9X1aFu9S5gbbe8Cbihqp6rqm/Se5fqqXNYryRpBuZizv2DwD90y2uAJ/q27evajpBkS5LxJOMTExNzUIYk6bChwj3JHwGHgOtme2xVba+qsaoaGxkZGaYMSdIkA78gO8lvAucAZ1VVdc37gXV9u63t2iRJC2igkXuSjcAlwHuq6vt9m3YDFyR5cZL1wAbg7uHLlCTNxrQj9yTXA2cAK5PsAy6jd3fMi4HbkgDcVVW/XVUPJtkFPERvuubiqvrhfBUvSZratOFeVRdO0Xz1C+x/OXD5MEVJkoYz8Jz7UjW69dbnlx/bdvYiViJJ88fHD0hSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBk0b7kl2JDmY5IG+tpOS3JbkG93PV3btSfKXSfYm+VqSU+azeEnS1GYycr8G2DipbStwe1VtAG7v1gHeRe/VehuALcBn56ZMSdJsTBvuVXUn8PSk5k3Azm55J3BuX/u11XMXsCLJ6rkqVpI0M4POua+qqgPd8pPAqm55DfBE3377urYjJNmSZDzJ+MTExIBlSJKmMvQF1aoqoAY4bntVjVXV2MjIyLBlSJL6DBruTx2ebul+Huza9wPr+vZb27VJkhbQoC/I3g1sBrZ1P2/pa/9QkhuAtwPf6Zu+Oeb4smxJrZo23JNcD5wBrEyyD7iMXqjvSnIR8Dhwfrf7F4F3A3uB7wMfmIeaJUnTmDbcq+rCo2w6a4p9C7h42KIkScPxG6qS1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1aNA3MR2T+t+sJEnL2VAj9yS/n+TBJA8kuT7JiUnWJ9mTZG+SG5OcMFfFSpJmZuBwT7IG+D1grKreBBwHXAB8Criiql4LPANcNBeFSpJmbtg59+OBlyQ5HngpcAA4E7ip274TOHfIc0iSZmngcK+q/cBfAN+iF+rfAe4Bnq2qQ91u+4A1Ux2fZEuS8STjExMTg5YhSZrCMNMyrwQ2AeuB1wAvAzbO9Piq2l5VY1U1NjIyMmgZkqQpDDMt86vAN6tqoqr+D7gZOB1Y0U3TAKwF9g9ZoyRploYJ928BpyV5aZIAZwEPAXcA53X7bAZuGa5ESdJsDTPnvofehdN7gfu7z9oOfAL4WJK9wKuAq+egTknSLAz1Jaaqugy4bFLzo8Cpw3yuJGk4Pn5AkhpkuEtSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBg31PPckK4CrgDcBBXwQeAS4ERgFHgPOr6pnhqpygY1uvfX55ce2nb2IlUjSYIYduV8J/GNVvQF4C/AwsBW4vao2ALd365KkBTRwuCf5aeBX6F6jV1U/qKpngU3Azm63ncC5wxYpSZqdYUbu64EJ4O+S3JfkqiQvA1ZV1YFunyeBVVMdnGRLkvEk4xMTE0OUIUmabJhwPx44BfhsVb0N+F8mTcFUVdGbiz9CVW2vqrGqGhsZGRmiDEnSZMNcUN0H7KuqPd36TfTC/akkq6vqQJLVwMFhi1wI/RdRJWmpG3jkXlVPAk8keX3XdBbwELAb2Ny1bQZuGapCSdKsDXUrJPBh4LokJwCPAh+g9wdjV5KLgMeB84c8hyRploYK96r6CjA2xaazhvlcSdJw/IaqJDXIcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1aNjHDzTPtzJJWoocuUtSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGDR3uSY5Lcl+SL3Tr65PsSbI3yY3dW5okSQtoLkbuHwEe7lv/FHBFVb0WeAa4aA7OIUmahaHCPcla4Gzgqm49wJnATd0uO4FzhzmHJGn2hh25fwa4BPhRt/4q4NmqOtSt7wPWTHVgki1JxpOMT0xMDFmGJKnfwOGe5BzgYFXdM8jxVbW9qsaqamxkZGTQMiRJUxjm2TKnA+9J8m7gROAVwJXAiiTHd6P3tcD+4cuUJM3GwCP3qrq0qtZW1ShwAfClqnovcAdwXrfbZuCWoauUJM3KfNzn/gngY0n20puDv3oeziFJegFz8sjfqvoy8OVu+VHg1Ln4XEnSYPyGqiQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWrQnNwKuRyNbr31+eXHtp29iJVI0pEM91noD3RJOpY5LSNJDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMGvs89yTrgWmAVUMD2qroyyUnAjcAo8BhwflU9M3ypU/Pec0k60jAj90PAx6vqZOA04OIkJwNbgduragNwe7cuSVpAw7xD9UBV3dstfw94GFgDbAJ2drvtBM4dtkhJ0uzMyZx7klHgbcAeYFVVHeg2PUlv2maqY7YkGU8yPjExMRdlSJI6Q4d7kpcDnwM+WlXf7d9WVUVvPv4IVbW9qsaqamxkZGTYMiRJfYZ6cFiSF9EL9uuq6uau+akkq6vqQJLVwMFhizzW+YRISceagUfuSQJcDTxcVZ/u27Qb2NwtbwZuGbw8SdIghhm5nw68H7g/yVe6tk8C24BdSS4CHgfOH65ESdJsDRzuVfXvQI6y+axBP3epOzxF4/SMpMXkyzoWmPPzkhaCjx+QpAY5cp8nPhZB0mJy5C5JDTLcJalBhrskNchwl6QGeUF1EXlbpKT54shdkhrkyP0YcbRbJx3RSxqEI3dJapDhLkkNMtwlqUHOuR/j5uqOGu/MkZYXw30J8aKrpJky3Bs23cPLHM1L7Zq3cE+yEbgSOA64qqq2zde5lru5COmjfcZMnm45V38YfNGJNHfm5YJqkuOAvwbeBZwMXJjk5Pk4lyTpSPM1cj8V2FtVjwIkuQHYBDw0T+dTZyYj7dlM1yy2pTZ1tJTr7bcUal/KFuL/k1TV3H9och6wsap+q1t/P/D2qvpQ3z5bgC3d6uuBR2ZxipXAt+eo3KXKPrAPDrMflm8f/GxVjUy1YdEuqFbVdmD7IMcmGa+qsTkuaUmxD+yDw+wH+2Aq8/Ulpv3Aur71tV2bJGkBzFe4/yewIcn6JCcAFwC75+lckqRJ5mVapqoOJfkQ8E/0boXcUVUPzuEpBprOaYx9YB8cZj/YB0eYlwuqkqTF5YPDJKlBhrskNWhJhXuSjUkeSbI3ydbFrmehJNmR5GCSB/raTkpyW5JvdD9fuZg1zrck65LckeShJA8m+UjXvmz6IcmJSe5O8tWuD/6ka1+fZE/3e3FjdxND05Icl+S+JF/o1pddH0xnyYT7Mn+kwTXAxkltW4Hbq2oDcHu33rJDwMer6mTgNODi7r//cuqH54Azq+otwFuBjUlOAz4FXFFVrwWeAS5axBoXykeAh/vWl2MfvKAlE+70PdKgqn4AHH6kQfOq6k7g6UnNm4Cd3fJO4NwFLWqBVdWBqrq3W/4evV/sNSyjfqie/+lWX9T9K+BM4Kauvek+AEiyFjgbuKpbD8usD2ZiKYX7GuCJvvV9XdtytaqqDnTLTwKrFrOYhZRkFHgbsIdl1g/ddMRXgIPAbcB/A89W1aFul+Xwe/EZ4BLgR936q1h+fTCtpRTuOorq3c+6LO5pTfJy4HPAR6vqu/3blkM/VNUPq+qt9L71fSrwhkUuaUElOQc4WFX3LHYtx7ql9LIOH2nwk55KsrqqDiRZTW8k17QkL6IX7NdV1c1d87LrB4CqejbJHcAvAiuSHN+NXFv/vTgdeE+SdwMnAq+g996I5dQHM7KURu4+0uAn7QY2d8ubgVsWsZZ5182rXg08XFWf7tu0bPohyUiSFd3yS4Bfo3ft4Q7gvG63pvugqi6tqrVVNUovA75UVe9lGfXBTC2pb6h2f60/w48faXD5Ipe0IJJcD5xB77GmTwGXAZ8HdgE/AzwOnF9Vky+6NiPJLwH/BtzPj+daP0lv3n1Z9EOSN9O7WHgcvYHZrqr60yQ/R+8Gg5OA+4D3VdVzi1fpwkhyBvAHVXXOcu2DF7Kkwl2SNDNLaVpGkjRDhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lq0P8DYL+jYbvBk1oAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light", "tags": [] }, "output_type": "display_data" } ], "source": [ "plt.hist(tfd.BetaQuotient(concentration1_numerator=5.,\n", " concentration0_numerator=2.,\n", " concentration1_denominator=3.,\n", " concentration0_denominator=8.).sample(1_000, seed=(1, 23)),\n", " bins='auto');" ] }, { "cell_type": "markdown", "metadata": { "id": "DtR1AvKz9y-P" }, "source": [ "### `DeterminantalPointProcess`\n", "\n", "特定のセットのサブセット(ワンホットとして表される)にわたるディストリビューション。サンプルは反発特性(確率は、選択したポイントのサブセットに対応するベクトルがまたがるボリュームに比例します)に従います。これは、さまざまなサブセットをサンプリングする傾向があります。[i.i.d ベルヌーイサンプルと比較してください。]" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "QS5JxWys9ygT" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tfp.distributions.DeterminantalPointProcess(\"DeterminantalPointProcess\", batch_shape=[4], event_shape=[256], dtype=int32)\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "tags": [] }, "output_type": "display_data" } ], "source": [ "grid_size = 16\n", "# Generate grid_size**2 pts on the unit square.\n", "grid = np.arange(0, 1, 1./grid_size).astype(np.float32)\n", "import itertools\n", "points = np.array(list(itertools.product(grid, grid)))\n", "\n", "# Create the kernel L that parameterizes the DPP.\n", "kernel_amplitude = 2.\n", "kernel_lengthscale = [.1, .15, .2, .25] # Increasing length scale indicates more points are \"nearby\", tending toward smaller subsets.\n", "kernel = tfpk.ExponentiatedQuadratic(kernel_amplitude, kernel_lengthscale)\n", "kernel_matrix = kernel.matrix(points, points)\n", "\n", "eigenvalues, eigenvectors = tf.linalg.eigh(kernel_matrix)\n", "dpp = tfd.DeterminantalPointProcess(eigenvalues, eigenvectors)\n", "print(dpp)\n", "\n", "# The inner-most dimension of the result of `dpp.sample` is a multi-hot\n", "# encoding of a subset of {1, ..., ground_set_size}.\n", "# We will compare against a bernoulli distribution.\n", "samps_dpp = dpp.sample(seed=(1, 2)) # 4 x grid_size**2\n", "logits = tf.broadcast_to([[-1.], [-1.5], [-2], [-2.5]], [4, grid_size**2])\n", "samps_bern = tfd.Bernoulli(logits=logits).sample(seed=(2, 3))\n", "\n", "plt.figure(figsize=(12, 6))\n", "for i, (samp, samp_bern) in enumerate(zip(samps_dpp, samps_bern)):\n", " plt.subplot(241 + i)\n", " plt.scatter(*points[np.where(samp)].T)\n", " plt.title(f'DPP, length scale={kernel_lengthscale[i]}')\n", " plt.xticks([])\n", " plt.yticks([])\n", " plt.gca().set_aspect(1.)\n", " plt.subplot(241 + i + 4)\n", " plt.scatter(*points[np.where(samp_bern)].T)\n", " plt.title(f'bernoulli, logit={logits[i,0]}')\n", " plt.xticks([])\n", " plt.yticks([])\n", " plt.gca().set_aspect(1.)\n", "\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "qim-evPz8e72" }, "source": [ "### `SigmoidBeta`\n", "\n", "2 つのガンマ分布の対数オッズ。`Beta` よりもサンプル空間が数値的に安定しています。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "eIpGyo7Glx9s" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYAAAAD4CAYAAADlwTGnAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAQqklEQVR4nO3dbYwd113H8e8Pm1i0gJsmRlDbxa7igjYgaLt1w6MKLsTBFX6TUkcqMsUQtUr6pEpV3EoVCopk04oQ1ATJSoKSEMkNbhErakhbhSDxInY2fSRxLRbbxU6L6sYhPCkJa/68uJP05mbXe9f75PX5fiTLM2fOzD0zGs1vZ87cc1NVSJLa831L3QBJ0tIwACSpUQaAJDXKAJCkRhkAktSolUvdgNm4/PLLa8OGDUvdDElaNh577LHvVtWaqZYtqwDYsGED4+PjS90MSVo2knxzumU+ApKkRhkAktQoA0CSGmUASFKjDABJapQBIEmNMgAkqVEGgCQ1ygCQpEYtq28CS9LFasNNn3vJ/Ik92xb8M70DkKRGGQCS1CgDQJIaZQBIUqMMAElqlAEgSY0yACSpUQaAJDXKAJCkRhkAktQoA0CSGmUASFKjDABJapQBIEmNMgAkqVEGgCQ1ygCQpEYN9YtgSbYCtwErgDuras/A8lXAvcCbgKeAd1bViW7ZbmAXcBZ4f1U92JV/CPg9oICvA++uqmfnYZ8kaVkY/BWwxTbjHUCSFcDtwDXACHBdkpGBaruAp6vqCuBWYG+37giwA7gS2ArckWRFkrXA+4HRqvopesGyY352SZI0jGEeAW0GJqrqWFU9D+wHtg/U2Q7c000fALYkSVe+v6qeq6rjwES3PejdffxAkpXAK4BvzW1XJEmzMUwArAVO9s2f6sqmrFNVk8AzwGXTrVtVTwKfBP4V+DbwTFV9fqoPT3J9kvEk46dPnx6iuZKkYSxJJ3CSS+ndHWwEXgO8Msm7pqpbVfuqarSqRtesWbOYzZSki9owAfAksL5vfl1XNmWd7pHOanqdwdOt+zbgeFWdrqr/BT4L/Pz57IAk6fwMEwCPApuSbExyCb3O2rGBOmPAzm76WuChqqqufEeSVUk2ApuAw/Qe/VyV5BVdX8EW4Mjcd0eSNKwZXwOtqskkNwIP0ntb5+6qejzJzcB4VY0BdwH3JZkAztC90dPVewB4ApgEbqiqs8ChJAeAL3XlXwb2zf/uSZKmk94f6svD6OhojY+PL3UzJGlenOt7ACf2bJuXz0jyWFWNTrXMbwJLUqMMAElqlAEgSY0yACSpUQaAJDXKAJCkRhkAktQoA0CSGmUASFKjDABJapQBIEmNMgAkqVEGgCQ1asbhoCVJ8+Nco38uBe8AJKlRBoAkNcoAkKRGGQCS1CgDQJIaZQBIUqMMAElqlAEgSY0yACSpUQaAJDXKAJCkRhkAktQoA0CSGmUASFKjHA5akhbQhTYEdD/vACSpUQaAJDXKAJCkRhkAktQoA0CSGmUASFKjfA1Uki5A/a+PntizbUE+wwCQpHl0Ib/3P2ioR0BJtiY5mmQiyU1TLF+V5NPd8kNJNvQt292VH01ydV/5q5IcSPKNJEeS/Nx87JAkaTgzBkCSFcDtwDXACHBdkpGBaruAp6vqCuBWYG+37giwA7gS2Arc0W0P4Dbg76rqJ4GfAY7MfXckScMa5g5gMzBRVceq6nlgP7B9oM524J5u+gCwJUm68v1V9VxVHQcmgM1JVgO/DNwFUFXPV9W/z313JEnDGiYA1gIn++ZPdWVT1qmqSeAZ4LJzrLsROA38eZIvJ7kzySun+vAk1ycZTzJ++vTpIZorSRrGUr0GuhJ4I/BnVfUG4L+Bl/UtAFTVvqoararRNWvWLGYbJemiNkwAPAms75tf15VNWSfJSmA18NQ51j0FnKqqQ135AXqBIElaJMMEwKPApiQbk1xCr1N3bKDOGLCzm74WeKiqqivf0b0ltBHYBByuqn8DTib5iW6dLcATc9wXSdIszPg9gKqaTHIj8CCwAri7qh5PcjMwXlVj9Dpz70syAZyhFxJ09R6gd3GfBG6oqrPdpt8H3N+FyjHg3fO8b5Kkcxjqi2BVdRA4OFD28b7pZ4F3TLPuLcAtU5R/BRidTWMlSfPHsYAkqVEGgCQ1ygCQpEYZAJLUKANAkhplAEhSowwASWqUASBJjTIAJKlRBoAkNcoAkKRGGQCS1KihBoOTJE1vw02fW+omnBfvACSpUQaAJDXKAJCkRhkAktQoA0CSGmUASFKjDABJapQBIEmNMgAkqVEGgCQ1ygCQpEYZAJLUKANAkhplAEhSowwASWqUvwcgSbO0XMf/H2QASNIQLpaLfj8fAUlSowwASWqUASBJjTIAJKlRBoAkNcoAkKRGGQCS1CgDQJIaNVQAJNma5GiSiSQ3TbF8VZJPd8sPJdnQt2x3V340ydUD661I8uUkfzPXHZEkzc6MAZBkBXA7cA0wAlyXZGSg2i7g6aq6ArgV2NutOwLsAK4EtgJ3dNt7wQeAI3PdCUnS7A1zB7AZmKiqY1X1PLAf2D5QZztwTzd9ANiSJF35/qp6rqqOAxPd9kiyDtgG3Dn33ZAkzdYwYwGtBU72zZ8C3jJdnaqaTPIMcFlX/sjAumu76T8BPgL80Lk+PMn1wPUAr33ta4doriTN3cU49s+gJekETvJ24DtV9dhMdatqX1WNVtXomjVrFqF1ktSGYQLgSWB93/y6rmzKOklWAquBp86x7i8Av5nkBL1HSr+a5C/Oo/2SpPM0TAA8CmxKsjHJJfQ6dccG6owBO7vpa4GHqqq68h3dW0IbgU3A4araXVXrqmpDt72Hqupd87A/kqQhzdgH0D3TvxF4EFgB3F1Vjye5GRivqjHgLuC+JBPAGXoXdbp6DwBPAJPADVV1doH2RZI0C0P9IExVHQQODpR9vG/6WeAd06x7C3DLObb9MPDwMO2QJM0ffxFMkjotvPnTz6EgJKlRBoAkNcoAkKRGGQCS1CgDQJIaZQBIUqMMAElqlAEgSY0yACSpUQaAJDXKoSAkNau1oR8GeQcgSY0yACSpUQaAJDXKAJCkRhkAktQoA0CSGmUASFKjDABJapQBIEmN8pvAkprS+rd/+3kHIEmNMgAkqVEGgCQ1ygCQpEbZCSzpoman7/S8A5CkRhkAktQoA0CSGmUASFKjDABJapQBIEmN8jVQScve4KueJ/ZsW6KWLC/eAUhSo7wDkHTR8ctfw/EOQJIaNVQAJNma5GiSiSQ3TbF8VZJPd8sPJdnQt2x3V340ydVd2fokf5/kiSSPJ/nAfO2QJGk4MwZAkhXA7cA1wAhwXZKRgWq7gKer6grgVmBvt+4IsAO4EtgK3NFtbxL4cFWNAFcBN0yxTUnSAhrmDmAzMFFVx6rqeWA/sH2gznbgnm76ALAlSbry/VX1XFUdByaAzVX17ar6EkBV/SdwBFg7992RJA1rmABYC5zsmz/Fyy/WL9apqkngGeCyYdbtHhe9ATg01YcnuT7JeJLx06dPD9FcSdIwlrQTOMkPAp8BPlhV/zFVnaraV1WjVTW6Zs2axW2gJF3EhgmAJ4H1ffPrurIp6yRZCawGnjrXukm+n97F//6q+uz5NF6SdP6GCYBHgU1JNia5hF6n7thAnTFgZzd9LfBQVVVXvqN7S2gjsAk43PUP3AUcqao/no8dkSTNzoxfBKuqySQ3Ag8CK4C7q+rxJDcD41U1Ru9ifl+SCeAMvZCgq/cA8AS9N39uqKqzSX4R+G3g60m+0n3UR6vq4HzvoKSLg8M9zL+hvgncXZgPDpR9vG/6WeAd06x7C3DLQNk/ApltYyVJ88dvAktSowwASWqUg8FJWpYc8G3uvAOQpEYZAJLUKB8BSbpg+ZhnYXkHIEmN8g5A0gXDv/gXl3cAktQoA0CSGmUASFKj7AOQtKR87r90vAOQpEYZAJLUKANAkhplAEhSowwASWqUbwFJmne+2bM8eAcgSY3yDkDSvPCv/uXHAJB0XrzgL38GgKRpDV7kT+zZtkQt0UIwACQNzb/6Ly52AktSo7wDkPQS/pXfDu8AJKlRBoAkNcoAkKRGGQCS1Cg7gaUG2dErMACkJnjB11R8BCRJjTIAJKlRBoAkNco+AOkCd67n9/2Ds/mcX7NlAEiLpP8CPV+janrR11wYANICOd+Lsxd1LRYDQJpHXvS1nAwVAEm2ArcBK4A7q2rPwPJVwL3Am4CngHdW1Ylu2W5gF3AWeH9VPTjMNqXlwAu+lrMZAyDJCuB24NeAU8CjScaq6om+aruAp6vqiiQ7gL3AO5OMADuAK4HXAF9M8vpunZm2Kc3ZsB2os1lPulgMcwewGZioqmMASfYD24H+i/V24A+66QPAp5KkK99fVc8Bx5NMdNtjiG3qPM3mZ/zOt2Nysd9MWYiLtRd5tW6YAFgLnOybPwW8Zbo6VTWZ5Bngsq78kYF113bTM20TgCTXA9d3s/+V5OgQbb4QXA58d6kbAZC981tvDts572MyX227wFww58gFxGPycpdn75yOyY9Pt+CC7wSuqn3AvqVux2wlGa+q0aVux4XEY/JSHo+X85i83EIek2G+CfwksL5vfl1XNmWdJCuB1fQ6g6dbd5htSpIW0DAB8CiwKcnGJJfQ69QdG6gzBuzspq8FHqqq6sp3JFmVZCOwCTg85DYlSQtoxkdA3TP9G4EH6b2yeXdVPZ7kZmC8qsaAu4D7uk7eM/Qu6HT1HqDXuTsJ3FBVZwGm2ub8796SWnaPrRaBx+SlPB4v5zF5uQU7Jun9oS5Jao2jgUpSowwASWqUATBHSX42ySNJvpJkPMnmrjxJ/jTJRJKvJXlj3zo7k/xz92/n9FtfvpK8L8k3kjye5I/6ynd3x+Rokqv7yrd2ZRNJblqaVi+8JB9OUkku7+abPU+SfKI7R76W5K+SvKpvWdPnyQsWfH+ryn9z+Ad8Hrimm/4N4OG+6b8FAlwFHOrKXw0c6/6/tJu+dKn3Y56Pya8AXwRWdfM/0v0/AnwVWAVsBP6F3ksAK7rp1wGXdHVGlno/FuC4rKf34sM3gcs9T/h1YGU3vRfY63nykuOz4PvrHcDcFfDD3fRq4Fvd9Hbg3up5BHhVkh8Drga+UFVnqupp4AvA1sVu9AJ7L7CnekOAUFXf6cpfHBqkqo4DLwwN8uJwI1X1PPDC0CAXm1uBj9A7Z17Q7HlSVZ+vqslu9hF63wcCz5MXLPj+GgBz90HgE0lOAp8EdnflUw2hsfYc5ReT1wO/lORQkn9I8uauvNljkmQ78GRVfXVgUbPHZMDv0rsTAo/JCxZ8fy/4oSAuBEm+CPzoFIs+BmwBPlRVn0nyW/S+E/G2xWzfUpjhmKyk9+jiKuDNwANJXreIzVsSMxyTj9J75NGUcx2Tqvrrrs7H6H1P6P7FbJsMgKFU1bQX9CT3Ah/oZv8SuLObPtcwGG8dKH94npq6aGY4Ju8FPlu9B5mHk/wfvUG+zjUEyLIfGmS6Y5Lkp+k9y/5qb5Bc1gFf6l4YaPY8AUjyO8DbgS3d+QIX+XkyCws/ZM5Sd3Qs93/AEeCt3fQW4LFuehsv7dw73JW/GjhOr2Pv0m761Uu9H/N8TN4D3NxNv57ebWzo/S5Ef+feMXodXSu76Y18r7PryqXejwU8Pif4Xidwy+fJVnqjBKwZKPc86R2HBd9f7wDm7veB27pB8J7le0NXH6T3hscE8D/AuwGq6kySP6Q3HhL0LpRnFrfJC+5u4O4k/wQ8D+ys3hnd8tAg02n5PPkUvYv8F7o7o0eq6j3V9hAyL6pphuGZz89wKAhJapRvAUlSowwASWqUASBJjTIAJKlRBoAkNcoAkKRGGQCS1Kj/B0gffgdXuhSpAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light", "tags": [] }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Old way, fractions non-finite:\n", "0.4215\n", "0.8624\n" ] } ], "source": [ "plt.hist(tfd.SigmoidBeta(concentration1=.01, concentration0=2.).sample(10_000, seed=(1, 23)),\n", " bins='auto', density=True);\n", "plt.show()\n", "\n", "print('Old way, fractions non-finite:')\n", "print(np.sum(~tf.math.is_finite(\n", " tfb.Invert(tfb.Sigmoid())(tfd.Beta(concentration1=.01, concentration0=2.)).sample(10_000, seed=(1, 23)))) / 10_000)\n", "print(np.sum(~tf.math.is_finite(\n", " tfb.Invert(tfb.Sigmoid())(tfd.Beta(concentration1=2., concentration0=.01)).sample(10_000, seed=(2, 34)))) / 10_000)" ] }, { "cell_type": "markdown", "metadata": { "id": "hGGkU_8A8tOn" }, "source": [ "### Zipf\n", "\n", "JAX サポートが追加されました。" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "l8bw9c49qPoY" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD4CAYAAADvsV2wAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAANF0lEQVR4nO3df6xf9V3H8efLVlSm3jlhUyn1MmEoTifLFadE49ww1VK6GKMQNRjIGhbBaRZNUeOf0qjxxwJxaQZ2iwghiLNdmUCmk39wUtBtsIprkMFFZjuNV6OJSnz7x/0WL9fey7e9p9/zoZ/nI2l6v4fv/Z536e3znvs5556bqkKSdOb7krEHkCTNhsGXpE4YfEnqhMGXpE4YfEnqxOaxB1jPOeecU/Pz82OPIUmvKo899tgXq+rc1dubDv78/DyHDh0aewxJelVJ8vkTbXdJR5I6YfAlqRNNBj/JjiR7l5aWxh5Fks4YTQa/qg5U1a65ubmxR5GkM0aTwZckDa/J4LukI0nDazL4LulI0vCaDL4kaXhNf+PVRszvPjjT/T2zZ/tM9ydJJ8sjfEnqRJPB96StJA2vyeB70laShtdk8CVJwzP4ktQJgy9JnWgy+J60laThNRl8T9pK0vCaDL4kaXgGX5I6YfAlqRMGX5I6YfAlqRNNBt/LMiVpeE0G38syJWl4TQZfkjQ8gy9JnTD4ktQJgy9JnTD4ktQJgy9JnWgy+F6HL0nDazL4XocvScNrMviSpOEZfEnqhMGXpE4YfEnqhMGXpE4YfEnqhMGXpE4YfEnqhMGXpE4YfEnqRJPB9146kjS8JoPvvXQkaXhNBl+SNDyDL0mdMPiS1AmDL0mdMPiS1AmDL0mdMPiS1AmDL0mdMPiS1AmDL0mdMPiS1AmDL0mdMPiS1AmDL0mdmFnwk7wxye1J7p3VPiVJ/2eq4Ce5I8nRJE+s2r4tyVNJjiTZvd5rVNXTVXX9RoaVJJ26zVM+bx9wK/Dh4xuSbAJuA64AFoFHk+wHNgG3rHr/66rq6IanlSSdsqmCX1UPJ5lftfky4EhVPQ2Q5G5gZ1XdAlw55JCSpI3byBr+ecBzKx4vTradUJKvTfIB4NIkN6/zvF1JDiU5dOzYsQ2MJ0laadolnQ2rqn8CbpjieXuBvQALCwt1uueSpF5s5Aj/eeD8FY+3TLZJkhq0keA/ClyU5IIkZwFXA/uHGCrJjiR7l5aWhng5SRLTX5Z5F/AIcHGSxSTXV9WLwI3AA8Bh4J6qenKIoarqQFXtmpubG+LlJElMf5XONWtsvx+4f9CJJEmnRZO3VnBJR5KG12TwXdKRpOE1GXxJ0vAMviR1osngu4YvScNrMviu4UvS8JoMviRpeAZfkjph8CWpE00G35O2kjS8JoPvSVtJGl6TwZckDc/gS1InDL4kdaLJ4HvSVpKG12TwPWkrScNrMviSpOEZfEnqhMGXpE4YfEnqRJPB9yodSRpek8H3Kh1JGl6TwZckDc/gS1InDL4kdcLgS1InDL4kdcLgS1InDL4kdaLJ4PuNV5I0vCaD7zdeSdLwmgy+JGl4Bl+SOrF57AHOFPO7D85sX8/s2T6zfUk6c3iEL0mdMPiS1AmDL0mdMPiS1AmDL0mdMPiS1IkmL8tMsgPYceGFF449SpNmeQkoeBmodKZo8gjfWytI0vCaDL4kaXgGX5I6YfAlqRMGX5I6YfAlqRMGX5I6YfAlqRMGX5I6YfAlqRMGX5I6YfAlqRMGX5I6YfAlqRMGX5I6YfAlqRMz+wEoSd4FbAe+Gri9qh6c1b4lSVMe4Se5I8nRJE+s2r4tyVNJjiTZvd5rVNVHqurdwA3Aj5/6yJKkUzHtEf4+4Fbgw8c3JNkE3AZcASwCjybZD2wCbln1/tdV1dHJ278yeT9J0gxNFfyqejjJ/KrNlwFHquppgCR3Azur6hbgytWvkSTAHuBjVfX4WvtKsgvYBbB169ZpxpMkTWEjJ23PA55b8Xhxsm0tNwHvBH40yQ1rPamq9lbVQlUtnHvuuRsYT5K00sxO2lbV+4H3z2p/kqSX28gR/vPA+Sseb5ls27AkO5LsXVpaGuLlJElsLPiPAhcluSDJWcDVwP4hhqqqA1W1a25uboiXkyQx/WWZdwGPABcnWUxyfVW9CNwIPAAcBu6pqidP36iSpI2Y9iqda9bYfj9w/6ATSZJOiyZvreAaviQNr8ngu4YvScNrMviSpOEZfEnqRJPBdw1fkobXZPBdw5ek4TUZfEnS8Ay+JHWiyeC7hi9Jw2sy+K7hS9Lwmgy+JGl4Bl+SOmHwJakTTQbfk7aSNLwmg+9JW0kaXpPBlyQNz+BLUicMviR1wuBLUicMviR1osnge1mmJA2vyeB7WaYkDa/J4EuShmfwJakTm8ceQO2b331wpvt7Zs/2me5P6oVH+JLUCYMvSZ0w+JLUiSaD73X4kjS8JoPvdfiSNLwmgy9JGp7Bl6ROGHxJ6oTBl6ROGHxJ6oTBl6ROGHxJ6oTBl6ROGHxJ6oS3R1ZzZnk7Zm/FrJ40eYTvvXQkaXhNBt976UjS8JoMviRpeAZfkjph8CWpEwZfkjph8CWpEwZfkjph8CWpEwZfkjph8CWpEwZfkjph8CWpEwZfkjph8CWpEwZfkjph8CWpEzMLfpJvSfKBJPcmec+s9itJWjZV8JPckeRokidWbd+W5KkkR5LsXu81qupwVd0A/Bhw+amPLEk6FdMe4e8Dtq3ckGQTcBvwQ8AlwDVJLknybUk+uurX6yfvcxVwELh/sD+BJGkqU/0Q86p6OMn8qs2XAUeq6mmAJHcDO6vqFuDKNV5nP7A/yUHgD0/0nCS7gF0AW7dunWY8SdIUpgr+Gs4DnlvxeBH4rrWenOT7gR8Bvox1jvCrai+wF2BhYaE2MJ8kaYWNBP+kVNUngE/Man+SpJfbyFU6zwPnr3i8ZbJtw5LsSLJ3aWlpiJeTJLGx4D8KXJTkgiRnAVcD+4cYqqoOVNWuubm5IV5OksT0l2XeBTwCXJxkMcn1VfUicCPwAHAYuKeqnjx9o0qSNmLaq3SuWWP7/XiJpSS9KjR5awXX8CVpeE0G3zV8SRpek8GXJA2vyeC7pCNJw5vZN16djKo6ABxYWFh499iz6Mw2v/vgTPf3zJ7tM92fhjPLj5XT9XHS5BG+JGl4Bl+SOmHwJakTTQbfk7aSNLwmg+91+JI0vCaDL0kansGXpE4YfEnqRJPB96StJA0vVe3+2Ngkx4DPjz0HcA7wxbGHOAHnOjnOdXKc6+S0NNc3VtW5qzc2HfxWJDlUVQtjz7Gac50c5zo5znVyWp1rpSaXdCRJwzP4ktQJgz+dvWMPsAbnOjnOdXKc6+S0OtdLXMOXpE54hC9JnTD4ktQJg7+OJOcn+fMkn03yZJL3jj3TcUk2JfnrJB8de5aVkrw2yb1J/jbJ4STfPfZMAEl+fvJ3+ESSu5J8+Uhz3JHkaJInVmx7XZKHknxu8vvXNDLXb0z+Hj+d5I+TvLaFuVb8t/clqSTntDJXkpsm/8+eTPLrs57rlRj89b0IvK+qLgHeBvxMkktGnum49wKHxx7iBH4X+NOq+mbgLTQwY5LzgJ8FFqrqzcAm4OqRxtkHbFu1bTfw8aq6CPj45PGs7eP/z/UQ8Oaq+nbg74CbZz0UJ56LJOcDPwg8O+uBJvaxaq4kbwd2Am+pqm8FfnOEudZl8NdRVS9U1eOTt/+N5XidN+5UkGQLsB344NizrJRkDvg+4HaAqvqvqvqXcad6yWbgK5JsBs4G/mGMIarqYeCfV23eCXxo8vaHgHfNdChOPFdVPVhVL04e/iWwpYW5Jn4b+EVglKtO1pjrPcCeqvrPyXOOznywV2Dwp5RkHrgU+OS4kwDwOyx/sP/P2IOscgFwDPj9yXLTB5O8Zuyhqup5lo+2ngVeAJaq6sFxp3qZN1TVC5O3vwC8Ycxh1nAd8LGxhwBIshN4vqo+NfYsq7wJ+N4kn0zyF0m+c+yBVjP4U0jylcAfAT9XVf868ixXAker6rEx51jDZuCtwO9V1aXAvzPO8sTLTNbEd7L8CekbgNck+clxpzqxWr5OuqlrpZP8MsvLm3c2MMvZwC8Bvzr2LCewGXgdy8u/vwDckyTjjvRyBv8VJPlSlmN/Z1XdN/Y8wOXAVUmeAe4GfiDJH4w70ksWgcWqOv5V0L0sfwIY2zuBv6+qY1X138B9wPeMPNNK/5jk6wEmvzezFJDkp4ErgZ+oNr5p55tY/sT9qcm/gS3A40m+btSpli0C99Wyv2L5K/CZn1Bej8Ffx+Sz8+3A4ar6rbHnAaiqm6tqS1XNs3zi8c+qqomj1ar6AvBckosnm94BfHbEkY57FnhbkrMnf6fvoIGTySvsB66dvH0t8CcjzvKSJNtYXjq8qqr+Y+x5AKrqM1X1+qqan/wbWATeOvnYG9tHgLcDJHkTcBbt3D0TMPiv5HLgp1g+iv6bya8fHnuoxt0E3Jnk08B3AL828jxMvuK4F3gc+AzLH/ejfBt8kruAR4CLkywmuR7YA1yR5HMsfzWyp5G5bgW+Cnho8rH/gUbmGt0ac90BvHFyqebdwLWNfFX0Em+tIEmd8Ahfkjph8CWpEwZfkjph8CWpEwZfkjph8CWpEwZfkjrxv4BSyobNGM5CAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light", "tags": [] }, "output_type": "display_data" } ], "source": [ "plt.hist(tfd.Zipf(3.).sample(1_000, seed=(12, 34)).numpy(), bins='auto', density=True, log=True);" ] }, { "cell_type": "markdown", "metadata": { "id": "vjl-c4g78FBl" }, "source": [ "### `NormalInverseGaussian`\n", "\n", "裾が重い、歪んだ、バニラ正規分布をサポートする柔軟なパラメータのファミリ。" ] }, { "cell_type": "markdown", "metadata": { "id": "nXPmCA0k8G00" }, "source": [ "### `MatrixNormalLinearOperator`\n", "\n", "行列正規分布。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "v8uMP5hcr2kx" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/linalg/linear_operator_kronecker.py:224: LinearOperator.graph_parents (from tensorflow.python.ops.linalg.linear_operator) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Do not call `graph_parents`.\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 6, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "# Initialize a single 2 x 3 Matrix Normal.\n", "mu = [[1., 2, 3], [3., 4, 5]]\n", "col_cov = [[ 0.36, 0.12, 0.06],\n", " [ 0.12, 0.29, -0.13],\n", " [ 0.06, -0.13, 0.26]]\n", "scale_column = tf.linalg.LinearOperatorLowerTriangular(tf.linalg.cholesky(col_cov))\n", "scale_row = tf.linalg.LinearOperatorDiag([0.9, 0.8])\n", "\n", "mvn = tfd.MatrixNormalLinearOperator(loc=mu, scale_row=scale_row, scale_column=scale_column)\n", "mvn.sample()" ] }, { "cell_type": "markdown", "metadata": { "id": "Idxrh0IC8SGs" }, "source": [ "### `MatrixStudentTLinearOperator`\n", "\n", "行列 T 分布。" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "lQkQ-bw0sLr3" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 7, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "mu = [[1., 2, 3], [3., 4, 5]]\n", "col_cov = [[ 0.36, 0.12, 0.06],\n", " [ 0.12, 0.29, -0.13],\n", " [ 0.06, -0.13, 0.26]]\n", "scale_column = tf.linalg.LinearOperatorLowerTriangular(tf.linalg.cholesky(col_cov))\n", "scale_row = tf.linalg.LinearOperatorDiag([0.9, 0.8])\n", "\n", "mvn = tfd.MatrixTLinearOperator(\n", " df=2.,\n", " loc=mu,\n", " scale_row=scale_row,\n", " scale_column=scale_column)\n", "mvn.sample()" ] }, { "cell_type": "markdown", "metadata": { "id": "d-aAkORa77LE" }, "source": [ "## ディストリビューション [ソフトウェア/ラッパー]" ] }, { "cell_type": "markdown", "metadata": { "id": "-IVGCYN6o5SX" }, "source": [ "### `Sharded`\n", "\n", "複数のプロセッサにまたがるディストリビューションの独立したイベント部分をシャーディングします。デバイス間で `log_prob` を集約し、`tfp.experimental.distribute.JointDistribution*` と連携して勾配を処理します。詳細は[分散推論](https://www.tensorflow.org/probability/examples/Distributed_Inference_with_JAX)ノートブックを参照してください。" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "GpW6oXQjpKJj" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce.\n", "WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.\n", "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0', '/job:localhost/replica:0/task:0/device:CPU:1')\n", "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0', '/job:localhost/replica:0/task:0/device:CPU:1').\n" ] }, { "data": { "text/plain": [ "(PerReplica:{\n", " 0: ,\n", " 1: \n", " }, PerReplica:{\n", " 0: ,\n", " 1: \n", " })" ] }, "execution_count": 8, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "strategy = tf.distribute.MirroredStrategy()\n", "\n", "@tf.function\n", "def sample_and_lp(seed):\n", " d = tfp.experimental.distribute.Sharded(tfd.Normal(0, 1))\n", " s = d.sample(seed=seed)\n", " return s, d.log_prob(s)\n", "\n", "strategy.run(sample_and_lp, args=(tf.constant([12,34]),))" ] }, { "cell_type": "markdown", "metadata": { "id": "l3CaO7rPulgw" }, "source": [ "### `BatchBroadcast`\n", "\n", "基になる分布のバッチディメンションを特定のバッチ形状に暗黙的にブロードキャストします。" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "yKtX4e6xuq63" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "underlying: tfp.distributions.MultivariateNormalDiag(\"MultivariateNormalDiag\", batch_shape=[7, 1], event_shape=[5], dtype=float32)\n", "broadcast [7, 1] *with* [8, 1, 6]: tfp.distributions.BatchBroadcast(\"BatchBroadcastMultivariateNormalDiag\", batch_shape=[8, 7, 6], event_shape=[5], dtype=float32)\n", "broadcast [7, 1] *to* [8, 1, 6] is invalid: Argument `to_shape` ([8 1 6]) is incompatible with underlying distribution batch shape ((7, 1)).\n", "broadcast [7, 1] *to* [8, 7, 6]: tfp.distributions.BatchBroadcast(\"BatchBroadcastMultivariateNormalDiag\", batch_shape=[8, 7, 6], event_shape=[5], dtype=float32)\n" ] } ], "source": [ "underlying = tfd.MultivariateNormalDiag(tf.zeros([7, 1, 5]), tf.ones([5]))\n", "print('underlying:', underlying)\n", "\n", "d = tfd.BatchBroadcast(underlying, [8, 1, 6])\n", "print('broadcast [7, 1] *with* [8, 1, 6]:', d)\n", "\n", "try:\n", " tfd.BatchBroadcast(underlying, to_shape=[8, 1, 6])\n", "except ValueError as e:\n", " print('broadcast [7, 1] *to* [8, 1, 6] is invalid:', e)\n", "\n", "d = tfd.BatchBroadcast(underlying, to_shape=[8, 7, 6])\n", "print('broadcast [7, 1] *to* [8, 7, 6]:', d)" ] }, { "cell_type": "markdown", "metadata": { "id": "d8DyDFP5WK6B" }, "source": [ "### `Masked`\n", "\n", "単一プログラム/複数データまたはスパースにマスクされた高密度のユースケースの場合、無効な基になるディストリビューションの `log_prob` をマスクするディストリビューション。" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "id": "4DQQ7VTwWbAa" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(\n", "[[-2.3054113 -1.8524303 -1.2220721 0. 0. 0.\n", " 0. ]\n", " [-1.118623 -1.1370811 -1.1574132 -5.884986 0. 0.\n", " 0. ]], shape=(2, 7), dtype=float32)\n", "tf.Tensor([ 0. -0.93683904 0. ], shape=(3,), dtype=float32)\n" ] } ], "source": [ "d = tfd.Masked(tfd.Normal(tf.zeros([7]), 1), \n", " validity_mask=tf.sequence_mask([3, 4], 7))\n", "print(d.log_prob(d.sample(seed=(1, 1))))\n", "\n", "d = tfd.Masked(tfd.Normal(0, 1), \n", " validity_mask=[False, True, False],\n", " safe_sample_fn=tfd.Distribution.mode)\n", "print(d.log_prob(d.sample(seed=(2, 2))))" ] }, { "cell_type": "markdown", "metadata": { "id": "Y52jA6ypIQm1" }, "source": [ "## Bijectors\n", "\n", "- Bijectors\n", " - `tf.nest.flatten` (`tfb.tree_flatten`) と`tf.nest.pack_sequence_as` (`tfb.pack_sequence_as`)を模倣するバイジェクタを追加します。\n", " - `tfp.experimental.bijectors.Sharded` を追加します。\n", " - 非推奨の `tfb.ScaleTrilL` を削除します。代わりに `tfb.FillScaleTriL` を使用します。\n", " - Bijectors に `cls.parameter_properties()` アノテーションを追加します。\n", " - 範囲 `tfb.Power` を、奇数の整数乗のすべての実数に拡張します。\n", " - 特に指定がない限り、autodiff を使用してスカラー bijector の対数次数ヤコビアンを推測します。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "Gtls14gqtIpk" }, "source": [ "### バイジェクタの再構築" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "id": "GabDIiMAtPN2" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[, , ]\n", "(, {'b': , 'c': })\n", "(, {'b': , 'c': })\n", "[, , ]\n" ] } ], "source": [ "ex = (tf.constant(1.), dict(b=tf.constant(2.), c=tf.constant(3.)))\n", "b = tfb.tree_flatten(ex)\n", "print(b.forward(ex))\n", "print(b.inverse(list(tf.constant([1., 2, 3]))))\n", "\n", "b = tfb.pack_sequence_as(ex)\n", "print(b.forward(list(tf.constant([1., 2, 3]))))\n", "print(b.inverse(ex))" ] }, { "cell_type": "markdown", "metadata": { "id": "ibpH6g2zsR6i" }, "source": [ "### `Sharded`\n", "\n", "対数行列式の SPMD 削減。下記のディストリビューションの `Sharded` を参照してください。" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "id": "Ja65bfTQsXnD" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce.\n", "WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.\n", "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0', '/job:localhost/replica:0/task:0/device:CPU:1')\n", "WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.\n", "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0', '/job:localhost/replica:0/task:0/device:CPU:1').\n", "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0', '/job:localhost/replica:0/task:0/device:CPU:1').\n" ] }, { "data": { "text/plain": [ "(PerReplica:{\n", " 0: ,\n", " 1: \n", " }, PerReplica:{\n", " 0: ,\n", " 1: \n", " }, PerReplica:{\n", " 0: ,\n", " 1: \n", " })" ] }, "execution_count": 13, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "strategy = tf.distribute.MirroredStrategy()\n", "\n", "def sample_lp_logdet(seed):\n", " d = tfd.TransformedDistribution(tfp.experimental.distribute.Sharded(tfd.Normal(0, 1), shard_axis_name='i'),\n", " tfp.experimental.bijectors.Sharded(tfb.Sigmoid(), shard_axis_name='i'))\n", " s = d.sample(seed=seed)\n", " return s, d.log_prob(s), d.bijector.inverse_log_det_jacobian(s)\n", "strategy.run(sample_lp_logdet, (tf.constant([1, 2]),))" ] }, { "cell_type": "markdown", "metadata": { "id": "nXyY5bgLIyrf" }, "source": [ "## VI\n", "\n", "- `build_split_flow_surrogate_posterior` を `tfp.experimental.vi` に追加して、正規化フローから構造化 VI 代理事後確率を構築します。\n", "- `build_affine_surrogate_posterior` を `tfp.experimental.vi` に追加して、イベント形状から ADVI 代理事後確率を構築します。\n", "- `build_affine_surrogate_posterior_from_base_distribution` を `tfp.experimental.vi` に追加して、アフィン変換によって誘導された相関構造を持つ ADVI 代理事後確率の構築を可能にします。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "LQRLYKUSyZGf" }, "source": [ "### VI/MAP/MLE\n", "\n", "- ディストリビューションとバイジェクターのトレーニング可能なインスタンスを作成するための便利なメソッド `tfp.experimental.util.make_trainable(cls)` を追加しました。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "O-YaQ-SWwGr9" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(, )\n", "tfp.distributions.Gamma(\"Gamma\", batch_shape=[], event_shape=[], dtype=float32)\n" ] } ], "source": [ "d = tfp.experimental.util.make_trainable(tfd.Gamma)\n", "print(d.trainable_variables)\n", "print(d)" ] }, { "cell_type": "markdown", "metadata": { "id": "FUx2oXURd3nE" }, "source": [ "## MCMC\n", "\n", "- MCMC 診断は、リストだけでなく、状態の任意の構造をサポートします。\n", "- `remc_thermodynamic_integrals` が`tfp.experimental.mcmc` に追加されました。\n", "- `tfp.experimental.mcmc.windowed_adaptive_hmc` を追加します。\n", "- 制約のない空間でほぼゼロの一様分布からマルコフ連鎖を初期化するための実験的な API を追加します。`tfp.experimental.mcmc.init_near_unconstrained_zero`\n", "- 許容できる点が見つかるまでマルコフ連鎖の初期化を再試行するための実験的なユーティリティを追加します。`tfp.experimental.mcmc.retry_init`\n", "- 実験的なストリーミング MCMC API をシャッフルして、中断を最小限に抑えて tfp.mcmc にスロットインします。\n", "- `ThinningKernel` to `experimental.mcmc` を追加します。\n", "- ストリーミングベースの代替候補として `experimental.mcmc.run_kernel`ドライバを `mcmc.sample_chain` に追加します\n" ] }, { "cell_type": "markdown", "metadata": { "id": "8euU8cFYIWwc" }, "source": [ "### `init_near_unconstrained_zero`、`retry_init`" ] }, { "cell_type": "code", "execution_count": 73, "metadata": { "id": "F7y01nhcIJaB" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tfp.distributions.TransformedDistribution(\"default_joint_bijectorrestructureJointDistributionSequential\", batch_shape=StructTuple(\n", " c0=[],\n", " c1=[]\n", "), event_shape=StructTuple(\n", " c0=[],\n", " c1=[]\n", "), dtype=StructTuple(\n", " c0=float32,\n", " c1=float32\n", "))\n" ] }, { "data": { "text/plain": [ "StructTuple(\n", " c0=,\n", " c1=\n", ")" ] }, "execution_count": 73, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "@tfd.JointDistributionCoroutine\n", "def model():\n", " Root = tfd.JointDistributionCoroutine.Root\n", " c0 = yield Root(tfd.Gamma(2, 2, name='c0'))\n", " c1 = yield Root(tfd.Gamma(2, 2, name='c1'))\n", " counts = yield tfd.Sample(tfd.BetaBinomial(23, c1, c0), 10, name='counts')\n", "jd = model.experimental_pin(counts=model.sample(seed=[20, 30]).counts)\n", "\n", "init_dist = tfp.experimental.mcmc.init_near_unconstrained_zero(jd)\n", "print(init_dist)\n", "\n", "tfp.experimental.mcmc.retry_init(init_dist.sample, jd.unnormalized_log_prob)" ] }, { "cell_type": "markdown", "metadata": { "id": "zETWMfIZ9Vp9" }, "source": [ "### ウィンドウ化されたアダプティブ HMC および NUTS サンプラー" ] }, { "cell_type": "code", "execution_count": 41, "metadata": { "id": "HWaHCIij-RQX" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "StructTuple(\n", " c0=,\n", " c1=,\n", " counts=\n", ")\n", "WARNING:tensorflow:6 out of the last 6 calls to triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n", "StructTuple(\n", " c0=,\n", " c1=,\n", " counts=\n", ")\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light", "tags": [] }, "output_type": "display_data" } ], "source": [ "fig, ax = plt.subplots(1, 2, figsize=(10, 4))\n", "for i, n_evidence in enumerate((10, 250)):\n", " ax[i].set_title(f'n evidence = {n_evidence}')\n", " ax[i].set_xlim(0, 2.5); ax[i].set_ylim(0, 3.5)\n", " @tfd.JointDistributionCoroutine\n", " def model():\n", " Root = tfd.JointDistributionCoroutine.Root\n", " c0 = yield Root(tfd.Gamma(2, 2, name='c0'))\n", " c1 = yield Root(tfd.Gamma(2, 2, name='c1'))\n", " counts = yield tfd.Sample(tfd.BetaBinomial(23, c1, c0), n_evidence, name='counts')\n", " s = model.sample(seed=[20, 30])\n", " print(s)\n", " jd = model.experimental_pin(counts=s.counts)\n", " states, trace = tf.function(tfp.experimental.mcmc.windowed_adaptive_hmc)(\n", " 100, jd, num_leapfrog_steps=5, seed=[100, 200])\n", " ax[i].scatter(states.c0.numpy().reshape(-1), states.c1.numpy().reshape(-1), \n", " marker='+', alpha=.1)\n", " ax[i].scatter(s.c0, s.c1, marker='+', color='r')" ] }, { "cell_type": "markdown", "metadata": { "id": "UwebvDOFIpQ4" }, "source": [ "## 数学、統計\n", "\n", "- 数学/線形代数\n", "\n", " - 台形積分用に `tfp.math.trapz` が追加。\n", " - `tfp.math.log_bessel_kve` が追加。\n", " - `no_pivot_ldl` が `experimental.linalg` を追加。\n", " - `marginal_fn` が引数 `GaussianProcess` に追加。(`no_pivot_ldl` を参照)。\n", " - `tfp.math.atan_difference(x, y)` が追加。\n", " - `tfp.math.erfcx`、`tfp.math.logerfc` および `tfp.math.logerfcx` が追加。\n", " - ドーソン関数に `tfp.math.dawsn` が追加。\n", " - `tfp.math.igammaincinv`、 `tfp.math.igammacinv` が追加。\n", " - `tfp.math.sqrt1pm1` を追加します。\n", " - `LogitNormal.stddev_approx` と`LogitNormal.variance_approx` が追加。\n", " - Owen の T 関数に `tfp.math.owens_t` が追加。\n", " - ルート検索の境界を自動的に初期化する `bracket_root` メソッドが追加。\n", " - スカラー関数の根を見つけるための Chandrupatla のメソッドが追加。\n", "\n", "- 統計\n", "\n", " - `tfp.stats.windowed_mean` はウィンドウ化された平均を効率的に計算します。\n", " - `tfp.stats.windowed_variance` はウィンドウ化された分散を効率的かつ正確に計算します。\n", " - `tfp.stats.cumulative_variance` は累積分散を効率的かつ正確に計算します。\n", " - `RunningCovariance` などでは、明示的な形状と dtype だけでなく、Tensor の例から初期化できるようになりました。\n", " - `RunningCentralMoments`、`RunningMean`、`RunningPotentialScaleReduction` 用のよりクリーンな API。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "f375DSTDA-DV" }, "source": [ "### Owen の T、Erfcx、Logerfc、Logerfcx、ダーソン関数" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RH7qe5lpBI0M" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Calculated values: [0.07896 0.01134]\n", "Expected values: [0.07932763 0.01137507]\n" ] } ], "source": [ "# Owen's T gives the probability that X > h, 0 < Y < a * X. Let's check that\n", "# with random sampling.\n", "h = np.array([1., 2.]).astype(np.float32)\n", "a = np.array([10., 11.5]).astype(np.float32)\n", "probs = tfp.math.owens_t(h, a)\n", "\n", "x = tfd.Normal(0., 1.).sample(int(1e5), seed=(6, 245)).numpy()\n", "y = tfd.Normal(0., 1.).sample(int(1e5), seed=(7, 245)).numpy()\n", "\n", "true_values = (\n", " (x[..., np.newaxis] > h) &\n", " (0. < y[..., np.newaxis]) &\n", " (y[..., np.newaxis] < a * x[..., np.newaxis]))\n", "\n", "print('Calculated values: {}'.format(\n", " np.count_nonzero(true_values, axis=0) / 1e5))\n", "\n", "print('Expected values: {}'.format(probs))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "VWZjRfnLG5sc" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light", "tags": [] }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light", "tags": [] }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light", "tags": [] }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light", "tags": [] }, "output_type": "display_data" } ], "source": [ "x = np.linspace(-3., 3., 100)\n", "plt.plot(x, tfp.math.erfcx(x))\n", "plt.ylabel('$erfcx(x)$')\n", "plt.show()\n", "\n", "plt.plot(x, tfp.math.logerfcx(x))\n", "plt.ylabel('$logerfcx(x)$')\n", "plt.show()\n", "\n", "plt.plot(x, tfp.math.logerfc(x))\n", "plt.ylabel('$logerfc(x)$')\n", "plt.show()\n", "\n", "plt.plot(x, tfp.math.dawsn(x))\n", "plt.ylabel('$dawsn(x)$')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "TAo7wY-vIbeR" }, "source": [ "### igammainv / igammacinv" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2Le6YC8JIkII" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "x: [ 1. 2. 3. 4. 5. 6. 7. 8. 9. 10.]\n", "igammainv(igamma(a, x)):\n", " [1. 1.9999992 3.000003 4.0000024 5.0000257 5.999887 7.0002484\n", " 7.999243 8.99872 9.994673 ]\n", "\n", "\n", "x: [ 1. 2. 3. 4. 5. 6. 7. 8. 9. 10.]\n", "igammacinv(igammac(a, x)):\n", " [1. 2. 3. 4. 5. 6. 7. 8.000001\n", " 9. 9.999999]\n" ] } ], "source": [ "# Igammainv and Igammacinv are inverses to Igamma and Igammac\n", "\n", "x = np.linspace(1., 10., 10)\n", "y = tf.math.igamma(0.3, x)\n", "x_prime = tfp.math.igammainv(0.3, y)\n", "print('x: {}'.format(x))\n", "print('igammainv(igamma(a, x)):\\n {}'.format(x_prime))\n", "\n", "y = tf.math.igammac(0.3, x)\n", "x_prime = tfp.math.igammacinv(0.3, y)\n", "\n", "print('\\n')\n", "print('x: {}'.format(x))\n", "print('igammacinv(igammac(a, x)):\\n {}'.format(x_prime))" ] }, { "cell_type": "markdown", "metadata": { "id": "WPNKYHVP9bs9" }, "source": [ "### log-kve" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "OuQOhwJTHidN" }, "outputs": [ { "data": { "text/plain": [ "Text(0.5, 1.0, 'Log(BesselKve(v, x)')" ] }, "execution_count": 26, "metadata": { "tags": [] }, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light", "tags": [] }, "output_type": "display_data" } ], "source": [ "x = np.linspace(0., 5., 100)\n", "for v in [0.5, 2., 3]:\n", " plt.plot(x, tfp.math.log_bessel_kve(v, x).numpy())\n", "\n", "plt.title('Log(BesselKve(v, x)')" ] }, { "cell_type": "markdown", "metadata": { "id": "y6UbD6n0jbgn" }, "source": [ "## その他\n", "\n", "- STS\n", "\n", " - 内部 `tf.function` ラッピングを使用して、STS の予測と分解を高速化します。\n", " - 最終ステップの結果のみが必要な場合に、`LinearGaussianSSM` でフィルタリングを高速化するオプションを追加します。\n", " - 同時分布による変分推論:[ラドンモデルを使用したノートブックの例](https://www.tensorflow.org/probability/examples/Variational_Inference_and_Joint_Distributions)。\n", " - ディストリビューションを前処理バイジェクタに変換するための実験的サポートを追加します。\n", "\n", "- `tfp.random.sanitize_seed` を追加します。\n", "\n", "- `tfp.random.spherical_uniform` を追加します。\n" ] }, { "cell_type": "code", "execution_count": 64, "metadata": { "id": "bqEYFeZhG_yW" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAARMAAAD4CAYAAADPXQJNAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAWdUlEQVR4nO3df4xV9ZnH8fdH2NG0W5Vfq1QFNDWrEoy2Yy2raXf705YE2q0t1DRVViPu1t1EU6NWlzW4VZSk7jbtthArarMrY23apcXGqOg2ZpE6ZKkwGCtqqbJjpWDBLhUKPPvHOWe4XO+duTP33HN/fV7Jzb3n1z3fM8w8POd7zvk+igjMzOp1VLMbYGadwcHEzHLhYGJmuXAwMbNcOJiYWS7GN7sBYzF58uSYMWNGs5th1nU2bNjw24iYUmlZWwaTGTNm0N/f3+xmmHUdSduqLfNpjpnlwsHEzHLhYGJmuXAwMbNcOJiYWS5yCSaS7pH0uqTNVZZL0jckbZX0rKT3liy7VNIL6evSPNpjzTN/+TrmL19XeeHKOcmr1vWtreR1afhe4JvA/VWWfxI4PX2dD3wbOF/SROCfgF4ggA2SVkfEGzm1y3I0f/k6vjJ4Le/oGcfMrz7VuB2lAWf+/psB6Fs0u3H7stzkEkwi4meSZgyzyjzg/kjGO3ha0vGSpgJ/CTwaEbsAJD0KXAQ8kEe7rDhZdrH+5V1HTPctmn04G9mWBqCVcxgY3M2SScsqrz+G/TrgNF9RN62dBLxSMv1qOq/a/LeRdCVwJcC0adMa00qrKMtIrjkUnHfUc7Afnrlldu4ZyuKd18HK44aCzjWHrkn3fxfggNHq2uYO2IhYAawA6O3t9YhOLSb7Q6+YKSxck7xnGcrCNcwE+srXX3lczfsbNhMqW8dBqBhFBZPtwCkl0yen87aTnOqUzn+yoDZZjZI/xnVH9Jmc14g+k7Kgc1fOfSYOLo1VVDBZDVwtaRVJB+zuiBiU9Ahwm6QJ6XofB24sqE3WAMP+oWbBotb1a9jPcBlJedZijZVLMJH0AEmGMVnSqyRXaP4EICK+AzwMfArYCuwFFqbLdkm6FXgm/aolWWestZ4sQ2m4NOj05fR1Wwb3APDmWwcAZyiNonYcULq3tzf81LCNpDwjyTKV80+dCDiYjIWkDRHRW2lZ23TAmo3VsJ3DlhsHE+tYeQQNB6DaOZhY1xhrQBi6/6VCB7Id5gf9zCrInhla//Iu3nzrAAODu31VaATOTMyqWLzzOt7sOcAH0rt+naEMz5mJWQV9i2Yzc+pxvOuYw//fzpxa+x263cjBxKyahWtYMmkZAz2zYPqFSUbirKQqBxOzYWQZykg8Lov7TMxG5mykJg4mZnWo5enlbuHTHDPLhTMTszr4Vv3DnJk0Q4WBlc3anTMTsxx0c0aScTApUoWBlQFfLbCO4GBSBJ/SWBfIa6S1i4B/BcYBd0fE0rLldwF/lU6+A/iziDg+XXYQ2JQu+3VEzM2jTa1kYHA3wOGR3J2RWAeqO5hIGgd8C/gYSamKZ9JCWluydSLimpL1/x44t+Qr/hAR59TbjpaUBo2Z+zcdMW02pIP+Y8kjM3k/sDUiXgJIB42eB2ypsv4XSMaI7XhDGUn5dCOr4Zk1SR7BpFIhrfMrrShpOnAqsLZk9jGS+oEDwNKI+FGVbduuCNeSScsAuGZ7kpjdlU7nNVCytbGyzviB2y4E2vs/mqI7YBcAD0XEwZJ50yNiu6TTgLWSNkXEi+UbtmMRruxy4cBt44+YNutEeQSTagW2KlkAfLl0RkRsT99fkvQkSX/K24JJO1vijMTKpX0kWUYyZ09SLur8Nr6TNo9g8gxwuqRTSYLIAuCS8pUknQFMoKTwSlp8a29E7JM0GbgAuDOHNhVvmI60dvzFMButuoNJRByQdDXwCMml4XsiYkDSEqA/Ilanqy4AVsWRhXrOBJZLOkRya//S0qtA7WL+8nUs3rnbI3HZqGV9JO2ckWRy6TOJiIdJqvaVzltcNn1Lhe3+G5iVRxua5vZTuGf/Qd4Z/wfb6KhLfWaj4Ttg6zB/+Tru2X+Qg4cClMwbGHSGYqPXzhlJxsFkrFbOYfHO3UlGItgT72DcUWLJpGX0LWz/Xwyz0fIQBHUozUDGHSXe2TOuI/6HMRsLB5OxykYqn34hAz2z+JsTfwA3vjLydmYdyqc5OZg59Tif2ljXczCpVbWrNL5qYwb4NMfMcuLMZCQr5ySXe8uHEXBGYnYEZyZmlgtnJsNIyhfczPo9u1jVcyvvOmY8S/bf7Mu/ZhU4MzGzXOjI5+7aQ29vb/T39xe2PxdYMktI2hARvZWWOTMxs1y4z6QGzkjMRubMxKwTtEDJWQcTM8tFLsFE0kWSnpe0VdINFZZfJmmHpI3p64qSZZdKeiF9XZpHe8y6RpaRbHsqeTUxQymkCFeqLyKuLtt2IkkNnV4ggA3ptm/U2y6zrvTas03bdTOKcJX6BPBoROxKt30UuAh4IId2mXW+7LGOFqgWWWQRrs9K+iDwS+CaiHilyrYnVdpJOxbhMitMlpHs2wMkJTRmTj2u0GfIiuqA/TEwIyLOBh4F7hvtF0TEiojojYjeKVOm5N5As7Z24tnJK7V3/8GhcrRFySOYjFiEKyJ2RsS+dPJu4H21bmtmNVi4hvn7b2agZxZPHzqTz711Mwv2/+PQ3dtFyCOYDBXhktRDUh9ndekKkqaWTM4Fnks/PwJ8XNKEtCDXx9N5ZjYGe/cfrrz75lsH2DK4p7CAUlQRrn+QNJekOPku4LJ0212SbiUJSABLss5YMxudvkWzmb/862wZ3AMcSJ507xnPEpYVsn8/6GfWYeYvX8eWwT2s6rk1907Y4R7087M5Zh2mr+efGehJRwcssMqkb6c360DNqCrpzMSs05TfyFbQvSbOTMwsF85MzDpVwRUUnJmYWS66L5i0wCAyZk3R4N/97jvNaeIj2madrHuCSRaR06cqXZnPukb2u77tqSOnc/7d755gUp6ROEOxLjUwuLsh96F0TzDJHs/OonPJ49pmHa3kvpOBwd0smbSMvoX5V1zonmCS/UBvP+XIabMuMH/5Ohbv3M2bbx1g/cu7GlJYrnuCScYZiXWpJZOWsf7lxj2U333BxBmJdaEsA2lkqduuuM9k1i2PMOsWj7lkBrB453UNud+kqLo510raIulZSY9Lml6y7GBJPZ3V5duaWX76Fs1u2BPFRdXN+R+gNyL2Svpb4E5gfrrsDxFxTr3tqCTLRt5868AR05tu+UQjdmfW2hp8v0kemclQ3ZyI2A9kdXOGRMQTEbE3nXyaZOBoM+sgRdbNyVwO/LRk+hhJ/STjwy6NiB9V2mgsdXOyDMQZiRkNH+ek0Ks5kr5IUgr0QyWzp0fEdkmnAWslbYqIF8u3jYgVwApIxoAtpMFmVrM8gklNtW8kfRS4CfhQSQ0dImJ7+v6SpCeBc4G3BZN6OCMxK9Gg2yOKqptzLrAcmBsRr5fMnyDp6PTzZOACaqtRbGYtpqi6OcuAPwW+Lwng1xExFzgTWC7pEElgW1p2FcjM2oTr5phZzYarm9MVd8CaWeM5mJhZLhxMzCwXDiZm3SzHQaYdTMwsF903nomZNeShP2cmZpYLZyZm3agBD/05MzGzXDgzMetmOT7017mZiWsKmxWqc4OJmRWq805zCqqramZHcmZiZrnovMykweNcmlllzkzMLBdFFeE6WlJfuny9pBkly25M5z8vKb/BWheucVZiVqC6g0lJEa5PAmcBX5B0VtlqlwNvRMR7gLuAO9JtzyIZM3YmcBHwb+n3mVmbKaQIVzp9X/r5IeAjSgaDnQesioh9EfEysDX9PjMryPzl64YKmtcjj2BSqQjXSdXWiYgDwG5gUo3bAkkRLkn9kvp37NiRQ7PNLE9tczXHRbjM8pVlI+tf3nXEdN+i2WP6vjwyk1qKcA2tI2k8cByws8ZtzawN5JGZDBXhIgkEC4BLytZZDVwKrAMuBtZGREhaDfyHpK8D7wZOB36eQ5vMbARZBlJvRpIpqgjXd4HvSdoK7CIJOKTrPUhSxe8A8OWIOFhvm8yseC7CZWY1cxEuM2s4BxOzbuZSF2bWatrmPhMzy5FLXZhZq3JmYtaNXOrCzFqVMxOzbuZSF2bWahxMzCwXHR9M8hr4xcyG1/HBxMyK0bEdsHkP/GJmw3NmYma56NjMJO+BX8xseM5MzCwXdWUmkiYCfcAM4FfA5yPijbJ1zgG+DRwLHAS+FhF96bJ7gQ+RjFYPcFlEbKynTeWckZgVo97M5Abg8Yg4HXg8nS63F/hSRGSFtv5F0vEly6+LiHPSV66BxMzerlG3S9QbTEqLa90HfLp8hYj4ZUS8kH7+X+B1YEqd+zWzFlNvB+wJETGYfn4NOGG4lSW9H+gBXiyZ/TVJi0kzm4jYV2XbK4ErAaZNm1Zns826T6NvlxgxM5H0mKTNFV5HlACNZGTqqqNTS5oKfA9YGBGH0tk3AmcA5wETgeurbR8RKyKiNyJ6p0xxYmPWakbMTCLio9WWSfqNpKkRMZgGi9errHcssAa4KSKeLvnuLKvZJ2kl8JVRtd7Matbo2yXq7TPJimuRvv9n+QqSeoAfAvdHxENly6am7yLpb9lcZ3vMrEnqqpsjaRLwIDAN2EZyaXiXpF7gqoi4QtIXgZXAQMmml0XERklrSTpjBWxMt/n9SPt13Ryz5hiubo6LcJlZzVyEy8wazsHEzHLhYGJmuXAwMetwRY026GBiZrno2PFMzLpd0aMNOjMxs1w4MzHrUEWPNujMxMxy4czErMMVNdqgMxMzy4WDiVkHakYlSwcTM8uF+0zMOkgzK1k6MzGzXDgzMesgzaxkWVdmImmipEclvZC+T6iy3kFJG9PX6pL5p0paL2mrpL50iEczG4uVc5JXkxRRhAvgDyWFtuaWzL8DuCsi3gO8AVxeZ3vMutPKOfDas0OTfYtmF17NsuFFuKpJB5H+MJANMj2q7c0slQWSfXtg21NNy1DqDSa1FuE6RlK/pKclZQFjEvC7iDiQTr8KnFRtR5KuTL+jf8eOHXU226xDrJwDv16XBJJMSYZSpBE7YCU9BpxYYdFNpRMREZKqjU49PSK2SzoNWCtpE4eLldckIlYAKyAZUHo0245FMzqwzOp29LFw4tmwcE3huy6kCFdEbE/fX5L0JHAu8APgeEnj0+zkZGD7GI7BrDtlpzJxMHnXuOS9CYEE6r80nBXhWkr1IlwTgL0RsU/SZOAC4M40k3kCuBhYVW37ojXzph+zUWnS6Uw19faZLAU+JukF4KPpNJJ6Jd2drnMm0C/pF8ATwNKI2JIuux64VtJWkj6U79bZHrPuceLZySszbfaR0wVzEa4qnJFY27j9lOT9xlcavqvhinD5DlizdtfEbKSUg0kVFTOSrMOrSR1cZhW1yO+jH/Qzs1w4M6lFlpFseyp5v/2Upl3LN2tVzkzMLBcOJrXIMpCjj03es1uXm/iEplmrcTAxs1w4mNRq4ZrkOv70C5PXwjUMDO5m4LYLm90ys5bgYGJmufAdsGOQZSMz929KpntmJdNffappbTIrwnB3wDozMbNc+D6TMcgykKEMxRmJmTMTs1bQjAp8eXNmUgdnJGaHOZiYNVEnDcbV8Lo5kv6qpGbORklvZYNKS7pX0ssly86ppz1m1jx1XRqWdCewKyKWSroBmBAR1w+z/kRgK3ByROyVdC/wk4h4qNo2lTT70rBZXSoMZdEuGUkjLw2Ptm7OxcBPI2Jvnfs1sxZTb2byu4g4Pv0s4I1susr6a4GvR8RP0ul7gdnAPtKKgBGxr8q2VwJXAkybNu1927ZtG3O7zZqifCiL6emjGG00lEVdmYmkxyRtrvCaV7peJFGpamRKS2HMAh4pmX0jcAZwHjCRZIDpiiJiRUT0RkTvlClTRmq2mRWskLo5qc8DP4yIP5Z8d1YNcJ+klcBXamy3WfvJMpAOHf6z3j6TrG4OjFz35gvAA6Uz0gCUnSJ9GthcZ3vMrEnqvc9kKfCgpMuBbSTZB5J6gasi4op0egZwCvBfZdv/u6QpgICNwFV1tses9XVYRpKpK5hExE7gIxXm9wNXlEz/igpFySPiw/Xs38xah5/NMbNcOJiYWS4cTMzqsXKOBxZPOZiYWS781LDZWJTfzdqh946MhjMTM8uFMxOzsejwu1nHwpmJmeXCmYlZPZyRDHFmYma5cDAxs1w4mJhZLhxMzCwXDiZdrhOKP1lrcDCx9ufnY1qCLw13qU4q/mStod4iXJ+TNCDpUDq6WrX1LpL0vKStaX2dbP6pktan8/sk9dTTHmu+vE6bavqeLCPZ9lTycobSVPVmJpuBvwaWV1tB0jjgW8DHgFeBZyStjogtwB3AXRGxStJ3gMuBb9fZJqtBloE4I7G81Dts43MAyXjQVb0f2BoRL6XrrgLmSXoO+DBwSbrefcAtOJi0pbxOm0b1PX4+pqUU0WdyEvBKyfSrwPnAJOB3EXGgZP7bxonNlBXhakxLu5AzEsvLiMFE0mPAiRUW3RQRw5W2yFVErABWQFJruKj9Wm3yOm0a0/c4I2kJdRXhqtF2kjIXmZPTeTuB4yWNT7OTbL6ZtaEiTnOeAU6XdCpJsFgAXBIRIekJkmLmqxi5iJe1gbxOm3z61X7qvTT8GUmvkhQfXyPpkXT+uyU9DJBmHVeT1Bh+DngwIgbSr7geuFbSVpI+lO/W0x4zax4l9cbbS29vb/T39ze7GWZdR9KGiKh4T5lvpzezXDiYmFkuHEzMLBcOJmaWi7bsgJW0A9hWw6qTgd82uDlF6ZRj6ZTjgO48lukRMaXSgrYMJrWS1F+t57nddMqxdMpxgI+lnE9zzCwXDiZmlotODyYrmt2AHHXKsXTKcYCP5Qgd3WdiZsXp9MzEzAriYGJmueioYFLvANetRNJESY9KeiF9n1BlvYOSNqav1UW3s5qRfsaSjk4HEd+aDio+o/hW1qaGY7lM0o6Sf4crmtHOkUi6R9LrkjZXWS5J30iP81lJ7x3VDiKiY17AmcCfA08CvVXWGQe8CJwG9AC/AM5qdtsrtPNO4Ib08w3AHVXW+32z2zqWnzHwd8B30s8LgL5mt7uOY7kM+Gaz21rDsXwQeC+wucryTwE/BQR8AFg/mu/vqMwkIp6LiOdHWG1ogOuI2E8yMNO8xrdu1OaRDLJN+v7pJrZltGr5GZce30PARzTCyORN0i6/LyOKiJ8Bu4ZZZR5wfySeJhkJcWqt399RwaRGlQa4rjqQdROdEBGD6efXgBOqrHeMpH5JT0tqlYBTy894aJ1IBtDaTTJAVqup9ffls+mpwUOSTqmwvB3U9bfRdhX9WmWA6zwMdyylExERkqpdw58eEdslnQaslbQpIl7Mu602rB8DD0TEPkmLSDKuDze5TYVru2ASjRvgunDDHYuk30iaGhGDaar5epXv2J6+vyTpSeBcknP8ZqrlZ5yt86qk8cBxJIOMt5oRjyUiStt9N0l/Vzuq62+jG09zhga4TsuRLgBa5ipIidUkg2xDlcG2JU2QdHT6eTJwAbClsBZWV8vPuPT4LgbWRtoL2GJGPJayfoW5JGMdt6PVwJfSqzofAHaXnGqPrNk9zDn3Vn+G5DxvH/Ab4JF0/ruBh8t6rX9J8j/4Tc1ud5VjmQQ8DrwAPAZMTOf3Anenn/8C2ERyhWETcHmz2z3czxhYAsxNPx8DfB/YCvwcOK3Zba7jWG4HBtJ/hyeAM5rd5irH8QAwCPwx/Tu5HLgKuCpdLpJSvi+mv08Vr4hWe/l2ejPLRTee5phZAziYmFkuHEzMLBcOJmaWCwcTM8uFg4mZ5cLBxMxy8f+RMgOAMaUEDAAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light", "tags": [] }, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(4, 4))\n", "seed = tfp.random.sanitize_seed(123)\n", "seed1, seed2 = tfp.random.split_seed(seed)\n", "samps = tfp.random.spherical_uniform([30], dimension=2, seed=seed1)\n", "plt.scatter(*samps.numpy().T, marker='+')\n", "samps = tfp.random.spherical_uniform([30], dimension=2, seed=seed2)\n", "plt.scatter(*samps.numpy().T, marker='+');" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "TFP_Release_Notebook_0_13_0.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }