**Copyright 2021 The TensorFlow Authors.**

In [1]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Sparsity preserving clustering Keras example

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://www.tensorflow.org/model_optimization/guide/combine/sparse_clustering_example"><img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/model-optimization/blob/master/tensorflow_model_optimization/g3doc/guide/combine/sparse_clustering_example.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/model-optimization/blob/master/tensorflow_model_optimization/g3doc/guide/combine/sparse_clustering_example.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View on GitHub</a>
  </td>
  <td>
    <a href="https://storage.googleapis.com/tensorflow_docs/model-optimization/tensorflow_model_optimization/g3doc/guide/combine/sparse_clustering_example.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png" />Download notebook</a>
  </td>
</table>

## Overview

This is an end to end example showing the usage of the **sparsity preserving clustering** API, part of the TensorFlow Model Optimization Toolkit's collaborative optimization pipeline.

### Other pages

For an introduction to the pipeline and other available techniques, see the [collaborative optimization overview page](https://www.tensorflow.org/model_optimization/guide/combine/collaborative_optimization).

### Contents

In the tutorial, you will:

1. Train a `keras` model for the MNIST dataset from scratch.
2. Fine-tune the model with sparsity and see the accuracy and observe that the model was successfully pruned.
3. Apply weight clustering to the pruned model and observe the loss of sparsity.
4. Apply sparsity preserving clustering on the pruned model and observe that the sparsity applied earlier has been preserved.
5. Generate a TFLite model and check that the accuracy has been preserved in the pruned clustered model.
6. Compare the sizes of the different models to observe the compression benefits of applying sparsity followed by the collaborative optimization technique of sparsity preserving clustering.

## Setup

You can run this Jupyter Notebook in your local [virtualenv](https://www.tensorflow.org/install/pip?lang=python3#2.-create-a-virtual-environment-recommended) or [colab](https://colab.sandbox.google.com/). For details of setting up dependencies, please refer to the [installation guide](https://www.tensorflow.org/model_optimization/guide/install). 

In [2]:
! pip install -q tensorflow-model-optimization

In [3]:
import tensorflow as tf
import tf_keras as keras

import numpy as np
import tempfile
import zipfile
import os

2025-06-21 11:55:37.868386: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1750506937.890257   44854 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1750506937.896819   44854 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1750506937.914420   44854 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1750506937.914442   44854 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1750506937.914445   44854 computation_placer.cc:177] computation placer alr

## Train a keras model for MNIST to be pruned and clustered

In [4]:
# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images  = test_images / 255.0

model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=12, kernel_size=(3, 3),
                         activation=tf.nn.relu),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])

# Train the digit classification model
model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(
    train_images,
    train_labels,
    validation_split=0.1,
    epochs=10
)

2025-06-21 11:55:41.605845: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected


Epoch 1/10


   1/1688 [..............................] - ETA: 12:45 - loss: 2.3935 - accuracy: 0.0938

  11/1688 [..............................] - ETA: 8s - loss: 2.2004 - accuracy: 0.2472   

  22/1688 [..............................] - ETA: 8s - loss: 2.0276 - accuracy: 0.3722

  34/1688 [..............................] - ETA: 8s - loss: 1.8420 - accuracy: 0.4660

  46/1688 [..............................] - ETA: 7s - loss: 1.6725 - accuracy: 0.5387

  58/1688 [>.............................] - ETA: 7s - loss: 1.5125 - accuracy: 0.5862

  70/1688 [>.............................] - ETA: 7s - loss: 1.3746 - accuracy: 0.6299

  82/1688 [>.............................] - ETA: 7s - loss: 1.2730 - accuracy: 0.6574

  94/1688 [>.............................] - ETA: 7s - loss: 1.1888 - accuracy: 0.6772

 106/1688 [>.............................] - ETA: 7s - loss: 1.1203 - accuracy: 0.6943

 118/1688 [=>............................] - ETA: 7s - loss: 1.0542 - accuracy: 0.7103

 130/1688 [=>............................] - ETA: 6s - loss: 0.9957 - accuracy: 0.7260

 142/1688 [=>............................] - ETA: 6s - loss: 0.9497 - accuracy: 0.7379

 154/1688 [=>............................] - ETA: 6s - loss: 0.9127 - accuracy: 0.7455

 166/1688 [=>............................] - ETA: 6s - loss: 0.8726 - accuracy: 0.7560

 178/1688 [==>...........................] - ETA: 6s - loss: 0.8395 - accuracy: 0.7662

 191/1688 [==>...........................] - ETA: 6s - loss: 0.8098 - accuracy: 0.7742

 203/1688 [==>...........................] - ETA: 6s - loss: 0.7880 - accuracy: 0.7794

 215/1688 [==>...........................] - ETA: 6s - loss: 0.7648 - accuracy: 0.7855

 227/1688 [===>..........................] - ETA: 6s - loss: 0.7455 - accuracy: 0.7905

 239/1688 [===>..........................] - ETA: 6s - loss: 0.7242 - accuracy: 0.7958

 251/1688 [===>..........................] - ETA: 6s - loss: 0.7048 - accuracy: 0.8012

 263/1688 [===>..........................] - ETA: 6s - loss: 0.6877 - accuracy: 0.8062

 275/1688 [===>..........................] - ETA: 6s - loss: 0.6704 - accuracy: 0.8106

 287/1688 [====>.........................] - ETA: 6s - loss: 0.6596 - accuracy: 0.8139

 299/1688 [====>.........................] - ETA: 6s - loss: 0.6469 - accuracy: 0.8171

 311/1688 [====>.........................] - ETA: 5s - loss: 0.6352 - accuracy: 0.8204

 323/1688 [====>.........................] - ETA: 5s - loss: 0.6248 - accuracy: 0.8239

 335/1688 [====>.........................] - ETA: 5s - loss: 0.6159 - accuracy: 0.8262

 347/1688 [=====>........................] - ETA: 5s - loss: 0.6050 - accuracy: 0.8288

 359/1688 [=====>........................] - ETA: 5s - loss: 0.5938 - accuracy: 0.8323

 371/1688 [=====>........................] - ETA: 5s - loss: 0.5867 - accuracy: 0.8337

 383/1688 [=====>........................] - ETA: 5s - loss: 0.5775 - accuracy: 0.8364



























































































































































































































Epoch 2/10


   1/1688 [..............................] - ETA: 9s - loss: 0.2710 - accuracy: 0.9375

  13/1688 [..............................] - ETA: 7s - loss: 0.1487 - accuracy: 0.9543

  25/1688 [..............................] - ETA: 7s - loss: 0.1349 - accuracy: 0.9588

  37/1688 [..............................] - ETA: 7s - loss: 0.1497 - accuracy: 0.9561

  49/1688 [..............................] - ETA: 6s - loss: 0.1629 - accuracy: 0.9534

  61/1688 [>.............................] - ETA: 6s - loss: 0.1577 - accuracy: 0.9554

  73/1688 [>.............................] - ETA: 6s - loss: 0.1540 - accuracy: 0.9546

  85/1688 [>.............................] - ETA: 6s - loss: 0.1480 - accuracy: 0.9555

  97/1688 [>.............................] - ETA: 6s - loss: 0.1521 - accuracy: 0.9552

 109/1688 [>.............................] - ETA: 6s - loss: 0.1514 - accuracy: 0.9564

 121/1688 [=>............................] - ETA: 6s - loss: 0.1481 - accuracy: 0.9569

 133/1688 [=>............................] - ETA: 6s - loss: 0.1523 - accuracy: 0.9558

 145/1688 [=>............................] - ETA: 6s - loss: 0.1532 - accuracy: 0.9563

 157/1688 [=>............................] - ETA: 6s - loss: 0.1525 - accuracy: 0.9564

 169/1688 [==>...........................] - ETA: 6s - loss: 0.1509 - accuracy: 0.9573

 181/1688 [==>...........................] - ETA: 6s - loss: 0.1523 - accuracy: 0.9565

 193/1688 [==>...........................] - ETA: 6s - loss: 0.1539 - accuracy: 0.9563

 205/1688 [==>...........................] - ETA: 6s - loss: 0.1526 - accuracy: 0.9569

 217/1688 [==>...........................] - ETA: 6s - loss: 0.1539 - accuracy: 0.9565

 229/1688 [===>..........................] - ETA: 6s - loss: 0.1534 - accuracy: 0.9569

 241/1688 [===>..........................] - ETA: 6s - loss: 0.1506 - accuracy: 0.9581

 253/1688 [===>..........................] - ETA: 6s - loss: 0.1517 - accuracy: 0.9574

 266/1688 [===>..........................] - ETA: 6s - loss: 0.1515 - accuracy: 0.9570

 278/1688 [===>..........................] - ETA: 5s - loss: 0.1523 - accuracy: 0.9567

 290/1688 [====>.........................] - ETA: 5s - loss: 0.1521 - accuracy: 0.9571

 302/1688 [====>.........................] - ETA: 5s - loss: 0.1524 - accuracy: 0.9574

 315/1688 [====>.........................] - ETA: 5s - loss: 0.1524 - accuracy: 0.9571

 327/1688 [====>.........................] - ETA: 5s - loss: 0.1517 - accuracy: 0.9569

 340/1688 [=====>........................] - ETA: 5s - loss: 0.1521 - accuracy: 0.9567

 352/1688 [=====>........................] - ETA: 5s - loss: 0.1522 - accuracy: 0.9564

 364/1688 [=====>........................] - ETA: 5s - loss: 0.1509 - accuracy: 0.9566

 377/1688 [=====>........................] - ETA: 5s - loss: 0.1526 - accuracy: 0.9564

 390/1688 [=====>........................] - ETA: 5s - loss: 0.1524 - accuracy: 0.9566

























































































































































































































