Skip to content

Commit c278c96

Browse files
committed
address comments
1 parent ae5a4fb commit c278c96

File tree

3 files changed

+63
-50
lines changed

3 files changed

+63
-50
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 48 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1364,26 +1364,6 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
13641364
return DAG.getConstant(I, dl, MVT::i32);
13651365
};
13661366

1367-
// Variadic arguments.
1368-
//
1369-
// Normally, for each argument, we declare a param scalar or a param
1370-
// byte array in the .param space, and store the argument value to that
1371-
// param scalar or array starting at offset 0.
1372-
//
1373-
// In the case of the first variadic argument, we declare a vararg byte array
1374-
// with size 0. The exact size of this array isn't known at this point, so
1375-
// it'll be patched later. All the variadic arguments will be stored to this
1376-
// array at a certain offset (which gets tracked by 'VAOffset'). The offset is
1377-
// initially set to 0, so it can be used for non-variadic arguments (which use
1378-
// 0 offset) to simplify the code.
1379-
//
1380-
// After all vararg is processed, 'VAOffset' holds the size of the
1381-
// vararg byte array.
1382-
1383-
SDValue VADeclareParam = SDValue(); // vararg byte array
1384-
const unsigned FirstVAArg = CLI.NumFixedArgs; // position of first variadic
1385-
unsigned VAOffset = 0; // current offset in the param array
1386-
13871367
const unsigned UniqueCallSite = GlobalUniqueCallSite++;
13881368
const SDValue CallChain = CLI.Chain;
13891369
const SDValue StartChain =
@@ -1392,7 +1372,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
13921372

13931373
SmallVector<SDValue, 16> CallPrereqs{StartChain};
13941374

