#include "global.h"
#include "serial.h"
#include "cobs.h"

#include <string.h>
#include <stdarg.h>
#include <stdlib.h>

volatile struct dma_tx_buf usart_tx_buf;

static uint32_t tx_overruns=0, rx_overruns=0;
static uint32_t rx_framing_errors=0, rx_protocol_errors=0;

static struct cobs_decode_state cobs_state;

static volatile uint8_t rx_buf[32];


static void usart_schedule_dma(void);
static int usart_putc_nonblocking(uint8_t c);


void usart_dma_reset() {
    usart_tx_buf.xfr_start = -1;
    usart_tx_buf.xfr_end = 0;
    usart_tx_buf.wr_pos = 0;
    usart_tx_buf.wr_idx = 0;
    usart_tx_buf.xfr_next = 0;
    usart_tx_buf.wraparound = false;
    usart_tx_buf.ack = true;

    for (size_t i=0; i<ARRAY_LEN(usart_tx_buf.packet_end); i++)
        usart_tx_buf.packet_end[i] = -1;

    cobs_decode_incremental_initialize(&cobs_state);
}

void usart_dma_init() {
    usart_dma_reset();

    /* Configure DMA 1 Channel 2 to handle uart transmission */
    DMA1_Channel2->CPAR = (uint32_t)&(USART1->TDR);
    DMA1_Channel2->CCR = (0<<DMA_CCR_PL_Pos)
        | DMA_CCR_DIR
        | (0<<DMA_CCR_MSIZE_Pos) /* 8 bit */
        | (0<<DMA_CCR_PSIZE_Pos) /* 8 bit */
        | DMA_CCR_MINC
        | DMA_CCR_TCIE; /* Enable transfer complete interrupt. */

    DMA1_Channel3->CMAR = (uint32_t)&(CRC->DR);
    DMA1_Channel3->CCR = (1<<DMA_CCR_PL_Pos)
        | (0<<DMA_CCR_MSIZE_Pos) /* 8 bit */
        | (0<<DMA_CCR_PSIZE_Pos) /* 8 bit */
        | DMA_CCR_PINC
        | DMA_CCR_TCIE; /* Enable transfer complete interrupt. */

    /* triggered on transfer completion. We use this to process the ADC data */
    NVIC_EnableIRQ(DMA1_Channel2_3_IRQn);
    NVIC_SetPriority(DMA1_Channel2_3_IRQn, 2<<5);

    USART1->CR1 = /* 8-bit -> M1, M0 clear */
        /* OVER8 clear. Use default 16x oversampling */
        /* CMIF clear */
          USART_CR1_MME
        /* WAKE clear */
        /* PCE, PS clear */
        | USART_CR1_RXNEIE /* Enable receive interrupt */
        /* other interrupts clear */
        | USART_CR1_TE
        | USART_CR1_RE;
    /* Set divider for 115.2kBd @48MHz system clock. */
    //USART1->BRR = 417;
    
    //USART1->BRR = 48; /* 1MBd */
    //USART1->BRR = 96; /* 500kBd */
    USART1->BRR = 192; /* 250kBd */
    //USART1->BRR = 208; /* 230400 */

    USART1->CR2 = USART_CR2_TXINV | USART_CR2_RXINV;

    USART1->CR3 |= USART_CR3_DMAT; /* TX DMA enable */

    /* Enable receive interrupt */
    NVIC_EnableIRQ(USART1_IRQn);
    NVIC_SetPriority(USART1_IRQn, 1<<5);

    /* And... go! */
    USART1->CR1 |= USART_CR1_UE;
}

void USART1_IRQHandler() {
    uint32_t isr = USART1->ISR;

    if (isr & USART_ISR_ORE) {
        USART1->ICR = USART_ICR_ORECF;
        rx_overruns++;
        return;
    }

    if (isr & USART_ISR_RXNE) {
        uint8_t c = USART1->RDR;

        int rc = cobs_decode_incremental(&cobs_state, (char *)rx_buf, sizeof(rx_buf), c);
        if (rc == 0) /* packet still incomplete */
            return;

        if (rc < 0) {
            rx_framing_errors++;
            return;
        }
        
        /* A complete frame received */
        if (rc != 2) {
            rx_protocol_errors++;
            return;
        }

        volatile struct ctrl_pkt *pkt = (volatile struct ctrl_pkt *)rx_buf;

        switch (pkt->type) {
            case CTRL_PKT_RESET:
                usart_dma_reset();
                break;

            case CTRL_PKT_ACK:
                usart_tx_buf.ack = true;
                if (!(DMA1_Channel2->CCR & DMA_CCR_EN))
                    usart_schedule_dma();
                break;

            default:
                rx_protocol_errors++;
        }
        return;
    }
}


