Digit Recognition With TensorFlow.js

Description

OCR (Optical Character Recognition) for digit recognition is a technology used to convert images containing handwritten or printed digits into machine-readable text. Here we train an OCR model using Tensorflow.js and PeerMR.

Implementation

This is a good example for porting an existing Tensorflow.js model that runs on a single web browser to run across multiple browsers using PeerMR. This example uses the vanilla Tensorflow.js library and does not require our fork of Tensorflow.js.

The following is adapted from the TensorFlow.js lab here

The input defines parameters for three different models. A single map function is used and three workers each train a single model. The prediction results are the final output of the job.

const execution = new JobExecution('gcs');
execution.scripts = ["https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest/dist/tf.min.js"];

execution.setInputFn(async function inputs() {
  return [
    {
      "kernel-size-1": 5,
      "filters-1": 8,
      "strides-1": 1,
      "kernel-size-2": 5,
      "filters-2": 16,
      "strides-2": 1
    },
    {
      "kernel-size-1": 4,
      "filters-1": 8,
      "strides-1": 1,
      "kernel-size-2": 4,
      "filters-2": 16,
      "strides-2": 1
    },
    {
      "kernel-size-1": 5,
      "filters-1": 16,
      "strides-1": 1,
      "kernel-size-2": 5,
      "filters-2": 32,
      "strides-2": 1
    }
  ];
});

execution.addStage(new MapStage(async function map(mapInputs) {
  const IMAGE_WIDTH = 28;
  const IMAGE_HEIGHT = 28;
  const IMAGE_CHANNELS = 1;

  function getModel(modelParams) {
    const model = tf.sequential();

    // In the first layer of our convolutional neural network we have
    // to specify the input shape. Then we specify some parameters for
    // the convolution operation that takes place in this layer.
    model.add(tf.layers.conv2d({
      inputShape: [IMAGE_WIDTH, IMAGE_HEIGHT, IMAGE_CHANNELS],
      kernelSize: modelParams['kernel-size-1'],
      filters: modelParams['filters-1'],
      strides: modelParams['strides-1'],
      activation: 'relu',
      kernelInitializer: 'varianceScaling'
    }));

    // The MaxPooling layer acts as a sort of downsampling using max values
    // in a region instead of averaging.
    model.add(tf.layers.maxPooling2d({poolSize: [2, 2], strides: [2, 2]}));

    // Repeat another conv2d + maxPooling stack.
    // Note that we have more filters in the convolution.
    model.add(tf.layers.conv2d({
      kernelSize: modelParams['kernel-size-2'],
      filters: modelParams['filters-2'],
      strides: modelParams['strides-2'],
      activation: 'relu',
      kernelInitializer: 'varianceScaling'
    }));
    model.add(tf.layers.maxPooling2d({poolSize: [2, 2], strides: [2, 2]}));

    // Now we flatten the output from the 2D filters into a 1D vector to prepare
    // it for input into our last layer. This is common practice when feeding
    // higher dimensional data to a final classification output layer.
    model.add(tf.layers.flatten());

    // Our last layer is a dense layer which has 10 output units, one for each
    // output class (i.e. 0, 1, 2, 3, 4, 5, 6, 7, 8, 9).
    const NUM_OUTPUT_CLASSES = 10;
    model.add(tf.layers.dense({
      units: NUM_OUTPUT_CLASSES,
      kernelInitializer: 'varianceScaling',
      activation: 'softmax'
    }));

    // Choose an optimizer, loss function and accuracy metric,
    // then compile and return the model
    const optimizer = tf.train.adam();
    model.compile({
      optimizer: optimizer,
      loss: 'categoricalCrossentropy',
      metrics: ['accuracy'],
    });

    return model;
  }

  async function train(model, data) {
    const BATCH_SIZE = 512;
    const TRAIN_DATA_SIZE = 5500;
    const TEST_DATA_SIZE = 1000;

    const [trainXs, trainYs] = tf.tidy(() => {
      const d = data.nextTrainBatch(TRAIN_DATA_SIZE);
      return [
        d.xs.reshape([TRAIN_DATA_SIZE, 28, 28, 1]),
        d.labels
      ];
    });

    const [testXs, testYs] = tf.tidy(() => {
      const d = data.nextTestBatch(TEST_DATA_SIZE);
      return [
        d.xs.reshape([TEST_DATA_SIZE, 28, 28, 1]),
        d.labels
      ];
    });

    const info = await model.fit(trainXs, trainYs, {
      batchSize: BATCH_SIZE,
      validationData: [testXs, testYs],
      epochs: 10,
      shuffle: true,
      // callbacks: fitCallbacks
    });
    console.log('Final accuracy', info.history.acc);
  }

  const IMAGE_SIZE = 784;
  const NUM_CLASSES = 10;
  const NUM_DATASET_ELEMENTS = 65000;

  const NUM_TRAIN_ELEMENTS = 55000;
  const NUM_TEST_ELEMENTS = NUM_DATASET_ELEMENTS - NUM_TRAIN_ELEMENTS;

  const MNIST_IMAGES_SPRITE_PATH =
    'https://storage.googleapis.com/peermr/689a5731-c29b-468e-94a6-24b6e117d3bb/_data/mnist_images.png';
  const MNIST_LABELS_PATH =
    'https://storage.googleapis.com/peermr/689a5731-c29b-468e-94a6-24b6e117d3bb/_data/mnist_labels_uint8';

  class MnistData {
    constructor() {
      this.shuffledTrainIndex = 0;
      this.shuffledTestIndex = 0;
    }

    async load() {
      const canvas = new OffscreenCanvas(IMAGE_WIDTH, IMAGE_HEIGHT);
      const ctx = canvas.getContext('2d');
      const imgResponse = await fetch(MNIST_IMAGES_SPRITE_PATH);
      const imgBlob = await imgResponse.blob();
      const imgBitmap = await createImageBitmap(imgBlob);

      const datasetBytesBuffer =
        new ArrayBuffer(NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4);
      const chunkSize = 5000;
      canvas.height = chunkSize;

      for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) {
        const datasetBytesView = new Float32Array(
          datasetBytesBuffer, i * IMAGE_SIZE * chunkSize * 4,
          IMAGE_SIZE * chunkSize);
        ctx.drawImage(
          imgBitmap, 0, i * chunkSize, imgBitmap.width, chunkSize, 0, 0, imgBitmap.width,
          chunkSize);

        const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);

        for (let j = 0; j < imageData.data.length / 4; j++) {
          // All channels hold an equal value since the image is grayscale, so
          // just read the red channel.
          datasetBytesView[j] = imageData.data[j * 4] / 255;
        }
      }
      this.datasetImages = new Float32Array(datasetBytesBuffer);

      const labelsResponse = await fetch(MNIST_LABELS_PATH);
      this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer());

      // Create shuffled indices into the train/test set for when we select a
      // random dataset element for training / validation.
      this.trainIndices = tf.util.createShuffledIndices(NUM_TRAIN_ELEMENTS);
      this.testIndices = tf.util.createShuffledIndices(NUM_TEST_ELEMENTS);

      // Slice the images and labels into train and test sets.
      this.trainImages =
        this.datasetImages.slice(0, IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
      this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
      this.trainLabels =
        this.datasetLabels.slice(0, NUM_CLASSES * NUM_TRAIN_ELEMENTS);
      this.testLabels =
        this.datasetLabels.slice(NUM_CLASSES * NUM_TRAIN_ELEMENTS);
    }

    nextTrainBatch(batchSize) {
      return this.nextBatch(
        batchSize, [this.trainImages, this.trainLabels], () => {
          this.shuffledTrainIndex =
            (this.shuffledTrainIndex + 1) % this.trainIndices.length;
          return this.trainIndices[this.shuffledTrainIndex];
        });
    }

    nextTestBatch(batchSize) {
      return this.nextBatch(batchSize, [this.testImages, this.testLabels], () => {
        this.shuffledTestIndex =
          (this.shuffledTestIndex + 1) % this.testIndices.length;
        return this.testIndices[this.shuffledTestIndex];
      });
    }

    nextBatch(batchSize, data, index) {
      const batchImagesArray = new Float32Array(batchSize * IMAGE_SIZE);
      const batchLabelsArray = new Uint8Array(batchSize * NUM_CLASSES);

      for (let i = 0; i < batchSize; i++) {
        const idx = index();

        const image =
          data[0].slice(idx * IMAGE_SIZE, idx * IMAGE_SIZE + IMAGE_SIZE);
        batchImagesArray.set(image, i * IMAGE_SIZE);

        const label =
          data[1].slice(idx * NUM_CLASSES, idx * NUM_CLASSES + NUM_CLASSES);
        batchLabelsArray.set(label, i * NUM_CLASSES);
      }

      const xs = tf.tensor2d(batchImagesArray, [batchSize, IMAGE_SIZE]);
      const labels = tf.tensor2d(batchLabelsArray, [batchSize, NUM_CLASSES]);

      return {xs, labels};
    }
  }

  function doPrediction(model, data, testDataSize = 500) {
    const IMAGE_WIDTH = 28;
    const IMAGE_HEIGHT = 28;
    const testData = data.nextTestBatch(testDataSize);
    const testxs = testData.xs.reshape([testDataSize, IMAGE_WIDTH, IMAGE_HEIGHT, 1]);
    const labels = testData.labels.argMax(-1);
    const preds = model.predict(testxs).argMax(-1);

    testxs.dispose();
    return [preds, labels];
  }

  async function trainAndPredict(k, modelParams) {
    context.log('loading mnist data');
    const data = new MnistData();
    await data.load();

    context.log('defining model with params: ' + JSON.stringify(modelParams));
    const model = getModel(modelParams);

    context.log('training model with params: ' + JSON.stringify(modelParams));
    await train(model, data);

    context.log('getting label predictions');
    const [preds, labels] = doPrediction(model, data);
    context.log('computing mse');
    const mse = tf.metrics.meanSquaredError(labels, preds)
    context.emit(k, mse.arraySync());
  }

  for await (const kv of mapInputs) {
    const [k, modelParams] = kv;
    await trainAndPredict(k, modelParams);
  }
  await context.onComplete();
}));

execution.workerCount = 3;
execution.start(jobRunner);