Skip to content

Commit 79d0a6a

Browse files
authored
Fix return type of complex norm (#382)
1 parent 79cc61b commit 79d0a6a

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

src/host/linalg.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,13 +204,14 @@ LinearAlgebra.permutedims!(dest::AbstractGPUArray, src::AbstractGPUArray, perm)
204204
## norm
205205

206206
function LinearAlgebra.norm(v::AbstractGPUArray{T}, p::Real=2) where {T}
207-
if p == Inf
207+
norm_x = if p == Inf
208208
maximum(abs.(v))
209209
elseif p == -Inf
210210
minimum(abs.(v))
211211
else
212-
mapreduce(x->abs(x)^p, +, v; init=zero(T))^(1/p)
212+
mapreduce(x->abs(x)^p, +, v; init=float(zero(T)))^(1/p)
213213
end
214+
return real(norm_x)
214215
end
215216

216217

test/testsuite/linalg.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,10 +148,11 @@ end
148148
@testset "$p-norm($sz x $T)" for sz in [(2,), (2,2), (2,2,2)],
149149
p in Any[1, 2, 3, Inf, -Inf],
150150
T in eltypes
151-
if T <: Complex || T == Int8
151+
if T == Int8
152152
continue
153153
end
154-
range = T <: Integer ? (T(1):T(10)) : T # prevent integer overflow
154+
range = real(T) <: Integer ? (T.(1:10)) : T # prevent integer overflow
155155
@test compare(norm, AT, rand(range, sz), Ref(p))
156+
@test typeof(norm(rand(range, sz))) <: Real
156157
end
157158
end

0 commit comments

Comments
 (0)