Here transfer learning is employed to leverage the knowledge of ResNet50 in a new classification task – the detection of COVID-19. Transfer learning includes two stages: freezing and fine-tuning.
In this series of articles, we’ll apply a Deep Learning (DL) network, ResNet50, to diagnose Covid-19 in chest X-ray images. We’ll use Python’s TensorFlow library to train the neural network on a Jupyter Notebook.
The tools and libraries you’ll need for this project are:
IDE:
Libraries:
We are assuming that you are familiar with deep learning with Python and Jupyter notebooks. If you're new to Python, start with this tutorial. And if you aren't yet familiar with Jupyter, start here.
In the previous article, we restructured the ResNet50 model to fit a new classification task - differentiation between COVID-19 and Normal chest X-rays. In this article, we’ll fine-tune our model to deliver the expected performance and accuracy.
Training the Model
In this project, transfer learning is employed to leverage the knowledge of ResNet50 in a new classification task – the detection of COVID-19. Transfer learning includes two stages: freezing and fine-tuning.
In the freezing stage, the publicly available weights and learned parameters of the pretrained model are frozen and utilized "as is." Fine-tuning begins by removing the fully connected layer (FC) of the ResNet50 and rearchitecting it into three fully connected layers, with two output neurons at the output layer corresponding to the COVID-19 and Normal chest X-rays. Note that the weights of the FC layers are initiated randomly during training. The weights of the remaining layers are frozen to make sure they act as a strong feature extractor of high abstraction levels of input images, as they are already trained on millions of images from the ImageNet dataset.
The network is trained on 1,590 images using the stochastic method of gradient descent optimization, with the batch size of 64 images per iteration. Code that reads the training data, as well as the data selection, was discussed in the second article of this series.
To minimize the cost function during training, the initial learning rate and the reducing factor of the fully connected layers are set to 0.0001 and 0.1, respectively. The selection of the number of epochs is a complex task: it is directly associated with the number of optimizations during training. If the number of epochs is high, the network may overfit and perform poorly.
To avoid overfitting, we monitored the error and performance rate on validation images. It was found that ResNet50 achieved its highest training accuracy, as well as the best generalization capability, at epoch 3. Table 1 shows good training performance of the network – 98.7% accuracy – after 3 epochs (figure 6).
tf.keras.callbacks.EarlyStopping(
monitor='val_loss', min_delta=0, patience=0, verbose=0,
mode='auto', baseline=None, restore_best_weights=False
)
model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])
history = model.fit_generator(generator = train_generator, steps_per_epoch=train_generator.n//train_generator.batch_size, epochs = 3)
Figure 6: Training process of the network
To evaluate the model performance during training, let’s plot accuracy and loss over the number of epochs (figures 7 & 8):
acc = history.history['accuracy']
loss = history.history['loss']
plt.figure()
plt.plot(acc, label='Training Accuracy')
plt.ylabel('Accuracy')
plt.title('Training Accuracy')
plt.figure()
plt.plot(loss, label='Training Loss')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.xlabel('epoch')
plt.show()
Figure 7: Training Accuracy vs the number of epochs
Figure 8: Loss vs the number of epochs
As seen in Figures 7 and 8, the training accuracy increases with the increase of epochs number, while the training loss decreases with every epoch. This is what should be expected when you run the gradient descent optimization. In other words, the error should be reduced with every iteration.
Evaluating the Model
Once the model is trained with good accuracy, we can start testing it on our testing dataset, which consists of 895 images that were not part of the training set. The testing set was loaded in the second article of the series as test_generator
.
To test our model, let’s use the model.evaluate
function. This function computes the loss and accuracy based on the input data you pass to it. We passed all the input test data as test_generator
and evaluated the model.
Testresults = model.evaluate(test_generator)
print("test loss, test acc:", Testresults)
Figure 9: A snapshot of the testing accuracy
As you can see above (figure 9), our model shows good generalization capability in sorting new images into COVID-19 or Normal, with the accuracy of 95% and the test loss of 0.108.
We can also evaluate our model using the model.predict
function. This function is similar to model.evaluate
except it generates the predicted outputs of a given input you feed to the model. To obtain all predictions of the testing data using model.predict
:
predictions = model.predict(test_generator, batch_size=None, verbose=0, steps=None, callbacks=None)
print(predictions)
After getting the predictions generated by the Softmax function as probabilities, we round these predictions to 0s or 1s. The rounded outputs are stored in classes while the actual outputs go to test_generator.Labels
.
classes = np.argmax(predictions, axis = 1)
print(classes)
print(test_generator.labels)
We also used model.predict
to classify the category of one image. We first printed the classes of all testing images using
print(test_generator.labels)
Then, we selected one image that belonged to class 0 (COVID-19) and fed it to model.predict
:
x=test_generator[5]
predictions = model.predict(x)
print('First prediction:', predictions[0])
category = np.argmax(predictions[0], axis = 0)
print("the image class is:", category)
Gradient Weight Class Activation Mapping (Grad-Cam)
To interpret and evaluate the employed model better, we can employ many parameters and metrics, such as Grad-Cam. We visualized Grad-Cam of some correctly classified COVID-19 and Normal testing images.
Grad-Cam is a method that enables visualization of the activations in the areas that the network focused on to classify a certain image. The suspected regions associated with the predicted class are highlighted by heatmaps where the highest activation regions appear in deep red, and the lowest activation regions – in deep blue.
Figure 10 shows Grad-Cam of some testing images that were correctly classified by ResNet50. The upper row shows the Normal examples while the lower one shows the COVID-19 images. Note that we used OpenCV (CV2) in addition to TensorFlow to compute and plot the heatmaps, hence CV2 had to be imported.
We selected some testing images and computed their Grad_Cam. To process an additional image, you just need to change the image name in IMAGE_PATH
. It is important to select the last convolution layer as LAYER_NAME
to investigate what the model visualizes in the final convolutional layer, which helps it make the final decision. Class_INDEX
represents the class of the given image, which is COVID-19 (0) in this example.
import cv2
IMAGE_PATH = r'C:\Users\abdul\Desktop\ContentLab\P1\archive\COVID-19 Radiography Database\test\NORMAL\NORMAL (2).png'
LAYER_NAME = 'conv5_block3_out'
CLASS_INDEX = 0
img = tf.keras.preprocessing.image.load_img(IMAGE_PATH, target_size=(224, 224))
img = tf.keras.preprocessing.image.img_to_array(img)
model = model
grad_model = tf.keras.models.Model([model.inputs], [model.get_layer(LAYER_NAME).output, model.output])
with tf.GradientTape() as tape:
conv_outputs, predictions = grad_model(np.array([img]))
loss = predictions[:, CLASS_INDEX]
output = conv_outputs[0]
grads = tape.gradient(loss, conv_outputs)[0]
weights = tf.reduce_mean(grads, axis=(0, 1))
cam = np.ones(output.shape[0:2], dtype=np.float32)
for index, w in enumerate(weights):
cam += w * output[:, :, index]
cam = cv2.resize(cam.numpy(), (224, 224))
cam = np.maximum(cam, 0)
heatmap = (cam - cam.min()) / (cam.max() - cam.min())
cam = cv2.applyColorMap(np.uint8(255*heatmap), cv2.COLORMAP_JET)
output_image = cv2.addWeighted(cv2.cvtColor(img.astype('uint8'), cv2.COLOR_RGB2BGR), 0.5, cam, 1, 0)
plt.imshow(output_image, cmap='rainbow')
Figure 10: Heatmaps of test images
Next Step
In the next article, we’ll discuss the COVID-19 detection results we’ve achieved with our model and compare these results with those of other models. Stay tuned!