-
Notifications
You must be signed in to change notification settings - Fork 14.4k
relaxed simd fma #147487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
relaxed simd fma #147487
Conversation
badumbatish
commented
Jul 8, 2025
- Precommit test for Fma not optimized for wasm relaxed-simd #121311
- [WASM] Optimize fma when relaxed and ffast-math
Fixes llvm#121311, which folds a series of multiply and add to wasm.fma when we have -mrelaxed-simd and -ffast-math. Also attempted to use wasm.fma instead of the built in llvm.fma
@llvm/pr-subscribers-backend-webassembly Author: jjasmine (badumbatish) Changes
Full diff: https://github.com/llvm/llvm-project/pull/147487.diff 3 Files Affected:
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index a3675eecfea3f..ec566b168bc3d 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -475,6 +475,7 @@ struct SDNodeFlags {
bool hasAllowReassociation() const { return Flags & AllowReassociation; }
bool hasNoFPExcept() const { return Flags & NoFPExcept; }
bool hasUnpredictable() const { return Flags & Unpredictable; }
+ bool hasFastMath() const { return Flags & FastMathFlags; }
bool operator==(const SDNodeFlags &Other) const {
return Flags == Other.Flags;
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index bf2e04caa0a61..ef0146f28aba1 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -182,6 +182,12 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
// SIMD-specific configuration
if (Subtarget->hasSIMD128()) {
+ // Enable fma optimization for wasm relaxed simd
+ if (Subtarget->hasRelaxedSIMD()) {
+ setTargetDAGCombine(ISD::FADD);
+ setTargetDAGCombine(ISD::FMA);
+ }
+
// Combine partial.reduce.add before legalization gets confused.
setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN);
@@ -3412,6 +3418,37 @@ static SDValue performSETCCCombine(SDNode *N,
return SDValue();
}
+static SDValue performFAddCombine(SDNode *N, SelectionDAG &DAG) {
+ assert(N->getOpcode() == ISD::FADD);
+ using namespace llvm::SDPatternMatch;
+ if (!N->getFlags().hasFastMath())
+ return SDValue();
+
+ SDLoc DL(N);
+ SDValue A, B, C;
+ EVT VecVT = N->getValueType(0);
+ if (sd_match(N, m_FAdd(m_Value(A), m_FMul(m_Value(B), m_Value(C)))))
+ return DAG.getNode(
+ ISD::INTRINSIC_WO_CHAIN, DL, VecVT,
+ {DAG.getConstant(Intrinsic::wasm_relaxed_madd, DL, MVT::i32), A, B, C});
+
+ return SDValue();
+}
+
+static SDValue performFMACombine(SDNode *N, SelectionDAG &DAG) {
+ assert(N->getOpcode() == ISD::FMA);
+ if (!N->getFlags().hasFastMath())
+ return SDValue();
+
+ SDLoc DL(N);
+ SDValue A = N->getOperand(0), B = N->getOperand(1), C = N->getOperand(2);
+ EVT VecVT = N->getValueType(0);
+
+ return DAG.getNode(
+ ISD::INTRINSIC_WO_CHAIN, DL, VecVT,
+ {DAG.getConstant(Intrinsic::wasm_relaxed_madd, DL, MVT::i32), A, B, C});
+}
+
static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG) {
assert(N->getOpcode() == ISD::MUL);
EVT VT = N->getValueType(0);
@@ -3529,6 +3566,10 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
return AnyAllCombine;
return performLowerPartialReduction(N, DCI.DAG);
}
+ case ISD::FADD:
+ return performFAddCombine(N, DCI.DAG);
+ case ISD::FMA:
+ return performFMACombine(N, DCI.DAG);
case ISD::MUL:
return performMulCombine(N, DCI.DAG);
}
diff --git a/llvm/test/CodeGen/WebAssembly/simd-relaxed-fma.ll b/llvm/test/CodeGen/WebAssembly/simd-relaxed-fma.ll
new file mode 100644
index 0000000000000..fe5e8573f12b4
--- /dev/null
+++ b/llvm/test/CodeGen/WebAssembly/simd-relaxed-fma.ll
@@ -0,0 +1,43 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+
+; RUN: llc < %s -verify-machineinstrs -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers -mattr=+simd128,+relaxed-simd | FileCheck %s
+target triple = "wasm32"
+define void @fma_seperate(ptr %a, ptr %b, ptr %c, ptr %dest) {
+; CHECK-LABEL: fma_seperate:
+; CHECK: .functype fma_seperate (i32, i32, i32, i32) -> ()
+; CHECK-NEXT: # %bb.0: # %entry
+; CHECK-NEXT: v128.load $push2=, 0($2):p2align=0
+; CHECK-NEXT: v128.load $push1=, 0($1):p2align=0
+; CHECK-NEXT: v128.load $push0=, 0($0):p2align=0
+; CHECK-NEXT: f32x4.relaxed_madd $push3=, $pop2, $pop1, $pop0
+; CHECK-NEXT: v128.store 0($3):p2align=0, $pop3
+; CHECK-NEXT: return
+entry:
+ %0 = load <4 x float>, ptr %a, align 1
+ %1 = load <4 x float>, ptr %b, align 1
+ %2 = load <4 x float>, ptr %c, align 1
+ %mul.i = fmul fast <4 x float> %1, %0
+ %add.i = fadd fast <4 x float> %mul.i, %2
+ store <4 x float> %add.i, ptr %dest, align 1
+ ret void
+}
+
+; Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(argmem: readwrite)
+define void @fma_llvm(ptr %a, ptr %b, ptr %c, ptr %dest) {
+; CHECK-LABEL: fma_llvm:
+; CHECK: .functype fma_llvm (i32, i32, i32, i32) -> ()
+; CHECK-NEXT: # %bb.0: # %entry
+; CHECK-NEXT: v128.load $push2=, 0($0):p2align=0
+; CHECK-NEXT: v128.load $push1=, 0($1):p2align=0
+; CHECK-NEXT: v128.load $push0=, 0($2):p2align=0
+; CHECK-NEXT: f32x4.relaxed_madd $push3=, $pop2, $pop1, $pop0
+; CHECK-NEXT: v128.store 0($3):p2align=0, $pop3
+; CHECK-NEXT: return
+entry:
+ %0 = load <4 x float>, ptr %a, align 1
+ %1 = load <4 x float>, ptr %b, align 1
+ %2 = load <4 x float>, ptr %c, align 1
+ %fma = tail call fast <4 x float> @llvm.fma.v4f32(<4 x float> %0, <4 x float> %1, <4 x float> %2)
+ store <4 x float> %fma, ptr %dest, align 1
+ ret void
+}
|
@llvm/pr-subscribers-llvm-selectiondag Author: jjasmine (badumbatish) Changes
Full diff: https://github.com/llvm/llvm-project/pull/147487.diff 3 Files Affected:
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index a3675eecfea3f..ec566b168bc3d 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -475,6 +475,7 @@ struct SDNodeFlags {
bool hasAllowReassociation() const { return Flags & AllowReassociation; }
bool hasNoFPExcept() const { return Flags & NoFPExcept; }
bool hasUnpredictable() const { return Flags & Unpredictable; }
+ bool hasFastMath() const { return Flags & FastMathFlags; }
bool operator==(const SDNodeFlags &Other) const {
return Flags == Other.Flags;
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index bf2e04caa0a61..ef0146f28aba1 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -182,6 +182,12 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
// SIMD-specific configuration
if (Subtarget->hasSIMD128()) {
+ // Enable fma optimization for wasm relaxed simd
+ if (Subtarget->hasRelaxedSIMD()) {
+ setTargetDAGCombine(ISD::FADD);
+ setTargetDAGCombine(ISD::FMA);
+ }
+
// Combine partial.reduce.add before legalization gets confused.
setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN);
@@ -3412,6 +3418,37 @@ static SDValue performSETCCCombine(SDNode *N,
return SDValue();
}
+static SDValue performFAddCombine(SDNode *N, SelectionDAG &DAG) {
+ assert(N->getOpcode() == ISD::FADD);
+ using namespace llvm::SDPatternMatch;
+ if (!N->getFlags().hasFastMath())
+ return SDValue();
+
+ SDLoc DL(N);
+ SDValue A, B, C;
+ EVT VecVT = N->getValueType(0);
+ if (sd_match(N, m_FAdd(m_Value(A), m_FMul(m_Value(B), m_Value(C)))))
+ return DAG.getNode(
+ ISD::INTRINSIC_WO_CHAIN, DL, VecVT,
+ {DAG.getConstant(Intrinsic::wasm_relaxed_madd, DL, MVT::i32), A, B, C});
+
+ return SDValue();
+}
+
+static SDValue performFMACombine(SDNode *N, SelectionDAG &DAG) {
+ assert(N->getOpcode() == ISD::FMA);
+ if (!N->getFlags().hasFastMath())
+ return SDValue();
+
+ SDLoc DL(N);
+ SDValue A = N->getOperand(0), B = N->getOperand(1), C = N->getOperand(2);
+ EVT VecVT = N->getValueType(0);
+
+ return DAG.getNode(
+ ISD::INTRINSIC_WO_CHAIN, DL, VecVT,
+ {DAG.getConstant(Intrinsic::wasm_relaxed_madd, DL, MVT::i32), A, B, C});
+}
+
static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG) {
assert(N->getOpcode() == ISD::MUL);
EVT VT = N->getValueType(0);
@@ -3529,6 +3566,10 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
return AnyAllCombine;
return performLowerPartialReduction(N, DCI.DAG);
}
+ case ISD::FADD:
+ return performFAddCombine(N, DCI.DAG);
+ case ISD::FMA:
+ return performFMACombine(N, DCI.DAG);
case ISD::MUL:
return performMulCombine(N, DCI.DAG);
}
diff --git a/llvm/test/CodeGen/WebAssembly/simd-relaxed-fma.ll b/llvm/test/CodeGen/WebAssembly/simd-relaxed-fma.ll
new file mode 100644
index 0000000000000..fe5e8573f12b4
--- /dev/null
+++ b/llvm/test/CodeGen/WebAssembly/simd-relaxed-fma.ll
@@ -0,0 +1,43 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+
+; RUN: llc < %s -verify-machineinstrs -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers -mattr=+simd128,+relaxed-simd | FileCheck %s
+target triple = "wasm32"
+define void @fma_seperate(ptr %a, ptr %b, ptr %c, ptr %dest) {
+; CHECK-LABEL: fma_seperate:
+; CHECK: .functype fma_seperate (i32, i32, i32, i32) -> ()
+; CHECK-NEXT: # %bb.0: # %entry
+; CHECK-NEXT: v128.load $push2=, 0($2):p2align=0
+; CHECK-NEXT: v128.load $push1=, 0($1):p2align=0
+; CHECK-NEXT: v128.load $push0=, 0($0):p2align=0
+; CHECK-NEXT: f32x4.relaxed_madd $push3=, $pop2, $pop1, $pop0
+; CHECK-NEXT: v128.store 0($3):p2align=0, $pop3
+; CHECK-NEXT: return
+entry:
+ %0 = load <4 x float>, ptr %a, align 1
+ %1 = load <4 x float>, ptr %b, align 1
+ %2 = load <4 x float>, ptr %c, align 1
+ %mul.i = fmul fast <4 x float> %1, %0
+ %add.i = fadd fast <4 x float> %mul.i, %2
+ store <4 x float> %add.i, ptr %dest, align 1
+ ret void
+}
+
+; Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(argmem: readwrite)
+define void @fma_llvm(ptr %a, ptr %b, ptr %c, ptr %dest) {
+; CHECK-LABEL: fma_llvm:
+; CHECK: .functype fma_llvm (i32, i32, i32, i32) -> ()
+; CHECK-NEXT: # %bb.0: # %entry
+; CHECK-NEXT: v128.load $push2=, 0($0):p2align=0
+; CHECK-NEXT: v128.load $push1=, 0($1):p2align=0
+; CHECK-NEXT: v128.load $push0=, 0($2):p2align=0
+; CHECK-NEXT: f32x4.relaxed_madd $push3=, $pop2, $pop1, $pop0
+; CHECK-NEXT: v128.store 0($3):p2align=0, $pop3
+; CHECK-NEXT: return
+entry:
+ %0 = load <4 x float>, ptr %a, align 1
+ %1 = load <4 x float>, ptr %b, align 1
+ %2 = load <4 x float>, ptr %c, align 1
+ %fma = tail call fast <4 x float> @llvm.fma.v4f32(<4 x float> %0, <4 x float> %1, <4 x float> %2)
+ store <4 x float> %fma, ptr %dest, align 1
+ ret void
+}
|
if (!N->getFlags().hasFastMath()) | ||
return SDValue(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does not require all fast math flags
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure when I see sth such as fadd fast ....
in the selection dag input, which flag should i check for?
I will perform if (!N->getFlags().hasAllowContract())
,
from llvm lang ref
contract
Allow floating-point contraction (e.g. fusing a multiply followed by an addition into a fused multiply-and-add). This does not enable reassociation to form arbitrary contractions. For example, (a*b) + (c*d) + e can not be transformed into (a*b) + ((c*d) + e) to create two fma operations.
@@ -475,6 +475,7 @@ struct SDNodeFlags { | |||
bool hasAllowReassociation() const { return Flags & AllowReassociation; } | |||
bool hasNoFPExcept() const { return Flags & NoFPExcept; } | |||
bool hasUnpredictable() const { return Flags & Unpredictable; } | |||
bool hasFastMath() const { return Flags & FastMathFlags; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should not add this. This does not depend on all the flags, and you should only depend on the exact set of flags required for a transform
@@ -0,0 +1,43 @@ | |||
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 | |||
|
|||
; RUN: llc < %s -verify-machineinstrs -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers -mattr=+simd128,+relaxed-simd | FileCheck %s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't need -verify-machineinstrs
static SDValue performFMACombine(SDNode *N, SelectionDAG &DAG) { | ||
assert(N->getOpcode() == ISD::FMA); | ||
if (!N->getFlags().hasFastMath()) | ||
return SDValue(); | ||
|
||
SDLoc DL(N); | ||
SDValue A = N->getOperand(0), B = N->getOperand(1), C = N->getOperand(2); | ||
EVT VecVT = N->getValueType(0); | ||
|
||
return DAG.getNode( | ||
ISD::INTRINSIC_WO_CHAIN, DL, VecVT, | ||
{DAG.getConstant(Intrinsic::wasm_relaxed_madd, DL, MVT::i32), A, B, C}); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I doubt it's appropriate to turn an FMA into a not-FMA, regardless of fast math flags. What are the semantics of wasm_relaxed_madd?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hi @arsenm, I think wasm_relaxed_madd is actually fused multiply add.
This is from https://github.com/WebAssembly/relaxed-simd/blob/main/proposals/relaxed-simd/Overview.md :
Relaxed fused multiply-add and fused negative multiply-add
relaxed f32x4.madd
relaxed f32x4.nmadd
relaxed f64x2.madd
relaxed f64x2.nmadd
All the instructions take 3 operands, a, b, c, perform a * b + c or -(a * b) + c:
relaxed f32x4.madd(a, b, c) = a * b + c
relaxed f32x4.nmadd(a, b, c) = -(a * b) + c
relaxed f64x2.madd(a, b, c) = a * b + c
relaxed f64x2.nmadd(a, b, c) = -(a * b) + c
where:
the intermediate a * b is be rounded first, and the final result rounded again (for a total of 2 roundings), or
the entire expression evaluated with higher precision and then only rounded once (if supported by hardware).
%0 = load <4 x float>, ptr %a, align 1 | ||
%1 = load <4 x float>, ptr %b, align 1 | ||
%2 = load <4 x float>, ptr %c, align 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better to use argument values instead of loading the sample values
%fma = tail call fast <4 x float> @llvm.fma.v4f32(<4 x float> %0, <4 x float> %1, <4 x float> %2) | ||
store <4 x float> %fma, ptr %dest, align 1 | ||
ret void | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test scalar cases and other vector widths
- Fix inefficient wasm test case. - Added scalar test case and more floating type. - Remove total ffast checking -> allowContract
- Added support for <8 x f32>. - Refactored out condition for relaxed simd to a seperate function.