14
14
#include " mlir/Dialect/Arith/Utils/Utils.h"
15
15
#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
16
16
#include " mlir/Dialect/LLVMIR/ROCDLDialect.h"
17
+ #include " mlir/Dialect/Utils/IndexingUtils.h"
17
18
#include " mlir/Dialect/Vector/IR/VectorOps.h"
19
+ #include " mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
20
+ #include " mlir/Dialect/Vector/Utils/VectorUtils.h"
18
21
#include " mlir/IR/BuiltinTypes.h"
19
22
#include " mlir/IR/PatternMatch.h"
20
23
#include " mlir/IR/TypeUtilities.h"
@@ -32,6 +35,7 @@ using namespace mlir::amdgpu;
32
35
namespace {
33
36
// Define commonly used chipsets versions for convenience.
34
37
constexpr Chipset kGfx942 = Chipset(9 , 4 , 2 );
38
+ constexpr Chipset kGfx950 = Chipset(9 , 5 , 0 );
35
39
36
40
struct ArithToAMDGPUConversionPass final
37
41
: impl::ArithToAMDGPUConversionPassBase<ArithToAMDGPUConversionPass> {
@@ -73,6 +77,28 @@ struct TruncfToFloat16RewritePattern final
73
77
PatternRewriter &rewriter) const override ;
74
78
};
75
79
80
+ struct ScalingExtFRewritePattern final
81
+ : OpRewritePattern<arith::ScalingExtFOp> {
82
+ using OpRewritePattern::OpRewritePattern;
83
+
84
+ ScalingExtFRewritePattern (MLIRContext *ctx)
85
+ : OpRewritePattern::OpRewritePattern(ctx) {}
86
+
87
+ LogicalResult matchAndRewrite (arith::ScalingExtFOp op,
88
+ PatternRewriter &rewriter) const override ;
89
+ };
90
+
91
+ struct ScalingTruncFRewritePattern final
92
+ : OpRewritePattern<arith::ScalingTruncFOp> {
93
+ using OpRewritePattern::OpRewritePattern;
94
+
95
+ ScalingTruncFRewritePattern (MLIRContext *ctx)
96
+ : OpRewritePattern::OpRewritePattern(ctx) {}
97
+
98
+ LogicalResult matchAndRewrite (arith::ScalingTruncFOp op,
99
+ PatternRewriter &rewriter) const override ;
100
+ };
101
+
76
102
} // end namespace
77
103
78
104
static bool isSupportedF8 (Type elementType, Chipset chipset) {
@@ -395,6 +421,247 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
395
421
return success ();
396
422
}
397
423
424
+ // / Get the broadcasted / splatted value for a chain of ops.
425
+ static Value getOriginalVectorValue (Value value) {
426
+ Value current = value;
427
+ while (Operation *definingOp = current.getDefiningOp ()) {
428
+ bool skipOp = llvm::TypeSwitch<Operation *, bool >(definingOp)
429
+ .Case <vector::ShapeCastOp>([¤t](auto op) {
430
+ current = op.getSource ();
431
+ return true ;
432
+ })
433
+ .Case <vector::BroadcastOp>([¤t](auto op) {
434
+ current = op.getSource ();
435
+ return false ;
436
+ })
437
+ .Case <vector::SplatOp>([¤t](auto op) {
438
+ current = op.getInput ();
439
+ return false ;
440
+ })
441
+ .Default ([](Operation *) { return false ; });
442
+
443
+ if (!skipOp) {
444
+ break ;
445
+ }
446
+ }
447
+ return current;
448
+ }
449
+
450
+ LogicalResult
451
+ ScalingExtFRewritePattern::matchAndRewrite (arith::ScalingExtFOp op,
452
+ PatternRewriter &rewriter) const {
453
+ Location loc = op.getLoc ();
454
+ constexpr int64_t opWidth = 2 ;
455
+
456
+ Value in = op.getIn ();
457
+ Value scale = op.getScale ();
458
+ Value out = op.getOut ();
459
+
460
+ Type f32 = rewriter.getF32Type ();
461
+ Type inType = getElementTypeOrSelf (in);
462
+ Type scaleType = getElementTypeOrSelf (scale);
463
+ Type outType = getElementTypeOrSelf (out);
464
+
465
+ VectorType outVecType = dyn_cast<VectorType>(out.getType ());
466
+ VectorType scaleVecType = dyn_cast<VectorType>(scale.getType ());
467
+
468
+ if (outVecType && outVecType.isScalable ())
469
+ return failure ();
470
+
471
+ Type scaleF32Type =
472
+ scaleVecType ? VectorType::get (scaleVecType.getShape (), f32 ) : f32 ;
473
+ if (scaleType.getIntOrFloatBitWidth () < 32 )
474
+ scale = rewriter.create <arith::ExtFOp>(loc, scaleF32Type, scale);
475
+ else if (scaleType.getIntOrFloatBitWidth () > 32 )
476
+ scale = rewriter.create <arith::TruncFOp>(loc, scaleF32Type, scale);
477
+
478
+ VectorType extScaleResultType = VectorType::get (opWidth, outType);
479
+
480
+ if (!outVecType) {
481
+ Value inCast =
482
+ rewriter.create <vector::SplatOp>(loc, VectorType::get (1 , inType), in);
483
+ // TODO: replace this with non-packed ScaledExtOp
484
+ Value scaleExt = rewriter.create <amdgpu::ScaledExtPackedOp>(
485
+ loc, extScaleResultType, inCast, scale, 0 );
486
+ scaleExt = rewriter.replaceOpWithNewOp <vector::ExtractOp>(op, scaleExt, 0 );
487
+ return success ();
488
+ }
489
+
490
+ VectorType inVecType = cast<VectorType>(in.getType ());
491
+ Value origScale = getOriginalVectorValue (op.getScale ());
492
+
493
+ ArrayRef<int64_t > inShape = inVecType.getShape ();
494
+ SmallVector<int64_t > originalScaleShape;
495
+ if (auto origScaleVecType = dyn_cast<VectorType>(origScale.getType ()))
496
+ llvm::append_range (originalScaleShape, origScaleVecType.getShape ());
497
+
498
+ originalScaleShape.insert (originalScaleShape.end (),
499
+ inShape.size () - originalScaleShape.size (), 1 );
500
+
501
+ auto maybeRatio = computeShapeRatio (inShape, originalScaleShape);
502
+ assert (maybeRatio &&
503
+ " failed to derive block size from broadcast or splat operation" );
504
+
505
+ SmallVector<int64_t > ratio =
506
+ maybeRatio.value_or (SmallVector<int64_t >(inShape.size (), 1 ));
507
+
508
+ int64_t blockSize = computeProduct (ratio);
509
+
510
+ Value zero = rewriter.create <arith::ConstantOp>(
511
+ loc, outType, rewriter.getFloatAttr (outType, 0.0 ));
512
+ Value result = rewriter.createOrFold <vector::SplatOp>(loc, outVecType, zero);
513
+
514
+ for (SmallVector<int64_t > offsets : StaticTileOffsetRange (inShape, ratio)) {
515
+ SmallVector<int64_t > strides (offsets.size (), 1 );
516
+ Value block = rewriter.create <vector::ExtractStridedSliceOp>(
517
+ loc, in, offsets, ratio, strides);
518
+ VectorType block1DType = VectorType::get (blockSize, inType);
519
+ Value block1D =
520
+ rewriter.create <vector::ShapeCastOp>(loc, block1DType, block);
521
+ Value uniformScale =
522
+ rewriter.create <vector::ExtractOp>(loc, scale, offsets);
523
+
524
+ VectorType blockResultType = VectorType::get (blockSize, outType);
525
+ Value blockResult =
526
+ rewriter.createOrFold <vector::SplatOp>(loc, blockResultType, zero);
527
+
528
+ for (int64_t i = 0 , sliceWidth = std::min (opWidth, blockSize - i);
529
+ i < blockSize;
530
+ i += sliceWidth, sliceWidth = std::min (opWidth, blockSize - i)) {
531
+ Value slice = rewriter.create <vector::ExtractStridedSliceOp>(
532
+ loc, block1D, i, sliceWidth, 1 );
533
+ // TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1
534
+ Value scaleExt = rewriter.create <amdgpu::ScaledExtPackedOp>(
535
+ loc, extScaleResultType, slice, uniformScale, 0 );
536
+ if (sliceWidth != opWidth)
537
+ scaleExt = rewriter.create <vector::ExtractStridedSliceOp>(
538
+ loc, scaleExt, 0 , sliceWidth, 1 );
539
+ blockResult = rewriter.create <vector::InsertStridedSliceOp>(
540
+ loc, scaleExt, blockResult, i, 1 );
541
+ }
542
+
543
+ VectorType resultType = VectorType::get (ratio, outType);
544
+ Value cast =
545
+ rewriter.create <vector::ShapeCastOp>(loc, resultType, blockResult);
546
+ result = rewriter.create <vector::InsertStridedSliceOp>(loc, cast, result,
547
+ offsets, strides);
548
+ }
549
+
550
+ rewriter.replaceOp (op, result);
551
+
552
+ return success ();
553
+ }
554
+
555
+ LogicalResult
556
+ ScalingTruncFRewritePattern::matchAndRewrite (arith::ScalingTruncFOp op,
557
+ PatternRewriter &rewriter) const {
558
+ Location loc = op.getLoc ();
559
+ constexpr int64_t opWidth = 2 ;
560
+
561
+ Value in = op.getIn ();
562
+ Value scale = op.getScale ();
563
+ Value out = op.getOut ();
564
+
565
+ Type f32 = rewriter.getF32Type ();
566
+ Type inType = getElementTypeOrSelf (in);
567
+ Type scaleType = getElementTypeOrSelf (scale);
568
+ Type outType = getElementTypeOrSelf (out);
569
+
570
+ VectorType outVecType = dyn_cast<VectorType>(out.getType ());
571
+ VectorType scaleVecType = dyn_cast<VectorType>(scale.getType ());
572
+
573
+ if (outVecType && outVecType.isScalable ())
574
+ return failure ();
575
+
576
+ Type scaleF32Type =
577
+ scaleVecType ? VectorType::get (scaleVecType.getShape (), f32 ) : f32 ;
578
+ if (scaleType.getIntOrFloatBitWidth () < 32 )
579
+ scale = rewriter.create <arith::ExtFOp>(loc, scaleF32Type, scale);
580
+ else if (scaleType.getIntOrFloatBitWidth () > 32 )
581
+ scale = rewriter.create <arith::TruncFOp>(loc, scaleF32Type, scale);
582
+
583
+ Value zero = rewriter.create <arith::ConstantOp>(
584
+ loc, outType, rewriter.getFloatAttr (outType, 0.0 ));
585
+ unsigned numPackedElem = 32 / outType.getIntOrFloatBitWidth ();
586
+ VectorType truncScaleResultType = VectorType::get (numPackedElem, outType);
587
+
588
+ if (!outVecType) {
589
+ Type inVecType = VectorType::get (1 , inType);
590
+ Value inCast = rewriter.create <vector::SplatOp>(loc, inVecType, in);
591
+ // TODO: replace this with non-packed ScaledTruncOp
592
+ Value scaleTrunc = rewriter.create <amdgpu::PackedScaledTruncOp>(
593
+ loc, truncScaleResultType, inCast, scale, 0 , /* existing=*/ nullptr );
594
+ scaleTrunc =
595
+ rewriter.replaceOpWithNewOp <vector::ExtractOp>(op, scaleTrunc, 0 );
596
+ return success ();
597
+ }
598
+
599
+ VectorType inVecType = cast<VectorType>(in.getType ());
600
+ Value origScale = getOriginalVectorValue (op.getScale ());
601
+
602
+ ArrayRef<int64_t > inShape = inVecType.getShape ();
603
+ SmallVector<int64_t > originalScaleShape;
604
+ if (auto origScaleVecType = dyn_cast<VectorType>(origScale.getType ()))
605
+ llvm::append_range (originalScaleShape, origScaleVecType.getShape ());
606
+
607
+ originalScaleShape.insert (originalScaleShape.end (),
608
+ inShape.size () - originalScaleShape.size (), 1 );
609
+
610
+ auto maybeRatio = computeShapeRatio (inShape, originalScaleShape);
611
+ assert (maybeRatio &&
612
+ " failed to derive block size from broadcast or splat operation" );
613
+
614
+ SmallVector<int64_t > ratio =
615
+ maybeRatio.value_or (SmallVector<int64_t >(inShape.size (), 1 ));
616
+
617
+ int64_t blockSize = computeProduct (ratio);
618
+
619
+ Value result = rewriter.createOrFold <vector::SplatOp>(loc, outVecType, zero);
620
+
621
+ for (SmallVector<int64_t > offsets : StaticTileOffsetRange (inShape, ratio)) {
622
+ SmallVector<int64_t > strides (offsets.size (), 1 );
623
+ Value block = rewriter.create <vector::ExtractStridedSliceOp>(
624
+ loc, in, offsets, ratio, strides);
625
+ VectorType block1DType = VectorType::get (blockSize, inType);
626
+ Value block1D =
627
+ rewriter.create <vector::ShapeCastOp>(loc, block1DType, block);
628
+ Value uniformScale =
629
+ rewriter.create <vector::ExtractOp>(loc, scale, offsets);
630
+
631
+ VectorType blockResultType = VectorType::get (blockSize, outType);
632
+ Value blockResult =
633
+ rewriter.createOrFold <vector::SplatOp>(loc, blockResultType, zero);
634
+
635
+ for (int64_t i = 0 , sliceWidth = std::min (opWidth, blockSize - i);
636
+ i < blockSize;
637
+ i += sliceWidth, sliceWidth = std::min (opWidth, blockSize - i)) {
638
+ Value slice = rewriter.create <vector::ExtractStridedSliceOp>(
639
+ loc, block1D, i, sliceWidth, 1 );
640
+ // TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1
641
+ Value scaleTrunc = rewriter.create <amdgpu::PackedScaledTruncOp>(
642
+ loc, truncScaleResultType, slice, uniformScale, 0 ,
643
+ /* existing=*/ nullptr );
644
+ int64_t packedWidth =
645
+ cast<VectorType>(scaleTrunc.getType ()).getNumElements ();
646
+ if (packedWidth != opWidth)
647
+ scaleTrunc = rewriter.create <vector::ExtractStridedSliceOp>(
648
+ loc, scaleTrunc, 0 , sliceWidth, 1 );
649
+ blockResult = rewriter.create <vector::InsertStridedSliceOp>(
650
+ loc, scaleTrunc, blockResult, i, 1 );
651
+ }
652
+
653
+ VectorType resultType = VectorType::get (ratio, outType);
654
+ Value cast =
655
+ rewriter.create <vector::ShapeCastOp>(loc, resultType, blockResult);
656
+ result = rewriter.create <vector::InsertStridedSliceOp>(loc, cast, result,
657
+ offsets, strides);
658
+ }
659
+
660
+ rewriter.replaceOp (op, result);
661
+
662
+ return success ();
663
+ }
664
+
398
665
void mlir::arith::populateArithToAMDGPUConversionPatterns (
399
666
RewritePatternSet &patterns, bool convertFP8Arithmetic,
400
667
bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {
@@ -406,6 +673,11 @@ void mlir::arith::populateArithToAMDGPUConversionPatterns(
406
673
}
407
674
if (allowPackedF16Rtz)
408
675
patterns.add <TruncfToFloat16RewritePattern>(patterns.getContext ());
676
+
677
+ if (chipset >= kGfx950 ) {
678
+ patterns.add <ScalingExtFRewritePattern>(patterns.getContext ());
679
+ patterns.add <ScalingTruncFRewritePattern>(patterns.getContext ());
680
+ }
409
681
}
410
682
411
683
void ArithToAMDGPUConversionPass::runOnOperation () {
0 commit comments