#!/usr/bin/env python import time from pprint import pprint from enum import Enum from functools import cache from dataclasses import dataclass, fields, astuple import struct import binascii import numpy as np import click import serial from cobs import cobs class CobsSerial: def __init__(self, port, timeout): self.ser = serial.Serial(port, timeout=timeout) self.ser.flushOutput() self.ser.flushInput() self.ser.write(bytes([0])) # synchronize self.ser.flushOutput() def write_packet(self, data): self.ser.write(cobs.encode(data)) self.ser.write(bytes([0])) self.ser.flushOutput() def read_packet(self): data = b'' while (b := self.ser.read(1)): if b[0] == 0: break data += b if data: return parse_packet(cobs.decode(data)) else: return None def command(self, command, args=b''): self.write_packet(bytes([command.value]) + args) return self.read_packet() class SerializableEnum(Enum): def __int__(self): return self.value class PacketType(SerializableEnum): USBP_GET_STATUS = 0 USBP_GET_MEASUREMENTS = 1 USBP_SET_MOTOR = 2 class ErrorCode(Enum): ERR_SUCCESS = 0 ERR_TIMEOUT = 1 ERR_PHYSICAL_LAYER = 2 ERR_FRAMING = 3 ERR_PROTOCOL = 4 ERR_DMA = 5 ERR_BUSY = 6 ERR_BUFFER_OVERFLOW = 7 ERR_RX_OVERRUN = 8 ERR_TX_OVERRUN = 9 class BoardConfig(Enum): BCFG_UNCONFIGURED = 0 BCFG_DISPLAY = 1 BCFG_MOTOR = 2 BCFG_MEAS = 3 class Serialized: @classmethod def deserialize(kls, data): fields = struct.unpack(kls._struct_format(), data) mapped = [cast(val) for cast, val in zip(kls._struct_casts(), fields)] return kls(*mapped) def serialize(self): mapped = [uncast(val) for uncast, val in zip(self._struct_uncasts(), astuple(self))] return struct.pack(self._struct_format(), *mapped) @classmethod @cache def _struct_format(kls): return kls._parse_fields()[0] @classmethod @cache def _struct_casts(kls): return kls._parse_fields()[1] @classmethod @cache def _struct_uncasts(kls): return kls._parse_fields()[2] @classmethod def _parse_fields(kls): fmt = '<' casts = [] uncasts = [] for field in fields(kls): if isinstance(field.type, tuple): struct_type, cast, uncast, *_ = *field.type, int else: struct_type, cast, uncast = field.type, int, int fmt += struct_type casts.append(cast) uncasts.append(uncast) return fmt, casts, uncasts def timestamp(value): return float(value) / 1e6 @dataclass class StatusPacket(Serialized): packet_type: ('B', PacketType) sys_time_us: ('Q', timestamp) has_lcd: ('B', bool) has_adc: ('B', bool) board_config: ('B', BoardConfig) bus_addr: 'B' last_uart_error: ('B', ErrorCode) last_uart_error_timestamp: ('Q', timestamp) last_uart_rx: ('Q', timestamp) last_uart_tx: ('Q', timestamp) last_bus_error: ('B', ErrorCode) last_bus_error_timestamp: ('Q', timestamp) @dataclass class MotorPacket(Serialized): packet_type: ('B', PacketType) speed_rpm: 'i' def parse_packet(data): packet_type = PacketType(data[0]) if packet_type == PacketType.USBP_GET_STATUS: return StatusPacket.deserialize(data) if packet_type == PacketType.USBP_GET_MEASUREMENTS: return MeasurementPacket.deserialize(data) else: raise ValueError(f'Unsupported packet type {packet_type}') @dataclass class MeasurementPacket(Serialized): packet_type: ('B', PacketType) num_channels: 'B' _num_samples_a: 'I' _num_samples_b: 'I' _measurements_raw: ('240s', bytes) @property def measurements(self): return np.frombuffer(self._measurements_raw, np.dtype(np.int32).newbyteorder('<')).reshape([2, 2, -1]) @property def num_samples(self): return [self._num_samples_a, self._num_samples_b] @click.group() def cli(): pass @cli.command() @click.argument('port') @click.option('--timeout', type=float, default=1) def probe(port, timeout): ser = CobsSerial(port, timeout) pprint(ser.command(PacketType.USBP_GET_STATUS)) while True: time.sleep(0.01) packet = ser.command(PacketType.USBP_GET_MEASUREMENTS) for i in range(packet.num_samples[1]): print(packet.measurements[1,1,i], packet.num_samples[1]) @cli.command() @click.argument('port') @click.argument('speed_rpm', type=int, default=0) @click.option('--timeout', type=float, default=1) def motor(port, speed_rpm, timeout): ser = CobsSerial(port, timeout) packet = MotorPacket(PacketType.USBP_SET_MOTOR, speed_rpm) ser.write_packet(packet.serialize()) if __name__ == '__main__': cli()