Skip to content

Commit 1ee6ec1

Browse files
committed
Unwrap triangular matrices in broadcast (#1332)
In broadcasting over triangular matrices, we loop only over the stored indices. For these indices, indexing into a triangular matrix is equivalent to indexing into its parent. We may therefore replace an `UpperOrLowerTriangular` matrix by its parent, which removes the branch in `getindex`. This improves performance: ```julia julia> L = LowerTriangular(zeros(600,600)); julia> L2 = copy(L); julia> @Btime broadcast!(+, $L2, $L, $L); 161.176 μs (0 allocations: 0 bytes) # master 80.894 μs (0 allocations: 0 bytes) # this PR ``` This replacement is performed recursively on a `Broadcasted` object by looping over its `args`, and non-triangular elements are left untouched. Only `UpperOrLowerTriangular` matrices will be replaced by their `parent`s. (cherry picked from commit 5165fd3)
1 parent 6d36322 commit 1ee6ec1

File tree

3 files changed

+24
-6
lines changed

3 files changed

+24
-6
lines changed

src/structuredbroadcast.jl

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,13 +264,24 @@ function copyto!(dest::Tridiagonal, bc::Broadcasted{<:StructuredMatrixStyle})
264264
return dest
265265
end
266266

267+
# Recursively replace wrapped matrices by their parents to improve broadcasting performance
268+
# We may do this because the indexing within `copyto!` is restricted to the stored indices
269+
preprocess_broadcasted(::Type{T}, A) where {T} = _preprocess_broadcasted(T, A)
270+
function preprocess_broadcasted(::Type{T}, bc::Broadcasted) where {T}
271+
args = map(x -> preprocess_broadcasted(T, x), bc.args)
272+
Broadcast.Broadcasted(bc.f, args, bc.axes)
273+
end
274+
_preprocess_broadcasted(::Type{LowerTriangular}, A) = lowertridata(A)
275+
_preprocess_broadcasted(::Type{UpperTriangular}, A) = uppertridata(A)
276+
267277
function copyto!(dest::LowerTriangular, bc::Broadcasted{<:StructuredMatrixStyle})
268278
isvalidstructbc(dest, bc) || return copyto!(dest, convert(Broadcasted{Nothing}, bc))
269279
axs = axes(dest)
270280
axes(bc) == axs || Broadcast.throwdm(axes(bc), axs)
281+
bc_unwrapped = preprocess_broadcasted(LowerTriangular, bc)
271282
for j in axs[2]
272283
for i in j:axs[1][end]
273-
@inbounds dest.data[i,j] = bc[CartesianIndex(i, j)]
284+
@inbounds dest.data[i,j] = bc_unwrapped[CartesianIndex(i, j)]
274285
end
275286
end
276287
return dest
@@ -280,9 +291,10 @@ function copyto!(dest::UpperTriangular, bc::Broadcasted{<:StructuredMatrixStyle}
280291
isvalidstructbc(dest, bc) || return copyto!(dest, convert(Broadcasted{Nothing}, bc))
281292
axs = axes(dest)
282293
axes(bc) == axs || Broadcast.throwdm(axes(bc), axs)
294+
bc_unwrapped = preprocess_broadcasted(UpperTriangular, bc)
283295
for j in axs[2]
284296
for i in 1:j
285-
@inbounds dest.data[i,j] = bc[CartesianIndex(i, j)]
297+
@inbounds dest.data[i,j] = bc_unwrapped[CartesianIndex(i, j)]
286298
end
287299
end
288300
return dest

src/triangular.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -658,10 +658,8 @@ end
658658

659659
uppertridata(A) = A
660660
lowertridata(A) = A
661-
# we restrict these specializations only to strided matrices to avoid cases where an UpperTriangular type
662-
# doesn't share its indexing with the parent
663-
uppertridata(A::UpperTriangular{<:Any, <:StridedMatrix}) = parent(A)
664-
lowertridata(A::LowerTriangular{<:Any, <:StridedMatrix}) = parent(A)
661+
uppertridata(A::UpperTriangular) = parent(A)
662+
lowertridata(A::LowerTriangular) = parent(A)
665663

666664
@inline _rscale_add!(A::AbstractTriangular, B::AbstractTriangular, C::Number, alpha::Number, beta::Number) =
667665
@stable_muladdmul _triscale!(A, B, C, MulAddMul(alpha, beta))

test/structuredbroadcast.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,4 +385,12 @@ end
385385
@test ind == CartesianIndex(1,1)
386386
end
387387

388+
@testset "nested triangular broadcast" begin
389+
for T in (LowerTriangular, UpperTriangular)
390+
L = T(rand(Int,4,4))
391+
M = Matrix(L)
392+
@test L .+ L .+ 0 .+ L .+ 0 .- L == 2M
393+
end
394+
end
395+
388396
end

0 commit comments

Comments
 (0)