#!/usr/bin/env python3 import os import sys from os import path import subprocess import json from collections import namedtuple, defaultdict from tqdm import tqdm import uuid import multiprocessing import sqlite3 import time from urllib.parse import urlparse import functools import tempfile import itertools import numpy as np np.set_printoptions(linewidth=240) from dsss_demod_test_waveform_gen import load_noise_meas_params, load_noise_synth_params,\ mains_noise_measured, mains_noise_synthetic, modulate as dsss_modulate def build_test_binary(nbits, thf, decimation, symbols, cachedir): build_id = str(uuid.uuid4()) builddir = path.join(cachedir, build_id) os.mkdir(builddir) cwd = path.join(path.dirname(__file__), '..') env = os.environ.copy() env['BUILDDIR'] = path.abspath(builddir) env['DSSS_GOLD_CODE_NBITS'] = str(nbits) env['DSSS_DECIMATION'] = str(decimation) env['DSSS_THRESHOLD_FACTOR'] = str(thf) env['DSSS_WAVELET_WIDTH'] = str(0.73 * decimation) env['DSSS_WAVELET_LUT_SIZE'] = str(10 * decimation) env['TRANSMISSION_SYMBOLS'] = str(symbols) with open(path.join(builddir, 'make_stdout.txt'), 'w') as stdout,\ open(path.join(builddir, 'make_stderr.txt'), 'w') as stderr: subprocess.run(['make', 'clean', os.path.abspath(path.join(builddir, 'tools/dsss_demod_test'))], env=env, cwd=cwd, check=True, stdout=stdout, stderr=stderr) return build_id @functools.lru_cache() def load_noise_gen(url): schema, refpath = url.split('://') if not path.isabs(refpath): refpath = path.abspath(path.join(path.dirname(__file__), refpath)) if schema == 'meas': return mains_noise_measured, load_noise_meas_params(refpath) elif schema == 'synth': return mains_noise_synthetic, load_noise_synth_params(refpath) else: raise ValueError('Invalid schema', schema) def sequence_matcher(test_data, decoded, max_shift=3): match_result = [] for shift in range(-max_shift, max_shift): failures = -shift if shift < 0 else 0 # we're skipping the first $shift symbols a = test_data if shift > 0 else test_data[-shift:] b = decoded if shift < 0 else decoded[shift:] for i, (ref, found) in enumerate(itertools.zip_longest(a, b)): if ref is None: # end of signal break if ref != found: failures += 1 match_result.append(failures) failures = min(match_result) return failures/len(test_data) ResultParams = namedtuple('ResultParams', ['nbits', 'thf', 'decimation', 'symbols', 'seed', 'amplitude', 'background']) def run_test(seed, amplitude_spec, background, nbits, decimation, symbols, thfs, lookup_binary, cachedir): noise_gen, noise_params = load_noise_gen(background) test_data = np.random.RandomState(seed=seed).randint(0, 2 * (2**nbits), symbols) signal = np.repeat(dsss_modulate(test_data, nbits) * 2.0 - 1, decimation) # We're re-using the seed here. This is not a problem. noise = noise_gen(seed, len(signal), *noise_params) amplitudes = amplitude_spec[0] * 10 ** np.linspace(0, amplitude_spec[1], amplitude_spec[2]) # DEBUG my_pid = multiprocessing.current_process().pid wql = len(amplitudes) * len(thfs) print(f'[{my_pid}] starting, got workqueue of length {wql}') i = 0 # Map lsb to sign to match test program # test_data = (test_data>>1) * (2*(test_data&1) - 1) # END DEBUG output = [] for amp in amplitudes: with tempfile.NamedTemporaryFile(dir=cachedir) as f: waveform = signal*amp + noise f.write(waveform.astype('float32').tobytes()) f.flush() # DEBUG fcopy = f'/tmp/test-{path.basename(f.name)}' import shutil shutil.copy(f.name, fcopy) # END DEBUG for thf in thfs: rpars = ResultParams(nbits, thf, decimation, symbols, seed, amp, background) cmdline = [lookup_binary(nbits, thf, decimation, symbols), f.name] # DEBUG starttime = time.time() # END DEBUG try: proc = subprocess.run(cmdline, stdout=subprocess.PIPE, encoding='utf-8', check=True, timeout=300) lines = proc.stdout.splitlines() matched = [ l.partition('[')[2].partition(']')[0] for l in lines if l.strip().startswith('data sequence received:') ] matched = [ [ int(elem) for elem in l.split(',') ] for l in matched ] ser = min(sequence_matcher(test_data, match) for match in matched) if matched else None output.append((rpars, ser)) # DEBUG #print(f'[{my_pid}] ran {i}/{wql}: time={time.time() - starttime}\n {ser=}\n {rpars}\n {" ".join(cmdline)}\n {fcopy}', flush=True) i += 1 # END DEBUG except subprocess.TimeoutExpired: output.append((rpars, None)) # DEBUG print(f'[{my_pid}] ran {i}/{wql}: Timeout!\n {rpars}\n {" ".join(cmdline)}\n {fcopy}', flush=True) i += 1 # END DEBUG print(f'[{my_pid}] finished.') return output def parallel_generator(db, table, columns, builder, param_list, desc, context={}, params_mapper=lambda *args: args, disable_cache=False): with multiprocessing.Pool(multiprocessing.cpu_count()) as pool: with db as conn: jobs = [] for params in param_list: found_res = conn.execute( f'SELECT result FROM {table} WHERE ({",".join(columns)}) = ({",".join("?"*len(columns))})', params_mapper(*params)).fetchone() if found_res and not disable_cache: yield params, json.loads(*found_res) else: jobs.append((params, pool.apply_async(builder, params, context))) pool.close() print('Using', len(param_list) - len(jobs), 'cached jobs', flush=True) with tqdm(total=len(jobs), desc=desc) as tq: for i, (params, res) in enumerate(jobs): # DEBUG print('Got result', i, params, res) # END DEBUG tq.update(1) result = res.get() with db as conn: conn.execute(f'INSERT INTO {table} VALUES ({"?,"*len(params)}?,?)', (*params_mapper(*params), json.dumps(result), timestamp())) yield params, result pool.join() if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument('-d', '--dump', help='Write results to JSON file') parser.add_argument('-c', '--cachedir', default='dsss_test_cache', help='Directory to store build output and data in') parser.add_argument('-n', '--no-cache', action='store_true', help='Disable result cache') parser.add_argument('-b', '--batches', type=int, default=1, help='Number of batches to split the computation into') parser.add_argument('-i', '--index', type=int, default=0, help='Batch index to compute') parser.add_argument('-p', '--prepare', action='store_true', help='Prepare mode: compile runners, then exit.') args = parser.parse_args() DecoderParams = namedtuple('DecoderParams', ['nbits', 'thf', 'decimation', 'symbols']) # dec_paramses = [ DecoderParams(nbits=nbits, thf=thf, decimation=decimation, symbols=20) # for nbits in [5, 6] # for thf in [4.5, 4.0, 5.0] # for decimation in [10, 5, 22] ] dec_paramses = [ DecoderParams(nbits=nbits, thf=thf, decimation=decimation, symbols=100) for nbits in [5, 6] for thf in [3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.5, 7.0, 7.5, 8.0, 8.5, 9.0, 9.5, 10.0] for decimation in [1, 2, 3, 4, 5, 6, 9, 10, 11, 12, 16, 22, 30, 40, 50] ] # dec_paramses = [ DecoderParams(nbits=nbits, thf=thf, decimation=decimation, symbols=100) # for nbits in [5, 6, 7, 8] # for thf in [1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.5, 7.0, 7.5, 8.0, 8.5, 9.0, 9.5, 10.0] # for decimation in [1, 2, 3, 4, 5, 6, 9, 10, 11, 12, 16, 22, 30, 40, 50] ] build_cache_dir = path.join(args.cachedir, 'builds') data_cache_dir = path.join(args.cachedir, 'data') os.makedirs(build_cache_dir, exist_ok=True) os.makedirs(data_cache_dir, exist_ok=True) build_db = sqlite3.connect(path.join(args.cachedir, 'build_db.sqlite3')) build_db.execute('CREATE TABLE IF NOT EXISTS builds (nbits, thf, decimation, symbols, result, timestamp)') timestamp = lambda: int(time.time()*1000) builds = dict(parallel_generator(build_db, table='builds', columns=['nbits', 'thf', 'decimation', 'symbols'], builder=build_test_binary, param_list=dec_paramses, desc='Building decoders', context=dict(cachedir=build_cache_dir))) print('Done building decoders.') if args.prepare: sys.exit(0) GeneratorParams = namedtuple('GeneratorParams', ['seed', 'amplitude_spec', 'background']) gen_params = [ GeneratorParams(rep, (5e-3, 1, 5), background) #GeneratorParams(rep, (0.05e-3, 3.5, 50), background) for rep in range(50) for background in ['meas://fmeas_export_ocxo_2day.bin', 'synth://grid_freq_psd_spl_108pt.json'] ] # gen_params = [ GeneratorParams(rep, (5e-3, 1, 5), background) # for rep in range(1) # for background in ['meas://fmeas_export_ocxo_2day.bin'] ] data_db = sqlite3.connect(path.join(args.cachedir, 'data_db.sqlite3')) data_db.execute('CREATE TABLE IF NOT EXISTS waveforms' '(seed, amplitude_spec, background, nbits, decimation, symbols, thresholds, result, timestamp)') 'SELECT FROM waveforms GROUP BY (amplitude_spec, background, nbits, decimation, symbols, thresholds, result)' dec_param_groups = defaultdict(lambda: []) for nbits, thf, decimation, symbols in dec_paramses: dec_param_groups[(nbits, decimation, symbols)].append(thf) waveform_params = [ (*gp, *dp, thfs) for gp in gen_params for dp, thfs in dec_param_groups.items() ] print(f'Generated {len(waveform_params)} parameter sets') # Separate out our batch waveform_params = waveform_params[args.index::args.batches] def lookup_binary(*params): return path.join(build_cache_dir, builds[tuple(params)], 'tools/dsss_demod_test') def params_mapper(seed, amplitude_spec, background, nbits, decimation, symbols, thresholds): amplitude_spec = ','.join(str(x) for x in amplitude_spec) thresholds = ','.join(str(x) for x in thresholds) return seed, amplitude_spec, background, nbits, decimation, symbols, thresholds results = [] for _params, chunk in parallel_generator(data_db, 'waveforms', ['seed', 'amplitude_spec', 'background', 'nbits', 'decimation', 'symbols', 'thresholds'], params_mapper=params_mapper, builder=run_test, param_list=waveform_params, desc='Simulating demodulation', context=dict(cachedir=data_cache_dir, lookup_binary=lookup_binary), disable_cache=args.no_cache): results += chunk if args.dump: with open(args.dump, 'w') as f: json.dump(results, f)