Skip to content

Add probabilistic modeling API with model factory and runner #11

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jun 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions docs/dev_logs/2025_May.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,15 @@
### Likelihood functions
- ? Should I start to create a catalog of likelihood functions that goes with gempy?
- This would be useful to be able to visualize whatever the likelihood function represents independent of the probabilistic model.
-

### What would I want to have to do StonePark example as flowless as possible?
1. [x] A way to save the model already with the correct **nugget effects** instead of having to "load nuggets" (TODO in gempy)
2. [ ] Likelihood function plug and play
3. [ ] Some sort of pre-analysis if the model looks good. Maybe some metrics about how long it takes to run the model.
1. Condition number
2. Acceptance rate of the posterior
4. Easy pykeops setting using .env
5. Some sort of easy set-up for having a reduced number of priors
5. Some sort of easy set-up for having a reduced number of priors


## Probabilistic Modeling definition
71 changes: 0 additions & 71 deletions examples/examples/gravity_inversion_I.py

This file was deleted.

4 changes: 4 additions & 0 deletions gempy_probability/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
from .modules.model_definition.prob_model_factory import make_gempy_pyro_model, GemPyPyroModel
from .modules import likelihoods
from .api.model_runner import run_predictive, run_mcmc_for_NUTS, run_nuts_inference

from ._version import __version__
1 change: 1 addition & 0 deletions gempy_probability/api/model_runner/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from ._pyro_runner import run_predictive, run_nuts_inference, run_mcmc_for_NUTS
90 changes: 90 additions & 0 deletions gempy_probability/api/model_runner/_pyro_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import arviz as az
import matplotlib.pyplot as plt
import torch
from pyro.infer import MCMC
from pyro.infer import NUTS
from pyro.infer import Predictive
from pyro.infer.autoguide import init_to_mean

import gempy as gp
import gempy_probability as gpp
from ...core.samplers_data import NUTSConfig


def run_predictive(prob_model: gpp.GemPyPyroModel, geo_model: gp.data.GeoModel,
y_obs_list: torch.Tensor, n_samples: int, plot_trace:bool=False) -> az.InferenceData:
predictive = Predictive(
model=prob_model,
num_samples=n_samples
)
prior: dict[str, torch.Tensor] = predictive(geo_model, y_obs_list)
# print("Number of interpolations: ", geo_model.counter)

data: az.InferenceData = az.from_pyro(prior=prior)
if plot_trace:
az.plot_trace(data.prior)
plt.show()

return data

def run_nuts_inference(prob_model: gpp.GemPyPyroModel, geo_model: gp.data.GeoModel,
y_obs_list: torch.Tensor, config: NUTSConfig, plot_trace:bool=False,
run_posterior_predictive:bool=False) -> az.InferenceData:

nuts_kernel = NUTS(
prob_model,
step_size=config.step_size,
adapt_step_size=config.adapt_step_size,
target_accept_prob=config.target_accept_prob,
max_tree_depth=config.max_tree_depth,
init_strategy=init_to_mean
)
data = run_mcmc_for_NUTS(
geo_model=geo_model,
nuts_kernel=nuts_kernel,
prob_model=prob_model,
y_obs_list=y_obs_list,
num_samples=config.num_samples,
warmup_steps=config.warmup_steps,
plot_trace=plot_trace,
run_posterior_predictive=run_posterior_predictive,
)

return data


def run_mcmc_for_NUTS(
geo_model: gp.data.GeoModel,
nuts_kernel: NUTS,
prob_model: gpp.GemPyPyroModel,
y_obs_list: torch.Tensor,
num_samples: int,
warmup_steps: int,
plot_trace: bool = False,
run_posterior_predictive: bool = False,
) -> az.InferenceData:
mcmc = MCMC(
kernel=nuts_kernel,
num_samples=num_samples,
warmup_steps=warmup_steps,
disable_validation=False,
)
mcmc.run(geo_model, y_obs_list)

