@@ -80,6 +80,46 @@ struct CastOpInterface
80
80
}
81
81
};
82
82
83
+ // / Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape.
84
+ struct CollapseShapeOpInterface
85
+ : public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface,
86
+ tensor::CollapseShapeOp> {
87
+ bool bufferizesToMemoryRead (Operation *op, OpOperand &opOperand,
88
+ const BufferizationState &state) const {
89
+ return false ;
90
+ }
91
+
92
+ bool bufferizesToMemoryWrite (Operation *op, OpOperand &opOperand,
93
+ const BufferizationState &state) const {
94
+ return false ;
95
+ }
96
+
97
+ SmallVector<OpResult>
98
+ getAliasingOpResult (Operation *op, OpOperand &opOperand,
99
+ const BufferizationState &state) const {
100
+ if (&opOperand == &op->getOpOperand (0 ) /* src*/ )
101
+ return {op->getOpResult (0 )};
102
+ return {};
103
+ }
104
+
105
+ BufferRelation bufferRelation (Operation *op, OpResult opResult,
106
+ const BufferizationState &state) const {
107
+ return BufferRelation::Equivalent;
108
+ }
109
+
110
+ LogicalResult bufferize (Operation *op, RewriterBase &rewriter,
111
+ const BufferizationState &state) const {
112
+ auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
113
+ Value buffer =
114
+ *state.getBuffer (rewriter, collapseShapeOp->getOpOperand (0 ) /* src*/ );
115
+ Type resultType =
116
+ getMemRefType (collapseShapeOp.getResultType (), state.getOptions ());
117
+ replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
118
+ rewriter, op, resultType, buffer, collapseShapeOp.reassociation ());
119
+ return success ();
120
+ }
121
+ };
122
+
83
123
// / Bufferization of tensor.dim. Replace with memref.dim.
84
124
struct DimOpInterface
85
125
: public BufferizableOpInterface::ExternalModel<DimOpInterface,
@@ -109,6 +149,46 @@ struct DimOpInterface
109
149
}
110
150
};
111
151
152
+ // / Bufferization of tensor.expand_shape. Replace with memref.expand_shape.
153
+ struct ExpandShapeOpInterface
154
+ : public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface,
155
+ tensor::ExpandShapeOp> {
156
+ bool bufferizesToMemoryRead (Operation *op, OpOperand &opOperand,
157
+ const BufferizationState &state) const {
158
+ return false ;
159
+ }
160
+
161
+ bool bufferizesToMemoryWrite (Operation *op, OpOperand &opOperand,
162
+ const BufferizationState &state) const {
163
+ return false ;
164
+ }
165
+
166
+ SmallVector<OpResult>
167
+ getAliasingOpResult (Operation *op, OpOperand &opOperand,
168
+ const BufferizationState &state) const {
169
+ if (&opOperand == &op->getOpOperand (0 ) /* src*/ )
170
+ return {op->getOpResult (0 )};
171
+ return {};
172
+ }
173
+
174
+ BufferRelation bufferRelation (Operation *op, OpResult opResult,
175
+ const BufferizationState &state) const {
176
+ return BufferRelation::Equivalent;
177
+ }
178
+
179
+ LogicalResult bufferize (Operation *op, RewriterBase &rewriter,
180
+ const BufferizationState &state) const {
181
+ auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
182
+ Value buffer =
183
+ *state.getBuffer (rewriter, expandShapeOp->getOpOperand (0 ) /* src*/ );
184
+ Type resultType =
185
+ getMemRefType (expandShapeOp.getResultType (), state.getOptions ());
186
+ replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
187
+ rewriter, op, resultType, buffer, expandShapeOp.reassociation ());
188
+ return success ();
189
+ }
190
+ };
191
+
112
192
// / Bufferization of tensor.extract_slice. Replace with memref.subview.
113
193
struct ExtractSliceOpInterface
114
194
: public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
@@ -635,7 +715,9 @@ struct RankOpInterface
635
715
void mlir::tensor::registerBufferizableOpInterfaceExternalModels (
636
716
DialectRegistry ®istry) {
637
717
registry.addOpInterface <CastOp, CastOpInterface>();
718
+ registry.addOpInterface <CollapseShapeOp, CollapseShapeOpInterface>();
638
719
registry.addOpInterface <DimOp, DimOpInterface>();
720
+ registry.addOpInterface <ExpandShapeOp, ExpandShapeOpInterface>();
639
721
registry.addOpInterface <ExtractSliceOp, ExtractSliceOpInterface>();
640
722
registry.addOpInterface <ExtractOp, ExtractOpInterface>();
641
723
registry.addOpInterface <FromElementsOp, FromElementsOpInterface>();
0 commit comments