Epoch 3/10


   1/1688 [..............................] - ETA: 8s - loss: 0.1688 - accuracy: 0.9375

  14/1688 [..............................] - ETA: 6s - loss: 0.1199 - accuracy: 0.9665

  27/1688 [..............................] - ETA: 6s - loss: 0.1216 - accuracy: 0.9664

  40/1688 [..............................] - ETA: 6s - loss: 0.1174 - accuracy: 0.9648

  53/1688 [..............................] - ETA: 6s - loss: 0.1216 - accuracy: 0.9611

  66/1688 [>.............................] - ETA: 6s - loss: 0.1170 - accuracy: 0.9616

  79/1688 [>.............................] - ETA: 6s - loss: 0.1156 - accuracy: 0.9648

  91/1688 [>.............................] - ETA: 6s - loss: 0.1143 - accuracy: 0.9653

 104/1688 [>.............................] - ETA: 6s - loss: 0.1145 - accuracy: 0.9660

 116/1688 [=>............................] - ETA: 6s - loss: 0.1152 - accuracy: 0.9661

 128/1688 [=>............................] - ETA: 6s - loss: 0.1122 - accuracy: 0.9675

 140/1688 [=>............................] - ETA: 6s - loss: 0.1090 - accuracy: 0.9685

 153/1688 [=>............................] - ETA: 6s - loss: 0.1070 - accuracy: 0.9688

 166/1688 [=>............................] - ETA: 6s - loss: 0.1070 - accuracy: 0.9689

 178/1688 [==>...........................] - ETA: 6s - loss: 0.1092 - accuracy: 0.9684

 190/1688 [==>...........................] - ETA: 6s - loss: 0.1084 - accuracy: 0.9694

 202/1688 [==>...........................] - ETA: 6s - loss: 0.1079 - accuracy: 0.9694

 215/1688 [==>...........................] - ETA: 6s - loss: 0.1073 - accuracy: 0.9692

 228/1688 [===>..........................] - ETA: 6s - loss: 0.1094 - accuracy: 0.9685

 241/1688 [===>..........................] - ETA: 5s - loss: 0.1083 - accuracy: 0.9688

 254/1688 [===>..........................] - ETA: 5s - loss: 0.1072 - accuracy: 0.9691

 267/1688 [===>..........................] - ETA: 5s - loss: 0.1089 - accuracy: 0.9685

 280/1688 [===>..........................] - ETA: 5s - loss: 0.1097 - accuracy: 0.9684

 293/1688 [====>.........................] - ETA: 5s - loss: 0.1107 - accuracy: 0.9678

 306/1688 [====>.........................] - ETA: 5s - loss: 0.1115 - accuracy: 0.9672

 318/1688 [====>.........................] - ETA: 5s - loss: 0.1116 - accuracy: 0.9671

 331/1688 [====>.........................] - ETA: 5s - loss: 0.1131 - accuracy: 0.9667

 344/1688 [=====>........................] - ETA: 5s - loss: 0.1124 - accuracy: 0.9666

 357/1688 [=====>........................] - ETA: 5s - loss: 0.1122 - accuracy: 0.9666

 370/1688 [=====>........................] - ETA: 5s - loss: 0.1116 - accuracy: 0.9666

 382/1688 [=====>........................] - ETA: 5s - loss: 0.1132 - accuracy: 0.9665















































































































































































































Epoch 4/10


   1/1688 [..............................] - ETA: 8s - loss: 0.0314 - accuracy: 1.0000

  14/1688 [..............................] - ETA: 6s - loss: 0.0905 - accuracy: 0.9710

  27/1688 [..............................] - ETA: 6s - loss: 0.0832 - accuracy: 0.9745

  40/1688 [..............................] - ETA: 6s - loss: 0.0828 - accuracy: 0.9734

  52/1688 [..............................] - ETA: 6s - loss: 0.0898 - accuracy: 0.9736

  65/1688 [>.............................] - ETA: 6s - loss: 0.0868 - accuracy: 0.9740

  78/1688 [>.............................] - ETA: 6s - loss: 0.0837 - accuracy: 0.9752

  90/1688 [>.............................] - ETA: 6s - loss: 0.0814 - accuracy: 0.9764

 103/1688 [>.............................] - ETA: 6s - loss: 0.0860 - accuracy: 0.9751

 115/1688 [=>............................] - ETA: 6s - loss: 0.0856 - accuracy: 0.9758

 127/1688 [=>............................] - ETA: 6s - loss: 0.0836 - accuracy: 0.9761

 139/1688 [=>............................] - ETA: 6s - loss: 0.0838 - accuracy: 0.9755

 151/1688 [=>............................] - ETA: 6s - loss: 0.0838 - accuracy: 0.9760

 163/1688 [=>............................] - ETA: 6s - loss: 0.0817 - accuracy: 0.9766

 175/1688 [==>...........................] - ETA: 6s - loss: 0.0807 - accuracy: 0.9770

 188/1688 [==>...........................] - ETA: 6s - loss: 0.0801 - accuracy: 0.9769

 201/1688 [==>...........................] - ETA: 6s - loss: 0.0802 - accuracy: 0.9767

 213/1688 [==>...........................] - ETA: 6s - loss: 0.0799 - accuracy: 0.9768

 226/1688 [===>..........................] - ETA: 6s - loss: 0.0806 - accuracy: 0.9764

 238/1688 [===>..........................] - ETA: 6s - loss: 0.0795 - accuracy: 0.9769

 251/1688 [===>..........................] - ETA: 5s - loss: 0.0795 - accuracy: 0.9766

 263/1688 [===>..........................] - ETA: 5s - loss: 0.0789 - accuracy: 0.9766

 275/1688 [===>..........................] - ETA: 5s - loss: 0.0774 - accuracy: 0.9772

 288/1688 [====>.........................] - ETA: 5s - loss: 0.0765 - accuracy: 0.9773

 301/1688 [====>.........................] - ETA: 5s - loss: 0.0767 - accuracy: 0.9774

 314/1688 [====>.........................] - ETA: 5s - loss: 0.0769 - accuracy: 0.9772

 326/1688 [====>.........................] - ETA: 5s - loss: 0.0758 - accuracy: 0.9776

 339/1688 [=====>........................] - ETA: 5s - loss: 0.0760 - accuracy: 0.9778

 351/1688 [=====>........................] - ETA: 5s - loss: 0.0770 - accuracy: 0.9776

 364/1688 [=====>........................] - ETA: 5s - loss: 0.0768 - accuracy: 0.9775

 376/1688 [=====>........................] - ETA: 5s - loss: 0.0772 - accuracy: 0.9774

 389/1688 [=====>........................] - ETA: 5s - loss: 0.0772 - accuracy: 0.9776























































































































































































































Epoch 5/10


   1/1688 [..............................] - ETA: 8s - loss: 0.0894 - accuracy: 0.9688

  13/1688 [..............................] - ETA: 7s - loss: 0.0485 - accuracy: 0.9832

  26/1688 [..............................] - ETA: 6s - loss: 0.0611 - accuracy: 0.9808

  38/1688 [..............................] - ETA: 6s - loss: 0.0696 - accuracy: 0.9786

  50/1688 [..............................] - ETA: 6s - loss: 0.0629 - accuracy: 0.9812

  63/1688 [>.............................] - ETA: 6s - loss: 0.0657 - accuracy: 0.9807

  75/1688 [>.............................] - ETA: 6s - loss: 0.0665 - accuracy: 0.9808

  88/1688 [>.............................] - ETA: 6s - loss: 0.0680 - accuracy: 0.9805

 100/1688 [>.............................] - ETA: 6s - loss: 0.0671 - accuracy: 0.9803

 113/1688 [=>............................] - ETA: 6s - loss: 0.0664 - accuracy: 0.9806

 125/1688 [=>............................] - ETA: 6s - loss: 0.0638 - accuracy: 0.9812

 138/1688 [=>............................] - ETA: 6s - loss: 0.0651 - accuracy: 0.9812

 150/1688 [=>............................] - ETA: 6s - loss: 0.0659 - accuracy: 0.9804

 162/1688 [=>............................] - ETA: 6s - loss: 0.0638 - accuracy: 0.9809

 174/1688 [==>...........................] - ETA: 6s - loss: 0.0633 - accuracy: 0.9808

 186/1688 [==>...........................] - ETA: 6s - loss: 0.0640 - accuracy: 0.9807

 198/1688 [==>...........................] - ETA: 6s - loss: 0.0631 - accuracy: 0.9811

 210/1688 [==>...........................] - ETA: 6s - loss: 0.0626 - accuracy: 0.9808

 222/1688 [==>...........................] - ETA: 6s - loss: 0.0632 - accuracy: 0.9809

 234/1688 [===>..........................] - ETA: 6s - loss: 0.0653 - accuracy: 0.9806

 246/1688 [===>..........................] - ETA: 6s - loss: 0.0663 - accuracy: 0.9802

 258/1688 [===>..........................] - ETA: 6s - loss: 0.0656 - accuracy: 0.9805

 270/1688 [===>..........................] - ETA: 5s - loss: 0.0646 - accuracy: 0.9807

 282/1688 [====>.........................] - ETA: 5s - loss: 0.0642 - accuracy: 0.9807

 294/1688 [====>.........................] - ETA: 5s - loss: 0.0638 - accuracy: 0.9808

 306/1688 [====>.........................] - ETA: 5s - loss: 0.0637 - accuracy: 0.9810

 318/1688 [====>.........................] - ETA: 5s - loss: 0.0628 - accuracy: 0.9812

 330/1688 [====>.........................] - ETA: 5s - loss: 0.0622 - accuracy: 0.9814

 342/1688 [=====>........................] - ETA: 5s - loss: 0.0630 - accuracy: 0.9814

 354/1688 [=====>........................] - ETA: 5s - loss: 0.0643 - accuracy: 0.9810

 366/1688 [=====>........................] - ETA: 5s - loss: 0.0647 - accuracy: 0.9809

 378/1688 [=====>........................] - ETA: 5s - loss: 0.0644 - accuracy: 0.9811

 390/1688 [=====>........................] - ETA: 5s - loss: 0.0645 - accuracy: 0.9812























































































































































































































