Chapter 19
Approximate Inference
Many probabilistic models are difficult to train because it is difficult to perform
inference in them. In the context of deep learning, we usually have a set of visible
variables
v
and a set of latent variables
h
. The challenge of inference usually
refers to the difficult problem of computing
p
(
h | v
) or taking expectations with
respect to it. Such operations are often necessary for tasks like maximum likelihood
learning.
Many simple graphical models with only one hidden layer, such as restricted
Boltzmann machines and probabilistic PCA, are defined in a way that makes
inference operations like computing
p
(
h | v
), or taking expectations with respect
to it, simple. Unfortunately, most graphical models with multiple layers of hidden
variables have intractable posterior distributions. Exact inference requires an
exponential amount of time in these models. Even some models with only a single
layer, such as sparse coding, have this problem.
In this chapter, we introduce several of the techniques for confronting these
intractable inference problems. In chapter 20, we describe how to use these
techniques to train probabilistic models that would otherwise be intractable, such
as deep belief networks and deep Boltzmann machines.
Intractable inference problems in deep learning usually arise from interactions
between latent variables in a structured graphical model. See figure 19.1 for some
examples. These interactions may be due to direct interactions in undirected
models or “explaining away” interactions between mutual ancestors of the same
visible unit in directed models.
629
CHAPTER 19. APPROXIMATE INFERENCE
Figure 19.1: Intractable inference problems in deep learning are usually the result of
interactions between latent variables in a structured graphical model. These interactions
can be due to edges directly connecting one latent variable to another or longer paths
that are activated when the child of a V-structure is observed. (Left)A
semi-restricted
Boltzmann machine
(Osindero and Hinton, 2008) with connections between hidden
units. These direct connections between latent variables make the posterior distribution
intractable because of large cliques of latent variables. (Center)A deep Boltzmann machine,
organized into layers of variables without intralayer connections, still has an intractable
posterior distribution because of the connections between layers. (Right)This directed
model has interactions between latent variables when the visible variables are observed,
because every two latent variables are coparents. Some probabilistic models are able
to provide tractable inference over the latent variables despite having one of the graph
structures depicted above. This is possible if the conditional probability distributions
are chosen to introduce additional independences beyond those described by the graph.
For example, probabilistic PCA has the graph structure shown in the right yet still has
simple inference because of special properties of the specific conditional distributions it
uses (linear-Gaussian conditionals with mutually orthogonal basis vectors).
630
CHAPTER 19. APPROXIMATE INFERENCE
19.1 Inference as Optimization
Many approaches to confronting the problem of difficult inference make use of
the observation that exact inference can be described as an optimization problem.
Approximate inference algorithms may then be derived by approximating the
underlying optimization problem.
To construct the optimization problem, assume we have a probabilistic model
consisting of observed variables
v
and latent variables
h
. We would like to compute
the log-probability of the observed data,
log p
(
v
;
θ
). Sometimes it is too difficult
to compute
log p
(
v
;
θ
) if it is costly to marginalize out
h
. Instead, we can compute
a lower bound
L
(
v, θ, q
) on
log p
(
v
;
θ
). This bound is called the
evidence lower
bound
(ELBO). Another commonly used name for this lower bound is the negative
variational free energy
. Specifically, the evidence lower bound is defined to be
L(v, θ, q) = log p(v; θ) D
KL
(q(h | v)p(h | v; θ)) , (19.1)
where q is an arbitrary probability distribution over h.
Because the difference between
log p
(
v
) and
L
(
v, θ, q
) is given by the KL
divergence, and because the KL divergence is always nonnegative, we can see that
L
always has at most the same value as the desired log-probability. The two are
equal if and only if q is the same distribution as p(h | v).
Surprisingly,
L
can be considerably easier to compute for some distributions
q
.
Simple algebra shows that we can rearrange
L
into a much more convenient form:
L(v, θ, q) = log p(v; θ) D
KL
(q(h | v)p(h | v; θ)) (19.2)
= log p(v; θ) E
hq
log
q(h | v)
p(h | v)
(19.3)
= log p(v; θ) E
hq
log
q(h | v)
p(h,v;θ)
p(v;θ)
(19.4)
= log p(v; θ) E
hq
[log q(h | v) log p(h, v; θ) + log p(v; θ)] (19.5)
= E
hq
[log q(h | v) log p(h, v; θ)] . (19.6)
This yields the more canonical definition of the evidence lower bound,
L(v, θ, q) = E
hq
[log p(h, v)] + H(q). (19.7)
For an appropriate choice of
q
,
L
is tractable to compute. For any choice
of
q
,
L
provides a lower bound on the likelihood. For
q
(
h | v
) that are better
631
CHAPTER 19. APPROXIMATE INFERENCE
approximations of
p
(
h | v
), the lower bound
L
will be tighter, in other words,
closer to
log p
(
v
). When
q
(
h | v
) =
p
(
h | v
), the approximation is perfect, and
L(v, θ, q) = log p(v; θ).
We can thus think of inference as the procedure for finding the
q
that maximizes
L
. Exact inference maximizes
L
perfectly by searching over a family of functions
q
that includes
p
(
h | v
). Throughout this chapter, we show how to derive different
forms of approximate inference by using approximate optimization to find
q
. We
can make the optimization procedure less expensive but approximate by restricting
the family of distributions
q
that the optimization is allowed to search over or by
using an imperfect optimization procedure that may not completely maximize
L
but may merely increase it by a significant amount.
No matter what choice of
q
we use,
L
is a lower bound. We can get tighter
or looser bounds that are cheaper or more expensive to compute depending on
how we choose to approach this optimization problem. We can obtain a poorly
matched
q
but reduce the computational cost by using an imperfect optimization
procedure, or by using a perfect optimization procedure over a restricted family of
q distributions.
19.2 Expectation Maximization
The first algorithm we introduce based on maximizing a lower bound
L
is the
expectation maximization
(EM) algorithm, a popular training algorithm for
models with latent variables. We describe here a view on the EM algorithm
developed by Neal and Hinton (1999). Unlike most of the other algorithms we
describe in this chapter, EM is not an approach to approximate inference, but
rather an approach to learning with an approximate posterior.
The EM algorithm consists of alternating between two steps until convergence:
The
E-step
(expectation step): Let
θ
(0)
denote the value of the parameters
at the beginning of the step. Set
q
(
h
(i)
| v
) =
p
(
h
(i)
| v
(i)
;
θ
(0)
) for all
indices
i
of the training examples
v
(i)
we want to train on (both batch and
minibatch variants are valid). By this we mean
q
is defined in terms of the
current parameter value of
θ
(0)
; if we vary
θ
, then
p
(
h | v
;
θ
) will change,
but q(h | v) will remain equal to p(h | v; θ
(0)
).
The M-step (maximization step): Completely or partially maximize
i
L(v
(i)
, θ, q) (19.8)
632
CHAPTER 19. APPROXIMATE INFERENCE
with respect to θ using your optimization algorithm of choice.
This can be viewed as a coordinate ascent algorithm to maximize
L
. On one
step, we maximize
L
with respect to
q
, and on the other, we maximize
L
with
respect to θ.
Stochastic gradient ascent on latent variable models can be seen as a special
case of the EM algorithm where the M-step consists of taking a single gradient
step. Other variants of the EM algorithm can make much larger steps. For some
model families, the M-step can even be performed analytically, jumping all the
way to the optimal solution for θ given the current q.
Even though the E-step involves exact inference, we can think of the EM
algorithm as using approximate inference in some sense. Specifically, the M-step
assumes that the same value of
q
can be used for all values of
θ
. This will introduce
a gap between
L
and the true
log p
(
v
) as the M-step moves further and further
away from the value
θ
(0)
used in the E-step. Fortunately, the E-step reduces the
gap to zero again as we enter the loop for the next time.
The EM algorithm contains a few different insights. First, there is the basic
structure of the learning process, in which we update the model parameters to
improve the likelihood of a completed dataset, where all missing variables have
their values provided by an estimate of the posterior distribution. This particular
insight is not unique to the EM algorithm. For example, using gradient descent to
maximize the log-likelihood also has this same property; the log-likelihood gradient
computations require taking expectations with respect to the posterior distribution
over the hidden units. Another key insight in the EM algorithm is that we can
continue to use one value of
q
even after we have moved to a different value of
θ
.
This particular insight is used throughout classical machine learning to derive large
M-step updates. In the context of deep learning, most models are too complex
to admit a tractable solution for an optimal large M-step update, so this second
insight, which is more unique to the EM algorithm, is rarely used.
19.3 MAP Inference and Sparse Coding
We usually use the term inference to refer to computing the probability distribution
over one set of variables given another. When training probabilistic models with
latent variables, we are usually interested in computing
p
(
h | v
). An alternative
form of inference is to compute the single most likely value of the missing variables,
rather than to infer the entire distribution over their possible values. In the context
633
CHAPTER 19. APPROXIMATE INFERENCE
of latent variable models, this means computing
h
= arg max
h
p(h | v). (19.9)
This is known as
maximum a posteriori
inference, abbreviated as MAP inference.
MAP inference is usually not thought of as approximate inference—it does
compute the exact most likely value of
h
. However, if we wish to develop a
learning process based on maximizing
L
(
v, h, q
), then it is helpful to think of
MAP inference as a procedure that provides a value of
q
. In this sense, we can
think of MAP inference as approximate inference, because it does not provide the
optimal q.
Recall from section 19.1 that exact inference consists of maximizing
L(v, θ, q) = E
hq
[log p(h, v)] + H(q) (19.10)
with respect to
q
over an unrestricted family of probability distributions, using
an exact optimization algorithm. We can derive MAP inference as a form of
approximate inference by restricting the family of distributions
q
may be drawn
from. Specifically, we require q to take on a Dirac distribution:
q(h | v) = δ(h µ). (19.11)
This means that we can now control
q
entirely via
µ
. Dropping terms of
L
that
do not vary with µ, we are left with the optimization problem
µ
= arg max
µ
log p(h = µ, v), (19.12)
which is equivalent to the MAP inference problem
h
= arg max
h
p(h | v). (19.13)
We can thus justify a learning procedure similar to EM, in which we alternate
between performing MAP inference to infer
h
and then update
θ
to increase
log p
(
h
, v
). As with EM, this is a form of coordinate ascent on
L
, where we
alternate between using inference to optimize
L
with respect to
q
and using
parameter updates to optimize
L
with respect to
θ
. The procedure as a whole can
be justified by the fact that
L
is a lower bound on
log p
(
v
). In the case of MAP
inference, this justification is rather vacuous, because the bound is infinitely loose,
due to the Dirac distribution’s differential entropy of negative infinity. Adding
noise to µ would make the bound meaningful again.
634
CHAPTER 19. APPROXIMATE INFERENCE
MAP inference is commonly used in deep learning as both a feature extractor
and a learning mechanism. It is primarily used for sparse coding models.
Recall from section 13.4 that sparse coding is a linear factor model that imposes
a sparsity-inducing prior on its hidden units. A common choice is a factorial Laplace
prior, with
p(h
i
) =
λ
2
e
λ|h
i
|
. (19.14)
The visible units are then generated by performing a linear transformation and
adding noise:
p(x | h) = N(v; W h + b, β
1
I). (19.15)
Computing or even representing
p
(
h | v
) is difficult. Every pair of variables
h
i
and
h
j
are both parents of
v
. This means that when
v
is observed, the graphical
model contains an active path connecting
h
i
and
h
j
. All the hidden units thus
participate in one massive clique in
p
(
h | v
). If the model were Gaussian, then
these interactions could be modeled efficiently via the covariance matrix, but the
sparse prior makes these interactions non-Gaussian.
Because
p
(
h | v
) is intractable, so is the computation of the log-likelihood and
its gradient. We thus cannot use exact maximum likelihood learning. Instead, we
use MAP inference and learn the parameters by maximizing the ELBO defined by
the Dirac distribution around the MAP estimate of h.
If we concatenate all the
h
vectors in the training set into a matrix
H
, and
concatenate all the
v
vectors into a matrix
V
, then the sparse coding learning
process consists of minimizing
J(H, W ) =
i,j
|H
i,j
| +
i,j
V HW
2
i,j
. (19.16)
Most applications of sparse coding also involve weight decay or a constraint on the
norms of the columns of
W
, to prevent the pathological solution with extremely
small H and large W .
We can minimize
J
by alternating between minimization with respect to
H
and minimization with respect to
W
. Both subproblems are convex. In fact, the
minimization with respect to
W
is just a linear regression problem. Minimization
of J with respect to both arguments, however, is usually not a convex problem.
Minimization with respect to
H
requires specialized algorithms such as the
feature-sign search algorithm (Lee et al., 2007).
635
CHAPTER 19. APPROXIMATE INFERENCE
19.4 Variational Inference and Learning
We have seen how the evidence lower bound
L
(
v, θ, q
) is a lower bound on
log p
(
v
;
θ
), how inference can be viewed as maximizing
L
with respect to
q
, and
how learning can be viewed as maximizing
L
with respect to
θ
. We have seen
that the EM algorithm enables us to make large learning steps with a fixed
q
and
that learning algorithms based on MAP inference enable us to learn using a point
estimate of
p
(
h | v
) rather than inferring the entire distribution. Now we develop
the more general approach to variational learning.
The core idea behind variational learning is that we can maximize
L
over a
restricted family of distributions
q
. This family should be chosen so that it is easy
to compute
E
q
log p
(
h, v
). A typical way to do this is to introduce assumptions
about how q factorizes.
A common approach to variational learning is to impose the restriction that
q
is a factorial distribution:
q(h | v) =
i
q(h
i
| v). (19.17)
This is called the
mean field
approach. More generally, we can impose any graphi-
cal model structure we choose on
q
, to flexibly determine how many interactions we
want our approximation to capture. This fully general graphical model approach
is called structured variational inference (Saul and Jordan, 1996).
The beauty of the variational approach is that we do not need to specify
a specific parametric form for
q
. We specify how it should factorize, but then
the optimization problem determines the optimal probability distribution within
those factorization constraints. For discrete latent variables, this just means
that we use traditional optimization techniques to optimize a finite number of
variables describing the
q
distribution. For continuous latent variables, this means
that we use a branch of mathematics called calculus of variations to perform
optimization over a space of functions and actually determine which function
should be used to represent
q
. Calculus of variations is the origin of the names
“variational learning” and “variational inference,” though these names apply even
when the latent variables are discrete and calculus of variations is not needed.
With continuous latent variables, calculus of variations is a powerful technique
that removes much of the responsibility from the human designer of the model,
who now must specify only how
q
factorizes, rather than needing to guess how to
design a specific q that can accurately approximate the posterior.
Because
L
(
v, θ, q
) is defined to be
log p
(
v
;
θ
)
D
KL
(
q
(
h | v
)
p
(
h | v
;
θ
)), we
can think of maximizing
L
with respect to
q
as minimizing
D
KL
(
q
(
h | v
)
p
(
h | v
)).
636
CHAPTER 19. APPROXIMATE INFERENCE
In this sense, we are fitting
q
to
p
. However, we are doing so with the opposite di-
rection of the KL divergence than we are used to using for fitting an approximation.
When we use maximum likelihood learning to fit a model to data, we minimize
D
KL
(
p
data
p
model
). As illustrated in figure 3.6, this means that maximum likelihood
encourages the model to have high probability everywhere that the data has high
probability, while our optimization-based inference procedure encourages
q
to have
low probability everywhere the true posterior has low probability. Both directions
of the KL divergence can have desirable and undesirable properties. The choice of
which to use depends on which properties are the highest priority for each applica-
tion. In the inference optimization problem, we choose to use
D
KL
(
q
(
h | v
)
p
(
h |
v
)) for computational reasons. Specifically, computing
D
KL
(
q
(
h | v
)
p
(
h | v
)) in-
volves evaluating expectations with respect to
q
, so by designing
q
to be simple, we
can simplify the required expectations. The opposite direction of the KL divergence
would require computing expectations with respect to the true posterior. Because
the form of the true posterior is determined by the choice of model, we cannot
design a reduced-cost approach to computing D
KL
(p(h | v)q(h | v)) exactly.
19.4.1 Discrete Latent Variables
Variational inference with discrete latent variables is relatively straightforward.
We define a distribution
q
, typically one where each factor of
q
is just defined
by a lookup table over discrete states. In the simplest case,
h
is binary and we
make the mean field assumption that q factorizes over each individual h
i
. In this
case we can parametrize
q
with a vector
ˆ
h
whose entries are probabilities. Then
q(h
i
= 1 | v) =
ˆ
h
i
.
After determining how to represent
q
, we simply optimize its parameters. With
discrete latent variables, this is just a standard optimization problem. In principle
the selection of
q
could be done with any optimization algorithm, such as gradient
descent.
Because this optimization must occur in the inner loop of a learning algorithm,
it must be very fast. To achieve this speed, we typically use special optimization
algorithms that are designed to solve comparatively small and simple problems in
few iterations. A popular choice is to iterate fixed-point equations, in other words,
to solve
ˆ
h
i
L = 0 (19.18)
for
ˆ
h
i
. We repeatedly update different elements of
ˆ
h
until we satisfy a convergence
criterion.
637
CHAPTER 19. APPROXIMATE INFERENCE
To make this more concrete, we show how to apply variational inference to
the
binary sparse coding model
(we present here the model developed by
Henniges et al. [2010] but demonstrate traditional, generic mean field applied to
the model, while they introduce a specialized algorithm). This derivation goes
into considerable mathematical detail and is intended for the reader who wishes to
fully resolve any ambiguity in the high-level conceptual description of variational
inference and learning we have presented so far. Readers who do not plan to derive
or implement variational learning algorithms may safely skip to the next section
without missing any new high-level concepts. Readers who proceed with the binary
sparse coding example are encouraged to review the list of useful properties of
functions that commonly arise in probabilistic models in section 3.10. We use
these properties liberally throughout the following derivations without highlighting
exactly where we use each one.
In the binary sparse coding model, the input
v R
n
is generated from the
model by adding Gaussian noise to the sum of
m
different components, which
can each be present or absent. Each component is switched on or off by the
corresponding hidden unit in h {0, 1}
m
:
p(h
i
= 1) = σ(b
i
), (19.19)
p(v | h) = N(v; W h, β
1
), (19.20)
where
b
is a learnable set of biases,
W
is a learnable weight matrix, and
β
is a
learnable, diagonal precision matrix.
Training this model with maximum likelihood requires taking the derivative
with respect to the parameters. Consider the derivative with respect to one of the
biases:
b
i
log p(v) (19.21)
=
b
i
p(v)
p(v)
(19.22)
=
b
i
h
p(h, v)
p(v)
(19.23)
=
b
i
h
p(h)p(v | h)
p(v)
(19.24)
=
h
p(v | h)
b
i
p(h)
p(v)
(19.25)
638
CHAPTER 19. APPROXIMATE INFERENCE
h
1
h
1
h
2
h
2
h
3
h
3
v
1
v
1
v
2
v
2
v
3
v
3
h
4
h
4
h
1
h
1
h
2
h
2
h
3
h
3
h
4
h
4
Figure 19.2: The graph structure of a binary sparse coding model with four hidden units.
(Left)The graph structure of
p
(
h, v
). Note that the edges are directed, and that every two
hidden units are coparents of every visible unit. (Right)The graph structure of
p
(
h | v
).
To account for the active paths between coparents, the posterior distribution needs an
edge between all the hidden units.
=
h
p(h | v)
b
i
p(h)
p(h)
(19.26)
=E
hp(h|v)
b
i
log p(h). (19.27)
This requires computing expectations with respect to
p
(
h | v
). Unfortunately,
p
(
h | v
) is a complicated distribution. See figure 19.2 for the graph structure of
p
(
h, v
) and
p
(
h | v
). The posterior distribution corresponds to the complete graph
over the hidden units, so variable elimination algorithms do not help us to compute
the required expectations any faster than brute force.
We can resolve this difficulty by using variational inference and variational
learning instead.
We can make a mean field approximation:
q(h | v) =
i
q(h
i
| v). (19.28)
The latent variables of the binary sparse coding model are binary, so to represent
a factorial
q
we simply need to model
m
Bernoulli distributions
q
(
h
i
| v
). A natural
way to represent the means of the Bernoulli distributions is with a vector
ˆ
h
of
probabilities, with
q
(
h
i
= 1
| v
) =
ˆ
h
i
. We impose a restriction that
ˆ
h
i
is never
equal to 0 or to 1, in order to avoid errors when computing, for example, log
ˆ
h
i
.
We will see that the variational inference equations never assign 0 or 1 to
ˆ
h
i
analytically. In a software implementation, however, machine rounding error could
result in 0 or 1 values. In software, we may wish to implement binary sparse
639
CHAPTER 19. APPROXIMATE INFERENCE
coding using an unrestricted vector of variational parameters
z
and obtain
ˆ
h
via
the relation
ˆ
h
=
σ
(
z
). We can thus safely compute
log
ˆ
h
i
on a computer by using
the identity log σ(z
i
) = ζ(z
i
), relating the sigmoid and the softplus.
To begin our derivation of variational learning in the binary sparse coding
model, we show that the use of this mean field approximation makes learning
tractable.
The evidence lower bound is given by
L(v, θ, q) (19.29)
= E
hq
[log p(h, v)] + H(q) (19.30)
= E
hq
[log p(h) + log p(v | h) log q(h | v)] (19.31)
= E
hq
m
i=1
log p(h
i
) +
n
i=1
log p(v
i
| h)
m
i=1
log q(h
i
| v)
(19.32)
=
m
i=1
ˆ
h
i
(log σ(b
i
) log
ˆ
h
i
) + (1
ˆ
h
i
)(log σ(b
i
) log(1
ˆ
h
i
))
(19.33)
+ E
hq
n
i=1
log
β
i
2π
exp
β
i
2
(v
i
W
i,:
h)
2
(19.34)
=
m
i=1
ˆ
h
i
(log σ(b
i
) log
ˆ
h
i
) + (1
ˆ
h
i
)(log σ(b
i
) log(1
ˆ
h
i
))
(19.35)
+
1
2
n
i=1
log
β
i
2π
β
i
v
2
i
2v
i
W
i,:
ˆ
h +
j
W
2
i,j
ˆ
h
j
+
k=j
W
i,j
W
i,k
ˆ
h
j
ˆ
h
k
.
(19.36)
While these equations are somewhat unappealing aesthetically, they show that
L
can be expressed in a small number of simple arithmetic operations. The evidence
lower bound
L
is therefore tractable. We can use
L
as a replacement for the
intractable log-likelihood.
In principle, we could simply run gradient ascent on both
v
and
h
, and this
would make a perfectly acceptable combined inference and training algorithm.
Usually, however, we do not do this, for two reasons. First, this would require
storing
ˆ
h
for each
v
. We typically prefer algorithms that do not require per
example memory. It is difficult to scale learning algorithms to billions of examples
if we must remember a dynamically updated vector associated with each example.
Second, we would like to be able to extract the features
ˆ
h
very quickly, in order to
recognize the content of
v
. In a realistic deployed setting, we would need to be
640
CHAPTER 19. APPROXIMATE INFERENCE
able to compute
ˆ
h in real time.
For both these reasons, we typically do not use gradient descent to compute
the mean field parameters
ˆ
h
. Instead, we rapidly estimate them with fixed-point
equations.
The idea behind fixed-point equations is that we are seeking a local maximum
with respect to
ˆ
h
, where
h
L
(
v, θ,
ˆ
h
) =
0
. We cannot efficiently solve this
equation with respect to all of
ˆ
h
simultaneously. However, we can solve for a single
variable:
ˆ
h
i
L(v, θ,
ˆ
h) = 0. (19.37)
We can then iteratively apply the solution to the equation for
i
= 1
, . . . , m
,
and repeat the cycle until we satisfy a convergence criterion. Common convergence
criteria include stopping when a full cycle of updates does not improve
L
by more
than some tolerance amount, or when the cycle does not change
ˆ
h
by more than
some amount.
Iterating mean field fixed-point equations is a general technique that can provide
fast variational inference in a broad variety of models. To make this more concrete,
we show how to derive the updates for the binary sparse coding model in particular.
First, we must write an expression for the derivatives with respect to
ˆ
h
i
. To
do so, we substitute equation 19.36 into the left side of equation 19.37:
ˆ
h
i
L(v, θ,
ˆ
h) (19.38)
=
ˆ
h
i
m
j=1
ˆ
h
j
(log σ(b
j
) log
ˆ
h
j
) + (1
ˆ
h
j
)(log σ(b
j
) log(1
ˆ
h
j
))
(19.39)
+
1
2
n
j=1
log
β
j
2π
β
j
v
2
j
2v
j
W
j,:
ˆ
h +
k
W
2
j,k
ˆ
h
k
+
l=k
W
j,k
W
j,l
ˆ
h
k
ˆ
h
l
(19.40)
= log σ(b
i
) log
ˆ
h
i
1 + log(1
ˆ
h
i
) + 1 log σ(b
i
) (19.41)
+
n
j=1
β
j
v
j
W
j,i
1
2
W
2
j,i
k=i
W
j,k
W
j,i
ˆ
h
k
(19.42)
=b
i
log
ˆ
h
i
+ log(1
ˆ
h
i
) + v
βW
:,i
1
2
W
:,i
βW
:,i
j=i
W
:,j
βW
:,i
ˆ
h
j
. (19.43)
641
CHAPTER 19. APPROXIMATE INFERENCE
To apply the fixed-point update inference rule, we solve for the
ˆ
h
i
that sets
equation 19.43 to 0:
ˆ
h
i
= σ
b
i
+ v
βW
:,i
1
2
W
:,i
βW
:,i
j=i
W
:,j
βW
:,i
ˆ
h
j
. (19.44)
At this point, we can see that there is a close connection between recurrent
neural networks and inference in graphical models. Specifically, the mean field
fixed-point equations defined a recurrent neural network. The task of this network
is to perform inference. We have described how to derive this network from a
model description, but it is also possible to train the inference network directly.
Several ideas based on this theme are described in chapter 20.
In the case of binary sparse coding, we can see that the recurrent network
connection specified by equation 19.44 consists of repeatedly updating the hidden
units based on the changing values of the neighboring hidden units. The input
always sends a fixed message of
v
βW
to the hidden units, but the hidden units
constantly update the message they send to each other. Specifically, two units
ˆ
h
i
and
ˆ
h
j
inhibit each other when their weight vectors are aligned. This is a form of
competition—between two hidden units that both explain the input, only the one
that explains the input best will be allowed to remain active. This competition is
the mean field approximation’s attempt to capture the explaining away interactions
in the binary sparse coding posterior. The explaining away effect actually should
cause a multimodal posterior, so that if we draw samples from the posterior,
some samples will have one unit active, other samples will have the other unit
active, but very few samples will have both active. Unfortunately, explaining away
interactions cannot be modeled by the factorial
q
used for mean field, so the mean
field approximation is forced to choose one mode to model. This is an instance of
the behavior illustrated in figure 3.6.
We can rewrite equation 19.44 into an equivalent form that reveals some further
insights:
ˆ
h
i
= σ
b
i
+
v
j=i
W
:,j
ˆ
h
j
βW
:,i
1
2
W
:,i
βW
:,i
. (19.45)
In this reformulation, we see the input at each step as consisting of
v
j=i
W
:,j
ˆ
h
j
rather than
v
. We can thus think of unit
i
as attempting to encode the residual
error in
v
given the code of the other units. We can thus think of sparse coding as an
642
CHAPTER 19. APPROXIMATE INFERENCE
iterative autoencoder, which repeatedly encodes and decodes its input, attempting
to fix mistakes in the reconstruction after each iteration.
In this example, we have derived an update rule that updates a single unit at
a time. It would be advantageous to be able to update more units simultaneously.
Some graphical models, such as deep Boltzmann machines, are structured in such a
way that we can solve for many entries of
ˆ
h
simultaneously. Unfortunately, binary
sparse coding does not admit such block updates. Instead, we can use a heuristic
technique called
damping
to perform block updates. In the damping approach,
we solve for the individually optimal values of every element of
ˆ
h
, then move all
the values in a small step in that direction. This approach is no longer guaranteed
to increase
L
at each step, but it works well in practice for many models. See
Koller and Friedman (2009) for more information about choosing the degree of
synchrony and damping strategies in message-passing algorithms.
19.4.2 Calculus of Variations
Before continuing with our presentation of variational learning, we must briefly
introduce an important set of mathematical tools used in variational learning:
calculus of variations.
Many machine learning techniques are based on minimizing a function
J
(
θ
) by
finding the input vector
θ R
n
for which it takes on its minimal value. This can
be accomplished with multivariate calculus and linear algebra, by solving for the
critical points where
θ
J
(
θ
) =
0
. In some cases, we actually want to solve for a
function
f
(
x
), such as when we want to find the probability density function over
some random variable. This is what calculus of variations enables us to do.
A function of a function
f
is known as a
functional J
[
f
]. Much as we
can take partial derivatives of a function with respect to elements of its vector-
valued argument, we can take
functional derivatives
, also known as
variational
derivatives
, of a functional
J
[
f
] with respect to individual values of the function
f
(
x
) at any specific value of
x
. The functional derivative of the functional
J
with
respect to the value of the function f at point x is denoted
δ
δf(x)
J.
A complete formal development of functional derivatives is beyond the scope of
this book. For our purposes, it is sufficient to state that for differentiable functions
f(x) and differentiable functions g(y, x) with continuous derivatives, that
δ
δf(x)
g (f (x), x) dx =
y
g(f(x), x). (19.46)
To gain some intuition for this identity, one can think of
f
(
x
) as being a vector
643
CHAPTER 19. APPROXIMATE INFERENCE
with uncountably many elements, indexed by a real vector
x
. In this (somewhat
incomplete) view, the identity providing the functional derivatives is the same as
what we would obtain for a vector θ R
n
indexed by positive integers:
θ
i
j
g(θ
j
, j) =
θ
i
g(θ
i
, i). (19.47)
Many results in other machine learning publications are presented using the more
general
Euler-Lagrange equation
, which allows
g
to depend on the derivatives
of
f
as well as the value of
f
, but we do not need this fully general form for the
results presented in this book.
To optimize a function with respect to a vector, we take the gradient of the
function with respect to the vector and solve for the point where every element of
the gradient is equal to zero. Likewise, we can optimize a functional by solving for
the function where the functional derivative at every point is equal to zero.
As an example of how this process works, consider the problem of finding the
probability distribution function over
x R
that has maximal differential entropy.
Recall that the entropy of a probability distribution p(x) is defined as
H[p] = E
x
log p(x). (19.48)
For continuous values, the expectation is an integral:
H[p] =
p(x) log p(x)dx. (19.49)
We cannot simply maximize
H
[
p
] with respect to the function
p
(
x
), because the
result might not be a probability distribution. Instead, we need to use Lagrange
multipliers to add a constraint that
p
(
x
) integrate to 1. Also, the entropy should
increase without bound as the variance increases. This makes the question of
which distribution has the greatest entropy uninteresting. Instead, we ask which
distribution has maximal entropy for fixed variance
σ
2
. Finally, the problem
is underdetermined because the distribution can be shifted arbitrarily without
changing the entropy. To impose a unique solution, we add a constraint that the
mean of the distribution be
µ
. The Lagrangian functional for this optimization
problem is
L[p] = λ
1
p(x)dx 1
+λ
2
(E[x] µ)+λ
3
E[(x µ)
2
] σ
2
+H[p] (19.50)
=
λ
1
p(x) + λ
2
p(x)x + λ
3
p(x)(x µ)
2
p(x) log p(x)
dx λ
1
µλ
2
σ
2
λ
3
.
(19.51)
644
CHAPTER 19. APPROXIMATE INFERENCE
To minimize the Lagrangian with respect to
p
, we set the functional derivatives
equal to 0:
x,
δ
δp(x)
L = λ
1
+ λ
2
x + λ
3
(x µ)
2
1 log p(x) = 0. (19.52)
This condition now tells us the functional form of
p
(
x
). By algebraically
rearranging the equation, we obtain
p(x) = exp
λ
1
+ λ
2
x + λ
3
(x µ)
2
1
. (19.53)
We never assumed directly that
p
(
x
) would take this functional form; we
obtained the expression itself by analytically minimizing a functional. To finish
the minimization problem, we must choose the
λ
values to ensure that all our
constraints are satisfied. We are free to choose any
λ
values, because the gradient
of the Lagrangian with respect to the
λ
variables is zero as long as the constraints
are satisfied. To satisfy all the constraints, we may set
λ
1
= 1
log σ
2π
,
λ
2
= 0,
and λ
3
=
1
2σ
2
to obtain
p(x) = N(x; µ, σ
2
). (19.54)
This is one reason for using the normal distribution when we do not know the
true distribution. Because the normal distribution has the maximum entropy, we
impose the least possible amount of structure by making this assumption.
While examining the critical points of the Lagrangian functional for the entropy,
we found only one critical point, corresponding to maximizing the entropy for
fixed variance. What about the probability distribution function that minimizes
the entropy? Why did we not find a second critical point corresponding to the
minimum? The reason is that no specific function achieves minimal entropy. As
functions place more probability density on the two points
x
=
µ
+
σ
and
x
=
µσ
,
and place less probability density on all other values of
x
, they lose entropy while
maintaining the desired variance. However, any function placing exactly zero mass
on all but two points does not integrate to one and is not a valid probability
distribution. Thus there is no single minimal entropy probability distribution
function, much as there is no single minimal positive real number. Instead, we can
say that there is a sequence of probability distributions converging toward putting
mass only on these two points. This degenerate scenario may be described as a
mixture of Dirac distributions. Because Dirac distributions are not described by a
single probability distribution function, no Dirac or mixture of Dirac distribution
corresponds to a single specific point in function space. These distributions are
thus invisible to our method of solving for a specific point where the functional
645
CHAPTER 19. APPROXIMATE INFERENCE
derivatives are zero. This is a limitation of the method. Distributions such as the
Dirac must be found by other methods, such as guessing the solution and then
proving that it is correct.
19.4.3 Continuous Latent Variables
When our graphical model contains continuous latent variables, we can still perform
variational inference and learning by maximizing
L
. However, we must now use
calculus of variations when maximizing L with respect to q(h | v).
In most cases, practitioners need not solve any calculus of variations problems
themselves. Instead, there is a general equation for the mean field fixed-point
updates. If we make the mean field approximation
q(h | v) =
i
q(h
i
| v), (19.55)
and fix
q
(
h
j
| v
) for all
j
=
i
, then the optimal
q
(
h
i
| v
) may be obtained by
normalizing the unnormalized distribution
˜q(h
i
| v) = exp
E
h
i
q(h
i
|v)
log ˜p(v, h)
, (19.56)
as long as
p
does not assign 0 probability to any joint configuration of variables.
Carrying out the expectation inside the equation will yield the correct functional
form of
q
(
h
i
| v
). Deriving functional forms of
q
directly using calculus of variations
is only necessary if one wishes to develop a new form of variational learning;
equation 19.56 yields the mean field approximation for any probabilistic model.
Equation 19.56 is a fixed-point equation, designed to be iteratively applied for
each value of
i
repeatedly until convergence. However, it also tells us more than
that. It tells us the functional form that the optimal solution will take, whether
we arrive there by fixed-point equations or not. This means we can take the
functional form from that equation but regard some of the values that appear in it
as parameters, which we can optimize with any optimization algorithm we like.
As an example, consider a simple probabilistic model, with latent variables
h R
2
and just one visible variable,
v
. Suppose that
p
(
h
) =
N
(
h
; 0
, I
) and
p
(
v | h
) =
N
(
v
;
w
h
; 1). We could actually simplify this model by integrating
out
h
; the result is just a Gaussian distribution over
v
. The model itself is not
interesting; we have constructed it only to provide a simple demonstration of how
calculus of variations can be applied to probabilistic modeling.
The true posterior is given, up to a normalizing constant, by
p(h | v) (19.57)
646
CHAPTER 19. APPROXIMATE INFERENCE
p(h, v) (19.58)
= p(h
1
)p(h
2
)p(v | h) (19.59)
exp
1
2
h
2
1
+ h
2
2
+ (v h
1
w
1
h
2
w
2
)
2
(19.60)
= exp
1
2
h
2
1
+ h
2
2
+ v
2
+ h
2
1
w
2
1
+ h
2
2
w
2
2
2vh
1
w
1
2vh
2
w
2
+ 2h
1
w
1
h
2
w
2
.
(19.61)
Because of the presence of the terms multiplying
h
1
and
h
2
together, we can see
that the true posterior does not factorize over h
1
and h
2
.
Applying equation 19.56, we find that
˜q(h
1
| v) (19.62)
= exp
E
h
2
q(h
2
|v)
log ˜p(v, h)
(19.63)
= exp
1
2
E
h
2
q(h
2
|v)
h
2
1
+ h
2
2
+ v
2
+ h
2
1
w
2
1
+ h
2
2
w
2
2
(19.64)
2vh
1
w
1
2vh
2
w
2
+ 2h
1
w
1
h
2
w
2
]
. (19.65)
From this, we can see that there are effectively only two values we need to obtain
from
q
(
h
2
| v
):
E
h
2
q(h|v)
[
h
2
] and
E
h
2
q(h|v)
[
h
2
2
]. Writing these as
h
2
and
h
2
2
,
we obtain
˜q(h
1
| v) = exp
1
2
h
2
1
+ h
2
2
+ v
2
+ h
2
1
w
2
1
+ h
2
2
w
2
2
(19.66)
2vh
1
w
1
2vh
2
w
2
+ 2h
1
w
1
h
2
w
2
]
. (19.67)
From this, we can see that
˜q
has the functional form of a Gaussian. We can
thus conclude
q
(
h | v
) =
N
(
h
;
µ, β
1
) where
µ
and diagonal
β
are variational
parameters, which we can optimize using any technique we choose. It is important
to recall that we did not ever assume that
q
would be Gaussian; its Gaussian
form was derived automatically by using calculus of variations to maximize
q
with
respect to
L
. Using the same approach on a different model could yield a different
functional form of q.
This was, of course, just a small case constructed for demonstration purposes.
For examples of real applications of variational learning with continuous variables
in the context of deep learning, see Goodfellow et al. (2013d).
647
CHAPTER 19. APPROXIMATE INFERENCE
19.4.4 Interactions between Learning and Inference
Using approximate inference as part of a learning algorithm affects the learning
process, and this in turn affects the accuracy of the inference algorithm.
Specifically, the training algorithm tends to adapt the model in a way that makes
the approximating assumptions underlying the approximate inference algorithm
become more true. When training the parameters, variational learning increases
E
hq
log p(v, h). (19.68)
For a specific
v
, this increases
p
(
h | v
) for values of
h
that have high probability
under
q
(
h | v
) and decreases
p
(
h | v
) for values of
h
that have low probability
under q(h | v).
This behavior causes our approximating assumptions to become self-fulfilling
prophecies. If we train the model with a unimodal approximate posterior, we will
obtain a model with a true posterior that is far closer to unimodal than we would
have obtained by training the model with exact inference.
Computing the true amount of harm imposed on a model by a variational
approximation is thus very difficult. There exist several methods for estimating
log p
(
v
). We often estimate
log p
(
v
;
θ
) after training the model and find that
the gap with
L
(
v, θ, q
) is small. From this, we can conclude that our variational
approximation is accurate for the specific value of
θ
that we obtained from the
learning process. We should not conclude that our variational approximation is
accurate in general or that the variational approximation did little harm to the
learning process. To measure the true amount of harm induced by the variational
approximation, we would need to know
θ
=
max
θ
log p
(
v
;
θ
). It is possible for
L
(
v, θ, q
)
log p
(
v
;
θ
) and
log p
(
v
;
θ
)
log p
(
v
;
θ
) to hold simultaneously. If
max
q
L
(
v, θ
, q
)
log p
(
v
;
θ
), because
θ
induces too complicated of a posterior
distribution for our
q
family to capture, then the learning process will never
approach
θ
. Such a problem is very difficult to detect, because we can only know
for sure that it happened if we have a superior learning algorithm that can find
θ
for comparison.
19.5 Learned Approximate Inference
We have seen that inference can be thought of as an optimization procedure
that increases the value of a function
L
. Explicitly performing optimization via
iterative procedures such as fixed-point equations or gradient-based optimization
is often very expensive and time consuming. Many approaches to inference avoid
648
CHAPTER 19. APPROXIMATE INFERENCE
this expense by learning to perform approximate inference. Specifically, we can
think of the optimization process as a function
f
that maps an input
v
to an
approximate distribution
q
=
arg max
q
L
(
v, q
). Once we think of the multistep
iterative optimization process as just being a function, we can approximate it with
a neural network that implements an approximation
ˆ
f(v; θ).
19.5.1 Wake-Sleep
One of the main difficulties with training a model to infer
h
from
v
is that we
do not have a supervised training set with which to train the model. Given a
v
,
we do not know the appropriate
h
. The mapping from
v
to
h
depends on the
choice of model family, and evolves throughout the learning process as
θ
changes.
The wake-sleep algorithm (Hinton et al., 1995b; Frey et al., 1996) resolves this
problem by drawing samples of both
h
and
v
from the model distribution. For
example, in a directed model, this can be done cheaply by performing ancestral
sampling beginning at
h
and ending at
v
. The inference network can then be
trained to perform the reverse mapping: predicting which
h
caused the present
v
. The main drawback to this approach is that we will only be able to train the
inference network on values of
v
that have high probability under the model. Early
in learning, the model distribution will not resemble the data distribution, so the
inference network will not have an opportunity to learn on samples that resemble
data.
In section 18.2 we saw that one possible explanation for the role of dream sleep
in human beings and animals is that dreams could provide the negative phase
samples that Monte Carlo training algorithms use to approximate the negative
gradient of the log partition function of undirected models. Another possible
explanation for biological dreaming is that it is providing samples from
p
(
h, v
)
which can be used to train an inference network to predict
h
given
v
. In some
senses, this explanation is more satisfying than the partition function explanation.
Monte Carlo algorithms generally do not perform well if they are run using only
the positive phase of the gradient for several steps then with only the negative
phase of the gradient for several steps. Human beings and animals are usually
awake for several consecutive hours then asleep for several consecutive hours. It is
not readily apparent how this schedule could support Monte Carlo training of an
undirected model. Learning algorithms based on maximizing
L
can be run with
prolonged periods of improving
q
and prolonged periods of improving
θ
, however.
If the role of biological dreaming is to train networks for predicting
q
, then this
explains how animals are able to remain awake for several hours (the longer they
are awake, the greater the gap between
L
and
log p
(
v
), but
L
will remain a lower
649
CHAPTER 19. APPROXIMATE INFERENCE
bound) and to remain asleep for several hours (the generative model itself is not
modified during sleep) without damaging their internal models. Of course, these
ideas are purely speculative, and there is no hard evidence to suggest that dreaming
accomplishes either of these goals. Dreaming may also serve reinforcement learning
rather than probabilistic modeling, by sampling synthetic experiences from the
animal’s transition model, on which to train the animal’s policy. Or sleep may
serve some other purpose not yet anticipated by the machine learning community.
19.5.2 Other Forms of Learned Inference
This strategy of learned approximate inference has also been applied to other
models. Salakhutdinov and Larochelle (2010) showed that a single pass in a
learned inference network could yield faster inference than iterating the mean field
fixed-point equations in a DBM. The training procedure is based on running the
inference network, then applying one step of mean field to improve its estimates,
and training the inference network to output this refined estimate instead of its
original estimate.
We have already seen in section 14.8 that the predictive sparse decomposition
model trains a shallow encoder network to predict a sparse code for the input.
This can be seen as a hybrid between an autoencoder and sparse coding. It is
possible to devise probabilistic semantics for the model, under which the encoder
may be viewed as performing learned approximate MAP inference. Due to its
shallow encoder, PSD is not able to implement the kind of competition between
units that we have seen in mean field inference. However, that problem can be
remedied by training a deep encoder to perform learned approximate inference, as
in the ISTA technique (Gregor and LeCun, 2010b).
Learned approximate inference has recently become one of the dominant
approaches to generative modeling, in the form of the variational autoencoder
(Kingma, 2013; Rezende et al., 2014). In this elegant approach, there is no need to
construct explicit targets for the inference network. Instead, the inference network
is simply used to define
L
, and then the parameters of the inference network are
adapted to increase L. This model is described in depth in section 20.10.3.
Using approximate inference, it is possible to train and use a wide variety of
models. Many of these models are described in the next chapter.
650