From 0e74740956d1e65ec8d0b8e69c46a77dbdf7a31f Mon Sep 17 00:00:00 2001 From: Thomas Munzer Date: Fri, 24 Apr 2026 12:09:14 +0200 Subject: [PATCH 1/3] Improve websocket reliability per Mist best practices --- README.md | 11 +- src/mistapi/websockets/__ws_client.py | 351 ++++++++++++++++++++++++-- src/mistapi/websockets/location.py | 40 +-- src/mistapi/websockets/orgs.py | 24 +- src/mistapi/websockets/session.py | 8 +- src/mistapi/websockets/sites.py | 56 ++-- tests/unit/test_websocket_client.py | 70 ++++- 7 files changed, 465 insertions(+), 95 deletions(-) diff --git a/README.md b/README.md index 7c7f3f3..8b72662 100644 --- a/README.md +++ b/README.md @@ -583,19 +583,22 @@ All channel classes accept the following optional keyword arguments: | Parameter | Type | Default | Description | |-----------|------|---------|-------------| -| `ping_interval` | `int` | `30` | Seconds between automatic ping frames. Set to `0` to disable pings. | -| `ping_timeout` | `int` | `10` | Seconds to wait for a pong response before treating the connection as dead. | +| `ping_interval` | `int` | `60` | Seconds between automatic ping frames. Set to `0` to disable pings. | +| `ping_timeout` | `int` | `45` | Seconds to wait for a pong response before treating the connection as dead. Must be lower than `ping_interval`. | | `auto_reconnect` | `bool` | `False` | Automatically reconnect on transient failures using exponential backoff. | | `max_reconnect_attempts` | `int` | `5` | Maximum number of reconnect attempts before giving up. | | `reconnect_backoff` | `float` | `2.0` | Base backoff delay in seconds. Doubles after each failed attempt (2s, 4s, 8s, ...). Resets on successful reconnection. | | `queue_maxsize` | `int` | `0` | Maximum messages buffered in the internal queue for `receive()`. `0` means unbounded. When set, incoming messages are dropped with a warning when the queue is full, preventing memory growth on high-frequency streams. | +| `subscription_watchdog_timeout` | `float` | `10.0` | Maximum time to wait for all `channel_subscribed` acknowledgements after connect. On timeout, the connection is closed to trigger a clean reconnect. | +| `rate_limit_backoff` | `float` | `30.0` | Minimum reconnect delay after a 429 rate-limit response. | +| `throughput_log_interval` | `int` | `100` | Logs queue depth and processed counts every N messages. Set to `0` to disable periodic throughput logs. | ```python ws = mistapi.websockets.sites.DeviceStatsEvents( apisession, site_ids=[""], ping_interval=60, # ping every 60 s - ping_timeout=20, # wait up to 20 s for pong + ping_timeout=45, # wait up to 45 s for pong auto_reconnect=True, # reconnect on transient failures ) ws.connect() @@ -609,6 +612,8 @@ ws.connect() | `ws.on_message(cb)` | `cb(data: dict)` | Register callback for incoming messages. Mutually exclusive with `receive()`. | | `ws.on_error(cb)` | `cb(error: Exception)` | Register callback for WebSocket errors | | `ws.on_close(cb)` | `cb(code: int \| None, msg: str \| None)` | Register callback for connection close. Safe to call `connect()` from within. | +| `ws.on_ping(cb)` | `cb(message: str \| bytes \| None)` | Register callback for received ping frames. | +| `ws.on_pong(cb)` | `cb(message: str \| bytes \| None)` | Register callback for received pong frames. | | `ws.connect(run_in_background)` | | Open the connection. `True` (default) runs in a daemon thread; `False` blocks. | | `ws.disconnect(wait, timeout)` | | Close the connection. `wait=True` blocks until the background thread finishes. | | `ws.receive()` | `-> Generator[dict]` | Blocking generator yielding messages. Mutually exclusive with `on_message`. | diff --git a/src/mistapi/websockets/__ws_client.py b/src/mistapi/websockets/__ws_client.py index e351c5d..58fabb5 100644 --- a/src/mistapi/websockets/__ws_client.py +++ b/src/mistapi/websockets/__ws_client.py @@ -19,7 +19,7 @@ import ssl import threading from collections.abc import Callable, Generator -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import websocket @@ -48,6 +48,10 @@ def filter(self, record: logging.LogRecord) -> bool: from mistapi import APISession +MAX_CHANNELS_PER_CONNECTION = 2000 +HIGH_CHANNEL_COUNT_WARNING = 1500 + + class _MistWebsocket: """ Base class for Mist API WebSocket channels. @@ -64,14 +68,25 @@ def __init__( self, mist_session: "APISession", channels: list[str], - ping_interval: int = 30, - ping_timeout: int = 10, + ping_interval: int = 60, + ping_timeout: int = 45, auto_reconnect: bool = False, max_reconnect_attempts: int = 5, reconnect_backoff: float = 2.0, max_reconnect_backoff: float | None = None, queue_maxsize: int = 0, + subscription_watchdog_timeout: float = 10.0, + rate_limit_backoff: float = 30.0, + throughput_log_interval: int = 100, ) -> None: + if ping_interval < 0: + raise ValueError("ping_interval must be >= 0") + if ping_timeout <= 0: + raise ValueError("ping_timeout must be > 0") + if ping_interval and ping_interval <= ping_timeout: + raise ValueError( + "ping_interval must be greater than ping_timeout when enabled" + ) if max_reconnect_attempts < 0: raise ValueError("max_reconnect_attempts must be >= 0 (0 = unlimited)") if reconnect_backoff <= 0: @@ -80,32 +95,75 @@ def __init__( raise ValueError("max_reconnect_backoff must be > 0") if queue_maxsize < 0: raise ValueError("queue_maxsize must be >= 0") + if subscription_watchdog_timeout <= 0: + raise ValueError("subscription_watchdog_timeout must be > 0") + if rate_limit_backoff <= 0: + raise ValueError("rate_limit_backoff must be > 0") + if throughput_log_interval < 0: + raise ValueError("throughput_log_interval must be >= 0") + + deduped_channels = list(dict.fromkeys(channels)) + if len(deduped_channels) != len(channels): + logger.warning( + "Duplicate channels detected; using %d unique channels instead of %d", + len(deduped_channels), + len(channels), + ) + if len(deduped_channels) > MAX_CHANNELS_PER_CONNECTION: + raise ValueError( + f"Too many channels ({len(deduped_channels)}). " + f"Mist supports up to {MAX_CHANNELS_PER_CONNECTION} channels per connection" + ) + if len(deduped_channels) >= HIGH_CHANNEL_COUNT_WARNING: + logger.warning( + "High channel count (%d). Consider spreading subscriptions over multiple " + "WebSocket connections to reduce message backlog risk.", + len(deduped_channels), + ) self._mist_session = mist_session - self._channels = channels + self._channels = deduped_channels + self._expected_channels = set(deduped_channels) self._ping_interval = ping_interval self._ping_timeout = ping_timeout self._auto_reconnect = auto_reconnect self._max_reconnect_attempts = max_reconnect_attempts self._reconnect_backoff = reconnect_backoff self._max_reconnect_backoff = max_reconnect_backoff + self._subscription_watchdog_timeout = subscription_watchdog_timeout + self._rate_limit_backoff = rate_limit_backoff + self._throughput_log_interval = throughput_log_interval self._lock = threading.Lock() + self._subscription_lock = threading.Lock() self._ws: websocket.WebSocketApp | None = None self._thread: threading.Thread | None = None + self._callback_thread: threading.Thread | None = None + self._subscription_watchdog: threading.Timer | None = None self._queue: queue.Queue[dict | None] = queue.Queue(maxsize=queue_maxsize) + self._callback_queue: queue.Queue[dict | None] = queue.Queue( + maxsize=queue_maxsize + ) self._connected = ( threading.Event() ) # tracks whether the WebSocket connection is currently open self._user_disconnect = threading.Event() + self._callback_stop = threading.Event() self._finished = threading.Event() self._finished.set() # not running initially self._reconnect_attempts = 0 self._last_close_code: int | None = None self._last_close_msg: str | None = None + self._last_http_status: int | None = None + self._subscribed_channels: set[str] = set() + self._messages_received = 0 + self._messages_dropped = 0 + self._messages_processed = 0 self._on_message_cb: Callable[[dict], None] | None = None self._on_error_cb: Callable[[Exception], None] | None = None self._on_open_cb: Callable[[], None] | None = None self._on_close_cb: Callable[[int | None, str | None], None] | None = None + self._on_ping_cb: Callable[[str | bytes | None], None] | None = None + self._on_pong_cb: Callable[[str | bytes | None], None] | None = None # ------------------------------------------------------------------ # Auth / URL helpers @@ -184,10 +242,198 @@ def on_close(self, callback: Callable[[int | None, str | None], None]) -> None: """Register a callback invoked when the connection closes.""" self._on_close_cb = callback + def on_ping(self, callback: Callable[[str | bytes | None], None]) -> None: + """Register a callback invoked when a ping frame is received.""" + self._on_ping_cb = callback + + def on_pong(self, callback: Callable[[str | bytes | None], None]) -> None: + """Register a callback invoked when a pong frame is received.""" + self._on_pong_cb = callback + + # ------------------------------------------------------------------ + # Internal helpers + + @staticmethod + def _extract_status_code(error: Exception) -> int | None: + status_code = getattr(error, "status_code", None) + if isinstance(status_code, int): + return status_code + response = getattr(error, "response", None) + if response is not None: + response_code = getattr(response, "status_code", None) + if isinstance(response_code, int): + return response_code + return None + + def _drain_queue(self, target_queue: queue.Queue[Any]) -> None: + while not target_queue.empty(): + try: + target_queue.get_nowait() + except queue.Empty: + break + + def _start_callback_worker(self) -> None: + if self._callback_thread is not None and self._callback_thread.is_alive(): + return + self._callback_stop.clear() + self._callback_thread = threading.Thread( + target=self._run_callback_worker, daemon=True + ) + self._callback_thread.start() + + def _run_callback_worker(self) -> None: + while True: + if self._callback_stop.is_set(): + break + try: + item = self._callback_queue.get(timeout=1) + except queue.Empty: + if self._finished.is_set() and self._callback_queue.empty(): + break + continue + if item is None: + if self._callback_stop.is_set() or self._finished.is_set(): + break + continue + callback = self._on_message_cb + if callback is None: + continue + try: + callback(item) + except Exception: + logger.exception("on_message callback raised") + self._messages_processed += 1 + if ( + self._throughput_log_interval + and self._messages_processed % self._throughput_log_interval == 0 + ): + logger.info( + "WebSocket callback worker processed %d messages. " + "Callback queue size=%d dropped=%d", + self._messages_processed, + self._callback_queue.qsize(), + self._messages_dropped, + ) + + def _cancel_subscription_watchdog(self) -> None: + with self._lock: + timer = self._subscription_watchdog + self._subscription_watchdog = None + if timer is not None: + timer.cancel() + + def _arm_subscription_watchdog(self, ws: websocket.WebSocketApp) -> None: + if not self._expected_channels: + return + + def _watchdog_expired() -> None: + if self._user_disconnect.is_set(): + return + with self._lock: + current_ws = self._ws + if ws is not current_ws: + return + with self._subscription_lock: + missing = sorted(self._expected_channels - self._subscribed_channels) + if not missing: + return + preview = ", ".join(missing[:5]) + if len(missing) > 5: + preview = f"{preview}, ..." + self._last_close_code = 1008 + self._last_close_msg = ( + f"subscription watchdog timeout: missing {len(missing)} channels" + ) + logger.error( + "Subscription watchdog timeout after %.1fs: received %d/%d subscriptions. " + "Missing: %s", + self._subscription_watchdog_timeout, + len(self._expected_channels) - len(missing), + len(self._expected_channels), + preview, + ) + ws.close() + + timer = threading.Timer(self._subscription_watchdog_timeout, _watchdog_expired) + timer.daemon = True + self._cancel_subscription_watchdog() + with self._lock: + self._subscription_watchdog = timer + timer.start() + + def _process_subscription_event( + self, ws: websocket.WebSocketApp, data: dict + ) -> None: + event = data.get("event") + channel = data.get("channel") + if not isinstance(channel, str): + channel = None + + if event == "channel_subscribed" and channel: + with self._subscription_lock: + self._subscribed_channels.add(channel) + subscribed_count = len(self._subscribed_channels) + expected_count = len(self._expected_channels) + logger.info( + "Channel subscribed (%d/%d): %s", + subscribed_count, + expected_count, + channel, + ) + if channel not in self._expected_channels: + logger.warning( + "Received channel_subscribed for unexpected channel: %s", channel + ) + if subscribed_count >= expected_count: + self._cancel_subscription_watchdog() + logger.info("All requested channels subscribed (%d)", expected_count) + return + + if event == "subscribe_failed": + detail = data.get("detail") + logger.error( + "Subscription failed for channel %s: %s. Closing to trigger reconnect.", + channel, + detail, + ) + self._last_close_code = 1008 + self._last_close_msg = f"subscribe_failed channel={channel} detail={detail}" + self._cancel_subscription_watchdog() + ws.close() + + def _enqueue_message(self, message: dict, to_callback_queue: bool) -> None: + target_queue = self._callback_queue if to_callback_queue else self._queue + queue_name = "callback" if to_callback_queue else "receive" + self._messages_received += 1 + try: + target_queue.put_nowait(message) + except queue.Full: + self._messages_dropped += 1 + logger.warning("%s queue full; dropping message", queue_name.capitalize()) + return + if ( + self._throughput_log_interval + and self._messages_received % self._throughput_log_interval == 0 + ): + logger.info( + "WebSocket received %d messages. %s queue size=%d dropped=%d", + self._messages_received, + queue_name.capitalize(), + target_queue.qsize(), + self._messages_dropped, + ) + # ------------------------------------------------------------------ # Internal WebSocketApp handlers def _handle_open(self, ws: websocket.WebSocketApp) -> None: + logger.info( + "WebSocket opened. Requesting %d channel subscription(s)", + len(self._channels), + ) + self._last_http_status = None + with self._subscription_lock: + self._subscribed_channels.clear() try: for channel in self._channels: ws.send(json.dumps({"subscribe": channel})) @@ -195,6 +441,8 @@ def _handle_open(self, ws: websocket.WebSocketApp) -> None: logger.error("Subscription send failed: %s", exc) ws.close() return + if self._expected_channels: + self._arm_subscription_watchdog(ws) self._reconnect_attempts = 0 self._last_close_code = None self._last_close_msg = None @@ -212,33 +460,70 @@ def _handle_message(self, ws: websocket.WebSocketApp, message: str | bytes) -> N data = json.loads(message) except (json.JSONDecodeError, TypeError): data = {"raw": message} + + if isinstance(data, dict): + self._process_subscription_event(ws, data) + if self._on_message_cb: - try: - self._on_message_cb(data) - except Exception: - logger.exception("on_message callback raised") - else: - try: - self._queue.put_nowait(data) - except queue.Full: - logger.warning("Receive queue full; dropping message") + self._start_callback_worker() + self._enqueue_message(data, to_callback_queue=True) + return - def _handle_error(self, ws: websocket.WebSocketApp, error: Exception) -> None: + self._enqueue_message(data, to_callback_queue=False) + + def _handle_error(self, _ws: websocket.WebSocketApp, error: Exception) -> None: + status_code = self._extract_status_code(error) + if status_code is not None: + self._last_http_status = status_code + if status_code == 429: + logger.warning( + "WebSocket received HTTP 429 (rate limit). " + "Reconnect backoff will be raised to at least %.1fs", + self._rate_limit_backoff, + ) + else: + logger.error("WebSocket error: %s", error) if self._on_error_cb: try: self._on_error_cb(error) except Exception: logger.exception("on_error callback raised") + def _handle_ping( + self, _ws: websocket.WebSocketApp, message: str | bytes | None + ) -> None: + logger.info("WebSocket ping received") + if self._on_ping_cb: + try: + self._on_ping_cb(message) + except Exception: + logger.exception("on_ping callback raised") + + def _handle_pong( + self, _ws: websocket.WebSocketApp, message: str | bytes | None + ) -> None: + logger.info("WebSocket pong received") + if self._on_pong_cb: + try: + self._on_pong_cb(message) + except Exception: + logger.exception("on_pong callback raised") + def _handle_close( self, - ws: websocket.WebSocketApp, + _ws: websocket.WebSocketApp, close_status_code: int | None, close_msg: str | None, ) -> None: self._connected.clear() + self._cancel_subscription_watchdog() self._last_close_code = close_status_code self._last_close_msg = close_msg + logger.info( + "WebSocket closed. code=%s message=%s", + close_status_code, + close_msg, + ) # ------------------------------------------------------------------ # Lifecycle @@ -253,6 +538,8 @@ def _create_ws_app(self) -> websocket.WebSocketApp: on_message=self._handle_message, on_error=self._handle_error, on_close=self._handle_close, + on_ping=self._handle_ping, + on_pong=self._handle_pong, ) def connect(self, run_in_background: bool = True) -> None: @@ -270,15 +557,18 @@ def connect(self, run_in_background: bool = True) -> None: raise RuntimeError("Already connected; call disconnect() first") self._finished.clear() self._user_disconnect.clear() + self._callback_stop.clear() self._reconnect_attempts = 0 + self._messages_received = 0 + self._messages_dropped = 0 + self._messages_processed = 0 # Drain stale sentinel from previous connection - while not self._queue.empty(): - try: - self._queue.get_nowait() - except queue.Empty: - break + self._drain_queue(self._queue) + self._drain_queue(self._callback_queue) self._ws = self._create_ws_app() + if self._on_message_cb: + self._start_callback_worker() if run_in_background: self._thread = threading.Thread( target=self._run_forever_safe, daemon=True @@ -294,6 +584,8 @@ def _run_forever_safe(self) -> None: while True: with self._lock: ws = self._ws + if ws is None: + break try: sslopt = self._build_sslopt() ws.run_forever( @@ -322,6 +614,8 @@ def _run_forever_safe(self) -> None: delay = self._reconnect_backoff * (2 ** (self._reconnect_attempts - 1)) if self._max_reconnect_backoff is not None: delay = min(delay, self._max_reconnect_backoff) + if self._last_http_status == 429: + delay = max(delay, self._rate_limit_backoff) if self._max_reconnect_attempts > 0: logger.info( "Reconnecting in %.1fs (attempt %d/%d)", @@ -353,10 +647,16 @@ def _run_forever_safe(self) -> None: pass finally: + self._cancel_subscription_watchdog() + self._callback_stop.set() try: self._queue.put_nowait(None) # sentinel — unblocks receive() except queue.Full: pass # _finished.set() below will unblock receive() independently + try: + self._callback_queue.put_nowait(None) # sentinel — unblocks worker + except queue.Full: + pass self._finished.set() # mark as not running — unblocks connect() if self._on_close_cb: try: @@ -376,13 +676,22 @@ def disconnect(self, wait: bool = False, timeout: float | None = None) -> None: when *wait* is True). ``None`` means wait indefinitely. """ self._user_disconnect.set() + self._callback_stop.set() + self._cancel_subscription_watchdog() with self._lock: ws = self._ws if ws: ws.close() + try: + self._callback_queue.put_nowait(None) + except queue.Full: + pass if wait and self._thread is not None: if self._thread is not threading.current_thread(): self._thread.join(timeout=timeout) + if wait and self._callback_thread is not None: + if self._callback_thread is not threading.current_thread(): + self._callback_thread.join(timeout=timeout) def receive(self) -> Generator[dict, None, None]: """ @@ -442,4 +751,4 @@ def __exit__(self, *args) -> None: def ready(self) -> bool: """Returns True if the WebSocket connection is open and ready.""" - return self._ws is not None and self._ws.ready() + return bool(self._ws is not None and self._ws.ready()) diff --git a/src/mistapi/websockets/location.py b/src/mistapi/websockets/location.py index ad2773c..16e9622 100644 --- a/src/mistapi/websockets/location.py +++ b/src/mistapi/websockets/location.py @@ -29,9 +29,9 @@ class BleAssetsEvents(_MistWebsocket): UUID of the site to stream events from. map_ids : list[str] UUIDs of the maps to stream events from. - ping_interval : int, default 30 + ping_interval : int, default 60 Interval in seconds to send WebSocket ping frames (keep-alive). - ping_timeout : int, default 10 + ping_timeout : int, default 45 Time in seconds to wait for a ping response before considering the connection dead. auto_reconnect : bool, default False Automatically reconnect on unexpected disconnections using exponential backoff. @@ -77,8 +77,8 @@ def __init__( mist_session: APISession, site_id: str, map_ids: list[str], - ping_interval: int = 30, - ping_timeout: int = 10, + ping_interval: int = 60, + ping_timeout: int = 45, auto_reconnect: bool = False, max_reconnect_attempts: int = 5, reconnect_backoff: float = 2.0, @@ -113,9 +113,9 @@ class ConnectedClientsEvents(_MistWebsocket): UUID of the site to stream events from. map_ids : list[str] UUIDs of the maps to stream events from. - ping_interval : int, default 30 + ping_interval : int, default 60 Interval in seconds to send WebSocket ping frames (keep-alive). - ping_timeout : int, default 10 + ping_timeout : int, default 45 Time in seconds to wait for a ping response before considering the connection dead. auto_reconnect : bool, default False Automatically reconnect on unexpected disconnections using exponential backoff. @@ -161,8 +161,8 @@ def __init__( mist_session: APISession, site_id: str, map_ids: list[str], - ping_interval: int = 30, - ping_timeout: int = 10, + ping_interval: int = 60, + ping_timeout: int = 45, auto_reconnect: bool = False, max_reconnect_attempts: int = 5, reconnect_backoff: float = 2.0, @@ -197,9 +197,9 @@ class SdkClientsEvents(_MistWebsocket): UUID of the site to stream events from. map_ids : list[str] UUIDs of the maps to stream events from. - ping_interval : int, default 30 + ping_interval : int, default 60 Interval in seconds to send WebSocket ping frames (keep-alive). - ping_timeout : int, default 10 + ping_timeout : int, default 45 Time in seconds to wait for a ping response before considering the connection dead. auto_reconnect : bool, default False Automatically reconnect on unexpected disconnections using exponential backoff. @@ -245,8 +245,8 @@ def __init__( mist_session: APISession, site_id: str, map_ids: list[str], - ping_interval: int = 30, - ping_timeout: int = 10, + ping_interval: int = 60, + ping_timeout: int = 45, auto_reconnect: bool = False, max_reconnect_attempts: int = 5, reconnect_backoff: float = 2.0, @@ -281,9 +281,9 @@ class UnconnectedClientsEvents(_MistWebsocket): UUID of the site to stream events from. map_ids : list[str] UUIDs of the maps to stream events from. - ping_interval : int, default 30 + ping_interval : int, default 60 Interval in seconds to send WebSocket ping frames (keep-alive). - ping_timeout : int, default 10 + ping_timeout : int, default 45 Time in seconds to wait for a ping response before considering the connection dead. auto_reconnect : bool, default False Automatically reconnect on unexpected disconnections using exponential backoff. @@ -329,8 +329,8 @@ def __init__( mist_session: APISession, site_id: str, map_ids: list[str], - ping_interval: int = 30, - ping_timeout: int = 10, + ping_interval: int = 60, + ping_timeout: int = 45, auto_reconnect: bool = False, max_reconnect_attempts: int = 5, reconnect_backoff: float = 2.0, @@ -367,9 +367,9 @@ class DiscoveredBleAssetsEvents(_MistWebsocket): UUID of the site to stream events from. map_ids : list[str] UUIDs of the maps to stream events from. - ping_interval : int, default 30 + ping_interval : int, default 60 Interval in seconds to send WebSocket ping frames (keep-alive). - ping_timeout : int, default 10 + ping_timeout : int, default 45 Time in seconds to wait for a ping response before considering the connection dead. auto_reconnect : bool, default False Automatically reconnect on unexpected disconnections using exponential backoff. @@ -415,8 +415,8 @@ def __init__( mist_session: APISession, site_id: str, map_ids: list[str], - ping_interval: int = 30, - ping_timeout: int = 10, + ping_interval: int = 60, + ping_timeout: int = 45, auto_reconnect: bool = False, max_reconnect_attempts: int = 5, reconnect_backoff: float = 2.0, diff --git a/src/mistapi/websockets/orgs.py b/src/mistapi/websockets/orgs.py index 9a04e7f..546d6b2 100644 --- a/src/mistapi/websockets/orgs.py +++ b/src/mistapi/websockets/orgs.py @@ -27,9 +27,9 @@ class InsightsEvents(_MistWebsocket): Authenticated API session. org_id : str UUID of the organization to stream events from. - ping_interval : int, default 30 + ping_interval : int, default 60 Interval in seconds to send WebSocket ping frames (keep-alive). - ping_timeout : int, default 10 + ping_timeout : int, default 45 Time in seconds to wait for a ping response before considering the connection dead. auto_reconnect : bool, default False Automatically reconnect on unexpected disconnections using exponential backoff. @@ -74,8 +74,8 @@ def __init__( self, mist_session: APISession, org_id: str, - ping_interval: int = 30, - ping_timeout: int = 10, + ping_interval: int = 60, + ping_timeout: int = 45, auto_reconnect: bool = False, max_reconnect_attempts: int = 5, reconnect_backoff: float = 2.0, @@ -107,9 +107,9 @@ class MxEdgesStatsEvents(_MistWebsocket): Authenticated API session. org_id : str UUID of the organization to stream events from. - ping_interval : int, default 30 + ping_interval : int, default 60 Interval in seconds to send WebSocket ping frames (keep-alive). - ping_timeout : int, default 10 + ping_timeout : int, default 45 Time in seconds to wait for a ping response before considering the connection dead. auto_reconnect : bool, default False Automatically reconnect on unexpected disconnections using exponential backoff. @@ -154,8 +154,8 @@ def __init__( self, mist_session: APISession, org_id: str, - ping_interval: int = 30, - ping_timeout: int = 10, + ping_interval: int = 60, + ping_timeout: int = 45, auto_reconnect: bool = False, max_reconnect_attempts: int = 5, reconnect_backoff: float = 2.0, @@ -187,9 +187,9 @@ class MxEdgesEvents(_MistWebsocket): Authenticated API session. org_id : str UUID of the org to stream events from. - ping_interval : int, default 30 + ping_interval : int, default 60 Interval in seconds to send WebSocket ping frames (keep-alive). - ping_timeout : int, default 10 + ping_timeout : int, default 45 Time in seconds to wait for a ping response before considering the connection dead. auto_reconnect : bool, default False Automatically reconnect on unexpected disconnections using exponential backoff. @@ -234,8 +234,8 @@ def __init__( self, mist_session: APISession, org_id: str, - ping_interval: int = 30, - ping_timeout: int = 10, + ping_interval: int = 60, + ping_timeout: int = 45, auto_reconnect: bool = False, max_reconnect_attempts: int = 5, reconnect_backoff: float = 2.0, diff --git a/src/mistapi/websockets/session.py b/src/mistapi/websockets/session.py index e7365dc..f15ae93 100644 --- a/src/mistapi/websockets/session.py +++ b/src/mistapi/websockets/session.py @@ -36,9 +36,9 @@ class SessionWithUrl(_MistWebsocket): The session's authentication credentials (API token or cookies) are sent to whatever host is specified in this URL. Only use trusted URLs — never pass user-supplied or untrusted input. - ping_interval : int, default 30 + ping_interval : int, default 60 Interval in seconds to send WebSocket ping frames (keep-alive). - ping_timeout : int, default 10 + ping_timeout : int, default 45 Time in seconds to wait for a ping response before considering the connection dead. auto_reconnect : bool, default False Automatically reconnect on unexpected disconnections using exponential backoff. @@ -83,8 +83,8 @@ def __init__( self, mist_session: APISession, url: str, - ping_interval: int = 30, - ping_timeout: int = 10, + ping_interval: int = 60, + ping_timeout: int = 45, auto_reconnect: bool = False, max_reconnect_attempts: int = 5, reconnect_backoff: float = 2.0, diff --git a/src/mistapi/websockets/sites.py b/src/mistapi/websockets/sites.py index 64b2b88..1ba11fb 100644 --- a/src/mistapi/websockets/sites.py +++ b/src/mistapi/websockets/sites.py @@ -27,9 +27,9 @@ class ClientsStatsEvents(_MistWebsocket): Authenticated API session. site_ids : list[str] UUIDs of the sites to stream events from. - ping_interval : int, default 30 + ping_interval : int, default 60 Interval in seconds to send WebSocket ping frames (keep-alive). - ping_timeout : int, default 10 + ping_timeout : int, default 45 Time in seconds to wait for a ping response before considering the connection dead. auto_reconnect : bool, default False Automatically reconnect on unexpected disconnections using exponential backoff. @@ -74,8 +74,8 @@ def __init__( self, mist_session: APISession, site_ids: list[str], - ping_interval: int = 30, - ping_timeout: int = 10, + ping_interval: int = 60, + ping_timeout: int = 45, auto_reconnect: bool = False, max_reconnect_attempts: int = 5, reconnect_backoff: float = 2.0, @@ -116,9 +116,9 @@ class DeviceCmdEvents(_MistWebsocket): UUID of the site to stream events from. device_ids : list[str] UUIDs of the devices to stream events from. - ping_interval : int, default 30 + ping_interval : int, default 60 Interval in seconds to send WebSocket ping frames (keep-alive). - ping_timeout : int, default 10 + ping_timeout : int, default 45 Time in seconds to wait for a ping response before considering the connection dead. auto_reconnect : bool, default False Automatically reconnect on unexpected disconnections using exponential backoff. @@ -164,8 +164,8 @@ def __init__( mist_session: APISession, site_id: str, device_ids: list[str], - ping_interval: int = 30, - ping_timeout: int = 10, + ping_interval: int = 60, + ping_timeout: int = 45, auto_reconnect: bool = False, max_reconnect_attempts: int = 5, reconnect_backoff: float = 2.0, @@ -200,9 +200,9 @@ class DeviceStatsEvents(_MistWebsocket): Authenticated API session. site_ids : list[str] UUIDs of the sites to stream events from. - ping_interval : int, default 30 + ping_interval : int, default 60 Interval in seconds to send WebSocket ping frames (keep-alive). - ping_timeout : int, default 10 + ping_timeout : int, default 45 Time in seconds to wait for a ping response before considering the connection dead. auto_reconnect : bool, default False Automatically reconnect on unexpected disconnections using exponential backoff. @@ -247,8 +247,8 @@ def __init__( self, mist_session: APISession, site_ids: list[str], - ping_interval: int = 30, - ping_timeout: int = 10, + ping_interval: int = 60, + ping_timeout: int = 45, auto_reconnect: bool = False, max_reconnect_attempts: int = 5, reconnect_backoff: float = 2.0, @@ -281,9 +281,9 @@ class DeviceEvents(_MistWebsocket): Authenticated API session. site_ids : list[str] UUIDs of the sites to stream events from. - ping_interval : int, default 30 + ping_interval : int, default 60 Interval in seconds to send WebSocket ping frames (keep-alive). - ping_timeout : int, default 10 + ping_timeout : int, default 45 Time in seconds to wait for a ping response before considering the connection dead. auto_reconnect : bool, default False Automatically reconnect on unexpected disconnections using exponential backoff. @@ -328,8 +328,8 @@ def __init__( self, mist_session: APISession, site_ids: list[str], - ping_interval: int = 30, - ping_timeout: int = 10, + ping_interval: int = 60, + ping_timeout: int = 45, auto_reconnect: bool = False, max_reconnect_attempts: int = 5, reconnect_backoff: float = 2.0, @@ -362,9 +362,9 @@ class MxEdgesStatsEvents(_MistWebsocket): Authenticated API session. site_ids : list[str] UUIDs of the sites to stream events from. - ping_interval : int, default 30 + ping_interval : int, default 60 Interval in seconds to send WebSocket ping frames (keep-alive). - ping_timeout : int, default 10 + ping_timeout : int, default 45 Time in seconds to wait for a ping response before considering the connection dead. auto_reconnect : bool, default False Automatically reconnect on unexpected disconnections using exponential backoff. @@ -409,8 +409,8 @@ def __init__( self, mist_session: APISession, site_ids: list[str], - ping_interval: int = 30, - ping_timeout: int = 10, + ping_interval: int = 60, + ping_timeout: int = 45, auto_reconnect: bool = False, max_reconnect_attempts: int = 5, reconnect_backoff: float = 2.0, @@ -443,9 +443,9 @@ class MxEdgesEvents(_MistWebsocket): Authenticated API session. site_ids : list[str] UUIDs of the sites to stream events from. - ping_interval : int, default 30 + ping_interval : int, default 60 Interval in seconds to send WebSocket ping frames (keep-alive). - ping_timeout : int, default 10 + ping_timeout : int, default 45 Time in seconds to wait for a ping response before considering the connection dead. auto_reconnect : bool, default False Automatically reconnect on unexpected disconnections using exponential backoff. @@ -490,8 +490,8 @@ def __init__( self, mist_session: APISession, site_ids: list[str], - ping_interval: int = 30, - ping_timeout: int = 10, + ping_interval: int = 60, + ping_timeout: int = 45, auto_reconnect: bool = False, max_reconnect_attempts: int = 5, reconnect_backoff: float = 2.0, @@ -524,9 +524,9 @@ class PcapEvents(_MistWebsocket): Authenticated API session. site_id : str UUID of the site to stream events from. - ping_interval : int, default 30 + ping_interval : int, default 60 Interval in seconds to send WebSocket ping frames (keep-alive). - ping_timeout : int, default 10 + ping_timeout : int, default 45 Time in seconds to wait for a ping response before considering the connection dead. auto_reconnect : bool, default False Automatically reconnect on unexpected disconnections using exponential backoff. @@ -571,8 +571,8 @@ def __init__( self, mist_session: APISession, site_id: str, - ping_interval: int = 30, - ping_timeout: int = 10, + ping_interval: int = 60, + ping_timeout: int = 45, auto_reconnect: bool = False, max_reconnect_attempts: int = 5, reconnect_backoff: float = 2.0, diff --git a/tests/unit/test_websocket_client.py b/tests/unit/test_websocket_client.py index b59da14..b684a61 100644 --- a/tests/unit/test_websocket_client.py +++ b/tests/unit/test_websocket_client.py @@ -363,17 +363,43 @@ def test_wraps_invalid_json_in_raw_key(self, ws_client) -> None: def test_calls_on_message_callback_with_parsed_data(self, ws_client) -> None: cb = Mock() - ws_client.on_message(cb) + called = threading.Event() + + def cb_wrapper(data): + cb(data) + called.set() + + ws_client.on_message(cb_wrapper) + ws_client._finished.clear() # keep worker alive for this assertion + ws_client._start_callback_worker() payload = {"type": "event"} ws_client._handle_message(Mock(), json.dumps(payload)) + + assert called.wait(timeout=1), "callback was not invoked by worker" cb.assert_called_once_with(payload) + ws_client._callback_stop.set() + ws_client._callback_queue.put_nowait(None) + def test_calls_on_message_callback_with_raw_fallback(self, ws_client) -> None: cb = Mock() - ws_client.on_message(cb) + called = threading.Event() + + def cb_wrapper(data): + cb(data) + called.set() + + ws_client.on_message(cb_wrapper) + ws_client._finished.clear() # keep worker alive for this assertion + ws_client._start_callback_worker() ws_client._handle_message(Mock(), "plain text") + + assert called.wait(timeout=1), "callback was not invoked by worker" cb.assert_called_once_with({"raw": "plain text"}) + ws_client._callback_stop.set() + ws_client._callback_queue.put_nowait(None) + def test_no_error_without_on_message_callback(self, ws_client) -> None: ws_client._handle_message(Mock(), '{"ok": true}') # Should not raise @@ -461,6 +487,8 @@ def test_connect_creates_websocket_app(self, mock_ws_cls, ws_client) -> None: on_message=ws_client._handle_message, on_error=ws_client._handle_error, on_close=ws_client._handle_close, + on_ping=ws_client._handle_ping, + on_pong=ws_client._handle_pong, ) mock_ws_instance.run_forever.assert_called_once() @@ -538,8 +566,8 @@ def test_passes_sslopt_when_verify_false(self, mock_session) -> None: client._ws = mock_ws client._run_forever_safe() mock_ws.run_forever.assert_called_once_with( - ping_interval=30, - ping_timeout=10, + ping_interval=60, + ping_timeout=45, sslopt={"cert_reqs": ssl.CERT_NONE, "check_hostname": False}, ) @@ -687,8 +715,8 @@ class TestInit: """Tests for __init__ defaults.""" def test_default_ping_interval_and_timeout(self, ws_client) -> None: - assert ws_client._ping_interval == 30 - assert ws_client._ping_timeout == 10 + assert ws_client._ping_interval == 60 + assert ws_client._ping_timeout == 45 def test_custom_ping_interval_and_timeout(self, single_channel_client) -> None: assert single_channel_client._ping_interval == 15 @@ -729,6 +757,22 @@ def test_negative_queue_maxsize_raises(self, mock_session) -> None: with pytest.raises(ValueError, match="queue_maxsize must be >= 0"): _MistWebsocket(mock_session, channels=["/ch"], queue_maxsize=-1) + def test_ping_interval_must_be_greater_than_ping_timeout( + self, mock_session + ) -> None: + with pytest.raises(ValueError, match="ping_interval must be greater"): + _MistWebsocket( + mock_session, + channels=["/ch"], + ping_interval=10, + ping_timeout=10, + ) + + def test_channel_limit_enforced(self, mock_session) -> None: + channels = [f"/sites/{i}/stats/devices" for i in range(2001)] + with pytest.raises(ValueError, match="Too many channels"): + _MistWebsocket(mock_session, channels=channels) + def test_negative_max_reconnect_backoff_raises(self, mock_session) -> None: with pytest.raises(ValueError, match="max_reconnect_backoff must be > 0"): _MistWebsocket(mock_session, channels=["/ch"], max_reconnect_backoff=-1.0) @@ -1205,12 +1249,24 @@ class TestQueueCallbackBehavior: def test_message_callback_skips_queue(self, ws_client) -> None: cb = Mock() - ws_client.on_message(cb) + called = threading.Event() + + def cb_wrapper(data): + cb(data) + called.set() + + ws_client.on_message(cb_wrapper) + ws_client._finished.clear() # keep worker alive for this assertion + ws_client._start_callback_worker() ws_client._handle_message(Mock(), '{"event": "data"}') + assert called.wait(timeout=1), "callback was not invoked by worker" cb.assert_called_once_with({"event": "data"}) assert ws_client._queue.empty() + ws_client._callback_stop.set() + ws_client._callback_queue.put_nowait(None) + def test_no_callback_uses_queue(self, ws_client) -> None: ws_client._handle_message(Mock(), '{"event": "data"}') assert not ws_client._queue.empty() From 7840cd3b7512679c083c322ef852a8578332682e Mon Sep 17 00:00:00 2001 From: Thomas Munzer Date: Fri, 24 Apr 2026 12:36:20 +0200 Subject: [PATCH 2/3] Address PR feedback for websocket reliability changes --- README.md | 4 +- src/mistapi/websockets/__ws_client.py | 15 ++++--- src/mistapi/websockets/location.py | 30 ++++++++++++++ src/mistapi/websockets/orgs.py | 18 +++++++++ src/mistapi/websockets/session.py | 6 +++ src/mistapi/websockets/sites.py | 42 +++++++++++++++++++ tests/unit/test_websocket_client.py | 58 ++++++++++++++++++++++++--- 7 files changed, 159 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 8b72662..c89dfb0 100644 --- a/README.md +++ b/README.md @@ -584,11 +584,11 @@ All channel classes accept the following optional keyword arguments: | Parameter | Type | Default | Description | |-----------|------|---------|-------------| | `ping_interval` | `int` | `60` | Seconds between automatic ping frames. Set to `0` to disable pings. | -| `ping_timeout` | `int` | `45` | Seconds to wait for a pong response before treating the connection as dead. Must be lower than `ping_interval`. | +| `ping_timeout` | `int` | `45` | Seconds to wait for a pong response before treating the connection as dead. When `ping_interval > 0`, this must be lower than `ping_interval`. | | `auto_reconnect` | `bool` | `False` | Automatically reconnect on transient failures using exponential backoff. | | `max_reconnect_attempts` | `int` | `5` | Maximum number of reconnect attempts before giving up. | | `reconnect_backoff` | `float` | `2.0` | Base backoff delay in seconds. Doubles after each failed attempt (2s, 4s, 8s, ...). Resets on successful reconnection. | -| `queue_maxsize` | `int` | `0` | Maximum messages buffered in the internal queue for `receive()`. `0` means unbounded. When set, incoming messages are dropped with a warning when the queue is full, preventing memory growth on high-frequency streams. | +| `queue_maxsize` | `int` | `0` | Maximum messages buffered in the internal queues used for both `receive()` and callback delivery. `0` means unbounded. When set, incoming messages are dropped with a warning when either queue is full, preventing memory growth on high-frequency streams. | | `subscription_watchdog_timeout` | `float` | `10.0` | Maximum time to wait for all `channel_subscribed` acknowledgements after connect. On timeout, the connection is closed to trigger a clean reconnect. | | `rate_limit_backoff` | `float` | `30.0` | Minimum reconnect delay after a 429 rate-limit response. | | `throughput_log_interval` | `int` | `100` | Logs queue depth and processed counts every N messages. Set to `0` to disable periodic throughput logs. | diff --git a/src/mistapi/websockets/__ws_client.py b/src/mistapi/websockets/__ws_client.py index 58fabb5..1556a76 100644 --- a/src/mistapi/websockets/__ws_client.py +++ b/src/mistapi/websockets/__ws_client.py @@ -460,6 +460,8 @@ def _handle_message(self, ws: websocket.WebSocketApp, message: str | bytes) -> N data = json.loads(message) except (json.JSONDecodeError, TypeError): data = {"raw": message} + if not isinstance(data, dict): + data = {"data": data} if isinstance(data, dict): self._process_subscription_event(ws, data) @@ -473,8 +475,7 @@ def _handle_message(self, ws: websocket.WebSocketApp, message: str | bytes) -> N def _handle_error(self, _ws: websocket.WebSocketApp, error: Exception) -> None: status_code = self._extract_status_code(error) - if status_code is not None: - self._last_http_status = status_code + self._last_http_status = status_code if status_code == 429: logger.warning( "WebSocket received HTTP 429 (rate limit). " @@ -517,12 +518,14 @@ def _handle_close( ) -> None: self._connected.clear() self._cancel_subscription_watchdog() - self._last_close_code = close_status_code - self._last_close_msg = close_msg + if close_status_code is not None: + self._last_close_code = close_status_code + if close_msg not in (None, ""): + self._last_close_msg = close_msg logger.info( "WebSocket closed. code=%s message=%s", - close_status_code, - close_msg, + self._last_close_code, + self._last_close_msg, ) # ------------------------------------------------------------------ diff --git a/src/mistapi/websockets/location.py b/src/mistapi/websockets/location.py index 16e9622..f9bc0aa 100644 --- a/src/mistapi/websockets/location.py +++ b/src/mistapi/websockets/location.py @@ -84,6 +84,9 @@ def __init__( reconnect_backoff: float = 2.0, max_reconnect_backoff: float | None = None, queue_maxsize: int = 0, + subscription_watchdog_timeout: float = 10.0, + rate_limit_backoff: float = 30.0, + throughput_log_interval: int = 100, ) -> None: channels = [f"/sites/{site_id}/stats/maps/{mid}/assets" for mid in map_ids] super().__init__( @@ -96,6 +99,9 @@ def __init__( reconnect_backoff=reconnect_backoff, max_reconnect_backoff=max_reconnect_backoff, queue_maxsize=queue_maxsize, + subscription_watchdog_timeout=subscription_watchdog_timeout, + rate_limit_backoff=rate_limit_backoff, + throughput_log_interval=throughput_log_interval, ) @@ -168,6 +174,9 @@ def __init__( reconnect_backoff: float = 2.0, max_reconnect_backoff: float | None = None, queue_maxsize: int = 0, + subscription_watchdog_timeout: float = 10.0, + rate_limit_backoff: float = 30.0, + throughput_log_interval: int = 100, ) -> None: channels = [f"/sites/{site_id}/stats/maps/{mid}/clients" for mid in map_ids] super().__init__( @@ -180,6 +189,9 @@ def __init__( reconnect_backoff=reconnect_backoff, max_reconnect_backoff=max_reconnect_backoff, queue_maxsize=queue_maxsize, + subscription_watchdog_timeout=subscription_watchdog_timeout, + rate_limit_backoff=rate_limit_backoff, + throughput_log_interval=throughput_log_interval, ) @@ -252,6 +264,9 @@ def __init__( reconnect_backoff: float = 2.0, max_reconnect_backoff: float | None = None, queue_maxsize: int = 0, + subscription_watchdog_timeout: float = 10.0, + rate_limit_backoff: float = 30.0, + throughput_log_interval: int = 100, ) -> None: channels = [f"/sites/{site_id}/stats/maps/{mid}/sdkclients" for mid in map_ids] super().__init__( @@ -264,6 +279,9 @@ def __init__( reconnect_backoff=reconnect_backoff, max_reconnect_backoff=max_reconnect_backoff, queue_maxsize=queue_maxsize, + subscription_watchdog_timeout=subscription_watchdog_timeout, + rate_limit_backoff=rate_limit_backoff, + throughput_log_interval=throughput_log_interval, ) @@ -336,6 +354,9 @@ def __init__( reconnect_backoff: float = 2.0, max_reconnect_backoff: float | None = None, queue_maxsize: int = 0, + subscription_watchdog_timeout: float = 10.0, + rate_limit_backoff: float = 30.0, + throughput_log_interval: int = 100, ) -> None: channels = [ f"/sites/{site_id}/stats/maps/{mid}/unconnected_clients" for mid in map_ids @@ -350,6 +371,9 @@ def __init__( reconnect_backoff=reconnect_backoff, max_reconnect_backoff=max_reconnect_backoff, queue_maxsize=queue_maxsize, + subscription_watchdog_timeout=subscription_watchdog_timeout, + rate_limit_backoff=rate_limit_backoff, + throughput_log_interval=throughput_log_interval, ) @@ -422,6 +446,9 @@ def __init__( reconnect_backoff: float = 2.0, max_reconnect_backoff: float | None = None, queue_maxsize: int = 0, + subscription_watchdog_timeout: float = 10.0, + rate_limit_backoff: float = 30.0, + throughput_log_interval: int = 100, ) -> None: channels = [ f"/sites/{site_id}/stats/maps/{mid}/discovered_assets" for mid in map_ids @@ -436,4 +463,7 @@ def __init__( reconnect_backoff=reconnect_backoff, max_reconnect_backoff=max_reconnect_backoff, queue_maxsize=queue_maxsize, + subscription_watchdog_timeout=subscription_watchdog_timeout, + rate_limit_backoff=rate_limit_backoff, + throughput_log_interval=throughput_log_interval, ) diff --git a/src/mistapi/websockets/orgs.py b/src/mistapi/websockets/orgs.py index 546d6b2..e23a6d6 100644 --- a/src/mistapi/websockets/orgs.py +++ b/src/mistapi/websockets/orgs.py @@ -81,6 +81,9 @@ def __init__( reconnect_backoff: float = 2.0, max_reconnect_backoff: float | None = None, queue_maxsize: int = 0, + subscription_watchdog_timeout: float = 10.0, + rate_limit_backoff: float = 30.0, + throughput_log_interval: int = 100, ) -> None: super().__init__( mist_session, @@ -92,6 +95,9 @@ def __init__( reconnect_backoff=reconnect_backoff, max_reconnect_backoff=max_reconnect_backoff, queue_maxsize=queue_maxsize, + subscription_watchdog_timeout=subscription_watchdog_timeout, + rate_limit_backoff=rate_limit_backoff, + throughput_log_interval=throughput_log_interval, ) @@ -161,6 +167,9 @@ def __init__( reconnect_backoff: float = 2.0, max_reconnect_backoff: float | None = None, queue_maxsize: int = 0, + subscription_watchdog_timeout: float = 10.0, + rate_limit_backoff: float = 30.0, + throughput_log_interval: int = 100, ) -> None: super().__init__( mist_session, @@ -172,6 +181,9 @@ def __init__( reconnect_backoff=reconnect_backoff, max_reconnect_backoff=max_reconnect_backoff, queue_maxsize=queue_maxsize, + subscription_watchdog_timeout=subscription_watchdog_timeout, + rate_limit_backoff=rate_limit_backoff, + throughput_log_interval=throughput_log_interval, ) @@ -241,6 +253,9 @@ def __init__( reconnect_backoff: float = 2.0, max_reconnect_backoff: float | None = None, queue_maxsize: int = 0, + subscription_watchdog_timeout: float = 10.0, + rate_limit_backoff: float = 30.0, + throughput_log_interval: int = 100, ) -> None: super().__init__( mist_session, @@ -252,4 +267,7 @@ def __init__( reconnect_backoff=reconnect_backoff, max_reconnect_backoff=max_reconnect_backoff, queue_maxsize=queue_maxsize, + subscription_watchdog_timeout=subscription_watchdog_timeout, + rate_limit_backoff=rate_limit_backoff, + throughput_log_interval=throughput_log_interval, ) diff --git a/src/mistapi/websockets/session.py b/src/mistapi/websockets/session.py index f15ae93..29f7b16 100644 --- a/src/mistapi/websockets/session.py +++ b/src/mistapi/websockets/session.py @@ -90,6 +90,9 @@ def __init__( reconnect_backoff: float = 2.0, max_reconnect_backoff: float | None = None, queue_maxsize: int = 0, + subscription_watchdog_timeout: float = 10.0, + rate_limit_backoff: float = 30.0, + throughput_log_interval: int = 100, ) -> None: parsed = urlparse(url) if parsed.scheme.lower() != "wss" or not parsed.netloc: @@ -105,6 +108,9 @@ def __init__( reconnect_backoff=reconnect_backoff, max_reconnect_backoff=max_reconnect_backoff, queue_maxsize=queue_maxsize, + subscription_watchdog_timeout=subscription_watchdog_timeout, + rate_limit_backoff=rate_limit_backoff, + throughput_log_interval=throughput_log_interval, ) def _build_ws_url(self) -> str: diff --git a/src/mistapi/websockets/sites.py b/src/mistapi/websockets/sites.py index 1ba11fb..04d902c 100644 --- a/src/mistapi/websockets/sites.py +++ b/src/mistapi/websockets/sites.py @@ -81,6 +81,9 @@ def __init__( reconnect_backoff: float = 2.0, max_reconnect_backoff: float | None = None, queue_maxsize: int = 0, + subscription_watchdog_timeout: float = 10.0, + rate_limit_backoff: float = 30.0, + throughput_log_interval: int = 100, ) -> None: channels = [f"/sites/{site_id}/stats/clients" for site_id in site_ids] super().__init__( @@ -93,6 +96,9 @@ def __init__( reconnect_backoff=reconnect_backoff, max_reconnect_backoff=max_reconnect_backoff, queue_maxsize=queue_maxsize, + subscription_watchdog_timeout=subscription_watchdog_timeout, + rate_limit_backoff=rate_limit_backoff, + throughput_log_interval=throughput_log_interval, ) @@ -171,6 +177,9 @@ def __init__( reconnect_backoff: float = 2.0, max_reconnect_backoff: float | None = None, queue_maxsize: int = 0, + subscription_watchdog_timeout: float = 10.0, + rate_limit_backoff: float = 30.0, + throughput_log_interval: int = 100, ) -> None: channels = [ f"/sites/{site_id}/devices/{device_id}/cmd" for device_id in device_ids @@ -185,6 +194,9 @@ def __init__( reconnect_backoff=reconnect_backoff, max_reconnect_backoff=max_reconnect_backoff, queue_maxsize=queue_maxsize, + subscription_watchdog_timeout=subscription_watchdog_timeout, + rate_limit_backoff=rate_limit_backoff, + throughput_log_interval=throughput_log_interval, ) @@ -254,6 +266,9 @@ def __init__( reconnect_backoff: float = 2.0, max_reconnect_backoff: float | None = None, queue_maxsize: int = 0, + subscription_watchdog_timeout: float = 10.0, + rate_limit_backoff: float = 30.0, + throughput_log_interval: int = 100, ) -> None: channels = [f"/sites/{site_id}/stats/devices" for site_id in site_ids] super().__init__( @@ -266,6 +281,9 @@ def __init__( reconnect_backoff=reconnect_backoff, max_reconnect_backoff=max_reconnect_backoff, queue_maxsize=queue_maxsize, + subscription_watchdog_timeout=subscription_watchdog_timeout, + rate_limit_backoff=rate_limit_backoff, + throughput_log_interval=throughput_log_interval, ) @@ -335,6 +353,9 @@ def __init__( reconnect_backoff: float = 2.0, max_reconnect_backoff: float | None = None, queue_maxsize: int = 0, + subscription_watchdog_timeout: float = 10.0, + rate_limit_backoff: float = 30.0, + throughput_log_interval: int = 100, ) -> None: channels = [f"/sites/{site_id}/devices" for site_id in site_ids] super().__init__( @@ -347,6 +368,9 @@ def __init__( reconnect_backoff=reconnect_backoff, max_reconnect_backoff=max_reconnect_backoff, queue_maxsize=queue_maxsize, + subscription_watchdog_timeout=subscription_watchdog_timeout, + rate_limit_backoff=rate_limit_backoff, + throughput_log_interval=throughput_log_interval, ) @@ -416,6 +440,9 @@ def __init__( reconnect_backoff: float = 2.0, max_reconnect_backoff: float | None = None, queue_maxsize: int = 0, + subscription_watchdog_timeout: float = 10.0, + rate_limit_backoff: float = 30.0, + throughput_log_interval: int = 100, ) -> None: channels = [f"/sites/{site_id}/stats/mxedges" for site_id in site_ids] super().__init__( @@ -428,6 +455,9 @@ def __init__( reconnect_backoff=reconnect_backoff, max_reconnect_backoff=max_reconnect_backoff, queue_maxsize=queue_maxsize, + subscription_watchdog_timeout=subscription_watchdog_timeout, + rate_limit_backoff=rate_limit_backoff, + throughput_log_interval=throughput_log_interval, ) @@ -497,6 +527,9 @@ def __init__( reconnect_backoff: float = 2.0, max_reconnect_backoff: float | None = None, queue_maxsize: int = 0, + subscription_watchdog_timeout: float = 10.0, + rate_limit_backoff: float = 30.0, + throughput_log_interval: int = 100, ) -> None: channels = [f"/sites/{site_id}/mxedges" for site_id in site_ids] super().__init__( @@ -509,6 +542,9 @@ def __init__( reconnect_backoff=reconnect_backoff, max_reconnect_backoff=max_reconnect_backoff, queue_maxsize=queue_maxsize, + subscription_watchdog_timeout=subscription_watchdog_timeout, + rate_limit_backoff=rate_limit_backoff, + throughput_log_interval=throughput_log_interval, ) @@ -578,6 +614,9 @@ def __init__( reconnect_backoff: float = 2.0, max_reconnect_backoff: float | None = None, queue_maxsize: int = 0, + subscription_watchdog_timeout: float = 10.0, + rate_limit_backoff: float = 30.0, + throughput_log_interval: int = 100, ) -> None: channels = [f"/sites/{site_id}/pcaps"] super().__init__( @@ -590,4 +629,7 @@ def __init__( reconnect_backoff=reconnect_backoff, max_reconnect_backoff=max_reconnect_backoff, queue_maxsize=queue_maxsize, + subscription_watchdog_timeout=subscription_watchdog_timeout, + rate_limit_backoff=rate_limit_backoff, + throughput_log_interval=throughput_log_interval, ) diff --git a/tests/unit/test_websocket_client.py b/tests/unit/test_websocket_client.py index b684a61..e2eb4e7 100644 --- a/tests/unit/test_websocket_client.py +++ b/tests/unit/test_websocket_client.py @@ -378,8 +378,7 @@ def cb_wrapper(data): assert called.wait(timeout=1), "callback was not invoked by worker" cb.assert_called_once_with(payload) - ws_client._callback_stop.set() - ws_client._callback_queue.put_nowait(None) + ws_client.disconnect(wait=True, timeout=1) def test_calls_on_message_callback_with_raw_fallback(self, ws_client) -> None: cb = Mock() @@ -397,8 +396,7 @@ def cb_wrapper(data): assert called.wait(timeout=1), "callback was not invoked by worker" cb.assert_called_once_with({"raw": "plain text"}) - ws_client._callback_stop.set() - ws_client._callback_queue.put_nowait(None) + ws_client.disconnect(wait=True, timeout=1) def test_no_error_without_on_message_callback(self, ws_client) -> None: ws_client._handle_message(Mock(), '{"ok": true}') # Should not raise @@ -829,6 +827,18 @@ def test_inherits_from_mist_websocket(self, mock_session) -> None: ws = DeviceCmdEvents(mock_session, site_id="s1", device_ids=["d1"]) assert isinstance(ws, _MistWebsocket) + def test_supports_reliability_kwargs(self, mock_session) -> None: + ws = DeviceStatsEvents( + mock_session, + site_ids=["s1"], + subscription_watchdog_timeout=3.0, + rate_limit_backoff=12.0, + throughput_log_interval=250, + ) + assert ws._subscription_watchdog_timeout == 3.0 + assert ws._rate_limit_backoff == 12.0 + assert ws._throughput_log_interval == 250 + class TestOrgChannels: """Tests for public org-level WebSocket channel classes.""" @@ -849,6 +859,18 @@ def test_inherits_from_mist_websocket(self, mock_session) -> None: ws = InsightsEvents(mock_session, org_id="o1") assert isinstance(ws, _MistWebsocket) + def test_supports_reliability_kwargs(self, mock_session) -> None: + ws = InsightsEvents( + mock_session, + org_id="o1", + subscription_watchdog_timeout=5.0, + rate_limit_backoff=20.0, + throughput_log_interval=400, + ) + assert ws._subscription_watchdog_timeout == 5.0 + assert ws._rate_limit_backoff == 20.0 + assert ws._throughput_log_interval == 400 + class TestLocationChannels: """Tests for public location-level WebSocket channel classes.""" @@ -880,6 +902,19 @@ def test_inherits_from_mist_websocket(self, mock_session) -> None: ws = BleAssetsEvents(mock_session, site_id="s1", map_ids=["m1"]) assert isinstance(ws, _MistWebsocket) + def test_supports_reliability_kwargs(self, mock_session) -> None: + ws = ConnectedClientsEvents( + mock_session, + site_id="s1", + map_ids=["m1"], + subscription_watchdog_timeout=4.0, + rate_limit_backoff=18.0, + throughput_log_interval=150, + ) + assert ws._subscription_watchdog_timeout == 4.0 + assert ws._rate_limit_backoff == 18.0 + assert ws._throughput_log_interval == 150 + class TestSessionChannel: """Tests for the SessionWithUrl WebSocket channel class.""" @@ -893,6 +928,18 @@ def test_inherits_from_mist_websocket(self, mock_session) -> None: ws = SessionWithUrl(mock_session, url="wss://example.com/custom") assert isinstance(ws, _MistWebsocket) + def test_supports_reliability_kwargs(self, mock_session) -> None: + ws = SessionWithUrl( + mock_session, + url="wss://example.com/custom", + subscription_watchdog_timeout=6.0, + rate_limit_backoff=25.0, + throughput_log_interval=300, + ) + assert ws._subscription_watchdog_timeout == 6.0 + assert ws._rate_limit_backoff == 25.0 + assert ws._throughput_log_interval == 300 + # --------------------------------------------------------------------------- # Auto-reconnect @@ -1264,8 +1311,7 @@ def cb_wrapper(data): cb.assert_called_once_with({"event": "data"}) assert ws_client._queue.empty() - ws_client._callback_stop.set() - ws_client._callback_queue.put_nowait(None) + ws_client.disconnect(wait=True, timeout=1) def test_no_callback_uses_queue(self, ws_client) -> None: ws_client._handle_message(Mock(), '{"event": "data"}') From 7ce8a8d7a7d9e68b2a08e3747ff9caee55327587 Mon Sep 17 00:00:00 2001 From: Thomas Munzer Date: Fri, 24 Apr 2026 12:37:46 +0200 Subject: [PATCH 3/3] Tune websocket logging and metrics thread safety --- src/mistapi/websockets/__ws_client.py | 43 +++++++++++++++++++-------- 1 file changed, 31 insertions(+), 12 deletions(-) diff --git a/src/mistapi/websockets/__ws_client.py b/src/mistapi/websockets/__ws_client.py index 1556a76..9b37c00 100644 --- a/src/mistapi/websockets/__ws_client.py +++ b/src/mistapi/websockets/__ws_client.py @@ -135,6 +135,7 @@ def __init__( self._throughput_log_interval = throughput_log_interval self._lock = threading.Lock() self._subscription_lock = threading.Lock() + self._metrics_lock = threading.Lock() self._ws: websocket.WebSocketApp | None = None self._thread: threading.Thread | None = None self._callback_thread: threading.Thread | None = None @@ -302,17 +303,20 @@ def _run_callback_worker(self) -> None: callback(item) except Exception: logger.exception("on_message callback raised") - self._messages_processed += 1 + with self._metrics_lock: + self._messages_processed += 1 + messages_processed = self._messages_processed + messages_dropped = self._messages_dropped if ( self._throughput_log_interval - and self._messages_processed % self._throughput_log_interval == 0 + and messages_processed % self._throughput_log_interval == 0 ): logger.info( "WebSocket callback worker processed %d messages. " "Callback queue size=%d dropped=%d", - self._messages_processed, + messages_processed, self._callback_queue.qsize(), - self._messages_dropped, + messages_dropped, ) def _cancel_subscription_watchdog(self) -> None: @@ -374,12 +378,22 @@ def _process_subscription_event( self._subscribed_channels.add(channel) subscribed_count = len(self._subscribed_channels) expected_count = len(self._expected_channels) - logger.info( + logger.debug( "Channel subscribed (%d/%d): %s", subscribed_count, expected_count, channel, ) + if expected_count and ( + subscribed_count == 1 + or subscribed_count % 100 == 0 + or subscribed_count >= expected_count + ): + logger.info( + "Subscription progress: received %d/%d channel acknowledgements", + subscribed_count, + expected_count, + ) if channel not in self._expected_channels: logger.warning( "Received channel_subscribed for unexpected channel: %s", channel @@ -404,23 +418,28 @@ def _process_subscription_event( def _enqueue_message(self, message: dict, to_callback_queue: bool) -> None: target_queue = self._callback_queue if to_callback_queue else self._queue queue_name = "callback" if to_callback_queue else "receive" - self._messages_received += 1 + with self._metrics_lock: + self._messages_received += 1 + messages_received = self._messages_received try: target_queue.put_nowait(message) except queue.Full: - self._messages_dropped += 1 + with self._metrics_lock: + self._messages_dropped += 1 logger.warning("%s queue full; dropping message", queue_name.capitalize()) return if ( self._throughput_log_interval - and self._messages_received % self._throughput_log_interval == 0 + and messages_received % self._throughput_log_interval == 0 ): + with self._metrics_lock: + messages_dropped = self._messages_dropped logger.info( "WebSocket received %d messages. %s queue size=%d dropped=%d", - self._messages_received, + messages_received, queue_name.capitalize(), target_queue.qsize(), - self._messages_dropped, + messages_dropped, ) # ------------------------------------------------------------------ @@ -493,7 +512,7 @@ def _handle_error(self, _ws: websocket.WebSocketApp, error: Exception) -> None: def _handle_ping( self, _ws: websocket.WebSocketApp, message: str | bytes | None ) -> None: - logger.info("WebSocket ping received") + logger.debug("WebSocket ping received") if self._on_ping_cb: try: self._on_ping_cb(message) @@ -503,7 +522,7 @@ def _handle_ping( def _handle_pong( self, _ws: websocket.WebSocketApp, message: str | bytes | None ) -> None: - logger.info("WebSocket pong received") + logger.debug("WebSocket pong received") if self._on_pong_cb: try: self._on_pong_cb(message)