Skip to content

Commit de4f8d3

Browse files
committed
2-norm: use BLAS on contiguous or strided arrays if possible
- add nonstandard-named `complex` norms to `nrm2` interface - test sliced and reshaped 2-norm
1 parent 437b96e commit de4f8d3

File tree

4 files changed

+99
-9
lines changed

4 files changed

+99
-9
lines changed

src/stdlib_linalg.fypp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1221,7 +1221,7 @@ module stdlib_linalg
12211221
#:for rank in range(1, MAXRANK + 1)
12221222
pure module subroutine norm_${rank}$D_${ii}$_${ri}$(a, nrm, order, err)
12231223
!> Input ${rank}$-d matrix a${ranksuffix(rank)}$
1224-
${rt}$, intent(in) :: a${ranksuffix(rank)}$
1224+
${rt}$, intent(in), target :: a${ranksuffix(rank)}$
12251225
!> Norm of the matrix.
12261226
real(${rk}$), intent(out) :: nrm
12271227
!> Order of the matrix norm being computed.

src/stdlib_linalg_blas.fypp

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -974,12 +974,26 @@ module stdlib_linalg_blas
974974
#else
975975
module procedure stdlib_dnrm2
976976
#endif
977+
#ifdef STDLIB_EXTERNAL_BLAS
978+
pure real(dp) function dznrm2( n, x, incx )
979+
import sp,dp,qp,ilp,lk
980+
implicit none(type,external)
981+
integer(ilp), intent(in) :: incx,n
982+
complex(dp), intent(in) :: x(*)
983+
end function dznrm2
984+
#else
985+
module procedure stdlib_dznrm2
986+
#endif
977987
#:for rk,rt,ri in REAL_KINDS_TYPES
978988
#:if not rk in ["sp","dp"]
979989
module procedure stdlib_${ri}$nrm2
980-
981990
#:endif
982991
#:endfor
992+
#:for rk,rt,ri in CMPLX_KINDS_TYPES
993+
#:if not rk in ["sp","dp"]
994+
module procedure stdlib_${c2ri(ri)}$znrm2
995+
#:endif
996+
#:endfor
983997
#ifdef STDLIB_EXTERNAL_BLAS
984998
pure real(sp) function snrm2( n, x, incx )
985999
import sp,dp,qp,ilp,lk
@@ -989,6 +1003,16 @@ module stdlib_linalg_blas
9891003
end function snrm2
9901004
#else
9911005
module procedure stdlib_snrm2
1006+
#endif
1007+
#ifdef STDLIB_EXTERNAL_BLAS
1008+
pure real(sp) function scnrm2( n, x, incx )
1009+
import sp,dp,qp,ilp,lk
1010+
implicit none(type,external)
1011+
integer(ilp), intent(in) :: incx,n
1012+
complex(sp), intent(in) :: x(*)
1013+
end function scnrm2
1014+
#else
1015+
module procedure stdlib_scnrm2
9921016
#endif
9931017
end interface nrm2
9941018

src/stdlib_linalg_norms.fypp

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88
! Vector norms
99
submodule(stdlib_linalg) stdlib_linalg_norms
1010
use stdlib_linalg_constants
11-
use stdlib_linalg_blas, only: nrm2
11+
use stdlib_linalg_blas!, only: nrm2
1212
use stdlib_linalg_lapack, only: lange
1313
use stdlib_linalg_state, only: linalg_state_type, linalg_error_handling, LINALG_ERROR, &
1414
LINALG_INTERNAL_ERROR, LINALG_VALUE_ERROR
15+
use iso_c_binding, only: c_intptr_t,c_char,c_loc
1516
implicit none(type,external)
1617

1718
character(*), parameter :: this = 'norm'
@@ -29,6 +30,13 @@ submodule(stdlib_linalg) stdlib_linalg_norms
2930
module procedure parse_norm_type_character
3031
end interface parse_norm_type
3132

33+
34+
interface stride_1d
35+
#:for rk,rt,ri in ALL_KINDS_TYPES
36+
module procedure stride_1d_${ri}$
37+
#:endfor
38+
end interface stride_1d
39+
3240
contains
3341

3442
!> Parse norm type from an integer user input
@@ -93,6 +101,25 @@ submodule(stdlib_linalg) stdlib_linalg_norms
93101
end subroutine parse_norm_type_character
94102

95103
#:for rk,rt,ri in ALL_KINDS_TYPES
104+
105+
! Compute stride of a 1d array
106+
pure integer(ilp) function stride_1d_${ri}$(a) result(stride)
107+
!> Input 1-d array
108+
${rt}$, intent(in), target :: a(:)
109+
110+
integer(c_intptr_t) :: a1,a2
111+
112+
if (size(a,kind=ilp)<=1_ilp) then
113+
stride = 1_ilp
114+
else
115+
a1 = transfer(c_loc(a(1)),a1)
116+
a2 = transfer(c_loc(a(2)),a2)
117+
stride = bit_size(0_c_char)*int(a2-a1, ilp)/storage_size(a, kind=ilp)
118+
endif
119+
120+
end function stride_1d_${ri}$
121+
122+
96123
#:for it,ii in INPUT_OPTIONS
97124

