from machine import Pin, ADC, Timer, WDT, reset
import network
import time
import ujson
from uhashlib import sha256
import binascii
import math
import urequests
import ntptime


######################################################## CONFIG ########################################################
# Wifi settings:
# WIFI_ESSID = 'Your wifi ESSID (name)'
# WIFI_PSK = 'Your wifi key'
#
# Notification settings:
# NOTIFICATION_URL = 'https://automation.jaseg.de/notify/klingel'
# NOTIFICATION_SECRET = b'Your notification proxy secret for this endpoint'
#
# NOTIFICATION_COOLDOWN = 60 # how long to wait after sending a notification before sending the next, in seconds
# HEARTBEAT_INTERVAL = 60 # seconds
#
# Detection settings
# MEAN_LEN = 8 # Window length for DC offset determination in seconds (1024ms to be exact)
# RMS_THRESHOLD = 1000  # Threshold for rms detection threshold over 1s window in ADC counts
######################################################### END ##########################################################


def wifi_connect():
    iface = network.WLAN(network.STA_IF)
    if not iface.isconnected():
        print('Connecting to wifi... ')
        iface.active(True)
        iface.connect(WIFI_ESSID, WIFI_PSK)
        for i in range(20):
            if iface.isconnected():
                print('Wifi connected. IP config: ', iface.ifconfig())
                break
            time.sleep(0.5)
        else:
            print("Couldn't connect to wifi.")


buf = [0] * 1024
capture = None
mean = 0
rms = 0
sample_tim = Timer(-1)

def start_sampling():
    global sample_tim
    buf_pos = 0
    buf_sum = 0
    mean_acc = []

    adc = ADC(Pin(34))
    adc.atten(ADC.ATTN_11DB)

    def sample_cb(tim):
        global buf, mean, rms, capture
        nonlocal adc, buf_pos, buf_sum, mean_acc

        val = adc.read()
        buf[buf_pos] = val
        buf_sum += val

        buf_pos += 1
        if buf_pos == len(buf):
            buf_pos = 0
            mean_acc = [buf_sum/len(buf)] + mean_acc[:MEAN_LEN-1]
            mean = sum(mean_acc)/len(mean_acc)
            buf_sum = 0

            rms = math.sqrt( sum( (x-mean)**2 for x in buf )/len(buf) )
            capture = list(buf) # Make a copy

    sample_tim.init(period=1, mode=Timer.PERIODIC, callback=sample_cb) # period in ms

def uhmac(key, data):
    blocksize = 64
    key += bytes(64 - len(key))
    tx = lambda s, x: bytes( b ^ x for b in s )
    outer = sha256(tx(key, 0x5C))
    inner = sha256(tx(key, 0x36))
    inner.update(data)
    outer.update(inner.digest())
    return outer.digest()

def unix_time():
    return int(time.time()) + 946684800 # ESP32 counts from 2000-01-01, unix from 1970-01-01

def usign(secret, scope, payload=None, seq=None):
    payload = {'time': unix_time(), 'scope': scope, 'd': payload}
    if seq is not None:
        payload['seq'] = seq

    payload = ujson.dumps(payload).encode()
    auth = binascii.hexlify(uhmac(secret, payload))

    return ujson.dumps({'payload': payload, 'auth': auth})

def notify(scope, **kwargs):
    wifi_connect()
    data = usign(NOTIFICATION_SECRET, scope, kwargs)
    print(unix_time(), 'Notifying', NOTIFICATION_URL)
    urequests.post(NOTIFICATION_URL, data=data, headers={'Content-Type': 'application/json'})


def classify(trace, mean):
    runs = []
    cur_sign, cur_run = 1, 0
    for x in trace[len(trace)//2:]:
        x -= mean
        if x > 0 and cur_sign < 0:
            runs.append(cur_run)
            cur_sign = 1
            cur_run = 0
        elif x < 0 and cur_sign >= 0:
            runs.append(cur_run)
            cur_sign = -1
            cur_run = 0
        else:
            cur_run += 1
    
    run_means = []
    for start in range(0, len(runs), 8):
        k = 32
        run_means.append(sum(runs[start:start+k])/k)
        
    bin_low, bin_high = 0, 0
    for e in run_means:
        if 0.05 < e < 0.25:
            bin_low += 1
        elif 0.30 < e < 0.50:
            bin_high += 1
    total = len(run_means)
    
    bin_low /= total
    bin_high /= total
    
    if 0.3 < bin_low < 0.7 and 0.3 < bin_high < 0.7:
        return 'downstairs'
    elif bin_low < 0.3 and bin_high > 0.7:
        return 'upstairs'
    else:
        return 'none'
        

def format_exc(limit=None, chain=True):
    return "".join(repr(i) for i in sys.exc_info())

def loop():
    global rms, capture
    last_notification, last_heartbeat = 0, 0
    n_exc, last_exc_clear = 0, unix_time()
    last_ntp_sync = unix_time()

    wdt = WDT(timeout=60000)

    while True:
        try:
            now = unix_time()
            if (now - last_notification) > NOTIFICATION_COOLDOWN and rms > RMS_THRESHOLD:
                    old_capture = capture
                    classification = classify(capture, mean)
                    rms = 0
                    while rms == 0:
                        time.sleep(0.1)
                    rms = 0
                    if classification in ('downstairs', 'upstairs'):
                        notify('default', classification=classification, rms=rms, capture=[old_capture, capture])
                    else:
                        notify('info', info_msg='Unclassified capture', rms=rms, capture=[old_capture, capture])
                    last_notification = now

            if (now - last_heartbeat) > HEARTBEAT_INTERVAL:
                notify('heartbeat')
                last_heartbeat = now

            if (now - last_ntp_sync) > 3600 * 24:
                wifi_connect()
                ntptime.settime()
                last_ntp_sync = now

            if (now - last_exc_clear) > 300:
                if n_exc > 0:
                    n_exc -= 1
                last_exc_clear = now

            wdt.feed()

        except:
            notify('error', e=format_exc())
            n_exc += 1
            if n_exc >= 5:
                reset()

        finally:
            time.sleep(0.1)

wifi_connect()
ntptime.settime()
start_sampling()
notify('boot')
loop()