Skip to content

[NVPTX][NFC] Move more TMA intrinsics lowering to tablegen #147576

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 28 additions & 141 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2157,16 +2157,9 @@ bool NVPTXScopes::empty() const { return Scopes.size() == 0; }
? NVPTX::CP_ASYNC_BULK_TENSOR_##dir##_##dim##_SHARED32_##mode##suffix \
: NVPTX::CP_ASYNC_BULK_TENSOR_##dir##_##dim##_##mode##suffix)

#define CP_ASYNC_BULK_TENSOR_OPCODE_S2G_IMPL(op, dim, mode, is_ch, is_s32) \
(is_ch ? (CP_ASYNC_BULK_TENSOR_OPCODE(op, dim, mode, is_s32, _CH)) \
: (CP_ASYNC_BULK_TENSOR_OPCODE(op, dim, mode, is_s32, )))

#define GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(dim, mode, is_reduce, is_ch, \
is_s32) \
(is_reduce \
? (CP_ASYNC_BULK_TENSOR_OPCODE_S2G_IMPL(RED, dim, mode, is_ch, is_s32)) \
: (CP_ASYNC_BULK_TENSOR_OPCODE_S2G_IMPL(S2G, dim, mode, is_ch, \
is_s32)))
#define GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G_RED(dim, mode, is_ch, is_s32) \
(is_ch ? (CP_ASYNC_BULK_TENSOR_OPCODE(RED, dim, mode, is_s32, _CH)) \
: (CP_ASYNC_BULK_TENSOR_OPCODE(RED, dim, mode, is_s32, )))

#define GET_CP_ASYNC_BULK_TENSOR_OPCODE_G2S(dim, mode, is_mc, is_ch, is_s32) \
[&]() -> auto { \
Expand All @@ -2179,48 +2172,45 @@ bool NVPTXScopes::empty() const { return Scopes.size() == 0; }
return CP_ASYNC_BULK_TENSOR_OPCODE(G2S, dim, mode, is_s32, ); \
}()

#define GET_CP_ASYNC_BULK_TENSOR_OPCODE_PREFETCH(dim, mode, is_ch) \
(is_ch ? NVPTX::CP_ASYNC_BULK_TENSOR_PREFETCH_##dim##_##mode##_CH \
: NVPTX::CP_ASYNC_BULK_TENSOR_PREFETCH_##dim##_##mode)

static unsigned GetCpAsyncBulkTensorS2GOpcode(size_t Dim, bool IsShared32,
bool IsCacheHint, bool IsIm2Col,
bool IsReduce = false) {
static unsigned GetCpAsyncBulkTensorS2GReductionOpcode(size_t Dim,
bool IsShared32,
bool IsCacheHint,
bool IsIm2Col) {
if (IsIm2Col) {
switch (Dim) {
case 3:
return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(3D, IM2COL, IsReduce,
IsCacheHint, IsShared32);
return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G_RED(3D, IM2COL, IsCacheHint,
IsShared32);
case 4:
return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(4D, IM2COL, IsReduce,
IsCacheHint, IsShared32);
return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G_RED(4D, IM2COL, IsCacheHint,
IsShared32);
case 5:
return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(5D, IM2COL, IsReduce,
IsCacheHint, IsShared32);
return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G_RED(5D, IM2COL, IsCacheHint,
IsShared32);
default:
llvm_unreachable("Invalid Dimension in im2col mode for "
"GetCpAsyncBulkTensorS2GOpcode.");
"GetCpAsyncBulkTensorS2GReductionOpcode.");
}
} else {
switch (Dim) {
case 1:
return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(1D, TILE, IsReduce,
IsCacheHint, IsShared32);
return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G_RED(1D, TILE, IsCacheHint,
IsShared32);
case 2:
return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(2D, TILE, IsReduce,
IsCacheHint, IsShared32);
return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G_RED(2D, TILE, IsCacheHint,
IsShared32);
case 3:
return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(3D, TILE, IsReduce,
IsCacheHint, IsShared32);
return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G_RED(3D, TILE, IsCacheHint,
IsShared32);
case 4:
return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(4D, TILE, IsReduce,
IsCacheHint, IsShared32);
return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G_RED(4D, TILE, IsCacheHint,
IsShared32);
case 5:
return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G(5D, TILE, IsReduce,
IsCacheHint, IsShared32);
return GET_CP_ASYNC_BULK_TENSOR_OPCODE_S2G_RED(5D, TILE, IsCacheHint,
IsShared32);
default:
llvm_unreachable(
"Invalid Dimension in tile mode for GetCpAsyncBulkTensorS2GOpcode.");
llvm_unreachable("Invalid Dimension in tile mode for "
"GetCpAsyncBulkTensorS2GReductionOpcode.");
}
}
}
Expand Down Expand Up @@ -2267,39 +2257,6 @@ static unsigned GetCpAsyncBulkTensorG2SOpcode(size_t Dim, bool IsShared32,
}
}

