From 61accdf0876ddd612d64600eba91bf42ad3d0d55 Mon Sep 17 00:00:00 2001
From: jaseg <git-bigdata-wsl-arch@jaseg.de>
Date: Mon, 17 Feb 2020 14:16:05 +0000
Subject: MLE multibit decoding works

---
 lab-windows/dsss_experiments.ipynb | 448 ++++++++++++++++++++++++++-----------
 1 file changed, 313 insertions(+), 135 deletions(-)

(limited to 'lab-windows')

diff --git a/lab-windows/dsss_experiments.ipynb b/lab-windows/dsss_experiments.ipynb
index dbc06e4..be41b59 100644
--- a/lab-windows/dsss_experiments.ipynb
+++ b/lab-windows/dsss_experiments.ipynb
@@ -72,7 +72,7 @@
     {
      "data": {
       "application/vnd.jupyter.widget-view+json": {
-       "model_id": "a4f8421f05544016854b22a49dbc3698",
+       "model_id": "01394154cc52483e9d4483f1178d94c3",
        "version_major": 2,
        "version_minor": 0
       },
@@ -93,7 +93,7 @@
     {
      "data": {
       "text/plain": [
-       "<matplotlib.image.AxesImage at 0x7f8fd0fdaac0>"
+       "<matplotlib.image.AxesImage at 0x7fcfd7369250>"
       ]
      },
      "execution_count": 5,
@@ -112,34 +112,105 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "def modulate(data, nbits=5, code=29):\n",
+    "def modulate(data, nbits=5):\n",
     "    # 0, 1 -> -1, 1\n",
-    "    mask = gold(nbits)[code]*2 - 1\n",
+    "    mask = np.array(gold(nbits))*2 - 1\n",
     "    \n",
-    "    # same here\n",
-    "    data_centered = (data*2 - 1)\n",
-    "    return (mask[:, np.newaxis] @ data_centered[np.newaxis, :] + 1).T.flatten() //2"
+    "    sel = mask[data>>1]\n",
+    "    data_lsb_centered = ((data&1)*2 - 1)\n",
+    "\n",
+    "    return (np.multiply(sel, np.tile(data_lsb_centered, (2**nbits-1, 1)).T).flatten() + 1) // 2"
    ]
   },
   {
    "cell_type": "code",
    "execution_count": 7,
    "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "(31,) (31,)\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "array([-1, -1, -1, -1, -1,  1,  1,  1, -1, -1,  1, -1, -1, -1,  1, -1,  1, -1,  1,  1,  1,  1, -1,  1,  1, -1,  1, -1, -1,  1,  1,  1,  1,  1,  1,  1, -1, -1, -1,  1,  1, -1,  1,  1,  1, -1,  1, -1,  1, -1, -1, -1, -1,  1, -1, -1,  1, -1,\n",
+       "        1,  1, -1, -1, -1, -1, -1, -1, -1,  1,  1, -1,  1,  1, -1, -1,  1,  1,  1,  1, -1,  1, -1, -1,  1, -1,  1, -1,  1,  1,  1, -1, -1, -1,  1,  1,  1,  1,  1,  1, -1, -1,  1, -1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1, -1,  1, -1,\n",
+       "        1, -1, -1, -1,  1,  1,  1, -1,  1,  1,  1,  1,  1,  1,  1, -1, -1, -1, -1,  1, -1, -1,  1, -1, -1, -1, -1, -1,  1, -1, -1, -1,  1, -1,  1,  1,  1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1,  1,  1,  1, -1,  1,  1, -1,  1,  1,  1,  1,\n",
+       "        1, -1,  1,  1,  1, -1,  1, -1, -1, -1,  1, -1,  1,  1,  1,  1, -1,  1, -1,  1, -1,  1, -1, -1, -1, -1,  1,  1,  1,  1, -1,  1, -1,  1,  1,  1,  1, -1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1, -1,  1, -1,  1, -1,  1,  1,  1,  1, -1,\n",
+       "       -1, -1, -1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1, -1, -1,  1,  1,  1,  1, -1, -1, -1,  1,  1,  1,  1,  1, -1, -1, -1, -1, -1, -1,  1,  1, -1,  1, -1, -1,  1,  1,  1, -1,  1, -1, -1, -1, -1, -1, -1,  1,  1,  1, -1, -1, -1, -1, -1,\n",
+       "        1,  1,  1,  1,  1,  1, -1, -1,  1, -1,  1,  1, -1, -1, -1,  1, -1,  1,  1,  1,  1,  1, -1, -1,  1,  1,  1, -1,  1, -1,  1, -1, -1,  1,  1,  1, -1, -1, -1,  1, -1,  1, -1,  1, -1,  1, -1, -1,  1, -1, -1, -1, -1,  1,  1, -1, -1, -1,\n",
+       "        1, -1,  1, -1,  1,  1, -1, -1, -1,  1,  1,  1, -1,  1, -1,  1, -1,  1, -1,  1,  1, -1,  1,  1,  1, -1, -1,  1, -1,  1, -1, -1, -1, -1,  1, -1,  1, -1, -1,  1,  1,  1,  1, -1,  1,  1, -1, -1, -1,  1,  1,  1,  1, -1, -1, -1,  1,  1,\n",
+       "       -1,  1, -1,  1,  1,  1,  1, -1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1, -1,  1,  1,  1, -1, -1, -1, -1,  1,  1, -1, -1,  1, -1, -1, -1, -1,  1, -1, -1,  1,  1, -1,  1, -1, -1, -1, -1, -1,  1,  1,  1,  1, -1, -1, -1, -1,  1,  1, -1,\n",
+       "       -1,  1,  1, -1,  1,  1,  1,  1, -1,  1,  1, -1, -1,  1, -1,  1,  1,  1,  1,  1, -1, -1, -1, -1,  1,  1,  1,  1, -1, -1,  1,  1])"
+      ]
+     },
+     "execution_count": 7,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "data = np.array(list(range(16)))\n",
+    "\n",
+    "mask = np.array(gold(5))*2 - 1\n",
+    "    \n",
+    "sel = mask[data>>1]\n",
+    "data_lsb_centered = ((data&1)*2 - 1)\n",
+    "mask.shape, data.shape, sel.shape\n",
+    "\n",
+    "#fig, ax = plt.subplots()\n",
+    "#ax.plot(\n",
+    "np.multiply(sel, np.tile(data_lsb_centered, (2**5-1, 1)).T).flatten()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {},
    "outputs": [],
    "source": [
-    "def correlate(sequence, nbits=5, code=29, decimation=1, mask_filter=lambda x: x):\n",
-    "    # 0, 1 -> -1, 1\n",
-    "    mask = mask_filter(np.repeat(gold(nbits)[code]*2 -1, decimation))\n",
-    "    # center\n",
+    "def correlate(sequence, nbits=5, decimation=1, mask_filter=lambda x: x):\n",
+    "    mask = np.tile(np.array(gold(nbits))[:,:,np.newaxis]*2 - 1, (1, 1, decimation)).reshape((2**nbits + 1, (2**nbits-1) * decimation))\n",
+    "\n",
     "    sequence -= np.mean(sequence)\n",
-    "    return np.correlate(sequence, mask, mode='full')"
+    "    \n",
+    "    return np.array([np.correlate(sequence, row, mode='full') for row in mask])"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 8,
+   "execution_count": 9,
    "metadata": {},
    "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "(31,) (31,)\n",
+      "(31,) (31,)\n",
+      "shapes (1240,) (1240,)\n",
+      "(31,) (31,)\n",
+      "mask (33, 310)\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "182fd5ac86e74ad299a67e5f1d0b2b2b",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
     {
      "name": "stdout",
      "output_type": "stream",
@@ -147,10 +218,46 @@
       "(31,) (31,)\n"
      ]
     },
+    {
+     "data": {
+      "text/plain": [
+       "<matplotlib.image.AxesImage at 0x7fcfd6c90070>"
+      ]
+     },
+     "execution_count": 9,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "nbits = 5\n",
+    "decimation = 10\n",
+    "\n",
+    "foo = np.repeat(modulate(np.array(list(range(4))), nbits).astype(float), decimation)\n",
+    "bar = np.repeat(modulate(np.array(list(range(4))), nbits) * 2.0 - 1, decimation) * 1e-3\n",
+    "print('shapes', foo.shape, bar.shape)\n",
+    "\n",
+    "mask = np.tile(np.array(gold(nbits))[:,:,np.newaxis]*2 - 1, (1, 1, decimation)).reshape((2**nbits + 1, (2**nbits-1) * decimation))\n",
+    "print('mask', mask.shape)\n",
+    "\n",
+    "fig, (ax1, ax2) = plt.subplots(2, figsize=(16, 5))\n",
+    "fig.tight_layout()\n",
+    "corr_m = np.array([np.correlate(foo, row, mode='full') for row in mask])\n",
+    "#corr_m = np.array([row for row in mask])\n",
+    "ax1.matshow(corr_m, aspect='auto')\n",
+    "#ax.matshow(foo.reshape(32, 31)[::2,:])\n",
+    "ax2.matshow(correlate(bar, decimation=decimation), aspect='auto')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "metadata": {},
+   "outputs": [
     {
      "data": {
       "application/vnd.jupyter.widget-view+json": {
-       "model_id": "c99513c0fb7f4b138367186127e379cf",
+       "model_id": "4cb2661eebb84478b06d285166ec13bc",
        "version_major": 2,
        "version_minor": 0
       },
@@ -165,29 +272,43 @@
      "name": "stdout",
      "output_type": "stream",
      "text": [
+      "(31,) (31,)\n",
       "(31,) (31,)\n"
      ]
     },
     {
      "data": {
       "text/plain": [
-       "[<matplotlib.lines.Line2D at 0x7f8fce86df40>]"
+       "<matplotlib.image.AxesImage at 0x7fcfd6c6cfa0>"
       ]
      },
-     "execution_count": 8,
+     "execution_count": 10,
      "metadata": {},
      "output_type": "execute_result"
     }
    ],
    "source": [
-    "foo = modulate(np.array([0, 1, 0, 0, 1, 1, 1, 0])).astype(float)\n",
-    "fig, ax = plt.subplots()\n",
-    "ax.plot(correlate(foo))"
+    "decimation = 10\n",
+    "\n",
+    "fig, (ax1, ax2) = plt.subplots(2, figsize=(12, 5))\n",
+    "fig.tight_layout()\n",
+    "\n",
+    "#mask = np.tile(np.array(gold(nbits))[:,:,np.newaxis]*2 - 1, (1, 1, decimation)).reshape((2**nbits + 1, (2**nbits-1) * decimation))\n",
+    "#mask_stretched = np.tile(np.array(gold(nbits))[:,:,np.newaxis]*2 - 1, (1, 1, 1)).reshape((2**nbits + 1, (2**nbits-1) * 1))\n",
+    "\n",
+    "#ax1.matshow(mask)\n",
+    "#ax2.matshow(mask_stretched, aspect='auto')\n",
+    "\n",
+    "foo = np.repeat(modulate(np.array(list(range(4)))).astype(float), 1).reshape((4, 31))\n",
+    "foo_stretched = np.repeat(modulate(np.array(list(range(4)))).astype(float), 10).reshape(4, 310)\n",
+    "\n",
+    "ax1.matshow(foo)\n",
+    "ax2.matshow(foo_stretched, aspect='auto')"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 9,
+   "execution_count": 11,
    "metadata": {},
    "outputs": [
     {
@@ -200,7 +321,7 @@
     {
      "data": {
       "application/vnd.jupyter.widget-view+json": {
-       "model_id": "93658d824ced42e5b1107501398234b4",
+       "model_id": "a2e2f747193b478bbfa792a0995ad4ed",
        "version_major": 2,
        "version_minor": 0
       },
@@ -222,10 +343,10 @@
     {
      "data": {
       "text/plain": [
-       "(2.0, 0.944245383185962)"
+       "(2.0, 1.0234353995297893)"
       ]
      },
-     "execution_count": 9,
+     "execution_count": 11,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -269,7 +390,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 10,
+   "execution_count": 12,
    "metadata": {},
    "outputs": [
     {
@@ -289,15 +410,13 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 76,
+   "execution_count": 96,
    "metadata": {},
    "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "(31,) (31,)\n",
-      "(31,) (31,)\n",
       "(31,) (31,)\n",
       "(31,) (31,)\n"
      ]
@@ -306,14 +425,14 @@
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "<ipython-input-76-c15a6a1f5988>:27: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`).\n",
-      "  fig, ((ax1, ax3, ax5), (ax2, ax4, ax6)) = plt.subplots(2, 3, figsize=(16, 9))\n"
+      "<ipython-input-96-b3aae757ccad>:33: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`).\n",
+      "  fig, ((ax1, ax3), (ax2, ax4)) = plt.subplots(2, 2, figsize=(16, 9))\n"
      ]
     },
     {
      "data": {
       "application/vnd.jupyter.widget-view+json": {
-       "model_id": "c24adc84c7294295a47a58aa7d5914f9",
+       "model_id": "246c19a3c1424e7eb0b675ff060ea5b3",
        "version_major": 2,
        "version_minor": 0
       },
@@ -327,24 +446,30 @@
     {
      "data": {
       "text/plain": [
-       "(0.001, 0.010294564)"
+       "(0.002, 0.013899708)"
       ]
      },
-     "execution_count": 76,
+     "execution_count": 96,
      "metadata": {},
      "output_type": "execute_result"
     }
    ],
    "source": [
     "decimation = 10\n",
-    "signal_amplitude = 1.0e-3\n",
+    "signal_amplitude = 2.0e-3\n",
     "nbits = 5\n",
     "\n",
     "#test_data = np.random.randint(0, 2, 100)\n",
-    "test_data = np.array([0, 1, 0, 0, 1, 1, 1, 0])\n",
+    "#test_data = np.array([0, 1, 0, 0, 1, 1, 1, 0])\n",
+    "#test_data = np.random.RandomState(seed=0).randint(0, 2 * (2**nbits), 64)\n",
+    "#test_data = np.random.RandomState(seed=0).randint(0, 8, 64)\n",
+    "#test_data = np.array(list(range(8)) * 8)\n",
+    "#test_data = np.array([0, 1] * 32)\n",
+    "test_data = np.array(list(range(64)))\n",
     "\n",
     "foo = np.repeat(modulate(test_data, nbits) * 2.0 - 1, decimation) * signal_amplitude\n",
     "noise = np.resize(mains_noise, len(foo))\n",
+    "#noise = 0\n",
     "\n",
     "sosh = sig.butter(3, 0.01, btype='highpass', output='sos', fs=decimation)\n",
     "sosl = sig.butter(3, 0.8, btype='lowpass', output='sos', fs=decimation)\n",
@@ -352,9 +477,9 @@
     "filtered = sig.sosfilt(sosh, foo + noise)\n",
     "\n",
     "cor1 = correlate(foo + noise, nbits=nbits, decimation=decimation)\n",
-    "cor2 = correlate(filtered, nbits=nbits, decimation=decimation)\n",
+    "#cor2 = correlate(filtered, nbits=nbits, decimation=decimation)\n",
     "\n",
-    "cor2_pe = correlate(filtered, nbits=nbits, decimation=decimation, mask_filter=lambda mask: sig.sosfilt(sosh, sig.sosfiltfilt(sosl, mask)))\n",
+    "#cor2_pe = correlate(filtered, nbits=nbits, decimation=decimation, mask_filter=lambda mask: sig.sosfilt(sosh, sig.sosfiltfilt(sosl, mask)))\n",
     "\n",
     "sosn = sig.butter(12, 4, btype='highpass', output='sos', fs=decimation)\n",
     "#cor1_flt = sig.sosfilt(sosn, cor1)\n",
@@ -362,7 +487,7 @@
     "#cor1_flt = cor1[1:] - cor1[:-1]\n",
     "#cor2_flt = cor2[1:] - cor2[:-1]\n",
     "\n",
-    "fig, ((ax1, ax3, ax5), (ax2, ax4, ax6)) = plt.subplots(2, 3, figsize=(16, 9))\n",
+    "fig, ((ax1, ax3), (ax2, ax4)) = plt.subplots(2, 2, figsize=(16, 9))\n",
     "fig.tight_layout()\n",
     "\n",
     "ax1.plot(foo + noise)\n",
@@ -373,13 +498,14 @@
     "ax2.plot(foo)\n",
     "ax2.set_title('filtered')\n",
     "\n",
-    "ax3.plot(cor1)\n",
+    "ax3.plot(cor1.T)\n",
     "ax3.set_title('corr raw')\n",
     "ax3.grid()\n",
     "\n",
-    "ax4.plot(cor2)\n",
-    "ax4.set_title('corr filtered')\n",
-    "ax4.grid()\n",
+    "#ax4.plot(cor2[:4].T)\n",
+    "#ax4.set_title('corr filtered')\n",
+    "#ax4.grid()\n",
+    "ax4.matshow(cor1, aspect='auto')\n",
     "\n",
     "#ax5.plot(cor1_flt)\n",
     "#ax5.set_title('corr raw (highpass)')\n",
@@ -389,9 +515,9 @@
     "#ax6.set_title('corr filtered (highpass)')\n",
     "#ax6.grid()\n",
     "\n",
-    "ax6.plot(cor2_pe)\n",
-    "ax6.set_title('corr filtered w/ mask preemphasis')\n",
-    "ax6.grid()\n",
+    "#ax6.plot(cor2_pe[:4].T)\n",
+    "#ax6.set_title('corr filtered w/ mask preemphasis')\n",
+    "#ax6.grid()\n",
     "\n",
     "rms = lambda x: np.sqrt(np.mean(np.square(x)))\n",
     "rms(foo), rms(noise)"
@@ -399,13 +525,13 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 21,
+   "execution_count": 14,
    "metadata": {},
    "outputs": [
     {
      "data": {
       "application/vnd.jupyter.widget-view+json": {
-       "model_id": "1c8e27744cd0482782fa0d65ed550ba6",
+       "model_id": "c08b2a1dbdef429eb22b598bd3dc0146",
        "version_major": 2,
        "version_minor": 0
       },
@@ -426,10 +552,10 @@
     {
      "data": {
       "text/plain": [
-       "[<matplotlib.lines.Line2D at 0x7f8fc2cab1f0>]"
+       "[<matplotlib.lines.Line2D at 0x7fcfd1635700>]"
       ]
      },
-     "execution_count": 21,
+     "execution_count": 14,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -449,13 +575,13 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 22,
+   "execution_count": 15,
    "metadata": {},
    "outputs": [
     {
      "data": {
       "application/vnd.jupyter.widget-view+json": {
-       "model_id": "483d7acbb9604c9c93533d342d6068ad",
+       "model_id": "20117316e02548a99386c39e45e71ef1",
        "version_major": 2,
        "version_minor": 0
       },
@@ -474,14 +600,15 @@
      ]
     },
     {
-     "data": {
-      "text/plain": [
-       "[<matplotlib.lines.Line2D at 0x7f8fc2679310>]"
-      ]
-     },
-     "execution_count": 22,
-     "metadata": {},
-     "output_type": "execute_result"
+     "ename": "NameError",
+     "evalue": "name 'cor2_pe' is not defined",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
+      "\u001b[0;32m<ipython-input-15-f158dfc14cca>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      8\u001b[0m \u001b[0msosh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbutter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0.1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'highpass'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'sos'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdecimation\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      9\u001b[0m \u001b[0msosl\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbutter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0.8\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'lowpass'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'sos'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdecimation\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m \u001b[0mcor2_pe_flt\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msosfilt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msosh\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcor2_pe\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     11\u001b[0m \u001b[0mcor2_pe_flt2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msosfilt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msosh\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msosfiltfilt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msosl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcor2_pe\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;31mNameError\u001b[0m: name 'cor2_pe' is not defined"
+     ]
     }
    ],
    "source": [
@@ -506,21 +633,21 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 77,
+   "execution_count": 57,
    "metadata": {},
    "outputs": [
     {
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "<ipython-input-77-478546893e6f>:1: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`).\n",
+      "<ipython-input-57-3e7dc7c98d30>:1: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`).\n",
       "  fig, ax = plt.subplots()\n"
      ]
     },
     {
      "data": {
       "application/vnd.jupyter.widget-view+json": {
-       "model_id": "298c99e458364426829979eeaf01af6e",
+       "model_id": "25fe274dea83415491fd7f86d38188d7",
        "version_major": 2,
        "version_minor": 0
       },
@@ -534,38 +661,38 @@
     {
      "data": {
       "text/plain": [
-       "[<matplotlib.lines.Line2D at 0x7f8f9968fa00>]"
+       "[<matplotlib.lines.Line2D at 0x7fcf84cb4a60>]"
       ]
      },
-     "execution_count": 77,
+     "execution_count": 57,
      "metadata": {},
      "output_type": "execute_result"
     }
    ],
    "source": [
     "fig, ax = plt.subplots()\n",
-    "nonlinear_distance = lambda x: 100**(2*np.abs(0.5-x%1)) / (np.abs(x)+7)**2\n",
+    "nonlinear_distance = lambda x: 100**(2*np.abs(0.5-x%1)) / (np.abs(x)+3)**2\n",
     "x = np.linspace(-1.5, 5.5, 10000)\n",
     "ax.plot(x, nonlinear_distance(x))"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 78,
+   "execution_count": 97,
    "metadata": {},
    "outputs": [
     {
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "<ipython-input-78-ad8374b3e684>:9: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`).\n",
-      "  fig, (ax1, ax2, ax3) = plt.subplots(3, figsize=(12, 12))\n"
+      "<ipython-input-97-2d2c2f814215>:11: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`).\n",
+      "  fig, (ax1, ax3) = plt.subplots(2, figsize=(12, 5))\n"
      ]
     },
     {
      "data": {
       "application/vnd.jupyter.widget-view+json": {
-       "model_id": "126c195977aa4639a8026e0a857a8ab8",
+       "model_id": "b083c661b5b441d6b7fc45201faa0576",
        "version_major": 2,
        "version_minor": 0
       },
@@ -577,127 +704,178 @@
      "output_type": "display_data"
     },
     {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "5a40fd3c41814b5f99e9f452c8923db4",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "interactive(children=(FloatSlider(value=10.0, description='w', max=30.0, min=-10.0), Output()), _dom_classes=(…"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "text/plain": [
-       "[<matplotlib.lines.Line2D at 0x7f8fa59086d0>]"
-      ]
-     },
-     "execution_count": 78,
-     "metadata": {},
-     "output_type": "execute_result"
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "cor_an (33, 20149)\n",
+      "cwt_res (33, 20149)\n",
+      "th (33, 20149)\n",
+      "[((33,), (33,)), ((33,), (33,)), ((33,), (33,)), ((33,), (33,)), ((33,), (33,))]\n",
+      "peaks: 180\n",
+      "avg_peak 1.058897833824206\n",
+      "skipped 3 symbols at 5889.0\n",
+      "skipped 2 symbols at 8369.0\n",
+      "skipped 2 symbols at 14568.5\n",
+      "skipped 2 symbols at 16739.0\n",
+      "decoding [ref|dec]:\n",
+      " -1| -1 ✔      1|  1 ✔      2|  2 ✔      3|  3 ✔      4|  4 ✔      5|  5 ✔      6|  6 ✔      7|  7 ✔    \n",
+      "  8|  8 ✔      9|  9 ✔     10| 10 ✔     11| 11 ✔     12| 12 ✔     13| 13 ✔     14| 14 ✔     15| 15 ✔    \n",
+      " 16| -1 ✘     17| -1 ✘     18| 18 ✔     19| 19 ✔     20| 20 ✔     21| 21 ✔     22| 22 ✔     23| 23 ✔    \n",
+      " 24| 24 ✔     25| -1 ✘     26| 26 ✔     27| 27 ✔     28| 28 ✔     29| 29 ✔     30| 30 ✔     31| 31 ✔    \n",
+      " 32| 32 ✔     33| 33 ✔     34| 34 ✔     35| 35 ✔     36| 36 ✔     37| 37 ✔     38| 38 ✔     39| 39 ✔    \n",
+      " 40| 40 ✔     41| 41 ✔     42| 42 ✔     43| 43 ✔     44| 44 ✔     45| -1 ✘     46| 46 ✔     47| 47 ✔    \n",
+      " 48| 48 ✔     49| 49 ✔     50| 50 ✔     51| 51 ✔     52| -1 ✘     53| 53 ✔     54| 54 ✔     55| 55 ✔    \n",
+      " 56| 56 ✔     57| 57 ✔     58| 58 ✔     59| 59 ✔     60| 60 ✔     61| 61 ✔     62| 62 ✔     63| 56 ✘    \n",
+      "Symbol error rate r=0.09375\n"
+     ]
     }
    ],
    "source": [
-    "threshold_factor = 2.0\n",
+    "threshold_factor = 4.0\n",
     "power_avg_width = 1024\n",
+    "max_lookahead = 6.5\n",
     "\n",
     "bit_period = (2**nbits) * decimation\n",
     "peak_group_threshold = 0.1 * bit_period\n",
     "\n",
     "cor_an = cor1\n",
     "\n",
-    "fig, (ax1, ax2, ax3) = plt.subplots(3, figsize=(12, 12))\n",
+    "#fig, (ax1, ax2, ax3) = plt.subplots(3, figsize=(12, 12))\n",
+    "fig, (ax1, ax3) = plt.subplots(2, figsize=(12, 5))\n",
     "fig.tight_layout()\n",
     "\n",
-    "ax1.matshow(sig.cwt(cor_an, sig.ricker, np.arange(1, 31)), aspect='auto')\n",
+    "#ax1.matshow(sig.cwt(cor_an, sig.ricker, np.arange(1, 31)), aspect='auto')\n",
     "\n",
-    "for i in np.linspace(1, 10, 19):\n",
-    "    offx = 5*i\n",
-    "    ax2.plot(sig.cwt(cor_an, sig.ricker, [i]).flatten() + offx, color='red')\n",
-    "\n",
-    "    ax2.text(-50, offx, f'{i:.1f}',\n",
-    "        horizontalalignment='right',\n",
-    "        verticalalignment='center',\n",
-    "        color='black')\n",
-    "ax2.grid()\n",
+    "#for i in np.linspace(1, 10, 19):\n",
+    "#    offx = 5*i\n",
+    "#    ax2.plot(sig.cwt(cor_an, sig.ricker, [i]).flatten() + offx, color='red')\n",
+    "#\n",
+    "#    ax2.text(-50, offx, f'{i:.1f}',\n",
+    "#        horizontalalignment='right',\n",
+    "#        verticalalignment='center',\n",
+    "#        color='black')\n",
+    "#ax2.grid()\n",
     "\n",
     "ax3.grid()\n",
+    "print('cor_an', cor_an.shape)\n",
+    "\n",
+    "cwt_res = np.array([ sig.cwt(row, sig.ricker, [0.73 * decimation]).flatten() for row in cor_an ])\n",
+    "ax3.plot(cwt_res.T)\n",
+    "#def update(w = 1.0 * decimation):\n",
+    "#    line.set_ydata(sig.cwt(cor_an, sig.ricker, [w]).flatten())\n",
+    "#    fig.canvas.draw_idle()\n",
+    "#ipywidgets.interact(update)\n",
     "\n",
-    "cwt_res = sig.cwt(cor_an, sig.ricker, [7.3]).flatten()\n",
-    "line, = ax3.plot(cwt_res)\n",
-    "def update(w=10.0):\n",
-    "    line.set_ydata(sig.cwt(cor_an, sig.ricker, [w]).flatten())\n",
-    "    fig.canvas.draw_idle()\n",
-    "ipywidgets.interact(update)\n",
+    "print('cwt_res', cwt_res.shape)\n",
+    "th = np.array([ np.convolve(np.abs(row), np.ones((power_avg_width,))/power_avg_width, mode='same') for row in cwt_res ])\n",
+    "ax1.plot(th.T)\n",
+    "print('th', th.shape)\n",
     "\n",
+    "def compare_th(elem):\n",
+    "    idx, (th, val) = elem\n",
+    "    #print('compare_th:', th.shape, val.shape)\n",
+    "    return np.any(np.abs(val) > th*threshold_factor)\n",
     "\n",
-    "th = np.convolve(np.abs(cwt_res), np.ones((power_avg_width,))/power_avg_width, mode='same')\n",
-    "peaks = [ list(group) for val, group in itertools.groupby(enumerate(zip(th, cwt_res)), lambda elem: abs(elem[1][1]) > elem[1][0]*threshold_factor) if val ]\n",
+    "print([ (a.shape, b.shape) for a, b in zip(th.T, cwt_res.T) ][:5])\n",
+    "\n",
+    "peaks = [ list(group) for val, group in itertools.groupby(enumerate(zip(th.T, cwt_res.T)), compare_th) if val ]\n",
+    "print('peaks:', len(peaks))\n",
     "peak_group = []\n",
     "for group in peaks:\n",
     "    pos = np.mean([idx for idx, _val in group])\n",
-    "    pol = np.mean([val for _idx, (_th, val) in group])\n",
+    "    pol = np.mean([max(val.min(), val.max(), key=abs) for _idx, (_th, val) in group])\n",
+    "    pol_idx = np.argmax(np.bincount([ np.argmax(np.abs(val)) for _idx, (_th, val) in group ]))\n",
+    "    #print(f'group', pos, pol, pol_idx)\n",
+    "    #for pol, (_idx, (_th, val)) in zip([max(val.min(), val.max(), key=abs) for _idx, (_th, val) in group], group):\n",
+    "    #    print('    ', pol, val)\n",
+    "    ax3.axvline(pos, color='cyan', alpha=0.3)\n",
     "    \n",
     "    if not peak_group or pos - peak_group[-1][1] > peak_group_threshold:\n",
     "        if peak_group:\n",
     "            peak_pos = peak_group[-1][3]\n",
-    "            ax3.axvline(peak_pos, color='red', alpha=0.3)\n",
+    "            ax3.axvline(peak_pos, color='red', alpha=0.6)\n",
     "            #ax3.text(peak_pos-20, 2.0, f'{0 if pol < 0 else 1}', horizontalalignment='right', verticalalignment='center', color='black')\n",
     "            \n",
-    "        peak_group.append((pos, pos, pol, pos))\n",
+    "        peak_group.append((pos, pos, pol, pos, pol_idx))\n",
     "        #ax3.axvline(pos, color='cyan', alpha=0.5)\n",
     "        \n",
     "    else:\n",
-    "        group_start, last_pos, last_pol, peak_pos = peak_group[-1]\n",
+    "        group_start, last_pos, last_pol, peak_pos, last_pol_idx = peak_group[-1]\n",
     "    \n",
     "        if abs(pol) > abs(last_pol):\n",
     "            #ax3.axvline(pos, color='magenta', alpha=0.5)\n",
-    "            peak_group[-1] = (group_start, pos, pol, pos)\n",
+    "            peak_group[-1] = (group_start, pos, pol, pos, pol_idx)\n",
     "        else:\n",
     "            #ax3.axvline(pos, color='blue', alpha=0.5)\n",
-    "            peak_group[-1] = (group_start, pos, last_pol, peak_pos)\n",
+    "            peak_group[-1] = (group_start, pos, last_pol, peak_pos, last_pol_idx)\n",
+    "\n",
+    "avg_peak = np.mean(np.abs(np.array([last_pol for _1, _2, last_pol, _3, _4 in peak_group])))\n",
+    "print('avg_peak', avg_peak)\n",
     "\n",
+    "noprint = lambda *args, **kwargs: None\n",
     "def mle_decode(peak_groups, print=print):\n",
-    "    peak_groups = [ (pos, pol) for _1, _2, pol, pos in peak_groups ]\n",
-    "    candidates = [ [(pos, pol)] for pos, pol in peak_groups if pos < bit_period*1.5 ]\n",
+    "    peak_groups = [ (pos, pol, idx) for _1, _2, pol, pos, idx in peak_groups ]\n",
+    "    candidates = [ (0, [(pos, pol, idx)]) for pos, pol, idx in peak_groups if pos < bit_period*2.5 ]\n",
     "    \n",
     "    while candidates:\n",
     "        chain_candidates = []\n",
-    "        for chain in candidates:\n",
-    "            pos, ampl = chain[-1]\n",
-    "            score_fun = lambda pos, npos, npol: abs(npol)/2 + nonlinear_distance((npos-pos)/bit_period)\n",
-    "            next_candidates = sorted([ (score_fun(pos, npos, npol), npos, npol) for npos, npol in peak_groups if pos < npos < pos + bit_period*3.5 ], reverse=True)\n",
+    "        for chain_score, chain in candidates:\n",
+    "            pos, ampl, _idx = chain[-1]\n",
+    "            score_fun = lambda pos, npos, npol: abs(npol)/avg_peak + nonlinear_distance((npos-pos)/bit_period)\n",
+    "            next_candidates = sorted([ (score_fun(pos, npos, npol), npos, npol, nidx) for npos, npol, nidx in peak_groups if pos < npos < pos + bit_period*max_lookahead ], reverse=True)\n",
     "            \n",
     "            print(f'    candidates for {pos}, {ampl}:')\n",
-    "            for score, npos, npol in next_candidates:\n",
-    "                print(f'        {score:.4f} {npos:.2f} {npol:.2f}')\n",
+    "            for score, npos, npol, nidx in next_candidates:\n",
+    "                print(f'        {score:.4f} {npos:.2f} {npol:.2f} {nidx:.2f}')\n",
     "            \n",
-    "            if len(cor_an) - pos < 1.5*bit_period or not next_candidates:\n",
-    "                score = sum(score_fun(opos, npos, npol) for (opos, _opol), (npos, npol) in zip(chain[:-1], chain[1:])) / (len(chain)-1)\n",
+    "            nch, cor_len = cor_an.shape\n",
+    "            if cor_len - pos < 1.5*bit_period or not next_candidates:\n",
+    "                score = sum(score_fun(opos, npos, npol) for (opos, _opol, _oidx), (npos, npol, _nidx) in zip(chain[:-1], chain[1:])) / len(chain)\n",
     "                yield score, chain\n",
     "            \n",
     "            else:\n",
-    "                for score, npos, npol in next_candidates[:3]:\n",
+    "                print('extending')\n",
+    "                for score, npos, npol, nidx in next_candidates[:3]:\n",
     "                    if score > 0.5:\n",
-    "                        chain_candidates.append((score, chain + [(npos, npol)]))\n",
+    "                        new_chain_score = chain_score * 0.9 + score * 0.1\n",
+    "                        chain_candidates.append((new_chain_score, chain + [(npos, npol, nidx)]))\n",
     "        print('chain candidates:')\n",
     "        for score, chain in sorted(chain_candidates, reverse=True):\n",
-    "            print('    ', [(score, [(f'{pos:.2f}', f'{pol:.2f}') for pos, pol in chain])])\n",
-    "        candidates = [ chain for _score, chain in sorted(chain_candidates, reverse=True)[:10] ]\n",
+    "            print('    ', [(score, [(f'{pos:.2f}', f'{pol:.2f}') for pos, pol, _idx in chain])])\n",
+    "        candidates = [ (chain_score, chain) for chain_score, chain in sorted(chain_candidates, reverse=True)[:10] ]\n",
     "\n",
-    "res = sorted(mle_decode(peak_group, print=lambda *args, **kwargs: None), reverse=True)\n",
+    "res = sorted(mle_decode(peak_group, print=noprint), reverse=True)\n",
     "#for i, (score, chain) in enumerate(res):\n",
     "#    print(f'Chain {i}@{score:.4f}: {chain}')\n",
     "(_score, chain), *_ = res\n",
-    "for pos, pol in chain:\n",
-    "    ax3.axvline(pos, color='blue', alpha=0.5)\n",
-    "    ax3.text(pos-20, 0.0, f'{0 if pol < 0 else 1}', horizontalalignment='right', verticalalignment='center', color='black')\n",
     "\n",
-    "ax3.plot(th)"
+    "def viz(chain):\n",
+    "    last_pos = None\n",
+    "    for pos, pol, nidx in chain:\n",
+    "        if last_pos:\n",
+    "            delta = int(round((pos - last_pos) / bit_period))\n",
+    "            if delta > 1:\n",
+    "                print(f'skipped {delta} symbols at {pos}')\n",
+    "            for i in range(delta-1):\n",
+    "                yield None\n",
+    "        ax3.axvline(pos, color='blue', alpha=0.5)\n",
+    "        decoded = nidx*2 + (0 if pol < 0 else 1)\n",
+    "        yield decoded\n",
+    "        ax3.text(pos-20, 0.0, f'{decoded}', horizontalalignment='right', verticalalignment='center', color='black')\n",
+    "\n",
+    "        last_pos = pos\n",
+    "\n",
+    "decoded = list(viz(chain))\n",
+    "print('decoding [ref|dec]:')\n",
+    "failures = 0\n",
+    "for i, (ref, found) in enumerate(itertools.zip_longest(test_data, decoded)):\n",
+    "    print(f'{ref or -1:>3d}|{found or -1:>3d} {\"✔\" if ref==found else \"✘\"}', end='    ')\n",
+    "    if ref != found:\n",
+    "        failures += 1\n",
+    "    if i%8 == 7:\n",
+    "        print()\n",
+    "print(f'Symbol error rate r={failures/len(test_data)}')\n",
+    "#ax3.plot(th)"
    ]
   },
   {
-- 
cgit