@@ -234,53 +234,100 @@ bool RISCVTargetLowering::lowerInterleavedStore(StoreInst *SI,
234
234
return true ;
235
235
}
236
236
237
+ static bool isMultipleOfN (const Value *V, const DataLayout &DL, unsigned N) {
238
+ assert (N);
239
+ if (N == 1 )
240
+ return true ;
241
+
242
+ using namespace PatternMatch ;
243
+ // Right now we're only recognizing the simplest pattern.
244
+ uint64_t C;
245
+ if (match (V, m_CombineOr (m_ConstantInt (C),
246
+ m_c_Mul (m_Value (), m_ConstantInt (C)))) &&
247
+ C && C % N == 0 )
248
+ return true ;
249
+
250
+ if (isPowerOf2_32 (N)) {
251
+ KnownBits KB = llvm::computeKnownBits (V, DL);
252
+ return KB.countMinTrailingZeros () >= Log2_32 (N);
253
+ }
254
+
255
+ return false ;
256
+ }
257
+
237
258
bool RISCVTargetLowering::lowerDeinterleaveIntrinsicToLoad (
238
- LoadInst *LI, ArrayRef<Value *> DeinterleaveValues) const {
259
+ Instruction *Load, Value *Mask,
260
+ ArrayRef<Value *> DeinterleaveValues) const {
239
261
const unsigned Factor = DeinterleaveValues.size ();
240
262
if (Factor > 8 )
241
263
return false ;
242
264
243
- assert (LI->isSimple ());
244
- IRBuilder<> Builder (LI);
265
+ IRBuilder<> Builder (Load);
245
266
246
267
Value *FirstActive =
247
268
*llvm::find_if (DeinterleaveValues, [](Value *V) { return V != nullptr ; });
248
269
VectorType *ResVTy = cast<VectorType>(FirstActive->getType ());
249
270
250
- const DataLayout &DL = LI->getDataLayout ();
271
+ const DataLayout &DL = Load->getDataLayout ();
272
+ auto *XLenTy = Type::getIntNTy (Load->getContext (), Subtarget.getXLen ());
251
273
252
- if (!isLegalInterleavedAccessType (ResVTy, Factor, LI->getAlign (),
253
- LI->getPointerAddressSpace (), DL))
274
+ Value *Ptr, *VL;
275
+ Align Alignment;
276
+ if (auto *LI = dyn_cast<LoadInst>(Load)) {
277
+ assert (LI->isSimple ());
278
+ Ptr = LI->getPointerOperand ();
279
+ Alignment = LI->getAlign ();
280
+ assert (!Mask && " Unexpected mask on a load\n " );
281
+ Mask = Builder.getAllOnesMask (ResVTy->getElementCount ());
282
+ VL = isa<FixedVectorType>(ResVTy)
283
+ ? Builder.CreateElementCount (XLenTy, ResVTy->getElementCount ())
284
+ : Constant::getAllOnesValue (XLenTy);
285
+ } else {
286
+ auto *VPLoad = cast<VPIntrinsic>(Load);
287
+ assert (VPLoad->getIntrinsicID () == Intrinsic::vp_load &&
288
+ " Unexpected intrinsic" );
289
+ Ptr = VPLoad->getMemoryPointerParam ();
290
+ Alignment = VPLoad->getPointerAlignment ().value_or (
291
+ DL.getABITypeAlign (ResVTy->getElementType ()));
292
+
293
+ assert (Mask && " vp.load needs a mask!" );
294
+
295
+ Value *WideEVL = VPLoad->getVectorLengthParam ();
296
+ // Conservatively check if EVL is a multiple of factor, otherwise some
297
+ // (trailing) elements might be lost after the transformation.
298
+ if (!isMultipleOfN (WideEVL, Load->getDataLayout (), Factor))
299
+ return false ;
300
+
301
+ VL = Builder.CreateZExt (
302
+ Builder.CreateUDiv (WideEVL,
303
+ ConstantInt::get (WideEVL->getType (), Factor)),
304
+ XLenTy);
305
+ }
306
+
307
+ Type *PtrTy = Ptr->getType ();
308
+ unsigned AS = PtrTy->getPointerAddressSpace ();
309
+ if (!isLegalInterleavedAccessType (ResVTy, Factor, Alignment, AS, DL))
254
310
return false ;
255
311
256
312
Value *Return;
257
- Type *PtrTy = LI->getPointerOperandType ();
258
- Type *XLenTy = Type::getIntNTy (LI->getContext (), Subtarget.getXLen ());
259
-
260
313
if (isa<FixedVectorType>(ResVTy)) {
261
- Value *VL = Builder.CreateElementCount (XLenTy, ResVTy->getElementCount ());
262
- Value *Mask = Builder.getAllOnesMask (ResVTy->getElementCount ());
263
314
Return = Builder.CreateIntrinsic (FixedVlsegIntrIds[Factor - 2 ],
264
- {ResVTy, PtrTy, XLenTy},
265
- {LI->getPointerOperand (), Mask, VL});
315
+ {ResVTy, PtrTy, XLenTy}, {Ptr, Mask, VL});
266
316
} else {
267
317
unsigned SEW = DL.getTypeSizeInBits (ResVTy->getElementType ());
268
318
unsigned NumElts = ResVTy->getElementCount ().getKnownMinValue ();
269
319
Type *VecTupTy = TargetExtType::get (
270
- LI ->getContext (), " riscv.vector.tuple" ,
271
- ScalableVectorType::get (Type::getInt8Ty (LI ->getContext ()),
320
+ Load ->getContext (), " riscv.vector.tuple" ,
321
+ ScalableVectorType::get (Type::getInt8Ty (Load ->getContext ()),
272
322
NumElts * SEW / 8 ),
273
323
Factor);
274
- Value *VL = Constant::getAllOnesValue (XLenTy);
275
- Value *Mask = Builder.getAllOnesMask (ResVTy->getElementCount ());
276
-
277
324
Function *VlsegNFunc = Intrinsic::getOrInsertDeclaration (
278
- LI ->getModule (), ScalableVlsegIntrIds[Factor - 2 ],
325
+ Load ->getModule (), ScalableVlsegIntrIds[Factor - 2 ],
279
326
{VecTupTy, PtrTy, Mask->getType (), VL->getType ()});
280
327
281
328
Value *Operands[] = {
282
329
PoisonValue::get (VecTupTy),
283
- LI-> getPointerOperand () ,
330
+ Ptr ,
284
331
Mask,
285
332
VL,
286
333
ConstantInt::get (XLenTy,
@@ -290,7 +337,7 @@ bool RISCVTargetLowering::lowerDeinterleaveIntrinsicToLoad(
290
337
CallInst *Vlseg = Builder.CreateCall (VlsegNFunc, Operands);
291
338
292
339
SmallVector<Type *, 2 > AggrTypes{Factor, ResVTy};
293
- Return = PoisonValue::get (StructType::get (LI ->getContext (), AggrTypes));
340
+ Return = PoisonValue::get (StructType::get (Load ->getContext (), AggrTypes));
294
341
for (unsigned i = 0 ; i < Factor; ++i) {
295
342
Value *VecExtract = Builder.CreateIntrinsic (
296
343
Intrinsic::riscv_tuple_extract, {ResVTy, VecTupTy},
@@ -370,27 +417,6 @@ bool RISCVTargetLowering::lowerInterleaveIntrinsicToStore(
370
417
return true ;
371
418
}
372
419
373
- static bool isMultipleOfN (const Value *V, const DataLayout &DL, unsigned N) {
374
- assert (N);
375
- if (N == 1 )
376
- return true ;
377
-
378
- using namespace PatternMatch ;
379
- // Right now we're only recognizing the simplest pattern.
380
- uint64_t C;
381
- if (match (V, m_CombineOr (m_ConstantInt (C),
382
- m_c_Mul (m_Value (), m_ConstantInt (C)))) &&
383
- C && C % N == 0 )
384
- return true ;
385
-
386
- if (isPowerOf2_32 (N)) {
387
- KnownBits KB = llvm::computeKnownBits (V, DL);
388
- return KB.countMinTrailingZeros () >= Log2_32 (N);
389
- }
390
-
391
- return false ;
392
- }
393
-
394
420
// / Lower an interleaved vp.load into a vlsegN intrinsic.
395
421
// /
396
422
// / E.g. Lower an interleaved vp.load (Factor = 2):
0 commit comments