Skip to content

Commit bfab205

Browse files
authored
Band indexing for adj/trans (#1299)
The advantage of this is that we may forward the indexing to the parent, and if the parent is a structured matrix and the band is a compile-time constant, the value may be evaluated as a constant. E.g.: ```julia julia> A = UnitUpperTriangular(reshape(1:9, 3, 3)) 3×3 UnitUpperTriangular{Int64, Base.ReshapedArray{Int64, 2, UnitRange{Int64}, Tuple{}}}: 1 4 7 ⋅ 1 8 ⋅ ⋅ 1 julia> function f(A, i, ::Val{band}) where {band} x = Adjoint(A)[LinearAlgebra.BandIndex(band,i)] Val(x) end f (generic function with 1 method) julia> @inferred f(A, 1, Val(0)) Val{1}() julia> @inferred f(A, 1, Val(1)) Val{0}() ``` In this case, the band index `0` is forwarded to the parent (`UnitUpperTriangular`), and the `getindex` is evaluated at compile-time to return `1` (if within bounds). Similarly, the band index `1` corresponds to the zeros of the triangular matrix, so the value is evaluated at compile-time. This may help with removing branches in indexing.
1 parent e7a8a15 commit bfab205

File tree

2 files changed

+53
-0
lines changed

2 files changed

+53
-0
lines changed

src/adjtrans.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,17 @@ IndexStyle(::Type{<:AdjOrTransAbsVec}) = IndexLinear()
352352
@propagate_inbounds getindex(v::AdjOrTransAbsVec, ::Colon, is::AbstractArray{Int}) = wrapperop(v)(v.parent[is])
353353
@propagate_inbounds getindex(v::AdjOrTransAbsVec, ::Colon, ::Colon) = wrapperop(v)(v.parent[:])
354354

355+
# band indexing
356+
@propagate_inbounds function getindex(A::AdjOrTransAbsMat{T}, b::BandIndex) where {T}
357+
require_one_based_indexing(A)
358+
wrapperop(A)(A.parent[BandIndex(-b.band, b.index)])::T
359+
end
360+
@propagate_inbounds function setindex!(A::AdjOrTransAbsMat, x, b::BandIndex)
361+
require_one_based_indexing(A)
362+
setindex!(A.parent, _wrapperop(A)(x), BandIndex(-b.band, b.index))
363+
return A
364+
end
365+
355366
# conversion of underlying storage
356367
convert(::Type{Adjoint{T,S}}, A::Adjoint) where {T,S} = Adjoint{T,S}(convert(S, A.parent))::Adjoint{T,S}
357368
convert(::Type{Transpose{T,S}}, A::Transpose) where {T,S} = Transpose{T,S}(convert(S, A.parent))::Transpose{T,S}

test/adjtrans.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -749,4 +749,46 @@ end
749749
end
750750
end
751751

752+
@testset "band indexing" begin
753+
n = 3
754+
A = UnitUpperTriangular(Matrix(reshape(1:n^2, n, n)))
755+
@testset "every index" begin
756+
Aadj = Adjoint(A)
757+
for k in -(n-1):n-1
758+
di = diagind(Aadj, k, IndexStyle(Aadj))
759+
for (i,d) in enumerate(di)
760+
@test Aadj[LinearAlgebra.BandIndex(k,i)] == Aadj[d]
761+
if k < 0 # the adjoint is a unit lower triangular
762+
Aadj[LinearAlgebra.BandIndex(k,i)] = n^2 + i
763+
@test Aadj[d] == n^2 + i
764+
end
765+
end
766+
end
767+
end
768+
@testset "inference for structured matrices" begin
769+
function f(A, i, ::Val{band}) where {band}
770+
x = Adjoint(A)[LinearAlgebra.BandIndex(band,i)]
771+
Val(x)
772+
end
773+
v = @inferred f(A, 1, Val(0))
774+
@test v == Val(1)
775+
v = @inferred f(A, 1, Val(1))
776+
@test v == Val(0)
777+
end
778+
@testset "non-square matrix" begin
779+
r = reshape(1:6, 2, 3)
780+
for d in (r, r*im)
781+
@test d'[LinearAlgebra.BandIndex(1,1)] == adjoint(d[2,1])
782+
@test d'[LinearAlgebra.BandIndex(-1,2)] == adjoint(d[2,3])
783+
@test transpose(d)[LinearAlgebra.BandIndex(1,1)] == transpose(d[2,1])
784+
@test transpose(d)[LinearAlgebra.BandIndex(-1,2)] == transpose(d[2,3])
785+
end
786+
end
787+
@testset "block matrix" begin
788+
B = reshape([[1 2; 3 4]*i for i in 1:4], 2, 2)
789+
@test B'[LinearAlgebra.BandIndex(1,1)] == adjoint(B[2,1])
790+
@test transpose(B)[LinearAlgebra.BandIndex(1,1)] == transpose(B[2,1])
791+
end
792+
end
793+
752794
end # module TestAdjointTranspose

0 commit comments

Comments
 (0)