Professional Documents
Culture Documents
(Advances in Intelligent Systems and Computing 1232) M. Arif Wani, Taghi M. Khoshgoftaar, Vasile Palade - Deep Learning Applications, Volume 2-Springer Singapore_Springer (2021)
(Advances in Intelligent Systems and Computing 1232) M. Arif Wani, Taghi M. Khoshgoftaar, Vasile Palade - Deep Learning Applications, Volume 2-Springer Singapore_Springer (2021)
M. Arif Wani
Taghi M. Khoshgoftaar
Vasile Palade Editors
Deep Learning
Applications,
Volume 2
Advances in Intelligent Systems and Computing
Volume 1232
Series Editor
Janusz Kacprzyk, Systems Research Institute, Polish Academy of Sciences,
Warsaw, Poland
Advisory Editors
Nikhil R. Pal, Indian Statistical Institute, Kolkata, India
Rafael Bello Perez, Faculty of Mathematics, Physics and Computing,
Universidad Central de Las Villas, Santa Clara, Cuba
Emilio S. Corchado, University of Salamanca, Salamanca, Spain
Hani Hagras, School of Computer Science and Electronic Engineering,
University of Essex, Colchester, UK
László T. Kóczy, Department of Automation, Széchenyi István University,
Gyor, Hungary
Vladik Kreinovich, Department of Computer Science, University of Texas
at El Paso, El Paso, TX, USA
Chin-Teng Lin, Department of Electrical Engineering, National Chiao
Tung University, Hsinchu, Taiwan
Jie Lu, Faculty of Engineering and Information Technology,
University of Technology Sydney, Sydney, NSW, Australia
Patricia Melin, Graduate Program of Computer Science, Tijuana Institute
of Technology, Tijuana, Mexico
Nadia Nedjah, Department of Electronics Engineering, University of Rio de Janeiro,
Rio de Janeiro, Brazil
Ngoc Thanh Nguyen , Faculty of Computer Science and Management,
Wrocław University of Technology, Wrocław, Poland
Jun Wang, Department of Mechanical and Automation Engineering,
The Chinese University of Hong Kong, Shatin, Hong Kong
The series “Advances in Intelligent Systems and Computing” contains publications
on theory, applications, and design methods of Intelligent Systems and Intelligent
Computing. Virtually all disciplines such as engineering, natural sciences, computer
and information science, ICT, economics, business, e-commerce, environment,
healthcare, life science are covered. The list of topics spans all the areas of modern
intelligent systems and computing such as: computational intelligence, soft comput-
ing including neural networks, fuzzy systems, evolutionary computing and the fusion
of these paradigms, social intelligence, ambient intelligence, computational neuro-
science, artificial life, virtual worlds and society, cognitive science and systems,
Perception and Vision, DNA and immune based systems, self-organizing and
adaptive systems, e-Learning and teaching, human-centered and human-centric
computing, recommender systems, intelligent control, robotics and mechatronics
including human-machine teaming, knowledge-based paradigms, learning para-
digms, machine ethics, intelligent data analysis, knowledge management, intelligent
agents, intelligent decision making and support, intelligent network security, trust
management, interactive entertainment, Web intelligence and multimedia.
The publications within “Advances in Intelligent Systems and Computing” are
primarily proceedings of important conferences, symposia and congresses. They
cover significant recent developments in the field, both of a foundational and
applicable character. An important characteristic feature of the series is the short
publication time and world-wide distribution. This permits a rapid and broad
dissemination of research results.
** Indexing: The books of this series are submitted to ISI Proceedings,
EI-Compendex, DBLP, SCOPUS, Google Scholar and Springerlink **
Vasile Palade
Editors
123
Editors
M. Arif Wani Taghi M. Khoshgoftaar
Department of Computer Science Computer and Electrical Engineering
University of Kashmir Florida Atlantic University
Srinagar, India Boca Raton, FL, USA
Vasile Palade
Faculty of Engineering and Computing
Coventry University
Coventry, UK
© The Editor(s) (if applicable) and The Author(s), under exclusive license to Springer Nature
Singapore Pte Ltd. 2021
This work is subject to copyright. All rights are solely and exclusively licensed by the Publisher, whether
the whole or part of the material is concerned, specifically the rights of translation, reprinting, reuse of
illustrations, recitation, broadcasting, reproduction on microfilms or in any other physical way, and
transmission or information storage and retrieval, electronic adaptation, computer software, or by similar
or dissimilar methodology now known or hereafter developed.
The use of general descriptive names, registered names, trademarks, service marks, etc. in this
publication does not imply, even in the absence of a specific statement, that such names are exempt from
the relevant protective laws and regulations and therefore free for general use.
The publisher, the authors and the editors are safe to assume that the advice and information in this
book are believed to be true and accurate at the date of publication. Neither the publisher nor the
authors or the editors give a warranty, expressed or implied, with respect to the material contained
herein or for any errors or omissions that may have been made. The publisher remains neutral with regard
to jurisdictional claims in published maps and institutional affiliations.
This Springer imprint is published by the registered company Springer Nature Singapore Pte Ltd.
The registered company address is: 152 Beach Road, #21-01/04 Gateway East, Singapore 189721,
Singapore
Preface
Machine learning algorithms have influenced many aspects of our day-to-day living
and transformed major industries around the world. Fueled by an exponential
growth of data, improvements in computer hardware, scalable cloud resources, and
accessible open-source frameworks, machine learning technology is being used by
companies in big and small alike for innumerable applications. At home, machine
learning models are suggesting TV shows, movies, and music for entertainment,
providing personalized ecommerce suggestions, shaping our digital social net-
works, and improving the efficiency of our appliances. At work, these data-driven
methods are filtering our emails, forecasting trends in productivity and sales, tar-
geting customers with advertisements, improving the quality of video conferences,
and guiding critical decisions. At the frontier of machine learning innovation are
deep learning systems, a class of multi-layered networks is capable of automatically
learning meaningful hierarchical representations from a variety of structured and
unstructured data. Breakthroughs in deep learning allow us to generate new rep-
resentations, extract knowledge, and draw inferences from raw images, video
streams, text and speech, time series, and other complex data types. These powerful
deep learning methods are being applied to new and exciting real-world problems in
medical diagnostics, factory automation, public safety, environmental sciences,
autonomous transportation, military applications, and much more.
The family of deep learning architectures continues to grow as new methods and
techniques are developed to address a wide variety of problems. A deep learning
network is composed of multiple layers that form universal approximators capable
of learning any function. For example, the convolutional layers in Convolutional
Neural Networks use shared weights and spatial invariance to efficiently learn
hierarchical representations from images, natural language, and temporal data.
Recurrent Neural Networks use backpropagation through time to learn from vari-
able length sequential data. Long Short-Term Memory networks are a type of
recurrent network capable of learning order dependence in sequence prediction
problems. Deep Belief Networks, Autoencoders, and other unsupervised models
generate meaningful latent features for downstream tasks and model the underlying
concepts of distributions by reconstructing their inputs. Generative Adversarial
v
vi Preface
vii
viii Contents
ix
x Editors and Contributors
conferences and has given many invited talks at various venues. Also, he has served
as North American Editor of the Software Quality Journal, was on the editorial
boards of the journals Multimedia Tools and Applications, Knowledge and
Information Systems, and Empirical Software Engineering, and is on the editorial
boards of the journals Software Quality, Software Engineering and Knowledge
Engineering, and Social Network Analysis and Mining.
Contributors
Abstract The term “information overload” has gained popularity over the last few
years. It defines the difficulties people face in finding what they want from a huge
volume of available information. Recommender systems have been recognized to be
an effective solution to such issues, such that suggestions are made based on users’
preferences. This chapter introduces an application of deep learning techniques in
the domain of recommender systems. Generally, collaborative filtering approaches,
and Matrix Factorization (MF) techniques in particular, are widely known for their
convincing performance in recommender systems. We introduce a Collaborative
Attentive Autoencoder (CATA) that improves the matrix factorization performance
by leveraging an item’s contextual data. Specifically, CATA learns the proper features
from scientific articles through the attention mechanism that can capture the most
pertinent parts of information in order to make better recommendations. The learned
features are then incorporated into the learning process of MF. Comprehensive exper-
iments on three real-world datasets have shown our method performs better than other
state-of-the-art methods according to various evaluation metrics. The source code of
our model is available at: https://github.com/jianlin-cheng/CATA.
This chapter is an extended version of our published paper at the IEEE ICMLA conference 2019
[1]. This chapter incorporates new experimental contributions compared to the original confere-
nce paper.
© The Editor(s) (if applicable) and The Author(s), under exclusive license 1
to Springer Nature Singapore Pte Ltd. 2021
M. A. Wani et al. (eds.), Deep Learning Applications, Volume 2,
Advances in Intelligent Systems and Computing 1232,
https://doi.org/10.1007/978-981-15-6759-9_1
2 M. Alfarhood and J. Cheng
1 Introduction
The era of e-commerce has vastly changed people’s lifestyles during the first part
of the twenty-first century. People today tend to do many of their daily routines
online, such as shopping, reading the news, and watching movies. Nevertheless,
consumers often face difficulties while exploring related items such as new fashion
trends because they are not aware of their existence due to the overwhelming amount
of information available online. This phenomenon is widely known as “information
overload”. Therefore, Recommender Systems (RSs) are a critical solution for helping
users make decisions when there are lots of choices. RSs have been integrated into
and have become an essential part of every website due to their impact on increasing
customer interactions, attracting new customers, and growing businesses’ revenue.
Scientific article recommendation is a very common application for RSs. It keeps
researchers updated on recent related work in their field. One traditional way to
find relevant articles is to go through the references section in other articles. Yet,
this approach is biased toward heavily cited articles, such that new relevant articles
with higher impact have less chance to be found. Another method is to search for
articles using keywords. Although this technique is popular among researchers, they
must filter out a tremendous number of articles from the search results to retrieve
the most suitable articles. Moreover, all users get the same search results with the
same keywords, and these results are not personalized based on the users’ personal
interests. Thus, recommendation systems can address this issue and help scientists
and researchers find valuable articles while being aware of recent related work.
Over the last few decades, a lot of effort has been made by both academia and
industry on proposing new ideas and solutions for RSs, which ultimately help ser-
vice providers in adopting such models in their system architecture. The research in
RSs has evolved remarkably following the Netflix prize competition1 in 2006, where
the company offered one million dollars for any team that could improve their rec-
ommendation accuracy by 10%. Since that time, collaborative filtering models and
matrix factorization techniques in particular have become the most common models
due to their effective performance. Generally, recommendation models are classified
into three categories: Collaborative Filtering Models (CF), Content-Based Filter-
ing models (CBF), and hybrid models. CF models [2–4] focus on users’ histories,
such that users with similar past behaviors tend to have similar future tastes. On the
other hand, CBF models work by learning the item’s features from its informational
description, such that two items are possibly similar to each other if they share more
characteristics. For example, two songs are similar to each other if they both share
the same artist, genre, tempo, energy, etc. However, similarities between items in CF
models are different such that two items are likely similar to each other once they are
rated by multiple users in the same manner, even though those items have different
characteristics.
1 www.netflixprize.com.
Deep Learning-Based Recommender Systems 3
2 www.citeulike.org.
4 M. Alfarhood and J. Cheng
against multiple recent works. The experimental results prove that our model can
extract more constructive information from an article’s contextual data than other
models. More importantly, CATA performs very well where the data sparsity is
extremely high.
The remainder of this chapter is organized in the following manner. First, we
demonstrate the matrix factorization method in Sect. 2. We introduce our model,
CATA, in Sect. 3. The experimental results of our model against the state-of-the-art
models are discussed thoroughly in Sect. 4. We then conclude our work in Sect. 5.
2 Background
Matrix Factorization (MF) [2] is the most popular CF method, mainly due to its
simplicity and efficiency. The idea behind MF is to decompose the user-item matrix,
R ∈ Rn×m , into two lower dimensional matrices, U ∈ Rn×d and V ∈ Rm×d , such
that the inner product of U and V will approximate the original matrix R, where d
is the dimension of the latent factors, such that d min(n, m). n and m correspond
to the number of users and items in the system. Figure 1 illustrates the MF process.
R ≈ U · VT (1)
where Ii j is an indicator function that equals 1 if useri has rated item j , and 0 if
otherwise. Also, ||U || and ||V || are the Euclidean norms, and λu , λv are two regu-
larization terms preventing the values of U and V from being too large. This avoid
model overfitting.
Explicit data, such as ratings (ri j ) are not regularly available. Therefore, Weighted
Regularized Matrix Factorization (WRMF) [9] introduces two modifications to the
previous objective function to make it work for implicit feedback. The optimization
Deep Learning-Based Recommender Systems 5
process in this case runs through all user-item pairs with different confidence levels
assigned to each pair, as in the following:
ci j λu λv
L= ( pi j − u i v Tj )2 + u i 2 + v j 2 (3)
i, j∈R
2 2 2
where pi j is the user preference score with a value of 1 when useri and item j have
an interaction, and 0 otherwise. ci j is a confidence variable where its value shows
how confident the user like the item. In general, ci j = a when pi j = 1, and ci j = b
when pi j = 0, such that a > b > 0.
Stochastic Gradient Decent (SGD) [10] and Alternating Least Squares (ALS) [11]
are two optimization methods that can be used to minimize the objective function
of MF in Eq. 2. The first method, SGD, loops over each single training sample and
then computes the prediction error as ei j = ri j − u i v Tj . The gradient of the objective
function with respect to u i and v j can be computed as follows:
∂L
=− Ii j (ri j − u i v Tj )v j + λu u i
∂u i j
(4)
∂L
=− Ii j (ri j − u i v Tj )u i + λv v j
∂v j i
After calculating the gradient, SGD updates the user and item latent factors in the
opposite direction of the gradient using the following equations:
6 M. Alfarhood and J. Cheng
⎛ ⎞
ui ← ui + α ⎝ Ii j ei j v j − λu u i ⎠
j
(5)
vj ← vj + α Ii j ei j u i − λ j v j
i
∂L
=− ci j ( pi j − u i v Tj )v j + λu u i
∂u i j
0 = −Ci (Pi − u i V T )V + λu u i
0 = −Ci V Pi + Ci V u i V T + λu u i (6)
V Ci Pi = u i V Ci V + λu u i
T
V Ci Pi = u i (V Ci V T + λu I )
ui = V Ci Pi (V Ci V T + λu I )−1
ui = (V Ci V T + λu I )−1 V Ci Pi
v j = (U C j U T + λv I )−1 U C j P j (7)
3 Proposed Model
In this section, we illustrate our proposed model in depth. The intuition behind our
model is to learn the latent factors of items in PMF with the use of available side
textual contents. We use an attentive unsupervised model to catch more plentiful
information from the available data. The architecture of our model is displayed in
Fig. 2. We first define the problem with implicit feedback before we go through the
details of our model.
Deep Learning-Based Recommender Systems 7
λu λv
X̂ j
Decoder
Attention
Ui Vj Zj X Softmax
Rij Encoder
Xj
i = 1:n j = 1:m
we treat all missing data as unobserved data without considering including negative
feedback in the model training, the corresponding trained model is probably useless
since it is only trained on positive data. As a result, sampling negative feedback
from positive feedback is one practical solution for this problem, which has been
proposed by [12]. In addition, Weighted Regularized Matrix Factorization (WRMF)
[9] is another proposed solution that introduces a confidence variable that works as
a weight to measure how likely a user is to like an item.
In general, the recommendation problem with implicit data is usually formulated
as follows:
1, if there is user-item interaction
Rnm = (8)
0, otherwise
where the ones in implicit feedback represent all the positive feedback. However,
it is important to note that a value of 0 does not imply always negative feedback.
It may be that users are not aware of the existence of those items. In addition,
the user-item interaction matrix (R) is usually highly imbalanced, such that the
number of the observed interactions is much less than the number of the unobserved
interactions. In other words, matrix R is very sparse, meaning that users only interact
explicitly or implicitly with a very small number of items compared to the total
number of items in this matrix. Sparsity is one frequent problem in RSs, which brings
a real challenge for any proposed model to have the capability to provide effective
personalized recommendations under this situation. The following sections explain
our methodology, where we aim to eliminate the influence of the aforementioned
problems.
Autoencoder [13] is an unsupervised learning neural network that is useful for com-
pressing high-dimensional input data into a lower dimensional representation while
preserving the abstract nature of the data. The autoencoder network is generally
composed of two main components, i.e., the encoder and the decoder. The encoder
takes the input and encodes it through multiple hidden layers and then generates a
compressed representative vector, Z j . The encoding function can be formulated as
Z j = f (X j ). Subsequently, the decoder can be used then to reconstruct and estimate
the original input, Xˆ j , using the representative vector, Z j . The decoder function can
be formulated as Xˆ j = f (Z j ). Each the encoder and the decoder usually consist of
the same number of hidden layers and neurons. The output of each hidden layer is
computed as follows:
where () is the layer number, W is the weights matrix, b is the bias vector, and σ
is a non-linear activation function. We use the Rectified Linear Unit (ReLU) as the
activation function.
Our model takes input from the article’s textual data, X j = {x 1 , x 2 , . . . , x s },
where x i is a value between [0, 1] and s represents the vocabulary size of the arti-
cles’ titles and abstracts. In other words, the input of our autoencoder network is
a normalized bag-of-words histograms of filtered vocabularies of the articles’ titles
and abstracts.
Batch Normalization (BN) [14] has been proven to be a proper solution for the
internal covariant shift problem, where the layer’s input distribution in deep neural
networks changes across the time of training, and causes difficulty in training the
model. In addition, BN can work as a regularization procedure like Dropout [15]
in deep neural networks. Accordingly, we apply a batch normalization layer after
each hidden layer in our autoencoder to obtain a stable distribution from each layer’s
output.
Furthermore, we use the idea of the attention mechanism to work between the
encoder and the decoder, such that only the relevant parts of the encoder output are
selected for the input reconstruction. Attention in deep learning can be described
simply as a vector of weights to show the importance of the input elements. Thus,
the intuition behind attention is that not all parts of the input are equally significant,
i.e., only few parts are significant for the model. We first calculate the scores as the
probability distribution of the encoder’s output using the so f tmax(.) function.
ezc
f (z c ) = zd
(10)
de
The probability distribution and the encoder output are then multiplied using
element-wise multiplication function to get Z j .
We use the attentive autoencoder to pretrain the items’ contextual information
and then integrate the compressed representation, Z j , in computing the items’ latent
factors, V j , from the matrix factorization method. The dimension space of Z j and V j
are set to be equal to each other. Finally, we adopt the binary cross-entropy (Eq. 11)
as the loss function we want to minimize in our attentive autoencoder model.
L=− yk log( pk ) − (1 − yk ) log(1 − pk ) (11)
k
u i ∼ N (0, λ−1
u I)
v j ∼ N (0, λ−1
v I) (12)
pi j ∼ N (u i v Tj , σ 2 )
We integrate the items’ contents, trained through the attentive autoencoder, into
PMF. Therefore, the objective function in Eq. 3 has been changed slightly to become
ci j λu λv
L= ( pi j − u i v Tj )2 + u i 2 + v j − θ (X j )2 (13)
i, j∈R
2 2 2
where θ (X j ) = Encoder (X j ) = Z j .
Thus, taking the partial derivative of our previous objective function with respect
to both u i and v j results in the following equations that minimize our objective
function the most
u i = (V Ci V T + λu I )−1 V Ci Pi
(14)
v j = (U C j U T + λv I )−1 U C j P j + λv θ (X j )
We optimize the values of u i and v j using the Alternating Least Squares (ALS)
optimization method.
3.4 Prediction
After our model has been trained and the latent factors of users and articles, U and
V , are identified, we calculate our model’s prediction scores of useri and each article
as the dot product of vector u i with all vectors in V as scor esi = u i V T . Then, we
sort all articles based on our model predication scores in descending order, and then
recommend the top-K articles for that useri . We go through all users in U in our
evaluation and report the average performance among all users. The overall process
of our approach is illustrated in Algorithm 1.
Deep Learning-Based Recommender Systems 11
4 Experiments
4.1 Datasets
Three scientific article datasets are used to evaluate our model against the state-of-
the-art methods. All datasets are collected from CiteULike website. The first dataset
is called Citeulike-a, which is collected by [5]. It has 5,551 users, 16,980 articles, and
204,986 user-article pairs. The sparseness of this dataset is extremely high, where
only around 0.22% of the user-article matrix has interactions. Each user has at least
12 M. Alfarhood and J. Cheng
ten articles in his or her library. On average, each user has 37 articles in his or her
library and each article has been added to 12 users’ libraries. The second dataset is
called Citeulike-t, which is collected by [6]. It has 7,947 users, 25,975 articles, and
134,860 user-article pairs. This dataset is actually sparser than the first one with only
0.07% available user-article interactions. Each user has at least three articles in his
or her library. On average, each user has 17 articles in his or her library and each
article has been added to five users’ libraries. Lastly, Citeulike-2004–2007 is the third
dataset, and it is collected by [16]. It is three times bigger than the previous ones with
regard to the user-article matrix. It has 3,039 users, 210,137 articles, and 284,960
user-article pairs. This dataset is the sparsest in this experiment, with a sparsity equal
to 99.95%. Each user has at least ten articles in his or her library. On average, each
user has 94 articles in his or her library and each article has been added only to one
user library. Brief statistics of the datasets are shown in Table 1.
Title and abstract of each article are given in each dataset. The average number
of words per article in both title and abstract after our text preprocessing is 67 words
in Citeulike-a, 19 words in Citeulike-t, and 55 words in Citeulike-2004–2007. We
follow the same preprocessing techniques as the state-of-the-art models in [5, 7,
8]. A five-stage procedure to preprocess the textual content is displayed in Fig. 3.
Each article title and abstract are combined together and then are preprocessed such
that stop words are removed. After that, top-N distinct words based on the TF-IDF
measurement are picked out. 8,000 distinct words are selected for the Citeulike-a
dataset, 20,000 distinct words are selected for the Citeulike-t dataset, and 19,871
distinct words are selected for the Citeulike-2004–2007 dataset to form the bag-of-
words histogram, which are then normalized into values between 0 and 1 based on
the vocabularies’ occurrences.
Figure 4 shows the ratio of articles that have been added to five or fewer users’
libraries. For example, 15, 77, and 99% of the articles in Citeulike-a, Citeulike-t, and
Citeulike-2004–2007, respectively, are added to five or fewer users’ libraries. Also,
only 1% of the articles in Citeulike-a have been added only to one user library, while
the rest of the articles have been added to more than this number. On the contrary,
13, and 77% of the articles in Citeulike-t and Citeulike-2004–2007 have been added
only to one user library. This proves the sparseness of the data with regard to articles
as we go from one dataset to another.
We follow the state-of-the-art techniques [6–8] to generate our training and testing
sets. For each dataset, we create two versions of the dataset for sparse and dense
settings. In total, six dataset cases are used in our evaluation. To form the sparse
(P = 1) and the dense (P = 10) datasets, P items are randomly selected from each
user library to generate the training set while the remaining items from each user
library are used to generate the testing set. As a result, when P = 1, only 2.7, 5.9,
and 1.1% of the data entries are used to generate the training set in Citeulike-a,
Citeulike-t, and Citeulike-2004–2007, respectively. Similarly, 27.1, 39.6, and 10.7%
of the data entries are used to generate the training set when P = 10 as Fig. 5 shows.
14 M. Alfarhood and J. Cheng
Fig. 5 The percentage of the data entries that forms the training and testing sets in all citeulike
datasets
We use recall and Discounted Cumulative Gain (DCG) as our evaluation metrics
to test how our model performs. Recall is usually used to evaluate recommender
systems with implicit feedback. However, precision is not favorable to use with
implicit feedback because the zero value in the user-article interaction matrix has
two meanings: either the user is not interested in the article, or the user is not aware
of the existence of this article. Therefore, using the precision metric only assumes
that for each zero value the user is not interested in the article, which is not the case.
Recall per user can be measured using the following formula:
where |U | is the total number of users, i is the rank of the top-K articles recommended
by the model, and rel(i) is an indicator function that outputs 1 if the article at rank i
is a relevant article, and 0 otherwise.
Deep Learning-Based Recommender Systems 15
4.3 Baselines
For each dataset, we repeat the data splitting four times with different random splits
of training and testing set, which has been previously described in the evaluation
methodology section. We use one split as a validation experiment to find the optimal
parameters of λu and λv for our model and the state-of-the-art models as well. We
search a grid of the following values {0.01, 0.1, 1, 10, 100} and the best values on
the validation experiment have been reported in Table 2. The other three splits are
used to report the average performance of our model against the baselines. In this
section, we address the research questions that have been previously defined in the
beginning of this section.
16 M. Alfarhood and J. Cheng
4.4.1 RQ1
To evaluate how our model performs, we conduct quantitative and qualitative com-
parisons to answer this question. Figures 6, 7, 8, and 9 show the performance of the
top-K recommendations under the sparse and dense settings in terms of recall and
DCG. First, the sparse cases are very challenging for any proposed model since there
is less data for training. In the sparse setting where there is only one article in each
user’s library in the training set, our model, CATA, outperforms the baselines in all
datasets in terms of recall and DCG, as Figs. 6 and 7 show. More importantly, CATA
outperforms the baselines by a wide margin in the Citeulike-2004–2007 dataset,
where it is actually sparser and contains a huge number of articles. This validates the
robustness of our model against data sparsity.
Second, with the dense setting where there are more articles in each user’s library
in the training set, our model performs comparably to other baselines in Citeulike-a
and Citeulike-t datasets as Figs. 8 and 9 show. As a matter of fact, many of the existing
models actually work well under this setting, but poorly under the sparse setting. For
example, CML+F achieves a competitive performance on the dense data; however, it
fails on the sparse data since their metric space needs more interactions for users to
capture their preferences. On the other hand, CATA outperforms the other baselines
under this setting in the Citeulike-2004–2007 dataset. As a result, this experiment
demonstrates the capability of our model for making more relevant recommendations
under both sparse and dense data conditions.
In addition to the previous quantitative comparisons, some qualitative results are
reported in Table 3 as well. The table shows the top ten recommendations gener-
ated by our model (CATA) and the state-of-the-art model (CVAE) for one randomly
selected user under the sparse setting using the Citeulike-2004–2007 dataset. In this
example, user 20 has only one article in his training library, entitled “Assessment of
Attention Deficit/ Hyperactivity Disorder in Adult Alcoholics”. From this example,
this user seems to be interested in the treatment of Attention-Deficit/Hyperactivity
Disorder (ADHD) among alcohol- and drug-using populations. Comparing the rec-
ommendation results between the two models, our model recommends more relevant
articles based on the user’s interests. For instance, most of the recommended articles
using the CATA model are related to the same topic, i.e., alcohol- and drug-users
with ADHD. However, there are some irrelevant articles recommended by CVAE,
18 M. Alfarhood and J. Cheng
Table 3 The top-10 recommendations for one selected random user under the sparse setting, P = 1,
using the citeulike-2004–2007 dataset
User ID: 20
Articles in the training set: assessment of attention deficit/ hyperactivity disorder in adult
alcoholics
CATA In user
library?
1. A Double-blind, placebo-controlled withdrawal trial of dexmethylphenidate No
hydrochloride in children with ADHD
2. Double-blind placebo-controlled trial of methylphenidate in the treatment of Yes
adult ADHD patients with comorbid cocaine...
3. Methylphenidate treatment for cocaine abusers with adult Yes
attention-deficit/hyperactivity disorder: a pilot study
4. A controlled trial of methylphenidate in adults with attention Yes
deficit/hyperactivity disorder and substance use disorders
5. Treatment of cocaine dependent treatment seekers with adult ADHD: Yes
double-blind comparison of methylphenidate and...
6. A large, double-blind, randomized clinical trial of methylphenidate in the Yes
treatment of adults with ADHS
7. Patterns of inattentive and hyperactive symptomatology in cocaine-addicted Yes
and non-cocaine-addicted smokers diagnosed...
8. Frontocortical activity in children with comorbidity of tic disorder and No
attention-deficit hyperactivity disorder
9. Gender effects on attention-deficit/hyperactivity disorder in adults, revisited Yes
10. Association between dopamine transporter (DAT1) genotype, left-sided Yes
inattention, and an enhanced response to...
CVAE In user
library?
1. Psycho-social correlates of unwed mothers No
2. A randomized, controlled trial of integrated home-school behavioral treatment No
for ADHD, predominantly inattentive type
3. Age and gender differences in children’s and adolescents’ adaptation to sexual No
abuse
4. Distress in individuals facing predictive DNA testing for autosomal dominant No
late-onset disorders: Comparing questionnaire results...
5. Combined treatment with sertraline and liothyronine in major depression: a No
randomized, double-blind, placebo-controlled trial
6. An open-label pilot study of methylphenidate in the treatment of cocaine Yes
dependent patients with adult ADHS
7. Treatment of cocaine dependent treatment seekers with adult ADHD: Yes
double-blind comparison of methylphenidate and...
8. ADouble-Blind, Placebo-Controlled Withdrawal Trial of Dexmethylphenidate No
Hydrochloride in Children with ADHS
9. Methylphenidate treatment for cocaine abusers with adult Yes
attention-deficit/hyperactivity disorder: a pilot study
10. A large, double-blind, randomized clinical trial of methylphenidate in the Yes
treatment of adults with ADHS
Deep Learning-Based Recommender Systems 19
Table 4 Performance comparisons on sparse data with using attention layer (CATA) and without
(CATA). Best values are marked in bold
Approach Citeulike-a Citeulike-t
Recall@300 DCG@300 Recall@300 DCG@300
CATA– 0.3003 1.6644 0.2260 0.4661
CATA 0.3060 1.7206 0.2425 0.5160
such as recommended articles numbers 1, 3, and 4. From this example and other
users’ examples we have examined, we can state that our model detects the major
elements of articles’ contents and users’ preferences more accurately.
4.4.2 RQ2
To examine the importance of adding the attention layer into our autoencoder, we
create another variant of our model that has the same architecture, but lacks the atten-
tion layer, which we call CATA–. We evaluate this model on the sparse cases using
Citeulike-a and Citeulike-t datasets. The performance comparisons are reported in
Table 4. As the table shows, adding the attention mechanism boosts the performance.
Consequently, using the attention mechanism gives more focus to some parts of the
encoded vocabularies in each article to better represent the contextual data, eventually
leading to increased recommendation quality.
4.4.3 RQ3
There are two regularization parameters (λu and λv ) that are used in the objective
function of the matrix factorization method to prevent the latent vectors’ magni-
tude from being too large, which eventually prevent the model from overfitting the
training data. Our previously reported results are obtained by setting λu and λv to
the numbers in Table 2 based on the validation experiment. However, we perform
multiple experiments to show the impact of different values of λu and λv and how
they affect our model’s performance. We use different values to set the parameters
from the following range {0.01, 0.1, 1, 10, 100}. Figure 10 visualizes how our model
performs under each combination of λu and λv . We find that our model has a lower
performance when the value of λv is considerably large under the dense setting, as
Figs. 10b, d show. On the other hand where the data is sparser in Figs. 10a, c, e, f, a
very small value of λu (e.g., 0.01) tends to have the lowest performance among all
other numbers. Even though Fig. 10f shows the performance under the dense setting
for the Citeulike-2004–2007 dataset, it still exemplifies the sparsity with regard to
articles as we indicate before in Fig. 4, where 80% of the articles have only been
added to one user’s library. Generally, we observe that optimal performance happens
in all datasets when λu = 10 and λv = 0.1. We can conclude that when there is suffi-
20 M. Alfarhood and J. Cheng
Fig. 10 The impact of λu and λv on CATA performance for a, b Citeulike-a, c, d citeulike-t, and
e, f citeulike-2004–2007 datasets
4.4.4 RQ4
The vectors of the latent features (U and V ) represent the characteristic of users and
items that a model tries to learn from data. We examine the impact of the size of
these vectors on the performance of our model. In other words, we examine how
many dimensions in the latent space can represent the user and item features more
accurately. It is worth mentioning that our reported results in the RQ1 section use 50
dimensions, which is similar to the size used by the state-of-the-art model (CVAE)
in order to have fair comparisons. However, we run our model again using five
dimension sizes from the following values {25, 50, 100, 200, 400}. Figure 11 shows
how our model performs in terms of recall@100 under each dimension size. We
observe that increasing the dimension size in dense data leads always to a gradual
increase in our model performance, as shown in Fig. 11b. Also, larger dimension sizes
are recommended for sparse data as well. However, they do not necessary improve the
model’s performance all the time (e.g., the Citeulike-t dataset in Fig. 11a). Generally,
dimension sizes between 100 and 200 are suggested for the latent space dimension.
Deep Learning-Based Recommender Systems 21
Fig. 11 The performance of CATA model with respect to different dimension values of the latent
space under, a sparse data and b dense data
4.4.5 RQ5
We pretrain our autoencoder first until the loss value of the data converges sufficiently.
The loss value shows the error computed by the autoencoder’s loss function where
it shows how well the model reconstructs outputs from inputs. Figure 12 visualizes
the number of needed training epochs to render the loss value sufficiently stable. We
find that 200 epochs are sufficient for pretraining our autoencoder.
Fig. 12 The reduction in the loss values versus the number of training epochs
22 M. Alfarhood and J. Cheng
5 Conclusion
References
Sharif Amit Kamran, Sourajit Saha, Ali Shihab Sabbir, and Alireza Tavakkoli
© The Editor(s) (if applicable) and The Author(s), under exclusive license 25
to Springer Nature Singapore Pte Ltd. 2021
M. A. Wani et al. (eds.), Deep Learning Applications, Volume 2,
Advances in Intelligent Systems and Computing 1232,
https://doi.org/10.1007/978-981-15-6759-9_2
26 S. A. Kamran et al.
1 Introduction
Diabetes being one of the crucial health concerns affects up to 7.2% of the world
population and this number can potentially mount up to 600 million by the year
2040 [6, 38]. With the pervasiveness of diabetes, statistically—one out of every
three diabetic patients develop Diabetic Retinopathy (DR) [33]. Diabetic Retinopa-
thy causes severe damage to human vision which in turn engenders vision loss that
affects nearly 2.8% of world population [3]. In developed countries there exist effi-
cacious vision tests for DR screening and early treatment yet, a by product of such
systems is fallacious results. Moreover, identifying the false positives and false neg-
atives still remains a challenging impediment for diagnosticians. On the other hand,
DR is often mistreated in many developing and poorer economies, where access to
trained ophthalmologist and eye-care machinery may be insufficient. Therefore, this
is impending to shift the DR diagnostic technology to a system that is autonomous
in the pursuit of more accurate and faster test results on DR and other related retinal
diseases and inexpensive so that more people have access to it. This research pro-
poses a novel architecture based on convolutional neural network which can identify
Diabetic Retionpathy, while being able to categorize multiple retinal diseases with
near perfect accuracy in real time.
There are different retinal diseases other than Diabetic Retinopathy, such as Mac-
ular Degeneration. Macula is a retinal sensor that is found in the central region of
retina in human eyes. The retinal lens perceive light emitted from outside sources and
transform them into neural signals, a process otherwise known as vision. The Macula
plays an integral role in human vision from processing light via photo-receptor nerve
cells to aggregating them encoded into neural signals sent directly to the brain through
optic nerves. Retinal diseases such as Macular Degeneration, Diabetic Retinopathy,
and Choroidal Neovascularization are the leading causes of eye diseases and vision
loss worldwide.
In ophthalmology a technique called Spectral Domain Optical Coherence Tomog-
raphy (SD-OCT) is used for viewing the morphology of the retinal layers [32]. Fur-
thermore, another way to treat these diseases is to use depth-resolved tissue formation
data encoded in the magnitude and delay of the back-scattered light by spectral analy-
sis [1]. While retrieving the retinal image is performed by the computational process
of SD-OCT, differential diagnosis is conducted by human ophthalmologists. Conse-
quently, this leaves room for human error while performing differential diagnosis.
Hence, an autonomous expert system is beneficial for ophthalmologists to distin-
guish among different retinal diseases more precisely with fewer mistakes in a more
timely manner.
One of the predominant factors for misclassification of retinal maladies is due to
the stark similarity between Diabetic Retinopathy and other retinal diseases. They can
be grouped by three major categories, (i) Diabetic Macular Edema (DME) and Age-
related degeneration of retinal layers (AMD), (ii) Drusen, a condition where lipid or
protein build-up occurs in the retinal layer, and (iii) Choroidal Neovascularization
(CNV), a growth of new blood vessels in sub-retinal space. The most common retinal
A Comprehensive Set of Novel Residual Blocks for Deep Learning … 27
2 Literature Review
In this section, we discuss a series of computational approaches that have been his-
torically adopted to diagnose SD-OCT imagery. These approaches dates back to the
earlier developments in image processing as well as the segmentation algorithms in
pre-deep learning era of computer vision. While those developments are important
and made tremendous strides in their specific domains, we also discuss the learning-
based approaches and deep learning models that were created and trained both with
and without transfer learning by other researchers in the pursuit of classifying retinal
diseases from SD-OCT images. In this section, we further explain the pros and cons
of the other deep learning classification models on SD-OCT images and the develop-
ments and contributions we achieved with our proposed deep learning architecture.
The earliest approach traced in the pursuit of classifying retinal diseases from images
contains multiple image processing techniques followed by feature extraction and
classification [24]. Evidently, one such research was conducted where retinal dis-
28 S. A. Kamran et al.
One of the most pronounced ways to identify a patient with Diabetic Macular Edema
is by enlarging macular density in retinal layer [1, 5]. Naturally, several systems have
been proposed and implemented which comprises retinal layer segmentation. Due to
evidence of liquids building up in the sub-retinal space as determined by the segmen-
tation algorithms, further identification of factors that engenders specific diseases are
made possible [17, 22, 23]. In [20, 26], the authors proposed the idea of segmenting
the intra-retinal layers in ten parts and then extracted the texture and depth infor-
mation from each layer. Subsequently, any aberrant retinal features are detected by
classifying the dissimilarity between healthy retinas and the diseased ones. More-
over, Niemeijer et al. [22] introduced a technique for 3D segmentation of regions
containing fluid in OCT images using a graph-based implementation. A graph-cut
algorithm is applied to get the final predictions from the information initially retrieved
from layer-based segmentation of fluid regions. Even though implementation based
on a previous segmentation of retinal layers have registered high scoring prediction
results, the initial step is reportedly laborious, prolonged and erroneous [10, 13]. As
reported in [19], retinal thickness measurements obtained by different systems have
stark dissimilarity. Therefore, it is neither efficient nor optimal to compare between
different retinal depth information retrieved by separate machines, despite of the
improved prediction accuracy over the feature engineering methods with traditional
A Comprehensive Set of Novel Residual Blocks for Deep Learning … 29
image analysis discussed earlier. These observations further enforces the fact that
segmentation-based approaches aren’t effective as a universal retinal disease recog-
nition system.
3 Proposed Methodology
Fig. 1 Different variants of residual unit and our proposed residual unit with a novel lateral prop-
agation of neurons
A Comprehensive Set of Novel Residual Blocks for Deep Learning … 31
Fig. 2 Illustration of the Building Blocks of our proposed CNN [ OpticNet-71 ]. Only the very
first convoluiton (7 × 7) layer in the CNN and the very last convolution (1 × 1) layer in Stage [2,
3, 4]: Residual Convolutional Unit uses stride 2, while all other convolution operations use stride 1
for such performance enhancement. While, Fig. 2b depicts the proposed mechanism
to handle gradient degradation, Fig. 2c narrates the entire CNN architecture. We
discuss constituent segments of our CNN architecture—Optic Net and philosophy
behind our lateral design schemes over the following subsections.
To design a model with higher efficacy, lower memory footprint and portable compu-
tational complexity we experimented with different types of convolution operations
and various lateral choices in the residual unit of our proposed neural network.
Figure 1a represents the propagation of neurons in a vanilla residual unit with pre-
activation [11, 12]. While achieving a high prediction accuracy, the first row of
Table 1 however reports higher memory and computational footprint resulting from
such a design. The natural progression was to include a kernel with wide receptive
field, hence inclusion of dilation in the kernel operation as shown by Yu et al. [39].
The following Sect. 3.1.2 addresses the mechanism of this block in details.
32 S. A. Kamran et al.
Table 1 Comparison between different convolution operations used in the middle portion of resid-
ual unit
Type of convolution used in Approximate numbers Depletion factor for Accuracy
the middle of residual unit of parametersa parameter, p (%) (%)
Regular convolution f 2 × D [i] × D [i−1] = 100 99.30
36,864
Atrous convolution ( f − 1)2 × D [i] × (1 − f)
1 2
= 44.9 97.20
D [i−1] = 16,384
Separable convolution ( f 2 + D [i] ) × D [i−1] = 1
f2
+ 1
D [i]
= 12.5 98.10
4,672
Atrous separable convolution (( f − 1)2 + D [i] ) × 1
f2
+ (1 − 1f )2 × D1[i] = 96.70
D [i−1] = 4,352 11.6
1 [i]
2 (( f − 1) (1 + 2 D ) + (1 − 1f )2 ×
1 2 1
Atrous convolution and (2 f )2
99.80
1 [i] [i−1]
atrous separable +2 D ) × D = ( 41 + 1
)= 14.4
convolution branched 2D [i]
5,248
a Here, kernel size, (f, f) = (3, 3). Depth (# kernels) in Residual unit’s middle operation, D [i] = 64
and first operation, D [i−1] = 64.
b The Test Accuracy reported in the table is obtained by training on OCT2017 [16] data-set, while
To further contract computational strain, we then take the separable residual unit and
replace the depth-wise 3 × 3 convolution block with an atrous convolution block of
2 × 2 with dilation rate two, as shown in Fig. 1d. Figure 3 further demonstrates the
mechanism of atrous separable convolution operation on signals. With this design
choice, we cut the parameter count by 87.4% which is the fastest computational com-
plexity in our experiment (Fourth row of Table 1). However, the repetitive use of such
residual unit on our proposed neural network accumulates to the lowest prediction
accuracy on our data-set. Furthermore, using only separable convolution results to
depth-wise feature extraction in most of the identity blocks while not encompassing
any spatial information in any of those learnable layers. That effectively answers the
underlying reason for having such low accuracy out of all the novel residual units.
Observations made on such trade-offs between computational complexity and
performance we arrive at our proposed residual (Fig. 1e) unit which we discuss next
in Sect. 3.2.
Historically, Residual Units [11, 12] used in Deep Residual Convolutional Neural
Networks (CNN), process the incoming input through three convolution operations
while adding the incoming input with the processed output. These three convolutional
operations are (1 × 1), (3×3) and (1 × 1) convolutions. Therefore, replacing the
(3 × 3) convolution in the middle with other types of convolutional operations can
potentially change the learning behavior, computational complexity, and eventually
prediction performance, as demonstrated in Fig. 1.
We experimented with different convolution operations as replacement for the
(3 × 3) middle convolution and observed which choice contributes the most to reduce
the number of parameters, ergo computational complexity, as depicted in Table 1.
Furthermore, in Table 1, we use a depletion factor for parameters, p which is a
ratio of number of parameters in the replaced convolution and regular convolution
expressed in percent.
In our proposed residual unit, we replace the middle (3 × 3) convolution operation
with two different operations running in parallel as detailed in Fig. 2a. Whereas, a
conventional residual unit uses D [i] number of channels for the middle convolution,
we use 21 D [i] number of channels for each of the newly replaced operations to prevent
any surge in parameter. In the proposed branching operation we use a (2 × 2) Atrous
convolution (C2 ) with dilation rate, r = 2 to get a (3 × 3) receptive field in the left
branch while in the right branch we use a (2 × 2) Atrous separable convolution (C3 )
with dilation rate, r = 2 to get a (3 × 3) receptive field. Sequentially, the results are then
added together. Furthermore, separable convolution [30] disentangles the spatial and
depth-wise feature maps separately while Atrous convolutions inspect both spatial
and depth channels together. We hypothesize that adding two such feature maps that
are learned very differently shall help trigger more robust and subtle features.
X l+1 = X l + (X l C1 C2 ) + (X l C1 C3 ) C4
(1)
= X l + F̂(X l , Wl )
Figure 3 shows how adding Atrous and Atrous separable feature maps help dis-
entangle the input image space with more depth information instead of activating
only the predominant edges. Moreover, the last row of Table 1 confirms that adopt-
ing this strategy still reduces the computational complexity by a reasonable margin,
while improving inference accuracy. Equation (1) further clarifies how input signals
X l travel through the proposed residual unit shown in Fig. 2a, where refers to
convolution operation.
A Comprehensive Set of Novel Residual Blocks for Deep Learning … 35
In this section we discuss the proposed building block as a constituent part of Optic-
Net. As shown in Fig. 2b we split the input signal (X l ) into two branches (1) Stack
of Residual Units, (2) Signal Exhaustion. Later in this section, we explain how we
connect these two branches to propagate signals further in the network.
N
α(X l ) = X l+N = X l + F̂(X i , Wi ) (2)
i=l
In order to initiate a novel learning chain for propagating signals through stacking
several of the proposed residual units linearly, we suggest to combine global residual
effects enhanced by pre-activation residual units [12] and our proposed set of convo-
lution operations (Fig. 2a). As shown in (1), F̂(X l , Wl ) denotes all the proposed set
of convolution operations inside a residual unit for input X l . We sequentially stack
these residual units N times over X l which is input to our proposed building block,
as narrated in Fig. 2b. Equation (2) illustrates the state of output signal denoted by
X l+N which is processed through a stack of residual units of length N . For the sake
of further demonstration we denote X l+N as α(X l ).
∂
N
∂τ (X l )
= 1 + σ(X l ) × 1 + F̂(X i , Wi )
∂ Xl ∂ X l i=l
N
+ 1 + Xl + F̂(X i , Wi ) × σ(X l ) × 1 − σ(X l )
i=l (4)
∂
N
σ (X l )
= 1+ × 1+ F̂(X i , Wi )
1 − σ(X l ) ∂ X l i=l
+ 1 + X l+N × σ (X l )
As shown if Fig. 2b, we process the residual signal, α(X l ) and exhausted signal,
β(X l ) following (3) and we denote the output signal propagated from the proposed
building block as τ (X l ). Our hypothesis behind such design is that, whenever one
of the branch falls prey to gradient degradation from a mini-batch the other branch
manages to propagate signals unaffected by the mini-batch with amplified absolute
gradient. To validate our hypothesis (3) shows that, τ (X l ) ≈ α(X l ), ∀β(X l ) ≈ 0
and τ (X l ) ≈ β(X l ), ∀α(X l ) ≈ 0 illustrating how the unaffected branch survives the
degradation in the affected branch. However, when none of the branch gets affected by
gradient amplification the multiplication (α(X l ) × β(X l )) balances out the increase
in signal propagation due to both branch’s addition. Equation (4) delineates the
gradient of building block output τ (X l ) with respect to building block input X l
calculated during back-propagation for optimization.
Figure 2c portrays the entire CNN architecture with all the building blocks and con-
stituent components joined together. First, the input batch (224 × 224 × 3) is propa-
gated through a 7 × 7 Conv with stride 2 that follows batch-normalization and ReLU
activation. Then we propagate the signals via a Residual Convolution Unit (same as
the unit used in [12]) which is then followed by our proposed building block. We
propagate the signals through this [Residual Convolution Unit → Building Block]
procedure for S = 4 times, as we call them stage 1, 2, 3, and 4, respectively. Then
global average pooling is applied to the signals which passes through two more Fully
Connected(FC) layers for the loss function which is denoted by ξ.
In Table 2, we show the number of feature maps (Layer Depth) we use for each
layer in the network. The output shape of the input tensor after four consecutive
stages are (112×112×256), (56×56×512), (28×28×1024), and (14×14×2048),
A Comprehensive Set of Novel Residual Blocks for Deep Learning … 37
Table 2 Architectural specifications for opticnet-71 and layer-wise analysis for number of feature
maps in comparison with Resnet50-v1 [11]
Layer name ResNet50 V1 [11] OpticNet71 [Ours]
Conv 7 × 7 [64] × 1 [64] × 1
Stage1: Res Conv [64, 64, 256] × 1 [64, 64, 256] × 1
Stage1: Res Unit [64, 64, 256] × 2 [32, 32, 32, 256] × 4
Stage2: Res Conv [128, 128, 512] × 1 [128, 128, 512] × 1
Stage2: Res Unit [128, 128, 512] × 3 [64, 64, 64, 512] × 4
Stage3: Res Conv [256, 256, 1024] × 1 [256, 256, 1024] × 1
Stage3: Res Unit [256, 256, 1024] × 5 [128, 128, 128, 1024] × 3
Stage4: Res Conv [512, 512, 2048] × 1 [512, 512, 2048] × 1
Stage4: Res Unit [512, 512, 2048] × 2 [256, 256, 256, 2048] × 3
Global Avg Pool 2048 2048
Dense layer 1 K (Classes) 256
Dense layer 2 – K (Classes)
Parameters 25.64 Million 12.50 Million
Required FLOPs 3.8 ×109 2.5 ×107
CNN Memory 98.20 MB 48.80 MB
Equation (5) represents the gradient calculated for the entire network chain dis-
tributed over stages of optimization. As (4) suggests, the term (1 + σ(X l ))—in com-
parison with [12])—works as an extra layer of protection to prevent possible gradient
explosion caused by the stacked residual units by multiplying nonzero activations
with the residual unit’s gradients. Moreover, the term (1 + X l+N ) indicates that the
optimization chain still has access to signals from much earlier in the network and
to prevent unwanted spikes in activations the term σ (X l ) can still mitigate gradient
expansion which can potentially jeopardize learning otherwise.
4 Experiments
The following section contains information for training, validating, and testing the
architectures with different settings and hyper-parameters. Moreover, it gives a
detailed analysis of how the architectures were compared with previous techniques
and expert human diagnosticians. Additionally, the juxtaposition was drawn in terms
of speed, accuracy, sensitivity, specificity, memory usage, and penalty weighted met-
38 S. A. Kamran et al.
rics. All the files related to the experimentation and training can be found in the
following Code Repository: https://github.com/SharifAmit/OCT_Classification.
We benchmark our model against two distinct data-sets (different scale, sample space,
etc.). The first data-set aims at correctly recognizing and differentiating between four
distinct retinal states provided by the OCT2017 [16] data-set. Where, the stages are
normal healthy retina, Drusen, Choroidal Neovascularization (CNV), and Diabetic
Macular Edema (DME). OCT2017 [16] data-set contains 84,484 images (provided
as high quality TIFF format with three non-RGB color channels). We split them
into 83,484 train-set and 1000 test-set. The second data-set—Srinivasan2014 [32]—
consists of three classes and aims at classifying normal healthy specimen of retina,
Age-Related Macular Degeneration (AMD) and Diabetic Macular Edema (DME).
Srinivasan2014 [32] data-set consists of 3,231 image samples that we split into 2,916
train-set, 315 test-set. We resize images from both data-sets to 224 × 224 × 3 for
both training and testing. For both the data-set we do fivefold cross-validation on the
training set and find the best models.
We calculated four standard metrics to evaluate our CNN model on both data-sets:
Accuracy (6), Sensitivity (7), Specificity (8) and a Special Weighted Error (9) from
[16]. Where N is the number of image samples and K is the number of classes. Here
TP, FP, FN, and TN denotes True Positive, False Positive, False Negative, and True
Negative, respectively. We report True Positive Rate (TPR) or Sensitivity (6) and
True Negative Rate (TNR) or Specificity (7) for the both the data-sets [16, 32]. For
this, we calculate the TPR and TNR for individual classes then sum all the values
and then divide that by the number of classes (K).
1
Accuracy = TP (6)
N
1 TP
Sensitivity = (7)
K T P + FN
1 TN
Specificity = (8)
K T N + FP
1
Weighted Error = Wi j · X i j (9)
N i, j∈K
A Comprehensive Set of Novel Residual Blocks for Deep Learning … 39
other existing solutions, with a Sensitivity and Specificity of 99.80 and 99.93%.
Furthermore, the Weighted Error is reported to be a mere 0.20% which can be visu-
alized in Fig. 4 as our architecture misidentifies one Drusen and one DME sample as
CNV. However, the penalty weight is only 1 for each of the misclassification as we
report in Table 3. Sequentially, with our proposed OpticNet-71 we obtain state-of-
the-art results on OCT2017 [16] data-set across all four performance metrics, while
significantly surpassing human benchmarks as mentioned in Table 4.
loss doesn’t lower for six consecutive epochs. Moreover, we set the lowest learning
rate to αmin
lr
= 1e−8 . Furthermore, We use Adam optimizer with default parameters
of β1adam
= 0.90 and β2adam = 0.99 for all training schemes. We train OCT2017 [16]
data-set for 44 hours and Srinivasan2014 [32] data-set for 2 hours on a 8 GB NVIDIA
GTX 1070 GPU.
Inception-v3 models under-perform compared to both pre-trained models and
OpticNet-71 as seen in Table 4. OpticNet-71 takes 0.03 seconds to make prediction
on an OCT image—which is real time and while accomplishing state-of-the-art
results on OCT2017 [16], Srinivasan2014[32] data-set our model also surpass human
level prediction on OCT images as depicted in Table 4. Human experts are real
diagnosticians as reported in [16]. In [16], there are six diagnosticians and the highest
performing one is Human Expert 5 while the lowest performing one is Human Expert
2. To validate our CNN architecture’s optimization strength we also train two smaller
versions of OptcNet-71 on both dataests, which are OpticNet-47 ( [N1 N2 N3 N4 ] =
[2 2 2 2] ) and OpticNet-63 ( [N1 N2 N3 N4 ] = [3 3 3 3] ). In Fig. 5 we unfold how
all of our variants of OpticNet outperforms the pre-trained CNNs on Srinivasan2014
[32] data-set while OpticNet-71 outperforms all the pre-trained CNNs on OCT2017
[16] data-set in terms of accuracy as well as performance-memory trade-off.
Fig. 5 Test accuracy (%), CNN memory (Mega-Bytes) and model parameters (Millions) on
OCT2017 [16] data-set and Srinivasan2014 [32] data-set
consider as output from each stage as narrated in Fig. 6b. Furthermore, Fig. 6b
portrays how element-wise addition with the element-wise multiplication between
signals helps the learning propagation of OpticNet-71. Figure 6b precisely depicts
why this optimization chain is particularly significant, as a zero activation can cancel
out a live signal channel from the residual counterpart (τ (X l ) = α(X l ) + β(X l ) ×
(1 + α(X l ))) while a dead signal channel can also cancel out a nonzero activation
from the interpolation counterpart (τ (X l ) = β(X l ) + α(X l ) × (1 + β(X l )))—thus
preventing all signals of a stage from dying and resulting in catastrophic optimization
failure due to dead weights or gradient explosion.
Fig. 6 a Visualizing input images from each class through different layers of Optic-Net 71. As
shown, the feature maps at the end of each building block learns more fine-grained features by
focusing sometimes on the same shapes—rather in different regions of the image—learning to decide
what features lead the image to the ground truth. b The learning progression, however, shows how
exhausting the signal propagated with residual activation learns to detect more thin edges—delving
further into the Macular region to learn anomalies. While using the signal exhaustion mechanism
sometimes, important features can be lost during training. Our experiments show, by using more of
these building blocks we can reduce that risk of feature loss and improve overall optimization for
Optic-Net 71
OCT2017[16] data-set in Fig. 8. As Fig. 7 illustrates, the best and average accuracy
from both training and validation set reaches a stable maxima after training and this
phenomena attest to the efficacy of our model. Furthermore, in Fig. 8 we report a
discernible discrepancy between the average validation loss and best validation loss
among the fivefold which assures that our model’s performance is not a resultant
factor of over-fitting.
We experimented with different variants of residual units to find the best version
of OpticNet-71 for both the data-set. First, we ran the experiment with a vanilla
convolution on OCT2017 [16] data-set. But with that we were under-performing
against Human-expert 5. After that we incorporated dilation with convolution layers
with which we reached sensitivity of 99.73% and weighted error of 0.8% as shown in
Table 6. So, it was worse than using residual units with vanilla convolution. Next, we
tried with separable convolution and its dilated counterpart with which we achieved
a precision score of 99.30% (vanilla) and 99.40% (dilated) consecutively. Lastly, we
tried out the proposed novel residual unit consisting of dilated convolution on one
branch and separable convolution with dilation on the other. With this, we reached
44 S. A. Kamran et al.
Fig. 7 Average accuracy and best accuracy of the fivefold cross-validation on the OCT2017 [16]
validation and training
Fig. 8 Average accuracy and minimum loss of the fivefold cross-validation on the OCT2017 [16]
validation and training
our desired precision score of 99.93% and weighted error of 0.2% beating Human
Expert 5 and reached state-of-the-art accuracy. It is worthwhile to mention that the
hyper-parameters for all the architectures were same. Moreover, the training was
done with the same optimizer (Adam) for 30 epochs. So, it is quite evident that
OpticNet-71, comprising of dilated convolution and dilated separable convolution,
was the optimum choice for deployment and prediction in the wild.
A Comprehensive Set of Novel Residual Blocks for Deep Learning … 45
In this section, we expound the application pipeline we have used to deploy our
model for users to interact with. The application pipeline consists of two fragments
as depicted in Fig. 9: (a) The Front End, (b) The Back End. The user interacts with
an app in the front end where one can upload an OCT image. The input image
which is an image uploaded by user on the app is then passed onto the back end.
The input image first goes through a pre-defined set of preprocessing steps and then
gets forwarded to our CNN model (Optic Net-71). All of these processes take place
on the server. The class prediction score outputted by our model is then sent back
to the app corresponding to the input image specific request by user. For our CNN
is lightweight and capable of outputting a high precision prediction in real time, it
facilitates a smooth user experience. Meanwhile, the uploaded image along with it’s
prediction tensor is simultaneously stored on cloud storage for further fine-tuning
and training of our model which expedites the goal of heightening system precision
and widens new horizons to foster further research and development.
5 Conclusion
In this chapter, we introduced a novel sets of residual blocks that can be infused to
build a convolutional neural network for abridging the relation between diagnosing
retinal diseases with expert level precision. Additionally, by exploiting this archi-
tecture we devise a practical solution that can address the problem of vanishing and
exploding gradients. This work is an extension of our previous work [14] which illus-
trates the exploratory analysis of different novel blocks and how effective it is in the
diagnosis of retinal degeneration. In future, we would like to expand on this research
to address other sub-types of retinal degeneration and isolate the boundaries of the
macular subspace in retina, which in turn will assist the expert ophthalmologist to
carry out their differential diagnosis,
References
5. R.A. Costa, M. Skaf, L.A. Melo Jr., D. Calucci, J.A. Cardillo, J.C. Castro, D. Huang, M.
Wojtkowski, Retinal assessment using optical coherence tomography. Prog. Retin. Eye Res.
25(3), 325–353 (2006)
6. C. Prevention et al., National diabetes statistics report, 2017 (2017)
7. B.M. Ege, O.K. Hejlesen, O.V. Larsen, K. Møller, B. Jennings, D. Kerr, D.A. Cavan, Screening
for diabetic retinopathy using computer based image analysis and statistical classification.
Comput. Methods Programs Biomed. 62(3), 165–175 (2000)
8. N. Ferrara, Vascular endothelial growth factor and age-related macular degeneration: from
basic science to therapy. Nat. Med. 16(10), 1107 (2010)
9. D.S. Friedman, B.J. O’Colmain, B. Munoz, S.C. Tomany, C. McCarty, P. De Jong, B. Nemesure,
P. Mitchell, J. Kempen et al., Prevalence of age-related macular degeneration in the united states.
Arch Ophthalmol 122(4), 564–572 (2004)
10. I. Ghorbel, F. Rossant, I. Bloch, S. Tick, M. Paques, Automated segmentation of macular
layers in OCT images and quantitative evaluation of performances. Pattern Recognit. 44(8),
1590–1603 (2011)
11. K. He, X. Zhang, S. Ren, J. Sun, Deep residual learning for image recognition, in Proceedings
of the IEEE Conference on Computer Vision and Pattern Recognition (2016), pp. 770–778
12. K. He, X. Zhang, S. Ren, J. Sun, J.: Identity mappings in deep residual networks, in European
Conference on Computer Vision (Springer, 2016), pp. 630–645
13. R. Kafieh, H. Rabbani, S. Kermani, A review of algorithms for segmentation of optical coher-
ence tomography from retina. J. Med. Signals Sens. 3(1), 45 (2013)
14. S.A. Kamran, S. Saha, A.S. Sabbir, A. Tavakkoli, Optic-net: a novel convolutional neural
network for diagnosis of retinal diseases from optical tomography images, in 2019 18th IEEE
International Conference On Machine Learning And Applications (ICMLA) (2019), pp. 964–
971
15. S.P.K. Karri, D. Chakraborty, J. Chatterjee, Transfer learning based classification of opti-
cal coherence tomography images with diabetic macular edema and dry age-related macular
degeneration. Biomed. Opt. Express 8(2), 579–592 (2017)
16. D.S. Kermany, M. Goldbaum, W. Cai, C.C. Valentim, H. Liang, S.L. Baxter, A. McKeown, G.
Yang, X. Wu, F. Yan et al., Identifying medical diagnoses and treatable diseases by image-based
deep learning. Cell 172(5), 1122–1131 (2018)
17. A. Lang, A. Carass, M. Hauser, E.S. Sotirchos, P.A. Calabresi, H.S. Ying, J.L. Prince, Retinal
layer segmentation of macular oct images using boundary classification. Biomed. Opt. Express
4(7), 1133–1152 (2013)
18. C.S. Lee, D.M. Baughman, A.Y. Lee, Deep learning is effective for classifying normal versus
age-related macular degeneration oct images. Ophthalmol. Retin. 1(4), 322–327 (2017)
19. J.Y. Lee, S.J. Chiu, P.P. Srinivasan, J.A. Izatt, C.A. Toth, S. Farsiu, G.J. Jaffe, Fully automatic
software for retinal thickness in eyes with diabetic macular edema from images acquired by
cirrus and spectralis systems. Investig. ophthalmol. Vis. Sci. 54(12), 7595–7602 (2013)
20. K. Lee, M. Niemeijer, M.K. Garvin, Y.H. Kwon, M. Sonka, M.D. Abramoff, Segmentation
of the optic disc in 3-d OCT scans of the optic nerve head. IEEE Trans. Med. Imaging 29(1),
159–168 (2010)
21. G. Lemaître, M. Rastgoo, J. Massich, C.Y. Cheung, T.Y. Wong, E. Lamoureux, D. Milea, F.
Mériaudeau, D. Sidibé, Classification of sd-oct volumes using local binary patterns: experi-
mental validation for dme detection. J. Ophthalmol. 2016 (2016)
22. X.C. MeindertNiemeijer, L.Z.K. Lee, M.D. Abràmoff, M. Sonka, 3d segmentation of fluid-
associated abnormalities in retinal oct: Probability constrained graph-search-graph-cut. IEEE
Trans. Med. Imaging 31(8), 1521–1531 (2012)
23. A. Mishra, A. Wong, K. Bizheva, D.A. Clausi, Intra-retinal layer segmentation in optical
coherence tomography images. Opt. Express 17(26), 23719–23728 (2009)
24. H. Nguyen, A. Roychoudhry, A. Shannon, Classification of diabetic retinopathy lesions from
stereoscopic fundus images, in Proceedings of the 19th Annual International Conference of
the IEEE Engineering in Medicine and Biology Society.’Magnificent Milestones and Emerging
Opportunities in Medical Engineering (Cat. No. 97CH36136), vol. 1 (IEEE, 1997), pp. 426–
428
48 S. A. Kamran et al.
Abstract Lower child mortality rates, advances in medicine, and cultural changes
have increased life expectancy to above 60-years old in developed countries. Some
countries expect that, by 2030, 20% of their population will be over 65 years old. The
quality of life at this advanced age is highly dictated by the individual’s health, which
will determine whether the elderly can engage in important activities to their well-
being, independence, and personal satisfaction. Old age is accompanied by health
problems caused by biological limitations and muscle weakness. This weakening
facilitates the occurrence of falls, which are responsible for the deaths of approxi-
mately 646,000 people worldwide and, even when a minor fall occurs, it can still
cause fractures, break bones, or damage soft tissues, which will not heal completely.
Injuries and damages of this nature, in turn, will consume the self-confidence of the
individual, diminishing their independence. In this work, we propose a method capa-
ble of detecting human falls in video sequences using multi-channel convolutional
neural networks (CNN). Our method makes use of a 3D CNN fed with features pre-
viously extracted from each frame to generate a vector for each channels. Then, the
vectors are concatenated, and a support vector machine (SVM) is applied to classify
the vectors and indicate whether or not there was a fall. We experiment with four
types of features, namely: (i) optical flow, (ii) visual rhythm, (iii) pose estimation, and
(iv) saliency map. The benchmarks used (UR Fall Detection Dataset (URFD) [33]
and (ii) Fall Detection Dataset (FDD) [12]) are publicly available and our results are
compared to those in the literature. The metrics selected for evaluation are balanced
accuracy, accuracy, sensitivity, and specificity. Our results are competitive with those
obtained by the state of the art on both URFD and FDD datasets. To the authors’
knowledge, we are the first to perform cross-tests between the datasets in question
and to report results for the balanced accuracy metric. The proposed method is able
to detect falls in the selected benchmarks. Fall detection, as well as activity classi-
fication in videos, is strongly related to the network’s ability to interpret temporal
information and, as expected, optical flow is the most relevant feature for detecting
falls.
1 Introduction
Developed countries have reached a life expectancy of over 60 years of age [76].
Some countries in the European Union and China expect 20% of their population to
be over 65 by 2030 [25]. According to the World Health Organization [25, 78], this
recent level is a side effect of scientific advances, medical discoveries, reduced child
mortality, and cultural changes.
However, although human beings are living longer, the quality of this life span
is mainly defined by their health, since it dictates an individual’s independence,
satisfaction, and the possibility of engaging in activities that are important to their
well being. The relation between health and life quality inspired several research
groups to try and develop assisting technologies focusing on the elderly population.
1.1 Motivation
Naturally, health problems will appear along with the aging process, mainly due to
the biological limitations of the human body and muscle weakness. As a side effect,
this weakening increases the elderly chances of suffering a fall. Falls are the second
leading cause of domestic death worldwide, killing around 646,000 people every
year [25]. Reports indicate that between 28 and 35% of the population over 65 years
old falls at least once a year and this percentage rises to 32–42% to individuals over
70 years old.
In an attempt to summarize which events may lead to a fall, Lusardi et al. [46]
reported risk factors, including how falls occur, who are falling and some precautions
to avoid them. The effects of a fall accident to this fragile population can unleash
a chain reaction of physiological and psychological damages, which consequently
could decrease the elderly’s self-confidence and independence. Even if the accident
is a small fall, it can break or fracture bones and damage soft tissues that, due to old
age, may never fully recover.
To avoid serious injuries the elder population needs constant care and monitoring,
especially regarding the fact that an accident’s effects are related to the time between
its occurrence and the beginning of adequate care. That being said, qualified care-
Three-Stream Convolutional Neural Network … 51
Technical Remote
report operator
Sensors Camera Classifier
Blood Heartbeat
pressure sensor
IMU
Fall
alert
+
Local
processing
center
givers are expensive, especially when added to the already inflated budget of health
care in advanced age, which might lead families to relocate the elderly. Relocation
is a common practice, in which the surrounding family moves the elderly from their
dwelling place to the family’s home, causing discomfort and stress to adapt to a new
environment and routine.
Based on the scenario described previously, a technological solution can be
devised in the form of an emergency system that would reliably trigger qualified
assistance automatically. Thus, the response time between the accident and care
would be reduced, allowing the elderly to dwell in their own home, keeping their
independence and self-esteem. The diagram in Fig. 1 illustrates the components of
such system, also known as an assisted daily living (ADL) system.
The system is fed by a range of sensors, installed around the home or attached to
the subject’s body, which monitors information such as blood pressure, heartbeats,
body temperature, blood glucose levels, acceleration, and posture. These devices
are connected to a local processing center that, in addition to continuously building
technical reports on the health of the elderly, also activates a remote operator in an
emergency alert. Upon receiving the alert, the operator performs a situation check
and, upon verifying the need, dispatch medical assistance.
52 G. V. Leite et al.
Aware of the dangers and consequences that a fall represents for the elderly population
and, as a video-based fall detection system can benefit from deep learning techniques,
we propose a fall detection module, as part of an ADL system, which can detect falls
in video frame sequences. In implementing this proposal, we intend to answer the
following research questions:
1. Are the feature channels able to maintain enough temporal information so that a
network can learn its patterns?
2. Which feature channel best contribute to the problem of fall detection?
3. Does the effectiveness of the proposed method remain for other human fall
datasets?
4. Are three-dimensional (3D) architectures more discriminative than two-
dimensional (2D) equivalent to this problem?
1.3 Contributions
The main contribution of this work is the multi-channel model to detect human
falls [8, 9, 37]. Each channel represents a feature that we judge by descriptive of a
fall event and, upon being ensembled, they produce a better estimation of a fall.
The model was tested on two publicly available datasets, and is open-sourced [38].
In addition, we present an extensive comparison between different channel combina-
tions and the employed architecture, whose best results are comparable to the state of
the art. Finally, we also discuss the implications of simulated datasets as a benchmark
for fall detection methods.
The remaining of this work is structured as follows. Section 2 reviews the exist-
ing approaches in the related literature. Section 3 presents the main concepts that
were used in the implementation of this work. Section 4 describes in detail the pro-
posed methodology and its steps. Section 5 exhibits the selected datasets used in
the experiments, our evaluation metrics, the experiments carried out, their results,
and discussions. Finally, Sect. 6 presents some final considerations and proposals for
future work.
Three-Stream Convolutional Neural Network … 53
2 Related Work
In the following subsections, we elucidate the related work to fall detection and
their methods. The works were split into two groups, based on the sensors used: (i)
methods without videos and (ii) methods with videos.
The works in this group utilize various sensors to obtain data, which can be a watch,
accelerometer, gyroscope, heart rate sensor or a smartphone, that is, any sensor that
is not a camera.
Khin et al. [28] developed a triangulation system that uses several presence sensors
installed in the monitored room of a home. The presence of someone in the room
causes a disturbance in the sensors and a fall would trigger an activation pattern.
Despite not reporting the results, the authors stated that they tested this configuration
and that the sensors detected different patterns for different actions, indicating that
the solution could be used to detect falls.
Kukharenko and Romanenko [31] obtained their data from a wrist device. A
threshold was used to detect impact or a “weightless” state, and after this detection,
the algorithm waits a second and analyzes the information obtained. The method was
tested on some volunteers, who reported complaints such as forgetting to wear the
device and the discomfort it caused.
Kumar et al. [32] placed a sensor attached to a belt on the person’s waist and
compared four methods to detect falls: threshold, support vector machines (SVM),
K -nearest neighbors (KNN) and dynamic time warping (DTW). The authors also
commented on the importance of sensors attached to the body, since they would
monitor the individual constantly as it would not present blind spots, as opposed to
cameras.
Vallejo et al. [72] developed a deep neural network to classify the sensor’s data.
The chosen sensor is a gyroscope worn at the waist and the network is composed
of three hidden layers, with five neurons each. The authors carried out experiments
with adults aged 19–56 years.
Zhao et al. [84] collected data from a gyroscope attached to the individual’s waist
and used a decision tree to classify the information. The experiments were performed
on five random adults.
Zigel et al. [86] used accelerometers and microphones as sensors, however, the
sensors were installed in the environment, instead of being attached to the subject.
The sensors would detect vibrations and feed a quadratic classifier. The tests were
executed with a test dummy, which was released from an upright position.
54 G. V. Leite et al.
In this section, we grouped the methods whose main data sources are video sequences
from cameras. Despite a sensor similarity, these methods presented a variety of
solutions such as the following works that used threshold-based activation techniques.
To isolate the human silhouette, Lee and Mihailidis [36] performed a background
subtraction alongside with a region extraction. After this, the posture was determined
through a threshold of the values of the silhouette’s perimeter, speed of the silhouette’s
center, and by the Feret’s diameter. The solution was tested on a dataset created by
the authors.
Nizam et al. [53] used two thresholds, the first verified if the body speed was high
and, if so, a second threshold verified whether the joints’ position was close to the
ground. The joints’ position was obtained after subtracting the background, with a
Kinect camera. The experiments were carried out on a dataset created by the authors.
Sase and Bhandari [57] applied a threshold in which, a fall was defined as if the
region of interest was smaller than one-third of the individual’s height. The region
of interest was obtained by background extraction and the method was tested on the
basis URFD [33]. Bhandari et al. [5] applied a threshold on the speed and direction of
the region of interest. A combination of Shi-Tomasi and Lucas–Kanade was applied
to determine the region of interest. The authors tested the approach in the URFD [33]
set and reported 95% accuracy.
Another widely used classification technique is the SVM, used by Abobakr et
al. [2]. The method used depth information to subtract the background of the video
frame, applied a random forest algorithm to estimate the posture and classified it
with SVM. Fan et al. [17] also separated the picture between the foreground and
background and fitted an ellipse to the silhouette of the human body found. From the
ellipse, six features were extracted and served to a slow feature function. An SVM
classified the output of the slow feature function and the experiments were executed
on the SDUFall [47] dataset.
Harrou et al. [23] used an SVM classifier that received features extracted from
video frames. During the tests, the authors compared the SVM with a multivariate
exponentially weighted moving average (MEWMA) and tested the solution on the
datasets URFD [33] and FDD [12].
Mohd et al. [50] fed an SVM classifier with information on the height, speed,
acceleration, and position of the joints and performed tests on three datasets: TST Fall
Detection [20], URFD [33] and Fall Detection by Zhang [83]. Panahi and Ghods [56]
subtracted the background from the depth information, fitted an ellipse to the shape of
the individual, classified the ellipse with SVM and performed tests on the URFD [33]
dataset.
Concerned with privacy, the following works argued that solutions to detect falls
should offer anonymity options. Therefore, Edgcomb and Vahid [16] tested the effec-
tiveness of a binary tree classifier over a time series. The authors compared different
means to hide identity, such as blurring, extracting the silhouette, replacing the indi-
vidual with an opaque ellipse or an opaque box. They conducted tests on their dataset,
Three-Stream Convolutional Neural Network … 55
with 23 videos recorded. Lin et al. [42] investigated a solution focused on privacy
using a silhouette. They applied a K-NN classifier in addition to a timer that checks
whether the individual’s pose has returned to normal. The tests were performed by
laboratory volunteers.
Some studies used convolutional neural networks, such as the case of
Anishchenko [4], which implemented an adaptation of the AlexNet architecture
to detect falls in the FDD [12] dataset. Fan et al. [18] used a CNN to monitor and
assess the degree of completeness of an event. A stack of video frames was used in
a VGG-16 architecture and its result was associated with the first frame in the stack.
The method was tested on two datasets: FDD [12] and Hockey Fights [52]. Their
results were reported in terms of completeness of the falls.
Huang et al. [26] used the OpenPose algorithm to obtain the coordinates of the
body joints. Two classifiers (SVM and VGG-16) were compared to classify the coor-
dinates. The experiments were carried out on the datasets URFD [33] and FDD [12].
Li et al. [40] created a modification of CNN’s architecture, AlexNet. The solution
was tested on the URFD [33] dataset, also the authors reported that the solution
classified between ADLs and falls in real time.
Min et al. [49] used an R-CNN (CNN of regions) to analyze a scene, which
generates spatial relationships between furniture and the human being on the scene,
and then classified the spatial relationship between them. The authors experimented
on three datasets: URFD [33], KTH [58], and a dataset created by them. Núñez-
Marcos et al. [54] performed the classification with a VGG-16. The authors calculated
the dense optical flow, which served as a characteristic for the network to classify.
They tested the method on the URFD [33] and FDD [12] databases.
Coincidentally, all works that used recurrent neural networks used the same archi-
tecture, long-short term memory (LSTM). Lie et al. [41] applied a recurrent neu-
ral network, with LSTM cells, to classify the individual’s posture. The stance was
extracted by a CNN and the experiments were carried out on a dataset created by the
authors.
Shojaei-Hashemi et al. [60] used a Microsoft Kinect device to obtain the indi-
vidual’s posture information and an LSTM as a classifier. The experiments were
performed on the NTU RGB+D dataset. Furthermore, the authors reported one advan-
tage of using the Kinect, since the posture extraction could be achieved in real time.
Lu et al. [43] proposed the application of an LSTM right after a 3D CNN. The authors
performed tests on the URFD [33] and FDD [12] datasets.
Other machine learning algorithms, such as the K-nearest neighbors, were also
used to detect falls. Kwolek and Kepski [34] made use of a combination of an
accelerometer and Kinect. Once the accelerometer surpassed a threshold, a fall alert
was raised, and only then, the Kinect camera started capturing frames of the scene’s
depth, which was used by a second classifier. The authors compared the classification
of the frames between KNN and SVM and tested on two datasets, the URFD [33]
and on an independent one.
Sehairi et al. [59] developed a finite state machine to estimate the position of the
human head from the extracted silhouette. The tests were performed on the FDD [12]
dataset.
56 G. V. Leite et al.
The application of Markov filters was also used to detect falls, as in the work
of Anderson et al. [3], in which the individual’s silhouette was extracted so that his
characteristics were classified by the Markov filter. The experiments were carried
out on their dataset.
Zerrouki and Houacine [81] described the characteristics of the body through
curvelet coefficients and the ratio between areas of the body. An SVM classifier
performed the posture classification and the Markov filter discriminated between
falls or not falls. The authors reported experiments on the URFD [33] and FDD [12]
datasets.
In addition to the methods mentioned above, the following works made use of
several techniques, such as Yu et al. [80], which obtained their characteristics by
applying head tracking techniques and analysis of shape variation. The characteristics
served as input to a Gaussian classifier. The authors created a dataset for the tests.
Zerrouki et al. [82] segmented the frames between the foreground and background
and applied another segmentation on the human body, dividing it into five partitions.
The body segmentations were fed to an AdaBoost classifier, which obtained 96%
accuracy on the URFD [33] dataset. Finally, Xu et al. [79] published a survey that
evaluates several fall detection systems.
3 Basic Concepts
In this section, we describe in detail the concepts necessary to understand the pro-
posed methodology.
Deep neural networks (DNNs) are a class of machine learning algorithms, in which
several layers of processing are used to extract and transform characteristics from
the input data and the backpropagation algorithm enables the network to learn the
complex patterns of the data. The input information for each layer is the same as the
output of the previous one, except for the first layer, in which data is input, and the
last layer, from which the outputs are extracted [22]. This structure is not necessarily
fixed, some layers can have two other layers as input or several outputs.
Deng and Yu [15] mentioned some reasons for the growing popularity of deep net-
works in recent years, which include their results in classification problems, improve-
ments in graphic processing units (GPUs), the appearance of tensor processing units
(TPUs), and the amount of data available digitally.
The layers of a deep network can be organized in different ways, to suit the task
at hand. The manner a DNNs is organized is called the “architecture”, and some of
Three-Stream Convolutional Neural Network … 57
them have become well known because of their performance in image classifica-
tion competitions. Some of them are AlexNet [30], LeNet [35], VGGNet [62] and
ResNet [24].
Convolutional neural networks (CNNs) are a subtype of deep networks, their structure
is similar to that of a DNN, such that information flows from one layer to the next.
However, on CNN the data is also processed by convolutional layers, which applies
various convolution operations and resizes the data, before sending it on to the next
layer.
These convolution operations allow the network to learn low-level features in
the first layers and merge them in the following layers to learn high-level features.
Although not mandatory, usually at the very end of a convolutional network there
are some fully connected layers.
To this work’s scope, two CNN architectures are relevant: (i) VGG-16 [62] and
(ii) Inception [66]. The VGG-16 was the winner of the ImageNet Large Scale Visual
Recognition Challenge (ILSVRC) 2014 competition, with a 7.3% error in the location
category. Its choice of using small filters, convolutions of 3 × 3, stride 1, padding 1 and
max-pooling of 2 × 2 with stride 2, allowed the network to be deeper, without being
computationally prohibitive. The VGG-16 has 16 layers and 138 million parameters,
which is considered small for deep networks. The largest load of computations in this
network occurs in the first layers, since, after them, the layers of pooling considerably
reduce the load to the deeper layers. Figure 2 illustrates the VGG-16 architecture.
The second architecture, Inception V1 [66], was the winner of ILSVRC 2014,
in the same year as VGG-16, however, on the classification category, with an error
of 6.7%. This network was developed to be deeper and, at the same time, more
computationally efficient. The Inception architecture has 22 layers and only 5 million
parameters. Its construction consists of stacking several modules, called Inception,
illustrated in Fig. 3a.
The modules were designed to create something of a network within the network,
in which several convolutions and max-pooling operations are performed in parallel
and, at the end of these, the features are concatenated to be sent to the next module.
However, if the network was composed of Inception modules, as illustrated in Fig. 3a,
it would perform 850 million operations total. To reduce this number, bottlenecks
Conv 1-1
Conv 1-2
Conv 2-1
Conv 2-2
Conv 3-1
Conv 3-2
Conv 3-3
Conv 4-1
Conv 4-2
Conv 4-3
Conv 5-1
Conv 5-2
Conv 5-3
Output
Input
Pool
Pool
Pool
Pool
Pool
FC
FC
FC
1x1 Convolutions
3x3 Convolutions
Filter
Previous Layer 5x5 Convolutions
Concatenation
1x1 Convolutions
Filter
Previous Layer 1x1 Convolutions 5x5 Convolutions
Concatenation
were created. Bottlenecks reduce the number of operations to 358 million. They are
1 × 1 convolutions that preserve the spatial dimension while decreasing the depth
of the features. To do so, they were placed before the convolutions 3 × 3, 5 × 5 and
after the max-pooling 3 × 3, as illustrated in Fig. 3b.
Transfer learning consists of reusing a network whose weights were trained in another
context, usually a related and more extensive context. In addition to usually improving
a network’s performance, the technique also collaborates to decrease the convergence
time and help on scenarios with not enough training data available [85].
Typically in classification problems, the transfer happens from the ImageNet [14]
dataset, which is one of the largest and well-known datasets. The goal is so that the
Three-Stream Convolutional Neural Network … 59
network can learn enough complex patterns that are generic enough to be used in
another context.
The literature does not provide a universal definition of fall [29, 45, 67], however,
some health agencies have created their definition, which can be used to describe a
general idea of a fall.
The Joint Commission [68] defines a fall as “[...] a fall may be described as an
unintentional change in position coming to rest on the ground, floor, or onto the
next lower surface (e.g., onto a bed, chair or bedside mat). [...]”. The World Health
Organization [77] defines it as “[...] an event which results in a person coming to
rest inadvertently on the ground or floor or other lower level. [...]”. The definition
of the National Center for Veterans Affairs [70] is as “loss of upright position that
results in landing on the floor, ground, or an object or furniture, or a sudden, uncon-
trolled, unintentional, non-purposeful, downward displacement of the body to the
floor/ground or hitting another object such as a chair or stair [...]”. In this work, we
describe a fall as an involuntary movement that results in an individual landing on
the ground.
Optical flow is a technique that deduces pixel movement, caused by the displacement
of the object or the camera. It is a vector that represents the movement of a region,
extracted from a sequence of frames. It assumes that the pixels will not leave the
frame region, and is a local method, thus, the difficulty to calculate it on uniform
regions.
The extraction of the optical flow is performed by comparing two consecutive
frames, and its representation is a vector of direction and magnitude. Consider I
a video frame and I (x, y, t) a pixel in this frame. A frame analyzed in a future
time dt is described in Equation 1 as a function of the pixel I (x, y, t) displacement
of (d X, dY ). The Eq. 2 is obtained from a Taylor series divided by dt and has the
following gradients f t , f x and f y (Eq. 3). The components of the optical flow are
the values of u and v (Eq. 4) and can be obtained by several methods such as Lucas–
Kanade [44] or Farnebäck [19]. Figure 4 illustrates some optical flow frames.
60 G. V. Leite et al.
Fig. 4 Examples of extracted optical flow. Each optical flow frame was extracted from the above
frame and its next in sequence. The pixels colors indicate the movement direction, and its brightness
relates to the magnitude of the movement
In the context of image processing, a saliency map is a feature of the image that
represents regions of interest in an image. The map is generally presented in shades
Three-Stream Convolutional Neural Network … 61
Resulting
1 2 3 ... N
Visual Rhythm
Source Frames
Fig. 5 Visual rhythm construction process. On the left, the zigzag manner in which each frame is
traversed. On the right, the rhythm construction through the column concatenation
Fig. 6 Visual rhythm examples. Each frame on the first row illustrates a different video and, bellow,
the visual rhythm for the corresponding entire video
of gray so that regions with low interest are black and regions with high interest are
white. In the context of deep learning, the saliency map is the activation map of a
classifier, highlighting the regions that greater contributed to the output classification.
In its origin, the salience map was extracted as a way of understanding what the deep
networks were learning and it is still used in that way, as in the work of Li et al. [39].
The salience map was used by Zuo et al. [87] as a feature to classify actions
on an egocentric point of view, in which the source of information is a camera
that corresponds to the subject’s first-person view. Using the saliency map in the
egocentric context has its roots on the assumption that important aspects of the
action take place in the foreground, instead of the background.
The saliency map can be obtained in several ways, as shown by Smilkov et al. [64],
Sundararajan et al. [65] and Simonyan et al. [63]. Figure 7 illustrates the extraction
of the saliency map.
62 G. V. Leite et al.
Fig. 7 Saliency map examples. Pixels varying from black to white, according to the region impor-
tance to the external classifier
It is a technique to derive the posture of one or more human beings. Different sensors
are used as input to this technique, such as depth sensors from a Microsoft Kinect or
images from a camera.
The algorithm proposed by Cao et al. [7], OpenPose, is notable for its effectiveness
in estimating the pose of individuals in video frames. OpenPose operates with a two-
stage network in search of 18 body joints. On its first stage, the method creates
probability maps of the joint’s position and the second stage predicts affinity fields
between the limbs found. The affinity is represented by a 2D vector, which encodes
the position and orientation of each body limb. Figure 8 shows the posture estimation
of some frames.
Fig. 8 Example of the posture estimation. Each circle represents a joint, whereas the edges represent
the limbs. Each limb has a specific color assigned to it
Three-Stream Convolutional Neural Network … 63
4 Proposed Method
4.1 Preprocessing
In the related literature to deep learning, knowledge about the positive influence
of preprocessing steps is ubiquitously present. In this sense, some processes were
executed to better tackle the task in question.
In this work, the preprocessing step, represented by the green block in Fig. 9,
consists of extracting features that can capture the various aspects of a fall and
applying data augmentation techniques.
The following features were extracted and later inputted to the network in a specific
manner. The posture estimation was extracted using a bottom-up approach, with the
OpenPose algorithm by Cao et al. [7]. The extracted frames were fed into the network
one at a time and the inference results are obtained frame by frame (Fig. 8). Regarding
the visual rhythm, an algorithm was implemented by us, such that each video has only
one visual rhythm. This rhythm frame was fed to the network repeatedly, so that its
inference output could be paired with the other features (Fig. 6). The saliency map was
obtained using the SmoothGrad technique, proposed and implemented by Smilkov
et al. [64], which acts on top of a previously existing technique by Sundararajan et
al. [65]. The frames were fed to the network, once again, one by one (Fig. 7).
The optical flow extraction was carried out with the algorithm proposed by
Farnebäck [19], which describes the dense optical flow (Fig. 4). As a fall event hap-
64 G. V. Leite et al.
Dataset
Pre-Processing
Optical Flow Visual Rhythm Pose Estimation
Stream Stream Stream
Training
Networks
Neural
Trained Trained Trained
Model Model Model
Test
SVM
Fall or
Not Fall
Fig. 9 Overview of the proposed method that illustrates the training and test phases, their steps
and the information flow throughout the method
Three-Stream Convolutional Neural Network … 65
Fig. 10 Illustration of the sliding window and how it moves to create each stack
pens throughout several frames and the flow represents only the relationship between
two of them, we employed a sliding window approach, suggested by Wang et al. [74].
The sliding window feeds the network with a stack of ten frames of optical flow. The
first stack contains frames from 1 to 10 and the second stack frames from 2 to 11,
and so on with stride of 1 (Fig. 10), so each video has N − 10 + 1 stacks, assuming
N as the number of frames in a video, also if at the end of a video there are less then
ten frames, then they do not contribute to the evaluation. The resulting inference of
each stack was associated with the first frame of the stack, it was done this way so
that the optical flow could be paired with the other channels on the network.
Data augmentation techniques were used, when applicable, in the training phase.
The following augmentations were employed: vertical axis mirroring, perspective
transform, cropping and adding mirrored borders, adding values between −20 and
20 to pixels, and adding values between −15 and 15 to the hue and saturation.
The whole augmentation process was done only over the RGB channel, as the other
channels would suffer negatively from it. For instance, the optical flow information
depends strictly on the relationship between frames and its magnitude is expressed by
the brightness of the pixel, mirroring an optical flow frame would break the continuity
between them and adding values to pixels would distort the magnitude of the vector.
4.2 Training
Due to the small amount of available data to experiment on, the training phase of
our method requires the application of transfer learning techniques. Thus, the model
was trained on the ImageNet [14] dataset and, later, on our selected fall dataset, this
whole process is illustrated in Fig. 11. Traditionally, other works might freeze some
layers between the transfer learning and the training, this was not the case, all the
layers were trained with the fall dataset.
66 G. V. Leite et al.
Fig. 11 Transfer learning process. From left to right, initially the model has no trained weights,
then it is trained on the ImageNet [14] dataset and, finally, it is trained on the selected fall dataset
The same transfer learning was done for all feature channels, meaning that, inde-
pendently of the feature that a channel will be trained on, its starting point was
the ImageNet training. After this, each channel is trained with its extracted feature
frames, although, we selected four features, we did not combine all of them at the
same time and kept it up to three combined features at a time.
4.3 Test
The selected features in this work were previously used by some others in the related
literature [7, 69], however, they were employed in a single-stream manner. These
works, along with the work of Simonyan et al. [61], paved our motivation in proposing
a multi-stream methodology that would join the different aspects of each feature, so
that a more robust description of the event could be achieved. This ensemble can be
accomplished in several ways, varying from a simple average between the channels,
through a weighted average, to some automatic methods, like the one used in this
work, the application of an SVM classifier.
The workflow of the test phase illustrated to the right of Fig. 9, begins with the
feature extraction, equal to the training phase. After that, the weights obtained in
the training phase were loaded, and all layers of the model were frozen. Then, each
channel with its specific trained model performed inferences on their input data. The
output vectors were concatenated and sent to the SVM classifier, which in turn would
classify each frame between fall and not fall.
Three-Stream Convolutional Neural Network … 67
5 Experimental Results
In this section, we describe the datasets used in the experiments, the metrics selected
to evaluate our method, the executed experiments, and how our method stands against
others proposed.
5.1 Datasets
Upon reviewing the related literature, a few datasets were found, however, some were
not publicly available, hyperlinks to the data were inactive, or the authors did not
answer our contact attempt. This lead us to select the following human fall datasets:
(i) URFD [33] and (ii) FDD [12].
5.1.1 URFD
Published by Kwolek and Kepski [33], the URFD dataset (University of Rzeszow Fall
Detection Dataset) is made up of 70 video sequences, 30 of which are falls and 40
of everyday activities. Each video has 30 frames per second (FPS), with a resolution
of 640 × 240 pixels and varying lengths.
The fall sequences were recorded with an accelerometer and two Microsoft Kinect
cameras, one camera has a horizontal view of the scene and one with a top–down view,
from the ceiling. The activities of daily living were recorded with a single horizontal
view camera and an accelerometer. The accelerometer information was excluded
from the experiments as it went beyond the scope of the project. As illustrated in
Fig. 12, the dataset has five ADL scenarios, but a single fall scenario, in which the
camera angle and background are the same, changing only the actors in the scene,
this lack of variety is further discussed in the experiments.
The dataset is annotated with the following information:
• Subject’s posture (not lying, lying on the floor and transition).
• Ratio between height and width of the bounding box.
• Ratio between maximum and minimum axes.
• Ratio of the subject’s occupancy in the bounding box.
• Standard deviation of the pixels to the centroid of the X and Z axes.
• Ratio between the subject’s height in the frame and the subject’s standing height.
• Subject’s height.
• Distance from the subject’s center to the floor.
68 G. V. Leite et al.
5.1.2 FDD
The FDD dataset (Fall Detection Dataset) was published by Charfi et al. [12] and
contains 191 video sequences, with 143 being falls and 48 being day-to-day activities.
Each video has 25 FPS, with a resolution of 320 × 240 pixels and varying lengths.
All sequences were recorded with a single camera, in four different environments:
home, coffee room, office, and classroom, illustrated in Fig. 13. Besides, the dataset
presents three experimentation protocols: (i) in which training and testing are created
with videos from the home and coffee room environments, (ii) in which the training
consists of videos from the coffee room and the test with videos from the office and
the classroom and (iii) where the training contains videos from the coffee room,
the office, and the classroom and the test contains videos from the office and the
classroom.
The dataset is annotated with the following information:
• Initial frame of the fall.
• Final frame of the fall.
• Height, width, and coordinates of the center of the bounding box in each frame.
Three-Stream Convolutional Neural Network … 69
TP
Precision = (5)
T P + FP
TP
Sensitivity = (6)
T P + FN
TP +TN
Accuracy = (7)
T P + T N + FP + FN
1
Balanced Accuracy = ( ŷi = yi )ŵi (8)
ŵi i
70 G. V. Leite et al.
in which wi
ŵi = (9)
j 1(y j = yi )w j
These metrics were chosen because of the need to compare our results with those
found in the literature, which, for the most part, reported only: precision, sensitivity,
and accuracy.
Considering that both datasets are unbalanced, so that the negative class has more
than twice as many samples as the positive class (falls), we chose to use the balanced
accuracy instead of some other balanced metric, because, as stated in its name, it
balances the samples and, in doing so, takes the false negatives into account. False
negatives are especially important in fall detection, as ignoring a fall incident can
lead to the health problems described in Sect. 1.
This method was implemented using the Python programming language [73], which
was chosen due to its wide availability of libraries for image analysis and deep
learning applications. Moreover, some libraries were used, such as: SciKit [27],
NumPy [55], OpenCV [6] and Keras [13], and the TensorFlow [1] framework.
Deep learning algorithms are known to be computationally intensive. Their train-
ing and experiments require more computational power than a conventional notebook
can provide and, therefore, were carried out in the cloud on a rental machine from
Amazon AWS, g2.2xlarge, with the following specifications: 1x Nvidia GRID K520
GPU (Kepler), 8x vCPUs, and 15GB of RAM.
5.4 Experiments
Next, we report the performed experiments, their results, and discussion are pre-
sented. The experiments were split between multi-channel, cross-tests, and literature
comparisons. To the knowledge of the authors, the cross-tests, in which the model was
trained in one dataset and tested on the other, is unprecedented among the selected
datasets. The data were split in proportions of 65% for training, 15% for validation
and 20% for testing.
All our experiments were executed using the following parameters: 500 epochs,
along with early stopping and patience of 10, a learning rate of 10−5 , mini-batches
of 192, 50% of dropout, Adam optimizer, and we trained to minimize the validation
loss function. In order, the results were reported on the URFD dataset, followed by
the one from the FDD set.
The results obtained on the URFD base are shown in Table 1, and are organized in
decreasing order of the balanced accuracy. In this first experiment, the combination
Three-Stream Convolutional Neural Network … 71
of the optical flow and RGB channels obtained the best result, nevertheless, the pose
estimation channels obtained the worst results.
Table 2 shows once again the efficacy of the optical flow channel, however, in this
instance, the RGB channels suffered a slight fall in performance, in contrast to the
pose estimation ones, who rose from the worsts results.
The first cross-test was performed with the training done on the URFD dataset,
and the test in the FDD set, as its results are shown in Table 3. Upon executing a
cross-test it is expected that the model will not perform as well since it is a completely
new dataset. The highest balanced accuracy was of 68%, a sharp drop from the 98%
of Table 2, furthermore, a majority of the channels could not surpass the 50% mark,
indicating that the model was barely able to generalize its learning.
We believe that this drop on the balanced accuracy is an effect of the training
dataset quality. As explained previously, the URFD dataset is very limited in its fall
scenarios variability, presenting the same scenario repeatedly. Returning to the multi-
channel test in the URFD dataset (Table 2), it is possible to notice that the channels
with access to the background were among the best results, namely, visual rhythm,
and RGB, but on the cross-test, the best results were obtained from channels without
access to the background information. This could indicate that channels with access
to background learned some features present in the scenario and were not able to
detect falls when the scenario changed. In contrast to those without access to the
background, those were still able to detect falls upon scenario changes.
The second cross-test was the opposite experiment, with training done on the
FDD dataset, and test on the URFD set (Table 4). Although there is, once again,
a sharp drop in the balanced accuracy, it performed much better than the previous
72 G. V. Leite et al.
test, keeping a maximum balanced accuracy of 84% due to the optical flow and pose
estimation channels. Moreover, a similar scenario is observable in both cross-tests,
the channels without access to background were able to better discriminate between
fall and not fall.
Three-Stream Convolutional Neural Network … 73
Given that on the second cross-test, even the channels with access to background
faced an improvement in their performance, one might argue that the background
had nothing to do with it. However, it is important to reiterate that the FDD dataset
is more heterogeneous than the URFD, to the point in which a fall scenario and
an ADL activity were recorded in the same scenario. This variability added to the
intrinsic nature of the 3D model, which creates internal temporal relation between
the input data, probably allowed the channels with access to the background to focus
on features of the fall movement itself. Thus, the rising number of channels with
more than 50% of balanced accuracy.
At last, we compared our best results and the ones from our previous work with those
found in the literature. As stated before, we did not compare our balanced accuracy,
since no other work reported theirs. The results regarding the URFD dataset are
reported in Table 5, while those on the FDD set are shown in Table 6, both are sorted
in decreasing order of accuracy.
Our Inception 3D architecture surpassed or matched the reviewed works on both
datasets, as well as our previous method using the VGG-16 architecture. However, the
VGG-16 was not able to surpass the work of Lu et al. [43], which curiously employs
a 3D method, an RNN with LSTM architecture, explaining its performance.
74 G. V. Leite et al.
In this work, we presented and compared our deep neural networks method to detect
human falls in video sequences, using a 3D architecture, the Inception V1 3D net-
work. The training and evaluation of the method were performed on two public
datasets and, regarding its effectiveness, outperformed or matched the related liter-
ature.
Three-Stream Convolutional Neural Network … 75
The results pointed out the importance of temporal information in the detection
of falls, both because the temporal channels were always among the best results,
especially the optical flow, and in the improvement obtained by the 3D method when
compared to our previous work on the VGG-16 architecture. The 3D method was also
able to generalize its learning to a never seen dataset, and this ability to generalize
the learning indicates that the 3D method can be considered a strong candidate to
compose the elder monitoring system.
This is evidenced throughout all the results shown in Sect. 5, since the temporal
channels are always among the most effective, except for two cases shown in Tables 2
and 3, in which the third-best result was the combination of the spatial channels of
salience and pose. This can be attributed to the fact that the 3D architecture itself
provides a temporal relationship between the data.
Our conclusion about the importance of temporal information to the fall classifi-
cation is corroborated by other works found in the literature, such as those by Meng
et al. [48] and Carreira and Zisserman [10], who stated the same for the classification
of actions in videos. This indicates that other deep learning architectures, such as
those described in [75], could also be used for this application.
In addition, as our results surpassed the reviewed works, the method demonstrated
itself to be effective in detecting falls. In a specific instance, our method matched
the results of the work of Lu et al. [43], in which the author makes use of an LSTM
architecture that, like ours, creates temporal relationships between the input data.
Innovatively, cross-tests were performed between the datasets. The results of these
tests showed a known, however, interesting facet of neural networks, in which the
minimization function finds a local minimum that does not correspond with the initial
objective of the solution. During the training, some channels of the network, learned
aspects of the background of the images to classify the falls, instead of focusing on
the aspects of the person’s movement on video.
The method developed in our work is corroborated by some factors, such as the
evaluation through the balanced accuracy, the tests being performed on two different
datasets, the heterogeneity of the FDD dataset, the execution of the cross-tests and
the comparisons between the various channel combinations. On the other hand, the
work also deals with some difficulties, such as (i) the low variability of the fall videos
in the URFD set, (ii) the fact that in the cross tests many combinations of channels
obtained only 50% of balanced accuracy, and (iii) the use of simple accuracy as a
means of comparison with the literature. However, the proposed method suppresses
these counterparts, remaining relevant to the problem at hand.
The effectiveness of the method shows that, if trained on a robust enough dataset,
it can extract the temporal patterns necessary to classify scenarios between fall and
non-fall. Admittedly, there is an expected drop in balanced accuracy in the cross
tests. Regarding fall detection, this work is one of the most accurate approaches and
would be a great contribution as a module in an integrated system of assistance to
the elderly.
Concerning future work, some points can be mentioned: (i) exploring other
datasets that may contain a greater variety of scenarios and actions, (ii) integrat-
ing fall detection into a multi-class system, (iii) experimenting with cheaper features
76 G. V. Leite et al.
to be extracted, (iv) adapting the method to work in real time, either through cheaper
channels or a lighter architecture, and (v) dealing with input as a stream of videos,
instead of clips, because in a real scenario, the camera would continuously feed the
system, which would further unbalance classes.
The contributions of this work are presented in the form of a human fall detection
method, implemented and publicly available in the repository [38], as well as the
experimentation using multi-channels and different datasets, which generated not
only a discussion about which metrics are more appropriate for evaluating fall solu-
tions, but also a discussion of the quality of the datasets used in these experiments.
Acknowledgements The authors are thankful to FAPESP (grant #2017/12646-3), CNPq (grant
#309330/2018-7), and CAPES for their financial support, as well as Semantix Brasil for the infras-
tructure and support provided during the development of the present work.
References
15. L. Deng, D. Yu, Deep learning: methods and applications. Found. Trends Signal Process.
7(3–4), 197–387 (2014)
16. A. Edgcomb, F. Vahid, Automated fall detection on privacy-enhanced video, in Annual Inter-
national Conference of the IEEE Engineering in Medicine and Biology Society (2012), pp.
252–255
17. K. Fan, P. Wang, S. Zhuang, Human fall detection using slow feature analysis. Multimed. Tools
Appl. 78(7), 9101–9128 (2018a)
18. Y. Fan, G. Wen, D. Li, S. Qiu, M.D. Levine, Early event detection based on dynamic images
of surveillance videos. J. Vis. Commun. Image Represent. 51, 70–75 (2018b)
19. G. Farnebäck, Two–frame motion estimation based on polynomial expansion, in Scandinavian
Conference on Image Analysis (2003), pp. 363–370
20. S. Gasparrini, E. Cippitelli, E. Gambi, S. Spinsante, J. Wåhslén, I. Orhan, T. Lindh, Proposal
and experimental evaluation of fall detection solution based on wearable and depth data fusion,
in International Conference on ICT Innovations (Springer, 2015), pp. 99–108
21. M.A. Goodale, A.D. Milner, Separate visual pathways for perception and action. Trends Neu-
rosci. 15(1), 20–25 (1992)
22. I. Goodfellow, Y. Bengio, A. Courville, Y. Bengio, Deep Learning (MIT Press, 2016)
23. F. Harrou, N. Zerrouki, Y. Sun, A. Houacine, Vision-based fall detection system for improving
safety of elderly people. IEEE Instrum. & Meas. Mag. 20(6), 49–55 (2017)
24. K. He, X. Zhang, S. Ren, J. Sun, Deep residual learning for image recognition, in IEEE
Conference on Computer Vision and Pattern Recognition (2016), pp. 770–778
25. D.L. Heymann, T. Prentice, L.T. Reinders, The World Health Report: a Safer Future: global
Public Health Security in the 21st Century (World Health Organization, 2007)
26. Z. Huang, Y. Liu, Y. Fang, B.K. Horn, Video-based fall detection for seniors with human pose
estimation, in 4th International Conference on Universal Village (IEEE, 2018), pp. 1–4
27. E. Jones, T. Oliphant, P. Peterson, SciPy: open source scientific tools for python (2001). http://
www.scipy.org
28. O.O. Khin, Q.M. Ta, C.C. Cheah, Development of a wireless sensor network for human fall
detection, in International Conference on Real-Time Computing and Robotics (IEEE, 2017),
pp. 273–278
29. Y. Kong, J. Huang, S. Huang, Z. Wei, S. Wang, Learning spatiotemporal representations for
human fall detection in surveillance video. J. Vis. Commun. Image Represent. 59, 215–230
(2019)
30. A. Krizhevsky, I. Sutskever, G.E. Hinton, Imagenet classification with deep convolutional
neural networks. Adv. Neural Inf. Process. Syst. 25, 1097–1105 (2012)
31. T. Kukharenko, V. Romanenko, Picking a human fall detection algorithm for wrist–worn elec-
tronic device, in IEEE First Ukraine Conference on Electrical and Computer Engineering
(2017), pp. 275–277
32. V.S. Kumar, K.G. Acharya, B. Sandeep, T. Jayavignesh, A. Chaturvedi, Wearable sensor–based
human fall detection wireless system, in Wireless Communication Networks and Internet of
Things (Springer, 2018), pp. 217–234
33. B. Kwolek, M. Kepski, Human fall detection on embedded platform using depth maps and
wireless accelerometer. Comput. Methods Programs Biomed. 117(3), 489–501 (2014)
34. B. Kwolek, M. Kepski, Improving fall detection by the use of depth sensor and accelerometer.
Neurocomputing 168, 637–645 (2015)
35. Y. LeCun, L. Bottou, Y. Bengio, P. Haffner, Gradient-based learning applied to document
recognition. Proc. IEEE 86(11), 2278–2324 (1998)
36. T. Lee, A. Mihailidis, An intelligent emergency response system: preliminary development and
testing of automated fall Detection. J. Telemed. Telecare 11(4), 194–198 (2005)
37. G. Leite, G. Silva, H. Pedrini, Fall detection in video sequences based on a three-stream
convolutional neural network, in 18th IEEE International Conference on Machine Learning
and Applications (ICMLA) (Boca Raton-FL, USA, 2019), pp. 191–195
38. G. Leite, G. Silva, H. Pedrini, Fall detection (2020). https://github.com/Lupins/fall_detection
78 G. V. Leite et al.
39. H. Li, K. Mueller, X. Chen, Beyond saliency: understanding convolutional neural networks
from saliency prediction on layer-wise relevance propagation. Comput. Res. Repos. (2017a)
40. X. Li, T. Pang, W. Liu, T. Wang, Fall detection for elderly person care using convolutional
neural networks, in 10th International Congress on Image and Signal Processing, BioMedical
Engineering and Informatics (2017b), pp. 1–6
41. W.N. Lie, A.T. Le, G.H. Lin, Human fall-down event detection based on 2D skeletons and deep
learning approach, in International Workshop on Advanced Image Technology (2018), pp. 1–4
42. B.S. Lin, J.S. Su, H. Chen, C.Y. Jan, A fall detection system based on human body silhouette,
in 9th International Conference on Intelligent Information Hiding and Multimedia Signal
Processing (IEEE, 2013), pp. 49–52
43. N. Lu, Y. Wu, L. Feng, J. Song, Deep learning for fall detection: 3D-CNN combined with
LSTM on video kinematic data. IEEE J. Biomed. Health Inform. 23(1), 314–323 (2018)
44. B.D. Lucas, T. Kanade, An iterative image registration technique with an application to stereo
vision, in International Joint Conference on Artificial Inteligence (1981), pp. 121–130
45. F. Luna-Perejon, J. Civit-Masot, I. Amaya-Rodriguez, L. Duran-Lopez, J.P. Dominguez-
Morales, A. Civit-Balcells, A. Linares-Barranco, An automated fall detection system using
recurrent neural networks, in Conference on Artificial Intelligence in Medicine in Europe
(Springer, 2019), pp. 36–41
46. M.M. Lusardi, S. Fritz, A. Middleton, L. Allison, M. Wingood, E. Phillips, Determining risk of
falls in community dwelling older adults: a systematic review and meta-analysis using posttest
probability. J. Geriatr. Phys. Ther. 40(1), 1–36 (2017)
47. X. Ma, H. Wang, B. Xue, M. Zhou, B. Ji, Y. Li, Depth-based human fall detection via shape
features and improved extreme learning machine. J. Biomed. Health Inform. 18(6), 1915–1922
(2014)
48. L. Meng, B. Zhao, B. Chang, G. Huang, W. Sun, F. Tung, L. Sigal, Interpretable
Spatio-Temporal Attention for Video Action Recognition (2018), pp. 1–10. arXiv preprint
arXiv:181004511
49. W. Min, H. Cui, H. Rao, Z. Li, L. Yao, Detection of human falls on furniture using scene analysis
based on deep learning and activity Characteristics. IEEE Access 6, 9324–9335 (2018)
50. M.N.H. Mohd, Y. Nizam, S. Suhaila, M.M.A. Jamil, An optimized low computational algorithm
for human fall detection from depth images based on support vector machine classification,
in IEEE International Conference on Signal and Image Processing Applications (2017), pp.
407–412
51. T.P. Moreira, D. Menotti, H. Pedrini, First-person action recognition through visual rhythm
texture description, in International Conference on Acoustics (Speech and Signal Processing,
IEEE, 2017), pp. 2627–2631
52. E.B. Nievas, O.D. Suarez, G.B. García, R. Sukthankar, Violence detection in video using
computer vision techniques, in International Conference on Computer Analysis of Images and
Patterns (Springer, 2011), pp. 332–339
53. Y. Nizam, M.N.H. Mohd, M.M.A. Jamil, Human fall detection from depth images using position
and velocity of subject. Procedia Comput. Sci. 105, 131–137 (2017)
54. A. Núñez-Marcos, G. Azkune, I. Arganda-Carreras, Vision-based fall detection with convolu-
tional neural networks. Wirel. Commun. Mob. Comput. 2017, 1–16 (2017)
55. T.E. Oliphant, Guide to NumPy, 2nd edn. (CreateSpace Independent Publishing Platform, USA,
USA, 2015)
56. L. Panahi, V. Ghods, Human fall detection using machine vision techniques on RGB-D images.
Biomed. Signal Process. Control 44, 146–153 (2018)
57. P.S. Sase, S.H. Bhandari, Human fall detection using depth videos, in 5th International Con-
ference on Signal Processing and Integrated Networks (IEEE, 2018), pp. 546–549
58. C. Schuldt, I. Laptev, B. Caputo, Recognizing human actions: a local SVM approach, in 17th
International Conference on Pattern Recognition, vol. 3 (IEEE, 2004), pp 32–36
59. K. Sehairi, F. Chouireb, J. Meunier, Elderly fall detection system based on multiple shape
features and motion analysis, in International Conference on Intelligent Systems and Computer
Vision (IEEE, 2018), pp. 1–8
Three-Stream Convolutional Neural Network … 79
60. A. Shojaei-Hashemi, P. Nasiopoulos, J.J. Little, M.T. Pourazad, Video–based human fall detec-
tion in smart homes using deep learning, in IEEE International Symposium on Circuits and
Systems (2018), pp. 1–5
61. K. Simonyan, A. Zisserman, Two-stream convolutional networks for action recognition in
videos. Adv. Neural Inf. Process. Syst. 27, 568–576 (2014a)
62. K. Simonyan, A. Zisserman, Very deep convolutional networks for large-scale image recogni-
tion (2014b), pp. 1–14. arXiv, arXiv:14091556
63. K. Simonyan, A. Vedaldi, A. Zisserman, Deep inside convolutional networks: visualising image
classification models and saliency maps. Comput. Res. Repos. (2013)
64. D. Smilkov, N. Thorat, B. Kim, F. Viégas, M. Wattenberg, Smoothgrad: removing noise by
adding noise (2017), pp. 1–10. arXiv preprint arXiv:170603825
65. M. Sundararajan, A. Taly, Q. Yan, Axiomatic attribution for deep networks, in 34th Interna-
tional Conference on Machine Learning, vol. 70, pp. 3319–3328 (JMLR.org, 2017)
66. C. Szegedy, W. Liu, Y. Jia, P. Sermanet, S. Reed, D. Anguelov, going deeper with convolutions,
in IEEE Conference on Computer Vision and Pattern Recognition (2015), pp. 1–9
67. S.K. Tasoulis, G.I. Mallis, S.V. Georgakopoulos, A.G. Vrahatis, V.P. Plagianakos, I.G. Maglo-
giannis, Deep learning and change detection for fall recognition, in Engineering Applications
of Neural Networks, ed. by J. Macintyre, L. Iliadis, I. Maglogiannis, C. Jayne (Springer Inter-
national Publishing, Cham, 2019), pp. 262–273
68. The Joint Commission, Fall reduction program—definition of a fall (2001)
69. B.S. Torres, H. Pedrini, Detection of complex video events through visual rhythm. Vis. Comput.
34(2), 145–165 (2018)
70. US Department of Veterans Affairs, Falls policy overview (2019). http://www.patientsafety.
va.gov/docs/fallstoolkit14/05_falls_policy_overview_v5.docx
71. F.B. Valio, H. Pedrini, N.J. Leite, Fast rotation-invariant video caption detection based on visual
rhythm. in Iberoamerican Congress on Pattern Recognition (Springer, 2011), pp. 157–164
72. M. Vallejo, C.V. Isaza, J.D. Lopez, Artificial neural networks as an alternative to traditional
fall detection methods, in 35th Annual International Conference of the IEEE Engineering in
Medicine and Biology Society (2013), pp. 1648–1651
73. G. Van Rossum, F.L. Jr Drake, Python reference manual. Tech. Rep. Report CS-R9525, Centrum
voor Wiskunde en Informatica, Amsterdam (1995)
74. L. Wang, Y. Xiong, Z. Wang, Y. Qiao, Towards good practices for very deep two-stream
convnets (2015), pp. 1–5. arXiv preprint arXiv:150702159
75. M. Wani, F. Bhat, S. Afzal, A. Khan, Advances in Deep Learning (Springer, 2020)
76. World Health Organization, Global Health and Aging (2011)
77. World Health Organization, Fact sheet falls (2012)
78. World Health Organization, World Report on Ageing and Health (2015)
79. T. Xu, Y. Zhou, J. Zhu, New advances and challenges of fall detection systems: a survey. Appl.
Sci. 8(3), 418 (2018)
80. M. Yu, S.M. Naqvi, J. Chambers, A robust fall detection system for the elderly in a smart
room, in IEEE International Conference on Acoustics Speech and Signal Processing (2010),
pp. 1666–1669
81. N. Zerrouki, A. Houacine, Combined curvelets and hidden Markov models for human fall
detection. Multimed. Tools Appl. 77(5), 6405–6424 (2018)
82. N. Zerrouki, F. Harrou, Y. Sun, A. Houacine, Vision-based human action classification using
adaptive boosting algorithm. IEEE Sens. J. 18(12), 5115–5121 (2018)
83. Z. Zhang, V. Athitsos, Fall detection by zhong zhang and vassilis athitsos (2020). http://vlm1.
uta.edu/~zhangzhong/fall_detection/
84. S. Zhao, W. Li, W. Niu, R. Gravina, G. Fortino, Recognition of human fall events based on
single tri–axial gyroscope, in IEEE 15th International Conference on Networking, Sensing and
Control (2018), pp. 1–6
85. F. Zhuang, Z. Qi, K. Duan, D. Xi, Y. Zhu, H. Zhu, H. Xiong, Q. He, A comprehensive survey
on transfer learning (2019), pp. 1–27. arXiv preprint arXiv:191102685
80 G. V. Leite et al.
86. Y. Zigel, D. Litvak, I. Gannot, A method for automatic fall detection of elderly people using floor
vibrations and sound-proof of concept on human mimicking doll falls. IEEE Trans. Biomed.
Eng. 56(12), 2858–2867 (2009)
87. Z. Zuo, B. Wei, F. Chao, Y. Qu, Y. Peng, L. Yang, Enhanced gradient-based local feature
descriptors by saliency map for egocentric action recognition. Appl. Syst. Innov. 2(1), 1–14
(2019)
Diagnosis of Bearing Faults in Electrical
Machines Using Long Short-Term
Memory (LSTM)
© The Editor(s) (if applicable) and The Author(s), under exclusive license 81
to Springer Nature Singapore Pte Ltd. 2021
M. A. Wani et al. (eds.), Deep Learning Applications, Volume 2,
Advances in Intelligent Systems and Computing 1232,
https://doi.org/10.1007/978-981-15-6759-9_4
82 R. Sabir et al.
and five features from time–frequency domain by using the Wavelet Packet Decom-
position (WPD). After the extraction of these eight features, the well-known Deep
Learning algorithm Long Short-Term Memory (LSTM) is used for bearing fault clas-
sification. The Deep Learning LSTM algorithm is mostly used in speech recognition
due to its time coherence, but in this paper, the ability of LSTM is also demon-
strated with the fault classification accuracy of 97%. A comparison of the proposed
algorithm is done with the traditional Machine Learning techniques, and it is shown
that the proposed methodology outperforms all the traditional algorithms which are
used for the classification of bearing faults using MCS. The method developed is
independent of the speed and the loading conditions.
1 Introduction
Rolling element bearings are key component in rotating machinery. They support
radial and axial loads by reducing rotational friction and hence carry heavy loads on
the machinery. They ensure critical optimal performance of the machinery, but their
failures can lead to the downtime of the machinery causing significant economic
losses [1, 2]. In an electrical machine, due to the overloading stress and misalign-
ment, bearings are the most susceptible components to be damaged. Bearing faults
are caused by insufficient lubrication, fatigue, incorrect mounting, corrosion, elec-
trical damage or from foreign particles (contamination) normally appear as wear,
indentation, and spalling [3, 4]. Bearing failures contribute to approximately about
50% of the failures in the electrical machines [4]. The structure of the rolling element
bearing is shown in Fig. 1. Nb is the number of balls, Db is the diameter of the balls,
Dc is the diameter of the cage, and β is the contact angle between the ball and
the raceway. Rolling element bearings are comprised of two rings, an inner ring
and an outer ring as shown in Fig. 1. The inner race is normally mounted on the
motor shaft. Between the two rings, some balls or rollers are located. Continued
stress results in material fragments loss on the inner and outer race. Due to the
susceptibility of bearings being damaged, bearing fault diagnosis has attracted the
attention of many researchers. Typical bearing fault diagnosis includes analysis of
different measured signals from the electrical machine with damaged bearings by
using signal-processing methods.
In this paper, only the inner and outer raceway will be analyzed. Each of these
faults corresponds to a characteristic frequency f c , which is given by (1) and (2),
where fr is the rotating frequency of the rotor.
Nb Db
Outer raceway : f 0 = fr 1 − cos β (1)
2 Dc
Nb Db
Inner raceway : f i = fr 1 + cos β (2)
2 Dc
These frequencies result in when the balls rotate and hit the defect causing an
impulsive effect. These characteristic frequencies are a function of bearing geometry
and the rotor frequency. The impulsive effect generates rotor eccentricities at the
characteristic frequencies, causing the air-gap between the rotor and the stator to
be varied. Because of the variation of the air-gap length, the machine inductances
are also varied. These characteristic frequencies and sidebands are derived in [5].
However, these can be only reliably detected under specific conditions as they are
highly influenced by external noise and operating conditions. Most of the times, these
frequencies only appear in the frequency spectrum for faults with high severity.
2 Literature Review
cost. Furthermore, [9] argues that the vibration signal can be easily influenced by
loose screws and resonance frequencies of different parts of the machine, especially
the housing, which may lead to incorrect or misleading fault diagnosis.
According to [5], stator current or Motor Current Signal (MCS) can also be effec-
tively used for bearing fault diagnosis since currents can be easily measured from
the existing frequency inverters, so no additional sensors are required. As already
discussed, damages in the bearings, e.g., a hole, pit, crack, or missing material result
in the creation of shock pulses when the ball passes over those defects. The shock
pulses result in vibration at the bearing characteristic frequencies. This causes the
average air-gap length of the rotor to be varied, resulting in changes in the flux density
(or machine inductances) that consequently appear in the stator current as f b f given
by (3) and modulate the rotating frequency of the machine.
fb f = | fs ± k fc | (3)
In literature, several Machine Learning and Deep Learning methods are employed;
each method serves its own advantages and disadvantages. For example, in [13], an
unsupervised classification technique called the artificial ant clustering is described,
but it requires three voltage and three current sensors, which make the method not
economically appealing. [14, 15] describe Machine Learning method SVM (Support
Vector Machine) for fault diagnosis, but in [14] only outer raceway faults are consid-
ered and accuracy drops from 100% to 95% when multi fault classification is done. In
[15], however the inner raceway fault, outer raceway fault, and cage fault are consid-
ered, the algorithm performs quite well at lower speeds with an average accuracy of
99%, but the accuracy drops to 92% at higher speeds. In [16], Convolutional Neural
Network (CNN) is trained with the amplitude of the selected frequency components
from the motor’s current spectrum. The results show that the accuracy is about 86%,
and only the bearing outer race fault is considered. In [17], Sparse Autoencoders (AE)
are used to extract fault patterns in the stator current and the vibration signals, and
SVM is used for classification. AE gives a 96% accuracy for high severity bearing
faults but 64% for low severity faults, and only the outer race fault is considered.
Finally, in [18], two feature extraction methods are used which include a 1D-CNN
and WPT (Wavelet Packet transform). In the analysis, two types of bearing faults
are classified, inner ring fault and fault due to aluminum powder in the bearing. The
method presented covers a wide speed range and fault severity with 98.8% accuracy
but fails to include the load variation and also requires very high specification training
hardware.
In this paper, stator current is used to diagnose two bearing faults, inner and
outer raceway, using the well-known Deep Learning algorithm, LSTM network.
The bearing fault data considered takes into account different operating conditions
such as speed, load variations, and varying fault severity levels. Hence, the method
developed is independent of the speed and the loading conditions. In the first section
of the paper, the LSTM network is described, then feature extraction and LSTM
training are discussed, and finally the results from the trained LSTM are analyzed.
3 LSTM Network
where
xt is the input to the RNN at time t
ht is the state of the hidden layer at time t
h t−1 is the state of the neural network at t − 1
bh is the bias of the hidden layer
by is the bias of the output
whx are the weights between the hidden and the input layer
whh are the weights of the hidden layer at t − 1 and the hidden layer at t.
RNN is trained across the timesteps using Backpropagation Through Time (BPTT)
[21]. However, because of the multiplication of gradients at timesteps, the gradient
value becomes smaller and smaller or it gets larger and larger, and as a result, the RNN
learning process encounters the issue of gradient vanishing or gradient exploding.
Due to the issue of vanishing and exploding gradients, RNNs are used in limited
applications. This issue is solved by using the LSTM (Long Short-Term Memory)
which includes a memory cell that replaces the hidden RNN units [22].
The memory cell is composed of four gates, forget gate (how much to keep from
previous cell state), input gate (whether to write in the cell), input modulation gate
(how much quantity to write in the cell), and output gate (how much to reveal from
the cell) [23]. LSTMs preserve the error that is backpropagated through time and
layers. Figure 3 Shows the typical LSTM cell. In Fig. 3, σ (sigmoid) represents the
gate activation function, whereas ϕ (tanh) is the input or output node activation. The
LSTM model [21] presented in Fig. 3 is described by (5).
gt = ϕ wgx xt + wgh h t−1 + bg
i t = σ (wi x xt + wi h h t−1 + bi )
f t = σ w f x xt + w f h h t−1 + b f
ot = σ (wox xt + woh h t−1 + bo )
Diagnosis of Bearing Faults in Electrical Machines … 87
st = gt i t + st−1 f t
h t = ϕ(st )ot (5)
where
• wgx , wi x , w f x and wox are the weights at time t between the input and hidden
layer
• wgh , wi h , w f h , and woh are the weights at time t and t − 1 between the hidden
layers
• bg , bi , b f and bo are the biases of the gates
• h t−1 is the value of the hidden layer at time t − 1
• f t ,gt , i t , and ot are the output values of the forget gate, input modulation gate,
input gate, output gate, respectively
• st and st−1 are the current state at time t and t − 1, respectively.
In [24, 25], a benchmark dataset for the purpose of bearing fault diagnosis is devel-
oped. The dataset is composed of synchronously measured currents of two phases
along with vibration signals. The stator currents are measured with current trans-
ducer sampled at 64 kHz, then filtered with a 25 kHz low-pass filter. The dataset is
composed of tests on six undamaged bearings and 26 damaged bearings. In the 26
damaged bearings, 12 have artificially induced damages and 14 have real damages.
In most of the research done on bearing fault diagnosis, only artificial bearing
faults are induced, and data is collected to make Machine Learning models because
they are easy to generate. Nevertheless, these models fail to deliver the expected
diagnosis when used in practical industry applications. In [24, 25], Machine Learning
models were trained with artificially damaged bearings and tested with real damaged
bearing. These models were not able to accurately classify the real damaged bearings.
The reason is that it is not possible to accurately replicate the real damages to the
88 R. Sabir et al.
bearings, artificially. Getting real damaged bearing from the machinery is not easily
possible. Due to the long lifetimes of bearings, bearings are generally replaced before
failure. Hence, it is difficult to find large quantities of real damaged bearings. In
[24, 25], by using scientific test rigs, real damaged bearings are produced by way of
accelerated lifetime tests. An advantage of such technique is that the bearing damages
can be generated by reproducible conditions. However, the disadvantage is that a lot
of time and effort is required for such a process. The real damage in the bearing is
generated with the application of a very high radial force to the bearing on a test
rig. The high radial force applied to the bearing is far greater than what the bearing
can endure which results in damages to appear very sooner in the bearing. Also,
to further accelerate the process, low viscosity oil is used that results in improper
lubrication and speeds up the process. Though the dataset is highly invaluable as it
provides the test data for real and artificial bearing damages, nevertheless this paper
will only focus on the real bearing damages for the purpose of demonstrating the
diagnosis algorithm, but the approach could of course be extended to the bearings
with artificial damages. The datasets of the three classes (healthy, outer race, and
inner race) are described in Table 1. The damages to the bearings are of varying
severity levels. The detailed description of these datasets and geometry of the used
bearing can be found in [24, 25]. All the damaged bearing dataset is composed of
single point damages, which are a result of fatigue, except the dataset KA30 that is
composed of distributed damaged bearing due to plastic deformation.
Each dataset (e.g., K001) has 80 measurements (20 measurements corresponding
to each operating condition) of 4 s each. Different operating conditions described
in the Table 2 are used, so that the Deep Learning algorithm incorporates variations
in the speed and load, and the model is not dependent only for certain operating
conditions. The operating parameters that are varied include the speed, radial force,
Table 2 Operating
No. Rotational speed Load torque [Nm] Radial force [N]
parameters of each dataset
[rpm]
0 1500 0.7 1000
1 900 0.7 1000
2 1500 0.1 1000
3 1500 0.7 400
Diagnosis of Bearing Faults in Electrical Machines … 89
and the load torque. For three settings, the speed of 1500 rpm is used, and in one
setting, speed of 900 rpm is used. Similarly, 0.7 Nm load torque for three settings,
0.1 Nm load torque for one setting, 1000 N radial force for the three settings, and
400 N radial force for one setting are used. All the datasets have damages that are less
than 2 mm in size, except for datasets KI16, KI18, and KA16 which have damages
greater than 2 mm. The temperature is kept constant to about 50 °C throughout all
experiments.
Before the Machine or Deep Learning algorithms are applied, the data is prepro-
cessed, and then the important features are extracted to be used as an input to the
algorithm for fault classification. It is therefore necessary to imply techniques that
extract the best and useful features and get rid of irrelevant or redundant informa-
tion. Removal of noisy, irrelevant, and misleading features gives a more compact
representation and improves the quality of detection and diagnosis [13].
Figure 4a shows the stator current signal of phase 1 of the machines, and Fig. 4b
shows its frequency spectrum. Looking closely in Fig. 4a, slight amplitude variations
of the sine wave amplitude can be observed. These small variations of amplitude could
contain the bearing characteristic frequencies. Hence, they are further analyzed for
feature extraction. The spectrum shows the two dominant frequencies ω0 and 5ω0 .
These frequencies offer no new information and are present in all the current signals
(whether belonging to the healthy or to the damaged bearing). Therefore, removing
Fig. 4 a Stator current, b frequency spectrum of stator current, c stator current after filtration of
ω0 and 5ω0 component, d the frequency spectrum of the filtered stator current
90 R. Sabir et al.
the frequencies from the current signal will give more focus to the frequencies that
result due to the bearing faults. To remove the unwanted frequencies (ω0 and 5ω0 ),
a signal-processing filter is designed in MATLAB that suppresses the frequencies
ω0 and 5ω0 from the current signal. Figure 4c shows the filtered current signal, and
Fig. 4d shows its spectrum. Now, this filtered signal is used for feature extraction.
The spectrum of the filtered signal now contains noise and the characteristic bearing
frequencies (if they are present).
Huo et al. [26] evaluates the statistical features of the vibration signals from the
bearings and concludes that the two features that are most sensitive parameters for
detecting bearing faults are kurtosis and impulse factor. This conclusion well applies
to the current signals, because these two statistical features adequately capture the
impulsive behavior of the faulty bearings. Another time domain feature that is also
useful is the clearance factor, which is also sensitive to the faults in the bearings. The
remaining features are extracted from the time–frequency domain using third-level
WPD (Wavelet Packet Decomposition).
WPD is one of the well-known signal-processing technique, which is widely used
in fault diagnosis. WPD provides useful information in both time and frequency
domains. WPD decomposes the time domain signal into wavelets of various scales
with variable sized windows and reveals the local structure in time–frequency
domain. The advantages of WPD over other signal-processing techniques are that
with WPD transient features can be effectively extracted and features from the full
spectrum can be extracted without the requirement of a specific frequency band. WPD
uses mother wavelets that are basic wavelet functions to expand, compress, and trans-
late the signal by varying the scale frequency and the time shift of the wavelet. This
enables the application at low-scale high frequency for short windows and at high-
scale low frequency for long window. For example, with long window at high scale
and low frequency, higher resolution in time can be achieved for high-frequency
components and high resolution in frequency for lower frequency components. The
wavelet function must meet the following requirement in (6).
|ψ(ω)|2
Cψ = dω < ∞ (6)
|ω|
R
where ψa,b (t) is the continuous mother wavelet which is scaled by factor a and
translated by factor b, and √1|a| is used for energy preservation.
ψa,b (t) acts as a window function whose frequency and time location can be
adjusted by a and b. For example, higher resolution in frequency can be achieved
by smaller values of a; this helps when extracting higher frequency components of
Diagnosis of Bearing Faults in Electrical Machines … 91
the signal. The scaling and translation parameters make wavelet analysis ideal for
non-stationary and non-linear signals in both time and frequency domain. a and b
are continuous and can take any value but when a and b are both discretized we get
DWT (Discrete Wavelet Transform).
Although DWT is an excellent method for time–frequency domain analysis of the
signal, it only considers the low-frequency part and neglects the high-frequency part,
resulting in very bad resolution of high frequency. Therefore, in our case Wavelet
Packet Decomposition (WPD), which is an extension of DWT is used for the time-
frequency analysis of the signal. The difference from DWT is that in WPD all the
detail signals (high-frequency part) are decomposed into further two signals, detail
signal and approximation signal (low-frequency part). Hence, using WPD, a full-
scale analysis in the time–frequency domain can be done. The scaling function ∅(t)
and wavelet function ψ(t) of WPD are described by (8) and (9), respectively.
√
∅(t) = 2 h(k)∅(2t − k) (8)
k
√
ψ(t) = 2 g(k)∅(2t − k) (9)
k
where h(k) and g(k) are low- and high-pass filters, respectively.
The WPD from level j to j + 1 at node n of a signal s( j, n) is given by (10) and
(11). s( j, n) could either be cAj approximation (low frequency) or cDj detail (high
frequency) coefficients.
s j+1,2n = h(m − 2k)s j,n (10)
m
s j+1,2n+1 = g(m − 2k)s j,n (11)
m
With the help of WPD, the signal is decomposed into different frequency bands
[27]. When a damage occurs in a bearing, the effect creates resonance at different
frequencies causing energies to be distributed in different frequency bands depending
on the bearing fault. The energy E( j, n) of WPD coefficient s( j, n) of jth level and
nth node can be computed as shown in (12).
E( j, n) = s( j, n)2 (12)
To perform WPD, a mother wavelet must be selected; the selection of the wavelet
depends on the application, as no wavelet is the absolute best. The wavelet that
works best has its properties or the similarity close to the signal of application. [28]
demonstrates the training of bearing fault data with different wavelets and concludes
that the mother wavelet that adequately captures the bearing fault signatures is the
Daubechies 6 (db6) wavelet. The Daubechies 6 (db6) wavelet also works best for our
case as well and therefore is used for extracting the time–frequency domain features
in this paper.
In [25], a detailed feature selection method based on maximum separation distance
is discussed. Using this method, the relevant features are selected as not all the
features are useful for fault classification. Figure 6 Shows the third-level WPD of
the signal x. From this decomposition, c A3,0 , c A3,1 , c A3,2 , and cD3,2 coefficients are
used, and their energies are calculated. Table 3 shows the detailed list of the eight
features that are selected and used in the diagnosis algorithm. These eight features
from the filtered current signal are able to describe the bearing fault signatures quite
well. From the datasets presented in Table 1, the following steps are considered for
preparing the data for training and testing of the LSTM Network.
1 n
4 RMS (Third-level WPD approximation coefficient 3, 0) R M S= n i=1 c A3,02
i
n
5 Energy of third-level WPD approximation coefficient 3, 0 E A3,0 = i=1 c A3,02 i
n2
6 Energy of third-level WPD approximation coefficient 3, 1 E A3,1 = i =1 c A3,12 i
7 Energy of third-level WPD approximation coefficient 3, 2 E A3,2 = n3i =1 c A3,22i
n4
8 Energy of third-level WPD detail coefficient 3, 2 E D3,2 = i =1 cD3,22
i
1. Step I: The stator current of phase 1 was taken which contains a 4 s measurement
(resulting in 1200 signals from all the datasets, 400 signals each for healthy, inner
race fault, and outer race fault class). The signals are filtered to remove ω0 and
5ω0 components. Then, the features presented in Table 3 are extracted.
2. Step II: All the features extracted in Step I are scaled between 0 and 1, for better
convergence of the algorithm.
3. Step III: The feature data of each class is shuffled within the class itself, so that
training and test data incorporate all the different operating points and conditions,
as each dataset has different fault severities.
4. Step IV: 20% of the data points (i.e., 80 points) from each class are kept for
testing, and 80% of the data points (i.e., 320 data points) from each class are
used for training. Therefore, in total, there are 960 data points in the training set
and 240 points in the testing set.
composing of three output nodes (corresponding to three output classes healthy, inner
ring damage, and outer ring damage) is added. The Binary Cross-entropy function is
used as the loss function, and ADAM optimizer [29] is used to train the parameters
of the LSTM network. In the code implementation, training is done on computer’s
GPU, so that the training process is accelerated.
After the features have been extracted from the datasets and preprocessed, the
feature vector is input to the LSTM Network for training. The parameters of the
training are displayed in Table 5.
Using the parameters, the network is trained for 2500 epochs with a batch size of
64. After the training, the LSTM network was able to achieve an accuracy of 97% on
the testing data and 100% accuracy on training data. Figure 7 Shows the confusion
matrix of the testing results of the LSTM model with rows representing the predicted
class and columns the true class. False positives are displayed in the left column, and
the upper row displays the false negatives. All of the normal or healthy bearing points
are correctly classified. However, some points of the inner race fault are misclassified
as outer race fault, and some points of the outer race are misclassified as inner race
fault. Nevertheless, the LSTM Network with the proposed methodology provides
excellent results in diagnosing the bearing faults and classifying healthy, inner race,
and outer race faults to a great degree of accuracy.
Diagnosis of Bearing Faults in Electrical Machines … 95
In the previous section, LSTM network was used in the diagnosis of the bearing inner
and outer race faults, and the testing accuracy of 97% is achieved, which is greater
than the testing accuracy 93.3% achieved by the ensemble of Machine Learning
algorithms in [25]. Therefore, the Machine Learning methods used in their anal-
ysis were not exposed to different severity bearing measurements, which led to a
lower accuracy in their analysis. In conventional Machine Learning techniques, the
features are manually extracted from the data so that the data patterns are clearer to
the Machine Learning algorithm. For Deep Learning approaches, the algorithm is
capable of automatically extracting high-level features from the data. This approach
of feeding the data directly to the Deep Learning algorithm works for majority of
the cases. However, in the case of bearing fault diagnosis, feeding direct MCS to
the proposed LSTM network failed to give good performance and resulted in a very
poor classification accuracy. The reason for the poor performance of the proposed
LSTM methodology is that the fault characteristic magnitudes are comparable or
lower in magnitude to the signal noise. As a result, LSTM Network is unsuccessful
in extracting these characteristic frequency features. Therefore, manual extraction of
features from MCS is done, and signal-processing techniques such as wavelet trans-
form are required for this task. As eight features are extracted from the filtered current
signal, the traditional Machine Learning methods can also be applied in comparison
to the proposed LSTM method. Table 6 shows the comparison of the traditional
methods with the proposed method. In this testing, all the datasets were combined
and shuffled, and five-fold cross-validation was done. Most of the Machine Learning
96 R. Sabir et al.
Table 6 Comparison of traditional methods with the proposed LSTM method using five-fold
cross-validation
Algorithm Classification accuracy (%)
Multilayer perceptron (layers 150, 100, 50 with ADAM solver and 95.4
ReLU activation)
SVM (Support Vector Machine) 66.1
k nearest neighbor 91.2
Linear regression 53.0
Linear discriminant analysis 56.8
CART (Classification And Regression Trees) 93.6
Gaussian Naive Bayes 46.0
LSTM (Proposed method) 97.0
methods showed poor performance; only MLP (Multilayer Perceptron), kNN, and
CART methods showed promising results. The reason for the better performance
of the LSTM method is that due to its memorizing ability, the algorithm learns
the patterns in these 8 features, as these 8 features form a sequence. Hence, the
proposed LSTM methodology outperforms all the other traditional methods, proving
the superiority of Deep Learning techniques compared to the traditional methods.
Although, the proposed LSTM methodology outperformed the traditional
Machine Learning methods, one disadvantage that Deep Learning algorithms carry
is about high computational requirements. The proposed LSTM Network model was
trained on GPU to accelerate the training performance. However, training of the
Machine Learning algorithms required normal CPU processing and less time, and
therefore if computational resources are available, then Deep Learning algorithms
should always be opted for.
The results from the LSTM model are quite promising, however the dataset is not
large enough and the trained model may not measure up to the same performance if
the model is exposed to data points from a new environment for testing. This happens
due to the randomness of background conditions and different noise conditions of new
data points. Hence, more research is needed in this area, so that the Deep Learning
models are able to adapt to a new environment without sacrificing their performance.
Therefore, it would be helpful to have a broader analysis with a larger dataset.
The paper focused on the diagnosis of the bearing inner and outer race fault. Conven-
tional bearing fault diagnosis involves the use of vibrational signal from the machine’s
accelerometer. However, in this paper, MCS is used instead, for the fault diagnosis. By
using MCS, its potential in effective bearing fault diagnosis has been demonstrated.
With this methodology, the need for additional sensor for monitoring the vibrational
Diagnosis of Bearing Faults in Electrical Machines … 97
signals is eliminated, reducing cost of the fault monitoring system. The Paderborn
university real damaged bearing dataset was considered. From the datasets, the stator
current of one phase was used and processed by removing the redundant frequen-
cies ω0 and 5ω0 . Then, eight features are extracted from this filtered signal, three
features from the time domain, and five features from the time–frequency domain
using third-level WPD. These eight features were scaled and then fed to the Deep
Learning LSTM network with four hidden layers. The LSTM network was able to
show excellent results with a classification accuracy of 97%, even with the dataset
containing data at different operating conditions, i.e., different speed and load, which
proves that the method developed is independent of the machine operating condi-
tions. In the end, the proposed LSTM methodology is compared with the traditional
Machine Learning methods by using five-fold cross-validation, and the proposed
methodology outperformed the traditional methods by achieving more than 1.5%
accuracy than the best performing algorithm. Hence, it has been shown that fault
diagnosis with MCS by using the proposed LSTM algorithm is able to give similar,
if not better performance than the diagnosis with vibrational signals.
For future work, diagnosis of other bearing faults, e.g., ball fault and cage fault
using the stator current will be considered. Secondly, other Deep Learning methods
that are listed in [30] such as CNN-LSTM networks, algorithms that incorporate
the randomness of different working conditions and algorithms that are able to
adapt to different operating environments without compromising on the method’s
performance will be explored. Furthermore, more research will done on increasing
the classification performance of the network by considering denoising LSTMs,
regularization methods, and a larger training dataset.
References
1. R. Sabir, S. Hartmann, C. Gühmann, Open and short circuit fault detection in alternators using
the rectified DC output voltage, in 2018 IEEE 4th Southern Power Electronics Conference
(SPEC) (Singapore, 2018), pp. 1–7
2. R. Sabir, D. Rosato, S. Hartmann, C. Gühmann, Detection and localization of electrical faults
in a three phase synchronous generator with rectifier, in 19th International Conference on
Electrical Drives & Power Electronics (EDPE 2019) (Slovakia, 2019)
3. Common causes of bearing failure | applied. Applied (2019). https://www.applied.com/bearin
gfailure
4. I.Y. Onel, M.E.H. Benbouzid, Induction motors bearing failures detection and diagnosis: park
and concordia transform approaches comparative study, in 2007 IEEE International Electric
Machines & Drives Conference (Antalya, 2007), pp. 1073–1078
5. R.R. Schoen, T.G. Habetler, F. Kamran, R.G. Bartfield, Motor bearing damage detection using
stator current monitoring. IEEE Trans. Ind. Appl. 31(6), 1274–1279 (1995). https://doi.org/10.
1109/28.475697
6. H. Pan, X. He, S. Tang, F. Meng, An improved bearing fault diagnosis method using one-
dimensional CNN and LSTM. J. Mech. Eng. 64(7–8), 443–452 (2018)
7. X. Guo, C. Shen, L. Chen, Deep fault recognizer: an integrated model to denoise and extract
features for fault diagnosis in rotating machinery. Appl. Sci. 7(41), 1–17 (2017)
98 R. Sabir et al.
8. H. Shao, H. Jiang, Y. Lin, X. Li, A novel method for intelligent fault diagnosis of rolling
bearings using ensemble deep autoencoders. Knowl.-Based Syst. 119, 200–220 (2018)
9. D. Filbert, C. Guehmann, Fault diagnosis on bearings of electric motors by estimating the
current spectrum. IFAC Proc. 27(5), 689–694 (1994)
10. S. Yeolekar, G.N. Mulay, J.B. Helonde, Outer race bearing fault identification of induction
motor based on stator current signature by wavelet transform, in 2017 2nd IEEE Interna-
tional Conference on Recent Trends in Electronics, Information & Communication Technology
(RTEICT) (Bangalore, 2017), pp. 2011–2015
11. F. Ben Abid, A. Braham, Advanced signal processing techniques for bearing fault detection in
induction motors, in 2018 15th International Multi-Conference on Systems, Signals & Devices
(SSD) (Hammamet, 2018), pp. 882–887
12. A. Bellini, F. Immovilli, R. Rubini, C. Tassoni, Diagnosis of bearing faults of induction
machines by vibration or current signals: a critical comparison, in 2008 IEEE Industry
Applications Society Annual Meeting (Edmonton, AB, 2008), pp. 1–8
13. A. Soualhi, G. Clerc, H. Razik, Detection and diagnosis of faults in induction motor using
an improved artificial ant clustering technique. IEEE Trans. Ind. Electron. 60(9), 4053–4062
(2013)
14. S. Gunasekaran, S.E. Pandarakone, K. Asano, Y. Mizuno, H. Nakamura, Condition monitoring
and diagnosis of outer raceway bearing fault using support vector machine, in 2018 Condition
Monitoring and Diagnosis (CMD) (Perth, WA, 2018), pp. 1–6
15. I. Andrijauskas, R. Adaskevicius, SVM based bearing fault diagnosis in induction motors
using frequency spectrum features of stator current, in 2018 23rd International Conference on
Methods & Models in Automation & Robotics (MMAR) (Miedzyzdroje, 2018), pp. 826–831
16. S.E. Pandarakone, M. Masuko, Y. Mizuno, H. Nakamura, Deep neural network based bearing
fault diagnosis of induction motor using fast fourier transform analysis, in 2018 IEEE Energy
Conversion Congress and Exposition (ECCE) (Portland, OR, 2018), pp. 3214–3221
17. J.S. Lal Senanayaka, H. Van Khang, K.G. Robbersmyr, Autoencoders and data fusion based
hybrid health indicator for detecting bearing and stator winding faults in electric motors, in
2018 21st International Conference on Electrical Machines and Systems (ICEMS) (Jeju, 2018),
pp. 531–536
18. I. Kao, W. Wang, Y. Lai, J. Perng, Analysis of permanent magnet synchronous motor fault
diagnosis based on learning. IEEE Trans. Instrum. Meas. 68(2), 310–324 (2019)
19. A beginner’s guide to LSTMs and recurrent neural networks (Skymind, 2019). https://skymind.
ai/wiki/lstm
20. S. Zhang, S. Zhang, B. Wang, T.G. Habetler, Machine learning and deep learning algorithms
for bearing fault diagnostics-a comprehensive review (2019). arXiv preprint arXiv:1901.08247
21. Z.C. Lipton, J. Berkowitz, C. Elkan, A critical review of recurrent neural networks for sequence
learning (2015). arXiv preprint arXiv:1506.00019
22. S. Hochreiter, J. Schmidhuber, Long short-term memory. Neural Comput. 9, 1735–1780. (1997)
(source: Stanford CS231N)
23. F. Immovilli, A. Bellini, R. Rubini et al., Diagnosis of bearing faults of induction machines by
vibration or current signals: a critical comparison. IEEE Trans. Ind. Appl. 46(4), 1350–1359
(2010)
24. Konstruktions-und Antriebstechnik (KAt)—Data Sets and Download (Universität Pader-
born), Mb.uni-paderborn.de (2019). https://mb.uni-paderborn.de/kat/forschung/datacenter/bea
ring-datacenter/data-sets-and-download/
25. C. Lessmeier, J.K. Kimotho, D. Zimmer, W. Sextro, Condition monitoring of bearing damage
in electromechanical drive systems by using motor current signals of electric motors: a bench-
mark data set for data-driven classification, in Proceedings of the European Conference of the
Prognostics and Health Management Society (2016), pp. 05–08
26. Z. Huo, Y. Zhang, P. Francq, L. Shu, J. Huang, Incipient fault diagnosis of roller bearing
using optimized wavelet transform based multi-speed vibration signatures. IEEE Access 5,
19442–19456 (2017)
Diagnosis of Bearing Faults in Electrical Machines … 99
27. X. Wang, Z. Lu, J. Wei, Y. Zhang, Fault diagnosis for rail vehicle axle-box bearings based on
energy feature reconstruction and composite multiscale permutation entropy. Entropy 21(9),
865 (2019)
28. S. Djaballah, K. Meftah, K. Khelil, M. Tedjini, L. Sedira, Detection and diagnosis of fault
bearing using wavelet packet transform and neural network. Frattura ed Integrità Strutturale
13(49), 291–301 (2019)
29. D.P. Kingma, J. Ba, Adam: a method for stochastic optimization (2014). arXiv preprint arXiv:
1412.6980
30. M.A. Wani, F.A. Bhat, S. Afzal, A.L. Khan, Advances in Deep Learning (Springer, 2020)
Automatic Solar Panel Detection
from High-Resolution Orthoimagery
Using Deep Learning Segmentation
Networks
Abstract Solar panel detection from aerial or satellite imagery is a very convenient
and economical technique for counting the number of solar panels on the rooftops
in a region or city and also for estimating the solar potential of the installed solar
panels. Detection of accurate shapes and sizes of solar panels is a prerequisite for
successful capacity and energy generation estimation from solar panels over a region
or a city. Such an approach is helpful for the government to build policies to inte-
grate solar panels installed at home, offices, and buildings with the electric grids. This
study explores the use of various deep learning segmentation algorithms for auto-
matic solar panel detection from high-resolution ortho-rectified RGB imagery with
resolution of 0.3 m. We compare and evaluate the performance of six deep learning
segmentation networks in automatic detection of the distributed solar panel arrays
from satellite imagery. The networks are tested on real data and augmented data.
Results indicate that deep learning segmentation networks work well for automatic
solar panel detection from high-resolution orthoimagery.
1 Introduction
The world is exploring to utilize more renewable energy sources as the non-renewable
energy sources are getting depleted. One of the most important and abundantly avail-
able renewable energy sources is solar energy. For utilizing solar energy, solar panels
are installed on ground and roof tops of buildings to convert solar energy into electric
energy. Throughout the world, there has been a surge in installing solar panels to get
the most out of this form of energy source. One of the challenges is to detect solar
panels installed on ground and buildings from aerial imagery. Accurate detection
© The Editor(s) (if applicable) and The Author(s), under exclusive license 101
to Springer Nature Singapore Pte Ltd. 2021
M. A. Wani et al. (eds.), Deep Learning Applications, Volume 2,
Advances in Intelligent Systems and Computing 1232,
https://doi.org/10.1007/978-981-15-6759-9_5
102 T. Mujtaba and M. A. Wani
of solar panels with accurate shapes and sizes is a prerequisite for determining the
capacity of energy generation from these panels. For this problem, semantic segmen-
tation is required to accurately detect the solar panels from the aerial imagery. For the
past several years, machine learning approaches have been used for many applica-
tions like classification, clustering, and image recognition [1–7], and some research
has been done to analyze and asses the objects like roads, buildings, and vehicles
present in satellite imagery.
In recent past, deep learning has been found to be more effective in image recog-
nition problems than the traditional machine learning techniques [8]. The image
recognition problems include tasks like image classification, object detection, and
semantic segmentation. Deep learning has outperformed the traditional machine
learning techniques in these image recognition tasks. Among these image recog-
nition tasks, the problem of semantic segmentation is one of the key research areas.
Semantic segmentation is the process of classifying each pixel of the image to a set
of predefined classes, therefore dividing an image into set of regions where each
region consists of pixels belonging to one common class. A number of traditional
machine learning techniques have been used in segmentation of images in the past.
The general procedure of traditional machine learning techniques is to use well-
established feature descriptors to extract features from an image and use a classifier
like support vector machines or random forest classifier on each pixel to determine
its likeliness of belonging to one of the predefined classes. Such techniques heavily
depend on the procedure that is used to extract features and usually need skilled
feature engineering for designing such procedures. It is necessary to determine which
features are important for an image. When the number of classes increases, feature
extraction becomes tedious task. Moreover, such techniques do not involve using
one end-to-end approach, feature extraction and classification is done separately. In
Comparison to it, deep learning models are end-to-end models where both feature
extraction and classification are done through a single algorithm. Further, model
parameters are learnt automatically during the process of learning. Because of the
above reasons, this work has used the deep learning approach for the problem of
solar panel detection.
Deep learning segmentation process has evolved where a deep learning image
classification model has been converted into a segmentation model [9]. The segmen-
tation process has been improved with the introduction of UNet [10] model. The
model essentially consists of an encoder and a decoder with skip connections from
lower layers of encoder to higher layers of decoder to transfer the learned features.
The purpose of encoder is to extract features that are up-sampled in the decoder and
features in the decoder are concatenated with the features extracted in the encoder.
Up-sampling is done by using transposed convolutions. An index-based up-sampling
technique has been introduced in [11]. The technique remembers the index of pixel
values with maximum intensity during max-pooling operation and then uses the same
indexes during up-sampling process. Dilated convolution that uses dilated filters has
been tested in [12–14]. Dilated convolution helps in reduction of learnable parame-
ters and it also helps in aggregating the context of the image. The aim of this work is
Automatic Solar Panel Detection from High-Resolution … 103
2 Related Work
With the availability of high-resolution satellite images and the emergence of deep
learning networks, it has become possible to analyze the high-resolution images for
accurate object detection. Detection of the objects like buildings, roads, and vehicles
from aerial/satellite images has been studied more widely during the past many years
than the detection of solar panels. A few studies [15–17] have been reported in the
literature that explore detection of solar panels from satellite or orthoimagery using
deep learning techniques. VggNet with pretraining has been used in [16] to detect
the solar panels from the aerial imagery dataset given in [18]. The authors have used
a very basic VggNet architecture consisting of six convolutional layers and two fully
connected layers. The model is applied on every pixel to detect whether it belongs to
solar panel or not. Further processing is performed to connect contagious pixels and
declare these as regions. Though this model has been found effective in detecting
solar panels but it does not give the shape and size of panel arrays explicitly.
Authors in [17] proposed a fully convolutional neural network for solar panel
detection. It uses seven convolutional layers with different number of filters and filter
sizes. It first extracts features from the image by using several convolution and max-
pooling layers, it then up-samples the extracted features and uses skip connection
from the shallow layers which contain fine grained features and concatenates them
with coarse layers. Although the work has been able to detect solar panels, the
standard metric used for segmentation is missing.
A deep learning segmentation architecture called SegNet has been used in [15]
for automatic detection of solar panels from ortho-rectified images given is [18].
Again it has used very small image patches of size 41 × 41 for training purposes.
These small patches may not contain much of the diverse background like buildings,
roads, trees, and vehicles, which may restrict the algorithm’s capability to perform
in diverse backgrounds. Further, each image in the dataset, which is of size 5000 ×
5000, would generate approximately 14,872 image patches of size 41 × 41 on which
training and testing are to be done. Very little work has been done on solar panel
detection using deep learning-based segmentation.
104 T. Mujtaba and M. A. Wani
The main purpose of the encoder is to extract features from the images through succes-
sive convolution and pooling layers. The encoding part usually comprises networks
like VggNet, ResNet, and DenseNet with their fully connected layers removed.
a. VggNet
VggNet [19] has various variants and the prominent one is VggNet-16, which consists
13 convolution layers, 5 max-pooling layers, and 3 fully connected layers. VggNet-
19 is also prominent and consists 16 convolutional layers, 5 max-pooling layers,
and 3 fully connected layers. VggNet uses smaller sized 3 × 3 convolution filters
as compared with the AlexNet which uses 7 × 7 convolution filters. Within a given
receptive field, using multiple number of 3 × 3 filters is better than using one larger
sized 7 × 7 filter as it involves less parameters and reduces computational effort.
b. ResNet
The main characteristic feature of DenseNet [21] is that its every layer is connected to
every other layer in a feedforward manner resulting in (L(L + 1))/2 direct connec-
tions where L is the number of layers in the network. For each and every layer,
the features maps of all preceding layers are used as input and its own produced
feature maps are used in all subsequent layers. DenseNet alleviates vanishing gradient
problem, strengthens feature propagation, and encourages feature reuse and reduces
the number of parameters. The network consists of various dense blocks where a
layer is connected to every subsequent layer. Each layer in a dense block consists
of batch normalization, ReLU activation and 3 × 3 convolution. The layers between
two dense blocks are transitions layers consisting of batch normalization, 1 × 1
convolution and average pooling. Such type of network can be easily adopted for
semantic segmentation.
The UNet-based decoder [10] uses transpose convolution for up-sampling followed
by convolutional layers and ReLU activation. UNet extends the concept of skip
connection used in the FCN decoder. Here every encoder is connected to its corre-
sponding decoder unit through a skip connection. The features learnt in the encoder
block are carried over and concatenated with the features of the decoder block.
b. Max-pooling index-based Decoder
Max-pooling index-based decoder [11] uses the indices stored during the max-
pooling process of the encoder to up-sample the feature maps in the corresponding
decoder during the decoding process. Unlike the FCN and UNet decoders, no features
are carried to the decoders through skip connections. The up-sampled features are
then convolved with trainable filters.
c. Un-pooling-based Decoder
The various up-sampling techniques that have been used to increase the resolution
of the feature maps in the decoder part are summarized below
Automatic Solar Panel Detection from High-Resolution … 107
a. Nearest Neighbor
Nearest Neighbor up-sampling technique simply copies the pixel value of the
nearest pixel to its neighboring pixel.
b. Bed of Nail
Bed of Nails puts value of a pixel in a particular/fixed position in the output and
the rest of the positions are filled with value 0.
c. Bilinear up-sampling
It calculates a pixel value by interpolating the values from the nearest pixels
which are known but unlike nearest neighbor technique the ratio of contribution
from each nearby pixel matters here and is inversely proportional to the ratio of
their corresponding distance.
d. Max Un-pooling
It remembers the index of the maximum activation during the max-pooling oper-
ation and uses the same index to position the pixel value in the output during the
up-sampling.
e. Transpose Convolution
Transpose convolution is the most effective and most commonly used technique
for image up-sampling in deep learning semantic segmentation because it’s a
learnable up-sampling. The input is padded with zeros when convolution is
applied.
f. Dilated Convolutions
Also known as Atrous convolution was first developed for the efficient compu-
tation of the undecimated wavelet transform. Dilated convolution is a normal
convolution with a wider kernel. The kernel in the convolution is exponentially
expanded to capture more context of the image without increasing the number
of parameters. A normal convolution is a dilated convolution with a dilation rate
equal to 1.
4.1 UNet
4.2 SegNet
The segmentation architecture SegNet [11] for scene understanding applications that
is efficient in terms of memory and computational time is explored here for automatic
detection of solar panels from satellite images. The SegNet architecture consists of
an encoder and decoder like UNet but differs in how up-sampling is done in the
decoder part. The deconvolutional layers used for up-sampling in decoder part of
UNet are time and memory consuming because up-sampling is performed using a
learnable model, which implies filters for up-sampling are learned during the training
process. The SegNet architecture replaces the learnable up-sampling by computing
and memorizing the max-pool indices and later uses these indices to up-sample the
features in the corresponding decoder block to produce sparse feature maps. It then
uses normal convolution with trainable filters to densify these sparse feature maps.
However, there are no skip connections for feature transfer like in UNet. The use of
max-pool indices results in reduced number of model parameters which eventually
takes less time to get trained. The architecture of SegNet used in this study is given in
Fig. 3. The max-pooling index concept is illustrated in Fig. 4 which also distinguishes
between the process of up-sampling used in UNet and SegNet architectures.
Automatic Solar Panel Detection from High-Resolution … 109
4.4 PSPNet
Fig. 6 Dilated convolution in 2D with different dilation rates. a Dilation rate = 1. b Dilation rate
= 2, c dilation rate = 3
objects. The pyramid pooling module uses four different pooling operations: 1 × 1,
2 × 2, 3 × 3, and 6 × 6. The pooling operations generate feature maps of different
sub-regions and form pooled representation for different locations. The output feature
maps from these pooling operations are of varied sizes. It then uses bilinear inter-
polation for up-sampling these features to a size of input resolution. The number of
pyramid levels and size of each level can be modified. The architecture of PSPNet
is shown in Fig. 7.
DeepLab v3+ [12] makes use of an encoder–decoder structure for dense semantic
segmentation is explored for automatic detection of solar panels in satellite images.
The encoder–decoder structure has two advantages: (i) it is capable of encoding
multi-scale contextual information by probing the incoming features with filters
or pooling operations at multiple rates and multiple effective fields-of-view, (ii) it
can capture sharper object boundaries by gradually recovering the spatial informa-
tion through the use of skip connections. It has an additional simple and effective
decoder module to refine the segmentation results especially along object boundaries.
It further applies depth-wise separable convolution to both Atrous Spatial Pyramid
Pooling and decoder modules, resulting in a faster and stronger encoder–decoder
network. The detailed encoder and decoder structure is given in Fig. 8.
Dilated Residual Network [14] uses dilated convolutions in a Residual Network for
classification, and segmentation tasks are explored for automatic detection of solar
panels in satellite images. In convolutional networks, the spatial size of feature maps
gets continuously reduced due to multiple use of pooling and striding operations.
Such a loss in spatial structure limits the model’s ability to produce good results in
Automatic Solar Panel Detection from High-Resolution … 113
The purpose of this study is to detect the location of solar panels in satellite images
of buildings and earth’s surface using deep learning segmentation techniques. The
training step of a segmentation process requires a dataset containing both the images
as well as their corresponding masks which are used as ground truth. The masks
highlight the pixels which correspond to the solar panels. As the dataset containing
both the images and corresponding masks is not publicly available, it was decided to
114 T. Mujtaba and M. A. Wani
use the dataset described in [18] and prepare masks for this dataset before training
the models. The dataset has 601 TIF orthoimages of four cities of California and
contains geospatial coordinates and vertices of about 19,000 solar panels spread
across all images. This work has used images of Fresno city. Each image is 5000-
by-5000 pixels and covers an area of 2.25 km2 . The images are of urban, suburban,
and rural landscape type, allowing the model to get trained on diverse images. The
vertices of the solar panels associated with each image have been utilized to create
polygon areas of white pixels corresponding to solar panels and setting the remaining
pixels as black pixels to represent background. Figure 10 shows a sample image of
size 5000-by-5000 pixels and its sub-image of size 224-by-224 pixels.
To make the models robust, data augmentation has been performed to reduce over-
fitting and improve generalization. The augmentation was achieved by performing
horizontal flip and vertical flip on images. These two augmentations proved useful
in training the models and increasing the segmentation accuracy.
Fig. 10 First row shows an image and its mask of size 5000 × 5000. Second row shows an image
with its mask of size 224 × 224
Automatic Solar Panel Detection from High-Resolution … 115
5.2 Training
All the architectures have been implemented in Python and trained on workstation
with Nvidia Tesla K40 (12 GB) GPU and 30 GB RAM. As image size of 5000-by-
5000 is huge for training purposes and needs a high GPU and RAM configuration, the
images and their corresponding masks have been cropped to size 224 × 224. A total
of 1118 cropped images were selected for augmentation and training. The testing
was done on image crops of the same size. Adam’s learning algorithm with a fixed
learning rate of l × 10−5 for ResNet-based models and l × 10−4 for VggNet-based
models has been used. Training was done from scratch without any pretraining or
transfer learning. The early stopping criterion of 15 epochs to stop the model trainings
has been used.
The performance measure metric used in this study is dice coefficient (f1 score) which
is one of the most widely used metric used in segmentation. This metric is used to
quantify how similar the ground truth annotated segmentation region matches with
the predicted segmentation region of the model. It is defined as the ratio of intersection
(overlap) of two regions to the union of the two regions. Given two sets of pixels
denoted by X and Y, the dice coefficient index is given by:
DC = (2 ∗ |X ∩ Y|)/(|X |U |Y |) (1)
The value of dice coefficient ranges from 0 to 1. Value close to 1 means more
overlap and similarity between the two regions, hence more accurate the predicted
segmentation from the model.
The loss function used in this study is the dice loss (DL) originated from dice
coefficient and was used by [24] and is defined as
DL = 1 − DC
where DL is the dice loss and DC is the dice coefficient defined above in (1). The
dice loss is used to optimize the value of DC during the training process.
The experimental results have been obtained by using augmented as well as original
datasets. The dice coefficients and loss results of training, validation, and testing of
augmented and original datasets are reported here.
116 T. Mujtaba and M. A. Wani
Table 1 shows dice coefficient and loss results of training, validation, and testing
augmented and original datasets on UNet model. It can be seen from Table 1 that
augmented dataset has helped in improving training, validation, and testing accuracy
results. Figure 11 shows the dice coefficient bar graphs of training, validation, and
testing augmented and original datasets on UNet model. As can be observed from
Fig. 11, the data augmentation has produced better DC values for training, validation,
and testing of datasets on UNet model.
Table 2 shows dice coefficient and loss results of training, validation, and testing
augmented and original datasets on SegNet model. It can be seen from Table 2 that
augmented dataset has helped in improving training, validation, and testing accuracy
results. The training, validation, and testing results of dice coefficients for augmented
datasets on SegNet model improves by a margin of about 9–10% when compared
with the results of original dataset on SegNet model. Figure 12 has shown the dice
coefficient bar graphs of training, validation, and testing augmented and original
datasets on SegNet model. As can be observed from Fig. 12, the data augmentation
Fig. 11 Dice Coefficients of augmented and original datasets on UNet model. a Results of training
process. b Results of validation process. c Results of testing process
Fig. 12 Dice Coefficients of augmented and original datasets on SegNet model. a Results of training
process. b Results of validation process. c Results of testing process
has produced better DC values for training, validation, and testing of datasets on
SegNet model.
Table 3 shows dice coefficient and loss results of training, validation, and testing
augmented and original datasets on Dilated Net model. It can be seen from Table 3
that augmented dataset has helped in improving training, validation, and testing
accuracy results. The training, validation, and testing results of dice coefficients for
augmented datasets on Dilated Net model improves by a margin of about 1–2% when
compared with the results of original dataset on Dilated Net model. Figure 13 has
shown the dice coefficient bar graphs of training, validation, and testing augmented
and original datasets on Dilated Net model. As can be observed from Fig. 13, the
Fig. 13 Dice coefficients of augmented and original datasets on dilated net model. a Results of
training process. b Results of validation process. c Results of testing process
118 T. Mujtaba and M. A. Wani
data augmentation has produced better DC values for training, validation, and testing
of datasets on Dilated Net model.
Table 4 shows dice coefficient and loss results of training, validation, and testing
augmented and original datasets on PSPNet model. It can be seen from Table 4 that
augmented dataset has helped in improving validation and testing accuracy results.
The validation and testing results of dice coefficients for augmented datasets on
PSPNet model improves by a margin of about 9–15% when compared with the results
of original dataset on PSPNet model. However, training results of dice coefficients
for augmented datasets on PSPNet model decreases. This implies more epochs are
required to train the PSPNet with larger datasets. Figure 14 has shown the dice
coefficient bar graphs of training, validation, and testing augmented and original
datasets on PSPNet model. As can be observed from Fig. 14, the data augmentation
has produced better DC values for validation, and testing of datasets on PSPNet
model.
Fig. 14 Dice coefficients of augmented and original datasets on PSPNet model. a Results of training
process. b Results of validation process. c Results of testing process
Table 5 shows dice coefficient and loss results of training, validation, and testing
augmented and original datasets on DeepLab v3+ model. It can be seen from Table 5
that augmented dataset has helped in improving validation and testing accuracy
results. The validation and testing results of dice coefficients for augmented datasets
on DeepLab v3 +model improves by a margin of about 7–9% when compared with
the results of original dataset on DeepLab v3+ model. However, training results of
dice coefficients for augmented datasets on DeepLab v3+ model decreases. This
implies more epochs are required to train the DeepLab v3+ with larger datasets.
Figure 15 has shown the dice coefficient bar graphs of training, validation and testing
augmented and original datasets on DeepLab v3+ model. As can be observed from
Fig. 15, the data augmentation has produced better DC values for validation, and
testing of datasets on DeepLab v3+ model.
Table 6 shows dice coefficient and loss results of training, validation, and testing
augmented and original datasets on Dilated ResNet model. It can be seen from Table 6
that augmented dataset has helped in improving validation and testing accuracy
results. The validation and testing results of dice coefficients for augmented datasets
on Dilated ResNet model improves by a margin of about 5–9% when compared with
the results of original dataset on Dilated ResNet model. However, training results of
dice coefficients for augmented datasets on Dilated ResNet model decreases. This
implies more epochs are required to train the Dilated ResNet with larger datasets.
Figure 16 has shown the dice coefficient bar graphs of training, validation, and testing
augmented and original datasets on Dilated ResNet model. As can be observed from
Fig. 15 Dice Coefficients of augmented and original datasets on DeepLab v3+ model. a Results
of training process. b Results of validation process. c Results of testing process
Fig. 16 Dice Coefficients of augmented and original datasets on Dilated ResNet model. a Results
of training process. b Results of validation process. c Results of testing process
Fig. 16, the data augmentation has produced better DC values for validation, and
testing of datasets on Dilated Resnet model.
Dice coefficient results of testing augmented and original datasets on all the six
models have been summarized in Fig. 17 in the form of bar graphs. The bar graphs
indicate that the UNet model produces the best value of DC, implying that the best
segmentation accuracy results are produced by UNet model, followed by SegNet and
DilatedNet models.
Fig. 17 DC values of testing augmented and original datasets on all the six models
Automatic Solar Panel Detection from High-Resolution … 121
6 Conclusion
This work described the automatic detection of solar panels from satellite imagery by
using deep learning segmentation models. The study thoroughly discussed various
state of art deep learning segmentation architectures, various encoding, decoding, and
up-sampling techniques used in deep learning segmentation process. The six archi-
tectures for automatic detection of solar panels used were UNet, SegNet, Dilated Net,
PSPNet, DeepLab v3+, and Dilated Residual Net. The dataset comprised satellite
images of four cities of California. Image size of 224 × 224 was used for training the
models. The results concluded that the UNet deep learning architecture that uses skip
connections with encoder and decoder modules produced the best segmentation accu-
racy results. Moreover, dataset augmentation helped to improve the segmentation
accuracy results further.
References
1. M.A. Wani, Incremental hybrid approach for microarray classification, in 2008 Seventh
International Conference on Machine Learning and Applications (IEEE, 2008), pp. 514–520
2. M.A. Wani, R. Riyaz, A new cluster validity index using maximum cluster spread based
compactness measure. Int. J. Intell. Comput. Cybern. (2016)
3. M.A. Wani, R. Riyaz, A novel point density based validity index for clustering gene expression
datasets. Int. J. Data Mining Bioinf. 17(1), 66–84 (2017)
4. R. Riyaz, M.A. Wani, Local and global data spread based index for determining number of
clusters in a dataset, in 2016 15th IEEE International Conference on Machine Learning and
Applications (ICMLA) (IEEE, 2016), pp. 651–656
5. F.A. Bhat, M.A. Wani, Performance comparison of major classical face recognition tech-
niques, in 2014 13th International Conference on Machine Learning and Applications (IEEE,
2014), pp. 521–528
6. M.A. Wani, M. Yesilbudak, Recognition of wind speed patterns using multi-scale subspace
grids with decision trees. Int. J. Renew. Res. (IJRER) 3(2), 458–462 (2013)
7. M.R. Wani, M.A. Wani, R. Riyaz, Cluster based approach for mining patterns to predict wind
speed, in 2016 IEEE International Conference on Renewable Energy Research and Applications
(ICRERA) (IEEE, 2016), pp. 1046–1050
8. M.A. Wani, F.A. Bhat, S. Afzal, A.L. Khan, Advances in Deep Learning (Springer, 2020)
9. J. Long, E. Shelhamer, T. Darrell, Fully convolutional networks for semantic segmentation,
in Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 3431–
3440 (2015)
10. O. Ronneberger, P. Fischer, T. Brox, U-net: convolutional networks for biomedical image
segmentation, in International Conference on Medical Image Computing and Computer-
Assisted Intervention (Springer, Cham, 2015), pp. 234–241
11. V. Badrinarayanan, A. Kendall, R. Cipolla, Segnet: a deep convolutional encoder-decoder
architecture for image segmentation. IEEE Trans. Pattern Anal. Mach. Intell. 39(12), 2481–
2495 (2017)
12. L.C. Chen, Y. Zhu, G. Papandreou, F. Schroff, H. Adam, Encoder-decoder with atrous separable
convolution for semantic image segmentation, in Proceedings of the European Conference on
Computer Vision (ECCV) (2018), pp. 801–818
13. F. Yu, V. Koltun, Multi-scale context aggregation by dilated convolutions, in ICLR (2016)
122 T. Mujtaba and M. A. Wani
14. F. Yu, V. Koltun, T. Funkhouser, Dilated residual networks, in Proceedings of the IEEE
Conference on Computer vision and Pattern Recognition (2017), pp. 472–480
15. J. Camilo, R. Wang, L.M. Collins, K. Bradbury, J.M. Malof, Application of a semantic segmen-
tation convolutional neural network for accurate automatic detection and mapping of solar
photovoltaic arrays in aerial imagery (2018). arXiv preprint arXiv:1801.04018
16. J.M. Malof, L.M. Collins, K. Bradbury, A deep convolutional neural network, with pre-training,
for solar photovoltaic array detection in aerial imagery, in 2017 IEEE International Geoscience
and Remote Sensing Symposium (IGARSS) (IEEE, 2017), pp. 874–877
17. J. Yuan, H.H.L. Yang, O.A. Omitaomu, B.L. Bhaduri, Large-scale solar panel mapping from
aerial images using deep convolutional networks, in 2016 IEEE International Conference on
Big Data (Big Data) (IEEE, 2016), pp. 2703–2708
18. K. Bradbury, R. Saboo, T.L. Johnson, J.M. Malof, A. Devarajan, W. Zhang, R.G. Newell,
Distributed solar photovoltaic array location and extent dataset for remote sensing object
identification. Sci. Data 3, 160106 (2016)
19. K. Simonyan, A. Zisserman, Very deep convolutional networks for large-scale image
recognition (2014). arXiv preprint arXiv:1409.1556
20. K. He, X. Zhang, S. Ren, J. Sun, Deep residual learning for image recognition, in Proceedings
of the IEEE Conference on Computer Vision and Pattern Recognition (2016), pp. 770–778
21. G. Huang, Z. Liu, L. Van Der Maaten, K.Q. Weinberger, Densely connected convolutional
networks, in Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition
(2017), pp. 4700–4708
22. H. Noh, S. Hong, B. Han, Learning deconvolution network for semantic segmentation,
in Proceedings of the IEEE International Conference on Computer Vision (2015), pp. 1520–
1528
23. H. Zhao, J. Shi, X. Qi, X. Wang, J. Jia, Pyramid scene parsing network, in Proceedings of the
IEEE International Conference on Computer Vision and Pattern Recognition (Honolulu, HI,
USA, 2017), pp. 2881–2890
24. F. Milletari, N. Navab, S.A. Ahmadi, V-net: fully convolutional neural networks for volumetric
medical image segmentation, in 2016 Fourth International Conference on 3D Vision (3DV)
(IEEE, 2016), pp. 565–571
Training Deep Learning Sequence
Models to Understand Driver Behavior
Abstract Driver distraction is one of the leading causes of fatal car accidents in the
U.S. Analyzing driver behavior using machine learning and deep learning models
is an emerging solution to detect abnormal behavior and alarm the driver. Models
with memory such as LSTM networks outperform memoryless models in car safety
applications since driving is a continuous task and considering information in the
sequence of driving data can increase the model’s performance. In this work, we used
time-sequenced driving data that we collected in eight driving contexts to measure
the driver distraction. Our model is also capable of detecting the type of behavior that
caused distraction. We used the driver interaction with the car infotainment system as
the distracting activity. A multilayer neural network (MLP) was used as the baseline
and two types of LSTM networks including the LSTM model with attention network
and the encoder–decoder model with attention were built and trained to analyze the
effect of memory and attention on the computational expense and performance of the
model. We compare the performance of these two complex networks to that of the
MLP in estimating driver behavior. We show that our encoder–decoder with attention
model outperforms the LSTM attention while using LSTM networks with attention
enhanced training process of the MLP network.
© The Editor(s) (if applicable) and The Author(s), under exclusive license 123
to Springer Nature Singapore Pte Ltd. 2021
M. A. Wani et al. (eds.), Deep Learning Applications, Volume 2,
Advances in Intelligent Systems and Computing 1232,
https://doi.org/10.1007/978-981-15-6759-9_6
124 S. M. Kouchak and A. Gaffar
1 Introduction
Driver distraction is one of the leading causes of fatal car accidents in the U.S.,
and—while it is becoming an epidemic—it is preventable [1]. Based on the NHTSA
report, 3,166 people died in fatal car accidents that are caused by distracted drivers
in 2017, which constitutes 8.5% of all fatalities [2, 3]. Distracted driving is any task
that diverts the driver’s attention from the primary task of driving such as texting,
reading, talking to other passengers, and using the car infotainment system [4, 5].
Texting is the most distracting task since the driver needs to take their eyes off the
road for an estimated five seconds to read or send a text. If the driver was driving at
45 mph, they will have covered a distance equal to a football field’s length without
seeing the road [4].
Driver distraction is considered a failure in driving task prioritization by putting
more attention on secondary tasks, such as reading a text or tuning the radio, while
ignoring the primary task of driving. There are four types of distraction including.
Visual distraction is any task that makes the driver to take his eyes off the road.
Manual distraction happens when the driver takes his hands off the steering wheel
[6].
Cognitive distraction is any task that diverts driver attention from the primary task
of driving [7].
Audio distraction is any noise that obscures important voice in the car such as
alarms or outside such as ambulance vehicle sirens [8].
Designing a car-friendly user interface [9], advanced driver assistant systems
(ADAS) [10–12], and autonomous vehicles (AV) [13] are some approaches to solve
the problem. Driver assistant systems such as lane keeping and adaptive cruise control
reduce human error by automating difficult or repetitive tasks. Although they can
enhance driver safety, they have some limitations as they often work only in prede-
fined situations. For instance, collision mitigation systems are based on algorithms
that have been defined by systems developers to cover specific situations that could
lead to a collision in traffic. These algorithms can’t identify a full range of dangerous
situations that could lead to an accident. They typically detect and prevent only a
limited number of predefined menace situations [14]. An autonomous vehicle can
sense its environment using a variety of sensors such as Radar and LIDAR and move
the car with limited or no human contribution. They face some challenges such as
dealing with driver cars, unsuitable road infrastructures, and cybersecurity. In the
long run, when these challenges are solved, autonomous vehicles could boost car
safety and decrease commuting time and cost [15].
Monitoring and analyzing driver behavior using machine learning methods to
detect abnormal behavior is an emerging solution to enhance car safety [16]. Human
error which can increase due to mental workloads persuaded by distraction is one
of the leading causes of car crashes. Workloads are hard to observe and quantify;
so, analyzing driver behavior to distinguish normal and aggressive driving behavior
can be used as a suitable approach to detect driver distraction and alarm the driver
to take the control of the car for a short time [17]. Driving data and driver status
Training Deep Learning Sequence Models … 125
that can be collected from internal sources (like the vehicle’s onboard computer and
the data bus) as well as external devices such as camera, can be used to train a
variety of machine learning methods such as Markov model, neural network, and
decision tree to learn driving patterns and detect abnormal driver behavior [18–22].
Deep learning methods such as convolutional neural network, LSTM networks, and
encoder–decoder outperform other machine learning methods in car safety applica-
tions such as pedestrian detection and driver behavior classification [23–26]. Some
machine learning and deep learning methods make a spontaneous decision based on
the current inputs of the model. In some real-world applications like driving, each
data sample has an effect on several next samples, so using a temporal dimension and
adding memory and attention can extract more behavior-representative information
from a sequence of data, which could improve the accuracy of these applications.
In this work, we use three models including a multilayer neural network (MLP),
an LSTM network with attention layer, and an encoder–decoder model to predict the
driver status using only driving data. No intrusive devices such as cameras or driver-
wearable sensors are used. This makes our approach more user-friendly and increases
its potential to be adapted in industry more readily. The MLP model was considered
as the baseline and memoryless model. Two other models consider the weighted
dependency between input and output [27, 28]. This provides instant feedback on the
progress of the network model, which increases its prediction accuracy and reduces
its training time. We started with a single-input–single-output neural network and
compared the accuracy and training process of this model with an LSTM model with
attention, which is multiple-input–single-output and has both memory and attention.
The LSTM model with attention achieved less train and test error in a smaller number
of training epochs. The average number of training epochs is 400 for MLP and 100
for LSTM attention. Besides, the LSTM attention model has a smaller number of
layers. After that, we used the encoder–decoder with attention to estimate a sequence
of driver behavior using a multiple-input–multiple-output model. We compared the
achieved results with two other models. The train and test error of this model is less
than two other models and it can estimate multiple driver behavior vectors.
Section 2 discusses some related works. Section 3 describes the experiment.
Section 4 explains the features of the collected data in the experiment. We discuss
methodology in Sect. 5. Section 6 talks about the details of the three models. Section 7
is the results and Sect. 8 is the conclusion.
2 Related Works
Wöllmer et al. [29] discussed a driver distraction detection method. The goal of this
method is to model head tracking and the context of driving data using an LSTM
network. An experiment was conducted with 30 volunteers that drove an Audi in a
straight road in Germany with one lane for each direction and 100 km/h speed limit.
Distracted driving data were collected by an interface to measure the vehicle’s CAN-
Bus data and a head tracking system installed in the car cockpit. Eight distracting
126 S. M. Kouchak and A. Gaffar
tasks were chosen including radio, CD, phone book, navigation point of interest,
phone, navigation, TV, and navigation sound. Distracting functions were available
through eight hard keys, which were located on the left and right sides of the interface.
In all, 53 non-distracted and 220 distracted runs were done. Collected data from the
CAN-Bus and head tracking system were fed to an LSTM network to predict the
driver’s state continuously. The accuracy of the model reached 96.6%. Although this
approach has high accuracy, it needs some external devices that are not available in
all cars, and is often considered an intrusion by drivers. Additionally, it considers
one driving context and a straight road. If generalized to complex driving contexts,
more work needs to be done to test the accuracy of the model’s performance which
might degrade when other types of roads and driving conditions are introduced as
this would result in more complex contextual data and hence a larger number of
patterns to differentiate between.
Xu et al. [30] introduced an image captioning model with attention mechanisms.
It is inspired by the language translation model. The input of the model is a raw
image and it produces a caption, which describes the input image. In this model,
the encoder extracts features of the image using the middle layers instead of using a
fully connected layer. The decoder is an LSTM network that produces the caption as
a sequence of words. The model shows good results using three benchmark datasets
using the METEOR and BLEU metric.
Xiao et al. [31] discussed an image classification model. It uses a two-level
attention model for fine-grained image classification in deep convolutional neural
networks. Fine-grained classification is detecting subordinate level categories under
some basic categories. This model is based on an intuition that for fine-grained clas-
sification, first the object needs to be detected and in the next step the discriminative
parts of the object should be detected. For object detection level, the model uses raw
images and removes noise to detect the objects and classify them. In the second level,
it filters the object using mid-level CNN filters to detect parts of the image and an
SVM classifier is used to classify the image’s parts. The model was validated on the
subset of the ILSVRC29112 dataset and the CUB200 2011 dataset and it showed
good performance under the weakest supervision condition.
Huang et al. [32] proposed a machine translation model with attention to translate
image captions from English to German. Additional information from the image is
considered by the model to solve the problem of ambiguity in languages. A convolu-
tional neural network is used to extract the image features, and the model adds these
features to the text features to enhance the performance of the LSTM network that is
used to generate the caption in the target language. Regional features of the image are
used instead of general features. In sum, the best performance of the model shows a
2% improvement in the BLEU score and 2.3% enhancement in the METEOR dataset
compared to models that only consider text information.
Lv et al. [33] introduced a deep learning model for traffic flow prediction that
considers both temporal and spatial correlations inherently. The model used a stacked
autoencoder to detect and learn features of traffic flow. The traffic data was collected
from 15000 detectors in the freeway system across California during weekdays in
the first three months of 2013. Collected data in the first two months were used
Training Deep Learning Sequence Models … 127
as the training data and the test dataset was the collected data in the third month.
The model outperformed previous models in medium and high traffic, but it didn’t
perform well in low traffic. They compared the performance of the model with
Random Walk Forecast method, Support Vector Machine, Backpropagation Neural
Network, and Radial Basis Function. Stacked autoencoder model, for the 15 min
traffic flow prediction in 86% of highways, reached more than 90% accuracy and
outperformed the other four shallow models.
Saleh et al. [34] discussed a novel method for driver behavior classification using
a stacked LSTM network. Nine sensory data are captured using internal sensors of a
cellphone during realistic driving sessions. Three driving behavior including normal,
drowsy, and aggressive were defined, and driver behavior classification problem was
modeled as a time series classification. A sequence of driving feature vectors was used
as the input of a stacked LSTM network and the network classified the driver behavior
accurately. Besides, the model achieved better results on UAH-DriveSet, which is a
naturalistic driver behavior dataset, compared to the baseline approach. This model
was compared with other common driver behavior classification methods including
decision tree (DT) and multilayer perceptron (MLP). DT and MLP achieved 51 and
75% accuracy, so the proposed method outperformed them by 86% accuracy.
These works used different types of LSTM networks to detect driver behavior
and driving patterns. In this work, we use two types of LSTM networks including a
bidirectional LSTM network and an LSTM network with attention layer to predict
driver behavior using only driving data. We use simple LSTM network as the baseline
to compare the results of these two networks with simple LSTM network’s result.
3 Experiment
Four contexts of driving were defined for this experiment including Day, Night,
Fog, and Fog and Night. We used an android application to simulate the car info-
tainment system. The application was hosted on an Android v4.4.2 based Samsung
Galaxy Tab4 8.0 which was connected to the Hyper Drive simulator. This setup
allowed us to control the height and angle of the infotainment system. The removable
tablet further allowed the control of screen size and contents.
Our minimalist design was used in the interface design of this application. In this
design, the main screen of the application shows six groups of car applications and the
driver has access to more features under each group. In each step of the navigation,
a maximum of six icons were displayed on the screen; which was tested earlier
and shown to be a suitable number in infotainment UI design [35]. The minimalist
design interaction allowed each application in this interface to be reached within a
maximum of four steps of navigation from the main screen; hence conforming with
NHTSA guidelines DOT HS 812 108 of acceptable distraction [3]. This helped us
to normalize the driver’s interaction tasks and standardize them between different
driving contexts.
3.1 Participants
distracted drivers by asking them to perform some tasks on the simulator’s infotain-
ment system. In this experiment, a distracting task is defined as reaching a specific
application on the designed car interface. We classified all possible applications in the
interface to three groups based on the number of navigation steps that the driver needs
to pass from the main screen to reach them and called them two-step, three-step, and
four-step tasks.
In each distracting scenario, we chose some tasks from each group and asked the
driver to do them. We put equal time intervals between tasks. It means when the
driver finished a task, we waited a few seconds before asking for the next task. In
distracted scenarios, we observed the driver behavior and collected four features for
each task including the number of errors that the driver did during the task, response
time which shows the time range from the moment that the driver was asked to do a
task and the time the task was completed, the mode of driving which is the current
driving scenario, and the number of navigation steps that the driver needs to pass
from the main screen to a specific application in the interface to complete the task.
4 Data
The simulator collected an average of 19000 data vectors per trip, with 53 driving-
related features each. In sum, during 280 trips 5.3 million data vectors were collected.
We used a paired t-test to select the most significant features. Based on the results of
pair t-tests, we chose 10 driving features including velocity, speed, steering wheel,
brake, lateral accelerating, headway distance, headway time, accelerating, longitude
accelerating, and lane position. We set the master data sampling rate at 60 samples
per second. While the high sample rate was useful in other experiments, using this
high sampling rate made the model computationally expensive, so we compressed
the collected data and averaged every 20 samples into one vector. We made a dataset
of driving-related data vectors with 10 features, the Vc. In distracted scenarios,
volunteers did 2025 tasks in sum. For each task, we defined four features for driver
behavior while interacting with our designed interface, the Vh:
1. The name of each scenario that shows if the driver is distracted or not, and based
on our previous experiments and the current collected data, the adverse (Fog,
Night) and double adverse (Fog and Night) driving contexts adversely affect
drivers’ performance.
2. The number of driver errors during the task. We defined error as touching a wrong
icon or not following driving rules while interacting with the interface.
3. Response time, which is the length of the task from the moment that we ask the
driver to start the task and the moment that the task is done.
4. The number of navigation steps that the driver needs to pass to reach the appli-
cation and finish the task. All tasks could be completed within four steps or
less.
130 S. M. Kouchak and A. Gaffar
For each driver data vector (Vh), several car data vectors (Vc) were collected by
the simulator. The number of Vc vectors linked to each task depends on the response
time of the task. To map Vc vectors to the corresponding Vh, we divided the Vc vectors
of each trip based on the length of tasks’ response time of that trip to Vc blocks with
different length and map each block to one Vh vector. In (1) N is the number of
tasks that were executed in a trip and per(taski ) is the approximated percentage of
Vc vectors in the trip that is related to taski .
Responset imei
per(taski ) = N ∗ 100 (1)
k=1 Responset imek
5 Methodology
Feedforward neural networks have been used in many car-related applications such
as lane departure and sign detection [22–25]. In feedforward neural networks, data
travels in one direction from the input layer to the output and the training process is
based on the assumption that input data are fully independent of each other. Feed-
forward networks don’t have any impression of order in sequence or time and only
the current sample is considered to tune the network’s hyperparameters [36]. We use
a feedforward neural network as the baseline model to estimate the driver behavior
using driving data. The model is a single-input–single-output network that uses the
mean of driving data during each task as the input and estimates the corresponding
driver behavior vector.
In some real-life applications such as speech recognition, driving and image
captioning input samples are not independent and there is valuable information in
the sequence so the models with memory such as recurrent neural networks can
outperform memoryless models. The LSTM model uses both the current sample
and previously observed data in each step of training. It combines these two data
to produce the network’s output [37]. The difference between feedforward networks
and RNN is that the feedback loop in RNN provides previous steps’ information to
the current one, adding memory to the network, which is preserved in some hidden
states of the Network. LSTM with attention is a recurrent model that is a combina-
tion of the bidirectional LSTM layer and the attention layer. LSTM attention model
uses the weighted effect of all driving data during each task to estimate the driver
behavior vector [38]. This model has both memory and attention and considers all
driving data during the task so our assumption is that this model would be more
accurate compared to the feedforward model.
Training Deep Learning Sequence Models … 131
6 Models
We built and trained a memoryless multilayer (MLP) neural network using both
scaled and unscaled data. This feedforward neural network was considered as the
baseline to compare with the LSTM attention network and encoder–decoder attention
model that both have attention and memory.
132 S. M. Kouchak and A. Gaffar
Attention is one of the fundamental parts of cognition and intelligence in humans [41,
42]. It helps reduce the amount of information processing and complexity [43, 44].
We can loosely define attention as directing some human senses (like vision), and
hence the mind, to a specific source or object rather than scanning the entire input
space [45]. This is an essential human perception component that helps increase
processing power while reducing demand on resources [46]. In the neural network
area, attention is primarily used as a memory mechanism that determines which
part of the input sequence has more effect on the final output [47, 48]. The attention
mechanism considers the weighted effect of each input on the model’s output instead
of producing one context vector from all samples of the input sequence (Fig. 2) [49,
50]. We used the attention mechanism to estimate the driver behavior using a sequence
of driving data vectors with 10 features and considering the weighted effect of each
input driving data on the driver behavior. Our assumption is that using attention
mechanism decreases the model’s training and test error and enhances the training
process.
filtered specifically for each output in the output sequence. Equation (2) shows the
output of the encoder in the encoder–decoder without attention. In (2) h is the output
vector of the encoder that contains information of all samples in the input sequence.
In attention networks, the encoder produces one vector for each output (3) shows the
encoder output in the attention network [49].
h = Encoder(x1 , x2 , x3 , . . . , x T , t) (2)
[h 1 , h 2 , h 3 , . . . , h T ] = Encoder(x1 , x2 , x3 , . . . , x T ) (3)
The decoder produces one output at a time and the model scores how well the
encoded input matches the current output. Equation (4) shows the scoring formula
to encode input i in step t. In (4) St−1 shows the output from previous step and h i
is the result of encoding input xi . In the next step, scores are normalized using (5)
which is a “SoftMax” function. The context vector for each time step is calculated
using (6).
exp(eti )
ati = T (5)
j=0 exp(et j )
T
Ct = at j ∗ h j (6)
j=0
7 Results
We built an MLP and trained it with both scaled and unscaled data. In this model,
80% of data were used for training and 20% of them for testing. We tried different
numbers of layers in the range of 2–6 and a variety of hidden neurons in the range
134 S. M. Kouchak and A. Gaffar
50–500 for this neural network. Table 1 shows some of the best-achieved results
with scaled data and Table 2 Shows the achieved results with unscaled data. These
results show the large difference between train and test error that means the model
has overfitting problem with both scaled and unscaled data and it doesn’t generalize
well in most cases.
In the next step, we trained an LSTM network with attention. We used Adam opti-
mizer as the model’s optimizer and mean absolute error (MAE) as the accuracy of
the model. We tried a wide range of LSTM neurons from 10 to 500. In this model,
80% of the dataset was used as the training set and 20% as the testing set.
Table 3 shows some results with unscaled data. For unscaled data, the best result
was achieved with 20 neurons and it is 0.85 training and 0.96. This model achieved
less test error compared to the MLP model using a smaller number of hidden neurons.
The best result of MLP achieved with four hidden layers and 300 neurons in each
layer, so the MLP model compared to the LSTM attention model is less accurate
and more computationally expensive. Besides, one layer of the LSTM model plus
an attention layer had better performance compared to a large MLP model with
four fully connected layers. In addition, the training process of the MLP model took
around 400–600 epochs while the LSTM converged in around 100–200 epochs.
Table 4 shows the best-achieved results with scaled data. The model with 40
neurons reached the minimum test error which is less than all cases in MLP network
with scaled data except one case that is 4 hidden layer model with 150 neurons. In
general, the LSTM attention model with scaled data generalized better than MLP in
all cases and achieved better performance with a smaller number of hidden neurons
and smaller network.
As we mentioned earlier, to have a sequence of driver behavior data vectors we can use
a sequence-to-sequence model instead of running a multi-input–single-output model
multiple times. We chose the encoder–decoder model as a suitable sequence-to-
sequence model. We built and trained encoder–decoder attention models including
three-step, four-step, and five-step models with both scaled and unscaled data. In
these models, 80% of data was used as a training dataset and 20% of them were used
for testing the model. Different combinations of batch size, activation function, and
the number of LSTM neurons were tested. Finally, we chose Adam as the activation
function, 100 as the batch size. We tried a range of LSTM neurons from 20 to 500.
Besides, we tried different lengths of input and output sequences from 2 to 6. After
the four-step model, increasing the length of the sequence didn’t have a positive
effect on the model’s performance. Table 5 shows some of the best-achieved results
with unscaled data.
Mean absolute error (MAE) was used as the lost function of this model. The three-
step model reached the minimum error which is 1.5 train and test mean absolute error.
The four-step and the five-step models have almost the same performance. The three-
step model with unscaled data and 100 LSTM neurons showed the best performance.
Figure 4 shows this model’s mean absolute error. Table 6 shows the achieved results
for the three models with scaled data. The three-step model with 250 LSTM neurons
reached the minimum error (Fig. 5).
The test error of the encoder–decoder attention model with unscaled data is close
to the MLP model with scale data but the generalization of this model is much better
Table 5 Encoder–decoder
LSTM neurons Sequence MAE train MAE test
attention with unscaled data
100 3 1.5 1.5
250 3 1.51 1.52
100 4 1.5 1.52
150 4 1.52 1.56
50 5 1.52 1.52
250 5 1.53 1.56
Training Deep Learning Sequence Models … 137
Fig. 4 The mean absolute error of a three-step model with 100 LSTM neuron and unscaled data
Table 6 Encoder–decoder
LSTM neurons Sequence MAE train MAE test
attention with scaled data
150 3 0.1683 0.16
250 3 0.1641 0.16
100 4 0.1658 0.16
300 4 0.1656 0.16
150 5 0.1655 0.16
400 5 0.1646 0.17
Fig. 5 The mean absolute error of a three-step model with 250 LSTM neurons and scaled data
138 S. M. Kouchak and A. Gaffar
than the MLP model and the network is smaller than the MLP model. Besides, it
converges in around 50 epochs on average which is much faster than the MLP model
that needs around 400 epochs training on average. Moreover, this model estimates
multiple driver behavior vectors in one run. The error of this model with not scaled
data is more than LSTM attention model but the generalization of this model is
better and the number of input samples is much less than LSTM attention model
since this model consider one driving data for each driver behavior vector, so it is
computationally less expensive.
The encoder–decoder attention with scaled data outperformed both the MLP
model and the LSTM attention model. This model has memory and attention similar
to the LSTM attention model. Besides, it considers the dependency between samples
of output sequence which is not possible if we run a multi-input–single-output model
multiple times to have a sequence of driver behavior vectors. Besides, this model is
computationally less expensive than the LSTM attention model since it considers one
input vector corresponding to each output, similar to the MLP model. The minimum
test error of this model with scaled data is 0.06 less than the minimum test error of the
LSTM attention model and 0.04 less than the minimum test error of the MLP model.
The encoder–decoder attention model converges with less error and takes less time
than two other models. The average number of epochs for this model with unscaled
and scaled data was between 50 and 100 epochs, which are half of the average epochs
of the LSTM attention model and one-fourth of MLP epochs. Besides, the model
uses a smaller number of layers compared to the MLP model and a smaller number
of input samples compared to the LSTM attention model making it computationally
less expensive. It also estimates multiple driver behavior data vectors in each run.
8 Conclusion
Using machine learning methods to monitor driver behavior and detect the driver’s
inattention that can lead to distraction is an emerging solution to detect and reduce
driver distraction. Different deep learning methods such as CNN, LSTM, and RNN
networks were used in several car safety applications. Some of these methods have
memory, so they can extract and learn information in a sequence of data. There
is some information in the sequence of driving data that can’t be impeded from
processing them manually. Driving data samples are not independent of each other,
so methods that have memory and attention such as recurrent models are a better
choice for higher intelligence and hence more reliable car safety applications.
These methods utilize different mechanisms to perceive the dependency between
samples and extract the latent information in the sequence. We chose the MLP
network which is a simple memoryless deep neural network as the baseline for
our model. Then we trained an LSTM attention model that has memory and atten-
tion. This model outperforms the MLP model with both scaled and unscaled data.
The model trained at least two times faster than the MLP model and achieved better
Training Deep Learning Sequence Models … 139
performance with less hidden neurons, smaller network, and a smaller number of
training epochs.
In order to have a sequence of driving data we have two options: run a multi-input–
single-output model multiple times and using a sequence-to-sequence model. We
built and trained an encoder–decoder attention model with both scaled and unscaled
data to have a sequence of driver behavior data vectors. This model outperforms the
MLP model with both scaled and unscaled data. Besides, this model outperformed the
LSTM attention model with scaled data. Encoder–decoder attention model trained
at least two times faster than the LSTM attention model and four times faster than
the MLP model. Besides, in each run, it estimates multiple driving data vectors and
it had the best generalization and minimum difference between train and test error.
Our work shows that this would be a viable and scalable option for deep neural
network models that work in real-life complex driving contexts without the need
to use intrusive devices. It also provides an objective measurement of the added
advantages of using attention networks to reliably detect driver behavior.
References
1. B. Darrow, Distracted driving is now an epidemic in the U.S., Fortune (2016). http://fortune.
com/2016/09/14/distracted-driving-epidemic/
2. National Center for Statistics and Analysis, Distracted driving in fatal crashes, 2017, (Traffic
Safety Facts Research Note, Report No. DOT HS 812 700) (Washington, DC, National Highway
Traffic Safety Administration, 2019)
3. N. Chaudhary, J. Connolly, J. Tison, M. Solomon, K. Elliott, Evaluation of the NHTSA
distracted driving high-visibility enforcement demonstration projects in California and
Delaware. (Report No. DOT HS 812 108) (Washington, DC, National Highway Traffic Safety
Administration, 2015)
4. National Center for Statistics and Analysis, Distracted driving in fatal crashes, 2017, (Traffic
safety facts research Note, Report No. DOT HS 812 700), (Washington, DC: National Highway
Traffic Safety Administration, 2019)
5. S. Monjezi Kouchak, A. Gaffar, Driver distraction detection using deep neural network, in The
Fifth International Conference on Machine Learning, Optimization, and Data Science (Siena,
Tuscany, Italy, 2019)
6. J. Lee, Dynamics of driver distraction: the process of engaging and disengaging. Assoc. Adv.
Autom. Med. 58, 24–35 (2014)
7. T. Hirayama, K. Mase, K. Takeda, Analysis of temporal relationships between eye gaze and
peripheral vehicle behavior for detecting driver distraction. Hindawi Publ. Corp. Int. J. Veh.
Technol. 2013, 8 (2013)
8. National Highway Traffic Safety Administration. Blueprint for Ending Distracted Driving.
Washington, DC: U.S. Department of Transportation. National Highway Traffic Safety
Administration, DOT HS 811 629 (2012)
9. T.B. Sheridan, R. Parasuraman, Human-automation interaction. reviews of human factors and
ergonomics, vol. 1, pp. 89–129 (2015). https://doi.org/10.1518/155723405783703082
10. U. Hamid, F. Zakuan, K. Zulkepli, M. ZulfaqarAzmi, H. Zamzuri, M. Rahman, M. Zakaria,
Autonomous Emergency Braking System with Potential Field Risk Assessment for Frontal
Collision Mitigation (IEEE ICSPC, Malaysia, 2017)
11. L. Li, D. Wen, N. Zheng, L. Shen, Cognitive cars: a new frontier for ADAS research. IEEE
Trans. Intell. Transp. Syst. 13 (2012)
140 S. M. Kouchak and A. Gaffar
12. S. Monjezi Kouchak, A. Gaffar, Estimating the driver status using long short term memory,
in Machine Learning and Knowledge Extraction, Third IFIP TC 5, TC 12, WG 8.4, WG 8.9,
WG 12.9 International Cross-Domain Conference, CD-MAKE 2019 (2019). https://doi.org/10.
1007/978-3-030-29726-8_5
13. P. Koopman, M. Wagner, Autonomous vehicle safety: an interdisciplinary challenge. IEEE
Intell. Transp. Syst. Mag. 9, 90–96 (2017)
14. M. Benmimoun, A. Pütz, A. Zlocki, L. Eckstein, euroFOT: field operational test and impact
assessment of advanced driver assistance systems: final results, in SAE-China, FISITA (eds)
Proceedings of the FISITA 2012 World Automotive Congress. Lecture Notes in Electrical
Engineering, vol. 197 (Springer, Berlin, Heidelberg, 2013)
15. S. Monjezi Kouchak, A. Gaffar, Determinism in future cars: why autonomous trucks are easier
to design, in IEEE Advanced and Trusted Computing (ATC 2017) (San Francisco Bay Area,
USA, 2017)
16. S. Kaplan, M.A. Guvensan, A.G. Yavuz, Y. Karalurt, Driver behavior analysis for safe driving:
a survey, IEEE Trans. Intell. Transp. Syst. 16, 3017–3032 (2015)
17. A. Aksjonov, P. Nedoma, V. Vodovozov, E. Petlenkov, M. Herrmann, Detection and evaluation
of driver distraction using machine learning and fuzzy logic. IEEE Trans. Intell. Transp. Syst.
1–12 (2018). https://doi.org/10.1109/tits.2018.2857222
18. R. Harb, X. Yan, E. Radwan, X. Su, Exploring precrash maneuvers using classification trees
and random forests, Accid. Anal. Prev. 41, 98–107 (2009)
19. A. Alvarez, F. Garcia, J. Naranjo, J. Anaya, F. Jimenez, Modeling the driving behavior of
electric vehicles using smartphones and neural networks. IEEE Intell. Transp. Syst. Mag. 6,
44–53 (2014)
20. J. Morton, T. Wheeler, M. Kochenderfer, Analysis of recurrent neural networks for probabilistic
modeling of driver behavior. IEEE Trans. Intell. Transp. Syst. 18, 1289–1298 (2017)
21. A. Sathyanarayana, P. Boyraz, J. Hansen, Driver behavior analysis and route recognition by
Hidden Markov models, IEEE International Conference on Vehicular Electronics and Safety
(2008)
22. J. Li, X. Mei, D. Prokhorov, D. Tao, Deep neural network for structural prediction and lane
detection in traffic scene. IEEE Trans. Neural Netw. Learn. Syst. 28, 14 (2017)
23. S. Monjezi Kouchak, A. Gaffar, Non-intrusive distraction pattern detection using behavior
triangulation method, in 4th Annual Conference on Computational Science and Computational
Intelligence CSCI-ISAI (USA, 2017)
24. S. Su, B. Nugraha, Fahmizal, Towards self-driving car using convolutional neural network and
road lane detector, in 2017 2nd International Conference on Automation, Cognitive Science,
Optics, Micro Electro-Mechanical System, and Information Technology (ICACOMIT) (Jakarta,
Indonesia, 2017), p. 5
25. S. Hung, I. Choi, Y. Kim, Real-time categorization of driver’s gaze zone using the deep learning
techniques, in 2016 International Conference on Big Data and Smart Computing (BigComp)
(2016), pp. 143–148
26. A. Koesdwiady, S. Bedavi, C. Ou, F. Karray, End-to-end deep learning for driver distraction
recognition. Springer International Publishing AG 2017 (2017), p. 8
27. I. Goodfellow, Y. Bengio, A. Courville, Deep Learning (The MIT press, 2016), ISBN:
9780262035613
28. J. Schmidhuber, Deep learning in neural networks: an overview, vol. 61, (Elsevier, 2015),
pp. 85–117
29. M. Wöllmer, C. Blaschke, T. Schindl, B. Schuller, B. Färber, S. Mayer, B. Trefflich, Online
driver distraction detection using long short-term memory. IEEE Trans. Intell. Transp. Syst.
2(2), 574–582 (2011)
30. K. Xu, J.L. Bay, R. Kirosy, K. Cho, A. Courville, R. Salakhutdinovy, R.S. Zemely, Y.
Bengio, Show, attend and tell: neural image caption generation with visual attention, in 32
nd International Conference on Machine Learning (Lille, France, 2015)
31. T. Xiao, Y. Xu, K. Yang, J. Zhang, Y. Peng, Z. Zhang, The application of two-level attention
models in deep convolutional neural network for fine-grained image classification, in The IEEE
Conference on Computer Vision and Pattern Recognition (CVPR) (2015), pp. 842–850
Training Deep Learning Sequence Models … 141
32. P. Huang, F. Liu, S. Shiang, J. Oh, C. Dyer, Attention-based multimodal neural machine trans-
lation, in Proceedings of the First Conference on Machine Translation, Shared Task Papers,
vol. 2, (Berlin, Germany, 2016), pp. 639–645
33. Y. Lv, Y. Duan, W. Kang, Z. Li, F. Wang, Traffic flow prediction with big data: a deep learning
approach. IEEE Trans. Intell. Transp. Syst. 16, 865–873 (2015)
34. K. Saleh, M. Hossny, S. Nahavandi, Driving behavior classification based on sensor data fusion
using LSTM recurrent neural networks, in IEEE 20th International Conference on Intelligent
Transportation Systems (ITSC) (2017)
35. A. Gaffar, S. Monjezi Kouchak, Minimalist design: an optimized solution for intelligent inter-
active infotainment systems, in IEEE IntelliSys, the International Conference on Intelligent
Systems and Artificial Intelligence (London, UK, 2017)
36. C. Bishop, Pattern Recognition and Machine Learning (Springer). ISBN-13: 978-0387310732
37. M. Magic, Action recognition using Python and recurrent neural network, First edn. (2019).
ISBN: 978-1798429044
38. D. Mandic, J. Chambers, Recurrent neural networks for prediction: learning algorithms,
architectures and stability, First edn. (Wiley, 2001). ISBN: 978-0471495178
39. J. Rogerson, Theory, Concepts and Methods of Recurrent Neural Networks and Soft Computing
(2015). ISBN-13: 978-1632404930
40. M.A. Wani, F.A. Bhat, S. Afzal, A.I. Khan, Advances in Deep Learning (Springer, 2020)
41. A. Gaffar, E. M. Darwish, A. Tridane, Structuring heterogeneous big data for scalability and
accuracy. Int. J. Digit. Inf. Wirel. Commun. 4, 10–23 (2014)
42. A. Gaffar, H. Javahery, A. Seffah, D. Sinnig, A pattern framework for eliciting and deliv-
ering UCD knowledge and practices, in Proceedings of the Tenth International Conference on
Human-Computer Interaction (2003), pp. 108–112
43. A. Gaffar, Enumerating mobile enterprise complexity 21 complexity factors to enhance the
design process, in Proceedings of the 2009 Conference of the Center for Advanced Studies on
Collaborative Research (2009), pp. 270–282
44. A. Gaffar, The 7C’s: an iterative process for generating pattern components, in 11th
International Conference on Human-Computer Interaction (2005)
45. J. Bermudez, Cognitive Science: An Introduction to the Science of the Mind, 2nd edn. (2014).
978-1107653351
46. B. Garrett, G. Hough, Brain & Behavior: an Introduction to Behavioral Neuroscience, 5th edn.
(SAGE). ISBN: 978-1506349206
47. Y. Wang, M. Huang, L. Zhao, X. Zhu, Attention-based LSTM for aspect-level sentiment clas-
sification, in Proceedings of the 2016 Conference on Empirical Methods in Natural Language
Processing (2016), pp. 606–615
48. I. Sutskever, O. Vinyals, Q. Le, Sequence to sequence learning with neural networks, in
Advances in Neural Information Processing Systems 27 (NIPS 2014) (2014)
49. S. Frintrop, E. Rome, H. Christenson, Computational visual attention systems and their cogni-
tive foundations: a survey, ACM Trans. Appl. Percept. (TAP) 7 (2010). https://doi.org/10.1145/
1658349.1658355
50. Z. Yang, D. Yang, C. Dyer, X. He, A Smola, E. Hovy, Hierarchical attention networks for
document classification, in NAACL-HLT 2016 (San Diego, California, 2016), pp. 1480–1489
Exploiting Spatio-Temporal Correlation
in RF Data Using Deep Learning
Abstract The pervasive presence of wireless services and applications have become
an integral part of our lives. We depend on wireless technologies not only for our
smartphones but also for other applications like surveillance, navigation, jamming,
anti-jamming, radar to name a few areas of applications. These recent advances of
wireless technologies in radio frequency (RF) environments have warranted more
autonomous deployments of wireless systems. With such large scale dependence on
use of the RF spectrum, it becomes imperative to understand the ambient signal char-
acteristics for optimal deployment of wireless infrastructure and efficient resource
provisioning. In order to make the best use of such radio resources in both the spa-
tial and time domains, past and current knowledge of the RF signals are important.
Although sensing mechanisms can be leveraged to assess the current environment,
learning techniques are the typically used for analyzing past observations and to pre-
dict the future occurrences of events in a given RF environment. Machine learning
(ML) techniques, having already proven useful in various domains, are also being
sought for characterizing and understanding the RF environment. Some of the goals
of the learning techniques in the RF domain are transmitter or emitter fingerprinting,
emitter localization, modulation recognition, feature learning, attention and saliency,
autonomous RF sensor configuration and waveform synthesis. Moreover, in large-
scale autonomous deployments of wireless communication networks, the signals
received from one component play a crucial role in the decision-making process of
other components. In order to efficiently implement such systems, each component
D. Roy (B)
Computer Science, University of Central Florida, Orlando, FL 32826, USA
e-mail: debashri@cs.ucf.edu
T. Mukherjee
Computer Science, University of Alabama, Huntsville, AL 35899, USA
e-mail: tm0130@uah.edu
E. Pasiliao
Munitions Directorate, Air Force Research Laboratory, Eglin AFB,
Valparaiso, FL 32542, USA
e-mail: eduardo.pasiliao@us.af.mil
© The Editor(s) (if applicable) and The Author(s), under exclusive license 143
to Springer Nature Singapore Pte Ltd. 2021
M. A. Wani et al. (eds.), Deep Learning Applications, Volume 2,
Advances in Intelligent Systems and Computing 1232,
https://doi.org/10.1007/978-981-15-6759-9_7
144 D. Roy et al.
We are living in a world where the distances are shrinking every day, thanks to
an explosion in the use of connected devices. The ubiquitous usage of wirelessly
connected Internet-of-Things (IoT) [56] along with the deployment of wireless
autonomous systems has ushered in a new era of industrial-scale deployment of
RF devices. This prevalence of large-scale peer-to-peer communication and nature
of the underlying ubiquitous network brings forth the challenge of accurately identi-
fying a RF transmitter. Every device that is part of a large network needs to be able to
identify its peers with high confidence in order to set up secure communication chan-
nels. One of the ways in which this is done is through the interchange of “keys” [42]
for host identification. However, such schemes are prone to breaches by malicious
agents [50] because often the actual implementations of such systems are not crypto-
graphically sound. In order to get around the problem of faulty implementations, one
can use the transmitter’s intrinsic characteristics to create a “fingerprint” that can be
used by a transmitter identification system. Every transmitter, no matter how similar,
Exploiting Spatio-Temporal Correlation in RF Data … 145
This inherent heterogeneity can be exploited to create unique identifiers for the
transmitters. One such property is the imbalance in the Inphase (I ) and Quadrature
(Q) phase components of the signal (I/Q data). However, because of the sheer number
of the transmitters involved, manually “fingerprinting” each and every transmitter
is not a feasible task [6]. Thus, in order to build such a system, there needs to
be an “automatic” method of extracting the transmitter characteristics and using the
resulting “fingerprint” for the differentiation process. One way of achieving this, is by
learning the representation of the transmitter in an appropriate “feature space” that has
enough discriminating capability so as to be able to differentiate between “apparently
identical” transmitters. Hence, the concept of “transmitter fingerprinting” came into
light for the task for transmitter classification or identification. For the rest of the
chapter we use the phrases “transmitter classification” and “transmitter recognition”
interchangeably.
We refer to deep neural networks (DNNs) as multiple layer feedforward network with
backprogapation. DNNs have revolutionized the field of artificial intelligence in the
last few years. The problems tackled by DNNs range over Computer Vision, Natural
Language Processing, Speech Processing, and so on. They have been demonstrated to
perform better than humans, for some of these problems. They have also been shown
Exploiting Spatio-Temporal Correlation in RF Data … 147
to be effective for automatically learning discriminating features from data for various
tasks [13]. With proper choice of the neural network architecture and associated
parameters, they can compute arbitrarily good function approximations [21].
In general, the use of deep learning in the RF domain has been limited in the past
with only a few applications in recent times [30]. However, DNNs have already
proven its applicability for the task of transmitter identification. A DNN-based
approach was proposed in [36], where the authors classified among eight differ-
ent transmitters with ∼95% accuracy. For the modulation recognition, a real-life
demonstration of modulation classification using a smartphone app was presented
in [49]. The authors designed the mobile application DeepRadioTM to distinguish
between six different modulation schemes by implementing a four-layer DNN.
Fully connected DNNs are like standard neural networks but with a lot more hidden
layers. However, these networks can be augmented with convolutional layers for
faster training and for enabling the network to learn more compact and meaningful
representations. Deep convolutional neural networks (DCNN) have been shown to
be effective for several different tasks in communication. There have been quite a few
attempts of using DCNN for learning inherent RF parameters for both modulation
recognition and transmitter identification.
In [29], the authors have demonstrated the use of convolutional neural networks
for modulation detection. Apart from the results, an interesting aspect of the work
is the way I/Q values were used as input to the neural network. More precisely,
given N I/Q values, the authors used a vector of size 2N as an input to the neural
network, effectively using the I and Q components as a tuple representing a point
in the complex plane. This representation proves to be useful for using the I/Q
data in different learning models. However, the authors considered the only spatial
correlations to be exploited by the DCNN implementation for synthetic RF data [27].
More researches were done in [31] with same concept using the real RF data. Further
investigations on modulation recognition were presented in [47] using DCNN.
As of the transmitter identification task, an application of transmitter fingerprint-
ing was presented in [14] for detecting new types of transmitters using an existing
trained neural network model. Similarly, in [44], Sankhe et al. presented a DCNN-
based large-scale transmitter identification technique with 80–90% accuracy. How-
ever, they could only achieve competitive accuracy by factoring in the controlled
impairments at the transmitter. It is to be noted that all prior works have only exploited
the spatial correlation of the signal data, though a continuous signal can be repre-
sented as a time series, having both temporal and spatial properties [54].
148 D. Roy et al.
Recurrent neural networks (RNN) are capable of predictions with time series data.
They have been used extensively for modeling temporal data such as speech [34].
There is limited amount of work that recognizes the potential of using recurrent
structures in the RF domain. RNNs have also been shown to be useful for capturing
and exploiting the temporal correlations of time series data [13]. There are a few
variants of recurrent neural networks: (i) Long Short-Term Memory (LSTM) [16],
(ii) Gated Recurrent Unit (GRU) [5], and (iii) Convolutional Long Short-Term Mem-
ory (ConvLSTM) [46]. All these variants are designed to learn the long-term temporal
dependencies and are capable of avoiding the “vanishing” or “exploding” gradient
problems [8].
In [26], O’Shea et. al. presented an RNN model that extracted high-level protocol
information from the low-level physical layer representation for the task of radio
traffic sequence classification. A radio anomaly detection technique was presented
in [28], where the authors used an LSTM-based RNN as a time series predictor using
the error component to detect anomaly from real signals. Another application of RNN
was proposed in [41], where the authors used a deep recurrent neural network to learn
the time-varying probability distribution of received powers on a channel and used
the same to predict the suitability of sharing that channel with other users. Bai et al.
proposed an RF path fingerprinting [2] method using two RNNs in order to learn the
spatial or temporal pattern. Both simulated and real-world data were used to improve
the positioning accuracy and robustness of moving RF devices. All these discussed
approaches prove that the temporal property of the RF data can be leveraged through
RNN models to learn and analyze different RF properties. Moreover, the working
principle of RNN is rather simple compared to many complicated learning algorithms
such as multi-stage training (MST). In [58], the authors compared several learning
paradigms for the task of transmitter identification. More precisely, they looked at
feedforward deep neural nets, convolutional neural nets, support vector machines,
and deep neural nets with MST. Though they achieved 99% accuracy for classifying
12 transmitters using MST, on the downside MST is a complex procedure and needs
an easily parallelizable training environment. On the other hand, RNNs use a first-
order update rule (stochastic gradient) and are comparatively simple procedures.
However, there has not been much effort on identifying RF transmitters by exploiting
the recurrent structure within RF data.
As far as the modulation recognition technique is concerned, a method was pro-
posed in [33] for a distributed wireless spectrum sensing network. The authors pro-
posed a recurrent neural network using long short-term memory (LSTM) cell, yield-
ing 90% accuracy on a synthetic dataset [32].
In the light of recent developments, one could argue that deep learning can be a
natural choice for implementing any system for exploiting different RF parameters
to build systems for RF application-oriented tasks. It is also clear that the paradigm
of applying deep learning in the RF domain gradually shifted from applying DNN
to CNN to RNN, while for different application areas, it shifted from modulation
Exploiting Spatio-Temporal Correlation in RF Data … 149
As mentioned earlier, we use three variants of RNNs with time series of I/Q data
to present an end-to-end solution for transmitter identification using the concept of
“transmitter fingerprinting”. The main highlights for the rest of the chapter areas
follows:
1. We exploit the temporal properties of I/Q data by using a supervised learning
approach for transmitter identification using recurrent neural networks. We use
two approaches: first, we exploit only the temporal property and, then we exploit
the spatio-temporal property. We use RNNs with LSTM and GRU cells for the first
approach, while we use a convLSTM model for the latter. Although transmitter
fingerprinting has been studied before, to the best of our knowledge this is the
first work which leverages the spatio-temporal property of the over-the-air signal
data for this task.
2. To examine the performance of the proposed networks, we test them on an indoor
testbed. We transmit raw signal data from eight universal software radio periph-
eral (USRP) B210s [10] and collect over-the-air signals using an RTL-SDR [24].
We use the I/Q values from each USRP for fingerprinting the corresponding trans-
mitter.
3. We collect I/Q data from several different types of SDRs, a couple of them made
by the same manufacturer and a couple made by different manufacturers. More
precisely we use USRP-B210 [10], USRP-B200 [9] and USRP-X310 [11] from
Ettus Research as well as ADALM-PLUTO [7] and BLADE-RF [25]. We show
that the spatio-temporal property is more pronounced (and thus easier to exploit,
both explicitly as well as implicitly) when different types of SDRs (from both
same or different manufacturers) are used as transmitters.
4. We also collect three additional datasets of I/Q values from eight USRP B210s
with varying signal-to-noise-ratio (SNR). We use distance and multi-path as the
defining factors for SNR variation during data collection.
5. We train the proposed RNN models and present a competitive analysis of the
performance of our models against the state-of-the-art techniques for transmitter
classification. Results reveal that the proposed methods out-perform the existing
ones, thus establishing the fact that exploiting the spatio-temporal property of I/Q
data can offset the necessity of pre-processing the raw I/Q data used for traditional
transmitter fingerprinting approaches.
150 D. Roy et al.
where (P(xt |Ck )) is the conditional probability of xt given the class Ck and
(P(xt x x+1 )) is the probability of xt and xt+1 occurring in order.
Though LSTM cells can be modeled and designed in various ways depending on the
need, we use the cells as shown in Fig. 1. In one LSTM cell, there are (i) three types
of gates: input (i), forget ( f ), and output (o); and (ii) a state update of internal cell
memory. The most interesting part of the LSTM cell is the “forget” gate, which at
time t is denoted by f t . The forget gates decide whether to keep a cell state memory
(ct ) or not. The forget gates are designed as per the Eq. (2) on the input value xt at
time t and output (h t−1 ) at time (t − 1).
Exploiting Spatio-Temporal Correlation in RF Data … 151
Note that Wx f and b f represent the associated weight and bias, respectively, between
input (x) and forget gate ( f ) and σ denotes the sigmoid activation function. Once f t
determines which memories to forget, the input gates (i t ) decides which cell states
ct ) to update as per Eqs. (3) and (4).
(
In Eq. (5), the old cell state (ct−1 ) is updated to the new cell state (ct ) using forget
gates ( f t ) and input gates (i t ).
Here ◦ is the Hadamard product. Finally, we filter the output values through the
output gates (ot ) based on the cell states (ct ) as per Eqs. (6) and (7).
h t = ot · tanh(ct ) (7)
The main drawback of using LSTM cells is the need for additional memory. GRUs [5]
have one less gate for the same purpose, thus having a reduced memory and CPU
footprint. The GRU cells control the flow of information just like the LSTM cells,
152 D. Roy et al.
but without the need for a memory unit. It simply exposes the full hidden content
without any control. It has a “reset gate” (z t ), an “update gate” (rt ), and a cell state
memory (ct ) as shown in Fig. 2. The reset gates determine whether to combine the
new input with a cell state memory (ct ) or not. The update gate decides how much
of ct to retain. The Eqs. (8)–(11), related to the different gates and states are given
below.
z t = σ(Wx z xt + Whz h t−1 + bz ) (8)
h t = (1 − z t ) · ct + z t · h t−1 (11)
To mitigate this problem, we use a convolution within the recurrent structure of the
RNN. We first discuss the spatio-temporal property of RF data and then model a
convolutional LSTM network to exploit the same.
total number of features sampled at each time stamp (in our case its 2048 since there
are 1024 features sampled each of dimension 2). Note that each cell corresponding
to one value of R and one value of C represents a particular feature (I or Q) at a given
point in time.
In order to capture the temporal property only, we use a sequence of vectors cor-
responding to different timestamps 1, 2, . . . , t as x1 , x2 , . . . , xt . However, to capture
both spatial and temporal properties, we introduce a new vector χt,t+γ , which is
formulated as: χt,t+γ = [xt , xt+1 , . . . , xt+γ−1 ]. So the vector χt,t+γ eventually pre-
serves the spatial properties with an increment of γ in time. So, we get a sequence
of new vectors χ1,γ , χγ,2γ , . . . χt,t+γ , . . . , χt+(β−1)γ,t+βγ , where β is R/γ, and the
goal is to create a model to classify them into one of the K classes (corresponding to
the transmitters). We model the class-conditional densities given by P(χt−γ,t |Ck ),
where k ∈ 1, · · · , K . We formulate the probability of the next γ-length sequence to
be in class Ck as per Eq. 12. The marginal probability is modeled as P(χt,t+γ ).
The cell model, as shown in Fig. 3, is similar to an LSTM cell, but the input transfor-
mations and recurrent transformations are both convolutional in nature [46]. We for-
mulate the input values, cell state, and hidden states as a 3-dimensional vector, where
the first dimension is the number of measurements which varies with the time interval
γ and the last two dimensions contain the spatial information (rows (R) and columns
(C)). We represent these as: (i) the inputs: χ1,γ , χγ,2γ , · · · χt,t+γ , · · · , χt+(β−1)γ,t+βγ
(previously stated); (ii) cell outputs: C1 , · · · , Ct , and (iii) hidden states: H1 , · · · , Ht .
We represent the gates in a similar manner as in the LSTM model. The parameters t,
i t , f t , ot , W , b hold the same meaning as in Sect. 4.1. The key operations are defined
in Eqs. 13–17. The probability of the next γ-sequence to be in a particular class (from
Eq. 12) is used within the implementation and execution of the model.
Ht = ot · tanh(Ct ) (17)
154 D. Roy et al.
5 Testbed Evaluation
In order to validate the proposed models, we collected raw signal data from eight
different universal software radio peripheral (USRP) B210s [10]. We collected the
data in an indoor lab environment with a signal-to-noise ratio of 30 dB, and used the
dataset to discriminate between four and eight transmitters, as mentioned in [40]. We
also collected data with varied SNRs (20, 10, and 0 dB) for the same 8 USRP B210
transmitters. Finally, we collected data from five different types of transmitters, each
transmitter being implemented using an SDR, from several different manufacturers.
In order to evaluate our methods for learning the inherent spatio-temporal features of
a transmitter, we used different types of SDRs as transmitters. The signal generation
and reception are shown in Fig. 4. We used GNURadio [12] to randomly generate
signal and modulated the same with quadrature phase shift keying (QPSK). We
programmed the transmitter SDRs to transmit the modulated signal over-the-air and
sensed the same using a DVB-T dongle (RTL-SDR) [24]. We generated the entire
dataset from “over-the-air” data as sensed by the RTL-SDR using the rtlsdr python
library.
We collected I/Q signal data with a sample size of 1024 at each timestamp. Each
data sample had 2048 entries consisting of the I and Q values for the 1024 samples.
Note that a larger sample size would mean more training examples for the neural
network. Our choice of 1024 samples was sufficient to capture the spatial-temporal
properties while at the same time the training was not computationally intensive. We
collected 40,000 training examples from each transmitter to avoid the data skewness
problem observed in machine learning. The configuration parameters that were used
are given in Table 1. We collected two different datasets at 30 dB SNR and three
datasets having three different SNRs, as discussed below. The different types of SNR
Exploiting Spatio-Temporal Correlation in RF Data … 155
Data
Datasets RTL-SDR
Collection
levels were achieved in an indoor lab environment by changing the propagation delay,
multi-path, and shadowing effects. We also collected a “heterogeneous" dataset using
several different types of SDRs. Note that we intend to make the dataset publicly
available upon publication of the chapter.
For the “homogeneous” dataset, we used eight radios from the same manufacturer,
namely, the USRP-B210 from Ettus Research [10], as transmitters. We collected two
sets of data: (i) using 4 USRP B210 transmitters: 6.8 GB size, 160K rows, and 2048
columns and (ii) using 8 USRP B210 transmitters: 13.45 GB size, 320K rows, and
2048 columns. Note that the SNR was 30 dB in each case.
In order to investigate the spatio-temporal correlation in the I/Q data from different
types of SDRs from varied manufacturers, we collected a “heterogeneous” dataset
as well. We used three different SDRs from same manufacturer and 2 SDRs from
two different manufacturers. We used USRP B210 [10], USRP B200 [9] and USRP
X310 [11] from Ettus Research. We also used BLADE RF [25] by Nuand and PLUTO
SDR [7] by Analog Devices as two different SDRs from two different manufacturers.
156 D. Roy et al.
The signal generation procedure is similar to Fig. 4 with different SDR models as
transmitters. The SNR remains 30 dB, same as earlier. The “heterogeneous” datasets
were obtained using (i) 5 USRP B210 transmitters: 8.46 GB size, 200K rows, and
2048 columns and (ii) 1 USRP B210, 1 USRP B200, 1 USRP X310, 1 BLADERF,
and 1 PLUTO transmitter: 6.42 GB size, 200K rows, and 2048 columns.
We include the 5 USRP B210 homogeneous data in this dataset to perform a
fair comparison between five heterogeneous radios with five homogeneous ones in
Sect. 6.6, as the mentioned homogeneous datasets in previous paragraph contain
either four or eight homogeneous radios.
We collected three more datasets with 8 USRP B210 transmitters with SNRs of 20
dB, 10 dB, and 0 dB, respectively. Each dataset is of size ∼13 GB with 320K rows
and 2048 columns.
Correlation between data samples plays a crucial role in the process of transmitter
identification. We represent the I and Q values of each training sample at time (t)
as: [I0 Q 0 I1 Q 1 I2 Q 2 I3 Q 3 I4 Q 4 . . . I1023 Q 1023 ]t . We used the QPSK modulation [48]
which means that the spatial correlation should be between every fourth value, i.e.,
between I0 and I4 , and Q 0 and Q 4 . So we calculate the correlation coefficient of
I0 I1 I2 I3 and I4 I5 I6 I7 . Similarly, for Q 0 Q 1 Q 2 Q 3 and Q 4 Q 5 Q 6 Q 7 . We take the aver-
age of all the correlation coefficients for each sample.
We use numpy.corrcoef for this purpose which uses Pearson product-moment
correlation coefficients, denoted by r . The Pearson’s method for a sample is given
by
(M−1)
i=0 (Ii − I )(Q i − Q̄)
r = (18)
(M−1) (M−1)
i=1 (I i − I ) 2
i=0 (Q i − Q̄) 2
where M is the sample size, Ii and Q i are the sample values indexed with i. The
1 (M−1)
sample mean is I¯ = Ii .
M i=0
The spatial correlations of all the samples for the different transmitters are shown
in Fig. 5. We observe that for most of the transmitters, the correlation is ∼0.42, with
a standard deviation of ∼0.2. However, transmitter 3 exhibits minimal correlation
between these samples, which implies that the spatial property of transmitter 3 is
different from the other transmitters. As a result transmitter 3 should be easily dis-
tinguishable from the others. This claim will be validated later in the experimental
Exploiting Spatio-Temporal Correlation in RF Data … 157
result section where we see 0% false positive and false negative for transmitter 3 for
all the three proposed models. This observation gives us the motivation to exploit the
spatial property as well as the temporal property for the collected time series data.
The calculated average spatial correlations of samples for five different types of
transmitters (as mentioned in Sect. 5.1.2 (ii)) are shown in Fig. 6. We observe that
data from USRP-B210 and PLUTO-SDR have better correlations than the other three
types. It is to be noted that we calculated the spatial correlation of this data using the
same technique described in previous section. It is also evident from the figure that
none of the transmitters exhibits impressive correlations, however, each has correla-
tions of different ranges. This phenomenon bolsters our claim that spatio-temporal
property in the heterogeneous data will be more distinguishable than homogeneous
ones, validated later in experimental result section (Sect. 6.6).
In this section we discuss the implementation of each of the proposed recurrent neural
networks. We train each network for transmitter classification with K classes. For
the sake of robustness and statistical significance, we present the results for each
model after averaging over several runs.
As discussed earlier, the recurrent structure of the neural network can be used to
exploit the temporal correlation in the data. To that end, we first implemented a
recurrent neural network with LSTM cells and trained it on the collected dataset
using the paradigm as shown in Fig. 7. We used two LSTM layers with 1024 and 256
units sequentially. We also used a dropout rate of 0.5 in between these two LSTM
layers. Next we used two fully connected (Dense) layers with 512 and 256 nodes,
respectively. We apply a dropout rate of 0.2, and add batch normalization [17] on
LSTM Layer1
LSTM Layer2
Dense
Dense
Data
Signal
Input
transmitter classification
(None,
1024 256 512 256 8 8
2048)
Exploiting Spatio-Temporal Correlation in RF Data … 159
the output, finally passing it through a Dense layer having eight nodes. We use ReLU
[23] as the activation function for the LSTM layers and tanh [3] for the Dense layers.
Lastly, we use stochastic gradient descent [3]-based optimization with categorical
cross-entropy training. Note that the neural network architecture was finalized over
several iterations of experimentation with the data and we are only reporting the
final architecture here. We achieved 97.17% and 92.00% testing accuracy for four
and eight transmitters, respectively. The accuracy plots and confusion matrices are
shown in Figs. 8 and 9, respectively. Note that the number of nodes in the last layer
is equal to the number of classes in the dataset. It is also to be noted that during the
process of designing the RNN architecture, we also fine tuned the hyper-parameters-
based generalization ability of the current network (as determined by comparing the
training and validation errors). We also limited the number of recurrent layers and
fully connected layers for each model for faster training [15], since no significant
increase in the validation accuracy was observed after increasing the number of
layers.
The rows and columns of the confusion matrix correspond to the number of
transmitters (classes) and the cell values show the recall or sensitivity and false
negative rate for each of the transmitters. Note that recall or sensitivity represents
the true positive rates for each of the prediction classes.
Next we implemented another variation of the RNN model using GRU cells for
leveraging temporal correlation. We used the same architecture as the LSTM imple-
mentation, presented in Fig. 10. The proposed GRU implementation needs fewer
parameters than the LSTM model. A quantitative comparison is given in Sect. 6.4.
The only difference is that we use two GRU layers with 1024 and 256 units instead
of using LSTM cells. We achieved 97.76% and 95.30% testing accuracy for four and
eight transmitters, respectively. The accuracy plots and confusion matrices are given
in Figs. 11 and 12. The GRU implementation provided a slight improvement over
the accuracy obtained using LSTM, for each run of the models, for both the datasets.
Processing and
Data Collection
GRU Layer1
GRU Layer2
Output
Dense
Dense
Dense
Data
Signal
Input
(None,
1024 256 512 256 8 8
2048)
Finally, in order to exploit the spatio-temporal property of the signal data, we imple-
mented another variation of the LSTM model with convolutional filters (transforma-
tions). The implemented architecture is shown in Fig. 13. ConvLSTM2D uses two-
dimensional convolutions for both input transformations and recurrent transforma-
tions. We first use two layers of convLSTM2D with 1024 and 256 filters, respectively,
and a dropout rate of 0.5 in between. We use kernel size of (2, 2) and stride of (2, 2)
at each ConvLSTM2D layer. Next we add two fully connected (Dense) layers having
512 and 256 nodes, respectively, after flattening the convolutional output. ReLU [23],
and tanh [3] activation functions are used for the convLSTM2D and Dense layers,
respectively. ADADELTA [59] with a learning rate of 10−4 and a decay rate of 0.9,
is used as the optimizer with categorical cross-entropy training. We achieved 98.9%
and 97.2% testing accuracy for four and eight transmitters, respectively. The accu-
racy plots and confusion matrices are given in Figs. 14 and 15, respectively. Being
able to exploit the spatio-temporal correlation, ConvLSTM implementation provides
improvement over the accuracies obtained using the LSTM and GRU models, for
both the datasets.
Processing and
Data Collection
ConvLSTM2D
ConvLSTM2D
Output
Layer1
Layer2
Dense
Dense
Dense
Data
Signal
Input
Flatten
(None,
1024 256 512 256 8 8
2048)
We used 90%, 5%, and 5% of the data to train, validate, and test, respectively. We
ran each model for 50 epochs with early-stopping on the validation set. One epoch
consists of a forward pass and a backward pass through the implemented architec-
ture for the entire dataset. The overall accuracies of the different implementations
are shown in Table 2. We find that the implementation of convolutional layers with
recurrent structure (ConvLSTM2D) exhibit the best accuracy for transmitter classi-
fication, which clearly shows the advantage of using the spatio-temporal correlation
present in the collected datasets. In Fig. 16, we present a better illustration of the
achieved classification accuracies for the different implemented models.
Exploiting Spatio-Temporal Correlation in RF Data … 163
Fig. 16 Comparison of
97.17%
97.76%
98.90%
92.00%
95.30%
97.20%
testing accuracies of
different types of recurrent
neural networks
ConvLSTM
ConvLSTM
LSTM
LSTM
GRU
GRU
So far, we have implemented the proposed RNN models for “homogeneous” datasets,
where transmitter SDRs were from the same manufacturer. However, in reality the
transmitters can be of different models from either same manufacturer, or several
different manufacturers. Now, we want to explore how the accuracy of transmitter
identification would change if “heterogeneous” data (as was discussed in Sect. 5.1.2)
obtained from different types of transmitters (manufacturers) were used. From the
testing accuracies as shown in Table 4, we observe that all the RNNs perform better
when transmitters are of different models either from same or different manufacturers
and hence are fundamentally of different types. The performance of LSTM, GRU,
and ConvLSTM increase 5%, 3.37%, and 1.51%, respectively, for classifying het-
erogeneous radios than the homogeneous one. This confirms the intuition that radios
manufactured using different processes (from different manufacturers) contain easily
exploitable characteristics in their I/Q samples, that can be implicitly learned using
an RNN. The comparison of confusion matrices for all three proposed models are
presented Figs. 17, 18, and 19. The false positives and true negatives are observed to
be considerably low for the 5-HETERO results than the 5-B210s. It is to be noted that
Table 4 Comparison of testing accuracies for different classification models for homogeneous and
heterogeneous datasets
Models 5-B210s (Acc) (%) 5-HETERO (Acc) (%) Change
LSTM 95.61 99.89 5%↑
GRU 96.72 99.97 3.37%↑
ConvLSTM 98.5 99.99 1.51%↑
Fig. 17 Confusion matrices for transmitter classification using LSTM cells for heterogeneous
dataset
166 D. Roy et al.
Fig. 18 Confusion matrices for transmitter classification using GRU cells for heterogeneous dataset
Fig. 19 Confusion matrices for transmitter classification using ConvLSTM cells for heterogeneous
dataset
we used the same proposed RNN models for heterogeneous data too, thus, implying
the robustness of the proposed models.
Table 6 Comparison of the our implementation with the existing transmitter classification
approaches
Approach #Trans SNR (dB) Acc (%) Inputs
Orthogonal component 3 20 62–71 Spurious
reconstruction modulation
(OCR) [57]
Genetic Algorithm [51] 5 25 85–98 Transients
Multifractal 8 Not mentioned 92.50 Transients
segmentation [45]
k-NN [18] 8 30 97.20 Transients
Ours 8 30 97.20 Raw signal
Ours-hetero 5 30 99.99 Raw signal
The “Inputs” column in both the tables refer to the type of inputs used for the
methods under consideration. Table 5 shows a comparison of our ConvLSTM-based
RNN for transmitter classification with other RNN-based implementations for sep-
arate tasks like modulation recognition and traffic sequence recognition. Table 6
establishes the efficacy of our ConvLSTM-based RNN model for the task of trans-
mitter classification in terms of testing accuracies. It is to be noted that all the other
methods use expert crafted features as inputs [18, 45, 51, 57], or work with synthetic
datasets [26, 33]. Our method, on the other hand achieves superior accuracy, for both
homogeneous (97.20%) and heterogeneous (99.99%) datasets, using features auto-
matically learned from the raw signal data, thereby paving the way for real-time
deployment of large-scale transmitter identification systems.
168 D. Roy et al.
In this section, we present the results of transmitter fingerprinting for varying SNR
values. We compare the accuracies for the proposed RNN models having 8 USRP
B210s with 30 dB SNR, with 3 other datasets collected at 0, 10, and 20 dB SNRs
having the same number of transmitters (8B210s) as shown in Table 7. We achieve
better accuracies with all the models for higher SNR values, which is intuitive. It is
to be mentioned that the proposed ConvLSTM RNN model gives more than 93%
accuracy at 0 dB SNR too, whereas GRU model gives lesser than that, and LSTM
fails to achieve a considerable range.
Moreover, the proposed RNN models can be trained using raw signal data from any
type of radio transmitter operating both in indoor as well as outdoor environments. We
would also like to point out that though our data was collected in a lab environment,
we had no control over the environment, there were other transmissions in progress,
people were moving in and out of the lab and there was a lot of multi-path due to
the location and design of the lab. Furthermore, the power of the transmitters was
low and hence this compounded the problem further. Given this, though we say that
the data was collected in a lab environment, in reality it was an uncontrolled daily
use environment reflective of our surroundings. Thus we can safely say that these
methods will work in any real-world deployment of large-scale radio network. In
summary,
• Exploiting temporal correlation only, recurrent neural networks yield 95–97%
accuracy for transmitter classification using LSTM or GRU cells. RNN imple-
mentation with GRU cells needs fewer parameters than LSTM cells as shown in
Table 2.
• Exploiting spatio-temporal correlation, the implementation of RNN using Con-
vLSTM2D cells provides better accuracy (97–98%) for transmitter classification,
thus providing a potential tool for building automatic real world transmitter iden-
tification systems.
• The spatio-temporal correlation is more noticeable (with 1.5–5% improvement
of classification accuracies) in the proposed RNN models for the heterogeneous
transmitters either different models from same manufacturer, or different models
from different manufacturers.
Table 7 Accuracies for different recurrent neural network models with varying SNRs
SNR(dB) Accuracy (%)
LSTM GRU ConvLSTM
0 84.23 90.3 93.3
10 90.21 92.64 95.64
20 91.89 94.02 97.02
30 92.00 95.30 97.20
Exploiting Spatio-Temporal Correlation in RF Data … 169
• The proposed RNN models give better accuracies with increasing SNRs of the
data collection environment. However, the ConvLSTM model is able to classify
with 93% accuracy at 0 dB SNR too, proving the robustness of spatio-temporal
property exploitation.
• We present a comparative study of the proposed spatio-temporal property-based
fingerprinting with the existing traditional and neural network-based models. This
clearly shows that the proposed model achieves the better accuracies compared to
any of the existing methods for transmitter identification.
7 Summary
With more and more autonomous deployments of wireless networks, accurate knowl-
edge of the RF environment is becoming indispensable. In recent years, there has
been a proliferation of autonomous systems that use deep learning algorithms on
large-scale historical data. To that end, the inherent recurrent structures within the
RF historical data can also be leveraged by deep learning algorithms for reliable
future prognosis.
In this chapter, we addressed some of such fundamental challenges on how to
effectively apply different learning techniques in the RF domain. We presented a
robust transmitter identification technique by exploiting both the inherent spatial
and temporal properties of RF signal data. The testbed implementation and result
analysis prove the effectiveness of the proposed deep learning models. The future
step forward can be to apply these methods for identification of actual infrastructure
transmitters (for example FM, AM, and GSM) in real-world settings.
8 Further Reading
More details on deep learning algorithms can be found in [55]. Advanced applications
of deep learning in RF domain involving adversaries can be found in [35, 36, 40].
Deep learning applications in the advanced field of RF, such as dynamic spectrum
access is discussed in [37, 38].
References
3. C.M. Bishop, Pattern Recognition and Machine Learning (Information Science and Statistics)
(Springer, 2006)
4. F. Chollet, et al., Keras: the python deep learning library (2015). https://keras.io
5. J. Chung, C. Gülçehre, K. Cho, Y. Bengio, Empirical evaluation of gated recurrent neural
networks on sequence modeling. CoRR (2014). arXiv:abs/1412.3555
6. B. Danev, S. Capkun, Transient-based identification of wireless sensor nodes, in International
Conference on Information Processing in Sensor Networks (2009), pp. 25–36
7. A. Devices, ADALM-PLUTO overview (2020). https://wiki.analog.com/university/tools/pluto
8. R. Dey, F.M. Salemt, Gate-variants of gated recurrent unit (GRU) neural networks, in 2017
IEEE 60th International Midwest Symposium on Circuits and Systems (MWSCAS) (2017), pp.
1597–1600
9. Ettus Research: USRP B200 (2020). https://www.ettus.com/all-products/ub200-kit/
10. Ettus Research: USRP B210 (2020). https://www.ettus.com/product/details/UB210-KIT/
11. Ettus Research: USRP X310 (2020). https://www.ettus.com/all-products/x310-kit/
12. GNURadio: GNU Radio (2020). https://www.gnuradio.org
13. I. Goodfellow, Y. Bengio, A. Courville, Deep Learning. MIT press (2016)
14. A. Gritsenko, Z. Wang, T. Jian, J. Dy, K. Chowdhury, S. Ioannidis, Finding a ‘New’ needle
in the haystack: unseen radio detection in large populations using deep learning, in IEEE
International Symposium on Dynamic Spectrum Access Networks (DySPAN) (2019), pp. 1–10
15. K. He, J. Sun, Convolutional neural networks at constrained time cost, in IEEE CVPR (2015)
16. S. Hochreiter, J. Schmidhuber, Long short-term memory. Neural Comput. 9(8), 1735–1780
(1997)
17. S. Ioffe, C. Szegedy, Batch normalization: accelerating deep network training by reducing
internal covariate shift. CoRR (2015). arXiv:abs/1502.03167
18. I.O. Kennedy, P. Scanlon, F.J. Mullany, M.M. Buddhikot, K.E. Nolan, T.W. Rondeau, Radio
transmitter fingerprinting: a steady state frequency domain approach, in IEEE Vehicular Tech-
nology Conference (2008), pp. 1–5
19. B.P. Lathi, Modern Digital and Analog Communication Systems, 3rd edn. (Oxford University
Press Inc, USA, 1998)
20. Y. LeCun, Y. Bengio, G. Hinton, Deep learning. Nature 521(7553), 436–444 (2015)
21. H.W. Lin, M. Tegmark, D. Rolnick, Why does deep and cheap learning work so well? J. Stat.
Phys. 168(6), 1223–1247 (2017)
22. R. Livni, S. Shalev-Shwartz, O. Shamir, On the computational efficiency of training neural
networks, in Advances in Neural Information Processing Systems (2014), pp. 855–863
23. V. Nair, G.E. Hinton, Rectified linear units improve restricted boltzmann machines, in Proceed-
ings of International Conference on International Conference on Machine Learning (2010),
pp. 807–814
24. NooElec: USRP B210 (2018). http://www.nooelec.com/store/sdr/sdr-receivers/nesdr-mini-
rtl2832-r820t.html
25. Nuad: bladeRF 2.0 micro xA4 (2020). https://www.nuand.com/product/bladeRF-xA4/
26. T.J. O’Shea, S. Hitefield, J. Corgan, End-to-end Radio traffic sequence recognition with recur-
rent neural networks, in IEEE Global Conference on Signal and Information Processing (Glob-
alSIP) (2016), pp. 277–281
27. T. O’Shea, N. West, Radio machine learning dataset generation with GNU radio. Proc. GNU
Radio Conf. 1(1) (2016)
28. T.J. O’Shea, T.C. Clancy, R.W. McGwier, Recurrent neural radio anomaly detection. CoRR
(2016). arXiv:abs/1611.00301
29. T.J. O’Shea, J. Corgan, T.C. Clancy, Convolutional radio modulation recognition networks, in
Engineering Applications of Neural Networks (2016), pp. 213–226
30. T. O’Shea, J. Hoydis, An introduction to deep learning for the physical layer. IEEE Trans.
Cogn. Commun. Netw. 3(4), 563–575 (2017)
31. T.L. O’Shea, T. Roy, T.C. Clancy, Over-the-air deep learning based radio signal classification.
IEEE J. Sel. Top. Signal Process. 12(1), 168–179 (2018)
32. radioML: RFML 2016 (2016). https://github.com/radioML/dataset
Exploiting Spatio-Temporal Correlation in RF Data … 171
33. S. Rajendran et al., Deep learning models for wireless signal classification with distributed
low-cost spectrum sensors. IEEE Trans. Cogn. Commun. Netw. 4(3), 433–445 (2018)
34. J.S. Ren, Y. Hu, Y.W. Tai, C. Wang, L. Xu, W. Sun, Q. Yan, Look, listen and learn-a multimodal
LSTM for speaker identification, in AAAI (2016), pp. 3581–3587
35. D. Roy, T. Mukherjee, M. Chatterjee, Machine learning in adversarial RF environments. IEEE
Commun. Mag. 57(5), 82–87 (2019)
36. D. Roy, T. Mukherjee, M. Chatterjee, E. Blasch, E. Pasiliao, RFAL: adversarial learning for
RF transmitter identification and classification. IEEE Trans. Cogn. Commun. Netw. (2019)
37. D. Roy, T. Mukherjee, M. Chatterjee, E. Pasiliao, Defense against PUE attacks in DSA networks
using GAN based learning, in IEEE Global Communications Conference (2019)
38. D. Roy, T. Mukherjee, M. Chatterjee, E. Pasiliao, Primary user activity prediction in DSA
networks using recurrent structures, in IEEE International Symposium on Dynamic Spectrum
Access Networks (2019), pp. 1–10
39. D. Roy, T. Mukherjee, M. Chatterjee, E. Pasiliao, RF transmitter fingerprinting exploiting
spatio-temporal properties in raw signal data, in IEEE International Conference on Machine
Learning and Applications (2019), pp. 89–96
40. D. Roy, T. Mukherjee, M. Chatterjee, E. Pasiliao, Detection of rogue RF transmitters using
generative adversarial nets, in IEEE Wireless Communications and Networking Conference
(WCNC) (2019)
41. H. Rutagemwa, A. Ghasemi, S. Liu, Dynamic spectrum assignment for land mobile radio
with deep recurrent neural networks, in IEEE International Conference on Communications
Workshops (ICC Workshops) (2018), pp. 1–6
42. Y.B. Saied, A. Olivereau, D-HIP: a distributed key exchange scheme for HIP-based internet of
things, in World of Wireless, Mobile and Multimedia Networks (WoWMoM) (2012), pp. 1–7
43. Sak, H., Senior, A.W., Beaufays, F.: Long short-term memory based recurrent neural network
architectures for large vocabulary speech recognition. CoRR (2014). arXiv:abs/1402.1128
44. K. Sankhe, M. Belgiovine, F. Zhou, L. Angioloni, F. Restuccia, S. D’Oro, T. Melodia, S.
Ioannidis, K. Chowdhury, No radio left behind: radio fingerprinting through deep learning of
physical-layer hardware impairments. IEEE Trans. Cogn. Commun. Netw. 1 (2019)
45. D. Shaw, W. Kinsner, Multifractal modelling of radio transmitter transients for classification,
in IEEE WESCANEX (1997), pp. 306–312
46. X. Shi, Z. Chen, H. Wang, D.Y. Yeung, W.K. Wong, W.C. Woo, Convolutional LSTM network:
a machine learning approach for precipitation nowcasting, in Proceedings of the 28th Interna-
tional Conference on Neural Information Processing Systems, vol. 1 (2015), pp. 802–810
47. Y. Shi, K. Davaslioglu, Y.E. Sagduyu, W.C. Headley, M. Fowler, G. Green, Deep learning for RF
signal classification in unknown and dynamic spectrum environments, in IEEE International
Symposium on Dynamic Spectrum Access Networks (DySPAN) (2019), pp. 1–10
48. S.W. Smith, The Scientist and Engineer’s Guide to Digital Signal Processing (California Tech-
nical Publishing, San Diego, CA, USA, 1997)
49. S. Soltani, Y.E. Sagduyu, R. Hasan, K. Davaslioglu, H. Deng, T. Erpek, Real-time experimen-
tation of deep learning-based RF signal classifier on FPGA, in IEEE International Symposium
on Dynamic Spectrum Access Networks (DySPAN) (2019), pp. 1–2
50. M. Stanislav, T. Beardsley, Hacking IoT: a case study on baby monitor exposures and vulner-
abilities. Rapid 7 (2015)
51. J. Toonstra, W. Kinsner, A radio transmitter fingerprinting system ODO-1, in Canadian Con-
ference on Electrical and Computer Engineering, vol. 1 (1996), pp. 60–63
52. D. Tse, P. Viswanath, Fundamentals of Wireless Communication (Oxford University Press Inc,
USA, 2005)
53. E. Tsironi, P. Barros, C. Weber, S. Wermter, An analysis of convolutional long short-term
memory recurrent neural networks for gesture recognition. Neurocomputing 268, 76–86 (2017)
54. N. Wagle, E. Frew, Spatio-temporal characterization of airborne radio frequency environments,
in IEEE GLOBECOM Workshops (GC Wkshps) (2011), pp. 1269–1273
55. M.A. Wani, F.A. Bhat, S. Afzal, A.I. Khan, Advances in Deep Learning (Springer, 2020)
56. R.H. Weber, R. Weber, Internet of Things, vol. 12 (Springer, 2010)
172 D. Roy et al.
57. S. Xu, L. Xu, Z. Xu, B. Huang, Individual radio transmitter identification based on spurious
modulation characteristics of signal envelop, in IEEE MILCOM (2008), pp. 1–5
58. K. Youssef, L. Bouchard, K. Haigh, J. Silovsky, B. Thapa, C.V. Valk, Machine learning approach
to RF transmitter identification. IEEE J. Radio Freq. Identif. 2(4), 197–205 (2018)
59. M.D. Zeiler, ADADELTA: an adaptive learning rate method. CoRR (2012).
arXiv:abs/1212.5701
Human Target Detection and
Localization with Radars Using Deep
Learning
M. Stephan · A. Santra
Infineon Technologies AG, Neubiberg, Germany
e-mail: avik.santra@infineon.com
M. Stephan (B) · G. Fischer (B)
Friedrich-Alexander-University Erlangen-Nuremberg, Erlangen, Germany
e-mail: michael.stephan@fau.de
G. Fischer
e-mail: georg.fischer@fau.de
© The Editor(s) (if applicable) and The Author(s), under exclusive license 173
to Springer Nature Singapore Pte Ltd. 2021
M. A. Wani et al. (eds.), Deep Learning Applications, Volume 2,
Advances in Intelligent Systems and Computing 1232,
https://doi.org/10.1007/978-981-15-6759-9_8
174 M. Stephan et al.
deep complex U-net model to generate human target detections directly from the
raw RDI. We demonstrate that the proposed deep residual U-net and complex U-net
models are capable of generating accurate target detections in the range-Doppler
and the range-angle domain, respectively. To train these networks, we record RDIs
from a variety of indoor scenes with different configurations and multiple humans
performing several regular activities. We devise a custom loss function and apply
augmentation strategies to generalize this model during real-time inference of the
model. We demonstrate that the proposed networks can efficiently learn to detect
and localize human targets correctly under different indoor environments in scenar-
ios where the conventional signal processing pipeline fails.
Keywords mm-wave radar sensor · Deep residual U-Net · Deep complex U-Net ·
Receiver operating characteristics · Detection strategy · DBSCAN · People
sensing · Localization · Human target detection · Occupancy detection
1 Introduction
The ever-increasing energy consumption has given rise to a search for energy-efficient
smart home technologies that can monitor and save energy, thus enhancing sustain-
ability and reducing the carbon footprint. Depending on baseline and operation,
several studies show that energy consumption can be significantly reduced in resi-
dential, commercial, or public spaces by 25–75% [1] by monitoring occupancy or
counting the number of people and accordingly regulating artificial light and HVAC
systems [2]. Frequency modulated continuous wave (FMCW) radars can provide a
ubiquitous solution to sense, monitor, and thus control the household appliances’
energy consumption.
Radar has evolved from automotive applications such as driver assistance systems,
safety and driver alert systems, and autonomous driving systems to low-cost solu-
tions, penetrating industrial and consumer market segments. Radar has been used
for perimeter intrusion detection systems [3], gesture recognition [4–6], human–
machine interfaces [7], outdoor positioning and localization [8], and indoor people
counting [9].
Radar responses from human targets are, in general, spread across Doppler due
to the macro-Doppler component of the torso and associated micro-Doppler com-
ponents due to hand, shoulder, and leg movements. With the use of higher sweep
bandwidths, the radar echoes from targets are not received as point targets but are
spread across range and are referred to as range-extended targets. Thus, human tar-
gets are perceived as doubly spread targets [10] across range and Doppler. Thereby,
human target detection using radars with higher sweep bandwidths requires several
adaptations in the standard signal processing pipeline before feeding the data into
application-specific processing, e.g., target counting or target tracking. Depending
on the application and the system used, there are two signal processing pipelines. One
where the processing is performed on the range-Doppler domain, and the objective
is to perform target detection and localization in the same domain.
Human Target Detection and Localization with Radars … 175
In such a case, the radar RDI processing pipeline involves MTI to remove static
targets and MRC to integrate data across antennas. We refer to the RDI after this
stage as raw RDI. The raw RDI is then fed into a constant false alarm rate (2D-CFAR)
detection algorithm, which then detects whether a cell under test (CUT) is a valid
target or not by evaluating the statistics of the reference cells. The constant-false alarm
rate detection adaptively calculates the noise floor by calculating the statistics around
the CUT and guarantees a fixed false alarm rate, which sets the threshold multiplier
with which the estimated noise floor is scaled. Further, the detections are grouped
together using a clustering algorithm such that reflections from the same human
target are detected as a single cluster and likewise for different humans. Alternately,
in the case of multiple virtual channels, the radar processing pipeline involves MTI to
remove static targets, followed by receiver angle of arrival estimation, through Capon
or minimum variance distortionless response (MVDR) beamforming to transform
the raw RDI data to range-angle image (RAI). Generating the RAI provides a means
to map the reflective distribution of the target from multi-frequency, multi-aspect
data onto the angular beams enabling to localize the target in the spatial domain.
Following RAI, the detection of targets is achieved through 2D-CFAR on the range-
angle domain, and eventually, targets are clustered using a clustering algorithm. The
former processing pipeline is typically applied in the case of single receive channel
sensor systems, where angular information is missing or in case of applications such
as activity or fall classification where relevant information lies in the range-Doppler
domain. Other applications, which require localizing the target in 2D space require
beamforming and detection in the range-angle domain.
However, human target detection in indoor environments poses further challenges,
such as ghost targets from static objects like walls, chairs, furniture, etc., and also
spurious radar responses due to multi-path reflections from multiple targets [11].
Further, strong reflecting or closer human targets often occlude less reflecting or
farther human targets at the CFAR detector output. While the earlier phenomenon
leads to overestimating the count of humans, the latter phenomenon leads to underes-
timating the count of humans in the room. This results in increased false alarms and
low detection probabilities, leading to poor radar receiver operating characteristics
and inaccurate application-specific decisions based on human target counts or target
tracking.
With the advent of deep learning, a variety of computer vision tasks, namely,
face recognition [12, 13], object detection [14, 15], segmentation [16, 17], have
seen superior state-of-the-art performance advances. The book in [18] gives a good
overview over basic and advanced concepts in deep neural networks, especially in
computer vision tasks. Image segmentation is a task wherein given an image, all the
pixels belonging to a particular class are extracted and indicated in the output image.
In [17], authors have proposed the deep U-Net architecture, which concatenates
feature maps from different levels, i.e., from low-level to high-level to achieve image
segmentation. In [19], authors have proposed deep residual connections to facilitate
training speed and improve classification accuracy. Recently, deep residual U-net
like network architectures have been proposed for identifying road lines from SAR
images [20]. We have re-purposed the deep residual U-Net to process raw RDIs
176 M. Stephan et al.
and generate target detected RDIs or RAIs. The target detected RDIs and RAIs are
demonstrated to suppress reflections from ghost targets, reject spurious targets due to
multi-path reflections, avoid target occlusions and achieve accurate target clustering
in the detected RDIs and RAIs, thus resulting in reliable and accurate human target
detections.
In [21], authors have proposed a fully convolutional neural network to achieve
object detection and estimation of a vehicle target in 3D space by replacing the
conventional signal processing pipeline. In [22], authors have proposed a deep neu-
ral network to distinguish the presence or absence of a target and demonstrate its
improved performance compared to conventional signal processing. In [23], we have
proposed using a deep residual U-net architecture to process raw RDIs into processed
RDIs, achieving reliable target detections in the range-Doppler domain.
We in this contribution, use a 60-GHz frequency modulated continuous wave
(FMCW) radar sensor to demonstrate the performance of our proposed solution
in both range-Doppler and range-angle domain. To handle our specific challenges
and use-case, we define and train our proposed deep residual U-net model using
an appropriate loss function for processing in the output range-Doppler domain.
We also propose a complex-U Net model to process the raw RDI from two receive
antennas to construct an RAI. For the complex neural network layers, we use the Ten-
sorFlow Keras, an open-source neural network library, implementation by Dramsch
and Contributors [24] of the complex layers as described in [25]. We demonstrate
the performance of our proposed system with real data with up to four persons in
indoor environments for both cases and compare the results to the conventional signal
processing approach. The paper is outlined as follows. Section 2 presents the sys-
tem design, with details of the hardware chipset in Sect. 2.1, conventional processing
pipeline in Sect. 2.2, challenges and contributions in Sect. 2.3. We present the models
of our proposed residual U-Net and complex U-Net architecture in Sects. 3 and 4.
We present and describe the dataset in Sect. 5.1, the loss function in Sect. 5.2 and the
design considerations for training in Sect. 5.3. We present the results and discussions
in Sect. 6 and we conclude in Sect. 7 also outlining possible future directions.
2 System Design
(a) Chipset
Fig. 1 a Infineon’s BGT60TR24B 60-GHz radar sensor. b Representational figure of the radar
scene and functional block diagram of the FMCW radar RF signal chain depicting 1TX, 1RX
channel
Owing to the FMCW waveform and its ambiguity function properties, the accurate
range and velocity estimates can be obtained by decoupling them on the generation
of the RDIs. The IF signal from a chirp with NTS = 256 number of samples is
received, PN consecutive chirps are collected and arranged in the form of a 2D
matrix, with dimensions of PN × NTS. The RDI is generated in two steps. The
first step involves calculating and subtracting the mean along fast time, followed by
applying 1D window function, zero-padding, and then 1D Fast Fourier Transform
(FFT) along fast time for all the PN chirps to obtain the range transformations. The
fast time refers to the dimension of NTS, which represents the chirp time. Then in
the second step, the mean along slow time is calculated and subtracted, a 1D window
function and zero-padding are applied, followed by 1D FFT along slow time to obtain
the Doppler transformation for all range bins. The slow time refers to the dimension
along PN, which represents the intra-chirp time. The mean subtraction across the
fast time removes the Tx-Rx leakage, and the subtraction across slow time removes
Human Target Detection and Localization with Radars … 179
the reflections from any static object in the field of view of the radar. Since short
chirp times are used, the frequency shift along fast time is mainly due to the two-way
propagation delay from the target, which is due to the distance of the target to the
sensor. The amplitude and phase shift across slow time is due to the Doppler shift of
the target.
Based on the application and radar system, the processing pipeline can be either
of the following:
• raw absolute RDI −→ 2D CFAR −→ DBSCAN −→ Processed RDI
• raw complex RDI (multiple channels) −→ MVDR −→ raw RAI −→ 2D CFAR
−→ DBSCAN −→ Processed RAI
In the first pipeline, the target detection and clustering operation is performed on
the range-Doppler domain. The detection and clustering is applied on the absolute
raw RDI, and the output is the processed RDI. This is typically applied in case of 1Tx -
1Rx radar sensors or in case the application is to extract target Doppler signatures
for classification, such as human activity classification, vital sensing, etc. While in
the alternate pipeline, the raw complex RDI from multiple channels is processed
through the MVDR algorithm to generate the RAI first, followed by 2D CFAR and
clustering operation in that domain. The latter pipeline is employed where target
detection through localization in 2D space is necessary for the application.
In the former pipeline, in case of multiple receive channels the raw RDIs are
optionally combined through maximal ratio combining (MRC) to gain diversity and
improve signal quality. The objective of MRC is to construct single RDIs by weighted
combinations of the RDIs across multiple channels. The gains for the weighted
averaging is determined by estimating the signal-to-noise ratio (SNR) for each RDI
across antennas. The effective RDI is computed as
NRx r x r x
r x=1 g |I |
I = NRx
(3)
rx
r x=1 g
where I r x is the complex RDI of the r xth receive channel, and the gain is adaptively
calculated as
NTS × PN max{|I r x |2 }
gr x = (4)
NTS PN
l=1 m=1 |I r x (m, l)|2 − max{|I r x |2 }
where max{.} represents the maximum value from the 2D function, and gr x represents
the estimated SNR at r th receive channel. Thus the (PN × NTS × NRx ) RDI tensor
is transformed into a (PN × NTS) RDI matrix. Alternately, in the latter pipeline,
instead of MRC the raw complex RDI is converted to an RAI using Capon or MVDR
algorithm.
180 M. Stephan et al.
dnRx sin(θ )
anRx (θ ) = exp(− j2π ); n = 1, 2 (6)
λ
where λ is the wavelength of the transmit signal. The two received de-ramped beat
signal are used to resolve the relative angle θ of the scatterer.
The azimuth imaging profile for each range bin can be generated using the Capon
spectrum from the beamformer. The Capon beamformer is computed by minimizing
the variance/power of noise while maintaining a distortionless response toward a
desired angle. The corresponding quadratic optimization problem is
where C is the Covariance matrix of noise, the above optimization has a closed
C −1 a(θ)
form expression given as wcapon = a H (θ)C −1 a(θ) , with θ being the desired angle. On
2D CFAR: Following the earlier operation in each pipeline, the raw RDI in the
former case and raw RAI in the latter case is fed into a CFAR detection algorithm to
Human Target Detection and Localization with Radars … 181
generate a hit matrix that indicates the detected targets. The target detection problem
can be expressed as
1 if |I (cut)|2 > μσ 2
Idet (cut) = (9)
0 if |I (cut)|2 < μσ 2
where Idet (cut) is the detection output for the cell under test (CUT), depending on
the input image I . The product of μ and σ 2 , the estimated covariance, represents the
threshold for detection. The threshold multiplier μ for the CFAR detection is set by
an acceptable probability of false alarm, which for cell-averaging CFAR (CA-CFAR)
is provided as
−1/N
μ = N (Pfa − 1) (10)
where Pfa is the probability of false alarm, and N is the window size of the so-called
“reference” cell used for noise power estimation. The noise variance σ 2 is estimated
from the “reference” cells around the CUT. Thus, estimated noise covariance takes
into account the local interference plus noise power around the CUT.
While CA-CFAR is the most common detector used in case of point targets, in
case of doubly spread targets, such as humans with wide radar bandwidth, it leads to
poor detections since the spread targets are also present in the reference cells, lead-
ing to high noise power estimations and missed detections. Additionally, CA-CFAR
elevates the noise threshold near a strong target, thus occluding nearby weaker tar-
gets. Alternately for such doubly spread targets, order-statistics CFAR (OS-CFAR)
is used to avoid such issues since the ordered statistic is robust to any outliers, in this
case, target’s spread, in the reference cells. Hence, instead of the mean power in the
reference cells, the kth ordered data is selected as the estimated noise variance σ 2 .
A detailed description of OS-CFAR can be found in [27].
2
range in m
4 2 0 2 4
velocity in m/s
Figure 5 illustrates the network architecture. It has an encoder and a decoder path,
each with three resolution steps. In the encoder path, each block contains two 3 × 3
convolutions, each followed by a rectified linear unit (ReLu), and a 3 × 3 max pooling
with strides of two in every dimension. In the decoder path, each block consists of
an upconvolution of 2 × 2 with strides of two in each dimension, followed by a
ReLu and two 3 × 3 convolutions, each followed by a ReLu. The up-convolutions
are implemented as an upsampling layer followed by a convolutional layer. Skip
connections from layers of equal resolution in the encoder path provide the high-
resolution features to the decoder path, like the standard U-net. The bridge connection
between the encoder and the decoder network consists of two convolutional layers,
each followed by a ReLu, with a Dropout Layer in between. The dropout is set
to 0.5 during training to reduce overfitting. In the last layer, a 1 × 1 convolution
184 M. Stephan et al.
Raw RDI
0
0
range in m
Traditioanl
1
range in m
2 OS-CFAR DBSCAN 2
3
Detection Clustering 3
4
4
4 2 0 2 4 4 2 0 2 4
velocity in m/s velocity in m/s
(a) Traditional Processing Pipeline
Raw RDI Detected Target RDI
0 0
range in m
1
range in m
Proposed
2 Deep 2
3
Residual U-Net 3
4 4
4 2 0 2 4 4 2 0 2 4
velocity in m/s velocity in m/s
(b) Proposed Processing Pipeline
Fig. 3 a Traditional processing Pipeline with OS-CFAR and DBSCAN to generate target detected
RDIs. b Processing Pipeline using the proposed deep residual U-Net to suppress ghost targets,
multi-path reflections, mitigate target occlusions, and achieve accurate clustering
Input Image
Output Image
Fig. 5 Proposed RDI presence detection architecture for a depth 3 network. Each box corresponds
to one or more layers
Fig. 6 Proposed RAI presence detection and localization architecture for a depth 4 network. Each
box corresponds to one or more layers
reduces the number of output channels to the number of classes, which is 2, target
present/absent, in our case. The architecture has 40752 trainable parameters in total.
As suggested in the literature, bottlenecks are prevented by doubling the number of
channels before max pooling [30]. We also adopt the same scheme in the decoder
path to avoid bottlenecks. The input to the network is a 128 × 32 × 1 raw RDI. Our
output is a 128 × 32 × 1 image with pixel values between 0 and 1, representing the
probability of target presence/absence in each pixel.
Figure 6 shows the network architecture used for the localization task. While it looks
quite similar to Fig. 5, it differs significantly in the details. It is still a U-Net like archi-
tecture, but it uses fully complex operations and 3D-convolutions. A brief description
of complex convolutional layer and complex activation layer is provided below:
186 M. Stephan et al.
where Ai−1,k , Bi−1,k presents the real and imaginary parts of the feature map at
ith layer and kth map after the convolution operation. The kth kernel’s real and
imaginary components Ci,k , Di,k are real and imaginary parts, respectively. The
filter dimensions are Q i × Q i × K i , where Q i is applied on the range-Doppler
dimension and K i along the spatial channels.
2. Complex Activation Function:
The complex 2D layers progressively extract deeper feature representation. The
activation function introduces non-linearity into the representation. The complex
ReLU is implemented as
In effect, complex RELU maps the second quadrant data to the positive half of
the imaginary axis, the third quadrant data to the origin, and fourth quadrant data
to the positive half of the real axis.
In the proposed architecture, the complex max pooling is avoided since it leads
to model instabilities, and thus progressive range-Doppler dimension reduction was
achieved through strided convolutions. In the encoder path, each block contains two
3 × 3 × 1 complex convolutional layers, each followed by a complex ReLu, and a
3 × 3 × 1 strided convolutional layer with a 2 × 2 × 1 stride. In the decoder path, the
up-convolutional layers are implemented as a 2 × 2 × 1 upsampling layer followed
by a 2 × 2 × 1 convolutional layer. Between each block of the same depth in the
encoder path and the decoder path are skip connections to provide the feature location
Human Target Detection and Localization with Radars … 187
5 Implementation Details
5.1 Dataset
To create the labeled dataset, the raw RDIs are processed with the respective tradi-
tional processing pipelines, with MVDR beamforming to create the labeled RAIs
and MRC instead to create the labeled RDIs. During recording, a camera was used
to generate the dataset, with whose feedback we removed ghost targets and multi-
path reflections and added detections whenever targets were occluded due to other
humans or static humans close to the wall. This was done by changing the parameters
for the detection and clustering algorithms in the conventional pipeline, so that the
probability of detection approaches 100%, also resulting in a high probability of false
alarm. This means decreasing the CFAR scaling factor in case of target occlusion,
and reducing/increasing the maximum neighbor distance for cluster detection with
DBSCAN in case of merged/separated targets. All falsely detected clusters are then
manually removed using the camera data as a reference. The described process is
relatively simple for one-target measurements, as the correct cluster is generally the
one closest in range to the radar sensor. The dataset comprises from one up to four
humans in the room.
The RDIs are augmented to increase the dataset and achieve a broader gener-
alization of the model. Due to the sharply increasing difficulty in creating labeled
data with multiple humans present, we synthetically computed RDIs with multiple
targets by superimposing several raw one-target RDIs after translation and rotation.
With this technique, a large number of RDIs, limited only by the number of possible
combinations of the one-target measurements, can be synthesized. Some caution is
required in having a large enough basis of one-target measurements, as the network
188 M. Stephan et al.
may otherwise overfit on the number of possible target positions. To increase the
pool of one-target measurements for the RAI presence detection and localization
task, the one-target measurements were augmented by multiplying the RDIs from
both antennas by the same complex values with an amplitude close to one and a
random phase. This operation changes the input values but should not change the
estimated ranges and angles aside from a negligible range error in the range of half
the wavelength.
Given a set of training images, raw RDIs and the corresponding ground truth pro-
cessed RDIs or RAIs Ii , Pi , the training procedure estimates the parameters of the
network, such that the model generalizes to reconstruct accurate RDIs or RAIs. This
is achieved through minimizing the loss between the true RDIs/RAIs Pi and that
generated by h(Ii ; W ). Here, we use a weighted combination of focal loss [31], and
hinge loss, as given in Eq. (13), for the loss function to train the proposed models
H L( p) = 1 − y(2 p − 1)
F L( pt ) = (1 − pt )ϒ log( pt )
L( pt ) = α (F L( pt ) + ηH L( p)) (13)
The variables y ∈ {±1}, p ∈ [0, 1], and pt ∈ [0, 1], specify the class label, the
estimated probability for the class with label y = 1, and the probability that a pixel
was correctly classified, as defined in Eq. (14) as
p if y = 1
pt = (14)
1 − p otherwise
The parameters γ , η, and α influence the shape of the focal loss, the weight of the
hinge loss, and the class weighting, respectively. For training the deep residual U-Net
model for reconstructing the processed range-Doppler image, these parameters were
chosen to γ = 2, α = 0.25, and η was step-wise increased to η = 0.15. In case of
training the deep complex U-Net model for reconstructing the processed range-angle
image, the same parameters except for η = 0 was used. The chosen parameters led
to the best learned model in terms of F1 score accuracy for both models.
The reason for using the focal loss is the class imbalance in the training data due
to the nature of the target setup in picture Fig. 2a. In most cases, the frequency of
pixels labeled as “target absent” is much higher than the frequency of those labeled
with “target present”. Additionally, the class frequencies may vary widely between
single training samples, especially as the training set contains RDIs with different
numbers of targets. The focal loss places a higher weight on the cross-entropy loss for
misclassified pixels and a much lower weight on well-classified ones. Thus, pixels
Human Target Detection and Localization with Radars … 189
belonging to the rarer class are generally weighted higher. The hinge loss is added
to the focal loss with a factor η to force the network to make clearer classification
decisions. The value for η is chosen in such a way that the focal loss dominates the
training for the first few epochs before the hinge loss becomes relevant.
The weight initialization of the network is performed with a Xavier uniform initial-
izer, which draws samples from a uniform distribution within [−limit, limit], where
the limit is calculated by taking the square root of six divided by the total number of
input and output units in the weight tensor. The respective biases were initialized as
zeros. For the backpropagation, the Adam [32] optimizer was used, with the default
learning rate (alpha) of 0.001; the exponential decay rate for the first moment (beta1)
was set to 0.9 and to 0.999 for the second moment (beta2). The epsilon that counters
divide by zero problems is set to 10−7 .
ating the labeled data, it is likely that only parts of each human are classified as
“target present” in the RDI/RAI. Therefore, when defining missed detections and
false alarms, small positional errors, and variations in the cluster size should be dis-
counted. We did this by computing the center of mass for each cluster in the labeled
data and the processed RDIs. Targets are identified as correctly detected only if the
distance between the cluster center of masses of the processed and the labeled RDIs
is smaller than 20 cm in range, and 1.5 m/s in velocity. Additionally, we enforce that
each cluster in the network output can only be assigned to exactly one cluster in the
labeled data and vice-versa.
Table 2 presents the performance of the proposed approach in terms of F1 score
in comparison to the traditional processing chain. Results for the proposed approach
are shown for a depth three network (NN_d3), as shown in Fig. 5 and also for a
deeper network with five residual blocks (NN_d5) in both encoder and decoder. For
the depth 5 network, 2 more residual blocks were added in both the encoder, and
the decoder, compared to the structure displayed in Fig. 5. The input to the bridge
block between encoder and decoder path then has the dimension 4 × 1 × 125. The
proposed approach gives a much better detection performance, an F1 score of 0.89,
than the traditional processing pipeline with an F1 score of 0.71. The deeper net-
work (NN_d5), with around 690 thousand trainable parameters, shows some further
improvements in terms of detection performance with an F1 score of 0.91.
Figure 8 presents the ROC curve of the proposed U-net architecture with depth
three and five in comparison with traditional processing. The ROC curves for the
depth three and the depth five networks are done for a varying hard threshold param-
eter, which describes the minimum pixel value for a point to be classified as a target
in the NN output. In the ROC curve representing the traditional processing chain, the
scaling factor for the CFAR detection method was varied. The curves are extrapolated
for detection probabilities close to one. As depicted in the ROC curve, the proposed
residual U-net architecture provides a much better AUC performance compared to the
traditional processing pipeline. Similarily, Fig. 9 shows a better AUC performance
for the complex U-net compared to the traditional processing.
Figure 10a presents the raw RDI, Fig. 10b presents the detected RDI using tra-
ditional approaches, Fig. 10c presents the detected RDI using the proposed deep
residual U-Net approach for a synthetic four target measurement. Originally, the
Table 2 Comparison of the detection performance of the traditional pipeline with the proposed
U-net architecture with a depth of 3 and depth of 5 for localization in the range-Doppler image
Approach Description F1-score Model size
Traditional OS-CFAR with 0.71 –
DBSCAN
Proposed U-net Proposed loss 0.89 616 kB
depth 3
Proposed U-net Proposed loss 0.91 2.8 MB
depth 5
Human Target Detection and Localization with Radars … 191
ROC curve
1
0.9
0.8
0.7
0.6
pD
0.5
0.4
0.3
NN_d5
0.2 NN_d3
traditional
0.1
0
0 0.2 0.4 0.6 0.8 1
pFA
Fig. 8 Radar receiver operating characteristics (ROC) comparison between the proposed residual
deep U-net and the traditional signal processing approach
ROC curve
probability of detection
Fig. 9 Radar receiver operating characteristics (ROC) comparison between the proposed complex
deep U-net and the traditional signal processing approach. Dashed parts indicate extrapolation
NN outputs classification probabilities between zero and one for each pixel. To get
the shown output, we used a hard threshold of 0.5 on the NN output, so that any
pixels with corresponding values of 0.5 and higher are classified as belonging to a
target. With the traditional approach, one target, at around 2 m distance, and −2 m/s
in velocity, was mistakenly split into two separate clusters. Additionally, the two
targets between 2 and 3 m were completely missed.
192 M. Stephan et al.
0 0 0
1 1 1
range in m
range in m
range in m
2 2 2
3 3 3
4 4 4
4 2 0 2 4 4 2 0 2 4 4 2 0 2 4
velocity in m/s velocity in m/s velocity in m/s
(a) Raw RDI (b) Processed RDI traditional approach (c) Processed RDI proposed approach
Fig. 10 a Raw RDI image with four human targets, b processed RDI using the traditional approach
wherein one target is split and two targets are occluded, c processed RDI using proposed approach
wherein all targets are detected accurately
0 0 0
1 1 1
range in m
range in m
range in m
2 2 2
3 3 3
4 4 4
4 2 0 2 4 4 2 0 2 4 4 2 0 2 4
velocity in m/s velocity in m/s velocity in m/s
(a) Raw RDI (b) Processe RDI Traditional (c) Processed RDI Proposed Approach
Approach
Fig. 11 a Raw RDI image with four human targets, b processed RDI using the traditional approach
wherein one target is occluded, c processed RDI using proposed approach wherein all targets are
detected accurately
0 0 0
1 1 1
range in m
range in m
range in m
2 2 2
3 3 3
4 4 4
4 2 0 2 4 4 2 0 2 4 4 2 0 2 4
velocity in m/s velocity in m/s velocity in m/s
(b) Processe RDI Traditional (c) Processed RDI Proposed Approach
(a) Raw RDI
Approach
Fig. 12 a Raw RDI image with four human targets, b processed RDI using the traditional approach
wherein two targets are merged into one by the DBSCAN clustering algorithm in addition to two
ghost targets, c processed RDI using proposed approach wherein all targets are detected accurately
and with proper clustering
Figure 11a–c presents the target occlusion problem on synthetic data for the con-
ventional processing chain. While with the traditional approach one target at around
3 m distance was missed, the proposed approach in Fig. 11c is able to reliably detect
all the targets. The figures in 12a–c, and 13a–c show two different scenarios the
traditional approach struggles with.
Human Target Detection and Localization with Radars … 193
0 0 0
1 1 1
range in m
range in m
range in m
2 2 2
3 3 3
4 4 4
4 2 0 2 4 4 2 0 2 4 4 2 0 2 4
velocity in m/s velocity in m/s velocity in m/s
(a) Raw RDI (b) Processed RDI traditional approach (c) Processed RDI proposed approach
Fig. 13 a Raw RDI image with four human targets, b processed RDI using the traditional approach
wherein a ghost target appears, c processed RDI using proposed approach wherein all targets are
detected accurately
In Fig. 12b, the two detected targets around the one-meter mark are too close
together for the DBSCAN clustering algorithm as to be detected as two distinct
clusters. Therefore, these two clusters merge, and one target is missed. Fig. 13b
showcases the ghost target problem. Here, one target at around 4 m in distance to the
radar sensor was wrongly detected. In both cases, the U-Net-based approach correctly
displays all the distinct targets, as seen in Figs. 12c and 13c. However, it has to be
mentioned that while our proposed approach outperforms the traditional processing
chain, missed detections, and false alarms may still occur in similar scenarios. From
our experiments we have observed, that the proposed approach excels in discarding
multi-path reflections and ghost targets caused by the reflections from static objects
in the scene, does well in preventing splitting or merging targets, but does not really
show improvements for the case of occluded targets if many humans are in front of
each other. The cause of this lies in the nature of how most of the training data was
created, where one-target measurements were superimposed over each other in order
to synthesize the multi-target RDIs.
In our experiments, we noticed that the loss function plays a crucial role in achiev-
ing excellent detection results. While the current loss function deals well with the
class imbalance problem and accelerates training for more important pixels in the
RDI, it could be improved by a more severe punishment of merging or split targets,
and by allowing small positional errors in the clusters without increasing the loss.
The evaluation of the proposed method for target detection and localization in
terms of range and angle is again done via ROC-curves and the F1-scores. We use
the same set of test measurements as described earlier, but with the complex RDIs
from two receiving antennas. For the evaluation of the range-angle output maps of the
neural network, a simple clustering algorithm is used on these output images. In our
case, we use DBSCAN with small values for the minimum number of neighbors and
the maximum neighbor distance so that every nearly continuous cluster is recognized
as one target. We then compute the center of masses of each cluster and compare
them to the center of masses computed from our labeled data, as described earlier.
In this case, a target still counts as detected if the distance between the two
centers of mass of the clusters is below some threshold; in this case, 7.8◦ or 37.5 cm.
Table 3 shows a comparison of F1-scores for the proposed method and a traditional
194 M. Stephan et al.
Table 3 Comparison of the detection performance of the traditional pipeline with the proposed
U-net architecture with a depth of 4 for localization in the range-angle image
Approach Description F1-score Model size
Traditional OS-CFAR with 0.61 –
DBSCAN
Complex U-net depth 3 Focal loss 0.72 529 kB
Complex U-net depth 4 Focal loss 0.77 1.23 MB
signal processing chain. Compared to the F1-score of 0.62 for the classical method,
the proposed method shows a clear improvement with an F1-score of 0.77. The
ROC curves in Fig. 9 stop at values smaller one due to how the detections were
evaluated. The dashed lines indicate an extrapolation. If the threshold or the CFAR
scaling factor is set close to zero, then every pixel will be detected as a target, which
would then be identified as one single huge cluster by the DBSCAN algorithm.
Therefore, a probability of false alarm of 1 is not achievable. In the evaluation, we
saw, that the proposed method performs, as expected, better for fewer targets while
its performance worsens mainly with a rising target density. However, even if the
neural network is trained with only one and two target measurements, it will still
be able to correctly identify the positions of three or four targets in a lot of cases.
Comparing the F1-scores from Table 3 to those in Table 2, it seems like the network
has a harder time doing the range-angle localization task. This has several reasons.
First, only two receiving antennas were used, making an accurate angle estimation
more difficult in general. The second explanation, which likely has a bigger impact,
is that the evaluation methods were not the same for both experiments. Specifically,
the minimum center of mass distance, as the difference in velocity is not comparable
to an angle difference. If we increase the angle threshold from 7.8◦ to 15.6◦ , we
get an F1-score of about 0.9 and 0.68 for the proposed and the traditional method,
respectively. Therefore, the proposed method does well in removing ghost targets and
estimating the correct target ranges but is not quite as good in accurately estimating
the angles. The traditional method does not gain as much due to this change since
most of its errors are due to ghost targets or missed detections.
In Figs. 14, 15, 16, and 17 some examples with the input image, the output image
from the traditional processing chain, and the output image from the neural network
are shown. The input image here is the RDI from one of the antennas, as the complex
two antenna input is hard to visualize. It is mainly there to illustrate the difficulty
of the range-angle presence detection and localization. In Fig. 14, two targets were
missed by the classical signal processing chain, whereas all four targets were detected
by the neural network with one additional false alarm. In Fig. 15, all four targets were
detected by the network, while one was missed by the traditional approach. In Fig.
16, only one target was correctly identified by the classical approach, three targets
missed, and one false alarm included. Here, the neural network only missed one
target. In the last example, the network again identified all targets at the correct
Human Target Detection and Localization with Radars … 195
1.5 1.5
range in m
range in m
range in m
1.5
3 3 3
Fig. 14 a Raw RDI image with four human targets, b processed RAI using the traditional approach
with two missed detections, c processed RAI using proposed complex U-net approach wherein one-
target split occurs
range in m
range in m
range in m
3 3 3
Fig. 15 a Raw RDI image with four human targets, b processed RAI using the traditional approach
with one missed detection, c processed RAI using proposed complex U-net approach wherein all
targets are detected accurately
range in m
range in m
3 3 3
Fig. 16 a Raw RDI image with four human targets, b processed RAI using the traditional approach
with one ghost target and three missed detections, c processed RAI using proposed complex U-net
approach wherein one target was missed
positions, while the traditional approach has two missed detections and one-target
split, resulting in one false alarm.
196 M. Stephan et al.
range in m
1.5 1.5 1.5
range in m
2.25 2.25 2.25
3 3 3
Fig. 17 a Raw RDI image with three human targets, b processed RAI using the traditional approach
with one-target split and two missed detections, c processed RAI using proposed complex U-net
approach wherein all targets are detected accurately
7 Conclusion
The traditional radar signal processing pipeline for detecting targets on either RDI
or RAI is prone to ghost targets, multi-path reflections from static objects, and target
occlusion in cases of human detections, especially in an indoor environment. Further,
parametric clustering algorithms suffer from single target splits, and multiple target
merges into single clusters. To overcome such artifacts and facilitate accurate human
detections, localization and counting, we, in this contribution, proposed to use deep
residual U-Net model and deep complex U-Net model to generate accurate human
detections in RDI and RAI domain, respectively, in indoor scenarios. We trained
the models using custom loss function, proposed architectural designs, and training
strategy through data augmentation to achieve accurate processed RDI and RAI.
We demonstrated the superior detection and clustering results in terms of F1 score
and ROC characterization compared to the conventional signal processing approach.
As future work, variational autoencoder generative adversarial network (VAE-GAN)
architecture can be deployed to minimize the sensitivity of U-Net models to variations
in data due to sensor noise and interference.
References
1. EPRI, Occupancy sensors: positive on/off lighting control, in Rep. EPRIBR-100323 (1994)
2. V. Garg, N. Bansal, Smart occupancy sensors to reduce energy consumption. Energy Build.
32, 81–87 (2000)
3. W. Butler, P. Poitevin, J. Bjomholt, Benefits of wide area intrusion detection systems using
FMCW radar (2007), pp. 176–182
4. J. Lien, N. Gillian, M. Emre Karagozler, P. Amihood, C. Schwesig, E. Olson, H. Raja,
I. Poupyrev, Soli: ubiquitous gesture sensing with millimeter wave radar. ACM Trans. Graph.
35, 1–19 (2016)
5. S. Hazra, A. Santra, Robust gesture recognition using millimetric-wave radar system. IEEE
Sens. Lett. PP, 1 (2018)
Human Target Detection and Localization with Radars … 197
6. S. Hazra, A. Santra, Short-range radar-based gesture recognition system using 3D CNN with
triplet loss. IEEE Access 7, 125623–125633 (2019)
7. M. Arsalan, A. Santra, Character recognition in air-writing based on network of radars for
human-machine interface. IEEE Sen. J. PP, 1 (2019)
8. C. Will, P. Vaishnav, A. Chakraborty, A. Santra, Human target detection, tracking, and classi-
fication using 24 GHZ FMCW radar. IEEE Sens. J. PP, 1 (2019)
9. A. Santra, R. Vagarappan Ulaganathan, T. Finke, Short-range millimetric-wave radar system
for occupancy sensing application. IEEE Sens. Lett. PP, 1 (2018)
10. H.L.V. Trees, Detection, Estimation, and Modulation Theory, Part I (Wiley, 2004)
11. A. Santra, I. Nasr, J. Kim, Reinventing radar: the power of 4D sensing. Microw. J. 61, 26–38
(2018)
12. F. Schroff, D. Kalenichenko, J. Philbin, Facenet: a unified embedding for face recognition and
clustering (2015), pp. 815–823
13. O. M. Parkhi, A. Vedaldi, A. Zisserman, Deep face recognition, vol. 1 (2015), pp. 41.1–41.12
14. S. Ren, K. He, R. Girshick, J. Sun, Faster r-cnn: towards real-time object detection with region
proposal networks. IEEE Trans. Pattern Anal. Mach. Intell. 39, 06 (2015)
15. J. Redmon, S. Divvala, R. Girshick, A. Farhadi, You only look once: unified, real-time object
detection (2016), pp. 779–788
16. L.-C. Chen, G. Papandreou, I. Kokkinos, K. Murphy, A.L. Yuille, Deeplab: semantic image
segmentation with deep convolutional nets, atrous convolution, and fully connected crfs. IEEE
Trans. Pattern Anal. Mach. Intell. PP (2016)
17. O. Ronneberger, P. Fischer, T. Brox, U-net: convolutional networks for biomedical image
segmentation (2015)
18. M.A. Wani, F.A. Bhat, S.Afzal, A.I. Khan, Advances in Deep Learning (Springer, Singapore,
2020)
19. K. He, X. Zhang, S. Ren, J. Sun, Deep residual learning for image recognition (2016), pp.
770–778
20. Z. Zhang, Q. Liu, Y. Wang, Road extraction by deep residual U-net. IEEE Geosci. Remote
Sens. Lett. PP (2017)
21. G. Zhang, H. Li, F. Wenger, Object detection and 3D estimation via an FMCW radar using a
fully convolutional network (2019). arXiv preprint arXiv:1902.05394
22. L. Wang, J. Tang, Q. Liao, A study on radar target detection based on deep neural networks.
IEEE Sens. Lett. 3(3), 1–4 (2019)
23. M. Stephan, A. Santra, Radar-based human target detection using deep residual U-net for
smart home applications, in 18th IEEE International Conference on Machine Learning And
Applications (ICMLA) (IEEE, 2019), pp. 175–182
24. J.S. Dramsch, Contributors, Complex-valued neural networks in keras with tensorflow (2019)
25. C. Trabelsi, O. Bilaniuk et al., Deep complex networks (2017). arXiv preprint arXiv:1705.09792
26. L. Xu, J. Li, P. Stoica, Adaptive techniques for MIMO radar, in Fourth IEEE Workshop on
Sensor Array and Multichannel Processing, vol. 2006 (IEEE, 2006), pp. 258–262
27. H. Rohling, Radar CFAR thresholding in clutter and multiple target situations. IEEE Trans.
Aerosp. Electron. Syst. 19, 608–621 (1983)
28. M. Ester, H.-P. Kriegel, J. Sander, X. Xu, A density-based algorithm for discovering clusters
in large spatial databases with noise, in KDD (1996)
29. A. Santra, R. Santhanakumar, K. Jadia, R. Srinivasan, SINR performance of matched illumi-
nation signals with dynamic target models (2016)
30. C. Szegedy, V. Vanhoucke, S. Ioffe, J. Shlens, Z. Wojna, Rethinking the inception architecture
for computer vision (2016)
31. T.-Y. Lin, P. Goyal, R. Girshick, K. He, P. Dollar, Focal loss for dense object detection (2017),
pp. 2999–3007
32. D. Kingma, J. Ba, Adam: a method for stochastic optimization, vol. 12 (2014)
Thresholding Strategies for Deep
Learning with Highly Imbalanced Big
Data
1 Introduction
Class imbalance exists when the total number of samples from one class, or category,
is significantly larger than any other category within the data set. This phenomenon
arises in many critical industries, e.g. financial [1], biomedical [2], and environmental
[3]. In each of these examples, the positive class of interest is the smaller class, i.e.
the minority group, and there is an abundance of less-interesting negative samples.
In this study, we focus specifically on binary classification problems that contain a
positive and negative class. The concepts presented can be extended to the multi-
class problem, however, because multi-class problems can be converted into a set of
two-class problems through class decomposition [4].
Imbalanced data sets have been shown to degrade the performance of classifica-
tion models, often causing models to over-predict the majority group. As a result,
instances belonging to the minority group are incorrectly classified as negative sam-
ples and positive class performance suffers. To make matters worse, popular eval-
uation metrics like accuracy are liable to mislead analysts with high scores that
incorrectly indicate good prediction performance. For example, given a binary data
set with a positive class size of just 1%, a simple learner always outputs the negative
class will score 99% accuracy. The extent of this performance degradation depends
on problem complexity, data set size, and the level of class imbalance [5]. In this
study, we use deep neural network (DNN) models to make predictions with complex
data sets that are characterized by both big data and high class imbalance.
Generally, DNN model performance degrades as the level of class imbalance
increases and the relative size of the positive class decreases [6]. We denote the level
of class imbalance in a binary data set using the ratio of negative samples to positive
samples, i.e. nneg : npos . For example, the imbalance of a data set with 400 negative
samples and 100 positive instances is denoted by 80:20. Equivalently, we sometimes
refer to the positive class’s prior probability, e.g. the 80:20 distribution has a positive
class prior of 0.2 or 20%. We classify a data set as highly imbalanced when the
positive class prior is ≤ 0.01 [7]. When the total number of positive occurrences
becomes even more infrequent, we describe the data set as exhibiting rarity. Weiss et
al. [8] distinguish between absolute rarity and relative rarity. Absolute rarity occurs
when there are very few samples for a given class, regardless of the size of the
majority class. Unlike absolute rarity, a relatively rare class can make up a small
percentage of a data set and still have many occurrences when the overall data set
is very large. For example, given a data set with 10 million records and a relative
positive class size of 1%, there are still 100,000 positive cases to use for training
machine learning models. Therefore, we can sometimes achieve good performance
with relatively rare classes, especially when working with large volumes of data.
The experiments in this chapter explore a wide range of class imbalance levels, e.g.
0.03–90%, and include cases of relative rarity.
The challenges of working with class-imbalanced data are often compounded by
the challenges of big data [9]. Big data commonly refers to data which exceeds the
capabilities of standard data storage and processing. These data sets are also defined
Thresholding Strategies for Deep Learning … 201
using the four Vs: volume, variety, velocity, and veracity [10, 11]. The large vol-
umes of data being collected require highly scalable hardware and efficient analysis
tools, often demanding distributed implementations. In addition to adding architec-
ture and network overhead, distributed systems have been shown to exacerbate the
negative effects of class-imbalanced data [12]. The variety of big data corresponds
to the mostly unstructured, diverse, and inconsistent representations that arise as
data is consumed from multiple sources over extended periods of time. Advanced
techniques for quickly processing incoming data streams and maintaining appropri-
ate turnaround times are required to keep up with the rate at which data is being
generated, i.e. data velocity. Finally, the veracity of big data, i.e. its accuracy and
trustworthiness, must be regularly validated to ensure results do not become cor-
rupted. MapReduce [13] and Apache Spark [14] are two popular frameworks that
address these big data challenges by operating on partitioned data in parallel. Neural
networks can also be trained in a distributed fashion using either data parallelism or
model parallelism techniques [15, 16]. The three data sets used in this study include
many of these big data characteristics.
Data-level and algorithm-level techniques for addressing class imbalance have
been studied extensively. Data methods use sampling to alter the distribution of the
training data, effectively reducing the level of class imbalance for model training.
Random over-sampling (ROS) and random under-sampling (RUS) are the two sim-
plest data-level methods for addressing class imbalance. ROS increases the size of
the minority class by randomly copying minority samples, and RUS decreases the
size of the majority class by randomly discarding samples from the majority class.
From these fundamental data-level methods, many more advanced variants have
been developed [17–20]. Algorithm-level methods modify the training process to
increase the impact of the minority class and reduce bias toward the majority class.
Direct algorithm-level methods modify a machine learning algorithm’s underlying
learner, usually by incorporating class costs or weights. Meta-learner methods use
a wrapper to convert non-cost-sensitive learners into cost-sensitive learners. Cost-
sensitive learning and output thresholding are examples of direct and meta-learner
algorithm-level methods, respectively [21]. Output thresholding is the process of
changing the decision threshold that is used to assign class labels to a model’s pos-
terior probabilities[22, 23]. Finally, there are a number of hybrid methods that com-
bine two or more data-level methods and algorithm-level methods [7, 24–26]. In this
chapter, we explore output thresholding as it relates to deep learning.
Deep learning is a subfield of machine learning that uses artificial neural network
(ANN) with two or more hidden layers to approximate some function f ∗ , where
f ∗ can be used to map input data to new representations or make predictions [27].
The ANN, inspired by the biological neural network, is a set of interconnected
neurons, or nodes, where connections are weighted and each neuron transforms its
input into a single output by applying a nonlinear activation function to the sum of
its weighted inputs. In a feedforward network, input data propagates through the
network in a forward pass, each hidden layer receiving its input from the previous
layer’s output, producing a final output that is dependent on the input data, the choice
of activation function, and the weight parameters [28]. Gradient descent optimization
202 J. M. Johnson and T. M. Khoshgoftaar
adjusts the network’s weight parameters in order to minimize the loss function, i.e.
the error between expected output and actual output. Composing multiple nonlinear
transformations creates hierarchical representations of the input data, increasing the
level of abstraction through each transformation. The deep learning architecture,
i.e. deep neural network (DNN), achieves its power through this composition of
increasingly complex abstract representations [27]. Deep learning methods have
proven very successful in solving complex problems related to natural language and
vision [29]. These recent successes can be attributed to an increased availability
of data, improvements in hardware and software [30–34], and various algorithmic
breakthroughs that speed up training and improve generalization to new data [35].
Ideally, we would like to leverage the power of these deep models to improve the
classification of highly imbalanced big data. While output thresholding has been
used with traditional learners to treat class imbalance in both big and non-big data
problems, we found there is little work that properly evaluates its use in DNN models.
Most commonly, the Default threshold of 0.5 is used to assign class labels to a
classifier’s posterior probability estimates. In this chapter, however, we argue that
the Default threshold is rarely optimal when neural network models are trained with
class-imbalanced data. This intuition is drawn from the fact that neural networks
have been shown to estimate Bayesian a posteriori probabilities when trained with
sufficient data [36]. In other words, a well-trained DNN model is expected to output
a posterior probability estimate for input x that corresponds to yc (x) (Eq. 1), where
factors p(c) and p(x) are prior probabilities and p(x | c) is the conditional probability
of observing instance x given class c. We do not need to compute each factor indi-
vidually since neural networks do not estimate these factors directly, but we can use
Bayes Theorem and the positive class prior from our training data to better under-
stand the posterior estimates produced by our models. The estimated positive class
prior p(cpos ) is the probability that a random sample drawn from the training data
belongs to the positive class, and it is equal to the number of positive training sam-
ples divided by the total number of training samples. In highly imbalanced problems,
e.g. p(cpos ) ≤ 0.01, small positive class priors can significantly decrease posterior
estimates and, in some cases, yc (x) may never exceed 0.5. If yc (x) ≤ 0.5 for all x, the
Default threshold will incorrectly assign all positive samples to the negative class.
We can account for this imbalance by identifying Optimal thresholds that balance
positive and negative class performance. We build on this concept empirically by
exploring two thresholding strategies with three real-world data sets. The Optimal
thresholding strategy identifies the threshold that maximizes performance on training
or validation data and then uses the optimal value to make predictions on the test
set. The Prior thresholding strategy estimates p(cpos ) from the training data and then
uses p(cpos ) as the classification threshold for making predictions on the test set.
p(c) · p(x | c)
yc (x) = p(c | x) = (1)
p(x)
The first two data sets, Medicare Part B [37] and Part D [38], summarize claims
that medical providers have submitted to Medicare and include a class label that
Thresholding Strategies for Deep Learning … 203
2 Related Work
This section begins by summarizing other works that address class imbalance with
DNN models. We then introduce works related to Medicare fraud prediction and
contact map prediction.
204 J. M. Johnson and T. M. Khoshgoftaar
The big data, big value, and high class imbalance inherent in Medicare fraud predic-
tion make it an excellent candidate for evaluating methods designed to address class
imbalance. Bauder and Khoshgoftaar [58] use a subset of the 2012–2013 Medicare
Part B data, i.e. Florida claims only, to model expected amounts paid to providers for
services rendered to patients. In another study, Bauder and Khoshgoftaar [59] pro-
posed an outlier detection method that uses Bayesian inference to identify outliers,
and successfully validated their model using claims data of a known Florida provider
that was under criminal investigation for excessive billing. This experiment used a
subset of 2012–2014 Medicare Part B data that included dermatology and optom-
etry claims from Florida office clinics. Another paper by Bauder et al. [60] uses a
Naive Bayes classifier to predict provider specialty types, and then flag providers
that are practicing outside their expected specialty type as fraudulent. Results show
that specialties with unique billing procedures, e.g. audiologist or chiropractic, are
able to be classified with high precision and recall. Herland et al. [61] expanded
on the work from [60] by incorporating 2014 Medicare Part B data and real-world
fraud labels defined by the List of Excluded Individuals and Entities (LEIE) [62] data
set. The authors find that grouping similar specialty types, e.g. Ophthalmology and
Optometry, improves overall performance. Bauder and Khoshgoftaar [63] merge the
2012–2015 Medicare Part B data sets, label fraudulent providers using the LEIE data
set, and compare multiple traditional machine learning classifiers. Class imbalance
is addressed with RUS, and various class distributions are generated to identify the
optimal imbalance ratio for training. The C4.5 decision tree and logistic regression
(LR) learners significantly outperform the support vector machine (SVM), and the
80:20 class distribution is shown to outperform 50:50, 65:35, and 75:25 distributions.
In summary, these studies show that Medicare Part B and Part D claims data contains
sufficient variability to detect fraudulent providers and that the LEIE data set can be
reliably used for ground truth fraud labels.
The Medicare experiments in our study leverage data sets curated by Herland
et al. [64]. In one study, Herland et al. [64] used Medicare Part B, Part D, and
DMEPOS claims data from the years 2012–2016. Cross-validation and ROC AUC
scores are used to compare LR, Random Forest (RF), and Gradient Boosted Tree
(GBT) learners. Results show that Part B data sets score significantly better on ROC
AUC than the Part D data set, and the LR learner outperforms the GBT and RF
learners with a max AUC of 0.816. In a second paper, Herland et al. [65] used these
same Medicare data sets to study the effect of class rarity with LR, RF, and GBT
learners. In this study, the authors create an absolutely rare positive class by using
subsets of the positive class to form new training sets. They reduced the positive class
size to 1000, 400, 200, and 100, and then used RUS methods to treat imbalance and
compare AUC scores. Results show that smaller positive class counts degrade model
performance, and the LR learner with an RUS distribution of 90:10 performs best.
Several other research groups have taken interest in detecting Medicare fraud
using the CMS Medicare and LEIE data sets. Feldman and Chawla [66] looked for
206 J. M. Johnson and T. M. Khoshgoftaar
anomalies in the relationship between medical school training and the procedures that
physicians perform in practice by linking 2012 Medicare Part B data with provider
medical school data obtained through the CMS physician compare data set [67].
Significant procedures for each school were used to evaluate school similarities and
present a geographical analysis of procedure charges and payment distributions. Ko
et al. [68] used the 2012 CMS data to analyze the variability of service utilization
and payments. Ko et al. found that the number of patient visits is strongly correlated
with Medicare reimbursement, and concluded that there is a possible 9% savings
within the field of Urology alone. Chandola et al. [69] used claims data and fraud
labels from the Texas Office of Inspector General’s exclusion database to detect
anomalies. The authors confirm the importance of including provider specialty types
in fraud detection, showing that the inclusion of specialty attributes increases AUC
scores from 0.716 to 0.814. Branting et al. [70] propose a graph-based method for
estimating healthcare fraud risk within the 2012–2014 CMS and LEIE data sets.
The authors leverage the NPPES [71] registry to look up providers that are missing
from the LEIE database, increasing their total fraudulent provider count to 12,000.
Branting et al. combine these fraudulent providers with a subset of 12,000 non-
fraudulent providers and employ a J48 decision tree learner to classify fraud with a
mean AUC of 0.96.
The best performing learner scored a G-Mean of 0.706 using the subset of 12 million
instances and the ROS strategy, suggesting that the subset of 0.6 million instances
is not representative enough. Río et al. [75] also used the Hadoop framework to
explore RF learner performance with ROS and RUS methods for addressing class
imbalance. They too found that ROS outperforms RUS, and suggested that this is
due to the already underrepresented minority class being split across many partitions.
Río et al. achieved their best performance with 64 partitions and an over-sampling
rate of 130%. Unfortunately, results show a relatively low TPR (0.705) compared to
the reported TNR (0.725) and the winning competition results (0.730).
Similar to these related works, we use a subset of ECBDL’14 data (3.5 million
instances) to evaluate thresholding strategies for DNN classification. Unlike related
works, however, we did not find it necessary to use a distributed training environ-
ment. Instead of training individual models in a distributed fashion, we use multiple
compute nodes with sufficient resources to train multiple models independently and
in parallel. Also, unlike these related works, we do not rely on data sampling tech-
niques to balance TPR and TNR scores. Rather, we show how a simple thresholding
technique can be used to optimize class-wise performance regardless of the class
imbalance level.
3 Data Sets
This section summarizes the data sets that are used to evaluate thresholding tech-
niques for addressing class imbalance with deep neural networks. We begin with
two Medicare fraud data sets that were first curated by Herland et al. [64]. We then
incorporate a large protein contact map prediction data set that was published by the
Evolutionary Computation for Big Data and Big Learning (ECBDL) workshop in
2014 [40]. All three data sets were obtained from publicly available resources and
exhibit big data and class imbalance characteristics.
Two publicly available Medicare fraud data sets are obtained from CMS: (1) Medicare
Provider Utilization and Payment Data: Physician and Other Supplier (Part B) [37],
and (2) Medicare Provider Utilization and Payment Data: Part D Prescriber (Part
D) [38]. The healthcare claims span years 2012–2016 and 2013–2017 for Part B
and Part D data sets, respectively. Physicians are identified within each data set by
their National Provider Identifier (NPI), a unique 10-digit number that is used to
identify healthcare providers [76]. Using the NPI, Herland et al. map fraud labels to
the Medicare data from the LEIE repository. The LEIE is maintained by the Office of
Inspector General and it lists providers that are prohibited from practicing. Additional
attributes of LEIE data include the reason for exclusion and provider reinstatement
208 J. M. Johnson and T. M. Khoshgoftaar
dates, where applicable. Providers that have been excluded for fraudulent activity
are labeled as fraudulent within the Medicare Part B and Part D data sets.
The Part B claims data set describes the services and procedures that healthcare
professionals provide to Medicare’s Fee-For-Service beneficiaries. Records within
the data set contain various provider-level attributes, e.g. NPI, first and last name,
gender, credentials, and provider type. More importantly, records contain specific
claims details that describe a provider’s activity within Medicare. Examples of claims
data include the procedure performed, the average charge submitted to Medicare, the
average amount paid by Medicare, and the place of service. The procedures rendered
are encoded using the Healthcare Common Procedures Coding System (HCPCS)
[77]. For example, HCPCS codes 99219 and 88346 are used to bill for hospital
observation care and antibody evaluation, respectively. Part B data is aggregated by
(1) provider NPI, (2) HCPCS code, and (3) place of service. The list of Part B features
used for training are provided in Table 1.
Similarly, the Part D data set contains a variety of provider-level attributes, e.g.
NPI, name, and provider type. More importantly, the Part D data set contains specific
details about medications prescribed by Medicare providers. Examples of prescrip-
tion attributes include drug names, costs, quantities prescribed, and the number of
beneficiaries receiving the medication. CMS aggregates the Part D data over (1) pre-
scriber NPI and (2) drug name. Table 2 summarizes all Part D predictors used for
classification.
The ECBDL’14 data set was originally generated to train a predictor for the residue–
residue contact prediction track of the 9th Community-Wide Experiment on the
Critical Assessment of Techniques for Protein Structure Prediction competition
(CASP9) [72]. Protein contact map prediction is a subproblem of protein structure
Thresholding Strategies for Deep Learning … 209
prediction. This subproblem entails predicting whether any two residues in a pro-
tein sequence are spatially close to each other [78]. The three-dimensional structure
of a protein can then be inferred from these residue–residue contact map predic-
tions. This is a fundamental problem in medical domains, e.g. drug design, as the
three-dimensional structure of a protein determines its function [79].
Each instance of the ECBDL’14 data set is a pair of amino acids represented by 539
continuous attributes, 92 categorical attributes, and a binary label that distinguishes
pairs that are in contact. Triguero et al. [73] provide a thorough description of the
data set and the methods used to win the competition. Attributes include detailed
information and statistics about the protein sequence and the segments connecting
the target pair of amino acids. Additional predictors include the length of the protein
sequence and a statistical contact propensity between the target pair of amino acid
types [73]. The training partition contains 32 million instances and a positive class
size of 2%. We refer readers to the original paper by Triguero et al. for a deeper
understanding of the amino acid representation. The characteristics of this data set
have made it a popular choice for evaluating methods of treating class imbalance and
big data.
Both Medicare Part B and Part D data sets were curated and cleaned by Herland et
al. [64] and required minimal preprocessing. First, test sets were created by using
stratified random sampling to hold out 20% of each Medicare data set. The held-
out test sets remain constant throughout all experiments, and when applicable, data
sampling is only applied to the training data to create new class imbalance levels.
These methods for creating varying levels of class imbalance are described in detail
in Sect. 4. Finally, all features were normalized to continuous values in the range
[0, 1]. Train and test sets were normalized by fitting a min-max scaler to the fit data
and applying the scaler to the train and test sets separately.
210 J. M. Johnson and T. M. Khoshgoftaar
A subset of 3.5 million records was taken from the ECBDL’14 data with ran-
dom under-sampling. Categorical features were encoded using one-hot encoding,
resulting in a final set of 985 features. Similar to the Medicare data, we normalized
all attributes to the range [0, 1] and set aside a 20% test set using stratified random
sampling. To improve efficiency, we applied Chi-square feature selection [80] and
selected the best 200 features. Preliminary experiments suggested that exceeding 200
features provided limited returns on validation performance.
Table 3 lists the sizes of train and test sets, the total number of predictors, and
the size of the positive class for each data set. Train sets have between 2.8 and 3.7
million samples and relatively fewer positive instances. The ECBDL’14 data set has
a positive class size of 2% and is arguably not considered to be highly imbalanced,
i.e. positive class is greater than 1% of the data set. In Sect. 4, we explain how high
class imbalance is simulated by under-sampling the positive class. The Medicare
data sets, on the other hand, have just between 3 and 4 positive instances for every
10,000 observations and are intrinsically severely imbalanced.
4 Methods
We evaluate thresholding strategies for addressing high class imbalance with deep
neural networks across a wide range of class imbalance levels. A stratified random
80–20% split is used to create train and test partitions from each of the three data sets.
Training sets are sampled to simulate the various levels of class imbalance, and test
sets are held constant for evaluation purposes. All validation and hyperparameter
tuning are executed on random partitions of the training data. After configuring
hyperparameters, each model is fit to the training data and scored on the test set with
30 repetitions. This repetition accounts for any variation in results that may be caused
by random sampling and allows for added statistical analysis.
All experiments are performed on a high-performance computing environment
running Scientific Linux 7.4 (Nitrogen) [81]. Neural networks are implemented using
the Keras [32] open-source deep learning library written in Python with the Tensor-
Flow [30] backend. The specific library implementations used in this study are the
Thresholding Strategies for Deep Learning … 211
neurons each was required to fit the training data. We then explored regularization
techniques to eliminate overfitting and improve validation performance. One way to
reduce overfitting is to reduce the total number of learnable parameters, i.e. reduc-
ing network depth or width. L1 or L2 regularization methods, or weight decay, add
parameter penalties to the objective function that constrain the network’s weights
to lie within a region that is defined by a coefficient α [27]. Dropout simulates the
ensembling of many models by randomly disabling non-output neurons with a prob-
ability P ∈ [0, 1] during each iteration, preventing neurons from co-adapting and
forcing the model to learn more robust features [87]. Although originally designed
to address internal covariate shift and speed up training, batch normalization has also
been shown to add regularizing effects to neural networks [88]. Batch normalization
is similar to normalizing input data to have a fixed mean and variance, except that it
normalizes the inputs to hidden layers across each batch. We found a combination of
dropout and batch normalization for best performance for all three data sets. For the
Medicare models, we use a dropout rate of P = 0.5 and for the ECBDL’14 data set
we use a dropout rate of P = 0.8. Batch normalization is applied before the activation
function in each hidden unit.
Table 4 describes the two-layer baseline architecture for the Medicare Part B data
set. To determine how the number of hidden layers affects performance, we extended
this model to four hidden layers following the same pattern, i.e. using 32 neuron lay-
ers, batch normalization, ReLU activations, and dropout in each hidden layer. We did
not find it necessary to select new hyperparameters for the Medicare Part D data set.
Instead, we just changed the size of the input layer to match the total number of fea-
tures in each respective data set. The architecture for the ECBDL’14 data set follows
this same basic pattern but contains four hidden layers with 128, 128, 64, and 32 neu-
rons in each consecutive layer. With the increased feature count and network width,
the ECBDL’14 network contains 54 K tunable parameters and is approximately 10×
larger than the two-layer architecture used in Medicare experiments.
Thresholding Strategies for Deep Learning … 213
We use data sampling to alter training distributions and evaluate thresholding strate-
gies across a wide range of class imbalance levels. The ROS method randomly dupli-
cates samples from the minority class until the desired positive class prior is achieved.
RUS randomly removes samples from the majority class without replacement until
the desired level of imbalance is reached. The hybrid ROS-RUS method first under-
samples from the majority class without replacement and then over-samples the
minority class until classes are balanced. A combination of ROS, RUS, and ROS-
RUS is used to create 17 new training distributions from each Medicare data set and
12 new training distributions from the ECBDL’14 data set.
Table 5 describes the 34 training distributions that were created from the Medicare
Part B and Part D data sets. Of these new distributions, 24 contain low to severe
levels of class imbalance and 10 have balanced positive and negative classes. The
ROS-RUS-1, ROS-RUS-2, and ROS-RUS-3 use RUS to remove 50, 75, and 90% of
the majority class. They then over-sample the minority class until both classes are
balanced 50:50.
The original ECBDL’14 data set has a positive class size of 2% and is not classified
as highly imbalanced. Therefore, we first simulate two highly imbalanced distribu-
tions by combining the entire majority class with two subsets of the minority class.
By randomly under-sampling the minority class, we achieve two new distributions
that have positive class sizes of 1% and 0.5%. We create additional distributions with
positive class sizes of 5, 10, 20, 30, 40, 50, 60, 70, 80, and 90% by using RUS to
reduce the size of the negative class. As a result, we are able to evaluate thresholding
strategies on 12 class-imbalanced distributions of ECBDL’14 data and one balanced
distribution of ECBDL’14 data. The size of each positive and negative class can be
inferred from these strategies and the training data sizes from Table 3.
The Optimal threshold strategy is used to approximately balance the TPR and TNR.
We accomplish this by using a range of decision thresholds to score models on
validation and training data and then select the threshold which maximizes the G-
Mean. We also add a constraint that the Optimal threshold yields a TPR that is
greater than the TNR, because we are more interested in detecting positive instances
than negative instances. Threshold selection can be modified to optimize any other
performance metric, e.g. precision, and the metrics used to optimize thresholds should
ultimately be guided by problem requirements. If false positives are very costly, for
example, then a threshold that maximizes TNR would be more appropriate. The
performance metrics used in this study are explained in Sect. 4.4 and the procedure
used to compute these Optimal thresholds is defined in Algorithm 1. Once model
training is complete, Algorithm 1 takes ground truth labels and probability estimates
from the trained model, iterates over a range of possible threshold values, and returns
the threshold that maximizes the G-Mean.
Thresholding Strategies for Deep Learning … 215
For Medicare experiments, Optimal decision thresholds are computed during the
validation phase. The validation step for each training distribution entails training
10 models and scoring them on random 10% partitions of the fit data, i.e. validation
sets. Optimal thresholds are computed on each validation set, averaged, and then the
average from each distribution is used to score models on the test set. While this
use of validation data should reduce the risk of overfitting, we do not use validation
data to compute Optimal thresholds for ECBDL’14 experiments. Instead, ECBDL’14
Optimal thresholds are computed by maximizing performance on fit data. Models are
trained for 50 epochs using all fit data, and then Algorithm 1 searches for the threshold
that maximizes performance on the training labels and corresponding model predic-
tions. Unlike the Medicare Optimal thresholds, the ECBDL’14 Optimal thresholds
can be computed with just one extra pass over the training data and do not require
validation partitions. This is beneficial when training data is limited or when the
positive class is absolutely rare.
The Prior thresholding strategy estimates the positive class prior from the training
data, i.e. p(cpos ) from Eq. 1, and uses its value to assign class labels to posterior
scores on the test set. Given a training set with p(cpos ) = 0.1, for example, the Prior
thresholding strategy will assign all test samples with probability scores > 0.1 to
the positive class and those with scores ≤ 0.1 to the negative class. Since the Prior
threshold can be calculated from the training data, and no optimization is required,
we believe it is a good candidate for preliminary experiments with imbalanced data.
Due to time constraints, this method was only explored using ECBDL’14 data.
TP
TPR = Recall = (2)
TP + FN
TN
TNR = Selectivity = (3)
TN + FP
√
G-Mean = TPR × TNR (4)
This section presents the DNN thresholding results that were obtained using the
Medicare and ECBDL’14 data sets. We begin by illustrating the relationship between
Optimal decision thresholds and positive class sizes using confidence intervals (C.I.)
and linear models. Next, Default and Optimal thresholds are used to compare G-Mean
scores across all Medicare distributions. ECBDL’14 results make similar compar-
isons and incorporate a third Prior thresholding strategy that proves effective. Finally,
a statistical analysis of TPR and TNR scores is used to estimate the significance of
each method’s results.
relationship between the positive class size and the Optimal threshold is both linear
and independent of the data set.
In Fig. 1, Medicare Optimal threshold results are grouped by architecture type and
plotted against the positive class size of the training distribution. Plots are enhanced
with horizontal jitter, and linear models are fit to the data using Ordinary Least
218 J. M. Johnson and T. M. Khoshgoftaar
Squares [89] and 95% confidence bands. For both Medicare data sets and network
architectures, there is a strong linear relationship between the positive class size and
the Optimal decision threshold. The relationship is strongest for the two-layer net-
works, with r 2 ≥ 0.980 and p ≤ 9.73e−145. The four-layer network results share
these linear characteristics, albeit weaker with r 2 ≤ 0.965 and visibly larger confi-
dence bands.
We also compute Optimal thresholds after each training epoch using ECBDL’14
data to determine how the Optimal threshold varies during model training. We trained
models for 50 epochs and repeated each experiment three times. As illustrated in
Fig. 2, the Optimal threshold is relatively stable and consistent throughout training.
Similar to Medicare results, the ECBDL’14 Optimal thresholds correspond closely
to the positive class prior of the training distribution.
This section concludes that the positive class size has a strong linear effect on the
Optimal decision threshold. We also found that the Optimal threshold for a given
distribution may vary between deep learning architectures. In the next section, we
consider the significance of these thresholds by using them to evaluate performance
on unseen test data.
Tables 8 and 9 list the 95% G-Mean confidence intervals for all Medicare Part B
and Part D distributions, respectively. Intervals listed in bold indicate those which
are significantly greater than the alternative. Our first observation is that the Default
classification threshold of 0.5 never performs significantly better than the Optimal
threshold. In fact, the Default threshold only yields acceptable G-Mean scores when
classes are virtually balanced, e.g. priors of 0.4–0.6. In all other distributions, the per-
formance of the Default threshold degrades as the level of class imbalance increases.
The Optimal threshold, however, yields relatively stable G-Mean scores across all
distributions. Even the baseline distribution, with a positive class size of just 0.03%,
yields acceptable G-Mean scores > 0.72 when using an Optimal classification thresh-
old. Overall, these results discourage using the Default classification threshold when
training DNN models with class-imbalanced data.
Results also indicate an increase in G-Mean scores among the ROS and ROS-
RUS methods. This is not due to the threshold procedure, but rather, the under-
sampling procedure used to create the RUS distributions. Our previous work shows
that using RUS with these highly imbalanced big data classification tasks tends to
underrepresent the majority group and degrade performance [90].
Figure 3 presents the combined TPR and TNR scores for both Medicare data
sets and DNN architectures. Optimal classification thresholds (left) produce stable
220 J. M. Johnson and T. M. Khoshgoftaar
TPR and TNR scores across all positive class sizes. Furthermore, we observe that the
TPR is always greater than the TNR when using the Optimal classification threshold.
This suggests that our threshold selection procedure (Algorithm 1) is effective, and
that we can expect performance trade-offs optimized during validation to generalize
to unseen test data. Default threshold results (right), however, are unstable as the
positive class size varies. When class imbalance levels are high, for example, the
Default threshold assigns all test samples to the negative class. It is only when classes
are mostly balanced that the Default threshold achieves high TPR and TNR. Even
Thresholding Strategies for Deep Learning … 221
when Default threshold performance is reasonably well balanced, e.g. positive class
sizes of 40%, we lose the ability to maximize TPR over TNR.
In summary, two highly imbalanced Medicare data sets were used to compare
DNN prediction performance using Optimal classification thresholds and Default
classification thresholds. For each data set, Part B and Part D, 18 distributions were
created using ROS and RUS to cover a wide range of class imbalance levels, i.e. 0.03–
60%. For each of these 36 distributions, 30 two-layer networks and 30 four-layer
networks were trained and scored on test sets. With evidence from over 2,000 DNN
models, statistical results show that the Default threshold is suboptimal whenever
models are trained with imbalanced data. Even in the most severely imbalanced
distributions, e.g. positive class size of 0.03%, scoring with an Optimal threshold
yields consistently favorable G-Mean, TPR, and TNR scores. In the next section, we
expand on these results with the ECBDL’14 data set and consider a third thresholding
method, the Prior threshold.
quickly when the positive class prior is ≥ 0.7 or ≤ 0.3. Unlike the Default thresh-
old, the Optimal threshold yields relatively stable G-Mean scores (≥ 0.7) across all
training distributions. Most interestingly, the Prior threshold also yields stable G-
Mean scores across all training distributions. In addition to performing, besides the
Optimal threshold, the Prior threshold strategy has the advantage of being derived
directly from the training data without requiring optimization. Most importantly, we
were able to achieve TPR and TNR scores on par with those of the competition
winners [73] without the added costs of over-sampling.
Table 10 lists ECBDL’14 G-Mean, TPR, and TNR scores averaged across all
training distributions. Tukey’s HSD groups are used to identify results with signif-
icantly different means, i.e. group a performs significantly better than group b. For
the G-Mean metric, the Optimal and Prior threshold methods are placed into group
a with mean scores of 0.7333 and 0.7320, respectively. The Default threshold is
placed into group b with a mean score of 0.4421. These results indicate that the
Optimal and Prior thresholds balance class-wise performance significantly better
than the Default threshold. Equivalently, TPR results place the Optimal and Prior
threshold methods into group a and the Default threshold into group b. Put another
way, the non-default threshold strategies are significantly better at capturing the pos-
itive class of interest. We expect this behavior from the Optimal threshold, as the
threshold selection procedure that we employed (Algorithm 1) explicitly optimizes
thresholds by maximizing TPR and G-Mean scores on the training data. The Prior
method, however, was derived directly from the training data with zero training or
optimization, and surprisingly, performed equally as well as the Optimal threshold.
We believe these qualities make the Prior thresholding strategy a great candidate
for preliminary experiments and baseline models. If requirements call for specific
class-wise performance trade-offs, the Prior threshold can still offer an approximate
baseline threshold to begin optimization.
On further inspection, we see that the Default threshold achieves the highest
average TNR score. While the negative class is typically not the class of interest,
it is still important to minimize false positive predictions. The Default threshold’s
TNR score is misleadingly high, however, and is a result of having more imbalanced
distributions with positive class sizes < 0.5 than there are > 0.5. Recall from Fig. 3
that when the positive class prior is small, models tend to assign all test samples
to the negative class. The Default threshold scores highly on TNR because most
of the models trained with imbalanced data are assigning all test samples to the
Thresholding Strategies for Deep Learning … 223
negative class and achieving a 100% TNR. Therefore, we rely on the G-Mean scores
to ensure class-wise performance is balanced and conclude that the Default threshold
is suboptimal when high class imbalance exists within the training data.
The ECBDL’14 results presented in this section align with those from the Medicare
experiments. For all imbalanced training distributions, Optimal thresholds consis-
tently outperform the Default threshold. The Prior threshold, although not optimized,
performed statistically as well as the Optimal threshold by all performance criteria.
6 Conclusion
This chapter explored the effects of highly imbalanced big data on training and scor-
ing deep neural network classifiers. We trained models on a wide range of class
imbalance levels (0.03–90%) and compared the results of two output thresholding
strategies to Default threshold results. The Optimal threshold technique used train-
ing or validation data to find thresholds that maximize the G-Mean performance
metric. The Prior threshold technique uses the positive class prior as the classifica-
tion threshold for assigning labels to test instances. As suggested by Bayes theorem
(Eq. 1), we found that all Optimal thresholds are proportional to the positive class
priors of the training data. As a result, the Default threshold of 0.5 only performed
well when the training data was relatively balanced, i.e. positive priors between
0.4–0.6. For all other distributions, the Optimal and Prior thresholding strategies
performed significantly better based on the G-Mean criterion. Furthermore, Tukey’s
HSD test results suggest that there is no difference between Optimal and Prior thresh-
old results. These Optimal threshold results are dependent on the threshold selection
criteria (Algorithm 1). This Optimal threshold procedure should be guided by the
classification task requirements, and selecting a new performance criteria may yield
Optimal thresholds that are significantly different from the Prior threshold.
Future works should evaluate these thresholding strategies across a wider range of
domains and network architectures, e.g. natural language processing and computer
vision. Additionally, the threshold selection procedure should be modified to opti-
mize alternative performance metrics, and statistical tests should be used to identify
significant differences.
References
1. W. Wei, J. Li, L. Cao, Y. Ou, J. Chen, Effective detection of sophisticated online banking fraud
on extremely imbalanced data. World Wide Web 16, 449–475 (2013)
2. A.N. Richter, T.M. Khoshgoftaar, Sample size determination for biomedical big data with
limited labels. Netw. Model. Anal. Health Inf. Bioinf. 9, 1–13 (2020)
3. M. Kubat, R.C. Holte, S. Matwin, Machine learning for the detection of oil spills in satellite
radar images. Mach. Learn. 30, 195–215 (1998)
224 J. M. Johnson and T. M. Khoshgoftaar
4. S. Wang, X. Yao, Multiclass imbalance problems: analysis and potential solutions. IEEE Trans.
Syst. Man Cyb. Part B (Cybern.) 42, 1119–1130 (2012)
5. N. Japkowicz, The class imbalance problem: significance and strategies, in Proceedings of the
International Conference on Artificial Intelligence (2000)
6. M. Buda, A. Maki, M.A. Mazurowski, A systematic study of the class imbalance problem in
convolutional neural networks. Neural Netw. 106, 249–259 (2018)
7. H. He, E.A. Garcia, Learning from imbalanced data, IEEE Trans. Knowl. Data Eng. 21, 1263–
1284 (2009)
8. G.M. Weiss, Mining with rarity: a unifying framework. SIGKDD Explor. Newsl. 6, 7–19 (2004)
9. R.A. Bauder, T.M. Khoshgoftaar, T. Hasanin, An empirical study on class rarity in big data,
in 2018 17th IEEE International Conference on Machine Learning and Applications (ICMLA)
(2018), pp. 785–790
10. E. Dumbill, What is big data? an introduction to the big data landscape (2012). http://radar.
oreilly.com/2012/01/what-is-big-data.html
11. S.E. Ahmed, Perspectives on Big Data Analysis: methodologies and Applications (Amer Math-
ematical Society, USA, 2014)
12. J.L. Leevy, T.M. Khoshgoftaar, R.A. Bauder, N. Seliya, A survey on addressing high-class
imbalance in big data. J. Big Data 5, 42 (2018)
13. J. Dean, S. Ghemawat, Mapreduce: simplified data processing on large clusters. Commun.
ACM 51, 107–113 (2008)
14. M. Zaharia, M. Chowdhury, M. J. Franklin, S. Shenker, I. Stoica, Spark: cluster computing
with working sets, in Proceedings of the 2Nd USENIX Conference on Hot Topics in Cloud
Computing, HotCloud’10, (Berkeley, CA, USA), USENIX Association (2010), p. 10
15. K. Chahal, M. Grover, K. Dey, R.R. Shah, A hitchhiker’s guide on distributed training of deep
neural networks. J. Parallel Distrib. Comput. 10 (2019)
16. R.K.L. Kennedy, T.M. Khoshgoftaar, F. Villanustre, T. Humphrey, A parallel and distributed
stochastic gradient descent implementation using commodity clusters. J. Big Data 6(1), 16
(2019)
17. D.L. Wilson, Asymptotic properties of nearest neighbor rules using edited data. IEEE Trans.
Syst. Man Cybern. SMC-2, 408–421 (1972)
18. N.V. Chawla, K.W. Bowyer, L.O. Hall, W.P. Kegelmeyer, Smote: synthetic minority over-
sampling technique. J. Artif. Int. Res. 16, 321–357 (2002)
19. H. Han, W.-Y. Wang, B.-H. Mao, Borderline-smote: a new over-sampling method in imbalanced
data sets learning, in Advances in Intelligent Computing ed. by D.-S. Huang, X.-P. Zhang, G.-B.
Huang (Springer, Berlin, Heidelberg, 2005), pp. 878–887
20. T. Jo, N. Japkowicz, Class imbalances versus small disjuncts. SIGKDD Explor. Newsl. 6, 40–49
(2004)
21. C. Ling, V. Sheng, Cost-sensitive learning and the class imbalance problem, in Encyclopedia
of Machine Learning (2010)
22. J.J Chen, C.-A. Tsai, H. Moon, H. Ahn, J.J. Young, C.-H. Chen, Decision threshold adjustment
in class prediction, in SAR and QSAR in Environmental Research, vol. 17 (2006), pp. 337–352
23. Q. Zou, S. Xie, Z. Lin, M. Wu, Y. Ju, Finding the best classification threshold in imbalanced
classification. Big Data Res. 5, 2–8 (2016)
24. X. Liu, J. Wu, Z. Zhou, Exploratory undersampling for class-imbalance learning. IEEE Trans.
Syst. Man Cybern. Part B (Cybern.) 39, 539–550 (2009)
25. N.V. Chawla, A. Lazarevic, L.O. Hall, K.W. Bowyer, Smoteboost: improving prediction of
the minority class in boosting, in Knowledge Discovery in Databases: PKDD 2003 ed. by
N. Lavrač, D. Gamberger, L. Todorovski, H. Blockeel, (Springer, Berlin, Heidelberg, 2003),
pp. 107–119
26. Y. Sun, Cost-sensitive Boosting for Classification of Imbalanced Data. Ph.D. thesis, Waterloo,
Ont., Canada, Canada, 2007. AAINR34548
27. I. Goodfellow, Y. Bengio, A. Courville, Deep Learning (The MIT Press, Cambridge, MA,
2016)
Thresholding Strategies for Deep Learning … 225
28. I.H. Witten, E. Frank, M.A. Hall, C.J. Pal, Data Mining, Fourth Edition: practical Machine
Learning Tools and Techniques, 4th edn. (San Francisco, CA, USA, Morgan Kaufmann Pub-
lishers Inc., 2016)
29. Y. LeCun, Y. Bengio, G. Hinton, Deep learning. Nature. 521, 436 (2015)
30. M. Abadi, A. Agarwal, P. Barham, E. Brevdo, Z. Chen, C. Citro, G.S. Corrado, A. Davis,
J. Dean, M. Devin, S. Ghemawat, I. Goodfellow, A. Harp, G. Irving, M. Isard, Y. Jia, R. Joze-
fowicz, L. Kaiser, M. Kudlur, J. Levenberg, D. Mané, R. Monga, S. Moore, D. Murray, C. Olah,
M. Schuster, J. Shlens, B. Steiner, I. Sutskever, K. Talwar, P. Tucker, V. Vanhoucke, V. Vasude-
van, F. Viégas, O. Vinyals, P. Warden, M. Wattenberg, M. Wicke, Y. Yu, X. Zheng, TensorFlow:
large-scale machine learning on heterogeneous systems (2015)
31. Theano Development Team, Theano: a python framework for fast computation of mathematical
expressions (2016). arXiv:abs/1605.02688
32. F. Chollet et al., Keras (2015). https://keras.io
33. A. Paszke, S. Gross, S. Chintala, G. Chanan, E. Yang, Z. DeVito, Z. Lin, A. Desmaison,
L. Antiga, A. Lerer, Automatic differentiation in pytorch, in NIPS-W (2017)
34. S. Chetlur, C. Woolley, P. Vandermersch, J. Cohen, J. Tran, B. Catanzaro, E. Shelhamer, cudnn:
efficient primitives for deep learning (2014)
35. A. Krizhevsky, I. Sutskever, G.E. Hinton, Imagenet classification with deep convolutional
neural networks. Neural Inform. Process. Syst. 25, 01 (2012)
36. M.D. Richard, R.P. Lippmann, Neural network classifiers estimate bayesian a posteriori prob-
abilities. Neural Comput. 3(4), 461–483 (1991)
37. Centers For Medicare & Medicaid Services, Medicare provider utilization and payment data:
physician and other supplier (2018)
38. Centers For Medicare & Medicaid Services, Medicare provider utilization and payment data:
part D prescriber (2018)
39. U.S. Government, U.S. Centers for Medicare & Medicaid Services, The official U.S. govern-
ment site for medicare
40. Evolutionary Computation for Big Data and Big Learning Workshop, Data mining competition
2014: self-deployment track
41. M. Wani, F. Bhat, S. Afzal, A. Khan, Advances in Deep Learning (Springer, 2020)
42. J.W. Tukey, Comparing individual means in the analysis of variance. Biometrics 5(2), 99–114
(1949)
43. R. Anand, K.G. Mehrotra, C.K. Mohan, S. Ranka, An improved algorithm for neural network
classification of imbalanced training sets. IEEE Trans. Neural Netw. 4, 962–969 (1993)
44. J.M. Johnson, T.M. Khoshgoftaar, Survey on deep learning with class imbalance. J. Big Data
6, 27 (2019)
45. J.M. Johnson, T.M. Khoshgoftaar, Medicare fraud detection using neural networks. J. Big Data
6(1), 63 (2019)
46. D. Masko, P. Hensman, The impact of imbalanced training data for convolutional neural net-
works, in 2015. KTH, School of Computer Science and Communication (CSC)
47. H. Lee, M. Park, J. Kim, Plankton classification on imbalanced large scale database via con-
volutional neural networks with transfer learning, in 2016 IEEE International Conference on
Image Processing (ICIP) (2016), pp. 3713–3717
48. S. Wang, W. Liu, J. Wu, L. Cao, Q. Meng, P. J. Kennedy, Training deep neural networks on
imbalanced data sets, in 2016 International Joint Conference on Neural Networks (IJCNN)
(2016), pp. 4368–4374
49. H. Wang, Z. Cui, Y. Chen, M. Avidan, A. B. Abdallah, A. Kronzer, Predicting hospital read-
mission via cost-sensitive deep learning. IEEE/ACM Trans. Comput. Biol. Bioinf. 1 (2018)
50. S.H. Khan, M. Hayat, M. Bennamoun, F.A. Sohel, R. Togneri, Cost-sensitive learning of deep
feature representations from imbalanced data. IEEE Trans. Neural Netw. Learn. Syst. 29, 3573–
3587 (2018)
51. T.-Y. Lin, P. Goyal, R. B. Girshick, K. He, P. Dollár, Focal loss for dense object detection, 2017
IEEE International Conference on Computer Vision (ICCV) (2017), pp. 2999–3007
226 J. M. Johnson and T. M. Khoshgoftaar
52. C. Huang, Y. Li, C. C. Loy, X. Tang, Learning deep representation for imbalanced classifica-
tion, in 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR) (2016),
pp. 5375–5384
53. S. Ando, C.Y. Huang, Deep over-sampling framework for classifying imbalanced data, in
Machine Learning and Knowledge Discovery in Databases, ed. by M. Ceci, J. Hollmén,
L. Todorovski, C. Vens, S. Džeroski (Springer International Publishing, Cham, 2017), pp. 770–
785
54. Q. Dong, S. Gong, X. Zhu, Imbalanced deep learning by minority class incremental rectifica-
tion. IEEE Trans. Pattern Anal. Mach. Intell. 1 (2018)
55. Q. Chen, J. Huang, R. Feris, L.M. Brown, J. Dong, S. Yan, Deep domain adaptation for describ-
ing people based on fine-grained clothing attributes, in 2015 IEEE Conference on Computer
Vision and Pattern Recognition (CVPR) (2015), pp. 5315–5324
56. Y. LeCun, C. Cortes, MNIST handwritten digit database (2010). http://yann.lecun.com/exdb/
mnist/, Accessed 15 Nov 2018
57. A. Krizhevsky, V. Nair, G. Hinton, Cifar-10 (canadian institute for advanced research). http://
www.cs.toronto.edu/kriz/cifar.html
58. R.A. Bauder, T.M. Khoshgoftaar, A novel method for fraudulent medicare claims detection from
expected payment deviations (application paper), in 2016 IEEE 17th International Conference
on Information Reuse and Integration (IRI) (2016), pp. 11–19
59. R.A. Bauder, T.M. Khoshgoftaar, A probabilistic programming approach for outlier detection
in healthcare claims, in 2016 15th IEEE International Conference on Machine Learning and
Applications (ICMLA) (2016), pp. 347–354
60. R.A. Bauder, T.M. Khoshgoftaar, A. Richter, M. Herland, Predicting medical provider spe-
cialties to detect anomalous insurance claims, in 2016 IEEE 28th International Conference on
Tools with Artificial Intelligence (ICTAI) (2016), pp. 784–790
61. M. Herland, R.A. Bauder, T.M. Khoshgoftaar, Medical provider specialty predictions for the
detection of anomalous medicare insurance claims, in 2017 IEEE International Conference on
Information Reuse and Integration (IRI) (2017), pp. 579–588
62. Office of Inspector General, LEIE downloadable databases (2019)
63. R.A. Bauder, T.M. Khoshgoftaar, The detection of medicare fraud using machine learning
methods with excluded provider labels, in FLAIRS Conference (2018)
64. M. Herland, T.M. Khoshgoftaar, R.A. Bauder, Big data fraud detection using multiple medicare
data sources. J. Big Data 5, 29 (2018)
65. M. Herland, R.A. Bauder, T.M. Khoshgoftaar, The effects of class rarity on the evaluation of
supervised healthcare fraud detection models. J. Big Data 6(1), 21 (2019)
66. K. Feldman, N.V. Chawla, Does medical school training relate to practice? evidence from big
data. Big Data (2015)
67. Centers for Medicare & Medicaid Services, Physician compare datasets (2019)
68. J. Ko, H. Chalfin, B. Trock, Z. Feng, E. Humphreys, S.-W. Park, B. Carter, K.D. Frick, M. Han,
Variability in medicare utilization and payment among urologists. Urology 85, 03 (2015)
69. V. Chandola, S.R. Sukumar, J.C. Schryver, Knowledge discovery from massive healthcare
claims data, in KDD (2013)
70. L.K. Branting, F. Reeder, J. Gold, T. Champney, Graph analytics for healthcare fraud risk
estimation, in 2016 IEEE/ACM International Conference on Advances in Social Networks
Analysis and Mining (ASONAM) (2016), pp. 845–851
71. National Plan & Provider Enumeration System, NPPES NPI registry (2019)
72. P.S.P. Center, 9th community wide experiment on the critical assessment of techniques for
protein structure prediction
73. I. Triguero, S. Rí, V. López, J. Bacardit, J. Benítez, F. Herrera, ROSEFW-RF: the winner
algorithm for the ecbdl’14 bigdata competition: an extremely imbalanced big data bioinfor-
maticsproblem. Knowl.-Based Syst. 87 (2015)
74. A. Fernández, S. del Río, N.V. Chawla, F. Herrera, An insight into imbalanced big data classi-
fication: outcomes and challenges. Complex Intell. Syst. 3(2), 105–120 (2017)
Thresholding Strategies for Deep Learning … 227
75. S. de Río, J.M. Benítez, F. Herrera, Analysis of data preprocessing increasing the over-
sampling ratio for extremely imbalanced big data classification, in 2015 IEEE Trust-
com/BigDataSE/ISPA, vol. 2 (2015), pp. 180–185
76. Centers for Medicare & Medicaid Services, National provider identifier standard (NPI) (2019)
77. Centers For Medicare & Medicaid Services, HCPCS general information (2018)
78. P. Di Lena, K. Nagata, P. Baldi, Deep architectures for protein contact map prediction, Bioin-
formatics (Oxford, England) 28, 2449–57 (2012)
79. J. Berg, J. Tymoczko, L. Stryer, Chapter 3, protein structure and function, in Biochemistry, 5th
edn. (W H Freeman, New York, 2002)
80. Z. Zhao, F. Morstatter, S. Sharma, S. Alelyani, A. Anand, H. Liu, Advancing feature selection
research, ASU Feature Selection Repository (2010), pp. 1–28
81. S. Linux, About (2014)
82. F. Pedregosa, G. Varoquaux, A. Gramfort, V. Michel, B. Thirion, O. Grisel, M. Blondel, P.
Prettenhofer, R. Weiss, V. Dubourg, J. Vanderplas, A. Passos, D. Cournapeau, M. Brucher,
M. Perrot, E. Duchesnay, Scikit-learn: machine learning in python. J. Mach. Learn. Res. 12,
2825–2830 (2011)
83. F. Provost, T. Fawcett, Analysis and visualization of classifier performance: comparison under
imprecise class and cost distributions, in Proceedings of the Third International Conference
on Knowledge Discovery and Data Mining, vol. 43–48 (1999), p. 12
84. D. Wilson, T. Martinez, The general inefficiency of batch training for gradient descent learning.
Neural Netw.: Off. J. Int. Neural Netw. Soc. 16, 1429–51 (2004)
85. D.P. Kingma, J. Ba, Adam: a method for stochastic optimization. CoRR (2015).
arXiv:abs/1412.6980
86. R.P. Lippmann, Neural networks, bayesian a posteriori probabilities, and pattern classification,
in From Statistics to Neural Networks, ed. by V. Cherkassky, J.H. Friedman, H. Wechsler
(Springer, Berlin, Heidelberg, 1994), pp. 83–104
87. N. Srivastava, G. Hinton, A. Krizhevsky, I. Sutskever, R. Salakhutdinov, Dropout: a simple
way to prevent neural networks from overfitting. J. Mach. Learn. Res. 15, 1929–1958 (2014)
88. S. Ioffe, C. Szegedy, Batch normalization: accelerating deep network training by reducing
internal covariate shift, in Proceedings of the 32Nd International Conference on International
Conference on Machine Learning, ICML’15, vol. 37 (JMLR.org, 2015), pp. 448–456
89. B. Zdaniuk, Ordinary Least-Squares (OLS) Model (Dordrecht, Springer Netherlands, 2014),
pp. 4515–4517
90. J.M. Johnson, T.M. Khoshgoftaar, Deep learning and data sampling with imbalanced big data,
2019 IEEE 20th International Conference on Information Reuse and Integration for Data
Science (IRI) (2019), pp. 175–183
Vehicular Localisation at High and Low
Estimation Rates During GNSS Outages:
A Deep Learning Approach
U. Onyekpe (B)
Research Center for Data Science, Institute for Future Transport and Cities, Coventry University,
Gulson Road, Coventry, UK
e-mail: onyekpeu@uni.coventry.ac.uk
S. Kanarachos
Faculty of Engineering, Coventry University, Gulson Road, Coventry, UK
e-mail: ab8522@coventry.ac.uk
V. Palade
Research Center for Data Science, Coventry University, Gulson Road, Coventry, UK
e-mail: ab5839@coventry.ac.uk
S.-R. G. Christopoulos
Institute for Future Transport and Cities, Faculty of Engineering, Coventry University, Gulson
Road, Coventry, UK
e-mail: ac0966@coventry.ac.uk
© The Editor(s) (if applicable) and The Author(s), under exclusive license 229
to Springer Nature Singapore Pte Ltd. 2021
M. A. Wani et al. (eds.), Deep Learning Applications, Volume 2,
Advances in Intelligent Systems and Computing 1232,
https://doi.org/10.1007/978-981-15-6759-9_10
230 U. Onyekpe et al.
1 Introduction
It is estimated that the UK’s autonomous vehicle market will be worth an approximate
value of £28 billion by 2035 [1]. A major motivation towards the development of
these vehicles is the need to improve road safety. According to [2], 75% of traffic-
related road accidents in the UK are due to human driving errors, rising to 94%
in the United States. The introduction of autonomous vehicles has the potential to
reduce these accidents [3]. Even though these vehicles could introduce new kinds of
accidents, there is the drive to ensure they are as safe as possible.
Sensory systems are key to the performance of autonomous vehicles as they
help the vehicle understand its environment [4]. Examples of sensors found on the
outside of the vehicle include LIDARs, cameras and ultrasonic systems. Several data
processing and analysis systems are also present inside the vehicle, which use the
sensor data to make decisions a human driver would normally make. LIDARs and
cameras are imaging systems used to identify objects, structures, potential collision
hazards and pedestrians in the vehicle’s trajectory [5]. Cameras are furthermore
essential to the identification of road signs and markings on structured roads. As proof
of the crucial role imaging systems play in the operation of the vehicle, a good number
of sophisticated versions of such systems are already employed [6]. Nevertheless,
although imaging systems can be used to assess the vehicle’s environment as well
as determine the position of objects or markings relative to it, there is the need to
continuously and robustly localise a vehicle with reference to a defined co-ordinate
system. Information on how the vehicle navigates through its environment is also
needed such that real-time follow-up decisions can be made.
A GNSS receiver performs complex analysis on the signals received from at least
four of the many satellites orbiting the earth and is known to be one of the best when
it comes to position estimation, as it has no competition in terms of cost or coverage
[7]. Despite the wide acceptance of GNSS, it is far from being a perfect positioning
system. There can be instances of GNSS failures in outdoor environments, as there has
Vehicular Localisation at High and Low Estimation Rates … 231
to be a direct line of sight between the satellites and the GNSS antennae. GNSS can
prove difficult to use in metropolitan cities and similar environments characterised
by tall buildings, bridges, tunnels or trees, as its line of sight may be blocked during
signal transmission [7]. More so, GNSS signal can be jammed and this leaves the
vehicle with no information about its position [8]. As such, a GNSS cannot act as a
standalone navigation system.
The GNSS is used to localise the autonomous vehicle to a road. To achieve lane
accuracy, GNSS is combined with high accuracy LIDARs, cameras, RADAR and
High Definition (HD) maps. There are however times when the camera and LIDAR
could be uninformative or unavailable for use. The accuracy of low-cost LIDARs and
cameras could be compromised when there is heavy fog, snow, sleet or rain [9]. This
is a well-recognised issue in the field. The cost of high accuracy LIDAR also makes
them a theft attractive item as they are worth several thousands of pounds. Hence, the
use of LIDARs on autonomous vehicles would make the vehicles more expensive.
Camera-based positioning systems could also face low accuracies depending on the
external light intensity and the objects in the camera’s scene. In level 4 self-driving
applications, as tested by Waymo LLC and Cruise LLC, the LIDAR scan is matched
onto an HD map in real time. Based on this, the system is able to precisely position
the vehicle within its environment [10]. However, this method is computationally
intensive. Furthermore, changes in the driving environment and infrastructure could
make an HD map temporarily outdated and as such not useful for navigation.
[20]. Despite the wide popularity of the Kalman filter, it does possess some draw-
backs. For an INS/GPS-integrated application, the Kalman filter requires stochastic
models to represent the INS errors, but these stochastic models are difficult to deter-
mine for most gyroscopes and accelerometers [13]. Even more, there is the need for
an accurate a priori information of the covariant matrices of the noises associated
with the INS. More so, the INS/GPS problem is one of a non-linear nature. As a
result, other types of filters have been studied [20].
The Kalman filter functions in two stages: the prediction stage, which involves
the computation of the errors between the measurement and the prediction, and the
update (innovation) stage, where the Kalman filter uses the inputs, measurements
and process model to correct the predictions.
In modelling the error between the INS and GNSS position, we model the
prediction and update stages in discrete time as
Z t = Ht X t + Vt (2)
Prediction Stage:
Innovation Stage: This stage involves the computation of the errors between the
predicted states and the measurements.
−1
K t = Pt− HtT Ht Pt− HtT + Rt (5)
X̂ t = X̂ t− + K t Z t − Ht X̂ t− (6)
Pt = (1 − K t Ht ) ∗ X̂ t− (7)
where X t is the error state vector, Ut is the input/control vector, wt is the system
noise (Gaussian), Vt is the measurement noise (Gaussian), Z t is the measurement
vector, At is the state transition matrix, Bt is the system noise coefficient matrix, Ht
relates the state to the measurement, Qt is the system noise covariance matrix,Rt is
the measurement noise covariance matrix, X̂ t is the state prediction update a priori,
X̂ t− is the state prediction update a posteriori, K t is the Kalman gain matrix, Pt is the
error covariance update a priori and Pt− is the error covariance update a posteriori.
More information on the Kalman filter can be found in [13, 21].
234 U. Onyekpe et al.
where y is the layer output vector, x is the input vector, w is the weight matrix and
b is the bias.
Deep learning algorithms employ the use of multiple layers to extract features
progressively from the raw input through multiple levels of non-linear abstractions.
They try to find good representations of the input–output relationship by exploiting
the input distribution structure [23]. The success of deep learning algorithms on
several challenging problems has been demonstrated in many published literatures,
for example [24–27].
Some popular deep learning architectures used today, especially in computer
vision applications, are ResNet [28] and Yolo [29], which show good performances in
visual recognition tasks. Aside from the deep learning’s success in image processing
tasks, sequential data such as audio and texts are now processed with deep neural
networks, which are able to achieve state-of-the-art performances in speech recog-
nition and natural language processing tasks [30]. These successes lend support to
the potential of deep learning algorithms to learn complex abstractions present in the
noisy sensor signals in the application under study in this chapter.
The position errors of INS are accumulative and follow a certain pattern [25]. There-
fore, previous positional sequences are required for the model to capture the error
trend. It is however difficult to utilise a static neural network to model this pattern. A
dynamic model can be employed by using an architecture that presents the previous
‘t’ values of the signal as inputs to the network, thus capturing the error trend present
Vehicular Localisation at High and Low Estimation Rates … 235
in the previous t timesteps [17]. The model can also be trained to learn time-varying or
sequential trends through the introduction of memory and associator units to the input
layer. The memory unit of the dynamic model can store previous INS samples and
forecast using the associator unit. The use of such a dynamic model has a significant
influence on the accuracy of the INS position prediction in the absence of GPS [17].
Figure 1 illustrates an IDNN’s general architecture, with p being the tapped delay
line memory length, Ui the hidden layer neurons, W the weights, G the activation
function, Y the target vector and D the delay operator.
Recurrent Neural Networks (RNNs) have been proven to learn more useful features
on sequential problems. Long Short-Term Memory (LSTM) networks are a variant
of RNN created to tackle its shortfall. They are specifically created to solve the
long-term dependency problem, hence enabling them to recall information for long
periods. Due to the accumulative and patterned nature of the INS positional errors,
236 U. Onyekpe et al.
the LSTM can be used to learn error patterns from previous sequences to provide
a better position estimation. The operation of the LSTM is regulated by gates; the
forget, input and output gates, and it operates as shown below:
f t = σ W f h t−1 , xt + b f (9)
i t = σ Wi h t−1 , xt + bi (10)
ĉt = tanh Wc h t−1 , xt + bc (11)
ot = σ Wo h t−1 , xt + bo (13)
h t = ot ∗ tanh(ct ) (14)
where ∗ is the Hadamard product, f t is the forget gate, i t is the input gate, ot is the
output gate, ct is the cell state, h t−1 is the previous state, W is the weight matrix and
σ is the sigmoid activation (non-linearity) function. Figures 2 and 3 show the RNN’s
architecture and LSTM’s cell structure.
3 Problem Formulation
In this section, we discuss the formulation of the learning task using the four
modelling techniques introduced in Sect. 2. Furthermore, we present the experiments
and comparative results on the inertial tracking problem in Sect. 4.
Vehicular Localisation at High and Low Estimation Rates … 237
Vehicular tracking using inertial sensors are governed by the Newtonian laws of
motion. Given an initial pose, the attitude rate from the gyroscope can be integrated
to provide continuous orientation information. The acceleration of the vehicle can
also be integrated to determine its velocity and, provided an initial velocity, the
vehicle’s displacement can be estimated through the integration of its velocity.
Several co-ordinate systems are used in vehicle tracking with the positioning
estimation expressed relative to a reference. Usually, measurements in the sensors’
co-ordinate system (body frame) would need to be transformed to the navigation
frame [31]. The body frame has its axis coincident to the sensors’ input axis.
The navigation frame, also known as the local frame, has its origin as the sensor
frames’ origin. Its x-axis points towards the geodetic north, with the z-axis orthogonal
to the ellipsoidal plane and the y-axis completing the orthogonal frame. The rotation
matrix from the body frame to the navigation frame is expressed in Eq. (15) [20].
⎡ ⎤
cos θ cos Ψ − cos θ sin Ψ + sin φ sin θ cos Ψ sin φ sin Ψ + cos φ sin θ cos Ψ
⎢ ⎥
R nb = ⎣ cos θ sin Ψ cos φ cos Ψ + sin φ sin θ sin Ψ − sin φ cos Ψ + cos φ sin θ sin Ψ ⎦ (15)
− sin θ sin φ cos θ cos φ cos θ
the navigation frame [31]. In this application, the Coriolis acceleration is considered
negligible due to its small magnitude compared to the accelerometers’ measurements,
and the centrifugal acceleration is considered to be absorbed in the local gravity
sector. Thus, a n = anb .
The gyroscope measures the attitude change in roll, yaw and pitch. It measures
the angular velocity of the body frame (vehicles frame) with respect to the inertial
frame, as expressed in the body frame. Represented by ωib b
, the attitude rate can be
expressed as
n
ωib
b
= R bn ωie + ωen
n
+ ωnb
b
(19)
where ωie n
is the angular velocity of the earth frame with respect to the inertial frame
and estimated to be approximately 7.29 × 10−5 rad/s [31]. The navigation frame is
defined stationary with respect to the earth, thus ωenn
= 0 with the angular velocity of
interest ωnb representing the rate of rotation of the vehicle in the navigation frame.
b
t
ΨINS = Ψ0 + ωnb
b
(20)
t−1
where
b
FINS = f INS
b
+ δINS
b
+ εab (21)
Furthermore, theaccelerometers’
noise is typically quite Gaussian1 and can be
modelled as εa ∼ N 0, a . The accelerometers’ bias is slowly time varying and as
b
ab = f b + gb (22)
b
FINS = aINS
b
+ δINS,a
b
+ εab (23)
a b = FINS
b
− δINS,a
b
− εab (24)
b
FINS − δINS,a
b
= a b + εab (25)
b
However, aINS = FINS
b
− δ bI N S,a (26)
b
aINS = a b + εab (27)
Through the integration of Eq. (27), the velocity of the vehicle in the body frame
can be determined.
t
vINS
b
= a b + εvb (28)
t−1
The displacement of the sensor in the body frame at time t from t − 1, x IbN S ,
can also be estimated by the double integration of the Eq. (27) provided an initial
velocity.
¨t
b
xINS = a b + εxb (29)
t−1
where δINS,a
b
is the bias in the body frame calculated to be a constant parameter by
computing the average reading from a stationary accelerometer which ran for 20 min.
b
FINS is the corrupted measurement provided directly by the accelerometer sensor at
1 Thevehicles’ dynamics is non-linear, especially when cornering or braking hard; thus, a linear or
non-accurate noise model would not sufficiently capture the non-linear relationship.
240 U. Onyekpe et al.
t ˜t
time t (sampling time), g is the gravity vector, a b , t−1 a b and t−1 a b are the true
(uncorrupted) longitudinal acceleration, velocity and displacement, respectively, of
the vehicle. ˜t
The true displacement of the vehicle is expressed; thus, x Gb P S ≈ t−1 a b
Furthermore, εxb can be derived by
εxb ≈ xGPS
b
− xINS
b
(30)
nb
from RINS · vINS
b
→ vINS
n
→ vINS
b
· cos ΨINS , vINS
b
· sin ΨINS (32)
nb
from RINS · xINS
b
→ xINS
n
→ xINS
b
· cos ΨINS , xINS
b
· sin ΨINS (33)
⎡ ⎤
cos ΨINS − sin ΨINS 0
Where : RINS
nb
= ⎣ sin ΨINS cos ΨINS 0 ⎦ (35)
0 0 1
b b
This section presents how to estimate the vehicle’s true displacement xGPS . xGPS is
useful in the determination of the target error εx as detailed in Eq. (30).
b
In estimating the distance travelled between two points on the earth’s surface, it
becomes obvious that the shape of the earth is neither a perfect sphere nor ellipse, but
rather that of an oblate ellipsoid. Due to the unique shape of the earth, complications
exist as there is no geometric shape it can be categorised under for analysis. The
Haversine formula applies perfectly to the calculations of distances on spherical
shapes, while the Vincenty’s formula applies to elliptic shapes [32].
Vehicular Localisation at High and Low Estimation Rates … 241
The Haversine’s formula is used to calculate the distance between two points on the
earth’s surface specified in longitude and latitude. It assumes a spherical earth [33].
∅t − ∅t−1 ϕt − ϕt−1
b
xGPS = 2r sin−1 sin2 + cos(∅t−1 )cos(∅t )sin2 (36)
2 2
where x̂tb is the distance travelled within t − 1, t with longitude and latitude (ϕ, ∅)
as obtained from the GPS, and r is the radius of the earth.
The Vincenty’s formula is used to calculate the distance between two points on the
earth’s surface specified in longitude and latitude. It assumes an ellipsoidal earth
[34]. The distance between two points is calculated as shown in Eqs. (37)–(55).
1
Given : f = (37)
298.257223563
b = (1 − f )a (38)
ϕ = ϕ2 − ϕ1 (40)
sin σ = (cos U2 sin λ)2 + (cos U1 sinU2 − sinU1 cosU2 cosλ)2 (42)
f
C= cos2 α 4 + f 4 − 3 cos2 α (47)
16
λ = ϕ + (1 − C) f sin α[σ + C sin σ cos(2σm ) + C cos σ −1 + 2 cos2 (2σm )
(48)
2
a − b2
u 2 = cos2 α (49)
b2
√
1 + u2 − 1
k1 = √ (50)
1 + u2 + 1
1 + 41 k12
A= (51)
1 − k1
3
B = k1 1 − k12 (52)
8
σ = B sin σ
1 B
cos2 (2σm ) + B cos σ −1 + 2 cos2 (2σm ) − cos[2σm ][−3 + 4 sin2 σ ][ − 3 + 4 cos2 (2σm )]
4 6
(53)
where x̂tb is the distance travelled within t − 1 and t with longitude and latitude
(ϕ, ∅), a is the radius of the earth at the equator, f is the flattening at the ellipsoid, b
is the length of the ellipsoid semi-minor axis, U1 and U2 are the reduced latitude at t
and t − 1, respectively, λ is the change in longitude along the auxiliary spheres, s is
the ellipsoidal distance between the position at t − 1 and t, σ1 is the angle between
the position at t − 1 and t, σ is the angle between the position at t − 1 and t and σm
is the angle between the equator and midpoint of the line.
The Vincenty’s formula is used in this work, as it provides a more accurate
solution compared to Haversine and other great circle formulas [32]. The Python
implementation of Vincenty’s Inverse Formula is used here [35].
The neural networks introduced in Sect. 2 are exploited to learn the relationship
n
between the input features; displacement xINS , velocity vINS
n n
and acceleration aINS ,
Vehicular Localisation at High and Low Estimation Rates … 243
Fig. 4 Learning scheme for the northwards and eastwards displacement error prediction
and the target displacement error εxn (as presented in Sects. 3.2 and 3.3) in the north-
wards and eastwards direction, as shown in Fig. 4. The predicted displacement error is
used to correct the INS-derived displacement to provide a better positioning solution.
The data used is the V-Vw12 aggressive driving vehicle benchmark data set
describing about 107 s of an approximate straight-line trajectory, of which the first
105 s is used for our analysis [36]. The sensors are inbuilt and the data is captured
from the vehicle’s ECU using the Vbox video H2 data acquisition sampling at a
frequency of 10 Hz. The longitudinal acceleration of the vehicle as well as its rate of
rotation about the z-axis (yaw rate), heading (yaw) and the GPS co-ordinates (lati-
tude and longitude) of the vehicle at each time instance is captured. Figure 5 shows
the vehicle used for the data collection and the location of its sensors.
4.2 Training
The training is done using the first 75 s of the data, and then the model is tested on the
next 10 s as well as 30 s after the training data. The Keras–TensorFlow framework
was used in the training exercise with a mean absolute error loss function and an
Adamax optimiser with a learning rate of 0.09. 5% of the units were dropped from
the hidden layers of the IDNN and MLNN and from the recurrent layers of the LSTM
to prevent the neural network from overfitting [37]. Furthermore, all features fed to
the models were standardised between 0 and 100 to avoid a biased learning. Forty
models were trained for each deep learning model, and the model providing the least
position errors was selected. Parameters defining the training of the neural network
models are highlighted on Table 1. The general architecture of the models is as shown
in Fig. 4. The objective of the training exercise is to teach the neural network to learn
the positioning error between the low-cost INS and GPS.
4.3 Testing
GPS outages were assumed on the 10 s as well as 30 s of data following the training
data to help analyse the performance of the prediction models. With information
on the vehicles orientation and velocity at the last known GPS signal before the
outage, the northwards and eastwards components of the vehicles displacement
x InN S,t, velocity v nI N S,t and acceleration a nIN S,t are fed into the respective models
to predict the north and east component of the displacement error εx,t n
.
Vehicular Localisation at High and Low Estimation Rates … 245
To evaluate the performance of the LSTM, IDNN, MLNN and Kalman filter
techniques, two GPS outage scenarios are explored: 10 s and 30 s.
The performance of the LSTM, IDNN, MLNN and Kalman filter solutions during
the 10 s outage is studied. From Table 2, it can be seen that at all sampling periods
the LSTM algorithm performed best at estimating the positioning error, followed
closely by the Kalman filter and IDNN. The MLNN has the least performance in
comparison. Comparing all sampling periods, the LSTM method produces the best
error estimation of 0.60 m, at a sampling period of 0.7 s, over about 233 m of travel
within the 10 s studied.
A study of the LSTM, IDNN, MLNN and Kalman Filter performances during the
30 s GPS outage scenario reveals that much unlike the 10 s experiment, the Kalman
filter performs poorly in comparison to the LSTM and IDNN approaches. The LSTM
performs the best at all sampling frequencies, with the Kalman filter outperforming
the MLNN. A comparison of the sampling periods shows that the LSTM approach
provides the best error estimation of 4.86 m at 10 Hz, over about 724 m of travel
within the 30 s investigated (Table 3).
246 U. Onyekpe et al.
5 Conclusions
References
1. I. Dowd, The future of autonomous vehicles, Open Access Government (2019). https://www.
openaccessgovernment.org/future-of-autonomous-vehicles/57772/, Accessed 04 June 2019
2. P. Liu, R. Yang, Z. Xu, How safe is safe enough for self-driving vehicles? Risk Anal. 39(2),
315–325 (2019)
3. A. Papadoulis, M. Quddus, M. Imprialou, Evaluating the safety impact of connected and
autonomous vehicles on motorways. Accid. Anal. Prev. 124, 12–22 (2019)
4. S.-J. Babak, S.A. Hussain, B. Karakas, S. Cetin, Control of autonomous ground vehicles:
a brief technical review—IOPscience (2017). https://iopscience.iop.org/article/10.1088/1757-
Vehicular Localisation at High and Low Estimation Rates … 247
27. W. Fang et al., A LSTM algorithm estimating pseudo measurements for aiding INS during
GNSS signal outages. Remote Sens. 12(2), 256 (2020)
28. K. He, X. Zhang, S. Ren, J. Sun, Deep residual learning for image recognition, in Proceedings
of the IEEE Computer Society Conference on Computer Vision and Pattern Recognition, vol.
2016, pp. 770–778 (2016)
29. J. Redmon, S. Divvala, R. Girshick, A. Farhadi, You only look once: unified, real-time object
detection (2016). arXiv preprint arXiv:1506.02640
30. H. Ismail Fawaz, G. Forestier, J. Weber, L. Idoumghar, P.A. Muller, Deep learning for time
series classification: a review. Data Min. Knowl. Discov. 33(4), 917–963 (2019)
31. M. Kok, J.D. Hol, T.B. Schön, Using inertial sensors for position and orientation estimation.
Found. Trends Signal Process. 11(2), 1–153 (2017)
32. H. Mahmoud, N. Akkari, Shortest path calculation: a comparative study for location-based
recommender system, in Proceedings—2016 World Symposium on Computer Applications
and Research, WSCAR 2016, pp. 1–5 (2016)
33. C.M. Thomas, W.E. Featherstone, Validation of Vincenty’s formulas for the geodesic using a
new fourth-order extension of Kivioja’s formula. J. Surv. Eng. 131(1), 20–26 (2005)
34. T. Vincenty, Direct and inverse solutions of geodesics on the ellipsoid with application of nested
equations. Surv. Rev. 23(176), 88–93 (1975)
35. vincenty PyPI. https://pypi.org/project/vincenty/. Accessed 08 May 2020
36. U. Onyekpe, V. Palade, S. Kanarachos, A. Szkolnik, IO-VNBD: inertial and odometry
benchmark dataset for ground vehicle positioning (2020). arXiv preprint arXiv:2005.01701
37. Y. Gal, Z. Ghahramani, A theoretically grounded application of dropout in recurrent neural
networks (2016). arXiv preprint arXiv:1512.05287
Multi-Adversarial Variational
Autoencoder Nets for Simultaneous
Image Generation and Classification
1 Introduction
Training deep neural networks usually requires copious data, yet obtaining large,
accurately labeled datasets for image classification and other tasks remains a fun-
damental challenge [36]. Although there has been explosive progress in the produc-
tion of vast quantities of high resolution images, large collections of labeled data
required for supervised learning remain scarce. Especially in domains such as med-
ical imaging, datasets are often limited in size due to privacy issues, and annotation
by medical experts is expensive, time-consuming, and prone to human subjectivity,
Fig. 1 Image generation based on the CIFAR-10 dataset [19]: a Relatively good images generated
by a GAN. b Blurry images generated by a VAE. Based on the SVHN dataset [24]: c mode collapsed
images generated by a GAN
inconsistency, and error. Even when large labeled datasets become available, they
are often highly imbalanced and non-uniformly distributed. In an imbalanced med-
ical dataset there will be an over-representation of common medical problems and
an under-representation of rarer conditions. Such biases make the training of neural
networks across multiple classes with consistent effectiveness very challenging.
The small-training-data problem is traditionally mitigated through simplistic and
cumbersome data augmentation, often by creating new training examples through
translation, rotation, flipping, etc. The missing or mismatched label problem may
be addressed by evaluating similarity measures over the training examples. This is
not always robust and its effectiveness depends largely on the performance of the
similarity measuring algorithms.
With the advent of deep generative models such as Variational AutoEncoders
(VAEs) [18] and Generative Adversarial Networks (GANs) [9], the ability to learn
underlying data distributions from training samples has become practical in common
scenarios where there is an abundance of unlabeled data. With minimal annotation,
efficient semi-supervised learning could be the preferred approach [16]. More specif-
ically, based on small quantities of annotation, realistic new training images may be
generated by models that have learned real-world data distributions (Fig. 1a). Both
VAEs and GANs may be employed for this purpose.
VAEs can learn dimensionality-reduced representations of training data and, with
an explicit density estimation, can generate new samples. Although VAEs can per-
form fast variational inference, VAE-generated samples are usually blurry (Fig. 1b).
On the other hand, despite their successes in generating images and semi-supervised
classifications, GAN frameworks remain difficult to train and there are challenges
in using GAN models, such as non-convergence due to unstable training, dimin-
ished gradient issues, overfitting, sensitivity to hyper-parameters, and mode collapsed
image generation (Fig. 1c).
Despite the recent progress in high-quality image generation with GANs and
VAEs, accuracy and image quality are usually not ensured by the same model, espe-
cially in multiclass image classification tasks. To tackle this shortcoming, we propose
Multi-Adversarial Variational Autoencoder Nets for Simultaneous … 251
a novel method that can simultaneously learn image generation and multiclass image
classification. Specifically, our work makes the following contributions:
1. The Multi-Adversarial Variational autoEncoder Network, or MAVEN, a novel
multiclass image classification model incorporating an ensemble of discriminators
in a combined VAE-GAN network. An ensemble layer combines the feedback
from multiple discriminators at the end of each batch. With the inclusion of
ensemble learning at the end of a VAE-GAN, both generated image quality and
classification accuracy are improved simultaneously.
2. A simplified version of the Descriptive Distribution Distance (DDD) [14] for eval-
uating generative models, which better represents the distribution of the generated
data and measures its closeness to the real data.
3. Extensive experimental results utilizing two computer vision and two medical
imaging datasets.1 These confirm that our MAVEN model improves upon the
simultaneous image generation and classification performance of a GAN and of
a VAE-GAN with the same set of hyper-parameters.
2 Related Work
Several techniques have been proposed to stabilize GAN training and avoid mode
collapse. Nguyen et al. [26] proposed a model where a single generator is used
alongside dual discriminators. Durugkar et al. [7] proposed a model with a single
generator and feedback aggregated over several discriminators, considering either
the average loss over all discriminators or only the discriminator with the maximum
loss in relation to the generator’s output. Neyshabur et al. [25] proposed a framework
in which a single generator simultaneously trains against an array of discrimina-
tors, each of which operates on a different low-dimensional projection of the data.
Moridido et al. [23], arguing that all the previous approaches restrict the discrimina-
tor’s architecture thereby compromising extensibility, proposed the Dropout-GAN,
where a single generator is trained against a dynamically changing ensemble of
discriminators. However, there is a risk of dropping out all the discriminators. Fea-
ture matching and minibatch discrimination techniques have been proposed [32] for
eliminating mode collapse and preventing overfitting in GAN training.
Realistic image generation helps address problems due to the scarcity of labeled
data. Various architectures of GANs and their variants have been applied in ongoing
efforts to improve the accuracy and effectiveness of image classification. The GAN
framework has been utilized as a generic approach to generating realistic train-
ing images that synthetically augment datasets in order to combat overfitting; e.g.,
for synthetic data augmentation in liver lesions [8], retinal fundi [10], histopathol-
ogy [13], and chest X-rays [16, 31]. Calimeri et al. [3] employed a LAPGAN [6] and
Han et al. [11] used a WGAN [1] to generate synthetic brain MR images. Bermudez
1 Thischapter significantly expands upon our ICMLA 2019 publication [15], which excluded our
experiments on medical imaging datasets.
252 A.-A.-Z. Imran and D. Terzopoulos
Figure 2 illustrates the models that serve as precursors to our MAVEN architecture.
The VAE is an explicit generative model that uses two neural nets, an encoder
E and decoder D . Network E learns an efficient compression of real data x into a
lower dimensional latent representation space z(x); i.e., qλ (z|x). With neural network
likelihoods, computing the gradient becomes intractable; however, via differentiable,
non-centered re-parameterization, sampling is performed from an approximate func-
tion qλ (z|x) = N (z; μλ , σλ2 ), where z = μλ + σλ ε̂ with ε̂ ∼ N (0, 1). Encoder E
yields μ and σ , and with the re-parameterization trick, z is sampled from a Gaus-
sian distribution. Then, with D , new samples are generated or real data samples
are reconstructed; i.e., D provides parameters for the real data distribution pλ (x|z).
Subsequently, a sample drawn from pφ (x|z) may be used to reconstruct the real data
by marginalizing out z.
The GAN is an implicit generative model where a generator G and a discriminator
D compete in a minimax game over the training data in order to improve their perfor-
Multi-Adversarial Variational Autoencoder Nets for Simultaneous … 253
Fig. 2 Our MAVEN architecture compared to those of the VAE, GAN, and VAE-GAN. In the
MAVEN, inputs to D can be real data X , or generated data X̂ or X̃ . An ensemble ensures the
combined feedback from the discriminators to the generator
G takes a noise sample z ∼ pg (z) and learns to map it into image space as if it
comes from the original data distribution pdata (x), while D takes as input either real
image data or generated image data and provides feedback to G as to whether that
input is real or generated. On the one hand, D wants to maximize the likelihood for
real samples and minimize the likelihood of generated samples; on the other hand,
G wants D to maximize the likelihood of generated samples. A Nash equilibrium
results when D can no longer distinguish real and generated samples, meaning that
the model distribution matches the data distribution.
Makhzani et al. [21] proposed the adversarial training of VAEs; i.e., VAE-GANs.
Although they kept both D and G, one can merge these networks since both can
generate data samples from the noise samples of the representation z. In this case, D
receives real data samples x and generated samples x̃ or x̂ via G. Although G and D
compete against each other, the feedback from D eventually becomes predictable for
G and it keeps generating samples from the same class, at which point the generated
samples lack heterogeneity. Figure 1c shows an example where all the generated
images are of the same class. Durugkar et al. [7] proposed that using multiple dis-
criminators in a GAN model helps improve performance, especially for resolving
this mode collapse. Moreover, a dynamic ensemble of multiple discriminators has
recently been proposed to address the issue [23] (Fig. 3).
As in a VAE-GAN, our MAVEN has three components, E, G, and D; all are CNNs
with convolutional or transposed convolutional layers. First, E takes real samples
254 A.-A.-Z. Imran and D. Terzopoulos
1
K
V (D) = wk D k (3)
K k=1
4 Semi-Supervised Learning
Algorithm 1 presents the overall training procedure of our MAVEN model. In the
forward pass, different real samples x into E and noise samples z into G provide
different inputs for each of the multiple discriminators. In the backward pass, the
combined feedback from the discriminators is computed and passed to G and E.
In the conventional image generator GAN, D works as a binary classifier—it
classifies the input image as real or generated. To facilitate the training for an n-class
classifier, D assumes the role of an (n + 1)-classifier. For multiple logit generation,
the sigmoid function is replaced by a softmax function. Now, it can receive an image
x as input and output an (n + 1)-dimensional vector of logits {l1 , . . . , ln , ln+1 }, which
are finally transformed into class probabilities for the n labels in the real data while
class (n + 1) denotes the generated data. The probability that x is real and belongs
to class 1 ≤ i ≤ n is
exp(li )
p(y = i | x) = n+1 (4)
j=1 exp(l j )
4.1 Losses
1
m
∇ Dk log Dk (xi ) + log(1 − Dk (G(z i )))
m
i=1
end for
(1) (m)
Sample minibatch z k , . . . , z k from pg (z)
if ensemble is ‘mean’ then
Assign weights wk to the Dk
Determine the mean discriminator
1
K
Dμ = wk D k
K
k
end if
Update G by descending along its gradient from the ensemble of Dμ :
1
m
∇G log(1 − Dμ (G(z i )))
m
i=1
end for
end for
4.1.1 D Loss
Since the model is trained on both labeled and unlabeled training data, the loss
function of D includes both supervised and unsupervised losses. When the model
receives real labeled data, it is the standard supervised learning loss
When it receives unlabeled data from three different sources, the unsupervised loss
contains the original GAN loss for real and generated data from two different sources:
Multi-Adversarial Variational Autoencoder Nets for Simultaneous … 257
synG directly from G and synE from E via G. The three losses,
4.1.2 G Loss
For G, the feature loss is used along with the original GAN loss. Activation f (x) from
an intermediate layer of D is used to match the feature between real and generated
samples. Feature matching has shown much potential in semi-supervised learning
[32]. The goal of feature matching is to encourage G to generate data that matches
real data statistics. It is natural for D to find the most discriminative features in real
data relative to data generated by the model:
2
L G feature = Ex∼ pdata f (x) − Ex̂∼G f (x̂)2 . (10)
The total G loss becomes the combined feature loss (10) plus the cost of maximizing
the log-probability of D making a mistake on the generated data (synG / synE); i.e.,
where
L G synG = −Ex̂∼G log[1 − p(y = n + 1 | x̂)], (12)
and
L G synE = −Ex̃∼G log[1 − p(y = n + 1 | x̃)]. (13)
4.1.3 E Loss
where
p(z)
L EKL = − KL [qλ (z | x) p(z)] = Eqλ (z|x) log
qλ (z | x) (15)
≈ Eqλ (z|x)
and 2
L Efeature = Ex∼ pdata f (x) − Ex̃∼G f (x̃)2 . (16)
5 Experiments
Applying semi-supervised learning using training data that is only partially labeled,
we evaluated our MAVEN model in image generation and classification tasks in a
number of experiments. For all our experiments, we used 10% labeled and 90%
unlabeled training data.
5.1 Data
All the models were implemented in TensorFlow and run on a single Nvidia Titan
GTX (12 GB) GPU. For the discriminator, after every convolutional layer, a dropout
layer was added with a dropout rate of 0.4. For all the models, we consistently
used the Adam optimizer with a learning rate of 2.0−4 for G and D, and 1.0−5 for
E, with a momentum of 0.9. All the convolutional layers were followed by batch
normalizations. Leaky ReLU activations were used with α = 0.2.
5.3 Evaluation
There are no perfect performance metrics for measuring the quality of generated sam-
ples. However, to assess the quality of the generated images, we employed the widely
used Fréchet Inception Distance (FID) [12] and a simplified version of the Descrip-
tive Distribution Distance (DDD) [14]. To measure the Fréchet distance between two
multivariate Gaussians, the generated samples and real data samples are compared
through their distribution statistics:
2
FID = μdata − μsyn + Tr data + syn − 2 data syn . (17)
Two distribution samples, Xdata ∼ N(μdata , data ) and Xsyn ∼ N(μsyn , syn ), for
real and model data, respectively, are calculated from the 2,048-dimensional acti-
vations of the pool3 layer of Inception-v3 [32]. DDD measures the closeness of
a generated data distribution to a real data distribution by comparing descriptive
parameters from the two distributions. We propose a simplified version based on the
first four moments of the distributions, computed as the weighted sum of normalized
differences of moments, as follows:
4
DDD = − log wi μdatai − μsyni , (18)
i=1
where the μdatai are the moments of the data distribution, the μsyni are the moments of
the model distribution, and the wi are the corresponding weights found in an exhaus-
tive search. The higher order moments are weighted more in order to emphasize the
stability of a distribution. For both the FID and DDD, lower scores are better.
2 × precision × recall
F1 = , (19)
precision + recall
with
TP TP
precision = and recall = , (20)
TP + FP TP + FN
where TP, FP, and FN are the number of true positives, false positives, and false
negatives, respectively.
5.4 Results
5.4.1 SVHN
For the SVHN dataset, we randomly selected 7,326 labeled images and they along
with the remaining 65,931 unlabeled images were provided to the network as training
data. All the models were trained for 300 epochs and then evaluated. We generated
new images equal in number to the training set size. Figure 5 presents a visual
comparison of a random selection of images generated by the DC-GAN, VAE-GAN,
and MAVEN models and real training images. Figure 6 compares the image intensity
histograms of 10K randomly sampled real images and equally many images sampled
from among those generated by each of the different models.
Generally speaking, our MAVEN models generate images that are more realistic
than those generated by the DC-GAN and VAE-GAN models. This was further
corroborated by randomly sampling 10K generated images and 10K real images. The
generated image quality measurement was performed for the eight different models.
Table 1 reports the resulting FID and DDD scores. For the FID score calculation, the
score is reported after running the pre-trained Inception-v3 network for 20 epochs
for each model. The MAVEN-r3D model achieved the best FID score and the best
DDD score was achieved by the MAVEN-m5D model.
Table 2 compares the classification performance of all the models for the SVHN
dataset. The MAVEN model consistently outperformed the DC-GAN and VAE-GAN
classifiers both in classification accuracy and class-wise F1 scores. Among all the
models, our MAVEN-m2D and MAVEN-m3D models were the most accurate.
262 A.-A.-Z. Imran and D. Terzopoulos
Fig. 5 Visual comparison of image samples from the SVHN dataset against those generated by the
different models
Fig. 6 Histograms of the real SVHN training data, and of the data generated by the DC-GAN and
VAE-GAN models and by our MAVEN models with mean and random feedback from 2, 3, to 5
discriminators
Table 1 Minimum FID and DDD scores achieved by the DC-GAN, VAE-GAN, and MAVEN models for the CIFAR-10, SVHN, CXR, and SLC datasets
CIFAR-10 SVHN CXR SLC
Model FID DDD Model FID DDD Model FID DDD Model FID DDD
DC-GAN 61.293±0.2090.265 DC-GAN 16.789±0.3030.343 DC-GAN 152.511±0.370 0.145 DC-GAN 1.828±0.370 0.795
VAE-GAN 15.511±0.1250.224 VAE-GAN 13.252±0.0010.329 VAE-GAN 141.422±0.580 0.107 VAE-GAN 1.828±0.580 0.795
MAVEN- 12.743±0.2420.223 MAVEN- 11.675±0.0010.309 MAVEN- 141.339±0.420 0.138 MAVEN- 1.874±0.270 0.802
m2D m2D m2D m2D
MAVEN- 11.316±0.8080.190 MAVEN- 11.515±0.0650.300 MAVEN- 140.865±0.983 0.018 MAVEN- 0.304±0.018 0.249
m3D m3D m3D m3D
MAVEN- 12.123±0.1400.207 MAVEN- 10.909±0.0010.294 MAVEN- 147.316±1.169 0.100 MAVEN- 1.518±0.190 0.793
m5D m5D m5D m5D
MAVEN- 12.820±0.5840.194 MAVEN- 11.384±0.0010.316 MAVEN- 154.501±0.345 0.038 MAVEN- 1.505±0.130 0.789
r2D r2D r2D r2D
MAVEN- 12.620±0.0010.202 MAVEN- 10.791±0.0290.357 MAVEN- 158.749±0.297 0.179 MAVEN- 0.336±0.080 0.783
r3D r3D r3D r3D
MAVEN- 18.509±0.0010.215 MAVEN- 11.052±0.7510.323 MAVEN- 152.778±1.254 0.180 MAVEN- 1.812±0.014 0.795
r5D r5D r5D r5D
DO-GAN 88.60± 0.08 –
[23]
Multi-Adversarial Variational Autoencoder Nets for Simultaneous …
Table 2 Average cross-validation accuracy and class-wise F1 scores in the semi-supervised clas-
sification performance comparison of the DC-GAN, VAE-GAN, and MAVEN models using the
SVHN dataset
Model Accuracy F1 scores
0 1 2 3 4 5 6 7 8 9
DC-GAN 0.876 0.860 0.920 0.890 0.840 0.890 0.870 0.830 0.890 0.820 0.840
VAE-GAN 0.901 0.900 0.940 0.930 0.860 0.920 0.900 0.860 0.910 0.840 0.850
MAVEN-m2D 0.909 0.890 0.930 0.940 0.890 0.930 0.900 0.870 0.910 0.870 0.890
MAVEN-m3D 0.909 0.910 0.940 0.940 0.870 0.920 0.890 0.870 0.920 0.870 0.860
MAVEN-m5D 0.905 0.910 0.930 0.930 0.870 0.930 0.900 0.860 0.910 0.860 0.870
MAVEN-r2D 0.905 0.910 0.930 0.940 0.870 0.930 0.890 0.860 0.920 0.850 0.860
MAVEN-r3D 0.907 0.890 0.910 0.920 0.870 0.900 0.870 0.860 0.900 0.870 0.890
MAVEN-r5D 0.903 0.910 0.930 0.940 0.860 0.910 0.890 0.870 0.920 0.850 0.870
5.4.2 CIFAR-10
For the CIFAR-10 dataset, we used 50 K training images, only 5 K of them labeled.
All the models were trained for 300 epochs and then evaluated. We generated new
images equal in number to the training set size. Figure 7 visually compares a random
selection of images generated by the DC-GAN, VAE-GAN, and MAVEN models
and real training images. Figure 8 compares the image intensity histograms of 10K
randomly sampled real images and equally many images sampled from among those
generated by each of the different models. Table 1 reports the FID and DDD scores.
As the tabulated results suggest, our MAVEN models achieved better FID scores than
some of the recently published models. Note that those models were implemented
in different settings.
As for the visual comparison, the FID and DDD scores confirmed more realistic
image generation by our MAVEN models compared to the DC-GAN and VAE-GAN
models. The MAVEN models have smaller FID scores, except for MAVEN-r5D.
MAVEN-m3D has the smallest FID and DDD scores among all the models.
Table 3 compares the classification performance of all the models with the CIFAR-
10 dataset. All the MAVEN models performed better than the DC-GAN and VAE-
GAN models. In particular, MAVEN-m5D achieved the best classification accuracy
and F1 scores.
5.4.3 CXR
With the CXR dataset, we used 522 labeled images and 4,694 unlabeled images. All
the models were trained for 150 epochs and then evaluated. We generated an equal
number of new images as the training set size. Figure 9 presents a visual comparison
of a random selection of generated and real images. The FID and DDD measurements
were performed for the distributions of generated and real training samples, indicating
that more realistic images were generated by the MAVEN models than by the GAN
Multi-Adversarial Variational Autoencoder Nets for Simultaneous … 265
Fig. 7 Visual comparison of image samples from the CIFAR-10 dataset against those generated
by the different models
Fig. 8 Histograms of the real CIFAR-10 training data, and of the data generated by the DC-GAN
and VAE-GAN models and by our MAVEN models with mean and random feedback from 2, 3, to
5 discriminators
266 A.-A.-Z. Imran and D. Terzopoulos
Table 3 Average cross-validation accuracy and class-wise F1 scores in the semi-supervised clas-
sification performance comparison of the DC-GAN, VAE-GAN, and MAVEN models using the
CIFAR-10 dataset
Model Accuracy F1 scores
Plane Auto Bird Cat Deer Dog Frog Horse Ship Truck
DC-GAN 0.713 0.760 0.840 0.560 0.510 0.660 0.590 0.780 0.780 0.810 0.810
VAE-GAN 0.743 0.770 0.850 0.640 0.560 0.690 0.620 0.820 0.770 0.860 0.830
MAVEN-m2D 0.761 0.800 0.860 0.650 0.590 0.750 0.680 0.810 0.780 0.850 0.850
MAVEN-m3D 0.759 0.770 0.860 0.670 0.580 0.700 0.690 0.800 0.810 0.870 0.830
MAVEN-m5D 0.771 0.800 0.860 0.650 0.610 0.710 0.640 0.810 0.790 0.880 0.820
MAVEN-r2D 0.757 0.780 0.860 0.650 0.530 0.720 0.650 0.810 0.800 0.870 0.860
MAVEN-r3D 0.756 0.780 0.860 0.640 0.580 0.720 0.650 0.830 0.800 0.870 0.830
MAVEN-r5D 0.762 0.810 0.850 0.680 0.600 0.720 0.660 0.840 0.800 0.850 0.820
and VAE-GAN models. The FID and DDD scores presented in Table 1 show that the
mean MAVEN-m3D model has the smallest FID and DDD scores.
The classification performance reported in Table 4 suggests that our MAVEN
model-based classifiers are more accurate than the baseline GAN and VAE-GAN
classifiers. Among all the models, the MAVEN-m3D classifier was the most accurate.
5.4.4 SLC
For the SLC dataset, we used 160 labeled images and 1,440 unlabeled images. All the
models were trained for 150 epochs and then evaluated. We generated new images
equal in number to the training set size. Figure 10 presents a visual comparison of
randomly selected generated and real image samples.
The FID and DDD measurements for the distributions of generated and real train-
ing samples indicate that more realistic images were generated by the MAVEN mod-
els than by the GAN and VAE-GAN models. The FID and DDD scores presented
in Table 1 show that the mean MAVEN-m3D model has the smallest FID and DDD
scores.
The classification performance reported in Table 5 suggests that our MAVEN
model-based classifiers are more accurate than the baseline GAN and VAE-GAN
classifiers. Among all the models, MAVEN-r3D is the most accurate in discriminating
between non-melanoma and melanoma lesion images.
Multi-Adversarial Variational Autoencoder Nets for Simultaneous … 267
Fig. 9 Visual comparison of image samples from the CXR dataset against those generated by the
different models
Table 4 Average cross-validation accuracy and class-wise F1 scores for the semi-supervised clas-
sification performance comparison of the DC-GAN, VAE-GAN, and MAVEN models using the
CXR dataset
Model Accuracy F1 scores
Normal B-Pneumonia V-Pneumonia
DC-GAN 0.461 0.300 0.520 0.480
VAE-GAN 0.467 0.220 0.640 0.300
MAVEN-m2D 0.469 0.310 0.620 0.260
MAVEN-m3D 0.525 0.640 0.480 0.480
MAVEN-m5D 0.477 0.380 0.480 0.540
MAVEN-r2D 0.478 0.280 0.630 0.310
MAVEN-r3D 0.506 0.440 0.630 0.220
MAVEN-r5D 0.483 0.170 0.640 0.240
268 A.-A.-Z. Imran and D. Terzopoulos
Fig. 10 Visual comparison of image samples from the SLC dataset against those generated by the
different models
Multi-Adversarial Variational Autoencoder Nets for Simultaneous … 269
Table 5 Average cross-validation accuracy and class-wise F1 scores for the semi-supervised clas-
sification performance comparison of the DC-GAN, VAE-GAN, and MAVEN models using the
SLC dataset
Model Accuracy F1 scores
Non-melanoma Melanoma
DC-GAN 0.802 0.890 0.120
VAE-GAN 0.810 0.890 0.012
MAVEN-m2D 0.815 0.900 0.016
MAVEN-m3D 0.814 0.900 0.110
MAVEN-m5D 0.812 0.900 0.140
MAVEN-r2D 0.808 0.890 0.260
MAVEN-r3D 0.821 0.900 0.020
MAVEN-r5D 0.797 0.890 0.040
6 Conclusions
References
5. N.C. Codella, D. Gutman, M.E. Celebi, B. Helba, M.A. Marchetti, S.W. Dusza, A. Kalloo,
K. Liopyris, N. Mishra, H. Kittler et al., Skin lesion analysis toward melanoma detection: a
challenge at the 2017 ISBI, hosted by ISIC, in IEEE International Symposium on Biomedical
Imaging (ISBI 2018) (2018), pp. 168–172
6. E.L. Denton, S. Chintala, A. Szlam, R. Fergus, Deep generative image models using a Lapla-
cian pyramid of adversarial networks, in Advances in Neural Information Processing Systems
(NeurIPS) (2015)
7. I. Durugkar, I. Gemp, S. Mahadevan, Generative multi-adversarial networks (2016). arXiv
preprint arXiv:1611.01673
8. M. Frid-Adar, I. Diamant, E. Klang, M. Amitai, J. Goldberger, H. Greenspan, GAN-based syn-
thetic medical image augmentation for increased CNN performance in liver lesion classification
(2018). arXiv preprint arXiv:1803.01229
9. I. Goodfellow, J. Pouget-Abadie, M. Mirza, B. Xu, D. Warde-Farley, S. Ozair, A. Courville, Y.
Bengio, Generative adversarial nets, in Advances in Neural Information Processing Systems
(NeurIPS) (2014), pp. 2672–2680
10. J.T. Guibas, T.S. Virdi, P.S. Li, Synthetic medical images from dual generative adversarial
networks (2017). arXiv preprint arXiv:1709.01872
11. C. Han, H. Hayashi, L. Rundo, R. Araki, W. Shimoda, S. Muramatsu, Y. Furukawa, G. Mauri,
H. Nakayama, GAN-based synthetic brain MR image generation, in IEEE International Sym-
posium on Biomedical Imaging (ISBI) (2018), pp. 734–738
12. M. Heusel, H. Ramsauer, T. Unterthiner, B. Nessler, S. Hochreiter, GANs trained by a two
time-scale update rule converge to a local Nash equilibrium, in Advances in Neural Information
Processing Systems (NeurIPS) (2017), pp. 6626–6637
13. L. Hou, A. Agarwal, D. Samaras, T.M. Kurc, R.R. Gupta, J.H. Saltz, Unsupervised histopathol-
ogy image synthesis (2017). arXiv preprint arXiv:1712.05021
14. A.A.Z. Imran, P.R. Bakic, A.D. Maidment, D.D. Pokrajac, Optimization of the simulation
parameters for improving realism in anthropomorphic breast phantoms, in Proceedings of the
SPIE, vol. 10132 (2017)
15. A.A.Z. Imran, D. Terzopoulos, Multi-adversarial variational autoencoder networks, in IEEE
International Conference on Machine Learning and Applications (ICMLA) (oca Raton, FL,
2019), pp. 777–782
16. A.A.Z. Imran, D. Terzopoulos, Semi-supervised multi-task learning with chest X-ray images
(2019). arXiv preprint arXiv:1908.03693
17. D.S. Kermany, M. Goldbaum, W. Cai, C.C. Valentim, H. Liang, S.L. Baxter, A. McKeown, G.
Yang, X. Wu, F. Yan et al., Identifying medical diagnoses and treatable diseases by image-based
deep learning. Cell 172(5), 1122–1131 (2018)
18. D.P. Kingma, M. Welling, Auto-encoding variational Bayes (2013). arXiv preprint
arXiv:1312.6114
19. A. Krizhevsky, Learning multiple layers of features from tiny images. Master’s thesis, Univer-
sity of Toronto, Dept. of Computer Science (2009)
20. A. Madani, M. Moradi, A. Karargyris, T. Syeda-Mahmood, Semi-supervised learning with
generative adversarial networks for chest X-ray classification with ability of data domain adap-
tation, in IEEE International Symposium on Biomedical Imaging (ISBI) (2018), pp. 1038–1042
21. A. Makhzani, J. Shlens, N. Jaitly, I. Goodfellow, B. Frey, Adversarial autoencoders (2015).
arXiv preprint arXiv:1511.05644
22. T. Miyato, T. Kataoka, M. Koyama, Y. Yoshida, Spectral normalization for generative adver-
sarial networks (2018). arXiv preprint arXiv:1802.05957
23. G. Mordido, H. Yang, C. Meinel, Dropout-GAN: learning from a dynamic ensemble of dis-
criminators (2018). arXiv preprint arXiv:1807.11346
24. Y. Netzer, T. Wang, A. Coates, A. Bissacco, B. Wu, A.Y. Ng, Reading digits in natural images
with unsupervised feature learning, in NIPS Workshop on Deep Learning and Unsupervised
Feature Learning, vol. 2011 (2011), pp. 1–9
25. B. Neyshabur, S. Bhojanapalli, A. Chakrabarti, Stabilizing GAN training with multiple random
projections (2017). arXiv preprint arXiv:1705.07831
Multi-Adversarial Variational Autoencoder Nets for Simultaneous … 271
26. T. Nguyen, T. Le, H. Vu, D. Phung, Dual discriminator generative adversarial nets, in Advances
in Neural Information Processing Systems (NeurIPS) (2017), pp. 2670–2680
27. A. dena, C. Olah, J. Shlens, Conditional image synthesis with auxiliary classifier GANs (2017).
arXiv preprint arXiv:1610.09585
28. G. Ostrovski, W. Dabney, R. Munos, Autoregressive quantile networks for generative modeling
(2018). arXiv preprint arXiv:1806.05575
29. A. Radford, L. Metz, S. Chintala, Unsupervised representation learning with deep convolutional
generative adversarial networks (2015). arXiv preprint arXiv:1511.06434
30. S. Ravuri, S. Mohamed, M. Rosca, O. Vinyals, Learning implicit generative models with the
method of learned moments (2018). arXiv preprint arXiv:1806.11006
31. H. Salehinejad, S. Valaee, T. Dowdell, E. Colak, J. Barfett, Generalization of deep neural
networks for chest pathology classification in X-rays using generative adversarial networks, in
IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP) (2018),
pp. 990–994
32. T. Salimans, I. Goodfellow, W. Zaremba, V. Cheung, A. Radford, X. Chen, Improved techniques
for training GANs, in Advances in Neural Information Processing Systems (NeurIPS) (2016),
pp. 2234–2242
33. J.T. Springenberg, Unsupervised and semi-supervised learning with categorical generative
adversarial networks (2015). arXiv preprint arXiv:1511.06390
34. T. Unterthiner, B. Nessler, C. Seward, G. Klambauer, M. Heusel, H. Ramsauer, S. Hochreiter,
Coulomb GANs: provably optimal nash equilibria via potential fields (2017). arXiv preprint
arXiv:1708.08819
35. S. Wang, L. Zhang, CatGAN: coupled adversarial transfer for domain generation (2017). arXiv
preprint arXiv:1711.08904
36. M.A. Wani, F.A. Bhat, S. Afzal, A.I. Khan (eds.), Advances in Deep Learning (Springer, 2020)
Non-convex Optimization Using
Parameter Continuation Methods for
Deep Neural Networks
1 Introduction
In many machine learning and deep learning problems, the key task is an optimization
problem, with the objective to learn useful properties of the data given a model. The
parameters of the model are used to learn linear or non-linear features of the data
and can be used to perform inference and predictions. In this chapter, we will study
a challenging non-convex optimization task, i.e. ability of deep learning models to
learn non-linear or complex representations from data.
This chapter is an extension of the conference paper [52] with additional results,
additional context for the theory of the algorithms, a detailed literature survey listing
some common limitations of curriculum strategies, and discussions on open research
questions.
Often, the objective of a deep learning model is to approximate a function f true ,
which is the true function that maps inputs x to the targets y, such that y = f true (x)
[15, 50]. More formally, a deep neural network defines a mapping y = f (x; θ ),
where θ is the set of parameters. These parameters θ are estimated by minimizing
an objective function, such that the mapping f best approximates f true . However,
training the deep neural network to find a good solution is a challenging optimization
task [3, 17, 44]. Where by a good solution is meant achieving low generalization error
in few training steps. Even with state-of-the-art techniques, machines configured with
a large number of processing units and high memory can spend days, or even longer
to solve such problems [50].
Neural networks use a composition of functions which, generally speaking, are
non-convex and difficult to solve [3, 8]. Even deep linear networks functions are
non-convex in parameter space [17]. Evidently, deep learning research have exceed-
ingly advanced in the past decade, and quality optimization algorithms have been
proposed to solve non-convex problems such as Stochastic Gradient Descent (SGD)
[15], RMSprop [25], Adagrad [12] and ADAM [35] are widely practiced by the
deep learning community and have significantly advanced the state-of-the-art in text
mining [43], super-resolution [16, 40, 51], image recognition [38], speech recog-
nition [24] and many more. However, their success usually depends on the choice
of hyperparameters and the quality of the initialization [60]. Previously, researchers
have shown that different initialization methods may lead to dramatically different
solution curve geometries on a cost surface [8, 17, 27]. In other words, a small shift
in the initialization may lead the model to converge to a different minimum [50]. In
this chapter, we attempt to introduce and improve the solution of such non-convex
optimization problems by rethinking usual approach. In particular, many of the cur-
rent state-of-the-art optimization techniques work on a fixed loss surface. On the
other hand, we propose to transform the loss surface continuously in order to design
an effective training strategy.
The contributions of this chapter are the following. We derive a homotopy formu-
lation of common activation functions that implicitly decomposes the deep neural
network optimization problem into a sequence of problems. Which then enables one
to design and apply parameter continuation methods to rapidly converge to a bet-
Non-convex Optimization Using Parameter Continuation … 275
2 Background
that the solution of J (i) (θ ) can be used as an initial estimate for the parameters of
J (i+1) (θ ) [2, 11, 50].
Based upon a literature survey [3, 44, 49, 50], there appear to be two main thrusts
for improving optimization by continuation:
• How to choose the simplified problem?
• How to transform the simplified problem into the original task?
In this chapter, we provide an approach that addressed both of these challenges. We
also observe that the results in this chapter are based upon results in several disparate
domains, and good entryways into the literature include [15, 26] for Autoencoders
(AE) [2, 34], for continuation methods and [3] Curriculum Learning.
2.1 Homotopies
One possible way to incorporate continuation methods in the context of neural net-
works is determining homotopy [2] formulations of activation function [52], as
defined in the following manner:
The Implicit Function Theorem (IFT), is an important tool for the theoretical under-
standing of continuation methods [34]. While a fulsome treatment of these ideas is
beyond the scope of this chapter, in this section we will provide a measure of intuition
for the interested reader to help bridge the gap between continuation methods and
deep learning. Let’s begin with the objective of a simple AE [34].
where X ∈ R m×n is the input data, f is the encoder function, g is the decoder function
and θ ∈ R N the network parameters. We want it to be zero in the case of mean squared
error. Now we can represent this objective function as a function G(X ; θ ) which can
be used in a continuation method framework [34].
Note, that in the above equation X is not a parameter. It is fixed by our choice
of training data. Accordingly, when we ‘count parameters’ below, only θ serves
as degrees of freedom. Now, say we add another parameter λ, which we call a
homotopy parameter, to the system of neural network equations. In our work, this
single parameter controls the transformation of the problem from a simple problem
to a hard problem, and such homotopy parameters will be the focus of Sect. 4. We
rewrite to get
G(θ, λ) = 0 (4)
Now (4) can be directly related to the the system of equations in [34], for which
implicit function theorem can be applied under certain assumptions1 The IFT then
guarantees that, in the neighbourhood of a known solution when the derivative is
non-singular, that a smooth and continuous homotopy path must exist. The interested
reader can see [2, 34] for more details.
3 Related Work
Optimization technique such as SGD or ADAM plays a crucial role in the quality
of convergence. Also, finding better initialization in order to find a superior local
minimum is popular in the research community. Here we draw a few observations
from some of the popular research works which have demonstrated quality perfor-
mance. Unsupervised pre-training [26] is one such strategy to train neural networks
for supervised and unsupervised learning tasks. The idea is to use an Autoencoder
(AE) to greedily train the hidden layers one layer at a time. This method has two
advantages. First, it provides stable initialization as compared to random initializa-
tion as the model has silently learned properties of the data. Second, it increases the
regularization effect [15, 50]. Fine-tuning may be seen as an evolution of above-
mentioned approach where the learned layers are trained further or fine-tuned for the
final task [15, 39]. In transfer learning, you are provided with two distinct tasks,
such that learning from one distribution can be transferred to another. This tech-
nique is broadly used in vision and language tasks such as recommendation systems,
image super-resolution, etc. In summary, these ideas suggest that incorporating prior
1 Assumptions:
• G : RN × R − → R N be smooth map
• ||G(θ0 , λ0 )|| ≤ c
• G 0 (θ) be non-singular at a known root ((θ0 , λ0 ))
See. [34] for the IFT theorem and proofs for local continuation.
278 H. Nilesh Pathak and R. Clinton Paffenroth
knowledge to the network or taking few simple learning steps first, can enhance the
final training performance [13, 17].
Some of the standard work that has been done close to the research in this chapter is
curriculum learning [3, 23]. Smoothing [22, 44] is a common approach for changing
the qualitative behaviour of activation functions throughout the course of the train-
ing that has been adopted by many researchers in distinct ways [5, 21, 22, 44, 50].
Smoothing can be seen as a convex envelope of the non-convex problems in [45] and
it may reveal the global curvature of the loss function [3, 44]. Then this smoothing
can be continuously decreased to get a sequence of loss functions with increasing
complexity. The smoothed objective function is minimized first, and then progres-
sively smoothing is reduced as training proceeds to obtain the terminal solution
[3, 44].
Next, Mollifying Networks added normally distributed noise to the weights of each
layer of the neural network [22] and thus, mollified the objective function. Again,
this noise is steadily reduced as the training proceeds. Linearization of activation
functions is another way of achieving the mollification effect [5, 21, 22]. Similarly,
gradually deforming from linear to non-linear representation has been implemented
with different noise injection techniques [21, 22, 44].
Initialization is one of the main advantages of continuation methods [2, 34]. In
many deep learning applications, we observe random initialization is widely prac-
ticed. However, continuation methods suggest the following, namely, to start with a
problem whose solution is easily available and then incrementally advance towards
the problem you want to solve. Naturally, this method can save computational cost.
While other methods would require many iterations to attain a reasonable solution, in
the paradigm of parameter continuation methods, the initialization could usually be
deterministic, and later we progressively solve simpler problems to obtain the opti-
mal solution. For example, in [44] one performs accurate initialization with almost
no computational cost. Furthermore, some more advantages were discussed in diffu-
sion methods [44]. In that research work, authors conceptually showed the popular
techniques such as dropouts [58], layerwise pre-training [26, 39], annealed learning
rate, careful initialization [60] and noise injection (directly to loss or weights of the
network) [3] can naturally arise from continuation methods [44].
We discuss a few limitations and future research directions in Sect. 6. In com-
parison to the previous approaches that add noise to activation functions [21, 22] to
serve as a proxy for the continuation method, we derive homotopies to obtain lin-
earization, which we will discuss in Sect. 4. Our intent is to simplify the original task
and start with a good initialization. Despite the above advances, there is a need for
systematic empirical evaluation to show how the continuation method may help to
avoid early local minima and saddle points [17, 52]. Unlike most of the research, that
was focused on classification tasks or prediction tasks, herein we focus on unsuper-
vised learning. However, there is nothing special in the formulation of the homotopy
function that depends on the particular structure of AEs, so we expect similar results
are possible with arbitrary neural networks.
Non-convex Optimization Using Parameter Continuation … 279
4 Methodology
where φ can be any activation function such as Sigmoid, ReLU, Tanh, etc., λ is the
homotopy parameter and v is the value of the output from the current layer. For a
standard AE, the loss function can be represented as the following
where X ∈ R m×n is the input data, f is the encoder function, g is the decoder function
and θ the network parameters. With the addition of a homotopy parameter λ the
optimization can be rewritten in the following manner:
J (X, gθ,λ ( f θ,λ (X )))) = argmin ||X − gθ,λ ( f θ,λ (X ))||2 . (7)
θ,λ
To provide intuition, first let us consider the extremes (i.e. λ = 0 or λ = 1). When λ = 1,
both the objective functions in (6) and (7) are exactly the same, thus λ = 1 indicates
the use of conventional activation functions. Second, in a deep neural network as
280 H. Nilesh Pathak and R. Clinton Paffenroth
Fig. 1 This figure is a good example of Manifold Learning [69]. Points in blue are true data points
that clearly lies on a non-linear manifold. The green points show an optimized non-linear projection
of the data. The red points show the linear manifold which is projected by PCA. Learning linear
manifold can be considered as simplified problem to non-linear manifold learning
shown in Fig. 6, consider all the activation functions are C-Activations with λ0 = 0,
then the neural network is a linear. Here, the network would attempt to find a linear
projection in order to minimize the cost function. We know such a solution can be
found in closed form using Principle Component Analysis (PCA) or the Singular
Value Decomposition (SVD) [50]. Hence, the PCA solution and the solution of our
optimization problem at λ0 = 0 should span the same subspace [15, 50]. Thus, we
leverage this observation and initialize our network using PCA, which we discuss
further in next Sect. 4.3. Effectively, λ = 0 is analogous to solving a simpler version
of the optimization problem (a linear projection) and as λ −→ 1 the problem becomes
harder (a non-linear and non-convex problem) (Fig. 1).
Thus, the homotopy parameter λ defines the scheme for the optimization, which
can be referred to as the homotopy path, as shown in Fig. 2. As λ : 0 − → 1, we solve
an increasingly difficult optimization problem, where the degree of non-linearity
learned from the data increases. However, we need a technique to traverse through
this homotopy path and find the solutions along that path and deriving such techniques
will be our focus in the following section.
X = U V T (8)
Non-convex Optimization Using Parameter Continuation … 281
Fig. 2 This figure provides intuition for NPC with a hypothetical homotopy path (blue curve),
which connects all solutions in multidimensional space (θ) at every λ ∈ [0, 1]. Here we show how
solution (θλ1 ) is used to initialize the parameters of the network at λ2 , θλinit
2
←
− θλ1 using a small
λ2,1 . Further taking some ADAM steps (shown by the orange arrow), we find some minimum
(θλ2 )
where X is our normalized training data,2 U and V are unitary matrices, and is a
diagonal matrix.
In Eq. (9), we observe that V can be seen as a mapping of data from high dimen-
sional space to a lower dimension linear subspace (an encoder in language of AEs).
In addition, V T is a mapping from a lower dimensional linear subspace back to
the original space of the provided data (a decoder in the language of AEs). This
behaviour of SVD enables us to initialize the weights and appropriately defined AE.
In particular, when λ = 0 in (7) we have that f and g are, in fact, linear functions
and (7) can be minimized using the SVD [47].
More precisely, we use the first n columns of V for the encode layer with width
n and for the decode layer, we use the transpose of the weights used for the encoder,
as shown in (12) and as in [50].
X = U V T (9)
X V = U V T V (10)
XV = U (11)
n
Wencoder = Vn-columns (12)
Wdecoder = (Wencoder
n n
)T (13)
where W n represents the weight matrix of the encoder layer with n as its width.
There are multiple advantages to initializing an AE using PCA. First, we start
the deep learning training from a solution to a linear problem rather than trying
to solve a non-linear, and often non-convex, optimization problem. Having such
a straight-forward initiation procedure allows the optimization of the AE to begin
2 For PCA the data is normalized by having the mean of each column being 0.
282 H. Nilesh Pathak and R. Clinton Paffenroth
from a known global optimum. Second, in the NPCS method, the C-Activation
defines the homotopy as a continuous deformation from a linear to a non-linear
network. Accordingly, the global optimum to the linear problem can be transformed
into a solution of the non-linear and non-convex problem of interest in a principled
fashion. Finally, as PCA provides a deterministic initialization which does not require
any sampling of the parameter space for initialization [10, 44, 60], our proposed
method also has a computational advantage. Of course, the idea of initializing a deep
feedforward network with PCA is not novel and has been independently explored
in the literature [7, 36, 57]. However, our proposed algorithm is unique in that
it leverages powerful techniques for solving homotopy problems using parameter
continuation.
θλinit
i+1
←
− θλi (14)
Non-convex Optimization Using Parameter Continuation … 283
Fig. 3 This figure, we provide intuition for our adaptive methods. We show two possible set of
steps, where green is an adaptive λ and grey is a fixed λ. We want to take larger steps when the
homotopy path is flat and shorter steps where the homotopy path is more oscillatory
Fig. 4 This figure, from left to right, illustrates how the C-Sigmoid, C-ReLU and C-Tanh behaves
at λ = 0.7 on uniformly distributed points between [−10 and 10]
284 H. Nilesh Pathak and R. Clinton Paffenroth
Table 1 This table shows the train and test loss values of different optimization techniques using a
specified network. All these experiments were computed for 50,000 backpropogation steps, and we
report the averages of the last 100 loss values for both training and testing. Perhaps not surprisingly,
SGD does the worst in almost all cases. More interestingly, RMSProp and ADAM both do well
in some cases, and quite badly in others. Note that the various parameter continuation methods all
have quite stable properties and achieve a low loss in all cases and the lowest loss in most cases
Network SGD RMSProp ADAM NPC NPAC NPACS
Fashion- AE-8 0.1133 0.03915 0.03370 0.03402 0.03370 0.03388
MNIST Sigmoid
AE-8 ReLU 0.11122 0.03582 0.03318 0.03171 0.03188 0.03191
AE-8 Tanh 0.10741 0.03459 0.03515 0.03573 0.03552 0.03559
AE-16 0.11397 0.06714 0.06714 0.03418 0.04505 0.03461
Sigmoid
AE-16 0.11394 0.06714 0.03436 0.03474 0.03445 0.03659
ReLU
AE-16 Tanh 0.10889 0.03419 0.03540 0.03753 0.03722 0.03622
CIFAR-10 AE-8 0.28440 0.03861 0.03352 0.03275 0.03224 0.03238
Sigmoid
AE-8 ReLU 0.07689 0.03467 0.03421 0.03459 0.03461 0.03302
AE-8 Tanh 0.27565 0.03355 0.03421 0.03343 0.03392 0.03408
AE-16 0.28717 0.06223 0.06223 0.03480 0.03310 0.03517
Sigmoid
AE-16 0.07722 0.03512 0.03419 0.03400 0.03456 0.03463
ReLU
AE-16 Tanh 0.27884 0.03496 0.03452 0.03637 0.03815 0.03405
Maximum 0.28440 0.06714 0.06714 0.03753 0.04505 0.03659
Fashion- AE-8 0.11333 0.08508 0.08525 0.08257 0.08324 0.08170
MNIST Sigmoid
(Test)
AE-8 ReLU 0.11123 0.08202 0.08076 0.08225 0.08160 0.08154
AE-8 Tanh 0.10742 0.08571 0.07904 0.07225 0.07306 0.07447
AE-16 0.11397 0.11396 0.11383 0.07405 0.09050 0.07912
Sigmoid
AE-16 0.11394 0.11396 0.08035 0.08103 0.07993 0.08441
ReLU
AE-16 Tanh 0.10891 0.07736 0.07713 0.07899 0.07519 0.08025
CIFAR-10 AE-8 0.28440 0.07835 0.06526 0.06344 0.06417 0.06522
(Test) Sigmoid
AE-8 ReLU 0.28440 0.03861 0.03352 0.03276 0.03225 0.03239
AE-8 Tanh 0.27568 0.09993 0.08604 0.06801 0.07580 0.07544
AE-16 0.28718 0.28671 0.28648 0.07735 0.07588 0.10232
Sigmoid
AE-16 0.07724 0.05507 0.05321 0.05188 0.06009 0.05564
ReLU
AE-16 Tanh 0.27887 0.08614 0.08939 0.10874 0.12128 0.08714
Maximum 0.28718 0.28671 0.28648 0.10874 0.12128 0.10232
For each row, the largest (worst) loss is shown in red, and the lowest (best) loss is shown in green
Non-convex Optimization Using Parameter Continuation … 285
observed C-Sigmoid has almost linear behaviour until λ = 0.8 after which it rapidly
adapts to the Sigmoid. Therefore, we needed an adaptive method that can reasonably
tune this λ update for different activation functions [52] and also adapt to the nature
of the homotopy path. We elaborate on these λ choices in Sect. 5
Accordingly, we developed an adaptive method for determining λ update by
utilizing the information of gradients during backpropagation and developed Algo-
rithm 1. NPAC has two benefits, first, it solves the issue of hand-picking the value
of λ. Second, this method provides a reasonable approach to determine how close
the next suitable problem or λ value should be, for which the current solution (say
at λ = 0.25) would be a good initialization.
Algorithm 1 Adaptive λ
Require: norm_grads- list of Norm of gradients of objective function from previous t steps, λi,i−1
for previous step, scale_up and scale_down factors.
1: avg_norm_p ← mean of first half of norm_grads
2: avg_norm_c ← mean of second half of norm_grads
3: condition1 ← (avg_norm_p−avg_norm_c)
avg_norm_p < (−tolerance)
4: condition2 ← (avg_norm_p−avg_norm_c)
avg_norm_p > (tolerance)
5: if condition1 then
6: λi+1,i ← λi,i−1 · scale_up
7: else if condition2 then
λi,i−1
8: λi+1,i ← scale_down
9: else
10: λi+1,i ← λi,i−1
11: end if
12: return λi+1,i
NPACS is a more advanced version of NPAC, where we enhance our method with
a secant step in θ (multidimensional) space along with the adaptive λ update. A
secant line to a curve is one that passes through at least two distinct points [31],
and we use this method to find linear approximation of the homotopy path, for
a particular neighbourhood [2]. In previous two methods, we simply assigned the
previous solution as the initialization for the current problem, whereas in NPACS
we take the previous two solutions and apply a secant approximation to initialize
the next step. An important step is to properly normalize for the secant step for
the parameters (θ ) of a neural network which are commonly thousands or even
millions of dimensions, depending on the network size [52]. In Fig. 5, we illustrate
the geometric interpretation of Eq. (15). A clear advantage of NPACS is that a secant
update follows the homotopy curve more closely to approximate the derivative of
286 H. Nilesh Pathak and R. Clinton Paffenroth
the curve [52] and initialize the subsequent problem accordingly. We also present
Algorithm 2, that demonstrates all the required steps to implement NPACS. Here we
perform model continuation for an AE (depth 8 and 16), using ADAM updates to
solve a particular optimization at λi . Depending on the problem, any other optimizer
may also be selected.
λi+1,i
θλinit ←
− θλi + (θλi − θλi−1 ) · (15)
i+1
λi,i−1
5 Experimental Results
We used two popular datasets for our experiments, namely, CIFAR-10 [37] and
Fashion-MNIST [67]. Fashion-MNIST has 55,000 training images and 10,000 test
images with ten different classes and each image has 784 dimensions. Similarly, for
CIFAR-10 we used 40,000 images for training and 10,000 as a test/validation set.
Each image in this dataset is 3072 (32 × 32 × 3) dimensions with ten different
classes. CIFAR-10 is a more challenging dataset than Fashion-MNIST [52] and also
widely used by the researchers. Next, we have autoencoders (AE), an unsupervised
learning method to test our optimization technique. The employed AE is shown in
Fig. 6. In case of Fashion MNIST dataset, the input is a 784-dimensional image
(or 3072-dimensional image for CIFAR-10 dataset). AE is then used to perform
the reconstruction of the image from only two-dimensional representation, which is
encoded in the hidden layer. In particular, two neural networks are evaluated, namely,
AE-8 and AE-16 of depth 8 and 16, respectively, for all our experiments.
We compare our parameter continuation methods NPC at (λ = 8e − 3), NPAC
and NPACS against existing methods such as ADAM, SGD and RMSProp. Primar-
ily, task consistency plays a key role in providing conclusive empirical evaluations,
and we achieve it by keeping the data, architecture and hyperparameters fixed. As
explained in Sect. 4.3 SVD is used for initialization of our network consistently
Non-convex Optimization Using Parameter Continuation … 287
note, Sigmoid, Tanh and ReLU are used with SGD, ADAM and RMSProp optimiza-
tion techniques, but while implementing continuation methods, their continuation
counterparts were used, such as C-Sigmoid and C-ReLU from (5). Finally, we are
going to opensource our code and all the hyperparameter choices on Github.3
Three important metrics to test a new optimizer are the qualitative analysis of train-
ing loss, generalization loss and convergence speed [28]. Table 1, depicts training and
validation loss of our network (six variants or tasks) with six different optimization
methods and two popular datasets. Table 1 indicates that our continuation methods
are consistently performing better at both validation and training loss. There are few
more interesting conclusions that can be drawn from Table 1. First, as expected,
ADAM, RMSProp and continuation methods performed better than SGD. This also
shows the optimization tasks for the experiment are not trivial. Second, our network
variant AE-16 Sigmoid turns out to be the most challenging task where ADAM and
RMSProp get stuck in a bad local minima which empirically shows networks are
difficult to train. However, our methods were able to skip these sub-optimal local
minima and achieved 49.18% lower training loss with 34.94% better generalization
as compared to ADAM, as shown in Table 1. Similar results were observed with
the Fashion-MNIST dataset. Another optimization bottleneck was seen in case of
AE-16 ReLU with the Fashion MNIST data, here RMSProp was unable to avoid
a sub-optimal local minimum, on the other hand, our methods had 48.88% lower
training loss and 30.08% better generalization (testing) error. Finally, we qualita-
tively demonstrate the maximum loss attained by various optimizers at the final step.
Clearly from Table 1, our parameter continuation methods visibly have a much lower
loss across all distinct tasks.
Further, to report convergence speed we analyze the convergence plots of all the
distinct tasks. Out of all plausible ways to define the speed of convergence. In this
chapter, an optimizer is faster if it obtains a lower training loss at an earlier step
in contrast to competing optimizer [52]. Hence, we studied the convergence plots
of various optimizers and found that continuation methods, i.e. NPC, NPAC and
NPACS, converged faster in majority of the tasks. In Fig. 7, we can see that the tasks
AE-16 and AE-8 Sigmoid networks continuation methods have a much lower training
error from the very beginning of the training as compared to both the RMSProp and
ADAM optimizer. Thus, our results from Table 1. Additionally, we extend our results
to three standard activation functions, C-Sigmoid in Fig. 7, C-Relu in Fig. 8 and C-
Tanh in Fig. 9. Through these tables we not only show that our generalization loss
is better in most cases but also that the continuation methods converges faster (i.e.
achieve lower train loss in fewer ADAM updates).
3 https://github.com/harsh306/NPACS.
Non-convex Optimization Using Parameter Continuation … 289
Fig. 7 The figures above demonstrate the convergence of various optimization methods when C-
Sigmoid is used as an activation function. In all cases, the X-axis shows the number of iterations,
and the Y-axis shows the log-loss or log-validation-loss. Note that the methods we propose provide
some of the lowest loss values, for training and testing, throughout the optimization procedure
290 H. Nilesh Pathak and R. Clinton Paffenroth
Fig. 8 The figures above demonstrate the convergence of various optimization methods when C-
ReLU is used as an activation function. In all cases, the X-axis shows the number of iterations,
and the Y-axis shows the log-loss or log-validation-loss. Note that the methods we propose provide
some of the lowest loss values, for training and testing, throughout the optimization procedure
Non-convex Optimization Using Parameter Continuation … 291
Fig. 9 The figures above demonstrate the convergence of various optimization methods when C-
Tanh is used as an activation function. In all cases, the X-axis shows the number of iterations, and
the Y-axis shows the log-loss or log-validation-loss. Note that the methods we propose provide
some of the lowest loss values, for training and testing, throughout the optimization procedure
292 H. Nilesh Pathak and R. Clinton Paffenroth
In [3, 50], the authors illustrated that continuation strategies can be introduced in
deep learning through model and data continuation. The methods illustrated in this
chapter are classical examples of model continuation. In Sect. 3, we discussed differ-
ent model continuation methods [21, 22, 44], and their advantages. These methods
have compared their training strategy with a target model and reported better gener-
alization error. However, we observe there are some limitations in these approaches.
Ideally, for the proposed continuation techniques it will be convenient to have them
empirically evaluated using different neural architectures and a variety of tasks. We
understand that such detailed evaluations may require substantial effort. First, we
observed that many different types of neural networks have not yet been system-
atically tested. For example, diffusion methods were applied on RNNs only [44].
MLPs and LSTMs were tested by [21] and only mollification were tested on CNNs,
RNNs and MLPs [22, 54]. In future research, it will be interesting to see how the
proposed curriculum affects the convergence of different types of networks. Second,
some methods [3, 22] were tested with limited depth and may not necessarily be
categorized as deep networks. Third, applications of curriculum learning may be
improved with more empirical evaluations from the methods in literature. In partic-
ular, it will be interesting to see comparisons with some techniques such as ResNets
[38], Dropout [58], and different normalizations for more tasks such as classification,
language modelling, regression and reconstruction [22].
As far as this chapter is concerned there are several natural extensions. First,
C-Activation can be any activation function, and in the case of multiple activation
functions in one network, we may add multiple homotopy parameters for each of
them. Also, as stated earlier, there is nothing in these techniques that makes them
specific in the AE case. In particular, we would also like to train Convolutional and
Recurrent Neural Networks described in [15, 39, 65] using our method. Second, our
adaptive method can be improved to conform to the homotopy path more accurately,
by using Pseudo-arclength methods [34, 48].
Non-convex Optimization Using Parameter Continuation … 293
Many researchers have shown learning data from a designed curriculum leads to bet-
ter learning and robustness to out of distribution error. Some of the popular methods
are progressive training [32, 50], Curriculum Learning [3, 23], Curriculum Learning
by Transfer Learning [66], C-SMOTE [50], etc. These methods have shown improve-
ment in the generalization performance. However, there is usually some limitation
because of the data type. For example, the design of Progressive GANs [32] was
applicable for images but may prove to be difficult to apply for a text dataset. For
future research it will be interesting to see how data curriculum methods work with
different kinds of data, such as images, text, sound and time-series. As a first step
in this direction, we collect and share a list of datasets.4 These datasets are indexed
from different sources and hope they could prove to be a useful resource for fellow
researchers. Recent research work illustrates how a data curriculum can be learned
via continuation [50], and automation can be introduced by measuring a model’s
learning [18, 30, 55, 59, 66, 68] with limited benchmarks which could likely be
extended in several directions.
A Deep Neural Network’s loss surface is usually non-convex [8, 17, 27], and depend-
ing on the neural architecture, many attempts have been made to theoretically catego-
rize the loss surfaces [17, 33, 41]. Continuation methods or curriculum learning may
provide a unique perspective to understand the loss surfaces of the neural network. In
particular, we plan to extend our work to perform Bifurcation Analysis of Neural
Networks. Bifurcations [2, 11, 49] are the dynamic changes in the behaviour of the
system that may occur while we track the solution through the homotopy path. In
our case, as we change our activation from linear to non-linear, detected bifurcations
[34] of neural networks may help us explain and identify local minima and they may
also help us to understand the so-called “black-box” of neural networks better.
the right curriculum for any given task [63]. In the paper [63], the authors empiri-
cally show that sometimes intuitively sound curriculum may even harm training (i.e.
learning harder tasks can lead to better solutions [56]). As a result, many research
works have designed self-generative curricula for training environments to introduce
progressively difficult and diverse scenarios for an agent and thus increase robust
learning [1, 59, 63]. One can easily imagine continuation methods having a role to
play for improving such methods.
In multi-task learning [6] model is designed to perform multiple tasks through the
same set of parameters. Designing such a system is more difficult, as learning from
one task may elevate or degrade the learning for other task and may also lead to Catas-
trophic forgetting [4]. Synchronizing neural network to perform multiple tasks is an
open research problem and recently, curriculum learning is used to meet this chal-
lenge [53, 64]. These can be broadly categorized in two ways, first, design a curricula
that provides selected batch of data dynamically, so that it benefits multi-task objec-
tive [61, 64]. Usually one need to determine importance of each instance for learning
a task for implementing such an automotive system, for which Bayesian optimiza-
tion is a classic choice [64]. Second, determine the best order of tasks to be learned
[53]. The latter approach is restricted to be implemented sequentially, following the
determined curriculum [53]. We expect progressive research in this direction may
significantly enhance multi-task learning. Also, in the above-mentioned approaches
the first approach is similar to the data continuation, and second is similar to model
continuation. We believe that we can draw similarities between the two and make
impactful progress in the field of multi-task learning.
Meta-learning is an active area of research [14] which was recently surveyed in [9].
We understand there are many aspects of meta-learning [14, 46, 46, 62], but for scope
of our discussion we limit our focus to hyperparameter search. The search space of
different hyperparameters can be traced efficiently using continuation methods, or
even multi-parameter continuation [2, 11, 34]. One can imagine two hyperparam-
eter λm and λd , where λm corresponds to model continuation and λd enables data
continuation. Then, we can apply continuation methods to optimize both strategies
simultaneously for a given neural network task. Recently, researchers used implicit
differentiation for hyperparameter optimization [42], in which millions of network
weights and hyperparameters can be jointly tuned.
Multi-parameter Continuation Optimization is a promising area of research
in fields such as mathematical analysis [2, 11, 34] and others [19, 29]. One may
Non-convex Optimization Using Parameter Continuation … 295
7 Conclusions
In this chapter, we exhibited a novel training strategy for deep neural networks.
As observed in most of the experiments all three continuation methods achieved
faster convergence, lower loss and better generalization than standard optimization
techniques. Also, we empirically show that the proposed method works with popular
activation functions, deeper networks and with distinct datasets. Finally, we examine
some of the possible improvements and provide various future directions to further
in deep learning using continuation methods.
References
15. I. Goodfellow, Y. Bengio, A. Courville, Deep Learning (MIT Press, 2016). http://www.
deeplearningbook.org
16. I.J. Goodfellow, NIPS 2016 tutorial: generative adversarial networks. NIPS (2017).
arXiv:abs/1701.00160
17. I.J. Goodfellow, O. Vinyals, Qualitatively characterizing neural network optimization problems.
CoRR (2014). arXiv:abs/1412.6544
18. A. Graves, M.G. Bellemare, J. Menick, R. Munos, K. Kavukcuoglu, Automated curriculum
learning for neural networks. CoRR (2017). arXiv:abs/1704.03003
19. C. Grenat, S. Baguet, C.H. Lamarque, R. Dufour, A multi-parametric recursive continuation
method for nonlinear dynamical systems. Mech. Syst. Signal Process. 127, 276–289 (2019)
20. M. Grzes, D. Kudenko, Theoretical and empirical analysis of reward shaping in reinforcement
learning, in 2009 International Conference on Machine Learning and Applications (2009), pp.
337–344. 10.1109/ICMLA.2009.33
21. C. Gülçehre, M. Moczulski, M. Denil, Y. Bengio, Noisy activation functions. CoRR (2016).
arXiv:abs/1603.00391
22. C. Gülçehre, M. Moczulski, F. Visin, Y. Bengio, Mollifying networks. CoRR (2016).
arXiv:abs/1608.04980
23. G. Hacohen, D. Weinshall, On the power of curriculum learning in training deep networks.
CoRR (2019). arXiv:abs/1904.03626
24. G. Hinton, L. Deng, D. Yu, G.E. Dahl, A. Mohamed, N. Jaitly, A. Senior, V. Vanhoucke, P.
Nguyen, T.N. Sainath, B. Kingsbury, Deep neural networks for acoustic modeling in speech
recognition: the shared views of four research groups. IEEE Signal Process. Mag. 29(6), 82–97
(2012). https://doi.org/10.1109/MSP.2012.2205597
25. G. Hinton, N. Srivastava, K. Swersky, Rmsprop: divide the gradient by a running average of
its recent magnitude. Neural networks for machine learning, Coursera lecture 6e (2012)
26. G.E. Hinton, R.R. Salakhutdinov, Reducing the dimensionality of data with neural networks.
Science 313(5786), 504–507 (2006). https://doi.org/10.1126/science.1127647, http://science.
sciencemag.org/content/313/5786/504
27. D.J. Im, M. Tao, K. Branson, An empirical analysis of deep network loss surfaces. CoRR
(2016). arXiv:abs/1612.04010
28. D. Jakubovitz, R. Giryes, M.R. Rodrigues, Generalization error in deep learning, in Compressed
Sensing and Its Applications (Springer, 2019), pp. 153–193
29. F. Jalali, J. Seader, Homotopy continuation method in multi-phase multi-reaction equilibrium
systems. Comput. Chem. Eng. 23(9), 1319–1331 (1999)
30. L. Jiang, Z. Zhou, T. Leung, L.J. Li, L. Fei-Fei, Mentornet: learning data-driven curriculum
for very deep neural networks on corrupted labels, in ICML (2018)
31. R. Johnson, F. Kiokemeister, Calculus, with Analytic Geometry (Allyn and Bacon, 1964).
https://books.google.com/books?id=X4_UAQAACAAJ
32. T. Karras, T. Aila, S. Laine, J. Lehtinen, Progressive growing of gans for improved quality,
stability, and variation. CoRR (2017). arXiv:abs/1710.10196
33. K. Kawaguchi, L.P. Kaelbling, Elimination of all bad local minima in deep learning. CoRR
(2019)
34. H.B. Keller, Numerical solution of bifurcation and nonlinear eigenvalue problems, in Applica-
tions of Bifurcation Theory, ed. by P.H. Rabinowitz (Academic Press, New York, 1977), pp.
359–384
35. D.P. Kingma, J. Ba, Adam: a method for stochastic optimization. CoRR (2014).
arXiv:abs/1412.6980
36. P. Krähenbühl, C. Doersch, J. Donahue, T. Darrell, Data-dependent initializations of convolu-
tional neural networks. CoRR (2015). arXiv:abs/1511.06856
37. A. Krizhevsky, V. Nair, G. Hinton, Cifar-10 (canadian institute for advanced research). http://
www.cs.toronto.edu/kriz/cifar.html
38. A. Krizhevsky, I. Sutskever, G.E. Hinton, Imagenet classification with deep convolutional
neural networks, in Advances in Neural Information Processing Systems (2012)
Non-convex Optimization Using Parameter Continuation … 297
59. F.P. Such, A. Rawal, J. Lehman, K. Stanley, J. Clune, Generative teaching networks: acceler-
ating neural architecture search by learning to generate synthetic training data (2020)
60. I. Sutskever, J. Martens, G. Dahl, G. Hinton, On the importance of initialization and momentum
in deep learning, in International Conference on Machine Learning (2013), pp. 1139–1147
61. Y. Tsvetkov, M. Faruqui, W. Ling, B. MacWhinney, C. Dyer, Learning the curriculum with
Bayesian optimization for task-specific word representation learning, in Proceedings of the
54th Annual Meeting of the Association for Computational Linguistics, Long Papers, vol. 1
(Association for Computational Linguistics, Berlin, Germany, 2016), pp. 130–139. https://doi.
org/10.18653/v1/P16-1013., https://www.aclweb.org/anthology/P16-1013
62. R. Vilalta, Y. Drissi, A perspective view and survey of meta-learning. Artif. Intell. Rev. 18(2),
77–95 (2002)
63. R. Wang, J. Lehman, J. Clune, K.O. Stanley, Paired open-ended trailblazer (POET): endlessly
generating increasingly complex and diverse learning environments and their solutions. CoRR
(2019). arXiv:abs/1901.01753
64. W. Wang, Y. Tian, J. Ngiam, Y. Yang, I. Caswell, Z. Parekh, Learning a multitask curriculum
for neural machine translation (2019). arXiv preprint arXiv:1908.10940 (2019)
65. M.A. Wani, F.A. Bhat, S. Afzal, A.I. Khan, Advances in deep learning, in Advances in Deep
Learning (Springer, 2020), pp. 1–11
66. D. Weinshall, G. Cohen, Curriculum learning by transfer learning: theory and experiments with
deep networks. CoRR (2018). arXiv:abs/1802.03796
67. H. Xiao, K. Rasul, R. Vollgraf, Fashion-mnist: a novel image dataset for benchmarking machine
learning algorithms (2017). https://github.com/zalandoresearch/fashion-mnist
68. H. Xuan, A. Stylianou, R. Pless, Improved embeddings with easy positive triplet mining (2019)
69. C. Zhou, R.C. Paffenroth, Anomaly detection with robust deep autoencoders, in Proceedings of
the 23rd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining
(ACM, 2017), pp. 665–674
Author Index
A K
Alfarhood, Meshal, 1 Kamran, Sharif Amit, 25
Arif Wani, M., 101 Kanarachos, Stratis, 229
Khoshgoftaar, Taghi M., 199
Kouchak, Shokoufeh Monjezi, 123
C
Cheng, Jianlin, 1
Christopoulos, Stavros-Richard G., 229 L
Clinton Paffenroth, Randy, 273 Leite, Guilherme Vieira, 49
D M
da Silva, Gabriel Pellegrino, 49 Mujtaba, Tahir, 101
Mukherjee, Tathagata, 143
F
Fischer, Georg, 173
N
Nilesh Pathak, Harsh, 273
G
Gaffar, Ashraf, 123
Gühmann, Clemens, 81 O
Onyekpe, Uche, 229
H
Hartmann, Sven, 81 P
Palade, Vasile, 229
Pasiliao, Eduardo, 143
I Pedrini, Helio, 49
Imran, Abdullah-Al-Zubaer, 249
R
J Rosato, Daniele, 81
Johnson, Justin M., 199 Roy, Debashri, 143
© The Editor(s) (if applicable) and The Author(s), under exclusive license 299
to Springer Nature Singapore Pte Ltd. 2021
M. A. Wani et al. (eds.), Deep Learning Applications, Volume 2,
Advances in Intelligent Systems and Computing 1232,
https://doi.org/10.1007/978-981-15-6759-9
300 Author Index
S T
Sabbir, Ali Shihab, 25 Tavakkoli, Alireza, 25
Sabir, Russell, 81 Terzopoulos, Demetri, 249
Saha, Sourajit, 25
Santra, Avik, 173
Stephan, Michael, 173