Skip to content

Commit e6f6916

Browse files
[mlir][bufferize] Support tensor.expand_shape and tensor.collapse_shape
Differential Revision: https://reviews.llvm.org/D112512
1 parent a65b9dd commit e6f6916

File tree

2 files changed

+110
-0
lines changed

2 files changed

+110
-0
lines changed

mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,46 @@ struct CastOpInterface
8080
}
8181
};
8282

83+
/// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape.
84+
struct CollapseShapeOpInterface
85+
: public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface,
86+
tensor::CollapseShapeOp> {
87+
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
88+
const BufferizationState &state) const {
89+
return false;
90+
}
91+
92+
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
93+
const BufferizationState &state) const {
94+
return false;
95+
}
96+
97+
SmallVector<OpResult>
98+
getAliasingOpResult(Operation *op, OpOperand &opOperand,
99+
const BufferizationState &state) const {
100+
if (&opOperand == &op->getOpOperand(0) /*src*/)
101+
return {op->getOpResult(0)};
102+
return {};
103+
}
104+
105+
BufferRelation bufferRelation(Operation *op, OpResult opResult,
106+
const BufferizationState &state) const {
107+
return BufferRelation::Equivalent;
108+
}
109+
110+
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
111+
const BufferizationState &state) const {
112+
auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
113+
Value buffer =
114+
*state.getBuffer(rewriter, collapseShapeOp->getOpOperand(0) /*src*/);
115+
Type resultType =
116+
getMemRefType(collapseShapeOp.getResultType(), state.getOptions());
117+
replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
118+
rewriter, op, resultType, buffer, collapseShapeOp.reassociation());
119+
return success();
120+
}
121+
};
122+
83123
/// Bufferization of tensor.dim. Replace with memref.dim.
84124
struct DimOpInterface
85125
: public BufferizableOpInterface::ExternalModel<DimOpInterface,
@@ -109,6 +149,46 @@ struct DimOpInterface
109149
}
110150
};
111151

152+
/// Bufferization of tensor.expand_shape. Replace with memref.expand_shape.
153+
struct ExpandShapeOpInterface
154+
: public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface,
155+
tensor::ExpandShapeOp> {
156+
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
157+
const BufferizationState &state) const {
158+
return false;
159+
}
160+
161+
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
162+
const BufferizationState &state) const {
163+
return false;
164+
}
165+
166+
SmallVector<OpResult>
167+
getAliasingOpResult(Operation *op, OpOperand &opOperand,
168+
const BufferizationState &state) const {
169+
if (&opOperand == &op->getOpOperand(0) /*src*/)
170+
return {op->getOpResult(0)};
171+
return {};
172+
}
173+
174+
BufferRelation bufferRelation(Operation *op, OpResult opResult,
175+
const BufferizationState &state) const {
176+
return BufferRelation::Equivalent;
177+
}
178+
179+
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
180+
const BufferizationState &state) const {
181+
auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
182+
Value buffer =
183+
*state.getBuffer(rewriter, expandShapeOp->getOpOperand(0) /*src*/);
184+
Type resultType =
185+
getMemRefType(expandShapeOp.getResultType(), state.getOptions());
186+
replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
187+
rewriter, op, resultType, buffer, expandShapeOp.reassociation());
188+
return success();
189+
}
190+
};
191+
112192
/// Bufferization of tensor.extract_slice. Replace with memref.subview.
113193
struct ExtractSliceOpInterface
114194
: public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
@@ -635,7 +715,9 @@ struct RankOpInterface
635715
void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
636716
DialectRegistry &registry) {
637717
registry.addOpInterface<CastOp, CastOpInterface>();
718+
registry.addOpInterface<CollapseShapeOp, CollapseShapeOpInterface>();
638719
registry.addOpInterface<DimOp, DimOpInterface>();
720+
registry.addOpInterface<ExpandShapeOp, ExpandShapeOpInterface>();
639721
registry.addOpInterface<ExtractSliceOp, ExtractSliceOpInterface>();
640722
registry.addOpInterface<ExtractOp, ExtractOpInterface>();
641723
registry.addOpInterface<FromElementsOp, FromElementsOpInterface>();

mlir/test/Dialect/Tensor/bufferize.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,3 +301,31 @@ func @tensor.insert(%t1: tensor<5xf32>, %idx1: index, %f: f32) -> tensor<5xf32>
301301
// CHECK: return %[[r]]
302302
return %0 : tensor<5xf32>
303303
}
304+
305+
// CHECK-LABEL: func @tensor.expand_shape(
306+
// CHECK-SAME: %[[t1:.*]]: tensor<?x10xf32>
307+
func @tensor.expand_shape(%t1: tensor<?x10xf32>) -> tensor<2x?x10xf32> {
308+
// CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x10xf32>
309+
// CHECK: %[[expanded:.*]] = memref.expand_shape %[[m1]] [
310+
// CHECK-SAME: [0, 1], [2]] : memref<?x10xf32> into memref<2x?x10xf32>
311+
%0 = tensor.expand_shape %t1 [[0, 1], [2]]
312+
: tensor<?x10xf32> into tensor<2x?x10xf32>
313+
314+
// CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]]
315+
// CHECK: return %[[r]]
316+
return %0 : tensor<2x?x10xf32>
317+
}
318+
319+
// CHECK-LABEL: func @tensor.collapse_shape(
320+
// CHECK-SAME: %[[t1:.*]]: tensor<2x?x?xf32>
321+
func @tensor.collapse_shape(%t1: tensor<2x?x?xf32>) -> tensor<?x?xf32> {
322+
// CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<2x?x?xf32>
323+
// CHECK: %[[collapsed:.*]] = memref.collapse_shape %[[m1]] [
324+
// CHECK-SAME: [0, 1], [2]] : memref<2x?x?xf32> into memref<?x?xf32>
325+
%0 = tensor.collapse_shape %t1 [[0, 1], [2]]
326+
: tensor<2x?x?xf32> into tensor<?x?xf32>
327+
328+
// CHECK: %[[r:.*]] = bufferization.to_tensor %[[collapsed]]
329+
// CHECK: return %[[r]]
330+
return %0 : tensor<?x?xf32>
331+
}

0 commit comments

Comments
 (0)