@@ -35,22 +35,16 @@ def test_basic_gempy_I() -> None:
35
35
engine_config = gp .data .GemPyEngineConfig (backend = gp .data .AvailableBackends .numpy )
36
36
)
37
37
38
- # TODO: This is the part that has to go to a function no question
39
- # Probabilistic Geomodeling with Pyro
40
- # -----------------------------------
41
- # In this section, we introduce a probabilistic approach to geological modeling.
42
- # By using Pyro, a probabilistic programming language, we define a model that integrates
43
- # geological data with uncertainty quantification.
44
-
45
- from gempy_engine .core .data .interpolation_input import InterpolationInput
46
38
from gempy_engine .core .backend_tensor import BackendTensor
47
- from gempy .modules .data_manipulation .engine_factory import interpolation_input_from_structural_frame
48
-
49
- interpolation_input_copy : InterpolationInput = interpolation_input_from_structural_frame (geo_model )
50
- sp_coords_copy = interpolation_input_copy .surface_points .sp_coords
51
- # Change the backend to PyTorch for probabilistic modeling
52
39
BackendTensor .change_backend_gempy (engine_backend = gp .data .AvailableBackends .PYTORCH )
40
+
41
+ import pyro .distributions as dist
42
+ import torch
53
43
44
+ normal = dist .Normal (
45
+ loc = (geo_model .surface_points_copy_transformed .xyz [0 , 2 ]),
46
+ scale = torch .tensor (0.1 , dtype = torch .float64 )
47
+ )
54
48
# %%
55
49
# Running Prior Sampling and Visualization
56
50
# ----------------------------------------
@@ -75,7 +69,7 @@ def test_basic_gempy_I() -> None:
75
69
num_samples = 50
76
70
)
77
71
78
- prior = predictive (geo_model , sp_coords_copy , y_obs_list )
72
+ prior = predictive (geo_model , normal , y_obs_list )
79
73
80
74
data = az .from_pyro (prior = prior )
81
75
az .plot_trace (data .prior )
@@ -100,7 +94,7 @@ def test_basic_gempy_I() -> None:
100
94
warmup_steps = 50 ,
101
95
disable_validation = False
102
96
)
103
- mcmc .run (geo_model , sp_coords_copy , y_obs_list )
97
+ mcmc .run (geo_model , normal , y_obs_list )
104
98
105
99
posterior_samples = mcmc .get_samples ()
106
100
@@ -109,7 +103,7 @@ def test_basic_gempy_I() -> None:
109
103
posterior_samples = posterior_samples
110
104
)
111
105
112
- posterior_predictive = posterior_predictive_fn (geo_model , sp_coords_copy , y_obs_list )
106
+ posterior_predictive = posterior_predictive_fn (geo_model , normal , y_obs_list )
113
107
114
108
data = az .from_pyro (posterior = mcmc , prior = prior , posterior_predictive = posterior_predictive )
115
109
az .plot_trace (data )
0 commit comments