{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "Tce3stUlHN0L" }, "source": [ "##### Copyright 2021 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2025-06-21T11:29:46.632121Z", "iopub.status.busy": "2025-06-21T11:29:46.631883Z", "iopub.status.idle": "2025-06-21T11:29:46.635903Z", "shell.execute_reply": "2025-06-21T11:29:46.635274Z" }, "id": "IcfrhafzkZbH" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "qFdPvlXBOdUN" }, "source": [ "# Sparse weights using structural pruning" ] }, { "cell_type": "markdown", "metadata": { "id": "MfBg1C5NB3X0" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
\n", " View on TensorFlow.org\n", " \n", " Run in Google Colab\n", " \n", " View source on GitHub\n", " \n", " Download notebook\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "FbORZA_bQx1G" }, "source": [ "Structural pruning weights from your model to make it sparse in specific pattern can accelerate model inference time with appropriate HW supports.\n", "\n", "This tutorial shows you how to:\n", "* Define and train a model on the mnist dataset with a specific structural sparsity\n", "* Convert the pruned model to tflite format\n", "* Visualize structure of the pruned weights\n", "\n", "For a general overview of the pruning technique for the model optimization, see the [pruning overview](https://www.tensorflow.org/model_optimization/guide/pruning). For tutorial on general weight pruning, see [Pruning in Keras](https://www.tensorflow.org/model_optimization/guide/pruning/pruning_with_keras)." ] }, { "cell_type": "markdown", "metadata": { "id": "0f4SoBcoXNcb" }, "source": [ "## Structural pruning of weights" ] }, { "cell_type": "markdown", "metadata": { "id": "rn_a9362Wr_B" }, "source": [ "Structural pruning systematically zeroes out model weights at the beginning of the training process. You apply this pruning techniques to regular blocks of weights to speed up inference on supporting HWs, for example: grouping weights in the model by blocks of four and zeroing out two of those weights in each block, known as a _2 by 4_ reduction. This technique applies only to the last dimension of the weight tensor for the model that is converted by TensorFlow Lite. For example, `Conv2D` layer weights in TensorFlow Lite have the structure `[channel_out, height, width, channel_in]` and `Dense` layer weights have the structure `[channel_out, channel_in]`. The sparsity pattern is applied to the weights in the last dimension: `channel_in`.\n", "\n", "Compare to the random sparsity, the structured sparsity generally has lower accuracy due to restrictive structure, however, it can reduce inference time significantly on the supported hardware.\n", "\n", "Pruning can be applied to a model together with other model compression techniques for better compression rate. See quantization and clustering examples in [collaborative optimization technique](https://blog.tensorflow.org/2021/10/Collaborative-Optimizations.html) for more details." ] }, { "cell_type": "markdown", "metadata": { "id": "nuABqZnXVDvO" }, "source": [ "## Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "u9mRDekZEfnR" }, "source": [ "Prepare your development environment and data." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "cellView": "both", "execution": { "iopub.execute_input": "2025-06-21T11:29:46.639373Z", "iopub.status.busy": "2025-06-21T11:29:46.639126Z", "iopub.status.idle": "2025-06-21T11:29:52.606499Z", "shell.execute_reply": "2025-06-21T11:29:52.605201Z" }, "id": "lvpH1Hg7ULFz" }, "outputs": [], "source": [ "! pip install -q tensorflow\n", "! pip install -q tensorflow-model-optimization\n", "! pip install -q matplotlib" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2025-06-21T11:29:52.610242Z", "iopub.status.busy": "2025-06-21T11:29:52.609924Z", "iopub.status.idle": "2025-06-21T11:29:56.182292Z", "shell.execute_reply": "2025-06-21T11:29:56.181261Z" }, "id": "_hn5e5_gWr_E" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-06-21 11:29:52.922961: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", "E0000 00:00:1750505392.944354 24882 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "E0000 00:00:1750505392.950893 24882 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", "W0000 00:00:1750505392.968012 24882 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n", "W0000 00:00:1750505392.968035 24882 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n", "W0000 00:00:1750505392.968037 24882 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n", "W0000 00:00:1750505392.968040 24882 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n" ] } ], "source": [ "import tensorflow as tf\n", "from tensorflow import keras\n", "\n", "import tensorflow_model_optimization as tfmot\n", "prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude" ] }, { "cell_type": "markdown", "metadata": { "id": "TZyLYFTER4aP" }, "source": [ "## Download and normalize image data from the [MNIST](https://www.tensorflow.org/datasets/catalog/mnist) dataset" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2025-06-21T11:29:56.186104Z", "iopub.status.busy": "2025-06-21T11:29:56.185626Z", "iopub.status.idle": "2025-06-21T11:29:56.634263Z", "shell.execute_reply": "2025-06-21T11:29:56.633219Z" }, "id": "hSf4jYKGWr_E" }, "outputs": [], "source": [ "# Load MNIST dataset.\n", "mnist = keras.datasets.mnist\n", "(train_images, train_labels), (test_images, test_labels) = mnist.load_data()\n", "\n", "# Normalize the input image so that each pixel value is between 0 and 1.\n", "train_images = train_images / 255.0\n", "test_images = test_images / 255.0" ] }, { "cell_type": "markdown", "metadata": { "id": "LKaL3XH1XO0Q" }, "source": [ "## Define structural pruning parameters" ] }, { "cell_type": "markdown", "metadata": { "id": "s9_33ta-Wr_E" }, "source": [ "Define parameters for pruning and specify the type of structural pruning. Set the parameters for pruning to `(2, 4)`.\n", "These settings mean that in a block of four elements, at least two with the lowest magnitude are set to zero.\n", "\n", "You don't have to set the `pruning_schedule` parameter. By default, the pruning mask is defined at the first step and it is not updated during the training." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2025-06-21T11:29:56.638336Z", "iopub.status.busy": "2025-06-21T11:29:56.638015Z", "iopub.status.idle": "2025-06-21T11:29:56.642139Z", "shell.execute_reply": "2025-06-21T11:29:56.641329Z" }, "id": "1EXNYAPJWr_F" }, "outputs": [], "source": [ "pruning_params_2_by_4 = {\n", " 'sparsity_m_by_n': (2, 4),\n", "}" ] }, { "cell_type": "markdown", "metadata": { "id": "mMKdsdAUWr_F" }, "source": [ "Define parameters for random pruning with the target sparsity of 50%." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2025-06-21T11:29:56.645408Z", "iopub.status.busy": "2025-06-21T11:29:56.644679Z", "iopub.status.idle": "2025-06-21T11:29:56.649172Z", "shell.execute_reply": "2025-06-21T11:29:56.648310Z" }, "id": "un24AZUOWr_F" }, "outputs": [], "source": [ "pruning_params_sparsity_0_5 = {\n", " 'pruning_schedule': tfmot.sparsity.keras.ConstantSparsity(target_sparsity=0.5,\n", " begin_step=0,\n", " frequency=100)\n", "}" ] }, { "cell_type": "markdown", "metadata": { "id": "jV4Yt0v5Wr_G" }, "source": [ "Define the model architecture and specify which layers to prune. Structural pruning is applied based on the layers of the model you select.\n", "\n", "In the example below, we prune only some of the layers. We prune the second `Conv2D` layer and the first `Dense` layer.\n", "\n", "Notice that the first `Conv2D` layer cannot be pruned structurally. To be pruned structurally, it should have more than one input channels. Instead, we prune the first `Conv2D` layer with random pruning." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2025-06-21T11:29:56.652658Z", "iopub.status.busy": "2025-06-21T11:29:56.651860Z", "iopub.status.idle": "2025-06-21T11:29:57.415536Z", "shell.execute_reply": "2025-06-21T11:29:57.414581Z" }, "id": "BDGzC6YlWr_G" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-06-21 11:29:56.776140: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Model: \"sequential\"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Layer (type) Output Shape Param # \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " prune_low_magnitude_prunin (None, 28, 28, 32) 1634 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " g_sparsity_0_5 (PruneLowMa \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " gnitude) \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " max_pooling2d (MaxPooling2 (None, 14, 14, 32) 0 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " D) \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " prune_low_magnitude_struct (None, 14, 14, 64) 102466 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " ural_pruning (PruneLowMagn \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " itude) \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " batch_normalization (Batch (None, 14, 14, 64) 256 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Normalization) \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " re_lu (ReLU) (None, 14, 14, 64) 0 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " max_pooling2d_1 (MaxPoolin (None, 7, 7, 64) 0 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " g2D) \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " flatten (Flatten) (None, 3136) 0 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " prune_low_magnitude_struct (None, 1024) 6423554 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " ural_pruning_dense (PruneL \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " owMagnitude) \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " dropout (Dropout) (None, 1024) 0 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " dense (Dense) (None, 10) 10250 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Total params: 6538160 (24.94 MB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Trainable params: 3274762 (12.49 MB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Non-trainable params: 3263398 (12.45 MB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] } ], "source": [ "model = keras.Sequential([\n", " prune_low_magnitude(\n", " keras.layers.Conv2D(\n", " 32, 5, padding='same', activation='relu',\n", " input_shape=(28, 28, 1),\n", " name=\"pruning_sparsity_0_5\"),\n", " **pruning_params_sparsity_0_5),\n", " keras.layers.MaxPooling2D((2, 2), (2, 2), padding='same'),\n", " prune_low_magnitude(\n", " keras.layers.Conv2D(\n", " 64, 5, padding='same',\n", " name=\"structural_pruning\"),\n", " **pruning_params_2_by_4),\n", " keras.layers.BatchNormalization(),\n", " keras.layers.ReLU(),\n", " keras.layers.MaxPooling2D((2, 2), (2, 2), padding='same'),\n", " keras.layers.Flatten(),\n", " prune_low_magnitude(\n", " keras.layers.Dense(\n", " 1024, activation='relu',\n", " name=\"structural_pruning_dense\"),\n", " **pruning_params_2_by_4),\n", " keras.layers.Dropout(0.4),\n", " keras.layers.Dense(10)\n", "])\n", "\n", "model.compile(optimizer='adam',\n", " loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", " metrics=['accuracy'])\n", "\n", "model.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "U_ddzMppWr_G" }, "source": [ "Train and evaluate the model." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2025-06-21T11:29:57.419028Z", "iopub.status.busy": "2025-06-21T11:29:57.418729Z", "iopub.status.idle": "2025-06-21T11:30:31.322479Z", "shell.execute_reply": "2025-06-21T11:30:31.321705Z" }, "id": "F4CnppA1Wr_H" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Pruned test accuracy: 0.9865999817848206\n" ] } ], "source": [ "batch_size = 128\n", "epochs = 2\n", "\n", "model.fit(\n", " train_images,\n", " train_labels,\n", " batch_size=batch_size,\n", " epochs=epochs,\n", " verbose=0,\n", " callbacks=tfmot.sparsity.keras.UpdatePruningStep(),\n", " validation_split=0.1)\n", "\n", "_, pruned_model_accuracy = model.evaluate(test_images, test_labels, verbose=0)\n", "print('Pruned test accuracy:', pruned_model_accuracy)" ] }, { "cell_type": "markdown", "metadata": { "id": "bA8EDPHMWr_H" }, "source": [ "Remove the pruning wrapper so that it is not included in the model when you convert it to TensorFlow Lite format." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2025-06-21T11:30:31.325952Z", "iopub.status.busy": "2025-06-21T11:30:31.325328Z", "iopub.status.idle": "2025-06-21T11:30:31.364954Z", "shell.execute_reply": "2025-06-21T11:30:31.364176Z" }, "id": "3wn-OQ_gWr_H" }, "outputs": [], "source": [ "model = tfmot.sparsity.keras.strip_pruning(model)" ] }, { "cell_type": "markdown", "metadata": { "id": "eM28m66YWr_H" }, "source": [ "## Convert model to tflite format" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2025-06-21T11:30:31.367848Z", "iopub.status.busy": "2025-06-21T11:30:31.367614Z", "iopub.status.idle": "2025-06-21T11:30:32.800142Z", "shell.execute_reply": "2025-06-21T11:30:32.799284Z" }, "id": "EJ7DsA6-Wr_I" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpvliqo0lw/assets\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpvliqo0lw/assets\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Saved converted pruned model to: /tmpfs/tmp/tmpr4fup215.tflite\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", "W0000 00:00:1750505432.319830 24882 tf_tfl_flatbuffer_helpers.cc:365] Ignored output_format.\n", "W0000 00:00:1750505432.319870 24882 tf_tfl_flatbuffer_helpers.cc:368] Ignored drop_control_dependency.\n", "I0000 00:00:1750505432.328462 24882 mlir_graph_optimization_pass.cc:425] MLIR V1 optimization pass is not enabled\n" ] } ], "source": [ "import tempfile\n", "\n", "converter = tf.lite.TFLiteConverter.from_keras_model(model)\n", "tflite_model = converter.convert()\n", "\n", "_, tflite_file = tempfile.mkstemp('.tflite')\n", "print('Saved converted pruned model to:', tflite_file)\n", "with open(tflite_file, 'wb') as f:\n", " f.write(tflite_model)" ] }, { "cell_type": "markdown", "metadata": { "id": "S44x9Rz3Wr_I" }, "source": [ "## Visualize and check weights" ] }, { "cell_type": "markdown", "metadata": { "id": "_CTu0wxFWr_J" }, "source": [ "Now visualize the structure of weights in the `Dense` layer pruned with 2 by 4 sparsity. Extract the weights from the tflite file." ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2025-06-21T11:30:32.803495Z", "iopub.status.busy": "2025-06-21T11:30:32.803197Z", "iopub.status.idle": "2025-06-21T11:30:32.815115Z", "shell.execute_reply": "2025-06-21T11:30:32.814408Z" }, "id": "fOIp6QB5Wr_J" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/interpreter.py:457: UserWarning: Warning: tf.lite.Interpreter is deprecated and is scheduled for deletion in\n", " TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package.\n", " See the [migration guide](https://ai.google.dev/edge/litert/migration)\n", " for details.\n", " \n", " warnings.warn(_INTERPRETER_DELETION_WARNING)\n", "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/interpreter.py:465: UserWarning: Warning: Enabling `experimental_preserve_all_tensors` with the BUILTIN or AUTO op resolver is intended for debugging purposes only. Be aware that this can significantly increase memory usage by storing all intermediate tensors. If you encounter memory problems or are not actively debugging, consider disabling this option.\n", " warnings.warn(\n", "INFO: Created TensorFlow Lite XNNPACK delegate for CPU.\n" ] } ], "source": [ "# Load tflite file with the created pruned model\n", "interpreter = tf.lite.Interpreter(model_path=tflite_file, experimental_preserve_all_tensors=True)\n", "interpreter.allocate_tensors()\n", "\n", "details = interpreter.get_tensor_details()\n", "\n", "# Weights of the dense layer that has been pruned.\n", "tensor_name = 'structural_pruning_dense/MatMul'\n", "detail = [x for x in details if tensor_name in x[\"name\"]]\n", "\n", "# We need the first layer.\n", "tensor_data = interpreter.tensor(detail[0][\"index\"])()" ] }, { "cell_type": "markdown", "metadata": { "id": "yy0jTs_QWr_K" }, "source": [ "To verify that we selected the correct layer that has been pruned, print the shape of the weight tensor." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2025-06-21T11:30:32.818290Z", "iopub.status.busy": "2025-06-21T11:30:32.817912Z", "iopub.status.idle": "2025-06-21T11:30:32.821529Z", "shell.execute_reply": "2025-06-21T11:30:32.820923Z" }, "id": "mCDkwMUPWr_K" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Shape of Dense layer is (1, 1024)\n" ] } ], "source": [ "print(f\"Shape of Dense layer is {tensor_data.shape}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "mvYTILeUWr_K" }, "source": [ "Now we visualize the structure for a small subset of the weight tensor. The structure of the weight tensor is sparse in the last dimension, using the `(2,4)` pattern: two elements out of four are zeros. To make the visualization more clear, we replace all non-zero values with ones." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2025-06-21T11:30:32.824319Z", "iopub.status.busy": "2025-06-21T11:30:32.823926Z", "iopub.status.idle": "2025-06-21T11:30:32.828119Z", "shell.execute_reply": "2025-06-21T11:30:32.827560Z" }, "id": "WZfn34bRWr_K" }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "# The value 24 is chosen for convenience.\n", "width = height = 24\n", "\n", "subset_values_to_display = tensor_data[0:height, 0:width]\n", "\n", "val_ones = np.ones([height, width])\n", "val_zeros = np.zeros([height, width])\n", "subset_values_to_display = np.where(abs(subset_values_to_display) > 0, val_ones, val_zeros)" ] }, { "cell_type": "markdown", "metadata": { "id": "fOfWvKwKWr_L" }, "source": [ "Define the auxiliary function to draw separation lines to see the structure clearly." ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "execution": { "iopub.execute_input": "2025-06-21T11:30:32.830914Z", "iopub.status.busy": "2025-06-21T11:30:32.830567Z", "iopub.status.idle": "2025-06-21T11:30:32.835947Z", "shell.execute_reply": "2025-06-21T11:30:32.835322Z" }, "id": "LUplruw9Wr_L" }, "outputs": [], "source": [ "def plot_separation_lines(height, width):\n", "\n", " block_size = [1, 4]\n", "\n", " # Add separation lines to the figure.\n", " num_hlines = int((height - 1) / block_size[0])\n", " num_vlines = int((width - 1) / block_size[1])\n", " line_y_pos = [y * block_size[0] for y in range(1, num_hlines + 1)]\n", " line_x_pos = [x * block_size[1] for x in range(1, num_vlines + 1)]\n", "\n", " for y_pos in line_y_pos:\n", " plt.plot([-0.5, width], [y_pos - 0.5 , y_pos - 0.5], color='w')\n", "\n", " for x_pos in line_x_pos:\n", " plt.plot([x_pos - 0.5, x_pos - 0.5], [-0.5, height], color='w')" ] }, { "cell_type": "markdown", "metadata": { "id": "sbyjrRgLWr_L" }, "source": [ "Now visualize the subset of the weight tensor." ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "execution": { "iopub.execute_input": "2025-06-21T11:30:32.838745Z", "iopub.status.busy": "2025-06-21T11:30:32.838373Z", "iopub.status.idle": "2025-06-21T11:30:33.053212Z", "shell.execute_reply": "2025-06-21T11:30:33.052452Z" }, "id": "ATeyf5vCWr_L" }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot_separation_lines(height, width)\n", "\n", "plt.axis('off')\n", "plt.imshow(subset_values_to_display)\n", "plt.colorbar()\n", "plt.title(\"Structural pruning for Dense layer\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "72f7VlOLWr_M" }, "source": [ "Visualize weights for the `Conv2D` layer. The structural sparsity is applied in the last channel, similar to the `Dense` layer. Only the second `Conv2D` layer is structurally pruned as pointed out above." ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "execution": { "iopub.execute_input": "2025-06-21T11:30:33.056113Z", "iopub.status.busy": "2025-06-21T11:30:33.055841Z", "iopub.status.idle": "2025-06-21T11:30:33.061027Z", "shell.execute_reply": "2025-06-21T11:30:33.060415Z" }, "id": "_Dkbt7eRWr_M" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Shape of the weight tensor is (64, 5, 5, 32)\n" ] } ], "source": [ "# Get weights of the convolutional layer that has been pruned with 2 by 4 sparsity.\n", "op_details = interpreter._get_ops_details()\n", "op_name = 'CONV_2D'\n", "op_detail = [x for x in op_details if op_name in x[\"op_name\"]]\n", "tensor_data = interpreter.tensor(op_detail[1][\"inputs\"][1])()\n", "print(f\"Shape of the weight tensor is {tensor_data.shape}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "m7a6uTdLWr_M" }, "source": [ "Similar to the weights of `Dense` layer, the last dimension of the kernel has a (2, 4) structure." ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "execution": { "iopub.execute_input": "2025-06-21T11:30:33.063626Z", "iopub.status.busy": "2025-06-21T11:30:33.063395Z", "iopub.status.idle": "2025-06-21T11:30:33.197818Z", "shell.execute_reply": "2025-06-21T11:30:33.197088Z" }, "id": "wyvLpfa6Wr_M" }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "weights_to_display = tf.reshape(tensor_data, [tf.reduce_prod(tensor_data.shape[:-1]), -1])\n", "weights_to_display = weights_to_display[0:width, 0:height]\n", "\n", "val_ones = np.ones([height, width])\n", "val_zeros = np.zeros([height, width])\n", "subset_values_to_display = np.where(abs(weights_to_display) > 1e-9, val_ones, val_zeros)\n", "\n", "plot_separation_lines(height, width)\n", "\n", "plt.axis('off')\n", "plt.imshow(subset_values_to_display)\n", "plt.colorbar()\n", "plt.title(\"Structurally pruned weights for Conv2D layer\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "1aX2O8w0Wr_M" }, "source": [ "Let's see how those randomly pruned weights look. We extract them and display a subset of the weight tensor." ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "execution": { "iopub.execute_input": "2025-06-21T11:30:33.200954Z", "iopub.status.busy": "2025-06-21T11:30:33.200698Z", "iopub.status.idle": "2025-06-21T11:30:33.205252Z", "shell.execute_reply": "2025-06-21T11:30:33.204534Z" }, "id": "eEHu5nizWr_M" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Shape of the weight tensor is (32, 5, 5, 1)\n" ] } ], "source": [ "# Get weights of the convolutional layer that has been pruned with random pruning.\n", "tensor_name = 'pruning_sparsity_0_5/Conv2D'\n", "detail = [x for x in details if tensor_name in x[\"name\"]]\n", "tensor_data = interpreter.tensor(detail[0][\"index\"])()\n", "print(f\"Shape of the weight tensor is {tensor_data.shape}\")" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "execution": { "iopub.execute_input": "2025-06-21T11:30:33.207690Z", "iopub.status.busy": "2025-06-21T11:30:33.207458Z", "iopub.status.idle": "2025-06-21T11:30:33.335957Z", "shell.execute_reply": "2025-06-21T11:30:33.335218Z" }, "id": "Cimzp3kVWr_M" }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "weights_to_display = tf.reshape(tensor_data, [tensor_data.shape[0],tf.reduce_prod(tensor_data.shape[1:])])\n", "weights_to_display = weights_to_display[0:width, 0:height]\n", "\n", "val_ones = np.ones([height, width])\n", "val_zeros = np.zeros([height, width])\n", "subset_values_to_display = np.where(abs(weights_to_display) > 0, val_ones, val_zeros)\n", "\n", "plot_separation_lines(height, width)\n", "\n", "plt.axis('off')\n", "plt.imshow(subset_values_to_display)\n", "plt.colorbar()\n", "plt.title(\"Unstructed pruned weights for Conv2D layer\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "vqsfZdMpWr_N" }, "source": [ "The TensorFlow Model Optimization Toolkit includes a python script that can be used to check whether which layers in the model from the given tflite file have the structurally pruned weights: [`check_sparsity_m_by_n.py`](https://github.com/tensorflow/model-optimization/blob/master/tensorflow_model_optimization/python/core/sparsity/keras/tools/check_sparsity_m_by_n.py). The following command demonstrates how to use this tool to check for 2 by 4 sparsity in a specific model." ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "execution": { "iopub.execute_input": "2025-06-21T11:30:33.339059Z", "iopub.status.busy": "2025-06-21T11:30:33.338803Z", "iopub.status.idle": "2025-06-21T11:30:33.567340Z", "shell.execute_reply": "2025-06-21T11:30:33.566209Z" }, "id": "7HDYffebWr_N" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "python3: can't open file '/tmpfs/src/temp/tensorflow_model_optimization/g3doc/guide/pruning/./tensorflow_model_optimization/python/core/sparsity/keras/tools/check_sparsity_m_by_n.py': [Errno 2] No such file or directory\r\n" ] } ], "source": [ "! python3 ./tensorflow_model_optimization/python/core/sparsity/keras/tools/check_sparsity_m_by_n.py --model_tflite=pruned_model.tflite --m_by_n=2,4\n" ] } ], "metadata": { "colab": { "collapsed_sections": [ "Tce3stUlHN0L" ], "name": "pruning_with_sparsity_2_by_4.ipynb", "provenance": [], "toc_visible": true }, "interpreter": { "hash": "5be03e09ac1816611305450014280c0b9eb46a3a95e12dcae8d73de01e2da776" }, "kernelspec": { "display_name": "Python 3.6.9 64-bit ('mo': venv)", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.23" } }, "nbformat": 4, "nbformat_minor": 0 }