A Brief Primer On Variational Inference - Fabian Dablander

You might also like

Download as pdf or txt
Download as pdf or txt
You are on page 1of 14

Fabian Dablander PhD Student Methods & Statistics

Blog Papers Talks Climate Emergency    

A brief primer on Variational Inference


Bayesian inference using Markov chain Monte Carlo methods can be notoriously slow. In this blog
post, we reframe Bayesian inference as an optimization problem using variational inference,
markedly speeding up computation. We derive the variational objective function, implement
coordinate ascent mean-field variational inference for a simple linear regression example in R, and
compare our results to results obtained via variational and exact inference using Stan. Sounds like
word salad? Then let’s start unpacking!

Preliminaries
Bayes’ rule states that

Likelihood


p(x ∣ z)
p(z ∣ x) = p(z) × ,
  ∫ p(x ∣ z) p(z) dz
Posterior Prior 
Marginal Likelihood

where z denotes latent parameters we want to infer and x denotes data.1 Bayes’ rule is, in general,
difficult to apply because it requires dealing with a potentially high-dimensional integral — the
marginal likelihood. Optimization, which involves taking derivatives instead of integrating, is much
easier and generally faster than the latter, and so our goal will be to reframe this integration
problem as one of optimization.

Variational objective
We want to get at the posterior distribution, but instead of sampling we simply try to find a density

q (z) from a family of densities Q that best approximates the posterior distribution:

q (z) = argmin  KL (q(z) || p(z ∣ x)) ,

q(z)∈Q

where KL(. ||. ) denotes the Kullback-Leibler divergence:

q(z)
KL (q(z) || p(z ∣ x)) = ∫ q(z) log  dz .
p(z ∣ x)

We cannot compute this Kullback-Leibler divergence because it still depends on the nasty integral
p(x) = ∫ p(x ∣ z) p(z) dz . To see this dependency, observe that:

q(z)
KL (q(z) || p(z ∣ x)) = Eq(z) [log  ]
p(z ∣ x)

= Eq(z) [log q(z)] − Eq(z) [log p(z ∣ x)]

p(z, x)
= Eq(z) [log q(z)] − Eq(z) [log  ]
p(x)

= Eq(z) [log q(z)] − Eq(z) [log p(z, x)] + Eq(z) [log p(x)]

= Eq(z) [log q(z)] − Eq(z) [log p(z, x)] + ∫ q(z) log p(x) dz

= Eq(z) [log q(z)] − Eq(z) [log p(z, x)] + log p(x) ,



Nemesis
where we have expanded the expectation to more clearly behold our nemesis. In doing so, we have
seen that log p(x) is actually a constant with respect to q(z); this means that we can ignore it in
our optimization problem. Moreover, minimizing a quantity means maximizing its negative, and so
we maximize the following quantity:

ELBO(q) = − (KL (q(z) || p(z ∣ x)) − log p(x))

⎛ ⎞

= −⎜
⎜Eq(z) [log q(z)] − Eq(z) [log p(z, x)] + log p(x) − log p(x)⎟


⎝ ⎠
Nemesis perishes

= Eq(z) [log p(z, x)] − Eq(z) [log q(z)] .

We can expand the joint probability to get more insight into this equation:

ELBO(q) = Eq(z) [log p(x ∣ z)] + Eq(z) [log p(z)] − Eq(z) [log q(z)]



Eq(z) [log p(z,x)]

p(z)
= Eq(z) [log p(x ∣ z)] + Eq(z) [log  ]
q(z)

q(z)
= Eq(z) [log p(x ∣ z)] − Eq(z) [log  ]
p(z)

= Eq(z) [log p(x ∣ z)] − KL (q(z) || p(z)) .

This is cool. It says that maximizing the ELBO finds an approximate distribution q(z) for latent
quantities z that allows the data to be predicted well, i.e., leads to a high expected log likelihood,
but that a penalty is incurred if q(z) strays far away from the prior p(z). This mirrors the usual
balance in Bayesian inference between likelihood and prior (Blei, Kucukelbier, & McAuliffe, 2017).

ELBO stands for evidence lower bound. The marginal likelihood is sometimes called evidence, and we
see that ELBO is indeed a lower bound for the evidence:

ELBO(q) = − (KL (q(z) || p(z ∣ x)) − log p(x))

log p(x) = ELBO(q) + KL (q(z) || p(z ∣ x))

log p(x) ≥ ELBO(q) ,

since the Kullback-Leibler divergence is non-negative. Heuristically, one might then use the ELBO
as a way to select between models. For more on predictive model selection, see this and this blog
post.

Why variational?
Our optimization problem is about finding q ⋆
(z) that best approximates the posterior distribution.
This is in contrast to more familiar optimization problems such as maximum likelihood estimation
where one wants to find, for example, the single best value that maximizes the log likelihood. For
such a problem, one can use standard calculus (see for example this blog post). In our setting, we do
not want to find a single best value but rather a single best function. To do this, we can use variational
calculus from which variational inference derives its name (Bishop, 2006, p. 462).

