Skip to content

[NVPTX][NFC] Move more TMA intrinsics lowering to tablegen #147576

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

durga4github
Copy link
Contributor

@durga4github durga4github commented Jul 8, 2025

This patch moves the lowering of the TMA Tensor prefetch
and S2G-copy intrinsics to tablegen itself. This is in preparation
for adding Blackwell-specific additions to these intrinsic.

The TMA reduction intrinsics lowering is kept intact (C++), and
hence the macro names are updated to reflect the current usage.

The existing tests have full coverage and continue to pass as expected.

@llvmbot
Copy link
Member

llvmbot commented Jul 8, 2025

@llvm/pr-subscribers-backend-nvptx

Author: Durgadoss R (durga4github)

Changes

This patch moves the lowering of the TMA Tensor prefetch
and S2G-copy intrinsics to tablegen itself. This is in preparation
for adding Blackwell-specific additions to these intrinsics.


Patch is 36.92 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/147576.diff

5 Files Affected:

  • (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp (+25-138)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h (-2)
  • (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+100-65)
  • (modified) llvm/test/CodeGen/NVPTX/cp-async-bulk-tensor-prefetch.ll (+5-5)
  • (modified) llvm/test/CodeGen/NVPTX/cp-async-bulk-tensor-s2g.ll (+10-10)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 5631342ecc13e..e4e337b7f167c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -2157,16 +2157,9 @@ bool NVPTXScopes::empty() const { return Scopes.size() == 0; }
        ? NVPTX::CP_ASYNC_BULK_TENSOR_##dir##_##dim##_SHARED32_##mode##suffix   \
        : NVPTX::CP_ASYNC_BULK_TENSOR_##dir##_##dim##_##mode##suffix)
 
-#define CP_ASYNC_BULK_TENSOR_OPCODE_S2G_IMPL(op, dim, mode, is_ch, is_s32)     \
-  (is_ch ? (CP_ASYNC_BULK_TENSOR_OPCODE(op, dim, mode, is_s32, _CH))           \
-         : (CP_ASYNC_BULK_TENSOR_OPCODE(op, dim, mode, is_s32, )))
-
-#define GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(dim, mode, is_reduce, is_ch,       \
-                                            is_s32)                            \
-  (is_reduce                                                                   \
-       ? (CP_ASYNC_BULK_TENSOR_OPCODE_S2G_IMPL(RED, dim, mode, is_ch, is_s32)) \
-       : (CP_ASYNC_BULK_TENSOR_OPCODE_S2G_IMPL(S2G, dim, mode, is_ch,          \
-                                               is_s32)))
+#define GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G_RED(dim, mode, is_ch, is_s32)      \
+  (is_ch ? (CP_ASYNC_BULK_TENSOR_OPCODE(RED, dim, mode, is_s32, _CH))          \
+         : (CP_ASYNC_BULK_TENSOR_OPCODE(RED, dim, mode, is_s32, )))
 
 #define GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(dim, mode, is_mc, is_ch, is_s32)   \
   [&]() -> auto {                                                              \
@@ -2179,24 +2172,21 @@ bool NVPTXScopes::empty() const { return Scopes.size() == 0; }
     return CP_ASYNC_BULK_TENSOR_OPCODE(G2S, dim, mode, is_s32, );              \
   }()
 
