Download as pdf or txt
Download as pdf or txt
You are on page 1of 13

2180 IEEE TRANSACTIONS ON SIGNAL PROCESSING, VOL.

70, 2022

Federated Generalized Bayesian Learning via


Distributed Stein Variational Gradient Descent
Rahif Kassab and Osvaldo Simeone , Fellow, IEEE

Abstract—This paper introduces Distributed Stein Varia-


tional Gradient Descent (DSVGD), a non-parametric generalized
Bayesian inference framework for federated learning. DSVGD
maintains a number of non-random and interacting particles at
a central server to represent the current iterate of the model global
posterior. The particles are iteratively downloaded and updated
by a subset of agents with the end goal of minimizing the global
free energy. By varying the number of particles, DSVGD enables
a flexible trade-off between per-iteration communication load and
number of communication rounds. DSVGD is shown to compare fa-
vorably to benchmark frequentist and Bayesian federated learning
strategies in terms of accuracy and scalability with respect to the
number of agents, while also providing well-calibrated, and hence
trustworthy, predictions.
Index Terms—Federated learning, bayesian learning, variational
inference.

I. INTRODUCTION
A. Motivation Fig. 1. Illustration of federated learning with K = 3 agents: 1) model training
is carried out at each agent using a local dataset; 2) information about the updated
EDERATED learning refers to the collaborative training of
F a machine learning model across agents with distinct local
data sets, and it applies at different scales, from industrial data si-
models, e.g., neural network weights, is shared by agents with the server; 3) the
server aggregates the received information; and 4) the server feeds back the
aggregated model information to the agents.
los to mobile devices [1]. While some common challenges exist,
such as the general statistical heterogeneity – “non-iidnes” – of • Trustworthiness: In many of the mentioned applications,
the local data sets, each setting also brings its own distinct prob- including personal health assistants and medical AI [3], the
lems. As illustrated in Fig. 1, in this paper, we are specifically learning agents’ recommendations need to be reliable and trust-
interested in a small-scale federated learning setting consisting worthy, e.g., to decide when to contact a doctor in case of a
of mobile or embedded devices, each having a limited data set possible emergency;
and running a small-sized model due to memory constraints. As • Number of communication rounds: When models are small,
an example, consider the deployment of health monitors based the payload per communication round may not be the main
on ECG data from smart-watches, or applications such as motion contributor to the overall latency of the training process. In
tracking and eye tracking [2]. In all these settings, there is a high contrast, accommodating many communication rounds, each
intrinsic cost of data collection due to the necessity of human requiring arbitrating channel access among multiple devices,
intervention for selection and labelling, as well as a inherent may be the main cause of a slow wall-clock time convergence [4].
limit on on-device storage. Trustworthiness may refer to several different criteria, broadly
In this context, we argue that it is essential to tackle the fol- classifiable in terms of fairness, explainability, auditability and
lowing challenges, which are largely not addressed by existing safety [5]. In this paper, we focus on a key requirement for
solutions: safety, namely the capacity of an algorithm to quantify un-
certainty [6], [7]. The problem of quantifying uncertainty is
Manuscript received September 26, 2021; revised February 3, 2022; accepted
April 12, 2022. Date of publication April 26, 2022; date of current version May particularly relevant and challenging in the settings of interest,
5, 2022. The associate editor coordinating the review of this manuscript and which are characterized by the availability of limited data [2],
approving it for publication was Dr. Kobi Cohen. This work was supported [8]–[10].
by the European Research Council under the European Union Horizon 2020
research and innovation program under Grant 725731. (Corresponding author: Bayesian learning is a promising solution to the problem at
Rahif Kassab.) hand of uncertainty-aware training with small data sets, as it
The authors are with the King’s Communications, Learning and Information properly accounts for the epistemic uncertainty arising from
Processing (KCLIP) Lab King’s College London, London WC2B4BG, U.K.
(e-mail: rahif.kassab@kcl.ac.uk; osvaldo.simeone@kcl.ac.uk). constraints in the availability of data [11]. This paper proposes
Digital Object Identifier 10.1109/TSP.2022.3168490 a novel federated learning approach that implements Bayesian
1053-587X © 2022 IEEE. Personal use is permitted, but republication/redistribution requires IEEE permission.
See https://www.ieee.org/publications/rights/index.html for more information.

Authorized licensed use limited to: KINGS COLLEGE LONDON. Downloaded on June 22,2022 at 00:40:34 UTC from IEEE Xplore. Restrictions apply.
KASSAB AND SIMEONE: FEDERATED GENERALIZED BAYESIAN LEARNING VIA DISTRIBUTED STEIN VARIATIONAL GRADIENT DESCENT 2181

learning, extending the advantages of Bayesian methods to models in the presence of decentralized, limited, data sets. While
federated settings. the focus of most existing works on federated learning is on re-
ducing the load per-communication round via compression [1],
B. Context DSVGD aims at reducing the number of communication rounds
at the cost of a larger per-iteration communication load.
Most existing federated learning algorithms, such as Feder-
As illustrated in Fig. 2, DSVGD maintains a number of non-
ated Averaging (FedAvg) [12], are based on frequentist learn-
random and interacting particles at a central server to represent
ing principles, relying on the identification of a single model
the current iterate of the global posterior. At each iteration, the
parameter vector, e.g., the weights of a neural network (see
particles are downloaded and updated by one of the agents by
Fig. 1). Frequentist learning is known to be unable to capture
minimizing a local free energy functional before being uploaded
epistemic uncertainty, yielding overconfident decisions [13].
to the server. Through the outlined non-parametric approach,
This is due to the fact that frequentist learning commits to a
DSVGD is shown to enable (i) a trade-off between per-iteration
single value of the model parameter vector, without accounting
communication load and number of communication rounds by
for the fact that there may be several models that are equally
varying the number of particles; while (ii) being able to make
effective at explaining the data, while differing significantly in
trustworthy decisions through Bayesian inference.
their functional behavior, e.g., predictions, outside it. In contrast,
The extension of SVGD entails several novel technical chal-
Bayesian learning accounts for the contribution of all models
lenges, in terms of both algorithm development and analysis,
that are consistent with the available data, enabling a principled
which are at the core of our contribution. Specifically, our main
quantification of uncertainty [14], [15]. Apart from quantifying
contributions are summarized as follows:
uncertainty, Bayesian learning offers further advantages, such r We propose DSVGD, a federated, non-parametric, deter-
as providing a natural and principled way of encoding and
ministic Bayesian learning algorithm. DSVGD addresses
accounting for prior knowledge and being inherently robust to
the problem of global free energy minimization in a dis-
overfitting via ensembling [11].
tributed manner by leveraging SVGD to tackle local free
Federated Bayesian learning has the general aim of com-
energy minimization problems on devices.
puting the global posterior distribution in the model parameter r While a naïve generalization of SVGD to a federated
space. Existing decentralized, or federated, Bayesian learning
setting would require memory sizes at the server and de-
protocols are based on approximated Bayesian inference via
vices that scale linearly with the number of communication
Variational Inference (VI) [16]–[19] or Monte Carlo (MC) sam-
rounds, DSVGD integrates a novel distillation step at the
pling [20]–[22]. VI converts the problem of Bayesian inference
server to ensure constant and flexible storage requirements.
into an optimization over a parametric family of models; while r We derive fixed-point properties of the proposed method,
MC sampling transforms it into a problem of sampling from an
as well as a novel bound on the decrease of the global free
unnormalized distribution. The performance of VI-based proto-
energy as a function of reductions in local free energy met-
cols is generally limited by the bias entailed by the variational
rics, which yields insight into the convergence of DSVGD.
approximation, while MC sampling is slow and suffers from the r We provide extensive experimental results to substantiate
difficulty of assessing convergence [16].
the gains of DSVGD in terms of accuracy and calibration
The state-of-the-art for federated VI methods is set by Parti-
with respect to the state of the art.
tioned Variational Inference (PVI), a recently introduced uni-
The rest of the paper is organized as follows. In Section II,
fying distributed and continual VI framework. For MC sampling,
we review related work on distributed learning and SVGD,
the current state of the art is given by Distributed Stochastic
while Section III and Section IV present the system set-up and
Gradient Langevin Dynamics (DSGLD), which maintains a
the distributed VI problem. Section V reviews the necessary
number of Markov chains updated via local Stochastic Gradient
background, while DSVGD is proposed in Section VI. The
Descent (SGD) with the addition of Gaussian noise [20], [23],
complexity of DSVGD is analyzed in Section VII. Experimental
[24].
results are presented in Section VIII and conclusions are drawn
Stein Variational Gradient Descent (SVGD) has been in-
in Section IX.
troduced in [25] as a non-parametric Bayesian method that
approximates a target posterior distribution via non-random
and interacting particles. SVGD inherits the flexibility of non-
parametric Bayesian inference methods, while improving the II. RELATED WORK
convergence speed of MC sampling [25]. By controlling the
Frequentist federated learning: Since the pioneering
number of particles, SVGD can provide flexible performance in
work [12] introducing the FedAvg algorithm, federating learning
terms of bias, convergence speed, and per-iteration complexity.
has been extensively studied both from the viewpoint of signal
processing and with the goal of exploring practical applications.
C. Contributions
Most of these works have focused on frequentist federated learn-
This paper introduces a novel non-parametric distributed ing, and recent reviews can be found in [26]–[28]. Earlier work
learning algorithm, termed Distributed Stein Variational Gra- on signal processing for distributed learning and optimization is
dient Descent (DSVGD). DSVGD extends SVGD to a federated described in [29]. Highlighting some specific contributions that
setting with the goal of optimizing trustworthy machine learning are particularly relevant for signal processing, references [30],

