Skip to content

Commit 2009bb9

Browse files
committed
Add __builtin_wasm_test_function_pointer_signature
This uses ref.test to check whether the function pointer's runtime type matches its static type. If so, then calling it won't trap with "indirect call signature mismatch". This would be very useful here: https://github.com/python/cpython/blob/main/Python/emscripten_trampoline.c and would allow us to fix function pointer mismatches on the WASI target and the Emscripten target in a uniform way.
1 parent 830c0b7 commit 2009bb9

File tree

9 files changed

+318
-0
lines changed

9 files changed

+318
-0
lines changed

clang/include/clang/Basic/BuiltinsWebAssembly.def

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,12 @@ TARGET_BUILTIN(__builtin_wasm_ref_is_null_extern, "ii", "nct", "reference-types"
199199
// return type.
200200
TARGET_BUILTIN(__builtin_wasm_ref_null_func, "i", "nct", "reference-types")
201201

202+
// Check if the static type of a function pointer matches its static type. Used
203+
// to avoid "function signature mismatch" traps. Takes a function pointer, uses
204+
// table.get to look up the pointer in __indirect_function_table and then
205+
// ref.test to test the type.
206+
TARGET_BUILTIN(__builtin_wasm_test_function_pointer_signature, "i.", "nct", "reference-types")
207+
202208
// Table builtins
203209
TARGET_BUILTIN(__builtin_wasm_table_set, "viii", "t", "reference-types")
204210
TARGET_BUILTIN(__builtin_wasm_table_get, "iii", "t", "reference-types")

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7566,6 +7566,8 @@ def err_typecheck_illegal_increment_decrement : Error<
75667566
"cannot %select{decrement|increment}1 value of type %0">;
75677567
def err_typecheck_expect_int : Error<
75687568
"used type %0 where integer is required">;
7569+
def err_typecheck_expect_function_pointer
7570+
: Error<"used type %0 where function pointer is required">;
75697571
def err_typecheck_expect_hlsl_resource : Error<
75707572
"used type %0 where __hlsl_resource_t is required">;
75717573
def err_typecheck_arithmetic_incomplete_or_sizeless_type : Error<
@@ -13164,6 +13166,10 @@ def err_wasm_builtin_arg_must_match_table_element_type : Error <
1316413166
"%ordinal0 argument must match the element type of the WebAssembly table in the %ordinal1 argument">;
1316513167
def err_wasm_builtin_arg_must_be_integer_type : Error <
1316613168
"%ordinal0 argument must be an integer">;
13169+
def err_wasm_builtin_test_fp_sig_cannot_include_reference_type
13170+
: Error<"__builtin_wasm_test_function_pointer_signature not supported for "
13171+
"function pointers with reference types in their "
13172+
"%select{return|parameter}0 type">;
1316713173

1316813174
// OpenACC diagnostics.
1316913175
def warn_acc_routine_unimplemented

clang/include/clang/Sema/SemaWasm.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class SemaWasm : public SemaBase {
3737
bool BuiltinWasmTableGrow(CallExpr *TheCall);
3838
bool BuiltinWasmTableFill(CallExpr *TheCall);
3939
bool BuiltinWasmTableCopy(CallExpr *TheCall);
40+
bool BuiltinWasmTestFunctionPointerSignature(CallExpr *TheCall);
4041

4142
WebAssemblyImportNameAttr *
4243
mergeImportNameAttr(Decl *D, const WebAssemblyImportNameAttr &AL);

clang/lib/CodeGen/TargetBuiltins/WebAssembly.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212

1313
#include "CGBuiltin.h"
1414
#include "clang/Basic/TargetBuiltins.h"
15+
#include "llvm/ADT/APInt.h"
16+
#include "llvm/IR/Constants.h"
1517
#include "llvm/IR/IntrinsicsWebAssembly.h"
18+
#include "llvm/Support/ErrorHandling.h"
1619

1720
using namespace clang;
1821
using namespace CodeGen;
@@ -218,6 +221,61 @@ Value *CodeGenFunction::EmitWebAssemblyBuiltinExpr(unsigned BuiltinID,
218221
Function *Callee = CGM.getIntrinsic(Intrinsic::wasm_ref_null_func);
219222
return Builder.CreateCall(Callee);
220223
}
224+
case WebAssembly::BI__builtin_wasm_test_function_pointer_signature: {
225+
Value *FuncRef = EmitScalarExpr(E->getArg(0));
226+
227+
// Get the function type from the argument's static type
228+
QualType ArgType = E->getArg(0)->getType();
229+
const PointerType *PtrTy = ArgType->getAs<PointerType>();
230+
assert(PtrTy && "Sema should have ensured this is a function pointer");
231+
232+
const FunctionType *FuncTy = PtrTy->getPointeeType()->getAs<FunctionType>();
233+
assert(FuncTy && "Sema should have ensured this is a function pointer");
234+
235+
// In the llvm IR, we won't have access anymore to the type of the function
236+
// pointer so we need to insert this type information somehow. We gave the
237+
// @llvm.wasm.ref.test.func varargs and here we add an extra 0 argument of
238+
// the type corresponding to the type of each argument of the function
239+
// signature. When we lower from the IR we'll use the types of these
240+
// arguments to determine the signature we want to test for.
241+
242+
// Make a type index constant with 0. This gets replaced by the actual type
243+
// in WebAssemblyMCInstLower.cpp.
244+
llvm::FunctionType *LLVMFuncTy =
245+
cast<llvm::FunctionType>(ConvertType(QualType(FuncTy, 0)));
246+
247+
uint NParams = LLVMFuncTy->getNumParams();
248+
std::vector<Value *> Args;
249+
Args.reserve(NParams + 1);
250+
// The only real argument is the FuncRef
251+
Args.push_back(FuncRef);
252+
253+
// Add the type information
254+
auto addType = [this, &Args](llvm::Type *T) {
255+
if (T->isVoidTy()) {
256+
// Use TokenTy as dummy for void b/c the verifier rejects a
257+
// void arg with 'Instruction operands must be first-class values!'
258+
// TokenTy isn't a first class value either but apparently the verifier
259+
// doesn't mind it.
260+
Args.push_back(
261+
UndefValue::get(llvm::Type::getTokenTy(getLLVMContext())));
262+
} else if (T->isFloatingPointTy()) {
263+
Args.push_back(ConstantFP::get(T, 0));
264+
} else if (T->isIntegerTy()) {
265+
Args.push_back(ConstantInt::get(T, 0));
266+
} else {
267+
// TODO: Handle reference types here. For now, we reject them in Sema.
268+
llvm_unreachable("Unhandled type");
269+
}
270+
};
271+
272+
addType(LLVMFuncTy->getReturnType());
273+
for (uint i = 0; i < NParams; i++) {
274+
addType(LLVMFuncTy->getParamType(i));
275+
}
276+
Function *Callee = CGM.getIntrinsic(Intrinsic::wasm_ref_test_func);
277+
return Builder.CreateCall(Callee, Args);
278+
}
221279
case WebAssembly::BI__builtin_wasm_swizzle_i8x16: {
222280
Value *Src = EmitScalarExpr(E->getArg(0));
223281
Value *Indices = EmitScalarExpr(E->getArg(1));

clang/lib/Sema/SemaWasm.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,53 @@ bool SemaWasm::BuiltinWasmTableCopy(CallExpr *TheCall) {
227227
return false;
228228
}
229229

230+
bool SemaWasm::BuiltinWasmTestFunctionPointerSignature(CallExpr *TheCall) {
231+
if (SemaRef.checkArgCount(TheCall, 1))
232+
return true;
233+
234+
Expr *FuncPtrArg = TheCall->getArg(0);
235+
QualType ArgType = FuncPtrArg->getType();
236+
237+
// Check that the argument is a function pointer
238+
const PointerType *PtrTy = ArgType->getAs<PointerType>();
239+
if (!PtrTy) {
240+
return Diag(FuncPtrArg->getBeginLoc(),
241+
diag::err_typecheck_expect_function_pointer)
242+
<< ArgType << FuncPtrArg->getSourceRange();
243+
}
244+
245+
const FunctionProtoType *FuncTy =
246+
PtrTy->getPointeeType()->getAs<FunctionProtoType>();
247+
if (!FuncTy) {
248+
return Diag(FuncPtrArg->getBeginLoc(),
249+
diag::err_typecheck_expect_function_pointer)
250+
<< ArgType << FuncPtrArg->getSourceRange();
251+
}
252+
253+
// Check that the function pointer doesn't use reference types
254+
if (FuncTy->getReturnType().isWebAssemblyReferenceType()) {
255+
return Diag(
256+
FuncPtrArg->getBeginLoc(),
257+
diag::err_wasm_builtin_test_fp_sig_cannot_include_reference_type)
258+
<< 0 << FuncTy->getReturnType() << FuncPtrArg->getSourceRange();
259+
}
260+
auto NParams = FuncTy->getNumParams();
261+
for (unsigned I = 0; I < NParams; I++) {
262+
if (FuncTy->getParamType(I).isWebAssemblyReferenceType()) {
263+
return Diag(
264+
FuncPtrArg->getBeginLoc(),
265+
diag::
266+
err_wasm_builtin_test_fp_sig_cannot_include_reference_type)
267+
<< 1 << FuncPtrArg->getSourceRange();
268+
}
269+
}
270+
271+
// Set return type to int (the result of the test)
272+
TheCall->setType(getASTContext().IntTy);
273+
274+
return false;
275+
}
276+
230277
bool SemaWasm::CheckWebAssemblyBuiltinFunctionCall(const TargetInfo &TI,
231278
unsigned BuiltinID,
232279
CallExpr *TheCall) {
@@ -249,6 +296,8 @@ bool SemaWasm::CheckWebAssemblyBuiltinFunctionCall(const TargetInfo &TI,
249296
return BuiltinWasmTableFill(TheCall);
250297
case WebAssembly::BI__builtin_wasm_table_copy:
251298
return BuiltinWasmTableCopy(TheCall);
299+
case WebAssembly::BI__builtin_wasm_test_function_pointer_signature:
300+
return BuiltinWasmTestFunctionPointerSignature(TheCall);
252301
}
253302

254303
return false;

llvm/include/llvm/IR/IntrinsicsWebAssembly.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ def int_wasm_ref_is_null_exn :
4343
DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_exnref_ty], [IntrNoMem],
4444
"llvm.wasm.ref.is_null.exn">;
4545

46+
def int_wasm_ref_test_func
47+
: DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_ptr_ty, llvm_vararg_ty],
48+
[IntrNoMem], "llvm.wasm.ref.test.func">;
49+
4650
//===----------------------------------------------------------------------===//
4751
// Table intrinsics
4852
//===----------------------------------------------------------------------===//

llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "WebAssemblySubtarget.h"
1919
#include "WebAssemblyTargetMachine.h"
2020
#include "WebAssemblyUtilities.h"
21+
#include "llvm/BinaryFormat/Wasm.h"
2122
#include "llvm/CodeGen/CallingConvLower.h"
2223
#include "llvm/CodeGen/MachineFrameInfo.h"
2324
#include "llvm/CodeGen/MachineInstrBuilder.h"
@@ -501,6 +502,51 @@ MVT WebAssemblyTargetLowering::getScalarShiftAmountTy(const DataLayout & /*DL*/,
501502
return Result;
502503
}
503504

505+
static MachineBasicBlock *LowerRefTestFuncRef(MachineInstr &MI, DebugLoc DL,
506+
MachineBasicBlock *BB,
507+
const TargetInstrInfo &TII) {
508+
// Lower a REF_TEST_FUNCREF_PSEUDO instruction into a REF_TEST_FUNCREF
509+
// instruction by combining the signature info Imm operands that
510+
// SelectionDag/InstrEmitter.cpp makes into one CImm operand. Put this into
511+
// the type index placeholder for REF_TEST_FUNCREF
512+
Register ResultReg = MI.getOperand(0).getReg();
513+
Register FuncRefReg = MI.getOperand(1).getReg();
514+
515+
auto NParams = MI.getNumOperands() - 3;
516+
auto Sig = APInt(NParams * 64, 0);
517+
518+
{
519+
uint64_t V = MI.getOperand(2).getImm();
520+
Sig |= int64_t(V);
521+
}
522+
523+
for (unsigned I = 3; I < MI.getNumOperands(); I++) {
524+
const MachineOperand &MO = MI.getOperand(I);
525+
if (!MO.isImm()) {
526+
// I'm not really sure what these are or where they come from but it seems
527+
// to be okay to ignore them
528+
continue;
529+
}
530+
uint16_t V = MO.getImm();
531+
Sig <<= 64;
532+
Sig |= int64_t(V);
533+
}
534+
535+
ConstantInt *TypeInfo =
536+
ConstantInt::get(BB->getParent()->getFunction().getContext(), Sig);
537+
538+
// Put the type info first in the placeholder for the type index, then the
539+
// actual funcref arg
540+
BuildMI(*BB, MI, DL, TII.get(WebAssembly::REF_TEST_FUNCREF), ResultReg)
541+
.addCImm(TypeInfo)
542+
.addReg(FuncRefReg);
543+
544+
// Remove the original instruction
545+
MI.eraseFromParent();
546+
547+
return BB;
548+
}
549+
504550
// Lower an fp-to-int conversion operator from the LLVM opcode, which has an
505551
// undefined result on invalid/overflow, to the WebAssembly opcode, which
506552
// traps on invalid/overflow.
@@ -862,6 +908,8 @@ MachineBasicBlock *WebAssemblyTargetLowering::EmitInstrWithCustomInserter(
862908
switch (MI.getOpcode()) {
863909
default:
864910
llvm_unreachable("Unexpected instr type to insert");
911+
case WebAssembly::REF_TEST_FUNCREF_PSEUDO:
912+
return LowerRefTestFuncRef(MI, DL, BB, TII);
865913
case WebAssembly::FP_TO_SINT_I32_F32:
866914
return LowerFPToInt(MI, DL, BB, TII, false, false, false,
867915
WebAssembly::I32_TRUNC_S_F32);
@@ -2253,6 +2301,72 @@ SDValue WebAssemblyTargetLowering::LowerIntrinsic(SDValue Op,
22532301
DAG.getTargetExternalSymbol(TlsBase, PtrVT)),
22542302
0);
22552303
}
2304+
case Intrinsic::wasm_ref_test_func: {
2305+
// First emit the TABLE_GET instruction to convert function pointer ==>
2306+
// funcref
2307+
MachineFunction &MF = DAG.getMachineFunction();
2308+
auto PtrVT = getPointerTy(MF.getDataLayout());
2309+
MCSymbol *Table =
2310+
WebAssembly::getOrCreateFunctionTableSymbol(MF.getContext(), Subtarget);
2311+
SDValue TableSym = DAG.getMCSymbol(Table, PtrVT);
2312+
SDValue FuncRef =
2313+
SDValue(DAG.getMachineNode(WebAssembly::TABLE_GET_FUNCREF, DL,
2314+
MVT::funcref, TableSym, Op.getOperand(1)),
2315+
0);
2316+
2317+
SmallVector<SDValue, 4> Ops;
2318+
Ops.push_back(FuncRef);
2319+
2320+
// We want to encode the type information into an APInt which we'll put
2321+
// in a CImm. However, in SelectionDag/InstrEmitter.cpp there is no code
2322+
// path that emits a CImm. So we need a custom inserter to put it in.
2323+
2324+
// We'll put each type argument in a separate TargetConstant which gets
2325+
// lowered to a MachineInstruction Imm. We combine these into a CImm in our
2326+
// custom inserter because it creates a problem downstream to have all these
2327+
// extra immediates.
2328+
{
2329+
SDValue Operand = Op.getOperand(2);
2330+
MVT VT = Operand.getValueType().getSimpleVT();
2331+
WebAssembly::BlockType V;
2332+
if (VT == MVT::Untyped) {
2333+
V = WebAssembly::BlockType::Void;
2334+
} else if (VT == MVT::i32) {
2335+
V = WebAssembly::BlockType::I32;
2336+
} else if (VT == MVT::i64) {
2337+
V = WebAssembly::BlockType::I64;
2338+
} else if (VT == MVT::f32) {
2339+
V = WebAssembly::BlockType::F32;
2340+
} else if (VT == MVT::f64) {
2341+
V = WebAssembly::BlockType::F64;
2342+
} else {
2343+
llvm_unreachable("Unhandled type!");
2344+
}
2345+
Ops.push_back(DAG.getTargetConstant((int64_t)V, DL, MVT::i64));
2346+
}
2347+
2348+
for (unsigned i = 3; i < Op.getNumOperands(); ++i) {
2349+
SDValue Operand = Op.getOperand(i);
2350+
MVT VT = Operand.getValueType().getSimpleVT();
2351+
wasm::ValType V;
2352+
if (VT == MVT::i32) {
2353+
V = wasm::ValType::I32;
2354+
} else if (VT == MVT::i64) {
2355+
V = wasm::ValType::I64;
2356+
} else if (VT == MVT::f32) {
2357+
V = wasm::ValType::F32;
2358+
} else if (VT == MVT::f64) {
2359+
V = wasm::ValType::F64;
2360+
} else {
2361+
llvm_unreachable("Unhandled type!");
2362+
}
2363+
Ops.push_back(DAG.getTargetConstant((int64_t)V, DL, MVT::i64));
2364+
}
2365+
2366+
return SDValue(DAG.getMachineNode(WebAssembly::REF_TEST_FUNCREF_PSEUDO, DL,
2367+
MVT::i32, Ops),
2368+
0);
2369+
}
22562370
}
22572371
}
22582372

llvm/lib/Target/WebAssembly/WebAssemblyInstrRef.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@ multiclass REF_I<WebAssemblyRegClass rc, ValueType vt, string ht> {
3636
Requires<[HasReferenceTypes]>;
3737
}
3838

39+
let usesCustomInserter = 1, isPseudo = 1 in defm REF_TEST_FUNCREF_PSEUDO
40+
: I<(outs I32:$res), (ins TypeIndex:$type, FUNCREF:$ref, variable_ops),
41+
(outs), (ins TypeIndex:$type), [], "ref.test.pseudo\t$type, $ref",
42+
"ref.test.pseudo $type", -1>;
43+
3944
defm REF_TEST_FUNCREF :
4045
I<(outs I32: $res),
4146
(ins TypeIndex:$type, FUNCREF: $ref),

0 commit comments

Comments
 (0)