9
9
from gempy_engine .core .backend_tensor import BackendTensor
10
10
import gempy_probability as gpp
11
11
from gempy_engine .core .data .interpolation_input import InterpolationInput
12
+ from gempy_probability .core .samplers_data import NUTSConfig
12
13
13
14
seed = 123456
14
15
torch .manual_seed (seed )
15
16
pyro .set_rng_seed (seed )
16
17
18
+
17
19
def test_prob_model_factory () -> None :
18
20
current_dir = os .path .dirname (os .path .abspath (__file__ ))
19
21
data_path = os .path .abspath (os .path .join (current_dir , '..' , '..' , 'examples' , 'tutorials' , 'data' ))
@@ -71,19 +73,20 @@ def test_prob_model_factory() -> None:
71
73
prob_model = pyro_gravity_model ,
72
74
y_obs_list = torch .tensor ([200 , 210 , 190 ])
73
75
)
74
-
76
+
77
+
75
78
def modify_z_for_surface_point1 (
76
79
samples : dict [str , Distribution ],
77
80
geo_model : gp .data .GeoModel ,
78
81
) -> InterpolationInput :
79
82
# TODO: We can make a factory for this type of functions
80
83
prior_key = r'$\mu_{top}$'
81
-
84
+
82
85
from gempy .modules .data_manipulation import interpolation_input_from_structural_frame
83
86
interp_input = interpolation_input_from_structural_frame (geo_model )
84
87
new_tensor : torch .Tensor = torch .index_put (
85
88
input = interp_input .surface_points .sp_coords ,
86
- indices = (torch .tensor ([0 ]), torch .tensor ([2 ])), # * This has to be Tensors
89
+ indices = (torch .tensor ([0 ]), torch .tensor ([2 ])), # * This has to be Tensors
87
90
values = (samples [prior_key ])
88
91
)
89
92
interp_input .surface_points .sp_coords = new_tensor
@@ -103,7 +106,7 @@ def _prob_run(geo_model: gp.data.GeoModel, prob_model: gpp.GemPyPyroModel,
103
106
from pyro .infer .autoguide import init_to_mean
104
107
105
108
# region prior sampling
106
- prior = gpp .run_predictive (
109
+ prior_inference = gpp .run_predictive (
107
110
prob_model = prob_model ,
108
111
geo_model = geo_model ,
109
112
y_obs_list = y_obs_list ,
@@ -114,30 +117,51 @@ def _prob_run(geo_model: gp.data.GeoModel, prob_model: gpp.GemPyPyroModel,
114
117
# endregion
115
118
116
119
# region inference
117
- pyro .primitives .enable_validation (is_validate = True )
118
- nuts_kernel = NUTS (
119
- prob_model ,
120
- step_size = 0.0085 ,
121
- adapt_step_size = True ,
122
- target_accept_prob = 0.9 ,
123
- max_tree_depth = 10 ,
124
- init_strategy = init_to_mean
125
- )
126
- mcmc = MCMC (
127
- kernel = nuts_kernel ,
128
- num_samples = 200 ,
129
- warmup_steps = 50 ,
130
- disable_validation = False
131
- )
132
- mcmc .run (geo_model , y_obs_list )
133
- posterior_samples = mcmc .get_samples ()
134
- posterior_predictive_fn = Predictive (
135
- model = prob_model ,
136
- posterior_samples = posterior_samples
120
+ data = gpp .run_nuts_inference (
121
+ prob_model = prob_model ,
122
+ geo_model = geo_model ,
123
+ y_obs_list = y_obs_list ,
124
+ config = NUTSConfig (
125
+ step_size = 0.0085 ,
126
+ adapt_step_size = True ,
127
+ target_accept_prob = 0.9 ,
128
+ max_tree_depth = 10 ,
129
+ init_strategy = 'auto' ,
130
+ num_samples = 200 ,
131
+ warmup_steps = 50 ,
132
+ ),
133
+ plot_trace = False ,
134
+ run_posterior_predictive = True
137
135
)
138
- posterior_predictive = posterior_predictive_fn (geo_model , y_obs_list )
139
-
140
- data = az .from_pyro (posterior = mcmc , prior = prior , posterior_predictive = posterior_predictive )
136
+ # pyro.primitives.enable_validation(is_validate=True)
137
+ # nuts_kernel = NUTS(
138
+ # prob_model,
139
+ # step_size=0.0085,
140
+ # adapt_step_size=True,
141
+ # target_accept_prob=0.9,
142
+ # max_tree_depth=10,
143
+ # init_strategy=init_to_mean
144
+ # )
145
+ # mcmc = MCMC(
146
+ # kernel=nuts_kernel,
147
+ # num_samples=200,
148
+ # warmup_steps=50,
149
+ # disable_validation=False
150
+ # )
151
+ # mcmc.run(geo_model, y_obs_list)
152
+ # posterior_samples = mcmc.get_samples()
153
+ # posterior_predictive_fn = Predictive(
154
+ # model=prob_model,
155
+ # posterior_samples=posterior_samples
156
+ # )
157
+ # posterior_predictive = posterior_predictive_fn(geo_model, y_obs_list)
158
+ #
159
+ # data = az.from_pyro(
160
+ # posterior=mcmc,
161
+ # # prior=prior.to_dict()["prior"],
162
+ # posterior_predictive=posterior_predictive
163
+ # )
164
+ data .extend (prior_inference )
141
165
# endregion
142
166
143
167
# print("Number of interpolations: ", geo_model.counter)
@@ -149,7 +173,7 @@ def _prob_run(geo_model: gp.data.GeoModel, prob_model: gpp.GemPyPyroModel,
149
173
plt .show ()
150
174
from gempy_probability .modules .plot .plot_posterior import default_red , default_blue
151
175
az .plot_density (
152
- data = [data .posterior_predictive , data .prior_predictive ],
176
+ data = [data .posterior_predictive , data .prior ],
153
177
shade = .9 ,
154
178
var_names = [r'$\mu_{thickness}$' ],
155
179
data_labels = ["Posterior Predictive" , "Prior Predictive" ],
@@ -165,6 +189,15 @@ def _prob_run(geo_model: gp.data.GeoModel, prob_model: gpp.GemPyPyroModel,
165
189
)
166
190
plt .show ()
167
191
192
+ az .plot_density (
193
+ data = [data .posterior_predictive , data .prior ],
194
+ shade = .9 ,
195
+ hdi_prob = .99 ,
196
+ data_labels = ["Posterior" , "Prior" ],
197
+ colors = [default_red , default_blue ],
198
+ )
199
+ plt .show ()
200
+
168
201
# Test posterior mean values
169
202
posterior_top_mean = float (data .posterior [r'$\mu_{top}$' ].mean ())
170
203
target_top = 0.00875
@@ -180,5 +213,3 @@ def _prob_run(geo_model: gp.data.GeoModel, prob_model: gpp.GemPyPyroModel,
180
213
print (f"Thickness mean: { posterior_thickness_mean } " )
181
214
print ("MCMC convergence diagnostics:" )
182
215
print (f"Divergences: { float (data .sample_stats .diverging .sum ())} " )
183
-
184
-
0 commit comments