Skip to content

Commit 6e36b6f

Browse files
committed
extend kahan sum for rank N arrays
1 parent aaa68bc commit 6e36b6f

File tree

4 files changed

+110
-9
lines changed

4 files changed

+110
-9
lines changed

doc/specs/stdlib_intrinsics.md

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ If `dim` is absent, the output is a scalar of the same `type` and `kind` as to t
4848

4949
#### Description
5050

51-
The `stdlib_sum_kahan` function can replace the intrinsic `sum` for 1D `real` or `complex` arrays. It follows a chunked implementation which maximizes vectorization potential, complemented by an `elemental` kernel based on the [kahan summation](https://en.wikipedia.org/wiki/Kahan_summation_algorithm) strategy to reduce the round-off error:
51+
The `stdlib_sum_kahan` function can replace the intrinsic `sum` for `real` or `complex` arrays. It follows a chunked implementation which maximizes vectorization potential complemented by an `elemental` kernel based on the [kahan summation](https://en.wikipedia.org/wiki/Kahan_summation_algorithm) strategy to reduce the round-off error:
5252

5353
```fortran
5454
elemental subroutine kahan_kernel_<kind>(a,s,c)
@@ -67,6 +67,8 @@ end subroutine
6767

6868
`res = ` [[stdlib_intrinsics(module):stdlib_sum_kahan(interface)]] ` (x [,mask] )`
6969

70+
`res = ` [[stdlib_intrinsics(module):stdlib_sum_kahan(interface)]] ` (x, dim [,mask] )`
71+
7072
#### Status
7173

7274
Experimental
@@ -79,11 +81,13 @@ Pure function.
7981

8082
`x`: 1D array of either `real` or `complex` type. This argument is `intent(in)`.
8183

82-
`mask` (optional): 1D array of `logical` values. This argument is `intent(in)`.
84+
`dim` (optional): scalar of type `integer` with a value in the range from 1 to n, where n equals the rank of `x`.
85+
86+
`mask` (optional): N-D array of `logical` values, with the same shape as `x`. This argument is `intent(in)`.
8387

8488
#### Output value or Result value
8589

86-
The output is a scalar of `type` and `kind` same as to that of `x`.
90+
If `dim` is absent, the output is a scalar of the same `type` and `kind` as to that of `x`. Otherwise, an array of rank n-1, where n equals the rank of `x`, and a shape similar to that of `x` with dimension `dim` dropped is returned.
8791

8892
#### Example
8993

src/stdlib_intrinsics.fypp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,14 @@ module stdlib_intrinsics
5656
!! version: experimental
5757
!!
5858
!!### Summary
59-
!! Sum elements of rank 1 arrays.
59+
!! Sum elements of rank N arrays.
6060
!! ([Specification](../page/specs/stdlib_intrinsics.html#stdlib_sum_kahan))
6161
!!
6262
!!### Description
6363
!!
64-
!! This interface provides standard conforming call for sum of elements of rank 1.
64+
!! This interface provides standard conforming call for sum of elements of any rank.
6565
!! The 1-D base implementation follows a chunked approach combined with a kahan kernel for optimizing performance and increasing accuracy.
66+
!! The `N-D` interfaces calls upon the `(N-1)-D` implementation.
6667
!! Supported data types include `real` and `complex`.
6768
!!
6869
#:for rk, rt, rs in RC_KINDS_TYPES
@@ -75,6 +76,19 @@ module stdlib_intrinsics
7576
logical, intent(in) :: mask(:)
7677
${rt}$ :: s
7778
end function
79+
#:for rank in RANKS
80+
pure module function stdlib_sum_kahan_${rank}$d_${rs}$( x, mask ) result( s )
81+
${rt}$, intent(in) :: x${ranksuffix(rank)}$
82+
logical, intent(in), optional :: mask${ranksuffix(rank)}$
83+
${rt}$ :: s
84+
end function
85+
pure module function stdlib_sum_kahan_${rank}$d_dim_${rs}$( x , dim, mask ) result( s )
86+
${rt}$, intent(in) :: x${ranksuffix(rank)}$
87+
integer, intent(in):: dim
88+
logical, intent(in), optional :: mask${ranksuffix(rank)}$
89+
${rt}$ :: s${reduced_shape('x', rank, 'dim')}$
90+
end function
91+
#:endfor
7892
#:endfor
7993
end interface
8094
public :: stdlib_sum_kahan

src/stdlib_intrinsics_sum.fypp

Lines changed: 80 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,14 @@ submodule(stdlib_intrinsics) stdlib_intrinsics_sum
99
use stdlib_kinds
1010
use stdlib_constants
1111
implicit none
12-
13-
integer, parameter :: chunk = 64
1412

1513
contains
1614

1715
!================= 1D Base implementations ============
1816
! This implementation is based on https://github.com/jalvesz/fast_math
1917
#:for rk, rt, rs in RC_KINDS_TYPES
2018
pure module function stdlib_sum_1d_${rs}$(a) result(s)
19+
integer, parameter :: chunk = 64
2120
${rt}$, intent(in) :: a(:)
2221
${rt}$ :: s
2322
${rt}$ :: abatch(chunk)
@@ -39,6 +38,7 @@ pure module function stdlib_sum_1d_${rs}$(a) result(s)
3938
end function
4039

4140
pure module function stdlib_sum_1d_${rs}$_mask(a,mask) result(s)
41+
integer, parameter :: chunk = 64
4242
${rt}$, intent(in) :: a(:)
4343
logical, intent(in) :: mask(:)
4444
${rt}$ :: s
@@ -61,13 +61,14 @@ pure module function stdlib_sum_1d_${rs}$_mask(a,mask) result(s)
6161
end function
6262

6363
pure module function stdlib_sum_kahan_1d_${rs}$(a) result(s)
64+
integer, parameter :: chunk = 64
6465
${rt}$, intent(in) :: a(:)
6566
${rt}$ :: s
6667
${rt}$ :: sbatch(chunk)
6768
${rt}$ :: cbatch(chunk)
6869
integer :: i, dr, rr
6970
! -----------------------------
70-
dr = size(a)/(chunk)
71+
dr = size(a)/chunk
7172
rr = size(a) - dr*chunk
7273
sbatch = zero_${rs}$
7374
cbatch = zero_${rs}$
@@ -83,14 +84,15 @@ pure module function stdlib_sum_kahan_1d_${rs}$(a) result(s)
8384
end function
8485

8586
pure module function stdlib_sum_kahan_1d_${rs}$_mask(a,mask) result(s)
87+
integer, parameter :: chunk = 64
8688
${rt}$, intent(in) :: a(:)
8789
logical, intent(in) :: mask(:)
8890
${rt}$ :: s
8991
${rt}$ :: sbatch(chunk)
9092
${rt}$ :: cbatch(chunk)
9193
integer :: i, dr, rr
9294
! -----------------------------
93-
dr = size(a)/(chunk)
95+
dr = size(a)/chunk
9496
rr = size(a) - dr*chunk
9597
sbatch = zero_${rs}$
9698
cbatch = zero_${rs}$
@@ -181,4 +183,78 @@ end function
181183
#:endfor
182184
#:endfor
183185

186+
#:for rk, rt, rs in RC_KINDS_TYPES
187+
#:for rank in RANKS
188+
pure module function stdlib_sum_kahan_${rank}$d_${rs}$( x , mask ) result( s )
189+
${rt}$, intent(in) :: x${ranksuffix(rank)}$
190+
logical, intent(in), optional :: mask${ranksuffix(rank)}$
191+
${rt}$ :: s
192+
if(.not.present(mask)) then
193+
s = sum_recast(x,size(x))
194+
else
195+
s = sum_recast_mask(x,mask,size(x))
196+
end if
197+
contains
198+
pure ${rt}$ function sum_recast(b,n)
199+
integer, intent(in) :: n
200+
${rt}$, intent(in) :: b(n)
201+
sum_recast = stdlib_sum_kahan(b)
202+
end function
203+
pure ${rt}$ function sum_recast_mask(b,m,n)
204+
integer, intent(in) :: n
205+
${rt}$, intent(in) :: b(n)
206+
logical, intent(in) :: m(n)
207+
sum_recast_mask = stdlib_sum_kahan(b,m)
208+
end function
209+
end function
210+
211+
pure module function stdlib_sum_kahan_${rank}$d_dim_${rs}$( x , dim, mask ) result( s )
212+
${rt}$, intent(in) :: x${ranksuffix(rank)}$
213+
integer, intent(in):: dim
214+
logical, intent(in), optional :: mask${ranksuffix(rank)}$
215+
${rt}$ :: s${reduced_shape('x', rank, 'dim')}$
216+
integer :: j
217+
218+
if(.not.present(mask)) then
219+
if(dim<${rank}$)then
220+
do j = 1, size(x,dim=${rank}$)
221+
#:if rank == 2
222+
s${select_subarray(rank-1, [(rank-1, 'j')])}$ = stdlib_sum_kahan( x${select_subarray(rank, [(rank, 'j')])}$ )
223+
#:else
224+
s${select_subarray(rank-1, [(rank-1, 'j')])}$ = stdlib_sum_kahan( x${select_subarray(rank, [(rank, 'j')])}$, dim=dim )
225+
#:endif
226+
end do
227+
else
228+
do j = 1, size(x,dim=1)
229+
#:if rank == 2
230+
s${select_subarray(rank-1, [(1, 'j')])}$ = stdlib_sum_kahan( x${select_subarray(rank, [(1, 'j')])}$ )
231+
#:else
232+
s${select_subarray(rank-1, [(1, 'j')])}$ = stdlib_sum_kahan( x${select_subarray(rank, [(1, 'j')])}$, dim=${rank-1}$ )
233+
#:endif
234+
end do
235+
end if
236+
else
237+
if(dim<${rank}$)then
238+
do j = 1, size(x,dim=${rank}$)
239+
#:if rank == 2
240+
s${select_subarray(rank-1, [(rank-1, 'j')])}$ = stdlib_sum_kahan( x${select_subarray(rank, [(rank, 'j')])}$, mask=mask${select_subarray(rank, [(rank, 'j')])}$ )
241+
#:else
242+
s${select_subarray(rank-1, [(rank-1, 'j')])}$ = stdlib_sum_kahan( x${select_subarray(rank, [(rank, 'j')])}$, dim=dim, mask=mask${select_subarray(rank, [(rank, 'j')])}$ )
243+
#:endif
244+
end do
245+
else
246+
do j = 1, size(x,dim=1)
247+
#:if rank == 2
248+
s${select_subarray(rank-1, [(1, 'j')])}$ = stdlib_sum_kahan( x${select_subarray(rank, [(1, 'j')])}$, mask=mask${select_subarray(rank, [(1, 'j')])}$ )
249+
#:else
250+
s${select_subarray(rank-1, [(1, 'j')])}$ = stdlib_sum_kahan( x${select_subarray(rank, [(1, 'j')])}$, dim=${rank-1}$, mask=mask${select_subarray(rank, [(1, 'j')])}$ )
251+
#:endif
252+
end do
253+
end if
254+
end if
255+
256+
end function
257+
#:endfor
258+
#:endfor
259+
184260
end submodule stdlib_intrinsics_sum

test/intrinsics/test_intrinsics.fypp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,11 +127,18 @@ subroutine test_sum(error)
127127
call check(error, abs( sum(x) - stdlib_sum(x) )<tolerance*size(x) , "KO: full ndarray stdlib_sum" )
128128
if (allocated(error)) return
129129

130+
call check(error, abs( sum(x) - stdlib_sum_kahan(x) )<tolerance*size(x) , "KO: full ndarray stdlib_sum_kahan" )
131+
if (allocated(error)) return
132+
130133
!> sum over specific rank dim
131134
do i = 1, rank(x)
132135
call check(error, norm2( sum(x,dim=i) - stdlib_sum(x,dim=i) )<tolerance*size(x) ,&
133136
"KO: ndarray stdlib_sum over dim "//to_string(i) )
134137
if (allocated(error)) return
138+
139+
call check(error, norm2( sum(x,dim=i) - stdlib_sum_kahan(x,dim=i) )<tolerance*size(x) ,&
140+
"KO: ndarray stdlib_sum_kahan over dim "//to_string(i) )
141+
if (allocated(error)) return
135142
end do
136143
end block ndarray
137144

0 commit comments

Comments
 (0)