**Copyright 2021 The TensorFlow Authors.**

In [1]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://www.tensorflow.org/model_optimization/guide/combine/pcqat_example"><img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/model-optimization/blob/master/tensorflow_model_optimization/g3doc/guide/combine/pcqat_example.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/model-optimization/blob/master/tensorflow_model_optimization/g3doc/guide/combine/pcqat_example.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View on GitHub</a>
  </td>
  <td>
    <a href="https://storage.googleapis.com/tensorflow_docs/model-optimization/tensorflow_model_optimization/g3doc/guide/combine/pcqat_example.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png" />Download notebook</a>
  </td>
</table>

# Sparsity and cluster preserving quantization aware training (PCQAT) Keras example

## Overview

This is an end to end example showing the usage of the **sparsity and cluster preserving quantization aware training (PCQAT)** API, part of the TensorFlow Model Optimization Toolkit's collaborative optimization pipeline.

### Other pages

For an introduction to the pipeline and other available techniques, see the [collaborative optimization overview page](https://www.tensorflow.org/model_optimization/guide/combine/collaborative_optimization).

### Contents

In the tutorial, you will:

1. Train a `keras` model for the MNIST dataset from scratch.
2. Fine-tune the model with pruning and see the accuracy and observe that the model was successfully pruned.
3. Apply sparsity preserving clustering on the pruned model and observe that the sparsity applied earlier has been preserved.
4. Apply QAT and observe the loss of sparsity and clusters.
5. Apply PCQAT and observe that both sparsity and clustering applied earlier have been preserved.
6. Generate a TFLite model and observe the effects of applying PCQAT on it.
7. Compare the sizes of the different models to observe the compression benefits of applying sparsity followed by the collaborative optimization techniques of sparsity preserving clustering and PCQAT.
8. Compare the accurracy of the fully optimized model with the un-optimized baseline model accuracy.

## Setup

You can run this Jupyter Notebook in your local [virtualenv](https://www.tensorflow.org/install/pip?lang=python3#2.-create-a-virtual-environment-recommended) or [colab](https://colab.sandbox.google.com/). For details of setting up dependencies, please refer to the [installation guide](https://www.tensorflow.org/model_optimization/guide/install).

In [2]:
! pip install -q tensorflow-model-optimization

In [3]:
import tensorflow as tf
import tf_keras as keras

import numpy as np
import tempfile
import zipfile
import os

## Train a keras model for MNIST to be pruned and clustered

In [4]:
# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images  = test_images / 255.0

model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=12, kernel_size=(3, 3),
                         activation=tf.nn.relu),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])

opt = keras.optimizers.Adam(learning_rate=1e-3)

# Train the digit classification model
model.compile(optimizer=opt,
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(
    train_images,
    train_labels,
    validation_split=0.1,
    epochs=10
)

2024-03-09 12:49:28.954689: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:282] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected


Epoch 1/10


   1/1688 [..............................] - ETA: 6:33:56 - loss: 2.2772 - accuracy: 0.1250

  12/1688 [..............................] - ETA: 8s - loss: 2.2091 - accuracy: 0.2240     

  24/1688 [..............................] - ETA: 7s - loss: 2.0947 - accuracy: 0.3372

  37/1688 [..............................] - ETA: 7s - loss: 1.9543 - accuracy: 0.4383

  50/1688 [..............................] - ETA: 6s - loss: 1.7983 - accuracy: 0.5206

  63/1688 [>.............................] - ETA: 6s - loss: 1.6339 - accuracy: 0.5714

  76/1688 [>.............................] - ETA: 6s - loss: 1.4822 - accuracy: 0.6151

  89/1688 [>.............................] - ETA: 6s - loss: 1.3483 - accuracy: 0.6503

 103/1688 [>.............................] - ETA: 6s - loss: 1.2429 - accuracy: 0.6784

 116/1688 [=>............................] - ETA: 6s - loss: 1.1603 - accuracy: 0.6983

 130/1688 [=>............................] - ETA: 6s - loss: 1.0889 - accuracy: 0.7149

 143/1688 [=>............................] - ETA: 6s - loss: 1.0325 - accuracy: 0.7297

 156/1688 [=>............................] - ETA: 6s - loss: 0.9867 - accuracy: 0.7406

 170/1688 [==>...........................] - ETA: 6s - loss: 0.9378 - accuracy: 0.7529

 184/1688 [==>...........................] - ETA: 5s - loss: 0.9014 - accuracy: 0.7621

 198/1688 [==>...........................] - ETA: 5s - loss: 0.8651 - accuracy: 0.7708

 212/1688 [==>...........................] - ETA: 5s - loss: 0.8324 - accuracy: 0.7784

 225/1688 [==>...........................] - ETA: 5s - loss: 0.8033 - accuracy: 0.7861

 239/1688 [===>..........................] - ETA: 5s - loss: 0.7761 - accuracy: 0.7920

 252/1688 [===>..........................] - ETA: 5s - loss: 0.7554 - accuracy: 0.7969

 266/1688 [===>..........................] - ETA: 5s - loss: 0.7343 - accuracy: 0.8020

 280/1688 [===>..........................] - ETA: 5s - loss: 0.7137 - accuracy: 0.8071

 294/1688 [====>.........................] - ETA: 5s - loss: 0.6950 - accuracy: 0.8115

 307/1688 [====>.........................] - ETA: 5s - loss: 0.6794 - accuracy: 0.8159

 320/1688 [====>.........................] - ETA: 5s - loss: 0.6671 - accuracy: 0.8192

 334/1688 [====>.........................] - ETA: 5s - loss: 0.6502 - accuracy: 0.8238

 347/1688 [=====>........................] - ETA: 5s - loss: 0.6395 - accuracy: 0.8259

 361/1688 [=====>........................] - ETA: 5s - loss: 0.6275 - accuracy: 0.8288

 374/1688 [=====>........................] - ETA: 5s - loss: 0.6185 - accuracy: 0.8306

 387/1688 [=====>........................] - ETA: 5s - loss: 0.6094 - accuracy: 0.8330































































































































































































Epoch 2/10


   1/1688 [..............................] - ETA: 7s - loss: 0.4145 - accuracy: 0.8750

  15/1688 [..............................] - ETA: 6s - loss: 0.1730 - accuracy: 0.9438

  29/1688 [..............................] - ETA: 6s - loss: 0.1407 - accuracy: 0.9558

  43/1688 [..............................] - ETA: 6s - loss: 0.1617 - accuracy: 0.9513

  56/1688 [..............................] - ETA: 6s - loss: 0.1472 - accuracy: 0.9570

  70/1688 [>.............................] - ETA: 6s - loss: 0.1537 - accuracy: 0.9554

  84/1688 [>.............................] - ETA: 6s - loss: 0.1496 - accuracy: 0.9561

  98/1688 [>.............................] - ETA: 6s - loss: 0.1493 - accuracy: 0.9557

 112/1688 [>.............................] - ETA: 5s - loss: 0.1438 - accuracy: 0.9576

 126/1688 [=>............................] - ETA: 5s - loss: 0.1396 - accuracy: 0.9588

 140/1688 [=>............................] - ETA: 5s - loss: 0.1398 - accuracy: 0.9592

 154/1688 [=>............................] - ETA: 5s - loss: 0.1405 - accuracy: 0.9586

 168/1688 [=>............................] - ETA: 5s - loss: 0.1387 - accuracy: 0.9583

 181/1688 [==>...........................] - ETA: 5s - loss: 0.1388 - accuracy: 0.9586

 195/1688 [==>...........................] - ETA: 5s - loss: 0.1375 - accuracy: 0.9596

 209/1688 [==>...........................] - ETA: 5s - loss: 0.1345 - accuracy: 0.9608

 223/1688 [==>...........................] - ETA: 5s - loss: 0.1365 - accuracy: 0.9603

 237/1688 [===>..........................] - ETA: 5s - loss: 0.1359 - accuracy: 0.9604

 251/1688 [===>..........................] - ETA: 5s - loss: 0.1372 - accuracy: 0.9598

 265/1688 [===>..........................] - ETA: 5s - loss: 0.1353 - accuracy: 0.9606

 279/1688 [===>..........................] - ETA: 5s - loss: 0.1354 - accuracy: 0.9607

 293/1688 [====>.........................] - ETA: 5s - loss: 0.1339 - accuracy: 0.9612

 307/1688 [====>.........................] - ETA: 5s - loss: 0.1329 - accuracy: 0.9615

 321/1688 [====>.........................] - ETA: 5s - loss: 0.1333 - accuracy: 0.9614

 335/1688 [====>.........................] - ETA: 5s - loss: 0.1344 - accuracy: 0.9611

 349/1688 [=====>........................] - ETA: 5s - loss: 0.1350 - accuracy: 0.9607

 363/1688 [=====>........................] - ETA: 4s - loss: 0.1344 - accuracy: 0.9611

 377/1688 [=====>........................] - ETA: 4s - loss: 0.1343 - accuracy: 0.9610

 391/1688 [=====>........................] - ETA: 4s - loss: 0.1342 - accuracy: 0.9613





























































































































































































