Skip to content

Commit 6d4519d

Browse files
committed
[CLN/ENH/WIP] Figuring out extracting model definition
1 parent b1fe4ce commit 6d4519d

File tree

3 files changed

+44
-48
lines changed

3 files changed

+44
-48
lines changed

gempy_probability/modules/model_definition/__init__.py

Whitespace-only changes.
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import gempy.core.data
2+
import gempy_engine
3+
from gempy.modules.data_manipulation.engine_factory import interpolation_input_from_structural_frame
4+
5+
6+
def model(geo_model: gempy.core.data.GeoModel, sp_coords_copy, y_obs_list):
7+
"""
8+
This Pyro model represents the probabilistic aspects of the geological model.
9+
It defines a prior distribution for the top layer's location and
10+
computes the thickness of the geological layer as an observed variable.
11+
"""
12+
# Define prior for the top layer's location:
13+
prior_mean = sp_coords_copy[0, 2]
14+
import pyro
15+
import pyro.distributions as dist
16+
import torch
17+
mu_top = pyro.sample(r'$\mu_{top}$', dist.Normal(prior_mean, torch.tensor(0.02, dtype=torch.float64)))
18+
19+
# Update the model with the new top layer's location
20+
interpolation_input = interpolation_input_from_structural_frame(geo_model)
21+
interpolation_input.surface_points.sp_coords = torch.index_put(
22+
interpolation_input.surface_points.sp_coords,
23+
(torch.tensor([0]), torch.tensor([2])),
24+
mu_top
25+
)
26+
27+
# Compute the geological model
28+
geo_model.solutions = gempy_engine.compute_model(
29+
interpolation_input=interpolation_input,
30+
options=geo_model.interpolation_options,
31+
data_descriptor=geo_model.input_data_descriptor,
32+
geophysics_input=geo_model.geophysics_input,
33+
)
34+
35+
# Compute and observe the thickness of the geological layer
36+
simulated_well = geo_model.solutions.octrees_output[0].last_output_center.custom_grid_values
37+
thickness = simulated_well.sum()
38+
pyro.deterministic(r'$\mu_{thickness}$', thickness.detach())
39+
y_thickness = pyro.sample(r'$y_{thickness}$', dist.Normal(thickness, 50), obs=y_obs_list)

tests/test_prob_model/test_prob_I.py

Lines changed: 5 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -44,57 +44,13 @@ def test_basic_gempy_I():
4444

4545
from gempy_engine.core.data.interpolation_input import InterpolationInput
4646
from gempy_engine.core.backend_tensor import BackendTensor
47-
4847
from gempy.modules.data_manipulation.engine_factory import interpolation_input_from_structural_frame
4948

5049
interpolation_input_copy: InterpolationInput = interpolation_input_from_structural_frame(geo_model)
5150
sp_coords_copy = interpolation_input_copy.surface_points.sp_coords
5251
# Change the backend to PyTorch for probabilistic modeling
5352
BackendTensor.change_backend_gempy(engine_backend=gp.data.AvailableBackends.PYTORCH)
5453

55-
# Defining the Probabilistic Model
56-
# --------------------------------
57-
# The Pyro model represents the probabilistic aspects of the geological model.
58-
# It defines a prior distribution for the top layer's location and computes the thickness
59-
# of the geological layer as an observed variable.
60-
61-
def model(y_obs_list):
62-
"""
63-
This Pyro model represents the probabilistic aspects of the geological model.
64-
It defines a prior distribution for the top layer's location and
65-
computes the thickness of the geological layer as an observed variable.
66-
"""
67-
# Define prior for the top layer's location:
68-
prior_mean = sp_coords_copy[0, 2]
69-
import pyro
70-
import pyro.distributions as dist
71-
import torch
72-
mu_top = pyro.sample(r'$\mu_{top}$', dist.Normal(prior_mean, torch.tensor(0.02, dtype=torch.float64)))
73-
74-
# Update the model with the new top layer's location
75-
interpolation_input = interpolation_input_from_structural_frame(geo_model)
76-
interpolation_input.surface_points.sp_coords = torch.index_put(
77-
interpolation_input.surface_points.sp_coords,
78-
(torch.tensor([0]), torch.tensor([2])),
79-
mu_top
80-
)
81-
82-
# Compute the geological model
83-
geo_model.solutions = gempy_engine.compute_model(
84-
interpolation_input=interpolation_input,
85-
options=geo_model.interpolation_options,
86-
data_descriptor=geo_model.input_data_descriptor,
87-
geophysics_input=geo_model.geophysics_input,
88-
)
89-
90-
# Compute and observe the thickness of the geological layer
91-
simulated_well = geo_model.solutions.octrees_output[0].last_output_center.custom_grid_values
92-
thickness = simulated_well.sum()
93-
pyro.deterministic(r'$\mu_{thickness}$', thickness.detach())
94-
y_thickness = pyro.sample(r'$y_{thickness}$', dist.Normal(thickness, 50), obs=y_obs_list)
95-
96-
97-
9854
# %%
9955
# Running Prior Sampling and Visualization
10056
# ----------------------------------------
@@ -112,8 +68,9 @@ def model(y_obs_list):
11268
import pyro
11369
import arviz as az
11470
import matplotlib.pyplot as plt
115-
116-
prior = Predictive(model, num_samples=50)(y_obs_list)
71+
72+
from gempy_probability.modules.model_definition.model_examples import model
73+
prior = Predictive(model, num_samples=50)(geo_model, sp_coords_copy, y_obs_list)
11774

11875
data = az.from_pyro(prior=prior)
11976
az.plot_trace(data.prior)
@@ -133,11 +90,11 @@ def model(y_obs_list):
13390
init_strategy=init_to_mean
13491
)
13592
mcmc = MCMC(nuts_kernel, num_samples=200, warmup_steps=50, disable_validation=False)
136-
mcmc.run(y_obs_list)
93+
mcmc.run(geo_model, sp_coords_copy, y_obs_list)
13794

13895

13996
posterior_samples = mcmc.get_samples()
140-
posterior_predictive = Predictive(model, posterior_samples)(y_obs_list)
97+
posterior_predictive = Predictive(model, posterior_samples)(geo_model, sp_coords_copy, y_obs_list)
14198
data = az.from_pyro(posterior=mcmc, prior=prior, posterior_predictive=posterior_predictive)
14299
az.plot_trace(data)
143100
plt.show()

0 commit comments

Comments
 (0)