Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,11 @@ class IntegerRelation {
convertVarKind(kind, varStart, varLimit, VarKind::Local);
}

/// Merge and align symbol variables of `this` and `other` with respect to
/// identifiers. After this operation the symbol variables of both relations
/// have the same identifiers in the same order.
void mergeAndAlignSymbols(IntegerRelation &other);

/// Adds additional local vars to the sets such that they both have the union
/// of the local vars in each set, without changing the set of points that
/// lie in `this` and `other`.
Expand Down
34 changes: 34 additions & 0 deletions mlir/lib/Analysis/Presburger/IntegerRelation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1288,6 +1288,40 @@ void IntegerRelation::eliminateRedundantLocalVar(unsigned posA, unsigned posB) {
removeVar(posB);
}

/// mergeAndAlignSymbols's implementation can be broken down into two steps:
/// 1. Merge and align identifiers into `other` from `this. If an identifier
/// from `this` exists in `other` then we align it. Otherwise, we assume it is a
/// new identifier and insert it into `other` in the same position as `this`.
/// 2. Add identifiers that are in `other` but not `this to `this`.
void IntegerRelation::mergeAndAlignSymbols(IntegerRelation &other) {
assert(space.isUsingIds() && other.space.isUsingIds() &&
"both relations need to have identifers to merge and align");

unsigned i = 0;
for (const Identifier identifier : space.getIds(VarKind::Symbol)) {
// Search in `other` starting at position `i` since the left of `i` is
// aligned.
const Identifier *findBegin =
other.space.getIds(VarKind::Symbol).begin() + i;
const Identifier *findEnd = other.space.getIds(VarKind::Symbol).end();
const Identifier *itr = std::find(findBegin, findEnd, identifier);
if (itr != findEnd) {
other.swapVar(other.getVarKindOffset(VarKind::Symbol) + i,
other.getVarKindOffset(VarKind::Symbol) + i +
std::distance(findBegin, itr));
} else {
other.insertVar(VarKind::Symbol, i);
other.space.getId(VarKind::Symbol, i) = identifier;
}
++i;
}

for (unsigned e = other.getNumVarKind(VarKind::Symbol); i < e; ++i) {
insertVar(VarKind::Symbol, i);
space.getId(VarKind::Symbol, i) = other.space.getId(VarKind::Symbol, i);
}
}

/// Adds additional local ids to the sets such that they both have the union
/// of the local ids in each set, without changing the set of points that
/// lie in `this` and `other`.
Expand Down
244 changes: 244 additions & 0 deletions mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,247 @@ TEST(IntegerRelationTest, swapVar) {
EXPECT_TRUE(swappedSpace.getId(VarKind::Symbol, 1)
.isEqual(space.getId(VarKind::Domain, 1)));
}

TEST(IntegerRelationTest, mergeAndAlignSymbols) {
IntegerRelation rel =
parseRelationFromSet("(x, y, z, a, b, c)[N, Q] : (a - x - y == 0, "
"x >= 0, N - b >= 0, y >= 0, Q - y >= 0)",
3);
IntegerRelation otherRel = parseRelationFromSet(
"(x, y, z, a, b)[N, M, P] : (z - x - y == 0, x >= 0, N - x "
">= 0, y >= 0, M - y >= 0, 2 * P - 3 * a + 2 * b == 0)",
3);
PresburgerSpace space = PresburgerSpace::getRelationSpace(3, 3, 2, 0);
space.resetIds();

PresburgerSpace otherSpace = PresburgerSpace::getRelationSpace(3, 2, 3, 0);
otherSpace.resetIds();

// Attach identifiers.
int identifiers[7] = {0, 1, 2, 3, 4, 5, 6};
int otherIdentifiers[8] = {10, 11, 12, 13, 14, 15, 16, 17};

space.getId(VarKind::Domain, 0) = Identifier(&identifiers[0]);
space.getId(VarKind::Domain, 1) = Identifier(&identifiers[1]);
// Note the common identifier.
space.getId(VarKind::Domain, 2) = Identifier(&otherIdentifiers[2]);
space.getId(VarKind::Range, 0) = Identifier(&identifiers[2]);
space.getId(VarKind::Range, 1) = Identifier(&identifiers[3]);
space.getId(VarKind::Range, 2) = Identifier(&identifiers[4]);
space.getId(VarKind::Symbol, 0) = Identifier(&identifiers[5]);
space.getId(VarKind::Symbol, 1) = Identifier(&identifiers[6]);

otherSpace.getId(VarKind::Domain, 0) = Identifier(&otherIdentifiers[0]);
otherSpace.getId(VarKind::Domain, 1) = Identifier(&otherIdentifiers[1]);
otherSpace.getId(VarKind::Domain, 2) = Identifier(&otherIdentifiers[2]);
otherSpace.getId(VarKind::Range, 0) = Identifier(&otherIdentifiers[3]);
otherSpace.getId(VarKind::Range, 1) = Identifier(&otherIdentifiers[4]);
// Note the common identifier.
otherSpace.getId(VarKind::Symbol, 0) = Identifier(&identifiers[6]);
otherSpace.getId(VarKind::Symbol, 1) = Identifier(&otherIdentifiers[5]);
otherSpace.getId(VarKind::Symbol, 2) = Identifier(&otherIdentifiers[7]);

rel.setSpace(space);
otherRel.setSpace(otherSpace);
rel.mergeAndAlignSymbols(otherRel);

space = rel.getSpace();
otherSpace = otherRel.getSpace();

// Check if merge and align is successful.
// Check symbol var identifiers.
EXPECT_EQ(4u, space.getNumSymbolVars());
EXPECT_EQ(4u, otherSpace.getNumSymbolVars());
EXPECT_EQ(space.getId(VarKind::Symbol, 0), Identifier(&identifiers[5]));
EXPECT_EQ(space.getId(VarKind::Symbol, 1), Identifier(&identifiers[6]));
EXPECT_EQ(space.getId(VarKind::Symbol, 2), Identifier(&otherIdentifiers[5]));
EXPECT_EQ(space.getId(VarKind::Symbol, 3), Identifier(&otherIdentifiers[7]));
EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 0), Identifier(&identifiers[5]));
EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 1), Identifier(&identifiers[6]));
EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 2),
Identifier(&otherIdentifiers[5]));
EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 3),
Identifier(&otherIdentifiers[7]));
// Check that domain and range var identifiers are not affected.
EXPECT_EQ(3u, space.getNumDomainVars());
EXPECT_EQ(3u, space.getNumRangeVars());
EXPECT_EQ(space.getId(VarKind::Domain, 0), Identifier(&identifiers[0]));
EXPECT_EQ(space.getId(VarKind::Domain, 1), Identifier(&identifiers[1]));
EXPECT_EQ(space.getId(VarKind::Domain, 2), Identifier(&otherIdentifiers[2]));
EXPECT_EQ(space.getId(VarKind::Range, 0), Identifier(&identifiers[2]));
EXPECT_EQ(space.getId(VarKind::Range, 1), Identifier(&identifiers[3]));
EXPECT_EQ(space.getId(VarKind::Range, 2), Identifier(&identifiers[4]));
EXPECT_EQ(3u, otherSpace.getNumDomainVars());
EXPECT_EQ(2u, otherSpace.getNumRangeVars());
EXPECT_EQ(otherSpace.getId(VarKind::Domain, 0),
Identifier(&otherIdentifiers[0]));
EXPECT_EQ(otherSpace.getId(VarKind::Domain, 1),
Identifier(&otherIdentifiers[1]));
EXPECT_EQ(otherSpace.getId(VarKind::Domain, 2),
Identifier(&otherIdentifiers[2]));
EXPECT_EQ(otherSpace.getId(VarKind::Range, 0),
Identifier(&otherIdentifiers[3]));
EXPECT_EQ(otherSpace.getId(VarKind::Range, 1),
Identifier(&otherIdentifiers[4]));
}

