Professional Documents
Culture Documents
Unbalanced and Small Sample Deep Learning For COVID X-Ray Classification
Unbalanced and Small Sample Deep Learning For COVID X-Ray Classification
1 Introduction
dataset [] for COVID positive Chest X-rays are limited and cannot be directly used to
train deep learning models for diagnosis.
In this paper, we propose a reliable deep learning model which can explain its
predictions, and address the imbalanced data classification task at both data level
-through data sampling techniques, and architecture level - through appropriate
network and regularization like drop outs and suitable learning strategies [ ].
To the best of our knowledge, we are the first one to employ attention based
convolutional neural network and learning strategies like LDAM loss to demonstrate
almost 100% recall maintaining high accuracy on the COVID-X dataset. We have
also brought in visualization of the classification task through density maps to help in
better visual inference of the processes.
Overall organization of the rest of the paper is as follows. Section 2 describes our
proposed study including the involved datasets, learning techniques, specialized
losses, the explainability approach and the architecture used. In Section 3, we describe
the experiments run on these datasets, quantify and analyze the effectiveness of the
different techniques discussed in enabling learning with limited unbalanced data.
Finally, conclusions and proposed future work are presented in Section 4.
2 Proposed Approach
2.1 Overview
2.2 Datasets
The work involves 3 major datasets. The NIH ChestX-ray14 dataset [7] with over 1
lakh labelled Chest X-Ray images and 14 pathological classifications is the largest
and most diverse Chest X-ray dataset publicly available. It has been studied
extensively and used to build many robust deep learning architectures [9, 11, 12].
However, out of these only around 67,000 images are PA views which are considered
for pretraining the proposed deep backbone model on a multi-label classification task.
For testing few instance learning and COVID19 diagnosis, the RSNA Pneumonia
dataset [13] with over 8000 Normal and 9000 Pneumonia affected Chest X-rays are
used in unison with the largest publicly available COVID Chest X-ray dataset [8],
which consists of only 142 PA COVID Chest X-rays. This combined unbalanced
dataset is named COVIDx and used to train the proposed robust deep learning model
for diagnosing COVID19 with high specificity and sensitivity, guided by transfer
learning, undersampling and specialized loss functions.
4
Transfer Learning. State of the art deep learning convolutional neural networks
require lakhs of images to train its parameters from random initialization. Fortunately,
recent trends [?] have shown that deep CNNs trained on one task may also perform
well on a comparable task with very little training. This transfer of knowledge from
one domain to another is exploited to pretrain the model on NIH multi-label
classification task prior to training on COVIDx 3-class classification task. During the
fine-tuning process it is also common to make use of dropouts [?] which is shown to
prevent overfitting to a great extent.
Data Augmentation and Preprocessing. By far the most widely studied and used
technique [?] to deal with limited datasets is augmenting existing data with simple
affine transformations, crops, flips and noise. They are also used with large datasets to
improve model robustness and generalizability, especially in the medical domain [?].
However, when data is extremely scarce, as is the case with many medical imaging
tasks, augmentation does not equal more data, due to lack of diversity.
Apart from the above-mentioned techniques, sometimes simple preprocessing steps
might help improve performance. Denoising, binarization, normalization are simple
techniques borrowed from the machine learning era. Histogram equalization is one
such technique applied on gray-scale inputs to improve quality of image without
introducing artifacts or information loss and helps deep models learn better features.
Sampling. Two approaches typically used to tackle class imbalance are oversampling
of minority class and undersampling of majority class. [?, ?] Oversampling minority
class removes the inherent bias towards majority classes, by having duplicate copies
of minority samples in each batch. However, it is more prone to overfitting to the
small number of samples and leads to poor generalization. Under-sampling majority
classes also removes the dataset bias, but leads to loss of true diverse information and
leaves a small number of samples for training, which more often is not enough to train
deep models.
Specialized Loss Functions. Categorical cross entropy is the standard loss used for
classification tasks, when datasets are balanced. However, when datasets are
unbalanced, it works against the minority class and is prone to overfit to the small set
of samples. The deep learning community has come up with specialized losses to
5
effectively tackle dataset imbalance, either in unison with cross entropy loss or
independently. Some key ideas behind these cost functions are to address prior
probability bias, minimize intraclass variance and maximize interclass variance [].
Weighted Loss. Weighting the samples in inverse proportionality to their class support
strength is often the traditional approach to tackle imbalance, which addresses the
difference in prior probabilities. Some variants such as Class-Balanced Loss [15] have
been proposed to tackle imbalances up to 1:200. However, in cases of extremely large
imbalance and very few samples of minority class, these render training highly
unstable and sometimes lead to overfitting as well.
ArcFace Loss[]. This was introduced in face recognition tasks to obtain highly
discriminative features. The key objective is to minimize intraclass variance and
maximize interclass variance. And this is done in angular space by adding additive
angular margin between the classes. ArcFace loss is defined as
N
1 ∑ es(cos(θyi +m))
L = − N ( l og( n )
i=1 es(cos(θyi +m)) + ∑ es(cosθj )
j=1,j=y
/ i
(1)
where N is no.of samples, θj is the angle between the weight and the feature
of class j and s is scale factor.
Label Distribution Aware Margin Loss (LDAM)[]. The key idea of this optimization
objective is to train a simpler model for minority class, providing more room for
generalization and a complex model for majority class. This provides stronger
regularization to minority class, to enhance mode robustness on unseen data.
Its class label dependent margins are enforced into a well-established classification
loss function, the cross-entropy loss, to adjust the margin in favor of the minority
class.
ez y −∆y
L ((x, y) ; f ) = − log (2)
ez y −∆y + ∑ ϱz j
j=y
/
C
where ∆j = nj
for j ∈{1, 2, ..., k } , C is a hyper-parameter and k is the total number
of classes.
merely based on other artifacts or inherent biases in the images. When available
datasets are sparse, these biases tend to get magnified by deep models and visualizing
the model’s reasoning is essential.
In that regard, we make use of an attention-based architecture [9] which globally
predicts the amount of attention to be provided for different local regions of the image
and evaluates the image with these additional weights, just like a coarse to fine search.
These attention weights are compared against existing explainability approaches such
as [5].
We take ideas from [9] to build a state-of-the-art Attention based CNN architecture
with Global, Local and Fusion branches. The Global and Local branches are based on
the standard DenseNet121 [16] architecture. The architecture is developed at three
stages and 3 class classification was carried out at all the three stages.
In stage 1, the chest X-ray image is resized to 224x224 and converted to grayscale
format and provided as model input to the Global DenseNet which produces an
attention weight map and a 1024 dimensional feature vector. Unlike in [9] the real
attention map is directly multiplied with the input image and passed through Local
DenseNet for stage 2 evaluation, which produces another 1024-dimensional feature
vector. This enables more effective feature fusion and attention map training for
identifying the RoI for local branch. Finally, in stage 3, these 2 feature vectors are
concatenated and passed through a Fusion branch, a single Dense layer, to produce
the 3 class probabilities.
The model is trained in a stage wise manner similar to [9], Global, Local and then
Fusion. Overall pre training is done on the NIH dataset and then fine tuned on
COVIDx dataset. The entire architecture was built using PyTorch [?].
Hyperparameters such as epochs, learning rate, batch size, optimizers were
experimentally fine tuned for different experiments.
Fig. 1. Model Architectural Setup ##[Need to update image with relevant one]
7
Since the COVIDx dataset [] is highly imbalanced, we started with the objective to
improve recall for minority class such that no Covid patient goes misclassified.
Training the model as is on the given COVIDx dataset would ideally result in low
sensitivity towards COVID class. We made use of transfer learning to pretrain the
model on NIH multi-label classification task prior to training on COVIDx 3 class
classification task. During the fine-tuning process we made an extensive use of
dropouts to prevent overfitting. The model’s performance was studied with more
suitable learning methods whose objectives were towards addressing imbalance in
data. The detailed study was presented in the following sections 3.1 and 3.2 with
quantitative results. In section 3.3 our model’s predictions were explained using
GradCam[] and also the attention provided by the heat maps on the input images were
visually presented.
3.1 Experiments
Experiment A. In this we made an ablation study on loss functions, observed the
model’s performance with the different learning strategies. Initially at data level,
under-sampling operation was carried out to mitigate the imbalance by cutting down
the available dataset, resulting in a dataset containing moderate imbalance. Later the
learning strategy [] was employed to ensure that the model can generalize well for
new unseen data of minority class in particular.
While there is no common benchmark dataset to compare all Chest X-Ray
based Covid Diagnosis works, we have listed the scores reported by some
state-of-the-art methods on their own closed-source data and our model’s performance
was given in table no. 1.
Table no. 2 shows the data augmentation results with recall and accuracy values.using the aforementioned
imbalanced test dataset, using loss functions NLL and LDAM disjointly.
NLL 80 92.68
LDAM 84 92.68
Local NLL 84 89
In order to get visual explanations for the classification done by the network,
GradCam[] based class discriminative feature maps are plotted for test images from
all the three classes. These images are showcased in Figure 2, which specifies the
portions of the images based on which the particular decisions have been taken.
For pneumonia class. Several regions in the lungs are shown to be affected as the
heat map shows attention in many regions. This is in line with medical findings [?]
which suggest pneumonia does not affect lungs in a localized manner.
For normal and covid class samples. The heatmaps are concentrated to a specific
region & similar to each other, this shows the network is learning to classify
distinctively between them based on local features.
10
The attention maps which are generated from the global branch and fed as input to
the local branch also show similar affected regions and hence affirms the model’s
ability to explain its predictions.
Fig. 2. Demonstrating GradCAM feature map (row 2) and model’s attention feature map (row
3) for the input images (row 1). Left is normal, middle is pneumonia and right is Covid.
###[Can update image distribution, to neatly show difference]
4 Conclusion
We believe these techniques can be widely used across many other Medical
Imaging applications wherein publicly available dataset is scarce. Especially in
uncertain times such as the ongoing pandemic, global access to open source
diagnosed images is questionable, and thus these techniques are of essence at building
robust deep-learning models.
1. Bai, Harrison & Hsieh, Ben & Xiong, Zeng & Halsey, Kasey & Choi, Ji & Tran, Thi &
Pan, Ian & Shi, Lin-Bo & Wang, Dong-Cui & Mei, Ji & Jiang, Xiao-Long & Zeng,
Qiu-Hua & Egglin, Thomas & Hu, Ping-Feng & Agarwal, Saurabh & Xie, Fangfang & Li,
Sha & Healey, Terrance & Atalay, Michael & Liao, Wei-Hua: Performance of radiologists
in differentiating COVID-19 from viral pneumonia on chest CT. Radiology 200823.
(2020).
2. Wang, Xiaosong & Peng, Yifan & Lu, Le & Lu, Zhiyong & Bagheri, Mohammadhadi &
Summers, Ronald: ChestX-ray14:Hospital-scale Chest X-ray Database and Benchmarks
on Weakly-Supervised Classification and Localization of Common Thorax Diseases.
(2017).
3. Joseph Paul Cohen and Paul Morrison and Lan Dao: COVID-19 image data collection.
https://github.com/ieee8023/covid-chestxray-dataset, arXiv:2003.11597 (2020).
4. Guan, Qingji & Huang, Yaping & Zhong, Zhun & Zheng, Zhedong & Zheng, Liang &
Yang, Yi: Thorax Disease Classification with Attention Guided Convolutional Neural
Network. Pattern Recognition Letters 131. (2019).
5. R. R. Selvaraju, M. Cogswell, A. Das, R. Vedantam, D. Parikh and D. Batra: Grad-CAM:
Visual Explanations from Deep Networks via Gradient-Based Localization. In Proceedings
of the 2017 IEEE International Conference on Computer Vision, pp. 618-626. IEEE,
Venice (2017).
6. Rajpurkar, Pranav & Irvin, Jeremy & Zhu, Kaylie & Yang, Brandon & Mehta, Hershel &
Duan, Tony & Ding, Daisy & Bagul, Aarti & Langlotz, Curtis & Shpanskaya, Katie &
Lungren, Matthew & Ng, Andrew: CheXNet: Radiologist-Level Pneumonia Detection on
Chest X-Rays with Deep Learning. (2017).
7. K. Wang, X. Zhang and S. Huang: KGZNet:Knowledge-Guided Deep Zoom Neural
Networks for Thoracic Disease Classification. In Proceedings of the 2019 IEEE
International Conference on Bioinformatics and Biomedicine, pp. 1396-1401. IEEE, San
Diego (2019).
8. RSNA Dataset, https://www.kaggle.com/c/rsna-pneumonia-detection-challenge/data
1. Jaehyung Kim and Jongheon Jeong and Jinwoo Shin: M2m: Imbalanced Classification via
Major-to-minor Translation. arXiv:2003.11597 (2020).
12
2. Cui, Yin & Jia, Menglin & Lin, Tsung-Yi & Song, Yang & Belongie, Serge:
Class-Balanced Loss Based on Effective Number of Samples. In Proceedings of the 2019
IEEE Conference on Computer Vision and Pattern Recognition, pp. 9260-9269. (2019).
3. G. Huang, Z. Liu, L. Van Der Maaten and K. Q. Weinberger: Densely Connected
Convolutional Networks. In Proceedings of the 2017 IEEE Conference on Computer
Vision and Pattern Recognition, pp. 2261-2269. IEEE, Honolulu (2017).
9. Ozturk, T., Talo, M., Yildirim, E.A., Baloglu, U.B., Yildirim, O. and Acharya, U.R., 2020.
Automated detection of COVID-19 cases using deep neural networks with X-ray
images. Computers in Biology and Medicine, p.103792.
10. Oh, Y., Park, S. and Ye, J.C., 2020. Deep learning covid-19 features on cxr using limited
training data sets. IEEE Transactions on Medical Imaging.
11. Apostolopoulos, I.D. and Mpesiana, T.A., 2020. Covid-19: automatic detection from x-ray
images utilizing transfer learning with convolutional neural networks. Physical and
Engineering Sciences in Medicine, p.1.
12. Khan, A.I., Shah, J.L. and Bhat, M.M., 2020. Coronet: A deep neural network for
detection and diagnosis of COVID-19 from chest x-ray images. Computer Methods and
Programs in Biomedicine, p.105581.
13. C. Szegedy, V. Vanhoucke, S. Ioffe, J. Shlens, Z. Wojna, Rethinking the inception
architecture for computer vision. In: Proceedings of the IEEE Conference on Computer
Vision and Pattern Recognition 2016 (pp. 2818–2826).
14. Mahmud, T., Rahman, M.A. and Fattah, S.A., 2020. CovXNet: A multi-dilation
convolutional neural network for automatic COVID-19 and other pneumonia detection
from chest X-ray images with transferable multi-receptive feature optimization. Computers
in Biology and Medicine, p.103869.
15. Chawla, N.V., Bowyer, K.W., Hall, L.O., Kegelmeyer, W.P.: SMOTE: Synthetic Minority
Over-sampling Technique. Journal of Artificial Intelligence Research 16(1), 321–357
(2002).
16. Sonak, A. and Patankar, R.A., 2015. A survey on methods to handle imbalance
dataset. Int. J. Comput. Sci. Mobile Comput, 4( 11), pp.338-343.
17. Drummond, C. and Holte, R.C., 2003, August. C4. 5, class imbalance, and cost sensitivity:
why under-sampling beats over-sampling. In Workshop on learning from imbalanced
datasets II (Vol. 11, pp. 1-8). Washington DC: Citeseer.
18. Liu, X.Y., Wu, J. and Zhou, Z.H., 2008. Exploratory undersampling for class-imbalance
learning. IEEE Transactions on Systems, Man, and Cybernetics, Part B
(Cybernetics), 39(2), pp.539-550.
19. Ramyachitra, D. and Manikandan, P., 2014. Imbalanced dataset classification and
solutions: a review. International Journal of Computing and Business Research
(IJCBR), 5(4).
20. Elkan, C., 2001, August. The foundations of cost-sensitive learning. In International joint
conference on artificial intelligence (Vol. 17, No. 1, pp. 973-978). Lawrence Erlbaum
Associates Ltd.
21. Khan, S.H., Hayat, M., Bennamoun, M., Sohel, F.A. and Togneri, R., 2017. Cost-sensitive
learning of deep feature representations from imbalanced data. IEEE transactions on
neural networks and learning systems, 29( 8), pp.3573-3587.
22. Mikołajczyk, A. and Grochowski, M., 2018, May. Data augmentation for improving deep
learning in image classification problem. In 2018 international interdisciplinary PhD
workshop (IIPhDW) (pp. 117-122). IEEE.
23. Perez, L. and Wang, J., 2017. The effectiveness of data augmentation in image
classification using deep learning. arXiv preprint arXiv:1712.04621.
13
24. Cao, Kaidi, et al. "Learning imbalanced datasets with label-distribution-aware margin
loss." Advances in Neural Information Processing Systems. 2019.
25. Wang, Feng, et al. "Additive margin softmax for face verification." IEEE Signal
Processing Letters 25.7 (2018): 926-930.
26. Guan, Qingji, et al. "Diagnose like a radiologist: Attention guided convolutional neural
network for thorax disease classification." arXiv preprint arXiv:1801.09927 (2018).
27. Wang, Y., Yao, Q., Kwok, J.T. and Ni, L.M., 2020. Generalizing from a few examples: A
survey on few-shot learning. ACM Computing Surveys (CSUR), 53(3), pp.1-34.
28. Redmon, J., Divvala, S., Girshick, R. and Farhadi, A., 2016. You only look once: Unified,
real-time object detection. In Proceedings of the IEEE conference on computer vision and
pattern recognition (pp. 779-788).