Skip to content

Commit 8744e49

Browse files
committed
Specialization for transpose variants
No code changes needed anymore, all handled by the new specialization classes. Fixing a bug in the previous matmul builder where the affine map was being ignored.
1 parent b2460ad commit 8744e49

File tree

3 files changed

+308
-51
lines changed

3 files changed

+308
-51
lines changed

mlir/include/mlir/Dialect/Linalg/IR/Linalg.h

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,4 +144,122 @@ std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr);
144144
#define GET_OP_CLASSES
145145
#include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.h.inc"
146146

147+
namespace mlir::linalg {
148+
149+
/// Specialization of `linalg.matmul` op that has a transpose map on A
150+
class MatmulTransposeAOp : public MatmulOp {
151+
/// Create an affine map for a transpose-A matmul. Used only in the builders.
152+
static SmallVector<AffineMap> getAffineMaps(OpBuilder &builder);
153+
154+
public:
155+
using MatmulOp::MatmulOp;
156+
static ::mlir::TypeID resolveTypeID() { return TypeID::get<MatmulOp>(); }
157+
158+
/// Build a transpose A matmul.
159+
static void build(OpBuilder &builder, OperationState &result,
160+
ValueRange inputs, ValueRange outputs,
161+
ArrayRef<NamedAttribute> attributes = {});
162+
163+
/// Build a transpose A matmul with a specific result type.
164+
static void build(OpBuilder &builder, OperationState &result,
165+
TypeRange resultTensorTypes, ValueRange inputs,
166+
ValueRange outputs, ArrayRef<NamedAttribute> attributes = {});
167+
168+
/// Build a transpose A matmul with a specific result type and a cast type.
169+
static void build(OpBuilder &builder, OperationState &result,
170+
TypeRange resultTensorTypes, ValueRange inputs,
171+
ValueRange outputs, Attribute cast,
172+
ArrayRef<NamedAttribute> attributes = {});
173+
174+
static bool classof(Operation *op);
175+
};
176+
177+
/// Specialization of `linalg.matmul` op that has a transpose map on B
178+
class MatmulTransposeBOp : public MatmulOp {
179+
/// Create an affine map for a transpose-B matmul. Used only in the builders.
180+
static SmallVector<AffineMap> getAffineMaps(OpBuilder &builder);
181+
182+
public:
183+
using MatmulOp::MatmulOp;
184+
static ::mlir::TypeID resolveTypeID() { return TypeID::get<MatmulOp>(); }
185+
186+
/// Build a transpose B matmul.
187+
static void build(OpBuilder &builder, OperationState &result,
188+
ValueRange inputs, ValueRange outputs,
189+
ArrayRef<NamedAttribute> attributes = {});
190+
191+
/// Build a transpose B matmul with a specific result type.
192+
static void build(OpBuilder &builder, OperationState &result,
193+
TypeRange resultTensorTypes, ValueRange inputs,
194+
ValueRange outputs, ArrayRef<NamedAttribute> attributes = {});
195+
196+
/// Build a transpose B matmul with a specific result type and a cast type.
197+
static void build(OpBuilder &builder, OperationState &result,
198+
TypeRange resultTensorTypes, ValueRange inputs,
199+
ValueRange outputs, Attribute cast,
200+
ArrayRef<NamedAttribute> attributes = {});
201+
202+
static bool classof(Operation *op);
203+
};
204+
205+
/// Specialization of `linalg.batch_matmul` op that has a transpose map on A
206+
class BatchMatmulTransposeAOp : public BatchMatmulOp {
207+
/// Create an affine map for a transpose-A batch_matmul. Used only in the
208+
/// builders.
209+
static SmallVector<AffineMap> getAffineMaps(OpBuilder &builder);
210+
211+
public:
212+
using BatchMatmulOp::BatchMatmulOp;
213+
static ::mlir::TypeID resolveTypeID() { return TypeID::get<BatchMatmulOp>(); }
214+
215+
/// Build a transpose A matmul.
216+
static void build(OpBuilder &builder, OperationState &result,
217+
ValueRange inputs, ValueRange outputs,
218+
ArrayRef<NamedAttribute> attributes = {});
219+
220+
/// Build a transpose A matmul with a specific result type.
221+
static void build(OpBuilder &builder, OperationState &result,
222+
TypeRange resultTensorTypes, ValueRange inputs,
223+
ValueRange outputs, ArrayRef<NamedAttribute> attributes = {});
224+
225+
/// Build a transpose A matmul with a specific result type and a cast type.
226+
static void build(OpBuilder &builder, OperationState &result,
227+
TypeRange resultTensorTypes, ValueRange inputs,
228+
ValueRange outputs, Attribute cast,
229+
ArrayRef<NamedAttribute> attributes = {});
230+
231+
static bool classof(Operation *op);
232+
};
233+
234+
/// Specialization of `linalg.batch_matmul` op that has a transpose map on B
235+
class BatchMatmulTransposeBOp : public BatchMatmulOp {
236+
/// Create an affine map for a transpose-B batch_matmul. Used only in the
237+
/// builders.
238+
static SmallVector<AffineMap> getAffineMaps(OpBuilder &builder);
239+
240+
public:
241+
using BatchMatmulOp::BatchMatmulOp;
242+
static ::mlir::TypeID resolveTypeID() { return TypeID::get<BatchMatmulOp>(); }
243+
244+
/// Build a transpose A matmul.
245+
static void build(OpBuilder &builder, OperationState &result,
246+
ValueRange inputs, ValueRange outputs,
247+
ArrayRef<NamedAttribute> attributes = {});
248+
249+
/// Build a transpose A matmul with a specific result type.
250+
static void build(OpBuilder &builder, OperationState &result,
251+
TypeRange resultTensorTypes, ValueRange inputs,
252+
ValueRange outputs, ArrayRef<NamedAttribute> attributes = {});
253+
254+
/// Build a transpose A matmul with a specific result type and a cast type.
255+
static void build(OpBuilder &builder, OperationState &result,
256+
TypeRange resultTensorTypes, ValueRange inputs,
257+
ValueRange outputs, Attribute cast,
258+
ArrayRef<NamedAttribute> attributes = {});
259+
260+
static bool classof(Operation *op);
261+
};
262+
263+
} // namespace mlir::linalg
264+
147265
#endif // MLIR_DIALECT_LINALG_IR_LINALG_H

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 170 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,9 +193,10 @@ static void buildMatmulOp(OpBuilder &b, OperationState &state,
193193
ArrayRef<AffineMap> indexingMaps) {
194194
// Initialize indexingMaps attribute, for MatmulOp.
195195
SmallVector<Attribute, 3> indexingMapsAttrVal;
196-
indexingMapsAttrVal = llvm::map_to_vector(
197-
MatmulOp::getDefaultIndexingMaps(b.getContext()),
198-
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
196+
indexingMapsAttrVal =
197+
llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
198+
return AffineMapAttr::get(map);
199+
});
199200
state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
200201
return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
201202
attributes, regionBuilder);
@@ -3881,6 +3882,172 @@ Speculation::Speculatability MatmulOp::getSpeculatability() {
38813882
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
38823883
}
38833884

