diff --git a/llvm/include/llvm/IR/IntrinsicsWebAssembly.td b/llvm/include/llvm/IR/IntrinsicsWebAssembly.td index f592ff287a0e3..fb61d8a11e5c0 100644 --- a/llvm/include/llvm/IR/IntrinsicsWebAssembly.td +++ b/llvm/include/llvm/IR/IntrinsicsWebAssembly.td @@ -43,6 +43,10 @@ def int_wasm_ref_is_null_exn : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_exnref_ty], [IntrNoMem], "llvm.wasm.ref.is_null.exn">; +def int_wasm_ref_test_func + : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_ptr_ty, llvm_vararg_ty], + [IntrNoMem], "llvm.wasm.ref.test.func">; + //===----------------------------------------------------------------------===// // Table intrinsics //===----------------------------------------------------------------------===// diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp index bf2e04caa0a61..4e71e12653473 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp @@ -18,6 +18,7 @@ #include "WebAssemblySubtarget.h" #include "WebAssemblyTargetMachine.h" #include "WebAssemblyUtilities.h" +#include "llvm/BinaryFormat/Wasm.h" #include "llvm/CodeGen/CallingConvLower.h" #include "llvm/CodeGen/MachineFrameInfo.h" #include "llvm/CodeGen/MachineInstrBuilder.h" @@ -501,6 +502,51 @@ MVT WebAssemblyTargetLowering::getScalarShiftAmountTy(const DataLayout & /*DL*/, return Result; } +static MachineBasicBlock *LowerRefTestFuncRef(MachineInstr &MI, DebugLoc DL, + MachineBasicBlock *BB, + const TargetInstrInfo &TII) { + // Lower a REF_TEST_FUNCREF_PSEUDO instruction into a REF_TEST_FUNCREF + // instruction by combining the signature info Imm operands that + // SelectionDag/InstrEmitter.cpp makes into one CImm operand. Put this into + // the type index placeholder for REF_TEST_FUNCREF + Register ResultReg = MI.getOperand(0).getReg(); + Register FuncRefReg = MI.getOperand(1).getReg(); + + auto NParams = MI.getNumOperands() - 3; + auto Sig = APInt(NParams * 64, 0); + + { + uint64_t V = MI.getOperand(2).getImm(); + Sig |= int64_t(V); + } + + for (unsigned I = 3; I < MI.getNumOperands(); I++) { + const MachineOperand &MO = MI.getOperand(I); + if (!MO.isImm()) { + // I'm not really sure what these are or where they come from but it seems + // to be okay to ignore them + continue; + } + uint16_t V = MO.getImm(); + Sig <<= 64; + Sig |= int64_t(V); + } + + ConstantInt *TypeInfo = + ConstantInt::get(BB->getParent()->getFunction().getContext(), Sig); + + // Put the type info first in the placeholder for the type index, then the + // actual funcref arg + BuildMI(*BB, MI, DL, TII.get(WebAssembly::REF_TEST_FUNCREF), ResultReg) + .addCImm(TypeInfo) + .addReg(FuncRefReg); + + // Remove the original instruction + MI.eraseFromParent(); + + return BB; +} + // Lower an fp-to-int conversion operator from the LLVM opcode, which has an // undefined result on invalid/overflow, to the WebAssembly opcode, which // traps on invalid/overflow. @@ -794,6 +840,7 @@ LowerCallResults(MachineInstr &CallResults, DebugLoc DL, MachineBasicBlock *BB, if (IsIndirect) { // Placeholder for the type index. + // This gets replaced with the correct value in WebAssemblyMCInstLower.cpp MIB.addImm(0); // The table into which this call_indirect indexes. MCSymbolWasm *Table = IsFuncrefCall @@ -862,6 +909,8 @@ MachineBasicBlock *WebAssemblyTargetLowering::EmitInstrWithCustomInserter( switch (MI.getOpcode()) { default: llvm_unreachable("Unexpected instr type to insert"); + case WebAssembly::REF_TEST_FUNCREF_PSEUDO: + return LowerRefTestFuncRef(MI, DL, BB, TII); case WebAssembly::FP_TO_SINT_I32_F32: return LowerFPToInt(MI, DL, BB, TII, false, false, false, WebAssembly::I32_TRUNC_S_F32); @@ -2253,6 +2302,72 @@ SDValue WebAssemblyTargetLowering::LowerIntrinsic(SDValue Op, DAG.getTargetExternalSymbol(TlsBase, PtrVT)), 0); } + case Intrinsic::wasm_ref_test_func: { + // First emit the TABLE_GET instruction to convert function pointer ==> + // funcref + MachineFunction &MF = DAG.getMachineFunction(); + auto PtrVT = getPointerTy(MF.getDataLayout()); + MCSymbol *Table = + WebAssembly::getOrCreateFunctionTableSymbol(MF.getContext(), Subtarget); + SDValue TableSym = DAG.getMCSymbol(Table, PtrVT); + SDValue FuncRef = + SDValue(DAG.getMachineNode(WebAssembly::TABLE_GET_FUNCREF, DL, + MVT::funcref, TableSym, Op.getOperand(1)), + 0); + + SmallVector Ops; + Ops.push_back(FuncRef); + + // We want to encode the type information into an APInt which we'll put + // in a CImm. However, in SelectionDag/InstrEmitter.cpp there is no code + // path that emits a CImm. So we need a custom inserter to put it in. + + // We'll put each type argument in a separate TargetConstant which gets + // lowered to a MachineInstruction Imm. We combine these into a CImm in our + // custom inserter because it creates a problem downstream to have all these + // extra immediates. + { + SDValue Operand = Op.getOperand(2); + MVT VT = Operand.getValueType().getSimpleVT(); + WebAssembly::BlockType V; + if (VT == MVT::Untyped) { + V = WebAssembly::BlockType::Void; + } else if (VT == MVT::i32) { + V = WebAssembly::BlockType::I32; + } else if (VT == MVT::i64) { + V = WebAssembly::BlockType::I64; + } else if (VT == MVT::f32) { + V = WebAssembly::BlockType::F32; + } else if (VT == MVT::f64) { + V = WebAssembly::BlockType::F64; + } else { + llvm_unreachable("Unhandled type!"); + } + Ops.push_back(DAG.getTargetConstant((int64_t)V, DL, MVT::i64)); + } + + for (unsigned i = 3; i < Op.getNumOperands(); ++i) { + SDValue Operand = Op.getOperand(i); + MVT VT = Operand.getValueType().getSimpleVT(); + wasm::ValType V; + if (VT == MVT::i32) { + V = wasm::ValType::I32; + } else if (VT == MVT::i64) { + V = wasm::ValType::I64; + } else if (VT == MVT::f32) { + V = wasm::ValType::F32; + } else if (VT == MVT::f64) { + V = wasm::ValType::F64; + } else { + llvm_unreachable("Unhandled type!"); + } + Ops.push_back(DAG.getTargetConstant((int64_t)V, DL, MVT::i64)); + } + + return SDValue(DAG.getMachineNode(WebAssembly::REF_TEST_FUNCREF_PSEUDO, DL, + MVT::i32, Ops), + 0); + } } } diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrRef.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrRef.td index 40b87a084c687..0c61f5770e748 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrRef.td +++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrRef.td @@ -36,6 +36,11 @@ multiclass REF_I { Requires<[HasReferenceTypes]>; } +let usesCustomInserter = 1, isPseudo = 1 in defm REF_TEST_FUNCREF_PSEUDO + : I<(outs I32:$res), (ins TypeIndex:$type, FUNCREF:$ref, variable_ops), + (outs), (ins TypeIndex:$type), [], "ref.test.pseudo\t$type, $ref", + "ref.test.pseudo $type", -1>; + defm REF_TEST_FUNCREF : I<(outs I32: $res), (ins TypeIndex:$type, FUNCREF: $ref), diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp index cc36244e63ff5..f725ec344d922 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp @@ -15,13 +15,17 @@ #include "WebAssemblyMCInstLower.h" #include "MCTargetDesc/WebAssemblyMCAsmInfo.h" #include "MCTargetDesc/WebAssemblyMCTargetDesc.h" +#include "MCTargetDesc/WebAssemblyMCTypeUtilities.h" #include "TargetInfo/WebAssemblyTargetInfo.h" #include "Utils/WebAssemblyTypeUtilities.h" #include "WebAssemblyAsmPrinter.h" #include "WebAssemblyMachineFunctionInfo.h" #include "WebAssemblyUtilities.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/BinaryFormat/Wasm.h" #include "llvm/CodeGen/AsmPrinter.h" #include "llvm/CodeGen/MachineFunction.h" +#include "llvm/CodeGen/MachineOperand.h" #include "llvm/IR/Constants.h" #include "llvm/MC/MCAsmInfo.h" #include "llvm/MC/MCContext.h" @@ -196,11 +200,80 @@ void WebAssemblyMCInstLower::lower(const MachineInstr *MI, MCOp = MCOperand::createReg(WAReg); break; } + case llvm::MachineOperand::MO_CImmediate: { + // Lower type index placeholder for ref.test + // Currently this is the only way that CImmediates show up so panic if we + // get confused. + unsigned DescIndex = I - NumVariadicDefs; + if (DescIndex >= Desc.NumOperands) { + llvm_unreachable("unexpected CImmediate operand"); + } + const MCOperandInfo &Info = Desc.operands()[DescIndex]; + if (Info.OperandType != WebAssembly::OPERAND_TYPEINDEX) { + llvm_unreachable("unexpected CImmediate operand"); + } + auto CImm = MO.getCImm()->getValue(); + auto NumWords = CImm.getNumWords(); + // Extract the type data we packed into the CImm in LowerRefTestFuncRef. + // We need to load the words from most significant to least significant + // order because of the way we bitshifted them in from the right. + // The return type needs special handling because it could be void. + auto ReturnType = static_cast( + CImm.extractBitsAsZExtValue(64, (NumWords - 1) * 64)); + SmallVector Returns; + switch (ReturnType) { + case WebAssembly::BlockType::Invalid: + llvm_unreachable("Invalid return type"); + case WebAssembly::BlockType::I32: + Returns = {wasm::ValType::I32}; + break; + case WebAssembly::BlockType::I64: + Returns = {wasm::ValType::I64}; + break; + case WebAssembly::BlockType::F32: + Returns = {wasm::ValType::F32}; + break; + case WebAssembly::BlockType::F64: + Returns = {wasm::ValType::F64}; + break; + case WebAssembly::BlockType::Void: + Returns = {}; + break; + case WebAssembly::BlockType::Exnref: + Returns = {wasm::ValType::EXNREF}; + break; + case WebAssembly::BlockType::Externref: + Returns = {wasm::ValType::EXTERNREF}; + break; + case WebAssembly::BlockType::Funcref: + Returns = {wasm::ValType::FUNCREF}; + break; + case WebAssembly::BlockType::V128: + Returns = {wasm::ValType::V128}; + break; + case WebAssembly::BlockType::Multivalue: { + llvm_unreachable("Invalid return type"); + } + } + SmallVector Params; + + for (int I = NumWords - 2; I >= 0; I--) { + auto Val = CImm.extractBitsAsZExtValue(64, 64 * I); + auto ParamType = static_cast(Val); + Params.push_back(ParamType); + } + MCOp = lowerTypeIndexOperand(std::move(Returns), std::move(Params)); + break; + } case MachineOperand::MO_Immediate: { unsigned DescIndex = I - NumVariadicDefs; if (DescIndex < Desc.NumOperands) { const MCOperandInfo &Info = Desc.operands()[DescIndex]; + // Replace type index placeholder with actual type index. The type index + // placeholders are Immediates and have an operand type of + // OPERAND_TYPEINDEX or OPERAND_SIGNATURE. if (Info.OperandType == WebAssembly::OPERAND_TYPEINDEX) { + // Lower type index placeholder for a CALL_INDIRECT instruction SmallVector Returns; SmallVector Params; @@ -228,6 +301,7 @@ void WebAssemblyMCInstLower::lower(const MachineInstr *MI, break; } if (Info.OperandType == WebAssembly::OPERAND_SIGNATURE) { + // Lower type index placeholder for blocks auto BT = static_cast(MO.getImm()); assert(BT != WebAssembly::BlockType::Invalid); if (BT == WebAssembly::BlockType::Multivalue) { diff --git a/llvm/test/CodeGen/WebAssembly/ref-test-func.ll b/llvm/test/CodeGen/WebAssembly/ref-test-func.ll new file mode 100644 index 0000000000000..3fc848cd167f9 --- /dev/null +++ b/llvm/test/CodeGen/WebAssembly/ref-test-func.ll @@ -0,0 +1,42 @@ +; RUN: llc < %s -mcpu=mvp -mattr=+reference-types | FileCheck %s + +target triple = "wasm32-unknown-unknown" + +; CHECK-LABEL: test_function_pointer_signature_void: +; CHECK-NEXT: .functype test_function_pointer_signature_void (i32) -> () +; CHECK-NEXT: .local funcref +; CHECK: local.get 0 +; CHECK-NEXT: table.get __indirect_function_table +; CHECK-NEXT: local.tee 1 +; CHECK-NEXT: ref.test (f32, f64, i32) -> (f32) +; CHECK-NEXT: call use +; CHECK-NEXT: local.get 1 +; CHECK-NEXT: ref.test (f32, f64, i32) -> (i32) +; CHECK-NEXT: call use +; CHECK-NEXT: local.get 1 +; CHECK-NEXT: ref.test (i32, i32, i32) -> (i32) +; CHECK-NEXT: call use +; CHECK-NEXT: local.get 1 +; CHECK-NEXT: ref.test (i32, i32, i32) -> () +; CHECK-NEXT: call use +; CHECK-NEXT: local.get 1 +; CHECK-NEXT: ref.test () -> () +; CHECK-NEXT: call use + +; Function Attrs: nounwind +define void @test_function_pointer_signature_void(ptr noundef %func) local_unnamed_addr #0 { +entry: + %0 = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, float 0.000000e+00, float 0.000000e+00, double 0.000000e+00, i32 0) + tail call void @use(i32 noundef %0) #3 + %1 = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, i32 0, float 0.000000e+00, double 0.000000e+00, i32 0) + tail call void @use(i32 noundef %1) #3 + %2 = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, i32 0, i32 0, i32 0, i32 0) + tail call void @use(i32 noundef %2) #3 + %3 = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, token poison, i32 0, i32 0, i32 0) + tail call void @use(i32 noundef %3) #3 + %4 = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, token poison) + tail call void @use(i32 noundef %4) #3 + ret void +} + +declare void @use(i32 noundef) local_unnamed_addr #1