Skip to content

Commit d709820

Browse files
committed
- [DOC/WIP] Adapting the code to the new Solutions classes
1 parent dd4c79d commit d709820

File tree

1 file changed

+46
-38
lines changed

1 file changed

+46
-38
lines changed

gempy_plugins/kriging/kriging.py

Lines changed: 46 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -7,32 +7,39 @@
77
"""
88

99
import warnings
10+
from typing import Optional, Sequence
11+
12+
import gempy as gp
13+
14+
from gempy_engine.core.data.raw_arrays_solution import RawArraysSolution
15+
1016
try:
1117
from scipy.spatial.distance import cdist
1218
except ImportError:
1319
warnings.warn('scipy.spatial package is not installed.')
1420

1521
import numpy as np
1622
import pandas as pd
17-
from gempy_viewer import _plot
18-
from gempy_viewer import helpers, _visualization_2d
1923
import matplotlib.cm as cm
2024
import matplotlib.pyplot as plt
2125
from copy import deepcopy
2226

23-
class domain(object):
2427

25-
def __init__(self, model, domain=None, data=None, set_mean=None):
28+
class Domain:
29+
def __init__(self, model_solutions: gp.data.Solutions, domain: Optional[Sequence] = None, data=None, set_mean=None):
2630

2731
# set model from a gempy solution
2832
# TODO: Check if I actually need all this or if its easier to just get grid and lith of the solution
29-
self.sol = model
33+
self.sol: RawArraysSolution = model_solutions.raw_arrays
3034

3135
# set kriging surfaces, basically in which lithologies to do all this, default is everything
3236
# TODO: Maybe also allow to pass a gempy regular grid object
3337
if domain is None:
3438
domain = np.unique(self.sol.lith_block)
35-
self.set_domain(domain)
39+
self.set_domain(
40+
domain=domain,
41+
grid_values=model_solutions.octrees_output[-1].grid_centers.regular_grid.original_values
42+
)
3643

3744
# set data, default is None
3845
# TODO: need to figure out a way to then set mean and variance for the SGS and SK
@@ -49,7 +56,7 @@ def __init__(self, model, domain=None, data=None, set_mean=None):
4956
self.inp_var = np.var(data[:, 3])
5057
self.inp_std = np.sqrt(self.inp_var)
5158

52-
def set_domain(self, domain):
59+
def set_domain(self, domain: np.ndarray, grid_values: np.ndarray):
5360
"""
5461
Method to cut domain by array of surfaces. Simply masking the lith_block with array of input lithologies
5562
applying mask to grid.
@@ -63,11 +70,12 @@ def set_domain(self, domain):
6370
self.domain = domain
6471

6572
# mask by array of input surfaces (by id, can be from different series)
73+
6674
self.mask = np.isin(self.sol.lith_block, self.domain)
6775

6876
# Apply mask to lith_block and grid
6977
self.krig_lith = self.sol.lith_block[self.mask]
70-
self.krig_grid = self.sol.grid.values[self.mask]
78+
self.krig_grid = grid_values[self.mask]
7179

