Skip to content

Commit 213e0fd

Browse files
Refactor rewrites and add a general sampler constructor function
1 parent 589520d commit 213e0fd

10 files changed

+1227
-524
lines changed

aemcmc/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
11
from . import _version
22

33
__version__ = _version.get_versions()["version"]
4+
5+
# Register rewrite databases
6+
import aemcmc.conjugates
7+
import aemcmc.gibbs

aemcmc/basic.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
from typing import Dict, Tuple
2+
3+
from aesara.graph.basic import Variable
4+
from aesara.graph.fg import FunctionGraph
5+
from aesara.tensor.random.utils import RandomStream
6+
from aesara.tensor.var import TensorVariable
7+
8+
from aemcmc.opt import (
9+
SamplerTracker,
10+
construct_ir_fgraph,
11+
expand_subsumptions,
12+
sampler_rewrites_db,
13+
)
14+
15+
16+
def construct_sampler(
17+
obs_rvs_to_values: Dict[TensorVariable, TensorVariable], srng: RandomStream
18+
) -> Tuple[
19+
Dict[TensorVariable, TensorVariable],
20+
Dict[Variable, Variable],
21+
Dict[TensorVariable, TensorVariable],
22+
]:
23+
r"""Eagerly construct a sampler for a given set of observed variables and their observations.
24+
25+
Parameters
26+
==========
27+
obs_rvs_to_values
28+
A ``dict`` of variables that maps stochastic elements
29+
(e.g. `RandomVariable`\s) to symbolic `Variable`\s representing their
30+
observed values.
31+
32+
Returns
33+
=======
34+
A ``dict`` that maps each random variable to its sampler step and
35+
any updates generated by the sampler steps.
36+
"""
37+
38+
fgraph, obs_rvs_to_values, memo, new_to_old_rvs = construct_ir_fgraph(
39+
obs_rvs_to_values
40+
)
41+
42+
fgraph.attach_feature(SamplerTracker(srng))
43+
44+
_ = sampler_rewrites_db.query("+basic").optimize(fgraph)
45+
46+
random_vars = tuple(rv for rv in fgraph.outputs if rv not in obs_rvs_to_values)
47+
48+
discovered_samplers = fgraph.sampler_mappings.rvs_to_samplers
49+
50+
rvs_to_init_vals = {rv: rv.clone() for rv in random_vars}
51+
posterior_sample_steps = rvs_to_init_vals.copy()
52+
# Replace occurrences of observed variables with their observed values
53+
posterior_sample_steps.update(obs_rvs_to_values)
54+
55+
# TODO FIXME: Get/extract `Scan`-generated updates
56+
posterior_updates: Dict[Variable, Variable] = {}
57+
58+
rvs_without_samplers = set()
59+
60+
for rv in fgraph.outputs:
61+
62+
if rv in obs_rvs_to_values:
63+
continue
64+
65+
rv_steps = discovered_samplers.get(rv)
66+
67+
if not rv_steps:
68+
rvs_without_samplers.add(rv)
69+
continue
70+
71+
# TODO FIXME: Just choosing one for now, but we should consider them all.
72+
step_desc, step, updates = rv_steps.pop()
73+
74+
# Expand subsumed `DimShuffle`d inputs to `Elemwise`s
75+
if updates:
76+
update_keys, update_values = zip(*updates.items())
77+
else:
78+
update_keys, update_values = tuple(), tuple()
79+
80+
sfgraph = FunctionGraph(
81+
outputs=(step,) + tuple(update_keys) + tuple(update_values),
82+
clone=False,
83+
copy_inputs=False,
84+
copy_orphans=False,
85+
)
86+
87+
# Update the other sampled random variables in this step's graph
88+
sfgraph.replace_all(list(posterior_sample_steps.items()), import_missing=True)
89+
90+
expand_subsumptions.optimize(sfgraph)
91+
92+
step = sfgraph.outputs[0]
93+
94+
# Update the other sampled random variables in this step's graph
95+
# (step,) = clone_replace([step], replace=posterior_sample_steps)
96+
97+
posterior_sample_steps[rv] = step
98+
99+
if updates:
100+
keys_offset = len(update_keys) + 1
101+
update_keys = sfgraph.outputs[1:keys_offset]
102+
update_values = sfgraph.outputs[keys_offset:]
103+
updates = dict(zip(update_keys, update_values))
104+
posterior_updates.update(updates)
105+
106+
if rvs_without_samplers:
107+
# TODO: Assign NUTS to these
108+
raise NotImplementedError(
109+
f"Could not find a posterior samplers for {rvs_without_samplers}"
110+
)
111+
112+
# TODO: Track/handle "auxiliary/augmentation" variables introduced by sample
113+
# steps?
114+
115+
return (
116+
{
117+
new_to_old_rvs[rv]: step
118+
for rv, step in posterior_sample_steps.items()
119+
if rv not in obs_rvs_to_values
120+
},
121+
posterior_updates,
122+
{new_to_old_rvs[rv]: init_var for rv, init_var in rvs_to_init_vals.items()},
123+
)

aemcmc/conjugates.py

Lines changed: 52 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
import aesara.tensor as at
2-
from aeppl.opt import NoCallbackEquilibriumDB
3-
from aesara.graph.kanren import KanrenRelationSub
2+
from aesara.graph.opt import in2out, local_optimizer
3+
from aesara.graph.optdb import LocalGroupDB
4+
from aesara.graph.unify import eval_if_etuple
5+
from aesara.tensor.random.basic import BinomialRV
46
from etuples import etuple, etuplize
5-
from kanren import eq, lall
7+
from kanren import eq, lall, run
68
from unification import var
79

10+
from aemcmc.opt import sampler_finder_db
811

9-
def beta_binomial_conjugateo(observed_rv_expr, posterior_expr):
12+
13+
def beta_binomial_conjugateo(observed_val, observed_rv_expr, posterior_expr):
1014
r"""Produce a goal that represents the application of Bayes theorem
1115
for a beta prior with a binomial observation model.
1216
@@ -22,15 +26,15 @@ def beta_binomial_conjugateo(observed_rv_expr, posterior_expr):
2226
2327
Parameters
2428
----------
29+
observed_val
30+
The observed value.
2531
observed_rv_expr
26-
A tuple that contains expressions that represent the observed variable
27-
and it observed value respectively.
32+
An expression that represents the observed variable.
2833
posterior_exp
2934
An expression that represents the posterior distribution of the latent
3035
variable.
3136
3237
"""
33-
3438
# Beta-binomial observation model
3539
alpha_lv, beta_lv = var(), var()
3640
p_rng_lv = var()
@@ -42,12 +46,10 @@ def beta_binomial_conjugateo(observed_rv_expr, posterior_expr):
4246
n_lv = var()
4347
Y_et = etuple(etuplize(at.random.binomial), var(), var(), var(), n_lv, p_et)
4448

45-
y_lv = var() # observation
46-
4749
# Posterior distribution for p
48-
new_alpha_et = etuple(etuplize(at.add), alpha_lv, y_lv)
50+
new_alpha_et = etuple(etuplize(at.add), alpha_lv, observed_val)
4951
new_beta_et = etuple(
50-
etuplize(at.sub), etuple(etuplize(at.add), beta_lv, n_lv), y_lv
52+
etuplize(at.sub), etuple(etuplize(at.add), beta_lv, n_lv), observed_val
5153
)
5254
p_posterior_et = etuple(
5355
etuplize(at.random.beta),
@@ -59,13 +61,47 @@ def beta_binomial_conjugateo(observed_rv_expr, posterior_expr):
5961
)
6062

6163
return lall(
62-
eq(observed_rv_expr[0], Y_et),
63-
eq(observed_rv_expr[1], y_lv),
64+
eq(observed_rv_expr, Y_et),
6465
eq(posterior_expr, p_posterior_et),
6566
)
6667

6768

68-
conjugates_db = NoCallbackEquilibriumDB()
69-
conjugates_db.register(
70-
"beta_binomial", KanrenRelationSub(beta_binomial_conjugateo), -5, "basic"
69+
@local_optimizer([BinomialRV])
70+
def local_beta_binomial_posterior(fgraph, node):
71+
72+
sampler_mappings = getattr(fgraph, "sampler_mappings", None)
73+
74+
rv_var = node.outputs[1]
75+
key = ("local_beta_binomial_posterior", rv_var)
76+
77+
if sampler_mappings is None or key in sampler_mappings.rvs_seen:
78+
return None # pragma: no cover
79+
80+
q = var()
81+
82+
rv_et = etuplize(rv_var)
83+
84+
res = run(None, q, beta_binomial_conjugateo(rv_var, rv_et, q))
85+
res = next(res, None)
86+
87+
if res is None:
88+
return None # pragma: no cover
89+
90+
beta_rv = rv_et[-1].evaled_obj
91+
beta_posterior = eval_if_etuple(res)
92+
93+
sampler_mappings.rvs_to_samplers.setdefault(beta_rv, []).append(
94+
("local_beta_binomial_posterior", beta_posterior, None)
95+
)
96+
sampler_mappings.rvs_seen.add(key)
97+
98+
return rv_var.owner.outputs
99+
100+
101+
conjugates_db = LocalGroupDB(apply_all_opts=True)
102+
conjugates_db.name = "conjugates_db"
103+
conjugates_db.register("beta_binomial", local_beta_binomial_posterior, "basic")
104+
105+
sampler_finder_db.register(
106+
"conjugates", in2out(conjugates_db.query("+basic"), name="gibbs"), "basic"
71107
)

aemcmc/dists.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def multivariate_normal_rue2005(rng, b, Q):
5353
w = at.slinalg.solve_triangular(L, b, lower=True)
5454
u = at.slinalg.solve_triangular(L.T, w, lower=False)
5555
z = rng.standard_normal(size=L.shape[0])
56+
z.owner.outputs[0].name = "z_rng"
5657
v = at.slinalg.solve_triangular(L.T, z, lower=False)
5758
return u + v
5859

@@ -135,6 +136,7 @@ def multivariate_normal_cong2017(
135136
A_inv = 1 / A
136137
a_rows = A.shape[0]
137138
z = rng.standard_normal(size=a_rows + omega.shape[0])
139+
z.owner.outputs[0].name = "z_rng"
138140
y1 = at.sqrt(A_inv) * z[:a_rows]
139141
y2 = (1 / at.sqrt(omega)) * z[a_rows:]
140142
Ainv_phi = A_inv[:, None] * phi.T

0 commit comments

Comments
 (0)