Skip to content

Commit 32e5a0d

Browse files
committed
[CLN]
1 parent 0a151ed commit 32e5a0d

File tree

2 files changed

+37
-83
lines changed

2 files changed

+37
-83
lines changed

gempy_probability/api/model_runner/_pyro_runner.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,14 @@
1-
import numpy as np
2-
import os
3-
import pyro
4-
import pyro.distributions as dist
5-
import torch
6-
from arviz import InferenceData
7-
from pyro.distributions import Distribution
8-
9-
import gempy as gp
10-
from gempy_engine.core.backend_tensor import BackendTensor
11-
import gempy_probability as gpp
12-
from gempy_engine.core.data.interpolation_input import InterpolationInput
13-
from pyro.infer import Predictive
14-
import pyro
151
import arviz as az
162
import matplotlib.pyplot as plt
17-
18-
from pyro.infer import NUTS
3+
import torch
194
from pyro.infer import MCMC
5+
from pyro.infer import NUTS
6+
from pyro.infer import Predictive
207
from pyro.infer.autoguide import init_to_mean
218

22-
from gempy_probability.core.samplers_data import NUTSConfig
9+
import gempy as gp
10+
import gempy_probability as gpp
11+
from ...core.samplers_data import NUTSConfig
2312

2413

2514
def run_predictive(prob_model: gpp.GemPyPyroModel, geo_model: gp.data.GeoModel,

tests/test_prob_model/test_prob_factory.py

Lines changed: 31 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
import pyro.distributions as dist
55
import torch
66
from pyro.distributions import Distribution
7+
import arviz as az
8+
import matplotlib.pyplot as plt
79

810
import gempy as gp
9-
from gempy_engine.core.backend_tensor import BackendTensor
1011
import gempy_probability as gpp
12+
from gempy_engine.core.backend_tensor import BackendTensor
1113
from gempy_engine.core.data.interpolation_input import InterpolationInput
1214
from gempy_probability.core.samplers_data import NUTSConfig
1315

@@ -29,9 +31,6 @@ def test_prob_model_factory() -> None:
2931
)
3032
)
3133

32-
# TODO: Make this a more elegant way
33-
BackendTensor.change_backend_gempy(engine_backend=gp.data.AvailableBackends.PYTORCH)
34-
3534
# TODO: Convert this into an options preset
3635
geo_model.interpolation_options.uni_degree = 0
3736
geo_model.interpolation_options.mesh_extraction = False
@@ -65,15 +64,36 @@ def test_prob_model_factory() -> None:
6564
priors=model_priors,
6665
set_interp_input_fn=modify_z_for_surface_point1,
6766
likelihood_fn=gpp.likelihoods.thickness_likelihood,
68-
obs_name="obs_gravity"
67+
obs_name="wells_thickness"
6968
)
7069

