From 449c0371dfdf1b052ab54a3f29d0311df9954670 Mon Sep 17 00:00:00 2001 From: Carsten Bauer Date: Tue, 1 Mar 2022 21:08:56 +0100 Subject: [PATCH 1/7] first naive attempt --- src/device/intrinsics/wmma.jl | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/src/device/intrinsics/wmma.jl b/src/device/intrinsics/wmma.jl index c12fe526ca..c9089fc513 100644 --- a/src/device/intrinsics/wmma.jl +++ b/src/device/intrinsics/wmma.jl @@ -14,6 +14,7 @@ const map_ptx_to_jl_array = Dict( "s8" => Int8, "s32" => Int32, "f16" => Float16, + "tf32" => Float32, "f32" => Float32 ) @@ -23,6 +24,7 @@ const map_ptx_to_jl_frag = Dict( "s8" => UInt32, "s32" => Int32, "f16" => NTuple{2, VecElement{Float16}}, + "tf32" => Float32, "f32" => Float32 ) @@ -40,6 +42,8 @@ const map_frag_sizes = Dict( "a.f16.m16n16k16" => 8, "a.f16.m8n32k16" => 8, "a.f16.m32n8k16" => 8, + + "a.tf32.m16n16k8" => 8, # B "b.u8.m16n16k16" => 2, "b.u8.m8n32k16" => 4, @@ -52,6 +56,8 @@ const map_frag_sizes = Dict( "b.f16.m16n16k16" => 8, "b.f16.m8n32k16" => 8, "b.f16.m32n8k16" => 8, + + "b.tf32.m16n16k8" => 8, # C "c.s32.m16n16k16" => 8, "c.s32.m8n32k16" => 8, @@ -64,6 +70,8 @@ const map_frag_sizes = Dict( "c.f32.m16n16k16" => 8, "c.f32.m8n32k16" => 8, "c.f32.m32n8k16" => 8, + + "c.f32.m16n16k8" => 8, # D "d.s32.m16n16k16" => 8, "d.s32.m8n32k16" => 8, @@ -76,6 +84,8 @@ const map_frag_sizes = Dict( "d.f32.m16n16k16" => 8, "d.f32.m8n32k16" => 8, "d.f32.m32n8k16" => 8, + + "d.f32.m16n16k8" => 8, ) # Maps PTX AS to CUDA.AS @@ -87,6 +97,10 @@ const map_ptx_as_to_as_ty = Dict( # Valid WMMA Operation configurations: Shape (M,N,K), Matrix, Element Type +# TF32-Precision Floating Point +const ldst_tf32_ab_ops = [(16,16,8)], ["a", "b"], ["tf32"] +const ldst_tf32_cd_ops = [(16,16,8)], ["c", "d"], ["f32"] +const wmma_tf32_ops = [(16,16,8)], ["tf32"], ["f32"], ["f32"] # Half-Precision Floating Point const ldst_half_ab_ops = [(16,16,16), (32,8,16), (8,32,16)], ["a", "b"], ["f16"] const ldst_half_cd_ops = [(16,16,16), (32,8,16), (8,32,16)], ["c", "d"], ["f16", "f32"] @@ -97,11 +111,12 @@ const ldst_int_cd_ops = [(16,16,16), (32,8,16), (8,32,16)], ["c", "d"], ["s32"] const wmma_int_ops = [(16,16,16), (32,8,16), (8,32,16)], ["s8", "u8"], ["s32"], ["s32"] const all_ldst_ops = vcat(ldst_half_ab_ops, ldst_half_cd_ops, - ldst_int_ab_ops, ldst_int_cd_ops) -const all_wmma_ops = vcat(wmma_half_ops, wmma_int_ops) + ldst_int_ab_ops, ldst_int_cd_ops, + ldst_tf32_ab_ops, ldst_tf32_cd_ops) +const all_wmma_ops = vcat(wmma_half_ops, wmma_int_ops, wmma_tf32_ops) # Valid WMMA operation shapes -const valid_shapes = [(16, 16, 16), (32, 8, 16), (8, 32, 16)] +const valid_shapes = [(16, 16, 16), (32, 8, 16), (8, 32, 16), (16,16,8)] ################################################################################ # HELPER FUNCTIONS From afea81642271ccabb80009c96b1f5a9b353501a2 Mon Sep 17 00:00:00 2001 From: Carsten Bauer Date: Wed, 2 Mar 2022 18:38:15 +0100 Subject: [PATCH 2/7] fix tf32 mma llvm intrinsic --- src/device/intrinsics/wmma.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/device/intrinsics/wmma.jl b/src/device/intrinsics/wmma.jl index c9089fc513..ecc08eab92 100644 --- a/src/device/intrinsics/wmma.jl +++ b/src/device/intrinsics/wmma.jl @@ -324,7 +324,7 @@ for ops in all_wmma_ops, # Name of the LLVM intrinsic # If integer/sub-byte/bit A/B types, name is determined by A/B types - if d_elem_type == "s32" + if d_elem_type == "s32" || a_elem_type == "tf32" llvm_intr = "llvm.nvvm.wmma.$shape.mma.$a_layout.$b_layout.$a_elem_type" # Name of the Julia wrapper function func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, a_elem_type]), "_")) From bd7e2e0c7f467526e8cde9bf31d93058ee3278bd Mon Sep 17 00:00:00 2001 From: Carsten Bauer Date: Fri, 4 Mar 2022 16:12:02 +0100 Subject: [PATCH 3/7] fix wmma tf32 fragment sizes for a and b --- src/device/intrinsics/wmma.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/device/intrinsics/wmma.jl b/src/device/intrinsics/wmma.jl index ecc08eab92..66e5294e9b 100644 --- a/src/device/intrinsics/wmma.jl +++ b/src/device/intrinsics/wmma.jl @@ -43,7 +43,7 @@ const map_frag_sizes = Dict( "a.f16.m8n32k16" => 8, "a.f16.m32n8k16" => 8, - "a.tf32.m16n16k8" => 8, + "a.tf32.m16n16k8" => 4, # B "b.u8.m16n16k16" => 2, "b.u8.m8n32k16" => 4, @@ -57,7 +57,7 @@ const map_frag_sizes = Dict( "b.f16.m8n32k16" => 8, "b.f16.m32n8k16" => 8, - "b.tf32.m16n16k8" => 8, + "b.tf32.m16n16k8" => 4, # C "c.s32.m16n16k16" => 8, "c.s32.m8n32k16" => 8, From 17af4d1aa44fdf919d3cac1e6656941190dd1fb5 Mon Sep 17 00:00:00 2001 From: Carsten Bauer Date: Mon, 13 Jun 2022 23:39:57 +0200 Subject: [PATCH 4/7] wmma tf32 tests --- test/device/intrinsics/wmma.jl | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/test/device/intrinsics/wmma.jl b/test/device/intrinsics/wmma.jl index e0d47e34ca..9ca214d5fd 100644 --- a/test/device/intrinsics/wmma.jl +++ b/test/device/intrinsics/wmma.jl @@ -6,8 +6,9 @@ map_ptx_to_jl_frag = Dict( "u32" => UInt32(42), "s32" => Int32(42), "f16" => ntuple(i -> VecElement{Float16}(42), 2), - "f32" => Float32(42) - ) + "f32" => Float32(42), + "tf32" => Float32(42) + ) # Return specific matrix shape given operation configuration function get_array_shape(mat, mnk, layout) if !(mat in ["a","b","c","d"]) @@ -46,13 +47,13 @@ end # Type-dependent variables array_ty = CUDA.WMMA.map_ptx_to_jl_array[elem_type] expected = map_ptx_to_jl_frag[elem_type] - + # Address-space dependent variables do_shared_test = (addr_space == "_shared") # Get the function name func = Symbol("llvm_wmma_load_$(mat)_$(layout)_$(shape)$(addr_space)_stride_$(elem_type)") - + input_shape = get_array_shape(mat, mnk, layout) input = array_ty(42) * ones(array_ty, input_shape) input_dev = CuArray(input) @@ -96,7 +97,7 @@ end elem_type in ops[3], addr_space in ["", "_global", "_shared"], stride in ["stride"] - + # Skip all but d matrices if mat != "d" continue @@ -169,9 +170,9 @@ end ldc_func = getfield(Main, Symbol("llvm_wmma_load_c_col_$(shape)_global_stride_$(c_elem_type)")) # Account for half and int/subint mma different naming conventions # Int/subint mma functions are distinguished by the a/b element type - mma_sym = d_ty == Int32 ? Symbol("llvm_wmma_mma_$(a_layout)_$(b_layout)_$(shape)_$(ab_elem_type)") : + mma_sym = (d_ty == Int32 || ab_elem_type == "tf32") ? Symbol("llvm_wmma_mma_$(a_layout)_$(b_layout)_$(shape)_$(ab_elem_type)") : Symbol("llvm_wmma_mma_$(a_layout)_$(b_layout)_$(shape)_$(d_elem_type)_$(c_elem_type)") - mma_func = getfield(Main, mma_sym) + mma_func = getfield(Main, mma_sym) std_func = getfield(Main, Symbol("llvm_wmma_store_d_col_$(shape)_global_stride_$(d_elem_type)")) a_shape = get_array_shape("a", mnk, a_layout) @@ -205,9 +206,9 @@ end new_a = (a_layout == "col" ? a : transpose(a)) new_b = (b_layout == "col" ? b : transpose(b)) # Alter test depending on a/b element Type - if ab_ty == Float16 + if ab_ty == Float16 || ab_elem_type == "tf32" @test new_a * new_b + c ≈ Array(d_dev) rtol=Base.rtoldefault(Float16) - else # Cast a and b to prevent UInt8 rollover of resultant data + else # Cast a and b to prevent UInt8 rollover of resultant data @test Int32.(new_a) * Int32.(new_b) + c == Array(d_dev) end end @@ -344,4 +345,4 @@ end @test !occursin(r"wmma.store.d.sync(.aligned)?.col.m16n16k16.f32", ptx) @test occursin(r"wmma.store.d.sync(.aligned)?.col.m16n16k16.shared.f32", ptx) end -end \ No newline at end of file +end From f8c5583c8b3dd14f8a776ca6f760bd3fa3368b8b Mon Sep 17 00:00:00 2001 From: Carsten Bauer Date: Wed, 17 Aug 2022 17:36:04 +0200 Subject: [PATCH 5/7] exclude TF32 tests for Julia <= 1.7 --- test/device/intrinsics/wmma.jl | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/test/device/intrinsics/wmma.jl b/test/device/intrinsics/wmma.jl index 9ca214d5fd..4dfb149bc1 100644 --- a/test/device/intrinsics/wmma.jl +++ b/test/device/intrinsics/wmma.jl @@ -42,6 +42,11 @@ end continue end + if mnk == (16,16,8) && VERSION <= v"1.7" + # TensorFlow32 tests require at least Julia 1.8 + continue + end + shape = CUDA.WMMA.get_hl_shape(mnk[1], mnk[2], mnk[3]) # Type-dependent variables @@ -103,6 +108,11 @@ end continue end + if mnk == (16,16,8) && VERSION <= v"1.7" + # TensorFlow32 tests require at least Julia 1.8 + continue + end + shape = CUDA.WMMA.get_hl_shape(mnk[1], mnk[2], mnk[3]) # Type-dependent variables @@ -157,6 +167,11 @@ end d_elem_type in ops[4], c_elem_type in ops[3] + if mnk == (16,16,8) && VERSION <= v"1.7" + # TensorFlow32 tests require at least Julia 1.8 + continue + end + # Type-dependent variables d_ty = CUDA.WMMA.map_ptx_to_jl_array[d_elem_type] c_ty = CUDA.WMMA.map_ptx_to_jl_array[c_elem_type] From dd5a187857bac1269d8b631adf8d1c1910059bc7 Mon Sep 17 00:00:00 2001 From: Carsten Bauer Date: Wed, 17 Aug 2022 17:58:52 +0200 Subject: [PATCH 6/7] add tf32 to docstrings --- src/device/intrinsics/wmma.jl | 316 ++++++++++++++++------------------ 1 file changed, 146 insertions(+), 170 deletions(-) diff --git a/src/device/intrinsics/wmma.jl b/src/device/intrinsics/wmma.jl index 66e5294e9b..a1632e80aa 100644 --- a/src/device/intrinsics/wmma.jl +++ b/src/device/intrinsics/wmma.jl @@ -10,113 +10,89 @@ using Core: LLVMPtr # Maps PTX types to Julia array types const map_ptx_to_jl_array = Dict( - "u8" => UInt8, - "s8" => Int8, - "s32" => Int32, - "f16" => Float16, - "tf32" => Float32, - "f32" => Float32 - ) + "u8" => UInt8, + "s8" => Int8, + "s32" => Int32, + "f16" => Float16, + "tf32" => Float32, + "f32" => Float32 +) # Maps PTX types to Julia fragment types const map_ptx_to_jl_frag = Dict( - "u8" => UInt32, - "s8" => UInt32, - "s32" => Int32, - "f16" => NTuple{2, VecElement{Float16}}, - "tf32" => Float32, - "f32" => Float32 - ) + "u8" => UInt32, + "s8" => UInt32, + "s32" => Int32, + "f16" => NTuple{2,VecElement{Float16}}, + "tf32" => Float32, + "f32" => Float32 +) # Maps matrix & PTX types to fragment sizes const map_frag_sizes = Dict( - # A - "a.u8.m16n16k16" => 2, - "a.u8.m8n32k16" => 1, - "a.u8.m32n8k16" => 4, - - "a.s8.m16n16k16" => 2, - "a.s8.m8n32k16" => 1, - "a.s8.m32n8k16" => 4, - - "a.f16.m16n16k16" => 8, - "a.f16.m8n32k16" => 8, - "a.f16.m32n8k16" => 8, - - "a.tf32.m16n16k8" => 4, - # B - "b.u8.m16n16k16" => 2, - "b.u8.m8n32k16" => 4, - "b.u8.m32n8k16" => 1, - - "b.s8.m16n16k16" => 2, - "b.s8.m8n32k16" => 4, - "b.s8.m32n8k16" => 1, - - "b.f16.m16n16k16" => 8, - "b.f16.m8n32k16" => 8, - "b.f16.m32n8k16" => 8, - - "b.tf32.m16n16k8" => 4, - # C - "c.s32.m16n16k16" => 8, - "c.s32.m8n32k16" => 8, - "c.s32.m32n8k16" => 8, - - "c.f16.m16n16k16" => 4, - "c.f16.m8n32k16" => 4, - "c.f16.m32n8k16" => 4, - - "c.f32.m16n16k16" => 8, - "c.f32.m8n32k16" => 8, - "c.f32.m32n8k16" => 8, - - "c.f32.m16n16k8" => 8, - # D - "d.s32.m16n16k16" => 8, - "d.s32.m8n32k16" => 8, - "d.s32.m32n8k16" => 8, - - "d.f16.m16n16k16" => 4, - "d.f16.m8n32k16" => 4, - "d.f16.m32n8k16" => 4, - - "d.f32.m16n16k16" => 8, - "d.f32.m8n32k16" => 8, - "d.f32.m32n8k16" => 8, - - "d.f32.m16n16k8" => 8, - ) + # A + "a.u8.m16n16k16" => 2, + "a.u8.m8n32k16" => 1, + "a.u8.m32n8k16" => 4, "a.s8.m16n16k16" => 2, + "a.s8.m8n32k16" => 1, + "a.s8.m32n8k16" => 4, "a.f16.m16n16k16" => 8, + "a.f16.m8n32k16" => 8, + "a.f16.m32n8k16" => 8, "a.tf32.m16n16k8" => 4, + # B + "b.u8.m16n16k16" => 2, + "b.u8.m8n32k16" => 4, + "b.u8.m32n8k16" => 1, "b.s8.m16n16k16" => 2, + "b.s8.m8n32k16" => 4, + "b.s8.m32n8k16" => 1, "b.f16.m16n16k16" => 8, + "b.f16.m8n32k16" => 8, + "b.f16.m32n8k16" => 8, "b.tf32.m16n16k8" => 4, + # C + "c.s32.m16n16k16" => 8, + "c.s32.m8n32k16" => 8, + "c.s32.m32n8k16" => 8, "c.f16.m16n16k16" => 4, + "c.f16.m8n32k16" => 4, + "c.f16.m32n8k16" => 4, "c.f32.m16n16k16" => 8, + "c.f32.m8n32k16" => 8, + "c.f32.m32n8k16" => 8, "c.f32.m16n16k8" => 8, + # D + "d.s32.m16n16k16" => 8, + "d.s32.m8n32k16" => 8, + "d.s32.m32n8k16" => 8, "d.f16.m16n16k16" => 4, + "d.f16.m8n32k16" => 4, + "d.f16.m32n8k16" => 4, "d.f32.m16n16k16" => 8, + "d.f32.m8n32k16" => 8, + "d.f32.m32n8k16" => 8, "d.f32.m16n16k8" => 8, +) # Maps PTX AS to CUDA.AS const map_ptx_as_to_as_ty = Dict( - "" => AS.Generic, - "shared" => AS.Shared, - "global" => AS.Global - ) + "" => AS.Generic, + "shared" => AS.Shared, + "global" => AS.Global +) # Valid WMMA Operation configurations: Shape (M,N,K), Matrix, Element Type # TF32-Precision Floating Point -const ldst_tf32_ab_ops = [(16,16,8)], ["a", "b"], ["tf32"] -const ldst_tf32_cd_ops = [(16,16,8)], ["c", "d"], ["f32"] -const wmma_tf32_ops = [(16,16,8)], ["tf32"], ["f32"], ["f32"] +const ldst_tf32_ab_ops = [(16, 16, 8)], ["a", "b"], ["tf32"] +const ldst_tf32_cd_ops = [(16, 16, 8)], ["c", "d"], ["f32"] +const wmma_tf32_ops = [(16, 16, 8)], ["tf32"], ["f32"], ["f32"] # Half-Precision Floating Point -const ldst_half_ab_ops = [(16,16,16), (32,8,16), (8,32,16)], ["a", "b"], ["f16"] -const ldst_half_cd_ops = [(16,16,16), (32,8,16), (8,32,16)], ["c", "d"], ["f16", "f32"] -const wmma_half_ops = [(16,16,16), (32,8,16), (8,32,16)], ["f16"], ["f16", "f32"], ["f16", "f32"] +const ldst_half_ab_ops = [(16, 16, 16), (32, 8, 16), (8, 32, 16)], ["a", "b"], ["f16"] +const ldst_half_cd_ops = [(16, 16, 16), (32, 8, 16), (8, 32, 16)], ["c", "d"], ["f16", "f32"] +const wmma_half_ops = [(16, 16, 16), (32, 8, 16), (8, 32, 16)], ["f16"], ["f16", "f32"], ["f16", "f32"] # Integer -const ldst_int_ab_ops = [(16,16,16), (32,8,16), (8,32,16)], ["a", "b"], ["u8", "s8"] -const ldst_int_cd_ops = [(16,16,16), (32,8,16), (8,32,16)], ["c", "d"], ["s32"] -const wmma_int_ops = [(16,16,16), (32,8,16), (8,32,16)], ["s8", "u8"], ["s32"], ["s32"] +const ldst_int_ab_ops = [(16, 16, 16), (32, 8, 16), (8, 32, 16)], ["a", "b"], ["u8", "s8"] +const ldst_int_cd_ops = [(16, 16, 16), (32, 8, 16), (8, 32, 16)], ["c", "d"], ["s32"] +const wmma_int_ops = [(16, 16, 16), (32, 8, 16), (8, 32, 16)], ["s8", "u8"], ["s32"], ["s32"] const all_ldst_ops = vcat(ldst_half_ab_ops, ldst_half_cd_ops, - ldst_int_ab_ops, ldst_int_cd_ops, - ldst_tf32_ab_ops, ldst_tf32_cd_ops) + ldst_int_ab_ops, ldst_int_cd_ops, + ldst_tf32_ab_ops, ldst_tf32_cd_ops) const all_wmma_ops = vcat(wmma_half_ops, wmma_int_ops, wmma_tf32_ops) # Valid WMMA operation shapes -const valid_shapes = [(16, 16, 16), (32, 8, 16), (8, 32, 16), (16,16,8)] +const valid_shapes = [(16, 16, 16), (32, 8, 16), (8, 32, 16), (16, 16, 8)] ################################################################################ # HELPER FUNCTIONS @@ -132,10 +108,10 @@ end # Returns (Julia array type, Julia fragment type, fragment size) get_frag_info(matrix, ptx_el_type, shape) = ( - map_ptx_to_jl_array[ptx_el_type], - map_ptx_to_jl_frag[ptx_el_type], - map_frag_sizes["$matrix.$ptx_el_type.$shape"] - ) + map_ptx_to_jl_array[ptx_el_type], + map_ptx_to_jl_frag[ptx_el_type], + map_frag_sizes["$matrix.$ptx_el_type.$shape"] +) get_addrspace_info(addr_space) = convert(Int, map_ptx_as_to_as_ty[addr_space]) @@ -152,7 +128,7 @@ for N in unique(values(map_frag_sizes)) Base.Cartesian.@nexprs $N i -> x_i::T end - @eval Base.convert(::Type{NTuple{$N, T}}, x::$struct_ty{T}) where {T} = ntuple(i -> getfield(x, i), $N) + @eval Base.convert(::Type{NTuple{$N,T}}, x::$struct_ty{T}) where {T} = ntuple(i -> getfield(x, i), $N) end ################################################################################ @@ -175,10 +151,10 @@ Wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.load.{matrix}.sync.{layout}.{ # Placeholders - `{matrix}`: The matrix to load. Can be `a`, `b` or `c`. - `{layout}`: The storage layout for the matrix. Can be `row` or `col`, for row major (C style) or column major (Julia style), respectively. -- `{shape}`: The overall shape of the MAC operation. Valid values are `m16n16k16`, `m32n8k16`, and `m8n32k16`. +- `{shape}`: The overall shape of the MAC operation. Valid values are `m16n16k8`, `m16n16k16`, `m32n8k16`, and `m8n32k16`. - `{addr_space}`: The address space of `src_addr`. Can be empty (generic addressing), `shared` or `global`. - `{elem_type}`: The type of each element in the matrix. For `a` and `b` matrices, valid values are `u8` (byte unsigned integer), - `s8` (byte signed integer), and `f16` (half precision floating point). For `c` and `d` matrices, valid values are + `s8` (byte signed integer), `f16` (half precision floating point), and `tf32` (TensorFloat-32). For `c` and `d` matrices, valid values are `s32` (32-bit signed integer), `f16` (half precision floating point), and `f32` (full precision floating point). """ llvm_wmma_load() = error("Cannot call llvm_wmma_load without values for placeholders!") @@ -208,10 +184,10 @@ for ops in all_ldst_ops, ccall_name = "extern $llvm_intr" - ptr_ty = LLVMPtr{arr_ty, addr_space_int} + ptr_ty = LLVMPtr{arr_ty,addr_space_int} struct_ty = Symbol("LLVMStruct$sz") - @eval $func_name(src_addr, stride) = convert(NTuple{$sz, $frag_ty}, ccall($ccall_name, llvmcall, $struct_ty{$frag_ty}, ($ptr_ty, Int32), src_addr, stride)) + @eval $func_name(src_addr, stride) = convert(NTuple{$sz,$frag_ty}, ccall($ccall_name, llvmcall, $struct_ty{$frag_ty}, ($ptr_ty, Int32), src_addr, stride)) @eval export $func_name @eval @doc (@doc llvm_wmma_load) $func_name end @@ -232,22 +208,22 @@ Wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.store.d.sync.{layout}.{shape} # Placeholders - `{layout}`: The storage layout for the matrix. Can be `row` or `col`, for row major (C style) or column major (Julia style), respectively. -- `{shape}`: The overall shape of the MAC operation. Valid values are `m16n16k16`, `m32n8k16`, and `m8n32k16`. +- `{shape}`: The overall shape of the MAC operation. Valid values are `m16n16k8`, `m16n16k16`, `m32n8k16`, and `m8n32k16`. - `{addr_space}`: The address space of `src_addr`. Can be empty (generic addressing), `shared` or `global`. - `{elem_type}`: The type of each element in the matrix. For `a` and `b` matrices, valid values are `u8` (byte unsigned integer), - `s8` (byte signed integer), and `f16` (half precision floating point). For `c` and `d` matrices, valid values are + `s8` (byte signed integer), `f16` (half precision floating point), and `tf32` (TensorFloat-32). For `c` and `d` matrices, valid values are `s32` (32-bit signed integer), `f16` (half precision floating point), and `f32` (full precision floating point). """ llvm_wmma_store() = error("Cannot call llvm_wmma_store without values for placeholders!") export llvm_wmma_store - for ops in all_ldst_ops, - mnk in ops[1], - mat in ops[2], - elem_type in ops[3], - layout in ["col", "row"], - addr_space in ["", "shared", "global"], - stride in ["stride"] +for ops in all_ldst_ops, + mnk in ops[1], + mat in ops[2], + elem_type in ops[3], + layout in ["col", "row"], + addr_space in ["", "shared", "global"], + stride in ["stride"] if mat != "d" continue @@ -272,7 +248,7 @@ export llvm_wmma_store frag_types = ntuple(i -> frag_ty, sz) frag_vars = ntuple(i -> :(data[$i]), sz) - ptr_ty = LLVMPtr{arr_ty, addr_space_int} + ptr_ty = LLVMPtr{arr_ty,addr_space_int} @eval $func_name(dst_addr, data, stride) = ccall($ccall_name, llvmcall, Nothing, ($ptr_ty, $(frag_types...), Int32), dst_addr, $(frag_vars...), stride) @eval export $func_name @@ -298,8 +274,8 @@ For all other operations: wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.mma # Placeholders - `{a_layout}`: The storage layout for matrix ``A``. Can be `row` or `col`, for row major (C style) or column major (Julia style), respectively. Note that this must match the layout used in the load operation. - `{b_layout}`: The storage layout for matrix ``B``. Can be `row` or `col`, for row major (C style) or column major (Julia style), respectively. Note that this must match the layout used in the load operation. -- `{shape}`: The overall shape of the MAC operation. Valid values are `m16n16k16`, `m32n8k16`, and `m8n32k16`. -- `{a_elem_type}`: The type of each element in the ``A`` matrix. Valid values are `u8` (byte unsigned integer), `s8` (byte signed integer), and `f16` (half precision floating point). +- `{shape}`: The overall shape of the MAC operation. Valid values are `m16n16k8`, `m16n16k16`, `m32n8k16`, and `m8n32k16`. +- `{a_elem_type}`: The type of each element in the ``A`` matrix. Valid values are `u8` (byte unsigned integer), `s8` (byte signed integer), `f16` (half precision floating point), and `tf32` (TensorFloat-32). - `{d_elem_type}`: The type of each element in the resultant ``D`` matrix. Valid values are `s32` (32-bit signed integer), `f16` (half precision floating point), and `f32` (full precision floating point). - `{c_elem_type}`: The type of each element in the ``C`` matrix. Valid values are `s32` (32-bit signed integer), `f16` (half precision floating point), and `f32` (full precision floating point). @@ -352,7 +328,7 @@ for ops in all_wmma_ops, struct_ty = Symbol("LLVMStruct$d_sz") - @eval $func_name(a, b, c) = convert(NTuple{$d_sz, $d_frag_ty}, ccall($ccall_name, llvmcall, $struct_ty{$d_frag_ty}, ($(a_types...), $(b_types...), $(c_types...)), $(a_vars...), $(b_vars...), $(c_vars...))) + @eval $func_name(a, b, c) = convert(NTuple{$d_sz,$d_frag_ty}, ccall($ccall_name, llvmcall, $struct_ty{$d_frag_ty}, ($(a_types...), $(b_types...), $(c_types...)), $(a_vars...), $(b_vars...), $(c_vars...))) @eval export $func_name @eval @doc (@doc llvm_wmma_mma) $func_name end @@ -366,11 +342,11 @@ flatten_recurse(typ, e) = [:($e)] unflatten_recurse(typ, e, idx) = :($e[$idx]), idx + 1 # VecElements -flatten_recurse(typ::Type{VecElement{T}}, e) where T = [:($e.value)] -unflatten_recurse(typ::Type{VecElement{T}}, e, idx) where T = :(VecElement{$T}($e[$idx])), idx + 1 +flatten_recurse(typ::Type{VecElement{T}}, e) where {T} = [:($e.value)] +unflatten_recurse(typ::Type{VecElement{T}}, e, idx) where {T} = :(VecElement{$T}($e[$idx])), idx + 1 # NTuples -function flatten_recurse(typ::Type{NTuple{N, T}}, e) where {N, T} +function flatten_recurse(typ::Type{NTuple{N,T}}, e) where {N,T} ret = Expr[] for (i, eltyp) in enumerate(typ.types) @@ -380,7 +356,7 @@ function flatten_recurse(typ::Type{NTuple{N, T}}, e) where {N, T} return ret end -function unflatten_recurse(typ::Type{NTuple{N, T}}, e, idx) where {N, T} +function unflatten_recurse(typ::Type{NTuple{N,T}}, e, idx) where {N,T} ret = Expr(:tuple) for (i, eltyp) in enumerate(typ.types) @@ -391,8 +367,8 @@ function unflatten_recurse(typ::Type{NTuple{N, T}}, e, idx) where {N, T} return ret, idx end -@generated flatten(x::typ) where typ = Expr(:tuple, flatten_recurse(typ, :x)...) -@generated unflatten(::Type{typ}, x) where typ = unflatten_recurse(typ, :x, 1)[1] +@generated flatten(x::typ) where {typ} = Expr(:tuple, flatten_recurse(typ, :x)...) +@generated unflatten(::Type{typ}, x) where {typ} = unflatten_recurse(typ, :x, 1)[1] ################################################################################ # HIGH LEVEL (CUDA-STYLE API) @@ -456,8 +432,8 @@ Type that represents per-thread intermediate results of WMMA operations. You can access individual elements using the `x` member or `[]` operator, but beware that the exact ordering of elements is unspecified. """ -struct Fragment{M, N, K, FS, T, L <: FragmentLayout, U <: FragmentUse} - x::NTuple{FS, T} +struct Fragment{M,N,K,FS,T,L<:FragmentLayout,U<:FragmentUse} + x::NTuple{FS,T} end # ---------------------- @@ -492,7 +468,7 @@ julia> config = WMMA.Config{16, 16, 16, Float32} CUDA.WMMA.Config{16, 16, 16, Float32} ``` """ -struct Config{M, N, K, d_type} end +struct Config{M,N,K,d_type} end # --------- # Constants @@ -506,27 +482,27 @@ const map_as_ty_to_str = Dict(val => key for (key, val) in map_ptx_as_to_as_ty) # Maps layout types to string const map_layout_ty_to_str = Dict( - RowMajor => "row", - ColMajor => "col" - ) + RowMajor => "row", + ColMajor => "col" +) # Maps matrix & type to number of elements (size after flattening) const map_num_elems = Dict( - ("a", Float16) => 16, - ("b", Float16) => 16, - ("c", Float16) => 8, - ("c", Float32) => 8, - ("d", Float16) => 8, - ("d", Float32) => 8 - ) + ("a", Float16) => 16, + ("b", Float16) => 16, + ("c", Float16) => 8, + ("c", Float32) => 8, + ("d", Float16) => 8, + ("d", Float32) => 8 +) # Maps matrix to its use const map_matrix_to_use = Dict( - "a" => MatrixA, - "b" => MatrixB, - "c" => Accumulator, - "d" => Accumulator - ) + "a" => MatrixA, + "b" => MatrixB, + "c" => Accumulator, + "d" => Accumulator +) # ---------------- # Helper functions @@ -561,9 +537,9 @@ function get_hl_frag_info(matrix, T, shape) try return (map_num_elems[(matrix, T)], - map_frag_sizes["$matrix.$ptx_ty.$shape"], - map_ptx_to_jl_frag[ptx_ty], - ptx_ty) + map_frag_sizes["$matrix.$ptx_ty.$shape"], + map_ptx_to_jl_frag[ptx_ty], + ptx_ty) catch error("Invalid type $T for matrix $matrix") end @@ -600,24 +576,24 @@ load_a, load_b, load_c for mat in ["a", "b", "c"] func_name = Symbol("load_$mat") - @eval @generated function $func_name(addr::LLVMPtr{T, AS}, - stride::Number, - layout::Type{L}, - config::Type{Config{M, N, K, D_TYPE}}) where {T, AS, L, M, N, K, D_TYPE} + @eval @generated function $func_name(addr::LLVMPtr{T,AS}, + stride::Number, + layout::Type{L}, + config::Type{Config{M,N,K,D_TYPE}}) where {T,AS,L,M,N,K,D_TYPE} - as_str = get_hl_as_info(AS) - layout = get_hl_layout(L) - shape = get_hl_shape(M, N, K) + as_str = get_hl_as_info(AS) + layout = get_hl_layout(L) + shape = get_hl_shape(M, N, K) num_els, _, _, arr_str = get_hl_frag_info($mat, T, shape) - U = get_hl_mat_use($mat) - L_ret = ($mat == "c") ? Unspecified : L + U = get_hl_mat_use($mat) + L_ret = ($mat == "c") ? Unspecified : L # Name of the Julia wrapper wrapper = Symbol(join(filter(!isempty, ["llvm", "wmma", "load", $mat, layout, shape, as_str, "stride", arr_str]), "_")) return quote x = flatten($wrapper(addr, stride)) - return Fragment{$M, $N, $K, $num_els, $T, $L_ret, $U}(x) + return Fragment{$M,$N,$K,$num_els,$T,$L_ret,$U}(x) end end end @@ -648,32 +624,32 @@ Perform the matrix multiply-accumulate operation ``D = A \\cdot B + C``. """ mma -@generated function mma(a::Fragment{M, N, K, A_SZ, A_T, A_L, MatrixA}, - b::Fragment{M, N, K, B_SZ, B_T, B_L, MatrixB}, - c::Fragment{M, N, K, C_SZ, C_T, Unspecified, Accumulator}, - config::Type{Config{M, N, K, D_T}}) where {M, N, K, A_SZ, A_T, A_L, B_SZ, B_T, B_L, C_SZ, C_T, D_T} +@generated function mma(a::Fragment{M,N,K,A_SZ,A_T,A_L,MatrixA}, + b::Fragment{M,N,K,B_SZ,B_T,B_L,MatrixB}, + c::Fragment{M,N,K,C_SZ,C_T,Unspecified,Accumulator}, + config::Type{Config{M,N,K,D_T}}) where {M,N,K,A_SZ,A_T,A_L,B_SZ,B_T,B_L,C_SZ,C_T,D_T} a_layout = get_hl_layout(A_L) b_layout = get_hl_layout(B_L) shape = get_hl_shape(M, N, K) - _, a_frag_sz, a_frag_ty, _ = get_hl_frag_info("a", A_T, shape) - _, b_frag_sz, b_frag_ty, _ = get_hl_frag_info("b", B_T, shape) + _, a_frag_sz, a_frag_ty, _ = get_hl_frag_info("a", A_T, shape) + _, b_frag_sz, b_frag_ty, _ = get_hl_frag_info("b", B_T, shape) _, c_frag_sz, c_frag_ty, c_arr_str = get_hl_frag_info("c", C_T, shape) - d_num_els, _, _, d_arr_str = get_hl_frag_info("d", D_T, shape) + d_num_els, _, _, d_arr_str = get_hl_frag_info("d", D_T, shape) + - # Name of the Julia wrapper wrapper = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, d_arr_str, c_arr_str]), "_")) return quote - a_unfl = unflatten(NTuple{$a_frag_sz, $a_frag_ty}, a.x) - b_unfl = unflatten(NTuple{$b_frag_sz, $b_frag_ty}, b.x) - c_unfl = unflatten(NTuple{$c_frag_sz, $c_frag_ty}, c.x) + a_unfl = unflatten(NTuple{$a_frag_sz,$a_frag_ty}, a.x) + b_unfl = unflatten(NTuple{$b_frag_sz,$b_frag_ty}, b.x) + c_unfl = unflatten(NTuple{$c_frag_sz,$c_frag_ty}, c.x) x = flatten($wrapper(a_unfl, b_unfl, c_unfl)) - return Fragment{$M, $N, $K, $d_num_els, $D_T, Unspecified, Accumulator}(x) + return Fragment{$M,$N,$K,$d_num_els,$D_T,Unspecified,Accumulator}(x) end end @@ -705,22 +681,22 @@ See also: [`WMMA.Fragment`](@ref), [`WMMA.FragmentLayout`](@ref), [`WMMA.Config` """ store_d -@generated function store_d(addr::LLVMPtr{T, AS}, - d::Fragment{M, N, K, D_SZ, T, Unspecified, Accumulator}, - stride::Number, - layout::Type{L}, - config::Type{Config{M, N, K, T}}) where {T, AS, M, N, K, D_SZ, L} +@generated function store_d(addr::LLVMPtr{T,AS}, + d::Fragment{M,N,K,D_SZ,T,Unspecified,Accumulator}, + stride::Number, + layout::Type{L}, + config::Type{Config{M,N,K,T}}) where {T,AS,M,N,K,D_SZ,L} - as_str = get_hl_as_info(AS) - layout = get_hl_layout(L) - shape = get_hl_shape(M, N, K) + as_str = get_hl_as_info(AS) + layout = get_hl_layout(L) + shape = get_hl_shape(M, N, K) num_els, frag_sz, frag_ty, arr_str = get_hl_frag_info("d", T, shape) # Name of the Julia wrapper wrapper = Symbol(join(filter(!isempty, ["llvm", "wmma", "store", "d", layout, shape, as_str, "stride", arr_str]), "_")) return quote - d_unfl = unflatten(NTuple{$frag_sz, $frag_ty}, d.x) + d_unfl = unflatten(NTuple{$frag_sz,$frag_ty}, d.x) $wrapper(addr, d_unfl, stride) return nothing end @@ -747,18 +723,18 @@ This operation is useful if you want to implement a matrix multiplication (and t fill_c @generated function fill_c(value::T, - config::Type{Config{M, N, K, D_TYPE}}) where {T, M, N, K, D_TYPE} + config::Type{Config{M,N,K,D_TYPE}}) where {T,M,N,K,D_TYPE} # We can't use closures in @generated functions, so we'll have to do it this way instead of # ntuple(i -> val, $num_els) shape = get_hl_shape(M, N, K) num_els, _, _ = get_hl_frag_info("c", T, shape) - args = [:value for i=1:num_els] + args = [:value for i = 1:num_els] expr = :(tuple($(args...))) return quote - return Fragment{$M, $N, $K, $num_els, $T, Unspecified, Accumulator}($expr) + return Fragment{$M,$N,$K,$num_els,$T,Unspecified,Accumulator}($expr) end end From f4b0bb3d2e04da16020c28d267a80ed84a72358c Mon Sep 17 00:00:00 2001 From: Carsten Bauer Date: Thu, 18 Aug 2022 12:23:09 +0200 Subject: [PATCH 7/7] undo formatting --- src/device/intrinsics/wmma.jl | 292 ++++++++++++++++++---------------- 1 file changed, 157 insertions(+), 135 deletions(-) diff --git a/src/device/intrinsics/wmma.jl b/src/device/intrinsics/wmma.jl index a1632e80aa..fe3ddbaf76 100644 --- a/src/device/intrinsics/wmma.jl +++ b/src/device/intrinsics/wmma.jl @@ -10,66 +10,88 @@ using Core: LLVMPtr # Maps PTX types to Julia array types const map_ptx_to_jl_array = Dict( - "u8" => UInt8, - "s8" => Int8, - "s32" => Int32, - "f16" => Float16, - "tf32" => Float32, - "f32" => Float32 -) + "u8" => UInt8, + "s8" => Int8, + "s32" => Int32, + "f16" => Float16, + "tf32" => Float32, + "f32" => Float32 + ) # Maps PTX types to Julia fragment types const map_ptx_to_jl_frag = Dict( - "u8" => UInt32, - "s8" => UInt32, - "s32" => Int32, - "f16" => NTuple{2,VecElement{Float16}}, - "tf32" => Float32, - "f32" => Float32 -) + "u8" => UInt32, + "s8" => UInt32, + "s32" => Int32, + "f16" => NTuple{2, VecElement{Float16}}, + "tf32" => Float32, + "f32" => Float32 + ) # Maps matrix & PTX types to fragment sizes const map_frag_sizes = Dict( - # A - "a.u8.m16n16k16" => 2, - "a.u8.m8n32k16" => 1, - "a.u8.m32n8k16" => 4, "a.s8.m16n16k16" => 2, - "a.s8.m8n32k16" => 1, - "a.s8.m32n8k16" => 4, "a.f16.m16n16k16" => 8, - "a.f16.m8n32k16" => 8, - "a.f16.m32n8k16" => 8, "a.tf32.m16n16k8" => 4, - # B - "b.u8.m16n16k16" => 2, - "b.u8.m8n32k16" => 4, - "b.u8.m32n8k16" => 1, "b.s8.m16n16k16" => 2, - "b.s8.m8n32k16" => 4, - "b.s8.m32n8k16" => 1, "b.f16.m16n16k16" => 8, - "b.f16.m8n32k16" => 8, - "b.f16.m32n8k16" => 8, "b.tf32.m16n16k8" => 4, - # C - "c.s32.m16n16k16" => 8, - "c.s32.m8n32k16" => 8, - "c.s32.m32n8k16" => 8, "c.f16.m16n16k16" => 4, - "c.f16.m8n32k16" => 4, - "c.f16.m32n8k16" => 4, "c.f32.m16n16k16" => 8, - "c.f32.m8n32k16" => 8, - "c.f32.m32n8k16" => 8, "c.f32.m16n16k8" => 8, - # D - "d.s32.m16n16k16" => 8, - "d.s32.m8n32k16" => 8, - "d.s32.m32n8k16" => 8, "d.f16.m16n16k16" => 4, - "d.f16.m8n32k16" => 4, - "d.f16.m32n8k16" => 4, "d.f32.m16n16k16" => 8, - "d.f32.m8n32k16" => 8, - "d.f32.m32n8k16" => 8, "d.f32.m16n16k8" => 8, -) + # A + "a.u8.m16n16k16" => 2, + "a.u8.m8n32k16" => 1, + "a.u8.m32n8k16" => 4, + + "a.s8.m16n16k16" => 2, + "a.s8.m8n32k16" => 1, + "a.s8.m32n8k16" => 4, + + "a.f16.m16n16k16" => 8, + "a.f16.m8n32k16" => 8, + "a.f16.m32n8k16" => 8, + + "a.tf32.m16n16k8" => 4, + # B + "b.u8.m16n16k16" => 2, + "b.u8.m8n32k16" => 4, + "b.u8.m32n8k16" => 1, + + "b.s8.m16n16k16" => 2, + "b.s8.m8n32k16" => 4, + "b.s8.m32n8k16" => 1, + + "b.f16.m16n16k16" => 8, + "b.f16.m8n32k16" => 8, + "b.f16.m32n8k16" => 8, + + "b.tf32.m16n16k8" => 4, + # C + "c.s32.m16n16k16" => 8, + "c.s32.m8n32k16" => 8, + "c.s32.m32n8k16" => 8, + + "c.f16.m16n16k16" => 4, + "c.f16.m8n32k16" => 4, + "c.f16.m32n8k16" => 4, + + "c.f32.m16n16k16" => 8, + "c.f32.m8n32k16" => 8, + "c.f32.m32n8k16" => 8, + "c.f32.m16n16k8" => 8, + # D + "d.s32.m16n16k16" => 8, + "d.s32.m8n32k16" => 8, + "d.s32.m32n8k16" => 8, + + "d.f16.m16n16k16" => 4, + "d.f16.m8n32k16" => 4, + "d.f16.m32n8k16" => 4, + + "d.f32.m16n16k16" => 8, + "d.f32.m8n32k16" => 8, + "d.f32.m32n8k16" => 8, + "d.f32.m16n16k8" => 8, + ) # Maps PTX AS to CUDA.AS const map_ptx_as_to_as_ty = Dict( - "" => AS.Generic, - "shared" => AS.Shared, - "global" => AS.Global -) + "" => AS.Generic, + "shared" => AS.Shared, + "global" => AS.Global + ) # Valid WMMA Operation configurations: Shape (M,N,K), Matrix, Element Type @@ -78,17 +100,17 @@ const ldst_tf32_ab_ops = [(16, 16, 8)], ["a", "b"], ["tf32"] const ldst_tf32_cd_ops = [(16, 16, 8)], ["c", "d"], ["f32"] const wmma_tf32_ops = [(16, 16, 8)], ["tf32"], ["f32"], ["f32"] # Half-Precision Floating Point -const ldst_half_ab_ops = [(16, 16, 16), (32, 8, 16), (8, 32, 16)], ["a", "b"], ["f16"] -const ldst_half_cd_ops = [(16, 16, 16), (32, 8, 16), (8, 32, 16)], ["c", "d"], ["f16", "f32"] -const wmma_half_ops = [(16, 16, 16), (32, 8, 16), (8, 32, 16)], ["f16"], ["f16", "f32"], ["f16", "f32"] +const ldst_half_ab_ops = [(16,16,16), (32,8,16), (8,32,16)], ["a", "b"], ["f16"] +const ldst_half_cd_ops = [(16,16,16), (32,8,16), (8,32,16)], ["c", "d"], ["f16", "f32"] +const wmma_half_ops = [(16,16,16), (32,8,16), (8,32,16)], ["f16"], ["f16", "f32"], ["f16", "f32"] # Integer -const ldst_int_ab_ops = [(16, 16, 16), (32, 8, 16), (8, 32, 16)], ["a", "b"], ["u8", "s8"] -const ldst_int_cd_ops = [(16, 16, 16), (32, 8, 16), (8, 32, 16)], ["c", "d"], ["s32"] -const wmma_int_ops = [(16, 16, 16), (32, 8, 16), (8, 32, 16)], ["s8", "u8"], ["s32"], ["s32"] +const ldst_int_ab_ops = [(16,16,16), (32,8,16), (8,32,16)], ["a", "b"], ["u8", "s8"] +const ldst_int_cd_ops = [(16,16,16), (32,8,16), (8,32,16)], ["c", "d"], ["s32"] +const wmma_int_ops = [(16,16,16), (32,8,16), (8,32,16)], ["s8", "u8"], ["s32"], ["s32"] const all_ldst_ops = vcat(ldst_half_ab_ops, ldst_half_cd_ops, - ldst_int_ab_ops, ldst_int_cd_ops, - ldst_tf32_ab_ops, ldst_tf32_cd_ops) + ldst_int_ab_ops, ldst_int_cd_ops, + ldst_tf32_ab_ops, ldst_tf32_cd_ops) const all_wmma_ops = vcat(wmma_half_ops, wmma_int_ops, wmma_tf32_ops) # Valid WMMA operation shapes @@ -108,10 +130,10 @@ end # Returns (Julia array type, Julia fragment type, fragment size) get_frag_info(matrix, ptx_el_type, shape) = ( - map_ptx_to_jl_array[ptx_el_type], - map_ptx_to_jl_frag[ptx_el_type], - map_frag_sizes["$matrix.$ptx_el_type.$shape"] -) + map_ptx_to_jl_array[ptx_el_type], + map_ptx_to_jl_frag[ptx_el_type], + map_frag_sizes["$matrix.$ptx_el_type.$shape"] + ) get_addrspace_info(addr_space) = convert(Int, map_ptx_as_to_as_ty[addr_space]) @@ -128,7 +150,7 @@ for N in unique(values(map_frag_sizes)) Base.Cartesian.@nexprs $N i -> x_i::T end - @eval Base.convert(::Type{NTuple{$N,T}}, x::$struct_ty{T}) where {T} = ntuple(i -> getfield(x, i), $N) + @eval Base.convert(::Type{NTuple{$N, T}}, x::$struct_ty{T}) where {T} = ntuple(i -> getfield(x, i), $N) end ################################################################################ @@ -184,10 +206,10 @@ for ops in all_ldst_ops, ccall_name = "extern $llvm_intr" - ptr_ty = LLVMPtr{arr_ty,addr_space_int} + ptr_ty = LLVMPtr{arr_ty, addr_space_int} struct_ty = Symbol("LLVMStruct$sz") - @eval $func_name(src_addr, stride) = convert(NTuple{$sz,$frag_ty}, ccall($ccall_name, llvmcall, $struct_ty{$frag_ty}, ($ptr_ty, Int32), src_addr, stride)) + @eval $func_name(src_addr, stride) = convert(NTuple{$sz, $frag_ty}, ccall($ccall_name, llvmcall, $struct_ty{$frag_ty}, ($ptr_ty, Int32), src_addr, stride)) @eval export $func_name @eval @doc (@doc llvm_wmma_load) $func_name end @@ -217,13 +239,13 @@ Wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.store.d.sync.{layout}.{shape} llvm_wmma_store() = error("Cannot call llvm_wmma_store without values for placeholders!") export llvm_wmma_store -for ops in all_ldst_ops, - mnk in ops[1], - mat in ops[2], - elem_type in ops[3], - layout in ["col", "row"], - addr_space in ["", "shared", "global"], - stride in ["stride"] + for ops in all_ldst_ops, + mnk in ops[1], + mat in ops[2], + elem_type in ops[3], + layout in ["col", "row"], + addr_space in ["", "shared", "global"], + stride in ["stride"] if mat != "d" continue @@ -248,7 +270,7 @@ for ops in all_ldst_ops, frag_types = ntuple(i -> frag_ty, sz) frag_vars = ntuple(i -> :(data[$i]), sz) - ptr_ty = LLVMPtr{arr_ty,addr_space_int} + ptr_ty = LLVMPtr{arr_ty, addr_space_int} @eval $func_name(dst_addr, data, stride) = ccall($ccall_name, llvmcall, Nothing, ($ptr_ty, $(frag_types...), Int32), dst_addr, $(frag_vars...), stride) @eval export $func_name @@ -328,7 +350,7 @@ for ops in all_wmma_ops, struct_ty = Symbol("LLVMStruct$d_sz") - @eval $func_name(a, b, c) = convert(NTuple{$d_sz,$d_frag_ty}, ccall($ccall_name, llvmcall, $struct_ty{$d_frag_ty}, ($(a_types...), $(b_types...), $(c_types...)), $(a_vars...), $(b_vars...), $(c_vars...))) + @eval $func_name(a, b, c) = convert(NTuple{$d_sz, $d_frag_ty}, ccall($ccall_name, llvmcall, $struct_ty{$d_frag_ty}, ($(a_types...), $(b_types...), $(c_types...)), $(a_vars...), $(b_vars...), $(c_vars...))) @eval export $func_name @eval @doc (@doc llvm_wmma_mma) $func_name end @@ -342,11 +364,11 @@ flatten_recurse(typ, e) = [:($e)] unflatten_recurse(typ, e, idx) = :($e[$idx]), idx + 1 # VecElements -flatten_recurse(typ::Type{VecElement{T}}, e) where {T} = [:($e.value)] -unflatten_recurse(typ::Type{VecElement{T}}, e, idx) where {T} = :(VecElement{$T}($e[$idx])), idx + 1 +flatten_recurse(typ::Type{VecElement{T}}, e) where T = [:($e.value)] +unflatten_recurse(typ::Type{VecElement{T}}, e, idx) where T = :(VecElement{$T}($e[$idx])), idx + 1 # NTuples -function flatten_recurse(typ::Type{NTuple{N,T}}, e) where {N,T} +function flatten_recurse(typ::Type{NTuple{N, T}}, e) where {N, T} ret = Expr[] for (i, eltyp) in enumerate(typ.types) @@ -356,7 +378,7 @@ function flatten_recurse(typ::Type{NTuple{N,T}}, e) where {N,T} return ret end -function unflatten_recurse(typ::Type{NTuple{N,T}}, e, idx) where {N,T} +function unflatten_recurse(typ::Type{NTuple{N, T}}, e, idx) where {N, T} ret = Expr(:tuple) for (i, eltyp) in enumerate(typ.types) @@ -367,8 +389,8 @@ function unflatten_recurse(typ::Type{NTuple{N,T}}, e, idx) where {N,T} return ret, idx end -@generated flatten(x::typ) where {typ} = Expr(:tuple, flatten_recurse(typ, :x)...) -@generated unflatten(::Type{typ}, x) where {typ} = unflatten_recurse(typ, :x, 1)[1] +@generated flatten(x::typ) where typ = Expr(:tuple, flatten_recurse(typ, :x)...) +@generated unflatten(::Type{typ}, x) where typ = unflatten_recurse(typ, :x, 1)[1] ################################################################################ # HIGH LEVEL (CUDA-STYLE API) @@ -432,8 +454,8 @@ Type that represents per-thread intermediate results of WMMA operations. You can access individual elements using the `x` member or `[]` operator, but beware that the exact ordering of elements is unspecified. """ -struct Fragment{M,N,K,FS,T,L<:FragmentLayout,U<:FragmentUse} - x::NTuple{FS,T} +struct Fragment{M, N, K, FS, T, L <: FragmentLayout, U <: FragmentUse} + x::NTuple{FS, T} end # ---------------------- @@ -468,7 +490,7 @@ julia> config = WMMA.Config{16, 16, 16, Float32} CUDA.WMMA.Config{16, 16, 16, Float32} ``` """ -struct Config{M,N,K,d_type} end +struct Config{M, N, K, d_type} end # --------- # Constants @@ -482,27 +504,27 @@ const map_as_ty_to_str = Dict(val => key for (key, val) in map_ptx_as_to_as_ty) # Maps layout types to string const map_layout_ty_to_str = Dict( - RowMajor => "row", - ColMajor => "col" -) + RowMajor => "row", + ColMajor => "col" + ) # Maps matrix & type to number of elements (size after flattening) const map_num_elems = Dict( - ("a", Float16) => 16, - ("b", Float16) => 16, - ("c", Float16) => 8, - ("c", Float32) => 8, - ("d", Float16) => 8, - ("d", Float32) => 8 -) + ("a", Float16) => 16, + ("b", Float16) => 16, + ("c", Float16) => 8, + ("c", Float32) => 8, + ("d", Float16) => 8, + ("d", Float32) => 8 + ) # Maps matrix to its use const map_matrix_to_use = Dict( - "a" => MatrixA, - "b" => MatrixB, - "c" => Accumulator, - "d" => Accumulator -) + "a" => MatrixA, + "b" => MatrixB, + "c" => Accumulator, + "d" => Accumulator + ) # ---------------- # Helper functions @@ -537,9 +559,9 @@ function get_hl_frag_info(matrix, T, shape) try return (map_num_elems[(matrix, T)], - map_frag_sizes["$matrix.$ptx_ty.$shape"], - map_ptx_to_jl_frag[ptx_ty], - ptx_ty) + map_frag_sizes["$matrix.$ptx_ty.$shape"], + map_ptx_to_jl_frag[ptx_ty], + ptx_ty) catch error("Invalid type $T for matrix $matrix") end @@ -576,24 +598,24 @@ load_a, load_b, load_c for mat in ["a", "b", "c"] func_name = Symbol("load_$mat") - @eval @generated function $func_name(addr::LLVMPtr{T,AS}, - stride::Number, - layout::Type{L}, - config::Type{Config{M,N,K,D_TYPE}}) where {T,AS,L,M,N,K,D_TYPE} + @eval @generated function $func_name(addr::LLVMPtr{T, AS}, + stride::Number, + layout::Type{L}, + config::Type{Config{M, N, K, D_TYPE}}) where {T, AS, L, M, N, K, D_TYPE} - as_str = get_hl_as_info(AS) - layout = get_hl_layout(L) - shape = get_hl_shape(M, N, K) + as_str = get_hl_as_info(AS) + layout = get_hl_layout(L) + shape = get_hl_shape(M, N, K) num_els, _, _, arr_str = get_hl_frag_info($mat, T, shape) - U = get_hl_mat_use($mat) - L_ret = ($mat == "c") ? Unspecified : L + U = get_hl_mat_use($mat) + L_ret = ($mat == "c") ? Unspecified : L # Name of the Julia wrapper wrapper = Symbol(join(filter(!isempty, ["llvm", "wmma", "load", $mat, layout, shape, as_str, "stride", arr_str]), "_")) return quote x = flatten($wrapper(addr, stride)) - return Fragment{$M,$N,$K,$num_els,$T,$L_ret,$U}(x) + return Fragment{$M, $N, $K, $num_els, $T, $L_ret, $U}(x) end end end @@ -624,19 +646,19 @@ Perform the matrix multiply-accumulate operation ``D = A \\cdot B + C``. """ mma -@generated function mma(a::Fragment{M,N,K,A_SZ,A_T,A_L,MatrixA}, - b::Fragment{M,N,K,B_SZ,B_T,B_L,MatrixB}, - c::Fragment{M,N,K,C_SZ,C_T,Unspecified,Accumulator}, - config::Type{Config{M,N,K,D_T}}) where {M,N,K,A_SZ,A_T,A_L,B_SZ,B_T,B_L,C_SZ,C_T,D_T} +@generated function mma(a::Fragment{M, N, K, A_SZ, A_T, A_L, MatrixA}, + b::Fragment{M, N, K, B_SZ, B_T, B_L, MatrixB}, + c::Fragment{M, N, K, C_SZ, C_T, Unspecified, Accumulator}, + config::Type{Config{M, N, K, D_T}}) where {M, N, K, A_SZ, A_T, A_L, B_SZ, B_T, B_L, C_SZ, C_T, D_T} a_layout = get_hl_layout(A_L) b_layout = get_hl_layout(B_L) shape = get_hl_shape(M, N, K) - _, a_frag_sz, a_frag_ty, _ = get_hl_frag_info("a", A_T, shape) - _, b_frag_sz, b_frag_ty, _ = get_hl_frag_info("b", B_T, shape) + _, a_frag_sz, a_frag_ty, _ = get_hl_frag_info("a", A_T, shape) + _, b_frag_sz, b_frag_ty, _ = get_hl_frag_info("b", B_T, shape) _, c_frag_sz, c_frag_ty, c_arr_str = get_hl_frag_info("c", C_T, shape) - d_num_els, _, _, d_arr_str = get_hl_frag_info("d", D_T, shape) + d_num_els, _, _, d_arr_str = get_hl_frag_info("d", D_T, shape) @@ -644,12 +666,12 @@ mma wrapper = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, d_arr_str, c_arr_str]), "_")) return quote - a_unfl = unflatten(NTuple{$a_frag_sz,$a_frag_ty}, a.x) - b_unfl = unflatten(NTuple{$b_frag_sz,$b_frag_ty}, b.x) - c_unfl = unflatten(NTuple{$c_frag_sz,$c_frag_ty}, c.x) + a_unfl = unflatten(NTuple{$a_frag_sz, $a_frag_ty}, a.x) + b_unfl = unflatten(NTuple{$b_frag_sz, $b_frag_ty}, b.x) + c_unfl = unflatten(NTuple{$c_frag_sz, $c_frag_ty}, c.x) x = flatten($wrapper(a_unfl, b_unfl, c_unfl)) - return Fragment{$M,$N,$K,$d_num_els,$D_T,Unspecified,Accumulator}(x) + return Fragment{$M, $N, $K, $d_num_els, $D_T, Unspecified, Accumulator}(x) end end @@ -681,22 +703,22 @@ See also: [`WMMA.Fragment`](@ref), [`WMMA.FragmentLayout`](@ref), [`WMMA.Config` """ store_d -@generated function store_d(addr::LLVMPtr{T,AS}, - d::Fragment{M,N,K,D_SZ,T,Unspecified,Accumulator}, - stride::Number, - layout::Type{L}, - config::Type{Config{M,N,K,T}}) where {T,AS,M,N,K,D_SZ,L} +@generated function store_d(addr::LLVMPtr{T, AS}, + d::Fragment{M, N, K, D_SZ, T, Unspecified, Accumulator}, + stride::Number, + layout::Type{L}, + config::Type{Config{M, N, K, T}}) where {T, AS, M, N, K, D_SZ, L} - as_str = get_hl_as_info(AS) - layout = get_hl_layout(L) - shape = get_hl_shape(M, N, K) + as_str = get_hl_as_info(AS) + layout = get_hl_layout(L) + shape = get_hl_shape(M, N, K) num_els, frag_sz, frag_ty, arr_str = get_hl_frag_info("d", T, shape) # Name of the Julia wrapper wrapper = Symbol(join(filter(!isempty, ["llvm", "wmma", "store", "d", layout, shape, as_str, "stride", arr_str]), "_")) return quote - d_unfl = unflatten(NTuple{$frag_sz,$frag_ty}, d.x) + d_unfl = unflatten(NTuple{$frag_sz, $frag_ty}, d.x) $wrapper(addr, d_unfl, stride) return nothing end @@ -723,18 +745,18 @@ This operation is useful if you want to implement a matrix multiplication (and t fill_c @generated function fill_c(value::T, - config::Type{Config{M,N,K,D_TYPE}}) where {T,M,N,K,D_TYPE} + config::Type{Config{M, N, K, D_TYPE}}) where {T, M, N, K, D_TYPE} # We can't use closures in @generated functions, so we'll have to do it this way instead of # ntuple(i -> val, $num_els) shape = get_hl_shape(M, N, K) num_els, _, _ = get_hl_frag_info("c", T, shape) - args = [:value for i = 1:num_els] + args = [:value for i=1:num_els] expr = :(tuple($(args...))) return quote - return Fragment{$M,$N,$K,$num_els,$T,Unspecified,Accumulator}($expr) + return Fragment{$M, $N, $K, $num_els, $T, Unspecified, Accumulator}($expr) end end