Skip to content

[WebAssembly,llvm] Add llvm.wasm.ref.test.func intrinsic #147486

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions llvm/include/llvm/IR/IntrinsicsWebAssembly.td
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ def int_wasm_ref_is_null_exn :
DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_exnref_ty], [IntrNoMem],
"llvm.wasm.ref.is_null.exn">;

def int_wasm_ref_test_func
: DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_ptr_ty, llvm_vararg_ty],
[IntrNoMem], "llvm.wasm.ref.test.func">;

//===----------------------------------------------------------------------===//
// Table intrinsics
//===----------------------------------------------------------------------===//
Expand Down
7 changes: 6 additions & 1 deletion llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,12 @@ void InstrEmitter::AddOperand(MachineInstrBuilder &MIB, SDValue Op,
AddRegisterOperand(MIB, Op, IIOpNum, II, VRBaseMap,
IsDebug, IsClone, IsCloned);
} else if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Op)) {
MIB.addImm(C->getSExtValue());
if (C->getAPIntValue().getBitWidth() <= 64) {
MIB.addImm(C->getSExtValue());
} else {
MIB.addCImm(
ConstantInt::get(MF->getFunction().getContext(), C->getAPIntValue()));
}
} else if (ConstantFPSDNode *F = dyn_cast<ConstantFPSDNode>(Op)) {
MIB.addFPImm(F->getConstantFPValue());
} else if (RegisterSDNode *R = dyn_cast<RegisterSDNode>(Op)) {
Expand Down
68 changes: 68 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "WebAssemblySubtarget.h"
#include "WebAssemblyTargetMachine.h"
#include "WebAssemblyUtilities.h"
#include "llvm/BinaryFormat/Wasm.h"
#include "llvm/CodeGen/CallingConvLower.h"
#include "llvm/CodeGen/MachineFrameInfo.h"
#include "llvm/CodeGen/MachineInstrBuilder.h"
Expand Down Expand Up @@ -794,6 +795,7 @@ LowerCallResults(MachineInstr &CallResults, DebugLoc DL, MachineBasicBlock *BB,

if (IsIndirect) {
// Placeholder for the type index.
// This gets replaced with the correct value in WebAssemblyMCInstLower.cpp
MIB.addImm(0);
// The table into which this call_indirect indexes.
MCSymbolWasm *Table = IsFuncrefCall
Expand Down Expand Up @@ -2253,6 +2255,72 @@ SDValue WebAssemblyTargetLowering::LowerIntrinsic(SDValue Op,
DAG.getTargetExternalSymbol(TlsBase, PtrVT)),
0);
}
case Intrinsic::wasm_ref_test_func: {
// First emit the TABLE_GET instruction to convert function pointer ==>
// funcref
MachineFunction &MF = DAG.getMachineFunction();
auto PtrVT = getPointerTy(MF.getDataLayout());
MCSymbol *Table =
WebAssembly::getOrCreateFunctionTableSymbol(MF.getContext(), Subtarget);
SDValue TableSym = DAG.getMCSymbol(Table, PtrVT);
SDValue FuncRef =
SDValue(DAG.getMachineNode(WebAssembly::TABLE_GET_FUNCREF, DL,
MVT::funcref, TableSym, Op.getOperand(1)),
0);

// Encode the signature information into the type index placeholder.
// This gets decoded and converted into the actual type signature in
// WebAssemblyMCInstLower.cpp.
auto NParams = Op.getNumOperands() - 2;
auto BitWidth = (NParams + 1) * 64;
auto Sig = APInt(BitWidth, 0);
// The return type has to be a BlockType since it can be void.
{
SDValue Operand = Op.getOperand(2);
MVT VT = Operand.getValueType().getSimpleVT();
WebAssembly::BlockType V;
if (VT == MVT::Untyped) {
V = WebAssembly::BlockType::Void;
} else if (VT == MVT::i32) {
V = WebAssembly::BlockType::I32;
} else if (VT == MVT::i64) {
V = WebAssembly::BlockType::I64;
} else if (VT == MVT::f32) {
V = WebAssembly::BlockType::F32;
} else if (VT == MVT::f64) {
V = WebAssembly::BlockType::F64;
} else {
llvm_unreachable("Unhandled type!");
}
Sig |= (int64_t)V;
}
for (unsigned i = 3; i < Op.getNumOperands(); ++i) {
SDValue Operand = Op.getOperand(i);
MVT VT = Operand.getValueType().getSimpleVT();
wasm::ValType V;
if (VT == MVT::i32) {
V = wasm::ValType::I32;
} else if (VT == MVT::i64) {
V = wasm::ValType::I64;
} else if (VT == MVT::f32) {
V = wasm::ValType::F32;
} else if (VT == MVT::f64) {
V = wasm::ValType::F64;
} else {
llvm_unreachable("Unhandled type!");
}
Sig <<= 64;
Sig |= (int64_t)V;
}

SmallVector<SDValue, 4> Ops;
Ops.push_back(DAG.getTargetConstant(
Sig, DL, EVT::getIntegerVT(*DAG.getContext(), BitWidth)));
Ops.push_back(FuncRef);
return SDValue(
DAG.getMachineNode(WebAssembly::REF_TEST_FUNCREF, DL, MVT::i32, Ops),
0);
}
}
}

