Skip to content

Commit 26b7d9d

Browse files
committed
[LoopVectorize] Permit vectorisation of more select(cmp(), X, Y) reduction patterns
This patch adds further support for vectorisation of loops that involve selecting an integer value based on a previous comparison. Consider the following C++ loop: int r = a; for (int i = 0; i < n; i++) { if (src[i] > 3) { r = b; } src[i] += 2; } We should be able to vectorise this loop because all we are doing is selecting between two states - 'a' and 'b' - both of which are loop invariant. This just involves building a vector of values that contain either 'a' or 'b', where the final reduced value will be 'b' if any lane contains 'b'. The IR generated by clang typically looks like this: %phi = phi i32 [ %a, %entry ], [ %phi.update, %for.body ] ... %pred = icmp ugt i32 %val, i32 3 %phi.update = select i1 %pred, i32 %b, i32 %phi We already detect min/max patterns, which also involve a select + cmp. However, with the min/max patterns we are selecting loaded values (and hence loop variant) in the loop. In addition we only support certain cmp predicates. This patch adds a new pattern matching function (isSelectCmpPattern) and new RecurKind enums - SelectICmp & SelectFCmp. We only support selecting values that are integer and loop invariant, however we can support any kind of compare - integer or float. Tests have been added here: Transforms/LoopVectorize/AArch64/sve-select-cmp.ll Transforms/LoopVectorize/select-cmp-predicated.ll Transforms/LoopVectorize/select-cmp.ll Differential Revision: https://reviews.llvm.org/D108136
1 parent cd1bd95 commit 26b7d9d

File tree

11 files changed

+942
-48
lines changed

11 files changed

+942
-48
lines changed

llvm/include/llvm/Analysis/IVDescriptors.h

+42-22
Original file line numberDiff line numberDiff line change
@@ -36,20 +36,24 @@ class DominatorTree;
3636

3737
/// These are the kinds of recurrences that we support.
3838
enum class RecurKind {
39-
None, ///< Not a recurrence.
40-
Add, ///< Sum of integers.
41-
Mul, ///< Product of integers.
42-
Or, ///< Bitwise or logical OR of integers.
43-
And, ///< Bitwise or logical AND of integers.
44-
Xor, ///< Bitwise or logical XOR of integers.
45-
SMin, ///< Signed integer min implemented in terms of select(cmp()).
46-
SMax, ///< Signed integer max implemented in terms of select(cmp()).
47-
UMin, ///< Unisgned integer min implemented in terms of select(cmp()).
48-
UMax, ///< Unsigned integer max implemented in terms of select(cmp()).
49-
FAdd, ///< Sum of floats.
50-
FMul, ///< Product of floats.
51-
FMin, ///< FP min implemented in terms of select(cmp()).
52-
FMax ///< FP max implemented in terms of select(cmp()).
39+
None, ///< Not a recurrence.
40+
Add, ///< Sum of integers.
41+
Mul, ///< Product of integers.
42+
Or, ///< Bitwise or logical OR of integers.
43+
And, ///< Bitwise or logical AND of integers.
44+
Xor, ///< Bitwise or logical XOR of integers.
45+
SMin, ///< Signed integer min implemented in terms of select(cmp()).
46+
SMax, ///< Signed integer max implemented in terms of select(cmp()).
47+
UMin, ///< Unisgned integer min implemented in terms of select(cmp()).
48+
UMax, ///< Unsigned integer max implemented in terms of select(cmp()).
49+
FAdd, ///< Sum of floats.
50+
FMul, ///< Product of floats.
51+
FMin, ///< FP min implemented in terms of select(cmp()).
52+
FMax, ///< FP max implemented in terms of select(cmp()).
53+
SelectICmp, ///< Integer select(icmp(),x,y) where one of (x,y) is loop
54+
///< invariant
55+
SelectFCmp ///< Integer select(fcmp(),x,y) where one of (x,y) is loop
56+
///< invariant
5357
};
5458

5559
/// The RecurrenceDescriptor is used to identify recurrences variables in a
@@ -112,12 +116,14 @@ class RecurrenceDescriptor {
112116
};
113117

