Skip to content

Commit 5eb65fd

Browse files
committed
[WebAssembly,llvm] Add llvm.wasm.ref.test.func intrinsic, option 2
To test whether or not a function pointer has the expected signature. Intended for adding a future clang builtin ` __builtin_wasm_test_function_pointer_signature` so we can test whether calling a function pointer will fail with function signature mismatch. This is an alternative to #147076, where instead of using a ref.test.pseudo instruction with a custom inserter, we teach SelectionDag a type of TargetConstantAP nodes that get converted to a CImm in the MCInst layer.
1 parent 830c0b7 commit 5eb65fd

File tree

11 files changed

+221
-16
lines changed

11 files changed

+221
-16
lines changed

llvm/include/llvm/CodeGen/ISDOpcodes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ enum NodeType {
173173
/// materialized in registers.
174174
TargetConstant,
175175
TargetConstantFP,
176+
TargetConstantAP,
176177

177178
/// TargetGlobalAddress - Like GlobalAddress, but the DAG does no folding or
178179
/// anything else with this node, and this is valid in the target-specific

llvm/include/llvm/CodeGen/SelectionDAG.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -683,7 +683,8 @@ class SelectionDAG {
683683
LLVM_ABI SDValue getConstant(uint64_t Val, const SDLoc &DL, EVT VT,
684684
bool isTarget = false, bool isOpaque = false);
685685
LLVM_ABI SDValue getConstant(const APInt &Val, const SDLoc &DL, EVT VT,
686-
bool isTarget = false, bool isOpaque = false);
686+
bool isTarget = false, bool isOpaque = false,
687+
bool isArbitraryPrecision = false);
687688

688689
LLVM_ABI SDValue getSignedConstant(int64_t Val, const SDLoc &DL, EVT VT,
689690
bool isTarget = false,
@@ -694,7 +695,8 @@ class SelectionDAG {
694695
bool IsOpaque = false);
695696

696697
LLVM_ABI SDValue getConstant(const ConstantInt &Val, const SDLoc &DL, EVT VT,
697-
bool isTarget = false, bool isOpaque = false);
698+
bool isTarget = false, bool isOpaque = false,
699+
bool isArbitraryPrecision = false);
698700
LLVM_ABI SDValue getIntPtrConstant(uint64_t Val, const SDLoc &DL,
699701
bool isTarget = false);
700702
LLVM_ABI SDValue getShiftAmountConstant(uint64_t Val, EVT VT,
@@ -712,6 +714,10 @@ class SelectionDAG {
712714
bool isOpaque = false) {
713715
return getConstant(Val, DL, VT, true, isOpaque);
714716
}
717+
SDValue getTargetConstantAP(const APInt &Val, const SDLoc &DL, EVT VT,
718+
bool isOpaque = false) {
719+
return getConstant(Val, DL, VT, true, isOpaque, true);
720+
}
715721
SDValue getTargetConstant(const ConstantInt &Val, const SDLoc &DL, EVT VT,
716722
bool isOpaque = false) {
717723
return getConstant(Val, DL, VT, true, isOpaque);

llvm/include/llvm/CodeGen/SelectionDAGNodes.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1742,10 +1742,11 @@ class ConstantSDNode : public SDNode {
17421742

17431743
const ConstantInt *Value;
17441744

1745-
ConstantSDNode(bool isTarget, bool isOpaque, const ConstantInt *val,
1746-
SDVTList VTs)
1747-
: SDNode(isTarget ? ISD::TargetConstant : ISD::Constant, 0, DebugLoc(),
1748-
VTs),
1745+
ConstantSDNode(bool isTarget, bool isOpaque, bool isAPTarget,
1746+
const ConstantInt *val, SDVTList VTs)
1747+
: SDNode(isAPTarget ? ISD::TargetConstantAP
1748+
: (isTarget ? ISD::TargetConstant : ISD::Constant),
1749+
0, DebugLoc(), VTs),
17491750
Value(val) {
17501751
assert(!isa<VectorType>(val->getType()) && "Unexpected vector type!");
17511752
ConstantSDNodeBits.IsOpaque = isOpaque;
@@ -1772,7 +1773,8 @@ class ConstantSDNode : public SDNode {
17721773

17731774
static bool classof(const SDNode *N) {
17741775
return N->getOpcode() == ISD::Constant ||
1775-
N->getOpcode() == ISD::TargetConstant;
1776+
N->getOpcode() == ISD::TargetConstant ||
1777+
N->getOpcode() == ISD::TargetConstantAP;
17761778
}
17771779
};
17781780

llvm/include/llvm/IR/IntrinsicsWebAssembly.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ def int_wasm_ref_is_null_exn :
4343
DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_exnref_ty], [IntrNoMem],
4444
"llvm.wasm.ref.is_null.exn">;
4545

46+
def int_wasm_ref_test_func
47+
: DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_ptr_ty, llvm_vararg_ty],
48+
[IntrNoMem], "llvm.wasm.ref.test.func">;
49+
4650
//===----------------------------------------------------------------------===//
4751
// Table intrinsics
4852
//===----------------------------------------------------------------------===//

llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,12 @@ void InstrEmitter::AddOperand(MachineInstrBuilder &MIB, SDValue Op,
402402
AddRegisterOperand(MIB, Op, IIOpNum, II, VRBaseMap,
403403
IsDebug, IsClone, IsCloned);
404404
} else if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Op)) {
405-
MIB.addImm(C->getSExtValue());
405+
if (C->getOpcode() == ISD::TargetConstantAP) {
406+
MIB.addCImm(
407+
ConstantInt::get(MF->getFunction().getContext(), C->getAPIntValue()));
408+
} else {
409+
MIB.addImm(C->getSExtValue());
410+
}
406411
} else if (ConstantFPSDNode *F = dyn_cast<ConstantFPSDNode>(Op)) {
407412
MIB.addFPImm(F->getConstantFPValue());
408413
} else if (RegisterSDNode *R = dyn_cast<RegisterSDNode>(Op)) {

llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -968,6 +968,7 @@ void SelectionDAGLegalize::LegalizeOp(SDNode *Node) {
968968

969969
// Allow illegal target nodes and illegal registers.
970970
if (Node->getOpcode() == ISD::TargetConstant ||
971+
Node->getOpcode() == ISD::TargetConstantAP ||
971972
Node->getOpcode() == ISD::Register)
972973
return;
973974

@@ -979,10 +980,11 @@ void SelectionDAGLegalize::LegalizeOp(SDNode *Node) {
979980

980981
for (const SDValue &Op : Node->op_values())
981982
assert((TLI.getTypeAction(*DAG.getContext(), Op.getValueType()) ==
982-
TargetLowering::TypeLegal ||
983+
TargetLowering::TypeLegal ||
983984
Op.getOpcode() == ISD::TargetConstant ||
985+
Op->getOpcode() == ISD::TargetConstantAP ||
984986
Op.getOpcode() == ISD::Register) &&
985-
"Unexpected illegal type!");
987+
"Unexpected illegal type!");
986988
#endif
987989

988990
// Figure out the correct action; the way to query this varies by opcode

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1664,14 +1664,14 @@ SDValue SelectionDAG::getConstant(uint64_t Val, const SDLoc &DL, EVT VT,
16641664
}
16651665

16661666
SDValue SelectionDAG::getConstant(const APInt &Val, const SDLoc &DL, EVT VT,
1667-
bool isT, bool isO) {
1668-
return getConstant(*ConstantInt::get(*Context, Val), DL, VT, isT, isO);
1667+
bool isT, bool isO, bool isAP) {
1668+
return getConstant(*ConstantInt::get(*Context, Val), DL, VT, isT, isO, isAP);
16691669
}
16701670

16711671
SDValue SelectionDAG::getConstant(const ConstantInt &Val, const SDLoc &DL,
1672-
EVT VT, bool isT, bool isO) {
1672+
EVT VT, bool isT, bool isO, bool isAP) {
16731673
assert(VT.isInteger() && "Cannot create FP integer constant!");
1674-
1674+
isT |= isAP;
16751675
EVT EltVT = VT.getScalarType();
16761676
const ConstantInt *Elt = &Val;
16771677

@@ -1760,7 +1760,8 @@ SDValue SelectionDAG::getConstant(const ConstantInt &Val, const SDLoc &DL,
17601760

17611761
assert(Elt->getBitWidth() == EltVT.getSizeInBits() &&
17621762
"APInt size does not match type size!");
1763-
unsigned Opc = isT ? ISD::TargetConstant : ISD::Constant;
1763+
unsigned Opc = isAP ? ISD::TargetConstantAP
1764+
: (isT ? ISD::TargetConstant : ISD::Constant);
17641765
SDVTList VTs = getVTList(EltVT);
17651766
FoldingSetNodeID ID;
17661767
AddNodeIDNode(ID, Opc, VTs, {});
@@ -1773,7 +1774,7 @@ SDValue SelectionDAG::getConstant(const ConstantInt &Val, const SDLoc &DL,
17731774
return SDValue(N, 0);
17741775

17751776
if (!N) {
1776-
N = newSDNode<ConstantSDNode>(isT, isO, Elt, VTs);
1777+
N = newSDNode<ConstantSDNode>(isT, isO, isAP, Elt, VTs);
17771778
CSEMap.InsertNode(N, IP);
17781779
InsertNode(N);
17791780
NewSDValueDbgMsg(SDValue(N, 0), "Creating constant: ", this);

llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3255,6 +3255,7 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch,
32553255
case ISD::HANDLENODE:
32563256
case ISD::MDNODE_SDNODE:
32573257
case ISD::TargetConstant:
3258+
case ISD::TargetConstantAP:
32583259
case ISD::TargetConstantFP:
32593260
case ISD::TargetConstantPool:
32603261
case ISD::TargetFrameIndex:

llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "WebAssemblySubtarget.h"
1919
#include "WebAssemblyTargetMachine.h"
2020
#include "WebAssemblyUtilities.h"
21+
#include "llvm/BinaryFormat/Wasm.h"
2122
#include "llvm/CodeGen/CallingConvLower.h"
2223
#include "llvm/CodeGen/MachineFrameInfo.h"
2324
#include "llvm/CodeGen/MachineInstrBuilder.h"
@@ -794,6 +795,7 @@ LowerCallResults(MachineInstr &CallResults, DebugLoc DL, MachineBasicBlock *BB,
794795

795796
if (IsIndirect) {
796797
// Placeholder for the type index.
798+
// This gets replaced with the correct value in WebAssemblyMCInstLower.cpp
797799
MIB.addImm(0);
798800
// The table into which this call_indirect indexes.
799801
MCSymbolWasm *Table = IsFuncrefCall
@@ -2253,6 +2255,71 @@ SDValue WebAssemblyTargetLowering::LowerIntrinsic(SDValue Op,
22532255
DAG.getTargetExternalSymbol(TlsBase, PtrVT)),
22542256
0);
22552257
}
2258+
case Intrinsic::wasm_ref_test_func: {
2259+
// First emit the TABLE_GET instruction to convert function pointer ==>
2260+
// funcref
2261+
MachineFunction &MF = DAG.getMachineFunction();
2262+
auto PtrVT = getPointerTy(MF.getDataLayout());
2263+
MCSymbol *Table =
2264+
WebAssembly::getOrCreateFunctionTableSymbol(MF.getContext(), Subtarget);
2265+
SDValue TableSym = DAG.getMCSymbol(Table, PtrVT);
2266+
SDValue FuncRef =
2267+
SDValue(DAG.getMachineNode(WebAssembly::TABLE_GET_FUNCREF, DL,
2268+
MVT::funcref, TableSym, Op.getOperand(1)),
2269+
0);
2270+
2271+
// Encode the signature information into the type index placeholder.
2272+
// This gets decoded and converted into the actual type signature in
2273+
// WebAssemblyMCInstLower.cpp.
2274+
auto NParams = Op.getNumOperands() - 2;
2275+
auto Sig = APInt(NParams * 64, 0);
2276+
// The return type has to be a BlockType since it can be void.
2277+
{
2278+
SDValue Operand = Op.getOperand(2);
2279+
MVT VT = Operand.getValueType().getSimpleVT();
2280+
WebAssembly::BlockType V;
2281+
if (VT == MVT::Untyped) {
2282+
V = WebAssembly::BlockType::Void;
2283+
} else if (VT == MVT::i32) {
2284+
V = WebAssembly::BlockType::I32;
2285+
} else if (VT == MVT::i64) {
2286+
V = WebAssembly::BlockType::I64;
2287+
} else if (VT == MVT::f32) {
2288+
V = WebAssembly::BlockType::F32;
2289+
} else if (VT == MVT::f64) {
2290+
V = WebAssembly::BlockType::F64;
2291+
} else {
2292+
llvm_unreachable("Unhandled type!");
2293+
}
2294+
Sig |= (int64_t)V;
2295+
}
2296+
for (unsigned i = 3; i < Op.getNumOperands(); ++i) {
2297+
SDValue Operand = Op.getOperand(i);
2298+
MVT VT = Operand.getValueType().getSimpleVT();
2299+
wasm::ValType V;
2300+
if (VT == MVT::i32) {
2301+
V = wasm::ValType::I32;
2302+
} else if (VT == MVT::i64) {
2303+
V = wasm::ValType::I64;
2304+
} else if (VT == MVT::f32) {
2305+
V = wasm::ValType::F32;
2306+
} else if (VT == MVT::f64) {
2307+
V = wasm::ValType::F64;
2308+
} else {
2309+
llvm_unreachable("Unhandled type!");
2310+
}
2311+
Sig <<= 64;
2312+
Sig |= (int64_t)V;
2313+
}
2314+
2315+
SmallVector<SDValue, 4> Ops;
2316+
Ops.push_back(DAG.getTargetConstantAP(
2317+
Sig, DL, EVT::getIntegerVT(*DAG.getContext(), NParams * 64)));
2318+
Ops.push_back(FuncRef);
2319+
return SDValue(
2320+
DAG.getMachineNode(WebAssembly::REF_TEST_FUNCREF, DL, MVT::i32, Ops),
2321+
0);
2322+
}
22562323
}
22572324
}
22582325

llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,17 @@
1515
#include "WebAssemblyMCInstLower.h"
1616
#include "MCTargetDesc/WebAssemblyMCAsmInfo.h"
1717
#include "MCTargetDesc/WebAssemblyMCTargetDesc.h"
18+
#include "MCTargetDesc/WebAssemblyMCTypeUtilities.h"
1819
#include "TargetInfo/WebAssemblyTargetInfo.h"
1920
#include "Utils/WebAssemblyTypeUtilities.h"
2021
#include "WebAssemblyAsmPrinter.h"
2122
#include "WebAssemblyMachineFunctionInfo.h"
2223
#include "WebAssemblyUtilities.h"
24+
#include "llvm/ADT/SmallVector.h"
25+
#include "llvm/BinaryFormat/Wasm.h"
2326
#include "llvm/CodeGen/AsmPrinter.h"
2427
#include "llvm/CodeGen/MachineFunction.h"
28+
#include "llvm/CodeGen/MachineOperand.h"
2529
#include "llvm/IR/Constants.h"
2630
#include "llvm/MC/MCAsmInfo.h"
2731
#include "llvm/MC/MCContext.h"
@@ -196,11 +200,80 @@ void WebAssemblyMCInstLower::lower(const MachineInstr *MI,
196200
MCOp = MCOperand::createReg(WAReg);
197201
break;
198202
}
203+
case llvm::MachineOperand::MO_CImmediate: {
204+
// Lower type index placeholder for ref.test
205+
// Currently this is the only way that CImmediates show up so panic if we
206+
// get confused.
207+
unsigned DescIndex = I - NumVariadicDefs;
208+
if (DescIndex >= Desc.NumOperands) {
209+
llvm_unreachable("unexpected CImmediate operand");
210+
}
211+
const MCOperandInfo &Info = Desc.operands()[DescIndex];
212+
if (Info.OperandType != WebAssembly::OPERAND_TYPEINDEX) {
213+
llvm_unreachable("unexpected CImmediate operand");
214+
}
215+
auto CImm = MO.getCImm()->getValue();
216+
auto NumWords = CImm.getNumWords();
217+
// Extract the type data we packed into the CImm in LowerRefTestFuncRef.
218+
// We need to load the words from most significant to least significant
219+
// order because of the way we bitshifted them in from the right.
220+
// The return type needs special handling because it could be void.
221+
auto ReturnType = static_cast<WebAssembly::BlockType>(
222+
CImm.extractBitsAsZExtValue(64, (NumWords - 1) * 64));
223+
SmallVector<wasm::ValType, 2> Returns;
224+
switch (ReturnType) {
225+
case WebAssembly::BlockType::Invalid:
226+
llvm_unreachable("Invalid return type");
227+
case WebAssembly::BlockType::I32:
228+
Returns = {wasm::ValType::I32};
229+
break;
230+
case WebAssembly::BlockType::I64:
231+
Returns = {wasm::ValType::I64};
232+
break;
233+
case WebAssembly::BlockType::F32:
234+
Returns = {wasm::ValType::F32};
235+
break;
236+
case WebAssembly::BlockType::F64:
237+
Returns = {wasm::ValType::F64};
238+
break;
239+
case WebAssembly::BlockType::Void:
240+
Returns = {};
241+
break;
242+
case WebAssembly::BlockType::Exnref:
243+
Returns = {wasm::ValType::EXNREF};
244+
break;
245+
case WebAssembly::BlockType::Externref:
246+
Returns = {wasm::ValType::EXTERNREF};
247+
break;
248+
case WebAssembly::BlockType::Funcref:
249+
Returns = {wasm::ValType::FUNCREF};
250+
break;
251+
case WebAssembly::BlockType::V128:
252+
Returns = {wasm::ValType::V128};
253+
break;
254+
case WebAssembly::BlockType::Multivalue: {
255+
llvm_unreachable("Invalid return type");
256+
}
257+
}
258+
SmallVector<wasm::ValType, 4> Params;
259+
260+
for (int I = NumWords - 2; I >= 0; I--) {
261+
auto Val = CImm.extractBitsAsZExtValue(64, 64 * I);
262+
auto ParamType = static_cast<wasm::ValType>(Val);
263+
Params.push_back(ParamType);
264+
}
265+
MCOp = lowerTypeIndexOperand(std::move(Returns), std::move(Params));
266+
break;
267+
}
199268
case MachineOperand::MO_Immediate: {
200269
unsigned DescIndex = I - NumVariadicDefs;
201270
if (DescIndex < Desc.NumOperands) {
202271
const MCOperandInfo &Info = Desc.operands()[DescIndex];
272+
// Replace type index placeholder with actual type index. The type index
273+
// placeholders are Immediates and have an operand type of
274+
// OPERAND_TYPEINDEX or OPERAND_SIGNATURE.
203275
if (Info.OperandType == WebAssembly::OPERAND_TYPEINDEX) {
276+
// Lower type index placeholder for a CALL_INDIRECT instruction
204277
SmallVector<wasm::ValType, 4> Returns;
205278
SmallVector<wasm::ValType, 4> Params;
206279

@@ -228,6 +301,7 @@ void WebAssemblyMCInstLower::lower(const MachineInstr *MI,
228301
break;
229302
}
230303
if (Info.OperandType == WebAssembly::OPERAND_SIGNATURE) {
304+
// Lower type index placeholder for blocks
231305
auto BT = static_cast<WebAssembly::BlockType>(MO.getImm());
232306
assert(BT != WebAssembly::BlockType::Invalid);
233307
if (BT == WebAssembly::BlockType::Multivalue) {

0 commit comments

Comments
 (0)