Skip to content

relaxed simd fma #147487

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "llvm/CodeGen/SDPatternMatch.h"
#include "llvm/CodeGen/SelectionDAG.h"
#include "llvm/CodeGen/SelectionDAGNodes.h"
#include "llvm/CodeGenTypes/MachineValueType.h"
#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/DiagnosticPrinter.h"
#include "llvm/IR/Function.h"
Expand Down Expand Up @@ -182,6 +183,12 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
// SIMD-specific configuration
if (Subtarget->hasSIMD128()) {

// Enable fma optimization for wasm relaxed simd
if (Subtarget->hasRelaxedSIMD()) {
setTargetDAGCombine(ISD::FADD);
setTargetDAGCombine(ISD::FMA);
}

// Combine partial.reduce.add before legalization gets confused.
setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN);

Expand Down Expand Up @@ -3411,6 +3418,62 @@ static SDValue performSETCCCombine(SDNode *N,
}
return SDValue();
}
static bool canRelaxSimd(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
EVT VecVT = N->getValueType(0);

// INFO: WebAssembly doesn't have scalar fma yet
// https://github.com/WebAssembly/design/issues/1391
if (!VecVT.isVector())
return false;

// Allows fp fusing
if (!N->getFlags().hasAllowContract())
return false;

if (N->getValueType(0).bitsGT(MVT::f128))
return false;

return true;
}
static SDValue performFAddCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI) {
assert(N->getOpcode() == ISD::FADD);
using namespace llvm::SDPatternMatch;

// INFO: WebAssembly doesn't have scalar fma yet
// https://github.com/WebAssembly/design/issues/1391
EVT VecVT = N->getValueType(0);
if (!VecVT.isVector())
return SDValue();

if (!canRelaxSimd(N, DCI))
return SDValue();

SDLoc DL(N);
SDValue A, B, C;
SelectionDAG &DAG = DCI.DAG;
if (sd_match(N, m_FAdd(m_Value(A), m_FMul(m_Value(B), m_Value(C)))))
return DAG.getNode(
ISD::INTRINSIC_WO_CHAIN, DL, VecVT,
{DAG.getConstant(Intrinsic::wasm_relaxed_madd, DL, MVT::i32), A, B, C});

return SDValue();
}

static SDValue performFMACombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI) {
assert(N->getOpcode() == ISD::FMA);

if (!canRelaxSimd(N, DCI))
return SDValue();

SDLoc DL(N);
SDValue A = N->getOperand(0), B = N->getOperand(1), C = N->getOperand(2);
SelectionDAG &DAG = DCI.DAG;
return DAG.getNode(
ISD::INTRINSIC_WO_CHAIN, DL, N->getValueType(0),
{DAG.getConstant(Intrinsic::wasm_relaxed_madd, DL, MVT::i32), A, B, C});
}

static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG) {
assert(N->getOpcode() == ISD::MUL);
Expand Down Expand Up @@ -3529,6 +3592,10 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
return AnyAllCombine;
return performLowerPartialReduction(N, DCI.DAG);
}
case ISD::FADD:
return performFAddCombine(N, DCI);
case ISD::FMA:
return performFMACombine(N, DCI);
case ISD::MUL:
return performMulCombine(N, DCI.DAG);
}
Expand Down
130 changes: 130 additions & 0 deletions llvm/test/CodeGen/WebAssembly/simd-relaxed-fma.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5

; RUN: llc < %s -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers -mattr=+simd128,+relaxed-simd | FileCheck %s
target triple = "wasm32"
define <4 x float> @fma_vector_4xf32_seperate(<4 x float> %a, <4 x float> %b, <4 x float> %c) {
; CHECK-LABEL: fma_vector_4xf32_seperate:
; CHECK: .functype fma_vector_4xf32_seperate (v128, v128, v128) -> (v128)
; CHECK-NEXT: # %bb.0: # %entry
; CHECK-NEXT: f32x4.relaxed_madd $push0=, $2, $1, $0
; CHECK-NEXT: return $pop0
entry:
%mul.i = fmul fast <4 x float> %b, %a
%add.i = fadd fast <4 x float> %mul.i, %c
ret <4 x float> %add.i
}

define <4 x float> @fma_vector_4xf32_llvm(<4 x float> %a, <4 x float> %b, <4 x float> %c) {
; CHECK-LABEL: fma_vector_4xf32_llvm:
; CHECK: .functype fma_vector_4xf32_llvm (v128, v128, v128) -> (v128)
; CHECK-NEXT: # %bb.0: # %entry
; CHECK-NEXT: f32x4.relaxed_madd $push0=, $0, $1, $2
; CHECK-NEXT: return $pop0
entry:
%fma = tail call fast <4 x float> @llvm.fma(<4 x float> %a, <4 x float> %b, <4 x float> %c)
ret <4 x float> %fma
}


