Skip to content

Commit 6becb4f

Browse files
authored
Add specializations for istriu/istril to speed up isdiag. (#502)
1 parent 15bf446 commit 6becb4f

File tree

2 files changed

+48
-0
lines changed

2 files changed

+48
-0
lines changed

src/host/linalg.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,38 @@ function LinearAlgebra.triu!(A::AbstractGPUMatrix{T}, d::Integer = 0) where T
165165
return A
166166
end
167167

168+
# check if upper triangular starting from the kth superdiagonal.
169+
function LinearAlgebra.istriu(A::AbstractGPUMatrix, k::Integer = 0)
170+
function mapper(a, I)
171+
row, col = Tuple(I)
172+
if col < row + k
173+
return iszero(a)
174+
else
175+
true
176+
end
177+
end
178+
function reducer(a, b)
179+
a && b
180+
end
181+
mapreduce(mapper, reducer, A, eachindex(IndexCartesian(), A); init=true)
182+
end
183+
184+
# check if lower triangular starting from the kth subdiagonal.
185+
function LinearAlgebra.istril(A::AbstractGPUMatrix, k::Integer = 0)
186+
function mapper(a, I)
187+
row, col = Tuple(I)
188+
if col > row + k
189+
return iszero(a)
190+
else
191+
true
192+
end
193+
end
194+
function reducer(a, b)
195+
a && b
196+
end
197+
mapreduce(mapper, reducer, A, eachindex(IndexCartesian(), A); init=true)
198+
end
199+
168200

169201
## diagonal
170202

test/testsuite/linalg.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,22 @@
105105
@test gpu_c isa TR
106106
end
107107
end
108+
109+
@testset "istril" begin
110+
for rows in 3:4, cols in 3:4, diag in -4:4
111+
A = tril(rand(Float32, rows,cols), diag)
112+
B = AT(A)
113+
@test istril(A) == istril(B)
114+
end
115+
end
116+
117+
@testset "istriu" begin
118+
for rows in 3:4, cols in 3:4, diag in -4:4
119+
A = triu(rand(Float32, rows,cols), diag)
120+
B = AT(A)
121+
@test istriu(A) == istriu(B)
122+
end
123+
end
108124
end
109125

110126
@testset "diagonal" begin

0 commit comments

Comments
 (0)