Skip to content

Commit f20d471

Browse files
authored
Use oneMKL with Float64 matmul. (#416)
1 parent fadcd8d commit f20d471

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

lib/mkl/linalg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ function LinearAlgebra.generic_matmatmul!(C::oneStridedMatrix, tA, tB, A::oneStr
220220
end
221221

222222
if all(in(('N', 'T', 'C')), (tA, tB))
223-
if T <: onemklFloat && eltype(A) == eltype(B) == T
223+
if T <: Union{onemklFloat, onemklComplex, onemklHalf} && eltype(A) == eltype(B) == T
224224
return gemm!(tA, tB, alpha, A, B, beta, C)
225225
end
226226
end

lib/mkl/oneMKL.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@ using LinearAlgebra.LAPACK: chkargsok, chklapackerror, chktrans, chkside, chkdia
1717

1818
using SparseArrays
1919

20-
# Exclude Float16 for now, since many oneMKL functions - copy, scal, do not take Float16
20+
# Exclude Float16 for now, since many oneMKL functions do not take Float16
2121
const onemklFloat = Union{Float64,Float32,ComplexF64,ComplexF32}
2222
const onemklComplex = Union{ComplexF32,ComplexF64}
23-
const onemklHalf = Union{Float16,ComplexF16}
23+
const onemklHalf = Float16
2424

2525
include("array.jl")
2626
include("utils.jl")

lib/mkl/wrappers_blas.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,7 @@ for (fname, elty, cty, sty, supty) in ((:onemklSrot,:Float32,:Float32,:Float32,:
537537
end
538538
end
539539

540-
function axpy!(n::Integer,
540+
function axpy!(n::Integer,
541541
alpha::Number,
542542
x::oneStridedArray{ComplexF16},
543543
y::oneStridedArray{ComplexF16})
@@ -1260,7 +1260,7 @@ function dgmm(mode::Char, A::oneStridedMatrix{T}, X::oneStridedVector{T}) where
12601260
dgmm!( mode, A, X, similar(A, (m,n) ) )
12611261
end
12621262

1263-
for (fname, elty) in
1263+
for (fname, elty) in
12641264
((:onemklSgemmBatchStrided, Float32),
12651265
(:onemklDgemmBatchStrided, Float64),
12661266
(:onemklCgemmBatchStrided, ComplexF32),

0 commit comments

Comments
 (0)