Epoch 6/10


   1/1688 [..............................] - ETA: 8s - loss: 0.0329 - accuracy: 1.0000

  14/1688 [..............................] - ETA: 6s - loss: 0.0442 - accuracy: 0.9888

  26/1688 [..............................] - ETA: 6s - loss: 0.0483 - accuracy: 0.9832

  38/1688 [..............................] - ETA: 6s - loss: 0.0503 - accuracy: 0.9844

  50/1688 [..............................] - ETA: 6s - loss: 0.0486 - accuracy: 0.9850

  62/1688 [>.............................] - ETA: 6s - loss: 0.0489 - accuracy: 0.9854

  74/1688 [>.............................] - ETA: 6s - loss: 0.0468 - accuracy: 0.9856

  86/1688 [>.............................] - ETA: 6s - loss: 0.0490 - accuracy: 0.9851

  98/1688 [>.............................] - ETA: 6s - loss: 0.0503 - accuracy: 0.9847

 110/1688 [>.............................] - ETA: 6s - loss: 0.0522 - accuracy: 0.9838

 122/1688 [=>............................] - ETA: 6s - loss: 0.0572 - accuracy: 0.9834

 134/1688 [=>............................] - ETA: 6s - loss: 0.0547 - accuracy: 0.9837

 146/1688 [=>............................] - ETA: 6s - loss: 0.0546 - accuracy: 0.9842

 158/1688 [=>............................] - ETA: 6s - loss: 0.0544 - accuracy: 0.9842

 170/1688 [==>...........................] - ETA: 6s - loss: 0.0560 - accuracy: 0.9840

 183/1688 [==>...........................] - ETA: 6s - loss: 0.0562 - accuracy: 0.9833

 195/1688 [==>...........................] - ETA: 6s - loss: 0.0562 - accuracy: 0.9835

 207/1688 [==>...........................] - ETA: 6s - loss: 0.0565 - accuracy: 0.9832

 219/1688 [==>...........................] - ETA: 6s - loss: 0.0568 - accuracy: 0.9827

 231/1688 [===>..........................] - ETA: 6s - loss: 0.0574 - accuracy: 0.9820

 243/1688 [===>..........................] - ETA: 6s - loss: 0.0571 - accuracy: 0.9821

 255/1688 [===>..........................] - ETA: 6s - loss: 0.0563 - accuracy: 0.9826

 267/1688 [===>..........................] - ETA: 6s - loss: 0.0575 - accuracy: 0.9820

 279/1688 [===>..........................] - ETA: 5s - loss: 0.0584 - accuracy: 0.9817

 291/1688 [====>.........................] - ETA: 5s - loss: 0.0583 - accuracy: 0.9817

 303/1688 [====>.........................] - ETA: 5s - loss: 0.0577 - accuracy: 0.9821

 315/1688 [====>.........................] - ETA: 5s - loss: 0.0576 - accuracy: 0.9822

 327/1688 [====>.........................] - ETA: 5s - loss: 0.0586 - accuracy: 0.9819

 339/1688 [=====>........................] - ETA: 5s - loss: 0.0586 - accuracy: 0.9817

 351/1688 [=====>........................] - ETA: 5s - loss: 0.0582 - accuracy: 0.9818

 363/1688 [=====>........................] - ETA: 5s - loss: 0.0573 - accuracy: 0.9821

 376/1688 [=====>........................] - ETA: 5s - loss: 0.0565 - accuracy: 0.9825

 388/1688 [=====>........................] - ETA: 5s - loss: 0.0558 - accuracy: 0.9827

























































































































































































































Epoch 7/10


   1/1688 [..............................] - ETA: 8s - loss: 0.0073 - accuracy: 1.0000

  13/1688 [..............................] - ETA: 7s - loss: 0.0544 - accuracy: 0.9832

  25/1688 [..............................] - ETA: 7s - loss: 0.0649 - accuracy: 0.9812

  37/1688 [..............................] - ETA: 7s - loss: 0.0622 - accuracy: 0.9806

  49/1688 [..............................] - ETA: 6s - loss: 0.0560 - accuracy: 0.9834

  61/1688 [>.............................] - ETA: 6s - loss: 0.0552 - accuracy: 0.9836

  73/1688 [>.............................] - ETA: 6s - loss: 0.0537 - accuracy: 0.9837

  85/1688 [>.............................] - ETA: 6s - loss: 0.0497 - accuracy: 0.9853

  98/1688 [>.............................] - ETA: 6s - loss: 0.0492 - accuracy: 0.9857

 110/1688 [>.............................] - ETA: 6s - loss: 0.0471 - accuracy: 0.9858

 122/1688 [=>............................] - ETA: 6s - loss: 0.0462 - accuracy: 0.9851

 134/1688 [=>............................] - ETA: 6s - loss: 0.0462 - accuracy: 0.9851

 146/1688 [=>............................] - ETA: 6s - loss: 0.0450 - accuracy: 0.9859

 158/1688 [=>............................] - ETA: 6s - loss: 0.0454 - accuracy: 0.9856

 170/1688 [==>...........................] - ETA: 6s - loss: 0.0441 - accuracy: 0.9860

 182/1688 [==>...........................] - ETA: 6s - loss: 0.0431 - accuracy: 0.9864

 194/1688 [==>...........................] - ETA: 6s - loss: 0.0432 - accuracy: 0.9861

 207/1688 [==>...........................] - ETA: 6s - loss: 0.0428 - accuracy: 0.9863

 219/1688 [==>...........................] - ETA: 6s - loss: 0.0426 - accuracy: 0.9866

 231/1688 [===>..........................] - ETA: 6s - loss: 0.0426 - accuracy: 0.9867

 243/1688 [===>..........................] - ETA: 6s - loss: 0.0420 - accuracy: 0.9868

 255/1688 [===>..........................] - ETA: 6s - loss: 0.0415 - accuracy: 0.9866

 267/1688 [===>..........................] - ETA: 6s - loss: 0.0433 - accuracy: 0.9861

 279/1688 [===>..........................] - ETA: 6s - loss: 0.0430 - accuracy: 0.9862

 291/1688 [====>.........................] - ETA: 5s - loss: 0.0428 - accuracy: 0.9860

 303/1688 [====>.........................] - ETA: 5s - loss: 0.0428 - accuracy: 0.9862

 315/1688 [====>.........................] - ETA: 5s - loss: 0.0432 - accuracy: 0.9859

 327/1688 [====>.........................] - ETA: 5s - loss: 0.0431 - accuracy: 0.9858

 339/1688 [=====>........................] - ETA: 5s - loss: 0.0434 - accuracy: 0.9857

 351/1688 [=====>........................] - ETA: 5s - loss: 0.0443 - accuracy: 0.9854

 363/1688 [=====>........................] - ETA: 5s - loss: 0.0438 - accuracy: 0.9855

 375/1688 [=====>........................] - ETA: 5s - loss: 0.0443 - accuracy: 0.9854

 387/1688 [=====>........................] - ETA: 5s - loss: 0.0445 - accuracy: 0.9854

























































































































































































































Epoch 8/10


   1/1688 [..............................] - ETA: 8s - loss: 0.1088 - accuracy: 0.9375

  13/1688 [..............................] - ETA: 6s - loss: 0.0553 - accuracy: 0.9784

  26/1688 [..............................] - ETA: 6s - loss: 0.0478 - accuracy: 0.9808

  38/1688 [..............................] - ETA: 6s - loss: 0.0483 - accuracy: 0.9819

  50/1688 [..............................] - ETA: 6s - loss: 0.0536 - accuracy: 0.9806

  62/1688 [>.............................] - ETA: 6s - loss: 0.0503 - accuracy: 0.9819

  74/1688 [>.............................] - ETA: 6s - loss: 0.0459 - accuracy: 0.9840

  86/1688 [>.............................] - ETA: 6s - loss: 0.0490 - accuracy: 0.9829

  98/1688 [>.............................] - ETA: 6s - loss: 0.0472 - accuracy: 0.9834

 110/1688 [>.............................] - ETA: 6s - loss: 0.0475 - accuracy: 0.9835

 122/1688 [=>............................] - ETA: 6s - loss: 0.0454 - accuracy: 0.9844

 134/1688 [=>............................] - ETA: 6s - loss: 0.0438 - accuracy: 0.9848

 146/1688 [=>............................] - ETA: 6s - loss: 0.0430 - accuracy: 0.9852

 159/1688 [=>............................] - ETA: 6s - loss: 0.0426 - accuracy: 0.9858

 171/1688 [==>...........................] - ETA: 6s - loss: 0.0409 - accuracy: 0.9868

 183/1688 [==>...........................] - ETA: 6s - loss: 0.0397 - accuracy: 0.9872

 195/1688 [==>...........................] - ETA: 6s - loss: 0.0412 - accuracy: 0.9870

 207/1688 [==>...........................] - ETA: 6s - loss: 0.0404 - accuracy: 0.9870

 219/1688 [==>...........................] - ETA: 6s - loss: 0.0399 - accuracy: 0.9872

 231/1688 [===>..........................] - ETA: 6s - loss: 0.0402 - accuracy: 0.9871

 243/1688 [===>..........................] - ETA: 6s - loss: 0.0408 - accuracy: 0.9870

 255/1688 [===>..........................] - ETA: 6s - loss: 0.0410 - accuracy: 0.9871

 267/1688 [===>..........................] - ETA: 6s - loss: 0.0407 - accuracy: 0.9874

 279/1688 [===>..........................] - ETA: 5s - loss: 0.0405 - accuracy: 0.9875

 291/1688 [====>.........................] - ETA: 5s - loss: 0.0404 - accuracy: 0.9878

 303/1688 [====>.........................] - ETA: 5s - loss: 0.0415 - accuracy: 0.9873

 315/1688 [====>.........................] - ETA: 5s - loss: 0.0419 - accuracy: 0.9871

 327/1688 [====>.........................] - ETA: 5s - loss: 0.0423 - accuracy: 0.9869

 339/1688 [=====>........................] - ETA: 5s - loss: 0.0422 - accuracy: 0.9872

 351/1688 [=====>........................] - ETA: 5s - loss: 0.0419 - accuracy: 0.9873

 364/1688 [=====>........................] - ETA: 5s - loss: 0.0417 - accuracy: 0.9874

 376/1688 [=====>........................] - ETA: 5s - loss: 0.0416 - accuracy: 0.9874

 388/1688 [=====>........................] - ETA: 5s - loss: 0.0421 - accuracy: 0.9873

























































































































































































































