@@ -5,7 +5,7 @@ module test_specialfunctions_activation
5
5
use testdrive, only : new_unittest, unittest_type, error_type, check
6
6
use stdlib_kinds
7
7
use stdlib_specialfunctions
8
-
8
+ use stdlib_math, only: linspace
9
9
implicit none
10
10
private
11
11
@@ -21,10 +21,46 @@ contains
21
21
type(unittest_type), allocatable, intent(out) :: testsuite(:)
22
22
23
23
testsuite = [ &
24
+ new_unittest("sigmoid", test_sigmoid), &
25
+ new_unittest("gelu" , test_gelu ), &
24
26
new_unittest("softmax", test_softmax) &
25
27
]
26
28
end subroutine collect_specialfunctions_activation
27
29
30
+ subroutine test_sigmoid(error)
31
+ type(error_type), allocatable, intent(out) :: error
32
+ integer, parameter :: n = 10
33
+ real(sp) :: x(n), y(n), y_ref(n)
34
+
35
+ y_ref = [0.119202919304371, 0.174285307526588, 0.247663781046867,&
36
+ 0.339243650436401, 0.444671928882599, 0.555328071117401,&
37
+ 0.660756349563599, 0.752336204051971, 0.825714707374573,&
38
+ 0.880797028541565]
39
+ x = linspace(-2._sp, 2._sp, n)
40
+ y = sigmoid( x )
41
+ call check(error, norm2(y-y_ref) < n*tol_sp )
42
+ if (allocated(error)) return
43
+ end subroutine
44
+
45
+ subroutine test_gelu(error)
46
+ type(error_type), allocatable, intent(out) :: error
47
+ integer, parameter :: n = 10
48
+ real(sp) :: x(n), y(n), y_ref(n)
49
+
50
+ y_ref = [-0.0455002784729 , -0.093188509345055, -0.148066952824593,&
51
+ -0.168328359723091, -0.0915712043643 , 0.130650997161865,&
52
+ 0.498338282108307, 0.963044226169586, 1.462367057800293,&
53
+ 1.9544997215271 ]
54
+ x = linspace(-2._sp, 2._sp, n)
55
+ y = gelu( x )
56
+ call check(error, norm2(y-y_ref) < n*tol_sp )
57
+ if (allocated(error)) return
58
+
59
+ y = gelu_approx( x )
60
+ call check(error, norm2(y-y_ref) < n*tol_sp )
61
+ if (allocated(error)) return
62
+ end subroutine
63
+
28
64
subroutine test_softmax(error)
29
65
type(error_type), allocatable, intent(out) :: error
30
66
@@ -88,7 +124,7 @@ contains
88
124
0.364060789, 0.241637364, 0.292525023,&
89
125
0.279837668, 0.357372403, 0.405537367,&
90
126
0.314476222, 0.404643506, 0.374830246,&
91
-
127
+
92
128
0.223737061, 0.410527140, 0.206393898,&
93
129
0.288762331, 0.224173695, 0.284117699,&
94
130
0.338987619, 0.295757085, 0.329763889 ] ,[3,3,3] )
0 commit comments