7 Best Techniques To Improve The Accuracy of CNN W:O Overfitting

You might also like

Download as pdf or txt
Download as pdf or txt
You are on page 1of 2

!"#$%&'()*%+ ,-(./+%+01.

2./0"/3((4
,.5*678*696: ; <*=%+*/(.) ; -%&4(+

!"#$%&"'$()*+,-$%"'."/012.3$"')$"4((-25(6".7
899":;<"<3$27+&&+*=
>.4.&(4*?*@ABCD*:9

!'E4E*#5*,.44'(F*2(+/5*G/E=*H"/&4

While we develop the Convolutional Neural Networks (CNN) to classify the


images, It is often observed the model starts overfitting when we try to improve
the accuracy. Very frustrating, Hence I list down the following techniques which
would improve the model performance without overfitting the model on the
training data.

1. Data normalization
We normalized the image tensors by subtracting the mean and dividing by
the standard deviation of pixels across each channel. Normalizing the data
prevents the pixel values from any one channel from disproportionately
affecting the losses and gradients. Learn more

2. Data augmentation
We applied random transformations while loading images from the training
dataset. Specifically, we will pad each image by 4 pixels, and then take a
random crop of size 32 x 32 pixels, and then flip the image horizontally with
a 50% probability. Learn more

3. Batch normalization
After each convolutional layer, we added a batch normalization layer, which
normalizes the outputs of the previous layer. This is somewhat similar to data
normalization, except it’s applied to the outputs of a layer, and the mean and
standard deviation are learned parameters. Learn more

4. Learning rate scheduling


Instead of using a fixed learning rate, we will use a learning rate scheduler,
which will change the learning rate after every batch of training. There are
many strategies for varying the learning rate during training, and we used
the “One Cycle Learning Rate Policy”. Learn more

5. Weight Decay: We added weight decay to the optimizer, yet another


regularization technique which prevents the weights from becoming too
large by adding an additional term to the loss function. Learn more

6. Gradient clipping: We also added gradient clip pint, which helps limit the
values of gradients to a small range to prevent undesirable changes in model
parameters due to large gradient values during training. Learn more

7. Adam optimizer: Instead of SGD (stochastic gradient descent), we used the


Adam optimizer which uses techniques like momentum and adaptive I(4*&4./4() J%0+*A+
learning rates for faster training. There are many other optimizers to choose
formal and experiment with. Learn more
J(./K'

CNN model without implementing the above technique gives an accuracy of


about 75%. You can learn more about the model on this blog. Now I built a CNN
by implementing all the 7 techniques and improved the model accuracy to 90%
without overfitting the model on the training set.

!"#$%#&''(
L6*BE$$EF(/&
4>.-&"?5&5%$&
>.4.*JK%(+K(*!/.K4%4%E+(/*M*,.K'%+(*-(./+%+0*M
The CIFAR-10 dataset (Canadian Institute For Advanced Research) is a collection N("/.$*N(4FE/O&*M*!5PE/K'*M*P(+&E/B$EF

of images that are commonly used to train machine learning and computer
BE$$EF
vision algorithms. It is one of the most widely used datasets for machine learning
research. The CIFAR-10 dataset contains 60,000 32x32 color images in 10
different classes. The 10 different classes represent airplanes, cars, birds, cats, )*#'+,#*-+)'./%-

deer, dogs, frogs, horses, ships, and trucks. There are 6,000 images of each class. Q.KE#*B(/"&

85*"*$-25@"*$&J.2V%"-*C$2%&5*C
+*&$=$2%W

'5>@$".7"(.*&$*&% C/.&'*R.=.+0%/

1. Introduction &25+*+*=Q"12$C+(&+.*Q"5*C
$35@-5&+.*".7"05()+*$"3+%+.*%
2. Data Pre Processing 0.C$@%"+*"35@+C5&+.*"0.C$
2.1 Loading the required libraries
C#'%&'(O*J%+0'

3. Applying Data normalization and Data augmentation E2$C+(&+*="P+*$":53$"<-&1-&"5*C


X+%-5@+H+*="&)$"?$$1"F$52*+*=
3.1 Building Data transforms 9$&J.2V
3.2 Applying the transforms to the dataset
J'.&'%O.+4*!/.&.)
4. Accessing few sample images
'2+1@$&"F.%%".*"/05=$9$&"?5&5%$&
5. Accessing GPU

6. Model configuration
4.1 Setting up Accuracy function and Image classification base class
2($S J4.4"& T/%4(/& H$E0 @./((/& !/%U.K5 P(/=& C#E"4
4.2 Implementing Batch normalization and Dropouts R+EF.#$(

4.3 Implementing Weight Decay, Gradient clipping, Adam optimizer


4.4 Moving the model to GPU

7. Training the model & Result Analysis


6.1 Setting up the parameters before training
6.2 Running the model for 8 epochs
6.3 Accuracy vs No of epochs
6.4 Loss vs epochs
6.5 Learning rate with batch number

8. Running prediction over test dataset

9. Summary

10. Future Work

11. References

№AB"/*&2.C-(&+.*
The CIFAR-10 dataset contains 60,000 32x32 color images in 10 different
classes. CIFAR-10 is a set of images that can be used to teach a FFNN how to
recognize objects. Since the images in CIFAR-10 are low-resolution (32x32), this
dataset can allow researchers to quickly try different algorithms to see what
works.

List of classes under the CIFAR 10 dataset —

1. Airplanes

2. Cars

3. Birds

4. Cats

5. Deer

6. Dogs

7. Frogs

8. Horses

9. Ships

10. Trucks

№DB"?5&5"E2$"E2.($%%+*=

F.5C+*="2$,-+2$C"@+>252+$%
Since we are using PyTorch to build the neural network. I import all the related
library in single go.

№GB"411@6+*="?5&5"*.205@+H5&+.*"5*C"?5&5"5-=0$*&5&+.*
Data normalization and data augmentation is implemented while loading the
dataset with in transforms functions.

#-+@C+*="?5&5"&25*%7.20%

411@6+*="&)$"&25*%7.20%"&."&)$"C5&5%$&

Let us access the training and validation set

Now I built 2 DataLoaders for testing and validation. To do this I use the
DataLoader method

№IB"4(($%%+*="7$J"%501@$"+05=$%
Before we work on building the model let us also access sample images. Note the
data present in DataLoaders is in normalized form displaying such data wont
make any sense. Hence a demoralize function is developed to revert back the
changes so that original images can be accessed.

!/(&(+4*%+*4/.%+%+0*).4.&(4

№KB"4(($%%+*="LEM
You can use a Graphics Processing Unit (GPU) to train your models faster if your
execution platform is connected to a GPU manufactured by NVIDIA. Follow these
instructions to use a GPU on the platform of your choice:

Google Colab: Use the menu option “Runtime > Change Runtime Type” and
select “GPU” from the “Hardware Accelerator” dropdown.

Kaggle: In the “Settings” section of the sidebar, select “GPU” from the
“Accelerator” dropdown. Use the button on the top-right to open the sidebar.

Binder: Notebooks running on Binder cannot use a GPU, as the machines


powering Binder aren’t connected to any GPUs.

Linux: If your laptop/desktop has an NVIDIA GPU (graphics card), make sure
you have installed the NVIDIA CUDA drivers.

Windows: If your laptop/desktop has an NVIDIA GPU (graphics card), make


sure you have installed the NVIDIA CUDA drivers. macOS: macOS is not
compatible with NVIDIA GPUs

If you do not have access to a GPU or aren’t sure what it is, don’t worry, you
can execute all the code in this tutorial just fine without a GPU.

Let’s begin by installing and importing the required libraries.

№NB"O.C$@"8.*7+=-25&+.*

P$&&+*="-1"4((-25(6"7-*(&+.*"5*C"/05=$8@5%%+7+(5&+.*#5%$"(@5%%
The accuracy function provides the accuracy of the model by comparing the
model output against the labels. In ImageClassificationBase we have functions to
calculate ‘loss’ and helper function to calculate combined losses and accuracies
for each epochs.

/01@$0$*&+*="#5&()"*.205@+H5&+.*"5*C"?2.1.-&%
We use nn.sequential to chain the layer’s of neural network. I implement Batch
normalization at the end of each layer. Also observe the implementation of
dropout of 20%.

/01@$0$*&+*=":$+=)&"?$(56Q"L25C+$*&"(@+11+*=Q"4C50".1&+0+H$2
I slightly modify the fit function to accept weight decay, gradient clipping
parameters. Note that fit function is generic function and can be used as it is in
other neural networks

O.3+*="&)$"0.C$@"&."LEM

Before we train let us check how model performs

I observe an accuracy of about 10%. As we have ten classes hence the chance of
getting the prediction right is 1 out of 10 classes. Hence the model seems to be
randomly guessing.

№!B"'25+*+*="&)$"0.C$@"R"S$%-@&"4*5@6%+%

P$&&+*="-1"&)$"15250$&$2%">$7.2$"&25+*+*=
I pass the following parameter before I start training. Try experimenting with
different parameter. You can initially start with large values and switch to
smaller values when your model start achieving higher accuracy values.

S-**+*="&)$"0.C$@"7.2"T"$1.()%

If you observe the above results. We achieved the accuracy of about 90% by just
training the model under 3 mins. This is awesome.

4((-25(6"3%"9.".7"$1.()%

F.%%"3%"$1.()%
We can also plot the training and validation losses to study the trend.

It’s clear from the trend that our model isn’t overfitting to the training data just
yet.

F$52*+*="25&$"J+&)">5&()"*-0>$2

As expected, the learning rate starts at a low value, and gradually increases for
30% of the iterations to a maximum value of 0.01 , and then gradually decreases

to a very small value.

№TB"'$%&+*="J+&)"+*C+3+C-5@"+05=$%
While we have been tracking the overall accuracy of a model so far, it’s also a
good idea to look at model’s results on some sample images. Let’s test out our
model with some images from the predefined test dataset of 10000 images.

Model seems to be performing really well. Some images are difficult to identify
by human eyes but the model seems to be doing the great job with classifying
those as well.

№UB"P-00526
Here is the brief summary of the article and step by step process we followed in
applying all the 7 techniques to improve the model performance.

1. We briefly learned about all the 7 techniques


- Data normalization
- Data augmentation
- Batch normalization
- Learning rate scheduling
- Weight Decay
- Gradient clipping
- Adam optimizer

2. We learned to built transforms and implemented data normalization and


2. We learned to built transforms and implemented data normalization and
data augmentation within those.

3. We learned accessing GPU and how to load training and validation dataset on
GPU.

4. We implemented batch normalization and dropout while constructing the


neural net.

5. we implemented weight decay, gradient clipping and adam optimizer


while training the model.

6. We achieved an accuracy of 90% with 3 mins of training.

7. We randomly checked the model performance by running it on the few


testing samples

№AYB"Z-&-2$":.2V
1. Try removing batch normalization, data augmentation and dropouts one by
one to study their effect on overfitting.

2. Building the model using TensorFlow and try implementing all these
techniques.

3. Our model trained to over 90% accuracy in under 3 minutes! Try playing
around with the data augmentations, network architecture &
hyperparameters to achieve the following results:
- 94% accuracy in under 10 minutes (easy)
- 90% accuracy in under 2 minutes (intermediate)
- 94% accuracy in under 5 minutes (hard)

№AAB"S$7$2$*($%
1. You can access and execute the complete notebook on this link —
https://jovian.ai/hargurjeet/cfar-10-dataset-6e9d9

2. https://pytorch.org/

3. https://jovian.ai/learn/deep-learning-with-pytorch-zero-to-gans

With this we complete the series of all 3 articles where initially we started
building ANN then moved to CNN for better performance and finally applying
some techniques to further enhance the model performance. You can access my
previous blogs here — blog1 , blog2

I really hope you guys learned something from this post. Feel free to if you
like what you learnt. Let me know if there is anything you need my help with.

')5*V%"7.2"2$5C+*="&)+%"52&+(@$["\5116"F$52*+*="
:6

:6

0/$1+%2+,*#+)"34/1'+5'"#1/1$+6#(
H5*,-(./+%+01.%

H(*&"/(*4E*JVHJ@DAHW*'(/(* *4E*+(U(/*=%&&*.+E4'(/*./4%K$(*E+*,.K'%+(*-(./+%+0*X*CA*C/4*YP.O(*.*$EEO1

I(4*4'%&*+(F&$(44(/
^E"/*(=.%$

H5*&%0+%+0*"S8*5E"*F%$$*K/(.4(*.*,()%"=*.KKE"+4*%G*5E"*)E+Z4*.$/(.)5*'.U(*E+(1*D(U%(F*E"/*!/%U.K5*!E$%K5*GE/*=E/(*%+GE/=.4%E+*.#E"4*E"/*S/%U.K5
S/.K4%K(&1

)*#'+,#*-+)5'"#1/1$7"/ BE$$EF

>.4.*JK%(+4%&4&*="&4*4'%+O*$%O(*.+*./4%&4*F'(+*G%+)%+0*.*&E$"4%E+*F'(+*K/(.4%+0*.*S%(K(
EG*KE)(1* *C/4%&4&*(+3E5*FE/O%+0*E+*%+4(/(&4%+0*S/E#$(=&8*(U(+*%G*4'(/(*%&*+E*E#U%E"&
.+&F(/* *$%+O4/1(([=$(./+%+0* *BE$$EF*4E*3E%+*E"/*:<R\*V+%]"(*>CA-^*D(.)(/&*

J%+)'"*J(($.= ; ,.5*6_8*696:

\.J"'."/012.3$"f.-2"O.C$@g%"E$27.205*($"M%+*=
82.%%hX5@+C5&+.*"'$()*+,-$%
Machine Learning models often fail to generalize well on data it has not
been trained on. Hence, there is always a need to validate the stability of…
your machine learning model. It means we need to ensure that the
elciency of our model
,.K'%+(*-(./+%+0 remains constant throughout. In other words, we…
_*=%+*/(.)

J'./(*5E"/*%)(.&*F%4'*=%$$%E+&*EG*/(.)(/&1 T/%4(*E+*,()%"=

!.U.+*R"+K'.$. ; ,.5*6_8*696:

8.@.-2"4C^-%&0$*&"411@+(5&+.*"-%+*="`.2*+5"R
P&2$50@+&
I have been working on Kornia, its a really cool library with some really cool
application you can install it using pip install --upgrade kornia You can…
check its documentation from here. Today we will create a Color
Adjustment Application using Kornia and Streamlit Importing Libraries
!54'E+ 6*=%+*/(.)
Let’s import the required libraries such…

CO&'%4'*R"=./ ; ,.5*6_8*696:

\5*C%".*"9FEh95&-25@"F5*=-5=$"E2.($%%+*=
In detail with Bag of words, TF_IDF, RNN’s, GRU’s & LSTM’s. Yes! Alexa is
designed from NLP modelling. But, how NLP works and how it is designed
for modelling? In this article, I will prepare data for NLP modelling. …

N-! 7*=%+*/(.)

>.=%.+ ; ,.5*6_8*696: ,(=#(/`E+$5

:)5&"6.-"J5*&"&."%56Q"C."+&"J+&)"()52&%
A list of charts and recommendations to improve your next dashboard —
Browsing the web I came across a page that allows you to create charts to
visualize data loaded by the user. It allows you to graph for free, althoug…
has a paid version. It allows you to give the graph a basic format and export
it>.4.*JK%(+K(
as an image… :9*=%+*/(.)

!/.O'./*,%&'/. ; ,.5*6_8*696: ,(=#(/`E+$5

]77+(+$*&"P6%&$0"7.2"L250052"]22.2"8.22$(&+.*".*
O.>+@$"?$3+($%
NLP Research Paper Summary — In this blog, I have tried summarizing the
paper An elcient system for grammatical error correction on mobile…
devices as per my understanding. Please feel free to comment your
thoughts on the same! tl;dr:
I/.==./*W//E/*@E//(K4%E+ This research proposes an elcient, small and
_*=%+*/(.)
accurate grammar error correction tool that can be deployed on…

D(.)*=E/(*G/E=*,-(./+%+01.%

8'3*--'1.'.+,#*-+)'./%-

W'&.+*C/.#+(a'.) 2./&'%4*P5b %+ PEF./)&*>.4.*JK%(b


0% K(
P152V+76B"E2$C+(&+*="8)-2*":+&) '5%V"8)$5&%)$$&"7.2"4@0.%&
P152V"7.2"5"O-%+("P&2$50+*= ]3$26"O5()+*$"F$52*+*="E2.^$(&
8.015*6

-.#($cS& %+ P'(*J4./4"S !%+O(&'*>.b %+ C"#(/0%+(*JE$"4%Eb


% &
:)5&"/%"/05=$"4**.&5&+.*W"4 :)6W"\.JW"/"05C$"Z-22"_
P).2&"/*&2.C-(&+.* E+=$.*"%(52+*="4*C2.+C"511
-%+*="OF"`+&

C&.)*,.$%O %+ C+.$54%K&*d%)'5. P"5(+*dE

9FEB"PV$@$&.*".7"5"P-1$23+%$C E6'.2()"'-&.2+5@"A
'25+*+*="5*C"'$%&+*=
O$()5*+%0"M%+*="411@$"P&.(a
9$J%
J'%U.=*O.$. d%+K(+4*@'(+

8.*7-%+.*"O5&2+b"c#$=+**$2% ]35@-5&+*="6.-2"&$b&"=$*$25&+.*
L-+C$d 2$%-@&%W"P+01@$"5%"&)5&e

You might also like