diff --git a/utils/dijkstra.py b/utils/dijkstra.py index 846027b..8f8ba8d 100644 --- a/utils/dijkstra.py +++ b/utils/dijkstra.py @@ -1,86 +1,88 @@ from math import inf +from typing import Iterator, Optional -class Edge: - - length = 0 - - def __init__(self, start, end, length): - self.length = length - self.start = start - self.end = end class Node: + def __init__(self, x: int, y: int): + self.x: int = x + self.y: int = y + + +class Edge: + def __init__(self, start: int, end: int, length: float): + self.length: float = length + self.start: int = start + self.end: int = end - def __init__(self, xpos, ypos): - self.xpos = xpos - self.ypos = ypos class Graph: - def __init__(self): - self.edges = list() - self.nodes = list() + self.edges: list[Edge] = [] + self.nodes: list[Node] = [] - def add_node(self, xpos, ypos): - self.nodes.append(Node(xpos, ypos)) - - def add_edge(self, start_index, end_index, length): + def add_node(self, x: int, y: int) -> None: + self.nodes.append(Node(x, y)) + + def add_edge(self, start_index: int, end_index: int, length: float) -> None: self.edges.append(Edge(start_index, end_index, length)) + def edges_adjacent_to(self, node_i: int) -> Iterator[Edge]: + return filter(lambda e: e.start == node_i or e.end == node_i, self.edges) -def Dijkstra(graph, source_index, target_index): + def dijkstra(self, source_index: int, target_index: int) -> Optional[list[int]]: + n = len(self.nodes) - n = len(graph.nodes) + if source_index < 0 or source_index >= n: + return None - if (target_index >= n): - return None + if target_index < 0 or target_index >= n: + return None - unvisited = list(range(n)) + unvisited = list(range(n)) - distances_from_start = [inf] * n - distances_from_start[source_index] = 0 + distances_from_start = [inf] * n + distances_from_start[source_index] = 0 - node_sequences = [list() for i in range(n)] - node_sequences[source_index] = [source_index] + node_sequences = [[] for _ in range(n)] + node_sequences[source_index] = [source_index] - while(True): - try: - current_index = min(unvisited, key = lambda i: distances_from_start[i]) - except ValueError: - break - - if current_index == target_index: - break + while True: + current_index = min(unvisited, key=lambda i: distances_from_start[i]) - unvisited.remove(current_index) + if current_index == target_index: + break - for edge in filter( - lambda e: e.start == current_index or e.end == current_index, - graph.edges): - - start = current_index - end = edge.end if edge.start == current_index else edge.start + unvisited.remove(current_index) - if end in unvisited and distances_from_start[end] > distances_from_start[start] + edge.length: - distances_from_start[end] = distances_from_start[start] + edge.length - node_sequences[end] = node_sequences[start].copy() - node_sequences[end].append(end) + for edge in self.edges_adjacent_to(current_index): + start = current_index + end = edge.end if edge.start == current_index else edge.start - return node_sequences[target_index] + if end in unvisited and distances_from_start[end] > distances_from_start[start] + edge.length: + distances_from_start[end] = distances_from_start[start] + edge.length + node_sequences[end] = node_sequences[start].copy() + node_sequences[end].append(end) + + return node_sequences[target_index] -graph = Graph() +def main() -> None: + graph = Graph() -graph.add_node(1, 2) -graph.add_node(4, 7) -graph.add_node(3,1) -graph.add_node(-2,0) -graph.add_node(0,0) + graph.add_node(1, 2) + graph.add_node(4, 7) + graph.add_node(3, 1) + graph.add_node(-2, 0) + graph.add_node(0, 0) -graph.add_edge(0, 1, 1) -graph.add_edge(1, 2, 2) -graph.add_edge(2, 3, 3) -graph.add_edge(3, 0, 1) -graph.add_edge(1, 3, 3) + graph.add_edge(0, 1, 1) + graph.add_edge(1, 2, 2) + graph.add_edge(2, 3, 3) + graph.add_edge(3, 0, 1) + graph.add_edge(1, 3, 3) -print(Dijkstra(graph, 0, 5)) \ No newline at end of file + print(graph.dijkstra(0, 3)) + + +if __name__ == "__main__": + main()