Skip to content

Commit 4b81dc7

Browse files
authored
[IA] Use a single callback for lowerDeinterleaveIntrinsic [nfc] (#148978)
This essentially merges the handling for VPLoad - currently in lowerInterleavedVPLoad which is shared between shuffle and intrinsic based interleaves - into the existing dedicated routine. My plan is that if we like this factoring is that I'll do the same for the intrinsic store paths, and then remove the excess generality from the shuffle paths since we don't need to support both modes in the shared VPLoad/Store callbacks. We can probably even fold the VP versions into the non-VP shuffle variants in the analogous way.
1 parent 386f73d commit 4b81dc7

File tree

6 files changed

+88
-59
lines changed

6 files changed

+88
-59
lines changed

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3249,10 +3249,11 @@ class LLVM_ABI TargetLoweringBase {
32493249
/// Return true on success. Currently only supports
32503250
/// llvm.vector.deinterleave{2,3,5,7}
32513251
///
3252-
/// \p LI is the accompanying load instruction.
3252+
/// \p Load is the accompanying load instruction. Can be either a plain load
3253+
/// instruction or a vp.load intrinsic.
32533254
/// \p DeinterleaveValues contains the deinterleaved values.
32543255
virtual bool
3255-
lowerDeinterleaveIntrinsicToLoad(LoadInst *LI,
3256+
lowerDeinterleaveIntrinsicToLoad(Instruction *Load, Value *Mask,
32563257
ArrayRef<Value *> DeinterleaveValues) const {
32573258
return false;
32583259
}

llvm/lib/CodeGen/InterleavedAccessPass.cpp

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -634,37 +634,32 @@ bool InterleavedAccessImpl::lowerDeinterleaveIntrinsic(
634634
if (!LastFactor)
635635
return false;
636636

637+
Value *Mask = nullptr;
637638
if (auto *VPLoad = dyn_cast<VPIntrinsic>(LoadedVal)) {
638639
if (VPLoad->getIntrinsicID() != Intrinsic::vp_load)
639640
return false;
640641
// Check mask operand. Handle both all-true/false and interleaved mask.
641642
Value *WideMask = VPLoad->getOperand(1);
642-
Value *Mask =
643-
getMask(WideMask, Factor, cast<VectorType>(LastFactor->getType()));
643+
Mask = getMask(WideMask, Factor, cast<VectorType>(LastFactor->getType()));
644644
if (!Mask)
645645
return false;
646646

647647
LLVM_DEBUG(dbgs() << "IA: Found a vp.load with deinterleave intrinsic "
648648
<< *DI << " and factor = " << Factor << "\n");
649-
650-
// Since lowerInterleaveLoad expects Shuffles and LoadInst, use special
651-
// TLI function to emit target-specific interleaved instruction.
652-
if (!TLI->lowerInterleavedVPLoad(VPLoad, Mask, DeinterleaveValues))
653-
return false;
654-
655649
} else {
656650
auto *LI = cast<LoadInst>(LoadedVal);
657651
if (!LI->isSimple())
658652
return false;
659653

660654
LLVM_DEBUG(dbgs() << "IA: Found a load with deinterleave intrinsic " << *DI
661655
<< " and factor = " << Factor << "\n");
662-
663-
// Try and match this with target specific intrinsics.
664-
if (!TLI->lowerDeinterleaveIntrinsicToLoad(LI, DeinterleaveValues))
665-
return false;
666656
}
667657

658+
// Try and match this with target specific intrinsics.
659+
if (!TLI->lowerDeinterleaveIntrinsicToLoad(cast<Instruction>(LoadedVal), Mask,
660+
DeinterleaveValues))
661+
return false;
662+
668663
for (Value *V : DeinterleaveValues)
669664
if (V)
670665
DeadInsts.insert(cast<Instruction>(V));

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17476,12 +17476,17 @@ bool AArch64TargetLowering::lowerInterleavedStore(StoreInst *SI,
1747617476
}
1747717477

1747817478
bool AArch64TargetLowering::lowerDeinterleaveIntrinsicToLoad(
17479-
LoadInst *LI, ArrayRef<Value *> DeinterleavedValues) const {
17479+
Instruction *Load, Value *Mask,
17480+
ArrayRef<Value *> DeinterleavedValues) const {
1748017481
unsigned Factor = DeinterleavedValues.size();
1748117482
if (Factor != 2 && Factor != 4) {
1748217483
LLVM_DEBUG(dbgs() << "Matching ld2 and ld4 patterns failed\n");
1748317484
return false;
1748417485
}
17486+
auto *LI = dyn_cast<LoadInst>(Load);
17487+
if (!LI)
17488+
return false;
17489+
assert(!Mask && "Unexpected mask on a load\n");
1748517490

1748617491
Value *FirstActive = *llvm::find_if(DeinterleavedValues,
1748717492
[](Value *V) { return V != nullptr; });

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,8 @@ class AArch64TargetLowering : public TargetLowering {
219219
unsigned Factor) const override;
220220

221221
bool lowerDeinterleaveIntrinsicToLoad(
222-
LoadInst *LI, ArrayRef<Value *> DeinterleaveValues) const override;
222+
Instruction *Load, Value *Mask,
223+
ArrayRef<Value *> DeinterleaveValues) const override;
223224

224225
bool lowerInterleaveIntrinsicToStore(
225226
StoreInst *SI, ArrayRef<Value *> InterleaveValues) const override;

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,8 @@ class RISCVTargetLowering : public TargetLowering {
438438
unsigned Factor) const override;
439439

440440
bool lowerDeinterleaveIntrinsicToLoad(
441-
LoadInst *LI, ArrayRef<Value *> DeinterleaveValues) const override;
441+
Instruction *Load, Value *Mask,
442+
ArrayRef<Value *> DeinterleaveValues) const override;
442443

443444
bool lowerInterleaveIntrinsicToStore(
444445
StoreInst *SI, ArrayRef<Value *> InterleaveValues) const override;

llvm/lib/Target/RISCV/RISCVInterleavedAccess.cpp

Lines changed: 68 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -234,53 +234,100 @@ bool RISCVTargetLowering::lowerInterleavedStore(StoreInst *SI,
234234
return true;
235235
}
236236

237+
static bool isMultipleOfN(const Value *V, const DataLayout &DL, unsigned N) {
238+
assert(N);
239+
if (N == 1)
240+
return true;
241+
242+
using namespace PatternMatch;
243+
// Right now we're only recognizing the simplest pattern.
244+
uint64_t C;
245+
if (match(V, m_CombineOr(m_ConstantInt(C),
246+
m_c_Mul(m_Value(), m_ConstantInt(C)))) &&
247+
C && C % N == 0)
248+
return true;
249+
250+
if (isPowerOf2_32(N)) {
251+
KnownBits KB = llvm::computeKnownBits(V, DL);
252+
return KB.countMinTrailingZeros() >= Log2_32(N);
253+
}
254+
255+
return false;
256+
}
257+
237258
bool RISCVTargetLowering::lowerDeinterleaveIntrinsicToLoad(
238-
LoadInst *LI, ArrayRef<Value *> DeinterleaveValues) const {
259+
Instruction *Load, Value *Mask,
260+
ArrayRef<Value *> DeinterleaveValues) const {
239261
const unsigned Factor = DeinterleaveValues.size();
240262
if (Factor > 8)
241263
return false;
242264

243-
assert(LI->isSimple());
244-
IRBuilder<> Builder(LI);
265+
IRBuilder<> Builder(Load);
245266

246267
Value *FirstActive =
247268
*llvm::find_if(DeinterleaveValues, [](Value *V) { return V != nullptr; });
248269
VectorType *ResVTy = cast<VectorType>(FirstActive->getType());
249270

250-
const DataLayout &DL = LI->getDataLayout();
271+
const DataLayout &DL = Load->getDataLayout();
272+
auto *XLenTy = Type::getIntNTy(Load->getContext(), Subtarget.getXLen());
251273

252-
if (!isLegalInterleavedAccessType(ResVTy, Factor, LI->getAlign(),
253-
LI->getPointerAddressSpace(), DL))
274+
Value *Ptr, *VL;
275+
Align Alignment;
276+
if (auto *LI = dyn_cast<LoadInst>(Load)) {
277+
assert(LI->isSimple());
278+
Ptr = LI->getPointerOperand();
279+
Alignment = LI->getAlign();
280+
assert(!Mask && "Unexpected mask on a load\n");
281+
Mask = Builder.getAllOnesMask(ResVTy->getElementCount());
282+
VL = isa<FixedVectorType>(ResVTy)
283+
? Builder.CreateElementCount(XLenTy, ResVTy->getElementCount())
284+
: Constant::getAllOnesValue(XLenTy);
285+
} else {
286+
auto *VPLoad = cast<VPIntrinsic>(Load);
287+
assert(VPLoad->getIntrinsicID() == Intrinsic::vp_load &&
288+
"Unexpected intrinsic");
289+
Ptr = VPLoad->getMemoryPointerParam();
290+
Alignment = VPLoad->getPointerAlignment().value_or(
291+
DL.getABITypeAlign(ResVTy->getElementType()));
292+
293+
assert(Mask && "vp.load needs a mask!");
294+
295+
Value *WideEVL = VPLoad->getVectorLengthParam();
296+
// Conservatively check if EVL is a multiple of factor, otherwise some
297+
// (trailing) elements might be lost after the transformation.
298+
if (!isMultipleOfN(WideEVL, Load->getDataLayout(), Factor))
299+
return false;
300+
301+
VL = Builder.CreateZExt(
302+
Builder.CreateUDiv(WideEVL,
303+
ConstantInt::get(WideEVL->getType(), Factor)),
304+
XLenTy);
305+
}
306+
307+
Type *PtrTy = Ptr->getType();
308+
unsigned AS = PtrTy->getPointerAddressSpace();
309+
if (!isLegalInterleavedAccessType(ResVTy, Factor, Alignment, AS, DL))
254310
return false;
255311

256312
Value *Return;
257-
Type *PtrTy = LI->getPointerOperandType();
258-
Type *XLenTy = Type::getIntNTy(LI->getContext(), Subtarget.getXLen());
259-
260313
if (isa<FixedVectorType>(ResVTy)) {
261-
Value *VL = Builder.CreateElementCount(XLenTy, ResVTy->getElementCount());
262-
Value *Mask = Builder.getAllOnesMask(ResVTy->getElementCount());
263314
Return = Builder.CreateIntrinsic(FixedVlsegIntrIds[Factor - 2],
264-
{ResVTy, PtrTy, XLenTy},
265-
{LI->getPointerOperand(), Mask, VL});
315+
{ResVTy, PtrTy, XLenTy}, {Ptr, Mask, VL});
266316
} else {
267317
unsigned SEW = DL.getTypeSizeInBits(ResVTy->getElementType());
268318
unsigned NumElts = ResVTy->getElementCount().getKnownMinValue();
269319
Type *VecTupTy = TargetExtType::get(
270-
LI->getContext(), "riscv.vector.tuple",
271-
ScalableVectorType::get(Type::getInt8Ty(LI->getContext()),
320+
Load->getContext(), "riscv.vector.tuple",
321+
ScalableVectorType::get(Type::getInt8Ty(Load->getContext()),
272322
NumElts * SEW / 8),
273323
Factor);
274-
Value *VL = Constant::getAllOnesValue(XLenTy);
275-
Value *Mask = Builder.getAllOnesMask(ResVTy->getElementCount());
276-
277324
Function *VlsegNFunc = Intrinsic::getOrInsertDeclaration(
278-
LI->getModule(), ScalableVlsegIntrIds[Factor - 2],
325+
Load->getModule(), ScalableVlsegIntrIds[Factor - 2],
279326
{VecTupTy, PtrTy, Mask->getType(), VL->getType()});
280327

281328
Value *Operands[] = {
282329
PoisonValue::get(VecTupTy),
283-
LI->getPointerOperand(),
330+
Ptr,
284331
Mask,
285332
VL,
286333
ConstantInt::get(XLenTy,
@@ -290,7 +337,7 @@ bool RISCVTargetLowering::lowerDeinterleaveIntrinsicToLoad(
290337
CallInst *Vlseg = Builder.CreateCall(VlsegNFunc, Operands);
291338

292339
SmallVector<Type *, 2> AggrTypes{Factor, ResVTy};
293-
Return = PoisonValue::get(StructType::get(LI->getContext(), AggrTypes));
340+
Return = PoisonValue::get(StructType::get(Load->getContext(), AggrTypes));
294341
for (unsigned i = 0; i < Factor; ++i) {
295342
Value *VecExtract = Builder.CreateIntrinsic(
296343
Intrinsic::riscv_tuple_extract, {ResVTy, VecTupTy},
@@ -370,27 +417,6 @@ bool RISCVTargetLowering::lowerInterleaveIntrinsicToStore(
370417
return true;
371418
}
372419

373-
static bool isMultipleOfN(const Value *V, const DataLayout &DL, unsigned N) {
374-
assert(N);
375-
if (N == 1)
376-
return true;
377-
378-
using namespace PatternMatch;
379-
// Right now we're only recognizing the simplest pattern.
380-
uint64_t C;
381-
if (match(V, m_CombineOr(m_ConstantInt(C),
382-
m_c_Mul(m_Value(), m_ConstantInt(C)))) &&
383-
C && C % N == 0)
384-
return true;
385-
386-
if (isPowerOf2_32(N)) {
387-
KnownBits KB = llvm::computeKnownBits(V, DL);
388-
return KB.countMinTrailingZeros() >= Log2_32(N);
389-
}
390-
391-
return false;
392-
}
393-
394420
/// Lower an interleaved vp.load into a vlsegN intrinsic.
395421
///
396422
/// E.g. Lower an interleaved vp.load (Factor = 2):

0 commit comments

Comments
 (0)