aboutsummaryrefslogtreecommitdiff
path: root/driver_fw/tools/usb_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'driver_fw/tools/usb_test.py')
-rw-r--r--driver_fw/tools/usb_test.py188
1 files changed, 188 insertions, 0 deletions
diff --git a/driver_fw/tools/usb_test.py b/driver_fw/tools/usb_test.py
new file mode 100644
index 0000000..0f61593
--- /dev/null
+++ b/driver_fw/tools/usb_test.py
@@ -0,0 +1,188 @@
+#!/usr/bin/env python
+
+import time
+from pprint import pprint
+from enum import Enum
+from functools import cache
+from dataclasses import dataclass, fields, astuple
+import struct
+import binascii
+
+import numpy as np
+import click
+import serial
+from cobs import cobs
+
+class CobsSerial:
+ def __init__(self, port, timeout):
+ self.ser = serial.Serial(port, timeout=timeout)
+ self.ser.flushOutput()
+ self.ser.flushInput()
+ self.ser.write(bytes([0])) # synchronize
+ self.ser.flushOutput()
+
+ def write_packet(self, data):
+ self.ser.write(cobs.encode(data))
+ self.ser.write(bytes([0]))
+ self.ser.flushOutput()
+
+ def read_packet(self):
+ data = b''
+ while (b := self.ser.read(1)):
+ if b[0] == 0:
+ break
+ data += b
+
+ if data:
+ return parse_packet(cobs.decode(data))
+ else:
+ return None
+
+ def command(self, command, args=b''):
+ self.write_packet(bytes([command.value]) + args)
+ return self.read_packet()
+
+
+class SerializableEnum(Enum):
+ def __int__(self):
+ return self.value
+
+class PacketType(SerializableEnum):
+ USBP_GET_STATUS = 0
+ USBP_GET_MEASUREMENTS = 1
+ USBP_SET_MOTOR = 2
+
+class ErrorCode(Enum):
+ ERR_SUCCESS = 0
+ ERR_TIMEOUT = 1
+ ERR_PHYSICAL_LAYER = 2
+ ERR_FRAMING = 3
+ ERR_PROTOCOL = 4
+ ERR_DMA = 5
+ ERR_BUSY = 6
+ ERR_BUFFER_OVERFLOW = 7
+ ERR_RX_OVERRUN = 8
+ ERR_TX_OVERRUN = 9
+
+class BoardConfig(Enum):
+ BCFG_UNCONFIGURED = 0
+ BCFG_DISPLAY = 1
+ BCFG_MOTOR = 2
+ BCFG_MEAS = 3
+
+class Serialized:
+ @classmethod
+ def deserialize(kls, data):
+ fields = struct.unpack(kls._struct_format(), data)
+ mapped = [cast(val) for cast, val in zip(kls._struct_casts(), fields)]
+ return kls(*mapped)
+
+ def serialize(self):
+ mapped = [uncast(val) for uncast, val in zip(self._struct_uncasts(), astuple(self))]
+ return struct.pack(self._struct_format(), *mapped)
+
+ @classmethod
+ @cache
+ def _struct_format(kls):
+ return kls._parse_fields()[0]
+
+ @classmethod
+ @cache
+ def _struct_casts(kls):
+ return kls._parse_fields()[1]
+
+ @classmethod
+ @cache
+ def _struct_uncasts(kls):
+ return kls._parse_fields()[2]
+
+ @classmethod
+ def _parse_fields(kls):
+ fmt = '<'
+ casts = []
+ uncasts = []
+ for field in fields(kls):
+ if isinstance(field.type, tuple):
+ struct_type, cast, uncast, *_ = *field.type, int
+ else:
+ struct_type, cast, uncast = field.type, int, int
+ fmt += struct_type
+ casts.append(cast)
+ uncasts.append(uncast)
+ return fmt, casts, uncasts
+
+def timestamp(value):
+ return float(value) / 1e6
+
+@dataclass
+class StatusPacket(Serialized):
+ packet_type: ('B', PacketType)
+ sys_time_us: ('Q', timestamp)
+ has_lcd: ('B', bool)
+ has_adc: ('B', bool)
+ board_config: ('B', BoardConfig)
+ bus_addr: 'B'
+ last_uart_error: ('B', ErrorCode)
+ last_uart_error_timestamp: ('Q', timestamp)
+ last_uart_rx: ('Q', timestamp)
+ last_uart_tx: ('Q', timestamp)
+ last_bus_error: ('B', ErrorCode)
+ last_bus_error_timestamp: ('Q', timestamp)
+
+@dataclass
+class MotorPacket(Serialized):
+ packet_type: ('B', PacketType)
+ speed_rpm: 'i'
+
+def parse_packet(data):
+ packet_type = PacketType(data[0])
+ if packet_type == PacketType.USBP_GET_STATUS:
+ return StatusPacket.deserialize(data)
+ if packet_type == PacketType.USBP_GET_MEASUREMENTS:
+ return MeasurementPacket.deserialize(data)
+ else:
+ raise ValueError(f'Unsupported packet type {packet_type}')
+
+@dataclass
+class MeasurementPacket(Serialized):
+ packet_type: ('B', PacketType)
+ num_channels: 'B'
+ _num_samples_a: 'I'
+ _num_samples_b: 'I'
+ _measurements_raw: ('240s', bytes)
+
+ @property
+ def measurements(self):
+ return np.frombuffer(self._measurements_raw, np.dtype(np.int32).newbyteorder('<')).reshape([2, 2, -1])
+
+ @property
+ def num_samples(self):
+ return [self._num_samples_a, self._num_samples_b]
+
+@click.group()
+def cli():
+ pass
+
+@cli.command()
+@click.argument('port')
+@click.option('--timeout', type=float, default=1)
+def probe(port, timeout):
+ ser = CobsSerial(port, timeout)
+ pprint(ser.command(PacketType.USBP_GET_STATUS))
+ while True:
+ time.sleep(0.01)
+ packet = ser.command(PacketType.USBP_GET_MEASUREMENTS)
+ for i in range(packet.num_samples[1]):
+ print(packet.measurements[1,1,i], packet.num_samples[1])
+
+@cli.command()
+@click.argument('port')
+@click.argument('speed_rpm', type=int, default=0)
+@click.option('--timeout', type=float, default=1)
+def motor(port, speed_rpm, timeout):
+ ser = CobsSerial(port, timeout)
+ packet = MotorPacket(PacketType.USBP_SET_MOTOR, speed_rpm)
+ ser.write_packet(packet.serialize())
+
+if __name__ == '__main__':
+ cli()