Epoch 9/10


   1/1688 [..............................] - ETA: 8s - loss: 0.0299 - accuracy: 1.0000

  13/1688 [..............................] - ETA: 7s - loss: 0.0434 - accuracy: 0.9880

  25/1688 [..............................] - ETA: 7s - loss: 0.0339 - accuracy: 0.9912

  37/1688 [..............................] - ETA: 6s - loss: 0.0376 - accuracy: 0.9907

  49/1688 [..............................] - ETA: 6s - loss: 0.0354 - accuracy: 0.9904

  61/1688 [>.............................] - ETA: 6s - loss: 0.0393 - accuracy: 0.9903

  73/1688 [>.............................] - ETA: 6s - loss: 0.0392 - accuracy: 0.9893

  86/1688 [>.............................] - ETA: 6s - loss: 0.0384 - accuracy: 0.9887

  98/1688 [>.............................] - ETA: 6s - loss: 0.0412 - accuracy: 0.9876

 110/1688 [>.............................] - ETA: 6s - loss: 0.0405 - accuracy: 0.9878

 123/1688 [=>............................] - ETA: 6s - loss: 0.0386 - accuracy: 0.9891

 135/1688 [=>............................] - ETA: 6s - loss: 0.0376 - accuracy: 0.9894

 147/1688 [=>............................] - ETA: 6s - loss: 0.0376 - accuracy: 0.9889

 159/1688 [=>............................] - ETA: 6s - loss: 0.0371 - accuracy: 0.9894

 171/1688 [==>...........................] - ETA: 6s - loss: 0.0365 - accuracy: 0.9896

 183/1688 [==>...........................] - ETA: 6s - loss: 0.0375 - accuracy: 0.9892

 195/1688 [==>...........................] - ETA: 6s - loss: 0.0382 - accuracy: 0.9893

 207/1688 [==>...........................] - ETA: 6s - loss: 0.0412 - accuracy: 0.9890

 220/1688 [==>...........................] - ETA: 6s - loss: 0.0407 - accuracy: 0.9892

 232/1688 [===>..........................] - ETA: 6s - loss: 0.0404 - accuracy: 0.9891

 245/1688 [===>..........................] - ETA: 6s - loss: 0.0389 - accuracy: 0.9897

 257/1688 [===>..........................] - ETA: 6s - loss: 0.0383 - accuracy: 0.9899

 269/1688 [===>..........................] - ETA: 5s - loss: 0.0384 - accuracy: 0.9897

 281/1688 [===>..........................] - ETA: 5s - loss: 0.0383 - accuracy: 0.9895

 293/1688 [====>.........................] - ETA: 5s - loss: 0.0384 - accuracy: 0.9893

 305/1688 [====>.........................] - ETA: 5s - loss: 0.0386 - accuracy: 0.9892

 317/1688 [====>.........................] - ETA: 5s - loss: 0.0378 - accuracy: 0.9895

 329/1688 [====>.........................] - ETA: 5s - loss: 0.0376 - accuracy: 0.9896

 341/1688 [=====>........................] - ETA: 5s - loss: 0.0378 - accuracy: 0.9895

 353/1688 [=====>........................] - ETA: 5s - loss: 0.0393 - accuracy: 0.9887

 365/1688 [=====>........................] - ETA: 5s - loss: 0.0393 - accuracy: 0.9885

 377/1688 [=====>........................] - ETA: 5s - loss: 0.0389 - accuracy: 0.9886

 389/1688 [=====>........................] - ETA: 5s - loss: 0.0385 - accuracy: 0.9888

























































































































































































































Epoch 10/10


   1/1688 [..............................] - ETA: 8s - loss: 0.0055 - accuracy: 1.0000

  13/1688 [..............................] - ETA: 7s - loss: 0.0175 - accuracy: 0.9976

  25/1688 [..............................] - ETA: 6s - loss: 0.0235 - accuracy: 0.9925

  37/1688 [..............................] - ETA: 7s - loss: 0.0226 - accuracy: 0.9924

  49/1688 [..............................] - ETA: 6s - loss: 0.0235 - accuracy: 0.9917

  61/1688 [>.............................] - ETA: 6s - loss: 0.0243 - accuracy: 0.9918

  73/1688 [>.............................] - ETA: 6s - loss: 0.0229 - accuracy: 0.9927

  85/1688 [>.............................] - ETA: 6s - loss: 0.0262 - accuracy: 0.9919

  97/1688 [>.............................] - ETA: 6s - loss: 0.0261 - accuracy: 0.9919

 109/1688 [>.............................] - ETA: 6s - loss: 0.0292 - accuracy: 0.9911

 121/1688 [=>............................] - ETA: 6s - loss: 0.0292 - accuracy: 0.9912

 133/1688 [=>............................] - ETA: 6s - loss: 0.0291 - accuracy: 0.9911

 145/1688 [=>............................] - ETA: 6s - loss: 0.0285 - accuracy: 0.9916

 158/1688 [=>............................] - ETA: 6s - loss: 0.0303 - accuracy: 0.9913

 170/1688 [==>...........................] - ETA: 6s - loss: 0.0319 - accuracy: 0.9904

 182/1688 [==>...........................] - ETA: 6s - loss: 0.0333 - accuracy: 0.9906

 194/1688 [==>...........................] - ETA: 6s - loss: 0.0350 - accuracy: 0.9902

 206/1688 [==>...........................] - ETA: 6s - loss: 0.0350 - accuracy: 0.9898

 218/1688 [==>...........................] - ETA: 6s - loss: 0.0343 - accuracy: 0.9903

 230/1688 [===>..........................] - ETA: 6s - loss: 0.0334 - accuracy: 0.9904

 242/1688 [===>..........................] - ETA: 6s - loss: 0.0331 - accuracy: 0.9906

 254/1688 [===>..........................] - ETA: 6s - loss: 0.0329 - accuracy: 0.9905

 266/1688 [===>..........................] - ETA: 6s - loss: 0.0324 - accuracy: 0.9906

 278/1688 [===>..........................] - ETA: 5s - loss: 0.0328 - accuracy: 0.9904

 290/1688 [====>.........................] - ETA: 5s - loss: 0.0329 - accuracy: 0.9904

 302/1688 [====>.........................] - ETA: 5s - loss: 0.0328 - accuracy: 0.9904

 314/1688 [====>.........................] - ETA: 5s - loss: 0.0332 - accuracy: 0.9902

 326/1688 [====>.........................] - ETA: 5s - loss: 0.0333 - accuracy: 0.9903

 338/1688 [=====>........................] - ETA: 5s - loss: 0.0332 - accuracy: 0.9902

 350/1688 [=====>........................] - ETA: 5s - loss: 0.0330 - accuracy: 0.9903

 362/1688 [=====>........................] - ETA: 5s - loss: 0.0328 - accuracy: 0.9903

 374/1688 [=====>........................] - ETA: 5s - loss: 0.0326 - accuracy: 0.9904

 386/1688 [=====>........................] - ETA: 5s - loss: 0.0329 - accuracy: 0.9902

























































































































































































































<tf_keras.src.callbacks.History at 0x7f53b0e72940>

### Evaluate the baseline model and save it for later usage

In [5]:
_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)

_, keras_file = tempfile.mkstemp('.h5')
print('Saving model to: ', keras_file)
keras.models.save_model(model, keras_file, include_optimizer=False)

Baseline test accuracy: 0.9805999994277954
Saving model to:  /tmpfs/tmp/tmpytpcnxd_.h5


  keras.models.save_model(model, keras_file, include_optimizer=False)


## Prune and fine-tune the model to 50% sparsity

Apply the `prune_low_magnitude()` API to prune the whole pre-trained model to achieve the model that is to be clustered in the next step. For how best to use the API to achieve the best compression rate while maintaining your target accuracy, refer to the [pruning comprehensive guide](https://www.tensorflow.org/model_optimization/guide/pruning/comprehensive_guide).

### Define the model and apply the sparsity API

Note that the pre-trained model is used.

In [6]:
import tensorflow_model_optimization as tfmot

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

pruning_params = {
      'pruning_schedule': tfmot.sparsity.keras.ConstantSparsity(0.5, begin_step=0, frequency=100)
  }

callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep()
]

pruned_model = prune_low_magnitude(model, **pruning_params)

# Use smaller learning rate for fine-tuning
opt = keras.optimizers.Adam(learning_rate=1e-5)

pruned_model.compile(
  loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  optimizer=opt,
  metrics=['accuracy'])

pruned_model.summary()

Model: "sequential"


_________________________________________________________________


 Layer (type)                Output Shape              Param #   




 prune_low_magnitude_reshap  (None, 28, 28, 1)         1         


 e (PruneLowMagnitude)                                           


                                                                 


 prune_low_magnitude_conv2d  (None, 26, 26, 12)        230       


  (PruneLowMagnitude)                                            


                                                                 


 prune_low_magnitude_max_po  (None, 13, 13, 12)        1         


 oling2d (PruneLowMagnitude                                      


 )                                                               


                                                                 


 prune_low_magnitude_flatte  (None, 2028)              1         


 n (PruneLowMagnitude)                                           


                                                                 


 prune_low_magnitude_dense   (None, 10)                40572     


 (PruneLowMagnitude)                                             


                                                                 




Total params: 40805 (159.41 KB)


Trainable params: 20410 (79.73 KB)


Non-trainable params: 20395 (79.69 KB)


_________________________________________________________________


### Fine-tune the model, check sparsity, and evaluate the accuracy against baseline

Fine-tune the model with pruning for 3 epochs.

In [7]:
# Fine-tune model
pruned_model.fit(
  train_images,
  train_labels,
  epochs=3,
  validation_split=0.1,
  callbacks=callbacks)

