Skip to content

Commit 8eebdcb

Browse files
committed
Add dispatch path for FP16 batched mul
1 parent ee909e6 commit 8eebdcb

File tree

3 files changed

+26
-7
lines changed

3 files changed

+26
-7
lines changed

ext/NNlibAMDGPUExt/NNlibAMDGPUExt.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,6 @@ Base.show(io::IO, x::AnyROCBatchedAdjOrTrans) = show(io, adapt(Array, x))
3535

3636
Base.display(x::AnyROCBatchedAdjOrTrans) = display(adapt(Array, x))
3737

38-
function NNlib._batched_gemm!(
39-
::Type{<: ROCArray}, transA::Char, transB::Char, α, A, B, β, C,
40-
)
41-
AMDGPU.rocBLAS.gemm_batched!(transA, transB, α, A, B, β, C)
42-
end
43-
4438
function nnlib_padding(dims)
4539
pd = NNlib.padding(dims)
4640
if !all(pd[1:2:end] .== pd[2:2:end])
@@ -52,6 +46,8 @@ function nnlib_padding(dims)
5246
pd[1:2:end]
5347
end
5448

49+
include("batched_mul.jl")
50+
5551
@static if AMDGPU.functional(:MIOpen)
5652
using AMDGPU.MIOpen
5753

ext/NNlibAMDGPUExt/batched_mul.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
function _blas_at(x)
2+
Base.stride(x, 1) == 1 && return x, 'N'
3+
Base.stride(x, 2) == 1 && return batched_transpose(x), 'T'
4+
throw(ArgumentError("""
5+
Unsupported array layout for batched mul.
6+
- Size: $(size(x))
7+
- Strides: $(strides(x))
8+
"""))
9+
end
10+
11+
function NNlib._batched_mul!(
12+
::Type{AT}, C, A, B, α::Float16, β::Float16,
13+
) where AT <: ROCArray{Float16}
14+
blasA, transA = _blas_at(A)
15+
blasB, transB = _blas_at(B)
16+
NNlib._batched_gemm!(AT, transA, transB, α, blasA, blasB, β, C)
17+
C
18+
end
19+
20+
function NNlib._batched_gemm!(
21+
::Type{<:ROCArray{T}}, transA::Char, transB::Char, α::T, A, B, β::T, C,
22+
) where T <: Union{MIOPENFloat, Float64}
23+
AMDGPU.rocBLAS.gemm_batched!(transA, transB, α, A, B, β, C)
24+
end

src/batched/batchedmul.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,6 @@ _batched_mul!(::Type{DT}, C, A, B, α::Number, β::Number) where {DT<:DenseArray
223223
_batched_try_gemm!(DT, C, A, B, α, β)
224224

225225
function _batched_try_gemm!(::Type{DT}, C, A, B, α::Number, β::Number) where {DT<:DenseArray{T}} where {T<:BlasFloat}
226-
227226
alpha, beta = promote(α, β, zero(T))
228227
alpha isa T && beta isa T || return batched_mul_generic!(C, A, B, α, β)
229228

0 commit comments

Comments
 (0)