Skip to content

Commit 6f291cb

Browse files
authored
[mlir][amdgpu] Add conversion from arith.scaling_extf / arith.scaling_truncf to amdgpu (#146372)
- add conversion from arith.scaling_extf to amdgpu.scaled_ext_packed - add conversion from arith.scaling_truncf to amdgpu.packed_scaled_trunc
1 parent 08ac3b3 commit 6f291cb

File tree

3 files changed

+727
-0
lines changed

3 files changed

+727
-0
lines changed

mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp

Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414
#include "mlir/Dialect/Arith/Utils/Utils.h"
1515
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1616
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
17+
#include "mlir/Dialect/Utils/IndexingUtils.h"
1718
#include "mlir/Dialect/Vector/IR/VectorOps.h"
19+
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
20+
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
1821
#include "mlir/IR/BuiltinTypes.h"
1922
#include "mlir/IR/PatternMatch.h"
2023
#include "mlir/IR/TypeUtilities.h"
@@ -32,6 +35,7 @@ using namespace mlir::amdgpu;
3235
namespace {
3336
// Define commonly used chipsets versions for convenience.
3437
constexpr Chipset kGfx942 = Chipset(9, 4, 2);
38+
constexpr Chipset kGfx950 = Chipset(9, 5, 0);
3539

3640
struct ArithToAMDGPUConversionPass final
3741
: impl::ArithToAMDGPUConversionPassBase<ArithToAMDGPUConversionPass> {
@@ -73,6 +77,28 @@ struct TruncfToFloat16RewritePattern final
7377
PatternRewriter &rewriter) const override;
7478
};
7579

80+
struct ScalingExtFRewritePattern final
81+
: OpRewritePattern<arith::ScalingExtFOp> {
82+
using OpRewritePattern::OpRewritePattern;
83+
84+
ScalingExtFRewritePattern(MLIRContext *ctx)
85+
: OpRewritePattern::OpRewritePattern(ctx) {}
86+
87+
LogicalResult matchAndRewrite(arith::ScalingExtFOp op,
88+
PatternRewriter &rewriter) const override;
89+
};
90+
91+
struct ScalingTruncFRewritePattern final
92+
: OpRewritePattern<arith::ScalingTruncFOp> {
93+
using OpRewritePattern::OpRewritePattern;
94+
95+
ScalingTruncFRewritePattern(MLIRContext *ctx)
96+
: OpRewritePattern::OpRewritePattern(ctx) {}
97+
98+
LogicalResult matchAndRewrite(arith::ScalingTruncFOp op,
99+
PatternRewriter &rewriter) const override;
100+
};
101+
76102
} // end namespace
77103

78104
static bool isSupportedF8(Type elementType, Chipset chipset) {
@@ -395,6 +421,247 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
395421
return success();
396422
}
397423

