Skip to content

Commit 69d78c2

Browse files
committed
legalize v2f32 as i64 reg and add test cases
1 parent 09a36c8 commit 69d78c2

File tree

6 files changed

+416
-5
lines changed

6 files changed

+416
-5
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,6 +1027,7 @@ pickOpcodeForVT(MVT::SimpleValueType VT, unsigned Opcode_i8,
10271027
case MVT::i32:
10281028
return Opcode_i32;
10291029
case MVT::i64:
1030+
case MVT::v2f32:
10301031
return Opcode_i64;
10311032
case MVT::f16:
10321033
case MVT::bf16:

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,8 +290,8 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
290290
// TargetLoweringBase::getVectorTypeBreakdown() which is invoked in
291291
// ComputePTXValueVTs() cannot currently break down non-power-of-2 sized
292292
// vectors.
293-
if ((Is16bitsType(EltVT.getSimpleVT())) && NumElts % 2 == 0 &&
294-
isPowerOf2_32(NumElts)) {
293+
if ((Is16bitsType(EltVT.getSimpleVT()) || EltVT == MVT::f32) &&
294+
NumElts % 2 == 0 && isPowerOf2_32(NumElts)) {
295295
// Vectors with an even number of f16 elements will be passed to
296296
// us as an array of v2f16/v2bf16 elements. We must match this so we
297297
// stay in sync with Ins/Outs.
@@ -305,6 +305,9 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
305305
case MVT::i16:
306306
EltVT = MVT::v2i16;
307307
break;
308+
case MVT::f32:
309+
EltVT = MVT::v2f32;
310+
break;
308311
default:
309312
llvm_unreachable("Unexpected type");
310313
}
@@ -578,6 +581,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
578581
addRegisterClass(MVT::v2f16, &NVPTX::Int32RegsRegClass);
579582
addRegisterClass(MVT::bf16, &NVPTX::Int16RegsRegClass);
580583
addRegisterClass(MVT::v2bf16, &NVPTX::Int32RegsRegClass);
584+
addRegisterClass(MVT::v2f32, &NVPTX::Int64RegsRegClass);
581585

582586
// Conversion to/from FP16/FP16x2 is always legal.
583587
setOperationAction(ISD::BUILD_VECTOR, MVT::v2f16, Custom);
@@ -840,6 +844,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
840844
setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
841845
if (getOperationAction(Op, MVT::bf16) == Promote)
842846
AddPromotedToType(Op, MVT::bf16, MVT::f32);
847+
if (STI.hasF32x2Instructions())
848+
setOperationAction(Op, MVT::v2f32, Legal);
843849
}
844850

845851
// On SM80, we select add/mul/sub as fma to avoid promotion to float
@@ -3315,6 +3321,8 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
33153321
// vectors which contain v2f16 or v2bf16 elements. So we must load
33163322
// using i32 here and then bitcast back.
33173323
LoadVT = MVT::i32;
3324+
else if (EltVT == MVT::v2f32)
3325+
LoadVT = MVT::i64;
33183326

33193327
EVT VecVT = EVT::getVectorVT(F->getContext(), LoadVT, NumElts);
33203328
SDValue VecAddr =

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ def hasHWROT32 : Predicate<"Subtarget->hasHWROT32()">;
165165
def noHWROT32 : Predicate<"!Subtarget->hasHWROT32()">;
166166
def hasDotInstructions : Predicate<"Subtarget->hasDotInstructions()">;
167167
def hasTcgen05Instructions : Predicate<"Subtarget->hasTcgen05Instructions()">;
168+
def hasF32x2Instructions : Predicate<"Subtarget->hasF32x2Instructions()">;
168169

169170
def True : Predicate<"true">;
170171
def False : Predicate<"false">;
@@ -2631,13 +2632,13 @@ class LastCallArgInstVT<NVPTXRegClass regclass, ValueType vt> :
26312632
NVPTXInst<(outs), (ins regclass:$a), "$a",
26322633
[(LastCallArg (i32 0), vt:$a)]>;
26332634

2634-
def CallArgI64 : CallArgInst<Int64Regs>;
2635+
def CallArgI64 : CallArgInstVT<Int64Regs, i64>;
26352636
def CallArgI32 : CallArgInstVT<Int32Regs, i32>;
26362637
def CallArgI16 : CallArgInstVT<Int16Regs, i16>;
26372638
def CallArgF64 : CallArgInst<Float64Regs>;
26382639
def CallArgF32 : CallArgInst<Float32Regs>;
26392640

2640-
def LastCallArgI64 : LastCallArgInst<Int64Regs>;
2641+
def LastCallArgI64 : LastCallArgInstVT<Int64Regs, i64>;
26412642
def LastCallArgI32 : LastCallArgInstVT<Int32Regs, i32>;
26422643
def LastCallArgI16 : LastCallArgInstVT<Int16Regs, i16>;
26432644
def LastCallArgF64 : LastCallArgInst<Float64Regs>;
@@ -3154,6 +3155,9 @@ let hasSideEffects = false in {
31543155
def V2F32toF64 : NVPTXInst<(outs Float64Regs:$d),
31553156
(ins Float32Regs:$s1, Float32Regs:$s2),
31563157
"mov.b64 \t$d, {{$s1, $s2}};", []>;
3158+
def V2F32toI64 : NVPTXInst<(outs Int64Regs:$d),
3159+
(ins Float32Regs:$s1, Float32Regs:$s2),
3160+
"mov.b64 \t$d, {{$s1, $s2}};", []>;
31573161

31583162
// unpack a larger int register to a set of smaller int registers
31593163
def I64toV4I16 : NVPTXInst<(outs Int16Regs:$d1, Int16Regs:$d2,
@@ -3218,6 +3222,8 @@ def : Pat<(v2bf16 (build_vector bf16:$a, bf16:$b)),
32183222
(V2I16toI32 $a, $b)>;
32193223
def : Pat<(v2i16 (build_vector i16:$a, i16:$b)),
32203224
(V2I16toI32 $a, $b)>;
3225+
def : Pat<(v2f32 (build_vector f32:$a, f32:$b)),
3226+
(V2F32toI64 $a, $b)>;
32213227

32223228
def: Pat<(v2i16 (scalar_to_vector i16:$a)),
32233229
(CVT_u32_u16 $a, CvtNONE)>;

llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ def Int16Regs : NVPTXRegClass<[i16, f16, bf16], 16, (add (sequence "RS%u", 0, 4)
6262
def Int32Regs : NVPTXRegClass<[i32, v2f16, v2bf16, v2i16, v4i8], 32,
6363
(add (sequence "R%u", 0, 4),
6464
VRFrame32, VRFrameLocal32)>;
65-
def Int64Regs : NVPTXRegClass<[i64], 64, (add (sequence "RL%u", 0, 4), VRFrame64, VRFrameLocal64)>;
65+
def Int64Regs : NVPTXRegClass<[i64, v2f32], 64,
66+
(add (sequence "RL%u", 0, 4),
67+
VRFrame64, VRFrameLocal64)>;
6668
// 128-bit regs are not defined as general regs in NVPTX. They are used for inlineASM only.
6769
def Int128Regs : NVPTXRegClass<[i128], 128, (add (sequence "RQ%u", 0, 4))>;
6870
def Float32Regs : NVPTXRegClass<[f32], 32, (add (sequence "F%u", 0, 4))>;

llvm/lib/Target/NVPTX/NVPTXSubtarget.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,10 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo {
112112
return HasTcgen05 && PTXVersion >= 86;
113113
}
114114

115+
bool hasF32x2Instructions() const {
116+
return SmVersion >= 100 && PTXVersion >= 86;
117+
}
118+
115119
// Prior to CUDA 12.3 ptxas did not recognize that the trap instruction
116120
// terminates a basic block. Instead, it would assume that control flow
117121
// continued to the next instruction. The next instruction could be in the

0 commit comments

Comments
 (0)