1
+ #:include "common.fypp"
2
+ #:set R_KINDS_TYPES = [KT for KT in REAL_KINDS_TYPES if KT[0] in ["sp","dp"]]
3
+
4
+ module test_specialfunctions_activation
5
+ use testdrive, only : new_unittest, unittest_type, error_type, check
6
+ use stdlib_kinds
7
+ use stdlib_specialfunctions
8
+
9
+ implicit none
10
+ private
11
+
12
+ public :: collect_specialfunctions_activation
13
+
14
+ #:for k1, t1 in R_KINDS_TYPES
15
+ ${t1}$, parameter :: tol_${k1}$ = 1000 * epsilon(1.0_${k1}$)
16
+ #:endfor
17
+
18
+ contains
19
+
20
+ subroutine collect_specialfunctions_activation(testsuite)
21
+ type(unittest_type), allocatable, intent(out) :: testsuite(:)
22
+
23
+ testsuite = [ &
24
+ new_unittest("softmax", test_softmax) &
25
+ ]
26
+ end subroutine collect_specialfunctions_activation
27
+
28
+ subroutine test_softmax(error)
29
+ type(error_type), allocatable, intent(out) :: error
30
+
31
+ real(sp) :: x(3,3,3), y(3,3,3), y_ref(3,3,3)
32
+
33
+ x = reshape( [ 0.82192878, 0.76998032, 0.98611263,&
34
+ 0.8621334 , 0.65358045, 0.26387113,&
35
+ 0.12743663, 0.35237132, 0.23801647,&
36
+
37
+ 0.69773567, 0.40568874, 0.44789482,&
38
+ 0.42930753, 0.49579193, 0.53139985,&
39
+ 0.03035799, 0.65293157, 0.47613957,&
40
+
41
+ 0.21088634, 0.9356926 , 0.0991312 ,&
42
+ 0.46070181, 0.02943479, 0.17557538,&
43
+ 0.10541313, 0.33946349, 0.34804323 ] ,[3,3,3] )
44
+
45
+ !> Softmax on dim = 1
46
+ y = Softmax(x,dim=1)
47
+
48
+ y_ref = reshape( [ 0.319712639, 0.303528070, 0.376759291,&
49
+ 0.423455358, 0.343743294, 0.232801422,&
50
+ 0.296809316, 0.371676773, 0.331513911,&
51
+
52
+ 0.395936400, 0.295658976, 0.308404684,&
53
+ 0.314838648, 0.336482018, 0.348679334,&
54
+ 0.225966826, 0.421138495, 0.352894694,&
55
+
56
+ 0.252614945, 0.521480858, 0.225904226,&
57
+ 0.416388273, 0.270521373, 0.313090324,&
58
+ 0.282621205, 0.357150704, 0.360228121 ] ,[3,3,3] )
59
+
60
+ call check(error, norm2(y-y_ref) < tol_sp )
61
+ if (allocated(error)) return
62
+
63
+ !> Softmax on dim = 2
64
+ y = Softmax(x,dim=2)
65
+
66
+ y_ref = reshape( [ 0.393646270, 0.392350882, 0.510482967,&
67
+ 0.409795105, 0.349239051, 0.247922391,&
68
+ 0.196558580, 0.258410037, 0.241594598,&
69
+
70
+ 0.439052343, 0.296315849, 0.320951223,&
71
+ 0.335690796, 0.324254662, 0.348903090,&
72
+ 0.225256786, 0.379429489, 0.330145657,&
73
+
74
+ 0.314101219, 0.511530280, 0.297435701,&
75
+ 0.403239518, 0.206675291, 0.321064562,&
76
+ 0.282659233, 0.281794399, 0.381499708 ] ,[3,3,3] )
77
+
78
+ call check(error, norm2(y-y_ref) < tol_sp )
79
+ if (allocated(error)) return
80
+
81
+ !> Softmax on dim = 3
82
+ y = Softmax(x,dim=3)
83
+
84
+ y_ref = reshape( [ 0.412202179, 0.347835541, 0.501081109,&
85
+ 0.431399941, 0.418453932, 0.310344934,&
86
+ 0.346536130, 0.299599379, 0.295405835,&
87
+
88
+ 0.364060789, 0.241637364, 0.292525023,&
89
+ 0.279837668, 0.357372403, 0.405537367,&
90
+ 0.314476222, 0.404643506, 0.374830246,&
91
+
92
+ 0.223737061, 0.410527140, 0.206393898,&
93
+ 0.288762331, 0.224173695, 0.284117699,&
94
+ 0.338987619, 0.295757085, 0.329763889 ] ,[3,3,3] )
95
+
96
+ call check(error, norm2(y-y_ref) < tol_sp )
97
+ if (allocated(error)) return
98
+
99
+ end subroutine test_softmax
100
+
101
+
102
+ end module test_specialfunctions_activation
103
+
104
+ program tester
105
+ use, intrinsic :: iso_fortran_env, only : error_unit
106
+ use testdrive, only : run_testsuite, new_testsuite, testsuite_type
107
+ use test_specialfunctions_activation, only : collect_specialfunctions_activation
108
+ implicit none
109
+ integer :: stat, is
110
+ type(testsuite_type), allocatable :: testsuites(:)
111
+ character(len=*), parameter :: fmt = '("#", *(1x, a))'
112
+
113
+ stat = 0
114
+
115
+ testsuites = [new_testsuite("activation functions", &
116
+ collect_specialfunctions_activation)]
117
+
118
+ do is = 1, size(testsuites)
119
+ write(error_unit, fmt) "Testing:", testsuites(is)%name
120
+ call run_testsuite(testsuites(is)%collect, error_unit, stat)
121
+ end do
122
+
123
+ if (stat > 0) then
124
+ write(error_unit, '(i0, 1x, a)') stat, "test(s) failed!"
125
+ error stop
126
+ end if
127
+ end program tester
0 commit comments