{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "m7hbib3bSGO9" }, "source": [ "**Copyright 2020 The TensorFlow Authors.**" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2020-11-12T08:14:13.347908Z", "iopub.status.busy": "2020-11-12T08:14:13.346831Z", "iopub.status.idle": "2020-11-12T08:14:13.349206Z", "shell.execute_reply": "2020-11-12T08:14:13.349617Z" }, "id": "mEE8NFIMSGO-" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "SyiSRgdtSGPC" }, "source": [ "# Keras 中的权重聚类示例" ] }, { "cell_type": "markdown", "metadata": { "id": "kW3os956SGPD" }, "source": [ "
![]() | \n",
" ![]() | \n",
" ![]() | \n",
" ![]() | \n",
"
strip_clustering
和应用标准压缩算法(例如通过 gzip)对于看到聚类压缩的好处必不可少。\n",
"\n",
"首先,为 TensorFlow 创建一个可压缩模型。在这里,`strip_clustering` 会移除聚类仅在训练期间才需要的所有变量(例如用于存储簇形心和索引的 `tf.Variable`),否则这些变量会在推理期间增加模型大小。"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"execution": {
"iopub.execute_input": "2020-11-12T08:15:03.920508Z",
"iopub.status.busy": "2020-11-12T08:15:03.919773Z",
"iopub.status.idle": "2020-11-12T08:15:03.956258Z",
"shell.execute_reply": "2020-11-12T08:15:03.955734Z"
},
"id": "4h6tSvMzSGPd"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Saving clustered model to: /tmp/tmpsc3jb7v8.h5\n"
]
}
],
"source": [
"final_model = tfmot.clustering.keras.strip_clustering(clustered_model)\n",
"\n",
"_, clustered_keras_file = tempfile.mkstemp('.h5')\n",
"print('Saving clustered model to: ', clustered_keras_file)\n",
"tf.keras.models.save_model(final_model, clustered_keras_file, \n",
" include_optimizer=False)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jZcotzPSVBtu"
},
"source": [
"随后,为 TFLite 创建可压缩模型。您可以将聚类模型转换为可在目标后端上运行的格式。TensorFlow Lite 是可用于部署到移动设备的示例。"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"execution": {
"iopub.execute_input": "2020-11-12T08:15:03.963791Z",
"iopub.status.busy": "2020-11-12T08:15:03.962922Z",
"iopub.status.idle": "2020-11-12T08:15:04.502483Z",
"shell.execute_reply": "2020-11-12T08:15:04.501909Z"
},
"id": "v2N47QW6SGPh"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"This property should not be used in TensorFlow 2.0, as updates are applied automatically.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"This property should not be used in TensorFlow 2.0, as updates are applied automatically.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: /tmp/tmp69qei5fh/assets\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Saved clustered TFLite model to: /tmp/clustered_mnist.tflite\n"
]
}
],
"source": [
"clustered_tflite_file = '/tmp/clustered_mnist.tflite'\n",
"converter = tf.lite.TFLiteConverter.from_keras_model(final_model)\n",
"tflite_clustered_model = converter.convert()\n",
"with open(clustered_tflite_file, 'wb') as f:\n",
" f.write(tflite_clustered_model)\n",
"print('Saved clustered TFLite model to:', clustered_tflite_file)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "S7amG_9XV-w9"
},
"source": [
"定义一个辅助函数,通过 gzip 实际压缩模型并测量压缩后的大小。"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"execution": {
"iopub.execute_input": "2020-11-12T08:15:04.508349Z",
"iopub.status.busy": "2020-11-12T08:15:04.507345Z",
"iopub.status.idle": "2020-11-12T08:15:04.510058Z",
"shell.execute_reply": "2020-11-12T08:15:04.509568Z"
},
"id": "1XJ4QBMpW5JB"
},
"outputs": [],
"source": [
"def get_gzipped_model_size(file):\n",
" # It returns the size of the gzipped model in bytes.\n",
" import os\n",
" import zipfile\n",
"\n",
" _, zipped_file = tempfile.mkstemp('.zip')\n",
" with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:\n",
" f.write(file)\n",
"\n",
" return os.path.getsize(zipped_file)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "INeAOWRBSGPj"
},
"source": [
"比较后可以发现,聚类使模型大小缩减至原来的**六分之一**"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"execution": {
"iopub.execute_input": "2020-11-12T08:15:04.514561Z",
"iopub.status.busy": "2020-11-12T08:15:04.513899Z",
"iopub.status.idle": "2020-11-12T08:15:04.530798Z",
"shell.execute_reply": "2020-11-12T08:15:04.530264Z"
},
"id": "SG1MgZCeSGPk"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Size of gzipped baseline Keras model: 78076.00 bytes\n",
"Size of gzipped clustered Keras model: 12728.00 bytes\n",
"Size of gzipped clustered TFlite model: 12126.00 bytes\n"
]
}
],
"source": [
"print(\"Size of gzipped baseline Keras model: %.2f bytes\" % (get_gzipped_model_size(keras_file)))\n",
"print(\"Size of gzipped clustered Keras model: %.2f bytes\" % (get_gzipped_model_size(clustered_keras_file)))\n",
"print(\"Size of gzipped clustered TFlite model: %.2f bytes\" % (get_gzipped_model_size(clustered_tflite_file)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5TOgpEGfSGPn"
},
"source": [
"## 通过将权重聚类与训练后量化相结合,创建一个大小缩减至**八分之一**的 TFLite 模型"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BQb50aC3SGPn"
},
"source": [
"您可以将训练后量化应用于聚类模型来获得更多好处。"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"execution": {
"iopub.execute_input": "2020-11-12T08:15:04.537458Z",
"iopub.status.busy": "2020-11-12T08:15:04.536246Z",
"iopub.status.idle": "2020-11-12T08:15:05.024821Z",
"shell.execute_reply": "2020-11-12T08:15:05.024188Z"
},
"id": "XyHC8euLSGPo"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: /tmp/tmpmzv1zby7/assets\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: /tmp/tmpmzv1zby7/assets\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Saved quantized and clustered TFLite model to: /tmp/tmp5yu2mobb.tflite\n",
"Size of gzipped baseline Keras model: 78076.00 bytes\n",
"Size of gzipped clustered and quantized TFlite model: 9237.00 bytes\n"
]
}
],
"source": [
"converter = tf.lite.TFLiteConverter.from_keras_model(final_model)\n",
"converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
"tflite_quant_model = converter.convert()\n",
"\n",
"_, quantized_and_clustered_tflite_file = tempfile.mkstemp('.tflite')\n",
"\n",
"with open(quantized_and_clustered_tflite_file, 'wb') as f:\n",
" f.write(tflite_quant_model)\n",
"\n",
"print('Saved quantized and clustered TFLite model to:', quantized_and_clustered_tflite_file)\n",
"print(\"Size of gzipped baseline Keras model: %.2f bytes\" % (get_gzipped_model_size(keras_file)))\n",
"print(\"Size of gzipped clustered and quantized TFlite model: %.2f bytes\" % (get_gzipped_model_size(quantized_and_clustered_tflite_file)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "U-yBcocGSGPv"
},
"source": [
"## 查看从 TF 到 TFLite 的准确率持久性"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Jh_pcf0XSGPv"
},
"source": [
"定义一个辅助函数,基于测试数据集评估 TFLite 模型。"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"execution": {
"iopub.execute_input": "2020-11-12T08:15:05.032379Z",
"iopub.status.busy": "2020-11-12T08:15:05.031614Z",
"iopub.status.idle": "2020-11-12T08:15:05.033465Z",
"shell.execute_reply": "2020-11-12T08:15:05.033912Z"
},
"id": "EJ9B7pRISGPw"
},
"outputs": [],
"source": [
"def eval_model(interpreter):\n",
" input_index = interpreter.get_input_details()[0][\"index\"]\n",
" output_index = interpreter.get_output_details()[0][\"index\"]\n",
"\n",
" # Run predictions on every image in the \"test\" dataset.\n",
" prediction_digits = []\n",
" for i, test_image in enumerate(test_images):\n",
" if i % 1000 == 0:\n",
" print('Evaluated on {n} results so far.'.format(n=i))\n",
" # Pre-processing: add batch dimension and convert to float32 to match with\n",
" # the model's input data format.\n",
" test_image = np.expand_dims(test_image, axis=0).astype(np.float32)\n",
" interpreter.set_tensor(input_index, test_image)\n",
"\n",
" # Run inference.\n",
" interpreter.invoke()\n",
"\n",
" # Post-processing: remove batch dimension and find the digit with highest\n",
" # probability.\n",
" output = interpreter.tensor(output_index)\n",
" digit = np.argmax(output()[0])\n",
" prediction_digits.append(digit)\n",
"\n",
" print('\\n')\n",
" # Compare prediction results with ground truth labels to calculate accuracy.\n",
" prediction_digits = np.array(prediction_digits)\n",
" accuracy = (prediction_digits == test_labels).mean()\n",
" return accuracy"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0swuxbPmSGPy"
},
"source": [
"评估已被聚类和量化的模型后,您将看到从 TensorFlow 持续到 TFLite 后端的准确率。"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"execution": {
"iopub.execute_input": "2020-11-12T08:15:05.038770Z",
"iopub.status.busy": "2020-11-12T08:15:05.038089Z",
"iopub.status.idle": "2020-11-12T08:15:06.686250Z",
"shell.execute_reply": "2020-11-12T08:15:06.685671Z"
},
"id": "RFD4LXjpSGPz"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluated on 0 results so far.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluated on 1000 results so far.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluated on 2000 results so far.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluated on 3000 results so far.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluated on 4000 results so far.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluated on 5000 results so far.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluated on 6000 results so far.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluated on 7000 results so far.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluated on 8000 results so far.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluated on 9000 results so far.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"Clustered and quantized TFLite test_accuracy: 0.9759\n",
"Clustered TF test accuracy: 0.9760000109672546\n"
]
}
],
"source": [
"interpreter = tf.lite.Interpreter(model_content=tflite_quant_model)\n",
"interpreter.allocate_tensors()\n",
"\n",
"test_accuracy = eval_model(interpreter)\n",
"\n",
"print('Clustered and quantized TFLite test_accuracy:', test_accuracy)\n",
"print('Clustered TF test accuracy:', clustered_model_accuracy)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JgXTEXC7SGP1"
},
"source": [
"## 结论"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7JhbpowqSGP1"
},
"source": [
"在本教程中,您了解了如何使用 TensorFlow Model Optimization Toolkit API 创建聚类模型。更具体地说,您已经从头至尾完成了一个端到端示例,此示例为 MNIST 创建了一个大小缩减至原来的八分之一且准确率差异最小的模型。我们鼓励您试用这项新功能,这对于在资源受限的环境中进行部署特别重要。\n"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "clustering_example.ipynb",
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
}
},
"nbformat": 4,
"nbformat_minor": 0
}