{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "5rmpybwysXGV" }, "source": [ "##### Copyright 2019 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "colab": {}, "colab_type": "code", "id": "m8y3rGtQsYP2" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "hrXv0rU9sIma" }, "source": [ "# カスタム訓練:基本" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "7S0BwJ_8sLu7" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
\n", " View on TensorFlow.org\n", " \n", " Run in Google Colab\n", " \n", " View source on GitHub\n", " \n", " Download notebook\n", "
" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "k2o3TTG4TFpt" }, "source": [ "前のチュートリアルでは、機械学習の基本構成ブロックの1つである自動微分について TensorFlow の API を学習しました。\n", "このチュートリアルでは、これまでのチュートリアルに出てきた TensorFlow の基本要素を使って、単純な機械学習を実行します。\n", "\n", "TensorFlow には `tf.keras` が含まれています。`tf.keras`は、抽象化により決まり切った記述を削減し、柔軟さと性能を犠牲にすることなく TensorFlow をやさしく使えるようにする、高度なニューラルネットワーク API です。開発には [tf.Keras API](../../guide/keras/overview.ipynb) を使うことを強くおすすめします。しかしながら、この短いチュートリアルでは、しっかりした基礎を身につけていただくために、ニューラルネットワークの訓練についていちから学ぶことにします。" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "3LXMVuV0VhDr" }, "source": [ "## 設定" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": {}, "colab_type": "code", "id": "NiolgWMPgpwI" }, "outputs": [], "source": [ "import tensorflow as tf" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "eMAWbDJFVmMk" }, "source": [ "## 変数\n", "\n", "TensorFlow のテンソルはイミュータブルでステートレスなオブジェクトです。しかしながら、機械学習モデルには変化する状態が必要です。モデルの訓練が進むにつれて、推論を行うおなじコードが異なる振る舞いをする必要があります(望むべくはより損失の少なくなるように)。この計算が進むにつれて変化する必要がある状態を表現するために、Python が状態を保つプログラミング言語であることを利用することができます。" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": {}, "colab_type": "code", "id": "VkJwtLS_Jbn8" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(\n", "[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n", " [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]], shape=(10, 10), dtype=float32)\n" ] } ], "source": [ "# Python の状態を使う\n", "x = tf.zeros([10, 10])\n", "x += 2 # これは x = x + 2 と等価で, x の元の値を変えているわけではない\n", "\n", "print(x)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "wfneTXy7JcUz" }, "source": [ "TensorFlow にはステートフルな演算が組み込まれているので、状態を表現するのに低レベルの Python による表現を使うよりは簡単なことがしばしばあります。\n", "\n", "`tf.Variable`オブジェクトは値を保持し、何も指示しなくともこの保存された値を読み出します。TensorFlow の変数に保持された値を操作する演算(`tf.assign_sub`, `tf.scatter_update`, など)が用意されています。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": {}, "colab_type": "code", "id": "itxmrMil6DQi" }, "outputs": [], "source": [ "v = tf.Variable(1.0)\n", "# Python の `assert` を条件をテストするデバッグ文として使用\n", "assert v.numpy() == 1.0\n", "\n", "# `v` に値を再代入\n", "v.assign(3.0)\n", "assert v.numpy() == 3.0\n", "\n", "# `v` に TensorFlow の `tf.square()` 演算を適用し再代入\n", "v.assign(tf.square(v))\n", "assert v.numpy() == 9.0\n" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "-paSaeq1JzwC" }, "source": [ "`tf.Variable`を使った計算は、勾配計算の際に自動的にトレースされます。埋め込みを表す変数では、TensorFlow は既定でスパースな更新を行います。これは計算量やメモリ使用量においてより効率的です。\n", "\n", "`tf.Variable`はあなたのコードを読む人にその状態の一部がミュータブルであることを示す方法でもあります。" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "BMiFcDzE7Qu3" }, "source": [ "## 線形モデルの適合\n", "\n", "これまでに学んだ `Tensor`、 `Variable`、 そして `GradientTape`という概念を使って、簡単なモデルの構築と訓練を行ってみましょう。通常、これには次のようないくつかの手順が含まれます。\n", "\n", "1. モデルの定義\n", "2. 損失関数の定義\n", "3. 訓練データの取得\n", "4. 訓練データを使って実行し、\"optimizer\" を使って変数をデータに適合\n", "\n", "ここでは、`f(x) = x * W + b`という簡単な線形モデルを作ります。このモデルには `W` (重み) と `b` (バイアス) の2つの変数があります。十分訓練されたモデルが `W = 3.0` と `b = 2.0` になるようなデータを人工的に作ります。" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "gFzH64Jn9PIm" }, "source": [ "### モデルの定義\n", "\n", "変数と計算をカプセル化する単純なクラスを定義してみましょう。" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": {}, "colab_type": "code", "id": "_WRu7Pze7wk8" }, "outputs": [], "source": [ "class Model(object):\n", " def __init__(self):\n", " # 重みを `5.0` に、バイアスを `0.0` に初期化\n", " # 実際には、これらの値は乱数で初期化するべき(例えば `tf.random.normal` を使って)\n", " self.W = tf.Variable(5.0)\n", " self.b = tf.Variable(0.0)\n", "\n", " def __call__(self, x):\n", " return self.W * x + self.b\n", "\n", "model = Model()\n", "\n", "assert model(3.0).numpy() == 15.0" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "xa6j_yXa-j79" }, "source": [ "### 損失関数の定義\n", "\n", "損失関数は、ある入力値に対するモデルの出力がどれだけ出力の目的値に近いかを測るものです。訓練を通じて、この差異を最小化するのがゴールとなります。最小二乗誤差とも呼ばれる L2 損失を使ってみましょう。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": {}, "colab_type": "code", "id": "Y0ysUFGY924U" }, "outputs": [], "source": [ "def loss(predicted_y, target_y):\n", " return tf.reduce_mean(tf.square(predicted_y - target_y))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "qutT_fkl_CBc" }, "source": [ "### 訓練データの取得\n", "\n", "最初に、入力にランダムなガウス(正規)分布のノイズを加えることで、訓練用データを生成します。" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": {}, "colab_type": "code", "id": "gxPTb-kt_N5m" }, "outputs": [], "source": [ "TRUE_W = 3.0\n", "TRUE_b = 2.0\n", "NUM_EXAMPLES = 1000\n", "\n", "inputs = tf.random.normal(shape=[NUM_EXAMPLES])\n", "noise = tf.random.normal(shape=[NUM_EXAMPLES])\n", "outputs = inputs * TRUE_W + TRUE_b + noise" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "-50nq-wPBsAW" }, "source": [ "モデルを訓練する前に、モデルの予測値を赤で、訓練データを青でプロットすることで、損失を可視化します。" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "colab": {}, "colab_type": "code", "id": "_eb83LtrB4nt" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAD4CAYAAAAJmJb0AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3dbYxc53Uf8P+Z2V2KS8omOcs4Iq0duq3Sdhm4CkQICNQaDsjIKlFUVQCnJmYFmjKw5oySKi2MNu0CtZtggaZvKSuIpNiaCs2d2giqCjFq1lHIBk39wbVXKU2TVJTIKnctWTWXS8sSX6Rd7p5+eOZyZu7e97l37sv8f8CAu3fuzL2zss8+e57znEdUFUREVEyltG+AiIiSwyBPRFRgDPJERAXGIE9EVGAM8kREBTaU9g10Ghsb0127dqV9G0REufLKK69cU9XtTs9lKsjv2rULc3Nzad8GEVGuiMi823NM1xARFRiDPBFRgTHIExEVWCxBXkROishVEbnYcexLIvKWiJxvPfbHcS0iIgourpH87wF4zOH476rqg63HmZiuRUREAcUS5FX1TwBcj+O9iIgGSrMJ7NoFlErm32Yz1rdPOif/ayJyoZXO2ZrwtYiI8qXZBKamgPl5QNX8OzUVa6BPMsgfA/CXATwI4G0A/9bpJBGZEpE5EZlbXFxM8HaIiDJmehq4dav72K1b5nhMEgvyqvpjVV1V1TUA/xHAwy7nnVDVPaq6Z/t2xwVbRETF0GwCY2OAiHnMu6xhWliI7ZKJBXkRua/j2ycAXHQ7l4io8JpN4KmngKUl/3PHx2O7bCxtDUTkqwA+CWBMRN4E8EUAnxSRBwEogCsAPh/HtYiIcml6Glhe9j9vdBSYmYntsrEEeVU94HD4y3G8NxFRIfilYETMCH5mBqjVYrtsphqUEREV1vi4ew6+WgWuXEnksmxrQETUDzMzwMjI+uPDw7GmZ+wY5ImI+qFWA06eBCqV9rFKBXjhhVjTM3ZM1xAR9UutlmhAd8KRPBFRgTHIExEVGIM8EVGBMcgTERUYgzwRUYExyBMRubE3FBsbi73fe9JYQklEZNdoAM8/D6ytdR9fWgIOHTJf97kUMiqO5ImIOu3eDRw7tj7AW1ZWeur3nvBGUOtwJE9EZGk0gMuX/c+L2O/d2gjK2ifE2ggKSO4PA47kiYgsJ04EOy9iv/c+bAS1DoM8EQ22RgMol83E6uqq//k9NBRz+wMgxo2g1mGQJ6LBtW+fd/7drseGYm5/AMS4EdQ6DPJENHis0shz54KdPzEBqALXrvWUPJ+ZMRs/dYp5I6h1GOSJaHBYwX1yMtheq+UyUK8Dly7FcvlazaT9q1WTHapWzfdJVmOyuoaIBkOjARw/bkbkQSS0W1O/uw1zJE9ExdZsAps3m9x70ABfKiWSQ3GqkU+6bp4jeSIqrkbDBPcwRkbMDk4xD7edauSfesr83llZaR+Lu26eI3kiKh4r9x4mwFcqwOws8MEHieRTnGrkl5fbAd4Sd908R/JEVCz2IbOfSgU4ciTxRHmYWvg46+Y5kieiYrAWNU1OBg/w9TqaR65h13Qt8V4yYWrh46ybjyXIi8hJEbkqIhc7jm0TkT8Skb9o/bs1jmsREXXpnFgNuqgJMAH+kaOYmjK5cNV2TjyJQO9UIz8yYhbQdoq7bj6ukfzvAXjMduw3AZxT1QcAnGt9T0R0V8+VJY0G8OSTwM2bwV+zaZPJvR896tpL5pln4q94caqRP3nSLKBNtG5eVWN5ANgF4GLH968BuK/19X0AXvN7j4ceekiJaDDMzqqOjqqaMbR5jI6a44FevGFD94v9HpXKujcX6enlmQFgTl3iapI5+Y+o6tutr/8fgI84nSQiUyIyJyJzi4uLCd4OEWVJ5I6MzSZw8KCpggmiXDYjd4eWBGFy30tLyaVyktSXidfWbxrHVQiqekJV96jqnu3bt/fjdogoA0J3ZGw0gKEhM7EapFskYPItp0655j+c8uRerFROniQZ5H8sIvcBQOvfqwlei4hyxq8jY2e+/tWh3dBjx4IHdwDYtAnf+vxXPCtnnPLkfpaW8jWaTzLIfx3AwdbXBwH8QYLXIqKc8erIaM2nfmG+gTsq+GurlyFB37heB1TRfP4GPnWq5ls5U6uZa46Pm3OCSHKTj9i5JevDPAB8FcDbAFYAvAngcwAqMFU1fwHgLIBtfu/DiVei/JudVa1WzaRmteo9Wel07uys+f4KduhamIlVQHXv3rvvWy67n9Z5X04TwH4PkWR/hmHBY+I1tuqaOB4M8kT5FrRixusXweXyhK4B4QJ8qaRar7veg9PDuq9qNVyAt35JZAmDPBH1hVvA7AyKnr8IJiYCB/c14G5g73xvrxF8HI/AZZ595BXk2daAiGITpGLGqXTyd2418Pcnh4DLwXLvCuAiJtB85OjdY1bLmjBzs0GMjJj2Nv3a5CNuDPJEFInTatUge5h2Tm4+iwZWIXgaxzAE/+isANYgeA51fByXulamHjwYvGVNUOUy8LnPmRL7tTWzh0ieAjzALpREFIFTb/SpKRNoT53qDradvViaTTMiVgXOYzc+juBVMwpgATuwC2/dPba01N7FL+4RvPWep04BjzySv+Bu4UieiEJzW6165oz3HqbT08B/0AbWIKED/AVMdAX4fom7v3u/icnZZ8OePXt0bm4u7dsgIh+lkhmN24l4N4Kcl50Yx48C590BYBllHMIpfBXhh9LDw+s35YjC73OlTUReUdU9Ts9xJE9EoQXJvXdpNACRUAH+OdRRguIe3FkX4EVMM0kn5XL7r4gPfSjAxQKIs797vzHIE1FoXqtVgfakbE2aeF/uMS0JgFAB/tdx1P0cBe65x/keTp1qT5Jev+59rVKACBh3f/d+Y5AnotCcer5YuXdrUvYL8w3MYhL34IPAufc1CGqYxW+U3QO85fr17nuoVICNG007BL9qH8CcP2QrPREB9u5NuL97nzHIExGAaBt43LjR7gszOWmC5k8mG3jvVglP41ioidVlDGESp/FypRYo/71hg5kQXVgAtm0D3nvPVNpY9zM1Bezf7zzan501m0ktL9vuQ4HXXzd/BeS1ZHIdt1VSaTy44pUoHWE38JidVR0eXr8aNEq/mTVAz2Pi7qFKJVqrAbeVtm4tFNw2DMlaX5og4LHildU1RAPO2oPDqc68WjWjWbuxsXZ9OgAcQBNfwSTKCJZ3B9rVM07590ql+/2j8qqK2bXLueuk22fOMlbXEA2YoKkXv1YATkGw2ewOwFexFU1MYgjhArxVPWMP8CLxBHjAOyfvN3lcGG5D/DQeTNcQ9c4t9VKvr09bVCr+KQ+r4Ve53H4PQPU8wneLXAP0Jkb0AGYTbSIWtJFYmLbIWQama4gGh1sawmonYOllodAdCEoIN3JXCI7isGdpZFzKZc9d/wrHK13D3jVEBePWCdI+nutngPere4/b2trgBHg/zMkTFUxSqzPPYzfWUg7wo6NmUtZPnleoxo1BnqhgZmZMKiYuV7Czq6GYX4DX1uM27kENs7EFeGthkt8q1kJOnvaAQZ4oB8IuVJLAu157u4N2v5mgLQnew0aUoBjF7UhNxZyItBcmeY3Si7BCNW4M8kQZZ5U5zs93r+a0B3rrF8Hk5PqVnGE9i0ak1MwFTODDiHnnDnQHdrfSx3rdfN3Z1oDAEkqiLPPas9Rv39Soj7CrVq0yymdRT6wc0raV67rSx3o93IrdogFLKInyx777kpty2fSM+eCD3q73TezDozgHINzo3b5bUxL8SiKLtHo1Cq54Jcohp92XnKyu9hbgD6CJVQgexblQuXcFUMNsX3ZrWl11TlFZgmwgPqgY5IlS4jeZ2o8AtQxBE5OBc+9WcL+GLShBY5tYBfwni7224Qu9ickASTzIi8gVEfm+iJwXEeZiiOA8mXrokGn8ZQX9bduSvYdVSOB+M1ZwtypnfgY/iXxdq/c7YNIwgEmrnD5tWgDbJ1U7uf3iG5g+NBH0a8XrL6nqtT5diyjT3Lo+rqy0G3M55ZfjEjb3Hnfe/fRp/xJHt66YbiPzzo3CFxbMeTMzLKUE2NaAqK/8uj4m6QCamMVk4Lw70C6LfBCXYrmHSsU/8FrP2yed/UbmtRqDupN+5OQVwMsi8oqITNmfFJEpEZkTkbnFxcU+3A5RMEEWINnPaTRMykXEPMbGul8XdDI1TgfQxFqE3PsazMRqXAF+dBQ4ciTYuV7bC1JIbrWVcT0A7Gz9+zMAvgfgE27nsk6esiLITklBa9OHh9uvS7q9rv1htQMO86I1QJeB2O4h72188wBZqZMXkS8BuKGq/8bpedbJU1YEqbt2O8eJNdEY12YYfg6gidN4EiVo6J2a4kzPDEqdetpSazUsIpsAlFT1vdbXjwL4rSSvSRSHIHXXYUoc+xXcAdNQzOo3E5RVORNnSwJWt2RD0jn5jwD4loh8D8B3AHxDVb+Z8DWJehak7jprNdhWv5kwAb4z9x5HgLfmIphDz45ER/Kq+gaAv5HkNYiSMDPjXt3RbJoJ1CTLHMP6KUZxL26HHr1fwxbHmvfNm4EbN8Lfh5WJp+zgilciFxs3tr+uVMzIFGgvYsoCa/QeJsB3jt7dFjVt2NBeqET5xjp5Iptm06w+7dwe7913zb9plEC6iZp7fxl78RjOep53/Tpw+DBw7Fi4ewqyaxP1F0fyRDbPPLN+/9OVFXM8Cw2vnkUDqz3k3v0CPGDmG44eNT3aSwGjxNBQ8Dp46h8GeRpIXgud3CphlpbSn2y9iRE8jWOhN/N4DxtRDthQTATYv998ffSoWZ07O9u9MKleXz9q//CHw3wS6hu3Avo0HlwMRVHYN5DwW3Tjt9Cp3wuWgjysjTzCbuaxCugBzHY95bYJidvPI+rPkfoHWVkM5YeLoSgsp401Rke9y/fcFjFVKqaqJCuTqkC0fjMAYuk547eQadA36sgSr8VQDPKUa14B+5pL39NSKR9lflexFWN4J3RwB4DnUMev42hP1xcB1tbcn3f7Ofq9juLHnaGosNwmQpeW3HcRSjuv7scqiwwT4LX1WEYJJWjgAO9VJun3c+JGHfnAIE+55hVQ7LsIWZOt8/P+uxCl5Sq24mkcC90OeBVACYp70N3DuFJxL2usVoE7d5w36gjSkoAbdeQDgzxlUpA2v4B3QOkc5XfuxARkL13Ty+j9AiYwjO4PtHmz+YzXrpmyRq9gHLWtL9sB54TbjGwaD1bXkGr4qo1KxblCxKq0qVbTr47xeiyHrJqxKmd+io2up4is/5mGqUCifAGrayhPglRtWP1jFhbMXqjvvrt+AZNI9kbsnayJVSBcakYBTGLWs+adFS6DJbVWw0RRuE2mzs+bXwD79wOnTrXLJpeWgJERYNMm4ObN9vlZDvCrkEhlkXcAjMD/gzEvThbm5ClTmk3vZfTz86afir1/zPIycPt2svcWB2srvrATq1buPUiAHxlhXpzaOJKnzOh1k+us12ZHbQesAMq24F4uu/+c7GkrGmwcyVPf+FXMZKnDY5yitgNWtHvOdCqXgS1b3F/LOnXqxJE89YW9/cD8vPkeaKcWstDhMW4rEJQRPvfuthXfpk0mNeXWRK2zuRgRwJE89YnTKP3WrfaCJa9cfFYXLnmxRu9hArw1en8O9XUBfnTULFoaG/NOx6iaSWm3dQU0eBjkKZKgi5Usbk2/Fha8c/HDw2b0mierkEirVt/DRteWBNYvxCB/7XT+8iRikKfQOlePqrZTL26Bvtl0H42rAgcPuufiV1ai7TWahqvYGqlyZg3Oo3e7hYXg+fYipr4oGgZ5Cs0v9eJ0vlfNetRqmqw4j91dLQnCBPgLmEA5YEOx8XHnfjFu5xIBDPIUgdsoMezxIngfZXwclyON3muYDdzv3eo1Y+8XU6mYuninc4kABnmKwG2UuG1bd56+0TD/ZnnlaVTfxD6sQTCCtUgTq0G24iuXnRt/1WqmZcHammlAdvIkm4SRu8R714jIYwCOACgD+E+q+i/dzmXvmnxw2o1pZMQE80FYiLMMwRCityTYtAl4/33/NNXsLIM1BZPapiEiUgbwHIC/DWACwAERmUjympQ8pxaz995b/ABv5d7DBHhr9L6AHXdbEty4Ycocq1X311UqDPAUj6TTNQ8DeF1V31DVZQBfA/B4wtekPuhMGVy5Aly/nvYdJWsZEin3rjCbeezCWwDagd36+blt2HHkSDz3TZR0kN8J4Icd37/ZOnaXiEyJyJyIzC0uLiZ8O5SUolZzWLn3KKN3K/ducZoQ5cYblLTU2xqo6gkAJwCTk0/5diii/fuB48eLNckataHYMkrrtuErl92Dd63GoE7JSXok/xaA+zu+/2jrGBVIo1GsAG+N3qM0FHsZe3HvcHeAHx01OXgGckpD0kH+uwAeEJGPicgIgM8A+HrC16SIwrYqsF5TlAB/AE2sQvAozgXOvVvB/Rq2oATFYziLF15g+oUyxG1fwLgeAPYD+HMAPwAw7XUu93g10tiP021f1Xrd+16yvn9q0Md5TETaZ3UF6DpcrSb/34rIDtzjNT+catBHR5MfDbrtq2rfJ9V+L3nsEGn3PsqhFjUBZvR+Dnvxyzh791g//jsROUmtTp7CC9sXJi5urQfsYwB7e+A8B/kr2Bl61SoAoFSCqOLHs2eZlqHMY5DPmLT6v4QpgbTuxa/xWFZZufdx/Cj0Ztqo1+8uVbWvFWCApyxikM8Yt2AbVx16s2k2nhAxj7Exc8ypu6HbKH3bNvNvHhuP/RSjaGISJYQM7oD5jXbUv1skUZYwyGeMU7CNq6tgswk89VT31nFLS8ChQ+Zre3dDt5a2S0vml4MV7PMgSlnkXVu25PNPFiIwyGdOrysgvcogp6fN/qB2KyvmOSv9cPo0cPs2cPOm+3WWltz3Gc2a89gdqiyyiyrwk58AiFZiSpQ2VtcUiF9lTqnkPiAVMbllwIzS8xLA/VzF1rubeYQyMQFcavd6T6vqiSgIVtcMCL/KHK/0ipXzbzbTDfBxVev8FKNduzWFotoV4IH0qp6IesUgnzFhUwKd57ttlj0/7z06Hx5u5/zTDlqHD3u34PVj7bNq5d5DBfgdO1z/1BnEXa+oGFJvUEZt9pSAtUE24JwScEohuHEL8CLACy+03z/NoCXSLl5xW5xlnecUi1dDbqLdxSdtOT7ufD9F7b5JxcGRfIZE2SA7SID3YsU266+BNBc3HT7c/trrl41qdxXQu63UTKQAv3dvoMqZJKueiJLEkXyGpLVBdudfA0nPw5fLZi3R5s2mekfVHJua6i5Bdxs5AybAX7nS+ibqb6UdO4C3gjdEtf7SmZ42P/fx8fbG2kRZxpF8RjSbZiTtJOwCqbA57V7/GghjddWMgI8fN9U8qsCdO+vXGM3MmLkCu5GR1uh59+7oAX52NlSAt3CFK+URg3wGWLl1t42d9+93Pu6VQiiX473HOAWpSqnVzFxBpdI+VqkAb/yVfahNCnD5cvgLDw+b3yqMzjRAWCefAV6TjICJTZ2To81mO21glUVev25G9vv3A2fOeL9fFnTW5Qe2dSvwzjvhLzY87LwKjKggWCefcX65dWtFKmB2YXrySRPEVU3VzO3bZtLyxg3g2LHsB3ggZFVKo2F+K0QJ8Dt2MMDTQOPEawZ4TTJaFhbcd2G6dSubuzOJmPSRvT1CqKqUqKN3IHs/EKIUcCSfAW45907j496tfbMWzyoV0wPHfl8iwMGDAdLiVrP6KAE+YFkk0SDgSD4DzpzxP2dhIT9xa2QEOHLEuY5fNcDn3b072sSqdQEiuosj+YS49W13EqTePQ+xy+qaefKkGamHru+3Ru9RAvzERD5+SER9xpF8Aqy+7Z3zfZ192+2piiA5+azrWqDUEqoVQNTR+5Ytd1sBE9F6HMlH4NdEzKtv++Tk+tc41bvnidtEauBWACMj0QJ8vc4AT+RHVTPzeOihhzTrZmdVR0dVTW7APEZHzXGLSPfzTg/7a2ZnVatV89pqVbVS8X+PLDzK5e7P4fTz6vxcXefu2BHtojt2xPrflCjvAMypS1zlYqiQ3BYudaYr/BY3Ob3GLmiHyUoF+NVfTWcBVE+bZvTSkoArVom6cDFUjIJMJs7MmAyEH6+g3LkNINBuU2CPjbdvA488kkzKx62XjiVSgN+5M1qAt3q9M8AThZJYkBeRL4nIWyJyvvUIUA2efUGahdVqpsKks++KE7/+MlZDLKuJV7XqvBDqmWfiaTsMtHu1q5peOm7NzqrVkPHWWrX6ox+Fv6l6PVJDMSJKfiT/u6r6YOsRoBo8+4JOJtZqwLVr3lV9q6vhdoFy+ytiacn7r4IwA2f7L7FY+qhv3Wr6LYRlLWqyt6gkosCYrgmpM41i1YX7pS3cRsMi7R401i5QXoHeq9+L118F27YBmza5P29x+2UV9vPetXVr9FWrs7PA2bPhX0dE3dxmZHt9APgSgCsALgA4CWCry3lTAOYAzI2Pjyc5AZ0ap4octwocqwLFqSJldta/YsftOb+Kn3WVL71+4KjlOqycIQoNSVXXiMhZAD/r8NQ0gG8DuAZAAfw2gPtU9Smv98tDdU1Une2B/RY/jY5259c7q1jcNuSuVs0ofHo6XJVNpJa/XvbtA86di/baDFV6EeWJV3VNX0ooRWQXgP+mqj/vdV6Rg7ydW5mltT2eXaVitsybn1+/kbW9lLFUCh4vvco4Q+kluLMskqgnqZRQish9Hd8+AeBiUtfKI6cJzeFh992hOidXVduTqU45crfcvX0CNraNqEdGogX4jRtZFkmUsCQnXv+ViHxfRC4A+CUA/zDBa+WOfUKzUglXBaPaHoXbY6RbRczhwxEnUN1YG9OurIR/bb3e381liQYUV7xmRNBVsp288un2OYCZmZgHzI1GtLLIiQng0qUYb4SIvNI17EKZEV7thisV58lWr5LKWi3BLEiUjpHsFkmUCtbJO/DrMpkEt4BdrZoNOHpekBSH3buj9Xvfu5cBniglDPI2VmOwMIuU4uC2BeD+/T0uSIrDvn3RgrvVb4aLmohSwyBv49QD5tYtcxxIbpTvtiWeddzqY7O25jzZmpjdu6NVzuzYwX4zRBnAIG/j1WUyyVF+6K3ykmY1FIuymcfevQzwRBnBIG/j1WXSb5Sf1HX7yiqL7KWhGNMzRJnBIG/j1XUxydF2LN0ee7Vvn9mfMGxZ7dAQG4oRZRSDvI3XJGeSo+3UJ1cbjfC5dxET3FdWuGqVKKO4GCoEpy35etoCLwuaTeDznwdu3gz3Oi5qIsoMbv8Xk9RH23FrNEx6JmyA37uXAZ4oJ7jiNaREV5L2S6MBHD8ePvfO0TtR7nAkP0issshjx8IH+HqdAZ4ohziSHxRRG4rt3cuqGaIc40i+6JpNYMOG8AG+XmfNO1EBFCLIp9FQLBesuvfl5eCvETEB/ujR5O6LiPom9+kae1mj1WoAKMAEaS/CtgMWAU6fHvAfGlHx5H4kn2SrgVyK2nOGAZ6okHI/ks9cY6+0NJvAZz8L3LkT/rXcSJuosHI/ks9MY680WYuawgZ4q6EYAzxRYeU+yGeisVeaopRGcjMPooGR+yBfuFYDQVjlRNbCpjDqdfZ6Jxoguc/JAwVpNRDUvn3RukUePsyySKIBVIggPzCitANmvxmigdZTukZEPi0il0RkTUT22J77pyLyuoi8JiKf6u02B1yzCYyNhUvNWIuaGOCJBlqvI/mLAH4FwPOdB0VkAsBnAOwGsAPAWRH5OVVd7fF6g6fZBA4dMhtzBMHUDBF16Gkkr6qvquprDk89DuBrqvqBqv5fAK8DeLiXaw2kZhM4eDB4gK9WzaImBngiakmqumYngB92fP9m69g6IjIlInMiMre4uJjQ7eSI1VBMxNS+rwb846deB65cGaAZaCIKwjddIyJnAfysw1PTqvoHvd6Aqp4AcAIw2//1+n65xsoZIoqZb5BX1X0R3vctAPd3fP/R1jFyE7ah2MgIcPIkR+5E5CmpdM3XAXxGRDaIyMcAPADgOwldK/8ajXABvlJhgCeiQHotoXxCRN4E8IsAviEifwgAqnoJwO8DuAzgmwCeZmWNA2vlatDSyNFR00zs2jUGeCIKpKcSSlV9CcBLLs/NABiUDjLh2Rvh+xkZGYB+DUQUN654TYtTI3w3XLVKRBHlvkFZbjQawNCQqYYZGjJbWPkZGjLpGQZ4IoqIQb4frHbAVs27X+17uWzq3ldWmJ4hop4wyCcp6sTqnTuseyeiWDDIJ6XRAJ580j8tM1CN8Imo3zjxmoRmEzh+3Oy+5KVcNq0IiIgSwpF8Eqan/QM8YEooiYgSxJF8EhYWvJ8vl02AZ96diBLGkXwSxsedj4twYpWI+opBPgqraqZUMv82m93Pz8yYSplOVrdITqwSUR8xyIfVWTWjav6dmuoO9LWaqZTprJzhZh5ElALRIBOEfbJnzx6dm5tL+zbcNZsmwDv9zKpVVsoQUSpE5BVV3eP0HEfyYXhVzfhNthIRpYBBPgyvQO422UpElCIGeSduE6teVTMz7KpMRNnDOnk7e593a2IVMIHc3gOeVTNElGEM8nZOfd5v3TLHrYnV6WmTuhkfN4GfAZ6IMorVNXalkvPkqgiwttb/+yEi8sHqmjDc8u6cWCWiHGKQt3NarTo6yolVIsolBnk7p9Wq7PNORDnFiVcntRqDOhEVAkfyREQFxiBPRFRgPQV5Efm0iFwSkTUR2dNxfJeI3BaR863H8d5vlYiIwuo1J38RwK8AeN7huR+o6oM9vj8REfWgpyCvqq8CgIjEczdERBSrJHPyHxOR/yMi/1NE/pbbSSIyJSJzIjK3uLiY4O0QEQ0e3yAvImdF5KLD43GPl70NYFxVfwHAPwLwn0XkQ04nquoJVd2jqnu2b98e7VP4bcdHRDSgfNM1qrov7Juq6gcAPmh9/YqI/ADAzwGIvzGNV9dI1roT0YBLJF0jIttFpNz6+i8BeADAG0lcy7NrJBHRgOu1hPIJEXkTwC8C+IaI/GHrqU8AuCAi5wH8FwCHVfV6b7fqwm23Jm7HR0TUc3XNSwBecjj+IoAXe3nvwMbHTYrG6TgR0YDL/4pXdo0kInKV/yDPrmdMnoIAAAO4SURBVJFERK6K0YWSXSOJiBzlfyRPRESuGOSJiAqMQZ6IqMAY5ImICoxBnoiowERV076Hu0RkEYDDyqaejQG4lsD79hs/R7bwc2TLIH+Oqqo6dnjMVJBPiojMqeoe/zOzjZ8jW/g5soWfwxnTNUREBcYgT0RUYIMS5E+kfQMx4efIFn6ObOHncDAQOXkiokE1KCN5IqKBxCBPRFRgAxPkReS3ReSCiJwXkZdFZEfa9xSFiPxrEfmz1md5SUS2pH1PUYjIp0XkkoisiUjuyt5E5DEReU1EXheR30z7fqIQkZMiclVELqZ9L70QkftF5I9F5HLrf1PPpH1PUYjIPSLyHRH5Xutz/ItY3ndQcvIi8iFVfbf19T8AMKGqh1O+rdBE5FEA/0NV74jI7wCAqv6TlG8rNBH56wDWADwP4AuqGv8m7wlp7V/85wB+GcCbAL4L4ICqXk71xkISkU8AuAHgK6r682nfT1Qich+A+1T1T0XkXgCvAPh7OfzvIQA2qeoNERkG8C0Az6jqt3t534EZyVsBvmUTgFz+dlPVl1X1TuvbbwP4aJr3E5Wqvqqqr6V9HxE9DOB1VX1DVZcBfA3A4ynfU2iq+icAktl7uY9U9W1V/dPW1+8BeBXAznTvKjw1brS+HW49eo5TAxPkAUBEZkTkhwBqAP552vcTg6cA/Pe0b2IA7QTww47v30QOg0oRicguAL8A4H+neyfRiEhZRM4DuArgj1S1589RqCAvImdF5KLD43EAUNVpVb0fQBPAr6V7t+78PkfrnGkAd2A+SyYF+RxEcRGRzQBeBPAbtr/cc0NVV1X1QZi/0B8WkZ7TaMXY/q9FVfcFPLUJ4AyALyZ4O5H5fQ4R+SyAvwNgr2Z4UiXEf4+8eQvA/R3ff7R1jFLSymG/CKCpqv817fvplaq+IyJ/DOAxAD1NjBdqJO9FRB7o+PZxAH+W1r30QkQeA/CPAfxdVb2V9v0MqO8CeEBEPiYiIwA+A+DrKd/TwGpNWH4ZwKuq+u/Svp+oRGS7VS0nIhthJvZ7jlODVF3zIoC/ClPRMQ/gsKrmbvQlIq8D2ABgqXXo2zmtEnoCwLMAtgN4B8B5Vf1UuncVnIjsB/DvAZQBnFTVmZRvKTQR+SqAT8K0tv0xgC+q6pdTvakIRORvAvhfAL4P8/9vAPhnqnomvbsKT0Q+DuAUzP+mSgB+X1V/q+f3HZQgT0Q0iAYmXUNENIgY5ImICoxBnoiowBjkiYgKjEGeiKjAGOSJiAqMQZ6IqMD+P30scbww7cBMAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Current loss: 8.123021\n" ] } ], "source": [ "import matplotlib.pyplot as plt\n", "\n", "plt.scatter(inputs, outputs, c='b')\n", "plt.scatter(inputs, model(inputs), c='r')\n", "plt.show()\n", "\n", "print('Current loss: %1.6f' % loss(model(inputs), outputs).numpy())" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "sSDP-yeq_4jE" }, "source": [ "### 訓練ループの定義\n", "\n", "ネットワークと訓練データが準備できたところで、損失が少なくなるように、重み変数 (`W`) とバイアス変数 (`b`) を更新するために、[gradient descent (勾配降下法)](https://en.wikipedia.org/wiki/Gradient_descent) を使ってモデルを訓練します。勾配降下法にはさまざまな変種があり、我々の推奨する実装である `tf.train.Optimizer` にも含まれています。しかし、ここでは基本原理から構築するという精神で、自動微分を行う `tf.GradientTape` と、値を減少させる `tf.assign_sub` (これは、`tf.assign` と `tf.sub` の組み合わせですが)の力を借りて、この基本計算を実装してみましょう。" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "colab": {}, "colab_type": "code", "id": "MBIACgdnA55X" }, "outputs": [], "source": [ "def train(model, inputs, outputs, learning_rate):\n", " with tf.GradientTape() as t:\n", " current_loss = loss(model(inputs), outputs)\n", " dW, db = t.gradient(current_loss, [model.W, model.b])\n", " model.W.assign_sub(learning_rate * dW)\n", " model.b.assign_sub(learning_rate * db)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "RwWPaJryD2aN" }, "source": [ "最後に、訓練データ全体に対して繰り返し実行し、`W` と `b` がどのように変化するかを見てみましょう。" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": {}, "colab_type": "code", "id": "XdfkR223D9dW" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0: W=5.00 b=0.00, loss=8.12302\n", "Epoch 1: W=4.65 b=0.38, loss=5.69242\n", "Epoch 2: W=4.36 b=0.68, loss=4.09065\n", "Epoch 3: W=4.12 b=0.93, loss=3.03475\n", "Epoch 4: W=3.92 b=1.12, loss=2.33847\n", "Epoch 5: W=3.76 b=1.28, loss=1.87918\n", "Epoch 6: W=3.63 b=1.41, loss=1.57614\n", "Epoch 7: W=3.52 b=1.51, loss=1.37612\n", "Epoch 8: W=3.44 b=1.60, loss=1.24407\n", "Epoch 9: W=3.36 b=1.66, loss=1.15686\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD4CAYAAADFAawfAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3deXiU5b3/8fdNCAaBIIQguwGByB4wCCQeQaFFObicA8UqLmgtqLUKYj3VbvyqrdWjFltFpbXHKmhRtKcVoXqkooVgbUBckEVUxIA0rEolke3+/fHNMDNMgAnMZJ5kPq/req5ZnifDNyN+vL2fe3Hee0REJLgapLoAERE5MgW1iEjAKahFRAJOQS0iEnAKahGRgGuYjA9t1aqVz8vLS8ZHi4jUS8uWLdvqvc+t7lxSgjovL4/S0tJkfLSISL3knPvkcOfU9SEiEnAKahGRgFNQi4gEXFL6qEVEAPbu3UtZWRmVlZWpLiUwsrKy6NChA5mZmXH/jIJaRJKmrKyMZs2akZeXh3Mu1eWknPeebdu2UVZWRufOneP+ubiC2jm3HtgF7Af2ee8Lj6lKEUkrlZWVCukIzjlycnLYsmVLjX6uJi3qs733W2tWloikO4V0tGP5PoJ1M/GOO+Dll+HAgVRXIiISGPEGtQdeds4tc85NrO4C59xE51ypc660ps16AHbtghkzYORIyM+H+++H7dtr/jkiIhGmTJnC9OnTD74eOXIk11xzzcHXU6dO5f77709FaXGLN6jP9N4PAM4DvuOcO+vQC7z3M733hd77wtzcamdBHlmzZrB+PTz1FLRpA1OnQvv2cPXVoFmOInKMiouLKSkpAeDAgQNs3bqVlStXHjxfUlJCUVFRqsqLS1xB7b3fWPVYDvwROCMp1ZxwAlxyCfztb/D22zBhAjzzDAwcCGecAY8/DhUVSfmjRaR+KioqYunSpQCsXLmS3r1706xZM3bs2MFXX33FqlWrGDBgQIqrPLKj3kx0zjUBGnjvd1U9/zrw06RX1rcvPPww3H03PPmkdYtcdRXcfLO1sq+9Frp2TXoZIpIgkyfDihWJ/cyCAojo1qhOu3btaNiwIRs2bKCkpIQhQ4awceNGli5dSvPmzenTpw+NGjVKbF0JFk+L+mRgsXPubeBN4EXv/V+SW1aE7Gz4znfgvfdg0SL42tfggQegWzc491z4859h//5aK0dE6p6ioiJKSkoOBvWQIUMOvi4uLk51eUd11Ba19/4joF8t1HJkzsHQoXZ89hn89rfw6KNw4YXQqRNMmgTf+hacfHKqKxWR6hyl5ZtMoX7qd999l969e9OxY0fuu+8+srOzueqqq1JWV7yCNTwvXm3bwo9+ZDcfn3vOWtc/+AF07AiXXgqLF4N2VxeRKkVFRcybN4+WLVuSkZFBy5Yt2blzJ0uXLg38jUSoq0Ed0rAh/Od/wiuvwKpVcP31MH8+/Nu/Qb9+8MgjNuxPRNJanz592Lp1K4MHD456r3nz5rRq1SqFlcXH+SS0PAsLC33KNg748kt4+ml46CG7cdGsGVxxBVx3HfTqlZqaRNLUqlWr6NGjR6rLCJzqvhfn3LLDLc9Rt1vU1WnSBK65BpYvh6VL4aKL4De/gd694eyz4dlnYe/eVFcpIhK3+hfUIc7B4MHwxBNQVmbD/Navh3Hj7ObjT35i74uIBFz9DepIublw662wbh28+CIMGGDriuTlwZgxsHChbj6KSGClR1CHZGTAqFEW1uvW2TT1116DESOgRw8bn71zZ6qrFBGJkl5BHalLF+sOKSuz7pEWLWzmVLt21sf9+utaxU9EAiF9gzokKwsuv9xuPC5bBuPH28JQQ4fauOzJk+2cukZEJEUU1JEGDLARIuXlFtYDB9p6I0VF1p/9ve/ZSn4KbZE6Y/369fTu3TvVZRwXBXV1mja1Vfz+938ttH//exveN326hXe3bnD77bbCn0JbRJJMQX00zZvbhJkXX4R//tPWGOnSBe65x1bu6tHDhvq9/36qKxWRw9i3bx/jx4+nR48ejB07lt27d6e6pBqpfzMTa8uWLbbOyDPP2Kp+3lure9w4uPhi6N491RWKpFzkDLwUrXLK+vXr6dy5M4sXL6a4uJirr76anj17cssttyS2mBrQzMTakptra2L/9a+waRP8+tdw0knw4x/bVmL9+8MvfgEff5zqSkXSXseOHQ8uZ3rZZZexePHiFFdUMzXZhVwOp00buOEGO8rKbJr6nDlw2212DBxorexx42wkiUgaSuEqpzE7f9e1ndHVok60Dh1gyhR44w1rTd99t21scMstNnW9uBh+9StrhYtIrdiwYcPB7bieeuopzjzzzBRXVDMK6mTKy7Op68uWwQcfwJ13wr/+BTfdZIE+bJhtMVZenupKReq1/Px8HnroIXr06MGOHTu47rrrUl1SjehmYiqsWmU3IefMsecNGtjKfhdfbOtr5+SkukKRhNAyp9XTzcS6IDSkb+VKeOcd68f+5BOYONH6u887z3Zc17ojIoKCOrWcgz59rEtk7VrrIrn5Zli92nZcb93aWtq/+AW89ZbWHhFJUwrqoHDOprDffTd89JHdjJwyBXbssBb3gAG2V+Tll8Ps2erXFkkjGp4XRM7BoEF23H237br+8svw0kuwYAHMmmXXDRgAI0fCuefCkCGQmZnaukUkKdSirgvatoUrr7SFosrL4R//sI0PTjzRprIPHWo3IC+6yBaR+uijVFcsIgmkFnVd06ABFBba8cMfwuef2+zIv/zFWtx/+pNd162btbZHjrRhgE2bprRsETl2alHXdc2bw3/8Bzz6qE2wWb3adqrp2hUeewzOP99a28OHW+tbK/5JGtm2bRsFBQUUFBTQpk0b2rdvf/D1nj17jvvz//SnP3HRRRcdfH3XXXfRtWvXg69feOEFLrjgguP+c9Sirk+cs3VG8vPhxhuhshIWL7aW9l/+Av/1X3a0aRNubX/ta9CqVaorF0mKnJwcVlStBDVt2jSaNm0atRjTvn37aNjw2GOwqKiISZMmHXy9dOlSsrOzKS8vp3Xr1pSUlFBUVHTsv0AVtajrs6ws2w/yv/8b3n3X1iH53e+sT/uFF+DSS20I4MCB1o2yeDHs3ZvqqkWSasKECVx77bUMGjSIW2+9lWnTpnHvvfcePN+7d2/Wr18PwKxZszjjjDMoKChg0qRJ7N+/P+qzcnNzyc7OZt26dQBs3LiRMWPGUFJSAkBJScnBxaCOh1rU6aR9exuffdVVtv5Iaam1tl96Ce66C372M8jOtm6SUIs7Ly/VVUt9MmxY7HvjxsH118Pu3bb59KEmTLBj61YYOzb63KJFx1RGWVkZJSUlZGRkMG3atGqvWbVqFXPmzGHJkiVkZmZy/fXXM3v2bK644oqo64qLiykpKWH//v1069aNwYMH89JLLzF69GjefvttBg4ceEw1RlJQp6uMjPAQwB//2MZrL1wYDu4//tGuy8+37pEzz7QtybT6n9QD3/jGN8jIyDjiNQsXLmTZsmUHg7aiooLWrVvHXFdUVHQwqIcMGcIZZ5zBT3/6U9566y1OO+00srKyjrteBbWYFi2stTJ2rN1sXL06PJLkscfgwQftug4dbAXAoiI7+vXT+G2J35FawCeeeOTzrVodcwv6UE2aNDn4vGHDhhyImPVbWVkJgPeeK6+8krvuuuuIn1VcXMyvf/1r9u/fz7e//W2aNWtGZWUlixYtSkj/NKiPWqrjnK1HMmWKhfXnn9vY7QcesJBessRWABw40DZLOPts6+OePx+2b0919SI1kpeXx/LlywFYvnw5H1dt9jF8+HDmzp1LedUs4O3bt/PJJ5/E/HyPHj3YtGkTixcvpn///gAUFBTwyCOPJKR/GmrQonbOZQClwEbv/eiE/OlSN2Rmhsdu33ijvffpp1BSYqFdUmLrkYRutPTsGW5xFxfbmO46tlC7pI8xY8bwxBNP0KtXLwYNGkT3qm30evbsyZ133snXv/51Dhw4QGZmJg899BCnnHJK1M875xg0aBCff/45mVX/dzlkyBBmzpyZsBZ13MucOuduBgqB7KMFtZY5TUNffglvvmmhHTpCq/+1ahUO7qIiC/zGjVNbr9QKLXNavZoucxpXi9o51wH4d+BnwM3HW6TUQ02aWBfI2Wfb6wMHrJ87stX95z/bucxMW6ck1OIuKrJp8iJSrXi7PqYDtwLNDneBc24iMBGgU6dOx1+Z1G0NGlgXSM+ecM019t6WLbB0aTi8Z8yAX/7SznXuHN1d0ru3jUwRkaMHtXNuNFDuvV/mnBt2uOu89zOBmWBdHwmrUOqP3Fy44AI7APbssXW2Qy3uhQttCVewtUkGDw63uAcNsunyImkonhZ1MXCBc24UkAVkO+dmee8vS25pUu81ahQey33zzTYscP36cB/3kiW2SuCBA3YzsmdPKCiwo18/e8zNTfVvIZJ0Rw1q7/1twG0AVS3qWxTSkhTOWRdI584wfry998UXdpNyyRKbSfnaa+FWN0C7duHwDgV4167W9SJST2jCiwRbdratVzJiRPi9bdtsFcAVK8LHyy/Dvn12vkkT6Ns3OsB797YJFSJ1UI2C2nu/CFiUlEpE4pWTA+ecY0fIV1/B++9Hh/fs2baRAlgLu3v32NZ3mzap+R2kVmzbto3hw4cDsHnzZjIyMsit6i578803adSo0XH/GXl5eZSWltIqiatQqkUt9cMJJ0D//naEhPq8I1vfS5fCH/4Qvubkk6PDu6DAJuhoxEm9kOxlTmtL8CsUOVaRfd4Ri7uzY0c4vEOP998fXuK1cWPbHT4yvPv00S459cSECRPIysrirbfeori4mOzs7KgA7927N/PmzSMvL49Zs2bxq1/9ij179jBo0CBmzJhR7WJO99xzDwsWLKBx48Y89dRTUZsHJIKCWtJPixa23Gbkkpt79sCqVeGW99tvw7PPwsyZdt45u0lZUAC9esFpp9nKgt27q++7BgKyymlClzkFaN68Oe+++y5PPPEEkydPZt68ecdW2GEoqEXAhgr262fHlVfae97bmiaR4V1aCnPnRm9n1rFjeGedUIDn59tKgxp9EkiJXOYU4JJLLjn4OGXKlMQWi4Ja5PCcg06d7Ijc966iAj74ANassWnya9bY8cQTsGtX+LrGja3FHRneoVZ4s8NO8q3XArLKaUKXOQVbmKm654mioBapqcaNbfhf377R73sPmzdHh/eaNbZE7LPP2sSdkHbtYgM8P9/+o6AbmbUqLy/vYFfFocucXnjhhUyZMoXWrVuzfft2du3aFbN6HsCcOXP4/ve/z5w5cxgyZEjCa1RQiySKc7a4VNu24cWpQiorYd266ABfswaeesrW+w454QRrcR8a4Pn5mkKfJMe7zCnAjh076Nu3LyeccAJPP/10wmuMe5nTmtAypyJx8h7Ky6PDO9Qi//jj8BrfYGO+I4P71FNtT8u8vMCGuJY5rV5SljkVkSRxzsZyn3wynHVW9Lk9e+DDD2MDfO7c2J10TjopHNp5eTYkMfJ1dnZt/DaSJApqkaBq1Mi2RKuuRbp1q7W416+PPtauten0u3dHX9+iRXRwRwb6KacoyANOQS1SF7VqZUfV0LEo3luQh8I7MtBDmxZXVET/TMuW1Qd56DiOUSre+6SMhKirjqW7WUEtUt84Z8u/5uYePsi3bIltja9fb+ulzJ9vNz8j5eRUH+CnnGLjxU86qdp9MbOysti2bRs5OTkKayykt23bRlZWVo1+TkEtkm6cg9at7TjjjNjzoRuc1QX5ypXw4ouxQd64sQ05jDzat6dDx46UdenClk2bbNihJgCRlZVFhw4davQzCmoRiRZ5g3PQoNjzoSD/+GP45BPYuBE2bbJj40ZYtsz2x6yoIBPoHPmzzZsfDPHIQI963qaN7aspBymoRaRmIoN88ODqr/HeNn2IDPFQkIeev/oqfPZZeB3xyM/PzT16oLdqlTYtdAW1iCSec9Z6bt7ctlA7nAMH7MbnoSEe+bq01Frwh96Ea9jQJheFwrt1awv46h5zcuz6OqruVi4idV+DBuH+8oKCw1+3d69Nz6+uZb5pkw1LXLLEQj9yqn6Iczay5XBBfuhjy5aBmsqvoBaR4MvMtFUKO3Y88nX799t64+XlNrLlcI/vv2/7b27bFttSB/sPSE5O/MHeokVSu2EU1CJSf2RkhMeYx2PfPgvrI4V6eTm884497thx5D/31FOtZZ9gCmoRSV8NG4ZvjMZj717rXjlcoCepVa2gFhGJV2ZmeIXEWpQeY1tEROowBbWISMApqEVEAk5BLSIScApqEZGAU1CLiAScglpEJOAU1CIiAaegFhEJuKMGtXMuyzn3pnPubefcSufc/6uNwkRExMQzhfwr4Bzv/b+cc5nAYufcAu/9G0muTUREiCOovW2Z+6+ql5lVR8230Y3XsGGx740bB9dfD7t3w6hRsecnTLBj61YYOzb2/HXXwcUXw6efwuWXx56fOhXOPx/WrIFJk2LP//CHMGIErFgBkyfHnv/5z6GoCEpK4PbbY89Pn25r7b7yCtx5Z+z5Rx+F/Hx44QW4777Y808+acs7zpkDDz8ce37uXFu56/HH7TjU/Plw4okwYwY880zs+UWL7PHee2HevOhzjRvDggX2/I47YOHC6PM5OfDcc/b8tttg6dLo8x06wKxZ9nzyZPsOI3XvDjNn2vOJE21d4UgFBfb9AVx2GZSVRZ8fMgTuusuejxljK6FFGj4cfvQje37eebG7b48eDbfcYs/1dy/2vP7u2fN4/+6Ffp8Ei6uP2jmX4ZxbAZQD/+e9/3s110x0zpU650q3bNmS6DpFRNKW89Utmn24i507Cfgj8F3v/XuHu66wsNCXlpYmoDwRkfTgnFvmvS+s7lyNRn1473cCrwLnJqIwERE5unhGfeRWtaRxzjUGvgasTnZhIiJi4hn10Rb4vXMuAwv2Z7z3847yMyIikiDxjPp4B+hfC7WIiEg1NDNRRCTgFNQiIgGnoBYRCTgFtYhIwCmoRUQCTkEtIhJwCmoRkYBTUIuIBJyCWkQk4BTUIiIBp6AWEQk4BbWISMApqEVEAk5BLSIScApqEZGAU1CLiAScglpEJOAU1CIiAaegFhEJOAW1iEjAKahFRAJOQS0iEnAKahGRgFNQi4gEnIJaRCTgFNQiIgGnoBYRCTgFtYhIwCmoRUQCTkEtIhJwCmoRkYA7alA75zo65151zr3vnFvpnLupNgoTERHTMI5r9gFTvffLnXPNgGXOuf/z3r+f5NpERIQ4gtp7/xnwWdXzXc65VUB7IClBPWxY7HvjxsH118Pu3TBqVOz5CRPs2LoVxo6NPX/ddXDxxfDpp3D55bHnp06F88+HNWtg0qTY8z/8IYwYAStWwOTJsed//nMoKoKSErj99tjz06dDQQG88grceWfs+Ucfhfx8eOEFuO++2PNPPgkdO8KcOfDww7Hn586FVq3g8cftONT8+XDiiTBjBjzzTOz5RYvs8d57Yd686HONG8OCBfb8jjtg4cLo8zk58Nxz9vy222Dp0ujzHTrArFn2fPJk+w4jde8OM2fa84kTYe3a6PMFBfb9AVx2GZSVRZ8fMgTuusuejxkD27ZFnx8+HH70I3t+3nlQURF9fvRouOUWe66/e7Hn9XfPnsf7dy/0+yRajfqonXN5QH/g79Wcm+icK3XOlW7ZsiUx1YmICM57H9+FzjUFXgN+5r1//kjXFhYW+tLS0gSUJyKSHpxzy7z3hdWdi6tF7ZzLBJ4DZh8tpEVEJLHiGfXhgMeAVd77+5NfkoiIRIqnRV0MXA6c45xbUXVUc1tFRESSIZ5RH4sBVwu1iIhINTQzUUQk4BTUIiIBp6AWEQk4BbWISMApqEVEAk5BLSIScPGsniciIlUqKmwRruoO52DatMT/mQpqEUlbe/bYqneHC97qjt27q/8s56BzZwW1iMhh7dsH27fXLHR37Tr85zVvbku4tmoFbdtCnz7h16EjJyf8vEULaJikRFVQi0hgeQ+ffw6ffRZ9bN4c/fqf/4QdOw7/OU2aRAds9+6xoRt5tGwJjRrV3u95NApqEal1+/fDli1HD+DNm6GyMvbns7Ksldu2LfTqBeecA7m51YduTo5tQlCXKahFJGEqK2PDtroALi+HAwdif/6kk8IBXFxsj23ahN8LHdnZ1iecLhTUIhKXL7+ETz6BDRvsMfQ8MoB37oz9uQYN4OSTw4E7YEA4cCNDuE0baylLLAW1iOC93Vw7NIgjXx+6H2XDhtC+PbRrBz16WPdDdQGcmwsZGan5veoLBbVIGti3DzZujA3hUBBv2BA77KxJEzjlFDsGDgw/Dx1t2yqAa4uCWqQe2L37yK3hjRvtBl6k3FwL3F69bIf2yBDu1MlGPqRTP3CQKahF6gDvYdMmWLsW1qyBDz6Ajz8OB/HWrdHXZ2RAhw4WumedFRvCnTrBiSem5neRmlNQiwTIzp0WxqFjzRp7/OADu5kXkpVls+A6dYLTT48N4nbtkjf5Qmqf/lGK1LKvvoIPPwyHcORRXh6+rkEDC+Pu3WHoUHvMz7fH9u3tvKQHBbVIEhw4AJ9+GtsyXrvWuioixxCffLKF7wUX2GPo6NIFTjghdb+DBIeCWuQ4bN0a2ypeswbWrYueUde0qYXvoEFwxRXhMO7WzdaUEDkSBbXIUXhvM+tWrIC334ZVq8KhvH17+LqGDeHUUy2AR46M7qpo00YjKOTYKahFIuzday3iUCiHHrdsCV/Trp0F8Lhx0V0VeXmQmZmy0qUeU1BL2tq500I4MpDfe8/WKAbrH+7VC84/H/r1g4IC6NvX1qMQqU0Kaqn3vLcxx5GBvGKF3dQLyc21IL7xRnvs189azWohSxAoqKVeqaiAlSujA/mdd+CLL+x8gwbWTTFkCFx7bTiU1YcsQaagljpr8+bYVvKaNeGhb02bWghfdlk4kHv31ow8qXsU1BJ43ttwt3/8IzqY//nP8DWdOlkQjx0b7k/u3FmTQqR+UFBL4Hz1FSxbBkuW2FFSEh510ahReBGhUCu5b19bQEikvlJQS8pt3WphHArm0lILa4CuXWHUKNvtY9AgW/dYN/gk3Rw1qJ1zvwNGA+Xe+97JL0nqM+9tokgolJcssX5lsAA+/XS44QYL5qIim14tku7iaVE/DjwIPJHcUqQ+qqyM7cYILcnZsqWF8YQJFsyFhXV/E1KRZDhqUHvvX3fO5SW/FKkPtmyJbi0vWxaeQNKtG4webaFcXGzjlHWzT+ToEtZH7ZybCEwE6NSpU6I+VgLMe1i9OjqYP/jAzjVqZC3kG28Md2O0bp3aekXqqoQFtfd+JjAToLCw0CfqcyU4KirsRl9kN0ZoUaKcHAvka66xx9NP147SIomiUR9yWBUVsGgRLFwY7sbYu9fO5efDRReFuzG6d9fMPpFkUVBLlHXrYMECmD/fQrqy0roxBg6EKVPC3RitWqW6UpH0Ec/wvKeBYUAr51wZ8BPv/WPJLkxqR0UFvPaaBfOCBRbUYDf+Jk2yiSVnnaXRGCKpFM+oj0tqoxCpPaFW84IF8Oqr1mrOyoJzzoGbbrJwPvXUVFcpIiHq+kgDoVZzKJxDIzO6dYOJEy2Yhw5Vq1kkqBTU9dSHH0a3misqrNV89tnw3e9aOHftmuoqRSQeCup6orIyutW8dq2937WrDZk77zwYNkytZpG6SEFdh330UXiERmSredgw+M53LJy7dUt1lSJyvBTUdUhlJbz+eniERqjVfOqp8K1v2SpzQ4dqYXyR+kZBHXAffxwO5ldfhd27bdNVtZpF0oeCOoBWr4annoJnn7XnAF26wNVXh/ua1WoWSR8K6oDYtAn+8AeYPRuWL7dV5YYNsw1YQ61mTdEWSU8K6hT6/HN4/nkL51dftU1ZTz8d7r8fvvlNaNs21RWKSBAoqGvZV19Zf/Ps2fDCC/a6Sxf4wQ/g0kvhtNNSXaGIBI2CuhYcOAB/+5uF89y5sGMH5ObCt78N48fbXoDq1hCRw1FQJ9E771g4P/00fPopNGliS4OOHw8jRmiTVhGJj4I6wTZssBEbs2fDe+9BRgaMHAm/+AVceKGFtYhITSioE2D7dhtKN3u2dXEADBkCDz4I48ZZN4eIyLFSUB+jigq7GTh7tt0c3LvXbgTecYfdFOzSJdUVikh9oaCugf374a9/tXB+/nnYtcuG0H33u9bv3L+/bgqKSOIpqI/Ce9srcPZsm5CyeTNkZ8PYsRbOw4ZZP7SISLIoqA9j3brwTcG1a23fwFGjLJxHj9YO2yJSexTUEfbssZuCDz4Ib7xh3RhDh8L3vgdjxkCLFqmuUETSkYIaG7Xx6KMW0Js2QX4+3HOPTePu2DHV1YlIukvroF6zBqZPh9//3kZxjBgBv/kNnHuuLYokIhIEaRfU3tvIjV/+El580dZ2Hj8eJk+GPn1SXZ2ISKy0CerKSpvKPX26Te1u3RqmTYPrrrPnIiJBVe+DurwcHn4YZsyw5336wGOP2aQUjdwQkbqg3gb1e+9Z98bs2baU6KhRMGUKDB+uSSkiUrfUq6A+cABeeskW3n/lFWjcGK66Cm66Ses8i0jdVS+CevduePJJ639evRratYOf/xwmToScnFRXJyJyfOp0UG/aBA89BI88YmOhTz8dZs2Cb3zDZhKKiNQHdTKoly+3/uc5c2DfPlvn+eab4cwz1f8sIvVPnQnq/fttWdFf/hJefx2aNrWhdTfeCKeemurqRESSJ/BB/a9/wf/8DzzwAHz4IXTqBPfeC9dcA82bp7o6EZHki2uitHPuXOfcGufcOufc95NdFNiWVt/7HnToYK3m1q3hmWcsrKdOVUiLSPo4aovaOZcBPAR8DSgD/uGc+7P3/v1kFPTGG9a98dxz9nrsWBv/PGhQMv40EZHgi6fr4wxgnff+IwDn3B+AC4GEBvUXX9gmsG+8Ya3lm2+GG26wrg4RkXQWT1C3Bz6NeF0GxLRvnXMTgYkAnY4hXbOz7abgpZfaJJWmTWv8ESIi9VLCbiZ672cCMwEKCwv9sXzGrFmJqkZEpP6I52biRiBy+fwOVe+JiEgtiCeo/wF0c851ds41Ar4J/Dm5ZYmISMhRu9PBpWwAAAMJSURBVD689/ucczcALwEZwO+89yuTXpmIiABx9lF77+cD85Nci4iIVEM7A4qIBJyCWkQk4BTUIiIBp6AWEQk45/0xzU058oc6twX45Bh/vBWwNYHl1GX6LqLp+4im7yOsPnwXp3jvc6s7kZSgPh7OuVLvfWGq6wgCfRfR9H1E0/cRVt+/C3V9iIgEnIJaRCTgghjUM1NdQIDou4im7yOavo+wev1dBK6PWkREogWxRS0iIhEU1CIiAReYoE7FBrpB5Zzr6Jx71Tn3vnNupXPuplTXlGrOuQzn3FvOuXmpriXVnHMnOefmOudWO+dWOeeGpLqmVHLOTan69+Q959zTzrmsVNeUaIEI6ogNdM8DegKXOOd6praqlNoHTPXe9wQGA99J8+8D4CZgVaqLCIgHgL94708D+pHG34tzrj1wI1Dove+NLcX8zdRWlXiBCGoiNtD13u8BQhvopiXv/Wfe++VVz3dh/yK2T21VqeOc6wD8O/DbVNeSas655sBZwGMA3vs93vudqa0q5RoCjZ1zDYETgU0prifhghLU1W2gm7bBFMk5lwf0B/6e2kpSajpwK3Ag1YUEQGdgC/A/VV1Bv3XONUl1Uanivd8I3AtsAD4DPvfev5zaqhIvKEEt1XDONQWeAyZ7779IdT2p4JwbDZR775elupaAaAgMAB723vcHvgTS9p6Oc64F9n/fnYF2QBPn3GWprSrxghLU2kD3EM65TCykZ3vvn091PSlUDFzgnFuPdYmd45xL5/3qy4Ay733o/7DmYsGdrkYAH3vvt3jv9wLPA0UprinhghLU2kA3gnPOYX2Qq7z396e6nlTy3t/mve/gvc/D/l781Xtf71pM8fLebwY+dc7lV701HHg/hSWl2gZgsHPuxKp/b4ZTD2+uxrVnYrJpA90YxcDlwLvOuRVV791etXelyHeB2VWNmo+Aq1JcT8p47//unJsLLMdGS71FPZxOrinkIiIBF5SuDxEROQwFtYhIwCmoRUQCTkEtIhJwCmoRkYBTUIuIBJyCWkQk4P4/H7eR/ck29KUAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "model = Model()\n", "\n", "# 後ほどプロットするために、W 値と b 値の履歴を集める\n", "Ws, bs = [], []\n", "epochs = range(10)\n", "for epoch in epochs:\n", " Ws.append(model.W.numpy())\n", " bs.append(model.b.numpy())\n", " current_loss = loss(model(inputs), outputs)\n", "\n", " train(model, inputs, outputs, learning_rate=0.1)\n", " print('Epoch %2d: W=%1.2f b=%1.2f, loss=%2.5f' %\n", " (epoch, Ws[-1], bs[-1], current_loss))\n", "\n", "# すべてをプロット\n", "plt.plot(epochs, Ws, 'r',\n", " epochs, bs, 'b')\n", "plt.plot([TRUE_W] * len(epochs), 'r--',\n", " [TRUE_b] * len(epochs), 'b--')\n", "plt.legend(['W', 'b', 'True W', 'True b'])\n", "plt.show()\n" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "vPnIVuaSJwWz" }, "source": [ "## 次のステップ\n", "\n", "このチュートリアルでは `tf.Variable` を使って単純な線形モデルの構築と訓練を行いました。\n", "\n", "実際にニューラルネットワークを構築する際には、`tf.keras` のような高レベルな API のほうが遥かに便利です。`tf.keras` は、(「レイヤー」と呼ばれる)高レベルの部品、状態を保存・復元するためのユーティリティ、さまざまな損失関数、さまざまな最適化戦略などを提供しています。詳しく知るには [TensorFlow Keras guide](https://www.tensorflow.org/guide/keras/overview) を参照してください。\n" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "custom_training.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true, "version": "0.3.2" }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 0 }