|
| 1 | +use std::{ |
| 2 | + collections::{BTreeMap, BinaryHeap}, |
| 3 | + ops::Add, |
| 4 | +}; |
| 5 | + |
| 6 | +use num_traits::Zero; |
| 7 | + |
| 8 | +type Graph<V, E> = BTreeMap<V, BTreeMap<V, E>>; |
| 9 | + |
| 10 | +#[derive(Clone, Debug, Eq, PartialEq)] |
| 11 | +struct Candidate<V, E> { |
| 12 | + estimated_weight: E, |
| 13 | + real_weight: E, |
| 14 | + state: V, |
| 15 | +} |
| 16 | + |
| 17 | +impl<V: Ord + Copy, E: Ord + Copy> PartialOrd for Candidate<V, E> { |
| 18 | + fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> { |
| 19 | + // Note the inverted order; we want nodes with lesser weight to have |
| 20 | + // higher priority |
| 21 | + other.estimated_weight.partial_cmp(&self.estimated_weight) |
| 22 | + } |
| 23 | +} |
| 24 | + |
| 25 | +impl<V: Ord + Copy, E: Ord + Copy> Ord for Candidate<V, E> { |
| 26 | + fn cmp(&self, other: &Self) -> std::cmp::Ordering { |
| 27 | + // Note the inverted order; we want nodes with lesser weight to have |
| 28 | + // higher priority |
| 29 | + other.estimated_weight.cmp(&self.estimated_weight) |
| 30 | + } |
| 31 | +} |
| 32 | + |
| 33 | +pub fn astar<V: Ord + Copy, E: Ord + Copy + Add<Output = E> + Zero>( |
| 34 | + graph: &Graph<V, E>, |
| 35 | + start: V, |
| 36 | + target: V, |
| 37 | + heuristic: impl Fn(V) -> E, |
| 38 | +) -> Option<(E, Vec<V>)> { |
| 39 | + // traversal front |
| 40 | + let mut queue = BinaryHeap::new(); |
| 41 | + // maps each node to its predecessor in the final path |
| 42 | + let mut previous = BTreeMap::new(); |
| 43 | + // weights[v] is the accumulated weight from start to v |
| 44 | + let mut weights = BTreeMap::new(); |
| 45 | + // initialize traversal |
| 46 | + weights.insert(start, E::zero()); |
| 47 | + queue.push(Candidate { |
| 48 | + estimated_weight: heuristic(start), |
| 49 | + real_weight: E::zero(), |
| 50 | + state: start, |
| 51 | + }); |
| 52 | + while let Some(Candidate { |
| 53 | + estimated_weight: _, |
| 54 | + real_weight, |
| 55 | + state: current, |
| 56 | + }) = queue.pop() |
| 57 | + { |
| 58 | + if current == target { |
| 59 | + break; |
| 60 | + } |
| 61 | + for (&next, &weight) in &graph[¤t] { |
| 62 | + let real_weight = real_weight + weight; |
| 63 | + if weights |
| 64 | + .get(&next) |
| 65 | + .map(|&weight| real_weight < weight) |
| 66 | + .unwrap_or(true) |
| 67 | + { |
| 68 | + // current allows us to reach next with lower weight (or at all) |
| 69 | + // add next to the front |
| 70 | + let estimated_weight = real_weight + heuristic(next); |
| 71 | + weights.insert(next, real_weight); |
| 72 | + queue.push(Candidate { |
| 73 | + estimated_weight, |
| 74 | + real_weight, |
| 75 | + state: next, |
| 76 | + }); |
| 77 | + previous.insert(next, current); |
| 78 | + } |
| 79 | + } |
| 80 | + } |
| 81 | + let weight = if let Some(&weight) = weights.get(&target) { |
| 82 | + weight |
| 83 | + } else { |
| 84 | + // we did not reach target from start |
| 85 | + return None; |
| 86 | + }; |
| 87 | + // build path in reverse |
| 88 | + let mut current = target; |
| 89 | + let mut path = vec![current]; |
| 90 | + while current != start { |
| 91 | + let prev = previous |
| 92 | + .get(¤t) |
| 93 | + .copied() |
| 94 | + .expect("We reached the target, but are unable to reconsistute the path"); |
| 95 | + current = prev; |
| 96 | + path.push(current); |
| 97 | + } |
| 98 | + path.reverse(); |
| 99 | + Some((weight, path)) |
| 100 | +} |
| 101 | + |
| 102 | +#[cfg(test)] |
| 103 | +mod tests { |
| 104 | + use super::{astar, Graph}; |
| 105 | + use num_traits::Zero; |
| 106 | + use std::collections::BTreeMap; |
| 107 | + |
| 108 | + // the null heuristic make A* equivalent to Dijkstra |
| 109 | + fn null_heuristic<V, E: Zero>(_v: V) -> E { |
| 110 | + E::zero() |
| 111 | + } |
| 112 | + |
| 113 | + fn add_edge<V: Ord + Copy, E: Ord>(graph: &mut Graph<V, E>, v1: V, v2: V, c: E) { |
| 114 | + graph.entry(v1).or_insert_with(BTreeMap::new).insert(v2, c); |
| 115 | + graph.entry(v2).or_insert_with(BTreeMap::new); |
| 116 | + } |
| 117 | + |
| 118 | + #[test] |
| 119 | + fn single_vertex() { |
| 120 | + let mut graph: Graph<usize, usize> = BTreeMap::new(); |
| 121 | + graph.insert(0, BTreeMap::new()); |
| 122 | + |
| 123 | + assert_eq!(astar(&graph, 0, 0, null_heuristic), Some((0, vec![0]))); |
| 124 | + assert_eq!(astar(&graph, 0, 1, null_heuristic), None); |
| 125 | + } |
| 126 | + |
| 127 | + #[test] |
| 128 | + fn single_edge() { |
| 129 | + let mut graph = BTreeMap::new(); |
| 130 | + add_edge(&mut graph, 0, 1, 2); |
| 131 | + |
| 132 | + assert_eq!(astar(&graph, 0, 1, null_heuristic), Some((2, vec![0, 1]))); |
| 133 | + assert_eq!(astar(&graph, 1, 0, null_heuristic), None); |
| 134 | + } |
| 135 | + |
| 136 | + #[test] |
| 137 | + fn graph_1() { |
| 138 | + let mut graph = BTreeMap::new(); |
| 139 | + add_edge(&mut graph, 'a', 'c', 12); |
| 140 | + add_edge(&mut graph, 'a', 'd', 60); |
| 141 | + add_edge(&mut graph, 'b', 'a', 10); |
| 142 | + add_edge(&mut graph, 'c', 'b', 20); |
| 143 | + add_edge(&mut graph, 'c', 'd', 32); |
| 144 | + add_edge(&mut graph, 'e', 'a', 7); |
| 145 | + |
| 146 | + // from a |
| 147 | + assert_eq!( |
| 148 | + astar(&graph, 'a', 'a', null_heuristic), |
| 149 | + Some((0, vec!['a'])) |
| 150 | + ); |
| 151 | + assert_eq!( |
| 152 | + astar(&graph, 'a', 'b', null_heuristic), |
| 153 | + Some((32, vec!['a', 'c', 'b'])) |
| 154 | + ); |
| 155 | + assert_eq!( |
| 156 | + astar(&graph, 'a', 'c', null_heuristic), |
| 157 | + Some((12, vec!['a', 'c'])) |
| 158 | + ); |
| 159 | + assert_eq!( |
| 160 | + astar(&graph, 'a', 'd', null_heuristic), |
| 161 | + Some((12 + 32, vec!['a', 'c', 'd'])) |
| 162 | + ); |
| 163 | + assert_eq!(astar(&graph, 'a', 'e', null_heuristic), None); |
| 164 | + |
| 165 | + // from b |
| 166 | + assert_eq!( |
| 167 | + astar(&graph, 'b', 'a', null_heuristic), |
| 168 | + Some((10, vec!['b', 'a'])) |
| 169 | + ); |
| 170 | + assert_eq!( |
| 171 | + astar(&graph, 'b', 'b', null_heuristic), |
| 172 | + Some((0, vec!['b'])) |
| 173 | + ); |
| 174 | + assert_eq!( |
| 175 | + astar(&graph, 'b', 'c', null_heuristic), |
| 176 | + Some((10 + 12, vec!['b', 'a', 'c'])) |
| 177 | + ); |
| 178 | + assert_eq!( |
| 179 | + astar(&graph, 'b', 'd', null_heuristic), |
| 180 | + Some((10 + 12 + 32, vec!['b', 'a', 'c', 'd'])) |
| 181 | + ); |
| 182 | + assert_eq!(astar(&graph, 'b', 'e', null_heuristic), None); |
| 183 | + |
| 184 | + // from c |
| 185 | + assert_eq!( |
| 186 | + astar(&graph, 'c', 'a', null_heuristic), |
| 187 | + Some((20 + 10, vec!['c', 'b', 'a'])) |
| 188 | + ); |
| 189 | + assert_eq!( |
| 190 | + astar(&graph, 'c', 'b', null_heuristic), |
| 191 | + Some((20, vec!['c', 'b'])) |
| 192 | + ); |
| 193 | + assert_eq!( |
| 194 | + astar(&graph, 'c', 'c', null_heuristic), |
| 195 | + Some((0, vec!['c'])) |
| 196 | + ); |
| 197 | + assert_eq!( |
| 198 | + astar(&graph, 'c', 'd', null_heuristic), |
| 199 | + Some((32, vec!['c', 'd'])) |
| 200 | + ); |
| 201 | + assert_eq!(astar(&graph, 'c', 'e', null_heuristic), None); |
| 202 | + |
| 203 | + // from d |
| 204 | + assert_eq!(astar(&graph, 'd', 'a', null_heuristic), None); |
| 205 | + assert_eq!(astar(&graph, 'd', 'b', null_heuristic), None); |
| 206 | + assert_eq!(astar(&graph, 'd', 'c', null_heuristic), None); |
| 207 | + assert_eq!( |
| 208 | + astar(&graph, 'd', 'd', null_heuristic), |
| 209 | + Some((0, vec!['d'])) |
| 210 | + ); |
| 211 | + assert_eq!(astar(&graph, 'd', 'e', null_heuristic), None); |
| 212 | + |
| 213 | + // from e |
| 214 | + assert_eq!( |
| 215 | + astar(&graph, 'e', 'a', null_heuristic), |
| 216 | + Some((7, vec!['e', 'a'])) |
| 217 | + ); |
| 218 | + assert_eq!( |
| 219 | + astar(&graph, 'e', 'b', null_heuristic), |
| 220 | + Some((7 + 12 + 20, vec!['e', 'a', 'c', 'b'])) |
| 221 | + ); |
| 222 | + assert_eq!( |
| 223 | + astar(&graph, 'e', 'c', null_heuristic), |
| 224 | + Some((7 + 12, vec!['e', 'a', 'c'])) |
| 225 | + ); |
| 226 | + assert_eq!( |
| 227 | + astar(&graph, 'e', 'd', null_heuristic), |
| 228 | + Some((7 + 12 + 32, vec!['e', 'a', 'c', 'd'])) |
| 229 | + ); |
| 230 | + assert_eq!( |
| 231 | + astar(&graph, 'e', 'e', null_heuristic), |
| 232 | + Some((0, vec!['e'])) |
| 233 | + ); |
| 234 | + } |
| 235 | + |
| 236 | + #[test] |
| 237 | + fn test_heuristic() { |
| 238 | + // make a grid |
| 239 | + let mut graph = BTreeMap::new(); |
| 240 | + let rows = 100; |
| 241 | + let cols = 100; |
| 242 | + for row in 0..rows { |
| 243 | + for col in 0..cols { |
| 244 | + add_edge(&mut graph, (row, col), (row + 1, col), 1); |
| 245 | + add_edge(&mut graph, (row, col), (row, col + 1), 1); |
| 246 | + add_edge(&mut graph, (row, col), (row + 1, col + 1), 1); |
| 247 | + add_edge(&mut graph, (row + 1, col), (row, col), 1); |
| 248 | + add_edge(&mut graph, (row + 1, col + 1), (row, col), 1); |
| 249 | + } |
| 250 | + } |
| 251 | + |
| 252 | + // Dijkstra would explore most of the 101 × 101 nodes |
| 253 | + // the heuristic should allow exploring only about 200 nodes |
| 254 | + let now = std::time::Instant::now(); |
| 255 | + let res = astar(&graph, (0, 0), (100, 90), |(i, j)| 100 - i + 90 - j); |
| 256 | + assert!(now.elapsed() < std::time::Duration::from_millis(10)); |
| 257 | + |
| 258 | + let (weight, path) = res.unwrap(); |
| 259 | + assert_eq!(weight, 100); |
| 260 | + assert_eq!(path.len(), 101); |
| 261 | + } |
| 262 | +} |
0 commit comments