Click here to Skip to main content
65,938 articles
CodeProject is changing. Read more.
Articles / artificial-intelligence / tensorflow

Fluffy Animal Detector: Recognizing Custom Objects in the Browser by Transfer Learning in TensorFlow.js

5.00/5 (2 votes)
10 Jul 2020CPOL5 min read 8.8K   87  
In this article we will build a Fluffy Animal Detector, where I will show you a way to leverage a pre-trained Convolutional Neural Network (CNN) model like MobileNet.
Here we look at: Transfer learning on MobileNet v1 architecture, modifying the model, and training the new model.

TensorFlow + JavaScript. The most popular, cutting-edge AI framework now supports the most widely used programming language on the planet, so let’s make magic happen through deep learning right in our web browser, GPU-accelerated via WebGL using TensorFlow.js!

This is the third article in our series of six:

  1. Getting Started With Deep Learning in Your Browser Using TensorFlow.js
  2. Dogs and Pizza: Computer Vision in the Browser With TensorFlow.js
  3. Fluffy Animal Detector: Recognizing Custom Objects in the Browser by Transfer Learning in TensorFlow.js
  4. Face Touch Detection With TensorFlow.js Part 1: Using Real-Time Webcam Data With Deep Learning
  5. Face Touch Detection With TensorFlow.js Part 2: Using BodyPix
  6. Interpreting Hand Gestures and Sign Language in the Webcam With AI Using TensorFlow.js

How about some more computer vision right within a web browser? This time, we will build a Fluffy Animal Detector, where I will show you a way to leverage a pre-trained Convolutional Neural Network (CNN) model like MobileNet. The model will come trained on millions of images through heavy processing power; we will bootstrap it to quickly learn how to recognize other types of objects for your specific scenario through Transfer Learning using TensorFlow.js.

Image 1

Starting Point

To start training for custom object recognition based on the pre-trained MobileNet model, we need to:

  • Gather sample images categorized into "fluffy" and "not-fluffy," including some images that aren’t part of the MobileNet pre-trained categories (the images I used in this project are from pexels.com)
  • Import TensorFlow.js
  • Define Fluffy vs. Not-Fluffy category labels
  • Randomly pick and load one of the images
  • Show the prediction result in text
  • Load a pre-trained MobileNet model and classify images

This will be our starting point for this project:

