@@ -405,6 +405,18 @@ multiclass F3_fma_component<string OpcStr, SDNode OpNode> {
405
405
!strconcat(OpcStr, ".f32 \t$dst, $a, $b;"),
406
406
[(set f32:$dst, (OpNode f32:$a, fpimm:$b))]>,
407
407
Requires<[allowFMA]>;
408
+ def f32x2rr_ftz :
409
+ NVPTXInst<(outs Int64Regs:$dst),
410
+ (ins Int64Regs:$a, Int64Regs:$b),
411
+ !strconcat(OpcStr, ".ftz.f32x2 \t$dst, $a, $b;"),
412
+ [(set v2f32:$dst, (OpNode v2f32:$a, v2f32:$b))]>,
413
+ Requires<[allowFMA, doF32FTZ]>;
414
+ def f32x2rr :
415
+ NVPTXInst<(outs Int64Regs:$dst),
416
+ (ins Int64Regs:$a, Int64Regs:$b),
417
+ !strconcat(OpcStr, ".f32x2 \t$dst, $a, $b;"),
418
+ [(set v2f32:$dst, (OpNode v2f32:$a, v2f32:$b))]>,
419
+ Requires<[allowFMA]>;
408
420
409
421
def f16rr_ftz :
410
422
NVPTXInst<(outs Int16Regs:$dst),
@@ -529,6 +541,18 @@ multiclass F3_fma_component<string OpcStr, SDNode OpNode> {
529
541
!strconcat(OpcStr, ".rn.bf16x2 \t$dst, $a, $b;"),
530
542
[(set v2bf16:$dst, (OpNode v2bf16:$a, v2bf16:$b))]>,
531
543
Requires<[hasBF16Math, noFMA]>;
544
+ def _rnf32x2rr_ftz :
545
+ NVPTXInst<(outs Int64Regs:$dst),
546
+ (ins Int64Regs:$a, Int64Regs:$b),
547
+ !strconcat(OpcStr, ".rn.ftz.f32x2 \t$dst, $a, $b;"),
548
+ [(set v2f32:$dst, (OpNode v2f32:$a, v2f32:$b))]>,
549
+ Requires<[hasF32x2Instructions, noFMA, doF32FTZ]>;
550
+ def _rnf32x2rr :
551
+ NVPTXInst<(outs Int64Regs:$dst),
552
+ (ins Int64Regs:$a, Int64Regs:$b),
553
+ !strconcat(OpcStr, ".rn.f32x2 \t$dst, $a, $b;"),
554
+ [(set v2f32:$dst, (OpNode v2f32:$a, v2f32:$b))]>,
555
+ Requires<[hasF32x2Instructions, noFMA]>;
532
556
}
533
557
534
558
// Template for operations which take two f32 or f64 operands. Provides three
@@ -1432,6 +1456,13 @@ multiclass FMA_BF16<string OpcStr, ValueType T, RegisterClass RC, Predicate Pred
1432
1456
Requires<[hasBF16Math, Pred]>;
1433
1457
}
1434
1458
1459
+ class FMA_F32x2<string OpcStr, Predicate Pred>
1460
+ : NVPTXInst<(outs Int64Regs:$res),
1461
+ (ins Int64Regs:$a, Int64Regs:$b, Int64Regs:$c),
1462
+ OpcStr # ".f32x2 \t$res, $a, $b, $c;",
1463
+ [(set v2f32:$res, (fma v2f32:$a, v2f32:$b, v2f32:$c))]>,
1464
+ Requires<[hasF32x2Instructions, Pred]>;
1465
+
1435
1466
defm FMA16_ftz : FMA_F16<"fma.rn.ftz.f16", f16, Int16Regs, doF32FTZ>;
1436
1467
defm FMA16 : FMA_F16<"fma.rn.f16", f16, Int16Regs, True>;
1437
1468
defm FMA16x2_ftz : FMA_F16<"fma.rn.ftz.f16x2", v2f16, Int32Regs, doF32FTZ>;
@@ -1440,6 +1471,8 @@ defm BFMA16 : FMA_BF16<"fma.rn.bf16", bf16, Int16Regs, True>;
1440
1471
defm BFMA16x2 : FMA_BF16<"fma.rn.bf16x2", v2bf16, Int32Regs, True>;
1441
1472
defm FMA32_ftz : FMA<"fma.rn.ftz.f32", Float32Regs, f32imm, doF32FTZ>;
1442
1473
defm FMA32 : FMA<"fma.rn.f32", Float32Regs, f32imm, True>;
1474
+ def FMA32x2_ftz : FMA_F32x2<"fma.rn.ftz", doF32FTZ>;
1475
+ def FMA32x2 : FMA_F32x2<"fma.rn", True>;
1443
1476
defm FMA64 : FMA<"fma.rn.f64", Float64Regs, f64imm, True>;
1444
1477
1445
1478
// sin/cos
0 commit comments