@@ -33,10 +33,10 @@ def activate_formation_block_from_args(Z_x, ids, scalar_value_at_sp, sigmoid_slo
33
33
sigm = bt .t .zeros ((1 , Z_x .shape [0 ]), dtype = BackendTensor .dtype_obj )
34
34
35
35
for i in range (len (ids )):
36
- sigm += _compute_sigmoid ( Z_x , scalar_0_v [ i ], scalar_1_v [ i ], drift_0_v [ i ], drift_1_v [ i ], ids [ i ], sigmoid_slope )
37
- # sigm += CustomSigmoidFunction.apply (Z_x, scalar_0_v[i], scalar_1_v[i], drift_0_v[i], drift_1_v[i], ids[i], sigmoid_slope)
38
-
39
- if False : _add_relu () # TODO: Add this
36
+ if LEGACY := False :
37
+ sigm += _compute_sigmoid (Z_x , scalar_0_v [i ], scalar_1_v [i ], drift_0_v [i ], drift_1_v [i ], ids [i ], sigmoid_slope )
38
+ else :
39
+ sigm += HardSigmoid . apply ( Z_x , scalar_0_v [ i ], scalar_1_v [ i ])
40
40
return sigm
41
41
42
42
@@ -49,8 +49,11 @@ def _compute_sigmoid(Z_x, scale_0, scale_1, drift_0, drift_1, drift_id, sigmoid_
49
49
50
50
sigmoid_slope_tensor = BackendTensor .t .array (sigmoid_slope , dtype = BackendTensor .dtype_obj )
51
51
52
- active_sig = - scale_0 .reshape ((- 1 , 1 )) / (1 + bt .tfnp .exp (- sigmoid_slope_tensor * (Z_x - drift_0 )))
53
- deactive_sig = - scale_1 .reshape ((- 1 , 1 )) / (1 + bt .tfnp .exp (sigmoid_slope_tensor * (Z_x - drift_1 )))
52
+ active_denominator = (1 + bt .tfnp .exp (- sigmoid_slope_tensor * (Z_x - drift_0 )))
53
+ deactive_denominator = (1 + bt .tfnp .exp (sigmoid_slope_tensor * (Z_x - drift_1 )))
54
+
55
+ active_sig = - scale_0 .reshape ((- 1 , 1 )) / active_denominator
56
+ deactive_sig = - scale_1 .reshape ((- 1 , 1 )) / deactive_denominator
54
57
activation_sig = active_sig + deactive_sig
55
58
56
59
sigm = activation_sig + drift_id .reshape ((- 1 , 1 ))
@@ -65,8 +68,30 @@ def _add_relu():
65
68
# formations_block += ReLU_down + ReLU_up
66
69
pass
67
70
71
+
68
72
# * This gets the scalar gradient
69
73
import torch
74
+ class HardSigmoid (torch .autograd .Function ):
75
+ @staticmethod
76
+ def forward (ctx , input , a , b ):
77
+ ctx .save_for_backward (input )
78
+ ctx .bounds = (a , b )
79
+ slope = 1 / (b - a )
80
+ return torch .clamp (slope * (input - a ) + 0.5 , min = 0 , max = 1 )
81
+
82
+ @staticmethod
83
+ def backward (ctx , grad_output ):
84
+ input , = ctx .saved_tensors
85
+ a , b = ctx .bounds
86
+ grad_input = grad_output .clone ()
87
+ grad_input [input < a ] = 0
88
+ grad_input [input > b ] = 0
89
+ grad_input [(input >= a ) & (input <= b )] = 1 / (b - a )
90
+ return grad_input , None , None
91
+
92
+
93
+
94
+
70
95
class CustomSigmoidFunction (torch .autograd .Function ):
71
96
@staticmethod
72
97
def forward (ctx , Z_x , scale_0 , scale_1 , drift_0 , drift_1 , drift_id , sigmoid_slope , epsilon = 1e-7 ):
0 commit comments