A function takes an input value and returns an output value. We can define a functional which takes
a whole function and returns an output value. The entropy of a probability distribution is a widely
used functional:

H[p] = ∫ p(x) log p(x)dx ,

which takes as input the probability distribution p(x) and returns a single value, its entropy. In
variational inference, we want to find the function that minimizes the ELBO, which is a functional.

In order to make this optimization problem more manageable, we need to constrain the functions in
some way. One could, for example, assume that q(z) is a Gaussian distribution with parameter
vector ω. The ELBO then becomes a function of ω, and we employ standard optimization methods to
solve this problem. Instead of restricting the parametric form of the variational distribution q(z), in
the next section we use an independence assumption to manage the inference problem.

Mean-field variational family


A frequently used approximation is to assume that the latent variables z for j j = {1, … , m} are
mutually independent, each governed by their own variational density:
m

q(z) = ∏ qj (zj ) .

j=1

Note that this mean-field variational family cannot model correlations in the posterior distribution;
by construction, the latent parameters are mutually independent. Observe that we do not make any
parametric assumption about the individual q j (zj ) . Instead, their parametric form is derived for
every particular inference problem.

We start from our definition of the ELBO and apply the mean-field assumption:

ELBO(q) = Eq(z) [log p(z, x)] − Eq(z) [log q(z)]

m m m

= ∫ ∏ qi (zi ) log p(z, x) dz − ∫ ∏ qi (zi ) log ∏ qi (zi ) dz .

i=1 i=1 i=1

In the following, we optimize the ELBO with respect to a single variational density q j (zj ) and
assume that all others are fixed:
m m m

ELBO(qj ) = ∫ ∏ qi (zi ) log p(z, x) dz − ∫ ∏ qi (zi ) log ∏ qi (zi ) dz

i=1 i=1 i=1

m m m

= ∫ ∏ qi (zi ) log p(z, x) dz − ∫ qj (zj ) log qj (zj ) dzj − ∫ ∏ qi (zi ) log ∏ qi (zi ) dz−j

i=1 i≠j i≠j


Constant with respect to qj (zj )

∝ ∫ ∏ qi (zi ) log p(z, x) dz − ∫ qj (zj ) log qj (zj ) dzj

i=1

m
⎛ ⎞
= ∫ qj (zj ) ∫ ∏ qi (zi ) log p(z, x) dz−j dzj − ∫ qj (zj ) log qj (zj ) dzj
⎝ ⎠
i≠j

= ∫ qj (zj ) Eq(z ) [log p(z, x)] dzj − ∫ qj (zj ) log qj (zj ) dzj .


−j

One could use variational calculus to derive the optimal variational density q ⋆
j
(zj ) ; instead, we
follow Bishop (2006, p. 465) and define the distribution
~
log p (x, zj ) = Eq(z [log p(z, x)] − Z ,
−j )

where we need to make sure that it integrates to one by subtracting the (log) normalizing constant
Z . With this in mind, observe that:

~
ELBO(qj ) ∝ ∫ qj (zj ) log p (x, zj ) dzj − ∫ qj (zj ) log qj (zj ) dzj

~
p (x, zj )
= ∫ qj (zj ) log  dzj
qj (zj )

qj (zj )
= −∫ qj (zj ) log  dzj
~
p (x, zj )

~
= −KL (qj (zj ) || p (x, zj )) .
Thus, maximizing the ELBO with respect to q j (zj ) is minimizing the Kullback-leibler divergence
between q j (zj ) and ~
p (x, zj ) ; it is zero when the two distributions are equal. Therefore, under the
mean-field assumption, the optimal variational density q ⋆
j
(zj ) is given by:


q (zj ) = exp (Eq (z−j )
[log p(x, z)] − Z)
j −j

exp (Eq [log p(x, z)])


−j (z−j )

= ,
∫ exp (Eq (z−j )
[log p(x, z)]) dzj
−j

see also Bishop (2006, p. 466). This is not an explicit solution, however, since each optimal
variational density depends on all others. This calls for an iterative solution in which we first
initialize all factors q j (zi ) and then cycle through them, updating them conditional on the updates
of the other. Such a procedure is known as Coordinate Ascent Variational Inference (CAVI). Further,
note that

p(zj , z−j , x)
p(zj ∣ z−j , x) = ∝ p(zj , z−j , x) ,
p(z−j , x)

which allows us to write the updates in terms of the conditional posterior distribution of z given all j

other factors z −j . This looks a lot like Gibbs sampling, which we discussed in detail in a previous
blog post. In the next section, we implement CAVI for a simple linear regression problem.

Application: Linear regression


In a previous blog post, we traced the history of least squares and applied it to the most basic
problem: fitting a straight line to a number of points. Here, we study the same problem but swap
optimization procedure: instead of least squares or maximum likelihood, we use variational
inference. Our linear regression setup is:

