Skip to content

Commit a3d24e4

Browse files
committed
extend fsum support for ndarrays
1 parent eaffa4a commit a3d24e4

File tree

3 files changed

+106
-1
lines changed

3 files changed

+106
-1
lines changed

src/stdlib_intrinsics.fypp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#:set R_KINDS_TYPES = list(zip(REAL_KINDS, REAL_TYPES, REAL_SUFFIX))
33
#:set C_KINDS_TYPES = list(zip(CMPLX_KINDS, CMPLX_TYPES, CMPLX_SUFFIX))
44
#:set RC_KINDS_TYPES = R_KINDS_TYPES + C_KINDS_TYPES
5+
#:set RANKS = range(1, MAXRANK + 1)
56

67
! This module is based on https://github.com/jalvesz/fast_math
78
module stdlib_intrinsics
@@ -22,6 +23,19 @@ module stdlib_intrinsics
2223
logical, intent(in) :: mask(:)
2324
${rt}$ :: s
2425
end function
26+
#:for rank in RANKS
27+
pure module function fsum_${rank}$d_${rs}$( x, mask ) result( s )
28+
${rt}$, intent(in) :: x${ranksuffix(rank)}$
29+
logical, intent(in), optional :: mask${ranksuffix(rank)}$
30+
${rt}$ :: s
31+
end function
32+
pure module function fsum_${rank}$d_dim_${rs}$( x , dim, mask ) result( s )
33+
${rt}$, intent(in) :: x${ranksuffix(rank)}$
34+
integer, intent(in):: dim
35+
logical, intent(in), optional :: mask${ranksuffix(rank)}$
36+
${rt}$ :: s${reduced_shape('x', rank, 'dim')}$
37+
end function
38+
#:endfor
2539
#:endfor
2640
end interface
2741
public :: fsum

src/stdlib_intrinsics_sum.fypp

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#:set R_KINDS_TYPES = list(zip(REAL_KINDS, REAL_TYPES, REAL_SUFFIX))
33
#:set C_KINDS_TYPES = list(zip(CMPLX_KINDS, CMPLX_TYPES, CMPLX_SUFFIX))
44
#:set RC_KINDS_TYPES = R_KINDS_TYPES + C_KINDS_TYPES
5+
#:set RANKS = range(1, MAXRANK + 1)
56

