Skip to content

Commit 0a151ed

Browse files
committed
[WIP/CLN/ENH] Extracting samplers to the api to make it easier
1 parent 4625dad commit 0a151ed

File tree

6 files changed

+129
-34
lines changed

6 files changed

+129
-34
lines changed

gempy_probability/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from .modules.model_definition.prob_model_factory import make_gempy_pyro_model, GemPyPyroModel
22
from .modules import likelihoods
3-
from .api.model_runner import run_predictive
3+
from .api.model_runner import run_predictive, run_mcmc_for_NUTS, run_nuts_inference
44

55
from ._version import __version__
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from ._pyro_runner import run_predictive
1+
from ._pyro_runner import run_predictive, run_nuts_inference, run_mcmc_for_NUTS

gempy_probability/api/model_runner/_pyro_runner.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from pyro.infer import MCMC
2020
from pyro.infer.autoguide import init_to_mean
2121

22+
from gempy_probability.core.samplers_data import NUTSConfig
23+
2224

2325
def run_predictive(prob_model: gpp.GemPyPyroModel, geo_model: gp.data.GeoModel,
2426
y_obs_list: torch.Tensor, n_samples: int, plot_trace:bool=False) -> az.InferenceData:
@@ -35,3 +37,65 @@ def run_predictive(prob_model: gpp.GemPyPyroModel, geo_model: gp.data.GeoModel,
3537
plt.show()
3638

3739
return data
40+
41+
def run_nuts_inference(prob_model: gpp.GemPyPyroModel, geo_model: gp.data.GeoModel,
42+
y_obs_list: torch.Tensor, config: NUTSConfig, plot_trace:bool=False,
43+
run_posterior_predictive:bool=False) -> az.InferenceData:
44+
45+
nuts_kernel = NUTS(
46+
prob_model,
47+
step_size=config.step_size,
48+
adapt_step_size=config.adapt_step_size,
49+
target_accept_prob=config.target_accept_prob,
50+
max_tree_depth=config.max_tree_depth,
51+
init_strategy=init_to_mean
52+
)
53+
data = run_mcmc_for_NUTS(
54+
geo_model=geo_model,
55+
nuts_kernel=nuts_kernel,
56+
prob_model=prob_model,
57+
y_obs_list=y_obs_list,
58+
num_samples=config.num_samples,
59+
warmup_steps=config.warmup_steps,
60+
plot_trace=plot_trace,
61+
run_posterior_predictive=run_posterior_predictive,
62+
)
63+
64+
return data
65+
66+
67+
def run_mcmc_for_NUTS(
68+
geo_model: gp.data.GeoModel,
69+
nuts_kernel: NUTS,
70+
prob_model: gpp.GemPyPyroModel,
71+
y_obs_list: torch.Tensor,
72+
num_samples: int,
73+
warmup_steps: int,
74+
plot_trace: bool = False,
75+
run_posterior_predictive: bool = False,
76+
) -> az.InferenceData:
77+
mcmc = MCMC(
78+
kernel=nuts_kernel,
79+
num_samples=num_samples,
80+
warmup_steps=warmup_steps,
81+
disable_validation=False,
82+
)
83+
mcmc.run(geo_model, y_obs_list)
84+
85+
if run_posterior_predictive:
86+
posterior_predictive = Predictive(
87+
model=prob_model,
88+
posterior_samples=mcmc.get_samples(),
89+
)(geo_model, y_obs_list)
90+
data: az.InferenceData = az.from_pyro(
91+
posterior=mcmc,
92+
posterior_predictive=posterior_predictive,
93+
)
94+
else:
95+
data: az.InferenceData = az.from_pyro(posterior=mcmc)
96+
97+
if plot_trace:
98+
az.plot_trace(data)
99+
plt.show()
100+
101+
return data

gempy_probability/core/samplers_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33

44
@dataclasses.dataclass
5-
class NUTS_Args:
5+
class NUTSConfig:
66
step_size: float = 0.0085
77
adapt_step_size: bool = True
88
target_accept_prob: float = 0.9

gempy_probability/modules/forwards/_gempy_fw.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66

77
def run_gempy_forward(interp_input: InterpolationInput, geo_model: gp.data.GeoModel) -> gp.data.Solutions:
8-
BackendTensor.change_backend_gempy(engine_backend=gp.data.AvailableBackends.PYTORCH)
8+
99
# compute model
1010
sol = gempy_engine.compute_model(
1111
interpolation_input=interp_input,

tests/test_prob_model/test_prob_factory.py

Lines changed: 61 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@
99
from gempy_engine.core.backend_tensor import BackendTensor
1010
import gempy_probability as gpp
1111
from gempy_engine.core.data.interpolation_input import InterpolationInput
12+
from gempy_probability.core.samplers_data import NUTSConfig
1213

1314
seed = 123456
1415
torch.manual_seed(seed)
1516
pyro.set_rng_seed(seed)
1617

18+
1719
def test_prob_model_factory() -> None:
1820
current_dir = os.path.dirname(os.path.abspath(__file__))
1921
data_path = os.path.abspath(os.path.join(current_dir, '..', '..', 'examples', 'tutorials', 'data'))
@@ -71,19 +73,20 @@ def test_prob_model_factory() -> None:
7173
prob_model=pyro_gravity_model,
7274
y_obs_list=torch.tensor([200, 210, 190])
7375
)
74-
76+
77+
7578
def modify_z_for_surface_point1(
7679
samples: dict[str, Distribution],
7780
geo_model: gp.data.GeoModel,
7881
) -> InterpolationInput:
7982
# TODO: We can make a factory for this type of functions
8083
prior_key = r'$\mu_{top}$'
81-
84+
8285
from gempy.modules.data_manipulation import interpolation_input_from_structural_frame
8386
interp_input = interpolation_input_from_structural_frame(geo_model)
8487
new_tensor: torch.Tensor = torch.index_put(
8588
input=interp_input.surface_points.sp_coords,
86-
indices=(torch.tensor([0]), torch.tensor([2])), # * This has to be Tensors
89+
indices=(torch.tensor([0]), torch.tensor([2])), # * This has to be Tensors
8790
values=(samples[prior_key])
8891
)
8992
interp_input.surface_points.sp_coords = new_tensor
@@ -103,7 +106,7 @@ def _prob_run(geo_model: gp.data.GeoModel, prob_model: gpp.GemPyPyroModel,
103106
from pyro.infer.autoguide import init_to_mean
104107

105108
# region prior sampling
106-
prior = gpp.run_predictive(
109+
prior_inference = gpp.run_predictive(
107110
prob_model=prob_model,
108111
geo_model=geo_model,
109112
y_obs_list=y_obs_list,
@@ -114,30 +117,51 @@ def _prob_run(geo_model: gp.data.GeoModel, prob_model: gpp.GemPyPyroModel,
114117
# endregion
115118

116119
# region inference
117-
pyro.primitives.enable_validation(is_validate=True)
118-
nuts_kernel = NUTS(
119-
prob_model,
120-
step_size=0.0085,
121-
adapt_step_size=True,
122-
target_accept_prob=0.9,
123-
max_tree_depth=10,
124-
init_strategy=init_to_mean
125-
)
126-
mcmc = MCMC(
127-
kernel=nuts_kernel,
128-
num_samples=200,
129-
warmup_steps=50,
130-
disable_validation=False
131-
)
132-
mcmc.run(geo_model, y_obs_list)
133-
posterior_samples = mcmc.get_samples()
134-
posterior_predictive_fn = Predictive(
135-
model=prob_model,
136-
posterior_samples=posterior_samples
120+
data = gpp.run_nuts_inference(
121+
prob_model=prob_model,
122+
geo_model=geo_model,
123+
y_obs_list=y_obs_list,
124+
config=NUTSConfig(
125+
step_size=0.0085,
126+
adapt_step_size=True,
127+
target_accept_prob=0.9,
128+
max_tree_depth=10,
129+
init_strategy='auto',
130+
num_samples=200,
131+
warmup_steps=50,
132+
),
133+
plot_trace=False,
134+
run_posterior_predictive=True
137135
)
138-
posterior_predictive = posterior_predictive_fn(geo_model, y_obs_list)
139-
140-
data = az.from_pyro(posterior=mcmc, prior=prior, posterior_predictive=posterior_predictive)
136+
# pyro.primitives.enable_validation(is_validate=True)
137+
# nuts_kernel = NUTS(
138+
# prob_model,
139+
# step_size=0.0085,
140+
# adapt_step_size=True,
141+
# target_accept_prob=0.9,
142+
# max_tree_depth=10,
143+
# init_strategy=init_to_mean
144+
# )
145+
# mcmc = MCMC(
146+
# kernel=nuts_kernel,
147+
# num_samples=200,
148+
# warmup_steps=50,
149+
# disable_validation=False
150+
# )
151+
# mcmc.run(geo_model, y_obs_list)
152+
# posterior_samples = mcmc.get_samples()
153+
# posterior_predictive_fn = Predictive(
154+
# model=prob_model,
155+
# posterior_samples=posterior_samples
156+
# )
157+
# posterior_predictive = posterior_predictive_fn(geo_model, y_obs_list)
158+
#
159+
# data = az.from_pyro(
160+
# posterior=mcmc,
161+
# # prior=prior.to_dict()["prior"],
162+
# posterior_predictive=posterior_predictive
163+
# )
164+
data.extend(prior_inference)
141165
# endregion
142166

143167
# print("Number of interpolations: ", geo_model.counter)
@@ -149,7 +173,7 @@ def _prob_run(geo_model: gp.data.GeoModel, prob_model: gpp.GemPyPyroModel,
149173
plt.show()
150174
from gempy_probability.modules.plot.plot_posterior import default_red, default_blue
151175
az.plot_density(
152-
data=[data.posterior_predictive, data.prior_predictive],
176+
data=[data.posterior_predictive, data.prior],
153177
shade=.9,
154178
var_names=[r'$\mu_{thickness}$'],
155179
data_labels=["Posterior Predictive", "Prior Predictive"],
@@ -165,6 +189,15 @@ def _prob_run(geo_model: gp.data.GeoModel, prob_model: gpp.GemPyPyroModel,
165189
)
166190
plt.show()
167191

192+
az.plot_density(
193+
data=[data.posterior_predictive, data.prior],
194+
shade=.9,
195+
hdi_prob=.99,
196+
data_labels=["Posterior", "Prior"],
197+
colors=[default_red, default_blue],
198+
)
199+
plt.show()
200+
168201
# Test posterior mean values
169202
posterior_top_mean = float(data.posterior[r'$\mu_{top}$'].mean())
170203
target_top = 0.00875
@@ -180,5 +213,3 @@ def _prob_run(geo_model: gp.data.GeoModel, prob_model: gpp.GemPyPyroModel,
180213
print(f"Thickness mean: {posterior_thickness_mean}")
181214
print("MCMC convergence diagnostics:")
182215
print(f"Divergences: {float(data.sample_stats.diverging.sum())}")
183-
184-

0 commit comments

Comments
 (0)