Skip to content

Commit 379a609

Browse files
[mlir][arith][transforms] Adds f4E2M1FN support to truncf and extf (#144157)
See work detail: iree-org/iree#20920 Add support for f4E2M1FN in `arith.truncf` and `arith.extf` ops though a software emulation --------- Signed-off-by: Muzammiluddin Syed <muzasyed@amd.com>
1 parent 940ff11 commit 379a609

File tree

5 files changed

+370
-8
lines changed

5 files changed

+370
-8
lines changed

mlir/include/mlir/Dialect/Arith/Transforms/Passes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns);
5959
/// Add patterns to expand Arith bf16 patterns to lower level bitcasts/shifts.
6060
void populateExpandBFloat16Patterns(RewritePatternSet &patterns);
6161

62+
/// Add patterns to expand Arith f4e2m1 patterns to lower level bitcasts/shifts.
63+
void populateExpandF4E2M1Patterns(RewritePatternSet &patterns);
64+
6265
/// Add patterns to expand Arith f8e8m0 patterns to lower level bitcasts/shifts.
6366
void populateExpandF8E8M0Patterns(RewritePatternSet &patterns);
6467

mlir/include/mlir/Dialect/Arith/Transforms/Passes.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ def ArithExpandOpsPass : Pass<"arith-expand"> {
1919
"Enable the BF16 expansion patterns">,
2020
Option<"includeF8E8M0", "include-f8e8m0", "bool", /*default=*/"false",
2121
"Enable the F8E8M0 expansion patterns">,
22+
Option<"includeF4E2M1", "include-f4e2m1", "bool", /*default=*/"false",
23+
"Enable the F4E2M1 expansion patterns">,
2224
];
2325
}
2426

mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp

Lines changed: 251 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,11 @@
1111
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1212
#include "mlir/IR/BuiltinTypeInterfaces.h"
1313
#include "mlir/IR/ImplicitLocOpBuilder.h"
14+
#include "mlir/IR/Location.h"
1415
#include "mlir/IR/TypeUtilities.h"
1516
#include "mlir/Transforms/DialectConversion.h"
17+
#include "llvm/ADT/SmallVectorExtras.h"
18+
#include <cstdint>
1619

1720
namespace mlir {
1821
namespace arith {
@@ -34,6 +37,18 @@ static Value createConst(Location loc, Type type, int value,
3437
return rewriter.create<arith::ConstantOp>(loc, attr);
3538
}
3639

40+
/// Create a float constant.
41+
static Value createFloatConst(Location loc, Type type, APFloat value,
42+
PatternRewriter &rewriter) {
43+
auto attr = rewriter.getFloatAttr(getElementTypeOrSelf(type), value);
44+
if (auto shapedTy = dyn_cast<ShapedType>(type)) {
45+
return rewriter.create<arith::ConstantOp>(
46+
loc, DenseElementsAttr::get(shapedTy, attr));
47+
}
48+
49+
return rewriter.create<arith::ConstantOp>(loc, attr);
50+
}
51+
3752
/// Creates shapedType using shape from cloneFrom and base type from cloneTo
3853
static Type cloneToShapedType(Type cloneFrom, Type cloneTo) {
3954
if (auto shapedTy = dyn_cast<ShapedType>(cloneFrom)) {
@@ -322,6 +337,100 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
322337
}
323338
};
324339

340+
/// In this implementation of extf we take advantage of some key patterns we
341+
/// notice between the binary representation of an F4E2M1 value and its
342+
/// corresponding value in F32.
343+
///
344+
/// Note: x is sign bit
345+
/// | Binary | F4E2M1 | f32[23:32]
346+
/// | x000 | 0.0 | x000 0000 00
347+
/// | x001 | 0.5 | x011 1111 00
348+
/// | x010 | 1.0 | x011 1111 10
349+
/// | x011 | 1.5 | x011 1111 11
350+
/// | x100 | 2.0 | x010 0000 00
351+
/// | x101 | 3.0 | x010 0000 01
352+
/// | x110 | 4.0 | x010 0000 10
353+
/// | x111 | 6.0 | x010 0000 11
354+
///
355+
/// 1) There are only two versions of bits [25:31] in the f32 result
356+
/// F4E2M1 bits[2:3] decide whether:
357+
/// - F32 bits[25:31] = 0011 1111
358+
/// - F32 bits[25:31] = 0010 0000
359+
/// Exception is zero where
360+
/// - F32 bits[25:31] = 0000 0000
361+
///
362+
/// 2) F4E2M1 bits[1:2] = F32 bits[23:24]
363+
/// Exception is 0.5 where
364+
/// - F4E2M1 bits[1:2] = 01, F32 bits[23:24] = 00
365+
///
366+
/// 3) F4E2M1 bits[4] = F32 bits[32] (sign bits are equal)
367+
///
368+
/// 4) F32 bits[1:22] = 0
369+
struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
370+
using OpRewritePattern::OpRewritePattern;
371+
LogicalResult matchAndRewrite(arith::ExtFOp op,
372+
PatternRewriter &rewriter) const final {
373+
Location loc = op.getLoc();
374+
ImplicitLocOpBuilder b(loc, rewriter);
375+
Value operand = op.getOperand();
376+
Type operandTy = operand.getType();
377+
Type resultTy = op.getType();
378+
Type operandETy = getElementTypeOrSelf(operandTy);
379+
Type resultETy = getElementTypeOrSelf(resultTy);
380+
381+
if (!isa<Float4E2M1FNType>(operandETy))
382+
return rewriter.notifyMatchFailure(op, "not a ext of F4E2M1FN");
383+
384+
Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
385+
Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
386+
Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
387+
Value i4Bits = b.create<arith::BitcastOp>(i4Ty, operand);
388+
389+
Value c0x0 = createConst(loc, i4Ty, 0x0, rewriter);
390+
Value c0x1 = createConst(loc, i4Ty, 0x1, rewriter);
391+
Value c0x2 = createConst(loc, i4Ty, 0x2, rewriter);
392+
Value c0x4 = createConst(loc, i4Ty, 0x4, rewriter);
393+
394+
// Set last Exponent bit and Mantissa.
395+
Value c0x00000014 = createConst(loc, i32Ty, 0x14, rewriter);
396+
Value bits1To24 = b.create<arith::ShLIOp>(i4Bits, c0x2);
397+
Value isHalf =
398+
b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, i4Bits, c0x1);
399+
bits1To24 = b.create<arith::SelectOp>(isHalf, c0x0, bits1To24);
400+
bits1To24 = b.create<arith::ExtUIOp>(i32Ty, bits1To24);
401+
bits1To24 = b.create<arith::ShLIOp>(bits1To24, c0x00000014);
402+
403+
// Set first 7 bits of Exponent.
404+
Value zeroExpBits = createConst(loc, i32Ty, 0x00000000, rewriter);
405+
Value highExpBits = createConst(loc, i32Ty, 0x40000000, rewriter);
406+
Value lowExpBits = createConst(loc, i32Ty, 0x3f000000, rewriter);
407+
Value useLargerExp =
408+
b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, i4Bits, c0x4);
409+
Value bits25To31 =
410+
b.create<arith::SelectOp>(useLargerExp, highExpBits, lowExpBits);
411+
Value zeroExp =
412+
b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, i4Bits, c0x0);
413+
bits25To31 = b.create<arith::SelectOp>(zeroExp, zeroExpBits, bits25To31);
414+
415+
// Set sign.
416+
Value c0x80000000 = createConst(loc, i32Ty, 0x80000000, rewriter);
417+
Value c0x8 = createConst(loc, i4Ty, 0x8, rewriter);
418+
Value negative =
419+
b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, i4Bits, c0x8);
420+
Value bit32 = b.create<arith::SelectOp>(negative, c0x80000000, zeroExpBits);
421+
422+
// Add segments together.
423+
Value bits1To31 = b.create<arith::AddIOp>(bits1To24, bits25To31);
424+
Value bits1To32 = b.create<arith::AddIOp>(bits1To31, bit32);
425+
Value result = b.create<arith::BitcastOp>(f32Ty, bits1To32);
426+
if (!isa<Float32Type>(resultETy))
427+
result = b.create<arith::TruncFOp>(resultTy, result);
428+
429+
rewriter.replaceOp(op, result);
430+
return success();
431+
}
432+
};
433+
325434
struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
326435
using OpRewritePattern::OpRewritePattern;
327436
LogicalResult matchAndRewrite(arith::ExtFOp op,
@@ -366,6 +475,130 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
366475
}
367476
};
368477

