Skip to content

Commit f4ad250

Browse files
committed
add tests for activations
1 parent 1b3bf4f commit f4ad250

File tree

2 files changed

+128
-0
lines changed

2 files changed

+128
-0
lines changed

test/specialfunctions/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
# Create a list of the files to be preprocessed
44
set(fppFiles
5+
test_specialfunctions_activations.fypp
56
test_specialfunctions_gamma.fypp
67
)
78

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
#:include "common.fypp"
2+
#:set R_KINDS_TYPES = [KT for KT in REAL_KINDS_TYPES if KT[0] in ["sp","dp"]]
3+
4+
module test_specialfunctions_activation
5+
use testdrive, only : new_unittest, unittest_type, error_type, check
6+
use stdlib_kinds
7+
use stdlib_specialfunctions
8+
9+
implicit none
10+
private
11+
12+
public :: collect_specialfunctions_activation
13+
14+
#:for k1, t1 in R_KINDS_TYPES
15+
${t1}$, parameter :: tol_${k1}$ = 1000 * epsilon(1.0_${k1}$)
16+
#:endfor
17+
18+
contains
19+
20+
subroutine collect_specialfunctions_activation(testsuite)
21+
type(unittest_type), allocatable, intent(out) :: testsuite(:)
22+
23+
testsuite = [ &
24+
new_unittest("softmax", test_softmax) &
25+
]
26+
end subroutine collect_specialfunctions_activation
27+
28+
subroutine test_softmax(error)
29+
type(error_type), allocatable, intent(out) :: error
30+
31+
real(sp) :: x(3,3,3), y(3,3,3), y_ref(3,3,3)
32+
33+
x = reshape( [ 0.82192878, 0.76998032, 0.98611263,&
34+
0.8621334 , 0.65358045, 0.26387113,&
35+
0.12743663, 0.35237132, 0.23801647,&
36+
37+
0.69773567, 0.40568874, 0.44789482,&
38+
0.42930753, 0.49579193, 0.53139985,&
39+
0.03035799, 0.65293157, 0.47613957,&
40+
41+
0.21088634, 0.9356926 , 0.0991312 ,&
42+
0.46070181, 0.02943479, 0.17557538,&
43+
0.10541313, 0.33946349, 0.34804323 ] ,[3,3,3] )
44+
45+
!> Softmax on dim = 1
46+
y = Softmax(x,dim=1)
47+
48+
y_ref = reshape( [ 0.319712639, 0.303528070, 0.376759291,&
49+
0.423455358, 0.343743294, 0.232801422,&
50+
0.296809316, 0.371676773, 0.331513911,&
51+
52+
0.395936400, 0.295658976, 0.308404684,&
53+
0.314838648, 0.336482018, 0.348679334,&
54+
0.225966826, 0.421138495, 0.352894694,&
55+
56+
0.252614945, 0.521480858, 0.225904226,&
57+
0.416388273, 0.270521373, 0.313090324,&
58+
0.282621205, 0.357150704, 0.360228121 ] ,[3,3,3] )
59+
60+
call check(error, norm2(y-y_ref) < tol_sp )
61+
if (allocated(error)) return
62+
63+
!> Softmax on dim = 2
64+
y = Softmax(x,dim=2)
65+
66+
y_ref = reshape( [ 0.393646270, 0.392350882, 0.510482967,&
67+
0.409795105, 0.349239051, 0.247922391,&
68+
0.196558580, 0.258410037, 0.241594598,&
69+
70+
0.439052343, 0.296315849, 0.320951223,&
71+
0.335690796, 0.324254662, 0.348903090,&
72+
0.225256786, 0.379429489, 0.330145657,&
73+
74+
0.314101219, 0.511530280, 0.297435701,&
75+
0.403239518, 0.206675291, 0.321064562,&
76+
0.282659233, 0.281794399, 0.381499708 ] ,[3,3,3] )
77+
78+
call check(error, norm2(y-y_ref) < tol_sp )
79+
if (allocated(error)) return
80+
81+
!> Softmax on dim = 3
82+
y = Softmax(x,dim=3)
83+
84+
y_ref = reshape( [ 0.412202179, 0.347835541, 0.501081109,&
85+
0.431399941, 0.418453932, 0.310344934,&
86+
0.346536130, 0.299599379, 0.295405835,&
87+
88+
0.364060789, 0.241637364, 0.292525023,&
89+
0.279837668, 0.357372403, 0.405537367,&
90+
0.314476222, 0.404643506, 0.374830246,&
91+
92+
0.223737061, 0.410527140, 0.206393898,&
93+
0.288762331, 0.224173695, 0.284117699,&
94+
0.338987619, 0.295757085, 0.329763889 ] ,[3,3,3] )
95+
96+
call check(error, norm2(y-y_ref) < tol_sp )
97+
if (allocated(error)) return
98+
99+
end subroutine test_softmax
100+
101+
102+
end module test_specialfunctions_activation
103+
104+
program tester
105+
use, intrinsic :: iso_fortran_env, only : error_unit
106+
use testdrive, only : run_testsuite, new_testsuite, testsuite_type
107+
use test_specialfunctions_activation, only : collect_specialfunctions_activation
108+
implicit none
109+
integer :: stat, is
110+
type(testsuite_type), allocatable :: testsuites(:)
111+
character(len=*), parameter :: fmt = '("#", *(1x, a))'
112+
113+
stat = 0
114+
115+
testsuites = [new_testsuite("activation functions", &
116+
collect_specialfunctions_activation)]
117+
118+
do is = 1, size(testsuites)
119+
write(error_unit, fmt) "Testing:", testsuites(is)%name
120+
call run_testsuite(testsuites(is)%collect, error_unit, stat)
121+
end do
122+
123+
if (stat > 0) then
124+
write(error_unit, '(i0, 1x, a)') stat, "test(s) failed!"
125+
error stop
126+
end if
127+
end program tester

0 commit comments

Comments
 (0)