Skip to content

Commit efb261a

Browse files
committed
[ENH] Add run_predictive utility for probabilistic modeling
Introduces `run_predictive` in the API for running predictive sampling. Refactored existing code in `test_prob_factory` to use the new utility, enhancing modularity and reducing redundancy.
1 parent 393e204 commit efb261a

File tree

7 files changed

+53
-12
lines changed

7 files changed

+53
-12
lines changed

gempy_probability/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .modules.model_definition.prob_model_factory import make_gempy_pyro_model, GemPyPyroModel
22
from .modules import likelihoods
3+
from .api.model_runner import run_predictive
34

45
from ._version import __version__
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from ._pyro_runner import run_predictive
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import numpy as np
2+
import os
3+
import pyro
4+
import pyro.distributions as dist
5+
import torch
6+
from arviz import InferenceData
7+
from pyro.distributions import Distribution
8+
9+
import gempy as gp
10+
from gempy_engine.core.backend_tensor import BackendTensor
11+
import gempy_probability as gpp
12+
from gempy_engine.core.data.interpolation_input import InterpolationInput
13+
from pyro.infer import Predictive
14+
import pyro
15+
import arviz as az
16+
import matplotlib.pyplot as plt
17+
18+
from pyro.infer import NUTS
19+
from pyro.infer import MCMC
20+
from pyro.infer.autoguide import init_to_mean
21+
22+
23+
def run_predictive(prob_model: gpp.GemPyPyroModel, geo_model: gp.data.GeoModel,
24+
y_obs_list: torch.Tensor, n_samples: int, plot_trace:bool=False) -> az.InferenceData:
25+
predictive = Predictive(
26+
model=prob_model,
27+
num_samples=n_samples
28+
)
29+
prior: dict[str, torch.Tensor] = predictive(geo_model, y_obs_list)
30+
# print("Number of interpolations: ", geo_model.counter)
31+
32+
data: az.InferenceData = az.from_pyro(prior=prior)
33+
if plot_trace:
34+
az.plot_trace(data.prior)
35+
plt.show()
36+
37+
return data

gempy_probability/core/sampler.py

Whitespace-only changes.

gempy_probability/modules/model_definition/prob_model_factory.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Callable, Dict
55

66
import gempy as gp
7+
from gempy_engine.core.backend_tensor import BackendTensor
78
from gempy_probability.modules.forwards import run_gempy_forward
89

910
GemPyPyroModel = Callable[[gp.data.GeoModel, torch.Tensor], None]
@@ -13,7 +14,7 @@ def make_gempy_pyro_model(
1314
*,
1415
priors: Dict[str, Distribution],
1516
set_interp_input_fn: Callable[
16-
[Dict[str, torch.Tensor], gp.data.GeoModel],
17+
[Dict[str, Distribution], gp.data.GeoModel],
1718
gp.data.InterpolationInput
1819
],
1920
likelihood_fn: Callable[[gp.data.Solutions], Distribution],
@@ -95,6 +96,9 @@ def make_gempy_pyro_model(
9596
>>> # Now this can be used with Predictive or MCMC directly:
9697
>>> # Predictive(pyro_model, num_samples=100)(geo_model, obs_tensor)
9798
"""
99+
100+
BackendTensor.change_backend_gempy(engine_backend=gp.data.AvailableBackends.PYTORCH)
101+
98102
def model(geo_model: gp.data.GeoModel, obs_data: torch.Tensor):
99103
# 1) Sample from the user‐supplied priors
100104
samples: Dict[str, torch.Tensor] = {}

gempy_probability/modules/model_runner/__init__.py

Whitespace-only changes.

tests/test_prob_model/test_prob_factory.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def test_prob_model_factory() -> None:
3333
# TODO: Convert this into an options preset
3434
geo_model.interpolation_options.uni_degree = 0
3535
geo_model.interpolation_options.mesh_extraction = False
36-
geo_model.interpolation_options.sigmoid_slope = 1100.
36+
geo_model.interpolation_options.sigmoid_slope = 1100
3737

3838
# region Minimal grid for the specific likelihood function
3939
x_loc = 6000
@@ -77,7 +77,6 @@ def modify_z_for_surface_point1(
7777
geo_model: gp.data.GeoModel,
7878
) -> InterpolationInput:
7979
# TODO: We can make a factory for this type of functions
80-
8180
prior_key = r'$\mu_{top}$'
8281

8382
from gempy.modules.data_manipulation import interpolation_input_from_structural_frame
@@ -104,16 +103,13 @@ def _prob_run(geo_model: gp.data.GeoModel, prob_model: gpp.GemPyPyroModel,
104103
from pyro.infer.autoguide import init_to_mean
105104

106105
# region prior sampling
107-
predictive = Predictive(
108-
model=prob_model,
109-
num_samples=50
106+
prior = gpp.run_predictive(
107+
prob_model=prob_model,
108+
geo_model=geo_model,
109+
y_obs_list=y_obs_list,
110+
n_samples=50,
111+
plot_trace=True
110112
)
111-
prior = predictive(geo_model, y_obs_list)
112-
# print("Number of interpolations: ", geo_model.counter)
113-
114-
data = az.from_pyro(prior=prior)
115-
az.plot_trace(data.prior)
116-
plt.show()
117113

118114
# endregion
119115

@@ -184,3 +180,5 @@ def _prob_run(geo_model: gp.data.GeoModel, prob_model: gpp.GemPyPyroModel,
184180
print(f"Thickness mean: {posterior_thickness_mean}")
185181
print("MCMC convergence diagnostics:")
186182
print(f"Divergences: {float(data.sample_stats.diverging.sum())}")
183+
184+

0 commit comments

Comments
 (0)