Epoch 3/10


   1/1688 [..............................] - ETA: 7s - loss: 0.0194 - accuracy: 1.0000

  15/1688 [..............................] - ETA: 6s - loss: 0.0647 - accuracy: 0.9812

  29/1688 [..............................] - ETA: 6s - loss: 0.0670 - accuracy: 0.9795

  43/1688 [..............................] - ETA: 6s - loss: 0.0743 - accuracy: 0.9760

  57/1688 [>.............................] - ETA: 6s - loss: 0.0753 - accuracy: 0.9759

  71/1688 [>.............................] - ETA: 6s - loss: 0.0725 - accuracy: 0.9762

  85/1688 [>.............................] - ETA: 6s - loss: 0.0700 - accuracy: 0.9776

  98/1688 [>.............................] - ETA: 6s - loss: 0.0732 - accuracy: 0.9767

 112/1688 [>.............................] - ETA: 5s - loss: 0.0732 - accuracy: 0.9766

 126/1688 [=>............................] - ETA: 5s - loss: 0.0748 - accuracy: 0.9774

 139/1688 [=>............................] - ETA: 5s - loss: 0.0783 - accuracy: 0.9766

 153/1688 [=>............................] - ETA: 5s - loss: 0.0797 - accuracy: 0.9763

 167/1688 [=>............................] - ETA: 5s - loss: 0.0815 - accuracy: 0.9759

 181/1688 [==>...........................] - ETA: 5s - loss: 0.0800 - accuracy: 0.9763

 195/1688 [==>...........................] - ETA: 5s - loss: 0.0792 - accuracy: 0.9766

 209/1688 [==>...........................] - ETA: 5s - loss: 0.0784 - accuracy: 0.9768

 223/1688 [==>...........................] - ETA: 5s - loss: 0.0802 - accuracy: 0.9760

 237/1688 [===>..........................] - ETA: 5s - loss: 0.0794 - accuracy: 0.9763

 251/1688 [===>..........................] - ETA: 5s - loss: 0.0810 - accuracy: 0.9753

 265/1688 [===>..........................] - ETA: 5s - loss: 0.0813 - accuracy: 0.9751

 279/1688 [===>..........................] - ETA: 5s - loss: 0.0818 - accuracy: 0.9752

 293/1688 [====>.........................] - ETA: 5s - loss: 0.0803 - accuracy: 0.9758

 307/1688 [====>.........................] - ETA: 5s - loss: 0.0800 - accuracy: 0.9759

 321/1688 [====>.........................] - ETA: 5s - loss: 0.0797 - accuracy: 0.9758

 335/1688 [====>.........................] - ETA: 5s - loss: 0.0802 - accuracy: 0.9759

 349/1688 [=====>........................] - ETA: 5s - loss: 0.0800 - accuracy: 0.9761

 363/1688 [=====>........................] - ETA: 5s - loss: 0.0797 - accuracy: 0.9762

 377/1688 [=====>........................] - ETA: 4s - loss: 0.0798 - accuracy: 0.9759

 391/1688 [=====>........................] - ETA: 4s - loss: 0.0795 - accuracy: 0.9760





























































































































































































Epoch 4/10


   1/1688 [..............................] - ETA: 7s - loss: 0.1059 - accuracy: 0.9375

  15/1688 [..............................] - ETA: 6s - loss: 0.0819 - accuracy: 0.9750

  28/1688 [..............................] - ETA: 6s - loss: 0.0642 - accuracy: 0.9799

  42/1688 [..............................] - ETA: 6s - loss: 0.0643 - accuracy: 0.9821

  56/1688 [..............................] - ETA: 6s - loss: 0.0616 - accuracy: 0.9821

  70/1688 [>.............................] - ETA: 6s - loss: 0.0604 - accuracy: 0.9830

  84/1688 [>.............................] - ETA: 6s - loss: 0.0578 - accuracy: 0.9840

  98/1688 [>.............................] - ETA: 6s - loss: 0.0596 - accuracy: 0.9828

 112/1688 [>.............................] - ETA: 6s - loss: 0.0608 - accuracy: 0.9827

 126/1688 [=>............................] - ETA: 5s - loss: 0.0611 - accuracy: 0.9829

 140/1688 [=>............................] - ETA: 5s - loss: 0.0645 - accuracy: 0.9824

 154/1688 [=>............................] - ETA: 5s - loss: 0.0643 - accuracy: 0.9823

 168/1688 [=>............................] - ETA: 5s - loss: 0.0662 - accuracy: 0.9825

 182/1688 [==>...........................] - ETA: 5s - loss: 0.0662 - accuracy: 0.9825

 196/1688 [==>...........................] - ETA: 5s - loss: 0.0674 - accuracy: 0.9815

 210/1688 [==>...........................] - ETA: 5s - loss: 0.0670 - accuracy: 0.9812

 224/1688 [==>...........................] - ETA: 5s - loss: 0.0675 - accuracy: 0.9809

 238/1688 [===>..........................] - ETA: 5s - loss: 0.0679 - accuracy: 0.9808

 252/1688 [===>..........................] - ETA: 5s - loss: 0.0684 - accuracy: 0.9808

 266/1688 [===>..........................] - ETA: 5s - loss: 0.0670 - accuracy: 0.9812

 280/1688 [===>..........................] - ETA: 5s - loss: 0.0672 - accuracy: 0.9814

 293/1688 [====>.........................] - ETA: 5s - loss: 0.0668 - accuracy: 0.9811

 306/1688 [====>.........................] - ETA: 5s - loss: 0.0666 - accuracy: 0.9810

 320/1688 [====>.........................] - ETA: 5s - loss: 0.0660 - accuracy: 0.9811

 334/1688 [====>.........................] - ETA: 5s - loss: 0.0653 - accuracy: 0.9812

 348/1688 [=====>........................] - ETA: 5s - loss: 0.0641 - accuracy: 0.9816

 361/1688 [=====>........................] - ETA: 5s - loss: 0.0644 - accuracy: 0.9816

 375/1688 [=====>........................] - ETA: 4s - loss: 0.0640 - accuracy: 0.9818

 389/1688 [=====>........................] - ETA: 4s - loss: 0.0650 - accuracy: 0.9815































































































































































































Epoch 5/10


   1/1688 [..............................] - ETA: 7s - loss: 0.0669 - accuracy: 0.9688

  15/1688 [..............................] - ETA: 6s - loss: 0.0481 - accuracy: 0.9854

  29/1688 [..............................] - ETA: 6s - loss: 0.0534 - accuracy: 0.9838

  43/1688 [..............................] - ETA: 6s - loss: 0.0649 - accuracy: 0.9818

  57/1688 [>.............................] - ETA: 6s - loss: 0.0586 - accuracy: 0.9841

  71/1688 [>.............................] - ETA: 6s - loss: 0.0597 - accuracy: 0.9820

  85/1688 [>.............................] - ETA: 6s - loss: 0.0607 - accuracy: 0.9820

  99/1688 [>.............................] - ETA: 6s - loss: 0.0597 - accuracy: 0.9814

 113/1688 [=>............................] - ETA: 5s - loss: 0.0665 - accuracy: 0.9801

 127/1688 [=>............................] - ETA: 5s - loss: 0.0667 - accuracy: 0.9798

 141/1688 [=>............................] - ETA: 5s - loss: 0.0642 - accuracy: 0.9812

 155/1688 [=>............................] - ETA: 5s - loss: 0.0639 - accuracy: 0.9804

 169/1688 [==>...........................] - ETA: 5s - loss: 0.0629 - accuracy: 0.9811

 183/1688 [==>...........................] - ETA: 5s - loss: 0.0638 - accuracy: 0.9812

 197/1688 [==>...........................] - ETA: 5s - loss: 0.0634 - accuracy: 0.9811

 211/1688 [==>...........................] - ETA: 5s - loss: 0.0617 - accuracy: 0.9818

 225/1688 [==>...........................] - ETA: 5s - loss: 0.0608 - accuracy: 0.9822

 239/1688 [===>..........................] - ETA: 5s - loss: 0.0595 - accuracy: 0.9825

 253/1688 [===>..........................] - ETA: 5s - loss: 0.0596 - accuracy: 0.9825

 267/1688 [===>..........................] - ETA: 5s - loss: 0.0597 - accuracy: 0.9822

 280/1688 [===>..........................] - ETA: 5s - loss: 0.0592 - accuracy: 0.9825

 294/1688 [====>.........................] - ETA: 5s - loss: 0.0587 - accuracy: 0.9826

 308/1688 [====>.........................] - ETA: 5s - loss: 0.0582 - accuracy: 0.9829

 322/1688 [====>.........................] - ETA: 5s - loss: 0.0583 - accuracy: 0.9827

 336/1688 [====>.........................] - ETA: 5s - loss: 0.0583 - accuracy: 0.9829

 350/1688 [=====>........................] - ETA: 5s - loss: 0.0574 - accuracy: 0.9835

 364/1688 [=====>........................] - ETA: 5s - loss: 0.0568 - accuracy: 0.9835

 378/1688 [=====>........................] - ETA: 4s - loss: 0.0572 - accuracy: 0.9833

 392/1688 [=====>........................] - ETA: 4s - loss: 0.0563 - accuracy: 0.9837





























































































































































































