@@ -150,6 +150,11 @@ static cl::opt<bool> UseLIRCodeSizeHeurs(
150
150
" with -Os/-Oz" ),
151
151
cl::init(true ), cl::Hidden);
152
152
153
+ static cl::opt<bool > ForceMemsetPatternIntrinsic (
154
+ " loop-idiom-force-memset-pattern-intrinsic" ,
155
+ cl::desc (" Use memset.pattern intrinsic whenever possible" ), cl::init(false ),
156
+ cl::Hidden);
157
+
153
158
namespace {
154
159
155
160
class LoopIdiomRecognize {
@@ -323,10 +328,15 @@ bool LoopIdiomRecognize::runOnLoop(Loop *L) {
323
328
L->getHeader ()->getParent ()->hasOptSize () && UseLIRCodeSizeHeurs;
324
329
325
330
HasMemset = TLI->has (LibFunc_memset);
331
+ // TODO: Unconditionally enable use of the memset pattern intrinsic (or at
332
+ // least, opt-in via target hook) once we are confident it will never result
333
+ // in worse codegen than without. For now, use it only when the target
334
+ // supports memset_pattern16 libcall (or unless this is overridden by
335
+ // command line option).
326
336
HasMemsetPattern = TLI->has (LibFunc_memset_pattern16);
327
337
HasMemcpy = TLI->has (LibFunc_memcpy);
328
338
329
- if (HasMemset || HasMemsetPattern || HasMemcpy)
339
+ if (HasMemset || HasMemsetPattern || ForceMemsetPatternIntrinsic || HasMemcpy)
330
340
if (SE->hasLoopInvariantBackedgeTakenCount (L))
331
341
return runOnCountableLoop ();
332
342
@@ -378,11 +388,13 @@ static APInt getStoreStride(const SCEVAddRecExpr *StoreEv) {
378
388
}
379
389
380
390
// / getMemSetPatternValue - If a strided store of the specified value is safe to
381
- // / turn into a memset_pattern16 , return a ConstantArray of 16 bytes that should
382
- // / be passed in. Otherwise, return null.
391
+ // / turn into a memset.patternn intrinsic , return the Constant that should
392
+ // / be passed in. Otherwise, return null.
383
393
// /
384
- // / Note that we don't ever attempt to use memset_pattern8 or 4, because these
385
- // / just replicate their input array and then pass on to memset_pattern16.
394
+ // / TODO this function could allow more constants than it does today (e.g.
395
+ // / those over 16 bytes) now it has transitioned to being used for the
396
+ // / memset.pattern intrinsic rather than directly the memset_pattern16
397
+ // / libcall.
386
398
static Constant *getMemSetPatternValue (Value *V, const DataLayout *DL) {
387
399
// FIXME: This could check for UndefValue because it can be merged into any
388
400
// other valid pattern.
@@ -411,14 +423,12 @@ static Constant *getMemSetPatternValue(Value *V, const DataLayout *DL) {
411
423
if (Size > 16 )
412
424
return nullptr ;
413
425
414
- // If the constant is exactly 16 bytes, just use it.
415
- if (Size == 16 )
416
- return C;
426
+ // For now, don't handle types that aren't int, floats, or pointers.
427
+ Type *CTy = C->getType ();
428
+ if (!CTy->isIntOrPtrTy () && !CTy->isFloatingPointTy ())
429
+ return nullptr ;
417
430
418
- // Otherwise, we'll use an array of the constants.
419
- unsigned ArraySize = 16 / Size;
420
- ArrayType *AT = ArrayType::get (V->getType (), ArraySize);
421
- return ConstantArray::get (AT, std::vector<Constant *>(ArraySize, C));
431
+ return C;
422
432
}
423
433
424
434
LoopIdiomRecognize::LegalStoreKind
@@ -479,7 +489,8 @@ LoopIdiomRecognize::isLegalStore(StoreInst *SI) {
479
489
// It looks like we can use SplatValue.
480
490
return LegalStoreKind::Memset;
481
491
}
482
- if (!UnorderedAtomic && HasMemsetPattern && !DisableLIRP::Memset &&
492
+ if (!UnorderedAtomic && (HasMemsetPattern || ForceMemsetPatternIntrinsic) &&
493
+ !DisableLIRP::Memset &&
483
494
// Don't create memset_pattern16s with address spaces.
484
495
StorePtr->getType ()->getPointerAddressSpace () == 0 &&
485
496
getMemSetPatternValue (StoredVal, DL)) {
@@ -1061,50 +1072,81 @@ bool LoopIdiomRecognize::processLoopStridedStore(
1061
1072
return Changed;
1062
1073
1063
1074
// Okay, everything looks good, insert the memset.
1075
+ Value *SplatValue = isBytewiseValue (StoredVal, *DL);
1076
+ Constant *PatternValue = nullptr ;
1077
+ if (!SplatValue)
1078
+ PatternValue = getMemSetPatternValue (StoredVal, DL);
1079
+
1080
+ // MemsetArg is the number of bytes for the memset libcall, and the number
1081
+ // of pattern repetitions if the memset.pattern intrinsic is being used.
1082
+ Value *MemsetArg;
1083
+ std::optional<int64_t > BytesWritten;
1084
+
1085
+ if (PatternValue && (HasMemsetPattern || ForceMemsetPatternIntrinsic)) {
1086
+ const SCEV *TripCountS =
1087
+ SE->getTripCountFromExitCount (BECount, IntIdxTy, CurLoop);
1088
+ if (!Expander.isSafeToExpand (TripCountS))
1089
+ return Changed;
1090
+ const SCEVConstant *ConstStoreSize = dyn_cast<SCEVConstant>(StoreSizeSCEV);
1091
+ if (!ConstStoreSize)
1092
+ return Changed;
1093
+ Value *TripCount = Expander.expandCodeFor (TripCountS, IntIdxTy,
1094
+ Preheader->getTerminator ());
1095
+ uint64_t PatternRepsPerTrip =
1096
+ (ConstStoreSize->getValue ()->getZExtValue () * 8 ) /
1097
+ DL->getTypeSizeInBits (PatternValue->getType ());
1098
+ // If ConstStoreSize is not equal to the width of PatternValue, then
1099
+ // MemsetArg is TripCount * (ConstStoreSize/PatternValueWidth). Else
1100
+ // MemSetArg is just TripCount.
1101
+ MemsetArg =
1102
+ PatternRepsPerTrip == 1
1103
+ ? TripCount
1104
+ : Builder.CreateMul (TripCount,
1105
+ Builder.getIntN (IntIdxTy->getIntegerBitWidth (),
1106
+ PatternRepsPerTrip));
1107
+ if (auto *CI = dyn_cast<ConstantInt>(TripCount))
1108
+ BytesWritten =
1109
+ CI->getZExtValue () * ConstStoreSize->getValue ()->getZExtValue ();
1064
1110
1065
- const SCEV *NumBytesS =
1066
- getNumBytes (BECount, IntIdxTy, StoreSizeSCEV, CurLoop, DL, SE);
1067
-
1068
- // TODO: ideally we should still be able to generate memset if SCEV expander
1069
- // is taught to generate the dependencies at the latest point.
1070
- if (!Expander.isSafeToExpand (NumBytesS))
1071
- return Changed;
1111
+ } else {
1112
+ const SCEV *NumBytesS =
1113
+ getNumBytes (BECount, IntIdxTy, StoreSizeSCEV, CurLoop, DL, SE);
1072
1114
1073
- Value *NumBytes =
1074
- Expander.expandCodeFor (NumBytesS, IntIdxTy, Preheader->getTerminator ());
1115
+ // TODO: ideally we should still be able to generate memset if SCEV expander
1116
+ // is taught to generate the dependencies at the latest point.
1117
+ if (!Expander.isSafeToExpand (NumBytesS))
1118
+ return Changed;
1119
+ MemsetArg =
1120
+ Expander.expandCodeFor (NumBytesS, IntIdxTy, Preheader->getTerminator ());
1121
+ if (auto *CI = dyn_cast<ConstantInt>(MemsetArg))
1122
+ BytesWritten = CI->getZExtValue ();
1123
+ }
1124
+ assert (MemsetArg && " MemsetArg should have been set" );
1075
1125
1076
1126
AAMDNodes AATags = TheStore->getAAMetadata ();
1077
1127
for (Instruction *Store : Stores)
1078
1128
AATags = AATags.merge (Store->getAAMetadata ());
1079
- if (auto CI = dyn_cast<ConstantInt>(NumBytes) )
1080
- AATags = AATags.extendTo (CI-> getZExtValue ());
1129
+ if (BytesWritten )
1130
+ AATags = AATags.extendTo (BytesWritten. value ());
1081
1131
else
1082
1132
AATags = AATags.extendTo (-1 );
1083
1133
1084
1134
CallInst *NewCall;
1085
- if (Value * SplatValue = isBytewiseValue (StoredVal, *DL) ) {
1086
- NewCall = Builder.CreateMemSet (BasePtr, SplatValue, NumBytes ,
1135
+ if (SplatValue) {
1136
+ NewCall = Builder.CreateMemSet (BasePtr, SplatValue, MemsetArg ,
1087
1137
MaybeAlign (StoreAlignment),
1088
1138
/* isVolatile=*/ false , AATags);
1089
- } else if (isLibFuncEmittable (M, TLI, LibFunc_memset_pattern16)) {
1090
- // Everything is emitted in default address space
1091
- Type *Int8PtrTy = DestInt8PtrTy;
1092
-
1093
- StringRef FuncName = " memset_pattern16" ;
1094
- FunctionCallee MSP = getOrInsertLibFunc (M, *TLI, LibFunc_memset_pattern16,
1095
- Builder.getVoidTy (), Int8PtrTy, Int8PtrTy, IntIdxTy);
1096
- inferNonMandatoryLibFuncAttrs (M, FuncName, *TLI);
1097
-
1098
- // Otherwise we should form a memset_pattern16. PatternValue is known to be
1099
- // an constant array of 16-bytes. Plop the value into a mergable global.
1100
- Constant *PatternValue = getMemSetPatternValue (StoredVal, DL);
1101
- assert (PatternValue && " Expected pattern value." );
1102
- GlobalVariable *GV = new GlobalVariable (*M, PatternValue->getType (), true ,
1103
- GlobalValue::PrivateLinkage,
1104
- PatternValue, " .memset_pattern" );
1105
- GV->setUnnamedAddr (GlobalValue::UnnamedAddr::Global); // Ok to merge these.
1106
- GV->setAlignment (Align (16 ));
1107
- NewCall = Builder.CreateCall (MSP, {BasePtr, GV, NumBytes});
1139
+ } else if (ForceMemsetPatternIntrinsic ||
1140
+ isLibFuncEmittable (M, TLI, LibFunc_memset_pattern16)) {
1141
+ assert (isa<SCEVConstant>(StoreSizeSCEV) && " Expected constant store size" );
1142
+
1143
+ NewCall = Builder.CreateIntrinsic (
1144
+ Intrinsic::experimental_memset_pattern,
1145
+ {DestInt8PtrTy, PatternValue->getType (), IntIdxTy},
1146
+ {BasePtr, PatternValue, MemsetArg,
1147
+ ConstantInt::getFalse (M->getContext ())});
1148
+ if (StoreAlignment)
1149
+ cast<MemSetPatternInst>(NewCall)->setDestAlignment (*StoreAlignment);
1108
1150
NewCall->setAAMetadata (AATags);
1109
1151
} else {
1110
1152
// Neither a memset, nor memset_pattern16
0 commit comments