Skip to content

Commit 895d762

Browse files
committed
[ENH/WIP] New activation function
1 parent 9990029 commit 895d762

File tree

2 files changed

+9
-13
lines changed

2 files changed

+9
-13
lines changed

gempy_engine/modules/activator/activator_interface.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ def activate_formation_block_from_args_hard_sigmoid(Z_x, ids, scalar_value_at_sp
4747
max_Z_x = BackendTensor.t.max(Z_x, axis=0).reshape(-1) # ? Is this as good as it gets?
4848

4949
# Add 5%
50-
min_Z_x = min_Z_x - 0.05 * (max_Z_x - min_Z_x)
51-
max_Z_x = max_Z_x + 0.05 * (max_Z_x - min_Z_x)
50+
min_Z_x = min_Z_x - 0.5 * (max_Z_x - min_Z_x)
51+
max_Z_x = max_Z_x + 0.5 * (max_Z_x - min_Z_x)
5252

5353
drift_0_v = bt.tfnp.concatenate([min_Z_x, scalar_value_at_sp])
5454
drift_1_v = bt.tfnp.concatenate([scalar_value_at_sp, max_Z_x])
@@ -59,7 +59,7 @@ def activate_formation_block_from_args_hard_sigmoid(Z_x, ids, scalar_value_at_sp
5959
#
6060
# scalar_1_v = bt.t.copy(ids)
6161
# scalar_1_v[-1] = 0
62-
62+
ids = ids.flip(0)
6363
# * Iterate over surface
6464
sigm = bt.t.zeros((1, Z_x.shape[0]), dtype=BackendTensor.dtype_obj)
6565

@@ -71,22 +71,20 @@ def activate_formation_block_from_args_hard_sigmoid(Z_x, ids, scalar_value_at_sp
7171
(drift_0_v[i + 1] + drift_1_v[i + 1]) / 2,
7272
ids[i]
7373
)
74-
return sigm.reshape(1, -1)
75-
7674
else:
7775
output = bt.t.zeros_like(Z_x)
7876
a = (drift_0_v[i] + drift_1_v[i]) / 2
79-
b = (drift_0_v[i + 1] + drift_1_v[i + 1]) / 2
80-
81-
slope_up = 1 / (b - a)
77+
b = (drift_0_v[i + 1] + drift_1_v[i + 1]) / 2
78+
79+
slope_up = -1 / (b - a)
8280

8381
# For x in the range [a, b]
8482
b_ = (Z_x > a) & (Z_x <= b)
8583
pos = slope_up * (Z_x[b_] - a)
8684

87-
output[b_] = ids[i] + pos
85+
output[b_] = ids[i] + 0.5 + pos
8886
sigm += output
89-
return sigm.reshape(1, -1)
87+
return sigm.reshape(1, -1)
9088

9189

9290
import torch
@@ -119,12 +117,11 @@ def backward(ctx, grad_output):
119117

120118
grad_input = grad_output.clone()
121119
# Apply gradient only within the range [a, b]
122-
grad_input[b_] = grad_input[b_] * slope_up
120+
grad_input[b_] = grad_input[b_] * slope_up
123121

124122
return grad_input, None, None, None
125123

126124

127-
128125
def _compute_sigmoid(Z_x, scale_0, scale_1, drift_0, drift_1, drift_id, sigmoid_slope):
129126
# TODO: Test to remove reshape once multiple values are implemented
130127

tests/test_common/test_modules/test_activator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ def test_activator_3_layers(simple_model_3_layers, simple_grid_3d_more_points_gr
7575
print(Z_x, Z_x.shape[0])
7676
print(sasp)
7777

78-
7978
ids_block = activate_formation_block(
8079
exported_fields=exported_fields,
8180
ids= ids,

0 commit comments

Comments
 (0)