Skip to content

[WebAssembly,llvm] Add llvm.wasm.ref.test.func intrinsic #147076

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions llvm/include/llvm/IR/IntrinsicsWebAssembly.td
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ def int_wasm_ref_is_null_exn :
DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_exnref_ty], [IntrNoMem],
"llvm.wasm.ref.is_null.exn">;

def int_wasm_ref_test_func
: DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_ptr_ty, llvm_vararg_ty],
[IntrNoMem], "llvm.wasm.ref.test.func">;

//===----------------------------------------------------------------------===//
// Table intrinsics
//===----------------------------------------------------------------------===//
Expand Down
115 changes: 115 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "WebAssemblySubtarget.h"
#include "WebAssemblyTargetMachine.h"
#include "WebAssemblyUtilities.h"
#include "llvm/BinaryFormat/Wasm.h"
#include "llvm/CodeGen/CallingConvLower.h"
#include "llvm/CodeGen/MachineFrameInfo.h"
#include "llvm/CodeGen/MachineInstrBuilder.h"
Expand Down Expand Up @@ -501,6 +502,51 @@ MVT WebAssemblyTargetLowering::getScalarShiftAmountTy(const DataLayout & /*DL*/,
return Result;
}

static MachineBasicBlock *LowerRefTestFuncRef(MachineInstr &MI, DebugLoc DL,
MachineBasicBlock *BB,
const TargetInstrInfo &TII) {
// Lower a REF_TEST_FUNCREF_PSEUDO instruction into a REF_TEST_FUNCREF
// instruction by combining the signature info Imm operands that
// SelectionDag/InstrEmitter.cpp makes into one CImm operand. Put this into
// the type index placeholder for REF_TEST_FUNCREF
Register ResultReg = MI.getOperand(0).getReg();
Register FuncRefReg = MI.getOperand(1).getReg();

auto NParams = MI.getNumOperands() - 3;
auto Sig = APInt(NParams * 64, 0);

{
uint64_t V = MI.getOperand(2).getImm();
Sig |= int64_t(V);
}

for (unsigned I = 3; I < MI.getNumOperands(); I++) {
const MachineOperand &MO = MI.getOperand(I);
if (!MO.isImm()) {
// I'm not really sure what these are or where they come from but it seems
// to be okay to ignore them
continue;
}
uint16_t V = MO.getImm();
Sig <<= 64;
Sig |= int64_t(V);
}

ConstantInt *TypeInfo =
ConstantInt::get(BB->getParent()->getFunction().getContext(), Sig);

// Put the type info first in the placeholder for the type index, then the
// actual funcref arg
BuildMI(*BB, MI, DL, TII.get(WebAssembly::REF_TEST_FUNCREF), ResultReg)
.addCImm(TypeInfo)
.addReg(FuncRefReg);

// Remove the original instruction
MI.eraseFromParent();

return BB;
}