Authorized licensed use limited to: KINGS COLLEGE LONDON. Downloaded on June 22,2022 at 00:40:34 UTC from IEEE Xplore. Restrictions apply.
2182 IEEE TRANSACTIONS ON SIGNAL PROCESSING, VOL. 70, 2022

Fig. 2. Federated learning across K agents equipped with local datasets and assisted by a central server: (a) in Distributed Variational Inference (DVI) agents
exchange the current model posterior q (i) (θ) with the server, while (b) in DSVGD agents exchange particles {θn }N
n=1 providing a non-parametric estimate of the
posterior.

[31] focus on communication efficiency by proposing commu- Method of Multipliers (ADMM) [16]. State-of-the-art results
nication schemes based on quantized gradients for FL; while have been obtained via DSGLD [20].
the works [32], [33] aim to improve performance of FL in the Distributed VI Learning: Considering local models fusion,
presence of non-iid datasets by leveraging a small shared dataset Bayesian methods have been used to deal with parameter
across all devices. The related work [33] adopts meta-learning to invariance and weight matching [54], [55]. Iterative VI such
optimize personalized models in such settings, and the authors as streaming variational Bias (SVB) [18] provide a VI-based
of [34] make use of reinforcement learning. Another line of framework for the exponential family to combine local models.
work, exemplified by [35]–[37], investigates the convergence PVI offers a general framework that can implement SVB, as well
speed under different constraints such as radio resources and as online VI [56] and has been extended to multi-task learning
differential privacy. None of these prior works address Bayesian in [19].
learning and uncertainty quantification.
Extensions of SVGD: Since its introduction, SVGD has been III. SYSTEM SET-UP AND PROBLEM DEFINITION
extended in various directions. Most related to this work is [38],
We consider the federated learning set-up in Fig. 2, where each
which proposes a message-passing SVGD solution for high-
agent k = 1, . . . , K has a distinct local dataset with associated
dimensional latent spaces by leveraging conditional indepen-
training loss Lk (θ) for model parameter θ. The agents commu-
dence properties in the variational posterior; and [39], which
nicate through a central node with the goal of computing the
uses SVGD as the per-task base learner in a meta-learning
global posterior distribution q(θ) over the shared model param-
algorithm approximating expectation maximization.
eter θ ∈ Rd for some prior distribution p0 (θ) [16]. Specifically,
Generalized Bayesian Inference: Owing to its reliance on
following the generalized Bayesian learning framework [47],
point estimates in the model parameter space, frequentist learn-
the agents aim at obtaining the distribution q(θ) that minimizes
ing methods, such as Federated Stochastic Gradient Descent
the global free energy
(FedSGD), FedAvg, and their extensions [35], [40]–[42] are
limited in their capacity to combat overfitting and quantify un- min F (q(θ))
q(θ)
certainty [13], [43]–[45]. This contrasts with Bayesian learning,
which produces distributional, rather than point, estimates by 
K
optimizing the free energy functional, which is a theoretically where F (q(θ)) = Eθ∼q(θ) [Lk (θ)] + αD(q(θ)||p0 (θ)),
principled bound on the generalization performance [46], [47]. k=1
Practical algorithms for Bayesian learning can leverage com- (1)
putationally efficient scalable solutions based on either MC where α > 0 is a temperature parameter. The (generalized, or
sampling or VI methods [16], [48]. Gibbs) global posterior qopt (θ) solving problem (1) must strike
Distributed MC Sampling: The design of algorithms for dis- a balance between minimizing the sum loss function (first term
tributed Bayesian learning has been mostly focused on one- in F (q)) and the model complexity defined by the divergence
shot, or “embarrassingly parallel,” solutions under ideal com- from a reference prior (second term in F (q)). It is given as
munications [49]. These implement distributed MC “consen-
sus” protocols, whereby samples from the global posterior are q̃opt (θ)
qopt (θ) =
approximately synthesized by combining particles from local Z
 
posteriors [50], [51]. Iterative extensions, such as Weierstrass 1
K
sampling [52], [53], impose consistency constraints across de- where q̃opt (θ) = p0 (θ) exp − Lk (θ) , (2)
α
vices and iterations in a way similar to the Alternating Direction k=1

Authorized licensed use limited to: KINGS COLLEGE LONDON. Downloaded on June 22,2022 at 00:40:34 UTC from IEEE Xplore. Restrictions apply.
KASSAB AND SIMEONE: FEDERATED GENERALIZED BAYESIAN LEARNING VIA DISTRIBUTED STEIN VARIATIONAL GRADIENT DESCENT 2183

where we denoted as Z the normalization constant. It is useful to prior for the update in (5). In a manner similar to (2), the
note that the global free energy can also be written as the scaled local free energy is minimized by the tilted distribution
(i) (i)
KL F (q(θ)) = αD(q(θ)||q̃opt (θ)). pk (θ) ∝ p̃k (θ) with
The main challenge in computing the optimal posterior

qopt (θ) in a distributed manner is that each agent k is only aware (i) (i) 1
p̃k (θ) = p̂k (θ) exp − Lk (θ) ; (7)
of its local loss Lk (θ). By exchanging information through the α
server, the K agents wish to obtain an estimate of the global
(i)
posterior (2) without disclosing their local datasets neither to 3) Agent k sends the updated posterior q (i) (·) = pk (·) to the
the server nor to the other agents. In this paper, we introduce a server (see Fig. 2(a), step 3), and, due to the factorization
novel non-parametric distributed generalized Bayesian learning in (3), updates its approximate likelihood as
framework that addresses this challenge by integrating Dis-
tributed VI (DVI) and SVGD [25]. (i) q (i) (θ) (i−1)
tk (θ) = t (θ). (8)
q (i−1) (θ) k
IV. DISTRIBUTED VARIATIONAL INFERENCE
(i) (i−1)
Finally, non-scheduled agents k  = k set tk (θ) = tk (θ),
In this section, we describe a general DVI algorithm to com-
and the server sets the next iterate as q (i) (θ).
pute the global posterior in a federated fashion [19], [56]. DVI
Generalizing the described DVI algorithm, in the paral-
starts from the observation that the posterior (2) factorizes as the
lel implementation, at each iteration i, a subset K(i) of
product
agents is scheduled. The next iterate is set as q (i) (θ) =
 (i)  (i) (i)

K
p0 (θ) k∈K(i) tk (θ) k ∈K(i) tk (θ), where we have tk (θ) =
q(θ) = p0 (θ) tk (θ), (3) (i−1) (i)
k=1
tk (θ) for k  ∈ K(i) , while tk (θ) is updated by using (8) for
k ∈ K(i) .
where the term tk (·) is given by the scaled local likelihood Theorem 1: Assuming that all devices are periodically sched-
exp(α−1 Lk (θ))/Zk for any constants {Zk }K k=1 such that we uled, the global posterior qopt (θ) in (2) is the unique fixed point

have Z = K k=1 Z k . Since the normalization constant Z de- of the DVI algorithm for any choice of the subsets K(i) .
pends on all data sets, the true scaled local likelihood tk (·) cannot The fixed-point property in Theorem 1 can be veri-
be directly computed at agent k. The idea of DVI is to iteratively (i−1)
fied directly by setting q (i−1) (θ) = qopt (θ) and tk (θ) =
update approximate likelihood factors tk (θ) for k = 1, . . ., K −1
K
exp(α Lk (θ))/Zk for any {Zk }k=1 such that k=1 Zk = Z,
K
by means of local optimization steps at the agents and communi-
and by observing that this leads to the fixed point condition
cation through the server, with the aim of minimizing the global
q (i) (θ) = q (i−1) (θ) = qopt (θ). The proof is provided in Sec. A
free energy (1) over distribution (3).
of the Appendix.
We first give here the standard implementation of DVI in
which a single agent is scheduled at each time, and then briefly
discuss the more general parallel implementation. Accordingly, V. PRELIMINARIES
at each communication round i = 1, 2, . . ., the server maintains In this section, we briefly review PVI, which serves as an im-
the current iterate q (i−1) (θ) of the global posterior, and schedules portant benchmark, and SVGD, on which we build the proposed
an agent k ∈ {1, 2, . . . , K}, which proceeds as follows: Bayesian federated learning solution.
1) Agent k downloads the current global variational posterior
distribution q (i−1) (θ) from the server (see Fig. 2(a), step
A. Partitioned Variational Inference (PVI)
1);
2) Agent k updates the global posterior by minimizing the The exact minimization of the local free energy function
(i) (5) assumed by DVI is often not tractable. To address this
local free energy Fk (q(θ)) (see Fig. 2(a), step 2)
  problem, in its most typical form, PVI constrains the local
