@@ -4169,28 +4169,35 @@ class StridedSliceBroadcast final
4169
4169
auto dstVecType = llvm::cast<VectorType>(op.getType ());
4170
4170
unsigned dstRank = dstVecType.getRank ();
4171
4171
unsigned rankDiff = dstRank - srcRank;
4172
- // Check if the most inner dimensions of the source of the broadcast are the
4173
- // same as the destination of the extract . If this is the case we can just
4174
- // use a broadcast as the original dimensions are untouched .
4175
- bool lowerDimMatch = true ;
4172
+ // Source dimensions can be broadcasted (1 -> n with n > 1) or sliced
4173
+ // (n -> m with n > m) . If they are originally both broadcasted *and*
4174
+ // sliced, this can be simplified to just broadcasting .
4175
+ bool needsSlice = false ;
4176
4176
for (unsigned i = 0 ; i < srcRank; i++) {
4177
- if (srcVecType.getDimSize (i) != dstVecType.getDimSize (i + rankDiff)) {
4178
- lowerDimMatch = false ;
4177
+ if (srcVecType.getDimSize (i) != 1 &&
4178
+ srcVecType.getDimSize (i) != dstVecType.getDimSize (i + rankDiff)) {
4179
+ needsSlice = true ;
4179
4180
break ;
4180
4181
}
4181
4182
}
4182
4183
Value source = broadcast.getSource ();
4183
- // If the inner dimensions don't match, it means we need to extract from the
4184
- // source of the orignal broadcast and then broadcast the extracted value.
4185
- // We also need to handle degenerated cases where the source is effectively
4186
- // just a single scalar.
4187
- bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements () == 1 );
4188
- if (!lowerDimMatch && !isScalarSrc) {
4184
+ if (needsSlice) {
4185
+ SmallVector<int64_t > offsets =
4186
+ getI64SubArray (op.getOffsets (), /* dropFront=*/ rankDiff);
4187
+ SmallVector<int64_t > sizes =
4188
+ getI64SubArray (op.getSizes (), /* dropFront=*/ rankDiff);
4189
+ for (unsigned i = 0 ; i < srcRank; i++) {
4190
+ if (srcVecType.getDimSize (i) == 1 ) {
4191
+ // In case this dimension was broadcasted *and* sliced, the offset
4192
+ // and size need to be updated now that there is no broadcast before
4193
+ // the slice.
4194
+ offsets[i] = 0 ;
4195
+ sizes[i] = 1 ;
4196
+ }
4197
+ }
4189
4198
source = rewriter.create <ExtractStridedSliceOp>(
4190
- op->getLoc (), source,
4191
- getI64SubArray (op.getOffsets (), /* dropFront=*/ rankDiff),
4192
- getI64SubArray (op.getSizes (), /* dropFront=*/ rankDiff),
4193
- getI64SubArray (op.getStrides (), /* dropFront=*/ rankDiff));
4199
+ op->getLoc (), source, offsets, sizes,
4200
+ getI64SubArray (op.getStrides (), /* dropFront=*/ rankDiff));
4194
4201
}
4195
4202
rewriter.replaceOpWithNewOp <BroadcastOp>(op, op.getType (), source);
4196
4203
return success ();
0 commit comments