Skip to content

Commit 49d1043

Browse files
committed
partical unification of in-placed and out-of-place matrix multiplication
1 parent e725e4d commit 49d1043

File tree

3 files changed

+34
-28
lines changed

3 files changed

+34
-28
lines changed

src/matrix_multiply.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ element access.
4040
4141
The name of the matrix to test is indicated by `asym`.
4242
"""
43-
function gen_by_access(expr_gen, a::Type{<:StaticMatrix}, asym = :a)
43+
function gen_by_access(expr_gen, a::Type{<:StaticVecOrMat}, asym = :a)
4444
return expr_gen(:any)
4545
end
4646
function gen_by_access(expr_gen, a::Type{<:Symmetric{<:Any, <:StaticMatrix}}, asym = :a)
@@ -67,10 +67,10 @@ end
6767
function gen_by_access(expr_gen, a::Type{<:LowerTriangular{<:Any, <:StaticMatrix}}, asym = :a)
6868
return expr_gen(:lower_triangular)
6969
end
70-
function gen_by_access(expr_gen, a::Type{<:Transpose{<:Any, <:StaticMatrix}}, asym = :a)
70+
function gen_by_access(expr_gen, a::Type{<:Transpose{<:Any, <:StaticVecOrMat}}, asym = :a)
7171
return expr_gen(:transpose)
7272
end
73-
function gen_by_access(expr_gen, a::Type{<:Adjoint{<:Any, <:StaticMatrix}}, asym = :a)
73+
function gen_by_access(expr_gen, a::Type{<:Adjoint{<:Any, <:StaticVecOrMat}}, asym = :a)
7474
return expr_gen(:adjoint)
7575
end
7676
"""
@@ -206,9 +206,9 @@ function uplo_access(sa, asym, k, j, uplo)
206206
return :(zero(T))
207207
end
208208
elseif uplo == :transpose
209-
return :($asym[$(LinearIndices(sa)[j, k])])
209+
return :($asym[$(LinearIndices(reverse(sa))[j, k])])
210210
elseif uplo == :ajoint
211-
return :(adjoint($asym[$(LinearIndices(sa)[j, k])]))
211+
return :(adjoint($asym[$(LinearIndices(reverse(sa))[j, k])]))
212212
end
213213
end
214214

