@@ -47,8 +47,8 @@ def activate_formation_block_from_args_hard_sigmoid(Z_x, ids, scalar_value_at_sp
47
47
max_Z_x = BackendTensor .t .max (Z_x , axis = 0 ).reshape (- 1 ) # ? Is this as good as it gets?
48
48
49
49
# Add 5%
50
- min_Z_x = min_Z_x - 0.05 * (max_Z_x - min_Z_x )
51
- max_Z_x = max_Z_x + 0.05 * (max_Z_x - min_Z_x )
50
+ min_Z_x = min_Z_x - 0.5 * (max_Z_x - min_Z_x )
51
+ max_Z_x = max_Z_x + 0.5 * (max_Z_x - min_Z_x )
52
52
53
53
drift_0_v = bt .tfnp .concatenate ([min_Z_x , scalar_value_at_sp ])
54
54
drift_1_v = bt .tfnp .concatenate ([scalar_value_at_sp , max_Z_x ])
@@ -59,7 +59,7 @@ def activate_formation_block_from_args_hard_sigmoid(Z_x, ids, scalar_value_at_sp
59
59
#
60
60
# scalar_1_v = bt.t.copy(ids)
61
61
# scalar_1_v[-1] = 0
62
-
62
+ ids = ids . flip ( 0 )
63
63
# * Iterate over surface
64
64
sigm = bt .t .zeros ((1 , Z_x .shape [0 ]), dtype = BackendTensor .dtype_obj )
65
65
@@ -71,22 +71,20 @@ def activate_formation_block_from_args_hard_sigmoid(Z_x, ids, scalar_value_at_sp
71
71
(drift_0_v [i + 1 ] + drift_1_v [i + 1 ]) / 2 ,
72
72
ids [i ]
73
73
)
74
- return sigm .reshape (1 , - 1 )
75
-
76
74
else :
77
75
output = bt .t .zeros_like (Z_x )
78
76
a = (drift_0_v [i ] + drift_1_v [i ]) / 2
79
- b = (drift_0_v [i + 1 ] + drift_1_v [i + 1 ]) / 2
80
-
81
- slope_up = 1 / (b - a )
77
+ b = (drift_0_v [i + 1 ] + drift_1_v [i + 1 ]) / 2
78
+
79
+ slope_up = - 1 / (b - a )
82
80
83
81
# For x in the range [a, b]
84
82
b_ = (Z_x > a ) & (Z_x <= b )
85
83
pos = slope_up * (Z_x [b_ ] - a )
86
84
87
- output [b_ ] = ids [i ] + pos
85
+ output [b_ ] = ids [i ] + 0.5 + pos
88
86
sigm += output
89
- return sigm .reshape (1 , - 1 )
87
+ return sigm .reshape (1 , - 1 )
90
88
91
89
92
90
import torch
@@ -119,12 +117,11 @@ def backward(ctx, grad_output):
119
117
120
118
grad_input = grad_output .clone ()
121
119
# Apply gradient only within the range [a, b]
122
- grad_input [b_ ] = grad_input [b_ ] * slope_up
120
+ grad_input [b_ ] = grad_input [b_ ] * slope_up
123
121
124
122
return grad_input , None , None , None
125
123
126
124
127
-
128
125
def _compute_sigmoid (Z_x , scale_0 , scale_1 , drift_0 , drift_1 , drift_id , sigmoid_slope ):
129
126
# TODO: Test to remove reshape once multiple values are implemented
130
127
0 commit comments