|
| 1 | +package aockt.util |
| 2 | + |
| 3 | +import java.util.PriorityQueue |
| 4 | + |
| 5 | +object Pathfinding { |
| 6 | + |
| 7 | + /** |
| 8 | + * Performs a search. |
| 9 | + * If a [heuristic] is given, it is A*, otherwise, Dijkstra's algorithm. |
| 10 | + * |
| 11 | + * @param start The node or state to begin the search from. |
| 12 | + * @param neighbours A function that returns all possible transitions from a node and their associated cost. |
| 13 | + * The cost _must_ be a non-negative value. |
| 14 | + * @param goalFunction A predicate that determines whether a state is the search destination. |
| 15 | + * Search stops upon reaching the first node that evaluates to `true`. |
| 16 | + * @param heuristic A function that estimates the lower bound cost of reaching a destination from a given node. |
| 17 | + * Must never overestimate, otherwise the search result might not be the most cost-effective. |
| 18 | + * @param onVisit An optional callback invoked on each node visit, useful for debugging. |
| 19 | + * @param maximumCost An optional upper bound, prevents any transitions that would exceed this value. |
| 20 | + * @param trackPath If `true`, keeps track of intermediary nodes to be able to construct a search path. |
| 21 | + * If `false` _(the default)_, only the costs to reach the nodes are computed. |
| 22 | + * |
| 23 | + * @return The search result, or `null` if a suitable destination couldn't be reached. |
| 24 | + */ |
| 25 | + fun <T : Any> search( |
| 26 | + start: T, |
| 27 | + neighbours: (T) -> Iterable<Pair<T, Int>>, |
| 28 | + goalFunction: (T) -> Boolean, |
| 29 | + heuristic: (T) -> Int = { 0 }, |
| 30 | + onVisit: (T) -> Unit = {}, |
| 31 | + maximumCost: Int = Int.MAX_VALUE, |
| 32 | + trackPath: Boolean = false, |
| 33 | + ): SearchResult<T>? { |
| 34 | + require(maximumCost > 0) { "Maximum cost must be positive." } |
| 35 | + |
| 36 | + val previous = mutableMapOf<T, T>() |
| 37 | + val distance = mutableMapOf(start to 0) |
| 38 | + val visited = mutableSetOf<Pair<T, Int>>() |
| 39 | + |
| 40 | + @Suppress("UNUSED_DESTRUCTURED_PARAMETER_ENTRY") |
| 41 | + val queue = PriorityQueue(compareBy<Triple<T, Int, Int>> { (node, costSoFar, priority) -> priority }) |
| 42 | + queue.add(Triple(start, 0, 0)) |
| 43 | + |
| 44 | + if (trackPath) previous[start] = start |
| 45 | + |
| 46 | + while (queue.isNotEmpty()) { |
| 47 | + val (node, costSoFar, _) = queue.poll() |
| 48 | + if (!visited.add(node to costSoFar)) continue |
| 49 | + onVisit(node) |
| 50 | + if (goalFunction(node)) return SearchResult(start, node, distance, previous) |
| 51 | + |
| 52 | + for ((nextNode, nextCost) in neighbours(node)) { |
| 53 | + check(nextCost >= 0) { "Transition cost between nodes cannot be negative." } |
| 54 | + if (maximumCost - nextCost < costSoFar) continue |
| 55 | + |
| 56 | + val totalCost = costSoFar + nextCost |
| 57 | + |
| 58 | + if (totalCost > (distance[nextNode] ?: Int.MAX_VALUE)) continue |
| 59 | + |
| 60 | + distance[nextNode] = totalCost |
| 61 | + if (trackPath) previous[nextNode] = node |
| 62 | + |
| 63 | + val heuristicValue = heuristic(node) |
| 64 | + check(heuristicValue >= 0) { "Heuristic value must be positive." } |
| 65 | + queue.add(Triple(nextNode, totalCost, totalCost + heuristicValue)) |
| 66 | + } |
| 67 | + } |
| 68 | + |
| 69 | + return null |
| 70 | + } |
| 71 | + |
| 72 | + /** |
| 73 | + * The result of a [Pathfinding] search. |
| 74 | + * |
| 75 | + * @property start The node the search started from. |
| 76 | + * @property end The destination node, or the last visited node if an exhaustive flood search was requested. |
| 77 | + * @property cost The cost from [start] to [end], or the maximum cost if an exhaustive flood search was requested. |
| 78 | + * @property path The path from [start] to [end], each node associated with the running cost. |
| 79 | + * @property distance The cost from the [start] to all the visited intermediary nodes. |
| 80 | + * @property previous The previous node in the path of all the visited intermediary nodes. |
| 81 | + * Following it recursively will lead back to the [start] node. |
| 82 | + */ |
| 83 | + class SearchResult<out T> internal constructor( |
| 84 | + val start: T, |
| 85 | + val end: T, |
| 86 | + private val distance: Map<T, Int>, |
| 87 | + private val previous: Map<T, T>, |
| 88 | + ) { |
| 89 | + val cost: Int get() = distance.getValue(end) |
| 90 | + |
| 91 | + val path: List<Pair<T, Int>> by lazy { |
| 92 | + check(previous.isNotEmpty()) { "Cannot generate path as search was performed with `trackPath = false`." } |
| 93 | + buildList { |
| 94 | + var current = end |
| 95 | + while (true) { |
| 96 | + add(current to distance.getValue(current)) |
| 97 | + val previous = previous.getValue(current) |
| 98 | + if (previous == current) break |
| 99 | + current = previous |
| 100 | + } |
| 101 | + }.asReversed() |
| 102 | + } |
| 103 | + } |
| 104 | +} |
0 commit comments