Skip to content

Fix mypy errors at bidirectional_a_star #4556

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 23 additions & 19 deletions graphs/bidirectional_a_star.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from math import sqrt

# 1 for manhattan, 0 for euclidean
from typing import Optional

HEURISTIC = 0

grid = [
Expand All @@ -22,6 +24,8 @@

delta = [[-1, 0], [0, -1], [1, 0], [0, 1]] # up, left, down, right

TPosition = tuple[int, int]


class Node:
"""
Expand All @@ -39,7 +43,15 @@ class Node:
True
"""

def __init__(self, pos_x, pos_y, goal_x, goal_y, g_cost, parent):
def __init__(
self,
pos_x: int,
pos_y: int,
goal_x: int,
goal_y: int,
g_cost: int,
parent: Optional[Node],
) -> None:
self.pos_x = pos_x
self.pos_y = pos_y
self.pos = (pos_y, pos_x)
Expand All @@ -61,7 +73,7 @@ def calculate_heuristic(self) -> float:
else:
return sqrt(dy ** 2 + dx ** 2)

def __lt__(self, other) -> bool:
def __lt__(self, other: Node) -> bool:
return self.f_cost < other.f_cost


Expand All @@ -81,23 +93,22 @@ class AStar:
(4, 3), (4, 4), (5, 4), (5, 5), (6, 5), (6, 6)]
"""

def __init__(self, start, goal):
def __init__(self, start: TPosition, goal: TPosition):
self.start = Node(start[1], start[0], goal[1], goal[0], 0, None)
self.target = Node(goal[1], goal[0], goal[1], goal[0], 99999, None)

self.open_nodes = [self.start]
self.closed_nodes = []
self.closed_nodes: list[Node] = []

self.reached = False

def search(self) -> list[tuple[int]]:
def search(self) -> list[TPosition]:
while self.open_nodes:
# Open Nodes are sorted using __lt__
self.open_nodes.sort()
current_node = self.open_nodes.pop(0)

if current_node.pos == self.target.pos:
self.reached = True
return self.retrace_path(current_node)

self.closed_nodes.append(current_node)
Expand All @@ -118,8 +129,7 @@ def search(self) -> list[tuple[int]]:
else:
self.open_nodes.append(better_node)

if not (self.reached):
return [(self.start.pos)]
return [self.start.pos]

def get_successors(self, parent: Node) -> list[Node]:
"""
Expand Down Expand Up @@ -147,7 +157,7 @@ def get_successors(self, parent: Node) -> list[Node]:
)
return successors

def retrace_path(self, node: Node) -> list[tuple[int]]:
def retrace_path(self, node: Optional[Node]) -> list[TPosition]:
"""
Retrace the path from parents to parents until start node
"""
Expand All @@ -173,20 +183,19 @@ class BidirectionalAStar:
(2, 5), (3, 5), (4, 5), (5, 5), (5, 6), (6, 6)]
"""

def __init__(self, start, goal):
def __init__(self, start: TPosition, goal: TPosition) -> None:
self.fwd_astar = AStar(start, goal)
self.bwd_astar = AStar(goal, start)
self.reached = False

def search(self) -> list[tuple[int]]:
def search(self) -> list[TPosition]:
while self.fwd_astar.open_nodes or self.bwd_astar.open_nodes:
self.fwd_astar.open_nodes.sort()
self.bwd_astar.open_nodes.sort()
current_fwd_node = self.fwd_astar.open_nodes.pop(0)
current_bwd_node = self.bwd_astar.open_nodes.pop(0)

if current_bwd_node.pos == current_fwd_node.pos:
self.reached = True
return self.retrace_bidirectional_path(
current_fwd_node, current_bwd_node
)
Expand Down Expand Up @@ -220,12 +229,11 @@ def search(self) -> list[tuple[int]]:
else:
astar.open_nodes.append(better_node)

if not self.reached:
return [self.fwd_astar.start.pos]
return [self.fwd_astar.start.pos]

def retrace_bidirectional_path(
self, fwd_node: Node, bwd_node: Node
) -> list[tuple[int]]:
) -> list[TPosition]:
fwd_path = self.fwd_astar.retrace_path(fwd_node)
bwd_path = self.bwd_astar.retrace_path(bwd_node)
bwd_path.pop()
Expand All @@ -236,9 +244,6 @@ def retrace_bidirectional_path(

if __name__ == "__main__":
# all coordinates are given in format [y,x]
import doctest

doctest.testmod()
init = (0, 0)
goal = (len(grid) - 1, len(grid[0]) - 1)
for elem in grid:
Expand All @@ -252,6 +257,5 @@ def retrace_bidirectional_path(

bd_start_time = time.time()
bidir_astar = BidirectionalAStar(init, goal)
path = bidir_astar.search()
bd_end_time = time.time() - bd_start_time
print(f"BidirectionalAStar execution time = {bd_end_time:f} seconds")