3
3
import numpy as np
4
4
5
5
from ...core .backend_tensor import BackendTensor as bt , BackendTensor
6
+ from ...core .data .kernel_classes .kernel_functions import dtype
6
7
7
8
try :
8
9
import torch
@@ -53,7 +54,7 @@ def _final_faults_segmentation(Z, edges, sigmoid_slope):
53
54
54
55
def _lith_segmentation (Z , edges , ids , sigmoid_slope ):
55
56
# 1) per-edge temperatures τ_k = |Δ_k|/(4·m)
56
- jumps = bt .t .abs (ids [1 :] - ids [:- 1 ]) # shape (K-1,)
57
+ jumps = bt .t .abs (ids [1 :] - ids [:- 1 ], dtype = bt . dtype_obj ) # shape (K-1,)
57
58
tau_k = jumps / float (sigmoid_slope ) # shape (K-1,)
58
59
# 2) first bin (-∞, e1) via σ((e1 - Z)/τ₁)
59
60
first = _sigmoid (
@@ -90,4 +91,31 @@ def _lith_segmentation(Z, edges, ids, sigmoid_slope):
90
91
91
92
92
93
def _sigmoid (scalar_field , edges , tau_k ):
93
- return 1.0 / (1.0 + bt .t .exp (- (scalar_field - edges ) / tau_k ))
94
+ x = - (scalar_field - edges ) / tau_k
95
+ return 1.0 / (1.0 + bt .t .exp (x ))
96
+
97
+
98
+ def _sigmoid_stable (scalar_field , edges , tau_k ):
99
+ """
100
+ Numerically‐stable sigmoid of (scalar_field - edges)/tau_k,
101
+ only exponentiates on the needed slice.
102
+ """
103
+ x = (scalar_field - edges ) / tau_k
104
+ # allocate output
105
+ out = bt .t .empty_like (x )
106
+
107
+ # mask which positions are >=0 or <0
108
+ pos = x >= 0
109
+ neg = ~ pos
110
+
111
+ # for x>=0: safe to compute exp(-x)
112
+ x_pos = x [pos ]
113
+ exp_neg = bt .t .exp (- x_pos ) # no overflow since -x_pos <= 0
114
+ out [pos ] = 1.0 / (1.0 + exp_neg )
115
+
116
+ # for x<0: safe to compute exp(x)
117
+ x_neg = x [neg ]
118
+ exp_pos = bt .t .exp (x_neg ) # no overflow since x_neg < 0
119
+ out [neg ] = exp_pos / (1.0 + exp_pos )
120
+
121
+ return out
0 commit comments