{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "8vD3L4qeREvg" }, "source": [ "##### Copyright 2021 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2023-11-07T21:50:12.959819Z", "iopub.status.busy": "2023-11-07T21:50:12.959526Z", "iopub.status.idle": "2023-11-07T21:50:12.964155Z", "shell.execute_reply": "2023-11-07T21:50:12.963414Z" }, "id": "qLCxmWRyRMZE" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "4k5PoHrgJQOU" }, "source": [ "# 用于 TFLite 的 Jax 模型转换\n", "\n", "## 概述\n", "\n", "注:此为新 API ,只有通过 pip 安装 tf-nighly 才能使用。它将在 TensorFlow 2.7 版中提供。另外,此 API 仍处于实验阶段,可能会发生变化。\n", "\n", "此 CodeLab 演示了如何使用 Jax 构建 MNIST 识别模型,以及如何将其转换为 TensorFlow Lite。此 CodeLab 还将演示如何使用训练后量化来优化 Jax 转换的 TFLite 模型。" ] }, { "cell_type": "markdown", "metadata": { "id": "i8cfOBcjSByO" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 上查看\n", " 在 Google Colab 运行 在 Github 上查看源代码 下载笔记本
" ] }, { "cell_type": "markdown", "metadata": { "id": "lq-T8XZMJ-zv" }, "source": [ "## 先决条件\n", "\n", "建议在最新的 TensorFlow nightly pip 构建中尝试此功能。" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T21:50:12.968351Z", "iopub.status.busy": "2023-11-07T21:50:12.967704Z", "iopub.status.idle": "2023-11-07T21:50:52.606177Z", "shell.execute_reply": "2023-11-07T21:50:52.605155Z" }, "id": "EV04hKdrnE4f" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Collecting tf-nightly\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Downloading tf_nightly-2.16.0.dev20231107-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.5 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: absl-py>=1.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (1.4.0)\r\n", "Requirement already satisfied: astunparse>=1.6.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (1.6.3)\r\n", "Requirement already satisfied: flatbuffers>=23.5.26 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (23.5.26)\r\n", "Requirement already satisfied: gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (0.5.4)\r\n", "Requirement already satisfied: google-pasta>=0.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (0.2.0)\r\n", "Requirement already satisfied: h5py>=3.10.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (3.10.0)\r\n", "Requirement already satisfied: libclang>=13.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (16.0.6)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting ml-dtypes~=0.3.1 (from tf-nightly)\r\n", " Downloading ml_dtypes-0.3.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB)\r\n", "Requirement already satisfied: opt-einsum>=2.3.2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (3.3.0)\r\n", "Requirement already satisfied: packaging in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (23.2)\r\n", "Requirement already satisfied: protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.20.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (3.20.3)\r\n", "Requirement already satisfied: setuptools in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (68.2.2)\r\n", "Requirement already satisfied: six>=1.12.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (1.16.0)\r\n", "Requirement already satisfied: termcolor>=1.1.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (2.3.0)\r\n", "Requirement already satisfied: typing-extensions>=3.6.6 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (4.8.0)\r\n", "Requirement already satisfied: wrapt<1.15,>=1.11.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (1.14.1)\r\n", "Requirement already satisfied: grpcio<2.0,>=1.24.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (1.59.2)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting tb-nightly~=2.16.0.a (from tf-nightly)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Downloading tb_nightly-2.16.0a20231107-py3-none-any.whl.metadata (1.7 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting tf-estimator-nightly~=2.14.0.dev (from tf-nightly)\r\n", " Downloading tf_estimator_nightly-2.14.0.dev2023080308-py2.py3-none-any.whl.metadata (1.3 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting keras-nightly~=3.0.0.dev (from tf-nightly)\r\n", " Downloading keras_nightly-3.0.0.dev2023110703-py3-none-any.whl.metadata (5.3 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (0.34.0)\r\n", "Requirement already satisfied: numpy<2.0.0,>=1.23.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (1.26.1)\r\n", "Requirement already satisfied: wheel<1.0,>=0.23.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from astunparse>=1.6.0->tf-nightly) (0.41.3)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting rich (from keras-nightly~=3.0.0.dev->tf-nightly)\r\n", " Downloading rich-13.6.0-py3-none-any.whl.metadata (18 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting namex (from keras-nightly~=3.0.0.dev->tf-nightly)\r\n", " Downloading namex-0.0.7-py3-none-any.whl (5.8 kB)\r\n", "Requirement already satisfied: dm-tree in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from keras-nightly~=3.0.0.dev->tf-nightly) (0.1.8)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: google-auth<3,>=1.6.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tb-nightly~=2.16.0.a->tf-nightly) (2.23.4)\r\n", "Requirement already satisfied: google-auth-oauthlib<2,>=0.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tb-nightly~=2.16.0.a->tf-nightly) (1.1.0)\r\n", "Requirement already satisfied: markdown>=2.6.8 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tb-nightly~=2.16.0.a->tf-nightly) (3.5.1)\r\n", "Requirement already satisfied: requests<3,>=2.21.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tb-nightly~=2.16.0.a->tf-nightly) (2.31.0)\r\n", "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tb-nightly~=2.16.0.a->tf-nightly) (0.7.2)\r\n", "Requirement already satisfied: werkzeug>=1.0.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tb-nightly~=2.16.0.a->tf-nightly) (3.0.1)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: cachetools<6.0,>=2.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tb-nightly~=2.16.0.a->tf-nightly) (5.3.2)\r\n", "Requirement already satisfied: pyasn1-modules>=0.2.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tb-nightly~=2.16.0.a->tf-nightly) (0.3.0)\r\n", "Requirement already satisfied: rsa<5,>=3.1.4 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tb-nightly~=2.16.0.a->tf-nightly) (4.9)\r\n", "Requirement already satisfied: requests-oauthlib>=0.7.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth-oauthlib<2,>=0.5->tb-nightly~=2.16.0.a->tf-nightly) (1.3.1)\r\n", "Requirement already satisfied: importlib-metadata>=4.4 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from markdown>=2.6.8->tb-nightly~=2.16.0.a->tf-nightly) (6.8.0)\r\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tb-nightly~=2.16.0.a->tf-nightly) (3.3.2)\r\n", "Requirement already satisfied: idna<4,>=2.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tb-nightly~=2.16.0.a->tf-nightly) (3.4)\r\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tb-nightly~=2.16.0.a->tf-nightly) (2.0.7)\r\n", "Requirement already satisfied: certifi>=2017.4.17 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tb-nightly~=2.16.0.a->tf-nightly) (2023.7.22)\r\n", "Requirement already satisfied: MarkupSafe>=2.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from werkzeug>=1.0.1->tb-nightly~=2.16.0.a->tf-nightly) (2.1.3)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting markdown-it-py>=2.2.0 (from rich->keras-nightly~=3.0.0.dev->tf-nightly)\r\n", " Downloading markdown_it_py-3.0.0-py3-none-any.whl.metadata (6.9 kB)\r\n", "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from rich->keras-nightly~=3.0.0.dev->tf-nightly) (2.16.1)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: zipp>=0.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tb-nightly~=2.16.0.a->tf-nightly) (3.17.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting mdurl~=0.1 (from markdown-it-py>=2.2.0->rich->keras-nightly~=3.0.0.dev->tf-nightly)\r\n", " Downloading mdurl-0.1.2-py3-none-any.whl (10.0 kB)\r\n", "Requirement already satisfied: pyasn1<0.6.0,>=0.4.6 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tb-nightly~=2.16.0.a->tf-nightly) (0.5.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: oauthlib>=3.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<2,>=0.5->tb-nightly~=2.16.0.a->tf-nightly) (3.2.2)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading tf_nightly-2.16.0.dev20231107-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (521.7 MB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading keras_nightly-3.0.0.dev2023110703-py3-none-any.whl (987 kB)\r\n", "Downloading ml_dtypes-0.3.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (206 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading tb_nightly-2.16.0a20231107-py3-none-any.whl (5.5 MB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading tf_estimator_nightly-2.14.0.dev2023080308-py2.py3-none-any.whl (440 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading rich-13.6.0-py3-none-any.whl (239 kB)\r\n", "Downloading markdown_it_py-3.0.0-py3-none-any.whl (87 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Installing collected packages: namex, tf-estimator-nightly, ml-dtypes, mdurl, markdown-it-py, rich, tb-nightly, keras-nightly, tf-nightly\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Attempting uninstall: ml-dtypes\r\n", " Found existing installation: ml-dtypes 0.2.0\r\n", " Uninstalling ml-dtypes-0.2.0:\r\n", " Successfully uninstalled ml-dtypes-0.2.0\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\r\n", "tensorflow 2.15.0rc1 requires ml-dtypes~=0.2.0, but you have ml-dtypes 0.3.1 which is incompatible.\u001b[0m\u001b[31m\r\n", "\u001b[0mSuccessfully installed keras-nightly-3.0.0.dev2023110703 markdown-it-py-3.0.0 mdurl-0.1.2 ml-dtypes-0.3.1 namex-0.0.7 rich-13.6.0 tb-nightly-2.16.0a20231107 tf-estimator-nightly-2.14.0.dev2023080308 tf-nightly-2.16.0.dev20231107\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting jax\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Downloading jax-0.4.20-py3-none-any.whl.metadata (23 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: ml-dtypes>=0.2.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from jax) (0.3.1)\r\n", "Requirement already satisfied: numpy>=1.22 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from jax) (1.26.1)\r\n", "Requirement already satisfied: opt-einsum in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from jax) (3.3.0)\r\n", "Requirement already satisfied: scipy>=1.9 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from jax) (1.11.3)\r\n", "Requirement already satisfied: importlib-metadata>=4.6 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from jax) (6.8.0)\r\n", "Requirement already satisfied: zipp>=0.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from importlib-metadata>=4.6->jax) (3.17.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading jax-0.4.20-py3-none-any.whl (1.7 MB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Installing collected packages: jax\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Successfully installed jax-0.4.20\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting jaxlib\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Downloading jaxlib-0.4.20-cp39-cp39-manylinux2014_x86_64.whl.metadata (2.1 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: scipy>=1.9 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from jaxlib) (1.11.3)\r\n", "Requirement already satisfied: numpy>=1.22 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from jaxlib) (1.26.1)\r\n", "Requirement already satisfied: ml-dtypes>=0.2.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from jaxlib) (0.3.1)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading jaxlib-0.4.20-cp39-cp39-manylinux2014_x86_64.whl (85.8 MB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Installing collected packages: jaxlib\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Successfully installed jaxlib-0.4.20\r\n" ] } ], "source": [ "!pip install tf-nightly --upgrade\n", "!pip install jax --upgrade\n", "!pip install jaxlib --upgrade" ] }, { "cell_type": "markdown", "metadata": { "id": "QAeY43k9KM55" }, "source": [ "## 数据准备\n", "\n", "使用 Keras 数据集下载 MNIST 数据并进行预处理。" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T21:50:52.611136Z", "iopub.status.busy": "2023-11-07T21:50:52.610811Z", "iopub.status.idle": "2023-11-07T21:50:55.516532Z", "shell.execute_reply": "2023-11-07T21:50:55.515773Z" }, "id": "qSOPSZJn1_Tj" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2023-11-07 21:50:53.094555: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:10778] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "2023-11-07 21:50:53.094603: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "2023-11-07 21:50:53.094963: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1533] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" ] } ], "source": [ "import numpy as np\n", "import tensorflow as tf\n", "import functools\n", "\n", "import time\n", "import itertools\n", "\n", "import numpy.random as npr\n", "\n", "import jax.numpy as jnp\n", "from jax import jit, grad, random\n", "from jax.example_libraries import optimizers\n", "from jax.example_libraries import stax\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T21:50:55.520823Z", "iopub.status.busy": "2023-11-07T21:50:55.520423Z", "iopub.status.idle": "2023-11-07T21:50:56.232275Z", "shell.execute_reply": "2023-11-07T21:50:56.231205Z" }, "id": "hdJIt3Da2Qn1" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "\u001b[1m 0/11490434\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m0s\u001b[0m 0s/step" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "\u001b[1m 9281536/11490434\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m━━━━\u001b[0m \u001b[1m0s\u001b[0m 0us/step" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "\u001b[1m11490434/11490434\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 0us/step\n" ] } ], "source": [ "def _one_hot(x, k, dtype=np.float32):\n", " \"\"\"Create a one-hot encoding of x of size k.\"\"\"\n", " return np.array(x[:, None] == np.arange(k), dtype)\n", "\n", "(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()\n", "train_images, test_images = train_images / 255.0, test_images / 255.0\n", "train_images = train_images.astype(np.float32)\n", "test_images = test_images.astype(np.float32)\n", "\n", "train_labels = _one_hot(train_labels, 10)\n", "test_labels = _one_hot(test_labels, 10)" ] }, { "cell_type": "markdown", "metadata": { "id": "0eFhx85YKlEY" }, "source": [ "## 使用 Jax 构建 MNIST 模型" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T21:50:56.236595Z", "iopub.status.busy": "2023-11-07T21:50:56.236330Z", "iopub.status.idle": "2023-11-07T21:50:56.300110Z", "shell.execute_reply": "2023-11-07T21:50:56.299350Z" }, "id": "mi3TKB9nnQdK" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n" ] } ], "source": [ "def loss(params, batch):\n", " inputs, targets = batch\n", " preds = predict(params, inputs)\n", " return -jnp.mean(jnp.sum(preds * targets, axis=1))\n", "\n", "def accuracy(params, batch):\n", " inputs, targets = batch\n", " target_class = jnp.argmax(targets, axis=1)\n", " predicted_class = jnp.argmax(predict(params, inputs), axis=1)\n", " return jnp.mean(predicted_class == target_class)\n", "\n", "init_random_params, predict = stax.serial(\n", " stax.Flatten,\n", " stax.Dense(1024), stax.Relu,\n", " stax.Dense(1024), stax.Relu,\n", " stax.Dense(10), stax.LogSoftmax)\n", "\n", "rng = random.PRNGKey(0)" ] }, { "cell_type": "markdown", "metadata": { "id": "bRtnOBdJLd63" }, "source": [ "## 训练并评估模型" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T21:50:56.304235Z", "iopub.status.busy": "2023-11-07T21:50:56.303957Z", "iopub.status.idle": "2023-11-07T21:51:31.014847Z", "shell.execute_reply": "2023-11-07T21:51:31.013934Z" }, "id": "SWbYRyj7LYZt" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Starting training...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0 in 4.47 sec\n", "Training set accuracy 0.8728833198547363\n", "Test set accuracy 0.880299985408783\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1 in 2.45 sec\n", "Training set accuracy 0.8983833193778992\n", "Test set accuracy 0.9047999978065491\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 2 in 2.45 sec\n", "Training set accuracy 0.9102333188056946\n", "Test set accuracy 0.9138000011444092\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 3 in 2.47 sec\n", "Training set accuracy 0.9172333478927612\n", "Test set accuracy 0.9218999743461609\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 4 in 2.42 sec\n", "Training set accuracy 0.9224833250045776\n", "Test set accuracy 0.9253999590873718\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 5 in 2.45 sec\n", "Training set accuracy 0.9272000193595886\n", "Test set accuracy 0.9309999942779541\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 6 in 2.44 sec\n", "Training set accuracy 0.9328166842460632\n", "Test set accuracy 0.9334999918937683\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 7 in 2.41 sec\n", "Training set accuracy 0.9360166788101196\n", "Test set accuracy 0.9370999932289124\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 8 in 2.44 sec\n", "Training set accuracy 0.939050018787384\n", "Test set accuracy 0.939300000667572\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 9 in 2.45 sec\n", "Training set accuracy 0.9425666928291321\n", "Test set accuracy 0.9429000020027161\n" ] } ], "source": [ "step_size = 0.001\n", "num_epochs = 10\n", "batch_size = 128\n", "momentum_mass = 0.9\n", "\n", "\n", "num_train = train_images.shape[0]\n", "num_complete_batches, leftover = divmod(num_train, batch_size)\n", "num_batches = num_complete_batches + bool(leftover)\n", "\n", "def data_stream():\n", " rng = npr.RandomState(0)\n", " while True:\n", " perm = rng.permutation(num_train)\n", " for i in range(num_batches):\n", " batch_idx = perm[i * batch_size:(i + 1) * batch_size]\n", " yield train_images[batch_idx], train_labels[batch_idx]\n", "batches = data_stream()\n", "\n", "opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass)\n", "\n", "@jit\n", "def update(i, opt_state, batch):\n", " params = get_params(opt_state)\n", " return opt_update(i, grad(loss)(params, batch), opt_state)\n", "\n", "_, init_params = init_random_params(rng, (-1, 28 * 28))\n", "opt_state = opt_init(init_params)\n", "itercount = itertools.count()\n", "\n", "print(\"\\nStarting training...\")\n", "for epoch in range(num_epochs):\n", " start_time = time.time()\n", " for _ in range(num_batches):\n", " opt_state = update(next(itercount), opt_state, next(batches))\n", " epoch_time = time.time() - start_time\n", "\n", " params = get_params(opt_state)\n", " train_acc = accuracy(params, (train_images, train_labels))\n", " test_acc = accuracy(params, (test_images, test_labels))\n", " print(\"Epoch {} in {:0.2f} sec\".format(epoch, epoch_time))\n", " print(\"Training set accuracy {}\".format(train_acc))\n", " print(\"Test set accuracy {}\".format(test_acc))" ] }, { "cell_type": "markdown", "metadata": { "id": "7Y1OZBhfQhOj" }, "source": [ "## 转换为 TFLite 模型\n", "\n", "请注意,我们需要执行以下操作:\n", "\n", "1. 使用 `functools.partial` 将参数内联到 Jax `predict` 函数。\n", "2. 构建一个 `jnp.zeros`,这是一个用于 Jax 跟踪模型的“占位符”张量。\n", "3. 调用 `experimental_from_jax`:\n", "\n", "> - `serving_func` 被封装在一个列表中。\n", "> - 输入与给定的名称相关联,并作为封装在列表中的数组传入。\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T21:51:31.018811Z", "iopub.status.busy": "2023-11-07T21:51:31.018528Z", "iopub.status.idle": "2023-11-07T21:51:31.426782Z", "shell.execute_reply": "2023-11-07T21:51:31.425974Z" }, "id": "6pcqKZqdNTmn" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/tmp/ipykernel_38895/3025848042.py:3: TFLiteConverterV2.experimental_from_jax (from tensorflow.lite.python.lite) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use `jax2tf.convert` and (`lite.TFLiteConverter.from_saved_model` or `lite.TFLiteConverter.from_concrete_functions`) instead.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2023-11-07 21:51:31.217318: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:378] Ignored output_format.\n", "2023-11-07 21:51:31.217364: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:381] Ignored drop_control_dependency.\n", "2023-11-07 21:51:31.217371: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:387] Ignored change_concat_input_ranges.\n", "Summary on the non-converted ops:\n", "---------------------------------\n", " * Accepted dialects: tfl, builtin, func\n", " * Non-Converted Ops: 7, Total Ops 18, % non-converted = 38.89 %\n", " * 7 ARITH ops\n", "\n", "- arith.constant: 7 occurrences (f32: 6, i32: 1)\n", "\n", "\n", "\n", " (f32: 2)\n", " (f32: 3)\n", " (f32: 1)\n", "\n", " (f32: 1)\n" ] } ], "source": [ "serving_func = functools.partial(predict, params)\n", "x_input = jnp.zeros((1, 28, 28))\n", "converter = tf.lite.TFLiteConverter.experimental_from_jax(\n", " [serving_func], [[('input1', x_input)]])\n", "tflite_model = converter.convert()\n", "with open('jax_mnist.tflite', 'wb') as f:\n", " f.write(tflite_model)" ] }, { "cell_type": "markdown", "metadata": { "id": "sqEhzaJPSPS1" }, "source": [ "## 检查转换后的 TFLite 模型\n", "\n", "将转换后的模型的结果与 Jax 模型进行比较。" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T21:51:31.430858Z", "iopub.status.busy": "2023-11-07T21:51:31.430276Z", "iopub.status.idle": "2023-11-07T21:51:31.796277Z", "shell.execute_reply": "2023-11-07T21:51:31.795267Z" }, "id": "acj2AYzjSlaY" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO: Created TensorFlow Lite XNNPACK delegate for CPU.\n" ] } ], "source": [ "expected = serving_func(train_images[0:1])\n", "\n", "# Run the model with TensorFlow Lite\n", "interpreter = tf.lite.Interpreter(model_content=tflite_model)\n", "interpreter.allocate_tensors()\n", "input_details = interpreter.get_input_details()\n", "output_details = interpreter.get_output_details()\n", "interpreter.set_tensor(input_details[0][\"index\"], train_images[0:1, :, :])\n", "interpreter.invoke()\n", "result = interpreter.get_tensor(output_details[0][\"index\"])\n", "\n", "# Assert if the result of TFLite model is consistent with the JAX model.\n", "np.testing.assert_almost_equal(expected, result, 1e-5)" ] }, { "cell_type": "markdown", "metadata": { "id": "Qy9Gp4H2SjBL" }, "source": [ "## 优化模型\n", "\n", "我们将提供一个 `representative_dataset` 来进行训练后量化,以优化模型。\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T21:51:31.800275Z", "iopub.status.busy": "2023-11-07T21:51:31.799965Z", "iopub.status.idle": "2023-11-07T21:51:33.164594Z", "shell.execute_reply": "2023-11-07T21:51:33.163825Z" }, "id": "KI0rLV-Meg-2" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2023-11-07 21:51:31.978225: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:378] Ignored output_format.\n", "2023-11-07 21:51:31.978276: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:381] Ignored drop_control_dependency.\n", "2023-11-07 21:51:31.978283: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:387] Ignored change_concat_input_ranges.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Summary on the non-converted ops:\n", "---------------------------------\n", " * Accepted dialects: tfl, builtin, func\n", " * Non-Converted Ops: 7, Total Ops 18, % non-converted = 38.89 %\n", " * 7 ARITH ops\n", "\n", "- arith.constant: 7 occurrences (f32: 6, i32: 1)\n", "\n", "\n", "\n", " (f32: 2)\n", " (f32: 3)\n", " (f32: 1)\n", "\n", " (f32: 1)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2023-11-07 21:51:32.321877: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:378] Ignored output_format.\n", "2023-11-07 21:51:32.321927: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:381] Ignored drop_control_dependency.\n", "2023-11-07 21:51:32.321934: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:387] Ignored change_concat_input_ranges.\n", "Summary on the non-converted ops:\n", "---------------------------------\n", " * Accepted dialects: tfl, builtin, func\n", " * Non-Converted Ops: 7, Total Ops 18, % non-converted = 38.89 %\n", " * 7 ARITH ops\n", "\n", "- arith.constant: 7 occurrences (f32: 6, i32: 1)\n", "\n", "\n", "\n", " (f32: 2)\n", " (f32: 3)\n", " (f32: 1)\n", "\n", " (f32: 1)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "fully_quantize: 0, inference_type: 6, input_inference_type: FLOAT32, output_inference_type: FLOAT32\n" ] } ], "source": [ "def representative_dataset():\n", " for i in range(1000):\n", " x = train_images[i:i+1]\n", " yield [x]\n", "\n", "converter = tf.lite.TFLiteConverter.experimental_from_jax(\n", " [serving_func], [[('x', x_input)]])\n", "tflite_model = converter.convert()\n", "converter.optimizations = [tf.lite.Optimize.DEFAULT]\n", "converter.representative_dataset = representative_dataset\n", "converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]\n", "tflite_quant_model = converter.convert()\n", "with open('jax_mnist_quant.tflite', 'wb') as f:\n", " f.write(tflite_quant_model)" ] }, { "cell_type": "markdown", "metadata": { "id": "15xQR3JZS8TV" }, "source": [ "## 评估优化后的模型" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T21:51:33.168420Z", "iopub.status.busy": "2023-11-07T21:51:33.168126Z", "iopub.status.idle": "2023-11-07T21:51:33.179335Z", "shell.execute_reply": "2023-11-07T21:51:33.178727Z" }, "id": "X3oOm0OaevD6" }, "outputs": [], "source": [ "expected = serving_func(train_images[0:1])\n", "\n", "# Run the model with TensorFlow Lite\n", "interpreter = tf.lite.Interpreter(model_content=tflite_quant_model)\n", "interpreter.allocate_tensors()\n", "input_details = interpreter.get_input_details()\n", "output_details = interpreter.get_output_details()\n", "interpreter.set_tensor(input_details[0][\"index\"], train_images[0:1, :, :])\n", "interpreter.invoke()\n", "result = interpreter.get_tensor(output_details[0][\"index\"])\n", "\n", "# Assert if the result of TFLite model is consistent with the Jax model.\n", "np.testing.assert_almost_equal(expected, result, 1e-5)" ] }, { "cell_type": "markdown", "metadata": { "id": "QqHXCNa3myor" }, "source": [ "## 比较量化模型大小\n", "\n", "我们应该能够看到,量化模型的大小缩减为了原始模型的四分之一。" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2023-11-07T21:51:33.182595Z", "iopub.status.busy": "2023-11-07T21:51:33.182323Z", "iopub.status.idle": "2023-11-07T21:51:33.456236Z", "shell.execute_reply": "2023-11-07T21:51:33.455136Z" }, "id": "imFPw007juVG" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "7.2M\tjax_mnist.tflite\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "1.8M\tjax_mnist_quant.tflite\r\n" ] } ], "source": [ "!du -h jax_mnist.tflite\n", "!du -h jax_mnist_quant.tflite" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "overview.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 0 }