@@ -303,13 +303,14 @@ class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> {
303
303
304
304
Value c = op.getC ();
305
305
VectorType cOrigTy = cast<VectorType>(c.getType ());
306
- assert (cOrigTy == op->getResultTypes ()[0 ] &&
307
- " Accumulator and result type mismatch" );
306
+ VectorType resOrigTy = cast<VectorType>( op->getResultTypes ()[0 ]);
307
+ assert (cOrigTy == resOrigTy && " Accumulator and result type mismatch" );
308
308
// OCL builtins encode bfloat16 as int16
309
309
VectorType cTy =
310
310
cOrigTy.getElementType ().isBF16 ()
311
311
? VectorType::get (cOrigTy.getShape (), rewriter.getIntegerType (16 ))
312
312
: cOrigTy;
313
+ VectorType resTy = cTy;
313
314
if (cOrigTy != cTy)
314
315
c = rewriter.create <LLVM::BitcastOp>(loc, cTy, c);
315
316
@@ -332,12 +333,12 @@ class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> {
332
333
auto funcAttrs = convergentNoUnwindWillReturnAttrs;
333
334
funcAttrs.memEffectsAttr = memAttr;
334
335
Value result =
335
- createDeviceFunctionCall (rewriter, fnName, cTy , argTypes, args, {},
336
+ createDeviceFunctionCall (rewriter, fnName, resTy , argTypes, args, {},
336
337
funcAttrs, op.getOperation ())
337
338
->getResult (0 );
338
339
339
- if (cOrigTy != cTy )
340
- result = rewriter.create <LLVM::BitcastOp>(loc, cOrigTy , result);
340
+ if (resOrigTy != resTy )
341
+ result = rewriter.create <LLVM::BitcastOp>(loc, resOrigTy , result);
341
342
342
343
rewriter.replaceOp (op, result);
343
344
return success ();
0 commit comments