define <8 x float> @fma_vector_8xf32_seperate(<8 x float> %a, <8 x float> %b, <8 x float> %c) {
; CHECK-LABEL: fma_vector_8xf32_seperate:
; CHECK: .functype fma_vector_8xf32_seperate (i32, v128, v128, v128, v128, v128, v128) -> ()
; CHECK-NEXT: # %bb.0: # %entry
; CHECK-NEXT: f32x4.relaxed_madd $push0=, $6, $4, $2
; CHECK-NEXT: v128.store 16($0), $pop0
; CHECK-NEXT: f32x4.relaxed_madd $push1=, $5, $3, $1
; CHECK-NEXT: v128.store 0($0), $pop1
; CHECK-NEXT: return
entry:
%mul.i = fmul fast <8 x float> %b, %a
%add.i = fadd fast <8 x float> %mul.i, %c
ret <8 x float> %add.i
}

define <8 x float> @fma_vector_8xf32_llvm(<8 x float> %a, <8 x float> %b, <8 x float> %c) {
; CHECK-LABEL: fma_vector_8xf32_llvm:
; CHECK: .functype fma_vector_8xf32_llvm (i32, v128, v128, v128, v128, v128, v128) -> ()
; CHECK-NEXT: # %bb.0: # %entry
; CHECK-NEXT: f32x4.relaxed_madd $push0=, $2, $4, $6
; CHECK-NEXT: v128.store 16($0), $pop0
; CHECK-NEXT: f32x4.relaxed_madd $push1=, $1, $3, $5
; CHECK-NEXT: v128.store 0($0), $pop1
; CHECK-NEXT: return
entry:
%fma = tail call fast <8 x float> @llvm.fma(<8 x float> %a, <8 x float> %b, <8 x float> %c)
ret <8 x float> %fma
}


define <2 x double> @fma_vector_2xf64_seperate(<2 x double> %a, <2 x double> %b, <2 x double> %c) {
; CHECK-LABEL: fma_vector_2xf64_seperate:
; CHECK: .functype fma_vector_2xf64_seperate (v128, v128, v128) -> (v128)
; CHECK-NEXT: # %bb.0: # %entry
; CHECK-NEXT: f64x2.relaxed_madd $push0=, $2, $1, $0
; CHECK-NEXT: return $pop0
entry:
%mul.i = fmul fast <2 x double> %b, %a
%add.i = fadd fast <2 x double> %mul.i, %c
ret <2 x double> %add.i
}

define <2 x double> @fma_vector_2xf64_llvm(<2 x double> %a, <2 x double> %b, <2 x double> %c) {
; CHECK-LABEL: fma_vector_2xf64_llvm:
; CHECK: .functype fma_vector_2xf64_llvm (v128, v128, v128) -> (v128)
; CHECK-NEXT: # %bb.0: # %entry
; CHECK-NEXT: f64x2.relaxed_madd $push0=, $0, $1, $2
; CHECK-NEXT: return $pop0
entry:
%fma = tail call fast <2 x double> @llvm.fma(<2 x double> %a, <2 x double> %b, <2 x double> %c)
ret <2 x double> %fma
}


define float @fma_scalar_f32_seperate(float %a, float %b, float %c) {
; CHECK-LABEL: fma_scalar_f32_seperate:
; CHECK: .functype fma_scalar_f32_seperate (f32, f32, f32) -> (f32)
; CHECK-NEXT: # %bb.0: # %entry
; CHECK-NEXT: f32.mul $push0=, $1, $0
; CHECK-NEXT: f32.add $push1=, $pop0, $2
; CHECK-NEXT: return $pop1
entry:
%mul.i = fmul fast float %b, %a
%add.i = fadd fast float %mul.i, %c
ret float %add.i
}

define float @fma_scalar_f32_llvm(float %a, float %b, float %c) {
; CHECK-LABEL: fma_scalar_f32_llvm:
; CHECK: .functype fma_scalar_f32_llvm (f32, f32, f32) -> (f32)
; CHECK-NEXT: # %bb.0: # %entry
; CHECK-NEXT: call $push0=, fmaf, $0, $1, $2
; CHECK-NEXT: return $pop0
entry:
%fma = tail call fast float @llvm.fma(float %a, float %b, float %c)
ret float %fma
}


define double @fma_scalar_f64_seperate(double %a, double %b, double %c) {
; CHECK-LABEL: fma_scalar_f64_seperate:
; CHECK: .functype fma_scalar_f64_seperate (f64, f64, f64) -> (f64)
; CHECK-NEXT: # %bb.0: # %entry
; CHECK-NEXT: f64.mul $push0=, $1, $0
; CHECK-NEXT: f64.add $push1=, $pop0, $2
; CHECK-NEXT: return $pop1
entry:
%mul.i = fmul fast double %b, %a
%add.i = fadd fast double %mul.i, %c
ret double %add.i
}

define double @fma_scalar_f64_llvm(double %a, double %b, double %c) {
; CHECK-LABEL: fma_scalar_f64_llvm:
; CHECK: .functype fma_scalar_f64_llvm (f64, f64, f64) -> (f64)
; CHECK-NEXT: # %bb.0: # %entry
; CHECK-NEXT: call $push0=, fma, $0, $1, $2
; CHECK-NEXT: return $pop0
entry:
%fma = tail call fast double @llvm.fma(double %a, double %b, double %c)
ret double %fma
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test scalar cases and other vector widths