Epoch 6/10


   1/1688 [..............................] - ETA: 7s - loss: 0.0145 - accuracy: 1.0000

  15/1688 [..............................] - ETA: 6s - loss: 0.0451 - accuracy: 0.9854

  28/1688 [..............................] - ETA: 6s - loss: 0.0451 - accuracy: 0.9821

  41/1688 [..............................] - ETA: 6s - loss: 0.0456 - accuracy: 0.9832

  54/1688 [..............................] - ETA: 6s - loss: 0.0455 - accuracy: 0.9838

  67/1688 [>.............................] - ETA: 6s - loss: 0.0425 - accuracy: 0.9851

  80/1688 [>.............................] - ETA: 6s - loss: 0.0445 - accuracy: 0.9852

  93/1688 [>.............................] - ETA: 6s - loss: 0.0443 - accuracy: 0.9862

 106/1688 [>.............................] - ETA: 6s - loss: 0.0460 - accuracy: 0.9861

 119/1688 [=>............................] - ETA: 6s - loss: 0.0449 - accuracy: 0.9863

 132/1688 [=>............................] - ETA: 6s - loss: 0.0496 - accuracy: 0.9860

 146/1688 [=>............................] - ETA: 5s - loss: 0.0487 - accuracy: 0.9865

 160/1688 [=>............................] - ETA: 5s - loss: 0.0504 - accuracy: 0.9859

 173/1688 [==>...........................] - ETA: 5s - loss: 0.0501 - accuracy: 0.9855

 187/1688 [==>...........................] - ETA: 5s - loss: 0.0503 - accuracy: 0.9853

 201/1688 [==>...........................] - ETA: 5s - loss: 0.0494 - accuracy: 0.9859

 215/1688 [==>...........................] - ETA: 5s - loss: 0.0497 - accuracy: 0.9856

 228/1688 [===>..........................] - ETA: 5s - loss: 0.0496 - accuracy: 0.9856

 241/1688 [===>..........................] - ETA: 5s - loss: 0.0492 - accuracy: 0.9859

 254/1688 [===>..........................] - ETA: 5s - loss: 0.0481 - accuracy: 0.9862

 267/1688 [===>..........................] - ETA: 5s - loss: 0.0484 - accuracy: 0.9858

 280/1688 [===>..........................] - ETA: 5s - loss: 0.0501 - accuracy: 0.9855

 293/1688 [====>.........................] - ETA: 5s - loss: 0.0506 - accuracy: 0.9851

 306/1688 [====>.........................] - ETA: 5s - loss: 0.0502 - accuracy: 0.9855

 319/1688 [====>.........................] - ETA: 5s - loss: 0.0496 - accuracy: 0.9857

 332/1688 [====>.........................] - ETA: 5s - loss: 0.0488 - accuracy: 0.9859

 345/1688 [=====>........................] - ETA: 5s - loss: 0.0492 - accuracy: 0.9859

 359/1688 [=====>........................] - ETA: 5s - loss: 0.0490 - accuracy: 0.9858

 373/1688 [=====>........................] - ETA: 5s - loss: 0.0492 - accuracy: 0.9858

 387/1688 [=====>........................] - ETA: 5s - loss: 0.0488 - accuracy: 0.9859





























































































































































































Epoch 7/10


   1/1688 [..............................] - ETA: 7s - loss: 0.0511 - accuracy: 0.9688

  15/1688 [..............................] - ETA: 6s - loss: 0.0659 - accuracy: 0.9750

  29/1688 [..............................] - ETA: 6s - loss: 0.0646 - accuracy: 0.9774

  43/1688 [..............................] - ETA: 6s - loss: 0.0554 - accuracy: 0.9811

  57/1688 [>.............................] - ETA: 6s - loss: 0.0496 - accuracy: 0.9830

  71/1688 [>.............................] - ETA: 6s - loss: 0.0468 - accuracy: 0.9842

  85/1688 [>.............................] - ETA: 6s - loss: 0.0447 - accuracy: 0.9849

  99/1688 [>.............................] - ETA: 5s - loss: 0.0425 - accuracy: 0.9864

 113/1688 [=>............................] - ETA: 5s - loss: 0.0413 - accuracy: 0.9870

 127/1688 [=>............................] - ETA: 5s - loss: 0.0410 - accuracy: 0.9870

 141/1688 [=>............................] - ETA: 5s - loss: 0.0427 - accuracy: 0.9867

 154/1688 [=>............................] - ETA: 5s - loss: 0.0413 - accuracy: 0.9866

 168/1688 [=>............................] - ETA: 5s - loss: 0.0408 - accuracy: 0.9868

 182/1688 [==>...........................] - ETA: 5s - loss: 0.0402 - accuracy: 0.9873

 196/1688 [==>...........................] - ETA: 5s - loss: 0.0411 - accuracy: 0.9872

 210/1688 [==>...........................] - ETA: 5s - loss: 0.0411 - accuracy: 0.9874

 224/1688 [==>...........................] - ETA: 5s - loss: 0.0417 - accuracy: 0.9876

 238/1688 [===>..........................] - ETA: 5s - loss: 0.0419 - accuracy: 0.9875

 252/1688 [===>..........................] - ETA: 5s - loss: 0.0427 - accuracy: 0.9874

 266/1688 [===>..........................] - ETA: 5s - loss: 0.0425 - accuracy: 0.9875

 280/1688 [===>..........................] - ETA: 5s - loss: 0.0415 - accuracy: 0.9878

 294/1688 [====>.........................] - ETA: 5s - loss: 0.0417 - accuracy: 0.9880

 308/1688 [====>.........................] - ETA: 5s - loss: 0.0416 - accuracy: 0.9878

 322/1688 [====>.........................] - ETA: 5s - loss: 0.0414 - accuracy: 0.9875

 336/1688 [====>.........................] - ETA: 5s - loss: 0.0413 - accuracy: 0.9875

 350/1688 [=====>........................] - ETA: 5s - loss: 0.0417 - accuracy: 0.9876

 364/1688 [=====>........................] - ETA: 4s - loss: 0.0411 - accuracy: 0.9878

 378/1688 [=====>........................] - ETA: 4s - loss: 0.0411 - accuracy: 0.9877

 392/1688 [=====>........................] - ETA: 4s - loss: 0.0408 - accuracy: 0.9879



























































































































































































