Click here to Skip to main content
65,938 articles
CodeProject is changing. Read more.
Articles / artificial-intelligence / Keras

Restructuring ResNet50 to Diagnose COVID-19

5.00/5 (2 votes)
17 Feb 2021CPOL2 min read 7.1K   65  
In this article, we’ll work on restructuring ResNet50 to perform the new classification task.
Here we’ll make this base model fit a new classification task: COVID-19 and Normal chest X-rays. We’ll use the ResNet50 model with a dataset that contains 2,484 images – a small dataset compared to ImageNet.

In this series of articles, we’ll apply a Deep Learning (DL) network, ResNet50, to diagnose Covid-19 in chest X-ray images. We’ll use Python’s TensorFlow library to train the neural network on a Jupyter Notebook.

The tools and libraries you’ll need for this project are:

IDE:

Libraries:

We are assuming that you are familiar with deep learning with Python and Jupyter notebooks. If you're new to Python, start with this tutorial. And if you aren't yet familiar with Jupyter, start here.

In the previous article, we loaded the base model and showed its layers. Now, we’ll make this base model fit a new classification task: COVID-19 and Normal chest X-rays. We’ll use the ResNet50 model with a dataset that contains 2,484 images – a small dataset compared to ImageNet. To make our model fit the above new task, we need to:

  • Remove the fully connected layers of the network and add a global averaging layer to condense all the feature maps
  • Replace the fully connected layers of the base model with new layers
  • Add a new dense output layer with two nodes that represent the two target classes: COVID-19 and Normal
  • Freeze the weights of the pretrained layers in the feature extraction part and randomize those of the new fully connected layers
  • Train ResNet50 to update only the weights of the fully connected layers

Restructuring the Base Model

As mentioned above, the first step towards reshaping our ResNet50 to apply transfer learning is to remove the fully connected layers and add a global averaging pooling layer, which is used to condense all the feature maps from the base model. In addition, two dense layers are used – one with 512 nodes and the other one with 2 that represent the two target classes.

Python
x = base_model.output
x = tf.keras.layers.GlobalAveragePooling2D()(x)

x = tf.keras.layers.Dense(1024, activation='relu')(x)
x = tf.keras.layers.Dense(1024, activation='relu')(x)
x = tf.keras.layers.Dense(1024, activation='relu')(x)
x = tf.keras.layers.Dense(512, activation='relu')(x)
preds = tf.keras.layers.Dense(2, activation ='softmax')(x)

Now, we can create a model with the new structure, containing the base model (feature extraction part), the new input data, and the output structure (preds).

Python
model = tf.keras.models.Model(inputs=base_model.input, outputs=preds)
print(model.summary())

Figure 5 shows that the reshaped model is similar to the base one. The differences are the added global averaging pooling layer and some fully connected dense layers, which change the network output to fit our new target classification task.

Image 1

Figure 5: A snapshot of the reshaped ResNet50 model

Freezing Weights

Now it’s time to freeze the weights of all layers before the global averaging pooling layer.

We’ll use the same code we’ve used in the previous article to help us enumerate the layers, and then we’ll freeze them by setting them "False."

Python
for i, layer in enumerate(model.layers):
  print(i, layer.name)

for layer in model.layers[:175]:
  layer.trainable = False

After freezing the weights of the layers, we’ll set the newly added layers as trainable by setting them "True."

Python
for layer in model.layers[175:]:
  layer.trainable = True

Next Step

In the next article, we’ll fine-tune our ResNet50 model. Stay tuned!

License

This article, along with any associated source code and files, is licensed under The Code Project Open License (CPOL)