src/matrix_multiply_add.jl

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,19 @@ Should pair with `parent`.
2424
"""
2525
struct TSize{S,T}
2626
function TSize{S,T}() where {S,T}
27-
new{S::Tuple{Vararg{StaticDimension}},T::Bool}()
27+
new{S::Tuple{Vararg{StaticDimension}},T::Symbol}()
2828
end
2929
end
30-
TSize(A::Type{<:Transpose{<:Any,<:StaticArray}}) = TSize{size(A),true}()
31-
TSize(A::Type{<:Adjoint{<:Real,<:StaticArray}}) = TSize{size(A),true}() # can't handle complex adjoints yet
32-
TSize(A::Type{<:StaticArray}) = TSize{size(A),false}()
30+
TSize(A::Type{<:StaticArrayLike}) = TSize{size(A), gen_by_access(identity, A)}()
3331
TSize(A::StaticArrayLike) = TSize(typeof(A))
34-
TSize(S::Size{s}, T=false) where s = TSize{s,T}()
32+
TSize(S::Size{s}, T=:any) where s = TSize{s,T}()
3533
TSize(s::Number) = TSize(Size(s))
36-
istranpose(::TSize{<:Any,T}) where T = T
34+
istranspose(::TSize{<:Any,T}) where T = (T === :transpose)
3735
size(::TSize{S}) where S = S
3836
Size(::TSize{S}) where S = Size{S}()
39-
Base.transpose(::TSize{S,T}) where {S,T} = TSize{reverse(S),!T}()
37+
access_type(::TSize{<:Any,T}) where T = T
38+
Base.transpose(::TSize{S,:transpose}) where {S,T} = TSize{reverse(S),:any}()
39+
Base.transpose(::TSize{S,:any}) where {S,T} = TSize{reverse(S),:transpose}()
4040

4141
# Get the parent of transposed arrays, or the array itself if it has no parent
4242
# Different from Base.parent because we only want to get rid of Transpose and Adjoint
@@ -97,13 +97,11 @@ end
9797
"Obtain an expression for the linear index of var[k,j], taking transposes into account"
9898
@inline _lind(A::Type{<:TSize}, k::Int, j::Int) = _lind(:a, A, k, j)
9999
function _lind(var::Symbol, A::Type{TSize{sa,tA}}, k::Int, j::Int) where {sa,tA}
100-
if tA
101-
return :($var[$(LinearIndices(reverse(sa))[j, k])])
102-
else
103-
return :($var[$(LinearIndices(sa)[k, j])])
104-
end
100+
return uplo_access(sa, var, k, j, tA)
105101
end
106102

103+
104+
107105
# Matrix-vector multiplication
108106
@generated function _mul!(Sc::TSize{sc}, c::StaticVecOrMatLike, Sa::TSize{sa}, Sb::TSize{sb},
109107
a::StaticMatrix, b::StaticVector, _add::MulAddMul,
@@ -133,14 +131,21 @@ end
133131
end
134132

135133
# Outer product
136-
@generated function _mul!(::TSize{sc}, c::StaticMatrix, ::TSize{sa,false}, ::TSize{sb,true},
134+
@generated function _mul!(::TSize{sc}, c::StaticMatrix, ::TSize{sa,:any}, tsb::Union{TSize{sb,:transpose},TSize{sb,:adjoint}},
137135
a::StaticVector, b::StaticVector, _add::MulAddMul) where {sa, sb, sc}
138136
if sc[1] != sa[1] || sc[2] != sb[2]
139137
throw(DimensionMismatch("Tried to multiply arrays of size $sa and $sb and assign to array of size $sc"))
140138
end
141139

140+
conjugate_b = isa(tsb, TSize{sb,:adjoint})
141+
142142
lhs = [:(c[$(LinearIndices(sc)[i,j])]) for i = 1:sa[1], j = 1:sb[2]]
143-
ab = [:(a[$i] * b[$j]) for i = 1:sa[1], j = 1:sb[2]]
143+
if conjugate_b
144+
ab = [:(a[$i] * adjoint(b[$j])) for i = 1:sa[1], j = 1:sb[2]]
145+
else
146+
ab = [:(a[$i] * b[$j]) for i = 1:sa[1], j = 1:sb[2]]
147+
end
148+
144149
exprs = _muladd_expr(lhs, ab, _add)
145150

146151
return quote
@@ -267,17 +272,18 @@ end
267272
@inline _get_raw_data(A::SizedArray) = A.data
268273
@inline _get_raw_data(A::StaticArray) = A
269274

270-
function mul_blas!(::TSize{<:Any,false}, c::StaticMatrix, ::TSize{<:Any,tA}, ::TSize{<:Any,tB},
271-
a::StaticMatrix, b::StaticMatrix, _add::MulAddMul) where {tA,tB}
272-
mat_char(tA) = tA ? 'T' : 'N'
275+
function mul_blas!(::TSize{<:Any,:any}, c::StaticMatrix,
276+
Sa::Union{TSize{<:Any,:any}, TSize{<:Any,:transpose}}, Sb::Union{TSize{<:Any,:any}, TSize{<:Any,:transpose}},
277+
a::StaticMatrix, b::StaticMatrix, _add::MulAddMul)
278+
mat_char(s) = istranspose(s) ? 'T' : 'N'
273279
T = eltype(a)
274280
A = _get_raw_data(a)
275281
B = _get_raw_data(b)
276282
C = _get_raw_data(c)
277-
BLAS.gemm!(mat_char(tA), mat_char(tB), T(alpha(_add)), A, B, T(beta(_add)), C)
283+
BLAS.gemm!(mat_char(Sa), mat_char(Sb), T(alpha(_add)), A, B, T(beta(_add)), C)
278284
end
279285

280286
# if C is transposed, transpose the entire expression
281-
@inline mul_blas!(Sc::TSize{<:Any,true}, c::StaticMatrix, Sa::TSize, Sb::TSize,
287+
@inline mul_blas!(Sc::TSize{<:Any,:transpose}, c::StaticMatrix, Sa::TSize, Sb::TSize,
282288
a::StaticMatrix, b::StaticMatrix, _add::MulAddMul) =
283289
mul_blas!(transpose(Sc), c, transpose(Sb), transpose(Sa), b, a, _add)

test/matrix_multiply_add.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,17 @@ function test_multiply_add(N1,N2,ArrayType=MArray)
4949

5050
# TSize
5151
ta = StaticArrays.TSize(A)
52-
@test !StaticArrays.istranpose(ta)
52+
@test !StaticArrays.istranspose(ta)
5353
@test size(ta) == (N1,N2)
5454
@test Size(ta) == Size(N1,N2)
5555
ta = StaticArrays.TSize(At)
56-
@test StaticArrays.istranpose(ta)
56+
@test StaticArrays.istranspose(ta)
5757
@test size(ta) == (N2,N1)
5858
@test Size(ta) == Size(N2,N1)
5959
tb = StaticArrays.TSize(b')
60-
@test StaticArrays.istranpose(tb)
60+
@test StaticArrays.access_type(tb) === :adjoint
6161
ta = transpose(ta)
62-
@test !StaticArrays.istranpose(ta)
62+
@test !StaticArrays.istranspose(ta)
6363
@test size(ta) == (N1,N2)
6464
@test Size(ta) == Size(N1,N2)
6565

0 commit comments

Comments
 (0)