71-
_prob_run(
70+
data: az.InferenceData = _prob_run(
7271
geo_model=geo_model,
7372
prob_model=pyro_gravity_model,
7473
y_obs_list=torch.tensor([200, 210, 190])
7574
)
7675

76+
if False: # * Save the arviz data
77+
data.to_netcdf("arviz_data.nc")
78+
79+
_plot(data)
80+
81+
# Test posterior mean values
82+
posterior_top_mean = float(data.posterior[r'$\mu_{top}$'].mean())
83+
target_top = 0.00875
84+
target_thickness = 223
85+
assert abs(posterior_top_mean - target_top) < 0.0200, f"Top layer mean {posterior_top_mean} outside expected range"
86+
posterior_thickness_mean = float(data.posterior_predictive[r'$\mu_{thickness}$'].mean())
87+
assert abs(posterior_thickness_mean - target_thickness) < 5, f"Thickness mean {posterior_thickness_mean} outside expected range"
88+
# Test convergence diagnostics
89+
assert float(data.sample_stats.diverging.sum()) == 0, "MCMC sampling has divergences"
90+
91+
print("Posterior mean values:")
92+
print(f"Top layer mean: {posterior_top_mean}")
93+
print(f"Thickness mean: {posterior_thickness_mean}")
94+
print("MCMC convergence diagnostics:")
95+
print(f"Divergences: {float(data.sample_stats.diverging.sum())}")
96+
7797

7898
def modify_z_for_surface_point1(
7999
samples: dict[str, Distribution],
@@ -87,23 +107,15 @@ def modify_z_for_surface_point1(
87107
new_tensor: torch.Tensor = torch.index_put(
88108
input=interp_input.surface_points.sp_coords,
89109
indices=(torch.tensor([0]), torch.tensor([2])), # * This has to be Tensors
90-
values=(samples[prior_key])
91-
)
110+
values=(samples[prior_key]
111+
)
92112
interp_input.surface_points.sp_coords = new_tensor
93113
return interp_input
94114

95115

96116
def _prob_run(geo_model: gp.data.GeoModel, prob_model: gpp.GemPyPyroModel,
97-
y_obs_list: torch.Tensor) -> None:
117+
y_obs_list: torch.Tensor) -> az.InferenceData:
98118
# Run prior sampling and visualization
99-
from pyro.infer import Predictive
100-
import pyro
101-
import arviz as az
102-
import matplotlib.pyplot as plt
103-
104-
from pyro.infer import NUTS
105-
from pyro.infer import MCMC
106-
from pyro.infer.autoguide import init_to_mean
107119

108120
# region prior sampling
109121
prior_inference = gpp.run_predictive(
@@ -133,42 +145,12 @@ def _prob_run(geo_model: gp.data.GeoModel, prob_model: gpp.GemPyPyroModel,
133145
plot_trace=False,
134146
run_posterior_predictive=True
135147
)
136-
# pyro.primitives.enable_validation(is_validate=True)
137-
# nuts_kernel = NUTS(
138-
# prob_model,
139-
# step_size=0.0085,
140-
# adapt_step_size=True,
141-
# target_accept_prob=0.9,
142-
# max_tree_depth=10,
143-
# init_strategy=init_to_mean
144-
# )
145-
# mcmc = MCMC(
146-
# kernel=nuts_kernel,
147-
# num_samples=200,
148-
# warmup_steps=50,
149-
# disable_validation=False
150-
# )
151-
# mcmc.run(geo_model, y_obs_list)
152-
# posterior_samples = mcmc.get_samples()
153-
# posterior_predictive_fn = Predictive(
154-
# model=prob_model,
155-
# posterior_samples=posterior_samples
156-
# )
157-
# posterior_predictive = posterior_predictive_fn(geo_model, y_obs_list)
158-
#
159-
# data = az.from_pyro(
160-
# posterior=mcmc,
161-
# # prior=prior.to_dict()["prior"],
162-
# posterior_predictive=posterior_predictive
163-
# )
164148
data.extend(prior_inference)
149+
return data
165150
# endregion
166151

167-
# print("Number of interpolations: ", geo_model.counter)
168-
169-
if False: # * Save the arviz data
170-
data.to_netcdf("arviz_data.nc")
171152

153+
def _plot(data):
172154
az.plot_trace(data)
173155
plt.show()
174156
from gempy_probability.modules.plot.plot_posterior import default_red, default_blue
@@ -188,7 +170,6 @@ def _prob_run(geo_model: gp.data.GeoModel, prob_model: gpp.GemPyPyroModel,
188170
colors=[default_red, default_blue],
189171
)
190172
plt.show()
191-
192173
az.plot_density(
193174
data=[data.posterior_predictive, data.prior],
194175
shade=.9,
@@ -197,19 +178,3 @@ def _prob_run(geo_model: gp.data.GeoModel, prob_model: gpp.GemPyPyroModel,
197178
colors=[default_red, default_blue],
198179
)
199180
plt.show()
200-
201-
# Test posterior mean values
202-
posterior_top_mean = float(data.posterior[r'$\mu_{top}$'].mean())
203-
target_top = 0.00875
204-
target_thickness = 223
205-
assert abs(posterior_top_mean - target_top) < 0.0200, f"Top layer mean {posterior_top_mean} outside expected range"
206-
posterior_thickness_mean = float(data.posterior_predictive[r'$\mu_{thickness}$'].mean())
207-
assert abs(posterior_thickness_mean - target_thickness) < 5, f"Thickness mean {posterior_thickness_mean} outside expected range"
208-
# Test convergence diagnostics
209-
assert float(data.sample_stats.diverging.sum()) == 0, "MCMC sampling has divergences"
210-
211-
print("Posterior mean values:")
212-
print(f"Top layer mean: {posterior_top_mean}")
213-
print(f"Thickness mean: {posterior_thickness_mean}")
214-
print("MCMC convergence diagnostics:")
215-
print(f"Divergences: {float(data.sample_stats.diverging.sum())}")

0 commit comments

Comments
 (0)