diff --git a/llvm/include/llvm/IR/IntrinsicsWebAssembly.td b/llvm/include/llvm/IR/IntrinsicsWebAssembly.td index f592ff287a0e3..c1e4b97e96bc8 100644 --- a/llvm/include/llvm/IR/IntrinsicsWebAssembly.td +++ b/llvm/include/llvm/IR/IntrinsicsWebAssembly.td @@ -43,6 +43,10 @@ def int_wasm_ref_is_null_exn : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_exnref_ty], [IntrNoMem], "llvm.wasm.ref.is_null.exn">; +def int_wasm_ref_test_func + : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_ptr_ty, llvm_vararg_ty], + [IntrNoMem]>; + //===----------------------------------------------------------------------===// // Table intrinsics //===----------------------------------------------------------------------===// diff --git a/llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp b/llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp index 03d3e8eab35d0..4d56803ba492a 100644 --- a/llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp @@ -402,7 +402,12 @@ void InstrEmitter::AddOperand(MachineInstrBuilder &MIB, SDValue Op, AddRegisterOperand(MIB, Op, IIOpNum, II, VRBaseMap, IsDebug, IsClone, IsCloned); } else if (ConstantSDNode *C = dyn_cast(Op)) { - MIB.addImm(C->getSExtValue()); + if (C->getAPIntValue().getBitWidth() <= 64) { + MIB.addImm(C->getSExtValue()); + } else { + MIB.addCImm( + ConstantInt::get(MF->getFunction().getContext(), C->getAPIntValue())); + } } else if (ConstantFPSDNode *F = dyn_cast(Op)) { MIB.addFPImm(F->getConstantFPValue()); } else if (RegisterSDNode *R = dyn_cast(Op)) { diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelDAGToDAG.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelDAGToDAG.cpp index ac819cf5c1801..93dbc4f157188 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyISelDAGToDAG.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelDAGToDAG.cpp @@ -15,12 +15,14 @@ #include "WebAssembly.h" #include "WebAssemblyISelLowering.h" #include "WebAssemblyTargetMachine.h" +#include "WebAssemblyUtilities.h" #include "llvm/CodeGen/MachineFrameInfo.h" #include "llvm/CodeGen/SelectionDAGISel.h" #include "llvm/CodeGen/WasmEHFuncInfo.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/Function.h" // To access function attributes. #include "llvm/IR/IntrinsicsWebAssembly.h" +#include "llvm/MC/MCSymbolWasm.h" #include "llvm/Support/Debug.h" #include "llvm/Support/KnownBits.h" #include "llvm/Support/raw_ostream.h" @@ -118,6 +120,47 @@ static SDValue getTagSymNode(int Tag, SelectionDAG *DAG) { return DAG->getTargetExternalSymbol(SymName, PtrVT); } +static APInt encodeFunctionSignature(SelectionDAG *DAG, SDLoc &DL, + SmallVector &Params, + SmallVector &Returns) { + auto toWasmValType = [&DAG, &DL](MVT VT) { + if (VT == MVT::i32) { + return wasm::ValType::I32; + } + if (VT == MVT::i64) { + return wasm::ValType::I64; + } + if (VT == MVT::f32) { + return wasm::ValType::F32; + } + if (VT == MVT::f64) { + return wasm::ValType::F64; + } + DAG->getContext()->diagnose( + DiagnosticInfoUnsupported(DAG->getMachineFunction().getFunction(), + "Unhandled type!", DL.getDebugLoc())); + }; + auto NParams = Params.size(); + auto NReturns = Returns.size(); + auto BitWidth = (NParams + NReturns + 2) * 64; + auto Sig = APInt(BitWidth, 0); + + Sig |= NParams; + for (auto &Param : Params) { + auto V = toWasmValType(Param); + Sig <<= 64; + Sig |= (int64_t)V; + } + Sig <<= 64; + Sig |= NReturns; + for (auto &Return : Returns) { + auto V = toWasmValType(Return); + Sig <<= 64; + Sig |= (int64_t)V; + } + return Sig; +} + void WebAssemblyDAGToDAGISel::Select(SDNode *Node) { // If we have a custom node, we already have selected! if (Node->isMachineOpcode()) { @@ -189,6 +232,42 @@ void WebAssemblyDAGToDAGISel::Select(SDNode *Node) { ReplaceNode(Node, TLSAlign); return; } + case Intrinsic::wasm_ref_test_func: { + // First emit the TABLE_GET instruction to convert function pointer ==> + // funcref + MachineFunction &MF = CurDAG->getMachineFunction(); + auto PtrVT = MVT::getIntegerVT(MF.getDataLayout().getPointerSizeInBits()); + MCSymbol *Table = WebAssembly::getOrCreateFunctionTableSymbol( + MF.getContext(), Subtarget); + SDValue TableSym = CurDAG->getMCSymbol(Table, PtrVT); + SDValue FuncRef = SDValue( + CurDAG->getMachineNode(WebAssembly::TABLE_GET_FUNCREF, DL, + MVT::funcref, TableSym, Node->getOperand(1)), + 0); + + // Encode the signature information into the type index placeholder. + // This gets decoded and converted into the actual type signature in + // WebAssemblyMCInstLower.cpp. + SmallVector Params; + SmallVector Results; + + MVT VT = Node->getOperand(2).getValueType().getSimpleVT(); + if (VT != MVT::Untyped) { + Params.push_back(VT); + } + for (unsigned I = 3; I < Node->getNumOperands(); ++I) { + MVT VT = Node->getOperand(I).getValueType().getSimpleVT(); + Results.push_back(VT); + } + auto Sig = encodeFunctionSignature(CurDAG, DL, Params, Results); + + auto SigOp = CurDAG->getTargetConstant( + Sig, DL, EVT::getIntegerVT(*CurDAG->getContext(), Sig.getBitWidth())); + MachineSDNode *RefTestNode = CurDAG->getMachineNode( + WebAssembly::REF_TEST_FUNCREF, DL, MVT::i32, {SigOp, FuncRef}); + ReplaceNode(Node, RefTestNode); + return; + } } break; } diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp index bf2e04caa0a61..081d09e5b9d31 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp @@ -794,6 +794,7 @@ LowerCallResults(MachineInstr &CallResults, DebugLoc DL, MachineBasicBlock *BB, if (IsIndirect) { // Placeholder for the type index. + // This gets replaced with the correct value in WebAssemblyMCInstLower.cpp MIB.addImm(0); // The table into which this call_indirect indexes. MCSymbolWasm *Table = IsFuncrefCall diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp index cc36244e63ff5..6ca046b22f503 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp @@ -15,13 +15,18 @@ #include "WebAssemblyMCInstLower.h" #include "MCTargetDesc/WebAssemblyMCAsmInfo.h" #include "MCTargetDesc/WebAssemblyMCTargetDesc.h" +#include "MCTargetDesc/WebAssemblyMCTypeUtilities.h" #include "TargetInfo/WebAssemblyTargetInfo.h" #include "Utils/WebAssemblyTypeUtilities.h" #include "WebAssemblyAsmPrinter.h" #include "WebAssemblyMachineFunctionInfo.h" #include "WebAssemblyUtilities.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/BinaryFormat/Wasm.h" #include "llvm/CodeGen/AsmPrinter.h" #include "llvm/CodeGen/MachineFunction.h" +#include "llvm/CodeGen/MachineOperand.h" #include "llvm/IR/Constants.h" #include "llvm/MC/MCAsmInfo.h" #include "llvm/MC/MCContext.h" @@ -152,6 +157,29 @@ MCOperand WebAssemblyMCInstLower::lowerTypeIndexOperand( return MCOperand::createExpr(Expr); } +MCOperand +WebAssemblyMCInstLower::lowerEncodedFunctionSignature(const APInt &Sig) const { + auto NumWords = Sig.getNumWords(); + SmallVector Params; + SmallVector Returns; + + int Idx = NumWords; + + auto GetWord = [&Idx, &Sig]() { + Idx--; + return Sig.extractBitsAsZExtValue(64, 64 * Idx); + }; + int NParams = GetWord(); + for (int I = 0; I < NParams; I++) { + Params.push_back(static_cast(GetWord())); + } + int NReturns = GetWord(); + for (int I = 0; I < NReturns; I++) { + Returns.push_back(static_cast(GetWord())); + } + return lowerTypeIndexOperand(std::move(Params), std::move(Returns)); +} + static void getFunctionReturns(const MachineInstr *MI, SmallVectorImpl &Returns) { const Function &F = MI->getMF()->getFunction(); @@ -196,11 +224,30 @@ void WebAssemblyMCInstLower::lower(const MachineInstr *MI, MCOp = MCOperand::createReg(WAReg); break; } + case llvm::MachineOperand::MO_CImmediate: { + // Lower type index placeholder for ref.test + // Currently this is the only way that CImmediates show up so panic if we + // get confused. + unsigned DescIndex = I - NumVariadicDefs; + if (DescIndex >= Desc.NumOperands) { + llvm_unreachable("unexpected CImmediate operand"); + } + const MCOperandInfo &Info = Desc.operands()[DescIndex]; + if (Info.OperandType != WebAssembly::OPERAND_TYPEINDEX) { + llvm_unreachable("unexpected CImmediate operand"); + } + MCOp = lowerEncodedFunctionSignature(MO.getCImm()->getValue()); + break; + } case MachineOperand::MO_Immediate: { unsigned DescIndex = I - NumVariadicDefs; if (DescIndex < Desc.NumOperands) { const MCOperandInfo &Info = Desc.operands()[DescIndex]; + // Replace type index placeholder with actual type index. The type index + // placeholders are Immediates and have an operand type of + // OPERAND_TYPEINDEX or OPERAND_SIGNATURE. if (Info.OperandType == WebAssembly::OPERAND_TYPEINDEX) { + // Lower type index placeholder for a CALL_INDIRECT instruction SmallVector Returns; SmallVector Params; @@ -228,6 +275,7 @@ void WebAssemblyMCInstLower::lower(const MachineInstr *MI, break; } if (Info.OperandType == WebAssembly::OPERAND_SIGNATURE) { + // Lower type index placeholder for blocks auto BT = static_cast(MO.getImm()); assert(BT != WebAssembly::BlockType::Invalid); if (BT == WebAssembly::BlockType::Multivalue) { diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.h b/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.h index 9f08499e5cde1..34404d93434bb 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.h +++ b/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.h @@ -36,6 +36,7 @@ class LLVM_LIBRARY_VISIBILITY WebAssemblyMCInstLower { MCOperand lowerSymbolOperand(const MachineOperand &MO, MCSymbol *Sym) const; MCOperand lowerTypeIndexOperand(SmallVectorImpl &&, SmallVectorImpl &&) const; + MCOperand lowerEncodedFunctionSignature(const APInt &Sig) const; public: WebAssemblyMCInstLower(MCContext &ctx, WebAssemblyAsmPrinter &printer) diff --git a/llvm/test/CodeGen/WebAssembly/ref-test-func.ll b/llvm/test/CodeGen/WebAssembly/ref-test-func.ll new file mode 100644 index 0000000000000..2f6b1f42f47ea --- /dev/null +++ b/llvm/test/CodeGen/WebAssembly/ref-test-func.ll @@ -0,0 +1,66 @@ +; RUN: llc < %s -mcpu=mvp -mattr=+reference-types | FileCheck %s + +target triple = "wasm32-unknown-unknown" + +; CHECK-LABEL: test_fpsig_1: +; CHECK: local.get 0 +; CHECK-NEXT: table.get __indirect_function_table +; CHECK-NEXT: ref.test (f32, f64, i32) -> (f32) +; CHECK-NEXT: call use +; Function Attrs: nounwind +define void @test_fpsig_1(ptr noundef %func) local_unnamed_addr #0 { +entry: + %res = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, float 0.000000e+00, float 0.000000e+00, double 0.000000e+00, i32 0) + tail call void @use(i32 noundef %res) #3 + ret void +} + +; CHECK-LABEL: test_fpsig_2: +; CHECK: local.get 0 +; CHECK-NEXT: table.get __indirect_function_table +; CHECK-NEXT: ref.test (f32, f64, i32) -> (i32) +; CHECK-NEXT: call use +define void @test_fpsig_2(ptr noundef %func) local_unnamed_addr #0 { +entry: + %res = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, i32 0, float 0.000000e+00, double 0.000000e+00, i32 0) + tail call void @use(i32 noundef %res) #3 + ret void +} + +; CHECK-LABEL: test_fpsig_3: +; CHECK: local.get 0 +; CHECK-NEXT: table.get __indirect_function_table +; CHECK-NEXT: ref.test (i32, i32, i32) -> (i32) +; CHECK-NEXT: call use +define void @test_fpsig_3(ptr noundef %func) local_unnamed_addr #0 { +entry: + %res = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, i32 0, i32 0, i32 0, i32 0) + tail call void @use(i32 noundef %res) #3 + ret void +} + +; CHECK-LABEL: test_fpsig_4: +; CHECK: local.get 0 +; CHECK-NEXT: table.get __indirect_function_table +; CHECK-NEXT: ref.test (i32, i32, i32) -> () +; CHECK-NEXT: call use +define void @test_fpsig_4(ptr noundef %func) local_unnamed_addr #0 { +entry: + %res = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, token poison, i32 0, i32 0, i32 0) + tail call void @use(i32 noundef %res) #3 + ret void +} + +; CHECK-LABEL: test_fpsig_5: +; CHECK: local.get 0 +; CHECK-NEXT: table.get __indirect_function_table +; CHECK-NEXT: ref.test () -> () +; CHECK-NEXT: call use +define void @test_fpsig_5(ptr noundef %func) local_unnamed_addr #0 { +entry: + %res = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, token poison) + tail call void @use(i32 noundef %res) #3 + ret void +} + +declare void @use(i32 noundef) local_unnamed_addr #1