Skip to content

Commit 53dc76e

Browse files
committed
[WIP] A bit of model analysis
1 parent 25ca165 commit 53dc76e

File tree

3 files changed

+41
-5
lines changed

3 files changed

+41
-5
lines changed

docs/dev_logs/2025_May.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@
22

33
--- General
44
- [ ] Create a proper suit of tests
5+
- [ ] Keep extracting code into modules and functions
6+
- [ ] Posterior analysis for gempy
57
- [ ] Investigate speed
68
- [ ] Investigate segmentation function
79
- [ ] Create a robust api to make sure we do not mess up assigning priors to gempy parameters
10+
- [ ] do we need to use index_put or just normal indexing?
811
--- Vector examples
912
- [ ] Bring the Vector example here...
1013
- [ ] Using categories as likelihood function (maybe I need to check out with Sam)

gempy_probability/modules/model_definition/model_examples.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def model(geo_model: gempy.core.data.GeoModel, sp_coords_copy, y_obs_list):
1818
prior_mean = sp_coords_copy[0, 2]
1919
normal = dist.Normal(
2020
loc=prior_mean,
21-
scale=torch.tensor(0.02, dtype=torch.float64)
21+
scale=torch.tensor(0.1, dtype=torch.float64)
2222
)
2323

2424
mu_top = pyro.sample(
@@ -36,6 +36,7 @@ def model(geo_model: gempy.core.data.GeoModel, sp_coords_copy, y_obs_list):
3636
indices=(torch.tensor([0]), torch.tensor([2])),
3737
values=mu_top
3838
)
39+
# interpolation_input.surface_points.sp_coords[0, 2] = mu_top
3940

4041
# endregion
4142

@@ -61,6 +62,6 @@ def model(geo_model: gempy.core.data.GeoModel, sp_coords_copy, y_obs_list):
6162

6263
y_thickness = pyro.sample(
6364
name=r'$y_{thickness}$',
64-
fn=dist.Normal(thickness, 50),
65+
fn=dist.Normal(thickness, 25),
6566
obs=y_obs_list
6667
)

tests/test_prob_model/test_prob_I.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55

66

7-
def test_basic_gempy_I():
7+
def test_basic_gempy_I() -> None:
88
current_dir = os.path.dirname(os.path.abspath(__file__))
99
data_path = os.path.abspath(os.path.join(current_dir, '..', '..', 'examples', 'tutorials', 'data'))
1010
geo_model = gp.create_geomodel(
@@ -84,6 +84,7 @@ def test_basic_gempy_I():
8484
from pyro.infer import NUTS
8585
from pyro.infer import MCMC
8686
from pyro.infer.autoguide import init_to_mean
87+
8788
pyro.primitives.enable_validation(is_validate=True)
8889
nuts_kernel = NUTS(
8990
model,
@@ -93,11 +94,42 @@ def test_basic_gempy_I():
9394
max_tree_depth=10,
9495
init_strategy=init_to_mean
9596
)
96-
mcmc = MCMC(nuts_kernel, num_samples=200, warmup_steps=50, disable_validation=False)
97+
mcmc = MCMC(
98+
kernel=nuts_kernel,
99+
num_samples=200,
100+
warmup_steps=50,
101+
disable_validation=False
102+
)
97103
mcmc.run(geo_model, sp_coords_copy, y_obs_list)
98104

99105
posterior_samples = mcmc.get_samples()
100-
posterior_predictive = Predictive(model, posterior_samples)(geo_model, sp_coords_copy, y_obs_list)
106+
107+
posterior_predictive_fn = Predictive(
108+
model=model,
109+
posterior_samples=posterior_samples
110+
)
111+
112+
posterior_predictive = posterior_predictive_fn(geo_model, sp_coords_copy, y_obs_list)
113+
101114
data = az.from_pyro(posterior=mcmc, prior=prior, posterior_predictive=posterior_predictive)
102115
az.plot_trace(data)
103116
plt.show()
117+
118+
from gempy_probability.modules.plot.plot_posterior import default_red, default_blue
119+
az.plot_density(
120+
data=[data.posterior_predictive, data.prior_predictive],
121+
shade=.9,
122+
var_names=[r'$\mu_{thickness}$'],
123+
data_labels=["Posterior Predictive", "Prior Predictive"],
124+
colors=[default_red, default_blue],
125+
)
126+
plt.show()
127+
128+
az.plot_density(
129+
data=[data, data.prior],
130+
shade=.9,
131+
hdi_prob=.99,
132+
data_labels=["Posterior", "Prior"],
133+
colors=[default_red, default_blue],
134+
)
135+
plt.show()

0 commit comments

Comments
 (0)