114118
/// Returns a struct describing if the instruction 'I' can be a recurrence
115-
/// variable of type 'Kind'. If the recurrence is a min/max pattern of
116-
/// select(icmp()) this function advances the instruction pointer 'I' from the
117-
/// compare instruction to the select instruction and stores this pointer in
118-
/// 'PatternLastInst' member of the returned struct.
119-
static InstDesc isRecurrenceInstr(Instruction *I, RecurKind Kind,
120-
InstDesc &Prev, FastMathFlags FuncFMF);
119+
/// variable of type 'Kind' for a Loop \p L and reduction PHI \p Phi.
120+
/// If the recurrence is a min/max pattern of select(icmp()) this function
121+
/// advances the instruction pointer 'I' from the compare instruction to the
122+
/// select instruction and stores this pointer in 'PatternLastInst' member of
123+
/// the returned struct.
124+
static InstDesc isRecurrenceInstr(Loop *L, PHINode *Phi, Instruction *I,
125+
RecurKind Kind, InstDesc &Prev,
126+
FastMathFlags FuncFMF);
121127

122128
/// Returns true if instruction I has multiple uses in Insts
123129
static bool hasMultipleUsesOf(Instruction *I,
@@ -135,13 +141,21 @@ class RecurrenceDescriptor {
135141
static InstDesc isMinMaxPattern(Instruction *I, RecurKind Kind,
136142
const InstDesc &Prev);
137143

144+
/// Returns a struct describing whether the instruction is either a
145+
/// Select(ICmp(A, B), X, Y), or
146+
/// Select(FCmp(A, B), X, Y)
147+
/// where one of (X, Y) is a loop invariant integer and the other is a PHI
148+
/// value. \p Prev specifies the description of an already processed select
149+
/// instruction, so its corresponding cmp can be matched to it.
150+
static InstDesc isSelectCmpPattern(Loop *Loop, PHINode *OrigPhi,
151+
Instruction *I, InstDesc &Prev);
152+
138153
/// Returns a struct describing if the instruction is a
139154
/// Select(FCmp(X, Y), (Z = X op PHINode), PHINode) instruction pattern.
140155
static InstDesc isConditionalRdxPattern(RecurKind Kind, Instruction *I);
141156

142157
/// Returns identity corresponding to the RecurrenceKind.
143-
static Constant *getRecurrenceIdentity(RecurKind K, Type *Tp,
144-
FastMathFlags FMF);
158+
Value *getRecurrenceIdentity(RecurKind K, Type *Tp, FastMathFlags FMF);
145159

146160
/// Returns the opcode corresponding to the RecurrenceKind.
147161
static unsigned getOpcode(RecurKind Kind);
@@ -221,6 +235,12 @@ class RecurrenceDescriptor {
221235
return isIntMinMaxRecurrenceKind(Kind) || isFPMinMaxRecurrenceKind(Kind);
222236
}
223237

238+
/// Returns true if the recurrence kind is of the form
239+
/// select(cmp(),x,y) where one of (x,y) is loop invariant.
240+
static bool isSelectCmpRecurrenceKind(RecurKind Kind) {
241+
return Kind == RecurKind::SelectICmp || Kind == RecurKind::SelectFCmp;
242+
}
243+
224244
/// Returns the type of the recurrence. This type can be narrower than the
225245
/// actual type of the Phi if the recurrence has been type-promoted.
226246
Type *getRecurrenceType() const { return RecurrenceType; }

llvm/include/llvm/Transforms/Utils/LoopUtils.h

+20-1
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,15 @@ bool canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT,
351351
/// Returns the comparison predicate used when expanding a min/max reduction.
352352
CmpInst::Predicate getMinMaxReductionPredicate(RecurKind RK);
353353

354+
/// See RecurrenceDescriptor::isSelectCmpPattern for a description of the
355+
/// pattern we are trying to match. In this pattern we are only ever selecting
356+
/// between two values: 1) an initial PHI start value, and 2) a loop invariant
357+
/// value. This function uses \p LoopExitInst to determine 2), which we then use
358+
/// to select between \p Left and \p Right. Any lane value in \p Left that
359+
/// matches 2) will be merged into \p Right.
360+
Value *createSelectCmpOp(IRBuilderBase &Builder, Value *StartVal, RecurKind RK,
361+
Value *Left, Value *Right);
362+
354363
/// Returns a Min/Max operation corresponding to MinMaxRecurrenceKind.
355364
/// The Builder's fast-math-flags must be set to propagate the expected values.
356365
Value *createMinMaxOp(IRBuilderBase &Builder, RecurKind RK, Value *Left,
@@ -378,12 +387,22 @@ Value *createSimpleTargetReduction(IRBuilderBase &B,
378387
RecurKind RdxKind,
379388
ArrayRef<Value *> RedOps = None);
380389

390+
/// Create a target reduction of the given vector \p Src for a reduction of the
391+
/// kind RecurKind::SelectICmp or RecurKind::SelectFCmp. The reduction operation
392+
/// is described by \p Desc.
393+
Value *createSelectCmpTargetReduction(IRBuilderBase &B,
394+
const TargetTransformInfo *TTI,
395+
Value *Src,
396+
const RecurrenceDescriptor &Desc,
397+
PHINode *OrigPhi);
398+
381399
/// Create a generic target reduction using a recurrence descriptor \p Desc
382400
/// The target is queried to determine if intrinsics or shuffle sequences are
383401
/// required to implement the reduction.
384402
/// Fast-math-flags are propagated using the RecurrenceDescriptor.
385403
Value *createTargetReduction(IRBuilderBase &B, const TargetTransformInfo *TTI,
386-
const RecurrenceDescriptor &Desc, Value *Src);
404+
const RecurrenceDescriptor &Desc, Value *Src,
405+
PHINode *OrigPhi = nullptr);
387406

388407
/// Create an ordered reduction intrinsic using the given recurrence
389408
/// descriptor \p Desc.

llvm/lib/Analysis/IVDescriptors.cpp

+94-9
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ bool RecurrenceDescriptor::isIntegerRecurrenceKind(RecurKind Kind) {
6262
case RecurKind::SMin:
6363
case RecurKind::UMax:
6464
case RecurKind::UMin:
65+
case RecurKind::SelectICmp:
66+
case RecurKind::SelectFCmp:
6567
return true;
6668
}
6769
return false;
@@ -327,7 +329,8 @@ bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurKind Kind,
327329
// the starting value (the Phi or an AND instruction if the Phi has been
328330
// type-promoted).
329331
if (Cur != Start) {
330-
ReduxDesc = isRecurrenceInstr(Cur, Kind, ReduxDesc, FuncFMF);
332+
ReduxDesc =
333+
isRecurrenceInstr(TheLoop, Phi, Cur, Kind, ReduxDesc, FuncFMF);
331334
if (!ReduxDesc.isRecurrence())
332335
return false;
333336
// FIXME: FMF is allowed on phi, but propagation is not handled correctly.
@@ -360,17 +363,18 @@ bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurKind Kind,
360363

361364
// A reduction operation must only have one use of the reduction value.
362365
if (!IsAPhi && !IsASelect && !isMinMaxRecurrenceKind(Kind) &&
366+
!isSelectCmpRecurrenceKind(Kind) &&
363367
hasMultipleUsesOf(Cur, VisitedInsts, 1))
364368
return false;
365369

366370
// All inputs to a PHI node must be a reduction value.
367371
if (IsAPhi && Cur != Phi && !areAllUsesIn(Cur, VisitedInsts))
368372
return false;
369373

370-
if (isIntMinMaxRecurrenceKind(Kind) &&
374+
if ((isIntMinMaxRecurrenceKind(Kind) || Kind == RecurKind::SelectICmp) &&
371375
(isa<ICmpInst>(Cur) || isa<SelectInst>(Cur)))
372376
++NumCmpSelectPatternInst;
373-
if (isFPMinMaxRecurrenceKind(Kind) &&
377+
if ((isFPMinMaxRecurrenceKind(Kind) || Kind == RecurKind::SelectFCmp) &&
374378
(isa<FCmpInst>(Cur) || isa<SelectInst>(Cur)))
375379
++NumCmpSelectPatternInst;
376380

@@ -423,8 +427,9 @@ bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurKind Kind,
423427
((!isa<FCmpInst>(UI) && !isa<ICmpInst>(UI) &&
424428
!isa<SelectInst>(UI)) ||
425429
(!isConditionalRdxPattern(Kind, UI).isRecurrence() &&
426-
!isMinMaxPattern(UI, Kind, IgnoredVal)
427-
.isRecurrence())))
430+
!isSelectCmpPattern(TheLoop, Phi, UI, IgnoredVal)
431+
.isRecurrence() &&
432+
!isMinMaxPattern(UI, Kind, IgnoredVal).isRecurrence())))
428433
return false;
429434

430435
// Remember that we completed the cycle.
@@ -442,6 +447,9 @@ bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurKind Kind,
442447
NumCmpSelectPatternInst != 0)
443448
return false;
444449

450+
if (isSelectCmpRecurrenceKind(Kind) && NumCmpSelectPatternInst != 1)
451+
return false;
452+
445453
if (!FoundStartPHI || !FoundReduxOp || !ExitInstruction)
446454
return false;
447455

@@ -508,6 +516,63 @@ bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurKind Kind,
508516
return true;
509517
}
510518

519+
// We are looking for loops that do something like this:
520+
// int r = 0;
521+
// for (int i = 0; i < n; i++) {
522+
// if (src[i] > 3)
523+
// r = 3;
524+
// }
525+
// where the reduction value (r) only has two states, in this example 0 or 3.
526+
// The generated LLVM IR for this type of loop will be like this:
527+
// for.body:
528+
// %r = phi i32 [ %spec.select, %for.body ], [ 0, %entry ]
529+
// ...
530+
// %cmp = icmp sgt i32 %5, 3
531+
// %spec.select = select i1 %cmp, i32 3, i32 %r
532+
// ...
533+
// In general we can support vectorization of loops where 'r' flips between
534+
// any two non-constants, provided they are loop invariant. The only thing
535+
// we actually care about at the end of the loop is whether or not any lane
536+
// in the selected vector is different from the start value. The final
537+
// across-vector reduction after the loop simply involves choosing the start
538+
// value if nothing changed (0 in the example above) or the other selected
539+
// value (3 in the example above).
540+
RecurrenceDescriptor::InstDesc
541+
RecurrenceDescriptor::isSelectCmpPattern(Loop *Loop, PHINode *OrigPhi,
542+
Instruction *I, InstDesc &Prev) {
543+
// We must handle the select(cmp(),x,y) as a single instruction. Advance to
544+
// the select.
545+
CmpInst::Predicate Pred;
546+
if (match(I, m_OneUse(m_Cmp(Pred, m_Value(), m_Value())))) {
547+
if (auto *Select = dyn_cast<SelectInst>(*I->user_begin()))
548+
return InstDesc(Select, Prev.getRecKind());
549+
}
550+
551+
// Only match select with single use cmp condition.
552+
if (!match(I, m_Select(m_OneUse(m_Cmp(Pred, m_Value(), m_Value())), m_Value(),
553+
m_Value())))
554+
return InstDesc(false, I);
555+
556+
SelectInst *SI = cast<SelectInst>(I);
557+
Value *NonPhi = nullptr;
558+
559+
if (OrigPhi == dyn_cast<PHINode>(SI->getTrueValue()))
560+
NonPhi = SI->getFalseValue();
561+
else if (OrigPhi == dyn_cast<PHINode>(SI->getFalseValue()))
562+
NonPhi = SI->getTrueValue();
563+
else
564+
return InstDesc(false, I);
565+
566+
// We are looking for selects of the form:
567+
// select(cmp(), phi, loop_invariant) or
568+
// select(cmp(), loop_invariant, phi)
569+
if (!Loop->isLoopInvariant(NonPhi))
570+
return InstDesc(false, I);
571+
572+
return InstDesc(I, isa<ICmpInst>(I->getOperand(0)) ? RecurKind::SelectICmp
573+
: RecurKind::SelectFCmp);
574+
}
575+
511576
RecurrenceDescriptor::InstDesc
512577
RecurrenceDescriptor::isMinMaxPattern(Instruction *I, RecurKind Kind,
513578
const InstDesc &Prev) {
@@ -602,7 +667,8 @@ RecurrenceDescriptor::isConditionalRdxPattern(RecurKind Kind, Instruction *I) {
602667
}
603668

604669
RecurrenceDescriptor::InstDesc
605-
RecurrenceDescriptor::isRecurrenceInstr(Instruction *I, RecurKind Kind,
670+
RecurrenceDescriptor::isRecurrenceInstr(Loop *L, PHINode *OrigPhi,
671+
Instruction *I, RecurKind Kind,
606672
InstDesc &Prev, FastMathFlags FuncFMF) {
607673
assert(Prev.getRecKind() == RecurKind::None || Prev.getRecKind() == Kind);
608674
switch (I->getOpcode()) {
@@ -636,6 +702,8 @@ RecurrenceDescriptor::isRecurrenceInstr(Instruction *I, RecurKind Kind,
636702
case Instruction::FCmp:
637703
case Instruction::ICmp:
638704
case Instruction::Call:
705+
if (isSelectCmpRecurrenceKind(Kind))
706+
return isSelectCmpPattern(L, OrigPhi, I, Prev);
639707
if (isIntMinMaxRecurrenceKind(Kind) ||
640708
(((FuncFMF.noNaNs() && FuncFMF.noSignedZeros()) ||
641709
(isa<FPMathOperator>(I) && I->hasNoNaNs() &&
@@ -664,7 +732,6 @@ bool RecurrenceDescriptor::isReductionPHI(PHINode *Phi, Loop *TheLoop,
664732
RecurrenceDescriptor &RedDes,
665733
DemandedBits *DB, AssumptionCache *AC,
666734
DominatorTree *DT) {
667-
668735
BasicBlock *Header = TheLoop->getHeader();
669736
Function &F = *Header->getParent();
670737
FastMathFlags FMF;
@@ -709,6 +776,12 @@ bool RecurrenceDescriptor::isReductionPHI(PHINode *Phi, Loop *TheLoop,
709776
LLVM_DEBUG(dbgs() << "Found a UMIN reduction PHI." << *Phi << "\n");
710777
return true;
711778
}
779+
if (AddReductionVar(Phi, RecurKind::SelectICmp, TheLoop, FMF, RedDes, DB, AC,
780+
DT)) {
781+
LLVM_DEBUG(dbgs() << "Found an integer conditional select reduction PHI."
782+
<< *Phi << "\n");
783+
return true;
784+
}
712785
if (AddReductionVar(Phi, RecurKind::FMul, TheLoop, FMF, RedDes, DB, AC, DT)) {
713786
LLVM_DEBUG(dbgs() << "Found an FMult reduction PHI." << *Phi << "\n");
714787
return true;
@@ -725,6 +798,12 @@ bool RecurrenceDescriptor::isReductionPHI(PHINode *Phi, Loop *TheLoop,
725798
LLVM_DEBUG(dbgs() << "Found a float MIN reduction PHI." << *Phi << "\n");
726799
return true;
727800
}
801+
if (AddReductionVar(Phi, RecurKind::SelectFCmp, TheLoop, FMF, RedDes, DB, AC,
802+
DT)) {
803+
LLVM_DEBUG(dbgs() << "Found a float conditional select reduction PHI."
804+
<< " PHI." << *Phi << "\n");
805+
return true;
806+
}
728807
// Not a reduction of known type.
729808
return false;
730809
}
@@ -831,8 +910,8 @@ bool RecurrenceDescriptor::isFirstOrderRecurrence(
831910

832911
/// This function returns the identity element (or neutral element) for
833912
/// the operation K.
834-
Constant *RecurrenceDescriptor::getRecurrenceIdentity(RecurKind K, Type *Tp,
835-
FastMathFlags FMF) {
913+
Value *RecurrenceDescriptor::getRecurrenceIdentity(RecurKind K, Type *Tp,
914+
FastMathFlags FMF) {
836915
switch (K) {
837916
case RecurKind::Xor:
838917
case RecurKind::Add:
@@ -872,6 +951,10 @@ Constant *RecurrenceDescriptor::getRecurrenceIdentity(RecurKind K, Type *Tp,
872951
return ConstantFP::getInfinity(Tp, true);
873952
case RecurKind::FMax:
874953
return ConstantFP::getInfinity(Tp, false);
954+
case RecurKind::SelectICmp:
955+
case RecurKind::SelectFCmp:
956+
return getRecurrenceStartValue();
957+
break;
875958
default:
876959
llvm_unreachable("Unknown recurrence kind");
877960
}
@@ -897,9 +980,11 @@ unsigned RecurrenceDescriptor::getOpcode(RecurKind Kind) {
897980
case RecurKind::SMin:
898981
case RecurKind::UMax:
899982
case RecurKind::UMin:
983+
case RecurKind::SelectICmp:
900984
return Instruction::ICmp;
901985
case RecurKind::FMax:
902986
case RecurKind::FMin:
987+
case RecurKind::SelectFCmp:
903988
return Instruction::FCmp;
904989
default:
905990
llvm_unreachable("Unknown recurrence operation");

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -1981,6 +1981,8 @@ bool AArch64TTIImpl::isLegalToVectorizeReduction(
19811981
case RecurKind::UMax:
19821982
case RecurKind::FMin:
19831983
case RecurKind::FMax:
1984+
case RecurKind::SelectICmp:
1985+
case RecurKind::SelectFCmp:
19841986
return true;
19851987
default:
19861988
return false;

0 commit comments

Comments
 (0)