Skip to content

Commit 65a3057

Browse files
committed
- [DOC] Topology example revived
1 parent 47b88b8 commit 65a3057

File tree

1 file changed

+30
-67
lines changed

1 file changed

+30
-67
lines changed

gempy_plugins/topology_analysis/topology.py

Lines changed: 30 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,14 @@
2424

2525

2626
def _get_nunconf(geo_model) -> int:
27-
return np.count_nonzero( geo_model._stack.df.BottomRelation == "Erosion" ) - 2 # TODO -2 n other lith series
27+
return np.count_nonzero(geo_model._stack.df.BottomRelation == "Erosion") - 2 # TODO -2 n other lith series
2828

2929

3030
def _get_nfaults(geo_model) -> int:
3131
return np.count_nonzero(geo_model._faults.df.isFault)
3232

3333

3434
def _get_fault_blocks(geo_model: gp.data.GeoModel) -> np.ndarray:
35-
# n_unconf = _get_nunconf(geo_model)
36-
# n_faults = _get_nfaults(geo_model)
37-
3835
fault_blocks = geo_model.solutions.raw_arrays.block_matrix[geo_model.structural_frame.group_is_fault]
3936
resolution = geo_model.solutions.octrees_output[-1].grid_centers.regular_grid.resolution
4037

@@ -43,7 +40,6 @@ def _get_fault_blocks(geo_model: gp.data.GeoModel) -> np.ndarray:
4340

4441

4542
def _get_lith_blocks(geo_model: gp.data.GeoModel) -> np.ndarray:
46-
4743
lith_blocks = geo_model.solutions.raw_arrays.block_matrix[[not x for x in geo_model.structural_frame.group_is_fault]]
4844
resolution = geo_model.solutions.octrees_output[-1].grid_centers.regular_grid.resolution
4945

