#!/usr/bin/env python3 import os from os import path import subprocess import json from collections import namedtuple, defaultdict from tqdm import tqdm import uuid import multiprocessing import sqlite3 import time from urllib.parse import urlparse import functools import tempfile import itertools import numpy as np np.set_printoptions(linewidth=240) from dsss_demod_test_waveform_gen import load_noise_meas_params, load_noise_synth_params,\ mains_noise_measured, mains_noise_synthetic, modulate as dsss_modulate def build_test_binary(nbits, thf, decimation, symbols, cachedir): build_id = str(uuid.uuid4()) builddir = path.join(cachedir, build_id) os.mkdir(builddir) cwd = path.join(path.dirname(__file__), '..') env = os.environ.copy() env['BUILDDIR'] = path.abspath(builddir) env['DSSS_GOLD_CODE_NBITS'] = str(nbits) env['DSSS_DECIMATION'] = str(decimation) env['DSSS_THRESHOLD_FACTOR'] = str(thf) env['DSSS_WAVELET_WIDTH'] = str(0.73 * decimation) env['DSSS_WAVELET_LUT_SIZE'] = str(10 * decimation) env['TRANSMISSION_SYMBOLS'] = str(symbols) with open(path.join(builddir, 'make_stdout.txt'), 'w') as stdout,\ open(path.join(builddir, 'make_stderr.txt'), 'w') as stderr: subprocess.run(['make', 'clean', os.path.abspath(path.join(builddir, 'tools/dsss_demod_test'))], env=env, cwd=cwd, check=True, stdout=stdout, stderr=stderr) return build_id @functools.lru_cache() def load_noise_gen(url): schema, refpath = url.split('://') if not path.isabs(refpath): refpath = path.abspath(path.join(path.dirname(__file__), refpath)) if schema == 'meas': return mains_noise_measured, load_noise_meas_params(refpath) elif schema == 'synth': return mains_noise_synthetic, load_noise_synth_params(refpath) else: raise ValueError('Invalid schema', schema) def sequence_matcher(test_data, decoded, max_shift=3): match_result = [] for shift in range(-max_shift, max_shift): failures = -shift if shift < 0 else 0 # we're skipping the first $shift symbols a = test_data if shift > 0 else test_data[-shift:] b = decoded if shift < 0 else decoded[shift:] for i, (ref, found) in enumerate(itertools.zip_longest(a, b)): if ref is None: # end of signal break if ref != found: failures += 1 match_result.append(failures) failures = min(match_result) return failures/len(test_data) ResultParams = namedtuple('ResultParams', ['nbits', 'thf', 'decimation', 'symbols', 'seed', 'amplitude', 'background']) def run_test(seed, amplitude_spec, background, nbits, decimation, symbols, thfs, lookup_binary, cachedir): noise_gen, noise_params = load_noise_gen(background) test_data = np.random.RandomState(seed=seed).randint(0, 2 * (2**nbits), symbols) signal = np.repeat(dsss_modulate(test_data, nbits) * 2.0 - 1, decimation) # We're re-using the seed here. This is not a problem. noise = noise_gen(seed, len(signal), *noise_params) # DEBUG # Map lsb to sign to match test program # test_data = (test_data>>1) * (2*(test_data&1) - 1) amplitudes = amplitude_spec[0] * 10 ** np.linspace(0, amplitude_spec[1], amplitude_spec[2]) output = [] for amp in amplitudes: with tempfile.NamedTemporaryFile(dir=cachedir) as f: waveform = signal*amp + noise f.write(waveform.astype('float32').tobytes()) f.flush() for thf in thfs: cmdline = [lookup_binary(nbits, thf, decimation, symbols), f.name] proc = subprocess.Popen(cmdline, stdout=subprocess.PIPE, text=True) stdout, _stderr = proc.communicate() if proc.returncode != 0: raise SystemError(f'Subprocess signalled error: {proc.returncode=}') lines = stdout.splitlines() matched = [ l.partition('[')[2].partition(']')[0] for l in lines if l.strip().startswith('data sequence received:') ] matched = [ [ int(elem) for elem in l.split(',') ] for l in matched ] ser = min(sequence_matcher(test_data, match) for match in matched) if matched else None rpars = ResultParams(nbits, thf, decimation, symbols, seed, amp, background) output.append((rpars, ser)) return output def parallel_generator(db, table, columns, builder, param_list, desc, context={}, params_mapper=lambda *args: args, disable_cache=False): with multiprocessing.Pool(multiprocessing.cpu_count()) as pool: with db as conn: jobs = [] for params in param_list: found_res = conn.execute( f'SELECT result FROM {table} WHERE ({",".join(columns)}) = ({",".join("?"*len(columns))})', params_mapper(*params)).fetchone() if found_res and not disable_cache: yield params, json.loads(*found_res) else: jobs.append((params, pool.apply_async(builder, params, context))) pool.close() print('Using', len(param_list) - len(jobs), 'cached jobs', flush=True) with tqdm(total=len(jobs), desc=desc) as tq: for params, res in jobs: tq.update(1) result = res.get() with db as conn: conn.execute(f'INSERT INTO {table} VALUES ({"?,"*len(params)}?,?)', (*params_mapper(*params), json.dumps(result), timestamp())) yield params, result pool.join() if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument('-d', '--dump', help='Write results to JSON file') parser.add_argument('-c', '--cachedir', default='dsss_test_cache', help='Directory to store build output and data in') parser.add_argument('-n', '--no-cache', action='store_true', help='Disable result cache') args = parser.parse_args() DecoderParams = namedtuple('DecoderParams', ['nbits', 'thf', 'decimation', 'symbols']) # dec_paramses = [ DecoderParams(nbits=nbits, thf=thf, decimation=decimation, symbols=20) # for nbits in [5, 6] # for thf in [4.5, 4.0, 5.0] # for decimation in [10, 5, 22] ] dec_paramses = [ DecoderParams(nbits=nbits, thf=thf, decimation=decimation, symbols=100) for nbits in [5, 6] for thf in [3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.5, 7.0, 7.5, 8.0, 8.5, 9.0, 9.5, 10.0] for decimation in [1, 2, 3, 4, 5, 6, 9, 10, 11, 12, 16, 22, 30, 40, 50] ] # dec_paramses = [ DecoderParams(nbits=nbits, thf=thf, decimation=decimation, symbols=100) # for nbits in [5, 6, 7, 8] # for thf in [1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.5, 7.0, 7.5, 8.0, 8.5, 9.0, 9.5, 10.0] # for decimation in [1, 2, 3, 4, 5, 6, 9, 10, 11, 12, 16, 22, 30, 40, 50] ] build_cache_dir = path.join(args.cachedir, 'builds') data_cache_dir = path.join(args.cachedir, 'data') os.makedirs(build_cache_dir, exist_ok=True) os.makedirs(data_cache_dir, exist_ok=True) build_db = sqlite3.connect(path.join(args.cachedir, 'build_db.sqlite3')) build_db.execute('CREATE TABLE IF NOT EXISTS builds (nbits, thf, decimation, symbols, result, timestamp)') timestamp = lambda: int(time.time()*1000) builds = dict(parallel_generator(build_db, table='builds', columns=['nbits', 'thf', 'decimation', 'symbols'], builder=build_test_binary, param_list=dec_paramses, desc='Building decoders', context=dict(cachedir=build_cache_dir))) print('Done building decoders.') GeneratorParams = namedtuple('GeneratorParams', ['seed', 'amplitude_spec', 'background']) gen_params = [ GeneratorParams(rep, (5e-3, 1, 5), background) #GeneratorParams(rep, (0.05e-3, 3.5, 50), background) for rep in range(50) for background in ['meas://fmeas_export_ocxo_2day.bin', 'synth://grid_freq_psd_spl_108pt.json'] ] # gen_params = [ GeneratorParams(rep, (5e-3, 1, 5), background) # for rep in range(1) # for background in ['meas://fmeas_export_ocxo_2day.bin'] ] data_db = sqlite3.connect(path.join(args.cachedir, 'data_db.sqlite3')) data_db.execute('CREATE TABLE IF NOT EXISTS waveforms' '(seed, amplitude_spec, background, nbits, decimation, symbols, thresholds, result, timestamp)') dec_param_groups = defaultdict(lambda: []) for nbits, thf, decimation, symbols in dec_paramses: dec_param_groups[(nbits, decimation, symbols)].append(thf) waveform_params = [ (*gp, *dp, thfs) for gp in gen_params for dp, thfs in dec_param_groups.items() ] print(f'Generated {len(waveform_params)} parameter sets') def lookup_binary(*params): return path.join(build_cache_dir, builds[tuple(params)], 'tools/dsss_demod_test') def params_mapper(seed, amplitude_spec, background, nbits, decimation, symbols, thresholds): amplitude_spec = ','.join(str(x) for x in amplitude_spec) thresholds = ','.join(str(x) for x in thresholds) return seed, amplitude_spec, background, nbits, decimation, symbols, thresholds results = [] for _params, chunk in parallel_generator(data_db, 'waveforms', ['seed', 'amplitude_spec', 'background', 'nbits', 'decimation', 'symbols', 'thresholds'], params_mapper=params_mapper, builder=run_test, param_list=waveform_params, desc='Generating waveforms', context=dict(cachedir=data_cache_dir, lookup_binary=lookup_binary), disable_cache=args.no_cache): results += chunk if args.dump: with open(args.dump, 'w') as f: json.dump(results, f)