(i) free energy minimization (5) to the space of  parametric dis-
q (i) (θ) = argmin Fk (q(θ)) , (4)
q(θ) tributions that factorize as q(θ|η) = p0 (θ|η0 ) K k=1 tk (θ|ηk ),
(i) (i)
where prior p0 (·|η0 ) = ExpFam(·|η0 ) and approximate like-
where Fk (q(θ)) = Eθ∼q(θ) [Lk (θ)] + αD(q(θ)||p̂k (θ)) lihood tk (·|ηk ) = ExpFam(·|ηk ) are selected from the same
(5) exponential-family distribution, with natural parameters η0 and
where we have defined the (unnormalized) cavity distri- ηk , respectively. PVI follows the same steps as DVI with the
(i) caveat that the local free energy (5) for agent k is minimized
bution p̂k (θ) as
over the natural parameter η. This can be done efficiently, albeit
(i) q (i−1) (θ) approximately, using natural gradient descent [57]. The bias
p̂k (θ) = (i−1)
. (6)
tk (θ) imposed by the parametrization in PVI significantly affects the
quality of the approximation of the obtained posterior q(θ) with
(i)
The cavity distribution p̂k (θ), which removes the con- respect to the true global posterior qopt (θ) in the presence of
tribution of the current approximate likelihood of agent model misspecification. Note that, in this case, the fixed-point
k from the current global posterior iterate, serves as a property in Theorem 1 no longer applies.

Authorized licensed use limited to: KINGS COLLEGE LONDON. Downloaded on June 22,2022 at 00:40:34 UTC from IEEE Xplore. Restrictions apply.
2184 IEEE TRANSACTIONS ON SIGNAL PROCESSING, VOL. 70, 2022

B. Stein Variational Gradient Descent (SVGD) requiring each agent to store a number of particles that increases
SVGD [25] tackles the minimization of the (scaled) free linearly with the number of iterations in which the agent is
energy functional D(q(θ)||p̃(θ)), for an unnormalized target scheduled. Then, we present a more practical algorithm, for
distribution p̃(θ), over a non-parametric generalized posterior which the memory requirements do not scale with the number
q(θ) defined over the model parameters θ ∈ Rd . The poste- of iterations as each agent must only memorize a set of N
rior q(θ) is represented by a set of particles {θn }N local particles across different iterations. Algorithmic table for
n=1 , with
θn ∈ Rd . In practice, an approximation of q(θ) can be obtained Unconstrained-DSVGD (U-DSVGD) in addition to discussions
from the particles {θn }N on complexity and convergence, can be found respectively in
n=1 through a Kernel Density Estimator
the extended version of the paper [60]. A direct extension of
(KDE) as q(θ) = N −1 N n=1 K(θ, θn ) for some kernel function DSVGD where multiple agents are scheduled per round can be
K(·, ·) [58]. The particles are iteratively updated through a
found in Sec. B of the Appendix.
series of transformations that are optimized to minimize the
free energy. The transformations are restricted to lie within
the unit ball of a Reproducing Kernel Hilbert Space (RKHS) A. U-DSVGD
Hd = H × . . . × H. It is shown by [25] that this optimization
yields the SVGD update In this section, we present a simplified DSVGD variant,
which we refer to as U-DSVGD. We follow the standard im-
  [l−1] [l−1]  
N
θ [l−1] [l−1] plementation of DVI with a single agent k scheduled at each
θn[l] ←
−n + k θj , θn ∇θj log p̃ θj communication round i = 1, 2, . . ., although, as discussed, par-
N j=1
(i)
 allel implementations are also possible. Let us define as Ik ⊆
[l−1]
+ ∇θj k θj , θn[l−1] (9) {1, . . . , i} the subset of rounds at which agent k is scheduled
prior, and including, iteration i. At the beginning of each round
for n = 1, . . . , N , where k(·, ·) is the positive definite kernel i, the server maintains the iterate of the current global particles
(i−1)
associated with RKHS H. The first term in the update (9) drives {θn }N n=1 , while each agent k keeps a local buffer of particles
the particles towards the regions of the target distribution p̃(θ) (j−1) (j) N (i−1)
{θn , θn }n=1 for all previous rounds j ∈ Ik at which
with high probability, while the second term drives the particles agent k was scheduled. The growing memory requirements at
away from each other, encouraging exploration in the model the agents will be dealt with by the final version of DSVGD
parameter space. It is known that, in the asymptotic limit of a to be introduced in Section VI-B. Furthermore, as illustrated
large number N of particles, the empirical distribution encoded in Fig. 2(b), at each iteration i, U-DSVGD schedules an agent
[l]
by the particles {θn }N n=1 converges to the normalized target k ∈ {1, 2, . . . , K} and carries out the following steps.
distribution p(θ) ∝ p̃(θ) [59]. 1) Agent k downloads the current global particles
(i−1)
{θn }N n=1 from the server (see Fig. 2(b), step 1) and
VI. DISTRIBUTED STEIN VARIATIONAL GRADIENT includes them in the local buffer.
DESCENT (DSVGD) 2) Agent k updates each downloaded particle using SVGD
In this section, we introduce DSVGD, a novel distributed al- as
gorithm that tackles the generalized Bayesian inference problem 
(1) via DVI over a non-parametric particle-based representation θn[l] ←
− θn[l−1] + φ θn[l−1] , for l = 1, . . . , L, (10)
of the global posterior. As illustrated in Fig. 2(b), DSVGD is
based on the iterative optimization of local free energy func-
where L is the number of local iterations; [l] denotes
tionals (5) via SVGD (see Section V), and on the exchange [0]
of particles between the central server and agents. Given the the local iteration index; we have the initialization θn =
(i−1)
flexibility of the non-parametric form of the posterior, DSVGD θn ; and the function φ(·) is to be optimized within the
does not suffer from the bias caused by the parametrization unit ball of a RKHS Hd . The function φ(·) is specifically
assumed by PVI. As a result, in the limit of a sufficiently optimized to maximize the steepest descent decrease of a
large number of particles, DSVGD benefits from the fixed point particle-based approximation of the local energy (5). To
(i−1)
property of DVI stated in Theorem 1, recovering the true global elaborate, we denote as q (i−1) (θ) = N n=1 K(θ, θn )
posterior as a fixed point of its iterations. Furthermore, as we will the KDE of the current global posterior iterate encoded by
(i−1)
discuss, DSVGD enables devices to exchange more informative particles {θn }N n=1 . Adopting the factorization (3) for
messages regarding the current iterate of the posterior by increas- the global posterior (cf. (8)), we define the current local
ing the number of particles. This can in turn reduce the number approximate likelihood
of communication rounds and the overall communication load
to convergence, at the cost of a larger per-round load. In this
(i−1)
 q (j) (θ) q (i−1) (θ) (i−2)
regard, we note that, in practice, a small number of particles is tk (θ) = = t (θ).
(i−1)
q (j−1) (θ) q (i−2) (θ) k
sufficient to obtain state-of-the-art performance [25], as verified j∈Ik
in Section VIII. (11)
In order to facilitate the presentation, we first introduce a Note that (11) can be computed using all the particles in the
simpler version of DSVGD that has the practical drawback of buffer at agent k at iteration i. Finally, the (unnormalized)

Authorized licensed use limited to: KINGS COLLEGE LONDON. Downloaded on June 22,2022 at 00:40:34 UTC from IEEE Xplore. Restrictions apply.
KASSAB AND SIMEONE: FEDERATED GENERALIZED BAYESIAN LEARNING VIA DISTRIBUTED STEIN VARIATIONAL GRADIENT DESCENT 2185

Using (11), the second gradient term can be obtained in a


Algorithm 1: Distributed Stein Variational Gradient De-
recursive manner using the local buffer as
scent (DSVGD).
⎧ (i−2)
1: Input: prior p0 (θ), local loss functions {Lk (θ)}K k=1 , ⎪
⎪ ∇θ log tk (θ) if agent k
temperature α > 0, kernels K(·, ·) and k(·, ·) ⎪

(i−1) not scheduled at iteration (i − 1)
(0)
2: Initialize: q (0) (θ) = p0 (θ); {θn }N
i.i.d ∇θ log tk (θ) =
n=1 ∼ p0 (θ);
(i−2)

⎪ ∇θ log tk (θ) + ∇θ log q (i−1) (θ)
(0) (0) (0) ⎪

{θk,n = θn }N n=1 and tk (θ) = 1 for k = 1, . . . , K −∇θ log q (i−2) (θ) otherwise.
3: Output: global approximate posterior q(θ)
(16)
4: for i = 1 to I do
5: Server schedules an agent k Finally, the gradients ∇θ log q (j) (θ) can be directly com-
6: Agent k downloads current global particles puted from the KDE expression of q (j) (θ), with initializations
(i−1)
{θn }N n=1 from server t(0) (θ) = 1 and q (0) (θ) = p0 (θ).
(i)
7: Agent k obtains updated global particles {θn }N n=1 The inner loop of U-DSVGD inherits the asymptotic con-
(i−1) (i−1) N vergence properties of SVGD in terms of local free energies,
using (14), {θn }N n=1 and {θ k,n }n=1
8: Agent k sends the updated global particles but existing results do not imply that the global free energy
(i) decreases across the iterations. This result is provided in the
{θn }Nn=1 to the server
(i) next theorem, whose precise formulation can be found in Sec.
9: Agent k carries distillation to obtain {θk,n }N n=1
(i) (i)
A of the Appendix.
encoding tk (θ) using (18) and {θn }N
n=1 Theorem 2: (Guaranteed per-iteration decrease of the global
10: end for free energy.) The decrease in the global free energy from local
(I)
11: return q(θ) = N −1 N n=1 K(θ, θn ) iteration l to l + 1 during communication round i for which
agent k is scheduled can be lower bounded as
(i)
(i)
tilted distribution p̃k (cf. (7)) is written as F (q [l] (θ)) − F (q [l+1] (θ)) ≥ αS(q [l] , p̃k )(1 − γ)


