Skip to content

Commit 4dacef7

Browse files
authored
Add probabilistic model examples and tests for GemPy (#5)
# [DOC] Adding development roadmap and probabilistic modeling enhancements This PR introduces a comprehensive development roadmap and implements probabilistic modeling capabilities with Pyro integration. ## Key Changes: - Added development log for May 2025 with prioritized tasks and ongoing work - Created a modular probabilistic model definition structure in `gempy_probability/modules/model_definition/` - Implemented a reusable Pyro model for geological layer thickness estimation - Added test infrastructure for probabilistic modeling with MCMC sampling - Implemented visualization of prior and posterior distributions using ArviZ - Added robust validation with convergence diagnostics and posterior checks The probabilistic workflow now supports sampling from priors, MCMC inference with NUTS, and comprehensive posterior analysis. The implementation includes proper separation of model definition, likelihood computation, and inference procedures.
2 parents 56afdb3 + 6f2f591 commit 4dacef7

File tree

5 files changed

+251
-0
lines changed

5 files changed

+251
-0
lines changed

docs/dev_logs/2025_May.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# TODO:
2+
3+
--- General
4+
- [ ] Create a proper suit of tests
5+
- [ ] Keep extracting code into modules and functions
6+
- [ ] Posterior analysis for gempy
7+
- [ ] Investigate speed
8+
- [ ] Investigate segmentation function
9+
- [ ] Create a robust api to make sure we do not mess up assigning priors to gempy parameters
10+
- [ ] do we need to use index_put or just normal indexing?
11+
--- Vector examples
12+
- [ ] Bring the Vector example here...
13+
- [ ] Using categories as likelihood function (maybe I need to check out with Sam)
14+
--- MinEye
15+
- [ ] Revive magnetics
16+
- [ ] Add Simon's example
17+
18+
### Doing:
19+
- [ ] Compressing code
20+
- [ ] GemPy posterior visualization
21+
22+
### Possible Design:
23+
- Geological Model
24+
- Priors
25+
- Likelihood
26+
----
27+
- Having a runner or something like that?
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import gempy as gp
2+
import gempy.core.data
3+
import gempy_engine
4+
from gempy.modules.data_manipulation.engine_factory import interpolation_input_from_structural_frame
5+
6+
import pyro
7+
import pyro.distributions as dist
8+
import torch
9+
10+
11+
def model(geo_model: gempy.core.data.GeoModel, normal, y_obs_list):
12+
"""
13+
This Pyro model represents the probabilistic aspects of the geological model.
14+
It defines a prior distribution for the top layer's location and
15+
computes the thickness of the geological layer as an observed variable.
16+
"""
17+
# Define prior for the top layer's location:
18+
# region Prior definition
19+
20+
mu_top = pyro.sample(
21+
name=r'$\mu_{top}$',
22+
fn=normal
23+
)
24+
# endregion
25+
26+
# region Prior injection into the gempy model
27+
28+
# * Update the model with the new top layer's location
29+
interpolation_input = interpolation_input_from_structural_frame(geo_model)
30+
31+
if True: # ?? I need to figure out if we need the index_put or not
32+
indices__ = (torch.tensor([0]), torch.tensor([2])) # * This has to be Tensors
33+
new_tensor: torch.Tensor = torch.index_put(
34+
input=interpolation_input.surface_points.sp_coords,
35+
indices=indices__,
36+
values=mu_top
37+
)
38+
interpolation_input.surface_points.sp_coords = new_tensor
39+
else:
40+
interpolation_input.surface_points.sp_coords[0, 2] = mu_top
41+
42+
# endregion
43+
44+
# region Forward model computation
45+
46+
# * Compute the geological model
47+
geo_model.solutions = gempy_engine.compute_model(
48+
interpolation_input=interpolation_input,
49+
options=geo_model.interpolation_options,
50+
data_descriptor=geo_model.input_data_descriptor,
51+
geophysics_input=geo_model.geophysics_input,
52+
)
53+
54+
# Compute and observe the thickness of the geological layer
55+
model_solutions: gp.data.Solutions = geo_model.solutions
56+
thickness = _likelihood(model_solutions)
57+
58+
# endregion
59+
60+
posterior_dist_normal = dist.Normal(thickness, 25)
61+
62+
# * This is used automagically by pyro to compute the log-probability
63+
y_thickness = pyro.sample(
64+
name=r'$y_{thickness}$',
65+
fn=posterior_dist_normal,
66+
obs=y_obs_list
67+
)
68+
69+
70+
def _likelihood(model_solutions: gp.data.Solutions) -> torch.Tensor:
71+
"""
72+
This function computes the thickness of the geological layer.
73+
"""
74+
simulated_well = model_solutions.octrees_output[0].last_output_center.custom_grid_values
75+
thickness = simulated_well.sum()
76+
pyro.deterministic(
77+
name=r'$\mu_{thickness}$',
78+
value=thickness.detach() # * This is only for az to track progress
79+
)
80+
return thickness

tests/test_prob_model/__init__.py

Whitespace-only changes.

tests/test_prob_model/test_prob_I.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
import numpy as np
2+
import os
3+
import pyro.distributions as dist
4+
import torch
5+
6+
import gempy as gp
7+
from gempy_engine.core.backend_tensor import BackendTensor
8+
9+
10+
def test_basic_gempy_I() -> None:
11+
current_dir = os.path.dirname(os.path.abspath(__file__))
12+
data_path = os.path.abspath(os.path.join(current_dir, '..', '..', 'examples', 'tutorials', 'data'))
13+
geo_model = gp.create_geomodel(
14+
project_name='Wells',
15+
extent=[0, 12000, -500, 500, 0, 4000],
16+
refinement=3,
17+
importer_helper=gp.data.ImporterHelper(
18+
path_to_orientations=os.path.join(data_path, "2-layers", "2-layers_orientations.csv"),
19+
path_to_surface_points=os.path.join(data_path, "2-layers", "2-layers_surface_points.csv")
20+
)
21+
)
22+
23+
# TODO: Convert this into an options preset
24+
geo_model.interpolation_options.uni_degree = 0
25+
geo_model.interpolation_options.mesh_extraction = False
26+
geo_model.interpolation_options.sigmoid_slope = 1100.
27+
28+
# region Minimal grid for the specific likelihood function
29+
x_loc = 6000
30+
y_loc = 0
31+
z_loc = np.linspace(0, 4000, 100)
32+
xyz_coord = np.array([[x_loc, y_loc, z] for z in z_loc])
33+
gp.set_custom_grid(geo_model.grid, xyz_coord=xyz_coord)
34+
# endregion
35+
36+
geo_model.grid.active_grids = gp.data.Grid.GridTypes.CUSTOM
37+
assert geo_model.grid.values.shape[0] == 100, "Custom grid should have 100 cells"
38+
gp.compute_model(gempy_model=geo_model)
39+
40+
# TODO: Make this a more elegant way
41+
BackendTensor.change_backend_gempy(engine_backend=gp.data.AvailableBackends.PYTORCH)
42+
43+
normal = dist.Normal(
44+
loc=(geo_model.surface_points_copy_transformed.xyz[0, 2]),
45+
scale=torch.tensor(0.1, dtype=torch.float64)
46+
)
47+
48+
from gempy_probability.modules.model_definition.model_examples import model
49+
_prob_run(
50+
geo_model=geo_model,
51+
prob_model=model,
52+
normal=normal,
53+
y_obs_list=torch.tensor([200, 210, 190])
54+
)
55+
56+
57+
def _prob_run(geo_model: gp.data.GeoModel, prob_model: callable,
58+
normal: torch.distributions.Distribution, y_obs_list: torch.Tensor) -> None:
59+
# Run prior sampling and visualization
60+
from pyro.infer import Predictive
61+
import pyro
62+
import arviz as az
63+
import matplotlib.pyplot as plt
64+
65+
from pyro.infer import NUTS
66+
from pyro.infer import MCMC
67+
from pyro.infer.autoguide import init_to_mean
68+
69+
# region prior sampling
70+
predictive = Predictive(
71+
model=prob_model,
72+
num_samples=50
73+
)
74+
prior = predictive(geo_model, normal, y_obs_list)
75+
76+
data = az.from_pyro(prior=prior)
77+
az.plot_trace(data.prior)
78+
plt.show()
79+
80+
# endregion
81+
82+
# region inference
83+
pyro.primitives.enable_validation(is_validate=True)
84+
nuts_kernel = NUTS(
85+
prob_model,
86+
step_size=0.0085,
87+
adapt_step_size=True,
88+
target_accept_prob=0.9,
89+
max_tree_depth=10,
90+
init_strategy=init_to_mean
91+
)
92+
mcmc = MCMC(
93+
kernel=nuts_kernel,
94+
num_samples=200,
95+
warmup_steps=50,
96+
disable_validation=False
97+
)
98+
mcmc.run(geo_model, normal, y_obs_list)
99+
posterior_samples = mcmc.get_samples()
100+
posterior_predictive_fn = Predictive(
101+
model=prob_model,
102+
posterior_samples=posterior_samples
103+
)
104+
posterior_predictive = posterior_predictive_fn(geo_model, normal, y_obs_list)
105+
data = az.from_pyro(posterior=mcmc, prior=prior, posterior_predictive=posterior_predictive)
106+
# endregion
107+
108+
az.plot_trace(data)
109+
plt.show()
110+
from gempy_probability.modules.plot.plot_posterior import default_red, default_blue
111+
az.plot_density(
112+
data=[data.posterior_predictive, data.prior_predictive],
113+
shade=.9,
114+
var_names=[r'$\mu_{thickness}$'],
115+
data_labels=["Posterior Predictive", "Prior Predictive"],
116+
colors=[default_red, default_blue],
117+
)
118+
plt.show()
119+
az.plot_density(
120+
data=[data, data.prior],
121+
shade=.9,
122+
hdi_prob=.99,
123+
data_labels=["Posterior", "Prior"],
124+
colors=[default_red, default_blue],
125+
)
126+
plt.show()
127+
128+
# Test posterior mean values
129+
posterior_top_mean = float(data.posterior[r'$\mu_{top}$'].mean())
130+
target_top = 0.00875
131+
target_thickness = 223
132+
assert abs(posterior_top_mean - target_top) < 0.0200, f"Top layer mean {posterior_top_mean} outside expected range"
133+
posterior_thickness_mean = float(data.posterior_predictive[r'$\mu_{thickness}$'].mean())
134+
assert abs(posterior_thickness_mean - target_thickness) < 5, f"Thickness mean {posterior_thickness_mean} outside expected range"
135+
# Test convergence diagnostics
136+
assert float(data.sample_stats.diverging.sum()) == 0, "MCMC sampling has divergences"
137+
138+
print("Posterior mean values:")
139+
print(f"Top layer mean: {posterior_top_mean}")
140+
print(f"Thickness mean: {posterior_thickness_mean}")
141+
print("MCMC convergence diagnostics:")
142+
print(f"Divergences: {float(data.sample_stats.diverging.sum())}")
143+
print(f"Acceptance rate: {float(data.sample_stats.acceptance_rate.mean())}")
144+
print(f"Mean tree depth: {float(data.sample_stats.mean_tree_depth.mean())}")

0 commit comments

Comments
 (0)