Skip to content

Commit 94f0548

Browse files
committed
[CLN/ENH] Better definition for specific gempy_pyro models
1 parent ab4be33 commit 94f0548

File tree

7 files changed

+112
-67
lines changed

7 files changed

+112
-67
lines changed

gempy_probability/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
1+
from .modules.model_definition.prob_model_factory import make_gempy_pyro_model
2+
13
from ._version import __version__
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from ._gempy_fw import run_gempy_forward
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import torch
2+
from pyro.distributions import Distribution
3+
4+
import gempy_engine
5+
from gempy.modules.data_manipulation import interpolation_input_from_structural_frame
6+
import gempy as gp
7+
from gempy_engine.core.data.interpolation_input import InterpolationInput
8+
9+
10+
def run_gempy_forward(interp_input: InterpolationInput, geo_model: gp.data.GeoModel) -> gp.data.Solutions:
11+
12+
# compute model
13+
sol = gempy_engine.compute_model(
14+
interpolation_input=interp_input,
15+
options=geo_model.interpolation_options,
16+
data_descriptor=geo_model.input_data_descriptor,
17+
geophysics_input=geo_model.geophysics_input
18+
)
19+
20+
return sol
21+
22+
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import torch
2+
import pyro.distributions as dist
3+
import gempy as gp
4+
from . import apparent_thickness_likelihood
5+
6+
7+
def gravity_likelihood(geo_model: gp.data.Solutions):
8+
# e.g. multivariate normal with precomputed cov matrix
9+
raise NotImplementedError("This function is not yet implemented")
10+
return dist.MultivariateNormal(simulated_gravity, covariance_matrix)
11+
12+
def thickness_likelihood(solutions: gp.data.Solutions) -> dist:
13+
# e.g. multivariate normal with precomputed cov matrix
14+
simulated_thickness = apparent_thickness_likelihood(solutions)
15+
return dist.Normal(simulated_thickness, 25)

gempy_probability/modules/model_definition/model_examples.py

Lines changed: 0 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
import pyro
88
import torch
99

10-
from gempy_probability.modules.model_definition.prob_model_factory import make_pyro_model
11-
1210