424+
/// Get the broadcasted / splatted value for a chain of ops.
425+
static Value getOriginalVectorValue(Value value) {
426+
Value current = value;
427+
while (Operation *definingOp = current.getDefiningOp()) {
428+
bool skipOp = llvm::TypeSwitch<Operation *, bool>(definingOp)
429+
.Case<vector::ShapeCastOp>([&current](auto op) {
430+
current = op.getSource();
431+
return true;
432+
})
433+
.Case<vector::BroadcastOp>([&current](auto op) {
434+
current = op.getSource();
435+
return false;
436+
})
437+
.Case<vector::SplatOp>([&current](auto op) {
438+
current = op.getInput();
439+
return false;
440+
})
441+
.Default([](Operation *) { return false; });
442+
443+
if (!skipOp) {
444+
break;
445+
}
446+
}
447+
return current;
448+
}
449+
450+
LogicalResult
451+
ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
452+
PatternRewriter &rewriter) const {
453+
Location loc = op.getLoc();
454+
constexpr int64_t opWidth = 2;
455+
456+
Value in = op.getIn();
457+
Value scale = op.getScale();
458+
Value out = op.getOut();
459+
460+
Type f32 = rewriter.getF32Type();
461+
Type inType = getElementTypeOrSelf(in);
462+
Type scaleType = getElementTypeOrSelf(scale);
463+
Type outType = getElementTypeOrSelf(out);
464+
465+
VectorType outVecType = dyn_cast<VectorType>(out.getType());
466+
VectorType scaleVecType = dyn_cast<VectorType>(scale.getType());
467+
468+
if (outVecType && outVecType.isScalable())
469+
return failure();
470+
471+
Type scaleF32Type =
472+
scaleVecType ? VectorType::get(scaleVecType.getShape(), f32) : f32;
473+
if (scaleType.getIntOrFloatBitWidth() < 32)
474+
scale = rewriter.create<arith::ExtFOp>(loc, scaleF32Type, scale);
475+
else if (scaleType.getIntOrFloatBitWidth() > 32)
476+
scale = rewriter.create<arith::TruncFOp>(loc, scaleF32Type, scale);
477+
478+
VectorType extScaleResultType = VectorType::get(opWidth, outType);
479+
480+
if (!outVecType) {
481+
Value inCast =
482+
rewriter.create<vector::SplatOp>(loc, VectorType::get(1, inType), in);
483+
// TODO: replace this with non-packed ScaledExtOp
484+
Value scaleExt = rewriter.create<amdgpu::ScaledExtPackedOp>(
485+
loc, extScaleResultType, inCast, scale, 0);
486+
scaleExt = rewriter.replaceOpWithNewOp<vector::ExtractOp>(op, scaleExt, 0);
487+
return success();
488+
}
489+
490+
VectorType inVecType = cast<VectorType>(in.getType());
491+
Value origScale = getOriginalVectorValue(op.getScale());
492+
493+
ArrayRef<int64_t> inShape = inVecType.getShape();
494+
SmallVector<int64_t> originalScaleShape;
495+
if (auto origScaleVecType = dyn_cast<VectorType>(origScale.getType()))
496+
llvm::append_range(originalScaleShape, origScaleVecType.getShape());
497+
498+
originalScaleShape.insert(originalScaleShape.end(),
499+
inShape.size() - originalScaleShape.size(), 1);
500+
501+
auto maybeRatio = computeShapeRatio(inShape, originalScaleShape);
502+
assert(maybeRatio &&
503+
"failed to derive block size from broadcast or splat operation");
504+
505+
SmallVector<int64_t> ratio =
506+
maybeRatio.value_or(SmallVector<int64_t>(inShape.size(), 1));
507+
508+
int64_t blockSize = computeProduct(ratio);
509+
510+
Value zero = rewriter.create<arith::ConstantOp>(
511+
loc, outType, rewriter.getFloatAttr(outType, 0.0));
512+
Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);
513+
514+
for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) {
515+
SmallVector<int64_t> strides(offsets.size(), 1);
516+
Value block = rewriter.create<vector::ExtractStridedSliceOp>(
517+
loc, in, offsets, ratio, strides);
518+
VectorType block1DType = VectorType::get(blockSize, inType);
519+
Value block1D =
520+
rewriter.create<vector::ShapeCastOp>(loc, block1DType, block);
521+
Value uniformScale =
522+
rewriter.create<vector::ExtractOp>(loc, scale, offsets);
523+
524+
VectorType blockResultType = VectorType::get(blockSize, outType);
525+
Value blockResult =
526+
rewriter.createOrFold<vector::SplatOp>(loc, blockResultType, zero);
527+
528+
for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
529+
i < blockSize;
530+
i += sliceWidth, sliceWidth = std::min(opWidth, blockSize - i)) {
531+
Value slice = rewriter.create<vector::ExtractStridedSliceOp>(
532+
loc, block1D, i, sliceWidth, 1);
533+
// TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1
534+
Value scaleExt = rewriter.create<amdgpu::ScaledExtPackedOp>(
535+
loc, extScaleResultType, slice, uniformScale, 0);
536+
if (sliceWidth != opWidth)
537+
scaleExt = rewriter.create<vector::ExtractStridedSliceOp>(
538+
loc, scaleExt, 0, sliceWidth, 1);
539+
blockResult = rewriter.create<vector::InsertStridedSliceOp>(
540+
loc, scaleExt, blockResult, i, 1);
541+
}
542+
543+
VectorType resultType = VectorType::get(ratio, outType);
544+
Value cast =
545+
rewriter.create<vector::ShapeCastOp>(loc, resultType, blockResult);
546+
result = rewriter.create<vector::InsertStridedSliceOp>(loc, cast, result,
547+
offsets, strides);
548+
}
549+
550+
rewriter.replaceOp(op, result);
551+
552+
return success();
553+
}
554+
555+
LogicalResult
556+
ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
557+
PatternRewriter &rewriter) const {
558+
Location loc = op.getLoc();
559+
constexpr int64_t opWidth = 2;
560+
561+
Value in = op.getIn();
562+
Value scale = op.getScale();
563+
Value out = op.getOut();
564+
565+
Type f32 = rewriter.getF32Type();
566+
Type inType = getElementTypeOrSelf(in);
567+
Type scaleType = getElementTypeOrSelf(scale);
568+
Type outType = getElementTypeOrSelf(out);
569+
570+
VectorType outVecType = dyn_cast<VectorType>(out.getType());
571+
VectorType scaleVecType = dyn_cast<VectorType>(scale.getType());
572+
573+
if (outVecType && outVecType.isScalable())
574+
return failure();
575+
576+
Type scaleF32Type =
577+
scaleVecType ? VectorType::get(scaleVecType.getShape(), f32) : f32;
578+
if (scaleType.getIntOrFloatBitWidth() < 32)
579+
scale = rewriter.create<arith::ExtFOp>(loc, scaleF32Type, scale);
580+
else if (scaleType.getIntOrFloatBitWidth() > 32)
581+
scale = rewriter.create<arith::TruncFOp>(loc, scaleF32Type, scale);
582+
583+
Value zero = rewriter.create<arith::ConstantOp>(
584+
loc, outType, rewriter.getFloatAttr(outType, 0.0));
585+
unsigned numPackedElem = 32 / outType.getIntOrFloatBitWidth();
586+
VectorType truncScaleResultType = VectorType::get(numPackedElem, outType);
587+
588+
if (!outVecType) {
589+
Type inVecType = VectorType::get(1, inType);
590+
Value inCast = rewriter.create<vector::SplatOp>(loc, inVecType, in);
591+
// TODO: replace this with non-packed ScaledTruncOp
592+
Value scaleTrunc = rewriter.create<amdgpu::PackedScaledTruncOp>(
593+
loc, truncScaleResultType, inCast, scale, 0, /*existing=*/nullptr);
594+
scaleTrunc =
595+
rewriter.replaceOpWithNewOp<vector::ExtractOp>(op, scaleTrunc, 0);
596+
return success();
597+
}
598+
599+
VectorType inVecType = cast<VectorType>(in.getType());
600+
Value origScale = getOriginalVectorValue(op.getScale());
601+
602+
ArrayRef<int64_t> inShape = inVecType.getShape();
603+
SmallVector<int64_t> originalScaleShape;
604+
if (auto origScaleVecType = dyn_cast<VectorType>(origScale.getType()))
605+
llvm::append_range(originalScaleShape, origScaleVecType.getShape());
606+
607+
originalScaleShape.insert(originalScaleShape.end(),
608+
inShape.size() - originalScaleShape.size(), 1);
609+
610+
auto maybeRatio = computeShapeRatio(inShape, originalScaleShape);
611+
assert(maybeRatio &&
612+
"failed to derive block size from broadcast or splat operation");
613+
614+
SmallVector<int64_t> ratio =
615+
maybeRatio.value_or(SmallVector<int64_t>(inShape.size(), 1));
616+
617+
int64_t blockSize = computeProduct(ratio);
618+
619+
Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);
620+
621+
for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) {
622+
SmallVector<int64_t> strides(offsets.size(), 1);
623+
Value block = rewriter.create<vector::ExtractStridedSliceOp>(
624+
loc, in, offsets, ratio, strides);
625+
VectorType block1DType = VectorType::get(blockSize, inType);
626+
Value block1D =
627+
rewriter.create<vector::ShapeCastOp>(loc, block1DType, block);
628+
Value uniformScale =
629+
rewriter.create<vector::ExtractOp>(loc, scale, offsets);
630+
631+
VectorType blockResultType = VectorType::get(blockSize, outType);
632+
Value blockResult =
633+
rewriter.createOrFold<vector::SplatOp>(loc, blockResultType, zero);
634+
635+
for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
636+
i < blockSize;
637+
i += sliceWidth, sliceWidth = std::min(opWidth, blockSize - i)) {
638+
Value slice = rewriter.create<vector::ExtractStridedSliceOp>(
639+
loc, block1D, i, sliceWidth, 1);
640+
// TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1
641+
Value scaleTrunc = rewriter.create<amdgpu::PackedScaledTruncOp>(
642+
loc, truncScaleResultType, slice, uniformScale, 0,
643+
/*existing=*/nullptr);
644+
int64_t packedWidth =
645+
cast<VectorType>(scaleTrunc.getType()).getNumElements();
646+
if (packedWidth != opWidth)
647+
scaleTrunc = rewriter.create<vector::ExtractStridedSliceOp>(
648+
loc, scaleTrunc, 0, sliceWidth, 1);
649+
blockResult = rewriter.create<vector::InsertStridedSliceOp>(
650+
loc, scaleTrunc, blockResult, i, 1);
651+
}
652+
653+
VectorType resultType = VectorType::get(ratio, outType);
654+
Value cast =
655+
rewriter.create<vector::ShapeCastOp>(loc, resultType, blockResult);
656+
result = rewriter.create<vector::InsertStridedSliceOp>(loc, cast, result,
657+
offsets, strides);
658+
}
659+
660+
rewriter.replaceOp(op, result);
661+
662+
return success();
663+
}
664+
398665
void mlir::arith::populateArithToAMDGPUConversionPatterns(
399666
RewritePatternSet &patterns, bool convertFP8Arithmetic,
400667
bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {
@@ -406,6 +673,11 @@ void mlir::arith::populateArithToAMDGPUConversionPatterns(
406673
}
407674
if (allowPackedF16Rtz)
408675
patterns.add<TruncfToFloat16RewritePattern>(patterns.getContext());
676+
677+
if (chipset >= kGfx950) {
678+
patterns.add<ScalingExtFRewritePattern>(patterns.getContext());
679+
patterns.add<ScalingTruncFRewritePattern>(patterns.getContext());
680+
}
409681
}
410682

411683
void ArithToAMDGPUConversionPass::runOnOperation() {

0 commit comments

Comments
 (0)