Epoch 8/10


   1/1688 [..............................] - ETA: 7s - loss: 0.0117 - accuracy: 1.0000

  15/1688 [..............................] - ETA: 6s - loss: 0.0427 - accuracy: 0.9875

  29/1688 [..............................] - ETA: 6s - loss: 0.0361 - accuracy: 0.9925

  43/1688 [..............................] - ETA: 6s - loss: 0.0398 - accuracy: 0.9913

  57/1688 [>.............................] - ETA: 6s - loss: 0.0438 - accuracy: 0.9896

  71/1688 [>.............................] - ETA: 6s - loss: 0.0398 - accuracy: 0.9894

  85/1688 [>.............................] - ETA: 6s - loss: 0.0418 - accuracy: 0.9890

  99/1688 [>.............................] - ETA: 5s - loss: 0.0464 - accuracy: 0.9883

 113/1688 [=>............................] - ETA: 5s - loss: 0.0474 - accuracy: 0.9873

 127/1688 [=>............................] - ETA: 5s - loss: 0.0458 - accuracy: 0.9877

 141/1688 [=>............................] - ETA: 5s - loss: 0.0469 - accuracy: 0.9874

 155/1688 [=>............................] - ETA: 5s - loss: 0.0454 - accuracy: 0.9875

 169/1688 [==>...........................] - ETA: 5s - loss: 0.0442 - accuracy: 0.9878

 183/1688 [==>...........................] - ETA: 5s - loss: 0.0425 - accuracy: 0.9882

 197/1688 [==>...........................] - ETA: 5s - loss: 0.0431 - accuracy: 0.9879

 211/1688 [==>...........................] - ETA: 5s - loss: 0.0422 - accuracy: 0.9882

 225/1688 [==>...........................] - ETA: 5s - loss: 0.0433 - accuracy: 0.9876

 238/1688 [===>..........................] - ETA: 5s - loss: 0.0426 - accuracy: 0.9875

 252/1688 [===>..........................] - ETA: 5s - loss: 0.0423 - accuracy: 0.9875

 266/1688 [===>..........................] - ETA: 5s - loss: 0.0411 - accuracy: 0.9879

 280/1688 [===>..........................] - ETA: 5s - loss: 0.0402 - accuracy: 0.9882

 294/1688 [====>.........................] - ETA: 5s - loss: 0.0397 - accuracy: 0.9885

 308/1688 [====>.........................] - ETA: 5s - loss: 0.0388 - accuracy: 0.9888

 322/1688 [====>.........................] - ETA: 5s - loss: 0.0390 - accuracy: 0.9889

 336/1688 [====>.........................] - ETA: 5s - loss: 0.0385 - accuracy: 0.9892

 350/1688 [=====>........................] - ETA: 5s - loss: 0.0386 - accuracy: 0.9891

 364/1688 [=====>........................] - ETA: 4s - loss: 0.0392 - accuracy: 0.9890

 378/1688 [=====>........................] - ETA: 4s - loss: 0.0392 - accuracy: 0.9893

 392/1688 [=====>........................] - ETA: 4s - loss: 0.0388 - accuracy: 0.9892





























































































































































































Epoch 9/10


   1/1688 [..............................] - ETA: 7s - loss: 0.0240 - accuracy: 1.0000

  15/1688 [..............................] - ETA: 6s - loss: 0.0267 - accuracy: 0.9958

  29/1688 [..............................] - ETA: 6s - loss: 0.0315 - accuracy: 0.9892

  43/1688 [..............................] - ETA: 6s - loss: 0.0393 - accuracy: 0.9876

  57/1688 [>.............................] - ETA: 6s - loss: 0.0416 - accuracy: 0.9868

  71/1688 [>.............................] - ETA: 6s - loss: 0.0430 - accuracy: 0.9868

  85/1688 [>.............................] - ETA: 6s - loss: 0.0425 - accuracy: 0.9871

  99/1688 [>.............................] - ETA: 6s - loss: 0.0395 - accuracy: 0.9880

 113/1688 [=>............................] - ETA: 5s - loss: 0.0373 - accuracy: 0.9881

 127/1688 [=>............................] - ETA: 5s - loss: 0.0352 - accuracy: 0.9889

 141/1688 [=>............................] - ETA: 5s - loss: 0.0327 - accuracy: 0.9900

 155/1688 [=>............................] - ETA: 5s - loss: 0.0321 - accuracy: 0.9903

 169/1688 [==>...........................] - ETA: 5s - loss: 0.0324 - accuracy: 0.9902

 183/1688 [==>...........................] - ETA: 5s - loss: 0.0320 - accuracy: 0.9899

 197/1688 [==>...........................] - ETA: 5s - loss: 0.0312 - accuracy: 0.9902

 211/1688 [==>...........................] - ETA: 5s - loss: 0.0314 - accuracy: 0.9901

 225/1688 [==>...........................] - ETA: 5s - loss: 0.0321 - accuracy: 0.9899

 239/1688 [===>..........................] - ETA: 5s - loss: 0.0323 - accuracy: 0.9897

 253/1688 [===>..........................] - ETA: 5s - loss: 0.0328 - accuracy: 0.9895

 267/1688 [===>..........................] - ETA: 5s - loss: 0.0321 - accuracy: 0.9897

 281/1688 [===>..........................] - ETA: 5s - loss: 0.0316 - accuracy: 0.9900

 295/1688 [====>.........................] - ETA: 5s - loss: 0.0322 - accuracy: 0.9898

 309/1688 [====>.........................] - ETA: 5s - loss: 0.0328 - accuracy: 0.9899

 323/1688 [====>.........................] - ETA: 5s - loss: 0.0321 - accuracy: 0.9900

 337/1688 [====>.........................] - ETA: 5s - loss: 0.0322 - accuracy: 0.9899

 351/1688 [=====>........................] - ETA: 5s - loss: 0.0321 - accuracy: 0.9898

 365/1688 [=====>........................] - ETA: 5s - loss: 0.0321 - accuracy: 0.9897

 379/1688 [=====>........................] - ETA: 4s - loss: 0.0318 - accuracy: 0.9899

 393/1688 [=====>........................] - ETA: 4s - loss: 0.0322 - accuracy: 0.9899



























































































































































































Epoch 10/10


   1/1688 [..............................] - ETA: 7s - loss: 0.0178 - accuracy: 1.0000

  15/1688 [..............................] - ETA: 6s - loss: 0.0285 - accuracy: 0.9937

  29/1688 [..............................] - ETA: 6s - loss: 0.0259 - accuracy: 0.9946

  43/1688 [..............................] - ETA: 6s - loss: 0.0277 - accuracy: 0.9942

  57/1688 [>.............................] - ETA: 6s - loss: 0.0266 - accuracy: 0.9945

  71/1688 [>.............................] - ETA: 6s - loss: 0.0283 - accuracy: 0.9938

  85/1688 [>.............................] - ETA: 6s - loss: 0.0278 - accuracy: 0.9930

  99/1688 [>.............................] - ETA: 5s - loss: 0.0295 - accuracy: 0.9915

 113/1688 [=>............................] - ETA: 5s - loss: 0.0304 - accuracy: 0.9917

 127/1688 [=>............................] - ETA: 5s - loss: 0.0301 - accuracy: 0.9919

 141/1688 [=>............................] - ETA: 5s - loss: 0.0292 - accuracy: 0.9922

 155/1688 [=>............................] - ETA: 5s - loss: 0.0294 - accuracy: 0.9919

 169/1688 [==>...........................] - ETA: 5s - loss: 0.0304 - accuracy: 0.9917

 183/1688 [==>...........................] - ETA: 5s - loss: 0.0313 - accuracy: 0.9915

 197/1688 [==>...........................] - ETA: 5s - loss: 0.0321 - accuracy: 0.9914

 211/1688 [==>...........................] - ETA: 5s - loss: 0.0316 - accuracy: 0.9914

 225/1688 [==>...........................] - ETA: 5s - loss: 0.0310 - accuracy: 0.9917

 239/1688 [===>..........................] - ETA: 5s - loss: 0.0322 - accuracy: 0.9911

 253/1688 [===>..........................] - ETA: 5s - loss: 0.0351 - accuracy: 0.9901

 267/1688 [===>..........................] - ETA: 5s - loss: 0.0351 - accuracy: 0.9901

 281/1688 [===>..........................] - ETA: 5s - loss: 0.0344 - accuracy: 0.9902

 295/1688 [====>.........................] - ETA: 5s - loss: 0.0340 - accuracy: 0.9900

 309/1688 [====>.........................] - ETA: 5s - loss: 0.0340 - accuracy: 0.9900

 323/1688 [====>.........................] - ETA: 5s - loss: 0.0333 - accuracy: 0.9900

 337/1688 [====>.........................] - ETA: 5s - loss: 0.0334 - accuracy: 0.9900

 351/1688 [=====>........................] - ETA: 5s - loss: 0.0326 - accuracy: 0.9902

 365/1688 [=====>........................] - ETA: 4s - loss: 0.0321 - accuracy: 0.9904

 379/1688 [=====>........................] - ETA: 4s - loss: 0.0317 - accuracy: 0.9907

 393/1688 [=====>........................] - ETA: 4s - loss: 0.0319 - accuracy: 0.9905



























































































































































































<tf_keras.src.callbacks.History at 0x7f615076beb0>

### Evaluate the baseline model and save it for later usage

