|
1 | 1 | # Copyright Contributors to the Pyro project.
|
2 | 2 | # SPDX-License-Identifier: Apache-2.0
|
3 | 3 |
|
4 |
| -from collections import defaultdict |
5 |
| -from functools import partial |
| 4 | +from collections import OrderedDict, defaultdict |
| 5 | +from functools import partial, reduce |
6 | 6 | from operator import itemgetter
|
7 | 7 | import warnings
|
8 | 8 |
|
|
11 | 11 | import jax.numpy as jnp
|
12 | 12 | from jax.scipy.special import logsumexp
|
13 | 13 |
|
| 14 | +from numpyro.distributions import ExpandedDistribution, MaskedDistribution |
14 | 15 | from numpyro.distributions.kl import kl_divergence
|
15 | 16 | from numpyro.distributions.util import scale_and_mask
|
16 | 17 | from numpyro.handlers import Messenger, replay, seed, substitute, trace
|
17 |
| -from numpyro.infer.util import get_importance_trace, log_density |
| 18 | +from numpyro.infer.util import ( |
| 19 | + _without_rsample_stop_gradient, |
| 20 | + get_importance_trace, |
| 21 | + is_identically_one, |
| 22 | + log_density, |
| 23 | +) |
18 | 24 | from numpyro.ops.provenance import eval_provenance, get_provenance
|
19 | 25 | from numpyro.util import _validate_model, check_model_guide_match, find_stack_level
|
20 | 26 |
|
@@ -710,3 +716,294 @@ def single_particle_elbo(rng_key):
|
710 | 716 | else:
|
711 | 717 | rng_keys = random.split(rng_key, self.num_particles)
|
712 | 718 | return -jnp.mean(vmap(single_particle_elbo)(rng_keys))
|
| 719 | + |
| 720 | + |
| 721 | +def get_importance_trace_enum(model, guide, args, kwargs, params, max_plate_nesting): |
| 722 | + """ |
| 723 | + (EXPERIMENTAL) Returns traces from the enumerated guide and the enumerated model that is run against it. |
| 724 | + The returned traces also store the log probability at each site and the log measure for measure vars. |
| 725 | + """ |
| 726 | + import funsor |
| 727 | + from numpyro.contrib.funsor import ( |
| 728 | + enum, |
| 729 | + plate_to_enum_plate, |
| 730 | + to_funsor, |
| 731 | + trace as _trace, |
| 732 | + ) |
| 733 | + |
| 734 | + with plate_to_enum_plate(), enum( |
| 735 | + first_available_dim=(-max_plate_nesting - 1) if max_plate_nesting else None |
| 736 | + ): |
| 737 | + guide = substitute(guide, data=params) |
| 738 | + with _without_rsample_stop_gradient(): |
| 739 | + guide_trace = _trace(guide).get_trace(*args, **kwargs) |
| 740 | + model = substitute(replay(model, guide_trace), data=params) |
| 741 | + model_trace = _trace(model).get_trace(*args, **kwargs) |
| 742 | + guide_trace = { |
| 743 | + name: site for name, site in guide_trace.items() if site["type"] == "sample" |
| 744 | + } |
| 745 | + model_trace = { |
| 746 | + name: site for name, site in model_trace.items() if site["type"] == "sample" |
| 747 | + } |
| 748 | + for is_model, tr in zip((False, True), (guide_trace, model_trace)): |
| 749 | + for name, site in tr.items(): |
| 750 | + if is_model and (site["is_observed"] or (site["name"] in guide_trace)): |
| 751 | + site["is_measure"] = False |
| 752 | + if "log_prob" not in site: |
| 753 | + value = site["value"] |
| 754 | + intermediates = site["intermediates"] |
| 755 | + if intermediates: |
| 756 | + log_prob = site["fn"].log_prob(value, intermediates) |
| 757 | + else: |
| 758 | + log_prob = site["fn"].log_prob(value) |
| 759 | + |
| 760 | + dim_to_name = site["infer"]["dim_to_name"] |
| 761 | + site["log_prob"] = to_funsor( |
| 762 | + log_prob, output=funsor.Real, dim_to_name=dim_to_name |
| 763 | + ) |
| 764 | + if site.get("is_measure", True): |
| 765 | + # get rid off masking |
| 766 | + base_fn = site["fn"] |
| 767 | + batch_shape = base_fn.batch_shape |
| 768 | + while isinstance( |
| 769 | + base_fn, (MaskedDistribution, ExpandedDistribution) |
| 770 | + ): |
| 771 | + base_fn = base_fn.base_dist |
| 772 | + base_fn = base_fn.expand(batch_shape) |
| 773 | + if intermediates: |
| 774 | + log_measure = base_fn.log_prob(value, intermediates) |
| 775 | + else: |
| 776 | + log_measure = base_fn.log_prob(value) |
| 777 | + # dice factor |
| 778 | + if not site["infer"].get("enumerate") == "parallel": |
| 779 | + log_measure = log_measure - funsor.ops.detach(log_measure) |
| 780 | + site["log_measure"] = to_funsor( |
| 781 | + log_measure, output=funsor.Real, dim_to_name=dim_to_name |
| 782 | + ) |
| 783 | + return model_trace, guide_trace |
| 784 | + |
| 785 | + |
| 786 | +def _partition(model_sum_deps, sum_vars): |
| 787 | + # Construct a bipartite graph between model_sum_deps and the sum_vars |
| 788 | + neighbors = OrderedDict([(t, []) for t in model_sum_deps.keys()]) |
| 789 | + for key, deps in model_sum_deps.items(): |
| 790 | + for dim in deps: |
| 791 | + if dim in sum_vars: |
| 792 | + neighbors[key].append(dim) |
| 793 | + neighbors.setdefault(dim, []).append(key) |
| 794 | + |
| 795 | + # Partition the bipartite graph into connected components for contraction. |
| 796 | + components = [] |
| 797 | + while neighbors: |
| 798 | + v, pending = neighbors.popitem() |
| 799 | + component = OrderedDict([(v, None)]) # used as an OrderedSet |
| 800 | + for v in pending: |
| 801 | + component[v] = None |
| 802 | + while pending: |
| 803 | + v = pending.pop() |
| 804 | + for v in neighbors.pop(v): |
| 805 | + if v not in component: |
| 806 | + component[v] = None |
| 807 | + pending.append(v) |
| 808 | + |
| 809 | + # Split this connected component into factors and measures. |
| 810 | + # Append only if component_factors is non-empty |
| 811 | + component_factors = frozenset(v for v in component if v not in sum_vars) |
| 812 | + if component_factors: |
| 813 | + component_measures = frozenset(v for v in component if v in sum_vars) |
| 814 | + components.append((component_factors, component_measures)) |
| 815 | + return components |
| 816 | + |
| 817 | + |
| 818 | +class TraceEnum_ELBO(ELBO): |
| 819 | + """ |
| 820 | + A TraceEnum implementation of ELBO-based SVI. The gradient estimator |
| 821 | + is constructed along the lines of reference [1] specialized to the case |
| 822 | + of the ELBO. It supports arbitrary dependency structure for the model |
| 823 | + and guide. |
| 824 | + Fine-grained conditional dependency information as recorded in the |
| 825 | + trace is used to reduce the variance of the gradient estimator. |
| 826 | + In particular provenance tracking [2] is used to find the ``cost`` terms |
| 827 | + that depend on each non-reparameterizable sample site. |
| 828 | + Enumerated variables are eliminated using the TVE algorithm for plated |
| 829 | + factor graphs [3]. |
| 830 | +
|
| 831 | + References |
| 832 | +
|
| 833 | + [1] `Storchastic: A Framework for General Stochastic Automatic Differentiation`, |
| 834 | + Emile van Kriekenc, Jakub M. Tomczak, Annette ten Teije |
| 835 | +
|
| 836 | + [2] `Nonstandard Interpretations of Probabilistic Programs for Efficient Inference`, |
| 837 | + David Wingate, Noah Goodman, Andreas Stuhlmüller, Jeffrey Siskind |
| 838 | +
|
| 839 | + [3] `Tensor Variable Elimination for Plated Factor Graphs`, |
| 840 | + Fritz Obermeyer, Eli Bingham, Martin Jankowiak, Justin Chiu, |
| 841 | + Neeraj Pradhan, Alexander M. Rush, Noah Goodman |
| 842 | + """ |
| 843 | + |
| 844 | + can_infer_discrete = True |
| 845 | + |
| 846 | + def __init__(self, num_particles=1, max_plate_nesting=float("inf")): |
| 847 | + if max_plate_nesting == float("inf"): |
| 848 | + raise ValueError( |
| 849 | + "Currently, we require `max_plate_nesting` to be a non-positive integer." |
| 850 | + ) |
| 851 | + self.max_plate_nesting = max_plate_nesting |
| 852 | + super().__init__(num_particles=num_particles) |
| 853 | + |
| 854 | + def loss(self, rng_key, param_map, model, guide, *args, **kwargs): |
| 855 | + def single_particle_elbo(rng_key): |
| 856 | + import funsor |
| 857 | + from numpyro.contrib.funsor import to_data, to_funsor |
| 858 | + |
| 859 | + model_seed, guide_seed = random.split(rng_key) |
| 860 | + seeded_model = seed(model, model_seed) |
| 861 | + seeded_guide = seed(guide, guide_seed) |
| 862 | + |
| 863 | + model_trace, guide_trace = get_importance_trace_enum( |
| 864 | + seeded_model, |
| 865 | + seeded_guide, |
| 866 | + args, |
| 867 | + kwargs, |
| 868 | + param_map, |
| 869 | + self.max_plate_nesting, |
| 870 | + ) |
| 871 | + check_model_guide_match(model_trace, guide_trace) |
| 872 | + _validate_model(model_trace, plate_warning="strict") |
| 873 | + |
| 874 | + # Find dependencies on non-reparameterizable sample sites for |
| 875 | + # each cost term in the model and the guide. |
| 876 | + model_deps, guide_deps = get_provenance( |
| 877 | + eval_provenance( |
| 878 | + partial( |
| 879 | + track_nonreparam(get_importance_log_probs), |
| 880 | + seeded_model, |
| 881 | + seeded_guide, |
| 882 | + args, |
| 883 | + kwargs, |
| 884 | + param_map, |
| 885 | + ) |
| 886 | + ) |
| 887 | + ) |
| 888 | + |
| 889 | + sum_vars = frozenset( |
| 890 | + [ |
| 891 | + name |
| 892 | + for name, site in model_trace.items() |
| 893 | + if site.get("is_measure", True) |
| 894 | + ] |
| 895 | + ) |
| 896 | + model_sum_deps = { |
| 897 | + k: v & sum_vars for k, v in model_deps.items() if k not in sum_vars |
| 898 | + } |
| 899 | + model_deps = { |
| 900 | + k: v - sum_vars for k, v in model_deps.items() if k not in sum_vars |
| 901 | + } |
| 902 | + |
| 903 | + elbo = 0.0 |
| 904 | + for group_names, group_sum_vars in _partition(model_sum_deps, sum_vars): |
| 905 | + if not group_sum_vars: |
| 906 | + # uncontracted logp cost term |
| 907 | + assert len(group_names) == 1 |
| 908 | + name = next(iter(group_names)) |
| 909 | + cost = model_trace[name]["log_prob"] |
| 910 | + scale = model_trace[name]["scale"] |
| 911 | + deps = model_deps[name] |
| 912 | + dice_factors = [guide_trace[key]["log_measure"] for key in deps] |
| 913 | + else: |
| 914 | + # compute contracted cost term |
| 915 | + group_factors = tuple( |
| 916 | + model_trace[name]["log_prob"] for name in group_names |
| 917 | + ) |
| 918 | + group_factors += tuple( |
| 919 | + model_trace[var]["log_measure"] for var in group_sum_vars |
| 920 | + ) |
| 921 | + group_factor_vars = frozenset().union( |
| 922 | + *[f.inputs for f in group_factors] |
| 923 | + ) |
| 924 | + group_plates = group_factor_vars - frozenset(model_trace.keys()) |
| 925 | + outermost_plates = frozenset.intersection( |
| 926 | + *(frozenset(f.inputs) & group_plates for f in group_factors) |
| 927 | + ) |
| 928 | + elim_plates = group_plates - outermost_plates |
| 929 | + cost = funsor.sum_product.sum_product( |
| 930 | + funsor.ops.logaddexp, |
| 931 | + funsor.ops.add, |
| 932 | + group_factors, |
| 933 | + plates=group_plates, |
| 934 | + eliminate=group_sum_vars | elim_plates, |
| 935 | + ) |
| 936 | + # incorporate the effects of subsampling and handlers.scale through a common scale factor |
| 937 | + group_scales = {} |
| 938 | + for name in group_names: |
| 939 | + for plate, value in ( |
| 940 | + model_trace[name].get("plate_to_scale", {}).items() |
| 941 | + ): |
| 942 | + if plate in group_scales: |
| 943 | + if value != group_scales[plate]: |
| 944 | + raise ValueError( |
| 945 | + "Expected all enumerated sample sites to share a common scale factor, " |
| 946 | + f"but found different scales at plate('{plate}')." |
| 947 | + ) |
| 948 | + else: |
| 949 | + group_scales[plate] = value |
| 950 | + scale = ( |
| 951 | + reduce(lambda a, b: a * b, group_scales.values()) |
| 952 | + if group_scales |
| 953 | + else None |
| 954 | + ) |
| 955 | + # combine deps |
| 956 | + deps = frozenset().union( |
| 957 | + *[model_deps[name] for name in group_names] |
| 958 | + ) |
| 959 | + # check model guide enumeration constraint |
| 960 | + for key in deps: |
| 961 | + site = guide_trace[key] |
| 962 | + if site["infer"].get("enumerate") == "parallel": |
| 963 | + for plate in ( |
| 964 | + frozenset(site["log_measure"].inputs) & elim_plates |
| 965 | + ): |
| 966 | + raise ValueError( |
| 967 | + "Expected model enumeration to be no more global than guide enumeration, but found " |
| 968 | + f"model enumeration sites upstream of guide site '{key}' in plate('{plate}')." |
| 969 | + "Try converting some model enumeration sites to guide enumeration sites." |
| 970 | + ) |
| 971 | + # combine dice factors |
| 972 | + dice_factors = [ |
| 973 | + guide_trace[key]["log_measure"].reduce( |
| 974 | + funsor.ops.add, |
| 975 | + frozenset(guide_trace[key]["log_measure"].inputs) |
| 976 | + & elim_plates, |
| 977 | + ) |
| 978 | + for key in deps |
| 979 | + ] |
| 980 | + |
| 981 | + if dice_factors: |
| 982 | + dice_factor = reduce(lambda a, b: a + b, dice_factors) |
| 983 | + cost = cost * funsor.ops.exp(dice_factor) |
| 984 | + if (scale is not None) and (not is_identically_one(scale)): |
| 985 | + cost = cost * to_funsor(scale) |
| 986 | + |
| 987 | + elbo = elbo + cost.reduce(funsor.ops.add) |
| 988 | + |
| 989 | + for name, deps in guide_deps.items(): |
| 990 | + # -logq cost term |
| 991 | + cost = -guide_trace[name]["log_prob"] |
| 992 | + scale = guide_trace[name]["scale"] |
| 993 | + dice_factors = [guide_trace[key]["log_measure"] for key in deps] |
| 994 | + if dice_factors: |
| 995 | + dice_factor = reduce(lambda a, b: a + b, dice_factors) |
| 996 | + cost = cost * funsor.ops.exp(dice_factor) |
| 997 | + if (scale is not None) and (not is_identically_one(scale)): |
| 998 | + cost = cost * to_funsor(scale) |
| 999 | + elbo = elbo + cost.reduce(funsor.ops.add) |
| 1000 | + |
| 1001 | + return to_data(elbo) |
| 1002 | + |
| 1003 | + # Return (-elbo) since by convention we do gradient descent on a loss and |
| 1004 | + # the ELBO is a lower bound that needs to be maximized. |
| 1005 | + if self.num_particles == 1: |
| 1006 | + return -single_particle_elbo(rng_key) |
| 1007 | + else: |
| 1008 | + rng_keys = random.split(rng_key, self.num_particles) |
| 1009 | + return -jnp.mean(vmap(single_particle_elbo)(rng_keys)) |
0 commit comments