Skip to content

Commit 5165fd3

Browse files
authored
Unwrap triangular matrices in broadcast (#1332)
In broadcasting over triangular matrices, we loop only over the stored indices. For these indices, indexing into a triangular matrix is equivalent to indexing into its parent. We may therefore replace an `UpperOrLowerTriangular` matrix by its parent, which removes the branch in `getindex`. This improves performance: ```julia julia> L = LowerTriangular(zeros(600,600)); julia> L2 = copy(L); julia> @Btime broadcast!(+, $L2, $L, $L); 161.176 μs (0 allocations: 0 bytes) # master 80.894 μs (0 allocations: 0 bytes) # this PR ``` This replacement is performed recursively on a `Broadcasted` object by looping over its `args`, and non-triangular elements are left untouched. Only `UpperOrLowerTriangular` matrices will be replaced by their `parent`s.
1 parent 9d26faf commit 5165fd3

File tree

3 files changed

+24
-6
lines changed

3 files changed

+24
-6
lines changed

src/structuredbroadcast.jl

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,13 +269,24 @@ function copyto!(dest::Tridiagonal, bc::Broadcasted{<:StructuredMatrixStyle})
269269
return dest
270270
end
271271

272+
# Recursively replace wrapped matrices by their parents to improve broadcasting performance
273+
# We may do this because the indexing within `copyto!` is restricted to the stored indices
274+
preprocess_broadcasted(::Type{T}, A) where {T} = _preprocess_broadcasted(T, A)
275+
function preprocess_broadcasted(::Type{T}, bc::Broadcasted) where {T}
276+
args = map(x -> preprocess_broadcasted(T, x), bc.args)
277+
Broadcast.Broadcasted(bc.f, args, bc.axes)
278+
end
279+
_preprocess_broadcasted(::Type{LowerTriangular}, A) = lowertridata(A)
280+
_preprocess_broadcasted(::Type{UpperTriangular}, A) = uppertridata(A)
281+
272282
function copyto!(dest::LowerTriangular, bc::Broadcasted{<:StructuredMatrixStyle})
273283
isvalidstructbc(dest, bc) || return copyto!(dest, convert(Broadcasted{Nothing}, bc))
274284
axs = axes(dest)
275285
axes(bc) == axs || Broadcast.throwdm(axes(bc), axs)
286+
bc_unwrapped = preprocess_broadcasted(LowerTriangular, bc)
276287
for j in axs[2]
277288
for i in j:axs[1][end]
278-
@inbounds dest.data[i,j] = bc[CartesianIndex(i, j)]
289+
@inbounds dest.data[i,j] = bc_unwrapped[CartesianIndex(i, j)]
279290
end
280291
end
281292
return dest
@@ -285,9 +296,10 @@ function copyto!(dest::UpperTriangular, bc::Broadcasted{<:StructuredMatrixStyle}
285296
isvalidstructbc(dest, bc) || return copyto!(dest, convert(Broadcasted{Nothing}, bc))
286297
axs = axes(dest)
287298
axes(bc) == axs || Broadcast.throwdm(axes(bc), axs)
299+
bc_unwrapped = preprocess_broadcasted(UpperTriangular, bc)
288300
for j in axs[2]
289301
for i in 1:j
290-
@inbounds dest.data[i,j] = bc[CartesianIndex(i, j)]
302+
@inbounds dest.data[i,j] = bc_unwrapped[CartesianIndex(i, j)]
291303
end
292304
end
293305
return dest

src/triangular.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -702,10 +702,8 @@ end
702702

703703
uppertridata(A) = A
704704
lowertridata(A) = A
705-
# we restrict these specializations only to strided matrices to avoid cases where an UpperTriangular type
706-
# doesn't share its indexing with the parent
707-
uppertridata(A::UpperTriangular{<:Any, <:StridedMatrix}) = parent(A)
708-
lowertridata(A::LowerTriangular{<:Any, <:StridedMatrix}) = parent(A)
705+
uppertridata(A::UpperTriangular) = parent(A)
706+
lowertridata(A::LowerTriangular) = parent(A)
709707

710708
@inline _rscale_add!(A::AbstractTriangular, B::AbstractTriangular, C::Number, alpha::Number, beta::Number) =
711709
@stable_muladdmul _triscale!(A, B, C, MulAddMul(alpha, beta))

test/structuredbroadcast.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,4 +388,12 @@ end
388388
@test ind == CartesianIndex(1,1)
389389
end
390390

391+
@testset "nested triangular broadcast" begin
392+
for T in (LowerTriangular, UpperTriangular)
393+
L = T(rand(Int,4,4))
394+
M = Matrix(L)
395+
@test L .+ L .+ 0 .+ L .+ 0 .- L == 2M
396+
end
397+
end
398+
391399
end

0 commit comments

Comments
 (0)