// Check that mergeAndAlignSymbols unions symbol variables when they are
// disjoint.
TEST(IntegerRelationTest, mergeAndAlignDisjointSymbols) {
IntegerRelation rel = parseRelationFromSet(
"(x, y, z)[A, B, C, D] : (x + A - C - y + D - z >= 0)", 2);
IntegerRelation otherRel = parseRelationFromSet(
"(u, v, a, b)[E, F, G, H] : (E - u + v == 0, v - G - H >= 0)", 2);
PresburgerSpace space = PresburgerSpace::getRelationSpace(2, 1, 4, 0);
space.resetIds();

PresburgerSpace otherSpace = PresburgerSpace::getRelationSpace(2, 2, 4, 0);
otherSpace.resetIds();

// Attach identifiers.
int identifiers[7] = {'x', 'y', 'z', 'A', 'B', 'C', 'D'};
int otherIdentifiers[8] = {'u', 'v', 'a', 'b', 'E', 'F', 'G', 'H'};

space.getId(VarKind::Domain, 0) = Identifier(&identifiers[0]);
space.getId(VarKind::Domain, 1) = Identifier(&identifiers[1]);
space.getId(VarKind::Range, 0) = Identifier(&identifiers[2]);
space.getId(VarKind::Symbol, 0) = Identifier(&identifiers[3]);
space.getId(VarKind::Symbol, 1) = Identifier(&identifiers[4]);
space.getId(VarKind::Symbol, 2) = Identifier(&identifiers[5]);
space.getId(VarKind::Symbol, 3) = Identifier(&identifiers[6]);

otherSpace.getId(VarKind::Domain, 0) = Identifier(&otherIdentifiers[0]);
otherSpace.getId(VarKind::Domain, 1) = Identifier(&otherIdentifiers[1]);
otherSpace.getId(VarKind::Range, 0) = Identifier(&otherIdentifiers[2]);
otherSpace.getId(VarKind::Range, 1) = Identifier(&otherIdentifiers[3]);
otherSpace.getId(VarKind::Symbol, 0) = Identifier(&otherIdentifiers[4]);
otherSpace.getId(VarKind::Symbol, 1) = Identifier(&otherIdentifiers[5]);
otherSpace.getId(VarKind::Symbol, 2) = Identifier(&otherIdentifiers[6]);
otherSpace.getId(VarKind::Symbol, 3) = Identifier(&otherIdentifiers[7]);

rel.setSpace(space);
otherRel.setSpace(otherSpace);
rel.mergeAndAlignSymbols(otherRel);

space = rel.getSpace();
otherSpace = otherRel.getSpace();

// Check if merge and align is successful.
// Check symbol var identifiers.
EXPECT_EQ(8u, space.getNumSymbolVars());
EXPECT_EQ(8u, otherSpace.getNumSymbolVars());
EXPECT_EQ(space.getId(VarKind::Symbol, 0), Identifier(&identifiers[3]));
EXPECT_EQ(space.getId(VarKind::Symbol, 1), Identifier(&identifiers[4]));
EXPECT_EQ(space.getId(VarKind::Symbol, 2), Identifier(&identifiers[5]));
EXPECT_EQ(space.getId(VarKind::Symbol, 3), Identifier(&identifiers[6]));
EXPECT_EQ(space.getId(VarKind::Symbol, 4), Identifier(&otherIdentifiers[4]));
EXPECT_EQ(space.getId(VarKind::Symbol, 5), Identifier(&otherIdentifiers[5]));
EXPECT_EQ(space.getId(VarKind::Symbol, 6), Identifier(&otherIdentifiers[6]));
EXPECT_EQ(space.getId(VarKind::Symbol, 7), Identifier(&otherIdentifiers[7]));
EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 0), Identifier(&identifiers[3]));
EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 1), Identifier(&identifiers[4]));
EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 2), Identifier(&identifiers[5]));
EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 3), Identifier(&identifiers[6]));
EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 4),
Identifier(&otherIdentifiers[4]));
EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 5),
Identifier(&otherIdentifiers[5]));
EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 6),
Identifier(&otherIdentifiers[6]));
EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 7),
Identifier(&otherIdentifiers[7]));
// Check that domain and range var identifiers are not affected.
EXPECT_EQ(2u, space.getNumDomainVars());
EXPECT_EQ(1u, space.getNumRangeVars());
EXPECT_EQ(space.getId(VarKind::Domain, 0), Identifier(&identifiers[0]));
EXPECT_EQ(space.getId(VarKind::Domain, 1), Identifier(&identifiers[1]));
EXPECT_EQ(space.getId(VarKind::Range, 0), Identifier(&identifiers[2]));
EXPECT_EQ(2u, otherSpace.getNumDomainVars());
EXPECT_EQ(2u, otherSpace.getNumRangeVars());
EXPECT_EQ(otherSpace.getId(VarKind::Domain, 0),
Identifier(&otherIdentifiers[0]));
EXPECT_EQ(otherSpace.getId(VarKind::Domain, 1),
Identifier(&otherIdentifiers[1]));
EXPECT_EQ(otherSpace.getId(VarKind::Range, 0),
Identifier(&otherIdentifiers[2]));
EXPECT_EQ(otherSpace.getId(VarKind::Range, 1),
Identifier(&otherIdentifiers[3]));
}

// Check that mergeAndAlignSymbols is correct when a suffix of identifiers is
// shared; i.e. identifiers are [A, B, C, D] and [E, F, C, D].
TEST(IntegerRelationTest, mergeAndAlignCommonSuffixSymbols) {
IntegerRelation rel = parseRelationFromSet(
"(x, y, z)[A, B, C, D] : (x + A - C - y + D - z >= 0)", 2);
IntegerRelation otherRel = parseRelationFromSet(
"(u, v, a, b)[E, F, C, D] : (E - u + v == 0, v - C - D >= 0)", 2);
PresburgerSpace space = PresburgerSpace::getRelationSpace(2, 1, 4, 0);
space.resetIds();

PresburgerSpace otherSpace = PresburgerSpace::getRelationSpace(2, 2, 4, 0);
otherSpace.resetIds();

// Attach identifiers.
int identifiers[7] = {'x', 'y', 'z', 'A', 'B', 'C', 'D'};
int otherIdentifiers[6] = {'u', 'v', 'a', 'b', 'E', 'F'};

space.getId(VarKind::Domain, 0) = Identifier(&identifiers[0]);
space.getId(VarKind::Domain, 1) = Identifier(&identifiers[1]);
space.getId(VarKind::Range, 0) = Identifier(&identifiers[2]);
space.getId(VarKind::Symbol, 0) = Identifier(&identifiers[3]);
space.getId(VarKind::Symbol, 1) = Identifier(&identifiers[4]);
space.getId(VarKind::Symbol, 2) = Identifier(&identifiers[5]);
space.getId(VarKind::Symbol, 3) = Identifier(&identifiers[6]);

otherSpace.getId(VarKind::Domain, 0) = Identifier(&otherIdentifiers[0]);
otherSpace.getId(VarKind::Domain, 1) = Identifier(&otherIdentifiers[1]);
otherSpace.getId(VarKind::Range, 0) = Identifier(&otherIdentifiers[2]);
otherSpace.getId(VarKind::Range, 1) = Identifier(&otherIdentifiers[3]);
otherSpace.getId(VarKind::Symbol, 0) = Identifier(&otherIdentifiers[4]);
otherSpace.getId(VarKind::Symbol, 1) = Identifier(&otherIdentifiers[5]);
// Note common identifiers
otherSpace.getId(VarKind::Symbol, 2) = Identifier(&identifiers[5]);
otherSpace.getId(VarKind::Symbol, 3) = Identifier(&identifiers[6]);

rel.setSpace(space);
otherRel.setSpace(otherSpace);
rel.mergeAndAlignSymbols(otherRel);

space = rel.getSpace();
otherSpace = otherRel.getSpace();

// Check if merge and align is successful.
// Check symbol var identifiers.
EXPECT_EQ(6u, space.getNumSymbolVars());
EXPECT_EQ(6u, otherSpace.getNumSymbolVars());
EXPECT_EQ(space.getId(VarKind::Symbol, 0), Identifier(&identifiers[3]));
EXPECT_EQ(space.getId(VarKind::Symbol, 1), Identifier(&identifiers[4]));
EXPECT_EQ(space.getId(VarKind::Symbol, 2), Identifier(&identifiers[5]));
EXPECT_EQ(space.getId(VarKind::Symbol, 3), Identifier(&identifiers[6]));
EXPECT_EQ(space.getId(VarKind::Symbol, 4), Identifier(&otherIdentifiers[4]));
EXPECT_EQ(space.getId(VarKind::Symbol, 5), Identifier(&otherIdentifiers[5]));
EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 0), Identifier(&identifiers[3]));
EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 1), Identifier(&identifiers[4]));
EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 2), Identifier(&identifiers[5]));
EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 3), Identifier(&identifiers[6]));
EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 4),
Identifier(&otherIdentifiers[4]));
EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 5),
Identifier(&otherIdentifiers[5]));
// Check that domain and range var identifiers are not affected.
EXPECT_EQ(2u, space.getNumDomainVars());
EXPECT_EQ(1u, space.getNumRangeVars());
EXPECT_EQ(space.getId(VarKind::Domain, 0), Identifier(&identifiers[0]));
EXPECT_EQ(space.getId(VarKind::Domain, 1), Identifier(&identifiers[1]));
EXPECT_EQ(space.getId(VarKind::Range, 0), Identifier(&identifiers[2]));
EXPECT_EQ(2u, otherSpace.getNumDomainVars());
EXPECT_EQ(2u, otherSpace.getNumRangeVars());
EXPECT_EQ(otherSpace.getId(VarKind::Domain, 0),
Identifier(&otherIdentifiers[0]));
EXPECT_EQ(otherSpace.getId(VarKind::Domain, 1),
Identifier(&otherIdentifiers[1]));
EXPECT_EQ(otherSpace.getId(VarKind::Range, 0),
Identifier(&otherIdentifiers[2]));
EXPECT_EQ(otherSpace.getId(VarKind::Range, 1),
Identifier(&otherIdentifiers[3]));
}