Download as pdf or txt
Download as pdf or txt
You are on page 1of 307

Advances in Intelligent Systems and Computing 1232

M. Arif Wani
Taghi M. Khoshgoftaar
Vasile Palade Editors

Deep Learning
Applications,
Volume 2
Advances in Intelligent Systems and Computing

Volume 1232

Series Editor
Janusz Kacprzyk, Systems Research Institute, Polish Academy of Sciences,
Warsaw, Poland

Advisory Editors
Nikhil R. Pal, Indian Statistical Institute, Kolkata, India
Rafael Bello Perez, Faculty of Mathematics, Physics and Computing,
Universidad Central de Las Villas, Santa Clara, Cuba
Emilio S. Corchado, University of Salamanca, Salamanca, Spain
Hani Hagras, School of Computer Science and Electronic Engineering,
University of Essex, Colchester, UK
László T. Kóczy, Department of Automation, Széchenyi István University,
Gyor, Hungary
Vladik Kreinovich, Department of Computer Science, University of Texas
at El Paso, El Paso, TX, USA
Chin-Teng Lin, Department of Electrical Engineering, National Chiao
Tung University, Hsinchu, Taiwan
Jie Lu, Faculty of Engineering and Information Technology,
University of Technology Sydney, Sydney, NSW, Australia
Patricia Melin, Graduate Program of Computer Science, Tijuana Institute
of Technology, Tijuana, Mexico
Nadia Nedjah, Department of Electronics Engineering, University of Rio de Janeiro,
Rio de Janeiro, Brazil
Ngoc Thanh Nguyen , Faculty of Computer Science and Management,
Wrocław University of Technology, Wrocław, Poland
Jun Wang, Department of Mechanical and Automation Engineering,
The Chinese University of Hong Kong, Shatin, Hong Kong
The series “Advances in Intelligent Systems and Computing” contains publications
on theory, applications, and design methods of Intelligent Systems and Intelligent
Computing. Virtually all disciplines such as engineering, natural sciences, computer
and information science, ICT, economics, business, e-commerce, environment,
healthcare, life science are covered. The list of topics spans all the areas of modern
intelligent systems and computing such as: computational intelligence, soft comput-
ing including neural networks, fuzzy systems, evolutionary computing and the fusion
of these paradigms, social intelligence, ambient intelligence, computational neuro-
science, artificial life, virtual worlds and society, cognitive science and systems,
Perception and Vision, DNA and immune based systems, self-organizing and
adaptive systems, e-Learning and teaching, human-centered and human-centric
computing, recommender systems, intelligent control, robotics and mechatronics
including human-machine teaming, knowledge-based paradigms, learning para-
digms, machine ethics, intelligent data analysis, knowledge management, intelligent
agents, intelligent decision making and support, intelligent network security, trust
management, interactive entertainment, Web intelligence and multimedia.
The publications within “Advances in Intelligent Systems and Computing” are
primarily proceedings of important conferences, symposia and congresses. They
cover significant recent developments in the field, both of a foundational and
applicable character. An important characteristic feature of the series is the short
publication time and world-wide distribution. This permits a rapid and broad
dissemination of research results.
** Indexing: The books of this series are submitted to ISI Proceedings,
EI-Compendex, DBLP, SCOPUS, Google Scholar and Springerlink **

More information about this series at http://www.springer.com/series/11156


M. Arif Wani Taghi M. Khoshgoftaar
• •

Vasile Palade
Editors

Deep Learning Applications,


Volume 2

123
Editors
M. Arif Wani Taghi M. Khoshgoftaar
Department of Computer Science Computer and Electrical Engineering
University of Kashmir Florida Atlantic University
Srinagar, India Boca Raton, FL, USA

Vasile Palade
Faculty of Engineering and Computing
Coventry University
Coventry, UK

ISSN 2194-5357 ISSN 2194-5365 (electronic)


Advances in Intelligent Systems and Computing
ISBN 978-981-15-6758-2 ISBN 978-981-15-6759-9 (eBook)
https://doi.org/10.1007/978-981-15-6759-9

© The Editor(s) (if applicable) and The Author(s), under exclusive license to Springer Nature
Singapore Pte Ltd. 2021
This work is subject to copyright. All rights are solely and exclusively licensed by the Publisher, whether
the whole or part of the material is concerned, specifically the rights of translation, reprinting, reuse of
illustrations, recitation, broadcasting, reproduction on microfilms or in any other physical way, and
transmission or information storage and retrieval, electronic adaptation, computer software, or by similar
or dissimilar methodology now known or hereafter developed.
The use of general descriptive names, registered names, trademarks, service marks, etc. in this
publication does not imply, even in the absence of a specific statement, that such names are exempt from
the relevant protective laws and regulations and therefore free for general use.
The publisher, the authors and the editors are safe to assume that the advice and information in this
book are believed to be true and accurate at the date of publication. Neither the publisher nor the
authors or the editors give a warranty, expressed or implied, with respect to the material contained
herein or for any errors or omissions that may have been made. The publisher remains neutral with regard
to jurisdictional claims in published maps and institutional affiliations.

This Springer imprint is published by the registered company Springer Nature Singapore Pte Ltd.
The registered company address is: 152 Beach Road, #21-01/04 Gateway East, Singapore 189721,
Singapore
Preface

Machine learning algorithms have influenced many aspects of our day-to-day living
and transformed major industries around the world. Fueled by an exponential
growth of data, improvements in computer hardware, scalable cloud resources, and
accessible open-source frameworks, machine learning technology is being used by
companies in big and small alike for innumerable applications. At home, machine
learning models are suggesting TV shows, movies, and music for entertainment,
providing personalized ecommerce suggestions, shaping our digital social net-
works, and improving the efficiency of our appliances. At work, these data-driven
methods are filtering our emails, forecasting trends in productivity and sales, tar-
geting customers with advertisements, improving the quality of video conferences,
and guiding critical decisions. At the frontier of machine learning innovation are
deep learning systems, a class of multi-layered networks is capable of automatically
learning meaningful hierarchical representations from a variety of structured and
unstructured data. Breakthroughs in deep learning allow us to generate new rep-
resentations, extract knowledge, and draw inferences from raw images, video
streams, text and speech, time series, and other complex data types. These powerful
deep learning methods are being applied to new and exciting real-world problems in
medical diagnostics, factory automation, public safety, environmental sciences,
autonomous transportation, military applications, and much more.
The family of deep learning architectures continues to grow as new methods and
techniques are developed to address a wide variety of problems. A deep learning
network is composed of multiple layers that form universal approximators capable
of learning any function. For example, the convolutional layers in Convolutional
Neural Networks use shared weights and spatial invariance to efficiently learn
hierarchical representations from images, natural language, and temporal data.
Recurrent Neural Networks use backpropagation through time to learn from vari-
able length sequential data. Long Short-Term Memory networks are a type of
recurrent network capable of learning order dependence in sequence prediction
problems. Deep Belief Networks, Autoencoders, and other unsupervised models
generate meaningful latent features for downstream tasks and model the underlying
concepts of distributions by reconstructing their inputs. Generative Adversarial

v
vi Preface

Networks simultaneously learn generative models capable of producing new data


from distribution and discriminative models that can distinguish between real and
artificial images. Transformer Networks combine encoders and decoders with
attention layers for improved sequence-to-sequence learning. Network architecture
search automates the designs of these deep models by optimizing performance over
the hyperparameter space. As a result of these advances, and many others, deep
learning is revolutionizing complex problem domains with state-of-the-art results
and, in some cases, is a way superior to the human performances.
This book explores some of the latest applications in deep learning and includes
a variety of architectures and novel deep learning techniques. Deep models are
trained to recommend products, diagnose medical conditions or faults in industrial
machines, detect when a human falls, and recognize solar panels in aerial images.
Sequence models are used to capture driving behaviors and identify radio trans-
mitters from temporal data. Residual networks are used to detect human targets in
indoor environments, algorithm incorporating thresholding strategy is used to
identify fraud within highly imbalanced data, and hybrid methods are used to locate
vehicles during satellite outages. Multi-adversarial variational autoencoder network
is used for image synthesis and classification and finally parameter continuation
method is used for non-convex optimization of deep neural networks. We believe
that these recent deep learning methods and applications illustrated in this book
capture some of the most exciting advances in deep learning.

Srinagar, India M. Arif Wani


Boca Raton, USA Taghi M. Khoshgoftaar
Coventry, UK Vasile Palade
Contents

Deep Learning-Based Recommender Systems . . . . . . . . . . . . . . . . . . . . 1


Meshal Alfarhood and Jianlin Cheng
A Comprehensive Set of Novel Residual Blocks for Deep Learning
Architectures for Diagnosis of Retinal Diseases from Optical
Coherence Tomography Images . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 25
Sharif Amit Kamran, Sourajit Saha, Ali Shihab Sabbir,
and Alireza Tavakkoli
Three-Stream Convolutional Neural Network for Human
Fall Detection . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 49
Guilherme Vieira Leite, Gabriel Pellegrino da Silva, and Helio Pedrini
Diagnosis of Bearing Faults in Electrical Machines Using Long
Short-Term Memory (LSTM) . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 81
Russell Sabir, Daniele Rosato, Sven Hartmann, and Clemens Gühmann
Automatic Solar Panel Detection from High-Resolution Orthoimagery
Using Deep Learning Segmentation Networks . . . . . . . . . . . . . . . . . . . . 101
Tahir Mujtaba and M. Arif Wani
Training Deep Learning Sequence Models to Understand
Driver Behavior . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 123
Shokoufeh Monjezi Kouchak and Ashraf Gaffar
Exploiting Spatio-Temporal Correlation in RF Data
Using Deep Learning . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 143
Debashri Roy, Tathagata Mukherjee, and Eduardo Pasiliao
Human Target Detection and Localization with Radars
Using Deep Learning . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 173
Michael Stephan, Avik Santra, and Georg Fischer

vii
viii Contents

Thresholding Strategies for Deep Learning with Highly Imbalanced


Big Data . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 199
Justin M. Johnson and Taghi M. Khoshgoftaar
Vehicular Localisation at High and Low Estimation Rates During
GNSS Outages: A Deep Learning Approach . . . . . . . . . . . . . . . . . . . . . 229
Uche Onyekpe, Stratis Kanarachos, Vasile Palade,
and Stavros-Richard G. Christopoulos
Multi-Adversarial Variational Autoencoder Nets for Simultaneous
Image Generation and Classification . . . . . . . . . . . . . . . . . . . . . . . . . . . 249
Abdullah-Al-Zubaer Imran and Demetri Terzopoulos
Non-convex Optimization Using Parameter Continuation Methods
for Deep Neural Networks . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 273
Harsh Nilesh Pathak and Randy Clinton Paffenroth

Author Index . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 299


Editors and Contributors

About the Editors

Dr. M. Arif Wani is a Professor at the University of Kashmir, having previously


served as a Professor at California State University, Bakersfield. He completed his
M.Tech. in Computer Technology at the Indian Institute of Technology, Delhi, and
his Ph.D. in Computer Vision at Cardiff University, UK. His research interests are
in the area of machine learning, with a focus on neural networks, deep learning,
inductive learning, and support vector machines, and with application to areas that
include computer vision, pattern recognition, classification, prediction, and analysis
of gene expression datasets. He has published many papers in reputed journals and
conferences in these areas. Dr. Wani has co-authored the book ‘Advances in Deep
Learning,’ co-edited the book ‘Deep Learning Applications,’ and co-edited 17
conference proceeding books in machine learning and applications area. He is a
member of many academic and professional bodies, e.g., the Indian Society for
Technical Education, Computer Society of India, and IEEE USA.

Dr. Taghi M. Khoshgoftaar is the Motorola Endowed Chair professor of the


Department of computer and electrical engineering and Computer Science, Florida
Atlantic University, and the Director of NSF Big Data Training and Research
Laboratory. His research interests are in big data analytics, data mining and
machine learning, health informatics and bioinformatics, social network mining,
and software engineering. He has published more than 750 refereed journal and
conference papers in these areas. He was the Conference Chair of the IEEE
International Conference on Machine Learning and Applications (ICMLA 2019).
He is the Co-Editor-in-Chief of the Journal of Big Data. He has served on orga-
nizing and technical program committees of various international conferences,
symposia, and workshops. He has been a Keynote Speaker at multiple international

ix
x Editors and Contributors

conferences and has given many invited talks at various venues. Also, he has served
as North American Editor of the Software Quality Journal, was on the editorial
boards of the journals Multimedia Tools and Applications, Knowledge and
Information Systems, and Empirical Software Engineering, and is on the editorial
boards of the journals Software Quality, Software Engineering and Knowledge
Engineering, and Social Network Analysis and Mining.

Dr. Vasile Palade is currently a Professor of Artificial Intelligence and Data


Science at Coventry University, UK. He previously held several academic and
research positions at the University of Oxford—UK, University of Hull—UK, and
the University of Galati—Romania. His research interests are in the area of machine
learning, with a focus on neural networks and deep learning, and with main
application to image processing, social network data analysis and web mining,
smart cities, health, among others. Dr. Palade is author and co-author of more than
170 papers in journals and conference proceedings as well as several books on
machine learning and applications. He is an Associate Editor for several reputed
journals, such as Knowledge and Information Systems and Neurocomputing. He
has delivered keynote talks to international conferences on machine learning and
applications. Dr. Vasile Palade is an IEEE Senior Member.

Contributors

Meshal Alfarhood Department of Electrical Engineering and Computer Science,


University of Missouri-Columbia, Columbia, USA
M. Arif Wani Department of Computer Science, University of Kashmir, Srinagar,
India
Jianlin Cheng Department of Electrical Engineering and Computer Science,
University of Missouri-Columbia, Columbia, USA
Stavros-Richard G. Christopoulos Institute for Future Transport and Cities,
Coventry University, Coventry, UK;
Faculty of Engineering, Coventry University, Coventry, UK
Randy Clinton Paffenroth Worcester Polytechnic Institute, Mathematical
Sciences Computer Science & Data Science, Worcester, MA, USA
Gabriel Pellegrino da Silva Institute of Computing, University of Campinas,
Campinas, SP, Brazil
Georg Fischer Friedrich-Alexander-University Erlangen-Nuremberg, Erlangen,
Germany
Ashraf Gaffar Arizona State University, Tempe, USA
Editors and Contributors xi

Clemens Gühmann Chair of Electronic Measurement and Diagnostic Technology


& Technische Universität Berlin, Berlin, Germany
Sven Hartmann SEG Automotive Germany GmbH, Stuttgart, Germany
Abdullah-Al-Zubaer Imran University of California, Los Angeles, CA, USA
Justin M. Johnson Florida Atlantic University, Boca Raton, FL, USA
Sharif Amit Kamran University of Nevada, Reno, NV, USA
Stratis Kanarachos Faculty of Engineering, Coventry University, Coventry, UK
Taghi M. Khoshgoftaar Florida Atlantic University, Boca Raton, FL, USA
Shokoufeh Monjezi Kouchak Arizona State University, Tempe, USA
Guilherme Vieira Leite Institute of Computing, University of Campinas,
Campinas, SP, Brazil
Tahir Mujtaba Department of Computer Science, University of Kashmir,
Srinagar, India
Tathagata Mukherjee Computer Science, University of Alabama, Huntsville,
AL, USA
Harsh Nilesh Pathak Expedia Group, Seattle, WA, USA
Uche Onyekpe Institute for Future Transport and Cities, Coventry University,
Coventry, UK;
Research Center for Data Science, Coventry University, Coventry, UK
Vasile Palade Research Center for Data Science, Coventry University, Coventry,
UK
Eduardo Pasiliao Munitions Directorate, Air Force Research Laboratory, Eglin
AFB, Valparaiso, FL, USA
Helio Pedrini Institute of Computing, University of Campinas, Campinas, SP,
Brazil
Daniele Rosato SEG Automotive Germany GmbH, Stuttgart, Germany
Debashri Roy Computer Science, University of Central Florida, Orlando, FL,
USA
Russell Sabir SEG Automotive Germany GmbH, Stuttgart, Germany;
Chair of Electronic Measurement and Diagnostic Technology & Technische
Universität Berlin, Berlin, Germany
Ali Shihab Sabbir Center for Cognitive Skill Enhancement, Independent
University Bangladesh, Dhaka, Bangladesh
xii Editors and Contributors

Sourajit Saha Center for Cognitive Skill Enhancement, Independent University


Bangladesh, Dhaka, Bangladesh
Avik Santra Infineon Technologies AG, Neubiberg, Germany
Michael Stephan Infineon Technologies AG, Neubiberg, Germany;
Friedrich-Alexander-University Erlangen-Nuremberg, Erlangen, Germany
Alireza Tavakkoli University of Nevada, Reno, NV, USA
Demetri Terzopoulos University of California, Los Angeles, CA, USA
Deep Learning-Based Recommender
Systems

Meshal Alfarhood and Jianlin Cheng

Abstract The term “information overload” has gained popularity over the last few
years. It defines the difficulties people face in finding what they want from a huge
volume of available information. Recommender systems have been recognized to be
an effective solution to such issues, such that suggestions are made based on users’
preferences. This chapter introduces an application of deep learning techniques in
the domain of recommender systems. Generally, collaborative filtering approaches,
and Matrix Factorization (MF) techniques in particular, are widely known for their
convincing performance in recommender systems. We introduce a Collaborative
Attentive Autoencoder (CATA) that improves the matrix factorization performance
by leveraging an item’s contextual data. Specifically, CATA learns the proper features
from scientific articles through the attention mechanism that can capture the most
pertinent parts of information in order to make better recommendations. The learned
features are then incorporated into the learning process of MF. Comprehensive exper-
iments on three real-world datasets have shown our method performs better than other
state-of-the-art methods according to various evaluation metrics. The source code of
our model is available at: https://github.com/jianlin-cheng/CATA.

This chapter is an extended version of our published paper at the IEEE ICMLA conference 2019
[1]. This chapter incorporates new experimental contributions compared to the original confere-
nce paper.

M. Alfarhood (B) · J. Cheng


Department of Electrical Engineering and Computer Science,
University of Missouri-Columbia, Columbia, USA
e-mail: may82@missouri.edu
J. Cheng
e-mail: chengji@missouri.edu

© The Editor(s) (if applicable) and The Author(s), under exclusive license 1
to Springer Nature Singapore Pte Ltd. 2021
M. A. Wani et al. (eds.), Deep Learning Applications, Volume 2,
Advances in Intelligent Systems and Computing 1232,
https://doi.org/10.1007/978-981-15-6759-9_1
2 M. Alfarhood and J. Cheng

1 Introduction

The era of e-commerce has vastly changed people’s lifestyles during the first part
of the twenty-first century. People today tend to do many of their daily routines
online, such as shopping, reading the news, and watching movies. Nevertheless,
consumers often face difficulties while exploring related items such as new fashion
trends because they are not aware of their existence due to the overwhelming amount
of information available online. This phenomenon is widely known as “information
overload”. Therefore, Recommender Systems (RSs) are a critical solution for helping
users make decisions when there are lots of choices. RSs have been integrated into
and have become an essential part of every website due to their impact on increasing
customer interactions, attracting new customers, and growing businesses’ revenue.
Scientific article recommendation is a very common application for RSs. It keeps
researchers updated on recent related work in their field. One traditional way to
find relevant articles is to go through the references section in other articles. Yet,
this approach is biased toward heavily cited articles, such that new relevant articles
with higher impact have less chance to be found. Another method is to search for
articles using keywords. Although this technique is popular among researchers, they
must filter out a tremendous number of articles from the search results to retrieve
the most suitable articles. Moreover, all users get the same search results with the
same keywords, and these results are not personalized based on the users’ personal
interests. Thus, recommendation systems can address this issue and help scientists
and researchers find valuable articles while being aware of recent related work.
Over the last few decades, a lot of effort has been made by both academia and
industry on proposing new ideas and solutions for RSs, which ultimately help ser-
vice providers in adopting such models in their system architecture. The research in
RSs has evolved remarkably following the Netflix prize competition1 in 2006, where
the company offered one million dollars for any team that could improve their rec-
ommendation accuracy by 10%. Since that time, collaborative filtering models and
matrix factorization techniques in particular have become the most common models
due to their effective performance. Generally, recommendation models are classified
into three categories: Collaborative Filtering Models (CF), Content-Based Filter-
ing models (CBF), and hybrid models. CF models [2–4] focus on users’ histories,
such that users with similar past behaviors tend to have similar future tastes. On the
other hand, CBF models work by learning the item’s features from its informational
description, such that two items are possibly similar to each other if they share more
characteristics. For example, two songs are similar to each other if they both share
the same artist, genre, tempo, energy, etc. However, similarities between items in CF
models are different such that two items are likely similar to each other once they are
rated by multiple users in the same manner, even though those items have different
characteristics.

1 www.netflixprize.com.
Deep Learning-Based Recommender Systems 3

Generally, CF models function better than CBF models. However, CF performance


drops substantially when users or items have an insufficient amount of feedback
data. This problem is defined as the data sparsity problem. To tackle data sparseness,
hybrid models have been widely proposed in recent works [5–8], in which content
information, used in CBF models, is incorporated with CF models to improve the
system performance. Hybrid models are divided into two sub-categories according
to how models are trained: loosely coupled models and tightly coupled models [7].
Loosely coupled models train CF and CBF models separately, like ensembles, and
then determine the final score based on the scores of the two separated models. On
the other hand, the tightly coupled models train both CF and CBF models jointly.
In joint training, both models cooperate with one another to calculate the final score
under the same loss function.
Even though traditional recommendation approaches have achieved great success
over the last years, they still have shortcomings in accurately modeling complex
(e.g., non-linear) relationship between users and items. Alternatively, deep neural
networks are universal function approximators that are capable of modeling any con-
tinuous function. Recently, Deep Learning (DL) has become an effective approach
for most data mining problems. DL meets recommendation systems in the last few
years. One of the first works that applied DL concept for CF recommendations was
Restricted Boltzmann Machines (RBM) [4]. However, this approach was not deep
enough (two layers only) to learn users’ tastes from their histories, and it also did not
take contextual information into consideration. Recently, Collaborative Deep Learn-
ing (CDL) [7] has become a very popular deep learning technique in RSs due to its
promising performance. CDL can be viewed as an updated version of Collaborative
Topic Regression (CTR) [5] by substituting the Latent Dirichlet Allocation (LDA)
topic modeling with a Stacked Denoising Autoencoder (SDAE) to learn from item
contents, and then integrating the learned latent features into a Probabilistic Matrix
Factorization (PMF). Lately, Collaborative Variational Autoencoder (CVAE) [8] has
been proposed to learn deep item latent features via a variational autoencoder. The
authors show that their model learns better item features than CDL because their
model infers the latent variable distribution in latent space instead of observation
space. However, both CDL and CVAE models assume that all parts of their model’s
contribution are the same for their final predictions.
Hence, in this work, we propose a deep learning-based model named Collaborative
Attentive Autoencoder (CATA) for recommending scientific articles. In particular,
we integrate the attention mechanism into our unsupervised deep learning process
to identify an item’s features. We learn the item’s features from the article’s textual
information (e.g., the article’s title and abstract) to enhance the recommendation
quality. The compressed low-dimensional representation learned by the unsupervised
model is incorporated then into the matrix factorization approach for our ultimate
recommendation. To demonstrate the capability of our proposed model to generate
more relevant recommendations, we conduct inclusive experiments on three real-
world datasets, which are taken from the CiteULike2 website, to evaluate CATA

2 www.citeulike.org.
4 M. Alfarhood and J. Cheng

against multiple recent works. The experimental results prove that our model can
extract more constructive information from an article’s contextual data than other
models. More importantly, CATA performs very well where the data sparsity is
extremely high.
The remainder of this chapter is organized in the following manner. First, we
demonstrate the matrix factorization method in Sect. 2. We introduce our model,
CATA, in Sect. 3. The experimental results of our model against the state-of-the-art
models are discussed thoroughly in Sect. 4. We then conclude our work in Sect. 5.

2 Background

Our work is designed and evaluated on recommendations with implicit feedback.


Thus, in this section, we describe the well-known collaborative filtering approach,
Matrix Factorization, for implicit feedback problems.

2.1 Matrix Factorization

Matrix Factorization (MF) [2] is the most popular CF method, mainly due to its
simplicity and efficiency. The idea behind MF is to decompose the user-item matrix,
R ∈ Rn×m , into two lower dimensional matrices, U ∈ Rn×d and V ∈ Rm×d , such
that the inner product of U and V will approximate the original matrix R, where d
is the dimension of the latent factors, such that d  min(n, m). n and m correspond
to the number of users and items in the system. Figure 1 illustrates the MF process.

R ≈ U · VT (1)

MF optimizes the values of U and V by minimizing the sum of the squared


difference between the actual values and the model predictions with adding two
regularization terms, as shown here:
 Ii j λu λv  
L= (ri j − u i v Tj )2 + u i 2 + v j 2 (2)
i, j∈R
2 2 2

where Ii j is an indicator function that equals 1 if useri has rated item j , and 0 if
otherwise. Also, ||U || and ||V || are the Euclidean norms, and λu , λv are two regu-
larization terms preventing the values of U and V from being too large. This avoid
model overfitting.
Explicit data, such as ratings (ri j ) are not regularly available. Therefore, Weighted
Regularized Matrix Factorization (WRMF) [9] introduces two modifications to the
previous objective function to make it work for implicit feedback. The optimization
Deep Learning-Based Recommender Systems 5

Fig. 1 Matrix factorization illustration

process in this case runs through all user-item pairs with different confidence levels
assigned to each pair, as in the following:
 ci j λu λv  
L= ( pi j − u i v Tj )2 + u i 2 + v j 2 (3)
i, j∈R
2 2 2

where pi j is the user preference score with a value of 1 when useri and item j have
an interaction, and 0 otherwise. ci j is a confidence variable where its value shows
how confident the user like the item. In general, ci j = a when pi j = 1, and ci j = b
when pi j = 0, such that a > b > 0.
Stochastic Gradient Decent (SGD) [10] and Alternating Least Squares (ALS) [11]
are two optimization methods that can be used to minimize the objective function
of MF in Eq. 2. The first method, SGD, loops over each single training sample and
then computes the prediction error as ei j = ri j − u i v Tj . The gradient of the objective
function with respect to u i and v j can be computed as follows:

∂L 
=− Ii j (ri j − u i v Tj )v j + λu u i
∂u i j
 (4)
∂L
=− Ii j (ri j − u i v Tj )u i + λv v j
∂v j i

After calculating the gradient, SGD updates the user and item latent factors in the
opposite direction of the gradient using the following equations:
6 M. Alfarhood and J. Cheng
⎛ ⎞

ui ← ui + α ⎝ Ii j ei j v j − λu u i ⎠
j
 (5)

vj ← vj + α Ii j ei j u i − λ j v j
i

where α is the learning rate.


Even though SGD is easy to implement and generally faster than ALS in some
cases, it is not suitable to use with implicit feedback, since looping over each single
training sample is not practical. ALS works better in this case. ALS iteratively opti-
mizes U while V is fixed, and vice versa. This optimization process is repeated until
the model converges.
To determine what user and item vector values minimize the objective function
for implicit data (Eq. 3), we first take the derivative of L with respect to u i .

∂L 
=− ci j ( pi j − u i v Tj )v j + λu u i
∂u i j

0 = −Ci (Pi − u i V T )V + λu u i
0 = −Ci V Pi + Ci V u i V T + λu u i (6)
V Ci Pi = u i V Ci V + λu u i
T

V Ci Pi = u i (V Ci V T + λu I )
ui = V Ci Pi (V Ci V T + λu I )−1
ui = (V Ci V T + λu I )−1 V Ci Pi

where I is the identity matrix.


Similarly, taking the derivative of L with respect to v j leads to

v j = (U C j U T + λv I )−1 U C j P j (7)

3 Proposed Model

In this section, we illustrate our proposed model in depth. The intuition behind our
model is to learn the latent factors of items in PMF with the use of available side
textual contents. We use an attentive unsupervised model to catch more plentiful
information from the available data. The architecture of our model is displayed in
Fig. 2. We first define the problem with implicit feedback before we go through the
details of our model.
Deep Learning-Based Recommender Systems 7

λu λv

X̂ j

Decoder

Attention

Ui Vj Zj X Softmax

Rij Encoder

Xj
i = 1:n j = 1:m

Fig. 2 Collaborative attentive autoencoder architecture

3.1 Problem Definition

User-item interaction data is the primary source for training recommendation


engines. This data can be either collected in an explicit or implicit manner. In explicit
data, users directly express their opinion about an item using the rating system to
show how much they like that item. The user’s ratings usually vary from one-star to
five-stars with five being very interested and one being not interested. This type of
data is very useful and reliable due to the fact that it represents the actual feeling of
users about items. However, users’ ratings occasionally are not available due to the
difficulty of obtaining users’ explicit opinions. In this case, implicit feedback can be
obtained indirectly from the user’s behavior such as user clicks, bookmarks, or the
time spent viewing an item. For instance, if a user listens to a song 10 times in the
last two days, he or she most likely likes this song. Thus, implicit data is more preva-
lent and easier to collect, but it is generally less reliable than explicit data. Also, all
the observed interactions in implicit data constitute positive feedback, but negative
feedback is missing. This problem is also defined as the one-class problem.
There are multiple previous works aiming to deal with the one-class problem. A
simple solution is to treat all missing data as negative feedback. However, this is
not true because the missing (unobserved) interaction could be positive if the user is
aware of the item existing. Therefore, using this strategy to build a model might result
in a misleading model due to faulty assumptions at the outset. On the contrary, if
8 M. Alfarhood and J. Cheng

we treat all missing data as unobserved data without considering including negative
feedback in the model training, the corresponding trained model is probably useless
since it is only trained on positive data. As a result, sampling negative feedback
from positive feedback is one practical solution for this problem, which has been
proposed by [12]. In addition, Weighted Regularized Matrix Factorization (WRMF)
[9] is another proposed solution that introduces a confidence variable that works as
a weight to measure how likely a user is to like an item.
In general, the recommendation problem with implicit data is usually formulated
as follows:
1, if there is user-item interaction
Rnm = (8)
0, otherwise

where the ones in implicit feedback represent all the positive feedback. However,
it is important to note that a value of 0 does not imply always negative feedback.
It may be that users are not aware of the existence of those items. In addition,
the user-item interaction matrix (R) is usually highly imbalanced, such that the
number of the observed interactions is much less than the number of the unobserved
interactions. In other words, matrix R is very sparse, meaning that users only interact
explicitly or implicitly with a very small number of items compared to the total
number of items in this matrix. Sparsity is one frequent problem in RSs, which brings
a real challenge for any proposed model to have the capability to provide effective
personalized recommendations under this situation. The following sections explain
our methodology, where we aim to eliminate the influence of the aforementioned
problems.

3.2 The Attentive Autoencoder

Autoencoder [13] is an unsupervised learning neural network that is useful for com-
pressing high-dimensional input data into a lower dimensional representation while
preserving the abstract nature of the data. The autoencoder network is generally
composed of two main components, i.e., the encoder and the decoder. The encoder
takes the input and encodes it through multiple hidden layers and then generates a
compressed representative vector, Z j . The encoding function can be formulated as
Z j = f (X j ). Subsequently, the decoder can be used then to reconstruct and estimate
the original input, Xˆ j , using the representative vector, Z j . The decoder function can
be formulated as Xˆ j = f (Z j ). Each the encoder and the decoder usually consist of
the same number of hidden layers and neurons. The output of each hidden layer is
computed as follows:

h () = σ (h (−1) W () + b() ) (9)


Deep Learning-Based Recommender Systems 9

where () is the layer number, W is the weights matrix, b is the bias vector, and σ
is a non-linear activation function. We use the Rectified Linear Unit (ReLU) as the
activation function.
Our model takes input from the article’s textual data, X j = {x 1 , x 2 , . . . , x s },
where x i is a value between [0, 1] and s represents the vocabulary size of the arti-
cles’ titles and abstracts. In other words, the input of our autoencoder network is
a normalized bag-of-words histograms of filtered vocabularies of the articles’ titles
and abstracts.
Batch Normalization (BN) [14] has been proven to be a proper solution for the
internal covariant shift problem, where the layer’s input distribution in deep neural
networks changes across the time of training, and causes difficulty in training the
model. In addition, BN can work as a regularization procedure like Dropout [15]
in deep neural networks. Accordingly, we apply a batch normalization layer after
each hidden layer in our autoencoder to obtain a stable distribution from each layer’s
output.
Furthermore, we use the idea of the attention mechanism to work between the
encoder and the decoder, such that only the relevant parts of the encoder output are
selected for the input reconstruction. Attention in deep learning can be described
simply as a vector of weights to show the importance of the input elements. Thus,
the intuition behind attention is that not all parts of the input are equally significant,
i.e., only few parts are significant for the model. We first calculate the scores as the
probability distribution of the encoder’s output using the so f tmax(.) function.

ezc
f (z c ) = zd
(10)
de

The probability distribution and the encoder output are then multiplied using
element-wise multiplication function to get Z j .
We use the attentive autoencoder to pretrain the items’ contextual information
and then integrate the compressed representation, Z j , in computing the items’ latent
factors, V j , from the matrix factorization method. The dimension space of Z j and V j
are set to be equal to each other. Finally, we adopt the binary cross-entropy (Eq. 11)
as the loss function we want to minimize in our attentive autoencoder model.

L=− yk log( pk ) − (1 − yk ) log(1 − pk ) (11)
k

where yk corresponds to the correct labels and pk corresponds to the predicted


values.
10 M. Alfarhood and J. Cheng

3.3 Probabilistic Matrix Factorization

Probabilistic Matrix Factorization (PMF) [3] is a probabilistic linear model where


the prior distributions of the latent factors and users’ preferences are drawn from
Gaussian normal distribution.

u i ∼ N (0, λ−1
u I)
v j ∼ N (0, λ−1
v I) (12)
pi j ∼ N (u i v Tj , σ 2 )

We integrate the items’ contents, trained through the attentive autoencoder, into
PMF. Therefore, the objective function in Eq. 3 has been changed slightly to become
 ci j λu λv  
L= ( pi j − u i v Tj )2 + u i 2 + v j − θ (X j )2 (13)
i, j∈R
2 2 2

where θ (X j ) = Encoder (X j ) = Z j .
Thus, taking the partial derivative of our previous objective function with respect
to both u i and v j results in the following equations that minimize our objective
function the most

u i = (V Ci V T + λu I )−1 V Ci Pi
(14)
v j = (U C j U T + λv I )−1 U C j P j + λv θ (X j )

We optimize the values of u i and v j using the Alternating Least Squares (ALS)
optimization method.

3.4 Prediction

After our model has been trained and the latent factors of users and articles, U and
V , are identified, we calculate our model’s prediction scores of useri and each article
as the dot product of vector u i with all vectors in V as scor esi = u i V T . Then, we
sort all articles based on our model predication scores in descending order, and then
recommend the top-K articles for that useri . We go through all users in U in our
evaluation and report the average performance among all users. The overall process
of our approach is illustrated in Algorithm 1.
Deep Learning-Based Recommender Systems 11

Algorithm 1: CATA algorithm


1 pretrain autoencoder with input X ;
2 Z ← θ(X );
3 U, V ← Initialize with random values;
4 while <NOT converge> do
5 for <each user i > do
6 u i ← update using Equation 14;
7 end for
8 for <each article j > do
9 vi ← update using Equation 14;
10 end for
11 end while
12 for <each user i > do
13 scor esi ← u i V T ;
14 sort(scor esi ) in descending order;
15 end for
16 Evaluate the top-K recommendations;

4 Experiments

In this section, we conduct extensive experiments aiming to answer the following


research questions:
• RQ1: How does our proposed model, CATA, perform against state-of-the-art meth-
ods?
• RQ2: Does adding the attention mechanism actually improve our model perfor-
mance?
• RQ3: How could different values of the regularization parameters (λu and λv )
affect CATA performance?
• RQ4: What is the impact of different dimension values of users and items’ latent
space on CATA performance?
• RQ5: How many training epochs are sufficient for pretraining our autoencoder?
Before answering these research questions, we first describe the datasets used in our
evaluations, the evaluation metrics, and the baseline approaches we use to evaluate
our model against.

4.1 Datasets

Three scientific article datasets are used to evaluate our model against the state-of-
the-art methods. All datasets are collected from CiteULike website. The first dataset
is called Citeulike-a, which is collected by [5]. It has 5,551 users, 16,980 articles, and
204,986 user-article pairs. The sparseness of this dataset is extremely high, where
only around 0.22% of the user-article matrix has interactions. Each user has at least
12 M. Alfarhood and J. Cheng

ten articles in his or her library. On average, each user has 37 articles in his or her
library and each article has been added to 12 users’ libraries. The second dataset is
called Citeulike-t, which is collected by [6]. It has 7,947 users, 25,975 articles, and
134,860 user-article pairs. This dataset is actually sparser than the first one with only
0.07% available user-article interactions. Each user has at least three articles in his
or her library. On average, each user has 17 articles in his or her library and each
article has been added to five users’ libraries. Lastly, Citeulike-2004–2007 is the third
dataset, and it is collected by [16]. It is three times bigger than the previous ones with
regard to the user-article matrix. It has 3,039 users, 210,137 articles, and 284,960
user-article pairs. This dataset is the sparsest in this experiment, with a sparsity equal
to 99.95%. Each user has at least ten articles in his or her library. On average, each
user has 94 articles in his or her library and each article has been added only to one
user library. Brief statistics of the datasets are shown in Table 1.
Title and abstract of each article are given in each dataset. The average number
of words per article in both title and abstract after our text preprocessing is 67 words
in Citeulike-a, 19 words in Citeulike-t, and 55 words in Citeulike-2004–2007. We
follow the same preprocessing techniques as the state-of-the-art models in [5, 7,
8]. A five-stage procedure to preprocess the textual content is displayed in Fig. 3.
Each article title and abstract are combined together and then are preprocessed such
that stop words are removed. After that, top-N distinct words based on the TF-IDF
measurement are picked out. 8,000 distinct words are selected for the Citeulike-a
dataset, 20,000 distinct words are selected for the Citeulike-t dataset, and 19,871
distinct words are selected for the Citeulike-2004–2007 dataset to form the bag-of-
words histogram, which are then normalized into values between 0 and 1 based on
the vocabularies’ occurrences.

Table 1 Descriptions of citeulike datasets


Dataset #Users #Articles #Pairs Sparsity (%)
Citeulike-a 5,551 16,980 204,986 99.78
Citeulike-t 7,947 25,975 134,860 99.93
Citeulike-2004–2007 3,039 210,137 284,960 99.95

Fig. 3 A five-stage procedure for preprocessing articles’ titles and abstracts


Deep Learning-Based Recommender Systems 13

Fig. 4 Ratio of articles that have been added to ≤N users’ libraries

Figure 4 shows the ratio of articles that have been added to five or fewer users’
libraries. For example, 15, 77, and 99% of the articles in Citeulike-a, Citeulike-t, and
Citeulike-2004–2007, respectively, are added to five or fewer users’ libraries. Also,
only 1% of the articles in Citeulike-a have been added only to one user library, while
the rest of the articles have been added to more than this number. On the contrary,
13, and 77% of the articles in Citeulike-t and Citeulike-2004–2007 have been added
only to one user library. This proves the sparseness of the data with regard to articles
as we go from one dataset to another.

4.2 Evaluation Methodology

We follow the state-of-the-art techniques [6–8] to generate our training and testing
sets. For each dataset, we create two versions of the dataset for sparse and dense
settings. In total, six dataset cases are used in our evaluation. To form the sparse
(P = 1) and the dense (P = 10) datasets, P items are randomly selected from each
user library to generate the training set while the remaining items from each user
library are used to generate the testing set. As a result, when P = 1, only 2.7, 5.9,
and 1.1% of the data entries are used to generate the training set in Citeulike-a,
Citeulike-t, and Citeulike-2004–2007, respectively. Similarly, 27.1, 39.6, and 10.7%
of the data entries are used to generate the training set when P = 10 as Fig. 5 shows.
14 M. Alfarhood and J. Cheng

(a) Citeulike-a (b) Citeulike-t (c) Citeulike-2004-2007

Fig. 5 The percentage of the data entries that forms the training and testing sets in all citeulike
datasets

We use recall and Discounted Cumulative Gain (DCG) as our evaluation metrics
to test how our model performs. Recall is usually used to evaluate recommender
systems with implicit feedback. However, precision is not favorable to use with
implicit feedback because the zero value in the user-article interaction matrix has
two meanings: either the user is not interested in the article, or the user is not aware
of the existence of this article. Therefore, using the precision metric only assumes
that for each zero value the user is not interested in the article, which is not the case.
Recall per user can be measured using the following formula:

Relevant Articles ∩ K Recommended Articles


recall@K = (15)
Relevant Articles
where K is set manually in the experiment and represents the top K articles of each
user. We set K = 10, 50, 100, 150, 200, 250, and 300 in our evaluations. The overall
recall can be calculated as the average recall among all users. If K equals the number
of articles in the dataset, recall will have a value of 1.
Recall, however, does not take into account the ranking of articles within the
top-K recommendations, as long as they are in the top-K list. However, DCG does.
DCG shows the capability of the recommendation engine to recommend articles at
the top of the ranking list. Articles in higher ranked K positions have more value than
others. The DCG among all users can be measured using the following equation:
|U |
1   rel(i)
K
DCG@K = (16)
|U | u=1 i=1 log2 (i + 1)

where |U | is the total number of users, i is the rank of the top-K articles recommended
by the model, and rel(i) is an indicator function that outputs 1 if the article at rank i
is a relevant article, and 0 otherwise.
Deep Learning-Based Recommender Systems 15

4.3 Baselines

We evaluate our approach against the following baselines described below:


• POP: Popular predictor is a non-personalized recommender system. It recom-
mends the most popular articles in the training set, such that all users get identical
recommendations. It is widely used as the baseline for personalized recommen-
dation models.
• CDL: Collaborative Deep Learning (CDL) [7] is a deep Bayesian model that
jointly models both user-item interaction data and items’ content via a Stacked
Denoising Autoencoder (SDAE) with a Probabilistic Matrix Factorization (PMF).
• CML+F: Collaborative Metric Learning (CML) [17] is a metric learning model
that pulls items liked by a user closer to that user. Recommendations are then
generated based on the K-Nearest Neighbor of each user. CML+F additionally
uses a neural network with two fully connected layers to train items’ features
(articles’ tags in this chapter) to update items’ location. CML+F has been shown
to have a better performance than CML.
• CVAE: Collaborative Variational Autoencoder (CVAE) [8] is a probabilistic
model that jointly models both user-item interaction data and items’ content
through a Variational Autoencoder (VAE) with a Probabilistic Matrix Factoriza-
tion (PMF). It can be considered as the baseline of our proposed approach since
CVAE and CATA share the same strategy.
For hyper-parameter settings, we set the confidence variables (i.e., a and b) to
a = 1, and b = 0.01. These are the same values used in CDL and CVAE as well.
Also, a four-layer network is used to construct our attentive autoencoder. The four-
layer network has the following shape “#Vocabularies-400-200-100-50-100-200-
400-#Vocabularies”.

4.4 Experimental Results

For each dataset, we repeat the data splitting four times with different random splits
of training and testing set, which has been previously described in the evaluation
methodology section. We use one split as a validation experiment to find the optimal
parameters of λu and λv for our model and the state-of-the-art models as well. We
search a grid of the following values {0.01, 0.1, 1, 10, 100} and the best values on
the validation experiment have been reported in Table 2. The other three splits are
used to report the average performance of our model against the baselines. In this
section, we address the research questions that have been previously defined in the
beginning of this section.
16 M. Alfarhood and J. Cheng

Table 2 Parameter settings for λu and λv based on the validation experiment


Approach Citeulike-a Citeulike-t Citeulike-2004–2007
Sparse Dense Sparse Dense Sparse Dense
λu λv λu λv λu λv λu λv λu λv λu λv
CDL 0.01 10 0.01 10 0.01 10 0.01 10 0.01 10 0.01 10
CVAE 0.1 10 1 10 0.1 10 0.1 10 0.1 10 0.1 10
CATA 10 0.1 10 0.1 10 0.1 10 0.1 10 0.1 10 0.1

4.4.1 RQ1

To evaluate how our model performs, we conduct quantitative and qualitative com-
parisons to answer this question. Figures 6, 7, 8, and 9 show the performance of the
top-K recommendations under the sparse and dense settings in terms of recall and
DCG. First, the sparse cases are very challenging for any proposed model since there
is less data for training. In the sparse setting where there is only one article in each
user’s library in the training set, our model, CATA, outperforms the baselines in all
datasets in terms of recall and DCG, as Figs. 6 and 7 show. More importantly, CATA
outperforms the baselines by a wide margin in the Citeulike-2004–2007 dataset,
where it is actually sparser and contains a huge number of articles. This validates the
robustness of our model against data sparsity.

(a) Citeulike-a (b) Citeulike-t (c) Citeulike-2004-2007

Fig. 6 Recall performance under the sparse setting, P = 1

(a) Citeulike-a (b) Citeulike-t (c) Citeulike-2004-2007

Fig. 7 DCG performance under the sparse setting, P = 1


Deep Learning-Based Recommender Systems 17

(a) Citeulike-a (b) Citeulike-t (c) Citeulike-2004-2007

Fig. 8 Recall performance under the dense setting, P = 10

(a) Citeulike-a (b) Citeulike-t (c) Citeulike-2004-2007

Fig. 9 DCG performance under the dense setting, P = 10

Second, with the dense setting where there are more articles in each user’s library
in the training set, our model performs comparably to other baselines in Citeulike-a
and Citeulike-t datasets as Figs. 8 and 9 show. As a matter of fact, many of the existing
models actually work well under this setting, but poorly under the sparse setting. For
example, CML+F achieves a competitive performance on the dense data; however, it
fails on the sparse data since their metric space needs more interactions for users to
capture their preferences. On the other hand, CATA outperforms the other baselines
under this setting in the Citeulike-2004–2007 dataset. As a result, this experiment
demonstrates the capability of our model for making more relevant recommendations
under both sparse and dense data conditions.
In addition to the previous quantitative comparisons, some qualitative results are
reported in Table 3 as well. The table shows the top ten recommendations gener-
ated by our model (CATA) and the state-of-the-art model (CVAE) for one randomly
selected user under the sparse setting using the Citeulike-2004–2007 dataset. In this
example, user 20 has only one article in his training library, entitled “Assessment of
Attention Deficit/ Hyperactivity Disorder in Adult Alcoholics”. From this example,
this user seems to be interested in the treatment of Attention-Deficit/Hyperactivity
Disorder (ADHD) among alcohol- and drug-using populations. Comparing the rec-
ommendation results between the two models, our model recommends more relevant
articles based on the user’s interests. For instance, most of the recommended articles
using the CATA model are related to the same topic, i.e., alcohol- and drug-users
with ADHD. However, there are some irrelevant articles recommended by CVAE,
18 M. Alfarhood and J. Cheng

Table 3 The top-10 recommendations for one selected random user under the sparse setting, P = 1,
using the citeulike-2004–2007 dataset
User ID: 20
Articles in the training set: assessment of attention deficit/ hyperactivity disorder in adult
alcoholics
CATA In user
library?
1. A Double-blind, placebo-controlled withdrawal trial of dexmethylphenidate No
hydrochloride in children with ADHD
2. Double-blind placebo-controlled trial of methylphenidate in the treatment of Yes
adult ADHD patients with comorbid cocaine...
3. Methylphenidate treatment for cocaine abusers with adult Yes
attention-deficit/hyperactivity disorder: a pilot study
4. A controlled trial of methylphenidate in adults with attention Yes
deficit/hyperactivity disorder and substance use disorders
5. Treatment of cocaine dependent treatment seekers with adult ADHD: Yes
double-blind comparison of methylphenidate and...
6. A large, double-blind, randomized clinical trial of methylphenidate in the Yes
treatment of adults with ADHS
7. Patterns of inattentive and hyperactive symptomatology in cocaine-addicted Yes
and non-cocaine-addicted smokers diagnosed...
8. Frontocortical activity in children with comorbidity of tic disorder and No
attention-deficit hyperactivity disorder
9. Gender effects on attention-deficit/hyperactivity disorder in adults, revisited Yes
10. Association between dopamine transporter (DAT1) genotype, left-sided Yes
inattention, and an enhanced response to...
CVAE In user
library?
1. Psycho-social correlates of unwed mothers No
2. A randomized, controlled trial of integrated home-school behavioral treatment No
for ADHD, predominantly inattentive type
3. Age and gender differences in children’s and adolescents’ adaptation to sexual No
abuse
4. Distress in individuals facing predictive DNA testing for autosomal dominant No
late-onset disorders: Comparing questionnaire results...
5. Combined treatment with sertraline and liothyronine in major depression: a No
randomized, double-blind, placebo-controlled trial
6. An open-label pilot study of methylphenidate in the treatment of cocaine Yes
dependent patients with adult ADHS
7. Treatment of cocaine dependent treatment seekers with adult ADHD: Yes
double-blind comparison of methylphenidate and...
8. ADouble-Blind, Placebo-Controlled Withdrawal Trial of Dexmethylphenidate No
Hydrochloride in Children with ADHS
9. Methylphenidate treatment for cocaine abusers with adult Yes
attention-deficit/hyperactivity disorder: a pilot study
10. A large, double-blind, randomized clinical trial of methylphenidate in the Yes
treatment of adults with ADHS
Deep Learning-Based Recommender Systems 19

Table 4 Performance comparisons on sparse data with using attention layer (CATA) and without
(CATA). Best values are marked in bold
Approach Citeulike-a Citeulike-t
Recall@300 DCG@300 Recall@300 DCG@300
CATA– 0.3003 1.6644 0.2260 0.4661
CATA 0.3060 1.7206 0.2425 0.5160

such as recommended articles numbers 1, 3, and 4. From this example and other
users’ examples we have examined, we can state that our model detects the major
elements of articles’ contents and users’ preferences more accurately.

4.4.2 RQ2

To examine the importance of adding the attention layer into our autoencoder, we
create another variant of our model that has the same architecture, but lacks the atten-
tion layer, which we call CATA–. We evaluate this model on the sparse cases using
Citeulike-a and Citeulike-t datasets. The performance comparisons are reported in
Table 4. As the table shows, adding the attention mechanism boosts the performance.
Consequently, using the attention mechanism gives more focus to some parts of the
encoded vocabularies in each article to better represent the contextual data, eventually
leading to increased recommendation quality.

4.4.3 RQ3

There are two regularization parameters (λu and λv ) that are used in the objective
function of the matrix factorization method to prevent the latent vectors’ magni-
tude from being too large, which eventually prevent the model from overfitting the
training data. Our previously reported results are obtained by setting λu and λv to
the numbers in Table 2 based on the validation experiment. However, we perform
multiple experiments to show the impact of different values of λu and λv and how
they affect our model’s performance. We use different values to set the parameters
from the following range {0.01, 0.1, 1, 10, 100}. Figure 10 visualizes how our model
performs under each combination of λu and λv . We find that our model has a lower
performance when the value of λv is considerably large under the dense setting, as
Figs. 10b, d show. On the other hand where the data is sparser in Figs. 10a, c, e, f, a
very small value of λu (e.g., 0.01) tends to have the lowest performance among all
other numbers. Even though Fig. 10f shows the performance under the dense setting
for the Citeulike-2004–2007 dataset, it still exemplifies the sparsity with regard to
articles as we indicate before in Fig. 4, where 80% of the articles have only been
added to one user’s library. Generally, we observe that optimal performance happens
in all datasets when λu = 10 and λv = 0.1. We can conclude that when there is suffi-
20 M. Alfarhood and J. Cheng

(a) Citeulike-a, P=1 (b) Citeulike-a, P=10 (c) Citeulike-t, P=1

(d) Citeulike-t,P=10 (e) Citeulike-2004-2007,P=1 (f) Citeulike–2004-2007,P=10

Fig. 10 The impact of λu and λv on CATA performance for a, b Citeulike-a, c, d citeulike-t, and
e, f citeulike-2004–2007 datasets

cient users’ feedback, items’ contextual information is no longer essential to obtain


users’ preferences, and vice versa.

4.4.4 RQ4

The vectors of the latent features (U and V ) represent the characteristic of users and
items that a model tries to learn from data. We examine the impact of the size of
these vectors on the performance of our model. In other words, we examine how
many dimensions in the latent space can represent the user and item features more
accurately. It is worth mentioning that our reported results in the RQ1 section use 50
dimensions, which is similar to the size used by the state-of-the-art model (CVAE)
in order to have fair comparisons. However, we run our model again using five
dimension sizes from the following values {25, 50, 100, 200, 400}. Figure 11 shows
how our model performs in terms of recall@100 under each dimension size. We
observe that increasing the dimension size in dense data leads always to a gradual
increase in our model performance, as shown in Fig. 11b. Also, larger dimension sizes
are recommended for sparse data as well. However, they do not necessary improve the
model’s performance all the time (e.g., the Citeulike-t dataset in Fig. 11a). Generally,
dimension sizes between 100 and 200 are suggested for the latent space dimension.
Deep Learning-Based Recommender Systems 21

(a) P=1 (b) P=10

Fig. 11 The performance of CATA model with respect to different dimension values of the latent
space under, a sparse data and b dense data

4.4.5 RQ5

We pretrain our autoencoder first until the loss value of the data converges sufficiently.
The loss value shows the error computed by the autoencoder’s loss function where
it shows how well the model reconstructs outputs from inputs. Figure 12 visualizes
the number of needed training epochs to render the loss value sufficiently stable. We
find that 200 epochs are sufficient for pretraining our autoencoder.

Fig. 12 The reduction in the loss values versus the number of training epochs
22 M. Alfarhood and J. Cheng

5 Conclusion

In this chapter, we present a Collaborative Attentive Autoencoder (CATA) for rec-


ommending scientific articles. We utilize an article’s textual data to learn a better
compressed representation of the data through the attention mechanism, which can
guide the training process to focus on the relevant part of the encoder output in order
to improve model predictions. CATA shows superiority over other state-of-the-art
methods on three scientific article datasets. The performance improvement of CATA
increases consistently as data sparsity increases. The qualitative results also reflect
the good quality of our model recommendations.
For potential future work, user data can be gathered and then used to update the
user latent factors in the same way as we update the item latent factors. Even though
user data is often not available due to privacy concerns (e.g., CiteULike datasets do
not have user data), we believe that item data, together with user-item interaction data,
can be used to infer user information. In addition, other variants of deep autoencoders
discussed in [18], could be investigated to replace the attentive autoencoder. Another
possible direction for future work is to explore new metric learning algorithms to
substitute the Matrix Factorization (MF) technique, because the dot product in MF
does not guarantee that items are placed correctly in the latent space with respect to
the triangle inequality between items.

References

1. M. Alfarhood, J. Cheng, Collaborative attentive autoencoder for scientific article recommen-


dation. in 2019 18th IEEE International Conference on Machine Learning and Applications
(ICMLA) (IEEE, 2019)
2. Y. Koren, R. Bell, C. Volinsky, Matrix factorization techniques for recommender systems.
Computer 8, 30–37 (2009)
3. A. Mnih, R. Salakhutdinov, Probabilistic matrix factorization. in Advances in Neural Informa-
tion Processing Systems (2008), pp. 1257–1264
4. R. Salakhutdinov, A. Mnih, G. Hinton, Restricted Boltzmann machines for collaborative filter-
ing. in Proceedings of the 24th International Conference on Machine Learning (ACM, 2007),
pp. 791–798
5. C. Wang, D. Blei, Collaborative topic modeling for recommending scientific articles. in Pro-
ceedings of the 17th ACM SIGKDD International Conference on Knowledge Discovery and
Data Mining (ACM, 2011), pp. 448–456
6. H. Wang, B. Chen, W. Li, Collaborative topic regression with social regularization for tag
recommendation. IJCAI (2013)
7. H. Wang, N. Wang, D. Yeung, collaborative deep learning for recommender systems. in Pro-
ceedings of the 21th ACM SIGKDD International Conference on Knowledge Discovery and
Data Mining (ACM, 2015), pp. 1235–1244
8. X. Li, J. She, Collaborative variational autoencoder for recommender systems. in Proceedings
of the 23rd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining
(ACM, 2017)
9. Y. Hu, Y. Koren, C. Volinsky, Collaborative filtering for implicit feedback datasets. in Proceed-
ings of the 8th IEEE International Conference on Data Mining (ICDM) (IEEE, 2008)
Deep Learning-Based Recommender Systems 23

10. S. Funk, Netflix update: try this at home (2006). https://sifter.org/simon/journal/20061211.


html. Accessed 13 Nov 2019
11. Y. Zhou, D. Wilkinson, R. Schreiber, R. Pan, Large-scale parallel collaborative filtering for the
netflix prize. in Proceedings of the International Conference on Algorithmic Applications in
Management (Springer, 2008)
12. R. Pan, Y. Zhou, B. Cao, N. Liu, R. Lukose, M. Scholz, Q. Yang, One-class collaborative
filtering. in Eighth IEEE International Conference on Data Mining, 2008. ICDM’08 (IEEE,
2008)
13. G. Hinton, R. Salakhutdinov, Reducing the dimensionality of data with neural networks. Sci-
ence 313(5786), 504–507 (2006)
14. S. Ioffe, C. Szegedy, Batch normalization: accelerating deep network training by reducing
internal covariate shift (2015). arXiv preprint arXiv:1502.03167
15. N. Srivastava, G. Hinton, A. Krizhevsky, I. Sutskever, R. Salakhutdinov, Dropout: a simple way
to prevent neural networks from overfitting. J. Mach. Learn. Res. (JMLR) 15(1), 1929–1958
(2014)
16. A. Alzogbi, Time-aware collaborative topic regression: towards higher relevance in textual
item recommendation. BIRNDL@ SIGIR (2018)
17. C. Hsieh, L. Yang, Y. Cui, T. Lin, S. Belongie, D. Estrin, collaborative metric learning. in
Proceedings of the 26th International Conference on World Wide Web. International World
Wide Web Conferences Steering Committee (2017)
18. M. Wani, F. Bhat, S. Afzal, A. Khan, Advances in Deep Learning (Springer, 2020)
A Comprehensive Set of Novel Residual
Blocks for Deep Learning Architectures
for Diagnosis of Retinal Diseases from
Optical Coherence Tomography Images

Sharif Amit Kamran, Sourajit Saha, Ali Shihab Sabbir, and Alireza Tavakkoli

Abstract Spectral Domain Optical Coherence Tomography (SD-OCT) is a demand-


ing imaging technique by which diagnosticians detect retinal diseases. Automating
the procedure for early detection and diagnosis of retinal diseases has been proposed
in many intricate ways through the use of image processing, machine learning, and
deep learning algorithms. Unfortunately, the traditional methods are erroneous in
nature and quite expensive as they require additional participation from the human
diagnosticians. In this chapter, we propose a comprehensive sets novel blocks for
building a deep learning architecture to effectively differentiate between different
pathologies causing retinal degeneration. We further show how integrating these
novel blocks within a novel network architecture gives a better classification accu-
racy of these disease and addresses the preexisting problems with gradient explosion
in the deep residual architectures. The technique proposed in this chapter achieves
better accuracy compared to the state of the art for two separately hosted Retinal
OCT image data-sets. Furthermore, we illustrate a real-time prediction system that
by exploiting this deep residual architecture, consisting one of these novel blocks,
outperforms expert ophthalmologists.

S. A. Kamran (B) · A. Tavakkoli


University of Nevada, Reno, NV, USA
e-mail: skamran@nevada.unr.edu
A. Tavakkoli
e-mail: tavakkol@unr.edu
S. Saha · A. S. Sabbir
Center for Cognitive Skill Enhancement, Independent University Bangladesh,
Dhaka, Bangladesh
e-mail: sourajit@iub.edu.bd
A. S. Sabbir
e-mail: asabbir@iub.edu.bd

© The Editor(s) (if applicable) and The Author(s), under exclusive license 25
to Springer Nature Singapore Pte Ltd. 2021
M. A. Wani et al. (eds.), Deep Learning Applications, Volume 2,
Advances in Intelligent Systems and Computing 1232,
https://doi.org/10.1007/978-981-15-6759-9_2
26 S. A. Kamran et al.

1 Introduction

Diabetes being one of the crucial health concerns affects up to 7.2% of the world
population and this number can potentially mount up to 600 million by the year
2040 [6, 38]. With the pervasiveness of diabetes, statistically—one out of every
three diabetic patients develop Diabetic Retinopathy (DR) [33]. Diabetic Retinopa-
thy causes severe damage to human vision which in turn engenders vision loss that
affects nearly 2.8% of world population [3]. In developed countries there exist effi-
cacious vision tests for DR screening and early treatment yet, a by product of such
systems is fallacious results. Moreover, identifying the false positives and false neg-
atives still remains a challenging impediment for diagnosticians. On the other hand,
DR is often mistreated in many developing and poorer economies, where access to
trained ophthalmologist and eye-care machinery may be insufficient. Therefore, this
is impending to shift the DR diagnostic technology to a system that is autonomous
in the pursuit of more accurate and faster test results on DR and other related retinal
diseases and inexpensive so that more people have access to it. This research pro-
poses a novel architecture based on convolutional neural network which can identify
Diabetic Retionpathy, while being able to categorize multiple retinal diseases with
near perfect accuracy in real time.
There are different retinal diseases other than Diabetic Retinopathy, such as Mac-
ular Degeneration. Macula is a retinal sensor that is found in the central region of
retina in human eyes. The retinal lens perceive light emitted from outside sources and
transform them into neural signals, a process otherwise known as vision. The Macula
plays an integral role in human vision from processing light via photo-receptor nerve
cells to aggregating them encoded into neural signals sent directly to the brain through
optic nerves. Retinal diseases such as Macular Degeneration, Diabetic Retinopathy,
and Choroidal Neovascularization are the leading causes of eye diseases and vision
loss worldwide.
In ophthalmology a technique called Spectral Domain Optical Coherence Tomog-
raphy (SD-OCT) is used for viewing the morphology of the retinal layers [32]. Fur-
thermore, another way to treat these diseases is to use depth-resolved tissue formation
data encoded in the magnitude and delay of the back-scattered light by spectral analy-
sis [1]. While retrieving the retinal image is performed by the computational process
of SD-OCT, differential diagnosis is conducted by human ophthalmologists. Conse-
quently, this leaves room for human error while performing differential diagnosis.
Hence, an autonomous expert system is beneficial for ophthalmologists to distin-
guish among different retinal diseases more precisely with fewer mistakes in a more
timely manner.
One of the predominant factors for misclassification of retinal maladies is due to
the stark similarity between Diabetic Retinopathy and other retinal diseases. They can
be grouped by three major categories, (i) Diabetic Macular Edema (DME) and Age-
related degeneration of retinal layers (AMD), (ii) Drusen, a condition where lipid or
protein build-up occurs in the retinal layer, and (iii) Choroidal Neovascularization
(CNV), a growth of new blood vessels in sub-retinal space. The most common retinal
A Comprehensive Set of Novel Residual Blocks for Deep Learning … 27

diseases are Diabetic Retinopathy and Age-related Macular Degeneration, worldwide


[9]. On the other hand, Drusen acts as an underlying cause that can trigger DR or
AMD in a prolonged time-frame. Choroidal Neovascularization (CNV), however is
an advanced stage of age-related Macular degeneration that affects about 200,000
people worldwide every year [8, 37].
Despite a decade of improvements to existing algorithms, identification of retinal
diseases still produces erroneous results and requires expert intervention. To address
this problem, we propose a novel deep neural network architecture which not only
identifies retinal diseases in real time but also performs better than human experts for
specific tasks. While our emphasis is to train a deep neural network model to mini-
mize the classification error as much as possible, we grapple with the challenges of
over-fitting on data, gradient explosion and gradient vanishing, as evident in many
deep neural network models. In this work, we propose a newly designed residual
Convolutional Neural Network (CNN) block that helps us reduce the memory foot-
print while we are able to train a much deeper CNN for better accuracy. Furthermore,
we propose a novel building block in our CNN architecture that contains a signal
attenuation mechanism, a newly written function to conjoint previous input layers
before passing onto the next one in the network. Then we further show, how these pro-
posed signal propagation techniques can lead to building a deeper and high-precision
network without succumbing to weight degradation, gradient explosion, and over-
fitting. In the following sections, we elaborate our principal contributions and also
provide a comparative analysis of different approaches.

2 Literature Review

In this section, we discuss a series of computational approaches that have been his-
torically adopted to diagnose SD-OCT imagery. These approaches dates back to the
earlier developments in image processing as well as the segmentation algorithms in
pre-deep learning era of computer vision. While those developments are important
and made tremendous strides in their specific domains, we also discuss the learning-
based approaches and deep learning models that were created and trained both with
and without transfer learning by other researchers in the pursuit of classifying retinal
diseases from SD-OCT images. In this section, we further explain the pros and cons
of the other deep learning classification models on SD-OCT images and the develop-
ments and contributions we achieved with our proposed deep learning architecture.

2.1 Traditional Image Analysis

The earliest approach traced in the pursuit of classifying retinal diseases from images
contains multiple image processing techniques followed by feature extraction and
classification [24]. Evidently, one such research was conducted where retinal dis-
28 S. A. Kamran et al.

eases are classified from images by finding abnormalities such as microaneurysms,


haemorrhages, exudate and cotton wool-spot from Retinal Fundus images [7]. This
approach exploits a noise reduction algorithm and blurring to branch out the four-
class problem to two cases of a two-class problems. Chronologically, they perform
background subtraction followed by shape estimation as feature extractor. Sequen-
tially they compute these extracted features to classify each of the four abnormali-
ties. In parallel, similar techniques with engineered features were adopted to detect
Diabeitc Macular Edema (DME) and Choroidal Neovascularization (CNV). The
images were manipulated on five discrete parameters: Retinal Thickness, augmen-
tation of Retinal Thickening, Macular volume, retinal morphology, and vitreoretinal
relationship [25]. There exists another efficacious method that compounded statisti-
cal classification with edge detection algorithms to detect sharp edges [28]. Sanchez
et al.’s [28] algorithm achieved a sensitivity score of 79.6% while classifying Diabeitc
Retionpathy. Ege et al.’s [7] approach incorporating Mahalanobis classifier detected
microaneurysms, haemorrhages, exudates, and cottonwool spots with a sensitivity
of 69, 83, 99, and 80%, respectively. It is evident that each of these techniques shows
promising improvements over the others, however they are not on par with human
diagnosticians in terms of precision. More effectual detection accuracy, therefore
is still required for these systems to be of assistance to human diagnosticians and
ophthalmologists.

2.2 Segmentation-Based Approaches

One of the most pronounced ways to identify a patient with Diabetic Macular Edema
is by enlarging macular density in retinal layer [1, 5]. Naturally, several systems have
been proposed and implemented which comprises retinal layer segmentation. Due to
evidence of liquids building up in the sub-retinal space as determined by the segmen-
tation algorithms, further identification of factors that engenders specific diseases are
made possible [17, 22, 23]. In [20, 26], the authors proposed the idea of segmenting
the intra-retinal layers in ten parts and then extracted the texture and depth infor-
mation from each layer. Subsequently, any aberrant retinal features are detected by
classifying the dissimilarity between healthy retinas and the diseased ones. More-
over, Niemeijer et al. [22] introduced a technique for 3D segmentation of regions
containing fluid in OCT images using a graph-based implementation. A graph-cut
algorithm is applied to get the final predictions from the information initially retrieved
from layer-based segmentation of fluid regions. Even though implementation based
on a previous segmentation of retinal layers have registered high scoring prediction
results, the initial step is reportedly laborious, prolonged and erroneous [10, 13]. As
reported in [19], retinal thickness measurements obtained by different systems have
stark dissimilarity. Therefore, it is neither efficient nor optimal to compare between
different retinal depth information retrieved by separate machines, despite of the
improved prediction accuracy over the feature engineering methods with traditional
A Comprehensive Set of Novel Residual Blocks for Deep Learning … 29

image analysis discussed earlier. These observations further enforces the fact that
segmentation-based approaches aren’t effective as a universal retinal disease recog-
nition system.

2.3 Machine Learning and Deep Learning Techniques

Recently, a combination of optimization algorithms from machine learning and a


series of deep neural network architectures have become a popular choice among
researchers and engineers, in regard to achieving state-of-the-art accuracy for recog-
nizing various retinal diseases [18, 21, 34]. Various deep learning architectures with
necessary modifications described in [36] can be employed to classify retinal dis-
eases from SD-OCT images as well. Awais et al. proposed a combination of VGG16
[31] with KNN and Random forest classifier with 100 trees to create a deep classifi-
cation model in order to differentiate between Normal Retina and Diabetic Macular
Edema. Lee et al. trained a standalone VGG16 architecture with a binary output to
classify Age-related Macular Edema (AMD) and cases with no retinal degeneration
[18]. While these systems exploit learning-based feature extraction from the vicinity
of large-scale image data-sets, the neural network models are computationally inef-
fective and suffers from large memory footprint. On the contrary, transfer learning
methods depend on weeks of training on millions of images and are not ideal for
finding stark differences between Retinal diseases. To help alleviate from all of these
challenges an architecture is necessary which is specially catered for identifying
retinal deceases with high precision, speed, and low memory usage.

2.4 Our Contributions

In this work, we propose a novel convolutional neural network which specializes


in identifying retinal diseases with near perfect precision. Moreover, through this
architecture we are proposing (a) a new residual unit subsuming Atrous Separable
Convolution, (b) a novel building block, and (c) a mechanism to prevent gradient
degradation. The proposed network outperforms other architectures with respect to
the number of parameters, accuracy, and memory size. Our proposed architecture is
trained from scratch and bench-marked on two publicly available data-sets: OCT2017
[16], Srinivasan2014 [32] data-sets. Henceforth, it doesn’t require any pre-trained
weights, reducing the training and deployment time of the model by many folds. We
believe with the deployment of this model, rapid identification and treatment can be
carried out with near perfect certainty. Additionally, it will aid the ophthalmologist
to get a second expert opinion in the course of differential diagnosis.
This work is an extension of our previous work [14] where we experiment with
different novel residual convolutional block architectures and achieve stat-of-the-art
performance on both OCT2017 [16] data-set and Srinivasan2014 [32] data-set. In
30 S. A. Kamran et al.

this chapter, we illustrate an exploratory analysis of laterally distinguishable novel


residual blocks and discuss the methodologies and observations that help us select
the optimal block in order to create our CNN architecture. In this chapter, we further
show our results on fivefold training and demonstrate the efficacy of our model and
discuss how our architectural design prevents the model from over-fitting the data at
hand. Along with that, we illustrate the differences of training on different variants
of our proposed model and further show our deployment pipeline in this chapter.

3 Proposed Methodology

In this section, we discuss our proposed methodologies and observations adopted


toward designing the proposed CNN architecture. We first elaborate how we train
both data-sets on different residual units, each with unique lateral propagation archi-
tecture. We then discuss how we select our proposed residual unit based on the
observations from training on the other variants. Sequentially, we illustrate how we
join our proposed residual unit with a signal attenuation mechanism and a newly
written signal propagation function to prevent gradient degradation. Subsequently,
we then demonstrate our proposed CNN architecture and the efficacy and novelty of
the model.
In Fig. 1 we exemplify different variants of residual unit, their attributes, and how
we arrive at our proposed variant of the residual block. Figure 2 illustrates the Deep
Convolutional Neural Network (CNN) architecture we propose for the classification
of retinal diseases from Optical Coherence Tomography (OCT) images. In Fig. 2a we
delineate how the proposed Residual Learning Unit improves feature learning capa-
bilities while discussing the techniques we adopt to reduce computational complexity

Fig. 1 Different variants of residual unit and our proposed residual unit with a novel lateral prop-
agation of neurons
A Comprehensive Set of Novel Residual Blocks for Deep Learning … 31

Fig. 2 Illustration of the Building Blocks of our proposed CNN [ OpticNet-71 ]. Only the very
first convoluiton (7 × 7) layer in the CNN and the very last convolution (1 × 1) layer in Stage [2,
3, 4]: Residual Convolutional Unit uses stride 2, while all other convolution operations use stride 1

for such performance enhancement. While, Fig. 2b depicts the proposed mechanism
to handle gradient degradation, Fig. 2c narrates the entire CNN architecture. We
discuss constituent segments of our CNN architecture—Optic Net and philosophy
behind our lateral design schemes over the following subsections.

3.1 Lateral and Operation-Wise Variants of Residual Unit

To design a model with higher efficacy, lower memory footprint and portable compu-
tational complexity we experimented with different types of convolution operations
and various lateral choices in the residual unit of our proposed neural network.

3.1.1 Vanila Residual Unit

Figure 1a represents the propagation of neurons in a vanilla residual unit with pre-
activation [11, 12]. While achieving a high prediction accuracy, the first row of
Table 1 however reports higher memory and computational footprint resulting from
such a design. The natural progression was to include a kernel with wide receptive
field, hence inclusion of dilation in the kernel operation as shown by Yu et al. [39].
The following Sect. 3.1.2 addresses the mechanism of this block in details.
32 S. A. Kamran et al.

Table 1 Comparison between different convolution operations used in the middle portion of resid-
ual unit
Type of convolution used in Approximate numbers Depletion factor for Accuracy
the middle of residual unit of parametersa parameter,  p (%) (%)
Regular convolution f 2 × D [i] × D [i−1] = 100 99.30
36,864
Atrous convolution ( f − 1)2 × D [i] × (1 − f)
1 2
= 44.9 97.20
D [i−1] = 16,384
Separable convolution ( f 2 + D [i] ) × D [i−1] = 1
f2
+ 1
D [i]
= 12.5 98.10
4,672
Atrous separable convolution (( f − 1)2 + D [i] ) × 1
f2
+ (1 − 1f )2 × D1[i] = 96.70
D [i−1] = 4,352 11.6
1 [i]
2 (( f − 1) (1 + 2 D ) + (1 − 1f )2 ×
1 2 1
Atrous convolution and (2 f )2
99.80
1 [i] [i−1]
atrous separable +2 D ) × D = ( 41 + 1
)= 14.4
convolution branched 2D [i]
5,248
a Here, kernel size, (f, f) = (3, 3). Depth (# kernels) in Residual unit’s middle operation, D [i] = 64
and first operation, D [i−1] = 64.
b The Test Accuracy reported in the table is obtained by training on OCT2017 [16] data-set, while

the backbone network is Optic-Net 71

3.1.2 Atrous Residual Unit

To reduce computational parameters, we then replace the middle 3 × 3 convolution


block in residual unit with a 2 × 2 atrous convolution with dilation rate two, as
detailed in Fig. 1b. Figure 3 further illustrates how atrous convolution with skipped
feature extraction capture minor details in the signal while reducing the number of
parameters by a reasonable margin—detailed in the second row of Table 1. On the
other hand, atrous residual unit registers a rather poor performance on our data-set.
What’s more, by not incorporating any depth information while doing convolution
operation, the spatial information overflows throughout the architecture. Resulting in
more error-prone results for borderline diagnosis. To address this problem depth-wise
convolution or Separable convolution is used in the next Sect. 3.1.3.

3.1.3 Separable Residual Unit

Concurrently we redesigned the residual unit with an atrous separable convolution


block in the middle feature extraction module, outlined in Fig. 1c. With a depth-
wise convolution followed by a point-wise operation we achieved a much lower
computational stress with a relatively better inference accuracy than we do on atrous
residual unit, as we report in the third row of Table 1. The reason being small depth
information dominates throughout the architecture. We tried to address this problem
by incorporating larger receptive fields using dilation in the depth-wise convolution
layer inside the Separable Residual unit which is discussed in details in Sect. 3.1.4.
A Comprehensive Set of Novel Residual Blocks for Deep Learning … 33

3.1.4 Atrous Separable Residual Unit

To further contract computational strain, we then take the separable residual unit and
replace the depth-wise 3 × 3 convolution block with an atrous convolution block of
2 × 2 with dilation rate two, as shown in Fig. 1d. Figure 3 further demonstrates the
mechanism of atrous separable convolution operation on signals. With this design
choice, we cut the parameter count by 87.4% which is the fastest computational com-
plexity in our experiment (Fourth row of Table 1). However, the repetitive use of such
residual unit on our proposed neural network accumulates to the lowest prediction
accuracy on our data-set. Furthermore, using only separable convolution results to
depth-wise feature extraction in most of the identity blocks while not encompassing
any spatial information in any of those learnable layers. That effectively answers the
underlying reason for having such low accuracy out of all the novel residual units.
Observations made on such trade-offs between computational complexity and
performance we arrive at our proposed residual (Fig. 1e) unit which we discuss next
in Sect. 3.2.

Fig. 3 Atrous separable convolution. Exploiting ( f − 1) × ( f − 1) convolutions in stead of f × f


that yields more fine grained and coarse features with better depth resolution compared to regular
atrous convolution
34 S. A. Kamran et al.

3.2 Proposed Residual Unit and Learning Mechanism

Historically, Residual Units [11, 12] used in Deep Residual Convolutional Neural
Networks (CNN), process the incoming input through three convolution operations
while adding the incoming input with the processed output. These three convolutional
operations are (1 × 1), (3×3) and (1 × 1) convolutions. Therefore, replacing the
(3 × 3) convolution in the middle with other types of convolutional operations can
potentially change the learning behavior, computational complexity, and eventually
prediction performance, as demonstrated in Fig. 1.
We experimented with different convolution operations as replacement for the
(3 × 3) middle convolution and observed which choice contributes the most to reduce
the number of parameters, ergo computational complexity, as depicted in Table 1.
Furthermore, in Table 1, we use a depletion factor for parameters,  p which is a
ratio of number of parameters in the replaced convolution and regular convolution
expressed in percent.
In our proposed residual unit, we replace the middle (3 × 3) convolution operation
with two different operations running in parallel as detailed in Fig. 2a. Whereas, a
conventional residual unit uses D [i] number of channels for the middle convolution,
we use 21 D [i] number of channels for each of the newly replaced operations to prevent
any surge in parameter. In the proposed branching operation we use a (2 × 2) Atrous
convolution (C2 ) with dilation rate, r = 2 to get a (3 × 3) receptive field in the left
branch while in the right branch we use a (2 × 2) Atrous separable convolution (C3 )
with dilation rate, r = 2 to get a (3 × 3) receptive field. Sequentially, the results are then
added together. Furthermore, separable convolution [30] disentangles the spatial and
depth-wise feature maps separately while Atrous convolutions inspect both spatial
and depth channels together. We hypothesize that adding two such feature maps that
are learned very differently shall help trigger more robust and subtle features.
 
X l+1 = X l + (X l  C1  C2 ) + (X l  C1  C3 )  C4
(1)
= X l + F̂(X l , Wl )

Figure 3 shows how adding Atrous and Atrous separable feature maps help dis-
entangle the input image space with more depth information instead of activating
only the predominant edges. Moreover, the last row of Table 1 confirms that adopt-
ing this strategy still reduces the computational complexity by a reasonable margin,
while improving inference accuracy. Equation (1) further clarifies how input signals
X l travel through the proposed residual unit shown in Fig. 2a, where  refers to
convolution operation.
A Comprehensive Set of Novel Residual Blocks for Deep Learning … 35

3.3 Proposed Building Block and Signal Propagation

In this section we discuss the proposed building block as a constituent part of Optic-
Net. As shown in Fig. 2b we split the input signal (X l ) into two branches (1) Stack
of Residual Units, (2) Signal Exhaustion. Later in this section, we explain how we
connect these two branches to propagate signals further in the network.


N
α(X l ) = X l+N = X l + F̂(X i , Wi ) (2)
i=l

3.3.1 Stack of Residual Units

In order to initiate a novel learning chain for propagating signals through stacking
several of the proposed residual units linearly, we suggest to combine global residual
effects enhanced by pre-activation residual units [12] and our proposed set of convo-
lution operations (Fig. 2a). As shown in (1), F̂(X l , Wl ) denotes all the proposed set
of convolution operations inside a residual unit for input X l . We sequentially stack
these residual units N times over X l which is input to our proposed building block,
as narrated in Fig. 2b. Equation (2) illustrates the state of output signal denoted by
X l+N which is processed through a stack of residual units of length N . For the sake
of further demonstration we denote X l+N as α(X l ).

3.3.2 Signal Exhaustion

In the proposed building block, we propagate the input signal X l through an


Max-pooling layer to achieve spatial down-sampling which we then up-sample
through Bi-linear interpolation. Since the down-sampling module only for-
wards the strongest activations, the interpolated reconstruction makes a dense
spatial volume from the down-sampled representation—intrinsically exhaust-
ing the incoming signal X l . As detailed in Fig. 2b, we sequentially pass the
exhausted signal space through sigmoid activation, σ(X l ) = 1/(1 + e−X l ). Recent
research [35] has shown how auto-encoding with residual skip connections
[Pencoder (input|code) → Pdecoder (code|input) + input] improve attention-
oriented classification performance. However unlike auto-encoders, max-pooling,
and Bi-linear interpolation functions are not enabled with learning mechanism. In
Optic-Net, we capacitate the CNN to activate spikes from an exhausted signal space
because we use it as a mechanism to avert gradient degradation. For the sake of
further demonstration we denote the exhausted signal activation module, σ(X l ) as
β(X l ).

τ (X l ) = α(X l ) + β(X l ) + α(X l ) × β(X l )) (3)


36 S. A. Kamran et al.

   
∂ 
N
∂τ (X l )
= 1 + σ(X l ) × 1 + F̂(X i , Wi )
∂ Xl ∂ X l i=l
 N     
+ 1 + Xl + F̂(X i , Wi ) × σ(X l ) × 1 − σ(X l )
i=l (4)
   
∂ 
 N
σ (X l )
= 1+ × 1+ F̂(X i , Wi )
1 − σ(X l ) ∂ X l i=l
   
+ 1 + X l+N × σ  (X l )

3.3.3 Signal Propagation

As shown if Fig. 2b, we process the residual signal, α(X l ) and exhausted signal,
β(X l ) following (3) and we denote the output signal propagated from the proposed
building block as τ (X l ). Our hypothesis behind such design is that, whenever one
of the branch falls prey to gradient degradation from a mini-batch the other branch
manages to propagate signals unaffected by the mini-batch with amplified absolute
gradient. To validate our hypothesis (3) shows that, τ (X l ) ≈ α(X l ), ∀β(X l ) ≈ 0
and τ (X l ) ≈ β(X l ), ∀α(X l ) ≈ 0 illustrating how the unaffected branch survives the
degradation in the affected branch. However, when none of the branch gets affected by
gradient amplification the multiplication (α(X l ) × β(X l )) balances out the increase
in signal propagation due to both branch’s addition. Equation (4) delineates the
gradient of building block output τ (X l ) with respect to building block input X l
calculated during back-propagation for optimization.

3.4 CNN Architecture and The Optimization Chain

Figure 2c portrays the entire CNN architecture with all the building blocks and con-
stituent components joined together. First, the input batch (224 × 224 × 3) is propa-
gated through a 7 × 7 Conv with stride 2 that follows batch-normalization and ReLU
activation. Then we propagate the signals via a Residual Convolution Unit (same as
the unit used in [12]) which is then followed by our proposed building block. We
propagate the signals through this [Residual Convolution Unit → Building Block]
procedure for S = 4 times, as we call them stage 1, 2, 3, and 4, respectively. Then
global average pooling is applied to the signals which passes through two more Fully
Connected(FC) layers for the loss function which is denoted by ξ.
In Table 2, we show the number of feature maps (Layer Depth) we use for each
layer in the network. The output shape of the input tensor after four consecutive
stages are (112×112×256), (56×56×512), (28×28×1024), and (14×14×2048),
A Comprehensive Set of Novel Residual Blocks for Deep Learning … 37

Table 2 Architectural specifications for opticnet-71 and layer-wise analysis for number of feature
maps in comparison with Resnet50-v1 [11]
Layer name ResNet50 V1 [11] OpticNet71 [Ours]
Conv 7 × 7 [64] × 1 [64] × 1
Stage1: Res Conv [64, 64, 256] × 1 [64, 64, 256] × 1
Stage1: Res Unit [64, 64, 256] × 2 [32, 32, 32, 256] × 4
Stage2: Res Conv [128, 128, 512] × 1 [128, 128, 512] × 1
Stage2: Res Unit [128, 128, 512] × 3 [64, 64, 64, 512] × 4
Stage3: Res Conv [256, 256, 1024] × 1 [256, 256, 1024] × 1
Stage3: Res Unit [256, 256, 1024] × 5 [128, 128, 128, 1024] × 3
Stage4: Res Conv [512, 512, 2048] × 1 [512, 512, 2048] × 1
Stage4: Res Unit [512, 512, 2048] × 2 [256, 256, 256, 2048] × 3
Global Avg Pool 2048 2048
Dense layer 1 K (Classes) 256
Dense layer 2 – K (Classes)
Parameters 25.64 Million 12.50 Million
Required FLOPs 3.8 ×109 2.5 ×107
CNN Memory 98.20 MB 48.80 MB

respectively. Moreover, FC1 ∈ R2048×256 and FC2 ∈ R256×K , where K = number of


classes.
Stage  j 
∂ξ ∂ξ ∂τ (X l )
= × (5)
∂ Xl Stage
∂τ (X l ) j=1 ∂ Xl
j

Equation (5) represents the gradient calculated for the entire network chain dis-
tributed over stages of optimization. As (4) suggests, the term (1 + σ(X l ))—in com-
parison with [12])—works as an extra layer of protection to prevent possible gradient
explosion caused by the stacked residual units by multiplying nonzero activations
with the residual unit’s gradients. Moreover, the term (1 + X l+N ) indicates that the
optimization chain still has access to signals from much earlier in the network and
to prevent unwanted spikes in activations the term σ  (X l ) can still mitigate gradient
expansion which can potentially jeopardize learning otherwise.

4 Experiments

The following section contains information for training, validating, and testing the
architectures with different settings and hyper-parameters. Moreover, it gives a
detailed analysis of how the architectures were compared with previous techniques
and expert human diagnosticians. Additionally, the juxtaposition was drawn in terms
of speed, accuracy, sensitivity, specificity, memory usage, and penalty weighted met-
38 S. A. Kamran et al.

rics. All the files related to the experimentation and training can be found in the
following Code Repository: https://github.com/SharifAmit/OCT_Classification.

4.1 Specifications of Data-Sets and Preprocessing Techniques

We benchmark our model against two distinct data-sets (different scale, sample space,
etc.). The first data-set aims at correctly recognizing and differentiating between four
distinct retinal states provided by the OCT2017 [16] data-set. Where, the stages are
normal healthy retina, Drusen, Choroidal Neovascularization (CNV), and Diabetic
Macular Edema (DME). OCT2017 [16] data-set contains 84,484 images (provided
as high quality TIFF format with three non-RGB color channels). We split them
into 83,484 train-set and 1000 test-set. The second data-set—Srinivasan2014 [32]—
consists of three classes and aims at classifying normal healthy specimen of retina,
Age-Related Macular Degeneration (AMD) and Diabetic Macular Edema (DME).
Srinivasan2014 [32] data-set consists of 3,231 image samples that we split into 2,916
train-set, 315 test-set. We resize images from both data-sets to 224 × 224 × 3 for
both training and testing. For both the data-set we do fivefold cross-validation on the
training set and find the best models.

4.2 Performance Metrics

We calculated four standard metrics to evaluate our CNN model on both data-sets:
Accuracy (6), Sensitivity (7), Specificity (8) and a Special Weighted Error (9) from
[16]. Where N is the number of image samples and K is the number of classes. Here
TP, FP, FN, and TN denotes True Positive, False Positive, False Negative, and True
Negative, respectively. We report True Positive Rate (TPR) or Sensitivity (6) and
True Negative Rate (TNR) or Specificity (7) for the both the data-sets [16, 32]. For
this, we calculate the TPR and TNR for individual classes then sum all the values
and then divide that by the number of classes (K).

1 
Accuracy = TP (6)
N
1  TP
Sensitivity = (7)
K T P + FN

1  TN
Specificity = (8)
K T N + FP

1 
Weighted Error = Wi j · X i j (9)
N i, j∈K
A Comprehensive Set of Novel Residual Blocks for Deep Learning … 39

Table 3 Penalty weights proposed for Oct2017 [16]


Normal Drusen CNV1 DME2
Normal 0 1 1 1
Drusen 1 0 1 1
CNV1 4 2 0 1
DME2 4 2 1 0
1 CNV: Chorodial Neovascularization
2 DME: Diabetic Macular Edema

Fig. 4 Confusion matrix generated by OpticNet-71 for OCT2017[16] data-set

As reported in [16], the penalty points for incorrect categorization of a retinal


disease can be arbitrary. Table 3 shows the penalty weight values for misidentifying
a category set by [16] which is only specific to OCT2017 [16] data-set. To calculate
Weighted Error (9), we apply element-wise multiplication on the confusion matrix
generated by specific model (Fig. 4 represents the confusion matrix generated by
OpticNet-71 on OCT2017 [16] data-set) and the weight matrix in Table 3 and then
take an average over the number of samples. Here, the penalty weight values from
Table 3 is denoted by W and the model’s prediction (confusion matrix) is denoted
by X where i, j denotes the rows and columns of the confusion matrix.

4.3 Training OpticNet-71 and Obtained Results

4.3.1 OCT2017 Data-Set

In Table 4, we report a comprehensive study for OCT2017 [16] data-set evalu-


ated through testing standards such as Test Accuracy, Sensitivity, Specificity, and
Weighted Error. OpticNet-71 scores the highest Test Accuracy (99.80%) among
40 S. A. Kamran et al.

Table 4 Results on Oct2017 [16] data-set


Architectures Test accuracy Sensitivity Specificity Weighted error
InceptionV3 (limited) 93.40 96.60 94.00 12.70
Human expert 2 [16] 92.10 99.39 94.03 10.50
InceptionV3 [16] 96.60 97.80 97.40 6.60
ResNet50-v1 [11] 99.30 99.30 99.76 1.00
MobileNet-v2 [29] 99.40 99.40 99.80 0.60
Human expert 5 [16] 99.70 99.70 99.90 0.40
Xception [4] 99.70 99.70 99.90 0.30
OpticNet-71 [Ours] 99.80 99.80 99.93 0.20

other existing solutions, with a Sensitivity and Specificity of 99.80 and 99.93%.
Furthermore, the Weighted Error is reported to be a mere 0.20% which can be visu-
alized in Fig. 4 as our architecture misidentifies one Drusen and one DME sample as
CNV. However, the penalty weight is only 1 for each of the misclassification as we
report in Table 3. Sequentially, with our proposed OpticNet-71 we obtain state-of-
the-art results on OCT2017 [16] data-set across all four performance metrics, while
significantly surpassing human benchmarks as mentioned in Table 4.

4.3.2 Srinivasan2014 Data-Set

We benchmark OpticNet-71 against other methods in Table 5 while evaluating Srini-


vasan2014 [32] data-set through three metrics: Accuracy, Sensitivity, and Specificity.
Among the mentioned solutions in Table 5 Lee et al. [18] use modified VGG-16,
Awais et al. [2] use VGG architecture with KNN in final layer and Karri et al. [15]
uses GoogleNet while they all use weights from transfer learning on ImageNet [27].
As shown in Table 5, OpticNet-71 achieves state-of-the-art result by scoring 100%
Accuracy, Sensitivity, and Specificity.
Furthermore, we train ResNet50-v1 [11], ResNet50-v2 [12], MobileNet-v2 [29],
and Xception [4] using pre-trained weights from 3.2 million ImageNet Data-set
consisting of 1000 categories [27] to compare with our achieved results (Tables 4
and 5), while we train Optic-Net from scratch with randomly initialized weights.

4.4 Hyper-Parameter Tuning and Performance Evaluation

The hyper-parameters while training OpticNet-47, OpticNet-63, OpticNet-71,


MobileNet-v2 [29], XceptionNet [4], ResNet50-v2 [12], ResNet50-v1 [11] are as fol-
lows: batch size, b = 8; epochs = 30; learning rate, αlr = 1e−4 ; step decay, γ = 1e−1 .
We use adaptive learning rate and decrease it using αlrnew = αcurrent
lr
× γ, if validation
A Comprehensive Set of Novel Residual Blocks for Deep Learning … 41

Table 5 Results on Srinivasan2014 [32] data-set


Architectures Test accuracy Sensitivity Specificity
Lee et al. [18] 87.63 84.63 91.54
Awais et al. [2] 93.00 87.00 100.00
ResNet50-v1 [11] 94.92 94.92 97.46
Karri et al. [15] 96.00 – –
MobileNet-v2 [29] 97.46 97.46 98.73
Xception [4] 99.36 99.36 99.68
OpticNet-71 [Ours] 100.00 100.00 100.00

loss doesn’t lower for six consecutive epochs. Moreover, we set the lowest learning
rate to αmin
lr
= 1e−8 . Furthermore, We use Adam optimizer with default parameters
of β1adam
= 0.90 and β2adam = 0.99 for all training schemes. We train OCT2017 [16]
data-set for 44 hours and Srinivasan2014 [32] data-set for 2 hours on a 8 GB NVIDIA
GTX 1070 GPU.
Inception-v3 models under-perform compared to both pre-trained models and
OpticNet-71 as seen in Table 4. OpticNet-71 takes 0.03 seconds to make prediction
on an OCT image—which is real time and while accomplishing state-of-the-art
results on OCT2017 [16], Srinivasan2014[32] data-set our model also surpass human
level prediction on OCT images as depicted in Table 4. Human experts are real
diagnosticians as reported in [16]. In [16], there are six diagnosticians and the highest
performing one is Human Expert 5 while the lowest performing one is Human Expert
2. To validate our CNN architecture’s optimization strength we also train two smaller
versions of OptcNet-71 on both dataests, which are OpticNet-47 ( [N1 N2 N3 N4 ] =
[2 2 2 2] ) and OpticNet-63 ( [N1 N2 N3 N4 ] = [3 3 3 3] ). In Fig. 5 we unfold how
all of our variants of OpticNet outperforms the pre-trained CNNs on Srinivasan2014
[32] data-set while OpticNet-71 outperforms all the pre-trained CNNs on OCT2017
[16] data-set in terms of accuracy as well as performance-memory trade-off.

4.5 Analysis of Proposed Residual Interpolated Block

To understand how the Residual Interpolated Block works, we visualize features


by passing a test image through our CNN model. Figure 6a illustrates some of the
sharp signals propagated by Residual blocks while the interpolation reconstruction
routine propagates a weak signal activation, yet the resulting signal space is both
more sharp and fine grained compared to their Residual counterparts. Since the conv
layers in the following stage activates the incoming signals first, we do not output
an activated signal space from a stage. Instead we only activate the interpolation
counterpart and then multiply with the last residual block’s non-activated output
space while adding the raw signals with the multiplied signal as well—which we
42 S. A. Kamran et al.

Fig. 5 Test accuracy (%), CNN memory (Mega-Bytes) and model parameters (Millions) on
OCT2017 [16] data-set and Srinivasan2014 [32] data-set

consider as output from each stage as narrated in Fig. 6b. Furthermore, Fig. 6b
portrays how element-wise addition with the element-wise multiplication between
signals helps the learning propagation of OpticNet-71. Figure 6b precisely depicts
why this optimization chain is particularly significant, as a zero activation can cancel
out a live signal channel from the residual counterpart (τ (X l ) = α(X l ) + β(X l ) ×
(1 + α(X l ))) while a dead signal channel can also cancel out a nonzero activation
from the interpolation counterpart (τ (X l ) = β(X l ) + α(X l ) × (1 + β(X l )))—thus
preventing all signals of a stage from dying and resulting in catastrophic optimization
failure due to dead weights or gradient explosion.

4.6 Fivefold Training on OCT2017 Data-set

We ran a fivefold cross-validation on OCT2017 [16] data-set to mitigate over-fitting


and assert our model’s generalization potency. We randomly split the 84,484 images
into 5 subsets, each containing approximately 16896 images. We validate our model
on one of the subsets after training the model on the remaining 4 subsets. We continue
this process for 5 times to validate on all different subsets. In Fig. 7 we report the
best training and validation accuracy among all five folds. Figure 7 also depicts the
arithmetic average of training and validation accuracy among all five folds. Further-
more, each accuracy is reported for a particular batch (size of 8) with all 30 epochs
registered in the X-axis of Fig. 7. Following the similar manner we report best train-
ing loss, best validation loss, average training loss and average validation loss for
A Comprehensive Set of Novel Residual Blocks for Deep Learning … 43

Fig. 6 a Visualizing input images from each class through different layers of Optic-Net 71. As
shown, the feature maps at the end of each building block learns more fine-grained features by
focusing sometimes on the same shapes—rather in different regions of the image—learning to decide
what features lead the image to the ground truth. b The learning progression, however, shows how
exhausting the signal propagated with residual activation learns to detect more thin edges—delving
further into the Macular region to learn anomalies. While using the signal exhaustion mechanism
sometimes, important features can be lost during training. Our experiments show, by using more of
these building blocks we can reduce that risk of feature loss and improve overall optimization for
Optic-Net 71

OCT2017[16] data-set in Fig. 8. As Fig. 7 illustrates, the best and average accuracy
from both training and validation set reaches a stable maxima after training and this
phenomena attest to the efficacy of our model. Furthermore, in Fig. 8 we report a
discernible discrepancy between the average validation loss and best validation loss
among the fivefold which assures that our model’s performance is not a resultant
factor of over-fitting.

4.7 Training different variants of OpticNet-71 on OCT2017

We experimented with different variants of residual units to find the best version
of OpticNet-71 for both the data-set. First, we ran the experiment with a vanilla
convolution on OCT2017 [16] data-set. But with that we were under-performing
against Human-expert 5. After that we incorporated dilation with convolution layers
with which we reached sensitivity of 99.73% and weighted error of 0.8% as shown in
Table 6. So, it was worse than using residual units with vanilla convolution. Next, we
tried with separable convolution and its dilated counterpart with which we achieved
a precision score of 99.30% (vanilla) and 99.40% (dilated) consecutively. Lastly, we
tried out the proposed novel residual unit consisting of dilated convolution on one
branch and separable convolution with dilation on the other. With this, we reached
44 S. A. Kamran et al.

Fig. 7 Average accuracy and best accuracy of the fivefold cross-validation on the OCT2017 [16]
validation and training

Fig. 8 Average accuracy and minimum loss of the fivefold cross-validation on the OCT2017 [16]
validation and training

our desired precision score of 99.93% and weighted error of 0.2% beating Human
Expert 5 and reached state-of-the-art accuracy. It is worthwhile to mention that the
hyper-parameters for all the architectures were same. Moreover, the training was
done with the same optimizer (Adam) for 30 epochs. So, it is quite evident that
OpticNet-71, comprising of dilated convolution and dilated separable convolution,
was the optimum choice for deployment and prediction in the wild.
A Comprehensive Set of Novel Residual Blocks for Deep Learning … 45

Table 6 Results on OCT2017 [16] using different variants of OpticNet-71


Architectures Test accuracy Sensitivity Specificity Weighted error
Optic-Net 71 99.40 99.40 99.80 0.60
(vanilla convolution)
Optic-Net 71 99.20 99.20 99.73 0.80
(dilated convolution)
Optic-Net 71 99.30 99.30 99.76 0.70
(separable convolution)
Optic-Net 71 99.40 99.40 99.80 0.60
(separable convolution
with dilation)
Optic-Net 71 99.80 99.80 99.93 0.20
(dilated convolution +
separable convolution
with dilation)

4.8 Deployment Pipeline of Our System

In this section, we expound the application pipeline we have used to deploy our
model for users to interact with. The application pipeline consists of two fragments
as depicted in Fig. 9: (a) The Front End, (b) The Back End. The user interacts with

Fig. 9 Application pipeline of optic net


46 S. A. Kamran et al.

an app in the front end where one can upload an OCT image. The input image
which is an image uploaded by user on the app is then passed onto the back end.
The input image first goes through a pre-defined set of preprocessing steps and then
gets forwarded to our CNN model (Optic Net-71). All of these processes take place
on the server. The class prediction score outputted by our model is then sent back
to the app corresponding to the input image specific request by user. For our CNN
is lightweight and capable of outputting a high precision prediction in real time, it
facilitates a smooth user experience. Meanwhile, the uploaded image along with it’s
prediction tensor is simultaneously stored on cloud storage for further fine-tuning
and training of our model which expedites the goal of heightening system precision
and widens new horizons to foster further research and development.

5 Conclusion

In this chapter, we introduced a novel sets of residual blocks that can be infused to
build a convolutional neural network for abridging the relation between diagnosing
retinal diseases with expert level precision. Additionally, by exploiting this archi-
tecture we devise a practical solution that can address the problem of vanishing and
exploding gradients. This work is an extension of our previous work [14] which illus-
trates the exploratory analysis of different novel blocks and how effective it is in the
diagnosis of retinal degeneration. In future, we would like to expand on this research
to address other sub-types of retinal degeneration and isolate the boundaries of the
macular subspace in retina, which in turn will assist the expert ophthalmologist to
carry out their differential diagnosis,

Acknowledgments We would like to thank https://www.cse.unr.edu/CVL/ “UNR Computer


Vision Laboratory” and http://ccse.iub.edu.bd/ “Center for Cognitive Skill Enhancement” for pro-
viding us with the technical support.

References

1. K. Alsaih, G. Lemaitre, M. Rastgoo, J. Massich, D. Sidibé, F. Meriaudeau, Machine learning


techniques for diabetic macular edema (dme) classification on sd-oct images. Biomed. Eng.
Online 16(1), 68 (2017)
2. M. Awais, H. Müller, T.B. Tang, F. Meriaudeau, Classification of sd-oct images using a deep
learning approach, in 2017 IEEE International Conference on Signal and Image Processing
Applications (ICSIPA) (IEEE, 2017), pp. 489–492
3. R.R. Bourne, G.A. Stevens, R.A. White, J.L. Smith, S.R. Flaxman, H. Price, J.B. Jonas, J.
Keeffe, J. Leasher, K. Naidoo et al., Causes of vision loss worldwide, 1990–2010: a systematic
analysis. Lancet Glob. Health 1(6), e339–e349 (2013)
4. F. Chollet, Xception: deep learning with depthwise separable convolutions, in Proceedings of
the IEEE Conference on Computer Vision and Pattern Recognition (2017), pp. 1251–1258
A Comprehensive Set of Novel Residual Blocks for Deep Learning … 47

5. R.A. Costa, M. Skaf, L.A. Melo Jr., D. Calucci, J.A. Cardillo, J.C. Castro, D. Huang, M.
Wojtkowski, Retinal assessment using optical coherence tomography. Prog. Retin. Eye Res.
25(3), 325–353 (2006)
6. C. Prevention et al., National diabetes statistics report, 2017 (2017)
7. B.M. Ege, O.K. Hejlesen, O.V. Larsen, K. Møller, B. Jennings, D. Kerr, D.A. Cavan, Screening
for diabetic retinopathy using computer based image analysis and statistical classification.
Comput. Methods Programs Biomed. 62(3), 165–175 (2000)
8. N. Ferrara, Vascular endothelial growth factor and age-related macular degeneration: from
basic science to therapy. Nat. Med. 16(10), 1107 (2010)
9. D.S. Friedman, B.J. O’Colmain, B. Munoz, S.C. Tomany, C. McCarty, P. De Jong, B. Nemesure,
P. Mitchell, J. Kempen et al., Prevalence of age-related macular degeneration in the united states.
Arch Ophthalmol 122(4), 564–572 (2004)
10. I. Ghorbel, F. Rossant, I. Bloch, S. Tick, M. Paques, Automated segmentation of macular
layers in OCT images and quantitative evaluation of performances. Pattern Recognit. 44(8),
1590–1603 (2011)
11. K. He, X. Zhang, S. Ren, J. Sun, Deep residual learning for image recognition, in Proceedings
of the IEEE Conference on Computer Vision and Pattern Recognition (2016), pp. 770–778
12. K. He, X. Zhang, S. Ren, J. Sun, J.: Identity mappings in deep residual networks, in European
Conference on Computer Vision (Springer, 2016), pp. 630–645
13. R. Kafieh, H. Rabbani, S. Kermani, A review of algorithms for segmentation of optical coher-
ence tomography from retina. J. Med. Signals Sens. 3(1), 45 (2013)
14. S.A. Kamran, S. Saha, A.S. Sabbir, A. Tavakkoli, Optic-net: a novel convolutional neural
network for diagnosis of retinal diseases from optical tomography images, in 2019 18th IEEE
International Conference On Machine Learning And Applications (ICMLA) (2019), pp. 964–
971
15. S.P.K. Karri, D. Chakraborty, J. Chatterjee, Transfer learning based classification of opti-
cal coherence tomography images with diabetic macular edema and dry age-related macular
degeneration. Biomed. Opt. Express 8(2), 579–592 (2017)
16. D.S. Kermany, M. Goldbaum, W. Cai, C.C. Valentim, H. Liang, S.L. Baxter, A. McKeown, G.
Yang, X. Wu, F. Yan et al., Identifying medical diagnoses and treatable diseases by image-based
deep learning. Cell 172(5), 1122–1131 (2018)
17. A. Lang, A. Carass, M. Hauser, E.S. Sotirchos, P.A. Calabresi, H.S. Ying, J.L. Prince, Retinal
layer segmentation of macular oct images using boundary classification. Biomed. Opt. Express
4(7), 1133–1152 (2013)
18. C.S. Lee, D.M. Baughman, A.Y. Lee, Deep learning is effective for classifying normal versus
age-related macular degeneration oct images. Ophthalmol. Retin. 1(4), 322–327 (2017)
19. J.Y. Lee, S.J. Chiu, P.P. Srinivasan, J.A. Izatt, C.A. Toth, S. Farsiu, G.J. Jaffe, Fully automatic
software for retinal thickness in eyes with diabetic macular edema from images acquired by
cirrus and spectralis systems. Investig. ophthalmol. Vis. Sci. 54(12), 7595–7602 (2013)
20. K. Lee, M. Niemeijer, M.K. Garvin, Y.H. Kwon, M. Sonka, M.D. Abramoff, Segmentation
of the optic disc in 3-d OCT scans of the optic nerve head. IEEE Trans. Med. Imaging 29(1),
159–168 (2010)
21. G. Lemaître, M. Rastgoo, J. Massich, C.Y. Cheung, T.Y. Wong, E. Lamoureux, D. Milea, F.
Mériaudeau, D. Sidibé, Classification of sd-oct volumes using local binary patterns: experi-
mental validation for dme detection. J. Ophthalmol. 2016 (2016)
22. X.C. MeindertNiemeijer, L.Z.K. Lee, M.D. Abràmoff, M. Sonka, 3d segmentation of fluid-
associated abnormalities in retinal oct: Probability constrained graph-search-graph-cut. IEEE
Trans. Med. Imaging 31(8), 1521–1531 (2012)
23. A. Mishra, A. Wong, K. Bizheva, D.A. Clausi, Intra-retinal layer segmentation in optical
coherence tomography images. Opt. Express 17(26), 23719–23728 (2009)
24. H. Nguyen, A. Roychoudhry, A. Shannon, Classification of diabetic retinopathy lesions from
stereoscopic fundus images, in Proceedings of the 19th Annual International Conference of
the IEEE Engineering in Medicine and Biology Society.’Magnificent Milestones and Emerging
Opportunities in Medical Engineering (Cat. No. 97CH36136), vol. 1 (IEEE, 1997), pp. 426–
428
48 S. A. Kamran et al.

25. G. Panozzo, B. Parolini, E. Gusson, A. Mercanti, S. Pinackatt, G. Bertoldo, S. Pignatto, Diabetic


macular edema: an oct-based classification. Semin. Ophthalmol. 19, 13–20 (Taylor & Francis)
(2004)
26. G. Quellec, K. Lee, M. Dolejsi, M.K. Garvin, M.D. Abramoff, M. Sonka, Three-dimensional
analysis of retinal layer texture: identification of fluid-filled regions in sd-oct of the macula.
IEEE Trans. Med. imaging 29(6), 1321–1330 (2010)
27. O. Russakovsky, J. Deng, H. Su, J. Krause, S. Satheesh, S. Ma, Z. Huang, A. Karpathy, A.
Khosla, M. Bernstein et al., Imagenet large scale visual recognition challenge. Int. J. Comput.
Vis. 115(3), 211–252 (2015)
28. C.I. Sánchez, R. Hornero, M.I. Lopez, J. Poza, Retinal image analysis to detect and quantify
lesions associated with diabetic retinopathy, in The 26th Annual International Conference of
the IEEE Engineering in Medicine and Biology Society, vol. 1 (IEEE, 2004), pp. 1624–1627
29. M. Sandler, A. Howard, M. Zhu, A. Zhmoginov, L.C. Chen, Mobilenetv2: inverted residuals
and linear bottlenecks, in Proceedings of the IEEE Conference on Computer Vision and Pattern
Recognition, pp. 4510–4520 (2018)
30. L. Sifre, S. Mallat, Rigid-motion scattering for image classification. Ph.D. thesis, vol. 1, no. 3
(2014)
31. K. Simonyan, A. Zisserman, Very deep convolutional networks for large-scale image recogni-
tion (2014). arXiv preprint arXiv:1409.1556
32. P.P. Srinivasan, L.A. Kim, P.S. Mettu, S.W. Cousins, G.M. Comer, J.A. Izatt, S. Farsiu, Fully
automated detection of diabetic macular edema and dry age-related macular degeneration from
optical coherence tomography images. Biomed. Opt. Express 5(10), 3568–3577 (2014)
33. D.S.W. Ting, G.C.M. Cheung, T.Y. Wong, Diabetic retinopathy: global prevalence, major risk
factors, screening practices and public health challenges: a review. Clin. Exp. Ophthalmol.
44(4), 260–277 (2016)
34. M. Treder, J.L. Lauermann, N. Eter, Automated detection of exudative age-related macular
degeneration in spectral domain optical coherence tomography using deep learning. Graefe’s
Arch. Clin. Exp. Ophthalmol. 256(2), 259–265 (2018)
35. F. Wang, M. Jiang, C. Qian, S. Yang, C. Li, H. Zhang, X. Wang, X. Tang, Residual attention
network for image classification, in Proceedings of the IEEE Conference on Computer Vision
and Pattern Recognition, pp. 3156–3164 (2017)
36. M.A. Wani, F.A. Bhat, S. Afzal, A.I. Khan, Advances in Deep Learning (Springer, 2020)
37. W.L. Wong, X. Su, X. Li, C.M.G. Cheung, R. Klein, C.Y. Cheng, T.Y. Wong, Global preva-
lence of age-related macular degeneration and disease burden projection for 2020 and 2040: a
systematic review and meta-analysis. Lancet Glob. Health 2(2), e106–e116 (2014)
38. J.W. Yau, S.L. Rogers, R. Kawasaki, E.L. Lamoureux, J.W. Kowalski, T. Bek, S.J. Chen, J.M.
Dekker, A. Fletcher, J. Grauslund et al., Global prevalence and major risk factors of diabetic
retinopathy. Diabetes Care 35(3), 556–564 (2012)
39. F. Yu, V. Koltun, Multi-scale context aggregation by dilated convolutions (2015). arXiv preprint
arXiv:1511.07122
Three-Stream Convolutional Neural
Network for Human Fall Detection

Guilherme Vieira Leite, Gabriel Pellegrino da Silva, and Helio Pedrini

Abstract Lower child mortality rates, advances in medicine, and cultural changes
have increased life expectancy to above 60-years old in developed countries. Some
countries expect that, by 2030, 20% of their population will be over 65 years old. The
quality of life at this advanced age is highly dictated by the individual’s health, which
will determine whether the elderly can engage in important activities to their well-
being, independence, and personal satisfaction. Old age is accompanied by health
problems caused by biological limitations and muscle weakness. This weakening
facilitates the occurrence of falls, which are responsible for the deaths of approxi-
mately 646,000 people worldwide and, even when a minor fall occurs, it can still
cause fractures, break bones, or damage soft tissues, which will not heal completely.
Injuries and damages of this nature, in turn, will consume the self-confidence of the
individual, diminishing their independence. In this work, we propose a method capa-
ble of detecting human falls in video sequences using multi-channel convolutional
neural networks (CNN). Our method makes use of a 3D CNN fed with features pre-
viously extracted from each frame to generate a vector for each channels. Then, the
vectors are concatenated, and a support vector machine (SVM) is applied to classify
the vectors and indicate whether or not there was a fall. We experiment with four
types of features, namely: (i) optical flow, (ii) visual rhythm, (iii) pose estimation, and
(iv) saliency map. The benchmarks used (UR Fall Detection Dataset (URFD) [33]
and (ii) Fall Detection Dataset (FDD) [12]) are publicly available and our results are
compared to those in the literature. The metrics selected for evaluation are balanced

G. V. Leite · G. P. da Silva · H. Pedrini (B)


Institute of Computing, University of Campinas, Campinas, SP, Brazil
e-mail: helio@ic.unicamp.br
G. V. Leite
e-mail: guilherme.vieira.leite@gmail.com
G. P. da Silva
e-mail: gpsunicamp016@gmail.com
© The Editor(s) (if applicable) and The Author(s), under exclusive license 49
to Springer Nature Singapore Pte Ltd. 2021
M. A. Wani et al. (eds.), Deep Learning Applications, Volume 2,
Advances in Intelligent Systems and Computing 1232,
https://doi.org/10.1007/978-981-15-6759-9_3
50 G. V. Leite et al.

accuracy, accuracy, sensitivity, and specificity. Our results are competitive with those
obtained by the state of the art on both URFD and FDD datasets. To the authors’
knowledge, we are the first to perform cross-tests between the datasets in question
and to report results for the balanced accuracy metric. The proposed method is able
to detect falls in the selected benchmarks. Fall detection, as well as activity classi-
fication in videos, is strongly related to the network’s ability to interpret temporal
information and, as expected, optical flow is the most relevant feature for detecting
falls.

1 Introduction

Developed countries have reached a life expectancy of over 60 years of age [76].
Some countries in the European Union and China expect 20% of their population to
be over 65 by 2030 [25]. According to the World Health Organization [25, 78], this
recent level is a side effect of scientific advances, medical discoveries, reduced child
mortality, and cultural changes.
However, although human beings are living longer, the quality of this life span
is mainly defined by their health, since it dictates an individual’s independence,
satisfaction, and the possibility of engaging in activities that are important to their
well being. The relation between health and life quality inspired several research
groups to try and develop assisting technologies focusing on the elderly population.

1.1 Motivation

Naturally, health problems will appear along with the aging process, mainly due to
the biological limitations of the human body and muscle weakness. As a side effect,
this weakening increases the elderly chances of suffering a fall. Falls are the second
leading cause of domestic death worldwide, killing around 646,000 people every
year [25]. Reports indicate that between 28 and 35% of the population over 65 years
old falls at least once a year and this percentage rises to 32–42% to individuals over
70 years old.
In an attempt to summarize which events may lead to a fall, Lusardi et al. [46]
reported risk factors, including how falls occur, who are falling and some precautions
to avoid them. The effects of a fall accident to this fragile population can unleash
a chain reaction of physiological and psychological damages, which consequently
could decrease the elderly’s self-confidence and independence. Even if the accident
is a small fall, it can break or fracture bones and damage soft tissues that, due to old
age, may never fully recover.
To avoid serious injuries the elder population needs constant care and monitoring,
especially regarding the fact that an accident’s effects are related to the time between
its occurrence and the beginning of adequate care. That being said, qualified care-
Three-Stream Convolutional Neural Network … 51

Emergency Situation Fall


Detection
System Emergency Check the
Frames
raised a
suspicion signal situation
False
alarm
Operator is contacted

Technical Remote
report operator
Sensors Camera Classifier
Blood Heartbeat
pressure sensor

IMU

Fall
alert

+
Local
processing
center

Fig. 1 Diagram illustrating the main components of a monitoring system

givers are expensive, especially when added to the already inflated budget of health
care in advanced age, which might lead families to relocate the elderly. Relocation
is a common practice, in which the surrounding family moves the elderly from their
dwelling place to the family’s home, causing discomfort and stress to adapt to a new
environment and routine.
Based on the scenario described previously, a technological solution can be
devised in the form of an emergency system that would reliably trigger qualified
assistance automatically. Thus, the response time between the accident and care
would be reduced, allowing the elderly to dwell in their own home, keeping their
independence and self-esteem. The diagram in Fig. 1 illustrates the components of
such system, also known as an assisted daily living (ADL) system.
The system is fed by a range of sensors, installed around the home or attached to
the subject’s body, which monitors information such as blood pressure, heartbeats,
body temperature, blood glucose levels, acceleration, and posture. These devices
are connected to a local processing center that, in addition to continuously building
technical reports on the health of the elderly, also activates a remote operator in an
emergency alert. Upon receiving the alert, the operator performs a situation check
and, upon verifying the need, dispatch medical assistance.
52 G. V. Leite et al.

1.2 Research Questions

Aware of the dangers and consequences that a fall represents for the elderly population
and, as a video-based fall detection system can benefit from deep learning techniques,
we propose a fall detection module, as part of an ADL system, which can detect falls
in video frame sequences. In implementing this proposal, we intend to answer the
following research questions:

1. Are the feature channels able to maintain enough temporal information so that a
network can learn its patterns?
2. Which feature channel best contribute to the problem of fall detection?
3. Does the effectiveness of the proposed method remain for other human fall
datasets?
4. Are three-dimensional (3D) architectures more discriminative than two-
dimensional (2D) equivalent to this problem?

1.3 Contributions

The main contribution of this work is the multi-channel model to detect human
falls [8, 9, 37]. Each channel represents a feature that we judge by descriptive of a
fall event and, upon being ensembled, they produce a better estimation of a fall.
The model was tested on two publicly available datasets, and is open-sourced [38].
In addition, we present an extensive comparison between different channel combina-
tions and the employed architecture, whose best results are comparable to the state of
the art. Finally, we also discuss the implications of simulated datasets as a benchmark
for fall detection methods.

1.4 Chapter Layout

The remaining of this work is structured as follows. Section 2 reviews the exist-
ing approaches in the related literature. Section 3 presents the main concepts that
were used in the implementation of this work. Section 4 describes in detail the pro-
posed methodology and its steps. Section 5 exhibits the selected datasets used in
the experiments, our evaluation metrics, the experiments carried out, their results,
and discussions. Finally, Sect. 6 presents some final considerations and proposals for
future work.
Three-Stream Convolutional Neural Network … 53

2 Related Work

In the following subsections, we elucidate the related work to fall detection and
their methods. The works were split into two groups, based on the sensors used: (i)
methods without videos and (ii) methods with videos.

2.1 Methods Without Videos

The works in this group utilize various sensors to obtain data, which can be a watch,
accelerometer, gyroscope, heart rate sensor or a smartphone, that is, any sensor that
is not a camera.
Khin et al. [28] developed a triangulation system that uses several presence sensors
installed in the monitored room of a home. The presence of someone in the room
causes a disturbance in the sensors and a fall would trigger an activation pattern.
Despite not reporting the results, the authors stated that they tested this configuration
and that the sensors detected different patterns for different actions, indicating that
the solution could be used to detect falls.
Kukharenko and Romanenko [31] obtained their data from a wrist device. A
threshold was used to detect impact or a “weightless” state, and after this detection,
the algorithm waits a second and analyzes the information obtained. The method was
tested on some volunteers, who reported complaints such as forgetting to wear the
device and the discomfort it caused.
Kumar et al. [32] placed a sensor attached to a belt on the person’s waist and
compared four methods to detect falls: threshold, support vector machines (SVM),
K -nearest neighbors (KNN) and dynamic time warping (DTW). The authors also
commented on the importance of sensors attached to the body, since they would
monitor the individual constantly as it would not present blind spots, as opposed to
cameras.
Vallejo et al. [72] developed a deep neural network to classify the sensor’s data.
The chosen sensor is a gyroscope worn at the waist and the network is composed
of three hidden layers, with five neurons each. The authors carried out experiments
with adults aged 19–56 years.
Zhao et al. [84] collected data from a gyroscope attached to the individual’s waist
and used a decision tree to classify the information. The experiments were performed
on five random adults.
Zigel et al. [86] used accelerometers and microphones as sensors, however, the
sensors were installed in the environment, instead of being attached to the subject.
The sensors would detect vibrations and feed a quadratic classifier. The tests were
executed with a test dummy, which was released from an upright position.
54 G. V. Leite et al.

2.2 Methods With Videos

In this section, we grouped the methods whose main data sources are video sequences
from cameras. Despite a sensor similarity, these methods presented a variety of
solutions such as the following works that used threshold-based activation techniques.
To isolate the human silhouette, Lee and Mihailidis [36] performed a background
subtraction alongside with a region extraction. After this, the posture was determined
through a threshold of the values of the silhouette’s perimeter, speed of the silhouette’s
center, and by the Feret’s diameter. The solution was tested on a dataset created by
the authors.
Nizam et al. [53] used two thresholds, the first verified if the body speed was high
and, if so, a second threshold verified whether the joints’ position was close to the
ground. The joints’ position was obtained after subtracting the background, with a
Kinect camera. The experiments were carried out on a dataset created by the authors.
Sase and Bhandari [57] applied a threshold in which, a fall was defined as if the
region of interest was smaller than one-third of the individual’s height. The region
of interest was obtained by background extraction and the method was tested on the
basis URFD [33]. Bhandari et al. [5] applied a threshold on the speed and direction of
the region of interest. A combination of Shi-Tomasi and Lucas–Kanade was applied
to determine the region of interest. The authors tested the approach in the URFD [33]
set and reported 95% accuracy.
Another widely used classification technique is the SVM, used by Abobakr et
al. [2]. The method used depth information to subtract the background of the video
frame, applied a random forest algorithm to estimate the posture and classified it
with SVM. Fan et al. [17] also separated the picture between the foreground and
background and fitted an ellipse to the silhouette of the human body found. From the
ellipse, six features were extracted and served to a slow feature function. An SVM
classified the output of the slow feature function and the experiments were executed
on the SDUFall [47] dataset.
Harrou et al. [23] used an SVM classifier that received features extracted from
video frames. During the tests, the authors compared the SVM with a multivariate
exponentially weighted moving average (MEWMA) and tested the solution on the
datasets URFD [33] and FDD [12].
Mohd et al. [50] fed an SVM classifier with information on the height, speed,
acceleration, and position of the joints and performed tests on three datasets: TST Fall
Detection [20], URFD [33] and Fall Detection by Zhang [83]. Panahi and Ghods [56]
subtracted the background from the depth information, fitted an ellipse to the shape of
the individual, classified the ellipse with SVM and performed tests on the URFD [33]
dataset.
Concerned with privacy, the following works argued that solutions to detect falls
should offer anonymity options. Therefore, Edgcomb and Vahid [16] tested the effec-
tiveness of a binary tree classifier over a time series. The authors compared different
means to hide identity, such as blurring, extracting the silhouette, replacing the indi-
vidual with an opaque ellipse or an opaque box. They conducted tests on their dataset,
Three-Stream Convolutional Neural Network … 55

with 23 videos recorded. Lin et al. [42] investigated a solution focused on privacy
using a silhouette. They applied a K-NN classifier in addition to a timer that checks
whether the individual’s pose has returned to normal. The tests were performed by
laboratory volunteers.
Some studies used convolutional neural networks, such as the case of
Anishchenko [4], which implemented an adaptation of the AlexNet architecture
to detect falls in the FDD [12] dataset. Fan et al. [18] used a CNN to monitor and
assess the degree of completeness of an event. A stack of video frames was used in
a VGG-16 architecture and its result was associated with the first frame in the stack.
The method was tested on two datasets: FDD [12] and Hockey Fights [52]. Their
results were reported in terms of completeness of the falls.
Huang et al. [26] used the OpenPose algorithm to obtain the coordinates of the
body joints. Two classifiers (SVM and VGG-16) were compared to classify the coor-
dinates. The experiments were carried out on the datasets URFD [33] and FDD [12].
Li et al. [40] created a modification of CNN’s architecture, AlexNet. The solution
was tested on the URFD [33] dataset, also the authors reported that the solution
classified between ADLs and falls in real time.
Min et al. [49] used an R-CNN (CNN of regions) to analyze a scene, which
generates spatial relationships between furniture and the human being on the scene,
and then classified the spatial relationship between them. The authors experimented
on three datasets: URFD [33], KTH [58], and a dataset created by them. Núñez-
Marcos et al. [54] performed the classification with a VGG-16. The authors calculated
the dense optical flow, which served as a characteristic for the network to classify.
They tested the method on the URFD [33] and FDD [12] databases.
Coincidentally, all works that used recurrent neural networks used the same archi-
tecture, long-short term memory (LSTM). Lie et al. [41] applied a recurrent neu-
ral network, with LSTM cells, to classify the individual’s posture. The stance was
extracted by a CNN and the experiments were carried out on a dataset created by the
authors.
Shojaei-Hashemi et al. [60] used a Microsoft Kinect device to obtain the indi-
vidual’s posture information and an LSTM as a classifier. The experiments were
performed on the NTU RGB+D dataset. Furthermore, the authors reported one advan-
tage of using the Kinect, since the posture extraction could be achieved in real time.
Lu et al. [43] proposed the application of an LSTM right after a 3D CNN. The authors
performed tests on the URFD [33] and FDD [12] datasets.
Other machine learning algorithms, such as the K-nearest neighbors, were also
used to detect falls. Kwolek and Kepski [34] made use of a combination of an
accelerometer and Kinect. Once the accelerometer surpassed a threshold, a fall alert
was raised, and only then, the Kinect camera started capturing frames of the scene’s
depth, which was used by a second classifier. The authors compared the classification
of the frames between KNN and SVM and tested on two datasets, the URFD [33]
and on an independent one.
Sehairi et al. [59] developed a finite state machine to estimate the position of the
human head from the extracted silhouette. The tests were performed on the FDD [12]
dataset.
56 G. V. Leite et al.

The application of Markov filters was also used to detect falls, as in the work
of Anderson et al. [3], in which the individual’s silhouette was extracted so that his
characteristics were classified by the Markov filter. The experiments were carried
out on their dataset.
Zerrouki and Houacine [81] described the characteristics of the body through
curvelet coefficients and the ratio between areas of the body. An SVM classifier
performed the posture classification and the Markov filter discriminated between
falls or not falls. The authors reported experiments on the URFD [33] and FDD [12]
datasets.
In addition to the methods mentioned above, the following works made use of
several techniques, such as Yu et al. [80], which obtained their characteristics by
applying head tracking techniques and analysis of shape variation. The characteristics
served as input to a Gaussian classifier. The authors created a dataset for the tests.
Zerrouki et al. [82] segmented the frames between the foreground and background
and applied another segmentation on the human body, dividing it into five partitions.
The body segmentations were fed to an AdaBoost classifier, which obtained 96%
accuracy on the URFD [33] dataset. Finally, Xu et al. [79] published a survey that
evaluates several fall detection systems.

3 Basic Concepts

In this section, we describe in detail the concepts necessary to understand the pro-
posed methodology.

3.1 Deep Neural Networks

Deep neural networks (DNNs) are a class of machine learning algorithms, in which
several layers of processing are used to extract and transform characteristics from
the input data and the backpropagation algorithm enables the network to learn the
complex patterns of the data. The input information for each layer is the same as the
output of the previous one, except for the first layer, in which data is input, and the
last layer, from which the outputs are extracted [22]. This structure is not necessarily
fixed, some layers can have two other layers as input or several outputs.
Deng and Yu [15] mentioned some reasons for the growing popularity of deep net-
works in recent years, which include their results in classification problems, improve-
ments in graphic processing units (GPUs), the appearance of tensor processing units
(TPUs), and the amount of data available digitally.
The layers of a deep network can be organized in different ways, to suit the task
at hand. The manner a DNNs is organized is called the “architecture”, and some of
Three-Stream Convolutional Neural Network … 57

them have become well known because of their performance in image classifica-
tion competitions. Some of them are AlexNet [30], LeNet [35], VGGNet [62] and
ResNet [24].

3.2 Convolutional Neural Networks

Convolutional neural networks (CNNs) are a subtype of deep networks, their structure
is similar to that of a DNN, such that information flows from one layer to the next.
However, on CNN the data is also processed by convolutional layers, which applies
various convolution operations and resizes the data, before sending it on to the next
layer.
These convolution operations allow the network to learn low-level features in
the first layers and merge them in the following layers to learn high-level features.
Although not mandatory, usually at the very end of a convolutional network there
are some fully connected layers.
To this work’s scope, two CNN architectures are relevant: (i) VGG-16 [62] and
(ii) Inception [66]. The VGG-16 was the winner of the ImageNet Large Scale Visual
Recognition Challenge (ILSVRC) 2014 competition, with a 7.3% error in the location
category. Its choice of using small filters, convolutions of 3 × 3, stride 1, padding 1 and
max-pooling of 2 × 2 with stride 2, allowed the network to be deeper, without being
computationally prohibitive. The VGG-16 has 16 layers and 138 million parameters,
which is considered small for deep networks. The largest load of computations in this
network occurs in the first layers, since, after them, the layers of pooling considerably
reduce the load to the deeper layers. Figure 2 illustrates the VGG-16 architecture.
The second architecture, Inception V1 [66], was the winner of ILSVRC 2014,
in the same year as VGG-16, however, on the classification category, with an error
of 6.7%. This network was developed to be deeper and, at the same time, more
computationally efficient. The Inception architecture has 22 layers and only 5 million
parameters. Its construction consists of stacking several modules, called Inception,
illustrated in Fig. 3a.
The modules were designed to create something of a network within the network,
in which several convolutions and max-pooling operations are performed in parallel
and, at the end of these, the features are concatenated to be sent to the next module.
However, if the network was composed of Inception modules, as illustrated in Fig. 3a,
it would perform 850 million operations total. To reduce this number, bottlenecks
Conv 1-1
Conv 1-2

Conv 2-1
Conv 2-2

Conv 3-1
Conv 3-2
Conv 3-3

Conv 4-1
Conv 4-2
Conv 4-3

Conv 5-1
Conv 5-2
Conv 5-3

Output
Input

Pool

Pool

Pool

Pool

Pool

FC
FC
FC

Fig. 2 Layout of the VGG-16 architecture


58 G. V. Leite et al.

1x1 Convolutions

3x3 Convolutions

Filter
Previous Layer 5x5 Convolutions
Concatenation

3x3 Max Pooling

(a) module without bottleneck

1x1 Convolutions

1x1 Convolutions 3x3 Convolutions

Filter
Previous Layer 1x1 Convolutions 5x5 Convolutions
Concatenation

3x3 Max Pooling 1x1 Convolutions

(b) module with bottleneck

Fig. 3 Inception modules. Adapted from Szegedy et al. [66]

were created. Bottlenecks reduce the number of operations to 358 million. They are
1 × 1 convolutions that preserve the spatial dimension while decreasing the depth
of the features. To do so, they were placed before the convolutions 3 × 3, 5 × 5 and
after the max-pooling 3 × 3, as illustrated in Fig. 3b.

3.2.1 Transfer Learning

Transfer learning consists of reusing a network whose weights were trained in another
context, usually a related and more extensive context. In addition to usually improving
a network’s performance, the technique also collaborates to decrease the convergence
time and help on scenarios with not enough training data available [85].
Typically in classification problems, the transfer happens from the ImageNet [14]
dataset, which is one of the largest and well-known datasets. The goal is so that the
Three-Stream Convolutional Neural Network … 59

network can learn enough complex patterns that are generic enough to be used in
another context.

3.3 Definition of Fall

The literature does not provide a universal definition of fall [29, 45, 67], however,
some health agencies have created their definition, which can be used to describe a
general idea of a fall.
The Joint Commission [68] defines a fall as “[...] a fall may be described as an
unintentional change in position coming to rest on the ground, floor, or onto the
next lower surface (e.g., onto a bed, chair or bedside mat). [...]”. The World Health
Organization [77] defines it as “[...] an event which results in a person coming to
rest inadvertently on the ground or floor or other lower level. [...]”. The definition
of the National Center for Veterans Affairs [70] is as “loss of upright position that
results in landing on the floor, ground, or an object or furniture, or a sudden, uncon-
trolled, unintentional, non-purposeful, downward displacement of the body to the
floor/ground or hitting another object such as a chair or stair [...]”. In this work, we
describe a fall as an involuntary movement that results in an individual landing on
the ground.

3.4 Optical Flow

Optical flow is a technique that deduces pixel movement, caused by the displacement
of the object or the camera. It is a vector that represents the movement of a region,
extracted from a sequence of frames. It assumes that the pixels will not leave the
frame region, and is a local method, thus, the difficulty to calculate it on uniform
regions.
The extraction of the optical flow is performed by comparing two consecutive
frames, and its representation is a vector of direction and magnitude. Consider I
a video frame and I (x, y, t) a pixel in this frame. A frame analyzed in a future
time dt is described in Equation 1 as a function of the pixel I (x, y, t) displacement
of (d X, dY ). The Eq. 2 is obtained from a Taylor series divided by dt and has the
following gradients f t , f x and f y (Eq. 3). The components of the optical flow are
the values of u and v (Eq. 4) and can be obtained by several methods such as Lucas–
Kanade [44] or Farnebäck [19]. Figure 4 illustrates some optical flow frames.
60 G. V. Leite et al.

Fig. 4 Examples of extracted optical flow. Each optical flow frame was extracted from the above
frame and its next in sequence. The pixels colors indicate the movement direction, and its brightness
relates to the magnitude of the movement

I (x, y, t) = I (x + d X, y + dY, t + dT ) (1)


f x u + f y v + ft = 0 (2)
∂f ∂f ∂f
fx = fy = ft = (3)
∂x ∂y ∂t
∂x ∂y
u= v= (4)
∂t ∂t

3.5 Visual Rhythm

Visual rhythm is an encoding technique, that aims at creating a temporal relation


between frames, without losing their spatial information. Its representation consists
of a single image summarizing the entire video, in a way that each video frame
contributes as a column on the final image [51, 69, 71].
The construction of the visual rhythm happens as each video frame is traversed in
a zigzag pattern, from its lower left diagonal to its upper right diagonal, as illustrated
in Fig. 5a. Each frame processed in zigzag generates a column of pixels, which
is concatenated with the other columns to form the visual rhythm (Fig. 5b). The
dimensions of the rhythm image are W × H , in which the width W is the number
of frames in the video and the height H is the length of the zigzag path. Figure 6
illustrates some extracted visual rhythms.

3.6 Saliency Map

In the context of image processing, a saliency map is a feature of the image that
represents regions of interest in an image. The map is generally presented in shades
Three-Stream Convolutional Neural Network … 61

Resulting
1 2 3 ... N
Visual Rhythm

Source Frames

Fig. 5 Visual rhythm construction process. On the left, the zigzag manner in which each frame is
traversed. On the right, the rhythm construction through the column concatenation

Fig. 6 Visual rhythm examples. Each frame on the first row illustrates a different video and, bellow,
the visual rhythm for the corresponding entire video

of gray so that regions with low interest are black and regions with high interest are
white. In the context of deep learning, the saliency map is the activation map of a
classifier, highlighting the regions that greater contributed to the output classification.
In its origin, the salience map was extracted as a way of understanding what the deep
networks were learning and it is still used in that way, as in the work of Li et al. [39].
The salience map was used by Zuo et al. [87] as a feature to classify actions
on an egocentric point of view, in which the source of information is a camera
that corresponds to the subject’s first-person view. Using the saliency map in the
egocentric context has its roots on the assumption that important aspects of the
action take place in the foreground, instead of the background.
The saliency map can be obtained in several ways, as shown by Smilkov et al. [64],
Sundararajan et al. [65] and Simonyan et al. [63]. Figure 7 illustrates the extraction
of the saliency map.
62 G. V. Leite et al.

Fig. 7 Saliency map examples. Pixels varying from black to white, according to the region impor-
tance to the external classifier

3.7 Posture Estimation

It is a technique to derive the posture of one or more human beings. Different sensors
are used as input to this technique, such as depth sensors from a Microsoft Kinect or
images from a camera.
The algorithm proposed by Cao et al. [7], OpenPose, is notable for its effectiveness
in estimating the pose of individuals in video frames. OpenPose operates with a two-
stage network in search of 18 body joints. On its first stage, the method creates
probability maps of the joint’s position and the second stage predicts affinity fields
between the limbs found. The affinity is represented by a 2D vector, which encodes
the position and orientation of each body limb. Figure 8 shows the posture estimation
of some frames.

Fig. 8 Example of the posture estimation. Each circle represents a joint, whereas the edges represent
the limbs. Each limb has a specific color assigned to it
Three-Stream Convolutional Neural Network … 63

4 Proposed Method

In this section, we describe the proposed method, an Inception 3D multi-channel


network to detect falls on video frames.
Figure 9 shows the overview of the method, that is based on the hypothesis raised
by Goodale and Milner [21], in which the human visual cortex consists of two
parts that focus on processing different aspects of vision. This same hypothesis
inspired Simonyan and Zisserman [61] to test neural networks with various channels
of information to simulate the visual cortex.
In the methodology illustrated in Fig. 9, each stream is a separate neural network,
which was trained on a specific feature. For instance, the optical flow stream is a
network fed exclusively with the optical flow extracted from the frames, whereas
the saliency stream is a different network fed exclusively with the extracted saliency
frames, and so on. Since each stream is independent of each other, we explored
different architectures and stream combinations in our experiments. Instead of always
employing three streams, we also tested whether using only two streams would
produce a better classifier.

4.1 Preprocessing

In the related literature to deep learning, knowledge about the positive influence
of preprocessing steps is ubiquitously present. In this sense, some processes were
executed to better tackle the task in question.
In this work, the preprocessing step, represented by the green block in Fig. 9,
consists of extracting features that can capture the various aspects of a fall and
applying data augmentation techniques.

4.1.1 Feature Extraction

The following features were extracted and later inputted to the network in a specific
manner. The posture estimation was extracted using a bottom-up approach, with the
OpenPose algorithm by Cao et al. [7]. The extracted frames were fed into the network
one at a time and the inference results are obtained frame by frame (Fig. 8). Regarding
the visual rhythm, an algorithm was implemented by us, such that each video has only
one visual rhythm. This rhythm frame was fed to the network repeatedly, so that its
inference output could be paired with the other features (Fig. 6). The saliency map was
obtained using the SmoothGrad technique, proposed and implemented by Smilkov
et al. [64], which acts on top of a previously existing technique by Sundararajan et
al. [65]. The frames were fed to the network, once again, one by one (Fig. 7).
The optical flow extraction was carried out with the algorithm proposed by
Farnebäck [19], which describes the dense optical flow (Fig. 4). As a fall event hap-
64 G. V. Leite et al.

Dataset

Pre-Processing
Optical Flow Visual Rhythm Pose Estimation
Stream Stream Stream

ImageNet ImageNet ImageNet


Weights Weights Weights

Training
Networks
Neural
Trained Trained Trained
Model Model Model

Trained Trained Trained


Model Model Model
Neural Networks

Test
SVM

Fall or
Not Fall

Fig. 9 Overview of the proposed method that illustrates the training and test phases, their steps
and the information flow throughout the method
Three-Stream Convolutional Neural Network … 65

Optical First Second Third


Flow stack stack stack

Fig. 10 Illustration of the sliding window and how it moves to create each stack

pens throughout several frames and the flow represents only the relationship between
two of them, we employed a sliding window approach, suggested by Wang et al. [74].
The sliding window feeds the network with a stack of ten frames of optical flow. The
first stack contains frames from 1 to 10 and the second stack frames from 2 to 11,
and so on with stride of 1 (Fig. 10), so each video has N − 10 + 1 stacks, assuming
N as the number of frames in a video, also if at the end of a video there are less then
ten frames, then they do not contribute to the evaluation. The resulting inference of
each stack was associated with the first frame of the stack, it was done this way so
that the optical flow could be paired with the other channels on the network.

4.1.2 Data Augmentation

Data augmentation techniques were used, when applicable, in the training phase.
The following augmentations were employed: vertical axis mirroring, perspective
transform, cropping and adding mirrored borders, adding values between −20 and
20 to pixels, and adding values between −15 and 15 to the hue and saturation.
The whole augmentation process was done only over the RGB channel, as the other
channels would suffer negatively from it. For instance, the optical flow information
depends strictly on the relationship between frames and its magnitude is expressed by
the brightness of the pixel, mirroring an optical flow frame would break the continuity
between them and adding values to pixels would distort the magnitude of the vector.

4.2 Training

Due to the small amount of available data to experiment on, the training phase of
our method requires the application of transfer learning techniques. Thus, the model
was trained on the ImageNet [14] dataset and, later, on our selected fall dataset, this
whole process is illustrated in Fig. 11. Traditionally, other works might freeze some
layers between the transfer learning and the training, this was not the case, all the
layers were trained with the fall dataset.
66 G. V. Leite et al.

Model Model Model


without trained on trained on
training ImageNet Fall datasets

Fig. 11 Transfer learning process. From left to right, initially the model has no trained weights,
then it is trained on the ImageNet [14] dataset and, finally, it is trained on the selected fall dataset

The same transfer learning was done for all feature channels, meaning that, inde-
pendently of the feature that a channel will be trained on, its starting point was
the ImageNet training. After this, each channel is trained with its extracted feature
frames, although, we selected four features, we did not combine all of them at the
same time and kept it up to three combined features at a time.

4.3 Test

The selected features in this work were previously used by some others in the related
literature [7, 69], however, they were employed in a single-stream manner. These
works, along with the work of Simonyan et al. [61], paved our motivation in proposing
a multi-stream methodology that would join the different aspects of each feature, so
that a more robust description of the event could be achieved. This ensemble can be
accomplished in several ways, varying from a simple average between the channels,
through a weighted average, to some automatic methods, like the one used in this
work, the application of an SVM classifier.
The workflow of the test phase illustrated to the right of Fig. 9, begins with the
feature extraction, equal to the training phase. After that, the weights obtained in
the training phase were loaded, and all layers of the model were frozen. Then, each
channel with its specific trained model performed inferences on their input data. The
output vectors were concatenated and sent to the SVM classifier, which in turn would
classify each frame between fall and not fall.
Three-Stream Convolutional Neural Network … 67

5 Experimental Results

In this section, we describe the datasets used in the experiments, the metrics selected
to evaluate our method, the executed experiments, and how our method stands against
others proposed.

5.1 Datasets

Upon reviewing the related literature, a few datasets were found, however, some were
not publicly available, hyperlinks to the data were inactive, or the authors did not
answer our contact attempt. This lead us to select the following human fall datasets:
(i) URFD [33] and (ii) FDD [12].

5.1.1 URFD

Published by Kwolek and Kepski [33], the URFD dataset (University of Rzeszow Fall
Detection Dataset) is made up of 70 video sequences, 30 of which are falls and 40
of everyday activities. Each video has 30 frames per second (FPS), with a resolution
of 640 × 240 pixels and varying lengths.
The fall sequences were recorded with an accelerometer and two Microsoft Kinect
cameras, one camera has a horizontal view of the scene and one with a top–down view,
from the ceiling. The activities of daily living were recorded with a single horizontal
view camera and an accelerometer. The accelerometer information was excluded
from the experiments as it went beyond the scope of the project. As illustrated in
Fig. 12, the dataset has five ADL scenarios, but a single fall scenario, in which the
camera angle and background are the same, changing only the actors in the scene,
this lack of variety is further discussed in the experiments.
The dataset is annotated with the following information:
• Subject’s posture (not lying, lying on the floor and transition).
• Ratio between height and width of the bounding box.
• Ratio between maximum and minimum axes.
• Ratio of the subject’s occupancy in the bounding box.
• Standard deviation of the pixels to the centroid of the X and Z axes.
• Ratio between the subject’s height in the frame and the subject’s standing height.
• Subject’s height.
• Distance from the subject’s center to the floor.
68 G. V. Leite et al.

Fig. 12 URFD’s [33] environments. a Fall scenarios; b ADLs scenarios

5.1.2 FDD

The FDD dataset (Fall Detection Dataset) was published by Charfi et al. [12] and
contains 191 video sequences, with 143 being falls and 48 being day-to-day activities.
Each video has 25 FPS, with a resolution of 320 × 240 pixels and varying lengths.
All sequences were recorded with a single camera, in four different environments:
home, coffee room, office, and classroom, illustrated in Fig. 13. Besides, the dataset
presents three experimentation protocols: (i) in which training and testing are created
with videos from the home and coffee room environments, (ii) in which the training
consists of videos from the coffee room and the test with videos from the office and
the classroom and (iii) where the training contains videos from the coffee room,
the office, and the classroom and the test contains videos from the office and the
classroom.
The dataset is annotated with the following information:
• Initial frame of the fall.
• Final frame of the fall.
• Height, width, and coordinates of the center of the bounding box in each frame.
Three-Stream Convolutional Neural Network … 69

Fig. 13 FDD’s [12] environments. a Fall scenarios; b ADLs scenarios

5.2 Evaluation Metrics

In this work, we approach the problem of detecting falls as a binary classification,


in which a classifier must decide whether a video frame corresponds to a fall or
not. To that end, the chosen metrics and their respective equations were as follows:
(i) precision (Eq. 5), (ii) sensitivity (Eq. 6), (iii) accuracy (Eq. 7), and (iv) balanced
accuracy (Eq. 8). In the following equations, the abbreviations corresponds to: TP
true positive, FP false positive, TN true negative, and FN false negative, also in Eq. 9,
yi corresponds to the true value of the i sample, and wi corresponds to the sample
weight.

TP
Precision = (5)
T P + FP
TP
Sensitivity = (6)
T P + FN
TP +TN
Accuracy = (7)
T P + T N + FP + FN
1 
Balanced Accuracy =  ( ŷi = yi )ŵi (8)
ŵi i
70 G. V. Leite et al.

in which wi
ŵi =  (9)
j 1(y j = yi )w j

These metrics were chosen because of the need to compare our results with those
found in the literature, which, for the most part, reported only: precision, sensitivity,
and accuracy.
Considering that both datasets are unbalanced, so that the negative class has more
than twice as many samples as the positive class (falls), we chose to use the balanced
accuracy instead of some other balanced metric, because, as stated in its name, it
balances the samples and, in doing so, takes the false negatives into account. False
negatives are especially important in fall detection, as ignoring a fall incident can
lead to the health problems described in Sect. 1.

5.3 Computational Resources

This method was implemented using the Python programming language [73], which
was chosen due to its wide availability of libraries for image analysis and deep
learning applications. Moreover, some libraries were used, such as: SciKit [27],
NumPy [55], OpenCV [6] and Keras [13], and the TensorFlow [1] framework.
Deep learning algorithms are known to be computationally intensive. Their train-
ing and experiments require more computational power than a conventional notebook
can provide and, therefore, were carried out in the cloud on a rental machine from
Amazon AWS, g2.2xlarge, with the following specifications: 1x Nvidia GRID K520
GPU (Kepler), 8x vCPUs, and 15GB of RAM.

5.4 Experiments

Next, we report the performed experiments, their results, and discussion are pre-
sented. The experiments were split between multi-channel, cross-tests, and literature
comparisons. To the knowledge of the authors, the cross-tests, in which the model was
trained in one dataset and tested on the other, is unprecedented among the selected
datasets. The data were split in proportions of 65% for training, 15% for validation
and 20% for testing.
All our experiments were executed using the following parameters: 500 epochs,
along with early stopping and patience of 10, a learning rate of 10−5 , mini-batches
of 192, 50% of dropout, Adam optimizer, and we trained to minimize the validation
loss function. In order, the results were reported on the URFD dataset, followed by
the one from the FDD set.
The results obtained on the URFD base are shown in Table 1, and are organized in
decreasing order of the balanced accuracy. In this first experiment, the combination
Three-Stream Convolutional Neural Network … 71

Table 1 3D multi-channel results on URFD dataset


Channels Precision (%) Sensitivity (%) Accuracy (%) Balanced
Accuracy (%)
OF RGB 1.00 1.00 0.97 0.98
OF VR 0.99 0.96 0.95 0.97
RGB VR 0.99 0.90 0.96 0.97
OF RGB SA 1.00 0.99 0.98 0.94
OF RGB VR 0.99 0.99 0.99 0.94
OF RGB PE 0.99 0.96 0.96 0.91
OF SA 0.99 0.98 0.94 0.91
SA VR 0.99 0.94 0.94 0.90
RGB SA 0.99 0.95 0.96 0.89
SA PE 0.99 1.00 0.91 0.89
RGB PE 0.99 0.99 0.92 0.89
VR PE 0.99 0.96 0.94 0.88
OF PE 0.99 0.97 0.92 0.87
OF: Optical flow, VR: Visual rhythm, SA: Saliency, PE: Posture estimation. The results are in
decreasing order of balanced accuracy, and the best result of each column is highlighted in bold

of the optical flow and RGB channels obtained the best result, nevertheless, the pose
estimation channels obtained the worst results.
Table 2 shows once again the efficacy of the optical flow channel, however, in this
instance, the RGB channels suffered a slight fall in performance, in contrast to the
pose estimation ones, who rose from the worsts results.
The first cross-test was performed with the training done on the URFD dataset,
and the test in the FDD set, as its results are shown in Table 3. Upon executing a
cross-test it is expected that the model will not perform as well since it is a completely
new dataset. The highest balanced accuracy was of 68%, a sharp drop from the 98%
of Table 2, furthermore, a majority of the channels could not surpass the 50% mark,
indicating that the model was barely able to generalize its learning.
We believe that this drop on the balanced accuracy is an effect of the training
dataset quality. As explained previously, the URFD dataset is very limited in its fall
scenarios variability, presenting the same scenario repeatedly. Returning to the multi-
channel test in the URFD dataset (Table 2), it is possible to notice that the channels
with access to the background were among the best results, namely, visual rhythm,
and RGB, but on the cross-test, the best results were obtained from channels without
access to the background information. This could indicate that channels with access
to background learned some features present in the scenario and were not able to
detect falls when the scenario changed. In contrast to those without access to the
background, those were still able to detect falls upon scenario changes.
The second cross-test was the opposite experiment, with training done on the
FDD dataset, and test on the URFD set (Table 4). Although there is, once again,
a sharp drop in the balanced accuracy, it performed much better than the previous
72 G. V. Leite et al.

Table 2 3D multi-channel results on FDD dataset


Channels Precision (%) Sensitivity (%) Accuracy (%) Balanced
Accuracy (%)
OF SA 0.99 0.99 0.98 0.98
SA VR 1.00 0.98 0.98 0.96
SA PE 1.00 1.00 0.97 0.96
RGB SA 1.00 0.99 0.99 0.95
OF PE 0.99 0.96 0.97 0.93
OF RGB VR 0.99 0.95 0.95 0.91
OF VR 0.99 0.93 0.94 0.91
RGB VR 0.99 0.94 0.89 0.91
OF RGB 0.99 0.91 0.99 0.90
VR PE 0.99 0.89 0.91 0.88
OF RGB PE 0.99 0.84 0.93 0.87
OF RGB SA 0.99 0.85 0.97 0.86
4 RGB PE 0.99 0.80 0.91 0.84
OF: Optical flow, VR: Visual rhythm, SA: Saliency, PE: Posture estimation. The results are in
decreasing order of balanced accuracy, and the best result of each column is highlighted in bold

Table 3 3D cross-test results URFD to FDD dataset


Channels Precision (%) Sensitivity (%) Accuracy (%) Balanced
Accuracy (%)
OF PE 0.97 1.00 0.97 0.68
OF SA 0.96 1.00 0.96 0.57
SA PE 0.95 1.00 0.96 0.55
OF RGB PE 0.95 1.00 0.96 0.50
OF RGB VR 0.95 1.00 0.96 0.50
OF RGB SA 0.95 1.00 0.96 0.50
OF RGB 0.95 1.00 0.96 0.50
OF VR 0.95 1.00 0.96 0.50
RGB PE 0.95 1.00 0.96 0.50
RGB VR 0.95 1.00 0.96 0.50
RGB SA 0.95 1.00 0.96 0.50
SA VR 0.95 1.00 0.96 0.50
VR PE 0.95 1.00 0.96 0.50
OF: Optical flow, VR: Visual rhythm, SA: Saliency, PE: Posture estimation. The results are in
decreasing order of balanced accuracy, and the best result of each column is highlighted in bold

test, keeping a maximum balanced accuracy of 84% due to the optical flow and pose
estimation channels. Moreover, a similar scenario is observable in both cross-tests,
the channels without access to background were able to better discriminate between
fall and not fall.
Three-Stream Convolutional Neural Network … 73

Table 4 3D cross-test results FDD to URFD dataset


Channels Precision (%) Sensitivity (%) Accuracy (%) Balanced
accuracy (%)
OF PE 0.97 0.90 0.89 0.84
VR PE 0.96 0.92 0.90 0.75
OF RGB PE 0.96 0.99 0.97 0.72
SA PE 0.96 0.93 0.95 0.60
OF SA 0.95 0.99 0.91 0.54
RGB PE 0.95 0.99 0.91 0.54
OF RGB SA 0.95 1.00 0.95 0.50
OF RGB VR 0.95 1.00 0.95 0.50
OF RGB 0.95 1.00 0.95 0.50
OF VR 0.95 1.00 0.95 0.50
RGB VR 0.95 1.00 0.95 0.50
RGB SA 0.95 1.00 0.95 0.50
SA VR 0.95 1.00 0.95 0.50
OF: Optical flow, VR: Visual rhythm, SA: Saliency, PE: Posture estimation. The results are in
decreasing order of balanced accuracy, and the best result of each column is highlighted in bold

Given that on the second cross-test, even the channels with access to background
faced an improvement in their performance, one might argue that the background
had nothing to do with it. However, it is important to reiterate that the FDD dataset
is more heterogeneous than the URFD, to the point in which a fall scenario and
an ADL activity were recorded in the same scenario. This variability added to the
intrinsic nature of the 3D model, which creates internal temporal relation between
the input data, probably allowed the channels with access to the background to focus
on features of the fall movement itself. Thus, the rising number of channels with
more than 50% of balanced accuracy.

5.4.1 Our Method Versus the Literature

At last, we compared our best results and the ones from our previous work with those
found in the literature. As stated before, we did not compare our balanced accuracy,
since no other work reported theirs. The results regarding the URFD dataset are
reported in Table 5, while those on the FDD set are shown in Table 6, both are sorted
in decreasing order of accuracy.
Our Inception 3D architecture surpassed or matched the reviewed works on both
datasets, as well as our previous method using the VGG-16 architecture. However, the
VGG-16 was not able to surpass the work of Lu et al. [43], which curiously employs
a 3D method, an RNN with LSTM architecture, explaining its performance.
74 G. V. Leite et al.

Table 5 Ours versus literature on URFD dataset


Approaches Precision (%) Sensitivity (%) Accuracy (%)
Our inception 3D 0.99 0.99 0.99
Lu et al. [43] – – 0.99
Previous VGG-16 1.00 0.98 0.98
Panahi and Ghods [56] 0.97 0.97 0.97
Zerrouki and – - 0.96
Houacine [81]
Harrou et al. [23] – – 0.96
Abobakr et al. [2] 1.00 0.91 0.96
Bhandari et al. [5] 0.96 – 0.95
Kwolek and 1.00 0.92 0.95
Kepski [34]
Núñez-Marcos et 1.00 0.92 0.95
al. [54]
Sase and 0.81 – 0.90
Bhandari [57]
Metrics that were not reported by the authors are exhibited as a hyphen (-). The results are sorted
by the accuracy on a decreasing order. The best result of each column is highlighted in bold

Table 6 Ours versus literature on FDD dataset


Approaches Precision (%) Sensitivity (%) Accuracy (%)
Our inception 3D 1.00 0.99 0.99
Previous VGG-16 0.99 0.99 0.99
Lu et al. [43] – – 0.99
Sehairi et al. [59] – – 0.98
Zerrouki and – – 0.97
Houacine [81]
Harrou et al. [23] – – 0.97
Núñez-Marcos et 0.99 0.97 0.97
al. [54]
Charfi et al. [11] 0.98 0.99 –
Metrics that were not reported by the authors are exhibited as a hyphen (-). The results are sorted
by the accuracy on a decreasing order. The best result of each column is highlighted in bold

6 Conclusions and Future Work

In this work, we presented and compared our deep neural networks method to detect
human falls in video sequences, using a 3D architecture, the Inception V1 3D net-
work. The training and evaluation of the method were performed on two public
datasets and, regarding its effectiveness, outperformed or matched the related liter-
ature.
Three-Stream Convolutional Neural Network … 75

The results pointed out the importance of temporal information in the detection
of falls, both because the temporal channels were always among the best results,
especially the optical flow, and in the improvement obtained by the 3D method when
compared to our previous work on the VGG-16 architecture. The 3D method was also
able to generalize its learning to a never seen dataset, and this ability to generalize
the learning indicates that the 3D method can be considered a strong candidate to
compose the elder monitoring system.
This is evidenced throughout all the results shown in Sect. 5, since the temporal
channels are always among the most effective, except for two cases shown in Tables 2
and 3, in which the third-best result was the combination of the spatial channels of
salience and pose. This can be attributed to the fact that the 3D architecture itself
provides a temporal relationship between the data.
Our conclusion about the importance of temporal information to the fall classifi-
cation is corroborated by other works found in the literature, such as those by Meng
et al. [48] and Carreira and Zisserman [10], who stated the same for the classification
of actions in videos. This indicates that other deep learning architectures, such as
those described in [75], could also be used for this application.
In addition, as our results surpassed the reviewed works, the method demonstrated
itself to be effective in detecting falls. In a specific instance, our method matched
the results of the work of Lu et al. [43], in which the author makes use of an LSTM
architecture that, like ours, creates temporal relationships between the input data.
Innovatively, cross-tests were performed between the datasets. The results of these
tests showed a known, however, interesting facet of neural networks, in which the
minimization function finds a local minimum that does not correspond with the initial
objective of the solution. During the training, some channels of the network, learned
aspects of the background of the images to classify the falls, instead of focusing on
the aspects of the person’s movement on video.
The method developed in our work is corroborated by some factors, such as the
evaluation through the balanced accuracy, the tests being performed on two different
datasets, the heterogeneity of the FDD dataset, the execution of the cross-tests and
the comparisons between the various channel combinations. On the other hand, the
work also deals with some difficulties, such as (i) the low variability of the fall videos
in the URFD set, (ii) the fact that in the cross tests many combinations of channels
obtained only 50% of balanced accuracy, and (iii) the use of simple accuracy as a
means of comparison with the literature. However, the proposed method suppresses
these counterparts, remaining relevant to the problem at hand.
The effectiveness of the method shows that, if trained on a robust enough dataset,
it can extract the temporal patterns necessary to classify scenarios between fall and
non-fall. Admittedly, there is an expected drop in balanced accuracy in the cross
tests. Regarding fall detection, this work is one of the most accurate approaches and
would be a great contribution as a module in an integrated system of assistance to
the elderly.
Concerning future work, some points can be mentioned: (i) exploring other
datasets that may contain a greater variety of scenarios and actions, (ii) integrat-
ing fall detection into a multi-class system, (iii) experimenting with cheaper features
76 G. V. Leite et al.

to be extracted, (iv) adapting the method to work in real time, either through cheaper
channels or a lighter architecture, and (v) dealing with input as a stream of videos,
instead of clips, because in a real scenario, the camera would continuously feed the
system, which would further unbalance classes.
The contributions of this work are presented in the form of a human fall detection
method, implemented and publicly available in the repository [38], as well as the
experimentation using multi-channels and different datasets, which generated not
only a discussion about which metrics are more appropriate for evaluating fall solu-
tions, but also a discussion of the quality of the datasets used in these experiments.

Acknowledgements The authors are thankful to FAPESP (grant #2017/12646-3), CNPq (grant
#309330/2018-7), and CAPES for their financial support, as well as Semantix Brasil for the infras-
tructure and support provided during the development of the present work.

References

1. M. Abadi, A. Agarwal, P. Barham, E. Brevdo, Z. Chen, C. Citro, TensorFlow: large-scale


machine learning on heterogeneous systems (2015). https://www.tensorflow.org
2. A. Abobakr, M. Hossny, S. Nahavandi, A skeleton-free fall detection system from depth images
using random decision forest. IEEE Syst. J. 12(3), 2994–3005 (2017)
3. D.T. Anderson, J.M. Keller, M. Skubic, X. Chen, Z. He, Recognizing falls from silhouettes,
in International Conference of the IEEE Engineering in Medicine and Biology Society (2006),
pp. 6388–6391
4. L. Anishchenko, Machine learning in video surveillance for fall detection, in Ural Symposium
on Biomedical Engineering, Radioelectronics and Information Technology (IEEE, 2018), pp.
99–102
5. S. Bhandari, N. Babar, P. Gupta, N. Shah, S. Pujari, A novel approach for fall detection in home
environment, in IEEE 6th Global Conference on Consumer Electronics (IEEE, 2017), pp. 1–5
6. G. Bradski, The openCV library. Dobb’s J. Softw. Tools 120, 122–125 (2000)
7. Z. Cao, T. Simon, S.E. Wei, Y. Sheikh, Realtime multi-person 2D pose estimation using part
affinity fields, in IEEE Conference on Computer Vision and Pattern Recognition (2017), pp.
7291–7299
8. S. Carneiro, G. Silva, G. Leite, R. Moreno, S. Guimaraes, H. Pedrini, Deep convolutional
multi-stream network detection system applied to fall identification in video sequences, in
15th International Conference on Machine Learning and Data Mining (2019a), pp. 681–695
9. S. Carneiro, G. Silva, G. Leite, R. Moreno, S. Guimaraes, H. Pedrini, Multi-stream deep
convolutional network using high-level features applied to fall detection in video sequences, in
26th International Conference on Systems, Signals and Image Processing (2019b), pp. 293–298
10. J. Carreira, A. Zisserman, Quo vadis, action recognition? a new model and the kinetics dataset,
in Conference on Computer Vision and Pattern Recognition (IEEE, 2017), pp. 6299–6308
11. I. Charfi, J. Miteran, J. Dubois, M. Atri, R. Tourki, Definition and performance evaluation
of a robust svm based fall detection solution, in International Conference on Signal Image
Technology and Internet Based Systems, vol. 12 (2012), pp. 218–224
12. I. Charfi, J. Miteran, J. Dubois, M. Atri, R. Tourki, Optimized spatio-temporal descriptors for
real-time fall detection: comparison of support vector machine and adaboost-based classifica-
tion. J. Electron. Imaging 22(4), 041106 (2013)
13. F. Chollet, Keras (2015). https://keras.io
14. J. Deng, W. Dong, R. Socher, L.J. Li, K. Li, L. Fei-Fei, Imagenet: a large–scale hierarchical
image database, in IEEE Conference on Computer Vision and Pattern Recognition (2009), pp.
248–255
Three-Stream Convolutional Neural Network … 77

15. L. Deng, D. Yu, Deep learning: methods and applications. Found. Trends Signal Process.
7(3–4), 197–387 (2014)
16. A. Edgcomb, F. Vahid, Automated fall detection on privacy-enhanced video, in Annual Inter-
national Conference of the IEEE Engineering in Medicine and Biology Society (2012), pp.
252–255
17. K. Fan, P. Wang, S. Zhuang, Human fall detection using slow feature analysis. Multimed. Tools
Appl. 78(7), 9101–9128 (2018a)
18. Y. Fan, G. Wen, D. Li, S. Qiu, M.D. Levine, Early event detection based on dynamic images
of surveillance videos. J. Vis. Commun. Image Represent. 51, 70–75 (2018b)
19. G. Farnebäck, Two–frame motion estimation based on polynomial expansion, in Scandinavian
Conference on Image Analysis (2003), pp. 363–370
20. S. Gasparrini, E. Cippitelli, E. Gambi, S. Spinsante, J. Wåhslén, I. Orhan, T. Lindh, Proposal
and experimental evaluation of fall detection solution based on wearable and depth data fusion,
in International Conference on ICT Innovations (Springer, 2015), pp. 99–108
21. M.A. Goodale, A.D. Milner, Separate visual pathways for perception and action. Trends Neu-
rosci. 15(1), 20–25 (1992)
22. I. Goodfellow, Y. Bengio, A. Courville, Y. Bengio, Deep Learning (MIT Press, 2016)
23. F. Harrou, N. Zerrouki, Y. Sun, A. Houacine, Vision-based fall detection system for improving
safety of elderly people. IEEE Instrum. & Meas. Mag. 20(6), 49–55 (2017)
24. K. He, X. Zhang, S. Ren, J. Sun, Deep residual learning for image recognition, in IEEE
Conference on Computer Vision and Pattern Recognition (2016), pp. 770–778
25. D.L. Heymann, T. Prentice, L.T. Reinders, The World Health Report: a Safer Future: global
Public Health Security in the 21st Century (World Health Organization, 2007)
26. Z. Huang, Y. Liu, Y. Fang, B.K. Horn, Video-based fall detection for seniors with human pose
estimation, in 4th International Conference on Universal Village (IEEE, 2018), pp. 1–4
27. E. Jones, T. Oliphant, P. Peterson, SciPy: open source scientific tools for python (2001). http://
www.scipy.org
28. O.O. Khin, Q.M. Ta, C.C. Cheah, Development of a wireless sensor network for human fall
detection, in International Conference on Real-Time Computing and Robotics (IEEE, 2017),
pp. 273–278
29. Y. Kong, J. Huang, S. Huang, Z. Wei, S. Wang, Learning spatiotemporal representations for
human fall detection in surveillance video. J. Vis. Commun. Image Represent. 59, 215–230
(2019)
30. A. Krizhevsky, I. Sutskever, G.E. Hinton, Imagenet classification with deep convolutional
neural networks. Adv. Neural Inf. Process. Syst. 25, 1097–1105 (2012)
31. T. Kukharenko, V. Romanenko, Picking a human fall detection algorithm for wrist–worn elec-
tronic device, in IEEE First Ukraine Conference on Electrical and Computer Engineering
(2017), pp. 275–277
32. V.S. Kumar, K.G. Acharya, B. Sandeep, T. Jayavignesh, A. Chaturvedi, Wearable sensor–based
human fall detection wireless system, in Wireless Communication Networks and Internet of
Things (Springer, 2018), pp. 217–234
33. B. Kwolek, M. Kepski, Human fall detection on embedded platform using depth maps and
wireless accelerometer. Comput. Methods Programs Biomed. 117(3), 489–501 (2014)
34. B. Kwolek, M. Kepski, Improving fall detection by the use of depth sensor and accelerometer.
Neurocomputing 168, 637–645 (2015)
35. Y. LeCun, L. Bottou, Y. Bengio, P. Haffner, Gradient-based learning applied to document
recognition. Proc. IEEE 86(11), 2278–2324 (1998)
36. T. Lee, A. Mihailidis, An intelligent emergency response system: preliminary development and
testing of automated fall Detection. J. Telemed. Telecare 11(4), 194–198 (2005)
37. G. Leite, G. Silva, H. Pedrini, Fall detection in video sequences based on a three-stream
convolutional neural network, in 18th IEEE International Conference on Machine Learning
and Applications (ICMLA) (Boca Raton-FL, USA, 2019), pp. 191–195
38. G. Leite, G. Silva, H. Pedrini, Fall detection (2020). https://github.com/Lupins/fall_detection
78 G. V. Leite et al.

39. H. Li, K. Mueller, X. Chen, Beyond saliency: understanding convolutional neural networks
from saliency prediction on layer-wise relevance propagation. Comput. Res. Repos. (2017a)
40. X. Li, T. Pang, W. Liu, T. Wang, Fall detection for elderly person care using convolutional
neural networks, in 10th International Congress on Image and Signal Processing, BioMedical
Engineering and Informatics (2017b), pp. 1–6
41. W.N. Lie, A.T. Le, G.H. Lin, Human fall-down event detection based on 2D skeletons and deep
learning approach, in International Workshop on Advanced Image Technology (2018), pp. 1–4
42. B.S. Lin, J.S. Su, H. Chen, C.Y. Jan, A fall detection system based on human body silhouette,
in 9th International Conference on Intelligent Information Hiding and Multimedia Signal
Processing (IEEE, 2013), pp. 49–52
43. N. Lu, Y. Wu, L. Feng, J. Song, Deep learning for fall detection: 3D-CNN combined with
LSTM on video kinematic data. IEEE J. Biomed. Health Inform. 23(1), 314–323 (2018)
44. B.D. Lucas, T. Kanade, An iterative image registration technique with an application to stereo
vision, in International Joint Conference on Artificial Inteligence (1981), pp. 121–130
45. F. Luna-Perejon, J. Civit-Masot, I. Amaya-Rodriguez, L. Duran-Lopez, J.P. Dominguez-
Morales, A. Civit-Balcells, A. Linares-Barranco, An automated fall detection system using
recurrent neural networks, in Conference on Artificial Intelligence in Medicine in Europe
(Springer, 2019), pp. 36–41
46. M.M. Lusardi, S. Fritz, A. Middleton, L. Allison, M. Wingood, E. Phillips, Determining risk of
falls in community dwelling older adults: a systematic review and meta-analysis using posttest
probability. J. Geriatr. Phys. Ther. 40(1), 1–36 (2017)
47. X. Ma, H. Wang, B. Xue, M. Zhou, B. Ji, Y. Li, Depth-based human fall detection via shape
features and improved extreme learning machine. J. Biomed. Health Inform. 18(6), 1915–1922
(2014)
48. L. Meng, B. Zhao, B. Chang, G. Huang, W. Sun, F. Tung, L. Sigal, Interpretable
Spatio-Temporal Attention for Video Action Recognition (2018), pp. 1–10. arXiv preprint
arXiv:181004511
49. W. Min, H. Cui, H. Rao, Z. Li, L. Yao, Detection of human falls on furniture using scene analysis
based on deep learning and activity Characteristics. IEEE Access 6, 9324–9335 (2018)
50. M.N.H. Mohd, Y. Nizam, S. Suhaila, M.M.A. Jamil, An optimized low computational algorithm
for human fall detection from depth images based on support vector machine classification,
in IEEE International Conference on Signal and Image Processing Applications (2017), pp.
407–412
51. T.P. Moreira, D. Menotti, H. Pedrini, First-person action recognition through visual rhythm
texture description, in International Conference on Acoustics (Speech and Signal Processing,
IEEE, 2017), pp. 2627–2631
52. E.B. Nievas, O.D. Suarez, G.B. García, R. Sukthankar, Violence detection in video using
computer vision techniques, in International Conference on Computer Analysis of Images and
Patterns (Springer, 2011), pp. 332–339
53. Y. Nizam, M.N.H. Mohd, M.M.A. Jamil, Human fall detection from depth images using position
and velocity of subject. Procedia Comput. Sci. 105, 131–137 (2017)
54. A. Núñez-Marcos, G. Azkune, I. Arganda-Carreras, Vision-based fall detection with convolu-
tional neural networks. Wirel. Commun. Mob. Comput. 2017, 1–16 (2017)
55. T.E. Oliphant, Guide to NumPy, 2nd edn. (CreateSpace Independent Publishing Platform, USA,
USA, 2015)
56. L. Panahi, V. Ghods, Human fall detection using machine vision techniques on RGB-D images.
Biomed. Signal Process. Control 44, 146–153 (2018)
57. P.S. Sase, S.H. Bhandari, Human fall detection using depth videos, in 5th International Con-
ference on Signal Processing and Integrated Networks (IEEE, 2018), pp. 546–549
58. C. Schuldt, I. Laptev, B. Caputo, Recognizing human actions: a local SVM approach, in 17th
International Conference on Pattern Recognition, vol. 3 (IEEE, 2004), pp 32–36
59. K. Sehairi, F. Chouireb, J. Meunier, Elderly fall detection system based on multiple shape
features and motion analysis, in International Conference on Intelligent Systems and Computer
Vision (IEEE, 2018), pp. 1–8
Three-Stream Convolutional Neural Network … 79

60. A. Shojaei-Hashemi, P. Nasiopoulos, J.J. Little, M.T. Pourazad, Video–based human fall detec-
tion in smart homes using deep learning, in IEEE International Symposium on Circuits and
Systems (2018), pp. 1–5
61. K. Simonyan, A. Zisserman, Two-stream convolutional networks for action recognition in
videos. Adv. Neural Inf. Process. Syst. 27, 568–576 (2014a)
62. K. Simonyan, A. Zisserman, Very deep convolutional networks for large-scale image recogni-
tion (2014b), pp. 1–14. arXiv, arXiv:14091556
63. K. Simonyan, A. Vedaldi, A. Zisserman, Deep inside convolutional networks: visualising image
classification models and saliency maps. Comput. Res. Repos. (2013)
64. D. Smilkov, N. Thorat, B. Kim, F. Viégas, M. Wattenberg, Smoothgrad: removing noise by
adding noise (2017), pp. 1–10. arXiv preprint arXiv:170603825
65. M. Sundararajan, A. Taly, Q. Yan, Axiomatic attribution for deep networks, in 34th Interna-
tional Conference on Machine Learning, vol. 70, pp. 3319–3328 (JMLR.org, 2017)
66. C. Szegedy, W. Liu, Y. Jia, P. Sermanet, S. Reed, D. Anguelov, going deeper with convolutions,
in IEEE Conference on Computer Vision and Pattern Recognition (2015), pp. 1–9
67. S.K. Tasoulis, G.I. Mallis, S.V. Georgakopoulos, A.G. Vrahatis, V.P. Plagianakos, I.G. Maglo-
giannis, Deep learning and change detection for fall recognition, in Engineering Applications
of Neural Networks, ed. by J. Macintyre, L. Iliadis, I. Maglogiannis, C. Jayne (Springer Inter-
national Publishing, Cham, 2019), pp. 262–273
68. The Joint Commission, Fall reduction program—definition of a fall (2001)
69. B.S. Torres, H. Pedrini, Detection of complex video events through visual rhythm. Vis. Comput.
34(2), 145–165 (2018)
70. US Department of Veterans Affairs, Falls policy overview (2019). http://www.patientsafety.
va.gov/docs/fallstoolkit14/05_falls_policy_overview_v5.docx
71. F.B. Valio, H. Pedrini, N.J. Leite, Fast rotation-invariant video caption detection based on visual
rhythm. in Iberoamerican Congress on Pattern Recognition (Springer, 2011), pp. 157–164
72. M. Vallejo, C.V. Isaza, J.D. Lopez, Artificial neural networks as an alternative to traditional
fall detection methods, in 35th Annual International Conference of the IEEE Engineering in
Medicine and Biology Society (2013), pp. 1648–1651
73. G. Van Rossum, F.L. Jr Drake, Python reference manual. Tech. Rep. Report CS-R9525, Centrum
voor Wiskunde en Informatica, Amsterdam (1995)
74. L. Wang, Y. Xiong, Z. Wang, Y. Qiao, Towards good practices for very deep two-stream
convnets (2015), pp. 1–5. arXiv preprint arXiv:150702159
75. M. Wani, F. Bhat, S. Afzal, A. Khan, Advances in Deep Learning (Springer, 2020)
76. World Health Organization, Global Health and Aging (2011)
77. World Health Organization, Fact sheet falls (2012)
78. World Health Organization, World Report on Ageing and Health (2015)
79. T. Xu, Y. Zhou, J. Zhu, New advances and challenges of fall detection systems: a survey. Appl.
Sci. 8(3), 418 (2018)
80. M. Yu, S.M. Naqvi, J. Chambers, A robust fall detection system for the elderly in a smart
room, in IEEE International Conference on Acoustics Speech and Signal Processing (2010),
pp. 1666–1669
81. N. Zerrouki, A. Houacine, Combined curvelets and hidden Markov models for human fall
detection. Multimed. Tools Appl. 77(5), 6405–6424 (2018)
82. N. Zerrouki, F. Harrou, Y. Sun, A. Houacine, Vision-based human action classification using
adaptive boosting algorithm. IEEE Sens. J. 18(12), 5115–5121 (2018)
83. Z. Zhang, V. Athitsos, Fall detection by zhong zhang and vassilis athitsos (2020). http://vlm1.
uta.edu/~zhangzhong/fall_detection/
84. S. Zhao, W. Li, W. Niu, R. Gravina, G. Fortino, Recognition of human fall events based on
single tri–axial gyroscope, in IEEE 15th International Conference on Networking, Sensing and
Control (2018), pp. 1–6
85. F. Zhuang, Z. Qi, K. Duan, D. Xi, Y. Zhu, H. Zhu, H. Xiong, Q. He, A comprehensive survey
on transfer learning (2019), pp. 1–27. arXiv preprint arXiv:191102685
80 G. V. Leite et al.

86. Y. Zigel, D. Litvak, I. Gannot, A method for automatic fall detection of elderly people using floor
vibrations and sound-proof of concept on human mimicking doll falls. IEEE Trans. Biomed.
Eng. 56(12), 2858–2867 (2009)
87. Z. Zuo, B. Wei, F. Chao, Y. Qu, Y. Peng, L. Yang, Enhanced gradient-based local feature
descriptors by saliency map for egocentric action recognition. Appl. Syst. Innov. 2(1), 1–14
(2019)
Diagnosis of Bearing Faults in Electrical
Machines Using Long Short-Term
Memory (LSTM)

Russell Sabir, Daniele Rosato, Sven Hartmann, and Clemens Gühmann

Abstract Rolling element bearings are very important components in electrical


machines. Almost 50% of the faults that occur in the electrical machines occur in the
bearings. This makes bearings as one of the most critical components in electrical
machinery. Bearing fault diagnosis has drawn the attention of many researchers.
Generally, vibration signals from the machine’s accelerometer are used for the diag-
nosis of bearing faults. In literature, application of Deep Learning algorithms on
these vibration signals has resulted in the fault detection accuracy that is close to
100%. Although, fault detection using vibration signals from the machine is ideal
but measurement of vibration signals requires an additional sensor, which is absent in
many machines, especially low voltage machines as it significantly adds to its cost.
Alternatively, bearing fault diagnosis with the help of the stator current or Motor
Current Signal (MCS) is also gaining popularity. This paper uses MCS for the diag-
nosis of bearing inner raceway and outer raceway fault. Diagnosis using MCS is diffi-
cult as the fault signatures are buried beneath the noise in the current signal. Hence,
signal-processing techniques are employed for the extraction of the fault features.
The paper uses the Paderborn University damaged bearing dataset, which contains
stator current data from healthy, real damaged inner raceway, and real damaged outer
raceway bearings with different fault severity. Fault features are extracted from MCS
by first filtering out the redundant frequencies from the signal and then extracting
eight features from the filtered signal, which include three features from time domain

R. Sabir (B) · D. Rosato · S. Hartmann


SEG Automotive Germany GmbH, Lotterbergstraße 30, 70499 Stuttgart, Germany
e-mail: russell.sabir@seg-automotive.com
D. Rosato
e-mail: daniele.rosato@seg-automotive.com
S. Hartmann
e-mail: sven.hartmann2@seg-automotive.com
R. Sabir · C. Gühmann
Chair of Electronic Measurement and Diagnostic Technology & Technische Universität Berlin,
Berlin, Germany
e-mail: clemens.guehmann@tu-berlin.de

© The Editor(s) (if applicable) and The Author(s), under exclusive license 81
to Springer Nature Singapore Pte Ltd. 2021
M. A. Wani et al. (eds.), Deep Learning Applications, Volume 2,
Advances in Intelligent Systems and Computing 1232,
https://doi.org/10.1007/978-981-15-6759-9_4
82 R. Sabir et al.

and five features from time–frequency domain by using the Wavelet Packet Decom-
position (WPD). After the extraction of these eight features, the well-known Deep
Learning algorithm Long Short-Term Memory (LSTM) is used for bearing fault clas-
sification. The Deep Learning LSTM algorithm is mostly used in speech recognition
due to its time coherence, but in this paper, the ability of LSTM is also demon-
strated with the fault classification accuracy of 97%. A comparison of the proposed
algorithm is done with the traditional Machine Learning techniques, and it is shown
that the proposed methodology outperforms all the traditional algorithms which are
used for the classification of bearing faults using MCS. The method developed is
independent of the speed and the loading conditions.

Keywords Ball bearings · AC machines · Fault diagnosis · Fault detection ·


Discrete wavelet transforms · Wavelet packets · Wavelet coefficients · Learning
(artificial intelligence) · Machine learning · Deep learning · LSTM

1 Introduction

Rolling element bearings are key component in rotating machinery. They support
radial and axial loads by reducing rotational friction and hence carry heavy loads on
the machinery. They ensure critical optimal performance of the machinery, but their
failures can lead to the downtime of the machinery causing significant economic
losses [1, 2]. In an electrical machine, due to the overloading stress and misalign-
ment, bearings are the most susceptible components to be damaged. Bearing faults
are caused by insufficient lubrication, fatigue, incorrect mounting, corrosion, elec-
trical damage or from foreign particles (contamination) normally appear as wear,
indentation, and spalling [3, 4]. Bearing failures contribute to approximately about
50% of the failures in the electrical machines [4]. The structure of the rolling element
bearing is shown in Fig. 1. Nb is the number of balls, Db is the diameter of the balls,
Dc is the diameter of the cage, and β is the contact angle between the ball and

Fig. 1 Rolling element bearing [18]


Diagnosis of Bearing Faults in Electrical Machines … 83

the raceway. Rolling element bearings are comprised of two rings, an inner ring
and an outer ring as shown in Fig. 1. The inner race is normally mounted on the
motor shaft. Between the two rings, some balls or rollers are located. Continued
stress results in material fragments loss on the inner and outer race. Due to the
susceptibility of bearings being damaged, bearing fault diagnosis has attracted the
attention of many researchers. Typical bearing fault diagnosis includes analysis of
different measured signals from the electrical machine with damaged bearings by
using signal-processing methods.
In this paper, only the inner and outer raceway will be analyzed. Each of these
faults corresponds to a characteristic frequency f c , which is given by (1) and (2),
where fr is the rotating frequency of the rotor.
 
Nb Db
Outer raceway : f 0 = fr 1 − cos β (1)
2 Dc
 
Nb Db
Inner raceway : f i = fr 1 + cos β (2)
2 Dc

These frequencies result in when the balls rotate and hit the defect causing an
impulsive effect. These characteristic frequencies are a function of bearing geometry
and the rotor frequency. The impulsive effect generates rotor eccentricities at the
characteristic frequencies, causing the air-gap between the rotor and the stator to
be varied. Because of the variation of the air-gap length, the machine inductances
are also varied. These characteristic frequencies and sidebands are derived in [5].
However, these can be only reliably detected under specific conditions as they are
highly influenced by external noise and operating conditions. Most of the times, these
frequencies only appear in the frequency spectrum for faults with high severity.

2 Literature Review

Bearing fault diagnosis using vibration signals, acoustic emissions, temperature


measurements, chemical monitoring, electric currents, and shock pulse method has
been important research areas during the last few decades. The most widely used
methods for bearing fault diagnosis employ the use of vibration signals. Fault diag-
nosis using the vibration signals is the most effective method until date with accu-
racy of fault detection reaching close to 100%. With conventional Deep Learning
methods on the vibration signals, researchers have been able to achieve accuracies
well above 99%, e.g., 99.6% accuracy using CNN–LSTM (Convolutional Neural
Network–Long Short-Term Memory) [6], 99.83% accuracy using SDAE (Stacked
Denoising Autoencoder) [7], EDAE (Ensemble Deep Autoencoder) 99.15% [8], and
many more. However, in most industrial applications, cheap electrical machines used
have power rating of 2 kW or less and having additional sensors, e.g., the accelerom-
eter makes it not economically appealing because putting an extra sensor adds to its
84 R. Sabir et al.

cost. Furthermore, [9] argues that the vibration signal can be easily influenced by
loose screws and resonance frequencies of different parts of the machine, especially
the housing, which may lead to incorrect or misleading fault diagnosis.
According to [5], stator current or Motor Current Signal (MCS) can also be effec-
tively used for bearing fault diagnosis since currents can be easily measured from
the existing frequency inverters, so no additional sensors are required. As already
discussed, damages in the bearings, e.g., a hole, pit, crack, or missing material result
in the creation of shock pulses when the ball passes over those defects. The shock
pulses result in vibration at the bearing characteristic frequencies. This causes the
average air-gap length of the rotor to be varied, resulting in changes in the flux density
(or machine inductances) that consequently appear in the stator current as f b f given
by (3) and modulate the rotating frequency of the machine.

fb f = | fs ± k fc | (3)

where f s is the electrical frequency of the stator and k = 1, 2, 3, …


Schoen et al. [5] further states that by the analysis of the MCS in frequency
domain these characteristic frequency components can be detected. However, similar
characteristic frequencies start to appear in the frequency spectrum of the MCS
as a result of rotor eccentricity, broken rotor bars, or other conditions that may
vary the air-gap length of the rotor and stator. Other scenarios that could hinder the
detection of the characteristic fault frequencies are when the fault is in the early
stages or when the measured signal has a low signal-to-noise ratio. Also, the location
of these characteristic frequencies in the frequency spectrum is dependent on the
machine speed, bearing geometry, and the location of the fault. Although, bearing
fault detection by using vibration signals may have its advantages over diagnosis
with MCS, but with the help of Deep Learning algorithms, excellent fault diagnosis
using MCS can be achieved. Normally, for Deep Learning techniques, the feature
extraction is not needed, but in this case, the modulating components or features are
buried in the noise and they have to be detected and extracted using signal-processing
techniques. Hence, diagnosis using MCS is not an easy task. In [10], fault signatures
are extracted using the Discrete Wavelet Transform (DWT), which denoises the
signal, reconstructs it back in time domain, and does spectral analysis to identify
the faulty peaks. [11] identifies the fault signatures in MCS using the Welch PSD
estimation, but the signatures are more correctly identified using the RMS values of
the C8,7 and C7,6 coefficients of the of SWPT(Stationary Wavelet Packet Transform).
Nevertheless, in all these approaches only fault signature identification is presented
and no algorithm is demonstrated that is able to automatically identify the faults
from the data. [12] evaluates the pros and cons of using MCS for diagnosis over the
vibration signal, and concludes that the diagnosis from MCS is not always reliable
due to low detection amplitude of the fault signatures. Hence, intelligent methods
are not enough for correct fault classification. So, for such case, a powerful approach
like Deep Learning methods must be adopted, because Deep Learning models have
the ability to do classification even from weak features.
Diagnosis of Bearing Faults in Electrical Machines … 85

In literature, several Machine Learning and Deep Learning methods are employed;
each method serves its own advantages and disadvantages. For example, in [13], an
unsupervised classification technique called the artificial ant clustering is described,
but it requires three voltage and three current sensors, which make the method not
economically appealing. [14, 15] describe Machine Learning method SVM (Support
Vector Machine) for fault diagnosis, but in [14] only outer raceway faults are consid-
ered and accuracy drops from 100% to 95% when multi fault classification is done. In
[15], however the inner raceway fault, outer raceway fault, and cage fault are consid-
ered, the algorithm performs quite well at lower speeds with an average accuracy of
99%, but the accuracy drops to 92% at higher speeds. In [16], Convolutional Neural
Network (CNN) is trained with the amplitude of the selected frequency components
from the motor’s current spectrum. The results show that the accuracy is about 86%,
and only the bearing outer race fault is considered. In [17], Sparse Autoencoders (AE)
are used to extract fault patterns in the stator current and the vibration signals, and
SVM is used for classification. AE gives a 96% accuracy for high severity bearing
faults but 64% for low severity faults, and only the outer race fault is considered.
Finally, in [18], two feature extraction methods are used which include a 1D-CNN
and WPT (Wavelet Packet transform). In the analysis, two types of bearing faults
are classified, inner ring fault and fault due to aluminum powder in the bearing. The
method presented covers a wide speed range and fault severity with 98.8% accuracy
but fails to include the load variation and also requires very high specification training
hardware.
In this paper, stator current is used to diagnose two bearing faults, inner and
outer raceway, using the well-known Deep Learning algorithm, LSTM network.
The bearing fault data considered takes into account different operating conditions
such as speed, load variations, and varying fault severity levels. Hence, the method
developed is independent of the speed and the loading conditions. In the first section
of the paper, the LSTM network is described, then feature extraction and LSTM
training are discussed, and finally the results from the trained LSTM are analyzed.

3 LSTM Network

RNN (Recurrent Neural Network) is a class of artificial neural networks which is


used to identify patterns in sequential data, e.g., speech recognition, handwriting
recognition, natural language processing, video analysis, etc. Therefore, RNN can
take time and sequence into account, because the RNN possesses two forms of input,
the present input and the input from the recent past [19]. The key difference between
the RNN to the feedforward network is that they have a feedback loop connected to
their past decisions.
From Fig. 2, it can be seen that that at a certain time t the recurrent hidden layer
neurons have input from not only the input layer xt but also from its own at instance
t − 1, i.e., from h t−1 . So, the output is the combination of the present and the past.
The process can be represented as described by (4).
86 R. Sabir et al.

Fig. 2 Architecture of a RNN. b RNN over a timestep [20, 21]

h t = f (whx xt + whh h t−1 + bh )


 
yt = f w yh h t + b y (4)

where
xt is the input to the RNN at time t
ht is the state of the hidden layer at time t
h t−1 is the state of the neural network at t − 1
bh is the bias of the hidden layer
by is the bias of the output
whx are the weights between the hidden and the input layer
whh are the weights of the hidden layer at t − 1 and the hidden layer at t.
RNN is trained across the timesteps using Backpropagation Through Time (BPTT)
[21]. However, because of the multiplication of gradients at timesteps, the gradient
value becomes smaller and smaller or it gets larger and larger, and as a result, the RNN
learning process encounters the issue of gradient vanishing or gradient exploding.
Due to the issue of vanishing and exploding gradients, RNNs are used in limited
applications. This issue is solved by using the LSTM (Long Short-Term Memory)
which includes a memory cell that replaces the hidden RNN units [22].
The memory cell is composed of four gates, forget gate (how much to keep from
previous cell state), input gate (whether to write in the cell), input modulation gate
(how much quantity to write in the cell), and output gate (how much to reveal from
the cell) [23]. LSTMs preserve the error that is backpropagated through time and
layers. Figure 3 Shows the typical LSTM cell. In Fig. 3, σ (sigmoid) represents the
gate activation function, whereas ϕ (tanh) is the input or output node activation. The
LSTM model [21] presented in Fig. 3 is described by (5).
 
gt = ϕ wgx xt + wgh h t−1 + bg
i t = σ (wi x xt + wi h h t−1 + bi )
 
f t = σ w f x xt + w f h h t−1 + b f
ot = σ (wox xt + woh h t−1 + bo )
Diagnosis of Bearing Faults in Electrical Machines … 87

Fig. 3 LSTM memory cell

st = gt i t + st−1 f t
h t = ϕ(st )ot (5)

where
• wgx , wi x , w f x and wox are the weights at time t between the input and hidden
layer
• wgh , wi h , w f h , and woh are the weights at time t and t − 1 between the hidden
layers
• bg , bi , b f and bo are the biases of the gates
• h t−1 is the value of the hidden layer at time t − 1
• f t ,gt , i t , and ot are the output values of the forget gate, input modulation gate,
input gate, output gate, respectively
• st and st−1 are the current state at time t and t − 1, respectively.

4 Paderborn University Dataset

In [24, 25], a benchmark dataset for the purpose of bearing fault diagnosis is devel-
oped. The dataset is composed of synchronously measured currents of two phases
along with vibration signals. The stator currents are measured with current trans-
ducer sampled at 64 kHz, then filtered with a 25 kHz low-pass filter. The dataset is
composed of tests on six undamaged bearings and 26 damaged bearings. In the 26
damaged bearings, 12 have artificially induced damages and 14 have real damages.
In most of the research done on bearing fault diagnosis, only artificial bearing
faults are induced, and data is collected to make Machine Learning models because
they are easy to generate. Nevertheless, these models fail to deliver the expected
diagnosis when used in practical industry applications. In [24, 25], Machine Learning
models were trained with artificially damaged bearings and tested with real damaged
bearing. These models were not able to accurately classify the real damaged bearings.
The reason is that it is not possible to accurately replicate the real damages to the
88 R. Sabir et al.

bearings, artificially. Getting real damaged bearing from the machinery is not easily
possible. Due to the long lifetimes of bearings, bearings are generally replaced before
failure. Hence, it is difficult to find large quantities of real damaged bearings. In
[24, 25], by using scientific test rigs, real damaged bearings are produced by way of
accelerated lifetime tests. An advantage of such technique is that the bearing damages
can be generated by reproducible conditions. However, the disadvantage is that a lot
of time and effort is required for such a process. The real damage in the bearing is
generated with the application of a very high radial force to the bearing on a test
rig. The high radial force applied to the bearing is far greater than what the bearing
can endure which results in damages to appear very sooner in the bearing. Also,
to further accelerate the process, low viscosity oil is used that results in improper
lubrication and speeds up the process. Though the dataset is highly invaluable as it
provides the test data for real and artificial bearing damages, nevertheless this paper
will only focus on the real bearing damages for the purpose of demonstrating the
diagnosis algorithm, but the approach could of course be extended to the bearings
with artificial damages. The datasets of the three classes (healthy, outer race, and
inner race) are described in Table 1. The damages to the bearings are of varying
severity levels. The detailed description of these datasets and geometry of the used
bearing can be found in [24, 25]. All the damaged bearing dataset is composed of
single point damages, which are a result of fatigue, except the dataset KA30 that is
composed of distributed damaged bearing due to plastic deformation.
Each dataset (e.g., K001) has 80 measurements (20 measurements corresponding
to each operating condition) of 4 s each. Different operating conditions described
in the Table 2 are used, so that the Deep Learning algorithm incorporates variations
in the speed and load, and the model is not dependent only for certain operating
conditions. The operating parameters that are varied include the speed, radial force,

Table 1 Dataset of healthy


Healthy (Class 1) Inner ring damage Outer ring damage
bearing and bearings with real
(Class 2) (Class 3)
damages
K001 KI04 KA04
K002 KI14 KA15
K003 KI16 KA16
K004 KI18 KA22
K005 KI21 KA30

Table 2 Operating
No. Rotational speed Load torque [Nm] Radial force [N]
parameters of each dataset
[rpm]
0 1500 0.7 1000
1 900 0.7 1000
2 1500 0.1 1000
3 1500 0.7 400
Diagnosis of Bearing Faults in Electrical Machines … 89

and the load torque. For three settings, the speed of 1500 rpm is used, and in one
setting, speed of 900 rpm is used. Similarly, 0.7 Nm load torque for three settings,
0.1 Nm load torque for one setting, 1000 N radial force for the three settings, and
400 N radial force for one setting are used. All the datasets have damages that are less
than 2 mm in size, except for datasets KI16, KI18, and KA16 which have damages
greater than 2 mm. The temperature is kept constant to about 50 °C throughout all
experiments.

5 Data Processing and Feature Extraction

Before the Machine or Deep Learning algorithms are applied, the data is prepro-
cessed, and then the important features are extracted to be used as an input to the
algorithm for fault classification. It is therefore necessary to imply techniques that
extract the best and useful features and get rid of irrelevant or redundant informa-
tion. Removal of noisy, irrelevant, and misleading features gives a more compact
representation and improves the quality of detection and diagnosis [13].
Figure 4a shows the stator current signal of phase 1 of the machines, and Fig. 4b
shows its frequency spectrum. Looking closely in Fig. 4a, slight amplitude variations
of the sine wave amplitude can be observed. These small variations of amplitude could
contain the bearing characteristic frequencies. Hence, they are further analyzed for
feature extraction. The spectrum shows the two dominant frequencies ω0 and 5ω0 .
These frequencies offer no new information and are present in all the current signals
(whether belonging to the healthy or to the damaged bearing). Therefore, removing

Fig. 4 a Stator current, b frequency spectrum of stator current, c stator current after filtration of
ω0 and 5ω0 component, d the frequency spectrum of the filtered stator current
90 R. Sabir et al.

the frequencies from the current signal will give more focus to the frequencies that
result due to the bearing faults. To remove the unwanted frequencies (ω0 and 5ω0 ),
a signal-processing filter is designed in MATLAB that suppresses the frequencies
ω0 and 5ω0 from the current signal. Figure 4c shows the filtered current signal, and
Fig. 4d shows its spectrum. Now, this filtered signal is used for feature extraction.
The spectrum of the filtered signal now contains noise and the characteristic bearing
frequencies (if they are present).
Huo et al. [26] evaluates the statistical features of the vibration signals from the
bearings and concludes that the two features that are most sensitive parameters for
detecting bearing faults are kurtosis and impulse factor. This conclusion well applies
to the current signals, because these two statistical features adequately capture the
impulsive behavior of the faulty bearings. Another time domain feature that is also
useful is the clearance factor, which is also sensitive to the faults in the bearings. The
remaining features are extracted from the time–frequency domain using third-level
WPD (Wavelet Packet Decomposition).
WPD is one of the well-known signal-processing technique, which is widely used
in fault diagnosis. WPD provides useful information in both time and frequency
domains. WPD decomposes the time domain signal into wavelets of various scales
with variable sized windows and reveals the local structure in time–frequency
domain. The advantages of WPD over other signal-processing techniques are that
with WPD transient features can be effectively extracted and features from the full
spectrum can be extracted without the requirement of a specific frequency band. WPD
uses mother wavelets that are basic wavelet functions to expand, compress, and trans-
late the signal by varying the scale frequency and the time shift of the wavelet. This
enables the application at low-scale high frequency for short windows and at high-
scale low frequency for long window. For example, with long window at high scale
and low frequency, higher resolution in time can be achieved for high-frequency
components and high resolution in frequency for lower frequency components. The
wavelet function must meet the following requirement in (6).

|ψ(ω)|2
Cψ = dω < ∞ (6)
|ω|
R

where ψ(ω) is the Fourier transform of the mother wavelet ψ(t)


The Continuous Wavelet Transform (CWT) is described by (7).
 
1 t −b
ψa,b (t) = √ ψ (7)
|a| a

where ψa,b (t) is the continuous mother wavelet which is scaled by factor a and
translated by factor b, and √1|a| is used for energy preservation.
ψa,b (t) acts as a window function whose frequency and time location can be
adjusted by a and b. For example, higher resolution in frequency can be achieved
by smaller values of a; this helps when extracting higher frequency components of
Diagnosis of Bearing Faults in Electrical Machines … 91

the signal. The scaling and translation parameters make wavelet analysis ideal for
non-stationary and non-linear signals in both time and frequency domain. a and b
are continuous and can take any value but when a and b are both discretized we get
DWT (Discrete Wavelet Transform).
Although DWT is an excellent method for time–frequency domain analysis of the
signal, it only considers the low-frequency part and neglects the high-frequency part,
resulting in very bad resolution of high frequency. Therefore, in our case Wavelet
Packet Decomposition (WPD), which is an extension of DWT is used for the time-
frequency analysis of the signal. The difference from DWT is that in WPD all the
detail signals (high-frequency part) are decomposed into further two signals, detail
signal and approximation signal (low-frequency part). Hence, using WPD, a full-
scale analysis in the time–frequency domain can be done. The scaling function ∅(t)
and wavelet function ψ(t) of WPD are described by (8) and (9), respectively.

√ 
∅(t) = 2 h(k)∅(2t − k) (8)
k

√ 
ψ(t) = 2 g(k)∅(2t − k) (9)
k

where h(k) and g(k) are low- and high-pass filters, respectively.
The WPD from level j to j + 1 at node n of a signal s( j, n) is given by (10) and
(11). s( j, n) could either be cAj approximation (low frequency) or cDj detail (high
frequency) coefficients.

s j+1,2n = h(m − 2k)s j,n (10)
m

s j+1,2n+1 = g(m − 2k)s j,n (11)
m

where m is the wavelet coefficient number


The process of the jth-level WPD is shown in Fig. 5.

Fig. 5 Schematic of WPD (Wavelet Packet Decomposition) process


92 R. Sabir et al.

With the help of WPD, the signal is decomposed into different frequency bands
[27]. When a damage occurs in a bearing, the effect creates resonance at different
frequencies causing energies to be distributed in different frequency bands depending
on the bearing fault. The energy E( j, n) of WPD coefficient s( j, n) of jth level and
nth node can be computed as shown in (12).

E( j, n) = s( j, n)2 (12)

To perform WPD, a mother wavelet must be selected; the selection of the wavelet
depends on the application, as no wavelet is the absolute best. The wavelet that
works best has its properties or the similarity close to the signal of application. [28]
demonstrates the training of bearing fault data with different wavelets and concludes
that the mother wavelet that adequately captures the bearing fault signatures is the
Daubechies 6 (db6) wavelet. The Daubechies 6 (db6) wavelet also works best for our
case as well and therefore is used for extracting the time–frequency domain features
in this paper.
In [25], a detailed feature selection method based on maximum separation distance
is discussed. Using this method, the relevant features are selected as not all the
features are useful for fault classification. Figure 6 Shows the third-level WPD of
the signal x. From this decomposition, c A3,0 , c A3,1 , c A3,2 , and cD3,2 coefficients are
used, and their energies are calculated. Table 3 shows the detailed list of the eight
features that are selected and used in the diagnosis algorithm. These eight features
from the filtered current signal are able to describe the bearing fault signatures quite
well. From the datasets presented in Table 1, the following steps are considered for
preparing the data for training and testing of the LSTM Network.

Fig. 6 Coefficients of the third-level WPD (Wavelet packet decomposition)


Diagnosis of Bearing Faults in Electrical Machines … 93

Table 3 Description of the features


Feature no. Description
n
i=1 (x i −x̄)
4
1 Kurtosis .Kurt = (n−1)σ 4
.
max(|x i |)
2 Impulse factor IF = n
i=1 |x i |
1
n
max(|x i |)
3 Clearance factor Clf = n √ 2
i=1 |x i |
1
n

1 n
4 RMS (Third-level WPD approximation coefficient 3, 0) R M S= n i=1 c A3,02
i
n
5 Energy of third-level WPD approximation coefficient 3, 0 E A3,0 = i=1 c A3,02 i
n2
6 Energy of third-level WPD approximation coefficient 3, 1 E A3,1 = i =1 c A3,12 i

7 Energy of third-level WPD approximation coefficient 3, 2 E A3,2 = n3i =1 c A3,22i
n4
8 Energy of third-level WPD detail coefficient 3, 2 E D3,2 = i =1 cD3,22
i

1. Step I: The stator current of phase 1 was taken which contains a 4 s measurement
(resulting in 1200 signals from all the datasets, 400 signals each for healthy, inner
race fault, and outer race fault class). The signals are filtered to remove ω0 and
5ω0 components. Then, the features presented in Table 3 are extracted.
2. Step II: All the features extracted in Step I are scaled between 0 and 1, for better
convergence of the algorithm.
3. Step III: The feature data of each class is shuffled within the class itself, so that
training and test data incorporate all the different operating points and conditions,
as each dataset has different fault severities.
4. Step IV: 20% of the data points (i.e., 80 points) from each class are kept for
testing, and 80% of the data points (i.e., 320 data points) from each class are
used for training. Therefore, in total, there are 960 data points in the training set
and 240 points in the testing set.

6 LSTM Network Training and Results

The LSTM network architecture is described in Table 4. The LSTM network is


composed of eight input nodes that correspond to the eight selected features. Addi-
tion of more features does not affect the accuracy largely but decreasing the features
causes accuracy to fall. The input nodes are followed by four hidden layers with 32
hidden nodes in each layer. There is no absolute rule in the choice of the number
of layers and hidden nodes, so four layers and 32 nodes have been chosen because
decreasing the number of layers and nodes results in decrease in accuracy and while
increasing it does not affect the accuracy to a great deal. After each LSTM hidden
layer, a 50% dropout from the previous layer is added to prevent the network memo-
rizing or overfitting the training data. Finally, the dense layer with softmax activation
94 R. Sabir et al.

Table 4 LSTM achitecture


Layers Nodes
description
Input layer 8 input nodes
LSTM layer 1 32 hidden nodes
50% dropout
LSTM layer 2 32 hidden nodes
50% dropout
LSTM layer 3 32 hidden nodes
50% dropout
LSTM layer 4 32 hidden nodes
50% dropout
Dense layer with softmax activation 3 output nodes

Table 5 LSTM training


Parameters Value
parameter details
Training samples 960
Testing samples 240
Batch size 64
Epochs 2500
Dropout ratio 0.5

composing of three output nodes (corresponding to three output classes healthy, inner
ring damage, and outer ring damage) is added. The Binary Cross-entropy function is
used as the loss function, and ADAM optimizer [29] is used to train the parameters
of the LSTM network. In the code implementation, training is done on computer’s
GPU, so that the training process is accelerated.
After the features have been extracted from the datasets and preprocessed, the
feature vector is input to the LSTM Network for training. The parameters of the
training are displayed in Table 5.
Using the parameters, the network is trained for 2500 epochs with a batch size of
64. After the training, the LSTM network was able to achieve an accuracy of 97% on
the testing data and 100% accuracy on training data. Figure 7 Shows the confusion
matrix of the testing results of the LSTM model with rows representing the predicted
class and columns the true class. False positives are displayed in the left column, and
the upper row displays the false negatives. All of the normal or healthy bearing points
are correctly classified. However, some points of the inner race fault are misclassified
as outer race fault, and some points of the outer race are misclassified as inner race
fault. Nevertheless, the LSTM Network with the proposed methodology provides
excellent results in diagnosing the bearing faults and classifying healthy, inner race,
and outer race faults to a great degree of accuracy.
Diagnosis of Bearing Faults in Electrical Machines … 95

Fig. 7 Confusion matrix of


the results of LSTM testing
for class 1 (normal), class 2
(inner race damaged), and
class 3 (outer race damaged)

7 Comparison with Traditional Methods

In the previous section, LSTM network was used in the diagnosis of the bearing inner
and outer race faults, and the testing accuracy of 97% is achieved, which is greater
than the testing accuracy 93.3% achieved by the ensemble of Machine Learning
algorithms in [25]. Therefore, the Machine Learning methods used in their anal-
ysis were not exposed to different severity bearing measurements, which led to a
lower accuracy in their analysis. In conventional Machine Learning techniques, the
features are manually extracted from the data so that the data patterns are clearer to
the Machine Learning algorithm. For Deep Learning approaches, the algorithm is
capable of automatically extracting high-level features from the data. This approach
of feeding the data directly to the Deep Learning algorithm works for majority of
the cases. However, in the case of bearing fault diagnosis, feeding direct MCS to
the proposed LSTM network failed to give good performance and resulted in a very
poor classification accuracy. The reason for the poor performance of the proposed
LSTM methodology is that the fault characteristic magnitudes are comparable or
lower in magnitude to the signal noise. As a result, LSTM Network is unsuccessful
in extracting these characteristic frequency features. Therefore, manual extraction of
features from MCS is done, and signal-processing techniques such as wavelet trans-
form are required for this task. As eight features are extracted from the filtered current
signal, the traditional Machine Learning methods can also be applied in comparison
to the proposed LSTM method. Table 6 shows the comparison of the traditional
methods with the proposed method. In this testing, all the datasets were combined
and shuffled, and five-fold cross-validation was done. Most of the Machine Learning
96 R. Sabir et al.

Table 6 Comparison of traditional methods with the proposed LSTM method using five-fold
cross-validation
Algorithm Classification accuracy (%)
Multilayer perceptron (layers 150, 100, 50 with ADAM solver and 95.4
ReLU activation)
SVM (Support Vector Machine) 66.1
k nearest neighbor 91.2
Linear regression 53.0
Linear discriminant analysis 56.8
CART (Classification And Regression Trees) 93.6
Gaussian Naive Bayes 46.0
LSTM (Proposed method) 97.0

methods showed poor performance; only MLP (Multilayer Perceptron), kNN, and
CART methods showed promising results. The reason for the better performance
of the LSTM method is that due to its memorizing ability, the algorithm learns
the patterns in these 8 features, as these 8 features form a sequence. Hence, the
proposed LSTM methodology outperforms all the other traditional methods, proving
the superiority of Deep Learning techniques compared to the traditional methods.
Although, the proposed LSTM methodology outperformed the traditional
Machine Learning methods, one disadvantage that Deep Learning algorithms carry
is about high computational requirements. The proposed LSTM Network model was
trained on GPU to accelerate the training performance. However, training of the
Machine Learning algorithms required normal CPU processing and less time, and
therefore if computational resources are available, then Deep Learning algorithms
should always be opted for.
The results from the LSTM model are quite promising, however the dataset is not
large enough and the trained model may not measure up to the same performance if
the model is exposed to data points from a new environment for testing. This happens
due to the randomness of background conditions and different noise conditions of new
data points. Hence, more research is needed in this area, so that the Deep Learning
models are able to adapt to a new environment without sacrificing their performance.
Therefore, it would be helpful to have a broader analysis with a larger dataset.

8 Conclusion and Future Work

The paper focused on the diagnosis of the bearing inner and outer race fault. Conven-
tional bearing fault diagnosis involves the use of vibrational signal from the machine’s
accelerometer. However, in this paper, MCS is used instead, for the fault diagnosis. By
using MCS, its potential in effective bearing fault diagnosis has been demonstrated.
With this methodology, the need for additional sensor for monitoring the vibrational
Diagnosis of Bearing Faults in Electrical Machines … 97

signals is eliminated, reducing cost of the fault monitoring system. The Paderborn
university real damaged bearing dataset was considered. From the datasets, the stator
current of one phase was used and processed by removing the redundant frequen-
cies ω0 and 5ω0 . Then, eight features are extracted from this filtered signal, three
features from the time domain, and five features from the time–frequency domain
using third-level WPD. These eight features were scaled and then fed to the Deep
Learning LSTM network with four hidden layers. The LSTM network was able to
show excellent results with a classification accuracy of 97%, even with the dataset
containing data at different operating conditions, i.e., different speed and load, which
proves that the method developed is independent of the machine operating condi-
tions. In the end, the proposed LSTM methodology is compared with the traditional
Machine Learning methods by using five-fold cross-validation, and the proposed
methodology outperformed the traditional methods by achieving more than 1.5%
accuracy than the best performing algorithm. Hence, it has been shown that fault
diagnosis with MCS by using the proposed LSTM algorithm is able to give similar,
if not better performance than the diagnosis with vibrational signals.
For future work, diagnosis of other bearing faults, e.g., ball fault and cage fault
using the stator current will be considered. Secondly, other Deep Learning methods
that are listed in [30] such as CNN-LSTM networks, algorithms that incorporate
the randomness of different working conditions and algorithms that are able to
adapt to different operating environments without compromising on the method’s
performance will be explored. Furthermore, more research will done on increasing
the classification performance of the network by considering denoising LSTMs,
regularization methods, and a larger training dataset.

References

1. R. Sabir, S. Hartmann, C. Gühmann, Open and short circuit fault detection in alternators using
the rectified DC output voltage, in 2018 IEEE 4th Southern Power Electronics Conference
(SPEC) (Singapore, 2018), pp. 1–7
2. R. Sabir, D. Rosato, S. Hartmann, C. Gühmann, Detection and localization of electrical faults
in a three phase synchronous generator with rectifier, in 19th International Conference on
Electrical Drives & Power Electronics (EDPE 2019) (Slovakia, 2019)
3. Common causes of bearing failure | applied. Applied (2019). https://www.applied.com/bearin
gfailure
4. I.Y. Onel, M.E.H. Benbouzid, Induction motors bearing failures detection and diagnosis: park
and concordia transform approaches comparative study, in 2007 IEEE International Electric
Machines & Drives Conference (Antalya, 2007), pp. 1073–1078
5. R.R. Schoen, T.G. Habetler, F. Kamran, R.G. Bartfield, Motor bearing damage detection using
stator current monitoring. IEEE Trans. Ind. Appl. 31(6), 1274–1279 (1995). https://doi.org/10.
1109/28.475697
6. H. Pan, X. He, S. Tang, F. Meng, An improved bearing fault diagnosis method using one-
dimensional CNN and LSTM. J. Mech. Eng. 64(7–8), 443–452 (2018)
7. X. Guo, C. Shen, L. Chen, Deep fault recognizer: an integrated model to denoise and extract
features for fault diagnosis in rotating machinery. Appl. Sci. 7(41), 1–17 (2017)
98 R. Sabir et al.

8. H. Shao, H. Jiang, Y. Lin, X. Li, A novel method for intelligent fault diagnosis of rolling
bearings using ensemble deep autoencoders. Knowl.-Based Syst. 119, 200–220 (2018)
9. D. Filbert, C. Guehmann, Fault diagnosis on bearings of electric motors by estimating the
current spectrum. IFAC Proc. 27(5), 689–694 (1994)
10. S. Yeolekar, G.N. Mulay, J.B. Helonde, Outer race bearing fault identification of induction
motor based on stator current signature by wavelet transform, in 2017 2nd IEEE Interna-
tional Conference on Recent Trends in Electronics, Information & Communication Technology
(RTEICT) (Bangalore, 2017), pp. 2011–2015
11. F. Ben Abid, A. Braham, Advanced signal processing techniques for bearing fault detection in
induction motors, in 2018 15th International Multi-Conference on Systems, Signals & Devices
(SSD) (Hammamet, 2018), pp. 882–887
12. A. Bellini, F. Immovilli, R. Rubini, C. Tassoni, Diagnosis of bearing faults of induction
machines by vibration or current signals: a critical comparison, in 2008 IEEE Industry
Applications Society Annual Meeting (Edmonton, AB, 2008), pp. 1–8
13. A. Soualhi, G. Clerc, H. Razik, Detection and diagnosis of faults in induction motor using
an improved artificial ant clustering technique. IEEE Trans. Ind. Electron. 60(9), 4053–4062
(2013)
14. S. Gunasekaran, S.E. Pandarakone, K. Asano, Y. Mizuno, H. Nakamura, Condition monitoring
and diagnosis of outer raceway bearing fault using support vector machine, in 2018 Condition
Monitoring and Diagnosis (CMD) (Perth, WA, 2018), pp. 1–6
15. I. Andrijauskas, R. Adaskevicius, SVM based bearing fault diagnosis in induction motors
using frequency spectrum features of stator current, in 2018 23rd International Conference on
Methods & Models in Automation & Robotics (MMAR) (Miedzyzdroje, 2018), pp. 826–831
16. S.E. Pandarakone, M. Masuko, Y. Mizuno, H. Nakamura, Deep neural network based bearing
fault diagnosis of induction motor using fast fourier transform analysis, in 2018 IEEE Energy
Conversion Congress and Exposition (ECCE) (Portland, OR, 2018), pp. 3214–3221
17. J.S. Lal Senanayaka, H. Van Khang, K.G. Robbersmyr, Autoencoders and data fusion based
hybrid health indicator for detecting bearing and stator winding faults in electric motors, in
2018 21st International Conference on Electrical Machines and Systems (ICEMS) (Jeju, 2018),
pp. 531–536
18. I. Kao, W. Wang, Y. Lai, J. Perng, Analysis of permanent magnet synchronous motor fault
diagnosis based on learning. IEEE Trans. Instrum. Meas. 68(2), 310–324 (2019)
19. A beginner’s guide to LSTMs and recurrent neural networks (Skymind, 2019). https://skymind.
ai/wiki/lstm
20. S. Zhang, S. Zhang, B. Wang, T.G. Habetler, Machine learning and deep learning algorithms
for bearing fault diagnostics-a comprehensive review (2019). arXiv preprint arXiv:1901.08247
21. Z.C. Lipton, J. Berkowitz, C. Elkan, A critical review of recurrent neural networks for sequence
learning (2015). arXiv preprint arXiv:1506.00019
22. S. Hochreiter, J. Schmidhuber, Long short-term memory. Neural Comput. 9, 1735–1780. (1997)
(source: Stanford CS231N)
23. F. Immovilli, A. Bellini, R. Rubini et al., Diagnosis of bearing faults of induction machines by
vibration or current signals: a critical comparison. IEEE Trans. Ind. Appl. 46(4), 1350–1359
(2010)
24. Konstruktions-und Antriebstechnik (KAt)—Data Sets and Download (Universität Pader-
born), Mb.uni-paderborn.de (2019). https://mb.uni-paderborn.de/kat/forschung/datacenter/bea
ring-datacenter/data-sets-and-download/
25. C. Lessmeier, J.K. Kimotho, D. Zimmer, W. Sextro, Condition monitoring of bearing damage
in electromechanical drive systems by using motor current signals of electric motors: a bench-
mark data set for data-driven classification, in Proceedings of the European Conference of the
Prognostics and Health Management Society (2016), pp. 05–08
26. Z. Huo, Y. Zhang, P. Francq, L. Shu, J. Huang, Incipient fault diagnosis of roller bearing
using optimized wavelet transform based multi-speed vibration signatures. IEEE Access 5,
19442–19456 (2017)
Diagnosis of Bearing Faults in Electrical Machines … 99

27. X. Wang, Z. Lu, J. Wei, Y. Zhang, Fault diagnosis for rail vehicle axle-box bearings based on
energy feature reconstruction and composite multiscale permutation entropy. Entropy 21(9),
865 (2019)
28. S. Djaballah, K. Meftah, K. Khelil, M. Tedjini, L. Sedira, Detection and diagnosis of fault
bearing using wavelet packet transform and neural network. Frattura ed Integrità Strutturale
13(49), 291–301 (2019)
29. D.P. Kingma, J. Ba, Adam: a method for stochastic optimization (2014). arXiv preprint arXiv:
1412.6980
30. M.A. Wani, F.A. Bhat, S. Afzal, A.L. Khan, Advances in Deep Learning (Springer, 2020)
Automatic Solar Panel Detection
from High-Resolution Orthoimagery
Using Deep Learning Segmentation
Networks

Tahir Mujtaba and M. Arif Wani

Abstract Solar panel detection from aerial or satellite imagery is a very convenient
and economical technique for counting the number of solar panels on the rooftops
in a region or city and also for estimating the solar potential of the installed solar
panels. Detection of accurate shapes and sizes of solar panels is a prerequisite for
successful capacity and energy generation estimation from solar panels over a region
or a city. Such an approach is helpful for the government to build policies to inte-
grate solar panels installed at home, offices, and buildings with the electric grids. This
study explores the use of various deep learning segmentation algorithms for auto-
matic solar panel detection from high-resolution ortho-rectified RGB imagery with
resolution of 0.3 m. We compare and evaluate the performance of six deep learning
segmentation networks in automatic detection of the distributed solar panel arrays
from satellite imagery. The networks are tested on real data and augmented data.
Results indicate that deep learning segmentation networks work well for automatic
solar panel detection from high-resolution orthoimagery.

1 Introduction

The world is exploring to utilize more renewable energy sources as the non-renewable
energy sources are getting depleted. One of the most important and abundantly avail-
able renewable energy sources is solar energy. For utilizing solar energy, solar panels
are installed on ground and roof tops of buildings to convert solar energy into electric
energy. Throughout the world, there has been a surge in installing solar panels to get
the most out of this form of energy source. One of the challenges is to detect solar
panels installed on ground and buildings from aerial imagery. Accurate detection

T. Mujtaba (B) · M. A. Wani


Department of Computer Science, University of Kashmir, Srinagar, India
e-mail: mjtbatahir@gmail.com
M. A. Wani
e-mail: awani@uok.edu.in

© The Editor(s) (if applicable) and The Author(s), under exclusive license 101
to Springer Nature Singapore Pte Ltd. 2021
M. A. Wani et al. (eds.), Deep Learning Applications, Volume 2,
Advances in Intelligent Systems and Computing 1232,
https://doi.org/10.1007/978-981-15-6759-9_5
102 T. Mujtaba and M. A. Wani

of solar panels with accurate shapes and sizes is a prerequisite for determining the
capacity of energy generation from these panels. For this problem, semantic segmen-
tation is required to accurately detect the solar panels from the aerial imagery. For the
past several years, machine learning approaches have been used for many applica-
tions like classification, clustering, and image recognition [1–7], and some research
has been done to analyze and asses the objects like roads, buildings, and vehicles
present in satellite imagery.
In recent past, deep learning has been found to be more effective in image recog-
nition problems than the traditional machine learning techniques [8]. The image
recognition problems include tasks like image classification, object detection, and
semantic segmentation. Deep learning has outperformed the traditional machine
learning techniques in these image recognition tasks. Among these image recog-
nition tasks, the problem of semantic segmentation is one of the key research areas.
Semantic segmentation is the process of classifying each pixel of the image to a set
of predefined classes, therefore dividing an image into set of regions where each
region consists of pixels belonging to one common class. A number of traditional
machine learning techniques have been used in segmentation of images in the past.
The general procedure of traditional machine learning techniques is to use well-
established feature descriptors to extract features from an image and use a classifier
like support vector machines or random forest classifier on each pixel to determine
its likeliness of belonging to one of the predefined classes. Such techniques heavily
depend on the procedure that is used to extract features and usually need skilled
feature engineering for designing such procedures. It is necessary to determine which
features are important for an image. When the number of classes increases, feature
extraction becomes tedious task. Moreover, such techniques do not involve using
one end-to-end approach, feature extraction and classification is done separately. In
Comparison to it, deep learning models are end-to-end models where both feature
extraction and classification are done through a single algorithm. Further, model
parameters are learnt automatically during the process of learning. Because of the
above reasons, this work has used the deep learning approach for the problem of
solar panel detection.
Deep learning segmentation process has evolved where a deep learning image
classification model has been converted into a segmentation model [9]. The segmen-
tation process has been improved with the introduction of UNet [10] model. The
model essentially consists of an encoder and a decoder with skip connections from
lower layers of encoder to higher layers of decoder to transfer the learned features.
The purpose of encoder is to extract features that are up-sampled in the decoder and
features in the decoder are concatenated with the features extracted in the encoder.
Up-sampling is done by using transposed convolutions. An index-based up-sampling
technique has been introduced in [11]. The technique remembers the index of pixel
values with maximum intensity during max-pooling operation and then uses the same
indexes during up-sampling process. Dilated convolution that uses dilated filters has
been tested in [12–14]. Dilated convolution helps in reduction of learnable parame-
ters and it also helps in aggregating the context of the image. The aim of this work is
Automatic Solar Panel Detection from High-Resolution … 103

to analyze and compare the performance of various segmentation architectures for


the problem of solar panel detection from aerial imagery.
This chapter is organized as follows. Section 2 presents an overview of the
related work done on the problem of solar panel detection. Section 3 describes deep
learning-based segmentation. Section 4 discusses different deep learning segmenta-
tion networks used in this study. Section 5 presents results and discussion. Conclusion
is finally presented in Sect. 6.

2 Related Work

With the availability of high-resolution satellite images and the emergence of deep
learning networks, it has become possible to analyze the high-resolution images for
accurate object detection. Detection of the objects like buildings, roads, and vehicles
from aerial/satellite images has been studied more widely during the past many years
than the detection of solar panels. A few studies [15–17] have been reported in the
literature that explore detection of solar panels from satellite or orthoimagery using
deep learning techniques. VggNet with pretraining has been used in [16] to detect
the solar panels from the aerial imagery dataset given in [18]. The authors have used
a very basic VggNet architecture consisting of six convolutional layers and two fully
connected layers. The model is applied on every pixel to detect whether it belongs to
solar panel or not. Further processing is performed to connect contagious pixels and
declare these as regions. Though this model has been found effective in detecting
solar panels but it does not give the shape and size of panel arrays explicitly.
Authors in [17] proposed a fully convolutional neural network for solar panel
detection. It uses seven convolutional layers with different number of filters and filter
sizes. It first extracts features from the image by using several convolution and max-
pooling layers, it then up-samples the extracted features and uses skip connection
from the shallow layers which contain fine grained features and concatenates them
with coarse layers. Although the work has been able to detect solar panels, the
standard metric used for segmentation is missing.
A deep learning segmentation architecture called SegNet has been used in [15]
for automatic detection of solar panels from ortho-rectified images given is [18].
Again it has used very small image patches of size 41 × 41 for training purposes.
These small patches may not contain much of the diverse background like buildings,
roads, trees, and vehicles, which may restrict the algorithm’s capability to perform
in diverse backgrounds. Further, each image in the dataset, which is of size 5000 ×
5000, would generate approximately 14,872 image patches of size 41 × 41 on which
training and testing are to be done. Very little work has been done on solar panel
detection using deep learning-based segmentation.
104 T. Mujtaba and M. A. Wani

3 Deep Learning-Based Segmentation

Segmentation of an image involves classifying every pixel of an image into a given


set of classes. Unlike some applications where the aim to classify the whole image
into one of the given classes, semantic segmentation makes dense predictions by
inferring labels to every pixel. The contemporary classification architectures like
VggNet [19], ResNet [20], and DenseNets [21] can be converted into architectures
that are suitable for segmentation. The segmentation process is shown in Fig. 1
where general encoder–decoder architecture for deep learning segmentation process
is shown.
A segmentation model generally consists of an encoder and a decoder. The encoder
is usually a pretrained classification network with its fully connected layers removed.
Its task is to extract features and produce a heatmap with low resolution. The task of
the decoder is to up-sample the heatmap successively to the original resolution of the
input. The deep learning segmentation algorithms differ in the way encoders extract
features and the way decoders perform up-sampling with different skip connection
strategies. The different encoding, decoding, and up-sampling techniques used in
deep learning-based segmentation are discussed below.

3.1 Encoding Techniques

The main purpose of the encoder is to extract features from the images through succes-
sive convolution and pooling layers. The encoding part usually comprises networks
like VggNet, ResNet, and DenseNet with their fully connected layers removed.

Fig. 1 General encoder–decoder architecture for deep learning segmentation process


Automatic Solar Panel Detection from High-Resolution … 105

a. VggNet

VggNet [19] has various variants and the prominent one is VggNet-16, which consists
13 convolution layers, 5 max-pooling layers, and 3 fully connected layers. VggNet-
19 is also prominent and consists 16 convolutional layers, 5 max-pooling layers,
and 3 fully connected layers. VggNet uses smaller sized 3 × 3 convolution filters
as compared with the AlexNet which uses 7 × 7 convolution filters. Within a given
receptive field, using multiple number of 3 × 3 filters is better than using one larger
sized 7 × 7 filter as it involves less parameters and reduces computational effort.

b. ResNet

Residual Network (ResNet) [20] uses residual connections or identity shortcut


connections skipping one or more layers to increase the performance without
increasing the depth of the network. The network consists of stacked residual blocks
with 3 × 3 convolutions. Various variants of this network are: ResNet-18, ResNet-
34, ResNet-52, ResNet-101, and ResNet-156, consisting of 18, 34, 52, 101, and
156 layers, respectively. Such a network can be easily adopted in the segmentation
process with fully connected layers removed.
c. DenseNet

The main characteristic feature of DenseNet [21] is that its every layer is connected to
every other layer in a feedforward manner resulting in (L(L + 1))/2 direct connec-
tions where L is the number of layers in the network. For each and every layer,
the features maps of all preceding layers are used as input and its own produced
feature maps are used in all subsequent layers. DenseNet alleviates vanishing gradient
problem, strengthens feature propagation, and encourages feature reuse and reduces
the number of parameters. The network consists of various dense blocks where a
layer is connected to every subsequent layer. Each layer in a dense block consists
of batch normalization, ReLU activation and 3 × 3 convolution. The layers between
two dense blocks are transitions layers consisting of batch normalization, 1 × 1
convolution and average pooling. Such type of network can be easily adopted for
semantic segmentation.

3.2 Decoding Techniques

A decoder increases the image resolution by using up-sampling techniques. Some of


the decoding techniques commonly used in deep learning segmentation process are
summarized below.
106 T. Mujtaba and M. A. Wani

a. Skip Connection-based Decoders

i. FCN Based decoding

FCN-based decoder [9] uses transpose convolution (deconvolution) for up-sampling.


The decoder has different variants depending on from which pooling layer a skip
connection is added: (a) FCN-32 has no skip connection, (b) FCN-16 uses one skip
connections from fourth pooling layer, and (c) FCN-8 uses two skip connections
from the third and fourth pooling layers. After last convolution, FCN decoder uses
1 × 1 convolution layer and softmax function to classify image pixels. The number
of 1 × 1 convolutions equals the number of classes into which the image pixels are
to be categorized.
ii. UNet-Based decoding

The UNet-based decoder [10] uses transpose convolution for up-sampling followed
by convolutional layers and ReLU activation. UNet extends the concept of skip
connection used in the FCN decoder. Here every encoder is connected to its corre-
sponding decoder unit through a skip connection. The features learnt in the encoder
block are carried over and concatenated with the features of the decoder block.
b. Max-pooling index-based Decoder

Max-pooling index-based decoder [11] uses the indices stored during the max-
pooling process of the encoder to up-sample the feature maps in the corresponding
decoder during the decoding process. Unlike the FCN and UNet decoders, no features
are carried to the decoders through skip connections. The up-sampled features are
then convolved with trainable filters.
c. Un-pooling-based Decoder

Un-pooling-based decoder [22] uses un-pooling and deconvolutional layers in its


decoder blocks. The un-pooling reverses the pooling operation performed during
encoding. In the encoding process, it remembers the maximum activation value
during max-pooling operation and during decoding it uses the un-pooling operation
to restore the resolution of the activations. The un-pooling operation is performed
by using switch variables which remember the location of the maximum activations
during max pooling. After un-pooling, a deconvolutional layer is used to densify the
sparse feature maps produced by un-pooling.

3.3 Up-Sampling Techniques

The various up-sampling techniques that have been used to increase the resolution
of the feature maps in the decoder part are summarized below
Automatic Solar Panel Detection from High-Resolution … 107

a. Nearest Neighbor
Nearest Neighbor up-sampling technique simply copies the pixel value of the
nearest pixel to its neighboring pixel.
b. Bed of Nail
Bed of Nails puts value of a pixel in a particular/fixed position in the output and
the rest of the positions are filled with value 0.
c. Bilinear up-sampling
It calculates a pixel value by interpolating the values from the nearest pixels
which are known but unlike nearest neighbor technique the ratio of contribution
from each nearby pixel matters here and is inversely proportional to the ratio of
their corresponding distance.
d. Max Un-pooling
It remembers the index of the maximum activation during the max-pooling oper-
ation and uses the same index to position the pixel value in the output during the
up-sampling.
e. Transpose Convolution
Transpose convolution is the most effective and most commonly used technique
for image up-sampling in deep learning semantic segmentation because it’s a
learnable up-sampling. The input is padded with zeros when convolution is
applied.
f. Dilated Convolutions
Also known as Atrous convolution was first developed for the efficient compu-
tation of the undecimated wavelet transform. Dilated convolution is a normal
convolution with a wider kernel. The kernel in the convolution is exponentially
expanded to capture more context of the image without increasing the number
of parameters. A normal convolution is a dilated convolution with a dilation rate
equal to 1.

4 Deep Learning Segmentation Architectures Used

4.1 UNet

A fully convolutional segmentation architecture for medical image segmentation


reported in [10] is explored here for segmentation of solar panel images. A UNet
model essentially consists of an encoder and a decoder. The encoder part uses various
convolution and pooling operations to extract the features from the image. The output
of the encoder is a heatmap which serves as input to the decoder. The purpose of
the decoder is to up-sample the heatmap so that spatial dimensions match the input,
densify the segmentation, classify the pixels, and produce a segmentation map. The
decoder semantically projects the fine features learnt in the beginning layers into the
higher layers to produce a dense segmentation. The encoder and decoder architecture
forms a U-shaped structure that gives it the name UNet. The contracting path acquires
108 T. Mujtaba and M. A. Wani

Fig. 2 UNet architecture used in this study

context and expanding path facilitates accurate localization. Up-sampling in decoder


is done using transpose convolutions. Architecture of UNet used in this study is
shown in Fig. 2.

4.2 SegNet

The segmentation architecture SegNet [11] for scene understanding applications that
is efficient in terms of memory and computational time is explored here for automatic
detection of solar panels from satellite images. The SegNet architecture consists of
an encoder and decoder like UNet but differs in how up-sampling is done in the
decoder part. The deconvolutional layers used for up-sampling in decoder part of
UNet are time and memory consuming because up-sampling is performed using a
learnable model, which implies filters for up-sampling are learned during the training
process. The SegNet architecture replaces the learnable up-sampling by computing
and memorizing the max-pool indices and later uses these indices to up-sample the
features in the corresponding decoder block to produce sparse feature maps. It then
uses normal convolution with trainable filters to densify these sparse feature maps.
However, there are no skip connections for feature transfer like in UNet. The use of
max-pool indices results in reduced number of model parameters which eventually
takes less time to get trained. The architecture of SegNet used in this study is given in
Fig. 3. The max-pooling index concept is illustrated in Fig. 4 which also distinguishes
between the process of up-sampling used in UNet and SegNet architectures.
Automatic Solar Panel Detection from High-Resolution … 109

Fig. 3 SegNet architecture used in this study

Fig. 4 Decoders used in FCN and SegNet


110 T. Mujtaba and M. A. Wani

Fig. 5 Multicontext aggregation architecture (Dilated Net) used in this study

4.3 Dilated Net

The use of dilated convolution for context aggregation in semantic segmentation


discussed in [13] is explored for automatic detection of solar panels in satellite
images. The feature context aggregation is done without losing resolution or using
rescaled versions of images. It introduces a context module consisting of 7 layers
with 3 × 3 convolution and dilated convolutions applied at different rates—1, 1, 2, 4,
8, 16, 1. The last convolution is a 1 × 1 × C convolution to produce the final output
of the module. The architecture uses two types of context modules—a basic module
and a large context module. The basic module contains same number of channels
(C) throughout the module while as the large context module contains increasing
number of channels (C) as input. The architecture introduces another module known
as front module which is constructed over a VggNet by removing the last two pooling
layers and striding layers and adding dilation convolution in layers to follow. Finally,
it adds a context module to the front module for dense semantic prediction and
contextualization. The architecture is shown in Fig. 5. A 2D dilated convolution with
different dilation rates is shown in Fig. 6.

4.4 PSPNet

PspNet [23] captures global context information by introducing a pyramid pooling


module for better classification of small objects is explored for automatic detection of
solar panels in satellite images. Small objects are hard to find but have a great impor-
tance in overall scene categorization. The pyramid pooling module gathers global
context information along with sub-region context for categorization of different
Automatic Solar Panel Detection from High-Resolution … 111

Fig. 6 Dilated convolution in 2D with different dilation rates. a Dilation rate = 1. b Dilation rate
= 2, c dilation rate = 3

objects. The pyramid pooling module uses four different pooling operations: 1 × 1,
2 × 2, 3 × 3, and 6 × 6. The pooling operations generate feature maps of different
sub-regions and form pooled representation for different locations. The output feature
maps from these pooling operations are of varied sizes. It then uses bilinear inter-
polation for up-sampling these features to a size of input resolution. The number of
pyramid levels and size of each level can be modified. The architecture of PSPNet
is shown in Fig. 7.

Fig. 7 PSPNet Architecture used in this study


112 T. Mujtaba and M. A. Wani

Fig. 8 Deep Lab v3+ Architecture used in this study

4.5 DeepLab v3+

DeepLab v3+ [12] makes use of an encoder–decoder structure for dense semantic
segmentation is explored for automatic detection of solar panels in satellite images.
The encoder–decoder structure has two advantages: (i) it is capable of encoding
multi-scale contextual information by probing the incoming features with filters
or pooling operations at multiple rates and multiple effective fields-of-view, (ii) it
can capture sharper object boundaries by gradually recovering the spatial informa-
tion through the use of skip connections. It has an additional simple and effective
decoder module to refine the segmentation results especially along object boundaries.
It further applies depth-wise separable convolution to both Atrous Spatial Pyramid
Pooling and decoder modules, resulting in a faster and stronger encoder–decoder
network. The detailed encoder and decoder structure is given in Fig. 8.

4.6 Dilated Residual Network

Dilated Residual Network [14] uses dilated convolutions in a Residual Network for
classification, and segmentation tasks are explored for automatic detection of solar
panels in satellite images. In convolutional networks, the spatial size of feature maps
gets continuously reduced due to multiple use of pooling and striding operations.
Such a loss in spatial structure limits the model’s ability to produce good results in
Automatic Solar Panel Detection from High-Resolution … 113

Fig. 9 Dilated residual network architecture used in this study

classification and segmentation tasks. Dilated Residual Network introduces dilated


convolution in a residual network to increases the receptive field of the feature maps
without increasing the parameters. Dilated Residual Network also develops an effec-
tive dilation strategy for dealing the gridding pattern problem that occurs due to
increase of dilation rates at successive layers. A residual network is converted into
dilated residual network by removing the striding in first layer of block 4 and block
5 and introducing dilation rate of 2 in the rest of the layers of block 4 and 1st layer of
block 5 and dilation rate of 4 in the rest of the layers of the block 5. Predictions are
produced by 1 × 1 × C layer where C is the number of classes. The feature responses
so produced have a resolution of 1/8 of the original image resolution and are bilinearly
up-sampled to get the same resolution as input image. A Dilated residual network is
shown in Fig. 9.

5 Results and Discussion

5.1 Dataset Used

The purpose of this study is to detect the location of solar panels in satellite images
of buildings and earth’s surface using deep learning segmentation techniques. The
training step of a segmentation process requires a dataset containing both the images
as well as their corresponding masks which are used as ground truth. The masks
highlight the pixels which correspond to the solar panels. As the dataset containing
both the images and corresponding masks is not publicly available, it was decided to
114 T. Mujtaba and M. A. Wani

use the dataset described in [18] and prepare masks for this dataset before training
the models. The dataset has 601 TIF orthoimages of four cities of California and
contains geospatial coordinates and vertices of about 19,000 solar panels spread
across all images. This work has used images of Fresno city. Each image is 5000-
by-5000 pixels and covers an area of 2.25 km2 . The images are of urban, suburban,
and rural landscape type, allowing the model to get trained on diverse images. The
vertices of the solar panels associated with each image have been utilized to create
polygon areas of white pixels corresponding to solar panels and setting the remaining
pixels as black pixels to represent background. Figure 10 shows a sample image of
size 5000-by-5000 pixels and its sub-image of size 224-by-224 pixels.
To make the models robust, data augmentation has been performed to reduce over-
fitting and improve generalization. The augmentation was achieved by performing
horizontal flip and vertical flip on images. These two augmentations proved useful
in training the models and increasing the segmentation accuracy.

Fig. 10 First row shows an image and its mask of size 5000 × 5000. Second row shows an image
with its mask of size 224 × 224
Automatic Solar Panel Detection from High-Resolution … 115

5.2 Training

All the architectures have been implemented in Python and trained on workstation
with Nvidia Tesla K40 (12 GB) GPU and 30 GB RAM. As image size of 5000-by-
5000 is huge for training purposes and needs a high GPU and RAM configuration, the
images and their corresponding masks have been cropped to size 224 × 224. A total
of 1118 cropped images were selected for augmentation and training. The testing
was done on image crops of the same size. Adam’s learning algorithm with a fixed
learning rate of l × 10−5 for ResNet-based models and l × 10−4 for VggNet-based
models has been used. Training was done from scratch without any pretraining or
transfer learning. The early stopping criterion of 15 epochs to stop the model trainings
has been used.

5.3 Performance Metric and Loss Function

The performance measure metric used in this study is dice coefficient (f1 score) which
is one of the most widely used metric used in segmentation. This metric is used to
quantify how similar the ground truth annotated segmentation region matches with
the predicted segmentation region of the model. It is defined as the ratio of intersection
(overlap) of two regions to the union of the two regions. Given two sets of pixels
denoted by X and Y, the dice coefficient index is given by:

DC = (2 ∗ |X ∩ Y|)/(|X |U |Y |) (1)

The value of dice coefficient ranges from 0 to 1. Value close to 1 means more
overlap and similarity between the two regions, hence more accurate the predicted
segmentation from the model.
The loss function used in this study is the dice loss (DL) originated from dice
coefficient and was used by [24] and is defined as

DL = 1 − DC

where DL is the dice loss and DC is the dice coefficient defined above in (1). The
dice loss is used to optimize the value of DC during the training process.

5.4 Experimental Results

The experimental results have been obtained by using augmented as well as original
datasets. The dice coefficients and loss results of training, validation, and testing of
augmented and original datasets are reported here.
116 T. Mujtaba and M. A. Wani

Table 1 Results on UNet model


Model Dataset type Training results Validation results Testing results
DC Loss DC Loss DC Loss
UNet Augmented 0.8826 0.1173 0.8867 0.1132 0.8963 0.1036
Original 0.8750 0.124 0.8819 0.1181 0.8943 0.1056

Table 1 shows dice coefficient and loss results of training, validation, and testing
augmented and original datasets on UNet model. It can be seen from Table 1 that
augmented dataset has helped in improving training, validation, and testing accuracy
results. Figure 11 shows the dice coefficient bar graphs of training, validation, and
testing augmented and original datasets on UNet model. As can be observed from
Fig. 11, the data augmentation has produced better DC values for training, validation,
and testing of datasets on UNet model.
Table 2 shows dice coefficient and loss results of training, validation, and testing
augmented and original datasets on SegNet model. It can be seen from Table 2 that
augmented dataset has helped in improving training, validation, and testing accuracy
results. The training, validation, and testing results of dice coefficients for augmented
datasets on SegNet model improves by a margin of about 9–10% when compared
with the results of original dataset on SegNet model. Figure 12 has shown the dice
coefficient bar graphs of training, validation, and testing augmented and original
datasets on SegNet model. As can be observed from Fig. 12, the data augmentation

Fig. 11 Dice Coefficients of augmented and original datasets on UNet model. a Results of training
process. b Results of validation process. c Results of testing process

Table 2 Results on SegNet model


Model Dataset type Training results Validation results Testing results
DC Loss DC Loss DC Loss
SegNet Augmented 0.8167 0.1832 0.7775 0.2224 0.7471 0.2528
Original 0.7102 0.2897 0.6828 0.3171 0.6425 0.3574
Automatic Solar Panel Detection from High-Resolution … 117

Fig. 12 Dice Coefficients of augmented and original datasets on SegNet model. a Results of training
process. b Results of validation process. c Results of testing process

has produced better DC values for training, validation, and testing of datasets on
SegNet model.
Table 3 shows dice coefficient and loss results of training, validation, and testing
augmented and original datasets on Dilated Net model. It can be seen from Table 3
that augmented dataset has helped in improving training, validation, and testing
accuracy results. The training, validation, and testing results of dice coefficients for
augmented datasets on Dilated Net model improves by a margin of about 1–2% when
compared with the results of original dataset on Dilated Net model. Figure 13 has
shown the dice coefficient bar graphs of training, validation, and testing augmented
and original datasets on Dilated Net model. As can be observed from Fig. 13, the

Table 3 Results on Dilated net


Model Dataset type Training results Validation results Testing results
DC Loss DC Loss DC Loss
Dilated net Augmented 0.6956 0.3043 0.6591 0.3408 0.6732 0.3267
Original 0.6862 0.3137 0.6465 0.3534 0.6615 0.3384

Fig. 13 Dice coefficients of augmented and original datasets on dilated net model. a Results of
training process. b Results of validation process. c Results of testing process
118 T. Mujtaba and M. A. Wani

Table 4 Results on PSPNet


Model Dataset type Training results Validation results Testing results
DC Loss DC Loss DC Loss
PSPNet Augmented 0.6025 0.397 0.5527 0.4472 0.5091 0.4908
Original 0.6122 0.3877 0.4102 0.5897 0.4181 0.5818

data augmentation has produced better DC values for training, validation, and testing
of datasets on Dilated Net model.
Table 4 shows dice coefficient and loss results of training, validation, and testing
augmented and original datasets on PSPNet model. It can be seen from Table 4 that
augmented dataset has helped in improving validation and testing accuracy results.
The validation and testing results of dice coefficients for augmented datasets on
PSPNet model improves by a margin of about 9–15% when compared with the results
of original dataset on PSPNet model. However, training results of dice coefficients
for augmented datasets on PSPNet model decreases. This implies more epochs are
required to train the PSPNet with larger datasets. Figure 14 has shown the dice
coefficient bar graphs of training, validation, and testing augmented and original
datasets on PSPNet model. As can be observed from Fig. 14, the data augmentation
has produced better DC values for validation, and testing of datasets on PSPNet
model.

Fig. 14 Dice coefficients of augmented and original datasets on PSPNet model. a Results of training
process. b Results of validation process. c Results of testing process

Table 5 Results on DeepLab v3+ model


Model Dataset type Training results Validation results Testing results
DC Loss DC Loss DC Loss
DeepLab v3+ Augmented 0.7572 0.2427 0.6410 0.3589 0.6610 0.3389
Original 0.7877 0.2122 0.5654 0.4345 0.5713 0.4286
Automatic Solar Panel Detection from High-Resolution … 119

Table 5 shows dice coefficient and loss results of training, validation, and testing
augmented and original datasets on DeepLab v3+ model. It can be seen from Table 5
that augmented dataset has helped in improving validation and testing accuracy
results. The validation and testing results of dice coefficients for augmented datasets
on DeepLab v3 +model improves by a margin of about 7–9% when compared with
the results of original dataset on DeepLab v3+ model. However, training results of
dice coefficients for augmented datasets on DeepLab v3+ model decreases. This
implies more epochs are required to train the DeepLab v3+ with larger datasets.
Figure 15 has shown the dice coefficient bar graphs of training, validation and testing
augmented and original datasets on DeepLab v3+ model. As can be observed from
Fig. 15, the data augmentation has produced better DC values for validation, and
testing of datasets on DeepLab v3+ model.
Table 6 shows dice coefficient and loss results of training, validation, and testing
augmented and original datasets on Dilated ResNet model. It can be seen from Table 6
that augmented dataset has helped in improving validation and testing accuracy
results. The validation and testing results of dice coefficients for augmented datasets
on Dilated ResNet model improves by a margin of about 5–9% when compared with
the results of original dataset on Dilated ResNet model. However, training results of
dice coefficients for augmented datasets on Dilated ResNet model decreases. This
implies more epochs are required to train the Dilated ResNet with larger datasets.
Figure 16 has shown the dice coefficient bar graphs of training, validation, and testing
augmented and original datasets on Dilated ResNet model. As can be observed from

Fig. 15 Dice Coefficients of augmented and original datasets on DeepLab v3+ model. a Results
of training process. b Results of validation process. c Results of testing process

Table 6 Results on Dilated ResNet model


Model Dataset type Training results Validation results Testing results
DC Loss DC Loss DC Loss
Dilated ResNet Augmented 0.7164 0.2835 0.6766 0.3233 0.6307 0.3692
Original 0.7498 0.2501 0.6203 0.3796 0.5434 0.4565
120 T. Mujtaba and M. A. Wani

Fig. 16 Dice Coefficients of augmented and original datasets on Dilated ResNet model. a Results
of training process. b Results of validation process. c Results of testing process

Fig. 16, the data augmentation has produced better DC values for validation, and
testing of datasets on Dilated Resnet model.
Dice coefficient results of testing augmented and original datasets on all the six
models have been summarized in Fig. 17 in the form of bar graphs. The bar graphs
indicate that the UNet model produces the best value of DC, implying that the best
segmentation accuracy results are produced by UNet model, followed by SegNet and
DilatedNet models.

Fig. 17 DC values of testing augmented and original datasets on all the six models
Automatic Solar Panel Detection from High-Resolution … 121

6 Conclusion

This work described the automatic detection of solar panels from satellite imagery by
using deep learning segmentation models. The study thoroughly discussed various
state of art deep learning segmentation architectures, various encoding, decoding, and
up-sampling techniques used in deep learning segmentation process. The six archi-
tectures for automatic detection of solar panels used were UNet, SegNet, Dilated Net,
PSPNet, DeepLab v3+, and Dilated Residual Net. The dataset comprised satellite
images of four cities of California. Image size of 224 × 224 was used for training the
models. The results concluded that the UNet deep learning architecture that uses skip
connections with encoder and decoder modules produced the best segmentation accu-
racy results. Moreover, dataset augmentation helped to improve the segmentation
accuracy results further.

References

1. M.A. Wani, Incremental hybrid approach for microarray classification, in 2008 Seventh
International Conference on Machine Learning and Applications (IEEE, 2008), pp. 514–520
2. M.A. Wani, R. Riyaz, A new cluster validity index using maximum cluster spread based
compactness measure. Int. J. Intell. Comput. Cybern. (2016)
3. M.A. Wani, R. Riyaz, A novel point density based validity index for clustering gene expression
datasets. Int. J. Data Mining Bioinf. 17(1), 66–84 (2017)
4. R. Riyaz, M.A. Wani, Local and global data spread based index for determining number of
clusters in a dataset, in 2016 15th IEEE International Conference on Machine Learning and
Applications (ICMLA) (IEEE, 2016), pp. 651–656
5. F.A. Bhat, M.A. Wani, Performance comparison of major classical face recognition tech-
niques, in 2014 13th International Conference on Machine Learning and Applications (IEEE,
2014), pp. 521–528
6. M.A. Wani, M. Yesilbudak, Recognition of wind speed patterns using multi-scale subspace
grids with decision trees. Int. J. Renew. Res. (IJRER) 3(2), 458–462 (2013)
7. M.R. Wani, M.A. Wani, R. Riyaz, Cluster based approach for mining patterns to predict wind
speed, in 2016 IEEE International Conference on Renewable Energy Research and Applications
(ICRERA) (IEEE, 2016), pp. 1046–1050
8. M.A. Wani, F.A. Bhat, S. Afzal, A.L. Khan, Advances in Deep Learning (Springer, 2020)
9. J. Long, E. Shelhamer, T. Darrell, Fully convolutional networks for semantic segmentation,
in Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 3431–
3440 (2015)
10. O. Ronneberger, P. Fischer, T. Brox, U-net: convolutional networks for biomedical image
segmentation, in International Conference on Medical Image Computing and Computer-
Assisted Intervention (Springer, Cham, 2015), pp. 234–241
11. V. Badrinarayanan, A. Kendall, R. Cipolla, Segnet: a deep convolutional encoder-decoder
architecture for image segmentation. IEEE Trans. Pattern Anal. Mach. Intell. 39(12), 2481–
2495 (2017)
12. L.C. Chen, Y. Zhu, G. Papandreou, F. Schroff, H. Adam, Encoder-decoder with atrous separable
convolution for semantic image segmentation, in Proceedings of the European Conference on
Computer Vision (ECCV) (2018), pp. 801–818
13. F. Yu, V. Koltun, Multi-scale context aggregation by dilated convolutions, in ICLR (2016)
122 T. Mujtaba and M. A. Wani

14. F. Yu, V. Koltun, T. Funkhouser, Dilated residual networks, in Proceedings of the IEEE
Conference on Computer vision and Pattern Recognition (2017), pp. 472–480
15. J. Camilo, R. Wang, L.M. Collins, K. Bradbury, J.M. Malof, Application of a semantic segmen-
tation convolutional neural network for accurate automatic detection and mapping of solar
photovoltaic arrays in aerial imagery (2018). arXiv preprint arXiv:1801.04018
16. J.M. Malof, L.M. Collins, K. Bradbury, A deep convolutional neural network, with pre-training,
for solar photovoltaic array detection in aerial imagery, in 2017 IEEE International Geoscience
and Remote Sensing Symposium (IGARSS) (IEEE, 2017), pp. 874–877
17. J. Yuan, H.H.L. Yang, O.A. Omitaomu, B.L. Bhaduri, Large-scale solar panel mapping from
aerial images using deep convolutional networks, in 2016 IEEE International Conference on
Big Data (Big Data) (IEEE, 2016), pp. 2703–2708
18. K. Bradbury, R. Saboo, T.L. Johnson, J.M. Malof, A. Devarajan, W. Zhang, R.G. Newell,
Distributed solar photovoltaic array location and extent dataset for remote sensing object
identification. Sci. Data 3, 160106 (2016)
19. K. Simonyan, A. Zisserman, Very deep convolutional networks for large-scale image
recognition (2014). arXiv preprint arXiv:1409.1556
20. K. He, X. Zhang, S. Ren, J. Sun, Deep residual learning for image recognition, in Proceedings
of the IEEE Conference on Computer Vision and Pattern Recognition (2016), pp. 770–778
21. G. Huang, Z. Liu, L. Van Der Maaten, K.Q. Weinberger, Densely connected convolutional
networks, in Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition
(2017), pp. 4700–4708
22. H. Noh, S. Hong, B. Han, Learning deconvolution network for semantic segmentation,
in Proceedings of the IEEE International Conference on Computer Vision (2015), pp. 1520–
1528
23. H. Zhao, J. Shi, X. Qi, X. Wang, J. Jia, Pyramid scene parsing network, in Proceedings of the
IEEE International Conference on Computer Vision and Pattern Recognition (Honolulu, HI,
USA, 2017), pp. 2881–2890
24. F. Milletari, N. Navab, S.A. Ahmadi, V-net: fully convolutional neural networks for volumetric
medical image segmentation, in 2016 Fourth International Conference on 3D Vision (3DV)
(IEEE, 2016), pp. 565–571
Training Deep Learning Sequence
Models to Understand Driver Behavior

Shokoufeh Monjezi Kouchak and Ashraf Gaffar

Abstract Driver distraction is one of the leading causes of fatal car accidents in the
U.S. Analyzing driver behavior using machine learning and deep learning models
is an emerging solution to detect abnormal behavior and alarm the driver. Models
with memory such as LSTM networks outperform memoryless models in car safety
applications since driving is a continuous task and considering information in the
sequence of driving data can increase the model’s performance. In this work, we used
time-sequenced driving data that we collected in eight driving contexts to measure
the driver distraction. Our model is also capable of detecting the type of behavior that
caused distraction. We used the driver interaction with the car infotainment system as
the distracting activity. A multilayer neural network (MLP) was used as the baseline
and two types of LSTM networks including the LSTM model with attention network
and the encoder–decoder model with attention were built and trained to analyze the
effect of memory and attention on the computational expense and performance of the
model. We compare the performance of these two complex networks to that of the
MLP in estimating driver behavior. We show that our encoder–decoder with attention
model outperforms the LSTM attention while using LSTM networks with attention
enhanced training process of the MLP network.

Keywords Bidirectional · LSTM network · Attention network · Driver


distraction · Deep learning · Encoder–decoder

S. M. Kouchak · A. Gaffar (B)


Arizona State University, Tempe, USA
e-mail: agaffar@asu.edu
S. M. Kouchak
e-mail: smonjezi@asu.edu

© The Editor(s) (if applicable) and The Author(s), under exclusive license 123
to Springer Nature Singapore Pte Ltd. 2021
M. A. Wani et al. (eds.), Deep Learning Applications, Volume 2,
Advances in Intelligent Systems and Computing 1232,
https://doi.org/10.1007/978-981-15-6759-9_6
124 S. M. Kouchak and A. Gaffar

1 Introduction

Driver distraction is one of the leading causes of fatal car accidents in the U.S.,
and—while it is becoming an epidemic—it is preventable [1]. Based on the NHTSA
report, 3,166 people died in fatal car accidents that are caused by distracted drivers
in 2017, which constitutes 8.5% of all fatalities [2, 3]. Distracted driving is any task
that diverts the driver’s attention from the primary task of driving such as texting,
reading, talking to other passengers, and using the car infotainment system [4, 5].
Texting is the most distracting task since the driver needs to take their eyes off the
road for an estimated five seconds to read or send a text. If the driver was driving at
45 mph, they will have covered a distance equal to a football field’s length without
seeing the road [4].
Driver distraction is considered a failure in driving task prioritization by putting
more attention on secondary tasks, such as reading a text or tuning the radio, while
ignoring the primary task of driving. There are four types of distraction including.
Visual distraction is any task that makes the driver to take his eyes off the road.
Manual distraction happens when the driver takes his hands off the steering wheel
[6].
Cognitive distraction is any task that diverts driver attention from the primary task
of driving [7].
Audio distraction is any noise that obscures important voice in the car such as
alarms or outside such as ambulance vehicle sirens [8].
Designing a car-friendly user interface [9], advanced driver assistant systems
(ADAS) [10–12], and autonomous vehicles (AV) [13] are some approaches to solve
the problem. Driver assistant systems such as lane keeping and adaptive cruise control
reduce human error by automating difficult or repetitive tasks. Although they can
enhance driver safety, they have some limitations as they often work only in prede-
fined situations. For instance, collision mitigation systems are based on algorithms
that have been defined by systems developers to cover specific situations that could
lead to a collision in traffic. These algorithms can’t identify a full range of dangerous
situations that could lead to an accident. They typically detect and prevent only a
limited number of predefined menace situations [14]. An autonomous vehicle can
sense its environment using a variety of sensors such as Radar and LIDAR and move
the car with limited or no human contribution. They face some challenges such as
dealing with driver cars, unsuitable road infrastructures, and cybersecurity. In the
long run, when these challenges are solved, autonomous vehicles could boost car
safety and decrease commuting time and cost [15].
Monitoring and analyzing driver behavior using machine learning methods to
detect abnormal behavior is an emerging solution to enhance car safety [16]. Human
error which can increase due to mental workloads persuaded by distraction is one
of the leading causes of car crashes. Workloads are hard to observe and quantify;
so, analyzing driver behavior to distinguish normal and aggressive driving behavior
can be used as a suitable approach to detect driver distraction and alarm the driver
to take the control of the car for a short time [17]. Driving data and driver status
Training Deep Learning Sequence Models … 125

that can be collected from internal sources (like the vehicle’s onboard computer and
the data bus) as well as external devices such as camera, can be used to train a
variety of machine learning methods such as Markov model, neural network, and
decision tree to learn driving patterns and detect abnormal driver behavior [18–22].
Deep learning methods such as convolutional neural network, LSTM networks, and
encoder–decoder outperform other machine learning methods in car safety applica-
tions such as pedestrian detection and driver behavior classification [23–26]. Some
machine learning and deep learning methods make a spontaneous decision based on
the current inputs of the model. In some real-world applications like driving, each
data sample has an effect on several next samples, so using a temporal dimension and
adding memory and attention can extract more behavior-representative information
from a sequence of data, which could improve the accuracy of these applications.
In this work, we use three models including a multilayer neural network (MLP),
an LSTM network with attention layer, and an encoder–decoder model to predict the
driver status using only driving data. No intrusive devices such as cameras or driver-
wearable sensors are used. This makes our approach more user-friendly and increases
its potential to be adapted in industry more readily. The MLP model was considered
as the baseline and memoryless model. Two other models consider the weighted
dependency between input and output [27, 28]. This provides instant feedback on the
progress of the network model, which increases its prediction accuracy and reduces
its training time. We started with a single-input–single-output neural network and
compared the accuracy and training process of this model with an LSTM model with
attention, which is multiple-input–single-output and has both memory and attention.
The LSTM model with attention achieved less train and test error in a smaller number
of training epochs. The average number of training epochs is 400 for MLP and 100
for LSTM attention. Besides, the LSTM attention model has a smaller number of
layers. After that, we used the encoder–decoder with attention to estimate a sequence
of driver behavior using a multiple-input–multiple-output model. We compared the
achieved results with two other models. The train and test error of this model is less
than two other models and it can estimate multiple driver behavior vectors.
Section 2 discusses some related works. Section 3 describes the experiment.
Section 4 explains the features of the collected data in the experiment. We discuss
methodology in Sect. 5. Section 6 talks about the details of the three models. Section 7
is the results and Sect. 8 is the conclusion.

2 Related Works

Wöllmer et al. [29] discussed a driver distraction detection method. The goal of this
method is to model head tracking and the context of driving data using an LSTM
network. An experiment was conducted with 30 volunteers that drove an Audi in a
straight road in Germany with one lane for each direction and 100 km/h speed limit.
Distracted driving data were collected by an interface to measure the vehicle’s CAN-
Bus data and a head tracking system installed in the car cockpit. Eight distracting
126 S. M. Kouchak and A. Gaffar

tasks were chosen including radio, CD, phone book, navigation point of interest,
phone, navigation, TV, and navigation sound. Distracting functions were available
through eight hard keys, which were located on the left and right sides of the interface.
In all, 53 non-distracted and 220 distracted runs were done. Collected data from the
CAN-Bus and head tracking system were fed to an LSTM network to predict the
driver’s state continuously. The accuracy of the model reached 96.6%. Although this
approach has high accuracy, it needs some external devices that are not available in
all cars, and is often considered an intrusion by drivers. Additionally, it considers
one driving context and a straight road. If generalized to complex driving contexts,
more work needs to be done to test the accuracy of the model’s performance which
might degrade when other types of roads and driving conditions are introduced as
this would result in more complex contextual data and hence a larger number of
patterns to differentiate between.
Xu et al. [30] introduced an image captioning model with attention mechanisms.
It is inspired by the language translation model. The input of the model is a raw
image and it produces a caption, which describes the input image. In this model,
the encoder extracts features of the image using the middle layers instead of using a
fully connected layer. The decoder is an LSTM network that produces the caption as
a sequence of words. The model shows good results using three benchmark datasets
using the METEOR and BLEU metric.
Xiao et al. [31] discussed an image classification model. It uses a two-level
attention model for fine-grained image classification in deep convolutional neural
networks. Fine-grained classification is detecting subordinate level categories under
some basic categories. This model is based on an intuition that for fine-grained clas-
sification, first the object needs to be detected and in the next step the discriminative
parts of the object should be detected. For object detection level, the model uses raw
images and removes noise to detect the objects and classify them. In the second level,
it filters the object using mid-level CNN filters to detect parts of the image and an
SVM classifier is used to classify the image’s parts. The model was validated on the
subset of the ILSVRC29112 dataset and the CUB200 2011 dataset and it showed
good performance under the weakest supervision condition.
Huang et al. [32] proposed a machine translation model with attention to translate
image captions from English to German. Additional information from the image is
considered by the model to solve the problem of ambiguity in languages. A convolu-
tional neural network is used to extract the image features, and the model adds these
features to the text features to enhance the performance of the LSTM network that is
used to generate the caption in the target language. Regional features of the image are
used instead of general features. In sum, the best performance of the model shows a
2% improvement in the BLEU score and 2.3% enhancement in the METEOR dataset
compared to models that only consider text information.
Lv et al. [33] introduced a deep learning model for traffic flow prediction that
considers both temporal and spatial correlations inherently. The model used a stacked
autoencoder to detect and learn features of traffic flow. The traffic data was collected
from 15000 detectors in the freeway system across California during weekdays in
the first three months of 2013. Collected data in the first two months were used
Training Deep Learning Sequence Models … 127

as the training data and the test dataset was the collected data in the third month.
The model outperformed previous models in medium and high traffic, but it didn’t
perform well in low traffic. They compared the performance of the model with
Random Walk Forecast method, Support Vector Machine, Backpropagation Neural
Network, and Radial Basis Function. Stacked autoencoder model, for the 15 min
traffic flow prediction in 86% of highways, reached more than 90% accuracy and
outperformed the other four shallow models.
Saleh et al. [34] discussed a novel method for driver behavior classification using
a stacked LSTM network. Nine sensory data are captured using internal sensors of a
cellphone during realistic driving sessions. Three driving behavior including normal,
drowsy, and aggressive were defined, and driver behavior classification problem was
modeled as a time series classification. A sequence of driving feature vectors was used
as the input of a stacked LSTM network and the network classified the driver behavior
accurately. Besides, the model achieved better results on UAH-DriveSet, which is a
naturalistic driver behavior dataset, compared to the baseline approach. This model
was compared with other common driver behavior classification methods including
decision tree (DT) and multilayer perceptron (MLP). DT and MLP achieved 51 and
75% accuracy, so the proposed method outperformed them by 86% accuracy.
These works used different types of LSTM networks to detect driver behavior
and driving patterns. In this work, we use two types of LSTM networks including a
bidirectional LSTM network and an LSTM network with attention layer to predict
driver behavior using only driving data. We use simple LSTM network as the baseline
to compare the results of these two networks with simple LSTM network’s result.

3 Experiment

We conducted an experiment to observe and model driver’s behavior using a simu-


lated modern car environment and collect a large body of driving patterns under
different conditions to be used for training three models including MLP, LSTM
network with attention, and encoder–decoder with attention. The experiment was
conducted using a Drive Safety Research simulator DS-600, a fully integrated, high
performance, high fidelity driving simulation system which includes multi-channel
audio/visual system, a minimum 180° wraparound display, full-width automobile
cab (Ford Focus) including windshield, driver and passenger seats, center console,
dash and instrumentation as well as real-time vehicle motion simulation. The simu-
lated immersive view of the road was provided by three large screens in front and
both sides of the car as well as three synchronized mirrors, a rear-view and two side
mirrors. The simulator provides different types of roads and driving contexts. We
designed an urban road with three left turns and a curve highway part. Figure 1 shows
the designed road. Flash A shows the start and end point of this road and flash B
shows the highway part of the road.
128 S. M. Kouchak and A. Gaffar

Fig. 1 Designed road

Four contexts of driving were defined for this experiment including Day, Night,
Fog, and Fog and Night. We used an android application to simulate the car info-
tainment system. The application was hosted on an Android v4.4.2 based Samsung
Galaxy Tab4 8.0 which was connected to the Hyper Drive simulator. This setup
allowed us to control the height and angle of the infotainment system. The removable
tablet further allowed the control of screen size and contents.
Our minimalist design was used in the interface design of this application. In this
design, the main screen of the application shows six groups of car applications and the
driver has access to more features under each group. In each step of the navigation,
a maximum of six icons were displayed on the screen; which was tested earlier
and shown to be a suitable number in infotainment UI design [35]. The minimalist
design interaction allowed each application in this interface to be reached within a
maximum of four steps of navigation from the main screen; hence conforming with
NHTSA guidelines DOT HS 812 108 of acceptable distraction [3]. This helped us
to normalize the driver’s interaction tasks and standardize them between different
driving contexts.

3.1 Participants

We invited 35 volunteers to participate in our simulator experiment. Each volunteer


was asked to take a 45-min simulated drive, which was divided into eight driving
scenarios including four distracted and four non-distracted ones. Volunteers were
graduate and undergraduate students of our university in the range of 20–35 years
old and had at least two years of driving experience. Before starting the experiment,
each volunteer was trained for 10–15 min until they became familiar with the car
interface and with driving in the simulator. We defined four driving contexts including
Day, Fog, Night, and Fog and Night. For each context, we had one distracted and one
non-distracted scenario. In the non-distracted scenarios, the volunteers were asked to
focus on driving and to try to drive as realistic as possible. In distracted scenarios, we
Training Deep Learning Sequence Models … 129

distracted drivers by asking them to perform some tasks on the simulator’s infotain-
ment system. In this experiment, a distracting task is defined as reaching a specific
application on the designed car interface. We classified all possible applications in the
interface to three groups based on the number of navigation steps that the driver needs
to pass from the main screen to reach them and called them two-step, three-step, and
four-step tasks.
In each distracting scenario, we chose some tasks from each group and asked the
driver to do them. We put equal time intervals between tasks. It means when the
driver finished a task, we waited a few seconds before asking for the next task. In
distracted scenarios, we observed the driver behavior and collected four features for
each task including the number of errors that the driver did during the task, response
time which shows the time range from the moment that the driver was asked to do a
task and the time the task was completed, the mode of driving which is the current
driving scenario, and the number of navigation steps that the driver needs to pass
from the main screen to a specific application in the interface to complete the task.

4 Data

The simulator collected an average of 19000 data vectors per trip, with 53 driving-
related features each. In sum, during 280 trips 5.3 million data vectors were collected.
We used a paired t-test to select the most significant features. Based on the results of
pair t-tests, we chose 10 driving features including velocity, speed, steering wheel,
brake, lateral accelerating, headway distance, headway time, accelerating, longitude
accelerating, and lane position. We set the master data sampling rate at 60 samples
per second. While the high sample rate was useful in other experiments, using this
high sampling rate made the model computationally expensive, so we compressed
the collected data and averaged every 20 samples into one vector. We made a dataset
of driving-related data vectors with 10 features, the Vc. In distracted scenarios,
volunteers did 2025 tasks in sum. For each task, we defined four features for driver
behavior while interacting with our designed interface, the Vh:
1. The name of each scenario that shows if the driver is distracted or not, and based
on our previous experiments and the current collected data, the adverse (Fog,
Night) and double adverse (Fog and Night) driving contexts adversely affect
drivers’ performance.
2. The number of driver errors during the task. We defined error as touching a wrong
icon or not following driving rules while interacting with the interface.
3. Response time, which is the length of the task from the moment that we ask the
driver to start the task and the moment that the task is done.
4. The number of navigation steps that the driver needs to pass to reach the appli-
cation and finish the task. All tasks could be completed within four steps or
less.
130 S. M. Kouchak and A. Gaffar

For each driver data vector (Vh), several car data vectors (Vc) were collected by
the simulator. The number of Vc vectors linked to each task depends on the response
time of the task. To map Vc vectors to the corresponding Vh, we divided the Vc vectors
of each trip based on the length of tasks’ response time of that trip to Vc blocks with
different length and map each block to one Vh vector. In (1) N is the number of
tasks that were executed in a trip and per(taski ) is the approximated percentage of
Vc vectors in the trip that is related to taski .
 
Responset imei
per(taski ) = N ∗ 100 (1)
k=1 Responset imek

We used driving data (Vc) as input of a multi-input–single-output bidirectional


LSTM network to predict the correlated driver behavior feature vector (Vh). Then
we added an attention layer to the model and analyzed its effect on the performance
of the model.

5 Methodology

Feedforward neural networks have been used in many car-related applications such
as lane departure and sign detection [22–25]. In feedforward neural networks, data
travels in one direction from the input layer to the output and the training process is
based on the assumption that input data are fully independent of each other. Feed-
forward networks don’t have any impression of order in sequence or time and only
the current sample is considered to tune the network’s hyperparameters [36]. We use
a feedforward neural network as the baseline model to estimate the driver behavior
using driving data. The model is a single-input–single-output network that uses the
mean of driving data during each task as the input and estimates the corresponding
driver behavior vector.
In some real-life applications such as speech recognition, driving and image
captioning input samples are not independent and there is valuable information in
the sequence so the models with memory such as recurrent neural networks can
outperform memoryless models. The LSTM model uses both the current sample
and previously observed data in each step of training. It combines these two data
to produce the network’s output [37]. The difference between feedforward networks
and RNN is that the feedback loop in RNN provides previous steps’ information to
the current one, adding memory to the network, which is preserved in some hidden
states of the Network. LSTM with attention is a recurrent model that is a combina-
tion of the bidirectional LSTM layer and the attention layer. LSTM attention model
uses the weighted effect of all driving data during each task to estimate the driver
behavior vector [38]. This model has both memory and attention and considers all
driving data during the task so our assumption is that this model would be more
accurate compared to the feedforward model.
Training Deep Learning Sequence Models … 131

If we want to estimate a sequence of driver behavior with the sequence length N,


we have three options including
• Running the feedforward model N times.
• Running the LSTM attention model N times.
• Using sequence-to-sequence model.
By using the memoryless feedforward model, we lose the dependency between
samples and each output calculates independently. The LSTM attention model
would be more accurate for each driver behavior because of memory and atten-
tion in the model but we lose the dependency of output sequence. Sequence-
to-sequence learning considers dependency and information in input and output
sequence. In sequence-to-sequence learning using an encoder–decoder model, the
encoder encodes the input sequence to a context vector and the decoder uses the
context vector to produce the output sequence. Similar to the multi-input–single-
output LSTM model, attention can be added to the model to detect and learn the
weight of each input in the input sequence on the current output more accurately
[39]. We decided to use the encoder–decoder attention model with equal input and
output sequence lengths. This model considers the mean of driving data during each
task as the input and the correlated driver behavior as the output similar to the feed-
forward model. If we want to consider all driving data similar to the LSTM attention
model, the sequence-to-sequence model would be very computationally expensive,
so we tried the same input and output sequence length.
Recent work [40] has investigated different types of machine learning, including
unsupervised learning and convolutional neural networks. As the need to address
other kinds of problems grow, the use of different types of machine learning solutions
is attempted. One prohibitive challenge is the increasing need for computational
power. Hence new optimization approaches are used to reduce the computational
demand in order to make the new solutions technically tractable and scalable.

6 Models

6.1 Multilayer Neural Network Model

We built and trained a memoryless multilayer (MLP) neural network using both
scaled and unscaled data. This feedforward neural network was considered as the
baseline to compare with the LSTM attention network and encoder–decoder attention
model that both have attention and memory.
132 S. M. Kouchak and A. Gaffar

Fig. 2 Attention mechanism

6.2 Bidirectional LSTM Network with Attention Layer

Attention is one of the fundamental parts of cognition and intelligence in humans [41,
42]. It helps reduce the amount of information processing and complexity [43, 44].
We can loosely define attention as directing some human senses (like vision), and
hence the mind, to a specific source or object rather than scanning the entire input
space [45]. This is an essential human perception component that helps increase
processing power while reducing demand on resources [46]. In the neural network
area, attention is primarily used as a memory mechanism that determines which
part of the input sequence has more effect on the final output [47, 48]. The attention
mechanism considers the weighted effect of each input on the model’s output instead
of producing one context vector from all samples of the input sequence (Fig. 2) [49,
50]. We used the attention mechanism to estimate the driver behavior using a sequence
of driving data vectors with 10 features and considering the weighted effect of each
input driving data on the driver behavior. Our assumption is that using attention
mechanism decreases the model’s training and test error and enhances the training
process.

6.3 Encoder–Decoder Attention

The encoder–decoder model is a multi-input–multi-output neural network architec-


ture. The encoder compresses the input sequence to a representative context vector
and the decoder uses this vector to produce the output sequence. In the encoder–
decoder with attention network, the encoder provides one context vector that is
Training Deep Learning Sequence Models … 133

filtered specifically for each output in the output sequence. Equation (2) shows the
output of the encoder in the encoder–decoder without attention. In (2) h is the output
vector of the encoder that contains information of all samples in the input sequence.
In attention networks, the encoder produces one vector for each output (3) shows the
encoder output in the attention network [49].

h = Encoder(x1 , x2 , x3 , . . . , x T , t) (2)

[h 1 , h 2 , h 3 , . . . , h T ] = Encoder(x1 , x2 , x3 , . . . , x T ) (3)

The decoder produces one output at a time and the model scores how well the
encoded input matches the current output. Equation (4) shows the scoring formula
to encode input i in step t. In (4) St−1 shows the output from previous step and h i
is the result of encoding input xi . In the next step, scores are normalized using (5)
which is a “SoftMax” function. The context vector for each time step is calculated
using (6).

eti = a(St−1 , h i ) (4)

exp(eti )
ati = T (5)
j=0 exp(et j )


T
Ct = at j ∗ h j (6)
j=0

Figure 3 shows the encoder–decoder with attention model. The bidirectional


LSTM layer provides access to the previous and next samples of the current sample in
each step of training. The “SeqSelfAttention” layer of Keras was used in the model.
This layer calculates the weighted effect of each input sample in the input sequence
on each output sample of the output sequence. We considered equal length for input
and output sequences as explained earlier. Three different models were built and
trained including three input and three output steps model, four input and four output
steps model and five input and five output steps model as elaborated below.

7 Results

7.1 MLP Results

We built an MLP and trained it with both scaled and unscaled data. In this model,
80% of data were used for training and 20% of them for testing. We tried different
numbers of layers in the range of 2–6 and a variety of hidden neurons in the range
134 S. M. Kouchak and A. Gaffar

Fig. 3 Encoder–decoder with attention

50–500 for this neural network. Table 1 shows some of the best-achieved results
with scaled data and Table 2 Shows the achieved results with unscaled data. These
results show the large difference between train and test error that means the model
has overfitting problem with both scaled and unscaled data and it doesn’t generalize
well in most cases.

Table 1 MLP with scaled


Layer Neuron Train MAE Test MAE
data
Two 50 0.11 0.27
Two 150 0.082 0.23
Three 50 0.096 0.24
Three 150 0.056 0.32
Four 50 0.089 0.24
Four 150 0.012 0.2
Training Deep Learning Sequence Models … 135

Table 2 MLP with unscaled


Layer Neuron Train MAE Test MAE
data
Two 150 0.54 1.41
Two 300 0.39 1.51
Three 150 0.33 1.48
Three 300 0.31 1.34
Four 150 0.4 1.46
Four 300 0.17 1.39

7.2 LSTM Attention Results

In the next step, we trained an LSTM network with attention. We used Adam opti-
mizer as the model’s optimizer and mean absolute error (MAE) as the accuracy of
the model. We tried a wide range of LSTM neurons from 10 to 500. In this model,
80% of the dataset was used as the training set and 20% as the testing set.
Table 3 shows some results with unscaled data. For unscaled data, the best result
was achieved with 20 neurons and it is 0.85 training and 0.96. This model achieved
less test error compared to the MLP model using a smaller number of hidden neurons.
The best result of MLP achieved with four hidden layers and 300 neurons in each
layer, so the MLP model compared to the LSTM attention model is less accurate
and more computationally expensive. Besides, one layer of the LSTM model plus
an attention layer had better performance compared to a large MLP model with
four fully connected layers. In addition, the training process of the MLP model took
around 400–600 epochs while the LSTM converged in around 100–200 epochs.
Table 4 shows the best-achieved results with scaled data. The model with 40
neurons reached the minimum test error which is less than all cases in MLP network
with scaled data except one case that is 4 hidden layer model with 150 neurons. In
general, the LSTM attention model with scaled data generalized better than MLP in
all cases and achieved better performance with a smaller number of hidden neurons
and smaller network.

Table 3 Bidirectional LSTM


LSTM Neurons Train MAE Test MAE
network with attention layer
unscaled data 20 0.85 0.96
30 0.85 0.99
40 0.85 1.01
100 0.85 1
200 0.85 0.99
300 0.96 0.99
136 S. M. Kouchak and A. Gaffar

Table 4 Bidirectional LSTM


LSTM neurons Train MAE Test MAE
network with attention layer
with scaled data 30 0.23 0.24
40 0.22 0.22
60 0.22 0.23
100 0.24 0.25
150 0.22 0.23
200 0.23 0.23

7.3 Encoder–Decoder Attention Model

As we mentioned earlier, to have a sequence of driver behavior data vectors we can use
a sequence-to-sequence model instead of running a multi-input–single-output model
multiple times. We chose the encoder–decoder model as a suitable sequence-to-
sequence model. We built and trained encoder–decoder attention models including
three-step, four-step, and five-step models with both scaled and unscaled data. In
these models, 80% of data was used as a training dataset and 20% of them were used
for testing the model. Different combinations of batch size, activation function, and
the number of LSTM neurons were tested. Finally, we chose Adam as the activation
function, 100 as the batch size. We tried a range of LSTM neurons from 20 to 500.
Besides, we tried different lengths of input and output sequences from 2 to 6. After
the four-step model, increasing the length of the sequence didn’t have a positive
effect on the model’s performance. Table 5 shows some of the best-achieved results
with unscaled data.
Mean absolute error (MAE) was used as the lost function of this model. The three-
step model reached the minimum error which is 1.5 train and test mean absolute error.
The four-step and the five-step models have almost the same performance. The three-
step model with unscaled data and 100 LSTM neurons showed the best performance.
Figure 4 shows this model’s mean absolute error. Table 6 shows the achieved results
for the three models with scaled data. The three-step model with 250 LSTM neurons
reached the minimum error (Fig. 5).
The test error of the encoder–decoder attention model with unscaled data is close
to the MLP model with scale data but the generalization of this model is much better

Table 5 Encoder–decoder
LSTM neurons Sequence MAE train MAE test
attention with unscaled data
100 3 1.5 1.5
250 3 1.51 1.52
100 4 1.5 1.52
150 4 1.52 1.56
50 5 1.52 1.52
250 5 1.53 1.56
Training Deep Learning Sequence Models … 137

Fig. 4 The mean absolute error of a three-step model with 100 LSTM neuron and unscaled data

Table 6 Encoder–decoder
LSTM neurons Sequence MAE train MAE test
attention with scaled data
150 3 0.1683 0.16
250 3 0.1641 0.16
100 4 0.1658 0.16
300 4 0.1656 0.16
150 5 0.1655 0.16
400 5 0.1646 0.17

Fig. 5 The mean absolute error of a three-step model with 250 LSTM neurons and scaled data
138 S. M. Kouchak and A. Gaffar

than the MLP model and the network is smaller than the MLP model. Besides, it
converges in around 50 epochs on average which is much faster than the MLP model
that needs around 400 epochs training on average. Moreover, this model estimates
multiple driver behavior vectors in one run. The error of this model with not scaled
data is more than LSTM attention model but the generalization of this model is
better and the number of input samples is much less than LSTM attention model
since this model consider one driving data for each driver behavior vector, so it is
computationally less expensive.
The encoder–decoder attention with scaled data outperformed both the MLP
model and the LSTM attention model. This model has memory and attention similar
to the LSTM attention model. Besides, it considers the dependency between samples
of output sequence which is not possible if we run a multi-input–single-output model
multiple times to have a sequence of driver behavior vectors. Besides, this model is
computationally less expensive than the LSTM attention model since it considers one
input vector corresponding to each output, similar to the MLP model. The minimum
test error of this model with scaled data is 0.06 less than the minimum test error of the
LSTM attention model and 0.04 less than the minimum test error of the MLP model.
The encoder–decoder attention model converges with less error and takes less time
than two other models. The average number of epochs for this model with unscaled
and scaled data was between 50 and 100 epochs, which are half of the average epochs
of the LSTM attention model and one-fourth of MLP epochs. Besides, the model
uses a smaller number of layers compared to the MLP model and a smaller number
of input samples compared to the LSTM attention model making it computationally
less expensive. It also estimates multiple driver behavior data vectors in each run.

8 Conclusion

Using machine learning methods to monitor driver behavior and detect the driver’s
inattention that can lead to distraction is an emerging solution to detect and reduce
driver distraction. Different deep learning methods such as CNN, LSTM, and RNN
networks were used in several car safety applications. Some of these methods have
memory, so they can extract and learn information in a sequence of data. There
is some information in the sequence of driving data that can’t be impeded from
processing them manually. Driving data samples are not independent of each other,
so methods that have memory and attention such as recurrent models are a better
choice for higher intelligence and hence more reliable car safety applications.
These methods utilize different mechanisms to perceive the dependency between
samples and extract the latent information in the sequence. We chose the MLP
network which is a simple memoryless deep neural network as the baseline for
our model. Then we trained an LSTM attention model that has memory and atten-
tion. This model outperforms the MLP model with both scaled and unscaled data.
The model trained at least two times faster than the MLP model and achieved better
Training Deep Learning Sequence Models … 139

performance with less hidden neurons, smaller network, and a smaller number of
training epochs.
In order to have a sequence of driving data we have two options: run a multi-input–
single-output model multiple times and using a sequence-to-sequence model. We
built and trained an encoder–decoder attention model with both scaled and unscaled
data to have a sequence of driver behavior data vectors. This model outperforms the
MLP model with both scaled and unscaled data. Besides, this model outperformed the
LSTM attention model with scaled data. Encoder–decoder attention model trained
at least two times faster than the LSTM attention model and four times faster than
the MLP model. Besides, in each run, it estimates multiple driving data vectors and
it had the best generalization and minimum difference between train and test error.
Our work shows that this would be a viable and scalable option for deep neural
network models that work in real-life complex driving contexts without the need
to use intrusive devices. It also provides an objective measurement of the added
advantages of using attention networks to reliably detect driver behavior.

References

1. B. Darrow, Distracted driving is now an epidemic in the U.S., Fortune (2016). http://fortune.
com/2016/09/14/distracted-driving-epidemic/
2. National Center for Statistics and Analysis, Distracted driving in fatal crashes, 2017, (Traffic
Safety Facts Research Note, Report No. DOT HS 812 700) (Washington, DC, National Highway
Traffic Safety Administration, 2019)
3. N. Chaudhary, J. Connolly, J. Tison, M. Solomon, K. Elliott, Evaluation of the NHTSA
distracted driving high-visibility enforcement demonstration projects in California and
Delaware. (Report No. DOT HS 812 108) (Washington, DC, National Highway Traffic Safety
Administration, 2015)
4. National Center for Statistics and Analysis, Distracted driving in fatal crashes, 2017, (Traffic
safety facts research Note, Report No. DOT HS 812 700), (Washington, DC: National Highway
Traffic Safety Administration, 2019)
5. S. Monjezi Kouchak, A. Gaffar, Driver distraction detection using deep neural network, in The
Fifth International Conference on Machine Learning, Optimization, and Data Science (Siena,
Tuscany, Italy, 2019)
6. J. Lee, Dynamics of driver distraction: the process of engaging and disengaging. Assoc. Adv.
Autom. Med. 58, 24–35 (2014)
7. T. Hirayama, K. Mase, K. Takeda, Analysis of temporal relationships between eye gaze and
peripheral vehicle behavior for detecting driver distraction. Hindawi Publ. Corp. Int. J. Veh.
Technol. 2013, 8 (2013)
8. National Highway Traffic Safety Administration. Blueprint for Ending Distracted Driving.
Washington, DC: U.S. Department of Transportation. National Highway Traffic Safety
Administration, DOT HS 811 629 (2012)
9. T.B. Sheridan, R. Parasuraman, Human-automation interaction. reviews of human factors and
ergonomics, vol. 1, pp. 89–129 (2015). https://doi.org/10.1518/155723405783703082
10. U. Hamid, F. Zakuan, K. Zulkepli, M. ZulfaqarAzmi, H. Zamzuri, M. Rahman, M. Zakaria,
Autonomous Emergency Braking System with Potential Field Risk Assessment for Frontal
Collision Mitigation (IEEE ICSPC, Malaysia, 2017)
11. L. Li, D. Wen, N. Zheng, L. Shen, Cognitive cars: a new frontier for ADAS research. IEEE
Trans. Intell. Transp. Syst. 13 (2012)
140 S. M. Kouchak and A. Gaffar

12. S. Monjezi Kouchak, A. Gaffar, Estimating the driver status using long short term memory,
in Machine Learning and Knowledge Extraction, Third IFIP TC 5, TC 12, WG 8.4, WG 8.9,
WG 12.9 International Cross-Domain Conference, CD-MAKE 2019 (2019). https://doi.org/10.
1007/978-3-030-29726-8_5
13. P. Koopman, M. Wagner, Autonomous vehicle safety: an interdisciplinary challenge. IEEE
Intell. Transp. Syst. Mag. 9, 90–96 (2017)
14. M. Benmimoun, A. Pütz, A. Zlocki, L. Eckstein, euroFOT: field operational test and impact
assessment of advanced driver assistance systems: final results, in SAE-China, FISITA (eds)
Proceedings of the FISITA 2012 World Automotive Congress. Lecture Notes in Electrical
Engineering, vol. 197 (Springer, Berlin, Heidelberg, 2013)
15. S. Monjezi Kouchak, A. Gaffar, Determinism in future cars: why autonomous trucks are easier
to design, in IEEE Advanced and Trusted Computing (ATC 2017) (San Francisco Bay Area,
USA, 2017)
16. S. Kaplan, M.A. Guvensan, A.G. Yavuz, Y. Karalurt, Driver behavior analysis for safe driving:
a survey, IEEE Trans. Intell. Transp. Syst. 16, 3017–3032 (2015)
17. A. Aksjonov, P. Nedoma, V. Vodovozov, E. Petlenkov, M. Herrmann, Detection and evaluation
of driver distraction using machine learning and fuzzy logic. IEEE Trans. Intell. Transp. Syst.
1–12 (2018). https://doi.org/10.1109/tits.2018.2857222
18. R. Harb, X. Yan, E. Radwan, X. Su, Exploring precrash maneuvers using classification trees
and random forests, Accid. Anal. Prev. 41, 98–107 (2009)
19. A. Alvarez, F. Garcia, J. Naranjo, J. Anaya, F. Jimenez, Modeling the driving behavior of
electric vehicles using smartphones and neural networks. IEEE Intell. Transp. Syst. Mag. 6,
44–53 (2014)
20. J. Morton, T. Wheeler, M. Kochenderfer, Analysis of recurrent neural networks for probabilistic
modeling of driver behavior. IEEE Trans. Intell. Transp. Syst. 18, 1289–1298 (2017)
21. A. Sathyanarayana, P. Boyraz, J. Hansen, Driver behavior analysis and route recognition by
Hidden Markov models, IEEE International Conference on Vehicular Electronics and Safety
(2008)
22. J. Li, X. Mei, D. Prokhorov, D. Tao, Deep neural network for structural prediction and lane
detection in traffic scene. IEEE Trans. Neural Netw. Learn. Syst. 28, 14 (2017)
23. S. Monjezi Kouchak, A. Gaffar, Non-intrusive distraction pattern detection using behavior
triangulation method, in 4th Annual Conference on Computational Science and Computational
Intelligence CSCI-ISAI (USA, 2017)
24. S. Su, B. Nugraha, Fahmizal, Towards self-driving car using convolutional neural network and
road lane detector, in 2017 2nd International Conference on Automation, Cognitive Science,
Optics, Micro Electro-Mechanical System, and Information Technology (ICACOMIT) (Jakarta,
Indonesia, 2017), p. 5
25. S. Hung, I. Choi, Y. Kim, Real-time categorization of driver’s gaze zone using the deep learning
techniques, in 2016 International Conference on Big Data and Smart Computing (BigComp)
(2016), pp. 143–148
26. A. Koesdwiady, S. Bedavi, C. Ou, F. Karray, End-to-end deep learning for driver distraction
recognition. Springer International Publishing AG 2017 (2017), p. 8
27. I. Goodfellow, Y. Bengio, A. Courville, Deep Learning (The MIT press, 2016), ISBN:
9780262035613
28. J. Schmidhuber, Deep learning in neural networks: an overview, vol. 61, (Elsevier, 2015),
pp. 85–117
29. M. Wöllmer, C. Blaschke, T. Schindl, B. Schuller, B. Färber, S. Mayer, B. Trefflich, Online
driver distraction detection using long short-term memory. IEEE Trans. Intell. Transp. Syst.
2(2), 574–582 (2011)
30. K. Xu, J.L. Bay, R. Kirosy, K. Cho, A. Courville, R. Salakhutdinovy, R.S. Zemely, Y.
Bengio, Show, attend and tell: neural image caption generation with visual attention, in 32
nd International Conference on Machine Learning (Lille, France, 2015)
31. T. Xiao, Y. Xu, K. Yang, J. Zhang, Y. Peng, Z. Zhang, The application of two-level attention
models in deep convolutional neural network for fine-grained image classification, in The IEEE
Conference on Computer Vision and Pattern Recognition (CVPR) (2015), pp. 842–850
Training Deep Learning Sequence Models … 141

32. P. Huang, F. Liu, S. Shiang, J. Oh, C. Dyer, Attention-based multimodal neural machine trans-
lation, in Proceedings of the First Conference on Machine Translation, Shared Task Papers,
vol. 2, (Berlin, Germany, 2016), pp. 639–645
33. Y. Lv, Y. Duan, W. Kang, Z. Li, F. Wang, Traffic flow prediction with big data: a deep learning
approach. IEEE Trans. Intell. Transp. Syst. 16, 865–873 (2015)
34. K. Saleh, M. Hossny, S. Nahavandi, Driving behavior classification based on sensor data fusion
using LSTM recurrent neural networks, in IEEE 20th International Conference on Intelligent
Transportation Systems (ITSC) (2017)
35. A. Gaffar, S. Monjezi Kouchak, Minimalist design: an optimized solution for intelligent inter-
active infotainment systems, in IEEE IntelliSys, the International Conference on Intelligent
Systems and Artificial Intelligence (London, UK, 2017)
36. C. Bishop, Pattern Recognition and Machine Learning (Springer). ISBN-13: 978-0387310732
37. M. Magic, Action recognition using Python and recurrent neural network, First edn. (2019).
ISBN: 978-1798429044
38. D. Mandic, J. Chambers, Recurrent neural networks for prediction: learning algorithms,
architectures and stability, First edn. (Wiley, 2001). ISBN: 978-0471495178
39. J. Rogerson, Theory, Concepts and Methods of Recurrent Neural Networks and Soft Computing
(2015). ISBN-13: 978-1632404930
40. M.A. Wani, F.A. Bhat, S. Afzal, A.I. Khan, Advances in Deep Learning (Springer, 2020)
41. A. Gaffar, E. M. Darwish, A. Tridane, Structuring heterogeneous big data for scalability and
accuracy. Int. J. Digit. Inf. Wirel. Commun. 4, 10–23 (2014)
42. A. Gaffar, H. Javahery, A. Seffah, D. Sinnig, A pattern framework for eliciting and deliv-
ering UCD knowledge and practices, in Proceedings of the Tenth International Conference on
Human-Computer Interaction (2003), pp. 108–112
43. A. Gaffar, Enumerating mobile enterprise complexity 21 complexity factors to enhance the
design process, in Proceedings of the 2009 Conference of the Center for Advanced Studies on
Collaborative Research (2009), pp. 270–282
44. A. Gaffar, The 7C’s: an iterative process for generating pattern components, in 11th
International Conference on Human-Computer Interaction (2005)
45. J. Bermudez, Cognitive Science: An Introduction to the Science of the Mind, 2nd edn. (2014).
978-1107653351
46. B. Garrett, G. Hough, Brain & Behavior: an Introduction to Behavioral Neuroscience, 5th edn.
(SAGE). ISBN: 978-1506349206
47. Y. Wang, M. Huang, L. Zhao, X. Zhu, Attention-based LSTM for aspect-level sentiment clas-
sification, in Proceedings of the 2016 Conference on Empirical Methods in Natural Language
Processing (2016), pp. 606–615
48. I. Sutskever, O. Vinyals, Q. Le, Sequence to sequence learning with neural networks, in
Advances in Neural Information Processing Systems 27 (NIPS 2014) (2014)
49. S. Frintrop, E. Rome, H. Christenson, Computational visual attention systems and their cogni-
tive foundations: a survey, ACM Trans. Appl. Percept. (TAP) 7 (2010). https://doi.org/10.1145/
1658349.1658355
50. Z. Yang, D. Yang, C. Dyer, X. He, A Smola, E. Hovy, Hierarchical attention networks for
document classification, in NAACL-HLT 2016 (San Diego, California, 2016), pp. 1480–1489
Exploiting Spatio-Temporal Correlation
in RF Data Using Deep Learning

Debashri Roy, Tathagata Mukherjee, and Eduardo Pasiliao

Abstract The pervasive presence of wireless services and applications have become
an integral part of our lives. We depend on wireless technologies not only for our
smartphones but also for other applications like surveillance, navigation, jamming,
anti-jamming, radar to name a few areas of applications. These recent advances of
wireless technologies in radio frequency (RF) environments have warranted more
autonomous deployments of wireless systems. With such large scale dependence on
use of the RF spectrum, it becomes imperative to understand the ambient signal char-
acteristics for optimal deployment of wireless infrastructure and efficient resource
provisioning. In order to make the best use of such radio resources in both the spa-
tial and time domains, past and current knowledge of the RF signals are important.
Although sensing mechanisms can be leveraged to assess the current environment,
learning techniques are the typically used for analyzing past observations and to pre-
dict the future occurrences of events in a given RF environment. Machine learning
(ML) techniques, having already proven useful in various domains, are also being
sought for characterizing and understanding the RF environment. Some of the goals
of the learning techniques in the RF domain are transmitter or emitter fingerprinting,
emitter localization, modulation recognition, feature learning, attention and saliency,
autonomous RF sensor configuration and waveform synthesis. Moreover, in large-
scale autonomous deployments of wireless communication networks, the signals
received from one component play a crucial role in the decision-making process of
other components. In order to efficiently implement such systems, each component

D. Roy (B)
Computer Science, University of Central Florida, Orlando, FL 32826, USA
e-mail: debashri@cs.ucf.edu
T. Mukherjee
Computer Science, University of Alabama, Huntsville, AL 35899, USA
e-mail: tm0130@uah.edu
E. Pasiliao
Munitions Directorate, Air Force Research Laboratory, Eglin AFB,
Valparaiso, FL 32542, USA
e-mail: eduardo.pasiliao@us.af.mil

© The Editor(s) (if applicable) and The Author(s), under exclusive license 143
to Springer Nature Singapore Pte Ltd. 2021
M. A. Wani et al. (eds.), Deep Learning Applications, Volume 2,
Advances in Intelligent Systems and Computing 1232,
https://doi.org/10.1007/978-981-15-6759-9_7
144 D. Roy et al.

of the network should be uniquely identifiable. ML techniques, that include recurrent


structures, have shown promise in creating such autonomous deployments using the
idea of radio frequency machine learning (RFML). Deep learning (DL) techniques
with the ability to automatically learn features, can be used for characterization
and recognition of different RF properties by automatically exploiting the inherent
features in the signal data. In this chapter, we present an application of such deep
learning techniques to the task of RF transmitter fingerprinting. The first section
concentrates on the application areas in the field of RF where deep learning can
be leveraged for futuristic autonomous deployments. Section 2 presents discussion
of different deep learning approaches for the task of transmitter fingerprinting as
well as the significance of leveraging recurrent structures through the use of recur-
rent neural network (RNN) models. Once we have established the basic knowledge
and motivation, we dive deep into the application of deep learning for transmitter
fingerprinting. Hence, a transmitter fingerprinting technique for radio device iden-
tification using recurrent structures, by exploiting the spatio-temporal properties of
the received radio signal, is discussed in Sects. 3 and 4. We present three types of
recurrent neural networks (RNNs) using different types of cell models: (i) long short-
term memory (LSTM), (ii) gated recurrent unit (GRU), and (iii) convolutional long
short-term memory (ConvLSTM) for that task. The proposed models are also val-
idated with real data and evaluated using a framework implemented using Python.
Section 5 describes the testbed setup and experimental design. The experimental
results, computational complexity analysis, and comparison with state of the art are
discussed in Sect. 6. The last section summarizes the chapter.

Keywords RF fingerprinting · Recurrent neural network · Supervised learning ·


Software-defined radios

1 Applications of Deep Learning in the RF Domain

We are living in a world where the distances are shrinking every day, thanks to
an explosion in the use of connected devices. The ubiquitous usage of wirelessly
connected Internet-of-Things (IoT) [56] along with the deployment of wireless
autonomous systems has ushered in a new era of industrial-scale deployment of
RF devices. This prevalence of large-scale peer-to-peer communication and nature
of the underlying ubiquitous network brings forth the challenge of accurately identi-
fying a RF transmitter. Every device that is part of a large network needs to be able to
identify its peers with high confidence in order to set up secure communication chan-
nels. One of the ways in which this is done is through the interchange of “keys” [42]
for host identification. However, such schemes are prone to breaches by malicious
agents [50] because often the actual implementations of such systems are not crypto-
graphically sound. In order to get around the problem of faulty implementations, one
can use the transmitter’s intrinsic characteristics to create a “fingerprint” that can be
used by a transmitter identification system. Every transmitter, no matter how similar,
Exploiting Spatio-Temporal Correlation in RF Data … 145

has intrinsic characteristics because of the imperfections in its underlying compo-


nents such as amplifiers, filters, frequency mixers as well as the physical properties
of the transmitting antenna; these characteristics are unique to a specific transmitter.
The inaccuracies present in the manufacturing process and the idiosyncrasies of the
hardware circuitry also contribute to the spatial and temporal characteristics of the
signal transmitted through a particular device.

1.1 Transmitter Identification

This inherent heterogeneity can be exploited to create unique identifiers for the
transmitters. One such property is the imbalance in the Inphase (I ) and Quadrature
(Q) phase components of the signal (I/Q data). However, because of the sheer number
of the transmitters involved, manually “fingerprinting” each and every transmitter
is not a feasible task [6]. Thus, in order to build such a system, there needs to
be an “automatic” method of extracting the transmitter characteristics and using the
resulting “fingerprint” for the differentiation process. One way of achieving this, is by
learning the representation of the transmitter in an appropriate “feature space” that has
enough discriminating capability so as to be able to differentiate between “apparently
identical” transmitters. Hence, the concept of “transmitter fingerprinting” came into
light for the task for transmitter classification or identification. For the rest of the
chapter we use the phrases “transmitter classification” and “transmitter recognition”
interchangeably.

1.2 Modulation Recognition

The radio frequency domain consists of a vast range of electro-magnetic frequencies


between 3 Hz and 3 THz [19]. Different wireless communication types, such as local
area network, broadcast radio, wide area network, voice radios, radar, and several
others use different bands of radio frequencies depending on the requirements [52].
They also use variable types of modulation schemes depending on the used frequency
band and application area. Modulation is the technique to change the data to be
transmitted on an RF carrier. Modulation recognition is the task of identifying the
modulation type of the received signal. To that end, the modulation recognition can be
leveraged to identify the type of communication the transmitter is using, even without
the knowledge of used frequency band. Deep convolution networks have shown great
prospect in successfully classifying among different types of modulation schemes,
details of which are presented in next section.
146 D. Roy et al.

2 Approaches for Implementing Deep Learning


in RF Domain

In this section, we present the evolution of researches for transmitter identification


and modulation recognition from traditional approaches to the deep learning-based
ones. First, we discuss few of the traditional methods for the transmitter identification
task. These traditional methods use manual feature engineering and leverage differ-
ent radio attributes like transients, or spurious modulations to create discriminating
feature sets. A transient signal is transmitted when a transmitter is powered up or
powered down. During this short period (typically a few microseconds), capacitive
loads charge or discharge. Different classification approaches using transient-based
recognition was proposed in [18, 45, 51]. In [51], where the authors proposed a
genetic algorithm-based solution for transmitter classification based on transients.
A multifractal segmentation technique was proposed in [45] using the same con-
cept of transients. Another transient-based transmitter classification was proposed
in [18] using a k-nearest neighbor discriminatory classifier. However, these traditional
approaches have extra overhead due to the feature extraction step and furthermore
the quality of the solution is constrained by the type of feature selected, and therefore
by the knowledge of the expert making that decision.
To avoid such overheads, deep learning [55]-based methods can provide an effi-
cient and automatic way of learning and characterizing the feature space within the
inherent properties of transmitters. They are able to learn and analyze the inherent
properties of large deployments and use it to predict and characterize the associated
parameters for the task of automatic feature learning for classification (or regression).
Moreover, the task of classification is equivalent to learning the decision boundary
and neural networks were a natural candidate for a learning machine algorithm. To
that end, neural networks have also previously been used for modulation recognition
and transmitter identification [31, 33, 40] and are particularly attractive since they
can generate accurate models without knowledge of the apriori data distribution.
Next, we demonstrate various existing efforts of such neural network-based meth-
ods for different types of applications in the RF domain. We divide our discussion
into three parts: (i) deep neural networks, (ii) convolutional neural networks, and (iii)
recurrent neural network.

2.1 Deep Neural Networks

We refer to deep neural networks (DNNs) as multiple layer feedforward network with
backprogapation. DNNs have revolutionized the field of artificial intelligence in the
last few years. The problems tackled by DNNs range over Computer Vision, Natural
Language Processing, Speech Processing, and so on. They have been demonstrated to
perform better than humans, for some of these problems. They have also been shown
Exploiting Spatio-Temporal Correlation in RF Data … 147

to be effective for automatically learning discriminating features from data for various
tasks [13]. With proper choice of the neural network architecture and associated
parameters, they can compute arbitrarily good function approximations [21].
In general, the use of deep learning in the RF domain has been limited in the past
with only a few applications in recent times [30]. However, DNNs have already
proven its applicability for the task of transmitter identification. A DNN-based
approach was proposed in [36], where the authors classified among eight differ-
ent transmitters with ∼95% accuracy. For the modulation recognition, a real-life
demonstration of modulation classification using a smartphone app was presented
in [49]. The authors designed the mobile application DeepRadioTM to distinguish
between six different modulation schemes by implementing a four-layer DNN.

2.2 Deep Convolutional Neural Networks

Fully connected DNNs are like standard neural networks but with a lot more hidden
layers. However, these networks can be augmented with convolutional layers for
faster training and for enabling the network to learn more compact and meaningful
representations. Deep convolutional neural networks (DCNN) have been shown to
be effective for several different tasks in communication. There have been quite a few
attempts of using DCNN for learning inherent RF parameters for both modulation
recognition and transmitter identification.
In [29], the authors have demonstrated the use of convolutional neural networks
for modulation detection. Apart from the results, an interesting aspect of the work
is the way I/Q values were used as input to the neural network. More precisely,
given N I/Q values, the authors used a vector of size 2N as an input to the neural
network, effectively using the I and Q components as a tuple representing a point
in the complex plane. This representation proves to be useful for using the I/Q
data in different learning models. However, the authors considered the only spatial
correlations to be exploited by the DCNN implementation for synthetic RF data [27].
More researches were done in [31] with same concept using the real RF data. Further
investigations on modulation recognition were presented in [47] using DCNN.
As of the transmitter identification task, an application of transmitter fingerprint-
ing was presented in [14] for detecting new types of transmitters using an existing
trained neural network model. Similarly, in [44], Sankhe et al. presented a DCNN-
based large-scale transmitter identification technique with 80–90% accuracy. How-
ever, they could only achieve competitive accuracy by factoring in the controlled
impairments at the transmitter. It is to be noted that all prior works have only exploited
the spatial correlation of the signal data, though a continuous signal can be repre-
sented as a time series, having both temporal and spatial properties [54].
148 D. Roy et al.

2.3 Recurrent Neural Networks

Recurrent neural networks (RNN) are capable of predictions with time series data.
They have been used extensively for modeling temporal data such as speech [34].
There is limited amount of work that recognizes the potential of using recurrent
structures in the RF domain. RNNs have also been shown to be useful for capturing
and exploiting the temporal correlations of time series data [13]. There are a few
variants of recurrent neural networks: (i) Long Short-Term Memory (LSTM) [16],
(ii) Gated Recurrent Unit (GRU) [5], and (iii) Convolutional Long Short-Term Mem-
ory (ConvLSTM) [46]. All these variants are designed to learn the long-term temporal
dependencies and are capable of avoiding the “vanishing” or “exploding” gradient
problems [8].
In [26], O’Shea et. al. presented an RNN model that extracted high-level protocol
information from the low-level physical layer representation for the task of radio
traffic sequence classification. A radio anomaly detection technique was presented
in [28], where the authors used an LSTM-based RNN as a time series predictor using
the error component to detect anomaly from real signals. Another application of RNN
was proposed in [41], where the authors used a deep recurrent neural network to learn
the time-varying probability distribution of received powers on a channel and used
the same to predict the suitability of sharing that channel with other users. Bai et al.
proposed an RF path fingerprinting [2] method using two RNNs in order to learn the
spatial or temporal pattern. Both simulated and real-world data were used to improve
the positioning accuracy and robustness of moving RF devices. All these discussed
approaches prove that the temporal property of the RF data can be leveraged through
RNN models to learn and analyze different RF properties. Moreover, the working
principle of RNN is rather simple compared to many complicated learning algorithms
such as multi-stage training (MST). In [58], the authors compared several learning
paradigms for the task of transmitter identification. More precisely, they looked at
feedforward deep neural nets, convolutional neural nets, support vector machines,
and deep neural nets with MST. Though they achieved 99% accuracy for classifying
12 transmitters using MST, on the downside MST is a complex procedure and needs
an easily parallelizable training environment. On the other hand, RNNs use a first-
order update rule (stochastic gradient) and are comparatively simple procedures.
However, there has not been much effort on identifying RF transmitters by exploiting
the recurrent structure within RF data.
As far as the modulation recognition technique is concerned, a method was pro-
posed in [33] for a distributed wireless spectrum sensing network. The authors pro-
posed a recurrent neural network using long short-term memory (LSTM) cell, yield-
ing 90% accuracy on a synthetic dataset [32].
In the light of recent developments, one could argue that deep learning can be a
natural choice for implementing any system for exploiting different RF parameters
to build systems for RF application-oriented tasks. It is also clear that the paradigm
of applying deep learning in the RF domain gradually shifted from applying DNN
to CNN to RNN, while for different application areas, it shifted from modulation
Exploiting Spatio-Temporal Correlation in RF Data … 149

recognition to transmitter identification. It is to be noted that, none of these methods


address the problem of providing an end-to-end solution using raw signal data for
transmitter identification using automatically extracted “fingerprints”. Hence, we
propose a robust end-to-end “radio fingerprinting” solution by proposing different
types of RNN-based models.

3 Highlights of the Proposed RNN Models

As mentioned earlier, we use three variants of RNNs with time series of I/Q data
to present an end-to-end solution for transmitter identification using the concept of
“transmitter fingerprinting”. The main highlights for the rest of the chapter areas
follows:
1. We exploit the temporal properties of I/Q data by using a supervised learning
approach for transmitter identification using recurrent neural networks. We use
two approaches: first, we exploit only the temporal property and, then we exploit
the spatio-temporal property. We use RNNs with LSTM and GRU cells for the first
approach, while we use a convLSTM model for the latter. Although transmitter
fingerprinting has been studied before, to the best of our knowledge this is the
first work which leverages the spatio-temporal property of the over-the-air signal
data for this task.
2. To examine the performance of the proposed networks, we test them on an indoor
testbed. We transmit raw signal data from eight universal software radio periph-
eral (USRP) B210s [10] and collect over-the-air signals using an RTL-SDR [24].
We use the I/Q values from each USRP for fingerprinting the corresponding trans-
mitter.
3. We collect I/Q data from several different types of SDRs, a couple of them made
by the same manufacturer and a couple made by different manufacturers. More
precisely we use USRP-B210 [10], USRP-B200 [9] and USRP-X310 [11] from
Ettus Research as well as ADALM-PLUTO [7] and BLADE-RF [25]. We show
that the spatio-temporal property is more pronounced (and thus easier to exploit,
both explicitly as well as implicitly) when different types of SDRs (from both
same or different manufacturers) are used as transmitters.
4. We also collect three additional datasets of I/Q values from eight USRP B210s
with varying signal-to-noise-ratio (SNR). We use distance and multi-path as the
defining factors for SNR variation during data collection.
5. We train the proposed RNN models and present a competitive analysis of the
performance of our models against the state-of-the-art techniques for transmitter
classification. Results reveal that the proposed methods out-perform the existing
ones, thus establishing the fact that exploiting the spatio-temporal property of I/Q
data can offset the necessity of pre-processing the raw I/Q data used for traditional
transmitter fingerprinting approaches.
150 D. Roy et al.

6. The novelty of proposed implementations lies in accurately modeling and imple-


menting different types of RNNs to build a robust transmitter fingerprinting system
using over-the-air signal data, by exploiting spatio-temporal correlations.

4 Proposed RNN Models for Classification

In order to estimate the noise in an RF channel, the system needs to “listen” to


the underlying signal for sometime and “remember” the same. Previously, neural
networks lacked this capability when used in the context of temporal data. Another
issue with using neural networks with temporal data was the problem of vanishing
gradients, when trying to use back propagation. Both these problems were solved by
the introduction of Recurrent Neural Networks (RNN) [20]. Moreover, inspired by the
success of deep learning systems for the task of characterizing RF environments [30]
and the successful use of RNN for the task of analyzing time series data [34], we
propose to use deep recurrent structures for learning transmitter “fingerprints” for
the task of transmitter classification or identification. These proposed models are
extended version of the work presented in [39].
Formulation of temporal property of RF data
Given T training samples (for T timestamps) where each training sample is of
size of M and consists of a vector of tuples of the form (I, Q) ∈ C representing a
number in the complex plane, we represent a single sample as xt = [[(I, Q)i ]t ; i =
1, 2, . . . , M] ∈ C M for each timestamp t = 1, 2, . . . , T , and we use it as an input to
the neural network. We use a sample size (M) of 1024 as a default. We want to find
the probability of the input vector for next time step (xt+1 ) to belong to class Ck ,
where k ∈ 1, 2, . . . , K , K being the number of classes. The probability P(Ck |xt+1 )
can be written as
P(xt |Ck )P(Ck )
P(Ck |xt+1 ) = (1)
P(xt xt+1 )

where (P(xt |Ck )) is the conditional probability of xt given the class Ck and
(P(xt x x+1 )) is the probability of xt and xt+1 occurring in order.

4.1 Long Short-Term Memory (LSTM) Cell Model

Though LSTM cells can be modeled and designed in various ways depending on the
need, we use the cells as shown in Fig. 1. In one LSTM cell, there are (i) three types
of gates: input (i), forget ( f ), and output (o); and (ii) a state update of internal cell
memory. The most interesting part of the LSTM cell is the “forget” gate, which at
time t is denoted by f t . The forget gates decide whether to keep a cell state memory
(ct ) or not. The forget gates are designed as per the Eq. (2) on the input value xt at
time t and output (h t−1 ) at time (t − 1).
Exploiting Spatio-Temporal Correlation in RF Data … 151

Fig. 1 LSTM cell


architecture used in the
proposed RNN model
σ σ
σ
ĉ

f t = σ(Wx f xt + Wh f h t−1 + b f ) (2)

Note that Wx f and b f represent the associated weight and bias, respectively, between
input (x) and forget gate ( f ) and σ denotes the sigmoid activation function. Once f t
determines which memories to forget, the input gates (i t ) decides which cell states
ct ) to update as per Eqs. (3) and (4).
(

i t = σ(Wxi xt + Whi h t−1 + bi ) (3)

ct = tanh(Wxc xt + Whc h t−1 + bct−1 ) (4)

In Eq. (5), the old cell state (ct−1 ) is updated to the new cell state (ct ) using forget
gates ( f t ) and input gates (i t ).

ct = f t · ct−1 + i t · ct (5)

Here ◦ is the Hadamard product. Finally, we filter the output values through the
output gates (ot ) based on the cell states (ct ) as per Eqs. (6) and (7).

ot = σ(Wxo xt + Who h t + bo ) (6)

h t = ot · tanh(ct ) (7)

4.2 Gated Recurrent Unit (GRU) Model

The main drawback of using LSTM cells is the need for additional memory. GRUs [5]
have one less gate for the same purpose, thus having a reduced memory and CPU
footprint. The GRU cells control the flow of information just like the LSTM cells,
152 D. Roy et al.

Fig. 2 GRU cell


architecture used in the
proposed RNN model
σ σ

but without the need for a memory unit. It simply exposes the full hidden content
without any control. It has a “reset gate” (z t ), an “update gate” (rt ), and a cell state
memory (ct ) as shown in Fig. 2. The reset gates determine whether to combine the
new input with a cell state memory (ct ) or not. The update gate decides how much
of ct to retain. The Eqs. (8)–(11), related to the different gates and states are given
below.
z t = σ(Wx z xt + Whz h t−1 + bz ) (8)

rt = σ(Wxr xt + Whr h t−1 + br ) (9)

ct = tanh(Wxc xt + Whc (rt · h t−1 )) (10)

h t = (1 − z t ) · ct + z t · h t−1 (11)

4.3 Convolutional LSTM Network Model

To mitigate this problem, we use a convolution within the recurrent structure of the
RNN. We first discuss the spatio-temporal property of RF data and then model a
convolutional LSTM network to exploit the same.

4.3.1 Formulation of Spatio-temporal property for RF data

Suppose that a radio signal is represented as a time-varying series over a spatial


region using R rows and C columns. Here R represents the time-varying nature of
the signal and as such in our case it represents the total number of timestamps at
which the signal was sampled (T in our case). C on the other hand represents the
Exploiting Spatio-Temporal Correlation in RF Data … 153

total number of features sampled at each time stamp (in our case its 2048 since there
are 1024 features sampled each of dimension 2). Note that each cell corresponding
to one value of R and one value of C represents a particular feature (I or Q) at a given
point in time.
In order to capture the temporal property only, we use a sequence of vectors cor-
responding to different timestamps 1, 2, . . . , t as x1 , x2 , . . . , xt . However, to capture
both spatial and temporal properties, we introduce a new vector χt,t+γ , which is
formulated as: χt,t+γ = [xt , xt+1 , . . . , xt+γ−1 ]. So the vector χt,t+γ eventually pre-
serves the spatial properties with an increment of γ in time. So, we get a sequence
of new vectors χ1,γ , χγ,2γ , . . . χt,t+γ , . . . , χt+(β−1)γ,t+βγ , where β is R/γ, and the
goal is to create a model to classify them into one of the K classes (corresponding to
the transmitters). We model the class-conditional densities given by P(χt−γ,t |Ck ),
where k ∈ 1, · · · , K . We formulate the probability of the next γ-length sequence to
be in class Ck as per Eq. 12. The marginal probability is modeled as P(χt,t+γ ).

P(χt−γ,t |Ck )P(Ck )


P(Ck |χt,t+γ ) = (12)
P(χt,t+γ )

4.3.2 The Model

The cell model, as shown in Fig. 3, is similar to an LSTM cell, but the input transfor-
mations and recurrent transformations are both convolutional in nature [46]. We for-
mulate the input values, cell state, and hidden states as a 3-dimensional vector, where
the first dimension is the number of measurements which varies with the time interval
γ and the last two dimensions contain the spatial information (rows (R) and columns
(C)). We represent these as: (i) the inputs: χ1,γ , χγ,2γ , · · · χt,t+γ , · · · , χt+(β−1)γ,t+βγ
(previously stated); (ii) cell outputs: C1 , · · · , Ct , and (iii) hidden states: H1 , · · · , Ht .
We represent the gates in a similar manner as in the LSTM model. The parameters t,
i t , f t , ot , W , b hold the same meaning as in Sect. 4.1. The key operations are defined
in Eqs. 13–17. The probability of the next γ-sequence to be in a particular class (from
Eq. 12) is used within the implementation and execution of the model.

i t = σ(Wxi χt,t+γ + Whi Ht−1 + bi ) (13)

f t = σ(Wx f χt,t+γ + Wh f Ht−1 + b f ) (14)

Ct = f t · Ct−1 + i t . tanh(Wxc χt,t+γ + Whc Ht−1 + bc ) (15)

ot = σ(Wxo χt,t+γ + Who Ht−1 + bo ) (16)

Ht = ot · tanh(Ct ) (17)
154 D. Roy et al.

Fig. 3 ConvLSTM cell


ConvLSTM Cell
architecture used in the Xt
proposed RNN model σ σ
ft ot Ht
σ tanh Ct-1
Ht-1 it ĉt
tanh
Ct

tanh tanh activation


σ sigmoid activation
sum over all elements
Hadamard product

5 Testbed Evaluation

In order to validate the proposed models, we collected raw signal data from eight
different universal software radio peripheral (USRP) B210s [10]. We collected the
data in an indoor lab environment with a signal-to-noise ratio of 30 dB, and used the
dataset to discriminate between four and eight transmitters, as mentioned in [40]. We
also collected data with varied SNRs (20, 10, and 0 dB) for the same 8 USRP B210
transmitters. Finally, we collected data from five different types of transmitters, each
transmitter being implemented using an SDR, from several different manufacturers.

5.1 Signal Generation and Data Collection

In order to evaluate our methods for learning the inherent spatio-temporal features of
a transmitter, we used different types of SDRs as transmitters. The signal generation
and reception are shown in Fig. 4. We used GNURadio [12] to randomly generate
signal and modulated the same with quadrature phase shift keying (QPSK). We
programmed the transmitter SDRs to transmit the modulated signal over-the-air and
sensed the same using a DVB-T dongle (RTL-SDR) [24]. We generated the entire
dataset from “over-the-air” data as sensed by the RTL-SDR using the rtlsdr python
library.
We collected I/Q signal data with a sample size of 1024 at each timestamp. Each
data sample had 2048 entries consisting of the I and Q values for the 1024 samples.
Note that a larger sample size would mean more training examples for the neural
network. Our choice of 1024 samples was sufficient to capture the spatial-temporal
properties while at the same time the training was not computationally intensive. We
collected 40,000 training examples from each transmitter to avoid the data skewness
problem observed in machine learning. The configuration parameters that were used
are given in Table 1. We collected two different datasets at 30 dB SNR and three
datasets having three different SNRs, as discussed below. The different types of SNR
Exploiting Spatio-Temporal Correlation in RF Data … 155

Random QPSK Transmitter


Signal Modulation SDR

Over The Air


Transmission

Data
Datasets RTL-SDR
Collection

Fig. 4 Over-the-air signal generation and data collection technique

Table 1 Transmission configuration parameters


Parameters Values
Transmitter gain 45 dB
Transmitter frequency 904 MHz (ISM)
Bandwidth 200 KHz
Sample size 1024
Samples/transmitter 40,000
# Transmitters 4 and 8

levels were achieved in an indoor lab environment by changing the propagation delay,
multi-path, and shadowing effects. We also collected a “heterogeneous" dataset using
several different types of SDRs. Note that we intend to make the dataset publicly
available upon publication of the chapter.

5.1.1 Homogeneous Dataset

For the “homogeneous” dataset, we used eight radios from the same manufacturer,
namely, the USRP-B210 from Ettus Research [10], as transmitters. We collected two
sets of data: (i) using 4 USRP B210 transmitters: 6.8 GB size, 160K rows, and 2048
columns and (ii) using 8 USRP B210 transmitters: 13.45 GB size, 320K rows, and
2048 columns. Note that the SNR was 30 dB in each case.

5.1.2 Heterogeneous Dataset

In order to investigate the spatio-temporal correlation in the I/Q data from different
types of SDRs from varied manufacturers, we collected a “heterogeneous” dataset
as well. We used three different SDRs from same manufacturer and 2 SDRs from
two different manufacturers. We used USRP B210 [10], USRP B200 [9] and USRP
X310 [11] from Ettus Research. We also used BLADE RF [25] by Nuand and PLUTO
SDR [7] by Analog Devices as two different SDRs from two different manufacturers.
156 D. Roy et al.

The signal generation procedure is similar to Fig. 4 with different SDR models as
transmitters. The SNR remains 30 dB, same as earlier. The “heterogeneous” datasets
were obtained using (i) 5 USRP B210 transmitters: 8.46 GB size, 200K rows, and
2048 columns and (ii) 1 USRP B210, 1 USRP B200, 1 USRP X310, 1 BLADERF,
and 1 PLUTO transmitter: 6.42 GB size, 200K rows, and 2048 columns.
We include the 5 USRP B210 homogeneous data in this dataset to perform a
fair comparison between five heterogeneous radios with five homogeneous ones in
Sect. 6.6, as the mentioned homogeneous datasets in previous paragraph contain
either four or eight homogeneous radios.

5.1.3 Varying SNR Datasets

We collected three more datasets with 8 USRP B210 transmitters with SNRs of 20
dB, 10 dB, and 0 dB, respectively. Each dataset is of size ∼13 GB with 320K rows
and 2048 columns.

5.2 Spatial Correlation in the Homogeneous Dataset

Correlation between data samples plays a crucial role in the process of transmitter
identification. We represent the I and Q values of each training sample at time (t)
as: [I0 Q 0 I1 Q 1 I2 Q 2 I3 Q 3 I4 Q 4 . . . I1023 Q 1023 ]t . We used the QPSK modulation [48]
which means that the spatial correlation should be between every fourth value, i.e.,
between I0 and I4 , and Q 0 and Q 4 . So we calculate the correlation coefficient of
I0 I1 I2 I3 and I4 I5 I6 I7 . Similarly, for Q 0 Q 1 Q 2 Q 3 and Q 4 Q 5 Q 6 Q 7 . We take the aver-
age of all the correlation coefficients for each sample.
We use numpy.corrcoef for this purpose which uses Pearson product-moment
correlation coefficients, denoted by r . The Pearson’s method for a sample is given
by
(M−1)
i=0 (Ii − I )(Q i − Q̄)
r =   (18)
(M−1) (M−1)
i=1 (I i − I ) 2
i=0 (Q i − Q̄) 2

where M is the sample size, Ii and Q i are the sample values indexed with i. The
1 (M−1)
sample mean is I¯ = Ii .
M i=0
The spatial correlations of all the samples for the different transmitters are shown
in Fig. 5. We observe that for most of the transmitters, the correlation is ∼0.42, with
a standard deviation of ∼0.2. However, transmitter 3 exhibits minimal correlation
between these samples, which implies that the spatial property of transmitter 3 is
different from the other transmitters. As a result transmitter 3 should be easily dis-
tinguishable from the others. This claim will be validated later in the experimental
Exploiting Spatio-Temporal Correlation in RF Data … 157

Fig. 5 Spatial correlation in


the datasets

result section where we see 0% false positive and false negative for transmitter 3 for
all the three proposed models. This observation gives us the motivation to exploit the
spatial property as well as the temporal property for the collected time series data.

5.3 Spatial Correlation in the Heterogeneous Dataset

The calculated average spatial correlations of samples for five different types of
transmitters (as mentioned in Sect. 5.1.2 (ii)) are shown in Fig. 6. We observe that
data from USRP-B210 and PLUTO-SDR have better correlations than the other three
types. It is to be noted that we calculated the spatial correlation of this data using the
same technique described in previous section. It is also evident from the figure that
none of the transmitters exhibits impressive correlations, however, each has correla-
tions of different ranges. This phenomenon bolsters our claim that spatio-temporal
property in the heterogeneous data will be more distinguishable than homogeneous
ones, validated later in experimental result section (Sect. 6.6).

Fig. 6 Spatial correlation in


the heterogeneous datasets
158 D. Roy et al.

5.4 Experimental Setup and Performance Metrics

We conducted the experiments on a Ryzen 8 Core system with 64 GB RAM, a GTX


1080 Ti GPU unit having 11 GB memory. We use Keras [4] as the frontend and
Tensorflow [1] as the backend for our implementations. During the training phase,
we use data from each transmitter to train the neural network model. In order to test
the resulting trained model, we use test data collected from one of the transmitters
and present the same to the trained network. In general, to measure the effectiveness
of any learning algorithm, “accuracy” is used as the typical performance metric.
However, accuracy can sometimes be misleading and incomplete when the data is
skewed. For the task of classification, a confusion matrix overcomes this problem
by showing how confused the learned model is on its predictions. It provides more
insights on the performance by identifying not only the number of errors, but also
more importantly the types of errors.

6 Model Implementations and Results

In this section we discuss the implementation of each of the proposed recurrent neural
networks. We train each network for transmitter classification with K classes. For
the sake of robustness and statistical significance, we present the results for each
model after averaging over several runs.

6.1 Implementation with LSTM Cells

As discussed earlier, the recurrent structure of the neural network can be used to
exploit the temporal correlation in the data. To that end, we first implemented a
recurrent neural network with LSTM cells and trained it on the collected dataset
using the paradigm as shown in Fig. 7. We used two LSTM layers with 1024 and 256
units sequentially. We also used a dropout rate of 0.5 in between these two LSTM
layers. Next we used two fully connected (Dense) layers with 512 and 256 nodes,
respectively. We apply a dropout rate of 0.2, and add batch normalization [17] on

Fig. 7 RNN implementation


Processing and
Data Collection

LSTM Layer1

LSTM Layer2

with LSTM cells for


Output
Dense

Dense

Dense

Data
Signal

Input

transmitter classification

(None,
1024 256 512 256 8 8
2048)
Exploiting Spatio-Temporal Correlation in RF Data … 159

the output, finally passing it through a Dense layer having eight nodes. We use ReLU
[23] as the activation function for the LSTM layers and tanh [3] for the Dense layers.
Lastly, we use stochastic gradient descent [3]-based optimization with categorical
cross-entropy training. Note that the neural network architecture was finalized over
several iterations of experimentation with the data and we are only reporting the
final architecture here. We achieved 97.17% and 92.00% testing accuracy for four
and eight transmitters, respectively. The accuracy plots and confusion matrices are
shown in Figs. 8 and 9, respectively. Note that the number of nodes in the last layer
is equal to the number of classes in the dataset. It is also to be noted that during the
process of designing the RNN architecture, we also fine tuned the hyper-parameters-
based generalization ability of the current network (as determined by comparing the
training and validation errors). We also limited the number of recurrent layers and
fully connected layers for each model for faster training [15], since no significant

Fig. 8 Accuracy plots for transmitter classification using LSTM cells

Fig. 9 Confusion matrices for transmitter classification using LSTM cells


160 D. Roy et al.

increase in the validation accuracy was observed after increasing the number of
layers.
The rows and columns of the confusion matrix correspond to the number of
transmitters (classes) and the cell values show the recall or sensitivity and false
negative rate for each of the transmitters. Note that recall or sensitivity represents
the true positive rates for each of the prediction classes.

6.2 Implementation with GRU Cells

Next we implemented another variation of the RNN model using GRU cells for
leveraging temporal correlation. We used the same architecture as the LSTM imple-
mentation, presented in Fig. 10. The proposed GRU implementation needs fewer
parameters than the LSTM model. A quantitative comparison is given in Sect. 6.4.
The only difference is that we use two GRU layers with 1024 and 256 units instead
of using LSTM cells. We achieved 97.76% and 95.30% testing accuracy for four and
eight transmitters, respectively. The accuracy plots and confusion matrices are given
in Figs. 11 and 12. The GRU implementation provided a slight improvement over
the accuracy obtained using LSTM, for each run of the models, for both the datasets.
Processing and
Data Collection

GRU Layer1

GRU Layer2

Output
Dense

Dense

Dense

Data
Signal

Input

(None,
1024 256 512 256 8 8
2048)

Fig. 10 RNN implementation with GRU cells for transmitter classification

Fig. 11 Accuracy plots for transmitter classification using GRU cells


Exploiting Spatio-Temporal Correlation in RF Data … 161

Fig. 12 Confusion matrices for transmitter classification using GRU cells

6.3 Implementation with ConvLSTM2D Cells

Finally, in order to exploit the spatio-temporal property of the signal data, we imple-
mented another variation of the LSTM model with convolutional filters (transforma-
tions). The implemented architecture is shown in Fig. 13. ConvLSTM2D uses two-
dimensional convolutions for both input transformations and recurrent transforma-
tions. We first use two layers of convLSTM2D with 1024 and 256 filters, respectively,
and a dropout rate of 0.5 in between. We use kernel size of (2, 2) and stride of (2, 2)
at each ConvLSTM2D layer. Next we add two fully connected (Dense) layers having
512 and 256 nodes, respectively, after flattening the convolutional output. ReLU [23],
and tanh [3] activation functions are used for the convLSTM2D and Dense layers,
respectively. ADADELTA [59] with a learning rate of 10−4 and a decay rate of 0.9,
is used as the optimizer with categorical cross-entropy training. We achieved 98.9%
and 97.2% testing accuracy for four and eight transmitters, respectively. The accu-
racy plots and confusion matrices are given in Figs. 14 and 15, respectively. Being
able to exploit the spatio-temporal correlation, ConvLSTM implementation provides
improvement over the accuracies obtained using the LSTM and GRU models, for
both the datasets.
Processing and
Data Collection

ConvLSTM2D

ConvLSTM2D

Output
Layer1

Layer2

Dense

Dense

Dense

Data
Signal

Input

Flatten

(None,
1024 256 512 256 8 8
2048)

Fig. 13 RNN implementation with ConvLSTM cells for transmitter classification


162 D. Roy et al.

Fig. 14 Accuracy plots for transmitter classification using ConvLSTM cells

Fig. 15 Confusion matrices for transmitter classification using ConvLSTM cells

6.4 Comparisons of LSTM/GRU/ConvLSTM


Implementations

We used 90%, 5%, and 5% of the data to train, validate, and test, respectively. We
ran each model for 50 epochs with early-stopping on the validation set. One epoch
consists of a forward pass and a backward pass through the implemented architec-
ture for the entire dataset. The overall accuracies of the different implementations
are shown in Table 2. We find that the implementation of convolutional layers with
recurrent structure (ConvLSTM2D) exhibit the best accuracy for transmitter classi-
fication, which clearly shows the advantage of using the spatio-temporal correlation
present in the collected datasets. In Fig. 16, we present a better illustration of the
achieved classification accuracies for the different implemented models.
Exploiting Spatio-Temporal Correlation in RF Data … 163

Table 2 Accuracy for different implementations


#Trans Models #Parameters (M) Acc (%)
4 LSTM (6 layers) 14.2 97.17
4 GRU (6 layers) 10.7 97.76
4 ConvLSTM (6 layers) 14.2 98.90
8 LSTM (6 layers) 14.2 92.00
8 GRU (6 layers) 10.7 95.30
8 ConvLSTM (6 layers) 14.2 97.20

Fig. 16 Comparison of

97.17%

97.76%

98.90%

92.00%

95.30%

97.20%
testing accuracies of
different types of recurrent
neural networks

ConvLSTM
ConvLSTM

LSTM
LSTM

GRU

GRU

6.5 Computational Complexities

In this section, we concentrate on the computational time complexities of one epoch


for the training phase only, as the trained model gives the output within constant
time (O(1)) during the deployment phase. Understanding the time complexity of
training an RNN is still an evolving research area. The proposed RNN models are
combination of two recurrent layers and four fully connected layers. Hence, we
analyze time complexities of both types separately. In [22], the authors proved that
δ
a fully connected NN of depth δ can be learned in poly(s 2 ) time, where s is the
dimension of the input (=T in the proposed models), and poly(.) takes a constant
164 D. Roy et al.

time depending on the configuration of the system. However, the complexity of


LSTM and GRU layers depends on the total number of parameters in the network
[43]. It can be expressed as O(Plstm × T ) for LSTM layers, where Plstm is the total
number of parameters in LSTM network, T is number of timesteps, or total number
of training data samples. Similarly for the GRU layers it will be O(Pgr u × T ), where
Pgr u is the total number of parameters in GRU network. However, the computational
complexity of ConvLSTM layer will depend on the complexity of convolution as
well as LSTM layers. In [53], the authors mentioned that the time complexity for

training all the convolutional layers is O( τ =1 (ητ −1 ντ2 .ητ ρ2τ ), where ζ is the number
of convolutional layers, τ is the index of a convolutional layer, ητ −1 is the number
of input channels of the τ th layer, ντ is the spatial size of the filters at the τ th layer,
ητ is the number of filters at the τ th layer, and ρτ is the size of the output features of
the τ th layer. In the proposed ConvLSTM model, we have two ConvLSTM layers,
and four fully connected layers, therefore, we add in additional time complexity for
training those convolutional layers.
The time complexities for each implemented RNN models for the homogeneous
dataset is presented in Table 3, using the aforementioned results on time complexity
of neural network training. The total number of parameters used in each network
are shown in Table 2. The numbers within the parenthesis in the second column
represent the total number of layers for a particular model. Note that we have two
different datasets of dimensions 160 and 320 K and as mentioned earlier, we use
95% of data for training and validation purpose. For example, the complexity for
ConvLSTM with  six layers using 95% of 160e3 data samples for training and val-
idation, is O(( 2τ =1 (ητ −1 · ντ2 · ητ · ρ2τ ) + 0.95 × 160e3) for 2 ConvLSTM layers
4
and poly(0.95 × 160e32 ) for four fully connected layers. Similarly, the computa-
tional complexity of LSTM with six layers using 95% of 160e3 data samples for
training and validation, is O(14.2e6 × 0.95 × 160e3) for 2 LSTM layers (where
4
14.2e6 is the number of parameters), and poly(0.95 × 160e32 ) for four fully con-
nected layers.

Table 3 Computational complexities for training of epoch of proposed implementations


#Trans Models Complexity
4
4 LSTM (6) O(14.2e6 × 0.95 × 160e3) + poly(0.95 × 160e32 )
4
4 GRU (6) O(10.7e6 × 0.95 × 160e3) + poly(0.95 × 160e32 )

4 ConvLSTM (6) O(( 2τ =1 (ητ −1 · ντ2 · ητ · ρ2τ ) + 0.95 × 160e3)
4
+ poly(0.95 × 160e32 )
4
8 LSTM (6) O(14.2e6 × 0.95 × 320e3) poly(0.95 × 320e32 )
4
8 GRU (6) O(10.7e6 × 0.95 × 320e3) + poly(0.95 × 320e32 )

8 ConvLSTM (6) O(( 2τ =1 (ητ −1 · ντ2 · ητ · ρ2τ ) + 0.95 × 160e3)
4
+ poly(0.95 × 320e32 )
Exploiting Spatio-Temporal Correlation in RF Data … 165

6.6 Experiments with Heterogeneous Dataset

So far, we have implemented the proposed RNN models for “homogeneous” datasets,
where transmitter SDRs were from the same manufacturer. However, in reality the
transmitters can be of different models from either same manufacturer, or several
different manufacturers. Now, we want to explore how the accuracy of transmitter
identification would change if “heterogeneous” data (as was discussed in Sect. 5.1.2)
obtained from different types of transmitters (manufacturers) were used. From the
testing accuracies as shown in Table 4, we observe that all the RNNs perform better
when transmitters are of different models either from same or different manufacturers
and hence are fundamentally of different types. The performance of LSTM, GRU,
and ConvLSTM increase 5%, 3.37%, and 1.51%, respectively, for classifying het-
erogeneous radios than the homogeneous one. This confirms the intuition that radios
manufactured using different processes (from different manufacturers) contain easily
exploitable characteristics in their I/Q samples, that can be implicitly learned using
an RNN. The comparison of confusion matrices for all three proposed models are
presented Figs. 17, 18, and 19. The false positives and true negatives are observed to
be considerably low for the 5-HETERO results than the 5-B210s. It is to be noted that

Table 4 Comparison of testing accuracies for different classification models for homogeneous and
heterogeneous datasets
Models 5-B210s (Acc) (%) 5-HETERO (Acc) (%) Change
LSTM 95.61 99.89 5%↑
GRU 96.72 99.97 3.37%↑
ConvLSTM 98.5 99.99 1.51%↑

Fig. 17 Confusion matrices for transmitter classification using LSTM cells for heterogeneous
dataset
166 D. Roy et al.

Fig. 18 Confusion matrices for transmitter classification using GRU cells for heterogeneous dataset

Fig. 19 Confusion matrices for transmitter classification using ConvLSTM cells for heterogeneous
dataset

we used the same proposed RNN models for heterogeneous data too, thus, implying
the robustness of the proposed models.

6.7 Comparisons of Proposed and Existing Approaches

Next, we present two comparative studies of our proposed implementations with


state-of-the-art techniques. We introduce a differential analysis of different RNN-
based implementations in the RF domain in Table 5. Another comparative study for
different transmitter classification techniques are shown in Table 6.
Exploiting Spatio-Temporal Correlation in RF Data … 167

Table 5 Comparison of proposed approach with the existing RNN implementations


Approaches Model SNR (dB) Acc (%) Inputs
Traffic sequence LSTM 20 31.20 Hybrid
recognition [26] real-synthetic
dataset
Automatic modulation LSTM 20 90 Synthetic
classification [33] dataset[32]
Transmitter ConvLSTM 30 97.20 Raw signal
classification (Ours)
Hetero-transmitter ConvLSTM 30 99.99 Raw signal
classification (Ours)

Table 6 Comparison of the our implementation with the existing transmitter classification
approaches
Approach #Trans SNR (dB) Acc (%) Inputs
Orthogonal component 3 20 62–71 Spurious
reconstruction modulation
(OCR) [57]
Genetic Algorithm [51] 5 25 85–98 Transients
Multifractal 8 Not mentioned 92.50 Transients
segmentation [45]
k-NN [18] 8 30 97.20 Transients
Ours 8 30 97.20 Raw signal
Ours-hetero 5 30 99.99 Raw signal

The “Inputs” column in both the tables refer to the type of inputs used for the
methods under consideration. Table 5 shows a comparison of our ConvLSTM-based
RNN for transmitter classification with other RNN-based implementations for sep-
arate tasks like modulation recognition and traffic sequence recognition. Table 6
establishes the efficacy of our ConvLSTM-based RNN model for the task of trans-
mitter classification in terms of testing accuracies. It is to be noted that all the other
methods use expert crafted features as inputs [18, 45, 51, 57], or work with synthetic
datasets [26, 33]. Our method, on the other hand achieves superior accuracy, for both
homogeneous (97.20%) and heterogeneous (99.99%) datasets, using features auto-
matically learned from the raw signal data, thereby paving the way for real-time
deployment of large-scale transmitter identification systems.
168 D. Roy et al.

6.8 Performance Comparison for Varying SNR

In this section, we present the results of transmitter fingerprinting for varying SNR
values. We compare the accuracies for the proposed RNN models having 8 USRP
B210s with 30 dB SNR, with 3 other datasets collected at 0, 10, and 20 dB SNRs
having the same number of transmitters (8B210s) as shown in Table 7. We achieve
better accuracies with all the models for higher SNR values, which is intuitive. It is
to be mentioned that the proposed ConvLSTM RNN model gives more than 93%
accuracy at 0 dB SNR too, whereas GRU model gives lesser than that, and LSTM
fails to achieve a considerable range.
Moreover, the proposed RNN models can be trained using raw signal data from any
type of radio transmitter operating both in indoor as well as outdoor environments. We
would also like to point out that though our data was collected in a lab environment,
we had no control over the environment, there were other transmissions in progress,
people were moving in and out of the lab and there was a lot of multi-path due to
the location and design of the lab. Furthermore, the power of the transmitters was
low and hence this compounded the problem further. Given this, though we say that
the data was collected in a lab environment, in reality it was an uncontrolled daily
use environment reflective of our surroundings. Thus we can safely say that these
methods will work in any real-world deployment of large-scale radio network. In
summary,
• Exploiting temporal correlation only, recurrent neural networks yield 95–97%
accuracy for transmitter classification using LSTM or GRU cells. RNN imple-
mentation with GRU cells needs fewer parameters than LSTM cells as shown in
Table 2.
• Exploiting spatio-temporal correlation, the implementation of RNN using Con-
vLSTM2D cells provides better accuracy (97–98%) for transmitter classification,
thus providing a potential tool for building automatic real world transmitter iden-
tification systems.
• The spatio-temporal correlation is more noticeable (with 1.5–5% improvement
of classification accuracies) in the proposed RNN models for the heterogeneous
transmitters either different models from same manufacturer, or different models
from different manufacturers.

Table 7 Accuracies for different recurrent neural network models with varying SNRs
SNR(dB) Accuracy (%)
LSTM GRU ConvLSTM
0 84.23 90.3 93.3
10 90.21 92.64 95.64
20 91.89 94.02 97.02
30 92.00 95.30 97.20
Exploiting Spatio-Temporal Correlation in RF Data … 169

• The proposed RNN models give better accuracies with increasing SNRs of the
data collection environment. However, the ConvLSTM model is able to classify
with 93% accuracy at 0 dB SNR too, proving the robustness of spatio-temporal
property exploitation.
• We present a comparative study of the proposed spatio-temporal property-based
fingerprinting with the existing traditional and neural network-based models. This
clearly shows that the proposed model achieves the better accuracies compared to
any of the existing methods for transmitter identification.

7 Summary

With more and more autonomous deployments of wireless networks, accurate knowl-
edge of the RF environment is becoming indispensable. In recent years, there has
been a proliferation of autonomous systems that use deep learning algorithms on
large-scale historical data. To that end, the inherent recurrent structures within the
RF historical data can also be leveraged by deep learning algorithms for reliable
future prognosis.
In this chapter, we addressed some of such fundamental challenges on how to
effectively apply different learning techniques in the RF domain. We presented a
robust transmitter identification technique by exploiting both the inherent spatial
and temporal properties of RF signal data. The testbed implementation and result
analysis prove the effectiveness of the proposed deep learning models. The future
step forward can be to apply these methods for identification of actual infrastructure
transmitters (for example FM, AM, and GSM) in real-world settings.

8 Further Reading

More details on deep learning algorithms can be found in [55]. Advanced applications
of deep learning in RF domain involving adversaries can be found in [35, 36, 40].
Deep learning applications in the advanced field of RF, such as dynamic spectrum
access is discussed in [37, 38].

References

1. M. Abadi et al., Tensorflow: large-scale machine learning on heterogeneous distributed systems.


CoRR (2016)
2. S. Bai, M. Yan, Y. Luo, Q. Wan, RFedRNN: an end-to-end recurrent neural network for radio
frequency path fingerprinting, in Recent Trends and Future Technology in Applied Intelligence
(2018), pp. 560–571
170 D. Roy et al.

3. C.M. Bishop, Pattern Recognition and Machine Learning (Information Science and Statistics)
(Springer, 2006)
4. F. Chollet, et al., Keras: the python deep learning library (2015). https://keras.io
5. J. Chung, C. Gülçehre, K. Cho, Y. Bengio, Empirical evaluation of gated recurrent neural
networks on sequence modeling. CoRR (2014). arXiv:abs/1412.3555
6. B. Danev, S. Capkun, Transient-based identification of wireless sensor nodes, in International
Conference on Information Processing in Sensor Networks (2009), pp. 25–36
7. A. Devices, ADALM-PLUTO overview (2020). https://wiki.analog.com/university/tools/pluto
8. R. Dey, F.M. Salemt, Gate-variants of gated recurrent unit (GRU) neural networks, in 2017
IEEE 60th International Midwest Symposium on Circuits and Systems (MWSCAS) (2017), pp.
1597–1600
9. Ettus Research: USRP B200 (2020). https://www.ettus.com/all-products/ub200-kit/
10. Ettus Research: USRP B210 (2020). https://www.ettus.com/product/details/UB210-KIT/
11. Ettus Research: USRP X310 (2020). https://www.ettus.com/all-products/x310-kit/
12. GNURadio: GNU Radio (2020). https://www.gnuradio.org
13. I. Goodfellow, Y. Bengio, A. Courville, Deep Learning. MIT press (2016)
14. A. Gritsenko, Z. Wang, T. Jian, J. Dy, K. Chowdhury, S. Ioannidis, Finding a ‘New’ needle
in the haystack: unseen radio detection in large populations using deep learning, in IEEE
International Symposium on Dynamic Spectrum Access Networks (DySPAN) (2019), pp. 1–10
15. K. He, J. Sun, Convolutional neural networks at constrained time cost, in IEEE CVPR (2015)
16. S. Hochreiter, J. Schmidhuber, Long short-term memory. Neural Comput. 9(8), 1735–1780
(1997)
17. S. Ioffe, C. Szegedy, Batch normalization: accelerating deep network training by reducing
internal covariate shift. CoRR (2015). arXiv:abs/1502.03167
18. I.O. Kennedy, P. Scanlon, F.J. Mullany, M.M. Buddhikot, K.E. Nolan, T.W. Rondeau, Radio
transmitter fingerprinting: a steady state frequency domain approach, in IEEE Vehicular Tech-
nology Conference (2008), pp. 1–5
19. B.P. Lathi, Modern Digital and Analog Communication Systems, 3rd edn. (Oxford University
Press Inc, USA, 1998)
20. Y. LeCun, Y. Bengio, G. Hinton, Deep learning. Nature 521(7553), 436–444 (2015)
21. H.W. Lin, M. Tegmark, D. Rolnick, Why does deep and cheap learning work so well? J. Stat.
Phys. 168(6), 1223–1247 (2017)
22. R. Livni, S. Shalev-Shwartz, O. Shamir, On the computational efficiency of training neural
networks, in Advances in Neural Information Processing Systems (2014), pp. 855–863
23. V. Nair, G.E. Hinton, Rectified linear units improve restricted boltzmann machines, in Proceed-
ings of International Conference on International Conference on Machine Learning (2010),
pp. 807–814
24. NooElec: USRP B210 (2018). http://www.nooelec.com/store/sdr/sdr-receivers/nesdr-mini-
rtl2832-r820t.html
25. Nuad: bladeRF 2.0 micro xA4 (2020). https://www.nuand.com/product/bladeRF-xA4/
26. T.J. O’Shea, S. Hitefield, J. Corgan, End-to-end Radio traffic sequence recognition with recur-
rent neural networks, in IEEE Global Conference on Signal and Information Processing (Glob-
alSIP) (2016), pp. 277–281
27. T. O’Shea, N. West, Radio machine learning dataset generation with GNU radio. Proc. GNU
Radio Conf. 1(1) (2016)
28. T.J. O’Shea, T.C. Clancy, R.W. McGwier, Recurrent neural radio anomaly detection. CoRR
(2016). arXiv:abs/1611.00301
29. T.J. O’Shea, J. Corgan, T.C. Clancy, Convolutional radio modulation recognition networks, in
Engineering Applications of Neural Networks (2016), pp. 213–226
30. T. O’Shea, J. Hoydis, An introduction to deep learning for the physical layer. IEEE Trans.
Cogn. Commun. Netw. 3(4), 563–575 (2017)
31. T.L. O’Shea, T. Roy, T.C. Clancy, Over-the-air deep learning based radio signal classification.
IEEE J. Sel. Top. Signal Process. 12(1), 168–179 (2018)
32. radioML: RFML 2016 (2016). https://github.com/radioML/dataset
Exploiting Spatio-Temporal Correlation in RF Data … 171

33. S. Rajendran et al., Deep learning models for wireless signal classification with distributed
low-cost spectrum sensors. IEEE Trans. Cogn. Commun. Netw. 4(3), 433–445 (2018)
34. J.S. Ren, Y. Hu, Y.W. Tai, C. Wang, L. Xu, W. Sun, Q. Yan, Look, listen and learn-a multimodal
LSTM for speaker identification, in AAAI (2016), pp. 3581–3587
35. D. Roy, T. Mukherjee, M. Chatterjee, Machine learning in adversarial RF environments. IEEE
Commun. Mag. 57(5), 82–87 (2019)
36. D. Roy, T. Mukherjee, M. Chatterjee, E. Blasch, E. Pasiliao, RFAL: adversarial learning for
RF transmitter identification and classification. IEEE Trans. Cogn. Commun. Netw. (2019)
37. D. Roy, T. Mukherjee, M. Chatterjee, E. Pasiliao, Defense against PUE attacks in DSA networks
using GAN based learning, in IEEE Global Communications Conference (2019)
38. D. Roy, T. Mukherjee, M. Chatterjee, E. Pasiliao, Primary user activity prediction in DSA
networks using recurrent structures, in IEEE International Symposium on Dynamic Spectrum
Access Networks (2019), pp. 1–10
39. D. Roy, T. Mukherjee, M. Chatterjee, E. Pasiliao, RF transmitter fingerprinting exploiting
spatio-temporal properties in raw signal data, in IEEE International Conference on Machine
Learning and Applications (2019), pp. 89–96
40. D. Roy, T. Mukherjee, M. Chatterjee, E. Pasiliao, Detection of rogue RF transmitters using
generative adversarial nets, in IEEE Wireless Communications and Networking Conference
(WCNC) (2019)
41. H. Rutagemwa, A. Ghasemi, S. Liu, Dynamic spectrum assignment for land mobile radio
with deep recurrent neural networks, in IEEE International Conference on Communications
Workshops (ICC Workshops) (2018), pp. 1–6
42. Y.B. Saied, A. Olivereau, D-HIP: a distributed key exchange scheme for HIP-based internet of
things, in World of Wireless, Mobile and Multimedia Networks (WoWMoM) (2012), pp. 1–7
43. Sak, H., Senior, A.W., Beaufays, F.: Long short-term memory based recurrent neural network
architectures for large vocabulary speech recognition. CoRR (2014). arXiv:abs/1402.1128
44. K. Sankhe, M. Belgiovine, F. Zhou, L. Angioloni, F. Restuccia, S. D’Oro, T. Melodia, S.
Ioannidis, K. Chowdhury, No radio left behind: radio fingerprinting through deep learning of
physical-layer hardware impairments. IEEE Trans. Cogn. Commun. Netw. 1 (2019)
45. D. Shaw, W. Kinsner, Multifractal modelling of radio transmitter transients for classification,
in IEEE WESCANEX (1997), pp. 306–312
46. X. Shi, Z. Chen, H. Wang, D.Y. Yeung, W.K. Wong, W.C. Woo, Convolutional LSTM network:
a machine learning approach for precipitation nowcasting, in Proceedings of the 28th Interna-
tional Conference on Neural Information Processing Systems, vol. 1 (2015), pp. 802–810
47. Y. Shi, K. Davaslioglu, Y.E. Sagduyu, W.C. Headley, M. Fowler, G. Green, Deep learning for RF
signal classification in unknown and dynamic spectrum environments, in IEEE International
Symposium on Dynamic Spectrum Access Networks (DySPAN) (2019), pp. 1–10
48. S.W. Smith, The Scientist and Engineer’s Guide to Digital Signal Processing (California Tech-
nical Publishing, San Diego, CA, USA, 1997)
49. S. Soltani, Y.E. Sagduyu, R. Hasan, K. Davaslioglu, H. Deng, T. Erpek, Real-time experimen-
tation of deep learning-based RF signal classifier on FPGA, in IEEE International Symposium
on Dynamic Spectrum Access Networks (DySPAN) (2019), pp. 1–2
50. M. Stanislav, T. Beardsley, Hacking IoT: a case study on baby monitor exposures and vulner-
abilities. Rapid 7 (2015)
51. J. Toonstra, W. Kinsner, A radio transmitter fingerprinting system ODO-1, in Canadian Con-
ference on Electrical and Computer Engineering, vol. 1 (1996), pp. 60–63
52. D. Tse, P. Viswanath, Fundamentals of Wireless Communication (Oxford University Press Inc,
USA, 2005)
53. E. Tsironi, P. Barros, C. Weber, S. Wermter, An analysis of convolutional long short-term
memory recurrent neural networks for gesture recognition. Neurocomputing 268, 76–86 (2017)
54. N. Wagle, E. Frew, Spatio-temporal characterization of airborne radio frequency environments,
in IEEE GLOBECOM Workshops (GC Wkshps) (2011), pp. 1269–1273
55. M.A. Wani, F.A. Bhat, S. Afzal, A.I. Khan, Advances in Deep Learning (Springer, 2020)
56. R.H. Weber, R. Weber, Internet of Things, vol. 12 (Springer, 2010)
172 D. Roy et al.

57. S. Xu, L. Xu, Z. Xu, B. Huang, Individual radio transmitter identification based on spurious
modulation characteristics of signal envelop, in IEEE MILCOM (2008), pp. 1–5
58. K. Youssef, L. Bouchard, K. Haigh, J. Silovsky, B. Thapa, C.V. Valk, Machine learning approach
to RF transmitter identification. IEEE J. Radio Freq. Identif. 2(4), 197–205 (2018)
59. M.D. Zeiler, ADADELTA: an adaptive learning rate method. CoRR (2012).
arXiv:abs/1212.5701
Human Target Detection and
Localization with Radars Using Deep
Learning

Michael Stephan, Avik Santra, and Georg Fischer

Abstract In this contribution, we present a novel radar pipeline based on deep


learning for detection and localization of human targets in indoor environments. The
detection of human targets can assist in energy savings in commercial buildings, pub-
lic spaces, and smart homes by automatic control of lighting, heating, ventilation, and
air conditioning (HVAC). Such smart sensing applications can facilitate monitoring,
controlling, and thus saving energy. Conventionally, the detection of radar targets is
performed either in the range-Doppler domain or in the range-angle domain. Based
on the application and the radar sensor, the angle or Doppler is estimated subsequently
to finally localize the human target in 2D space. When the detection is performed on
the range-Doppler domain, the processing pipeline includes moving target indica-
tors (MTI) to remove static targets on range-Doppler images (RDI), maximal ratio
combining (MRC) to integrate data across antennas, followed by constant false alarm
rate (CFAR)-based detectors and clustering algorithms to generate the processed RDI
detections. In the other case, the pipeline replaces MRC with Capon or minimum
variance distortionless response (MVDR) beamforming to transform the raw RDI
from multiple receive channels into raw range-angle images (RAI), which is then
followed by CFAR and clustering algorithm to generate the processed RAI detec-
tions. However, in the conventional pipeline, particularly in case of indoor human
target detection, both domains suffer from ghost targets and multipath reflections
from static objects such as walls, furniture, etc. Further, conventional parametric
clustering algorithms lead to single target splits, and adjacent target merges in the
target range-Doppler and range-angle detections. To overcome such issues, we pro-
pose a deep learning-based architecture based on the deep residual U-net model and

M. Stephan · A. Santra
Infineon Technologies AG, Neubiberg, Germany
e-mail: avik.santra@infineon.com
M. Stephan (B) · G. Fischer (B)
Friedrich-Alexander-University Erlangen-Nuremberg, Erlangen, Germany
e-mail: michael.stephan@fau.de
G. Fischer
e-mail: georg.fischer@fau.de
© The Editor(s) (if applicable) and The Author(s), under exclusive license 173
to Springer Nature Singapore Pte Ltd. 2021
M. A. Wani et al. (eds.), Deep Learning Applications, Volume 2,
Advances in Intelligent Systems and Computing 1232,
https://doi.org/10.1007/978-981-15-6759-9_8
174 M. Stephan et al.

deep complex U-net model to generate human target detections directly from the
raw RDI. We demonstrate that the proposed deep residual U-net and complex U-net
models are capable of generating accurate target detections in the range-Doppler
and the range-angle domain, respectively. To train these networks, we record RDIs
from a variety of indoor scenes with different configurations and multiple humans
performing several regular activities. We devise a custom loss function and apply
augmentation strategies to generalize this model during real-time inference of the
model. We demonstrate that the proposed networks can efficiently learn to detect
and localize human targets correctly under different indoor environments in scenar-
ios where the conventional signal processing pipeline fails.

Keywords mm-wave radar sensor · Deep residual U-Net · Deep complex U-Net ·
Receiver operating characteristics · Detection strategy · DBSCAN · People
sensing · Localization · Human target detection · Occupancy detection

1 Introduction

The ever-increasing energy consumption has given rise to a search for energy-efficient
smart home technologies that can monitor and save energy, thus enhancing sustain-
ability and reducing the carbon footprint. Depending on baseline and operation,
several studies show that energy consumption can be significantly reduced in resi-
dential, commercial, or public spaces by 25–75% [1] by monitoring occupancy or
counting the number of people and accordingly regulating artificial light and HVAC
systems [2]. Frequency modulated continuous wave (FMCW) radars can provide a
ubiquitous solution to sense, monitor, and thus control the household appliances’
energy consumption.
Radar has evolved from automotive applications such as driver assistance systems,
safety and driver alert systems, and autonomous driving systems to low-cost solu-
tions, penetrating industrial and consumer market segments. Radar has been used
for perimeter intrusion detection systems [3], gesture recognition [4–6], human–
machine interfaces [7], outdoor positioning and localization [8], and indoor people
counting [9].
Radar responses from human targets are, in general, spread across Doppler due
to the macro-Doppler component of the torso and associated micro-Doppler com-
ponents due to hand, shoulder, and leg movements. With the use of higher sweep
bandwidths, the radar echoes from targets are not received as point targets but are
spread across range and are referred to as range-extended targets. Thus, human tar-
gets are perceived as doubly spread targets [10] across range and Doppler. Thereby,
human target detection using radars with higher sweep bandwidths requires several
adaptations in the standard signal processing pipeline before feeding the data into
application-specific processing, e.g., target counting or target tracking. Depending
on the application and the system used, there are two signal processing pipelines. One
where the processing is performed on the range-Doppler domain, and the objective
is to perform target detection and localization in the same domain.
Human Target Detection and Localization with Radars … 175

In such a case, the radar RDI processing pipeline involves MTI to remove static
targets and MRC to integrate data across antennas. We refer to the RDI after this
stage as raw RDI. The raw RDI is then fed into a constant false alarm rate (2D-CFAR)
detection algorithm, which then detects whether a cell under test (CUT) is a valid
target or not by evaluating the statistics of the reference cells. The constant-false alarm
rate detection adaptively calculates the noise floor by calculating the statistics around
the CUT and guarantees a fixed false alarm rate, which sets the threshold multiplier
with which the estimated noise floor is scaled. Further, the detections are grouped
together using a clustering algorithm such that reflections from the same human
target are detected as a single cluster and likewise for different humans. Alternately,
in the case of multiple virtual channels, the radar processing pipeline involves MTI to
remove static targets, followed by receiver angle of arrival estimation, through Capon
or minimum variance distortionless response (MVDR) beamforming to transform
the raw RDI data to range-angle image (RAI). Generating the RAI provides a means
to map the reflective distribution of the target from multi-frequency, multi-aspect
data onto the angular beams enabling to localize the target in the spatial domain.
Following RAI, the detection of targets is achieved through 2D-CFAR on the range-
angle domain, and eventually, targets are clustered using a clustering algorithm. The
former processing pipeline is typically applied in the case of single receive channel
sensor systems, where angular information is missing or in case of applications such
as activity or fall classification where relevant information lies in the range-Doppler
domain. Other applications, which require localizing the target in 2D space require
beamforming and detection in the range-angle domain.
However, human target detection in indoor environments poses further challenges,
such as ghost targets from static objects like walls, chairs, furniture, etc., and also
spurious radar responses due to multi-path reflections from multiple targets [11].
Further, strong reflecting or closer human targets often occlude less reflecting or
farther human targets at the CFAR detector output. While the earlier phenomenon
leads to overestimating the count of humans, the latter phenomenon leads to underes-
timating the count of humans in the room. This results in increased false alarms and
low detection probabilities, leading to poor radar receiver operating characteristics
and inaccurate application-specific decisions based on human target counts or target
tracking.
With the advent of deep learning, a variety of computer vision tasks, namely,
face recognition [12, 13], object detection [14, 15], segmentation [16, 17], have
seen superior state-of-the-art performance advances. The book in [18] gives a good
overview over basic and advanced concepts in deep neural networks, especially in
computer vision tasks. Image segmentation is a task wherein given an image, all the
pixels belonging to a particular class are extracted and indicated in the output image.
In [17], authors have proposed the deep U-Net architecture, which concatenates
feature maps from different levels, i.e., from low-level to high-level to achieve image
segmentation. In [19], authors have proposed deep residual connections to facilitate
training speed and improve classification accuracy. Recently, deep residual U-net
like network architectures have been proposed for identifying road lines from SAR
images [20]. We have re-purposed the deep residual U-Net to process raw RDIs
176 M. Stephan et al.

and generate target detected RDIs or RAIs. The target detected RDIs and RAIs are
demonstrated to suppress reflections from ghost targets, reject spurious targets due to
multi-path reflections, avoid target occlusions and achieve accurate target clustering
in the detected RDIs and RAIs, thus resulting in reliable and accurate human target
detections.
In [21], authors have proposed a fully convolutional neural network to achieve
object detection and estimation of a vehicle target in 3D space by replacing the
conventional signal processing pipeline. In [22], authors have proposed a deep neu-
ral network to distinguish the presence or absence of a target and demonstrate its
improved performance compared to conventional signal processing. In [23], we have
proposed using a deep residual U-net architecture to process raw RDIs into processed
RDIs, achieving reliable target detections in the range-Doppler domain.
We in this contribution, use a 60-GHz frequency modulated continuous wave
(FMCW) radar sensor to demonstrate the performance of our proposed solution
in both range-Doppler and range-angle domain. To handle our specific challenges
and use-case, we define and train our proposed deep residual U-net model using
an appropriate loss function for processing in the output range-Doppler domain.
We also propose a complex-U Net model to process the raw RDI from two receive
antennas to construct an RAI. For the complex neural network layers, we use the Ten-
sorFlow Keras, an open-source neural network library, implementation by Dramsch
and Contributors [24] of the complex layers as described in [25]. We demonstrate
the performance of our proposed system with real data with up to four persons in
indoor environments for both cases and compare the results to the conventional signal
processing approach. The paper is outlined as follows. Section 2 presents the sys-
tem design, with details of the hardware chipset in Sect. 2.1, conventional processing
pipeline in Sect. 2.2, challenges and contributions in Sect. 2.3. We present the models
of our proposed residual U-Net and complex U-Net architecture in Sects. 3 and 4.
We present and describe the dataset in Sect. 5.1, the loss function in Sect. 5.2 and the
design considerations for training in Sect. 5.3. We present the results and discussions
in Sect. 6 and we conclude in Sect. 7 also outlining possible future directions.

2 System Design

2.1 Radar Chipset

The prototype indoor people counting system is based on Infineon’s BGT60TR24B


FMCW radar chipset, shown in Fig. 1a. The functional block diagram of the bistatic
FMCW radar with one transmit and one receive antenna, representing the scenario
with people, furniture, and walls in indoor environments, is depicted in Fig. 1b.
BGT60TR24B operates in frequencies ranging from 57 to 64 GHz wherein the chirp
duration can be configured. The chip features an external phase-locked loop that
controls the linear frequency sweep. The tune voltage output that controls the loop is
Human Target Detection and Localization with Radars … 177

(a) Chipset

(b) FMCW Radar RF signal chain

Fig. 1 a Infineon’s BGT60TR24B 60-GHz radar sensor. b Representational figure of the radar
scene and functional block diagram of the FMCW radar RF signal chain depicting 1TX, 1RX
channel

varied from 1 to 4.5 V to enable the voltage-controlled oscillator to generate highly


linear frequency chirps over its bandwidth.
The radar response from the target in the field of view is mixed with a replica
of the transmitted signal followed by low-pass filtering the resulting mixed-signal,
which is then sampled at the analog to digital converter. The digital intermediate
frequency (IF) signal contains target information such as range, Doppler, and angle,
which can be estimated by digital signal processing algorithms.
The chipset BGT60TR24 is configured with the system parameters given in
Table 1. Consecutive sawtooth chirps are transmitted within a frame and processed
by generating the RDI. Each chirp has NTS = 256 samples representing the DAC
samples, B = 4 GHz represents the frequency sweep bandwidth, and Tc = 261 µs
represents the chirp duration. PN = 32 and TPRT = 520 µs are the number of chirps
transmitted within a frame and the chirp repetition time, respectively. The range res-
olution is determined as δr = 2B c
= 3.75 cm and the maximum theoretical range is
Rmax = (N T S/2) × δr = 4.8 m, the divide by 2 arises since BGT60TR24 has only
an I (Inphase) channel. The maximum unambiguous velocity is
c
vmax = = 4.8 m/s (1)
2 f c TPRT

and the velocity resolution is


c
δv = = 0.3 m/s (2)
2 f c (PN/2)TPRT
178 M. Stephan et al.

Table 1 Operating parameters


Parameters, symbol Value
Ramp start frequency, f min 58 GHz
Ramp stop frequency, f max 62 GHz
Bandwidth, B 4 GHz
Range resolution, δr 3.75 cm
Number of samples per chirp, N T S 256
Maximum range, Rmax 4.8 m
Sampling frequency, f s 2 MHz
Chirp time, Tc 261 µs
Chirp repetition time, TPRT 520 µs
maximum Doppler, vmax 4.8 m/s
Number of chirps, P N 32
Doppler resolution, δv 0.3 m/s
Number of Tx antennas, NTx 1
Number of Rx antennas, NRx 3
Elevation θelev per radar 70◦
Azimuth θazim per radar 70◦

We enabled 1 Tx and 3 Rx antennas from the L-shape configuration to cover both


the elevation and azimuth angle calculations, although the chip contains 2 Tx and
4 Rx antennas. Both the elevation and azimuth 3 dB half-power beamwidth of
BGT60TR24 are 70◦ .

2.2 Processing Pipeline

Owing to the FMCW waveform and its ambiguity function properties, the accurate
range and velocity estimates can be obtained by decoupling them on the generation
of the RDIs. The IF signal from a chirp with NTS = 256 number of samples is
received, PN consecutive chirps are collected and arranged in the form of a 2D
matrix, with dimensions of PN × NTS. The RDI is generated in two steps. The
first step involves calculating and subtracting the mean along fast time, followed by
applying 1D window function, zero-padding, and then 1D Fast Fourier Transform
(FFT) along fast time for all the PN chirps to obtain the range transformations. The
fast time refers to the dimension of NTS, which represents the chirp time. Then in
the second step, the mean along slow time is calculated and subtracted, a 1D window
function and zero-padding are applied, followed by 1D FFT along slow time to obtain
the Doppler transformation for all range bins. The slow time refers to the dimension
along PN, which represents the intra-chirp time. The mean subtraction across the
fast time removes the Tx-Rx leakage, and the subtraction across slow time removes
Human Target Detection and Localization with Radars … 179

the reflections from any static object in the field of view of the radar. Since short
chirp times are used, the frequency shift along fast time is mainly due to the two-way
propagation delay from the target, which is due to the distance of the target to the
sensor. The amplitude and phase shift across slow time is due to the Doppler shift of
the target.
Based on the application and radar system, the processing pipeline can be either
of the following:
• raw absolute RDI −→ 2D CFAR −→ DBSCAN −→ Processed RDI
• raw complex RDI (multiple channels) −→ MVDR −→ raw RAI −→ 2D CFAR
−→ DBSCAN −→ Processed RAI
In the first pipeline, the target detection and clustering operation is performed on
the range-Doppler domain. The detection and clustering is applied on the absolute
raw RDI, and the output is the processed RDI. This is typically applied in case of 1Tx -
1Rx radar sensors or in case the application is to extract target Doppler signatures
for classification, such as human activity classification, vital sensing, etc. While in
the alternate pipeline, the raw complex RDI from multiple channels is processed
through the MVDR algorithm to generate the RAI first, followed by 2D CFAR and
clustering operation in that domain. The latter pipeline is employed where target
detection through localization in 2D space is necessary for the application.
In the former pipeline, in case of multiple receive channels the raw RDIs are
optionally combined through maximal ratio combining (MRC) to gain diversity and
improve signal quality. The objective of MRC is to construct single RDIs by weighted
combinations of the RDIs across multiple channels. The gains for the weighted
averaging is determined by estimating the signal-to-noise ratio (SNR) for each RDI
across antennas. The effective RDI is computed as
 NRx r x r x
r x=1 g |I |
I =  NRx
(3)
rx
r x=1 g

where I r x is the complex RDI of the r xth receive channel, and the gain is adaptively
calculated as
NTS × PN max{|I r x |2 }
gr x =   (4)
NTS PN
l=1 m=1 |I r x (m, l)|2 − max{|I r x |2 }

where max{.} represents the maximum value from the 2D function, and gr x represents
the estimated SNR at r th receive channel. Thus the (PN × NTS × NRx ) RDI tensor
is transformed into a (PN × NTS) RDI matrix. Alternately, in the latter pipeline,
instead of MRC the raw complex RDI is converted to an RAI using Capon or MVDR
algorithm.
180 M. Stephan et al.

MVDR: Denoting the 3D positional coordinates of the Tx element as d Tx and the


Rx elements as dnRx , n = 1, 2 in space, then on assuming far-field conditions, the
signal propagation from a Tx element d Tx to a point scatterer p and subsequently the
reflection from p to Rx element dnRx can be approximated as 2x + d sin(θ ), where x
is the base distance of the scatterer to the center of the virtual linear array, d refers
distance between receive elements, and θ is the incident angle to the center of the
array with respect to bore-sight.
Assuming that far-field conditions are satisfied, the time delay of the radar return
from a scatterer at base distance x from the center of the virtual linear array can be
expressed as
2x d sin(θ )
τn = + (5)
c c
The receiving steering vector is

dnRx sin(θ )
anRx (θ ) = exp(− j2π ); n = 1, 2 (6)
λ
where λ is the wavelength of the transmit signal. The two received de-ramped beat
signal are used to resolve the relative angle θ of the scatterer.
The azimuth imaging profile for each range bin can be generated using the Capon
spectrum from the beamformer. The Capon beamformer is computed by minimizing
the variance/power of noise while maintaining a distortionless response toward a
desired angle. The corresponding quadratic optimization problem is

min w H Cw s.t. w H a Rx (θ ) = 1, (7)


w

where C is the Covariance matrix of noise, the above optimization has a closed
C −1 a(θ)
form expression given as wcapon = a H (θ)C −1 a(θ) , with θ being the desired angle. On

substituting wcapon in objective function of (7), the spatial spectrum is given as


1
Pl (θ ) = (8)
a Rx (θ ) H Cl−1 a Rx (θ )
with l = 0, . . . , L

where L is the number of range bins.


However estimation of noise covariance at each range bin l is difficult in practice,
hence Ĉl is estimated which contains the signal component as well
 K andIFcan be esti-
mated using sample matrix inversion (SMI) technique Cl = N1 k=1 sl (k)slIF (k)H ,
where K denotes the number of chirps in a frame used for signal plus noise covari-
ance estimation and slIF (k) is the de-ramped intermediate frequency signal at range
bin l [26].

2D CFAR: Following the earlier operation in each pipeline, the raw RDI in the
former case and raw RAI in the latter case is fed into a CFAR detection algorithm to
Human Target Detection and Localization with Radars … 181

generate a hit matrix that indicates the detected targets. The target detection problem
can be expressed as

1 if |I (cut)|2 > μσ 2
Idet (cut) = (9)
0 if |I (cut)|2 < μσ 2

where Idet (cut) is the detection output for the cell under test (CUT), depending on
the input image I . The product of μ and σ 2 , the estimated covariance, represents the
threshold for detection. The threshold multiplier μ for the CFAR detection is set by
an acceptable probability of false alarm, which for cell-averaging CFAR (CA-CFAR)
is provided as
−1/N
μ = N (Pfa − 1) (10)

where Pfa is the probability of false alarm, and N is the window size of the so-called
“reference” cell used for noise power estimation. The noise variance σ 2 is estimated
from the “reference” cells around the CUT. Thus, estimated noise covariance takes
into account the local interference plus noise power around the CUT.
While CA-CFAR is the most common detector used in case of point targets, in
case of doubly spread targets, such as humans with wide radar bandwidth, it leads to
poor detections since the spread targets are also present in the reference cells, lead-
ing to high noise power estimations and missed detections. Additionally, CA-CFAR
elevates the noise threshold near a strong target, thus occluding nearby weaker tar-
gets. Alternately for such doubly spread targets, order-statistics CFAR (OS-CFAR)
is used to avoid such issues since the ordered statistic is robust to any outliers, in this
case, target’s spread, in the reference cells. Hence, instead of the mean power in the
reference cells, the kth ordered data is selected as the estimated noise variance σ 2 .
A detailed description of OS-CFAR can be found in [27].

DBSCAN: Contrary to a point target, in case of doubly extended targets, the


output of the detection algorithm is not a single detection in the RDI for a target but
spread across range and Doppler. Thus, a clustering algorithm is required to group
the detections from a single target, based on its size, as a single cluster. For this, the
density-based spatial clustering of applications with noise (DBSCAN) algorithm is
used, which is the most used unsupervised learning algorithm in the machine learning
community. Given a set of target detections from same and multiple targets in the
RAI, DBSCAN groups detections that are closely packed together, while at the same
time removing as outliers detections that lie alone in low-density regions. To do this,
DBSCAN classifies each point as either a core point, edge point, or noise. Two input
parameters are needed for the DBSCAN clustering algorithm, the neighborhood
radius d, and the minimum number of neighbors M. A point is defined as a core
point if it has at least M − 1 neighbors, i.e., points within the distance d. An edge
point has less than M − 1 neighbors, but at least one of its neighbors is a core point.
All points that have less than M − 1 neighbors and no core point as a neighbor do
not belong to any cluster and will be classified as noise [28].
182 M. Stephan et al.

2.3 Challenges and Contributions

Indoor environments typically pose a multitude of challenges to human presence


detection, localization, and people counting systems using radar sensors. The pri-
mary error sources are multi-path reflections, ghost targets, target occlusion, merging
targets, and split targets. Multipath reflections and ghost targets occur due to reflec-
tions from static objects like walls, chairs, or tables to the human target and back to the
radar sensor. These ghost targets and spurious reflections appear as false alarms at the
output of the CFAR detection algorithm. Figure 2a presents the indoor environment
setup wherein a human walks close to the wall, Fig. 2b depicts the target detected
RDI, using a conventional processing pipeline, where the true target is marked in
green and ghost targets & multi-path artifacts are marked in red. Target occlusion
effects may happen if another object is in the line of sight of the target to the radar
sensor, if two targets are close together, or if the reflections from the target are weak-
ened through other reasons, such as translation range migration, rotational range
migration, or speckle [29]. Adaptive CFAR algorithms fail to detect the true target
in such scenarios. The problem of merging targets and split targets partly originates
from the parametric clustering algorithms. With DBSCAN clustering specifically,
clusters may merge if the neighborhood radius is set too high. However, when set
too low, arms, legs, or the head of a human target may be recognized as separate tar-
gets. Based on the indoor environment and activity of the human, the radar response
from a target will have varying points on the range-Doppler domain. Thus setting a
neighborhood radius that works for all the scenarios is very difficult if not impossible.
The traditional processing pipeline with ordered statistics CFAR (OS-CFAR) and
DBSCAN is depicted in Fig. 3a, whereas the proposed deep residual U-net architec-
ture to process the target detected RDI is presented in Fig. 3b. Inspired by the deep
residual U-Net for image segmentation problems, we in this contribution propose to
use the deep residual U-Net architecture to generate detection RDIs, while addition-
ally removing ghost targets, multi-path reflections, preventing target occlusion, and
achieving accurate target clustering. After the target detected RDIs are computed,
the number of targets and their parameters, i.e., range, velocity, angle, are estimated.
The traditional processing pipeline, which takes raw RDIs from multiple receive
antennas as input and processes the raw RDIs through MVDR, OS-CFAR, and
DBSCAN is depicted in Fig. 4a. On the contrary, in the proposed deep complex
U-net model, the raw RDIs from different channels are processed directly by the
neural network to reconstruct the target detected RAI, which is presented in Fig. 4b.
The objective of the proposed deep complex U-Net model is to reconstruct detec-
tion RAIs, whereby the ghost targets are removed, multi-path reflections removed,
preventing target occlusion, and achieving accurate target clustering.
After the target detected RDIs or RAIs are reconstructed, the number of targets
and their parameters, i.e., range, velocity, angle, are estimated. The list of targets
with their parameters is further fed to the application-specific processing or people
tracking algorithm.
Human Target Detection and Localization with Radars … 183

Fig. 2 a Indoor room


environment with a human
walking around in the room.
b Corresponding processed
RDI as sensed by the radar
with, OS-CFAR, and
DBSCAN. Processed RDI
depicts the true target (in
green) and ghost targets (in
red) due to reflections from
walls, etc.

(a) Indoor room environment


0

2
range in m

4 2 0 2 4
velocity in m/s

(b) Traditional detected RDI with a ghost target


(red)

3 Network Architecture—Deep Residual U-Net

Figure 5 illustrates the network architecture. It has an encoder and a decoder path,
each with three resolution steps. In the encoder path, each block contains two 3 × 3
convolutions, each followed by a rectified linear unit (ReLu), and a 3 × 3 max pooling
with strides of two in every dimension. In the decoder path, each block consists of
an upconvolution of 2 × 2 with strides of two in each dimension, followed by a
ReLu and two 3 × 3 convolutions, each followed by a ReLu. The up-convolutions
are implemented as an upsampling layer followed by a convolutional layer. Skip
connections from layers of equal resolution in the encoder path provide the high-
resolution features to the decoder path, like the standard U-net. The bridge connection
between the encoder and the decoder network consists of two convolutional layers,
each followed by a ReLu, with a Dropout Layer in between. The dropout is set
to 0.5 during training to reduce overfitting. In the last layer, a 1 × 1 convolution
184 M. Stephan et al.

Raw RDI
0
0

range in m
Traditioanl

1
range in m

2 OS-CFAR DBSCAN 2

3
Detection Clustering 3

4
4

4 2 0 2 4 4 2 0 2 4
velocity in m/s velocity in m/s
(a) Traditional Processing Pipeline
Raw RDI Detected Target RDI
0 0

range in m
1
range in m
Proposed

2 Deep 2

3
Residual U-Net 3

4 4

4 2 0 2 4 4 2 0 2 4
velocity in m/s velocity in m/s
(b) Proposed Processing Pipeline
Fig. 3 a Traditional processing Pipeline with OS-CFAR and DBSCAN to generate target detected
RDIs. b Processing Pipeline using the proposed deep residual U-Net to suppress ghost targets,
multi-path reflections, mitigate target occlusions, and achieve accurate clustering

(a) Traditional Processing Pipeline

(b) Proposed Processing Pipeline


Fig. 4 a Traditional processing Pipeline with MVDR, OS-CFAR, and DBSCAN to generate target
detected RAIs from input RDIs across channels. b Processing Pipeline using the proposed deep
complex U-Net to suppress ghost targets, multi-path reflections, mitigate target occlusions, and
achieve accurate clustering
Human Target Detection and Localization with Radars … 185

Conv + ReLu + BN Concat

Input Image

MaxPool Concat + UpConv + ReLu

Output Image

Dropout Concat + Conv + Softmax

Fig. 5 Proposed RDI presence detection architecture for a depth 3 network. Each box corresponds
to one or more layers

UpCconv + ReLu Cconv + ReLu + BN

Input Image Concat Output Image


strided Cconv + ReLu Conv + ReLu

Fig. 6 Proposed RAI presence detection and localization architecture for a depth 4 network. Each
box corresponds to one or more layers

reduces the number of output channels to the number of classes, which is 2, target
present/absent, in our case. The architecture has 40752 trainable parameters in total.
As suggested in the literature, bottlenecks are prevented by doubling the number of
channels before max pooling [30]. We also adopt the same scheme in the decoder
path to avoid bottlenecks. The input to the network is a 128 × 32 × 1 raw RDI. Our
output is a 128 × 32 × 1 image with pixel values between 0 and 1, representing the
probability of target presence/absence in each pixel.

4 Network Architecture—Deep Complex U-Net

Figure 6 shows the network architecture used for the localization task. While it looks
quite similar to Fig. 5, it differs significantly in the details. It is still a U-Net like archi-
tecture, but it uses fully complex operations and 3D-convolutions. A brief description
of complex convolutional layer and complex activation layer is provided below:
186 M. Stephan et al.

Fig. 7 Illustration of complex convolution on 2D CNN from one layer to next

1. Complex Convolutional Layer:


Figure 7 illustrates the 2D CNN operation using complex kernels and input image
map at any layer i. The complex convolutional layer generates feature maps from
the range-Doppler dimensions of both receive channels. The kth feature map in
ith layer can be expressed as

Âi,k + B̂i,k = (Ai−1, j ∗ Ci,k − Bi−1,k ∗ Di,k )


(11)
+ j (Ai−1,k ∗ Di,k + Bi−1,k ∗ Ci,k )

where Ai−1,k , Bi−1,k presents the real and imaginary parts of the feature map at
ith layer and kth map after the convolution operation. The kth kernel’s real and
imaginary components Ci,k , Di,k are real and imaginary parts, respectively. The
filter dimensions are Q i × Q i × K i , where Q i is applied on the range-Doppler
dimension and K i along the spatial channels.
2. Complex Activation Function:
The complex 2D layers progressively extract deeper feature representation. The
activation function introduces non-linearity into the representation. The complex
ReLU is implemented as

Ai,k + j.Bi,k = R E LU ( Âi,k ) + j.R E LU ( B̂i,k ) (12)

In effect, complex RELU maps the second quadrant data to the positive half of
the imaginary axis, the third quadrant data to the origin, and fourth quadrant data
to the positive half of the real axis.
In the proposed architecture, the complex max pooling is avoided since it leads
to model instabilities, and thus progressive range-Doppler dimension reduction was
achieved through strided convolutions. In the encoder path, each block contains two
3 × 3 × 1 complex convolutional layers, each followed by a complex ReLu, and a
3 × 3 × 1 strided convolutional layer with a 2 × 2 × 1 stride. In the decoder path, the
up-convolutional layers are implemented as a 2 × 2 × 1 upsampling layer followed
by a 2 × 2 × 1 convolutional layer. Between each block of the same depth in the
encoder path and the decoder path are skip connections to provide the feature location
Human Target Detection and Localization with Radars … 187

information to the decoder. Batch normalization is done after every convolutional


layer in the encoder and the decoder. In the encoder path the size of the channel
dimension is doubled with every block, while the sizes of the image dimensions are
halved. The opposite is true for the decoder path, where the number of channels
is halved, while the image is upsampled by a factor of two in every block. Due to
the size of the 3d convolution filters, the network only combines the information
provided by the individual antennas in the last complex convolutional layer, where a
filter size of 1 × 1 × 2 is used without padding the input. The last convolutional layer,
drawn in purple in Fig. 6 is a normal convolutional layer in order to combine the real
and imaginary outputs of the previous layers and with a single channel output. The
network with four blocks each in the encoder, and the decoder has 68181 trainable
parameters. The inputs to the neural network are two 128 × 32 complex RDIs from
two receiving antennas. The output is one 128 × 32 RAI with pixel values between
0 and 1, representing the probability of target presence/absence in each range-angle
pixel.

5 Implementation Details

5.1 Dataset

To create the labeled dataset, the raw RDIs are processed with the respective tradi-
tional processing pipelines, with MVDR beamforming to create the labeled RAIs
and MRC instead to create the labeled RDIs. During recording, a camera was used
to generate the dataset, with whose feedback we removed ghost targets and multi-
path reflections and added detections whenever targets were occluded due to other
humans or static humans close to the wall. This was done by changing the parameters
for the detection and clustering algorithms in the conventional pipeline, so that the
probability of detection approaches 100%, also resulting in a high probability of false
alarm. This means decreasing the CFAR scaling factor in case of target occlusion,
and reducing/increasing the maximum neighbor distance for cluster detection with
DBSCAN in case of merged/separated targets. All falsely detected clusters are then
manually removed using the camera data as a reference. The described process is
relatively simple for one-target measurements, as the correct cluster is generally the
one closest in range to the radar sensor. The dataset comprises from one up to four
humans in the room.
The RDIs are augmented to increase the dataset and achieve a broader gener-
alization of the model. Due to the sharply increasing difficulty in creating labeled
data with multiple humans present, we synthetically computed RDIs with multiple
targets by superimposing several raw one-target RDIs after translation and rotation.
With this technique, a large number of RDIs, limited only by the number of possible
combinations of the one-target measurements, can be synthesized. Some caution is
required in having a large enough basis of one-target measurements, as the network
188 M. Stephan et al.

may otherwise overfit on the number of possible target positions. To increase the
pool of one-target measurements for the RAI presence detection and localization
task, the one-target measurements were augmented by multiplying the RDIs from
both antennas by the same complex values with an amplitude close to one and a
random phase. This operation changes the input values but should not change the
estimated ranges and angles aside from a negligible range error in the range of half
the wavelength.

5.2 Loss Function

Given a set of training images, raw RDIs and the corresponding ground truth pro-
cessed RDIs or RAIs Ii , Pi , the training procedure estimates the parameters of the
network, such that the model generalizes to reconstruct accurate RDIs or RAIs. This
is achieved through minimizing the loss between the true RDIs/RAIs Pi and that
generated by h(Ii ; W ). Here, we use a weighted combination of focal loss [31], and
hinge loss, as given in Eq. (13), for the loss function to train the proposed models

H L( p) = 1 − y(2 p − 1)
F L( pt ) = (1 − pt )ϒ log( pt )
L( pt ) = α (F L( pt ) + ηH L( p)) (13)

The variables y ∈ {±1}, p ∈ [0, 1], and pt ∈ [0, 1], specify the class label, the
estimated probability for the class with label y = 1, and the probability that a pixel
was correctly classified, as defined in Eq. (14) as

p if y = 1
pt = (14)
1 − p otherwise

The parameters γ , η, and α influence the shape of the focal loss, the weight of the
hinge loss, and the class weighting, respectively. For training the deep residual U-Net
model for reconstructing the processed range-Doppler image, these parameters were
chosen to γ = 2, α = 0.25, and η was step-wise increased to η = 0.15. In case of
training the deep complex U-Net model for reconstructing the processed range-angle
image, the same parameters except for η = 0 was used. The chosen parameters led
to the best learned model in terms of F1 score accuracy for both models.
The reason for using the focal loss is the class imbalance in the training data due
to the nature of the target setup in picture Fig. 2a. In most cases, the frequency of
pixels labeled as “target absent” is much higher than the frequency of those labeled
with “target present”. Additionally, the class frequencies may vary widely between
single training samples, especially as the training set contains RDIs with different
numbers of targets. The focal loss places a higher weight on the cross-entropy loss for
misclassified pixels and a much lower weight on well-classified ones. Thus, pixels
Human Target Detection and Localization with Radars … 189

belonging to the rarer class are generally weighted higher. The hinge loss is added
to the focal loss with a factor η to force the network to make clearer classification
decisions. The value for η is chosen in such a way that the focal loss dominates the
training for the first few epochs before the hinge loss becomes relevant.

5.3 Design Consideration

The weight initialization of the network is performed with a Xavier uniform initial-
izer, which draws samples from a uniform distribution within [−limit, limit], where
the limit is calculated by taking the square root of six divided by the total number of
input and output units in the weight tensor. The respective biases were initialized as
zeros. For the backpropagation, the Adam [32] optimizer was used, with the default
learning rate (alpha) of 0.001; the exponential decay rate for the first moment (beta1)
was set to 0.9 and to 0.999 for the second moment (beta2). The epsilon that counters
divide by zero problems is set to 10−7 .

6 Results and Discussions

The most common metric for evaluating the detection-clustering performance is to


evaluate the radar receiver operating characteristics (ROC) and the corresponding
area under curve (AUC). In the case of target detection problems such as the one
presented, classification accuracy can be a misleading metric since it doesn’t fully
capture the model performance in this case. Precision helps when the costs of false
positives, i.e., ghost targets or multi-path reflections, are high. Recall helps when the
cost of false negatives, i.e., target occlusions, are substantial. F1 score is an overall
measure of a model’s performance that combines precision and recall. A good F1
score, close to 1, means low false positives, i.e., ghost target detections and low false
negatives, i.e., target occlusions, thus indicating correct target detections without
being disturbed by false alarms. A total of 2000 test images are used for evaluating
the performance of our proposed approach.
The test set consists of one to four human target raw RDIs from two receive anten-
nas, where the data was collected from different room configurations with humans
performing several regular activities. In our experiments, we observed that using the
RDIs after MTI as inputs to the neural network (NN) allows the network to generalize
better in terms of different target scenes. Without removing the static targets from
the input RDIs, the network appears prone to overfitting on the target scene used to
create the training data. However, networks trained using static target removed RDIs,
as presented in this contribution, do not suffer from such generalization issues.
In order to evaluate the probability of detection and the probability of false alarm
for the NN-based and the traditional signal processing approach, the respective pro-
cessed RDI outputs are compared to the labeled data. Due to the difficulties in cre-
190 M. Stephan et al.

ating the labeled data, it is likely that only parts of each human are classified as
“target present” in the RDI/RAI. Therefore, when defining missed detections and
false alarms, small positional errors, and variations in the cluster size should be dis-
counted. We did this by computing the center of mass for each cluster in the labeled
data and the processed RDIs. Targets are identified as correctly detected only if the
distance between the cluster center of masses of the processed and the labeled RDIs
is smaller than 20 cm in range, and 1.5 m/s in velocity. Additionally, we enforce that
each cluster in the network output can only be assigned to exactly one cluster in the
labeled data and vice-versa.
Table 2 presents the performance of the proposed approach in terms of F1 score
in comparison to the traditional processing chain. Results for the proposed approach
are shown for a depth three network (NN_d3), as shown in Fig. 5 and also for a
deeper network with five residual blocks (NN_d5) in both encoder and decoder. For
the depth 5 network, 2 more residual blocks were added in both the encoder, and
the decoder, compared to the structure displayed in Fig. 5. The input to the bridge
block between encoder and decoder path then has the dimension 4 × 1 × 125. The
proposed approach gives a much better detection performance, an F1 score of 0.89,
than the traditional processing pipeline with an F1 score of 0.71. The deeper net-
work (NN_d5), with around 690 thousand trainable parameters, shows some further
improvements in terms of detection performance with an F1 score of 0.91.
Figure 8 presents the ROC curve of the proposed U-net architecture with depth
three and five in comparison with traditional processing. The ROC curves for the
depth three and the depth five networks are done for a varying hard threshold param-
eter, which describes the minimum pixel value for a point to be classified as a target
in the NN output. In the ROC curve representing the traditional processing chain, the
scaling factor for the CFAR detection method was varied. The curves are extrapolated
for detection probabilities close to one. As depicted in the ROC curve, the proposed
residual U-net architecture provides a much better AUC performance compared to the
traditional processing pipeline. Similarily, Fig. 9 shows a better AUC performance
for the complex U-net compared to the traditional processing.
Figure 10a presents the raw RDI, Fig. 10b presents the detected RDI using tra-
ditional approaches, Fig. 10c presents the detected RDI using the proposed deep
residual U-Net approach for a synthetic four target measurement. Originally, the

Table 2 Comparison of the detection performance of the traditional pipeline with the proposed
U-net architecture with a depth of 3 and depth of 5 for localization in the range-Doppler image
Approach Description F1-score Model size
Traditional OS-CFAR with 0.71 –
DBSCAN
Proposed U-net Proposed loss 0.89 616 kB
depth 3
Proposed U-net Proposed loss 0.91 2.8 MB
depth 5
Human Target Detection and Localization with Radars … 191

ROC curve
1

0.9

0.8

0.7

0.6
pD

0.5

0.4

0.3
NN_d5
0.2 NN_d3
traditional
0.1

0
0 0.2 0.4 0.6 0.8 1
pFA

Fig. 8 Radar receiver operating characteristics (ROC) comparison between the proposed residual
deep U-net and the traditional signal processing approach

ROC curve
probability of detection

probability of false alarm

Fig. 9 Radar receiver operating characteristics (ROC) comparison between the proposed complex
deep U-net and the traditional signal processing approach. Dashed parts indicate extrapolation

NN outputs classification probabilities between zero and one for each pixel. To get
the shown output, we used a hard threshold of 0.5 on the NN output, so that any
pixels with corresponding values of 0.5 and higher are classified as belonging to a
target. With the traditional approach, one target, at around 2 m distance, and −2 m/s
in velocity, was mistakenly split into two separate clusters. Additionally, the two
targets between 2 and 3 m were completely missed.
192 M. Stephan et al.

0 0 0

1 1 1
range in m

range in m

range in m
2 2 2

3 3 3

4 4 4

4 2 0 2 4 4 2 0 2 4 4 2 0 2 4
velocity in m/s velocity in m/s velocity in m/s
(a) Raw RDI (b) Processed RDI traditional approach (c) Processed RDI proposed approach

Fig. 10 a Raw RDI image with four human targets, b processed RDI using the traditional approach
wherein one target is split and two targets are occluded, c processed RDI using proposed approach
wherein all targets are detected accurately

0 0 0

1 1 1

range in m
range in m

range in m

2 2 2

3 3 3

4 4 4

4 2 0 2 4 4 2 0 2 4 4 2 0 2 4
velocity in m/s velocity in m/s velocity in m/s

(a) Raw RDI (b) Processe RDI Traditional (c) Processed RDI Proposed Approach
Approach

Fig. 11 a Raw RDI image with four human targets, b processed RDI using the traditional approach
wherein one target is occluded, c processed RDI using proposed approach wherein all targets are
detected accurately

0 0 0

1 1 1
range in m
range in m

range in m

2 2 2

3 3 3

4 4 4

4 2 0 2 4 4 2 0 2 4 4 2 0 2 4
velocity in m/s velocity in m/s velocity in m/s
(b) Processe RDI Traditional (c) Processed RDI Proposed Approach
(a) Raw RDI
Approach

Fig. 12 a Raw RDI image with four human targets, b processed RDI using the traditional approach
wherein two targets are merged into one by the DBSCAN clustering algorithm in addition to two
ghost targets, c processed RDI using proposed approach wherein all targets are detected accurately
and with proper clustering

Figure 11a–c presents the target occlusion problem on synthetic data for the con-
ventional processing chain. While with the traditional approach one target at around
3 m distance was missed, the proposed approach in Fig. 11c is able to reliably detect
all the targets. The figures in 12a–c, and 13a–c show two different scenarios the
traditional approach struggles with.
Human Target Detection and Localization with Radars … 193

0 0 0

1 1 1

range in m

range in m
range in m

2 2 2

3 3 3

4 4 4

4 2 0 2 4 4 2 0 2 4 4 2 0 2 4
velocity in m/s velocity in m/s velocity in m/s
(a) Raw RDI (b) Processed RDI traditional approach (c) Processed RDI proposed approach

Fig. 13 a Raw RDI image with four human targets, b processed RDI using the traditional approach
wherein a ghost target appears, c processed RDI using proposed approach wherein all targets are
detected accurately

In Fig. 12b, the two detected targets around the one-meter mark are too close
together for the DBSCAN clustering algorithm as to be detected as two distinct
clusters. Therefore, these two clusters merge, and one target is missed. Fig. 13b
showcases the ghost target problem. Here, one target at around 4 m in distance to the
radar sensor was wrongly detected. In both cases, the U-Net-based approach correctly
displays all the distinct targets, as seen in Figs. 12c and 13c. However, it has to be
mentioned that while our proposed approach outperforms the traditional processing
chain, missed detections, and false alarms may still occur in similar scenarios. From
our experiments we have observed, that the proposed approach excels in discarding
multi-path reflections and ghost targets caused by the reflections from static objects
in the scene, does well in preventing splitting or merging targets, but does not really
show improvements for the case of occluded targets if many humans are in front of
each other. The cause of this lies in the nature of how most of the training data was
created, where one-target measurements were superimposed over each other in order
to synthesize the multi-target RDIs.
In our experiments, we noticed that the loss function plays a crucial role in achiev-
ing excellent detection results. While the current loss function deals well with the
class imbalance problem and accelerates training for more important pixels in the
RDI, it could be improved by a more severe punishment of merging or split targets,
and by allowing small positional errors in the clusters without increasing the loss.
The evaluation of the proposed method for target detection and localization in
terms of range and angle is again done via ROC-curves and the F1-scores. We use
the same set of test measurements as described earlier, but with the complex RDIs
from two receiving antennas. For the evaluation of the range-angle output maps of the
neural network, a simple clustering algorithm is used on these output images. In our
case, we use DBSCAN with small values for the minimum number of neighbors and
the maximum neighbor distance so that every nearly continuous cluster is recognized
as one target. We then compute the center of masses of each cluster and compare
them to the center of masses computed from our labeled data, as described earlier.
In this case, a target still counts as detected if the distance between the two
centers of mass of the clusters is below some threshold; in this case, 7.8◦ or 37.5 cm.
Table 3 shows a comparison of F1-scores for the proposed method and a traditional
194 M. Stephan et al.

Table 3 Comparison of the detection performance of the traditional pipeline with the proposed
U-net architecture with a depth of 4 for localization in the range-angle image
Approach Description F1-score Model size
Traditional OS-CFAR with 0.61 –
DBSCAN
Complex U-net depth 3 Focal loss 0.72 529 kB
Complex U-net depth 4 Focal loss 0.77 1.23 MB

signal processing chain. Compared to the F1-score of 0.62 for the classical method,
the proposed method shows a clear improvement with an F1-score of 0.77. The
ROC curves in Fig. 9 stop at values smaller one due to how the detections were
evaluated. The dashed lines indicate an extrapolation. If the threshold or the CFAR
scaling factor is set close to zero, then every pixel will be detected as a target, which
would then be identified as one single huge cluster by the DBSCAN algorithm.
Therefore, a probability of false alarm of 1 is not achievable. In the evaluation, we
saw, that the proposed method performs, as expected, better for fewer targets while
its performance worsens mainly with a rising target density. However, even if the
neural network is trained with only one and two target measurements, it will still
be able to correctly identify the positions of three or four targets in a lot of cases.
Comparing the F1-scores from Table 3 to those in Table 2, it seems like the network
has a harder time doing the range-angle localization task. This has several reasons.
First, only two receiving antennas were used, making an accurate angle estimation
more difficult in general. The second explanation, which likely has a bigger impact,
is that the evaluation methods were not the same for both experiments. Specifically,
the minimum center of mass distance, as the difference in velocity is not comparable
to an angle difference. If we increase the angle threshold from 7.8◦ to 15.6◦ , we
get an F1-score of about 0.9 and 0.68 for the proposed and the traditional method,
respectively. Therefore, the proposed method does well in removing ghost targets and
estimating the correct target ranges but is not quite as good in accurately estimating
the angles. The traditional method does not gain as much due to this change since
most of its errors are due to ghost targets or missed detections.
In Figs. 14, 15, 16, and 17 some examples with the input image, the output image
from the traditional processing chain, and the output image from the neural network
are shown. The input image here is the RDI from one of the antennas, as the complex
two antenna input is hard to visualize. It is mainly there to illustrate the difficulty
of the range-angle presence detection and localization. In Fig. 14, two targets were
missed by the classical signal processing chain, whereas all four targets were detected
by the neural network with one additional false alarm. In Fig. 15, all four targets were
detected by the network, while one was missed by the traditional approach. In Fig.
16, only one target was correctly identified by the classical approach, three targets
missed, and one false alarm included. Here, the neural network only missed one
target. In the last example, the network again identified all targets at the correct
Human Target Detection and Localization with Radars … 195

0.75 0.75 0.75

1.5 1.5
range in m

range in m

range in m
1.5

2.25 2.25 2.25

3 3 3

3.75 3.75 3.75

4.5 4.5 4.5

-4 -2 0 2 4 -50 -30 -10 10 30 50 -50 -30 -10 10 30 50

velocity in m/s angle in ° angle in °


(a) Raw RDI (b) Processed RAI traditional approach (c) Processed RAI proposed approach

Fig. 14 a Raw RDI image with four human targets, b processed RAI using the traditional approach
with two missed detections, c processed RAI using proposed complex U-net approach wherein one-
target split occurs

0.75 0.75 0.75

1.5 1.5 1.5

range in m
range in m

range in m

2.25 2.25 2.25

3 3 3

3.75 3.75 3.75

4.5 4.5 4.5

-4 -2 0 2 4 -50 -30 -10 10 30 50 -50 -30 -10 10 30 50


velocity in m/s angle in ° angle in °
(a) Raw RDI (b) Processed RAI traditional approach (c) Processed RAI proposed approach

Fig. 15 a Raw RDI image with four human targets, b processed RAI using the traditional approach
with one missed detection, c processed RAI using proposed complex U-net approach wherein all
targets are detected accurately

0.75 0.75 0.75

1.5 1.5 1.5


range in m

range in m

range in m

2.25 2.25 2.25

3 3 3

3.75 3.75 3.75

4.5 4.5 4.5

-4 -2 0 2 4 -50 -30 -10 10 30 50 -50 -30 -10 10 30 50


velocity in m/s angle in ° angle in °
(a) Raw RDI (b) Processed RAI traditional approach (c) Processed RAI proposed approach

Fig. 16 a Raw RDI image with four human targets, b processed RAI using the traditional approach
with one ghost target and three missed detections, c processed RAI using proposed complex U-net
approach wherein one target was missed

positions, while the traditional approach has two missed detections and one-target
split, resulting in one false alarm.
196 M. Stephan et al.

0.75 0.75 0.75


range in m

range in m
1.5 1.5 1.5

range in m
2.25 2.25 2.25

3 3 3

3.75 3.75 3.75

4.5 4.5 4.5

-4 -2 0 2 4 -50 -30 -10 10 30 50 -50 -30 -10 10 30 50

velocity in m/s angle in ° angle in °


(a) Raw RDI (b) Processed RAI traditional approach (c) Processed RAI proposed approach

Fig. 17 a Raw RDI image with three human targets, b processed RAI using the traditional approach
with one-target split and two missed detections, c processed RAI using proposed complex U-net
approach wherein all targets are detected accurately

7 Conclusion

The traditional radar signal processing pipeline for detecting targets on either RDI
or RAI is prone to ghost targets, multi-path reflections from static objects, and target
occlusion in cases of human detections, especially in an indoor environment. Further,
parametric clustering algorithms suffer from single target splits, and multiple target
merges into single clusters. To overcome such artifacts and facilitate accurate human
detections, localization and counting, we, in this contribution, proposed to use deep
residual U-Net model and deep complex U-Net model to generate accurate human
detections in RDI and RAI domain, respectively, in indoor scenarios. We trained
the models using custom loss function, proposed architectural designs, and training
strategy through data augmentation to achieve accurate processed RDI and RAI.
We demonstrated the superior detection and clustering results in terms of F1 score
and ROC characterization compared to the conventional signal processing approach.
As future work, variational autoencoder generative adversarial network (VAE-GAN)
architecture can be deployed to minimize the sensitivity of U-Net models to variations
in data due to sensor noise and interference.

References

1. EPRI, Occupancy sensors: positive on/off lighting control, in Rep. EPRIBR-100323 (1994)
2. V. Garg, N. Bansal, Smart occupancy sensors to reduce energy consumption. Energy Build.
32, 81–87 (2000)
3. W. Butler, P. Poitevin, J. Bjomholt, Benefits of wide area intrusion detection systems using
FMCW radar (2007), pp. 176–182
4. J. Lien, N. Gillian, M. Emre Karagozler, P. Amihood, C. Schwesig, E. Olson, H. Raja,
I. Poupyrev, Soli: ubiquitous gesture sensing with millimeter wave radar. ACM Trans. Graph.
35, 1–19 (2016)
5. S. Hazra, A. Santra, Robust gesture recognition using millimetric-wave radar system. IEEE
Sens. Lett. PP, 1 (2018)
Human Target Detection and Localization with Radars … 197

6. S. Hazra, A. Santra, Short-range radar-based gesture recognition system using 3D CNN with
triplet loss. IEEE Access 7, 125623–125633 (2019)
7. M. Arsalan, A. Santra, Character recognition in air-writing based on network of radars for
human-machine interface. IEEE Sen. J. PP, 1 (2019)
8. C. Will, P. Vaishnav, A. Chakraborty, A. Santra, Human target detection, tracking, and classi-
fication using 24 GHZ FMCW radar. IEEE Sens. J. PP, 1 (2019)
9. A. Santra, R. Vagarappan Ulaganathan, T. Finke, Short-range millimetric-wave radar system
for occupancy sensing application. IEEE Sens. Lett. PP, 1 (2018)
10. H.L.V. Trees, Detection, Estimation, and Modulation Theory, Part I (Wiley, 2004)
11. A. Santra, I. Nasr, J. Kim, Reinventing radar: the power of 4D sensing. Microw. J. 61, 26–38
(2018)
12. F. Schroff, D. Kalenichenko, J. Philbin, Facenet: a unified embedding for face recognition and
clustering (2015), pp. 815–823
13. O. M. Parkhi, A. Vedaldi, A. Zisserman, Deep face recognition, vol. 1 (2015), pp. 41.1–41.12
14. S. Ren, K. He, R. Girshick, J. Sun, Faster r-cnn: towards real-time object detection with region
proposal networks. IEEE Trans. Pattern Anal. Mach. Intell. 39, 06 (2015)
15. J. Redmon, S. Divvala, R. Girshick, A. Farhadi, You only look once: unified, real-time object
detection (2016), pp. 779–788
16. L.-C. Chen, G. Papandreou, I. Kokkinos, K. Murphy, A.L. Yuille, Deeplab: semantic image
segmentation with deep convolutional nets, atrous convolution, and fully connected crfs. IEEE
Trans. Pattern Anal. Mach. Intell. PP (2016)
17. O. Ronneberger, P. Fischer, T. Brox, U-net: convolutional networks for biomedical image
segmentation (2015)
18. M.A. Wani, F.A. Bhat, S.Afzal, A.I. Khan, Advances in Deep Learning (Springer, Singapore,
2020)
19. K. He, X. Zhang, S. Ren, J. Sun, Deep residual learning for image recognition (2016), pp.
770–778
20. Z. Zhang, Q. Liu, Y. Wang, Road extraction by deep residual U-net. IEEE Geosci. Remote
Sens. Lett. PP (2017)
21. G. Zhang, H. Li, F. Wenger, Object detection and 3D estimation via an FMCW radar using a
fully convolutional network (2019). arXiv preprint arXiv:1902.05394
22. L. Wang, J. Tang, Q. Liao, A study on radar target detection based on deep neural networks.
IEEE Sens. Lett. 3(3), 1–4 (2019)
23. M. Stephan, A. Santra, Radar-based human target detection using deep residual U-net for
smart home applications, in 18th IEEE International Conference on Machine Learning And
Applications (ICMLA) (IEEE, 2019), pp. 175–182
24. J.S. Dramsch, Contributors, Complex-valued neural networks in keras with tensorflow (2019)
25. C. Trabelsi, O. Bilaniuk et al., Deep complex networks (2017). arXiv preprint arXiv:1705.09792
26. L. Xu, J. Li, P. Stoica, Adaptive techniques for MIMO radar, in Fourth IEEE Workshop on
Sensor Array and Multichannel Processing, vol. 2006 (IEEE, 2006), pp. 258–262
27. H. Rohling, Radar CFAR thresholding in clutter and multiple target situations. IEEE Trans.
Aerosp. Electron. Syst. 19, 608–621 (1983)
28. M. Ester, H.-P. Kriegel, J. Sander, X. Xu, A density-based algorithm for discovering clusters
in large spatial databases with noise, in KDD (1996)
29. A. Santra, R. Santhanakumar, K. Jadia, R. Srinivasan, SINR performance of matched illumi-
nation signals with dynamic target models (2016)
30. C. Szegedy, V. Vanhoucke, S. Ioffe, J. Shlens, Z. Wojna, Rethinking the inception architecture
for computer vision (2016)
31. T.-Y. Lin, P. Goyal, R. Girshick, K. He, P. Dollar, Focal loss for dense object detection (2017),
pp. 2999–3007
32. D. Kingma, J. Ba, Adam: a method for stochastic optimization, vol. 12 (2014)
Thresholding Strategies for Deep
Learning with Highly Imbalanced Big
Data

Justin M. Johnson and Taghi M. Khoshgoftaar

Abstract A variety of data-level, algorithm-level, and hybrid methods have been


used to address the challenges associated with training predictive models with class-
imbalanced data. While many of these techniques have been extended to deep neural
network (DNN) models, there are relatively fewer studies that emphasize the signif-
icance of output thresholding. In this chapter, we relate DNN outputs to Bayesian a
posteriori probabilities and suggest that the Default threshold of 0.5 is almost never
optimal when training data is imbalanced. We simulate a wide range of class imbal-
ance levels using three real-world data sets, i.e. positive class sizes of 0.03–90%,
and we compare Default threshold results to two alternative thresholding strategies.
The Optimal threshold strategy uses validation data or training data to search for
the classification threshold that maximizes the geometric mean. The Prior threshold
strategy requires no optimization, and instead sets the classification threshold to be
the prior probability of the positive class. Multiple deep architectures are explored
and all experiments are repeated 30 times to account for random error. Linear mod-
els and visualizations show that the Optimal threshold is strongly correlated with
the positive class prior. Confidence intervals show that the Default threshold only
performs well when training data is balanced and Optimal thresholds perform sig-
nificantly better when training data is skewed. Surprisingly, statistical results show
that the Prior threshold performs consistently as well as the Optimal threshold across
all distributions. The contributions of this chapter are twofold: (1) illustrating the
side effects of training deep models with highly imbalanced big data and (2) com-
paring multiple thresholding strategies for maximizing class-wise performance with
imbalanced training data.

J. M. Johnson (B) · T. M. Khoshgoftaar


Florida Atlantic University, Boca Raton, FL 33431, USA
e-mail: jjohn273@fau.edu
T. M. Khoshgoftaar
e-mail: khoshgof@fau.edu
© The Editor(s) (if applicable) and The Author(s), under exclusive license 199
to Springer Nature Singapore Pte Ltd. 2021
M. A. Wani et al. (eds.), Deep Learning Applications, Volume 2,
Advances in Intelligent Systems and Computing 1232,
https://doi.org/10.1007/978-981-15-6759-9_9
200 J. M. Johnson and T. M. Khoshgoftaar

1 Introduction

Class imbalance exists when the total number of samples from one class, or category,
is significantly larger than any other category within the data set. This phenomenon
arises in many critical industries, e.g. financial [1], biomedical [2], and environmental
[3]. In each of these examples, the positive class of interest is the smaller class, i.e.
the minority group, and there is an abundance of less-interesting negative samples.
In this study, we focus specifically on binary classification problems that contain a
positive and negative class. The concepts presented can be extended to the multi-
class problem, however, because multi-class problems can be converted into a set of
two-class problems through class decomposition [4].
Imbalanced data sets have been shown to degrade the performance of classifica-
tion models, often causing models to over-predict the majority group. As a result,
instances belonging to the minority group are incorrectly classified as negative sam-
ples and positive class performance suffers. To make matters worse, popular eval-
uation metrics like accuracy are liable to mislead analysts with high scores that
incorrectly indicate good prediction performance. For example, given a binary data
set with a positive class size of just 1%, a simple learner always outputs the negative
class will score 99% accuracy. The extent of this performance degradation depends
on problem complexity, data set size, and the level of class imbalance [5]. In this
study, we use deep neural network (DNN) models to make predictions with complex
data sets that are characterized by both big data and high class imbalance.
Generally, DNN model performance degrades as the level of class imbalance
increases and the relative size of the positive class decreases [6]. We denote the level
of class imbalance in a binary data set using the ratio of negative samples to positive
samples, i.e. nneg : npos . For example, the imbalance of a data set with 400 negative
samples and 100 positive instances is denoted by 80:20. Equivalently, we sometimes
refer to the positive class’s prior probability, e.g. the 80:20 distribution has a positive
class prior of 0.2 or 20%. We classify a data set as highly imbalanced when the
positive class prior is ≤ 0.01 [7]. When the total number of positive occurrences
becomes even more infrequent, we describe the data set as exhibiting rarity. Weiss et
al. [8] distinguish between absolute rarity and relative rarity. Absolute rarity occurs
when there are very few samples for a given class, regardless of the size of the
majority class. Unlike absolute rarity, a relatively rare class can make up a small
percentage of a data set and still have many occurrences when the overall data set
is very large. For example, given a data set with 10 million records and a relative
positive class size of 1%, there are still 100,000 positive cases to use for training
machine learning models. Therefore, we can sometimes achieve good performance
with relatively rare classes, especially when working with large volumes of data.
The experiments in this chapter explore a wide range of class imbalance levels, e.g.
0.03–90%, and include cases of relative rarity.
The challenges of working with class-imbalanced data are often compounded by
the challenges of big data [9]. Big data commonly refers to data which exceeds the
capabilities of standard data storage and processing. These data sets are also defined
Thresholding Strategies for Deep Learning … 201

using the four Vs: volume, variety, velocity, and veracity [10, 11]. The large vol-
umes of data being collected require highly scalable hardware and efficient analysis
tools, often demanding distributed implementations. In addition to adding architec-
ture and network overhead, distributed systems have been shown to exacerbate the
negative effects of class-imbalanced data [12]. The variety of big data corresponds
to the mostly unstructured, diverse, and inconsistent representations that arise as
data is consumed from multiple sources over extended periods of time. Advanced
techniques for quickly processing incoming data streams and maintaining appropri-
ate turnaround times are required to keep up with the rate at which data is being
generated, i.e. data velocity. Finally, the veracity of big data, i.e. its accuracy and
trustworthiness, must be regularly validated to ensure results do not become cor-
rupted. MapReduce [13] and Apache Spark [14] are two popular frameworks that
address these big data challenges by operating on partitioned data in parallel. Neural
networks can also be trained in a distributed fashion using either data parallelism or
model parallelism techniques [15, 16]. The three data sets used in this study include
many of these big data characteristics.
Data-level and algorithm-level techniques for addressing class imbalance have
been studied extensively. Data methods use sampling to alter the distribution of the
training data, effectively reducing the level of class imbalance for model training.
Random over-sampling (ROS) and random under-sampling (RUS) are the two sim-
plest data-level methods for addressing class imbalance. ROS increases the size of
the minority class by randomly copying minority samples, and RUS decreases the
size of the majority class by randomly discarding samples from the majority class.
From these fundamental data-level methods, many more advanced variants have
been developed [17–20]. Algorithm-level methods modify the training process to
increase the impact of the minority class and reduce bias toward the majority class.
Direct algorithm-level methods modify a machine learning algorithm’s underlying
learner, usually by incorporating class costs or weights. Meta-learner methods use
a wrapper to convert non-cost-sensitive learners into cost-sensitive learners. Cost-
sensitive learning and output thresholding are examples of direct and meta-learner
algorithm-level methods, respectively [21]. Output thresholding is the process of
changing the decision threshold that is used to assign class labels to a model’s pos-
terior probabilities[22, 23]. Finally, there are a number of hybrid methods that com-
bine two or more data-level methods and algorithm-level methods [7, 24–26]. In this
chapter, we explore output thresholding as it relates to deep learning.
Deep learning is a subfield of machine learning that uses artificial neural network
(ANN) with two or more hidden layers to approximate some function f ∗ , where
f ∗ can be used to map input data to new representations or make predictions [27].
The ANN, inspired by the biological neural network, is a set of interconnected
neurons, or nodes, where connections are weighted and each neuron transforms its
input into a single output by applying a nonlinear activation function to the sum of
its weighted inputs. In a feedforward network, input data propagates through the
network in a forward pass, each hidden layer receiving its input from the previous
layer’s output, producing a final output that is dependent on the input data, the choice
of activation function, and the weight parameters [28]. Gradient descent optimization
202 J. M. Johnson and T. M. Khoshgoftaar

adjusts the network’s weight parameters in order to minimize the loss function, i.e.
the error between expected output and actual output. Composing multiple nonlinear
transformations creates hierarchical representations of the input data, increasing the
level of abstraction through each transformation. The deep learning architecture,
i.e. deep neural network (DNN), achieves its power through this composition of
increasingly complex abstract representations [27]. Deep learning methods have
proven very successful in solving complex problems related to natural language and
vision [29]. These recent successes can be attributed to an increased availability
of data, improvements in hardware and software [30–34], and various algorithmic
breakthroughs that speed up training and improve generalization to new data [35].
Ideally, we would like to leverage the power of these deep models to improve the
classification of highly imbalanced big data. While output thresholding has been
used with traditional learners to treat class imbalance in both big and non-big data
problems, we found there is little work that properly evaluates its use in DNN models.
Most commonly, the Default threshold of 0.5 is used to assign class labels to a
classifier’s posterior probability estimates. In this chapter, however, we argue that
the Default threshold is rarely optimal when neural network models are trained with
class-imbalanced data. This intuition is drawn from the fact that neural networks
have been shown to estimate Bayesian a posteriori probabilities when trained with
sufficient data [36]. In other words, a well-trained DNN model is expected to output
a posterior probability estimate for input x that corresponds to yc (x) (Eq. 1), where
factors p(c) and p(x) are prior probabilities and p(x | c) is the conditional probability
of observing instance x given class c. We do not need to compute each factor indi-
vidually since neural networks do not estimate these factors directly, but we can use
Bayes Theorem and the positive class prior from our training data to better under-
stand the posterior estimates produced by our models. The estimated positive class
prior p(cpos ) is the probability that a random sample drawn from the training data
belongs to the positive class, and it is equal to the number of positive training sam-
ples divided by the total number of training samples. In highly imbalanced problems,
e.g. p(cpos ) ≤ 0.01, small positive class priors can significantly decrease posterior
estimates and, in some cases, yc (x) may never exceed 0.5. If yc (x) ≤ 0.5 for all x, the
Default threshold will incorrectly assign all positive samples to the negative class.
We can account for this imbalance by identifying Optimal thresholds that balance
positive and negative class performance. We build on this concept empirically by
exploring two thresholding strategies with three real-world data sets. The Optimal
thresholding strategy identifies the threshold that maximizes performance on training
or validation data and then uses the optimal value to make predictions on the test
set. The Prior thresholding strategy estimates p(cpos ) from the training data and then
uses p(cpos ) as the classification threshold for making predictions on the test set.
p(c) · p(x | c)
yc (x) = p(c | x) = (1)
p(x)

The first two data sets, Medicare Part B [37] and Part D [38], summarize claims
that medical providers have submitted to Medicare and include a class label that
Thresholding Strategies for Deep Learning … 203

indicates whether or not the provider is known to be fraudulent. Medicare is a United


States healthcare program that provides affordable health insurance to individuals 65
years and older, and other select individuals with permanent disabilities [39]. The Part
B and Part D data sets each have 4.6 and 3.6 million observations, respectively, with
positive class sizes < 0.04%. They were made publicly available by the Centers
for Medicare & Medicaid Services (CMS) in order to increase transparency and
reduce fraud. The third data set was first published by the Evolutionary Computation
for Big Data and Big Learning (ECBDL) workshop in 2014 and is now publicly
available [40]. Samples within the ECBDL’14 data set contain features that describe
pairs of amino acids and labels that indicate whether or not amino acid pairs are
spatially close in three-dimensional space. The accurate prediction of these protein
contact maps enables further inferences on the three-dimensional shape of proteins.
We use a subset of the ECBDL’14 data set consisting of 3.5 million instances and
maintain the original positive class size of 0.2%. The ECBDL’14 data set is not as
highly imbalanced as the Medicare data set, so we use data sampling techniques to
simulate high class imbalance.
The primary contribution of this study is in providing a unique and thorough
analysis of treating class imbalance with DNN models and output thresholding.
These thresholding techniques are applied to deep feedforward networks and can be
extended to other deep learning architectures described in “Advances in Deep Learn-
ing” [41]. ROS, RUS, and a hybrid ROS-RUS are used to create over 30 training
distributions with positive class priors ranging from 0.03 to 90%. Optimal thresh-
olds are computed for each distribution by maximizing performance on training and
validation sets. As expected, linear models reveal a strong relationship between the
positive class size and the Optimal threshold. In another experiment, we compute
Optimal thresholds for each training epoch and visualize its stability over time. Clas-
sification experiments are repeated 30 times to account for random error, and multiple
deep architectures are used to determine if results generalize to deeper models. Per-
formance results, confidence intervals, and figures are used to show that the Default
threshold should not be used when training DNN models with class-imbalanced data.
Finally, Tukey’s HSD (honestly significant difference) test [42] results show that the
Prior threshold performs as well as the Optimal threshold on average.
The remainder of this chapter is outlined as follows. Section 2 discusses several
methods for training deep models with class-imbalanced data and other related works
that have used the Medicare and ECBDL’14 data sets. In Sects. 3 and 4, we describe
the data sets used in this study and present our experiment design, respectively.
Results are discussed in Sect. 5, and Sect. 6 closes with suggestions for future works.

2 Related Work

This section begins by summarizing other works that address class imbalance with
DNN models. We then introduce works related to Medicare fraud prediction and
contact map prediction.
204 J. M. Johnson and T. M. Khoshgoftaar

2.1 Deep Learning with Class-Imbalanced Data

The effects of class imbalance on backpropagation and optimization were studied by


Anand et al. [43] using shallow neural networks. During optimization, the network
loss is dominated by the majority class, and the minority group has very little influence
on the total loss and network weight updates. This tends to reduce the error of the
majority group very quickly during early iterations while consequently increasing
the error of the minority group.
In recent studies, we explored a variety of data-level, algorithm-level, and hybrid
methods for addressing this phenomenon with deep neural networks [44, 45]. Several
authors explored data sampling methods [6, 46, 47] and found ROS to outperform
RUS and baseline models. Others employed cost-sensitive loss functions or proposed
new loss functions that reduce the bias toward the majority class [48–51]. Some of
the best results were achieved by more complex hybrid methods that leverage deep
feature learning and custom loss functions [52–54]. Output thresholding, which is
commonly used with traditional learners to maximize class performance, has received
very little attention in related deep learning works.
Lin et al. [51] proposed the Focal Loss function for addressing the severe class
imbalance found in object detection problems. While their study is not specifically
about thresholding, they do disclose using a threshold of 0.05 to speed up inference.
Dong et al. [54] also present a loss function for addressing class imbalance, i.e.
the Class Rectification Loss. They compare their proposed loss function to a num-
ber of alternative methods, including thresholding. Results from Dong et al. show
that thresholding outperforms ROS, RUS, cost-sensitive learning, and other baseline
models on the imbalanced X-Domain [55] image data set. These studies were not
intended to showcase thresholding, yet, their results clearly indicate that thresholding
plays an important role in classifying imbalanced data with deep models.
To the best of our knowledge, Buda et al. [6] were the only authors to explicitly
isolate the thresholding method and study its ability to improve the classification of
imbalanced data with deep models. ROS and RUS were used to create training dis-
tributions with varying levels of class imbalance from the MNIST [56] and CIFAR-
10 [57] benchmarks, and the authors evaluated minority class sizes between 0.02
and 50%. Thresholding was achieved by dividing CNN outputs by prior class prob-
abilities, and the accuracy performance metric was used to show how thresholding
improves class-wise performance in nearly all cases. In addition to outperforming
ROS, RUS, and the baseline CNN, the authors show that combining thresholding
with ROS performs exceptionally well and outperforms plain ROS. We expand on
the work by Buda et al. by incorporating statistical analysis and complementary per-
formance metrics, e.g. geometric mean (G-Mean), true positive rate (TPR), and true
negative rate (TNR). In addition, we provide a unique analysis that compares the
Optimal decision threshold to the positive class size of training distributions across
a wide range of class imbalance levels.
Thresholding Strategies for Deep Learning … 205

2.2 Medicare Fraud Detection

The big data, big value, and high class imbalance inherent in Medicare fraud predic-
tion make it an excellent candidate for evaluating methods designed to address class
imbalance. Bauder and Khoshgoftaar [58] use a subset of the 2012–2013 Medicare
Part B data, i.e. Florida claims only, to model expected amounts paid to providers for
services rendered to patients. In another study, Bauder and Khoshgoftaar [59] pro-
posed an outlier detection method that uses Bayesian inference to identify outliers,
and successfully validated their model using claims data of a known Florida provider
that was under criminal investigation for excessive billing. This experiment used a
subset of 2012–2014 Medicare Part B data that included dermatology and optom-
etry claims from Florida office clinics. Another paper by Bauder et al. [60] uses a
Naive Bayes classifier to predict provider specialty types, and then flag providers
that are practicing outside their expected specialty type as fraudulent. Results show
that specialties with unique billing procedures, e.g. audiologist or chiropractic, are
able to be classified with high precision and recall. Herland et al. [61] expanded
on the work from [60] by incorporating 2014 Medicare Part B data and real-world
fraud labels defined by the List of Excluded Individuals and Entities (LEIE) [62] data
set. The authors find that grouping similar specialty types, e.g. Ophthalmology and
Optometry, improves overall performance. Bauder and Khoshgoftaar [63] merge the
2012–2015 Medicare Part B data sets, label fraudulent providers using the LEIE data
set, and compare multiple traditional machine learning classifiers. Class imbalance
is addressed with RUS, and various class distributions are generated to identify the
optimal imbalance ratio for training. The C4.5 decision tree and logistic regression
(LR) learners significantly outperform the support vector machine (SVM), and the
80:20 class distribution is shown to outperform 50:50, 65:35, and 75:25 distributions.
In summary, these studies show that Medicare Part B and Part D claims data contains
sufficient variability to detect fraudulent providers and that the LEIE data set can be
reliably used for ground truth fraud labels.
The Medicare experiments in our study leverage data sets curated by Herland
et al. [64]. In one study, Herland et al. [64] used Medicare Part B, Part D, and
DMEPOS claims data from the years 2012–2016. Cross-validation and ROC AUC
scores are used to compare LR, Random Forest (RF), and Gradient Boosted Tree
(GBT) learners. Results show that Part B data sets score significantly better on ROC
AUC than the Part D data set, and the LR learner outperforms the GBT and RF
learners with a max AUC of 0.816. In a second paper, Herland et al. [65] used these
same Medicare data sets to study the effect of class rarity with LR, RF, and GBT
learners. In this study, the authors create an absolutely rare positive class by using
subsets of the positive class to form new training sets. They reduced the positive class
size to 1000, 400, 200, and 100, and then used RUS methods to treat imbalance and
compare AUC scores. Results show that smaller positive class counts degrade model
performance, and the LR learner with an RUS distribution of 90:10 performs best.
Several other research groups have taken interest in detecting Medicare fraud
using the CMS Medicare and LEIE data sets. Feldman and Chawla [66] looked for
206 J. M. Johnson and T. M. Khoshgoftaar

anomalies in the relationship between medical school training and the procedures that
physicians perform in practice by linking 2012 Medicare Part B data with provider
medical school data obtained through the CMS physician compare data set [67].
Significant procedures for each school were used to evaluate school similarities and
present a geographical analysis of procedure charges and payment distributions. Ko
et al. [68] used the 2012 CMS data to analyze the variability of service utilization
and payments. Ko et al. found that the number of patient visits is strongly correlated
with Medicare reimbursement, and concluded that there is a possible 9% savings
within the field of Urology alone. Chandola et al. [69] used claims data and fraud
labels from the Texas Office of Inspector General’s exclusion database to detect
anomalies. The authors confirm the importance of including provider specialty types
in fraud detection, showing that the inclusion of specialty attributes increases AUC
scores from 0.716 to 0.814. Branting et al. [70] propose a graph-based method for
estimating healthcare fraud risk within the 2012–2014 CMS and LEIE data sets.
The authors leverage the NPPES [71] registry to look up providers that are missing
from the LEIE database, increasing their total fraudulent provider count to 12,000.
Branting et al. combine these fraudulent providers with a subset of 12,000 non-
fraudulent providers and employ a J48 decision tree learner to classify fraud with a
mean AUC of 0.96.

2.3 ECBDL’14 Contact Map Prediction

The Community-Wide Experiment on the Critical Assessment of Techniques for


Protein Structure Prediction (CASP) has been assessing team’s abilities to predict
protein structure publicly since 1994 [72]. A total of 13 CASP experiments have
been held thus far, and the CASP14 conference is scheduled for December of 2020.
CASP notes that substantial progress has been made in the area of residue–residue
contact prediction over recent years and descriptions of results are available in their
Proceedings [72].
Triguero et al. [73] won the ECBDL’14 competition by employing random over-
sampling, evolutionary feature selection, and an RF learner (ROSEFW-RF) within
the MapReduce framework. In their paper, the authors explore a range of hyperpa-
rameters and achieve their best results with an over-sampling rate of 170% and a
subset of 90 features. We use the results from Triguero et al., i.e. a balanced TPR and
TNR of 0.730, to evaluate the DNN thresholding results obtained in our ECBDL’14
experiments.
Since the competition, other groups have used ECBDL’14 data to evaluate meth-
ods for treating class imbalance and big data. Fernández et al. [74] compared the
performance of ROS, RUS, and SMOTE using subsets of ECBDL’14 containing 0.6
and 12 million instances and 90 features. Apache Spark and Hadoop frameworks
were used to distribute RF and decision tree models across partition sizes of 1, 8, 16,
32, and 64. Results from Fernández et al. show that ROS and RUS perform better
than SMOTE, and ROS tends to perform better as the number of partitions increases.
Thresholding Strategies for Deep Learning … 207

The best performing learner scored a G-Mean of 0.706 using the subset of 12 million
instances and the ROS strategy, suggesting that the subset of 0.6 million instances
is not representative enough. Río et al. [75] also used the Hadoop framework to
explore RF learner performance with ROS and RUS methods for addressing class
imbalance. They too found that ROS outperforms RUS, and suggested that this is
due to the already underrepresented minority class being split across many partitions.
Río et al. achieved their best performance with 64 partitions and an over-sampling
rate of 130%. Unfortunately, results show a relatively low TPR (0.705) compared to
the reported TNR (0.725) and the winning competition results (0.730).
Similar to these related works, we use a subset of ECBDL’14 data (3.5 million
instances) to evaluate thresholding strategies for DNN classification. Unlike related
works, however, we did not find it necessary to use a distributed training environ-
ment. Instead of training individual models in a distributed fashion, we use multiple
compute nodes with sufficient resources to train multiple models independently and
in parallel. Also, unlike these related works, we do not rely on data sampling tech-
niques to balance TPR and TNR scores. Rather, we show how a simple thresholding
technique can be used to optimize class-wise performance regardless of the class
imbalance level.

3 Data Sets

This section summarizes the data sets that are used to evaluate thresholding tech-
niques for addressing class imbalance with deep neural networks. We begin with
two Medicare fraud data sets that were first curated by Herland et al. [64]. We then
incorporate a large protein contact map prediction data set that was published by the
Evolutionary Computation for Big Data and Big Learning (ECBDL) workshop in
2014 [40]. All three data sets were obtained from publicly available resources and
exhibit big data and class imbalance characteristics.

3.1 CMS Medicare Data

Two publicly available Medicare fraud data sets are obtained from CMS: (1) Medicare
Provider Utilization and Payment Data: Physician and Other Supplier (Part B) [37],
and (2) Medicare Provider Utilization and Payment Data: Part D Prescriber (Part
D) [38]. The healthcare claims span years 2012–2016 and 2013–2017 for Part B
and Part D data sets, respectively. Physicians are identified within each data set by
their National Provider Identifier (NPI), a unique 10-digit number that is used to
identify healthcare providers [76]. Using the NPI, Herland et al. map fraud labels to
the Medicare data from the LEIE repository. The LEIE is maintained by the Office of
Inspector General and it lists providers that are prohibited from practicing. Additional
attributes of LEIE data include the reason for exclusion and provider reinstatement
208 J. M. Johnson and T. M. Khoshgoftaar

Table 1 Description of Part B features [64]


Feature Description
Npi Unique provider identification number
Provider_type Medical provider’s specialty (or practice)
Nppes_provider_gender Provider’s gender
Line_srvc_cnt Number of procedures/services the provider performed
Bene_unique_cnt Number of Medicare beneficiaries receiving the service
Bene_day_srvc_cnt Number of Medicare beneficiaries/per day services
Avg_submitted_chrg_amt Avg. of the charges that a provider submitted for service
Avg_medicare_payment_amt Avg. payment made to a provider per claim for service
Exclusion Fraud labels from the LEIE data set

dates, where applicable. Providers that have been excluded for fraudulent activity
are labeled as fraudulent within the Medicare Part B and Part D data sets.
The Part B claims data set describes the services and procedures that healthcare
professionals provide to Medicare’s Fee-For-Service beneficiaries. Records within
the data set contain various provider-level attributes, e.g. NPI, first and last name,
gender, credentials, and provider type. More importantly, records contain specific
claims details that describe a provider’s activity within Medicare. Examples of claims
data include the procedure performed, the average charge submitted to Medicare, the
average amount paid by Medicare, and the place of service. The procedures rendered
are encoded using the Healthcare Common Procedures Coding System (HCPCS)
[77]. For example, HCPCS codes 99219 and 88346 are used to bill for hospital
observation care and antibody evaluation, respectively. Part B data is aggregated by
(1) provider NPI, (2) HCPCS code, and (3) place of service. The list of Part B features
used for training are provided in Table 1.
Similarly, the Part D data set contains a variety of provider-level attributes, e.g.
NPI, name, and provider type. More importantly, the Part D data set contains specific
details about medications prescribed by Medicare providers. Examples of prescrip-
tion attributes include drug names, costs, quantities prescribed, and the number of
beneficiaries receiving the medication. CMS aggregates the Part D data over (1) pre-
scriber NPI and (2) drug name. Table 2 summarizes all Part D predictors used for
classification.

3.2 ECBDL’14 Data

The ECBDL’14 data set was originally generated to train a predictor for the residue–
residue contact prediction track of the 9th Community-Wide Experiment on the
Critical Assessment of Techniques for Protein Structure Prediction competition
(CASP9) [72]. Protein contact map prediction is a subproblem of protein structure
Thresholding Strategies for Deep Learning … 209

Table 2 Description of Part D features [64]


Feature Description
Npi Unique provider identification number
Provider_type Medical provider’s specialty (or practice)
Bene_count Number of distinct beneficiaries receiving a drug
Total_claim_count Number of drugs administered by a provider
Total_30_day_fill_count Number of standardized 30-day fills
Total_day_supply Number of day’s supply
Total_drug_cost Cost paid for all associated claims
Exclusion Fraud labels from the LEIE data set

prediction. This subproblem entails predicting whether any two residues in a pro-
tein sequence are spatially close to each other [78]. The three-dimensional structure
of a protein can then be inferred from these residue–residue contact map predic-
tions. This is a fundamental problem in medical domains, e.g. drug design, as the
three-dimensional structure of a protein determines its function [79].
Each instance of the ECBDL’14 data set is a pair of amino acids represented by 539
continuous attributes, 92 categorical attributes, and a binary label that distinguishes
pairs that are in contact. Triguero et al. [73] provide a thorough description of the
data set and the methods used to win the competition. Attributes include detailed
information and statistics about the protein sequence and the segments connecting
the target pair of amino acids. Additional predictors include the length of the protein
sequence and a statistical contact propensity between the target pair of amino acid
types [73]. The training partition contains 32 million instances and a positive class
size of 2%. We refer readers to the original paper by Triguero et al. for a deeper
understanding of the amino acid representation. The characteristics of this data set
have made it a popular choice for evaluating methods of treating class imbalance and
big data.

3.3 Data Preprocessing

Both Medicare Part B and Part D data sets were curated and cleaned by Herland et
al. [64] and required minimal preprocessing. First, test sets were created by using
stratified random sampling to hold out 20% of each Medicare data set. The held-
out test sets remain constant throughout all experiments, and when applicable, data
sampling is only applied to the training data to create new class imbalance levels.
These methods for creating varying levels of class imbalance are described in detail
in Sect. 4. Finally, all features were normalized to continuous values in the range
[0, 1]. Train and test sets were normalized by fitting a min-max scaler to the fit data
and applying the scaler to the train and test sets separately.
210 J. M. Johnson and T. M. Khoshgoftaar

Table 3 Train and test set sizes


Data set Sample count Feature count Positive count Positive (%)
Part B Train 3,753,896 125 1206 0.032
Test 938,474 302 0.032
Part D Train 2,917,961 133 1028 0.035
Test 729,491 257 0.035
ECBDL’14 Train 2,800,000 200 59,960 2.141
Test 700,000 15,017 2.145

A subset of 3.5 million records was taken from the ECBDL’14 data with ran-
dom under-sampling. Categorical features were encoded using one-hot encoding,
resulting in a final set of 985 features. Similar to the Medicare data, we normalized
all attributes to the range [0, 1] and set aside a 20% test set using stratified random
sampling. To improve efficiency, we applied Chi-square feature selection [80] and
selected the best 200 features. Preliminary experiments suggested that exceeding 200
features provided limited returns on validation performance.
Table 3 lists the sizes of train and test sets, the total number of predictors, and
the size of the positive class for each data set. Train sets have between 2.8 and 3.7
million samples and relatively fewer positive instances. The ECBDL’14 data set has
a positive class size of 2% and is arguably not considered to be highly imbalanced,
i.e. positive class is greater than 1% of the data set. In Sect. 4, we explain how high
class imbalance is simulated by under-sampling the positive class. The Medicare
data sets, on the other hand, have just between 3 and 4 positive instances for every
10,000 observations and are intrinsically severely imbalanced.

4 Methods

We evaluate thresholding strategies for addressing high class imbalance with deep
neural networks across a wide range of class imbalance levels. A stratified random
80–20% split is used to create train and test partitions from each of the three data sets.
Training sets are sampled to simulate the various levels of class imbalance, and test
sets are held constant for evaluation purposes. All validation and hyperparameter
tuning are executed on random partitions of the training data. After configuring
hyperparameters, each model is fit to the training data and scored on the test set with
30 repetitions. This repetition accounts for any variation in results that may be caused
by random sampling and allows for added statistical analysis.
All experiments are performed on a high-performance computing environment
running Scientific Linux 7.4 (Nitrogen) [81]. Neural networks are implemented using
the Keras [32] open-source deep learning library written in Python with the Tensor-
Flow [30] backend. The specific library implementations used in this study are the
Thresholding Strategies for Deep Learning … 211

default configurations of Keras 2.1.6-tf and TensorFlow 1.10.0. The scikit-learn


package [82] (version 0.21.1) is used for preprocessing data.
The remainder of this section describes (1) the DNN architectures used for each
data set, (2) the data sampling methods used to vary class imbalance levels, (3) the
Optimal and Prior thresholding strategy procedures, and (4) performance evaluation
criteria.

4.1 Baseline Models

Baseline architectures and hyperparameters were discovered through a random


search procedure. For each data set, we set aside a random 10% of the fit data for vali-
dation, trained models on the remaining 90%, and then scored them on the validation
set. This process was repeated 10 times for each hyperparameter configuration. The
number of hidden layers, the number of neurons per layer, and regularization tech-
niques were the primary focus of hyperparameter tuning. Experiments were restricted
to deep fully connected models, i.e. neural networks containing two or more hidden
layers. We first sought a model with sufficient capacity to learn the training data and
then applied regularization techniques to reduce overfitting and improve generaliza-
tion to validation sets. We used the area under the Receiver Operating Characteristic
curve (ROC AUC) [83] performance metric to assess validation results. We prefer
the ROC AUC metric for comparing models because it is threshold agnostic. If one
model achieves a higher AUC score, then there exists an operating point (threshold)
that will also achieve higher class-wise performance.
Model validation results led us to select the following hyperparameter configu-
rations. Mini-batch stochastic gradient descent with mini-batch sizes of 256 is used
for all three data sets. This is preferred over batch gradient descent because it is
computationally expensive to compute the loss over the entire data set, and increas-
ing the number of samples that contribute to the gradient provides less than linear
returns [27]. It has also been suggested that smaller batch sizes offer a regularization
effect by introducing noise into the learning process [84]. We employ an advanced
form of stochastic gradient descent (SGD) that adapts parameter-specific learning
rates through training, i.e. the Adam optimizer, as it has been shown to outperform
other popular optimizers [85]. The default learning rate (lr = 0.001) is used along
with default moment estimate decay rates of β1 = 0.9 and β2 = 0.999. The Rectified
Linear Unit (ReLU) activation function is used in all hidden layer neurons, and the
sigmoid activation function is used at the output layer to estimate posterior probabil-
ities [86]. The non-saturating ReLU activation function has been shown to alleviate
the vanishing gradient problem and allow for faster training [35].
Network topologies were defined by first iteratively increasing architecture depth
and width while monitoring training and validation performance. For both Medicare
data sets, we determined that two hidden layers containing 32 neurons per layer pro-
vided sufficient capacity to overfit the model to the training data. For the ECBDL’14
data set, a larger network with four hidden layers containing between 128 and 32
212 J. M. Johnson and T. M. Khoshgoftaar

Table 4 Medicare Part B two-layer architecture


Layer type # of neurons # of parameters
Input 125 0
Dense 32 4032
Batch normalization 32 128
ReLU activation 32 0
Dropout P = 0.5 32 0
Dense 32 1056
Batch normalization 32 128
ReLU activation 32 0
Dropout P = 0.5 32 0
Dense 1 33
Sigmoid activation 1 0

neurons each was required to fit the training data. We then explored regularization
techniques to eliminate overfitting and improve validation performance. One way to
reduce overfitting is to reduce the total number of learnable parameters, i.e. reduc-
ing network depth or width. L1 or L2 regularization methods, or weight decay, add
parameter penalties to the objective function that constrain the network’s weights
to lie within a region that is defined by a coefficient α [27]. Dropout simulates the
ensembling of many models by randomly disabling non-output neurons with a prob-
ability P ∈ [0, 1] during each iteration, preventing neurons from co-adapting and
forcing the model to learn more robust features [87]. Although originally designed
to address internal covariate shift and speed up training, batch normalization has also
been shown to add regularizing effects to neural networks [88]. Batch normalization
is similar to normalizing input data to have a fixed mean and variance, except that it
normalizes the inputs to hidden layers across each batch. We found a combination of
dropout and batch normalization for best performance for all three data sets. For the
Medicare models, we use a dropout rate of P = 0.5 and for the ECBDL’14 data set
we use a dropout rate of P = 0.8. Batch normalization is applied before the activation
function in each hidden unit.
Table 4 describes the two-layer baseline architecture for the Medicare Part B data
set. To determine how the number of hidden layers affects performance, we extended
this model to four hidden layers following the same pattern, i.e. using 32 neuron lay-
ers, batch normalization, ReLU activations, and dropout in each hidden layer. We did
not find it necessary to select new hyperparameters for the Medicare Part D data set.
Instead, we just changed the size of the input layer to match the total number of fea-
tures in each respective data set. The architecture for the ECBDL’14 data set follows
this same basic pattern but contains four hidden layers with 128, 128, 64, and 32 neu-
rons in each consecutive layer. With the increased feature count and network width,
the ECBDL’14 network contains 54 K tunable parameters and is approximately 10×
larger than the two-layer architecture used in Medicare experiments.
Thresholding Strategies for Deep Learning … 213

4.2 Data Sampling Strategies

We use data sampling to alter training distributions and evaluate thresholding strate-
gies across a wide range of class imbalance levels. The ROS method randomly dupli-
cates samples from the minority class until the desired positive class prior is achieved.
RUS randomly removes samples from the majority class without replacement until
the desired level of imbalance is reached. The hybrid ROS-RUS method first under-
samples from the majority class without replacement and then over-samples the
minority class until classes are balanced. A combination of ROS, RUS, and ROS-
RUS is used to create 17 new training distributions from each Medicare data set and
12 new training distributions from the ECBDL’14 data set.
Table 5 describes the 34 training distributions that were created from the Medicare
Part B and Part D data sets. Of these new distributions, 24 contain low to severe
levels of class imbalance and 10 have balanced positive and negative classes. The
ROS-RUS-1, ROS-RUS-2, and ROS-RUS-3 use RUS to remove 50, 75, and 90% of
the majority class. They then over-sample the minority class until both classes are
balanced 50:50.

Table 5 Description of medicare distributions


Distribution Positive prior CMS Part B CMS Part D
type
nneg npos nneg npos
Baseline 0.0003 3,377,421 1,085 2,916,933 1,028
RUS-1 0.001 773,092 1,085 1,027,052 1,028
RUS-2 0.005 194,202 1,085 204,477 1,028
RUS-3 0.01 107,402 1,085 101,801 1,028
RUS-4 0.20 4,390 1,085 4,084 1,028
RUS-5 0.40 1,620 1,085 1,546 1,028
RUS-6 0.50 1,085 1,085 1,028 1,028
RUS-7 0.60 710 1,085 671 1,028
ROS-1 0.001 3,377,421 3,385 2,916,933 2,920
ROS-2 0.005 3,377,421 16,969 2,916,933 14,659
ROS-3 0.01 3,377,421 33,635 2,916,933 29,401
ROS-4 0.20 3,377,421 844,130 2,916,933 729,263
ROS-5 0.40 3,377,421 2,251,375 2,916,933 1,944,626
ROS-6 0.50 3,377,421 3,377,421 2,916,933 2,916,929
ROS-7 0.60 3,377,421 5,064,780 2,916,933 4,375,404
ROS-RUS-1 0.50 1,688,710 1,688,710 1,458,466 1,458,466
ROS-RUS-2 0.50 844,355 844,355 729,233 729,233
ROS-RUS-3 0.50 337,742 337,742 291,693 291,693
214 J. M. Johnson and T. M. Khoshgoftaar

input : targets y, probability estimates p


output: optimal threshold
best_thresh ← curr_thresh ← max_gmean ← 0;
delta_thresh ← 0.0005;
while curr_thresh < 1.0 do
ŷ ← ApplyThreshold(p, curr_thresh);
tpr, tnr, gmean ← CalcPerformance(y, ŷ);
if tpr < tnr then
return best_thresh;
end
if gmean > max_gmean then
max_gmean ← gmean;
best_thresh ← curr_thresh;
end
curr_thresh ← curr_thresh + delta_thresh;
end
return best_thresh;
Algorithm 1: Optimal threshold procedure

The original ECBDL’14 data set has a positive class size of 2% and is not classified
as highly imbalanced. Therefore, we first simulate two highly imbalanced distribu-
tions by combining the entire majority class with two subsets of the minority class.
By randomly under-sampling the minority class, we achieve two new distributions
that have positive class sizes of 1% and 0.5%. We create additional distributions with
positive class sizes of 5, 10, 20, 30, 40, 50, 60, 70, 80, and 90% by using RUS to
reduce the size of the negative class. As a result, we are able to evaluate thresholding
strategies on 12 class-imbalanced distributions of ECBDL’14 data and one balanced
distribution of ECBDL’14 data. The size of each positive and negative class can be
inferred from these strategies and the training data sizes from Table 3.

4.3 Thresholding Strategies

The Optimal threshold strategy is used to approximately balance the TPR and TNR.
We accomplish this by using a range of decision thresholds to score models on
validation and training data and then select the threshold which maximizes the G-
Mean. We also add a constraint that the Optimal threshold yields a TPR that is
greater than the TNR, because we are more interested in detecting positive instances
than negative instances. Threshold selection can be modified to optimize any other
performance metric, e.g. precision, and the metrics used to optimize thresholds should
ultimately be guided by problem requirements. If false positives are very costly, for
example, then a threshold that maximizes TNR would be more appropriate. The
performance metrics used in this study are explained in Sect. 4.4 and the procedure
used to compute these Optimal thresholds is defined in Algorithm 1. Once model
training is complete, Algorithm 1 takes ground truth labels and probability estimates
from the trained model, iterates over a range of possible threshold values, and returns
the threshold that maximizes the G-Mean.
Thresholding Strategies for Deep Learning … 215

For Medicare experiments, Optimal decision thresholds are computed during the
validation phase. The validation step for each training distribution entails training
10 models and scoring them on random 10% partitions of the fit data, i.e. validation
sets. Optimal thresholds are computed on each validation set, averaged, and then the
average from each distribution is used to score models on the test set. While this
use of validation data should reduce the risk of overfitting, we do not use validation
data to compute Optimal thresholds for ECBDL’14 experiments. Instead, ECBDL’14
Optimal thresholds are computed by maximizing performance on fit data. Models are
trained for 50 epochs using all fit data, and then Algorithm 1 searches for the threshold
that maximizes performance on the training labels and corresponding model predic-
tions. Unlike the Medicare Optimal thresholds, the ECBDL’14 Optimal thresholds
can be computed with just one extra pass over the training data and do not require
validation partitions. This is beneficial when training data is limited or when the
positive class is absolutely rare.
The Prior thresholding strategy estimates the positive class prior from the training
data, i.e. p(cpos ) from Eq. 1, and uses its value to assign class labels to posterior
scores on the test set. Given a training set with p(cpos ) = 0.1, for example, the Prior
thresholding strategy will assign all test samples with probability scores > 0.1 to
the positive class and those with scores ≤ 0.1 to the negative class. Since the Prior
threshold can be calculated from the training data, and no optimization is required,
we believe it is a good candidate for preliminary experiments with imbalanced data.
Due to time constraints, this method was only explored using ECBDL’14 data.

4.4 Performance Evaluation

The confusion matrix (Table 6) is created by comparing predicted labels to ground


truth labels, where predicted labels are dependent on model outputs and the decision
threshold. From the confusion matrix, we compute the TPR (Eq. 2), TNR (Eq. 3),
and G-Mean (Eq. 4) performance metrics. We compare results using the Default,
Optimal, and Prior thresholding strategies using 95% confidence intervals. From
these confidence intervals, we are able to determine which thresholding strategies
perform significantly better than others. We do not consider the ROC AUC metric
because it is threshold agnostic, and we do not consider accuracy or error rate because
they are misleading when working with imbalanced data.

TP
TPR = Recall = (2)
TP + FN

Table 6 Confusion matrix


Actual positive Actual negative
Predicted positive True positive (TP) False positive (FP)
Predicted negative False negative (FN) True negative (TN)
216 J. M. Johnson and T. M. Khoshgoftaar

TN
TNR = Selectivity = (3)
TN + FP

G-Mean = TPR × TNR (4)

We also use Tukey’s HSD test (α = 0.05) to estimate the significance of


ECBDL’14 results. Tukey’s HSD test is a multiple comparison procedure that deter-
mines which method means are statistically different from each other by identifying
differences that are greater than the expected standard error. Result sets are assigned
to alphabetic groups based on the statistical difference of performance means, e.g.
group a performs significantly better than group b.

5 Results and Discussion

This section presents the DNN thresholding results that were obtained using the
Medicare and ECBDL’14 data sets. We begin by illustrating the relationship between
Optimal decision thresholds and positive class sizes using confidence intervals (C.I.)
and linear models. Next, Default and Optimal thresholds are used to compare G-Mean
scores across all Medicare distributions. ECBDL’14 results make similar compar-
isons and incorporate a third Prior thresholding strategy that proves effective. Finally,
a statistical analysis of TPR and TNR scores is used to estimate the significance of
each method’s results.

5.1 The Effect of Priors on Optimal Thresholds

Classification thresholds are optimized in Medicare Part B and Part D experiments by


training models and maximizing performance on a validation set using Algorithm 1.
To account for random error and enable statistical analysis, this process is repeated
10 times for each distribution and architecture pair. Results from the two-layer archi-
tecture are listed in Table 7. Confidence intervals (α = 0.05) are provided separately
for each Medicare data set, and bold-typed intervals indicate those which overlap the
Default threshold of 0.5.
Medicare results from the two-layer network suggest that the Optimal threshold
varies significantly with the positive class prior. For example, optimal thresholds
range from 0.0002 to 0.6478 as the positive class prior increases from 0.0003 to 0.6.
More specifically, most distributions have Optimal thresholds that are approximately
equal to their positive class prior. Following this pattern, we observe that Optimal
threshold intervals only overlap the Default threshold when the positive class prior
is equal to 0.5. We also observe that the Part B and Part D threshold intervals for
respective distributions overlap each other in 17 of 18 cases. This suggests that the
Thresholding Strategies for Deep Learning … 217

Table 7 Medicare optimal thresholds


Distribution type Pos. class prior Optimal threshold 95% C.I.
Medicare Part B Medicare Part D
Baseline 0.03 (0.0002, 0.0003) (0.0002, 0.0004)
RUS-1 0.001 (0.0007, 0.0011) (0.0005, 0.0009)
RUS-2 0.005 (0.0059, 0.0069) (0.0049, 0.0056)
RUS-3 0.01 (0.0095, 0.0125) (0.0107, 0.0130)
RUS-4 0.2 (0.2502, 0.2858) (0.1998, 0.2516)
RUS-5 0.4 (0.3959, 0.4441) (0.4030, 0.4665)
RUS-6 0.5 (0.4704, 0.5236) (0.4690, 0.5326)
RUS-7 0.6 (0.5400, 0.6060) (0.5495, 0.6478)
ROS-1 0.1 (0.0005, 0.0009) (0.0005, 0.0008)
ROS-2 0.5 (0.0051, 0.0073) (0.0051, 0.0064)
ROS-3 0.01 (0.0087, 0.0132) (0.0112, 0.0139)
ROS-4 0.2 (0.2135, 0.2685) (0.1958, 0.2613)
ROS-5 0.4 (0.3691, 0.4469) (0.3409, 0.4197)
ROS-6 0.5 (0.4150, 0.4910) (0.3795, 0.4925)
ROS-7 0.6 (0.5169, 0.6091) (0.5189, 0.5707)
ROS-RUS-1 0.5 (0.4554, 0.5146) (0.3807, 0.4889)
ROS-RUS-2 0.5 (0.4940, 0.5497) (0.4111, 0.5203)
ROS-RUS-3 0.5 (0.4771, 0.5409) (0.4119, 0.4774)

Fig. 1 Positive class size versus optimal decision threshold

relationship between the positive class size and the Optimal threshold is both linear
and independent of the data set.
In Fig. 1, Medicare Optimal threshold results are grouped by architecture type and
plotted against the positive class size of the training distribution. Plots are enhanced
with horizontal jitter, and linear models are fit to the data using Ordinary Least
218 J. M. Johnson and T. M. Khoshgoftaar

Fig. 2 ECBDL training epochs versus optimal thresholds

Squares [89] and 95% confidence bands. For both Medicare data sets and network
architectures, there is a strong linear relationship between the positive class size and
the Optimal decision threshold. The relationship is strongest for the two-layer net-
works, with r 2 ≥ 0.980 and p ≤ 9.73e−145. The four-layer network results share
these linear characteristics, albeit weaker with r 2 ≤ 0.965 and visibly larger confi-
dence bands.
We also compute Optimal thresholds after each training epoch using ECBDL’14
data to determine how the Optimal threshold varies during model training. We trained
models for 50 epochs and repeated each experiment three times. As illustrated in
Fig. 2, the Optimal threshold is relatively stable and consistent throughout training.
Similar to Medicare results, the ECBDL’14 Optimal thresholds correspond closely
to the positive class prior of the training distribution.
This section concludes that the positive class size has a strong linear effect on the
Optimal decision threshold. We also found that the Optimal threshold for a given
distribution may vary between deep learning architectures. In the next section, we
consider the significance of these thresholds by using them to evaluate performance
on unseen test data.

5.2 Medicare Classification Results

Optimal classification threshold performance is first evaluated against Default thresh-


old performance using the Medicare Part B and Part D data sets. Threshold results
are compared over a range of class imbalance levels, i.e. 0.03–60%, using G-Mean,
TPR, and TNR performance metrics. G-Mean results from the two- and four-layer
networks are consolidated by aggregating on the distribution for each data set, and
TPR and TNR results are averaged across both Medicare data sets.
Thresholding Strategies for Deep Learning … 219

Table 8 Medicare Part B G-mean scores


Distribution type Pos. class prior G-mean 95% C.I.
Optimal threshold Default threshold
Baseline 0.0003 (0.7281, 0.7321) (0.0000, 0.0000)
RUS-1 0.001 (0.7206, 0.7280) (0.0000, 0.0000)
RUS-2 0.005 (0.7379, 0.7425) (0.0000, 0.0000)
RUS-3 0.01 (0.7351, 0.7415) (0.0000, 0.0000)
RUS-4 0.2 (0.7322, 0.7353) (0.0903, 0.1939)
RUS-5 0.4 (0.7171, 0.7253) (0.7307, 0.7333)
RUS-6 0.5 (0.7148, 0.7242) (0.7155, 0.7225)
RUS-7 0.6 (0.7109, 0.7199) (0.6542, 0.6798)
ROS-1 0.001 (0.7151, 0.7235) (0.0000, 0.0000)
ROS-2 0.005 (0.7459, 0.7543) (0.0000, 0.0000)
ROS-3 0.01 (0.7197, 0.7479) (0.0000, 0.0000)
ROS-4 0.2 (0.7449, 0.7649) (0.6070, 0.6466)
ROS-5 0.4 (0.7435, 0.7729) (0.7563, 0.7695)
ROS-6 0.5 (0.7665, 0.7719) (0.7563, 0.7695)
ROS-7 0.6 (0.7673, 0.7729) (0.7434, 0.7686)
ROS-RUS-1 0.5 (0.7576, 0.7754) (0.7488, 0.7744)
ROS-RUS-2 0.5 (0.7680, 0.7740) (0.7472, 0.7726)
ROS-RUS-3 0.5 (0.7506, 0.7744) (0.7497, 0.7749)

Tables 8 and 9 list the 95% G-Mean confidence intervals for all Medicare Part B
and Part D distributions, respectively. Intervals listed in bold indicate those which
are significantly greater than the alternative. Our first observation is that the Default
classification threshold of 0.5 never performs significantly better than the Optimal
threshold. In fact, the Default threshold only yields acceptable G-Mean scores when
classes are virtually balanced, e.g. priors of 0.4–0.6. In all other distributions, the per-
formance of the Default threshold degrades as the level of class imbalance increases.
The Optimal threshold, however, yields relatively stable G-Mean scores across all
distributions. Even the baseline distribution, with a positive class size of just 0.03%,
yields acceptable G-Mean scores > 0.72 when using an Optimal classification thresh-
old. Overall, these results discourage using the Default classification threshold when
training DNN models with class-imbalanced data.
Results also indicate an increase in G-Mean scores among the ROS and ROS-
RUS methods. This is not due to the threshold procedure, but rather, the under-
sampling procedure used to create the RUS distributions. Our previous work shows
that using RUS with these highly imbalanced big data classification tasks tends to
underrepresent the majority group and degrade performance [90].
Figure 3 presents the combined TPR and TNR scores for both Medicare data
sets and DNN architectures. Optimal classification thresholds (left) produce stable
220 J. M. Johnson and T. M. Khoshgoftaar

Table 9 Medicare Part D G-mean scores


Distribution type Pos. class prior G-mean 95% C.I.
Optimal threshold Default threshold
Baseline 0.0003 (0.6986, 0.7022) (0.0000, 0.0000)
RUS-1 0.001 (0.7058, 0.7128) (0.0000, 0.0000)
RUS-2 0.005 (0.7262, 0.7300) (0.0000, 0.0000)
RUS-3 0.01 (0.7305, 0.7345) (0.0000, 0.0000)
RUS-4 0.2 (0.7052, 0.7120) (0.1932, 0.2814)
RUS-5 0.4 (0.6815, 0.6849) (0.6659, 0.6835)
RUS-6 0.5 (0.6870, 0.6928) (0.6870, 0.6926)
RUS-7 0.6 (0.6785, 0.6843) (0.6486, 0.6606)
ROS-1 0.001 (0.8058, 0.8128) (0.0000, 0.0000)
ROS-2 0.005 (0.7262, 0.7300) (0.0000, 0.0089)
ROS-3 0.01 (0.7305, 0.7345) (0.0030, 0.0220)
ROS-4 0.2 (0.7378, 0.7478) (0.6040, 0.6258)
ROS-5 0.4 (0.7424, 0.7494) (0.7337, 0.7405)
ROS-6 0.5 (0.7262, 0.7402) (0.7431, 0.7497)
ROS-7 0.6 (0.7450, 0.7516) (0.7395, 0.7481)
ROS-RUS-1 0.5 (0.7398, 0.7490) (0.7463, 0.7537)
ROS-RUS-2 0.5 (0.7453, 0.7515) (0.7487, 0.7547)
ROS-RUS-3 0.5 (0.7456, 0.7532) (0.7479, 0.7535)

Fig. 3 Medicare class-wise performance

TPR and TNR scores across all positive class sizes. Furthermore, we observe that the
TPR is always greater than the TNR when using the Optimal classification threshold.
This suggests that our threshold selection procedure (Algorithm 1) is effective, and
that we can expect performance trade-offs optimized during validation to generalize
to unseen test data. Default threshold results (right), however, are unstable as the
positive class size varies. When class imbalance levels are high, for example, the
Default threshold assigns all test samples to the negative class. It is only when classes
are mostly balanced that the Default threshold achieves high TPR and TNR. Even
Thresholding Strategies for Deep Learning … 221

when Default threshold performance is reasonably well balanced, e.g. positive class
sizes of 40%, we lose the ability to maximize TPR over TNR.
In summary, two highly imbalanced Medicare data sets were used to compare
DNN prediction performance using Optimal classification thresholds and Default
classification thresholds. For each data set, Part B and Part D, 18 distributions were
created using ROS and RUS to cover a wide range of class imbalance levels, i.e. 0.03–
60%. For each of these 36 distributions, 30 two-layer networks and 30 four-layer
networks were trained and scored on test sets. With evidence from over 2,000 DNN
models, statistical results show that the Default threshold is suboptimal whenever
models are trained with imbalanced data. Even in the most severely imbalanced
distributions, e.g. positive class size of 0.03%, scoring with an Optimal threshold
yields consistently favorable G-Mean, TPR, and TNR scores. In the next section, we
expand on these results with the ECBDL’14 data set and consider a third thresholding
method, the Prior threshold.

5.3 ECBDL’14 Classification Results

ECBDL’14 experiments incorporate a third thresholding strategy, the Prior threshold,


which uses the prior probability of the positive class from the training distribution
as the classification threshold. We first present G-Mean scores from each respec-
tive thresholding strategy over a wide range of class imbalance levels, i.e. 0.5–90%.
G-Mean, TPR, and TNR results are then averaged across all distributions and sum-
marized using Tukey’s HSD test.
Figure 4 illustrates the G-Mean score for each thresholding strategy and distri-
bution. Similar to Medicare results, the Default threshold performance is acceptable
when classes are mostly balanced, e.g. positive class sizes of 40–60%, but deteriorates

Fig. 4 ECBDL’14 G-mean results


222 J. M. Johnson and T. M. Khoshgoftaar

Table 10 ECBDL’14 class-wise performance and HSD groups


Threshold Geometric mean True positive rate True negative rate
strategy Mean Std. Group Mean Std. Group Mean Std. Group
Default 0.4421 0.28 b 0.4622 0.37 b 0.7823 0.26 a
Optimal 0.7333 0.01 a 0.7276 0.02 a 0.7399 0.03 b
Prior 0.7320 0.02 a 0.7195 0.04 a 0.7470 0.05 b

quickly when the positive class prior is ≥ 0.7 or ≤ 0.3. Unlike the Default thresh-
old, the Optimal threshold yields relatively stable G-Mean scores (≥ 0.7) across all
training distributions. Most interestingly, the Prior threshold also yields stable G-
Mean scores across all training distributions. In addition to performing, besides the
Optimal threshold, the Prior threshold strategy has the advantage of being derived
directly from the training data without requiring optimization. Most importantly, we
were able to achieve TPR and TNR scores on par with those of the competition
winners [73] without the added costs of over-sampling.
Table 10 lists ECBDL’14 G-Mean, TPR, and TNR scores averaged across all
training distributions. Tukey’s HSD groups are used to identify results with signif-
icantly different means, i.e. group a performs significantly better than group b. For
the G-Mean metric, the Optimal and Prior threshold methods are placed into group
a with mean scores of 0.7333 and 0.7320, respectively. The Default threshold is
placed into group b with a mean score of 0.4421. These results indicate that the
Optimal and Prior thresholds balance class-wise performance significantly better
than the Default threshold. Equivalently, TPR results place the Optimal and Prior
threshold methods into group a and the Default threshold into group b. Put another
way, the non-default threshold strategies are significantly better at capturing the pos-
itive class of interest. We expect this behavior from the Optimal threshold, as the
threshold selection procedure that we employed (Algorithm 1) explicitly optimizes
thresholds by maximizing TPR and G-Mean scores on the training data. The Prior
method, however, was derived directly from the training data with zero training or
optimization, and surprisingly, performed equally as well as the Optimal threshold.
We believe these qualities make the Prior thresholding strategy a great candidate
for preliminary experiments and baseline models. If requirements call for specific
class-wise performance trade-offs, the Prior threshold can still offer an approximate
baseline threshold to begin optimization.
On further inspection, we see that the Default threshold achieves the highest
average TNR score. While the negative class is typically not the class of interest,
it is still important to minimize false positive predictions. The Default threshold’s
TNR score is misleadingly high, however, and is a result of having more imbalanced
distributions with positive class sizes < 0.5 than there are > 0.5. Recall from Fig. 3
that when the positive class prior is small, models tend to assign all test samples
to the negative class. The Default threshold scores highly on TNR because most
of the models trained with imbalanced data are assigning all test samples to the
Thresholding Strategies for Deep Learning … 223

negative class and achieving a 100% TNR. Therefore, we rely on the G-Mean scores
to ensure class-wise performance is balanced and conclude that the Default threshold
is suboptimal when high class imbalance exists within the training data.
The ECBDL’14 results presented in this section align with those from the Medicare
experiments. For all imbalanced training distributions, Optimal thresholds consis-
tently outperform the Default threshold. The Prior threshold, although not optimized,
performed statistically as well as the Optimal threshold by all performance criteria.

6 Conclusion

This chapter explored the effects of highly imbalanced big data on training and scor-
ing deep neural network classifiers. We trained models on a wide range of class
imbalance levels (0.03–90%) and compared the results of two output thresholding
strategies to Default threshold results. The Optimal threshold technique used train-
ing or validation data to find thresholds that maximize the G-Mean performance
metric. The Prior threshold technique uses the positive class prior as the classifica-
tion threshold for assigning labels to test instances. As suggested by Bayes theorem
(Eq. 1), we found that all Optimal thresholds are proportional to the positive class
priors of the training data. As a result, the Default threshold of 0.5 only performed
well when the training data was relatively balanced, i.e. positive priors between
0.4–0.6. For all other distributions, the Optimal and Prior thresholding strategies
performed significantly better based on the G-Mean criterion. Furthermore, Tukey’s
HSD test results suggest that there is no difference between Optimal and Prior thresh-
old results. These Optimal threshold results are dependent on the threshold selection
criteria (Algorithm 1). This Optimal threshold procedure should be guided by the
classification task requirements, and selecting a new performance criteria may yield
Optimal thresholds that are significantly different from the Prior threshold.
Future works should evaluate these thresholding strategies across a wider range of
domains and network architectures, e.g. natural language processing and computer
vision. Additionally, the threshold selection procedure should be modified to opti-
mize alternative performance metrics, and statistical tests should be used to identify
significant differences.

References

1. W. Wei, J. Li, L. Cao, Y. Ou, J. Chen, Effective detection of sophisticated online banking fraud
on extremely imbalanced data. World Wide Web 16, 449–475 (2013)
2. A.N. Richter, T.M. Khoshgoftaar, Sample size determination for biomedical big data with
limited labels. Netw. Model. Anal. Health Inf. Bioinf. 9, 1–13 (2020)
3. M. Kubat, R.C. Holte, S. Matwin, Machine learning for the detection of oil spills in satellite
radar images. Mach. Learn. 30, 195–215 (1998)
224 J. M. Johnson and T. M. Khoshgoftaar

4. S. Wang, X. Yao, Multiclass imbalance problems: analysis and potential solutions. IEEE Trans.
Syst. Man Cyb. Part B (Cybern.) 42, 1119–1130 (2012)
5. N. Japkowicz, The class imbalance problem: significance and strategies, in Proceedings of the
International Conference on Artificial Intelligence (2000)
6. M. Buda, A. Maki, M.A. Mazurowski, A systematic study of the class imbalance problem in
convolutional neural networks. Neural Netw. 106, 249–259 (2018)
7. H. He, E.A. Garcia, Learning from imbalanced data, IEEE Trans. Knowl. Data Eng. 21, 1263–
1284 (2009)
8. G.M. Weiss, Mining with rarity: a unifying framework. SIGKDD Explor. Newsl. 6, 7–19 (2004)
9. R.A. Bauder, T.M. Khoshgoftaar, T. Hasanin, An empirical study on class rarity in big data,
in 2018 17th IEEE International Conference on Machine Learning and Applications (ICMLA)
(2018), pp. 785–790
10. E. Dumbill, What is big data? an introduction to the big data landscape (2012). http://radar.
oreilly.com/2012/01/what-is-big-data.html
11. S.E. Ahmed, Perspectives on Big Data Analysis: methodologies and Applications (Amer Math-
ematical Society, USA, 2014)
12. J.L. Leevy, T.M. Khoshgoftaar, R.A. Bauder, N. Seliya, A survey on addressing high-class
imbalance in big data. J. Big Data 5, 42 (2018)
13. J. Dean, S. Ghemawat, Mapreduce: simplified data processing on large clusters. Commun.
ACM 51, 107–113 (2008)
14. M. Zaharia, M. Chowdhury, M. J. Franklin, S. Shenker, I. Stoica, Spark: cluster computing
with working sets, in Proceedings of the 2Nd USENIX Conference on Hot Topics in Cloud
Computing, HotCloud’10, (Berkeley, CA, USA), USENIX Association (2010), p. 10
15. K. Chahal, M. Grover, K. Dey, R.R. Shah, A hitchhiker’s guide on distributed training of deep
neural networks. J. Parallel Distrib. Comput. 10 (2019)
16. R.K.L. Kennedy, T.M. Khoshgoftaar, F. Villanustre, T. Humphrey, A parallel and distributed
stochastic gradient descent implementation using commodity clusters. J. Big Data 6(1), 16
(2019)
17. D.L. Wilson, Asymptotic properties of nearest neighbor rules using edited data. IEEE Trans.
Syst. Man Cybern. SMC-2, 408–421 (1972)
18. N.V. Chawla, K.W. Bowyer, L.O. Hall, W.P. Kegelmeyer, Smote: synthetic minority over-
sampling technique. J. Artif. Int. Res. 16, 321–357 (2002)
19. H. Han, W.-Y. Wang, B.-H. Mao, Borderline-smote: a new over-sampling method in imbalanced
data sets learning, in Advances in Intelligent Computing ed. by D.-S. Huang, X.-P. Zhang, G.-B.
Huang (Springer, Berlin, Heidelberg, 2005), pp. 878–887
20. T. Jo, N. Japkowicz, Class imbalances versus small disjuncts. SIGKDD Explor. Newsl. 6, 40–49
(2004)
21. C. Ling, V. Sheng, Cost-sensitive learning and the class imbalance problem, in Encyclopedia
of Machine Learning (2010)
22. J.J Chen, C.-A. Tsai, H. Moon, H. Ahn, J.J. Young, C.-H. Chen, Decision threshold adjustment
in class prediction, in SAR and QSAR in Environmental Research, vol. 17 (2006), pp. 337–352
23. Q. Zou, S. Xie, Z. Lin, M. Wu, Y. Ju, Finding the best classification threshold in imbalanced
classification. Big Data Res. 5, 2–8 (2016)
24. X. Liu, J. Wu, Z. Zhou, Exploratory undersampling for class-imbalance learning. IEEE Trans.
Syst. Man Cybern. Part B (Cybern.) 39, 539–550 (2009)
25. N.V. Chawla, A. Lazarevic, L.O. Hall, K.W. Bowyer, Smoteboost: improving prediction of
the minority class in boosting, in Knowledge Discovery in Databases: PKDD 2003 ed. by
N. Lavrač, D. Gamberger, L. Todorovski, H. Blockeel, (Springer, Berlin, Heidelberg, 2003),
pp. 107–119
26. Y. Sun, Cost-sensitive Boosting for Classification of Imbalanced Data. Ph.D. thesis, Waterloo,
Ont., Canada, Canada, 2007. AAINR34548
27. I. Goodfellow, Y. Bengio, A. Courville, Deep Learning (The MIT Press, Cambridge, MA,
2016)
Thresholding Strategies for Deep Learning … 225

28. I.H. Witten, E. Frank, M.A. Hall, C.J. Pal, Data Mining, Fourth Edition: practical Machine
Learning Tools and Techniques, 4th edn. (San Francisco, CA, USA, Morgan Kaufmann Pub-
lishers Inc., 2016)
29. Y. LeCun, Y. Bengio, G. Hinton, Deep learning. Nature. 521, 436 (2015)
30. M. Abadi, A. Agarwal, P. Barham, E. Brevdo, Z. Chen, C. Citro, G.S. Corrado, A. Davis,
J. Dean, M. Devin, S. Ghemawat, I. Goodfellow, A. Harp, G. Irving, M. Isard, Y. Jia, R. Joze-
fowicz, L. Kaiser, M. Kudlur, J. Levenberg, D. Mané, R. Monga, S. Moore, D. Murray, C. Olah,
M. Schuster, J. Shlens, B. Steiner, I. Sutskever, K. Talwar, P. Tucker, V. Vanhoucke, V. Vasude-
van, F. Viégas, O. Vinyals, P. Warden, M. Wattenberg, M. Wicke, Y. Yu, X. Zheng, TensorFlow:
large-scale machine learning on heterogeneous systems (2015)
31. Theano Development Team, Theano: a python framework for fast computation of mathematical
expressions (2016). arXiv:abs/1605.02688
32. F. Chollet et al., Keras (2015). https://keras.io
33. A. Paszke, S. Gross, S. Chintala, G. Chanan, E. Yang, Z. DeVito, Z. Lin, A. Desmaison,
L. Antiga, A. Lerer, Automatic differentiation in pytorch, in NIPS-W (2017)
34. S. Chetlur, C. Woolley, P. Vandermersch, J. Cohen, J. Tran, B. Catanzaro, E. Shelhamer, cudnn:
efficient primitives for deep learning (2014)
35. A. Krizhevsky, I. Sutskever, G.E. Hinton, Imagenet classification with deep convolutional
neural networks. Neural Inform. Process. Syst. 25, 01 (2012)
36. M.D. Richard, R.P. Lippmann, Neural network classifiers estimate bayesian a posteriori prob-
abilities. Neural Comput. 3(4), 461–483 (1991)
37. Centers For Medicare & Medicaid Services, Medicare provider utilization and payment data:
physician and other supplier (2018)
38. Centers For Medicare & Medicaid Services, Medicare provider utilization and payment data:
part D prescriber (2018)
39. U.S. Government, U.S. Centers for Medicare & Medicaid Services, The official U.S. govern-
ment site for medicare
40. Evolutionary Computation for Big Data and Big Learning Workshop, Data mining competition
2014: self-deployment track
41. M. Wani, F. Bhat, S. Afzal, A. Khan, Advances in Deep Learning (Springer, 2020)
42. J.W. Tukey, Comparing individual means in the analysis of variance. Biometrics 5(2), 99–114
(1949)
43. R. Anand, K.G. Mehrotra, C.K. Mohan, S. Ranka, An improved algorithm for neural network
classification of imbalanced training sets. IEEE Trans. Neural Netw. 4, 962–969 (1993)
44. J.M. Johnson, T.M. Khoshgoftaar, Survey on deep learning with class imbalance. J. Big Data
6, 27 (2019)
45. J.M. Johnson, T.M. Khoshgoftaar, Medicare fraud detection using neural networks. J. Big Data
6(1), 63 (2019)
46. D. Masko, P. Hensman, The impact of imbalanced training data for convolutional neural net-
works, in 2015. KTH, School of Computer Science and Communication (CSC)
47. H. Lee, M. Park, J. Kim, Plankton classification on imbalanced large scale database via con-
volutional neural networks with transfer learning, in 2016 IEEE International Conference on
Image Processing (ICIP) (2016), pp. 3713–3717
48. S. Wang, W. Liu, J. Wu, L. Cao, Q. Meng, P. J. Kennedy, Training deep neural networks on
imbalanced data sets, in 2016 International Joint Conference on Neural Networks (IJCNN)
(2016), pp. 4368–4374
49. H. Wang, Z. Cui, Y. Chen, M. Avidan, A. B. Abdallah, A. Kronzer, Predicting hospital read-
mission via cost-sensitive deep learning. IEEE/ACM Trans. Comput. Biol. Bioinf. 1 (2018)
50. S.H. Khan, M. Hayat, M. Bennamoun, F.A. Sohel, R. Togneri, Cost-sensitive learning of deep
feature representations from imbalanced data. IEEE Trans. Neural Netw. Learn. Syst. 29, 3573–
3587 (2018)
51. T.-Y. Lin, P. Goyal, R. B. Girshick, K. He, P. Dollár, Focal loss for dense object detection, 2017
IEEE International Conference on Computer Vision (ICCV) (2017), pp. 2999–3007
226 J. M. Johnson and T. M. Khoshgoftaar

52. C. Huang, Y. Li, C. C. Loy, X. Tang, Learning deep representation for imbalanced classifica-
tion, in 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR) (2016),
pp. 5375–5384
53. S. Ando, C.Y. Huang, Deep over-sampling framework for classifying imbalanced data, in
Machine Learning and Knowledge Discovery in Databases, ed. by M. Ceci, J. Hollmén,
L. Todorovski, C. Vens, S. Džeroski (Springer International Publishing, Cham, 2017), pp. 770–
785
54. Q. Dong, S. Gong, X. Zhu, Imbalanced deep learning by minority class incremental rectifica-
tion. IEEE Trans. Pattern Anal. Mach. Intell. 1 (2018)
55. Q. Chen, J. Huang, R. Feris, L.M. Brown, J. Dong, S. Yan, Deep domain adaptation for describ-
ing people based on fine-grained clothing attributes, in 2015 IEEE Conference on Computer
Vision and Pattern Recognition (CVPR) (2015), pp. 5315–5324
56. Y. LeCun, C. Cortes, MNIST handwritten digit database (2010). http://yann.lecun.com/exdb/
mnist/, Accessed 15 Nov 2018
57. A. Krizhevsky, V. Nair, G. Hinton, Cifar-10 (canadian institute for advanced research). http://
www.cs.toronto.edu/kriz/cifar.html
58. R.A. Bauder, T.M. Khoshgoftaar, A novel method for fraudulent medicare claims detection from
expected payment deviations (application paper), in 2016 IEEE 17th International Conference
on Information Reuse and Integration (IRI) (2016), pp. 11–19
59. R.A. Bauder, T.M. Khoshgoftaar, A probabilistic programming approach for outlier detection
in healthcare claims, in 2016 15th IEEE International Conference on Machine Learning and
Applications (ICMLA) (2016), pp. 347–354
60. R.A. Bauder, T.M. Khoshgoftaar, A. Richter, M. Herland, Predicting medical provider spe-
cialties to detect anomalous insurance claims, in 2016 IEEE 28th International Conference on
Tools with Artificial Intelligence (ICTAI) (2016), pp. 784–790
61. M. Herland, R.A. Bauder, T.M. Khoshgoftaar, Medical provider specialty predictions for the
detection of anomalous medicare insurance claims, in 2017 IEEE International Conference on
Information Reuse and Integration (IRI) (2017), pp. 579–588
62. Office of Inspector General, LEIE downloadable databases (2019)
63. R.A. Bauder, T.M. Khoshgoftaar, The detection of medicare fraud using machine learning
methods with excluded provider labels, in FLAIRS Conference (2018)
64. M. Herland, T.M. Khoshgoftaar, R.A. Bauder, Big data fraud detection using multiple medicare
data sources. J. Big Data 5, 29 (2018)
65. M. Herland, R.A. Bauder, T.M. Khoshgoftaar, The effects of class rarity on the evaluation of
supervised healthcare fraud detection models. J. Big Data 6(1), 21 (2019)
66. K. Feldman, N.V. Chawla, Does medical school training relate to practice? evidence from big
data. Big Data (2015)
67. Centers for Medicare & Medicaid Services, Physician compare datasets (2019)
68. J. Ko, H. Chalfin, B. Trock, Z. Feng, E. Humphreys, S.-W. Park, B. Carter, K.D. Frick, M. Han,
Variability in medicare utilization and payment among urologists. Urology 85, 03 (2015)
69. V. Chandola, S.R. Sukumar, J.C. Schryver, Knowledge discovery from massive healthcare
claims data, in KDD (2013)
70. L.K. Branting, F. Reeder, J. Gold, T. Champney, Graph analytics for healthcare fraud risk
estimation, in 2016 IEEE/ACM International Conference on Advances in Social Networks
Analysis and Mining (ASONAM) (2016), pp. 845–851
71. National Plan & Provider Enumeration System, NPPES NPI registry (2019)
72. P.S.P. Center, 9th community wide experiment on the critical assessment of techniques for
protein structure prediction
73. I. Triguero, S. Rí, V. López, J. Bacardit, J. Benítez, F. Herrera, ROSEFW-RF: the winner
algorithm for the ecbdl’14 bigdata competition: an extremely imbalanced big data bioinfor-
maticsproblem. Knowl.-Based Syst. 87 (2015)
74. A. Fernández, S. del Río, N.V. Chawla, F. Herrera, An insight into imbalanced big data classi-
fication: outcomes and challenges. Complex Intell. Syst. 3(2), 105–120 (2017)
Thresholding Strategies for Deep Learning … 227

75. S. de Río, J.M. Benítez, F. Herrera, Analysis of data preprocessing increasing the over-
sampling ratio for extremely imbalanced big data classification, in 2015 IEEE Trust-
com/BigDataSE/ISPA, vol. 2 (2015), pp. 180–185
76. Centers for Medicare & Medicaid Services, National provider identifier standard (NPI) (2019)
77. Centers For Medicare & Medicaid Services, HCPCS general information (2018)
78. P. Di Lena, K. Nagata, P. Baldi, Deep architectures for protein contact map prediction, Bioin-
formatics (Oxford, England) 28, 2449–57 (2012)
79. J. Berg, J. Tymoczko, L. Stryer, Chapter 3, protein structure and function, in Biochemistry, 5th
edn. (W H Freeman, New York, 2002)
80. Z. Zhao, F. Morstatter, S. Sharma, S. Alelyani, A. Anand, H. Liu, Advancing feature selection
research, ASU Feature Selection Repository (2010), pp. 1–28
81. S. Linux, About (2014)
82. F. Pedregosa, G. Varoquaux, A. Gramfort, V. Michel, B. Thirion, O. Grisel, M. Blondel, P.
Prettenhofer, R. Weiss, V. Dubourg, J. Vanderplas, A. Passos, D. Cournapeau, M. Brucher,
M. Perrot, E. Duchesnay, Scikit-learn: machine learning in python. J. Mach. Learn. Res. 12,
2825–2830 (2011)
83. F. Provost, T. Fawcett, Analysis and visualization of classifier performance: comparison under
imprecise class and cost distributions, in Proceedings of the Third International Conference
on Knowledge Discovery and Data Mining, vol. 43–48 (1999), p. 12
84. D. Wilson, T. Martinez, The general inefficiency of batch training for gradient descent learning.
Neural Netw.: Off. J. Int. Neural Netw. Soc. 16, 1429–51 (2004)
85. D.P. Kingma, J. Ba, Adam: a method for stochastic optimization. CoRR (2015).
arXiv:abs/1412.6980
86. R.P. Lippmann, Neural networks, bayesian a posteriori probabilities, and pattern classification,
in From Statistics to Neural Networks, ed. by V. Cherkassky, J.H. Friedman, H. Wechsler
(Springer, Berlin, Heidelberg, 1994), pp. 83–104
87. N. Srivastava, G. Hinton, A. Krizhevsky, I. Sutskever, R. Salakhutdinov, Dropout: a simple
way to prevent neural networks from overfitting. J. Mach. Learn. Res. 15, 1929–1958 (2014)
88. S. Ioffe, C. Szegedy, Batch normalization: accelerating deep network training by reducing
internal covariate shift, in Proceedings of the 32Nd International Conference on International
Conference on Machine Learning, ICML’15, vol. 37 (JMLR.org, 2015), pp. 448–456
89. B. Zdaniuk, Ordinary Least-Squares (OLS) Model (Dordrecht, Springer Netherlands, 2014),
pp. 4515–4517
90. J.M. Johnson, T.M. Khoshgoftaar, Deep learning and data sampling with imbalanced big data,
2019 IEEE 20th International Conference on Information Reuse and Integration for Data
Science (IRI) (2019), pp. 175–183
Vehicular Localisation at High and Low
Estimation Rates During GNSS Outages:
A Deep Learning Approach

Uche Onyekpe, Stratis Kanarachos, Vasile Palade,


and Stavros-Richard G. Christopoulos

Abstract Road localisation of autonomous vehicles is reliant on consistent accurate


GNSS (Global Navigation Satellite System) positioning information. Commercial
GNSS receivers usually sample at 1 Hz, which is not sufficient to robustly and accu-
rately track a vehicle in certain scenarios, such as driving on the highway, where
the vehicle could travel at medium to high speeds, or in safety-critical scenarios. In
addition, the GNSS relies on a number of satellites to perform triangulation and may
experience signal loss around tall buildings, bridges, tunnels and trees. An approach
to overcoming this problem involves integrating the GNSS with a vehicle-mounted
Inertial Navigation Sensor (INS) system to provide a continuous and more reliable
high rate positioning solution. INSs are however plagued by unbounded exponential
error drifts during the double integration of the acceleration to displacement. Several
deep learning algorithms have been employed to learn the error drift for a better posi-
tioning prediction. We therefore investigate in this chapter the performance of Long
Short-Term Memory (LSTM), Input Delay Neural Network (IDNN), Multi-Layer
Neural Network (MLNN) and Kalman Filter (KF) for high data rate positioning. We
show that Deep Neural Network-based solutions can exhibit better performances for

U. Onyekpe (B)
Research Center for Data Science, Institute for Future Transport and Cities, Coventry University,
Gulson Road, Coventry, UK
e-mail: onyekpeu@uni.coventry.ac.uk
S. Kanarachos
Faculty of Engineering, Coventry University, Gulson Road, Coventry, UK
e-mail: ab8522@coventry.ac.uk
V. Palade
Research Center for Data Science, Coventry University, Gulson Road, Coventry, UK
e-mail: ab5839@coventry.ac.uk
S.-R. G. Christopoulos
Institute for Future Transport and Cities, Faculty of Engineering, Coventry University, Gulson
Road, Coventry, UK
e-mail: ac0966@coventry.ac.uk

© The Editor(s) (if applicable) and The Author(s), under exclusive license 229
to Springer Nature Singapore Pte Ltd. 2021
M. A. Wani et al. (eds.), Deep Learning Applications, Volume 2,
Advances in Intelligent Systems and Computing 1232,
https://doi.org/10.1007/978-981-15-6759-9_10
230 U. Onyekpe et al.

high data rate positioning of vehicles in comparison to commonly used approaches


like the Kalman filter.

Keywords Inertial navigation · INS · INS/GPS-integrated navigation · GPS


outage · Autonomous vehicle navigation · Deep learning · Neural networks · High
sampling rate

1 Introduction

1.1 Importance of Autonomous Vehicles

It is estimated that the UK’s autonomous vehicle market will be worth an approximate
value of £28 billion by 2035 [1]. A major motivation towards the development of
these vehicles is the need to improve road safety. According to [2], 75% of traffic-
related road accidents in the UK are due to human driving errors, rising to 94%
in the United States. The introduction of autonomous vehicles has the potential to
reduce these accidents [3]. Even though these vehicles could introduce new kinds of
accidents, there is the drive to ensure they are as safe as possible.
Sensory systems are key to the performance of autonomous vehicles as they
help the vehicle understand its environment [4]. Examples of sensors found on the
outside of the vehicle include LIDARs, cameras and ultrasonic systems. Several data
processing and analysis systems are also present inside the vehicle, which use the
sensor data to make decisions a human driver would normally make. LIDARs and
cameras are imaging systems used to identify objects, structures, potential collision
hazards and pedestrians in the vehicle’s trajectory [5]. Cameras are furthermore
essential to the identification of road signs and markings on structured roads. As proof
of the crucial role imaging systems play in the operation of the vehicle, a good number
of sophisticated versions of such systems are already employed [6]. Nevertheless,
although imaging systems can be used to assess the vehicle’s environment as well
as determine the position of objects or markings relative to it, there is the need to
continuously and robustly localise a vehicle with reference to a defined co-ordinate
system. Information on how the vehicle navigates through its environment is also
needed such that real-time follow-up decisions can be made.

1.2 GNSS (Global Navigation Satellite System) Issues

A GNSS receiver performs complex analysis on the signals received from at least
four of the many satellites orbiting the earth and is known to be one of the best when
it comes to position estimation, as it has no competition in terms of cost or coverage
[7]. Despite the wide acceptance of GNSS, it is far from being a perfect positioning
system. There can be instances of GNSS failures in outdoor environments, as there has
Vehicular Localisation at High and Low Estimation Rates … 231

to be a direct line of sight between the satellites and the GNSS antennae. GNSS can
prove difficult to use in metropolitan cities and similar environments characterised
by tall buildings, bridges, tunnels or trees, as its line of sight may be blocked during
signal transmission [7]. More so, GNSS signal can be jammed and this leaves the
vehicle with no information about its position [8]. As such, a GNSS cannot act as a
standalone navigation system.
The GNSS is used to localise the autonomous vehicle to a road. To achieve lane
accuracy, GNSS is combined with high accuracy LIDARs, cameras, RADAR and
High Definition (HD) maps. There are however times when the camera and LIDAR
could be uninformative or unavailable for use. The accuracy of low-cost LIDARs and
cameras could be compromised when there is heavy fog, snow, sleet or rain [9]. This
is a well-recognised issue in the field. The cost of high accuracy LIDAR also makes
them a theft attractive item as they are worth several thousands of pounds. Hence, the
use of LIDARs on autonomous vehicles would make the vehicles more expensive.
Camera-based positioning systems could also face low accuracies depending on the
external light intensity and the objects in the camera’s scene. In level 4 self-driving
applications, as tested by Waymo LLC and Cruise LLC, the LIDAR scan is matched
onto an HD map in real time. Based on this, the system is able to precisely position
the vehicle within its environment [10]. However, this method is computationally
intensive. Furthermore, changes in the driving environment and infrastructure could
make an HD map temporarily outdated and as such not useful for navigation.

1.3 Navigation Using Inertial Measurement Sensors

An Inertial Navigation Sensor (INS) however, unlike other sensors found in an


autonomous vehicle, does not need to interact with the external environment to
perform localisation, making it unique to the other sensors employed on the vehicle.
This independence makes it vital for both sensor fusion and safety. An Inertial
Measuring Unit (IMU) measures a vehicle’s linear acceleration and rotational rate
components in the x, y and z-axis and computes velocity, positioning and orientation
by continuous dead reckoning. It functions to provide localisation data, which is
needed by the vehicle to position itself within its environment. As production vehi-
cles are already equipped with anywhere from one third to a full INS [11], the IMU
can be used to localise the vehicle temporarily in the absence of the GNSS signals.
The IMUs can also be used to compare positions and estimate in order to introduce
certainty to the final localisation output. In the absence of an IMU, it would be diffi-
cult to know when the localization accuracy of the LIDAR may have deteriorated
[10].
Through a complex mathematical analysis, the position of the vehicle can be
computed using the INS to dead reckon during the GNSS outage. However, the
sensors are plagued by exponential error drifts manifested by the double integra-
tion of the acceleration to displacement. These errors are unboundedly cascaded
over time, leading to provide a very poor positioning estimation. Commonly, the
232 U. Onyekpe et al.

GNSS in what could be described as a mutually symbiotic relationship calibrates


the INS periodically during signal coverage to help improve the positioning estima-
tion accuracy. Traditionally, Kalman filters are used to model the error between the
GPS position and the INS position solution. Kalman filters have limitations when
modelling highly non-linear dependencies, non-Gaussian noise measurements and
stochastic relationships.
The use of artificial neural network techniques in place of Kalman filters to model
the errors has been recently explored by some researchers, as they are capable of
learning non-linear relationships within. Compared to Kalman filters, deep learning
techniques have proven to perform better in longer GPS signal losses. Rashad et al.
proposed a radial basis function neural network to model the position error between
the GPS and the INS position [12]. A Multi-Layer Feed-Forward Neural Network
(MLNN) was applied to a DGPS, and tactical grade INS-integrated architecture
for navigation [13]. [15] utilised an MLNN on a single point positioning GPS
integrated and IMU architecture. Malleswaran et al. suggested the use of bidirec-
tional and hetero-associative neural networks on an INS/GPS-integrated system [16].
Noureldin et al. proposed the use of an Input Delay Neural Network (IDNN) on the
INS/GPS problem by utilising inputs from previous timesteps [17]. Malleswaran
et al. investigated the use of a Sigma-Pi neural network on the navigation problem
[18]. The performance of these techniques, as demonstrated in published literature,
highlights the potential of intelligent algorithms in autonomous vehicle navigation.
More so, a direct comparison of the performances of these techniques is not possible,
as information of the vehicle’s dynamics studied are not publicly available.
Most researchers have employed a prediction frequency of 1 Hz, as commercial
GPS receivers mostly update their location information every second [19]. However,
on the motorway and other safety-critical applications, predicting at such frequency
is not sufficient to efficiently track the vehicle. With a speed of 70 mph, a vehicle
could cover a distance of 32 m in a second with the vehicle’s motion unaccounted
for, where a 2.5 m lateral displacement could mean a completely different lane. More
so, assessments of vehicle-related accidents by automotive insurers would require
high resolution positioning estimation [19]. A high data rate positioning technique
would more accurately monitor the vehicle’s motion between GPS signal updates and
signal losses. We therefore comparatively investigate in this chapter the performance
of LSTM, IDNN, MLNN and the Kalman filter for high data rate positioning.

2 INS/GPS Fusion Techniques Used for Vehicle Positioning

2.1 Kalman Filter (KF) Approach

Traditionally, Kalman filters are utilised to perform INS/GPS integration. The


Kalman filter is used to estimate a linear system instantaneous state affected by Gaus-
sian white noise. It has become a standard technique for use in INS/GPS applications
Vehicular Localisation at High and Low Estimation Rates … 233

[20]. Despite the wide popularity of the Kalman filter, it does possess some draw-
backs. For an INS/GPS-integrated application, the Kalman filter requires stochastic
models to represent the INS errors, but these stochastic models are difficult to deter-
mine for most gyroscopes and accelerometers [13]. Even more, there is the need for
an accurate a priori information of the covariant matrices of the noises associated
with the INS. More so, the INS/GPS problem is one of a non-linear nature. As a
result, other types of filters have been studied [20].
The Kalman filter functions in two stages: the prediction stage, which involves
the computation of the errors between the measurement and the prediction, and the
update (innovation) stage, where the Kalman filter uses the inputs, measurements
and process model to correct the predictions.
In modelling the error between the INS and GNSS position, we model the
prediction and update stages in discrete time as

X t = At|t−1 X t−1 + Bt−1 Ut−1 + wt−1 (1)

Z t = Ht X t + Vt (2)

Prediction Stage:

. X̂ t = At|t−1 X t−1 . (3)

Z t = At|t−1 Pt−1 At|t−1


T
+ Bt−1 Q t−1 Bt−1
T
(4)

Innovation Stage: This stage involves the computation of the errors between the
predicted states and the measurements.
 −1
K t = Pt− HtT Ht Pt− HtT + Rt (5)
 
X̂ t = X̂ t− + K t Z t − Ht X̂ t− (6)

Pt = (1 − K t Ht ) ∗ X̂ t− (7)

where X t is the error state vector, Ut is the input/control vector, wt is the system
noise (Gaussian), Vt is the measurement noise (Gaussian), Z t is the measurement
vector, At is the state transition matrix, Bt is the system noise coefficient matrix, Ht
relates the state to the measurement, Qt is the system noise covariance matrix,Rt is
the measurement noise covariance matrix, X̂ t is the state prediction update a priori,
X̂ t− is the state prediction update a posteriori, K t is the Kalman gain matrix, Pt is the
error covariance update a priori and Pt− is the error covariance update a posteriori.
More information on the Kalman filter can be found in [13, 21].
234 U. Onyekpe et al.

2.2 Multi-layer Neural Networks (MLNN) and Deep


Learning

An MLNN consists of an interconnected system of neurons with the ability to map


non-linear relationships between the inputs and the outputs. This capability is partic-
ularly of interest as vehicles’ dynamics is non-linear in nature. The neurons are
connected by weights, with the output defined by a function of the sum of the
neuron’s input and transformed non-linearly through an activation function. The
neuron’s input is computed from the product of a weight factor matrix and the input
matrix, and a bias. The output from a neuron layer becomes the input vector for the
neurons in the next layer. Through the continuous backpropagation of errors signals,
the weights are adjusted in what is referred to as the training phase of the MLNN.
An adjustable learning rate and momentum can be used to prevent the MLNN from
getting trapped in a local minimum while backpropagating the errors [22].
The feed-forward layer operation is governed by
 
y=σ xw + b (8)

where y is the layer output vector, x is the input vector, w is the weight matrix and
b is the bias.
Deep learning algorithms employ the use of multiple layers to extract features
progressively from the raw input through multiple levels of non-linear abstractions.
They try to find good representations of the input–output relationship by exploiting
the input distribution structure [23]. The success of deep learning algorithms on
several challenging problems has been demonstrated in many published literatures,
for example [24–27].
Some popular deep learning architectures used today, especially in computer
vision applications, are ResNet [28] and Yolo [29], which show good performances in
visual recognition tasks. Aside from the deep learning’s success in image processing
tasks, sequential data such as audio and texts are now processed with deep neural
networks, which are able to achieve state-of-the-art performances in speech recog-
nition and natural language processing tasks [30]. These successes lend support to
the potential of deep learning algorithms to learn complex abstractions present in the
noisy sensor signals in the application under study in this chapter.

2.3 Input Delay Neural Network (IDNN)

The position errors of INS are accumulative and follow a certain pattern [25]. There-
fore, previous positional sequences are required for the model to capture the error
trend. It is however difficult to utilise a static neural network to model this pattern. A
dynamic model can be employed by using an architecture that presents the previous
‘t’ values of the signal as inputs to the network, thus capturing the error trend present
Vehicular Localisation at High and Low Estimation Rates … 235

Fig. 1 Illustration of an IDNN’s general architecture [17]

in the previous t timesteps [17]. The model can also be trained to learn time-varying or
sequential trends through the introduction of memory and associator units to the input
layer. The memory unit of the dynamic model can store previous INS samples and
forecast using the associator unit. The use of such a dynamic model has a significant
influence on the accuracy of the INS position prediction in the absence of GPS [17].
Figure 1 illustrates an IDNN’s general architecture, with p being the tapped delay
line memory length, Ui the hidden layer neurons, W the weights, G the activation
function, Y the target vector and D the delay operator.

2.4 Long Short-Term Neural Networks (LSTM)

Recurrent Neural Networks (RNNs) have been proven to learn more useful features
on sequential problems. Long Short-Term Memory (LSTM) networks are a variant
of RNN created to tackle its shortfall. They are specifically created to solve the
long-term dependency problem, hence enabling them to recall information for long
periods. Due to the accumulative and patterned nature of the INS positional errors,
236 U. Onyekpe et al.

Fig. 2 Unrolled RNN


architecture

the LSTM can be used to learn error patterns from previous sequences to provide
a better position estimation. The operation of the LSTM is regulated by gates; the
forget, input and output gates, and it operates as shown below:
   
f t = σ W f h t−1 , xt + b f (9)

   
i t = σ Wi h t−1 , xt + bi (10)

   
ĉt = tanh Wc h t−1 , xt + bc (11)

ct = f t ∗ ct−1 + i t ∗ ĉt (12)

   
ot = σ Wo h t−1 , xt + bo (13)

h t = ot ∗ tanh(ct ) (14)

where ∗ is the Hadamard product, f t is the forget gate, i t is the input gate, ot is the
output gate, ct is the cell state, h t−1 is the previous state, W is the weight matrix and
σ is the sigmoid activation (non-linearity) function. Figures 2 and 3 show the RNN’s
architecture and LSTM’s cell structure.

3 Problem Formulation

In this section, we discuss the formulation of the learning task using the four
modelling techniques introduced in Sect. 2. Furthermore, we present the experiments
and comparative results on the inertial tracking problem in Sect. 4.
Vehicular Localisation at High and Low Estimation Rates … 237

Fig. 3 LSTM cell structure

3.1 Inertial Tracking

Vehicular tracking using inertial sensors are governed by the Newtonian laws of
motion. Given an initial pose, the attitude rate from the gyroscope can be integrated
to provide continuous orientation information. The acceleration of the vehicle can
also be integrated to determine its velocity and, provided an initial velocity, the
vehicle’s displacement can be estimated through the integration of its velocity.
Several co-ordinate systems are used in vehicle tracking with the positioning
estimation expressed relative to a reference. Usually, measurements in the sensors’
co-ordinate system (body frame) would need to be transformed to the navigation
frame [31]. The body frame has its axis coincident to the sensors’ input axis.
The navigation frame, also known as the local frame, has its origin as the sensor
frames’ origin. Its x-axis points towards the geodetic north, with the z-axis orthogonal
to the ellipsoidal plane and the y-axis completing the orthogonal frame. The rotation
matrix from the body frame to the navigation frame is expressed in Eq. (15) [20].
⎡ ⎤
cos θ cos Ψ − cos θ sin Ψ + sin φ sin θ cos Ψ sin φ sin Ψ + cos φ sin θ cos Ψ
⎢ ⎥
R nb = ⎣ cos θ sin Ψ cos φ cos Ψ + sin φ sin θ sin Ψ − sin φ cos Ψ + cos φ sin θ sin Ψ ⎦ (15)
− sin θ sin φ cos θ cos φ cos θ

where θ is the pitch, Ψ is the yaw and φ is the roll.


As the problem is considered to be the tracking of a vehicle in a horizontal plane,
the roll and pitch of the vehicle can be considered negligible and thus assumed to be
zero. Hence, the rotation matrix becomes
⎡ ⎤
cos Ψ − sin Ψ 0
R nb = ⎣ sin Ψ cos Ψ 0 ⎦ (16)
0 0 1
238 U. Onyekpe et al.

An Inertial Measuring Unit usually consists of a three orthogonal axis accelerom-


eter and a three orthogonal axis gyroscope. The accelerometer measures acceleration
in the x, y and z-axis. It measures the specific force f on the sensor in the body frame
b [20]. This can be expressed as in Eq. (17), where R bn is the rotation matrix from the
navigation frame to the body frame, g n represents the gravity vector and a n denotes
the linear acceleration of the sensor expressed in the navigation frame [31].
 
f b = R bn a n − g n (17)

given : a n = anb + 2ωie


n
+ Vnn + ωie
n
× ωie
n
× pn (18)

where anb is the acceleration of interest, 2ωie n


+ Vnn is the Coriolis acceleration,
ωie × ωie × p is the centrifugal acceleration and Vnn is the velocity of the vehicle in
n n n

the navigation frame [31]. In this application, the Coriolis acceleration is considered
negligible due to its small magnitude compared to the accelerometers’ measurements,
and the centrifugal acceleration is considered to be absorbed in the local gravity
sector. Thus, a n = anb .
The gyroscope measures the attitude change in roll, yaw and pitch. It measures
the angular velocity of the body frame (vehicles frame) with respect to the inertial
frame, as expressed in the body frame. Represented by ωib b
, the attitude rate can be
expressed as
 n 
ωib
b
= R bn ωie + ωen
n
+ ωnb
b
(19)

where ωie n
is the angular velocity of the earth frame with respect to the inertial frame
and estimated to be approximately 7.29 × 10−5 rad/s [31]. The navigation frame is
defined stationary with respect to the earth, thus ωenn
= 0 with the angular velocity of
interest ωnb representing the rate of rotation of the vehicle in the navigation frame.
b

If initial conditions are known, ωnb b


may be integrated over time to determine the
vehicles orientation (yaw), as shown in Eq. (20):

t
ΨINS = Ψ0 + ωnb
b
(20)
t−1

where Ψ0 is the last known yaw of the vehicle.

3.2 Deep Learning Task Formulation

The accelerometer measurement (specific force) f b at each time instant t is typically


assumed to be corrupted by a bias δ bI N S and noise εab . Thus, the corrupted sensor’s
measurement can be represented as FIbN S ,
Vehicular Localisation at High and Low Estimation Rates … 239

where
b
FINS = f INS
b
+ δINS
b
+ εab (21)

Furthermore, theaccelerometers’
  noise is typically quite Gaussian1 and can be
modelled as εa ∼ N 0, a . The accelerometers’ bias is slowly time varying and as
b

such can be modelled either as a constant parameter or as part of a time-varying state.


The specific force measurement can be expanded from Eq. (17) as shown below:

ab = f b + gb (22)

b
FINS = aINS
b
+ δINS,a
b
+ εab (23)

a b = FINS
b
− δINS,a
b
− εab (24)

b
FINS − δINS,a
b
= a b + εab (25)

b
However, aINS = FINS
b
− δ bI N S,a (26)

b
aINS = a b + εab (27)

Through the integration of Eq. (27), the velocity of the vehicle in the body frame
can be determined.

t
vINS
b
= a b + εvb (28)
t−1

The displacement of the sensor in the body frame at time t from t − 1, x IbN S ,
can also be estimated by the double integration of the Eq. (27) provided an initial
velocity.

¨t
b
xINS = a b + εxb (29)
t−1

where δINS,a
b
is the bias in the body frame calculated to be a constant parameter by
computing the average reading from a stationary accelerometer which ran for 20 min.
b
FINS is the corrupted measurement provided directly by the accelerometer sensor at

1 Thevehicles’ dynamics is non-linear, especially when cornering or braking hard; thus, a linear or
non-accurate noise model would not sufficiently capture the non-linear relationship.
240 U. Onyekpe et al.
t ˜t
time t (sampling time), g is the gravity vector, a b , t−1 a b and t−1 a b are the true
(uncorrupted) longitudinal acceleration, velocity and displacement, respectively, of
the vehicle. ˜t
The true displacement of the vehicle is expressed; thus, x Gb P S ≈ t−1 a b
Furthermore, εxb can be derived by

εxb ≈ xGPS
b
− xINS
b
(30)

The noise εxb , displacement xINS


b
, velocity vINS
b b
and acceleration aINS of the vehicle
in the body frame within window t − 1 to t can thus be transformed to the navigation
frame using the rotation matrix R nb and defined by the North-East-Down (NED)
system, as shown in Eqs. (31–34). However, the down axis is not considered in this
study.
nb
from RINS · aINS
b
→ aINS
n
→ aINS
b
· cos ΨINS , aINS
b
· sin ΨINS (31)

nb
from RINS · vINS
b
→ vINS
n
→ vINS
b
· cos ΨINS , vINS
b
· sin ΨINS (32)

nb
from RINS · xINS
b
→ xINS
n
→ xINS
b
· cos ΨINS , xINS
b
· sin ΨINS (33)

from R nb · εxb → εxn → εxb · cos Ψ, εxb · sin Ψ (34)

⎡ ⎤
cos ΨINS − sin ΨINS 0
Where : RINS
nb
= ⎣ sin ΨINS cos ΨINS 0 ⎦ (35)
0 0 1

3.3 Vehicle’s True Displacement Estimation

b b
This section presents how to estimate the vehicle’s true displacement xGPS . xGPS is
useful in the determination of the target error εx as detailed in Eq. (30).
b

In estimating the distance travelled between two points on the earth’s surface, it
becomes obvious that the shape of the earth is neither a perfect sphere nor ellipse, but
rather that of an oblate ellipsoid. Due to the unique shape of the earth, complications
exist as there is no geometric shape it can be categorised under for analysis. The
Haversine formula applies perfectly to the calculations of distances on spherical
shapes, while the Vincenty’s formula applies to elliptic shapes [32].
Vehicular Localisation at High and Low Estimation Rates … 241

3.3.1 Haversine’s Formula

The Haversine’s formula is used to calculate the distance between two points on the
earth’s surface specified in longitude and latitude. It assumes a spherical earth [33].
    
∅t − ∅t−1 ϕt − ϕt−1
b
xGPS = 2r sin−1 sin2 + cos(∅t−1 )cos(∅t )sin2 (36)
2 2

where x̂tb is the distance travelled within t − 1, t with longitude and latitude (ϕ, ∅)
as obtained from the GPS, and r is the radius of the earth.

3.3.2 Vincenty’s Inverse Formula

The Vincenty’s formula is used to calculate the distance between two points on the
earth’s surface specified in longitude and latitude. It assumes an ellipsoidal earth
[34]. The distance between two points is calculated as shown in Eqs. (37)–(55).

1
Given : f = (37)
298.257223563

b = (1 − f )a (38)

U 1 = arctan(1 − f )tanφ1 (39)

ϕ = ϕ2 − ϕ1 (40)

U 1 = arctan(1 − f )tanφ2 (41)


sin σ = (cos U2 sin λ)2 + (cos U1 sinU2 − sinU1 cosU2 cosλ)2 (42)

cos σ = sin U1 sin U2 + cos U1 cos U2 cos λ (43)

σ = arctan2(sin σ, cos σ ) (44)

cos U1 cos U2 sin λ


sin α = (45)
sin σ
2 sin U1 sin U2
cos(2σm ) = cos σ − (46)
cos2 α
242 U. Onyekpe et al.

f   
C= cos2 α 4 + f 4 − 3 cos2 α (47)
16
  
λ = ϕ + (1 − C) f sin α[σ + C sin σ cos(2σm ) + C cos σ −1 + 2 cos2 (2σm )
(48)
 2 
a − b2
u 2 = cos2 α (49)
b2

1 + u2 − 1
k1 = √ (50)
1 + u2 + 1
1 + 41 k12
A= (51)
1 − k1
 
3
B = k1 1 − k12 (52)
8
σ = B sin σ
  
1   B
cos2 (2σm ) + B cos σ −1 + 2 cos2 (2σm ) − cos[2σm ][−3 + 4 sin2 σ ][ − 3 + 4 cos2 (2σm )]
4 6
(53)

α1 = arctan2(cos U2 sin λ, cos U1 sin U2 − sin U1 cos U2 cos λ) (54)

α2 = arctan2(cos U1 sin λ, sin U1 cos U2 − cos U1 sin U2 cos λ) (55)

where x̂tb is the distance travelled within t − 1 and t with longitude and latitude
(ϕ, ∅), a is the radius of the earth at the equator, f is the flattening at the ellipsoid, b
is the length of the ellipsoid semi-minor axis, U1 and U2 are the reduced latitude at t
and t − 1, respectively, λ is the change in longitude along the auxiliary spheres, s is
the ellipsoidal distance between the position at t − 1 and t, σ1 is the angle between
the position at t − 1 and t, σ is the angle between the position at t − 1 and t and σm
is the angle between the equator and midpoint of the line.
The Vincenty’s formula is used in this work, as it provides a more accurate
solution compared to Haversine and other great circle formulas [32]. The Python
implementation of Vincenty’s Inverse Formula is used here [35].

3.4 Learning Scheme for the Vehicle’s Displacement Error


Prediction

The neural networks introduced in Sect. 2 are exploited to learn the relationship
n
between the input features; displacement xINS , velocity vINS
n n
and acceleration aINS ,
Vehicular Localisation at High and Low Estimation Rates … 243

Fig. 4 Learning scheme for the northwards and eastwards displacement error prediction

and the target displacement error εxn (as presented in Sects. 3.2 and 3.3) in the north-
wards and eastwards direction, as shown in Fig. 4. The predicted displacement error is
used to correct the INS-derived displacement to provide a better positioning solution.

4 Experiments, Results and Discussion

4.1 Data set

The data used is the V-Vw12 aggressive driving vehicle benchmark data set
describing about 107 s of an approximate straight-line trajectory, of which the first
105 s is used for our analysis [36]. The sensors are inbuilt and the data is captured
from the vehicle’s ECU using the Vbox video H2 data acquisition sampling at a
frequency of 10 Hz. The longitudinal acceleration of the vehicle as well as its rate of
rotation about the z-axis (yaw rate), heading (yaw) and the GPS co-ordinates (lati-
tude and longitude) of the vehicle at each time instance is captured. Figure 5 shows
the vehicle used for the data collection and the location of its sensors.

Fig. 5 Data collection vehicle [36]


244 U. Onyekpe et al.

Table 1 LSTM, IDNN and MLNN training parameters


Parameters LSTM IDNN MLNN
Learning rate 0.09 0.09 0.09
L1 regulariser 0.9 – –
L2 regulariser 0.99 – –
Recurrent dropout 5% – –
Dropout – 5% 5%
Sequence length for sample periods 5 10 × Sample Period –
0.1–0.3 s
Sequence length for sample periods 10 × Sample Period 10 × Sample Period –
0.4–1 s
Hidden layers 2 2 2
Hidden neurons per layer 32 32 32
Batch size 32 32 32
Epochs 500 500 500

4.2 Training

The training is done using the first 75 s of the data, and then the model is tested on the
next 10 s as well as 30 s after the training data. The Keras–TensorFlow framework
was used in the training exercise with a mean absolute error loss function and an
Adamax optimiser with a learning rate of 0.09. 5% of the units were dropped from
the hidden layers of the IDNN and MLNN and from the recurrent layers of the LSTM
to prevent the neural network from overfitting [37]. Furthermore, all features fed to
the models were standardised between 0 and 100 to avoid a biased learning. Forty
models were trained for each deep learning model, and the model providing the least
position errors was selected. Parameters defining the training of the neural network
models are highlighted on Table 1. The general architecture of the models is as shown
in Fig. 4. The objective of the training exercise is to teach the neural network to learn
the positioning error between the low-cost INS and GPS.

4.3 Testing

GPS outages were assumed on the 10 s as well as 30 s of data following the training
data to help analyse the performance of the prediction models. With information
on the vehicles orientation and velocity at the last known GPS signal before the
outage, the northwards and eastwards components of the vehicles displacement
x InN S,t, velocity v nI N S,t and acceleration a nIN S,t are fed into the respective models
to predict the north and east component of the displacement error εx,t n
.
Vehicular Localisation at High and Low Estimation Rates … 245

Table 2 Position error after 10 s of GPS outages


LSTM position IDNN position MLNN position KF position error
error (m) error (m) error (m) (m)
Sampling North East Total North East Total North East Total North East Total
period (s)
1 1.21 0.61 1.36 1.63 0.64 1.75 25.51 6.55 26.34 1.39 0.54 1.49
0.9 0.59 0.7 0.92 1.22 1.08 1.63 22.02 4.28 22.43 1.05 0.93 1.40
0.8 0.39 0.47 0.61 0.37 0.93 1.00 19.38 6.69 20.51 0.32 0.81 0.87
0.7 0.29 0.53 0.60 0.62 0.98 1.16 16.97 6.63 18.22 0.55 0.86 1.02
0.6 1.04 0.36 1.10 1.37 0.81 1.59 16.36 4.38 16.94 1.22 0.72 1.42
0.5 0.62 0.54 0.82 1.67 1.05 1.97 12.09 5.88 13.44 1.50 0.95 1.78
0.4 1.81 0.6 1.91 2.17 1.11 2.44 9.79 5.9 11.43 1.97 1.01 2.22
0.3 2.24 0.89 2.24 2.64 1.15 2.88 6.98 5.93 9.16 2.43 1.06 2.65
0.2 1.45 0.71 1.61 2.83 1.05 3.02 4.73 5.08 6.94 2.80 1.04 2.99
0.1 0.94 0.90 1.30 1.12 1.41 2.80 2.14 4.49 4.97 1.01 1.40 1.73

4.4 Results and Discussion

To evaluate the performance of the LSTM, IDNN, MLNN and Kalman filter
techniques, two GPS outage scenarios are explored: 10 s and 30 s.

4.4.1 10 s Outage Experiment Result

The performance of the LSTM, IDNN, MLNN and Kalman filter solutions during
the 10 s outage is studied. From Table 2, it can be seen that at all sampling periods
the LSTM algorithm performed best at estimating the positioning error, followed
closely by the Kalman filter and IDNN. The MLNN has the least performance in
comparison. Comparing all sampling periods, the LSTM method produces the best
error estimation of 0.60 m, at a sampling period of 0.7 s, over about 233 m of travel
within the 10 s studied.

4.4.2 30 s Outage Experiment Result

A study of the LSTM, IDNN, MLNN and Kalman Filter performances during the
30 s GPS outage scenario reveals that much unlike the 10 s experiment, the Kalman
filter performs poorly in comparison to the LSTM and IDNN approaches. The LSTM
performs the best at all sampling frequencies, with the Kalman filter outperforming
the MLNN. A comparison of the sampling periods shows that the LSTM approach
provides the best error estimation of 4.86 m at 10 Hz, over about 724 m of travel
within the 30 s investigated (Table 3).
246 U. Onyekpe et al.

Table 3 Position error after 30 s of GPS outages


LSTM position IDNN position error MLNN position KF position error
error (m) (m) error (m) (m)
Sampling North East Total North East Total North East Total North East Total
period (s)
1 4.81 3.07 5.71 4.87 6.57 8.18 30.28 10.26 31.97 22.71 7.70 23.98
0.9 5.33 5.19 7.44 5.22 7.30 8.97 25.42 9.15 27.02 19.57 7.05 20.80
0.8 6.25 7.79 9.99 5.93 8.30 10.20 22.64 12.08 25.66 17.88 9.54 20.27
0.7 6.65 7.69 10.17 6.33 8.71 10.77 20.22 12.60 23.82 16.38 10.21 19.30
0.6 6.93 7.36 10.11 7.03 8.56 11.08 20.35 10.70 22.99 16.89 8.88 19.08
0.5 6.43 4.45 7.82 7.3 8.96 11.56 15.38 12.36 19.73 13.07 10.51 16.77
0.4 7.52 7.69 10.76 7.49 9.29 11.93 12.77 12.82 18.09 11.11 11.15 15.74
0.3 5.46 4.62 8.72 8.16 10.38 13.20 9.98 13.53 16.81 8.88 12.04 14.96
0.2 5.46 4.46 7.05 8.29 9.87 12.89 7.64 12.52 14.67 6.95 11.39 13.35
0.1 3.74 3.11 4.86 8.90 10.22 13.55 5.19 12.2 13.26 4.83 11.35 12.33

5 Conclusions

Effective vehicular services and safety of autonomous vehicles depend on an accurate


and reliable positioning of the vehicle. Most commercial GPS receivers however
operate at a rather low sampling rate (1 Hz) and face reliability problems in urban
canyons, tunnels, etc. An INS can fill in for the GPS to provide continuous positioning
information in between GPS signals reception. To this end, the LSTM, IDNN, MLNN
and Kalman filter techniques were investigated over several sampling scenarios in
GPS signal outages of 10 s and 30 s. The results of the study show that during short-
term outages (less than 10 s) and longer GPS outages (about 30 s) the LSTM approach
provides the best positioning solution. Furthermore, our findings show that sampling
at lower rates during long-term GPS outages provides relatively poorer position
estimates. There is however the need to explore the performance of the LSTM model
on more complex driving scenarios, as a means to assess its robustness. This will be
the subject of our future research.

References

1. I. Dowd, The future of autonomous vehicles, Open Access Government (2019). https://www.
openaccessgovernment.org/future-of-autonomous-vehicles/57772/, Accessed 04 June 2019
2. P. Liu, R. Yang, Z. Xu, How safe is safe enough for self-driving vehicles? Risk Anal. 39(2),
315–325 (2019)
3. A. Papadoulis, M. Quddus, M. Imprialou, Evaluating the safety impact of connected and
autonomous vehicles on motorways. Accid. Anal. Prev. 124, 12–22 (2019)
4. S.-J. Babak, S.A. Hussain, B. Karakas, S. Cetin, Control of autonomous ground vehicles:
a brief technical review—IOPscience (2017). https://iopscience.iop.org/article/10.1088/1757-
Vehicular Localisation at High and Low Estimation Rates … 247

899X/224/1/012029, Accessed 22 Mar 2020


5. K. Onda, T. Oishi, Y. Kuroda, Dynamic environment recognition for autonomous navigation
with wide FOV 3D-LiDAR. IFAC-PapersOnLine 51(22), 530–535 (2018)
6. S. Ahmed, M.N. Huda, S. Rajbhandari, C. Saha, M. Elshaw, S. Kanarachos, Pedestrian and
cyclist detection and intent estimation for autonomous vehicles: a survey. Appl. Sci. 9(11),
2335 (2019)
7. W. Yao et al., GPS signal loss in the wide area monitoring system: prevalence, impact, and
solution, Electr. Power Syst. Res. 147(C), 254–262 (2017)
8. G. O’Dwyer, Finland, Norway press Russia on suspected GPS jamming during NATO
drill (2018). https://www.defensenews.com/global/europe/2018/11/16/finland-norway-press-
russia-on-suspected-gps-jamming-during-nato-drill/, Accessed 04 June 2019
9. B. Templeton, Cameras or lasers? (2017). http://www.templetons.com/brad/robocars/cameras-
lasers.html, Accessed 04 June 2019
10. L. Teschler, Inertial measurement units will keep self-driving cars on track (2018). https://
www.microcontrollertips.com/inertial-measurement-units-will-keep-self-driving-cars-on-
track-faq/, Accessed 05 June 2019
11. OXTS, Why integrate an INS with imaging systems on an autonomous vehicle (2016).
https://www.oxts.com/technical-notes/why-use-ins-with-autonomous-vehicle/, Accessed 04
June 2019
12. R. Sharaf, A. Noureldin, A. Osman, N. El-Sheimy, Online INS/GPS integration with a radial
basis function neural network. IEEE Aerosp. Electron. Syst. Mag. 20(3), 8–14 (2005)
13. K.-W. Chiang, N. El-Sheimy, INS/GPS integration using neural networks for land vehicle
navigation applications (2002), pp. 535–544
14. K.W. Chiang, A. Noureldin, N. El-Sheimy, Multisensor integration using neuron computing
for land-vehicle navigation. GPS Solut. 6(4), 209–218 (2003)
15. K.-W. Chiang, The utilization of single point positioning and multi-layers feed-forward network
for INS/GPS integration (2003), pp. 258–266
16. M. Malleswaran, V. Vaidehi, M. Jebarsi, Neural networks review for performance enhancement
in GPS/INS integration, in 2012 International Conference on Recent Trends in Information
Technology ICRTIT 2012, no. 1 (2012), pp. 34–39
17. A. Noureldin, A. El-Shafie, M. Bayoumi, GPS/INS integration utilizing dynamic neural
networks for vehicular navigation. Inf. Fusion 12(1), 48–57 (2011)
18. M. Malleswaran, V. Vaidehi, A. Saravanaselvan, M. Mohankumar, Performance analysis of
various artificial intelligent neural networks for GPS/INS integration. Appl. Artif. Intell. 27(5),
367–407 (2013)
19. A.S. El-Wakeel, A. Noureldin, N. Zorba, H.S. Hassanein, A framework for adaptive resolution
geo-referencing in intelligent vehicular services, in IEEE Vehicular Technology Conference,
vol. 2019 (2019)
20. K. Chiang, INS/GPS integration using neural networks for land vehicular navigation UCGE
reports number 20209 Department of Geomatics Engineering INS/GPS Integration using
Neural Networks for Land Vehicular Navigation Applications by Kai-Wei Chiang (2004)
21. T.P. Van, T.N. Van, D.A. Nguyen, T.C. Duc, T.T. Duc, 15-state extended kalman filter design
for INS/GPS navigation system. J. Autom. Control Eng. 3(2), 109–114 (2015)
22. M.W. Gardner, S.R. Dorling, artificial neural networks (the multilayer perceptron)—A review
of applications in the atmospheric sciences. Atmos. Environ. 32(14–15), 2627–2636 (1998)
23. M.A. Wani, F.A. Bhat, S. Afzal, A.I. Khan, Advances in Deep Learning, vol. 57 (Springer,
Singapore, 2020)
24. A. Krizhevsky, I. Sutskever, G.E. Hinton, ImageNet classification with deep convolutional
neural networks (2012). https://papers.nips.cc/paper/4824-imagenet-classification-with-deep-
convolutional-neural-networks.pdf
25. C. Chen, X. Lu, A. Markham, N. Trigoni, IONet: learning to cure the curse of drift in inertial
odometry (2018), pp. 6468–6476
26. P. Kasnesis, C.Z. Patrikakis, I.S. Venieris, PerceptionNet: a deep convolutional neural network
for late sensor fusion (2018)
248 U. Onyekpe et al.

27. W. Fang et al., A LSTM algorithm estimating pseudo measurements for aiding INS during
GNSS signal outages. Remote Sens. 12(2), 256 (2020)
28. K. He, X. Zhang, S. Ren, J. Sun, Deep residual learning for image recognition, in Proceedings
of the IEEE Computer Society Conference on Computer Vision and Pattern Recognition, vol.
2016, pp. 770–778 (2016)
29. J. Redmon, S. Divvala, R. Girshick, A. Farhadi, You only look once: unified, real-time object
detection (2016). arXiv preprint arXiv:1506.02640
30. H. Ismail Fawaz, G. Forestier, J. Weber, L. Idoumghar, P.A. Muller, Deep learning for time
series classification: a review. Data Min. Knowl. Discov. 33(4), 917–963 (2019)
31. M. Kok, J.D. Hol, T.B. Schön, Using inertial sensors for position and orientation estimation.
Found. Trends Signal Process. 11(2), 1–153 (2017)
32. H. Mahmoud, N. Akkari, Shortest path calculation: a comparative study for location-based
recommender system, in Proceedings—2016 World Symposium on Computer Applications
and Research, WSCAR 2016, pp. 1–5 (2016)
33. C.M. Thomas, W.E. Featherstone, Validation of Vincenty’s formulas for the geodesic using a
new fourth-order extension of Kivioja’s formula. J. Surv. Eng. 131(1), 20–26 (2005)
34. T. Vincenty, Direct and inverse solutions of geodesics on the ellipsoid with application of nested
equations. Surv. Rev. 23(176), 88–93 (1975)
35. vincenty PyPI. https://pypi.org/project/vincenty/. Accessed 08 May 2020
36. U. Onyekpe, V. Palade, S. Kanarachos, A. Szkolnik, IO-VNBD: inertial and odometry
benchmark dataset for ground vehicle positioning (2020). arXiv preprint arXiv:2005.01701
37. Y. Gal, Z. Ghahramani, A theoretically grounded application of dropout in recurrent neural
networks (2016). arXiv preprint arXiv:1512.05287
Multi-Adversarial Variational
Autoencoder Nets for Simultaneous
Image Generation and Classification

Abdullah-Al-Zubaer Imran and Demetri Terzopoulos

Abstract Discriminative deep-learning models are often reliant on copious labeled


training data. By contrast, from relatively small corpora of training data, deep genera-
tive models can learn to generate realistic images approximating real-world distribu-
tions. In particular, the proper training of Generative Adversarial Networks (GANs)
and Variational AutoEncoders (VAEs) enables them to perform semi-supervised
image classification. Combining the power of these two models, we introduce Multi-
Adversarial Variational autoEncoder Networks (MAVENs), a novel deep generative
model that incorporates an ensemble of discriminators in a VAE-GAN network in
order to perform simultaneous adversarial learning and variational inference. We
apply MAVENs to the generation of synthetic images and propose a new distribution
measure to quantify the quality of these images. Our experimental results with only
10% labeled training data from the computer vision and medical imaging domains
demonstrate performance competitive to state-of-the-art semi-supervised models in
simultaneous image generation and classification tasks.

1 Introduction

Training deep neural networks usually requires copious data, yet obtaining large,
accurately labeled datasets for image classification and other tasks remains a fun-
damental challenge [36]. Although there has been explosive progress in the produc-
tion of vast quantities of high resolution images, large collections of labeled data
required for supervised learning remain scarce. Especially in domains such as med-
ical imaging, datasets are often limited in size due to privacy issues, and annotation
by medical experts is expensive, time-consuming, and prone to human subjectivity,

A.-A.-Z. Imran (B) · D. Terzopoulos


University of California, Los Angeles, CA 90095, USA
e-mail: aimran@cs.ucla.edu
D. Terzopoulos
e-mail: dt@cs.ucla.edu
© The Editor(s) (if applicable) and The Author(s), under exclusive license 249
to Springer Nature Singapore Pte Ltd. 2021
M. A. Wani et al. (eds.), Deep Learning Applications, Volume 2,
Advances in Intelligent Systems and Computing 1232,
https://doi.org/10.1007/978-981-15-6759-9_11
250 A.-A.-Z. Imran and D. Terzopoulos

Fig. 1 Image generation based on the CIFAR-10 dataset [19]: a Relatively good images generated
by a GAN. b Blurry images generated by a VAE. Based on the SVHN dataset [24]: c mode collapsed
images generated by a GAN

inconsistency, and error. Even when large labeled datasets become available, they
are often highly imbalanced and non-uniformly distributed. In an imbalanced med-
ical dataset there will be an over-representation of common medical problems and
an under-representation of rarer conditions. Such biases make the training of neural
networks across multiple classes with consistent effectiveness very challenging.
The small-training-data problem is traditionally mitigated through simplistic and
cumbersome data augmentation, often by creating new training examples through
translation, rotation, flipping, etc. The missing or mismatched label problem may
be addressed by evaluating similarity measures over the training examples. This is
not always robust and its effectiveness depends largely on the performance of the
similarity measuring algorithms.
With the advent of deep generative models such as Variational AutoEncoders
(VAEs) [18] and Generative Adversarial Networks (GANs) [9], the ability to learn
underlying data distributions from training samples has become practical in common
scenarios where there is an abundance of unlabeled data. With minimal annotation,
efficient semi-supervised learning could be the preferred approach [16]. More specif-
ically, based on small quantities of annotation, realistic new training images may be
generated by models that have learned real-world data distributions (Fig. 1a). Both
VAEs and GANs may be employed for this purpose.
VAEs can learn dimensionality-reduced representations of training data and, with
an explicit density estimation, can generate new samples. Although VAEs can per-
form fast variational inference, VAE-generated samples are usually blurry (Fig. 1b).
On the other hand, despite their successes in generating images and semi-supervised
classifications, GAN frameworks remain difficult to train and there are challenges
in using GAN models, such as non-convergence due to unstable training, dimin-
ished gradient issues, overfitting, sensitivity to hyper-parameters, and mode collapsed
image generation (Fig. 1c).
Despite the recent progress in high-quality image generation with GANs and
VAEs, accuracy and image quality are usually not ensured by the same model, espe-
cially in multiclass image classification tasks. To tackle this shortcoming, we propose
Multi-Adversarial Variational Autoencoder Nets for Simultaneous … 251

a novel method that can simultaneously learn image generation and multiclass image
classification. Specifically, our work makes the following contributions:
1. The Multi-Adversarial Variational autoEncoder Network, or MAVEN, a novel
multiclass image classification model incorporating an ensemble of discriminators
in a combined VAE-GAN network. An ensemble layer combines the feedback
from multiple discriminators at the end of each batch. With the inclusion of
ensemble learning at the end of a VAE-GAN, both generated image quality and
classification accuracy are improved simultaneously.
2. A simplified version of the Descriptive Distribution Distance (DDD) [14] for eval-
uating generative models, which better represents the distribution of the generated
data and measures its closeness to the real data.
3. Extensive experimental results utilizing two computer vision and two medical
imaging datasets.1 These confirm that our MAVEN model improves upon the
simultaneous image generation and classification performance of a GAN and of
a VAE-GAN with the same set of hyper-parameters.

2 Related Work

Several techniques have been proposed to stabilize GAN training and avoid mode
collapse. Nguyen et al. [26] proposed a model where a single generator is used
alongside dual discriminators. Durugkar et al. [7] proposed a model with a single
generator and feedback aggregated over several discriminators, considering either
the average loss over all discriminators or only the discriminator with the maximum
loss in relation to the generator’s output. Neyshabur et al. [25] proposed a framework
in which a single generator simultaneously trains against an array of discrimina-
tors, each of which operates on a different low-dimensional projection of the data.
Moridido et al. [23], arguing that all the previous approaches restrict the discrimina-
tor’s architecture thereby compromising extensibility, proposed the Dropout-GAN,
where a single generator is trained against a dynamically changing ensemble of
discriminators. However, there is a risk of dropping out all the discriminators. Fea-
ture matching and minibatch discrimination techniques have been proposed [32] for
eliminating mode collapse and preventing overfitting in GAN training.
Realistic image generation helps address problems due to the scarcity of labeled
data. Various architectures of GANs and their variants have been applied in ongoing
efforts to improve the accuracy and effectiveness of image classification. The GAN
framework has been utilized as a generic approach to generating realistic train-
ing images that synthetically augment datasets in order to combat overfitting; e.g.,
for synthetic data augmentation in liver lesions [8], retinal fundi [10], histopathol-
ogy [13], and chest X-rays [16, 31]. Calimeri et al. [3] employed a LAPGAN [6] and
Han et al. [11] used a WGAN [1] to generate synthetic brain MR images. Bermudez

1 Thischapter significantly expands upon our ICMLA 2019 publication [15], which excluded our
experiments on medical imaging datasets.
252 A.-A.-Z. Imran and D. Terzopoulos

et al. [2] used a DCGAN [29] to generate 2D brain MR images followed by an


autoencoder for image denoising. Chuquicusma et al. [4] utilized a DCGAN to gen-
erate lung nodules and then conducted a Turing test to evaluate the quality of the
generated samples. GAN frameworks have also been shown to improve accuracy
of image classification via the generation of new synthetic training images. Frid et
al. [8] used a DCGAN and an ACGAN [27] to generate images of three liver lesion
classes to synthetically augment the limited dataset and improve the performance of a
Convolutional Neural Net (CNN) in liver lesion classification. Similarly, Salehinejad
et al. [31] employed a DCGAN to artificially simulate pathology across five classes
of chest X-rays in order to augment the original imbalanced dataset and improve the
performance of a CNN in chest pathology classification.
The GAN framework has also been utilized in semi-supervised learning architec-
tures to leverage unlabeled data alongside limited labeled data. The following efforts
demonstrate how incorporating unlabeled data in the GAN framework has led to sig-
nificant improvements in the accuracy of image-level classification. Madani et al. [20]
used an order of magnitude less labeled data with a DCGAN in semi-supervised
learning yet showed comparable performance to a traditional supervised CNN clas-
sifier and furthermore demonstrated reduced domain overfitting by simply supplying
unlabeled test domain images. Springenberg et al. [33] combined a WGAN and Cat-
GAN [35] for unsupervised and semi-supervised learning of feature representation
of dermoscopy images.
Despite the aforecited successes, GAN frameworks remain challenging to train,
as we discussed above. Our MAVEN framework mitigates the difficulties of training
GANs by enabling training on a limited quantity of labeled data, preventing overfit-
ting to a specific data domain source, and preventing mode collapse, while supporting
multiclass image classification.

3 The MAVEN Architecture

Figure 2 illustrates the models that serve as precursors to our MAVEN architecture.
The VAE is an explicit generative model that uses two neural nets, an encoder
E and decoder D  . Network E learns an efficient compression of real data x into a
lower dimensional latent representation space z(x); i.e., qλ (z|x). With neural network
likelihoods, computing the gradient becomes intractable; however, via differentiable,
non-centered re-parameterization, sampling is performed from an approximate func-
tion qλ (z|x) = N (z; μλ , σλ2 ), where z = μλ + σλ  ε̂ with ε̂ ∼ N (0, 1). Encoder E
yields μ and σ , and with the re-parameterization trick, z is sampled from a Gaus-
sian distribution. Then, with D  , new samples are generated or real data samples
are reconstructed; i.e., D  provides parameters for the real data distribution pλ (x|z).
Subsequently, a sample drawn from pφ (x|z) may be used to reconstruct the real data
by marginalizing out z.
The GAN is an implicit generative model where a generator G and a discriminator
D compete in a minimax game over the training data in order to improve their perfor-
Multi-Adversarial Variational Autoencoder Nets for Simultaneous … 253

Fig. 2 Our MAVEN architecture compared to those of the VAE, GAN, and VAE-GAN. In the
MAVEN, inputs to D can be real data X , or generated data X̂ or X̃ . An ensemble ensures the
combined feedback from the discriminators to the generator

mance. Generator G tries to approximate the underlying distribution of the training


data and generates synthetic samples, while discriminator D learns to discriminate
synthetic samples from real samples. The GAN model is trained on the following
objectives:

max V (D) = E x∼ pd ata(x) [log D(x)] + E x∼ pg (z) [log(1 − D(G(z))]; (1)


D

min V (G) = E x∼ pz (z) [log(1 − D(G(z))]. (2)


G

G takes a noise sample z ∼ pg (z) and learns to map it into image space as if it
comes from the original data distribution pdata (x), while D takes as input either real
image data or generated image data and provides feedback to G as to whether that
input is real or generated. On the one hand, D wants to maximize the likelihood for
real samples and minimize the likelihood of generated samples; on the other hand,
G wants D to maximize the likelihood of generated samples. A Nash equilibrium
results when D can no longer distinguish real and generated samples, meaning that
the model distribution matches the data distribution.
Makhzani et al. [21] proposed the adversarial training of VAEs; i.e., VAE-GANs.
Although they kept both D  and G, one can merge these networks since both can
generate data samples from the noise samples of the representation z. In this case, D
receives real data samples x and generated samples x̃ or x̂ via G. Although G and D
compete against each other, the feedback from D eventually becomes predictable for
G and it keeps generating samples from the same class, at which point the generated
samples lack heterogeneity. Figure 1c shows an example where all the generated
images are of the same class. Durugkar et al. [7] proposed that using multiple dis-
criminators in a GAN model helps improve performance, especially for resolving
this mode collapse. Moreover, a dynamic ensemble of multiple discriminators has
recently been proposed to address the issue [23] (Fig. 3).
As in a VAE-GAN, our MAVEN has three components, E, G, and D; all are CNNs
with convolutional or transposed convolutional layers. First, E takes real samples
254 A.-A.-Z. Imran and D. Terzopoulos

Fig. 3 The three convolutional neural networks, E, G, and D, in the MAVEN

x and generates a dimensionality-reduced representation z(x). Second, G can input


samples from noise distribution z ∼ pg (z) or sampled noise z(x) ∼ qλ (x) and it
produces generated samples. Third, D takes inputs from distributions of real labeled
data, real unlabeled data, and generated data. Fractionally strided convolutions are
performed in G to obtain the image dimension from the latent code. The goal of an
autoencoder is to maximize the Evidence Lower Bound (ELBO). The intuition here
is to show the network more real data. The greater the quantity of real data that it
sees, the more evidence is available to it and, as a result, the ELBO can be maximized
faster.
Multi-Adversarial Variational Autoencoder Nets for Simultaneous … 255

In our MAVEN architecture (Fig. 2), the VAE-GAN combination is extended to


include multiple discriminators aggregated in an ensemble layer. K discriminators
are collected and the combined feedback

1 
K
V (D) = wk D k (3)
K k=1

is passed to G. In order to randomize the feedback from the multiple discriminators,


a single discriminator is randomly selected.

4 Semi-Supervised Learning

Algorithm 1 presents the overall training procedure of our MAVEN model. In the
forward pass, different real samples x into E and noise samples z into G provide
different inputs for each of the multiple discriminators. In the backward pass, the
combined feedback from the discriminators is computed and passed to G and E.
In the conventional image generator GAN, D works as a binary classifier—it
classifies the input image as real or generated. To facilitate the training for an n-class
classifier, D assumes the role of an (n + 1)-classifier. For multiple logit generation,
the sigmoid function is replaced by a softmax function. Now, it can receive an image
x as input and output an (n + 1)-dimensional vector of logits {l1 , . . . , ln , ln+1 }, which
are finally transformed into class probabilities for the n labels in the real data while
class (n + 1) denotes the generated data. The probability that x is real and belongs
to class 1 ≤ i ≤ n is
exp(li )
p(y = i | x) = n+1 (4)
j=1 exp(l j )

while the probability that x is generated corresponds to i = n + 1 in (4). As a semi-


supervised classifier, the model takes labels only for a small portion of the training
data. It is trained via supervised learning from the labeled data, while it learns in an
unsupervised manner from the unlabeled data. The advantage comes from generating
new samples. The model learns the classifier by generating samples from different
classes.

4.1 Losses

Three networks, E, G, and D, are trained on different objectives. E is trained on


maximizing the ELBO, G is trained on generating realistic samples, and D is trained
to learn a classifier that classifies generated samples or particular classes for the real
data samples.
256 A.-A.-Z. Imran and D. Terzopoulos

Algorithm 1 MAVEN Training procedure.


m is the number of samples; B is the minibatch-size; and K is the number of dis-
criminators.
steps ← mB
for each epoch do
for each step in steps do
for k = 1 to K do
Sample minibatch z (1) , . . . , z (m) from pg (z)
Sample minibatch x (1) , . . . , x (m) from pdata (x)
Update Dk by ascending along its gradient:

1 
m

∇ Dk log Dk (xi ) + log(1 − Dk (G(z i )))
m
i=1

end for
(1) (m)
Sample minibatch z k , . . . , z k from pg (z)
if ensemble is ‘mean’ then
Assign weights wk to the Dk
Determine the mean discriminator

1 
K
Dμ = wk D k
K
k

end if
Update G by descending along its gradient from the ensemble of Dμ :

1 
m

∇G log(1 − Dμ (G(z i )))
m
i=1

Sample minibatch x (1) , . . . , x (m) from pdata (x)


Update E along its expectation function:
 
p(z)
∇ Eqλ log
qλ (z | x)

end for
end for

4.1.1 D Loss

Since the model is trained on both labeled and unlabeled training data, the loss
function of D includes both supervised and unsupervised losses. When the model
receives real labeled data, it is the standard supervised learning loss

L Dsupervised = −Ex,y∼ pdata log[ p(y = i | x)], i < n + 1. (5)

When it receives unlabeled data from three different sources, the unsupervised loss
contains the original GAN loss for real and generated data from two different sources:
Multi-Adversarial Variational Autoencoder Nets for Simultaneous … 257

synG directly from G and synE from E via G. The three losses,

L Dreal = −Ex∼ pdata log[1 − p(y = n + 1 | x)], (6)

L DsynG = −Ex̂∼G log[ p(y = n + 1 | x̂)], (7)

L DsynE = −Ex̃∼G log[ p(y = n + 1 | x̃)], (8)

are combined as the unsupervised loss in D:

L Dunsupervised = L Dreal + L DsynG + L DsynE . (9)

4.1.2 G Loss

For G, the feature loss is used along with the original GAN loss. Activation f (x) from
an intermediate layer of D is used to match the feature between real and generated
samples. Feature matching has shown much potential in semi-supervised learning
[32]. The goal of feature matching is to encourage G to generate data that matches
real data statistics. It is natural for D to find the most discriminative features in real
data relative to data generated by the model:
 2
L G feature = Ex∼ pdata f (x) − Ex̂∼G f (x̂)2 . (10)

The total G loss becomes the combined feature loss (10) plus the cost of maximizing
the log-probability of D making a mistake on the generated data (synG / synE); i.e.,

L G = L G feature + L G synG + L G synE , (11)

where
L G synG = −Ex̂∼G log[1 − p(y = n + 1 | x̂)], (12)

and
L G synE = −Ex̃∼G log[1 − p(y = n + 1 | x̃)]. (13)

4.1.3 E Loss

In the encoder E, the maximization of ELBO is equivalent to minimizing the KL-


divergence, allowing approximate posterior inferences. Therefore the loss function
includes the KL-divergence and also a feature loss to match the features in the synE
data with the real data distribution. The loss for the encoder is

L E = L EKL + L Efeature , (14)


258 A.-A.-Z. Imran and D. Terzopoulos

where  
p(z)
L EKL = − KL [qλ (z | x)  p(z)] = Eqλ (z|x) log
qλ (z | x) (15)
≈ Eqλ (z|x)

and  2
L Efeature = Ex∼ pdata f (x) − Ex̃∼G f (x̃)2 . (16)

5 Experiments

Applying semi-supervised learning using training data that is only partially labeled,
we evaluated our MAVEN model in image generation and classification tasks in a
number of experiments. For all our experiments, we used 10% labeled and 90%
unlabeled training data.

5.1 Data

We employed the following four image datasets:


1. The Street View House Numbers (SVHN) dataset [24] (Fig. 4a). There are 73,257
digit images for training and 26,032 digit images for testing. Out of two versions
of the images, we used the version which has MNIST-like 32 × 32 pixel RGB
color images centered around a single digit. Each image is labeled as belonging
to one of 10 classes: digits 0–9.
2. The CIFAR-10 dataset [19] (Fig. 4b). It consists of 60,000 32 × 32 pixel RGB
color images in 10 classes. There are 50,000 training images and 10,000 test
images. Each image is labeled as belonging to one of 10 classes: plane, auto, bird,
cat, deer, dog, frog, horse, ship, and truck.
3. The anterior-posterior Chest X-Ray (CXR) dataset [17] (Fig. 4c). The dataset
contains 5,216 training and 624 test images. Each image is labeled as belonging
to one of three classes: normal, bacterial pneumonia (b-pneumonia), and viral
pneumonia (v-pneumonia).
4. The skin lesion classification (SLC) dataset (Fig. 4d). We employed 2,000 RGB
skin images from the ISIC 2017 dermoscopy image dataset [5]; of which we used
1,600 for training and 400 for testing. Each image is labeled as belonging to one
of two classes: non-melanoma and melanoma.
For the SVHN and CIFAR-10 datasets, the images were normalized and provided
to the models in their original (32 × 32 × 3) pixel sizes. For the CXR dataset, the
images were normalized and resized to 128 × 128 × 1 pixels. For the SLC dataset,
the images were resized to 128 × 128 × 3 pixels.
Multi-Adversarial Variational Autoencoder Nets for Simultaneous … 259

Fig. 4 Example images of each class in the four datasets

5.2 Implementation Details

To compare the image generation and multiclass classification performance of


our MAVEN model, we used two baselines, the Deep Convolutional GAN (DC-
GAN) [29] and the VAE-GAN. The same generator and discriminator architectures
were used for DC-GAN and MAVEN models and the same encoder was used for the
VAE-GAN and MAVEN models. For our MAVENs, we experimented with 2, 3, and
5 discriminators. In addition to using the mean feedback of the multiple discrimi-
nators, we also experimented with feedback from a randomly selected discrimina-
tor. The six MAVEN variants are therefore denoted MAVEN-m2D, MAVEN-m3D,
MAVEN-m5D, MAVEN-r2D, MAVEN-r3D, and MAVEN-r5D, where “m” indicates
mean feedback while “r” indicates random feedback.
260 A.-A.-Z. Imran and D. Terzopoulos

All the models were implemented in TensorFlow and run on a single Nvidia Titan
GTX (12 GB) GPU. For the discriminator, after every convolutional layer, a dropout
layer was added with a dropout rate of 0.4. For all the models, we consistently
used the Adam optimizer with a learning rate of 2.0−4 for G and D, and 1.0−5 for
E, with a momentum of 0.9. All the convolutional layers were followed by batch
normalizations. Leaky ReLU activations were used with α = 0.2.

5.3 Evaluation

5.3.1 Image Generation Performance Metrics

There are no perfect performance metrics for measuring the quality of generated sam-
ples. However, to assess the quality of the generated images, we employed the widely
used Fréchet Inception Distance (FID) [12] and a simplified version of the Descrip-
tive Distribution Distance (DDD) [14]. To measure the Fréchet distance between two
multivariate Gaussians, the generated samples and real data samples are compared
through their distribution statistics:
 2
FID = μdata − μsyn  + Tr data + syn − 2 data syn . (17)

Two distribution samples, Xdata ∼ N(μdata , data ) and Xsyn ∼ N(μsyn , syn ), for
real and model data, respectively, are calculated from the 2,048-dimensional acti-
vations of the pool3 layer of Inception-v3 [32]. DDD measures the closeness of
a generated data distribution to a real data distribution by comparing descriptive
parameters from the two distributions. We propose a simplified version based on the
first four moments of the distributions, computed as the weighted sum of normalized
differences of moments, as follows:


4
DDD = − log wi μdatai − μsyni , (18)
i=1

where the μdatai are the moments of the data distribution, the μsyni are the moments of
the model distribution, and the wi are the corresponding weights found in an exhaus-
tive search. The higher order moments are weighted more in order to emphasize the
stability of a distribution. For both the FID and DDD, lower scores are better.

5.3.2 Image Classification Performance Metrics

To evaluate model performance in classification, we used two measures, image-level


classification accuracy and class-wise F1 scoring. The F1 score is
Multi-Adversarial Variational Autoencoder Nets for Simultaneous … 261

2 × precision × recall
F1 = , (19)
precision + recall

with
TP TP
precision = and recall = , (20)
TP + FP TP + FN

where TP, FP, and FN are the number of true positives, false positives, and false
negatives, respectively.

5.4 Results

We measured the image classification performances of the models with cross-


validation and in the following sections report the average scores from running each
model 10 times.

5.4.1 SVHN

For the SVHN dataset, we randomly selected 7,326 labeled images and they along
with the remaining 65,931 unlabeled images were provided to the network as training
data. All the models were trained for 300 epochs and then evaluated. We generated
new images equal in number to the training set size. Figure 5 presents a visual
comparison of a random selection of images generated by the DC-GAN, VAE-GAN,
and MAVEN models and real training images. Figure 6 compares the image intensity
histograms of 10K randomly sampled real images and equally many images sampled
from among those generated by each of the different models.
Generally speaking, our MAVEN models generate images that are more realistic
than those generated by the DC-GAN and VAE-GAN models. This was further
corroborated by randomly sampling 10K generated images and 10K real images. The
generated image quality measurement was performed for the eight different models.
Table 1 reports the resulting FID and DDD scores. For the FID score calculation, the
score is reported after running the pre-trained Inception-v3 network for 20 epochs
for each model. The MAVEN-r3D model achieved the best FID score and the best
DDD score was achieved by the MAVEN-m5D model.
Table 2 compares the classification performance of all the models for the SVHN
dataset. The MAVEN model consistently outperformed the DC-GAN and VAE-GAN
classifiers both in classification accuracy and class-wise F1 scores. Among all the
models, our MAVEN-m2D and MAVEN-m3D models were the most accurate.
262 A.-A.-Z. Imran and D. Terzopoulos

Fig. 5 Visual comparison of image samples from the SVHN dataset against those generated by the
different models

Fig. 6 Histograms of the real SVHN training data, and of the data generated by the DC-GAN and
VAE-GAN models and by our MAVEN models with mean and random feedback from 2, 3, to 5
discriminators
Table 1 Minimum FID and DDD scores achieved by the DC-GAN, VAE-GAN, and MAVEN models for the CIFAR-10, SVHN, CXR, and SLC datasets
CIFAR-10 SVHN CXR SLC
Model FID DDD Model FID DDD Model FID DDD Model FID DDD
DC-GAN 61.293±0.2090.265 DC-GAN 16.789±0.3030.343 DC-GAN 152.511±0.370 0.145 DC-GAN 1.828±0.370 0.795
VAE-GAN 15.511±0.1250.224 VAE-GAN 13.252±0.0010.329 VAE-GAN 141.422±0.580 0.107 VAE-GAN 1.828±0.580 0.795
MAVEN- 12.743±0.2420.223 MAVEN- 11.675±0.0010.309 MAVEN- 141.339±0.420 0.138 MAVEN- 1.874±0.270 0.802
m2D m2D m2D m2D
MAVEN- 11.316±0.8080.190 MAVEN- 11.515±0.0650.300 MAVEN- 140.865±0.983 0.018 MAVEN- 0.304±0.018 0.249
m3D m3D m3D m3D
MAVEN- 12.123±0.1400.207 MAVEN- 10.909±0.0010.294 MAVEN- 147.316±1.169 0.100 MAVEN- 1.518±0.190 0.793
m5D m5D m5D m5D
MAVEN- 12.820±0.5840.194 MAVEN- 11.384±0.0010.316 MAVEN- 154.501±0.345 0.038 MAVEN- 1.505±0.130 0.789
r2D r2D r2D r2D
MAVEN- 12.620±0.0010.202 MAVEN- 10.791±0.0290.357 MAVEN- 158.749±0.297 0.179 MAVEN- 0.336±0.080 0.783
r3D r3D r3D r3D
MAVEN- 18.509±0.0010.215 MAVEN- 11.052±0.7510.323 MAVEN- 152.778±1.254 0.180 MAVEN- 1.812±0.014 0.795
r5D r5D r5D r5D
DO-GAN 88.60± 0.08 –
[23]
Multi-Adversarial Variational Autoencoder Nets for Simultaneous …

TTUR [12] 36.9 –


C-GAN [34] 27.300 –
AIQN [28] 49.500 –
SN-GAN 21.700 –
[22]
LM [30] 18.9 –
263
264 A.-A.-Z. Imran and D. Terzopoulos

Table 2 Average cross-validation accuracy and class-wise F1 scores in the semi-supervised clas-
sification performance comparison of the DC-GAN, VAE-GAN, and MAVEN models using the
SVHN dataset
Model Accuracy F1 scores
0 1 2 3 4 5 6 7 8 9
DC-GAN 0.876 0.860 0.920 0.890 0.840 0.890 0.870 0.830 0.890 0.820 0.840
VAE-GAN 0.901 0.900 0.940 0.930 0.860 0.920 0.900 0.860 0.910 0.840 0.850
MAVEN-m2D 0.909 0.890 0.930 0.940 0.890 0.930 0.900 0.870 0.910 0.870 0.890
MAVEN-m3D 0.909 0.910 0.940 0.940 0.870 0.920 0.890 0.870 0.920 0.870 0.860
MAVEN-m5D 0.905 0.910 0.930 0.930 0.870 0.930 0.900 0.860 0.910 0.860 0.870
MAVEN-r2D 0.905 0.910 0.930 0.940 0.870 0.930 0.890 0.860 0.920 0.850 0.860
MAVEN-r3D 0.907 0.890 0.910 0.920 0.870 0.900 0.870 0.860 0.900 0.870 0.890
MAVEN-r5D 0.903 0.910 0.930 0.940 0.860 0.910 0.890 0.870 0.920 0.850 0.870

5.4.2 CIFAR-10

For the CIFAR-10 dataset, we used 50 K training images, only 5 K of them labeled.
All the models were trained for 300 epochs and then evaluated. We generated new
images equal in number to the training set size. Figure 7 visually compares a random
selection of images generated by the DC-GAN, VAE-GAN, and MAVEN models
and real training images. Figure 8 compares the image intensity histograms of 10K
randomly sampled real images and equally many images sampled from among those
generated by each of the different models. Table 1 reports the FID and DDD scores.
As the tabulated results suggest, our MAVEN models achieved better FID scores than
some of the recently published models. Note that those models were implemented
in different settings.
As for the visual comparison, the FID and DDD scores confirmed more realistic
image generation by our MAVEN models compared to the DC-GAN and VAE-GAN
models. The MAVEN models have smaller FID scores, except for MAVEN-r5D.
MAVEN-m3D has the smallest FID and DDD scores among all the models.
Table 3 compares the classification performance of all the models with the CIFAR-
10 dataset. All the MAVEN models performed better than the DC-GAN and VAE-
GAN models. In particular, MAVEN-m5D achieved the best classification accuracy
and F1 scores.

5.4.3 CXR

With the CXR dataset, we used 522 labeled images and 4,694 unlabeled images. All
the models were trained for 150 epochs and then evaluated. We generated an equal
number of new images as the training set size. Figure 9 presents a visual comparison
of a random selection of generated and real images. The FID and DDD measurements
were performed for the distributions of generated and real training samples, indicating
that more realistic images were generated by the MAVEN models than by the GAN
Multi-Adversarial Variational Autoencoder Nets for Simultaneous … 265

Fig. 7 Visual comparison of image samples from the CIFAR-10 dataset against those generated
by the different models

Fig. 8 Histograms of the real CIFAR-10 training data, and of the data generated by the DC-GAN
and VAE-GAN models and by our MAVEN models with mean and random feedback from 2, 3, to
5 discriminators
266 A.-A.-Z. Imran and D. Terzopoulos

Table 3 Average cross-validation accuracy and class-wise F1 scores in the semi-supervised clas-
sification performance comparison of the DC-GAN, VAE-GAN, and MAVEN models using the
CIFAR-10 dataset
Model Accuracy F1 scores
Plane Auto Bird Cat Deer Dog Frog Horse Ship Truck
DC-GAN 0.713 0.760 0.840 0.560 0.510 0.660 0.590 0.780 0.780 0.810 0.810
VAE-GAN 0.743 0.770 0.850 0.640 0.560 0.690 0.620 0.820 0.770 0.860 0.830
MAVEN-m2D 0.761 0.800 0.860 0.650 0.590 0.750 0.680 0.810 0.780 0.850 0.850
MAVEN-m3D 0.759 0.770 0.860 0.670 0.580 0.700 0.690 0.800 0.810 0.870 0.830
MAVEN-m5D 0.771 0.800 0.860 0.650 0.610 0.710 0.640 0.810 0.790 0.880 0.820
MAVEN-r2D 0.757 0.780 0.860 0.650 0.530 0.720 0.650 0.810 0.800 0.870 0.860
MAVEN-r3D 0.756 0.780 0.860 0.640 0.580 0.720 0.650 0.830 0.800 0.870 0.830
MAVEN-r5D 0.762 0.810 0.850 0.680 0.600 0.720 0.660 0.840 0.800 0.850 0.820

and VAE-GAN models. The FID and DDD scores presented in Table 1 show that the
mean MAVEN-m3D model has the smallest FID and DDD scores.
The classification performance reported in Table 4 suggests that our MAVEN
model-based classifiers are more accurate than the baseline GAN and VAE-GAN
classifiers. Among all the models, the MAVEN-m3D classifier was the most accurate.

5.4.4 SLC

For the SLC dataset, we used 160 labeled images and 1,440 unlabeled images. All the
models were trained for 150 epochs and then evaluated. We generated new images
equal in number to the training set size. Figure 10 presents a visual comparison of
randomly selected generated and real image samples.
The FID and DDD measurements for the distributions of generated and real train-
ing samples indicate that more realistic images were generated by the MAVEN mod-
els than by the GAN and VAE-GAN models. The FID and DDD scores presented
in Table 1 show that the mean MAVEN-m3D model has the smallest FID and DDD
scores.
The classification performance reported in Table 5 suggests that our MAVEN
model-based classifiers are more accurate than the baseline GAN and VAE-GAN
classifiers. Among all the models, MAVEN-r3D is the most accurate in discriminating
between non-melanoma and melanoma lesion images.
Multi-Adversarial Variational Autoencoder Nets for Simultaneous … 267

Fig. 9 Visual comparison of image samples from the CXR dataset against those generated by the
different models
Table 4 Average cross-validation accuracy and class-wise F1 scores for the semi-supervised clas-
sification performance comparison of the DC-GAN, VAE-GAN, and MAVEN models using the
CXR dataset
Model Accuracy F1 scores
Normal B-Pneumonia V-Pneumonia
DC-GAN 0.461 0.300 0.520 0.480
VAE-GAN 0.467 0.220 0.640 0.300
MAVEN-m2D 0.469 0.310 0.620 0.260
MAVEN-m3D 0.525 0.640 0.480 0.480
MAVEN-m5D 0.477 0.380 0.480 0.540
MAVEN-r2D 0.478 0.280 0.630 0.310
MAVEN-r3D 0.506 0.440 0.630 0.220
MAVEN-r5D 0.483 0.170 0.640 0.240
268 A.-A.-Z. Imran and D. Terzopoulos

Fig. 10 Visual comparison of image samples from the SLC dataset against those generated by the
different models
Multi-Adversarial Variational Autoencoder Nets for Simultaneous … 269

Table 5 Average cross-validation accuracy and class-wise F1 scores for the semi-supervised clas-
sification performance comparison of the DC-GAN, VAE-GAN, and MAVEN models using the
SLC dataset
Model Accuracy F1 scores
Non-melanoma Melanoma
DC-GAN 0.802 0.890 0.120
VAE-GAN 0.810 0.890 0.012
MAVEN-m2D 0.815 0.900 0.016
MAVEN-m3D 0.814 0.900 0.110
MAVEN-m5D 0.812 0.900 0.140
MAVEN-r2D 0.808 0.890 0.260
MAVEN-r3D 0.821 0.900 0.020
MAVEN-r5D 0.797 0.890 0.040

6 Conclusions

We have introduced a novel generative modeling approach, called Multi-Adversarial


Variational autoEncoder Networks, or MAVENs, which demonstrates the advantage
of an ensemble of discriminators in the adversarial learning of variational autoen-
coders. We have shown that training our MAVEN models on small, labeled datasets
and allowing them to leverage large numbers of unlabeled training examples enables
them to achieve superior performance relative to prior GAN and VAE-GAN-based
classifiers, suggesting that MAVENs can be very effective in simultaneously gen-
erating high-quality realistic images and improving multiclass image classification
performance. Furthermore, unlike conventional GAN-based semi-supervised clas-
sification, improvements in the classification of natural and medical images do not
compromise the quality of the generated images. Future work with MAVENs should
explore more complex image analysis tasks beyond classification and include more
extensive experimentation spanning additional domains.

References

1. M. Arjovsky, S. Chintala, L. Bottou, Wasserstein GAN (2017). arXiv preprint


arXiv:1701.07875
2. C. Bermudez, A.J. Plassard, L.T. Davis, A.T. Newton, S.M. Resnick, B.A. Landman, Learning
implicit brain MRI manifolds with deep learning, in Medical Imaging 2018: Image Processing,
vol. 10574 (2018), p. 105741L
3. F. Calimeri, A. Marzullo, C. Stamile, G. Terracina, Biomedical data augmentation using gen-
erative adversarial neural networks, in International Conference on Artificial Neural Networks
(2017), pp. 626–634
4. M.J. Chuquicusma, S. Hussein, J. Burt, U. Bagci, How to fool radiologists with generative
adversarial networks? A visual turing test for lung cancer diagnosis, in IEEE International
Symposium on Biomedical Imaging (ISBI) (2018), pp. 240–244
270 A.-A.-Z. Imran and D. Terzopoulos

5. N.C. Codella, D. Gutman, M.E. Celebi, B. Helba, M.A. Marchetti, S.W. Dusza, A. Kalloo,
K. Liopyris, N. Mishra, H. Kittler et al., Skin lesion analysis toward melanoma detection: a
challenge at the 2017 ISBI, hosted by ISIC, in IEEE International Symposium on Biomedical
Imaging (ISBI 2018) (2018), pp. 168–172
6. E.L. Denton, S. Chintala, A. Szlam, R. Fergus, Deep generative image models using a Lapla-
cian pyramid of adversarial networks, in Advances in Neural Information Processing Systems
(NeurIPS) (2015)
7. I. Durugkar, I. Gemp, S. Mahadevan, Generative multi-adversarial networks (2016). arXiv
preprint arXiv:1611.01673
8. M. Frid-Adar, I. Diamant, E. Klang, M. Amitai, J. Goldberger, H. Greenspan, GAN-based syn-
thetic medical image augmentation for increased CNN performance in liver lesion classification
(2018). arXiv preprint arXiv:1803.01229
9. I. Goodfellow, J. Pouget-Abadie, M. Mirza, B. Xu, D. Warde-Farley, S. Ozair, A. Courville, Y.
Bengio, Generative adversarial nets, in Advances in Neural Information Processing Systems
(NeurIPS) (2014), pp. 2672–2680
10. J.T. Guibas, T.S. Virdi, P.S. Li, Synthetic medical images from dual generative adversarial
networks (2017). arXiv preprint arXiv:1709.01872
11. C. Han, H. Hayashi, L. Rundo, R. Araki, W. Shimoda, S. Muramatsu, Y. Furukawa, G. Mauri,
H. Nakayama, GAN-based synthetic brain MR image generation, in IEEE International Sym-
posium on Biomedical Imaging (ISBI) (2018), pp. 734–738
12. M. Heusel, H. Ramsauer, T. Unterthiner, B. Nessler, S. Hochreiter, GANs trained by a two
time-scale update rule converge to a local Nash equilibrium, in Advances in Neural Information
Processing Systems (NeurIPS) (2017), pp. 6626–6637
13. L. Hou, A. Agarwal, D. Samaras, T.M. Kurc, R.R. Gupta, J.H. Saltz, Unsupervised histopathol-
ogy image synthesis (2017). arXiv preprint arXiv:1712.05021
14. A.A.Z. Imran, P.R. Bakic, A.D. Maidment, D.D. Pokrajac, Optimization of the simulation
parameters for improving realism in anthropomorphic breast phantoms, in Proceedings of the
SPIE, vol. 10132 (2017)
15. A.A.Z. Imran, D. Terzopoulos, Multi-adversarial variational autoencoder networks, in IEEE
International Conference on Machine Learning and Applications (ICMLA) (oca Raton, FL,
2019), pp. 777–782
16. A.A.Z. Imran, D. Terzopoulos, Semi-supervised multi-task learning with chest X-ray images
(2019). arXiv preprint arXiv:1908.03693
17. D.S. Kermany, M. Goldbaum, W. Cai, C.C. Valentim, H. Liang, S.L. Baxter, A. McKeown, G.
Yang, X. Wu, F. Yan et al., Identifying medical diagnoses and treatable diseases by image-based
deep learning. Cell 172(5), 1122–1131 (2018)
18. D.P. Kingma, M. Welling, Auto-encoding variational Bayes (2013). arXiv preprint
arXiv:1312.6114
19. A. Krizhevsky, Learning multiple layers of features from tiny images. Master’s thesis, Univer-
sity of Toronto, Dept. of Computer Science (2009)
20. A. Madani, M. Moradi, A. Karargyris, T. Syeda-Mahmood, Semi-supervised learning with
generative adversarial networks for chest X-ray classification with ability of data domain adap-
tation, in IEEE International Symposium on Biomedical Imaging (ISBI) (2018), pp. 1038–1042
21. A. Makhzani, J. Shlens, N. Jaitly, I. Goodfellow, B. Frey, Adversarial autoencoders (2015).
arXiv preprint arXiv:1511.05644
22. T. Miyato, T. Kataoka, M. Koyama, Y. Yoshida, Spectral normalization for generative adver-
sarial networks (2018). arXiv preprint arXiv:1802.05957
23. G. Mordido, H. Yang, C. Meinel, Dropout-GAN: learning from a dynamic ensemble of dis-
criminators (2018). arXiv preprint arXiv:1807.11346
24. Y. Netzer, T. Wang, A. Coates, A. Bissacco, B. Wu, A.Y. Ng, Reading digits in natural images
with unsupervised feature learning, in NIPS Workshop on Deep Learning and Unsupervised
Feature Learning, vol. 2011 (2011), pp. 1–9
25. B. Neyshabur, S. Bhojanapalli, A. Chakrabarti, Stabilizing GAN training with multiple random
projections (2017). arXiv preprint arXiv:1705.07831
Multi-Adversarial Variational Autoencoder Nets for Simultaneous … 271

26. T. Nguyen, T. Le, H. Vu, D. Phung, Dual discriminator generative adversarial nets, in Advances
in Neural Information Processing Systems (NeurIPS) (2017), pp. 2670–2680
27. A. dena, C. Olah, J. Shlens, Conditional image synthesis with auxiliary classifier GANs (2017).
arXiv preprint arXiv:1610.09585
28. G. Ostrovski, W. Dabney, R. Munos, Autoregressive quantile networks for generative modeling
(2018). arXiv preprint arXiv:1806.05575
29. A. Radford, L. Metz, S. Chintala, Unsupervised representation learning with deep convolutional
generative adversarial networks (2015). arXiv preprint arXiv:1511.06434
30. S. Ravuri, S. Mohamed, M. Rosca, O. Vinyals, Learning implicit generative models with the
method of learned moments (2018). arXiv preprint arXiv:1806.11006
31. H. Salehinejad, S. Valaee, T. Dowdell, E. Colak, J. Barfett, Generalization of deep neural
networks for chest pathology classification in X-rays using generative adversarial networks, in
IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP) (2018),
pp. 990–994
32. T. Salimans, I. Goodfellow, W. Zaremba, V. Cheung, A. Radford, X. Chen, Improved techniques
for training GANs, in Advances in Neural Information Processing Systems (NeurIPS) (2016),
pp. 2234–2242
33. J.T. Springenberg, Unsupervised and semi-supervised learning with categorical generative
adversarial networks (2015). arXiv preprint arXiv:1511.06390
34. T. Unterthiner, B. Nessler, C. Seward, G. Klambauer, M. Heusel, H. Ramsauer, S. Hochreiter,
Coulomb GANs: provably optimal nash equilibria via potential fields (2017). arXiv preprint
arXiv:1708.08819
35. S. Wang, L. Zhang, CatGAN: coupled adversarial transfer for domain generation (2017). arXiv
preprint arXiv:1711.08904
36. M.A. Wani, F.A. Bhat, S. Afzal, A.I. Khan (eds.), Advances in Deep Learning (Springer, 2020)
Non-convex Optimization Using
Parameter Continuation Methods for
Deep Neural Networks

Harsh Nilesh Pathak and Randy Clinton Paffenroth

Abstract Numerical parameter continuation methods are popularly utilized to opti-


mize non-convex problems. These methods have had many applications in Physics
and Mathematical analysis such as bifurcation study of dynamical systems. How-
ever, as far as we know, such efficient methods have seen relatively limited use in
the optimization of neural networks. In this chapter, we propose a novel training
method for deep neural networks based on the ideas from parameter continuation
methods and compare them with widely practiced methods such as Stochastic Gra-
dient Descent (SGD), AdaGrad, RMSProp and ADAM. Transfer and curriculum
learning have recently shown exceptional performance enhancements in deep learn-
ing and are intuitively similar to the homotopy or continuation techniques. However,
our proposed methods leverage decades of theoretical and computational work and
can be viewed as an initial bridge between those techniques and deep neural net-
works. In particular, we illustrate a method that we call Natural Parameter Adaption
Continuation with Secant approximation (NPACS). Herein we transform regularly
used activation functions to their homotopic versions. Such a version allows one to
decompose the complex optimization problem into a sequence of problems, each of
which is provided with a good initial guess based upon the solution of the previous
problem. NPACS uses the above-mentioned system uniquely with ADAM to obtain
faster convergence. We demonstrate the effectiveness of our method on standard
benchmark problems and compute local minima more rapidly and achieve lower
generalization error than contemporary techniques in a majority of cases.

H. Nilesh Pathak (B)


Expedia Group, 1111 Expedia Group Way W, Seattle, WA 98119, USA
e-mail: hpathak@expedia.com
R. Clinton Paffenroth
Worcester Polytechnic Institute, Mathematical Sciences Computer Science & Data Science,
100 Institute Rd, Worcester, MA 01609, USA
e-mail: rcpaffenroth@wpi.edu
© The Editor(s) (if applicable) and The Author(s), under exclusive license 273
to Springer Nature Singapore Pte Ltd. 2021
M. A. Wani et al. (eds.), Deep Learning Applications, Volume 2,
Advances in Intelligent Systems and Computing 1232,
https://doi.org/10.1007/978-981-15-6759-9_12
274 H. Nilesh Pathak and R. Clinton Paffenroth

1 Introduction

In many machine learning and deep learning problems, the key task is an optimization
problem, with the objective to learn useful properties of the data given a model. The
parameters of the model are used to learn linear or non-linear features of the data
and can be used to perform inference and predictions. In this chapter, we will study
a challenging non-convex optimization task, i.e. ability of deep learning models to
learn non-linear or complex representations from data.
This chapter is an extension of the conference paper [52] with additional results,
additional context for the theory of the algorithms, a detailed literature survey listing
some common limitations of curriculum strategies, and discussions on open research
questions.
Often, the objective of a deep learning model is to approximate a function f true ,
which is the true function that maps inputs x to the targets y, such that y = f true (x)
[15, 50]. More formally, a deep neural network defines a mapping y = f (x; θ ),
where θ is the set of parameters. These parameters θ are estimated by minimizing
an objective function, such that the mapping f best approximates f true . However,
training the deep neural network to find a good solution is a challenging optimization
task [3, 17, 44]. Where by a good solution is meant achieving low generalization error
in few training steps. Even with state-of-the-art techniques, machines configured with
a large number of processing units and high memory can spend days, or even longer
to solve such problems [50].
Neural networks use a composition of functions which, generally speaking, are
non-convex and difficult to solve [3, 8]. Even deep linear networks functions are
non-convex in parameter space [17]. Evidently, deep learning research have exceed-
ingly advanced in the past decade, and quality optimization algorithms have been
proposed to solve non-convex problems such as Stochastic Gradient Descent (SGD)
[15], RMSprop [25], Adagrad [12] and ADAM [35] are widely practiced by the
deep learning community and have significantly advanced the state-of-the-art in text
mining [43], super-resolution [16, 40, 51], image recognition [38], speech recog-
nition [24] and many more. However, their success usually depends on the choice
of hyperparameters and the quality of the initialization [60]. Previously, researchers
have shown that different initialization methods may lead to dramatically different
solution curve geometries on a cost surface [8, 17, 27]. In other words, a small shift
in the initialization may lead the model to converge to a different minimum [50]. In
this chapter, we attempt to introduce and improve the solution of such non-convex
optimization problems by rethinking usual approach. In particular, many of the cur-
rent state-of-the-art optimization techniques work on a fixed loss surface. On the
other hand, we propose to transform the loss surface continuously in order to design
an effective training strategy.
The contributions of this chapter are the following. We derive a homotopy formu-
lation of common activation functions that implicitly decomposes the deep neural
network optimization problem into a sequence of problems. Which then enables one
to design and apply parameter continuation methods to rapidly converge to a bet-
Non-convex Optimization Using Parameter Continuation … 275

ter solution. Further, we developed a set of Natural Parameter Continuation (NPC)


[52] techniques, apply them to neural networks, and empirically observe improve-
ment in terms of generalization performance compared to standard optimization
techniques. After performing a curated list of experiments, we observed that a naive
NPC method requires careful tuning depending on the network hyperparameters,
data and activation functions [52]. Accordingly, we designed an enhanced strategy
to adaptively adjust during training, which we call Natural Parameter Adaptive Con-
tinuation (NPAC). As motivated by the continuation literature, Natural Parameter
Adaptive Continuation with Secant approximation (NPACS) is our prime contribu-
tion. Herein we improve the quality of initialization and speed of convergence by
employing the secant method. Further, we perform a set of experiments that demon-
strate that continuation methods achieve lower generalization and training loss, that
implies faster learning compared to conventional optimization methods. In some of
our experiments, we observed the standard optimization methods trapped at some
local minimum which our proposed methods swiftly surpassed. Finally, we discuss
open research problems and provide insights into how continuation methods can
serve as an efficient strategy for optimizing and analyzing neural networks. These
ideas are demonstrated using benchmarks and extensive empirical evaluations from
multiple sources.
The remainder of this chapter is structured as follows: Sect. 2 provides a pre-
liminary discussion of homotopies and similar ideas. Section 2.1 provides classical
theory that connects to our work involving the implicit function theorem. Section 3
covers related work, in particular, we discuss advantages and limitation of tech-
niques that use activation as a medium of continuation or curriculum. Section 4
describes our methodology and approach, including a detailed description of our
novel optimization and initialization strategy. In Sect. 5 we discuss our experimental
results. Additionally, we present a discussion of future open problems in Sect. 6, that
could help unravel new paths to deep learning research. Finally, Sect. 7 concludes
the chapter. In this extended chapter, Sects. 2, 6 and 4.3 are newly added. Moreover,
additional insights and results are added in Sects. 3, 4.6 and 5.

2 Background

Non-convex optimization is a challenging task which arises in almost all deep


learning models [15, 52]. A historically effective class of algorithms for solving
such non-convex problems are numerical continuation methods [2]. Continuation
methods can be utilized to organize the training of a neural network and assist in
improving the quality of an initial guess that may accelerate the network conver-
gence [52]. The fundamental idea is to start from an easier version of the desired
problem and gradually transform it into the original version. Formally, this may be
described as follows: given a loss function J (θ ), define a sequence of loss func-
tions {J (0) (θ ), J (1) (θ ), J (2) (θ ), J (3) (θ ), . . . , J (n) (θ )} such that J (i) (θ ) is easier to
optimize than J (i+1) (θ ) [15]. Here J (i) (θ ) and J (i+1) (θ ) are sufficiently similar, so
276 H. Nilesh Pathak and R. Clinton Paffenroth

that the solution of J (i) (θ ) can be used as an initial estimate for the parameters of
J (i+1) (θ ) [2, 11, 50].
Based upon a literature survey [3, 44, 49, 50], there appear to be two main thrusts
for improving optimization by continuation:
• How to choose the simplified problem?
• How to transform the simplified problem into the original task?
In this chapter, we provide an approach that addressed both of these challenges. We
also observe that the results in this chapter are based upon results in several disparate
domains, and good entryways into the literature include [15, 26] for Autoencoders
(AE) [2, 34], for continuation methods and [3] Curriculum Learning.

2.1 Homotopies

One possible way to incorporate continuation methods in the context of neural net-
works is determining homotopy [2] formulations of activation function [52], as
defined in the following manner:

h(x, λ) = (1 − λ) · h 1 (x) + λ · h 2 (x) (1)

where λ ∈ [0, 1] is known as a homotopy parameter. Herein, we consider h 1 (x) to


be some simple activation function and h 2 (x) to be some more complex activation
function. For example, h 1 (x) := x and h 2 (x) := 1+e1 −x , and h(x, λ) provides a con-
tinuous transformation between them. While a seemingly simple idea, such methods
have a long history and have made substantial impacts in other contexts, for example,
the solution of boundary value problems in ODEs [2, 11]. As we will detail in the
next section, when such ideas are applied to neural networks, one can transform the
minimum of a simple neural network into a minimum of a much more complicated
neural network [52].

2.2 The Implicit Function Theorem

The Implicit Function Theorem (IFT), is an important tool for the theoretical under-
standing of continuation methods [34]. While a fulsome treatment of these ideas is
beyond the scope of this chapter, in this section we will provide a measure of intuition
for the interested reader to help bridge the gap between continuation methods and
deep learning. Let’s begin with the objective of a simple AE [34].

J (X, g( f (X )); θ ) = 0 (2)


G(X ; θ ) = 0 (3)
Non-convex Optimization Using Parameter Continuation … 277

where X ∈ R m×n is the input data, f is the encoder function, g is the decoder function
and θ ∈ R N the network parameters. We want it to be zero in the case of mean squared
error. Now we can represent this objective function as a function G(X ; θ ) which can
be used in a continuation method framework [34].
Note, that in the above equation X is not a parameter. It is fixed by our choice
of training data. Accordingly, when we ‘count parameters’ below, only θ serves
as degrees of freedom. Now, say we add another parameter λ, which we call a
homotopy parameter, to the system of neural network equations. In our work, this
single parameter controls the transformation of the problem from a simple problem
to a hard problem, and such homotopy parameters will be the focus of Sect. 4. We
rewrite to get

G(θ, λ) = 0 (4)

Now (4) can be directly related to the the system of equations in [34], for which
implicit function theorem can be applied under certain assumptions1 The IFT then
guarantees that, in the neighbourhood of a known solution when the derivative is
non-singular, that a smooth and continuous homotopy path must exist. The interested
reader can see [2, 34] for more details.

3 Related Work

Optimization technique such as SGD or ADAM plays a crucial role in the quality
of convergence. Also, finding better initialization in order to find a superior local
minimum is popular in the research community. Here we draw a few observations
from some of the popular research works which have demonstrated quality perfor-
mance. Unsupervised pre-training [26] is one such strategy to train neural networks
for supervised and unsupervised learning tasks. The idea is to use an Autoencoder
(AE) to greedily train the hidden layers one layer at a time. This method has two
advantages. First, it provides stable initialization as compared to random initializa-
tion as the model has silently learned properties of the data. Second, it increases the
regularization effect [15, 50]. Fine-tuning may be seen as an evolution of above-
mentioned approach where the learned layers are trained further or fine-tuned for the
final task [15, 39]. In transfer learning, you are provided with two distinct tasks,
such that learning from one distribution can be transferred to another. This tech-
nique is broadly used in vision and language tasks such as recommendation systems,
image super-resolution, etc. In summary, these ideas suggest that incorporating prior

1 Assumptions:

• G : RN × R − → R N be smooth map
• ||G(θ0 , λ0 )|| ≤ c
• G 0 (θ) be non-singular at a known root ((θ0 , λ0 ))
See. [34] for the IFT theorem and proofs for local continuation.
278 H. Nilesh Pathak and R. Clinton Paffenroth

knowledge to the network or taking few simple learning steps first, can enhance the
final training performance [13, 17].
Some of the standard work that has been done close to the research in this chapter is
curriculum learning [3, 23]. Smoothing [22, 44] is a common approach for changing
the qualitative behaviour of activation functions throughout the course of the train-
ing that has been adopted by many researchers in distinct ways [5, 21, 22, 44, 50].
Smoothing can be seen as a convex envelope of the non-convex problems in [45] and
it may reveal the global curvature of the loss function [3, 44]. Then this smoothing
can be continuously decreased to get a sequence of loss functions with increasing
complexity. The smoothed objective function is minimized first, and then progres-
sively smoothing is reduced as training proceeds to obtain the terminal solution
[3, 44].
Next, Mollifying Networks added normally distributed noise to the weights of each
layer of the neural network [22] and thus, mollified the objective function. Again,
this noise is steadily reduced as the training proceeds. Linearization of activation
functions is another way of achieving the mollification effect [5, 21, 22]. Similarly,
gradually deforming from linear to non-linear representation has been implemented
with different noise injection techniques [21, 22, 44].
Initialization is one of the main advantages of continuation methods [2, 34]. In
many deep learning applications, we observe random initialization is widely prac-
ticed. However, continuation methods suggest the following, namely, to start with a
problem whose solution is easily available and then incrementally advance towards
the problem you want to solve. Naturally, this method can save computational cost.
While other methods would require many iterations to attain a reasonable solution, in
the paradigm of parameter continuation methods, the initialization could usually be
deterministic, and later we progressively solve simpler problems to obtain the opti-
mal solution. For example, in [44] one performs accurate initialization with almost
no computational cost. Furthermore, some more advantages were discussed in diffu-
sion methods [44]. In that research work, authors conceptually showed the popular
techniques such as dropouts [58], layerwise pre-training [26, 39], annealed learning
rate, careful initialization [60] and noise injection (directly to loss or weights of the
network) [3] can naturally arise from continuation methods [44].
We discuss a few limitations and future research directions in Sect. 6. In com-
parison to the previous approaches that add noise to activation functions [21, 22] to
serve as a proxy for the continuation method, we derive homotopies to obtain lin-
earization, which we will discuss in Sect. 4. Our intent is to simplify the original task
and start with a good initialization. Despite the above advances, there is a need for
systematic empirical evaluation to show how the continuation method may help to
avoid early local minima and saddle points [17, 52]. Unlike most of the research, that
was focused on classification tasks or prediction tasks, herein we focus on unsuper-
vised learning. However, there is nothing special in the formulation of the homotopy
function that depends on the particular structure of AEs, so we expect similar results
are possible with arbitrary neural networks.
Non-convex Optimization Using Parameter Continuation … 279

4 Methodology

In this section, we discuss our method to construct a sequence of objective functions


with increasing complexity. We illustrate how we transform the standard activation
functions through homotopy. This leads us to strategically solve deep neural network
optimization, following the principles of parameter continuation [2]. We discuss three
continuation methods, namely, NPC, NPAC and NPACS [52].

4.1 Continuation Activation

(C-Activation) is a homotopy formulation of standard activation functions. Homo-


topy adds the ability to a network to learn linear, non-linear and any intermediate
characteristic of the data. Activation functions can be reformulated according to (5),
namely,
φC-Activation (v, λ) = (1 − λ) · v + λ · φ(v)
(5)
λ ∈ [0, 1]

where φ can be any activation function such as Sigmoid, ReLU, Tanh, etc., λ is the
homotopy parameter and v is the value of the output from the current layer. For a
standard AE, the loss function can be represented as the following

J (X, gθ ( f θ (X ))) = argmin ||X − gθ ( f θ (X ))||2 . (6)


θ

where X ∈ R m×n is the input data, f is the encoder function, g is the decoder function
and θ the network parameters. With the addition of a homotopy parameter λ the
optimization can be rewritten in the following manner:

J (X, gθ,λ ( f θ,λ (X )))) = argmin ||X − gθ,λ ( f θ,λ (X ))||2 . (7)
θ,λ

In the above optimization problem, at any continuation step λi (λ ∈ [0, 1] and i ∈ N0 )


we solve the problem J (X, gθi ,λi ( f θi ,λi (X )))) and obtain a solution as θi , λi . Every
value of λ represents a different optimization problem and a corresponding degree
of non-linearity [52].

4.2 Intuition for Extreme Values of λ

To provide intuition, first let us consider the extremes (i.e. λ = 0 or λ = 1). When λ = 1,
both the objective functions in (6) and (7) are exactly the same, thus λ = 1 indicates
the use of conventional activation functions. Second, in a deep neural network as
280 H. Nilesh Pathak and R. Clinton Paffenroth

Fig. 1 This figure is a good example of Manifold Learning [69]. Points in blue are true data points
that clearly lies on a non-linear manifold. The green points show an optimized non-linear projection
of the data. The red points show the linear manifold which is projected by PCA. Learning linear
manifold can be considered as simplified problem to non-linear manifold learning

shown in Fig. 6, consider all the activation functions are C-Activations with λ0 = 0,
then the neural network is a linear. Here, the network would attempt to find a linear
projection in order to minimize the cost function. We know such a solution can be
found in closed form using Principle Component Analysis (PCA) or the Singular
Value Decomposition (SVD) [50]. Hence, the PCA solution and the solution of our
optimization problem at λ0 = 0 should span the same subspace [15, 50]. Thus, we
leverage this observation and initialize our network using PCA, which we discuss
further in next Sect. 4.3. Effectively, λ = 0 is analogous to solving a simpler version
of the optimization problem (a linear projection) and as λ −→ 1 the problem becomes
harder (a non-linear and non-convex problem) (Fig. 1).
Thus, the homotopy parameter λ defines the scheme for the optimization, which
can be referred to as the homotopy path, as shown in Fig. 2. As λ : 0 − → 1, we solve
an increasingly difficult optimization problem, where the degree of non-linearity
learned from the data increases. However, we need a technique to traverse through
this homotopy path and find the solutions along that path and deriving such techniques
will be our focus in the following section.

4.3 Stable Initialization of Autoencoder Through PCA

Principle Component Analysis (PCA) [47] is a classic dimension reduction technique


that provides a lower dimensional linear representation of the data. It is well-known
that PCA projections can be computed using the Singular Value Decomposition
(SVD) [47] by computing matrices U , , V T from given data as shown in (8).

X = U V T (8)
Non-convex Optimization Using Parameter Continuation … 281

Fig. 2 This figure provides intuition for NPC with a hypothetical homotopy path (blue curve),
which connects all solutions in multidimensional space (θ) at every λ ∈ [0, 1]. Here we show how
solution (θλ1 ) is used to initialize the parameters of the network at λ2 , θλinit
2

− θλ1 using a small
λ2,1 . Further taking some ADAM steps (shown by the orange arrow), we find some minimum
(θλ2 )

where X is our normalized training data,2 U and V are unitary matrices, and  is a
diagonal matrix.
In Eq. (9), we observe that V can be seen as a mapping of data from high dimen-
sional space to a lower dimension linear subspace (an encoder in language of AEs).
In addition, V T is a mapping from a lower dimensional linear subspace back to
the original space of the provided data (a decoder in the language of AEs). This
behaviour of SVD enables us to initialize the weights and appropriately defined AE.
In particular, when λ = 0 in (7) we have that f and g are, in fact, linear functions
and (7) can be minimized using the SVD [47].
More precisely, we use the first n columns of V for the encode layer with width
n and for the decode layer, we use the transpose of the weights used for the encoder,
as shown in (12) and as in [50].

X = U V T (9)
X V = U V T V (10)
XV = U (11)
n
Wencoder = Vn-columns (12)
Wdecoder = (Wencoder
n n
)T (13)

where W n represents the weight matrix of the encoder layer with n as its width.
There are multiple advantages to initializing an AE using PCA. First, we start
the deep learning training from a solution to a linear problem rather than trying
to solve a non-linear, and often non-convex, optimization problem. Having such
a straight-forward initiation procedure allows the optimization of the AE to begin

2 For PCA the data is normalized by having the mean of each column being 0.
282 H. Nilesh Pathak and R. Clinton Paffenroth

from a known global optimum. Second, in the NPCS method, the C-Activation
defines the homotopy as a continuous deformation from a linear to a non-linear
network. Accordingly, the global optimum to the linear problem can be transformed
into a solution of the non-linear and non-convex problem of interest in a principled
fashion. Finally, as PCA provides a deterministic initialization which does not require
any sampling of the parameter space for initialization [10, 44, 60], our proposed
method also has a computational advantage. Of course, the idea of initializing a deep
feedforward network with PCA is not novel and has been independently explored
in the literature [7, 36, 57]. However, our proposed algorithm is unique in that
it leverages powerful techniques for solving homotopy problems using parameter
continuation.

4.4 Natural Parameter Continuation (NPC)

Natural Parameter Continuation (NPC) is an adaptation of the standard iterative


optimization solvers to a homotopy type problem [52]. Continuing our explanation
from Sect. 4.2, the solution at λi can be used as the initial guess for the solution
at λi+1 = λi + λi . With λi being sufficiently small, the iteration applied to the
initial solution should converge to a local minima [2]. In Fig. 2, we show the method
used to traverse along the homotopy path. Next, we are asked to find the solution for
every problem (J (x, g( f (x); θ, λi ))) along the way. The solution set of such an N-
dimensional (θ ) problem is on one-dimensional manifold (homotopy path) embedded
in an N+1 dimensional space [48, 52], as we show schematically in the 2D plot 2.
Note, using PCA, we know the solution at λ0 = 0, if λ0 is made sufficiently small,
then our solution at λ1 , for example, should be very close to the PCA solution. Hence,
we initialize the optimization problem at λ1 with a solution of λ0 . We keep repeating
this iteratively till the end of homotopy to find the solution to the problem of interest.
Generally, this may be referred as Natural Parameter Continuation (NPC) [2, 48] as
shown in Fig. 2, where the solution of the optimization problem at λi can be used
to initialize the problem at λi+1 . After initializing we may require some standard
optimization techniques such as a few steps of the ADAM algorithm to find the
optimal solution at λi+1 (Fig. 2).
The main challenge in deep learning framework is we have too many hyperpa-
rameters. We have different network architectures, activation functions, batch-size,
etc and coming up with a thumb rule is difficult. In particular, we have two main
disadvantages with this fixed step size; (1) we may end up taking many steps where
they are not required and (2) fewer steps where we require cautious updates depend-
ing on the nature of a homotopy path [52]. To better understand these disadvantages
we illustrate, Fig. 3. To address this, our next method is inspired by the parameter
continuation literature, in Sect. 4.5. We use an adaptive solution that can reasonably
determine the nature of a homotopy path and update λ accordingly.

θλinit
i+1

− θλi (14)
Non-convex Optimization Using Parameter Continuation … 283

Fig. 3 This figure, we provide intuition for our adaptive methods. We show two possible set of
steps, where green is an adaptive λ and grey is a fixed λ. We want to take larger steps when the
homotopy path is flat and shorter steps where the homotopy path is more oscillatory

4.5 NPC with Adaption for Continuation Steps (NPAC)

In the NPC we had a user-specified configuration for a fixed λ. In practice, we


found coming up with a precise λ is difficult. In the field of continuation methods
or bifurcation analysis, idea of adaptive step sizes is well-known and used in software
packages such as AUTO2000 [48] and AUTO-07P [11]. Accordingly, we borrow
this idea in our implementation of neural network training. First, we observe that
the behaviour of activation function plays an important role towards the nature of
the homotopy path. For example, in Fig. 4, at λ = 0.7 we observe that C-Sigmoid
is mostly linear and C-ReLU is gradually transforming into a ReLU, demonstrating
two different possible behaviours while transforming through the homotopy path.
Accordingly, at λ = 0.7 the algorithm should make longer λ updates in the case
of C-Sigmoid, but shorter in the case of C-ReLU.
In our experiments, we show similar observations in Table 1. For NPC methods
we used fix updates that we computed empirically, such as λ = 0.008 for C-ReLU,
however, for C-Sigmoid and C-Tanh we had to be careful in determining λ values.
Until λ < 0.8 λ = 0.02 and if λ >= 0.8 then λ = 8e − 4. This is because we

Fig. 4 This figure, from left to right, illustrates how the C-Sigmoid, C-ReLU and C-Tanh behaves
at λ = 0.7 on uniformly distributed points between [−10 and 10]
284 H. Nilesh Pathak and R. Clinton Paffenroth

Table 1 This table shows the train and test loss values of different optimization techniques using a
specified network. All these experiments were computed for 50,000 backpropogation steps, and we
report the averages of the last 100 loss values for both training and testing. Perhaps not surprisingly,
SGD does the worst in almost all cases. More interestingly, RMSProp and ADAM both do well
in some cases, and quite badly in others. Note that the various parameter continuation methods all
have quite stable properties and achieve a low loss in all cases and the lowest loss in most cases
Network SGD RMSProp ADAM NPC NPAC NPACS
Fashion- AE-8 0.1133 0.03915 0.03370 0.03402 0.03370 0.03388
MNIST Sigmoid
AE-8 ReLU 0.11122 0.03582 0.03318 0.03171 0.03188 0.03191
AE-8 Tanh 0.10741 0.03459 0.03515 0.03573 0.03552 0.03559
AE-16 0.11397 0.06714 0.06714 0.03418 0.04505 0.03461
Sigmoid
AE-16 0.11394 0.06714 0.03436 0.03474 0.03445 0.03659
ReLU
AE-16 Tanh 0.10889 0.03419 0.03540 0.03753 0.03722 0.03622
CIFAR-10 AE-8 0.28440 0.03861 0.03352 0.03275 0.03224 0.03238
Sigmoid
AE-8 ReLU 0.07689 0.03467 0.03421 0.03459 0.03461 0.03302
AE-8 Tanh 0.27565 0.03355 0.03421 0.03343 0.03392 0.03408
AE-16 0.28717 0.06223 0.06223 0.03480 0.03310 0.03517
Sigmoid
AE-16 0.07722 0.03512 0.03419 0.03400 0.03456 0.03463
ReLU
AE-16 Tanh 0.27884 0.03496 0.03452 0.03637 0.03815 0.03405
Maximum 0.28440 0.06714 0.06714 0.03753 0.04505 0.03659
Fashion- AE-8 0.11333 0.08508 0.08525 0.08257 0.08324 0.08170
MNIST Sigmoid
(Test)
AE-8 ReLU 0.11123 0.08202 0.08076 0.08225 0.08160 0.08154
AE-8 Tanh 0.10742 0.08571 0.07904 0.07225 0.07306 0.07447
AE-16 0.11397 0.11396 0.11383 0.07405 0.09050 0.07912
Sigmoid
AE-16 0.11394 0.11396 0.08035 0.08103 0.07993 0.08441
ReLU
AE-16 Tanh 0.10891 0.07736 0.07713 0.07899 0.07519 0.08025
CIFAR-10 AE-8 0.28440 0.07835 0.06526 0.06344 0.06417 0.06522
(Test) Sigmoid
AE-8 ReLU 0.28440 0.03861 0.03352 0.03276 0.03225 0.03239
AE-8 Tanh 0.27568 0.09993 0.08604 0.06801 0.07580 0.07544
AE-16 0.28718 0.28671 0.28648 0.07735 0.07588 0.10232
Sigmoid
AE-16 0.07724 0.05507 0.05321 0.05188 0.06009 0.05564
ReLU
AE-16 Tanh 0.27887 0.08614 0.08939 0.10874 0.12128 0.08714
Maximum 0.28718 0.28671 0.28648 0.10874 0.12128 0.10232
For each row, the largest (worst) loss is shown in red, and the lowest (best) loss is shown in green
Non-convex Optimization Using Parameter Continuation … 285

observed C-Sigmoid has almost linear behaviour until λ = 0.8 after which it rapidly
adapts to the Sigmoid. Therefore, we needed an adaptive method that can reasonably
tune this λ update for different activation functions [52] and also adapt to the nature
of the homotopy path. We elaborate on these λ choices in Sect. 5
Accordingly, we developed an adaptive method for determining λ update by
utilizing the information of gradients during backpropagation and developed Algo-
rithm 1. NPAC has two benefits, first, it solves the issue of hand-picking the value
of λ. Second, this method provides a reasonable approach to determine how close
the next suitable problem or λ value should be, for which the current solution (say
at λ = 0.25) would be a good initialization.

Algorithm 1 Adaptive λ
Require: norm_grads- list of Norm of gradients of objective function from previous t steps, λi,i−1
for previous step, scale_up and scale_down factors.
1: avg_norm_p ← mean of first half of norm_grads
2: avg_norm_c ← mean of second half of norm_grads
3: condition1 ← (avg_norm_p−avg_norm_c)
avg_norm_p < (−tolerance)
4: condition2 ← (avg_norm_p−avg_norm_c)
avg_norm_p > (tolerance)
5: if condition1 then
6: λi+1,i ← λi,i−1 · scale_up
7: else if condition2 then
λi,i−1
8: λi+1,i ← scale_down
9: else
10: λi+1,i ← λi,i−1
11: end if
12: return λi+1,i

4.6 Natural Parameter Adaptive Continuation with Secant


Approximation (NPACS)

NPACS is a more advanced version of NPAC, where we enhance our method with
a secant step in θ (multidimensional) space along with the adaptive λ update. A
secant line to a curve is one that passes through at least two distinct points [31],
and we use this method to find linear approximation of the homotopy path, for
a particular neighbourhood [2]. In previous two methods, we simply assigned the
previous solution as the initialization for the current problem, whereas in NPACS
we take the previous two solutions and apply a secant approximation to initialize
the next step. An important step is to properly normalize for the secant step for
the parameters (θ ) of a neural network which are commonly thousands or even
millions of dimensions, depending on the network size [52]. In Fig. 5, we illustrate
the geometric interpretation of Eq. (15). A clear advantage of NPACS is that a secant
update follows the homotopy curve more closely to approximate the derivative of
286 H. Nilesh Pathak and R. Clinton Paffenroth

Fig. 5 This figure shows the


secant update in our NPCAS
method. Unlike other
continuation methods, here
we utilize the previous two
solutions to draw a secant
vector in multidimensional
space (θ)

the curve [52] and initialize the subsequent problem accordingly. We also present
Algorithm 2, that demonstrates all the required steps to implement NPACS. Here we
perform model continuation for an AE (depth 8 and 16), using ADAM updates to
solve a particular optimization at λi . Depending on the problem, any other optimizer
may also be selected.

λi+1,i
θλinit ←
− θλi + (θλi − θλi−1 ) · (15)
i+1
λi,i−1

5 Experimental Results

We used two popular datasets for our experiments, namely, CIFAR-10 [37] and
Fashion-MNIST [67]. Fashion-MNIST has 55,000 training images and 10,000 test
images with ten different classes and each image has 784 dimensions. Similarly, for
CIFAR-10 we used 40,000 images for training and 10,000 as a test/validation set.
Each image in this dataset is 3072 (32 × 32 × 3) dimensions with ten different
classes. CIFAR-10 is a more challenging dataset than Fashion-MNIST [52] and also
widely used by the researchers. Next, we have autoencoders (AE), an unsupervised
learning method to test our optimization technique. The employed AE is shown in
Fig. 6. In case of Fashion MNIST dataset, the input is a 784-dimensional image
(or 3072-dimensional image for CIFAR-10 dataset). AE is then used to perform
the reconstruction of the image from only two-dimensional representation, which is
encoded in the hidden layer. In particular, two neural networks are evaluated, namely,
AE-8 and AE-16 of depth 8 and 16, respectively, for all our experiments.
We compare our parameter continuation methods NPC at (λ = 8e − 3), NPAC
and NPACS against existing methods such as ADAM, SGD and RMSProp. Primar-
ily, task consistency plays a key role in providing conclusive empirical evaluations,
and we achieve it by keeping the data, architecture and hyperparameters fixed. As
explained in Sect. 4.3 SVD is used for initialization of our network consistently
Non-convex Optimization Using Parameter Continuation … 287

Algorithm 2 NPACS- model Continuation for AE using ADAM


Require: Learning rate , Number of Adam steps to perform after every continuation step - u_ f r eq,
Initial λ, Initial neural network parameter θ using PCA, intial homotopy parameter λ, t- after
t steps adaptive behaviour will start.
1: k ← 1
2: i ← 1
3: nor m_grads ← [ ]
4: loss_histor y ← [ ]
5: while stopping criteria not met do
6: Sample a minibatch from i
data x
7: Compute Loss loss ← j J (x, g( f (x)); θ; λi )
8: loss_histor y ← Append loss
9: Compute gradient estimate ĝ ← AD AM()
10: nor m_grads ←Append ||ĝ||2
11: Apply gradient θ ← θ −  · ĝ
12: k ← k + 1
13: if k%u_ f r eq == 0 and u_ f r eq > 0 then
14: if k%t == 0 and k > 0 then
15: λi+1,i ← Compute Adaptive λ() Algorithm-1
16: end if

17: θλinit
i+1

− θλi + (θλi − θλi−1 ) · λi+1,i
i,i−1
18: i =i +1 
19: Compute loss ← j J (x, g( f (x)); θ; λi )
20: loss_histor y ← Append loss
21: k ←k+1
22: end if
23: end while

Fig. 6 This figure shows


Autoencoder (AE-8) we used
in our experiments. It is an
eight-layer deep network
with specified width as
shown in the above block
diagram. Additionally, we
apply one particular
activation function to all the
hidden layers of the network,
except the code and the
output layer. We specify
these while reporting results

across all our experiments. This provides us an exact vector of parameters (θ ) at


λ = 0. Selection of λ was carried out using line search between (8e-5 and 2e-2)
and we used the best performing λ =8e-3 for NPC methods. Further, for NPAC
and NPACS methods λ was chosen by our adaptive Algorithm 1. Another impor-
tant hyperparameter is number of ADAM (or any other solver) steps in between two
continuation steps. Again, we performed a linear search over all values between 5
and 500, and found 10 ADAM steps to work the best and we use it consistently. Also
288 H. Nilesh Pathak and R. Clinton Paffenroth

note, Sigmoid, Tanh and ReLU are used with SGD, ADAM and RMSProp optimiza-
tion techniques, but while implementing continuation methods, their continuation
counterparts were used, such as C-Sigmoid and C-ReLU from (5). Finally, we are
going to opensource our code and all the hyperparameter choices on Github.3
Three important metrics to test a new optimizer are the qualitative analysis of train-
ing loss, generalization loss and convergence speed [28]. Table 1, depicts training and
validation loss of our network (six variants or tasks) with six different optimization
methods and two popular datasets. Table 1 indicates that our continuation methods
are consistently performing better at both validation and training loss. There are few
more interesting conclusions that can be drawn from Table 1. First, as expected,
ADAM, RMSProp and continuation methods performed better than SGD. This also
shows the optimization tasks for the experiment are not trivial. Second, our network
variant AE-16 Sigmoid turns out to be the most challenging task where ADAM and
RMSProp get stuck in a bad local minima which empirically shows networks are
difficult to train. However, our methods were able to skip these sub-optimal local
minima and achieved 49.18% lower training loss with 34.94% better generalization
as compared to ADAM, as shown in Table 1. Similar results were observed with
the Fashion-MNIST dataset. Another optimization bottleneck was seen in case of
AE-16 ReLU with the Fashion MNIST data, here RMSProp was unable to avoid
a sub-optimal local minimum, on the other hand, our methods had 48.88% lower
training loss and 30.08% better generalization (testing) error. Finally, we qualita-
tively demonstrate the maximum loss attained by various optimizers at the final step.
Clearly from Table 1, our parameter continuation methods visibly have a much lower
loss across all distinct tasks.
Further, to report convergence speed we analyze the convergence plots of all the
distinct tasks. Out of all plausible ways to define the speed of convergence. In this
chapter, an optimizer is faster if it obtains a lower training loss at an earlier step
in contrast to competing optimizer [52]. Hence, we studied the convergence plots
of various optimizers and found that continuation methods, i.e. NPC, NPAC and
NPACS, converged faster in majority of the tasks. In Fig. 7, we can see that the tasks
AE-16 and AE-8 Sigmoid networks continuation methods have a much lower training
error from the very beginning of the training as compared to both the RMSProp and
ADAM optimizer. Thus, our results from Table 1. Additionally, we extend our results
to three standard activation functions, C-Sigmoid in Fig. 7, C-Relu in Fig. 8 and C-
Tanh in Fig. 9. Through these tables we not only show that our generalization loss
is better in most cases but also that the continuation methods converges faster (i.e.
achieve lower train loss in fewer ADAM updates).

3 https://github.com/harsh306/NPACS.
Non-convex Optimization Using Parameter Continuation … 289

Fig. 7 The figures above demonstrate the convergence of various optimization methods when C-
Sigmoid is used as an activation function. In all cases, the X-axis shows the number of iterations,
and the Y-axis shows the log-loss or log-validation-loss. Note that the methods we propose provide
some of the lowest loss values, for training and testing, throughout the optimization procedure
290 H. Nilesh Pathak and R. Clinton Paffenroth

Fig. 8 The figures above demonstrate the convergence of various optimization methods when C-
ReLU is used as an activation function. In all cases, the X-axis shows the number of iterations,
and the Y-axis shows the log-loss or log-validation-loss. Note that the methods we propose provide
some of the lowest loss values, for training and testing, throughout the optimization procedure
Non-convex Optimization Using Parameter Continuation … 291

Fig. 9 The figures above demonstrate the convergence of various optimization methods when C-
Tanh is used as an activation function. In all cases, the X-axis shows the number of iterations, and
the Y-axis shows the log-loss or log-validation-loss. Note that the methods we propose provide
some of the lowest loss values, for training and testing, throughout the optimization procedure
292 H. Nilesh Pathak and R. Clinton Paffenroth

6 Open Problems and Future Directions

Continuation methods have a long history of solving non-convex optimization prob-


lems. Recently there has been great progress in enhancing deep learning training
using methods that are akin to continuation methods, for example, data curriculum
strategies [3, 23, 32, 66]. We believe this trend will continue and have an enor-
mous impact on more fields such as Reinforcement learning and Meta-learning. In
the paper, AI-generating algorithms (AI-GAs) [9], the authors proposed three pillars
for the future of general AI, namely, (1) meta-learning algorithms, (2) meta-learning
architectures and (3) generating effective learning environments. We believe that con-
tinuation methods are a good candidate to assist in these directions in a principled
manner. Next, we discuss a few open problems for the deep learning community.

6.1 Model Continuation

In [3, 50], the authors illustrated that continuation strategies can be introduced in
deep learning through model and data continuation. The methods illustrated in this
chapter are classical examples of model continuation. In Sect. 3, we discussed differ-
ent model continuation methods [21, 22, 44], and their advantages. These methods
have compared their training strategy with a target model and reported better gener-
alization error. However, we observe there are some limitations in these approaches.
Ideally, for the proposed continuation techniques it will be convenient to have them
empirically evaluated using different neural architectures and a variety of tasks. We
understand that such detailed evaluations may require substantial effort. First, we
observed that many different types of neural networks have not yet been system-
atically tested. For example, diffusion methods were applied on RNNs only [44].
MLPs and LSTMs were tested by [21] and only mollification were tested on CNNs,
RNNs and MLPs [22, 54]. In future research, it will be interesting to see how the
proposed curriculum affects the convergence of different types of networks. Second,
some methods [3, 22] were tested with limited depth and may not necessarily be
categorized as deep networks. Third, applications of curriculum learning may be
improved with more empirical evaluations from the methods in literature. In partic-
ular, it will be interesting to see comparisons with some techniques such as ResNets
[38], Dropout [58], and different normalizations for more tasks such as classification,
language modelling, regression and reconstruction [22].
As far as this chapter is concerned there are several natural extensions. First,
C-Activation can be any activation function, and in the case of multiple activation
functions in one network, we may add multiple homotopy parameters for each of
them. Also, as stated earlier, there is nothing in these techniques that makes them
specific in the AE case. In particular, we would also like to train Convolutional and
Recurrent Neural Networks described in [15, 39, 65] using our method. Second, our
adaptive method can be improved to conform to the homotopy path more accurately,
by using Pseudo-arclength methods [34, 48].
Non-convex Optimization Using Parameter Continuation … 293

6.2 Data Continuation

Many researchers have shown learning data from a designed curriculum leads to bet-
ter learning and robustness to out of distribution error. Some of the popular methods
are progressive training [32, 50], Curriculum Learning [3, 23], Curriculum Learning
by Transfer Learning [66], C-SMOTE [50], etc. These methods have shown improve-
ment in the generalization performance. However, there is usually some limitation
because of the data type. For example, the design of Progressive GANs [32] was
applicable for images but may prove to be difficult to apply for a text dataset. For
future research it will be interesting to see how data curriculum methods work with
different kinds of data, such as images, text, sound and time-series. As a first step
in this direction, we collect and share a list of datasets.4 These datasets are indexed
from different sources and hope they could prove to be a useful resource for fellow
researchers. Recent research work illustrates how a data curriculum can be learned
via continuation [50], and automation can be introduced by measuring a model’s
learning [18, 30, 55, 59, 66, 68] with limited benchmarks which could likely be
extended in several directions.

6.3 Loss Surface Understanding

A Deep Neural Network’s loss surface is usually non-convex [8, 17, 27], and depend-
ing on the neural architecture, many attempts have been made to theoretically catego-
rize the loss surfaces [17, 33, 41]. Continuation methods or curriculum learning may
provide a unique perspective to understand the loss surfaces of the neural network. In
particular, we plan to extend our work to perform Bifurcation Analysis of Neural
Networks. Bifurcations [2, 11, 49] are the dynamic changes in the behaviour of the
system that may occur while we track the solution through the homotopy path. In
our case, as we change our activation from linear to non-linear, detected bifurcations
[34] of neural networks may help us explain and identify local minima and they may
also help us to understand the so-called “black-box” of neural networks better.

6.4 Environment Generation Curricula

In Reinforcement learning, agents take some actions in an environment and then


provide rewards or feedback for the next action. Similar to data continuation, we
can think of the generation of training environments with varied complexities for
Reinforcement learning. For example, Reward Shaping [13, 20] is a method where a
user defines a curriculum of environments for the agent. However, while a curriculum
may be essential for some tasks in principle, in practice it is challenging to know

4 Dataset collection: https://github.com/harsh306/curriculum-datasets.


294 H. Nilesh Pathak and R. Clinton Paffenroth

the right curriculum for any given task [63]. In the paper [63], the authors empiri-
cally show that sometimes intuitively sound curriculum may even harm training (i.e.
learning harder tasks can lead to better solutions [56]). As a result, many research
works have designed self-generative curricula for training environments to introduce
progressively difficult and diverse scenarios for an agent and thus increase robust
learning [1, 59, 63]. One can easily imagine continuation methods having a role to
play for improving such methods.

6.5 Multi-task Learning

In multi-task learning [6] model is designed to perform multiple tasks through the
same set of parameters. Designing such a system is more difficult, as learning from
one task may elevate or degrade the learning for other task and may also lead to Catas-
trophic forgetting [4]. Synchronizing neural network to perform multiple tasks is an
open research problem and recently, curriculum learning is used to meet this chal-
lenge [53, 64]. These can be broadly categorized in two ways, first, design a curricula
that provides selected batch of data dynamically, so that it benefits multi-task objec-
tive [61, 64]. Usually one need to determine importance of each instance for learning
a task for implementing such an automotive system, for which Bayesian optimiza-
tion is a classic choice [64]. Second, determine the best order of tasks to be learned
[53]. The latter approach is restricted to be implemented sequentially, following the
determined curriculum [53]. We expect progressive research in this direction may
significantly enhance multi-task learning. Also, in the above-mentioned approaches
the first approach is similar to the data continuation, and second is similar to model
continuation. We believe that we can draw similarities between the two and make
impactful progress in the field of multi-task learning.

6.6 Hyperparameter Optimization

Meta-learning is an active area of research [14] which was recently surveyed in [9].
We understand there are many aspects of meta-learning [14, 46, 46, 62], but for scope
of our discussion we limit our focus to hyperparameter search. The search space of
different hyperparameters can be traced efficiently using continuation methods, or
even multi-parameter continuation [2, 11, 34]. One can imagine two hyperparam-
eter λm and λd , where λm corresponds to model continuation and λd enables data
continuation. Then, we can apply continuation methods to optimize both strategies
simultaneously for a given neural network task. Recently, researchers used implicit
differentiation for hyperparameter optimization [42], in which millions of network
weights and hyperparameters can be jointly tuned.
Multi-parameter Continuation Optimization is a promising area of research
in fields such as mathematical analysis [2, 11, 34] and others [19, 29]. One may
Non-convex Optimization Using Parameter Continuation … 295

define and optimize most of the components or hyperparameters of deep learning


framework via a continuation scheme. We may derive more efficient methods for
hyperparameter search that could be capable of evolving from a generalized network
via adding/pruning of layers and additional similar ideas.

7 Conclusions

In this chapter, we exhibited a novel training strategy for deep neural networks.
As observed in most of the experiments all three continuation methods achieved
faster convergence, lower loss and better generalization than standard optimization
techniques. Also, we empirically show that the proposed method works with popular
activation functions, deeper networks and with distinct datasets. Finally, we examine
some of the possible improvements and provide various future directions to further
in deep learning using continuation methods.

References

1. I. Akkaya, M. Andrychowicz, M. Chociej, M. Litwin, B. McGrew, A. Petron, A. Paino, M.


Plappert, G. Powell, R. Ribas et al., Solving rubik’s cube with a robot hand (2019). arXiv
preprint arXiv:1910.07113
2. E. Allgower, K. Georg, Introduction to numerical continuation methods. Soc. Ind. Appl. Math.
(2003). https://epubs.siam.org/doi/abs/10.1137/1.9780898719154
3. Y. Bengio, J. Louradour, R. Collobert, J. Weston, Curriculum learning (2009)
4. Y. Bengio, M. Mirza, I. Goodfellow, A. Courville, X. Da, An empirical investigation of catas-
trophic forgeting in gradient-based neural networks (2013)
5. Z. Cao, M. Long, J. Wang, P.S. Yu, Hashnet: deep learning to hash by continuation. CoRR
(2017). arXiv:abs/1702.00758
6. R. Caruana, Multitask learning. Mach. Learn. 28(1), 41–75 (1997)
7. T. Chan, K. Jia, S. Gao, J. Lu, Z. Zeng, Y. Ma, Pcanet: a simple deep learning baseline for
image classification? IEEE Trans. Image Process. 24(12), 5017–5032 (2015). https://doi.org/
10.1109/TIP.2015.2475625
8. A. Choromanska, M. Henaff, M. Mathieu, G.B. Arous, Y. LeCun, The loss surface of multilayer
networks. CoRR (2014). arXiv:abs/1412.0233
9. J. Clune, Ai-gas: ai-generating algorithms, an alternate paradigm for producing general artificial
intelligence. CoRR (2019). arXiv:abs/1905.10985
10. T. Dick, E. Wong, C. Dann, How many random restarts are enough
11. E.J. Doedel, T.F. Fairgrieve, B. Sandstede, A.R. Champneys, Y.A. Kuznetsov, X. Wang, Auto-
07p: continuation and bifurcation software for ordinary differential equations (2007)
12. J. Duchi, E. Hazan, Y. Singer, Adaptive subgradient methods for online learning and stochas-
tic optimization. J. Mach. Learn. Res. 12, 2121–2159 (2011). http://dl.acm.org/citation.cfm?
id=1953048.2021068
13. T. Erez, W.D. Smart, What does shaping mean for computational reinforcement learning? in
2008 7th IEEE International Conference on Development and Learning (2008), pp. 215–219.
https://doi.org/10.1109/DEVLRN.2008.4640832
14. C. Finn, P. Abbeel, S. Levine, Model-agnostic meta-learning for fast adaptation of deep net-
works, in Proceedings of the 34th International Conference on Machine Learning, vol. 70.
(JMLR. org, 2017), pp. 1126–1135
296 H. Nilesh Pathak and R. Clinton Paffenroth

15. I. Goodfellow, Y. Bengio, A. Courville, Deep Learning (MIT Press, 2016). http://www.
deeplearningbook.org
16. I.J. Goodfellow, NIPS 2016 tutorial: generative adversarial networks. NIPS (2017).
arXiv:abs/1701.00160
17. I.J. Goodfellow, O. Vinyals, Qualitatively characterizing neural network optimization problems.
CoRR (2014). arXiv:abs/1412.6544
18. A. Graves, M.G. Bellemare, J. Menick, R. Munos, K. Kavukcuoglu, Automated curriculum
learning for neural networks. CoRR (2017). arXiv:abs/1704.03003
19. C. Grenat, S. Baguet, C.H. Lamarque, R. Dufour, A multi-parametric recursive continuation
method for nonlinear dynamical systems. Mech. Syst. Signal Process. 127, 276–289 (2019)
20. M. Grzes, D. Kudenko, Theoretical and empirical analysis of reward shaping in reinforcement
learning, in 2009 International Conference on Machine Learning and Applications (2009), pp.
337–344. 10.1109/ICMLA.2009.33
21. C. Gülçehre, M. Moczulski, M. Denil, Y. Bengio, Noisy activation functions. CoRR (2016).
arXiv:abs/1603.00391
22. C. Gülçehre, M. Moczulski, F. Visin, Y. Bengio, Mollifying networks. CoRR (2016).
arXiv:abs/1608.04980
23. G. Hacohen, D. Weinshall, On the power of curriculum learning in training deep networks.
CoRR (2019). arXiv:abs/1904.03626
24. G. Hinton, L. Deng, D. Yu, G.E. Dahl, A. Mohamed, N. Jaitly, A. Senior, V. Vanhoucke, P.
Nguyen, T.N. Sainath, B. Kingsbury, Deep neural networks for acoustic modeling in speech
recognition: the shared views of four research groups. IEEE Signal Process. Mag. 29(6), 82–97
(2012). https://doi.org/10.1109/MSP.2012.2205597
25. G. Hinton, N. Srivastava, K. Swersky, Rmsprop: divide the gradient by a running average of
its recent magnitude. Neural networks for machine learning, Coursera lecture 6e (2012)
26. G.E. Hinton, R.R. Salakhutdinov, Reducing the dimensionality of data with neural networks.
Science 313(5786), 504–507 (2006). https://doi.org/10.1126/science.1127647, http://science.
sciencemag.org/content/313/5786/504
27. D.J. Im, M. Tao, K. Branson, An empirical analysis of deep network loss surfaces. CoRR
(2016). arXiv:abs/1612.04010
28. D. Jakubovitz, R. Giryes, M.R. Rodrigues, Generalization error in deep learning, in Compressed
Sensing and Its Applications (Springer, 2019), pp. 153–193
29. F. Jalali, J. Seader, Homotopy continuation method in multi-phase multi-reaction equilibrium
systems. Comput. Chem. Eng. 23(9), 1319–1331 (1999)
30. L. Jiang, Z. Zhou, T. Leung, L.J. Li, L. Fei-Fei, Mentornet: learning data-driven curriculum
for very deep neural networks on corrupted labels, in ICML (2018)
31. R. Johnson, F. Kiokemeister, Calculus, with Analytic Geometry (Allyn and Bacon, 1964).
https://books.google.com/books?id=X4_UAQAACAAJ
32. T. Karras, T. Aila, S. Laine, J. Lehtinen, Progressive growing of gans for improved quality,
stability, and variation. CoRR (2017). arXiv:abs/1710.10196
33. K. Kawaguchi, L.P. Kaelbling, Elimination of all bad local minima in deep learning. CoRR
(2019)
34. H.B. Keller, Numerical solution of bifurcation and nonlinear eigenvalue problems, in Applica-
tions of Bifurcation Theory, ed. by P.H. Rabinowitz (Academic Press, New York, 1977), pp.
359–384
35. D.P. Kingma, J. Ba, Adam: a method for stochastic optimization. CoRR (2014).
arXiv:abs/1412.6980
36. P. Krähenbühl, C. Doersch, J. Donahue, T. Darrell, Data-dependent initializations of convolu-
tional neural networks. CoRR (2015). arXiv:abs/1511.06856
37. A. Krizhevsky, V. Nair, G. Hinton, Cifar-10 (canadian institute for advanced research). http://
www.cs.toronto.edu/kriz/cifar.html
38. A. Krizhevsky, I. Sutskever, G.E. Hinton, Imagenet classification with deep convolutional
neural networks, in Advances in Neural Information Processing Systems (2012)
Non-convex Optimization Using Parameter Continuation … 297

39. Y. LeCun, Y. Bengio, G. Hinton, Deep learning. Nature (2015). https://www.nature.com/


articles/nature14539
40. C. Ledig, L. Theis, F. Huszár, J. Caballero, A. Cunningham, A. Acosta, A. Aitken, A. Tejani, J.
Totz, Z. Wang et al., Photo-realistic single image super-resolution using a generative adversarial
network, in Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition
(2017), pp. 4681–4690
41. S. Liang, R. Sun, J.D. Lee, R. Srikant, Adding one neuron can eliminate all bad local minima,
in S. Bengio, H. Wallach, H. Larochelle, K. Grauman, N. Cesa-Bianchi, R. Garnett (eds.),
Advances in Neural Information Processing Systems, vol. 31 (Curran Associates, Inc., 2018),
pp. 4350–4360. http://papers.nips.cc/paper/7688-adding-one-neuron-can-eliminate-all-bad-
local-minima.pdf
42. J. Lorraine, P. Vicol, D. Duvenaud, Optimizing millions of hyperparameters by implicit differ-
entiation (2019). arXiv preprint arXiv:1910.07113
43. T. Mikolov, I. Sutskever, K. Chen, G.S. Corrado, J. Dean, Distributed representations of words
and phrases and their compositionality, in C.J.C. Burges, L. Bottou, M. Welling, Z. Ghahra-
mani, K.Q. Weinberger (eds.), Advances in Neural Information Processing Systems, vol. 26
(Curran Associates, Inc., 2013), pp. 3111–3119. http://papers.nips.cc/paper/5021-distributed-
representations-of-words-and-phrases-and-their-compositionality.pdf
44. H. Mobahi, Training recurrent neural networks by diffusion. CoRR (2016).
arXiv:abs/1601.04114
45. H. Mobahi, III, J.W. Fisher, On the link between gaussian homotopy continuation and convex
envelopes, in Lecture Notes in Computer Science (EMMCVPR 2015) (Springer, 2015)
46. A. Nagabandi, I. Clavera, S. Liu, R.S. Fearing, P. Abbeel, S. Levine, C. Finn, Learning to adapt in
dynamic, real-world environments through meta-reinforcement learning (2018). arXiv preprint
arXiv:1803.11347
47. K. Nordhausen, The elements of statistical learning: data mining, inference, and prediction,
second edn. T. Hastie, R. Tibshirani, J. Friedman (eds.), Int. Stat. Rev. 77(3), 482–482
48. R. Paffenroth, E. Doedel, D. Dichmann, Continuation of periodic orbits around lagrange points
and auto2000, in AAS/AIAA Astrodynamics Specialist Conference (Quebec City, Canada, 2001)
49. R.C. Paffenroth, Mathematical visualization, parameter continuation, and steered computa-
tions. Ph.D. thesis, AAI9926816 (College Park, MD, USA, 1999)
50. H.N. Pathak, Parameter continuation with secant approximation for deep neural networks
(2018)
51. H.N. Pathak, X. Li, S. Minaee, B. Cowan, Efficient super resolution for large-scale images
using attentional gan, in 2018 IEEE International Conference on Big Data (Big Data) (IEEE,
2018), pp. 1777–1786
52. H.N. Pathak, R. Paffenroth, Parameter continuation methods for the optimization of deep neural
networks, in 2019 18th IEEE International Conference on Machine Learning And Applications
(ICMLA) (IEEE, 2019), pp. 1637–1643
53. A. Pentina, V. Sharmanska, C.H. Lampert, Curriculum learning of multiple tasks, in Pro-
ceedings of the IEEE Conference on Computer Vision and Pattern Recognition (2015), pp.
5492–5500
54. J. Rojas-Delgado, R. Trujillo-Rasúa, R. Bello, A continuation approach for training
artificial neural networks with meta-heuristics. Pattern Recognit. Lett. 125, 373–380
(2019). https://doi.org/10.1016/j.patrec.2019.05.017, http://www.sciencedirect.com/science/
article/pii/S0167865519301667
55. S. Saxena, O. Tuzel, D. DeCoste, Data parameters: a new family of parameters for learning a
differentiable curriculum (2019)
56. B. Settles, Active Learning Literature Survey, , Tech. rep. (University of Wisconsin-Madison
Department of Computer Sciences, 2009)
57. M. Seuret, M. Alberti, R. Ingold, M. Liwicki, Pca-initialized deep neural networks applied to
document image analysis. CoRR (2017). arXiv:abs/1702.00177
58. N. Srivastava, G. Hinton, A. Krizhevsky, I. Sutskever, R. Salakhutdinov, Dropout: a simple way
to prevent neural networks from overfitting. J. Mach. Learn. Res. 15(1), 1929–1958 (2014)
298 H. Nilesh Pathak and R. Clinton Paffenroth

59. F.P. Such, A. Rawal, J. Lehman, K. Stanley, J. Clune, Generative teaching networks: acceler-
ating neural architecture search by learning to generate synthetic training data (2020)
60. I. Sutskever, J. Martens, G. Dahl, G. Hinton, On the importance of initialization and momentum
in deep learning, in International Conference on Machine Learning (2013), pp. 1139–1147
61. Y. Tsvetkov, M. Faruqui, W. Ling, B. MacWhinney, C. Dyer, Learning the curriculum with
Bayesian optimization for task-specific word representation learning, in Proceedings of the
54th Annual Meeting of the Association for Computational Linguistics, Long Papers, vol. 1
(Association for Computational Linguistics, Berlin, Germany, 2016), pp. 130–139. https://doi.
org/10.18653/v1/P16-1013., https://www.aclweb.org/anthology/P16-1013
62. R. Vilalta, Y. Drissi, A perspective view and survey of meta-learning. Artif. Intell. Rev. 18(2),
77–95 (2002)
63. R. Wang, J. Lehman, J. Clune, K.O. Stanley, Paired open-ended trailblazer (POET): endlessly
generating increasingly complex and diverse learning environments and their solutions. CoRR
(2019). arXiv:abs/1901.01753
64. W. Wang, Y. Tian, J. Ngiam, Y. Yang, I. Caswell, Z. Parekh, Learning a multitask curriculum
for neural machine translation (2019). arXiv preprint arXiv:1908.10940 (2019)
65. M.A. Wani, F.A. Bhat, S. Afzal, A.I. Khan, Advances in deep learning, in Advances in Deep
Learning (Springer, 2020), pp. 1–11
66. D. Weinshall, G. Cohen, Curriculum learning by transfer learning: theory and experiments with
deep networks. CoRR (2018). arXiv:abs/1802.03796
67. H. Xiao, K. Rasul, R. Vollgraf, Fashion-mnist: a novel image dataset for benchmarking machine
learning algorithms (2017). https://github.com/zalandoresearch/fashion-mnist
68. H. Xuan, A. Stylianou, R. Pless, Improved embeddings with easy positive triplet mining (2019)
69. C. Zhou, R.C. Paffenroth, Anomaly detection with robust deep autoencoders, in Proceedings of
the 23rd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining
(ACM, 2017), pp. 665–674
Author Index

A K
Alfarhood, Meshal, 1 Kamran, Sharif Amit, 25
Arif Wani, M., 101 Kanarachos, Stratis, 229
Khoshgoftaar, Taghi M., 199
Kouchak, Shokoufeh Monjezi, 123
C
Cheng, Jianlin, 1
Christopoulos, Stavros-Richard G., 229 L
Clinton Paffenroth, Randy, 273 Leite, Guilherme Vieira, 49

D M
da Silva, Gabriel Pellegrino, 49 Mujtaba, Tahir, 101
Mukherjee, Tathagata, 143

F
Fischer, Georg, 173
N
Nilesh Pathak, Harsh, 273
G
Gaffar, Ashraf, 123
Gühmann, Clemens, 81 O
Onyekpe, Uche, 229

H
Hartmann, Sven, 81 P
Palade, Vasile, 229
Pasiliao, Eduardo, 143
I Pedrini, Helio, 49
Imran, Abdullah-Al-Zubaer, 249

R
J Rosato, Daniele, 81
Johnson, Justin M., 199 Roy, Debashri, 143
© The Editor(s) (if applicable) and The Author(s), under exclusive license 299
to Springer Nature Singapore Pte Ltd. 2021
M. A. Wani et al. (eds.), Deep Learning Applications, Volume 2,
Advances in Intelligent Systems and Computing 1232,
https://doi.org/10.1007/978-981-15-6759-9
300 Author Index

S T
Sabbir, Ali Shihab, 25 Tavakkoli, Alireza, 25
Sabir, Russell, 81 Terzopoulos, Demetri, 249
Saha, Sourajit, 25
Santra, Avik, 173
Stephan, Michael, 173

You might also like