Skip to content

Commit 6971c1b

Browse files
committed
[LoongArch] Add support for tail call optimization
This patch adds tail call support to the LoongArch backend. When appropriate, use the `b` or `jr` instruction for tail calls (the `pcalau12i+jirl` instruction pair when use medium codemodel). This patch also modifies the inappropriate operand name: simm26_bl -> simm26_symbol This has been modeled after RISCV's tail call opt. Reviewed By: SixWeining Differential Revision: https://reviews.llvm.org/D137889
1 parent 4b4250c commit 6971c1b

File tree

8 files changed

+349
-21
lines changed

8 files changed

+349
-21
lines changed

llvm/lib/Target/LoongArch/LoongArchExpandPseudoInsts.cpp

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ class LoongArchPreRAExpandPseudo : public MachineFunctionPass {
7777
MachineBasicBlock::iterator &NextMBBI);
7878
bool expandFunctionCALL(MachineBasicBlock &MBB,
7979
MachineBasicBlock::iterator MBBI,
80-
MachineBasicBlock::iterator &NextMBBI);
80+
MachineBasicBlock::iterator &NextMBBI,
81+
bool IsTailCall);
8182
};
8283

8384
char LoongArchPreRAExpandPseudo::ID = 0;
@@ -121,7 +122,9 @@ bool LoongArchPreRAExpandPseudo::expandMI(
121122
case LoongArch::PseudoLA_TLS_GD:
122123
return expandLoadAddressTLSGD(MBB, MBBI, NextMBBI);
123124
case LoongArch::PseudoCALL:
124-
return expandFunctionCALL(MBB, MBBI, NextMBBI);
125+
return expandFunctionCALL(MBB, MBBI, NextMBBI, /*IsTailCall=*/false);
126+
case LoongArch::PseudoTAIL:
127+
return expandFunctionCALL(MBB, MBBI, NextMBBI, /*IsTailCall=*/true);
125128
}
126129
return false;
127130
}
@@ -247,27 +250,43 @@ bool LoongArchPreRAExpandPseudo::expandLoadAddressTLSGD(
247250

248251
bool LoongArchPreRAExpandPseudo::expandFunctionCALL(
249252
MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
250-
MachineBasicBlock::iterator &NextMBBI) {
253+
MachineBasicBlock::iterator &NextMBBI, bool IsTailCall) {
251254
MachineFunction *MF = MBB.getParent();
252255
MachineInstr &MI = *MBBI;
253256
DebugLoc DL = MI.getDebugLoc();
254257
const MachineOperand &Func = MI.getOperand(0);
255258
MachineInstrBuilder CALL;
259+
unsigned Opcode;
256260

257261
switch (MF->getTarget().getCodeModel()) {
258262
default:
259263
report_fatal_error("Unsupported code model");
260264
break;
261-
case CodeModel::Small: // Default CodeModel.
262-
CALL = BuildMI(MBB, MBBI, DL, TII->get(LoongArch::BL)).add(Func);
265+
case CodeModel::Small: {
266+
// CALL:
267+
// bl func
268+
// TAIL:
269+
// b func
270+
Opcode = IsTailCall ? LoongArch::PseudoB_TAIL : LoongArch::BL;
271+
CALL = BuildMI(MBB, MBBI, DL, TII->get(Opcode)).add(Func);
263272
break;
273+
}
264274
case CodeModel::Medium: {
275+
// CALL:
265276
// pcalau12i $ra, %pc_hi20(func)
266277
// jirl $ra, $ra, %pc_lo12(func)
278+
// TAIL:
279+
// pcalau12i $scratch, %pc_hi20(func)
280+
// jirl $r0, $scratch, %pc_lo12(func)
281+
Opcode =
282+
IsTailCall ? LoongArch::PseudoJIRL_TAIL : LoongArch::PseudoJIRL_CALL;
283+
Register ScratchReg =
284+
IsTailCall
285+
? MF->getRegInfo().createVirtualRegister(&LoongArch::GPRRegClass)
286+
: LoongArch::R1;
267287
MachineInstrBuilder MIB =
268-
BuildMI(MBB, MBBI, DL, TII->get(LoongArch::PCALAU12I), LoongArch::R1);
269-
CALL = BuildMI(MBB, MBBI, DL, TII->get(LoongArch::PseudoJIRL_CALL))
270-
.addReg(LoongArch::R1);
288+
BuildMI(MBB, MBBI, DL, TII->get(LoongArch::PCALAU12I), ScratchReg);
289+
CALL = BuildMI(MBB, MBBI, DL, TII->get(Opcode)).addReg(ScratchReg);
271290
if (Func.isSymbol()) {
272291
const char *FnName = Func.getSymbolName();
273292
MIB.addExternalSymbol(FnName, LoongArchII::MO_PCREL_HI);

llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp

Lines changed: 73 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ using namespace llvm;
3131

3232
#define DEBUG_TYPE "loongarch-isel-lowering"
3333

34+
STATISTIC(NumTailCalls, "Number of tail calls");
35+
3436
static cl::opt<bool> ZeroDivCheck(
3537
"loongarch-check-zero-division", cl::Hidden,
3638
cl::desc("Trap on integer division by zero."),
@@ -1334,6 +1336,7 @@ const char *LoongArchTargetLowering::getTargetNodeName(unsigned Opcode) const {
13341336
// TODO: Add more target-dependent nodes later.
13351337
NODE_NAME_CASE(CALL)
13361338
NODE_NAME_CASE(RET)
1339+
NODE_NAME_CASE(TAIL)
13371340
NODE_NAME_CASE(SLL_W)
13381341
NODE_NAME_CASE(SRA_W)
13391342
NODE_NAME_CASE(SRL_W)
@@ -1808,6 +1811,48 @@ SDValue LoongArchTargetLowering::LowerFormalArguments(
18081811
return Chain;
18091812
}
18101813

1814+
// Check whether the call is eligible for tail call optimization.
1815+
bool LoongArchTargetLowering::isEligibleForTailCallOptimization(
1816+
CCState &CCInfo, CallLoweringInfo &CLI, MachineFunction &MF,
1817+
const SmallVectorImpl<CCValAssign> &ArgLocs) const {
1818+
1819+
auto CalleeCC = CLI.CallConv;
1820+
auto &Outs = CLI.Outs;
1821+
auto &Caller = MF.getFunction();
1822+
auto CallerCC = Caller.getCallingConv();
1823+
1824+
// Do not tail call opt if the stack is used to pass parameters.
1825+
if (CCInfo.getNextStackOffset() != 0)
1826+
return false;
1827+
1828+
// Do not tail call opt if any parameters need to be passed indirectly.
1829+
for (auto &VA : ArgLocs)
1830+
if (VA.getLocInfo() == CCValAssign::Indirect)
1831+
return false;
1832+
1833+
// Do not tail call opt if either caller or callee uses struct return
1834+
// semantics.
1835+
auto IsCallerStructRet = Caller.hasStructRetAttr();
1836+
auto IsCalleeStructRet = Outs.empty() ? false : Outs[0].Flags.isSRet();
1837+
if (IsCallerStructRet || IsCalleeStructRet)
1838+
return false;
1839+
1840+
// Do not tail call opt if either the callee or caller has a byval argument.
1841+
for (auto &Arg : Outs)
1842+
if (Arg.Flags.isByVal())
1843+
return false;
1844+
1845+
// The callee has to preserve all registers the caller needs to preserve.
1846+
const LoongArchRegisterInfo *TRI = Subtarget.getRegisterInfo();
1847+
const uint32_t *CallerPreserved = TRI->getCallPreservedMask(MF, CallerCC);
1848+
if (CalleeCC != CallerCC) {
1849+
const uint32_t *CalleePreserved = TRI->getCallPreservedMask(MF, CalleeCC);
1850+
if (!TRI->regmaskSubsetEqual(CallerPreserved, CalleePreserved))
1851+
return false;
1852+
}
1853+
return true;
1854+
}
1855+
18111856
static Align getPrefTypeAlign(EVT VT, SelectionDAG &DAG) {
18121857
return DAG.getDataLayout().getPrefTypeAlign(
18131858
VT.getTypeForEVT(*DAG.getContext()));
@@ -1829,7 +1874,7 @@ LoongArchTargetLowering::LowerCall(CallLoweringInfo &CLI,
18291874
bool IsVarArg = CLI.IsVarArg;
18301875
EVT PtrVT = getPointerTy(DAG.getDataLayout());
18311876
MVT GRLenVT = Subtarget.getGRLenVT();
1832-
CLI.IsTailCall = false;
1877+
bool &IsTailCall = CLI.IsTailCall;
18331878

18341879
MachineFunction &MF = DAG.getMachineFunction();
18351880

@@ -1839,6 +1884,16 @@ LoongArchTargetLowering::LowerCall(CallLoweringInfo &CLI,
18391884

18401885
analyzeOutputArgs(MF, ArgCCInfo, Outs, /*IsRet=*/false, &CLI, CC_LoongArch);
18411886

1887+
// Check if it's really possible to do a tail call.
1888+
if (IsTailCall)
1889+
IsTailCall = isEligibleForTailCallOptimization(ArgCCInfo, CLI, MF, ArgLocs);
1890+
1891+
if (IsTailCall)
1892+
++NumTailCalls;
1893+
else if (CLI.CB && CLI.CB->isMustTailCall())
1894+
report_fatal_error("failed to perform tail call elimination on a call "
1895+
"site marked musttail");
1896+
18421897
// Get a count of how many bytes are to be pushed on the stack.
18431898
unsigned NumBytes = ArgCCInfo.getNextStackOffset();
18441899

@@ -1860,12 +1915,13 @@ LoongArchTargetLowering::LowerCall(CallLoweringInfo &CLI,
18601915

18611916
Chain = DAG.getMemcpy(Chain, DL, FIPtr, Arg, SizeNode, Alignment,
18621917
/*IsVolatile=*/false,
1863-
/*AlwaysInline=*/false, /*isTailCall=*/false,
1918+
/*AlwaysInline=*/false, /*isTailCall=*/IsTailCall,
18641919
MachinePointerInfo(), MachinePointerInfo());
18651920
ByValArgs.push_back(FIPtr);
18661921
}
18671922

1868-
Chain = DAG.getCALLSEQ_START(Chain, NumBytes, 0, CLI.DL);
1923+
if (!IsTailCall)
1924+
Chain = DAG.getCALLSEQ_START(Chain, NumBytes, 0, CLI.DL);
18691925

18701926
// Copy argument values to their designated locations.
18711927
SmallVector<std::pair<Register, SDValue>> RegsToPass;
@@ -1932,6 +1988,8 @@ LoongArchTargetLowering::LowerCall(CallLoweringInfo &CLI,
19321988
RegsToPass.push_back(std::make_pair(VA.getLocReg(), ArgValue));
19331989
} else {
19341990
assert(VA.isMemLoc() && "Argument not register or memory");
1991+
assert(!IsTailCall && "Tail call not allowed if stack is used "
1992+
"for passing parameters");
19351993

19361994
// Work out the address of the stack slot.
19371995
if (!StackPtr.getNode())
@@ -1986,11 +2044,13 @@ LoongArchTargetLowering::LowerCall(CallLoweringInfo &CLI,
19862044
for (auto &Reg : RegsToPass)
19872045
Ops.push_back(DAG.getRegister(Reg.first, Reg.second.getValueType()));
19882046

1989-
// Add a register mask operand representing the call-preserved registers.
1990-
const TargetRegisterInfo *TRI = Subtarget.getRegisterInfo();
1991-
const uint32_t *Mask = TRI->getCallPreservedMask(MF, CallConv);
1992-
assert(Mask && "Missing call preserved mask for calling convention");
1993-
Ops.push_back(DAG.getRegisterMask(Mask));
2047+
if (!IsTailCall) {
2048+
// Add a register mask operand representing the call-preserved registers.
2049+
const TargetRegisterInfo *TRI = Subtarget.getRegisterInfo();
2050+
const uint32_t *Mask = TRI->getCallPreservedMask(MF, CallConv);
2051+
assert(Mask && "Missing call preserved mask for calling convention");
2052+
Ops.push_back(DAG.getRegisterMask(Mask));
2053+
}
19942054

19952055
// Glue the call to the argument copies, if any.
19962056
if (Glue.getNode())
@@ -1999,6 +2059,11 @@ LoongArchTargetLowering::LowerCall(CallLoweringInfo &CLI,
19992059
// Emit the call.
20002060
SDVTList NodeTys = DAG.getVTList(MVT::Other, MVT::Glue);
20012061

2062+
if (IsTailCall) {
2063+
MF.getFrameInfo().setHasTailCall();
2064+
return DAG.getNode(LoongArchISD::TAIL, DL, NodeTys, Ops);
2065+
}
2066+
20022067
Chain = DAG.getNode(LoongArchISD::CALL, DL, NodeTys, Ops);
20032068
DAG.addNoMergeSiteInfo(Chain.getNode(), CLI.NoMerge);
20042069
Glue = Chain.getValue(1);

llvm/lib/Target/LoongArch/LoongArchISelLowering.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ enum NodeType : unsigned {
2929
// TODO: add more LoongArchISDs
3030
CALL,
3131
RET,
32+
TAIL,
33+
3234
// 32-bit shifts, directly matching the semantics of the named LoongArch
3335
// instructions.
3436
SLL_W,
@@ -204,6 +206,10 @@ class LoongArchTargetLowering : public TargetLowering {
204206
void LowerAsmOperandForConstraint(SDValue Op, std::string &Constraint,
205207
std::vector<SDValue> &Ops,
206208
SelectionDAG &DAG) const override;
209+
210+
bool isEligibleForTailCallOptimization(
211+
CCState &CCInfo, CallLoweringInfo &CLI, MachineFunction &MF,
212+
const SmallVectorImpl<CCValAssign> &ArgLocs) const;
207213
};
208214

209215
} // end namespace llvm

llvm/lib/Target/LoongArch/LoongArchInstrInfo.td

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ def loongarch_call : SDNode<"LoongArchISD::CALL", SDT_LoongArchCall,
5050
SDNPVariadic]>;
5151
def loongarch_ret : SDNode<"LoongArchISD::RET", SDTNone,
5252
[SDNPHasChain, SDNPOptInGlue, SDNPVariadic]>;
53+
def loongarch_tail : SDNode<"LoongArchISD::TAIL", SDT_LoongArchCall,
54+
[SDNPHasChain, SDNPOptInGlue, SDNPOutGlue,
55+
SDNPVariadic]>;
5356
def loongarch_sll_w : SDNode<"LoongArchISD::SLL_W", SDT_LoongArchIntBinOpW>;
5457
def loongarch_sra_w : SDNode<"LoongArchISD::SRA_W", SDT_LoongArchIntBinOpW>;
5558
def loongarch_srl_w : SDNode<"LoongArchISD::SRL_W", SDT_LoongArchIntBinOpW>;
@@ -232,8 +235,8 @@ def SImm26OperandBL: AsmOperandClass {
232235
let ParserMethod = "parseSImm26Operand";
233236
}
234237

235-
// A symbol or an imm used in BL/PseudoCALL.
236-
def simm26_bl : Operand<GRLenVT> {
238+
// A symbol or an imm used in BL/PseudoCALL/PseudoTAIL.
239+
def simm26_symbol : Operand<GRLenVT> {
237240
let ParserMatchClass = SImm26OperandBL;
238241
let EncoderMethod = "getImmOpValueAsr2";
239242
let DecoderMethod = "decodeSImmOperand<26, 2>";
@@ -455,7 +458,7 @@ def BNEZ : BrCCZ_1RI21<0b010001, "bnez">;
455458
def B : Br_I26<0b010100, "b">;
456459

457460
let isCall = 1, Defs=[R1] in
458-
def BL : FmtI26<0b010101, (outs), (ins simm26_bl:$imm26), "bl", "$imm26">;
461+
def BL : FmtI26<0b010101, (outs), (ins simm26_symbol:$imm26), "bl", "$imm26">;
459462
def JIRL : Fmt2RI16<0b010011, (outs GPR:$rd),
460463
(ins GPR:$rj, simm16_lsl2:$imm16), "jirl",
461464
"$rd, $rj, $imm16">;
@@ -934,7 +937,7 @@ def : Pat<(brind (add GPR:$rj, simm16_lsl2:$imm16)),
934937
(PseudoBRIND GPR:$rj, simm16_lsl2:$imm16)>;
935938

936939
let isCall = 1, Defs = [R1] in
937-
def PseudoCALL : Pseudo<(outs), (ins simm26_bl:$func)>;
940+
def PseudoCALL : Pseudo<(outs), (ins simm26_symbol:$func)>;
938941

939942
def : Pat<(loongarch_call tglobaladdr:$func), (PseudoCALL tglobaladdr:$func)>;
940943
def : Pat<(loongarch_call texternalsym:$func), (PseudoCALL texternalsym:$func)>;
@@ -953,6 +956,28 @@ let isBarrier = 1, isReturn = 1, isTerminator = 1 in
953956
def PseudoRET : Pseudo<(outs), (ins), [(loongarch_ret)]>,
954957
PseudoInstExpansion<(JIRL R0, R1, 0)>;
955958

959+
let isCall = 1, isTerminator = 1, isReturn = 1, isBarrier = 1, Uses = [R3] in
960+
def PseudoTAIL : Pseudo<(outs), (ins simm26_symbol:$dst)>;
961+
962+
def : Pat<(loongarch_tail (iPTR tglobaladdr:$dst)),
963+
(PseudoTAIL tglobaladdr:$dst)>;
964+
def : Pat<(loongarch_tail (iPTR texternalsym:$dst)),
965+
(PseudoTAIL texternalsym:$dst)>;
966+
967+
let isCall = 1, isTerminator = 1, isReturn = 1, isBarrier = 1, Uses = [R3] in
968+
def PseudoTAILIndirect : Pseudo<(outs), (ins GPRT:$rj),
969+
[(loongarch_tail GPRT:$rj)]>,
970+
PseudoInstExpansion<(JIRL R0, GPR:$rj, 0)>;
971+
972+
let isCall = 1, isTerminator = 1, isReturn = 1, isBarrier = 1, Uses = [R3] in
973+
def PseudoB_TAIL : Pseudo<(outs), (ins simm26_b:$imm26)>,
974+
PseudoInstExpansion<(B simm26_b:$imm26)>;
975+
976+
let isCall = 1, isTerminator = 1, isReturn = 1, isBarrier = 1, Uses = [R3] in
977+
def PseudoJIRL_TAIL : Pseudo<(outs), (ins GPR:$rj, simm16_lsl2:$imm16)>,
978+
PseudoInstExpansion<(JIRL R0, GPR:$rj,
979+
simm16_lsl2:$imm16)>;
980+
956981
let hasSideEffects = 0, mayLoad = 0, mayStore = 0 in
957982
def PseudoLA_PCREL : Pseudo<(outs GPR:$dst), (ins grlenimm:$src)>;
958983

llvm/lib/Target/LoongArch/LoongArchRegisterInfo.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,16 @@ def GPR : RegisterClass<"LoongArch", [GRLenVT], 32, (add
9898
let RegInfos = GRLenRI;
9999
}
100100

101+
// GPR for indirect tail calls. We can't use callee-saved registers, as they are
102+
// restored to the saved value before the tail call, which would clobber a call
103+
// address.
104+
def GPRT : RegisterClass<"LoongArch", [GRLenVT], 32, (add
105+
// a0...a7, t0...t8
106+
(sequence "R%u", 4, 20)
107+
)> {
108+
let RegInfos = GRLenRI;
109+
}
110+
101111
// Floating point registers
102112

103113
let RegAltNameIndices = [RegAliasName] in {

llvm/test/CodeGen/LoongArch/codemodel-medium.ll

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,19 @@ entry:
6161
call void @llvm.memset.p0.i64(ptr %dst, i8 0, i64 1000, i1 false)
6262
ret void
6363
}
64+
65+
;; Tail call with different codemodel.
66+
declare i32 @callee_tail(i32 %i)
67+
define i32 @caller_tail(i32 %i) nounwind {
68+
; SMALL-LABEL: caller_tail:
69+
; SMALL: # %bb.0: # %entry
70+
; SMALL-NEXT: b %plt(callee_tail)
71+
;
72+
; MEDIUM-LABEL: caller_tail:
73+
; MEDIUM: # %bb.0: # %entry
74+
; MEDIUM-NEXT: pcalau12i $a1, %pc_hi20(callee_tail)
75+
; MEDIUM-NEXT: jirl $zero, $a1, %pc_lo12(callee_tail)
76+
entry:
77+
%r = tail call i32 @callee_tail(i32 %i)
78+
ret i32 %r
79+
}

llvm/test/CodeGen/LoongArch/nomerge.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,4 @@ attributes #0 = { nomerge }
3232
; CHECK: .LBB0_3: # %if.then2
3333
; CHECK-NEXT: bl %plt(bar)
3434
; CHECK: .LBB0_4: # %if.end3
35-
; CHECK: bl %plt(bar)
35+
; CHECK: b %plt(bar)

0 commit comments

Comments
 (0)