diff options
Diffstat (limited to 'hexnoise.py')
-rwxr-xr-x | hexnoise.py | 88 |
1 files changed, 63 insertions, 25 deletions
diff --git a/hexnoise.py b/hexnoise.py index 31d7695..df71d93 100755 --- a/hexnoise.py +++ b/hexnoise.py @@ -38,8 +38,8 @@ def send_packet(ser, data, width=16): def receive_packet(ser, width=16): packet = ser.read_until(b'\0') data = cobs.decode(packet[:-1]) - #print(f'\033[93mReceived {len(data)} bytes\033[0m') - #hexdump(print, data, width) + print(f'\033[93mReceived {len(data)} bytes\033[0m') + hexdump(print, data, width) return data if __name__ == '__main__': @@ -54,6 +54,23 @@ if __name__ == '__main__': ser = serial.Serial(args.serial, args.baudrate) + import uinput + ALL_KEYS = [ v for k, v in uinput.ev.__dict__.items() if k.startswith('KEY_') ] + MODIFIERS = [ + uinput.ev.KEY_LEFTCTRL, + uinput.ev.KEY_LEFTSHIFT, + uinput.ev.KEY_LEFTALT, + uinput.ev.KEY_LEFTMETA, + uinput.ev.KEY_RIGHTCTRL, + uinput.ev.KEY_RIGHTSHIFT, + uinput.ev.KEY_RIGHTALT, + uinput.ev.KEY_RIGHTMETA, + ] + map_modifiers = lambda x: [ mod for i, mod in enumerate(MODIFIERS) if x & (1<<i) ] + import keymap + map_regular = { v: getattr(uinput.ev, k) for k, v in keymap.__dict__.items() if k.startswith('KEY_') } + map_regulars = lambda keycodes: [ map_regular[kc] for kc in keycodes if kc != 0 and kc in map_regular ] + from noise.connection import NoiseConnection, Keypair from noise.exceptions import NoiseInvalidMessage @@ -81,29 +98,50 @@ if __name__ == '__main__': print('Handshake finished, handshake hash:') hexdump(print, proto.get_handshake_hash(), args.width) - def noise_rx(received): + old_kcs = set() + def noise_rx(received, ui): + global old_kcs + data = proto.decrypt(received) - #print('Decrypted data:') - #hexdump(print, data, args.width) + print('Decrypted data:') + hexdump(print, data, args.width) - while True: - try: - received = receive_packet(ser, args.width) + rtype, rlen, *report = data + if rtype != 1 or rlen != 8: + return + + modbyte, _reserved, *keycodes = report + keys = map_modifiers(modbyte) + map_regulars(keycodes) + print('Emitting:', keys) + keyset = set(keys) + + for key in keyset - old_kcs: + ui.emit(key, 1, syn=False) + for key in old_kcs - keyset: + ui.emit(key, 0, syn=False) + ui.syn() + + old_kcs = keyset + + with uinput.Device(ALL_KEYS) as ui: + while True: try: - noise_rx(received) - except NoiseInvalidMessage as e: - orig_n = proto.noise_protocol.cipher_state_decrypt.n - print('Invalid noise message', e) - for n in [orig_n+1, orig_n+2, orig_n+3]: - try: - proto.noise_protocol.cipher_state_decrypt.n = n - noise_rx(received) - print(f' Recovered. n={n}') - break - except NoiseInvalidMessage as e: - pass - else: - print(' Unrecoverable.') - proto.noise_protocol.cipher_state_decrypt.n = orig_n - except Exception as e: - print('Invalid framing:', e) + received = receive_packet(ser, args.width) + try: + noise_rx(received, ui) + except NoiseInvalidMessage as e: + orig_n = proto.noise_protocol.cipher_state_decrypt.n + print('Invalid noise message', e) + for n in [orig_n+1, orig_n+2, orig_n+3]: + try: + proto.noise_protocol.cipher_state_decrypt.n = n + noise_rx(received, ui) + print(f' Recovered. n={n}') + break + except NoiseInvalidMessage as e: + pass + else: + print(' Unrecoverable.') + proto.noise_protocol.cipher_state_decrypt.n = orig_n + except Exception as e: + print('Invalid framing:', e) |