Trace-based probabilistic programming

  • We'll continue with our discussion from last time, talking first about modeling / expressiveness of languages, then about inference (if we get that far...)
  • Let's revisit the last example from last class:
    T = 100
    x ~ normal(0.0, 1.0)
    for t in 1:T:
      x ~ normal(x, 1.0)
      y ~ observe(normal(x, 0.5), obs)
    vs
    stop ~ bernoulli(0.01)
    x ~ normal(0.0, 1.0)
    while not stop:
      x ~ normal(x, 1.0)
      y ~ observe(normal(x, 0.5), obs)
      stop ~ bernoulli(0.01)
In [11]:
def hmm_static_model(y, T=2):
    prev_x = pyro.sample(
        "x0", dist.Normal(0.0, 1.0)
    )
    retval = torch.empty((T,))
    for t in pyro.poutine.markov(range(1, T + 1),):
        this_x = pyro.sample(
            f"x{t}", dist.Normal(prev_x, 1.0)
        )
        this_y = pyro.sample(
            f"y{t}", dist.Normal(this_x, 1.0), obs=y[t - 1]
        )
        prev_x = this_x
        retval[t - 1] = this_y
    return retval
In [25]:
seed_me()
T = 100
ax = ts_model_plot(
    hmm_static_model,
    [None] * T,
    n_draws=5,
    T=T
)
In [16]:
static_trace = pyro.poutine.trace(hmm_static_model).get_trace([None] * 3, T=3)
static_trace.compute_log_prob()
clean_trace_view(static_trace.nodes)
Out[16]:
OrderedDict([('input', {'T': 3}),
             ('x0',
              {'fn': Normal(loc: 0.0, scale: 1.0),
               'value': tensor(-0.4013),
               'lpdf': tensor(-0.9995)}),
             ('x1',
              {'fn': Normal(loc: -0.40129169821739197, scale: 1.0),
               'value': tensor(-1.8146),
               'lpdf': tensor(-1.9177)}),
             ('y1',
              {'fn': Normal(loc: -1.8146270513534546, scale: 1.0),
               'value': tensor(-0.6041),
               'lpdf': tensor(-1.6516)}),
             ('x2',
              {'fn': Normal(loc: -1.8146270513534546, scale: 1.0),
               'value': tensor(-2.2848),
               'lpdf': tensor(-1.0295)}),
             ('y2',
              {'fn': Normal(loc: -2.284834146499634, scale: 1.0),
               'value': tensor(-0.0183),
               'lpdf': tensor(-3.4876)}),
             ('x3',
              {'fn': Normal(loc: -2.284834146499634, scale: 1.0),
               'value': tensor(-3.0112),
               'lpdf': tensor(-1.1827)}),
             ('y3',
              {'fn': Normal(loc: -3.0112056732177734, scale: 1.0),
               'value': tensor(-3.4138),
               'lpdf': tensor(-1.0000)}),
             ('output', {'value': tensor([-0.6041, -0.0183, -3.4138])})])
In [38]:
def hmm_dynamic_model():
    prev_x = pyro.sample(
        "x0", dist.Normal(0.0, 1.0)
    )
    stop = pyro.sample(
        "stop0",
        dist.Bernoulli(0.1)
    )
    retval = []
    t = 1
    while not stop:
        this_x = pyro.sample(
            f"x{t}", dist.Normal(prev_x, 1.0)
        )
        this_y = pyro.sample(
            f"y{t}", dist.Normal(this_x, 1.0),
        )
        prev_x = this_x
        retval.append(this_y)
        stop = pyro.sample(
            f"stop{t}",
            dist.Bernoulli(0.1)
        )
        t += 1
    return retval