1395-
const auto DeclareScalarParam = [&](SDValue Symbol, unsigned Size) {
1375+
const auto MakeDeclareScalarParam = [&](SDValue Symbol, unsigned Size) {
13961376
// PTX ABI requires integral types to be at least 32 bits in size. FP16 is
13971377
// loaded/stored using i16, so it's handled here as well.
13981378
const unsigned SizeBits = promoteScalarArgumentSize(Size * 8);
@@ -1404,8 +1384,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
14041384
return Declare;
14051385
};
14061386

1407-
const auto DeclareArrayParam = [&](SDValue Symbol, Align Align,
1408-
unsigned Size) {
1387+
const auto MakeDeclareArrayParam = [&](SDValue Symbol, Align Align,
1388+
unsigned Size) {
14091389
SDValue Declare = DAG.getNode(
14101390
NVPTXISD::DeclareArrayParam, dl, {MVT::Other, MVT::Glue},
14111391
{StartChain, Symbol, GetI32(Align.value()), GetI32(Size), DeclareGlue});
@@ -1414,6 +1394,33 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
14141394
return Declare;
14151395
};
14161396

1397+
// Variadic arguments.
1398+
//
1399+
// Normally, for each argument, we declare a param scalar or a param
1400+
// byte array in the .param space, and store the argument value to that
1401+
// param scalar or array starting at offset 0.
1402+
//
1403+
// In the case of the first variadic argument, we declare a vararg byte array
1404+
// with size 0. The exact size of this array isn't known at this point, so
1405+
// it'll be patched later. All the variadic arguments will be stored to this
1406+
// array at a certain offset (which gets tracked by 'VAOffset'). The offset is
1407+
// initially set to 0, so it can be used for non-variadic arguments (which use
1408+
// 0 offset) to simplify the code.
1409+
//
1410+
// After all vararg is processed, 'VAOffset' holds the size of the
1411+
// vararg byte array.
1412+
assert((CLI.IsVarArg || CLI.Args.size() == CLI.NumFixedArgs) &&
1413+
"Non-VarArg function with extra arguments");
1414+
1415+
const unsigned FirstVAArg = CLI.NumFixedArgs; // position of first variadic
1416+
unsigned VAOffset = 0; // current offset in the param array
1417+
1418+
const SDValue VADeclareParam =
1419+
CLI.Args.size() > FirstVAArg
1420+
? MakeDeclareArrayParam(getCallParamSymbol(DAG, FirstVAArg, MVT::i32),
1421+
Align(STI.getMaxRequiredAlignment()), 0)
1422+
: SDValue();
1423+
14171424
// Args.size() and Outs.size() need not match.
14181425
// Outs.size() will be larger
14191426
// * if there is an aggregate argument with multiple fields (each field
@@ -1474,21 +1481,17 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
14741481
"type size mismatch");
14751482

14761483
const SDValue ArgDeclare = [&]() {
1477-
if (IsVAArg) {
1478-
if (ArgI == FirstVAArg)
1479-
VADeclareParam = DeclareArrayParam(
1480-
ParamSymbol, Align(STI.getMaxRequiredAlignment()), 0);
1484+
if (IsVAArg)
14811485
return VADeclareParam;
1482-
}
14831486

14841487
if (IsByVal || shouldPassAsArray(Arg.Ty))
1485-
return DeclareArrayParam(ParamSymbol, ArgAlign, TypeSize);
1488+
return MakeDeclareArrayParam(ParamSymbol, ArgAlign, TypeSize);
14861489

14871490
assert(ArgOuts.size() == 1 && "We must pass only one value as non-array");
14881491
assert((ArgOuts[0].VT.isInteger() || ArgOuts[0].VT.isFloatingPoint()) &&
14891492
"Only int and float types are supported as non-array arguments");
14901493

1491-
return DeclareScalarParam(ParamSymbol, TypeSize);
1494+
return MakeDeclareScalarParam(ParamSymbol, TypeSize);
14921495
}();
14931496

14941497
// PTX Interoperability Guide 3.3(A): [Integer] Values shorter
@@ -1548,7 +1551,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
15481551
if (NumElts == 1) {
15491552
Val = GetStoredValue(J, EltVT, CurrentAlign);
15501553
} else {
1551-
SmallVector<SDValue, 6> StoreVals;
1554+
SmallVector<SDValue, 8> StoreVals;
15521555
for (const unsigned K : llvm::seq(NumElts)) {
15531556
SDValue ValJ = GetStoredValue(J + K, EltVT, CurrentAlign);
15541557
if (ValJ.getValueType().isVector())
@@ -1589,9 +1592,9 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
15891592
const unsigned ResultSize = DL.getTypeAllocSize(RetTy);
15901593
if (shouldPassAsArray(RetTy)) {
15911594
const Align RetAlign = getArgumentAlignment(CB, RetTy, 0, DL);
1592-
DeclareArrayParam(RetSymbol, RetAlign, ResultSize);
1595+
MakeDeclareArrayParam(RetSymbol, RetAlign, ResultSize);
15931596
} else {
1594-
DeclareScalarParam(RetSymbol, ResultSize);
1597+
MakeDeclareScalarParam(RetSymbol, ResultSize);
15951598
}
15961599
}
15971600

@@ -1715,17 +1718,16 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
17151718

17161719
LoadChains.push_back(R.getValue(1));
17171720

1718-
if (NumElts == 1) {
1721+
if (NumElts == 1)
17191722
ProxyRegOps.push_back(R);
1720-
} else {
1723+
else
17211724
for (const unsigned J : llvm::seq(NumElts)) {
17221725
SDValue Elt = DAG.getNode(
17231726
LoadVT.isVector() ? ISD::EXTRACT_SUBVECTOR
17241727
: ISD::EXTRACT_VECTOR_ELT,
17251728
dl, LoadVT, R, DAG.getVectorIdxConstant(J * PackingAmt, dl));
17261729
ProxyRegOps.push_back(Elt);
17271730
}
1728-
}
17291731
I += NumElts;
17301732
}
17311733
}
@@ -5578,7 +5580,7 @@ static SDValue sinkProxyReg(SDValue R, SDValue Chain,
55785580
{Chain, R});
55795581
}
55805582
case ISD::BUILD_VECTOR: {
5581-
if (DCI.isAfterLegalizeDAG())
5583+
if (DCI.isBeforeLegalize())
55825584
return SDValue();
55835585

55845586
SmallVector<SDValue, 16> Ops;
@@ -5590,6 +5592,15 @@ static SDValue sinkProxyReg(SDValue R, SDValue Chain,
55905592
}
55915593
return DCI.DAG.getNode(ISD::BUILD_VECTOR, SDLoc(R), R.getValueType(), Ops);
55925594
}
5595+
case ISD::EXTRACT_VECTOR_ELT: {
5596+
if (DCI.isBeforeLegalize())
5597+
return SDValue();
5598+
5599+
if (SDValue V = sinkProxyReg(R.getOperand(0), Chain, DCI))
5600+
return DCI.DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SDLoc(R), R.getValueType(),
5601+
V, R.getOperand(1));
5602+
return SDValue();
5603+
}
55935604
default:
55945605
return SDValue();
55955606
}

