diff --git a/obsstudio_sdk/callback.py b/obsstudio_sdk/callback.py index 8cc4a61..a106c3a 100644 --- a/obsstudio_sdk/callback.py +++ b/obsstudio_sdk/callback.py @@ -42,13 +42,17 @@ class Callback: if fns not in self._callbacks: self._callbacks.append(fns) - def deregister(self, callback): + def deregister(self, fns: Union[Iterable, Callable]): """deregisters a callback from _callbacks""" try: - self._callbacks.remove(callback) - except ValueError: - print(f"Failed to remove: {callback}") + iterator = iter(fns) + for fn in iterator: + if fn in self._callbacks: + self._callbacks.remove(fn) + except TypeError as e: + if fns in self._callbacks: + self._callbacks.remove(fns) def clear(self): """clears the _callbacks list""" diff --git a/obsstudio_sdk/events.py b/obsstudio_sdk/events.py index 9379891..8d62ab4 100644 --- a/obsstudio_sdk/events.py +++ b/obsstudio_sdk/events.py @@ -41,8 +41,9 @@ class EventClient(object): self.base_client = ObsClient(**kwargs) self.base_client.authenticate() self.callback = Callback() + self.subscribe() - self.running = True + def subscribe(self): worker = Thread(target=self.trigger, daemon=True) worker.start() @@ -52,6 +53,7 @@ class EventClient(object): Triggers a callback on event received. """ + self.running = True while self.running: self.data = json.loads(self.base_client.ws.recv()) event, data = (self.data["d"].get("eventType"), self.data["d"]) diff --git a/tests/test_callback.py b/tests/test_callback.py new file mode 100644 index 0000000..45d9329 --- /dev/null +++ b/tests/test_callback.py @@ -0,0 +1,59 @@ +import pytest +from obsstudio_sdk.callback import Callback + + +class TestCallbacks: + __test__ = True + + @classmethod + def setup_class(cls): + cls.callback = Callback() + + @pytest.fixture(autouse=True) + def wraps_tests(self): + yield + self.callback.clear() + + def test_register_callback(self): + def on_callback_method(): + pass + + self.callback.register(on_callback_method) + assert self.callback.get() == ["CallbackMethod"] + + def test_register_callbacks(self): + def on_callback_method_one(): + pass + + def on_callback_method_two(): + pass + + self.callback.register((on_callback_method_one, on_callback_method_two)) + assert self.callback.get() == ["CallbackMethodOne", "CallbackMethodTwo"] + + def test_deregister_callback(self): + def on_callback_method_one(): + pass + + def on_callback_method_two(): + pass + + self.callback.register((on_callback_method_one, on_callback_method_two)) + self.callback.deregister(on_callback_method_one) + assert self.callback.get() == ["CallbackMethodTwo"] + + def test_deregister_callbacks(self): + def on_callback_method_one(): + pass + + def on_callback_method_two(): + pass + + def on_callback_method_three(): + pass + + self.callback.register( + (on_callback_method_one, on_callback_method_two, on_callback_method_three) + ) + self.callback.deregister((on_callback_method_two, on_callback_method_three)) + assert self.callback.get() == ["CallbackMethodOne"] diff --git a/tests/test_request.py b/tests/test_request.py index d68feb9..70efc3e 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -1,5 +1,3 @@ -import time - import pytest from tests import req_cl