Skip to content

Commit b9328f5

Browse files
committed
[ENH] Refactor and enhance posterior analysis testing
Added 2D plotting functionality using gempy_viewer, improved test coverage by including posterior mean visualization, and adjusted geomodel refinement. Simplified readability by reorganizing imports, and replaced deprecated or less efficient operations.
1 parent 9debf48 commit b9328f5

File tree

1 file changed

+50
-7
lines changed

1 file changed

+50
-7
lines changed

tests/test_posterior_analysis/test_gempy_posterior_I.py

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1+
import arviz as az
2+
import matplotlib.pyplot as plt
13
import numpy as np
24
import os
3-
import pyro.distributions as dist
4-
import torch
55

66
import gempy as gp
7+
import gempy_viewer as gpv
78
from gempy_engine.core.backend_tensor import BackendTensor
8-
import arviz as az
9-
import matplotlib.pyplot as plt
9+
from gempy_viewer.API._plot_2d_sections_api import plot_sections
10+
from gempy_viewer.core.data_to_show import DataToShow
1011

1112

1213
def test_gempy_posterior_I():
@@ -15,22 +16,64 @@ def test_gempy_posterior_I():
1516
geo_model = gp.create_geomodel(
1617
project_name='Wells',
1718
extent=[0, 12000, -500, 500, 0, 4000],
18-
refinement=3,
19+
refinement=4,
1920
importer_helper=gp.data.ImporterHelper(
2021
path_to_orientations=os.path.join(data_path, "2-layers", "2-layers_orientations.csv"),
2122
path_to_surface_points=os.path.join(data_path, "2-layers", "2-layers_surface_points.csv")
2223
)
2324
)
25+
26+
BackendTensor.change_backend_gempy(engine_backend=gp.data.AvailableBackends.PYTORCH)
2427

2528
geo_model.interpolation_options.uni_degree = 0
2629
geo_model.interpolation_options.mesh_extraction = False
27-
geo_model.interpolation_options.sigmoid_slope = 1100.
2830

2931
data = az.from_netcdf("../arviz_data.nc")
3032

31-
if PLOT:= False:
33+
if PLOT := False:
3234
_plot_posterior(data)
3335

36+
gp.compute_model(gempy_model=geo_model)
37+
p2d = gpv.plot_2d(
38+
model=geo_model,
39+
show_topography=False,
40+
legend=False,
41+
show_lith=False,
42+
show=False
43+
)
44+
45+
posterior_top_mean_z: np.ndarray = (data.posterior[r'$\mu_{top}$'].values[0, :])
46+
xyz = np.zeros((posterior_top_mean_z.shape[0], 3))
47+
xyz[:, 2] = posterior_top_mean_z
48+
world_coord = geo_model.input_transform.apply_inverse(xyz)
49+
i = 0
50+
for i in range(0, 200, 5):
51+
gp.modify_surface_points(
52+
geo_model=geo_model,
53+
slice=0,
54+
Z=world_coord[i, 2]
55+
)
56+
gp.compute_model(gempy_model=geo_model)
57+
58+
plot_sections(
59+
gempy_model=geo_model,
60+
sections_data=p2d.section_data_list,
61+
data_to_show=DataToShow(
62+
n_axis=1,
63+
show_data=True,
64+
show_surfaces=True,
65+
show_lith=False
66+
),
67+
kwargs_boundaries={
68+
"linewidth": 0.5,
69+
"alpha" : 0.1,
70+
},
71+
)
72+
73+
p2d.fig.show()
74+
75+
pass
76+
3477

3578
def _plot_posterior(data):
3679
az.plot_trace(data)

0 commit comments

Comments
 (0)