4
4
import pyro .distributions as dist
5
5
import torch
6
6
from pyro .distributions import Distribution
7
+ import arviz as az
8
+ import matplotlib .pyplot as plt
7
9
8
10
import gempy as gp
9
- from gempy_engine .core .backend_tensor import BackendTensor
10
11
import gempy_probability as gpp
12
+ from gempy_engine .core .backend_tensor import BackendTensor
11
13
from gempy_engine .core .data .interpolation_input import InterpolationInput
12
14
from gempy_probability .core .samplers_data import NUTSConfig
13
15
@@ -29,9 +31,6 @@ def test_prob_model_factory() -> None:
29
31
)
30
32
)
31
33
32
- # TODO: Make this a more elegant way
33
- BackendTensor .change_backend_gempy (engine_backend = gp .data .AvailableBackends .PYTORCH )
34
-
35
34
# TODO: Convert this into an options preset
36
35
geo_model .interpolation_options .uni_degree = 0
37
36
geo_model .interpolation_options .mesh_extraction = False
@@ -65,15 +64,36 @@ def test_prob_model_factory() -> None:
65
64
priors = model_priors ,
66
65
set_interp_input_fn = modify_z_for_surface_point1 ,
67
66
likelihood_fn = gpp .likelihoods .thickness_likelihood ,
68
- obs_name = "obs_gravity "
67
+ obs_name = "wells_thickness "
69
68
)
70
69
71
- _prob_run (
70
+ data : az . InferenceData = _prob_run (
72
71
geo_model = geo_model ,
73
72
prob_model = pyro_gravity_model ,
74
73
y_obs_list = torch .tensor ([200 , 210 , 190 ])
75
74
)
76
75
76
+ if False : # * Save the arviz data
77
+ data .to_netcdf ("arviz_data.nc" )
78
+
79
+ _plot (data )
80
+
81
+ # Test posterior mean values
82
+ posterior_top_mean = float (data .posterior [r'$\mu_{top}$' ].mean ())
83
+ target_top = 0.00875
84
+ target_thickness = 223
85
+ assert abs (posterior_top_mean - target_top ) < 0.0200 , f"Top layer mean { posterior_top_mean } outside expected range"
86
+ posterior_thickness_mean = float (data .posterior_predictive [r'$\mu_{thickness}$' ].mean ())
87
+ assert abs (posterior_thickness_mean - target_thickness ) < 5 , f"Thickness mean { posterior_thickness_mean } outside expected range"
88
+ # Test convergence diagnostics
89
+ assert float (data .sample_stats .diverging .sum ()) == 0 , "MCMC sampling has divergences"
90
+
91
+ print ("Posterior mean values:" )
92
+ print (f"Top layer mean: { posterior_top_mean } " )
93
+ print (f"Thickness mean: { posterior_thickness_mean } " )
94
+ print ("MCMC convergence diagnostics:" )
95
+ print (f"Divergences: { float (data .sample_stats .diverging .sum ())} " )
96
+
77
97
78
98
def modify_z_for_surface_point1 (
79
99
samples : dict [str , Distribution ],
@@ -87,23 +107,15 @@ def modify_z_for_surface_point1(
87
107
new_tensor : torch .Tensor = torch .index_put (
88
108
input = interp_input .surface_points .sp_coords ,
89
109
indices = (torch .tensor ([0 ]), torch .tensor ([2 ])), # * This has to be Tensors
90
- values = (samples [prior_key ])
91
- )
110
+ values = (samples [prior_key ]
111
+ )
92
112
interp_input .surface_points .sp_coords = new_tensor
93
113
return interp_input
94
114
95
115
96
116
def _prob_run (geo_model : gp .data .GeoModel , prob_model : gpp .GemPyPyroModel ,
97
- y_obs_list : torch .Tensor ) -> None :
117
+ y_obs_list : torch .Tensor ) -> az . InferenceData :
98
118
# Run prior sampling and visualization
99
- from pyro .infer import Predictive
100
- import pyro
101
- import arviz as az
102
- import matplotlib .pyplot as plt
103
-
104
- from pyro .infer import NUTS
105
- from pyro .infer import MCMC
106
- from pyro .infer .autoguide import init_to_mean
107
119
108
120
# region prior sampling
109
121
prior_inference = gpp .run_predictive (
@@ -133,42 +145,12 @@ def _prob_run(geo_model: gp.data.GeoModel, prob_model: gpp.GemPyPyroModel,
133
145
plot_trace = False ,
134
146
run_posterior_predictive = True
135
147
)
136
- # pyro.primitives.enable_validation(is_validate=True)
137
- # nuts_kernel = NUTS(
138
- # prob_model,
139
- # step_size=0.0085,
140
- # adapt_step_size=True,
141
- # target_accept_prob=0.9,
142
- # max_tree_depth=10,
143
- # init_strategy=init_to_mean
144
- # )
145
- # mcmc = MCMC(
146
- # kernel=nuts_kernel,
147
- # num_samples=200,
148
- # warmup_steps=50,
149
- # disable_validation=False
150
- # )
151
- # mcmc.run(geo_model, y_obs_list)
152
- # posterior_samples = mcmc.get_samples()
153
- # posterior_predictive_fn = Predictive(
154
- # model=prob_model,
155
- # posterior_samples=posterior_samples
156
- # )
157
- # posterior_predictive = posterior_predictive_fn(geo_model, y_obs_list)
158
- #
159
- # data = az.from_pyro(
160
- # posterior=mcmc,
161
- # # prior=prior.to_dict()["prior"],
162
- # posterior_predictive=posterior_predictive
163
- # )
164
148
data .extend (prior_inference )
149
+ return data
165
150
# endregion
166
151
167
- # print("Number of interpolations: ", geo_model.counter)
168
-
169
- if False : # * Save the arviz data
170
- data .to_netcdf ("arviz_data.nc" )
171
152
153
+ def _plot (data ):
172
154
az .plot_trace (data )
173
155
plt .show ()
174
156
from gempy_probability .modules .plot .plot_posterior import default_red , default_blue
@@ -188,7 +170,6 @@ def _prob_run(geo_model: gp.data.GeoModel, prob_model: gpp.GemPyPyroModel,
188
170
colors = [default_red , default_blue ],
189
171
)
190
172
plt .show ()
191
-
192
173
az .plot_density (
193
174
data = [data .posterior_predictive , data .prior ],
194
175
shade = .9 ,
@@ -197,19 +178,3 @@ def _prob_run(geo_model: gp.data.GeoModel, prob_model: gpp.GemPyPyroModel,
197
178
colors = [default_red , default_blue ],
198
179
)
199
180
plt .show ()
200
-
201
- # Test posterior mean values
202
- posterior_top_mean = float (data .posterior [r'$\mu_{top}$' ].mean ())
203
- target_top = 0.00875
204
- target_thickness = 223
205
- assert abs (posterior_top_mean - target_top ) < 0.0200 , f"Top layer mean { posterior_top_mean } outside expected range"
206
- posterior_thickness_mean = float (data .posterior_predictive [r'$\mu_{thickness}$' ].mean ())
207
- assert abs (posterior_thickness_mean - target_thickness ) < 5 , f"Thickness mean { posterior_thickness_mean } outside expected range"
208
- # Test convergence diagnostics
209
- assert float (data .sample_stats .diverging .sum ()) == 0 , "MCMC sampling has divergences"
210
-
211
- print ("Posterior mean values:" )
212
- print (f"Top layer mean: { posterior_top_mean } " )
213
- print (f"Thickness mean: { posterior_thickness_mean } " )
214
- print ("MCMC convergence diagnostics:" )
215
- print (f"Divergences: { float (data .sample_stats .diverging .sum ())} " )
0 commit comments