2
y ∼ N (βx, σ )

2 2
β ∼ N (0, σ τ )

1
2
σ ∝ ,
2
σ

where we assume that the population mean of y is zero (i.e., β 0 = 0 ); and we assign the error
variance σ an improper Jeffreys’ prior and β a Gaussian prior with variance σ
2 2
τ
2
. We scale the
prior of β by the error variance to reason in terms of a standardized effect size β/σ since with this
specification:
2 2
β 1 σ τ 2
Var [ ] = Var[β] = = τ .
2 2
σ σ σ

As a heads up, we have to do a surprising amount of calculations to implement variational inference


even for this simple problem. In the next section, we start our journey by deriving the variational
density for σ . 2

Variational density for σ 2

Our optimal variational density q ⋆


(σ )
2
is given by:
⋆ 2 2
q (σ ) ∝ exp (Eq(β) [log p(σ ∣ y, β)]) .

To get started, we need to derive the conditional posterior distribution p(σ 2


∣ y, β) . We write:
2 2 2
p(σ ∣ y, β) ∝ p(y ∣ σ , β) p(β) p(σ )

n
1 1

1 − 1 2 −
1 − 1 −1
2 2 2 2 2 2 2
= ∏(2π) 2
(σ ) exp (− (yi − βxi ) ) (2π) 2
(σ τ ) exp (− β )(σ )
2 2 2
2σ 2σ τ 
i=1
 2
p(σ )
p(β)

n
n+1 2
n+1
− −1 −1 1 2
β
− 2 2 2
= (2π) 2
(σ ) (τ ) exp (− (∑ (yi − βxi ) + ))
2 2
2σ τ
i=1

⎛ ⎞
n
n+1 ⎜ β
2 ⎟
2
− −1 ⎜ 1 2 ⎟
2
∝ (σ ) exp ⎜− (∑ (yi − βxi ) + )⎟ ,
⎜ 2 2 ⎟
2σ τ
⎜ i=1 ⎟

⎝ ⎠
A

which is proportional to an inverse Gamma distribution. Moving on, we exploit the linearity of the
expectation and write:

⋆ 2 2
q (σ ) ∝ exp (Eq(β) [log p(σ ∣ y, β)])

n+1

2
− −1 1
2
= exp (Eq(β) [log (σ ) − A])
2σ 2

n+1

2
− −1 1
2
= exp (Eq(β) [log (σ ) ] − Eq(β) [ A])
2

n+1
− −1 1 1
2 2
= exp (log (σ ) − Eq(β) [ A])
2
σ 2

n+1
− −1 1 1
2 2
= (σ ) exp (− Eq(β) [ A]) .
2
σ 2

This, too, looks like an inverse Gamma distribution! Plugging in the normalizing constant, we arrive
at:

n+1
⎛ ⎞
1 2
Eq(β) [ A] −
n+1
−1 1 1
⋆ 2 2 2 2 ⎜ ⎟
q (σ ) = (σ ) exp ⎜− Eq(β) [ A]⎟ .
n+1 ⎜ σ
2
2 ⎟
Γ( )
2
⎝  ⎠
ν

Note that this quantity depends on β. In the next section, we derive the variational density for β.

Variational density for β


Our optimal variational density q ⋆
(β) is given by:

⋆ 2
q (β) ∝ exp (Eq(σ 2 ) [log p(β ∣ y, σ )]) ,

and so we again have to derive the conditional posterior distribution p(β ∣ y, σ )


2
. We write:
2 2 2
p(β ∣ y, σ ) ∝ p(y ∣ β, σ ) p(β) p(σ )

n
n+1 2
n+1 − −1 −1 1 2
β
− 2 2 2
= (2π) 2
(σ ) (τ ) exp (− (∑ (yi − βxi ) + ))
2 2
2σ τ
i=1

n n n 2
n+1
n+1
− −1 −1 1 β
− 2 2 2 2 2 2
= (2π) 2
(σ ) (τ ) exp (− (∑ y − 2β ∑ yi xi + β ∑x + ))
2 i i 2
2σ τ
i=1 i=1 i=1

n n
1 2
1
2
∝ exp (− (β (∑ x + ) − 2β ∑ yi xi ))
2 i 2
2σ τ
i=1 i=1

n 2 1 n
⎛ ∑ x + ⎛ ⎞⎞
i=1 i τ
2 ∑ yi x i
2 i=1
= exp ⎜− ⎜β − 2β ⎟⎟
2 n
2σ 2 1
⎝ ⎝ (∑ x + ) ⎠⎠
i=1 i τ
2

2
⎛ n 2 1
⎛ n
⎞ ⎞
∑ x + ∑ yi x i
i=1 i τ
2
i=1
∝ exp ⎜− ⎜β − ⎟ ⎟ ,
⎜ ⎟
2σ 2 n 2 1
⎝ ⎝ (∑ x + ) ⎠ ⎠
i=1 i τ
2

