Skip to content

[WebAssembly,llvm] Add llvm.wasm.ref.test.func intrinsic #147486

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llvm/include/llvm/CodeGen/ISDOpcodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ enum NodeType {
/// materialized in registers.
TargetConstant,
TargetConstantFP,
TargetConstantAP,

/// TargetGlobalAddress - Like GlobalAddress, but the DAG does no folding or
/// anything else with this node, and this is valid in the target-specific
Expand Down
10 changes: 8 additions & 2 deletions llvm/include/llvm/CodeGen/SelectionDAG.h
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,8 @@ class SelectionDAG {
LLVM_ABI SDValue getConstant(uint64_t Val, const SDLoc &DL, EVT VT,
bool isTarget = false, bool isOpaque = false);
LLVM_ABI SDValue getConstant(const APInt &Val, const SDLoc &DL, EVT VT,
bool isTarget = false, bool isOpaque = false);
bool isTarget = false, bool isOpaque = false,
bool isArbitraryPrecision = false);

LLVM_ABI SDValue getSignedConstant(int64_t Val, const SDLoc &DL, EVT VT,
bool isTarget = false,
Expand All @@ -694,7 +695,8 @@ class SelectionDAG {
bool IsOpaque = false);

LLVM_ABI SDValue getConstant(const ConstantInt &Val, const SDLoc &DL, EVT VT,
bool isTarget = false, bool isOpaque = false);
bool isTarget = false, bool isOpaque = false,
bool isArbitraryPrecision = false);
LLVM_ABI SDValue getIntPtrConstant(uint64_t Val, const SDLoc &DL,
bool isTarget = false);
LLVM_ABI SDValue getShiftAmountConstant(uint64_t Val, EVT VT,
Expand All @@ -712,6 +714,10 @@ class SelectionDAG {
bool isOpaque = false) {
return getConstant(Val, DL, VT, true, isOpaque);
}
SDValue getTargetConstantAP(const APInt &Val, const SDLoc &DL, EVT VT,
bool isOpaque = false) {
return getConstant(Val, DL, VT, true, isOpaque, true);
}
SDValue getTargetConstant(const ConstantInt &Val, const SDLoc &DL, EVT VT,
bool isOpaque = false) {
return getConstant(Val, DL, VT, true, isOpaque);
Expand Down
12 changes: 7 additions & 5 deletions llvm/include/llvm/CodeGen/SelectionDAGNodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -1742,10 +1742,11 @@ class ConstantSDNode : public SDNode {

const ConstantInt *Value;

ConstantSDNode(bool isTarget, bool isOpaque, const ConstantInt *val,
SDVTList VTs)
: SDNode(isTarget ? ISD::TargetConstant : ISD::Constant, 0, DebugLoc(),
VTs),
ConstantSDNode(bool isTarget, bool isOpaque, bool isAPTarget,
const ConstantInt *val, SDVTList VTs)
: SDNode(isAPTarget ? ISD::TargetConstantAP
: (isTarget ? ISD::TargetConstant : ISD::Constant),
0, DebugLoc(), VTs),
Value(val) {
assert(!isa<VectorType>(val->getType()) && "Unexpected vector type!");
ConstantSDNodeBits.IsOpaque = isOpaque;
Expand All @@ -1772,7 +1773,8 @@ class ConstantSDNode : public SDNode {

static bool classof(const SDNode *N) {
return N->getOpcode() == ISD::Constant ||
N->getOpcode() == ISD::TargetConstant;
N->getOpcode() == ISD::TargetConstant ||
N->getOpcode() == ISD::TargetConstantAP;
}
};

Expand Down
4 changes: 4 additions & 0 deletions llvm/include/llvm/IR/IntrinsicsWebAssembly.td
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ def int_wasm_ref_is_null_exn :
DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_exnref_ty], [IntrNoMem],
"llvm.wasm.ref.is_null.exn">;

def int_wasm_ref_test_func
: DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_ptr_ty, llvm_vararg_ty],
[IntrNoMem], "llvm.wasm.ref.test.func">;

//===----------------------------------------------------------------------===//
// Table intrinsics
//===----------------------------------------------------------------------===//
Expand Down
7 changes: 6 additions & 1 deletion llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,12 @@ void InstrEmitter::AddOperand(MachineInstrBuilder &MIB, SDValue Op,
AddRegisterOperand(MIB, Op, IIOpNum, II, VRBaseMap,
IsDebug, IsClone, IsCloned);
} else if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Op)) {
MIB.addImm(C->getSExtValue());
if (C->getOpcode() == ISD::TargetConstantAP) {
MIB.addCImm(
ConstantInt::get(MF->getFunction().getContext(), C->getAPIntValue()));
} else {
MIB.addImm(C->getSExtValue());
}
} else if (ConstantFPSDNode *F = dyn_cast<ConstantFPSDNode>(Op)) {
MIB.addFPImm(F->getConstantFPValue());
} else if (RegisterSDNode *R = dyn_cast<RegisterSDNode>(Op)) {
Expand Down
6 changes: 4 additions & 2 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -968,6 +968,7 @@ void SelectionDAGLegalize::LegalizeOp(SDNode *Node) {

// Allow illegal target nodes and illegal registers.
if (Node->getOpcode() == ISD::TargetConstant ||
Node->getOpcode() == ISD::TargetConstantAP ||
Node->getOpcode() == ISD::Register)
return;

Expand All @@ -979,10 +980,11 @@ void SelectionDAGLegalize::LegalizeOp(SDNode *Node) {

for (const SDValue &Op : Node->op_values())
assert((TLI.getTypeAction(*DAG.getContext(), Op.getValueType()) ==
TargetLowering::TypeLegal ||
TargetLowering::TypeLegal ||
Op.getOpcode() == ISD::TargetConstant ||
Op->getOpcode() == ISD::TargetConstantAP ||
Op.getOpcode() == ISD::Register) &&
"Unexpected illegal type!");
"Unexpected illegal type!");
#endif

// Figure out the correct action; the way to query this varies by opcode
Expand Down
13 changes: 7 additions & 6 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1664,14 +1664,14 @@ SDValue SelectionDAG::getConstant(uint64_t Val, const SDLoc &DL, EVT VT,
}

SDValue SelectionDAG::getConstant(const APInt &Val, const SDLoc &DL, EVT VT,
bool isT, bool isO) {
return getConstant(*ConstantInt::get(*Context, Val), DL, VT, isT, isO);
bool isT, bool isO, bool isAP) {
return getConstant(*ConstantInt::get(*Context, Val), DL, VT, isT, isO, isAP);
}

SDValue SelectionDAG::getConstant(const ConstantInt &Val, const SDLoc &DL,
EVT VT, bool isT, bool isO) {
EVT VT, bool isT, bool isO, bool isAP) {
assert(VT.isInteger() && "Cannot create FP integer constant!");

isT |= isAP;
EVT EltVT = VT.getScalarType();
const ConstantInt *Elt = &Val;

Expand Down Expand Up @@ -1760,7 +1760,8 @@ SDValue SelectionDAG::getConstant(const ConstantInt &Val, const SDLoc &DL,

assert(Elt->getBitWidth() == EltVT.getSizeInBits() &&
"APInt size does not match type size!");
unsigned Opc = isT ? ISD::TargetConstant : ISD::Constant;
unsigned Opc = isAP ? ISD::TargetConstantAP
: (isT ? ISD::TargetConstant : ISD::Constant);
SDVTList VTs = getVTList(EltVT);
FoldingSetNodeID ID;
AddNodeIDNode(ID, Opc, VTs, {});
Expand All @@ -1773,7 +1774,7 @@ SDValue SelectionDAG::getConstant(const ConstantInt &Val, const SDLoc &DL,
return SDValue(N, 0);

if (!N) {
N = newSDNode<ConstantSDNode>(isT, isO, Elt, VTs);
N = newSDNode<ConstantSDNode>(isT, isO, isAP, Elt, VTs);
CSEMap.InsertNode(N, IP);
InsertNode(N);
NewSDValueDbgMsg(SDValue(N, 0), "Creating constant: ", this);
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3255,6 +3255,7 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch,
case ISD::HANDLENODE:
case ISD::MDNODE_SDNODE:
case ISD::TargetConstant:
case ISD::TargetConstantAP:
case ISD::TargetConstantFP:
case ISD::TargetConstantPool:
case ISD::TargetFrameIndex:
Expand Down
67 changes: 67 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "WebAssemblySubtarget.h"
#include "WebAssemblyTargetMachine.h"
#include "WebAssemblyUtilities.h"
#include "llvm/BinaryFormat/Wasm.h"
#include "llvm/CodeGen/CallingConvLower.h"
#include "llvm/CodeGen/MachineFrameInfo.h"
#include "llvm/CodeGen/MachineInstrBuilder.h"
Expand Down Expand Up @@ -794,6 +795,7 @@ LowerCallResults(MachineInstr &CallResults, DebugLoc DL, MachineBasicBlock *BB,

if (IsIndirect) {
// Placeholder for the type index.
// This gets replaced with the correct value in WebAssemblyMCInstLower.cpp
MIB.addImm(0);
// The table into which this call_indirect indexes.
MCSymbolWasm *Table = IsFuncrefCall
Expand Down Expand Up @@ -2253,6 +2255,71 @@ SDValue WebAssemblyTargetLowering::LowerIntrinsic(SDValue Op,
DAG.getTargetExternalSymbol(TlsBase, PtrVT)),
0);
}
case Intrinsic::wasm_ref_test_func: {
// First emit the TABLE_GET instruction to convert function pointer ==>
// funcref
MachineFunction &MF = DAG.getMachineFunction();
auto PtrVT = getPointerTy(MF.getDataLayout());
MCSymbol *Table =
WebAssembly::getOrCreateFunctionTableSymbol(MF.getContext(), Subtarget);
SDValue TableSym = DAG.getMCSymbol(Table, PtrVT);
SDValue FuncRef =
SDValue(DAG.getMachineNode(WebAssembly::TABLE_GET_FUNCREF, DL,
MVT::funcref, TableSym, Op.getOperand(1)),
0);

// Encode the signature information into the type index placeholder.
// This gets decoded and converted into the actual type signature in
// WebAssemblyMCInstLower.cpp.
auto NParams = Op.getNumOperands() - 2;
auto Sig = APInt(NParams * 64, 0);
// The return type has to be a BlockType since it can be void.
{
SDValue Operand = Op.getOperand(2);
MVT VT = Operand.getValueType().getSimpleVT();
WebAssembly::BlockType V;
if (VT == MVT::Untyped) {
V = WebAssembly::BlockType::Void;
} else if (VT == MVT::i32) {
V = WebAssembly::BlockType::I32;
} else if (VT == MVT::i64) {
V = WebAssembly::BlockType::I64;
} else if (VT == MVT::f32) {
V = WebAssembly::BlockType::F32;
} else if (VT == MVT::f64) {
V = WebAssembly::BlockType::F64;
} else {
llvm_unreachable("Unhandled type!");
}
Sig |= (int64_t)V;
}
for (unsigned i = 3; i < Op.getNumOperands(); ++i) {
SDValue Operand = Op.getOperand(i);
MVT VT = Operand.getValueType().getSimpleVT();
wasm::ValType V;
if (VT == MVT::i32) {
V = wasm::ValType::I32;
} else if (VT == MVT::i64) {
V = wasm::ValType::I64;
} else if (VT == MVT::f32) {
V = wasm::ValType::F32;
} else if (VT == MVT::f64) {
V = wasm::ValType::F64;
} else {
llvm_unreachable("Unhandled type!");
}
Sig <<= 64;
Sig |= (int64_t)V;
}

SmallVector<SDValue, 4> Ops;
Ops.push_back(DAG.getTargetConstantAP(
Sig, DL, EVT::getIntegerVT(*DAG.getContext(), NParams * 64)));
Ops.push_back(FuncRef);
return SDValue(
DAG.getMachineNode(WebAssembly::REF_TEST_FUNCREF, DL, MVT::i32, Ops),
0);
}
}
}

Expand Down
74 changes: 74 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,17 @@
#include "WebAssemblyMCInstLower.h"
#include "MCTargetDesc/WebAssemblyMCAsmInfo.h"
#include "MCTargetDesc/WebAssemblyMCTargetDesc.h"
#include "MCTargetDesc/WebAssemblyMCTypeUtilities.h"
#include "TargetInfo/WebAssemblyTargetInfo.h"
#include "Utils/WebAssemblyTypeUtilities.h"
#include "WebAssemblyAsmPrinter.h"
#include "WebAssemblyMachineFunctionInfo.h"
#include "WebAssemblyUtilities.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/BinaryFormat/Wasm.h"
#include "llvm/CodeGen/AsmPrinter.h"
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/CodeGen/MachineOperand.h"
#include "llvm/IR/Constants.h"
#include "llvm/MC/MCAsmInfo.h"
#include "llvm/MC/MCContext.h"
Expand Down Expand Up @@ -196,11 +200,80 @@ void WebAssemblyMCInstLower::lower(const MachineInstr *MI,
MCOp = MCOperand::createReg(WAReg);
break;
}
case llvm::MachineOperand::MO_CImmediate: {
// Lower type index placeholder for ref.test
// Currently this is the only way that CImmediates show up so panic if we
// get confused.
unsigned DescIndex = I - NumVariadicDefs;
if (DescIndex >= Desc.NumOperands) {
llvm_unreachable("unexpected CImmediate operand");
}
const MCOperandInfo &Info = Desc.operands()[DescIndex];
if (Info.OperandType != WebAssembly::OPERAND_TYPEINDEX) {
llvm_unreachable("unexpected CImmediate operand");
}
auto CImm = MO.getCImm()->getValue();
auto NumWords = CImm.getNumWords();
// Extract the type data we packed into the CImm in LowerRefTestFuncRef.
// We need to load the words from most significant to least significant
// order because of the way we bitshifted them in from the right.
// The return type needs special handling because it could be void.
auto ReturnType = static_cast<WebAssembly::BlockType>(
CImm.extractBitsAsZExtValue(64, (NumWords - 1) * 64));
SmallVector<wasm::ValType, 2> Returns;
switch (ReturnType) {
case WebAssembly::BlockType::Invalid:
llvm_unreachable("Invalid return type");
case WebAssembly::BlockType::I32:
Returns = {wasm::ValType::I32};
break;
case WebAssembly::BlockType::I64:
Returns = {wasm::ValType::I64};
break;
case WebAssembly::BlockType::F32:
Returns = {wasm::ValType::F32};
break;
case WebAssembly::BlockType::F64:
Returns = {wasm::ValType::F64};
break;
case WebAssembly::BlockType::Void:
Returns = {};
break;
case WebAssembly::BlockType::Exnref:
Returns = {wasm::ValType::EXNREF};
break;
case WebAssembly::BlockType::Externref:
Returns = {wasm::ValType::EXTERNREF};
break;
case WebAssembly::BlockType::Funcref:
Returns = {wasm::ValType::FUNCREF};
break;
case WebAssembly::BlockType::V128:
Returns = {wasm::ValType::V128};
break;
case WebAssembly::BlockType::Multivalue: {
llvm_unreachable("Invalid return type");
}
}
SmallVector<wasm::ValType, 4> Params;

for (int I = NumWords - 2; I >= 0; I--) {
auto Val = CImm.extractBitsAsZExtValue(64, 64 * I);
auto ParamType = static_cast<wasm::ValType>(Val);
Params.push_back(ParamType);
}
MCOp = lowerTypeIndexOperand(std::move(Returns), std::move(Params));
break;
}
case MachineOperand::MO_Immediate: {
unsigned DescIndex = I - NumVariadicDefs;
if (DescIndex < Desc.NumOperands) {
const MCOperandInfo &Info = Desc.operands()[DescIndex];
// Replace type index placeholder with actual type index. The type index
// placeholders are Immediates and have an operand type of
// OPERAND_TYPEINDEX or OPERAND_SIGNATURE.
if (Info.OperandType == WebAssembly::OPERAND_TYPEINDEX) {
// Lower type index placeholder for a CALL_INDIRECT instruction
SmallVector<wasm::ValType, 4> Returns;
SmallVector<wasm::ValType, 4> Params;

Expand Down Expand Up @@ -228,6 +301,7 @@ void WebAssemblyMCInstLower::lower(const MachineInstr *MI,
break;
}
if (Info.OperandType == WebAssembly::OPERAND_SIGNATURE) {
// Lower type index placeholder for blocks
auto BT = static_cast<WebAssembly::BlockType>(MO.getImm());
assert(BT != WebAssembly::BlockType::Invalid);
if (BT == WebAssembly::BlockType::Multivalue) {
Expand Down
42 changes: 42 additions & 0 deletions llvm/test/CodeGen/WebAssembly/ref-test-func.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
; RUN: llc < %s -mcpu=mvp -mattr=+reference-types | FileCheck %s

target triple = "wasm32-unknown-unknown"

; CHECK-LABEL: test_function_pointer_signature_void:
; CHECK-NEXT: .functype test_function_pointer_signature_void (i32) -> ()
; CHECK-NEXT: .local funcref
; CHECK: local.get 0
; CHECK-NEXT: table.get __indirect_function_table
; CHECK-NEXT: local.tee 1
; CHECK-NEXT: ref.test (f32, f64, i32) -> (f32)
; CHECK-NEXT: call use
; CHECK-NEXT: local.get 1
; CHECK-NEXT: ref.test (f32, f64, i32) -> (i32)
; CHECK-NEXT: call use
; CHECK-NEXT: local.get 1
; CHECK-NEXT: ref.test (i32, i32, i32) -> (i32)
; CHECK-NEXT: call use
; CHECK-NEXT: local.get 1
; CHECK-NEXT: ref.test (i32, i32, i32) -> ()
; CHECK-NEXT: call use
; CHECK-NEXT: local.get 1
; CHECK-NEXT: ref.test () -> ()
; CHECK-NEXT: call use

; Function Attrs: nounwind
define void @test_function_pointer_signature_void(ptr noundef %func) local_unnamed_addr #0 {
entry:
%0 = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, float 0.000000e+00, float 0.000000e+00, double 0.000000e+00, i32 0)
tail call void @use(i32 noundef %0) #3
%1 = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, i32 0, float 0.000000e+00, double 0.000000e+00, i32 0)
tail call void @use(i32 noundef %1) #3
%2 = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, i32 0, i32 0, i32 0, i32 0)
tail call void @use(i32 noundef %2) #3
%3 = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, token poison, i32 0, i32 0, i32 0)
tail call void @use(i32 noundef %3) #3
%4 = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, token poison)
tail call void @use(i32 noundef %4) #3
ret void
}

declare void @use(i32 noundef) local_unnamed_addr #1
Loading