Variable elimination

  • What it sounds like -- eliminating variables from a factor graph.
    • Why? Suppose we want to calculate $p(z | x) = p(x | z) p(z) / p(x)$. We need to find $p(x)$ in order to do this, i.e., we need to marginalize out $z$, i.e., calculate $$ p(x) = \sum_z p(x|z) p(x) $$
  • Okay, what's the problem? Two-fold...
    1. This sum might actually be an integral that we have to compute via sampling (e.g., what if part of $z$ is a continuous-valued distribution?)
    2. Even if each component of $z$ has only $n$ possible values, if the length of $z$ is $m$, then this algorithm has complexity $n^m$ which is REALLY BAD (e.g. 100 binary rvs produces $2^{100} \approx 1.26 \times 10^{30}$ terms in the sum if computed naively...)
  • Worst case, there is nothing to be done about this. In practice, we observe that the degree of the factor nodes are often quite small, i.e., $m$ is not too large.
  • Exploit structure: sum out variables one at a time in order to dynamically create new factors, which may in turn be summed out later on.
    • Not a panacea: optimal ordering to sum out variables is NP-hard problem...
  • Example model $$ p(a, b, c, d, e) = p(a| b, c)\ p(b|c, d)\ p(c|e)\ p(d | e)\ p(e) $$ Want to find $p(a)$ (the marginal distribution). Summing out in topological order (here this is a good ordering, not always true!):
    • $p(a, b, c, d) = p(a| b, c)\ p(b|c, d)\ \sum_e p(c|e)\ p(d | e)\ p(e) = p(a| b, c)\ p(b|c, d)\ f(c, d)$
    • $p(a, b, c) = p(a| b, c)\ \sum_d p(b|c, d)\ f(c, d) = p(a | b, c)f(b, c)$
    • $p(a, b) = \sum_c p(a|b, c) f(b, c) = f(a, b)$
    • $p(a) = \sum_b f(a, b)$ (whew)
    • Equivalent to just writing $$ p(a) = \sum_b \sum_c p(a| b, c)\ \sum_d p(b|c, d)\ \sum_e p(c|e)\ p(d | e)\ p(e) $$
  • VE can be used in combination with other inference algorithms (that we haven't talked about yet, but will soon)
  • Consider the following program:
    function f()
      b ~ Beta(2.0, 2.0)
      z ~ Bernoulli(b)
      if z
          x ~ Normal(1.0, 1.0)
      else
          x ~ Normal(0.0, 1.0)
      end
    end
  • The joint density is given by $p(x, z, b) = p(x | z)\ p(z|b)\ p(b)$. Instead of sampling from the entire joint density $(b, z, x) \sim p(x, z, b)$ via trace(f).get_trace(), we can use VE to generate a reduced space $(b, x) \sim p(x, b) = \sum_z p(x|z)p(z|b) p(b)$.
  • Belief propagation reuses ideas from variable elimination to enable efficient inference on large factor graphs of discrete latent variables -- we won't cover this in class but maybe in HW

Sampling

  • Sampling methods allow us to perform inference on models with dynamic structure like the ones we've been exploring the last few classes.
  • We will cover two sampling methods that are fully general, meaning that they can be applied to probabilistic programs with any structure. Other methods exist that are far more efficient but also more limited in scope. (These are good topics for final projects.) However, we will focus on...
    • Importance sampling: intelligently-weighted sampling from a joint distribution
    • Markov-chain Monte Carlo: incremental modification of a trace

Importance sampling

  • Denote the return value of a probabilistic program by $r(z)$ where $z$ are the latent random variables. The posterior expectation of $r$ is given by $E_{z \sim p(z | x)}[r(z)]$.
  • Notice that we can rewrite this integral: $$ E_{z \sim p(z | x)}[r(z)] = \int\ dz\ r(z) p(z|x) = \int\ dz\ r(z) q(z) \frac{p(z|x)}{q(z)} = E_{z \sim q(z)}\left[r(z) \frac{p(z|x)}{q(z)}\right], $$ where $q(z)$ is another pdf defined over the same domain as $p(z|x)$ (there is an additional technical restriction on $q$ that we won't cover here).
  • We don't know what $p(z|x)$ is (that's what we're trying to find!) but we do know something proportional to it: $p(x, z)$. Substituting, we have $$ E_{z \sim q(z)}\left[r(z) \frac{p(z|x)}{q(z)}\right] = E_{z \sim q(z)}\left[ r(z) \frac{p(x, z)}{p(x)q(z)} \right] = \frac{1}{p(x)}E_{z \sim q(z)}\left[ r(z) \frac{p(x, z)}{q(z)} \right] \approx \frac{1}{p(x)}\frac{1}{N}\sum_n W_n r(z_n), $$ where we define $\log W_n = \log p(x_n, z_n) - \log q(z_n)$ with $z_n \sim q(z_n)$ to be the unnormalized importance weights

Importance sampling

  • We can use this result to do multiple things:
    1. Compute an estimate for the evidence $p(x)$: set $r(z) = 1$, so that $E_z [r(z)] = 1$ for any distribution over $z$. Then: $$ 1 = \frac{1}{p(x)}E_{z \sim q(z)}\left[ 1 \cdot \frac{p(x, z)}{q(z)} \right] \implies p(x) = E_{z \sim q(z)}\left[ 1 \cdot \frac{p(x, z)}{q(z)} \right] \approx \frac{1}{N}\sum_n W_n. $$ (Remember, we would usually instead want to calculate $\log p(x)$ using the logsumexp function that is more numerically stable.)
    2. Get a nonparametric posterior density estimate simply by normalizing the un-normalized weights: $w_n = W_N / Z(W)$ where $Z(W) = \sum_n W_n$. A draw from the posterior can be approximated by a draw from the set of $z_n$s, where each $z_n$ is drawn with probability $w_n$. This is equivalent to saying that $p(z|x) \approx \text{Categorical}(w_1,..., w_N)$. (Exercise: how can we generate a continuous relaxation of this distribution?)
  • This is not the most efficient way of sampling from the posterior (HW question) but it always works, which is very valuable. At worst we can always implement importance sampling and try to increase its efficiency in a particular use case somehow (HW question). It's also pretty easy to implement from scratch.
In [2]:
def model(observations=dict(x=None), shape=(1,)):
    b = pyro.sample("b", dist.Beta(2.0, 2.0))
    z = pyro.sample("z", dist.Bernoulli(b))
    loc = z.type(torch.float)
    x = pyro.sample("x", dist.Normal(loc, 1.0).expand(shape), obs=observations["x"])
    return x

seed_me()
trace = pyro.poutine.trace(model).get_trace()
trace.compute_log_prob()
clean_trace_view(trace.nodes)
Out[2]:
OrderedDict([('input', {}),
             ('b',
              {'fn': Beta(), 'value': tensor(0.7176), 'lpdf': tensor(0.1955)}),
             ('z',
              {'fn': Bernoulli(probs: 0.7175814509391785, logits: 0.9324962496757507),
               'value': tensor(1.),
               'lpdf': tensor(-0.3319)}),
             ('x',
              {'fn': Normal(loc: tensor([1.]), scale: tensor([1.])),
               'value': tensor([1.1450]),
               'lpdf': tensor(-0.9294)}),
             ('output', {'value': tensor([1.1450])})])
In [3]:
seed_me()
data = torch.distributions.Normal(0.0, 1.0).sample(sample_shape=(50,))
posterior = pyro.infer.Importance(model, num_samples=1000).run(
    observations=dict(x=data),
    shape=data.shape
)
# get posterior samples for cluster probability
b_marginal = pyro.infer.EmpiricalMarginal(posterior, "b")
b_marginal_samples = torch.stack([
    b_marginal() for _ in range(1000)
])
In [5]:
plt.hist(dist.Beta(2.0, 2.0).sample((1000,)).numpy(), histtype="step", bins='sqrt',
        label='prior', density=True,)
plt.hist(b_marginal_samples.numpy(), histtype="step", bins='sqrt',
        label='posterior', density=True,)
plt.xlabel("$b$"); plt.ylabel("$p(b | x)$")
plt.legend();

MCMC

  • Markov-chain Monte Carlo: the idea is to incrementally modify a trace in order to obtain a better approximation to the posterior density
  • MCMC sampling uses a proposal distribution (like $q(z)$ above) that proposes new values for the posterior $z'$ given the "current" posterior value $z$. This proposal behavior is usually local in the sense that the $z$s are usually "close" to each other.
  • Though we will discuss MCMC sampling in the context only of sampling from probability distributions, in fact this method is useful for many things (numerical integration and optimization are two other common use cases)

Metropolis-Hastings

  • Probably most common MCMC algorithm and very, very useful. Easily summarized:

    z ~ ...
    samples = []
    
    for i = 1,...,N:
      z' ~ q(z'|z)
      x ~ p(x | z); x' ~ p(x'|z')
      Q = q(z|z'); Q' = q(z'|z); P = p(x, z); P' = p(x', z')
      a = (P' / Q') / (P / Q)
      if min(1, a) > uniform():
          z = z'
      samples.append(z)
  • There are two principal difficulties:
    1. How should we define a proposal kernel $q(z'|z)$ that can be evaluated at both $z'$ and $z$? Involved in this is how to best define "closeness" between the points $z'$ and $z$.
    2. How should we deal with a program that emits traces of different lengths?

Metropolis-Hastings

  • Solving question (1) is pretty simple -- or, at least, there are simple answers that aren't too terrible. There are some common symmetric proposal kernels that work well when domsins are either unbounded or we expect the posterior density to lie very far from a boundary:
    • Symmetric normal: $z' \sim \text{Normal}(z' | z, \sigma^2)$, where $\sigma^2$ is usually calibrated during sampling
    • Symmetric categorical: $z' \sim \text{Bernoulli}(z'| \{z - 1, z, z + 1\}, p = 1/3)$
  • Any solution like this can be easily modified to incorporate boundary conditions (as long as probability density / mass is modified accordingly)
  • These univariate distributions can be used to learn / sample from posterior distributions of multivariate densities -- at each iteration, we choose a random site and propose only to that one, keeping other sites in the trace constant (see Algorithm 17 on p177 of this text for more details). In exchange for ease in calculation, we'll need to acquire more samples in order to cover the posterior mass.
  • Recommend checking out Gen-1) and / or Pyro for how this might be implemented in practice.