Skip to content

Commit e44a59b

Browse files
author
Andy Ferris
committed
Hopefully fixed the multiplication ambiguities
1 parent c287eff commit e44a59b

File tree

2 files changed

+122
-11
lines changed

2 files changed

+122
-11
lines changed

src/matrix_multiply.jl

Lines changed: 106 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,46 @@ end
9191
end
9292

9393

94-
@generated function *(A::StaticMatrix, b::StaticVector)
95-
TA = eltype(A)
96-
Tb = eltype(b)
94+
@generated function *{TA,Tb}(A::StaticMatrix{TA}, b::StaticVector{Tb})
95+
sA = size(A)
96+
sb = size(b)
97+
98+
s = (sA[1],)
99+
T = promote_op(matprod, TA, Tb)
100+
#println(T)
101+
102+
if sb[1] != sA[2]
103+
error("Dimension mismatch")
104+
end
105+
106+
if s == sb
107+
if T == Tb
108+
newtype = b
109+
else
110+
newtype = similar_type(b, T)
111+
end
112+
else
113+
if T == Tb
114+
newtype = similar_type(b, s)
115+
else
116+
newtype = similar_type(b, T, s)
117+
end
118+
end
119+
120+
if sA[2] != 0
121+
exprs = [reduce((ex1,ex2) -> :(+($ex1,$ex2)), [:(A[$(sub2ind(sA, k, j))]*b[$j]) for j = 1:sA[2]]) for k = 1:sA[1]]
122+
else
123+
exprs = [zero(T) for k = 1:sA[1]]
124+
end
125+
126+
return quote
127+
$(Expr(:meta,:inline))
128+
@inbounds return $(Expr(:call, newtype, Expr(:tuple, exprs...)))
129+
end
130+
end
131+
132+
# For an ambiguity relating to the below two functions
133+
@generated function *{TA<:Base.LinAlg.BlasFloat,Tb}(A::StaticMatrix{TA}, b::StaticVector{Tb})
97134
sA = size(A)
98135
sb = size(b)
99136

@@ -132,9 +169,7 @@ end
132169
end
133170

134171
# This happens to be size-inferrable from A
135-
@generated function *(A::StaticMatrix, b::AbstractVector)
136-
TA = eltype(A)
137-
Tb = eltype(b)
172+
@generated function *{TA,Tb}(A::StaticMatrix{TA}, b::AbstractVector{Tb})
138173
sA = size(A)
139174
#sb = size(b)
140175

@@ -457,11 +492,52 @@ end
457492
7
458493

459494
# TODO aliasing problems if c === b?
460-
@generated function A_mul_B!(c::StaticVector, A::StaticMatrix, b::StaticVector)
495+
@generated function A_mul_B!{T1,T2,T3}(c::StaticVector{T1}, A::StaticMatrix{T2}, b::StaticVector{T3})
496+
sA = size(A)
497+
sb = size(b)
498+
s = size(c)
499+
500+
if sb[1] != sA[2] || s[1] != sA[1]
501+
error("Dimension mismatch")
502+
end
503+
504+
if sA[2] != 0
505+
exprs = [:(c[$k] = $(reduce((ex1,ex2) -> :(+($ex1,$ex2)), [:(A[$(sub2ind(sA, k, j))]*b[$j]) for j = 1:sA[2]]))) for k = 1:sA[1]]
506+
else
507+
exprs = [:(c[$k] = $(zero(T1))) for k = 1:sA[1]]
508+
end
509+
510+
return quote
511+
$(Expr(:meta,:inline))
512+
@inbounds $(Expr(:block, exprs...))
513+
end
514+
end
515+
516+
# These two for ambiguity with a BLAS calling function
517+
@generated function A_mul_B!{T<:Union{Float32, Float64}}(c::StaticVector{T}, A::StaticMatrix{T}, b::StaticVector{T})
518+
sA = size(A)
519+
sb = size(b)
520+
s = size(c)
521+
522+
if sb[1] != sA[2] || s[1] != sA[1]
523+
error("Dimension mismatch")
524+
end
525+
526+
if sA[2] != 0
527+
exprs = [:(c[$k] = $(reduce((ex1,ex2) -> :(+($ex1,$ex2)), [:(A[$(sub2ind(sA, k, j))]*b[$j]) for j = 1:sA[2]]))) for k = 1:sA[1]]
528+
else
529+
exprs = [:(c[$k] = $(zero(T))) for k = 1:sA[1]]
530+
end
531+
532+
return quote
533+
$(Expr(:meta,:inline))
534+
@inbounds $(Expr(:block, exprs...))
535+
end
536+
end
537+
@generated function A_mul_B!{T<:Union{Complex{Float32}, Complex{Float64}}}(c::StaticVector{T}, A::StaticMatrix{T}, b::StaticVector{T})
461538
sA = size(A)
462539
sb = size(b)
463540
s = size(c)
464-
T = eltype(c)
465541

466542
if sb[1] != sA[2] || s[1] != sA[1]
467543
error("Dimension mismatch")
@@ -480,7 +556,7 @@ end
480556
end
481557

482558
# The unrolled code is inferrable from the size of A
483-
@generated function A_mul_B!(c::AbstractVector, A::StaticMatrix, b::AbstractVector)
559+
@generated function A_mul_B!{T1,T2,T3}(c::AbstractVector{T1}, A::StaticMatrix{T2}, b::AbstractVector{T3})
484560
sA = size(A)
485561
T = eltype(c)
486562

@@ -500,11 +576,30 @@ end
500576
end
501577

502578
# Ambiguity with a BLAS specialized function
503-
@generated function Base.A_mul_B!{T<:Base.LinAlg.BlasFloat}(c::StridedVector{T}, A::StaticMatrix{T}, b::StridedVector{T})
579+
# Also possible bug makes this harder to resolve (see https://github.com/JuliaLang/julia/issues/19124)
580+
# (problem being that I can't use T<:BlasFloat)
581+
@generated function A_mul_B!{T<:Union{Float64,Float32}}(c::StridedVector{T}, A::StaticMatrix{T}, b::StridedVector{T})
582+
sA = size(A)
583+
584+
if sA[2] != 0
585+
exprs = [:(c[$k] = $(reduce((ex1,ex2) -> :(+($ex1,$ex2)), [:(A[$(sub2ind(sA, k, j))]*b[$j]) for j = 1:sA[2]]))) for k = 1:sA[1]]
586+
else
587+
exprs = [:(c[$k] = $(zero(T))) for k = 1:sA[1]]
588+
end
589+
590+
return quote
591+
$(Expr(:meta,:inline))
592+
if length(b) != $(sA[2]) || length(c) != $(sA[1])
593+
error("Dimension mismatch")
594+
end
595+
@inbounds $(Expr(:block, exprs...))
596+
end
597+
end
598+
@generated function A_mul_B!{T<:Union{Complex{Float64},Complex{Float32}}}(c::StridedVector{T}, A::StaticMatrix{T}, b::StridedVector{T})
504599
sA = size(A)
505600

506601
if sA[2] != 0
507-
exprs = [:(c[$k] = $(reduce((ex1,ex2) -> :(+($ex1,$ex2)), [:(2*A[$(sub2ind(sA, k, j))]*b[$j]) for j = 1:sA[2]]))) for k = 1:sA[1]]
602+
exprs = [:(c[$k] = $(reduce((ex1,ex2) -> :(+($ex1,$ex2)), [:(A[$(sub2ind(sA, k, j))]*b[$j]) for j = 1:sA[2]]))) for k = 1:sA[1]]
508603
else
509604
exprs = [:(c[$k] = $(zero(T))) for k = 1:sA[1]]
510605
end

test/matrix_multiply.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@
2323
m4 = @MArray [1 2; 3 4]
2424
v6 = @MArray [1, 2]
2525
@test (m4*v6)::MArray == @MArray [5, 11]
26+
27+
m5 = @SMatrix [1.0 2.0; 3.0 4.0]
28+
v7 = [1.0, 2.0]
29+
@test (m5*v7)::SVector @SVector [5.0, 11.0]
2630
end
2731

2832
@testset "Vector-matrix" begin
@@ -194,5 +198,17 @@
194198
a = MMatrix{16,16,Int}()
195199
A_mul_B!(a, m, n)
196200
@test a a_array
201+
202+
# Float64
203+
vf = @SVector [2.0, 4.0]
204+
vf2 = [2.0, 4.0]
205+
mf = @SMatrix [1.0 2.0; 3.0 4.0]
206+
207+
outvecf = MVector{2,Float64}()
208+
A_mul_B!(outvecf, mf, vf)
209+
@test outvecf @MVector [10.0, 22.0]
210+
outvec2f = Vector{Float64}(2)
211+
A_mul_B!(outvec2f, mf, vf2)
212+
@test outvec2f [10.0, 22.0]
197213
end
198214
end

0 commit comments

Comments
 (0)