Skip to content

[WebAssembly,llvm] Add llvm.wasm.ref.test.func intrinsic #147486

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions llvm/include/llvm/IR/IntrinsicsWebAssembly.td
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ def int_wasm_ref_is_null_exn :
DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_exnref_ty], [IntrNoMem],
"llvm.wasm.ref.is_null.exn">;

def int_wasm_ref_test_func
: DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_ptr_ty, llvm_vararg_ty],
[IntrNoMem]>;

//===----------------------------------------------------------------------===//
// Table intrinsics
//===----------------------------------------------------------------------===//
Expand Down
7 changes: 6 additions & 1 deletion llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,12 @@ void InstrEmitter::AddOperand(MachineInstrBuilder &MIB, SDValue Op,
AddRegisterOperand(MIB, Op, IIOpNum, II, VRBaseMap,
IsDebug, IsClone, IsCloned);
} else if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Op)) {
MIB.addImm(C->getSExtValue());
if (C->getAPIntValue().getBitWidth() <= 64) {
MIB.addImm(C->getSExtValue());
} else {
MIB.addCImm(
ConstantInt::get(MF->getFunction().getContext(), C->getAPIntValue()));
}
} else if (ConstantFPSDNode *F = dyn_cast<ConstantFPSDNode>(Op)) {
MIB.addFPImm(F->getConstantFPValue());
} else if (RegisterSDNode *R = dyn_cast<RegisterSDNode>(Op)) {
Expand Down
79 changes: 79 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
#include "WebAssembly.h"
#include "WebAssemblyISelLowering.h"
#include "WebAssemblyTargetMachine.h"
#include "WebAssemblyUtilities.h"
#include "llvm/CodeGen/MachineFrameInfo.h"
#include "llvm/CodeGen/SelectionDAGISel.h"
#include "llvm/CodeGen/WasmEHFuncInfo.h"
#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/Function.h" // To access function attributes.
#include "llvm/IR/IntrinsicsWebAssembly.h"
#include "llvm/MC/MCSymbolWasm.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/KnownBits.h"
#include "llvm/Support/raw_ostream.h"
Expand Down Expand Up @@ -118,6 +120,47 @@ static SDValue getTagSymNode(int Tag, SelectionDAG *DAG) {
return DAG->getTargetExternalSymbol(SymName, PtrVT);
}

static APInt encodeFunctionSignature(SelectionDAG *DAG, SDLoc &DL,
SmallVector<MVT, 4> &Params,
SmallVector<MVT, 1> &Returns) {
auto toWasmValType = [&DAG, &DL](MVT VT) {
if (VT == MVT::i32) {
return wasm::ValType::I32;
}
if (VT == MVT::i64) {
return wasm::ValType::I64;
}
if (VT == MVT::f32) {
return wasm::ValType::F32;
}
if (VT == MVT::f64) {
return wasm::ValType::F64;
}
DAG->getContext()->diagnose(
DiagnosticInfoUnsupported(DAG->getMachineFunction().getFunction(),
"Unhandled type!", DL.getDebugLoc()));
};
auto NParams = Params.size();
auto NReturns = Returns.size();
auto BitWidth = (NParams + NReturns + 2) * 64;
auto Sig = APInt(BitWidth, 0);

Sig |= NParams;
for (auto &Param : Params) {
auto V = toWasmValType(Param);
Sig <<= 64;
Sig |= (int64_t)V;
}
Sig <<= 64;
Sig |= NReturns;
for (auto &Return : Returns) {
auto V = toWasmValType(Return);
Sig <<= 64;
Sig |= (int64_t)V;
}
return Sig;
}

void WebAssemblyDAGToDAGISel::Select(SDNode *Node) {
// If we have a custom node, we already have selected!
if (Node->isMachineOpcode()) {
Expand Down Expand Up @@ -189,6 +232,42 @@ void WebAssemblyDAGToDAGISel::Select(SDNode *Node) {
ReplaceNode(Node, TLSAlign);
return;
}
case Intrinsic::wasm_ref_test_func: {
// First emit the TABLE_GET instruction to convert function pointer ==>
// funcref
MachineFunction &MF = CurDAG->getMachineFunction();
auto PtrVT = MVT::getIntegerVT(MF.getDataLayout().getPointerSizeInBits());
MCSymbol *Table = WebAssembly::getOrCreateFunctionTableSymbol(
MF.getContext(), Subtarget);
SDValue TableSym = CurDAG->getMCSymbol(Table, PtrVT);
SDValue FuncRef = SDValue(
CurDAG->getMachineNode(WebAssembly::TABLE_GET_FUNCREF, DL,
MVT::funcref, TableSym, Node->getOperand(1)),
0);

// Encode the signature information into the type index placeholder.
// This gets decoded and converted into the actual type signature in
// WebAssemblyMCInstLower.cpp.
SmallVector<MVT, 4> Params;
SmallVector<MVT, 1> Results;

MVT VT = Node->getOperand(2).getValueType().getSimpleVT();
if (VT != MVT::Untyped) {
Params.push_back(VT);
}
for (unsigned I = 3; I < Node->getNumOperands(); ++I) {
MVT VT = Node->getOperand(I).getValueType().getSimpleVT();
Results.push_back(VT);
}
auto Sig = encodeFunctionSignature(CurDAG, DL, Params, Results);

auto SigOp = CurDAG->getTargetConstant(
Sig, DL, EVT::getIntegerVT(*CurDAG->getContext(), Sig.getBitWidth()));
MachineSDNode *RefTestNode = CurDAG->getMachineNode(
WebAssembly::REF_TEST_FUNCREF, DL, MVT::i32, {SigOp, FuncRef});
ReplaceNode(Node, RefTestNode);
return;
}
}
break;
}
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,7 @@ LowerCallResults(MachineInstr &CallResults, DebugLoc DL, MachineBasicBlock *BB,

if (IsIndirect) {
// Placeholder for the type index.
// This gets replaced with the correct value in WebAssemblyMCInstLower.cpp
MIB.addImm(0);
// The table into which this call_indirect indexes.
MCSymbolWasm *Table = IsFuncrefCall
Expand Down
48 changes: 48 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,18 @@
#include "WebAssemblyMCInstLower.h"
#include "MCTargetDesc/WebAssemblyMCAsmInfo.h"
#include "MCTargetDesc/WebAssemblyMCTargetDesc.h"
#include "MCTargetDesc/WebAssemblyMCTypeUtilities.h"
#include "TargetInfo/WebAssemblyTargetInfo.h"
#include "Utils/WebAssemblyTypeUtilities.h"
#include "WebAssemblyAsmPrinter.h"
#include "WebAssemblyMachineFunctionInfo.h"
#include "WebAssemblyUtilities.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/BinaryFormat/Wasm.h"
#include "llvm/CodeGen/AsmPrinter.h"
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/CodeGen/MachineOperand.h"
#include "llvm/IR/Constants.h"
#include "llvm/MC/MCAsmInfo.h"
#include "llvm/MC/MCContext.h"
Expand Down Expand Up @@ -152,6 +157,29 @@ MCOperand WebAssemblyMCInstLower::lowerTypeIndexOperand(
return MCOperand::createExpr(Expr);
}

MCOperand
WebAssemblyMCInstLower::lowerEncodedFunctionSignature(const APInt &Sig) const {
auto NumWords = Sig.getNumWords();
SmallVector<wasm::ValType, 4> Params;
SmallVector<wasm::ValType, 2> Returns;

int Idx = NumWords;

auto GetWord = [&Idx, &Sig]() {
Idx--;
return Sig.extractBitsAsZExtValue(64, 64 * Idx);
};
int NParams = GetWord();
for (int I = 0; I < NParams; I++) {
Params.push_back(static_cast<wasm::ValType>(GetWord()));
}
int NReturns = GetWord();
for (int I = 0; I < NReturns; I++) {
Returns.push_back(static_cast<wasm::ValType>(GetWord()));
}
return lowerTypeIndexOperand(std::move(Params), std::move(Returns));
}

static void getFunctionReturns(const MachineInstr *MI,
SmallVectorImpl<wasm::ValType> &Returns) {
const Function &F = MI->getMF()->getFunction();
Expand Down Expand Up @@ -196,11 +224,30 @@ void WebAssemblyMCInstLower::lower(const MachineInstr *MI,
MCOp = MCOperand::createReg(WAReg);
break;
}
case llvm::MachineOperand::MO_CImmediate: {
// Lower type index placeholder for ref.test
// Currently this is the only way that CImmediates show up so panic if we
// get confused.
unsigned DescIndex = I - NumVariadicDefs;
if (DescIndex >= Desc.NumOperands) {
llvm_unreachable("unexpected CImmediate operand");
}
const MCOperandInfo &Info = Desc.operands()[DescIndex];
if (Info.OperandType != WebAssembly::OPERAND_TYPEINDEX) {
llvm_unreachable("unexpected CImmediate operand");
}
MCOp = lowerEncodedFunctionSignature(MO.getCImm()->getValue());
break;
}
case MachineOperand::MO_Immediate: {
unsigned DescIndex = I - NumVariadicDefs;
if (DescIndex < Desc.NumOperands) {
const MCOperandInfo &Info = Desc.operands()[DescIndex];
// Replace type index placeholder with actual type index. The type index
// placeholders are Immediates and have an operand type of
// OPERAND_TYPEINDEX or OPERAND_SIGNATURE.
if (Info.OperandType == WebAssembly::OPERAND_TYPEINDEX) {
// Lower type index placeholder for a CALL_INDIRECT instruction
SmallVector<wasm::ValType, 4> Returns;
SmallVector<wasm::ValType, 4> Params;

Expand Down Expand Up @@ -228,6 +275,7 @@ void WebAssemblyMCInstLower::lower(const MachineInstr *MI,
break;
}
if (Info.OperandType == WebAssembly::OPERAND_SIGNATURE) {
// Lower type index placeholder for blocks
auto BT = static_cast<WebAssembly::BlockType>(MO.getImm());
assert(BT != WebAssembly::BlockType::Invalid);
if (BT == WebAssembly::BlockType::Multivalue) {
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class LLVM_LIBRARY_VISIBILITY WebAssemblyMCInstLower {
MCOperand lowerSymbolOperand(const MachineOperand &MO, MCSymbol *Sym) const;
MCOperand lowerTypeIndexOperand(SmallVectorImpl<wasm::ValType> &&,
SmallVectorImpl<wasm::ValType> &&) const;
MCOperand lowerEncodedFunctionSignature(const APInt &Sig) const;

public:
WebAssemblyMCInstLower(MCContext &ctx, WebAssemblyAsmPrinter &printer)
Expand Down
66 changes: 66 additions & 0 deletions llvm/test/CodeGen/WebAssembly/ref-test-func.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
; RUN: llc < %s -mcpu=mvp -mattr=+reference-types | FileCheck %s

target triple = "wasm32-unknown-unknown"

; CHECK-LABEL: test_fpsig_1:
; CHECK: local.get 0
; CHECK-NEXT: table.get __indirect_function_table
; CHECK-NEXT: ref.test (f32, f64, i32) -> (f32)
; CHECK-NEXT: call use
; Function Attrs: nounwind
define void @test_fpsig_1(ptr noundef %func) local_unnamed_addr #0 {
entry:
%res = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, float 0.000000e+00, float 0.000000e+00, double 0.000000e+00, i32 0)
tail call void @use(i32 noundef %res) #3
ret void
}

; CHECK-LABEL: test_fpsig_2:
; CHECK: local.get 0
; CHECK-NEXT: table.get __indirect_function_table
; CHECK-NEXT: ref.test (f32, f64, i32) -> (i32)
; CHECK-NEXT: call use
define void @test_fpsig_2(ptr noundef %func) local_unnamed_addr #0 {
entry:
%res = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, i32 0, float 0.000000e+00, double 0.000000e+00, i32 0)
tail call void @use(i32 noundef %res) #3
ret void
}

; CHECK-LABEL: test_fpsig_3:
; CHECK: local.get 0
; CHECK-NEXT: table.get __indirect_function_table
; CHECK-NEXT: ref.test (i32, i32, i32) -> (i32)
; CHECK-NEXT: call use
define void @test_fpsig_3(ptr noundef %func) local_unnamed_addr #0 {
entry:
%res = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, i32 0, i32 0, i32 0, i32 0)
tail call void @use(i32 noundef %res) #3
ret void
}

; CHECK-LABEL: test_fpsig_4:
; CHECK: local.get 0
; CHECK-NEXT: table.get __indirect_function_table
; CHECK-NEXT: ref.test (i32, i32, i32) -> ()
; CHECK-NEXT: call use
define void @test_fpsig_4(ptr noundef %func) local_unnamed_addr #0 {
entry:
%res = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, token poison, i32 0, i32 0, i32 0)
tail call void @use(i32 noundef %res) #3
ret void
}

; CHECK-LABEL: test_fpsig_5:
; CHECK: local.get 0
; CHECK-NEXT: table.get __indirect_function_table
; CHECK-NEXT: ref.test () -> ()
; CHECK-NEXT: call use
define void @test_fpsig_5(ptr noundef %func) local_unnamed_addr #0 {
entry:
%res = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, token poison)
tail call void @use(i32 noundef %res) #3
ret void
}

declare void @use(i32 noundef) local_unnamed_addr #1
Loading