Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .cspell/custom-dictionary-workspace.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ Codespaces
collapsable
compareform
configform
connack
Consolas
coro
cprofile
Expand Down Expand Up @@ -100,6 +101,7 @@ dstart
dwindow
Eddi
elif
emqx
emszzzz
enctype
endfor
Expand Down
95 changes: 72 additions & 23 deletions apps/predbat/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,11 @@ async def _mqtt_loop(self):
if self.api_stop:
break

# If the broker rejected our credentials (e.g. the MQTT JWT expired),
# refresh the token before reconnecting — otherwise we would retry the
# same rejected token forever and lose all gateway control.
await self._maybe_refresh_on_auth_error(e)

self.log(f"Info: GatewayMQTT: Reconnecting in {backoff}s")
await asyncio.sleep(backoff)
backoff = min(backoff * 2, max_backoff)
Expand Down Expand Up @@ -1306,11 +1311,11 @@ async def final(self):
self.log("Info: GatewayMQTT: Finalized")

async def _check_token_refresh(self):
"""Check if the MQTT JWT token needs refreshing and refresh if needed.
"""Proactively refresh the MQTT JWT when it nears expiry (housekeeping path).

Uses the oauth-refresh edge function (same pattern as OAuthMixin) to
obtain a new access token before the current one expires. The refresh
token is held server-side in instance secrets.
Runs every cycle from run(); only triggers a refresh once the token is
within token_needs_refresh() of expiry. The reconnect loop additionally
forces a refresh on broker auth-failure via _maybe_refresh_on_auth_error().
"""
if not HAS_AIOHTTP:
return
Expand All @@ -1326,14 +1331,69 @@ async def _check_token_refresh(self):
self.log("Warn: GatewayMQTT: could not extract JWT expiry, token refresh disabled")
return

if self.mqtt_token_expires_at and self.mqtt_token_expires_at > 0 and not self.token_needs_refresh(self.mqtt_token_expires_at):
if self.mqtt_token_expires_at == -1:
return

if self.mqtt_token_expires_at == -1:
if self.mqtt_token_expires_at and self.mqtt_token_expires_at > 0 and not self.token_needs_refresh(self.mqtt_token_expires_at):
return

await self._do_token_refresh()

@staticmethod
def _is_auth_failure(error):
"""True if a broker error means our MQTT credentials were rejected.

Matches MQTT CONNACK auth reason codes 134 (bad user name or password) and
135 (not authorized) plus their text. An expired MQTT JWT shows up as code
134, so the reconnect loop uses this to decide it must refresh the token
rather than retry the same rejected one forever.
"""
text = str(error).lower()
return "bad user name or password" in text or "not authorized" in text or "not authorised" in text or "unauthorized" in text or "code:134" in text or "code:135" in text
Comment thread
springfall2008 marked this conversation as resolved.

async def _maybe_refresh_on_auth_error(self, error):
"""Force an MQTT token refresh when the broker rejected authentication.

Returns True if a refresh was attempted and succeeded. Non-auth errors
(network drops, "Disconnected during message iteration") are ignored so we
don't hammer the refresh endpoint for problems a new token cannot fix.
"""
if not self._is_auth_failure(error):
return False
self.log("Info: GatewayMQTT: Broker rejected auth — refreshing MQTT token before reconnect")
return await self._do_token_refresh()

def _apply_refresh_response(self, data):
"""Apply an oauth-refresh JSON reply to the in-memory token. Returns success."""
if not data.get("success"):
self.log(f"Warn: GatewayMQTT: Token refresh failed: {data.get('error', 'unknown')}")
return False

self.mqtt_token = data["access_token"]
if data.get("expires_at"):
try:
if isinstance(data["expires_at"], (int, float)):
self.mqtt_token_expires_at = float(data["expires_at"])
else:
dt = datetime.datetime.fromisoformat(data["expires_at"].replace("Z", "+00:00"))
self.mqtt_token_expires_at = dt.timestamp()
except (ValueError, AttributeError):
self.mqtt_token_expires_at = 0
self.log("Info: GatewayMQTT: MQTT token refreshed")
return True

async def _do_token_refresh(self):
"""Refresh the MQTT JWT via the oauth-refresh edge function, unconditionally.