478+
/// Conversion from F32 to F4E2M1 according to the OCP Spec:
479+
/// www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
480+
///
481+
/// The spec requiers us to perform Round to Nearest, Ties to Even.
482+
///
483+
/// This means that after rounding, we should break ties by choosing the option
484+
/// which results in a mantissa of 0 in the least significant digit.
485+
///
486+
/// Table of representable values in F4E2M1:
487+
///
488+
/// Note: x is sign bit
489+
/// | Binary | F4E2M1 | F32[23:32]
490+
/// | x000 | 0.0 | x000 0000 00
491+
/// | x001 | 0.5 | x011 1111 00
492+
/// | x010 | 1.0 | x011 1111 10
493+
/// | x011 | 1.5 | x011 1111 11
494+
/// | x100 | 2.0 | x010 0000 00
495+
/// | x101 | 3.0 | x010 0000 01
496+
/// | x110 | 4.0 | x010 0000 10
497+
/// | x111 | 6.0 | x010 0000 11
498+
///
499+
/// Conversion procedure:
500+
/// Step 1: Clamp to representable bounds.
501+
/// Step 2: Convert exponent by adjusting bias.
502+
/// Step 3: Set mantissa to first bit.
503+
/// Step 4: Special consideration for subnormal and zero exponent.
504+
/// Step 5: Round up if necessary, if mantissa[1:] greater than 1000000 or
505+
/// subnormal.
506+
struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
507+
using OpRewritePattern::OpRewritePattern;
508+
LogicalResult matchAndRewrite(arith::TruncFOp op,
509+
PatternRewriter &rewriter) const final {
510+
Location loc = op.getLoc();
511+
ImplicitLocOpBuilder b(loc, rewriter);
512+
Value operand = op.getOperand();
513+
Type operandTy = operand.getType();
514+
Type resultTy = op.getType();
515+
Type operandETy = getElementTypeOrSelf(operandTy);
516+
Type resultETy = getElementTypeOrSelf(resultTy);
517+
518+
Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
519+
Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
520+
Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
521+
Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
522+
523+
if (!isa<Float32Type>(operandETy))
524+
operand = b.create<arith::ExtFOp>(f32Ty, operand);
525+
if (!isa<Float4E2M1FNType>(resultETy))
526+
return rewriter.notifyMatchFailure(op, "not a trunc of F4E2M1FN");
527+
528+
Value c0x1 = createConst(loc, i4Ty, 1, rewriter);
529+
Value c0x3 = createConst(loc, i4Ty, 3, rewriter);
530+
Value c0x00000016 = createConst(loc, i32Ty, 22, rewriter);
531+
Value c0x00 = createConst(loc, i8Ty, 0x00, rewriter);
532+
Value c0xff = createConst(loc, i8Ty, 0xff, rewriter);
533+
Value zeroExpBits = createConst(loc, i32Ty, 0, rewriter);
534+
535+
// Step 0: Clamp to bounds.
536+
Value cHigherBound = createFloatConst(loc, f32Ty, APFloat(6.0f), rewriter);
537+
Value cLowerBound = createFloatConst(loc, f32Ty, APFloat(-6.0f), rewriter);
538+
Value operandClamped = b.create<arith::MinNumFOp>(cHigherBound, operand);
539+
operandClamped = b.create<arith::MaxNumFOp>(cLowerBound, operandClamped);
540+
Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operandClamped);
541+
542+
// Step 1: Set sign bit.
543+
Value cF32ExpManWidth = createConst(loc, i32Ty, 31, rewriter); // 23
544+
Value f32Sign = b.create<arith::ShRUIOp>(f32Bits, cF32ExpManWidth);
545+
Value f4Sign = b.create<arith::TruncIOp>(i4Ty, f32Sign);
546+
Value f4Bits = b.create<arith::ShLIOp>(f4Sign, c0x3);
547+
548+
// Step 2: Convert exponent by adjusting bias.
549+
Value biasAdjustment = createConst(loc, i32Ty, 0x7e, rewriter);
550+
Value cF4MantissaWidth = c0x1; // 1
551+
Value cF32MantissaWidth = createConst(loc, i32Ty, 23, rewriter); // 23
552+
Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
553+
Value biasAdjustedSignExp =
554+
b.create<arith::SubIOp>(f32SignExp, biasAdjustment);
555+
Value f4Exp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedSignExp);
556+
f4Exp = b.create<arith::ShLIOp>(f4Exp, cF4MantissaWidth);
557+
f4Bits = b.create<arith::AddIOp>(f4Bits, f4Exp);
558+
559+
// Step 3: Set mantissa to first bit.
560+
Value cF32FirstBitMask = createConst(loc, i32Ty, 0x400000, rewriter);
561+
Value man1Bit = b.create<arith::AndIOp>(f32Bits, cF32FirstBitMask);
562+
man1Bit = b.create<arith::ShRUIOp>(man1Bit, c0x00000016);
563+
Value f4Man = b.create<arith::TruncIOp>(i4Ty, man1Bit);
564+
f4Bits = b.create<arith::AddIOp>(f4Bits, f4Man);
565+
566+
// Step 4: Special consideration for conversion to 0.5.
567+
Value cF32MantissaMask = createConst(loc, i32Ty, 0x7fffff, rewriter);
568+
Value f8Exp = b.create<arith::TruncIOp>(i8Ty, biasAdjustedSignExp);
569+
Value isSubnormal =
570+
b.create<arith::CmpIOp>(arith::CmpIPredicate::sle, f8Exp, c0x00);
571+
Value isNegOneExp =
572+
b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0xff);
573+
Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32MantissaMask);
574+
Value isNonZeroMan = b.create<arith::CmpIOp>(arith::CmpIPredicate::ugt,
575+
man23Bits, zeroExpBits);
576+
Value roundToHalf = b.create<arith::AndIOp>(isNegOneExp, isNonZeroMan);
577+
Value isZeroExp =
578+
b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
579+
Value subnormalF4Bits = createConst(loc, i4Ty, 0xf, rewriter);
580+
Value halfF4Bits = createConst(loc, i4Ty, 0x0, rewriter);
581+
Value subResult =
582+
b.create<arith::SelectOp>(isSubnormal, subnormalF4Bits, f4Bits);
583+
subResult = b.create<arith::SelectOp>(roundToHalf, halfF4Bits, subResult);
584+
f4Bits = b.create<arith::SelectOp>(isZeroExp, f4Bits, subResult);
585+
586+
// Step 5: Round up if necessary.
587+
Value cF32Last22BitMask = createConst(loc, i32Ty, 0x3fffff, rewriter);
588+
Value cRound = createConst(loc, i32Ty, 0x200000, rewriter); // 010 0000...
589+
Value man22Bits = b.create<arith::AndIOp>(f32Bits, cF32Last22BitMask);
590+
Value shouldRound =
591+
b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man22Bits, cRound);
592+
shouldRound = b.create<arith::OrIOp>(shouldRound, isSubnormal);
593+
Value roundedF4Bits = b.create<arith::AddIOp>(f4Bits, c0x1);
594+
f4Bits = b.create<arith::SelectOp>(shouldRound, roundedF4Bits, f4Bits);
595+
596+
Value result = b.create<arith::BitcastOp>(resultTy, f4Bits);
597+
rewriter.replaceOp(op, result);
598+
return success();
599+
}
600+
};
601+
369602
/*
370603
TruncF to F8E8M0 is expected to extract exponent bits out of F32 type
371604
Since All kinds of Infs and NaNs are mapped to same exponent bits in F32 type,
@@ -498,6 +731,8 @@ struct ArithExpandOpsPass
498731
arith::populateArithExpandOpsPatterns(patterns);
499732

500733
target.addLegalDialect<arith::ArithDialect>();
734+
target.addLegalDialect<vector::VectorDialect>();
735+
501736
// clang-format off
502737
target.addIllegalOp<
503738
arith::CeilDivSIOp,
@@ -515,22 +750,24 @@ struct ArithExpandOpsPass
515750
arith::ScalingTruncFOp
516751
>();
517752

518-
if (includeBf16) {
753+
if (includeBf16)
519754
arith::populateExpandBFloat16Patterns(patterns);
520-
}
521-
if (includeF8E8M0) {
755+
if (includeF8E8M0)
522756
arith::populateExpandF8E8M0Patterns(patterns);
523-
}
757+
if (includeF4E2M1)
758+
arith::populateExpandF4E2M1Patterns(patterns);
524759

525760
target.addDynamicallyLegalOp<arith::ExtFOp>(
526761
[=](arith::ExtFOp op) {
527762
Type inETy = getElementTypeOrSelf(op.getOperand().getType());
528763
Type outETy = getElementTypeOrSelf(op.getType());
529764
bool legalTypes = true;
530-
if (includeBf16)
765+
if (includeBf16)
531766
legalTypes &= !(inETy.isBF16() && outETy.isF32());
532767
if (includeF8E8M0)
533768
legalTypes &= !llvm::isa<Float8E8M0FNUType>(inETy);
769+
if (includeF4E2M1)
770+
legalTypes &= !llvm::isa<Float4E2M1FNType>(inETy);
534771
return legalTypes;
535772
});
536773

@@ -539,10 +776,12 @@ struct ArithExpandOpsPass
539776
Type inETy = getElementTypeOrSelf(op.getOperand().getType());
540777
Type outETy = getElementTypeOrSelf(op.getType());
541778
bool legalTypes = true;
542-
if (includeBf16)
779+
if (includeBf16)
543780
legalTypes &= !(inETy.isF32() && outETy.isBF16());
544-
if (includeF8E8M0)
781+
if (includeF8E8M0)
545782
legalTypes &= !(llvm::isa<Float8E8M0FNUType>(outETy));
783+
if (includeF4E2M1)
784+
legalTypes &= !llvm::isa<Float4E2M1FNType>(outETy);
546785
return legalTypes;
547786
});
548787

@@ -567,6 +806,11 @@ void mlir::arith::populateExpandBFloat16Patterns(RewritePatternSet &patterns) {
567806
patterns.getContext());
568807
}
569808

809+
void mlir::arith::populateExpandF4E2M1Patterns(RewritePatternSet &patterns) {
810+
patterns.add<F4E2M1ExtFOpConverter, F4E2M1TruncFOpConverter>(
811+
patterns.getContext());
812+
}
813+
570814
void mlir::arith::populateExpandF8E8M0Patterns(RewritePatternSet &patterns) {
571815
patterns.add<F8E8M0ExtFOpConverter, F8E8M0TruncFOpConverter>(
572816
patterns.getContext());

mlir/test/Dialect/Arith/expand-ops.mlir

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s -arith-expand="include-bf16=true include-f8e8m0=true" -verify-diagnostics -split-input-file | FileCheck %s
1+
// RUN: mlir-opt %s -arith-expand="include-bf16=true include-f8e8m0=true include-f4e2m1=true" -verify-diagnostics -split-input-file | FileCheck %s
22
// RUN: mlir-opt %s -arith-expand -split-input-file -verify-diagnostics | FileCheck %s --check-prefix=SCHECK
33

44
// Test ceil divide with signed integer
@@ -593,3 +593,43 @@ func.func @minui(%a: i32, %b: i32) -> i32 {
593593
// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ult, %[[LHS]], %[[RHS]] : i32
594594
// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : i32
595595
// CHECK-NEXT: return %[[RESULT]] : i32
596+
597+
// -----
598+
599+
func.func @truncf_f32_to_f4E2M1FN(%arg0 : f32) -> f4E2M1FN {
600+
%0 = arith.truncf %arg0 : f32 to f4E2M1FN
601+
return %0 : f4E2M1FN
602+
}
603+
604+
// CHECK-LABEL: @truncf_f32_to_f4E2M1FN
605+
// CHECK-NOT: arith.truncf
606+
607+
// -----
608+
609+
func.func @truncf_vector_f32_to_f4E2M1FN(%arg0 : vector<4xf32>) -> vector<4xf4E2M1FN> {
610+
%0 = arith.truncf %arg0 : vector<4xf32> to vector<4xf4E2M1FN>
611+
return %0 : vector<4xf4E2M1FN>
612+
}
613+
614+
// CHECK-LABEL: @truncf_vector_f32_to_f4E2M1FN
615+
// CHECK-NOT: arith.truncf
616+
617+
// -----
618+
619+
func.func @extf_f4E2M1FN_to_f32(%arg0 : f4E2M1FN) -> f32 {
620+
%0 = arith.extf %arg0 : f4E2M1FN to f32
621+
return %0 : f32
622+
}
623+
624+
// CHECK-LABEL: @extf_f4E2M1FN_to_f32
625+
// CHECK-NOT: arith.extf
626+
627+
// -----
628+
629+
func.func @extf_vector_f4E2M1FN_to_f32(%arg0 : vector<4xf4E2M1FN>) -> vector<4xf32> {
630+
%0 = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xf32>
631+
return %0 : vector<4xf32>
632+
}
633+
634+
// CHECK-LABEL: @extf_vector_f4E2M1FN_to_f32
635+
// CHECK-NOT: arith.extf

0 commit comments

Comments
 (0)