Skip to content

Commit 83f9d4a

Browse files
authored
Fix Topological Sort (TheAlgorithms#481)
1 parent 50e67b1 commit 83f9d4a

File tree

1 file changed

+105
-44
lines changed

1 file changed

+105
-44
lines changed

src/graph/topological_sort.rs

Lines changed: 105 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,123 @@
1-
use std::collections::{BTreeMap, VecDeque};
2-
3-
type Graph<V, E> = BTreeMap<V, Vec<(V, E)>>;
4-
5-
/// returns topological sort of the graph using Kahn's algorithm
6-
pub fn topological_sort<V: Ord + Copy, E: Ord>(graph: &Graph<V, E>) -> Vec<V> {
7-
let mut visited = BTreeMap::new();
8-
let mut degree = BTreeMap::new();
9-
for u in graph.keys() {
10-
degree.insert(*u, 0);
11-
for (v, _) in graph.get(u).unwrap() {
12-
let entry = degree.entry(*v).or_insert(0);
13-
*entry += 1;
14-
}
1+
use std::collections::HashMap;
2+
use std::collections::VecDeque;
3+
use std::hash::Hash;
4+
5+
#[derive(Debug, Eq, PartialEq)]
6+
pub enum TopoligicalSortError {
7+
CycleDetected,
8+
}
9+
10+
type TopologicalSortResult<Node> = Result<Vec<Node>, TopoligicalSortError>;
11+
12+
/// Given a directed graph, modeled as a list of edges from source to destination
13+
/// Uses Kahn's algorithm to either:
14+
/// return the topological sort of the graph
15+
/// or detect if there's any cycle
16+
pub fn topological_sort<Node: Hash + Eq + Copy>(
17+
edges: &Vec<(Node, Node)>,
18+
) -> TopologicalSortResult<Node> {
19+
// Preparation:
20+
// Build a map of edges, organised from source to destinations
21+
// Also, count the number of incoming edges by node
22+
let mut edges_by_source: HashMap<Node, Vec<Node>> = HashMap::default();
23+
let mut incoming_edges_count: HashMap<Node, usize> = HashMap::default();
24+
for (source, destination) in edges {
25+
incoming_edges_count.entry(*source).or_insert(0); // if we haven't seen this node yet, mark it as having 0 incoming nodes
26+
edges_by_source // add destination to the list of outgoing edges from source
27+
.entry(*source)
28+
.or_insert_with(Vec::default)
29+
.push(*destination);
30+
// then make destination have one more incoming edge
31+
*incoming_edges_count.entry(*destination).or_insert(0) += 1;
1532
}
16-
let mut queue = VecDeque::new();
17-
for (u, d) in degree.iter() {
18-
if *d == 0 {
19-
queue.push_back(*u);
20-
visited.insert(*u, true);
33+
34+
// Now Kahn's algorithm:
35+
// Add nodes that have no incoming edges to a queue
36+
let mut no_incoming_edges_q = VecDeque::default();
37+
for (node, count) in &incoming_edges_count {
38+
if *count == 0 {
39+
no_incoming_edges_q.push_back(*node);
2140
}
2241
}
23-
let mut ret = Vec::new();
24-
while let Some(u) = queue.pop_front() {
25-
ret.push(u);
26-
if let Some(from_u) = graph.get(&u) {
27-
for (v, _) in from_u {
28-
*degree.get_mut(v).unwrap() -= 1;
29-
if *degree.get(v).unwrap() == 0 {
30-
queue.push_back(*v);
31-
visited.insert(*v, true);
42+
// For each node in this "O-incoming-edge-queue"
43+
let mut sorted = Vec::default();
44+
while let Some(no_incoming_edges) = no_incoming_edges_q.pop_back() {
45+
sorted.push(no_incoming_edges); // since the node has no dependency, it can be safely pushed to the sorted result
46+
incoming_edges_count.remove(&no_incoming_edges);
47+
// For each node having this one as dependency
48+
for neighbour in edges_by_source.get(&no_incoming_edges).unwrap_or(&vec![]) {
49+
if let Some(count) = incoming_edges_count.get_mut(neighbour) {
50+
*count -= 1; // decrement the count of incoming edges for the dependent node
51+
if *count == 0 {
52+
// `node` was the last node `neighbour` was dependent on
53+
incoming_edges_count.remove(neighbour); // let's remove it from the map, so that we can know if we covered the whole graph
54+
no_incoming_edges_q.push_front(*neighbour); // it has no incoming edges anymore => push it to the queue
3255
}
3356
}
3457
}
3558
}
36-
ret
59+
if incoming_edges_count.is_empty() {
60+
// we have visited every node
61+
Ok(sorted)
62+
} else {
63+
// some nodes haven't been visited, meaning there's a cycle in the graph
64+
Err(TopoligicalSortError::CycleDetected)
65+
}
3766
}
3867

3968
#[cfg(test)]
4069
mod tests {
41-
use std::collections::BTreeMap;
70+
use super::topological_sort;
71+
use crate::graph::topological_sort::TopoligicalSortError;
4272

43-
use super::{topological_sort, Graph};
44-
fn add_edge<V: Ord + Copy, E: Ord>(graph: &mut Graph<V, E>, from: V, to: V, weight: E) {
45-
let edges = graph.entry(from).or_insert(Vec::new());
46-
edges.push((to, weight));
73+
fn is_valid_sort<Node: Eq>(sorted: &[Node], graph: &[(Node, Node)]) -> bool {
74+
for (source, dest) in graph {
75+
let source_pos = sorted.iter().position(|node| node == source);
76+
let dest_pos = sorted.iter().position(|node| node == dest);
77+
match (source_pos, dest_pos) {
78+
(Some(src), Some(dst)) if src < dst => {}
79+
_ => {
80+
return false;
81+
}
82+
};
83+
}
84+
true
4785
}
4886

4987
#[test]
5088
fn it_works() {
51-
let mut graph = BTreeMap::new();
52-
add_edge(&mut graph, 1, 2, 1);
53-
add_edge(&mut graph, 1, 3, 1);
54-
add_edge(&mut graph, 2, 3, 1);
55-
add_edge(&mut graph, 3, 4, 1);
56-
add_edge(&mut graph, 4, 5, 1);
57-
add_edge(&mut graph, 5, 6, 1);
58-
add_edge(&mut graph, 6, 7, 1);
59-
60-
assert_eq!(topological_sort(&graph), vec![1, 2, 3, 4, 5, 6, 7]);
89+
let graph = vec![(1, 2), (1, 3), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7)];
90+
let sort = topological_sort(&graph);
91+
assert!(sort.is_ok());
92+
let sort = sort.unwrap();
93+
assert!(is_valid_sort(&sort, &graph));
94+
assert_eq!(sort, vec![1, 2, 3, 4, 5, 6, 7]);
95+
}
96+
97+
#[test]
98+
fn test_wikipedia_example() {
99+
let graph = vec![
100+
(5, 11),
101+
(7, 11),
102+
(7, 8),
103+
(3, 8),
104+
(3, 10),
105+
(11, 2),
106+
(11, 9),
107+
(11, 10),
108+
(8, 9),
109+
];
110+
let sort = topological_sort(&graph);
111+
assert!(sort.is_ok());
112+
let sort = sort.unwrap();
113+
assert!(is_valid_sort(&sort, &graph));
114+
}
115+
116+
#[test]
117+
fn test_cyclic_graph() {
118+
let graph = vec![(1, 2), (2, 3), (3, 4), (4, 5), (4, 2)];
119+
let sort = topological_sort(&graph);
120+
assert!(sort.is_err());
121+
assert_eq!(sort.err().unwrap(), TopoligicalSortError::CycleDetected);
61122
}
62123
}

0 commit comments

Comments
 (0)