{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "h2q27gKz1H20" }, "source": [ "##### Copyright 2019 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "TUfAcER1oUS6" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "Gb7qyhNL1yWt" }, "source": [ "# 使用 TensorFlow Lite Model Maker 进行文本分类" ] }, { "cell_type": "markdown", "metadata": { "id": "Fw5Y7snSuG51" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 上查看 在 Google Colab 中运行 在 Github 上查看源代码 下载笔记本
" ] }, { "cell_type": "markdown", "metadata": { "id": "sr3q-gvm3cI8" }, "source": [ "TensorFlow Lite Model Maker 库简化了在设备端 ML 应用中部署 TensorFlow 模型时修改此模型并将其转换为特定输入数据的过程。\n", "\n", "此笔记本展示了一个端到端的示例,该示例利用 Model Maker 库来说明如何改编和转换常用的文本分类模型来对移动设备上的电影评论进行分类。文本分类模型会将文本分类为预定义的类别。输入应为经过预处理的文本,而输出为类别的概率。本教程中使用的数据集是正面和负面的电影评论。" ] }, { "cell_type": "markdown", "metadata": { "id": "bcLF2PKkSbV3" }, "source": [ "## 前提条件\n" ] }, { "cell_type": "markdown", "metadata": { "id": "2vvAObmTqglq" }, "source": [ "### 安装所需的软件包\n", "\n", "要运行此示例,请安装所需的软件包,包括 [GitHub 仓库](https://github.com/tensorflow/examples/tree/master/tensorflow_examples/lite/model_maker)中的 Model Maker 软件包。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "qhl8lqVamEty" }, "outputs": [], "source": [ "!sudo apt -y install libportaudio2\n", "!pip install -q tflite-model-maker\n", "!pip uninstall tflite_support_nightly\n", "!pip install tflite_support_nightly" ] }, { "cell_type": "markdown", "metadata": { "id": "l6lRhVK9Q_0U" }, "source": [ "导入所需的软件包。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XtxiUeZEiXpt" }, "outputs": [], "source": [ "import numpy as np\n", "import os\n", "\n", "from tflite_model_maker import model_spec\n", "from tflite_model_maker import text_classifier\n", "from tflite_model_maker.config import ExportFormat\n", "from tflite_model_maker.text_classifier import AverageWordVecSpec\n", "from tflite_model_maker.text_classifier import DataLoader\n", "\n", "from tflite_support.task import core\n", "from tflite_support.task import processor\n", "from tflite_support.task import text\n", "\n", "import tensorflow as tf\n", "assert tf.__version__.startswith('2')\n", "tf.get_logger().setLevel('ERROR')" ] }, { "cell_type": "markdown", "metadata": { "id": "BRd13bfetO7B" }, "source": [ "### 下载样本训练数据。\n", "\n", "在本教程中,我们将使用 [SST-2](https://nlp.stanford.edu/sentiment/index.html) (Stanford Sentiment Treebank),它是 [GLUE](https://gluebenchmark.com/) 基准测试中的一项任务。其中包含 67,349 条用于训练的电影评论和 872 条用于测试的电影评论。数据集有两个类:正面电影评论和负面电影评论。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "R2BSkxWg6Rhx" }, "outputs": [], "source": [ "data_dir = tf.keras.utils.get_file(\n", " fname='SST-2.zip',\n", " origin='https://dl.fbaipublicfiles.com/glue/data/SST-2.zip',\n", " extract=True)\n", "data_dir = os.path.join(os.path.dirname(data_dir), 'SST-2')" ] }, { "cell_type": "markdown", "metadata": { "id": "gPYTbGrizcTC" }, "source": [ "SST-2 数据集以 TSV 格式存储。TSV 和 CSV 之间的唯一区别是,TSV 使用制表符 `\\t` 字符作为分隔符,而不是 CSV 格式中的逗号 `,`。\n", "\n", "以下是训练数据集的前 5 行。label=0 表示负面,label=1 表示正面。\n", "\n", "句子 | 标签 | | |\n", "--- | --- | --- | --- | ---\n", "hide new secretions from the parental units | 0 | | |\n", "contains no wit , only labored gags | 0 | | |\n", "that loves its characters and communicates something rather beautiful about human nature | 1 | | |\n", "remains utterly satisfied to remain the same throughout | 0 | | |\n", "on the worst revenge-of-the-nerds clichés the filmmakers could dredge up | 0 | | |\n", "\n", "接下来,我们将数据集加载到 Pandas 数据帧中,并将当前的标签名称(`0` 和 `1`)更改为更便于人类阅读的名称(`negative` 和 `positive`),并将它们用于模型训练。\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "iLNaOXnl3JQB" }, "outputs": [], "source": [ "import pandas as pd\n", "\n", "def replace_label(original_file, new_file):\n", " # Load the original file to pandas. We need to specify the separator as\n", " # '\\t' as the training data is stored in TSV format\n", " df = pd.read_csv(original_file, sep='\\t')\n", "\n", " # Define how we want to change the label name\n", " label_map = {0: 'negative', 1: 'positive'}\n", "\n", " # Excute the label change\n", " df.replace({'label': label_map}, inplace=True)\n", "\n", " # Write the updated dataset to a new file\n", " df.to_csv(new_file)\n", "\n", "# Replace the label name for both the training and test dataset. Then write the\n", "# updated CSV dataset to the current folder.\n", "replace_label(os.path.join(os.path.join(data_dir, 'train.tsv')), 'train.csv')\n", "replace_label(os.path.join(os.path.join(data_dir, 'dev.tsv')), 'dev.csv')" ] }, { "cell_type": "markdown", "metadata": { "id": "xushUyZXqP59" }, "source": [ "## 快速入门\n", "\n", "训练文本分类模型有五个步骤:\n", "\n", "**第 1 步:选择文本分类模型架构。**\n", "\n", "这里我们使用平均单词嵌入向量模型架构,它将生成一个小而快的模型,并具有不错的准确率。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CtdZ-JDwMimd" }, "outputs": [], "source": [ "spec = model_spec.get('average_word_vec')" ] }, { "cell_type": "markdown", "metadata": { "id": "yug6gR9qyHui" }, "source": [ "Model Maker 还支持其他模型架构,如 [BERT](https://arxiv.org/abs/1810.04805)。如果您有兴趣了解其他架构,请参阅下面的[为文本分类器选择模型架构](#scrollTo=kJ_B8fMDOhMR)部分。" ] }, { "cell_type": "markdown", "metadata": { "id": "s5U-A3tw6Y27" }, "source": [ "**第 2 步:加载训练数据和测试数据,然后根据特定的 `model_spec` 对其进行预处理。**\n", "\n", "Model Maker 可以接受 CSV 格式的输入数据。我们将使用先前创建的人类可读的标签名称加载训练数据集和测试数据集。\n", "\n", "每个模型架构都要求以特定的方式处理输入数据。`DataLoader` 会从 `model_spec` 读取要求,并自动执行必要的预处理。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HD5BvzWe6YKa" }, "outputs": [], "source": [ "train_data = DataLoader.from_csv(\n", " filename='train.csv',\n", " text_column='sentence',\n", " label_column='label',\n", " model_spec=spec,\n", " is_training=True)\n", "test_data = DataLoader.from_csv(\n", " filename='dev.csv',\n", " text_column='sentence',\n", " label_column='label',\n", " model_spec=spec,\n", " is_training=False)" ] }, { "cell_type": "markdown", "metadata": { "id": "2uZkLR6N6gDR" }, "source": [ "**第 3 步:用训练数据训练 TensorFlow 模型。**\n", "\n", "默认情况下,平均单词嵌入向量模型使用 `batch_size = 32`。因此,您将看到需要 2104 个步骤来遍历训练数据集中的 67,349 个句子。我们将对模型进行 10 个周期的训练,这意味着要对训练数据集遍历 10 次。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kwlYdTcg63xy" }, "outputs": [], "source": [ "model = text_classifier.create(train_data, model_spec=spec, epochs=10)" ] }, { "cell_type": "markdown", "metadata": { "id": "-BzCHLWJ6h7q" }, "source": [ "**第 4 步:在测试数据上评估模型。**\n", "\n", "在使用训练数据集中的句子训练文本分类模型后,我们将使用测试数据集中的剩余 872 个句子来评估模型在以前从未见过的新数据上的性能。\n", "\n", "由于默认批次大小为 32,因此需要 28 个步骤来遍历测试数据集中的 872 个句子。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8xmnl6Yy7ARn" }, "outputs": [], "source": [ "loss, acc = model.evaluate(test_data)" ] }, { "cell_type": "markdown", "metadata": { "id": "CgCDMe0e6jlT" }, "source": [ "**第 5 步:导出为 TensorFlow Lite 模型。**\n", "\n", "我们以 TensorFlow Lite 格式导出我们训练过的文本分类。我们将指定要导出模型的文件夹。默认情况下,对于平均单词嵌入向量模型架构,会导出浮点 TFLite 模型。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Hm_UULdW7A9T" }, "outputs": [], "source": [ "model.export(export_dir='average_word_vec')" ] }, { "cell_type": "markdown", "metadata": { "id": "rVxaf3x_7OfB" }, "source": [ "您可以使用 Colab 的左侧边栏下载 TensorFlow Lite 模型文件。按照我们在上面的 `export_dir` 参数中指定的那样,进入 `average_word_vec` 文件夹,右键点击 `model.tflite` 文件,然后选择 `Download` 以将其下载到本地计算机。\n", "\n", "可以使用 [TensorFlow Lite Task Library](https://tensorflow.google.cn/lite/inference_with_metadata/task_library/nl_classifier) 的 [NLClassifier API](https://tensorflow.google.cn/lite/inference_with_metadata/task_library/overview) 将此模型集成到 Android 或 iOS 应用中。\n", "\n", "有关如何在工作应用中使用模型的更多详细信息,请参阅 [TFLite 文本分类示例应用](https://github.com/tensorflow/examples/blob/master/lite/examples/text_classification/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/textclassification/client/TextClassificationClient.java#L54)。\n", "\n", "*注 1:Android Studio Model Binding 目前还不支持文本分类,请使用 TensorFlow Lite Task Library。*\n", "\n", "*注 2:在与 TFLite 模型相同的文件夹中有一个 `model.json` 文件。它包含绑定在 TensorFlow Lite 模型中的[元数据](https://tensorflow.google.cn/lite/models/convert/metadata)的 JSON 表示。模型元数据能够帮助 TFLite Task Library 了解模型的作用以及如何对模型的数据进行预处理/后处理。您无需下载 `model.json` 文件,因为它仅用于信息目的,其内容已经包含在 TFLite 文件中。*\n", "\n", "*注 3:如果您使用 MobileBERT 或 BERT-Base 架构对文本分类模型进行训练,则需要使用 [BertNLClassifier API](https://tensorflow.google.cn/lite/inference_with_metadata/task_library/bert_nl_classifier) 将训练后的模型集成到移动应用中。*" ] }, { "cell_type": "markdown", "metadata": { "id": "l65ctmtW7_FF" }, "source": [ "以下各部分将逐步介绍该示例,以展示更多细节。" ] }, { "cell_type": "markdown", "metadata": { "id": "izO7NU7unYot" }, "source": [ "**第 6 步:使用 `TFLite Task Library` 演示如何使用训练的模型**" ] }, { "cell_type": "markdown", "metadata": { "id": "VDov6P4wppHO" }, "source": [ "将 dev.csv 文件读取到句子数据中,以使用训练的模型进行预测" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XWwvHmIltQC2" }, "outputs": [], "source": [ "sentence_data = pd.read_csv('/content/dev.csv', index_col=0)\n", "sentence_data" ] }, { "cell_type": "markdown", "metadata": { "id": "y_-bejm5vRBf" }, "source": [ "模型配置参数" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IAEEs3_3vPz5" }, "outputs": [], "source": [ "# Name of the TFLite text classification model.\n", "_MODEL = '/content/average_word_vec/model.tflite'\n", "# Whether to run the model on EdgeTPU.\n", "_ENABLE_EDGETPU = False\n", "# Number of CPU threads to run the model.\n", "_NUM_THREADS = 4" ] }, { "cell_type": "markdown", "metadata": { "id": "bInGjRcOtQbn" }, "source": [ "初始化模型\n", "\n", "我们还可以更改可能影响模型结果的参数,例如 `file_name{/code 0}、{code 1}use_coral{/code 1} 和 num_threads`。可以调整的参数包括:\n", "\n", "- `file_name`:TFLite 图像分类模型的名称。\n", "- `use_coral`:如果为 true,推断将被委托给连接的 Coral Edge TPU 设备。\n", "- `num_threads`:运行模型的 CPU 线程数。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Haham4qT8hmV" }, "outputs": [], "source": [ "# Initialize the text classification model.\n", "base_options = core.BaseOptions(file_name=_MODEL, use_coral=_ENABLE_EDGETPU, num_threads=_NUM_THREADS)\n", "options = text.NLClassifierOptions(base_options)\n", "\n", "# Create NLClassifier from options.\n", "classifier = text.NLClassifier.create_from_options(options)" ] }, { "cell_type": "markdown", "metadata": { "id": "9HLl9LC9oA3G" }, "source": [ "使用 `TFLite Task Library` 进行预测" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pAQDHFs5tTxZ" }, "outputs": [], "source": [ "for idx in range(20):\n", " sentence = sentence_data['sentence'].iloc[idx]\n", " label = sentence_data['label'].iloc[idx]\n", " text_classification_result = classifier.classify(sentence)\n", " classification_list = text_classification_result.classifications[0].categories\n", "\n", " # Sort output by probability descending.\n", " predict_label = sorted(\n", " classification_list, key=lambda item: item.score, reverse=True)[0]\n", "\n", " print('truth_label: {} -----> predict_label: {}'.format(label, predict_label.category_name))" ] }, { "cell_type": "markdown", "metadata": { "id": "kJ_B8fMDOhMR" }, "source": [ "## 选择文本分类器的模型架构\n", "\n", "每个 `model_spec` 对象代表一个文本分类器的具体模型。TensorFlow Lite Model Maker 目前支持 [MobileBERT](https://arxiv.org/pdf/2004.02984.pdf),平均单词嵌入向量和 [BERT-Base](https://arxiv.org/pdf/1810.04805.pdf) 模型。\n", "\n", "支持的模型 | model_spec 名称 | 模型描述 | 模型规模\n", "--- | --- | --- | ---\n", "平均单词嵌入向量 | 'average_word_vec' | 带有 RELU 激活的平均单词嵌入向量。 | <1MB\n", "MobileBERT | 'mobilebert_classifier' | 比 BERT-Base 小 4.3 倍、快 5.5 倍,同时可获得具有竞争力的结果,适合设备端应用。 | 25MB w/ 量化
100MB w/o 量化\n", "BERT-Base | 'bert_classifier' | NLP 任务中广泛使用的标准 BERT 模型。 | 300MB\n", "\n", "在快速入门中,我们使用了平均词嵌入向量模型。我们切换到 [MobileBERT](https://arxiv.org/pdf/2004.02984.pdf) 来训练准确率更高的模型。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vEAWuZQ1PFiX" }, "outputs": [], "source": [ "mb_spec = model_spec.get('mobilebert_classifier')" ] }, { "cell_type": "markdown", "metadata": { "id": "ygEncJxtl-nQ" }, "source": [ "## 加载训练模型\n", "\n", "您可以上传您自己的数据集以完成本教程。请使用 Colab 的左侧边栏上传您的数据集。\n", "\n", "\n", "\"Upload\n", "\n", "如果您不想将数据集上传到云,也可以按照[指南](https://github.com/tensorflow/examples/tree/master/tensorflow_examples/lite/model_maker)在本地运行该库。" ] }, { "cell_type": "markdown", "metadata": { "id": "mWAusqz-WD5i" }, "source": [ "为简单起见,我们将重用先前下载的 SST-2 数据集。我们使用 `DataLoader.from_csv` 方法来加载数据。\n", "\n", "请注意,由于我们已经更改了模型架构,我们将需要重新加载训练数据集和测试数据集以应用新的预处理逻辑。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "I_fOlZsklmlL" }, "outputs": [], "source": [ "train_data = DataLoader.from_csv(\n", " filename='train.csv',\n", " text_column='sentence',\n", " label_column='label',\n", " model_spec=mb_spec,\n", " is_training=True)\n", "test_data = DataLoader.from_csv(\n", " filename='dev.csv',\n", " text_column='sentence',\n", " label_column='label',\n", " model_spec=mb_spec,\n", " is_training=False)" ] }, { "cell_type": "markdown", "metadata": { "id": "MlHvVvv2hw4H" }, "source": [ "Model Maker 库还支持 `from_folder()` 方法来加载数据。它假定同一类的文本数据位于相同的子目录中,且子文件夹名称为类名称。每个文本文件包含一个电影评论样本。参数 `class_labels` 用于指定具体的子文件夹。" ] }, { "cell_type": "markdown", "metadata": { "id": "AWuoensX4vDA" }, "source": [ "## 训练 TensorFlow 模型\n", "\n", "使用训练数据训练文本分类模型。\n", "\n", "*注:由于 MobileBERT 是一个复杂模型,在 Colab GPU 上进行每个训练周期大约需要 10 分钟。请确保您使用的是 GPU 运行时。*" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "TvYSUuJY3QxR" }, "outputs": [], "source": [ "model = text_classifier.create(train_data, model_spec=mb_spec, epochs=3)" ] }, { "cell_type": "markdown", "metadata": { "id": "0JKI-pNc8idH" }, "source": [ "检查详细的模型结构。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gd7Hs8TF8n3H" }, "outputs": [], "source": [ "model.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "LP5FPk_tOxoZ" }, "source": [ "## 评估模型\n", "\n", "使用测试数据对我们刚刚训练的模型进行评估,并测量损失和准确率值。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "A8c2ZQ0J3Riy" }, "outputs": [], "source": [ "loss, acc = model.evaluate(test_data)" ] }, { "cell_type": "markdown", "metadata": { "id": "esBGwHE2QxE8" }, "source": [ "## 导出为 TensorFlow Lite 模型\n", "\n", "使用[元数据](https://tensorflow.google.cn/lite/models/convert/metadata)将训练好的模型转换为 TensorFlow Lite 模型格式,以便以后在设备端机器学习应用中使用。标签文件和词汇文件嵌入在元数据中。默认的 TFLite 文件名为 `model.tflite`。\n", "\n", "在许多设备端 ML 应用中,模型大小是一个重要因素。因此,建议您应用量化模型以使其更小并可能加快运行速度。对于 BERT 和 MobileBERT 模型,默认的训练后量化技术是动态范围量化。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Im6wA9lK3TQB" }, "outputs": [], "source": [ "model.export(export_dir='mobilebert/')" ] }, { "cell_type": "markdown", "metadata": { "id": "w12kvDdHJIGH" }, "source": [ "TensorFlow Lite 模型文件可以使用 [TensorFlow Lite Task Library](https://tensorflow.google.cn/lite/inference_with_metadata/task_library/overview) 中的 [BertNLClassifier API](https://tensorflow.google.cn/lite/inference_with_metadata/task_library/bert_nl_classifier) 集成到移动应用中。请注意,这与用于将训练的文本分类与平均单词向量模型架构集成在一起的 `NLClassifier` API 不同。" ] }, { "cell_type": "markdown", "metadata": { "id": "AVy0ormoMZwL" }, "source": [ "导出格式可以是以下列表中的一个或多个:\n", "\n", "- `ExportFormat.TFLITE`\n", "- `ExportFormat.LABEL`\n", "- `ExportFormat.VOCAB`\n", "- `ExportFormat.SAVED_MODEL`\n", "\n", "默认情况下,它仅导出包含模型元数据的 TensorFlow Lite 模型文件。您也可以选择导出与模型相关的其他文件,以便更好地进行检查。例如,只导出标签文件和词汇文件,如下所示:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "nbK7nzK_Mfx4" }, "outputs": [], "source": [ "model.export(export_dir='mobilebert/', export_format=[ExportFormat.LABEL, ExportFormat.VOCAB])" ] }, { "cell_type": "markdown", "metadata": { "id": "HZKYthlVrTos" }, "source": [ "您可以使用 `evaluate_tflite` 方法对 TFLite 模型进行评估,以衡量其准确率。将训练好的 TensorFlow 模型转换为 TFLite 格式并应用量化可能会影响其准确率,因此建议在部署之前评估 TFLite 模型的准确率。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ochbq95ZrVFX" }, "outputs": [], "source": [ "accuracy = model.evaluate_tflite('mobilebert/model.tflite', test_data)\n", "print('TFLite model accuracy: ', accuracy)" ] }, { "cell_type": "markdown", "metadata": { "id": "EoWiA_zX8rxE" }, "source": [ "## 高级用法\n", "\n", "`create` 函数是 Model Maker 库用于创建模型的驱动函数。`model_spec` 参数定义模型规范。目前支持 `AverageWordVecSpec` 和 `BertClassifierSpec` 类。`create` 函数包括以下步骤:\n", "\n", "1. 根据 `model_spec` 为文本分类器创建模型。\n", "2. 训练分类器模型。默认周期和默认批次大小通过 `model_spec` 对象中的 `default_training_epochs` 和 `default_batch_size` 变量进行设置。\n", "\n", "本部分介绍了调整模型和训练超参数等高级用法主题。" ] }, { "cell_type": "markdown", "metadata": { "id": "E8VxPiOLy4Gv" }, "source": [ "### 自定义 MobileBERT 模型超参数\n", "\n", "可以调整的模型参数包括:\n", "\n", "- `seq_len`:馈送到模型的序列长度。\n", "- `initializer_range`:用于初始化所有权重矩阵的 truncated_normal_initializer 的标准差。\n", "- `trainable`:指定预训练层是否可训练的布尔值。\n", "\n", "可以调整的训练流水线参数包括:\n", "\n", "- `model_dir`:模型检查点文件的位置。如果未设置,则会使用临时目录。\n", "- `dropout_rate`:随机失活率。\n", "- `learning_rate`:Adam 优化器的初始学习率。\n", "- `tpu`:要连接的 TPU 地址。\n", "\n", "例如,您可以设置 `seq_len=256`(默认为 128)。这允许模型对较长的文本进行分类。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4tr9BLcjy4Sh" }, "outputs": [], "source": [ "new_model_spec = model_spec.get('mobilebert_classifier')\n", "new_model_spec.seq_len = 256" ] }, { "cell_type": "markdown", "metadata": { "id": "mwtiksguDfhl" }, "source": [ "### 自定义平均单词嵌入向量模型超参数\n", "\n", "您可以调整模型基础架构,如 `AverageWordVecSpec` 类中的 `wordvec_dim` 和 `seq_len` 变量。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "cAOd5_bzH9AQ" }, "source": [ "例如,您可以使用较大的 `wordvec_dim` 值来训练模型。请注意,如果您修改模型,必须构造新的 `model_spec`。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "e9WBN0UTQoMN" }, "outputs": [], "source": [ "new_model_spec = AverageWordVecSpec(wordvec_dim=32)" ] }, { "cell_type": "markdown", "metadata": { "id": "6LSTdghTP0Cv" }, "source": [ "获取预处理的数据。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DVZurFBORG3J" }, "outputs": [], "source": [ "new_train_data = DataLoader.from_csv(\n", " filename='train.csv',\n", " text_column='sentence',\n", " label_column='label',\n", " model_spec=new_model_spec,\n", " is_training=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "tD7QVVHeRZoM" }, "source": [ "训练新模型。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PzpV246_JGEu" }, "outputs": [], "source": [ "model = text_classifier.create(new_train_data, model_spec=new_model_spec)" ] }, { "cell_type": "markdown", "metadata": { "id": "LvQuy7RSDir3" }, "source": [ "### 调节训练超参数\n", "\n", "您还可以调节影响模型准确率的训练超参数(如 `epochs` 和 `batch_size`)。例如,\n", "\n", "- `epochs`:更多的周期可能会获得更高的准确率,但也可能导致过拟合。\n", "- `batch_size`:一个训练步骤中要使用的样本数。\n", "\n", "例如,您可以使用更多周期进行训练。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rnWFaYZBG6NW" }, "outputs": [], "source": [ "model = text_classifier.create(new_train_data, model_spec=new_model_spec, epochs=20)" ] }, { "cell_type": "markdown", "metadata": { "id": "nUaKQZBQHBQR" }, "source": [ "使用 20 个训练周期评估新近重新训练的模型。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "BMPi1xflHDSY" }, "outputs": [], "source": [ "new_test_data = DataLoader.from_csv(\n", " filename='dev.csv',\n", " text_column='sentence',\n", " label_column='label',\n", " model_spec=new_model_spec,\n", " is_training=False)\n", "\n", "loss, accuracy = model.evaluate(new_test_data)" ] }, { "cell_type": "markdown", "metadata": { "id": "Eq6B9lKMfhS6" }, "source": [ "### 更改模型架构\n", "\n", "您可以通过更改 `model_spec` 来更改模型。下文展示了如何更改为 BERT-Base 模型。\n", "\n", "将文本分类器的 `model_spec` 更改为 BERT-Base 模型。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QfFCWrwyggrT" }, "outputs": [], "source": [ "spec = model_spec.get('bert_classifier')" ] }, { "cell_type": "markdown", "metadata": { "id": "L2d7yycrgu6L" }, "source": [ "其余步骤相同。" ] }, { "cell_type": "markdown", "metadata": { "id": "GgiD_tkyQn7l" }, "source": [ "### 在 TensorFlow Lite 模型上自定义训练后量化\n", "\n", "[训练后量化](https://tensorflow.google.cn/lite/performance/post_training_quantization)是一种转换技术,可以缩减模型大小并缩短推断延迟,同时改善 CPU 和硬件加速器推断速度,且几乎不会降低模型准确率。因此,它被广泛用于优化模型。\n", "\n", "Model Maker 库在导出模型时会应用默认的训练后量化技术。如果您想自定义训练后量化,Model Maker 也支持使用 [QuantizationConfig](https://tensorflow.google.cn/lite/api_docs/python/tflite_model_maker/config/QuantizationConfig) 的多个训练后量化选项。我们以 float16 量化为例。首先,定义量化配置。\n", "\n", "```python\n", "config = QuantizationConfig.for_float16()\n", "```\n", "\n", "然后,我们使用此配置导出 TensorFlow Lite 模型。\n", "\n", "```python\n", "model.export(export_dir='.', tflite_filename='model_fp16.tflite', quantization_config=config)\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "qkJGvMEx6VD-" }, "source": [ "# 阅读更多\n", "\n", "您可以阅读我们的[文本分类](https://tensorflow.google.cn/lite/examples/text_classification/overview)示例以了解技术细节。如需了解更多信息,请参阅:\n", "\n", "- TensorFlow Lite Model Maker [指南](https://tensorflow.google.cn/lite/models/modify/model_maker)和 [API 参考](https://tensorflow.google.cn/lite/api_docs/python/tflite_model_maker)。\n", "- Task Library:用于开发的 [NLClassifier](https://tensorflow.google.cn/lite/inference_with_metadata/task_library/nl_classifier) 和 [BertNLClassifier](https://tensorflow.google.cn/lite/inference_with_metadata/task_library/bert_nl_classifier)。\n", "- 端到端参考应用:[Android](https://github.com/tensorflow/examples/tree/master/lite/examples/text_classification/android) 和 [iOS](https://github.com/tensorflow/examples/tree/master/lite/examples/text_classification/ios)。" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "text_classification.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }