Visualizing training with tfjs-vis

tfjs-vis is a small set of visualization utilities to make it easier to understand what is going on with your tfjs models. It is designed in a way to work along side regular web apps. This page will use some of the features of tfjs to illustrate what is going on with a convolutional model that will be trained (in the browser) to recognize handwritten digits.

tfjs-vis provides 2 main things:

  1. A place to put visualizations that tries not to interfere with your web page. We call this place a visor.
  2. Some built in visualizations that we have found to be useful when working with TensorFlow.js

The Visor

Let's take a look at the first. Calling tfvis.visor() will create a visor if it doesn't exist or return the existing one. Click the button below to show the visor.

Notice the panel that is now displayed on the right. It hovers over your pages content and shouldn't disturb the flow of your page's DOM Elements. You can see a few controls for showing or hiding the visor, but by default it also supports the following keyboard shortcuts:

The API allows you to disable (unbind) these keyboard shortcuts.

Surfaces

To add content to the visor we need a surface. We make a surface with the following function call:

To create a surface we must give is a name, we can also optionally specify a tab name that the surface should be put on. visor().surface() allows us to create a surface if it doesn't exist or fetch it if it does. This API returns an object that has a pointer to 3 DOM elements:

Our Data

We will use the MNIST database as our training set, it is comprised of a set of about 60k images of handwritten digits, all cropped to 28x28 px. Lets take a look at a few examples, we'll use the surface we created earlier.

   

The code to render these examples isn't built into tfjs. But because you have full access to the DOM element for each surface, you can draw whatever you would like into them. This allows easy integration of custom visualizations into the visor.

Here is the code for the "Show Example Digits" button above:

Training Our Model

Our goal is to train a model to recognize similar digits. We have already written a tutorial on how to do so. So in this article we are going to focus on monitoring that training and also look at how well our model performs.

First let us define a helper function to do our training.

We can use the show.fitCallbacks method to get functions that will plot the loss after every batch and epoch.

Another option is to wait for the training to complete and render the loss curve when it is done.

Customizing training charts.

The show.fitCallbacks function is designed to help you quickly plot training behaviour with reasonable defaults. If you want to customize the rendering of these charts, you can use render.linechart function. An example that plots accuracy values at the end of every epoch using custom colors and a custom yaxis domain is shown below.

Evaluating Our Model

Now that our model is trained we should evalute its performance. For a classification task like this one we can use the `perClassAccuracy` and `confusionMatrix` functions. These are demonstrated below.