## Trace-based probabilistic programming¶

• We'll continue with our discussion from last time, talking first about modeling / expressiveness of languages, then about inference (if we get that far...)
• Let's revisit the last example from last class:
T = 100
x ~ normal(0.0, 1.0)
for t in 1:T:
x ~ normal(x, 1.0)
y ~ observe(normal(x, 0.5), obs)
vs
stop ~ bernoulli(0.01)
x ~ normal(0.0, 1.0)
while not stop:
x ~ normal(x, 1.0)
y ~ observe(normal(x, 0.5), obs)
stop ~ bernoulli(0.01)
In [11]:
def hmm_static_model(y, T=2):
prev_x = pyro.sample(
"x0", dist.Normal(0.0, 1.0)
)
retval = torch.empty((T,))
for t in pyro.poutine.markov(range(1, T + 1),):
this_x = pyro.sample(
f"x{t}", dist.Normal(prev_x, 1.0)
)
this_y = pyro.sample(
f"y{t}", dist.Normal(this_x, 1.0), obs=y[t - 1]
)
prev_x = this_x
retval[t - 1] = this_y
return retval

In [25]:
seed_me()
T = 100
ax = ts_model_plot(
hmm_static_model,
[None] * T,
n_draws=5,
T=T
)

In [16]:
static_trace = pyro.poutine.trace(hmm_static_model).get_trace([None] * 3, T=3)
static_trace.compute_log_prob()
clean_trace_view(static_trace.nodes)

Out[16]:
OrderedDict([('input', {'T': 3}),
('x0',
{'fn': Normal(loc: 0.0, scale: 1.0),
'value': tensor(-0.4013),
'lpdf': tensor(-0.9995)}),
('x1',
{'fn': Normal(loc: -0.40129169821739197, scale: 1.0),
'value': tensor(-1.8146),
'lpdf': tensor(-1.9177)}),
('y1',
{'fn': Normal(loc: -1.8146270513534546, scale: 1.0),
'value': tensor(-0.6041),
'lpdf': tensor(-1.6516)}),
('x2',
{'fn': Normal(loc: -1.8146270513534546, scale: 1.0),
'value': tensor(-2.2848),
'lpdf': tensor(-1.0295)}),
('y2',
{'fn': Normal(loc: -2.284834146499634, scale: 1.0),
'value': tensor(-0.0183),
'lpdf': tensor(-3.4876)}),
('x3',
{'fn': Normal(loc: -2.284834146499634, scale: 1.0),
'value': tensor(-3.0112),
'lpdf': tensor(-1.1827)}),
('y3',
{'fn': Normal(loc: -3.0112056732177734, scale: 1.0),
'value': tensor(-3.4138),
'lpdf': tensor(-1.0000)}),
('output', {'value': tensor([-0.6041, -0.0183, -3.4138])})])
In [38]:
def hmm_dynamic_model():
prev_x = pyro.sample(
"x0", dist.Normal(0.0, 1.0)
)
stop = pyro.sample(
"stop0",
dist.Bernoulli(0.1)
)
retval = []
t = 1
while not stop:
this_x = pyro.sample(
f"x{t}", dist.Normal(prev_x, 1.0)
)
this_y = pyro.sample(
f"y{t}", dist.Normal(this_x, 1.0),
)
prev_x = this_x
retval.append(this_y)
stop = pyro.sample(
f"stop{t}",
dist.Bernoulli(0.1)
)
t += 1
return retval

In [39]:
seed_me()
ax = ts_model_plot(
hmm_dynamic_model,
n_draws=5
)

In [59]:
seed_me()
trace_lengths = [
get_timesteps(pyro.poutine.trace(hmm_dynamic_model).get_trace()) for _ in range(1000)
]
# should be distributed geometric with p = 0.01
xvals = torch.linspace(0.0, 50.0, 101).type(torch.long)
plt.hist(trace_lengths, density=True, label='observed')
plt.plot(
xvals, dist.Geometric(0.1).log_prob(xvals).exp(),
label='theoretical'
)
plt.legend()
plt.xlim(0.0, 50.0);