@@ -158,10 +154,8 @@ def _analyze_topology(
158154
fault_shift = fault_matrix_sum.min()
159155
fault_matrix_sum_shift = fault_matrix_sum - fault_shift
160156

161-
where = np.tile(lith_matrix, (n_lith, 1)) == np.unique(lith_matrix).reshape(
162-
-1, 1)
163-
lith_matrix_shift = np.sum(where * np.arange(n_lith).reshape(-1, 1),
164-
axis=0) + 1
157+
where = np.tile(lith_matrix, (n_lith, 1)) == np.unique(lith_matrix).reshape(-1, 1)
158+
lith_matrix_shift = np.sum(where * np.arange(n_lith).reshape(-1, 1), axis=0) + 1
165159

166160
topo_matrix = lith_matrix_shift + n_lith * fault_matrix_sum_shift
167161
topo_matrix_3D = topo_matrix.reshape(*res)
@@ -193,19 +187,14 @@ def _analyze_topology(
193187
else:
194188
z_edges = np.array([[], []])
195189

196-
edges = np.unique(
197-
np.concatenate((x_edges.T, y_edges.T, z_edges.T), axis=0), axis=0
198-
)
190+
edges = np.unique(np.concatenate((x_edges.T, y_edges.T, z_edges.T), axis=0), axis=0)
199191

200192
centroids = _get_centroids(topo_matrix_3D)
201193

202194
return edges, centroids
203195

204196

205-
def get_lot_node_to_lith_id(
206-
geo_model,
207-
centroids: Dict[int, np.ndarray]
208-
) -> Dict[int, int]:
197+
def get_lot_node_to_lith_id(geo_model, centroids: Dict[int, np.ndarray]) -> Dict[int, int]:
209198
"""Get look-up table to translate topology node id's back into GemPy lith
210199
id's.
211200
@@ -216,9 +205,8 @@ def get_lot_node_to_lith_id(
216205
Returns:
217206
Dict[int, int]: Look-up table translating node id -> lith id.
218207
"""
219-
lb = geo_model.solutions.lith_block.reshape(
220-
geo_model._grid.regular_grid.resolution
221-
).astype(int)
208+
resolution = geo_model.solutions.octrees_output[-1].grid_centers.regular_grid.resolution
209+
lb = geo_model.solutions.raw_arrays.lith_block.reshape(resolution).astype(int)
222210

223211
lot = {}
224212
for node, pos in centroids.items():
@@ -228,9 +216,7 @@ def get_lot_node_to_lith_id(
228216
return lot
229217

230218

231-
def get_lot_lith_to_node_id(
232-
lot: Dict[int, np.ndarray]
233-
) -> Dict[int, List[int]]:
219+
def get_lot_lith_to_node_id(lot: Dict[int, np.ndarray]) -> Dict[int, List[int]]:
234220
"""Get look-up table to translate lith id's back into topology node
235221
id's.
236222
@@ -250,10 +236,7 @@ def get_lot_lith_to_node_id(
250236
return lot2
251237

252238

253-
def get_lot_node_to_fault_block(
254-
geo_model,
255-
centroids: Dict[int, np.ndarray]
256-
) -> Dict[int, int]:
239+
def get_lot_node_to_fault_block( geo_model, centroids: Dict[int, np.ndarray] ) -> Dict[int, int]:
257240
"""Get a look-up table to access fault block id's for each topology node
258241
id.
259242
@@ -280,16 +263,14 @@ def get_fault_ids(geo_model) -> List[int]:
280263
Returns:
281264
List[int]: List of fault id's.
282265
"""
283-
f_series_names = geo_model._faults.df[geo_model._faults.df.isFault].index
284-
fault_ids = [0]
285-
for fsn in f_series_names:
286-
fid = geo_model._surfaces.df[
287-
geo_model._surfaces.df.series == fsn].id.values[0]
288-
fault_ids.append(fid)
266+
group_is_fault: list[bool] = geo_model.structural_frame.group_is_fault
267+
n_faults = np.sum(group_is_fault)
268+
fault_ids = [i for i in range(n_faults + 1)]
269+
289270
return fault_ids
290271

291272

292-
def get_lith_ids(geo_model, basement: bool = True) -> List[int]:
273+
def get_lith_ids(geo_model: gp.data.GeoModel) -> List[int]:
293274
""" Get lithology id's of all lithologies (except basement) in given
294275
geomodel.
295276
@@ -299,16 +280,13 @@ def get_lith_ids(geo_model, basement: bool = True) -> List[int]:
299280
Returns:
300281
List[int]: List of lithology id's.
301282
"""
302-
fmt_series_names = geo_model._faults.df[~geo_model._faults.df.isFault].index
303-
lith_ids = []
304-
for fsn in fmt_series_names:
305-
if not basement:
306-
if fsn == "Basement":
307-
continue
308-
lids = geo_model._surfaces.df[
309-
geo_model._surfaces.df.series == fsn].id.values
310-
for lid in lids:
311-
lith_ids.append(lid)
283+
# ! This is only working assuming that the faults are on top
284+
group_is_fault: list[bool] = geo_model.structural_frame.group_is_fault
285+
n_elements = geo_model.structural_frame.n_elements
286+
n_faults = np.sum(group_is_fault)
287+
288+
lith_ids = [i for i in range(n_faults + 1, n_elements + 1)]
289+
312290
return lith_ids
313291

314292

@@ -349,10 +327,7 @@ def get_detailed_labels(
349327
return edges_, centroids_
350328

351329

352-
def _get_edges(
353-
l: np.ndarray,
354-
r: np.ndarray
355-
) -> Optional[np.ndarray]:
330+
def _get_edges( l: np.ndarray, r: np.ndarray ) -> Optional[np.ndarray]:
356331
"""Get edges from given shifted arrays.
357332
358333
Args:
@@ -515,8 +490,7 @@ def plot_adjacency_matrix(
515490
n_faults = len(f_ids) // 2
516491
lith_ids = get_lith_ids(geo_model)
517492
n_liths = len(lith_ids)
518-
adj_matrix_labels, adj_matrix_lith_labels, adj_matrix_fault_labels = _get_adj_matrix_labels(
519-
geo_model)
493+
adj_matrix_labels, adj_matrix_lith_labels, adj_matrix_fault_labels = _get_adj_matrix_labels(geo_model)
520494
# ///////////////////////////////////////////////////////
521495
n = len(adj_matrix_labels)
522496
fig, ax = plt.subplots(figsize=(n // 2.5, n // 2.5))
@@ -536,13 +510,10 @@ def plot_adjacency_matrix(
536510

537511
# ///////////////////////////////////////////////////////
538512
# lith tick labels colors
539-
colors = list(geo_model._surfaces.colors.colordict.values())
540-
bboxkwargs = dict(
541-
edgecolor='none',
542-
)
543-
for xticklabel, yticklabel, l in zip(ax.xaxis.get_ticklabels(),
544-
ax.yaxis.get_ticklabels(),
545-
adj_matrix_labels[::1]):
513+
colors = geo_model.structural_frame.elements_colors
514+
# colors = list(geo_model._surfaces.colors.colordict.values())
515+
bboxkwargs = dict(edgecolor='none', )
516+
for xticklabel, yticklabel, l in zip(ax.xaxis.get_ticklabels(), ax.yaxis.get_ticklabels(), adj_matrix_labels[::1]):
546517
color = colors[l[0] - 1]
547518

548519
xticklabel.set_bbox(
@@ -569,8 +540,7 @@ def plot_adjacency_matrix(
569540
newax.spines['left'].set_position(('outward', 25))
570541
newax.set_ylim(0, n_faults * 2)
571542
newax.set_yticks(np.arange(1, n_faults * 2 + 1) - 0.5)
572-
newax.set_yticklabels(
573-
["FB " + str(i + 1) for i in range(n_faults * 2)][::1])
543+
newax.set_yticklabels( ["FB " + str(i + 1) for i in range(n_faults * 2)][::1])
574544

575545
# ///////////////////////////////////////////////////////
576546
# (dotted) lines for fb's
@@ -601,11 +571,7 @@ def plot_adjacency_matrix(
601571
return
602572

603573

604-
def check_adjacency(
605-
edges: set,
606-
n1: Union[int, str],
607-
n2: Union[int, str]
608-
) -> bool:
574+
def check_adjacency( edges: set, n1: Union[int, str], n2: Union[int, str] ) -> bool:
609575
"""Check if given nodes n1 and n2 are adjacent in given topology
610576
edge set.
611577
@@ -623,10 +589,7 @@ def check_adjacency(
623589
return False
624590

625591

626-
def get_adjacencies(
627-
edges: set,
628-
node: Union[int, str]
629-
) -> set:
592+
def get_adjacencies( edges: set, node: Union[int, str] ) -> set:
630593
"""Get node labels of all adjacent geobodies of geobody with given node
631594
in given set of edges.
632595

0 commit comments

Comments
 (0)