@@ -299,9 +299,17 @@ bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurrenceKind Kind,
299
299
return false ;
300
300
}
301
301
302
+ bool IsASelect = isa<SelectInst>(Cur);
303
+
304
+ // A conditional reduction operation must only have 2 or less uses in
305
+ // VisitedInsts.
306
+ if (IsASelect && (Kind == RK_FloatAdd || Kind == RK_FloatMult) &&
307
+ hasMultipleUsesOf (Cur, VisitedInsts, 2 ))
308
+ return false ;
309
+
302
310
// A reduction operation must only have one use of the reduction value.
303
- if (!IsAPhi && Kind != RK_IntegerMinMax && Kind != RK_FloatMinMax &&
304
- hasMultipleUsesOf (Cur, VisitedInsts))
311
+ if (!IsAPhi && !IsASelect && Kind != RK_IntegerMinMax &&
312
+ Kind != RK_FloatMinMax && hasMultipleUsesOf (Cur, VisitedInsts, 1 ))
305
313
return false ;
306
314
307
315
// All inputs to a PHI node must be a reduction value.
@@ -362,7 +370,8 @@ bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurrenceKind Kind,
362
370
} else if (!isa<PHINode>(UI) &&
363
371
((!isa<FCmpInst>(UI) && !isa<ICmpInst>(UI) &&
364
372
!isa<SelectInst>(UI)) ||
365
- !isMinMaxSelectCmpPattern (UI, IgnoredVal).isRecurrence ()))
373
+ (!isConditionalRdxPattern (Kind, UI).isRecurrence () &&
374
+ !isMinMaxSelectCmpPattern (UI, IgnoredVal).isRecurrence ())))
366
375
return false ;
367
376
368
377
// Remember that we completed the cycle.
@@ -491,6 +500,52 @@ RecurrenceDescriptor::isMinMaxSelectCmpPattern(Instruction *I, InstDesc &Prev) {
491
500
return InstDesc (false , I);
492
501
}
493
502
503
+ // / Returns true if the select instruction has users in the compare-and-add
504
+ // / reduction pattern below. The select instruction argument is the last one
505
+ // / in the sequence.
506
+ // /
507
+ // / %sum.1 = phi ...
508
+ // / ...
509
+ // / %cmp = fcmp pred %0, %CFP
510
+ // / %add = fadd %0, %sum.1
511
+ // / %sum.2 = select %cmp, %add, %sum.1
512
+ RecurrenceDescriptor::InstDesc
513
+ RecurrenceDescriptor::isConditionalRdxPattern (
514
+ RecurrenceKind Kind, Instruction *I) {
515
+ SelectInst *SI = dyn_cast<SelectInst>(I);
516
+ if (!SI)
517
+ return InstDesc (false , I);
518
+
519
+ CmpInst *CI = dyn_cast<CmpInst>(SI->getCondition ());
520
+ // Only handle single use cases for now.
521
+ if (!CI || !CI->hasOneUse ())
522
+ return InstDesc (false , I);
523
+
524
+ Value *TrueVal = SI->getTrueValue ();
525
+ Value *FalseVal = SI->getFalseValue ();
526
+ // Handle only when either of operands of select instruction is a PHI
527
+ // node for now.
528
+ if ((isa<PHINode>(*TrueVal) && isa<PHINode>(*FalseVal)) ||
529
+ (!isa<PHINode>(*TrueVal) && !isa<PHINode>(*FalseVal)))
530
+ return InstDesc (false , I);
531
+
532
+ Instruction *I1 =
533
+ isa<PHINode>(*TrueVal) ? dyn_cast<Instruction>(FalseVal)
534
+ : dyn_cast<Instruction>(TrueVal);
535
+ if (!I1 || !I1->isBinaryOp ())
536
+ return InstDesc (false , I);
537
+
538
+ Value *Op1, *Op2;
539
+ if (m_FAdd (m_Value (Op1), m_Value (Op2)).match (I1) ||
540
+ m_FSub (m_Value (Op1), m_Value (Op2)).match (I1))
541
+ return InstDesc (Kind == RK_FloatAdd, SI);
542
+
543
+ if (m_FMul (m_Value (Op1), m_Value (Op2)).match (I1))
544
+ return InstDesc (Kind == RK_FloatMult, SI);
545
+
546
+ return InstDesc (false , I);
547
+ }
548
+
494
549
RecurrenceDescriptor::InstDesc
495
550
RecurrenceDescriptor::isRecurrenceInstr (Instruction *I, RecurrenceKind Kind,
496
551
InstDesc &Prev, bool HasFunNoNaNAttr) {
@@ -520,9 +575,12 @@ RecurrenceDescriptor::isRecurrenceInstr(Instruction *I, RecurrenceKind Kind,
520
575
case Instruction::FSub:
521
576
case Instruction::FAdd:
522
577
return InstDesc (Kind == RK_FloatAdd, I, UAI);
578
+ case Instruction::Select:
579
+ if (Kind == RK_FloatAdd || Kind == RK_FloatMult)
580
+ return isConditionalRdxPattern (Kind, I);
581
+ LLVM_FALLTHROUGH;
523
582
case Instruction::FCmp:
524
583
case Instruction::ICmp:
525
- case Instruction::Select:
526
584
if (Kind != RK_IntegerMinMax &&
527
585
(!HasFunNoNaNAttr || Kind != RK_FloatMinMax))
528
586
return InstDesc (false , I);
@@ -531,13 +589,14 @@ RecurrenceDescriptor::isRecurrenceInstr(Instruction *I, RecurrenceKind Kind,
531
589
}
532
590
533
591
bool RecurrenceDescriptor::hasMultipleUsesOf (
534
- Instruction *I, SmallPtrSetImpl<Instruction *> &Insts) {
592
+ Instruction *I, SmallPtrSetImpl<Instruction *> &Insts,
593
+ unsigned MaxNumUses) {
535
594
unsigned NumUses = 0 ;
536
595
for (User::op_iterator Use = I->op_begin (), E = I->op_end (); Use != E;
537
596
++Use) {
538
597
if (Insts.count (dyn_cast<Instruction>(*Use)))
539
598
++NumUses;
540
- if (NumUses > 1 )
599
+ if (NumUses > MaxNumUses )
541
600
return true ;
542
601
}
543
602
0 commit comments