// Lower an fp-to-int conversion operator from the LLVM opcode, which has an
// undefined result on invalid/overflow, to the WebAssembly opcode, which
// traps on invalid/overflow.
Expand Down Expand Up @@ -794,6 +840,7 @@ LowerCallResults(MachineInstr &CallResults, DebugLoc DL, MachineBasicBlock *BB,

if (IsIndirect) {
// Placeholder for the type index.
// This gets replaced with the correct value in WebAssemblyMCInstLower.cpp
MIB.addImm(0);
// The table into which this call_indirect indexes.
MCSymbolWasm *Table = IsFuncrefCall
Expand Down Expand Up @@ -862,6 +909,8 @@ MachineBasicBlock *WebAssemblyTargetLowering::EmitInstrWithCustomInserter(
switch (MI.getOpcode()) {
default:
llvm_unreachable("Unexpected instr type to insert");
case WebAssembly::REF_TEST_FUNCREF_PSEUDO:
return LowerRefTestFuncRef(MI, DL, BB, TII);
case WebAssembly::FP_TO_SINT_I32_F32:
return LowerFPToInt(MI, DL, BB, TII, false, false, false,
WebAssembly::I32_TRUNC_S_F32);
Expand Down Expand Up @@ -2253,6 +2302,72 @@ SDValue WebAssemblyTargetLowering::LowerIntrinsic(SDValue Op,
DAG.getTargetExternalSymbol(TlsBase, PtrVT)),
0);
}
case Intrinsic::wasm_ref_test_func: {
// First emit the TABLE_GET instruction to convert function pointer ==>
// funcref
MachineFunction &MF = DAG.getMachineFunction();
auto PtrVT = getPointerTy(MF.getDataLayout());
MCSymbol *Table =
WebAssembly::getOrCreateFunctionTableSymbol(MF.getContext(), Subtarget);
SDValue TableSym = DAG.getMCSymbol(Table, PtrVT);
SDValue FuncRef =
SDValue(DAG.getMachineNode(WebAssembly::TABLE_GET_FUNCREF, DL,
MVT::funcref, TableSym, Op.getOperand(1)),
0);

SmallVector<SDValue, 4> Ops;
Ops.push_back(FuncRef);

// We want to encode the type information into an APInt which we'll put
// in a CImm. However, in SelectionDag/InstrEmitter.cpp there is no code
// path that emits a CImm. So we need a custom inserter to put it in.

// We'll put each type argument in a separate TargetConstant which gets
// lowered to a MachineInstruction Imm. We combine these into a CImm in our
// custom inserter because it creates a problem downstream to have all these
// extra immediates.
{
SDValue Operand = Op.getOperand(2);
MVT VT = Operand.getValueType().getSimpleVT();
WebAssembly::BlockType V;
if (VT == MVT::Untyped) {
V = WebAssembly::BlockType::Void;
} else if (VT == MVT::i32) {
V = WebAssembly::BlockType::I32;
} else if (VT == MVT::i64) {
V = WebAssembly::BlockType::I64;
} else if (VT == MVT::f32) {
V = WebAssembly::BlockType::F32;
} else if (VT == MVT::f64) {
V = WebAssembly::BlockType::F64;
} else {
llvm_unreachable("Unhandled type!");
}
Ops.push_back(DAG.getTargetConstant((int64_t)V, DL, MVT::i64));
}

for (unsigned i = 3; i < Op.getNumOperands(); ++i) {
SDValue Operand = Op.getOperand(i);
MVT VT = Operand.getValueType().getSimpleVT();
wasm::ValType V;
if (VT == MVT::i32) {
V = wasm::ValType::I32;
} else if (VT == MVT::i64) {
V = wasm::ValType::I64;
} else if (VT == MVT::f32) {
V = wasm::ValType::F32;
} else if (VT == MVT::f64) {
V = wasm::ValType::F64;
} else {
llvm_unreachable("Unhandled type!");
}
Ops.push_back(DAG.getTargetConstant((int64_t)V, DL, MVT::i64));
}

return SDValue(DAG.getMachineNode(WebAssembly::REF_TEST_FUNCREF_PSEUDO, DL,
MVT::i32, Ops),
0);
}
}
}

Expand Down
5 changes: 5 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyInstrRef.td
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ multiclass REF_I<WebAssemblyRegClass rc, ValueType vt, string ht> {
Requires<[HasReferenceTypes]>;
}

let usesCustomInserter = 1, isPseudo = 1 in defm REF_TEST_FUNCREF_PSEUDO
: I<(outs I32:$res), (ins TypeIndex:$type, FUNCREF:$ref, variable_ops),
(outs), (ins TypeIndex:$type), [], "ref.test.pseudo\t$type, $ref",
"ref.test.pseudo $type", -1>;

