Skip to content

Commit 0906515

Browse files
authored
Merge pull request #366 from martinholters/mh/mul_0.7
Fixes for mul! on Julia 0.7
2 parents 99181cf + 4ce18ee commit 0906515

File tree

1 file changed

+29
-10
lines changed

1 file changed

+29
-10
lines changed

src/matrix_multiply.jl

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ const StaticVecOrMat{T} = Union{StaticVector{<:Any, T}, StaticMatrix{<:Any, <:An
2828
@inline At_mul_Bt!(dest::StaticVecOrMat, A::StaticVecOrMat, B::StaticVecOrMat) = mul!(dest, transpose(A), transpose(B))
2929
@inline At_mul_B!(dest::StaticVecOrMat, A::StaticVecOrMat, B::StaticVecOrMat) = mul!(dest, transpose(A), B)
3030
else
31-
import LinearAlgebra: BlasFloat, matprod
31+
import LinearAlgebra: BlasFloat, matprod, mul!
3232
end
3333

3434

@@ -342,6 +342,32 @@ end
342342
gemm = :cgemm_
343343
end
344344

345+
if VERSION < v"0.7-"
346+
blascall = quote
347+
ccall((Base.BLAS.@blasfunc($gemm), Base.BLAS.libblas), Nothing,
348+
(Ptr{UInt8}, Ptr{UInt8}, Ptr{Base.BLAS.BlasInt}, Ptr{Base.BLAS.BlasInt},
349+
Ptr{Base.BLAS.BlasInt}, Ptr{$T}, Ptr{$T}, Ptr{Base.BLAS.BlasInt},
350+
Ptr{$T}, Ptr{Base.BLAS.BlasInt}, Ptr{$T}, Ptr{$T},
351+
Ptr{Base.BLAS.BlasInt}),
352+
&transA, &transB, &m, &n,
353+
&ka, &alpha, a, &strideA,
354+
b, &strideB, &beta, c,
355+
&strideC)
356+
end
357+
else
358+
blascall = quote
359+
ccall((LinearAlgebra.BLAS.@blasfunc($gemm), LinearAlgebra.BLAS.libblas), Nothing,
360+
(Ref{UInt8}, Ref{UInt8}, Ref{LinearAlgebra.BLAS.BlasInt}, Ref{LinearAlgebra.BLAS.BlasInt},
361+
Ref{LinearAlgebra.BLAS.BlasInt}, Ref{$T}, Ptr{$T}, Ref{LinearAlgebra.BLAS.BlasInt},
362+
Ptr{$T}, Ref{LinearAlgebra.BLAS.BlasInt}, Ref{$T}, Ptr{$T},
363+
Ref{LinearAlgebra.BLAS.BlasInt}),
364+
transA, transB, m, n,
365+
ka, alpha, a, strideA,
366+
b, strideB, beta, c,
367+
strideC)
368+
end
369+
end
370+
345371
return quote
346372
alpha = one(T)
347373
beta = zero(T)
@@ -355,15 +381,8 @@ end
355381
strideB = $(sb[1])
356382
strideC = $(s[1])
357383

358-
ccall((Base.BLAS.@blasfunc($gemm), Base.BLAS.libblas), Nothing,
359-
(Ptr{UInt8}, Ptr{UInt8}, Ptr{Base.BLAS.BlasInt}, Ptr{Base.BLAS.BlasInt},
360-
Ptr{Base.BLAS.BlasInt}, Ptr{$T}, Ptr{$T}, Ptr{Base.BLAS.BlasInt},
361-
Ptr{$T}, Ptr{Base.BLAS.BlasInt}, Ptr{$T}, Ptr{$T},
362-
Ptr{Base.BLAS.BlasInt}),
363-
&transA, &transB, &m, &n,
364-
&ka, &alpha, a, &strideA,
365-
b, &strideB, &beta, c,
366-
&strideC)
384+
$blascall
385+
367386
return c
368387
end
369388
else

0 commit comments

Comments
 (0)