where we have “completed the square” (see also this blog post) and realized that the conditional
posterior is Gaussian. We continue by taking expectations:

⋆ 2
q (β) ∝ exp (Eq(σ 2 ) [log p(β ∣ y, σ )])

2
⎛ ⎡ n 2 1
⎛ n
⎞ ⎤⎞
∑ x + ∑ yi x i
i=1 i τ
2
i=1
= exp ⎜Eq(σ 2 ) ⎢− ⎜β − ⎟ ⎥ ⎟
⎜ ⎢ 2 ⎥⎟
2σ n 2 1
⎝ ⎣ ⎝ (∑ x + ) ⎠ ⎦⎠
i=1 i τ
2

2
⎛ n 2 1
⎛ n
⎞ ⎞
∑ x + ∑ yi x i
i=1 i τ
2 1 i=1
= exp ⎜− Eq(σ 2 ) [ ] ⎜β − ⎟ ⎟ ,
⎜ 2 ⎟
2 σ n 2 1
⎝ ⎝ (∑ x + ) ⎠ ⎠
i=1 i τ
2

which is again proportional to a Gaussian distribution! Plugging in the normalizing constant yields:
1
− 2
2
⎛ ⎛ ⎞ ⎞
⎛ ⎞
−1
1 ⎜ ⎜ ⎟ ⎟
⎜ ⎟ ⎜
n 2 1 n ⎟
⎜ Eq(σ 2 ) [ 2 ] ⎟ ∑ x + ⎜ ⎟
σ ⎜ i=1 i 2 1 ∑ yi x i ⎟
⋆ ⎜ ⎟ ⎜
τ ⎜ i=1 ⎟ ⎟
q (β) = ⎜2π ⎟ exp − Eq(σ 2 ) [ ] ⎜β − ⎟ ,
⎜ n 2 1 ⎟ ⎜ 2 ⎜ ⎟ ⎟
∑ x + ⎜ 2 σ n 2 1 ⎟
⎜ i=1 i 2 ⎟ ⎜ (∑ x + ) ⎟
⎜ τ
⎟ ⎜ ⎜ i=1 i τ
2
⎟ ⎟
 ⎜ ⎟

⎝ σ
2 ⎠ ⎝ ⎝ ⎠ ⎠
β μβ

Note that while the variance of this distribution, σ , depends on q(σ ), its mean μ does not. 2

β
2
β

To recap, instead of assuming a parametric form for the variational densities, we have derived the
optimal densities under the mean-field assumption, that is, under the assumption that the
parameters are independent: q(β, σ 2
) = q(β) q(σ )
2
. Assigning β a Gaussian distribution and σ a 2

Jeffreys’s prior, we have found that the variational density for σ is an inverse Gamma distribution 2

and that the variational density for β a Gaussian distribution. We noted that these variational
densities depend on each other. However, this is not the end of the manipulation of symbols; both
distributions still feature an expectation we need to remove. In the next section, we expand the
remaining expectations.

Removing expectations
Now that we know the parametric form of both variational densities, we can expand the terms that
involve an expectation. In particular, to remove the expectation in the variational density for σ , we 2

write:
n
2
2
β
Eq(β) [A] = Eq(β) [(∑ (yi − βxi ) + )]
2
τ
i=1

n n n
1
2 2 2 2
= ∑y − 2 ∑ yi xi Eq(β) [β] + ∑ x Eq(β) [β ] + Eq(β) [β ] .
i i 2
τ
i=1 i=1 i=1

Noting that E q(β)


[β] = μβ and using the fact that:

2 2 2 2
Eq(β) [β ] = Varq(β) [β] + Eq(β) [β] = σ + μ ,
β β

the expectation becomes:

n n n
1
2 2 2 2
Eq(β) [A] = ∑ y − 2 ∑ yi xi μβ + (σ + μ ) (∑ x + ) .
i β β i 2
τ
i=1 i=1 i=1

For the expectation which features in the variational distribution for β, things are slightly less
elaborate, although the result also looks unwieldy. We write:
n+1
n+1
1 1 ν 2
− −1 1
2 2 2
Eq(σ 2 ) [ ] = ∫ (σ ) exp (− ν) dσ
2 2 2
σ σ n+1 σ
Γ( )
2

n+1
n+1
ν 2
2
−( +1)−1 1 2
2
= ∫ (σ ) exp (− ν) dσ
2
n+1 σ
Γ( )
2

n+1
n+1
Γ( + 1)
ν 2 2

=
n+1
n+1 +1
Γ( ) ν 2

−1
n + 1 1
= ( Eq(β) [A])
2 2

−1
n n n
n + 1 1 1
2 2 2 2
= ( (∑ y − 2 ∑ yi xi μβ + (σ + μ ) (∑ x + ))) .
i β β i 2
2 2 τ
i=1 i=1 i=1

