Skip to content

Commit baa5a29

Browse files
authored
Implement A* search algorithm (TheAlgorithms#470)
1 parent ad7ec45 commit baa5a29

File tree

3 files changed

+283
-19
lines changed

3 files changed

+283
-19
lines changed

src/graph/astar.rs

Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
use std::{
2+
collections::{BTreeMap, BinaryHeap},
3+
ops::Add,
4+
};
5+
6+
use num_traits::Zero;
7+
8+
type Graph<V, E> = BTreeMap<V, BTreeMap<V, E>>;
9+
10+
#[derive(Clone, Debug, Eq, PartialEq)]
11+
struct Candidate<V, E> {
12+
estimated_weight: E,
13+
real_weight: E,
14+
state: V,
15+
}
16+
17+
impl<V: Ord + Copy, E: Ord + Copy> PartialOrd for Candidate<V, E> {
18+
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
19+
// Note the inverted order; we want nodes with lesser weight to have
20+
// higher priority
21+
other.estimated_weight.partial_cmp(&self.estimated_weight)
22+
}
23+
}
24+
25+
impl<V: Ord + Copy, E: Ord + Copy> Ord for Candidate<V, E> {
26+
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
27+
// Note the inverted order; we want nodes with lesser weight to have
28+
// higher priority
29+
other.estimated_weight.cmp(&self.estimated_weight)
30+
}
31+
}
32+
33+
pub fn astar<V: Ord + Copy, E: Ord + Copy + Add<Output = E> + Zero>(
34+
graph: &Graph<V, E>,
35+
start: V,
36+
target: V,
37+
heuristic: impl Fn(V) -> E,
38+
) -> Option<(E, Vec<V>)> {
39+
// traversal front
40+
let mut queue = BinaryHeap::new();
41+
// maps each node to its predecessor in the final path
42+
let mut previous = BTreeMap::new();
43+
// weights[v] is the accumulated weight from start to v
44+
let mut weights = BTreeMap::new();
45+
// initialize traversal
46+
weights.insert(start, E::zero());
47+
queue.push(Candidate {
48+
estimated_weight: heuristic(start),
49+
real_weight: E::zero(),
50+
state: start,
51+
});
52+
while let Some(Candidate {
53+
estimated_weight: _,
54+
real_weight,
55+
state: current,
56+
}) = queue.pop()
57+
{
58+
if current == target {
59+
break;
60+
}
61+
for (&next, &weight) in &graph[&current] {
62+
let real_weight = real_weight + weight;
63+
if weights
64+
.get(&next)
65+
.map(|&weight| real_weight < weight)
66+
.unwrap_or(true)
67+
{
68+
// current allows us to reach next with lower weight (or at all)
69+
// add next to the front
70+
let estimated_weight = real_weight + heuristic(next);
71+
weights.insert(next, real_weight);
72+
queue.push(Candidate {
73+
estimated_weight,
74+
real_weight,
75+
state: next,
76+
});
77+
previous.insert(next, current);
78+
}
79+
}
80+
}
81+
let weight = if let Some(&weight) = weights.get(&target) {
82+
weight
83+
} else {
84+
// we did not reach target from start
85+
return None;
86+
};
87+
// build path in reverse
88+
let mut current = target;
89+
let mut path = vec![current];
90+
while current != start {
91+
let prev = previous
92+
.get(&current)
93+
.copied()
94+
.expect("We reached the target, but are unable to reconsistute the path");
95+
current = prev;
96+
path.push(current);
97+
}
98+
path.reverse();
99+
Some((weight, path))
100+
}
101+
102+
#[cfg(test)]
103+
mod tests {
104+
use super::{astar, Graph};
105+
use num_traits::Zero;
106+
use std::collections::BTreeMap;
107+
108+
// the null heuristic make A* equivalent to Dijkstra
109+
fn null_heuristic<V, E: Zero>(_v: V) -> E {
110+
E::zero()
111+
}
112+
113+
fn add_edge<V: Ord + Copy, E: Ord>(graph: &mut Graph<V, E>, v1: V, v2: V, c: E) {
114+
graph.entry(v1).or_insert_with(BTreeMap::new).insert(v2, c);
115+
graph.entry(v2).or_insert_with(BTreeMap::new);
116+
}
117+
118+
#[test]
119+
fn single_vertex() {
120+
let mut graph: Graph<usize, usize> = BTreeMap::new();
121+
graph.insert(0, BTreeMap::new());
122+
123+
assert_eq!(astar(&graph, 0, 0, null_heuristic), Some((0, vec![0])));
124+
assert_eq!(astar(&graph, 0, 1, null_heuristic), None);
125+
}
126+
127+
#[test]
128+
fn single_edge() {
129+
let mut graph = BTreeMap::new();
130+
add_edge(&mut graph, 0, 1, 2);
131+
132+
assert_eq!(astar(&graph, 0, 1, null_heuristic), Some((2, vec![0, 1])));
133+
assert_eq!(astar(&graph, 1, 0, null_heuristic), None);
134+
}
135+
136+
#[test]
137+
fn graph_1() {
138+
let mut graph = BTreeMap::new();
139+
add_edge(&mut graph, 'a', 'c', 12);
140+
add_edge(&mut graph, 'a', 'd', 60);
141+
add_edge(&mut graph, 'b', 'a', 10);
142+
add_edge(&mut graph, 'c', 'b', 20);
143+
add_edge(&mut graph, 'c', 'd', 32);
144+
add_edge(&mut graph, 'e', 'a', 7);
145+
146+
// from a
147+
assert_eq!(
148+
astar(&graph, 'a', 'a', null_heuristic),
149+
Some((0, vec!['a']))
150+
);
151+
assert_eq!(
152+
astar(&graph, 'a', 'b', null_heuristic),
153+
Some((32, vec!['a', 'c', 'b']))
154+
);
155+
assert_eq!(
156+
astar(&graph, 'a', 'c', null_heuristic),
157+
Some((12, vec!['a', 'c']))
158+
);
159+
assert_eq!(
160+
astar(&graph, 'a', 'd', null_heuristic),
161+
Some((12 + 32, vec!['a', 'c', 'd']))
162+
);
163+
assert_eq!(astar(&graph, 'a', 'e', null_heuristic), None);
164+
165+
// from b
166+
assert_eq!(
167+
astar(&graph, 'b', 'a', null_heuristic),
168+
Some((10, vec!['b', 'a']))
169+
);
170+
assert_eq!(
171+
astar(&graph, 'b', 'b', null_heuristic),
172+
Some((0, vec!['b']))
173+
);
174+
assert_eq!(
175+
astar(&graph, 'b', 'c', null_heuristic),
176+
Some((10 + 12, vec!['b', 'a', 'c']))
177+
);
178+
assert_eq!(
179+
astar(&graph, 'b', 'd', null_heuristic),
180+
Some((10 + 12 + 32, vec!['b', 'a', 'c', 'd']))
181+
);
182+
assert_eq!(astar(&graph, 'b', 'e', null_heuristic), None);
183+
184+
// from c
185+
assert_eq!(
186+
astar(&graph, 'c', 'a', null_heuristic),
187+
Some((20 + 10, vec!['c', 'b', 'a']))
188+
);
189+
assert_eq!(
190+
astar(&graph, 'c', 'b', null_heuristic),
191+
Some((20, vec!['c', 'b']))
192+
);
193+
assert_eq!(
194+
astar(&graph, 'c', 'c', null_heuristic),
195+
Some((0, vec!['c']))
196+
);
197+
assert_eq!(
198+
astar(&graph, 'c', 'd', null_heuristic),
199+
Some((32, vec!['c', 'd']))
200+
);
201+
assert_eq!(astar(&graph, 'c', 'e', null_heuristic), None);
202+
203+
// from d
204+
assert_eq!(astar(&graph, 'd', 'a', null_heuristic), None);
205+
assert_eq!(astar(&graph, 'd', 'b', null_heuristic), None);
206+
assert_eq!(astar(&graph, 'd', 'c', null_heuristic), None);
207+
assert_eq!(
208+
astar(&graph, 'd', 'd', null_heuristic),
209+
Some((0, vec!['d']))
210+
);
211+
assert_eq!(astar(&graph, 'd', 'e', null_heuristic), None);
212+
213+
// from e
214+
assert_eq!(
215+
astar(&graph, 'e', 'a', null_heuristic),
216+
Some((7, vec!['e', 'a']))
217+
);
218+
assert_eq!(
219+
astar(&graph, 'e', 'b', null_heuristic),
220+
Some((7 + 12 + 20, vec!['e', 'a', 'c', 'b']))
221+
);
222+
assert_eq!(
223+
astar(&graph, 'e', 'c', null_heuristic),
224+
Some((7 + 12, vec!['e', 'a', 'c']))
225+
);
226+
assert_eq!(
227+
astar(&graph, 'e', 'd', null_heuristic),
228+
Some((7 + 12 + 32, vec!['e', 'a', 'c', 'd']))
229+
);
230+
assert_eq!(
231+
astar(&graph, 'e', 'e', null_heuristic),
232+
Some((0, vec!['e']))
233+
);
234+
}
235+
236+
#[test]
237+
fn test_heuristic() {
238+
// make a grid
239+
let mut graph = BTreeMap::new();
240+
let rows = 100;
241+
let cols = 100;
242+
for row in 0..rows {
243+
for col in 0..cols {
244+
add_edge(&mut graph, (row, col), (row + 1, col), 1);
245+
add_edge(&mut graph, (row, col), (row, col + 1), 1);
246+
add_edge(&mut graph, (row, col), (row + 1, col + 1), 1);
247+
add_edge(&mut graph, (row + 1, col), (row, col), 1);
248+
add_edge(&mut graph, (row + 1, col + 1), (row, col), 1);
249+
}
250+
}
251+
252+
// Dijkstra would explore most of the 101 × 101 nodes
253+
// the heuristic should allow exploring only about 200 nodes
254+
let now = std::time::Instant::now();
255+
let res = astar(&graph, (0, 0), (100, 90), |(i, j)| 100 - i + 90 - j);
256+
assert!(now.elapsed() < std::time::Duration::from_millis(10));
257+
258+
let (weight, path) = res.unwrap();
259+
assert_eq!(weight, 100);
260+
assert_eq!(path.len(), 101);
261+
}
262+
}