98125
!==============================================
@@ -132,7 +159,7 @@ submodule(stdlib_linalg) stdlib_linalg_norms
132159
! Internal implementation
133160
pure module subroutine norm_${rank}$D_${ii}$_${ri}$(a, nrm, order, err)
134161
!> Input ${rank}$-d matrix a${ranksuffix(rank)}$
135-
${rt}$, intent(in) :: a${ranksuffix(rank)}$
162+
${rt}$, intent(in), target :: a${ranksuffix(rank)}$
136163
!> Norm of the matrix.
137164
real(${rk}$), intent(out) :: nrm
138165
!> Order of the matrix norm being computed.
@@ -142,9 +169,10 @@ submodule(stdlib_linalg) stdlib_linalg_norms
142169

143170
type(linalg_state_type) :: err_
144171

145-
integer(ilp) :: sze,norm_request
172+
integer(ilp) :: sze,norm_request,str
146173
real(${rk}$) :: rorder
147-
intrinsic :: abs, sum, sqrt, norm2, maxval, minval, conjg
174+
${rt}$, pointer :: a1d(:)
175+
intrinsic :: abs, sum, sqrt, maxval, minval, conjg
148176

149177
sze = size(a,kind=ilp)
150178

@@ -169,10 +197,12 @@ submodule(stdlib_linalg) stdlib_linalg_norms
169197
case(NORM_ONE)
170198
nrm = sum( abs(a) )
171199
case(NORM_TWO)
172-
#:if rt.startswith('complex')
173-
nrm = sqrt( real( sum( a * conjg(a) ), ${rk}$) )
200+
#:if rank==1
201+
nrm = nrm2(sze,a,incx=stride_1d(a))
202+
#:elif rt.startswith('complex')
203+
nrm = sqrt( real( sum( a * conjg(a) ), ${rk}$) )
174204
#:else
175-
nrm = norm2( a )
205+
nrm = norm2(a)
176206
#:endif
177207
case(NORM_INF)
178208
nrm = maxval( abs(a) )

test/linalg/test_linalg_norm.fypp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ module test_linalg_norm
3232
allocate(tests(0))
3333

3434
#:for rk,rt,ri in RC_KINDS_TYPES
35+
tests = [tests,new_unittest("strided_1d_norm_${ri}$",test_strided_1d_${ri}$)]
3536
#:for rank in range(1, MAXRANK)
3637
tests = [tests,new_unittest("norm_${ri}$_${rank}$d",test_norm_${ri}$_${rank}$d)]
3738
#:endfor
@@ -46,6 +47,41 @@ module test_linalg_norm
4647
end subroutine test_vector_norms
4748

4849
#:for rk,rt,ri in RC_KINDS_TYPES
50+
51+
!> Test strided norm
52+
subroutine test_strided_1d_${ri}$(error)
53+
type(error_type), allocatable, intent(out) :: error
54+
55+
integer(ilp), parameter :: m = 8_ilp
56+
integer(ilp), parameter :: n = m**2
57+
real(${rk}$), parameter :: tol = 10*sqrt(epsilon(0.0_${rk}$))
58+
${rt}$, target :: a(n)
59+
${rt}$, allocatable :: slice(:)
60+
${rt}$, pointer :: twod(:,:)
61+
real(${rk}$) :: rea(n),ima(n)
62+
63+
call random_number(rea)
64+
#:if rt.startswith('real')
65+
a = rea
66+
#:else
67+
call random_number(ima)
68+
a = cmplx(rea,ima,kind=${rk}$)
69+
#:endif
70+
71+
! Test sliced array results
72+
slice = a(4:7:59)
73+
call check(error,abs(norm(a(4:7:59),2)-norm(slice,2))<tol*max(1.0_${rk}$,norm(slice,2)), &
74+
'sliced ${rt}$ norm(a(4:7:59),2)')
75+
if (allocated(error)) return
76+
77+
! Test 2d array results
78+
twod(1:m,1:m) => a
79+
call check(error,abs(norm(twod,2)-norm(a,2))<tol*max(1.0_${rk}$,norm(twod,2)), &
80+
'2d-reshaped ${rt}$ norm(a,2)')
81+
if (allocated(error)) return
82+
83+
end subroutine test_strided_1d_${ri}$
84+
4985
#:for rank in range(1, MAXRANK)
5086

5187
!> Test several norms with different dimensions

0 commit comments

Comments
 (0)