Skip to content

Commit 51c5fb5

Browse files
committed
[WIP] Improving activation function
1 parent 84c82ae commit 51c5fb5

File tree

3 files changed

+102
-18
lines changed

3 files changed

+102
-18
lines changed

gempy_engine/modules/activator/activator_interface.py

Lines changed: 77 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@ def activate_formation_block(exported_fields: ExportedFields, ids: np.ndarray, s
1111
Z_x: np.ndarray = exported_fields.scalar_field_everywhere
1212
scalar_value_at_sp: np.ndarray = exported_fields.scalar_field_at_surface_points
1313

14-
sigm = activate_formation_block_from_args(Z_x, ids, scalar_value_at_sp, sigmoid_slope)
14+
if LEGACY :=False:
15+
sigm = activate_formation_block_from_args(Z_x, ids, scalar_value_at_sp, sigmoid_slope)
16+
else:
17+
sigm = activate_formation_block_from_args_hard_sigmoid(Z_x, ids, scalar_value_at_sp, sigmoid_slope)
1518

1619
return sigm
1720

@@ -33,13 +36,40 @@ def activate_formation_block_from_args(Z_x, ids, scalar_value_at_sp, sigmoid_slo
3336
sigm = bt.t.zeros((1, Z_x.shape[0]), dtype=BackendTensor.dtype_obj)
3437

3538
for i in range(len(ids)):
36-
if LEGACY:=True:
37-
sigm += _compute_sigmoid(Z_x, scalar_0_v[i], scalar_1_v[i], drift_0_v[i], drift_1_v[i], ids[i], sigmoid_slope)
38-
else:
39-
sigm += HardSigmoid.apply(Z_x, scalar_0_v[i], scalar_1_v[i])
39+
sigm += _compute_sigmoid(Z_x, scalar_0_v[i], scalar_1_v[i], drift_0_v[i], drift_1_v[i], ids[i], sigmoid_slope)
4040
return sigm
4141

4242

43+
def activate_formation_block_from_args_hard_sigmoid(Z_x, ids, scalar_value_at_sp, sigmoid_slope):
44+
element_0 = bt.t.array([0], dtype=BackendTensor.dtype_obj)
45+
46+
min_Z_x = BackendTensor.t.min(Z_x, axis=0).reshape(-1) # ? Is this as good as it gets?
47+
max_Z_x = BackendTensor.t.max(Z_x, axis=0)[0].reshape(-1) # ? Is this as good as it gets?
48+
49+
# Add 5%
50+
min_Z_x = min_Z_x - 0.05 * (max_Z_x - min_Z_x)
51+
max_Z_x = max_Z_x + 0.05 * (max_Z_x - min_Z_x)
52+
53+
54+
drift_0_v = bt.tfnp.concatenate([min_Z_x, scalar_value_at_sp])
55+
drift_1_v = bt.tfnp.concatenate([scalar_value_at_sp, max_Z_x])
56+
57+
ids = bt.t.array(ids, dtype="int32")
58+
scalar_0_v = bt.t.copy(ids)
59+
scalar_0_v[0] = 0
60+
#
61+
# scalar_1_v = bt.t.copy(ids)
62+
# scalar_1_v[-1] = 0
63+
64+
# * Iterate over surface
65+
sigm = bt.t.zeros((1, Z_x.shape[0]), dtype=BackendTensor.dtype_obj)
66+
67+
for i in range(len(ids)):
68+
# if (i == 3):
69+
sigm += ids[i] * HardSigmoidModified.apply(Z_x, drift_0_v[i], drift_1_v[i])
70+
return sigm.view(1, -1)
71+
72+
4373
def _compute_sigmoid(Z_x, scale_0, scale_1, drift_0, drift_1, drift_id, sigmoid_slope):
4474
# TODO: Test to remove reshape once multiple values are implemented
4575

@@ -49,9 +79,9 @@ def _compute_sigmoid(Z_x, scale_0, scale_1, drift_0, drift_1, drift_id, sigmoid_
4979

5080
sigmoid_slope_tensor = BackendTensor.t.array(sigmoid_slope, dtype=BackendTensor.dtype_obj)
5181

52-
active_denominator = (1 + bt.tfnp.exp(-sigmoid_slope_tensor * (Z_x - drift_0)))
82+
active_denominator = (1 + bt.tfnp.exp(-sigmoid_slope_tensor * (Z_x - drift_0)))
5383
deactive_denominator = (1 + bt.tfnp.exp(sigmoid_slope_tensor * (Z_x - drift_1)))
54-
84+
5585
active_sig = -scale_0.reshape((-1, 1)) / active_denominator
5686
deactive_sig = -scale_1.reshape((-1, 1)) / deactive_denominator
5787
activation_sig = active_sig + deactive_sig
@@ -71,25 +101,61 @@ def _add_relu():
71101

72102
# * This gets the scalar gradient
73103
import torch
74-
class HardSigmoid(torch.autograd.Function):
104+
105+
106+
class HardSigmoidModified(torch.autograd.Function):
75107
@staticmethod
76108
def forward(ctx, input, a, b):
77109
ctx.save_for_backward(input)
78110
ctx.bounds = (a, b)
79-
slope = 1 / (b - a)
80-
return torch.clamp(slope * (input - a) + 0.5, min=0, max=1)
111+
output = torch.zeros_like(input)
112+
slope_up = 100 / (b - a)
113+
114+
# For x in the range [a, b]
115+
output[(input >= a) & (input <= b)] += torch.clamp(slope_up * (input[(input >= a) & (input <= b)] - a), min=0, max=1)
116+
117+
output[(input >= a) & (input <= b)] += torch.clamp(-slope_up * (input[(input >= a) & (input <= b)] - b), min=0, max=1)
118+
119+
# Clamping the values outside the range [a, c] to zero
120+
output[input < a] = 0
121+
output[input >= b] = 0
122+
123+
return output
124+
81125

82126
@staticmethod
83127
def backward(ctx, grad_output):
84128
input, = ctx.saved_tensors
85129
a, b = ctx.bounds
130+
midpoint = (a + b) / 2
86131
grad_input = grad_output.clone()
132+
133+
# Gradient is 1/(b-a) for x in [a, midpoint), -1/(b-a) for x in (midpoint, b], and 0 elsewhere
87134
grad_input[input < a] = 0
88135
grad_input[input > b] = 0
89-
grad_input[(input >= a) & (input <= b)] = 1 / (b - a)
136+
grad_input[(input >= a) & (input < midpoint)] = 1 / (b - a)
137+
grad_input[(input > midpoint) & (input <= b)] = -1 / (b - a)
138+
90139
return grad_input, None, None
91140

92141

142+
class HardSigmoid(torch.autograd.Function):
143+
@staticmethod
144+
def forward(ctx, input, a, b, c):
145+
ctx.save_for_backward(input)
146+
ctx.bounds = (a, b)
147+
slope = 1000 / (b - a)
148+
return torch.clamp(slope * (input - a) + 0.5, min=0, max=1)
149+
150+
@staticmethod
151+
def backward(ctx, grad_output):
152+
input, = ctx.saved_tensors
153+
a, b = ctx.bounds
154+
grad_input = grad_output.clone()
155+
grad_input[input < a] = 0
156+
grad_input[input > b] = 0
157+
grad_input[(input >= a) & (input <= b)] = 1 / (b - a)
158+
return grad_input, None, None
93159

94160

95161
class CustomSigmoidFunction(torch.autograd.Function):

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from tests.fixtures.heavy_models import *
1717

1818
pykeops_enabled = False
19-
backend = AvailableBackends.numpy
19+
backend = AvailableBackends.PYTORCH
2020
use_gpu = False
2121
plot_pyvista = False # ! Set here if you want to plot the results
2222

tests/test_common/test_modules/test_activator.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
import dataclasses
22
import os
3-
import pytest
3+
44
import matplotlib.pyplot as plt
55
import numpy as np
66

77
from gempy_engine.API.interp_single._interp_scalar_field import _solve_interpolation, _evaluate_sys_eq
88
from gempy_engine.API.interp_single._interp_single_feature import input_preprocess
9+
from gempy_engine.config import AvailableBackends
910
from gempy_engine.core.data.internal_structs import SolverInput
10-
from gempy_engine.core.data.interp_output import InterpOutput
1111
from gempy_engine.modules.activator.activator_interface import activate_formation_block
12-
from gempy_engine.API.interp_single.interp_features import interpolate_single_field
12+
from gempy_engine.core.backend_tensor import BackendTensor
1313

1414
dir_name = os.path.dirname(__file__)
1515

@@ -27,13 +27,22 @@ def test_activator(simple_model_values_block_output):
2727
ids_block = activate_formation_block(simple_model_values_block_output.exported_fields, ids, 50000)[:, :-7]
2828
print(ids_block)
2929

30+
if BackendTensor.engine_backend == AvailableBackends.PYTORCH:
31+
ids_block = ids_block.detach().numpy()
32+
Z_x = Z_x.detach().numpy()
33+
3034
if plot:
3135
plt.contourf(Z_x.reshape(50, 5, 50)[:, 2, :].T, N=40, cmap="autumn")
3236
plt.colorbar()
3337

3438
plt.show()
3539

36-
plt.contourf(ids_block[0].reshape(50, 5, 50)[:, 2, :].T, N=40, cmap="viridis")
40+
plt.contourf(
41+
ids_block[0].reshape(50, 5, 50)[:, 2, :].T,
42+
N=40,
43+
cmap="viridis",
44+
# levels=[-1, 0.5, 1, 1.5, 2.5]
45+
)
3746
plt.colorbar()
3847

3948
plt.show()
@@ -71,8 +80,13 @@ def test_activator_3_layers(simple_model_3_layers, simple_grid_3d_more_points_gr
7180
exported_fields=exported_fields,
7281
ids= ids,
7382
sigmoid_slope=50000
74-
)[:, :-7]
83+
)[0, :-7]
7584

85+
if BackendTensor.engine_backend == AvailableBackends.PYTORCH:
86+
ids_block = ids_block.detach().numpy()
87+
Z_x = Z_x.detach().numpy()
88+
interpolation_input.surface_points.sp_coords = interpolation_input.surface_points.sp_coords.detach().numpy()
89+
7690
if plot:
7791
plt.contourf(Z_x.reshape(50, 5, 50)[:, 0, :].T, N=40, cmap="autumn",
7892
extent=(.25, .75, .25, .75))
@@ -83,7 +97,11 @@ def test_activator_3_layers(simple_model_3_layers, simple_grid_3d_more_points_gr
8397

8498
plt.show()
8599

86-
plt.contourf(ids_block[0, :-4].reshape(50, 5, 50)[:, 2, :].T, N=40, cmap="viridis")
100+
plt.contourf(
101+
ids_block[:-4].reshape(50, 5, 50)[:, 2, :].T,
102+
N=250,
103+
cmap="viridis"
104+
)
87105
plt.colorbar()
88106

89107
plt.show()

0 commit comments

Comments
 (0)