-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[NVPTX] Use cvt.sat to lower min/max clamping to i8 and i16 ranges #143016
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,7 @@ | |
#include "llvm/CodeGen/MachineFunction.h" | ||
#include "llvm/CodeGen/MachineJumpTableInfo.h" | ||
#include "llvm/CodeGen/MachineMemOperand.h" | ||
#include "llvm/CodeGen/SDPatternMatch.h" | ||
#include "llvm/CodeGen/SelectionDAG.h" | ||
#include "llvm/CodeGen/SelectionDAGNodes.h" | ||
#include "llvm/CodeGen/TargetCallingConv.h" | ||
|
@@ -74,6 +75,7 @@ | |
#define DEBUG_TYPE "nvptx-lower" | ||
|
||
using namespace llvm; | ||
using namespace llvm::SDPatternMatch; | ||
|
||
static cl::opt<bool> sched4reg( | ||
"nvptx-sched4reg", | ||
|
@@ -659,6 +661,11 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, | |
setOperationAction(ISD::BR_CC, VT, Expand); | ||
} | ||
|
||
setOperationAction({ISD::TRUNCATE_SSAT_S, ISD::TRUNCATE_SSAT_U}, MVT::i16, | ||
Legal); | ||
setOperationAction({ISD::TRUNCATE_SSAT_S, ISD::TRUNCATE_SSAT_U}, MVT::i8, | ||
Custom); | ||
|
||
// Some SIGN_EXTEND_INREG can be done using cvt instruction. | ||
// For others we will expand to a SHL/SRA pair. | ||
setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i64, Legal); | ||
|
@@ -836,7 +843,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, | |
// We have some custom DAG combine patterns for these nodes | ||
setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD, | ||
ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT, | ||
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST}); | ||
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::SMIN, | ||
ISD::SMAX}); | ||
|
||
// setcc for f16x2 and bf16x2 needs special handling to prevent | ||
// legalizer's attempt to scalarize it due to v2i1 not being legal. | ||
|
@@ -1081,6 +1089,8 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const { | |
MAKE_CASE(NVPTXISD::PseudoUseParam) | ||
MAKE_CASE(NVPTXISD::UNPACK_VECTOR) | ||
MAKE_CASE(NVPTXISD::BUILD_VECTOR) | ||
MAKE_CASE(NVPTXISD::TRUNCATE_SSAT_U_I8) | ||
MAKE_CASE(NVPTXISD::TRUNCATE_SSAT_S_I8) | ||
MAKE_CASE(NVPTXISD::RETURN) | ||
MAKE_CASE(NVPTXISD::CallSeqBegin) | ||
MAKE_CASE(NVPTXISD::CallSeqEnd) | ||
|
@@ -5667,6 +5677,49 @@ static SDValue combineADDRSPACECAST(SDNode *N, | |
return SDValue(); | ||
} | ||
|
||
static SDValue combineMINMAX(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { | ||
|
||
EVT VT = N->getValueType(0); | ||
if (!(VT == MVT::i32 || VT == MVT::i64 || VT == MVT::i16)) | ||
return SDValue(); | ||
|
||
SDValue Val; | ||
APInt Ceil, Floor; | ||
if (!(sd_match(N, m_SMin(m_SMax(m_Value(Val), m_ConstInt(Floor)), | ||
m_ConstInt(Ceil))) || | ||
sd_match(N, m_SMax(m_SMin(m_Value(Val), m_ConstInt(Ceil)), | ||
m_ConstInt(Floor))))) | ||
return SDValue(); | ||
|
||
const unsigned BitWidth = VT.getSizeInBits(); | ||
SDLoc DL(N); | ||
auto MatchTuncSat = [&](MVT DestVT) { | ||
const unsigned DestBitWidth = DestVT.getSizeInBits(); | ||
bool IsSigned; | ||
if (Ceil == APInt::getSignedMaxValue(DestBitWidth).sext(BitWidth) && | ||
Floor == APInt::getSignedMinValue(DestBitWidth).sext(BitWidth)) | ||
IsSigned = true; | ||
else if (Ceil == APInt::getMaxValue(DestBitWidth).zext(BitWidth) && | ||
Floor == APInt::getMinValue(BitWidth)) | ||
IsSigned = false; | ||
else | ||
return SDValue(); | ||
|
||
unsigned Opcode = IsSigned ? ISD::TRUNCATE_SSAT_S : ISD::TRUNCATE_SSAT_U; | ||
SDValue Trunc = DCI.DAG.getNode(Opcode, DL, DestVT, Val); | ||
return DCI.DAG.getExtOrTrunc(IsSigned, Trunc, DL, VT); | ||
}; | ||
|
||
if (VT != MVT::i16) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd fold it into the MatchTruncSat as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep, that is cleaner. I've moved the check. |
||
if (auto Res = MatchTuncSat(MVT::i16)) | ||
return Res; | ||
|
||
if (auto Res = MatchTuncSat(MVT::i8)) | ||
return Res; | ||
|
||
return SDValue(); | ||
} | ||
|
||
SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N, | ||
DAGCombinerInfo &DCI) const { | ||
CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel(); | ||
|
@@ -5685,6 +5738,9 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N, | |
case ISD::UREM: | ||
case ISD::SREM: | ||
return PerformREMCombine(N, DCI, OptLevel); | ||
case ISD::SMIN: | ||
case ISD::SMAX: | ||
return combineMINMAX(N, DCI); | ||
case ISD::SETCC: | ||
return PerformSETCCCombine(N, DCI, STI.getSmVersion()); | ||
case NVPTXISD::StoreRetval: | ||
|
@@ -6045,6 +6101,20 @@ static void ReplaceCopyFromReg_128(SDNode *N, SelectionDAG &DAG, | |
Results.push_back(NewValue.getValue(3)); | ||
} | ||
|
||
static void replaceTruncateSSat(SDNode *N, SelectionDAG &DAG, | ||
SmallVectorImpl<SDValue> &Results) { | ||
SDLoc DL(N); | ||
|
||
const bool IsSigned = N->getOpcode() == ISD::TRUNCATE_SSAT_S; | ||
const unsigned Opcode = | ||
IsSigned ? NVPTXISD::TRUNCATE_SSAT_S_I8 : NVPTXISD::TRUNCATE_SSAT_U_I8; | ||
SDValue NewTrunc = DAG.getNode(Opcode, DL, MVT::i16, N->getOperand(0)); | ||
SDValue Assert = DAG.getNode(IsSigned ? ISD::AssertSext : ISD::AssertZext, DL, | ||
MVT::i16, NewTrunc, DAG.getValueType(MVT::i8)); | ||
|
||
Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, Assert)); | ||
} | ||
|
||
void NVPTXTargetLowering::ReplaceNodeResults( | ||
SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const { | ||
switch (N->getOpcode()) { | ||
|
@@ -6062,6 +6132,10 @@ void NVPTXTargetLowering::ReplaceNodeResults( | |
case ISD::CopyFromReg: | ||
ReplaceCopyFromReg_128(N, DAG, Results); | ||
return; | ||
case ISD::TRUNCATE_SSAT_U: | ||
case ISD::TRUNCATE_SSAT_S: | ||
replaceTruncateSSat(N, DAG, Results); | ||
return; | ||
} | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 | ||
; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_20 -verify-machineinstrs | FileCheck %s | ||
; RUN: %if ptxas %{ llc < %s -mtriple=nvptx64 -mcpu=sm_20 -verify-machineinstrs | %ptxas-verify %} | ||
|
||
target triple = "nvptx-unknown-cuda" | ||
|
||
|
||
define i64 @trunc_ssat_i64_u16(i64 %a) { | ||
; CHECK-LABEL: trunc_ssat_i64_u16( | ||
; CHECK: { | ||
; CHECK-NEXT: .reg .b16 %rs<2>; | ||
; CHECK-NEXT: .reg .b64 %rd<3>; | ||
; CHECK-EMPTY: | ||
; CHECK-NEXT: // %bb.0: | ||
; CHECK-NEXT: ld.param.b64 %rd1, [trunc_ssat_i64_u16_param_0]; | ||
; CHECK-NEXT: cvt.sat.u16.s64 %rs1, %rd1; | ||
; CHECK-NEXT: cvt.u64.u16 %rd2, %rs1; | ||
; CHECK-NEXT: st.param.b64 [func_retval0], %rd2; | ||
; CHECK-NEXT: ret; | ||
%v1 = call i64 @llvm.smax.i64(i64 %a, i64 0) | ||
%v2 = call i64 @llvm.smin.i64(i64 %v1, i64 65535) | ||
ret i64 %v2 | ||
} | ||
|
||
define i32 @trunc_ssat_i32_u16(i32 %a) { | ||
; CHECK-LABEL: trunc_ssat_i32_u16( | ||
; CHECK: { | ||
; CHECK-NEXT: .reg .b16 %rs<2>; | ||
; CHECK-NEXT: .reg .b32 %r<3>; | ||
; CHECK-EMPTY: | ||
; CHECK-NEXT: // %bb.0: | ||
; CHECK-NEXT: ld.param.b32 %r1, [trunc_ssat_i32_u16_param_0]; | ||
; CHECK-NEXT: cvt.sat.u16.s32 %rs1, %r1; | ||
; CHECK-NEXT: cvt.u32.u16 %r2, %rs1; | ||
; CHECK-NEXT: st.param.b32 [func_retval0], %r2; | ||
; CHECK-NEXT: ret; | ||
%v1 = call i32 @llvm.smax.i32(i32 %a, i32 0) | ||
%v2 = call i32 @llvm.smin.i32(i32 %v1, i32 65535) | ||
ret i32 %v2 | ||
} | ||
|
||
define i64 @trunc_ssat_i64_s16(i64 %a) { | ||
; CHECK-LABEL: trunc_ssat_i64_s16( | ||
; CHECK: { | ||
; CHECK-NEXT: .reg .b16 %rs<2>; | ||
; CHECK-NEXT: .reg .b64 %rd<3>; | ||
; CHECK-EMPTY: | ||
; CHECK-NEXT: // %bb.0: | ||
; CHECK-NEXT: ld.param.b64 %rd1, [trunc_ssat_i64_s16_param_0]; | ||
; CHECK-NEXT: cvt.sat.s16.s64 %rs1, %rd1; | ||
; CHECK-NEXT: cvt.s64.s16 %rd2, %rs1; | ||
; CHECK-NEXT: st.param.b64 [func_retval0], %rd2; | ||
; CHECK-NEXT: ret; | ||
%v1 = call i64 @llvm.smax.i64(i64 %a, i64 -32768) | ||
%v2 = call i64 @llvm.smin.i64(i64 %v1, i64 32767) | ||
ret i64 %v2 | ||
} | ||
|
||
define i32 @trunc_ssat_i32_s16(i32 %a) { | ||
; CHECK-LABEL: trunc_ssat_i32_s16( | ||
; CHECK: { | ||
; CHECK-NEXT: .reg .b16 %rs<2>; | ||
; CHECK-NEXT: .reg .b32 %r<3>; | ||
; CHECK-EMPTY: | ||
; CHECK-NEXT: // %bb.0: | ||
; CHECK-NEXT: ld.param.b32 %r1, [trunc_ssat_i32_s16_param_0]; | ||
; CHECK-NEXT: cvt.sat.s16.s32 %rs1, %r1; | ||
; CHECK-NEXT: cvt.s32.s16 %r2, %rs1; | ||
; CHECK-NEXT: st.param.b32 [func_retval0], %r2; | ||
; CHECK-NEXT: ret; | ||
%v1 = call i32 @llvm.smax.i32(i32 %a, i32 -32768) | ||
%v2 = call i32 @llvm.smin.i32(i32 %v1, i32 32767) | ||
ret i32 %v2 | ||
} | ||
|
||
define i64 @trunc_ssat_i64_u8(i64 %a) { | ||
; CHECK-LABEL: trunc_ssat_i64_u8( | ||
; CHECK: { | ||
; CHECK-NEXT: .reg .b16 %rs<2>; | ||
; CHECK-NEXT: .reg .b64 %rd<3>; | ||
; CHECK-EMPTY: | ||
; CHECK-NEXT: // %bb.0: | ||
; CHECK-NEXT: ld.param.b64 %rd1, [trunc_ssat_i64_u8_param_0]; | ||
; CHECK-NEXT: cvt.sat.u8.u64 %rs1, %rd1; | ||
; CHECK-NEXT: cvt.u64.u16 %rd2, %rs1; | ||
; CHECK-NEXT: st.param.b64 [func_retval0], %rd2; | ||
; CHECK-NEXT: ret; | ||
%v1 = call i64 @llvm.smax.i64(i64 %a, i64 0) | ||
%v2 = call i64 @llvm.smin.i64(i64 %v1, i64 255) | ||
ret i64 %v2 | ||
} | ||
|
||
define i32 @trunc_ssat_i32_u8(i32 %a) { | ||
; CHECK-LABEL: trunc_ssat_i32_u8( | ||
; CHECK: { | ||
; CHECK-NEXT: .reg .b16 %rs<2>; | ||
; CHECK-NEXT: .reg .b32 %r<3>; | ||
; CHECK-EMPTY: | ||
; CHECK-NEXT: // %bb.0: | ||
; CHECK-NEXT: ld.param.b32 %r1, [trunc_ssat_i32_u8_param_0]; | ||
; CHECK-NEXT: cvt.sat.u8.u32 %rs1, %r1; | ||
; CHECK-NEXT: cvt.u32.u16 %r2, %rs1; | ||
; CHECK-NEXT: st.param.b32 [func_retval0], %r2; | ||
; CHECK-NEXT: ret; | ||
%v1 = call i32 @llvm.smax.i32(i32 %a, i32 0) | ||
%v2 = call i32 @llvm.smin.i32(i32 %v1, i32 255) | ||
ret i32 %v2 | ||
} | ||
|
||
define i16 @trunc_ssat_i16_u8(i16 %a) { | ||
; CHECK-LABEL: trunc_ssat_i16_u8( | ||
; CHECK: { | ||
; CHECK-NEXT: .reg .b16 %rs<3>; | ||
; CHECK-NEXT: .reg .b32 %r<2>; | ||
; CHECK-EMPTY: | ||
; CHECK-NEXT: // %bb.0: | ||
; CHECK-NEXT: ld.param.b16 %rs1, [trunc_ssat_i16_u8_param_0]; | ||
; CHECK-NEXT: cvt.sat.u8.u16 %rs2, %rs1; | ||
; CHECK-NEXT: cvt.u32.u16 %r1, %rs2; | ||
; CHECK-NEXT: st.param.b32 [func_retval0], %r1; | ||
; CHECK-NEXT: ret; | ||
%v1 = call i16 @llvm.smax.i16(i16 %a, i16 0) | ||
%v2 = call i16 @llvm.smin.i16(i16 %v1, i16 255) | ||
ret i16 %v2 | ||
} | ||
|
||
define i64 @trunc_ssat_i64_s8(i64 %a) { | ||
; CHECK-LABEL: trunc_ssat_i64_s8( | ||
; CHECK: { | ||
; CHECK-NEXT: .reg .b16 %rs<2>; | ||
; CHECK-NEXT: .reg .b64 %rd<3>; | ||
; CHECK-EMPTY: | ||
; CHECK-NEXT: // %bb.0: | ||
; CHECK-NEXT: ld.param.b64 %rd1, [trunc_ssat_i64_s8_param_0]; | ||
; CHECK-NEXT: cvt.sat.s8.s64 %rs1, %rd1; | ||
; CHECK-NEXT: cvt.s64.s16 %rd2, %rs1; | ||
; CHECK-NEXT: st.param.b64 [func_retval0], %rd2; | ||
; CHECK-NEXT: ret; | ||
%v1 = call i64 @llvm.smax.i64(i64 %a, i64 -128) | ||
%v2 = call i64 @llvm.smin.i64(i64 %v1, i64 127) | ||
ret i64 %v2 | ||
} | ||
|
||
define i32 @trunc_ssat_i32_s8(i32 %a) { | ||
; CHECK-LABEL: trunc_ssat_i32_s8( | ||
; CHECK: { | ||
; CHECK-NEXT: .reg .b16 %rs<2>; | ||
; CHECK-NEXT: .reg .b32 %r<3>; | ||
; CHECK-EMPTY: | ||
; CHECK-NEXT: // %bb.0: | ||
; CHECK-NEXT: ld.param.b32 %r1, [trunc_ssat_i32_s8_param_0]; | ||
; CHECK-NEXT: cvt.sat.s8.s32 %rs1, %r1; | ||
; CHECK-NEXT: cvt.s32.s16 %r2, %rs1; | ||
; CHECK-NEXT: st.param.b32 [func_retval0], %r2; | ||
; CHECK-NEXT: ret; | ||
%v1 = call i32 @llvm.smax.i32(i32 %a, i32 -128) | ||
%v2 = call i32 @llvm.smin.i32(i32 %v1, i32 127) | ||
ret i32 %v2 | ||
} | ||
|
||
define i16 @trunc_ssat_i16_s8(i16 %a) { | ||
; CHECK-LABEL: trunc_ssat_i16_s8( | ||
; CHECK: { | ||
; CHECK-NEXT: .reg .b16 %rs<3>; | ||
; CHECK-NEXT: .reg .b32 %r<2>; | ||
; CHECK-EMPTY: | ||
; CHECK-NEXT: // %bb.0: | ||
; CHECK-NEXT: ld.param.b16 %rs1, [trunc_ssat_i16_s8_param_0]; | ||
; CHECK-NEXT: cvt.sat.s8.s16 %rs2, %rs1; | ||
; CHECK-NEXT: cvt.u32.u16 %r1, %rs2; | ||
; CHECK-NEXT: st.param.b32 [func_retval0], %r1; | ||
; CHECK-NEXT: ret; | ||
%v1 = call i16 @llvm.smax.i16(i16 %a, i16 -128) | ||
%v2 = call i16 @llvm.smin.i16(i16 %v1, i16 127) | ||
ret i16 %v2 | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Perhaps rename
MatchTruncSat
toTryToLowerAsSaturatedConversion
as we're not just matching the graph nodes, but also constructing their replacement.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. Fixed