In [5]:
_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)

_, keras_file = tempfile.mkstemp('.h5')
print('Saving model to: ', keras_file)
keras.models.save_model(model, keras_file, include_optimizer=False)

Baseline test accuracy: 0.9835000038146973
Saving model to:  /tmpfs/tmp/tmpf70eijr3.h5


  keras.models.save_model(model, keras_file, include_optimizer=False)


## Prune and fine-tune the model to 50% sparsity

Apply the `prune_low_magnitude()` API to achieve the pruned model that is to be clustered in the next step. Refer to the [pruning comprehensive guide](https://www.tensorflow.org/model_optimization/guide/pruning/comprehensive_guide) for more information on the pruning API.

### Define the model and apply the sparsity API

Note that the pre-trained model is used.

In [6]:
import tensorflow_model_optimization as tfmot

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

pruning_params = {
      'pruning_schedule': tfmot.sparsity.keras.ConstantSparsity(0.5, begin_step=0, frequency=100)
  }

callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep()
]

pruned_model = prune_low_magnitude(model, **pruning_params)

# Use smaller learning rate for fine-tuning
opt = keras.optimizers.Adam(learning_rate=1e-5)

pruned_model.compile(
  loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  optimizer=opt,
  metrics=['accuracy'])

### Fine-tune the model, check sparsity, and evaluate the accuracy against baseline

Fine-tune the model with pruning for 3 epochs.

In [7]:
# Fine-tune model
pruned_model.fit(
  train_images,
  train_labels,
  epochs=3,
  validation_split=0.1,
  callbacks=callbacks)

Epoch 1/3


   1/1688 [..............................] - ETA: 1:11:51 - loss: 0.0019 - accuracy: 1.0000

  13/1688 [..............................] - ETA: 7s - loss: 0.0116 - accuracy: 0.9976     

  26/1688 [..............................] - ETA: 7s - loss: 0.0223 - accuracy: 0.9928

  38/1688 [..............................] - ETA: 6s - loss: 0.0199 - accuracy: 0.9942

  50/1688 [..............................] - ETA: 6s - loss: 0.0225 - accuracy: 0.9931

  63/1688 [>.............................] - ETA: 6s - loss: 0.0243 - accuracy: 0.9926

  76/1688 [>.............................] - ETA: 6s - loss: 0.0240 - accuracy: 0.9926

  89/1688 [>.............................] - ETA: 6s - loss: 0.0257 - accuracy: 0.9926

 101/1688 [>.............................] - ETA: 6s - loss: 0.0271 - accuracy: 0.9920

 114/1688 [=>............................] - ETA: 6s - loss: 0.0380 - accuracy: 0.9882

 127/1688 [=>............................] - ETA: 6s - loss: 0.0498 - accuracy: 0.9835

 140/1688 [=>............................] - ETA: 6s - loss: 0.0516 - accuracy: 0.9837

 153/1688 [=>............................] - ETA: 6s - loss: 0.0544 - accuracy: 0.9820

 166/1688 [=>............................] - ETA: 6s - loss: 0.0580 - accuracy: 0.9804

 179/1688 [==>...........................] - ETA: 6s - loss: 0.0628 - accuracy: 0.9787

 192/1688 [==>...........................] - ETA: 6s - loss: 0.0641 - accuracy: 0.9784

 205/1688 [==>...........................] - ETA: 6s - loss: 0.0678 - accuracy: 0.9779

 218/1688 [==>...........................] - ETA: 6s - loss: 0.0705 - accuracy: 0.9768

 231/1688 [===>..........................] - ETA: 5s - loss: 0.0749 - accuracy: 0.9752

 244/1688 [===>..........................] - ETA: 5s - loss: 0.0757 - accuracy: 0.9750

 257/1688 [===>..........................] - ETA: 5s - loss: 0.0774 - accuracy: 0.9742

 270/1688 [===>..........................] - ETA: 5s - loss: 0.0790 - accuracy: 0.9734

 283/1688 [====>.........................] - ETA: 5s - loss: 0.0792 - accuracy: 0.9733

 296/1688 [====>.........................] - ETA: 5s - loss: 0.0800 - accuracy: 0.9731

 309/1688 [====>.........................] - ETA: 5s - loss: 0.0803 - accuracy: 0.9725

 322/1688 [====>.........................] - ETA: 5s - loss: 0.0801 - accuracy: 0.9726

 335/1688 [====>.........................] - ETA: 5s - loss: 0.0820 - accuracy: 0.9720

 348/1688 [=====>........................] - ETA: 5s - loss: 0.0821 - accuracy: 0.9717

 361/1688 [=====>........................] - ETA: 5s - loss: 0.0832 - accuracy: 0.9713

 374/1688 [=====>........................] - ETA: 5s - loss: 0.0836 - accuracy: 0.9710

 388/1688 [=====>........................] - ETA: 5s - loss: 0.0869 - accuracy: 0.9703











































































































































































































Epoch 2/3


   1/1688 [..............................] - ETA: 8s - loss: 0.0325 - accuracy: 1.0000

  14/1688 [..............................] - ETA: 6s - loss: 0.0539 - accuracy: 0.9844

  27/1688 [..............................] - ETA: 6s - loss: 0.0633 - accuracy: 0.9792

  40/1688 [..............................] - ETA: 6s - loss: 0.0755 - accuracy: 0.9750

  53/1688 [..............................] - ETA: 6s - loss: 0.0737 - accuracy: 0.9770

  66/1688 [>.............................] - ETA: 6s - loss: 0.0682 - accuracy: 0.9792

  79/1688 [>.............................] - ETA: 6s - loss: 0.0684 - accuracy: 0.9790

  92/1688 [>.............................] - ETA: 6s - loss: 0.0670 - accuracy: 0.9783

 105/1688 [>.............................] - ETA: 6s - loss: 0.0672 - accuracy: 0.9777

 118/1688 [=>............................] - ETA: 6s - loss: 0.0677 - accuracy: 0.9783

 131/1688 [=>............................] - ETA: 6s - loss: 0.0668 - accuracy: 0.9783

 144/1688 [=>............................] - ETA: 6s - loss: 0.0683 - accuracy: 0.9779

 157/1688 [=>............................] - ETA: 6s - loss: 0.0697 - accuracy: 0.9775

 170/1688 [==>...........................] - ETA: 5s - loss: 0.0695 - accuracy: 0.9774

 183/1688 [==>...........................] - ETA: 5s - loss: 0.0681 - accuracy: 0.9776

 196/1688 [==>...........................] - ETA: 5s - loss: 0.0674 - accuracy: 0.9780

 209/1688 [==>...........................] - ETA: 5s - loss: 0.0659 - accuracy: 0.9783

 222/1688 [==>...........................] - ETA: 5s - loss: 0.0651 - accuracy: 0.9785

 235/1688 [===>..........................] - ETA: 5s - loss: 0.0654 - accuracy: 0.9781

 248/1688 [===>..........................] - ETA: 5s - loss: 0.0657 - accuracy: 0.9777

 261/1688 [===>..........................] - ETA: 5s - loss: 0.0663 - accuracy: 0.9771

 274/1688 [===>..........................] - ETA: 5s - loss: 0.0671 - accuracy: 0.9767

 287/1688 [====>.........................] - ETA: 5s - loss: 0.0668 - accuracy: 0.9769

 300/1688 [====>.........................] - ETA: 5s - loss: 0.0660 - accuracy: 0.9771

 313/1688 [====>.........................] - ETA: 5s - loss: 0.0650 - accuracy: 0.9775

 326/1688 [====>.........................] - ETA: 5s - loss: 0.0647 - accuracy: 0.9775

 339/1688 [=====>........................] - ETA: 5s - loss: 0.0646 - accuracy: 0.9771

 352/1688 [=====>........................] - ETA: 5s - loss: 0.0643 - accuracy: 0.9773

 365/1688 [=====>........................] - ETA: 5s - loss: 0.0642 - accuracy: 0.9772

 378/1688 [=====>........................] - ETA: 5s - loss: 0.0633 - accuracy: 0.9777

 391/1688 [=====>........................] - ETA: 5s - loss: 0.0629 - accuracy: 0.9779









































































































































































































