Skip to content

Commit 3ce2e4d

Browse files
authored
[NVPTX] Add tcgen05.cp/shift intrinsics (llvm#127669)
This patch adds intrinsics for tcgen05.cp and tcgen05.shift instructions. lit tests are added and verified with a ptxas-12.8 executable. Docs are updated in the NVPTXUsage.rst file. Signed-off-by: Durgadoss R <durgadossr@nvidia.com>
1 parent 73d0679 commit 3ce2e4d

File tree

5 files changed

+532
-0
lines changed

5 files changed

+532
-0
lines changed

llvm/docs/NVPTXUsage.rst

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1183,6 +1183,93 @@ operations.
11831183
For more information, refer to the PTX ISA
11841184
`<https://docs.nvidia.com/cuda/parallel-thread-execution/#tensorcore-5th-generation-instructions-tcgen05-fence>`_.
11851185

1186+
'``llvm.nvvm.tcgen05.shift``'
1187+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1188+
1189+
Syntax:
1190+
"""""""
1191+
1192+
.. code-block:: llvm
1193+
1194+
declare void @llvm.nvvm.tcgen05.shift.down.cg1(ptr addrspace(6) %tmem_addr)
1195+
declare void @llvm.nvvm.tcgen05.shift.down.cg2(ptr addrspace(6) %tmem_addr)
1196+
1197+
Overview:
1198+
"""""""""
1199+
1200+
The '``@llvm.nvvm.tcgen05.shift.{cg1/cg2}``' intrinsics correspond to
1201+
the ``tcgen05.shift.{cg1/cg2}`` PTX instructions. The ``tcgen05.shift``
1202+
is an asynchronous instruction which initiates the shifting of 32-byte
1203+
elements downwards across all the rows, except the last, by one row.
1204+
The address operand ``%tmem_addr`` specifies the base address of the
1205+
matrix in the Tensor Memory whose rows must be down shifted.
1206+
1207+
For more information, refer to the PTX ISA
1208+
`<https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-shift>`_.
1209+
1210+
'``llvm.nvvm.tcgen05.cp``'
1211+
^^^^^^^^^^^^^^^^^^^^^^^^^^
1212+
1213+
Syntax:
1214+
"""""""
1215+
1216+
.. code-block:: llvm
1217+
1218+
declare void @llvm.nvvm.tcgen05.cp.4x256b.{cg1,cg2}(ptr addrspace(6) %tmem_addr, i64 %sdesc)
1219+
declare void @llvm.nvvm.tcgen05.cp.128x256b.{cg1,cg2}(ptr addrspace(6) %tmem_addr, i64 %sdesc)
1220+
declare void @llvm.nvvm.tcgen05.cp.128x128b.{cg1,cg2}(ptr addrspace(6) %tmem_addr, i64 %sdesc)
1221+
declare void @llvm.nvvm.tcgen05.cp.32x128b_warpx4.{cg1,cg2}(ptr addrspace(6) %tmem_addr, i64 %sdesc)
1222+
declare void @llvm.nvvm.tcgen05.cp.64x128b_warpx2_02_13.{cg1,cg2}(ptr addrspace(6) %tmem_addr, i64 %sdesc)
1223+
declare void @llvm.nvvm.tcgen05.cp.64x128b_warpx2_01_23.{cg1,cg2}(ptr addrspace(6) %tmem_addr, i64 %sdesc)
1224+
1225+
declare void @llvm.nvvm.tcgen05.cp.4x256b.b6x16_p32.{cg1,cg2}(ptr addrspace(6) %tmem_addr, i64 %sdesc)
1226+
declare void @llvm.nvvm.tcgen05.cp.128x256b.b6x16_p32.{cg1,cg2}(ptr addrspace(6) %tmem_addr, i64 %sdesc)
1227+
declare void @llvm.nvvm.tcgen05.cp.128x128b.b6x16_p32.{cg1,cg2}(ptr addrspace(6) %tmem_addr, i64 %sdesc)
1228+
declare void @llvm.nvvm.tcgen05.cp.32x128b_warpx4.b6x16_p32.{cg1,cg2}(ptr addrspace(6) %tmem_addr, i64 %sdesc)
1229+
declare void @llvm.nvvm.tcgen05.cp.64x128b_warpx2_02_13.b6x16_p32.{cg1,cg2}(ptr addrspace(6) %tmem_addr, i64 %sdesc)
1230+
declare void @llvm.nvvm.tcgen05.cp.64x128b_warpx2_01_23.b6x16_p32.{cg1,cg2}(ptr addrspace(6) %tmem_addr, i64 %sdesc)
1231+
1232+
declare void @llvm.nvvm.tcgen05.cp.4x256b.b4x16_p64.{cg1,cg2}(ptr addrspace(6) %tmem_addr, i64 %sdesc)
1233+
declare void @llvm.nvvm.tcgen05.cp.128x256b.b4x16_p64.{cg1,cg2}(ptr addrspace(6) %tmem_addr, i64 %sdesc)
1234+
declare void @llvm.nvvm.tcgen05.cp.128x128b.b4x16_p64.{cg1,cg2}(ptr addrspace(6) %tmem_addr, i64 %sdesc)
1235+
declare void @llvm.nvvm.tcgen05.cp.32x128b_warpx4.b4x16_p64.{cg1,cg2}(ptr addrspace(6) %tmem_addr, i64 %sdesc)
1236+
declare void @llvm.nvvm.tcgen05.cp.64x128b_warpx2_02_13.b4x16_p64.{cg1,cg2}(ptr addrspace(6) %tmem_addr, i64 %sdesc)
1237+
declare void @llvm.nvvm.tcgen05.cp.64x128b_warpx2_01_23.b4x16_p64.{cg1,cg2}(ptr addrspace(6) %tmem_addr, i64 %sdesc)
1238+
1239+
Overview:
1240+
"""""""""
1241+
1242+
The '``@llvm.nvvm.tcgen05.cp.{shape}.{src_fmt}.{cg1/cg2}``' intrinsics
1243+
correspond to the ``tcgen05.cp.*`` family of PTX instructions.
1244+
The ``tcgen05.cp`` instruction initiates an asynchronous copy operation from
1245+
shared memory to the location specified by ``%tmem_addr`` in Tensor Memory.
1246+
The 64-bit register operand ``%sdesc`` is the matrix descriptor representing
1247+
the source matrix in shared memory that needs to be copied.
1248+
1249+
The valid shapes for the copy operation are:
1250+
{128x256b, 4x256b, 128x128b, 64x128b_warpx2_02_13, 64x128b_warpx2_01_23, 32x128b_warpx4}.
1251+
1252+
Shapes ``64x128b`` and ``32x128b`` require dedicated multicast qualifiers,
1253+
which are appended to the corresponding intrinsic names.
1254+
1255+
Optionally, the data can be decompressed from the source format in the shared memory
1256+
to the destination format in Tensor Memory during the copy operation. Currently,
1257+
only ``.b8x16`` is supported as destination format. The valid source formats are
1258+
``.b6x16_p32`` and ``.b4x16_p64``.
1259+
1260+
When the source format is ``.b6x16_p32``, a contiguous set of 16 elements of 6-bits
1261+
each followed by four bytes of padding (``_p32``) in shared memory is decompressed
1262+
into 16 elements of 8-bits (``.b8x16``) each in the Tensor Memory.
1263+
1264+
When the source format is ``.b4x16_p64``, a contiguous set of 16 elements of 4-bits
1265+
each followed by eight bytes of padding (``_p64``) in shared memory is decompressed
1266+
into 16 elements of 8-bits (``.b8x16``) each in the Tensor Memory.
1267+
1268+
For more information on the decompression schemes, refer to the PTX ISA
1269+
`<https://docs.nvidia.com/cuda/parallel-thread-execution/#optional-decompression>`_.
1270+
1271+
For more information on the tcgen05.cp instruction, refer to the PTX ISA
1272+
`<https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-cp>`_.
11861273

11871274
Other Intrinsics
11881275
----------------

llvm/include/llvm/IR/IntrinsicsNVVM.td

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,14 @@ def llvm_tmem_ptr_ty : LLVMQualPointerType<6>; // (tensor memory)ptr
5555
// MISC
5656
//
5757

58+
// Helper class that concatenates list elements with
59+
// a given separator 'sep' and returns the result.
60+
// Handles empty strings.
61+
class StrJoin<string sep, list<string> str_list> {
62+
string ret = !foldl("", str_list, a, b,
63+
!if(!eq(a, ""), b, !if(!eq(b, ""), a, !strconcat(a, sep, b))));
64+
}
65+
5866
// Helper class that represents a 'fragment' of an NVPTX *MMA instruction.
5967
// Geom: m<M>n<N>k<K>. E.g. m8n32k16
6068
// Frag: [a|b|c|d] ([x1|x2|x4] for ldmatrix)
@@ -5140,6 +5148,11 @@ foreach cta_group = ["cg1", "cg2"] in {
51405148
[llvm_shared_ptr_ty, llvm_i16_ty], // mbar_ptr, cta_mask
51415149
[IntrConvergent, IntrInaccessibleMemOrArgMemOnly,
51425150
NoCapture<ArgIndex<0>>]>;
5151+
5152+
def int_nvvm_tcgen05_shift_down_ # cta_group : Intrinsic<[],
5153+
[llvm_tmem_ptr_ty], // tmem_addr
5154+
[IntrConvergent, IntrArgMemOnly,
5155+
NoCapture<ArgIndex<0>>]>;
51435156
}
51445157

51455158
// Tcgen05 wait_ld/st intrinsics
@@ -5154,4 +5167,23 @@ def int_nvvm_tcgen05_fence_before_thread_sync : Intrinsic<[], [],
51545167
def int_nvvm_tcgen05_fence_after_thread_sync : Intrinsic<[], [],
51555168
[IntrNoMem, IntrHasSideEffects]>;
51565169

5170+
// Tcgen05 cp intrinsics
5171+
foreach cta_group = ["cg1", "cg2"] in {
5172+
foreach src_fmt = ["", "b6x16_p32", "b4x16_p64"] in {
5173+
foreach shape = ["128x256b", "4x256b", "128x128b",
5174+
"64x128b_warpx2_02_13",
5175+
"64x128b_warpx2_01_23",
5176+
"32x128b_warpx4"] in {
5177+
defvar intr_suffix = StrJoin<"_", [shape, src_fmt, cta_group]>.ret;
5178+
defvar name_suffix = StrJoin<".", [shape, src_fmt, cta_group]>.ret;
5179+
5180+
def int_nvvm_tcgen05_cp_ # intr_suffix : Intrinsic<[],
5181+
[llvm_tmem_ptr_ty, // tmem_addr
5182+
llvm_i64_ty], // smem descriptor
5183+
[IntrConvergent, IntrInaccessibleMemOrArgMemOnly, NoCapture<ArgIndex<0>>],
5184+
"llvm.nvvm.tcgen05.cp." # name_suffix>;
5185+
}
5186+
}
5187+
}
5188+
51575189
} // let TargetPrefix = "nvvm"

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7704,6 +7704,48 @@ defm TCGEN05_COMMIT_S64_CG2 : TCGEN05_COMMIT_INTR<Int64Regs, "shared", "2">;
77047704
defm TCGEN05_COMMIT_S32_CG1 : TCGEN05_COMMIT_INTR<Int32Regs, "shared", "1">;
77057705
defm TCGEN05_COMMIT_S32_CG2 : TCGEN05_COMMIT_INTR<Int32Regs, "shared", "2">;
77067706

