From 7b85ba8d4fb34e76d34a2d581e89e856aa471cf5 Mon Sep 17 00:00:00 2001 From: jaseg Date: Mon, 21 Dec 2020 16:26:57 +0100 Subject: Move fw into direct subdir --- fw/hid-dials/tools/dsss_demod_test_runner.py | 241 --------------------------- 1 file changed, 241 deletions(-) delete mode 100644 fw/hid-dials/tools/dsss_demod_test_runner.py (limited to 'fw/hid-dials/tools/dsss_demod_test_runner.py') diff --git a/fw/hid-dials/tools/dsss_demod_test_runner.py b/fw/hid-dials/tools/dsss_demod_test_runner.py deleted file mode 100644 index d3c3cfc..0000000 --- a/fw/hid-dials/tools/dsss_demod_test_runner.py +++ /dev/null @@ -1,241 +0,0 @@ -#!/usr/bin/env python3 - -import os -import sys -from os import path -import subprocess -import json -from collections import namedtuple, defaultdict -from tqdm import tqdm -import uuid -import multiprocessing -import sqlite3 -import time -from urllib.parse import urlparse -import tempfile -import itertools - -import numpy as np -np.set_printoptions(linewidth=240) - -from dsss_demod_test_waveform_gen import load_noise_gen, modulate as dsss_modulate - - -def build_test_binary(nbits, thf, decimation, symbols, cachedir): - build_id = str(uuid.uuid4()) - builddir = path.join(cachedir, build_id) - os.mkdir(builddir) - - cwd = path.join(path.dirname(__file__), '..') - - env = os.environ.copy() - env['BUILDDIR'] = path.abspath(builddir) - env['DSSS_GOLD_CODE_NBITS'] = str(nbits) - env['DSSS_DECIMATION'] = str(decimation) - env['DSSS_THRESHOLD_FACTOR'] = str(thf) - env['DSSS_WAVELET_WIDTH'] = str(0.73 * decimation) - env['DSSS_WAVELET_LUT_SIZE'] = str(10 * decimation) - env['TRANSMISSION_SYMBOLS'] = str(symbols) - - with open(path.join(builddir, 'make_stdout.txt'), 'w') as stdout,\ - open(path.join(builddir, 'make_stderr.txt'), 'w') as stderr: - subprocess.run(['make', 'clean', os.path.abspath(path.join(builddir, 'tools/dsss_demod_test'))], - env=env, cwd=cwd, check=True, stdout=stdout, stderr=stderr) - - return build_id - -def sequence_matcher(test_data, decoded, max_shift=3): - match_result = [] - for shift in range(-max_shift, max_shift): - failures = -shift if shift < 0 else 0 # we're skipping the first $shift symbols - a = test_data if shift > 0 else test_data[-shift:] - b = decoded if shift < 0 else decoded[shift:] - for i, (ref, found) in enumerate(itertools.zip_longest(a, b)): - if ref is None: # end of signal - break - if ref != found: - failures += 1 - match_result.append(failures) - failures = min(match_result) - return failures/len(test_data) - -ResultParams = namedtuple('ResultParams', ['nbits', 'thf', 'decimation', 'symbols', 'seed', 'amplitude', 'background']) - -def run_test(seed, amplitude_spec, background, nbits, decimation, symbols, thfs, lookup_binary, cachedir): - noise_gen, noise_params = load_noise_gen(background) - - test_data = np.random.RandomState(seed=seed).randint(0, 2 * (2**nbits), symbols) - - signal = np.repeat(dsss_modulate(test_data, nbits) * 2.0 - 1, decimation) - # We're re-using the seed here. This is not a problem. - noise = noise_gen(seed, len(signal), *noise_params) - amplitudes = amplitude_spec[0] * 10 ** np.linspace(0, amplitude_spec[1], amplitude_spec[2]) - # DEBUG - my_pid = multiprocessing.current_process().pid - wql = len(amplitudes) * len(thfs) - print(f'[{my_pid}] starting, got workqueue of length {wql}') - i = 0 - # Map lsb to sign to match test program - # test_data = (test_data>>1) * (2*(test_data&1) - 1) - # END DEBUG - - output = [] - for amp in amplitudes: - with tempfile.NamedTemporaryFile(dir=cachedir) as f: - waveform = signal*amp + noise - f.write(waveform.astype('float32').tobytes()) - f.flush() - # DEBUG - fcopy = f'/tmp/test-{path.basename(f.name)}' - import shutil - shutil.copy(f.name, fcopy) - # END DEBUG - - for thf in thfs: - rpars = ResultParams(nbits, thf, decimation, symbols, seed, amp, background) - cmdline = [lookup_binary(nbits, thf, decimation, symbols), f.name] - # DEBUG - starttime = time.time() - # END DEBUG - try: - proc = subprocess.run(cmdline, stdout=subprocess.PIPE, encoding='utf-8', check=True, timeout=300) - - lines = proc.stdout.splitlines() - matched = [ l.partition('[')[2].partition(']')[0] - for l in lines if l.strip().startswith('data sequence received:') ] - matched = [ [ int(elem) for elem in l.split(',') ] for l in matched ] - - ser = min(sequence_matcher(test_data, match) for match in matched) if matched else None - output.append((rpars, ser)) - # DEBUG - #print(f'[{my_pid}] ran {i}/{wql}: time={time.time() - starttime}\n {ser=}\n {rpars}\n {" ".join(cmdline)}\n {fcopy}', flush=True) - i += 1 - # END DEBUG - - except subprocess.TimeoutExpired: - output.append((rpars, None)) - # DEBUG - print(f'[{my_pid}] ran {i}/{wql}: Timeout!\n {rpars}\n {" ".join(cmdline)}\n {fcopy}', flush=True) - i += 1 - # END DEBUG - print(f'[{my_pid}] finished.') - return output - -def parallel_generator(db, table, columns, builder, param_list, desc, context={}, params_mapper=lambda *args: args, - disable_cache=False): - with multiprocessing.Pool(multiprocessing.cpu_count()) as pool: - with db as conn: - jobs = [] - for params in param_list: - found_res = conn.execute( - f'SELECT result FROM {table} WHERE ({",".join(columns)}) = ({",".join("?"*len(columns))})', - params_mapper(*params)).fetchone() - - if found_res and not disable_cache: - yield params, json.loads(*found_res) - - else: - jobs.append((params, pool.apply_async(builder, params, context))) - - pool.close() - print('Using', len(param_list) - len(jobs), 'cached jobs', flush=True) - with tqdm(total=len(jobs), desc=desc) as tq: - for i, (params, res) in enumerate(jobs): - # DEBUG - print('Got result', i, params, res) - # END DEBUG - tq.update(1) - result = res.get() - with db as conn: - conn.execute(f'INSERT INTO {table} VALUES ({"?,"*len(params)}?,?)', - (*params_mapper(*params), json.dumps(result), timestamp())) - yield params, result - pool.join() - -if __name__ == '__main__': - import argparse - parser = argparse.ArgumentParser() - parser.add_argument('-d', '--dump', help='Write results to JSON file') - parser.add_argument('-c', '--cachedir', default='dsss_test_cache', help='Directory to store build output and data in') - parser.add_argument('-n', '--no-cache', action='store_true', help='Disable result cache') - parser.add_argument('-b', '--batches', type=int, default=1, help='Number of batches to split the computation into') - parser.add_argument('-i', '--index', type=int, default=0, help='Batch index to compute') - parser.add_argument('-p', '--prepare', action='store_true', help='Prepare mode: compile runners, then exit.') - args = parser.parse_args() - - DecoderParams = namedtuple('DecoderParams', ['nbits', 'thf', 'decimation', 'symbols']) -# dec_paramses = [ DecoderParams(nbits=nbits, thf=thf, decimation=decimation, symbols=20) -# for nbits in [5, 6] -# for thf in [4.5, 4.0, 5.0] -# for decimation in [10, 5, 22] ] - dec_paramses = [ DecoderParams(nbits=nbits, thf=thf, decimation=decimation, symbols=100) - for nbits in [5, 6] - for thf in [3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.5, 7.0, 7.5, 8.0, 8.5, 9.0, 9.5, 10.0] - for decimation in [1, 2, 3, 4, 5, 6, 9, 10, 11, 12, 16, 22, 30, 40, 50] ] -# dec_paramses = [ DecoderParams(nbits=nbits, thf=thf, decimation=decimation, symbols=100) -# for nbits in [5, 6, 7, 8] -# for thf in [1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.5, 7.0, 7.5, 8.0, 8.5, 9.0, 9.5, 10.0] -# for decimation in [1, 2, 3, 4, 5, 6, 9, 10, 11, 12, 16, 22, 30, 40, 50] ] - - build_cache_dir = path.join(args.cachedir, 'builds') - data_cache_dir = path.join(args.cachedir, 'data') - os.makedirs(build_cache_dir, exist_ok=True) - os.makedirs(data_cache_dir, exist_ok=True) - - build_db = sqlite3.connect(path.join(args.cachedir, 'build_db.sqlite3')) - build_db.execute('CREATE TABLE IF NOT EXISTS builds (nbits, thf, decimation, symbols, result, timestamp)') - timestamp = lambda: int(time.time()*1000) - - builds = dict(parallel_generator(build_db, table='builds', columns=['nbits', 'thf', 'decimation', 'symbols'], - builder=build_test_binary, param_list=dec_paramses, desc='Building decoders', - context=dict(cachedir=build_cache_dir))) - print('Done building decoders.') - if args.prepare: - sys.exit(0) - - GeneratorParams = namedtuple('GeneratorParams', ['seed', 'amplitude_spec', 'background']) - gen_params = [ GeneratorParams(rep, (5e-3, 1, 5), background) - #GeneratorParams(rep, (0.05e-3, 3.5, 50), background) - for rep in range(50) - for background in ['meas://fmeas_export_ocxo_2day.bin', 'synth://grid_freq_psd_spl_108pt.json'] ] -# gen_params = [ GeneratorParams(rep, (5e-3, 1, 5), background) -# for rep in range(1) -# for background in ['meas://fmeas_export_ocxo_2day.bin'] ] - - data_db = sqlite3.connect(path.join(args.cachedir, 'data_db.sqlite3')) - data_db.execute('CREATE TABLE IF NOT EXISTS waveforms' - '(seed, amplitude_spec, background, nbits, decimation, symbols, thresholds, result, timestamp)') - - 'SELECT FROM waveforms GROUP BY (amplitude_spec, background, nbits, decimation, symbols, thresholds, result)' - - dec_param_groups = defaultdict(lambda: []) - for nbits, thf, decimation, symbols in dec_paramses: - dec_param_groups[(nbits, decimation, symbols)].append(thf) - waveform_params = [ (*gp, *dp, thfs) for gp in gen_params for dp, thfs in dec_param_groups.items() ] - print(f'Generated {len(waveform_params)} parameter sets') - - # Separate out our batch - waveform_params = waveform_params[args.index::args.batches] - - def lookup_binary(*params): - return path.join(build_cache_dir, builds[tuple(params)], 'tools/dsss_demod_test') - - def params_mapper(seed, amplitude_spec, background, nbits, decimation, symbols, thresholds): - amplitude_spec = ','.join(str(x) for x in amplitude_spec) - thresholds = ','.join(str(x) for x in thresholds) - return seed, amplitude_spec, background, nbits, decimation, symbols, thresholds - - results = [] - for _params, chunk in parallel_generator(data_db, 'waveforms', - ['seed', 'amplitude_spec', 'background', 'nbits', 'decimation', 'symbols', 'thresholds'], - params_mapper=params_mapper, - builder=run_test, - param_list=waveform_params, desc='Simulating demodulation', - context=dict(cachedir=data_cache_dir, lookup_binary=lookup_binary), - disable_cache=args.no_cache): - results += chunk - - if args.dump: - with open(args.dump, 'w') as f: - json.dump(results, f) - -- cgit