Visualizing training with tfjs-vis
tfjs-vis is a small set of visualization utilities to make it easier to understand what is going on with your tfjs models. It is designed in a way to work along side regular web apps. This page will use some of the features of tfjs to illustrate what is going on with a convolutional model that will be trained (in the browser) to recognize handwritten digits.
tfjs-vis provides 2 main things:
- A place to put visualizations that tries not to interfere with your web page. We call this place a visor.
- Some built in visualizations that we have found to be useful when working with TensorFlow.js
The Visor
Let's take a look at the first. Calling tfvis.visor()
will create a visor if it doesn't exist or
return
the existing one. Click the button below to show the
visor.
Notice the panel that is now displayed on the right. It hovers over your pages content and shouldn't disturb the flow of your page's DOM Elements. You can see a few controls for showing or hiding the visor, but by default it also supports the following keyboard shortcuts:
- ` (backtick): Shows or hides the visor
- ~ (tilde, shift+backtick): Toggles betweeen the two sizes the visor supports
Surfaces
To add content to the visor we need a surface. We make a surface with the following function call:
To create a surface we must give is a name, we can also optionally specify a tab name that the surface should
be put on.
visor().surface()
allows us to create a surface if it doesn't exist or fetch it if it does. This
API
returns an object that has a pointer to 3 DOM elements:
- container: The containing DOM element for the surface
- label: The label element
- drawArea: A DOM Element where we can render visualizations or other content.
Our Data
We will use the MNIST database as our training set, it is comprised of a set of about 60k images of handwritten digits, all cropped to 28x28 px. Lets take a look at a few examples, we'll use the surface we created earlier.
The code to render these examples isn't built into tfjs. But because you have full access to the DOM element for each surface, you can draw whatever you would like into them. This allows easy integration of custom visualizations into the visor.
Here is the code for the "Show Example Digits" button above:
Training Our Model
Our goal is to train a model to recognize similar digits. We have already written a tutorial on how to do so. So in this article we are going to focus on monitoring that training and also look at how well our model performs.
First let us define a helper function to do our training.
We can use the show.fitCallbacks
method to get functions that will plot the loss after every batch
and
epoch.
Another option is to wait for the training to complete and render the loss curve when it is done.
Customizing training charts.
The show.fitCallbacks
function is designed to help you quickly plot training behaviour
with reasonable defaults. If you want to customize the rendering of these charts, you can use
render.linechart
function. An example that plots accuracy values at the end
of every epoch using custom colors and a custom yaxis domain is shown below.
Evaluating Our Model
Now that our model is trained we should evalute its performance. For a classification task like this one we can use the `perClassAccuracy` and `confusionMatrix` functions. These are demonstrated below.