Skip to content

Commit 6e8a113

Browse files
committed
[ENH/CLN] Refactor probabilistic workflow and add validations
Refactored probabilistic model setup by modularizing `_prob_run` for clarity and reusability. Added posterior mean value checks, MCMC convergence diagnostics, and cleaned redundant comments to enhance readability and test robustness.
1 parent 7bced3e commit 6e8a113

File tree

2 files changed

+60
-44
lines changed

2 files changed

+60
-44
lines changed

gempy_probability/modules/model_definition/model_examples.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,16 @@ def model(geo_model: gempy.core.data.GeoModel, normal, y_obs_list):
2626

2727
# * Update the model with the new top layer's location
2828
interpolation_input = interpolation_input_from_structural_frame(geo_model)
29-
interpolation_input.surface_points.sp_coords = torch.index_put(
30-
input=interpolation_input.surface_points.sp_coords,
31-
indices=(torch.tensor([0]), torch.tensor([2])),
32-
values=mu_top
33-
)
34-
# interpolation_input.surface_points.sp_coords[0, 2] = mu_top
29+
30+
if False: # ?? I need to figure out if we need the index_put or not
31+
indices__ = (torch.tensor([0]), torch.tensor([2])) # * This has to be Tensors
32+
interpolation_input.surface_points.sp_coords = torch.index_put(
33+
input=interpolation_input.surface_points.sp_coords,
34+
indices=indices__,
35+
values=mu_top
36+
)
37+
else:
38+
interpolation_input.surface_points.sp_coords[0, 2] = mu_top
3539

3640
# endregion
3741

@@ -50,13 +54,14 @@ def model(geo_model: gempy.core.data.GeoModel, normal, y_obs_list):
5054
thickness = simulated_well.sum()
5155
pyro.deterministic(
5256
name=r'$\mu_{thickness}$',
53-
value=thickness.detach()
57+
value=thickness.detach() # * This is only for az to track progress
5458
)
5559

5660
# endregion
5761

62+
posterior_dist_normal = dist.Normal(thickness, 25)
5863
y_thickness = pyro.sample(
5964
name=r'$y_{thickness}$',
60-
fn=dist.Normal(thickness, 25),
65+
fn=posterior_dist_normal,
6166
obs=y_obs_list
6267
)

tests/test_prob_model/test_prob_I.py

Lines changed: 47 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1+
import numpy as np
12
import os
3+
import pyro.distributions as dist
4+
import torch
5+
from pyro.distributions import TorchDistributionMixin
6+
27
import gempy as gp
3-
import gempy_engine
4-
import numpy as np
8+
from gempy_engine.core.backend_tensor import BackendTensor
59

610

711
def test_basic_gempy_I() -> None:
@@ -20,68 +24,72 @@ def test_basic_gempy_I() -> None:
2024
# TODO: Convert this into an options preset
2125
geo_model.interpolation_options.uni_degree = 0
2226
geo_model.interpolation_options.mesh_extraction = False
23-
geo_model.interpolation_options.sigmoid_slope = 1100.
27+
geo_model.interpolation_options.sigmoid_slope = 1100.
2428

29+
# region Minimal grid for the specific likelihood function
2530
x_loc = 6000
2631
y_loc = 0
2732
z_loc = np.linspace(0, 4000, 100)
2833
xyz_coord = np.array([[x_loc, y_loc, z] for z in z_loc])
2934
gp.set_custom_grid(geo_model.grid, xyz_coord=xyz_coord)
35+
# endregion
3036

3137
# TODO: Make sure only the custom grid ins active
32-
38+
39+
3340
gp.compute_model(
3441
gempy_model=geo_model,
35-
engine_config=gp.data.GemPyEngineConfig(backend=gp.data.AvailableBackends.numpy)
42+
engine_config=gp.data.GemPyEngineConfig(
43+
backend=gp.data.AvailableBackends.numpy
44+
)
3645
)
37-
38-
from gempy_engine.core.backend_tensor import BackendTensor
39-
BackendTensor.change_backend_gempy(engine_backend=gp.data.AvailableBackends.PYTORCH)
4046