Monitoring convergence
The algorithm works by first specifying initial values for the parameters of the variational densities
and then iteratively updating them until the ELBO does not change anymore. This requires us to
compute the ELBO, which we still need to derive, on each update. We write:

2 2
ELBO(q) = Eq(β,σ 2 ) [log p(y, β, σ )] − Eq(β,σ 2 ) [log q(β, σ )]

2 2 2
= Eq(β,σ 2 ) [log p(y ∣ β, σ )] + Ep(β,σ 2 ) [log p(β, σ )] − Eq(β,σ 2 ) [log q(β, σ )]

2
p(β, σ )
2
= Eq(β,σ 2 ) [log p(y ∣ β, σ )] + Eq(β,σ 2 ) [log  ] .
2
q(β, σ )


2 2
−KL(q(β,σ ) || p(β,σ ))

Let’s take a deep breath and tackle the second term first:
2 2 2
p(β, σ ) p(β ∣ σ ) p(σ )
Eq(β,σ 2 ) [log  ] = Eq(σ 2 ) [Eq(β) [log  ] + log  ]
2 2
q(β, σ ) q(β) q(σ )

⎡ ⎡ −
1
⎤ ⎤
2 2 2 1 2
(2πσ τ ) exp (− β ) 2
⎢ ⎢ 2σ τ
2 2
⎥ p(σ ) ⎥
= Eq(σ 2 ) ⎢E ⎢log  ⎥ + log  ⎥
⎢ q(β) ⎢ 1 ⎥ ⎥
⎢ ⎢ −
2
⎥ q(σ 2 ) ⎥
2 1 2
(2πσ ) exp (− (β − μβ ) )
⎣ ⎣ β 2σ
2 ⎦ ⎦
β

1 2
⎡ ⎡ 2 2 β ⎤ 2 ⎤
σ τ σ τ
2 2 p(σ )
= Eq(σ 2 ) ⎢Eq(β) ⎢log  + ⎥ + log  ⎥
2 1 2
2 q(σ )
σ (β − μβ )
⎣ ⎣ β 2 ⎦ ⎦
σ
β

2 2
2 2 σ + μ 2
σ τ β β p(σ )
= Eq(σ 2 ) [log + ] + Eq(σ 2 ) [log  ]
2 2 2 2
σ σ τ q(σ )
β

2 2
2 σ + μ 2
τ β β 1 p(σ )
2
= log + Eq(σ 2 ) [log σ ] + Eq(σ 2 ) [ ] + Eq(σ 2 ) [log  ]
2 2 2 2
σ τ σ q(σ )
β

