Skip to content

Commit 25ca165

Browse files
committed
[CLN/ENH/WIP] Breaking down the different pieces of a prob model
1 parent 6d4519d commit 25ca165

File tree

2 files changed

+46
-16
lines changed

2 files changed

+46
-16
lines changed

gempy_probability/modules/model_definition/model_examples.py

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
import gempy_engine
33
from gempy.modules.data_manipulation.engine_factory import interpolation_input_from_structural_frame
44

5+
import pyro
6+
import pyro.distributions as dist
7+
import torch
8+
59

610
def model(geo_model: gempy.core.data.GeoModel, sp_coords_copy, y_obs_list):
711
"""
@@ -10,21 +14,34 @@ def model(geo_model: gempy.core.data.GeoModel, sp_coords_copy, y_obs_list):
1014
computes the thickness of the geological layer as an observed variable.
1115
"""
1216
# Define prior for the top layer's location:
17+
# region Prior definition
1318
prior_mean = sp_coords_copy[0, 2]
14-
import pyro
15-
import pyro.distributions as dist
16-
import torch
17-
mu_top = pyro.sample(r'$\mu_{top}$', dist.Normal(prior_mean, torch.tensor(0.02, dtype=torch.float64)))
19+
normal = dist.Normal(
20+
loc=prior_mean,
21+
scale=torch.tensor(0.02, dtype=torch.float64)
22+
)
23+
24+
mu_top = pyro.sample(
25+
name=r'$\mu_{top}$',
26+
fn=normal
27+
)
28+
# endregion
29+
30+
# region Prior injection into the gempy model
1831

19-
# Update the model with the new top layer's location
32+
# * Update the model with the new top layer's location
2033
interpolation_input = interpolation_input_from_structural_frame(geo_model)
2134
interpolation_input.surface_points.sp_coords = torch.index_put(
22-
interpolation_input.surface_points.sp_coords,
23-
(torch.tensor([0]), torch.tensor([2])),
24-
mu_top
35+
input=interpolation_input.surface_points.sp_coords,
36+
indices=(torch.tensor([0]), torch.tensor([2])),
37+
values=mu_top
2538
)
2639

27-
# Compute the geological model
40+
# endregion
41+
42+
# region Forward model computation
43+
44+
# * Compute the geological model
2845
geo_model.solutions = gempy_engine.compute_model(
2946
interpolation_input=interpolation_input,
3047
options=geo_model.interpolation_options,
@@ -35,5 +52,15 @@ def model(geo_model: gempy.core.data.GeoModel, sp_coords_copy, y_obs_list):
3552
# Compute and observe the thickness of the geological layer
3653
simulated_well = geo_model.solutions.octrees_output[0].last_output_center.custom_grid_values
3754
thickness = simulated_well.sum()
38-
pyro.deterministic(r'$\mu_{thickness}$', thickness.detach())
39-
y_thickness = pyro.sample(r'$y_{thickness}$', dist.Normal(thickness, 50), obs=y_obs_list)
55+
pyro.deterministic(
56+
name=r'$\mu_{thickness}$',
57+
value=thickness.detach()
58+
)
59+
60+
# endregion
61+
62+
y_thickness = pyro.sample(
63+
name=r'$y_{thickness}$',
64+
fn=dist.Normal(thickness, 50),
65+
obs=y_obs_list
66+
)

tests/test_prob_model/test_prob_I.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import gempy_engine
44
import numpy as np
55

6+
67
def test_basic_gempy_I():
78
current_dir = os.path.dirname(os.path.abspath(__file__))
89
data_path = os.path.abspath(os.path.join(current_dir, '..', '..', 'examples', 'tutorials', 'data'))
@@ -21,7 +22,6 @@ def test_basic_gempy_I():
2122
geo_model.interpolation_options.mesh_extraction = False
2223
geo_model.interpolation_options.sigmoid_slope = 1100.
2324

24-
2525
x_loc = 6000
2626
y_loc = 0
2727
z_loc = np.linspace(0, 4000, 100)
@@ -70,13 +70,17 @@ def test_basic_gempy_I():
7070
import matplotlib.pyplot as plt
7171

7272
from gempy_probability.modules.model_definition.model_examples import model
73-
prior = Predictive(model, num_samples=50)(geo_model, sp_coords_copy, y_obs_list)
74-
73+
predictive = Predictive(
74+
model=model,
75+
num_samples=50
76+
)
77+
78+
prior = predictive(geo_model, sp_coords_copy, y_obs_list)
79+
7580
data = az.from_pyro(prior=prior)
7681
az.plot_trace(data.prior)
7782
plt.show()
7883

79-
8084
from pyro.infer import NUTS
8185
from pyro.infer import MCMC
8286
from pyro.infer.autoguide import init_to_mean
@@ -92,7 +96,6 @@ def test_basic_gempy_I():
9296
mcmc = MCMC(nuts_kernel, num_samples=200, warmup_steps=50, disable_validation=False)
9397
mcmc.run(geo_model, sp_coords_copy, y_obs_list)
9498

95-
9699
posterior_samples = mcmc.get_samples()
97100
posterior_predictive = Predictive(model, posterior_samples)(geo_model, sp_coords_copy, y_obs_list)
98101
data = az.from_pyro(posterior=mcmc, prior=prior, posterior_predictive=posterior_predictive)

0 commit comments

Comments
 (0)