67
#:def cnjg(type,expression)
78
#:if 'complex' in type
@@ -66,6 +67,78 @@ pure module function fsum_1d_${rs}$_mask(a,mask) result(s)
6667
s = s + abatch(i)+abatch(chunk/2+i)
6768
end do
6869
end function
70+
71+
#:for rank in RANKS
72+
pure module function fsum_${rank}$d_${rs}$( x , mask ) result( s )
73+
${rt}$, intent(in) :: x${ranksuffix(rank)}$
74+
logical, intent(in), optional :: mask${ranksuffix(rank)}$
75+
${rt}$ :: s
76+
if(.not.present(mask)) then
77+
s = sum_recast(x,size(x))
78+
else
79+
s = sum_recast_mask(x,mask,size(x))
80+
end if
81+
contains
82+
pure ${rt}$ function sum_recast(b,n)
83+
integer, intent(in) :: n
84+
${rt}$, intent(in) :: b(n)
85+
sum_recast = fsum(b)
86+
end function
87+
pure ${rt}$ function sum_recast_mask(b,m,n)
88+
integer, intent(in) :: n
89+
${rt}$, intent(in) :: b(n)
90+
logical, intent(in) :: m(n)
91+
sum_recast_mask = fsum(b,m)
92+
end function
93+
end function
94+
95+
pure module function fsum_${rank}$d_dim_${rs}$( x , dim, mask ) result( s )
96+
${rt}$, intent(in) :: x${ranksuffix(rank)}$
97+
integer, intent(in):: dim
98+
logical, intent(in), optional :: mask${ranksuffix(rank)}$
99+
${rt}$ :: s${reduced_shape('x', rank, 'dim')}$
100+
integer :: j
101+
102+
if(.not.present(mask)) then
103+
if(dim<${rank}$)then
104+
do j = 1, size(x,dim=${rank}$)
105+
#:if rank == 2
106+
s${select_subarray(rank-1, [(rank-1, 'j')])}$ = fsum( x${select_subarray(rank, [(rank, 'j')])}$ )
107+
#:else
108+
s${select_subarray(rank-1, [(rank-1, 'j')])}$ = fsum( x${select_subarray(rank, [(rank, 'j')])}$, dim=dim )
109+
#:endif
110+
end do
111+
else
112+
do j = 1, size(x,dim=1)
113+
#:if rank == 2
114+
s${select_subarray(rank-1, [(1, 'j')])}$ = fsum( x${select_subarray(rank, [(1, 'j')])}$ )
115+
#:else
116+
s${select_subarray(rank-1, [(1, 'j')])}$ = fsum( x${select_subarray(rank, [(1, 'j')])}$, dim=${rank-1}$ )
117+
#:endif
118+
end do
119+
end if
120+
else
121+
if(dim<${rank}$)then
122+
do j = 1, size(x,dim=${rank}$)
123+
#:if rank == 2
124+
s${select_subarray(rank-1, [(rank-1, 'j')])}$ = fsum( x${select_subarray(rank, [(rank, 'j')])}$, mask=mask${select_subarray(rank, [(rank, 'j')])}$ )
125+
#:else
126+
s${select_subarray(rank-1, [(rank-1, 'j')])}$ = fsum( x${select_subarray(rank, [(rank, 'j')])}$, dim=dim, mask=mask${select_subarray(rank, [(rank, 'j')])}$ )
127+
#:endif
128+
end do
129+
else
130+
do j = 1, size(x,dim=1)
131+
#:if rank == 2
132+
s${select_subarray(rank-1, [(1, 'j')])}$ = fsum( x${select_subarray(rank, [(1, 'j')])}$, mask=mask${select_subarray(rank, [(1, 'j')])}$ )
133+
#:else
134+
s${select_subarray(rank-1, [(1, 'j')])}$ = fsum( x${select_subarray(rank, [(1, 'j')])}$, dim=${rank-1}$, mask=mask${select_subarray(rank, [(1, 'j')])}$ )
135+
#:endif
136+
end do
137+
end if
138+
end if
139+
140+
end function
141+
#:endfor
69142
#:endfor
70143

71144
#:for rk, rt, rs in RC_KINDS_TYPES

test/intrinsics/test_intrinsics.fypp

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,12 +110,30 @@ subroutine test_sum(error)
110110
xsum(3) = fsum(x,mask)+fsum(x,nmask) ! chunked summation
111111
err(1:ncalc) = abs(1._${k1}$-(xsum(1:ncalc)%re)/total_sum)
112112

113-
114113
call check(error, all(err(:)<tolerance) , "complex masked sum is not accurate" )
115114
if (allocated(error)) return
116115
end block
117116
#:endfor
118117

118+
ndarray : block
119+
use stdlib_strings, only: to_string
120+
real(sp), allocatable :: x(:,:,:)
121+
real(sp), parameter :: tolerance = epsilon(1._sp)*100
122+
integer :: i
123+
124+
allocate(x(100,100,10))
125+
call random_number(x)
126+
!> sum all elements
127+
call check(error, abs( sum(x) - fsum(x) )<tolerance*size(x) , "KO: full ndarray fsum" )
128+
if (allocated(error)) return
129+
130+
!> sum over specific rank dim
131+
do i = 1, rank(x)
132+
call check(error, norm2( sum(x,dim=i) - fsum(x,dim=i) )<tolerance*size(x) , "KO: ndarray fsum over dim "//to_string(i) )
133+
if (allocated(error)) return
134+
end do
135+
end block ndarray
136+
119137
end subroutine
120138

121139
subroutine test_dot_product(error)

0 commit comments

Comments
 (0)