Skip to content

Commit 09a3e0b

Browse files
ordabayevyfehiepsi
andauthored
Funsor based TraceEnum_ELBO implementation (#1512)
* initial commit * wip * remove TraceMarkovEnum_ELBO * pair coded * clean up * Make enum example work * port tests from pyro to numpyro * Add missing test file * traceenum_elbo2 * pair coded * pass more tests * pass all tests * organize * lint * fixes * test_gradient * fix TraceGraph_ELBO * lint * Revert tracegraph_elbo changes * Address masked distribution * revert changes at replay messenger * refactor * clean * add validations * fix comments * lint * fix validation * fix * fix enum_vars * rm wordclouds.png * address comments Co-authored-by: Du Phan <phandu@google.com>
1 parent c46b0db commit 09a3e0b

File tree

8 files changed

+2745
-7
lines changed

8 files changed

+2745
-7
lines changed

numpyro/contrib/funsor/enum_messenger.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,12 @@ def _get_batch_shape(cond_indep_stack):
521521
def process_message(self, msg):
522522
if msg["type"] in ["to_funsor", "to_data"]:
523523
return super().process_message(msg)
524+
if msg["type"] == "sample" and self.size != self.subsample_size:
525+
plate_to_scale = msg.setdefault("plate_to_scale", {})
526+
assert self.name not in plate_to_scale
527+
plate_to_scale[self.name] = (
528+
self.size / self.subsample_size if self.subsample_size else 1
529+
)
524530
return OrigPlateMessenger.process_message(self, msg)
525531

526532
def postprocess_message(self, msg):

numpyro/distributions/kl.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
Normal,
4040
Weibull,
4141
)
42+
from numpyro.distributions.discrete import CategoricalProbs
4243
from numpyro.distributions.distribution import (
4344
Delta,
4445
Distribution,
@@ -146,6 +147,14 @@ def kl_divergence(p, q):
146147
return t1 - t2 + t3
147148

148149

150+
@dispatch(CategoricalProbs, CategoricalProbs)
151+
def kl_divergence(p, q):
152+
t = p.probs * (p.logits - q.logits)
153+
t = jnp.where(q.probs == 0, jnp.inf, t)
154+
t = jnp.where(p.probs == 0, 0.0, t)
155+
return t.sum(-1)
156+
157+
149158
@dispatch(Dirichlet, Dirichlet)
150159
def kl_divergence(p, q):
151160
# From http://bariskurt.com/kullback-leibler-divergence-between-two-dirichlet-and-beta-distributions/

numpyro/handlers.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,13 @@ def process_message(self, msg):
609609
msg["scale"] = (
610610
self.scale if msg.get("scale") is None else self.scale * msg["scale"]
611611
)
612+
plate_to_scale = msg.setdefault("plate_to_scale", {})
613+
scale = (
614+
self.scale
615+
if plate_to_scale.get(None) is None
616+
else self.scale * plate_to_scale[None]
617+
)
618+
plate_to_scale[None] = scale
612619

613620

614621
class scope(Messenger):

numpyro/infer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
ELBO,
77
RenyiELBO,
88
Trace_ELBO,
9+
TraceEnum_ELBO,
910
TraceGraph_ELBO,
1011
TraceMeanField_ELBO,
1112
)
@@ -49,6 +50,7 @@
4950
"SA",
5051
"SVI",
5152
"Trace_ELBO",
53+
"TraceEnum_ELBO",
5254
"TraceGraph_ELBO",
5355
"TraceMeanField_ELBO",
5456
]

numpyro/infer/elbo.py

Lines changed: 300 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# Copyright Contributors to the Pyro project.
22
# SPDX-License-Identifier: Apache-2.0
33

4-
from collections import defaultdict
5-
from functools import partial
4+
from collections import OrderedDict, defaultdict
5+
from functools import partial, reduce
66
from operator import itemgetter
77
import warnings
88

@@ -11,10 +11,16 @@
1111
import jax.numpy as jnp
1212
from jax.scipy.special import logsumexp
1313

