|
1 | 1 | # integration with LinearAlgebra stdlib
|
2 | 2 |
|
| 3 | +using LinearAlgebra: MulAddMul |
| 4 | + |
| 5 | +if isdefined(LinearAlgebra, :wrap) # i.e., VERSION >= v"1.10.0-DEV.1365" |
| 6 | + using LinearAlgebra: wrap |
| 7 | +else |
| 8 | + function wrap(A::AbstractVecOrMat, tA::AbstractChar) |
| 9 | + if tA == 'N' |
| 10 | + return A |
| 11 | + elseif tA == 'T' |
| 12 | + return transpose(A) |
| 13 | + elseif tA == 'C' |
| 14 | + return adjoint(A) |
| 15 | + elseif tA == 'H' |
| 16 | + return Hermitian(A, :U) |
| 17 | + elseif tA == 'h' |
| 18 | + return Hermitian(A, :L) |
| 19 | + elseif tA == 'S' |
| 20 | + return Symmetric(A, :U) |
| 21 | + else # tA == 's' |
| 22 | + return Symmetric(A, :L) |
| 23 | + end |
| 24 | + end |
| 25 | +end |
| 26 | + |
3 | 27 | ## transpose and adjoint
|
4 | 28 |
|
5 | 29 | function LinearAlgebra.transpose!(B::AbstractGPUVector, A::AbstractGPUMatrix)
|
@@ -319,28 +343,46 @@ function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::Abstrac
|
319 | 343 | C
|
320 | 344 | end
|
321 | 345 |
|
322 |
| -LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::AbstractGPUVecOrMat, B::AbstractGPUVecOrMat, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b) |
323 |
| -LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::AbstractGPUVecOrMat, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b) |
324 |
| -LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::AbstractGPUVecOrMat, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b) |
325 |
| -LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, B::AbstractGPUVecOrMat, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b) |
326 |
| -LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, B::AbstractGPUVecOrMat, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b) |
327 |
| -LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b) |
328 |
| -LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b) |
329 |
| -LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b) |
330 |
| -LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b) |
331 |
| - |
332 |
| -# specificity hacks |
333 |
| -LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::AbstractGPUVecOrMat, B::AbstractGPUVecOrMat, a::Real, b::Real) = generic_matmatmul!(C, A, B, a, b) |
334 |
| -LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::AbstractGPUVecOrMat, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, a::Real, b::Real) = generic_matmatmul!(C, A, B, a, b) |
335 |
| -LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::AbstractGPUVecOrMat, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, a::Real, b::Real) = generic_matmatmul!(C, A, B, a, b) |
336 |
| -LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, B::AbstractGPUVecOrMat, a::Real, b::Real) = generic_matmatmul!(C, A, B, a, b) |
337 |
| -LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, B::AbstractGPUVecOrMat, a::Real, b::Real) = generic_matmatmul!(C, A, B, a, b) |
338 |
| -LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, a::Real, b::Real) = generic_matmatmul!(C, A, B, a, b) |
339 |
| -LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, a::Real, b::Real) = generic_matmatmul!(C, A, B, a, b) |
340 |
| -LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, a::Real, b::Real) = generic_matmatmul!(C, A, B, a, b) |
341 |
| -LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, a::Real, b::Real) = generic_matmatmul!(C, A, B, a, b) |
| 346 | +function LinearAlgebra.generic_matvecmul!(C::AbstractGPUVector, tA::AbstractChar, A::AbstractGPUMatrix, B::AbstractGPUVector, _add::MulAddMul = MulAddMul()) |
| 347 | + generic_matmatmul!(C, wrap(A, tA), B, _add.alpha, _add.beta) |
| 348 | +end |
342 | 349 |
|
| 350 | +function LinearAlgebra.generic_matmatmul!(C::AbstractGPUVecOrMat, tA, tB, A::AbstractGPUVecOrMat, B::AbstractGPUVecOrMat, _add::MulAddMul=MulAddMul()) |
| 351 | + generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add.alpha, _add.beta) |
| 352 | +end |
343 | 353 |
|
| 354 | +if VERSION < v"1.10.0-DEV.1365" |
| 355 | +# catch other functions that are called by LinearAlgebra's mul! |
| 356 | +function LinearAlgebra.gemv!(C::AbstractGPUVector, tA::AbstractChar, A::AbstractGPUMatrix, B::AbstractGPUVector, a::Number, b::Number) |
| 357 | + generic_matmatmul!(C, wrap(A, tA), B, a, b) |
| 358 | +end |
| 359 | +# disambiguation |
| 360 | +function LinearAlgebra.gemv!(C::AbstractGPUVector{T}, tA::AbstractChar, A::AbstractGPUMatrix{T}, B::AbstractGPUVector{T}, a::Number, b::Number) where {T<:LinearAlgebra.BlasFloat} |
| 361 | + generic_matmatmul!(C, wrap(A, tA), B, a, b) |
| 362 | +end |
| 363 | + |
| 364 | +LinearAlgebra.gemm_wrapper!(C::AbstractGPUVecOrMat, tA::AbstractChar, tB::AbstractChar, A::AbstractGPUVecOrMat, B::AbstractGPUVecOrMat, _add::MulAddMul) = |
| 365 | + LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add) |
| 366 | +# disambiguation |
| 367 | +LinearAlgebra.gemm_wrapper!(C::AbstractGPUVecOrMat{T}, tA::AbstractChar, tB::AbstractChar, A::AbstractGPUVecOrMat{T}, B::AbstractGPUVecOrMat{T}, _add::MulAddMul) where {T<:LinearAlgebra.BlasFloat} = |
| 368 | + LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add) |
| 369 | + |
| 370 | +function LinearAlgebra.syrk_wrapper!(C::AbstractGPUMatrix, tA::AbstractChar, A::AbstractGPUVecOrMat, _add::MulAddMul = MulAddMul()) |
| 371 | + if tA == 'T' |
| 372 | + LinearAlgebra.generic_matmatmul!(C, 'T', 'N', A, A, _add) |
| 373 | + else # tA == 'N' |
| 374 | + LinearAlgebra.generic_matmatmul!(C, 'N', 'T', A, A, _add) |
| 375 | + end |
| 376 | +end |
| 377 | +function LinearAlgebra.herk_wrapper!(C::AbstractGPUMatrix, tA::AbstractChar, A::AbstractGPUVecOrMat, _add::MulAddMul = MulAddMul()) |
| 378 | + if tA == 'C' |
| 379 | + LinearAlgebra.generic_matmatmul!(C, 'C', 'N', A, A, _add) |
| 380 | + else # tA == 'N' |
| 381 | + LinearAlgebra.generic_matmatmul!(C, 'N', 'C', A, A, _add) |
| 382 | + end |
| 383 | +end |
| 384 | +end # VERSION |
| 385 | + |
344 | 386 | function generic_rmul!(X::AbstractArray, s::Number)
|
345 | 387 | gpu_call(X, s; name="rmul!") do ctx, X, s
|
346 | 388 | i = @linearidx X
|
|
0 commit comments