Skip to content

Commit 99b494b

Browse files
committed
[WIP] Developing interface for probabilistic models
1 parent 65021df commit 99b494b

File tree

4 files changed

+247
-5
lines changed

4 files changed

+247
-5
lines changed

docs/dev_logs/2025_May.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,15 @@
4646
### Likelihood functions
4747
- ? Should I start to create a catalog of likelihood functions that goes with gempy?
4848
- This would be useful to be able to visualize whatever the likelihood function represents independent of the probabilistic model.
49-
-
49+
5050
### What would I want to have to do StonePark example as flowless as possible?
5151
1. [x] A way to save the model already with the correct **nugget effects** instead of having to "load nuggets" (TODO in gempy)
5252
2. [ ] Likelihood function plug and play
5353
3. [ ] Some sort of pre-analysis if the model looks good. Maybe some metrics about how long it takes to run the model.
5454
1. Condition number
5555
2. Acceptance rate of the posterior
5656
4. Easy pykeops setting using .env
57-
5. Some sort of easy set-up for having a reduced number of priors
57+
5. Some sort of easy set-up for having a reduced number of priors
58+
59+
60+
## Probabilistic Modeling definition

gempy_probability/modules/model_definition/model_examples.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import gempy as gp
22
import gempy.core.data
33
import gempy_engine
4-
from gempy.modules.data_manipulation.engine_factory import interpolation_input_from_structural_frame
5-
from gempy_probability.modules.likelihoods import apparent_thickness_likelihood
4+
from gempy.modules.data_manipulation import interpolation_input_from_structural_frame
5+
from ..likelihoods import apparent_thickness_likelihood
66

77
import pyro
8-
import pyro.distributions as dist
98
import torch
109

10+
from gempy_probability.modules.model_definition.prob_model_factory import make_pyro_model
11+
1112