In [39]:
seed_me()
ax = ts_model_plot(
    hmm_dynamic_model,
    n_draws=5
)
In [59]:
seed_me()
trace_lengths = [
    get_timesteps(pyro.poutine.trace(hmm_dynamic_model).get_trace()) for _ in range(1000)
]
# should be distributed geometric with p = 0.01
xvals = torch.linspace(0.0, 50.0, 101).type(torch.long)
plt.hist(trace_lengths, density=True, label='observed')
plt.plot(
    xvals, dist.Geometric(0.1).log_prob(xvals).exp(),
    label='theoretical'
)
plt.legend()
plt.xlim(0.0, 50.0);
  • So, revisiting our earlier assertion that mathematical representation $\iff$ plate notatioon $\iff$ program representation: is that actually true?
  • NO! Only true when the programmatic represetation is a first-order programming language (FOPL) (not Turing-complete, though can express stochastic control flow). Otherwise, it is only true that mathematical representation $\iff$ plate notation $\iff$ unique trace address structure.
    • "First order" in the sense of first-order logic, no universal quantifiers, not expressive enough to define unbounded sets. "Higher order" in the sense of universal quantifiers e.g. $\forall$.
  • Normative statement (quasi-opinion): if our practical modeling situation allows it, we should attempt to represent our model in a FOPL because that greatly eases inference, e.g., by allowing compilation to a graph (equivalent to plate notation!) and performing inference using the graph.
  • But, sometimes it's unavaoidable to need higher-order language, e.g., when the problem you're modeling truly has recursive structure!
    • Unbounded number of clusters, change points, or agents, unknown stopping time (as above), ...

Inference in probabilistic programming

  • There is an interesting dichotomy here. We simultaneously want to: a) entirely decouple modeling from inference, so that you never need to think about inference when writing a model; but also b) they must be completely entertwined, so that the act of writing a valid model (the model compiles!) means that inference must be possible (though not necessarily with efficiency guarantees).
  • We will see that goal (b) is always possible (though, again, without efficiency guarantees) but goal (a) is harder and requires competent domain-specific language (DSL) design, into which we cannot dive deeply.
  • Multiple examples of inference using pyro and pointers to other inference examples, as well as methods we won't implement (during class, at least, though they're great topics for midterm + final papers).
  • First, imagine, we're using a FOPL (or at least, that our models are guaranteed to be of ex ante bounded cardinality) and hence that we can compile them to a graph. (Sadly, because we have limited time in this course, we won't go over how to actually compile them, though you can and should attempt to figure this out yourself.)
  • There are fundamentally two types of graphs we can consider: DAGs (we've seen these already) and factor graphs, which we have not yet considered.
  • Let's think about inference on factor graphs...
    • Exact inference: variable elimination
    • Belief propagation (sum-product algorithm)

Variable elimination

  • What it sounds like -- eliminating variables from a factor graph.
    • Why? Suppose we want to calculate $p(z | x) = p(x | z) p(z) / p(x)$. We need to find $p(x)$ in order to do this, i.e., we need to marginalize out $z$, i.e., calculate $$ p(x) = \sum_z p(x|z) p(x) $$
  • Okay, what's the problem? Two-fold...
    1. This sum might actually be an integral that we have to compute via sampling (e.g., what if part of $z$ is a continuous-valued distribution?)
    2. Even if each component of $z$ has only $n$ possible values, if the length of $z$ is $m$, then this algorithm has complexity $n^m$ which is REALLY BAD (e.g. 100 binary rvs produces $2^{100} \approx 1.26 \times 10^{30}$ terms in the sum if computed naively...)
  • Worst case, there is nothing to be done about this. In practice, we observe that the degree of the factor nodes are often quite small, i.e., $m$ is not too large.
  • Exploit structure: sum out variables one at a time in order to dynamically create new factors, which may in turn be summed out later on.
    • Not a panacea: optimal ordering to sum out variables is NP-hard problem...
  • Example model $$ p(a, b, c, d, e) = p(a| b, c)\ p(b|c, d)\ p(c|e)\ p(d | e)\ p(e) $$ Want to find $p(a)$ (the marginal distribution). Summing out in topological order (here this is a good ordering, not always true!):
    • $p(a, b, c, d) = p(a| b, c)\ p(b|c, d)\ \sum_e p(c|e)\ p(d | e)\ p(e) = p(a| b, c)\ p(b|c, d)\ f(c, d)$
    • $p(a, b, c) = p(a| b, c)\ \sum_d p(b|c, d)\ f(c, d) = p(a | b, c)f(b, c)$
    • $p(a, b) = \sum_c p(a|b, c) f(b, c) = f(a, b)$
    • $p(a) = \sum_b f(a, b)$ (whew)
    • Equivalent to just writing $$ p(a) = \sum_b \sum_c p(a| b, c)\ \sum_d p(b|c, d)\ \sum_e p(c|e)\ p(d | e)\ p(e) $$
In [ ]: