feat: add snapshot recording

This commit is contained in:
2025-10-24 17:22:29 +02:00
parent 8542ee81e7
commit db112ada4c
6 changed files with 89 additions and 12 deletions

View File

@@ -8,6 +8,7 @@ from typing import Type
class CommandType(IntEnum):
CAR_CONTROL = 0
RECORDING = 1
class CarControl(IntEnum):
@@ -64,3 +65,20 @@ class ControlCommand(Command):
active: bool = (value & 1) == 1
control: int = value >> 1
return ControlCommand(CarControl(control), active)
class RecordingCommand(Command):
TYPE = CommandType.RECORDING
__match_args__ = ("state",)
def __init__(self, state: bool) -> None:
super().__init__()
self.state: bool = state
def get_payload(self) -> bytes:
return struct.pack(">B", self.state)
@classmethod
def from_payload(cls, payload: bytes) -> Command:
state: bool = struct.unpack(">B", payload)[0]
return RecordingCommand(state)

View File

@@ -2,10 +2,10 @@ import socket
import struct
from PyQt6 import uic
from PyQt6.QtCore import QObject, Qt, QThread, QTimer, pyqtSignal, pyqtSlot
from PyQt6.QtCore import QObject, QThread, QTimer, pyqtSignal, pyqtSlot
from PyQt6.QtWidgets import QMainWindow
from src.command import CarControl, Command, ControlCommand
from src.command import CarControl, Command, ControlCommand, RecordingCommand
from src.recorder_ui import Ui_Recorder
from src.snapshot import Snapshot
@@ -18,7 +18,8 @@ class RecorderClient(QObject):
super().__init__()
self.host: str = host
self.port: int = port
self.socket: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket: socket.socket = socket.socket(
socket.AF_INET, socket.SOCK_STREAM)
self.timer: QTimer = QTimer(self)
self.timer.timeout.connect(self.poll_socket)
self.connected: bool = False
@@ -29,7 +30,7 @@ class RecorderClient(QObject):
self.socket.setblocking(False)
self.connected = True
self.timer.start(50)
print(f"Connected to server")
print("Connected to server")
def poll_socket(self):
buffer: bytes = b""
@@ -155,7 +156,10 @@ class RecorderWindow(Ui_Recorder, QMainWindow):
self.send_command(ControlCommand(control, active))
def toggle_record(self):
pass
self.recording = not self.recording
self.recordDataButton.setText(
"Recording..." if self.recording else "Record")
self.send_command(RecordingCommand(self.recording))
def rollback(self):
pass

View File

@@ -6,7 +6,9 @@ import struct
import threading
from typing import TYPE_CHECKING, Optional
from src.command import CarControl, Command, ControlCommand
from src.command import CarControl, Command, ControlCommand, RecordingCommand
from src.snapshot import Snapshot
from src.utils import RepeatTimer
if TYPE_CHECKING:
from src.car import Car
@@ -23,10 +25,13 @@ class RemoteController:
CarControl.RIGHT: "right",
}
SNAPSHOT_INTERVAL = 0.1
def __init__(self, car: Car, port: int = DEFAULT_PORT) -> None:
self.car: Car = car
self.port: int = port
self.server: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.server: socket.socket = socket.socket(
socket.AF_INET, socket.SOCK_STREAM)
self.server_thread: threading.Thread = threading.Thread(
target=self.wait_for_connections, daemon=True
)
@@ -34,6 +39,10 @@ class RemoteController:
self.queue: queue.Queue[Command] = queue.Queue()
self.client_thread: Optional[threading.Thread] = None
self.client: Optional[socket.socket] = None
self.snapshot_timer: RepeatTimer = RepeatTimer(
interval=self.SNAPSHOT_INTERVAL, function=self.take_snapshot)
self.snapshot_timer.start()
self.recording: bool = False
@property
def is_connected(self) -> bool:
@@ -56,6 +65,7 @@ class RemoteController:
if self.client:
self.client.close()
self.server.close()
self.snapshot_timer.cancel()
self.running = False
def on_client_connected(self, conn: socket.socket):
@@ -107,6 +117,18 @@ class RemoteController:
match command:
case ControlCommand(control, active):
self.set_control(control, active)
case RecordingCommand(state):
self.recording = state
def set_control(self, control: CarControl, active: bool):
setattr(self.car, self.CONTROL_ATTRIBUTES[control], active)
def take_snapshot(self):
if self.client is None:
return
if not self.recording:
return
snapshot: Snapshot = Snapshot.from_car(self.car)
payload: bytes = snapshot.pack()
self.client.sendall(struct.pack(">I", len(payload)) + payload)

View File

@@ -2,12 +2,15 @@ from __future__ import annotations
import struct
from dataclasses import dataclass, field
from typing import Optional
from typing import TYPE_CHECKING, Optional
import numpy as np
from src.vec import Vec
if TYPE_CHECKING:
from src.car import Car
def iter_unpack(format, data):
nbr_bytes = struct.calcsize(format)
@@ -20,7 +23,8 @@ class Snapshot:
position: Vec = field(default_factory=Vec)
direction: Vec = field(default_factory=Vec)
speed: float = 0
raycast_distances: list[float] | tuple[float, ...] = field(default_factory=list)
raycast_distances: list[float] | tuple[float, ...] = field(
default_factory=list)
image: Optional[np.ndarray] = None
def pack(self):
@@ -36,10 +40,12 @@ class Snapshot:
)
nbr_raycasts: int = len(self.raycast_distances)
data += struct.pack(f">B{nbr_raycasts}f", nbr_raycasts, *self.raycast_distances)
data += struct.pack(f">B{nbr_raycasts}f",
nbr_raycasts, *self.raycast_distances)
if self.image is not None:
data += struct.pack(">II", self.image.shape[0], self.image.shape[1])
data += struct.pack(">II",
self.image.shape[0], self.image.shape[1])
data += self.image.tobytes()
else:
data += struct.pack(">II", 0, 0)
@@ -72,3 +78,19 @@ class Snapshot:
raycast_distances=raycast_distances,
image=image,
)
@staticmethod
def from_car(car: Car) -> Snapshot:
return Snapshot(
controls=(
car.forward,
car.backward,
car.left,
car.right
),
position=car.pos.copy(),
direction=car.direction.copy(),
speed=car.speed,
raycast_distances=car.rays.copy(),
image=None
)

View File

@@ -1,10 +1,12 @@
import os
from pathlib import Path
from threading import Timer
from typing import Optional
from src.vec import Vec
ROOT = Path(os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)))
ROOT = Path(os.path.abspath(os.path.join(
os.path.dirname(__file__), os.pardir)))
def orientation(a: Vec, b: Vec, c: Vec) -> float:
@@ -59,3 +61,9 @@ def get_segments_intersection(a1: Vec, a2: Vec, b1: Vec, b2: Vec) -> Optional[Ve
if intersection.within(a1, a2) and intersection.within(b1, b2):
return intersection
return None
class RepeatTimer(Timer):
def run(self):
while not self.finished.wait(self.interval):
self.function(*self.args, **self.kwargs)

View File

@@ -8,6 +8,9 @@ class Vec:
self.x: float = x
self.y: float = y
def copy(self) -> Vec:
return Vec(self.x, self.y)
def __add__(self, other: float | Vec) -> Vec:
if isinstance(other, Vec):
return Vec(self.x + other.x, self.y + other.y)