Skip to content

Commit 9d7eb7c

Browse files
committed
add tests for sigmoid and gelu
1 parent f4ad250 commit 9d7eb7c

File tree

1 file changed

+38
-2
lines changed

1 file changed

+38
-2
lines changed

test/specialfunctions/test_specialfunctions_activations.fypp

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ module test_specialfunctions_activation
55
use testdrive, only : new_unittest, unittest_type, error_type, check
66
use stdlib_kinds
77
use stdlib_specialfunctions
8-
8+
use stdlib_math, only: linspace
99
implicit none
1010
private
1111

@@ -21,10 +21,46 @@ contains
2121
type(unittest_type), allocatable, intent(out) :: testsuite(:)
2222

2323
testsuite = [ &
24+
new_unittest("sigmoid", test_sigmoid), &
25+
new_unittest("gelu" , test_gelu ), &
2426
new_unittest("softmax", test_softmax) &
2527
]
2628
end subroutine collect_specialfunctions_activation
2729

30+
subroutine test_sigmoid(error)
31+
type(error_type), allocatable, intent(out) :: error
32+
integer, parameter :: n = 10
33+
real(sp) :: x(n), y(n), y_ref(n)
34+
35+
y_ref = [0.119202919304371, 0.174285307526588, 0.247663781046867,&
36+
0.339243650436401, 0.444671928882599, 0.555328071117401,&
37+
0.660756349563599, 0.752336204051971, 0.825714707374573,&
38+
0.880797028541565]
39+
x = linspace(-2._sp, 2._sp, n)
40+
y = sigmoid( x )
41+
call check(error, norm2(y-y_ref) < n*tol_sp )
42+
if (allocated(error)) return
43+
end subroutine
44+
45+
subroutine test_gelu(error)
46+
type(error_type), allocatable, intent(out) :: error
47+
integer, parameter :: n = 10
48+
real(sp) :: x(n), y(n), y_ref(n)
49+
50+
y_ref = [-0.0455002784729 , -0.093188509345055, -0.148066952824593,&
51+
-0.168328359723091, -0.0915712043643 , 0.130650997161865,&
52+
0.498338282108307, 0.963044226169586, 1.462367057800293,&
53+
1.9544997215271 ]
54+
x = linspace(-2._sp, 2._sp, n)
55+
y = gelu( x )
56+
call check(error, norm2(y-y_ref) < n*tol_sp )
57+
if (allocated(error)) return
58+
59+
y = gelu_approx( x )
60+
call check(error, norm2(y-y_ref) < n*tol_sp )
61+
if (allocated(error)) return
62+
end subroutine
63+
2864
subroutine test_softmax(error)
2965
type(error_type), allocatable, intent(out) :: error
3066

@@ -88,7 +124,7 @@ contains
88124
0.364060789, 0.241637364, 0.292525023,&
89125
0.279837668, 0.357372403, 0.405537367,&
90126
0.314476222, 0.404643506, 0.374830246,&
91-
127+
92128
0.223737061, 0.410527140, 0.206393898,&
93129
0.288762331, 0.224173695, 0.284117699,&
94130
0.338987619, 0.295757085, 0.329763889 ] ,[3,3,3] )

0 commit comments

Comments
 (0)