From 27fd86efa5e2edc845aa09f7224fc7d658ddbc1b Mon Sep 17 00:00:00 2001 From: onyx-and-iris <75868496+onyx-and-iris@users.noreply.github.com> Date: Mon, 24 Oct 2022 22:42:16 +0100 Subject: [PATCH] only check for host+port values in init. only pass auth token if auth enabled add context manager methods to reqclient. added logging --- obsws_python/baseclient.py | 72 ++++++++++++++++++++------------------ obsws_python/callback.py | 4 +-- obsws_python/events.py | 16 +++++---- obsws_python/reqs.py | 15 ++++++-- 4 files changed, 63 insertions(+), 44 deletions(-) diff --git a/obsws_python/baseclient.py b/obsws_python/baseclient.py index 8361993..42e00b7 100644 --- a/obsws_python/baseclient.py +++ b/obsws_python/baseclient.py @@ -1,6 +1,7 @@ import base64 import hashlib import json +import logging from pathlib import Path from random import randint @@ -11,9 +12,11 @@ except ModuleNotFoundError: import websocket +from .error import OBSSDKError + class ObsClient: - DELAY = 0.001 + logger = logging.getLogger("baseclient.obsclient") def __init__(self, **kwargs): defaultkwargs = { @@ -23,7 +26,7 @@ class ObsClient: kwargs = defaultkwargs | kwargs for attr, val in kwargs.items(): setattr(self, attr, val) - if not (self.host and self.port and self.password): + if not (self.host and self.port): conn = self._conn_from_toml() self.host = conn["host"] self.port = conn["port"] @@ -40,50 +43,51 @@ class ObsClient: return conn["connection"] def authenticate(self): - secret = base64.b64encode( - hashlib.sha256( - ( - self.password + self.server_hello["d"]["authentication"]["salt"] - ).encode() - ).digest() - ) - - auth = base64.b64encode( - hashlib.sha256( - ( - secret.decode() - + self.server_hello["d"]["authentication"]["challenge"] - ).encode() - ).digest() - ).decode() - payload = { "op": 1, "d": { "rpcVersion": 1, - "authentication": auth, "eventSubscriptions": self.subs, }, } + if "authentication" in self.server_hello["d"]: + secret = base64.b64encode( + hashlib.sha256( + ( + self.password + self.server_hello["d"]["authentication"]["salt"] + ).encode() + ).digest() + ) + + auth = base64.b64encode( + hashlib.sha256( + ( + secret.decode() + + self.server_hello["d"]["authentication"]["challenge"] + ).encode() + ).digest() + ).decode() + + payload["d"]["authentication"] = auth + self.ws.send(json.dumps(payload)) - return self.ws.recv() + try: + response = json.loads(self.ws.recv()) + return response["op"] == 2 + except json.decoder.JSONDecodeError: + raise OBSSDKError("failed to identify client with the server") def req(self, req_type, req_data=None): + id = randint(1, 1000) + self.logger.debug(f"Sending request with response id {id}") + payload = { + "op": 6, + "d": {"requestType": req_type, "requestId": id}, + } if req_data: - payload = { - "op": 6, - "d": { - "requestType": req_type, - "requestId": randint(1, 1000), - "requestData": req_data, - }, - } - else: - payload = { - "op": 6, - "d": {"requestType": req_type, "requestId": randint(1, 1000)}, - } + payload["d"]["requestData"] = req_data self.ws.send(json.dumps(payload)) response = json.loads(self.ws.recv()) + self.logger.debug(f"Reponse received {response}") return response["d"] diff --git a/obsws_python/callback.py b/obsws_python/callback.py index 28e090d..20e46fe 100644 --- a/obsws_python/callback.py +++ b/obsws_python/callback.py @@ -31,7 +31,7 @@ class Callback: for fn in iterator: if fn not in self._callbacks: self._callbacks.append(fn) - except TypeError as e: + except TypeError: if fns not in self._callbacks: self._callbacks.append(fns) @@ -43,7 +43,7 @@ class Callback: for fn in iterator: if fn in self._callbacks: self._callbacks.remove(fn) - except TypeError as e: + except TypeError: if fns in self._callbacks: self._callbacks.remove(fns) diff --git a/obsws_python/events.py b/obsws_python/events.py index 56c4b63..c4d0c8c 100644 --- a/obsws_python/events.py +++ b/obsws_python/events.py @@ -1,4 +1,5 @@ import json +import logging import time from enum import IntEnum from threading import Thread @@ -20,6 +21,7 @@ Subs = IntEnum( class EventClient: + logger = logging.getLogger("events.eventclient") DELAY = 0.001 def __init__(self, **kwargs): @@ -40,7 +42,8 @@ class EventClient: } kwargs = defaultkwargs | kwargs self.base_client = ObsClient(**kwargs) - self.base_client.authenticate() + if self.base_client.authenticate(): + self.logger.info("Successfully identified client with the server") self.callback = Callback() self.subscribe() @@ -56,12 +59,13 @@ class EventClient: """ self.running = True while self.running: - self.data = json.loads(self.base_client.ws.recv()) - event, data = ( - self.data["d"].get("eventType"), - self.data["d"].get("eventData"), + event = json.loads(self.base_client.ws.recv()) + self.logger.debug(f"Event received {event}") + type_, data = ( + event["d"].get("eventType"), + event["d"].get("eventData"), ) - self.callback.trigger(event, data if data else {}) + self.callback.trigger(type_, data if data else {}) time.sleep(self.DELAY) def unsubscribe(self): diff --git a/obsws_python/reqs.py b/obsws_python/reqs.py index 95654c8..3c98e50 100644 --- a/obsws_python/reqs.py +++ b/obsws_python/reqs.py @@ -1,3 +1,5 @@ +import logging + from .baseclient import ObsClient from .error import OBSSDKError from .util import as_dataclass @@ -10,9 +12,18 @@ https://github.com/obsproject/obs-websocket/blob/master/docs/generated/protocol. class ReqClient: + logger = logging.getLogger("reqs.reqclient") + def __init__(self, **kwargs): self.base_client = ObsClient(**kwargs) - self.base_client.authenticate() + if self.base_client.authenticate(): + self.logger.info("Successfully identified client with the server") + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + self.base_client.ws.close() def send(self, param, data=None): response = self.base_client.req(param, data) @@ -486,7 +497,7 @@ class ReqClient: """ - return self.send("GetSceneList") + return self.send("GetGroupList") def get_current_program_scene(self): """