Skip to content

Commit 9fc0400

Browse files
committed
Use result type instead of C type for MMA.
1 parent 49efd76 commit 9fc0400

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -303,13 +303,14 @@ class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> {
303303

304304
Value c = op.getC();
305305
VectorType cOrigTy = cast<VectorType>(c.getType());
306-
assert(cOrigTy == op->getResultTypes()[0] &&
307-
"Accumulator and result type mismatch");
306+
VectorType resOrigTy = cast<VectorType>(op->getResultTypes()[0]);
307+
assert(cOrigTy == resOrigTy && "Accumulator and result type mismatch");
308308
// OCL builtins encode bfloat16 as int16
309309
VectorType cTy =
310310
cOrigTy.getElementType().isBF16()
311311
? VectorType::get(cOrigTy.getShape(), rewriter.getIntegerType(16))
312312
: cOrigTy;
313+
VectorType resTy = cTy;
313314
if (cOrigTy != cTy)
314315
c = rewriter.create<LLVM::BitcastOp>(loc, cTy, c);
315316

@@ -332,12 +333,12 @@ class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> {
332333
auto funcAttrs = convergentNoUnwindWillReturnAttrs;
333334
funcAttrs.memEffectsAttr = memAttr;
334335
Value result =
335-
createDeviceFunctionCall(rewriter, fnName, cTy, argTypes, args, {},
336+
createDeviceFunctionCall(rewriter, fnName, resTy, argTypes, args, {},
336337
funcAttrs, op.getOperation())
337338
->getResult(0);
338339

339-
if (cOrigTy != cTy)
340-
result = rewriter.create<LLVM::BitcastOp>(loc, cOrigTy, result);
340+
if (resOrigTy != resTy)
341+
result = rewriter.create<LLVM::BitcastOp>(loc, resOrigTy, result);
341342

342343
rewriter.replaceOp(op, result);
343344
return success();

0 commit comments

Comments
 (0)