Skip to content

Commit 44ac5de

Browse files
committed
Add __builtin_wasm_test_function_pointer_signature
This uses ref.test to check whether the function pointer's runtime type matches its static type. If so, then calling it won't trap with "indirect call signature mismatch". This would be very useful here: https://github.com/python/cpython/blob/main/Python/emscripten_trampoline.c and would allow us to fix function pointer mismatches on the WASI target and the Emscripten target in a uniform way.
1 parent 6a4693a commit 44ac5de

File tree

9 files changed

+318
-0
lines changed

9 files changed

+318
-0
lines changed

clang/include/clang/Basic/BuiltinsWebAssembly.def

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,12 @@ TARGET_BUILTIN(__builtin_wasm_ref_null_extern, "i", "nct", "reference-types")
198198
// return type.
199199
TARGET_BUILTIN(__builtin_wasm_ref_null_func, "i", "nct", "reference-types")
200200

201+
// Check if the static type of a function pointer matches its static type. Used
202+
// to avoid "function signature mismatch" traps. Takes a function pointer, uses
203+
// table.get to look up the pointer in __indirect_function_table and then
204+
// ref.test to test the type.
205+
TARGET_BUILTIN(__builtin_wasm_test_function_pointer_signature, "i.", "nct", "reference-types")
206+
201207
// Table builtins
202208
TARGET_BUILTIN(__builtin_wasm_table_set, "viii", "t", "reference-types")
203209
TARGET_BUILTIN(__builtin_wasm_table_get, "iii", "t", "reference-types")

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7460,6 +7460,8 @@ def err_typecheck_illegal_increment_decrement : Error<
74607460
"cannot %select{decrement|increment}1 value of type %0">;
74617461
def err_typecheck_expect_int : Error<
74627462
"used type %0 where integer is required">;
7463+
def err_typecheck_expect_function_pointer
7464+
: Error<"used type %0 where function pointer is required">;
74637465
def err_typecheck_expect_hlsl_resource : Error<
74647466
"used type %0 where __hlsl_resource_t is required">;
74657467
def err_typecheck_arithmetic_incomplete_or_sizeless_type : Error<
@@ -12995,6 +12997,10 @@ def err_wasm_builtin_arg_must_match_table_element_type : Error <
1299512997
"%ordinal0 argument must match the element type of the WebAssembly table in the %ordinal1 argument">;
1299612998
def err_wasm_builtin_arg_must_be_integer_type : Error <
1299712999
"%ordinal0 argument must be an integer">;
13000+
def err_wasm_builtin_test_fp_sig_cannot_include_reference_type
13001+
: Error<"__builtin_wasm_test_function_pointer_signature not supported for "
13002+
"function pointers with reference types in their "
13003+
"%select{return|parameter}0 type">;
1299813004

1299913005
// OpenACC diagnostics.
1300013006
def warn_acc_routine_unimplemented