7707+
multiclass TCGEN05_SHIFT_INTR<string num, Intrinsic Intr> {
7708+
def NAME : NVPTXInst<(outs),
7709+
(ins Int32Regs:$tmem_addr),
7710+
!strconcat("tcgen05.shift.cta_group::", num, ".down [$tmem_addr];"),
7711+
[(Intr Int32Regs:$tmem_addr)]>,
7712+
Requires<[hasTcgen05Instructions]>;
7713+
}
7714+
defm TCGEN05_SHIFT_CG1: TCGEN05_SHIFT_INTR<"1", int_nvvm_tcgen05_shift_down_cg1>;
7715+
defm TCGEN05_SHIFT_CG2: TCGEN05_SHIFT_INTR<"2", int_nvvm_tcgen05_shift_down_cg2>;
7716+
7717+
multiclass TCGEN05_CP_INTR<string shape, string src_fmt, string mc = ""> {
7718+
defvar dst_fmt = !if(!eq(src_fmt, ""), "", ".b8x16");
7719+
defvar fmt_asm = StrJoin<".", [dst_fmt, src_fmt]>.ret;
7720+
defvar fmt_intr = StrJoin<"_", [src_fmt]>.ret;
7721+
7722+
defvar shape_mc_asm = StrJoin<".", [shape, mc]>.ret;
7723+
defvar shape_mc_intr = !subst("::", "_", !subst(".", "_", shape_mc_asm));
7724+
7725+
defvar intr_prefix = StrJoin<"_", ["int_nvvm_tcgen05_cp", shape_mc_intr, fmt_intr]>.ret;
7726+
defvar IntrCG1 = !cast<Intrinsic>(intr_prefix # "_cg1");
7727+
defvar IntrCG2 = !cast<Intrinsic>(intr_prefix # "_cg2");
7728+
7729+
def NAME # _cg1 : NVPTXInst<(outs),
7730+
(ins Int32Regs:$tmem_addr, Int64Regs:$sdesc),
7731+
"tcgen05.cp.cta_group::1." # shape_mc_asm # fmt_asm # " [$tmem_addr], $sdesc;",
7732+
[(IntrCG1 Int32Regs:$tmem_addr, Int64Regs:$sdesc)]>,
7733+
Requires<[hasTcgen05Instructions]>;
7734+
def NAME # _cg2 : NVPTXInst<(outs),
7735+
(ins Int32Regs:$tmem_addr, Int64Regs:$sdesc),
7736+
"tcgen05.cp.cta_group::2." # shape_mc_asm # fmt_asm # " [$tmem_addr], $sdesc;",
7737+
[(IntrCG2 Int32Regs:$tmem_addr, Int64Regs:$sdesc)]>,
7738+
Requires<[hasTcgen05Instructions]>;
7739+
}
7740+
7741+
foreach src_fmt = ["", "b6x16_p32", "b4x16_p64"] in {
7742+
defm TCGEN05_CP_128x256b # src_fmt : TCGEN05_CP_INTR<"128x256b", src_fmt>;
7743+
defm TCGEN05_CP_4x256b # src_fmt : TCGEN05_CP_INTR<"4x256b", src_fmt>;
7744+
defm TCGEN05_CP_128x128b # src_fmt : TCGEN05_CP_INTR<"128x128b", src_fmt>;
7745+
defm TCGEN05_CP_64x128_1 # src_fmt : TCGEN05_CP_INTR<"64x128b", src_fmt, "warpx2::02_13">;
7746+
defm TCGEN05_CP_64x128_2 # src_fmt : TCGEN05_CP_INTR<"64x128b", src_fmt, "warpx2::01_23">;
7747+
defm TCGEN05_CP_32x128 # src_fmt : TCGEN05_CP_INTR<"32x128b", src_fmt, "warpx4">;
7748+
}
77077749
} // isConvergent
77087750

77097751
let hasSideEffects = 1 in {

0 commit comments

Comments
 (0)