Skip to content

Commit b24f0b5

Browse files
GVignemaleadt
andauthored
Add iszero and isone for AbstractGPUMatrix (#419)
Co-authored-by: Tim Besard <tim.besard@gmail.com>
1 parent 219a0b8 commit b24f0b5

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

src/host/linalg.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,3 +418,21 @@ LinearAlgebra.axpby!(alpha::Number, x::AbstractGPUArray,
418418
beta::Number, y::AbstractGPUArray) = y .= x.*alpha .+ y.*beta
419419

420420
LinearAlgebra.axpy!(alpha::Number, x::AbstractGPUArray, y::AbstractGPUArray) = y .+= x.*alpha
421+
422+
## identity and zero equality check
423+
424+
Base.iszero(x::AbstractGPUMatrix{T}) where {T} = all(iszero, x)
425+
function Base.isone(x::AbstractGPUMatrix{T}) where {T}
426+
n,m = size(x)
427+
m != n && return false
428+
429+
# lazily perform `x-I`
430+
bc = Broadcast.broadcasted(x, CartesianIndices(x)) do _x, inds
431+
_x - (inds[1] == inds[2] ? one(_x) : zero(_x))
432+
end
433+
# call `GPUArrays.mapreducedim!` directly, which supports Broadcasted inputs
434+
y = similar(x, Bool, 1)
435+
GPUArrays.mapreducedim!(iszero, &, y, Broadcast.instantiate(bc); init=true)
436+
437+
Array(y)[]
438+
end

test/testsuite/linalg.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,16 @@
222222
@test axpby!(alpha,x,beta,y) T.([7,10,13])
223223
@test axpy!(alpha,x,y) T.([8,12,16])
224224
end
225+
226+
@testset "iszero and isone" for T in eltypes
227+
A = one(AT(rand(T, 2, 2)))
228+
@test isone(A)
229+
@test iszero(A) == false
230+
231+
A = zero(AT(rand(T, 2, 2)))
232+
@test iszero(A)
233+
@test isone(A) == false
234+
end
225235
end
226236

227237
@testsuite "linalg/mul!/vector-matrix" (AT, eltypes)->begin

0 commit comments

Comments
 (0)