Epoch 1/3


   1/1688 [..............................] - ETA: 46:35 - loss: 0.0723 - accuracy: 0.9688

  12/1688 [..............................] - ETA: 8s - loss: 0.0342 - accuracy: 0.9948   

  24/1688 [..............................] - ETA: 7s - loss: 0.0331 - accuracy: 0.9922

  35/1688 [..............................] - ETA: 7s - loss: 0.0301 - accuracy: 0.9920

  46/1688 [..............................] - ETA: 7s - loss: 0.0320 - accuracy: 0.9918

  58/1688 [>.............................] - ETA: 7s - loss: 0.0271 - accuracy: 0.9935

  70/1688 [>.............................] - ETA: 7s - loss: 0.0287 - accuracy: 0.9929

  82/1688 [>.............................] - ETA: 7s - loss: 0.0265 - accuracy: 0.9935

  94/1688 [>.............................] - ETA: 7s - loss: 0.0282 - accuracy: 0.9927

 105/1688 [>.............................] - ETA: 7s - loss: 0.0378 - accuracy: 0.9893

 117/1688 [=>............................] - ETA: 7s - loss: 0.0511 - accuracy: 0.9861

 129/1688 [=>............................] - ETA: 7s - loss: 0.0576 - accuracy: 0.9843

 141/1688 [=>............................] - ETA: 7s - loss: 0.0610 - accuracy: 0.9820

 153/1688 [=>............................] - ETA: 6s - loss: 0.0652 - accuracy: 0.9806

 164/1688 [=>............................] - ETA: 6s - loss: 0.0709 - accuracy: 0.9783

 176/1688 [==>...........................] - ETA: 6s - loss: 0.0730 - accuracy: 0.9776

 187/1688 [==>...........................] - ETA: 6s - loss: 0.0768 - accuracy: 0.9769

 198/1688 [==>...........................] - ETA: 6s - loss: 0.0798 - accuracy: 0.9755

 209/1688 [==>...........................] - ETA: 6s - loss: 0.0869 - accuracy: 0.9740

 220/1688 [==>...........................] - ETA: 6s - loss: 0.0903 - accuracy: 0.9729

 232/1688 [===>..........................] - ETA: 6s - loss: 0.0916 - accuracy: 0.9727

 243/1688 [===>..........................] - ETA: 6s - loss: 0.0957 - accuracy: 0.9711

 254/1688 [===>..........................] - ETA: 6s - loss: 0.0965 - accuracy: 0.9705

 266/1688 [===>..........................] - ETA: 6s - loss: 0.1013 - accuracy: 0.9691

 278/1688 [===>..........................] - ETA: 6s - loss: 0.1016 - accuracy: 0.9689

 289/1688 [====>.........................] - ETA: 6s - loss: 0.1023 - accuracy: 0.9685

 300/1688 [====>.........................] - ETA: 6s - loss: 0.1040 - accuracy: 0.9678

 312/1688 [====>.........................] - ETA: 6s - loss: 0.1053 - accuracy: 0.9678

 324/1688 [====>.........................] - ETA: 6s - loss: 0.1084 - accuracy: 0.9673

 336/1688 [====>.........................] - ETA: 6s - loss: 0.1091 - accuracy: 0.9669

 348/1688 [=====>........................] - ETA: 6s - loss: 0.1089 - accuracy: 0.9667

 360/1688 [=====>........................] - ETA: 6s - loss: 0.1106 - accuracy: 0.9664

 372/1688 [=====>........................] - ETA: 5s - loss: 0.1105 - accuracy: 0.9664

 384/1688 [=====>........................] - ETA: 5s - loss: 0.1101 - accuracy: 0.9665





























































































































































































































Epoch 2/3


   1/1688 [..............................] - ETA: 8s - loss: 0.0299 - accuracy: 1.0000

  13/1688 [..............................] - ETA: 7s - loss: 0.0701 - accuracy: 0.9808

  24/1688 [..............................] - ETA: 7s - loss: 0.0873 - accuracy: 0.9779

  36/1688 [..............................] - ETA: 7s - loss: 0.0792 - accuracy: 0.9774

  48/1688 [..............................] - ETA: 7s - loss: 0.0702 - accuracy: 0.9805

  60/1688 [>.............................] - ETA: 7s - loss: 0.0649 - accuracy: 0.9818

  72/1688 [>.............................] - ETA: 7s - loss: 0.0675 - accuracy: 0.9813

  84/1688 [>.............................] - ETA: 7s - loss: 0.0690 - accuracy: 0.9799

  96/1688 [>.............................] - ETA: 7s - loss: 0.0674 - accuracy: 0.9808

 108/1688 [>.............................] - ETA: 7s - loss: 0.0693 - accuracy: 0.9806

 119/1688 [=>............................] - ETA: 7s - loss: 0.0708 - accuracy: 0.9800

 131/1688 [=>............................] - ETA: 6s - loss: 0.0694 - accuracy: 0.9802

 143/1688 [=>............................] - ETA: 6s - loss: 0.0714 - accuracy: 0.9797

 155/1688 [=>............................] - ETA: 6s - loss: 0.0701 - accuracy: 0.9800

 167/1688 [=>............................] - ETA: 6s - loss: 0.0694 - accuracy: 0.9802

 179/1688 [==>...........................] - ETA: 6s - loss: 0.0691 - accuracy: 0.9799

 191/1688 [==>...........................] - ETA: 6s - loss: 0.0680 - accuracy: 0.9800

 203/1688 [==>...........................] - ETA: 6s - loss: 0.0681 - accuracy: 0.9800

 215/1688 [==>...........................] - ETA: 6s - loss: 0.0696 - accuracy: 0.9795

 227/1688 [===>..........................] - ETA: 6s - loss: 0.0690 - accuracy: 0.9795

 239/1688 [===>..........................] - ETA: 6s - loss: 0.0688 - accuracy: 0.9793

 251/1688 [===>..........................] - ETA: 6s - loss: 0.0692 - accuracy: 0.9788

 263/1688 [===>..........................] - ETA: 6s - loss: 0.0686 - accuracy: 0.9790

 275/1688 [===>..........................] - ETA: 6s - loss: 0.0693 - accuracy: 0.9789

 287/1688 [====>.........................] - ETA: 6s - loss: 0.0689 - accuracy: 0.9789

 299/1688 [====>.........................] - ETA: 6s - loss: 0.0687 - accuracy: 0.9786

 311/1688 [====>.........................] - ETA: 6s - loss: 0.0689 - accuracy: 0.9783

 323/1688 [====>.........................] - ETA: 6s - loss: 0.0691 - accuracy: 0.9780

 335/1688 [====>.........................] - ETA: 6s - loss: 0.0687 - accuracy: 0.9782

 347/1688 [=====>........................] - ETA: 5s - loss: 0.0687 - accuracy: 0.9781

 359/1688 [=====>........................] - ETA: 5s - loss: 0.0696 - accuracy: 0.9778

 371/1688 [=====>........................] - ETA: 5s - loss: 0.0699 - accuracy: 0.9778

 383/1688 [=====>........................] - ETA: 5s - loss: 0.0701 - accuracy: 0.9778





























































































































































































































Epoch 3/3


   1/1688 [..............................] - ETA: 9s - loss: 0.2954 - accuracy: 0.9375

  13/1688 [..............................] - ETA: 7s - loss: 0.0623 - accuracy: 0.9832

  24/1688 [..............................] - ETA: 7s - loss: 0.0625 - accuracy: 0.9805

  36/1688 [..............................] - ETA: 7s - loss: 0.0646 - accuracy: 0.9792

  48/1688 [..............................] - ETA: 7s - loss: 0.0603 - accuracy: 0.9792

  60/1688 [>.............................] - ETA: 7s - loss: 0.0594 - accuracy: 0.9781

  72/1688 [>.............................] - ETA: 7s - loss: 0.0626 - accuracy: 0.9766

  84/1688 [>.............................] - ETA: 7s - loss: 0.0611 - accuracy: 0.9773

  96/1688 [>.............................] - ETA: 7s - loss: 0.0598 - accuracy: 0.9779

 108/1688 [>.............................] - ETA: 6s - loss: 0.0593 - accuracy: 0.9783

 120/1688 [=>............................] - ETA: 6s - loss: 0.0607 - accuracy: 0.9779

 131/1688 [=>............................] - ETA: 6s - loss: 0.0605 - accuracy: 0.9783

 143/1688 [=>............................] - ETA: 6s - loss: 0.0600 - accuracy: 0.9788

 155/1688 [=>............................] - ETA: 6s - loss: 0.0614 - accuracy: 0.9786

 167/1688 [=>............................] - ETA: 6s - loss: 0.0620 - accuracy: 0.9789

 179/1688 [==>...........................] - ETA: 6s - loss: 0.0600 - accuracy: 0.9801

 191/1688 [==>...........................] - ETA: 6s - loss: 0.0596 - accuracy: 0.9802

 203/1688 [==>...........................] - ETA: 6s - loss: 0.0596 - accuracy: 0.9798

 215/1688 [==>...........................] - ETA: 6s - loss: 0.0591 - accuracy: 0.9802

 226/1688 [===>..........................] - ETA: 6s - loss: 0.0584 - accuracy: 0.9804

 238/1688 [===>..........................] - ETA: 6s - loss: 0.0576 - accuracy: 0.9807

 250/1688 [===>..........................] - ETA: 6s - loss: 0.0572 - accuracy: 0.9809

 262/1688 [===>..........................] - ETA: 6s - loss: 0.0567 - accuracy: 0.9810

 273/1688 [===>..........................] - ETA: 6s - loss: 0.0571 - accuracy: 0.9813

 285/1688 [====>.........................] - ETA: 6s - loss: 0.0583 - accuracy: 0.9808

 297/1688 [====>.........................] - ETA: 6s - loss: 0.0579 - accuracy: 0.9810

 309/1688 [====>.........................] - ETA: 6s - loss: 0.0580 - accuracy: 0.9809

 321/1688 [====>.........................] - ETA: 6s - loss: 0.0584 - accuracy: 0.9807

 333/1688 [====>.........................] - ETA: 6s - loss: 0.0587 - accuracy: 0.9808

 345/1688 [=====>........................] - ETA: 5s - loss: 0.0586 - accuracy: 0.9808

 357/1688 [=====>........................] - ETA: 5s - loss: 0.0579 - accuracy: 0.9807

 369/1688 [=====>........................] - ETA: 5s - loss: 0.0577 - accuracy: 0.9809

 381/1688 [=====>........................] - ETA: 5s - loss: 0.0575 - accuracy: 0.9810

 393/1688 [=====>........................] - ETA: 5s - loss: 0.0575 - accuracy: 0.9810



























































































































































































































<tf_keras.src.callbacks.History at 0x7f52bc36d3a0>

Define helper functions to calculate and print the sparsity of the model.

