{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "hX4n9TsbGw-f" }, "source": [ "##### Copyright 2018 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "0nbI5DtDGw-i" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "9TnJztDZGw-n" }, "source": [ "# 使用 RNN 进行文本分类" ] }, { "cell_type": "markdown", "metadata": { "id": "AfN3bMR5Gw-o" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
\n", " \n", " \n", " 在 tensorFlow.google.cn 上查看\n", " \n", " \n", " \n", " 在 Google Colab 中运行\n", " \n", " \n", " \n", " 在 GitHub 上查看源代码\n", " \n", " 下载 notebook\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "lUWearf0Gw-p" }, "source": [ "此文本分类教程将在 [IMDB 大型电影评论数据集](http://ai.stanford.edu/~amaas/data/sentiment/)上训练[循环神经网络](https://developers.google.com/machine-learning/glossary/#recurrent_neural_network),以进行情感分析。" ] }, { "cell_type": "markdown", "metadata": { "id": "_2VQo4bajwUU" }, "source": [ "## 设置" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "z682XYsrjkY9" }, "outputs": [], "source": [ "import tensorflow_datasets as tfds\n", "import tensorflow as tf" ] }, { "cell_type": "markdown", "metadata": { "id": "1rXHa-w9JZhb" }, "source": [ "导入 `matplotlib` 并创建一个辅助函数来绘制计算图:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Mp1Z7P9pYRSK" }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "def plot_graphs(history, metric):\n", " plt.plot(history.history[metric])\n", " plt.plot(history.history['val_'+metric], '')\n", " plt.xlabel(\"Epochs\")\n", " plt.ylabel(metric)\n", " plt.legend([metric, 'val_'+metric])\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "pRmMubr0jrE2" }, "source": [ "## 设置输入流水线\n", "\n", "IMDB 大型电影评论数据集是一个*二进制分类*数据集——所有评论都具有*正面*或*负面*情绪。\n", "\n", "使用 [TFDS](https://tensorflow.google.cn/datasets) 下载数据集。\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SHRwRoP2nVHX" }, "outputs": [], "source": [ "dataset, info = tfds.load('imdb_reviews/subwords8k', with_info=True,\n", " as_supervised=True)\n", "train_dataset, test_dataset = dataset['train'], dataset['test']" ] }, { "cell_type": "markdown", "metadata": { "id": "MCorLciXSDJE" }, "source": [ "数据集 `info` 包括编码器 (`tfds.features.text.SubwordTextEncoder`)。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "EplYp5pNnW1S" }, "outputs": [], "source": [ "encoder = info.features['text'].encoder" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "e7ACuHM5hFp3" }, "outputs": [], "source": [ "print('Vocabulary size: {}'.format(encoder.vocab_size))" ] }, { "cell_type": "markdown", "metadata": { "id": "tAfGg8YRe6fu" }, "source": [ "此文本编码器将以可逆方式对任何字符串进行编码,并在必要时退回到字节编码。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Bq6xDmf2SAs-" }, "outputs": [], "source": [ "sample_string = 'Hello TensorFlow.'\n", "\n", "encoded_string = encoder.encode(sample_string)\n", "print('Encoded string is {}'.format(encoded_string))\n", "\n", "original_string = encoder.decode(encoded_string)\n", "print('The original string: \"{}\"'.format(original_string))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "TN7QbKaM4-5H" }, "outputs": [], "source": [ "assert original_string == sample_string" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "MDVc6UGO5Dh6" }, "outputs": [], "source": [ "for index in encoded_string:\n", " print('{} ----> {}'.format(index, encoder.decode([index])))" ] }, { "cell_type": "markdown", "metadata": { "id": "GlYWqhTVlUyQ" }, "source": [ "## 准备用于训练的数据" ] }, { "cell_type": "markdown", "metadata": { "id": "z2qVJzcEluH_" }, "source": [ "接下来,创建这些编码字符串的批次。使用 `padded_batch` 方法将序列零填充至批次中最长字符串的长度:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dDsCaZCDYZgm" }, "outputs": [], "source": [ "BUFFER_SIZE = 10000\n", "BATCH_SIZE = 64" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "VznrltNOnUc5" }, "outputs": [], "source": [ "train_dataset = train_dataset.shuffle(BUFFER_SIZE)\n", "train_dataset = train_dataset.padded_batch(BATCH_SIZE)\n", "\n", "test_dataset = test_dataset.padded_batch(BATCH_SIZE)" ] }, { "cell_type": "markdown", "metadata": { "id": "bjUqGVBxGw-t" }, "source": [ "## 创建模型" ] }, { "cell_type": "markdown", "metadata": { "id": "bgs6nnSTGw-t" }, "source": [ "构建一个 `tf.keras.Sequential` 模型并从嵌入向量层开始。嵌入向量层每个单词存储一个向量。调用时,它会将单词索引序列转换为向量序列。这些向量是可训练的。(在足够的数据上)训练后,具有相似含义的单词通常具有相似的向量。\n", "\n", "与通过 `tf.keras.layers.Dense` 层传递独热编码向量的等效运算相比,这种索引查找方法要高效得多。\n", "\n", "循环神经网络 (RNN) 通过遍历元素来处理序列输入。RNN 将输出从一个时间步骤传递到其输入,然后传递到下一个步骤。\n", "\n", "`tf.keras.layers.Bidirectional` 包装器也可以与 RNN 层一起使用。这将通过 RNN 层向前和向后传播输入,然后连接输出。这有助于 RNN 学习长程依赖关系。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LwfoBkmRYcP3" }, "outputs": [], "source": [ "model = tf.keras.Sequential([\n", " tf.keras.layers.Embedding(encoder.vocab_size, 64),\n", " tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64)),\n", " tf.keras.layers.Dense(64, activation='relu'),\n", " tf.keras.layers.Dense(1)\n", "])" ] }, { "cell_type": "markdown", "metadata": { "id": "QIGmIGkkouUb" }, "source": [ "请注意,我们在这里选择 Keras 序贯模型,因为模型中的所有层都只有单个输入并产生单个输出。如果要使用有状态 RNN 层,则可能需要使用 Keras 函数式 API 或模型子类化来构建模型,以便可以检索和重用 RNN 层状态。有关更多详细信息,请参阅 [Keras RNN 指南](https://tensorflow.google.cn/guide/keras/rnn#rnn_state_reuse)。" ] }, { "cell_type": "markdown", "metadata": { "id": "sRI776ZcH3Tf" }, "source": [ "编译 Keras 模型以配置训练过程:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kj2xei41YZjC" }, "outputs": [], "source": [ "model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),\n", " optimizer=tf.keras.optimizers.Adam(1e-4),\n", " metrics=['accuracy'])" ] }, { "cell_type": "markdown", "metadata": { "id": "zIwH3nto596k" }, "source": [ "## 训练模型" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hw86wWS4YgR2" }, "outputs": [], "source": [ "history = model.fit(train_dataset, epochs=10,\n", " validation_data=test_dataset, \n", " validation_steps=30)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "BaNbXi43YgUT" }, "outputs": [], "source": [ "test_loss, test_acc = model.evaluate(test_dataset)\n", "\n", "print('Test Loss: {}'.format(test_loss))\n", "print('Test Accuracy: {}'.format(test_acc))" ] }, { "cell_type": "markdown", "metadata": { "id": "DwSE_386uhxD" }, "source": [ "上面的模型没有遮盖应用于序列的填充。如果在填充序列上进行训练并在未填充序列上进行测试,则可能导致倾斜。理想情况下,您可以[使用遮盖](../../guide/keras/masking_and_padding)来避免这种情况,但是正如您在下面看到的那样,它只会对输出产生很小的影响。\n", "\n", "如果预测 >= 0.5,则为正,否则为负。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8w0dseJMiEUh" }, "outputs": [], "source": [ "def pad_to_size(vec, size):\n", " zeros = [0] * (size - len(vec))\n", " vec.extend(zeros)\n", " return vec" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Y-E4cgkIvmVu" }, "outputs": [], "source": [ "def sample_predict(sample_pred_text, pad):\n", " encoded_sample_pred_text = encoder.encode(sample_pred_text)\n", "\n", " if pad:\n", " encoded_sample_pred_text = pad_to_size(encoded_sample_pred_text, 64)\n", " encoded_sample_pred_text = tf.cast(encoded_sample_pred_text, tf.float32)\n", " predictions = model.predict(tf.expand_dims(encoded_sample_pred_text, 0))\n", "\n", " return (predictions)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "O41gw3KfWHus" }, "outputs": [], "source": [ "# predict on a sample text without padding.\n", "\n", "sample_pred_text = ('The movie was cool. The animation and the graphics '\n", " 'were out of this world. I would recommend this movie.')\n", "predictions = sample_predict(sample_pred_text, pad=False)\n", "print(predictions)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kFh4xLARucTy" }, "outputs": [], "source": [ "# predict on a sample text with padding\n", "\n", "sample_pred_text = ('The movie was cool. The animation and the graphics '\n", " 'were out of this world. I would recommend this movie.')\n", "predictions = sample_predict(sample_pred_text, pad=True)\n", "print(predictions)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZfIVoxiNmKBF" }, "outputs": [], "source": [ "plot_graphs(history, 'accuracy')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IUzgkqnhmKD2" }, "outputs": [], "source": [ "plot_graphs(history, 'loss')" ] }, { "cell_type": "markdown", "metadata": { "id": "7g1evcaRpTKm" }, "source": [ "## 堆叠两个或更多 LSTM 层\n", "\n", "Keras 循环层有两种可用的模式,这些模式由 `return_sequences` 构造函数参数控制:\n", "\n", "- 返回每个时间步骤的连续输出的完整序列(形状为 `(batch_size, timesteps, output_features)` 的 3D 张量)。\n", "- 仅返回每个输入序列的最后一个输出(形状为 (batch_size, output_features) 的 2D 张量)。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jo1jjO3vn0jo" }, "outputs": [], "source": [ "model = tf.keras.Sequential([\n", " tf.keras.layers.Embedding(encoder.vocab_size, 64),\n", " tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64, return_sequences=True)),\n", " tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(32)),\n", " tf.keras.layers.Dense(64, activation='relu'),\n", " tf.keras.layers.Dropout(0.5),\n", " tf.keras.layers.Dense(1)\n", "])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hEPV5jVGp-is" }, "outputs": [], "source": [ "model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),\n", " optimizer=tf.keras.optimizers.Adam(1e-4),\n", " metrics=['accuracy'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LeSE-YjdqAeN" }, "outputs": [], "source": [ "history = model.fit(train_dataset, epochs=10,\n", " validation_data=test_dataset,\n", " validation_steps=30)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_LdwilM1qPM3" }, "outputs": [], "source": [ "test_loss, test_acc = model.evaluate(test_dataset)\n", "\n", "print('Test Loss: {}'.format(test_loss))\n", "print('Test Accuracy: {}'.format(test_acc))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ykUKnAoqbycW" }, "outputs": [], "source": [ "# predict on a sample text without padding.\n", "\n", "sample_pred_text = ('The movie was not good. The animation and the graphics '\n", " 'were terrible. I would not recommend this movie.')\n", "predictions = sample_predict(sample_pred_text, pad=False)\n", "print(predictions)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2RiC-94zvdZO" }, "outputs": [], "source": [ "# predict on a sample text with padding\n", "\n", "sample_pred_text = ('The movie was not good. The animation and the graphics '\n", " 'were terrible. I would not recommend this movie.')\n", "predictions = sample_predict(sample_pred_text, pad=True)\n", "print(predictions)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_YYub0EDtwCu" }, "outputs": [], "source": [ "plot_graphs(history, 'accuracy')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DPV3Nn9xtwFM" }, "outputs": [], "source": [ "plot_graphs(history, 'loss')" ] }, { "cell_type": "markdown", "metadata": { "id": "9xvpE3BaGw_V" }, "source": [ "检查其他现有循环层,例如 [GRU 层](https://tensorflow.google.cn/api_docs/python/tf/keras/layers/GRU)。\n", "\n", "如果您对构建自定义 RNN 感兴趣,请参阅 [Keras RNN 指南](../../guide/keras/rnn.ipynb)。\n" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "text_classification_rnn.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }