Skip to content

Commit 4f0136f

Browse files
committed
Check for operand types supported by lowering.
1 parent 37e62a0 commit 4f0136f

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,28 @@ class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> {
265265
if (!op.getC()) {
266266
return rewriter.notifyMatchFailure(op, "OCL requires C operand");
267267
}
268+
auto precisionA = op.getTypes().getA();
269+
auto precisionB = op.getTypes().getB();
270+
auto precisionC = op.getTypes().getC();
271+
auto precisionD = op.getTypes().getD();
272+
if (precisionC != precisionD) {
273+
return rewriter.notifyMatchFailure(op, "type of C and D need to match");
274+
}
275+
if (precisionC != xevm::ElemType::S32 &&
276+
precisionC != xevm::ElemType::F32 &&
277+
precisionC != xevm::ElemType::F16 &&
278+
precisionC != xevm::ElemType::BF16) {
279+
return rewriter.notifyMatchFailure(
280+
op, "type of C and D must be S32, F32, F16 or BF16");
281+
}
282+
if (precisionA == xevm::ElemType::S32 ||
283+
precisionA == xevm::ElemType::F32) {
284+
return rewriter.notifyMatchFailure(op, "type of A cannot be S32 or F32");
285+
}
286+
if (precisionB == xevm::ElemType::S32 ||
287+
precisionB == xevm::ElemType::F32) {
288+
return rewriter.notifyMatchFailure(op, "type of B cannot be S32 or F32");
289+
}
268290
constexpr uint32_t bitWidthPackedA{16};
269291
constexpr uint32_t bitWidthPackedB{32};
270292
auto loc = op.getLoc();

0 commit comments

Comments
 (0)