@@ -76,14 +76,16 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
76
76
Value getNumElements (ConversionPatternRewriter &rewriter, Location loc,
77
77
MemRefType type, MemRefDescriptor desc) const {
78
78
Type indexType = ConvertToLLVMPattern::getIndexType ();
79
- return type.hasStaticShape ()
80
- ? ConvertToLLVMPattern::createIndexAttrConstant (
81
- rewriter, loc, indexType, type.getNumElements ())
82
- // For identity maps (verified by caller), the number of
83
- // elements is stride[0] * size[0].
84
- : rewriter.create <LLVM::MulOp>(loc,
85
- desc.stride (rewriter, loc, 0 ),
86
- desc.size (rewriter, loc, 0 ));
79
+ if (type.hasStaticShape ())
80
+ return ConvertToLLVMPattern::createIndexAttrConstant (
81
+ rewriter, loc, indexType, type.getNumElements ());
82
+ // Compute the number of elements by multiplying all the dim sizes.
83
+ uint64_t rank = type.getRank ();
84
+ Value numElements = desc.size (rewriter, loc, /* pos=*/ 0 );
85
+ for (unsigned i = 1 ; i < rank; i++)
86
+ numElements = rewriter.create <LLVM::MulOp>(
87
+ loc, numElements, desc.size (rewriter, loc, /* pos=*/ i));
88
+ return numElements;
87
89
}
88
90
89
91
MLIRContext *context = &this ->getTypeConverter ()->getContext();
0 commit comments