static unsigned GetCpAsyncBulkTensorPrefetchOpcode(size_t Dim, bool IsCacheHint,
bool IsIm2Col) {
if (IsIm2Col) {
switch (Dim) {
case 3:
return GET_CP_ASYNC_BULK_TENSOR_OPCODE_PREFETCH(3D, IM2COL, IsCacheHint);
case 4:
return GET_CP_ASYNC_BULK_TENSOR_OPCODE_PREFETCH(4D, IM2COL, IsCacheHint);
case 5:
return GET_CP_ASYNC_BULK_TENSOR_OPCODE_PREFETCH(5D, IM2COL, IsCacheHint);
default:
llvm_unreachable("Invalid Dimension in im2col mode for "
"GetCpAsyncBulkTensorPrefetchOpcode.");
}
} else {
switch (Dim) {
case 1:
return GET_CP_ASYNC_BULK_TENSOR_OPCODE_PREFETCH(1D, TILE, IsCacheHint);
case 2:
return GET_CP_ASYNC_BULK_TENSOR_OPCODE_PREFETCH(2D, TILE, IsCacheHint);
case 3:
return GET_CP_ASYNC_BULK_TENSOR_OPCODE_PREFETCH(3D, TILE, IsCacheHint);
case 4:
return GET_CP_ASYNC_BULK_TENSOR_OPCODE_PREFETCH(4D, TILE, IsCacheHint);
case 5:
return GET_CP_ASYNC_BULK_TENSOR_OPCODE_PREFETCH(5D, TILE, IsCacheHint);
default:
llvm_unreachable("Invalid Dimension in tile mode for "
"GetCpAsyncBulkTensorPrefetchOpcode.");
}
}
}