llvm/test/CodeGen/NVPTX/convert-call-to-indirect.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,8 @@ define i64 @test_param_type_mismatch_variadic(ptr %p) {
173173
; CHECK-NEXT: // %bb.0:
174174
; CHECK-NEXT: ld.param.b64 %rd1, [test_param_type_mismatch_variadic_param_0];
175175
; CHECK-NEXT: { // callseq 4, 0
176-
; CHECK-NEXT: .param .b64 param0;
177176
; CHECK-NEXT: .param .align 8 .b8 param1[8];
177+
; CHECK-NEXT: .param .b64 param0;
178178
; CHECK-NEXT: .param .b64 retval0;
179179
; CHECK-NEXT: st.param.b64 [param0], %rd1;
180180
; CHECK-NEXT: st.param.b64 [param1], 7;
@@ -195,8 +195,8 @@ define i64 @test_param_count_mismatch_variadic(ptr %p) {
195195
; CHECK-NEXT: // %bb.0:
196196
; CHECK-NEXT: ld.param.b64 %rd1, [test_param_count_mismatch_variadic_param_0];
197197
; CHECK-NEXT: { // callseq 5, 0
198-
; CHECK-NEXT: .param .b64 param0;
199198
; CHECK-NEXT: .param .align 8 .b8 param1[8];
199+
; CHECK-NEXT: .param .b64 param0;
200200
; CHECK-NEXT: .param .b64 retval0;
201201
; CHECK-NEXT: st.param.b64 [param0], %rd1;
202202
; CHECK-NEXT: st.param.b64 [param1], 7;

llvm/test/CodeGen/NVPTX/param-load-store.ll

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ define <4 x i8> @test_v4i8(<4 x i8> %a) {
243243
; CHECK: call.uni (retval0), test_v5i8,
244244
; CHECK-DAG: ld.param.b32 [[RE0:%r[0-9]+]], [retval0];
245245
; CHECK-DAG: ld.param.b8 [[RE4:%rs[0-9]+]], [retval0+4];
246-
; CHECK-DAG: st.param.b32 [func_retval0], {{%r[0-9]+}};
246+
; CHECK-DAG: st.param.b32 [func_retval0], [[RE0]];
247247
; CHECK-DAG: st.param.b8 [func_retval0+4], [[RE4]];
248248
; CHECK-NEXT: ret;
249249
define <5 x i8> @test_v5i8(<5 x i8> %a) {
@@ -311,8 +311,9 @@ define signext i16 @test_i16s(i16 signext %a) {
311311
; CHECK-DAG: st.param.b32 [param0], [[E0]];
312312
; CHECK-DAG: st.param.b16 [param0+4], [[E2]];
313313
; CHECK: call.uni (retval0), test_v3i16,
314-
; CHECK: ld.param.v2.b16 {[[RE0:%rs[0-9]+]], [[RE1:%rs[0-9]+]]}, [retval0];
314+
; CHECK: ld.param.b32 [[RE:%r[0-9]+]], [retval0];
315315
; CHECK: ld.param.b16 [[RE2:%rs[0-9]+]], [retval0+4];
316+
; CHECK-DAG: mov.b32 {[[RE0:%rs[0-9]+]], [[RE1:%rs[0-9]+]]}, [[RE]];
316317
; CHECK-DAG: st.param.v2.b16 [func_retval0], {[[RE0]], [[RE1]]};
317318
; CHECK-DAG: st.param.b16 [func_retval0+4], [[RE2]];
318319
; CHECK-NEXT: ret;
@@ -347,9 +348,9 @@ define <4 x i16> @test_v4i16(<4 x i16> %a) {
347348
; CHECK-DAG: st.param.v2.b32 [param0], {[[E0]], [[E1]]};
348349
; CHECK-DAG: st.param.b16 [param0+8], [[E4]];
349350
; CHECK: call.uni (retval0), test_v5i16,
350-
; CHECK-DAG: ld.param.v4.b16 {[[RE0:%rs[0-9]+]], [[RE1:%rs[0-9]+]], [[RE2:%rs[0-9]+]], [[RE3:%rs[0-9]+]]}, [retval0];
351+
; CHECK-DAG: ld.param.v2.b32 {[[RE0:%r[0-9]+]], [[RE1:%r[0-9]+]]}, [retval0];
351352
; CHECK-DAG: ld.param.b16 [[RE4:%rs[0-9]+]], [retval0+8];
352-
; CHECK-DAG: st.param.v4.b16 [func_retval0], {[[RE0]], [[RE1]], [[RE2]], [[RE3]]}
353+
; CHECK-DAG: st.param.v2.b32 [func_retval0], {[[RE0]], [[RE1]]}
353354
; CHECK-DAG: st.param.b16 [func_retval0+8], [[RE4]];
354355
; CHECK-NEXT: ret;
355356
define <5 x i16> @test_v5i16(<5 x i16> %a) {
@@ -432,8 +433,9 @@ define <2 x bfloat> @test_v2bf16(<2 x bfloat> %a) {
432433
; CHECK-DAG: st.param.b32 [param0], [[E0]];
433434
; CHECK-DAG: st.param.b16 [param0+4], [[E2]];
434435
; CHECK: call.uni (retval0), test_v3f16,
435-
; CHECK-DAG: ld.param.v2.b16 {[[R0:%rs[0-9]+]], [[R1:%rs[0-9]+]]}, [retval0];
436+
; CHECK-DAG: ld.param.b32 [[R:%r[0-9]+]], [retval0];
436437
; CHECK-DAG: ld.param.b16 [[R2:%rs[0-9]+]], [retval0+4];
438+
; CHECK-DAG: mov.b32 {[[R0:%rs[0-9]+]], [[R1:%rs[0-9]+]]}, [[R]];
437439
; CHECK-DAG: st.param.v2.b16 [func_retval0], {[[R0]], [[R1]]};
438440
; CHECK-DAG: st.param.b16 [func_retval0+4], [[R2]];
439441
; CHECK: ret;
@@ -468,9 +470,9 @@ define <4 x half> @test_v4f16(<4 x half> %a) {
468470
; CHECK-DAG: st.param.v2.b32 [param0], {[[E0]], [[E1]]};
469471
; CHECK-DAG: st.param.b16 [param0+8], [[E4]];
470472
; CHECK: call.uni (retval0), test_v5f16,
471-
; CHECK-DAG: ld.param.v4.b16 {[[R0:%rs[0-9]+]], [[R1:%rs[0-9]+]], [[R2:%rs[0-9]+]], [[R3:%rs[0-9]+]]}, [retval0];
473+
; CHECK-DAG: ld.param.v2.b32 {[[R0:%r[0-9]+]], [[R1:%r[0-9]+]]}, [retval0];
472474
; CHECK-DAG: ld.param.b16 [[R4:%rs[0-9]+]], [retval0+8];
473-
; CHECK-DAG: st.param.v4.b16 [func_retval0], {[[R0]], [[R1]], [[R2]], [[R3]]};
475+
; CHECK-DAG: st.param.v2.b32 [func_retval0], {[[R0]], [[R1]]};
474476
; CHECK-DAG: st.param.b16 [func_retval0+8], [[R4]];
475477
; CHECK: ret;
476478
define <5 x half> @test_v5f16(<5 x half> %a) {
@@ -506,11 +508,11 @@ define <8 x half> @test_v8f16(<8 x half> %a) {
506508
; CHECK-DAG: st.param.v2.b32 [param0+8], {[[E2]], [[E3]]};
507509
; CHECK-DAG: st.param.b16 [param0+16], [[E8]];
508510
; CHECK: call.uni (retval0), test_v9f16,
509-
; CHECK-DAG: ld.param.v4.b16 {[[R0:%rs[0-9]+]], [[R1:%rs[0-9]+]], [[R2:%rs[0-9]+]], [[R3:%rs[0-9]+]]}, [retval0];
510-
; CHECK-DAG: ld.param.v4.b16 {[[R4:%rs[0-9]+]], [[R5:%rs[0-9]+]], [[R6:%rs[0-9]+]], [[R7:%rs[0-9]+]]}, [retval0+8];
511+
; CHECK-DAG: ld.param.v2.b32 {[[R0:%r[0-9]+]], [[R1:%r[0-9]+]]}, [retval0];
512+
; CHECK-DAG: ld.param.v2.b32 {[[R2:%r[0-9]+]], [[R3:%r[0-9]+]]}, [retval0+8];
511513
; CHECK-DAG: ld.param.b16 [[R8:%rs[0-9]+]], [retval0+16];
512-
; CHECK-DAG: st.param.v4.b16 [func_retval0], {[[R0]], [[R1]], [[R2]], [[R3]]};
513-
; CHECK-DAG: st.param.v4.b16 [func_retval0+8], {[[R4]], [[R5]], [[R6]], [[R7]]};
514+
; CHECK-DAG: st.param.v2.b32 [func_retval0], {[[R0]], [[R1]]};
515+
; CHECK-DAG: st.param.v2.b32 [func_retval0+8], {[[R2]], [[R3]]};
514516
; CHECK-DAG: st.param.b16 [func_retval0+16], [[R8]];
515517
; CHECK: ret;
516518
define <9 x half> @test_v9f16(<9 x half> %a) {

0 commit comments

Comments
 (0)