@@ -265,6 +265,28 @@ class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> {
265
265
if (!op.getC ()) {
266
266
return rewriter.notifyMatchFailure (op, " OCL requires C operand" );
267
267
}
268
+ auto precisionA = op.getTypes ().getA ();
269
+ auto precisionB = op.getTypes ().getB ();
270
+ auto precisionC = op.getTypes ().getC ();
271
+ auto precisionD = op.getTypes ().getD ();
272
+ if (precisionC != precisionD) {
273
+ return rewriter.notifyMatchFailure (op, " type of C and D need to match" );
274
+ }
275
+ if (precisionC != xevm::ElemType::S32 &&
276
+ precisionC != xevm::ElemType::F32 &&
277
+ precisionC != xevm::ElemType::F16 &&
278
+ precisionC != xevm::ElemType::BF16) {
279
+ return rewriter.notifyMatchFailure (
280
+ op, " type of C and D must be S32, F32, F16 or BF16" );
281
+ }
282
+ if (precisionA == xevm::ElemType::S32 ||
283
+ precisionA == xevm::ElemType::F32) {
284
+ return rewriter.notifyMatchFailure (op, " type of A cannot be S32 or F32" );
285
+ }
286
+ if (precisionB == xevm::ElemType::S32 ||
287
+ precisionB == xevm::ElemType::F32) {
288
+ return rewriter.notifyMatchFailure (op, " type of B cannot be S32 or F32" );
289
+ }
268
290
constexpr uint32_t bitWidthPackedA{16 };
269
291
constexpr uint32_t bitWidthPackedB{32 };
270
292
auto loc = op.getLoc ();
0 commit comments