static size_t GetDimsFromIntrinsic(unsigned IID) {
switch (IID) {
case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d:
Expand Down Expand Up @@ -2364,52 +2321,6 @@ void NVPTXDAGToDAGISel::SelectCpAsyncBulkTensorG2SCommon(SDNode *N,
ReplaceNode(N, CurDAG->getMachineNode(Opcode, DL, N->getVTList(), Ops));
}

void NVPTXDAGToDAGISel::SelectCpAsyncBulkTensorS2GCommon(SDNode *N,
bool IsIm2Col) {
// We have {Chain, Intrinsic-ID} followed by the actual intrisic args:
// src, dst, dims{d0...dN}, cache_hint, cache_hint_flag
// NumOperands = {Chain, IID} + {Actual intrinsic args}
// = {2} + {4 + dims}
size_t NumOps = N->getNumOperands();
size_t NumDims = NumOps - 6;
bool IsCacheHint = N->getConstantOperandVal(NumOps - 1) == 1;
size_t NumArgs = NumDims + (IsCacheHint ? 3 : 2); // src, dst, cache_hint

SDLoc DL(N);
SmallVector<SDValue, 8> Ops(N->ops().slice(2, NumArgs));
Ops.push_back(N->getOperand(0)); // Chain operand

bool IsShared32 =
CurDAG->getDataLayout().getPointerSizeInBits(ADDRESS_SPACE_SHARED) == 32;
unsigned Opcode =
GetCpAsyncBulkTensorS2GOpcode(NumDims, IsShared32, IsCacheHint, IsIm2Col);
ReplaceNode(N, CurDAG->getMachineNode(Opcode, DL, N->getVTList(), Ops));
}

void NVPTXDAGToDAGISel::SelectCpAsyncBulkTensorPrefetchCommon(SDNode *N,
bool IsIm2Col) {
// We have {Chain, Intrinsic-ID} followed by the actual intrisic args:
// {src, dims{d0...dN}, im2col_offsets{dims-2}
// cache_hint, cache_hint_flag}
// NumOperands = {Chain, IID} + {Actual intrinsic args}
// = {2} + {3 + dims + im2col_offsets}
size_t NumOps = N->getNumOperands();
size_t NumDims = IsIm2Col ? GetDimsFromIntrinsic(N->getConstantOperandVal(1))
: (NumOps - 5);
// Offsets is always 'NumDims - 2' and only for im2col mode
size_t NumOffsets = IsIm2Col ? (NumDims - 2) : 0;
bool IsCacheHint = N->getConstantOperandVal(NumOps - 1) == 1;
size_t NumArgs = NumDims + NumOffsets + (IsCacheHint ? 2 : 1);

SDLoc DL(N);
SmallVector<SDValue, 12> Ops(N->ops().slice(2, NumArgs));
Ops.push_back(N->getOperand(0)); // Chain operand

unsigned Opcode =
GetCpAsyncBulkTensorPrefetchOpcode(NumDims, IsCacheHint, IsIm2Col);
ReplaceNode(N, CurDAG->getMachineNode(Opcode, DL, N->getVTList(), Ops));
}

void NVPTXDAGToDAGISel::SelectCpAsyncBulkTensorReduceCommon(SDNode *N,
unsigned RedOp,
bool IsIm2Col) {
Expand All @@ -2429,8 +2340,8 @@ void NVPTXDAGToDAGISel::SelectCpAsyncBulkTensorReduceCommon(SDNode *N,

bool IsShared32 =
CurDAG->getDataLayout().getPointerSizeInBits(ADDRESS_SPACE_SHARED) == 32;
unsigned Opcode = GetCpAsyncBulkTensorS2GOpcode(
NumDims, IsShared32, IsCacheHint, IsIm2Col, /*IsReduce=*/true);
unsigned Opcode = GetCpAsyncBulkTensorS2GReductionOpcode(
NumDims, IsShared32, IsCacheHint, IsIm2Col);
ReplaceNode(N, CurDAG->getMachineNode(Opcode, DL, N->getVTList(), Ops));
}

Expand Down Expand Up @@ -2550,18 +2461,6 @@ bool NVPTXDAGToDAGISel::tryIntrinsicVoid(SDNode *N) {
switch (IID) {
default:
return false;
case Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_1d:
case Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_2d:
case Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_3d:
case Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_4d:
case Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_5d:
SelectCpAsyncBulkTensorS2GCommon(N);
return true;
case Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_3d:
case Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_4d:
case Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_5d:
SelectCpAsyncBulkTensorS2GCommon(N, /*IsIm2Col=*/true);
return true;
case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d:
case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d:
case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d:
Expand All @@ -2574,18 +2473,6 @@ bool NVPTXDAGToDAGISel::tryIntrinsicVoid(SDNode *N) {
case Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d:
SelectCpAsyncBulkTensorG2SCommon(N, /*IsIm2Col=*/true);
return true;
case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d:
case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d:
case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d:
case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d:
case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d:
SelectCpAsyncBulkTensorPrefetchCommon(N);
return true;
case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d:
case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d:
case Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d:
SelectCpAsyncBulkTensorPrefetchCommon(N, /*IsIm2Col=*/true);
return true;
case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_1d:
case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_2d:
case Intrinsic::nvvm_cp_async_bulk_tensor_reduce_add_tile_3d:
Expand Down
2 changes: 0 additions & 2 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,6 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
void SelectV2I64toI128(SDNode *N);
void SelectI128toV2I64(SDNode *N);
void SelectCpAsyncBulkTensorG2SCommon(SDNode *N, bool IsIm2Col = false);
void SelectCpAsyncBulkTensorS2GCommon(SDNode *N, bool IsIm2Col = false);
void SelectCpAsyncBulkTensorPrefetchCommon(SDNode *N, bool IsIm2Col = false);
void SelectCpAsyncBulkTensorReduceCommon(SDNode *N, unsigned RedOp,
bool IsIm2Col = false);
void SelectTcgen05Ld(SDNode *N, bool hasOffset = false);
Expand Down
Loading