summaryrefslogtreecommitdiff
path: root/controller/fw/tools/dsss_demod_test_runner.py
diff options
context:
space:
mode:
authorjaseg <git-bigdata-wsl-arch@jaseg.de>2020-04-17 17:59:08 +0200
committerjaseg <git-bigdata-wsl-arch@jaseg.de>2020-04-17 17:59:08 +0200
commit87ae7dfcb35d9a55950eecc2116d13d73b2b9ece (patch)
tree090946c28eb7b9d3028844bf4e0dcc01ddb29664 /controller/fw/tools/dsss_demod_test_runner.py
parente505627adad7510673f983cd158016342aa1bdfc (diff)
downloadmaster-thesis-87ae7dfcb35d9a55950eecc2116d13d73b2b9ece.tar.gz
master-thesis-87ae7dfcb35d9a55950eecc2116d13d73b2b9ece.tar.bz2
master-thesis-87ae7dfcb35d9a55950eecc2116d13d73b2b9ece.zip
fw simulator: WIP
Diffstat (limited to 'controller/fw/tools/dsss_demod_test_runner.py')
-rw-r--r--controller/fw/tools/dsss_demod_test_runner.py200
1 files changed, 184 insertions, 16 deletions
diff --git a/controller/fw/tools/dsss_demod_test_runner.py b/controller/fw/tools/dsss_demod_test_runner.py
index 4e93d7b..e31a686 100644
--- a/controller/fw/tools/dsss_demod_test_runner.py
+++ b/controller/fw/tools/dsss_demod_test_runner.py
@@ -4,36 +4,204 @@ import os
from os import path
import subprocess
import json
+from collections import namedtuple, defaultdict
+from tqdm import tqdm
+import uuid
+import multiprocessing
+import sqlite3
+import time
+from urllib.parse import urlparse
+import functools
+import tempfile
+import itertools
import numpy as np
np.set_printoptions(linewidth=240)
+from dsss_demod_test_waveform_gen import load_noise_meas_params, load_noise_synth_params,\
+ mains_noise_measured, mains_noise_synthetic, modulate as dsss_modulate
+
+
+def build_test_binary(nbits, thf, decimation, symbols, cachedir):
+ build_id = str(uuid.uuid4())
+ builddir = path.join(cachedir, build_id)
+ os.mkdir(builddir)
+
+ cwd = path.join(path.dirname(__file__), '..')
+
+ env = os.environ.copy()
+ env['BUILDDIR'] = path.abspath(builddir)
+ env['DSSS_GOLD_CODE_NBITS'] = str(nbits)
+ env['DSSS_DECIMATION'] = str(decimation)
+ env['DSSS_THRESHOLD_FACTOR'] = str(thf)
+ env['DSSS_WAVELET_WIDTH'] = str(0.73 * decimation)
+ env['DSSS_WAVELET_LUT_SIZE'] = str(10 * decimation)
+ env['TRANSMISSION_SYMBOLS'] = str(symbols)
+
+ with open(path.join(builddir, 'make_stdout.txt'), 'w') as stdout,\
+ open(path.join(builddir, 'make_stderr.txt'), 'w') as stderr:
+ subprocess.run(['make', 'clean', os.path.abspath(path.join(builddir, 'tools/dsss_demod_test'))],
+ env=env, cwd=cwd, check=True, stdout=stdout, stderr=stderr)
+
+ return build_id
+
+@functools.lru_cache()
+def load_noise_gen(url):
+ schema, refpath = url.split('://')
+ if not path.isabs(refpath):
+ refpath = path.abspath(path.join(path.dirname(__file__), refpath))
+
+ if schema == 'meas':
+ return mains_noise_measured, load_noise_meas_params(refpath)
+ elif schema == 'synth':
+ return mains_noise_synthetic, load_noise_synth_params(refpath)
+ else:
+ raise ValueError('Invalid schema', schema)
+
+def sequence_matcher(test_data, decoded, max_shift=3):
+ match_result = []
+ for shift in range(-max_shift, max_shift):
+ failures = -shift if shift < 0 else 0 # we're skipping the first $shift symbols
+ a = test_data if shift > 0 else test_data[-shift:]
+ b = decoded if shift < 0 else decoded[shift:]
+ for i, (ref, found) in enumerate(itertools.zip_longest(a, b)):
+ if ref is None: # end of signal
+ break
+ if ref != found:
+ failures += 1
+ match_result.append(failures)
+ failures = min(match_result)
+ return failures/len(test_data)
+
+ResultParams = namedtuple('ResultParams', ['nbits', 'thf', 'decimation', 'symbols', 'seed', 'amplitude', 'background'])
+
+def run_test(seed, amplitude_spec, background, nbits, decimation, symbols, thfs, lookup_binary, cachedir):
+ noise_gen, noise_params = load_noise_gen(background)
+
+ test_data = np.random.RandomState(seed=seed).randint(0, 2 * (2**nbits), symbols)
+
+ signal = np.repeat(dsss_modulate(test_data, nbits) * 2.0 - 1, decimation)
+ # We're re-using the seed here. This is not a problem.
+ noise = noise_gen(seed, len(signal), *noise_params)
+
+ amplitudes = amplitude_spec[0] * 10 ** np.linspace(0, amplitude_spec[1], amplitude_spec[2])
+ output = []
+ for amp in amplitudes:
+ with tempfile.NamedTemporaryFile(dir=cachedir) as f:
+ waveform = signal*amp + noise
+ f.write(waveform.astype('float').tobytes())
+ f.flush()
+
+ for thf in thfs:
+ cmdline = [lookup_binary(nbits, thf, decimation, symbols), f.name]
+ proc = subprocess.Popen(cmdline, stdout=subprocess.PIPE, text=True)
+ stdout, _stderr = proc.communicate()
+ if proc.returncode != 0:
+ raise SystemError(f'Subprocess signalled error: {proc.returncode=}')
+
+ lines = stdout.splitlines()
+ matched = [ l.partition('[')[2].partition(']')[0]
+ for l in lines if l.strip().startswith('data sequence received:') ]
+ matched = [ [ int(elem) for elem in l.split(',') ] for l in matched ]
+
+ ser = min(sequence_matcher(test_data, match) for match in matched) if matched else None
+ rpars = ResultParams(nbits, thf, decimation, symbols, seed, amp, background)
+ output.append((rpars, ser))
+ print(f'ran {rpars} {ser=} {" ".join(cmdline)}')
+ return output
+
+def parallel_generator(db, table, columns, builder, param_list, desc, context={}, params_mapper=lambda *args: args):
+ with multiprocessing.Pool(multiprocessing.cpu_count()) as pool:
+ with db as conn:
+ jobs = []
+ for params in param_list:
+ found_res = conn.execute(
+ f'SELECT result FROM {table} WHERE ({",".join(columns)}) = ({",".join("?"*len(columns))})',
+ params_mapper(*params)).fetchone()
+
+ if found_res:
+ yield params, json.loads(*found_res)
+
+ else:
+ jobs.append((params, pool.apply_async(builder, params, context)))
+
+ pool.close()
+ print('Using', len(param_list) - len(jobs), 'cached jobs', flush=True)
+ with tqdm(total=len(jobs), desc=desc) as tq:
+ for params, res in jobs:
+ tq.update(1)
+ result = res.get()
+ with db as conn:
+ conn.execute(f'INSERT INTO {table} VALUES ({"?,"*len(params)}?,?)',
+ (*params_mapper(*params), json.dumps(result), timestamp()))
+ yield params, result
+ pool.join()
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
- parser.add_argument(metavar='test_data_directory', dest='dir', help='Directory with test data .bin files')
- default_binary = path.abspath(path.join(path.dirname(__file__), '../build/tools/dsss_demod_test'))
- parser.add_argument(metavar='test_binary', dest='binary', nargs='?', default=default_binary)
- parser.add_argument('-d', '--dump', help='Write raw measurements to JSON file')
+ parser.add_argument('-d', '--dump', help='Write results to JSON file')
+ parser.add_argument('-c', '--cachedir', default='dsss_test_cache', help='Directory to store build output and data in')
args = parser.parse_args()
- bin_files = [ path.join(args.dir, d) for d in os.listdir(args.dir) if d.lower().endswith('.bin') ]
+ DecoderParams = namedtuple('DecoderParams', ['nbits', 'thf', 'decimation', 'symbols'])
+ dec_paramses = [ DecoderParams(nbits=nbits, thf=thf, decimation=decimation, symbols=20)
+ for nbits in [5, 6]
+ for thf in [4.5, 4.0, 5.0]
+ for decimation in [10, 5, 22] ]
+# dec_paramses = [ DecoderParams(nbits=nbits, thf=thf, decimation=decimation, symbols=100)
+# for nbits in [5, 6, 7, 8]
+# for thf in [1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.5, 7.0, 7.5, 8.0, 8.5, 9.0, 9.5, 10.0]
+# for decimation in [1, 2, 3, 4, 5, 6, 9, 10, 11, 12, 16, 22, 30, 40, 50] ]
+
+ build_cache_dir = path.join(args.cachedir, 'builds')
+ data_cache_dir = path.join(args.cachedir, 'data')
+ os.makedirs(build_cache_dir, exist_ok=True)
+ os.makedirs(data_cache_dir, exist_ok=True)
+
+ build_db = sqlite3.connect(path.join(args.cachedir, 'build_db.sqlite3'))
+ build_db.execute('CREATE TABLE IF NOT EXISTS builds (nbits, thf, decimation, symbols, result, timestamp)')
+ timestamp = lambda: int(time.time()*1000)
+
+ builds = dict(parallel_generator(build_db, table='builds', columns=['nbits', 'thf', 'decimation', 'symbols'],
+ builder=build_test_binary, param_list=dec_paramses, desc='Building decoders',
+ context=dict(cachedir=build_cache_dir)))
+ print('Done building decoders.')
+
+ GeneratorParams = namedtuple('GeneratorParams', ['seed', 'amplitude_spec', 'background'])
+ gen_params = [ GeneratorParams(rep, (5e-3, 1, 5), background)
+ #GeneratorParams(rep, (0.05e-3, 3.5, 50), background)
+ for rep in range(30)
+ for background in ['meas://fmeas_export_ocxo_2day.bin', 'synth://grid_freq_psd_spl_108pt.json'] ]
+
+ data_db = sqlite3.connect(path.join(args.cachedir, 'data_db.sqlite3'))
+ data_db.execute('CREATE TABLE IF NOT EXISTS waveforms'
+ '(seed, amplitude_spec, background, nbits, decimation, symbols, thresholds, result, timestamp)')
+
+ dec_param_groups = defaultdict(lambda: [])
+ for nbits, thf, decimation, symbols in dec_paramses:
+ dec_param_groups[(nbits, decimation, symbols)].append(thf)
+ waveform_params = [ (*gp, *dp, thfs) for gp in gen_params for dp, thfs in dec_param_groups.items() ]
+ print(f'Generated {len(waveform_params)} parameter sets')
- savedata = {}
- for p in bin_files:
- output = subprocess.check_output([args.binary, p], stderr=subprocess.DEVNULL)
- measurements = np.array([ float(value) for _offset, value in [ line.split() for line in output.splitlines() ] ])
- savedata[p] = list(measurements)
+ def lookup_binary(*params):
+ return path.join(build_cache_dir, builds[tuple(params)], 'tools/dsss_demod_test')
- # Cut off first and last sample for mean and RMS calculations as these show boundary effects.
- measurements = measurements[1:-1]
- mean = np.mean(measurements)
- rms = np.sqrt(np.mean(np.square(measurements - mean)))
+ def params_mapper(seed, amplitude_spec, background, nbits, decimation, symbols, thresholds):
+ amplitude_spec = ','.join(str(x) for x in amplitude_spec)
+ thresholds = ','.join(str(x) for x in thresholds)
+ return seed, amplitude_spec, background, nbits, decimation, symbols, thresholds
- print(f'{path.basename(p):<60}: mean={mean:<8.4f}Hz rms={rms*1000:.3f}mHz')
+ results = []
+ for _params, chunk in parallel_generator(data_db, 'waveforms',
+ ['seed', 'amplitude_spec', 'background', 'nbits', 'decimation', 'symbols', 'thresholds'],
+ params_mapper=params_mapper,
+ builder=run_test,
+ param_list=waveform_params, desc='Generating waveforms',
+ context=dict(cachedir=data_cache_dir, lookup_binary=lookup_binary)):
+ results += chunk
if args.dump:
with open(args.dump, 'w') as f:
- json.dump(savedata, f)
+ json.dump(results, f)