@@ -1774,7 +1774,7 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty,
17741774 {
17751775 const SCEV *LHS;
17761776 const SCEV *RHS;
1777- if (matchURem (Op, LHS, RHS))
1777+ if (match (Op, m_scev_URem(m_SCEV( LHS), m_SCEV( RHS), *this) ))
17781778 return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
17791779 getZeroExtendExpr(RHS, Ty, Depth + 1));
17801780 }
@@ -2699,17 +2699,12 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
26992699 }
27002700
27012701 // Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y)
2702- if (Ops.size() == 2) {
2703- const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[0]);
2704- if (Mul && Mul->getNumOperands() == 2 &&
2705- Mul->getOperand(0)->isAllOnesValue()) {
2706- const SCEV *X;
2707- const SCEV *Y;
2708- if (matchURem(Mul->getOperand(1), X, Y) && X == Ops[1]) {
2709- return getMulExpr(Y, getUDivExpr(X, Y));
2710- }
2711- }
2712- }
2702+ const SCEV *Y;
2703+ if (Ops.size() == 2 &&
2704+ match(Ops[0],
2705+ m_scev_Mul(m_scev_AllOnes(),
2706+ m_scev_URem(m_scev_Specific(Ops[1]), m_SCEV(Y), *this))))
2707+ return getMulExpr(Y, getUDivExpr(Ops[1], Y));
27132708
27142709 // Skip past any other cast SCEVs.
27152710 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
@@ -15410,65 +15405,6 @@ void PredicatedScalarEvolution::print(raw_ostream &OS, unsigned Depth) const {
1541015405 }
1541115406}
1541215407
15413- // Match the mathematical pattern A - (A / B) * B, where A and B can be
15414- // arbitrary expressions. Also match zext (trunc A to iB) to iY, which is used
15415- // for URem with constant power-of-2 second operands.
15416- // It's not always easy, as A and B can be folded (imagine A is X / 2, and B is
15417- // 4, A / B becomes X / 8).
15418- bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS,
15419- const SCEV *&RHS) {
15420- if (Expr->getType()->isPointerTy())
15421- return false;
15422-
15423- // Try to match 'zext (trunc A to iB) to iY', which is used
15424- // for URem with constant power-of-2 second operands. Make sure the size of
15425- // the operand A matches the size of the whole expressions.
15426- if (match(Expr, m_scev_ZExt(m_scev_Trunc(m_SCEV(LHS))))) {
15427- Type *TruncTy = cast<SCEVZeroExtendExpr>(Expr)->getOperand()->getType();
15428- // Bail out if the type of the LHS is larger than the type of the
15429- // expression for now.
15430- if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(Expr->getType()))
15431- return false;
15432- if (LHS->getType() != Expr->getType())
15433- LHS = getZeroExtendExpr(LHS, Expr->getType());
15434- RHS = getConstant(APInt(getTypeSizeInBits(Expr->getType()), 1)
15435- << getTypeSizeInBits(TruncTy));
15436- return true;
15437- }
15438- const auto *Add = dyn_cast<SCEVAddExpr>(Expr);
15439- if (Add == nullptr || Add->getNumOperands() != 2)
15440- return false;
15441-
15442- const SCEV *A = Add->getOperand(1);
15443- const auto *Mul = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
15444-
15445- if (Mul == nullptr)
15446- return false;
15447-
15448- const auto MatchURemWithDivisor = [&](const SCEV *B) {
15449- // (SomeExpr + (-(SomeExpr / B) * B)).
15450- if (Expr == getURemExpr(A, B)) {
15451- LHS = A;
15452- RHS = B;
15453- return true;
15454- }
15455- return false;
15456- };
15457-
15458- // (SomeExpr + (-1 * (SomeExpr / B) * B)).
15459- if (Mul->getNumOperands() == 3 && isa<SCEVConstant>(Mul->getOperand(0)))
15460- return MatchURemWithDivisor(Mul->getOperand(1)) ||
15461- MatchURemWithDivisor(Mul->getOperand(2));
15462-
15463- // (SomeExpr + ((-SomeExpr / B) * B)) or (SomeExpr + ((SomeExpr / B) * -B)).
15464- if (Mul->getNumOperands() == 2)
15465- return MatchURemWithDivisor(Mul->getOperand(1)) ||
15466- MatchURemWithDivisor(Mul->getOperand(0)) ||
15467- MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(1))) ||
15468- MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(0)));
15469- return false;
15470- }
15471-
1547215408ScalarEvolution::LoopGuards
1547315409ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) {
1547415410 BasicBlock *Header = L->getHeader();
@@ -15689,20 +15625,18 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1568915625 if (Predicate == CmpInst::ICMP_EQ && match(RHS, m_scev_Zero())) {
1569015626 // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
1569115627 // explicitly express that.
15692- const SCEV *URemLHS = nullptr;
15628+ const SCEVUnknown *URemLHS = nullptr;
1569315629 const SCEV *URemRHS = nullptr;
15694- if (SE.matchURem(LHS, URemLHS, URemRHS)) {
15695- if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
15696- auto I = RewriteMap.find(LHSUnknown);
15697- const SCEV *RewrittenLHS =
15698- I != RewriteMap.end() ? I->second : LHSUnknown;
15699- RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
15700- const auto *Multiple =
15701- SE.getMulExpr(SE.getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
15702- RewriteMap[LHSUnknown] = Multiple;
15703- ExprsToRewrite.push_back(LHSUnknown);
15704- return;
15705- }
15630+ if (match(LHS,
15631+ m_scev_URem(m_SCEVUnknown(URemLHS), m_SCEV(URemRHS), SE))) {
15632+ auto I = RewriteMap.find(URemLHS);
15633+ const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : URemLHS;
15634+ RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
15635+ const auto *Multiple =
15636+ SE.getMulExpr(SE.getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
15637+ RewriteMap[URemLHS] = Multiple;
15638+ ExprsToRewrite.push_back(URemLHS);
15639+ return;
1570615640 }
1570715641 }
1570815642
0 commit comments