summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xmpv-test.py73
-rw-r--r--mpv.py19
2 files changed, 83 insertions, 9 deletions
diff --git a/mpv-test.py b/mpv-test.py
index 7cdec89..d829dba 100755
--- a/mpv-test.py
+++ b/mpv-test.py
@@ -243,6 +243,79 @@ class ObservePropertyTest(MpvTestCase):
mock.call('loop', 'inf')],
any_order=True)
+class KeyBindingTest(MpvTestCase):
+ def test_register_direct_cmd(self):
+ self.m.register_key_binding('a', 'playlist-clear')
+ self.assertEqual(self.m._key_binding_handlers, {})
+ self.m.register_key_binding('Ctrl+Shift+a', 'playlist-clear')
+ self.m.unregister_key_binding('a')
+ self.m.unregister_key_binding('Ctrl+Shift+a')
+
+ def test_register_direct_fun(self):
+ b = mpv.MPV._binding_name
+
+ def reg_test_fun(state, name):
+ pass
+
+ self.m.register_key_binding('a', reg_test_fun)
+ self.assertIn(b('a'), self.m._key_binding_handlers)
+ self.assertEqual(self.m._key_binding_handlers[b('a')], reg_test_fun)
+
+ self.m.unregister_key_binding('a')
+ self.assertNotIn(b('a'), self.m._key_binding_handlers)
+
+ def test_register_direct_bound_method(self):
+ b = mpv.MPV._binding_name
+
+ class RegTestCls:
+ def method(self, state, name):
+ pass
+ instance = RegTestCls()
+
+ self.m.register_key_binding('a', instance.method)
+ self.assertIn(b('a'), self.m._key_binding_handlers)
+ self.assertEqual(self.m._key_binding_handlers[b('a')], instance.method)
+
+ self.m.unregister_key_binding('a')
+ self.assertNotIn(b('a'), self.m._key_binding_handlers)
+
+ def test_register_decorator_fun(self):
+ b = mpv.MPV._binding_name
+
+ @self.m.key_binding('a')
+ def reg_test_fun(state, name):
+ pass
+ self.assertEqual(reg_test_fun.mpv_key_bindings, ['a'])
+ self.assertIn(b('a'), self.m._key_binding_handlers)
+ self.assertEqual(self.m._key_binding_handlers[b('a')], reg_test_fun)
+
+ reg_test_fun.unregister_mpv_key_bindings()
+ self.assertNotIn(b('a'), self.m._key_binding_handlers)
+
+ def test_register_decorator_fun_chaining(self):
+ b = mpv.MPV._binding_name
+
+ @self.m.key_binding('a')
+ @self.m.key_binding('b')
+ def reg_test_fun(state, name):
+ pass
+
+ @self.m.key_binding('c')
+ def reg_test_fun_2_stay_intact(state, name):
+ pass
+
+ self.assertEqual(reg_test_fun.mpv_key_bindings, ['b', 'a'])
+ self.assertIn(b('a'), self.m._key_binding_handlers)
+ self.assertIn(b('b'), self.m._key_binding_handlers)
+ self.assertIn(b('c'), self.m._key_binding_handlers)
+ self.assertEqual(self.m._key_binding_handlers[b('a')], reg_test_fun)
+ self.assertEqual(self.m._key_binding_handlers[b('b')], reg_test_fun)
+
+ reg_test_fun.unregister_mpv_key_bindings()
+ self.assertNotIn(b('a'), self.m._key_binding_handlers)
+ self.assertNotIn(b('b'), self.m._key_binding_handlers)
+ self.assertIn(b('c'), self.m._key_binding_handlers)
+
class TestLifecycle(unittest.TestCase):
def test_create_destroy(self):
thread_names = lambda: [ t.name for t in threading.enumerate() ]
diff --git a/mpv.py b/mpv.py
index d8624f3..c53deb8 100644
--- a/mpv.py
+++ b/mpv.py
@@ -837,8 +837,8 @@ class MPV(object):
You can also call the ```unregister_mpv_messages``` function attribute set on the handler function when it is
registered. """
- if isinstance(target, str):
- del self._message_handlers[target]
+ if isinstance(target_or_handler, str):
+ del self._message_handlers[target_or_handler]
else:
for key, val in self._message_handlers.items():
if val == target_or_handler:
@@ -945,10 +945,16 @@ class MPV(object):
this is completely fine--but, if you are about to pass untrusted input into this parameter, better double-check
whether this is secure in your case. """
- def wrapper(fun):
+ def register(fun):
+ fun.mpv_key_bindings = getattr(fun, 'mpv_key_bindings', []) + [keydef]
+ def unregister_all():
+ for keydef in fun.mpv_key_bindings:
+ self.unregister_key_binding(keydef)
+ fun.unregister_mpv_key_bindings = unregister_all
+
self.register_key_binding(keydef, fun, mode)
return fun
- return wrapper
+ return register
def register_key_binding(self, keydef, callback_or_cmd, mode='force'):
""" Register a key binding. This takes an mpv keydef and either a string containing a mpv
@@ -959,11 +965,6 @@ class MPV(object):
'symbolic name (as printed by --input-keylist')
binding_name = MPV._binding_name(keydef)
if callable(callback_or_cmd):
- callback_or_cmd.mpv_key_bindings = getattr(callback_or_cmd, 'mpv_key_bindings', []) + [keydef]
- def unregister_all():
- for keydef in callback_or_cmd.mpv_key_bindings:
- self.unregister_key_binding(keydef)
- callback_or_cmd.unregister_mpv_key_bindings = unregister_all
self._key_binding_handlers[binding_name] = callback_or_cmd
self.register_message_handler('key-binding', self._handle_key_binding_message)
self.command('define-section',