Skip to content

Commit c41b2c2

Browse files
committed
[WIP] Getting there
1 parent 99b494b commit c41b2c2

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

gempy_probability/modules/model_definition/model_examples.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,17 @@ def two_wells_prob_model_I(geo_model: gempy.core.data.GeoModel, normal, y_obs_li
7979
# 1) Define your priors
8080
gravity_priors = {
8181
# e.g. density contrast of layer index 3
82-
"mu_density": dist.Normal(2.62, 0.5)
82+
# "mu_density": dist.Normal(2.62, 0.5),
83+
"mu_thickness": dist.Normal(
84+
loc=torch.tensor(0.0487, dtype=torch.float64),
85+
scale=torch.tensor(0.1, dtype=torch.float64)
86+
)
8387
}
8488

8589
# 2) Wrap your forward pass into a small fn
8690
def run_gempy_forward(samples, geo_model):
8791

88-
mu_top = samples["mu_density"]
92+
mu_top = samples["mu_thickness"]
8993

9094
interp_input = interpolation_input_from_structural_frame(geo_model)
9195

tests/test_prob_model/test_prob_factory.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,20 +101,20 @@ def _prob_run(geo_model: gp.data.GeoModel, prob_model: callable,
101101
warmup_steps=50,
102102
disable_validation=False
103103
)
104-
mcmc.run(geo_model, normal, y_obs_list)
104+
mcmc.run(geo_model, y_obs_list)
105105
posterior_samples = mcmc.get_samples()
106106
posterior_predictive_fn = Predictive(
107107
model=prob_model,
108108
posterior_samples=posterior_samples
109109
)
110-
posterior_predictive = posterior_predictive_fn(geo_model, normal, y_obs_list)
110+
posterior_predictive = posterior_predictive_fn(geo_model, y_obs_list)
111111

112112
data = az.from_pyro(posterior=mcmc, prior=prior, posterior_predictive=posterior_predictive)
113113
# endregion
114114

115115
# print("Number of interpolations: ", geo_model.counter)
116116

117-
if True: # * Save the arviz data
117+
if False: # * Save the arviz data
118118
data.to_netcdf("arviz_data.nc")
119119

120120
az.plot_trace(data)

0 commit comments

Comments
 (0)