Skip to content

Commit 893fa06

Browse files
authored
[RISC-V] Adjust trampoline code for branch control flow protection (#141949)
Trampoline will use a alternative sequence when branch CFI is on. The stack of the test is organized as follow ``` 56 $ra 44 $a0 f 36 $a1 p 32 00038067 jalr t2 28 010e3e03 ld t3, 16(t3) 24 018e3383 ld t2, 24(t3) 20 00000e17 auipc t3, 0 sp+16 00000023 lpad 0 ```
1 parent d570409 commit 893fa06

File tree

2 files changed

+178
-31
lines changed

2 files changed

+178
-31
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 79 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8362,9 +8362,23 @@ SDValue RISCVTargetLowering::lowerINIT_TRAMPOLINE(SDValue Op,
83628362
// 16: <StaticChainOffset>
83638363
// 24: <FunctionAddressOffset>
83648364
// 32:
8365-
8366-
constexpr unsigned StaticChainOffset = 16;
8367-
constexpr unsigned FunctionAddressOffset = 24;
8365+
// Offset with branch control flow protection enabled:
8366+
// 0: lpad <imm20>
8367+
// 4: auipc t3, 0
8368+
// 8: ld t2, 28(t3)
8369+
// 12: ld t3, 20(t3)
8370+
// 16: jalr t2
8371+
// 20: <StaticChainOffset>
8372+
// 28: <FunctionAddressOffset>
8373+
// 36:
8374+
8375+
const bool HasCFBranch =
8376+
Subtarget.hasStdExtZicfilp() &&
8377+
DAG.getMachineFunction().getFunction().getParent()->getModuleFlag(
8378+
"cf-protection-branch");
8379+
const unsigned StaticChainIdx = HasCFBranch ? 5 : 4;
8380+
const unsigned StaticChainOffset = StaticChainIdx * 4;
8381+
const unsigned FunctionAddressOffset = StaticChainOffset + 8;
83688382

83698383
const MCSubtargetInfo *STI = getTargetMachine().getMCSubtargetInfo();
83708384
assert(STI);
@@ -8376,38 +8390,70 @@ SDValue RISCVTargetLowering::lowerINIT_TRAMPOLINE(SDValue Op,
83768390
return Encoding;
83778391
};
83788392

8379-
SDValue OutChains[6];
8380-
8381-
uint32_t Encodings[] = {
8382-
// auipc t2, 0
8383-
// Loads the current PC into t2.
8384-
GetEncoding(MCInstBuilder(RISCV::AUIPC).addReg(RISCV::X7).addImm(0)),
8385-
// ld t0, 24(t2)
8386-
// Loads the function address into t0. Note that we are using offsets
8387-
// pc-relative to the first instruction of the trampoline.
8388-
GetEncoding(
8389-
MCInstBuilder(RISCV::LD).addReg(RISCV::X5).addReg(RISCV::X7).addImm(
8390-
FunctionAddressOffset)),
8391-
// ld t2, 16(t2)
8392-
// Load the value of the static chain.
8393-
GetEncoding(
8394-
MCInstBuilder(RISCV::LD).addReg(RISCV::X7).addReg(RISCV::X7).addImm(
8395-
StaticChainOffset)),
8396-
// jalr t0
8397-
// Jump to the function.
8398-
GetEncoding(MCInstBuilder(RISCV::JALR)
8399-
.addReg(RISCV::X0)
8400-
.addReg(RISCV::X5)
8401-
.addImm(0))};
8393+
SmallVector<SDValue> OutChains;
8394+
8395+
SmallVector<uint32_t> Encodings;
8396+
if (!HasCFBranch) {
8397+
Encodings.append(
8398+
{// auipc t2, 0
8399+
// Loads the current PC into t2.
8400+
GetEncoding(MCInstBuilder(RISCV::AUIPC).addReg(RISCV::X7).addImm(0)),
8401+
// ld t0, 24(t2)
8402+
// Loads the function address into t0. Note that we are using offsets
8403+
// pc-relative to the first instruction of the trampoline.
8404+
GetEncoding(MCInstBuilder(RISCV::LD)
8405+
.addReg(RISCV::X5)
8406+
.addReg(RISCV::X7)
8407+
.addImm(FunctionAddressOffset)),
8408+
// ld t2, 16(t2)
8409+
// Load the value of the static chain.
8410+
GetEncoding(MCInstBuilder(RISCV::LD)
8411+
.addReg(RISCV::X7)
8412+
.addReg(RISCV::X7)
8413+
.addImm(StaticChainOffset)),
8414+
// jalr t0
8415+
// Jump to the function.
8416+
GetEncoding(MCInstBuilder(RISCV::JALR)
8417+
.addReg(RISCV::X0)
8418+
.addReg(RISCV::X5)
8419+
.addImm(0))});
8420+
} else {
8421+
Encodings.append(
8422+
{// auipc x0, <imm20> (lpad <imm20>)
8423+
// Landing pad.
8424+
GetEncoding(MCInstBuilder(RISCV::AUIPC).addReg(RISCV::X0).addImm(0)),
8425+
// auipc t3, 0
8426+
// Loads the current PC into t3.
8427+
GetEncoding(MCInstBuilder(RISCV::AUIPC).addReg(RISCV::X28).addImm(0)),
8428+
// ld t2, (FunctionAddressOffset - 4)(t3)
8429+
// Loads the function address into t2. Note that we are using offsets
8430+
// pc-relative to the SECOND instruction of the trampoline.
8431+
GetEncoding(MCInstBuilder(RISCV::LD)
8432+
.addReg(RISCV::X7)
8433+
.addReg(RISCV::X28)
8434+
.addImm(FunctionAddressOffset - 4)),
8435+
// ld t3, (StaticChainOffset - 4)(t3)
8436+
// Load the value of the static chain.
8437+
GetEncoding(MCInstBuilder(RISCV::LD)
8438+
.addReg(RISCV::X28)
8439+
.addReg(RISCV::X28)
8440+
.addImm(StaticChainOffset - 4)),
8441+
// jalr t2
8442+
// Software-guarded jump to the function.
8443+
GetEncoding(MCInstBuilder(RISCV::JALR)
8444+
.addReg(RISCV::X0)
8445+
.addReg(RISCV::X7)
8446+
.addImm(0))});
8447+
}
84028448

84038449
// Store encoded instructions.
84048450
for (auto [Idx, Encoding] : llvm::enumerate(Encodings)) {
84058451
SDValue Addr = Idx > 0 ? DAG.getNode(ISD::ADD, dl, MVT::i64, Trmp,
84068452
DAG.getConstant(Idx * 4, dl, MVT::i64))
84078453
: Trmp;
8408-
OutChains[Idx] = DAG.getTruncStore(
8454+
OutChains.push_back(DAG.getTruncStore(
84098455
Root, dl, DAG.getConstant(Encoding, dl, MVT::i64), Addr,
8410-
MachinePointerInfo(TrmpAddr, Idx * 4), MVT::i32);
8456+
MachinePointerInfo(TrmpAddr, Idx * 4), MVT::i32));
84118457
}
84128458

84138459
// Now store the variable part of the trampoline.
@@ -8423,16 +8469,18 @@ SDValue RISCVTargetLowering::lowerINIT_TRAMPOLINE(SDValue Op,
84238469
{StaticChainOffset, StaticChain},
84248470
{FunctionAddressOffset, FunctionAddress},
84258471
};
8426-
for (auto [Idx, OffsetValue] : llvm::enumerate(OffsetValues)) {
8472+
for (auto &OffsetValue : OffsetValues) {
84278473
SDValue Addr =
84288474
DAG.getNode(ISD::ADD, dl, MVT::i64, Trmp,
84298475
DAG.getConstant(OffsetValue.Offset, dl, MVT::i64));
84308476
OffsetValue.Addr = Addr;
8431-
OutChains[Idx + 4] =
8477+
OutChains.push_back(
84328478
DAG.getStore(Root, dl, OffsetValue.Value, Addr,
8433-
MachinePointerInfo(TrmpAddr, OffsetValue.Offset));
8479+
MachinePointerInfo(TrmpAddr, OffsetValue.Offset)));
84348480
}
84358481

8482+
assert(OutChains.size() == StaticChainIdx + 2 &&
8483+
"Size of OutChains mismatch");
84368484
SDValue StoreToken = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OutChains);
84378485

84388486
// The end of instructions of trampoline is the same as the static chain
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc -O0 -mtriple=riscv64 -mattr=+experimental-zicfilp -verify-machineinstrs < %s \
3+
; RUN: | FileCheck -check-prefix=RV64 %s
4+
; RUN: llc -O0 -mtriple=riscv64-unknown-linux-gnu -mattr=+experimental-zicfilp -verify-machineinstrs < %s \
5+
; RUN: | FileCheck -check-prefix=RV64-LINUX %s
6+
7+
declare void @llvm.init.trampoline(ptr, ptr, ptr)
8+
declare ptr @llvm.adjust.trampoline(ptr)
9+
declare i64 @f(ptr nest, i64)
10+
11+
define i64 @test0(i64 %n, ptr %p) nounwind {
12+
; RV64-LABEL: test0:
13+
; RV64: # %bb.0:
14+
; RV64-NEXT: lpad 0
15+
; RV64-NEXT: addi sp, sp, -64
16+
; RV64-NEXT: sd ra, 56(sp) # 8-byte Folded Spill
17+
; RV64-NEXT: sd a0, 0(sp) # 8-byte Folded Spill
18+
; RV64-NEXT: lui a0, %hi(f)
19+
; RV64-NEXT: addi a0, a0, %lo(f)
20+
; RV64-NEXT: sw a0, 44(sp)
21+
; RV64-NEXT: srli a0, a0, 32
22+
; RV64-NEXT: sw a0, 48(sp)
23+
; RV64-NEXT: sw a1, 36(sp)
24+
; RV64-NEXT: srli a0, a1, 32
25+
; RV64-NEXT: sw a0, 40(sp)
26+
; RV64-NEXT: li a0, 23
27+
; RV64-NEXT: sw a0, 16(sp)
28+
; RV64-NEXT: lui a0, 56
29+
; RV64-NEXT: addi a0, a0, 103
30+
; RV64-NEXT: sw a0, 32(sp)
31+
; RV64-NEXT: lui a0, 4324
32+
; RV64-NEXT: addi a0, a0, -509
33+
; RV64-NEXT: sw a0, 28(sp)
34+
; RV64-NEXT: lui a0, 6371
35+
; RV64-NEXT: addi a0, a0, 899
36+
; RV64-NEXT: sw a0, 24(sp)
37+
; RV64-NEXT: lui a0, 1
38+
; RV64-NEXT: addi a0, a0, -489
39+
; RV64-NEXT: sw a0, 20(sp)
40+
; RV64-NEXT: addi a1, sp, 36
41+
; RV64-NEXT: addi a0, sp, 16
42+
; RV64-NEXT: sd a0, 8(sp) # 8-byte Folded Spill
43+
; RV64-NEXT: call __clear_cache
44+
; RV64-NEXT: ld a0, 0(sp) # 8-byte Folded Reload
45+
; RV64-NEXT: ld a1, 8(sp) # 8-byte Folded Reload
46+
; RV64-NEXT: jalr a1
47+
; RV64-NEXT: ld ra, 56(sp) # 8-byte Folded Reload
48+
; RV64-NEXT: addi sp, sp, 64
49+
; RV64-NEXT: ret
50+
;
51+
; RV64-LINUX-LABEL: test0:
52+
; RV64-LINUX: # %bb.0:
53+
; RV64-LINUX-NEXT: lpad 0
54+
; RV64-LINUX-NEXT: addi sp, sp, -64
55+
; RV64-LINUX-NEXT: sd ra, 56(sp) # 8-byte Folded Spill
56+
; RV64-LINUX-NEXT: sd a0, 0(sp) # 8-byte Folded Spill
57+
; RV64-LINUX-NEXT: lui a0, %hi(f)
58+
; RV64-LINUX-NEXT: addi a0, a0, %lo(f)
59+
; RV64-LINUX-NEXT: sw a0, 44(sp)
60+
; RV64-LINUX-NEXT: srli a0, a0, 32
61+
; RV64-LINUX-NEXT: sw a0, 48(sp)
62+
; RV64-LINUX-NEXT: sw a1, 36(sp)
63+
; RV64-LINUX-NEXT: srli a0, a1, 32
64+
; RV64-LINUX-NEXT: sw a0, 40(sp)
65+
; RV64-LINUX-NEXT: li a0, 23
66+
; RV64-LINUX-NEXT: sw a0, 16(sp)
67+
; RV64-LINUX-NEXT: lui a0, 56
68+
; RV64-LINUX-NEXT: addi a0, a0, 103
69+
; RV64-LINUX-NEXT: sw a0, 32(sp)
70+
; RV64-LINUX-NEXT: lui a0, 4324
71+
; RV64-LINUX-NEXT: addi a0, a0, -509
72+
; RV64-LINUX-NEXT: sw a0, 28(sp)
73+
; RV64-LINUX-NEXT: lui a0, 6371
74+
; RV64-LINUX-NEXT: addi a0, a0, 899
75+
; RV64-LINUX-NEXT: sw a0, 24(sp)
76+
; RV64-LINUX-NEXT: lui a0, 1
77+
; RV64-LINUX-NEXT: addi a0, a0, -489
78+
; RV64-LINUX-NEXT: sw a0, 20(sp)
79+
; RV64-LINUX-NEXT: addi a1, sp, 36
80+
; RV64-LINUX-NEXT: addi a0, sp, 16
81+
; RV64-LINUX-NEXT: sd a0, 8(sp) # 8-byte Folded Spill
82+
; RV64-LINUX-NEXT: li a2, 0
83+
; RV64-LINUX-NEXT: call __riscv_flush_icache
84+
; RV64-LINUX-NEXT: ld a0, 0(sp) # 8-byte Folded Reload
85+
; RV64-LINUX-NEXT: ld a1, 8(sp) # 8-byte Folded Reload
86+
; RV64-LINUX-NEXT: jalr a1
87+
; RV64-LINUX-NEXT: ld ra, 56(sp) # 8-byte Folded Reload
88+
; RV64-LINUX-NEXT: addi sp, sp, 64
89+
; RV64-LINUX-NEXT: ret
90+
%alloca = alloca [36 x i8], align 8
91+
call void @llvm.init.trampoline(ptr %alloca, ptr @f, ptr %p)
92+
%tramp = call ptr @llvm.adjust.trampoline(ptr %alloca)
93+
%ret = call i64 %tramp(i64 %n)
94+
ret i64 %ret
95+
}
96+
97+
!llvm.module.flags = !{!0}
98+
99+
!0 = !{i32 8, !"cf-protection-branch", i32 1}

0 commit comments

Comments
 (0)