Epoch 3/3


   1/1688 [..............................] - ETA: 8s - loss: 0.0855 - accuracy: 0.9688

  14/1688 [..............................] - ETA: 6s - loss: 0.0586 - accuracy: 0.9777

  27/1688 [..............................] - ETA: 6s - loss: 0.0614 - accuracy: 0.9769

  40/1688 [..............................] - ETA: 6s - loss: 0.0525 - accuracy: 0.9805

  53/1688 [..............................] - ETA: 6s - loss: 0.0495 - accuracy: 0.9829

  66/1688 [>.............................] - ETA: 6s - loss: 0.0512 - accuracy: 0.9815

  79/1688 [>.............................] - ETA: 6s - loss: 0.0541 - accuracy: 0.9806

  92/1688 [>.............................] - ETA: 6s - loss: 0.0553 - accuracy: 0.9803

 105/1688 [>.............................] - ETA: 6s - loss: 0.0571 - accuracy: 0.9801

 118/1688 [=>............................] - ETA: 6s - loss: 0.0545 - accuracy: 0.9812

 131/1688 [=>............................] - ETA: 6s - loss: 0.0538 - accuracy: 0.9816

 144/1688 [=>............................] - ETA: 6s - loss: 0.0540 - accuracy: 0.9816

 157/1688 [=>............................] - ETA: 6s - loss: 0.0533 - accuracy: 0.9821

 170/1688 [==>...........................] - ETA: 5s - loss: 0.0537 - accuracy: 0.9820

 183/1688 [==>...........................] - ETA: 5s - loss: 0.0535 - accuracy: 0.9819

 196/1688 [==>...........................] - ETA: 5s - loss: 0.0531 - accuracy: 0.9825

 209/1688 [==>...........................] - ETA: 5s - loss: 0.0524 - accuracy: 0.9828

 222/1688 [==>...........................] - ETA: 5s - loss: 0.0515 - accuracy: 0.9830

 235/1688 [===>..........................] - ETA: 5s - loss: 0.0509 - accuracy: 0.9834

 248/1688 [===>..........................] - ETA: 5s - loss: 0.0504 - accuracy: 0.9836

 261/1688 [===>..........................] - ETA: 5s - loss: 0.0502 - accuracy: 0.9836

 274/1688 [===>..........................] - ETA: 5s - loss: 0.0501 - accuracy: 0.9835

 287/1688 [====>.........................] - ETA: 5s - loss: 0.0499 - accuracy: 0.9836

 300/1688 [====>.........................] - ETA: 5s - loss: 0.0503 - accuracy: 0.9835

 313/1688 [====>.........................] - ETA: 5s - loss: 0.0509 - accuracy: 0.9835

 326/1688 [====>.........................] - ETA: 5s - loss: 0.0511 - accuracy: 0.9835

 339/1688 [=====>........................] - ETA: 5s - loss: 0.0512 - accuracy: 0.9835

 352/1688 [=====>........................] - ETA: 5s - loss: 0.0511 - accuracy: 0.9836

 365/1688 [=====>........................] - ETA: 5s - loss: 0.0522 - accuracy: 0.9832

 378/1688 [=====>........................] - ETA: 5s - loss: 0.0517 - accuracy: 0.9834

 391/1688 [=====>........................] - ETA: 5s - loss: 0.0517 - accuracy: 0.9836









































































































































































































<tf_keras.src.callbacks.History at 0x7f60c8593ee0>

Define helper functions to calculate and print the sparsity and clusters of the model.

In [8]:
def print_model_weights_sparsity(model):
    for layer in model.layers:
        if isinstance(layer, keras.layers.Wrapper):
            weights = layer.trainable_weights
        else:
            weights = layer.weights
        for weight in weights:
            if "kernel" not in weight.name or "centroid" in weight.name:
                continue
            weight_size = weight.numpy().size
            zero_num = np.count_nonzero(weight == 0)
            print(
                f"{weight.name}: {zero_num/weight_size:.2%} sparsity ",
                f"({zero_num}/{weight_size})",
            )

def print_model_weight_clusters(model):
    for layer in model.layers:
        if isinstance(layer, keras.layers.Wrapper):
            weights = layer.trainable_weights
        else:
            weights = layer.weights
        for weight in weights:
            # ignore auxiliary quantization weights
            if "quantize_layer" in weight.name:
                continue
            if "kernel" in weight.name:
                unique_count = len(np.unique(weight))
                print(
                    f"{layer.name}/{weight.name}: {unique_count} clusters "
                )

Let's strip the pruning wrapper first, then check that the model kernels were correctly pruned.

In [9]:
stripped_pruned_model = tfmot.sparsity.keras.strip_pruning(pruned_model)

print_model_weights_sparsity(stripped_pruned_model)

conv2d/kernel:0: 50.00% sparsity  (54/108)
dense/kernel:0: 50.00% sparsity  (10140/20280)


## Apply sparsity preserving clustering and check its effect on model sparsity in both cases

Next, apply sparsity preserving clustering on the pruned model and observe the number of clusters and check that the sparsity is preserved.

In [10]:
import tensorflow_model_optimization as tfmot
from tensorflow_model_optimization.python.core.clustering.keras.experimental import (
    cluster,
)

cluster_weights = tfmot.clustering.keras.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization

cluster_weights = cluster.cluster_weights

clustering_params = {
  'number_of_clusters': 8,
  'cluster_centroids_init': CentroidInitialization.KMEANS_PLUS_PLUS,
  'preserve_sparsity': True
}

sparsity_clustered_model = cluster_weights(stripped_pruned_model, **clustering_params)

sparsity_clustered_model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

print('Train sparsity preserving clustering model:')
sparsity_clustered_model.fit(train_images, train_labels,epochs=3, validation_split=0.1)

Train sparsity preserving clustering model:


Epoch 1/3


   1/1688 [..............................] - ETA: 42:22 - loss: 0.0442 - accuracy: 1.0000

  12/1688 [..............................] - ETA: 7s - loss: 0.0474 - accuracy: 0.9922   

  24/1688 [..............................] - ETA: 7s - loss: 0.0446 - accuracy: 0.9909

  36/1688 [..............................] - ETA: 7s - loss: 0.0626 - accuracy: 0.9826

  48/1688 [..............................] - ETA: 7s - loss: 0.0582 - accuracy: 0.9824

  60/1688 [>.............................] - ETA: 7s - loss: 0.0577 - accuracy: 0.9828

  72/1688 [>.............................] - ETA: 7s - loss: 0.0550 - accuracy: 0.9844

  84/1688 [>.............................] - ETA: 7s - loss: 0.0541 - accuracy: 0.9847

  96/1688 [>.............................] - ETA: 7s - loss: 0.0513 - accuracy: 0.9854

 108/1688 [>.............................] - ETA: 6s - loss: 0.0511 - accuracy: 0.9847

 120/1688 [=>............................] - ETA: 6s - loss: 0.0508 - accuracy: 0.9849

 132/1688 [=>............................] - ETA: 6s - loss: 0.0502 - accuracy: 0.9848

 144/1688 [=>............................] - ETA: 6s - loss: 0.0497 - accuracy: 0.9855

 156/1688 [=>............................] - ETA: 6s - loss: 0.0494 - accuracy: 0.9850

 169/1688 [==>...........................] - ETA: 6s - loss: 0.0482 - accuracy: 0.9858

 181/1688 [==>...........................] - ETA: 6s - loss: 0.0471 - accuracy: 0.9864

 193/1688 [==>...........................] - ETA: 6s - loss: 0.0479 - accuracy: 0.9861

 206/1688 [==>...........................] - ETA: 6s - loss: 0.0482 - accuracy: 0.9856

 219/1688 [==>...........................] - ETA: 6s - loss: 0.0482 - accuracy: 0.9850

 231/1688 [===>..........................] - ETA: 6s - loss: 0.0475 - accuracy: 0.9853

 243/1688 [===>..........................] - ETA: 6s - loss: 0.0460 - accuracy: 0.9860

 255/1688 [===>..........................] - ETA: 6s - loss: 0.0445 - accuracy: 0.9865

 267/1688 [===>..........................] - ETA: 6s - loss: 0.0436 - accuracy: 0.9869

 279/1688 [===>..........................] - ETA: 6s - loss: 0.0434 - accuracy: 0.9869

 291/1688 [====>.........................] - ETA: 5s - loss: 0.0433 - accuracy: 0.9869

 303/1688 [====>.........................] - ETA: 5s - loss: 0.0431 - accuracy: 0.9870

 315/1688 [====>.........................] - ETA: 5s - loss: 0.0423 - accuracy: 0.9871

 327/1688 [====>.........................] - ETA: 5s - loss: 0.0429 - accuracy: 0.9867

 340/1688 [=====>........................] - ETA: 5s - loss: 0.0421 - accuracy: 0.9869

 352/1688 [=====>........................] - ETA: 5s - loss: 0.0419 - accuracy: 0.9868

 364/1688 [=====>........................] - ETA: 5s - loss: 0.0414 - accuracy: 0.9871

 377/1688 [=====>........................] - ETA: 5s - loss: 0.0416 - accuracy: 0.9872

 390/1688 [=====>........................] - ETA: 5s - loss: 0.0422 - accuracy: 0.9868





















































































































































































































Epoch 2/3


   1/1688 [..............................] - ETA: 8s - loss: 0.0139 - accuracy: 1.0000

  13/1688 [..............................] - ETA: 7s - loss: 0.0238 - accuracy: 0.9904

  25/1688 [..............................] - ETA: 7s - loss: 0.0286 - accuracy: 0.9887

  38/1688 [..............................] - ETA: 6s - loss: 0.0352 - accuracy: 0.9868

  50/1688 [..............................] - ETA: 6s - loss: 0.0406 - accuracy: 0.9844

  62/1688 [>.............................] - ETA: 6s - loss: 0.0418 - accuracy: 0.9844

  75/1688 [>.............................] - ETA: 6s - loss: 0.0403 - accuracy: 0.9850

  87/1688 [>.............................] - ETA: 6s - loss: 0.0380 - accuracy: 0.9864

 100/1688 [>.............................] - ETA: 6s - loss: 0.0390 - accuracy: 0.9859

 113/1688 [=>............................] - ETA: 6s - loss: 0.0369 - accuracy: 0.9873

 125/1688 [=>............................] - ETA: 6s - loss: 0.0365 - accuracy: 0.9880

 137/1688 [=>............................] - ETA: 6s - loss: 0.0358 - accuracy: 0.9881

 149/1688 [=>............................] - ETA: 6s - loss: 0.0356 - accuracy: 0.9880

 162/1688 [=>............................] - ETA: 6s - loss: 0.0360 - accuracy: 0.9875

 175/1688 [==>...........................] - ETA: 6s - loss: 0.0358 - accuracy: 0.9873

 188/1688 [==>...........................] - ETA: 6s - loss: 0.0355 - accuracy: 0.9877

 201/1688 [==>...........................] - ETA: 6s - loss: 0.0350 - accuracy: 0.9882

 214/1688 [==>...........................] - ETA: 6s - loss: 0.0364 - accuracy: 0.9879

 226/1688 [===>..........................] - ETA: 6s - loss: 0.0356 - accuracy: 0.9884

 238/1688 [===>..........................] - ETA: 6s - loss: 0.0349 - accuracy: 0.9888

 250/1688 [===>..........................] - ETA: 5s - loss: 0.0355 - accuracy: 0.9886

 262/1688 [===>..........................] - ETA: 5s - loss: 0.0352 - accuracy: 0.9888

 275/1688 [===>..........................] - ETA: 5s - loss: 0.0363 - accuracy: 0.9887

 288/1688 [====>.........................] - ETA: 5s - loss: 0.0361 - accuracy: 0.9888

 300/1688 [====>.........................] - ETA: 5s - loss: 0.0363 - accuracy: 0.9889

 312/1688 [====>.........................] - ETA: 5s - loss: 0.0359 - accuracy: 0.9889

 325/1688 [====>.........................] - ETA: 5s - loss: 0.0354 - accuracy: 0.9890

 337/1688 [====>.........................] - ETA: 5s - loss: 0.0353 - accuracy: 0.9891

 350/1688 [=====>........................] - ETA: 5s - loss: 0.0366 - accuracy: 0.9887

 363/1688 [=====>........................] - ETA: 5s - loss: 0.0360 - accuracy: 0.9891

 376/1688 [=====>........................] - ETA: 5s - loss: 0.0357 - accuracy: 0.9891

 388/1688 [=====>........................] - ETA: 5s - loss: 0.0360 - accuracy: 0.9890

















































































































































































































Epoch 3/3


   1/1688 [..............................] - ETA: 8s - loss: 0.0640 - accuracy: 0.9688

  14/1688 [..............................] - ETA: 6s - loss: 0.0254 - accuracy: 0.9911

  27/1688 [..............................] - ETA: 6s - loss: 0.0269 - accuracy: 0.9896

  39/1688 [..............................] - ETA: 6s - loss: 0.0250 - accuracy: 0.9920

  52/1688 [..............................] - ETA: 6s - loss: 0.0382 - accuracy: 0.9886

  65/1688 [>.............................] - ETA: 6s - loss: 0.0413 - accuracy: 0.9875

  78/1688 [>.............................] - ETA: 6s - loss: 0.0389 - accuracy: 0.9872

  91/1688 [>.............................] - ETA: 6s - loss: 0.0376 - accuracy: 0.9880

 103/1688 [>.............................] - ETA: 6s - loss: 0.0362 - accuracy: 0.9885

 116/1688 [=>............................] - ETA: 6s - loss: 0.0345 - accuracy: 0.9890

 128/1688 [=>............................] - ETA: 6s - loss: 0.0356 - accuracy: 0.9890

 141/1688 [=>............................] - ETA: 6s - loss: 0.0353 - accuracy: 0.9891

 153/1688 [=>............................] - ETA: 6s - loss: 0.0368 - accuracy: 0.9884

 166/1688 [=>............................] - ETA: 6s - loss: 0.0364 - accuracy: 0.9883

 179/1688 [==>...........................] - ETA: 6s - loss: 0.0368 - accuracy: 0.9881

 192/1688 [==>...........................] - ETA: 6s - loss: 0.0361 - accuracy: 0.9884

 204/1688 [==>...........................] - ETA: 6s - loss: 0.0368 - accuracy: 0.9882

 216/1688 [==>...........................] - ETA: 6s - loss: 0.0366 - accuracy: 0.9881

 228/1688 [===>..........................] - ETA: 6s - loss: 0.0362 - accuracy: 0.9882

 240/1688 [===>..........................] - ETA: 6s - loss: 0.0364 - accuracy: 0.9882

 252/1688 [===>..........................] - ETA: 5s - loss: 0.0361 - accuracy: 0.9882

 265/1688 [===>..........................] - ETA: 5s - loss: 0.0378 - accuracy: 0.9881

 277/1688 [===>..........................] - ETA: 5s - loss: 0.0385 - accuracy: 0.9877

 290/1688 [====>.........................] - ETA: 5s - loss: 0.0384 - accuracy: 0.9876

 302/1688 [====>.........................] - ETA: 5s - loss: 0.0394 - accuracy: 0.9873

 315/1688 [====>.........................] - ETA: 5s - loss: 0.0390 - accuracy: 0.9874

 327/1688 [====>.........................] - ETA: 5s - loss: 0.0383 - accuracy: 0.9876

 339/1688 [=====>........................] - ETA: 5s - loss: 0.0382 - accuracy: 0.9876

 351/1688 [=====>........................] - ETA: 5s - loss: 0.0382 - accuracy: 0.9876

 363/1688 [=====>........................] - ETA: 5s - loss: 0.0387 - accuracy: 0.9877

 376/1688 [=====>........................] - ETA: 5s - loss: 0.0384 - accuracy: 0.9876

 389/1688 [=====>........................] - ETA: 5s - loss: 0.0376 - accuracy: 0.9879

















































































































































































































<tf_keras.src.callbacks.History at 0x7f6080153790>

Strip the clustering wrapper first, then check that the model is correctly pruned and clustered.

In [11]:
stripped_clustered_model = tfmot.clustering.keras.strip_clustering(sparsity_clustered_model)

print("Model sparsity:\n")
print_model_weights_sparsity(stripped_clustered_model)

print("\nModel clusters:\n")
print_model_weight_clusters(stripped_clustered_model)

Model sparsity:

kernel:0: 50.93% sparsity  (55/108)
kernel:0: 58.12% sparsity  (11787/20280)

Model clusters:

conv2d/kernel:0: 8 clusters 
dense/kernel:0: 8 clusters 


## Apply QAT and PCQAT and check effect on model clusters and sparsity

Next, apply both QAT and PCQAT on the sparse clustered model and observe that PCQAT preserves weight sparsity and clusters in your model. Note that the stripped model is passed to the QAT and PCQAT API.

In [12]:
# QAT
qat_model = tfmot.quantization.keras.quantize_model(stripped_clustered_model)

qat_model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
print('Train qat model:')
qat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)

# PCQAT
quant_aware_annotate_model = tfmot.quantization.keras.quantize_annotate_model(
              stripped_clustered_model)
