Click here to Skip to main content
65,938 articles
CodeProject is changing. Read more.
Articles / artificial-intelligence / Keras

Preparing Data for AI Fashion Classification

5.00/5 (3 votes)
16 Mar 2021CPOL3 min read 9.8K  
In this article we’ll show you how to use transfer learning to fine-tune the VGG19 model to classify fashion clothing categories.
Here we show you how to load the DeepFashion dataset, and how to restructure the VGG16 model to fit our clothing classification task.

Introduction

The availability of datasets like DeepFashion open up new possibilities for the fashion industry. In this series of articles, we’ll showcase an AI-powered deep learning system that can revolutionize the fashion design industry by helping us better understand customers’ needs.

In this project, we’ll use:

We are assuming that you are familiar with the concepts of deep learning, as well as with Jupyter Notebooks and TensorFlow. If you’re new to Jupyter Notebooks, start with this tutorial. You are welcome to download the project code.

In the previous article, we discussed the data subset to be used, as well as formulated the problem. In this article, we’ll apply transfer learning to a VGG16 deep network to classify clothes in images from the DeepFashion dataset. We’ll fine-tune the VGG16 pre-trained model to fit the task of classifying clothes into 15 different categories. The network will be trained on a subset containing 9,935 images.

Loading Dataset

In this project, we’ll use TensorFlow and Keras to fine-tune VGG16 as Keras provides easy-to-use tools for loading data, loading pre-trained models, and fine-tuning. The ImageDataGenerator tools will help us load, normalize, resize, and rescale the data.

To start, let’s import the libraries that we’ll need:

Python
import os
import matplotlib.pyplot as plt
import matplotlib.image as img
import tensorflow.keras as keras
import numpy as np
import tensorflow as tf
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
tf.compat.v1.keras.backend.set_session(tf.compat.v1.Session(config=config))

Now that we have the basics, let’s import ImageDataGenerator, add our data directory, and start loading the training and validation data. The validation data will amount to only 10% of our training data: we only need the validation data to fine-tune the hyperparameters of the VGG16 during training. We set our batch size in this part of the code. You can set its value to fit your machine capabilities and memory.

Python
datasetdir = r'C:\Users\myuser\Desktop\\DeepFashion\Train'
os.chdir(datasetdir)
from tensorflow.keras.preprocessing.image import ImageDataGenerator
batch_size = 3

def DataLoad(shape, preprocessing): 
    '''Create the training and validation datasets for 
    a given image shape.
    '''
    imgdatagen = ImageDataGenerator(
        preprocessing_function = preprocessing,
        horizontal_flip = True, 
        validation_split = 0.1,
    )

    height, width = shape

    train_dataset = imgdatagen.flow_from_directory(
        os.getcwd(),
        target_size = (height, width), 
        classes = ['Blazer', 'Blouse', 'Cardigan', 'Dress', 'Jacket',
                 'Jeans', 'Jumpsuit', 'Romper', 'Shorts', 'Skirts', 'Sweater', 'Sweatpants', 'Tank', 'Tee', 'Top'],
        batch_size = batch_size,
        subset = 'training', 
    )

    val_dataset = imgdatagen.flow_from_directory(
        os.getcwd(),
        target_size = (height, width), 
        classes = ['Blazer', 'Blouse', 'Cardigan', 'Dress', 'Jacket',
                 'Jeans', 'Jumpsuit', 'Romper', 'Shorts', 'Skirts', 'Sweater', 'Sweatpants', 'Tank', 'Tee', 'Top'],
        batch_size = batch_size,
        subset = 'validation'
    )
    return train_dataset, val_dataset

Now that the DalaLoad function is set, let’s use it to extract our training and validation data and resize the images to the shape that fits our pre-trained model: 224 x 224 x 3.

Python
train_dataset, val_dataset = DataLoad((224,224), preprocessing=vgg16.preprocess_input)

We can now use the next function of ImageDataGenerator as an iterator when sequentially loading images for a single dataset. Using next, you’ll have your training images and labels saved in the X_train and y_train parameters, respectively. You can apply the same function to the validation and testing data.

Python
# Function for plots images with labels within jupyter notebook
X_train, y_train = next(train_dataset)

As you can see, we have 7,656 images as training data, belonging to 15 different categories, as well as 842 validation images .

Loading Pretrained Model (VGG16) from Keras

Time to load the VGG16 network from Keras and show its baseline:

Python
vgg16 = keras.applications.vgg16
conv_model = vgg16.VGG16(weights='imagenet', include_top=False)
conv_model.summary()

Next, we load the ImageNet weights of this network, so that we can use them during the transfer learning.

Python
conv_model = vgg16.VGG16(weights='imagenet', include_top=False, input_shape=(224,224,3))

With the network baseline and its corresponding weights loaded, let’s start restructuring the VGG16 to classify the 15 different clothing categories. To do so, we will add a flattening layer, three dense layers of 100 nodes, a dense layer of 15 layers representing the 15 clothing categories, and a Softmax layer to show the class (category) probabilities.

Python
# flatten the output of the convolutional part: 
x = keras.layers.Flatten()(conv_model.output)
# three hidden layers
x = keras.layers.Dense(100, activation='relu')(x)
x = keras.layers.Dense(100, activation='relu')(x)
x = keras.layers.Dense(100, activation='relu')(x)
# final softmax layer with 15 categories
predictions = keras.layers.Dense(15, activation='softmax')(x)

# creating the full model:
full_model = keras.models.Model(inputs=conv_model.input, outputs=predictions)
full_model.summary()

Image 1

Next Steps

In the next article, we’ll show you how to train VGG19 to recognize what people are wearing. Stay tuned!

License

This article, along with any associated source code and files, is licensed under The Code Project Open License (CPOL)