7280
def set_data(self, data):
7381
"""
@@ -86,7 +94,7 @@ def set_data(self, data):
8694
self.data_df = pd.DataFrame(data=d)
8795

8896

89-
class variogram_model(object):
97+
class VariogramModel(object):
9098

9199
# class containing all the variogram functionality
92100

@@ -182,36 +190,35 @@ def plot(self, type_='variogram', show_parameters=True):
182190

183191
if show_parameters == True:
184192
plt.axhline(self.sill, color='black', lw=1)
185-
plt.text(self.range_*2, self.sill, 'sill', fontsize=12, va='center', ha='center', backgroundcolor='w')
193+
plt.text(self.range_ * 2, self.sill, 'sill', fontsize=12, va='center', ha='center', backgroundcolor='w')
186194
plt.axvline(self.range_, color='black', lw=1)
187-
plt.text(self.range_, self.sill/2, 'range', fontsize=12, va='center', ha='center', backgroundcolor='w')
195+
plt.text(self.range_, self.sill / 2, 'range', fontsize=12, va='center', ha='center', backgroundcolor='w')
188196

189197
if type_ == 'variogram':
190-
d = np.arange(0, self.range_*4, self.range_/1000)
198+
d = np.arange(0, self.range_ * 4, self.range_ / 1000)
191199
plt.plot(d, self.calculate_semivariance(d), label=self.theoretical_model + " variogram model")
192200
plt.ylabel('semivariance')
193201
plt.title('Variogram model')
194202
plt.legend()
195203

196204
if type_ == 'covariance':
197-
d = np.arange(0, self.range_*4, self.range_/1000)
205+
d = np.arange(0, self.range_ * 4, self.range_ / 1000)
198206
plt.plot(d, self.calculate_covariance(d), label=self.theoretical_model + " covariance model")
199207
plt.ylabel('covariance')
200208
plt.title('Covariance model')
201209
plt.legend()
202210

203211
if type_ == 'both':
204-
d = np.arange(0, self.range_*4, self.range_/1000)
212+
d = np.arange(0, self.range_ * 4, self.range_ / 1000)
205213
plt.plot(d, self.calculate_semivariance(d), label=self.theoretical_model + " variogram model")
206214
plt.plot(d, self.calculate_covariance(d), label=self.theoretical_model + " covariance model")
207215
plt.ylabel('semivariance/covariance')
208216
plt.title('Models of spatial correlation')
209217
plt.legend()
210218

211219
plt.xlabel('lag distance')
212-
plt.ylim(0-self.sill/20, self.sill+self.sill/20)
213-
plt.xlim(0, self.range_*4)
214-
220+
plt.ylim(0 - self.sill / 20, self.sill + self.sill / 20)
221+
plt.xlim(0, self.range_ * 4)
215222

216223

217224
class field_solution(object):
@@ -240,7 +247,7 @@ def plot_results(self, geo_data, prop='val', direction='y', result='interpolatio
240247
Returns:
241248
242249
"""
243-
a = np.full_like(self.domain.mask, np.nan, dtype=np.double) #array like lith_block but with nan if outside domain
250+
a = np.full_like(self.domain.mask, np.nan, dtype=np.double) # array like lith_block but with nan if outside domain
244251

