diff options
Diffstat (limited to 'tools/usb_test.py')
-rw-r--r-- | tools/usb_test.py | 188 |
1 files changed, 188 insertions, 0 deletions
diff --git a/tools/usb_test.py b/tools/usb_test.py new file mode 100644 index 0000000..0f61593 --- /dev/null +++ b/tools/usb_test.py @@ -0,0 +1,188 @@ +#!/usr/bin/env python + +import time +from pprint import pprint +from enum import Enum +from functools import cache +from dataclasses import dataclass, fields, astuple +import struct +import binascii + +import numpy as np +import click +import serial +from cobs import cobs + +class CobsSerial: + def __init__(self, port, timeout): + self.ser = serial.Serial(port, timeout=timeout) + self.ser.flushOutput() + self.ser.flushInput() + self.ser.write(bytes([0])) # synchronize + self.ser.flushOutput() + + def write_packet(self, data): + self.ser.write(cobs.encode(data)) + self.ser.write(bytes([0])) + self.ser.flushOutput() + + def read_packet(self): + data = b'' + while (b := self.ser.read(1)): + if b[0] == 0: + break + data += b + + if data: + return parse_packet(cobs.decode(data)) + else: + return None + + def command(self, command, args=b''): + self.write_packet(bytes([command.value]) + args) + return self.read_packet() + + +class SerializableEnum(Enum): + def __int__(self): + return self.value + +class PacketType(SerializableEnum): + USBP_GET_STATUS = 0 + USBP_GET_MEASUREMENTS = 1 + USBP_SET_MOTOR = 2 + +class ErrorCode(Enum): + ERR_SUCCESS = 0 + ERR_TIMEOUT = 1 + ERR_PHYSICAL_LAYER = 2 + ERR_FRAMING = 3 + ERR_PROTOCOL = 4 + ERR_DMA = 5 + ERR_BUSY = 6 + ERR_BUFFER_OVERFLOW = 7 + ERR_RX_OVERRUN = 8 + ERR_TX_OVERRUN = 9 + +class BoardConfig(Enum): + BCFG_UNCONFIGURED = 0 + BCFG_DISPLAY = 1 + BCFG_MOTOR = 2 + BCFG_MEAS = 3 + +class Serialized: + @classmethod + def deserialize(kls, data): + fields = struct.unpack(kls._struct_format(), data) + mapped = [cast(val) for cast, val in zip(kls._struct_casts(), fields)] + return kls(*mapped) + + def serialize(self): + mapped = [uncast(val) for uncast, val in zip(self._struct_uncasts(), astuple(self))] + return struct.pack(self._struct_format(), *mapped) + + @classmethod + @cache + def _struct_format(kls): + return kls._parse_fields()[0] + + @classmethod + @cache + def _struct_casts(kls): + return kls._parse_fields()[1] + + @classmethod + @cache + def _struct_uncasts(kls): + return kls._parse_fields()[2] + + @classmethod + def _parse_fields(kls): + fmt = '<' + casts = [] + uncasts = [] + for field in fields(kls): + if isinstance(field.type, tuple): + struct_type, cast, uncast, *_ = *field.type, int + else: + struct_type, cast, uncast = field.type, int, int + fmt += struct_type + casts.append(cast) + uncasts.append(uncast) + return fmt, casts, uncasts + +def timestamp(value): + return float(value) / 1e6 + +@dataclass +class StatusPacket(Serialized): + packet_type: ('B', PacketType) + sys_time_us: ('Q', timestamp) + has_lcd: ('B', bool) + has_adc: ('B', bool) + board_config: ('B', BoardConfig) + bus_addr: 'B' + last_uart_error: ('B', ErrorCode) + last_uart_error_timestamp: ('Q', timestamp) + last_uart_rx: ('Q', timestamp) + last_uart_tx: ('Q', timestamp) + last_bus_error: ('B', ErrorCode) + last_bus_error_timestamp: ('Q', timestamp) + +@dataclass +class MotorPacket(Serialized): + packet_type: ('B', PacketType) + speed_rpm: 'i' + +def parse_packet(data): + packet_type = PacketType(data[0]) + if packet_type == PacketType.USBP_GET_STATUS: + return StatusPacket.deserialize(data) + if packet_type == PacketType.USBP_GET_MEASUREMENTS: + return MeasurementPacket.deserialize(data) + else: + raise ValueError(f'Unsupported packet type {packet_type}') + +@dataclass +class MeasurementPacket(Serialized): + packet_type: ('B', PacketType) + num_channels: 'B' + _num_samples_a: 'I' + _num_samples_b: 'I' + _measurements_raw: ('240s', bytes) + + @property + def measurements(self): + return np.frombuffer(self._measurements_raw, np.dtype(np.int32).newbyteorder('<')).reshape([2, 2, -1]) + + @property + def num_samples(self): + return [self._num_samples_a, self._num_samples_b] + +@click.group() +def cli(): + pass + +@cli.command() +@click.argument('port') +@click.option('--timeout', type=float, default=1) +def probe(port, timeout): + ser = CobsSerial(port, timeout) + pprint(ser.command(PacketType.USBP_GET_STATUS)) + while True: + time.sleep(0.01) + packet = ser.command(PacketType.USBP_GET_MEASUREMENTS) + for i in range(packet.num_samples[1]): + print(packet.measurements[1,1,i], packet.num_samples[1]) + +@cli.command() +@click.argument('port') +@click.argument('speed_rpm', type=int, default=0) +@click.option('--timeout', type=float, default=1) +def motor(port, speed_rpm, timeout): + ser = CobsSerial(port, timeout) + packet = MotorPacket(PacketType.USBP_SET_MOTOR, speed_rpm) + ser.write_packet(packet.serialize()) + +if __name__ == '__main__': + cli() |