From 6ed8086e12dbd3d2bbb238c61679ad0796df631f Mon Sep 17 00:00:00 2001 From: Isaia Nisoli Date: Sat, 30 Nov 2024 09:41:17 -0300 Subject: [PATCH 01/15] fma and add working, sub not working --- docs/src/tutorials/exposing_new_intrinsics.jl | 84 +++++++++++++++++++ src/device/intrinsics/math.jl | 34 ++++++++ src/device/intrinsics/wmma.jl | 4 +- 3 files changed, 121 insertions(+), 1 deletion(-) create mode 100644 docs/src/tutorials/exposing_new_intrinsics.jl diff --git a/docs/src/tutorials/exposing_new_intrinsics.jl b/docs/src/tutorials/exposing_new_intrinsics.jl new file mode 100644 index 0000000000..98de21cadb --- /dev/null +++ b/docs/src/tutorials/exposing_new_intrinsics.jl @@ -0,0 +1,84 @@ +# # Introduction + +# * Adding new GPU intrinsics * + +# In this tutorial we will expose some GPU intrinsics to allow directed rounding in fused-multiply-add (fma) +# floating point operation +# We start by identifying the intrinsic we want to expose; to do so, we read the PTX (Parallel Thread Execution) +# documentation at [PTX - Floating Point Instructions](https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions). +# In table 32, it is presented a summary of floating point operations: we can construct the intrinsic string from that. +# The FMA instruction for Float32 is presented as `{mad,fma}.rnd.f32`, where `rnd` can assume the values `.rnd = { .rn, .rz, .rm, .rp }`, +# where `rn` is round to nearest, `rz` round to zero, `rm` round to minus infinity, `rp` round to plus infinity. +# When building the intrinsic for the call, we need to change the type `.f64` with `.d` and `.f32` with `.f` +# Therefore, to call the rounded towards infinity `fma` for `.f64` we need to call the intrinsic `llvm.nvvm.fma.rp.d` + +fma_rp(x::Float64, y::Float64, z::Float64) = ccall("llvm.nvvm.fma.rp.d", llvmcall, Cdouble, (Cdouble, Cdouble, Cdouble), x, y, z) + +# We inspect the PTX code +CUDA.code_ptx(fma_rp, Tuple{Float64, Float64, Float64}) + +# It is possible to see that the PTX code contains a call to the intrinsic `fma.rp.f64`; we add this function now +# to src/device/intrins/math.jl + +function test_fma!(out, x, y, z) + I = threadIdx().x + if I%4 == 0 + out[I] = CUDA.fma_rn(x, y, z) + elseif I%4 ==1 + out[I] = CUDA.fma_rz(x, y, z) + elseif I%4 ==2 + out[I] = CUDA.fma_rm(x, y, z) + elseif I%4 ==3 + out[I] = CUDA.fma_rp(x, y, z) + end + return +end + +# The first thread computes round to nearest and stores in the first entry, the second thread computes +# round towards zero and store in the second, the third rounds towards minus infinity, the fourth towards plus infinity + +out_d = CuArray(zeros(4)) +@cuda threads = 4 test_fma!(out_d, 1.0, 1.0, 2^(-54)) +out_h = Array(out_d) + +out_d = CuArray(zeros(4)) +@cuda threads = 4 test_fma!(out_d, -1.0, 1.0, 2^(-54)) +out_h = Array(out_d) + +# The binary operations as add, sub, mul, div have been implemented through a macro + +function test_add!(out, x, y) + I = threadIdx().x + if I%4 == 0 + out[I] = CUDA.add_rn(x, y) + elseif I%4 ==1 + out[I] = CUDA.add_rz(x, y) + elseif I%4 ==2 + out[I] = CUDA.add_rm(x, y) + elseif I%4 ==3 + out[I] = CUDA.add_rp(x, y) + end + return +end + +out_d = CuArray(zeros(4)) +@cuda threads = 4 test_add!(out_d, 1.0, 2^(-53)) +out_h = Array(out_d) + +function test_sub!(out, x, y) + I = threadIdx().x + if I%4 == 0 + out[I] = CUDA.sub_rn(x, y) + elseif I%4 ==1 + out[I] = CUDA.sub_rz(x, y) + elseif I%4 ==2 + out[I] = CUDA.sub_rm(x, y) + elseif I%4 ==3 + out[I] = CUDA.sub_rp(x, y) + end + return +end + +out_d = CuArray(zeros(4)) +@cuda threads = 4 test_sub!(out_d, 1.0, 2^(-53)) +out_h = Array(out_d) diff --git a/src/device/intrinsics/math.jl b/src/device/intrinsics/math.jl index a1d589721d..62e2a99c4c 100644 --- a/src/device/intrinsics/math.jl +++ b/src/device/intrinsics/math.jl @@ -399,9 +399,43 @@ end @device_override Base.hypot(x::Float64, y::Float64) = ccall("extern __nv_hypot", llvmcall, Cdouble, (Cdouble, Cdouble), x, y) @device_override Base.hypot(x::Float32, y::Float32) = ccall("extern __nv_hypotf", llvmcall, Cfloat, (Cfloat, Cfloat), x, y) + +for type in [:f, :d] + for round in [:rn, :rm, :rz, :rp] + for op in [:add, :sub, :mul, :div] + + inp_type = Symbol("Float64") + c_type = Symbol("Cdouble") + if type == :f + inp_type = Symbol("Float32") + c_type = Symbol("Cfloat") + end + + func_name = Symbol("$(op)_$(round)") + intrinsic_name = "llvm.nvvm.$(op).$(round).$(type)" + + @eval @device_function $func_name(x::$inp_type, y::$inp_type) = ccall($intrinsic_name, llvmcall, $c_type, ($c_type, $c_type), x, y) + end + end +end + @device_override Base.fma(x::Float64, y::Float64, z::Float64) = ccall("extern __nv_fma", llvmcall, Cdouble, (Cdouble, Cdouble, Cdouble), x, y, z) @device_override Base.fma(x::Float32, y::Float32, z::Float32) = ccall("extern __nv_fmaf", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z) @device_override Base.fma(x::Float16, y::Float16, z::Float16) = ccall("llvm.fma.f16", llvmcall, Float16, (Float16, Float16, Float16), x, y, z) +@device_function fma_rn(x::Float64, y::Float64, z::Float64) = ccall("llvm.nvvm.fma.rn.d", llvmcall, Cdouble, (Cdouble, Cdouble, Cdouble), x, y, z) +@device_function fma_rn(x::Float32, y::Float32, z::Float32) = ccall("llvm.nvvm.fma.rn.f", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z) +@device_function fma_rz(x::Float64, y::Float64, z::Float64) = ccall("llvm.nvvm.fma.rz.d", llvmcall, Cdouble, (Cdouble, Cdouble, Cdouble), x, y, z) +@device_function fma_rz(x::Float32, y::Float32, z::Float32) = ccall("llvm.nvvm.fma.rz.f", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z) +@device_function fma_rm(x::Float64, y::Float64, z::Float64) = ccall("llvm.nvvm.fma.rm.d", llvmcall, Cdouble, (Cdouble, Cdouble, Cdouble), x, y, z) +@device_function fma_rm(x::Float32, y::Float32, z::Float32) = ccall("llvm.nvvm.fma.rm.f", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z) +@device_function fma_rp(x::Float64, y::Float64, z::Float64) = ccall("llvm.nvvm.fma.rp.d", llvmcall, Cdouble, (Cdouble, Cdouble, Cdouble), x, y, z) +@device_function fma_rp(x::Float32, y::Float32, z::Float32) = ccall("llvm.nvvm.fma.rp.f", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z) +# @device_override Base.fma(x, y, z, ::RoundingMode{:Nearest}) = fma_rn(x, y, z) +# @device_override Base.fma(x, y, z, ::RoundingMode{:ToZero}) = fma_rz(x, y, z) +# @device_override Base.fma(x, y, z, ::RoundingMode{:Down}) = fma_rm(x, y, z) +# @device_override Base.fma(x, y, z, ::RoundingMode{:Up}) = fma_rp(x, y, z) + + @device_function sad(x::Int32, y::Int32, z::Int32) = ccall("extern __nv_sad", llvmcall, Int32, (Int32, Int32, Int32), x, y, z) @device_function sad(x::UInt32, y::UInt32, z::UInt32) = convert(UInt32, ccall("extern __nv_usad", llvmcall, Int32, (Int32, Int32, Int32), x, y, z)) diff --git a/src/device/intrinsics/wmma.jl b/src/device/intrinsics/wmma.jl index b40bbffe2d..c02b0370bf 100644 --- a/src/device/intrinsics/wmma.jl +++ b/src/device/intrinsics/wmma.jl @@ -491,7 +491,9 @@ julia> config = WMMA.Config{16, 16, 16, Float32} CUDA.WMMA.Config{16, 16, 16, Float32} ``` """ -struct Config{M, N, K, d_type} end +struct ConfigRounding{M, N, K, d_type, rounding} end + +Config{M, N, K, d_type} = ConfigRounding{M, N, K, d_type, RoundNearest} # --------- # Constants From fbee09f75a451def8b7f95d119bd661dcd5f250e Mon Sep 17 00:00:00 2001 From: Isaia Nisoli Date: Mon, 2 Dec 2024 08:54:22 -0300 Subject: [PATCH 02/15] Added tests in tutorials --- docs/src/tutorials/exposing_new_intrinsics.jl | 20 ++++++++++++++++++- src/device/intrinsics/math.jl | 13 ++++++++++-- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/docs/src/tutorials/exposing_new_intrinsics.jl b/docs/src/tutorials/exposing_new_intrinsics.jl index 98de21cadb..78774bb313 100644 --- a/docs/src/tutorials/exposing_new_intrinsics.jl +++ b/docs/src/tutorials/exposing_new_intrinsics.jl @@ -62,7 +62,7 @@ function test_add!(out, x, y) end out_d = CuArray(zeros(4)) -@cuda threads = 4 test_add!(out_d, 1.0, 2^(-53)) +@cuda threads = 4 test_add!(out_d, 1.0, 2^(-54)) out_h = Array(out_d) function test_sub!(out, x, y) @@ -82,3 +82,21 @@ end out_d = CuArray(zeros(4)) @cuda threads = 4 test_sub!(out_d, 1.0, 2^(-53)) out_h = Array(out_d) + +function test_mul!(out, x, y) + I = threadIdx().x + if I%4 == 0 + out[I] = CUDA.mul_rn(x, y) + elseif I%4 ==1 + out[I] = CUDA.mul_rz(x, y) + elseif I%4 ==2 + out[I] = CUDA.mul_rm(x, y) + elseif I%4 ==3 + out[I] = CUDA.mul_rp(x, y) + end + return +end + +out_d = CuArray(zeros(4)) +@cuda threads = 4 test_mul!(out_d, 1.0 - 2^(-52), 1.0 + 2^(-52)) +out_h = Array(out_d) diff --git a/src/device/intrinsics/math.jl b/src/device/intrinsics/math.jl index 62e2a99c4c..b2e75e261f 100644 --- a/src/device/intrinsics/math.jl +++ b/src/device/intrinsics/math.jl @@ -401,8 +401,8 @@ end for type in [:f, :d] - for round in [:rn, :rm, :rz, :rp] - for op in [:add, :sub, :mul, :div] + for round in [:rn, :rz, :rm, :rp] + for op in [:add, :mul, :div] inp_type = Symbol("Float64") c_type = Symbol("Cdouble") @@ -413,12 +413,21 @@ for type in [:f, :d] func_name = Symbol("$(op)_$(round)") intrinsic_name = "llvm.nvvm.$(op).$(round).$(type)" + #@info func_name, intrinsic_name @eval @device_function $func_name(x::$inp_type, y::$inp_type) = ccall($intrinsic_name, llvmcall, $c_type, ($c_type, $c_type), x, y) end end end + +@device_function sub_rn(x, y) = add_rn(x, -y) +@device_function sub_rz(x, y) = add_rz(x, -y) +@device_function sub_rm(x, y) = add_rm(x, -y) +@device_function sub_rp(x, y) = add_rp(x, -y) + + + @device_override Base.fma(x::Float64, y::Float64, z::Float64) = ccall("extern __nv_fma", llvmcall, Cdouble, (Cdouble, Cdouble, Cdouble), x, y, z) @device_override Base.fma(x::Float32, y::Float32, z::Float32) = ccall("extern __nv_fmaf", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z) @device_override Base.fma(x::Float16, y::Float16, z::Float16) = ccall("llvm.fma.f16", llvmcall, Float16, (Float16, Float16, Float16), x, y, z) From 3354104dad1fd2b4790cd9622fccaf2a2f469044 Mon Sep 17 00:00:00 2001 From: Isaia Nisoli Date: Mon, 2 Dec 2024 20:48:54 -0300 Subject: [PATCH 03/15] Different examples --- docs/src/tutorials/exposing_new_intrinsics.jl | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/docs/src/tutorials/exposing_new_intrinsics.jl b/docs/src/tutorials/exposing_new_intrinsics.jl index 78774bb313..5c863ec514 100644 --- a/docs/src/tutorials/exposing_new_intrinsics.jl +++ b/docs/src/tutorials/exposing_new_intrinsics.jl @@ -20,24 +20,22 @@ CUDA.code_ptx(fma_rp, Tuple{Float64, Float64, Float64}) # It is possible to see that the PTX code contains a call to the intrinsic `fma.rp.f64`; we add this function now # to src/device/intrins/math.jl -function test_fma!(out, x, y, z) +function test_fma!(out, x, y) I = threadIdx().x - if I%4 == 0 - out[I] = CUDA.fma_rn(x, y, z) - elseif I%4 ==1 - out[I] = CUDA.fma_rz(x, y, z) - elseif I%4 ==2 - out[I] = CUDA.fma_rm(x, y, z) - elseif I%4 ==3 - out[I] = CUDA.fma_rp(x, y, z) - end + z = typeof(x)(2)^(-(I+50)) + + out[I] = CUDA.fma_rn(x, y, z) + out[I+4] = CUDA.fma_rz(x, y, z) + out[I+8] = CUDA.fma_rm(x, y, z) + out[I+12] = CUDA.fma_rp(x, y, z) + return end # The first thread computes round to nearest and stores in the first entry, the second thread computes # round towards zero and store in the second, the third rounds towards minus infinity, the fourth towards plus infinity -out_d = CuArray(zeros(4)) +out_d = CuArray(zeros(16)) @cuda threads = 4 test_fma!(out_d, 1.0, 1.0, 2^(-54)) out_h = Array(out_d) From 244a39ac29fb05379bd462810ddbc89986e5c1c0 Mon Sep 17 00:00:00 2001 From: Isaia Nisoli Date: Sat, 7 Dec 2024 21:54:59 -0300 Subject: [PATCH 04/15] Using RoundToNearest etc... in intrinsic call --- docs/src/tutorials/exposing_new_intrinsics.jl | 60 +++++++++---------- src/device/intrinsics/math.jl | 32 +++++++--- 2 files changed, 54 insertions(+), 38 deletions(-) diff --git a/docs/src/tutorials/exposing_new_intrinsics.jl b/docs/src/tutorials/exposing_new_intrinsics.jl index 5c863ec514..4f058e0dbf 100644 --- a/docs/src/tutorials/exposing_new_intrinsics.jl +++ b/docs/src/tutorials/exposing_new_intrinsics.jl @@ -15,21 +15,21 @@ fma_rp(x::Float64, y::Float64, z::Float64) = ccall("llvm.nvvm.fma.rp.d", llvmcall, Cdouble, (Cdouble, Cdouble, Cdouble), x, y, z) # We inspect the PTX code -CUDA.code_ptx(fma_rp, Tuple{Float64, Float64, Float64}) +CUDA.code_ptx(fma_rp, Tuple{Float64,Float64,Float64}) # It is possible to see that the PTX code contains a call to the intrinsic `fma.rp.f64`; we add this function now # to src/device/intrins/math.jl function test_fma!(out, x, y) I = threadIdx().x - z = typeof(x)(2)^(-(I+50)) + z = typeof(x)(2)^(-(I + 50)) out[I] = CUDA.fma_rn(x, y, z) out[I+4] = CUDA.fma_rz(x, y, z) out[I+8] = CUDA.fma_rm(x, y, z) out[I+12] = CUDA.fma_rp(x, y, z) - return + return end # The first thread computes round to nearest and stores in the first entry, the second thread computes @@ -47,16 +47,16 @@ out_h = Array(out_d) function test_add!(out, x, y) I = threadIdx().x - if I%4 == 0 - out[I] = CUDA.add_rn(x, y) - elseif I%4 ==1 - out[I] = CUDA.add_rz(x, y) - elseif I%4 ==2 - out[I] = CUDA.add_rm(x, y) - elseif I%4 ==3 - out[I] = CUDA.add_rp(x, y) + if I == 1 + out[I] = CUDA.add(x, y, RoundNearest) + elseif I == 2 + out[I] = CUDA.add(x, y, RoundToZero) + elseif I == 3 + out[I] = CUDA.add(x, y, RoundUp) + elseif I == 4 + out[I] = CUDA.add(x, y, RoundDown) end - return + return end out_d = CuArray(zeros(4)) @@ -65,16 +65,16 @@ out_h = Array(out_d) function test_sub!(out, x, y) I = threadIdx().x - if I%4 == 0 - out[I] = CUDA.sub_rn(x, y) - elseif I%4 ==1 - out[I] = CUDA.sub_rz(x, y) - elseif I%4 ==2 - out[I] = CUDA.sub_rm(x, y) - elseif I%4 ==3 - out[I] = CUDA.sub_rp(x, y) + if I == 1 + out[I] = CUDA.sub(x, y, RoundNearest) + elseif I == 2 + out[I] = CUDA.sub(x, y, RoundToZero) + elseif I == 3 + out[I] = CUDA.sub(x, y, RoundUp) + elseif I == 4 + out[I] = CUDA.sub(x, y, RoundDown) end - return + return end out_d = CuArray(zeros(4)) @@ -83,16 +83,16 @@ out_h = Array(out_d) function test_mul!(out, x, y) I = threadIdx().x - if I%4 == 0 - out[I] = CUDA.mul_rn(x, y) - elseif I%4 ==1 - out[I] = CUDA.mul_rz(x, y) - elseif I%4 ==2 - out[I] = CUDA.mul_rm(x, y) - elseif I%4 ==3 - out[I] = CUDA.mul_rp(x, y) + if I == 1 + out[I] = CUDA.mul(x, y, RoundNearest) + elseif I == 2 + out[I] = CUDA.mul(x, y, RoundToZero) + elseif I == 3 + out[I] = CUDA.mul(x, y, RoundUp) + elseif I == 4 + out[I] = CUDA.mul(x, y, RoundDown) end - return + return end out_d = CuArray(zeros(4)) diff --git a/src/device/intrinsics/math.jl b/src/device/intrinsics/math.jl index b2e75e261f..bc9f82eb96 100644 --- a/src/device/intrinsics/math.jl +++ b/src/device/intrinsics/math.jl @@ -390,8 +390,6 @@ end @device_function normcdfinv(x::Float64) = ccall("extern __nv_normcdfinv", llvmcall, Cdouble, (Cdouble,), x) @device_function normcdfinv(x::Float32) = ccall("extern __nv_normcdfinvf", llvmcall, Cfloat, (Cfloat,), x) - - # # Unsorted # @@ -420,12 +418,31 @@ for type in [:f, :d] end end - @device_function sub_rn(x, y) = add_rn(x, -y) @device_function sub_rz(x, y) = add_rz(x, -y) @device_function sub_rm(x, y) = add_rm(x, -y) @device_function sub_rp(x, y) = add_rp(x, -y) +@device_function add(x::T, y::T, ::RoundingMode{:Nearest}) where {T <: Union{Float32, Float64}} = add_rn(x, y) +@device_function add(x::T, y::T, ::RoundingMode{:ToZero}) where {T <: Union{Float32, Float64}} = add_rz(x, y) +@device_function add(x::T, y::T, ::RoundingMode{:Down}) where {T <: Union{Float32, Float64}} = add_rm(x, y) +@device_function add(x::T, y::T, ::RoundingMode{:Up}) where {T <: Union{Float32, Float64}} = add_rp(x, y) + +@device_function sub(x::T, y::T, ::RoundingMode{:Nearest}) where {T <: Union{Float32, Float64}} = sub_rn(x, y) +@device_function sub(x::T, y::T, ::RoundingMode{:ToZero}) where {T <: Union{Float32, Float64}} = sub_rz(x, y) +@device_function sub(x::T, y::T, ::RoundingMode{:Down}) where {T <: Union{Float32, Float64}} = sub_rm(x, y) +@device_function sub(x::T, y::T, ::RoundingMode{:Up}) where {T <: Union{Float32, Float64}} = sub_rp(x, y) + +@device_function mul(x::T, y::T, ::RoundingMode{:Nearest}) where {T <: Union{Float32, Float64}} = mul_rn(x, y) +@device_function mul(x::T, y::T, ::RoundingMode{:ToZero}) where {T <: Union{Float32, Float64}} = mul_rz(x, y) +@device_function mul(x::T, y::T, ::RoundingMode{:Down}) where {T <: Union{Float32, Float64}} = mul_rm(x, y) +@device_function mul(x::T, y::T, ::RoundingMode{:Up}) where {T <: Union{Float32, Float64}} = mul_rp(x, y) + +@device_function div(x::T, y::T, ::RoundingMode{:Nearest}) where {T <: Union{Float32, Float64}} = div_rn(x, y) +@device_function div(x::T, y::T, ::RoundingMode{:ToZero}) where {T <: Union{Float32, Float64}} = div_rz(x, y) +@device_function div(x::T, y::T, ::RoundingMode{:Down}) where {T <: Union{Float32, Float64}} = div_rm(x, y) +@device_function div(x::T, y::T, ::RoundingMode{:Up}) where {T <: Union{Float32, Float64}} = div_rp(x, y) + @device_override Base.fma(x::Float64, y::Float64, z::Float64) = ccall("extern __nv_fma", llvmcall, Cdouble, (Cdouble, Cdouble, Cdouble), x, y, z) @@ -439,12 +456,11 @@ end @device_function fma_rm(x::Float32, y::Float32, z::Float32) = ccall("llvm.nvvm.fma.rm.f", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z) @device_function fma_rp(x::Float64, y::Float64, z::Float64) = ccall("llvm.nvvm.fma.rp.d", llvmcall, Cdouble, (Cdouble, Cdouble, Cdouble), x, y, z) @device_function fma_rp(x::Float32, y::Float32, z::Float32) = ccall("llvm.nvvm.fma.rp.f", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z) -# @device_override Base.fma(x, y, z, ::RoundingMode{:Nearest}) = fma_rn(x, y, z) -# @device_override Base.fma(x, y, z, ::RoundingMode{:ToZero}) = fma_rz(x, y, z) -# @device_override Base.fma(x, y, z, ::RoundingMode{:Down}) = fma_rm(x, y, z) -# @device_override Base.fma(x, y, z, ::RoundingMode{:Up}) = fma_rp(x, y, z) - +@device_override Base.fma(x::T, y::T, z::T, ::RoundingMode{:Nearest}) where {T <: Union{Float32, Float64}} = fma_rn(x, y, z) +@device_override Base.fma(x::T, y::T, z::T, ::RoundingMode{:ToZero}) where {T <: Union{Float32, Float64}} = fma_rz(x, y, z) +@device_override Base.fma(x::T, y::T, z::T, ::RoundingMode{:Down}) where {T <: Union{Float32, Float64}} = fma_rm(x, y, z) +@device_override Base.fma(x::T, y::T, z::T, ::RoundingMode{:Up}) where {T <: Union{Float32, Float64}} = fma_rp(x, y, z) @device_function sad(x::Int32, y::Int32, z::Int32) = ccall("extern __nv_sad", llvmcall, Int32, (Int32, Int32, Int32), x, y, z) @device_function sad(x::UInt32, y::UInt32, z::UInt32) = convert(UInt32, ccall("extern __nv_usad", llvmcall, Int32, (Int32, Int32, Int32), x, y, z)) From f8ec736afeeef6f117548acdfdb7030101d05071 Mon Sep 17 00:00:00 2001 From: Isaia Nisoli Date: Sat, 7 Dec 2024 21:56:07 -0300 Subject: [PATCH 05/15] Added to the tutorial --- docs/src/tutorials/exposing_new_intrinsics.jl | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/docs/src/tutorials/exposing_new_intrinsics.jl b/docs/src/tutorials/exposing_new_intrinsics.jl index 4f058e0dbf..08440af1fe 100644 --- a/docs/src/tutorials/exposing_new_intrinsics.jl +++ b/docs/src/tutorials/exposing_new_intrinsics.jl @@ -13,6 +13,7 @@ # Therefore, to call the rounded towards infinity `fma` for `.f64` we need to call the intrinsic `llvm.nvvm.fma.rp.d` fma_rp(x::Float64, y::Float64, z::Float64) = ccall("llvm.nvvm.fma.rp.d", llvmcall, Cdouble, (Cdouble, Cdouble, Cdouble), x, y, z) +fma(x::T, y::T, z::T, ::RoundingMode{:Up}) where {T <: Union{Float32, Float64}} = fma_rp(x, y, z) # We inspect the PTX code CUDA.code_ptx(fma_rp, Tuple{Float64,Float64,Float64}) @@ -22,25 +23,25 @@ CUDA.code_ptx(fma_rp, Tuple{Float64,Float64,Float64}) function test_fma!(out, x, y) I = threadIdx().x - z = typeof(x)(2)^(-(I + 50)) + z = (2.0) ^ (-(I+53)) - out[I] = CUDA.fma_rn(x, y, z) - out[I+4] = CUDA.fma_rz(x, y, z) - out[I+8] = CUDA.fma_rm(x, y, z) - out[I+12] = CUDA.fma_rp(x, y, z) + out[I] = fma(x, y, z, RoundNearest) + out[I+4] = fma(x, y, z, RoundToZero) + out[I+8] = fma(x, y, z, RoundUp) + out[I+12] = fma(x, y, z, RoundDown) return end -# The first thread computes round to nearest and stores in the first entry, the second thread computes -# round towards zero and store in the second, the third rounds towards minus infinity, the fourth towards plus infinity +# The first four entries of the output are Rounded to Nearest, the entries 5 to 8 are rounded towards zero, +# etc... out_d = CuArray(zeros(16)) -@cuda threads = 4 test_fma!(out_d, 1.0, 1.0, 2^(-54)) +@cuda threads = 4 test_fma!(out_d, 1.0, 1.0) out_h = Array(out_d) out_d = CuArray(zeros(4)) -@cuda threads = 4 test_fma!(out_d, -1.0, 1.0, 2^(-54)) +@cuda threads = 4 test_fma!(out_d, -1.0, 1.0) out_h = Array(out_d) # The binary operations as add, sub, mul, div have been implemented through a macro From 7346a6a77da7d9c4d650bb5e152d69f4e708f6f6 Mon Sep 17 00:00:00 2001 From: Isaia Nisoli Date: Sun, 8 Dec 2024 16:29:10 -0300 Subject: [PATCH 06/15] Directed rounding MMA exposed --- docs/src/tutorials/TODO.jl | 37 ++++++++++++ src/device/intrinsics/wmma.jl | 106 +++++++++++++++++++++++++++++++--- 2 files changed, 136 insertions(+), 7 deletions(-) create mode 100644 docs/src/tutorials/TODO.jl diff --git a/docs/src/tutorials/TODO.jl b/docs/src/tutorials/TODO.jl new file mode 100644 index 0000000000..99c7ae2c6c --- /dev/null +++ b/docs/src/tutorials/TODO.jl @@ -0,0 +1,37 @@ +# https://github.com/JuliaGPU/CUDA.jl/pull/1426 + +function kernel_wmma_f64_lowlevel(a_dev, b_dev, c_dev, d_dev) + a_frag = WMMA.llvm_wmma_load_a_col_m8n8k4_global_stride_f64(pointer(a_dev), 8) + b_frag = WMMA.llvm_wmma_load_b_col_m8n8k4_global_stride_f64(pointer(b_dev), 4) + c_frag = WMMA.llvm_wmma_load_c_col_m8n8k4_global_stride_f64(pointer(c_dev), 8) + + d_frag = WMMA.llvm_wmma_mma_col_col_m8n8k4_f64(a_frag, b_frag, c_frag) + d_frag = WMMA.llvm_wmma_mma_col_col_m8n8k4_f64(a_frag, b_frag, c_frag, RoundToZero) + d_frag = WMMA.llvm_wmma_mma_col_col_m8n8k4_f64(a_frag, b_frag, c_frag, RoundUp) + d_frag = WMMA.llvm_wmma_mma_col_col_m8n8k4_f64(a_frag, b_frag, c_frag, RoundDown) + #@cuprintln d_frag + + ccall("llvm.nvvm.wmma.m8n8k4.store.d.col.stride.f64", llvmcall, + Nothing, (Core.LLVMPtr{Float64, 1}, Float64, Float64, Int32), + pointer(d_dev), d_frag[1], d_frag[2], 8) + #WMMA.llvm_wmma_store_d_col_m8n8k4_global_stride_f64(pointer(d_dev), d_frag, 8) + return nothing +end + +function call_kernel() + m = n = 8 + k = 4 + dtype_a = dtype_b = Float64 + dtype_c = dtype_d = Float64 + + d_a = CUDA.rand(dtype_a, m, k) + d_b = CUDA.rand(dtype_b, k, n) + d_c = CUDA.rand(dtype_c, m, n) + d_d = CUDA.zeros(dtype_d, m, n) + + CUDA.@sync @cuda kernel_wmma_f64_lowlevel(d_a, d_b, d_c, d_d) + return nothing +end + +#https://github.com/llvm/llvm-project/blob/main/clang/test/CodeGen/builtins-nvptx-mma.cu + diff --git a/src/device/intrinsics/wmma.jl b/src/device/intrinsics/wmma.jl index c02b0370bf..a500274139 100644 --- a/src/device/intrinsics/wmma.jl +++ b/src/device/intrinsics/wmma.jl @@ -15,7 +15,8 @@ const map_ptx_to_jl_array = Dict( "s8" => Int8, "s32" => Int32, "f16" => Float16, - "f32" => Float32 + "f32" => Float32, + "f64" => Float64 ) # Maps PTX types to Julia fragment types @@ -24,10 +25,13 @@ const map_ptx_to_jl_frag = Dict( "s8" => UInt32, "s32" => Int32, "f16" => NTuple{2, VecElement{Float16}}, - "f32" => Float32 + "f32" => Float32, + "f64" => Float64 ) -# Maps matrix & PTX types to fragment sizes +# Maps matrix & PTX types to fragment sizes, information retrieved from +# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=wmma#matrix-fragments-for-wmma + const map_frag_sizes = Dict( # A "a.u8.m16n16k16" => 2, @@ -41,6 +45,9 @@ const map_frag_sizes = Dict( "a.f16.m16n16k16" => 8, "a.f16.m8n32k16" => 8, "a.f16.m32n8k16" => 8, + + "a.f64.m8n8k4" => 1, + # B "b.u8.m16n16k16" => 2, "b.u8.m8n32k16" => 4, @@ -53,6 +60,9 @@ const map_frag_sizes = Dict( "b.f16.m16n16k16" => 8, "b.f16.m8n32k16" => 8, "b.f16.m32n8k16" => 8, + + "b.f64.m8n8k4" => 1, + # C "c.s32.m16n16k16" => 8, "c.s32.m8n32k16" => 8, @@ -65,6 +75,12 @@ const map_frag_sizes = Dict( "c.f32.m16n16k16" => 8, "c.f32.m8n32k16" => 8, "c.f32.m32n8k16" => 8, + + "c.f64.m8n8k4" => 2, # there is a clash of documentation here: + # https://docs.nvidia.com/cuda/parallel-thread-execution/#matrix-fragments-for-mma-m8n8k4-with-f64-floating-point-type + # says `A vector expression containing of two .f64 registers containing two .f64 elements from the matrix C.` + # while https://docs.nvidia.com/cuda/parallel-thread-execution/#matrix-fragments-for-wmma says 1 + # D "d.s32.m16n16k16" => 8, "d.s32.m8n32k16" => 8, @@ -77,6 +93,8 @@ const map_frag_sizes = Dict( "d.f32.m16n16k16" => 8, "d.f32.m8n32k16" => 8, "d.f32.m32n8k16" => 8, + + "d.f64.m8n8k4" => 2, ) # Maps PTX AS to CUDA.AS @@ -96,13 +114,19 @@ const wmma_half_ops = [(16,16,16), (32,8,16), (8,32,16)], ["f16"], ["f16", "f const ldst_int_ab_ops = [(16,16,16), (32,8,16), (8,32,16)], ["a", "b"], ["u8", "s8"] const ldst_int_cd_ops = [(16,16,16), (32,8,16), (8,32,16)], ["c", "d"], ["s32"] const wmma_int_ops = [(16,16,16), (32,8,16), (8,32,16)], ["s8", "u8"], ["s32"], ["s32"] - -const all_ldst_ops = vcat(ldst_half_ab_ops, ldst_half_cd_ops, - ldst_int_ab_ops, ldst_int_cd_ops) +# Double +const ldst_double_ab_ops = [(8, 8, 4)], ["a", "b"], ["f64"] +const ldst_double_cd_ops = [(8, 8, 4)], ["c", "d"], ["f64"] +const wmma_double_ops = [(8, 8, 4)], ["f64"], ["f64"], ["f64"] + +const all_ldst_ops = vcat(ldst_half_ab_ops, ldst_half_cd_ops, ldst_double_ab_ops, + ldst_int_ab_ops, ldst_int_cd_ops, ldst_double_cd_ops) + +# the wmma_double_ops will be treated separatedly due to rounding const all_wmma_ops = vcat(wmma_half_ops, wmma_int_ops) # Valid WMMA operation shapes -const valid_shapes = [(16, 16, 16), (32, 8, 16), (8, 32, 16)] +const valid_shapes = [(16, 16, 16), (32, 8, 16), (8, 32, 16), (8, 8, 4)] ################################################################################ # HELPER FUNCTIONS @@ -256,19 +280,24 @@ export llvm_wmma_store func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "store", mat, layout, shape, addr_space, stride, elem_type]), "_")) # Name of the LLVM intrinsic + #llvm.nvvm.wmma.m8n8k4.store.d.col.stride.f64 llvm_intr = "llvm.nvvm.wmma.$shape.store.$mat.$layout.stride.$elem_type.p$(addr_space_int)" if LLVM.version() < v"17" llvm_intr *= "i8" end + @info llvm_intr # Determine types + size for this (matrix, elem_type) combination arr_ty, frag_ty, sz = get_frag_info(mat, elem_type, shape) + @info arr_ty, frag_ty, sz ccall_name = "$llvm_intr" frag_types = ntuple(i -> frag_ty, sz) frag_vars = ntuple(i -> :(data[$i]), sz) + @info frag_types, frag_vars ptr_ty = :(LLVMPtr{$arr_ty, $addr_space_int}) + @info ptr_ty @eval $func_name(dst_addr, data, stride) = ccall($ccall_name, llvmcall, Nothing, ($ptr_ty, $(frag_types...), Int32), dst_addr, $(frag_vars...), stride) @eval export $func_name @@ -283,6 +312,7 @@ end WMMA.llvm_wmma_mma_{a_layout}_{b_layout}_{shape}_{d_elem_type}_{c_elem_type}(a, b, c) or WMMA.llvm_wmma_mma_{a_layout}_{b_layout}_{shape}_{a_elem_type}(a, b, c) +For double operations: wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.mma.sync.{a_layout}.{b_layout}.{shape}.{rnd}.{d_elem_type}.{c_elem_type}` For floating point operations: wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.mma.sync.{a_layout}.{b_layout}.{shape}.{d_elem_type}.{c_elem_type}` For all other operations: wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.mma.sync.{a_layout}.{b_layout}.{shape}.{a_elem_type}` @@ -356,6 +386,68 @@ for ops in all_wmma_ops, @eval @doc (@doc llvm_wmma_mma) $func_name end +const wmma_double_rounding = ["", "rn", "rz", "rm", "rp"] + +for ops in [wmma_double_ops], + a_layout in ["col", "row"], + b_layout in ["col", "row"], + mnk in ops[1], + rnd in wmma_double_rounding + + a_elem_type = "f64" + b_elem_type = "f64" + c_elem_type = "f64" + d_elem_type = "f64" + + shape = get_hl_shape(mnk[1], mnk[2], mnk[3]) + + llvm_intr = "llvm.nvvm.wmma.$shape.mma.$a_layout.$b_layout.$rnd.f64" + if rnd == "" + llvm_intr = "llvm.nvvm.wmma.$shape.mma.$a_layout.$b_layout.f64" + end + # Name of the Julia wrapper function + func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, a_elem_type, rnd]), "_")) + + # Determine types + size for the (matrix, elem_type) combinations for matrix A, B, C and D + a_arr_ty, a_frag_ty, a_sz = get_frag_info("a", a_elem_type, shape) + b_arr_ty, b_frag_ty, b_sz = get_frag_info("b", b_elem_type, shape) + c_arr_ty, c_frag_ty, c_sz = get_frag_info("c", c_elem_type, shape) + d_arr_ty, d_frag_ty, d_sz = get_frag_info("d", d_elem_type, shape) + + ccall_name = "$llvm_intr" + + a_types = ntuple(i -> a_frag_ty, a_sz) + b_types = ntuple(i -> b_frag_ty, b_sz) + c_types = ntuple(i -> c_frag_ty, c_sz) + + a_vars = ntuple(i -> :(a[$i]), a_sz) + b_vars = ntuple(i -> :(b[$i]), b_sz) + c_vars = ntuple(i -> :(c[$i]), c_sz) + + if d_sz == 1 + @eval $func_name(a, b, c) = tuple(ccall($ccall_name, llvmcall, $d_frag_ty, ($(a_types...), $(b_types...), $(c_types...)), $(a_vars...), $(b_vars...), $(c_vars...))) + else + struct_ty = Symbol("LLVMStruct$d_sz") + @eval $func_name(a, b, c) = convert(NTuple{$d_sz, $d_frag_ty}, ccall($ccall_name, llvmcall, $struct_ty{$d_frag_ty}, ($(a_types...), $(b_types...), $(c_types...)), $(a_vars...), $(b_vars...), $(c_vars...))) + end + @eval export $func_name + @eval @doc (@doc llvm_wmma_mma) $func_name +end + +llvm_wmma_mma_col_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Nearest}) = llvm_wmma_mma_col_col_m8n8k4_f64_rn(a_frag, b_frag, c_frag) +llvm_wmma_mma_col_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:ToZero}) = llvm_wmma_mma_col_col_m8n8k4_f64_rz(a_frag, b_frag, c_frag) +llvm_wmma_mma_col_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Up}) = llvm_wmma_mma_col_col_m8n8k4_f64_rp(a_frag, b_frag, c_frag) +llvm_wmma_mma_col_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Down}) = llvm_wmma_mma_col_col_m8n8k4_f64_rm(a_frag, b_frag, c_frag) + + + +# elseif d_elem_type == "f64" +# llvm_intr = "llvm.nvvm.wmma.$shape.mma.$a_layout.$b_layout.$rnd.f64.f64.f64.f64" +# # Name of the Julia wrapper function +# func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, a_elem_type, rnd]), "_")) + + + ################################################################################ # FLATTENING/UNFLATTENING LOGIC ################################################################################ From f8872a4344d35609ae1f8ebe7899bd28cdac5699 Mon Sep 17 00:00:00 2001 From: Isaia Nisoli Date: Mon, 9 Dec 2024 05:21:56 +0900 Subject: [PATCH 07/15] Removed TODO file with test --- docs/src/tutorials/TODO.jl | 37 ------------------------------------- 1 file changed, 37 deletions(-) delete mode 100644 docs/src/tutorials/TODO.jl diff --git a/docs/src/tutorials/TODO.jl b/docs/src/tutorials/TODO.jl deleted file mode 100644 index 99c7ae2c6c..0000000000 --- a/docs/src/tutorials/TODO.jl +++ /dev/null @@ -1,37 +0,0 @@ -# https://github.com/JuliaGPU/CUDA.jl/pull/1426 - -function kernel_wmma_f64_lowlevel(a_dev, b_dev, c_dev, d_dev) - a_frag = WMMA.llvm_wmma_load_a_col_m8n8k4_global_stride_f64(pointer(a_dev), 8) - b_frag = WMMA.llvm_wmma_load_b_col_m8n8k4_global_stride_f64(pointer(b_dev), 4) - c_frag = WMMA.llvm_wmma_load_c_col_m8n8k4_global_stride_f64(pointer(c_dev), 8) - - d_frag = WMMA.llvm_wmma_mma_col_col_m8n8k4_f64(a_frag, b_frag, c_frag) - d_frag = WMMA.llvm_wmma_mma_col_col_m8n8k4_f64(a_frag, b_frag, c_frag, RoundToZero) - d_frag = WMMA.llvm_wmma_mma_col_col_m8n8k4_f64(a_frag, b_frag, c_frag, RoundUp) - d_frag = WMMA.llvm_wmma_mma_col_col_m8n8k4_f64(a_frag, b_frag, c_frag, RoundDown) - #@cuprintln d_frag - - ccall("llvm.nvvm.wmma.m8n8k4.store.d.col.stride.f64", llvmcall, - Nothing, (Core.LLVMPtr{Float64, 1}, Float64, Float64, Int32), - pointer(d_dev), d_frag[1], d_frag[2], 8) - #WMMA.llvm_wmma_store_d_col_m8n8k4_global_stride_f64(pointer(d_dev), d_frag, 8) - return nothing -end - -function call_kernel() - m = n = 8 - k = 4 - dtype_a = dtype_b = Float64 - dtype_c = dtype_d = Float64 - - d_a = CUDA.rand(dtype_a, m, k) - d_b = CUDA.rand(dtype_b, k, n) - d_c = CUDA.rand(dtype_c, m, n) - d_d = CUDA.zeros(dtype_d, m, n) - - CUDA.@sync @cuda kernel_wmma_f64_lowlevel(d_a, d_b, d_c, d_d) - return nothing -end - -#https://github.com/llvm/llvm-project/blob/main/clang/test/CodeGen/builtins-nvptx-mma.cu - From 8e4cffdb0a688bf76fc3a4bb89a6a5c7eb81f7e4 Mon Sep 17 00:00:00 2001 From: Isaia Nisoli Date: Mon, 9 Dec 2024 05:30:00 +0900 Subject: [PATCH 08/15] Removed debug --- src/device/intrinsics/wmma.jl | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/device/intrinsics/wmma.jl b/src/device/intrinsics/wmma.jl index a500274139..2c7a96dc48 100644 --- a/src/device/intrinsics/wmma.jl +++ b/src/device/intrinsics/wmma.jl @@ -286,19 +286,15 @@ export llvm_wmma_store llvm_intr *= "i8" end - @info llvm_intr # Determine types + size for this (matrix, elem_type) combination arr_ty, frag_ty, sz = get_frag_info(mat, elem_type, shape) - @info arr_ty, frag_ty, sz - + ccall_name = "$llvm_intr" frag_types = ntuple(i -> frag_ty, sz) frag_vars = ntuple(i -> :(data[$i]), sz) - @info frag_types, frag_vars - + ptr_ty = :(LLVMPtr{$arr_ty, $addr_space_int}) - @info ptr_ty - + @eval $func_name(dst_addr, data, stride) = ccall($ccall_name, llvmcall, Nothing, ($ptr_ty, $(frag_types...), Int32), dst_addr, $(frag_vars...), stride) @eval export $func_name @eval @doc (@doc llvm_wmma_store) $func_name From 1042ebb7b9e1c0acde8de0761218b38ea61fd4bb Mon Sep 17 00:00:00 2001 From: Isaia Nisoli Date: Mon, 9 Dec 2024 16:04:04 -0300 Subject: [PATCH 09/15] Added rounding to WMMA config, need to propagate back! --- docs/make.jl | 3 + .../exposing_new_intrinsics.jl | 58 +------------------ docs/src/tutorials/TODO.jl | 44 ++++++++++++++ src/device/intrinsics/wmma.jl | 25 ++++---- 4 files changed, 64 insertions(+), 66 deletions(-) rename docs/src/{tutorials => hacking}/exposing_new_intrinsics.jl (60%) create mode 100644 docs/src/tutorials/TODO.jl diff --git a/docs/make.jl b/docs/make.jl index 3cbd62d523..9af70ae3e7 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -60,6 +60,9 @@ function main() "development/troubleshooting.md", "development/debugging.md", ], + "Hacking" => Any[ + "hacking/exposing_new_intrinsics.jl" + ], "API reference" => Any[ "api/essentials.md", "api/array.md", diff --git a/docs/src/tutorials/exposing_new_intrinsics.jl b/docs/src/hacking/exposing_new_intrinsics.jl similarity index 60% rename from docs/src/tutorials/exposing_new_intrinsics.jl rename to docs/src/hacking/exposing_new_intrinsics.jl index 08440af1fe..662940af27 100644 --- a/docs/src/tutorials/exposing_new_intrinsics.jl +++ b/docs/src/hacking/exposing_new_intrinsics.jl @@ -11,6 +11,8 @@ # where `rn` is round to nearest, `rz` round to zero, `rm` round to minus infinity, `rp` round to plus infinity. # When building the intrinsic for the call, we need to change the type `.f64` with `.d` and `.f32` with `.f` # Therefore, to call the rounded towards infinity `fma` for `.f64` we need to call the intrinsic `llvm.nvvm.fma.rp.d` +# Please remark that this is only possible if LLVM support the intrinsic; a source for those exposed by LLVM +# may be found by searching the [LLVM repository](https://github.com/llvm/llvm-project). In in other cases you'd need @asmcall and inline PTX assembly. fma_rp(x::Float64, y::Float64, z::Float64) = ccall("llvm.nvvm.fma.rp.d", llvmcall, Cdouble, (Cdouble, Cdouble, Cdouble), x, y, z) fma(x::T, y::T, z::T, ::RoundingMode{:Up}) where {T <: Union{Float32, Float64}} = fma_rp(x, y, z) @@ -21,6 +23,7 @@ CUDA.code_ptx(fma_rp, Tuple{Float64,Float64,Float64}) # It is possible to see that the PTX code contains a call to the intrinsic `fma.rp.f64`; we add this function now # to src/device/intrins/math.jl +using CUDA function test_fma!(out, x, y) I = threadIdx().x z = (2.0) ^ (-(I+53)) @@ -44,58 +47,3 @@ out_d = CuArray(zeros(4)) @cuda threads = 4 test_fma!(out_d, -1.0, 1.0) out_h = Array(out_d) -# The binary operations as add, sub, mul, div have been implemented through a macro - -function test_add!(out, x, y) - I = threadIdx().x - if I == 1 - out[I] = CUDA.add(x, y, RoundNearest) - elseif I == 2 - out[I] = CUDA.add(x, y, RoundToZero) - elseif I == 3 - out[I] = CUDA.add(x, y, RoundUp) - elseif I == 4 - out[I] = CUDA.add(x, y, RoundDown) - end - return -end - -out_d = CuArray(zeros(4)) -@cuda threads = 4 test_add!(out_d, 1.0, 2^(-54)) -out_h = Array(out_d) - -function test_sub!(out, x, y) - I = threadIdx().x - if I == 1 - out[I] = CUDA.sub(x, y, RoundNearest) - elseif I == 2 - out[I] = CUDA.sub(x, y, RoundToZero) - elseif I == 3 - out[I] = CUDA.sub(x, y, RoundUp) - elseif I == 4 - out[I] = CUDA.sub(x, y, RoundDown) - end - return -end - -out_d = CuArray(zeros(4)) -@cuda threads = 4 test_sub!(out_d, 1.0, 2^(-53)) -out_h = Array(out_d) - -function test_mul!(out, x, y) - I = threadIdx().x - if I == 1 - out[I] = CUDA.mul(x, y, RoundNearest) - elseif I == 2 - out[I] = CUDA.mul(x, y, RoundToZero) - elseif I == 3 - out[I] = CUDA.mul(x, y, RoundUp) - elseif I == 4 - out[I] = CUDA.mul(x, y, RoundDown) - end - return -end - -out_d = CuArray(zeros(4)) -@cuda threads = 4 test_mul!(out_d, 1.0 - 2^(-52), 1.0 + 2^(-52)) -out_h = Array(out_d) diff --git a/docs/src/tutorials/TODO.jl b/docs/src/tutorials/TODO.jl new file mode 100644 index 0000000000..587ddff7b2 --- /dev/null +++ b/docs/src/tutorials/TODO.jl @@ -0,0 +1,44 @@ +# https://github.com/JuliaGPU/CUDA.jl/pull/1426 + +function kernel_wmma_f64_lowlevel(a_dev, b_dev, c_dev, d_dev) + conf = WMMA.Config{8, 8, 4, Float64, RoundUp} + + # a_frag = WMMA.llvm_wmma_load_a_col_m8n8k4_global_stride_f64(pointer(a_dev), 8) + # b_frag = WMMA.llvm_wmma_load_b_col_m8n8k4_global_stride_f64(pointer(b_dev), 4) + # c_frag = WMMA.llvm_wmma_load_c_col_m8n8k4_global_stride_f64(pointer(c_dev), 8) + + a_frag = WMMA.load_a(pointer(a_dev), 8, ColMajor, conf) + b_frag = WMMA.load_b(pointer(b_dev), 4, ColMajor, conf) + c_frag = WMMA.load_b(pointer(c_dev), 8, ColMajor, conf) + + d_frag = WMMA.llvm_wmma_mma(a_frag, b_frag, c_frag, conf) + #d_frag = WMMA.llvm_wmma_mma_col_col_m8n8k4_f64(a_frag, b_frag, c_frag) + #d_frag = WMMA.llvm_wmma_mma_col_col_m8n8k4_f64(a_frag, b_frag, c_frag, RoundToZero) + #d_frag = WMMA.llvm_wmma_mma_col_col_m8n8k4_f64(a_frag, b_frag, c_frag, RoundUp) + #d_frag = WMMA.llvm_wmma_mma_col_col_m8n8k4_f64(a_frag, b_frag, c_frag, RoundDown) + #@cuprintln d_frag + WWMA.store_d(pointer(d_dev), d_frag, 8, ColMajor, conf) + #ccall("llvm.nvvm.wmma.m8n8k4.store.d.col.stride.f64", llvmcall, + # Nothing, (Core.LLVMPtr{Float64, 1}, Float64, Float64, Int32), + # pointer(d_dev), d_frag[1], d_frag[2], 8) + #WMMA.llvm_wmma_store_d_col_m8n8k4_global_stride_f64(pointer(d_dev), d_frag, 8) + return nothing +end + +function call_kernel() + m = n = 8 + k = 4 + dtype_a = dtype_b = Float64 + dtype_c = dtype_d = Float64 + + d_a = CUDA.rand(dtype_a, m, k) + d_b = CUDA.rand(dtype_b, k, n) + d_c = CUDA.rand(dtype_c, m, n) + d_d = CUDA.zeros(dtype_d, m, n) + + CUDA.@sync @cuda kernel_wmma_f64_lowlevel(d_a, d_b, d_c, d_d) + return nothing +end + +#https://github.com/llvm/llvm-project/blob/main/clang/test/CodeGen/builtins-nvptx-mma.cu + diff --git a/src/device/intrinsics/wmma.jl b/src/device/intrinsics/wmma.jl index 2c7a96dc48..958fec80fd 100644 --- a/src/device/intrinsics/wmma.jl +++ b/src/device/intrinsics/wmma.jl @@ -377,6 +377,10 @@ for ops in all_wmma_ops, else struct_ty = Symbol("LLVMStruct$d_sz") @eval $func_name(a, b, c) = convert(NTuple{$d_sz, $d_frag_ty}, ccall($ccall_name, llvmcall, $struct_ty{$d_frag_ty}, ($(a_types...), $(b_types...), $(c_types...)), $(a_vars...), $(b_vars...), $(c_vars...))) + @eval $func_name(a, b, c, ::RoundingMode{:Nearest}) = $func_name(a, b, c) + @eval $func_name(a, b, c, ::RoundingMode{:ToZero}) = $func_name(a, b, c) + @eval $func_name(a, b, c, ::RoundingMode{:Up}) = $func_name(a, b, c) + @eval $func_name(a, b, c, ::RoundingMode{:Down}) = $func_name(a, b, c) end @eval export $func_name @eval @doc (@doc llvm_wmma_mma) $func_name @@ -562,26 +566,25 @@ end export Config """ - WMMA.Config{M, N, K, d_type} + WMMA.Config{M, N, K, d_type, rounding} Type that contains all information for WMMA operations that cannot be inferred from the argument's types. WMMA instructions calculate the matrix multiply-accumulate operation ``D = A \\cdot B + C``, where ``A`` is a ``M \\times K`` matrix, ``B`` a ``K \\times N`` matrix, and ``C`` and ``D`` are ``M \\times N`` matrices. -`d_type` refers to the type of the elements of matrix ``D``, and can be either `Float16` or `Float32`. +`d_type` refers to the type of the elements of matrix ``D``, and can be either `Float16`, `Float32` or `Float64`. +`rounding` refers to a rounding mode between RoundNearest, RoundToZero, RoundUp and RoundDown, only works with `Float64` All WMMA operations take a `Config` as their final argument. # Examples ```jldoctest -julia> config = WMMA.Config{16, 16, 16, Float32} -CUDA.WMMA.Config{16, 16, 16, Float32} +config = WMMA.Config{16, 16, 16, Float64, RoundNearest} +CUDA.WMMA.Config{16, 16, 16, Float64, RoundingMode{:Nearest}()} ``` """ -struct ConfigRounding{M, N, K, d_type, rounding} end - -Config{M, N, K, d_type} = ConfigRounding{M, N, K, d_type, RoundNearest} +struct Config{M, N, K, d_type, rounding} end # --------- # Constants @@ -692,7 +695,7 @@ for mat in ["a", "b", "c"] @eval @generated function $func_name(addr::LLVMPtr{T, AS}, stride::Number, layout::Type{L}, - config::Type{Config{M, N, K, D_TYPE}}) where {T, AS, L, M, N, K, D_TYPE} + config::Type{Config{M, N, K, D_TYPE, rounding}}) where {T, AS, L, M, N, K, D_TYPE, rounding} as_str = get_hl_as_info(AS) layout = get_hl_layout(L) @@ -740,7 +743,7 @@ mma @generated function mma(a::Fragment{M, N, K, A_SZ, A_T, A_L, MatrixA}, b::Fragment{M, N, K, B_SZ, B_T, B_L, MatrixB}, c::Fragment{M, N, K, C_SZ, C_T, Unspecified, Accumulator}, - config::Type{Config{M, N, K, D_T}}) where {M, N, K, A_SZ, A_T, A_L, B_SZ, B_T, B_L, C_SZ, C_T, D_T} + config::Type{Config{M, N, K, D_T, rounding}}) where {M, N, K, A_SZ, A_T, A_L, B_SZ, B_T, B_L, C_SZ, C_T, D_T} a_layout = get_hl_layout(A_L) b_layout = get_hl_layout(B_L) @@ -761,7 +764,7 @@ mma b_unfl = unflatten(NTuple{$b_frag_sz, $b_frag_ty}, b.x) c_unfl = unflatten(NTuple{$c_frag_sz, $c_frag_ty}, c.x) - x = flatten($wrapper(a_unfl, b_unfl, c_unfl)) + x = flatten($wrapper(a_unfl, b_unfl, c_unfl, rounding)) return Fragment{$M, $N, $K, $d_num_els, $D_T, Unspecified, Accumulator}(x) end end @@ -798,7 +801,7 @@ store_d d::Fragment{M, N, K, D_SZ, T, Unspecified, Accumulator}, stride::Number, layout::Type{L}, - config::Type{Config{M, N, K, T}}) where {T, AS, M, N, K, D_SZ, L} + config::Type{Config{M, N, K, T, rounding}}) where {T, AS, M, N, K, D_SZ, L, rounding} as_str = get_hl_as_info(AS) layout = get_hl_layout(L) From 0e7fa43267f29718ec2d315fb3882ddda1583a35 Mon Sep 17 00:00:00 2001 From: Isaia Nisoli Date: Tue, 10 Dec 2024 20:51:15 +0900 Subject: [PATCH 10/15] Reverted to WMMA.Config without rounding --- src/device/intrinsics/wmma.jl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/device/intrinsics/wmma.jl b/src/device/intrinsics/wmma.jl index 958fec80fd..f71959d0e7 100644 --- a/src/device/intrinsics/wmma.jl +++ b/src/device/intrinsics/wmma.jl @@ -566,7 +566,7 @@ end export Config """ - WMMA.Config{M, N, K, d_type, rounding} + WMMA.Config{M, N, K, d_type} Type that contains all information for WMMA operations that cannot be inferred from the argument's types. @@ -574,7 +574,6 @@ WMMA instructions calculate the matrix multiply-accumulate operation ``D = A \\c ``B`` a ``K \\times N`` matrix, and ``C`` and ``D`` are ``M \\times N`` matrices. `d_type` refers to the type of the elements of matrix ``D``, and can be either `Float16`, `Float32` or `Float64`. -`rounding` refers to a rounding mode between RoundNearest, RoundToZero, RoundUp and RoundDown, only works with `Float64` All WMMA operations take a `Config` as their final argument. @@ -584,7 +583,7 @@ config = WMMA.Config{16, 16, 16, Float64, RoundNearest} CUDA.WMMA.Config{16, 16, 16, Float64, RoundingMode{:Nearest}()} ``` """ -struct Config{M, N, K, d_type, rounding} end +struct Config{M, N, K, d_type} end # --------- # Constants @@ -743,7 +742,7 @@ mma @generated function mma(a::Fragment{M, N, K, A_SZ, A_T, A_L, MatrixA}, b::Fragment{M, N, K, B_SZ, B_T, B_L, MatrixB}, c::Fragment{M, N, K, C_SZ, C_T, Unspecified, Accumulator}, - config::Type{Config{M, N, K, D_T, rounding}}) where {M, N, K, A_SZ, A_T, A_L, B_SZ, B_T, B_L, C_SZ, C_T, D_T} + config::Type{Config{M, N, K, D_T}}) where {M, N, K, A_SZ, A_T, A_L, B_SZ, B_T, B_L, C_SZ, C_T, D_T} a_layout = get_hl_layout(A_L) b_layout = get_hl_layout(B_L) @@ -801,7 +800,7 @@ store_d d::Fragment{M, N, K, D_SZ, T, Unspecified, Accumulator}, stride::Number, layout::Type{L}, - config::Type{Config{M, N, K, T, rounding}}) where {T, AS, M, N, K, D_SZ, L, rounding} + config::Type{Config{M, N, K, T}}) where {T, AS, M, N, K, D_SZ, L} as_str = get_hl_as_info(AS) layout = get_hl_layout(L) From c40489c3f0a0baf65718c31a6e3c78863bdebb8d Mon Sep 17 00:00:00 2001 From: Isaia Nisoli Date: Tue, 10 Dec 2024 21:45:25 +0900 Subject: [PATCH 11/15] Added rounding as a default keyword argument in mma --- src/device/intrinsics/wmma.jl | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/src/device/intrinsics/wmma.jl b/src/device/intrinsics/wmma.jl index f71959d0e7..14b023b5cf 100644 --- a/src/device/intrinsics/wmma.jl +++ b/src/device/intrinsics/wmma.jl @@ -407,7 +407,8 @@ for ops in [wmma_double_ops], end # Name of the Julia wrapper function func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, a_elem_type, rnd]), "_")) - + func_name_no_round = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, a_elem_type, rnd]), "_")) + # Determine types + size for the (matrix, elem_type) combinations for matrix A, B, C and D a_arr_ty, a_frag_ty, a_sz = get_frag_info("a", a_elem_type, shape) b_arr_ty, b_frag_ty, b_sz = get_frag_info("b", b_elem_type, shape) @@ -434,11 +435,24 @@ for ops in [wmma_double_ops], @eval @doc (@doc llvm_wmma_mma) $func_name end +# TODO, rewrite this as a macro + llvm_wmma_mma_col_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Nearest}) = llvm_wmma_mma_col_col_m8n8k4_f64_rn(a_frag, b_frag, c_frag) llvm_wmma_mma_col_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:ToZero}) = llvm_wmma_mma_col_col_m8n8k4_f64_rz(a_frag, b_frag, c_frag) llvm_wmma_mma_col_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Up}) = llvm_wmma_mma_col_col_m8n8k4_f64_rp(a_frag, b_frag, c_frag) llvm_wmma_mma_col_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Down}) = llvm_wmma_mma_col_col_m8n8k4_f64_rm(a_frag, b_frag, c_frag) - +llvm_wmma_mma_row_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Nearest}) = llvm_wmma_mma_row_col_m8n8k4_f64_rn(a_frag, b_frag, c_frag) +llvm_wmma_mma_row_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:ToZero}) = llvm_wmma_mma_row_col_m8n8k4_f64_rz(a_frag, b_frag, c_frag) +llvm_wmma_mma_row_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Up}) = llvm_wmma_mma_row_col_m8n8k4_f64_rp(a_frag, b_frag, c_frag) +llvm_wmma_mma_row_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Down}) = llvm_wmma_mma_row_col_m8n8k4_f64_rm(a_frag, b_frag, c_frag) +llvm_wmma_mma_col_row_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Nearest}) = llvm_wmma_mma_col_row_m8n8k4_f64_rn(a_frag, b_frag, c_frag) +llvm_wmma_mma_col_row_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:ToZero}) = llvm_wmma_mma_col_row_m8n8k4_f64_rz(a_frag, b_frag, c_frag) +llvm_wmma_mma_col_row_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Up}) = llvm_wmma_mma_col_row_m8n8k4_f64_rp(a_frag, b_frag, c_frag) +llvm_wmma_mma_col_row_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Down}) = llvm_wmma_mma_col_row_m8n8k4_f64_rm(a_frag, b_frag, c_frag) +llvm_wmma_mma_row_row_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Nearest}) = llvm_wmma_mma_row_row_m8n8k4_f64_rn(a_frag, b_frag, c_frag) +llvm_wmma_mma_row_row_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:ToZero}) = llvm_wmma_mma_row_row_m8n8k4_f64_rz(a_frag, b_frag, c_frag) +llvm_wmma_mma_row_row_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Up}) = llvm_wmma_mma_row_row_m8n8k4_f64_rp(a_frag, b_frag, c_frag) +llvm_wmma_mma_row_row_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Down}) = llvm_wmma_mma_row_row_m8n8k4_f64_rm(a_frag, b_frag, c_frag) # elseif d_elem_type == "f64" @@ -579,8 +593,8 @@ All WMMA operations take a `Config` as their final argument. # Examples ```jldoctest -config = WMMA.Config{16, 16, 16, Float64, RoundNearest} -CUDA.WMMA.Config{16, 16, 16, Float64, RoundingMode{:Nearest}()} +config = WMMA.Config{16, 16, 16, Float64} +CUDA.WMMA.Config{16, 16, 16, Float64} ``` """ struct Config{M, N, K, d_type} end @@ -694,7 +708,7 @@ for mat in ["a", "b", "c"] @eval @generated function $func_name(addr::LLVMPtr{T, AS}, stride::Number, layout::Type{L}, - config::Type{Config{M, N, K, D_TYPE, rounding}}) where {T, AS, L, M, N, K, D_TYPE, rounding} + config::Type{Config{M, N, K, D_TYPE}}) where {T, AS, L, M, N, K, D_TYPE} as_str = get_hl_as_info(AS) layout = get_hl_layout(L) @@ -742,7 +756,7 @@ mma @generated function mma(a::Fragment{M, N, K, A_SZ, A_T, A_L, MatrixA}, b::Fragment{M, N, K, B_SZ, B_T, B_L, MatrixB}, c::Fragment{M, N, K, C_SZ, C_T, Unspecified, Accumulator}, - config::Type{Config{M, N, K, D_T}}) where {M, N, K, A_SZ, A_T, A_L, B_SZ, B_T, B_L, C_SZ, C_T, D_T} + config::Type{Config{M, N, K, D_T}}; rounding = RoundNearest) where {M, N, K, A_SZ, A_T, A_L, B_SZ, B_T, B_L, C_SZ, C_T, D_T} a_layout = get_hl_layout(A_L) b_layout = get_hl_layout(B_L) From 4de16b188994d625295af3efff463d8a07ae6791 Mon Sep 17 00:00:00 2001 From: Isaia Nisoli Date: Tue, 10 Dec 2024 21:46:07 +0900 Subject: [PATCH 12/15] Removed TODO --- docs/src/tutorials/TODO.jl | 44 -------------------------------------- 1 file changed, 44 deletions(-) delete mode 100644 docs/src/tutorials/TODO.jl diff --git a/docs/src/tutorials/TODO.jl b/docs/src/tutorials/TODO.jl deleted file mode 100644 index 587ddff7b2..0000000000 --- a/docs/src/tutorials/TODO.jl +++ /dev/null @@ -1,44 +0,0 @@ -# https://github.com/JuliaGPU/CUDA.jl/pull/1426 - -function kernel_wmma_f64_lowlevel(a_dev, b_dev, c_dev, d_dev) - conf = WMMA.Config{8, 8, 4, Float64, RoundUp} - - # a_frag = WMMA.llvm_wmma_load_a_col_m8n8k4_global_stride_f64(pointer(a_dev), 8) - # b_frag = WMMA.llvm_wmma_load_b_col_m8n8k4_global_stride_f64(pointer(b_dev), 4) - # c_frag = WMMA.llvm_wmma_load_c_col_m8n8k4_global_stride_f64(pointer(c_dev), 8) - - a_frag = WMMA.load_a(pointer(a_dev), 8, ColMajor, conf) - b_frag = WMMA.load_b(pointer(b_dev), 4, ColMajor, conf) - c_frag = WMMA.load_b(pointer(c_dev), 8, ColMajor, conf) - - d_frag = WMMA.llvm_wmma_mma(a_frag, b_frag, c_frag, conf) - #d_frag = WMMA.llvm_wmma_mma_col_col_m8n8k4_f64(a_frag, b_frag, c_frag) - #d_frag = WMMA.llvm_wmma_mma_col_col_m8n8k4_f64(a_frag, b_frag, c_frag, RoundToZero) - #d_frag = WMMA.llvm_wmma_mma_col_col_m8n8k4_f64(a_frag, b_frag, c_frag, RoundUp) - #d_frag = WMMA.llvm_wmma_mma_col_col_m8n8k4_f64(a_frag, b_frag, c_frag, RoundDown) - #@cuprintln d_frag - WWMA.store_d(pointer(d_dev), d_frag, 8, ColMajor, conf) - #ccall("llvm.nvvm.wmma.m8n8k4.store.d.col.stride.f64", llvmcall, - # Nothing, (Core.LLVMPtr{Float64, 1}, Float64, Float64, Int32), - # pointer(d_dev), d_frag[1], d_frag[2], 8) - #WMMA.llvm_wmma_store_d_col_m8n8k4_global_stride_f64(pointer(d_dev), d_frag, 8) - return nothing -end - -function call_kernel() - m = n = 8 - k = 4 - dtype_a = dtype_b = Float64 - dtype_c = dtype_d = Float64 - - d_a = CUDA.rand(dtype_a, m, k) - d_b = CUDA.rand(dtype_b, k, n) - d_c = CUDA.rand(dtype_c, m, n) - d_d = CUDA.zeros(dtype_d, m, n) - - CUDA.@sync @cuda kernel_wmma_f64_lowlevel(d_a, d_b, d_c, d_d) - return nothing -end - -#https://github.com/llvm/llvm-project/blob/main/clang/test/CodeGen/builtins-nvptx-mma.cu - From 48c36d07a7913fb888b4c162f04e05a7690114dd Mon Sep 17 00:00:00 2001 From: Isaia Nisoli Date: Fri, 20 Dec 2024 09:22:50 -0300 Subject: [PATCH 13/15] Finished reverting Rounding in Config --- src/device/intrinsics/wmma.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/device/intrinsics/wmma.jl b/src/device/intrinsics/wmma.jl index 14b023b5cf..34755138bd 100644 --- a/src/device/intrinsics/wmma.jl +++ b/src/device/intrinsics/wmma.jl @@ -756,7 +756,7 @@ mma @generated function mma(a::Fragment{M, N, K, A_SZ, A_T, A_L, MatrixA}, b::Fragment{M, N, K, B_SZ, B_T, B_L, MatrixB}, c::Fragment{M, N, K, C_SZ, C_T, Unspecified, Accumulator}, - config::Type{Config{M, N, K, D_T}}; rounding = RoundNearest) where {M, N, K, A_SZ, A_T, A_L, B_SZ, B_T, B_L, C_SZ, C_T, D_T} + config::Type{Config{M, N, K, D_T}}) where {M, N, K, A_SZ, A_T, A_L, B_SZ, B_T, B_L, C_SZ, C_T, D_T} a_layout = get_hl_layout(A_L) b_layout = get_hl_layout(B_L) From 5f8fff5bcf419ded759c134fb1476b016a154b24 Mon Sep 17 00:00:00 2001 From: Isaia Nisoli Date: Fri, 20 Dec 2024 10:08:05 -0300 Subject: [PATCH 14/15] Moved tutorial to hacking section --- docs/make.jl | 3 +++ docs/src/{tutorials => hacking}/exposing_new_intrinsics.jl | 0 2 files changed, 3 insertions(+) rename docs/src/{tutorials => hacking}/exposing_new_intrinsics.jl (100%) diff --git a/docs/make.jl b/docs/make.jl index 3cbd62d523..b75e790825 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -60,6 +60,9 @@ function main() "development/troubleshooting.md", "development/debugging.md", ], + "Hacking" => Any[ + "hacking/exposing_new_intrinsics.md", + ], "API reference" => Any[ "api/essentials.md", "api/array.md", diff --git a/docs/src/tutorials/exposing_new_intrinsics.jl b/docs/src/hacking/exposing_new_intrinsics.jl similarity index 100% rename from docs/src/tutorials/exposing_new_intrinsics.jl rename to docs/src/hacking/exposing_new_intrinsics.jl From 3c2d721978e3382a43b92dadee4f15644c925b54 Mon Sep 17 00:00:00 2001 From: Isaia Nisoli Date: Fri, 20 Dec 2024 10:19:35 -0300 Subject: [PATCH 15/15] Revert wmma.jl --- src/device/intrinsics/wmma.jl | 134 ++++------------------------------ 1 file changed, 16 insertions(+), 118 deletions(-) diff --git a/src/device/intrinsics/wmma.jl b/src/device/intrinsics/wmma.jl index 34755138bd..f6da7d90c9 100644 --- a/src/device/intrinsics/wmma.jl +++ b/src/device/intrinsics/wmma.jl @@ -15,8 +15,7 @@ const map_ptx_to_jl_array = Dict( "s8" => Int8, "s32" => Int32, "f16" => Float16, - "f32" => Float32, - "f64" => Float64 + "f32" => Float32 ) # Maps PTX types to Julia fragment types @@ -25,8 +24,7 @@ const map_ptx_to_jl_frag = Dict( "s8" => UInt32, "s32" => Int32, "f16" => NTuple{2, VecElement{Float16}}, - "f32" => Float32, - "f64" => Float64 + "f32" => Float32 ) # Maps matrix & PTX types to fragment sizes, information retrieved from @@ -45,9 +43,6 @@ const map_frag_sizes = Dict( "a.f16.m16n16k16" => 8, "a.f16.m8n32k16" => 8, "a.f16.m32n8k16" => 8, - - "a.f64.m8n8k4" => 1, - # B "b.u8.m16n16k16" => 2, "b.u8.m8n32k16" => 4, @@ -60,9 +55,6 @@ const map_frag_sizes = Dict( "b.f16.m16n16k16" => 8, "b.f16.m8n32k16" => 8, "b.f16.m32n8k16" => 8, - - "b.f64.m8n8k4" => 1, - # C "c.s32.m16n16k16" => 8, "c.s32.m8n32k16" => 8, @@ -75,12 +67,6 @@ const map_frag_sizes = Dict( "c.f32.m16n16k16" => 8, "c.f32.m8n32k16" => 8, "c.f32.m32n8k16" => 8, - - "c.f64.m8n8k4" => 2, # there is a clash of documentation here: - # https://docs.nvidia.com/cuda/parallel-thread-execution/#matrix-fragments-for-mma-m8n8k4-with-f64-floating-point-type - # says `A vector expression containing of two .f64 registers containing two .f64 elements from the matrix C.` - # while https://docs.nvidia.com/cuda/parallel-thread-execution/#matrix-fragments-for-wmma says 1 - # D "d.s32.m16n16k16" => 8, "d.s32.m8n32k16" => 8, @@ -93,8 +79,6 @@ const map_frag_sizes = Dict( "d.f32.m16n16k16" => 8, "d.f32.m8n32k16" => 8, "d.f32.m32n8k16" => 8, - - "d.f64.m8n8k4" => 2, ) # Maps PTX AS to CUDA.AS @@ -114,19 +98,13 @@ const wmma_half_ops = [(16,16,16), (32,8,16), (8,32,16)], ["f16"], ["f16", "f const ldst_int_ab_ops = [(16,16,16), (32,8,16), (8,32,16)], ["a", "b"], ["u8", "s8"] const ldst_int_cd_ops = [(16,16,16), (32,8,16), (8,32,16)], ["c", "d"], ["s32"] const wmma_int_ops = [(16,16,16), (32,8,16), (8,32,16)], ["s8", "u8"], ["s32"], ["s32"] -# Double -const ldst_double_ab_ops = [(8, 8, 4)], ["a", "b"], ["f64"] -const ldst_double_cd_ops = [(8, 8, 4)], ["c", "d"], ["f64"] -const wmma_double_ops = [(8, 8, 4)], ["f64"], ["f64"], ["f64"] - -const all_ldst_ops = vcat(ldst_half_ab_ops, ldst_half_cd_ops, ldst_double_ab_ops, - ldst_int_ab_ops, ldst_int_cd_ops, ldst_double_cd_ops) - -# the wmma_double_ops will be treated separatedly due to rounding + +const all_ldst_ops = vcat(ldst_half_ab_ops, ldst_half_cd_ops, + ldst_int_ab_ops, ldst_int_cd_ops) const all_wmma_ops = vcat(wmma_half_ops, wmma_int_ops) # Valid WMMA operation shapes -const valid_shapes = [(16, 16, 16), (32, 8, 16), (8, 32, 16), (8, 8, 4)] +const valid_shapes = [(16, 16, 16), (32, 8, 16), (8, 32, 16)] ################################################################################ # HELPER FUNCTIONS @@ -280,7 +258,6 @@ export llvm_wmma_store func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "store", mat, layout, shape, addr_space, stride, elem_type]), "_")) # Name of the LLVM intrinsic - #llvm.nvvm.wmma.m8n8k4.store.d.col.stride.f64 llvm_intr = "llvm.nvvm.wmma.$shape.store.$mat.$layout.stride.$elem_type.p$(addr_space_int)" if LLVM.version() < v"17" llvm_intr *= "i8" @@ -288,13 +265,13 @@ export llvm_wmma_store # Determine types + size for this (matrix, elem_type) combination arr_ty, frag_ty, sz = get_frag_info(mat, elem_type, shape) - + ccall_name = "$llvm_intr" frag_types = ntuple(i -> frag_ty, sz) frag_vars = ntuple(i -> :(data[$i]), sz) - + ptr_ty = :(LLVMPtr{$arr_ty, $addr_space_int}) - + @eval $func_name(dst_addr, data, stride) = ccall($ccall_name, llvmcall, Nothing, ($ptr_ty, $(frag_types...), Int32), dst_addr, $(frag_vars...), stride) @eval export $func_name @eval @doc (@doc llvm_wmma_store) $func_name @@ -308,7 +285,6 @@ end WMMA.llvm_wmma_mma_{a_layout}_{b_layout}_{shape}_{d_elem_type}_{c_elem_type}(a, b, c) or WMMA.llvm_wmma_mma_{a_layout}_{b_layout}_{shape}_{a_elem_type}(a, b, c) -For double operations: wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.mma.sync.{a_layout}.{b_layout}.{shape}.{rnd}.{d_elem_type}.{c_elem_type}` For floating point operations: wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.mma.sync.{a_layout}.{b_layout}.{shape}.{d_elem_type}.{c_elem_type}` For all other operations: wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.mma.sync.{a_layout}.{b_layout}.{shape}.{a_elem_type}` @@ -372,59 +348,6 @@ for ops in all_wmma_ops, b_vars = ntuple(i -> :(b[$i]), b_sz) c_vars = ntuple(i -> :(c[$i]), c_sz) - if d_sz == 1 - @eval $func_name(a, b, c) = tuple(ccall($ccall_name, llvmcall, $d_frag_ty, ($(a_types...), $(b_types...), $(c_types...)), $(a_vars...), $(b_vars...), $(c_vars...))) - else - struct_ty = Symbol("LLVMStruct$d_sz") - @eval $func_name(a, b, c) = convert(NTuple{$d_sz, $d_frag_ty}, ccall($ccall_name, llvmcall, $struct_ty{$d_frag_ty}, ($(a_types...), $(b_types...), $(c_types...)), $(a_vars...), $(b_vars...), $(c_vars...))) - @eval $func_name(a, b, c, ::RoundingMode{:Nearest}) = $func_name(a, b, c) - @eval $func_name(a, b, c, ::RoundingMode{:ToZero}) = $func_name(a, b, c) - @eval $func_name(a, b, c, ::RoundingMode{:Up}) = $func_name(a, b, c) - @eval $func_name(a, b, c, ::RoundingMode{:Down}) = $func_name(a, b, c) - end - @eval export $func_name - @eval @doc (@doc llvm_wmma_mma) $func_name -end - -const wmma_double_rounding = ["", "rn", "rz", "rm", "rp"] - -for ops in [wmma_double_ops], - a_layout in ["col", "row"], - b_layout in ["col", "row"], - mnk in ops[1], - rnd in wmma_double_rounding - - a_elem_type = "f64" - b_elem_type = "f64" - c_elem_type = "f64" - d_elem_type = "f64" - - shape = get_hl_shape(mnk[1], mnk[2], mnk[3]) - - llvm_intr = "llvm.nvvm.wmma.$shape.mma.$a_layout.$b_layout.$rnd.f64" - if rnd == "" - llvm_intr = "llvm.nvvm.wmma.$shape.mma.$a_layout.$b_layout.f64" - end - # Name of the Julia wrapper function - func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, a_elem_type, rnd]), "_")) - func_name_no_round = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, a_elem_type, rnd]), "_")) - - # Determine types + size for the (matrix, elem_type) combinations for matrix A, B, C and D - a_arr_ty, a_frag_ty, a_sz = get_frag_info("a", a_elem_type, shape) - b_arr_ty, b_frag_ty, b_sz = get_frag_info("b", b_elem_type, shape) - c_arr_ty, c_frag_ty, c_sz = get_frag_info("c", c_elem_type, shape) - d_arr_ty, d_frag_ty, d_sz = get_frag_info("d", d_elem_type, shape) - - ccall_name = "$llvm_intr" - - a_types = ntuple(i -> a_frag_ty, a_sz) - b_types = ntuple(i -> b_frag_ty, b_sz) - c_types = ntuple(i -> c_frag_ty, c_sz) - - a_vars = ntuple(i -> :(a[$i]), a_sz) - b_vars = ntuple(i -> :(b[$i]), b_sz) - c_vars = ntuple(i -> :(c[$i]), c_sz) - if d_sz == 1 @eval $func_name(a, b, c) = tuple(ccall($ccall_name, llvmcall, $d_frag_ty, ($(a_types...), $(b_types...), $(c_types...)), $(a_vars...), $(b_vars...), $(c_vars...))) else @@ -435,33 +358,6 @@ for ops in [wmma_double_ops], @eval @doc (@doc llvm_wmma_mma) $func_name end -# TODO, rewrite this as a macro - -llvm_wmma_mma_col_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Nearest}) = llvm_wmma_mma_col_col_m8n8k4_f64_rn(a_frag, b_frag, c_frag) -llvm_wmma_mma_col_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:ToZero}) = llvm_wmma_mma_col_col_m8n8k4_f64_rz(a_frag, b_frag, c_frag) -llvm_wmma_mma_col_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Up}) = llvm_wmma_mma_col_col_m8n8k4_f64_rp(a_frag, b_frag, c_frag) -llvm_wmma_mma_col_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Down}) = llvm_wmma_mma_col_col_m8n8k4_f64_rm(a_frag, b_frag, c_frag) -llvm_wmma_mma_row_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Nearest}) = llvm_wmma_mma_row_col_m8n8k4_f64_rn(a_frag, b_frag, c_frag) -llvm_wmma_mma_row_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:ToZero}) = llvm_wmma_mma_row_col_m8n8k4_f64_rz(a_frag, b_frag, c_frag) -llvm_wmma_mma_row_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Up}) = llvm_wmma_mma_row_col_m8n8k4_f64_rp(a_frag, b_frag, c_frag) -llvm_wmma_mma_row_col_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Down}) = llvm_wmma_mma_row_col_m8n8k4_f64_rm(a_frag, b_frag, c_frag) -llvm_wmma_mma_col_row_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Nearest}) = llvm_wmma_mma_col_row_m8n8k4_f64_rn(a_frag, b_frag, c_frag) -llvm_wmma_mma_col_row_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:ToZero}) = llvm_wmma_mma_col_row_m8n8k4_f64_rz(a_frag, b_frag, c_frag) -llvm_wmma_mma_col_row_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Up}) = llvm_wmma_mma_col_row_m8n8k4_f64_rp(a_frag, b_frag, c_frag) -llvm_wmma_mma_col_row_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Down}) = llvm_wmma_mma_col_row_m8n8k4_f64_rm(a_frag, b_frag, c_frag) -llvm_wmma_mma_row_row_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Nearest}) = llvm_wmma_mma_row_row_m8n8k4_f64_rn(a_frag, b_frag, c_frag) -llvm_wmma_mma_row_row_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:ToZero}) = llvm_wmma_mma_row_row_m8n8k4_f64_rz(a_frag, b_frag, c_frag) -llvm_wmma_mma_row_row_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Up}) = llvm_wmma_mma_row_row_m8n8k4_f64_rp(a_frag, b_frag, c_frag) -llvm_wmma_mma_row_row_m8n8k4_f64(a_frag, b_frag, c_frag, ::RoundingMode{:Down}) = llvm_wmma_mma_row_row_m8n8k4_f64_rm(a_frag, b_frag, c_frag) - - -# elseif d_elem_type == "f64" -# llvm_intr = "llvm.nvvm.wmma.$shape.mma.$a_layout.$b_layout.$rnd.f64.f64.f64.f64" -# # Name of the Julia wrapper function -# func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, a_elem_type, rnd]), "_")) - - - ################################################################################ # FLATTENING/UNFLATTENING LOGIC ################################################################################ @@ -587,17 +483,19 @@ Type that contains all information for WMMA operations that cannot be inferred f WMMA instructions calculate the matrix multiply-accumulate operation ``D = A \\cdot B + C``, where ``A`` is a ``M \\times K`` matrix, ``B`` a ``K \\times N`` matrix, and ``C`` and ``D`` are ``M \\times N`` matrices. -`d_type` refers to the type of the elements of matrix ``D``, and can be either `Float16`, `Float32` or `Float64`. +`d_type` refers to the type of the elements of matrix ``D``, and can be either `Float16` or `Float32`. All WMMA operations take a `Config` as their final argument. # Examples ```jldoctest -config = WMMA.Config{16, 16, 16, Float64} -CUDA.WMMA.Config{16, 16, 16, Float64} +julia> config = WMMA.Config{16, 16, 16, Float32} +CUDA.WMMA.Config{16, 16, 16, Float32} ``` """ -struct Config{M, N, K, d_type} end +struct ConfigRounding{M, N, K, d_type, rounding} end + +Config{M, N, K, d_type} = ConfigRounding{M, N, K, d_type, RoundNearest} # --------- # Constants @@ -777,7 +675,7 @@ mma b_unfl = unflatten(NTuple{$b_frag_sz, $b_frag_ty}, b.x) c_unfl = unflatten(NTuple{$c_frag_sz, $c_frag_ty}, c.x) - x = flatten($wrapper(a_unfl, b_unfl, c_unfl, rounding)) + x = flatten($wrapper(a_unfl, b_unfl, c_unfl)) return Fragment{$M, $N, $K, $d_num_els, $D_T, Unspecified, Accumulator}(x) end end