clang/include/clang/Sema/SemaWasm.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class SemaWasm : public SemaBase {
3636
bool BuiltinWasmTableGrow(CallExpr *TheCall);
3737
bool BuiltinWasmTableFill(CallExpr *TheCall);
3838
bool BuiltinWasmTableCopy(CallExpr *TheCall);
39+
bool BuiltinWasmTestFunctionPointerSignature(CallExpr *TheCall);
3940

4041
WebAssemblyImportNameAttr *
4142
mergeImportNameAttr(Decl *D, const WebAssemblyImportNameAttr &AL);

clang/lib/CodeGen/TargetBuiltins/WebAssembly.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212

1313
#include "CGBuiltin.h"
1414
#include "clang/Basic/TargetBuiltins.h"
15+
#include "llvm/ADT/APInt.h"
16+
#include "llvm/IR/Constants.h"
1517
#include "llvm/IR/IntrinsicsWebAssembly.h"
18+
#include "llvm/Support/ErrorHandling.h"
1619

1720
using namespace clang;
1821
using namespace CodeGen;
@@ -213,6 +216,61 @@ Value *CodeGenFunction::EmitWebAssemblyBuiltinExpr(unsigned BuiltinID,
213216
Function *Callee = CGM.getIntrinsic(Intrinsic::wasm_ref_null_func);
214217
return Builder.CreateCall(Callee);
215218
}
219+
case WebAssembly::BI__builtin_wasm_test_function_pointer_signature: {
220+
Value *FuncRef = EmitScalarExpr(E->getArg(0));
221+
222+
// Get the function type from the argument's static type
223+
QualType ArgType = E->getArg(0)->getType();
224+
const PointerType *PtrTy = ArgType->getAs<PointerType>();
225+
assert(PtrTy && "Sema should have ensured this is a function pointer");
226+
227+
const FunctionType *FuncTy = PtrTy->getPointeeType()->getAs<FunctionType>();
228+
assert(FuncTy && "Sema should have ensured this is a function pointer");
229+
230+
// In the llvm IR, we won't have access anymore to the type of the function
231+
// pointer so we need to insert this type information somehow. We gave the
232+
// @llvm.wasm.ref.test.func varargs and here we add an extra 0 argument of
233+
// the type corresponding to the type of each argument of the function
234+
// signature. When we lower from the IR we'll use the types of these
235+
// arguments to determine the signature we want to test for.
236+
237+
// Make a type index constant with 0. This gets replaced by the actual type
238+
// in WebAssemblyMCInstLower.cpp.
239+
llvm::FunctionType *LLVMFuncTy =
240+
cast<llvm::FunctionType>(ConvertType(QualType(FuncTy, 0)));
241+
242+
uint NParams = LLVMFuncTy->getNumParams();
243+
std::vector<Value *> Args;
244+
Args.reserve(NParams + 1);
245+
// The only real argument is the FuncRef
246+
Args.push_back(FuncRef);
247+
248+
// Add the type information
249+
auto addType = [this, &Args](llvm::Type *T) {
250+
if (T->isVoidTy()) {
251+
// Use TokenTy as dummy for void b/c the verifier rejects a
252+
// void arg with 'Instruction operands must be first-class values!'
253+
// TokenTy isn't a first class value either but apparently the verifier
254+
// doesn't mind it.
255+
Args.push_back(
256+
UndefValue::get(llvm::Type::getTokenTy(getLLVMContext())));
257+
} else if (T->isFloatingPointTy()) {
258+
Args.push_back(ConstantFP::get(T, 0));
259+
} else if (T->isIntegerTy()) {
260+
Args.push_back(ConstantInt::get(T, 0));
261+
} else {
262+
// TODO: Handle reference types here. For now, we reject them in Sema.
263+
llvm_unreachable("Unhandled type");
264+
}
265+
};
266+
267+
addType(LLVMFuncTy->getReturnType());
268+
for (uint i = 0; i < NParams; i++) {
269+
addType(LLVMFuncTy->getParamType(i));
270+
}
271+
Function *Callee = CGM.getIntrinsic(Intrinsic::wasm_ref_test_func);
272+
return Builder.CreateCall(Callee, Args);
273+
}
216274
case WebAssembly::BI__builtin_wasm_swizzle_i8x16: {
217275
Value *Src = EmitScalarExpr(E->getArg(0));
218276
Value *Indices = EmitScalarExpr(E->getArg(1));

clang/lib/Sema/SemaWasm.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,53 @@ bool SemaWasm::BuiltinWasmTableCopy(CallExpr *TheCall) {
216216
return false;
217217
}
218218

219+
bool SemaWasm::BuiltinWasmTestFunctionPointerSignature(CallExpr *TheCall) {
220+
if (SemaRef.checkArgCount(TheCall, 1))
221+
return true;
222+
223+
Expr *FuncPtrArg = TheCall->getArg(0);
224+
QualType ArgType = FuncPtrArg->getType();
225+
226+
// Check that the argument is a function pointer
227+
const PointerType *PtrTy = ArgType->getAs<PointerType>();
228+
if (!PtrTy) {
229+
return Diag(FuncPtrArg->getBeginLoc(),
230+
diag::err_typecheck_expect_function_pointer)
231+
<< ArgType << FuncPtrArg->getSourceRange();
232+
}
233+
234+
const FunctionProtoType *FuncTy =
235+
PtrTy->getPointeeType()->getAs<FunctionProtoType>();
236+
if (!FuncTy) {
237+
return Diag(FuncPtrArg->getBeginLoc(),
238+
diag::err_typecheck_expect_function_pointer)
239+
<< ArgType << FuncPtrArg->getSourceRange();
240+
}
241+
242+
// Check that the function pointer doesn't use reference types
243+
if (FuncTy->getReturnType().isWebAssemblyReferenceType()) {
244+
return Diag(
245+
FuncPtrArg->getBeginLoc(),
246+
diag::err_wasm_builtin_test_fp_sig_cannot_include_reference_type)
247+
<< 0 << FuncTy->getReturnType() << FuncPtrArg->getSourceRange();
248+
}
249+
auto NParams = FuncTy->getNumParams();
250+
for (unsigned I = 0; I < NParams; I++) {
251+
if (FuncTy->getParamType(I).isWebAssemblyReferenceType()) {
252+
return Diag(
253+
FuncPtrArg->getBeginLoc(),
254+
diag::
255+
err_wasm_builtin_test_fp_sig_cannot_include_reference_type)
256+
<< 1 << FuncPtrArg->getSourceRange();
257+
}
258+
}
259+
260+
// Set return type to int (the result of the test)
261+
TheCall->setType(getASTContext().IntTy);
262+
263+
return false;
264+
}
265+
219266
bool SemaWasm::CheckWebAssemblyBuiltinFunctionCall(const TargetInfo &TI,
220267
unsigned BuiltinID,
221268
CallExpr *TheCall) {
@@ -236,6 +283,8 @@ bool SemaWasm::CheckWebAssemblyBuiltinFunctionCall(const TargetInfo &TI,
236283
return BuiltinWasmTableFill(TheCall);
237284
case WebAssembly::BI__builtin_wasm_table_copy:
238285
return BuiltinWasmTableCopy(TheCall);
286+
case WebAssembly::BI__builtin_wasm_test_function_pointer_signature:
287+
return BuiltinWasmTestFunctionPointerSignature(TheCall);
239288
}
240289

241290
return false;

llvm/include/llvm/IR/IntrinsicsWebAssembly.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ def int_wasm_ref_is_null_exn :
4343
DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_exnref_ty], [IntrNoMem],
4444
"llvm.wasm.ref.is_null.exn">;
4545

46+
def int_wasm_ref_test_func
47+
: DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_ptr_ty, llvm_vararg_ty],
48+
[IntrNoMem], "llvm.wasm.ref.test.func">;
49+
4650
//===----------------------------------------------------------------------===//
4751
// Table intrinsics
4852
//===----------------------------------------------------------------------===//

llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "WebAssemblySubtarget.h"
1919
#include "WebAssemblyTargetMachine.h"
2020
#include "WebAssemblyUtilities.h"
21+
#include "llvm/BinaryFormat/Wasm.h"
2122
#include "llvm/CodeGen/CallingConvLower.h"
2223
#include "llvm/CodeGen/MachineFrameInfo.h"
2324
#include "llvm/CodeGen/MachineInstrBuilder.h"
@@ -505,6 +506,51 @@ MVT WebAssemblyTargetLowering::getScalarShiftAmountTy(const DataLayout & /*DL*/,
505506
return Result;
506507
}
507508

509+
static MachineBasicBlock *LowerRefTestFuncRef(MachineInstr &MI, DebugLoc DL,
510+
MachineBasicBlock *BB,
511+
const TargetInstrInfo &TII) {
512+
// Lower a REF_TEST_FUNCREF_PSEUDO instruction into a REF_TEST_FUNCREF
513+
// instruction by combining the signature info Imm operands that
514+
// SelectionDag/InstrEmitter.cpp makes into one CImm operand. Put this into
515+
// the type index placeholder for REF_TEST_FUNCREF
516+
Register ResultReg = MI.getOperand(0).getReg();
517+
Register FuncRefReg = MI.getOperand(1).getReg();
518+
519+
auto NParams = MI.getNumOperands() - 3;
520+
auto Sig = APInt(NParams * 64, 0);
521+
522+
{
523+
uint64_t V = MI.getOperand(2).getImm();
524+
Sig |= int64_t(V);
525+
}
526+
527+
for (unsigned I = 3; I < MI.getNumOperands(); I++) {
528+
const MachineOperand &MO = MI.getOperand(I);
529+
if (!MO.isImm()) {
530+
// I'm not really sure what these are or where they come from but it seems
531+
// to be okay to ignore them
532+
continue;
533+
}
534+
uint16_t V = MO.getImm();
535+
Sig <<= 64;
536+
Sig |= int64_t(V);
537+
}
538+
539+
ConstantInt *TypeInfo =
540+
ConstantInt::get(BB->getParent()->getFunction().getContext(), Sig);
541+
542+
// Put the type info first in the placeholder for the type index, then the
543+
// actual funcref arg
544+
BuildMI(*BB, MI, DL, TII.get(WebAssembly::REF_TEST_FUNCREF), ResultReg)
545+
.addCImm(TypeInfo)
546+
.addReg(FuncRefReg);
547+
548+
// Remove the original instruction
549+
MI.eraseFromParent();
550+
551+
return BB;
552+
}
553+
508554
// Lower an fp-to-int conversion operator from the LLVM opcode, which has an
509555
// undefined result on invalid/overflow, to the WebAssembly opcode, which
510556
// traps on invalid/overflow.
@@ -866,6 +912,8 @@ MachineBasicBlock *WebAssemblyTargetLowering::EmitInstrWithCustomInserter(
866912
switch (MI.getOpcode()) {
867913
default:
868914
llvm_unreachable("Unexpected instr type to insert");
915+
case WebAssembly::REF_TEST_FUNCREF_PSEUDO:
916+
return LowerRefTestFuncRef(MI, DL, BB, TII);
869917
case WebAssembly::FP_TO_SINT_I32_F32:
870918
return LowerFPToInt(MI, DL, BB, TII, false, false, false,
871919
WebAssembly::I32_TRUNC_S_F32);
@@ -2260,6 +2308,72 @@ SDValue WebAssemblyTargetLowering::LowerIntrinsic(SDValue Op,
22602308
DAG.getTargetExternalSymbol(TlsBase, PtrVT)),
22612309
0);
22622310
}
2311+
case Intrinsic::wasm_ref_test_func: {
2312+
// First emit the TABLE_GET instruction to convert function pointer ==>
2313+
// funcref
2314+
MachineFunction &MF = DAG.getMachineFunction();
2315+
auto PtrVT = getPointerTy(MF.getDataLayout());
2316+
MCSymbol *Table =
2317+
WebAssembly::getOrCreateFunctionTableSymbol(MF.getContext(), Subtarget);
2318+
SDValue TableSym = DAG.getMCSymbol(Table, PtrVT);
2319+
SDValue FuncRef =
2320+
SDValue(DAG.getMachineNode(WebAssembly::TABLE_GET_FUNCREF, DL,
2321+
MVT::funcref, TableSym, Op.getOperand(1)),
2322+
0);
2323+
2324+
SmallVector<SDValue, 4> Ops;
2325+
Ops.push_back(FuncRef);
2326+
2327+
// We want to encode the type information into an APInt which we'll put
2328+
// in a CImm. However, in SelectionDag/InstrEmitter.cpp there is no code
2329+
// path that emits a CImm. So we need a custom inserter to put it in.
2330+
2331+
// We'll put each type argument in a separate TargetConstant which gets
2332+
// lowered to a MachineInstruction Imm. We combine these into a CImm in our
2333+
// custom inserter because it creates a problem downstream to have all these
2334+
// extra immediates.
2335+
{
2336+
SDValue Operand = Op.getOperand(2);
2337+
MVT VT = Operand.getValueType().getSimpleVT();
2338+
WebAssembly::BlockType V;
2339+
if (VT == MVT::Untyped) {
2340+
V = WebAssembly::BlockType::Void;
2341+
} else if (VT == MVT::i32) {
2342+
V = WebAssembly::BlockType::I32;
2343+
} else if (VT == MVT::i64) {
2344+
V = WebAssembly::BlockType::I64;
2345+
} else if (VT == MVT::f32) {
2346+
V = WebAssembly::BlockType::F32;
2347+
} else if (VT == MVT::f64) {
2348+
V = WebAssembly::BlockType::F64;
2349+
} else {
2350+
llvm_unreachable("Unhandled type!");
2351+
}
2352+
Ops.push_back(DAG.getTargetConstant((int64_t)V, DL, MVT::i64));
2353+
}
2354+
2355+
for (unsigned i = 3; i < Op.getNumOperands(); ++i) {
2356+
SDValue Operand = Op.getOperand(i);
2357+
MVT VT = Operand.getValueType().getSimpleVT();
2358+
wasm::ValType V;
2359+
if (VT == MVT::i32) {
2360+
V = wasm::ValType::I32;
2361+
} else if (VT == MVT::i64) {
2362+
V = wasm::ValType::I64;
2363+
} else if (VT == MVT::f32) {
2364+
V = wasm::ValType::F32;
2365+
} else if (VT == MVT::f64) {
2366+
V = wasm::ValType::F64;
2367+
} else {
2368+
llvm_unreachable("Unhandled type!");
2369+
}
2370+
Ops.push_back(DAG.getTargetConstant((int64_t)V, DL, MVT::i64));
2371+
}
2372+
2373+
return SDValue(DAG.getMachineNode(WebAssembly::REF_TEST_FUNCREF_PSEUDO, DL,
2374+
MVT::i32, Ops),
2375+
0);
2376+
}
22632377
}
22642378
}
22652379

llvm/lib/Target/WebAssembly/WebAssemblyInstrRef.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@ multiclass REF_I<WebAssemblyRegClass rc, ValueType vt, string ht> {
3636
Requires<[HasReferenceTypes]>;
3737
}
3838

39+
let usesCustomInserter = 1, isPseudo = 1 in defm REF_TEST_FUNCREF_PSEUDO
40+
: I<(outs I32:$res), (ins TypeIndex:$type, FUNCREF:$ref, variable_ops),
41+
(outs), (ins TypeIndex:$type), [], "ref.test.pseudo\t$type, $ref",
42+
"ref.test.pseudo $type", -1>;
43+
3944
defm REF_TEST_FUNCREF :
4045
I<(outs I32: $res),
4146
(ins TypeIndex:$type, FUNCREF: $ref),

0 commit comments

Comments
 (0)