1
1
import warnings
2
2
3
- from gempy_engine .config import DEBUG_MODE , AvailableBackends
4
- from gempy_engine .core .backend_tensor import BackendTensor as bt , BackendTensor
3
+ from ...config import DEBUG_MODE , AvailableBackends
4
+ from ...core .backend_tensor import BackendTensor as bt , BackendTensor
5
+ from ...core .data .exported_fields import ExportedFields
6
+ from ._soft_segment import soft_segment_unbounded
7
+
5
8
import numpy as np
6
9
import numbers
7
10
8
- from gempy_engine .core .data .exported_fields import ExportedFields
9
-
10
11
11
12
def activate_formation_block (exported_fields : ExportedFields , ids : np .ndarray ,
12
13
sigmoid_slope : float ) -> np .ndarray :
@@ -23,9 +24,17 @@ def activate_formation_block(exported_fields: ExportedFields, ids: np.ndarray,
23
24
sigmoid_slope = sigmoid_slope
24
25
)
25
26
else :
27
+ sigm = soft_segment_unbounded (
28
+ Z = Z_x ,
29
+ edges = scalar_value_at_sp ,
30
+ ids = ids ,
31
+ sigmoid_slope = sigmoid_slope
32
+ )
33
+ return sigm
34
+
26
35
match BackendTensor .engine_backend :
27
36
case AvailableBackends .PYTORCH :
28
- sigm = soft_segment_unbounded (
37
+ sigm = soft_segment_unbounded_torch (
29
38
Z = Z_x ,
30
39
edges = scalar_value_at_sp ,
31
40
ids = ids ,
@@ -85,7 +94,7 @@ def _compute_sigmoid(Z_x, scale_0, scale_1, drift_0, drift_1, drift_id, sigmoid_
85
94
import torch
86
95
87
96
88
- def soft_segment_unbounded (Z , edges , ids , sigmoid_slope ):
97
+ def soft_segment_unbounded_torch (Z , edges , ids , sigmoid_slope ):
89
98
"""
90
99
Z: (...,) tensor of scalar values
91
100
edges: (K-1,) tensor of finite split points [e1, e2, ..., e_{K-1}]
@@ -124,7 +133,7 @@ def soft_segment_unbounded(Z, edges, ids, sigmoid_slope):
124
133
125
134
# weighted sum by the ids
126
135
ids__sum = (membership * ids ).sum (dim = - 1 )
127
-
136
+
128
137
# make it at least 2d
129
138
ids__sum = ids__sum [None , :]
130
139
@@ -151,10 +160,9 @@ def soft_segment_unbounded_np(Z, edges, ids, sigmoid_slope):
151
160
case np .ndarray ():
152
161
membership = _final_faults_segmentation (Z , edges , sigmoid_slope )
153
162
case numbers .Number ():
154
- membership = _lith_segmentation (Z , edges , ids , sigmoid_slope )
163
+ membership = _lith_segmentation (Z , edges , ids , sigmoid_slope )
155
164
case _:
156
- raise ValueError ("sigmoid_slope must be a float or an array" )
157
-
165
+ raise ValueError ("sigmoid_slope must be a float or an array" )
158
166
159
167
ids__sum = np .sum (membership * ids , axis = - 1 )
160
168
return np .atleast_2d (ids__sum )
@@ -179,7 +187,6 @@ def _final_faults_segmentation(Z, edges, sigmoid_slope):
179
187
180
188
181
189
def _lith_segmentation (Z , edges , ids , sigmoid_slope ):
182
-
183
190
# 1) per-edge temperatures τ_k = |Δ_k|/(4·m)
184
191
jumps = np .abs (ids [1 :] - ids [:- 1 ]) # shape (K-1,)
185
192
tau_k = jumps / float (sigmoid_slope ) # shape (K-1,)
0 commit comments