2 2
2 σ + μ
τ 2
β β 1 2 2
= log + Eq(σ 2 ) [log σ ] + Eq(σ 2 ) [ ] + Eq(σ 2 ) [log p(σ )] − Eq(σ 2 ) [log q(σ
2 2 2
σ τ σ
β

Note that there are three expectations left. However, we really deserve a break, and so instead of
analytically deriving the expectations we compute E q(σ )
2 [log σ ]
2
and E p(σ )
2 [log q(σ )]
2

numerically using Gaussian quadrature. This fails for E q(σ )


2 [log q(σ )]
2
, which we compute using
Monte carlo integration:

N
1
2 2 2 2 2
Eq(σ 2 ) [log q(σ )] = ∫ q(σ ) log q(σ ) dσ ≈ ∑ log q(σ ) ,
N 
i=1
2 2
σ ∼ q(σ )

We are left with the expected log likelihood. Instead of filling this blog post with more equations, we
again resort to numerical methods. However, we refactor the expression so that numerical
integration is more efficient:

2 2 2
Eq(β,σ 2 ) [log p(y ∣ β, σ )] = ∫ ∫ q(β) q(σ ) log p(y ∣ β, σ ) dσdβ

n
n

2 2
− 1 2
2
= ∫ q(β) ∫ q(σ ) log ((2πσ ) exp (− ∑(yi − xi β) )) dσdβ
2

i=1

n
n 1
2 2 2
= log (2π) ∫ q(β) (∑(yi − xi β) ) dβ ∫ q(σ ) log (σ ) dσ .
2
4 σ
i=1

Since we have solved a similar problem already above, we evaluate the expecation with respect to
q(β) analytically:

n n n n

2 2 2 2 2
Eq(β) [∑(yi − xi β) ] = ∑ y − 2 ∑ yi xi μβ + (σ + μ ) (∑ x ) .
i β β i

i=1 i=1 i=1 i=1

In the next section, we implement the algorithm for our linear regression problem in R.

Implementation in R
Now that we have derived the optimal densities, we know how they are parameterized. Therefore,
the ELBO is a function of these variational parameters and the parameters of the priors, which in
our case is just τ . We write a function that computes the ELBO:
2

library('MCMCpack')

#' Computes the ELBO for the linear regression example


#'
#' @param y univariate outcome variable
#' @param x univariate predictor variable
#' @param beta_mu mean of the variational density for \beta
#' @param beta_sd standard deviation of the variational density for \beta
#' @param nu parameter of the variational density for \sigma^2
#' @param nr_samples number of samples for the Monte carlo integration
#' @returns ELBO
compute_elbo <- function(y, x, beta_mu, beta_sd, nu, tau2, nr_samples = 1e4) {
n <- length(y)
sum_y2 <- sum(y^2)
sum_x2 <- sum(x^2)
sum_yx <- sum(x*y)

# Takes a function and computes its expectation with respect to q(\beta)


E_q_beta <- function(fn) {
integrate(function(beta) {
dnorm(beta, beta_mu, beta_sd) * fn(beta)
}, -Inf, Inf)$value
}

# Takes a function and computes its expectation with respect to q(\sigma^2)


E_q_sigma2 <- function(fn) {
integrate(function(sigma) {
dinvgamma(sigma^2, (n + 1)/2, nu) * fn(sigma)
}, 0, Inf)$value
}

# Compute expectations of log p(\sigma^2)


E_log_p_sigma2 <- E_q_sigma2(function(sigma) log(1/sigma^2))

# Compute expectations of log p(\beta \mid \sigma^2)


E_log_p_beta <- (
log(tau2 / beta_sd^2) * E_q_sigma2(function(sigma) log(sigma^2)) +
(beta_sd^2 + tau2) / (tau2) * E_q_sigma2(function(sigma) 1/sigma^2)
)

# Compute expectations of the log variational densities q(\beta)


E_log_q_beta <- E_q_beta(function(beta) dnorm(beta, beta_mu, beta_sd, log = TRUE))
# E_log_q_sigma2 <- E_q_sigma2(function(x) log(dinvgamma(x, (n + 1)/2, nu))) # fails

# Compute expectations of the log variational densities q(\sigma^2)


sigma2 <- rinvgamma(nr_samples, (n + 1)/2, nu)
E_log_q_sigma2 <- mean(log(dinvgamma(sigma2, (n + 1)/2, nu)))

# Compute the expected log likelihood


E_log_y_b <- sum_y2 - 2*sum_yx*beta_mu + (beta_sd^2 + beta_mu^2)*sum_x2
E_log_y_sigma2 <- E_q_sigma2(function(sigma) log(sigma^2) * 1/sigma^2)
E_log_y <- n/4 * log(2*pi) * E_log_y_b * E_log_y_sigma2

# Compute and return the ELBO


ELBO <- E_log_y + E_log_p_beta + E_log_p_sigma2 - E_log_q_beta - E_log_q_sigma2
ELBO
}

The function below implements coordinate ascent mean-field variational inference for our simple
linear regression problem. Recall that the variational parameters are:

n n n
1 1
2 2 2 2
ν = (∑ y − 2 ∑ yi xi μβ + (σ + μ ) (∑ x + ))
i β β i 2
2 τ
i=1 i=1 i=1

N
∑ yi x i
i=1
μβ =
n 2 1
∑ x +
i=1 i τ
2

n+1 −1
( )ν
2
2
σ = .
β n 2 1
∑ x +
i=1 i τ
2

The following function implements the iterative updating of these variational parameters until the
ELBO has converged.

#' Implements CAVI for the linear regression example


#'
#' @param y univariate outcome variable
#' @param x univariate predictor variable
#' @param tau2 prior variance for the standardized effect size
#' @returns parameters for the variational densities and ELBO
lmcavi <- function(y, x, tau2, nr_samples = 1e5, epsilon = 1e-2) {
n <- length(y)
sum_y2 <- sum(y^2)
sum_x2 <- sum(x^2)
sum_yx <- sum(x*y)

# is not being updated through variational inference!


beta_mu <- sum_yx / (sum_x2 + 1/tau2)

res <- list()


res[['nu']] <- 5
res[['beta_mu']] <- beta_mu
res[['beta_sd']] <- 1
res[['ELBO']] <- 0

j <- 1
has_converged <- function(x, y) abs(x - y) < epsilon
ELBO <- compute_elbo(y, x, beta_mu, 1, 5, tau2, nr_samples = nr_samples)

# while the ELBO has not converged


while (!has_converged(res[['ELBO']][j], ELBO)) {

nu_prev <- res[['nu']][j]


beta_sd_prev <- res[['beta_sd']][j]

# used in the update of beta_sd and nu


E_qA <- sum_y2 - 2*sum_yx*beta_mu + (beta_sd_prev^2 + beta_mu^2)*(sum_x2 + 1/tau2)

# update the variational parameters for sigma2 and beta


nu <- 1/2 * E_qA
beta_sd <- sqrt(((n + 1) / E_qA) / (sum_x2 + 1/tau2))

# update results object


res[['nu']] <- c(res[['nu']], nu)
res[['beta_sd']] <- c(res[['beta_sd']], beta_sd)
res[['ELBO']] <- c(res[['ELBO']], ELBO)

# compute new ELBO


j <- j + 1
ELBO <- compute_elbo(y, x, beta_mu, beta_sd, nu, tau2, nr_samples = nr_samples)
}

res
}

Let’s run this on a simulated data set of size n = 100 with a true coefficient of β = 0.30 and a true
error variance of σ 2
= 1 . We assign β a Gaussian prior with variance τ 2
= 0.25 so that values for
|β| larger than two standard deviations (0.50) receive about 0.68 prior probability.

gen_dat <- function(n, beta, sigma) {


x <- rnorm(n)
y <- 0 + beta*x + rnorm(n, 0, sigma)
data.frame(x = x, y = y)
}

set.seed(1)
dat <- gen_dat(100, 0.30, 1)

mc <- lmcavi(dat$y, dat$x, tau2 = 0.50^2)


mc

## $nu
## [1] 5.00000 88.17995 45.93875 46.20205 46.19892 46.19895
##
## $beta_mu
## [1] 0.2800556
##
## $beta_sd
## [1] 1.00000000 0.08205605 0.11368572 0.11336132 0.11336517 0.11336512
##
## $ELBO
## [1] 0.0000 -297980.0495 493.4807 -281.4578 -265.1289 -265.3197

From the output, we see that the ELBO and the variational parameters have converged. In the next
section, we compare these results to results obtained with Stan.

Comparison with Stan


Whenever one goes down a rabbit hole of calculations, it is good to sanity check one’s results. Here,
we use Stan’s variational inference scheme to check whether our results are comparable. It assumes
a Gaussian variational density for each parameter after transforming them to the real line and
automates inference in a “black-box” way so that no problem-specific calculations are required (see
Kucukelbir, Ranganath, Gelman, & Blei, 2015). Subsequently, we compare our results to the exact
posteriors arrived by Markov chain Monte carlo. The simple linear regression model in Stan is:

data {
int<lower=0> n;
vector[n] y;
vector[n] x;
real tau;
}

parameters {
real b;
real<lower=0> sigma;
}

model {
target += -log(sigma);
target += normal_lpdf(b | 0, sigma*tau);
target += normal_lpdf(y | b*x, sigma);
}

We use Stan’s black-box variational inference scheme:

library('rstan')

# save the above model to a file and compile it


model <- stan_model(file = 'stan-compiled/variational-regression.stan')

stan_dat <- list('n' = nrow(dat), 'x' = dat$x, 'y' = dat$y, 'tau' = 0.50)
fit <- vb(
model, data = stan_dat, output_samples = 20000, adapt_iter = 10000,
init = list('b' = 0.30, 'sigma' = 1), refresh = FALSE, seed = 1
)

This gives similar estimates as ours:

fit

## Inference for Stan model: variational-regression.


## 1 chains, each with iter=20000; warmup=0; thin=1;
## post-warmup draws per chain=20000, total post-warmup draws=20000.
##
## mean sd 2.5% 25% 50% 75% 97.5%
## b 0.28 0.13 0.02 0.19 0.28 0.37 0.54
## sigma 0.99 0.09 0.82 0.92 0.99 1.05 1.18
## lp__ 0.00 0.00 0.00 0.00 0.00 0.00 0.00
##
## Approximate samples were drawn using VB(meanfield) at Thu Mar 19 10:45:28 2020.

## We recommend genuine 'sampling' from the posterior distribution for final inferences!

Their recommendation is prudent. If you run the code with different seeds, you can get quite
different results. For example, the posterior mean of β can range from 0.12 to 0.45, and the
posterior standard deviation can be as low as 0.03; in all these settings, Stan indicates that the
ELBO has converged, but it seems that it has converged to a different local optimum for each run.
(For seed = 3, Stan gives completely nonsensical results). Stan warns that the algorithm is
experimental and may be unstable, and it is probably wise to not use it in production.

Update: As Ben Goodrich points out in the comments, there is some cool work on providing
diagnostics for variational inference; see this blog post and the paper by Yao, Vehtari, Simpson, &
Gelman (2018) as well as the paper by Huggins, Kasprzak, Campbell, & Broderik (2019).

Although the posterior distribution for β and σ is available in closed-form (see the Post Scriptum),
2

we check our results against exact inference using Markov chain Monte carlo by visual inspection.
fit <- sampling(model, data = stan_dat, iter = 8000, refresh = FALSE, seed = 1)

The Figure below overlays our closed-form results to the histogram of posterior samples obtained
using Stan.

Note that the posterior variance of β is slightly overestimated when using our variational scheme.
This is in contrast to the fact that variational inference generally underestimates variances. Note
also that Bayesian inference using Markov chain Monte Carlo is very fast on this simple problem.
However, the comparative advantage of variational inference becomes clear by increasing the
sample size: for sample sizes as large as n = 100000 , our variational inference scheme takes less
then a tenth of a second!

Conclusion
In this blog post, we have seen how to turn an integration problem into an optimization problem
using variational inference. Assuming that the variational densities are independent, we have
derived the optimal variational densities for a simple linear regression problem with one predictor.
While using variational inference for this problem is unnecessary since everything is available in
closed-form, I have focused on such a simple problem so as to not confound this introduction to
variational inference by the complexity of the model. Still, the derivations were quite lengthy. They
were also entirely specific to our particular problem, and thus generic “black-box” algorithms which
avoid problem-specific calculations hold great promise.

We also implemented coordinate ascent mean-field variational inference (CAVI) in R and compared
our results to results obtained via variational and exact inference using Stan. We have found that
one probably should not trust Stan’s variational inference implementation, and that our results
closely correspond to the exact procedure. For more on variational inference, I recommend the
excellent review article by Blei, Kucukelbir, and McAuliffe (2017).

I would like to thank Don van den Bergh for helpful comments on this blog post.

Post Scriptum
Normal-inverse-gamma Distribution

The posterior distribution is a Normal-inverse-gamma distribution:

α 2
γ −α−1 2γ + λ(β − μ)
2 2
p(β, σ ∣ y) = (σ ) exp (− ) ,
2
Γ (α) 2σ

where
n
∑ yi x i
i=1
μ =
n 1
∑ xi +
i=1 τ
2

n
1
λ = ∑ xi +
2
τ
i=1

n + 1
α =
2

n 2
n
⎛ 1 ⎛ (∑ yi x i ) ⎞⎞
i=1
2
γ = ∑y − .
i n 1
⎝ 2 ⎝ ∑ xi + ⎠⎠
i=1 i=1 τ
2

Note that the marginal posterior distribution for β is actually a Student-t distribution, contrary to
what we assume in our variational inference scheme.

References
Blei, D. M., Kucukelbir, A., & McAuliffe, J. D. (2017). Variational inference: A review for
statisticians. Journal of the American Statistical Association, 112(518), 859-877.
Huggins, J. H., Kasprzak, M., Campbell, T., & Broderick, T. (2019). Practical Posterior Error
Bounds from Variational Objectives. arXiv preprint arXiv:1910.04102.
Kucukelbir, A., Ranganath, R., Gelman, A., & Blei, D. (2015). Automatic variational inference in
Stan. In Advances in Neural Information Processing Systems (pp. 568-576).
Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic
differentiation variational inference. The Journal of Machine Learning Research, 18(1), 430-474.
Yao, Y., Vehtari, A., Simpson, D., & Gelman, A. (2018). Yes, but did it work?: Evaluating
variational inference. arXiv preprint arXiv:1802.02538.

Footnotes
1. The first part of this blog post draws heavily on the excellent review article by Blei,
Kucukelbier, and McAuliffe (2017), and so I use their (machine learning) notation. ↩

Written on October 30th, 2019 by Fabian Dablander

Feel free to share!

  
3 Comments 
1 Login

G Join the discussion…

LOG IN WITH OR SIGN UP WITH DISQUS ?

Name

 Share Best Newest Oldest

A
A QQ − ⚑
9 months ago edited

Hello Fabian,

could you please help me with this problem

I have problem distinguishing between the latent variables $z_i$ and the parameters $\theta_i$ in EM
algorithm. Suppose we have the hierarchical priors

\begin{aligned}
\beta|\tau,\omega &\sim \mathcal{N}(0,\tau\,\omega) \\[.5em]
\tau &\sim \text{Gamma}(a,1) \enspace \\[.5em]
\omega &\sim \text{Inv-Gamma}(b,1) \enspace .
\end{aligned}

In a paper I have read, the latent variables $z_i$ are chosen to be $\{\beta,\tau,\omega\}$ while the
hyperparameters $\theta_i$ are $\{a,b\}$.

However, [in other models][1], $\beta$ is chosen as a member of $\theta_i$.

My question is how do we choose $\theta_i$ and $z_i$? Are we free to choose?

Also, why is it that $\theta$ doesn't appear in [Variational Inference][2] but appears in the [Variational EM]
[3].

[1]: https://stats.stackexchange...
[2]: https://fabiandablander.com...
[3]: https://chrischoy.github.io...

1 0 Reply • Share ›

B
Benjamin Goodrich − ⚑
3 years ago edited

It is well worth linking to this blog post on diagnostics for variational inference

https://statmodeling.stat.c...

and the longer paper it was based on

https://arxiv.org/abs/1802....

Stan has been considering what to do with its automatic differentiation variational inference (ADVI) algorithm
because we are short of examples where ADVI currently works well but taking a multivariate normal
approximation to the posterior distribution at its mode works poorly.

And also

https://arxiv.org/abs/1910....

which provides some computable error bounds.

0 0 Reply • Share ›

Fabian Dablander Mod > Benjamin Goodrich


− ⚑
3 years ago

Indeed! I've updated the post and linked to the blog post and the papers. Thanks!

0 0 Reply • Share ›

Subscribe Privacy Do Not Sell My Data

   
Fabian Dablander | PhD Student Methods & Statistics

You might also like