Skip to content

Commit 89813ee

Browse files
authored
Improve type inference for adjoint/transpose of nested matrices (JuliaArrays#1181)
* Improve type inference for adjoint/transpose of nested matrices * Add comment
1 parent ebcbea5 commit 89813ee

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

src/linalg.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,15 @@ end
6262
#--------------------------------------------------
6363
# Matrix algebra
6464

65+
# _adjointtype returns the eltype of the container when computing the adjoint/transpose
66+
# of a static array. Using this method instead of calling `Base.promote_op` directly
67+
# helps with type-inference, particularly for nested static arrays,
68+
# where the adjoint is applied recursively.
69+
@inline _adjointtype(f, ::Type{T}) where {T} = Base.promote_op(f, T)
70+
for S in (:SMatrix, :MMatrix)
71+
@eval @inline _adjointtype(f, ::Type{$S{M,N,T,L}}) where {M,N,T,L} = $S{N,M,_adjointtype(f, T),L}
72+
end
73+
6574
# Transpose, etc
6675
@inline transpose(m::StaticMatrix) = _transpose(Size(m), m)
6776
# note: transpose of StaticVector is a Transpose, handled by Base
@@ -74,7 +83,7 @@ end
7483
return quote
7584
$(Expr(:meta, :inline))
7685
elements = tuple($(exprs...))
77-
@inbounds return similar_type($m, Base.promote_op(transpose, T), Size($(n2,n1)))(elements)
86+
@inbounds return similar_type($m, _adjointtype(transpose, T), Size($(n2,n1)))(elements)
7887
end
7988
end
8089

@@ -88,7 +97,7 @@ end
8897
return quote
8998
$(Expr(:meta, :inline))
9099
elements = tuple($(exprs...))
91-
@inbounds return similar_type($m, Base.promote_op(adjoint, T), Size($(n2,n1)))(elements)
100+
@inbounds return similar_type($m, _adjointtype(adjoint, T), Size($(n2,n1)))(elements)
92101
end
93102
end
94103

test/linalg.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,15 @@ end
237237
end
238238
@test adjoint(SMatrix{0,0,Vector{Int}}()) isa SMatrix{0,0,Adjoint{Int,Vector{Int}}}
239239
@test transpose(SMatrix{0,0,Vector{Int}}()) isa SMatrix{0,0,Transpose{Int,Vector{Int}}}
240+
241+
@testset "inference for nested matrices" begin
242+
A = reshape([reshape([complex(i,2i)*j for i in 1:2], 1, 2) for j in 1:6], 3, 2)
243+
for TA in (SMatrix, MMatrix), TB in (SMatrix, MMatrix)
244+
S = TA{3,2}(TB{1,2}.(A)) # static matrix of static matrices
245+
@test @inferred(transpose(S)) == transpose(A)
246+
@test @inferred(adjoint(S)) == adjoint(A)
247+
end
248+
end
240249
end
241250

242251
@testset "normalization" begin

0 commit comments

Comments
 (0)