282 lines
9.0 KiB
Python
282 lines
9.0 KiB
Python
import os
|
|
from pathlib import Path
|
|
import socket
|
|
import struct
|
|
from typing import Optional
|
|
|
|
from PyQt6 import uic
|
|
from PyQt6.QtCore import QObject, QThread, QTimer, pyqtSignal, pyqtSlot
|
|
from PyQt6.QtGui import QKeyEvent
|
|
from PyQt6.QtWidgets import QMainWindow
|
|
|
|
from src.bot import Bot
|
|
from src.command import ApplySnapshotCommand, CarControl, Command, ControlCommand, RecordingCommand, ResetCommand
|
|
from src.record_file import RecordFile
|
|
from src.recorder_ui import Ui_Recorder
|
|
from src.snapshot import Snapshot
|
|
|
|
|
|
class RecorderClient(QObject):
|
|
DATA_CHUNK_SIZE = 4096
|
|
data_received: pyqtSignal = pyqtSignal(Snapshot)
|
|
|
|
def __init__(self, host: str, port: int) -> None:
|
|
super().__init__()
|
|
self.host: str = host
|
|
self.port: int = port
|
|
self.socket: socket.socket = socket.socket(
|
|
socket.AF_INET, socket.SOCK_STREAM)
|
|
self.timer: Optional[QTimer] = None
|
|
self.connected: bool = False
|
|
self.buffer: bytes = b""
|
|
|
|
@pyqtSlot()
|
|
def start(self):
|
|
self.socket.connect((self.host, self.port))
|
|
self.socket.setblocking(False)
|
|
self.connected = True
|
|
self.timer = QTimer(self)
|
|
self.timer.timeout.connect(self.poll_socket)
|
|
self.timer.start(50)
|
|
print("Connected to server")
|
|
|
|
def poll_socket(self):
|
|
if not self.connected:
|
|
return
|
|
|
|
try:
|
|
while True:
|
|
chunk: bytes = self.socket.recv(self.DATA_CHUNK_SIZE)
|
|
if not chunk:
|
|
return
|
|
self.buffer += chunk
|
|
|
|
while True:
|
|
if len(self.buffer) < 4:
|
|
break
|
|
msg_len: int = struct.unpack(">I", self.buffer[:4])[0]
|
|
msg_end: int = 4 + msg_len
|
|
if len(self.buffer) < msg_end:
|
|
break
|
|
|
|
message: bytes = self.buffer[4:msg_end]
|
|
self.buffer = self.buffer[msg_end:]
|
|
self.on_message(message)
|
|
except BlockingIOError:
|
|
pass
|
|
except Exception as e:
|
|
print(f"Socket error: {e}")
|
|
self.shutdown()
|
|
|
|
def on_message(self, message: bytes):
|
|
snapshot: Snapshot = Snapshot.unpack(message)
|
|
self.data_received.emit(snapshot)
|
|
|
|
@pyqtSlot(object)
|
|
def send_command(self, command):
|
|
if self.connected:
|
|
try:
|
|
payload: bytes = command.pack()
|
|
self.socket.sendall(struct.pack(">I", len(payload)) + payload)
|
|
except Exception as e:
|
|
print(f"An exception occured: {e}")
|
|
self.shutdown()
|
|
else:
|
|
print("Not connected")
|
|
|
|
@pyqtSlot()
|
|
def shutdown(self):
|
|
print("Shutting down client")
|
|
if self.timer is not None:
|
|
self.timer.stop()
|
|
self.timer = None
|
|
self.connected = False
|
|
self.socket.close()
|
|
|
|
|
|
class ThreadedSaver(QThread):
|
|
def __init__(self, path: str | Path, snapshots: list[Snapshot]):
|
|
super().__init__()
|
|
self.path: str | Path = path
|
|
self.snapshots: list[Snapshot] = snapshots
|
|
|
|
def run(self):
|
|
with RecordFile(self.path, "w") as f:
|
|
f.write_snapshots(self.snapshots)
|
|
|
|
|
|
class RecorderWindow(Ui_Recorder, QMainWindow):
|
|
close_signal: pyqtSignal = pyqtSignal()
|
|
send_signal: pyqtSignal = pyqtSignal(object)
|
|
|
|
SAVE_DIR: Path = Path(__file__).parent.parent / "records"
|
|
|
|
COMMAND_DIRECTIONS: dict[str, CarControl] = {
|
|
"w": CarControl.FORWARD,
|
|
"s": CarControl.BACKWARD,
|
|
"d": CarControl.RIGHT,
|
|
"a": CarControl.LEFT,
|
|
}
|
|
|
|
def __init__(self, host: str, port: int) -> None:
|
|
super().__init__()
|
|
|
|
self.host: str = host
|
|
self.port: int = port
|
|
self.client_thread: QThread = QThread()
|
|
self.client: RecorderClient = RecorderClient(self.host, self.port)
|
|
self.client.data_received.connect(self.on_snapshot_received)
|
|
self.client.moveToThread(self.client_thread)
|
|
self.client_thread.started.connect(self.client.start)
|
|
self.close_signal.connect(self.client.shutdown)
|
|
self.send_signal.connect(self.client.send_command)
|
|
|
|
uic.load_ui.loadUi("src/recorder.ui", self)
|
|
|
|
self.forwardButton.pressed.connect(
|
|
lambda: self.on_car_controlled(CarControl.FORWARD, True)
|
|
)
|
|
self.forwardButton.released.connect(
|
|
lambda: self.on_car_controlled(CarControl.FORWARD, False)
|
|
)
|
|
|
|
self.backwardButton.pressed.connect(
|
|
lambda: self.on_car_controlled(CarControl.BACKWARD, True)
|
|
)
|
|
self.backwardButton.released.connect(
|
|
lambda: self.on_car_controlled(CarControl.BACKWARD, False)
|
|
)
|
|
|
|
self.rightButton.pressed.connect(
|
|
lambda: self.on_car_controlled(CarControl.RIGHT, True)
|
|
)
|
|
self.rightButton.released.connect(
|
|
lambda: self.on_car_controlled(CarControl.RIGHT, False)
|
|
)
|
|
|
|
self.leftButton.pressed.connect(
|
|
lambda: self.on_car_controlled(CarControl.LEFT, True)
|
|
)
|
|
self.leftButton.released.connect(
|
|
lambda: self.on_car_controlled(CarControl.LEFT, False)
|
|
)
|
|
|
|
self.recordDataButton.clicked.connect(self.toggle_record)
|
|
self.resetButton.clicked.connect(self.rollback)
|
|
|
|
self.bot: Optional[Bot] = None
|
|
self.autopiloting = False
|
|
|
|
self.autopilotButton.clicked.connect(self.toggle_autopilot)
|
|
self.autopilotButton.setDisabled(True)
|
|
|
|
self.saveRecordButton.clicked.connect(self.save_record)
|
|
|
|
self.saving_worker: Optional[ThreadedSaver] = None
|
|
self.recording = False
|
|
|
|
self.snapshots: list[Snapshot] = []
|
|
self.client_thread.start()
|
|
|
|
def on_car_controlled(self, control: CarControl, active: bool):
|
|
self.send_command(ControlCommand(control, active))
|
|
|
|
def keyPressEvent(self, event): # type: ignore
|
|
if event.isAutoRepeat():
|
|
return
|
|
|
|
if isinstance(event, QKeyEvent):
|
|
key_text = event.text()
|
|
ctrl: Optional[CarControl] = self.COMMAND_DIRECTIONS.get(key_text)
|
|
if ctrl is not None:
|
|
self.on_car_controlled(ctrl, True)
|
|
|
|
def keyReleaseEvent(self, event): # type: ignore
|
|
if event.isAutoRepeat():
|
|
return
|
|
if isinstance(event, QKeyEvent):
|
|
key_text = event.text()
|
|
ctrl: Optional[CarControl] = self.COMMAND_DIRECTIONS.get(key_text)
|
|
if ctrl is not None:
|
|
self.on_car_controlled(ctrl, False)
|
|
|
|
def toggle_record(self):
|
|
self.recording = not self.recording
|
|
self.recordDataButton.setText(
|
|
"Recording..." if self.recording else "Record")
|
|
self.send_command(RecordingCommand(self.recording))
|
|
|
|
def rollback(self):
|
|
rollback_by: int = self.forgetSnapshotNumber.value()
|
|
rollback_by = max(0, min(rollback_by, len(self.snapshots) - 1))
|
|
|
|
self.snapshots = self.snapshots[:-rollback_by]
|
|
self.nbrSnapshotSaved.setText(str(len(self.snapshots)))
|
|
|
|
if len(self.snapshots) == 0:
|
|
self.send_command(ResetCommand())
|
|
else:
|
|
self.send_command(ApplySnapshotCommand(self.snapshots[-1]))
|
|
|
|
if self.recording:
|
|
self.toggle_record()
|
|
|
|
def toggle_autopilot(self):
|
|
self.autopiloting = not self.autopiloting
|
|
self.autopilotButton.setText(
|
|
"AutoPilot:\n" + ("ON" if self.autopiloting else "OFF")
|
|
)
|
|
|
|
def save_record(self):
|
|
if self.saving_worker is not None:
|
|
print("Already saving !")
|
|
return
|
|
|
|
if len(self.snapshots) == 0:
|
|
print("No data to save !")
|
|
return
|
|
|
|
if self.recording:
|
|
self.toggle_record()
|
|
|
|
self.saveRecordButton.setText("Saving ...")
|
|
|
|
self.SAVE_DIR.mkdir(exist_ok=True)
|
|
|
|
record_name: str = "record_%d.rec"
|
|
fid = 0
|
|
while os.path.exists(self.SAVE_DIR / (record_name % fid)):
|
|
fid += 1
|
|
|
|
self.saving_worker = ThreadedSaver(
|
|
self.SAVE_DIR / (record_name % fid), self.snapshots)
|
|
self.snapshots = []
|
|
self.nbrSnapshotSaved.setText("0")
|
|
self.saving_worker.finished.connect(self.on_record_save_done)
|
|
self.saving_worker.start()
|
|
|
|
def on_record_save_done(self):
|
|
if self.saving_worker is None:
|
|
return
|
|
print("Recorded data saved to", self.saving_worker.path)
|
|
self.saving_worker = None
|
|
self.saveRecordButton.setText("Save")
|
|
|
|
@pyqtSlot(Snapshot)
|
|
def on_snapshot_received(self, snapshot: Snapshot):
|
|
self.snapshots.append(snapshot)
|
|
self.nbrSnapshotSaved.setText(str(len(self.snapshots)))
|
|
|
|
if self.autopiloting and self.bot is not None:
|
|
self.bot.on_snapshot_received(snapshot)
|
|
|
|
def shutdown(self):
|
|
self.close_signal.emit()
|
|
|
|
def send_command(self, command: Command):
|
|
self.send_signal.emit(command)
|
|
|
|
def register_bot(self, bot: Bot):
|
|
self.bot = bot
|
|
self.autopilotButton.setDisabled(False)
|