pcqat_model = tfmot.quantization.keras.quantize_apply(
              quant_aware_annotate_model,
              tfmot.experimental.combine.Default8BitClusterPreserveQuantizeScheme(preserve_sparsity=True))

pcqat_model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
print('Train pcqat model:')
pcqat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)

Train qat model:


  1/422 [..............................] - ETA: 9:01 - loss: 0.0300 - accuracy: 0.9922

  9/422 [..............................] - ETA: 2s - loss: 0.0528 - accuracy: 0.9844  

 17/422 [>.............................] - ETA: 2s - loss: 0.0440 - accuracy: 0.9862

 26/422 [>.............................] - ETA: 2s - loss: 0.0384 - accuracy: 0.9880

 35/422 [=>............................] - ETA: 2s - loss: 0.0392 - accuracy: 0.9877

 44/422 [==>...........................] - ETA: 2s - loss: 0.0374 - accuracy: 0.9881

 53/422 [==>...........................] - ETA: 2s - loss: 0.0353 - accuracy: 0.9886

 62/422 [===>..........................] - ETA: 2s - loss: 0.0345 - accuracy: 0.9888

 71/422 [====>.........................] - ETA: 2s - loss: 0.0327 - accuracy: 0.9893

 80/422 [====>.........................] - ETA: 2s - loss: 0.0327 - accuracy: 0.9895

 89/422 [=====>........................] - ETA: 2s - loss: 0.0331 - accuracy: 0.9896

 98/422 [=====>........................] - ETA: 1s - loss: 0.0323 - accuracy: 0.9901











































































Train pcqat model:










  1/422 [..............................] - ETA: 11:38 - loss: 0.0170 - accuracy: 1.0000

 10/422 [..............................] - ETA: 2s - loss: 0.0286 - accuracy: 0.9914   

 19/422 [>.............................] - ETA: 2s - loss: 0.0262 - accuracy: 0.9926

 28/422 [>.............................] - ETA: 2s - loss: 0.0296 - accuracy: 0.9914

 37/422 [=>............................] - ETA: 2s - loss: 0.0298 - accuracy: 0.9920

 46/422 [==>...........................] - ETA: 2s - loss: 0.0295 - accuracy: 0.9917

 55/422 [==>...........................] - ETA: 2s - loss: 0.0298 - accuracy: 0.9913

 64/422 [===>..........................] - ETA: 2s - loss: 0.0303 - accuracy: 0.9916

 73/422 [====>.........................] - ETA: 2s - loss: 0.0312 - accuracy: 0.9917

 82/422 [====>.........................] - ETA: 2s - loss: 0.0315 - accuracy: 0.9917

 91/422 [=====>........................] - ETA: 2s - loss: 0.0326 - accuracy: 0.9916











































































<tf_keras.src.callbacks.History at 0x7f6050606e80>

In [13]:
print("QAT Model clusters:")
print_model_weight_clusters(qat_model)
print("\nQAT Model sparsity:")
print_model_weights_sparsity(qat_model)
print("\nPCQAT Model clusters:")
print_model_weight_clusters(pcqat_model)
print("\nPCQAT Model sparsity:")
print_model_weights_sparsity(pcqat_model)

QAT Model clusters:
quant_conv2d/conv2d/kernel:0: 100 clusters 
quant_dense/dense/kernel:0: 18251 clusters 

QAT Model sparsity:
conv2d/kernel:0: 8.33% sparsity  (9/108)
dense/kernel:0: 7.52% sparsity  (1525/20280)

PCQAT Model clusters:
quant_conv2d/conv2d/kernel:0: 8 clusters 
quant_dense/dense/kernel:0: 8 clusters 

PCQAT Model sparsity:
conv2d/kernel:0: 50.93% sparsity  (55/108)
dense/kernel:0: 58.16% sparsity  (11794/20280)


## See compression benefits of PCQAT model

Define helper function to get zipped model file.

In [14]:
def get_gzipped_model_size(file):
  # It returns the size of the gzipped model in kilobytes.

  _, zipped_file = tempfile.mkstemp('.zip')
  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(file)

  return os.path.getsize(zipped_file)/1000

Observe that applying sparsity, clustering and PCQAT to a model yields significant compression benefits.

In [15]:
# QAT model
converter = tf.lite.TFLiteConverter.from_keras_model(qat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
qat_tflite_model = converter.convert()
qat_model_file = 'qat_model.tflite'
# Save the model.
with open(qat_model_file, 'wb') as f:
    f.write(qat_tflite_model)

# PCQAT model
converter = tf.lite.TFLiteConverter.from_keras_model(pcqat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
pcqat_tflite_model = converter.convert()
pcqat_model_file = 'pcqat_model.tflite'
# Save the model.
with open(pcqat_model_file, 'wb') as f:
    f.write(pcqat_tflite_model)

print("QAT model size: ", get_gzipped_model_size(qat_model_file), ' KB')
print("PCQAT model size: ", get_gzipped_model_size(pcqat_model_file), ' KB')

INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpbd29dk98/assets


INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpbd29dk98/assets




W0000 00:00:1709988717.237025   41361 tf_tfl_flatbuffer_helpers.cc:390] Ignored output_format.
W0000 00:00:1709988717.237075   41361 tf_tfl_flatbuffer_helpers.cc:393] Ignored drop_control_dependency.


INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpy4q5o_1n/assets


INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpy4q5o_1n/assets


W0000 00:00:1709988720.060897   41361 tf_tfl_flatbuffer_helpers.cc:390] Ignored output_format.
W0000 00:00:1709988720.060927   41361 tf_tfl_flatbuffer_helpers.cc:393] Ignored drop_control_dependency.


QAT model size:  13.958  KB
PCQAT model size:  7.876  KB


## See the persistence of accuracy from TF to TFLite

Define a helper function to evaluate the TFLite model on the test dataset.

In [16]:
def eval_model(interpreter):
  input_index = interpreter.get_input_details()[0]["index"]
  output_index = interpreter.get_output_details()[0]["index"]

  # Run predictions on every image in the "test" dataset.
  prediction_digits = []
  for i, test_image in enumerate(test_images):
    if i % 1000 == 0:
      print(f"Evaluated on {i} results so far.")
    # Pre-processing: add batch dimension and convert to float32 to match with
    # the model's input data format.
    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)

    # Run inference.
    interpreter.invoke()

    # Post-processing: remove batch dimension and find the digit with highest
    # probability.
    output = interpreter.tensor(output_index)
    digit = np.argmax(output()[0])
    prediction_digits.append(digit)

  print('\n')
  # Compare prediction results with ground truth labels to calculate accuracy.
  prediction_digits = np.array(prediction_digits)
  accuracy = (prediction_digits == test_labels).mean()
  return accuracy

Evaluate the model, which has been pruned, clustered and quantized, and then see that the accuracy from TensorFlow persists in the TFLite backend.

In [17]:
interpreter = tf.lite.Interpreter(pcqat_model_file)
interpreter.allocate_tensors()

pcqat_test_accuracy = eval_model(interpreter)

print('Pruned, clustered and quantized TFLite test_accuracy:', pcqat_test_accuracy)
print('Baseline TF test accuracy:', baseline_model_accuracy)

Evaluated on 0 results so far.
Evaluated on 1000 results so far.
Evaluated on 2000 results so far.


INFO: Created TensorFlow Lite XNNPACK delegate for CPU.


Evaluated on 3000 results so far.
Evaluated on 4000 results so far.
Evaluated on 5000 results so far.


Evaluated on 6000 results so far.
Evaluated on 7000 results so far.
Evaluated on 8000 results so far.


Evaluated on 9000 results so far.


Pruned, clustered and quantized TFLite test_accuracy: 0.9806
Baseline TF test accuracy: 0.9835000038146973


## Conclusion

In this tutorial, you learned how to create a model, prune it using the `prune_low_magnitude()` API, and apply sparsity preserving clustering using the `cluster_weights()` API to preserve sparsity while clustering the weights.

Next, sparsity and cluster preserving quantization aware training (PCQAT) was applied to preserve model sparsity and clusters while using QAT. The final PCQAT model was compared to the QAT one to show that sparsity and clusters are preserved in the former and lost in the latter.

Next, the models were converted to TFLite to show the compression benefits of chaining sparsity, clustering, and PCQAT model optimization techniques and the TFLite model was evaluated to ensure that the accuracy persists in the TFLite backend.

Finally, the PCQAT TFLite model accuracy was compared to the pre-optimization baseline model accuracy to show that collaborative optimization techniques managed to achieve the compression benefits while maintaining a similar accuracy compared to the original model.