#!/usr/bin/env python
# coding: utf-8

# # ROCOF test waveform library
# 
# This is a re-implementation of the ROCOF test waveforms described in https://zenodo.org/record/3559798
# 
# **This file is exported as a python module and loaded from other notebooks here, so please make sure to re-export when changing it.**

# In[ ]:


import math
import itertools

import numpy as np
from scipy import signal
from matplotlib import pyplot as plt


# In[ ]:


get_ipython().run_line_magic('matplotlib', 'notebook')


# In[ ]:


def sample_waveform(generator, duration:"s"=10, sampling_rate:"sp/s"=10000, frequency:"Hz"=50):
    samples = int(duration*sampling_rate)
    phases = np.linspace(0, 2*np.pi, 6, endpoint=False)
    omega_t = np.linspace(phases, phases + 2*np.pi*duration*frequency, samples)
    fundamental = np.sin(omega_t)
    return generator(omega_t, fundamental, sampling_rate=sampling_rate, duration=duration, frequency=frequency).swapaxes(0, 1)


# In[ ]:


def gen_harmonics(amplitudes, phases=[]):
    return lambda omega_t, fundamental, **_: fundamental + np.sum([
                a*np.sin((p if p else 0) + i*omega_t)
                   for i, (a, p) in enumerate(itertools.zip_longest(amplitudes, phases), start=2)
    ], axis=0)

def test_harmonics():
    return gen_harmonics([0.02, 0.05, 0.01, 0.06, 0.005, 0.05, 0.005, 0.015, 0.005, 0.035, 0.005, 0.003])


# In[ ]:


def gen_interharmonic(amplitudes, ih=[], ih_phase=[]):
    def gen(omega_t, fundamental, **_):
        return fundamental + np.sum([
            a*np.sin(omega_t * ih + (p if p else 0))
            for a, ih, p in itertools.zip_longest(amplitudes, ih, ih_phase)
        ], axis=0)
    return gen

def test_interharmonics():
    return gen_interharmonic([0.1], [15.01401], [np.pi])


# In[ ]:


def gen_noise(amplitude=0.2, fmax:'Hz'=4.9e3, fmin:'Hz'=100, filter_order=6):
    def gen(omega_t, fundamental, sampling_rate, **_):
        noise = np.random.normal(0, amplitude, fundamental.shape)
        b, a = signal.butter(filter_order,
                             [fmin, min(fmax, sampling_rate//2-1)],
                             btype='bandpass',
                             fs=sampling_rate)
        return fundamental + signal.lfilter(b, a, noise, axis=0)
    return gen

def test_noise():
    return gen_noise()

def test_noise_loud():
    return gen_noise(amplitude=0.5, fmin=10)


# In[406]:


def gen_steps(size_amplitude=0.1, size_phase=0.1*np.pi, steps_per_sec=1):
    def gen(omega_t, fundamental, duration, **_):
        n = int(steps_per_sec * duration)
        indices = np.random.randint(0, len(omega_t), n)
        amplitudes = np.random.normal(1, size_amplitude, (n, 6))
        phases = np.random.normal(0, size_phase, (n, 6))
        amplitude = np.ones(omega_t.shape)
        for start, end, a, p in zip(indices, indices[1:], amplitudes, phases):
            omega_t[start:end] += p
            amplitude[start:end] = a
        return amplitude*np.sin(omega_t)
    return gen

def test_amplitude_steps():
    return gen_steps(size_amplitude=0.4, size_phase=0)

def test_phase_steps():
    return gen_steps(size_amplitude=0, size_phase=0.1)

def test_amplitude_and_phase_steps():
    return gen_steps(size_amplitude=0.2, size_phase=0.07)


# In[418]:


def step_gen(shape, stdev, duration, steps_per_sec=1.0, mean=0.0):
    samples, channels = shape
    n = int(steps_per_sec * duration)
    indices = np.random.randint(0, samples, n)
    phases = np.random.normal(mean, stdev, (n, 6))
    amplitude = np.ones((samples, channels))
    out = np.zeros(shape)
    for start, end, a in zip(indices, indices[1:], amplitude):
        out[start:end] = a
    return out

def gen_chirp(fmin, fmax, period, dwell_time=1.0, amplitude=None, phase_steps=None):
    def gen(omega_t, fundamental, sampling_rate, duration, **_):
        samples = int(duration*sampling_rate)
        phases = np.linspace(0, 2*np.pi, 6, endpoint=False)
        
        c = (fmax-fmin)/period
        t = np.linspace(0, duration, samples)
        
        x = np.repeat(np.reshape(2*np.pi*fmin*t, (-1,1)), 6, axis=1)
        data = (phases + x)[:int(sampling_rate*dwell_time)]
        current_phase = 2*np.pi*fmin*dwell_time
        direction = 'up'
        
        for idx in range(int(dwell_time*sampling_rate), samples, int(2*period*sampling_rate)):
            t1 = np.linspace(0, period, int(period*sampling_rate))
            t2 = np.linspace(0, period, int(period*sampling_rate))
            chirp_phase = np.hstack((
                2*np.pi*(c/2 * t1**2 + fmin * t1),
                2*np.pi*(-c/2 * t2**2 + fmax * t2 - (c/2 * period**2 + fmin * period))
            ))
            chirp_phase = np.repeat(np.reshape(chirp_phase, (-1, 1)), 6, axis=1)
            new = phases + chirp_phase + current_phase
            current_phase = chirp_phase[-1]
            data = np.vstack((data, new))
            
        data = data[:len(fundamental)]
        
        if phase_steps:
            (step_amplitude, steps_per_sec) = phase_steps
            steps = step_gen(data.shape, step_amplitude, duration, steps_per_sec)
            data += steps
            
        if amplitude is None:
            return np.sin(data)
        else:
            return fundamental + amplitude*np.sin(data)
    return gen

def test_close_interharmonics_and_flicker():
    return gen_chirp(90.0, 150.0, 10, 1, amplitude=0.1)

def test_off_frequency():
#     return gen_chirp(48.0, 52.0, 0.25, 1)
     return gen_chirp(48.0, 52.0, 10, 1)

def test_sweep_phase_steps():
    return gen_chirp(48.0, 52.0, 10, 1, phase_steps=(0.1, 1))
#     return gen_chirp(48.0, 52.0, 0.25, 1, phase_steps=(0.1, 1))


# In[ ]:


all_tests = [test_harmonics, test_interharmonics, test_noise, test_noise_loud, test_amplitude_steps, test_phase_steps, test_amplitude_and_phase_steps, test_close_interharmonics_and_flicker, test_off_frequency, test_sweep_phase_steps]