Compare commits
	
		
			4 Commits
		
	
	
		
			8542ee81e7
			...
			8b7927a3c5
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 8b7927a3c5 | |||
| 62de92e7a2 | |||
| 8ad97785b8 | |||
| db112ada4c | 
| @@ -8,6 +8,7 @@ from typing import Type | ||||
|  | ||||
| class CommandType(IntEnum): | ||||
|     CAR_CONTROL = 0 | ||||
|     RECORDING = 1 | ||||
|  | ||||
|  | ||||
| class CarControl(IntEnum): | ||||
| @@ -64,3 +65,20 @@ class ControlCommand(Command): | ||||
|         active: bool = (value & 1) == 1 | ||||
|         control: int = value >> 1 | ||||
|         return ControlCommand(CarControl(control), active) | ||||
|  | ||||
|  | ||||
| class RecordingCommand(Command): | ||||
|     TYPE = CommandType.RECORDING | ||||
|     __match_args__ = ("state",) | ||||
|  | ||||
|     def __init__(self, state: bool) -> None: | ||||
|         super().__init__() | ||||
|         self.state: bool = state | ||||
|  | ||||
|     def get_payload(self) -> bytes: | ||||
|         return struct.pack(">B", self.state) | ||||
|  | ||||
|     @classmethod | ||||
|     def from_payload(cls, payload: bytes) -> Command: | ||||
|         state: bool = struct.unpack(">B", payload)[0] | ||||
|         return RecordingCommand(state) | ||||
|   | ||||
							
								
								
									
										49
									
								
								src/record_file.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										49
									
								
								src/record_file.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,49 @@ | ||||
| from pathlib import Path | ||||
| import struct | ||||
| import time | ||||
| from typing import Literal | ||||
|  | ||||
| from src.snapshot import Snapshot | ||||
|  | ||||
|  | ||||
| class RecordFile: | ||||
|     VERSION = 1 | ||||
|  | ||||
|     def __init__(self, path: str | Path, mode: Literal["w", "r"]) -> None: | ||||
|         self.path: str | Path = path | ||||
|         self.mode: Literal["w", "r"] = mode | ||||
|         self.file = open(self.path, self.mode + "b") | ||||
|  | ||||
|     def __enter__(self): | ||||
|         return self | ||||
|  | ||||
|     def __exit__(self, type, value, traceback): | ||||
|         self.file.close() | ||||
|  | ||||
|     def write_header(self, n_snapshots: int): | ||||
|         data: bytes = struct.pack( | ||||
|             ">IId", self.VERSION, n_snapshots, time.time()) | ||||
|         self.file.write(data) | ||||
|  | ||||
|     def write_snapshots(self, snapshots: list[Snapshot]): | ||||
|         self.write_header(len(snapshots)) | ||||
|         for snapshot in snapshots: | ||||
|             data: bytes = snapshot.pack() | ||||
|             self.file.write(struct.pack(">I", len(data)) + data) | ||||
|  | ||||
|     def read_snapshots(self) -> list[Snapshot]: | ||||
|         version: int = struct.unpack(">I", self.file.read(4))[0] | ||||
|         if version != self.VERSION: | ||||
|             raise ValueError( | ||||
|                 f"Cannot parse record file with format version {version} (current version: {self.VERSION})") | ||||
|  | ||||
|         n_snapshots: int | ||||
|         timestamp: float | ||||
|         n_snapshots, timestamp = struct.unpack(">Id", self.file.read(12)) | ||||
|         snapshots: list[Snapshot] = [] | ||||
|  | ||||
|         for _ in range(n_snapshots): | ||||
|             size: int = struct.unpack(">I", self.file.read(4))[0] | ||||
|             snapshots.append(Snapshot.unpack(self.file.read(size))) | ||||
|  | ||||
|         return snapshots | ||||
							
								
								
									
										115
									
								
								src/recorder.py
									
									
									
									
									
								
							
							
						
						
									
										115
									
								
								src/recorder.py
									
									
									
									
									
								
							| @@ -1,11 +1,16 @@ | ||||
