|
9 | 9 |
|
10 | 10 | def make_gempy_pyro_model(
|
11 | 11 | priors: Dict[str, Distribution],
|
12 |
| - set_interp_input_fn: Callable[..., gp.data.Solutions], |
| 12 | + set_interp_input_fn: Callable[ |
| 13 | + [Dict[str, torch.Tensor], gp.data.GeoModel], |
| 14 | + gp.data.InterpolationInput |
| 15 | + ], |
13 | 16 | likelihood_fn: Callable[[gp.data.Solutions], Distribution],
|
14 | 17 | obs_name: str = "obs"
|
15 |
| -): |
| 18 | +) -> Callable[[gp.data.GeoModel, torch.Tensor], None]: |
| 19 | + """ |
| 20 | + Factory to produce a Pyro model for GemPy forward simulations. |
16 | 21 |
|
17 |
| - def model(geo_model, obs_data): |
18 |
| - # 1) Sample each prior |
19 |
| - samples = {} |
| 22 | + This returns a `model(geo_model, obs_data)` function that you can hand |
| 23 | + directly to Pyro's inference APIs (`Predictive`, `MCMC`, …). Internally |
| 24 | + the generated model does: |
| 25 | +
|
| 26 | + 1. Samples each key/value in `priors` via `pyro.sample(name, dist)`. |
| 27 | + 2. Calls `set_interp_input_fn(samples, geo_model)` to inject those samples |
| 28 | + into a new `InterpolationInput`. |
| 29 | + 3. Runs the GemPy forward solver with `run_gempy_forward(...)`. |
| 30 | + 4. Wraps the resulting `Solutions` in `likelihood_fn(...)` to get a Pyro |
| 31 | + distribution, then observes `obs_data` under that distribution. |
| 32 | +
|
| 33 | + Parameters |
| 34 | + ---------- |
| 35 | + priors : Dict[str, Distribution] |
| 36 | + A mapping from Pyro sample‐site names to Pyro Distribution objects |
| 37 | + defining your prior beliefs over each parameter. |
| 38 | + set_interp_input_fn : Callable[[samples, geo_model], InterpolationInput] |
| 39 | + A user function which receives: |
| 40 | + |
| 41 | + * `samples` (Dict[str, Tensor]): the values drawn from your priors |
| 42 | + * `geo_model` (gempy.core.data.GeoModel): the base geological model |
| 43 | + |
| 44 | + and must return a ready‐to‐use `InterpolationInput` for GemPy. |
| 45 | + likelihood_fn : Callable[[Solutions], Distribution] |
| 46 | + A function mapping the GemPy forward‐model output (`Solutions`) |
| 47 | + to a Pyro Distribution representing the likelihood of the observed data. |
| 48 | + obs_name : str, optional |
| 49 | + The Pyro site name under which `obs_data` will be observed. |
| 50 | + Defaults to `"obs"`. |
| 51 | +
|
| 52 | + Returns |
| 53 | + ------- |
| 54 | + model : Callable[[GeoModel, Tensor], None] |
| 55 | + A Pyro‐compatible model taking: |
| 56 | +
|
| 57 | + * `geo_model`: your GemPy GeoModel object |
| 58 | + * `obs_data`: a torch.Tensor of observations |
| 59 | +
|
| 60 | + This function has no return value; it registers its random draws |
| 61 | + via `pyro.sample`. |
| 62 | +
|
| 63 | + Example |
| 64 | + ------- |
| 65 | + >>> import pyro.distributions as dist |
| 66 | + >>> from gempy_probability import make_gempy_pyro_model |
| 67 | + >>> from gempy.modules.data_manipulation import interpolation_input_from_structural_frame |
| 68 | + >>> import gempy_probability |
| 69 | + >>> |
| 70 | + >>> # 1) Define a simple Normal prior on μ |
| 71 | + >>> priors = {"μ": dist.Normal(0., 1.)} |
| 72 | + >>> |
| 73 | + >>> # 2) Function to inject μ into your interp-input |
| 74 | + >>> def set_input(samples, gm): |
| 75 | + ... inp = interpolation_input_from_structural_frame(gm) |
| 76 | + ... inp.surface_points.sp_coords = torch.index_put( |
| 77 | + ... input=inp.surface_points.sp_coords, |
| 78 | + ... indices=(torch.tensor([0]), torch.tensor([2])), # * This has to be Tensors |
| 79 | + ... values=(samples["μ"]) |
| 80 | + ... ) |
| 81 | + ... return inp |
| 82 | + >>> |
| 83 | + >>> |
| 84 | + >>> # 4) Build the model |
| 85 | + >>> pyro_model = make_gempy_pyro_model( |
| 86 | + ... priors=priors, |
| 87 | + ... set_interp_input_fn=set_input, |
| 88 | + ... likelihood_fn=gempy_probability.likelihoods.thickness_likelihood, |
| 89 | + ... obs_name="y") |
| 90 | + >>> |
| 91 | + >>> |
| 92 | + >>> # Now this can be used with Predictive or MCMC directly: |
| 93 | + >>> # Predictive(pyro_model, num_samples=100)(geo_model, obs_tensor) |
| 94 | + """ |
| 95 | + def model(geo_model: gp.data.GeoModel, obs_data: torch.Tensor): |
| 96 | + # 1) Sample from the user‐supplied priors |
| 97 | + samples: Dict[str, torch.Tensor] = {} |
20 | 98 | for name, dist in priors.items():
|
21 | 99 | samples[name] = pyro.sample(name, dist)
|
22 | 100 |
|
23 |
| - # 3) Run your forward geological model |
| 101 | + # 2) Build the new interpolation input |
| 102 | + interp_input = set_interp_input_fn(samples, geo_model) |
| 103 | + |
| 104 | + # 3) Run GemPy forward simulation |
24 | 105 | simulated: gp.data.Solutions = run_gempy_forward(
|
25 |
| - interp_input=set_interp_input_fn(samples, geo_model), |
| 106 | + interp_input=interp_input, |
26 | 107 | geo_model=geo_model
|
27 | 108 | )
|
28 | 109 |
|
29 |
| - # 4) Build likelihood and observe |
| 110 | + # 4) Wrap in likelihood & observe |
30 | 111 | lik_dist = likelihood_fn(simulated)
|
31 | 112 | pyro.sample(obs_name, lik_dist, obs=obs_data)
|
32 | 113 |
|
33 | 114 | return model
|
34 | 115 |
|
35 | 116 |
|
| 117 | + |
36 | 118 | def make_generic_pyro_model(
|
37 | 119 | priors: Dict[str, Distribution],
|
38 | 120 | forward_fn: Callable[..., torch.Tensor],
|
|
0 commit comments