In [8]:
def print_model_weights_sparsity(model):

    for layer in model.layers:
        if isinstance(layer, keras.layers.Wrapper):
            weights = layer.trainable_weights
        else:
            weights = layer.weights
        for weight in weights:
            if "kernel" not in weight.name or "centroid" in weight.name:
                continue
            weight_size = weight.numpy().size
            zero_num = np.count_nonzero(weight == 0)
            print(
                f"{weight.name}: {zero_num/weight_size:.2%} sparsity ",
                f"({zero_num}/{weight_size})",
            )

Check that the model kernels was correctly pruned. We need to strip the pruning wrapper first. We also create a deep copy of the model to be used in the next step.

In [9]:
stripped_pruned_model = tfmot.sparsity.keras.strip_pruning(pruned_model)

print_model_weights_sparsity(stripped_pruned_model)

stripped_pruned_model_copy = keras.models.clone_model(stripped_pruned_model)
stripped_pruned_model_copy.set_weights(stripped_pruned_model.get_weights())

conv2d/kernel:0: 50.00% sparsity  (54/108)
dense/kernel:0: 50.00% sparsity  (10140/20280)


## Apply clustering and sparsity preserving clustering and check its effect on model sparsity in both cases

Next, we apply both clustering and sparsity preserving clustering on the pruned model and observe that the latter preserves sparsity on your pruned model. Note that we stripped pruning wrappers from the pruned model with `tfmot.sparsity.keras.strip_pruning` before applying the clustering API.

In [10]:
# Clustering
cluster_weights = tfmot.clustering.keras.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization

clustering_params = {
  'number_of_clusters': 8,
  'cluster_centroids_init': CentroidInitialization.KMEANS_PLUS_PLUS
}

clustered_model = cluster_weights(stripped_pruned_model, **clustering_params)

clustered_model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

print('Train clustering model:')
clustered_model.fit(train_images, train_labels,epochs=3, validation_split=0.1)


stripped_pruned_model.save("stripped_pruned_model_clustered.h5")

# Sparsity preserving clustering
from tensorflow_model_optimization.python.core.clustering.keras.experimental import (
    cluster,
)

cluster_weights = cluster.cluster_weights

clustering_params = {
  'number_of_clusters': 8,
  'cluster_centroids_init': CentroidInitialization.KMEANS_PLUS_PLUS,
  'preserve_sparsity': True
}

sparsity_clustered_model = cluster_weights(stripped_pruned_model_copy, **clustering_params)

sparsity_clustered_model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

print('Train sparsity preserving clustering model:')
sparsity_clustered_model.fit(train_images, train_labels,epochs=3, validation_split=0.1)

Train clustering model:


Epoch 1/3


   1/1688 [..............................] - ETA: 15:36 - loss: 0.0178 - accuracy: 1.0000

  12/1688 [..............................] - ETA: 8s - loss: 0.0683 - accuracy: 0.9792   

  23/1688 [..............................] - ETA: 8s - loss: 0.0744 - accuracy: 0.9755

  34/1688 [..............................] - ETA: 7s - loss: 0.0655 - accuracy: 0.9798

  45/1688 [..............................] - ETA: 7s - loss: 0.0608 - accuracy: 0.9812

  56/1688 [..............................] - ETA: 7s - loss: 0.0584 - accuracy: 0.9827

  67/1688 [>.............................] - ETA: 7s - loss: 0.0608 - accuracy: 0.9832

  78/1688 [>.............................] - ETA: 7s - loss: 0.0602 - accuracy: 0.9836

  89/1688 [>.............................] - ETA: 7s - loss: 0.0566 - accuracy: 0.9842

 101/1688 [>.............................] - ETA: 7s - loss: 0.0575 - accuracy: 0.9845

 113/1688 [=>............................] - ETA: 7s - loss: 0.0560 - accuracy: 0.9848

 125/1688 [=>............................] - ETA: 7s - loss: 0.0539 - accuracy: 0.9850

 137/1688 [=>............................] - ETA: 7s - loss: 0.0520 - accuracy: 0.9859

 149/1688 [=>............................] - ETA: 7s - loss: 0.0524 - accuracy: 0.9851

 161/1688 [=>............................] - ETA: 6s - loss: 0.0528 - accuracy: 0.9849

 173/1688 [==>...........................] - ETA: 6s - loss: 0.0520 - accuracy: 0.9850

 185/1688 [==>...........................] - ETA: 6s - loss: 0.0507 - accuracy: 0.9853

 197/1688 [==>...........................] - ETA: 6s - loss: 0.0489 - accuracy: 0.9859

 209/1688 [==>...........................] - ETA: 6s - loss: 0.0475 - accuracy: 0.9861

 221/1688 [==>...........................] - ETA: 6s - loss: 0.0468 - accuracy: 0.9860

 233/1688 [===>..........................] - ETA: 6s - loss: 0.0463 - accuracy: 0.9862

 245/1688 [===>..........................] - ETA: 6s - loss: 0.0461 - accuracy: 0.9864

 257/1688 [===>..........................] - ETA: 6s - loss: 0.0456 - accuracy: 0.9865

 269/1688 [===>..........................] - ETA: 6s - loss: 0.0449 - accuracy: 0.9868

 281/1688 [===>..........................] - ETA: 6s - loss: 0.0455 - accuracy: 0.9862

 293/1688 [====>.........................] - ETA: 6s - loss: 0.0452 - accuracy: 0.9865

 305/1688 [====>.........................] - ETA: 6s - loss: 0.0450 - accuracy: 0.9865

 317/1688 [====>.........................] - ETA: 6s - loss: 0.0446 - accuracy: 0.9867

 329/1688 [====>.........................] - ETA: 6s - loss: 0.0449 - accuracy: 0.9869

 341/1688 [=====>........................] - ETA: 5s - loss: 0.0445 - accuracy: 0.9872

 353/1688 [=====>........................] - ETA: 5s - loss: 0.0447 - accuracy: 0.9872

 365/1688 [=====>........................] - ETA: 5s - loss: 0.0448 - accuracy: 0.9872

 377/1688 [=====>........................] - ETA: 5s - loss: 0.0441 - accuracy: 0.9874

 389/1688 [=====>........................] - ETA: 5s - loss: 0.0447 - accuracy: 0.9874





























































































































































































































Epoch 2/3


   1/1688 [..............................] - ETA: 8s - loss: 0.0023 - accuracy: 1.0000

  13/1688 [..............................] - ETA: 7s - loss: 0.0184 - accuracy: 0.9952

  25/1688 [..............................] - ETA: 7s - loss: 0.0183 - accuracy: 0.9950

  37/1688 [..............................] - ETA: 7s - loss: 0.0254 - accuracy: 0.9916

  48/1688 [..............................] - ETA: 7s - loss: 0.0298 - accuracy: 0.9902

  60/1688 [>.............................] - ETA: 7s - loss: 0.0342 - accuracy: 0.9880

  72/1688 [>.............................] - ETA: 7s - loss: 0.0355 - accuracy: 0.9874

  84/1688 [>.............................] - ETA: 7s - loss: 0.0411 - accuracy: 0.9881

  95/1688 [>.............................] - ETA: 7s - loss: 0.0384 - accuracy: 0.9888

 107/1688 [>.............................] - ETA: 7s - loss: 0.0401 - accuracy: 0.9880

 119/1688 [=>............................] - ETA: 7s - loss: 0.0393 - accuracy: 0.9884

 131/1688 [=>............................] - ETA: 6s - loss: 0.0412 - accuracy: 0.9874

 143/1688 [=>............................] - ETA: 6s - loss: 0.0392 - accuracy: 0.9882

 155/1688 [=>............................] - ETA: 6s - loss: 0.0385 - accuracy: 0.9885

 166/1688 [=>............................] - ETA: 6s - loss: 0.0383 - accuracy: 0.9887

 177/1688 [==>...........................] - ETA: 6s - loss: 0.0386 - accuracy: 0.9887

 188/1688 [==>...........................] - ETA: 6s - loss: 0.0388 - accuracy: 0.9885

 200/1688 [==>...........................] - ETA: 6s - loss: 0.0374 - accuracy: 0.9891

 212/1688 [==>...........................] - ETA: 6s - loss: 0.0372 - accuracy: 0.9892

 223/1688 [==>...........................] - ETA: 6s - loss: 0.0362 - accuracy: 0.9895

 234/1688 [===>..........................] - ETA: 6s - loss: 0.0357 - accuracy: 0.9896

 246/1688 [===>..........................] - ETA: 6s - loss: 0.0363 - accuracy: 0.9895

 258/1688 [===>..........................] - ETA: 6s - loss: 0.0359 - accuracy: 0.9897

 270/1688 [===>..........................] - ETA: 6s - loss: 0.0358 - accuracy: 0.9898

 282/1688 [====>.........................] - ETA: 6s - loss: 0.0355 - accuracy: 0.9898

 294/1688 [====>.........................] - ETA: 6s - loss: 0.0353 - accuracy: 0.9898

 306/1688 [====>.........................] - ETA: 6s - loss: 0.0348 - accuracy: 0.9899

 317/1688 [====>.........................] - ETA: 6s - loss: 0.0349 - accuracy: 0.9897

 329/1688 [====>.........................] - ETA: 6s - loss: 0.0350 - accuracy: 0.9895

 341/1688 [=====>........................] - ETA: 6s - loss: 0.0347 - accuracy: 0.9894

 352/1688 [=====>........................] - ETA: 5s - loss: 0.0344 - accuracy: 0.9893

 364/1688 [=====>........................] - ETA: 5s - loss: 0.0354 - accuracy: 0.9891

 375/1688 [=====>........................] - ETA: 5s - loss: 0.0357 - accuracy: 0.9889

 386/1688 [=====>........................] - ETA: 5s - loss: 0.0358 - accuracy: 0.9887

































































































































































































































