{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "Tce3stUlHN0L" }, "source": [ "##### Copyright 2020 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2025-06-21T11:32:59.431244Z", "iopub.status.busy": "2025-06-21T11:32:59.430846Z", "iopub.status.idle": "2025-06-21T11:32:59.434995Z", "shell.execute_reply": "2025-06-21T11:32:59.434386Z" }, "id": "IcfrhafzkZbH" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "qFdPvlXBOdUN" }, "source": [ "# Pruning comprehensive guide" ] }, { "cell_type": "markdown", "metadata": { "id": "MfBg1C5NB3X0" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
\n", " View on TensorFlow.org\n", " \n", " Run in Google Colab\n", " \n", " View source on GitHub\n", " \n", " Download notebook\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "FbORZA_bQx1G" }, "source": [ "Welcome to the comprehensive guide for Keras weight pruning.\n", "\n", "This page documents various use cases and shows how to use the API for each one. Once you know which APIs you need, find the parameters and the low-level details in the\n", "[API docs](https://www.tensorflow.org/model_optimization/api_docs/python/tfmot/sparsity).\n", "\n", "* If you want to see the benefits of pruning and what's supported, see the [overview](https://www.tensorflow.org/model_optimization/guide/pruning).\n", "* For a single end-to-end example, see the [pruning example](https://www.tensorflow.org/model_optimization/guide/pruning/pruning_with_keras).\n", "\n", "The following use cases are covered:\n", "* Define and train a pruned model.\n", " * Sequential and Functional.\n", " * Keras model.fit and custom training loops\n", "* Checkpoint and deserialize a pruned model.\n", "* Deploy a pruned model and see compression benefits.\n", "\n", "For configuration of the pruning algorithm, refer to the `tfmot.sparsity.keras.prune_low_magnitude` API docs." ] }, { "cell_type": "markdown", "metadata": { "id": "nuABqZnXVDvO" }, "source": [ "## Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "u9mRDekZEfnR" }, "source": [ "For finding the APIs you need and understanding purposes, you can run but skip reading this section." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "cellView": "both", "execution": { "iopub.execute_input": "2025-06-21T11:32:59.438710Z", "iopub.status.busy": "2025-06-21T11:32:59.438089Z", "iopub.status.idle": "2025-06-21T11:33:07.528726Z", "shell.execute_reply": "2025-06-21T11:33:07.527860Z" }, "id": "lvpH1Hg7ULFz" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-06-21 11:33:03.628323: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", "E0000 00:00:1750505583.651308 27487 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "E0000 00:00:1750505583.658221 27487 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", "W0000 00:00:1750505583.676481 27487 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n", "W0000 00:00:1750505583.676504 27487 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n", "W0000 00:00:1750505583.676506 27487 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n", "W0000 00:00:1750505583.676509 27487 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2025-06-21 11:33:07.004484: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/1 [==============================] - ETA: 0s - loss: 16.1181 - accuracy: 0.0000e+00" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1/1 [==============================] - 0s 341ms/step - loss: 16.1181 - accuracy: 0.0000e+00\n" ] } ], "source": [ "! pip install -q tensorflow-model-optimization\n", "\n", "import tensorflow as tf\n", "import numpy as np\n", "import tensorflow_model_optimization as tfmot\n", "import tf_keras as keras\n", "\n", "%load_ext tensorboard\n", "\n", "import tempfile\n", "\n", "input_shape = [20]\n", "x_train = np.random.randn(1, 20).astype(np.float32)\n", "y_train = keras.utils.to_categorical(np.random.randn(1), num_classes=20)\n", "\n", "def setup_model():\n", " model = keras.Sequential([\n", " keras.layers.Dense(20, input_shape=input_shape),\n", " keras.layers.Flatten()\n", " ])\n", " return model\n", "\n", "def setup_pretrained_weights():\n", " model = setup_model()\n", "\n", " model.compile(\n", " loss=keras.losses.categorical_crossentropy,\n", " optimizer='adam',\n", " metrics=['accuracy']\n", " )\n", "\n", " model.fit(x_train, y_train)\n", "\n", " _, pretrained_weights = tempfile.mkstemp('.tf')\n", "\n", " model.save_weights(pretrained_weights)\n", "\n", " return pretrained_weights\n", "\n", "def get_gzipped_model_size(model):\n", " # Returns size of gzipped model, in bytes.\n", " import os\n", " import zipfile\n", "\n", " _, keras_file = tempfile.mkstemp('.h5')\n", " model.save(keras_file, include_optimizer=False)\n", "\n", " _, zipped_file = tempfile.mkstemp('.zip')\n", " with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:\n", " f.write(keras_file)\n", "\n", " return os.path.getsize(zipped_file)\n", "\n", "setup_model()\n", "pretrained_weights = setup_pretrained_weights()" ] }, { "cell_type": "markdown", "metadata": { "id": "TZyLYFTER4aP" }, "source": [ "## Define model" ] }, { "cell_type": "markdown", "metadata": { "id": "Ybigft1fTn4T" }, "source": [ "### Prune whole model (Sequential and Functional)" ] }, { "cell_type": "markdown", "metadata": { "id": "puZvqnp1xsn-" }, "source": [ "**Tips for better model accuracy:**\n", "* Try \"Prune some layers\" to skip pruning the layers that reduce accuracy the most.\n", "* It's generally better to finetune with pruning as opposed to training from scratch.\n", "\n", "To make the whole model train with pruning, apply `tfmot.sparsity.keras.prune_low_magnitude` to the model.\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2025-06-21T11:33:07.532787Z", "iopub.status.busy": "2025-06-21T11:33:07.532120Z", "iopub.status.idle": "2025-06-21T11:33:08.014415Z", "shell.execute_reply": "2025-06-21T11:33:08.013688Z" }, "id": "aIn-hFO_T_PU" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"sequential_2\"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Layer (type) Output Shape Param # \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " prune_low_magnitude_dense_ (None, 20) 822 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 2 (PruneLowMagnitude) \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " prune_low_magnitude_flatte (None, 20) 1 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " n_2 (PruneLowMagnitude) \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Total params: 823 (3.22 KB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Trainable params: 420 (1.64 KB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Non-trainable params: 403 (1.58 KB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] } ], "source": [ "base_model = setup_model()\n", "base_model.load_weights(pretrained_weights) # optional but recommended.\n", "\n", "model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)\n", "\n", "model_for_pruning.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "xTbTLn3dZM7h" }, "source": [ "### Prune some layers (Sequential and Functional)" ] }, { "cell_type": "markdown", "metadata": { "id": "MbM8o832xTxV" }, "source": [ "Pruning a model can have a negative effect on accuracy. You can selectively prune layers of a model to explore the trade-off between accuracy, speed, and model size.\n", "\n", "**Tips for better model accuracy:**\n", "* It's generally better to finetune with pruning as opposed to training from scratch.\n", "* Try pruning the later layers instead of the first layers.\n", "* Avoid pruning critical layers (e.g. attention mechanism).\n", "\n", "**More**:\n", "* The `tfmot.sparsity.keras.prune_low_magnitude` API docs provide details on how to vary the pruning configuration per layer.\n", "\n", "In the example below, prune only the `Dense` layers." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2025-06-21T11:33:08.017973Z", "iopub.status.busy": "2025-06-21T11:33:08.017700Z", "iopub.status.idle": "2025-06-21T11:33:08.073948Z", "shell.execute_reply": "2025-06-21T11:33:08.073369Z" }, "id": "HN0B_QB-ZhE2" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._iterations\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._iterations\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._learning_rate\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._learning_rate\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Model: \"sequential_3\"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Layer (type) Output Shape Param # \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " prune_low_magnitude_dense_ (None, 20) 822 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 3 (PruneLowMagnitude) \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " flatten_3 (Flatten) (None, 20) 0 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Total params: 822 (3.21 KB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Trainable params: 420 (1.64 KB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Non-trainable params: 402 (1.57 KB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] } ], "source": [ "# Create a base model\n", "base_model = setup_model()\n", "base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy\n", "\n", "# Helper function uses `prune_low_magnitude` to make only the \n", "# Dense layers train with pruning.\n", "def apply_pruning_to_dense(layer):\n", " if isinstance(layer, keras.layers.Dense):\n", " return tfmot.sparsity.keras.prune_low_magnitude(layer)\n", " return layer\n", "\n", "# Use `keras.models.clone_model` to apply `apply_pruning_to_dense` \n", "# to the layers of the model.\n", "model_for_pruning = keras.models.clone_model(\n", " base_model,\n", " clone_function=apply_pruning_to_dense,\n", ")\n", "\n", "model_for_pruning.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "HiA28PrrW11H" }, "source": [ "While this example used the type of the layer to decide what to prune, the easiest way to prune a particular layer is to set its `name` property, and look for that name in the `clone_function`." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2025-06-21T11:33:08.077522Z", "iopub.status.busy": "2025-06-21T11:33:08.076869Z", "iopub.status.idle": "2025-06-21T11:33:08.080736Z", "shell.execute_reply": "2025-06-21T11:33:08.080070Z" }, "id": "CjY_JyB808Da" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "dense_3\n" ] } ], "source": [ "print(base_model.layers[0].name)" ] }, { "cell_type": "markdown", "metadata": { "id": "mpb_BydRaSoF" }, "source": [ "#### More readable but potentially lower model accuracy" ] }, { "cell_type": "markdown", "metadata": { "id": "2vqXeYffzSHp" }, "source": [ "This is not compatible with fine-tuning with pruning, which is why it may be less accurate than the above examples which\n", "support fine-tuning.\n", "\n", "While `prune_low_magnitude` can be applied while defining the initial model, loading the weights after does not work in the below examples." ] }, { "cell_type": "markdown", "metadata": { "id": "s5p5jvH5KznJ" }, "source": [ "**Functional example**" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2025-06-21T11:33:08.083799Z", "iopub.status.busy": "2025-06-21T11:33:08.083247Z", "iopub.status.idle": "2025-06-21T11:33:08.120759Z", "shell.execute_reply": "2025-06-21T11:33:08.120162Z" }, "id": "7Wow55hg5oiM" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"model\"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Layer (type) Output Shape Param # \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " input_1 (InputLayer) [(None, 20)] 0 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " prune_low_magnitude_dense_ (None, 10) 412 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 4 (PruneLowMagnitude) \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " flatten_4 (Flatten) (None, 10) 0 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Total params: 412 (1.61 KB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Trainable params: 210 (840.00 Byte)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Non-trainable params: 202 (812.00 Byte)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] } ], "source": [ "# Use `prune_low_magnitude` to make the `Dense` layer train with pruning.\n", "i = keras.Input(shape=(20,))\n", "x = tfmot.sparsity.keras.prune_low_magnitude(keras.layers.Dense(10))(i)\n", "o = keras.layers.Flatten()(x)\n", "model_for_pruning = keras.Model(inputs=i, outputs=o)\n", "\n", "model_for_pruning.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "wIGj-r2of2ls" }, "source": [ "**Sequential example**\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2025-06-21T11:33:08.123829Z", "iopub.status.busy": "2025-06-21T11:33:08.123595Z", "iopub.status.idle": "2025-06-21T11:33:08.157995Z", "shell.execute_reply": "2025-06-21T11:33:08.157401Z" }, "id": "mQOiDUGgfi4y" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"sequential_4\"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Layer (type) Output Shape Param # \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " prune_low_magnitude_dense_ (None, 20) 822 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 5 (PruneLowMagnitude) \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " flatten_5 (Flatten) (None, 20) 0 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Total params: 822 (3.21 KB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Trainable params: 420 (1.64 KB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Non-trainable params: 402 (1.57 KB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] } ], "source": [ "# Use `prune_low_magnitude` to make the `Dense` layer train with pruning.\n", "model_for_pruning = keras.Sequential([\n", " tfmot.sparsity.keras.prune_low_magnitude(keras.layers.Dense(20, input_shape=input_shape)),\n", " keras.layers.Flatten()\n", "])\n", "\n", "model_for_pruning.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "vnMguvVSnUqD" }, "source": [ "### Prune custom Keras layer or modify parts of layer to prune" ] }, { "cell_type": "markdown", "metadata": { "id": "BLgH1aFMjTK4" }, "source": [ "**Common mistake:** pruning the bias usually harms model accuracy too much.\n", "\n", "`tfmot.sparsity.keras.PrunableLayer` serves two use cases:\n", "1. Prune a custom Keras layer\n", "2. Modify parts of a built-in Keras layer to prune.\n", "\n", "For an example, the API defaults to only pruning the kernel of the\n", "`Dense` layer. The example below prunes the bias also.\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2025-06-21T11:33:08.161029Z", "iopub.status.busy": "2025-06-21T11:33:08.160794Z", "iopub.status.idle": "2025-06-21T11:33:08.202968Z", "shell.execute_reply": "2025-06-21T11:33:08.202361Z" }, "id": "77jgBjccnTh6" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"sequential_5\"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Layer (type) Output Shape Param # \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " prune_low_magnitude_my_den (None, 20) 843 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " se_layer (PruneLowMagnitud \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " e) \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " flatten_6 (Flatten) (None, 20) 0 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Total params: 843 (3.30 KB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Trainable params: 420 (1.64 KB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Non-trainable params: 423 (1.66 KB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] } ], "source": [ "class MyDenseLayer(keras.layers.Dense, tfmot.sparsity.keras.PrunableLayer):\n", "\n", " def get_prunable_weights(self):\n", " # Prune bias also, though that usually harms model accuracy too much.\n", " return [self.kernel, self.bias]\n", "\n", "# Use `prune_low_magnitude` to make the `MyDenseLayer` layer train with pruning.\n", "model_for_pruning = keras.Sequential([\n", " tfmot.sparsity.keras.prune_low_magnitude(MyDenseLayer(20, input_shape=input_shape)),\n", " keras.layers.Flatten()\n", "])\n", "\n", "model_for_pruning.summary()\n" ] }, { "cell_type": "markdown", "metadata": { "id": "itAyTyzvRroH" }, "source": [ "## Train model" ] }, { "cell_type": "markdown", "metadata": { "id": "y4hnWH2NY5MO" }, "source": [ "### Model.fit" ] }, { "cell_type": "markdown", "metadata": { "id": "_LYCDIunTE9B" }, "source": [ "Call the `tfmot.sparsity.keras.UpdatePruningStep` callback during training. \n", "\n", "To help debug training, use the `tfmot.sparsity.keras.PruningSummaries` callback." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fKZ2PxcpY_WV" }, "outputs": [], "source": [ "# Define the model.\n", "base_model = setup_model()\n", "base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy\n", "model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)\n", "\n", "log_dir = tempfile.mkdtemp()\n", "callbacks = [\n", " tfmot.sparsity.keras.UpdatePruningStep(),\n", " # Log sparsity and other metrics in Tensorboard.\n", " tfmot.sparsity.keras.PruningSummaries(log_dir=log_dir)\n", "]\n", "\n", "model_for_pruning.compile(\n", " loss=keras.losses.categorical_crossentropy,\n", " optimizer='adam',\n", " metrics=['accuracy']\n", ")\n", "\n", "model_for_pruning.fit(\n", " x_train,\n", " y_train,\n", " callbacks=callbacks,\n", " epochs=2,\n", ")\n", "\n", "#docs_infra: no_execute\n", "%tensorboard --logdir={log_dir}" ] }, { "cell_type": "markdown", "metadata": { "id": "6kcuGmf5MSnJ" }, "source": [ "For non-Colab users, you can see [the results of a previous run](https://tensorboard.dev/experiment/XiNXEBjHQ3Oabc6jRLKiXQ/#scalars&_smoothingWeight=0) of this code block on [TensorBoard.dev](https://tensorboard.dev/)." ] }, { "cell_type": "markdown", "metadata": { "id": "pDcSvbNdZA-1" }, "source": [ "### Custom training loop" ] }, { "cell_type": "markdown", "metadata": { "id": "uQA8GaD6T3-o" }, "source": [ "Call the `tfmot.sparsity.keras.UpdatePruningStep` callback during training. \n", "\n", "To help debug training, use the `tfmot.sparsity.keras.PruningSummaries` callback." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hPQUrkodbIF2" }, "outputs": [], "source": [ "# Define the model.\n", "base_model = setup_model()\n", "base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy\n", "model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)\n", "\n", "# Boilerplate\n", "loss = keras.losses.categorical_crossentropy\n", "optimizer = keras.optimizers.Adam()\n", "log_dir = tempfile.mkdtemp()\n", "unused_arg = -1\n", "epochs = 2\n", "batches = 1 # example is hardcoded so that the number of batches cannot change.\n", "\n", "# Non-boilerplate.\n", "model_for_pruning.optimizer = optimizer\n", "step_callback = tfmot.sparsity.keras.UpdatePruningStep()\n", "step_callback.set_model(model_for_pruning)\n", "log_callback = tfmot.sparsity.keras.PruningSummaries(log_dir=log_dir) # Log sparsity and other metrics in Tensorboard.\n", "log_callback.set_model(model_for_pruning)\n", "\n", "step_callback.on_train_begin() # run pruning callback\n", "for _ in range(epochs):\n", " log_callback.on_epoch_begin(epoch=unused_arg) # run pruning callback\n", " for _ in range(batches):\n", " step_callback.on_train_batch_begin(batch=unused_arg) # run pruning callback\n", "\n", " with tf.GradientTape() as tape:\n", " logits = model_for_pruning(x_train, training=True)\n", " loss_value = loss(y_train, logits)\n", " grads = tape.gradient(loss_value, model_for_pruning.trainable_variables)\n", " optimizer.apply_gradients(zip(grads, model_for_pruning.trainable_variables))\n", "\n", " step_callback.on_epoch_end(batch=unused_arg) # run pruning callback\n", "\n", "#docs_infra: no_execute\n", "%tensorboard --logdir={log_dir}" ] }, { "cell_type": "markdown", "metadata": { "id": "vh4lJt4zMh1v" }, "source": [ "For non-Colab users, you can see [the results of a previous run](https://tensorboard.dev/experiment/jDeGzF3xQeSyb7Qir1ZcBQ/#scalars&_smoothingWeight=0) of this code block on [TensorBoard.dev](https://tensorboard.dev/)." ] }, { "cell_type": "markdown", "metadata": { "id": "o8H-8lQ-cPa-" }, "source": [ "### Improve pruned model accuracy\n" ] }, { "cell_type": "markdown", "metadata": { "id": "Y2t4fYXvAV1V" }, "source": [ "First, look at the `tfmot.sparsity.keras.prune_low_magnitude` API docs\n", "to understand what a pruning schedule is and the math of\n", "each type of pruning schedule.\n", "\n", "**Tips**:\n", "\n", "* Have a learning rate that's not too high or too low when the model is pruning. Consider the [pruning schedule](https://www.tensorflow.org/model_optimization/api_docs/python/tfmot/sparsity/keras/PruningSchedule) to be a hyperparameter.\n", "\n", "* As a quick test, try experimenting with pruning a model to the final sparsity at the begining of training by setting `begin_step` to 0 with a `tfmot.sparsity.keras.ConstantSparsity` schedule. You might get lucky with good results.\n", "\n", "* Do not prune very frequently to give the model time to recover. The [pruning schedule](https://www.tensorflow.org/model_optimization/api_docs/python/tfmot/sparsity/keras/PruningSchedule) provides a decent default frequency.\n", "\n", "* For general ideas to improve model accuracy, look for tips for your use case(s) under \"Define model\"." ] }, { "cell_type": "markdown", "metadata": { "id": "MpvX5IqahV1r" }, "source": [ "## Checkpoint and deserialize" ] }, { "cell_type": "markdown", "metadata": { "id": "GuZ5wlij1dcJ" }, "source": [ "You must preserve the optimizer step during checkpointing. This means while you can use Keras HDF5 models for checkpointing, you cannot use Keras HDF5 weights." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2025-06-21T11:33:08.207501Z", "iopub.status.busy": "2025-06-21T11:33:08.207153Z", "iopub.status.idle": "2025-06-21T11:33:08.268919Z", "shell.execute_reply": "2025-06-21T11:33:08.268244Z" }, "id": "6khQg-q7imfH" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._iterations\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._iterations\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._learning_rate\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._learning_rate\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tf_keras/src/engine/training.py:3098: UserWarning: You are saving your model as an HDF5 file via `model.save()`. This file format is considered legacy. We recommend using instead the native TF-Keras format, e.g. `model.save('my_model.keras')`.\n", " saving_api.save_model(\n", "WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n" ] } ], "source": [ "# Define the model.\n", "base_model = setup_model()\n", "base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy\n", "model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)\n", "\n", "_, keras_model_file = tempfile.mkstemp('.h5')\n", "\n", "# Checkpoint: saving the optimizer is necessary (include_optimizer=True is the default).\n", "model_for_pruning.save(keras_model_file, include_optimizer=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "H-CLxooLYnRN" }, "source": [ "The above applies generally. The code below is only needed for the HDF5 model format (not HDF5 weights and other formats).\n" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2025-06-21T11:33:08.271992Z", "iopub.status.busy": "2025-06-21T11:33:08.271749Z", "iopub.status.idle": "2025-06-21T11:33:08.322637Z", "shell.execute_reply": "2025-06-21T11:33:08.321934Z" }, "id": "2nGC1hZnYlzb" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:No training configuration found in the save file, so the model was *not* compiled. Compile it manually.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:No training configuration found in the save file, so the model was *not* compiled. Compile it manually.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Model: \"sequential_6\"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Layer (type) Output Shape Param # \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " prune_low_magnitude_dense_ (None, 20) 822 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 6 (PruneLowMagnitude) \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " prune_low_magnitude_flatte (None, 20) 1 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " n_7 (PruneLowMagnitude) \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Total params: 823 (3.22 KB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Trainable params: 420 (1.64 KB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Non-trainable params: 403 (1.58 KB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] } ], "source": [ "# Deserialize model.\n", "with tfmot.sparsity.keras.prune_scope():\n", " loaded_model = keras.models.load_model(keras_model_file)\n", "\n", "loaded_model.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "jew8M217SgQw" }, "source": [ "## Deploy pruned model" ] }, { "cell_type": "markdown", "metadata": { "id": "2uj4SfF1cnTR" }, "source": [ "### Export model with size compression" ] }, { "cell_type": "markdown", "metadata": { "id": "57uNm47L4Yro" }, "source": [ "**Common mistake**: both `strip_pruning` and applying a standard compression algorithm (e.g. via gzip) are necessary to see the compression\n", "benefits of pruning." ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2025-06-21T11:33:08.326978Z", "iopub.status.busy": "2025-06-21T11:33:08.326488Z", "iopub.status.idle": "2025-06-21T11:33:08.423473Z", "shell.execute_reply": "2025-06-21T11:33:08.422763Z" }, "id": "EZ3TD8cYkxZM" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._iterations\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._iterations\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._learning_rate\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._learning_rate\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "final model" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Model: \"sequential_7\"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Layer (type) Output Shape Param # \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " dense_7 (Dense) (None, 20) 420 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " flatten_8 (Flatten) (None, 20) 0 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Total params: 420 (1.64 KB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Trainable params: 420 (1.64 KB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Non-trainable params: 0 (0.00 Byte)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Size of gzipped pruned model without stripping: 3447.00 bytes\n", "WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Size of gzipped pruned model with stripping: 2934.00 bytes\n" ] } ], "source": [ "# Define the model.\n", "base_model = setup_model()\n", "base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy\n", "model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)\n", "\n", "# Typically you train the model here.\n", "\n", "model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)\n", "\n", "print(\"final model\")\n", "model_for_export.summary()\n", "\n", "print(\"\\n\")\n", "print(\"Size of gzipped pruned model without stripping: %.2f bytes\" % (get_gzipped_model_size(model_for_pruning)))\n", "print(\"Size of gzipped pruned model with stripping: %.2f bytes\" % (get_gzipped_model_size(model_for_export)))" ] }, { "cell_type": "markdown", "metadata": { "id": "qPXvYIHOctem" }, "source": [ "### Hardware-specific optimizations" ] }, { "cell_type": "markdown", "metadata": { "id": "yqk0jI49c1mw" }, "source": [ "Once different backends [enable pruning to improve latency]((https://github.com/tensorflow/model-optimization/issues/173)), using block sparsity can improve latency for certain hardware.\n", "\n", "Increasing the block size will decrease the peak sparsity that's achievable for a target model accuracy. Despite this, latency can still improve.\n", "\n", "For details on what's supported for block sparsity, see\n", "the `tfmot.sparsity.keras.prune_low_magnitude` API docs." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2025-06-21T11:33:08.426676Z", "iopub.status.busy": "2025-06-21T11:33:08.426428Z", "iopub.status.idle": "2025-06-21T11:33:08.486953Z", "shell.execute_reply": "2025-06-21T11:33:08.486318Z" }, "id": "xedaVDeFc0bw" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._iterations\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._iterations\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._learning_rate\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._learning_rate\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Model: \"sequential_8\"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Layer (type) Output Shape Param # \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " prune_low_magnitude_dense_ (None, 20) 822 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 8 (PruneLowMagnitude) \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " prune_low_magnitude_flatte (None, 20) 1 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " n_9 (PruneLowMagnitude) \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Total params: 823 (3.22 KB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Trainable params: 420 (1.64 KB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Non-trainable params: 403 (1.58 KB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] } ], "source": [ "base_model = setup_model()\n", "\n", "# For using intrinsics on a CPU with 128-bit registers, together with 8-bit\n", "# quantized weights, a 1x16 block size is nice because the block perfectly\n", "# fits into the register.\n", "pruning_params = {'block_size': [1, 16]}\n", "model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model, **pruning_params)\n", "\n", "model_for_pruning.summary()" ] } ], "metadata": { "colab": { "collapsed_sections": [ "Tce3stUlHN0L" ], "name": "comprehensive_guide.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.23" } }, "nbformat": 4, "nbformat_minor": 0 }