41-
import pyro.distributions as dist
42-
import torch
47+
# TODO: Make this a more elegant way
48+
BackendTensor.change_backend_gempy(engine_backend=gp.data.AvailableBackends.PYTORCH)
4349

4450
normal = dist.Normal(
4551
loc=(geo_model.surface_points_copy_transformed.xyz[0, 2]),
4652
scale=torch.tensor(0.1, dtype=torch.float64)
4753
)
48-
# %%
49-
# Running Prior Sampling and Visualization
50-
# ----------------------------------------
51-
# Prior sampling is an essential step in probabilistic modeling.
52-
# It helps in understanding the distribution of our prior assumptions before observing any data.
53-
54-
# %%
55-
# Prepare observation data
56-
import torch
57-
y_obs_list = torch.tensor([200, 210, 190])
58-
59-
# %%
54+
55+
from gempy_probability.modules.model_definition.model_examples import model
56+
_prob_run(
57+
geo_model=geo_model,
58+
prob_model=model,
59+
normal=normal,
60+
y_obs_list=torch.tensor([200, 210, 190])
61+
)
62+
63+
64+
def _prob_run(geo_model: gp.data.GeoModel, prob_model: callable,
65+
normal: TorchDistributionMixin, y_obs_list: torch.Tensor) -> None:
6066
# Run prior sampling and visualization
6167
from pyro.infer import Predictive
6268
import pyro
6369
import arviz as az
6470
import matplotlib.pyplot as plt
6571

66-
from gempy_probability.modules.model_definition.model_examples import model
72+
from pyro.infer import NUTS
73+
from pyro.infer import MCMC
74+
from pyro.infer.autoguide import init_to_mean
75+
76+
# region prior sampling
6777
predictive = Predictive(
68-
model=model,
78+
model=prob_model,
6979
num_samples=50
7080
)
71-
7281
prior = predictive(geo_model, normal, y_obs_list)
7382

7483
data = az.from_pyro(prior=prior)
7584
az.plot_trace(data.prior)
7685
plt.show()
7786

78-
from pyro.infer import NUTS
79-
from pyro.infer import MCMC
80-
from pyro.infer.autoguide import init_to_mean
87+
# endregion
8188

89+
# region inference
8290
pyro.primitives.enable_validation(is_validate=True)
8391
nuts_kernel = NUTS(
84-
model,
92+
prob_model,
8593
step_size=0.0085,
8694
adapt_step_size=True,
8795
target_accept_prob=0.9,
@@ -95,20 +103,24 @@ def test_basic_gempy_I() -> None:
95103
disable_validation=False
96104
)
97105
mcmc.run(geo_model, normal, y_obs_list)
98-
99106
posterior_samples = mcmc.get_samples()
100-
101107
posterior_predictive_fn = Predictive(
102-
model=model,
108+
model=prob_model,
103109
posterior_samples=posterior_samples
104110
)
105-
106111
posterior_predictive = posterior_predictive_fn(geo_model, normal, y_obs_list)
107-
108112
data = az.from_pyro(posterior=mcmc, prior=prior, posterior_predictive=posterior_predictive)
113+
# Test posterior mean values
114+
posterior_top_mean = float(data.posterior[r'$\mu_{top}$'].mean())
115+
assert 0.0070 < posterior_top_mean < 0.0071, f"Top layer mean {posterior_top_mean} outside expected range"
116+
posterior_thickness_mean = float(data.posterior_predictive[r'$\mu_{thickness}$'].mean())
117+
assert 220 < posterior_thickness_mean < 225, f"Thickness mean {posterior_thickness_mean} outside expected range"
118+
# Test convergence diagnostics
119+
assert float(data.sample_stats.diverging.sum()) == 0, "MCMC sampling has divergences"
120+
# endregion
121+
109122
az.plot_trace(data)
110123
plt.show()
111-
112124
from gempy_probability.modules.plot.plot_posterior import default_red, default_blue
113125
az.plot_density(
114126
data=[data.posterior_predictive, data.prior_predictive],
@@ -118,7 +130,6 @@ def test_basic_gempy_I() -> None:
118130
colors=[default_red, default_blue],
119131
)
120132
plt.show()
121-
122133
az.plot_density(
123134
data=[data, data.prior],
124135
shade=.9,

0 commit comments

Comments
 (0)