-#define GET_CP_ASYNC_BULK_TENSOR_OPCODE_PREFETCH(dim, mode, is_ch)             \
-  (is_ch ? NVPTX::CP_ASYNC_BULK_TENSOR_PREFETCH_##dim##_##mode##_CH            \
-         : NVPTX::CP_ASYNC_BULK_TENSOR_PREFETCH_##dim##_##mode)
-
-static unsigned GetCpAsyncBulkTensorS2GOpcode(size_t Dim, bool IsShared32,
-                                              bool IsCacheHint, bool IsIm2Col,
-                                              bool IsReduce = false) {
+static unsigned GetCpAsyncBulkTensorS2GReductionOpcode(size_t Dim,
+                                                       bool IsShared32,
+                                                       bool IsCacheHint,
+                                                       bool IsIm2Col) {
   if (IsIm2Col) {
     switch (Dim) {
     case 3:
-      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(3D, IM2COL, IsReduce,
-                                                 IsCacheHint, IsShared32);
+      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G_RED(3D, IM2COL, IsCacheHint,
+                                                     IsShared32);
     case 4:
-      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(4D, IM2COL, IsReduce,
-                                                 IsCacheHint, IsShared32);
+      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G_RED(4D, IM2COL, IsCacheHint,
+                                                     IsShared32);
     case 5:
-      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(5D, IM2COL, IsReduce,
-                                                 IsCacheHint, IsShared32);
+      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G_RED(5D, IM2COL, IsCacheHint,
+                                                     IsShared32);
     default:
       llvm_unreachable("Invalid Dimension in im2col mode for "
                        "GetCpAsyncBulkTensorS2GOpcode.");
@@ -2204,20 +2194,20 @@ static unsigned GetCpAsyncBulkTensorS2GOpcode(size_t Dim, bool IsShared32,
   } else {
     switch (Dim) {
     case 1:
-      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(1D, TILE, IsReduce,
-                                                 IsCacheHint, IsShared32);
+      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G_RED(1D, TILE, IsCacheHint,
+                                                     IsShared32);
     case 2:
-      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(2D, TILE, IsReduce,
-                                                 IsCacheHint, IsShared32);
+      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G_RED(2D, TILE, IsCacheHint,
+                                                     IsShared32);
     case 3:
-      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(3D, TILE, IsReduce,
-                                                 IsCacheHint, IsShared32);
+      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G_RED(3D, TILE, IsCacheHint,
+                                                     IsShared32);
     case 4:
-      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(4D, TILE, IsReduce,
-                                                 IsCacheHint, IsShared32);
+      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G_RED(4D, TILE, IsCacheHint,
+                                                     IsShared32);
     case 5:
-      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(5D, TILE, IsReduce,
-                                                 IsCacheHint, IsShared32);
+      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G_RED(5D, TILE, IsCacheHint,
+                                                     IsShared32);
     default:
       llvm_unreachable(
           "Invalid Dimension in tile mode for GetCpAsyncBulkTensorS2GOpcode.");
@@ -2267,39 +2257,6 @@ static unsigned GetCpAsyncBulkTensorG2SOpcode(size_t Dim, bool IsShared32,
   }
 }
 
-static unsigned GetCpAsyncBulkTensorPrefetchOpcode(size_t Dim, bool IsCacheHint,
-                                                   bool IsIm2Col) {
-  if (IsIm2Col) {
-    switch (Dim) {
-    case 3:
-      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_PREFETCH(3D, IM2COL, IsCacheHint);
-    case 4:
-      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_PREFETCH(4D, IM2COL, IsCacheHint);
-    case 5:
-      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_PREFETCH(5D, IM2COL, IsCacheHint);
-    default:
-      llvm_unreachable("Invalid Dimension in im2col mode for "
-                       "GetCpAsyncBulkTensorPrefetchOpcode.");
-    }
-  } else {
-    switch (Dim) {
-    case 1:
-      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_PREFETCH(1D, TILE, IsCacheHint);
-    case 2:
-      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_PREFETCH(2D, TILE, IsCacheHint);
-    case 3:
-      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_PREFETCH(3D, TILE, IsCacheHint);
-    case 4:
-      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_PREFETCH(4D, TILE, IsCacheHint);
-    case 5:
-      return GET_CP_ASYNC_BULK_TENSOR_OPCODE_PREFETCH(5D, TILE, IsCacheHint);
-    default:
-      llvm_unreachable("Invalid Dimension in tile mode for "
-                       "GetCpAsyncBulkTensorPrefetchOpcode.");
-    }
-  }
-}
-
 static size_t GetDimsFromIntrinsic(unsigned IID) {
   switch (IID) {
   case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d:
@@ -2364,52 +2321,6 @@ void NVPTXDAGToDAGISel::SelectCpAsyncBulkTensorG2SCommon(SDNode *N,
   ReplaceNode(N, CurDAG->getMachineNode(Opcode, DL, N->getVTList(), Ops));
 }
 
-void NVPTXDAGToDAGISel::SelectCpAsyncBulkTensorS2GCommon(SDNode *N,
-                                                         bool IsIm2Col) {
-  // We have {Chain, Intrinsic-ID} followed by the actual intrisic args:
-  // src, dst, dims{d0...dN}, cache_hint, cache_hint_flag
-  // NumOperands = {Chain, IID} + {Actual intrinsic args}
-  //             = {2}          + {4 + dims}
-  size_t NumOps = N->getNumOperands();
-  size_t NumDims = NumOps - 6;
-  bool IsCacheHint = N->getConstantOperandVal(NumOps - 1) == 1;
-  size_t NumArgs = NumDims + (IsCacheHint ? 3 : 2); // src, dst, cache_hint
-
-  SDLoc DL(N);
-  SmallVector<SDValue, 8> Ops(N->ops().slice(2, NumArgs));
-  Ops.push_back(N->getOperand(0)); // Chain operand
-
-  bool IsShared32 =
-      CurDAG->getDataLayout().getPointerSizeInBits(ADDRESS_SPACE_SHARED) == 32;
-  unsigned Opcode =
-      GetCpAsyncBulkTensorS2GOpcode(NumDims, IsShared32, IsCacheHint, IsIm2Col);
-  ReplaceNode(N, CurDAG->getMachineNode(Opcode, DL, N->getVTList(), Ops));
-}
-
-void NVPTXDAGToDAGISel::SelectCpAsyncBulkTensorPrefetchCommon(SDNode *N,
-                                                              bool IsIm2Col) {
-  // We have {Chain, Intrinsic-ID} followed by the actual intrisic args:
-  // {src, dims{d0...dN}, im2col_offsets{dims-2}
-  // cache_hint, cache_hint_flag}
-  // NumOperands = {Chain, IID} + {Actual intrinsic args}
-  //             = {2}          + {3 + dims + im2col_offsets}
-  size_t NumOps = N->getNumOperands();
-  size_t NumDims = IsIm2Col ? GetDimsFromIntrinsic(N->getConstantOperandVal(1))
-                            : (NumOps - 5);
-  // Offsets is always 'NumDims - 2' and only for im2col mode
-  size_t NumOffsets = IsIm2Col ? (NumDims - 2) : 0;
-  bool IsCacheHint = N->getConstantOperandVal(NumOps - 1) == 1;
-  size_t NumArgs = NumDims + NumOffsets + (IsCacheHint ? 2 : 1);
-
-  SDLoc DL(N);
-  SmallVector<SDValue, 12> Ops(N->ops().slice(2, NumArgs));
-  Ops.push_back(N->getOperand(0)); // Chain operand
-
-  unsigned Opcode =
-      GetCpAsyncBulkTensorPrefetchOpcode(NumDims, IsCacheHint, IsIm2Col);
-  ReplaceNode(N, CurDAG->getMachineNode(Opcode, DL, N->getVTList(), Ops));
-}
-
 void NVPTXDAGToDAGISel::SelectCpAsyncBulkTensorReduceCommon(SDNode *N,
                                                             unsigned RedOp,
                                                             bool IsIm2Col) {
@@ -2429,8 +2340,8 @@ void NVPTXDAGToDAGISel::SelectCpAsyncBulkTensorReduceCommon(SDNode *N,
 
   bool IsShared32 =
       CurDAG->getDataLayout().getPointerSizeInBits(ADDRESS_SPACE_SHARED) == 32;
-  unsigned Opcode = GetCpAsyncBulkTensorS2GOpcode(
-      NumDims, IsShared32, IsCacheHint, IsIm2Col, /*IsReduce=*/true);
+  unsigned Opcode = GetCpAsyncBulkTensorS2GReductionOpcode(
+      NumDims, IsShared32, IsCacheHint, IsIm2Col);
   ReplaceNode(N, CurDAG->getMachineNode(Opcode, DL, N->getVTList(), Ops));
 }
 
@@ -2550,18 +2461,6 @@ bool NVPTXDAGToDAGISel::tryIntrinsicVoid(SDNode *N) {
   switch (IID) {
   default:
     return false;
-  case Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_1d:
-  case Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_2d:
-  case Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_3d:
-  case Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_4d:
-  case Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_5d:
-    SelectCpAsyncBulkTensorS2GCommon(N);
-    return true;
-  case Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_3d:
-  case Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_4d:
-  case Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_5d:
-    SelectCpAsyncBulkTensorS2GCommon(N, /*IsIm2Col=*/true);
-    return true;
   case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d:
   case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d:
   case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d:
@@ -2574,18 +2473,6 @@ bool NVPTXDAGToDAGISel::tryIntrinsicVoid(SDNode *N) {
   case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d:
     SelectCpAsyncBulkTensorG2SCommon(N, /*IsIm2Col=*/true);
     return true;
-  case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d:
-  case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d:
-  case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d:
-  case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d:
-  case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d:
-    SelectCpAsyncBulkTensorPrefetchCommon(N);
-    return true;
-  case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d:
-  case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d:
-  case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d:
-    SelectCpAsyncBulkTensorPrefetchCommon(N, /*IsIm2Col=*/true);
-    return true;
   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_1d:
   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_2d:
   case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_3d:
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
index 0e4dec1adca67..016a2a349f9f5 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
@@ -92,8 +92,6 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
   void SelectV2I64toI128(SDNode *N);
   void SelectI128toV2I64(SDNode *N);
   void SelectCpAsyncBulkTensorG2SCommon(SDNode *N, bool IsIm2Col = false);
-  void SelectCpAsyncBulkTensorS2GCommon(SDNode *N, bool IsIm2Col = false);
-  void SelectCpAsyncBulkTensorPrefetchCommon(SDNode *N, bool IsIm2Col = false);
   void SelectCpAsyncBulkTensorReduceCommon(SDNode *N, unsigned RedOp,
                                            bool IsIm2Col = false);
   void SelectTcgen05Ld(SDNode *N, bool hasOffset = false);
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index cc1fd027d8515..6f82dee2d91a6 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -560,6 +560,27 @@ defm CP_ASYNC_BULK_PREFETCH_CH : CP_ASYNC_BULK_PREFETCH_INTR<has_ch = 1>;
 // TMA Async Bulk Tensor Copy Functions
 //-------------------------------------
 
+class TMA_DIMS_UTIL<int dim> {
+  dag ins_dag = !dag(ins, !listsplat(B32, dim), !foreach(i, !range(dim), "d" # i));
+  string base_str = !interleave(!foreach(i, !range(dim), "$d" # i), ", ");
+}
+
+class TMA_IM2COL_UTIL<int dim, string mode> {
+  // For im2col_w/w_128 modes, number of offsets is always 2.
+  // For im2col mode, offsets is (dim - 2).
+  // For non-im2col modes (i.e. tile) there are no offsets.
+  int offsets = !cond(
+                  !eq(mode, "im2col") : !sub(dim, 2),
+                  !eq(mode, "im2col_w") : 2,
+                  !eq(mode, "im2col_w_128") : 2,
+                  true : 0); // for all other modes
+
+  dag ins_dag = !if(!gt(offsets, 0),
+    !dag(ins, !listsplat(B16, offsets), !foreach(i, !range(offsets), "im2col" # i)),
+    (ins));
+  string base_str = !interleave(!foreach(i, !range(offsets), "$im2col" # i), ", ");
+}
+
 // From Global to Shared memory (G2S)
 class G2S_STRINGS<int dim, string mode, bit mc, bit ch, bit is_shared32 = 0> {
   string prefix = "cp.async.bulk.tensor";
@@ -628,39 +649,44 @@ foreach dim = [1, 2, 3, 4, 5] in {
   }
 }
 
-// From Shared to Global memory (S2G)
-class S2G_STRINGS<int dim, string mode, bit ch,
-                  bit is_shared32 = 0, bit is_reduce = 0> {
-  string dir = "global.shared::cta";
-  string completion = "bulk_group";
-  string inst_name = !if(is_reduce, "cp.reduce", "cp")
-                     # ".async.bulk.tensor"
-                     # "." # dim # "d"
-                     # "." # dir
-                     # "." # mode
-                     # "." # completion
-                     # !if(ch, ".L2::cache_hint", "");
-  string intr_name = "CP_ASYNC_BULK_TENSOR_"
-                     # !if(is_reduce, "RED_", "S2G_")
-                     # dim # "D"
-                     # !if(is_shared32, "_SHARED32", "")
-                     # !if(!eq(mode, "tile"), "_TILE", "_IM2COL");
-}
-
-multiclass CP_ASYNC_BULK_TENSOR_S2G_INTR<int dim, bit shared32, string mode> {
-  defvar dims_dag = !dag(ins, !listsplat(B32, dim), !foreach(i, !range(dim), "d" # i));
-  defvar dims_str = !interleave(!foreach(i, !range(dim), "$d" # i), ", ");
+multiclass TMA_TENSOR_S2G_INTR<int dim, string mode,
+                               list<Predicate> pred = [hasPTX<80>, hasSM<90>]> {
+  defvar dims_dag = TMA_DIMS_UTIL<dim>.ins_dag;
+  defvar dims_str = TMA_DIMS_UTIL<dim>.base_str;
   defvar asm_str = " [$tmap, {{" # dims_str # "}}], [$src]";
-  defvar rc = !if(shared32, B32, B64);
+
+  defvar intr = !cast<Intrinsic>(
+                  "int_nvvm_cp_async_bulk_tensor_s2g_" # mode # "_" # dim # d);
+  defvar intr_dag = !con(
+                      (intr addr:$src, B64:$tmap),
+                      !setdagop(dims_dag, intr),
+                      (intr B64:$ch));
+
+  // For im2col mode, the actual asm_str is "im2col_no_offs"
+  defvar mode_asm_str = !if(!eq(mode, "im2col"),
+                            "im2col_no_offs", mode);
+  defvar prefix = "cp.async.bulk.tensor"
+                  # "." # dim # "d"
+                  # ".global.shared::cta"
+                  # "." # mode_asm_str
+                  # ".bulk_group";
 
   def "" : NVPTXInst<(outs),
-            !con((ins rc:$src, B64:$tmap), dims_dag),
-            !strconcat(S2G_STRINGS<dim, mode, 0>.inst_name, asm_str, ";"), []>,
-            Requires<[hasPTX<80>, hasSM<90>]>;
+             !con((ins ADDR:$src, B64:$tmap), dims_dag, (ins B64:$ch)),
+             prefix # asm_str # ";",
+             [!con(intr_dag, (intr 0))]>,
+             Requires<pred>;
   def _CH : NVPTXInst<(outs),
-                  !con((ins rc:$src, B64:$tmap), dims_dag, (ins B64:$ch)),
-                  !strconcat(S2G_STRINGS<dim, mode, 1>.inst_name, asm_str, ", $ch;"), []>,
-                  Requires<[hasPTX<80>, hasSM<90>]>;
+              !con((ins ADDR:$src, B64:$tmap), dims_dag, (ins B64:$ch)),
+              prefix # ".L2::cache_hint" # asm_str # ", $ch;",
+              [!con(intr_dag, (intr -1))]>,
+              Requires<pred>;
+}
+foreach dim = 1...5 in {
+  foreach mode = !if(!ge(dim, 3), ["tile", "im2col"], ["tile"]) in {
+    defvar suffix = !toupper(mode) # "_" # dim # D;
+    defm TMA_TENSOR_S2G_ # suffix : TMA_TENSOR_S2G_INTR<dim, mode>;
+  }
 }
 
 def TMAReductionFlags : Operand<i32> {
@@ -674,8 +700,11 @@ multiclass CP_ASYNC_BULK_TENSOR_REDUCE_INTR<int dim, bit shared32, string mode>
   defvar asm_str = " [$tmap, {{" # dims_str # "}}], [$src]";
   defvar rc = !if(shared32, B32, B64);
 
+  // For im2col mode, the actual asm_str is "im2col_no_offs"
+  defvar mode_asm_str = !if(!eq(mode, "im2col"),
+                            "im2col_no_offs", mode);
   defvar prefix = "cp.reduce.async.bulk.tensor" # "." # dim # "d" # ".global.shared::cta";
-  defvar suffix = "." # mode # ".bulk_group";
+  defvar suffix = "." # mode_asm_str # ".bulk_group";
 
   def "" : NVPTXInst<(outs),
             !con((ins rc:$src, B64:$tmap), dims_dag, (ins TMAReductionFlags:$red_op)),
@@ -689,10 +718,11 @@ multiclass CP_ASYNC_BULK_TENSOR_REDUCE_INTR<int dim, bit shared32, string mode>
 
 foreach dim = [1, 2, 3, 4, 5] in {
   foreach shared32 = [true, false] in {
-    foreach mode = !if(!ge(dim, 3), ["tile", "im2col_no_offs"], ["tile"]) in {
-      defm S2G_STRINGS<dim, mode, 0, shared32>.intr_name :
-        CP_ASYNC_BULK_TENSOR_S2G_INTR<dim, shared32, mode>;
-      defm S2G_STRINGS<dim, mode, 0, shared32, 1>.intr_name :
+    foreach mode = !if(!ge(dim, 3), ["tile", "im2col"], ["tile"]) in {
+      defvar suffix = dim # "D"
+                      # !if(shared32, "_SHARED32", "")
+                      # "_" # !toupper(mode);
+      defm CP_ASYNC_BULK_TENSOR_RED_ # suffix :
         CP_ASYNC_BULK_TENSOR_REDUCE_INTR<dim, shared32, mode>;
     }
   }
@@ -707,40 +737,45 @@ class PREFETCH_STRINGS<int dim, string mode, bit ch> {
                      # "." # dir
                      # "." # mode
                      # !if(ch, ".L2::cache_hint", "");
-  string intr_name = "CP_ASYNC_BULK_TENSOR_PREFETCH_"
-                     # dim # "D"
-                     # !if(!eq(mode, "tile"), "_TILE", "_IM2COL");
-}
-
-multiclass CP_ASYNC_BULK_TENSOR_PREFETCH_INTR<int dim, string mode> {
-  defvar dims_dag = !dag(ins, !listsplat(B32, dim), !foreach(i, !range(dim), "d" # i));
-  defvar dims_str = !interleave(!foreach(i, !range(dim), "$d" # i), ", ");
-  defvar asm_str_default = " [$tmap, {{" # dims_str # "}}]";
-
-  defvar num_im2col = !if(!ge(dim, 3), !add(dim, -2), 0);
-  defvar im2col_dag = !if(!eq(mode, "im2col"),
-    !dag(ins, !listsplat(B16, num_im2col), !foreach(i, !range(num_im2col), "im2col" # i)),
-    (ins));
-  defvar im2col_str = !interleave(!foreach(i, !range(num_im2col), "$im2col" # i), ", ");
-  defvar im2col_asm_str = ", {{" # im2col_str # "}}";
-
-  defvar asm_str = !if(!eq(mode, "im2col"),
-    !strconcat(asm_str_default, im2col_asm_str), asm_str_default);
-
-  def "" : NVPTXInst<(outs),
-            !con((ins B64:$tmap), dims_dag, im2col_dag),
-            !strconcat(PREFETCH_STRINGS<dim, mode, 0>.inst_name, asm_str, ";"), []>,
-            Requires<[hasPTX<80>, hasSM<90>]>;
-  def _CH : NVPTXInst<(outs),
-                  !con((ins B64:$tmap), dims_dag, im2col_dag, (ins B64:$ch)),
-                  !strconcat(PREFETCH_STRINGS<dim, mode, 1>.inst_name, asm_st...
[truncated]

@durga4github durga4github requested a review from Artem-B July 8, 2025 17:59
@durga4github durga4github force-pushed the durgadossr/nvptx_tma_nfc_refactor branch from 6baf029 to 0535ac2 Compare July 8, 2025 18:05
Copy link
Member

@Artem-B Artem-B left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with a few nits.

@durga4github durga4github force-pushed the durgadossr/nvptx_tma_nfc_refactor branch 2 times, most recently from 2373fb3 to 96682f8 Compare July 9, 2025 12:07
This patch moves the lowering of the TMA Tensor prefetch
and S2G-copy intrinsics to tablegen itself. This is in
preparation for adding blackwell specific additions to
these intrinsics.

Signed-off-by: Durgadoss R <durgadossr@nvidia.com>
@durga4github durga4github force-pushed the durgadossr/nvptx_tma_nfc_refactor branch from 96682f8 to a4b5960 Compare July 9, 2025 12:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants