Skip to content

relaxed simd fma #147487

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open

relaxed simd fma #147487

wants to merge 4 commits into from

Conversation

badumbatish
Copy link
Contributor

Fixes llvm#121311, which folds a series of multiply and add to wasm.fma when
we have -mrelaxed-simd and -ffast-math.

Also attempted to use wasm.fma instead of the built in llvm.fma
@llvmbot llvmbot added backend:WebAssembly llvm:SelectionDAG SelectionDAGISel as well labels Jul 8, 2025
@llvmbot
Copy link
Member

llvmbot commented Jul 8, 2025

@llvm/pr-subscribers-backend-webassembly

Author: jjasmine (badumbatish)

Changes
  • Precommit test for #121311
  • [WASM] Optimize fma when relaxed and ffast-math

Full diff: https://github.com/llvm/llvm-project/pull/147487.diff

3 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/SelectionDAGNodes.h (+1)
  • (modified) llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp (+41)
  • (added) llvm/test/CodeGen/WebAssembly/simd-relaxed-fma.ll (+43)
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index a3675eecfea3f..ec566b168bc3d 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -475,6 +475,7 @@ struct SDNodeFlags {
   bool hasAllowReassociation() const { return Flags & AllowReassociation; }
   bool hasNoFPExcept() const { return Flags & NoFPExcept; }
   bool hasUnpredictable() const { return Flags & Unpredictable; }
+  bool hasFastMath() const { return Flags & FastMathFlags; }
 
   bool operator==(const SDNodeFlags &Other) const {
     return Flags == Other.Flags;
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index bf2e04caa0a61..ef0146f28aba1 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -182,6 +182,12 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
   // SIMD-specific configuration
   if (Subtarget->hasSIMD128()) {
 
+    // Enable fma optimization for wasm relaxed simd
+    if (Subtarget->hasRelaxedSIMD()) {
+      setTargetDAGCombine(ISD::FADD);
+      setTargetDAGCombine(ISD::FMA);
+    }
+
     // Combine partial.reduce.add before legalization gets confused.
     setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN);
 
@@ -3412,6 +3418,37 @@ static SDValue performSETCCCombine(SDNode *N,
   return SDValue();
 }
 
+static SDValue performFAddCombine(SDNode *N, SelectionDAG &DAG) {
+  assert(N->getOpcode() == ISD::FADD);
+  using namespace llvm::SDPatternMatch;
+  if (!N->getFlags().hasFastMath())
+    return SDValue();
+
+  SDLoc DL(N);
+  SDValue A, B, C;
+  EVT VecVT = N->getValueType(0);
+  if (sd_match(N, m_FAdd(m_Value(A), m_FMul(m_Value(B), m_Value(C)))))
+    return DAG.getNode(
+        ISD::INTRINSIC_WO_CHAIN, DL, VecVT,
+        {DAG.getConstant(Intrinsic::wasm_relaxed_madd, DL, MVT::i32), A, B, C});
+
+  return SDValue();
+}
+
+static SDValue performFMACombine(SDNode *N, SelectionDAG &DAG) {
+  assert(N->getOpcode() == ISD::FMA);
+  if (!N->getFlags().hasFastMath())
+    return SDValue();
+
+  SDLoc DL(N);
+  SDValue A = N->getOperand(0), B = N->getOperand(1), C = N->getOperand(2);
+  EVT VecVT = N->getValueType(0);
+
+  return DAG.getNode(
+      ISD::INTRINSIC_WO_CHAIN, DL, VecVT,
+      {DAG.getConstant(Intrinsic::wasm_relaxed_madd, DL, MVT::i32), A, B, C});
+}
+
 static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG) {
   assert(N->getOpcode() == ISD::MUL);
   EVT VT = N->getValueType(0);
@@ -3529,6 +3566,10 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
       return AnyAllCombine;
     return performLowerPartialReduction(N, DCI.DAG);
   }
+  case ISD::FADD:
+    return performFAddCombine(N, DCI.DAG);
+  case ISD::FMA:
+    return performFMACombine(N, DCI.DAG);
   case ISD::MUL:
     return performMulCombine(N, DCI.DAG);
   }
diff --git a/llvm/test/CodeGen/WebAssembly/simd-relaxed-fma.ll b/llvm/test/CodeGen/WebAssembly/simd-relaxed-fma.ll
new file mode 100644
index 0000000000000..fe5e8573f12b4
--- /dev/null
+++ b/llvm/test/CodeGen/WebAssembly/simd-relaxed-fma.ll
@@ -0,0 +1,43 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+
+; RUN: llc < %s -verify-machineinstrs -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers  -mattr=+simd128,+relaxed-simd | FileCheck %s
+target triple = "wasm32"
+define void @fma_seperate(ptr %a, ptr %b, ptr %c, ptr %dest) {
+; CHECK-LABEL: fma_seperate:
+; CHECK:         .functype fma_seperate (i32, i32, i32, i32) -> ()
+; CHECK-NEXT:  # %bb.0: # %entry
+; CHECK-NEXT:    v128.load $push2=, 0($2):p2align=0
+; CHECK-NEXT:    v128.load $push1=, 0($1):p2align=0
+; CHECK-NEXT:    v128.load $push0=, 0($0):p2align=0
+; CHECK-NEXT:    f32x4.relaxed_madd $push3=, $pop2, $pop1, $pop0
+; CHECK-NEXT:    v128.store 0($3):p2align=0, $pop3
+; CHECK-NEXT:    return
+entry:
+  %0 = load <4 x float>, ptr %a, align 1
+  %1 = load <4 x float>, ptr %b, align 1
+  %2 = load <4 x float>, ptr %c, align 1
+  %mul.i = fmul fast <4 x float> %1, %0
+  %add.i = fadd fast <4 x float> %mul.i, %2
+  store <4 x float> %add.i, ptr %dest, align 1
+  ret void
+}
+
+; Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(argmem: readwrite)
+define void @fma_llvm(ptr %a, ptr %b, ptr %c, ptr %dest) {
+; CHECK-LABEL: fma_llvm:
+; CHECK:         .functype fma_llvm (i32, i32, i32, i32) -> ()
+; CHECK-NEXT:  # %bb.0: # %entry
+; CHECK-NEXT:    v128.load $push2=, 0($0):p2align=0
+; CHECK-NEXT:    v128.load $push1=, 0($1):p2align=0
+; CHECK-NEXT:    v128.load $push0=, 0($2):p2align=0
+; CHECK-NEXT:    f32x4.relaxed_madd $push3=, $pop2, $pop1, $pop0
+; CHECK-NEXT:    v128.store 0($3):p2align=0, $pop3
+; CHECK-NEXT:    return
+entry:
+  %0 = load <4 x float>, ptr %a, align 1
+  %1 = load <4 x float>, ptr %b, align 1
+  %2 = load <4 x float>, ptr %c, align 1
+  %fma = tail call fast <4 x float> @llvm.fma.v4f32(<4 x float> %0, <4 x float> %1, <4 x float> %2)
+  store <4 x float> %fma, ptr %dest, align 1
+  ret void
+}

@llvmbot
Copy link
Member

llvmbot commented Jul 8, 2025

@llvm/pr-subscribers-llvm-selectiondag

Author: jjasmine (badumbatish)

Changes
  • Precommit test for #121311
  • [WASM] Optimize fma when relaxed and ffast-math

Full diff: https://github.com/llvm/llvm-project/pull/147487.diff

3 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/SelectionDAGNodes.h (+1)
  • (modified) llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp (+41)
  • (added) llvm/test/CodeGen/WebAssembly/simd-relaxed-fma.ll (+43)
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index a3675eecfea3f..ec566b168bc3d 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -475,6 +475,7 @@ struct SDNodeFlags {
   bool hasAllowReassociation() const { return Flags & AllowReassociation; }
   bool hasNoFPExcept() const { return Flags & NoFPExcept; }
   bool hasUnpredictable() const { return Flags & Unpredictable; }
+  bool hasFastMath() const { return Flags & FastMathFlags; }
 
   bool operator==(const SDNodeFlags &Other) const {
     return Flags == Other.Flags;
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index bf2e04caa0a61..ef0146f28aba1 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -182,6 +182,12 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
   // SIMD-specific configuration
   if (Subtarget->hasSIMD128()) {
 
+    // Enable fma optimization for wasm relaxed simd
+    if (Subtarget->hasRelaxedSIMD()) {
+      setTargetDAGCombine(ISD::FADD);
+      setTargetDAGCombine(ISD::FMA);
+    }
+
     // Combine partial.reduce.add before legalization gets confused.
     setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN);
 
@@ -3412,6 +3418,37 @@ static SDValue performSETCCCombine(SDNode *N,
   return SDValue();
 }
 
+static SDValue performFAddCombine(SDNode *N, SelectionDAG &DAG) {
+  assert(N->getOpcode() == ISD::FADD);
+  using namespace llvm::SDPatternMatch;
+  if (!N->getFlags().hasFastMath())
+    return SDValue();
+
+  SDLoc DL(N);
+  SDValue A, B, C;
+  EVT VecVT = N->getValueType(0);
+  if (sd_match(N, m_FAdd(m_Value(A), m_FMul(m_Value(B), m_Value(C)))))
+    return DAG.getNode(
+        ISD::INTRINSIC_WO_CHAIN, DL, VecVT,
+        {DAG.getConstant(Intrinsic::wasm_relaxed_madd, DL, MVT::i32), A, B, C});
+
+  return SDValue();
+}
+
+static SDValue performFMACombine(SDNode *N, SelectionDAG &DAG) {
+  assert(N->getOpcode() == ISD::FMA);
+  if (!N->getFlags().hasFastMath())
+    return SDValue();
+
+  SDLoc DL(N);
+  SDValue A = N->getOperand(0), B = N->getOperand(1), C = N->getOperand(2);
+  EVT VecVT = N->getValueType(0);
+
+  return DAG.getNode(
+      ISD::INTRINSIC_WO_CHAIN, DL, VecVT,
+      {DAG.getConstant(Intrinsic::wasm_relaxed_madd, DL, MVT::i32), A, B, C});
+}
+
 static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG) {
   assert(N->getOpcode() == ISD::MUL);
   EVT VT = N->getValueType(0);
@@ -3529,6 +3566,10 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
       return AnyAllCombine;
     return performLowerPartialReduction(N, DCI.DAG);
   }
+  case ISD::FADD:
+    return performFAddCombine(N, DCI.DAG);
+  case ISD::FMA:
+    return performFMACombine(N, DCI.DAG);
   case ISD::MUL:
     return performMulCombine(N, DCI.DAG);
   }
diff --git a/llvm/test/CodeGen/WebAssembly/simd-relaxed-fma.ll b/llvm/test/CodeGen/WebAssembly/simd-relaxed-fma.ll
new file mode 100644
index 0000000000000..fe5e8573f12b4
--- /dev/null
+++ b/llvm/test/CodeGen/WebAssembly/simd-relaxed-fma.ll
@@ -0,0 +1,43 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+
+; RUN: llc < %s -verify-machineinstrs -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers  -mattr=+simd128,+relaxed-simd | FileCheck %s
+target triple = "wasm32"
+define void @fma_seperate(ptr %a, ptr %b, ptr %c, ptr %dest) {
+; CHECK-LABEL: fma_seperate:
+; CHECK:         .functype fma_seperate (i32, i32, i32, i32) -> ()
+; CHECK-NEXT:  # %bb.0: # %entry
+; CHECK-NEXT:    v128.load $push2=, 0($2):p2align=0
+; CHECK-NEXT:    v128.load $push1=, 0($1):p2align=0
+; CHECK-NEXT:    v128.load $push0=, 0($0):p2align=0
+; CHECK-NEXT:    f32x4.relaxed_madd $push3=, $pop2, $pop1, $pop0
+; CHECK-NEXT:    v128.store 0($3):p2align=0, $pop3
+; CHECK-NEXT:    return
+entry:
+  %0 = load <4 x float>, ptr %a, align 1
+  %1 = load <4 x float>, ptr %b, align 1
+  %2 = load <4 x float>, ptr %c, align 1
+  %mul.i = fmul fast <4 x float> %1, %0
+  %add.i = fadd fast <4 x float> %mul.i, %2
+  store <4 x float> %add.i, ptr %dest, align 1
+  ret void
+}
+
+; Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(argmem: readwrite)
+define void @fma_llvm(ptr %a, ptr %b, ptr %c, ptr %dest) {
+; CHECK-LABEL: fma_llvm:
+; CHECK:         .functype fma_llvm (i32, i32, i32, i32) -> ()
+; CHECK-NEXT:  # %bb.0: # %entry
+; CHECK-NEXT:    v128.load $push2=, 0($0):p2align=0
+; CHECK-NEXT:    v128.load $push1=, 0($1):p2align=0
+; CHECK-NEXT:    v128.load $push0=, 0($2):p2align=0
+; CHECK-NEXT:    f32x4.relaxed_madd $push3=, $pop2, $pop1, $pop0
+; CHECK-NEXT:    v128.store 0($3):p2align=0, $pop3
+; CHECK-NEXT:    return
+entry:
+  %0 = load <4 x float>, ptr %a, align 1
+  %1 = load <4 x float>, ptr %b, align 1
+  %2 = load <4 x float>, ptr %c, align 1
+  %fma = tail call fast <4 x float> @llvm.fma.v4f32(<4 x float> %0, <4 x float> %1, <4 x float> %2)
+  store <4 x float> %fma, ptr %dest, align 1
+  ret void
+}

Comment on lines 3424 to 3425
if (!N->getFlags().hasFastMath())
return SDValue();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not require all fast math flags

Copy link
Contributor Author

@badumbatish badumbatish Jul 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure when I see sth such as fadd fast .... in the selection dag input, which flag should i check for?

I will perform if (!N->getFlags().hasAllowContract()),

from llvm lang ref

contract
Allow floating-point contraction (e.g. fusing a multiply followed by an addition into a fused multiply-and-add). This does not enable reassociation to form arbitrary contractions. For example, (a*b) + (c*d) + e can not be transformed into (a*b) + ((c*d) + e) to create two fma operations.

@@ -475,6 +475,7 @@ struct SDNodeFlags {
bool hasAllowReassociation() const { return Flags & AllowReassociation; }
bool hasNoFPExcept() const { return Flags & NoFPExcept; }
bool hasUnpredictable() const { return Flags & Unpredictable; }
bool hasFastMath() const { return Flags & FastMathFlags; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should not add this. This does not depend on all the flags, and you should only depend on the exact set of flags required for a transform

@@ -0,0 +1,43 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5

; RUN: llc < %s -verify-machineinstrs -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers -mattr=+simd128,+relaxed-simd | FileCheck %s
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't need -verify-machineinstrs

Comment on lines 3438 to 3450
static SDValue performFMACombine(SDNode *N, SelectionDAG &DAG) {
assert(N->getOpcode() == ISD::FMA);
if (!N->getFlags().hasFastMath())
return SDValue();

SDLoc DL(N);
SDValue A = N->getOperand(0), B = N->getOperand(1), C = N->getOperand(2);
EVT VecVT = N->getValueType(0);

return DAG.getNode(
ISD::INTRINSIC_WO_CHAIN, DL, VecVT,
{DAG.getConstant(Intrinsic::wasm_relaxed_madd, DL, MVT::i32), A, B, C});
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I doubt it's appropriate to turn an FMA into a not-FMA, regardless of fast math flags. What are the semantics of wasm_relaxed_madd?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @arsenm, I think wasm_relaxed_madd is actually fused multiply add.

This is from https://github.com/WebAssembly/relaxed-simd/blob/main/proposals/relaxed-simd/Overview.md :

Relaxed fused multiply-add and fused negative multiply-add

relaxed f32x4.madd
relaxed f32x4.nmadd
relaxed f64x2.madd
relaxed f64x2.nmadd

All the instructions take 3 operands, a, b, c, perform a * b + c or -(a * b) + c:

relaxed f32x4.madd(a, b, c) = a * b + c
relaxed f32x4.nmadd(a, b, c) = -(a * b) + c
relaxed f64x2.madd(a, b, c) = a * b + c
relaxed f64x2.nmadd(a, b, c) = -(a * b) + c

where:

the intermediate a * b is be rounded first, and the final result rounded again (for a total of 2 roundings), or
the entire expression evaluated with higher precision and then only rounded once (if supported by hardware).

Comment on lines 16 to 18
%0 = load <4 x float>, ptr %a, align 1
%1 = load <4 x float>, ptr %b, align 1
%2 = load <4 x float>, ptr %c, align 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to use argument values instead of loading the sample values

%fma = tail call fast <4 x float> @llvm.fma.v4f32(<4 x float> %0, <4 x float> %1, <4 x float> %2)
store <4 x float> %fma, ptr %dest, align 1
ret void
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test scalar cases and other vector widths

- Fix inefficient wasm test case.
- Added scalar test case and more floating type.
- Remove total ffast checking -> allowContract
- Added support for <8 x f32>.
- Refactored out condition for relaxed simd to a seperate function.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants