mirror of
https://github.com/onyx-and-iris/obsws-python.git
synced 2025-01-18 03:20:47 +00:00
only check for host+port values in init.
only pass auth token if auth enabled add context manager methods to reqclient. added logging
This commit is contained in:
parent
69b0d4137a
commit
27fd86efa5
@ -1,6 +1,7 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from random import randint
|
||||
|
||||
@ -11,9 +12,11 @@ except ModuleNotFoundError:
|
||||
|
||||
import websocket
|
||||
|
||||
from .error import OBSSDKError
|
||||
|
||||
|
||||
class ObsClient:
|
||||
DELAY = 0.001
|
||||
logger = logging.getLogger("baseclient.obsclient")
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
defaultkwargs = {
|
||||
@ -23,7 +26,7 @@ class ObsClient:
|
||||
kwargs = defaultkwargs | kwargs
|
||||
for attr, val in kwargs.items():
|
||||
setattr(self, attr, val)
|
||||
if not (self.host and self.port and self.password):
|
||||
if not (self.host and self.port):
|
||||
conn = self._conn_from_toml()
|
||||
self.host = conn["host"]
|
||||
self.port = conn["port"]
|
||||
@ -40,50 +43,51 @@ class ObsClient:
|
||||
return conn["connection"]
|
||||
|
||||
def authenticate(self):
|
||||
secret = base64.b64encode(
|
||||
hashlib.sha256(
|
||||
(
|
||||
self.password + self.server_hello["d"]["authentication"]["salt"]
|
||||
).encode()
|
||||
).digest()
|
||||
)
|
||||
|
||||
auth = base64.b64encode(
|
||||
hashlib.sha256(
|
||||
(
|
||||
secret.decode()
|
||||
+ self.server_hello["d"]["authentication"]["challenge"]
|
||||
).encode()
|
||||
).digest()
|
||||
).decode()
|
||||
|
||||
payload = {
|
||||
"op": 1,
|
||||
"d": {
|
||||
"rpcVersion": 1,
|
||||
"authentication": auth,
|
||||
"eventSubscriptions": self.subs,
|
||||
},
|
||||
}
|
||||
|
||||
if "authentication" in self.server_hello["d"]:
|
||||
secret = base64.b64encode(
|
||||
hashlib.sha256(
|
||||
(
|
||||
self.password + self.server_hello["d"]["authentication"]["salt"]
|
||||
).encode()
|
||||
).digest()
|
||||
)
|
||||
|
||||
auth = base64.b64encode(
|
||||
hashlib.sha256(
|
||||
(
|
||||
secret.decode()
|
||||
+ self.server_hello["d"]["authentication"]["challenge"]
|
||||
).encode()
|
||||
).digest()
|
||||
).decode()
|
||||
|
||||
payload["d"]["authentication"] = auth
|
||||
|
||||
self.ws.send(json.dumps(payload))
|
||||
return self.ws.recv()
|
||||
try:
|
||||
response = json.loads(self.ws.recv())
|
||||
return response["op"] == 2
|
||||
except json.decoder.JSONDecodeError:
|
||||
raise OBSSDKError("failed to identify client with the server")
|
||||
|
||||
def req(self, req_type, req_data=None):
|
||||
id = randint(1, 1000)
|
||||
self.logger.debug(f"Sending request with response id {id}")
|
||||
payload = {
|
||||
"op": 6,
|
||||
"d": {"requestType": req_type, "requestId": id},
|
||||
}
|
||||
if req_data:
|
||||
payload = {
|
||||
"op": 6,
|
||||
"d": {
|
||||
"requestType": req_type,
|
||||
"requestId": randint(1, 1000),
|
||||
"requestData": req_data,
|
||||
},
|
||||
}
|
||||
else:
|
||||
payload = {
|
||||
"op": 6,
|
||||
"d": {"requestType": req_type, "requestId": randint(1, 1000)},
|
||||
}
|
||||
payload["d"]["requestData"] = req_data
|
||||
self.ws.send(json.dumps(payload))
|
||||
response = json.loads(self.ws.recv())
|
||||
self.logger.debug(f"Reponse received {response}")
|
||||
return response["d"]
|
||||
|
@ -31,7 +31,7 @@ class Callback:
|
||||
for fn in iterator:
|
||||
if fn not in self._callbacks:
|
||||
self._callbacks.append(fn)
|
||||
except TypeError as e:
|
||||
except TypeError:
|
||||
if fns not in self._callbacks:
|
||||
self._callbacks.append(fns)
|
||||
|
||||
@ -43,7 +43,7 @@ class Callback:
|
||||
for fn in iterator:
|
||||
if fn in self._callbacks:
|
||||
self._callbacks.remove(fn)
|
||||
except TypeError as e:
|
||||
except TypeError:
|
||||
if fns in self._callbacks:
|
||||
self._callbacks.remove(fns)
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from enum import IntEnum
|
||||
from threading import Thread
|
||||
@ -20,6 +21,7 @@ Subs = IntEnum(
|
||||
|
||||
|
||||
class EventClient:
|
||||
logger = logging.getLogger("events.eventclient")
|
||||
DELAY = 0.001
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
@ -40,7 +42,8 @@ class EventClient:
|
||||
}
|
||||
kwargs = defaultkwargs | kwargs
|
||||
self.base_client = ObsClient(**kwargs)
|
||||
self.base_client.authenticate()
|
||||
if self.base_client.authenticate():
|
||||
self.logger.info("Successfully identified client with the server")
|
||||
self.callback = Callback()
|
||||
self.subscribe()
|
||||
|
||||
@ -56,12 +59,13 @@ class EventClient:
|
||||
"""
|
||||
self.running = True
|
||||
while self.running:
|
||||
self.data = json.loads(self.base_client.ws.recv())
|
||||
event, data = (
|
||||
self.data["d"].get("eventType"),
|
||||
self.data["d"].get("eventData"),
|
||||
event = json.loads(self.base_client.ws.recv())
|
||||
self.logger.debug(f"Event received {event}")
|
||||
type_, data = (
|
||||
event["d"].get("eventType"),
|
||||
event["d"].get("eventData"),
|
||||
)
|
||||
self.callback.trigger(event, data if data else {})
|
||||
self.callback.trigger(type_, data if data else {})
|
||||
time.sleep(self.DELAY)
|
||||
|
||||
def unsubscribe(self):
|
||||
|
@ -1,3 +1,5 @@
|
||||
import logging
|
||||
|
||||
from .baseclient import ObsClient
|
||||
from .error import OBSSDKError
|
||||
from .util import as_dataclass
|
||||
@ -10,9 +12,18 @@ https://github.com/obsproject/obs-websocket/blob/master/docs/generated/protocol.
|
||||
|
||||
|
||||
class ReqClient:
|
||||
logger = logging.getLogger("reqs.reqclient")
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.base_client = ObsClient(**kwargs)
|
||||
self.base_client.authenticate()
|
||||
if self.base_client.authenticate():
|
||||
self.logger.info("Successfully identified client with the server")
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||
self.base_client.ws.close()
|
||||
|
||||
def send(self, param, data=None):
|
||||
response = self.base_client.req(param, data)
|
||||
@ -486,7 +497,7 @@ class ReqClient:
|
||||
|
||||
|
||||
"""
|
||||
return self.send("GetSceneList")
|
||||
return self.send("GetGroupList")
|
||||
|
||||
def get_current_program_scene(self):
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user