1311
def two_wells_prob_model_I(geo_model: gempy.core.data.GeoModel, normal, y_obs_list):
1412
"""
@@ -73,59 +71,3 @@ def two_wells_prob_model_I(geo_model: gempy.core.data.GeoModel, normal, y_obs_li
7371
)
7472

7573

76-
77-
import pyro.distributions as dist
78-
79-
# 1) Define your priors
80-
gravity_priors = {
81-
# e.g. density contrast of layer index 3
82-
# "mu_density": dist.Normal(2.62, 0.5),
83-
r'$\mu_{top}$': dist.Normal(
84-
loc=torch.tensor(0.0487, dtype=torch.float64),
85-
scale=torch.tensor(0.1, dtype=torch.float64)
86-
)
87-
}
88-
89-
# 2) Wrap your forward pass into a small fn
90-
def run_gempy_forward(samples, geo_model):
91-
92-
mu_top = samples["$\mu_{top}$"]
93-
94-
interp_input = interpolation_input_from_structural_frame(geo_model)
95-
96-
indices__ = (torch.tensor([0]), torch.tensor([2])) # * This has to be Tensors
97-
new_tensor: torch.Tensor = torch.index_put(
98-
input=interp_input.surface_points.sp_coords,
99-
indices=indices__,
100-
values=mu_top
101-
)
102-
interp_input.surface_points.sp_coords = new_tensor
103-
104-
# compute model
105-
sol = gempy_engine.compute_model(
106-
interpolation_input=interp_input,
107-
options=geo_model.interpolation_options,
108-
data_descriptor=geo_model.input_data_descriptor,
109-
geophysics_input=geo_model.geophysics_input
110-
)
111-
112-
thickness = apparent_thickness_likelihood(sol)
113-
return thickness
114-
115-
# 3) Define your likelihood factory
116-
def gravity_likelihood(simulated_gravity):
117-
# e.g. multivariate normal with precomputed cov matrix
118-
return dist.MultivariateNormal(simulated_gravity, covariance_matrix)
119-
120-
def thickness_likelihood(simulated_thickness):
121-
# e.g. multivariate normal with precomputed cov matrix
122-
return dist.Normal(simulated_thickness, 25)
123-
124-
# 4) Build the model
125-
pyro_gravity_model = make_pyro_model(
126-
priors=gravity_priors,
127-
forward_fn=run_gempy_forward,
128-
likelihood_fn=thickness_likelihood,
129-
obs_name="obs_gravity"
130-
)
131-

gempy_probability/modules/model_definition/prob_model_factory.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,37 @@
33
from typing import Callable, Dict
44
from pyro.distributions import Distribution
55

6-
def make_pyro_model(
6+
from gempy_probability.modules.forwards import run_gempy_forward
7+
import gempy as gp
8+
9+
10+
def make_gempy_pyro_model(
11+
priors: Dict[str, Distribution],
12+
set_interp_input_fn: Callable[..., gp.data.Solutions],
13+
likelihood_fn: Callable[[gp.data.Solutions], Distribution],
14+
obs_name: str = "obs"
15+
):
16+
17+
def model(geo_model, obs_data):
18+
# 1) Sample each prior
19+
samples = {}
20+
for name, dist in priors.items():
21+
samples[name] = pyro.sample(name, dist)
22+
23+
# 3) Run your forward geological model
24+
simulated: gp.data.Solutions = run_gempy_forward(
25+
interp_input=set_interp_input_fn(samples, geo_model),
26+
geo_model=geo_model
27+
)
28+
29+
# 4) Build likelihood and observe
30+
lik_dist = likelihood_fn(simulated)
31+
pyro.sample(obs_name, lik_dist, obs=obs_data)
32+
33+
return model
34+
35+
36+
def make_generic_pyro_model(
737
priors: Dict[str, Distribution],
838
forward_fn: Callable[..., torch.Tensor],
939
likelihood_fn: Callable[[torch.Tensor], Distribution],
@@ -29,4 +59,4 @@ def model(geo_model, obs_data):
2959
lik_dist = likelihood_fn(simulated)
3060
pyro.sample(obs_name, lik_dist, obs=obs_data)
3161

32-
return model
62+
return model

tests/test_prob_model/test_prob_factory.py

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
11
import numpy as np
22
import os
3+
import pyro
34
import pyro.distributions as dist
45
import torch
6+
from pyro.distributions import Distribution
57

68
import gempy as gp
79
from gempy_engine.core.backend_tensor import BackendTensor
10+
import gempy_probability as gpp
11+
from gempy_engine.core.data.interpolation_input import InterpolationInput
12+
13+
seed = 123456
14+
torch.manual_seed(seed)
15+
pyro.set_rng_seed(seed)
816

917
def test_prob_model_factory() -> None:
1018
current_dir = os.path.dirname(os.path.abspath(__file__))
@@ -44,23 +52,48 @@ def test_prob_model_factory() -> None:
4452
)
4553
BackendTensor.change_backend_gempy(engine_backend=gp.data.AvailableBackends.PYTORCH)
4654

47-
normal = dist.Normal(
48-
loc=(geo_model.surface_points_copy_transformed.xyz[0, 2]),
49-
scale=torch.tensor(0.1, dtype=torch.float64)
55+
# 1) Define your priors
56+
model_priors = {
57+
r'$\mu_{top}$': dist.Normal(
58+
loc=geo_model.surface_points_copy_transformed.xyz[0, 2],
59+
scale=torch.tensor(0.1, dtype=torch.float64)
60+
)
61+
}
62+
63+
from gempy_probability.modules.likelihoods._likelihood_functions import thickness_likelihood
64+
65+
pyro_gravity_model = gpp.make_gempy_pyro_model(
66+
priors=model_priors,
67+
set_interp_input_fn=modify_z_for_surface_point1,
68+
likelihood_fn=thickness_likelihood,
69+
obs_name="obs_gravity"
5070
)
5171

52-
from gempy_probability.modules.model_definition.model_examples import two_wells_prob_model_I, pyro_gravity_model
53-
5472
_prob_run(
5573
geo_model=geo_model,
5674
prob_model=pyro_gravity_model,
57-
normal=normal,
5875
y_obs_list=torch.tensor([200, 210, 190])
5976
)
77+
78+
def modify_z_for_surface_point1(
79+
samples: dict[str, Distribution],
80+
geo_model: gp.data.GeoModel,
81+
) -> InterpolationInput:
82+
prior_key = r'$\mu_{top}$'
83+
84+
from gempy.modules.data_manipulation import interpolation_input_from_structural_frame
85+
interp_input = interpolation_input_from_structural_frame(geo_model)
86+
new_tensor: torch.Tensor = torch.index_put(
87+
input=interp_input.surface_points.sp_coords,
88+
indices=(torch.tensor([0]), torch.tensor([2])), # * This has to be Tensors
89+
values=(samples[prior_key])
90+
)
91+
interp_input.surface_points.sp_coords = new_tensor
92+
return interp_input
6093

6194

6295
def _prob_run(geo_model: gp.data.GeoModel, prob_model: callable,
63-
normal: torch.distributions.Distribution, y_obs_list: torch.Tensor) -> None:
96+
y_obs_list: torch.Tensor) -> None:
6497
# Run prior sampling and visualization
6598
from pyro.infer import Predictive
6699
import pyro

0 commit comments

Comments
 (0)