Skip to content

Commit 03450f4

Browse files
committed
[DOC] Adding a ton of docs to construct the model
1 parent 8464715 commit 03450f4

File tree

1 file changed

+90
-8
lines changed

1 file changed

+90
-8
lines changed

gempy_probability/modules/model_definition/prob_model_factory.py

Lines changed: 90 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,30 +9,112 @@
99

1010
def make_gempy_pyro_model(
1111
priors: Dict[str, Distribution],
12-
set_interp_input_fn: Callable[..., gp.data.Solutions],
12+
set_interp_input_fn: Callable[
13+
[Dict[str, torch.Tensor], gp.data.GeoModel],
14+
gp.data.InterpolationInput
15+
],
1316
likelihood_fn: Callable[[gp.data.Solutions], Distribution],
1417
obs_name: str = "obs"
15-
):
18+
) -> Callable[[gp.data.GeoModel, torch.Tensor], None]:
19+
"""
20+
Factory to produce a Pyro model for GemPy forward simulations.
1621
17-
def model(geo_model, obs_data):
18-
# 1) Sample each prior
19-
samples = {}
22+
This returns a `model(geo_model, obs_data)` function that you can hand
23+
directly to Pyro's inference APIs (`Predictive`, `MCMC`, …). Internally
24+
the generated model does:
25+
26+
1. Samples each key/value in `priors` via `pyro.sample(name, dist)`.
27+
2. Calls `set_interp_input_fn(samples, geo_model)` to inject those samples
28+
into a new `InterpolationInput`.
29+
3. Runs the GemPy forward solver with `run_gempy_forward(...)`.
30+
4. Wraps the resulting `Solutions` in `likelihood_fn(...)` to get a Pyro
31+
distribution, then observes `obs_data` under that distribution.
32+
33+
Parameters
34+
----------
35+
priors : Dict[str, Distribution]
36+
A mapping from Pyro sample‐site names to Pyro Distribution objects
37+
defining your prior beliefs over each parameter.
38+
set_interp_input_fn : Callable[[samples, geo_model], InterpolationInput]
39+
A user function which receives:
40+
41+
* `samples` (Dict[str, Tensor]): the values drawn from your priors
42+
* `geo_model` (gempy.core.data.GeoModel): the base geological model
43+
44+
and must return a ready‐to‐use `InterpolationInput` for GemPy.
45+
likelihood_fn : Callable[[Solutions], Distribution]
46+
A function mapping the GemPy forward‐model output (`Solutions`)
47+
to a Pyro Distribution representing the likelihood of the observed data.
48+
obs_name : str, optional
49+
The Pyro site name under which `obs_data` will be observed.
50+
Defaults to `"obs"`.
51+
52+
Returns
53+
-------
54+
model : Callable[[GeoModel, Tensor], None]
55+
A Pyro‐compatible model taking:
56+
57+
* `geo_model`: your GemPy GeoModel object
58+
* `obs_data`: a torch.Tensor of observations
59+
60+
This function has no return value; it registers its random draws
61+
via `pyro.sample`.
62+
63+
Example
64+
-------
65+
>>> import pyro.distributions as dist
66+
>>> from gempy_probability import make_gempy_pyro_model
67+
>>> from gempy.modules.data_manipulation import interpolation_input_from_structural_frame
68+
>>> import gempy_probability
69+
>>>
70+
>>> # 1) Define a simple Normal prior on μ
71+
>>> priors = {"μ": dist.Normal(0., 1.)}
72+
>>>
73+
>>> # 2) Function to inject μ into your interp-input
74+
>>> def set_input(samples, gm):
75+
... inp = interpolation_input_from_structural_frame(gm)
76+
... inp.surface_points.sp_coords = torch.index_put(
77+
... input=inp.surface_points.sp_coords,
78+
... indices=(torch.tensor([0]), torch.tensor([2])), # * This has to be Tensors
79+
... values=(samples["μ"])
80+
... )
81+
... return inp
82+
>>>
83+
>>>
84+
>>> # 4) Build the model
85+
>>> pyro_model = make_gempy_pyro_model(
86+
... priors=priors,
87+
... set_interp_input_fn=set_input,
88+
... likelihood_fn=gempy_probability.likelihoods.thickness_likelihood,
89+
... obs_name="y")
90+
>>>
91+
>>>
92+
>>> # Now this can be used with Predictive or MCMC directly:
93+
>>> # Predictive(pyro_model, num_samples=100)(geo_model, obs_tensor)
94+
"""
95+
def model(geo_model: gp.data.GeoModel, obs_data: torch.Tensor):
96+
# 1) Sample from the user‐supplied priors
97+
samples: Dict[str, torch.Tensor] = {}
2098
for name, dist in priors.items():
2199
samples[name] = pyro.sample(name, dist)
22100

23-
# 3) Run your forward geological model
101+
# 2) Build the new interpolation input
102+
interp_input = set_interp_input_fn(samples, geo_model)
103+
104+
# 3) Run GemPy forward simulation
24105
simulated: gp.data.Solutions = run_gempy_forward(
25-
interp_input=set_interp_input_fn(samples, geo_model),
106+
interp_input=interp_input,
26107
geo_model=geo_model
27108
)
28109

29-
# 4) Build likelihood and observe
110+
# 4) Wrap in likelihood & observe
30111
lik_dist = likelihood_fn(simulated)
31112
pyro.sample(obs_name, lik_dist, obs=obs_data)
32113

33114
return model
34115

35116

117+
36118
def make_generic_pyro_model(
37119
priors: Dict[str, Distribution],
38120
forward_fn: Callable[..., torch.Tensor],

0 commit comments

Comments
 (0)