| import os | ||||
| from pathlib import Path | ||||
| import socket | ||||
| import struct | ||||
| from typing import Optional | ||||
|  | ||||
| from PyQt6 import uic | ||||
| from PyQt6.QtCore import QObject, Qt, QThread, QTimer, pyqtSignal, pyqtSlot | ||||
| from PyQt6.QtCore import QObject, QThread, QTimer, pyqtSignal, pyqtSlot | ||||
| from PyQt6.QtGui import QKeyEvent | ||||
| from PyQt6.QtWidgets import QMainWindow | ||||
|  | ||||
| from src.command import CarControl, Command, ControlCommand | ||||
| from src.command import CarControl, Command, ControlCommand, RecordingCommand | ||||
| from src.record_file import RecordFile | ||||
| from src.recorder_ui import Ui_Recorder | ||||
| from src.snapshot import Snapshot | ||||
|  | ||||
| @@ -18,9 +23,9 @@ class RecorderClient(QObject): | ||||
|         super().__init__() | ||||
|         self.host: str = host | ||||
|         self.port: int = port | ||||
|         self.socket: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | ||||
|         self.timer: QTimer = QTimer(self) | ||||
|         self.timer.timeout.connect(self.poll_socket) | ||||
|         self.socket: socket.socket = socket.socket( | ||||
|             socket.AF_INET, socket.SOCK_STREAM) | ||||
|         self.timer: Optional[QTimer] = None | ||||
|         self.connected: bool = False | ||||
|  | ||||
|     @pyqtSlot() | ||||
| @@ -28,8 +33,10 @@ class RecorderClient(QObject): | ||||
|         self.socket.connect((self.host, self.port)) | ||||
|         self.socket.setblocking(False) | ||||
|         self.connected = True | ||||
|         self.timer = QTimer(self) | ||||
|         self.timer.timeout.connect(self.poll_socket) | ||||
|         self.timer.start(50) | ||||
|         print(f"Connected to server") | ||||
|         print("Connected to server") | ||||
|  | ||||
|     def poll_socket(self): | ||||
|         buffer: bytes = b"" | ||||
| @@ -78,15 +85,37 @@ class RecorderClient(QObject): | ||||
|     @pyqtSlot() | ||||
|     def shutdown(self): | ||||
|         print("Shutting down client") | ||||
|         self.timer.stop() | ||||
|         if self.timer is not None: | ||||
|             self.timer.stop() | ||||
|             self.timer = None | ||||
|         self.connected = False | ||||
|         self.socket.close() | ||||
|  | ||||
|  | ||||
| class ThreadedSaver(QThread): | ||||
|     def __init__(self, path: str | Path, snapshots: list[Snapshot]): | ||||
|         super().__init__() | ||||
|         self.path: str | Path = path | ||||
|         self.snapshots: list[Snapshot] = snapshots | ||||
|  | ||||
|     def run(self): | ||||
|         with RecordFile(self.path, "w") as f: | ||||
|             f.write_snapshots(self.snapshots) | ||||
|  | ||||
|  | ||||
| class RecorderWindow(Ui_Recorder, QMainWindow): | ||||
|     close_signal: pyqtSignal = pyqtSignal() | ||||
|     send_signal: pyqtSignal = pyqtSignal(object) | ||||
|  | ||||
|     SAVE_DIR: Path = Path(__file__).parent.parent / "records" | ||||
|  | ||||
|     COMMAND_DIRECTIONS: dict[str, CarControl] = { | ||||
|         "w": CarControl.FORWARD, | ||||
|         "s": CarControl.BACKWARD, | ||||
|         "d": CarControl.RIGHT, | ||||
|         "a": CarControl.LEFT, | ||||
|     } | ||||
|  | ||||
|     def __init__(self, host: str, port: int) -> None: | ||||
|         super().__init__() | ||||
|  | ||||
| @@ -102,13 +131,6 @@ class RecorderWindow(Ui_Recorder, QMainWindow): | ||||
|  | ||||
|         uic.load_ui.loadUi("src/recorder.ui", self) | ||||
|  | ||||
|         self.command_directions = { | ||||
|             "w": CarControl.FORWARD, | ||||
|             "s": CarControl.BACKWARD, | ||||
|             "d": CarControl.RIGHT, | ||||
|             "a": CarControl.LEFT, | ||||
|         } | ||||
|  | ||||
|         self.forwardButton.pressed.connect( | ||||
|             lambda: self.on_car_controlled(CarControl.FORWARD, True) | ||||
|         ) | ||||
| @@ -146,16 +168,39 @@ class RecorderWindow(Ui_Recorder, QMainWindow): | ||||
|  | ||||
|         self.saveRecordButton.clicked.connect(self.save_record) | ||||
|  | ||||
|         self.saving_worker: Optional[ThreadedSaver] = None | ||||
|         self.recording = False | ||||
|  | ||||
|         self.recorded_data = [] | ||||
|         self.snapshots: list[Snapshot] = [] | ||||
|         self.client_thread.start() | ||||
|  | ||||
|     def on_car_controlled(self, control: CarControl, active: bool): | ||||
|         self.send_command(ControlCommand(control, active)) | ||||
|  | ||||
|     def keyPressEvent(self, event):  # type: ignore | ||||
|         if event.isAutoRepeat(): | ||||
|             return | ||||
|  | ||||
|         if isinstance(event, QKeyEvent): | ||||
|             key_text = event.text() | ||||
|             ctrl: Optional[CarControl] = self.COMMAND_DIRECTIONS.get(key_text) | ||||
|             if ctrl is not None: | ||||
|                 self.on_car_controlled(ctrl, True) | ||||
|  | ||||
|     def keyReleaseEvent(self, event):  # type: ignore | ||||
|         if event.isAutoRepeat(): | ||||
|             return | ||||
|         if isinstance(event, QKeyEvent): | ||||
|             key_text = event.text() | ||||
|             ctrl: Optional[CarControl] = self.COMMAND_DIRECTIONS.get(key_text) | ||||
|             if ctrl is not None: | ||||
|                 self.on_car_controlled(ctrl, False) | ||||
|  | ||||
|     def toggle_record(self): | ||||
|         pass | ||||
|         self.recording = not self.recording | ||||
|         self.recordDataButton.setText( | ||||
|             "Recording..." if self.recording else "Record") | ||||
|         self.send_command(RecordingCommand(self.recording)) | ||||
|  | ||||
|     def rollback(self): | ||||
|         pass | ||||
| @@ -167,12 +212,44 @@ class RecorderWindow(Ui_Recorder, QMainWindow): | ||||
|         ) | ||||
|  | ||||
|     def save_record(self): | ||||
|         pass | ||||
|         if self.saving_worker is not None: | ||||
|             print("Already saving !") | ||||
|             return | ||||
|  | ||||
|         if len(self.snapshots) == 0: | ||||
|             print("No data to save !") | ||||
|             return | ||||
|  | ||||
|         if self.recording: | ||||
|             self.toggle_record() | ||||
|  | ||||
|         self.saveRecordButton.setText("Saving ...") | ||||
|  | ||||
|         self.SAVE_DIR.mkdir(exist_ok=True) | ||||
|  | ||||
|         record_name: str = "record_%d.rec" | ||||
|         fid = 0 | ||||
|         while os.path.exists(self.SAVE_DIR / (record_name % fid)): | ||||
|             fid += 1 | ||||
|  | ||||
|         self.saving_worker = ThreadedSaver( | ||||
|             self.SAVE_DIR / (record_name % fid), self.snapshots) | ||||
|         self.snapshots = [] | ||||
|         self.nbrSnapshotSaved.setText("0") | ||||
|         self.saving_worker.finished.connect(self.on_record_save_done) | ||||
|         self.saving_worker.start() | ||||
|  | ||||
|     def on_record_save_done(self): | ||||
|         if self.saving_worker is None: | ||||
|             return | ||||
|         print("Recorded data saved to", self.saving_worker.path) | ||||
|         self.saving_worker = None | ||||
|         self.saveRecordButton.setText("Save") | ||||
|  | ||||
|     @pyqtSlot(Snapshot) | ||||
|     def on_snapshot_received(self, snapshot: Snapshot): | ||||
|         self.recorded_data.append(snapshot) | ||||
|         self.nbrSnapshotSaved.setText(str(len(self.recorded_data))) | ||||
|         self.snapshots.append(snapshot) | ||||
|         self.nbrSnapshotSaved.setText(str(len(self.snapshots))) | ||||
|  | ||||
|     def shutdown(self): | ||||
|         self.close_signal.emit() | ||||
|   | ||||
| @@ -6,7 +6,9 @@ import struct | ||||
| import threading | ||||
| from typing import TYPE_CHECKING, Optional | ||||
|  | ||||
| from src.command import CarControl, Command, ControlCommand | ||||
| from src.command import CarControl, Command, ControlCommand, RecordingCommand | ||||
| from src.snapshot import Snapshot | ||||
| from src.utils import RepeatTimer | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from src.car import Car | ||||
| @@ -23,10 +25,13 @@ class RemoteController: | ||||
|         CarControl.RIGHT: "right", | ||||
|     } | ||||
|  | ||||
|     SNAPSHOT_INTERVAL = 0.1 | ||||
|  | ||||
|     def __init__(self, car: Car, port: int = DEFAULT_PORT) -> None: | ||||
|         self.car: Car = car | ||||
|         self.port: int = port | ||||
|         self.server: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | ||||
|         self.server: socket.socket = socket.socket( | ||||
|             socket.AF_INET, socket.SOCK_STREAM) | ||||
|         self.server_thread: threading.Thread = threading.Thread( | ||||
|             target=self.wait_for_connections, daemon=True | ||||
|         ) | ||||
| @@ -34,6 +39,10 @@ class RemoteController: | ||||
|         self.queue: queue.Queue[Command] = queue.Queue() | ||||
|         self.client_thread: Optional[threading.Thread] = None | ||||
|         self.client: Optional[socket.socket] = None | ||||
|         self.snapshot_timer: RepeatTimer = RepeatTimer( | ||||
|             interval=self.SNAPSHOT_INTERVAL, function=self.take_snapshot) | ||||
|         self.snapshot_timer.start() | ||||
|         self.recording: bool = False | ||||
|  | ||||
|     @property | ||||
|     def is_connected(self) -> bool: | ||||
| @@ -56,6 +65,7 @@ class RemoteController: | ||||
|         if self.client: | ||||
|             self.client.close() | ||||
|         self.server.close() | ||||
|         self.snapshot_timer.cancel() | ||||
|         self.running = False | ||||
|  | ||||
|     def on_client_connected(self, conn: socket.socket): | ||||
| @@ -107,6 +117,18 @@ class RemoteController: | ||||
|         match command: | ||||
|             case ControlCommand(control, active): | ||||
|                 self.set_control(control, active) | ||||
|             case RecordingCommand(state): | ||||
|                 self.recording = state | ||||
|  | ||||
|     def set_control(self, control: CarControl, active: bool): | ||||
|         setattr(self.car, self.CONTROL_ATTRIBUTES[control], active) | ||||
|  | ||||
|     def take_snapshot(self): | ||||
|         if self.client is None: | ||||
|             return | ||||
|         if not self.recording: | ||||
|             return | ||||
|  | ||||
|         snapshot: Snapshot = Snapshot.from_car(self.car) | ||||
|         payload: bytes = snapshot.pack() | ||||
|         self.client.sendall(struct.pack(">I", len(payload)) + payload) | ||||
|   | ||||
| @@ -2,12 +2,15 @@ from __future__ import annotations | ||||
|  | ||||
| import struct | ||||
| from dataclasses import dataclass, field | ||||
| from typing import Optional | ||||
| from typing import TYPE_CHECKING, Optional | ||||
|  | ||||
| import numpy as np | ||||
|  | ||||
| from src.vec import Vec | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from src.car import Car | ||||
|  | ||||
|  | ||||
| def iter_unpack(format, data): | ||||
|     nbr_bytes = struct.calcsize(format) | ||||
| @@ -20,7 +23,8 @@ class Snapshot: | ||||
|     position: Vec = field(default_factory=Vec) | ||||
|     direction: Vec = field(default_factory=Vec) | ||||
|     speed: float = 0 | ||||
|     raycast_distances: list[float] | tuple[float, ...] = field(default_factory=list) | ||||
|     raycast_distances: list[float] | tuple[float, ...] = field( | ||||
|         default_factory=list) | ||||
|     image: Optional[np.ndarray] = None | ||||
|  | ||||
|     def pack(self): | ||||
| @@ -36,10 +40,12 @@ class Snapshot: | ||||
|         ) | ||||
|  | ||||
|         nbr_raycasts: int = len(self.raycast_distances) | ||||
|         data += struct.pack(f">B{nbr_raycasts}f", nbr_raycasts, *self.raycast_distances) | ||||
|         data += struct.pack(f">B{nbr_raycasts}f", | ||||
|                             nbr_raycasts, *self.raycast_distances) | ||||
|  | ||||
|         if self.image is not None: | ||||
|             data += struct.pack(">II", self.image.shape[0], self.image.shape[1]) | ||||
|             data += struct.pack(">II", | ||||
|                                 self.image.shape[0], self.image.shape[1]) | ||||
|             data += self.image.tobytes() | ||||
|         else: | ||||
|             data += struct.pack(">II", 0, 0) | ||||
| @@ -72,3 +78,19 @@ class Snapshot: | ||||
|             raycast_distances=raycast_distances, | ||||
|             image=image, | ||||
|         ) | ||||
|  | ||||
|     @staticmethod | ||||
|     def from_car(car: Car) -> Snapshot: | ||||
|         return Snapshot( | ||||
|             controls=( | ||||
|                 car.forward, | ||||
|                 car.backward, | ||||
|                 car.left, | ||||
|                 car.right | ||||
|             ), | ||||
|             position=car.pos.copy(), | ||||
|             direction=car.direction.copy(), | ||||
|             speed=car.speed, | ||||
|             raycast_distances=car.rays.copy(), | ||||
|             image=None | ||||
|         ) | ||||
|   | ||||
							
								
								
									
										10
									
								
								src/utils.py
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								src/utils.py
									
									
									
									
									
								
							| @@ -1,10 +1,12 @@ | ||||
