{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "tqrD7Yzlmlsk" }, "source": [ "##### Copyright 2019 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "2k8X1C1nmpKv" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "32xflLc4NTx-" }, "source": [ "# 自定义联合算法,第 2 部分:实现联合平均" ] }, { "cell_type": "markdown", "metadata": { "id": "jtATV6DlqPs0" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 上查看在 Google Colab 中运行在 GitHub 上查看源代码 下载笔记本
" ] }, { "cell_type": "markdown", "metadata": { "id": "_igJ2sfaNWS8" }, "source": [ "本系列教程包括两个部分,此为第二部分。该系列演示了如何使用 [Federated Core (FC)](../federated_core.md) 在 TFF 中实现自定义类型的联合算法,它是[联合学习 (FL)](../federated_learning.md) 层(`tff.learning`)的基础。\n", "\n", "我们建议您先阅读[本系列的第一部分](custom_federated_algorithms_1.ipynb),其中介绍了此处使用的一些关键概念和编程抽象。\n", "\n", "本系列的第二部分使用第一部分中介绍的机制来实现简单版本的联合训练和评估算法。\n", "\n", "我们建议您查看[图像分类](federated_learning_for_image_classification.ipynb)和[文本生成](federated_learning_for_text_generation.ipynb)教程,以获得对 TFF 的 Federated Learning API 更高级和更循序渐进的介绍,因为它们将帮助您在上下文中理解我们在此描述的概念。" ] }, { "cell_type": "markdown", "metadata": { "id": "cuJuLEh2TfZG" }, "source": [ "## 准备工作\n", "\n", "在开始之前,请尝试运行以下“Hello World”示例,以确保您的环境已正确配置。如果无法正常运行,请参阅[安装](../install.md)指南查看说明。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rB1ovcX1mBxQ" }, "outputs": [], "source": [ "#@test {\"skip\": true}\n", "!pip install --quiet --upgrade tensorflow-federated-nightly\n", "!pip install --quiet --upgrade nest-asyncio\n", "\n", "import nest_asyncio\n", "nest_asyncio.apply()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-skNC6aovM46" }, "outputs": [], "source": [ "import collections\n", "\n", "import numpy as np\n", "import tensorflow as tf\n", "import tensorflow_federated as tff\n", "\n", "# TODO(b/148678573,b/148685415): must use the reference context because it\n", "# supports unbounded references and tff.sequence_* intrinsics.\n", "tff.backends.reference.set_reference_context()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zzXwGnZamIMM" }, "outputs": [ { "data": { "text/plain": [ "'Hello, World!'" ] }, "execution_count": 4, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "@tff.federated_computation\n", "def hello_world():\n", " return 'Hello, World!'\n", "\n", "hello_world()" ] }, { "cell_type": "markdown", "metadata": { "id": "iu5Gd8D6W33s" }, "source": [ "## 实现联合平均\n", "\n", "与[图像分类联合学习](federated_learning_for_image_classification.ipynb)一样,我们将使用 MNIST 示例,但由于这是一个低级教程,我们将绕过 Keras API 和 `tff.simulation`,编写原始模型代码,并从头开始构造联合数据集。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "b6qCjef350c_" }, "source": [ "### 准备联合数据集\n", "\n", "为了进行演示,我们将模拟一个场景,其中有来自 10 个用户的数据,每个用户都会提供如何识别不同数字的知识。这是能够得到的最非[独立同分布](https://en.wikipedia.org/wiki/Independent_and_identically_distributed_random_variables)的情况。\n", "\n", "首先,加载标准 MNIST 数据:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "uThZM4Ds-KDQ" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz\n", "11493376/11490434 [==============================] - 0s 0us/step\n", "11501568/11490434 [==============================] - 0s 0us/step\n" ] } ], "source": [ "mnist_train, mnist_test = tf.keras.datasets.mnist.load_data()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PkJc5rHA2no_" }, "outputs": [ { "data": { "text/plain": [ "[(dtype('uint8'), (60000, 28, 28)), (dtype('uint8'), (60000,))]" ] }, "execution_count": 6, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "[(x.dtype, x.shape) for x in mnist_train]" ] }, { "cell_type": "markdown", "metadata": { "id": "mFET4BKJFbkP" }, "source": [ "数据以 Numpy 数组的形式出现,一个带有图像,另一个带有数字标签,其中第一个维度都遍历各个样本。我们来编写一个辅助函数,并使用与将联合序列馈送到 TFF 计算的方式相兼容的方式(即作为列表的列表,外部列表包括用户(数字),内部列表包括每个客户端序列中的数据批次)对其进行格式化。按照惯例,我们将每个批次构造为一对名为 `x` 和 `y` 的张量,每个张量都具有与首个批次相同的维度。同时,我们还将每个图像展平为一个具有 784 个元素的向量,并将其中的像素重新缩放到 `0..1` 范围内,这样我们就不必在模型逻辑上进行数据转换了。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XTaTLiq5GNqy" }, "outputs": [], "source": [ "NUM_EXAMPLES_PER_USER = 1000\n", "BATCH_SIZE = 100\n", "\n", "\n", "def get_data_for_digit(source, digit):\n", " output_sequence = []\n", " all_samples = [i for i, d in enumerate(source[1]) if d == digit]\n", " for i in range(0, min(len(all_samples), NUM_EXAMPLES_PER_USER), BATCH_SIZE):\n", " batch_samples = all_samples[i:i + BATCH_SIZE]\n", " output_sequence.append({\n", " 'x':\n", " np.array([source[0][i].flatten() / 255.0 for i in batch_samples],\n", " dtype=np.float32),\n", " 'y':\n", " np.array([source[1][i] for i in batch_samples], dtype=np.int32)\n", " })\n", " return output_sequence\n", "\n", "\n", "federated_train_data = [get_data_for_digit(mnist_train, d) for d in range(10)]\n", "\n", "federated_test_data = [get_data_for_digit(mnist_test, d) for d in range(10)]" ] }, { "cell_type": "markdown", "metadata": { "id": "xpNdBimWaMHD" }, "source": [ "作为快速的健全性检查,我们来看一下第五个客户端(对应数字 `5`)所贡献的最后一个数据批次中的 `Y` 张量。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "bTNuL1W4bcuc" }, "outputs": [ { "data": { "text/plain": [ "array([5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,\n", " 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,\n", " 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,\n", " 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,\n", " 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5], dtype=int32)" ] }, "execution_count": 8, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "federated_train_data[5][-1]['y']" ] }, { "cell_type": "markdown", "metadata": { "id": "Xgvcwv7Obhat" }, "source": [ "保险起见,我们再检查一下该批次最后一个元素对应的图像。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cI4aat1za525" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAEBZJREFUeJzt3W1sU2Ufx/HfHgJKAOOQzo6hY9mE\nyTbINrMYZQjIgyaOh4Vn4mCEJeALA6hB34AJgSVCggmoVIgsMWogymZAFoiAyoxZivQFSAxOFsco\nG8gMAxU2PPcLc++WW3q1dH2C6/tJTrL239Pzz9l+O+25TnslOY7jCIB1kuPdAID4IPyApQg/YCnC\nD1iK8AOWIvyApQg/YCnCD1iK8AOWSo3lxpKSkmK5OcBKoV6026cjf0NDg0aOHKmcnBzV1NT05akA\nxJoTpp6eHic7O9tpbm52rl+/7hQWFjqnTp0yriOJhYUlykuowj7yNzU1KScnR9nZ2erXr5/mzZun\n+vr6cJ8OQIyFHf62tjYNHz6893ZmZqba2tr+9TiPx6OSkhKVlJSEuykAURD2Cb/bnVS43Qm96upq\nVVdXB6wDiI+wj/yZmZlqbW3tvX3u3DllZGREpCkAMRDuCb/u7m5nxIgRzs8//9x7wu/kyZOc8GNh\nifMSqrBf9qempmrr1q2aOnWqbt68qaqqKo0ePTrcpwMQY0nO7d68R2tjvOcHoi7USHN5L2Apwg9Y\nivADliL8gKUIP2Apwg9YivADliL8gKUIP2Apwg9YivADliL8gKUIP2Apwg9YivADliL8gKUIP2Ap\nwg9YivADliL8gKUIP2CpmE7RjcTz0ksvGevXrl0z1nft2hXBbm6VlZVlrCcnm49dc+bMCVgbNmyY\ncd0VK1YY688++6yxfuTIEWM9EXDkByxF+AFLEX7AUoQfsBThByxF+AFLEX7AUn0a58/KytKgQYOU\nkpKi1NRUeb3eSPWFGJkxY4axPmHCBGN9yJAhxrrP5wtYW7BggXHdRYsWGespKSnGel9cvXrVWO/s\n7IzatmOlzxf5HDlyRA899FAkegEQQ7zsByzVp/AnJSVpypQpKi4ulsfjiVRPAGKgTy/7GxsblZGR\noY6ODk2ePFmjRo1SWVnZLY/xeDz8YwASUJ+O/BkZGZIkl8ulmTNnqqmp6V+Pqa6ultfr5WQgkGDC\nDv+1a9fU1dXV+/PBgweVn58fscYARFfYL/vb29s1c+ZMSVJPT48WLFigadOmRawxANGV5DiOE7ON\nJSXFalMI0aFDh4z1YOP8wX6nMfzzuiMrV6401g8cOGCs//TTT5FsJ6JC3ecM9QGWIvyApQg/YCnC\nD1iK8AOWIvyApfjq7nuAabjtqaeeMq47fvz4SLcTsj/++MNY/+9FZIE0NDQY6+vXrw9YO3v2rHHd\nRB2ijCSO/IClCD9gKcIPWIrwA5Yi/IClCD9gKcIPWIqP9N4DBg4cGLD222+/RXXbN27cMNY///zz\ngLVNmzYZ1+Xbn8LDR3oBGBF+wFKEH7AU4QcsRfgBSxF+wFKEH7AUn+e/B8yePTtu216xYoWxvmvX\nrtg0gjvGkR+wFOEHLEX4AUsRfsBShB+wFOEHLEX4AUsFHeevqqrSvn375HK5dPLkSUnS5cuXNXfu\nXLW0tCgrK0u7d+/Wgw8+GPVmbTVnzhxjfcuWLVHb9jvvvGOsM45/9wp65F+8ePG/JkeoqanRpEmT\ndObMGU2aNEk1NTVRaxBAdAQNf1lZmdLS0m65r76+XpWVlZKkyspK1dXVRac7AFET1nv+9vZ2ud1u\nSZLb7VZHR0dEmwIQfVG/tt/j8cjj8UR7MwDuUFhH/vT0dPn9fkmS3++Xy+UK+Njq6mp5vV6+jBFI\nMGGFv7y8XLW1tZKk2tpaTZ8+PaJNAYi+oOGfP3++nnzySf3444/KzMzUzp07tWbNGh06dEi5ubk6\ndOiQ1qxZE4teAUQQ39ufAAYMGGCsf/vtt8Z6fn5+2Ns+fPiwsV5RUWGsd3V1hb1tRAff2w/AiPAD\nliL8gKUIP2Apwg9YivADluKru2Ogf//+xvr27duN9b4M5QWzceNGY52hvHsXR37AUoQfsBThByxF\n+AFLEX7AUoQfsBThByzFOH8MPPPMM8b6/PnzY9PIbcyaNctYLywsNNavXLlirH/wwQd33BNigyM/\nYCnCD1iK8AOWIvyApQg/YCnCD1iK8AOW4qu7Y2D//v3G+rRp02LUSeQlJ5uPH/X19QFrwfbLzp07\njfW//vrLWLcVX90NwIjwA5Yi/IClCD9gKcIPWIrwA5Yi/IClgo7zV1VVad++fXK5XDp58qQkad26\ndXr//fc1dOhQSdKGDRv0/PPPB9+YpeP8RUVFxvq7775rrBcXF4e97dOnTxvrfr/fWB8+fLix/thj\njxnrfbmMZM2aNcb6pk2bwn7ue1nExvkXL16shoaGf92/cuVK+Xw++Xy+kIIPILEEDX9ZWZnS0tJi\n0QuAGAr7Pf/WrVtVWFioqqoqdXZ2RrInADEQVviXL1+u5uZm+Xw+ud1urV69OuBjPR6PSkpKVFJS\nEnaTACIvrPCnp6crJSVFycnJWrZsmZqamgI+trq6Wl6vV16vN+wmAUReWOH/5xnivXv3RnUWWQDR\nEfSru+fPn6+jR4/q0qVLyszM1JtvvqmjR4/K5/MpKSlJWVlZQaeYBpB4+Dx/AhgwYICxnp2dHfZz\nt7W1GevBTtYOGTLEWB85cqSx/vrrrwesPffcc8Z1b968aazPmDHDWD9w4ICxfq/i8/wAjAg/YCnC\nD1iK8AOWIvyApQg/YCmG+iLg/vvvN9b//PNPYz2Gv4KYS0lJCVjz+XzGdfPy8oz1xsZGY338+PHG\n+r2KoT4ARoQfsBThByxF+AFLEX7AUoQfsBThBywV9PP8+NsDDzwQsPbRRx8Z1509e7ax/vvvv4fV\n091g4MCBAWv33Xdfn547NZU/377gyA9YivADliL8gKUIP2Apwg9YivADliL8gKUYKA2RabqxqVOn\nGtcNNo11sM+1JzLTOL4kffjhhwFrI0aMiHQ7uAMc+QFLEX7AUoQfsBThByxF+AFLEX7AUoQfsFTQ\ncf7W1la9+OKLunDhgpKTk1VdXa2XX35Zly9f1ty5c9XS0qKsrCzt3r1bDz74YCx6vus0NDQY66Zp\nrCVpz549kWznjixevNhYX7t2rbHel7+J7u5uY/29994L+7kRwpE/NTVVmzdv1unTp/Xdd99p27Zt\n+uGHH1RTU6NJkybpzJkzmjRpkmpqamLRL4AICRp+t9utoqIiSdKgQYOUl5entrY21dfXq7KyUpJU\nWVmpurq66HYKIKLu6D1/S0uLTpw4odLSUrW3t8vtdkv6+x9ER0dHVBoEEB0hX9t/9epVVVRUaMuW\nLRo8eHDIG/B4PPJ4PGE1ByB6Qjryd3d3q6KiQgsXLtSsWbMkSenp6fL7/ZIkv98vl8t123Wrq6vl\n9Xrl9Xoj1DKASAgafsdxtHTpUuXl5WnVqlW995eXl6u2tlaSVFtbq+nTp0evSwARF3SK7mPHjmnc\nuHEqKChQcvLf/ys2bNig0tJSzZkzR7/88oseeeQR7dmzR2lpaeaN3cVTdJeWlgasHT582Lhu//79\nI91Owgj2OzX9eXV2dhrXDTYEumPHDmPdVqFO0R30Pf/TTz8d8Mm+/PLLO+sKQMLgCj/AUoQfsBTh\nByxF+AFLEX7AUoQfsFTQcf6IbuwuHuc3WbJkibEe7KOnKSkpkWwnpoL9Ti9evBiwVlFRYVy3sbEx\nrJ5sF2qkOfIDliL8gKUIP2Apwg9YivADliL8gKUIP2ApxvljYNSoUcb6Z599ZqwHm+I7moJNH75v\n3z5j3XSNw4ULF8LqCWaM8wMwIvyApQg/YCnCD1iK8AOWIvyApQg/YCnG+YF7DOP8AIwIP2Apwg9Y\nivADliL8gKUIP2Apwg9YKmj4W1tbNWHCBOXl5Wn06NF6++23JUnr1q3TsGHDNHbsWI0dO1ZffPFF\n1JsFEDlBL/Lx+/3y+/0qKipSV1eXiouLVVdXp927d2vgwIF65ZVXQt8YF/kAURfqRT6pwR7gdrvl\ndrslSYMGDVJeXp7a2tr61h2AuLuj9/wtLS06ceKESktLJUlbt25VYWGhqqqq1NnZedt1PB6PSkpK\nVFJS0vduAUSOE6Kuri6nqKjI+fTTTx3HcZwLFy44PT09zs2bN5033njDWbJkSdDnkMTCwhLlJVQh\nPfLGjRvOlClTnM2bN9+2fvbsWWf06NGEn4UlAZZQBX3Z7ziOli5dqry8PK1atar3fr/f3/vz3r17\nlZ+fH+ypACSQoGf7jx07pnHjxqmgoEDJyX//r9iwYYM+/vhj+Xw+JSUlKSsrS9u3b+89MRhwY5zt\nB6IuSKR78Xl+4B4TaqS5wg+wFOEHLEX4AUsRfsBShB+wFOEHLEX4AUsRfsBShB+wFOEHLEX4AUsR\nfsBShB+wFOEHLBX0CzwjaciQIcrKyuq9ffHiRQ0dOjSWLYQsUXtL1L4kegtXJHtraWkJ+bEx/Tz/\n/yspKZHX643X5o0StbdE7Uuit3DFqzde9gOWIvyApVLWrVu3Lp4NFBcXx3PzRonaW6L2JdFbuOLR\nW1zf8wOIH172A5aKS/gbGho0cuRI5eTkqKamJh4tBJSVlaWCggKNHTs27lOMVVVVyeVy3TInwuXL\nlzV58mTl5uZq8uTJAadJi0dviTJzc6CZpeO97xJuxuuQp/eIkJ6eHic7O9tpbm52rl+/7hQWFjqn\nTp2KdRsBPfroo87Fixfj3YbjOI7z1VdfOcePH79lNqRXX33V2bhxo+M4jrNx40bntddeS5je1q5d\n67z11ltx6eefzp8/7xw/ftxxHMe5cuWKk5ub65w6dSru+y5QX/HabzE/8jc1NSknJ0fZ2dnq16+f\n5s2bp/r6+li3cVcoKytTWlraLffV19ersrJSklRZWam6urp4tHbb3hKF2+1WUVGRpFtnlo73vgvU\nV7zEPPxtbW0aPnx47+3MzMyEmvI7KSlJU6ZMUXFxsTweT7zb+Zf29vbemZHcbrc6Ojri3NGtQpm5\nOZb+ObN0Iu27cGa8jrSYh9+5zeBCIs3k09jYqO+//14HDhzQtm3b9PXXX8e7pbvG8uXL1dzcLJ/P\nJ7fbrdWrV8e1n6tXr6qiokJbtmzR4MGD49rLP/1/X/HabzEPf2ZmplpbW3tvnzt3ThkZGbFuI6D/\n9uJyuTRz5kw1NTXFuaNbpaen906S6vf75XK54tzR/6SnpyslJUXJyclatmxZXPddd3e3KioqtHDh\nQs2aNau3v3jvu0B9xWO/xTz8TzzxhM6cOaOzZ8/qxo0b+uSTT1ReXh7rNm7r2rVr6urq6v354MGD\nCTf7cHl5uWprayVJtbW1mj59epw7+p9EmbnZCTCzdLz3XaC+4rbfYn6K0XGc/fv3O7m5uU52draz\nfv36eLRwW83NzU5hYaFTWFjoPP7443Hvbd68ec7DDz/spKamOsOGDXN27NjhXLp0yZk4caKTk5Pj\nTJw40fn1118TprdFixY5+fn5TkFBgfPCCy8458+fj0tv33zzjSPJKSgocMaMGeOMGTPG2b9/f9z3\nXaC+4rXfuMIPsBRX+AGWIvyApQg/YCnCD1iK8AOWIvyApQg/YCnCD1jqP1hNIrYb+rn+AAAAAElF\nTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "tags": [] }, "output_type": "display_data" } ], "source": [ "from matplotlib import pyplot as plt\n", "\n", "plt.imshow(federated_train_data[5][-1]['x'][-1].reshape(28, 28), cmap='gray')\n", "plt.grid(False)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "J-ox58PA56f8" }, "source": [ "### 关于 TensorFlow 与 TFF 的结合\n", "\n", "在本教程中,出于紧凑考虑,我们使用 `tff.tf_computation` 对引入 TensorFlow 逻辑的函数进行了直接装饰。但对于更复杂的逻辑,我们不建议使用这种模式。调试 TensorFlow 本身就是一种挑战,如果在 TensorFlow 完全序列化并重新导入后再对其进行调试,必然会丢失部分元数据并限制交互性,这会使调试面临更大挑战。\n", "\n", "因此,**我们强烈建议将复杂的 TF 逻辑编写为独立的 Python 函数**(即不使用 `tff.tf_computation` 装饰)。这样,在序列化 TFF 计算之前(例如,通过将 Python 函数用作参数调用 `tff.tf_computation`),可以使用 TF 最佳做法和工具(如 Eager 模式)对 TensorFlow 逻辑进行开发和测试。" ] }, { "cell_type": "markdown", "metadata": { "id": "RSd6UatXbzw-" }, "source": [ "### 定义损失函数\n", "\n", "现在有了数据,我们来定义一个可以用于训练的损失函数。首先,将输入类型定义为 TFF 命名元组。由于数据批次的大小可能会有所不同,因此我们将批次维度设置为 `None`,表示该维度的大小未知。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "653xv5NXd4fy" }, "outputs": [ { "data": { "text/plain": [ "''" ] }, "execution_count": 10, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "BATCH_SPEC = collections.OrderedDict(\n", " x=tf.TensorSpec(shape=[None, 784], dtype=tf.float32),\n", " y=tf.TensorSpec(shape=[None], dtype=tf.int32))\n", "BATCH_TYPE = tff.to_type(BATCH_SPEC)\n", "\n", "str(BATCH_TYPE)" ] }, { "cell_type": "markdown", "metadata": { "id": "pb6qPUvyh5A1" }, "source": [ "您可能想知道为什么我们不能只定义普通的 Python 类型。回想一下[第 1 部分](custom_federated_algorithms_1.ipynb)中讨论的内容,我们解释了虽然可以使用 Python 来表达 TFF 计算的逻辑,但实际上 TFF 计算*不是* Python。上面定义的符号 `BATCH_TYPE` 表示抽象的 TFF 类型规范。区分这种*抽象的* TFF 类型与具体的 Python *表示* 类型(可用来表示 Python 函数主体中 TFF 类型的容器,如 `dict` 或 `collections.namedtuple`)很重要。与 Python 不同,针对类似元组的容器,TFF 具有单个抽象类型构造函数 `tff.StructType`,其元素可以单独命名或不命名。这种类型还用于对计算的形式化参数进行建模,因为 TFF 计算形式上只能声明一个参数和一个结果(稍后您将看到相关示例)。\n", "\n", "现在,我们来定义模型参数的 TFF 类型,仍然将其定义为*权重*和*偏差*的 TFF 命名元组。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Og7VViafh-30" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "MODEL_SPEC = collections.OrderedDict(\n", " weights=tf.TensorSpec(shape=[784, 10], dtype=tf.float32),\n", " bias=tf.TensorSpec(shape=[10], dtype=tf.float32))\n", "MODEL_TYPE = tff.to_type(MODEL_SPEC)\n", "\n", "print(MODEL_TYPE)" ] }, { "cell_type": "markdown", "metadata": { "id": "iHhdaWSpfQxo" }, "source": [ "有了这些定义,现在我们可以在单个批次上定义给定模型的损失。请注意 `@tf.function` 装饰器在 `@tff.tf_computation` 装饰器内部的用法。通过这种用法,即使在由 `tff.tf_computation` 装饰器创建的 `tf.Graph` 上下文中,我们也可以使用类似 Python 的语义来编写 TF。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4EObiz_Ke0uK" }, "outputs": [], "source": [ "# NOTE: `forward_pass` is defined separately from `batch_loss` so that it can \n", "# be later called from within another tf.function. Necessary because a\n", "# @tf.function decorated method cannot invoke a @tff.tf_computation.\n", "\n", "@tf.function\n", "def forward_pass(model, batch):\n", " predicted_y = tf.nn.softmax(\n", " tf.matmul(batch['x'], model['weights']) + model['bias'])\n", " return -tf.reduce_mean(\n", " tf.reduce_sum(\n", " tf.one_hot(batch['y'], 10) * tf.math.log(predicted_y), axis=[1]))\n", "\n", "@tff.tf_computation(MODEL_TYPE, BATCH_TYPE)\n", "def batch_loss(model, batch):\n", " return forward_pass(model, batch)" ] }, { "cell_type": "markdown", "metadata": { "id": "8K0UZHGnr8SB" }, "source": [ "和预期一样,在给定模型和单个数据批次的情况下,计算 `batch_loss` 返回 `float32` 损失。请注意 `MODEL_TYPE` 和 `BATCH_TYPE` 合并为形式参数的二维元组的方式;您可以将 `batch_loss` 的类型识别为 `( -> float32)`。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4WXEAY8Nr89V" }, "outputs": [ { "data": { "text/plain": [ "'(<,> -> float32)'" ] }, "execution_count": 13, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "str(batch_loss.type_signature)" ] }, { "cell_type": "markdown", "metadata": { "id": "pAnt_UcdnvGa" }, "source": [ "作为健全性检查,我们来构造一个用零填充的初始模型,并计算上文中可视化的那批数据的损失。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "U8Ne8igan3os" }, "outputs": [ { "data": { "text/plain": [ "2.3025854" ] }, "execution_count": 14, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "initial_model = collections.OrderedDict(\n", " weights=np.zeros([784, 10], dtype=np.float32),\n", " bias=np.zeros([10], dtype=np.float32))\n", "\n", "sample_batch = federated_train_data[5][-1]\n", "\n", "batch_loss(initial_model, sample_batch)" ] }, { "cell_type": "markdown", "metadata": { "id": "ckigEAyDAWFV" }, "source": [ "请注意,我们使用定义为 `dict` 的初始模型为 TFF 计算馈送数据,即便定义它的 Python 函数的主体将模型参数用作 `model['weight']` 和 `model['bias']` 。`batch_loss` 调用的参数并不是简单地传递给该函数的主体。\n", "\n", "当我们调用 `batch_loss` 时会发生什么情况?`batch_loss` 的 Python 主体已在上面的单元格中(在对其进行定义的位置)进行了跟踪和序列化。TFF 在计算定义时充当 `batch_loss` 的调用者,并在 `batch_loss` 被调用时充当调用的目标。在这两个角色中,TFF 均充当 TFF 的抽象类型系统和 Python 表示类型之间的桥梁。在调用时,TFF 将接受大多数标准 Python 容器类型(`dict`、`list`、`tuple`、`collections.namedtuple` 等),以将其作为抽象 TFF 元组的具体表示。虽然我们在上文中提到,TFF 计算在形式上仅接受单个参数,但如果参数的类型是元组,则可以将熟悉的 Python 调用语法与位置和/或关键字参数一起使用,它会按预期工作。" ] }, { "cell_type": "markdown", "metadata": { "id": "eB510nILYbId" }, "source": [ "### 单个批次上的梯度下降\n", "\n", "现在,我们来定义一个使用下面的损失函数来执行单步梯度下降的计算。请注意我们在定义此函数时,如何将 `batch_loss` 用作子组件。您可以在另一个计算的主体内部调用使用 `tff.tf_computation` 构造的计算,但正如我们在上文中提到的,您通常没有必要进行此操作。这是因为,序列化会丢失部分调试信息,因此对于不使用 `tff.tf_computation` 装饰器来编写和测试所有 TensorFlow 的更复杂的计算来说,这种方式更加可取。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "O4uaVxw3AyYS" }, "outputs": [], "source": [ "@tff.tf_computation(MODEL_TYPE, BATCH_TYPE, tf.float32)\n", "def batch_train(initial_model, batch, learning_rate):\n", " # Define a group of model variables and set them to `initial_model`. Must\n", " # be defined outside the @tf.function.\n", " model_vars = collections.OrderedDict([\n", " (name, tf.Variable(name=name, initial_value=value))\n", " for name, value in initial_model.items()\n", " ])\n", " optimizer = tf.keras.optimizers.SGD(learning_rate)\n", "\n", " @tf.function\n", " def _train_on_batch(model_vars, batch):\n", " # Perform one step of gradient descent using loss from `batch_loss`.\n", " with tf.GradientTape() as tape:\n", " loss = forward_pass(model_vars, batch)\n", " grads = tape.gradient(loss, model_vars)\n", " optimizer.apply_gradients(\n", " zip(tf.nest.flatten(grads), tf.nest.flatten(model_vars)))\n", " return model_vars\n", "\n", " return _train_on_batch(model_vars, batch)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Y84gQsaohC38" }, "outputs": [ { "data": { "text/plain": [ "'(<,,float32> -> )'" ] }, "execution_count": 16, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "str(batch_train.type_signature)" ] }, { "cell_type": "markdown", "metadata": { "id": "ID8xg9FCUL2A" }, "source": [ "当您在另一个此类函数的主体中调用使用 `tff.tf_computation` 装饰的 Python 函数时,内部 TFF 计算的逻辑会嵌入(本质上为内嵌)到外部计算的逻辑中。如上所述,如果要编写这两个计算,最好将内部函数(在本例中为 `batch_loss`)设置为常规 Python 或 `tf.function` 函数,而非 `tff.tf_computation` 函数。但这里我们演示了,在 `tff.tf_computation` 内部调用与其相同的函数基本上可以按预期工作。例如,如果您没有定义 `batch_loss` 的 Python 代码,而只有它的序列化 TFF 表示,则可能必须进行此操作。\n", "\n", "现在,将这个函数在初始模型上应用几次,以查看损失是否会减少。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8edcJTlXUULm" }, "outputs": [], "source": [ "model = initial_model\n", "losses = []\n", "for _ in range(5):\n", " model = batch_train(model, sample_batch, 0.1)\n", " losses.append(batch_loss(model, sample_batch))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3n1onojT1zHv" }, "outputs": [ { "data": { "text/plain": [ "[0.19690022, 0.13176313, 0.10113226, 0.082738124, 0.0703014]" ] }, "execution_count": 18, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "losses" ] }, { "cell_type": "markdown", "metadata": { "id": "EQk4Ha8PU-3P" }, "source": [ "### 本地数据序列上的梯度下降\n", "\n", "现在,由于 `batch_train` 似乎可以正常工作,我们来编写一个类似的训练函数 `local_train`,它会使用一个用户所有批次的整个序列,而不仅仅是一个批次。现在,新的计算将需要使用 `tff.SequenceType(BATCH_TYPE)` 而不是 `BATCH_TYPE`。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "EfPD5a6QVNXM" }, "outputs": [], "source": [ "LOCAL_DATA_TYPE = tff.SequenceType(BATCH_TYPE)\n", "\n", "@tff.federated_computation(MODEL_TYPE, tf.float32, LOCAL_DATA_TYPE)\n", "def local_train(initial_model, learning_rate, all_batches):\n", "\n", " # Mapping function to apply to each batch.\n", " @tff.federated_computation(MODEL_TYPE, BATCH_TYPE)\n", " def batch_fn(model, batch):\n", " return batch_train(model, batch, learning_rate)\n", "\n", " return tff.sequence_reduce(all_batches, initial_model, batch_fn)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sAhkS5yKUgjC" }, "outputs": [ { "data": { "text/plain": [ "'(<,float32,*> -> )'" ] }, "execution_count": 20, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "str(local_train.type_signature)" ] }, { "cell_type": "markdown", "metadata": { "id": "EYT-SiopYBtH" }, "source": [ "这段简短的代码中包含了很多细节,我们将逐一进行介绍。\n", "\n", "首先,虽然我们完全可以用 TensorFlow 实现此逻辑,像之前那样利用 `tf.data.Dataset.reduce` 来处理序列,但这次我们选择用胶水语言将此逻辑表达为 `tff.federated_computation`。我们已使用联合算子 `tff.sequence_reduce` 来执行归约。\n", "\n", "算子 `tff.sequence_reduce` 的用法类似于 `tf.data.Dataset.reduce`。您可以认为它在本质上与 `tf.data.Dataset.reduce` 相同,但是前者用于联合计算内部(您也许还记得,它不能包含 TensorFlow 代码)。它是一个模板算子,其形式参数三维元组由 `T` 型元素的*序列*、某种类型 `U` 的归约初始状态(我们将其抽象地称为*零*),以及类型 `( -> U)` 的*归约算子*(通过处理单个元素改变归约状态)组成。得到的结果是按顺序处理所有元素后归约的最终状态。在我们的示例中,归约状态是在数据前缀上训练的模型,且元素是数据批次。\n", "\n", "其次,请注意,我们再次将一个计算(`batch_train`)用作了另一个计算(`local_train`)中的组件,而非直接使用。我们不能将其用作归约算子,因为它需要一个额外参数,即学习率。为了解决这个问题,我们定义一个嵌入式联合计算 `batch_fn`,该计算绑定到其主体中 `local_train` 的参数 `learning_rate`。因此,以这种方式定义的子计算可以捕获其父级的形式参数,只要子计算未在其父级的主体之外调用。您可以将此模式视为 Python 中 `functools.partial` 的等效项。\n", "\n", "当然,以这种方式捕获 `learning_rate` 的实际含义是,在所有批次中都使用相同的学习率值。\n", "\n", "现在,我们在整个数据序列上尝试新定义的本地训练函数,该数据序列由贡献了样本批次的同一用户(数字 `5`)提供。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "EnWFLoZGcSby" }, "outputs": [], "source": [ "locally_trained_model = local_train(initial_model, 0.1, federated_train_data[5])" ] }, { "cell_type": "markdown", "metadata": { "id": "y0UXUqGk9zoF" }, "source": [ "有效果吗?为了回答这个问题,我们需要实现评估。" ] }, { "cell_type": "markdown", "metadata": { "id": "a8WDKu6WYy__" }, "source": [ "### 本地评估\n", "\n", "下面是一种通过将所有数据批次的损失加总起来实现本地评估的方法(也可以算出平均值;我们将把它作为练习留给读者)。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0RiODuc6z7Ln" }, "outputs": [], "source": [ "@tff.federated_computation(MODEL_TYPE, LOCAL_DATA_TYPE)\n", "def local_eval(model, all_batches):\n", " # TODO(b/120157713): Replace with `tff.sequence_average()` once implemented.\n", " return tff.sequence_sum(\n", " tff.sequence_map(\n", " tff.federated_computation(lambda b: batch_loss(model, b), BATCH_TYPE),\n", " all_batches))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pH2XPEAKa4Dg" }, "outputs": [ { "data": { "text/plain": [ "'(<,*> -> float32)'" ] }, "execution_count": 23, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "str(local_eval.type_signature)" ] }, { "cell_type": "markdown", "metadata": { "id": "efX81HuE-BcO" }, "source": [ "同样,此代码演示了一些新的元素,我们将逐一进行介绍。\n", "\n", "首先,我们使用了两个新的联合算子来处理序列:一个是 `tff.sequence_map`,它接受*映射函数* `T->U` 和 `T` 的*序列*,然后发出通过逐点应用映射函数获得的 `U` 的序列;另一个是 `tff.sequence_sum`,它只是把所有元素加总起来。在这里,我们将每个数据批次映射到损失值,然后将生成的损失值加总以计算总损失。\n", "\n", "请注意,我们可以再次使用 `tff.sequence_reduce`,但这不是最佳选择,根据定义,归约过程是顺序的,而映射和求和可以并行计算。如果有选择的话,最好坚持使用不限制实现选择的算子,这样,当将来编译 TFF 计算以部署到特定环境时,就可以充分利用所有潜在机会,实现更快、扩展性更强、更节省资源的执行。\n", "\n", "其次,请注意,正如在 `local_train` 中一样,我们需要的组件函数(`batch_loss`)接受的参数比联合算子(`tff.sequence_map`)所期望的参数要多,因此我们再次定义了部分参数(内嵌),这次是通过直接将 `lambda` 封装为 `tff.federated_computation`。如果要使用 `tff.tf_computation` 将 TensorFlow 逻辑嵌入 TFF,建议将封装容器与函数一起作为参数内嵌使用。\n", "\n", "现在,看看我们的训练是否有效。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vPw6JSVf5q_x" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "initial_model loss = 23.025854\n", "locally_trained_model loss = 0.4348469\n" ] } ], "source": [ "print('initial_model loss =', local_eval(initial_model,\n", " federated_train_data[5]))\n", "print('locally_trained_model loss =',\n", " local_eval(locally_trained_model, federated_train_data[5]))" ] }, { "cell_type": "markdown", "metadata": { "id": "6Tvu70cnBsUf" }, "source": [ "确实,损失减少了。但如果我们根据其他用户的数据对其进行评估,会发生什么呢?" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gjF0NYAj5wls" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "initial_model loss = 23.025854\n", "locally_trained_model loss = 74.50075\n" ] } ], "source": [ "print('initial_model loss =', local_eval(initial_model,\n", " federated_train_data[0]))\n", "print('locally_trained_model loss =',\n", " local_eval(locally_trained_model, federated_train_data[0]))" ] }, { "cell_type": "markdown", "metadata": { "id": "7WPumnRTBzUs" }, "source": [ "情况果然变得更糟了。该模型经过训练可以识别 `5`,但从未看到 `0`。这就出现了一个问题,即从全局角度来看,本地训练会对模型质量产生什么影响?" ] }, { "cell_type": "markdown", "metadata": { "id": "QJnL2mQRZKTO" }, "source": [ "### 联合评估\n", "\n", "至此,我们终于回到了联合类型和联合计算,即我们最开始讨论的主题。下面是一对源自服务器的模型的 TFF 类型定义,以及保留在客户端上的数据。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LjGGhpoEBh_6" }, "outputs": [], "source": [ "SERVER_MODEL_TYPE = tff.type_at_server(MODEL_TYPE)\n", "CLIENT_DATA_TYPE = tff.type_at_clients(LOCAL_DATA_TYPE)" ] }, { "cell_type": "markdown", "metadata": { "id": "4gTXV2-jZtE3" }, "source": [ "根据目前为止介绍的所有定义,在 TFF 中对联合评估的表达均为一行式,我们将模型分发给客户端,让每个客户端在其本地数据部分上调用本地评估,然后对损失进行平均。下面是一种编写方法。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2zChEPzEBx4T" }, "outputs": [], "source": [ "@tff.federated_computation(SERVER_MODEL_TYPE, CLIENT_DATA_TYPE)\n", "def federated_eval(model, data):\n", " return tff.federated_mean(\n", " tff.federated_map(local_eval, [tff.federated_broadcast(model), data]))" ] }, { "cell_type": "markdown", "metadata": { "id": "IWcNONNWaE0N" }, "source": [ "我们已经在更简单的场景中看到了 `tff.federated_mean` 和 `tff.federated_map` 的示例,直观来看,他们可以按照预期工作,但这部分代码并不像看上去那么简单,下面我们来仔细研究一下。\n", "\n", "首先,我们来分解一下*让每个客户端在其本地数据部分上调用本地评估*这个部分。您可能还记得前几部分的内容,`local_eval` 具有形式为 `( -> float32)` 的类型签名。\n", "\n", "联合算子 `tff.federated_map` 是一个模版,它接受二维元组作为参数,该二维元组由某种类型 `T->U` 的*映射函数*和类型 `{T}@CLIENTS` 的联合值(即,具有与映射函数的参数相同类型的成员组成)组成,并返回 `{U}@CLIENTS` 类型的结果。\n", "\n", "由于我们将 `local_eval` 作为映射函数馈送给每个客户端,因此第二个参数应为联合类型 `{}@CLIENTS`(即,根据前几部分的命名,它应该是一个联合元组)。每个客户端应将 `local_eval` 的完整参数集作为成员组成。相反,我们向它馈送的是 2 个元素的 Python `list`。这是什么情况?\n", "\n", "实际上,这是 TFF 中*隐式类型转换*的示例,它类似于您可能在其他地方遇到的隐式类型转换(例如,当您向接受 `float` 的函数馈送 `int`时)。目前很少使用隐式转换,但我们计划使它在 TFF 中更加普遍,以尽量减少样板文件。\n", "\n", "在这种情况下,应用的隐式转换在形式为 `{}@Z` 的联合元组和联合值为 `<{X}@Z,{Y}@Z>` 的元组之间等效。虽然二者是不同的类型签名,从程序员的角度来看,`Z` 中的每个设备都包含数据 `X` 和 `Y` 的两个单元。这里发生的情况与 Python 中的 `zip` 没什么区别,实际上,我们提供了一种算子 `tff.federated_zip`,使您可以显式地执行此类转换。当 `tff.federated_map` 遇到作为第二个参数的元组时,它将为您直接调用 `tff.federated_zip`。\n", "\n", "根据上述信息,您现在应该能够将表达式 `tff.federated_broadcast(model)` 识别为表示 TFF 类型 `{MODEL_TYPE}@CLIENTS` 的值,并将 `data` 识别为 TFF 类型 `{LOCAL_DATA_TYPE}@CLIENTS`(或简写为 `CLIENT_DATA_TYPE`)的值,两者通过隐式 `tff.federated_zip` 一起筛选,以形成 `tff.federated_map` 的第二个参数。\n", "\n", "如您所料,算子 `tff.federated_broadcast` 只是将数据从服务器传输到客户端。\n", "\n", "现在,我们来看看本地训练如何影响系统的平均损失。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tbmtJItcn94j" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "initial_model loss = 23.025852\n", "locally_trained_model loss = 54.432625\n" ] } ], "source": [ "print('initial_model loss =', federated_eval(initial_model,\n", " federated_train_data))\n", "print('locally_trained_model loss =',\n", " federated_eval(locally_trained_model, federated_train_data))" ] }, { "cell_type": "markdown", "metadata": { "id": "LQi2rGX_fK7i" }, "source": [ "确实,和预期一样,损失增加了。为了改进所有用户的模型,我们需要用每个用户自己的数据进行训练。" ] }, { "cell_type": "markdown", "metadata": { "id": "vkw9f59qfS7o" }, "source": [ "### 联合训练\n", "\n", "实现联合训练的最简单方法是进行本地训练,然后对模型进行平均。这会用到我们讨论过的相同构建块和模式,如下所示。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mBOC4uoG6dd-" }, "outputs": [], "source": [ "SERVER_FLOAT_TYPE = tff.type_at_server(tf.float32)\n", "\n", "\n", "@tff.federated_computation(SERVER_MODEL_TYPE, SERVER_FLOAT_TYPE,\n", " CLIENT_DATA_TYPE)\n", "def federated_train(model, learning_rate, data):\n", " return tff.federated_mean(\n", " tff.federated_map(local_train, [\n", " tff.federated_broadcast(model),\n", " tff.federated_broadcast(learning_rate), data\n", " ]))" ] }, { "cell_type": "markdown", "metadata": { "id": "z2vACMsQjzO1" }, "source": [ "请注意,在 `tff.learning` 所提供的联合平均的全功能实现中,由于多种原因(例如,裁剪更新范数的能力、用于压缩等),我们更喜欢对模型增量进行平均,而不是对模型进行平均。\n", "\n", "让我们通过进行几轮训练并比较前后的平均损失,来看看训练是否有效。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NLx-3rLs9jGY" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "round 0, loss=21.60552406311035\n", "round 1, loss=20.365678787231445\n", "round 2, loss=19.27480125427246\n", "round 3, loss=18.31110954284668\n", "round 4, loss=17.45725440979004\n" ] } ], "source": [ "model = initial_model\n", "learning_rate = 0.1\n", "for round_num in range(5):\n", " model = federated_train(model, learning_rate, federated_train_data)\n", " learning_rate = learning_rate * 0.9\n", " loss = federated_eval(model, federated_train_data)\n", " print('round {}, loss={}'.format(round_num, loss))" ] }, { "cell_type": "markdown", "metadata": { "id": "Z0VjSLQzlUIp" }, "source": [ "现在,为了完整起见,我们也在测试数据上运行一下,以确认我们的模型能够很好地泛化。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZaZT45yFMOaM" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "initial_model test loss = 22.795593\n", "trained_model test loss = 17.278767\n" ] } ], "source": [ "print('initial_model test loss =',\n", " federated_eval(initial_model, federated_test_data))\n", "print('trained_model test loss =', federated_eval(model, federated_test_data))" ] }, { "cell_type": "markdown", "metadata": { "id": "pxlHHwLGlgFB" }, "source": [ "我们的教程到此结束。\n", "\n", "当然,我们的简化示例并不能反映您在更实际的场景中需要进行的诸多操作(例如,除了损失之外,我们没有计算指标)。我们鼓励您学习 `tff.learning` 中联合平均的[实现](https://github.com/tensorflow/federated/blob/master/tensorflow_federated/python/learning/federated_averaging.py),它是一个更完整的示例,并在其中演示了我们所建议的一些编程做法。" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "custom_federated_algorithms_2.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }