Skip to content

Commit 774ccab

Browse files
committed
[TEST][ENH] Add probabilistic geomodeling test with Pyro
Introduce a test script for probabilistic geomodeling using Pyro, incorporating prior sampling, posterior inference, and visualization. This enhances testing coverage with a probabilistic approach while demonstrating Pyro integration with GemPy.
1 parent c439e32 commit 774ccab

File tree

2 files changed

+143
-0
lines changed

2 files changed

+143
-0
lines changed

tests/test_prob_model/test_prob_I.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
import os
2+
import gempy as gp
3+
import gempy_engine
4+
import numpy as np
5+
6+
def test_basic_gempy_I():
7+
current_dir = os.path.dirname(os.path.abspath(__file__))
8+
data_path = os.path.abspath(os.path.join(current_dir, '..', '..', 'examples', 'tutorials', 'data'))
9+
geo_model = gp.create_geomodel(
10+
project_name='Wells',
11+
extent=[0, 12000, -500, 500, 0, 4000],
12+
refinement=3,
13+
importer_helper=gp.data.ImporterHelper(
14+
path_to_orientations=os.path.join(data_path, "2-layers", "2-layers_orientations.csv"),
15+
path_to_surface_points=os.path.join(data_path, "2-layers", "2-layers_surface_points.csv")
16+
)
17+
)
18+
19+
# TODO: Convert this into an options preset
20+
geo_model.interpolation_options.uni_degree = 0
21+
geo_model.interpolation_options.mesh_extraction = False
22+
geo_model.interpolation_options.sigmoid_slope = 1100.
23+
24+
25+
x_loc = 6000
26+
y_loc = 0
27+
z_loc = np.linspace(0, 4000, 100)
28+
xyz_coord = np.array([[x_loc, y_loc, z] for z in z_loc])
29+
gp.set_custom_grid(geo_model.grid, xyz_coord=xyz_coord)
30+
31+
# TODO: Make sure only the custom grid ins active
32+
33+
gp.compute_model(
34+
gempy_model=geo_model,
35+
engine_config=gp.data.GemPyEngineConfig(backend=gp.data.AvailableBackends.numpy)
36+
)
37+
38+
# TODO: This is the part that has to go to a function no question
39+
# Probabilistic Geomodeling with Pyro
40+
# -----------------------------------
41+
# In this section, we introduce a probabilistic approach to geological modeling.
42+
# By using Pyro, a probabilistic programming language, we define a model that integrates
43+
# geological data with uncertainty quantification.
44+
45+
from gempy_engine.core.data.interpolation_input import InterpolationInput
46+
from gempy_engine.core.backend_tensor import BackendTensor
47+
48+
from gempy.modules.data_manipulation.engine_factory import interpolation_input_from_structural_frame
49+
50+
interpolation_input_copy: InterpolationInput = interpolation_input_from_structural_frame(geo_model)
51+
sp_coords_copy = interpolation_input_copy.surface_points.sp_coords
52+
# Change the backend to PyTorch for probabilistic modeling
53+
BackendTensor.change_backend_gempy(engine_backend=gp.data.AvailableBackends.PYTORCH)
54+
55+
# Defining the Probabilistic Model
56+
# --------------------------------
57+
# The Pyro model represents the probabilistic aspects of the geological model.
58+
# It defines a prior distribution for the top layer's location and computes the thickness
59+
# of the geological layer as an observed variable.
60+
61+
def model(y_obs_list):
62+
"""
63+
This Pyro model represents the probabilistic aspects of the geological model.
64+
It defines a prior distribution for the top layer's location and
65+
computes the thickness of the geological layer as an observed variable.
66+
"""
67+
# Define prior for the top layer's location:
68+
prior_mean = sp_coords_copy[0, 2]
69+
import pyro
70+
import pyro.distributions as dist
71+
import torch
72+
mu_top = pyro.sample(r'$\mu_{top}$', dist.Normal(prior_mean, torch.tensor(0.02, dtype=torch.float64)))
73+
74+
# Update the model with the new top layer's location
75+
interpolation_input = interpolation_input_from_structural_frame(geo_model)
76+
interpolation_input.surface_points.sp_coords = torch.index_put(
77+
interpolation_input.surface_points.sp_coords,
78+
(torch.tensor([0]), torch.tensor([2])),
79+
mu_top
80+
)
81+
82+
# Compute the geological model
83+
geo_model.solutions = gempy_engine.compute_model(
84+
interpolation_input=interpolation_input,
85+
options=geo_model.interpolation_options,
86+
data_descriptor=geo_model.input_data_descriptor,
87+
geophysics_input=geo_model.geophysics_input,
88+
)
89+
90+
# Compute and observe the thickness of the geological layer
91+
simulated_well = geo_model.solutions.octrees_output[0].last_output_center.custom_grid_values
92+
thickness = simulated_well.sum()
93+
pyro.deterministic(r'$\mu_{thickness}$', thickness.detach())
94+
y_thickness = pyro.sample(r'$y_{thickness}$', dist.Normal(thickness, 50), obs=y_obs_list)
95+
96+
97+
98+
# %%
99+
# Running Prior Sampling and Visualization
100+
# ----------------------------------------
101+
# Prior sampling is an essential step in probabilistic modeling.
102+
# It helps in understanding the distribution of our prior assumptions before observing any data.
103+
104+
# %%
105+
# Prepare observation data
106+
import torch
107+
y_obs_list = torch.tensor([200, 210, 190])
108+
109+
# %%
110+
# Run prior sampling and visualization
111+
from pyro.infer import Predictive
112+
import pyro
113+
import arviz as az
114+
import matplotlib.pyplot as plt
115+
116+
prior = Predictive(model, num_samples=50)(y_obs_list)
117+
118+
data = az.from_pyro(prior=prior)
119+
az.plot_trace(data.prior)
120+
plt.show()
121+
122+
123+
from pyro.infer import NUTS
124+
from pyro.infer import MCMC
125+
from pyro.infer.autoguide import init_to_mean
126+
pyro.primitives.enable_validation(is_validate=True)
127+
nuts_kernel = NUTS(
128+
model,
129+
step_size=0.0085,
130+
adapt_step_size=True,
131+
target_accept_prob=0.9,
132+
max_tree_depth=10,
133+
init_strategy=init_to_mean
134+
)
135+
mcmc = MCMC(nuts_kernel, num_samples=200, warmup_steps=50, disable_validation=False)
136+
mcmc.run(y_obs_list)
137+
138+
139+
posterior_samples = mcmc.get_samples()
140+
posterior_predictive = Predictive(model, posterior_samples)(y_obs_list)
141+
data = az.from_pyro(posterior=mcmc, prior=prior, posterior_predictive=posterior_predictive)
142+
az.plot_trace(data)
143+
plt.show()

0 commit comments

Comments
 (0)