void usart_schedule_dma() {
    volatile struct dma_tx_buf *buf = &usart_tx_buf;

    ssize_t xfr_start, xfr_end, xfr_len;
    if (buf->wraparound) {
        buf->wraparound = false;
        xfr_start = 0;
        xfr_len = buf->xfr_end;
        xfr_end = buf->xfr_end;

    } else if (buf->ack) {
        if (buf->packet_end[buf->xfr_next] == -1)
            return; /* Nothing to trasnmit */

        buf->ack = false;

        xfr_start = buf->xfr_end;
        xfr_end = buf->packet_end[buf->xfr_next];
        buf->packet_end[buf->xfr_next] = -1;
        buf->xfr_next = (buf->xfr_next + 1) % ARRAY_LEN(buf->packet_end);

        if (xfr_end > xfr_start) { /* no wraparound */
            xfr_len = xfr_end - xfr_start;

        } else { /* wraparound */
            if (xfr_end != 0)
                buf->wraparound = true;
            xfr_len = sizeof(buf->data) - xfr_start;
        }

    } else {
        /* retransmit */
        /* First, send a zero to delimit any garbage from the following good packet */
        USART1->TDR = 0x00;

        xfr_start = buf->xfr_start;
        xfr_end = buf->xfr_end;

        if (xfr_end > xfr_start) { /* no wraparound */
            xfr_len = xfr_end - xfr_start;

        } else { /* wraparound */
            if (xfr_end != 0)
                buf->wraparound = true;
            xfr_len = sizeof(buf->data) - xfr_start;
        }

        leds.error = 250;
    }

    buf->xfr_start = xfr_start;
    buf->xfr_end = xfr_end;

    /* initiate transmission of new buffer */
    DMA1_Channel2->CMAR = (uint32_t)(buf->data + xfr_start);
    DMA1_Channel2->CNDTR = xfr_len;
    DMA1_Channel2->CCR |= DMA_CCR_EN;
}

int usart_putc_nonblocking(uint8_t c) {
    volatile struct dma_tx_buf *buf = &usart_tx_buf;

    if (buf->wr_pos == buf->xfr_start)
        return -EBUSY;

    buf->data[buf->wr_pos] = c;
    buf->wr_pos = (buf->wr_pos + 1) % sizeof(buf->data);
    return 0;
}


void DMA1_Channel2_3_IRQHandler(void) {
    /* Transfer complete */
    DMA1->IFCR |= DMA_IFCR_CTCIF2;

    DMA1_Channel2->CCR &= ~DMA_CCR_EN;
    if (usart_tx_buf.wraparound)
        usart_schedule_dma();
}

/* len is the packet length including headers */
int usart_send_packet_nonblocking(struct ll_pkt *pkt, size_t pkt_len) {

    if (usart_tx_buf.packet_end[usart_tx_buf.wr_idx] != -1) {
        /* Find a free slot for this packet */
        tx_overruns++;
        return -EBUSY;
    }

    pkt->pid = usart_tx_buf.wr_idx;
    pkt->_pad = usart_tx_buf.xfr_next;

    /* make the value this wonky-ass CRC implementation produces match zlib etc. */
    CRC->CR = CRC_CR_REV_OUT | (1<<CRC_CR_REV_IN_Pos) | CRC_CR_RESET;
    for (size_t i=offsetof(struct ll_pkt, pid); i<pkt_len; i++)
        CRC->DR = ((uint8_t *)pkt)[i];

    pkt->crc32 = ~CRC->DR;

    int rc = cobs_encode_usart((int (*)(char))usart_putc_nonblocking, (char *)pkt, pkt_len);
    if (rc)
        return rc;

    usart_tx_buf.packet_end[usart_tx_buf.wr_idx] = usart_tx_buf.wr_pos;
    usart_tx_buf.wr_idx = (usart_tx_buf.wr_idx + 1) % ARRAY_LEN(usart_tx_buf.packet_end);

    leds.usb = 100;

    if (!(DMA1_Channel2->CCR & DMA_CCR_EN))
        usart_schedule_dma();
    return 0;
}