14+
from numpyro.distributions import ExpandedDistribution, MaskedDistribution
1415
from numpyro.distributions.kl import kl_divergence
1516
from numpyro.distributions.util import scale_and_mask
1617
from numpyro.handlers import Messenger, replay, seed, substitute, trace
17-
from numpyro.infer.util import get_importance_trace, log_density
18+
from numpyro.infer.util import (
19+
_without_rsample_stop_gradient,
20+
get_importance_trace,
21+
is_identically_one,
22+
log_density,
23+
)
1824
from numpyro.ops.provenance import eval_provenance, get_provenance
1925
from numpyro.util import _validate_model, check_model_guide_match, find_stack_level
2026

@@ -710,3 +716,294 @@ def single_particle_elbo(rng_key):
710716
else:
711717
rng_keys = random.split(rng_key, self.num_particles)
712718
return -jnp.mean(vmap(single_particle_elbo)(rng_keys))
719+
720+
721+
def get_importance_trace_enum(model, guide, args, kwargs, params, max_plate_nesting):
722+
"""
723+
(EXPERIMENTAL) Returns traces from the enumerated guide and the enumerated model that is run against it.
724+
The returned traces also store the log probability at each site and the log measure for measure vars.
725+
"""
726+
import funsor
727+
from numpyro.contrib.funsor import (
728+
enum,
729+
plate_to_enum_plate,
730+
to_funsor,
731+
trace as _trace,
732+
)
733+
734+
with plate_to_enum_plate(), enum(
735+
first_available_dim=(-max_plate_nesting - 1) if max_plate_nesting else None
736+
):
737+
guide = substitute(guide, data=params)
738+
with _without_rsample_stop_gradient():
739+
guide_trace = _trace(guide).get_trace(*args, **kwargs)
740+
model = substitute(replay(model, guide_trace), data=params)
741+
model_trace = _trace(model).get_trace(*args, **kwargs)
742+
guide_trace = {
743+
name: site for name, site in guide_trace.items() if site["type"] == "sample"
744+
}
745+
model_trace = {
746+
name: site for name, site in model_trace.items() if site["type"] == "sample"
747+
}
748+
for is_model, tr in zip((False, True), (guide_trace, model_trace)):
749+
for name, site in tr.items():
750+
if is_model and (site["is_observed"] or (site["name"] in guide_trace)):
751+
site["is_measure"] = False
752+
if "log_prob" not in site:
753+
value = site["value"]
754+
intermediates = site["intermediates"]
755+
if intermediates:
756+
log_prob = site["fn"].log_prob(value, intermediates)
757+
else:
758+
log_prob = site["fn"].log_prob(value)
759+
760+
dim_to_name = site["infer"]["dim_to_name"]
761+
site["log_prob"] = to_funsor(
762+
log_prob, output=funsor.Real, dim_to_name=dim_to_name
763+
)
764+
if site.get("is_measure", True):
765+
# get rid off masking
766+
base_fn = site["fn"]
767+
batch_shape = base_fn.batch_shape
768+
while isinstance(
769+
base_fn, (MaskedDistribution, ExpandedDistribution)
770+
):
771+
base_fn = base_fn.base_dist
772+
base_fn = base_fn.expand(batch_shape)
773+
if intermediates:
774+
log_measure = base_fn.log_prob(value, intermediates)
775+
else:
776+
log_measure = base_fn.log_prob(value)
777+
# dice factor
778+
if not site["infer"].get("enumerate") == "parallel":
779+
log_measure = log_measure - funsor.ops.detach(log_measure)
780+
site["log_measure"] = to_funsor(
781+
log_measure, output=funsor.Real, dim_to_name=dim_to_name
782+
)
783+
return model_trace, guide_trace
784+
785+
786+
def _partition(model_sum_deps, sum_vars):
787+
# Construct a bipartite graph between model_sum_deps and the sum_vars
788+
neighbors = OrderedDict([(t, []) for t in model_sum_deps.keys()])
789+
for key, deps in model_sum_deps.items():
790+
for dim in deps:
791+
if dim in sum_vars:
792+
neighbors[key].append(dim)
793+
neighbors.setdefault(dim, []).append(key)
794+
795+
# Partition the bipartite graph into connected components for contraction.
796+
components = []
797+
while neighbors:
798+
v, pending = neighbors.popitem()
799+
component = OrderedDict([(v, None)]) # used as an OrderedSet
800+
for v in pending:
801+
component[v] = None
802+
while pending:
803+
v = pending.pop()
804+
for v in neighbors.pop(v):
805+
if v not in component:
806+
component[v] = None
807+
pending.append(v)
808+
809+
# Split this connected component into factors and measures.
810+
# Append only if component_factors is non-empty
811+
component_factors = frozenset(v for v in component if v not in sum_vars)
812+
if component_factors:
813+
component_measures = frozenset(v for v in component if v in sum_vars)
814+
components.append((component_factors, component_measures))
815+
return components
816+
817+
818+
class TraceEnum_ELBO(ELBO):
819+
"""
820+
A TraceEnum implementation of ELBO-based SVI. The gradient estimator
821+
is constructed along the lines of reference [1] specialized to the case
822+
of the ELBO. It supports arbitrary dependency structure for the model
823+
and guide.
824+
Fine-grained conditional dependency information as recorded in the
825+
trace is used to reduce the variance of the gradient estimator.
826+
In particular provenance tracking [2] is used to find the ``cost`` terms
827+
that depend on each non-reparameterizable sample site.
828+
Enumerated variables are eliminated using the TVE algorithm for plated
829+
factor graphs [3].
830+
831+
References
832+
833+
[1] `Storchastic: A Framework for General Stochastic Automatic Differentiation`,
834+
Emile van Kriekenc, Jakub M. Tomczak, Annette ten Teije
835+
836+
[2] `Nonstandard Interpretations of Probabilistic Programs for Efficient Inference`,
837+
David Wingate, Noah Goodman, Andreas Stuhlmüller, Jeffrey Siskind
838+
839+
[3] `Tensor Variable Elimination for Plated Factor Graphs`,
840+
Fritz Obermeyer, Eli Bingham, Martin Jankowiak, Justin Chiu,
841+
Neeraj Pradhan, Alexander M. Rush, Noah Goodman
842+
"""
843+
844+
can_infer_discrete = True
845+
846+
def __init__(self, num_particles=1, max_plate_nesting=float("inf")):
847+
if max_plate_nesting == float("inf"):
848+
raise ValueError(
849+
"Currently, we require `max_plate_nesting` to be a non-positive integer."
850+
)
851+
self.max_plate_nesting = max_plate_nesting
852+
super().__init__(num_particles=num_particles)
853+
854+
def loss(self, rng_key, param_map, model, guide, *args, **kwargs):
855+
def single_particle_elbo(rng_key):
856+
import funsor
857+
from numpyro.contrib.funsor import to_data, to_funsor
858+
859+
model_seed, guide_seed = random.split(rng_key)
860+
seeded_model = seed(model, model_seed)
861+
seeded_guide = seed(guide, guide_seed)
862+
863+
model_trace, guide_trace = get_importance_trace_enum(
864+
seeded_model,
865+
seeded_guide,
866+
args,
867+
kwargs,
868+
param_map,
869+
self.max_plate_nesting,
870+
)
871+
check_model_guide_match(model_trace, guide_trace)
872+
_validate_model(model_trace, plate_warning="strict")
873+
874+
# Find dependencies on non-reparameterizable sample sites for
875+
# each cost term in the model and the guide.
876+
model_deps, guide_deps = get_provenance(
877+
eval_provenance(
878+
partial(
879+
track_nonreparam(get_importance_log_probs),
880+
seeded_model,
881+
seeded_guide,
882+
args,
883+
kwargs,
884+
param_map,
885+
)
886+
)
887+
)
888+
889+
sum_vars = frozenset(
890+
[
891+
name
892+
for name, site in model_trace.items()
893+
if site.get("is_measure", True)
894+
]
895+
)
896+
model_sum_deps = {
897+
k: v & sum_vars for k, v in model_deps.items() if k not in sum_vars
898+
}
899+
model_deps = {
900+
k: v - sum_vars for k, v in model_deps.items() if k not in sum_vars
901+
}
902+
903+
elbo = 0.0
904+
for group_names, group_sum_vars in _partition(model_sum_deps, sum_vars):
905+
if not group_sum_vars:
906+
# uncontracted logp cost term
907+
assert len(group_names) == 1
908+
name = next(iter(group_names))
909+
cost = model_trace[name]["log_prob"]
910+
scale = model_trace[name]["scale"]
911+
deps = model_deps[name]
912+
dice_factors = [guide_trace[key]["log_measure"] for key in deps]
913+
else:
914+
# compute contracted cost term
915+
group_factors = tuple(
916+
model_trace[name]["log_prob"] for name in group_names
917+
)
918+
group_factors += tuple(
919+
model_trace[var]["log_measure"] for var in group_sum_vars
920+
)
921+
group_factor_vars = frozenset().union(
922+
*[f.inputs for f in group_factors]
923+
)
924+
group_plates = group_factor_vars - frozenset(model_trace.keys())
925+
outermost_plates = frozenset.intersection(
926+
*(frozenset(f.inputs) & group_plates for f in group_factors)
927+
)
928+
elim_plates = group_plates - outermost_plates
929+
cost = funsor.sum_product.sum_product(
930+
funsor.ops.logaddexp,
931+
funsor.ops.add,
932+
group_factors,
933+
plates=group_plates,
934+
eliminate=group_sum_vars | elim_plates,
935+
)
936+
# incorporate the effects of subsampling and handlers.scale through a common scale factor
937+
group_scales = {}
938+
for name in group_names:
939+
for plate, value in (
940+
model_trace[name].get("plate_to_scale", {}).items()
941+
):
942+
if plate in group_scales:
943+
if value != group_scales[plate]:
944+
raise ValueError(
945+
"Expected all enumerated sample sites to share a common scale factor, "
946+
f"but found different scales at plate('{plate}')."
947+
)
948+
else:
949+
group_scales[plate] = value
950+
scale = (
951+
reduce(lambda a, b: a * b, group_scales.values())
952+
if group_scales
953+
else None
954+
)
955+
# combine deps
956+
deps = frozenset().union(
957+
*[model_deps[name] for name in group_names]
958+
)
959+
# check model guide enumeration constraint
960+
for key in deps:
961+
site = guide_trace[key]
962+
if site["infer"].get("enumerate") == "parallel":
963+
for plate in (
964+
frozenset(site["log_measure"].inputs) & elim_plates
965+
):
966+
raise ValueError(
967+
"Expected model enumeration to be no more global than guide enumeration, but found "
968+
f"model enumeration sites upstream of guide site '{key}' in plate('{plate}')."
969+
"Try converting some model enumeration sites to guide enumeration sites."
970+
)
971+
# combine dice factors
972+
dice_factors = [
973+
guide_trace[key]["log_measure"].reduce(
974+
funsor.ops.add,
975+
frozenset(guide_trace[key]["log_measure"].inputs)
976+
& elim_plates,
977+
)
978+
for key in deps
979+
]
980+
981+
if dice_factors:
982+
dice_factor = reduce(lambda a, b: a + b, dice_factors)
983+
cost = cost * funsor.ops.exp(dice_factor)
984+
if (scale is not None) and (not is_identically_one(scale)):
985+
cost = cost * to_funsor(scale)
986+
987+
elbo = elbo + cost.reduce(funsor.ops.add)
988+
989+
for name, deps in guide_deps.items():
990+
# -logq cost term
991+
cost = -guide_trace[name]["log_prob"]
992+
scale = guide_trace[name]["scale"]
993+
dice_factors = [guide_trace[key]["log_measure"] for key in deps]
994+
if dice_factors:
995+
dice_factor = reduce(lambda a, b: a + b, dice_factors)
996+
cost = cost * funsor.ops.exp(dice_factor)
997+
if (scale is not None) and (not is_identically_one(scale)):
998+
cost = cost * to_funsor(scale)
999+
elbo = elbo + cost.reduce(funsor.ops.add)
1000+
1001+
return to_data(elbo)
1002+
1003+
# Return (-elbo) since by convention we do gradient descent on a loss and
1004+
# the ELBO is a lower bound that needs to be maximized.
1005+
if self.num_particles == 1:
1006+
return -single_particle_elbo(rng_key)
1007+
else:
1008+
rng_keys = random.split(rng_key, self.num_particles)
1009+
return -jnp.mean(vmap(single_particle_elbo)(rng_keys))

0 commit comments

Comments
 (0)