JavaScript
<html>
    <head>
        <title>Fluffy Animal Detector: Recognizing Custom Objects in the Browser by Transfer Learning in TensorFlow.js</title>
        <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@2.0.0/dist/tf.min.js"></script>
        <style>
            img {
                object-fit: cover;
            }
        </style>
    </head>
    <body>
        <img id="image" src="" width="224" height="224" />
        <h1 id="status">Loading...</h1>
        <script>
        const fluffy = [
            "web/dalmation.jpg", // https://www.pexels.com/photo/selective-focus-photography-of-woman-holding-adult-dalmatian-dog-1852225/
            "web/maltese.jpg", // https://www.pexels.com/photo/white-long-cot-puppy-on-lap-167085/
            "web/pug.jpg", // https://www.pexels.com/photo/a-wrinkly-pug-sitting-in-a-wooden-table-3475680/
            "web/pomeranians.jpg", // https://www.pexels.com/photo/photo-of-pomeranian-puppies-4065609/
            "web/kitty.jpg", // https://www.pexels.com/photo/eyes-cat-coach-sofa-96938/
            "web/upsidedowncat.jpg", // https://www.pexels.com/photo/silver-tabby-cat-1276553/
            "web/babychick.jpg", // https://www.pexels.com/photo/animal-easter-chick-chicken-5145/
            "web/chickcute.jpg", // https://www.pexels.com/photo/animal-bird-chick-cute-583677/
            "web/beakchick.jpg", // https://www.pexels.com/photo/animal-beak-blur-chick-583675/
            "web/easterchick.jpg", // https://www.pexels.com/photo/cute-animals-easter-chicken-5143/
            "web/closeupchick.jpg", // https://www.pexels.com/photo/close-up-photo-of-chick-2695703/
            "web/yellowcute.jpg", // https://www.pexels.com/photo/nature-bird-yellow-cute-55834/
            "web/chickbaby.jpg", // https://www.pexels.com/photo/baby-chick-58906/
        ];

        const notfluffy = [
            "web/pizzaslice.jpg", // https://www.pexels.com/photo/pizza-on-brown-wooden-board-825661/
            "web/pizzaboard.jpg", // https://www.pexels.com/photo/pizza-on-brown-wooden-board-825661/
            "web/squarepizza.jpg", // https://www.pexels.com/photo/pizza-with-bacon-toppings-1435900/
            "web/pizza.jpg", // https://www.pexels.com/photo/pizza-on-plate-with-slicer-and-fork-2260200/
            "web/salad.jpg", // https://www.pexels.com/photo/vegetable-salad-on-plate-1059905/
            "web/salad2.jpg", // https://www.pexels.com/photo/vegetable-salad-with-wheat-bread-on-the-side-1213710/
        ];

        // Create the ultimate, combined list of images
        const images = fluffy.concat( notfluffy );

        // Newly defined Labels
        const labels = [
            "So Cute & Fluffy!",
            "Not Fluffy"
        ];

        function pickImage() {
            document.getElementById( "image" ).src = images[ Math.floor( Math.random() * images.length ) ];
        }

        function setText( text ) {
            document.getElementById( "status" ).innerText = text;
        }

        async function predictImage() {
            let result = tf.tidy( () => {
                const img = tf.browser.fromPixels( document.getElementById( "image" ) ).toFloat();
                const normalized = img.div( 127 ).sub( 1 ); // Normalize from [0,255] to [-1,1]
                const input = normalized.reshape( [ 1, 224, 224, 3 ] );
                return model.predict( input );
            });
            let prediction = await result.data();
            result.dispose();
            // Get the index of the highest value in the prediction
            let id = prediction.indexOf( Math.max( ...prediction ) );
            setText( labels[ id ] );
        }

        // Mobilenet v1 0.25 224x224 model
        const mobilenet = "https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json";

        let model = null;

        (async () => {
            // Load the model
            model = await tf.loadLayersModel( mobilenet );
            setInterval( pickImage, 5000 );
            document.getElementById( "image" ).onload = predictImage;
        })();
        </script>
    </body>
</html>

You can change the array of images to match the filenames of your test images. Once open in a browser, this page will show you a different, randomly selected image every five seconds.

Before going further, note that, for this project to run properly, the webpage and images must be served from a web server (due to HTML5 canvas restrictions). See a full explanation in the previous article.

Transfer Learning on MobileNet v1 Architecture

It’s important to understand the neural network architecture of the MobileNet model before we can apply any transfer learning.

MobileNets were designed with transfer learning in mind; they work through straightforward, sequential convolutional layers before passing their output to a final set of classification layers, which determine the output across 1000 categories.

Check out this printed view of the architecture when you run model.summary():