defm REF_TEST_FUNCREF :
I<(outs I32: $res),
(ins TypeIndex:$type, FUNCREF: $ref),
Expand Down
74 changes: 74 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,17 @@
#include "WebAssemblyMCInstLower.h"
#include "MCTargetDesc/WebAssemblyMCAsmInfo.h"
#include "MCTargetDesc/WebAssemblyMCTargetDesc.h"
#include "MCTargetDesc/WebAssemblyMCTypeUtilities.h"
#include "TargetInfo/WebAssemblyTargetInfo.h"
#include "Utils/WebAssemblyTypeUtilities.h"
#include "WebAssemblyAsmPrinter.h"
#include "WebAssemblyMachineFunctionInfo.h"
#include "WebAssemblyUtilities.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/BinaryFormat/Wasm.h"
#include "llvm/CodeGen/AsmPrinter.h"
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/CodeGen/MachineOperand.h"
#include "llvm/IR/Constants.h"
#include "llvm/MC/MCAsmInfo.h"
#include "llvm/MC/MCContext.h"
Expand Down Expand Up @@ -196,11 +200,80 @@ void WebAssemblyMCInstLower::lower(const MachineInstr *MI,
MCOp = MCOperand::createReg(WAReg);
break;
}
case llvm::MachineOperand::MO_CImmediate: {
// Lower type index placeholder for ref.test
// Currently this is the only way that CImmediates show up so panic if we
// get confused.
unsigned DescIndex = I - NumVariadicDefs;
if (DescIndex >= Desc.NumOperands) {
llvm_unreachable("unexpected CImmediate operand");
}
const MCOperandInfo &Info = Desc.operands()[DescIndex];
if (Info.OperandType != WebAssembly::OPERAND_TYPEINDEX) {
llvm_unreachable("unexpected CImmediate operand");
}
auto CImm = MO.getCImm()->getValue();
auto NumWords = CImm.getNumWords();
// Extract the type data we packed into the CImm in LowerRefTestFuncRef.
// We need to load the words from most significant to least significant
// order because of the way we bitshifted them in from the right.
// The return type needs special handling because it could be void.
auto ReturnType = static_cast<WebAssembly::BlockType>(
CImm.extractBitsAsZExtValue(64, (NumWords - 1) * 64));
SmallVector<wasm::ValType, 2> Returns;
switch (ReturnType) {
case WebAssembly::BlockType::Invalid:
llvm_unreachable("Invalid return type");
case WebAssembly::BlockType::I32:
Returns = {wasm::ValType::I32};
break;
case WebAssembly::BlockType::I64:
Returns = {wasm::ValType::I64};
break;
case WebAssembly::BlockType::F32:
Returns = {wasm::ValType::F32};
break;
case WebAssembly::BlockType::F64:
Returns = {wasm::ValType::F64};
break;
case WebAssembly::BlockType::Void:
Returns = {};
break;
case WebAssembly::BlockType::Exnref:
Returns = {wasm::ValType::EXNREF};
break;
case WebAssembly::BlockType::Externref:
Returns = {wasm::ValType::EXTERNREF};
break;
case WebAssembly::BlockType::Funcref:
Returns = {wasm::ValType::FUNCREF};
break;
case WebAssembly::BlockType::V128:
Returns = {wasm::ValType::V128};
break;
case WebAssembly::BlockType::Multivalue: {
llvm_unreachable("Invalid return type");
}
}
SmallVector<wasm::ValType, 4> Params;

for (int I = NumWords - 2; I >= 0; I--) {
auto Val = CImm.extractBitsAsZExtValue(64, 64 * I);
auto ParamType = static_cast<wasm::ValType>(Val);
Params.push_back(ParamType);
}
MCOp = lowerTypeIndexOperand(std::move(Returns), std::move(Params));
break;
}
case MachineOperand::MO_Immediate: {
unsigned DescIndex = I - NumVariadicDefs;
if (DescIndex < Desc.NumOperands) {
const MCOperandInfo &Info = Desc.operands()[DescIndex];
// Replace type index placeholder with actual type index. The type index
// placeholders are Immediates and have an operand type of
// OPERAND_TYPEINDEX or OPERAND_SIGNATURE.
if (Info.OperandType == WebAssembly::OPERAND_TYPEINDEX) {
// Lower type index placeholder for a CALL_INDIRECT instruction
SmallVector<wasm::ValType, 4> Returns;
SmallVector<wasm::ValType, 4> Params;

Expand Down Expand Up @@ -228,6 +301,7 @@ void WebAssemblyMCInstLower::lower(const MachineInstr *MI,
break;
}
if (Info.OperandType == WebAssembly::OPERAND_SIGNATURE) {
// Lower type index placeholder for blocks
auto BT = static_cast<WebAssembly::BlockType>(MO.getImm());
assert(BT != WebAssembly::BlockType::Invalid);
if (BT == WebAssembly::BlockType::Multivalue) {
Expand Down
42 changes: 42 additions & 0 deletions llvm/test/CodeGen/WebAssembly/ref-test-func.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
; RUN: llc < %s -mcpu=mvp -mattr=+reference-types | FileCheck %s

target triple = "wasm32-unknown-unknown"

; CHECK-LABEL: test_function_pointer_signature_void:
; CHECK-NEXT: .functype test_function_pointer_signature_void (i32) -> ()
; CHECK-NEXT: .local funcref
; CHECK: local.get 0
; CHECK-NEXT: table.get __indirect_function_table
; CHECK-NEXT: local.tee 1
; CHECK-NEXT: ref.test (f32, f64, i32) -> (f32)
; CHECK-NEXT: call use
; CHECK-NEXT: local.get 1
; CHECK-NEXT: ref.test (f32, f64, i32) -> (i32)
; CHECK-NEXT: call use
; CHECK-NEXT: local.get 1
; CHECK-NEXT: ref.test (i32, i32, i32) -> (i32)
; CHECK-NEXT: call use
; CHECK-NEXT: local.get 1
; CHECK-NEXT: ref.test (i32, i32, i32) -> ()
; CHECK-NEXT: call use
; CHECK-NEXT: local.get 1
; CHECK-NEXT: ref.test () -> ()
; CHECK-NEXT: call use

; Function Attrs: nounwind
define void @test_function_pointer_signature_void(ptr noundef %func) local_unnamed_addr #0 {
entry:
%0 = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, float 0.000000e+00, float 0.000000e+00, double 0.000000e+00, i32 0)
tail call void @use(i32 noundef %0) #3
%1 = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, i32 0, float 0.000000e+00, double 0.000000e+00, i32 0)
tail call void @use(i32 noundef %1) #3
%2 = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, i32 0, i32 0, i32 0, i32 0)
tail call void @use(i32 noundef %2) #3
%3 = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, token poison, i32 0, i32 0, i32 0)
tail call void @use(i32 noundef %3) #3
%4 = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, token poison)
tail call void @use(i32 noundef %4) #3
ret void
}

declare void @use(i32 noundef) local_unnamed_addr #1