Files
rally-racer/src/recorder.py

282 lines
9.0 KiB
Python

import os
from pathlib import Path
import socket
import struct
from typing import Optional
from PyQt6 import uic
from PyQt6.QtCore import QObject, QThread, QTimer, pyqtSignal, pyqtSlot
from PyQt6.QtGui import QKeyEvent
from PyQt6.QtWidgets import QMainWindow
from src.bot import Bot
from src.command import ApplySnapshotCommand, CarControl, Command, ControlCommand, RecordingCommand, ResetCommand
from src.record_file import RecordFile
from src.recorder_ui import Ui_Recorder
from src.snapshot import Snapshot
class RecorderClient(QObject):
DATA_CHUNK_SIZE = 65536
data_received: pyqtSignal = pyqtSignal(Snapshot)
def __init__(self, host: str, port: int) -> None:
super().__init__()
self.host: str = host
self.port: int = port
self.socket: socket.socket = socket.socket(
socket.AF_INET, socket.SOCK_STREAM)
self.timer: Optional[QTimer] = None
self.connected: bool = False
self.buffer: bytes = b""
@pyqtSlot()
def start(self):
self.socket.connect((self.host, self.port))
self.socket.setblocking(False)
self.connected = True
self.timer = QTimer(self)
self.timer.timeout.connect(self.poll_socket)
self.timer.start(50)
print("Connected to server")
def poll_socket(self):
if not self.connected:
return
try:
while True:
chunk: bytes = self.socket.recv(self.DATA_CHUNK_SIZE)
if not chunk:
return
self.buffer += chunk
while True:
if len(self.buffer) < 4:
break
msg_len: int = struct.unpack(">I", self.buffer[:4])[0]
msg_end: int = 4 + msg_len
if len(self.buffer) < msg_end:
break
message: bytes = self.buffer[4:msg_end]
self.buffer = self.buffer[msg_end:]
self.on_message(message)
except BlockingIOError:
pass
except Exception as e:
print(f"Socket error: {e}")
self.shutdown()
def on_message(self, message: bytes):
snapshot: Snapshot = Snapshot.unpack(message)
self.data_received.emit(snapshot)
@pyqtSlot(object)
def send_command(self, command):
if self.connected:
try:
payload: bytes = command.pack()
self.socket.sendall(struct.pack(">I", len(payload)) + payload)
except Exception as e:
print(f"An exception occured: {e}")
self.shutdown()
else:
print("Not connected")
@pyqtSlot()
def shutdown(self):
print("Shutting down client")
if self.timer is not None:
self.timer.stop()
self.timer = None
self.connected = False
self.socket.close()
class ThreadedSaver(QThread):
def __init__(self, path: str | Path, snapshots: list[Snapshot]):
super().__init__()
self.path: str | Path = path
self.snapshots: list[Snapshot] = snapshots
def run(self):
with RecordFile(self.path, "w") as f:
f.write_snapshots(self.snapshots)
class RecorderWindow(Ui_Recorder, QMainWindow):
close_signal: pyqtSignal = pyqtSignal()
send_signal: pyqtSignal = pyqtSignal(object)
SAVE_DIR: Path = Path(__file__).parent.parent / "records"
COMMAND_DIRECTIONS: dict[str, CarControl] = {
"w": CarControl.FORWARD,
"s": CarControl.BACKWARD,
"d": CarControl.RIGHT,
"a": CarControl.LEFT,
}
def __init__(self, host: str, port: int) -> None:
super().__init__()
self.host: str = host
self.port: int = port
self.client_thread: QThread = QThread()
self.client: RecorderClient = RecorderClient(self.host, self.port)
self.client.data_received.connect(self.on_snapshot_received)
self.client.moveToThread(self.client_thread)
self.client_thread.started.connect(self.client.start)
self.close_signal.connect(self.client.shutdown)
self.send_signal.connect(self.client.send_command)
uic.load_ui.loadUi("src/recorder.ui", self)
self.forwardButton.pressed.connect(
lambda: self.on_car_controlled(CarControl.FORWARD, True)
)
self.forwardButton.released.connect(
lambda: self.on_car_controlled(CarControl.FORWARD, False)
)
self.backwardButton.pressed.connect(
lambda: self.on_car_controlled(CarControl.BACKWARD, True)
)
self.backwardButton.released.connect(
lambda: self.on_car_controlled(CarControl.BACKWARD, False)
)
self.rightButton.pressed.connect(
lambda: self.on_car_controlled(CarControl.RIGHT, True)
)
self.rightButton.released.connect(
lambda: self.on_car_controlled(CarControl.RIGHT, False)
)
self.leftButton.pressed.connect(
lambda: self.on_car_controlled(CarControl.LEFT, True)
)
self.leftButton.released.connect(
lambda: self.on_car_controlled(CarControl.LEFT, False)
)
self.recordDataButton.clicked.connect(self.toggle_record)
self.resetButton.clicked.connect(self.rollback)
self.bot: Optional[Bot] = None
self.autopiloting = False
self.autopilotButton.clicked.connect(self.toggle_autopilot)
self.autopilotButton.setDisabled(True)
self.saveRecordButton.clicked.connect(self.save_record)
self.saving_worker: Optional[ThreadedSaver] = None
self.recording = False
self.snapshots: list[Snapshot] = []
self.client_thread.start()
def on_car_controlled(self, control: CarControl, active: bool):
self.send_command(ControlCommand(control, active))
def keyPressEvent(self, event): # type: ignore
if event.isAutoRepeat():
return
if isinstance(event, QKeyEvent):
key_text = event.text()
ctrl: Optional[CarControl] = self.COMMAND_DIRECTIONS.get(key_text)
if ctrl is not None:
self.on_car_controlled(ctrl, True)
def keyReleaseEvent(self, event): # type: ignore
if event.isAutoRepeat():
return
if isinstance(event, QKeyEvent):
key_text = event.text()
ctrl: Optional[CarControl] = self.COMMAND_DIRECTIONS.get(key_text)
if ctrl is not None:
self.on_car_controlled(ctrl, False)
def toggle_record(self):
self.recording = not self.recording
self.recordDataButton.setText(
"Recording..." if self.recording else "Record")
self.send_command(RecordingCommand(self.recording))
def rollback(self):
rollback_by: int = self.forgetSnapshotNumber.value()
rollback_by = max(0, min(rollback_by, len(self.snapshots) - 1))
self.snapshots = self.snapshots[:-rollback_by]
self.nbrSnapshotSaved.setText(str(len(self.snapshots)))
if len(self.snapshots) == 0:
self.send_command(ResetCommand())
else:
self.send_command(ApplySnapshotCommand(self.snapshots[-1]))
if self.recording:
self.toggle_record()
def toggle_autopilot(self):
self.autopiloting = not self.autopiloting
self.autopilotButton.setText(
"AutoPilot:\n" + ("ON" if self.autopiloting else "OFF")
)
def save_record(self):
if self.saving_worker is not None:
print("Already saving !")
return
if len(self.snapshots) == 0:
print("No data to save !")
return
if self.recording:
self.toggle_record()
self.saveRecordButton.setText("Saving ...")
self.SAVE_DIR.mkdir(exist_ok=True)
record_name: str = "record_%d.rec"
fid = 0
while os.path.exists(self.SAVE_DIR / (record_name % fid)):
fid += 1
self.saving_worker = ThreadedSaver(
self.SAVE_DIR / (record_name % fid), self.snapshots)
self.snapshots = []
self.nbrSnapshotSaved.setText("0")
self.saving_worker.finished.connect(self.on_record_save_done)
self.saving_worker.start()
def on_record_save_done(self):
if self.saving_worker is None:
return
print("Recorded data saved to", self.saving_worker.path)
self.saving_worker = None
self.saveRecordButton.setText("Save")
@pyqtSlot(Snapshot)
def on_snapshot_received(self, snapshot: Snapshot):
self.snapshots.append(snapshot)
self.nbrSnapshotSaved.setText(str(len(self.snapshots)))
if self.autopiloting and self.bot is not None:
self.bot.on_snapshot_received(snapshot)
def shutdown(self):
self.close_signal.emit()
def send_command(self, command: Command):
self.send_signal.emit(command)
def register_bot(self, bot: Bot):
self.bot = bot
self.autopilotButton.setDisabled(False)