src/graph/dijkstra.rs

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,37 +11,37 @@ type Graph<V, E> = BTreeMap<V, BTreeMap<V, E>>;
1111
// since the start has no predecessor but is reachable, map[start] will be None
1212
pub fn dijkstra<V: Ord + Copy, E: Ord + Copy + Add<Output = E>>(
1313
graph: &Graph<V, E>,
14-
start: &V,
14+
start: V,
1515
) -> BTreeMap<V, Option<(V, E)>> {
1616
let mut ans = BTreeMap::new();
1717
let mut prio = BinaryHeap::new();
1818

1919
// start is the special case that doesn't have a predecessor
20-
ans.insert(*start, None);
20+
ans.insert(start, None);
2121

22-
for (new, weight) in &graph[start] {
23-
ans.insert(*new, Some((*start, *weight)));
24-
prio.push(Reverse((*weight, new, start)));
22+
for (new, weight) in &graph[&start] {
23+
ans.insert(*new, Some((start, *weight)));
24+
prio.push(Reverse((*weight, *new, start)));
2525
}
2626

2727
while let Some(Reverse((dist_new, new, prev))) = prio.pop() {
28-
match ans[new] {
28+
match ans[&new] {
2929
// what we popped is what is in ans, we'll compute it
30-
Some((p, d)) if p == *prev && d == dist_new => {}
30+
Some((p, d)) if p == prev && d == dist_new => {}
3131
// otherwise it's not interesting
3232
_ => continue,
3333
}
3434

35-
for (next, weight) in &graph[new] {
35+
for (next, weight) in &graph[&new] {
3636
match ans.get(next) {
3737
// if ans[next] is a lower dist than the alternative one, we do nothing
3838
Some(Some((_, dist_next))) if dist_new + *weight >= *dist_next => {}
3939
// if ans[next] is None then next is start and so the distance won't be changed, it won't be added again in prio
4040
Some(None) => {}
4141
// the new path is shorter, either new was not in ans or it was farther
4242
_ => {
43-
ans.insert(*next, Some((*new, *weight + dist_new)));
44-
prio.push(Reverse((*weight + dist_new, next, new)));
43+
ans.insert(*next, Some((new, *weight + dist_new)));
44+
prio.push(Reverse((*weight + dist_new, *next, new)));
4545
}
4646
}
4747
}
@@ -68,7 +68,7 @@ mod tests {
6868
let mut dists = BTreeMap::new();
6969
dists.insert(0, None);
7070

71-
assert_eq!(dijkstra(&graph, &0), dists);
71+
assert_eq!(dijkstra(&graph, 0), dists);
7272
}
7373

7474
#[test]
@@ -80,12 +80,12 @@ mod tests {
8080
dists_0.insert(0, None);
8181
dists_0.insert(1, Some((0, 2)));
8282

83-
assert_eq!(dijkstra(&graph, &0), dists_0);
83+
assert_eq!(dijkstra(&graph, 0), dists_0);
8484

8585
let mut dists_1 = BTreeMap::new();
8686
dists_1.insert(1, None);
8787

88-
assert_eq!(dijkstra(&graph, &1), dists_1);
88+
assert_eq!(dijkstra(&graph, 1), dists_1);
8989
}
9090

9191
#[test]
@@ -109,7 +109,7 @@ mod tests {
109109
}
110110
}
111111

112-
assert_eq!(dijkstra(&graph, &1), dists);
112+
assert_eq!(dijkstra(&graph, 1), dists);
113113
}
114114

115115
#[test]
@@ -127,32 +127,32 @@ mod tests {
127127
dists_a.insert('c', Some(('a', 12)));
128128
dists_a.insert('d', Some(('c', 44)));
129129
dists_a.insert('b', Some(('c', 32)));
130-
assert_eq!(dijkstra(&graph, &'a'), dists_a);
130+
assert_eq!(dijkstra(&graph, 'a'), dists_a);
131131

132132
let mut dists_b = BTreeMap::new();
133133
dists_b.insert('b', None);
134134
dists_b.insert('a', Some(('b', 10)));
135135
dists_b.insert('c', Some(('a', 22)));
136136
dists_b.insert('d', Some(('c', 54)));
137-
assert_eq!(dijkstra(&graph, &'b'), dists_b);
137+
assert_eq!(dijkstra(&graph, 'b'), dists_b);
138138

139139
let mut dists_c = BTreeMap::new();
140140
dists_c.insert('c', None);
141141
dists_c.insert('b', Some(('c', 20)));
142142
dists_c.insert('d', Some(('c', 32)));
143143
dists_c.insert('a', Some(('b', 30)));
144-
assert_eq!(dijkstra(&graph, &'c'), dists_c);
144+
assert_eq!(dijkstra(&graph, 'c'), dists_c);
145145

146146
let mut dists_d = BTreeMap::new();
147147
dists_d.insert('d', None);
148-
assert_eq!(dijkstra(&graph, &'d'), dists_d);
148+
assert_eq!(dijkstra(&graph, 'd'), dists_d);
149149

150150
let mut dists_e = BTreeMap::new();
151151
dists_e.insert('e', None);
152152
dists_e.insert('a', Some(('e', 7)));
153153
dists_e.insert('c', Some(('a', 19)));
154154
dists_e.insert('d', Some(('c', 51)));
155155
dists_e.insert('b', Some(('c', 39)));
156-
assert_eq!(dijkstra(&graph, &'e'), dists_e);
156+
assert_eq!(dijkstra(&graph, 'e'), dists_e);
157157
}
158158
}

src/graph/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
mod astar;
12
mod bellman_ford;
23
mod bipartite_matching;
34
mod breadth_first_search;
@@ -17,6 +18,7 @@ mod prufer_code;
1718
mod strongly_connected_components;
1819
mod topological_sort;
1920
mod two_satisfiability;
21+
pub use self::astar::astar;
2022
pub use self::bellman_ford::bellman_ford;
2123
pub use self::bipartite_matching::BipartiteMatching;
2224
pub use self::breadth_first_search::breadth_first_search;

0 commit comments

Comments
 (0)