Skip to content

Commit 9debf48

Browse files
committed
[TEST] Adding test to visualize gempy posterior models
1 parent d13fbbf commit 9debf48

File tree

3 files changed

+60
-4
lines changed

3 files changed

+60
-4
lines changed

tests/test_posterior_analysis/__init__.py

Whitespace-only changes.
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import numpy as np
2+
import os
3+
import pyro.distributions as dist
4+
import torch
5+
6+
import gempy as gp
7+
from gempy_engine.core.backend_tensor import BackendTensor
8+
import arviz as az
9+
import matplotlib.pyplot as plt
10+
11+
12+
def test_gempy_posterior_I():
13+
current_dir = os.path.dirname(os.path.abspath(__file__))
14+
data_path = os.path.abspath(os.path.join(current_dir, '..', '..', 'examples', 'tutorials', 'data'))
15+
geo_model = gp.create_geomodel(
16+
project_name='Wells',
17+
extent=[0, 12000, -500, 500, 0, 4000],
18+
refinement=3,
19+
importer_helper=gp.data.ImporterHelper(
20+
path_to_orientations=os.path.join(data_path, "2-layers", "2-layers_orientations.csv"),
21+
path_to_surface_points=os.path.join(data_path, "2-layers", "2-layers_surface_points.csv")
22+
)
23+
)
24+
25+
geo_model.interpolation_options.uni_degree = 0
26+
geo_model.interpolation_options.mesh_extraction = False
27+
geo_model.interpolation_options.sigmoid_slope = 1100.
28+
29+
data = az.from_netcdf("../arviz_data.nc")
30+
31+
if PLOT:= False:
32+
_plot_posterior(data)
33+
34+
35+
def _plot_posterior(data):
36+
az.plot_trace(data)
37+
plt.show()
38+
from gempy_probability.modules.plot.plot_posterior import default_red, default_blue
39+
az.plot_density(
40+
data=[data.posterior_predictive, data.prior_predictive],
41+
shade=.9,
42+
var_names=[r'$\mu_{thickness}$'],
43+
data_labels=["Posterior Predictive", "Prior Predictive"],
44+
colors=[default_red, default_blue],
45+
)
46+
plt.show()
47+
az.plot_density(
48+
data=[data, data.prior],
49+
shade=.9,
50+
hdi_prob=.99,
51+
data_labels=["Posterior", "Prior"],
52+
colors=[default_red, default_blue],
53+
)
54+
plt.show()

tests/test_prob_model/test_prob_I.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ def test_basic_gempy_I() -> None:
2020
)
2121
)
2222

23+
# TODO: Make this a more elegant way
24+
BackendTensor.change_backend_gempy(engine_backend=gp.data.AvailableBackends.PYTORCH)
25+
2326
# TODO: Convert this into an options preset
2427
geo_model.interpolation_options.uni_degree = 0
2528
geo_model.interpolation_options.mesh_extraction = False
@@ -37,8 +40,6 @@ def test_basic_gempy_I() -> None:
3740
assert geo_model.grid.values.shape[0] == 100, "Custom grid should have 100 cells"
3841
gp.compute_model(gempy_model=geo_model)
3942

40-
# TODO: Make this a more elegant way
41-
BackendTensor.change_backend_gempy(engine_backend=gp.data.AvailableBackends.PYTORCH)
4243

4344
normal = dist.Normal(
4445
loc=(geo_model.surface_points_copy_transformed.xyz[0, 2]),
@@ -105,6 +106,9 @@ def _prob_run(geo_model: gp.data.GeoModel, prob_model: callable,
105106
data = az.from_pyro(posterior=mcmc, prior=prior, posterior_predictive=posterior_predictive)
106107
# endregion
107108

109+
if True: # * Save the arviz data
110+
data.to_netcdf("arviz_data.nc")
111+
108112
az.plot_trace(data)
109113
plt.show()
110114
from gempy_probability.modules.plot.plot_posterior import default_red, default_blue
@@ -140,5 +144,3 @@ def _prob_run(geo_model: gp.data.GeoModel, prob_model: callable,
140144
print(f"Thickness mean: {posterior_thickness_mean}")
141145
print("MCMC convergence diagnostics:")
142146
print(f"Divergences: {float(data.sample_stats.diverging.sum())}")
143-
print(f"Acceptance rate: {float(data.sample_stats.acceptance_rate.mean())}")
144-
print(f"Mean tree depth: {float(data.sample_stats.mean_tree_depth.mean())}")

0 commit comments

Comments
 (0)