(i) q (i−1) (θ) 1 − 2α(K − 1)lmax (i)
2D(q [l+1] ||q [l] ), (17)
p̃k (θ) = (i−1) exp − Lk (θ) . (12)
t (θ) α
k (i) (i−1)
where lmax = sup max| log(tm (θ)) · exp( α1 Lm (θ))|, S(q, p)
Following SVGD, the update (10) is optimized to max- θ m=k
imize the steepest descent decrease of the Kullback– denotes the Kernalized Stein Discrepancy between distributions
Leibler (KL) divergence between the approximate global q and p [61], and γ is a constant depending on the RKHS kernel
[l] [l]
posterior qφ (θ) encoded via particles {θn }Nn=1 and the
and the target distribution.
(i) The first term in bound (17) quantifies the decrease in the
tilted distribution p̃k (θ) in (12) (see Fig. 2(b), step 2),
local free energy at agent k, which depends on the “distance”
i.e.,
between current iterate q [l] and the local target given by the
(i)
φ (·) ←
− arg maxφ(·)∈Hd tilted distribution pk (θ); while the second term quantifies the
  effect of the update on the local free energies of other agents.
d [l−1] (i)
× − D(qφ (θ)||p̃k (θ)):||φ||Hd ≤ 1 . (13) Interestingly, in the presence of only one agent, the second
d
term reduces to zero, and one recovers the upper bound on the
Thus, recalling (9), the particles are updated as guaranteed per-iteration improvement for SVGD derived in [62].
  [l−1] [l−1]  
N
θ [l−1] (i) [l−1]
θn[l] ←
−n + k θj ,θn ∇θj log p̃k θj B. DSVGD
N j=1
In this section, we describe the final version of DSVGD,

[l−1] which, unlike U-DSVGD, requires each agent k to maintain only
+∇θj k θj , θn[l−1] , for l = 1, . . . , L. (i)
N local particles {θk,n }N n=1 across the communication rounds
(14)
i = 1, 2, . . .. To this end, in each round i, at the end of the L local
(i) [L] SVGD updates in (14), DSVGD carries out a form of model
3) Agent k sets θn = θn for n = 1, . . . , N . Particles
(i) distillation [63], [64] via SVGD. Specifically, L additional
{θn }N n=1 are added to the buffer and sent to the server (see (i)
Fig. 2(b), step 3) that updates the current global particles SVGD steps are used to approximate the term tk (θ) using the
(i)
as {θn }N
(i) N N local particles {θk,n }N n=1 . It is noted that this approximation
n=1 = {θn }n=1 .
In order to implement the described U-DSVGD algorithm, we step is not necessarily harmful to the overall performance, since
(i)
need to compute the gradient in (14) at agent k. First, by (12), describing the factor tk (θ) with fewer particles can have a
we have denoising effect acting as a regularizer.
(i) (i−1) DSVGD operates as U-DSVGD apart from the computation
∇θ log p̃k (θ) = ∇θ log q (i−1) (θ) − ∇θ log tk (θ)
of the gradient in (15) and the management of the local particle
1 buffers. The key idea is that, instead of using the recursion (16) to
− ∇θ Lk (θ). (15) (i−1)
α compute (15), DSVGD computes the gradient ∇θ log tk (θ)

Authorized licensed use limited to: KINGS COLLEGE LONDON. Downloaded on June 22,2022 at 00:40:34 UTC from IEEE Xplore. Restrictions apply.
2186 IEEE TRANSACTIONS ON SIGNAL PROCESSING, VOL. 70, 2022

(i−1) N (i−1) Time Complexity: When scheduled, an agent has to perform


from the KDE tk (θ) = n=1 K(θ, θk,n ) based on the local
(i−1) O(max(L, L )N 2 ) operations with O(LN 2 ) operations to up-
particles {θk,n }N n=1 in the buffer. At the end of each round i, date the global particles (line 5 in Algorithm 1) with L local
(i−1) 
the local particles {θk,n }Nn=1 are updated by running L local iterations and O(L N 2 ) operations to carry out distillation (line
SVGD iterations with target given by the updated local factor 9 in Algorithm 1) with L local iterations. Furthermore, the L
(i) (i) (i−1)
tk (θ) = qq(i−1)(θ) t
(θ) k
(θ). This amounts to the updates distillation iterations can be performed by the scheduled agent
after it has sent its global particles to the central server. This
enables the pipelining of the second loop with the operations at
  [l −1] [l −1] 
 N
[l ] θ [l −1] (i) the server and at other agents, which can potentially reduce the
θk,n ←
−k,n + k θk,j , θk,n ∇θj log tk (θ)
N j=1 wall-clock time per communication round. It is important to note
  that, from the perspective of each agent, the time complexity of
[l −1] [l −1]
+ ∇θj k θk,j , θk,n , (18) DSVGD is similar to SVGD.
Space Complexity: DSVGD inherits the space complexity of
SVGD. In particular, DSVGD requires the computation of the
for l = 1, . . . , L and some learning rate  , where the
(i) (i−1) kernel matrix k(·, ·) between all particles at each local iteration,
gradient ∇θ log tk (θ) = ∇θ log q (i) (θ) + ∇θ log tk (θ) − which can then be deleted before the next iteration. This requires
(i−1)
∇θ log q (θ) can be directly computed using KDE based O(N 2 ) space complexity. As pointed out by [25] and noticed
(i)
on the available particles {θn }N n=1 (updated global particles), in our experiments reported in the next section, for sufficiently
(i−1) N (i−1)
{θk,n }n=1 (local particles) and {θn }N n=1 (downloaded small problems of practical interest for mobile embedded ap-
global particles). Finally, we note that the distillation operation plications, few particles are enough to obtain state-of-the art
can be performed after sending the updated global particles to performance. Furthermore, N particles of dimension d need to
the server and thus enabling pipelining of the L local itera- be saved in the local buffer, requiring O(N d) space. Given that
tions with operations at the server and other agents. DSVGD N is generally much lower than the number of data samples,
is summarized in Algorithm 1 while Parallel-Distributed Stein saving the particles in the local buffer is not problematic.
Variational Gradient Descent (P-DSVGD) is detailed in Sec. B
and Algorithm 2 in the Appendix. VIII. EXPERIMENTS
Using DSVGD, the communication load between a scheduled
agent and the central server is of the order O(N d) since N As in [25], for all our experiments1 with SVGD and DSVGD,
particles of dimensions d need to be exchanged at each com- we use the Radial Basis Function (RBF) kernel k(x, x0 ) =
munication round. In contrast, the communication load of PVI exp(−||x − x0 ||22 /h). The bandwidth h is adapted to the set of
depends on the selected parametrization. For instance, one can particles used in each update by setting h = med2 / log n, where
use PVI with a fully factorized Gaussian approximate posterior, med is the median of the pairwise distances between the particles
which requires only 2˜d parameters to be shared with the server, in the current iterate. The Gaussian kernel K(·, ·) used for the
namely mean and variance of each of the d parameters at the price KDEs has a bandwidth equal to 0.55. Unless specified other-
of having lower accuracy. That said, while DSVGD increases the wise, we use AdaGrad with momentum to choose the learning
communication load compared to state of the art FL algorithms, rates  and  for DSVGD. Throughout, we fix the temperature
the number of particles required is very low in practice [25], and parameter α = 1 in (1). Finally, to ensure a fair comparison
the overall communication load across the iterations turns out with distributed schemes, we run centralized schemes for the
to be of the same order as FedAvg in all our experiments (see same total number I × L of iterations across all experiments.
Section VIII). Additional results can be found in the extended version of
the paper [60], which include also additional implementation
details, robustness to data heterogeneity (non-i.i.d) experiments
VII. CONVERGENCE AND SPACE-TIME COMPLEXITY and comparison with PVI and U-DSVGD.
Gaussian 1D mixture toy example: We start by considering
This section summarizes known properties of SVGD in terms
a simple one-dimensional mixture model in which the local
of convergence and complexity.
unnormalized local posteriors pk (θ) = p0 (θ) exp(−α−1 Lk (θ))
Convergence: The two local SVGD loops in Algorithm 1
at each agent k are defined as p1 (θ) = p0 (θ)N (θ|1, 4) and
produce a set of global and local particles, respectively, that are
p2 (θ) = p0 (θ)(N (θ| − 3, 1) + N (θ|3, 2)) and the prior p0 (θ)
convergent to their respective targets as the number N of parti-
is uniform over [−6, 6], i.e., p0 (θ) = U (θ| − 6, 6). The local
cles increases [59]. Furthermore, as discussed, a fixed point of
posteriors are shown in Fig. 3 as dashed lines, along with the
the set of local free energy minimization problems is guaranteed
global posterior qopt (θ) ∝ q̃opt (θ) in (2), which is represented
to be a local optimum for the global free energy problem (see
as a shaded area. We fix the number of particles to N = 200.
Property 3 in [56]). This property hence carries over to DSVGD
The approximate posteriors obtained from the KDE over the
in the limit of large number of particles. Theorem 1 shows the
global particles are plotted in Fig. 3 as solid lines. It can be
uniqueness of such point. However, the convergence rate to a
fixed point is an open question for PVI, and consequently also
for DSVGD. Nevertheless, Theorem 2 provides a closed bound 1 Our code is made public and can be found at https://github.com/kclip/
on the convergence rate. DSVGD.

Authorized licensed use limited to: KINGS COLLEGE LONDON. Downloaded on June 22,2022 at 00:40:34 UTC from IEEE Xplore. Restrictions apply.
KASSAB AND SIMEONE: FEDERATED GENERALIZED BAYESIAN LEARNING VIA DISTRIBUTED STEIN VARIATIONAL GRADIENT DESCENT 2187