3885+
SmallVector<AffineMap> MatmulTransposeAOp::getAffineMaps(OpBuilder &builder) {
3886+
AffineExpr d0, d1, d2;
3887+
auto context = builder.getContext();
3888+
bindDims(context, d0, d1, d2);
3889+
AffineMap mapLHS = AffineMap::get(3, 0, {d2, d0}, context);
3890+
AffineMap mapRHS = AffineMap::get(3, 0, {d2, d1}, context);
3891+
AffineMap mapOut = AffineMap::get(3, 0, {d0, d1}, context);
3892+
SmallVector<AffineMap> affineMaps{mapLHS, mapRHS, mapOut};
3893+
return affineMaps;
3894+
}
3895+
3896+
void linalg::MatmulTransposeAOp::build(OpBuilder &builder,
3897+
OperationState &result,
3898+
ValueRange inputs, ValueRange outputs,
3899+
ArrayRef<NamedAttribute> attributes) {
3900+
buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
3901+
MatmulOp::getRegionBuilder(), getAffineMaps(builder));
3902+
}
3903+
3904+
void linalg::MatmulTransposeAOp::build(OpBuilder &builder,
3905+
OperationState &result,
3906+
TypeRange resultTensorTypes,
3907+
ValueRange inputs, ValueRange outputs,
3908+
ArrayRef<NamedAttribute> attributes) {
3909+
buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
3910+
MatmulOp::getRegionBuilder(), getAffineMaps(builder));
3911+
}
3912+
3913+
void linalg::MatmulTransposeAOp::build(OpBuilder &builder,
3914+
OperationState &result,
3915+
TypeRange resultTensorTypes,
3916+
ValueRange inputs, ValueRange outputs,
3917+
Attribute cast,
3918+
ArrayRef<NamedAttribute> attributes) {
3919+
result.addAttribute("cast", cast);
3920+
buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
3921+
MatmulOp::getRegionBuilder(), getAffineMaps(builder));
3922+
}
3923+
3924+
bool MatmulTransposeAOp::classof(Operation *op) {
3925+
return dyn_cast_or_null<linalg::MatmulOp>(op);
3926+
}
3927+
3928+
SmallVector<AffineMap> MatmulTransposeBOp::getAffineMaps(OpBuilder &builder) {
3929+
AffineExpr d0, d1, d2;
3930+
auto context = builder.getContext();
3931+
bindDims(context, d0, d1, d2);
3932+
AffineMap mapLHS = AffineMap::get(3, 0, {d0, d2}, context);
3933+
AffineMap mapRHS = AffineMap::get(3, 0, {d1, d2}, context);
3934+
AffineMap mapOut = AffineMap::get(3, 0, {d0, d1}, context);
3935+
SmallVector<AffineMap> affineMaps{mapLHS, mapRHS, mapOut};
3936+
return affineMaps;
3937+
}
3938+
3939+
void linalg::MatmulTransposeBOp::build(OpBuilder &builder,
3940+
OperationState &result,
3941+
ValueRange inputs, ValueRange outputs,
3942+
ArrayRef<NamedAttribute> attributes) {
3943+
buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
3944+
MatmulOp::getRegionBuilder(), getAffineMaps(builder));
3945+
}
3946+
3947+
void linalg::MatmulTransposeBOp::build(OpBuilder &builder,
3948+
OperationState &result,
3949+
TypeRange resultTensorTypes,
3950+
ValueRange inputs, ValueRange outputs,
3951+
ArrayRef<NamedAttribute> attributes) {
3952+
buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
3953+
MatmulOp::getRegionBuilder(), getAffineMaps(builder));
3954+
}
3955+
3956+
void linalg::MatmulTransposeBOp::build(OpBuilder &builder,
3957+
OperationState &result,
3958+
TypeRange resultTensorTypes,
3959+
ValueRange inputs, ValueRange outputs,
3960+
Attribute cast,
3961+
ArrayRef<NamedAttribute> attributes) {
3962+
result.addAttribute("cast", cast);
3963+
buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
3964+
MatmulOp::getRegionBuilder(), getAffineMaps(builder));
3965+
}
3966+
3967+
bool MatmulTransposeBOp::classof(Operation *op) {
3968+
return dyn_cast_or_null<linalg::MatmulOp>(op);
3969+
}
3970+
3971+
SmallVector<AffineMap>
3972+
BatchMatmulTransposeAOp::getAffineMaps(OpBuilder &builder) {
3973+
AffineExpr d0, d1, d2, d3;
3974+
auto context = builder.getContext();
3975+
bindDims(context, d0, d1, d2, d3);
3976+
AffineMap mapLHS = AffineMap::get(4, 0, {d0, d3, d1}, context);
3977+
AffineMap mapRHS = AffineMap::get(4, 0, {d0, d3, d2}, context);
3978+
AffineMap mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
3979+
SmallVector<AffineMap> affineMaps{mapLHS, mapRHS, mapOut};
3980+
return affineMaps;
3981+
}
3982+
3983+
void linalg::BatchMatmulTransposeAOp::build(
3984+
OpBuilder &builder, OperationState &result, ValueRange inputs,
3985+
ValueRange outputs, ArrayRef<NamedAttribute> attributes) {
3986+
buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
3987+
BatchMatmulOp::getRegionBuilder(), getAffineMaps(builder));
3988+
}
3989+
3990+
void linalg::BatchMatmulTransposeAOp::build(
3991+
OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
3992+
ValueRange inputs, ValueRange outputs,
3993+
ArrayRef<NamedAttribute> attributes) {
3994+
buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
3995+
BatchMatmulOp::getRegionBuilder(), getAffineMaps(builder));
3996+
}
3997+
3998+
void linalg::BatchMatmulTransposeAOp::build(
3999+
OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
4000+
ValueRange inputs, ValueRange outputs, Attribute cast,
4001+
ArrayRef<NamedAttribute> attributes) {
4002+
result.addAttribute("cast", cast);
4003+
buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4004+
BatchMatmulOp::getRegionBuilder(), getAffineMaps(builder));
4005+
}
4006+
4007+
bool BatchMatmulTransposeAOp::classof(Operation *op) {
4008+
return dyn_cast_or_null<linalg::BatchMatmulOp>(op);
4009+
}
4010+
4011+
SmallVector<AffineMap>
4012+
BatchMatmulTransposeBOp::getAffineMaps(OpBuilder &builder) {
4013+
AffineExpr d0, d1, d2, d3;
4014+
auto context = builder.getContext();
4015+
bindDims(context, d0, d1, d2, d3);
4016+
AffineMap mapLHS = AffineMap::get(4, 0, {d0, d1, d3}, context);
4017+
AffineMap mapRHS = AffineMap::get(4, 0, {d0, d2, d3}, context);
4018+
AffineMap mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
4019+
SmallVector<AffineMap> affineMaps{mapLHS, mapRHS, mapOut};
4020+
return affineMaps;
4021+
}
4022+
4023+
void linalg::BatchMatmulTransposeBOp::build(
4024+
OpBuilder &builder, OperationState &result, ValueRange inputs,
4025+
ValueRange outputs, ArrayRef<NamedAttribute> attributes) {
4026+
buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
4027+
BatchMatmulOp::getRegionBuilder(), getAffineMaps(builder));
4028+
}
4029+
4030+
void linalg::BatchMatmulTransposeBOp::build(
4031+
OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
4032+
ValueRange inputs, ValueRange outputs,
4033+
ArrayRef<NamedAttribute> attributes) {
4034+
buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4035+
BatchMatmulOp::getRegionBuilder(), getAffineMaps(builder));
4036+
}
4037+
4038+
void linalg::BatchMatmulTransposeBOp::build(
4039+
OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
4040+
ValueRange inputs, ValueRange outputs, Attribute cast,
4041+
ArrayRef<NamedAttribute> attributes) {
4042+
result.addAttribute("cast", cast);
4043+
buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4044+
BatchMatmulOp::getRegionBuilder(), getAffineMaps(builder));
4045+
}
4046+
4047+
bool BatchMatmulTransposeBOp::classof(Operation *op) {
4048+
return dyn_cast_or_null<linalg::BatchMatmulOp>(op);
4049+
}
4050+
38844051
//===----------------------------------------------------------------------===//
38854052
// ContractOp
38864053
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)