{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "h2q27gKz1H20" }, "source": [ "##### Copyright 2019 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "TUfAcER1oUS6" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "Gb7qyhNL1yWt" }, "source": [ "# 使用 TensorFlow Lite Model Maker 进行文本分类" ] }, { "cell_type": "markdown", "metadata": { "id": "Fw5Y7snSuG51" }, "source": [ "
![]() | \n",
" ![]() | \n",
" ![]() | \n",
" ![]() | \n",
"
num_threads`。可以调整的参数包括:\n",
"\n",
"- `file_name`:TFLite 图像分类模型的名称。\n",
"- `use_coral`:如果为 true,推断将被委托给连接的 Coral Edge TPU 设备。\n",
"- `num_threads`:运行模型的 CPU 线程数。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Haham4qT8hmV"
},
"outputs": [],
"source": [
"# Initialize the text classification model.\n",
"base_options = core.BaseOptions(file_name=_MODEL, use_coral=_ENABLE_EDGETPU, num_threads=_NUM_THREADS)\n",
"options = text.NLClassifierOptions(base_options)\n",
"\n",
"# Create NLClassifier from options.\n",
"classifier = text.NLClassifier.create_from_options(options)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9HLl9LC9oA3G"
},
"source": [
"使用 `TFLite Task Library` 进行预测"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "pAQDHFs5tTxZ"
},
"outputs": [],
"source": [
"for idx in range(20):\n",
" sentence = sentence_data['sentence'].iloc[idx]\n",
" label = sentence_data['label'].iloc[idx]\n",
" text_classification_result = classifier.classify(sentence)\n",
" classification_list = text_classification_result.classifications[0].categories\n",
"\n",
" # Sort output by probability descending.\n",
" predict_label = sorted(\n",
" classification_list, key=lambda item: item.score, reverse=True)[0]\n",
"\n",
" print('truth_label: {} -----> predict_label: {}'.format(label, predict_label.category_name))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kJ_B8fMDOhMR"
},
"source": [
"## 选择文本分类器的模型架构\n",
"\n",
"每个 `model_spec` 对象代表一个文本分类器的具体模型。TensorFlow Lite Model Maker 目前支持 [MobileBERT](https://arxiv.org/pdf/2004.02984.pdf),平均单词嵌入向量和 [BERT-Base](https://arxiv.org/pdf/1810.04805.pdf) 模型。\n",
"\n",
"支持的模型 | model_spec 名称 | 模型描述 | 模型规模\n",
"--- | --- | --- | ---\n",
"平均单词嵌入向量 | 'average_word_vec' | 带有 RELU 激活的平均单词嵌入向量。 | <1MB\n",
"MobileBERT | 'mobilebert_classifier' | 比 BERT-Base 小 4.3 倍、快 5.5 倍,同时可获得具有竞争力的结果,适合设备端应用。 | 25MB w/ 量化
100MB w/o 量化\n",
"BERT-Base | 'bert_classifier' | NLP 任务中广泛使用的标准 BERT 模型。 | 300MB\n",
"\n",
"在快速入门中,我们使用了平均词嵌入向量模型。我们切换到 [MobileBERT](https://arxiv.org/pdf/2004.02984.pdf) 来训练准确率更高的模型。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "vEAWuZQ1PFiX"
},
"outputs": [],
"source": [
"mb_spec = model_spec.get('mobilebert_classifier')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ygEncJxtl-nQ"
},
"source": [
"## 加载训练模型\n",
"\n",
"您可以上传您自己的数据集以完成本教程。请使用 Colab 的左侧边栏上传您的数据集。\n",
"\n",
"\n",
"
\n",
"\n",
"如果您不想将数据集上传到云,也可以按照[指南](https://github.com/tensorflow/examples/tree/master/tensorflow_examples/lite/model_maker)在本地运行该库。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mWAusqz-WD5i"
},
"source": [
"为简单起见,我们将重用先前下载的 SST-2 数据集。我们使用 `DataLoader.from_csv` 方法来加载数据。\n",
"\n",
"请注意,由于我们已经更改了模型架构,我们将需要重新加载训练数据集和测试数据集以应用新的预处理逻辑。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "I_fOlZsklmlL"
},
"outputs": [],
"source": [
"train_data = DataLoader.from_csv(\n",
" filename='train.csv',\n",
" text_column='sentence',\n",
" label_column='label',\n",
" model_spec=mb_spec,\n",
" is_training=True)\n",
"test_data = DataLoader.from_csv(\n",
" filename='dev.csv',\n",
" text_column='sentence',\n",
" label_column='label',\n",
" model_spec=mb_spec,\n",
" is_training=False)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MlHvVvv2hw4H"
},
"source": [
"Model Maker 库还支持 `from_folder()` 方法来加载数据。它假定同一类的文本数据位于相同的子目录中,且子文件夹名称为类名称。每个文本文件包含一个电影评论样本。参数 `class_labels` 用于指定具体的子文件夹。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AWuoensX4vDA"
},
"source": [
"## 训练 TensorFlow 模型\n",
"\n",
"使用训练数据训练文本分类模型。\n",
"\n",
"*注:由于 MobileBERT 是一个复杂模型,在 Colab GPU 上进行每个训练周期大约需要 10 分钟。请确保您使用的是 GPU 运行时。*"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "TvYSUuJY3QxR"
},
"outputs": [],
"source": [
"model = text_classifier.create(train_data, model_spec=mb_spec, epochs=3)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0JKI-pNc8idH"
},
"source": [
"检查详细的模型结构。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gd7Hs8TF8n3H"
},
"outputs": [],
"source": [
"model.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LP5FPk_tOxoZ"
},
"source": [
"## 评估模型\n",
"\n",
"使用测试数据对我们刚刚训练的模型进行评估,并测量损失和准确率值。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "A8c2ZQ0J3Riy"
},
"outputs": [],
"source": [
"loss, acc = model.evaluate(test_data)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "esBGwHE2QxE8"
},
"source": [
"## 导出为 TensorFlow Lite 模型\n",
"\n",
"使用[元数据](https://tensorflow.google.cn/lite/models/convert/metadata)将训练好的模型转换为 TensorFlow Lite 模型格式,以便以后在设备端机器学习应用中使用。标签文件和词汇文件嵌入在元数据中。默认的 TFLite 文件名为 `model.tflite`。\n",
"\n",
"在许多设备端 ML 应用中,模型大小是一个重要因素。因此,建议您应用量化模型以使其更小并可能加快运行速度。对于 BERT 和 MobileBERT 模型,默认的训练后量化技术是动态范围量化。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Im6wA9lK3TQB"
},
"outputs": [],
"source": [
"model.export(export_dir='mobilebert/')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "w12kvDdHJIGH"
},
"source": [
"TensorFlow Lite 模型文件可以使用 [TensorFlow Lite Task Library](https://tensorflow.google.cn/lite/inference_with_metadata/task_library/overview) 中的 [BertNLClassifier API](https://tensorflow.google.cn/lite/inference_with_metadata/task_library/bert_nl_classifier) 集成到移动应用中。请注意,这与用于将训练的文本分类与平均单词向量模型架构集成在一起的 `NLClassifier` API 不同。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AVy0ormoMZwL"
},
"source": [
"导出格式可以是以下列表中的一个或多个:\n",
"\n",
"- `ExportFormat.TFLITE`\n",
"- `ExportFormat.LABEL`\n",
"- `ExportFormat.VOCAB`\n",
"- `ExportFormat.SAVED_MODEL`\n",
"\n",
"默认情况下,它仅导出包含模型元数据的 TensorFlow Lite 模型文件。您也可以选择导出与模型相关的其他文件,以便更好地进行检查。例如,只导出标签文件和词汇文件,如下所示:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "nbK7nzK_Mfx4"
},
"outputs": [],
"source": [
"model.export(export_dir='mobilebert/', export_format=[ExportFormat.LABEL, ExportFormat.VOCAB])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HZKYthlVrTos"
},
"source": [
"您可以使用 `evaluate_tflite` 方法对 TFLite 模型进行评估,以衡量其准确率。将训练好的 TensorFlow 模型转换为 TFLite 格式并应用量化可能会影响其准确率,因此建议在部署之前评估 TFLite 模型的准确率。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ochbq95ZrVFX"
},
"outputs": [],
"source": [
"accuracy = model.evaluate_tflite('mobilebert/model.tflite', test_data)\n",
"print('TFLite model accuracy: ', accuracy)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EoWiA_zX8rxE"
},
"source": [
"## 高级用法\n",
"\n",
"`create` 函数是 Model Maker 库用于创建模型的驱动函数。`model_spec` 参数定义模型规范。目前支持 `AverageWordVecSpec` 和 `BertClassifierSpec` 类。`create` 函数包括以下步骤:\n",
"\n",
"1. 根据 `model_spec` 为文本分类器创建模型。\n",
"2. 训练分类器模型。默认周期和默认批次大小通过 `model_spec` 对象中的 `default_training_epochs` 和 `default_batch_size` 变量进行设置。\n",
"\n",
"本部分介绍了调整模型和训练超参数等高级用法主题。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "E8VxPiOLy4Gv"
},
"source": [
"### 自定义 MobileBERT 模型超参数\n",
"\n",
"可以调整的模型参数包括:\n",
"\n",
"- `seq_len`:馈送到模型的序列长度。\n",
"- `initializer_range`:用于初始化所有权重矩阵的 truncated_normal_initializer 的标准差。\n",
"- `trainable`:指定预训练层是否可训练的布尔值。\n",
"\n",
"可以调整的训练流水线参数包括:\n",
"\n",
"- `model_dir`:模型检查点文件的位置。如果未设置,则会使用临时目录。\n",
"- `dropout_rate`:随机失活率。\n",
"- `learning_rate`:Adam 优化器的初始学习率。\n",
"- `tpu`:要连接的 TPU 地址。\n",
"\n",
"例如,您可以设置 `seq_len=256`(默认为 128)。这允许模型对较长的文本进行分类。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4tr9BLcjy4Sh"
},
"outputs": [],
"source": [
"new_model_spec = model_spec.get('mobilebert_classifier')\n",
"new_model_spec.seq_len = 256"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mwtiksguDfhl"
},
"source": [
"### 自定义平均单词嵌入向量模型超参数\n",
"\n",
"您可以调整模型基础架构,如 `AverageWordVecSpec` 类中的 `wordvec_dim` 和 `seq_len` 变量。\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cAOd5_bzH9AQ"
},
"source": [
"例如,您可以使用较大的 `wordvec_dim` 值来训练模型。请注意,如果您修改模型,必须构造新的 `model_spec`。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "e9WBN0UTQoMN"
},
"outputs": [],
"source": [
"new_model_spec = AverageWordVecSpec(wordvec_dim=32)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6LSTdghTP0Cv"
},
"source": [
"获取预处理的数据。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "DVZurFBORG3J"
},
"outputs": [],
"source": [
"new_train_data = DataLoader.from_csv(\n",
" filename='train.csv',\n",
" text_column='sentence',\n",
" label_column='label',\n",
" model_spec=new_model_spec,\n",
" is_training=True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tD7QVVHeRZoM"
},
"source": [
"训练新模型。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PzpV246_JGEu"
},
"outputs": [],
"source": [
"model = text_classifier.create(new_train_data, model_spec=new_model_spec)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LvQuy7RSDir3"
},
"source": [
"### 调节训练超参数\n",
"\n",
"您还可以调节影响模型准确率的训练超参数(如 `epochs` 和 `batch_size`)。例如,\n",
"\n",
"- `epochs`:更多的周期可能会获得更高的准确率,但也可能导致过拟合。\n",
"- `batch_size`:一个训练步骤中要使用的样本数。\n",
"\n",
"例如,您可以使用更多周期进行训练。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rnWFaYZBG6NW"
},
"outputs": [],
"source": [
"model = text_classifier.create(new_train_data, model_spec=new_model_spec, epochs=20)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nUaKQZBQHBQR"
},
"source": [
"使用 20 个训练周期评估新近重新训练的模型。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "BMPi1xflHDSY"
},
"outputs": [],
"source": [
"new_test_data = DataLoader.from_csv(\n",
" filename='dev.csv',\n",
" text_column='sentence',\n",
" label_column='label',\n",
" model_spec=new_model_spec,\n",
" is_training=False)\n",
"\n",
"loss, accuracy = model.evaluate(new_test_data)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Eq6B9lKMfhS6"
},
"source": [
"### 更改模型架构\n",
"\n",
"您可以通过更改 `model_spec` 来更改模型。下文展示了如何更改为 BERT-Base 模型。\n",
"\n",
"将文本分类器的 `model_spec` 更改为 BERT-Base 模型。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QfFCWrwyggrT"
},
"outputs": [],
"source": [
"spec = model_spec.get('bert_classifier')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "L2d7yycrgu6L"
},
"source": [
"其余步骤相同。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GgiD_tkyQn7l"
},
"source": [
"### 在 TensorFlow Lite 模型上自定义训练后量化\n",
"\n",
"[训练后量化](https://tensorflow.google.cn/lite/performance/post_training_quantization)是一种转换技术,可以缩减模型大小并缩短推断延迟,同时改善 CPU 和硬件加速器推断速度,且几乎不会降低模型准确率。因此,它被广泛用于优化模型。\n",
"\n",
"Model Maker 库在导出模型时会应用默认的训练后量化技术。如果您想自定义训练后量化,Model Maker 也支持使用 [QuantizationConfig](https://tensorflow.google.cn/lite/api_docs/python/tflite_model_maker/config/QuantizationConfig) 的多个训练后量化选项。我们以 float16 量化为例。首先,定义量化配置。\n",
"\n",
"```python\n",
"config = QuantizationConfig.for_float16()\n",
"```\n",
"\n",
"然后,我们使用此配置导出 TensorFlow Lite 模型。\n",
"\n",
"```python\n",
"model.export(export_dir='.', tflite_filename='model_fp16.tflite', quantization_config=config)\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qkJGvMEx6VD-"
},
"source": [
"# 阅读更多\n",
"\n",
"您可以阅读我们的[文本分类](https://tensorflow.google.cn/lite/examples/text_classification/overview)示例以了解技术细节。如需了解更多信息,请参阅:\n",
"\n",
"- TensorFlow Lite Model Maker [指南](https://tensorflow.google.cn/lite/models/modify/model_maker)和 [API 参考](https://tensorflow.google.cn/lite/api_docs/python/tflite_model_maker)。\n",
"- Task Library:用于开发的 [NLClassifier](https://tensorflow.google.cn/lite/inference_with_metadata/task_library/nl_classifier) 和 [BertNLClassifier](https://tensorflow.google.cn/lite/inference_with_metadata/task_library/bert_nl_classifier)。\n",
"- 端到端参考应用:[Android](https://github.com/tensorflow/examples/tree/master/lite/examples/text_classification/android) 和 [iOS](https://github.com/tensorflow/examples/tree/master/lite/examples/text_classification/ios)。"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "text_classification.ipynb",
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}