Here we’re going to extend the pre-trained model using the technique of transfer learning. We will use a KNN classifier to recognize images of grumpy facial expressions.
We’ve seen in the previous article how easy it is to load a pre-trained model. In this article, we’re going to extend the pre-trained model using transfer learning. We will build on the model with our own training set and use a K-nearest neighbor (KNN) module to categorize images of facial expressions as either grumpy or neutral.
Before we dig into any code, let’s quickly talk about how KNN and transfer learning work.
The KNN Classifier
The KNN algorithm is a simple, easy-to-implement supervised machine learning algorithm that can be used to solve classification as well as regression predictive problems.
The algorithm assumes that similar things exist close to each other. For general understanding, the shades of red color would be more similar to each other than any other color like yellow or black. KNN uses the same idea of similarity and classifies a new case by comparing how close it is to pre-classified cases using a distance function i.e. cosine, hamming. It then chooses the class for new case most common amongst its K nearest cases or what is called "nearest neighbors"
The KNN classifier from TensorFlow.js provides a utility for creating a classifier using the same algorithm. One thing to note here is that it doesn’t provide a model, but instead provides a utility for constructing a KNN model and uses activations from another model or tensor. You can read more about it here.
Transfer Learning
Transfer learning is a machine learning technique that allows you to reuse a model developed for a certain task as the starting point or base for a model on another task.
Transfer learning is especially popular in deep learning where you can use pre-trained models as the starting point on computer vision tasks. Since a huge amount of computational resources and time are required to develop neural networks for these platforms, transfer learning comes in handy to significantly improve the performance of the overall system.
Our Technology Stack
For this example, we’re going to use the following technology stack:
- TensorFlow.js – A machine learning framework that makes it possible to do machine learning on the client side on the web.
- MobileNet Model – A pre-trained TensorFlow.js model used for image classification.
- KNN Classifier – A basic TensorFlow.js classifier that can be used to customize the image classification.
You can use other technology stacks like React or Angular if you’d like. Also feel free to extend the example.
Setting Up
Let’s start off by importing the required models:
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js"> </script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/mobilenet"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/knn-classifier"></script>
The next thing we need to do is define a canvas element with a specific width and height:
<canvas width="224" height="224"></canvas>
This is because the classifier has been trained on the images of the same specific dimensions. We’re using the same size to match the format of the data so we don’t have to resize the image before feeding it into the classifier.
Since we’re building a classifier that will classify images of a human face as having either a grumpy or neutral expression, we create a Grumpy and a Neutral button to manually classify images and add to our training data, and a "Predict" button to predict an image’s classification:
<button class="grumpy">Grumpy</button>
<button class="neutral">Neutral</button>
<button class="predict">Predict</button>
Now we attach event listeners to the buttons:
const grumpy = document.querySelector('.grumpy');
const neutral = document.querySelector('.neutral');
grumpy.addEventListener('click', () => addExamples('grumpy'));
neutral.addEventListener('click', () => addExamples('neutral'));
document.querySelector('.predict').addEventListener('click', predict);
To keep it simple and easy, we’re going to make our canvas accept images by drag and drop:
const canvas = document.querySelector("canvas");
const context = canvas.getContext("2d");
canvas.addEventListener('dragover', e => e.preventDefault(), false);
canvas.addEventListener('drop', onImageDrop, false);
The last thing we need is our function to handle dropped files:
const onImageDrop = e => {
e.preventDefault();
const imageFile = e.dataTransfer.files[0];
const imageReader = new FileReader();
imageReader.onload = imageFile => {
const image = new Image();
image.onload = () => {
context.drawImage(image, 0, 0, 224, 224);
};
image.src = imageFile.target.result;
};
imageReader.readAsDataURL(imageFile);
};
Once everything is in place, this is how our HTML document looks:
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<title>Image classification with Tensorflow.js</title>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/mobilenet"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/knn-classifier"></script>
</head>
<body>
<h1>Custom Image Classifier using Tensorflow.js</h1>
<canvas style=" border: 2px dashed #34495e; margin: auto;" width="224" height="224"></canvas>
<h3>Train classifier with examples</h3>
<button class="grumpy">Grumpy</button>
<button class="neutral">Neutral</button>
<button class="predict">Predict</button>
<script src="knnClassifier.js"></script>
<script>
const canvas = document.querySelector("canvas");
const context = canvas.getContext("2d");
const grumpy = document.querySelector('.grumpy');
const neutral = document.querySelector('.neutral');
const onImageDrop = e => {
e.preventDefault();
const imageFile = e.dataTransfer.files[0];
const imageReader = new FileReader();
imageReader.onload = imageFile => {
const image = new Image();
image.onload = () => {
context.drawImage(image, 0, 0, 224, 224);
};
image.src = imageFile.target.result;
};
imageReader.readAsDataURL(imageFile);
};
canvas.addEventListener('dragover', e => e.preventDefault(), false);
canvas.addEventListener('drop', onImageDrop, false);
grumpy.addEventListener('click', () => addExamples('grumpy'));
neutral.addEventListener('click', () => addExamples('neutral'));
document.querySelector('.predict').addEventListener('click', predict);
</script>
</body>
</html>
You might have noticed we’re also using the knnClassifier.js file. This file will contain the functions to create the classifier, load the model and handle prediction. Let’s first create the KNN classifier and load the MobileNet model.
const loadKnnClassifier = async () => {
knn = knnClassifier.create();
console.log("Model is Loading...")
model = await mobilenet.load();
console.log("Model Loaded successfully!")
};
Using the KNN classifier
As mentioned earlier, we need to train the classifier on custom images. KNN classifier has an addExample
method that takes two arguments:
example
– Usually an activation from another model to add an example to the dataset. label
– The class name of the example.
Here’s our function to add to the training data:
const addExamples = label => {
const img = tf.browser.fromPixels(canvas);
const attribute = model.infer(img, 'conv_preds');
knn.addExample(attribute, label);
context.clearRect(0, 0, canvas.width, canvas.height);
if(label === 'grumpy'){
grumpy.innerText = `Grumpy (${++trainingDataSets[0]})`
}
else {
neutral.innerText = `Neutral (${++trainingDataSets[1]})`
}
console.log(`Trained classifier with ${label}`)
img.dispose();
};
Last but not the least is our prediction function:
const predict = async () => {
if (knn.getNumClasses() > 0) {
const img = tf.browser.fromPixels(canvas);
const attribute = model.infer(img, 'conv_preds');
const prediction = await knn.predictClass(attribute);
context.clearRect(0, 0, canvas.width, canvas.height);
console.log(`Prediction: ${prediction.label}`)
img.dispose();
}
};
Putting the Code Together
The final look of our code is as follows:
let knn;
let model;
let trainingDataSets = [0, 0];
const loadKnnClassifier = async () => {
knn = knnClassifier.create();
console.log("Model is Loading...")
model = await mobilenet.load();
console.log("Model Loaded successfully!")
};
const addExamples = label => {
const img = tf.browser.fromPixels(canvas);
const attribute = model.infer(img, 'conv_preds');
knn.addExample(attribute, label);
context.clearRect(0, 0, canvas.width, canvas.height);
if(label === 'grumpy'){
grumpy.innerText = `Grumpy (${++trainingDataSets[0]})`
}
else {
neutral.innerText = `Neutral (${++trainingDataSets[1]})`
}
console.log(`Trained classifier with ${label}`)
img.dispose();
};
const predict = async () => {
if (knn.getNumClasses() > 0) {
const img = tf.browser.fromPixels(canvas);
const attribute = model.infer(img, 'conv_preds');
const prediction = await knn.predictClass(attribute);
context.clearRect(0, 0, canvas.width, canvas.height);
console.log(`Prediction: ${prediction.label}`)
img.dispose();
}
};
loadKnnClassifier();
Testing it Out
Open the HTML document in the browser and drag & drop an image file onto the canvas, then click the Grumpy or Neutral button to categorize it.
Once you’ve trained the classifier with a few images, drag in another image and click the Predict button to get the prediction.
The final console output should be similar to the following:
What’s Next?
In this article, we extended the pre-trained MobileNet model with the help of a KNN classifier using transfer learning. We trained a custom classifier to classify the human expressions in image files as grumpy or neutral. We did it all in the browser, but we used static images to train our model. What if we’re interested in real time custom classification?
Follow along the next article in the series where we will extend our model to do custom classification in real time using a webcam.