Skip to content

Commit c287eff

Browse files
author
Andy Ferris
committed
Attempt to resolve ambiguities with BLAS funcs
1 parent eaf6a55 commit c287eff

File tree

1 file changed

+66
-18
lines changed

1 file changed

+66
-18
lines changed

src/matrix_multiply.jl

Lines changed: 66 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,34 @@ end
162162
end
163163
end
164164

165+
# Ambiguity with BLAS function
166+
@generated function *{TA <: Base.LinAlg.BlasFloat, Tb}(A::StaticMatrix{TA}, b::StridedVector{Tb})
167+
sA = size(A)
168+
169+
s = (sA[1],)
170+
T = promote_op(matprod, TA, Tb)
171+
172+
if T == Tb
173+
newtype = similar_type(A, s)
174+
else
175+
newtype = similar_type(A, T, s)
176+
end
177+
178+
if sA[2] != 0
179+
exprs = [reduce((ex1,ex2) -> :(+($ex1,$ex2)), [:(A[$(sub2ind(sA, k, j))]*b[$j]) for j = 1:sA[2]]) for k = 1:sA[1]]
180+
else
181+
exprs = [zero(T) for k = 1:sA[1]]
182+
end
183+
184+
return quote
185+
$(Expr(:meta,:inline))
186+
if length(b) != $(sA[2])
187+
error("Dimension mismatch")
188+
end
189+
@inbounds return $(Expr(:call, newtype, Expr(:tuple, exprs...)))
190+
end
191+
end
192+
165193
@generated function *(a::StaticVector, B::StaticMatrix)
166194
Ta = eltype(a)
167195
TB = eltype(B)
@@ -426,25 +454,9 @@ end
426454
@inbounds return $(Expr(:call, newtype, Expr(:tuple, exprs...)))
427455
end
428456
end
457+
7
429458

430-
431-
#function A_mul_B_blas(a, b, c, A, B)
432-
#q
433-
#end
434-
435-
# The idea here is to get pointers to stack variables and call BLAS.
436-
# This saves an aweful lot of time compared to copying SArray's to Ref{SArray{...}}
437-
# and using BLAS should be fastest for (very) large SArrays
438-
439-
# Here is an LLVM function that gets the pointer to its input, %x
440-
# After this we would make the ccall above.
441-
#
442-
# define i8* @f(i32 %x) #0 {
443-
# %1 = alloca i32, align 4
444-
# store i32 %x, i32* %1, align 4
445-
# ret i32* %1
446-
# }
447-
459+
# TODO aliasing problems if c === b?
448460
@generated function A_mul_B!(c::StaticVector, A::StaticMatrix, b::StaticVector)
449461
sA = size(A)
450462
sb = size(b)
@@ -487,6 +499,25 @@ end
487499
end
488500
end
489501

502+
# Ambiguity with a BLAS specialized function
503+
@generated function Base.A_mul_B!{T<:Base.LinAlg.BlasFloat}(c::StridedVector{T}, A::StaticMatrix{T}, b::StridedVector{T})
504+
sA = size(A)
505+
506+
if sA[2] != 0
507+
exprs = [:(c[$k] = $(reduce((ex1,ex2) -> :(+($ex1,$ex2)), [:(2*A[$(sub2ind(sA, k, j))]*b[$j]) for j = 1:sA[2]]))) for k = 1:sA[1]]
508+
else
509+
exprs = [:(c[$k] = $(zero(T))) for k = 1:sA[1]]
510+
end
511+
512+
return quote
513+
$(Expr(:meta,:inline))
514+
if length(b) != $(sA[2]) || length(c) != $(sA[1])
515+
error("Dimension mismatch")
516+
end
517+
@inbounds $(Expr(:block, exprs...))
518+
end
519+
end
520+
490521

491522
@generated function A_mul_B!(C::StaticMatrix, A::StaticMatrix, B::StaticMatrix)
492523
if isbits(C)
@@ -647,3 +678,20 @@ end
647678
exprs...
648679
)
649680
end
681+
682+
#function A_mul_B_blas(a, b, c, A, B)
683+
#q
684+
#end
685+
686+
# The idea here is to get pointers to stack variables and call BLAS.
687+
# This saves an aweful lot of time compared to copying SArray's to Ref{SArray{...}}
688+
# and using BLAS should be fastest for (very) large SArrays
689+
690+
# Here is an LLVM function that gets the pointer to its input, %x
691+
# After this we would make the ccall above.
692+
#
693+
# define i8* @f(i32 %x) #0 {
694+
# %1 = alloca i32, align 4
695+
# store i32 %x, i32* %1, align 4
696+
# ret i32* %1
697+
# }

0 commit comments

Comments
 (0)