Skip to content

Commit e33d627

Browse files
authored
[ENH] Add soft segment activation function for improved layer segmentation (#11)
# Improved Activation Functions with Soft Segmentation This PR introduces a new soft segmentation approach for activation functions with the following changes: - Added `_soft_segment.py` module with `soft_segment_unbounded` function that provides improved sigmoid-based segmentation - Implemented specialized segmentation functions for lithology and fault handling - Updated the activation interface to use the new soft segmentation by default - Added support for both scalar and array-based sigmoid slope parameters - Incorporated per-edge adaptive temperature calculations for better control of sigmoid transitions - Added deprecation warning to the legacy hard sigmoid activation function - Created comprehensive test cases to validate the new segmentation functions The new implementation provides more precise control over the transition between geological formations and supports both scalar field values and fault handling with improved numerical stability.
2 parents e143236 + c20d02c commit e33d627

File tree

7 files changed

+223
-16
lines changed

7 files changed

+223
-16
lines changed
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import numbers
2+
3+
import numpy as np
4+
5+
from ...core.backend_tensor import BackendTensor as bt, BackendTensor
6+
7+
try:
8+
import torch
9+
except ModuleNotFoundError:
10+
pass
11+
12+
13+
def soft_segment_unbounded(Z, edges, ids, sigmoid_slope):
14+
"""
15+
Z: array of shape (...,) of scalar values
16+
edges: array of shape (K-1,) of finite split points [e1, e2, ..., e_{K-1}]
17+
ids: array of shape (K,) of the id for each of the K bins
18+
sigmoid_slope: scalar target peak slope m > 0
19+
returns: array of shape (...,) of the soft-assigned id
20+
"""
21+
ids = bt.t.array(ids[::-1].copy())
22+
23+
# Check if sigmoid function is num or array
24+
match sigmoid_slope:
25+
case numbers.Number():
26+
membership = _lith_segmentation(Z, edges, ids, sigmoid_slope)
27+
case _ if isinstance(sigmoid_slope, (np.ndarray, torch.Tensor)):
28+
membership = _final_faults_segmentation(Z, edges, sigmoid_slope)
29+
case _:
30+
raise ValueError("sigmoid_slope must be a float or an array")
31+
32+
ids__sum = bt.t.sum(membership * ids, axis=-1)
33+
return ids__sum[None, :]
34+
35+
36+
def _final_faults_segmentation(Z, edges, sigmoid_slope):
37+
first = _sigmoid(
38+
scalar_field=Z,
39+
edges=edges[0],
40+
tau_k=1 / sigmoid_slope
41+
) # shape (...,)
42+
last = _sigmoid(
43+
scalar_field=Z,
44+
edges=edges[-1],
45+
tau_k=1 / sigmoid_slope
46+
)
47+
membership = bt.t.concatenate(
48+
[first[..., None], last[..., None]],
49+
axis=-1
50+
) # shape (...,K)
51+
return membership
52+
53+
54+
def _lith_segmentation(Z, edges, ids, sigmoid_slope):
55+
# 1) per-edge temperatures τ_k = |Δ_k|/(4·m)
56+
jumps = bt.t.abs(ids[1:] - ids[:-1]) # shape (K-1,)
57+
tau_k = jumps / float(sigmoid_slope) # shape (K-1,)
58+
# 2) first bin (-∞, e1) via σ((e1 - Z)/τ₁)
59+
first = _sigmoid(
60+
scalar_field=-Z,
61+
edges=-edges[0],
62+
tau_k=tau_k[0]
63+
) # shape (...,)
64+
# 3) last bin [e_{K-1}, ∞) via σ((Z - e_{K-1})/τ_{K-1})
65+
# last = 1.0 / (1.0 + np.exp(-(Z - edges[-1]) / tau_k[-1])) # shape (...,)
66+
last = _sigmoid(
67+
scalar_field=Z,
68+
edges=edges[-1],
69+
tau_k=tau_k[-1]
70+
)
71+
# 4) middle bins [e_i, e_{i+1}): σ((Z - e_i)/τ_i) - σ((Z - e_{i+1})/τ_{i+1})
72+
# shape (...,1)
73+
left = _sigmoid(
74+
scalar_field=(Z[..., None]),
75+
edges=edges[:-1],
76+
tau_k=tau_k[:-1]
77+
)
78+
right = _sigmoid(
79+
scalar_field=(Z[..., None]),
80+
edges=edges[1:],
81+
tau_k=tau_k[1:]
82+
)
83+
middle = left - right # (...,K-2)
84+
# 5) assemble memberships and weight by ids
85+
membership = bt.t.concatenate(
86+
[first[..., None], middle, last[..., None]],
87+
axis=-1
88+
) # shape (...,K)
89+
return membership
90+
91+
92+
def _sigmoid(scalar_field, edges, tau_k):
93+
return 1.0 / (1.0 + bt.t.exp(-(scalar_field - edges) / tau_k))

gempy_engine/modules/activator/activator_interface.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,34 @@
11
import warnings
22

3-
from gempy_engine.config import DEBUG_MODE, AvailableBackends
4-
from gempy_engine.core.backend_tensor import BackendTensor as bt, BackendTensor
5-
import numpy as np
3+
from ...config import DEBUG_MODE, AvailableBackends
4+
from ...core.backend_tensor import BackendTensor as bt, BackendTensor
5+
from ...core.data.exported_fields import ExportedFields
6+
from ._soft_segment import soft_segment_unbounded
67

7-
from gempy_engine.core.data.exported_fields import ExportedFields
8+
import numpy as np
89

910

1011
def activate_formation_block(exported_fields: ExportedFields, ids: np.ndarray,
1112
sigmoid_slope: float) -> np.ndarray:
1213
Z_x: np.ndarray = exported_fields.scalar_field_everywhere
1314
scalar_value_at_sp: np.ndarray = exported_fields.scalar_field_at_surface_points
1415

15-
sigmoid_slope_negative = isinstance(sigmoid_slope, float) and sigmoid_slope < 0 # * sigmoid_slope can be array for finite faultskA
16-
17-
if LEGACY := True and not sigmoid_slope_negative: # * Here we branch to the experimental activation function with hard sigmoid
18-
sigm = activate_formation_block_from_args(Z_x, ids, scalar_value_at_sp, sigmoid_slope)
19-
else:
20-
from .torch_activation import activate_formation_block_from_args_hard_sigmoid
21-
sigm = activate_formation_block_from_args_hard_sigmoid(Z_x, ids, scalar_value_at_sp)
16+
sigmoid_slope_negative = isinstance(sigmoid_slope, float) and sigmoid_slope < 0 # * sigmoid_slope can be array for finite faultskA
2217

18+
if LEGACY := False and not sigmoid_slope_negative: # * Here we branch to the experimental activation function with hard sigmoid
19+
sigm = activate_formation_block_from_args(
20+
Z_x=Z_x,
21+
ids=ids,
22+
scalar_value_at_sp=scalar_value_at_sp,
23+
sigmoid_slope=sigmoid_slope
24+
)
25+
else:
26+
sigm = soft_segment_unbounded(
27+
Z=Z_x,
28+
edges=scalar_value_at_sp,
29+
ids=ids,
30+
sigmoid_slope=sigmoid_slope
31+
)
2332
return sigm
2433

2534

gempy_engine/modules/activator/torch_activation.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import warnings
2+
13
import torch
24
from ...core.backend_tensor import BackendTensor as bt, BackendTensor
35

@@ -14,6 +16,9 @@
1416

1517

1618
def activate_formation_block_from_args_hard_sigmoid(Z_x, ids, scalar_value_at_sp):
19+
20+
warnings.warn(DeprecationWarning("This function is deprecated. Use activate_formation_block instead."))
21+
1722
element_0 = bt.t.array([0], dtype=BackendTensor.dtype_obj)
1823

1924
min_Z_x = BackendTensor.t.min(Z_x, axis=0).reshape(-1) # ? Is this as good as it gets?

tests/fixtures/complex_geometries.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def one_fault_model():
3535

3636
spi: SurfacePoints = SurfacePoints(sp_coords)
3737
ori: Orientations = Orientations(dip_postions, dip_gradients)
38-
ids = np.array([1, 2, 3, 4, 5, 6])
38+
ids = np.array([1, 2, 3, 4, 5, 6, 7])
3939

4040
resolution = [2, 2, 2]
4141
extent = np.array([-500, 500., -500, 500, -450, 550]) / rescaling_factor
@@ -172,7 +172,7 @@ def graben_fault_model():
172172

173173
spi = SurfacePoints(sp_coords)
174174
ori = Orientations(dip_postions, dip_gradients)
175-
ids = np.array([1, 2, 3, 4, 5, 6])
175+
ids = np.array([1, 2, 3, 4, 5, 6, 7])
176176

177177
resolution = [2, 2, 2]
178178
extent = np.array([-500, 500., -500, 500, -450, 550]) / rescaling_factor

tests/test_common/test_api/test_faults/test_faults_graben.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def test_graben_fault_model(graben_fault_model):
2727
options.evaluation_options.dual_conturing_fancy = True
2828
options.debug=True
2929

30-
options.evaluation_options.number_octree_levels = 4
30+
options.evaluation_options.number_octree_levels = 5
3131
solutions: Solutions = compute_model(interpolation_input, options, structure)
3232

3333
outputs: list[OctreeLevel] = solutions.octrees_output

tests/test_common/test_api/test_faults/test_one_fault.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from gempy_engine.plugins.plotting.helper_functions import plot_block_and_input_2d, plot_scalar_and_input_2d
2222

2323

24-
def test_one_fault_model(one_fault_model, n_oct_levels=3):
24+
def test_one_fault_model(one_fault_model, n_oct_levels=5):
2525
interpolation_input: InterpolationInput
2626
structure: InputDataDescriptor
2727
options: InterpolationOptions
@@ -44,7 +44,7 @@ def test_one_fault_model(one_fault_model, n_oct_levels=3):
4444
gempy_v2_cov = _covariance_for_one_fault_model_from_gempy_v2()
4545
diff = last_cov - gempy_v2_cov
4646

47-
if plot_2d := False:
47+
if plot_2d := True:
4848
_plot_stack_raw(interpolation_input, outputs, structure)
4949
_plot_stack_squeezed_mask(interpolation_input, outputs, structure)
5050
_plot_stack_mask_component(interpolation_input, outputs, structure)
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import dataclasses
2+
import os
3+
4+
import matplotlib.pyplot as plt
5+
import numpy as np
6+
7+
from gempy_engine.API.interp_single._interp_scalar_field import _solve_interpolation, _evaluate_sys_eq
8+
from gempy_engine.API.interp_single._interp_single_feature import input_preprocess
9+
from gempy_engine.config import AvailableBackends
10+
from gempy_engine.core.data.internal_structs import SolverInput
11+
from gempy_engine.modules.activator.activator_interface import activate_formation_block
12+
from gempy_engine.core.backend_tensor import BackendTensor
13+
14+
dir_name = os.path.dirname(__file__)
15+
16+
plot = True
17+
18+
19+
def test_activator_3_layers_segmentation_function(simple_model_3_layers, simple_grid_3d_more_points_grid):
20+
Z_x, grid, ids_block, interpolation_input = _run_test(
21+
backend=AvailableBackends.numpy,
22+
ids=np.array([1, 20, 3, 4]),
23+
simple_grid_3d_more_points_grid=simple_grid_3d_more_points_grid,
24+
simple_model_3_layers=simple_model_3_layers
25+
)
26+
27+
if plot:
28+
_plot_continious(grid, ids_block, interpolation_input)
29+
30+
31+
def test_activator_3_layers_segmentation_function_II(simple_model_3_layers, simple_grid_3d_more_points_grid):
32+
Z_x, grid, ids_block, interpolation_input = _run_test(
33+
backend=AvailableBackends.numpy,
34+
ids=np.array([1, 2, 3, 4]),
35+
simple_grid_3d_more_points_grid=simple_grid_3d_more_points_grid,
36+
simple_model_3_layers=simple_model_3_layers
37+
)
38+
39+
BackendTensor.change_backend_gempy(AvailableBackends.numpy)
40+
41+
if plot:
42+
_plot_continious(grid, ids_block, interpolation_input)
43+
44+
45+
def test_activator_3_layers_segmentation_function_torch(simple_model_3_layers, simple_grid_3d_more_points_grid):
46+
Z_x, grid, ids_block, interpolation_input = _run_test(
47+
backend=AvailableBackends.PYTORCH,
48+
ids=np.array([1, 2, 3, 4]),
49+
simple_grid_3d_more_points_grid=simple_grid_3d_more_points_grid,
50+
simple_model_3_layers=simple_model_3_layers
51+
)
52+
53+
BackendTensor.change_backend_gempy(AvailableBackends.numpy)
54+
if plot:
55+
_plot_continious(grid, ids_block, interpolation_input)
56+
57+
58+
def _run_test(backend, ids, simple_grid_3d_more_points_grid, simple_model_3_layers):
59+
interpolation_input = simple_model_3_layers[0]
60+
options = simple_model_3_layers[1]
61+
data_shape = simple_model_3_layers[2].tensors_structure
62+
grid = dataclasses.replace(simple_grid_3d_more_points_grid)
63+
interpolation_input.set_temp_grid(grid)
64+
interp_input: SolverInput = input_preprocess(data_shape, interpolation_input)
65+
weights = _solve_interpolation(interp_input, options.kernel_options)
66+
exported_fields = _evaluate_sys_eq(interp_input, weights, options)
67+
exported_fields.set_structure_values(
68+
reference_sp_position=data_shape.reference_sp_position,
69+
slice_feature=interpolation_input.slice_feature,
70+
grid_size=interpolation_input.grid.len_all_grids)
71+
Z_x: np.ndarray = exported_fields.scalar_field
72+
sasp = exported_fields.scalar_field_at_surface_points
73+
print(Z_x, Z_x.shape[0])
74+
print(sasp)
75+
BackendTensor.change_backend_gempy(backend)
76+
ids_block = activate_formation_block(
77+
exported_fields=exported_fields,
78+
ids=ids,
79+
sigmoid_slope=500 * 4
80+
)[0, :-7]
81+
return Z_x, grid, ids_block, interpolation_input
82+
83+
84+
def _plot_continious(grid, ids_block, interpolation_input):
85+
block__ = ids_block[grid.dense_grid_slice]
86+
unique = np.unique(block__)
87+
t = block__.reshape(50, 5, 50)[:, 2, :].T
88+
unique = np.unique(t)
89+
90+
levels = np.linspace(t.min(), t.max(), 40)
91+
plt.contourf(
92+
t,
93+
levels=levels,
94+
cmap="jet",
95+
extent=(.25, .75, .25, .75)
96+
)
97+
xyz = interpolation_input.surface_points.sp_coords
98+
plt.plot(xyz[:, 0], xyz[:, 2], "o")
99+
plt.colorbar()
100+
plt.show()

0 commit comments

Comments
 (0)