_________________________________________________________________
Layer (type) Output shape Param #
=================================================================
input_1 (InputLayer) [null,224,224,3] 0
_________________________________________________________________
conv1 (Conv2D) [null,112,112,8] 216
_________________________________________________________________
conv1_bn (BatchNormalization [null,112,112,8] 32
_________________________________________________________________
conv1_relu (Activation) [null,112,112,8] 0
_________________________________________________________________
....
_________________________________________________________________
conv_pw_13_bn (BatchNormaliz [null,7,7,256] 1024
_________________________________________________________________
conv_pw_13_relu (Activation) [null,7,7,256] 0
_________________________________________________________________
global_average_pooling2d_1 ( [null,256] 0
_________________________________________________________________
reshape_1 (Reshape) [null,1,1,256] 0
_________________________________________________________________
dropout (Dropout) [null,1,1,256] 0
_________________________________________________________________
conv_preds (Conv2D) [null,1,1,1000] 257000
_________________________________________________________________
act_softmax (Activation) [null,1,1,1000] 0
_________________________________________________________________
reshape_2 (Reshape) [null,1000] 0
=================================================================
Total params: 475544
Trainable params: 470072
Non-trainable params: 5472

All of the top layers, which begin with conv, are the network layers looking at the spatial information of the pixels, which eventually compile into the start of the classification with global_average_pooling2d_1, and then finally pass through the conv_preds layer that outputs the 1000 original categories the MobileNet was trained to predict.

We are going to intercept this model right before the conv_preds layer (i.e., in the "dropout" layer), attach new layers for classification to the "top," and train only these layers to predict just two categories – Fluffy vs. Not-Fluffy – while keeping the pre-trained spatial layers intact.

Let’s get to it!

Modifying the Model

After loading the pre-trained MobileNet model, we can find our "bottleneck" layer and create a new, truncated base model:

Python
const bottleneck = model.getLayer( "dropout" ); // This is the final layer before the conv_pred pre-trained classification layer
const baseModel = tf.model({
    inputs: model.inputs,
    outputs: bottleneck.output
});

Next, let’s freeze all of the "pre-bottleneck" layers to preserve the model’s training, so that we can leverage all the processing power that has already been put into this chunk of the model.

pyhton
// Freeze the convolutional base
for( const layer of baseModel.layers ) {
    layer.trainable = false;
}

Then we can attach our custom classification head, consisting of multiple dense layers, to the output of the base model for a new TensorFlow model that is ripe for training.

The final dense layer contains only two units, corresponding to the Fluffy vs. Not-Fluffy categories, and uses a softmax activation that will normalize the sum of the outputs to equal 1.0, meaning we can use each predicted category as the model’s prediction confidence value.

Python
// Add a classification head
const newHead = tf.sequential();
newHead.add( tf.layers.flatten( {
    inputShape: baseModel.outputs[ 0 ].shape.slice( 1 )
} ) );
newHead.add( tf.layers.dense( { units: 100, activation: 'relu' } ) );
newHead.add( tf.layers.dense( { units: 100, activation: 'relu' } ) );
newHead.add( tf.layers.dense( { units: 10, activation: 'relu' } ) );
newHead.add( tf.layers.dense( {
    units: 2,
    kernelInitializer: 'varianceScaling',
    useBias: false,
    activation: 'softmax'
} ) );
// Build the new model
const newOutput = newHead.apply( baseModel.outputs[ 0 ] );
const newModel = tf.model( { inputs: baseModel.inputs, outputs: newOutput } );

To keep the code clean, we can put it into a function and run it right after we load the MobileNet model:

Python
function createTransferModel( model ) {
    // Create the truncated base model (remove the "top" layers, classification + bottleneck layers)
    const bottleneck = model.getLayer( "dropout" ); // This is the final layer before the conv_pred pre-trained classification layer
    const baseModel = tf.model({
        inputs: model.inputs,
        outputs: bottleneck.output
    });
    // Freeze the convolutional base
    for( const layer of baseModel.layers ) {
        layer.trainable = false;
    }
    // Add a classification head
    const newHead = tf.sequential();
    newHead.add( tf.layers.flatten( {
        inputShape: baseModel.outputs[ 0 ].shape.slice( 1 )
    } ) );
    newHead.add( tf.layers.dense( { units: 100, activation: 'relu' } ) );
    newHead.add( tf.layers.dense( { units: 100, activation: 'relu' } ) );
    newHead.add( tf.layers.dense( { units: 10, activation: 'relu' } ) );
    newHead.add( tf.layers.dense( {
        units: 2,
        kernelInitializer: 'varianceScaling',
        useBias: false,
        activation: 'softmax'
    } ) );
    // Build the new model
    const newOutput = newHead.apply( baseModel.outputs[ 0 ] );
    const newModel = tf.model( { inputs: baseModel.inputs, outputs: newOutput } );
    return newModel;
}

...

(async () => {
    // Load the model
    model = await tf.loadLayersModel( mobilenet );
    model = createTransferModel( model );
    setInterval( pickImage, 2000 );
    document.getElementById( "image" ).onload = predictImage;
})();

Training the New Model

We’re almost there. There’s only one more step to go, which is to train our new TensorFlow model on our custom training data.

To generate training data tensors from the custom images, let’s create a function that loads an image to the web page’s image element and gets a normalized tensor:

JavaScript
async function getTrainingImage( url ) {
    return new Promise( ( resolve, reject ) => {
        document.getElementById( "image" ).src = url;
        document.getElementById( "image" ).onload = () => {
            const img = tf.browser.fromPixels( document.getElementById( "image" ) ).toFloat();
            const normalized = img.div( 127 ).sub( 1 );
            resolve( normalized );
        };
    });
}

And now, we can use this function to create our stack of input and target tensors. You might remember these as the xs and ys that we used for training in the first article of the series. We will only use half of the images from each category for training to verify that our new model makes predictions for fresh images.

Python
// Setup training data
const imageSamples = [];
const targetSamples = [];
for( let i = 0; i < fluffy.length / 2; i++ ) {
    let result = await getTrainingImage( fluffy[ i ] );
    imageSamples.push( result );
    targetSamples.push( tf.tensor1d( [ 1, 0 ] ) );
}
for( let i = 0; i < notfluffy.length / 2; i++ ) {
    let result = await getTrainingImage( notfluffy[ i ] );
    imageSamples.push( result );
    targetSamples.push( tf.tensor1d( [ 0, 1 ] ) );
}
const xs = tf.stack( imageSamples );
const ys = tf.stack( targetSamples );
tf.dispose( [ imageSamples, targetSamples ] );

Finally, we compile and fit the model to the data. Thanks to all of the pre-training in MobileNet, we only need about 30 epochs this time (instead of 100) to reliably distinguish between the categories.

Python
model.compile( { loss: "meanSquaredError", optimizer: "adam", metrics: [ "acc" ] } );

// Train the model on new image samples
await model.fit( xs, ys, {
    epochs: 30,
    shuffle: true,
    callbacks: {
        onEpochEnd: ( epoch, logs ) => {
            console.log( "Epoch #", epoch, logs );
        }
    }
});

Applying Marie Kondo’s KonMari method to code, let’s spark some joy by putting all the above code into the function before calling it:

pyhton
async function trainModel() {
    setText( "Training..." );

    // Setup training data
    const imageSamples = [];
    const targetSamples = [];
    for( let i = 0; i < fluffy.length / 2; i++ ) {
        let result = await getTrainingImage( fluffy[ i ] );
        imageSamples.push( result );
        targetSamples.push( tf.tensor1d( [ 1, 0 ] ) );
    }
    for( let i = 0; i < notfluffy.length / 2; i++ ) {
        let result = await getTrainingImage( notfluffy[ i ] );
        imageSamples.push( result );
        targetSamples.push( tf.tensor1d( [ 0, 1 ] ) );
    }
    const xs = tf.stack( imageSamples );
    const ys = tf.stack( targetSamples );
    tf.dispose( [ imageSamples, targetSamples ] );

    model.compile( { loss: "meanSquaredError", optimizer: "adam", metrics: [ "acc" ] } );

    // Train the model on new image samples
    await model.fit( xs, ys, {
        epochs: 30,
        shuffle: true,
        callbacks: {
            onEpochEnd: ( epoch, logs ) => {
                console.log( "Epoch #", epoch, logs );
            }
        }
    });
}
...
(async () => {
    // Load the model
    model = await tf.loadLayersModel( mobilenet );
    model = createTransferModel( model );
    await trainModel();
    setInterval( pickImage, 2000 );
    document.getElementById( "image" ).onload = predictImage;
})();

Running Object Recognition

With all the pieces in place, we should be able to run our Fluffy Animal Detector and see it learn to recognize the fluffiness! Check out some of these results from my laptop:

Image 2

Image 3

Image 4

Finish Line

To wrap up our project, here is the final code:

JavaScript
<html>
    <head>
        <title>Fluffy Animal Detector: Recognizing Custom Objects in the Browser by Transfer Learning in TensorFlow.js</title>
        <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@2.0.0/dist/tf.min.js"></script>
        <style>
            img {
                object-fit: cover;
            }
        </style>
    </head>
    <body>
        <img id="image" src="" width="224" height="224" />
        <h1 id="status">Loading...</h1>
        <script>
        const fluffy = [
            "web/dalmation.jpg", // https://www.pexels.com/photo/selective-focus-photography-of-woman-holding-adult-dalmatian-dog-1852225/
            "web/maltese.jpg", // https://www.pexels.com/photo/white-long-cot-puppy-on-lap-167085/
            "web/pug.jpg", // https://www.pexels.com/photo/a-wrinkly-pug-sitting-in-a-wooden-table-3475680/
            "web/pomeranians.jpg", // https://www.pexels.com/photo/photo-of-pomeranian-puppies-4065609/
            "web/kitty.jpg", // https://www.pexels.com/photo/eyes-cat-coach-sofa-96938/
            "web/upsidedowncat.jpg", // https://www.pexels.com/photo/silver-tabby-cat-1276553/
            "web/babychick.jpg", // https://www.pexels.com/photo/animal-easter-chick-chicken-5145/
            "web/chickcute.jpg", // https://www.pexels.com/photo/animal-bird-chick-cute-583677/
            "web/beakchick.jpg", // https://www.pexels.com/photo/animal-beak-blur-chick-583675/
            "web/easterchick.jpg", // https://www.pexels.com/photo/cute-animals-easter-chicken-5143/
            "web/closeupchick.jpg", // https://www.pexels.com/photo/close-up-photo-of-chick-2695703/
            "web/yellowcute.jpg", // https://www.pexels.com/photo/nature-bird-yellow-cute-55834/
            "web/chickbaby.jpg", // https://www.pexels.com/photo/baby-chick-58906/
        ];

        const notfluffy = [
            "web/pizzaslice.jpg", // https://www.pexels.com/photo/pizza-on-brown-wooden-board-825661/
            "web/pizzaboard.jpg", // https://www.pexels.com/photo/pizza-on-brown-wooden-board-825661/
            "web/squarepizza.jpg", // https://www.pexels.com/photo/pizza-with-bacon-toppings-1435900/
            "web/pizza.jpg", // https://www.pexels.com/photo/pizza-on-plate-with-slicer-and-fork-2260200/
            "web/salad.jpg", // https://www.pexels.com/photo/vegetable-salad-on-plate-1059905/
            "web/salad2.jpg", // https://www.pexels.com/photo/vegetable-salad-with-wheat-bread-on-the-side-1213710/
        ];

        // Create the ultimate, combined list of images
        const images = fluffy.concat( notfluffy );

        // Newly defined Labels
        const labels = [
            "So Cute & Fluffy!",
            "Not Fluffy"
        ];

        function pickImage() {
            document.getElementById( "image" ).src = images[ Math.floor( Math.random() * images.length ) ];
        }

        function setText( text ) {
            document.getElementById( "status" ).innerText = text;
        }

        async function predictImage() {
            let result = tf.tidy( () => {
                const img = tf.browser.fromPixels( document.getElementById( "image" ) ).toFloat();
                const normalized = img.div( 127 ).sub( 1 ); // Normalize from [0,255] to [-1,1]
                const input = normalized.reshape( [ 1, 224, 224, 3 ] );
                return model.predict( input );
            });
            let prediction = await result.data();
            result.dispose();
            // Get the index of the highest value in the prediction
            let id = prediction.indexOf( Math.max( ...prediction ) );
            setText( labels[ id ] );
        }

        function createTransferModel( model ) {
            // Create the truncated base model (remove the "top" layers, classification + bottleneck layers)
            const bottleneck = model.getLayer( "dropout" ); // This is the final layer before the conv_pred pre-trained classification layer
            const baseModel = tf.model({
                inputs: model.inputs,
                outputs: bottleneck.output
            });
            // Freeze the convolutional base
            for( const layer of baseModel.layers ) {
                layer.trainable = false;
            }
            // Add a classification head
            const newHead = tf.sequential();
            newHead.add( tf.layers.flatten( {
                inputShape: baseModel.outputs[ 0 ].shape.slice( 1 )
            } ) );
            newHead.add( tf.layers.dense( { units: 100, activation: 'relu' } ) );
            newHead.add( tf.layers.dense( { units: 100, activation: 'relu' } ) );
            newHead.add( tf.layers.dense( { units: 10, activation: 'relu' } ) );
            newHead.add( tf.layers.dense( {
                units: 2,
                kernelInitializer: 'varianceScaling',
                useBias: false,
                activation: 'softmax'
            } ) );
            // Build the new model
            const newOutput = newHead.apply( baseModel.outputs[ 0 ] );
            const newModel = tf.model( { inputs: baseModel.inputs, outputs: newOutput } );
            return newModel;
        }

        async function getTrainingImage( url ) {
            return new Promise( ( resolve, reject ) => {
                document.getElementById( "image" ).src = url;
                document.getElementById( "image" ).onload = () => {
                    const img = tf.browser.fromPixels( document.getElementById( "image" ) ).toFloat();
                    const normalized = img.div( 127 ).sub( 1 );
                    resolve( normalized );
                };
            });
        }

        async function trainModel() {
            setText( "Training..." );

            // Setup training data
            const imageSamples = [];
            const targetSamples = [];
            for( let i = 0; i < fluffy.length / 2; i++ ) {
                let result = await getTrainingImage( fluffy[ i ] );
                imageSamples.push( result );
                targetSamples.push( tf.tensor1d( [ 1, 0 ] ) );
            }
            for( let i = 0; i < notfluffy.length / 2; i++ ) {
                let result = await getTrainingImage( notfluffy[ i ] );
                imageSamples.push( result );
                targetSamples.push( tf.tensor1d( [ 0, 1 ] ) );
            }
            const xs = tf.stack( imageSamples );
            const ys = tf.stack( targetSamples );
            tf.dispose( [ imageSamples, targetSamples ] );

            model.compile( { loss: "meanSquaredError", optimizer: "adam", metrics: [ "acc" ] } );

            // Train the model on new image samples
            await model.fit( xs, ys, {
                epochs: 30,
                shuffle: true,
                callbacks: {
                    onEpochEnd: ( epoch, logs ) => {
                        console.log( "Epoch #", epoch, logs );
                    }
                }
            });
        }

        // Mobilenet v1 0.25 224x224 model
        const mobilenet = "https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json";

        let model = null;

        (async () => {
            // Load the model
            model = await tf.loadLayersModel( mobilenet );
            model = createTransferModel( model );
            await trainModel();
            setInterval( pickImage, 2000 );
            document.getElementById( "image" ).onload = predictImage;
        })();
        </script>
    </body>
</html>

What’s Next? Can We Detect Faces?

Are you amazed yet at what’s possible with deep learning inside a webpage, or how fast and easy it is? Next up, we’ll make use of the browser’s easy-to-use HTML5 webcam API to train and run predictions on real-time images.

Follow along with the next article in this series, Face Touch Detection with TensorFlow.js Part 1: Using Real-Time Webcam Data with Deep Learning.

Image 5

License

This article, along with any associated source code and files, is licensed under The Code Project Open License (CPOL)