Fig. 3. Gaussian toy example with uniform prior and K = 2. Dashed lines represent local posteriors, the shaded area represents the true global posterior, while the
solid black line is the approximate posterior obtained using a KDE over the particles. DSVGD schedules agent 1 and 2 at odd and even number of communication
rounds i, respectively.

Fig. 4. Accuracy for Bayesian logistic regression with (left) K = 2 agents and (right) K = 20 agents as function of the number of communication rounds i
(N = 6 particles, L = L = 200).

observed that at each round, the global posterior updated by


DSVGD integrates the local likelihood of the scheduled agent,
while still preserving information about the likelihood of the
other agent from prior iterates, until (approximate) convergence
to the true global posterior qopt , which is a normalized version
of q̃opt in (2), is reached.
Bayesian logistic regression: We now consider Bayesian lo-
gistic regression for binary classification using the same setting
as in [65]. The model parameters θ = [w, log(ξ)] include the re-
gression weights w ∈ Rd along with the logarithm of a precision Fig. 5. Bayesian logistic regression accuracy for covertype (left) and twonorm
parameter ξ. The prior is given as p0 (w, ξ) = p0 (w|ξ)p0 (ξ), (right) datasets with K = 20 agents using the setting in [65] comparing U-
with p0 (w|ξ) = N (w|0, ξ −1 Id ) and p0 (ξ) = Gamma(ξ|a, b) DSVGD and DSVGD to distributed (DSGLD) and centralized (SVGD and
SGLD) schemes as function of the local iterations number L (N = 6 particles,
with a = 1 and b = 0.01. The local training loss Lk (θ) at and I = 20).
each agent k is given as Lk (θ) = (xk ,yk )∈Dk l(xk , yk , w),
where Dk is the dataset at agent k with covariates xk ∈ Rd
and label yk ∈ {−1, 1}, and the loss function l(xk , yk , w) is the of communication rounds, DSVGD can also reduce the overall
cross-entropy. Point decisions are taken based on the maximum communication load. For example, in the third plot in Fig. 4,
of the average predictive distribution. We consider the datasets DSVGD reaches an accuracy of 70% after 5 communication
Covertype and Twonorm [65]. We randomly split the training rounds with N = 6, requiring the exchange of 30 particles. In
dataset into partitions of equal size among the K agents. We contrast, FedAvg requires around 100 rounds to obtain the same
also include FedAvg, Stochastic Gradient Langevin Dynamics accuracy, making the total communication load much higher
(SGLD), FedBe [64] and DSGLD for comparison. We note than that of DSVGD.
that FedAvg is implemented here for consistency with the other In order to shed light on DSVGD’s computational advantage,
schemes by scheduling a single agent at each step. in Fig. 5, we plot the accuracy for Bayesian logistic regression as
In Fig. 4, we study how the accuracy evolves as function of the function of the number of local iterations L for a fixed number of
number of communication rounds i, or number of communica- communication rounds. It can be seen that DSVGD has a higher
tion rounds, across different datasets, using N = 2 and N = 6 computation efficiency than the other decentralized schemes
particles. We observe that DSVGD consistently outperforms (FedAvg and DSGLD), as it needs a lower number of local
the mentioned decentralized benchmarks and that, in contrast iterations to achieve the same level of accuracy.
to FedAvg and DSGLD, its performance scales well with the Bayesian Neural Networks: We now consider regression
number K of agents. Furthermore, the number N of particles is and multi-label classification with Bayesian Neural Networks
seen to control the trade-off between the communication load, (BNN) models. The experimental setup is the same as in [66],
which increases with N , and the convergence speed, which with the only exception that the prior of the weights is set to
improves as N grows larger. Through reduction of the number p0 (w) = N (w|0, λ−1 Id ) with a fixed precision λ = e. We plot

Authorized licensed use limited to: KINGS COLLEGE LONDON. Downloaded on June 22,2022 at 00:40:34 UTC from IEEE Xplore. Restrictions apply.
2188 IEEE TRANSACTIONS ON SIGNAL PROCESSING, VOL. 70, 2022

Fig. 6. Average RMSE as a function of the number of communication rounds for regression using Bayesian Neural Networks (BNN) with a single hidden layer
of ReLUs with (left) K = 2 and (right) K = 20 agents (N = 20, L = 200, 100 hidden neurons for Year Prediction and 50 for Kin8nm).

Fig. 7. Multi-label classification accuracy using BNN with a single hidden layer of 100 neurons as function of i, or number of communication rounds, using
MNIST and Fashion MNIST with (left) K = 2 agents and (right) K = 20 agents (N = 20, L = L = 200).

Fig. 8. Reliability plots for classification using BNN with variable number Fig. 9. Accuracy and maximum calibration error (MCE) as function of the
of hidden neurons using fashion MNIST (N = 20, I = 10, L = L = 200, number of particles N for Bayesian neural networks (I = 10, L = L = 200
K = 20). and K = 20 agents).

the average Root Mean Square Error (RMSE) for K = 2 and as [13]
K = 20 agents in Fig. 6 for regression over the Kin8nm and
Year datasets, and accuracy for multi-label classification on the MCE = maxm=1,...,M |accuracy(Bm ) − confidence(Bm )|,
MNIST and Fashion MNIST datasets in Fig. 7. Confirming the where bin Bm is the m-th bin containing data samples having
results for logistic regression, DSVGD consistently outperforms accuracy in the range [(m − 1)/M, m/M ). Here, accuracy is
the other decentralized benchmarks in terms of RMSE and defined by the average number of correctly classified samples
accuracy, while being more robust in terms of convergence speed while confidence is the probability that a given sample belongs
to an increase in the number of agents. to bin Bm . It can be seen that the MCE for Bayesian learning
Calibration: Reliability plots are a visual tool used to quantify schemes (dashed and solid red lines) is lower than frequentist
model calibration [13]. They report the average sample accuracy counterpart (dotted red line), while maintaining a high accuracy.
as function of the confidence level. Perfect calibration yields an For example, with N = 2 particles, DSVGD achieves an MCE
accuracy equal to the corresponding confidence (dashed line in of 0.06 on the covertype dataset while FedAVG achieves an MCE
Fig. 8). Fig. 8 shows the reliability plots for FedAvg and DSVGD of 0.13.
on the Fashion MNIST dataset for the BNN setting. While
increasing the number of neurons negatively affects FedAvg due
to overfitting, DSVGD enjoys excellent calibration even for large IX. CONCLUSION
models and is hence able to make trustworthy predictions. This paper has introduced DSVGD, a non-parametric
In Fig. 9, we plot the accuracy (black lines) and the maximum distributed variational inference algorithm for generalized
calibration error (MCE) (red lines) [13]. The MCE is defined Bayesian federated learning. DSVGD enables a flexible trade

Authorized licensed use limited to: KINGS COLLEGE LONDON. Downloaded on June 22,2022 at 00:40:34 UTC from IEEE Xplore. Restrictions apply.
KASSAB AND SIMEONE: FEDERATED GENERALIZED BAYESIAN LEARNING VIA DISTRIBUTED STEIN VARIATIONAL GRADIENT DESCENT 2189

