{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "826IBSWMN4rr" }, "source": [ "##### Copyright 2020 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2025-06-21T11:18:37.567799Z", "iopub.status.busy": "2025-06-21T11:18:37.567500Z", "iopub.status.idle": "2025-06-21T11:18:37.571557Z", "shell.execute_reply": "2025-06-21T11:18:37.570938Z" }, "id": "ITj3u97-tNR7" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "BYwfpc4wN4rt" }, "source": [ "# Weight clustering comprehensive guide" ] }, { "cell_type": "markdown", "metadata": { "id": "IFva_Ed5N4ru" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
\n", " View on TensorFlow.org\n", " \n", " Run in Google Colab\n", " \n", " View source on GitHub\n", " \n", " Download notebook\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "tidmcl3sN4rv" }, "source": [ "Welcome to the comprehensive guide for *weight clustering*, part of the TensorFlow Model Optimization toolkit.\n", "\n", "This page documents various use cases and shows how to use the API for each one. Once you know which APIs you need, find the parameters and the low-level details in the [API docs](https://www.tensorflow.org/model_optimization/api_docs/python/tfmot/clustering):\n", "\n", "* If you want to see the benefits of weight clustering and what's supported, check the [overview](https://www.tensorflow.org/model_optimization/guide/clustering).\n", "* For a single end-to-end example, see the [weight clustering example](https://www.tensorflow.org/model_optimization/guide/clustering/clustering_example).\n", "\n", "In this guide, the following use cases are covered:\n", "* Define a clustered model.\n", "* Checkpoint and deserialize a clustered model.\n", "* Improve the accuracy of the clustered model.\n", "* For deployment only, you must take steps to see compression benefits.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "RRtKxbo8N4rv" }, "source": [ "## Setup\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2025-06-21T11:18:37.574981Z", "iopub.status.busy": "2025-06-21T11:18:37.574456Z", "iopub.status.idle": "2025-06-21T11:18:45.454370Z", "shell.execute_reply": "2025-06-21T11:18:45.453638Z" }, "id": "08dJRvOqN4rw" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-06-21 11:18:41.642629: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", "E0000 00:00:1750504721.664821 14116 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "E0000 00:00:1750504721.671838 14116 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", "W0000 00:00:1750504721.689764 14116 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n", "W0000 00:00:1750504721.689787 14116 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n", "W0000 00:00:1750504721.689789 14116 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n", "W0000 00:00:1750504721.689792 14116 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Model: \"sequential_1\"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Layer (type) Output Shape Param # \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " dense_1 (Dense) (None, 20) 420 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " flatten_1 (Flatten) (None, 20) 0 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Total params: 420 (1.64 KB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Trainable params: 420 (1.64 KB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Non-trainable params: 0 (0.00 Byte)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2025-06-21 11:18:44.945673: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/1 [==============================] - ETA: 0s - loss: 16.1181 - accuracy: 0.0000e+00" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1/1 [==============================] - 0s 342ms/step - loss: 16.1181 - accuracy: 0.0000e+00\n" ] } ], "source": [ "! pip install -q tensorflow-model-optimization\n", "\n", "import tensorflow as tf\n", "import tf_keras as keras\n", "import numpy as np\n", "import tempfile\n", "import os\n", "import tensorflow_model_optimization as tfmot\n", "\n", "input_dim = 20\n", "output_dim = 20\n", "x_train = np.random.randn(1, input_dim).astype(np.float32)\n", "y_train = keras.utils.to_categorical(np.random.randn(1), num_classes=output_dim)\n", "\n", "def setup_model():\n", " model = keras.Sequential([\n", " keras.layers.Dense(input_dim, input_shape=[input_dim]),\n", " keras.layers.Flatten()\n", " ])\n", " return model\n", "\n", "def train_model(model):\n", " model.compile(\n", " loss=keras.losses.categorical_crossentropy,\n", " optimizer='adam',\n", " metrics=['accuracy']\n", " )\n", " model.summary()\n", " model.fit(x_train, y_train)\n", " return model\n", "\n", "def save_model_weights(model):\n", " _, pretrained_weights = tempfile.mkstemp('.h5')\n", " model.save_weights(pretrained_weights)\n", " return pretrained_weights\n", "\n", "def setup_pretrained_weights():\n", " model= setup_model()\n", " model = train_model(model)\n", " pretrained_weights = save_model_weights(model)\n", " return pretrained_weights\n", "\n", "def setup_pretrained_model():\n", " model = setup_model()\n", " pretrained_weights = setup_pretrained_weights()\n", " model.load_weights(pretrained_weights)\n", " return model\n", "\n", "def save_model_file(model):\n", " _, keras_file = tempfile.mkstemp('.h5') \n", " model.save(keras_file, include_optimizer=False)\n", " return keras_file\n", "\n", "def get_gzipped_model_size(model):\n", " # It returns the size of the gzipped model in bytes.\n", " import os\n", " import zipfile\n", "\n", " keras_file = save_model_file(model)\n", "\n", " _, zipped_file = tempfile.mkstemp('.zip')\n", " with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:\n", " f.write(keras_file)\n", " return os.path.getsize(zipped_file)\n", "\n", "setup_model()\n", "pretrained_weights = setup_pretrained_weights()" ] }, { "cell_type": "markdown", "metadata": { "id": "ARd37qONN4rz" }, "source": [ "## Define a clustered model\n" ] }, { "cell_type": "markdown", "metadata": { "id": "zHB3pkU3N4r0" }, "source": [ "### Cluster a whole model (sequential and functional)" ] }, { "cell_type": "markdown", "metadata": { "id": "ig-il1lmN4r1" }, "source": [ "**Tips** for better model accuracy:\n", "\n", "* You must pass a pre-trained model with acceptable accuracy to this API. Training models from scratch with clustering results in subpar accuracy. \n", "* In some cases, clustering certain layers has a detrimental effect on model accuracy. Check \"Cluster some layers\" to see how to skip clustering the layers that affect accuracy the most.\n", "\n", "To cluster all layers, apply `tfmot.clustering.keras.cluster_weights` to the model.\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2025-06-21T11:18:45.457762Z", "iopub.status.busy": "2025-06-21T11:18:45.457463Z", "iopub.status.idle": "2025-06-21T11:18:45.811639Z", "shell.execute_reply": "2025-06-21T11:18:45.810962Z" }, "id": "29g7OADjN4r1" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"sequential_2\"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Layer (type) Output Shape Param # \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " cluster_dense_2 (ClusterWe (None, 20) 823 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " ights) \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " cluster_flatten_2 (Cluster (None, 20) 0 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Weights) \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Total params: 823 (4.78 KB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Trainable params: 423 (1.65 KB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Non-trainable params: 400 (3.12 KB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] } ], "source": [ "import tensorflow_model_optimization as tfmot\n", "\n", "cluster_weights = tfmot.clustering.keras.cluster_weights\n", "CentroidInitialization = tfmot.clustering.keras.CentroidInitialization\n", "\n", "clustering_params = {\n", " 'number_of_clusters': 3,\n", " 'cluster_centroids_init': CentroidInitialization.KMEANS_PLUS_PLUS\n", "}\n", "\n", "model = setup_model()\n", "model.load_weights(pretrained_weights)\n", "\n", "clustered_model = cluster_weights(model, **clustering_params)\n", "\n", "clustered_model.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "zEOHK4OON4r7" }, "source": [ "### Cluster some layers (sequential and functional models)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "ENscQ7ZWN4r8" }, "source": [ "**Tips** for better model accuracy:\n", "\n", "* You must pass a pre-trained model with acceptable accuracy to this API. Training models from scratch with clustering results in subpar accuracy.\n", "* Cluster later layers with more redundant parameters (e.g. `keras.layers.Dense`, `keras.layers.Conv2D`), as opposed to the early layers.\n", "* Freeze early layers prior to the clustered layers during fine-tuning. Treat the number of frozen layers as a hyperparameter. Empirically, freezing most early layers is ideal for the current clustering API.\n", "* Avoid clustering critical layers (e.g. attention mechanism).\n", "\n", "**More**: the `tfmot.clustering.keras.cluster_weights` API docs provide details on how to vary the clustering configuration per layer." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2025-06-21T11:18:45.814456Z", "iopub.status.busy": "2025-06-21T11:18:45.814210Z", "iopub.status.idle": "2025-06-21T11:18:45.867716Z", "shell.execute_reply": "2025-06-21T11:18:45.866846Z" }, "id": "IqBdl3uJN4r_" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"sequential_3\"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Layer (type) Output Shape Param # \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " cluster_dense_3 (ClusterWe (None, 20) 823 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " ights) \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " flatten_3 (Flatten) (None, 20) 0 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Total params: 823 (4.78 KB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Trainable params: 423 (1.65 KB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Non-trainable params: 400 (3.12 KB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] } ], "source": [ "# Create a base model\n", "base_model = setup_model()\n", "base_model.load_weights(pretrained_weights)\n", "\n", "# Helper function uses `cluster_weights` to make only \n", "# the Dense layers train with clustering\n", "def apply_clustering_to_dense(layer):\n", " if isinstance(layer, keras.layers.Dense):\n", " return cluster_weights(layer, **clustering_params)\n", " return layer\n", "\n", "# Use `keras.models.clone_model` to apply `apply_clustering_to_dense` \n", "# to the layers of the model.\n", "clustered_model = keras.models.clone_model(\n", " base_model,\n", " clone_function=apply_clustering_to_dense,\n", ")\n", "\n", "clustered_model.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "bU0SIhY2Q63C" }, "source": [ "### Cluster convolutional layers per channel\n", "\n", "The clustered model could be passed to further optimizations such as a [post training quantization](https://www.tensorflow.org/lite/performance/post_training_quantization). If the quantization is done per channel, then the model should be clustered per channel as well. This increases the accuracy of the clustered and quantized model.\n", "\n", "**Note:** only Conv2D layers are clustered per channel\n", "\n", "To cluster per channel, the parameter `cluster_per_channel` should be set to `True`. It could be set for some layers or for the whole model.\n", "\n", "**Tips:**\n", "\n", "* If a model is to be quantized further, you can consider to use [cluster preserving QAT technique](https://www.tensorflow.org/model_optimization/guide/combine/collaborative_optimization).\n", "\n", "* The model could be pruned before applying the clustering per channel. With the parameter `preserve_sparsity` is set to `True`, the sparsity is preserved during the clustering per channel. Note that the [sparsity and cluster preserving QAT technique](https://www.tensorflow.org/model_optimization/guide/combine/collaborative_optimization) should be used in this case." ] }, { "cell_type": "markdown", "metadata": { "id": "WcFrw1dHmxTr" }, "source": [ "### Cluster custom Keras layer or specify which weights of layer to cluster\n", "\n", "`tfmot.clustering.keras.ClusterableLayer` serves two use cases:\n", "1. Cluster any layer that is not supported natively, including a custom Keras layer.\n", "2. Specify which weights of a supported layer are to be clustered.\n", "\n", "For an example, the API defaults to only clustering the kernel of the\n", "`Dense` layer. The example below shows how to modify it to also cluster the bias. Note that when deriving from the keras layer, you need to override the function `get_clusterable_weights`, where you specify the name of the trainable variable to be clustered and the trainable variable itself. For example, if you return an empty list [], then no weights will be clusterable.\n", "\n", "**Common mistake:** clustering the bias usually harms model accuracy too much." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2025-06-21T11:18:45.870715Z", "iopub.status.busy": "2025-06-21T11:18:45.870461Z", "iopub.status.idle": "2025-06-21T11:18:45.926712Z", "shell.execute_reply": "2025-06-21T11:18:45.926088Z" }, "id": "73iboQ7MmxTs" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"sequential_4\"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Layer (type) Output Shape Param # \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " cluster_my_dense_layer (Cl (None, 20) 846 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " usterWeights) \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " flatten_4 (Flatten) (None, 20) 0 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Total params: 846 (4.95 KB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Trainable params: 426 (1.66 KB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Non-trainable params: 420 (3.28 KB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] } ], "source": [ "class MyDenseLayer(keras.layers.Dense, tfmot.clustering.keras.ClusterableLayer):\n", "\n", " def get_clusterable_weights(self):\n", " # Cluster kernel and bias. This is just an example, clustering\n", " # bias usually hurts model accuracy.\n", " return [('kernel', self.kernel), ('bias', self.bias)]\n", "\n", "# Use `cluster_weights` to make the `MyDenseLayer` layer train with clustering as usual.\n", "model_for_clustering = keras.Sequential([\n", " tfmot.clustering.keras.cluster_weights(MyDenseLayer(20, input_shape=[input_dim]), **clustering_params),\n", " keras.layers.Flatten()\n", "])\n", "\n", "model_for_clustering.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "SYlWPXEWmxTs" }, "source": [ "You may also use `tfmot.clustering.keras.ClusterableLayer` to cluster a keras custom layer. To do this, you extend `keras.Layer` as usual and implement the `__init__`, `call`, and `build` functions, but you also need to extend the `clusterable_layer.ClusterableLayer` class and implement:\n", "1. `get_clusterable_weights`, where you specify the weight kernel to be clustered, as shown above.\n", "2. `get_clusterable_algorithm`, where you specify the clustering algorithm for the weight tensor. This is because you need to specify how the custom layer weights are shaped for clustering. The returned clustering algorithm class should be derived from the `clustering_algorithm.ClusteringAlgorithm` class and the function `get_pulling_indices` should be overwritten. An example of this function, which supports weights of ranks 1D, 2D, and 3D, can be found [here]( https://github.com/tensorflow/model-optimization/blob/18e87d262e536c9a742aef700880e71b47a7f768/tensorflow_model_optimization/python/core/clustering/keras/clustering_algorithm.py#L62).\n", "\n", "An example of this use case can be found [here](https://github.com/tensorflow/model-optimization/blob/master/tensorflow_model_optimization/python/core/clustering/keras/mnist_clusterable_layer_test.py)." ] }, { "cell_type": "markdown", "metadata": { "id": "hN0DgpvD5Add" }, "source": [ "## Checkpoint and deserialize a clustered model" ] }, { "cell_type": "markdown", "metadata": { "id": "hfji5KWN6XCF" }, "source": [ "**Your use case:** this code is only needed for the HDF5 model format (not HDF5 weights or other formats)." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2025-06-21T11:18:45.929746Z", "iopub.status.busy": "2025-06-21T11:18:45.929493Z", "iopub.status.idle": "2025-06-21T11:18:46.031181Z", "shell.execute_reply": "2025-06-21T11:18:46.030514Z" }, "id": "w7P67mPk6RkQ" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:No training configuration found in the save file, so the model was *not* compiled. Compile it manually.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Model: \"sequential_5\"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Layer (type) Output Shape Param # \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " cluster_dense_4 (ClusterWe (None, 20) 823 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " ights) \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " cluster_flatten_5 (Cluster (None, 20) 0 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Weights) \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Total params: 823 (4.78 KB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Trainable params: 423 (1.65 KB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Non-trainable params: 400 (3.12 KB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tf_keras/src/engine/training.py:3098: UserWarning: You are saving your model as an HDF5 file via `model.save()`. This file format is considered legacy. We recommend using instead the native TF-Keras format, e.g. `model.save('my_model.keras')`.\n", " saving_api.save_model(\n" ] } ], "source": [ "# Define the model.\n", "base_model = setup_model()\n", "base_model.load_weights(pretrained_weights)\n", "clustered_model = cluster_weights(base_model, **clustering_params)\n", "\n", "# Save or checkpoint the model.\n", "_, keras_model_file = tempfile.mkstemp('.h5')\n", "clustered_model.save(keras_model_file, include_optimizer=True)\n", "\n", "# `cluster_scope` is needed for deserializing HDF5 models.\n", "with tfmot.clustering.keras.cluster_scope():\n", " loaded_model = keras.models.load_model(keras_model_file)\n", "\n", "loaded_model.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "cUv-scK-N4sN" }, "source": [ "## Improve the accuracy of the clustered model" ] }, { "cell_type": "markdown", "metadata": { "id": "-fZZopDBN4sO" }, "source": [ "For your specific use case, there are tips you can consider:\n", "\n", "* Centroid initialization plays a key role in the final optimized model accuracy. In general, kmeans++ initialization outperforms linear, density and random initialization. When not using kmeans++, linear initialization tends to outperform density and random initialization, since it does not tend to miss large weights. However, density initialization has been observed to give better accuracy for the case of using very few clusters on weights with bimodal distributions.\n", "\n", "* Set a learning rate that is lower than the one used in training when fine-tuning the clustered model.\n", "\n", "* For general ideas to improve model accuracy, look for tips for your use case(s) under \"Define a clustered model\"." ] }, { "cell_type": "markdown", "metadata": { "id": "4DXw7YbyN4sP" }, "source": [ "## Deployment" ] }, { "cell_type": "markdown", "metadata": { "id": "5Y5zLfPzN4sQ" }, "source": [ "### Export model with size compression" ] }, { "cell_type": "markdown", "metadata": { "id": "wX4OrHD9N4sQ" }, "source": [ "**Common mistake**: both `strip_clustering` and applying a standard compression algorithm (e.g. via gzip) are necessary to see the compression benefits of clustering." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2025-06-21T11:18:46.034548Z", "iopub.status.busy": "2025-06-21T11:18:46.034292Z", "iopub.status.idle": "2025-06-21T11:18:46.596998Z", "shell.execute_reply": "2025-06-21T11:18:46.596241Z" }, "id": "ZvuiCBsVN4sR" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/1 [==============================] - ETA: 0s - loss: 1.9171 - accuracy: 0.0000e+00" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1/1 [==============================] - 0s 392ms/step - loss: 1.9171 - accuracy: 0.0000e+00\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "final model\n", "Model: \"sequential_6\"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Layer (type) Output Shape Param # \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " dense_5 (Dense) (None, 20) 420 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " flatten_6 (Flatten) (None, 20) 0 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "=================================================================\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Total params: 420 (1.64 KB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Trainable params: 420 (1.64 KB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Non-trainable params: 0 (0.00 Byte)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "Size of gzipped clustered model without stripping: 3528.00 bytes\n", "WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Size of gzipped clustered model with stripping: 1477.00 bytes\n" ] } ], "source": [ "model = setup_model()\n", "clustered_model = cluster_weights(model, **clustering_params)\n", "\n", "clustered_model.compile(\n", " loss=keras.losses.categorical_crossentropy,\n", " optimizer='adam',\n", " metrics=['accuracy']\n", ")\n", "\n", "clustered_model.fit(\n", " x_train,\n", " y_train\n", ")\n", "\n", "final_model = tfmot.clustering.keras.strip_clustering(clustered_model)\n", "\n", "print(\"final model\")\n", "final_model.summary()\n", "\n", "print(\"\\n\")\n", "print(\"Size of gzipped clustered model without stripping: %.2f bytes\" \n", " % (get_gzipped_model_size(clustered_model)))\n", "print(\"Size of gzipped clustered model with stripping: %.2f bytes\" \n", " % (get_gzipped_model_size(final_model)))" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "clustering_comprehensive_guide.ipynb", "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.23" } }, "nbformat": 4, "nbformat_minor": 0 }