140 lines
4.8 KiB
Python
140 lines
4.8 KiB
Python
from __future__ import annotations
|
|
from math import inf, sqrt
|
|
from typing import Iterator, Optional
|
|
|
|
from src.graph.node import Node
|
|
from src.graph.edge import Edge
|
|
|
|
|
|
class Graph:
|
|
def __init__(self):
|
|
self.edges: list[Edge] = []
|
|
self.nodes: list[Node] = []
|
|
|
|
def add_node(self, x: int, z: int, name: str = "") -> None:
|
|
self.nodes.append(Node(x, z, len(self.nodes), name))
|
|
|
|
def add_edge(self, start_index: int, end_index: int, auto_length: bool = True) -> None:
|
|
length = 0
|
|
if auto_length:
|
|
n1, n2 = self.nodes[start_index], self.nodes[end_index]
|
|
length = sqrt((n1.x - n2.x)**2 + (n1.z - n2.z)**2)
|
|
self.edges.append(Edge(start_index, end_index, length, len(self.edges)))
|
|
|
|
def delete_edge(self, edge: Edge) -> None:
|
|
self.edges.remove(edge)
|
|
for ed in self.edges:
|
|
ed.index = self.edges.index(ed)
|
|
|
|
def delete_node(self, node: Node) -> None:
|
|
edges_to_delete = []
|
|
for edge in self.edges:
|
|
if node.index in (edge.start, edge.end):
|
|
edges_to_delete.append(edge)
|
|
continue
|
|
if edge.start > node.index:
|
|
edge.start -= 1
|
|
if edge.end > node.index:
|
|
edge.end -= 1
|
|
for edge in edges_to_delete:
|
|
self.delete_edge(edge)
|
|
self.nodes.remove(node)
|
|
for no in self.nodes:
|
|
no.index = self.nodes.index(no)
|
|
|
|
def recompute_lengths(self) -> None:
|
|
for edge in self.edges:
|
|
n1 = self.nodes[edge.start]
|
|
n2 = self.nodes[edge.end]
|
|
edge.length = sqrt((n1.x - n2.x)**2 + (n1.z - n2.z)**2)
|
|
|
|
def number_of_nodes(self) -> int:
|
|
return len(self.nodes)
|
|
|
|
def get_edge(self, node_1: int, node_2: int) -> int:
|
|
for edge in self.edges:
|
|
if (edge.start == node_1 and edge.end == node_2) or (edge.start == node_2 and edge.end == node_1):
|
|
return self.edges.index(edge)
|
|
return -1
|
|
|
|
def get_edge_nodes(self, edge: Edge) -> tuple[Node, Node]:
|
|
return self.nodes[edge.start], self.nodes[edge.end]
|
|
|
|
def get_edge_center(self, edge_index: int) -> tuple[float, float]:
|
|
edge = self.edges[edge_index]
|
|
start_n = self.nodes[edge.start]
|
|
end_n = self.nodes[edge.end]
|
|
return (start_n.x + end_n.x) / 2, (start_n.z + end_n.z) / 2
|
|
|
|
def edges_adjacent_to(self, node_i: int) -> Iterator[Edge]:
|
|
return filter(lambda e: e.start == node_i or e.end == node_i, self.edges)
|
|
|
|
def edge_exists(self, node_1: int, node_2: int) -> bool:
|
|
return self.get_edge(node_1, node_2) != -1
|
|
|
|
def dijkstra(self, source_index: int, target_index: int) -> Optional[list[int]]:
|
|
n = len(self.nodes)
|
|
|
|
if source_index < 0 or source_index >= n:
|
|
return None
|
|
|
|
if target_index < 0 or target_index >= n:
|
|
return None
|
|
|
|
unvisited = list(range(n))
|
|
|
|
distances_from_start = [inf] * n
|
|
distances_from_start[source_index] = 0
|
|
|
|
node_sequences = [[] for _ in range(n)]
|
|
node_sequences[source_index] = [source_index]
|
|
|
|
while True:
|
|
current_index = min(unvisited, key=lambda i: distances_from_start[i])
|
|
|
|
if current_index == target_index:
|
|
break
|
|
|
|
unvisited.remove(current_index)
|
|
|
|
for edge in self.edges_adjacent_to(current_index):
|
|
start = current_index
|
|
end = edge.end if edge.start == current_index else edge.start
|
|
|
|
if end in unvisited and distances_from_start[end] > distances_from_start[start] + edge.length:
|
|
distances_from_start[end] = distances_from_start[start] + edge.length
|
|
node_sequences[end] = node_sequences[start].copy()
|
|
node_sequences[end].append(end)
|
|
|
|
return node_sequences[target_index]
|
|
|
|
def save(self, path: str) -> None:
|
|
with open(path, "w") as f:
|
|
for node in self.nodes:
|
|
f.write(f"n {node.x} {node.z} {node.name}\n")
|
|
f.write("\n")
|
|
for edge in self.edges:
|
|
f.write(f"e {edge.start} {edge.end}\n")
|
|
|
|
@staticmethod
|
|
def load(path: str) -> Graph:
|
|
graph = Graph()
|
|
with open(path, "r") as f:
|
|
lines = f.read().splitlines()
|
|
for line in lines:
|
|
if len(line.strip()) == 0:
|
|
continue
|
|
|
|
entry_type, values = line.split(" ", 1)
|
|
if entry_type == "n":
|
|
x, z, name = values.split(" ", 2)
|
|
x, z = int(x), int(z)
|
|
graph.add_node(x, z, name)
|
|
elif entry_type == "e":
|
|
start, end = values.split(" ", 2)
|
|
start, end = int(start), int(end)
|
|
graph.add_edge(start, end, False)
|
|
|
|
graph.recompute_lengths()
|
|
return graph
|