Skip to content

Commit b68e8f1

Browse files
authored
[mlir][linalg] Allow promotion to use the original subview size (#144334)
linalg promotion attempts to compute a constant upper bound for the allocated buffer size. Only when failed to compute an upperbound it fallbacks to the original subview size, which may be dynamic. Adding a promotion option to use the original subview size by default, thus minimizing the allocation size. Fixes #144268.
1 parent 3c6cade commit b68e8f1

File tree

5 files changed

+80
-4
lines changed

5 files changed

+80
-4
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,6 +1402,7 @@ def PromoteOp : Op<Transform_Dialect, "structured.promote",
14021402
DefaultValuedAttr<I64ArrayAttr, "{}">:$operands_to_promote,
14031403
DefaultValuedAttr<BoolArrayAttr, "{}">:$use_full_tile_buffers,
14041404
UnitAttr:$use_full_tiles_by_default,
1405+
UnitAttr:$use_original_subview_size,
14051406
UnitAttr:$use_alloca,
14061407
OptionalAttr<AnyAttr>:$memory_space,
14071408
OptionalAttr<DeviceMappingArrayAttr>:$mapping,

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,13 @@ struct LinalgPromotionOptions {
422422
useFullTileBuffersDefault = use;
423423
return *this;
424424
}
425+
/// If true, buffers will be allocated with the original subview size. This
426+
/// may result in more dynamic allocations, in case of dynamic sizes.
427+
bool useOriginalSubviewSize = false;
428+
LinalgPromotionOptions &setUseOriginalSubviewSize(bool originalSize) {
429+
useOriginalSubviewSize = originalSize;
430+
return *this;
431+
}
425432
/// Alignment of promoted buffer. If `std::nullopt` do not specify alignment.
426433
std::optional<unsigned> alignment;
427434
LinalgPromotionOptions &setAlignment(unsigned align) {
@@ -796,7 +803,8 @@ FailureOr<LinalgOp> specializeGenericOp(RewriterBase &rewriter,
796803
GenericOp genericOp);
797804

798805
/// Create a new buffer using the `allocationFn` provided. The size of this
799-
/// buffer is the smallest constant bounding size along each dimension that
806+
/// buffer is either the original subview size when 'useOriginalSubviewSize' is
807+
/// set to true or the smallest constant bounding size along each dimension that
800808
/// can be computed for the size of the result of `subView`. Returns the
801809
/// allocated buffer as `fullLocalView` and the view that matches the size of
802810
/// the result of subview operation as `partialLocalView`.
@@ -806,6 +814,7 @@ struct PromotionInfo {
806814
};
807815
FailureOr<PromotionInfo>
808816
promoteSubviewAsNewBuffer(OpBuilder &b, Location loc, memref::SubViewOp subView,
817+
bool useOriginalSubviewSize,
809818
const AllocBufferCallbackFn &allocationFn,
810819
DataLayout &layout);
811820

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2408,6 +2408,9 @@ transform::PromoteOp::applyToOne(transform::TransformRewriter &rewriter,
24082408
if (getUseFullTilesByDefault())
24092409
promotionOptions = promotionOptions.setUseFullTileBuffersByDefault(
24102410
getUseFullTilesByDefault());
2411+
if (getUseOriginalSubviewSize())
2412+
promotionOptions =
2413+
promotionOptions.setUseOriginalSubviewSize(getUseOriginalSubviewSize());
24112414
if (getUseAlloca())
24122415
promotionOptions = promotionOptions.setUseAlloca(getUseAlloca());
24132416
if (!getUseFullTileBuffers().empty())

mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,9 @@ struct LinalgOpInstancePromotionOptions {
148148
llvm::SmallSet<int64_t, 4> operandsNumbersToCopyIn;
149149
/// True if the full view should be used for the promoted buffer.
150150
DenseMap<Value, bool> useFullTileBuffers;
151+
/// True if the original subview size should be used. This means the full tile
152+
/// buffer is the same size as the partial view.
153+
bool useOriginalSubviewSize;
151154

152155
/// Callback functions for allocation and deallocation of promoted buffers, as
153156
/// well as to copy the data into and out of these buffers.
@@ -170,6 +173,7 @@ LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions(
170173
options.useFullTileBuffers.value_or(llvm::SmallBitVector());
171174
vUseFullTileBuffers.resize(linalgOp->getNumOperands(),
172175
options.useFullTileBuffersDefault);
176+
useOriginalSubviewSize = options.useOriginalSubviewSize;
173177

174178
for (OpOperand &opOperand : linalgOp->getOpOperands()) {
175179
int64_t operandNumber = opOperand.getOperandNumber();
@@ -237,7 +241,8 @@ LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions(
237241
// by a partial `copy` op.
238242
FailureOr<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer(
239243
OpBuilder &b, Location loc, memref::SubViewOp subView,
240-
const AllocBufferCallbackFn &allocationFn, DataLayout &layout) {
244+
bool useOriginalSubviewSize, const AllocBufferCallbackFn &allocationFn,
245+
DataLayout &layout) {
241246
auto viewType = subView.getType();
242247
auto rank = viewType.getRank();
243248
SmallVector<Value, 4> fullSizes;
@@ -254,7 +259,8 @@ FailureOr<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer(
254259
// to look for the bound.
255260
LLVM_DEBUG(llvm::dbgs() << "Extract tightest: " << rangeValue.size << "\n");
256261
Value size;
257-
if (auto attr = llvm::dyn_cast_if_present<Attribute>(rangeValue.size)) {
262+
if (llvm::isa_and_present<Attribute>(rangeValue.size) ||
263+
useOriginalSubviewSize) {
258264
size = getValueOrCreateConstantIndexOp(b, loc, rangeValue.size);
259265
} else {
260266
FailureOr<int64_t> upperBound =
@@ -295,7 +301,8 @@ promoteSubViews(ImplicitLocOpBuilder &b,
295301
memref::SubViewOp subView =
296302
cast<memref::SubViewOp>(v.second.getDefiningOp());
297303
auto promotionInfo = promoteSubviewAsNewBuffer(
298-
b, b.getLoc(), subView, options.allocationFn, layout);
304+
b, b.getLoc(), subView, options.useOriginalSubviewSize,
305+
options.allocationFn, layout);
299306
if (failed(promotionInfo))
300307
return failure();
301308
promotionInfoMap[v.first] = *promotionInfo;

mlir/test/Dialect/Linalg/promotion_options.mlir

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,59 @@ module attributes {transform.with_named_sequence} {
4242
transform.yield
4343
}
4444
}
45+
46+
// -----
47+
48+
func.func @matmul_f32(%A: memref<512x256xf32>, %B: memref<256x512xf32>, %C: memref<256x256xf32>, %s0: index, %s1: index, %s2: index) {
49+
%c0 = arith.constant 0 : index
50+
%c256 = arith.constant 256 : index
51+
%c512 = arith.constant 512 : index
52+
scf.for %arg4 = %c0 to %c512 step %s0 {
53+
scf.for %arg5 = %c0 to %c512 step %s1 {
54+
scf.for %arg6 = %c0 to %c256 step %s2 {
55+
%i0 = affine.min affine_map<(d0)[s0] -> (-d0 + 512, s0)>(%arg4)[%s0]
56+
%i1 = affine.min affine_map<(d0)[s0] -> (-d0 + 512, s0)>(%arg5)[%s1]
57+
%i2 = affine.min affine_map<(d0)[s0] -> (-d0 + 256, s0)>(%arg6)[%s2]
58+
%0 = memref.subview %A[%arg4, %arg6][%i0, %i2][1, 1] : memref<512x256xf32> to memref<?x?xf32, strided<[256, 1], offset: ?>>
59+
%1 = memref.subview %B[%arg6, %arg5][%i2, %i1][1, 1] : memref<256x512xf32> to memref<?x?xf32, strided<[512, 1], offset: ?>>
60+
%2 = memref.subview %C[%arg4, %arg5][%i0, %i1][1, 1] : memref<256x256xf32> to memref<?x?xf32, strided<[256, 1], offset: ?>>
61+
linalg.matmul
62+
ins(%0, %1: memref<?x?xf32, strided<[256, 1], offset: ?>>,
63+
memref<?x?xf32, strided<[512, 1], offset: ?>>)
64+
outs(%2: memref<?x?xf32, strided<[256, 1], offset: ?>>)
65+
}
66+
}
67+
}
68+
return
69+
}
70+
71+
// CHECK-LABEL: func.func @matmul_f32(
72+
// CHECK-SAME: %[[ARG0:.*]]: memref<512x256xf32>
73+
// CHECK-SAME: %[[ARG1:.*]]: memref<256x512xf32>
74+
// CHECK-SAME: %[[ARG2:.*]]: memref<256x256xf32>
75+
// CHECK-SAME: %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index
76+
// CHECK: %[[C4:.*]] = arith.constant 4 : index
77+
78+
// CHECK: %[[i0:.*]] = affine.min
79+
// CHECK: %[[i1:.*]] = affine.min
80+
// CHECK: %[[i2:.*]] = affine.min
81+
82+
// CHECK: %[[VAL_13:.*]] = arith.muli %[[i0]], %[[i2]] : index
83+
// CHECK: %[[VAL_14:.*]] = arith.muli %[[VAL_13]], %[[C4]] : index
84+
// CHECK: %[[VAL_15:.*]] = memref.alloc(%[[VAL_14]]) : memref<?xi8>
85+
86+
// CHECK: %[[VAL_18:.*]] = arith.muli %[[i2]], %[[i1]] : index
87+
// CHECK: %[[VAL_19:.*]] = arith.muli %[[VAL_18]], %[[C4]] : index
88+
// CHECK: %[[VAL_20:.*]] = memref.alloc(%[[VAL_19]]) : memref<?xi8>
89+
90+
// CHECK: %[[VAL_23:.*]] = arith.muli %[[i0]], %[[i1]] : index
91+
// CHECK: %[[VAL_24:.*]] = arith.muli %[[VAL_23]], %[[C4]] : index
92+
// CHECK: %[[VAL_25:.*]] = memref.alloc(%[[VAL_24]]) : memref<?xi8>
93+
94+
module attributes {transform.with_named_sequence} {
95+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
96+
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
97+
%1 = transform.structured.promote %0 { use_original_subview_size } : (!transform.any_op) -> !transform.any_op
98+
transform.yield
99+
}
100+
}

0 commit comments

Comments
 (0)