Epoch 3/3


   1/1688 [..............................] - ETA: 9s - loss: 0.0195 - accuracy: 1.0000

  13/1688 [..............................] - ETA: 7s - loss: 0.0195 - accuracy: 0.9976

  25/1688 [..............................] - ETA: 7s - loss: 0.0227 - accuracy: 0.9937

  36/1688 [..............................] - ETA: 7s - loss: 0.0304 - accuracy: 0.9913

  48/1688 [..............................] - ETA: 7s - loss: 0.0256 - accuracy: 0.9928

  59/1688 [>.............................] - ETA: 7s - loss: 0.0294 - accuracy: 0.9915

  70/1688 [>.............................] - ETA: 7s - loss: 0.0308 - accuracy: 0.9915

  82/1688 [>.............................] - ETA: 7s - loss: 0.0304 - accuracy: 0.9912

  93/1688 [>.............................] - ETA: 7s - loss: 0.0349 - accuracy: 0.9899

 105/1688 [>.............................] - ETA: 7s - loss: 0.0354 - accuracy: 0.9902

 116/1688 [=>............................] - ETA: 7s - loss: 0.0349 - accuracy: 0.9900

 128/1688 [=>............................] - ETA: 7s - loss: 0.0346 - accuracy: 0.9897

 140/1688 [=>............................] - ETA: 7s - loss: 0.0356 - accuracy: 0.9891

 152/1688 [=>............................] - ETA: 6s - loss: 0.0352 - accuracy: 0.9891

 164/1688 [=>............................] - ETA: 6s - loss: 0.0355 - accuracy: 0.9889

 176/1688 [==>...........................] - ETA: 6s - loss: 0.0352 - accuracy: 0.9890

 187/1688 [==>...........................] - ETA: 6s - loss: 0.0371 - accuracy: 0.9881

 198/1688 [==>...........................] - ETA: 6s - loss: 0.0364 - accuracy: 0.9883

 209/1688 [==>...........................] - ETA: 6s - loss: 0.0369 - accuracy: 0.9883

 220/1688 [==>...........................] - ETA: 6s - loss: 0.0372 - accuracy: 0.9884

 232/1688 [===>..........................] - ETA: 6s - loss: 0.0381 - accuracy: 0.9875

 244/1688 [===>..........................] - ETA: 6s - loss: 0.0372 - accuracy: 0.9878

 256/1688 [===>..........................] - ETA: 6s - loss: 0.0369 - accuracy: 0.9879

 268/1688 [===>..........................] - ETA: 6s - loss: 0.0372 - accuracy: 0.9876

 280/1688 [===>..........................] - ETA: 6s - loss: 0.0378 - accuracy: 0.9874

 291/1688 [====>.........................] - ETA: 6s - loss: 0.0373 - accuracy: 0.9873

 303/1688 [====>.........................] - ETA: 6s - loss: 0.0367 - accuracy: 0.9875

 315/1688 [====>.........................] - ETA: 6s - loss: 0.0374 - accuracy: 0.9871

 327/1688 [====>.........................] - ETA: 6s - loss: 0.0374 - accuracy: 0.9870

 339/1688 [=====>........................] - ETA: 6s - loss: 0.0374 - accuracy: 0.9872

 351/1688 [=====>........................] - ETA: 6s - loss: 0.0370 - accuracy: 0.9874

 363/1688 [=====>........................] - ETA: 5s - loss: 0.0370 - accuracy: 0.9874

 375/1688 [=====>........................] - ETA: 5s - loss: 0.0370 - accuracy: 0.9875

 387/1688 [=====>........................] - ETA: 5s - loss: 0.0364 - accuracy: 0.9878































































































































































































































