Skip to content

Commit 3611349

Browse files
committed
legalize v2f32 as i64 reg and add test cases
1 parent 213d0d2 commit 3611349

File tree

6 files changed

+414
-3
lines changed

6 files changed

+414
-3
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,6 +1029,7 @@ static std::optional<unsigned> pickOpcodeForVT(
10291029
case MVT::i32:
10301030
return Opcode_i32;
10311031
case MVT::i64:
1032+
case MVT::v2f32:
10321033
return Opcode_i64;
10331034
case MVT::f16:
10341035
case MVT::bf16:

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,8 +331,8 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
331331
// TargetLoweringBase::getVectorTypeBreakdown() which is invoked in
332332
// ComputePTXValueVTs() cannot currently break down non-power-of-2 sized
333333
// vectors.
334-
if ((Is16bitsType(EltVT.getSimpleVT())) && NumElts % 2 == 0 &&
335-
isPowerOf2_32(NumElts)) {
334+
if ((Is16bitsType(EltVT.getSimpleVT()) || EltVT == MVT::f32) &&
335+
NumElts % 2 == 0 && isPowerOf2_32(NumElts)) {
336336
// Vectors with an even number of f16 elements will be passed to
337337
// us as an array of v2f16/v2bf16 elements. We must match this so we
338338
// stay in sync with Ins/Outs.
@@ -346,6 +346,9 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
346346
case MVT::i16:
347347
EltVT = MVT::v2i16;
348348
break;
349+
case MVT::f32:
350+
EltVT = MVT::v2f32;
351+
break;
349352
default:
350353
llvm_unreachable("Unexpected type");
351354
}
@@ -612,6 +615,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
612615
addRegisterClass(MVT::v2f16, &NVPTX::Int32RegsRegClass);
613616
addRegisterClass(MVT::bf16, &NVPTX::Int16RegsRegClass);
614617
addRegisterClass(MVT::v2bf16, &NVPTX::Int32RegsRegClass);
618+
addRegisterClass(MVT::v2f32, &NVPTX::Int64RegsRegClass);
615619

616620
// Conversion to/from FP16/FP16x2 is always legal.
617621
setOperationAction(ISD::BUILD_VECTOR, MVT::v2f16, Custom);
@@ -877,6 +881,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
877881
setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
878882
if (getOperationAction(Op, MVT::bf16) == Promote)
879883
AddPromotedToType(Op, MVT::bf16, MVT::f32);
884+
if (STI.hasF32x2Instructions())
885+
setOperationAction(Op, MVT::v2f32, Legal);
880886
}
881887

882888
// On SM80, we select add/mul/sub as fma to avoid promotion to float
@@ -3568,6 +3574,8 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
35683574
// vectors which contain v2f16 or v2bf16 elements. So we must load
35693575
// using i32 here and then bitcast back.
35703576
LoadVT = MVT::i32;
3577+
else if (EltVT == MVT::v2f32)
3578+
LoadVT = MVT::i64;
35713579

35723580
EVT VecVT = EVT::getVectorVT(F->getContext(), LoadVT, NumElts);
35733581
SDValue VecAddr =

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ def hasHWROT32 : Predicate<"Subtarget->hasHWROT32()">;
158158
def noHWROT32 : Predicate<"!Subtarget->hasHWROT32()">;
159159
def hasDotInstructions : Predicate<"Subtarget->hasDotInstructions()">;
160160
def hasTcgen05Instructions : Predicate<"Subtarget->hasTcgen05Instructions()">;
161+
def hasF32x2Instructions : Predicate<"Subtarget->hasF32x2Instructions()">;
161162

162163
def True : Predicate<"true">;
163164
def False : Predicate<"false">;
@@ -2858,6 +2859,9 @@ let hasSideEffects = false in {
28582859
def V2F32toF64 : NVPTXInst<(outs Float64Regs:$d),
28592860
(ins Float32Regs:$s1, Float32Regs:$s2),
28602861
"mov.b64 \t$d, {{$s1, $s2}};", []>;
2862+
def V2F32toI64 : NVPTXInst<(outs Int64Regs:$d),
2863+
(ins Float32Regs:$s1, Float32Regs:$s2),
2864+
"mov.b64 \t$d, {{$s1, $s2}};", []>;
28612865

28622866
// unpack a larger int register to a set of smaller int registers
28632867
def I64toV4I16 : NVPTXInst<(outs Int16Regs:$d1, Int16Regs:$d2,
@@ -2941,6 +2945,8 @@ def : Pat<(v2bf16 (build_vector bf16:$a, bf16:$b)),
29412945
(V2I16toI32 $a, $b)>;
29422946
def : Pat<(v2i16 (build_vector i16:$a, i16:$b)),
29432947
(V2I16toI32 $a, $b)>;
2948+
def : Pat<(v2f32 (build_vector f32:$a, f32:$b)),
2949+
(V2F32toI64 $a, $b)>;
29442950

29452951
def: Pat<(v2i16 (scalar_to_vector i16:$a)),
29462952
(CVT_u32_u16 $a, CvtNONE)>;

llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,9 @@ def Int16Regs : NVPTXRegClass<[i16, f16, bf16], 16, (add (sequence "RS%u", 0, 4)
6060
def Int32Regs : NVPTXRegClass<[i32, v2f16, v2bf16, v2i16, v4i8, f32], 32,
6161
(add (sequence "R%u", 0, 4),
6262
VRFrame32, VRFrameLocal32)>;
63-
def Int64Regs : NVPTXRegClass<[i64, f64], 64, (add (sequence "RL%u", 0, 4), VRFrame64, VRFrameLocal64)>;
63+
def Int64Regs : NVPTXRegClass<[i64, f64, v2f32], 64,
64+
(add (sequence "RL%u", 0, 4),
65+
VRFrame64, VRFrameLocal64)>;
6466
// 128-bit regs are not defined as general regs in NVPTX. They are used for inlineASM only.
6567
def Int128Regs : NVPTXRegClass<[i128], 128, (add (sequence "RQ%u", 0, 4))>;
6668

llvm/lib/Target/NVPTX/NVPTXSubtarget.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,10 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo {
117117
return HasTcgen05 && PTXVersion >= 86;
118118
}
119119

120+
bool hasF32x2Instructions() const {
121+
return SmVersion >= 100 && PTXVersion >= 86;
122+
}
123+
120124
// Prior to CUDA 12.3 ptxas did not recognize that the trap instruction
121125
// terminates a basic block. Instead, it would assume that control flow
122126
// continued to the next instruction. The next instruction could be in the

0 commit comments

Comments
 (0)