Skip to content

Commit 22f7cf2

Browse files
committed
[ENH] Refactor probability tests and model thickness logic
Streamlined probability testing by refining grid logic and posterior checks, added print diagnostics for clarity, and adjusted posterior value assertions. Moved geological layer thickness computation to a separate `_likelihood` function for better modularity and maintainability.
1 parent 6e8a113 commit 22f7cf2

File tree

3 files changed

+53
-28
lines changed

3 files changed

+53
-28
lines changed

docs/dev_logs/2025_May.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,11 @@
1414
--- MinEye
1515
- [ ] Revive magnetics
1616
- [ ] Add Simon's example
17+
18+
19+
### Possible Design:
20+
- Geological Model
21+
- Priors
22+
- Likelihood
23+
----
24+
- Having a runner or something like that?

gempy_probability/modules/model_definition/model_examples.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import gempy as gp
12
import gempy.core.data
23
import gempy_engine
34
from gempy.modules.data_manipulation.engine_factory import interpolation_input_from_structural_frame
@@ -27,13 +28,14 @@ def model(geo_model: gempy.core.data.GeoModel, normal, y_obs_list):
2728
# * Update the model with the new top layer's location
2829
interpolation_input = interpolation_input_from_structural_frame(geo_model)
2930

30-
if False: # ?? I need to figure out if we need the index_put or not
31+
if True: # ?? I need to figure out if we need the index_put or not
3132
indices__ = (torch.tensor([0]), torch.tensor([2])) # * This has to be Tensors
32-
interpolation_input.surface_points.sp_coords = torch.index_put(
33+
new_tensor: torch.Tensor = torch.index_put(
3334
input=interpolation_input.surface_points.sp_coords,
3435
indices=indices__,
3536
values=mu_top
3637
)
38+
interpolation_input.surface_points.sp_coords = new_tensor
3739
else:
3840
interpolation_input.surface_points.sp_coords[0, 2] = mu_top
3941

@@ -50,18 +52,29 @@ def model(geo_model: gempy.core.data.GeoModel, normal, y_obs_list):
5052
)
5153

5254
# Compute and observe the thickness of the geological layer
53-
simulated_well = geo_model.solutions.octrees_output[0].last_output_center.custom_grid_values
54-
thickness = simulated_well.sum()
55-
pyro.deterministic(
56-
name=r'$\mu_{thickness}$',
57-
value=thickness.detach() # * This is only for az to track progress
58-
)
55+
model_solutions: gp.data.Solutions = geo_model.solutions
56+
thickness = _likelihood(model_solutions)
5957

6058
# endregion
6159

6260
posterior_dist_normal = dist.Normal(thickness, 25)
61+
62+
# * This is used automagically by pyro to compute the log-probability
6363
y_thickness = pyro.sample(
6464
name=r'$y_{thickness}$',
6565
fn=posterior_dist_normal,
6666
obs=y_obs_list
6767
)
68+
69+
70+
def _likelihood(model_solutions: gp.data.Solutions) -> torch.Tensor:
71+
"""
72+
This function computes the thickness of the geological layer.
73+
"""
74+
simulated_well = model_solutions.octrees_output[0].last_output_center.custom_grid_values
75+
thickness = simulated_well.sum()
76+
pyro.deterministic(
77+
name=r'$\mu_{thickness}$',
78+
value=thickness.detach() # * This is only for az to track progress
79+
)
80+
return thickness

tests/test_prob_model/test_prob_I.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import os
33
import pyro.distributions as dist
44
import torch
5-
from pyro.distributions import TorchDistributionMixin
65

76
import gempy as gp
87
from gempy_engine.core.backend_tensor import BackendTensor
@@ -24,7 +23,7 @@ def test_basic_gempy_I() -> None:
2423
# TODO: Convert this into an options preset
2524
geo_model.interpolation_options.uni_degree = 0
2625
geo_model.interpolation_options.mesh_extraction = False
27-
geo_model.interpolation_options.sigmoid_slope = 1100.
26+
geo_model.interpolation_options.sigmoid_slope = 1100.
2827

2928
# region Minimal grid for the specific likelihood function
3029
x_loc = 6000
@@ -34,16 +33,10 @@ def test_basic_gempy_I() -> None:
3433
gp.set_custom_grid(geo_model.grid, xyz_coord=xyz_coord)
3534
# endregion
3635

37-
# TODO: Make sure only the custom grid ins active
38-
39-
40-
gp.compute_model(
41-
gempy_model=geo_model,
42-
engine_config=gp.data.GemPyEngineConfig(
43-
backend=gp.data.AvailableBackends.numpy
44-
)
45-
)
46-
36+
geo_model.grid.active_grids = gp.data.Grid.GridTypes.CUSTOM
37+
assert geo_model.grid.values.shape[0] == 100, "Custom grid should have 100 cells"
38+
gp.compute_model(gempy_model=geo_model)
39+
4740
# TODO: Make this a more elegant way
4841
BackendTensor.change_backend_gempy(engine_backend=gp.data.AvailableBackends.PYTORCH)
4942

@@ -62,7 +55,7 @@ def test_basic_gempy_I() -> None:
6255

6356

6457
def _prob_run(geo_model: gp.data.GeoModel, prob_model: callable,
65-
normal: TorchDistributionMixin, y_obs_list: torch.Tensor) -> None:
58+
normal: torch.distributions.Distribution, y_obs_list: torch.Tensor) -> None:
6659
# Run prior sampling and visualization
6760
from pyro.infer import Predictive
6861
import pyro
@@ -110,13 +103,6 @@ def _prob_run(geo_model: gp.data.GeoModel, prob_model: callable,
110103
)
111104
posterior_predictive = posterior_predictive_fn(geo_model, normal, y_obs_list)
112105
data = az.from_pyro(posterior=mcmc, prior=prior, posterior_predictive=posterior_predictive)
113-
# Test posterior mean values
114-
posterior_top_mean = float(data.posterior[r'$\mu_{top}$'].mean())
115-
assert 0.0070 < posterior_top_mean < 0.0071, f"Top layer mean {posterior_top_mean} outside expected range"
116-
posterior_thickness_mean = float(data.posterior_predictive[r'$\mu_{thickness}$'].mean())
117-
assert 220 < posterior_thickness_mean < 225, f"Thickness mean {posterior_thickness_mean} outside expected range"
118-
# Test convergence diagnostics
119-
assert float(data.sample_stats.diverging.sum()) == 0, "MCMC sampling has divergences"
120106
# endregion
121107

122108
az.plot_trace(data)
@@ -138,3 +124,21 @@ def _prob_run(geo_model: gp.data.GeoModel, prob_model: callable,
138124
colors=[default_red, default_blue],
139125
)
140126
plt.show()
127+
128+
# Test posterior mean values
129+
posterior_top_mean = float(data.posterior[r'$\mu_{top}$'].mean())
130+
target_top = 0.00875
131+
target_thickness = 223
132+
assert abs(posterior_top_mean - target_top) < 0.0200, f"Top layer mean {posterior_top_mean} outside expected range"
133+
posterior_thickness_mean = float(data.posterior_predictive[r'$\mu_{thickness}$'].mean())
134+
assert abs(posterior_thickness_mean - target_thickness) < 5, f"Thickness mean {posterior_thickness_mean} outside expected range"
135+
# Test convergence diagnostics
136+
assert float(data.sample_stats.diverging.sum()) == 0, "MCMC sampling has divergences"
137+
138+
print("Posterior mean values:")
139+
print(f"Top layer mean: {posterior_top_mean}")
140+
print(f"Thickness mean: {posterior_thickness_mean}")
141+
print("MCMC convergence diagnostics:")
142+
print(f"Divergences: {float(data.sample_stats.diverging.sum())}")
143+
print(f"Acceptance rate: {float(data.sample_stats.acceptance_rate.mean())}")
144+
print(f"Mean tree depth: {float(data.sample_stats.mean_tree_depth.mean())}")

0 commit comments

Comments
 (0)