Train sparsity preserving clustering model:


  saving_api.save_model(


Epoch 1/3


   1/1688 [..............................] - ETA: 17:32 - loss: 0.0251 - accuracy: 1.0000

  12/1688 [..............................] - ETA: 8s - loss: 0.0789 - accuracy: 0.9766   

  23/1688 [..............................] - ETA: 8s - loss: 0.0765 - accuracy: 0.9783

  34/1688 [..............................] - ETA: 8s - loss: 0.0795 - accuracy: 0.9743

  45/1688 [..............................] - ETA: 8s - loss: 0.0762 - accuracy: 0.9736

  56/1688 [..............................] - ETA: 7s - loss: 0.0705 - accuracy: 0.9760

  67/1688 [>.............................] - ETA: 7s - loss: 0.0697 - accuracy: 0.9743

  78/1688 [>.............................] - ETA: 7s - loss: 0.0664 - accuracy: 0.9764

  89/1688 [>.............................] - ETA: 7s - loss: 0.0659 - accuracy: 0.9768

 100/1688 [>.............................] - ETA: 7s - loss: 0.0667 - accuracy: 0.9766

 111/1688 [>.............................] - ETA: 7s - loss: 0.0657 - accuracy: 0.9769

 122/1688 [=>............................] - ETA: 7s - loss: 0.0663 - accuracy: 0.9775

 134/1688 [=>............................] - ETA: 7s - loss: 0.0640 - accuracy: 0.9781

 145/1688 [=>............................] - ETA: 7s - loss: 0.0636 - accuracy: 0.9778

 156/1688 [=>............................] - ETA: 7s - loss: 0.0621 - accuracy: 0.9786

 168/1688 [=>............................] - ETA: 7s - loss: 0.0617 - accuracy: 0.9786

 180/1688 [==>...........................] - ETA: 7s - loss: 0.0626 - accuracy: 0.9785

 191/1688 [==>...........................] - ETA: 7s - loss: 0.0615 - accuracy: 0.9791

 202/1688 [==>...........................] - ETA: 7s - loss: 0.0619 - accuracy: 0.9791

 213/1688 [==>...........................] - ETA: 6s - loss: 0.0607 - accuracy: 0.9795

 224/1688 [==>...........................] - ETA: 6s - loss: 0.0602 - accuracy: 0.9799

 235/1688 [===>..........................] - ETA: 6s - loss: 0.0592 - accuracy: 0.9803

 246/1688 [===>..........................] - ETA: 6s - loss: 0.0590 - accuracy: 0.9806

 257/1688 [===>..........................] - ETA: 6s - loss: 0.0587 - accuracy: 0.9807

 269/1688 [===>..........................] - ETA: 6s - loss: 0.0597 - accuracy: 0.9803

 280/1688 [===>..........................] - ETA: 6s - loss: 0.0589 - accuracy: 0.9806

 291/1688 [====>.........................] - ETA: 6s - loss: 0.0583 - accuracy: 0.9809

 303/1688 [====>.........................] - ETA: 6s - loss: 0.0581 - accuracy: 0.9809

 314/1688 [====>.........................] - ETA: 6s - loss: 0.0569 - accuracy: 0.9815

 325/1688 [====>.........................] - ETA: 6s - loss: 0.0560 - accuracy: 0.9819

 337/1688 [====>.........................] - ETA: 6s - loss: 0.0568 - accuracy: 0.9816

 349/1688 [=====>........................] - ETA: 6s - loss: 0.0561 - accuracy: 0.9820

 360/1688 [=====>........................] - ETA: 6s - loss: 0.0557 - accuracy: 0.9822

 371/1688 [=====>........................] - ETA: 6s - loss: 0.0550 - accuracy: 0.9825

 382/1688 [=====>........................] - ETA: 6s - loss: 0.0545 - accuracy: 0.9825

 393/1688 [=====>........................] - ETA: 6s - loss: 0.0542 - accuracy: 0.9823







































































































































































































































Epoch 2/3


   1/1688 [..............................] - ETA: 8s - loss: 0.0058 - accuracy: 1.0000

  13/1688 [..............................] - ETA: 7s - loss: 0.0506 - accuracy: 0.9880

  24/1688 [..............................] - ETA: 7s - loss: 0.0381 - accuracy: 0.9909

  35/1688 [..............................] - ETA: 7s - loss: 0.0440 - accuracy: 0.9866

  47/1688 [..............................] - ETA: 7s - loss: 0.0431 - accuracy: 0.9867

  59/1688 [>.............................] - ETA: 7s - loss: 0.0436 - accuracy: 0.9862

  71/1688 [>.............................] - ETA: 7s - loss: 0.0410 - accuracy: 0.9877

  82/1688 [>.............................] - ETA: 7s - loss: 0.0430 - accuracy: 0.9870

  94/1688 [>.............................] - ETA: 7s - loss: 0.0446 - accuracy: 0.9857

 105/1688 [>.............................] - ETA: 7s - loss: 0.0441 - accuracy: 0.9857

 116/1688 [=>............................] - ETA: 7s - loss: 0.0429 - accuracy: 0.9865

 127/1688 [=>............................] - ETA: 7s - loss: 0.0446 - accuracy: 0.9850

 138/1688 [=>............................] - ETA: 7s - loss: 0.0448 - accuracy: 0.9848

 150/1688 [=>............................] - ETA: 6s - loss: 0.0438 - accuracy: 0.9848

 161/1688 [=>............................] - ETA: 6s - loss: 0.0426 - accuracy: 0.9854

 172/1688 [==>...........................] - ETA: 6s - loss: 0.0426 - accuracy: 0.9853

 183/1688 [==>...........................] - ETA: 6s - loss: 0.0423 - accuracy: 0.9851

 194/1688 [==>...........................] - ETA: 6s - loss: 0.0422 - accuracy: 0.9853

 205/1688 [==>...........................] - ETA: 6s - loss: 0.0410 - accuracy: 0.9857

 217/1688 [==>...........................] - ETA: 6s - loss: 0.0424 - accuracy: 0.9852

 228/1688 [===>..........................] - ETA: 6s - loss: 0.0429 - accuracy: 0.9851

 240/1688 [===>..........................] - ETA: 6s - loss: 0.0428 - accuracy: 0.9850

 251/1688 [===>..........................] - ETA: 6s - loss: 0.0427 - accuracy: 0.9851

 262/1688 [===>..........................] - ETA: 6s - loss: 0.0419 - accuracy: 0.9852

 274/1688 [===>..........................] - ETA: 6s - loss: 0.0424 - accuracy: 0.9849

 285/1688 [====>.........................] - ETA: 6s - loss: 0.0428 - accuracy: 0.9850

 297/1688 [====>.........................] - ETA: 6s - loss: 0.0425 - accuracy: 0.9854

 308/1688 [====>.........................] - ETA: 6s - loss: 0.0425 - accuracy: 0.9855

 319/1688 [====>.........................] - ETA: 6s - loss: 0.0419 - accuracy: 0.9857

 331/1688 [====>.........................] - ETA: 6s - loss: 0.0422 - accuracy: 0.9856

 342/1688 [=====>........................] - ETA: 6s - loss: 0.0419 - accuracy: 0.9859

 353/1688 [=====>........................] - ETA: 6s - loss: 0.0418 - accuracy: 0.9858

 365/1688 [=====>........................] - ETA: 6s - loss: 0.0415 - accuracy: 0.9859

 376/1688 [=====>........................] - ETA: 5s - loss: 0.0415 - accuracy: 0.9860

 388/1688 [=====>........................] - ETA: 5s - loss: 0.0415 - accuracy: 0.9859



































































































































































































































Epoch 3/3


   1/1688 [..............................] - ETA: 8s - loss: 0.0346 - accuracy: 1.0000

  12/1688 [..............................] - ETA: 7s - loss: 0.0440 - accuracy: 0.9896

  23/1688 [..............................] - ETA: 7s - loss: 0.0404 - accuracy: 0.9891

  35/1688 [..............................] - ETA: 7s - loss: 0.0453 - accuracy: 0.9884

  47/1688 [..............................] - ETA: 7s - loss: 0.0404 - accuracy: 0.9907

  58/1688 [>.............................] - ETA: 7s - loss: 0.0396 - accuracy: 0.9903

  70/1688 [>.............................] - ETA: 7s - loss: 0.0430 - accuracy: 0.9875

  81/1688 [>.............................] - ETA: 7s - loss: 0.0426 - accuracy: 0.9880

  92/1688 [>.............................] - ETA: 7s - loss: 0.0411 - accuracy: 0.9885

 103/1688 [>.............................] - ETA: 7s - loss: 0.0426 - accuracy: 0.9876

 115/1688 [=>............................] - ETA: 7s - loss: 0.0403 - accuracy: 0.9883

 127/1688 [=>............................] - ETA: 7s - loss: 0.0399 - accuracy: 0.9887

 138/1688 [=>............................] - ETA: 7s - loss: 0.0385 - accuracy: 0.9891

 150/1688 [=>............................] - ETA: 6s - loss: 0.0395 - accuracy: 0.9890

 162/1688 [=>............................] - ETA: 6s - loss: 0.0399 - accuracy: 0.9892

 174/1688 [==>...........................] - ETA: 6s - loss: 0.0395 - accuracy: 0.9894

 186/1688 [==>...........................] - ETA: 6s - loss: 0.0386 - accuracy: 0.9896

 198/1688 [==>...........................] - ETA: 6s - loss: 0.0381 - accuracy: 0.9897

 210/1688 [==>...........................] - ETA: 6s - loss: 0.0370 - accuracy: 0.9902

 222/1688 [==>...........................] - ETA: 6s - loss: 0.0363 - accuracy: 0.9901

 233/1688 [===>..........................] - ETA: 6s - loss: 0.0369 - accuracy: 0.9898

 244/1688 [===>..........................] - ETA: 6s - loss: 0.0372 - accuracy: 0.9896

 256/1688 [===>..........................] - ETA: 6s - loss: 0.0367 - accuracy: 0.9897

 267/1688 [===>..........................] - ETA: 6s - loss: 0.0365 - accuracy: 0.9896

 278/1688 [===>..........................] - ETA: 6s - loss: 0.0362 - accuracy: 0.9897

 289/1688 [====>.........................] - ETA: 6s - loss: 0.0361 - accuracy: 0.9896

 301/1688 [====>.........................] - ETA: 6s - loss: 0.0359 - accuracy: 0.9896

 312/1688 [====>.........................] - ETA: 6s - loss: 0.0361 - accuracy: 0.9896

 323/1688 [====>.........................] - ETA: 6s - loss: 0.0353 - accuracy: 0.9898

 334/1688 [====>.........................] - ETA: 6s - loss: 0.0356 - accuracy: 0.9897

 345/1688 [=====>........................] - ETA: 6s - loss: 0.0356 - accuracy: 0.9899

 356/1688 [=====>........................] - ETA: 6s - loss: 0.0357 - accuracy: 0.9898

 367/1688 [=====>........................] - ETA: 6s - loss: 0.0355 - accuracy: 0.9898

 378/1688 [=====>........................] - ETA: 5s - loss: 0.0350 - accuracy: 0.9898

 389/1688 [=====>........................] - ETA: 5s - loss: 0.0363 - accuracy: 0.9896





































































































































































































































<tf_keras.src.callbacks.History at 0x7f52a0150a60>

Check sparsity for both models.

In [11]:
print("Clustered Model sparsity:\n")
print_model_weights_sparsity(clustered_model)
print("\nSparsity preserved clustered Model sparsity:\n")
print_model_weights_sparsity(sparsity_clustered_model)

Clustered Model sparsity:

conv2d/kernel:0: 8.33% sparsity  (9/108)
dense/kernel:0: 5.59% sparsity  (1133/20280)

Sparsity preserved clustered Model sparsity:

conv2d/kernel:0: 50.00% sparsity  (54/108)
dense/kernel:0: 50.00% sparsity  (10140/20280)


## Create 1.6x smaller models from clustering

Define helper function to get zipped model file.

In [12]:
def get_gzipped_model_size(file):
  # It returns the size of the gzipped model in kilobytes.

  _, zipped_file = tempfile.mkstemp('.zip')
  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(file)

  return os.path.getsize(zipped_file)/1000

In [13]:
# Clustered model
clustered_model_file = 'clustered_model.h5'

# Save the model.
clustered_model.save(clustered_model_file)
    
#Sparsity Preserve Clustered model
sparsity_clustered_model_file = 'sparsity_clustered_model.h5'

# Save the model.
sparsity_clustered_model.save(sparsity_clustered_model_file)
    
print("Clustered Model size: ", get_gzipped_model_size(clustered_model_file), ' KB')
print("Sparsity preserved clustered Model size: ", get_gzipped_model_size(sparsity_clustered_model_file), ' KB')

Clustered Model size:  233.187  KB
Sparsity preserved clustered Model size:  148.963  KB


## Create a TFLite model from combining sparsity preserving weight clustering and post-training quantization

Strip clustering wrappers and convert to TFLite.

In [14]:
stripped_sparsity_clustered_model = tfmot.clustering.keras.strip_clustering(sparsity_clustered_model)

converter = tf.lite.TFLiteConverter.from_keras_model(stripped_sparsity_clustered_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
sparsity_clustered_quant_model = converter.convert()

_, pruned_and_clustered_tflite_file = tempfile.mkstemp('.tflite')

with open(pruned_and_clustered_tflite_file, 'wb') as f:
  f.write(sparsity_clustered_quant_model)

print("Sparsity preserved clustered Model size: ", get_gzipped_model_size(sparsity_clustered_model_file), ' KB')
print("Sparsity preserved clustered and quantized TFLite model size:",
       get_gzipped_model_size(pruned_and_clustered_tflite_file), ' KB')

INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpovhqwz69/assets


INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpovhqwz69/assets


Sparsity preserved clustered Model size:  148.963  KB
Sparsity preserved clustered and quantized TFLite model size: 8.125  KB


W0000 00:00:1750507096.750501   44854 tf_tfl_flatbuffer_helpers.cc:365] Ignored output_format.
W0000 00:00:1750507096.750539   44854 tf_tfl_flatbuffer_helpers.cc:368] Ignored drop_control_dependency.
I0000 00:00:1750507096.756116   44854 mlir_graph_optimization_pass.cc:425] MLIR V1 optimization pass is not enabled


## See the persistence of accuracy from TF to TFLite

In [15]:
def eval_model(interpreter):
  input_index = interpreter.get_input_details()[0]["index"]
  output_index = interpreter.get_output_details()[0]["index"]

  # Run predictions on every image in the "test" dataset.
  prediction_digits = []
  for i, test_image in enumerate(test_images):
    if i % 1000 == 0:
      print(f"Evaluated on {i} results so far.")
    # Pre-processing: add batch dimension and convert to float32 to match with
    # the model's input data format.
    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)

    # Run inference.
    interpreter.invoke()

    # Post-processing: remove batch dimension and find the digit with highest
    # probability.
    output = interpreter.tensor(output_index)
    digit = np.argmax(output()[0])
    prediction_digits.append(digit)

  print('\n')
  # Compare prediction results with ground truth labels to calculate accuracy.
  prediction_digits = np.array(prediction_digits)
  accuracy = (prediction_digits == test_labels).mean()
  return accuracy

You evaluate the model, which has been pruned, clustered and quantized, and then see that the accuracy from TensorFlow persists in the TFLite backend.

In [16]:
# Keras model evaluation
stripped_sparsity_clustered_model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
_, sparsity_clustered_keras_accuracy = stripped_sparsity_clustered_model.evaluate(
    test_images, test_labels, verbose=0)

# TFLite model evaluation
interpreter = tf.lite.Interpreter(pruned_and_clustered_tflite_file)
interpreter.allocate_tensors()

sparsity_clustered_tflite_accuracy = eval_model(interpreter)

print('Pruned, clustered and quantized Keras model accuracy:', sparsity_clustered_keras_accuracy)
print('Pruned, clustered and quantized TFLite model accuracy:', sparsity_clustered_tflite_accuracy)

    TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package.
    See the [migration guide](https://ai.google.dev/edge/litert/migration)
    for details.
    
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.


Evaluated on 0 results so far.
Evaluated on 1000 results so far.
Evaluated on 2000 results so far.
Evaluated on 3000 results so far.
Evaluated on 4000 results so far.
Evaluated on 5000 results so far.
Evaluated on 6000 results so far.
Evaluated on 7000 results so far.


Evaluated on 8000 results so far.
Evaluated on 9000 results so far.


Pruned, clustered and quantized Keras model accuracy: 0.9805999994277954
Pruned, clustered and quantized TFLite model accuracy: 0.9805


## Conclusion

In this tutorial, you learned how to create a model, prune it using the `prune_low_magnitude()` API, and apply sparsity preserving clustering to preserve sparsity while clustering the weights. The sparsity preserving clustered model was compared to a clustered one to show that sparsity is preserved in the former and lost in the latter. Next, the pruned clustered model was converted to TFLite to show the compression benefits of chaining the pruning and sparsity preserving clustering model optimization techniques and, finally, the TFLite model was evaluated to ensure that the accuracy persists in the TFLite backend.