feat: add snapshot recording

This commit is contained in:
2025-10-24 17:22:29 +02:00
parent 8542ee81e7
commit db112ada4c
6 changed files with 89 additions and 12 deletions

View File

@@ -8,6 +8,7 @@ from typing import Type
class CommandType(IntEnum): class CommandType(IntEnum):
CAR_CONTROL = 0 CAR_CONTROL = 0
RECORDING = 1
class CarControl(IntEnum): class CarControl(IntEnum):
@@ -64,3 +65,20 @@ class ControlCommand(Command):
active: bool = (value & 1) == 1 active: bool = (value & 1) == 1
control: int = value >> 1 control: int = value >> 1
return ControlCommand(CarControl(control), active) return ControlCommand(CarControl(control), active)
class RecordingCommand(Command):
TYPE = CommandType.RECORDING
__match_args__ = ("state",)
def __init__(self, state: bool) -> None:
super().__init__()
self.state: bool = state
def get_payload(self) -> bytes:
return struct.pack(">B", self.state)
@classmethod
def from_payload(cls, payload: bytes) -> Command:
state: bool = struct.unpack(">B", payload)[0]
return RecordingCommand(state)

View File

@@ -2,10 +2,10 @@ import socket
import struct import struct
from PyQt6 import uic from PyQt6 import uic
from PyQt6.QtCore import QObject, Qt, QThread, QTimer, pyqtSignal, pyqtSlot from PyQt6.QtCore import QObject, QThread, QTimer, pyqtSignal, pyqtSlot
from PyQt6.QtWidgets import QMainWindow from PyQt6.QtWidgets import QMainWindow
from src.command import CarControl, Command, ControlCommand from src.command import CarControl, Command, ControlCommand, RecordingCommand
from src.recorder_ui import Ui_Recorder from src.recorder_ui import Ui_Recorder
from src.snapshot import Snapshot from src.snapshot import Snapshot
@@ -18,7 +18,8 @@ class RecorderClient(QObject):
super().__init__() super().__init__()
self.host: str = host self.host: str = host
self.port: int = port self.port: int = port
self.socket: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.socket: socket.socket = socket.socket(
socket.AF_INET, socket.SOCK_STREAM)
self.timer: QTimer = QTimer(self) self.timer: QTimer = QTimer(self)
self.timer.timeout.connect(self.poll_socket) self.timer.timeout.connect(self.poll_socket)
self.connected: bool = False self.connected: bool = False
@@ -29,7 +30,7 @@ class RecorderClient(QObject):
self.socket.setblocking(False) self.socket.setblocking(False)
self.connected = True self.connected = True
self.timer.start(50) self.timer.start(50)
print(f"Connected to server") print("Connected to server")
def poll_socket(self): def poll_socket(self):
buffer: bytes = b"" buffer: bytes = b""
@@ -155,7 +156,10 @@ class RecorderWindow(Ui_Recorder, QMainWindow):
self.send_command(ControlCommand(control, active)) self.send_command(ControlCommand(control, active))
def toggle_record(self): def toggle_record(self):
pass self.recording = not self.recording
self.recordDataButton.setText(
"Recording..." if self.recording else "Record")
self.send_command(RecordingCommand(self.recording))
def rollback(self): def rollback(self):
pass pass

View File

@@ -6,7 +6,9 @@ import struct
import threading import threading
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
from src.command import CarControl, Command, ControlCommand from src.command import CarControl, Command, ControlCommand, RecordingCommand
from src.snapshot import Snapshot
from src.utils import RepeatTimer
if TYPE_CHECKING: if TYPE_CHECKING:
from src.car import Car from src.car import Car
@@ -23,10 +25,13 @@ class RemoteController:
CarControl.RIGHT: "right", CarControl.RIGHT: "right",
} }
SNAPSHOT_INTERVAL = 0.1
def __init__(self, car: Car, port: int = DEFAULT_PORT) -> None: def __init__(self, car: Car, port: int = DEFAULT_PORT) -> None:
self.car: Car = car self.car: Car = car
self.port: int = port self.port: int = port
self.server: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.server: socket.socket = socket.socket(
socket.AF_INET, socket.SOCK_STREAM)
self.server_thread: threading.Thread = threading.Thread( self.server_thread: threading.Thread = threading.Thread(
target=self.wait_for_connections, daemon=True target=self.wait_for_connections, daemon=True
) )
@@ -34,6 +39,10 @@ class RemoteController:
self.queue: queue.Queue[Command] = queue.Queue() self.queue: queue.Queue[Command] = queue.Queue()
self.client_thread: Optional[threading.Thread] = None self.client_thread: Optional[threading.Thread] = None
self.client: Optional[socket.socket] = None self.client: Optional[socket.socket] = None
self.snapshot_timer: RepeatTimer = RepeatTimer(
interval=self.SNAPSHOT_INTERVAL, function=self.take_snapshot)
self.snapshot_timer.start()
self.recording: bool = False
@property @property
def is_connected(self) -> bool: def is_connected(self) -> bool:
@@ -56,6 +65,7 @@ class RemoteController:
if self.client: if self.client:
self.client.close() self.client.close()
self.server.close() self.server.close()
self.snapshot_timer.cancel()
self.running = False self.running = False
def on_client_connected(self, conn: socket.socket): def on_client_connected(self, conn: socket.socket):
@@ -107,6 +117,18 @@ class RemoteController:
match command: match command:
case ControlCommand(control, active): case ControlCommand(control, active):
self.set_control(control, active) self.set_control(control, active)
case RecordingCommand(state):
self.recording = state
def set_control(self, control: CarControl, active: bool): def set_control(self, control: CarControl, active: bool):
setattr(self.car, self.CONTROL_ATTRIBUTES[control], active) setattr(self.car, self.CONTROL_ATTRIBUTES[control], active)
def take_snapshot(self):
if self.client is None:
return
if not self.recording:
return
snapshot: Snapshot = Snapshot.from_car(self.car)
payload: bytes = snapshot.pack()
self.client.sendall(struct.pack(">I", len(payload)) + payload)

View File

@@ -2,12 +2,15 @@ from __future__ import annotations
import struct import struct
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional from typing import TYPE_CHECKING, Optional
import numpy as np import numpy as np
from src.vec import Vec from src.vec import Vec
if TYPE_CHECKING:
from src.car import Car
def iter_unpack(format, data): def iter_unpack(format, data):
nbr_bytes = struct.calcsize(format) nbr_bytes = struct.calcsize(format)
@@ -20,7 +23,8 @@ class Snapshot:
position: Vec = field(default_factory=Vec) position: Vec = field(default_factory=Vec)
direction: Vec = field(default_factory=Vec) direction: Vec = field(default_factory=Vec)
speed: float = 0 speed: float = 0
raycast_distances: list[float] | tuple[float, ...] = field(default_factory=list) raycast_distances: list[float] | tuple[float, ...] = field(
default_factory=list)
image: Optional[np.ndarray] = None image: Optional[np.ndarray] = None
def pack(self): def pack(self):
@@ -36,10 +40,12 @@ class Snapshot:
) )
nbr_raycasts: int = len(self.raycast_distances) nbr_raycasts: int = len(self.raycast_distances)
data += struct.pack(f">B{nbr_raycasts}f", nbr_raycasts, *self.raycast_distances) data += struct.pack(f">B{nbr_raycasts}f",
nbr_raycasts, *self.raycast_distances)
if self.image is not None: if self.image is not None:
data += struct.pack(">II", self.image.shape[0], self.image.shape[1]) data += struct.pack(">II",
self.image.shape[0], self.image.shape[1])
data += self.image.tobytes() data += self.image.tobytes()
else: else:
data += struct.pack(">II", 0, 0) data += struct.pack(">II", 0, 0)
@@ -72,3 +78,19 @@ class Snapshot:
raycast_distances=raycast_distances, raycast_distances=raycast_distances,
image=image, image=image,
) )
@staticmethod
def from_car(car: Car) -> Snapshot:
return Snapshot(
controls=(
car.forward,
car.backward,
car.left,
car.right
),
position=car.pos.copy(),
direction=car.direction.copy(),
speed=car.speed,
raycast_distances=car.rays.copy(),
image=None
)

View File

@@ -1,10 +1,12 @@
import os import os
from pathlib import Path from pathlib import Path
from threading import Timer
from typing import Optional from typing import Optional
from src.vec import Vec from src.vec import Vec
ROOT = Path(os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))) ROOT = Path(os.path.abspath(os.path.join(
os.path.dirname(__file__), os.pardir)))
def orientation(a: Vec, b: Vec, c: Vec) -> float: def orientation(a: Vec, b: Vec, c: Vec) -> float:
@@ -59,3 +61,9 @@ def get_segments_intersection(a1: Vec, a2: Vec, b1: Vec, b2: Vec) -> Optional[Ve
if intersection.within(a1, a2) and intersection.within(b1, b2): if intersection.within(a1, a2) and intersection.within(b1, b2):
return intersection return intersection
return None return None
class RepeatTimer(Timer):
def run(self):
while not self.finished.wait(self.interval):
self.function(*self.args, **self.kwargs)

View File

@@ -8,6 +8,9 @@ class Vec:
self.x: float = x self.x: float = x
self.y: float = y self.y: float = y
def copy(self) -> Vec:
return Vec(self.x, self.y)
def __add__(self, other: float | Vec) -> Vec: def __add__(self, other: float | Vec) -> Vec:
if isinstance(other, Vec): if isinstance(other, Vec):
return Vec(self.x + other.x, self.y + other.y) return Vec(self.x + other.x, self.y + other.y)