{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "h2q27gKz1H20" }, "source": [ "##### Copyright 2019 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "TUfAcER1oUS6" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "Gb7qyhNL1yWt" }, "source": [ "# 使用 TensorFlow Lite Model Maker 进行图像分类" ] }, { "cell_type": "markdown", "metadata": { "id": "nDABAblytltI" }, "source": [ "\n", " \n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 上查看 在 Google Colab 中运行 在 Github 上查看源代码 下载笔记本 查看 TF Hub 模型
" ] }, { "cell_type": "markdown", "metadata": { "id": "m86-Nh4pMHqY" }, "source": [ "为设备端 ML 应用部署此模型时,[TensorFlow Lite Model Maker 库](https://tensorflow.google.cn/lite/models/modify/model_maker)可以简化将 TensorFlow 神经网络模型适配和转换为特定输入数据的过程。\n", "\n", "此笔记本展示了一个端到端示例,该示例使用 Model Maker 库演示了如何调整和转换在移动设备上对花卉进行分类的常用图像分类模型。" ] }, { "cell_type": "markdown", "metadata": { "id": "bcLF2PKkSbV3" }, "source": [ "## 前提条件\n", "\n", "要运行此示例,我们首先需要安装几个所需的软件包,包括 GitHub [仓库](https://github.com/tensorflow/examples/tree/master/tensorflow_examples/lite/model_maker)中的 Model Maker 软件包。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6cv3K3oaksJv" }, "outputs": [], "source": [ "!sudo apt -y install libportaudio2\n", "!pip install -q tflite-model-maker" ] }, { "cell_type": "markdown", "metadata": { "id": "Gx1HGRoFQ54j" }, "source": [ "导入所需的软件包。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XtxiUeZEiXpt" }, "outputs": [], "source": [ "import os\n", "\n", "import numpy as np\n", "\n", "import tensorflow as tf\n", "assert tf.__version__.startswith('2')\n", "\n", "from tflite_model_maker import model_spec\n", "from tflite_model_maker import image_classifier\n", "from tflite_model_maker.config import ExportFormat\n", "from tflite_model_maker.config import QuantizationConfig\n", "from tflite_model_maker.image_classifier import DataLoader\n", "\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "metadata": { "id": "KKRaYHABpob5" }, "source": [ "## 简单的端到端示例" ] }, { "cell_type": "markdown", "metadata": { "id": "SiZZ5DHXotaW" }, "source": [ "### 获取数据路径\n", "\n", "让我们获取一些图像来试验一下这个简单的端到端示例。对 Model Maker 来说,数百个图像是好的开始,但更多数据可以获得更高的准确率。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "3jz5x0JoskPv" }, "outputs": [], "source": [ "image_path = tf.keras.utils.get_file(\n", " 'flower_photos.tgz',\n", " 'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',\n", " extract=True)\n", "image_path = os.path.join(os.path.dirname(image_path), 'flower_photos')" ] }, { "cell_type": "markdown", "metadata": { "id": "a55MR6i6nuDm" }, "source": [ "您可以将 `image_path` 替换为自己的图像文件夹。对于将数据上传到 Colab,您可以在左侧边栏中找到上传按钮(如下图中的红色方框所示)。您可以尝试上传一个 Zip 文件并将其解压缩。根文件路径为当前路径。\n", "\n", " \"上传文件\" " ] }, { "cell_type": "markdown", "metadata": { "id": "NNRNv_mloS89" }, "source": [ "如果您不想将图像上传到云端,也可以按照 GitHub 中的[指南](https://github.com/tensorflow/examples/tree/master/tensorflow_examples/lite/model_maker)尝试在本地运行库。" ] }, { "cell_type": "markdown", "metadata": { "id": "w-VDriAdsowu" }, "source": [ "### 运行示例\n", "\n", "如下所示,该示例包含 4 行代码,每行代码表示整个过程中的一个步骤。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "6ahtcO86tZBL" }, "source": [ "第 1 步:加载特定于设备端 ML 应用的输入数据,并将其拆分为训练数据和测试数据。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "lANoNS_gtdH1" }, "outputs": [], "source": [ "data = DataLoader.from_folder(image_path)\n", "train_data, test_data = data.split(0.9)" ] }, { "cell_type": "markdown", "metadata": { "id": "Y_9IWyIztuRF" }, "source": [ "第 2 步:自定义 TensorFlow 模型。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yRXMZbrwtyRD" }, "outputs": [], "source": [ "model = image_classifier.create(train_data)" ] }, { "cell_type": "markdown", "metadata": { "id": "oxU2fDr-t2Ya" }, "source": [ "第 3 步:评估模型。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wQr02VxJt6Cs" }, "outputs": [], "source": [ "loss, accuracy = model.evaluate(test_data)" ] }, { "cell_type": "markdown", "metadata": { "id": "eVZw9zU8t84y" }, "source": [ "第 4 步:导出为 TensorFlow Lite 模型。\n", "\n", "在这里,我们导出带有[元数据](https://tensorflow.google.cn/lite/models/convert/metadata)的 TensorFlow Lite 模型,它提供了模型描述的标准。标签文件嵌入在元数据中。默认的训练后量化技术是图像分类任务的全整数量化。\n", "\n", "与上传部分相同,您可以在左侧边栏下载它供您自己使用。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Zb-eIzfluCoa" }, "outputs": [], "source": [ "model.export(export_dir='.')" ] }, { "cell_type": "markdown", "metadata": { "id": "pyju1qc_v-wy" }, "source": [ "完成上述 4 个简单步骤后,我们可以在设备端应用(如[图像分类](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification)参考应用)中进一步使用 TensorFlow Lite 模型文件。" ] }, { "cell_type": "markdown", "metadata": { "id": "R1QG32ivs9lF" }, "source": [ "## 详细流程\n", "\n", "目前,我们支持多种用于图像分类的预训练模型(例如 EfficientNet-Lite*、MobileNetV2 和 ResNet50 模型)。向此库添加新的预训练模型也很灵活,您只需编写几行代码。\n", "\n", "下文对此端到端示例进行了逐步介绍,以展示更多详细信息。" ] }, { "cell_type": "markdown", "metadata": { "id": "ygEncJxtl-nQ" }, "source": [ "### 第 1 步:加载特定于设备端 ML 应用的输入数据\n", "\n", "花卉数据集包含分属于 5 个类的 3670 个图像。下载数据集的存档版本并解压缩。\n", "\n", "该数据集具有以下目录结构:\n", "\n", "
<b>flower_photos</b>\n",
        "|__ <b>daisy</b>\n",
        "    |______ 100080576_f52e8ee070_n.jpg\n",
        "    |______ 14167534527_781ceb1b7a_n.jpg\n",
        "    |______ ...\n",
        "|__ <b>dandelion</b>\n",
        "    |______ 10043234166_e6dd915111_n.jpg\n",
        "    |______ 1426682852_e62169221f_m.jpg\n",
        "    |______ ...\n",
        "|__ <b>roses</b>\n",
        "    |______ 102501987_3cdb8e5394_n.jpg\n",
        "    |______ 14982802401_a3dfb22afb.jpg\n",
        "    |______ ...\n",
        "|__ <b>sunflowers</b>\n",
        "    |______ 12471791574_bb1be83df4.jpg\n",
        "    |______ 15122112402_cafa41934f.jpg\n",
        "    |______ ...\n",
        "|__ <b>tulips</b>\n",
        "    |______ 13976522214_ccec508fe7.jpg\n",
        "    |______ 14487943607_651e8062a1_m.jpg\n",
        "    |______ ...\n",
        "
" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7tOfUr2KlgpU" }, "outputs": [], "source": [ "image_path = tf.keras.utils.get_file(\n", " 'flower_photos.tgz',\n", " 'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',\n", " extract=True)\n", "image_path = os.path.join(os.path.dirname(image_path), 'flower_photos')" ] }, { "cell_type": "markdown", "metadata": { "id": "E051HBUM5owi" }, "source": [ "使用 `DataLoader` 类加载数据。\n", "\n", "对于 `from_folder()` 方法,它可以从文件夹加载数据。它假定同一类的图像数据位于相同的子目录中,且子文件夹名称为类名称。目前支持 JPEG 编码的图像和 PNG 编码的图像。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "I_fOlZsklmlL" }, "outputs": [], "source": [ "data = DataLoader.from_folder(image_path)" ] }, { "cell_type": "markdown", "metadata": { "id": "u501eT4koURB" }, "source": [ "将它拆分为训练数据 (80%)、验证数据(10%,可选)和测试数据 (10%)。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cY4UU5SUobtJ" }, "outputs": [], "source": [ "train_data, rest_data = data.split(0.8)\n", "validation_data, test_data = rest_data.split(0.5)" ] }, { "cell_type": "markdown", "metadata": { "id": "Z9_MYPie3EMO" }, "source": [ "显示 25 个带标签的图像样本。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Ih4Wx44I482b" }, "outputs": [], "source": [ "plt.figure(figsize=(10,10))\n", "for i, (image, label) in enumerate(data.gen_dataset().unbatch().take(25)):\n", " plt.subplot(5,5,i+1)\n", " plt.xticks([])\n", " plt.yticks([])\n", " plt.grid(False)\n", " plt.imshow(image.numpy(), cmap=plt.cm.gray)\n", " plt.xlabel(data.index_to_label[label.numpy()])\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "AWuoensX4vDA" }, "source": [ "### 第 2 步:自定义 TensorFlow 模型\n", "\n", "根据加载的数据创建自定义图像分类器模型。默认模型为 EfficientNet-Lite0。\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "TvYSUuJY3QxR" }, "outputs": [], "source": [ "model = image_classifier.create(train_data, validation_data=validation_data)" ] }, { "cell_type": "markdown", "metadata": { "id": "4JFOKWnH9x8_" }, "source": [ "看一下详细的模型结构。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QNXAfjl192dC" }, "outputs": [], "source": [ "model.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "LP5FPk_tOxoZ" }, "source": [ "### 第 3 步:评估自定义的模型\n", "\n", "评估模型的结果,获得模型的损失和准确率。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "A8c2ZQ0J3Riy" }, "outputs": [], "source": [ "loss, accuracy = model.evaluate(test_data)" ] }, { "cell_type": "markdown", "metadata": { "id": "6ZCrYOWoCt05" }, "source": [ "我们可以对 100 个测试图像的预测结果进行绘制。红色的预测标签为错误的预测结果,其余为正确的预测。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "n9O9Kx7nDQWD" }, "outputs": [], "source": [ "# A helper function that returns 'red'/'black' depending on if its two input\n", "# parameter matches or not.\n", "def get_label_color(val1, val2):\n", " if val1 == val2:\n", " return 'black'\n", " else:\n", " return 'red'\n", "\n", "# Then plot 100 test images and their predicted labels.\n", "# If a prediction result is different from the label provided label in \"test\"\n", "# dataset, we will highlight it in red color.\n", "plt.figure(figsize=(20, 20))\n", "predicts = model.predict_top_k(test_data)\n", "for i, (image, label) in enumerate(test_data.gen_dataset().unbatch().take(100)):\n", " ax = plt.subplot(10, 10, i+1)\n", " plt.xticks([])\n", " plt.yticks([])\n", " plt.grid(False)\n", " plt.imshow(image.numpy(), cmap=plt.cm.gray)\n", "\n", " predict_label = predicts[i][0][0]\n", " color = get_label_color(predict_label,\n", " test_data.index_to_label[label.numpy()])\n", " ax.xaxis.label.set_color(color)\n", " plt.xlabel('Predicted: %s' % predict_label)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "S3H0rkbLUZAG" }, "source": [ "如果准确率不符合应用要求,可以参考[高级用法](#scrollTo=zNDBP2qA54aK)来探索替代方法(例如更改为更大的模型、调整重新训练参数等)。" ] }, { "cell_type": "markdown", "metadata": { "id": "aeHoGAceO2xV" }, "source": [ "### 第 4 步:导出为 TensorFlow Lite 模型\n", "\n", "将经过训练的模型转换为带有[元数据](https://tensorflow.google.cn/lite/models/convert/metadata)的 TensorFlow Lite 模型格式,以便您以后可以在设备端 ML 应用中使用。标签文件和词汇文件嵌入在元数据中。默认的 TFLite 文件名是 `model.tflite`。\n", "\n", "在许多设备端 ML 应用中,模型大小是一个重要因素。因此,建议您应用量化模型以使其更小并可能加快运行速度。对于图像分类任务,默认的训练后量化技术是全整数量化。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Im6wA9lK3TQB" }, "outputs": [], "source": [ "model.export(export_dir='.')" ] }, { "cell_type": "markdown", "metadata": { "id": "ROS2Ay2jMPCl" }, "source": [ "有关如何将 TensorFlow Lite 模型集成到移动应用中的更多详细信息,请参阅图片分类[示例指南](https://tensorflow.google.cn/lite/examples/image_classification/overview)。\n", "\n", "可以使用 [TensorFlow Lite Task Library](https://tensorflow.google.cn/lite/inference_with_metadata/task_library/overview) 的 [ImageClassifier API](https://tensorflow.google.cn/lite/inference_with_metadata/task_library/image_classifier) 将此模型集成到 Android 或 iOS 应用中。" ] }, { "cell_type": "markdown", "metadata": { "id": "habFnvRxxQ4A" }, "source": [ "允许的导出格式可以是以下列表中的一个或多个:\n", "\n", "- `ExportFormat.TFLITE`\n", "- `ExportFormat.LABEL`\n", "- `ExportFormat.SAVED_MODEL`\n", "\n", "默认情况下,它仅导出带有元数据的 TensorFlow Lite 模型。您也可以有选择地导出不同的文件。例如,仅导出标签文件,如下所示:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "BvxWsOTmKG4P" }, "outputs": [], "source": [ "model.export(export_dir='.', export_format=ExportFormat.LABEL)" ] }, { "cell_type": "markdown", "metadata": { "id": "-4jQaxyT5_KV" }, "source": [ "您还可以使用 `evaluate_tflite` 方法评估 tflite 模型。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "S1YoPX5wOK-u" }, "outputs": [], "source": [ "model.evaluate_tflite('model.tflite', test_data)" ] }, { "cell_type": "markdown", "metadata": { "id": "zNDBP2qA54aK" }, "source": [ "## 高级用法\n", "\n", "`create` 函数是此库的关键部分。它使用迁移学习以及与[教程](https://tensorflow.google.cn/tutorials/images/transfer_learning)类似的预训练模型。\n", "\n", "`create` 函数包含以下步骤:\n", "\n", "1. 根据参数 `validation_ratio` 和 `test_ratio` 将数据拆分为训练数据、验证数据和测试数据。`validation_ratio` 和 `test_ratio` 的默认值分别是 `0.1` 和 `0.1`。\n", "2. 从 TensorFlow Hub 下载[图像特征向量](https://tensorflow.google.cn/hub/common_signatures/images#image_feature_vector)作为基础模型。默认的预训练模型是 EfficientNet-Lite0。\n", "3. 在头层和预训练模型之间添加一个包含带 `dropout_rate` 的随机失活层的分类器头。默认的 `dropout_rate` 是 TensorFlow Hub 上 [make_image_classifier_lib](https://github.com/tensorflow/hub/blob/master/tensorflow_hub/tools/make_image_classifier/make_image_classifier_lib.py#L55) 中的默认 `dropout_rate` 值。\n", "4. 预处理原始输入数据。目前,预处理步骤包括将每个图像像素的值归一化为模型输入尺度,并将其调整为模型输入大小。EfficientNet-Lite0 的输入尺度为 `[0, 1]`,输入图像大小为 `[224, 224, 3]`。\n", "5. 将数据馈送到分类器模型中。默认情况下,训练参数(例如训练周期、批次大小、学习率、动量)是 TensorFlow Hub 上 [make_image_classifier_lib](https://github.com/tensorflow/hub/blob/master/tensorflow_hub/tools/make_image_classifier/make_image_classifier_lib.py#L55) 中的默认值。只有分类器头经过了训练。\n", "\n", "在本部分中,我们介绍了几个高级主题,包括切换到不同的图像分类器模型、更改训练超参数等。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "Gc4Jk8TvBQfm" }, "source": [ "## 在 TensorFlow Lite 模型上自定义训练后量化\n" ] }, { "cell_type": "markdown", "metadata": { "id": "tD8BOYrHBiDt" }, "source": [ "[训练后量化](https://tensorflow.google.cn/lite/performance/post_training_quantization)是一种转换技术,可以缩减模型大小并缩短推断延迟,同时改善 CPU 和硬件加速器推断速度,且几乎不会降低模型准确率。因此,它被广泛用于优化模型。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "iyIo0d5TCzE2" }, "source": [ "Model Maker 库在导出模型时会应用默认的训练后量化技术。如果您想自定义训练后量化,Model Maker 也支持使用 [QuantizationConfig](https://tensorflow.google.cn/lite/api_docs/python/tflite_model_maker/config/QuantizationConfig) 的多个训练后量化选项。我们以 float16 量化为例。首先,定义量化配置。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "k8hL2mstCxQl" }, "outputs": [], "source": [ "config = QuantizationConfig.for_float16()" ] }, { "cell_type": "markdown", "metadata": { "id": "K1gzx_rmFMOA" }, "source": [ "然后,我们使用此配置导出 TensorFlow Lite 模型。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WTJzFQnJFMjr" }, "outputs": [], "source": [ "model.export(export_dir='.', tflite_filename='model_fp16.tflite', quantization_config=config)" ] }, { "cell_type": "markdown", "metadata": { "id": "Safo0e40wKZW" }, "source": [ "在 Colab 中,您可以从左侧边栏下载名为 `model_fp16.tflite` 的模型,与上文中的上传部分相同。" ] }, { "cell_type": "markdown", "metadata": { "id": "A4kiTJtZ_sDm" }, "source": [ "## 更改模型\n" ] }, { "cell_type": "markdown", "metadata": { "id": "794vgj6ud7Ep" }, "source": [ "### 更改为此库所支持的模型。\n", "\n", "此库目前支持 EfficientNet-Lite、MobileNetV2 和 ResNet50 模型。[EfficientNet-Lite](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite) 是一系列图像分类模型,可以实现最先进的准确率,并适用于 Edge 设备。默认模型为 EfficientNet-Lite0。\n", "\n", "我们只需将 `create` 方法中的参数 `model_spec` 设置为 MobileNetV2 模型规范,即可将模型切换为 MobileNetV2。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7JKsJ6-P6ae1" }, "outputs": [], "source": [ "model = image_classifier.create(train_data, model_spec=model_spec.get('mobilenet_v2'), validation_data=validation_data)" ] }, { "cell_type": "markdown", "metadata": { "id": "gm_B1Wv08AxR" }, "source": [ "评估新近重新训练的 MobileNetV2 模型,查看测试数据中的准确率和损失。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "lB2Go3HW8X7_" }, "outputs": [], "source": [ "loss, accuracy = model.evaluate(test_data)" ] }, { "cell_type": "markdown", "metadata": { "id": "vAciGzVWtmWp" }, "source": [ "### 更改为 TensorFlow Hub 中的模型\n", "\n", "此外,我们还可以切换为其他新模型,这些模型输入图像并输出具有 TensorFlow Hub 格式的特征向量。\n", "\n", "以 [Inception V3](https://tfhub.dev/google/imagenet/inception_v3/feature_vector/1) 模型为例,我们可以定义 `inception_v3_spec`,它是 [image_classifier.ModelSpec](https://tensorflow.google.cn/lite/api_docs/python/tflite_model_maker/image_classifier/ModelSpec) 的对象,且包含 Inception V3 模型的规范。\n", "\n", "我们需要指定模型名称 `name`,以及 TensorFlow Hub 模型的网址 `uri`。同时,`input_image_shape` 的默认值为 `[224, 224]`。对于 Inception V3 模型,我们需要将其更改为 `[299, 299]`。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xdiMF2WMfAR4" }, "outputs": [], "source": [ "inception_v3_spec = image_classifier.ModelSpec(\n", " uri='https://tfhub.dev/google/imagenet/inception_v3/feature_vector/1')\n", "inception_v3_spec.input_image_shape = [299, 299]" ] }, { "cell_type": "markdown", "metadata": { "id": "T_GGIoXZCs5F" }, "source": [ "然后,将 `create` 方法中的参数 `model_spec` 设置为 `inception_v3_spec`,我们便可重新训练 Inception V3 模型。\n", "\n", "其余步骤完全相同,最后我们可以获得自定义的 InceptionV3 TensorFlow Lite 模型。" ] }, { "cell_type": "markdown", "metadata": { "id": "UhZ5IRKdeex3" }, "source": [ "### 更改您的自定义模型" ] }, { "cell_type": "markdown", "metadata": { "id": "svTjlZhrCrcV" }, "source": [ "如果我们想使用 TensorFlow Hub 中没有的自定义模型,应在 TensorFlow Hub 中创建并导出 [ModelSpec](https://tensorflow.google.cn/hub/api_docs/python/hub/ModuleSpec)。\n", "\n", "然后像上面的过程一样开始定义 `ModelSpec` 对象。" ] }, { "cell_type": "markdown", "metadata": { "id": "4M9bn703AHt2" }, "source": [ "## 更改训练超参数\n", "\n", "我们还可以更改训练超参数(如 `epochs`、`dropout_rate` 和 `batch_size`),它们可能会影响模型的准确率。可以调整的模型参数包括:\n", "\n", "- `epochs`:更多的周期可能会获得更高的准确率,直到收敛为止,但训练的周期过多可能会导致过拟合。\n", "- `dropout_rate` :随机失活率,避免过拟合。默认为 None。\n", "- `batch_size` :在一个训练步骤中使用的样本数。默认为 None。\n", "- `validation_data` :验证数据。如果为 None,则跳过验证过程。默认为 None。\n", "- `train_whole_model`:如果为 True,则将 Hub 模块与顶部的分类层一起训练。否则,仅训练顶部分类层。默认为 None。\n", "- `learning_rate` :基础学习率。默认为 None。\n", "- `momentum`:转发给优化器的 Python 浮点数。仅在 `use_hub_library` 为 True 时使用。默认为 None。\n", "- `shuffle`:布尔值,表示是否应对数据进行乱序。默认为 False。\n", "- `use_augmentation`:布尔值,使用数据增强进行预处理。默认为 False。\n", "- `use_hub_library`:布尔值,使用来自 TensorFlow Hub 的 `make_image_classifier_lib` 重新训练模型。对于具有多个类别的复杂数据集,此训练流水线可以实现更好的性能。默认为 True。\n", "- `warmup_steps`:用于学习率预热计划的预热步骤数。如果为 None,则使用默认的 warmup_steps,它是两个周期中的总训练步数。仅当 `use_hub_library` 为 False 时使用。默认为 None。\n", "- `model_dir`: 可选,模型检查点文件的位置。仅当 `use_hub_library` 为 False 时使用。默认为 None。\n", "\n", "默认为 None 的参数(如 `epochs`),将从 TensorFlow Hub 库或 [train_image_classifier_lib](https://github.com/tensorflow/examples/blob/f0260433d133fd3cea4a920d1e53ecda07163aee/tensorflow_examples/lite/model_maker/core/task/train_image_classifier_lib.py#L61) 获取 [make_image_classifier_lib](https://github.com/tensorflow/hub/blob/02ab9b7d3455e99e97abecf43c5d598a5528e20c/tensorflow_hub/tools/make_image_classifier/make_image_classifier_lib.py#L54) 中的具体默认参数。\n", "\n", "例如,我们可以进行更多周期的训练。\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "A3k7mhH54QcK" }, "outputs": [], "source": [ "model = image_classifier.create(train_data, validation_data=validation_data, epochs=10)" ] }, { "cell_type": "markdown", "metadata": { "id": "VaYBQymQDsXU" }, "source": [ "使用 10 个训练周期评估新近重新训练的模型。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "VafIYpKWD4Sw" }, "outputs": [], "source": [ "loss, accuracy = model.evaluate(test_data)" ] }, { "cell_type": "markdown", "metadata": { "id": "dhBU5NCy5Ji2" }, "source": [ "# 阅读更多\n", "\n", "您可以阅读我们的[图像分类](https://tensorflow.google.cn/lite/examples/image_classification/overview)示例以了解技术细节。如需了解更多信息,请参阅:\n", "\n", "- TensorFlow Lite Model Maker [指南](https://tensorflow.google.cn/lite/models/modify/model_maker)和 [API 参考](https://tensorflow.google.cn/lite/api_docs/python/tflite_model_maker)。\n", "- Task Library:用于部署的 [ImageClassifier](https://tensorflow.google.cn/lite/inference_with_metadata/task_library/image_classifier)。\n", "- 端到端参考应用: [Android](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android)、 [iOS](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/ios) 和 [Raspberry PI](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/raspberry_pi)。\n" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "image_classification.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }