Skip to content

Commit 88cf224

Browse files
committed
[CLN] Cleaning complex network
[WIP] Testing new activation functions
1 parent 9a867fa commit 88cf224

File tree

1 file changed

+31
-6
lines changed

1 file changed

+31
-6
lines changed

gempy_engine/modules/activator/activator_interface.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@ def activate_formation_block_from_args(Z_x, ids, scalar_value_at_sp, sigmoid_slo
3333
sigm = bt.t.zeros((1, Z_x.shape[0]), dtype=BackendTensor.dtype_obj)
3434

3535
for i in range(len(ids)):
36-
sigm += _compute_sigmoid(Z_x, scalar_0_v[i], scalar_1_v[i], drift_0_v[i], drift_1_v[i], ids[i], sigmoid_slope)
37-
# sigm += CustomSigmoidFunction.apply(Z_x, scalar_0_v[i], scalar_1_v[i], drift_0_v[i], drift_1_v[i], ids[i], sigmoid_slope)
38-
39-
if False: _add_relu() # TODO: Add this
36+
if LEGACY:=False:
37+
sigm += _compute_sigmoid(Z_x, scalar_0_v[i], scalar_1_v[i], drift_0_v[i], drift_1_v[i], ids[i], sigmoid_slope)
38+
else:
39+
sigm += HardSigmoid.apply(Z_x, scalar_0_v[i], scalar_1_v[i])
4040
return sigm
4141

4242

@@ -49,8 +49,11 @@ def _compute_sigmoid(Z_x, scale_0, scale_1, drift_0, drift_1, drift_id, sigmoid_
4949

5050
sigmoid_slope_tensor = BackendTensor.t.array(sigmoid_slope, dtype=BackendTensor.dtype_obj)
5151

52-
active_sig = -scale_0.reshape((-1, 1)) / (1 + bt.tfnp.exp(-sigmoid_slope_tensor * (Z_x - drift_0)))
53-
deactive_sig = -scale_1.reshape((-1, 1)) / (1 + bt.tfnp.exp(sigmoid_slope_tensor * (Z_x - drift_1)))
52+
active_denominator = (1 + bt.tfnp.exp(-sigmoid_slope_tensor * (Z_x - drift_0)))
53+
deactive_denominator = (1 + bt.tfnp.exp(sigmoid_slope_tensor * (Z_x - drift_1)))
54+
55+
active_sig = -scale_0.reshape((-1, 1)) / active_denominator
56+
deactive_sig = -scale_1.reshape((-1, 1)) / deactive_denominator
5457
activation_sig = active_sig + deactive_sig
5558

5659
sigm = activation_sig + drift_id.reshape((-1, 1))
@@ -65,8 +68,30 @@ def _add_relu():
6568
# formations_block += ReLU_down + ReLU_up
6669
pass
6770

71+
6872
# * This gets the scalar gradient
6973
import torch
74+
class HardSigmoid(torch.autograd.Function):
75+
@staticmethod
76+
def forward(ctx, input, a, b):
77+
ctx.save_for_backward(input)
78+
ctx.bounds = (a, b)
79+
slope = 1 / (b - a)
80+
return torch.clamp(slope * (input - a) + 0.5, min=0, max=1)
81+
82+
@staticmethod
83+
def backward(ctx, grad_output):
84+
input, = ctx.saved_tensors
85+
a, b = ctx.bounds
86+
grad_input = grad_output.clone()
87+
grad_input[input < a] = 0
88+
grad_input[input > b] = 0
89+
grad_input[(input >= a) & (input <= b)] = 1 / (b - a)
90+
return grad_input, None, None
91+
92+
93+
94+
7095
class CustomSigmoidFunction(torch.autograd.Function):
7196
@staticmethod
7297
def forward(ctx, Z_x, scale_0, scale_1, drift_0, drift_1, drift_id, sigmoid_slope, epsilon=1e-7):

0 commit comments

Comments
 (0)