Expand Down
74 changes: 74 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,17 @@
#include "WebAssemblyMCInstLower.h"
#include "MCTargetDesc/WebAssemblyMCAsmInfo.h"
#include "MCTargetDesc/WebAssemblyMCTargetDesc.h"
#include "MCTargetDesc/WebAssemblyMCTypeUtilities.h"
#include "TargetInfo/WebAssemblyTargetInfo.h"
#include "Utils/WebAssemblyTypeUtilities.h"
#include "WebAssemblyAsmPrinter.h"
#include "WebAssemblyMachineFunctionInfo.h"
#include "WebAssemblyUtilities.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/BinaryFormat/Wasm.h"
#include "llvm/CodeGen/AsmPrinter.h"
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/CodeGen/MachineOperand.h"
#include "llvm/IR/Constants.h"
#include "llvm/MC/MCAsmInfo.h"
#include "llvm/MC/MCContext.h"
Expand Down Expand Up @@ -196,11 +200,80 @@ void WebAssemblyMCInstLower::lower(const MachineInstr *MI,
MCOp = MCOperand::createReg(WAReg);
break;
}
case llvm::MachineOperand::MO_CImmediate: {
// Lower type index placeholder for ref.test
// Currently this is the only way that CImmediates show up so panic if we
// get confused.
unsigned DescIndex = I - NumVariadicDefs;
if (DescIndex >= Desc.NumOperands) {
llvm_unreachable("unexpected CImmediate operand");
}
const MCOperandInfo &Info = Desc.operands()[DescIndex];
if (Info.OperandType != WebAssembly::OPERAND_TYPEINDEX) {
llvm_unreachable("unexpected CImmediate operand");
}
auto CImm = MO.getCImm()->getValue();
auto NumWords = CImm.getNumWords() - 1;
// Extract the type data we packed into the CImm in LowerRefTestFuncRef.
// We need to load the words from most significant to least significant
// order because of the way we bitshifted them in from the right.
// The return type needs special handling because it could be void.
auto ReturnType = static_cast<WebAssembly::BlockType>(
CImm.extractBitsAsZExtValue(64, (NumWords - 1) * 64));
SmallVector<wasm::ValType, 2> Returns;
switch (ReturnType) {
case WebAssembly::BlockType::Invalid:
llvm_unreachable("Invalid return type");
case WebAssembly::BlockType::I32:
Returns = {wasm::ValType::I32};
break;
case WebAssembly::BlockType::I64:
Returns = {wasm::ValType::I64};
break;
case WebAssembly::BlockType::F32:
Returns = {wasm::ValType::F32};
break;
case WebAssembly::BlockType::F64:
Returns = {wasm::ValType::F64};
break;
case WebAssembly::BlockType::Void:
Returns = {};
break;
case WebAssembly::BlockType::Exnref:
Returns = {wasm::ValType::EXNREF};
break;
case WebAssembly::BlockType::Externref:
Returns = {wasm::ValType::EXTERNREF};
break;
case WebAssembly::BlockType::Funcref:
Returns = {wasm::ValType::FUNCREF};
break;
case WebAssembly::BlockType::V128:
Returns = {wasm::ValType::V128};
break;
case WebAssembly::BlockType::Multivalue: {
llvm_unreachable("Invalid return type");
}
}
SmallVector<wasm::ValType, 4> Params;

for (int I = NumWords - 2; I >= 0; I--) {
auto Val = CImm.extractBitsAsZExtValue(64, 64 * I);
auto ParamType = static_cast<wasm::ValType>(Val);
Params.push_back(ParamType);
}
MCOp = lowerTypeIndexOperand(std::move(Returns), std::move(Params));
break;
}
case MachineOperand::MO_Immediate: {
unsigned DescIndex = I - NumVariadicDefs;
if (DescIndex < Desc.NumOperands) {
const MCOperandInfo &Info = Desc.operands()[DescIndex];
// Replace type index placeholder with actual type index. The type index
// placeholders are Immediates and have an operand type of
// OPERAND_TYPEINDEX or OPERAND_SIGNATURE.
if (Info.OperandType == WebAssembly::OPERAND_TYPEINDEX) {
// Lower type index placeholder for a CALL_INDIRECT instruction
SmallVector<wasm::ValType, 4> Returns;
SmallVector<wasm::ValType, 4> Params;

Expand Down Expand Up @@ -228,6 +301,7 @@ void WebAssemblyMCInstLower::lower(const MachineInstr *MI,
break;
}
if (Info.OperandType == WebAssembly::OPERAND_SIGNATURE) {
// Lower type index placeholder for blocks
auto BT = static_cast<WebAssembly::BlockType>(MO.getImm());
assert(BT != WebAssembly::BlockType::Invalid);
if (BT == WebAssembly::BlockType::Multivalue) {
Expand Down
42 changes: 42 additions & 0 deletions llvm/test/CodeGen/WebAssembly/ref-test-func.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
; RUN: llc < %s -mcpu=mvp -mattr=+reference-types | FileCheck %s

target triple = "wasm32-unknown-unknown"

; CHECK-LABEL: test_function_pointer_signature_void:
; CHECK-NEXT: .functype test_function_pointer_signature_void (i32) -> ()
; CHECK-NEXT: .local funcref
; CHECK: local.get 0
; CHECK-NEXT: table.get __indirect_function_table
; CHECK-NEXT: local.tee 1
; CHECK-NEXT: ref.test (f32, f64, i32) -> (f32)
; CHECK-NEXT: call use
; CHECK-NEXT: local.get 1
; CHECK-NEXT: ref.test (f32, f64, i32) -> (i32)
; CHECK-NEXT: call use
; CHECK-NEXT: local.get 1
; CHECK-NEXT: ref.test (i32, i32, i32) -> (i32)
; CHECK-NEXT: call use
; CHECK-NEXT: local.get 1
; CHECK-NEXT: ref.test (i32, i32, i32) -> ()
; CHECK-NEXT: call use
; CHECK-NEXT: local.get 1
; CHECK-NEXT: ref.test () -> ()
; CHECK-NEXT: call use

; Function Attrs: nounwind
define void @test_function_pointer_signature_void(ptr noundef %func) local_unnamed_addr #0 {
entry:
%0 = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, float 0.000000e+00, float 0.000000e+00, double 0.000000e+00, i32 0)
tail call void @use(i32 noundef %0) #3
%1 = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, i32 0, float 0.000000e+00, double 0.000000e+00, i32 0)
tail call void @use(i32 noundef %1) #3
%2 = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, i32 0, i32 0, i32 0, i32 0)
tail call void @use(i32 noundef %2) #3
%3 = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, token poison, i32 0, i32 0, i32 0)
tail call void @use(i32 noundef %3) #3
%4 = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, token poison)
tail call void @use(i32 noundef %4) #3
ret void
}

declare void @use(i32 noundef) local_unnamed_addr #1
Loading