{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "Ic4_occAAiAT" }, "source": [ "##### Copyright 2019 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "ioaprt5q5US7" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "yCl0eTNH5RS3" }, "outputs": [], "source": [ "#@title MIT License\n", "#\n", "# Copyright (c) 2017 François Chollet\n", "#\n", "# Permission is hereby granted, free of charge, to any person obtaining a\n", "# copy of this software and associated documentation files (the \"Software\"),\n", "# to deal in the Software without restriction, including without limitation\n", "# the rights to use, copy, modify, merge, publish, distribute, sublicense,\n", "# and/or sell copies of the Software, and to permit persons to whom the\n", "# Software is furnished to do so, subject to the following conditions:\n", "#\n", "# The above copyright notice and this permission notice shall be included in\n", "# all copies or substantial portions of the Software.\n", "#\n", "# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n", "# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n", "# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL\n", "# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n", "# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING\n", "# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER\n", "# DEALINGS IN THE SOFTWARE." ] }, { "cell_type": "markdown", "metadata": { "id": "ItXfxkxvosLH" }, "source": [ "# Text classification with TensorFlow Hub: Movie reviews" ] }, { "cell_type": "markdown", "metadata": { "id": "hKY4XMc9o8iB" }, "source": [ "\n", " \n", " \n", " \n", " \n", " \n", "
\n", " View on TensorFlow.org\n", " \n", " Run in Google Colab\n", " \n", " View on GitHub\n", " \n", " Download notebook\n", " \n", " See TF Hub models\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "Eg62Pmz3o83v" }, "source": [ "This notebook classifies movie reviews as *positive* or *negative* using the text of the review. This is an example of *binary*—or two-class—classification, an important and widely applicable kind of machine learning problem.\n", "\n", "The tutorial demonstrates the basic application of transfer learning with [TensorFlow Hub](https://tfhub.dev) and Keras.\n", "\n", "It uses the [IMDB dataset](https://www.tensorflow.org/api_docs/python/tf/keras/datasets/imdb) that contains the text of 50,000 movie reviews from the [Internet Movie Database](https://www.imdb.com/). These are split into 25,000 reviews for training and 25,000 reviews for testing. The training and testing sets are *balanced*, meaning they contain an equal number of positive and negative reviews. \n", "\n", "This notebook uses [`tf.keras`](https://www.tensorflow.org/guide/keras), a high-level API to build and train models in TensorFlow, and [`tensorflow_hub`](https://www.tensorflow.org/hub), a library for loading trained models from [TFHub](https://tfhub.dev) in a single line of code. For a more advanced text classification tutorial using `tf.keras`, see the [MLCC Text Classification Guide](https://developers.google.com/machine-learning/guides/text-classification/)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IHTzYqKZ7auw" }, "outputs": [], "source": [ "!pip install tensorflow-hub\n", "!pip install tensorflow-datasets" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2ew7HTbPpCJH" }, "outputs": [], "source": [ "import os\n", "import numpy as np\n", "\n", "import tensorflow as tf\n", "import tensorflow_hub as hub\n", "import tensorflow_datasets as tfds\n", "\n", "print(\"Version: \", tf.__version__)\n", "print(\"Eager mode: \", tf.executing_eagerly())\n", "print(\"Hub version: \", hub.__version__)\n", "print(\"GPU is\", \"available\" if tf.config.list_physical_devices(\"GPU\") else \"NOT AVAILABLE\")" ] }, { "cell_type": "markdown", "metadata": { "id": "iAsKG535pHep" }, "source": [ "## Download the IMDB dataset\n", "\n", "The IMDB dataset is available on [imdb reviews](https://www.tensorflow.org/datasets/catalog/imdb_reviews) or on [TensorFlow datasets](https://www.tensorflow.org/datasets). The following code downloads the IMDB dataset to your machine (or the colab runtime):" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zXXx5Oc3pOmN" }, "outputs": [], "source": [ "# Split the training set into 60% and 40% to end up with 15,000 examples\n", "# for training, 10,000 examples for validation and 25,000 examples for testing.\n", "train_data, validation_data, test_data = tfds.load(\n", " name=\"imdb_reviews\", \n", " split=('train[:60%]', 'train[60%:]', 'test'),\n", " as_supervised=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "l50X3GfjpU4r" }, "source": [ "## Explore the data \n", "\n", "Let's take a moment to understand the format of the data. Each example is a sentence representing the movie review and a corresponding label. The sentence is not preprocessed in any way. The label is an integer value of either 0 or 1, where 0 is a negative review, and 1 is a positive review.\n", "\n", "Let's print first 10 examples." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QtTS4kpEpjbi" }, "outputs": [], "source": [ "train_examples_batch, train_labels_batch = next(iter(train_data.batch(10)))\n", "train_examples_batch" ] }, { "cell_type": "markdown", "metadata": { "id": "IFtaCHTdc-GY" }, "source": [ "Let's also print the first 10 labels." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tvAjVXOWc6Mj" }, "outputs": [], "source": [ "train_labels_batch" ] }, { "cell_type": "markdown", "metadata": { "id": "LLC02j2g-llC" }, "source": [ "## Build the model\n", "\n", "The neural network is created by stacking layers—this requires three main architectural decisions:\n", "\n", "* How to represent the text?\n", "* How many layers to use in the model?\n", "* How many *hidden units* to use for each layer?\n", "\n", "In this example, the input data consists of sentences. The labels to predict are either 0 or 1.\n", "\n", "One way to represent the text is to convert sentences into embeddings vectors. Use a pre-trained text embedding as the first layer, which will have three advantages:\n", "\n", "* You don't have to worry about text preprocessing,\n", "* Benefit from transfer learning,\n", "* the embedding has a fixed size, so it's simpler to process.\n", "\n", "For this example you use a **pre-trained text embedding model** from [TensorFlow Hub](https://tfhub.dev) called [google/nnlm-en-dim50/2](https://tfhub.dev/google/nnlm-en-dim50/2).\n", "\n", "There are many other pre-trained text embeddings from TFHub that can be used in this tutorial:\n", "\n", "* [google/nnlm-en-dim128/2](https://tfhub.dev/google/nnlm-en-dim128/2) - trained with the same NNLM architecture on the same data as [google/nnlm-en-dim50/2](https://tfhub.dev/google/nnlm-en-dim50/2), but with a larger embedding dimension. Larger dimensional embeddings can improve on your task but it may take longer to train your model.\n", "* [google/nnlm-en-dim128-with-normalization/2](https://tfhub.dev/google/nnlm-en-dim128-with-normalization/2) - the same as [google/nnlm-en-dim128/2](https://tfhub.dev/google/nnlm-en-dim128/2), but with additional text normalization such as removing punctuation. This can help if the text in your task contains additional characters or punctuation.\n", "* [google/universal-sentence-encoder/4](https://tfhub.dev/google/universal-sentence-encoder/4) - a much larger model yielding 512 dimensional embeddings trained with a deep averaging network (DAN) encoder.\n", "\n", "And many more! Find more [text embedding models](https://tfhub.dev/s?module-type=text-embedding) on TFHub." ] }, { "cell_type": "markdown", "metadata": { "id": "In2nDpTLkgKa" }, "source": [ "Let's first create a Keras layer that uses a TensorFlow Hub model to embed the sentences, and try it out on a couple of input examples. Note that no matter the length of the input text, the output shape of the embeddings is: `(num_examples, embedding_dimension)`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_NUbzVeYkgcO" }, "outputs": [], "source": [ "embedding = \"https://tfhub.dev/google/nnlm-en-dim50/2\"\n", "hub_layer = hub.KerasLayer(embedding, input_shape=[], \n", " dtype=tf.string, trainable=True)\n", "hub_layer(train_examples_batch[:3])" ] }, { "cell_type": "markdown", "metadata": { "id": "dfSbV6igl1EH" }, "source": [ "Let's now build the full model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xpKOoWgu-llD" }, "outputs": [], "source": [ "model = tf.keras.Sequential()\n", "model.add(hub_layer)\n", "model.add(tf.keras.layers.Dense(16, activation='relu'))\n", "model.add(tf.keras.layers.Dense(1))\n", "\n", "model.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "6PbKQ6mucuKL" }, "source": [ "The layers are stacked sequentially to build the classifier:\n", "\n", "1. The first layer is a TensorFlow Hub layer. This layer uses a pre-trained Saved Model to map a sentence into its embedding vector. The pre-trained text embedding model that you are using ([google/nnlm-en-dim50/2](https://tfhub.dev/google/nnlm-en-dim50/2)) splits the sentence into tokens, embeds each token and then combines the embedding. The resulting dimensions are: `(num_examples, embedding_dimension)`. For this NNLM model, the `embedding_dimension` is 50.\n", "2. This fixed-length output vector is piped through a fully-connected (`Dense`) layer with 16 hidden units.\n", "3. The last layer is densely connected with a single output node.\n", "\n", "Let's compile the model." ] }, { "cell_type": "markdown", "metadata": { "id": "L4EqVWg4-llM" }, "source": [ "### Loss function and optimizer\n", "\n", "A model needs a loss function and an optimizer for training. Since this is a binary classification problem and the model outputs logits (a single-unit layer with a linear activation), you'll use the `binary_crossentropy` loss function.\n", "\n", "This isn't the only choice for a loss function, you could, for instance, choose `mean_squared_error`. But, generally, `binary_crossentropy` is better for dealing with probabilities—it measures the \"distance\" between probability distributions, or in our case, between the ground-truth distribution and the predictions.\n", "\n", "Later, when you are exploring regression problems (say, to predict the price of a house), you'll see how to use another loss function called mean squared error.\n", "\n", "Now, configure the model to use an optimizer and a loss function:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Mr0GP-cQ-llN" }, "outputs": [], "source": [ "model.compile(optimizer='adam',\n", " loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),\n", " metrics=['accuracy'])" ] }, { "cell_type": "markdown", "metadata": { "id": "35jv_fzP-llU" }, "source": [ "## Train the model\n", "\n", "Train the model for 10 epochs in mini-batches of 512 samples. This is 10 iterations over all samples in the `x_train` and `y_train` tensors. While training, monitor the model's loss and accuracy on the 10,000 samples from the validation set:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tXSGrjWZ-llW" }, "outputs": [], "source": [ "history = model.fit(train_data.shuffle(10000).batch(512),\n", " epochs=10,\n", " validation_data=validation_data.batch(512),\n", " verbose=1)" ] }, { "cell_type": "markdown", "metadata": { "id": "9EEGuDVuzb5r" }, "source": [ "## Evaluate the model\n", "\n", "And let's see how the model performs. Two values will be returned. Loss (a number which represents our error, lower values are better), and accuracy." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zOMKywn4zReN" }, "outputs": [], "source": [ "results = model.evaluate(test_data.batch(512), verbose=2)\n", "\n", "for name, value in zip(model.metrics_names, results):\n", " print(\"%s: %.3f\" % (name, value))" ] }, { "cell_type": "markdown", "metadata": { "id": "z1iEXVTR0Z2t" }, "source": [ "This fairly naive approach achieves an accuracy of about 87%. With more advanced approaches, the model should get closer to 95%." ] }, { "cell_type": "markdown", "metadata": { "id": "5KggXVeL-llZ" }, "source": [ "## Further reading\n", "\n", "* For a more general way to work with string inputs and for a more detailed analysis of the progress of accuracy and loss during training, see the [Text classification with preprocessed text](./text_classification.ipynb) tutorial.\n", "* Try out more [text-related tutorials](https://www.tensorflow.org/hub/tutorials#text-related-tutorials) using trained models from TFHub." ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "text_classification_with_hub.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }