Skip to content

Commit 0c128e5

Browse files
committed
Sema: Optimize ConstraintGraph::computeConnectedComponents()
Instead of starting a depth-first search from each type variable and marking all type variables that haven't been marked yet, we can implement this as a union-find. We can also store the temporary state directly inside the TypeVariableType::Implementation, instead of creating large DenseMaps whose keys range over all type variables.
1 parent 45346c8 commit 0c128e5

File tree

2 files changed

+108
-93
lines changed

2 files changed

+108
-93
lines changed

include/swift/Sema/ConstraintSystem.h

+45-7
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
#include "swift/Sema/SolutionResult.h"
4040
#include "swift/Sema/SyntacticElementTarget.h"
4141
#include "llvm/ADT/MapVector.h"
42+
#include "llvm/ADT/PointerIntPair.h"
4243
#include "llvm/ADT/PointerUnion.h"
4344
#include "llvm/ADT/STLExtras.h"
4445
#include "llvm/ADT/SetOperations.h"
@@ -339,17 +340,17 @@ class TypeVariableType::Implementation {
339340
/// The corresponding node in the constraint graph.
340341
constraints::ConstraintGraphNode *GraphNode = nullptr;
341342

343+
/// Temporary state for ConstraintGraph::computeConnectedComponents(),
344+
/// stored inline for performance.
345+
llvm::PointerIntPair<TypeVariableType *, 1, unsigned> Component;
346+
342347
friend class constraints::SolverTrail;
343348

344349
public:
345350
/// Retrieve the type variable associated with this implementation.
346-
TypeVariableType *getTypeVariable() {
347-
return reinterpret_cast<TypeVariableType *>(this) - 1;
348-
}
349-
350-
/// Retrieve the type variable associated with this implementation.
351-
const TypeVariableType *getTypeVariable() const {
352-
return reinterpret_cast<const TypeVariableType *>(this) - 1;
351+
TypeVariableType *getTypeVariable() const {
352+
return reinterpret_cast<TypeVariableType *>(
353+
const_cast<Implementation *>(this)) - 1;
353354
}
354355

355356
explicit Implementation(constraints::ConstraintLocator *locator,
@@ -415,6 +416,12 @@ class TypeVariableType::Implementation {
415416
return ParentOrFixed.get<TypeVariableType *>() != getTypeVariable();
416417
}
417418

419+
/// Low-level accessor; use getRepresentative() or getFixedType() instead.
420+
llvm::PointerUnion<TypeVariableType *, TypeBase *>
421+
getRepresentativeOrFixed() const {
422+
return ParentOrFixed;
423+
}
424+
418425
/// Record the current type-variable binding.
419426
void recordBinding(constraints::SolverTrail &trail) {
420427
trail.recordChange(constraints::SolverTrail::Change::UpdatedTypeVariable(
@@ -632,6 +639,37 @@ class TypeVariableType::Implementation {
632639
impl.getTypeVariable()->Bits.TypeVariableType.Options |= TVO_CanBindToHole;
633640
}
634641

642+
void setComponent(TypeVariableType *parent) {
643+
Component.setPointerAndInt(parent, /*valid=*/false);
644+
}
645+
646+
TypeVariableType *getComponent() const {
647+
auto *rep = getTypeVariable();
648+
while (rep != rep->getImpl().Component.getPointer())
649+
rep = rep->getImpl().Component.getPointer();
650+
651+
// Path compression
652+
if (rep != getTypeVariable()) {
653+
const_cast<TypeVariableType::Implementation *>(this)
654+
->Component.setPointer(rep);
655+
}
656+
657+
return rep;
658+
}
659+
660+
bool isValidComponent() const {
661+
ASSERT(Component.getPointer() == getTypeVariable());
662+
return Component.getInt();
663+
}
664+
665+
bool markValidComponent() {
666+
if (Component.getInt())
667+
return false;
668+
ASSERT(Component.getPointer() == getTypeVariable());
669+
Component.setInt(1);
670+
return true;
671+
}
672+
635673
/// Print the type variable to the given output stream.
636674
void print(llvm::raw_ostream &OS);
637675

lib/Sema/ConstraintGraph.cpp

+63-86
Original file line numberDiff line numberDiff line change
@@ -721,19 +721,9 @@ namespace {
721721
ConstraintGraph &cg;
722722
ArrayRef<TypeVariableType *> typeVars;
723723

724-
/// A mapping from each type variable to its representative in a union-find
725-
/// data structure, excluding entries where the type variable is its own
726-
/// representative.
727-
mutable llvm::SmallDenseMap<TypeVariableType *, TypeVariableType *>
728-
representatives;
729-
730-
// Figure out which components have unbound type variables and/or
731-
// constraints. These are the only components we want to report.
732-
llvm::SmallDenseSet<TypeVariableType *> validComponents;
733-
734-
/// The complete set of constraints that were visited while computing
735-
/// connected components.
736-
llvm::SmallPtrSet<Constraint *, 8> visitedConstraints;
724+
/// The number of connected components discovered so far. Decremented when
725+
/// we merge equivalence classes.
726+
unsigned validComponentCount = 0;
737727

738728
/// Describes the one-way incoming and outcoming adjacencies of
739729
/// a component within the directed graph of one-way constraints.
@@ -776,10 +766,9 @@ namespace {
776766
// The final return value.
777767
SmallVector<Component, 1> flatComponents;
778768

779-
780769
// We don't actually need to partition the graph into components if
781770
// there are fewer than 2.
782-
if (validComponents.size() < 2 && cg.getOrphanedConstraints().empty())
771+
if (validComponentCount < 2 && cg.getOrphanedConstraints().empty())
783772
return flatComponents;
784773

785774
// Mapping from representatives to components.
@@ -790,8 +779,8 @@ namespace {
790779
for (auto typeVar : typeVars) {
791780
// Find the representative. If we aren't creating a type variable
792781
// for this component, skip it.
793-
auto rep = findRepresentative(typeVar);
794-
if (validComponents.count(rep) == 0)
782+
auto rep = typeVar->getImpl().getComponent();
783+
if (!rep->getImpl().isValidComponent())
795784
continue;
796785

797786
auto pair = components.insert({rep, Component(components.size())});
@@ -829,7 +818,7 @@ namespace {
829818
typeVar = constraintTypeVars.front();
830819
}
831820

832-
auto rep = findRepresentative(typeVar);
821+
auto rep = typeVar->getImpl().getComponent();
833822
getComponent(rep).addConstraint(&constraint);
834823
}
835824

@@ -857,7 +846,7 @@ namespace {
857846
auto &oneWayComponent = knownOneWayComponent->second;
858847
auto &component = getComponent(typeVar);
859848
for (auto inAdj : oneWayComponent.inAdjacencies) {
860-
if (validComponents.count(inAdj) == 0)
849+
if (!inAdj->getImpl().isValidComponent())
861850
continue;
862851

863852
component.recordDependency(getComponent(inAdj));
@@ -899,44 +888,34 @@ namespace {
899888
return flatComponents;
900889
}
901890

902-
/// Find the representative for the given type variable within the set
903-
/// of representatives in a union-find data structure.
904-
TypeVariableType *findRepresentative(TypeVariableType *typeVar) const {
905-
// If we don't have a record of this type variable, it is it's own
906-
// representative.
907-
auto known = representatives.find(typeVar);
908-
if (known == representatives.end() || known->second == typeVar)
909-
return typeVar;
910-
911-
// Find the representative of the parent.
912-
auto parent = known->second;
913-
auto rep = findRepresentative(parent);
914-
representatives[typeVar] = rep;
915-
916-
return rep;
917-
}
918-
919891
private:
920892
/// Perform the union of two type variables in a union-find data structure
921893
/// used for connected components.
922894
///
923895
/// \returns true if the two components were separate and have now been
924896
/// joined, \c false if they were already in the same set.
925897
bool unionSets(TypeVariableType *typeVar1, TypeVariableType *typeVar2) {
926-
auto rep1 = findRepresentative(typeVar1);
927-
auto rep2 = findRepresentative(typeVar2);
898+
auto rep1 = typeVar1->getImpl().getComponent();
899+
auto rep2 = typeVar2->getImpl().getComponent();
928900
if (rep1 == rep2)
929901
return false;
930902

931903
// Reparent the type variable with the higher ID. The actual choice doesn't
932904
// matter, but this makes debugging easier.
933-
if (rep1->getID() < rep2->getID()) {
934-
validComponents.erase(rep2);
935-
representatives[rep2] = rep1;
936-
} else {
937-
validComponents.erase(rep1);
938-
representatives[rep1] = rep2;
905+
if (rep1->getID() > rep2->getID())
906+
std::swap(rep1, rep2);
907+
908+
if (rep2->getImpl().isValidComponent()) {
909+
// If both are valid components, decrement the valid component counter
910+
// by one. Otherwise, propagate the valid component flag.
911+
if (!rep1->getImpl().markValidComponent()) {
912+
ASSERT(validComponentCount > 0);
913+
--validComponentCount;
914+
}
939915
}
916+
917+
rep2->getImpl().setComponent(rep1);
918+
940919
return true;
941920
}
942921

@@ -949,51 +928,49 @@ namespace {
949928

950929
auto &cs = cg.getConstraintSystem();
951930

952-
// Perform a depth-first search from each type variable to identify
953-
// what component it is in.
954931
for (auto typeVar : typeVars) {
955-
// If we've already assigned a representative to this type variable,
956-
// we're done.
957-
if (representatives.count(typeVar) > 0)
958-
continue;
959-
960-
// Perform a depth-first search to mark those type variables that are
961-
// in the same component as this type variable.
962-
depthFirstSearch(
963-
cg, typeVar,
964-
[&](TypeVariableType *found) {
965-
// If we have already seen this node, we're done.
966-
auto inserted = representatives.insert({found, typeVar});
967-
assert((inserted.second || inserted.first->second == typeVar) &&
968-
"Wrong component?");
969-
970-
if (inserted.second)
971-
if (!cs.getFixedType(found))
972-
validComponents.insert(typeVar);
973-
974-
return inserted.second;
975-
},
976-
[&](Constraint *constraint) {
977-
// Record and skip one-way constraints.
978-
if (constraint->isOneWayConstraint()) {
979-
oneWayConstraints.push_back(constraint);
980-
return false;
981-
}
932+
auto &impl = typeVar->getImpl();
933+
if (auto *rep = impl.getRepresentativeOrFixed().dyn_cast<TypeVariableType *>()) {
934+
impl.setComponent(rep);
935+
if (typeVar == rep) {
936+
if (impl.markValidComponent())
937+
++validComponentCount;
938+
}
939+
} else {
940+
impl.setComponent(typeVar);
941+
}
942+
}
982943

983-
return true;
984-
},
985-
visitedConstraints);
944+
for (auto typeVar : typeVars) {
945+
auto &impl = typeVar->getImpl();
946+
if (auto fixedType = impl.getRepresentativeOrFixed().dyn_cast<TypeBase *>()) {
947+
auto &node = cg[typeVar];
948+
for (auto otherTypeVar : node.getReferencedVars()) {
949+
unionSets(typeVar, otherTypeVar);
950+
}
951+
}
986952
}
987953

988954
for (auto &constraint : cs.getConstraints()) {
989-
if (constraint.getKind() == ConstraintKind::Disjunction ||
990-
constraint.getKind() == ConstraintKind::Conjunction) {
991-
for (auto typeVar : constraint.getTypeVariables()) {
992-
auto rep = findRepresentative(typeVar);
993-
if (validComponents.insert(rep).second)
994-
ASSERT(cs.getFixedType(typeVar));
995-
}
955+
if (constraint.isOneWayConstraint()) {
956+
oneWayConstraints.push_back(&constraint);
957+
auto *typeVar = constraint.getFirstType()->castTo<TypeVariableType>();
958+
typeVar = typeVar->getImpl().getComponent();
959+
if (typeVar->getImpl().markValidComponent())
960+
++validComponentCount;
961+
continue;
996962
}
963+
964+
auto typeVars = constraint.getTypeVariables();
965+
if (typeVars.empty())
966+
continue;
967+
968+
auto *firstTypeVar = typeVars[0]->getImpl().getComponent();
969+
if (firstTypeVar->getImpl().markValidComponent())
970+
++validComponentCount;
971+
972+
for (auto *otherTypeVar : typeVars.slice(1))
973+
unionSets(firstTypeVar, otherTypeVar);
997974
}
998975

999976
return oneWayConstraints;
@@ -1016,7 +993,7 @@ namespace {
1016993
SmallPtrSet<TypeVariableType *, 2> typeVars;
1017994
type->getTypeVariables(typeVars);
1018995
for (auto typeVar : typeVars) {
1019-
auto rep = findRepresentative(typeVar);
996+
auto rep = typeVar->getImpl().getComponent();
1020997
insertIfUnique(results, rep);
1021998
}
1022999

@@ -1226,7 +1203,7 @@ namespace {
12261203
rep,
12271204
[&](TypeVariableType *typeVar) -> ArrayRef<TypeVariableType *> {
12281205
// Traverse the outgoing adjacencies for the subcomponent
1229-
assert(typeVar == findRepresentative(typeVar));
1206+
assert(typeVar == typeVar->getImpl().getComponent());
12301207
auto oneWayComponent = oneWayDigraph.find(typeVar);
12311208
if (oneWayComponent == oneWayDigraph.end()) {
12321209
return { };
@@ -1240,7 +1217,7 @@ namespace {
12401217
[&](TypeVariableType *typeVar) {
12411218
// Record this type variable, if it's one of the representative
12421219
// type variables.
1243-
if (validComponents.count(typeVar) > 0)
1220+
if (typeVar->getImpl().isValidComponent())
12441221
orderedReps.push_back(typeVar);
12451222
});
12461223
}

0 commit comments

Comments
 (0)