11
11
#include " mlir/Dialect/Vector/IR/VectorOps.h"
12
12
#include " mlir/IR/BuiltinTypeInterfaces.h"
13
13
#include " mlir/IR/ImplicitLocOpBuilder.h"
14
+ #include " mlir/IR/Location.h"
14
15
#include " mlir/IR/TypeUtilities.h"
15
16
#include " mlir/Transforms/DialectConversion.h"
17
+ #include " llvm/ADT/SmallVectorExtras.h"
18
+ #include < cstdint>
16
19
17
20
namespace mlir {
18
21
namespace arith {
@@ -34,6 +37,18 @@ static Value createConst(Location loc, Type type, int value,
34
37
return rewriter.create <arith::ConstantOp>(loc, attr);
35
38
}
36
39
40
+ // / Create a float constant.
41
+ static Value createFloatConst (Location loc, Type type, APFloat value,
42
+ PatternRewriter &rewriter) {
43
+ auto attr = rewriter.getFloatAttr (getElementTypeOrSelf (type), value);
44
+ if (auto shapedTy = dyn_cast<ShapedType>(type)) {
45
+ return rewriter.create <arith::ConstantOp>(
46
+ loc, DenseElementsAttr::get (shapedTy, attr));
47
+ }
48
+
49
+ return rewriter.create <arith::ConstantOp>(loc, attr);
50
+ }
51
+
37
52
// / Creates shapedType using shape from cloneFrom and base type from cloneTo
38
53
static Type cloneToShapedType (Type cloneFrom, Type cloneTo) {
39
54
if (auto shapedTy = dyn_cast<ShapedType>(cloneFrom)) {
@@ -322,6 +337,100 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
322
337
}
323
338
};
324
339
340
+ // / In this implementation of extf we take advantage of some key patterns we
341
+ // / notice between the binary representation of an F4E2M1 value and its
342
+ // / corresponding value in F32.
343
+ // /
344
+ // / Note: x is sign bit
345
+ // / | Binary | F4E2M1 | f32[23:32]
346
+ // / | x000 | 0.0 | x000 0000 00
347
+ // / | x001 | 0.5 | x011 1111 00
348
+ // / | x010 | 1.0 | x011 1111 10
349
+ // / | x011 | 1.5 | x011 1111 11
350
+ // / | x100 | 2.0 | x010 0000 00
351
+ // / | x101 | 3.0 | x010 0000 01
352
+ // / | x110 | 4.0 | x010 0000 10
353
+ // / | x111 | 6.0 | x010 0000 11
354
+ // /
355
+ // / 1) There are only two versions of bits [25:31] in the f32 result
356
+ // / F4E2M1 bits[2:3] decide whether:
357
+ // / - F32 bits[25:31] = 0011 1111
358
+ // / - F32 bits[25:31] = 0010 0000
359
+ // / Exception is zero where
360
+ // / - F32 bits[25:31] = 0000 0000
361
+ // /
362
+ // / 2) F4E2M1 bits[1:2] = F32 bits[23:24]
363
+ // / Exception is 0.5 where
364
+ // / - F4E2M1 bits[1:2] = 01, F32 bits[23:24] = 00
365
+ // /
366
+ // / 3) F4E2M1 bits[4] = F32 bits[32] (sign bits are equal)
367
+ // /
368
+ // / 4) F32 bits[1:22] = 0
369
+ struct F4E2M1ExtFOpConverter : public OpRewritePattern <arith::ExtFOp> {
370
+ using OpRewritePattern::OpRewritePattern;
371
+ LogicalResult matchAndRewrite (arith::ExtFOp op,
372
+ PatternRewriter &rewriter) const final {
373
+ Location loc = op.getLoc ();
374
+ ImplicitLocOpBuilder b (loc, rewriter);
375
+ Value operand = op.getOperand ();
376
+ Type operandTy = operand.getType ();
377
+ Type resultTy = op.getType ();
378
+ Type operandETy = getElementTypeOrSelf (operandTy);
379
+ Type resultETy = getElementTypeOrSelf (resultTy);
380
+
381
+ if (!isa<Float4E2M1FNType>(operandETy))
382
+ return rewriter.notifyMatchFailure (op, " not a ext of F4E2M1FN" );
383
+
384
+ Type f32Ty = cloneToShapedType (operandTy, b.getF32Type ());
385
+ Type i4Ty = cloneToShapedType (operandTy, b.getI4Type ());
386
+ Type i32Ty = cloneToShapedType (operandTy, b.getI32Type ());
387
+ Value i4Bits = b.create <arith::BitcastOp>(i4Ty, operand);
388
+
389
+ Value c0x0 = createConst (loc, i4Ty, 0x0 , rewriter);
390
+ Value c0x1 = createConst (loc, i4Ty, 0x1 , rewriter);
391
+ Value c0x2 = createConst (loc, i4Ty, 0x2 , rewriter);
392
+ Value c0x4 = createConst (loc, i4Ty, 0x4 , rewriter);
393
+
394
+ // Set last Exponent bit and Mantissa.
395
+ Value c0x00000014 = createConst (loc, i32Ty, 0x14 , rewriter);
396
+ Value bits1To24 = b.create <arith::ShLIOp>(i4Bits, c0x2);
397
+ Value isHalf =
398
+ b.create <arith::CmpIOp>(arith::CmpIPredicate::eq, i4Bits, c0x1);
399
+ bits1To24 = b.create <arith::SelectOp>(isHalf, c0x0, bits1To24);
400
+ bits1To24 = b.create <arith::ExtUIOp>(i32Ty, bits1To24);
401
+ bits1To24 = b.create <arith::ShLIOp>(bits1To24, c0x00000014);
402
+
403
+ // Set first 7 bits of Exponent.
404
+ Value zeroExpBits = createConst (loc, i32Ty, 0x00000000 , rewriter);
405
+ Value highExpBits = createConst (loc, i32Ty, 0x40000000 , rewriter);
406
+ Value lowExpBits = createConst (loc, i32Ty, 0x3f000000 , rewriter);
407
+ Value useLargerExp =
408
+ b.create <arith::CmpIOp>(arith::CmpIPredicate::uge, i4Bits, c0x4);
409
+ Value bits25To31 =
410
+ b.create <arith::SelectOp>(useLargerExp, highExpBits, lowExpBits);
411
+ Value zeroExp =
412
+ b.create <arith::CmpIOp>(arith::CmpIPredicate::eq, i4Bits, c0x0);
413
+ bits25To31 = b.create <arith::SelectOp>(zeroExp, zeroExpBits, bits25To31);
414
+
415
+ // Set sign.
416
+ Value c0x80000000 = createConst (loc, i32Ty, 0x80000000 , rewriter);
417
+ Value c0x8 = createConst (loc, i4Ty, 0x8 , rewriter);
418
+ Value negative =
419
+ b.create <arith::CmpIOp>(arith::CmpIPredicate::uge, i4Bits, c0x8);
420
+ Value bit32 = b.create <arith::SelectOp>(negative, c0x80000000, zeroExpBits);
421
+
422
+ // Add segments together.
423
+ Value bits1To31 = b.create <arith::AddIOp>(bits1To24, bits25To31);
424
+ Value bits1To32 = b.create <arith::AddIOp>(bits1To31, bit32);
425
+ Value result = b.create <arith::BitcastOp>(f32Ty, bits1To32);
426
+ if (!isa<Float32Type>(resultETy))
427
+ result = b.create <arith::TruncFOp>(resultTy, result);
428
+
429
+ rewriter.replaceOp (op, result);
430
+ return success ();
431
+ }
432
+ };
433
+
325
434
struct F8E8M0ExtFOpConverter : public OpRewritePattern <arith::ExtFOp> {
326
435
using OpRewritePattern::OpRewritePattern;
327
436
LogicalResult matchAndRewrite (arith::ExtFOp op,
@@ -366,6 +475,130 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
366
475
}
367
476
};
368
477
478
+ // / Conversion from F32 to F4E2M1 according to the OCP Spec:
479
+ // / www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
480
+ // /
481
+ // / The spec requiers us to perform Round to Nearest, Ties to Even.
482
+ // /
483
+ // / This means that after rounding, we should break ties by choosing the option
484
+ // / which results in a mantissa of 0 in the least significant digit.
485
+ // /
486
+ // / Table of representable values in F4E2M1:
487
+ // /
488
+ // / Note: x is sign bit
489
+ // / | Binary | F4E2M1 | F32[23:32]
490
+ // / | x000 | 0.0 | x000 0000 00
491
+ // / | x001 | 0.5 | x011 1111 00
492
+ // / | x010 | 1.0 | x011 1111 10
493
+ // / | x011 | 1.5 | x011 1111 11
494
+ // / | x100 | 2.0 | x010 0000 00
495
+ // / | x101 | 3.0 | x010 0000 01
496
+ // / | x110 | 4.0 | x010 0000 10
497
+ // / | x111 | 6.0 | x010 0000 11
498
+ // /
499
+ // / Conversion procedure:
500
+ // / Step 1: Clamp to representable bounds.
501
+ // / Step 2: Convert exponent by adjusting bias.
502
+ // / Step 3: Set mantissa to first bit.
503
+ // / Step 4: Special consideration for subnormal and zero exponent.
504
+ // / Step 5: Round up if necessary, if mantissa[1:] greater than 1000000 or
505
+ // / subnormal.
506
+ struct F4E2M1TruncFOpConverter : public OpRewritePattern <arith::TruncFOp> {
507
+ using OpRewritePattern::OpRewritePattern;
508
+ LogicalResult matchAndRewrite (arith::TruncFOp op,
509
+ PatternRewriter &rewriter) const final {
510
+ Location loc = op.getLoc ();
511
+ ImplicitLocOpBuilder b (loc, rewriter);
512
+ Value operand = op.getOperand ();
513
+ Type operandTy = operand.getType ();
514
+ Type resultTy = op.getType ();
515
+ Type operandETy = getElementTypeOrSelf (operandTy);
516
+ Type resultETy = getElementTypeOrSelf (resultTy);
517
+
518
+ Type i4Ty = cloneToShapedType (operandTy, b.getI4Type ());
519
+ Type i8Ty = cloneToShapedType (operandTy, b.getI8Type ());
520
+ Type i32Ty = cloneToShapedType (operandTy, b.getI32Type ());
521
+ Type f32Ty = cloneToShapedType (operandTy, b.getF32Type ());
522
+
523
+ if (!isa<Float32Type>(operandETy))
524
+ operand = b.create <arith::ExtFOp>(f32Ty, operand);
525
+ if (!isa<Float4E2M1FNType>(resultETy))
526
+ return rewriter.notifyMatchFailure (op, " not a trunc of F4E2M1FN" );
527
+
528
+ Value c0x1 = createConst (loc, i4Ty, 1 , rewriter);
529
+ Value c0x3 = createConst (loc, i4Ty, 3 , rewriter);
530
+ Value c0x00000016 = createConst (loc, i32Ty, 22 , rewriter);
531
+ Value c0x00 = createConst (loc, i8Ty, 0x00 , rewriter);
532
+ Value c0xff = createConst (loc, i8Ty, 0xff , rewriter);
533
+ Value zeroExpBits = createConst (loc, i32Ty, 0 , rewriter);
534
+
535
+ // Step 0: Clamp to bounds.
536
+ Value cHigherBound = createFloatConst (loc, f32Ty, APFloat (6 .0f ), rewriter);
537
+ Value cLowerBound = createFloatConst (loc, f32Ty, APFloat (-6 .0f ), rewriter);
538
+ Value operandClamped = b.create <arith::MinNumFOp>(cHigherBound, operand);
539
+ operandClamped = b.create <arith::MaxNumFOp>(cLowerBound, operandClamped);
540
+ Value f32Bits = b.create <arith::BitcastOp>(i32Ty, operandClamped);
541
+
542
+ // Step 1: Set sign bit.
543
+ Value cF32ExpManWidth = createConst (loc, i32Ty, 31 , rewriter); // 23
544
+ Value f32Sign = b.create <arith::ShRUIOp>(f32Bits, cF32ExpManWidth);
545
+ Value f4Sign = b.create <arith::TruncIOp>(i4Ty, f32Sign);
546
+ Value f4Bits = b.create <arith::ShLIOp>(f4Sign, c0x3);
547
+
548
+ // Step 2: Convert exponent by adjusting bias.
549
+ Value biasAdjustment = createConst (loc, i32Ty, 0x7e , rewriter);
550
+ Value cF4MantissaWidth = c0x1; // 1
551
+ Value cF32MantissaWidth = createConst (loc, i32Ty, 23 , rewriter); // 23
552
+ Value f32SignExp = b.create <arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
553
+ Value biasAdjustedSignExp =
554
+ b.create <arith::SubIOp>(f32SignExp, biasAdjustment);
555
+ Value f4Exp = b.create <arith::TruncIOp>(i4Ty, biasAdjustedSignExp);
556
+ f4Exp = b.create <arith::ShLIOp>(f4Exp, cF4MantissaWidth);
557
+ f4Bits = b.create <arith::AddIOp>(f4Bits, f4Exp);
558
+
559
+ // Step 3: Set mantissa to first bit.
560
+ Value cF32FirstBitMask = createConst (loc, i32Ty, 0x400000 , rewriter);
561
+ Value man1Bit = b.create <arith::AndIOp>(f32Bits, cF32FirstBitMask);
562
+ man1Bit = b.create <arith::ShRUIOp>(man1Bit, c0x00000016);
563
+ Value f4Man = b.create <arith::TruncIOp>(i4Ty, man1Bit);
564
+ f4Bits = b.create <arith::AddIOp>(f4Bits, f4Man);
565
+
566
+ // Step 4: Special consideration for conversion to 0.5.
567
+ Value cF32MantissaMask = createConst (loc, i32Ty, 0x7fffff , rewriter);
568
+ Value f8Exp = b.create <arith::TruncIOp>(i8Ty, biasAdjustedSignExp);
569
+ Value isSubnormal =
570
+ b.create <arith::CmpIOp>(arith::CmpIPredicate::sle, f8Exp, c0x00);
571
+ Value isNegOneExp =
572
+ b.create <arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0xff);
573
+ Value man23Bits = b.create <arith::AndIOp>(f32Bits, cF32MantissaMask);
574
+ Value isNonZeroMan = b.create <arith::CmpIOp>(arith::CmpIPredicate::ugt,
575
+ man23Bits, zeroExpBits);
576
+ Value roundToHalf = b.create <arith::AndIOp>(isNegOneExp, isNonZeroMan);
577
+ Value isZeroExp =
578
+ b.create <arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
579
+ Value subnormalF4Bits = createConst (loc, i4Ty, 0xf , rewriter);
580
+ Value halfF4Bits = createConst (loc, i4Ty, 0x0 , rewriter);
581
+ Value subResult =
582
+ b.create <arith::SelectOp>(isSubnormal, subnormalF4Bits, f4Bits);
583
+ subResult = b.create <arith::SelectOp>(roundToHalf, halfF4Bits, subResult);
584
+ f4Bits = b.create <arith::SelectOp>(isZeroExp, f4Bits, subResult);
585
+
586
+ // Step 5: Round up if necessary.
587
+ Value cF32Last22BitMask = createConst (loc, i32Ty, 0x3fffff , rewriter);
588
+ Value cRound = createConst (loc, i32Ty, 0x200000 , rewriter); // 010 0000...
589
+ Value man22Bits = b.create <arith::AndIOp>(f32Bits, cF32Last22BitMask);
590
+ Value shouldRound =
591
+ b.create <arith::CmpIOp>(arith::CmpIPredicate::uge, man22Bits, cRound);
592
+ shouldRound = b.create <arith::OrIOp>(shouldRound, isSubnormal);
593
+ Value roundedF4Bits = b.create <arith::AddIOp>(f4Bits, c0x1);
594
+ f4Bits = b.create <arith::SelectOp>(shouldRound, roundedF4Bits, f4Bits);
595
+
596
+ Value result = b.create <arith::BitcastOp>(resultTy, f4Bits);
597
+ rewriter.replaceOp (op, result);
598
+ return success ();
599
+ }
600
+ };
601
+
369
602
/*
370
603
TruncF to F8E8M0 is expected to extract exponent bits out of F32 type
371
604
Since All kinds of Infs and NaNs are mapped to same exponent bits in F32 type,
@@ -498,6 +731,8 @@ struct ArithExpandOpsPass
498
731
arith::populateArithExpandOpsPatterns (patterns);
499
732
500
733
target.addLegalDialect <arith::ArithDialect>();
734
+ target.addLegalDialect <vector::VectorDialect>();
735
+
501
736
// clang-format off
502
737
target.addIllegalOp <
503
738
arith::CeilDivSIOp,
@@ -515,22 +750,24 @@ struct ArithExpandOpsPass
515
750
arith::ScalingTruncFOp
516
751
>();
517
752
518
- if (includeBf16) {
753
+ if (includeBf16)
519
754
arith::populateExpandBFloat16Patterns (patterns);
520
- }
521
- if (includeF8E8M0) {
755
+ if (includeF8E8M0)
522
756
arith::populateExpandF8E8M0Patterns (patterns);
523
- }
757
+ if (includeF4E2M1)
758
+ arith::populateExpandF4E2M1Patterns (patterns);
524
759
525
760
target.addDynamicallyLegalOp <arith::ExtFOp>(
526
761
[=](arith::ExtFOp op) {
527
762
Type inETy = getElementTypeOrSelf (op.getOperand ().getType ());
528
763
Type outETy = getElementTypeOrSelf (op.getType ());
529
764
bool legalTypes = true ;
530
- if (includeBf16)
765
+ if (includeBf16)
531
766
legalTypes &= !(inETy.isBF16 () && outETy.isF32 ());
532
767
if (includeF8E8M0)
533
768
legalTypes &= !llvm::isa<Float8E8M0FNUType>(inETy);
769
+ if (includeF4E2M1)
770
+ legalTypes &= !llvm::isa<Float4E2M1FNType>(inETy);
534
771
return legalTypes;
535
772
});
536
773
@@ -539,10 +776,12 @@ struct ArithExpandOpsPass
539
776
Type inETy = getElementTypeOrSelf (op.getOperand ().getType ());
540
777
Type outETy = getElementTypeOrSelf (op.getType ());
541
778
bool legalTypes = true ;
542
- if (includeBf16)
779
+ if (includeBf16)
543
780
legalTypes &= !(inETy.isF32 () && outETy.isBF16 ());
544
- if (includeF8E8M0)
781
+ if (includeF8E8M0)
545
782
legalTypes &= !(llvm::isa<Float8E8M0FNUType>(outETy));
783
+ if (includeF4E2M1)
784
+ legalTypes &= !llvm::isa<Float4E2M1FNType>(outETy);
546
785
return legalTypes;
547
786
});
548
787
@@ -567,6 +806,11 @@ void mlir::arith::populateExpandBFloat16Patterns(RewritePatternSet &patterns) {
567
806
patterns.getContext ());
568
807
}
569
808
809
+ void mlir::arith::populateExpandF4E2M1Patterns (RewritePatternSet &patterns) {
810
+ patterns.add <F4E2M1ExtFOpConverter, F4E2M1TruncFOpConverter>(
811
+ patterns.getContext ());
812
+ }
813
+
570
814
void mlir::arith::populateExpandF8E8M0Patterns (RewritePatternSet &patterns) {
571
815
patterns.add <F8E8M0ExtFOpConverter, F8E8M0TruncFOpConverter>(
572
816
patterns.getContext ());
0 commit comments