Skip to content

Commit 443dab4

Browse files
authored
Reduce number of mul! methods (#472)
1 parent e733292 commit 443dab4

File tree

1 file changed

+62
-20
lines changed

1 file changed

+62
-20
lines changed

src/host/linalg.jl

Lines changed: 62 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,29 @@
11
# integration with LinearAlgebra stdlib
22

3+
using LinearAlgebra: MulAddMul
4+
5+
if isdefined(LinearAlgebra, :wrap) # i.e., VERSION >= v"1.10.0-DEV.1365"
6+
using LinearAlgebra: wrap
7+
else
8+
function wrap(A::AbstractVecOrMat, tA::AbstractChar)
9+
if tA == 'N'
10+
return A
11+
elseif tA == 'T'
12+
return transpose(A)
13+
elseif tA == 'C'
14+
return adjoint(A)
15+
elseif tA == 'H'
16+
return Hermitian(A, :U)
17+
elseif tA == 'h'
18+
return Hermitian(A, :L)
19+
elseif tA == 'S'
20+
return Symmetric(A, :U)
21+
else # tA == 's'
22+
return Symmetric(A, :L)
23+
end
24+
end
25+
end
26+
327
## transpose and adjoint
428

529
function LinearAlgebra.transpose!(B::AbstractGPUVector, A::AbstractGPUMatrix)
@@ -319,28 +343,46 @@ function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::Abstrac
319343
C
320344
end
321345

322-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::AbstractGPUVecOrMat, B::AbstractGPUVecOrMat, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b)
323-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::AbstractGPUVecOrMat, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b)
324-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::AbstractGPUVecOrMat, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b)
325-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, B::AbstractGPUVecOrMat, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b)
326-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, B::AbstractGPUVecOrMat, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b)
327-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b)
328-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b)
329-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b)
330-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b)
331-
332-
# specificity hacks
333-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::AbstractGPUVecOrMat, B::AbstractGPUVecOrMat, a::Real, b::Real) = generic_matmatmul!(C, A, B, a, b)
334-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::AbstractGPUVecOrMat, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, a::Real, b::Real) = generic_matmatmul!(C, A, B, a, b)
335-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::AbstractGPUVecOrMat, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, a::Real, b::Real) = generic_matmatmul!(C, A, B, a, b)
336-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, B::AbstractGPUVecOrMat, a::Real, b::Real) = generic_matmatmul!(C, A, B, a, b)
337-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, B::AbstractGPUVecOrMat, a::Real, b::Real) = generic_matmatmul!(C, A, B, a, b)
338-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, a::Real, b::Real) = generic_matmatmul!(C, A, B, a, b)
339-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, a::Real, b::Real) = generic_matmatmul!(C, A, B, a, b)
340-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, a::Real, b::Real) = generic_matmatmul!(C, A, B, a, b)
341-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, a::Real, b::Real) = generic_matmatmul!(C, A, B, a, b)
346+
function LinearAlgebra.generic_matvecmul!(C::AbstractGPUVector, tA::AbstractChar, A::AbstractGPUMatrix, B::AbstractGPUVector, _add::MulAddMul = MulAddMul())
347+
generic_matmatmul!(C, wrap(A, tA), B, _add.alpha, _add.beta)
348+
end
342349

350+
function LinearAlgebra.generic_matmatmul!(C::AbstractGPUVecOrMat, tA, tB, A::AbstractGPUVecOrMat, B::AbstractGPUVecOrMat, _add::MulAddMul=MulAddMul())
351+
generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add.alpha, _add.beta)
352+
end
343353

354+
if VERSION < v"1.10.0-DEV.1365"
355+
# catch other functions that are called by LinearAlgebra's mul!
356+
function LinearAlgebra.gemv!(C::AbstractGPUVector, tA::AbstractChar, A::AbstractGPUMatrix, B::AbstractGPUVector, a::Number, b::Number)
357+
generic_matmatmul!(C, wrap(A, tA), B, a, b)
358+
end
359+
# disambiguation
360+
function LinearAlgebra.gemv!(C::AbstractGPUVector{T}, tA::AbstractChar, A::AbstractGPUMatrix{T}, B::AbstractGPUVector{T}, a::Number, b::Number) where {T<:LinearAlgebra.BlasFloat}
361+
generic_matmatmul!(C, wrap(A, tA), B, a, b)
362+
end
363+
364+
LinearAlgebra.gemm_wrapper!(C::AbstractGPUVecOrMat, tA::AbstractChar, tB::AbstractChar, A::AbstractGPUVecOrMat, B::AbstractGPUVecOrMat, _add::MulAddMul) =
365+
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add)
366+
# disambiguation
367+
LinearAlgebra.gemm_wrapper!(C::AbstractGPUVecOrMat{T}, tA::AbstractChar, tB::AbstractChar, A::AbstractGPUVecOrMat{T}, B::AbstractGPUVecOrMat{T}, _add::MulAddMul) where {T<:LinearAlgebra.BlasFloat} =
368+
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add)
369+
370+
function LinearAlgebra.syrk_wrapper!(C::AbstractGPUMatrix, tA::AbstractChar, A::AbstractGPUVecOrMat, _add::MulAddMul = MulAddMul())
371+
if tA == 'T'
372+
LinearAlgebra.generic_matmatmul!(C, 'T', 'N', A, A, _add)
373+
else # tA == 'N'
374+
LinearAlgebra.generic_matmatmul!(C, 'N', 'T', A, A, _add)
375+
end
376+
end
377+
function LinearAlgebra.herk_wrapper!(C::AbstractGPUMatrix, tA::AbstractChar, A::AbstractGPUVecOrMat, _add::MulAddMul = MulAddMul())
378+
if tA == 'C'
379+
LinearAlgebra.generic_matmatmul!(C, 'C', 'N', A, A, _add)
380+
else # tA == 'N'
381+
LinearAlgebra.generic_matmatmul!(C, 'N', 'C', A, A, _add)
382+
end
383+
end
384+
end # VERSION
385+
344386
function generic_rmul!(X::AbstractArray, s::Number)
345387
gpu_call(X, s; name="rmul!") do ctx, X, s
346388
i = @linearidx X

0 commit comments

Comments
 (0)