Skip to content

Commit 7517f62

Browse files
authored
Add probabilistic modeling API with model factory and runner (#11)
# Probabilistic Modeling Framework for GemPy This PR introduces a comprehensive framework for probabilistic modeling in GemPy, making it easier to define, run, and analyze probabilistic geological models. ## Key Features - Added a factory function `make_gempy_pyro_model` to create Pyro models for GemPy with clean separation of concerns - Introduced `GemPyPyroModel` type alias for better type clarity throughout the codebase - Created standardized API for running probabilistic models with `run_predictive`, `run_nuts_inference`, and `run_mcmc_for_NUTS` - Implemented a configurable NUTS sampler with `NUTSConfig` dataclass - Reorganized likelihood functions into a more modular structure - Added comprehensive documentation for model creation and usage ## Implementation Details - Moved backend configuration to `run_gempy_forward` for consistent execution - Refactored likelihood modules for better organization and reusability - Added utility functions for common probabilistic modeling tasks - Created test cases demonstrating the new framework's capabilities - Removed outdated gravity inversion example in favor of the new approach This PR lays the groundwork for more sophisticated probabilistic modeling in GemPy, making it easier for users to define custom models while maintaining a clean separation between model definition, execution, and analysis.
2 parents eff535c + 619caaf commit 7517f62

File tree

20 files changed

+488
-96
lines changed

20 files changed

+488
-96
lines changed

docs/dev_logs/2025_May.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,15 @@
4646
### Likelihood functions
4747
- ? Should I start to create a catalog of likelihood functions that goes with gempy?
4848
- This would be useful to be able to visualize whatever the likelihood function represents independent of the probabilistic model.
49-
-
49+
5050
### What would I want to have to do StonePark example as flowless as possible?
5151
1. [x] A way to save the model already with the correct **nugget effects** instead of having to "load nuggets" (TODO in gempy)
5252
2. [ ] Likelihood function plug and play
5353
3. [ ] Some sort of pre-analysis if the model looks good. Maybe some metrics about how long it takes to run the model.
5454
1. Condition number
5555
2. Acceptance rate of the posterior
5656
4. Easy pykeops setting using .env
57-
5. Some sort of easy set-up for having a reduced number of priors
57+
5. Some sort of easy set-up for having a reduced number of priors
58+
59+
60+
## Probabilistic Modeling definition

examples/examples/gravity_inversion_I.py

Lines changed: 0 additions & 71 deletions
This file was deleted.

gempy_probability/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,5 @@
1+
from .modules.model_definition.prob_model_factory import make_gempy_pyro_model, GemPyPyroModel
2+
from .modules import likelihoods
3+
from .api.model_runner import run_predictive, run_mcmc_for_NUTS, run_nuts_inference
4+
15
from ._version import __version__
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from ._pyro_runner import run_predictive, run_nuts_inference, run_mcmc_for_NUTS
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import arviz as az
2+
import matplotlib.pyplot as plt
3+
import torch
4+
from pyro.infer import MCMC
5+
from pyro.infer import NUTS
6+
from pyro.infer import Predictive
7+
from pyro.infer.autoguide import init_to_mean
8+
9+
import gempy as gp
10+
import gempy_probability as gpp
11+
from ...core.samplers_data import NUTSConfig
12+
13+
14+
def run_predictive(prob_model: gpp.GemPyPyroModel, geo_model: gp.data.GeoModel,
15+
y_obs_list: torch.Tensor, n_samples: int, plot_trace:bool=False) -> az.InferenceData:
16+
predictive = Predictive(
17+
model=prob_model,
18+
num_samples=n_samples
19+
)
20+
prior: dict[str, torch.Tensor] = predictive(geo_model, y_obs_list)
21+
# print("Number of interpolations: ", geo_model.counter)
22+
23+
data: az.InferenceData = az.from_pyro(prior=prior)
24+
if plot_trace:
25+
az.plot_trace(data.prior)
26+
plt.show()
27+
28+
return data
29+
30+
def run_nuts_inference(prob_model: gpp.GemPyPyroModel, geo_model: gp.data.GeoModel,
31+
y_obs_list: torch.Tensor, config: NUTSConfig, plot_trace:bool=False,
32+
run_posterior_predictive:bool=False) -> az.InferenceData:
33+
34+
nuts_kernel = NUTS(
35+
prob_model,
36+
step_size=config.step_size,
37+
adapt_step_size=config.adapt_step_size,
38+
target_accept_prob=config.target_accept_prob,
39+
max_tree_depth=config.max_tree_depth,
40+
init_strategy=init_to_mean
41+
)
42+
data = run_mcmc_for_NUTS(
43+
geo_model=geo_model,
44+
nuts_kernel=nuts_kernel,
45+
prob_model=prob_model,
46+
y_obs_list=y_obs_list,
47+
num_samples=config.num_samples,
48+
warmup_steps=config.warmup_steps,
49+
plot_trace=plot_trace,
50+
run_posterior_predictive=run_posterior_predictive,
51+
)
52+
53+
return data
54+
55+
56+
def run_mcmc_for_NUTS(
57+
geo_model: gp.data.GeoModel,
58+
nuts_kernel: NUTS,
59+
prob_model: gpp.GemPyPyroModel,
60+
y_obs_list: torch.Tensor,
61+
num_samples: int,
62+
warmup_steps: int,
63+
plot_trace: bool = False,
64+
run_posterior_predictive: bool = False,
65+
) -> az.InferenceData:
66+
mcmc = MCMC(
67+
kernel=nuts_kernel,
68+
num_samples=num_samples,
69+
warmup_steps=warmup_steps,
70+
disable_validation=False,
71+
)
72+
mcmc.run(geo_model, y_obs_list)
73+
74+
if run_posterior_predictive:
75+
posterior_predictive = Predictive(
76+
model=prob_model,
77+
posterior_samples=mcmc.get_samples(),
78+
)(geo_model, y_obs_list)
79+
data: az.InferenceData = az.from_pyro(
80+
posterior=mcmc,
81+
posterior_predictive=posterior_predictive,
82+
)
83+
else:
84+
data: az.InferenceData = az.from_pyro(posterior=mcmc)
85+
86+
if plot_trace:
87+
az.plot_trace(data)
88+
plt.show()
89+
90+
return data
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import dataclasses
2+
3+
4+
@dataclasses.dataclass
5+
class NUTSConfig:
6+
step_size: float = 0.0085
7+
adapt_step_size: bool = True
8+
target_accept_prob: float = 0.9
9+
max_tree_depth: int = 10
10+
init_strategy: str = 'auto'
11+
num_samples: int = 200
12+
warmup_steps: int = 50
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from ._gempy_fw import run_gempy_forward
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import gempy as gp
2+
import gempy_engine
3+
from gempy_engine.core.backend_tensor import BackendTensor
4+
from gempy_engine.core.data.interpolation_input import InterpolationInput
5+
6+
7+
def run_gempy_forward(interp_input: InterpolationInput, geo_model: gp.data.GeoModel) -> gp.data.Solutions:
8+
9+
# compute model
10+
sol = gempy_engine.compute_model(
11+
interpolation_input=interp_input,
12+
options=geo_model.interpolation_options,
13+
data_descriptor=geo_model.input_data_descriptor,
14+
geophysics_input=geo_model.geophysics_input
15+
)
16+
17+
return sol
18+
19+
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
1-
from ._apparent_thickness import apparent_thickness_likelihood, apparent_thickness_likelihood_II
2-
from ._gravity_inv import gravity_inversion_likelihood
1+
from ._likelihood_functions import thickness_likelihood
32

gempy_probability/modules/likelihoods/_gravity_inv.py

Lines changed: 0 additions & 9 deletions
This file was deleted.

0 commit comments

Comments
 (0)