Skip to content

Commit 3eb9fd8

Browse files
authored
[MLIR][Presburger] Implement IntegerRelation::mergeAndAlignSymbols (llvm#76736)
1 parent 2835be8 commit 3eb9fd8

File tree

3 files changed

+283
-0
lines changed

3 files changed

+283
-0
lines changed

mlir/include/mlir/Analysis/Presburger/IntegerRelation.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,11 @@ class IntegerRelation {
577577
convertVarKind(kind, varStart, varLimit, VarKind::Local);
578578
}
579579

580+
/// Merge and align symbol variables of `this` and `other` with respect to
581+
/// identifiers. After this operation the symbol variables of both relations
582+
/// have the same identifiers in the same order.
583+
void mergeAndAlignSymbols(IntegerRelation &other);
584+
580585
/// Adds additional local vars to the sets such that they both have the union
581586
/// of the local vars in each set, without changing the set of points that
582587
/// lie in `this` and `other`.

mlir/lib/Analysis/Presburger/IntegerRelation.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1288,6 +1288,40 @@ void IntegerRelation::eliminateRedundantLocalVar(unsigned posA, unsigned posB) {
12881288
removeVar(posB);
12891289
}
12901290

1291+
/// mergeAndAlignSymbols's implementation can be broken down into two steps:
1292+
/// 1. Merge and align identifiers into `other` from `this. If an identifier
1293+
/// from `this` exists in `other` then we align it. Otherwise, we assume it is a
1294+
/// new identifier and insert it into `other` in the same position as `this`.
1295+
/// 2. Add identifiers that are in `other` but not `this to `this`.
1296+
void IntegerRelation::mergeAndAlignSymbols(IntegerRelation &other) {
1297+
assert(space.isUsingIds() && other.space.isUsingIds() &&
1298+
"both relations need to have identifers to merge and align");
1299+
1300+
unsigned i = 0;
1301+
for (const Identifier identifier : space.getIds(VarKind::Symbol)) {
1302+
// Search in `other` starting at position `i` since the left of `i` is
1303+
// aligned.
1304+
const Identifier *findBegin =
1305+
other.space.getIds(VarKind::Symbol).begin() + i;
1306+
const Identifier *findEnd = other.space.getIds(VarKind::Symbol).end();
1307+
const Identifier *itr = std::find(findBegin, findEnd, identifier);
1308+
if (itr != findEnd) {
1309+
other.swapVar(other.getVarKindOffset(VarKind::Symbol) + i,
1310+
other.getVarKindOffset(VarKind::Symbol) + i +
1311+
std::distance(findBegin, itr));
1312+
} else {
1313+
other.insertVar(VarKind::Symbol, i);
1314+
other.space.getId(VarKind::Symbol, i) = identifier;
1315+
}
1316+
++i;
1317+
}
1318+
1319+
for (unsigned e = other.getNumVarKind(VarKind::Symbol); i < e; ++i) {
1320+
insertVar(VarKind::Symbol, i);
1321+
space.getId(VarKind::Symbol, i) = other.space.getId(VarKind::Symbol, i);
1322+
}
1323+
}
1324+
12911325
/// Adds additional local ids to the sets such that they both have the union
12921326
/// of the local ids in each set, without changing the set of points that
12931327
/// lie in `this` and `other`.

mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,3 +207,247 @@ TEST(IntegerRelationTest, swapVar) {
207207
EXPECT_TRUE(swappedSpace.getId(VarKind::Symbol, 1)
208208
.isEqual(space.getId(VarKind::Domain, 1)));
209209
}
210+
211+
TEST(IntegerRelationTest, mergeAndAlignSymbols) {
212+
IntegerRelation rel =
213+
parseRelationFromSet("(x, y, z, a, b, c)[N, Q] : (a - x - y == 0, "
214+
"x >= 0, N - b >= 0, y >= 0, Q - y >= 0)",
215+
3);
216+
IntegerRelation otherRel = parseRelationFromSet(
217+
"(x, y, z, a, b)[N, M, P] : (z - x - y == 0, x >= 0, N - x "
218+
">= 0, y >= 0, M - y >= 0, 2 * P - 3 * a + 2 * b == 0)",
219+
3);
220+
PresburgerSpace space = PresburgerSpace::getRelationSpace(3, 3, 2, 0);
221+
space.resetIds();
222+
223+
PresburgerSpace otherSpace = PresburgerSpace::getRelationSpace(3, 2, 3, 0);
224+
otherSpace.resetIds();
225+
226+
// Attach identifiers.
227+
int identifiers[7] = {0, 1, 2, 3, 4, 5, 6};
228+
int otherIdentifiers[8] = {10, 11, 12, 13, 14, 15, 16, 17};
229+
230+
space.getId(VarKind::Domain, 0) = Identifier(&identifiers[0]);
231+
space.getId(VarKind::Domain, 1) = Identifier(&identifiers[1]);
232+
// Note the common identifier.
233+
space.getId(VarKind::Domain, 2) = Identifier(&otherIdentifiers[2]);
234+
space.getId(VarKind::Range, 0) = Identifier(&identifiers[2]);
235+
space.getId(VarKind::Range, 1) = Identifier(&identifiers[3]);
236+
space.getId(VarKind::Range, 2) = Identifier(&identifiers[4]);
237+
space.getId(VarKind::Symbol, 0) = Identifier(&identifiers[5]);
238+
space.getId(VarKind::Symbol, 1) = Identifier(&identifiers[6]);
239+
240+
otherSpace.getId(VarKind::Domain, 0) = Identifier(&otherIdentifiers[0]);
241+
otherSpace.getId(VarKind::Domain, 1) = Identifier(&otherIdentifiers[1]);
242+
otherSpace.getId(VarKind::Domain, 2) = Identifier(&otherIdentifiers[2]);
243+
otherSpace.getId(VarKind::Range, 0) = Identifier(&otherIdentifiers[3]);
244+
otherSpace.getId(VarKind::Range, 1) = Identifier(&otherIdentifiers[4]);
245+
// Note the common identifier.
246+
otherSpace.getId(VarKind::Symbol, 0) = Identifier(&identifiers[6]);
247+
otherSpace.getId(VarKind::Symbol, 1) = Identifier(&otherIdentifiers[5]);
248+
otherSpace.getId(VarKind::Symbol, 2) = Identifier(&otherIdentifiers[7]);
249+
250+
rel.setSpace(space);
251+
otherRel.setSpace(otherSpace);
252+
rel.mergeAndAlignSymbols(otherRel);
253+
254+
space = rel.getSpace();
255+
otherSpace = otherRel.getSpace();
256+
257+
// Check if merge and align is successful.
258+
// Check symbol var identifiers.
259+
EXPECT_EQ(4u, space.getNumSymbolVars());
260+
EXPECT_EQ(4u, otherSpace.getNumSymbolVars());
261+
EXPECT_EQ(space.getId(VarKind::Symbol, 0), Identifier(&identifiers[5]));
262+
EXPECT_EQ(space.getId(VarKind::Symbol, 1), Identifier(&identifiers[6]));
263+
EXPECT_EQ(space.getId(VarKind::Symbol, 2), Identifier(&otherIdentifiers[5]));
264+
EXPECT_EQ(space.getId(VarKind::Symbol, 3), Identifier(&otherIdentifiers[7]));
265+
EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 0), Identifier(&identifiers[5]));
266+
EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 1), Identifier(&identifiers[6]));
267+
EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 2),
268+
Identifier(&otherIdentifiers[5]));
269+
EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 3),
270+
Identifier(&otherIdentifiers[7]));
271+
// Check that domain and range var identifiers are not affected.
272+
EXPECT_EQ(3u, space.getNumDomainVars());
273+
EXPECT_EQ(3u, space.getNumRangeVars());
274+
EXPECT_EQ(space.getId(VarKind::Domain, 0), Identifier(&identifiers[0]));
275+
EXPECT_EQ(space.getId(VarKind::Domain, 1), Identifier(&identifiers[1]));
276+
EXPECT_EQ(space.getId(VarKind::Domain, 2), Identifier(&otherIdentifiers[2]));
277+
EXPECT_EQ(space.getId(VarKind::Range, 0), Identifier(&identifiers[2]));
278+
EXPECT_EQ(space.getId(VarKind::Range, 1), Identifier(&identifiers[3]));
279+
EXPECT_EQ(space.getId(VarKind::Range, 2), Identifier(&identifiers[4]));
280+
EXPECT_EQ(3u, otherSpace.getNumDomainVars());
281+
EXPECT_EQ(2u, otherSpace.getNumRangeVars());
282+
EXPECT_EQ(otherSpace.getId(VarKind::Domain, 0),
283+
Identifier(&otherIdentifiers[0]));
284+
EXPECT_EQ(otherSpace.getId(VarKind::Domain, 1),
285+
Identifier(&otherIdentifiers[1]));
286+
EXPECT_EQ(otherSpace.getId(VarKind::Domain, 2),
287+
Identifier(&otherIdentifiers[2]));
288+
EXPECT_EQ(otherSpace.getId(VarKind::Range, 0),
289+
Identifier(&otherIdentifiers[3]));
290+
EXPECT_EQ(otherSpace.getId(VarKind::Range, 1),
291+
Identifier(&otherIdentifiers[4]));
292+
}
293+
294+
// Check that mergeAndAlignSymbols unions symbol variables when they are
295+
// disjoint.
296+
TEST(IntegerRelationTest, mergeAndAlignDisjointSymbols) {
297+
IntegerRelation rel = parseRelationFromSet(
298+
"(x, y, z)[A, B, C, D] : (x + A - C - y + D - z >= 0)", 2);
299+
IntegerRelation otherRel = parseRelationFromSet(
300+
"(u, v, a, b)[E, F, G, H] : (E - u + v == 0, v - G - H >= 0)", 2);
301+
PresburgerSpace space = PresburgerSpace::getRelationSpace(2, 1, 4, 0);
302+
space.resetIds();
303+
304+
PresburgerSpace otherSpace = PresburgerSpace::getRelationSpace(2, 2, 4, 0);
305+
otherSpace.resetIds();
306+
307+
// Attach identifiers.
308+
int identifiers[7] = {'x', 'y', 'z', 'A', 'B', 'C', 'D'};
309+
int otherIdentifiers[8] = {'u', 'v', 'a', 'b', 'E', 'F', 'G', 'H'};
310+
311+
space.getId(VarKind::Domain, 0) = Identifier(&identifiers[0]);
312+
space.getId(VarKind::Domain, 1) = Identifier(&identifiers[1]);
313+
space.getId(VarKind::Range, 0) = Identifier(&identifiers[2]);
314+
space.getId(VarKind::Symbol, 0) = Identifier(&identifiers[3]);
315+
space.getId(VarKind::Symbol, 1) = Identifier(&identifiers[4]);
316+
space.getId(VarKind::Symbol, 2) = Identifier(&identifiers[5]);
317+
space.getId(VarKind::Symbol, 3) = Identifier(&identifiers[6]);
318+
319+
otherSpace.getId(VarKind::Domain, 0) = Identifier(&otherIdentifiers[0]);
320+
otherSpace.getId(VarKind::Domain, 1) = Identifier(&otherIdentifiers[1]);
321+
otherSpace.getId(VarKind::Range, 0) = Identifier(&otherIdentifiers[2]);
322+
otherSpace.getId(VarKind::Range, 1) = Identifier(&otherIdentifiers[3]);
323+
otherSpace.getId(VarKind::Symbol, 0) = Identifier(&otherIdentifiers[4]);
324+
otherSpace.getId(VarKind::Symbol, 1) = Identifier(&otherIdentifiers[5]);
325+
otherSpace.getId(VarKind::Symbol, 2) = Identifier(&otherIdentifiers[6]);
326+
otherSpace.getId(VarKind::Symbol, 3) = Identifier(&otherIdentifiers[7]);
327+
328+
rel.setSpace(space);
329+
otherRel.setSpace(otherSpace);
330+
rel.mergeAndAlignSymbols(otherRel);
331+
332+
space = rel.getSpace();
333+
otherSpace = otherRel.getSpace();
334+
335+
// Check if merge and align is successful.
336+
// Check symbol var identifiers.
337+
EXPECT_EQ(8u, space.getNumSymbolVars());
338+
EXPECT_EQ(8u, otherSpace.getNumSymbolVars());
339+
EXPECT_EQ(space.getId(VarKind::Symbol, 0), Identifier(&identifiers[3]));
340+
EXPECT_EQ(space.getId(VarKind::Symbol, 1), Identifier(&identifiers[4]));
341+
EXPECT_EQ(space.getId(VarKind::Symbol, 2), Identifier(&identifiers[5]));
342+
EXPECT_EQ(space.getId(VarKind::Symbol, 3), Identifier(&identifiers[6]));
343+
EXPECT_EQ(space.getId(VarKind::Symbol, 4), Identifier(&otherIdentifiers[4]));
344+
EXPECT_EQ(space.getId(VarKind::Symbol, 5), Identifier(&otherIdentifiers[5]));
345+
EXPECT_EQ(space.getId(VarKind::Symbol, 6), Identifier(&otherIdentifiers[6]));
346+
EXPECT_EQ(space.getId(VarKind::Symbol, 7), Identifier(&otherIdentifiers[7]));
347+
EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 0), Identifier(&identifiers[3]));
348+
EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 1), Identifier(&identifiers[4]));
349+
EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 2), Identifier(&identifiers[5]));
350+
EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 3), Identifier(&identifiers[6]));
351+
EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 4),
352+
Identifier(&otherIdentifiers[4]));
353+
EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 5),
354+
Identifier(&otherIdentifiers[5]));
355+
EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 6),
356+
Identifier(&otherIdentifiers[6]));
357+
EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 7),
358+
Identifier(&otherIdentifiers[7]));
359+
// Check that domain and range var identifiers are not affected.
360+
EXPECT_EQ(2u, space.getNumDomainVars());
361+
EXPECT_EQ(1u, space.getNumRangeVars());
362+
EXPECT_EQ(space.getId(VarKind::Domain, 0), Identifier(&identifiers[0]));
363+
EXPECT_EQ(space.getId(VarKind::Domain, 1), Identifier(&identifiers[1]));
364+
EXPECT_EQ(space.getId(VarKind::Range, 0), Identifier(&identifiers[2]));
365+
EXPECT_EQ(2u, otherSpace.getNumDomainVars());
366+
EXPECT_EQ(2u, otherSpace.getNumRangeVars());
367+
EXPECT_EQ(otherSpace.getId(VarKind::Domain, 0),
368+
Identifier(&otherIdentifiers[0]));
369+
EXPECT_EQ(otherSpace.getId(VarKind::Domain, 1),
370+
Identifier(&otherIdentifiers[1]));
371+
EXPECT_EQ(otherSpace.getId(VarKind::Range, 0),
372+
Identifier(&otherIdentifiers[2]));
373+
EXPECT_EQ(otherSpace.getId(VarKind::Range, 1),
374+
Identifier(&otherIdentifiers[3]));
375+
}
376+
377+
// Check that mergeAndAlignSymbols is correct when a suffix of identifiers is
378+
// shared; i.e. identifiers are [A, B, C, D] and [E, F, C, D].
379+
TEST(IntegerRelationTest, mergeAndAlignCommonSuffixSymbols) {
380+
IntegerRelation rel = parseRelationFromSet(
381+
"(x, y, z)[A, B, C, D] : (x + A - C - y + D - z >= 0)", 2);
382+
IntegerRelation otherRel = parseRelationFromSet(
383+
"(u, v, a, b)[E, F, C, D] : (E - u + v == 0, v - C - D >= 0)", 2);
384+
PresburgerSpace space = PresburgerSpace::getRelationSpace(2, 1, 4, 0);
385+
space.resetIds();
386+
387+
PresburgerSpace otherSpace = PresburgerSpace::getRelationSpace(2, 2, 4, 0);
388+
otherSpace.resetIds();
389+
390+
// Attach identifiers.
391+
int identifiers[7] = {'x', 'y', 'z', 'A', 'B', 'C', 'D'};
392+
int otherIdentifiers[6] = {'u', 'v', 'a', 'b', 'E', 'F'};
393+
394+
space.getId(VarKind::Domain, 0) = Identifier(&identifiers[0]);
395+
space.getId(VarKind::Domain, 1) = Identifier(&identifiers[1]);
396+
space.getId(VarKind::Range, 0) = Identifier(&identifiers[2]);
397+
space.getId(VarKind::Symbol, 0) = Identifier(&identifiers[3]);
398+
space.getId(VarKind::Symbol, 1) = Identifier(&identifiers[4]);
399+
space.getId(VarKind::Symbol, 2) = Identifier(&identifiers[5]);
400+
space.getId(VarKind::Symbol, 3) = Identifier(&identifiers[6]);
401+
402+
otherSpace.getId(VarKind::Domain, 0) = Identifier(&otherIdentifiers[0]);
403+
otherSpace.getId(VarKind::Domain, 1) = Identifier(&otherIdentifiers[1]);
404+
otherSpace.getId(VarKind::Range, 0) = Identifier(&otherIdentifiers[2]);
405+
otherSpace.getId(VarKind::Range, 1) = Identifier(&otherIdentifiers[3]);
406+
otherSpace.getId(VarKind::Symbol, 0) = Identifier(&otherIdentifiers[4]);
407+
otherSpace.getId(VarKind::Symbol, 1) = Identifier(&otherIdentifiers[5]);
408+
// Note common identifiers
409+
otherSpace.getId(VarKind::Symbol, 2) = Identifier(&identifiers[5]);
410+
otherSpace.getId(VarKind::Symbol, 3) = Identifier(&identifiers[6]);
411+
412+
rel.setSpace(space);
413+
otherRel.setSpace(otherSpace);
414+
rel.mergeAndAlignSymbols(otherRel);
415+
416+
space = rel.getSpace();
417+
otherSpace = otherRel.getSpace();
418+
419+
// Check if merge and align is successful.
420+
// Check symbol var identifiers.
421+
EXPECT_EQ(6u, space.getNumSymbolVars());
422+
EXPECT_EQ(6u, otherSpace.getNumSymbolVars());
423+
EXPECT_EQ(space.getId(VarKind::Symbol, 0), Identifier(&identifiers[3]));
424+
EXPECT_EQ(space.getId(VarKind::Symbol, 1), Identifier(&identifiers[4]));
425+
EXPECT_EQ(space.getId(VarKind::Symbol, 2), Identifier(&identifiers[5]));
426+
EXPECT_EQ(space.getId(VarKind::Symbol, 3), Identifier(&identifiers[6]));
427+
EXPECT_EQ(space.getId(VarKind::Symbol, 4), Identifier(&otherIdentifiers[4]));
428+
EXPECT_EQ(space.getId(VarKind::Symbol, 5), Identifier(&otherIdentifiers[5]));
429+
EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 0), Identifier(&identifiers[3]));
430+
EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 1), Identifier(&identifiers[4]));
431+
EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 2), Identifier(&identifiers[5]));
432+
EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 3), Identifier(&identifiers[6]));
433+
EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 4),
434+
Identifier(&otherIdentifiers[4]));
435+
EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 5),
436+
Identifier(&otherIdentifiers[5]));
437+
// Check that domain and range var identifiers are not affected.
438+
EXPECT_EQ(2u, space.getNumDomainVars());
439+
EXPECT_EQ(1u, space.getNumRangeVars());
440+
EXPECT_EQ(space.getId(VarKind::Domain, 0), Identifier(&identifiers[0]));
441+
EXPECT_EQ(space.getId(VarKind::Domain, 1), Identifier(&identifiers[1]));
442+
EXPECT_EQ(space.getId(VarKind::Range, 0), Identifier(&identifiers[2]));
443+
EXPECT_EQ(2u, otherSpace.getNumDomainVars());
444+
EXPECT_EQ(2u, otherSpace.getNumRangeVars());
445+
EXPECT_EQ(otherSpace.getId(VarKind::Domain, 0),
446+
Identifier(&otherIdentifiers[0]));
447+
EXPECT_EQ(otherSpace.getId(VarKind::Domain, 1),
448+
Identifier(&otherIdentifiers[1]));
449+
EXPECT_EQ(otherSpace.getId(VarKind::Range, 0),
450+
Identifier(&otherIdentifiers[2]));
451+
EXPECT_EQ(otherSpace.getId(VarKind::Range, 1),
452+
Identifier(&otherIdentifiers[3]));
453+
}

0 commit comments

Comments
 (0)