Skip to content

Commit e46f256

Browse files
committed
legalize v2f32 as i64 reg and add test cases
1 parent 227328f commit e46f256

File tree

6 files changed

+414
-3
lines changed

6 files changed

+414
-3
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,6 +1026,7 @@ pickOpcodeForVT(MVT::SimpleValueType VT, unsigned Opcode_i8,
10261026
case MVT::i32:
10271027
return Opcode_i32;
10281028
case MVT::i64:
1029+
case MVT::v2f32:
10291030
return Opcode_i64;
10301031
case MVT::f16:
10311032
case MVT::bf16:

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -295,8 +295,8 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
295295
// TargetLoweringBase::getVectorTypeBreakdown() which is invoked in
296296
// ComputePTXValueVTs() cannot currently break down non-power-of-2 sized
297297
// vectors.
298-
if ((Is16bitsType(EltVT.getSimpleVT())) && NumElts % 2 == 0 &&
299-
isPowerOf2_32(NumElts)) {
298+
if ((Is16bitsType(EltVT.getSimpleVT()) || EltVT == MVT::f32) &&
299+
NumElts % 2 == 0 && isPowerOf2_32(NumElts)) {
300300
// Vectors with an even number of f16 elements will be passed to
301301
// us as an array of v2f16/v2bf16 elements. We must match this so we
302302
// stay in sync with Ins/Outs.
@@ -310,6 +310,9 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
310310
case MVT::i16:
311311
EltVT = MVT::v2i16;
312312
break;
313+
case MVT::f32:
314+
EltVT = MVT::v2f32;
315+
break;
313316
default:
314317
llvm_unreachable("Unexpected type");
315318
}
@@ -576,6 +579,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
576579
addRegisterClass(MVT::v2f16, &NVPTX::Int32RegsRegClass);
577580
addRegisterClass(MVT::bf16, &NVPTX::Int16RegsRegClass);
578581
addRegisterClass(MVT::v2bf16, &NVPTX::Int32RegsRegClass);
582+
addRegisterClass(MVT::v2f32, &NVPTX::Int64RegsRegClass);
579583

580584
// Conversion to/from FP16/FP16x2 is always legal.
581585
setOperationAction(ISD::BUILD_VECTOR, MVT::v2f16, Custom);
@@ -841,6 +845,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
841845
setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
842846
if (getOperationAction(Op, MVT::bf16) == Promote)
843847
AddPromotedToType(Op, MVT::bf16, MVT::f32);
848+
if (STI.hasF32x2Instructions())
849+
setOperationAction(Op, MVT::v2f32, Legal);
844850
}
845851

846852
// On SM80, we select add/mul/sub as fma to avoid promotion to float
@@ -3465,6 +3471,8 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
34653471
// vectors which contain v2f16 or v2bf16 elements. So we must load
34663472
// using i32 here and then bitcast back.
34673473
LoadVT = MVT::i32;
3474+
else if (EltVT == MVT::v2f32)
3475+
LoadVT = MVT::i64;
34683476

34693477
EVT VecVT = EVT::getVectorVT(F->getContext(), LoadVT, NumElts);
34703478
SDValue VecAddr =

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def hasHWROT32 : Predicate<"Subtarget->hasHWROT32()">;
161161
def noHWROT32 : Predicate<"!Subtarget->hasHWROT32()">;
162162
def hasDotInstructions : Predicate<"Subtarget->hasDotInstructions()">;
163163
def hasTcgen05Instructions : Predicate<"Subtarget->hasTcgen05Instructions()">;
164+
def hasF32x2Instructions : Predicate<"Subtarget->hasF32x2Instructions()">;
164165

165166
def True : Predicate<"true">;
166167
def False : Predicate<"false">;
@@ -2786,6 +2787,9 @@ let hasSideEffects = false in {
27862787
def V2F32toF64 : NVPTXInst<(outs Float64Regs:$d),
27872788
(ins Float32Regs:$s1, Float32Regs:$s2),
27882789
"mov.b64 \t$d, {{$s1, $s2}};", []>;
2790+
def V2F32toI64 : NVPTXInst<(outs Int64Regs:$d),
2791+
(ins Float32Regs:$s1, Float32Regs:$s2),
2792+
"mov.b64 \t$d, {{$s1, $s2}};", []>;
27892793

27902794
// unpack a larger int register to a set of smaller int registers
27912795
def I64toV4I16 : NVPTXInst<(outs Int16Regs:$d1, Int16Regs:$d2,
@@ -2869,6 +2873,8 @@ def : Pat<(v2bf16 (build_vector bf16:$a, bf16:$b)),
28692873
(V2I16toI32 $a, $b)>;
28702874
def : Pat<(v2i16 (build_vector i16:$a, i16:$b)),
28712875
(V2I16toI32 $a, $b)>;
2876+
def : Pat<(v2f32 (build_vector f32:$a, f32:$b)),
2877+
(V2F32toI64 $a, $b)>;
28722878

28732879
def: Pat<(v2i16 (scalar_to_vector i16:$a)),
28742880
(CVT_u32_u16 $a, CvtNONE)>;

llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ def Int16Regs : NVPTXRegClass<[i16, f16, bf16], 16, (add (sequence "RS%u", 0, 4)
6262
def Int32Regs : NVPTXRegClass<[i32, v2f16, v2bf16, v2i16, v4i8], 32,
6363
(add (sequence "R%u", 0, 4),
6464
VRFrame32, VRFrameLocal32)>;
65-
def Int64Regs : NVPTXRegClass<[i64], 64, (add (sequence "RL%u", 0, 4), VRFrame64, VRFrameLocal64)>;
65+
def Int64Regs : NVPTXRegClass<[i64, v2f32], 64,
66+
(add (sequence "RL%u", 0, 4),
67+
VRFrame64, VRFrameLocal64)>;
6668
// 128-bit regs are not defined as general regs in NVPTX. They are used for inlineASM only.
6769
def Int128Regs : NVPTXRegClass<[i128], 128, (add (sequence "RQ%u", 0, 4))>;
6870
def Float32Regs : NVPTXRegClass<[f32], 32, (add (sequence "F%u", 0, 4))>;

llvm/lib/Target/NVPTX/NVPTXSubtarget.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,10 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo {
112112
return HasTcgen05 && PTXVersion >= 86;
113113
}
114114

115+
bool hasF32x2Instructions() const {
116+
return SmVersion >= 100 && PTXVersion >= 86;
117+
}
118+
115119
// Prior to CUDA 12.3 ptxas did not recognize that the trap instruction
116120
// terminates a basic block. Instead, it would assume that control flow
117121
// continued to the next instruction. The next instruction could be in the

0 commit comments

Comments
 (0)