245252
est_vals = self.results_df['estimated value'].values
246253
est_var = self.results_df['estimation variance'].values
@@ -257,13 +264,13 @@ def plot_results(self, geo_data, prop='val', direction='y', result='interpolatio
257264
else:
258265
print('prop must be val var or both')
259266

260-
#create plot object
267+
# create plot object
261268
p = _visualization_2d.PlotSolution(geo_data)
262269
_a, _b, _c, extent_val, x, y = p._slice(direction, cell_number)[:-2]
263270

264-
#colors
271+
# colors
265272
cmap = cm.get_cmap(cmap)
266-
cmap.set_bad(color='w', alpha=alpha) #define color and alpha for nan values
273+
cmap.set_bad(color='w', alpha=alpha) # define color and alpha for nan values
267274

268275
# plot
269276
if prop is not 'both':
@@ -299,6 +306,7 @@ def plot_results(self, geo_data, prop='val', direction='y', result='interpolatio
299306
helpers.add_colorbar(im2, label='variance[]')
300307
plt.tight_layout()
301308

309+
302310
# TODO: check with new ordianry kriging and nugget effect
303311
def simple_kriging(a, b, prop, var_mod, inp_mean):
304312
'''
@@ -320,11 +328,11 @@ def simple_kriging(a, b, prop, var_mod, inp_mean):
320328
w = np.zeros((shape))
321329

322330
# Filling matrices with covariances based on calculated distances
323-
C[:shape, :shape] = var_mod.calculate_covariance(b) #? cov or semiv
324-
c[:shape] = var_mod.calculate_covariance(a) #? cov or semiv
331+
C[:shape, :shape] = var_mod.calculate_covariance(b) # ? cov or semiv
332+
c[:shape] = var_mod.calculate_covariance(a) # ? cov or semiv
325333

326334
# nugget effect for simple kriging - dont remember why i set this actively, should be the same
327-
#np.fill_diagonal(C, self.sill)
335+
# np.fill_diagonal(C, self.sill)
328336

329337
# TODO: find way to check quality of matrix and solutions for instability
330338
# Solve Kriging equations
@@ -337,6 +345,7 @@ def simple_kriging(a, b, prop, var_mod, inp_mean):
337345

338346
return result, pred_var
339347

348+
340349
def ordinary_kriging(a, b, prop, var_mod):
341350
'''
342351
Method for ordinary kriging calculation.
@@ -381,6 +390,7 @@ def ordinary_kriging(a, b, prop, var_mod):
381390

382391
return result, pred_var
383392

393+
384394
def create_kriged_field(domain, variogram_model, distance_type='euclidian',
385395
moving_neighbourhood='all', kriging_type='OK', n_closest_points=20):
386396
'''
@@ -449,14 +459,15 @@ def create_kriged_field(domain, variogram_model, distance_type='euclidian',
449459

450460
# create dataframe of results data for calling
451461
d = {'X': domain.krig_grid[:, 0], 'Y': domain.krig_grid[:, 1], 'Z': domain.krig_grid[:, 2],
452-
'estimated value': kriging_result_vals, 'estimation variance': kriging_result_vars}
462+
'estimated value': kriging_result_vals, 'estimation variance': kriging_result_vars}
453463

454464
results_df = pd.DataFrame(data=d)
455465

456466
return field_solution(domain, variogram_model, results_df, field_type='interpolation')
457467

468+
458469
def create_gaussian_field(domain, variogram_model, distance_type='euclidian',
459-
moving_neighbourhood='all', kriging_type='OK', n_closest_points=20):
470+
moving_neighbourhood='all', kriging_type='OK', n_closest_points=20):
460471
'''
461472
Method to create a kriged field over the defined grid of the gempy solution depending on the defined
462473
input data (conditioning).
@@ -472,9 +483,9 @@ def create_gaussian_field(domain, variogram_model, distance_type='euclidian',
472483
np.random.shuffle(shuffled_grid)
473484

474485
# append shuffled grid to input locations
475-
sgs_locations = np.vstack((domain.data[:,:3],shuffled_grid))
486+
sgs_locations = np.vstack((domain.data[:, :3], shuffled_grid))
476487
# create array for input properties
477-
sgs_prop_updating = domain.data[:,3] # use this and then always stack new ant end
488+
sgs_prop_updating = domain.data[:, 3] # use this and then always stack new ant end
478489

479490
# container for estimation variances
480491
estimation_var = np.zeros(len(shuffled_grid))
@@ -493,9 +504,9 @@ def create_gaussian_field(domain, variogram_model, distance_type='euclidian',
493504
for i in range(len(domain.krig_grid)):
494505
# STEP 1: cut update distance matrix to correct size
495506
# HAVE TO CHECK IF THIS IS REALLY CORRECT
496-
active_distance_matrix = dist_all_to_all[:active_data,:active_data]
497-
active_distance_vector = dist_all_to_all[:,active_data] #basically next point to be simulated
498-
active_distance_vector = active_distance_vector[:active_data] #cut to left or diagonal
507+
active_distance_matrix = dist_all_to_all[:active_data, :active_data]
508+
active_distance_vector = dist_all_to_all[:, active_data] # basically next point to be simulated
509+
active_distance_vector = active_distance_vector[:active_data] # cut to left or diagonal
499510

500511
# TODO: NEED PART FOR ZERO INPUT OR NO POINTS IN RANGE OR LESS THAN N POINTS
501512

@@ -512,7 +523,7 @@ def create_gaussian_field(domain, variogram_model, distance_type='euclidian',
512523
# This seems to work
513524
if len(sgs_prop_updating) <= n_closest_points:
514525
a = active_distance_vector[:active_data]
515-
b = active_distance_matrix[:active_data,:active_data]
526+
b = active_distance_matrix[:active_data, :active_data]
516527
prop = sgs_prop_updating
517528

518529
# this does not # DAMN THIS STILL HAS ITSELF RIGHT? PROBLEM!
@@ -552,22 +563,19 @@ def create_gaussian_field(domain, variogram_model, distance_type='euclidian',
552563

553564
# append to prop:
554565
sgs_prop_updating = np.append(sgs_prop_updating, estimate)
555-
estimation_var[i]= var
566+
estimation_var[i] = var
556567

557568
# at end of loop: include simulated point for next step
558569
active_data += 1
559570

560571
# delete original input data from results
561-
simulated_prop = sgs_prop_updating[len(domain.data[:,3]):] # check if this works like intended
572+
simulated_prop = sgs_prop_updating[len(domain.data[:, 3]):] # check if this works like intended
562573

563574
# create dataframe of results data for calling
564575
d = {'X': shuffled_grid[:, 0], 'Y': shuffled_grid[:, 1], 'Z': shuffled_grid[:, 2],
565576
'estimated value': simulated_prop, 'estimation variance': estimation_var}
566577

567578
results_df = pd.DataFrame(data=d)
568-
results_df = results_df.sort_values(['X','Y','Z'])
579+
results_df = results_df.sort_values(['X', 'Y', 'Z'])
569580

570581
return field_solution(domain, variogram_model, results_df, field_type='simulation')
571-
572-
573-

0 commit comments

Comments
 (0)