if run_posterior_predictive:
posterior_predictive = Predictive(
model=prob_model,
posterior_samples=mcmc.get_samples(),
)(geo_model, y_obs_list)
data: az.InferenceData = az.from_pyro(
posterior=mcmc,
posterior_predictive=posterior_predictive,
)
else:
data: az.InferenceData = az.from_pyro(posterior=mcmc)

if plot_trace:
az.plot_trace(data)
plt.show()

return data
12 changes: 12 additions & 0 deletions gempy_probability/core/samplers_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import dataclasses


@dataclasses.dataclass
class NUTSConfig:
step_size: float = 0.0085
adapt_step_size: bool = True
target_accept_prob: float = 0.9
max_tree_depth: int = 10
init_strategy: str = 'auto'
num_samples: int = 200
warmup_steps: int = 50
1 change: 1 addition & 0 deletions gempy_probability/modules/forwards/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from ._gempy_fw import run_gempy_forward
19 changes: 19 additions & 0 deletions gempy_probability/modules/forwards/_gempy_fw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import gempy as gp
import gempy_engine
from gempy_engine.core.backend_tensor import BackendTensor
from gempy_engine.core.data.interpolation_input import InterpolationInput


def run_gempy_forward(interp_input: InterpolationInput, geo_model: gp.data.GeoModel) -> gp.data.Solutions:

# compute model
sol = gempy_engine.compute_model(
interpolation_input=interp_input,
options=geo_model.interpolation_options,
data_descriptor=geo_model.input_data_descriptor,
geophysics_input=geo_model.geophysics_input
)

return sol


3 changes: 1 addition & 2 deletions gempy_probability/modules/likelihoods/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
from ._apparent_thickness import apparent_thickness_likelihood, apparent_thickness_likelihood_II
from ._gravity_inv import gravity_inversion_likelihood
from ._likelihood_functions import thickness_likelihood

9 changes: 0 additions & 9 deletions gempy_probability/modules/likelihoods/_gravity_inv.py

This file was deleted.

14 changes: 14 additions & 0 deletions gempy_probability/modules/likelihoods/_likelihood_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import pyro.distributions as dist
import gempy as gp
from ._apparent_thickness import apparent_thickness_likelihood


def gravity_likelihood(geo_model: gp.data.Solutions):
# e.g. multivariate normal with precomputed cov matrix
raise NotImplementedError("This function is not yet implemented")
return dist.MultivariateNormal(simulated_gravity, covariance_matrix)

def thickness_likelihood(solutions: gp.data.Solutions) -> dist:
# e.g. multivariate normal with precomputed cov matrix
simulated_thickness = apparent_thickness_likelihood(solutions)
return dist.Normal(simulated_thickness, 25)

This file was deleted.

12 changes: 6 additions & 6 deletions gempy_probability/modules/model_definition/model_examples.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import gempy as gp
import gempy.core.data
import gempy_engine
from gempy.modules.data_manipulation.engine_factory import interpolation_input_from_structural_frame
from gempy_probability.modules.likelihoods import apparent_thickness_likelihood
from gempy.modules.data_manipulation import interpolation_input_from_structural_frame

import pyro
import pyro.distributions as dist
from pyro import distributions as dist
import torch

from gempy_probability.modules.likelihoods._apparent_thickness import apparent_thickness_likelihood

def two_wells_prob_model_I(geo_model: gempy.core.data.GeoModel, normal, y_obs_list):

def two_wells_prob_model_I(geo_model: gp.data.GeoModel, normal, y_obs_list):
"""
This Pyro model represents the probabilistic aspects of the geological model.
It defines a prior distribution for the top layer's location and
Expand Down Expand Up @@ -44,7 +44,7 @@ def two_wells_prob_model_I(geo_model: gempy.core.data.GeoModel, normal, y_obs_li

# region Forward model computation

geo_model.counter +=1
# geo_model.counter +=1

# * Compute the geological model
geo_model.solutions = gempy_engine.compute_model(
Expand Down
Loading