Skip to content

Commit 7f04ee1

Browse files
authored
[SCEV] Move URem matching to ScalarEvolutionPatternMatch.h (#163170)
Move URem matching to ScalarEvolutionPatternMatch.h so it can be re-used together with other matchers. Depends on #163169 PR: #163170
1 parent e10e2f7 commit 7f04ee1

File tree

5 files changed

+98
-96
lines changed

5 files changed

+98
-96
lines changed

llvm/include/llvm/Analysis/ScalarEvolution.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2316,10 +2316,6 @@ class ScalarEvolution {
23162316
/// an add rec on said loop.
23172317
void getUsedLoops(const SCEV *S, SmallPtrSetImpl<const Loop *> &LoopsUsed);
23182318

2319-
/// Try to match the pattern generated by getURemExpr(A, B). If successful,
2320-
/// Assign A and B to LHS and RHS, respectively.
2321-
LLVM_ABI bool matchURem(const SCEV *Expr, const SCEV *&LHS, const SCEV *&RHS);
2322-
23232319
/// Look for a SCEV expression with type `SCEVType` and operands `Ops` in
23242320
/// `UniqueSCEVs`. Return if found, else nullptr.
23252321
SCEV *findExistingSCEVInCache(SCEVTypes SCEVType, ArrayRef<const SCEV *> Ops);

llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,80 @@ m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1) {
252252
return m_scev_Binary<SCEVUDivExpr>(Op0, Op1);
253253
}
254254

255+
/// Match unsigned remainder pattern.
256+
/// Matches patterns generated by getURemExpr.
257+
template <typename Op0_t, typename Op1_t> struct SCEVURem_match {
258+
Op0_t Op0;
259+
Op1_t Op1;
260+
ScalarEvolution &SE;
261+
262+
SCEVURem_match(Op0_t Op0, Op1_t Op1, ScalarEvolution &SE)
263+
: Op0(Op0), Op1(Op1), SE(SE) {}
264+
265+
bool match(const SCEV *Expr) const {
266+
if (Expr->getType()->isPointerTy())
267+
return false;
268+
269+
// Try to match 'zext (trunc A to iB) to iY', which is used
270+
// for URem with constant power-of-2 second operands. Make sure the size of
271+
// the operand A matches the size of the whole expressions.
272+
const SCEV *LHS;
273+
if (SCEVPatternMatch::match(Expr, m_scev_ZExt(m_scev_Trunc(m_SCEV(LHS))))) {
274+
Type *TruncTy = cast<SCEVZeroExtendExpr>(Expr)->getOperand()->getType();
275+
// Bail out if the type of the LHS is larger than the type of the
276+
// expression for now.
277+
if (SE.getTypeSizeInBits(LHS->getType()) >
278+
SE.getTypeSizeInBits(Expr->getType()))
279+
return false;
280+
if (LHS->getType() != Expr->getType())
281+
LHS = SE.getZeroExtendExpr(LHS, Expr->getType());
282+
const SCEV *RHS =
283+
SE.getConstant(APInt(SE.getTypeSizeInBits(Expr->getType()), 1)
284+
<< SE.getTypeSizeInBits(TruncTy));
285+
return Op0.match(LHS) && Op1.match(RHS);
286+
}
287+
const auto *Add = dyn_cast<SCEVAddExpr>(Expr);
288+
if (Add == nullptr || Add->getNumOperands() != 2)
289+
return false;
290+
291+
const SCEV *A = Add->getOperand(1);
292+
const auto *Mul = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
293+
294+
if (Mul == nullptr)
295+
return false;
296+
297+
const auto MatchURemWithDivisor = [&](const SCEV *B) {
298+
// (SomeExpr + (-(SomeExpr / B) * B)).
299+
if (Expr == SE.getURemExpr(A, B))
300+
return Op0.match(A) && Op1.match(B);
301+
return false;
302+
};
303+
304+
// (SomeExpr + (-1 * (SomeExpr / B) * B)).
305+
if (Mul->getNumOperands() == 3 && isa<SCEVConstant>(Mul->getOperand(0)))
306+
return MatchURemWithDivisor(Mul->getOperand(1)) ||
307+
MatchURemWithDivisor(Mul->getOperand(2));
308+
309+
// (SomeExpr + ((-SomeExpr / B) * B)) or (SomeExpr + ((SomeExpr / B) * -B)).
310+
if (Mul->getNumOperands() == 2)
311+
return MatchURemWithDivisor(Mul->getOperand(1)) ||
312+
MatchURemWithDivisor(Mul->getOperand(0)) ||
313+
MatchURemWithDivisor(SE.getNegativeSCEV(Mul->getOperand(1))) ||
314+
MatchURemWithDivisor(SE.getNegativeSCEV(Mul->getOperand(0)));
315+
return false;
316+
}
317+
};
318+
319+
/// Match the mathematical pattern A - (A / B) * B, where A and B can be
320+
/// arbitrary expressions. Also match zext (trunc A to iB) to iY, which is used
321+
/// for URem with constant power-of-2 second operands. It's not always easy, as
322+
/// A and B can be folded (imagine A is X / 2, and B is 4, A / B becomes X / 8).
323+
template <typename Op0_t, typename Op1_t>
324+
inline SCEVURem_match<Op0_t, Op1_t> m_scev_URem(Op0_t LHS, Op1_t RHS,
325+
ScalarEvolution &SE) {
326+
return SCEVURem_match<Op0_t, Op1_t>(LHS, RHS, SE);
327+
}
328+
255329
inline class_match<const Loop> m_Loop() { return class_match<const Loop>(); }
256330

257331
/// Match an affine SCEVAddRecExpr.

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 18 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1774,7 +1774,7 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty,
17741774
{
17751775
const SCEV *LHS;
17761776
const SCEV *RHS;
1777-
if (matchURem(Op, LHS, RHS))
1777+
if (match(Op, m_scev_URem(m_SCEV(LHS), m_SCEV(RHS), *this)))
17781778
return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
17791779
getZeroExtendExpr(RHS, Ty, Depth + 1));
17801780
}
@@ -2699,17 +2699,12 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
26992699
}
27002700

27012701
// Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y)
2702-
if (Ops.size() == 2) {
2703-
const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[0]);
2704-
if (Mul && Mul->getNumOperands() == 2 &&
2705-
Mul->getOperand(0)->isAllOnesValue()) {
2706-
const SCEV *X;
2707-
const SCEV *Y;
2708-
if (matchURem(Mul->getOperand(1), X, Y) && X == Ops[1]) {
2709-
return getMulExpr(Y, getUDivExpr(X, Y));
2710-
}
2711-
}
2712-
}
2702+
const SCEV *Y;
2703+
if (Ops.size() == 2 &&
2704+
match(Ops[0],
2705+
m_scev_Mul(m_scev_AllOnes(),
2706+
m_scev_URem(m_scev_Specific(Ops[1]), m_SCEV(Y), *this))))
2707+
return getMulExpr(Y, getUDivExpr(Ops[1], Y));
27132708

27142709
// Skip past any other cast SCEVs.
27152710
while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
@@ -15410,65 +15405,6 @@ void PredicatedScalarEvolution::print(raw_ostream &OS, unsigned Depth) const {
1541015405
}
1541115406
}
1541215407

15413-
// Match the mathematical pattern A - (A / B) * B, where A and B can be
15414-
// arbitrary expressions. Also match zext (trunc A to iB) to iY, which is used
15415-
// for URem with constant power-of-2 second operands.
15416-
// It's not always easy, as A and B can be folded (imagine A is X / 2, and B is
15417-
// 4, A / B becomes X / 8).
15418-
bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS,
15419-
const SCEV *&RHS) {
15420-
if (Expr->getType()->isPointerTy())
15421-
return false;
15422-
15423-
// Try to match 'zext (trunc A to iB) to iY', which is used
15424-
// for URem with constant power-of-2 second operands. Make sure the size of
15425-
// the operand A matches the size of the whole expressions.
15426-
if (match(Expr, m_scev_ZExt(m_scev_Trunc(m_SCEV(LHS))))) {
15427-
Type *TruncTy = cast<SCEVZeroExtendExpr>(Expr)->getOperand()->getType();
15428-
// Bail out if the type of the LHS is larger than the type of the
15429-
// expression for now.
15430-
if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(Expr->getType()))
15431-
return false;
15432-
if (LHS->getType() != Expr->getType())
15433-
LHS = getZeroExtendExpr(LHS, Expr->getType());
15434-
RHS = getConstant(APInt(getTypeSizeInBits(Expr->getType()), 1)
15435-
<< getTypeSizeInBits(TruncTy));
15436-
return true;
15437-
}
15438-
const auto *Add = dyn_cast<SCEVAddExpr>(Expr);
15439-
if (Add == nullptr || Add->getNumOperands() != 2)
15440-
return false;
15441-
15442-
const SCEV *A = Add->getOperand(1);
15443-
const auto *Mul = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
15444-
15445-
if (Mul == nullptr)
15446-
return false;
15447-
15448-
const auto MatchURemWithDivisor = [&](const SCEV *B) {
15449-
// (SomeExpr + (-(SomeExpr / B) * B)).
15450-
if (Expr == getURemExpr(A, B)) {
15451-
LHS = A;
15452-
RHS = B;
15453-
return true;
15454-
}
15455-
return false;
15456-
};
15457-
15458-
// (SomeExpr + (-1 * (SomeExpr / B) * B)).
15459-
if (Mul->getNumOperands() == 3 && isa<SCEVConstant>(Mul->getOperand(0)))
15460-
return MatchURemWithDivisor(Mul->getOperand(1)) ||
15461-
MatchURemWithDivisor(Mul->getOperand(2));
15462-
15463-
// (SomeExpr + ((-SomeExpr / B) * B)) or (SomeExpr + ((SomeExpr / B) * -B)).
15464-
if (Mul->getNumOperands() == 2)
15465-
return MatchURemWithDivisor(Mul->getOperand(1)) ||
15466-
MatchURemWithDivisor(Mul->getOperand(0)) ||
15467-
MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(1))) ||
15468-
MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(0)));
15469-
return false;
15470-
}
15471-
1547215408
ScalarEvolution::LoopGuards
1547315409
ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) {
1547415410
BasicBlock *Header = L->getHeader();
@@ -15689,20 +15625,18 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1568915625
if (Predicate == CmpInst::ICMP_EQ && match(RHS, m_scev_Zero())) {
1569015626
// If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
1569115627
// explicitly express that.
15692-
const SCEV *URemLHS = nullptr;
15628+
const SCEVUnknown *URemLHS = nullptr;
1569315629
const SCEV *URemRHS = nullptr;
15694-
if (SE.matchURem(LHS, URemLHS, URemRHS)) {
15695-
if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
15696-
auto I = RewriteMap.find(LHSUnknown);
15697-
const SCEV *RewrittenLHS =
15698-
I != RewriteMap.end() ? I->second : LHSUnknown;
15699-
RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
15700-
const auto *Multiple =
15701-
SE.getMulExpr(SE.getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
15702-
RewriteMap[LHSUnknown] = Multiple;
15703-
ExprsToRewrite.push_back(LHSUnknown);
15704-
return;
15705-
}
15630+
if (match(LHS,
15631+
m_scev_URem(m_SCEVUnknown(URemLHS), m_SCEV(URemRHS), SE))) {
15632+
auto I = RewriteMap.find(URemLHS);
15633+
const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : URemLHS;
15634+
RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
15635+
const auto *Multiple =
15636+
SE.getMulExpr(SE.getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
15637+
RewriteMap[URemLHS] = Multiple;
15638+
ExprsToRewrite.push_back(URemLHS);
15639+
return;
1570615640
}
1570715641
}
1570815642

llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,7 @@ Value *SCEVExpander::visitAddExpr(const SCEVAddExpr *S) {
526526
// Recognize the canonical representation of an unsimplifed urem.
527527
const SCEV *URemLHS = nullptr;
528528
const SCEV *URemRHS = nullptr;
529-
if (SE.matchURem(S, URemLHS, URemRHS)) {
529+
if (match(S, m_scev_URem(m_SCEV(URemLHS), m_SCEV(URemRHS), SE))) {
530530
Value *LHS = expand(URemLHS);
531531
Value *RHS = expand(URemRHS);
532532
return InsertBinop(Instruction::URem, LHS, RHS, SCEV::FlagAnyWrap,

llvm/unittests/Analysis/ScalarEvolutionTest.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "llvm/Analysis/LoopInfo.h"
1212
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
1313
#include "llvm/Analysis/ScalarEvolutionNormalization.h"
14+
#include "llvm/Analysis/ScalarEvolutionPatternMatch.h"
1415
#include "llvm/Analysis/TargetLibraryInfo.h"
1516
#include "llvm/AsmParser/Parser.h"
1617
#include "llvm/IR/Constants.h"
@@ -26,6 +27,8 @@
2627

2728
namespace llvm {
2829

30+
using namespace SCEVPatternMatch;
31+
2932
// We use this fixture to ensure that we clean up ScalarEvolution before
3033
// deleting the PassManager.
3134
class ScalarEvolutionsTest : public testing::Test {
@@ -64,11 +67,6 @@ static std::optional<APInt> computeConstantDifference(ScalarEvolution &SE,
6467
return SE.computeConstantDifference(LHS, RHS);
6568
}
6669

67-
static bool matchURem(ScalarEvolution &SE, const SCEV *Expr, const SCEV *&LHS,
68-
const SCEV *&RHS) {
69-
return SE.matchURem(Expr, LHS, RHS);
70-
}
71-
7270
static bool isImpliedCond(
7371
ScalarEvolution &SE, ICmpInst::Predicate Pred, const SCEV *LHS,
7472
const SCEV *RHS, ICmpInst::Predicate FoundPred, const SCEV *FoundLHS,
@@ -1524,7 +1522,7 @@ TEST_F(ScalarEvolutionsTest, MatchURem) {
15241522
auto *URemI = getInstructionByName(F, N);
15251523
auto *S = SE.getSCEV(URemI);
15261524
const SCEV *LHS, *RHS;
1527-
EXPECT_TRUE(matchURem(SE, S, LHS, RHS));
1525+
EXPECT_TRUE(match(S, m_scev_URem(m_SCEV(LHS), m_SCEV(RHS), SE)));
15281526
EXPECT_EQ(LHS, SE.getSCEV(URemI->getOperand(0)));
15291527
EXPECT_EQ(RHS, SE.getSCEV(URemI->getOperand(1)));
15301528
EXPECT_EQ(LHS->getType(), S->getType());
@@ -1537,7 +1535,7 @@ TEST_F(ScalarEvolutionsTest, MatchURem) {
15371535
auto *URem1 = getInstructionByName(F, "rem4");
15381536
auto *S = SE.getSCEV(Ext);
15391537
const SCEV *LHS, *RHS;
1540-
EXPECT_TRUE(matchURem(SE, S, LHS, RHS));
1538+
EXPECT_TRUE(match(S, m_scev_URem(m_SCEV(LHS), m_SCEV(RHS), SE)));
15411539
EXPECT_NE(LHS, SE.getSCEV(URem1->getOperand(0)));
15421540
// RHS and URem1->getOperand(1) have different widths, so compare the
15431541
// integer values.

0 commit comments

Comments
 (0)