feat: add snapshot recording
This commit is contained in:
@@ -8,6 +8,7 @@ from typing import Type
|
||||
|
||||
class CommandType(IntEnum):
|
||||
CAR_CONTROL = 0
|
||||
RECORDING = 1
|
||||
|
||||
|
||||
class CarControl(IntEnum):
|
||||
@@ -64,3 +65,20 @@ class ControlCommand(Command):
|
||||
active: bool = (value & 1) == 1
|
||||
control: int = value >> 1
|
||||
return ControlCommand(CarControl(control), active)
|
||||
|
||||
|
||||
class RecordingCommand(Command):
|
||||
TYPE = CommandType.RECORDING
|
||||
__match_args__ = ("state",)
|
||||
|
||||
def __init__(self, state: bool) -> None:
|
||||
super().__init__()
|
||||
self.state: bool = state
|
||||
|
||||
def get_payload(self) -> bytes:
|
||||
return struct.pack(">B", self.state)
|
||||
|
||||
@classmethod
|
||||
def from_payload(cls, payload: bytes) -> Command:
|
||||
state: bool = struct.unpack(">B", payload)[0]
|
||||
return RecordingCommand(state)
|
||||
|
||||
@@ -2,10 +2,10 @@ import socket
|
||||
import struct
|
||||
|
||||
from PyQt6 import uic
|
||||
from PyQt6.QtCore import QObject, Qt, QThread, QTimer, pyqtSignal, pyqtSlot
|
||||
from PyQt6.QtCore import QObject, QThread, QTimer, pyqtSignal, pyqtSlot
|
||||
from PyQt6.QtWidgets import QMainWindow
|
||||
|
||||
from src.command import CarControl, Command, ControlCommand
|
||||
from src.command import CarControl, Command, ControlCommand, RecordingCommand
|
||||
from src.recorder_ui import Ui_Recorder
|
||||
from src.snapshot import Snapshot
|
||||
|
||||
@@ -18,7 +18,8 @@ class RecorderClient(QObject):
|
||||
super().__init__()
|
||||
self.host: str = host
|
||||
self.port: int = port
|
||||
self.socket: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self.socket: socket.socket = socket.socket(
|
||||
socket.AF_INET, socket.SOCK_STREAM)
|
||||
self.timer: QTimer = QTimer(self)
|
||||
self.timer.timeout.connect(self.poll_socket)
|
||||
self.connected: bool = False
|
||||
@@ -29,7 +30,7 @@ class RecorderClient(QObject):
|
||||
self.socket.setblocking(False)
|
||||
self.connected = True
|
||||
self.timer.start(50)
|
||||
print(f"Connected to server")
|
||||
print("Connected to server")
|
||||
|
||||
def poll_socket(self):
|
||||
buffer: bytes = b""
|
||||
@@ -155,7 +156,10 @@ class RecorderWindow(Ui_Recorder, QMainWindow):
|
||||
self.send_command(ControlCommand(control, active))
|
||||
|
||||
def toggle_record(self):
|
||||
pass
|
||||
self.recording = not self.recording
|
||||
self.recordDataButton.setText(
|
||||
"Recording..." if self.recording else "Record")
|
||||
self.send_command(RecordingCommand(self.recording))
|
||||
|
||||
def rollback(self):
|
||||
pass
|
||||
|
||||
@@ -6,7 +6,9 @@ import struct
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from src.command import CarControl, Command, ControlCommand
|
||||
from src.command import CarControl, Command, ControlCommand, RecordingCommand
|
||||
from src.snapshot import Snapshot
|
||||
from src.utils import RepeatTimer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.car import Car
|
||||
@@ -23,10 +25,13 @@ class RemoteController:
|
||||
CarControl.RIGHT: "right",
|
||||
}
|
||||
|
||||
SNAPSHOT_INTERVAL = 0.1
|
||||
|
||||
def __init__(self, car: Car, port: int = DEFAULT_PORT) -> None:
|
||||
self.car: Car = car
|
||||
self.port: int = port
|
||||
self.server: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self.server: socket.socket = socket.socket(
|
||||
socket.AF_INET, socket.SOCK_STREAM)
|
||||
self.server_thread: threading.Thread = threading.Thread(
|
||||
target=self.wait_for_connections, daemon=True
|
||||
)
|
||||
@@ -34,6 +39,10 @@ class RemoteController:
|
||||
self.queue: queue.Queue[Command] = queue.Queue()
|
||||
self.client_thread: Optional[threading.Thread] = None
|
||||
self.client: Optional[socket.socket] = None
|
||||
self.snapshot_timer: RepeatTimer = RepeatTimer(
|
||||
interval=self.SNAPSHOT_INTERVAL, function=self.take_snapshot)
|
||||
self.snapshot_timer.start()
|
||||
self.recording: bool = False
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
@@ -56,6 +65,7 @@ class RemoteController:
|
||||
if self.client:
|
||||
self.client.close()
|
||||
self.server.close()
|
||||
self.snapshot_timer.cancel()
|
||||
self.running = False
|
||||
|
||||
def on_client_connected(self, conn: socket.socket):
|
||||
@@ -107,6 +117,18 @@ class RemoteController:
|
||||
match command:
|
||||
case ControlCommand(control, active):
|
||||
self.set_control(control, active)
|
||||
case RecordingCommand(state):
|
||||
self.recording = state
|
||||
|
||||
def set_control(self, control: CarControl, active: bool):
|
||||
setattr(self.car, self.CONTROL_ATTRIBUTES[control], active)
|
||||
|
||||
def take_snapshot(self):
|
||||
if self.client is None:
|
||||
return
|
||||
if not self.recording:
|
||||
return
|
||||
|
||||
snapshot: Snapshot = Snapshot.from_car(self.car)
|
||||
payload: bytes = snapshot.pack()
|
||||
self.client.sendall(struct.pack(">I", len(payload)) + payload)
|
||||
|
||||
@@ -2,12 +2,15 @@ from __future__ import annotations
|
||||
|
||||
import struct
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from src.vec import Vec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.car import Car
|
||||
|
||||
|
||||
def iter_unpack(format, data):
|
||||
nbr_bytes = struct.calcsize(format)
|
||||
@@ -20,7 +23,8 @@ class Snapshot:
|
||||
position: Vec = field(default_factory=Vec)
|
||||
direction: Vec = field(default_factory=Vec)
|
||||
speed: float = 0
|
||||
raycast_distances: list[float] | tuple[float, ...] = field(default_factory=list)
|
||||
raycast_distances: list[float] | tuple[float, ...] = field(
|
||||
default_factory=list)
|
||||
image: Optional[np.ndarray] = None
|
||||
|
||||
def pack(self):
|
||||
@@ -36,10 +40,12 @@ class Snapshot:
|
||||
)
|
||||
|
||||
nbr_raycasts: int = len(self.raycast_distances)
|
||||
data += struct.pack(f">B{nbr_raycasts}f", nbr_raycasts, *self.raycast_distances)
|
||||
data += struct.pack(f">B{nbr_raycasts}f",
|
||||
nbr_raycasts, *self.raycast_distances)
|
||||
|
||||
if self.image is not None:
|
||||
data += struct.pack(">II", self.image.shape[0], self.image.shape[1])
|
||||
data += struct.pack(">II",
|
||||
self.image.shape[0], self.image.shape[1])
|
||||
data += self.image.tobytes()
|
||||
else:
|
||||
data += struct.pack(">II", 0, 0)
|
||||
@@ -72,3 +78,19 @@ class Snapshot:
|
||||
raycast_distances=raycast_distances,
|
||||
image=image,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_car(car: Car) -> Snapshot:
|
||||
return Snapshot(
|
||||
controls=(
|
||||
car.forward,
|
||||
car.backward,
|
||||
car.left,
|
||||
car.right
|
||||
),
|
||||
position=car.pos.copy(),
|
||||
direction=car.direction.copy(),
|
||||
speed=car.speed,
|
||||
raycast_distances=car.rays.copy(),
|
||||
image=None
|
||||
)
|
||||
|
||||
10
src/utils.py
10
src/utils.py
@@ -1,10 +1,12 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from threading import Timer
|
||||
from typing import Optional
|
||||
|
||||
from src.vec import Vec
|
||||
|
||||
ROOT = Path(os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)))
|
||||
ROOT = Path(os.path.abspath(os.path.join(
|
||||
os.path.dirname(__file__), os.pardir)))
|
||||
|
||||
|
||||
def orientation(a: Vec, b: Vec, c: Vec) -> float:
|
||||
@@ -59,3 +61,9 @@ def get_segments_intersection(a1: Vec, a2: Vec, b1: Vec, b2: Vec) -> Optional[Ve
|
||||
if intersection.within(a1, a2) and intersection.within(b1, b2):
|
||||
return intersection
|
||||
return None
|
||||
|
||||
|
||||
class RepeatTimer(Timer):
|
||||
def run(self):
|
||||
while not self.finished.wait(self.interval):
|
||||
self.function(*self.args, **self.kwargs)
|
||||
|
||||
@@ -8,6 +8,9 @@ class Vec:
|
||||
self.x: float = x
|
||||
self.y: float = y
|
||||
|
||||
def copy(self) -> Vec:
|
||||
return Vec(self.x, self.y)
|
||||
|
||||
def __add__(self, other: float | Vec) -> Vec:
|
||||
if isinstance(other, Vec):
|
||||
return Vec(self.x + other.x, self.y + other.y)
|
||||
|
||||
Reference in New Issue
Block a user