From 2a1ec30d818a902857d22ac2995e17798b2c0b01 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Thu, 10 Jul 2025 13:58:14 -0400 Subject: [PATCH] [mlir][MemRef] Add support for emulating narrow floats This enables memref.load/store + vector.load/store support for sub-byte float types. Since the memref types don't matter, we still use the same types as integers with equivalent widths, with a few extra bitcasts needed around certain operations. --- .../MemRef/Transforms/EmulateNarrowType.cpp | 44 ++++++++--- .../Transforms/VectorEmulateNarrowType.cpp | 14 +++- .../Dialect/MemRef/emulate-narrow-type.mlir | 58 ++++++++++++++ .../Vector/vector-emulate-narrow-type.mlir | 78 +++++++++++++++++++ 4 files changed, 181 insertions(+), 13 deletions(-) diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp index d2a032688fb6d..ec2bc95291455 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp @@ -323,19 +323,28 @@ struct ConvertMemRefLoad final : OpConversionPattern { // It is not clear if this case actually happens in practice, but we keep // the operations just in case. Otherwise, if the arith computation bitwidth // is different from the emulated bitwidth we truncate the result. - Operation *result; + Value result; auto resultTy = getTypeConverter()->convertType(oldElementType); - if (resultTy == convertedElementType) { + auto conversionTy = + resultTy.isInteger() + ? resultTy + : IntegerType::get(rewriter.getContext(), + resultTy.getIntOrFloatBitWidth()); + if (conversionTy == convertedElementType) { auto mask = rewriter.create( loc, convertedElementType, rewriter.getIntegerAttr(convertedElementType, (1 << srcBits) - 1)); result = rewriter.create(loc, bitsLoad, mask); } else { - result = rewriter.create(loc, resultTy, bitsLoad); + result = rewriter.create(loc, conversionTy, bitsLoad); } - rewriter.replaceOp(op, result->getResult(0)); + if (conversionTy != resultTy) { + result = rewriter.create(loc, resultTy, result); + } + + rewriter.replaceOp(op, result); return success(); } }; @@ -415,8 +424,18 @@ struct ConvertMemrefStore final : OpConversionPattern { } Location loc = op.getLoc(); - Value extendedInput = rewriter.create(loc, dstIntegerType, - adaptor.getValue()); + + // Pad the input value with 0s on the left. + Value input = adaptor.getValue(); + if (!input.getType().isInteger()) { + input = rewriter.create( + loc, + IntegerType::get(rewriter.getContext(), + input.getType().getIntOrFloatBitWidth()), + input); + } + Value extendedInput = + rewriter.create(loc, dstIntegerType, input); // Special case 0-rank memref stores. No need for masking. if (convertedType.getRank() == 0) { @@ -619,11 +638,11 @@ void memref::populateMemRefNarrowTypeEmulationConversions( arith::NarrowTypeEmulationConverter &typeConverter) { typeConverter.addConversion( [&typeConverter](MemRefType ty) -> std::optional { - auto intTy = dyn_cast(ty.getElementType()); - if (!intTy) + Type elementType = ty.getElementType(); + if (!elementType.isIntOrFloat()) return ty; - unsigned width = intTy.getWidth(); + unsigned width = elementType.getIntOrFloatBitWidth(); unsigned loadStoreWidth = typeConverter.getLoadStoreBitwidth(); if (width >= loadStoreWidth) return ty; @@ -636,8 +655,11 @@ void memref::populateMemRefNarrowTypeEmulationConversions( if (!strides.empty() && strides.back() != 1) return nullptr; - auto newElemTy = IntegerType::get(ty.getContext(), loadStoreWidth, - intTy.getSignedness()); + auto newElemTy = IntegerType::get( + ty.getContext(), loadStoreWidth, + elementType.isInteger() + ? cast(elementType).getSignedness() + : IntegerType::SignednessSemantics::Signless); if (!newElemTy) return nullptr; diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index 004beadc9ec7d..0fe08417f818f 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -1268,8 +1268,18 @@ struct ConvertVectorTransferRead final bool isDivisibleInSize = fitsInMultiByteContainerTy(op.getVectorType(), containerElemTy); - auto newPadding = rewriter.create(loc, containerElemTy, - adaptor.getPadding()); + // Pad the padding value with 0s on the left. These bits are discarded and + // thus their values don't matter. + Value padding = adaptor.getPadding(); + if (!padding.getType().isInteger()) { + padding = rewriter.create( + loc, + IntegerType::get(rewriter.getContext(), + padding.getType().getIntOrFloatBitWidth()), + padding); + } + auto newPadding = + rewriter.create(loc, containerElemTy, padding); auto stridedMetadata = rewriter.create(loc, op.getBase()); diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir index 3378d329e8205..0cce8c18a40bc 100644 --- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir +++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir @@ -61,6 +61,41 @@ func.func @memref_load_i4(%arg0: index) -> i4 { // ----- +func.func @memref_load_f4(%arg0: index) -> f4E2M1FN { + %0 = memref.alloc() : memref<5xf4E2M1FN> + %1 = memref.load %0[%arg0] : memref<5xf4E2M1FN> + return %1 : f4E2M1FN +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 2) * 8) +// CHECK: func @memref_load_f4( +// CHECK-SAME: %[[ARG0:.+]]: index +// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8> +// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]] +// CHECK: %[[LOADVAL:.+]] = memref.load %[[ALLOC]][%[[INDEX]]] +// CHECK: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]] +// CHECK: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8 +// CHECK: %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]] +// CHECK: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i8 to i4 +// CHECK: %[[BC:.+]] = arith.bitcast %[[TRUNC]] : i4 to f4E2M1FN +// CHECK: return %[[BC]] + +// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 8)> +// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 8) * 32) +// CHECK32: func @memref_load_f4( +// CHECK32-SAME: %[[ARG0:.+]]: index +// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<1xi32> +// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]] +// CHECK32: %[[LOADVAL:.+]] = memref.load %[[ALLOC]][%[[INDEX]]] +// CHECK32: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]] +// CHECK32: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32 +// CHECK32: %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]] +// CHECK32: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i4 +// CHECK32: %[[BC:.+]] = arith.bitcast %[[TRUNC]] : i4 to f4E2M1FN +// CHECK32: return %[[BC]] + +// ----- + func.func @memref_load_i4_rank2(%arg0: index, %arg1: index) -> i4 { %0 = memref.alloc() : memref<3x125xi4> %align0 = memref.assume_alignment %0, 64 : memref<3x125xi4> @@ -470,6 +505,29 @@ func.func @rank_zero_memref_store(%arg0: i4) -> () { // ----- +func.func @rank_zero_memref_store_f4(%arg0: f4E2M1FN) -> () { + %0 = memref.alloc() : memref + memref.store %arg0, %0[] : memref + return +} +// CHECK-LABEL: func @rank_zero_memref +// CHECK-SAME: %[[ARG0:.+]]: f4E2M1FN +// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref +// CHECK: %[[BC:.+]] = arith.bitcast %[[ARG0]] : f4E2M1FN to i4 +// CHECK: %[[EXTUI:.+]] = arith.extui %[[BC]] : i4 to i8 +// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw assign %[[EXTUI]], %[[ALLOC]][] : (i8, memref) -> i8 +// CHECK: return + +// CHECK32-LABEL: func @rank_zero_memref +// CHECK32-SAME: %[[ARG0:.+]]: f4E2M1FN +// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref +// CHECK32: %[[BC:.+]] = arith.bitcast %[[ARG0]] : f4E2M1FN to i4 +// CHECK32: %[[EXTUI:.+]] = arith.extui %[[BC]] : i4 to i32 +// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw assign %[[EXTUI]], %[[ALLOC]][] : (i32, memref) -> i32 +// CHECK32: return + +// ----- + func.func @memref_collapse_shape_i4(%idx0 : index, %idx1 : index) -> i4 { %arr = memref.alloc() : memref<32x8x128xi4> %collapse = memref.collapse_shape %arr[[0, 1], [2]] : memref<32x8x128xi4> into memref<256x128xi4> diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir index 6c924492b513e..98b1f07ef5fb0 100644 --- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir +++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir @@ -53,6 +53,31 @@ func.func @vector_load_i4(%arg1: index, %arg2: index) -> vector<3x8xi4> { // ----- +func.func @vector_load_f4(%arg1: index, %arg2: index) -> vector<3x8xf4E2M1FN> { + %0 = memref.alloc() : memref<3x8xf4E2M1FN> + %cst = arith.constant dense<0.0> : vector<3x8xf4E2M1FN> + %1 = vector.load %0[%arg1, %arg2] : memref<3x8xf4E2M1FN>, vector<8xf4E2M1FN> + %2 = vector.insert %1, %cst [0] : vector<8xf4E2M1FN> into vector<3x8xf4E2M1FN> + return %2 : vector<3x8xf4E2M1FN> +} +// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)> +// CHECK: func @vector_load_f4 +// CHECK-SAME: (%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index) +// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8> +// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]] +// CHECK: %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<12xi8>, vector<4xi8> +// CHECK: %[[VEC_F4:.+]] = vector.bitcast %[[VEC]] : vector<4xi8> to vector<8xf4E2M1FN> + +// CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)> +// CHECK32: func @vector_load_f4 +// CHECK32-SAME: (%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index) +// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32> +// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]] +// CHECK32: %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<3xi32>, vector<1xi32> +// CHECK32: %[[VEC_F4:.+]] = vector.bitcast %[[VEC]] : vector<1xi32> to vector<8xf4E2M1FN> + +// ----- + func.func @vector_load_i4_dynamic(%arg0 : index, %arg1 : index, %arg2 : index, %arg3 : index) -> vector<8xi4> { %0 = memref.alloc(%arg0, %arg1) : memref %1 = vector.load %0[%arg2, %arg3] : memref, vector<8xi4> @@ -119,6 +144,37 @@ func.func @vector_transfer_read_i4(%arg1: index, %arg2: index) -> vector<8xi4> { // ----- +func.func @vector_transfer_read_f4(%arg1: index, %arg2: index) -> vector<8xf4E2M1FN> { + %c0 = arith.constant 0.0 : f4E2M1FN + %0 = memref.alloc() : memref<3x8xf4E2M1FN> + %1 = vector.transfer_read %0[%arg1, %arg2], %c0 {in_bounds = [true]} : + memref<3x8xf4E2M1FN>, vector<8xf4E2M1FN> + return %1 : vector<8xf4E2M1FN> +} +// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)> +// CHECK: func @vector_transfer_read_f4 +// CHECK-SAME: (%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index) +// CHECK: %[[CONST:.+]] = arith.constant 0.{{0+}}e+00 : f4E2M1FN +// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8> +// CHECK: %[[BC:.+]] = arith.bitcast %[[CONST]] : f4E2M1FN to i4 +// CHECK: %[[PAD:.+]] = arith.extui %[[BC]] : i4 to i8 +// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]] +// CHECK: %[[VEC:.+]] = vector.transfer_read %[[ALLOC]][%[[INDEX]]], %[[PAD]] : memref<12xi8>, vector<4xi8> +// CHECK: %[[VEC_F4:.+]] = vector.bitcast %[[VEC]] : vector<4xi8> to vector<8xf4E2M1FN> + +// CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)> +// CHECK32: func @vector_transfer_read_f4 +// CHECK32-SAME: (%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index) +// CHECK32: %[[CONST:.+]] = arith.constant 0.{{0+}}e+00 : f4E2M1FN +// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32> +// CHECK32: %[[BC:.+]] = arith.bitcast %[[CONST]] : f4E2M1FN to i4 +// CHECK32: %[[PAD:.+]] = arith.extui %[[BC]] : i4 to i32 +// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]] +// CHECK32: %[[VEC:.+]] = vector.transfer_read %[[ALLOC]][%[[INDEX]]], %[[PAD]] : memref<3xi32>, vector<1xi32> +// CHECK32: %[[VEC_F4:.+]] = vector.bitcast %[[VEC]] : vector<1xi32> to vector<8xf4E2M1FN> + +// ----- + ///---------------------------------------------------------------------------------------- /// vector.maskedload ///---------------------------------------------------------------------------------------- @@ -439,6 +495,28 @@ func.func @vector_store_i4(%arg0: vector<8xi4>, %arg1: index, %arg2: index) { // ----- +func.func @vector_store_f4(%arg0: vector<8xf4E2M1FN>, %arg1: index, %arg2: index) { + %0 = memref.alloc() : memref<4x8xf4E2M1FN> + vector.store %arg0, %0[%arg1, %arg2] :memref<4x8xf4E2M1FN>, vector<8xf4E2M1FN> + return +} + +// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)> +// CHECK: func @vector_store_f4 +// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<16xi8> +// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]] +// CHECK: %[[VEC_I8:.+]] = vector.bitcast %[[ARG0]] : vector<8xf4E2M1FN> to vector<4xi8> +// CHECK: vector.store %[[VEC_I8:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref<16xi8>, vector<4xi8> + +// CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)> +// CHECK32: func @vector_store_f4 +// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<4xi32> +// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]] +// CHECK32: %[[VEC_I32:.+]] = vector.bitcast %[[ARG0]] : vector<8xf4E2M1FN> to vector<1xi32> +// CHECK32: vector.store %[[VEC_I32:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref<4xi32>, vector<1xi32> + +// ----- + // FIXME: This example assumes that the store happens at a byte boundary, but // that's not guaranteed. Below is a counter-example with specific dimensions: // vector.store %arg0, %0[0, 3] : memref<2x13xi4>, vector<8xi4>