| import os | ||||
| from pathlib import Path | ||||
| from threading import Timer | ||||
| from typing import Optional | ||||
|  | ||||
| from src.vec import Vec | ||||
|  | ||||
| ROOT = Path(os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))) | ||||
| ROOT = Path(os.path.abspath(os.path.join( | ||||
|     os.path.dirname(__file__), os.pardir))) | ||||
|  | ||||
|  | ||||
| def orientation(a: Vec, b: Vec, c: Vec) -> float: | ||||
| @@ -59,3 +61,9 @@ def get_segments_intersection(a1: Vec, a2: Vec, b1: Vec, b2: Vec) -> Optional[Ve | ||||
|     if intersection.within(a1, a2) and intersection.within(b1, b2): | ||||
|         return intersection | ||||
|     return None | ||||
|  | ||||
|  | ||||
| class RepeatTimer(Timer): | ||||
|     def run(self): | ||||
|         while not self.finished.wait(self.interval): | ||||
|             self.function(*self.args, **self.kwargs) | ||||
|   | ||||
| @@ -8,6 +8,9 @@ class Vec: | ||||
|         self.x: float = x | ||||
|         self.y: float = y | ||||
|  | ||||
|     def copy(self) -> Vec: | ||||
|         return Vec(self.x, self.y) | ||||
|  | ||||
|     def __add__(self, other: float | Vec) -> Vec: | ||||
|         if isinstance(other, Vec): | ||||
|             return Vec(self.x + other.x, self.y + other.y) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user