{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "ScitaPqhKtuW" }, "source": [ "##### Copyright 2021 The TensorFlow Hub Authors.\n", "\n", "Licensed under the Apache License, Version 2.0 (the \"License\");" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jvztxQ6VsK2k" }, "outputs": [], "source": [ "# Copyright 2021 The TensorFlow Hub Authors. All Rights Reserved.\n", "#\n", "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# http://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License.\n", "# ==============================================================================" ] }, { "cell_type": "markdown", "metadata": { "id": "oYM61xrTsP5d" }, "source": [ "# Retraining an Image Classifier\n" ] }, { "cell_type": "markdown", "metadata": { "id": "MfBg1C5NB3X0" }, "source": [ "\n", " \n", " \n", " \n", " \n", " \n", "
\n", " View on TensorFlow.org\n", " \n", " Run in Google Colab\n", " \n", " View on GitHub\n", " \n", " Download notebook\n", " \n", " See TF Hub models\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "L1otmJgmbahf" }, "source": [ "## Introduction\n", "\n", "Image classification models have millions of parameters. Training them from\n", "scratch requires a lot of labeled training data and a lot of computing power. Transfer learning is a technique that shortcuts much of this by taking a piece of a model that has already been trained on a related task and reusing it in a new model.\n", "\n", "This Colab demonstrates how to build a Keras model for classifying five species of flowers by using a pre-trained TF2 SavedModel from TensorFlow Hub for image feature extraction, trained on the much larger and more general ImageNet dataset. Optionally, the feature extractor can be trained (\"fine-tuned\") alongside the newly added classifier.\n", "\n", "### Looking for a tool instead?\n", "\n", "This is a TensorFlow coding tutorial. If you want a tool that just builds the TensorFlow or TFLite model for, take a look at the [make_image_classifier](https://github.com/tensorflow/hub/tree/master/tensorflow_hub/tools/make_image_classifier) command-line tool that gets [installed](https://www.tensorflow.org/hub/installation) by the PIP package `tensorflow-hub[make_image_classifier]`, or at [this](https://colab.sandbox.google.com/github/tensorflow/examples/blob/master/tensorflow_examples/lite/model_maker/demo/image_classification.ipynb) TFLite colab.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "bL54LWCHt5q5" }, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dlauq-4FWGZM" }, "outputs": [], "source": [ "import itertools\n", "import os\n", "\n", "import matplotlib.pylab as plt\n", "import numpy as np\n", "\n", "import tensorflow as tf\n", "import tensorflow_hub as hub\n", "\n", "print(\"TF version:\", tf.__version__)\n", "print(\"Hub version:\", hub.__version__)\n", "print(\"GPU is\", \"available\" if tf.config.list_physical_devices('GPU') else \"NOT AVAILABLE\")" ] }, { "cell_type": "markdown", "metadata": { "id": "mmaHHH7Pvmth" }, "source": [ "## Select the TF2 SavedModel module to use\n", "\n", "For starters, use [https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/feature_vector/4](https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/feature_vector/4). The same URL can be used in code to identify the SavedModel and in your browser to show its documentation. (Note that models in TF1 Hub format won't work here.)\n", "\n", "You can find more TF2 models that generate image feature vectors [here](https://tfhub.dev/s?module-type=image-feature-vector&tf-version=tf2).\n", "\n", "There are multiple possible models to try. All you need to do is select a different one on the cell below and follow up with the notebook." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FlsEcKVeuCnf" }, "outputs": [], "source": [ "#@title\n", "\n", "model_name = \"efficientnetv2-xl-21k\" # @param ['efficientnetv2-s', 'efficientnetv2-m', 'efficientnetv2-l', 'efficientnetv2-s-21k', 'efficientnetv2-m-21k', 'efficientnetv2-l-21k', 'efficientnetv2-xl-21k', 'efficientnetv2-b0-21k', 'efficientnetv2-b1-21k', 'efficientnetv2-b2-21k', 'efficientnetv2-b3-21k', 'efficientnetv2-s-21k-ft1k', 'efficientnetv2-m-21k-ft1k', 'efficientnetv2-l-21k-ft1k', 'efficientnetv2-xl-21k-ft1k', 'efficientnetv2-b0-21k-ft1k', 'efficientnetv2-b1-21k-ft1k', 'efficientnetv2-b2-21k-ft1k', 'efficientnetv2-b3-21k-ft1k', 'efficientnetv2-b0', 'efficientnetv2-b1', 'efficientnetv2-b2', 'efficientnetv2-b3', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', 'efficientnet_b4', 'efficientnet_b5', 'efficientnet_b6', 'efficientnet_b7', 'bit_s-r50x1', 'inception_v3', 'inception_resnet_v2', 'resnet_v1_50', 'resnet_v1_101', 'resnet_v1_152', 'resnet_v2_50', 'resnet_v2_101', 'resnet_v2_152', 'nasnet_large', 'nasnet_mobile', 'pnasnet_large', 'mobilenet_v2_100_224', 'mobilenet_v2_130_224', 'mobilenet_v2_140_224', 'mobilenet_v3_small_100_224', 'mobilenet_v3_small_075_224', 'mobilenet_v3_large_100_224', 'mobilenet_v3_large_075_224']\n", "\n", "model_handle_map = {\n", " \"efficientnetv2-s\": \"https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet1k_s/feature_vector/2\",\n", " \"efficientnetv2-m\": \"https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet1k_m/feature_vector/2\",\n", " \"efficientnetv2-l\": \"https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet1k_l/feature_vector/2\",\n", " \"efficientnetv2-s-21k\": \"https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_s/feature_vector/2\",\n", " \"efficientnetv2-m-21k\": \"https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_m/feature_vector/2\",\n", " \"efficientnetv2-l-21k\": \"https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_l/feature_vector/2\",\n", " \"efficientnetv2-xl-21k\": \"https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_xl/feature_vector/2\",\n", " \"efficientnetv2-b0-21k\": \"https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_b0/feature_vector/2\",\n", " \"efficientnetv2-b1-21k\": \"https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_b1/feature_vector/2\",\n", " \"efficientnetv2-b2-21k\": \"https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_b2/feature_vector/2\",\n", " \"efficientnetv2-b3-21k\": \"https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_b3/feature_vector/2\",\n", " \"efficientnetv2-s-21k-ft1k\": \"https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_ft1k_s/feature_vector/2\",\n", " \"efficientnetv2-m-21k-ft1k\": \"https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_ft1k_m/feature_vector/2\",\n", " \"efficientnetv2-l-21k-ft1k\": \"https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_ft1k_l/feature_vector/2\",\n", " \"efficientnetv2-xl-21k-ft1k\": \"https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_ft1k_xl/feature_vector/2\",\n", " \"efficientnetv2-b0-21k-ft1k\": \"https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_ft1k_b0/feature_vector/2\",\n", " \"efficientnetv2-b1-21k-ft1k\": \"https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_ft1k_b1/feature_vector/2\",\n", " \"efficientnetv2-b2-21k-ft1k\": \"https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_ft1k_b2/feature_vector/2\",\n", " \"efficientnetv2-b3-21k-ft1k\": \"https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_ft1k_b3/feature_vector/2\",\n", " \"efficientnetv2-b0\": \"https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet1k_b0/feature_vector/2\",\n", " \"efficientnetv2-b1\": \"https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet1k_b1/feature_vector/2\",\n", " \"efficientnetv2-b2\": \"https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet1k_b2/feature_vector/2\",\n", " \"efficientnetv2-b3\": \"https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet1k_b3/feature_vector/2\",\n", " \"efficientnet_b0\": \"https://tfhub.dev/tensorflow/efficientnet/b0/feature-vector/1\",\n", " \"efficientnet_b1\": \"https://tfhub.dev/tensorflow/efficientnet/b1/feature-vector/1\",\n", " \"efficientnet_b2\": \"https://tfhub.dev/tensorflow/efficientnet/b2/feature-vector/1\",\n", " \"efficientnet_b3\": \"https://tfhub.dev/tensorflow/efficientnet/b3/feature-vector/1\",\n", " \"efficientnet_b4\": \"https://tfhub.dev/tensorflow/efficientnet/b4/feature-vector/1\",\n", " \"efficientnet_b5\": \"https://tfhub.dev/tensorflow/efficientnet/b5/feature-vector/1\",\n", " \"efficientnet_b6\": \"https://tfhub.dev/tensorflow/efficientnet/b6/feature-vector/1\",\n", " \"efficientnet_b7\": \"https://tfhub.dev/tensorflow/efficientnet/b7/feature-vector/1\",\n", " \"bit_s-r50x1\": \"https://tfhub.dev/google/bit/s-r50x1/1\",\n", " \"inception_v3\": \"https://tfhub.dev/google/imagenet/inception_v3/feature-vector/4\",\n", " \"inception_resnet_v2\": \"https://tfhub.dev/google/imagenet/inception_resnet_v2/feature-vector/4\",\n", " \"resnet_v1_50\": \"https://tfhub.dev/google/imagenet/resnet_v1_50/feature-vector/4\",\n", " \"resnet_v1_101\": \"https://tfhub.dev/google/imagenet/resnet_v1_101/feature-vector/4\",\n", " \"resnet_v1_152\": \"https://tfhub.dev/google/imagenet/resnet_v1_152/feature-vector/4\",\n", " \"resnet_v2_50\": \"https://tfhub.dev/google/imagenet/resnet_v2_50/feature-vector/4\",\n", " \"resnet_v2_101\": \"https://tfhub.dev/google/imagenet/resnet_v2_101/feature-vector/4\",\n", " \"resnet_v2_152\": \"https://tfhub.dev/google/imagenet/resnet_v2_152/feature-vector/4\",\n", " \"nasnet_large\": \"https://tfhub.dev/google/imagenet/nasnet_large/feature_vector/4\",\n", " \"nasnet_mobile\": \"https://tfhub.dev/google/imagenet/nasnet_mobile/feature_vector/4\",\n", " \"pnasnet_large\": \"https://tfhub.dev/google/imagenet/pnasnet_large/feature_vector/4\",\n", " \"mobilenet_v2_100_224\": \"https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/feature_vector/4\",\n", " \"mobilenet_v2_130_224\": \"https://tfhub.dev/google/imagenet/mobilenet_v2_130_224/feature_vector/4\",\n", " \"mobilenet_v2_140_224\": \"https://tfhub.dev/google/imagenet/mobilenet_v2_140_224/feature_vector/4\",\n", " \"mobilenet_v3_small_100_224\": \"https://tfhub.dev/google/imagenet/mobilenet_v3_small_100_224/feature_vector/5\",\n", " \"mobilenet_v3_small_075_224\": \"https://tfhub.dev/google/imagenet/mobilenet_v3_small_075_224/feature_vector/5\",\n", " \"mobilenet_v3_large_100_224\": \"https://tfhub.dev/google/imagenet/mobilenet_v3_large_100_224/feature_vector/5\",\n", " \"mobilenet_v3_large_075_224\": \"https://tfhub.dev/google/imagenet/mobilenet_v3_large_075_224/feature_vector/5\",\n", "}\n", "\n", "model_image_size_map = {\n", " \"efficientnetv2-s\": 384,\n", " \"efficientnetv2-m\": 480,\n", " \"efficientnetv2-l\": 480,\n", " \"efficientnetv2-b0\": 224,\n", " \"efficientnetv2-b1\": 240,\n", " \"efficientnetv2-b2\": 260,\n", " \"efficientnetv2-b3\": 300,\n", " \"efficientnetv2-s-21k\": 384,\n", " \"efficientnetv2-m-21k\": 480,\n", " \"efficientnetv2-l-21k\": 480,\n", " \"efficientnetv2-xl-21k\": 512,\n", " \"efficientnetv2-b0-21k\": 224,\n", " \"efficientnetv2-b1-21k\": 240,\n", " \"efficientnetv2-b2-21k\": 260,\n", " \"efficientnetv2-b3-21k\": 300,\n", " \"efficientnetv2-s-21k-ft1k\": 384,\n", " \"efficientnetv2-m-21k-ft1k\": 480,\n", " \"efficientnetv2-l-21k-ft1k\": 480,\n", " \"efficientnetv2-xl-21k-ft1k\": 512,\n", " \"efficientnetv2-b0-21k-ft1k\": 224,\n", " \"efficientnetv2-b1-21k-ft1k\": 240,\n", " \"efficientnetv2-b2-21k-ft1k\": 260,\n", " \"efficientnetv2-b3-21k-ft1k\": 300, \n", " \"efficientnet_b0\": 224,\n", " \"efficientnet_b1\": 240,\n", " \"efficientnet_b2\": 260,\n", " \"efficientnet_b3\": 300,\n", " \"efficientnet_b4\": 380,\n", " \"efficientnet_b5\": 456,\n", " \"efficientnet_b6\": 528,\n", " \"efficientnet_b7\": 600,\n", " \"inception_v3\": 299,\n", " \"inception_resnet_v2\": 299,\n", " \"nasnet_large\": 331,\n", " \"pnasnet_large\": 331,\n", "}\n", "\n", "model_handle = model_handle_map.get(model_name)\n", "pixels = model_image_size_map.get(model_name, 224)\n", "\n", "print(f\"Selected model: {model_name} : {model_handle}\")\n", "\n", "IMAGE_SIZE = (pixels, pixels)\n", "print(f\"Input size {IMAGE_SIZE}\")\n", "\n", "BATCH_SIZE = 16#@param {type:\"integer\"}" ] }, { "cell_type": "markdown", "metadata": { "id": "yTY8qzyYv3vl" }, "source": [ "## Set up the Flowers dataset\n", "\n", "Inputs are suitably resized for the selected module. Dataset augmentation (i.e., random distortions of an image each time it is read) improves training, esp. when fine-tuning." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WBtFK1hO8KsO" }, "outputs": [], "source": [ "data_dir = tf.keras.utils.get_file(\n", " 'flower_photos',\n", " 'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',\n", " untar=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "umB5tswsfTEQ" }, "outputs": [], "source": [ "def build_dataset(subset):\n", " return tf.keras.preprocessing.image_dataset_from_directory(\n", " data_dir,\n", " validation_split=.20,\n", " subset=subset,\n", " label_mode=\"categorical\",\n", " # Seed needs to provided when using validation_split and shuffle = True.\n", " # A fixed seed is used so that the validation set is stable across runs.\n", " seed=123,\n", " image_size=IMAGE_SIZE,\n", " batch_size=1)\n", "\n", "train_ds = build_dataset(\"training\")\n", "class_names = tuple(train_ds.class_names)\n", "train_size = train_ds.cardinality().numpy()\n", "train_ds = train_ds.unbatch().batch(BATCH_SIZE)\n", "train_ds = train_ds.repeat()\n", "\n", "normalization_layer = tf.keras.layers.Rescaling(1. / 255)\n", "preprocessing_model = tf.keras.Sequential([normalization_layer])\n", "do_data_augmentation = False #@param {type:\"boolean\"}\n", "if do_data_augmentation:\n", " preprocessing_model.add(\n", " tf.keras.layers.RandomRotation(40))\n", " preprocessing_model.add(\n", " tf.keras.layers.RandomTranslation(0, 0.2))\n", " preprocessing_model.add(\n", " tf.keras.layers.RandomTranslation(0.2, 0))\n", " # Like the old tf.keras.preprocessing.image.ImageDataGenerator(),\n", " # image sizes are fixed when reading, and then a random zoom is applied.\n", " # If all training inputs are larger than image_size, one could also use\n", " # RandomCrop with a batch size of 1 and rebatch later.\n", " preprocessing_model.add(\n", " tf.keras.layers.RandomZoom(0.2, 0.2))\n", " preprocessing_model.add(\n", " tf.keras.layers.RandomFlip(mode=\"horizontal\"))\n", "train_ds = train_ds.map(lambda images, labels:\n", " (preprocessing_model(images), labels))\n", "\n", "val_ds = build_dataset(\"validation\")\n", "valid_size = val_ds.cardinality().numpy()\n", "val_ds = val_ds.unbatch().batch(BATCH_SIZE)\n", "val_ds = val_ds.map(lambda images, labels:\n", " (normalization_layer(images), labels))" ] }, { "cell_type": "markdown", "metadata": { "id": "FS_gVStowW3G" }, "source": [ "## Defining the model\n", "\n", "All it takes is to put a linear classifier on top of the `feature_extractor_layer` with the Hub module.\n", "\n", "For speed, we start out with a non-trainable `feature_extractor_layer`, but you can also enable fine-tuning for greater accuracy." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RaJW3XrPyFiF" }, "outputs": [], "source": [ "do_fine_tuning = False #@param {type:\"boolean\"}" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "50FYNIb1dmJH" }, "outputs": [], "source": [ "print(\"Building model with\", model_handle)\n", "model = tf.keras.Sequential([\n", " # Explicitly define the input shape so the model can be properly\n", " # loaded by the TFLiteConverter\n", " tf.keras.layers.InputLayer(input_shape=IMAGE_SIZE + (3,)),\n", " hub.KerasLayer(model_handle, trainable=do_fine_tuning),\n", " tf.keras.layers.Dropout(rate=0.2),\n", " tf.keras.layers.Dense(len(class_names),\n", " kernel_regularizer=tf.keras.regularizers.l2(0.0001))\n", "])\n", "model.build((None,)+IMAGE_SIZE+(3,))\n", "model.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "u2e5WupIw2N2" }, "source": [ "## Training the model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9f3yBUvkd_VJ" }, "outputs": [], "source": [ "model.compile(\n", " optimizer=tf.keras.optimizers.SGD(learning_rate=0.005, momentum=0.9), \n", " loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True, label_smoothing=0.1),\n", " metrics=['accuracy'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "w_YKX2Qnfg6x" }, "outputs": [], "source": [ "steps_per_epoch = train_size // BATCH_SIZE\n", "validation_steps = valid_size // BATCH_SIZE\n", "hist = model.fit(\n", " train_ds,\n", " epochs=5, steps_per_epoch=steps_per_epoch,\n", " validation_data=val_ds,\n", " validation_steps=validation_steps).history" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CYOw0fTO1W4x" }, "outputs": [], "source": [ "plt.figure()\n", "plt.ylabel(\"Loss (training and validation)\")\n", "plt.xlabel(\"Training Steps\")\n", "plt.ylim([0,2])\n", "plt.plot(hist[\"loss\"])\n", "plt.plot(hist[\"val_loss\"])\n", "\n", "plt.figure()\n", "plt.ylabel(\"Accuracy (training and validation)\")\n", "plt.xlabel(\"Training Steps\")\n", "plt.ylim([0,1])\n", "plt.plot(hist[\"accuracy\"])\n", "plt.plot(hist[\"val_accuracy\"])" ] }, { "cell_type": "markdown", "metadata": { "id": "jZ8DKKgeKv4-" }, "source": [ "Try out the model on an image from the validation data:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "oi1iCNB9K1Ai" }, "outputs": [], "source": [ "x, y = next(iter(val_ds))\n", "image = x[0, :, :, :]\n", "true_index = np.argmax(y[0])\n", "plt.imshow(image)\n", "plt.axis('off')\n", "plt.show()\n", "\n", "# Expand the validation image to (1, 224, 224, 3) before predicting the label\n", "prediction_scores = model.predict(np.expand_dims(image, axis=0))\n", "predicted_index = np.argmax(prediction_scores)\n", "print(\"True label: \" + class_names[true_index])\n", "print(\"Predicted label: \" + class_names[predicted_index])" ] }, { "cell_type": "markdown", "metadata": { "id": "YCsAsQM1IRvA" }, "source": [ "Finally, the trained model can be saved for deployment to TF Serving or TFLite (on mobile) as follows." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LGvTi69oIc2d" }, "outputs": [], "source": [ "saved_model_path = f\"/tmp/saved_flowers_model_{model_name}\"\n", "tf.saved_model.save(model, saved_model_path)" ] }, { "cell_type": "markdown", "metadata": { "id": "QzW4oNRjILaq" }, "source": [ "## Optional: Deployment to TensorFlow Lite\n", "\n", "[TensorFlow Lite](https://www.tensorflow.org/lite) lets you deploy TensorFlow models to mobile and IoT devices. The code below shows how to convert the trained model to TFLite and apply post-training tools from the [TensorFlow Model Optimization Toolkit](https://www.tensorflow.org/model_optimization). Finally, it runs it in the TFLite Interpreter to examine the resulting quality\n", "\n", " * Converting without optimization provides the same results as before (up to roundoff error).\n", " * Converting with optimization without any data quantizes the model weights to 8 bits, but inference still uses floating-point computation for the neural network activations. This reduces model size almost by a factor of 4 and improves CPU latency on mobile devices.\n", " * On top, computation of the neural network activations can be quantized to 8-bit integers as well if a small reference dataset is provided to calibrate the quantization range. On a mobile device, this accelerates inference further and makes it possible to run on accelerators like Edge TPU." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Va1Vo92fSyV6" }, "outputs": [], "source": [ "#@title Optimization settings\n", "optimize_lite_model = False #@param {type:\"boolean\"}\n", "#@markdown Setting a value greater than zero enables quantization of neural network activations. A few dozen is already a useful amount.\n", "num_calibration_examples = 60 #@param {type:\"slider\", min:0, max:1000, step:1}\n", "representative_dataset = None\n", "if optimize_lite_model and num_calibration_examples:\n", " # Use a bounded number of training examples without labels for calibration.\n", " # TFLiteConverter expects a list of input tensors, each with batch size 1.\n", " representative_dataset = lambda: itertools.islice(\n", " ([image[None, ...]] for batch, _ in train_ds for image in batch),\n", " num_calibration_examples)\n", "\n", "converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_path)\n", "if optimize_lite_model:\n", " converter.optimizations = [tf.lite.Optimize.DEFAULT]\n", " if representative_dataset: # This is optional, see above.\n", " converter.representative_dataset = representative_dataset\n", "lite_model_content = converter.convert()\n", "\n", "with open(f\"/tmp/lite_flowers_model_{model_name}.tflite\", \"wb\") as f:\n", " f.write(lite_model_content)\n", "print(\"Wrote %sTFLite model of %d bytes.\" %\n", " (\"optimized \" if optimize_lite_model else \"\", len(lite_model_content)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_wqEmD0xIqeG" }, "outputs": [], "source": [ "interpreter = tf.lite.Interpreter(model_content=lite_model_content)\n", "# This little helper wraps the TFLite Interpreter as a numpy-to-numpy function.\n", "def lite_model(images):\n", " interpreter.allocate_tensors()\n", " interpreter.set_tensor(interpreter.get_input_details()[0]['index'], images)\n", " interpreter.invoke()\n", " return interpreter.get_tensor(interpreter.get_output_details()[0]['index'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JMMK-fZrKrk8" }, "outputs": [], "source": [ "#@markdown For rapid experimentation, start with a moderate number of examples.\n", "num_eval_examples = 50 #@param {type:\"slider\", min:0, max:700}\n", "eval_dataset = ((image, label) # TFLite expects batch size 1.\n", " for batch in train_ds\n", " for (image, label) in zip(*batch))\n", "count = 0\n", "count_lite_tf_agree = 0\n", "count_lite_correct = 0\n", "for image, label in eval_dataset:\n", " probs_lite = lite_model(image[None, ...])[0]\n", " probs_tf = model(image[None, ...]).numpy()[0]\n", " y_lite = np.argmax(probs_lite)\n", " y_tf = np.argmax(probs_tf)\n", " y_true = np.argmax(label)\n", " count +=1\n", " if y_lite == y_tf: count_lite_tf_agree += 1\n", " if y_lite == y_true: count_lite_correct += 1\n", " if count >= num_eval_examples: break\n", "print(\"TFLite model agrees with original model on %d of %d examples (%g%%).\" %\n", " (count_lite_tf_agree, count, 100.0 * count_lite_tf_agree / count))\n", "print(\"TFLite model is accurate on %d of %d examples (%g%%).\" %\n", " (count_lite_correct, count, 100.0 * count_lite_correct / count))" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [ "ScitaPqhKtuW" ], "name": "tf2_image_retraining.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }