Skip to content

Commit 5b76cdb

Browse files
authored
[VPlan] Handle AnyOf when unrolling. (#145340)
Currently AnyOf is not handled correctly during unrolling. This is currently causing mis-compiles when vectorizing early-exit loops with interleaving forced (even though selectInterleaveCount will currently only pick IC = 1, unless forced by the user). This patch updates handling of AnyOf to be analogous to computing final reduction results: during unrolling, the created copies for its original operand are added as additional operands, and AnyOf will always produce the reduced value across all unrolled iterations. Note that the generated code is still incorrect, as we also need to handle FirstActiveLane and ExtractElement with FirstActiveLane operands. I will share patches for those soon as well. PR: #145340
1 parent fe4b403 commit 5b76cdb

File tree

6 files changed

+238
-21
lines changed

6 files changed

+238
-21
lines changed

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -959,8 +959,10 @@ class VPInstruction : public VPRecipeWithIRFlags,
959959
// operand). Only generates scalar values (either for the first lane only or
960960
// for all lanes, depending on its uses).
961961
PtrAdd,
962-
// Returns a scalar boolean value, which is true if any lane of its (only
963-
// boolean) vector operand is true.
962+
// Returns a scalar boolean value, which is true if any lane of its
963+
// (boolean) vector operand is true. It produces the reduced value across
964+
// all unrolled iterations. Unrolling will add all copies of its original
965+
// operand as additional operands.
964966
AnyOf,
965967
// Calculates the first active lane index of the vector predicate operand.
966968
FirstActiveLane,

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -850,8 +850,10 @@ Value *VPInstruction::generate(VPTransformState &State) {
850850
return Builder.CreatePtrAdd(Ptr, Addend, Name, getGEPNoWrapFlags());
851851
}
852852
case VPInstruction::AnyOf: {
853-
Value *A = State.get(getOperand(0));
854-
return Builder.CreateOrReduce(A);
853+
Value *Res = State.get(getOperand(0));
854+
for (VPValue *Op : drop_begin(operands()))
855+
Res = Builder.CreateOr(Res, State.get(Op));
856+
return Builder.CreateOrReduce(Res);
855857
}
856858
case VPInstruction::FirstActiveLane: {
857859
Value *Mask = State.get(getOperand(0));

llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -345,10 +345,11 @@ void UnrollState::unrollBlock(VPBlockBase *VPB) {
345345
if (ToSkip.contains(&R) || isa<VPIRInstruction>(&R))
346346
continue;
347347

348-
// Add all VPValues for all parts to ComputeReductionResult which combines
349-
// the parts to compute the final reduction value.
348+
// Add all VPValues for all parts to AnyOf and Compute*Result which combine
349+
// all parts to compute the final value.
350350
VPValue *Op1;
351-
if (match(&R, m_VPInstruction<VPInstruction::ComputeAnyOfResult>(
351+
if (match(&R, m_VPInstruction<VPInstruction::AnyOf>(m_VPValue(Op1))) ||
352+
match(&R, m_VPInstruction<VPInstruction::ComputeAnyOfResult>(
352353
m_VPValue(), m_VPValue(), m_VPValue(Op1))) ||
353354
match(&R, m_VPInstruction<VPInstruction::ComputeReductionResult>(
354355
m_VPValue(), m_VPValue(Op1))) ||

llvm/test/Transforms/LoopVectorize/AArch64/single-early-exit-interleave.ll

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,43 @@ define i64 @same_exit_block_pre_inc_use1() #0 {
3131
; CHECK-NEXT: [[OFFSET_IDX:%.*]] = add i64 3, [[INDEX1]]
3232
; CHECK-NEXT: [[TMP7:%.*]] = getelementptr inbounds i8, ptr [[P1]], i64 [[OFFSET_IDX]]
3333
; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds i8, ptr [[TMP7]], i32 0
34+
; CHECK-NEXT: [[TMP18:%.*]] = call i64 @llvm.vscale.i64()
35+
; CHECK-NEXT: [[TMP19:%.*]] = mul nuw i64 [[TMP18]], 16
36+
; CHECK-NEXT: [[TMP29:%.*]] = getelementptr inbounds i8, ptr [[TMP7]], i64 [[TMP19]]
37+
; CHECK-NEXT: [[TMP36:%.*]] = call i64 @llvm.vscale.i64()
38+
; CHECK-NEXT: [[TMP37:%.*]] = mul nuw i64 [[TMP36]], 32
39+
; CHECK-NEXT: [[TMP38:%.*]] = getelementptr inbounds i8, ptr [[TMP7]], i64 [[TMP37]]
40+
; CHECK-NEXT: [[TMP39:%.*]] = call i64 @llvm.vscale.i64()
41+
; CHECK-NEXT: [[TMP40:%.*]] = mul nuw i64 [[TMP39]], 48
42+
; CHECK-NEXT: [[TMP41:%.*]] = getelementptr inbounds i8, ptr [[TMP7]], i64 [[TMP40]]
3443
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <vscale x 16 x i8>, ptr [[TMP8]], align 1
44+
; CHECK-NEXT: [[WIDE_LOAD5:%.*]] = load <vscale x 16 x i8>, ptr [[TMP29]], align 1
45+
; CHECK-NEXT: [[WIDE_LOAD3:%.*]] = load <vscale x 16 x i8>, ptr [[TMP38]], align 1
46+
; CHECK-NEXT: [[WIDE_LOAD4:%.*]] = load <vscale x 16 x i8>, ptr [[TMP41]], align 1
3547
; CHECK-NEXT: [[TMP9:%.*]] = getelementptr inbounds i8, ptr [[P2]], i64 [[OFFSET_IDX]]
3648
; CHECK-NEXT: [[TMP10:%.*]] = getelementptr inbounds i8, ptr [[TMP9]], i32 0
49+
; CHECK-NEXT: [[TMP20:%.*]] = call i64 @llvm.vscale.i64()
50+
; CHECK-NEXT: [[TMP21:%.*]] = mul nuw i64 [[TMP20]], 16
51+
; CHECK-NEXT: [[TMP22:%.*]] = getelementptr inbounds i8, ptr [[TMP9]], i64 [[TMP21]]
52+
; CHECK-NEXT: [[TMP23:%.*]] = call i64 @llvm.vscale.i64()
53+
; CHECK-NEXT: [[TMP24:%.*]] = mul nuw i64 [[TMP23]], 32
54+
; CHECK-NEXT: [[TMP25:%.*]] = getelementptr inbounds i8, ptr [[TMP9]], i64 [[TMP24]]
55+
; CHECK-NEXT: [[TMP26:%.*]] = call i64 @llvm.vscale.i64()
56+
; CHECK-NEXT: [[TMP27:%.*]] = mul nuw i64 [[TMP26]], 48
57+
; CHECK-NEXT: [[TMP28:%.*]] = getelementptr inbounds i8, ptr [[TMP9]], i64 [[TMP27]]
3758
; CHECK-NEXT: [[WIDE_LOAD2:%.*]] = load <vscale x 16 x i8>, ptr [[TMP10]], align 1
59+
; CHECK-NEXT: [[WIDE_LOAD6:%.*]] = load <vscale x 16 x i8>, ptr [[TMP22]], align 1
60+
; CHECK-NEXT: [[WIDE_LOAD7:%.*]] = load <vscale x 16 x i8>, ptr [[TMP25]], align 1
61+
; CHECK-NEXT: [[WIDE_LOAD8:%.*]] = load <vscale x 16 x i8>, ptr [[TMP28]], align 1
3862
; CHECK-NEXT: [[TMP11:%.*]] = icmp ne <vscale x 16 x i8> [[WIDE_LOAD]], [[WIDE_LOAD2]]
63+
; CHECK-NEXT: [[TMP30:%.*]] = icmp ne <vscale x 16 x i8> [[WIDE_LOAD5]], [[WIDE_LOAD6]]
64+
; CHECK-NEXT: [[TMP31:%.*]] = icmp ne <vscale x 16 x i8> [[WIDE_LOAD3]], [[WIDE_LOAD7]]
65+
; CHECK-NEXT: [[TMP32:%.*]] = icmp ne <vscale x 16 x i8> [[WIDE_LOAD4]], [[WIDE_LOAD8]]
3966
; CHECK-NEXT: [[INDEX_NEXT3]] = add nuw i64 [[INDEX1]], [[TMP5]]
40-
; CHECK-NEXT: [[TMP12:%.*]] = call i1 @llvm.vector.reduce.or.nxv16i1(<vscale x 16 x i1> [[TMP11]])
67+
; CHECK-NEXT: [[TMP33:%.*]] = or <vscale x 16 x i1> [[TMP11]], [[TMP30]]
68+
; CHECK-NEXT: [[TMP34:%.*]] = or <vscale x 16 x i1> [[TMP33]], [[TMP31]]
69+
; CHECK-NEXT: [[TMP35:%.*]] = or <vscale x 16 x i1> [[TMP34]], [[TMP32]]
70+
; CHECK-NEXT: [[TMP12:%.*]] = call i1 @llvm.vector.reduce.or.nxv16i1(<vscale x 16 x i1> [[TMP35]])
4171
; CHECK-NEXT: [[TMP13:%.*]] = icmp eq i64 [[INDEX_NEXT3]], [[N_VEC]]
4272
; CHECK-NEXT: [[TMP14:%.*]] = or i1 [[TMP12]], [[TMP13]]
4373
; CHECK-NEXT: br i1 [[TMP14]], label [[MIDDLE_SPLIT:%.*]], label [[LOOP]], !llvm.loop [[LOOP0:![0-9]+]]

0 commit comments

Comments
 (0)