• So, revisiting our earlier assertion that mathematical representation $\iff$ plate notatioon $\iff$ program representation: is that actually true?
• NO! Only true when the programmatic represetation is a first-order programming language (FOPL) (not Turing-complete, though can express stochastic control flow). Otherwise, it is only true that mathematical representation $\iff$ plate notation $\iff$ unique trace address structure.
• "First order" in the sense of first-order logic, no universal quantifiers, not expressive enough to define unbounded sets. "Higher order" in the sense of universal quantifiers e.g. $\forall$.
• Normative statement (quasi-opinion): if our practical modeling situation allows it, we should attempt to represent our model in a FOPL because that greatly eases inference, e.g., by allowing compilation to a graph (equivalent to plate notation!) and performing inference using the graph.
• But, sometimes it's unavaoidable to need higher-order language, e.g., when the problem you're modeling truly has recursive structure!
• Unbounded number of clusters, change points, or agents, unknown stopping time (as above), ...

## Inference in probabilistic programming¶

• There is an interesting dichotomy here. We simultaneously want to: a) entirely decouple modeling from inference, so that you never need to think about inference when writing a model; but also b) they must be completely entertwined, so that the act of writing a valid model (the model compiles!) means that inference must be possible (though not necessarily with efficiency guarantees).
• We will see that goal (b) is always possible (though, again, without efficiency guarantees) but goal (a) is harder and requires competent domain-specific language (DSL) design, into which we cannot dive deeply.
• Multiple examples of inference using pyro and pointers to other inference examples, as well as methods we won't implement (during class, at least, though they're great topics for midterm + final papers).
• First, imagine, we're using a FOPL (or at least, that our models are guaranteed to be of ex ante bounded cardinality) and hence that we can compile them to a graph. (Sadly, because we have limited time in this course, we won't go over how to actually compile them, though you can and should attempt to figure this out yourself.)
• There are fundamentally two types of graphs we can consider: DAGs (we've seen these already) and factor graphs, which we have not yet considered.
• Let's think about inference on factor graphs...
• Exact inference: variable elimination
• Belief propagation (sum-product algorithm)

## Variable elimination¶

• What it sounds like -- eliminating variables from a factor graph.
• Why? Suppose we want to calculate $p(z | x) = p(x | z) p(z) / p(x)$. We need to find $p(x)$ in order to do this, i.e., we need to marginalize out $z$, i.e., calculate $$p(x) = \sum_z p(x|z) p(x)$$
• Okay, what's the problem? Two-fold...
1. This sum might actually be an integral that we have to compute via sampling (e.g., what if part of $z$ is a continuous-valued distribution?)
2. Even if each component of $z$ has only $n$ possible values, if the length of $z$ is $m$, then this algorithm has complexity $n^m$ which is REALLY BAD (e.g. 100 binary rvs produces $2^{100} \approx 1.26 \times 10^{30}$ terms in the sum if computed naively...)
• Worst case, there is nothing to be done about this. In practice, we observe that the degree of the factor nodes are often quite small, i.e., $m$ is not too large.
• Exploit structure: sum out variables one at a time in order to dynamically create new factors, which may in turn be summed out later on.
• Not a panacea: optimal ordering to sum out variables is NP-hard problem...
• Example model $$p(a, b, c, d, e) = p(a| b, c)\ p(b|c, d)\ p(c|e)\ p(d | e)\ p(e)$$ Want to find $p(a)$ (the marginal distribution). Summing out in topological order (here this is a good ordering, not always true!):
• $p(a, b, c, d) = p(a| b, c)\ p(b|c, d)\ \sum_e p(c|e)\ p(d | e)\ p(e) = p(a| b, c)\ p(b|c, d)\ f(c, d)$
• $p(a, b, c) = p(a| b, c)\ \sum_d p(b|c, d)\ f(c, d) = p(a | b, c)f(b, c)$
• $p(a, b) = \sum_c p(a|b, c) f(b, c) = f(a, b)$
• $p(a) = \sum_b f(a, b)$ (whew)
• Equivalent to just writing $$p(a) = \sum_b \sum_c p(a| b, c)\ \sum_d p(b|c, d)\ \sum_e p(c|e)\ p(d | e)\ p(e)$$
In [ ]: