#include <unistd.h>
#include <errno.h>

#include <libopencm3/stm32/gpio.h>

#include "output.h"
#include "jtaglib.h"

#include "sr_global.h"
#include "mspdebug_wrapper.h"

#define BLOCK_SIZE 512 /* bytes */


static struct jtdev sr_jtdev;


int flash_and_reset(size_t img_start, size_t img_len, ssize_t (*read_block)(int addr, size_t len, uint8_t *out))
{
	union {
		uint8_t bytes[BLOCK_SIZE];
		uint16_t words[BLOCK_SIZE/2];
	} block;

	/* Initialize JTAG connection */
	unsigned int jtag_id = jtag_init(&sr_jtdev);

	if (sr_jtdev.failed)
		return -EPIPE;

	if (jtag_id != 0x89 && jtag_id != 0x91)
		return -EINVAL;

	/* Clear flash */
	jtag_erase_flash(&sr_jtdev, JTAG_ERASE_MAIN, 0);
	if (sr_jtdev.failed)
		return -EPIPE;

	/* Write flash */
	for (size_t p = img_start; p < img_start + img_len; p += BLOCK_SIZE) {
		ssize_t nin = read_block(p, BLOCK_SIZE, block.bytes);

		if (nin < 0)
			return nin;

		if (nin & 1) { /* pad unaligned */
			block.bytes[nin] = 0;
			nin ++;
		}
		
		/* Convert to little-endian */
		for (ssize_t i=0; i<nin/2; i++)
			block.words[i] = htole(block.words[i]);

		jtag_write_flash(&sr_jtdev, p, nin/2, block.words);
		if (sr_jtdev.failed)
			return -EPIPE;
	}

	/* Verify flash */
	/* FIXME
	word = jtag_read_mem(NULL, 16, addr+index );
	*/

	/* Execute power on reset */
	jtag_execute_puc(&sr_jtdev);
	if (sr_jtdev.failed)
		return -EPIPE;

	return 0;
}

/* mspdebug HAL shim */

int printc_err(const char *fmt, ...) {
	UNUSED(fmt);
	/* ignore */
    return 0;
}


static void sr_jtdev_power_on(struct jtdev *p) {
    UNUSED(p);
	/* ignore */
}

static void sr_jtdev_connect(struct jtdev *p) {
    UNUSED(p);
	/* ignore */
}

enum sr_gpio_types {
	SR_GPIO_TCK,
	SR_GPIO_TMS,
	SR_GPIO_TDI,
	SR_GPIO_RST,
	SR_GPIO_TST,
	SR_GPIO_TDO,
	SR_NUM_GPIOS
};

struct {
	uint32_t port;
	uint16_t num;
} gpios[SR_NUM_GPIOS] = {
	[SR_GPIO_TCK] = {GPIOD, 8},
	[SR_GPIO_TMS] = {GPIOD, 9},
	[SR_GPIO_TDI] = {GPIOD, 10},
	[SR_GPIO_RST] = {GPIOD, 11},
	[SR_GPIO_TST] = {GPIOD, 12},
	[SR_GPIO_TDO] = {GPIOD, 13},
};

static void sr_gpio_write(int num, int out) {
	if (out)
		gpio_set(gpios[num].port, gpios[num].num);
	else
		gpio_clear(gpios[num].port, gpios[num].num);
}

static void sr_jtdev_tck(struct jtdev *p, int out) {
	UNUSED(p);
	sr_gpio_write(SR_GPIO_TCK, out);
}

static void sr_jtdev_tms(struct jtdev *p, int out) {
	UNUSED(p);
	sr_gpio_write(SR_GPIO_TMS, out);
}

static void sr_jtdev_tdi(struct jtdev *p, int out) {
	UNUSED(p);
	sr_gpio_write(SR_GPIO_TDI, out);
}

static void sr_jtdev_rst(struct jtdev *p, int out) {
	UNUSED(p);
	sr_gpio_write(SR_GPIO_RST, out);
}

static void sr_jtdev_tst(struct jtdev *p, int out) {
	UNUSED(p);
	sr_gpio_write(SR_GPIO_TST, out);
}

static int sr_jtdev_tdo_get(struct jtdev *p) {
    UNUSED(p);
	return gpio_get(gpios[SR_GPIO_TST].port, gpios[SR_GPIO_TST].num);
}

static void sr_jtdev_tclk(struct jtdev *p, int out) {
	UNUSED(p);
	sr_gpio_write(SR_GPIO_TST, out);
}

static int sr_jtdev_tclk_get(struct jtdev *p) {
    UNUSED(p);
	return gpio_get(gpios[SR_GPIO_TDI].port, gpios[SR_GPIO_TDI].num);
}

static void sr_jtdev_tclk_strobe(struct jtdev *p, unsigned int count) {
    UNUSED(p);
	while (count--) {
		gpio_set(gpios[SR_GPIO_TDI].port, gpios[SR_GPIO_TDI].num);
		gpio_clear(gpios[SR_GPIO_TDI].port, gpios[SR_GPIO_TDI].num);
	}
}

static void sr_jtdev_led_green(struct jtdev *p, int out) {
	UNUSED(p);
	UNUSED(out);
	/* ignore */
}

static void sr_jtdev_led_red(struct jtdev *p, int out) {
	UNUSED(p);
	UNUSED(out);
	/* ignore */
}


static struct jtdev_func sr_jtdev_vtable = {
	.jtdev_open = NULL,
	.jtdev_close = NULL,

	.jtdev_power_off = NULL,
	.jtdev_release = NULL,

	.jtdev_power_on = sr_jtdev_power_on,
	.jtdev_connect = sr_jtdev_connect,

	.jtdev_tck = sr_jtdev_tck,
	.jtdev_tms = sr_jtdev_tms,
	.jtdev_tdi = sr_jtdev_tdi,
	.jtdev_rst = sr_jtdev_rst,
	.jtdev_tst = sr_jtdev_tst,
	.jtdev_tdo_get = sr_jtdev_tdo_get,

	.jtdev_tclk = sr_jtdev_tclk,
	.jtdev_tclk_get = sr_jtdev_tclk_get,
	.jtdev_tclk_strobe = sr_jtdev_tclk_strobe,

	.jtdev_led_green = sr_jtdev_led_green,
	.jtdev_led_red = sr_jtdev_led_red,

};

static struct jtdev sr_jtdev = {
	0,
	.f = &sr_jtdev_vtable
};