between per-iteration communication load and number of com- r the maximum absolute eigenvalue of the Hessian
(i)
munication rounds, while being able to make trustworthy deci- −∇2 log p̃k (θ) is upper bounded by a constant M > 0;
sions via Bayesian inference. Communication efficiency could and
be further improved by implementing compression and quanti- r the inequality S(q [l] (θ), p̃(i) ) < C holds for some C > 0.
k
1
zation schemes [30]. In this work, we have not investigated this For learning rate  ≤ (β − 1)/(βBC 2 ) with any β > 1, the
aspect, which we leave for future work. However, we observe decrease in the local KL divergence from local iteration l to l + 1
that communication load reduction due to quantization and satisfies the inequality
compression are expected to apply equally to both frequentis 
(i)
and Bayesian particle-based schemes. Quantization noise, if F (q [l+1] (θ)) − F (q [l] (θ)) ≤ −αS q [l] , p̃k (1 − γ),
modelled as a Gaussian noise, may even be advantageous in (20)
improving the performance of SVGD [67]. where γ = ((β 2 + M )B 2 )/2.
Lemma 1 shows that by choosing a learning rate  ≤
1
APPENDIX min(γ −1 , (β − 1)/(βBC 2 ), one can guarantee a per-iteration
decrease in the local-free energy, i.e., in the KL divergence
A Proofs
between the particles’ distribution and the target tilted distri-
In this section, we prove Theorem 1 and 2. (i)
bution p̃k (θ) that depends on the kernelized Stein discrepancy
Proof: Consider the general implementation of DVI, whereby (i)
S(q [l] , p̃k ) at the iteration before the update.
a set K of agents are scheduled in parallel. DVI is equivalent to
Lemma 2: (Relationship between global and local free en-
the following functional mapping
ergy.) The global free energy F (q(θ)) in (1) is related to the
  (i)
ti (θ) −
→ ti (θ) local free energy Fk (q(θ)) in (5) of the k-th scheduled agent
i∈K i∈K as

  (i−1)

q(θ) 1 (i) tm (θ)


{tk (θ)}k∈K −
→ tk (θ) ∝ exp − Lk (θ) F (q(θ)) = Fk (q(θ)) + α Eq(θ) log .
tk (θ) α k∈K m=k
exp(− α1 Lm (θ))
(21)

K  
q(θ) = p0 (θ) → q  (θ) = p0 (θ)
ti (θ) − ti (θ) tk (θ) Proof: The global free energy (1) can be written as
i=1 i∈K k∈K

q(θ)
F (q(θ)) = αEq(θ) log
Therefore, assuming that all devices k are periodically sched- p0 (θ) exp(− α1 K m=1 Lm (θ))
uled, q(θ) is a fixed point of DVI if and only if the following q (i−1) (θ)
equality holds (i−1)

q(θ) tk (θ)
= αEq(θ) log (i)
· 1

tk (θ) = tk (θ) for k = 1, . . . , K. p̃k (θ) 0p (θ) exp(− α m=k Lm (θ))

q(θ)
This condition is satisfied by setting q(θ) = qopt (θ), along = αEq(θ) log (i)
with tk (θ) = 1/Zk exp(−Lk /α), and by no other distribution. p̃k (θ)
This concludes the proof.   (i−1)

We move now to Theorem 2 for U-DSVGD. We leave the p0 (θ) m=k tm (θ)
+ αEq(θ) log
analysis of the impact of the additional distillation step used by p0 (θ) exp(− α1 m=k Lm (θ))
DSVGD for future work. The analysis builds on the following
 (i−1)

result from [62], which is restated here using our notation. (i) tm (θ)
= Fk (q(θ)) + α Eq(θ) log ,
Denote by || · ||H the norm in the RKHS H defined by the m=k
exp(− α1 Lm (θ))
positive definite kernel k(θ, θ ). We assume that the kernel (22)
satisfies the following technical condition: there exist a constant
B > 0 such that where in the second equality we have used (12); and in
the third equality we have used the equality q (i−1) (θ) =
d    (i−1)
 ∂k(θ, ·) 2 p0 (θ) K m=1 tm (θ), which is guaranteed by the U-DSVGD
||k(θ, ·)||H ≤ B and   ≤ B 2 . (19)
∂θj H update (11) and (12) (see [56, Property 2]). 
j=1
We know from Lemma 1 that a learning rate  ≤
1
This condition is for instance satisfied by the RBF kernel with min(γ −1 , (β − 1)/(βBC 2 ) is sufficient to ensure a per-
B = 1 [68]. Furthermore, we define the kernelized Stein discrep- iteration decrease in the local free energy. Given that the KL
ancy [61] between two distributions p and q as S(p, q), and the divergence in the second term in (17) generally increases with
total variation distance as ||q − p||T V = 12 |q(θ) − p(θ)|dθ. , 2 demonstrates that, in order to guarantee a reduction of the
Lemma 1: (Guaranteed per-iteration decrease of the local global free energy, a smaller learning rate may be required. We
free energy.) [62] For a kernel satisfying (19), assume that, at a also note that the KL divergence term D(q [l+1] ||q [l] ) may be
given communication round i and local iteration l, with agent k explicitly related to the learning rate by following [69, Sec. 8],
scheduled, we have: but we do not further pursue this aspect here. We finally remark

Authorized licensed use limited to: KINGS COLLEGE LONDON. Downloaded on June 22,2022 at 00:40:34 UTC from IEEE Xplore. Restrictions apply.
2190 IEEE TRANSACTIONS ON SIGNAL PROCESSING, VOL. 70, 2022

that, in the presence of K = 1 agent, the upper bound (20) in [62]


Algorithm 2: Parallel-Distributed Stein Variational Gradi-
is recovered. This is because, in the presence of one agent, the
ent Descent (P-DSVGD).
global free energy reduces to the local free energy (see (21)) and
accordingly U-DSVGD reduces to SVGD. 1: Input: prior p0 (θ), local loss functions {Lk (θ)}K
k=1 ,
Proof: We wish to obtain an upper bound on the decrease temperature α > 0, kernels K(·, ·) and k(·, ·)
of the global free energy F (q [l+1] (θ)) − F (q [l] (θ)) across each 2: Output: global
approximate posterior
local SVGD iteration during communication round i. Using (21), q(θ) = N −1 N n=1 K(θ, φn )
(0) i.i.d
the decrease in the global free energy can be written as 3: Initialize: q (0) (θ) = p0 (θ); {θn }N n=1 ∼ p0 (θ);
(0) (0) (0) (0)
  {φn = θk,n = θn }N n=1 and t k (θ) = 1 for
(i) (i)
F (q [l+1] (θ)) − F (q [l] (θ)) = Fk q [l+1] (θ) − Fk q [l] (θ) k = 1, . . . , K
   4: for i = 1 to I do
(a)
5: Server schedules a set K(i) of agents in parallel
 (i−1)
tm (θ)

6: Agents downloads current server particles
+α Eq[l+1] (θ) log (i−1)
{φn }N
exp(− α1 Lm (θ)) n=1 from server
m=k    (i)
7: Agents obtains updated global particles {θn }N n=1
(b) (i−1) (i−1) N (i−1) N
using (14), {θn = φn }n=1 and {θk,n }n=1
(i−1)
 (i)
tm (θ) 8: Agents carries distillation to obtain {θk,n }N
−Eq[l] (θ) log . (23) n=1
exp(− α1 Lm (θ)) (i)
encoding tk (θ) using (18) and {θn }N
(i)
   n=1
(i)
(b) 9: Agents sends the obtained local particles {θk,n }N
n=1
for k ∈ K(i) to the server
We now derive upper bounds for (a) and (b). Using Lemma (i) (i−1) N
10: Server obtains {φn }N n=1 using (25), {φn }n=1
1 and the definition of the local free energy in (5), we have the (i) N (i)
following upper bound on (a) and {θk,n }n=1 for k ∈ K
11: end for (I)
(a) =
(i)
Fk (q [l+1] (θ)) −
(i)
Fk (q [l] (θ)) 12: return q(θ) = N −1 N n=1 K(θ, φn )

(i)
≤ −αS(q [l] (θ), p̃k )(1 − γ),
to be obtained as
while (b) can be rewritten and upper bounded by using the
 (i)
 (i)
properties of the total variation distance as q (i) (θ) = p0 (θ) tk (θ) tk (θ), (24)
 (i−1)
k∈K(i) k ∈K(i)
tm (θ)
(b) = (q [l+1] (θ) − q [l] (θ)) log dθ
exp(− α1 Lm (θ)) (i) (i−1)
where tk (θ) = tk (θ) for k  ∈ K(i) . To replicate this same
(i) behaviour while preserving the non-parametric property of
≤ 2lmax ||q [l+1] − q [l] ||T V .
DSVGD, in P-DSVGD, each agent k ∈ K(i) shares its local par-
(i)
Using Pinsker’s inequality [70], the term (b) can be further ticles {θk,n }N
n=1 representing the approximate likelihood where
(i) −1
N (i) (i)
upper bounded as tk (θ) = N n K(θ, θk,n ). Then, to approximate q (θ) in
 (24), using SVGD, the server carries out Ls SVGD updates as
(i)
(b) ≤ 2lmax 2D(q [l+1] ||q [l] ).
  [l−1] [l−1]  
N
[l−1]
Accordingly, the global energy dissipation in (23) can be φ[l]
n ← − φ[l−1]
n + k φj , φn ∇θj log q (i) φj
N j=1
upper bounded as in (17). 

[l−1]
+ ∇φj k φj , φ[l−1]
n ], for l = 1, . . . , Ls .
B Parallel-DSVGD
(25)
In this section, we present a direct extension of DSVGD in
which multiple agents can be scheduled in parallel during the For the (i + 1)-th communication round, scheduled agents
(i+1) (i+1) [Ls ] N
same communication round. In Parallel-DSVGD (P-DSVGD), K download particles {φn }N n=1 = {φn }n=1 that are
each agent in the set K(i) of scheduled agents at round i applies treated in a similar fashion as in DSVGD. The full algorithmic
the same steps as in DSVGD except that it shares the local table for P-DSVGD is provided in Algorithm 2. Numerical
(i)
particles {θk,n }Nn=1 with the server instead of the global ones. results for P-DSVGD are provided in [60].
Then, the server distills the received local particles into a set
(i)
of N server-side particles {φn }N n=1 using SVGD to obtain the REFERENCES
next iterate of the global posterior.
[1] P. Kairouz et al., “Advances and open problems in federated learning,”
As highlighted in Section IV of the main text, a parallel Foundations Trends Mach. Learn., vol. 14, no. 1/2, pp. 1–210, 2021,
implementation requires the i-th iterate of the global posterior doi: 10.1561/2200000083.

Authorized licensed use limited to: KINGS COLLEGE LONDON. Downloaded on June 22,2022 at 00:40:34 UTC from IEEE Xplore. Restrictions apply.
KASSAB AND SIMEONE: FEDERATED GENERALIZED BAYESIAN LEARNING VIA DISTRIBUTED STEIN VARIATIONAL GRADIENT DESCENT 2191

[2] A. Vabalas, E. Gowen, E. Poliakoff, and A. J. Casson, “Machine learning [23] M. Welling and Y. W. Teh, “Bayesian learning via stochastic gradient
algorithm validation with a limited sample size,” PLoS One, vol. 14, no. 11, Langevin dynamics,” in Proc. 28th Int. Conf. Int. Conf. Mach. Learn.,
2019, Art. no. e0224365. 2011, pp. 681–688.
[3] T. Grote, “Trustworthy medical AI systems need to know when they [24] D. Liu and O. Simeone, “Wireless federated langevin Monte Carlo:
don’t know,” J. Med. Ethics, vol. 47, no. 5, pp. 337–338, 2021. [Online]. Repurposing channel noise for Bayesian sampling and privacy,” 2021,
Available: https://jme.bmj.com/content/47/5/337 arXiv:2108.07644.
[4] F. P.-C. Lin, C. G. Brinton, and N. Michelusi, “Federated learning with [25] Q. Liu and D. Wang, “Stein variational gradient descent: A general purpose
communication delay in edge networks,” in Proc. IEEE Global Commun. Bayesian inference algorithm,” in Proc. Adv. Neural Inf. Process. Syst.,
Conf., 2020, pp. 1–6, doi: 10.1109/GLOBECOM42002.2020.9322592. 2016, pp. 2378–2386.
[5] E. Toreini, M. Aitken, K. Coopamootoo, K. Elliott, C. G. Zelaya, and [26] T. Li, A. K. Sahu, A. Talwalkar, and V. Smith, “Federated learning:
A. van Moorsel, “The relationship between trust in AI and trustworthy Challenges, methods, and future directions,” IEEE Signal Process. Mag.,
machine learning technologies,” in Proc. Conf. Fairness, Accountability, vol. 37, no. 3, pp. 50–60, May 2020.
Transparency, New York, NY, USA, 2020, pp. 272–283. [Online]. Avail- [27] T.-H. Chang, M. Hong, H.-T. Wai, X. Zhang, and S. Lu, “Dis-
able: https://doi.org/10.1145/3351095.3372834 tributed learning in the nonconvex world: From batch data to streaming
[6] E. Hüllermeier and W. Waegeman, “Aleatoric and epistemic uncertainty and beyond,” IEEE Signal Process. Mag., vol. 37, no. 3, pp. 26–38,
in machine learning: An introduction to concepts and methods,” Mach. May 2020.
Learn., vol. 110, no. 3, pp. 457–506, 2021. [Online]. Available: https: [28] P. M. Djurić and Y. Wang, “Distributed Bayesian learning in multiagent
//doi.org/10.1007/s10994-021-05946-3 systems: Improving our understanding of its capabilities and limitations,”
[7] X. Wang, H. Liu, C. Shi, and C. Yang, “Be confident! towards trust- IEEE Signal Process. Mag., vol. 29, no. 2, pp. 65–76, Mar. 2012.
worthy graph neural networks via confidence calibration,” in Proc. Adv. [29] P. Djuric and C. Richard, Cooperative and Graph Signal Processing:
Neural Inf. Process. Syst., A. Beygelzimer, Y. Dauphin, P. Liang, and Principles and Applications. Cambridge, MA, USA: Academic Press,
J. W. Vaughan, Eds., 2021. [Online]. Available: https://openreview.net/ 2018.
forum?id=9c-IsSptbmA [30] N. Shlezinger, M. Chen, Y. C. Eldar, H. V. Poor, and S. Cui, “Federated
[8] G. Varoquaux, “Cross-validation failure: Small sample sizes lead to large learning with quantization constraints,” in Proc. IEEE Int. Conf. Acoust.,
error bars,” NeuroImage, vol. 180, pp. 68–77, 2018. [Online]. Available: Speech Signal Process., 2020, pp. 8851–8855.
https://www.sciencedirect.com/science/article/pii/S1053811917305311 [31] M. M. Amiri and D. Gündüz, “Over-the-air machine learning at the
[9] M. Olson, A. Wyner, and R. Berk, “Modern neural networks generalize on wireless edge,” in Proc. IEEE 20th Int. Workshop Signal Process. Adv.
small data sets,” in Proc. Adv. Neural Inf. Process. Syst., S. Garnett, Ed., Wireless Commun., 2019, pp. 1–5.
2018, vol. 31. [Online]. Available: https://proceedings.neurips.cc/paper/ [32] Y. Zhao, M. Li, L. Lai, N. Suda, D. Civin, and V. Chandra, “Federated
2018/file/fface8385abbf94b4593a0ed53a0c70f-Paper.pdf learning with non-IID data,” 2018, arXiv:1806.00582.
[10] S. Arora, S. S. Du, Z. Li, R. Salakhutdinov, R. Wang, and D. Yu, [33] A. Fallah, A. Mokhtari, and A. Ozdaglar, “Personalized feder-
“Harnessing the power of infinitely wide deep nets on small-data tasks,” ated learning with theoretical guarantees: A model-agnostic meta-
in Proc. Int. Conf. Learn. Representations, 2020. [Online]. Available: learning approach,” in Proc. Adv. Neural Inf. Process. Syst.,
https://openreview.net/forum?id=rkl8sJBYvH 2020, pp. 3557–3568. [Online]. Available: https://proceedings.neurips.cc/
[11] D. J. MacKay and D. J. Mac Kay, Information Theory, Inference paper/2020/file/24389bfe4fe2eba8bf9aa9203a44cdad-Paper.pdf
and Learning Algorithms. Cambridge, U.K.: Cambridge Univ. Press, [34] H. Wang, Z. Kaplan, D. Niu, and B. Li, “Optimizing federated learning on
2003. Non-IID data with reinforcement learning,” in Proc. IEEE Conf. Comput.
[12] B. McMahan, E. Moore, D. Ramage, S. Hampson, and B. A. y Ar- Commun., 2020, pp. 1698–1707.
cas, “Communication-efficient learning of deep networks from decen- [35] H. T. Nguyen, V. Sehwag, S. Hosseinalipour, C. G. Brinton,
tralized data,” in Proc. 20th Int. Conf. Artif. Intell. Statist., 2017, M. Chiang, and H. V. Poor, “Fast-convergent federated learning,”
vol. 54, pp. 1273–1282. [Online]. Available: http://proceedings.mlr.press/ IEEE J. Sel. Areas Commun., vol. 39, no. 1, pp. 201–218, 2021,
v54/mcmahan17a.html doi: 10.1109/JSAC.2020.3036952.
[13] C. Guo, G. Pleiss, Y. Sun, and K. Q. Weinberger, “On calibration of modern [36] W. Xia, T. Q. S. Quek, K. Guo, W. Wen, H. H. Yang, and H. Zhu, “Multi-
neural networks,” in Proc. 34th Int. Conf. Mach. Learn.-Volume 70, 2017, armed bandit-based client scheduling for federated learning,” IEEE Trans.
pp. 1321–1330. Wireless Commun., vol. 19, no. 11, pp. 7108–7123, Nov. 2020.
[14] W. Maddox, T. Garipov, P. Izmailov, D. Vetrov, and A. G. Wil- [37] T. Nishio and R. Yonetani, “Client selection for federated learning with
son, “Fast uncertainty estimates and bayesian model averaging of heterogeneous resources in mobile edge,” in Proc. IEEE Int. Conf. Com-
DNNs,” in Proc. Uncertainty Deep Learn. Workshop at UAI, 2018, mun., 2019, pp. 1–7.
vol. 8. [38] J. Zhuo, C. Liu, J. Shi, J. Zhu, N. Chen, and B. Zhang, “Message passing
[15] W. J. Maddox, T. Garipov, P. Izmailov, D. Vetrov, and A. G. Wilson, A stein variational gradient descent,” in Proc. Int. Conf. Mach. Learn., 2018,
Simple Baseline for Bayesian Uncertainty in Deep Learning. Red Hook, pp. 6018–6027.
NY, USA: Curran Associates Inc., 2019. [39] J. Yoon, T. Kim, O. Dia, S. Kim, Y. Bengio, and S. Ahn, “Bayesian model-
[16] E. Angelino, M. J. Johnson, and R. P. Adams, “Patterns of scalable agnostic meta-learning,” in Proc. Adv. Neural Inf. Process. Syst., 2018,
Bayesian inference,” Foundations Trends Mach. Learn., vol. 9, no. 2/3, pp. 7332–7342.
pp. 119–247, 2016, doi: 10.1561/2200000052. [40] T. Li, A. K. Sahu, M. Zaheer, M. Sanjabi, A. Talwalkar, and V. Smith,
[17] W. Neiswanger, C. Wang, and E. Xing, “Embarrassingly parallel varia- “Federated optimization in heterogeneous networks,” in Proc. Mach.
tional inference in nonconjugate models,” 2015, arXiv:1510.04163. Learn. Syst., 2020, vol. 2, pp. 429–450.
[18] T. Broderick, N. Boyd, A. Wibisono, A. C. Wilson, and M. I. Jordan, [41] X. Zhang, M. Hong, S. Dhople, W. Yin, and Y. Liu, “FedPD: A federated
“Streaming variational bayes,” in Proc. Adv. Neural Inf. Process. Syst., learning framework with optimal rates and adaptivity to Non-IID data,”
2013, pp. 1727–1735. [Online]. Available: http://papers.nips.cc/paper/ 2020, arXiv:2005.11418.
4980-streaming-variational-bayes.pdf [42] R. Pathak and M. J. Wainwright, “FedSplit: An algorithmic framework
[19] L. Corinzia and J. M. Buhmann, “Variational federated multi-task learn- for fast federated optimization,” Adv. Neural Inf. Process. Syst., vol. 33,
ing,” 2019, arXiv:1906.06268. pp. 7057–7066, 2020.
[20] S. Ahn, B. Shahbaba, and M. Welling, “Distributed stochastic gradi- [43] D. J. C. MacKay, Information Theory, Inference & Learning Algorithms.
ent MCMC,” in Proc. 31st Int. Conf. Mach. Learn., 2014, vol. 32, USA: Cambridge Univ. Press, 2002.
pp. 1044–1052. [Online]. Available: http://proceedings.mlr.press/v32/ [44] R. M. Neal, Bayesian Learning for Neural Networks, vol. 118. Berlin,
ahn14.html Germany: Springer, 2012.
[21] D. Mesquita, P. Blomstedt, and S. Kaski, “Embarrassingly parallel MCMC [45] J. Mitros and B. Mac Namee, “On the validity of Bayesian neural networks
using deep invertible transformations,” in Proc. Mach. Learn. Res., for uncertainty estimation,” 2019, arXiv:1912.01530.
2020, vol. 115, pp. 1244–1252. [Online]. Available: http://proceedings. [46] T. Zhang, “Information-theoretic upper and lower bounds for statistical
mlr.press/v115/mesquita20a.html estimation,” IEEE Trans. Inf. Theory, vol. 52, no. 4, pp. 1307–1321,
[22] Z. Wei and E. M. Conlon, “Parallel Markov chain Monte Carlo for Apr. 2006.
Bayesian hierarchical models with big data, in two stages,” J. Appl. Statist., [47] J. Knoblauch, J. Jewson, and T. Damoulas, “Generalized variational infer-
vol. 46, no. 11, pp. 1917–1936, 2019. ence,” 2019, arXiv:1904.02063.

Authorized licensed use limited to: KINGS COLLEGE LONDON. Downloaded on June 22,2022 at 00:40:34 UTC from IEEE Xplore. Restrictions apply.
2192 IEEE TRANSACTIONS ON SIGNAL PROCESSING, VOL. 70, 2022

[48] P. Alquier, J. Ridgway, and N. Chopin, “On the properties of variational [67] M. Ye, T. Ren, and Q. Liu, “Stein self-repulsive dynamics: Benefits from
approximations of gibbs posteriors,” J. Mach. Learn. Res., vol. 17, no. 236, past samples,” Adv. Neural Inf. Process. Syst., vol. 33, pp. 241–252, 2020.
pp. 1–41, 2016. [Online]. Available: http://jmlr.org/papers/v17/15-290. [68] D.-X. Zhou, “Derivative reproducing properties for kernel methods in
html learning theory,” J. Comput. Appl. Math., vol. 220, no. 1-2, pp. 456–463,
[49] M. I. Jordan, J. D. Lee, and Y. Yang, “Communication-efficient 2008.
distributed statistical inference,” J. Amer. Stat. Assoc., vol. 114, [69] T. Pinder, C. Nemeth, and D. Leslier, “Stein variational Gaussian pro-
no. 526, pp. 668–681, 2019. [Online]. Available: https://doi.org/10.1080/ cesses,” 2020, arXiv:2009.12141.
01621459.2018.1429274 [70] M. S. Pinsker, Information and Information Stability of Random Variables
[50] Q. Liu and A. Ihler, “Distributed estimation, information loss and expo- and Processes. San Francisco, CA, USA: Holden-Day, 1964.
nential families,” in Proc. 27th Int. Conf. Neural Inf. Process. Syst.- Volume
1, Cambridge, MA, USA, MIT Press, 2014, pp. 1098–1106.
[51] S. L. Scott, A. W. Blocker, F. V. Bonassi, H. A. Chipman, E. I. George,
and R. E. McCulloch, “Bayes and Big Data: The consensus Monte Rahif Kassab received the two Engineering degrees
Carlo algorithm,” Int. J. Manage. Sci. Eng. Manage., vol. 11, pp. 78–88, from Telecom Paris, Paris, France, and Lebanese Uni-
2016. [Online]. Available: http://www.tandfonline.com/doi/full/10.1080/ versity, Beirut, Lebanon, in 2017, the M.Sc. degree
17509653.2016.1142191 in advanced communication networks jointly from
[52] X. Wang and D. B. Dunson, “Parallel MCMC via weierstrass sampler,” École Polytechnique, France, and Telecom Paris, and
2013, arXiv:1312.4605. the Ph.D. degree in computer science from King’s
[53] L. J. Rendell, A. M. Johansen, A. Lee, and N. Whiteley, “Global consensus College London, London, U.K., in 2021. His research
Monte Carlo,” J. Comput. Graphical Statist., vol. 30, no. 2, pp. 249–259, interests include communication theory, optimiza-
2020. tion, and machine learning. His industrial experience
[54] M. Yurochkin, M. Agarwal, S. Ghosh, K. Greenewald, N. Hoang, and includes a six-month internship with Nokia Bell Labs
Y. Khazaeni, “Bayesian nonparametric federated learning of neural net- and a summer internship with Huawei’s Mathematical
works,” in Proc. 36th Int. Conf. Mach. Learn., Long Beach, California, and Algorithmic Sciences Lab, Paris, France. He received the 2018 IEEE
USA, 2019, vol. 97, pp. 7252–7261. Globecom Student Travel Grant, the Ile-de-France Masters Scholarship, from
[55] S. Claici, M. Yurochkin, S. Ghosh, and J. Solomon, “Model fusion with 2015 to 2017, and a Ph.D. Fellowship, from 2018 to 2020, awarded from
Kullback-Leibler divergence,” in Proc. Int. Conf. Mach. Learn., 2020, King’s College London and funded by the European Research Council. He
pp. 2038–2047. was a reviewer for many journals, including the IEEE JOURNAL ON SELECTED
[56] T. D. Bui, C. V. Nguyen, S. Swaroop, and R. E. Turner, “Partitioned AREAS IN COMMUNICATIONS, IEEE JOURNAL OF SELECTED TOPICS IN SIGNAL
variational inference: A unified framework encompassing federated and PROCESSING, and IEEE TRANSACTIONS ON WIRELESS COMMUNICATIONS.
continual learning,” 2018, arXiv:1811.11206.
[57] S.-I. Amari, “Natural gradient works efficiently in learning,” Neural Com-
put., vol. 10, no. 2, pp. 251–276, 1998.
[58] C. M. Bishop, Pattern Recognition and Machine Learning (Information Osvaldo Simeone (Fellow, IEEE) received the M.Sc.
Science and Statistics). Berlin, Heidelberg: Springer-Verlag, 2006. (Hons.) and Ph.D. degrees in information engineer-
[59] Q. Liu, “Stein variational gradient descent as gradient flow,” in Proc. Adv. ing from the Politecnico di Milano, Milan, Italy, in
Neural Inf. Process. Syst., 2017, pp. 3115–3123. 2001 and 2005, respectively. From 2006 to 2017, he
[60] R. Kassab and O. Simeone, “Federated generalized Bayesian learning via was a Faculty Member with the Electrical and Com-
distributed stein variational gradient descent,” 2020, arXiv:2009.06419. puter Engineering Department, New Jersey Institute
[61] Q. Liu, J. Lee, and M. Jordan, “A kernelized stein discrepancy for of Technology, Newark, NJ, USA, where he was
goodness-of-fit tests,” in Proc. Int. Conf. Mach. Learn., 2016, pp. 276–284. affiliated with the Center for Wireless Information
[62] A. Korba, A. Salim, M. Arbel, G. Luise, and A. Gretton, “A non-asymptotic Processing. He is currently a Professor of information
analysis for stein variational gradient descent,” in Proc. Adv. Neural Inf. engineering with the Centre for Telecommunications
Process. Syst., H. Larochelle, M. Ranzato, R. Hadsell, M. F. Balcan, and Research, Department of Informatics, King’s College
H. Lin, Eds., 2020, vol. 33, pp. 4672–4682. London, London, U.K. He has coauthored two monographs, an edited book
[63] G. Hinton, O. Vinyals, and J. Dean, “Distilling the knowledge in a neural published by Cambridge University Press, and more than one hundred research
network,” 2015, arXiv:1503.02531. journal papers. His research interests include wireless communications, infor-
[64] H.-Y. Chen and W.-L. Chao, “FedDistill: Making bayesian model ensem- mation theory, optimization, and machine learning. He is a Fellow of IET. He
ble applicable to federated learning,” in Proc. Int. Conf. Learn. Repre- was the co-recipient of the 2018 IEEE Signal Processing Best Paper Award,
sentations, 2021. [Online]. Available: https://openreview.net/forum?id= 2017 JCN Best Paper Award, 2015 IEEE Communication Society Best Tutorial
dgtpE6gKjHn Paper Award, and Best Paper Awards of the IEEE SPAWC 2007 and the IEEE
[65] S. J. Gershman, M. D. Hoffman, and D. M. Blei, “Nonparametric varia- WRECOM 2007. He was awarded a Consolidator Grant by the European
tional inference,” in Proc. 29th Int. Cof. Int. Conf. Mach. Learn., Madison, Research Council (ERC), in 2016. His research has been supported by the U.S.
WI, USA, Omnipress, 2012, pp. 235–242. NSF, ERC, Vienna Science and Technology Fund, and also by a number of
[66] J. M. Hernández-Lobato and R. P. Adams, “Probabilistic backpropagation industrial collaborations. He is currently on the Editorial Board of IEEE Signal
for scalable learning of bayesian neural networks,” in Proc. 32nd Int. Conf. Processing Magazine, and he is a Distinguished Lecturer of IEEE Information
Int. Conf. Mach. Learn.- Volume 37, 2015, pp. 1861–1869. Theory Society.

Authorized licensed use limited to: KINGS COLLEGE LONDON. Downloaded on June 22,2022 at 00:40:34 UTC from IEEE Xplore. Restrictions apply.

You might also like