@@ -7084,61 +7084,136 @@ class HorizontalReduction {
7084
7084
7085
7085
} // end anonymous namespace
7086
7086
7087
+ static Optional<unsigned > getAggregateSize (Instruction *InsertInst) {
7088
+ if (auto *IE = dyn_cast<InsertElementInst>(InsertInst))
7089
+ return cast<FixedVectorType>(IE->getType ())->getNumElements ();
7090
+
7091
+ unsigned AggregateSize = 1 ;
7092
+ auto *IV = cast<InsertValueInst>(InsertInst);
7093
+ Type *CurrentType = IV->getType ();
7094
+ do {
7095
+ if (auto *ST = dyn_cast<StructType>(CurrentType)) {
7096
+ for (auto *Elt : ST->elements ())
7097
+ if (Elt != ST->getElementType (0 )) // check homogeneity
7098
+ return None;
7099
+ AggregateSize *= ST->getNumElements ();
7100
+ CurrentType = ST->getElementType (0 );
7101
+ } else if (auto *AT = dyn_cast<ArrayType>(CurrentType)) {
7102
+ AggregateSize *= AT->getNumElements ();
7103
+ CurrentType = AT->getElementType ();
7104
+ } else if (auto *VT = dyn_cast<FixedVectorType>(CurrentType)) {
7105
+ AggregateSize *= VT->getNumElements ();
7106
+ return AggregateSize;
7107
+ } else if (CurrentType->isSingleValueType ()) {
7108
+ return AggregateSize;
7109
+ } else {
7110
+ return None;
7111
+ }
7112
+ } while (true );
7113
+ }
7114
+
7115
+ static Optional<unsigned > getOperandIndex (Instruction *InsertInst,
7116
+ unsigned OperandOffset) {
7117
+ unsigned OperandIndex = OperandOffset;
7118
+ if (auto *IE = dyn_cast<InsertElementInst>(InsertInst)) {
7119
+ if (auto *CI = dyn_cast<ConstantInt>(IE->getOperand (2 ))) {
7120
+ auto *VT = cast<FixedVectorType>(IE->getType ());
7121
+ OperandIndex *= VT->getNumElements ();
7122
+ OperandIndex += CI->getZExtValue ();
7123
+ return OperandIndex;
7124
+ }
7125
+ return None;
7126
+ }
7127
+
7128
+ auto *IV = cast<InsertValueInst>(InsertInst);
7129
+ Type *CurrentType = IV->getType ();
7130
+ for (unsigned int Index : IV->indices ()) {
7131
+ if (auto *ST = dyn_cast<StructType>(CurrentType)) {
7132
+ OperandIndex *= ST->getNumElements ();
7133
+ CurrentType = ST->getElementType (Index);
7134
+ } else if (auto *AT = dyn_cast<ArrayType>(CurrentType)) {
7135
+ OperandIndex *= AT->getNumElements ();
7136
+ CurrentType = AT->getElementType ();
7137
+ } else {
7138
+ return None;
7139
+ }
7140
+ OperandIndex += Index;
7141
+ }
7142
+ return OperandIndex;
7143
+ }
7144
+
7145
+ static bool findBuildAggregate_rec (Instruction *LastInsertInst,
7146
+ TargetTransformInfo *TTI,
7147
+ SmallVectorImpl<Value *> &BuildVectorOpds,
7148
+ SmallVectorImpl<Value *> &InsertElts,
7149
+ unsigned OperandOffset) {
7150
+ do {
7151
+ Value *InsertedOperand = LastInsertInst->getOperand (1 );
7152
+ Optional<unsigned > OperandIndex =
7153
+ getOperandIndex (LastInsertInst, OperandOffset);
7154
+ if (!OperandIndex)
7155
+ return false ;
7156
+ if (isa<InsertElementInst>(InsertedOperand) ||
7157
+ isa<InsertValueInst>(InsertedOperand)) {
7158
+ if (!findBuildAggregate_rec (cast<Instruction>(InsertedOperand), TTI,
7159
+ BuildVectorOpds, InsertElts, *OperandIndex))
7160
+ return false ;
7161
+ } else {
7162
+ BuildVectorOpds[*OperandIndex] = InsertedOperand;
7163
+ InsertElts[*OperandIndex] = LastInsertInst;
7164
+ }
7165
+ if (isa<UndefValue>(LastInsertInst->getOperand (0 )))
7166
+ return true ;
7167
+ LastInsertInst = dyn_cast<Instruction>(LastInsertInst->getOperand (0 ));
7168
+ } while (LastInsertInst != nullptr &&
7169
+ (isa<InsertValueInst>(LastInsertInst) ||
7170
+ isa<InsertElementInst>(LastInsertInst)) &&
7171
+ LastInsertInst->hasOneUse ());
7172
+ return false ;
7173
+ }
7174
+
7087
7175
// / Recognize construction of vectors like
7088
7176
// / %ra = insertelement <4 x float> undef, float %s0, i32 0
7089
7177
// / %rb = insertelement <4 x float> %ra, float %s1, i32 1
7090
7178
// / %rc = insertelement <4 x float> %rb, float %s2, i32 2
7091
7179
// / %rd = insertelement <4 x float> %rc, float %s3, i32 3
7092
7180
// / starting from the last insertelement or insertvalue instruction.
7093
7181
// /
7094
- // / Also recognize aggregates like {<2 x float>, <2 x float>},
7182
+ // / Also recognize homogeneous aggregates like {<2 x float>, <2 x float>},
7095
7183
// / {{float, float}, {float, float}}, [2 x {float, float}] and so on.
7096
7184
// / See llvm/test/Transforms/SLPVectorizer/X86/pr42022.ll for examples.
7097
7185
// /
7098
7186
// / Assume LastInsertInst is of InsertElementInst or InsertValueInst type.
7099
7187
// /
7100
7188
// / \return true if it matches.
7101
- static bool findBuildAggregate (Value *LastInsertInst, TargetTransformInfo *TTI,
7189
+ static bool findBuildAggregate (Instruction *LastInsertInst,
7190
+ TargetTransformInfo *TTI,
7102
7191
SmallVectorImpl<Value *> &BuildVectorOpds,
7103
7192
SmallVectorImpl<Value *> &InsertElts) {
7193
+
7104
7194
assert ((isa<InsertElementInst>(LastInsertInst) ||
7105
7195
isa<InsertValueInst>(LastInsertInst)) &&
7106
7196
" Expected insertelement or insertvalue instruction!" );
7107
- do {
7108
- Value *InsertedOperand;
7109
- auto *IE = dyn_cast<InsertElementInst>(LastInsertInst);
7110
- if (IE) {
7111
- InsertedOperand = IE->getOperand (1 );
7112
- LastInsertInst = IE->getOperand (0 );
7113
- } else {
7114
- auto *IV = cast<InsertValueInst>(LastInsertInst);
7115
- InsertedOperand = IV->getInsertedValueOperand ();
7116
- LastInsertInst = IV->getAggregateOperand ();
7117
- }
7118
- if (isa<InsertElementInst>(InsertedOperand) ||
7119
- isa<InsertValueInst>(InsertedOperand)) {
7120
- SmallVector<Value *, 8 > TmpBuildVectorOpds;
7121
- SmallVector<Value *, 8 > TmpInsertElts;
7122
- if (!findBuildAggregate (InsertedOperand, TTI, TmpBuildVectorOpds,
7123
- TmpInsertElts))
7124
- return false ;
7125
- BuildVectorOpds.append (TmpBuildVectorOpds.rbegin (),
7126
- TmpBuildVectorOpds.rend ());
7127
- InsertElts.append (TmpInsertElts.rbegin (), TmpInsertElts.rend ());
7128
- } else {
7129
- BuildVectorOpds.push_back (InsertedOperand);
7130
- InsertElts.push_back (IE);
7131
- }
7132
- if (isa<UndefValue>(LastInsertInst))
7133
- break ;
7134
- if ((!isa<InsertValueInst>(LastInsertInst) &&
7135
- !isa<InsertElementInst>(LastInsertInst)) ||
7136
- !LastInsertInst->hasOneUse ())
7137
- return false ;
7138
- } while (true );
7139
- std::reverse (BuildVectorOpds.begin (), BuildVectorOpds.end ());
7140
- std::reverse (InsertElts.begin (), InsertElts.end ());
7141
- return true ;
7197
+
7198
+ assert ((BuildVectorOpds.empty () && InsertElts.empty ()) &&
7199
+ " Expected empty result vectors!" );
7200
+
7201
+ Optional<unsigned > AggregateSize = getAggregateSize (LastInsertInst);
7202
+ if (!AggregateSize)
7203
+ return false ;
7204
+ BuildVectorOpds.resize (*AggregateSize);
7205
+ InsertElts.resize (*AggregateSize);
7206
+
7207
+ if (findBuildAggregate_rec (LastInsertInst, TTI, BuildVectorOpds, InsertElts,
7208
+ 0 )) {
7209
+ llvm::erase_if (BuildVectorOpds,
7210
+ [](const Value *V) { return V == nullptr ; });
7211
+ llvm::erase_if (InsertElts, [](const Value *V) { return V == nullptr ; });
7212
+ if (BuildVectorOpds.size () >= 2 )
7213
+ return true ;
7214
+ }
7215
+
7216
+ return false ;
7142
7217
}
7143
7218
7144
7219
static bool PhiTypeSorterFunc (Value *V, Value *V2) {
@@ -7308,8 +7383,7 @@ bool SLPVectorizerPass::vectorizeInsertValueInst(InsertValueInst *IVI,
7308
7383
7309
7384
SmallVector<Value *, 16 > BuildVectorOpds;
7310
7385
SmallVector<Value *, 16 > BuildVectorInsts;
7311
- if (!findBuildAggregate (IVI, TTI, BuildVectorOpds, BuildVectorInsts) ||
7312
- BuildVectorOpds.size () < 2 )
7386
+ if (!findBuildAggregate (IVI, TTI, BuildVectorOpds, BuildVectorInsts))
7313
7387
return false ;
7314
7388
7315
7389
LLVM_DEBUG (dbgs () << " SLP: array mappable to vector: " << *IVI << " \n " );
@@ -7324,7 +7398,6 @@ bool SLPVectorizerPass::vectorizeInsertElementInst(InsertElementInst *IEI,
7324
7398
SmallVector<Value *, 16 > BuildVectorInsts;
7325
7399
SmallVector<Value *, 16 > BuildVectorOpds;
7326
7400
if (!findBuildAggregate (IEI, TTI, BuildVectorOpds, BuildVectorInsts) ||
7327
- BuildVectorOpds.size () < 2 ||
7328
7401
(llvm::all_of (BuildVectorOpds,
7329
7402
[](Value *V) { return isa<ExtractElementInst>(V); }) &&
7330
7403
isShuffle (BuildVectorOpds)))
0 commit comments