Skip to content

Commit 592bd16

Browse files
committed
[NVPTX] update combiner rule for more types of loads
Handle more loads, including ones with multiple proxy registers: - i64 = LOAD - i64 = LoadParam - v2f32,v2f32 = LoadParamV2 Also update the test cases. Because this is an optimization, it is not triggered for some of these tests that compile with no optimizations.
1 parent 2d61d1b commit 592bd16

File tree

8 files changed

+553
-507
lines changed

8 files changed

+553
-507
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 97 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -5189,12 +5189,14 @@ PerformFADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
51895189
return SDValue();
51905190
}
51915191

5192+
/// OverrideVT - allows overriding result and memory type
51925193
static std::optional<std::pair<SDValue, SDValue>>
51935194
convertVectorLoad(SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI,
5194-
bool BuildVector) {
5195+
bool BuildVector,
5196+
std::optional<EVT> OverrideVT = std::nullopt) {
51955197
LoadSDNode *LD = cast<LoadSDNode>(N);
5196-
const EVT ResVT = LD->getValueType(0);
5197-
const EVT MemVT = LD->getMemoryVT();
5198+
const EVT ResVT = OverrideVT.value_or(LD->getValueType(0));
5199+
const EVT MemVT = OverrideVT.value_or(LD->getMemoryVT());
51985200

51995201
// If we're doing sign/zero extension as part of the load, avoid lowering to
52005202
// a LoadV node. TODO: consider relaxing this restriction.
@@ -5251,8 +5253,8 @@ convertVectorLoad(SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI,
52515253
// pass along the extension information
52525254
OtherOps.push_back(DAG.getIntPtrConstant(LD->getExtensionType(), DL));
52535255

5254-
SDValue NewLD = DAG.getMemIntrinsicNode(
5255-
Opcode, DL, LdResVTs, OtherOps, LD->getMemoryVT(), LD->getMemOperand());
5256+
SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, DL, LdResVTs, OtherOps, MemVT,
5257+
LD->getMemOperand());
52565258

52575259
SmallVector<SDValue> ScalarRes;
52585260
if (EltVT.isVector()) {
@@ -5277,6 +5279,26 @@ convertVectorLoad(SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI,
52775279
SDValue LoadChain = NewLD.getValue(NumElts);
52785280

52795281
if (BuildVector) {
5282+
SmallVector<SDValue> ScalarRes;
5283+
if (EltVT.isVector()) {
5284+
assert(EVT(EltVT.getVectorElementType()) == ResVT.getVectorElementType());
5285+
assert(NumElts * EltVT.getVectorNumElements() ==
5286+
ResVT.getVectorNumElements());
5287+
// Generate EXTRACT_VECTOR_ELTs to split v2[i,f,bf]16/v4i8 subvectors back
5288+
// into individual elements.
5289+
for (const unsigned I : llvm::seq(NumElts)) {
5290+
SDValue SubVector = NewLD.getValue(I);
5291+
DAG.ExtractVectorElements(SubVector, ScalarRes);
5292+
}
5293+
} else {
5294+
for (const unsigned I : llvm::seq(NumElts)) {
5295+
SDValue Res = NewLD.getValue(I);
5296+
if (LoadEltVT != EltVT)
5297+
Res = DAG.getNode(ISD::TRUNCATE, DL, EltVT, Res);
5298+
ScalarRes.push_back(Res);
5299+
}
5300+
}
5301+
52805302
const MVT BuildVecVT =
52815303
MVT::getVectorVT(EltVT.getScalarType(), ScalarRes.size());
52825304
SDValue BuildVec = DAG.getBuildVector(BuildVecVT, DL, ScalarRes);
@@ -5292,23 +5314,20 @@ static SDValue PerformLoadCombine(SDNode *N,
52925314
TargetLowering::DAGCombinerInfo &DCI,
52935315
const NVPTXSubtarget &STI) {
52945316
auto *MemN = cast<MemSDNode>(N);
5295-
EVT MemVT = MemN->getMemoryVT();
5296-
5297-
// ignore volatile loads
5298-
if (MemN->isVolatile())
5299-
return SDValue();
5300-
53015317
// only operate on vectors of f32s / i64s
5302-
if (!MemVT.isVector())
5318+
if (EVT MemVT = MemN->getMemoryVT();
5319+
!(MemVT == MVT::i64 ||
5320+
(MemVT.isVector() && (MemVT.getVectorElementType() == MVT::f32 ||
5321+
MemVT.getVectorElementType() == MVT::i64))))
53035322
return SDValue();
53045323

5305-
EVT ElementVT = MemVT.getVectorElementType();
5306-
if (!(ElementVT == MVT::f32 ||
5307-
(ElementVT == MVT::i64 && N->getOpcode() != ISD::LOAD)))
5308-
return SDValue();
5324+
const unsigned OrigNumResults =
5325+
llvm::count_if(N->values(), [](const auto &VT) {
5326+
return VT == MVT::i64 || VT == MVT::f32 || VT.isVector();
5327+
});
53095328

53105329
SmallDenseMap<SDNode *, unsigned> ExtractElts;
5311-
SDNode *ProxyReg = nullptr;
5330+
SmallVector<SDNode *> ProxyRegs(OrigNumResults, nullptr);
53125331
SmallVector<std::pair<SDNode *, unsigned /*offset*/>> WorkList{{N, 0}};
53135332
while (!WorkList.empty()) {
53145333
auto [V, Offset] = WorkList.pop_back_val();
@@ -5321,8 +5340,14 @@ static SDValue PerformLoadCombine(SDNode *N,
53215340

53225341
SDNode *User = U.getUser();
53235342
if (User->getOpcode() == NVPTXISD::ProxyReg) {
5343+
Offset = U.getResNo() * 2;
5344+
SDNode *&ProxyReg = ProxyRegs[Offset / 2];
5345+
5346+
// We shouldn't have multiple proxy regs for the same value from the
5347+
// load, but bail out anyway since we don't handle this.
53245348
if (ProxyReg)
5325-
return SDValue(); // bail out if we've seen a proxy reg?
5349+
return SDValue();
5350+
53265351
ProxyReg = User;
53275352
} else if (User->getOpcode() == ISD::BITCAST &&
53285353
User->getValueType(0) == MVT::v2f32 &&
@@ -5412,10 +5437,18 @@ static SDValue PerformLoadCombine(SDNode *N,
54125437
if (NewGlueIdx)
54135438
NewGlue = NewLoad.getValue(*NewGlueIdx);
54145439
} else if (N->getOpcode() == ISD::LOAD) { // rewrite a load
5415-
if (auto Result =
5416-
convertVectorLoad(N, DCI.DAG, STI, /*BuildVector=*/false)) {
5440+
std::optional<EVT> CastToType;
5441+
EVT ResVT = N->getValueType(0);
5442+
if (ResVT == MVT::i64) {
5443+
// ld.b64 is treated as a vector by subsequent code
5444+
CastToType = MVT::v2f32;
5445+
}
5446+
if (auto Result = convertVectorLoad(N, DCI.DAG, STI, /*BuildVector=*/false,
5447+
CastToType)) {
54175448
std::tie(NewLoad, NewChain) = *Result;
5418-
NumElts = MemVT.getVectorNumElements();
5449+
NumElts =
5450+
CastToType.value_or(cast<MemSDNode>(NewLoad.getNode())->getMemoryVT())
5451+
.getVectorNumElements();
54195452
if (NewLoad->getValueType(NewLoad->getNumValues() - 1) == MVT::Glue)
54205453
NewGlue = NewLoad.getValue(NewLoad->getNumValues() - 1);
54215454
}
@@ -5427,54 +5460,65 @@ static SDValue PerformLoadCombine(SDNode *N,
54275460
// (3) begin rewriting uses
54285461
SmallVector<SDValue> NewOutputsF32;
54295462

5430-
if (ProxyReg) {
5431-
// scalarize proxyreg, but first rewrite all uses of chain and glue from the
5432-
// old load to the new load
5463+
if (llvm::any_of(ProxyRegs, [](const SDNode *PR) { return PR != nullptr; })) {
5464+
// scalarize proxy regs, but first rewrite all uses of chain and glue from
5465+
// the old load to the new load
54335466
DCI.DAG.ReplaceAllUsesOfValueWith(OldChain, NewChain);
54345467
DCI.DAG.ReplaceAllUsesOfValueWith(OldGlue, NewGlue);
54355468

5436-
// Update the new chain and glue to be old inputs to the proxyreg, if they
5437-
// came from an intervening instruction between this proxyreg and the
5438-
// original load (ex: callseq_end). Other than bitcasts and extractelts, we
5439-
// followed all other nodes by chain and glue accesses.
5440-
if (SDValue OldInChain = ProxyReg->getOperand(0); OldInChain.getNode() != N)
5469+
for (unsigned ProxyI = 0, ProxyE = ProxyRegs.size(); ProxyI != ProxyE;
5470+
++ProxyI) {
5471+
SDNode *ProxyReg = ProxyRegs[ProxyI];
5472+
5473+
// no proxy reg might mean this result is unused
5474+
if (!ProxyReg)
5475+
continue;
5476+
5477+
// Update the new chain and glue to be old inputs to the proxyreg, if they
5478+
// came from an intervening instruction between this proxyreg and the
5479+
// original load (ex: callseq_end). Other than bitcasts and extractelts,
5480+
// we followed all other nodes by chain and glue accesses.
5481+
if (SDValue OldInChain = ProxyReg->getOperand(0);
5482+
OldInChain.getNode() != N)
54415483
NewChain = OldInChain;
5442-
if (SDValue OldInGlue = ProxyReg->getOperand(2); OldInGlue.getNode() != N)
5484+
if (SDValue OldInGlue = ProxyReg->getOperand(2); OldInGlue.getNode() != N)
54435485
NewGlue = OldInGlue;
54445486

5445-
// update OldChain, OldGlue to the outputs of ProxyReg, which we will
5446-
// replace later
5447-
OldChain = SDValue(ProxyReg, 1);
5448-
OldGlue = SDValue(ProxyReg, 2);
5449-
5450-
// generate the scalar proxy regs
5451-
for (unsigned I = 0, E = NumElts; I != E; ++I) {
5452-
SDValue ProxyRegElem =
5453-
DCI.DAG.getNode(NVPTXISD::ProxyReg, SDLoc(ProxyReg),
5454-
DCI.DAG.getVTList(MVT::f32, MVT::Other, MVT::Glue),
5455-
{NewChain, NewLoad.getValue(I), NewGlue});
5456-
NewChain = ProxyRegElem.getValue(1);
5457-
NewGlue = ProxyRegElem.getValue(2);
5458-
NewOutputsF32.push_back(ProxyRegElem);
5487+
// update OldChain, OldGlue to the outputs of ProxyReg, which we will
5488+
// replace later
5489+
OldChain = SDValue(ProxyReg, 1);
5490+
OldGlue = SDValue(ProxyReg, 2);
5491+
5492+
// generate the scalar proxy regs
5493+
for (unsigned I = 0, E = 2; I != E; ++I) {
5494+
SDValue ProxyRegElem = DCI.DAG.getNode(
5495+
NVPTXISD::ProxyReg, SDLoc(ProxyReg),
5496+
DCI.DAG.getVTList(MVT::f32, MVT::Other, MVT::Glue),
5497+
{NewChain, NewLoad.getValue(ProxyI * 2 + I), NewGlue});
5498+
NewChain = ProxyRegElem.getValue(1);
5499+
NewGlue = ProxyRegElem.getValue(2);
5500+
NewOutputsF32.push_back(ProxyRegElem);
5501+
}
5502+
5503+
// replace all uses of the glue and chain from the old proxy reg
5504+
DCI.DAG.ReplaceAllUsesOfValueWith(OldChain, NewChain);
5505+
DCI.DAG.ReplaceAllUsesOfValueWith(OldGlue, NewGlue);
54595506
}
54605507
} else {
54615508
for (unsigned I = 0, E = NumElts; I != E; ++I)
54625509
if (NewLoad->getValueType(I) == MVT::f32)
54635510
NewOutputsF32.push_back(NewLoad.getValue(I));
5511+
5512+
// replace all glue and chain nodes
5513+
DCI.DAG.ReplaceAllUsesOfValueWith(OldChain, NewChain);
5514+
if (OldGlue)
5515+
DCI.DAG.ReplaceAllUsesOfValueWith(OldGlue, NewGlue);
54645516
}
54655517

5466-
// now, for all extractelts, replace them with one of the new outputs
5518+
// replace all extractelts with the new outputs
54675519
for (auto &[Extract, Index] : ExtractElts)
54685520
DCI.CombineTo(Extract, NewOutputsF32[Index], false);
54695521

5470-
// now replace all glue and chain nodes
5471-
DCI.DAG.ReplaceAllUsesOfValueWith(OldChain, NewChain);
5472-
if (OldGlue)
5473-
DCI.DAG.ReplaceAllUsesOfValueWith(OldGlue, NewGlue);
5474-
5475-
// cleanup
5476-
if (ProxyReg)
5477-
DCI.recursivelyDeleteUnusedNodes(ProxyReg);
54785522
return SDValue();
54795523
}
54805524

llvm/test/CodeGen/NVPTX/aggregate-return.ll

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,7 @@ define void @test_v3f32(<3 x float> %input, ptr %output) {
2727
; CHECK-NOT: ld.param.b32 [[E3:%r[0-9]+]], [retval0+12];
2828
store <3 x float> %call, ptr %output, align 8
2929
; CHECK-DAG: st.b32 [{{%rd[0-9]}}+8],
30-
; -- This is suboptimal. We should do st.v2.f32 instead
31-
; of combining 2xf32 info i64.
32-
; CHECK-DAG: st.b64 [{{%rd[0-9]}}],
30+
; CHECK-DAG: st.v2.b32 [{{%rd[0-9]}}], {[[E0]], [[E1]]}
3331
; CHECK: ret;
3432
ret void
3533
}

llvm/test/CodeGen/NVPTX/bf16-instructions.ll

Lines changed: 48 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -712,25 +712,25 @@ define <8 x float> @test_extload_bf16x8(ptr addrspace(3) noundef %arg) #0 {
712712
; SM70-NEXT: // %bb.0:
713713
; SM70-NEXT: ld.param.b64 %rd1, [test_extload_bf16x8_param_0];
714714
; SM70-NEXT: ld.shared.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1];
715-
; SM70-NEXT: mov.b32 {%rs1, %rs2}, %r1;
716-
; SM70-NEXT: mov.b32 {%rs3, %rs4}, %r2;
717-
; SM70-NEXT: mov.b32 {%rs5, %rs6}, %r3;
718-
; SM70-NEXT: mov.b32 {%rs7, %rs8}, %r4;
719-
; SM70-NEXT: cvt.u32.u16 %r5, %rs8;
715+
; SM70-NEXT: mov.b32 {%rs1, %rs2}, %r4;
716+
; SM70-NEXT: cvt.u32.u16 %r5, %rs2;
720717
; SM70-NEXT: shl.b32 %r29, %r5, 16;
721-
; SM70-NEXT: cvt.u32.u16 %r8, %rs7;
718+
; SM70-NEXT: cvt.u32.u16 %r8, %rs1;
722719
; SM70-NEXT: shl.b32 %r30, %r8, 16;
723-
; SM70-NEXT: cvt.u32.u16 %r11, %rs6;
720+
; SM70-NEXT: mov.b32 {%rs3, %rs4}, %r3;
721+
; SM70-NEXT: cvt.u32.u16 %r11, %rs4;
724722
; SM70-NEXT: shl.b32 %r31, %r11, 16;
725-
; SM70-NEXT: cvt.u32.u16 %r14, %rs5;
723+
; SM70-NEXT: cvt.u32.u16 %r14, %rs3;
726724
; SM70-NEXT: shl.b32 %r32, %r14, 16;
727-
; SM70-NEXT: cvt.u32.u16 %r17, %rs4;
725+
; SM70-NEXT: mov.b32 {%rs5, %rs6}, %r2;
726+
; SM70-NEXT: cvt.u32.u16 %r17, %rs6;
728727
; SM70-NEXT: shl.b32 %r33, %r17, 16;
729-
; SM70-NEXT: cvt.u32.u16 %r20, %rs3;
728+
; SM70-NEXT: cvt.u32.u16 %r20, %rs5;
730729
; SM70-NEXT: shl.b32 %r34, %r20, 16;
731-
; SM70-NEXT: cvt.u32.u16 %r23, %rs2;
730+
; SM70-NEXT: mov.b32 {%rs7, %rs8}, %r1;
731+
; SM70-NEXT: cvt.u32.u16 %r23, %rs8;
732732
; SM70-NEXT: shl.b32 %r35, %r23, 16;
733-
; SM70-NEXT: cvt.u32.u16 %r26, %rs1;
733+
; SM70-NEXT: cvt.u32.u16 %r26, %rs7;
734734
; SM70-NEXT: shl.b32 %r36, %r26, 16;
735735
; SM70-NEXT: st.param.v4.b32 [func_retval0], {%r36, %r35, %r34, %r33};
736736
; SM70-NEXT: st.param.v4.b32 [func_retval0+16], {%r32, %r31, %r30, %r29};
@@ -745,18 +745,18 @@ define <8 x float> @test_extload_bf16x8(ptr addrspace(3) noundef %arg) #0 {
745745
; SM80-NEXT: // %bb.0:
746746
; SM80-NEXT: ld.param.b64 %rd1, [test_extload_bf16x8_param_0];
747747
; SM80-NEXT: ld.shared.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1];
748-
; SM80-NEXT: mov.b32 {%rs1, %rs2}, %r1;
749-
; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r2;
750-
; SM80-NEXT: mov.b32 {%rs5, %rs6}, %r3;
751-
; SM80-NEXT: mov.b32 {%rs7, %rs8}, %r4;
752-
; SM80-NEXT: cvt.f32.bf16 %r5, %rs8;
753-
; SM80-NEXT: cvt.f32.bf16 %r6, %rs7;
754-
; SM80-NEXT: cvt.f32.bf16 %r7, %rs6;
755-
; SM80-NEXT: cvt.f32.bf16 %r8, %rs5;
756-
; SM80-NEXT: cvt.f32.bf16 %r9, %rs4;
757-
; SM80-NEXT: cvt.f32.bf16 %r10, %rs3;
758-
; SM80-NEXT: cvt.f32.bf16 %r11, %rs2;
759-
; SM80-NEXT: cvt.f32.bf16 %r12, %rs1;
748+
; SM80-NEXT: mov.b32 {%rs1, %rs2}, %r4;
749+
; SM80-NEXT: cvt.f32.bf16 %r5, %rs2;
750+
; SM80-NEXT: cvt.f32.bf16 %r6, %rs1;
751+
; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r3;
752+
; SM80-NEXT: cvt.f32.bf16 %r7, %rs4;
753+
; SM80-NEXT: cvt.f32.bf16 %r8, %rs3;
754+
; SM80-NEXT: mov.b32 {%rs5, %rs6}, %r2;
755+
; SM80-NEXT: cvt.f32.bf16 %r9, %rs6;
756+
; SM80-NEXT: cvt.f32.bf16 %r10, %rs5;
757+
; SM80-NEXT: mov.b32 {%rs7, %rs8}, %r1;
758+
; SM80-NEXT: cvt.f32.bf16 %r11, %rs8;
759+
; SM80-NEXT: cvt.f32.bf16 %r12, %rs7;
760760
; SM80-NEXT: st.param.v4.b32 [func_retval0], {%r12, %r11, %r10, %r9};
761761
; SM80-NEXT: st.param.v4.b32 [func_retval0+16], {%r8, %r7, %r6, %r5};
762762
; SM80-NEXT: ret;
@@ -770,18 +770,18 @@ define <8 x float> @test_extload_bf16x8(ptr addrspace(3) noundef %arg) #0 {
770770
; SM80-FTZ-NEXT: // %bb.0:
771771
; SM80-FTZ-NEXT: ld.param.b64 %rd1, [test_extload_bf16x8_param_0];
772772
; SM80-FTZ-NEXT: ld.shared.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1];
773-
; SM80-FTZ-NEXT: mov.b32 {%rs1, %rs2}, %r1;
774-
; SM80-FTZ-NEXT: mov.b32 {%rs3, %rs4}, %r2;
775-
; SM80-FTZ-NEXT: mov.b32 {%rs5, %rs6}, %r3;
776-
; SM80-FTZ-NEXT: mov.b32 {%rs7, %rs8}, %r4;
777-
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %r5, %rs8;
778-
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %r6, %rs7;
779-
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %r7, %rs6;
780-
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %r8, %rs5;
781-
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %r9, %rs4;
782-
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %r10, %rs3;
783-
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %r11, %rs2;
784-
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %r12, %rs1;
773+
; SM80-FTZ-NEXT: mov.b32 {%rs1, %rs2}, %r4;
774+
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %r5, %rs2;
775+
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %r6, %rs1;
776+
; SM80-FTZ-NEXT: mov.b32 {%rs3, %rs4}, %r3;
777+
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %r7, %rs4;
778+
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %r8, %rs3;
779+
; SM80-FTZ-NEXT: mov.b32 {%rs5, %rs6}, %r2;
780+
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %r9, %rs6;
781+
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %r10, %rs5;
782+
; SM80-FTZ-NEXT: mov.b32 {%rs7, %rs8}, %r1;
783+
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %r11, %rs8;
784+
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %r12, %rs7;
785785
; SM80-FTZ-NEXT: st.param.v4.b32 [func_retval0], {%r12, %r11, %r10, %r9};
786786
; SM80-FTZ-NEXT: st.param.v4.b32 [func_retval0+16], {%r8, %r7, %r6, %r5};
787787
; SM80-FTZ-NEXT: ret;
@@ -795,18 +795,18 @@ define <8 x float> @test_extload_bf16x8(ptr addrspace(3) noundef %arg) #0 {
795795
; SM90-NEXT: // %bb.0:
796796
; SM90-NEXT: ld.param.b64 %rd1, [test_extload_bf16x8_param_0];
797797
; SM90-NEXT: ld.shared.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1];
798-
; SM90-NEXT: mov.b32 {%rs1, %rs2}, %r1;
799-
; SM90-NEXT: mov.b32 {%rs3, %rs4}, %r2;
800-
; SM90-NEXT: mov.b32 {%rs5, %rs6}, %r3;
801-
; SM90-NEXT: mov.b32 {%rs7, %rs8}, %r4;
802-
; SM90-NEXT: cvt.f32.bf16 %r5, %rs8;
803-
; SM90-NEXT: cvt.f32.bf16 %r6, %rs7;
804-
; SM90-NEXT: cvt.f32.bf16 %r7, %rs6;
805-
; SM90-NEXT: cvt.f32.bf16 %r8, %rs5;
806-
; SM90-NEXT: cvt.f32.bf16 %r9, %rs4;
807-
; SM90-NEXT: cvt.f32.bf16 %r10, %rs3;
808-
; SM90-NEXT: cvt.f32.bf16 %r11, %rs2;
809-
; SM90-NEXT: cvt.f32.bf16 %r12, %rs1;
798+
; SM90-NEXT: mov.b32 {%rs1, %rs2}, %r4;
799+
; SM90-NEXT: cvt.f32.bf16 %r5, %rs2;
800+
; SM90-NEXT: cvt.f32.bf16 %r6, %rs1;
801+
; SM90-NEXT: mov.b32 {%rs3, %rs4}, %r3;
802+
; SM90-NEXT: cvt.f32.bf16 %r7, %rs4;
803+
; SM90-NEXT: cvt.f32.bf16 %r8, %rs3;
804+
; SM90-NEXT: mov.b32 {%rs5, %rs6}, %r2;
805+
; SM90-NEXT: cvt.f32.bf16 %r9, %rs6;
806+
; SM90-NEXT: cvt.f32.bf16 %r10, %rs5;
807+
; SM90-NEXT: mov.b32 {%rs7, %rs8}, %r1;
808+
; SM90-NEXT: cvt.f32.bf16 %r11, %rs8;
809+
; SM90-NEXT: cvt.f32.bf16 %r12, %rs7;
810810
; SM90-NEXT: st.param.v4.b32 [func_retval0], {%r12, %r11, %r10, %r9};
811811
; SM90-NEXT: st.param.v4.b32 [func_retval0+16], {%r8, %r7, %r6, %r5};
812812
; SM90-NEXT: ret;

0 commit comments

Comments
 (0)