Shared by the proactive near-expiry check and the auth-failure reconnect
path. The refresh token is held server-side in instance secrets. Returns
True if a new access token was obtained and applied.
"""
if not HAS_AIOHTTP:
return False

if self._refresh_in_progress:
return
return False

self._refresh_in_progress = True
try:
Expand All @@ -1343,7 +1403,7 @@ async def _check_token_refresh(self):

if not supabase_url or not supabase_key or not instance_id:
self.log("Warn: GatewayMQTT: Token refresh skipped — missing env vars or instance_id")
return
return False

url = f"{supabase_url}/functions/v1/oauth-refresh"
headers = {
Expand All @@ -1362,29 +1422,18 @@ async def _check_token_refresh(self):
async with session.post(url, headers=headers, json=payload) as response:
if response.status != 200:
self.log(f"Warn: GatewayMQTT: Token refresh HTTP {response.status}")
return
return False

data = await response.json()

if data.get("success"):
self.mqtt_token = data["access_token"]
if data.get("expires_at"):
try:
if isinstance(data["expires_at"], (int, float)):
self.mqtt_token_expires_at = float(data["expires_at"])
else:
dt = datetime.datetime.fromisoformat(data["expires_at"].replace("Z", "+00:00"))
self.mqtt_token_expires_at = dt.timestamp()
except (ValueError, AttributeError):
self.mqtt_token_expires_at = 0
self.log("Info: GatewayMQTT: MQTT token refreshed")
else:
self.log(f"Warn: GatewayMQTT: Token refresh failed: {data.get('error', 'unknown')}")
return self._apply_refresh_response(data)

except (aiohttp.ClientError, asyncio.TimeoutError) as e:
self.log(f"Warn: GatewayMQTT: Token refresh network error: {e}")
return False
except Exception as e:
self.log(f"Warn: GatewayMQTT: Token refresh error: {e}")
return False
finally:
self._refresh_in_progress = False

Expand Down
5 changes: 5 additions & 0 deletions apps/predbat/tests/test_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -2872,6 +2872,8 @@ async def fake_mqtt_loop():

def run_gateway_tests(my_predbat=None):
"""Run all GatewayMQTT tests. Returns True on failure, False on success."""
from tests.test_gateway_token_refresh import TestIsAuthFailure, TestApplyRefreshResponse, TestMaybeRefreshOnAuthError

test_classes = [
TestProtobufDecode,
TestPlanSerialization,
Expand All @@ -2889,6 +2891,9 @@ def run_gateway_tests(my_predbat=None):
TestIanaToPosixTz,
TestCheckInverterResets,
TestRunStartupWait,
TestIsAuthFailure,
TestApplyRefreshResponse,
TestMaybeRefreshOnAuthError,
]
for cls in test_classes:
instance = cls()
Expand Down
117 changes: 117 additions & 0 deletions apps/predbat/tests/test_gateway_token_refresh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""Tests for GatewayMQTT MQTT-token refresh on broker auth-failure.

Regression cover for the 2026-06-16 fleet incident: the gateway MQTT JWT has a
24h TTL; when it expired, _mqtt_loop reconnected forever with the same rejected
token (EMQX CONNACK code 134 "Bad user name or password") and never refreshed,
so PredBat lost all gateway control until the pod/config was rebuilt. The fix
makes the reconnect loop force a token refresh when the broker rejects auth.
"""
Comment thread
springfall2008 marked this conversation as resolved.
import asyncio
from unittest.mock import MagicMock, AsyncMock


def _bare_gateway():
"""A GatewayMQTT instance with no __init__ side effects, log mocked."""
from gateway import GatewayMQTT

gw = GatewayMQTT.__new__(GatewayMQTT)
gw.log = MagicMock()
return gw


class TestIsAuthFailure:
"""_is_auth_failure() distinguishes broker auth rejections from other drops."""

def test_code_134_bad_credentials_is_auth_failure(self):
from gateway import GatewayMQTT

assert GatewayMQTT._is_auth_failure("[code:134] Bad user name or password") is True

def test_code_135_not_authorized_is_auth_failure(self):
from gateway import GatewayMQTT

assert GatewayMQTT._is_auth_failure("[code:135] Not authorized") is True

def test_not_authorised_british_spelling_is_auth_failure(self):
from gateway import GatewayMQTT

assert GatewayMQTT._is_auth_failure("Not authorised") is True

def test_accepts_exception_object_not_just_string(self):
from gateway import GatewayMQTT

err = Exception("[code:134] Bad user name or password")
assert GatewayMQTT._is_auth_failure(err) is True

def test_message_iteration_drop_is_not_auth_failure(self):
from gateway import GatewayMQTT

assert GatewayMQTT._is_auth_failure("Disconnected during message iteration") is False

def test_network_refused_is_not_auth_failure(self):
from gateway import GatewayMQTT

assert GatewayMQTT._is_auth_failure("[Errno 111] Connection refused") is False

def test_empty_is_not_auth_failure(self):
from gateway import GatewayMQTT

assert GatewayMQTT._is_auth_failure("") is False


class TestApplyRefreshResponse:
"""_apply_refresh_response() updates the in-memory token from an oauth-refresh reply."""

def test_success_updates_token_and_epoch_expiry(self):
gw = _bare_gateway()
gw.mqtt_token = "old.jwt.token"
gw.mqtt_token_expires_at = 1.0

ok = gw._apply_refresh_response({"success": True, "access_token": "new.jwt.token", "expires_at": 1781700000})

assert ok is True
assert gw.mqtt_token == "new.jwt.token"
assert gw.mqtt_token_expires_at == 1781700000.0

def test_success_parses_iso_expiry(self):
gw = _bare_gateway()
gw.mqtt_token = "old.jwt.token"
gw.mqtt_token_expires_at = 1.0

ok = gw._apply_refresh_response({"success": True, "access_token": "new.jwt.token", "expires_at": "2026-06-17T00:00:00Z"})

assert ok is True
assert gw.mqtt_token == "new.jwt.token"
assert gw.mqtt_token_expires_at > 0

def test_failure_leaves_token_unchanged(self):
gw = _bare_gateway()
gw.mqtt_token = "old.jwt.token"
gw.mqtt_token_expires_at = 1.0

ok = gw._apply_refresh_response({"success": False, "error": "needs_reauth"})

assert ok is False
assert gw.mqtt_token == "old.jwt.token"


class TestMaybeRefreshOnAuthError:
"""_maybe_refresh_on_auth_error() forces a refresh only for auth failures."""

def test_auth_failure_triggers_refresh(self):
gw = _bare_gateway()
gw._do_token_refresh = AsyncMock(return_value=True)

result = asyncio.run(gw._maybe_refresh_on_auth_error("[code:134] Bad user name or password"))

gw._do_token_refresh.assert_awaited_once()
assert result is True

def test_non_auth_failure_does_not_refresh(self):
gw = _bare_gateway()
gw._do_token_refresh = AsyncMock(return_value=True)

result = asyncio.run(gw._maybe_refresh_on_auth_error("Disconnected during message iteration"))

gw._do_token_refresh.assert_not_awaited()
assert result is False
Loading