1213
def two_wells_prob_model_I(geo_model: gempy.core.data.GeoModel, normal, y_obs_list):
1314
"""
@@ -72,3 +73,55 @@ def two_wells_prob_model_I(geo_model: gempy.core.data.GeoModel, normal, y_obs_li
7273
)
7374

7475

76+
77+
import pyro.distributions as dist
78+
79+
# 1) Define your priors
80+
gravity_priors = {
81+
# e.g. density contrast of layer index 3
82+
"mu_density": dist.Normal(2.62, 0.5)
83+
}
84+
85+
# 2) Wrap your forward pass into a small fn
86+
def run_gempy_forward(samples, geo_model):
87+
88+
mu_top = samples["mu_density"]
89+
90+
interp_input = interpolation_input_from_structural_frame(geo_model)
91+
92+
indices__ = (torch.tensor([0]), torch.tensor([2])) # * This has to be Tensors
93+
new_tensor: torch.Tensor = torch.index_put(
94+
input=interp_input.surface_points.sp_coords,
95+
indices=indices__,
96+
values=mu_top
97+
)
98+
interp_input.surface_points.sp_coords = new_tensor
99+
100+
# compute model
101+
sol = gempy_engine.compute_model(
102+
interpolation_input=interp_input,
103+
options=geo_model.interpolation_options,
104+
data_descriptor=geo_model.input_data_descriptor,
105+
geophysics_input=geo_model.geophysics_input
106+
)
107+
108+
thickness = apparent_thickness_likelihood(sol)
109+
return thickness
110+
111+
# 3) Define your likelihood factory
112+
def gravity_likelihood(simulated_gravity):
113+
# e.g. multivariate normal with precomputed cov matrix
114+
return dist.MultivariateNormal(simulated_gravity, covariance_matrix)
115+
116+
def thickness_likelihood(simulated_thickness):
117+
# e.g. multivariate normal with precomputed cov matrix
118+
return dist.Normal(simulated_thickness, 25)
119+
120+
# 4) Build the model
121+
pyro_gravity_model = make_pyro_model(
122+
priors=gravity_priors,
123+
forward_fn=run_gempy_forward,
124+
likelihood_fn=thickness_likelihood,
125+
obs_name="obs_gravity"
126+
)
127+
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import pyro
2+
import torch
3+
from typing import Callable, Dict
4+
from pyro.distributions import Distribution
5+
6+
def make_pyro_model(
7+
priors: Dict[str, Distribution],
8+
forward_fn: Callable[..., torch.Tensor],
9+
likelihood_fn: Callable[[torch.Tensor], Distribution],
10+
obs_name: str = "obs"
11+
):
12+
"""
13+
Build a Pyro model function with:
14+
- priors: map parameter names -> pyro.distributions objects
15+
- forward_fn: fn(samples_dict, *args)-> simulated quantity (torch.Tensor)
16+
- likelihood_fn: fn(simulated)-> a Pyro Distribution over data
17+
- obs_name: name of the observed site
18+
"""
19+
def model(geo_model, obs_data):
20+
# 1) Sample each prior
21+
samples = {}
22+
for name, dist in priors.items():
23+
samples[name] = pyro.sample(name, dist)
24+
25+
# 2) Run your forward geological model
26+
simulated = forward_fn(samples, geo_model)
27+
28+
# 3) Build likelihood and observe
29+
lik_dist = likelihood_fn(simulated)
30+
pyro.sample(obs_name, lik_dist, obs=obs_data)
31+
32+
return model
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
import numpy as np
2+
import os
3+
import pyro.distributions as dist
4+
import torch
5+
6+
import gempy as gp
7+
from gempy_engine.core.backend_tensor import BackendTensor
8+
9+
def test_prob_model_factory() -> None:
10+
current_dir = os.path.dirname(os.path.abspath(__file__))
11+
data_path = os.path.abspath(os.path.join(current_dir, '..', '..', 'examples', 'tutorials', 'data'))
12+
geo_model = gp.create_geomodel(
13+
project_name='Wells',
14+
extent=[0, 12000, -500, 500, 0, 4000],
15+
refinement=3,
16+
importer_helper=gp.data.ImporterHelper(
17+
path_to_orientations=os.path.join(data_path, "2-layers", "2-layers_orientations.csv"),
18+
path_to_surface_points=os.path.join(data_path, "2-layers", "2-layers_surface_points.csv")
19+
)
20+
)
21+
22+
# TODO: Make this a more elegant way
23+
BackendTensor.change_backend_gempy(engine_backend=gp.data.AvailableBackends.PYTORCH)
24+
25+
# TODO: Convert this into an options preset
26+
geo_model.interpolation_options.uni_degree = 0
27+
geo_model.interpolation_options.mesh_extraction = False
28+
geo_model.interpolation_options.sigmoid_slope = 1100.
29+
30+
# region Minimal grid for the specific likelihood function
31+
x_loc = 6000
32+
y_loc = 0
33+
z_loc = np.linspace(0, 4000, 100)
34+
xyz_coord = np.array([[x_loc, y_loc, z] for z in z_loc])
35+
gp.set_custom_grid(geo_model.grid, xyz_coord=xyz_coord)
36+
# endregion
37+
38+
geo_model.grid.active_grids = gp.data.Grid.GridTypes.CUSTOM
39+
assert geo_model.grid.values.shape[0] == 100, "Custom grid should have 100 cells"
40+
# geo_model.counter = 0
41+
gp.compute_model(
42+
gempy_model=geo_model,
43+
validate_serialization=False
44+
)
45+
BackendTensor.change_backend_gempy(engine_backend=gp.data.AvailableBackends.PYTORCH)
46+
47+
normal = dist.Normal(
48+
loc=(geo_model.surface_points_copy_transformed.xyz[0, 2]),
49+
scale=torch.tensor(0.1, dtype=torch.float64)
50+
)
51+
52+
from gempy_probability.modules.model_definition.model_examples import two_wells_prob_model_I, pyro_gravity_model
53+
54+
_prob_run(
55+
geo_model=geo_model,
56+
prob_model=pyro_gravity_model,
57+
normal=normal,
58+
y_obs_list=torch.tensor([200, 210, 190])
59+
)
60+
61+
62+
def _prob_run(geo_model: gp.data.GeoModel, prob_model: callable,
63+
normal: torch.distributions.Distribution, y_obs_list: torch.Tensor) -> None:
64+
# Run prior sampling and visualization
65+
from pyro.infer import Predictive
66+
import pyro
67+
import arviz as az
68+
import matplotlib.pyplot as plt
69+
70+
from pyro.infer import NUTS
71+
from pyro.infer import MCMC
72+
from pyro.infer.autoguide import init_to_mean
73+
74+
# region prior sampling
75+
predictive = Predictive(
76+
model=prob_model,
77+
num_samples=50
78+
)
79+
prior = predictive(geo_model, y_obs_list)
80+
# print("Number of interpolations: ", geo_model.counter)
81+
82+
data = az.from_pyro(prior=prior)
83+
az.plot_trace(data.prior)
84+
plt.show()
85+
86+
# endregion
87+
88+
# region inference
89+
pyro.primitives.enable_validation(is_validate=True)
90+
nuts_kernel = NUTS(
91+
prob_model,
92+
step_size=0.0085,
93+
adapt_step_size=True,
94+
target_accept_prob=0.9,
95+
max_tree_depth=10,
96+
init_strategy=init_to_mean
97+
)
98+
mcmc = MCMC(
99+
kernel=nuts_kernel,
100+
num_samples=200,
101+
warmup_steps=50,
102+
disable_validation=False
103+
)
104+
mcmc.run(geo_model, normal, y_obs_list)
105+
posterior_samples = mcmc.get_samples()
106+
posterior_predictive_fn = Predictive(
107+
model=prob_model,
108+
posterior_samples=posterior_samples
109+
)
110+
posterior_predictive = posterior_predictive_fn(geo_model, normal, y_obs_list)
111+
112+
data = az.from_pyro(posterior=mcmc, prior=prior, posterior_predictive=posterior_predictive)
113+
# endregion
114+
115+
# print("Number of interpolations: ", geo_model.counter)
116+
117+
if True: # * Save the arviz data
118+
data.to_netcdf("arviz_data.nc")
119+
120+
az.plot_trace(data)
121+
plt.show()
122+
from gempy_probability.modules.plot.plot_posterior import default_red, default_blue
123+
az.plot_density(
124+
data=[data.posterior_predictive, data.prior_predictive],
125+
shade=.9,
126+
var_names=[r'$\mu_{thickness}$'],
127+
data_labels=["Posterior Predictive", "Prior Predictive"],
128+
colors=[default_red, default_blue],
129+
)
130+
plt.show()
131+
az.plot_density(
132+
data=[data, data.prior],
133+
shade=.9,
134+
hdi_prob=.99,
135+
data_labels=["Posterior", "Prior"],
136+
colors=[default_red, default_blue],
137+
)
138+
plt.show()
139+
140+
# Test posterior mean values
141+
posterior_top_mean = float(data.posterior[r'$\mu_{top}$'].mean())
142+
target_top = 0.00875
143+
target_thickness = 223
144+
assert abs(posterior_top_mean - target_top) < 0.0200, f"Top layer mean {posterior_top_mean} outside expected range"
145+
posterior_thickness_mean = float(data.posterior_predictive[r'$\mu_{thickness}$'].mean())
146+
assert abs(posterior_thickness_mean - target_thickness) < 5, f"Thickness mean {posterior_thickness_mean} outside expected range"
147+
# Test convergence diagnostics
148+
assert float(data.sample_stats.diverging.sum()) == 0, "MCMC sampling has divergences"
149+
150+
print("Posterior mean values:")
151+
print(f"Top layer mean: {posterior_top_mean}")
152+
print(f"Thickness mean: {posterior_thickness_mean}")
153+
print("MCMC convergence diagnostics:")
154+
print(f"Divergences: {float(data.sample_stats.diverging.sum())}")

0 commit comments

Comments
 (0)