Skip to content

Conversation

@iambrj
Copy link
Member

@iambrj iambrj commented Jan 2, 2024

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Jan 2, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-presburger

Author: Bharathi Ramana Joshi (iambrj)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/76736.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Analysis/Presburger/IntegerRelation.h (+5)
  • (modified) mlir/lib/Analysis/Presburger/IntegerRelation.cpp (+31)
  • (modified) mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp (+83)
diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
index 4c6b810f92e95a..cd957280eb740d 100644
--- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
+++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
@@ -577,6 +577,11 @@ class IntegerRelation {
     convertVarKind(kind, varStart, varLimit, VarKind::Local);
   }
 
+  /// Merge and align symbol variables of `this` and `other` with respect to
+  /// identifiers. After this operation the symbol variables of both relations
+  /// have the same identifiers in the same order.
+  void mergeAndAlignSymbols(IntegerRelation &other);
+
   /// Adds additional local vars to the sets such that they both have the union
   /// of the local vars in each set, without changing the set of points that
   /// lie in `this` and `other`.
diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
index 0109384f1689dd..af16321e69a4cc 100644
--- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
@@ -1288,6 +1288,37 @@ void IntegerRelation::eliminateRedundantLocalVar(unsigned posA, unsigned posB) {
   removeVar(posB);
 }
 
+void IntegerRelation::mergeAndAlignSymbols(IntegerRelation &other) {
+  assert(space.isUsingIds() && other.space.isUsingIds() &&
+         "Both relations need to have identifers to merge & align");
+
+  // First merge & align identifiers into `other` from `this`.
+  unsigned i = 0;
+  for (const Identifier identifier : space.getIds(VarKind::Symbol)) {
+    // If the identifier exists in `other`, then align it; otherwise insert it
+    // assuming it is a new identifier. Search in `other` starting at position
+    // `i` since the left of `i` is aligned.
+    auto *findBegin = other.space.getIds(VarKind::Symbol).begin() + i;
+    auto *findEnd = other.space.getIds(VarKind::Symbol).end();
+    auto *itr = std::find(findBegin, findEnd, identifier);
+    if (itr != findEnd) {
+      other.swapVar(other.getVarKindOffset(VarKind::Symbol) + i,
+                    other.getVarKindOffset(VarKind::Symbol) + i +
+                        std::distance(findBegin, itr));
+    } else {
+      other.insertVar(VarKind::Symbol, i);
+      other.space.getId(VarKind::Symbol, i) = identifier;
+    }
+    ++i;
+  }
+
+  // Finally add identifiers that are in `other`, but not in `this` to `this`.
+  for (unsigned e = other.getNumVarKind(VarKind::Symbol); i < e; ++i) {
+    insertVar(VarKind::Symbol, i);
+    space.getId(VarKind::Symbol, i) = other.space.getId(VarKind::Symbol, i);
+  }
+}
+
 /// Adds additional local ids to the sets such that they both have the union
 /// of the local ids in each set, without changing the set of points that
 /// lie in `this` and `other`.
diff --git a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
index f390296da648d2..998218416d7e57 100644
--- a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
@@ -207,3 +207,86 @@ TEST(IntegerRelationTest, swapVar) {
   EXPECT_TRUE(swappedSpace.getId(VarKind::Symbol, 1)
                   .isEqual(space.getId(VarKind::Domain, 1)));
 }
+
+TEST(IntegerRelationTest, mergeAndAlignSymbols) {
+  IntegerRelation rel =
+      parseRelationFromSet("(x, y, z, a, b, c)[N, Q] : (a - x - y == 0, "
+                           "x >= 0, N - b >= 0, y >= 0, Q - y >= 0)",
+                           3);
+  IntegerRelation otherRel = parseRelationFromSet(
+      "(x, y, z, a, b)[N, M, P] : (z - x - y == 0, x >= 0, N - x "
+      ">= 0, y >= 0, M - y >= 0, 2 * P - 3 * a + 2 * b == 0)",
+      3);
+  PresburgerSpace space = PresburgerSpace::getRelationSpace(3, 3, 2, 0);
+  space.resetIds();
+
+  PresburgerSpace otherSpace = PresburgerSpace::getRelationSpace(3, 2, 3, 0);
+  otherSpace.resetIds();
+
+  // Attach identifiers.
+  int identifiers[7] = {0, 1, 2, 3, 4, 5, 6};
+  int otherIdentifiers[8] = {10, 11, 12, 13, 14, 15, 16, 17};
+
+  space.getId(VarKind::Domain, 0) = Identifier(&identifiers[0]);
+  space.getId(VarKind::Domain, 1) = Identifier(&identifiers[1]);
+  // Note the common identifier.
+  space.getId(VarKind::Domain, 2) = Identifier(&otherIdentifiers[2]);
+  space.getId(VarKind::Range, 0) = Identifier(&identifiers[2]);
+  space.getId(VarKind::Range, 1) = Identifier(&identifiers[3]);
+  space.getId(VarKind::Range, 2) = Identifier(&identifiers[4]);
+  space.getId(VarKind::Symbol, 0) = Identifier(&identifiers[5]);
+  space.getId(VarKind::Symbol, 1) = Identifier(&identifiers[6]);
+
+  otherSpace.getId(VarKind::Domain, 0) = Identifier(&otherIdentifiers[0]);
+  otherSpace.getId(VarKind::Domain, 1) = Identifier(&otherIdentifiers[1]);
+  otherSpace.getId(VarKind::Domain, 2) = Identifier(&otherIdentifiers[2]);
+  otherSpace.getId(VarKind::Range, 0) = Identifier(&otherIdentifiers[3]);
+  otherSpace.getId(VarKind::Range, 1) = Identifier(&otherIdentifiers[4]);
+  // Note the common identifier.
+  otherSpace.getId(VarKind::Symbol, 0) = Identifier(&identifiers[6]);
+  otherSpace.getId(VarKind::Symbol, 1) = Identifier(&otherIdentifiers[5]);
+  otherSpace.getId(VarKind::Symbol, 2) = Identifier(&otherIdentifiers[7]);
+
+  rel.setSpace(space);
+  otherRel.setSpace(otherSpace);
+  rel.mergeAndAlignSymbols(otherRel);
+
+  space = rel.getSpace();
+  otherSpace = otherRel.getSpace();
+
+  // Check if merge & align is successful.
+  // Check symbol var identifiers.
+  EXPECT_EQ(4u, space.getNumSymbolVars());
+  EXPECT_EQ(4u, otherSpace.getNumSymbolVars());
+  EXPECT_EQ(space.getId(VarKind::Symbol, 0), Identifier(&identifiers[5]));
+  EXPECT_EQ(space.getId(VarKind::Symbol, 1), Identifier(&identifiers[6]));
+  EXPECT_EQ(space.getId(VarKind::Symbol, 2), Identifier(&otherIdentifiers[5]));
+  EXPECT_EQ(space.getId(VarKind::Symbol, 3), Identifier(&otherIdentifiers[7]));
+  EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 0), Identifier(&identifiers[5]));
+  EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 1), Identifier(&identifiers[6]));
+  EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 2),
+            Identifier(&otherIdentifiers[5]));
+  EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 3),
+            Identifier(&otherIdentifiers[7]));
+  // Check that domain and range var identifiers are not affected.
+  EXPECT_EQ(3u, space.getNumDomainVars());
+  EXPECT_EQ(3u, space.getNumRangeVars());
+  EXPECT_EQ(space.getId(VarKind::Domain, 0), Identifier(&identifiers[0]));
+  EXPECT_EQ(space.getId(VarKind::Domain, 1), Identifier(&identifiers[1]));
+  EXPECT_EQ(space.getId(VarKind::Domain, 2), Identifier(&otherIdentifiers[2]));
+  EXPECT_EQ(space.getId(VarKind::Range, 0), Identifier(&identifiers[2]));
+  EXPECT_EQ(space.getId(VarKind::Range, 1), Identifier(&identifiers[3]));
+  EXPECT_EQ(space.getId(VarKind::Range, 2), Identifier(&identifiers[4]));
+  EXPECT_EQ(3u, otherSpace.getNumDomainVars());
+  EXPECT_EQ(2u, otherSpace.getNumRangeVars());
+  EXPECT_EQ(otherSpace.getId(VarKind::Domain, 0),
+            Identifier(&otherIdentifiers[0]));
+  EXPECT_EQ(otherSpace.getId(VarKind::Domain, 1),
+            Identifier(&otherIdentifiers[1]));
+  EXPECT_EQ(otherSpace.getId(VarKind::Domain, 2),
+            Identifier(&otherIdentifiers[2]));
+  EXPECT_EQ(otherSpace.getId(VarKind::Range, 0),
+            Identifier(&otherIdentifiers[3]));
+  EXPECT_EQ(otherSpace.getId(VarKind::Range, 1),
+            Identifier(&otherIdentifiers[4]));
+}

@iambrj iambrj force-pushed the IRmergeSymbols branch 2 times, most recently from f5c792f to 8d338e2 Compare January 3, 2024 15:56
@iambrj iambrj requested a review from Groverkss January 3, 2024 15:56
@iambrj iambrj requested a review from Groverkss January 3, 2024 16:52
@Superty Superty requested review from Superty and removed request for Superty January 3, 2024 16:53
Copy link
Member

@Groverkss Groverkss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some documentation requests. Once that is done, we can land this.

@iambrj iambrj force-pushed the IRmergeSymbols branch 2 times, most recently from 2d9219f to eb5c837 Compare January 5, 2024 15:19
Copy link
Member

@Groverkss Groverkss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@iambrj iambrj merged commit 3eb9fd8 into llvm:main Jan 7, 2024
justinfargnoli pushed a commit to justinfargnoli/llvm-project that referenced this pull request Jan 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants