Skip to content

Commit f5d088c

Browse files
authored
Add posterior analysis functionality with saving and visualization (#6)
# Adding test to visualize gempy posterior models Added a new test module for posterior analysis that visualizes GemPy models using posterior samples. The test loads posterior data from an Arviz netcdf file and plots the mean surface positions across multiple iterations. Key changes: - Created new test directory for posterior analysis - Added functionality to save Arviz data to netcdf for later analysis - Moved backend initialization before model computation for consistency - Added visualization of posterior mean surface positions using gempy_viewer - Enhanced dev logs with notes about saving posteriors and likelihood functions The test demonstrates how to load posterior samples and visualize their impact on geological models, providing a foundation for more advanced posterior analysis workflows.
2 parents 4dacef7 + 774b63d commit f5d088c

File tree

4 files changed

+111
-5
lines changed

4 files changed

+111
-5
lines changed

docs/dev_logs/2025_May.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,11 @@
2424
- Priors
2525
- Likelihood
2626
----
27-
- Having a runner or something like that?
27+
- Having a runner or something like that?
28+
29+
### Saving posteriors gempy models for later analysis
30+
- ? Do we want to try to save the full input as versioning so that we can just have one system to go back and forth?
31+
32+
### Likelihood functions
33+
- ? Should I start to create a catalog of likelihood functions that goes with gempy?
34+
- This would be useful to be able to visualize whatever the likelihood function represents independent of the probabilistic model.

tests/test_posterior_analysis/__init__.py

Whitespace-only changes.
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import arviz as az
2+
import matplotlib.pyplot as plt
3+
import numpy as np
4+
import os
5+
6+
import gempy as gp
7+
import gempy_viewer as gpv
8+
from gempy_engine.core.backend_tensor import BackendTensor
9+
from gempy_viewer.API._plot_2d_sections_api import plot_sections
10+
from gempy_viewer.core.data_to_show import DataToShow
11+
12+
13+
def test_gempy_posterior_I():
14+
current_dir = os.path.dirname(os.path.abspath(__file__))
15+
data_path = os.path.abspath(os.path.join(current_dir, '..', '..', 'examples', 'tutorials', 'data'))
16+
geo_model = gp.create_geomodel(
17+
project_name='Wells',
18+
extent=[0, 12000, -500, 500, 0, 4000],
19+
refinement=4,
20+
importer_helper=gp.data.ImporterHelper(
21+
path_to_orientations=os.path.join(data_path, "2-layers", "2-layers_orientations.csv"),
22+
path_to_surface_points=os.path.join(data_path, "2-layers", "2-layers_surface_points.csv")
23+
)
24+
)
25+
26+
BackendTensor.change_backend_gempy(engine_backend=gp.data.AvailableBackends.PYTORCH)
27+
28+
geo_model.interpolation_options.uni_degree = 0
29+
geo_model.interpolation_options.mesh_extraction = False
30+
31+
data = az.from_netcdf("../arviz_data.nc")
32+
33+
if PLOT := False:
34+
_plot_posterior(data)
35+
36+
gp.compute_model(gempy_model=geo_model)
37+
p2d = gpv.plot_2d(
38+
model=geo_model,
39+
show_topography=False,
40+
legend=False,
41+
show_lith=False,
42+
show=False
43+
)
44+
45+
posterior_top_mean_z: np.ndarray = (data.posterior[r'$\mu_{top}$'].values[0, :])
46+
xyz = np.zeros((posterior_top_mean_z.shape[0], 3))
47+
xyz[:, 2] = posterior_top_mean_z
48+
world_coord = geo_model.input_transform.apply_inverse(xyz)
49+
i = 0
50+
for i in range(0, 200, 5):
51+
gp.modify_surface_points(
52+
geo_model=geo_model,
53+
slice=0,
54+
Z=world_coord[i, 2]
55+
)
56+
gp.compute_model(gempy_model=geo_model)
57+
58+
plot_sections(
59+
gempy_model=geo_model,
60+
sections_data=p2d.section_data_list,
61+
data_to_show=DataToShow(
62+
n_axis=1,
63+
show_data=True,
64+
show_surfaces=True,
65+
show_lith=False
66+
),
67+
kwargs_boundaries={
68+
"linewidth": 0.5,
69+
"alpha" : 0.1,
70+
},
71+
)
72+
73+
p2d.fig.show()
74+
75+
pass
76+
77+
78+
def _plot_posterior(data):
79+
az.plot_trace(data)
80+
plt.show()
81+
from gempy_probability.modules.plot.plot_posterior import default_red, default_blue
82+
az.plot_density(
83+
data=[data.posterior_predictive, data.prior_predictive],
84+
shade=.9,
85+
var_names=[r'$\mu_{thickness}$'],
86+
data_labels=["Posterior Predictive", "Prior Predictive"],
87+
colors=[default_red, default_blue],
88+
)
89+
plt.show()
90+
az.plot_density(
91+
data=[data, data.prior],
92+
shade=.9,
93+
hdi_prob=.99,
94+
data_labels=["Posterior", "Prior"],
95+
colors=[default_red, default_blue],
96+
)
97+
plt.show()

tests/test_prob_model/test_prob_I.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ def test_basic_gempy_I() -> None:
2020
)
2121
)
2222

23+
# TODO: Make this a more elegant way
24+
BackendTensor.change_backend_gempy(engine_backend=gp.data.AvailableBackends.PYTORCH)
25+
2326
# TODO: Convert this into an options preset
2427
geo_model.interpolation_options.uni_degree = 0
2528
geo_model.interpolation_options.mesh_extraction = False
@@ -37,8 +40,6 @@ def test_basic_gempy_I() -> None:
3740
assert geo_model.grid.values.shape[0] == 100, "Custom grid should have 100 cells"
3841
gp.compute_model(gempy_model=geo_model)
3942

40-
# TODO: Make this a more elegant way
41-
BackendTensor.change_backend_gempy(engine_backend=gp.data.AvailableBackends.PYTORCH)
4243

4344
normal = dist.Normal(
4445
loc=(geo_model.surface_points_copy_transformed.xyz[0, 2]),
@@ -105,6 +106,9 @@ def _prob_run(geo_model: gp.data.GeoModel, prob_model: callable,
105106
data = az.from_pyro(posterior=mcmc, prior=prior, posterior_predictive=posterior_predictive)
106107
# endregion
107108

109+
if True: # * Save the arviz data
110+
data.to_netcdf("arviz_data.nc")
111+
108112
az.plot_trace(data)
109113
plt.show()
110114
from gempy_probability.modules.plot.plot_posterior import default_red, default_blue
@@ -140,5 +144,3 @@ def _prob_run(geo_model: gp.data.GeoModel, prob_model: callable,
140144
print(f"Thickness mean: {posterior_thickness_mean}")
141145
print("MCMC convergence diagnostics:")
142146
print(f"Divergences: {float(data.sample_stats.diverging.sum())}")
143-
print(f"Acceptance rate: {float(data.sample_stats.acceptance_rate.mean())}")
144-
print(f"Mean tree depth: {float(data.sample_stats.mean_tree_depth.mean())}")

0 commit comments

Comments
 (0)