Skip to content

Commit 9990029

Browse files
committed
[WIP] Towards a more stable activation function... hopefully
1 parent 51c5fb5 commit 9990029

File tree

3 files changed

+96
-27
lines changed

3 files changed

+96
-27
lines changed

gempy_engine/core/backend_tensor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ def _array(array_like, dtype=None):
178178
cls.tfnp.array = _array
179179
cls.tfnp.to_numpy = lambda tensor: tensor.detach().numpy()
180180
cls.tfnp.min = lambda tensor, axis: tensor.min(axis=axis)[0]
181+
cls.tfnp.max = lambda tensor, axis: tensor.max(axis=axis)[0]
181182
cls.tfnp.rint = lambda tensor: tensor.round().type(torch.int32)
182183
cls.tfnp.vstack = lambda tensors: torch.cat(tensors, dim=0)
183184
cls.tfnp.copy = lambda tensor: tensor.clone()

gempy_engine/modules/activator/activator_interface.py

Lines changed: 87 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def activate_formation_block(exported_fields: ExportedFields, ids: np.ndarray, s
1111
Z_x: np.ndarray = exported_fields.scalar_field_everywhere
1212
scalar_value_at_sp: np.ndarray = exported_fields.scalar_field_at_surface_points
1313

14-
if LEGACY :=False:
14+
if LEGACY := False:
1515
sigm = activate_formation_block_from_args(Z_x, ids, scalar_value_at_sp, sigmoid_slope)
1616
else:
1717
sigm = activate_formation_block_from_args_hard_sigmoid(Z_x, ids, scalar_value_at_sp, sigmoid_slope)
@@ -44,13 +44,12 @@ def activate_formation_block_from_args_hard_sigmoid(Z_x, ids, scalar_value_at_sp
4444
element_0 = bt.t.array([0], dtype=BackendTensor.dtype_obj)
4545

4646
min_Z_x = BackendTensor.t.min(Z_x, axis=0).reshape(-1) # ? Is this as good as it gets?
47-
max_Z_x = BackendTensor.t.max(Z_x, axis=0)[0].reshape(-1) # ? Is this as good as it gets?
48-
47+
max_Z_x = BackendTensor.t.max(Z_x, axis=0).reshape(-1) # ? Is this as good as it gets?
48+
4949
# Add 5%
5050
min_Z_x = min_Z_x - 0.05 * (max_Z_x - min_Z_x)
5151
max_Z_x = max_Z_x + 0.05 * (max_Z_x - min_Z_x)
52-
53-
52+
5453
drift_0_v = bt.tfnp.concatenate([min_Z_x, scalar_value_at_sp])
5554
drift_1_v = bt.tfnp.concatenate([scalar_value_at_sp, max_Z_x])
5655

@@ -64,10 +63,66 @@ def activate_formation_block_from_args_hard_sigmoid(Z_x, ids, scalar_value_at_sp
6463
# * Iterate over surface
6564
sigm = bt.t.zeros((1, Z_x.shape[0]), dtype=BackendTensor.dtype_obj)
6665

67-
for i in range(len(ids)):
68-
# if (i == 3):
69-
sigm += ids[i] * HardSigmoidModified.apply(Z_x, drift_0_v[i], drift_1_v[i])
70-
return sigm.view(1, -1)
66+
for i in range(len(ids) - 1):
67+
if False:
68+
sigm += HardSigmoidModified2.apply(
69+
Z_x,
70+
(drift_0_v[i] + drift_1_v[i]) / 2,
71+
(drift_0_v[i + 1] + drift_1_v[i + 1]) / 2,
72+
ids[i]
73+
)
74+
return sigm.reshape(1, -1)
75+
76+
else:
77+
output = bt.t.zeros_like(Z_x)
78+
a = (drift_0_v[i] + drift_1_v[i]) / 2
79+
b = (drift_0_v[i + 1] + drift_1_v[i + 1]) / 2
80+
81+
slope_up = 1 / (b - a)
82+
83+
# For x in the range [a, b]
84+
b_ = (Z_x > a) & (Z_x <= b)
85+
pos = slope_up * (Z_x[b_] - a)
86+
87+
output[b_] = ids[i] + pos
88+
sigm += output
89+
return sigm.reshape(1, -1)
90+
91+
92+
import torch
93+
94+
95+
class HardSigmoidModified2(torch.autograd.Function):
96+
@staticmethod
97+
def forward(ctx, input, a, b, id):
98+
ctx.save_for_backward(input)
99+
ctx.bounds = (a, b)
100+
ctx.id = id
101+
output = bt.t.zeros_like(input)
102+
slope_up = 1 / (b - a)
103+
104+
# For x in the range [a, b]
105+
b_ = (input > a) & (input <= b)
106+
pos = slope_up * (input[b_] - a)
107+
108+
output[b_] = id + pos
109+
110+
return output
111+
112+
@staticmethod
113+
def backward(ctx, grad_output):
114+
input = ctx.saved_tensors[0]
115+
a, b = ctx.bounds
116+
slope_up = 1 / (b - a)
117+
118+
b_ = (input > a) & (input <= b)
119+
120+
grad_input = grad_output.clone()
121+
# Apply gradient only within the range [a, b]
122+
grad_input[b_] = grad_input[b_] * slope_up
123+
124+
return grad_input, None, None, None
125+
71126

72127

73128
def _compute_sigmoid(Z_x, scale_0, scale_1, drift_0, drift_1, drift_id, sigmoid_slope):
@@ -100,28 +155,41 @@ def _add_relu():
100155

101156

102157
# * This gets the scalar gradient
103-
import torch
104158

105159

106160
class HardSigmoidModified(torch.autograd.Function):
107161
@staticmethod
108-
def forward(ctx, input, a, b):
162+
def forward(ctx, input, a, b, id):
109163
ctx.save_for_backward(input)
110164
ctx.bounds = (a, b)
111165
output = torch.zeros_like(input)
112166
slope_up = 100 / (b - a)
167+
midpoint = (a + b) / 2
113168

114169
# For x in the range [a, b]
115-
output[(input >= a) & (input <= b)] += torch.clamp(slope_up * (input[(input >= a) & (input <= b)] - a), min=0, max=1)
170+
b_ = (input > a) & (input <= b)
171+
172+
pos = slope_up * (input[b_] - a)
173+
174+
neg = -slope_up * (input[b_] - b)
175+
176+
print("Max min:", pos.max(), pos.min())
177+
foo = id * pos - (id - 1) * neg
178+
179+
# output[b_] = id * pos
180+
output[b_] = id + pos
181+
182+
# output[(input >= a) & (input <= b)] = torch.clamp(neg, min=0, max=1)
183+
# output[(input >= a) & (input <= b)] = torch.clamp(pos + neg, min=0, max=1)
184+
# output[(input >= a) & (input <= b)] = torch.clamp(pos + neg, min=0, max=1)
116185

117-
output[(input >= a) & (input <= b)] += torch.clamp(-slope_up * (input[(input >= a) & (input <= b)] - b), min=0, max=1)
118-
119186
# Clamping the values outside the range [a, c] to zero
120-
output[input < a] = 0
121-
output[input >= b] = 0
187+
# output[input < a] = 0
188+
# output[input >= b] = 0
122189

123-
return output
190+
# output[b_] *= id
124191

192+
return output
125193

126194
@staticmethod
127195
def backward(ctx, grad_output):
@@ -136,15 +204,15 @@ def backward(ctx, grad_output):
136204
grad_input[(input >= a) & (input < midpoint)] = 1 / (b - a)
137205
grad_input[(input > midpoint) & (input <= b)] = -1 / (b - a)
138206

139-
return grad_input, None, None
207+
return grad_input, None, None, None
140208

141209

142210
class HardSigmoid(torch.autograd.Function):
143211
@staticmethod
144212
def forward(ctx, input, a, b, c):
145213
ctx.save_for_backward(input)
146214
ctx.bounds = (a, b)
147-
slope = 1000 / (b - a)
215+
slope = 1 / (b - a)
148216
return torch.clamp(slope * (input - a) + 0.5, min=0, max=1)
149217

150218
@staticmethod

tests/test_common/test_modules/test_activator.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,14 +88,14 @@ def test_activator_3_layers(simple_model_3_layers, simple_grid_3d_more_points_gr
8888
interpolation_input.surface_points.sp_coords = interpolation_input.surface_points.sp_coords.detach().numpy()
8989

9090
if plot:
91-
plt.contourf(Z_x.reshape(50, 5, 50)[:, 0, :].T, N=40, cmap="autumn",
92-
extent=(.25, .75, .25, .75))
93-
94-
xyz = interpolation_input.surface_points.sp_coords
95-
plt.plot(xyz[:, 0], xyz[:, 2], "o")
96-
plt.colorbar()
97-
98-
plt.show()
91+
# plt.contourf(Z_x.reshape(50, 5, 50)[:, 0, :].T, N=40, cmap="autumn",
92+
# extent=(.25, .75, .25, .75))
93+
#
94+
# xyz = interpolation_input.surface_points.sp_coords
95+
# plt.plot(xyz[:, 0], xyz[:, 2], "o")
96+
# plt.colorbar()
97+
#
98+
# plt.show()
9999

100100
plt.contourf(
101101
ids_block[:-4].reshape(50, 5, 50)[:, 2, :].T,

0 commit comments

Comments
 (0)