#ifndef SIMULATION
#include <stm32f407xx.h>
#endif

#include <unistd.h>
#include <math.h>

#include <arm_math.h>
#include <levmarq.h>

#include "freq_meas.h"
#include "sr_global.h"
#include "simulation.h"


/* FTT window lookup table defined in generated/fmeas_fft_window.c */
extern const float * const fmeas_fft_window_table;

/* jury-rig some definitions for these functions since the ARM headers only export an over-generalized variable bin size
 * variant. */
extern arm_status arm_rfft_32_fast_init_f32(arm_rfft_fast_instance_f32 * S);
extern arm_status arm_rfft_64_fast_init_f32(arm_rfft_fast_instance_f32 * S);
extern arm_status arm_rfft_128_fast_init_f32(arm_rfft_fast_instance_f32 * S);
extern arm_status arm_rfft_256_fast_init_f32(arm_rfft_fast_instance_f32 * S);
extern arm_status arm_rfft_512_fast_init_f32(arm_rfft_fast_instance_f32 * S);
extern arm_status arm_rfft_1024_fast_init_f32(arm_rfft_fast_instance_f32 * S);
extern arm_status arm_rfft_2048_fast_init_f32(arm_rfft_fast_instance_f32 * S);
extern arm_status arm_rfft_4096_fast_init_f32(arm_rfft_fast_instance_f32 * S);

#define CONCAT(A, B, C) A ## B ## C
#define arm_rfft_init_name(nbits) CONCAT(arm_rfft_, nbits, _fast_init_f32)

void func_gauss_grad(float *out, float *params, int x, void *userdata);
float func_gauss(float *params, int x, void *userdata);

int adc_buf_measure_freq(uint16_t adc_buf[FMEAS_FFT_LEN], float *out) {
    int rc;
    float in_buf[FMEAS_FFT_LEN];
    float out_buf[FMEAS_FFT_LEN];
    /*
    DEBUG_PRINTN("    [emulated adc buf] ");
    for (size_t i=0; i<FMEAS_FFT_LEN; i++)
        DEBUG_PRINTN("%5d, ", adc_buf[i]);
    DEBUG_PRINTN("\n");
    */
    //DEBUG_PRINT("Applying window function");
    for (size_t i=0; i<FMEAS_FFT_LEN; i++)
        in_buf[i] = ((float)adc_buf[i] / (float)FMEAS_ADC_MAX - 0.5) * fmeas_fft_window_table[i];

    //DEBUG_PRINT("Running FFT");
    arm_rfft_fast_instance_f32 fft_inst;
    if ((rc = arm_rfft_init_name(FMEAS_FFT_LEN)(&fft_inst)) != ARM_MATH_SUCCESS) {
        *out = NAN;
        return rc;
    }

    /*
    DEBUG_PRINTN("    [input] ");
    for (size_t i=0; i<FMEAS_FFT_LEN; i++)
        DEBUG_PRINTN("%010f, ", in_buf[i]);
    DEBUG_PRINTN("\n");
    */
#ifndef SIMULATION
    GPIOA->BSRR = 1<<12;
#endif
    arm_rfft_fast_f32(&fft_inst, in_buf, out_buf, 0);
#ifndef SIMULATION
    GPIOA->BSRR = 1<<12<<16;
#endif

#define FMEAS_FFT_WINDOW_MIN_F_HZ 30.0f
#define FMEAS_FFT_WINDOW_MAX_F_HZ 70.0f
    const float binsize_hz = (float)FMEAS_ADC_SAMPLING_RATE / FMEAS_FFT_LEN;
    const size_t first_bin = (int)(FMEAS_FFT_WINDOW_MIN_F_HZ / binsize_hz);
    const size_t last_bin = (int)(FMEAS_FFT_WINDOW_MAX_F_HZ / binsize_hz + 0.5f);
    const size_t nbins = last_bin - first_bin + 1;

    /*
    DEBUG_PRINT("binsize_hz=%f first_bin=%zd last_bin=%zd nbins=%zd", binsize_hz, first_bin, last_bin, nbins);
    DEBUG_PRINTN("    [bins real] ");
    for (size_t i=0; i<FMEAS_FFT_LEN/2; i+=2)
        DEBUG_PRINTN("%010f, ", out_buf[i]);
    DEBUG_PRINTN("\n    [bins imag] ");
    for (size_t i=1; i<FMEAS_FFT_LEN/2; i+=2)
        DEBUG_PRINTN("%010f, ", out_buf[i]);
    DEBUG_PRINT("\n");

    DEBUG_PRINT("Repacking FFT results");
    */
    /* Copy real values of target data to front of output buffer */
    for (size_t i=0; i<nbins; i++) {
        float real = out_buf[2 * (first_bin + i)];
        float imag = out_buf[2 * (first_bin + i) + 1];
        out_buf[i] = sqrtf(real*real + imag*imag);
    }

    /*
    DEBUG_PRINT("Running Levenberg-Marquardt");
    */
    LMstat lmstat;
    levmarq_init(&lmstat);

    float a_max = 0.0f;
    int i_max = 0;
    for (size_t i=0; i<nbins; i++) {
        if (out_buf[i] > a_max) {
            a_max = out_buf[i];
            i_max = i;
        }
    }

    float par[3] = {
        a_max, i_max, 1.0f
    };
    /*
    DEBUG_PRINT("    par_pre={%010f, %010f, %010f}", par[0], par[1], par[2]);
    */

#ifndef SIMULATION
    GPIOA->BSRR = 1<<12;
#endif
    if (levmarq(3, par, nbins, out_buf, NULL, func_gauss, func_gauss_grad, NULL, &lmstat) < 0) {
#ifndef SIMULATION
        GPIOA->BSRR = 1<<12<<16;
#endif
        *out = NAN;
        return -1;
    }
#ifndef SIMULATION
    GPIOA->BSRR = 1<<12<<16;
#endif

    /*
    DEBUG_PRINT("    par_post={%010f, %010f, %010f}", par[0], par[1], par[2]);

    DEBUG_PRINT("done.");
    */
    float res = (par[1] + first_bin) * binsize_hz;
    if (par[1] < 2 || res < 5 || res > 150 || par[0] < 1) {
        *out = NAN;
        return -1;
    }
    
    *out = res;
    return 0;
}

float func_gauss(float *params, int x, void *userdata) {
    UNUSED(userdata);
    float a = params[0], b = params[1], c = params[2];
    float n = x-b;
    return a*expf(-n*n / (2.0f* c*c));
}

void func_gauss_grad(float *out, float *params, int x, void *userdata) {
    UNUSED(userdata);
    float a = params[0], b = params[1], c = params[2];
    float n = x-b;
    float e = expf(-n*n / (2.0f * c*c));
    
    /* d/da */
    out[0] = e;
    /* d/db */
    out[1] = a*n/(c*c) * e;
    /* d/dc */
    out[2] = a*n*n/(c*c*c) * e;
}