diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp index bf2e04caa0a61..15038df6d5f6c 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp @@ -27,6 +27,7 @@ #include "llvm/CodeGen/SDPatternMatch.h" #include "llvm/CodeGen/SelectionDAG.h" #include "llvm/CodeGen/SelectionDAGNodes.h" +#include "llvm/CodeGenTypes/MachineValueType.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/DiagnosticPrinter.h" #include "llvm/IR/Function.h" @@ -182,6 +183,12 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering( // SIMD-specific configuration if (Subtarget->hasSIMD128()) { + // Enable fma optimization for wasm relaxed simd + if (Subtarget->hasRelaxedSIMD()) { + setTargetDAGCombine(ISD::FADD); + setTargetDAGCombine(ISD::FMA); + } + // Combine partial.reduce.add before legalization gets confused. setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN); @@ -3411,6 +3418,62 @@ static SDValue performSETCCCombine(SDNode *N, } return SDValue(); } +static bool canRelaxSimd(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { + EVT VecVT = N->getValueType(0); + + // INFO: WebAssembly doesn't have scalar fma yet + // https://github.com/WebAssembly/design/issues/1391 + if (!VecVT.isVector()) + return false; + + // Allows fp fusing + if (!N->getFlags().hasAllowContract()) + return false; + + if (N->getValueType(0).bitsGT(MVT::f128)) + return false; + + return true; +} +static SDValue performFAddCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI) { + assert(N->getOpcode() == ISD::FADD); + using namespace llvm::SDPatternMatch; + + // INFO: WebAssembly doesn't have scalar fma yet + // https://github.com/WebAssembly/design/issues/1391 + EVT VecVT = N->getValueType(0); + if (!VecVT.isVector()) + return SDValue(); + + if (!canRelaxSimd(N, DCI)) + return SDValue(); + + SDLoc DL(N); + SDValue A, B, C; + SelectionDAG &DAG = DCI.DAG; + if (sd_match(N, m_FAdd(m_Value(A), m_FMul(m_Value(B), m_Value(C))))) + return DAG.getNode( + ISD::INTRINSIC_WO_CHAIN, DL, VecVT, + {DAG.getConstant(Intrinsic::wasm_relaxed_madd, DL, MVT::i32), A, B, C}); + + return SDValue(); +} + +static SDValue performFMACombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI) { + assert(N->getOpcode() == ISD::FMA); + + if (!canRelaxSimd(N, DCI)) + return SDValue(); + + SDLoc DL(N); + SDValue A = N->getOperand(0), B = N->getOperand(1), C = N->getOperand(2); + SelectionDAG &DAG = DCI.DAG; + return DAG.getNode( + ISD::INTRINSIC_WO_CHAIN, DL, N->getValueType(0), + {DAG.getConstant(Intrinsic::wasm_relaxed_madd, DL, MVT::i32), A, B, C}); +} static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG) { assert(N->getOpcode() == ISD::MUL); @@ -3529,6 +3592,10 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N, return AnyAllCombine; return performLowerPartialReduction(N, DCI.DAG); } + case ISD::FADD: + return performFAddCombine(N, DCI); + case ISD::FMA: + return performFMACombine(N, DCI); case ISD::MUL: return performMulCombine(N, DCI.DAG); } diff --git a/llvm/test/CodeGen/WebAssembly/simd-relaxed-fma.ll b/llvm/test/CodeGen/WebAssembly/simd-relaxed-fma.ll new file mode 100644 index 0000000000000..e4bd6a3a8cda6 --- /dev/null +++ b/llvm/test/CodeGen/WebAssembly/simd-relaxed-fma.ll @@ -0,0 +1,130 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 + +; RUN: llc < %s -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers -mattr=+simd128,+relaxed-simd | FileCheck %s +target triple = "wasm32" +define <4 x float> @fma_vector_4xf32_seperate(<4 x float> %a, <4 x float> %b, <4 x float> %c) { +; CHECK-LABEL: fma_vector_4xf32_seperate: +; CHECK: .functype fma_vector_4xf32_seperate (v128, v128, v128) -> (v128) +; CHECK-NEXT: # %bb.0: # %entry +; CHECK-NEXT: f32x4.relaxed_madd $push0=, $2, $1, $0 +; CHECK-NEXT: return $pop0 +entry: + %mul.i = fmul fast <4 x float> %b, %a + %add.i = fadd fast <4 x float> %mul.i, %c + ret <4 x float> %add.i +} + +define <4 x float> @fma_vector_4xf32_llvm(<4 x float> %a, <4 x float> %b, <4 x float> %c) { +; CHECK-LABEL: fma_vector_4xf32_llvm: +; CHECK: .functype fma_vector_4xf32_llvm (v128, v128, v128) -> (v128) +; CHECK-NEXT: # %bb.0: # %entry +; CHECK-NEXT: f32x4.relaxed_madd $push0=, $0, $1, $2 +; CHECK-NEXT: return $pop0 +entry: + %fma = tail call fast <4 x float> @llvm.fma(<4 x float> %a, <4 x float> %b, <4 x float> %c) + ret <4 x float> %fma +} + + +define <8 x float> @fma_vector_8xf32_seperate(<8 x float> %a, <8 x float> %b, <8 x float> %c) { +; CHECK-LABEL: fma_vector_8xf32_seperate: +; CHECK: .functype fma_vector_8xf32_seperate (i32, v128, v128, v128, v128, v128, v128) -> () +; CHECK-NEXT: # %bb.0: # %entry +; CHECK-NEXT: f32x4.relaxed_madd $push0=, $6, $4, $2 +; CHECK-NEXT: v128.store 16($0), $pop0 +; CHECK-NEXT: f32x4.relaxed_madd $push1=, $5, $3, $1 +; CHECK-NEXT: v128.store 0($0), $pop1 +; CHECK-NEXT: return +entry: + %mul.i = fmul fast <8 x float> %b, %a + %add.i = fadd fast <8 x float> %mul.i, %c + ret <8 x float> %add.i +} + +define <8 x float> @fma_vector_8xf32_llvm(<8 x float> %a, <8 x float> %b, <8 x float> %c) { +; CHECK-LABEL: fma_vector_8xf32_llvm: +; CHECK: .functype fma_vector_8xf32_llvm (i32, v128, v128, v128, v128, v128, v128) -> () +; CHECK-NEXT: # %bb.0: # %entry +; CHECK-NEXT: f32x4.relaxed_madd $push0=, $2, $4, $6 +; CHECK-NEXT: v128.store 16($0), $pop0 +; CHECK-NEXT: f32x4.relaxed_madd $push1=, $1, $3, $5 +; CHECK-NEXT: v128.store 0($0), $pop1 +; CHECK-NEXT: return +entry: + %fma = tail call fast <8 x float> @llvm.fma(<8 x float> %a, <8 x float> %b, <8 x float> %c) + ret <8 x float> %fma +} + + +define <2 x double> @fma_vector_2xf64_seperate(<2 x double> %a, <2 x double> %b, <2 x double> %c) { +; CHECK-LABEL: fma_vector_2xf64_seperate: +; CHECK: .functype fma_vector_2xf64_seperate (v128, v128, v128) -> (v128) +; CHECK-NEXT: # %bb.0: # %entry +; CHECK-NEXT: f64x2.relaxed_madd $push0=, $2, $1, $0 +; CHECK-NEXT: return $pop0 +entry: + %mul.i = fmul fast <2 x double> %b, %a + %add.i = fadd fast <2 x double> %mul.i, %c + ret <2 x double> %add.i +} + +define <2 x double> @fma_vector_2xf64_llvm(<2 x double> %a, <2 x double> %b, <2 x double> %c) { +; CHECK-LABEL: fma_vector_2xf64_llvm: +; CHECK: .functype fma_vector_2xf64_llvm (v128, v128, v128) -> (v128) +; CHECK-NEXT: # %bb.0: # %entry +; CHECK-NEXT: f64x2.relaxed_madd $push0=, $0, $1, $2 +; CHECK-NEXT: return $pop0 +entry: + %fma = tail call fast <2 x double> @llvm.fma(<2 x double> %a, <2 x double> %b, <2 x double> %c) + ret <2 x double> %fma +} + + +define float @fma_scalar_f32_seperate(float %a, float %b, float %c) { +; CHECK-LABEL: fma_scalar_f32_seperate: +; CHECK: .functype fma_scalar_f32_seperate (f32, f32, f32) -> (f32) +; CHECK-NEXT: # %bb.0: # %entry +; CHECK-NEXT: f32.mul $push0=, $1, $0 +; CHECK-NEXT: f32.add $push1=, $pop0, $2 +; CHECK-NEXT: return $pop1 +entry: + %mul.i = fmul fast float %b, %a + %add.i = fadd fast float %mul.i, %c + ret float %add.i +} + +define float @fma_scalar_f32_llvm(float %a, float %b, float %c) { +; CHECK-LABEL: fma_scalar_f32_llvm: +; CHECK: .functype fma_scalar_f32_llvm (f32, f32, f32) -> (f32) +; CHECK-NEXT: # %bb.0: # %entry +; CHECK-NEXT: call $push0=, fmaf, $0, $1, $2 +; CHECK-NEXT: return $pop0 +entry: + %fma = tail call fast float @llvm.fma(float %a, float %b, float %c) + ret float %fma +} + + +define double @fma_scalar_f64_seperate(double %a, double %b, double %c) { +; CHECK-LABEL: fma_scalar_f64_seperate: +; CHECK: .functype fma_scalar_f64_seperate (f64, f64, f64) -> (f64) +; CHECK-NEXT: # %bb.0: # %entry +; CHECK-NEXT: f64.mul $push0=, $1, $0 +; CHECK-NEXT: f64.add $push1=, $pop0, $2 +; CHECK-NEXT: return $pop1 +entry: + %mul.i = fmul fast double %b, %a + %add.i = fadd fast double %mul.i, %c + ret double %add.i +} + +define double @fma_scalar_f64_llvm(double %a, double %b, double %c) { +; CHECK-LABEL: fma_scalar_f64_llvm: +; CHECK: .functype fma_scalar_f64_llvm (f64, f64, f64) -> (f64) +; CHECK-NEXT: # %bb.0: # %entry +; CHECK-NEXT: call $push0=, fma, $0, $1, $2 +; CHECK-NEXT: return $pop0 +entry: + %fma = tail call fast double @llvm.fma(double %a, double %b, double %c) + ret double %fma +}