diff --git a/.coverage_full b/.coverage_full
new file mode 100644
index 0000000..6aa8830
Binary files /dev/null and b/.coverage_full differ
diff --git a/.secrets.baseline b/.secrets.baseline
index 1beb0ff..9805f36 100644
--- a/.secrets.baseline
+++ b/.secrets.baseline
@@ -268,28 +268,28 @@
"filename": "src/api/hive_auth_async.py",
"hashed_secret": "5dc786e32e3a0a4611daaf397721c6ef64cd71b0",
"is_verified": false,
- "line_number": 48
+ "line_number": 49
},
{
"type": "Secret Keyword",
"filename": "src/api/hive_auth_async.py",
"hashed_secret": "ac9f290e69cee683ba3c63461f1f3fa02765032a",
"is_verified": false,
- "line_number": 49
+ "line_number": 50
},
{
"type": "Secret Keyword",
"filename": "src/api/hive_auth_async.py",
"hashed_secret": "351b174ccf89601f6f4bd3f3970a4aba7d17c98e",
"is_verified": false,
- "line_number": 52
+ "line_number": 53
},
{
"type": "Secret Keyword",
"filename": "src/api/hive_auth_async.py",
"hashed_secret": "576956b5291ac38d04ef5f82cc974286a857f0b2",
"is_verified": false,
- "line_number": 109
+ "line_number": 112
}
],
"src/api/srp_crypto.py": [
@@ -298,112 +298,112 @@
"filename": "src/api/srp_crypto.py",
"hashed_secret": "3e619ee0820ecf213c2f38c634e416b53defe3b0",
"is_verified": false,
- "line_number": 11
+ "line_number": 10
},
{
"type": "Hex High Entropy String",
"filename": "src/api/srp_crypto.py",
"hashed_secret": "b8e0d506d969f09a9af89ce89fd9759b72c63262",
"is_verified": false,
- "line_number": 12
+ "line_number": 11
},
{
"type": "Hex High Entropy String",
"filename": "src/api/srp_crypto.py",
"hashed_secret": "e97a751edc71e9afbe0c0f63ec94873392833f9f",
"is_verified": false,
- "line_number": 13
+ "line_number": 12
},
{
"type": "Hex High Entropy String",
"filename": "src/api/srp_crypto.py",
"hashed_secret": "92488c021dd524a2f4e116666b3645308fa0e35c",
"is_verified": false,
- "line_number": 14
+ "line_number": 13
},
{
"type": "Hex High Entropy String",
"filename": "src/api/srp_crypto.py",
"hashed_secret": "d4571e2f026f458aecd2950b0eb6aec190276177",
"is_verified": false,
- "line_number": 15
+ "line_number": 14
},
{
"type": "Hex High Entropy String",
"filename": "src/api/srp_crypto.py",
"hashed_secret": "8109d3c2f659f13cb61fc9e71eed574efe8c8fd8",
"is_verified": false,
- "line_number": 16
+ "line_number": 15
},
{
"type": "Hex High Entropy String",
"filename": "src/api/srp_crypto.py",
"hashed_secret": "08cac7461d7b624b88c53ee47da09cbbb84ea290",
"is_verified": false,
- "line_number": 17
+ "line_number": 16
},
{
"type": "Hex High Entropy String",
"filename": "src/api/srp_crypto.py",
"hashed_secret": "95523fea7e6136c6148299dcc3077debfa2976b3",
"is_verified": false,
- "line_number": 18
+ "line_number": 17
},
{
"type": "Hex High Entropy String",
"filename": "src/api/srp_crypto.py",
"hashed_secret": "c978fb77621e86f5e9077653fe5345ac1616b466",
"is_verified": false,
- "line_number": 19
+ "line_number": 18
},
{
"type": "Hex High Entropy String",
"filename": "src/api/srp_crypto.py",
"hashed_secret": "fc02990268ecf8a35a4912d60dab3754e5f43846",
"is_verified": false,
- "line_number": 20
+ "line_number": 19
},
{
"type": "Hex High Entropy String",
"filename": "src/api/srp_crypto.py",
"hashed_secret": "2c2c0ca491a73e95c8965b6641731057b65f6462",
"is_verified": false,
- "line_number": 21
+ "line_number": 20
},
{
"type": "Hex High Entropy String",
"filename": "src/api/srp_crypto.py",
"hashed_secret": "672b25c6be065170206f3fc6346ebb8e84cbb9d3",
"is_verified": false,
- "line_number": 22
+ "line_number": 21
},
{
"type": "Hex High Entropy String",
"filename": "src/api/srp_crypto.py",
"hashed_secret": "99d02e268ea3ee849fb6e359c6c1b019e4d07efd",
"is_verified": false,
- "line_number": 23
+ "line_number": 22
},
{
"type": "Hex High Entropy String",
"filename": "src/api/srp_crypto.py",
"hashed_secret": "e677fc4cb09d99e1e0d30af31f2e209e541e380e",
"is_verified": false,
- "line_number": 24
+ "line_number": 23
},
{
"type": "Hex High Entropy String",
"filename": "src/api/srp_crypto.py",
"hashed_secret": "05b69b06f40cae0c910a15b1ac75b1f7a847eccb",
"is_verified": false,
- "line_number": 25
+ "line_number": 24
},
{
"type": "Hex High Entropy String",
"filename": "src/api/srp_crypto.py",
"hashed_secret": "c7f914bac2d66eb3f8ae3888fa47bf1ada6caaf5",
"is_verified": false,
- "line_number": 26
+ "line_number": 25
}
],
"tests/unit/test_device_registration.py": [
@@ -419,21 +419,21 @@
"filename": "tests/unit/test_device_registration.py",
"hashed_secret": "d8bce9746547bb7743e5933fbf0fc4f2d2cbcad3",
"is_verified": false,
- "line_number": 710
+ "line_number": 659
},
{
"type": "Secret Keyword",
"filename": "tests/unit/test_device_registration.py",
"hashed_secret": "e4f50034475acff058e17b35679f8ef1e54f86c5",
"is_verified": false,
- "line_number": 783
+ "line_number": 727
},
{
"type": "Secret Keyword",
"filename": "tests/unit/test_device_registration.py",
"hashed_secret": "6ab013c213c685b1f1b1a452796bf22afbd44699",
"is_verified": false,
- "line_number": 794
+ "line_number": 737
}
],
"tests/unit/test_hive_auth.py": [
@@ -486,14 +486,14 @@
"filename": "tests/unit/test_hive_auth_async.py",
"hashed_secret": "5c5a15a8b0b3e154d77746945e563ba40100681b",
"is_verified": false,
- "line_number": 150
+ "line_number": 173
},
{
"type": "Secret Keyword",
"filename": "tests/unit/test_hive_auth_async.py",
"hashed_secret": "d8bce9746547bb7743e5933fbf0fc4f2d2cbcad3",
"is_verified": false,
- "line_number": 206
+ "line_number": 272
}
],
"tests/unit/test_hive_auth_async_extended.py": [
@@ -502,49 +502,49 @@
"filename": "tests/unit/test_hive_auth_async_extended.py",
"hashed_secret": "5c5a15a8b0b3e154d77746945e563ba40100681b",
"is_verified": false,
- "line_number": 259
+ "line_number": 260
},
{
"type": "Secret Keyword",
"filename": "tests/unit/test_hive_auth_async_extended.py",
"hashed_secret": "d8bce9746547bb7743e5933fbf0fc4f2d2cbcad3",
"is_verified": false,
- "line_number": 340
+ "line_number": 341
},
{
"type": "Secret Keyword",
"filename": "tests/unit/test_hive_auth_async_extended.py",
"hashed_secret": "76f6b6f16cb41692b330fc806029e8a31e20b69b",
"is_verified": false,
- "line_number": 815
+ "line_number": 813
},
{
"type": "Secret Keyword",
"filename": "tests/unit/test_hive_auth_async_extended.py",
"hashed_secret": "b3ed2cf313e7546085c3c50622143ff31e467d23",
"is_verified": false,
- "line_number": 834
+ "line_number": 832
},
{
"type": "Secret Keyword",
"filename": "tests/unit/test_hive_auth_async_extended.py",
"hashed_secret": "7476b69b5005e05d536361f960a9d18b736dfbfc",
"is_verified": false,
- "line_number": 848
+ "line_number": 846
},
{
"type": "Secret Keyword",
"filename": "tests/unit/test_hive_auth_async_extended.py",
"hashed_secret": "ff9f30d9ba5a4ec386edddeacc27f74ef412085e",
"is_verified": false,
- "line_number": 855
+ "line_number": 853
},
{
"type": "Secret Keyword",
"filename": "tests/unit/test_hive_auth_async_extended.py",
"hashed_secret": "a8ad0732120b9dfed5b99fd6a2aca4fc8ba48d80",
"is_verified": false,
- "line_number": 893
+ "line_number": 890
}
],
"tests/unit/test_hive_helper_extended.py": [
@@ -553,21 +553,21 @@
"filename": "tests/unit/test_hive_helper_extended.py",
"hashed_secret": "701b389b848a2b1cfab867093101d8d5ac56addd",
"is_verified": false,
- "line_number": 134
+ "line_number": 102
},
{
"type": "Secret Keyword",
"filename": "tests/unit/test_hive_helper_extended.py",
"hashed_secret": "18960546905b75c869e7de63961dc185f9a0a7c9",
"is_verified": false,
- "line_number": 141
+ "line_number": 109
},
{
"type": "Secret Keyword",
"filename": "tests/unit/test_hive_helper_extended.py",
"hashed_secret": "fbf52ca8a72d8ecd77235d3b3e5d014e19ffbff2",
"is_verified": false,
- "line_number": 143
+ "line_number": 111
}
],
"tests/unit/test_session_discovery_extended.py": [
@@ -580,5 +580,5 @@
}
]
},
- "generated_at": "2026-05-17T16:44:49Z"
+ "generated_at": "2026-06-16T18:53:56Z"
}
diff --git a/src/__init__.py b/src/__init__.py
index bf6f4a0..13762b9 100644
--- a/src/__init__.py
+++ b/src/__init__.py
@@ -12,7 +12,10 @@
from .helper.const import SMS_REQUIRED
from .helper.hive_exceptions import (
HiveApiError,
+ HiveAuthCredentialError,
HiveAuthError,
+ HiveConfigurationError,
+ HiveError,
HiveFailedToRefreshTokens,
HiveInvalid2FACode,
HiveInvalidDeviceAuthentication,
diff --git a/src/api/device_registration.py b/src/api/device_registration.py
index 5013ed8..316eb3c 100644
--- a/src/api/device_registration.py
+++ b/src/api/device_registration.py
@@ -57,7 +57,7 @@ class DeviceRegistrationMixin:
large_a_value: int
client_secret: str | None
- async def generate_hash_device(self, device_group_key, device_key):
+ def generate_hash_device(self, device_group_key, device_key):
"""Generate device hash key."""
# source: https://github.com/amazon-archives/amazon-cognito-identity-js/blob/6b87f1a30a998072b4d98facb49dcaf8780d15b0/src/AuthenticationHelper.js#L137 # pylint: disable=line-too-long
@@ -81,7 +81,7 @@ async def generate_hash_device(self, device_group_key, device_key):
self.device_password = device_password
return device_secret_verifier_config
- async def get_device_authentication_key( # pylint: disable=too-many-positional-arguments
+ def get_device_authentication_key( # pylint: disable=too-many-positional-arguments
self, device_group_key, device_key, device_password, server_b_value, salt
):
"""Get device authentication key."""
@@ -120,7 +120,7 @@ async def process_device_challenge(self, challenge_parameters):
"%a %b %d %H:%M:%S UTC %Y"
),
)
- hkdf = await self.get_device_authentication_key(
+ hkdf = self.get_device_authentication_key(
self.device_group_key,
self.device_key,
self.device_password,
@@ -169,7 +169,7 @@ async def confirm_device(self, device_name: str | None = None):
result = None
try:
- device_secret_verifier_config = await self.generate_hash_device(
+ device_secret_verifier_config = self.generate_hash_device(
self.device_group_key, self.device_key
)
result = await self.loop.run_in_executor(
@@ -183,14 +183,12 @@ async def confirm_device(self, device_name: str | None = None):
),
)
except botocore.exceptions.ClientError as err:
- if err.__class__.__name__ in (
- "NotAuthorizedException",
- "CodeMismatchException",
- ):
+ code = (err.response or {}).get("Error", {}).get("Code", "")
+ if code == "CodeMismatchException":
raise HiveInvalid2FACode from err
+ raise HiveApiError from err
except botocore.exceptions.EndpointConnectionError as err:
- if err.__class__.__name__ == "EndpointConnectionError":
- raise HiveApiError from err
+ raise HiveApiError from err
return result
@@ -210,8 +208,7 @@ async def update_device_status(self):
),
)
except botocore.exceptions.EndpointConnectionError as err:
- if err.__class__.__name__ == "EndpointConnectionError":
- raise HiveApiError from err
+ raise HiveApiError from err
return result
@@ -335,10 +332,8 @@ async def forget_device(self, access_token, device_key):
),
)
except botocore.exceptions.ClientError as err:
- if err.__class__.__name__ == "NotAuthorizedException":
- raise HiveInvalid2FACode from err
+ raise HiveApiError from err
except botocore.exceptions.EndpointConnectionError as err:
- if err.__class__.__name__ == "ResourceNotFoundException":
- raise HiveApiError from err
+ raise HiveApiError from err
return result
diff --git a/src/api/hive_api.py b/src/api/hive_api.py
index 6801d3f..c4ae70c 100644
--- a/src/api/hive_api.py
+++ b/src/api/hive_api.py
@@ -2,15 +2,22 @@
import json
import logging
+import re
import requests
-import urllib3
from pyquery import PyQuery
-urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
-
_LOGGER = logging.getLogger(__name__)
+_NO_RESPONSE = "No response to Hive API request"
+_ERROR_RESPONSE = "Error making API call"
+
+# requests exceptions all subclass OSError; response.json() raises a
+# json.JSONDecodeError subclass.
+_REQUEST_ERRORS = (OSError, RuntimeError, json.JSONDecodeError)
+
+_SSO_ASSIGNMENT = re.compile(r'window\.(\w+)\s*=\s*"([^"]*)"')
+
class HiveApi:
"""Hive API Code."""
@@ -33,8 +40,8 @@ def __init__(self, hive_session=None, token=None):
}
self.timeout = 5
self.json_return = {
- "original": "No response to Hive API request",
- "parsed": "No response to Hive API request",
+ "original": _NO_RESPONSE,
+ "parsed": _NO_RESPONSE,
}
self.session = hive_session
self.token = token
@@ -74,37 +81,24 @@ def request(self, http_method, url, jsc=None):
_LOGGER.error("Request failed: %s", e)
raise
- def refresh_tokens(self, tokens=None):
- """Get new session tokens - DEPRECATED NOW BY AWS TOKEN MANAGEMENT."""
- _LOGGER.debug("refresh_tokens - Attempting token refresh (deprecated method)")
- if tokens is None:
- tokens = {}
- url = self.urls["refresh"]
- if self.session is not None:
- tokens = self.session.tokens.token_data
- jsc = (
- "{"
- + ",".join(
- ('"' + str(i) + '": "' + str(t) + '" ' for i, t in tokens.items())
- )
- + "}"
- )
+ def _call_endpoint(self, http_method, url, jsc=None):
+ """Call an endpoint and return a fresh result dict for this call."""
+ json_return = {
+ "original": _NO_RESPONSE,
+ "parsed": _NO_RESPONSE,
+ }
try:
- info = self.request("POST", url, jsc)
- data = json.loads(info.text)
- if "token" in data and self.session:
- _LOGGER.debug(
- "refresh_tokens - Token refresh successful, updating session"
- )
- self.session.update_tokens(data)
- self.urls.update({"base": data["platform"]["endpoint"]})
- self.json_return.update({"original": info.status_code})
- self.json_return.update({"parsed": info.json()})
- except (OSError, RuntimeError, ZeroDivisionError, json.JSONDecodeError) as e:
- _LOGGER.error("Token refresh failed: %s", str(e))
- self.error()
+ response = self.request(http_method, url, jsc)
+ if response is not None:
+ json_return["original"] = response.status_code
+ json_return["parsed"] = response.json()
+ else:
+ _LOGGER.error("No response from Hive API call to %s", url)
+ except _REQUEST_ERRORS as e:
+ _LOGGER.error("Hive API call to %s failed: %s", url, e)
+ json_return = self.error()
- return self.json_return
+ return json_return
def get_login_info(self):
"""Get login properties to make the login request."""
@@ -113,33 +107,21 @@ def get_login_info(self):
)
url = self.urls["properties"]
try:
- data = requests.get(url=url, verify=False, timeout=self.timeout)
+ data = requests.get(url=url, timeout=self.timeout)
_LOGGER.debug(
"get_login_info - Login info response status: %s", data.status_code
)
- html = PyQuery(data.content)
- json_data = json.loads(
- '{"'
- + (html("script:first").text())
- .replace(",", ', "')
- .replace("=", '":')
- .replace("window.", "")
- + "}"
- )
-
- login_data = {}
- login_data.update({"UPID": json_data["HiveSSOPoolId"]})
- login_data.update({"CLIID": json_data["HiveSSOPublicCognitoClientId"]})
- login_data.update({"REGION": json_data["HiveSSOPoolId"]})
+ script_text = PyQuery(data.content)("script:first").text()
+ sso_values = dict(_SSO_ASSIGNMENT.findall(script_text))
+
+ login_data = {
+ "UPID": sso_values["HiveSSOPoolId"],
+ "CLIID": sso_values["HiveSSOPublicCognitoClientId"],
+ "REGION": sso_values["HiveSSOPoolId"],
+ }
_LOGGER.debug("get_login_info - Login info extracted successfully")
return login_data
- except (
- OSError,
- RuntimeError,
- ZeroDivisionError,
- json.JSONDecodeError,
- KeyError,
- ) as e:
+ except (OSError, RuntimeError, KeyError) as e:
_LOGGER.error("Failed to get login info: %s", str(e))
self.error()
return None
@@ -147,59 +129,23 @@ def get_login_info(self):
def get_all(self):
"""Build and query all endpoint."""
_LOGGER.debug("get_all - Fetching all devices/products/actions from Hive API")
- json_return = {}
url = self.urls["base"] + self.urls["all"]
- try:
- info = self.request("GET", url)
- if info is not None:
- json_return.update({"original": info.status_code})
- json_return.update({"parsed": info.json()})
- _LOGGER.debug(
- "get_all - All data fetch successful, status: %s", info.status_code
- )
- else:
- _LOGGER.error("Failed to get response from all endpoint")
- except (OSError, RuntimeError, ZeroDivisionError, json.JSONDecodeError) as e:
- _LOGGER.error("Failed to fetch all data: %s", str(e))
- self.error()
-
- return json_return
+ return self._call_endpoint("GET", url)
def get_devices(self):
"""Call the get devices endpoint."""
url = self.urls["base"] + self.urls["devices"]
- try:
- response = self.request("GET", url)
- self.json_return.update({"original": response.status_code})
- self.json_return.update({"parsed": response.json()})
- except (OSError, RuntimeError, ZeroDivisionError):
- self.error()
-
- return self.json_return
+ return self._call_endpoint("GET", url)
def get_products(self):
"""Call the get products endpoint."""
url = self.urls["base"] + self.urls["products"]
- try:
- response = self.request("GET", url)
- self.json_return.update({"original": response.status_code})
- self.json_return.update({"parsed": response.json()})
- except (OSError, RuntimeError, ZeroDivisionError):
- self.error()
-
- return self.json_return
+ return self._call_endpoint("GET", url)
def get_actions(self):
"""Call the get actions endpoint."""
url = self.urls["base"] + self.urls["actions"]
- try:
- response = self.request("GET", url)
- self.json_return.update({"original": response.status_code})
- self.json_return.update({"parsed": response.json()})
- except (OSError, RuntimeError, ZeroDivisionError):
- self.error()
-
- return self.json_return
+ return self._call_endpoint("GET", url)
def motion_sensor(self, sensor, fromepoch, toepoch):
"""Call a way to get motion sensor info."""
@@ -215,27 +161,13 @@ def motion_sensor(self, sensor, fromepoch, toepoch):
+ "&to="
+ str(toepoch)
)
- try:
- response = self.request("GET", url)
- self.json_return.update({"original": response.status_code})
- self.json_return.update({"parsed": response.json()})
- except (OSError, RuntimeError, ZeroDivisionError):
- self.error()
-
- return self.json_return
+ return self._call_endpoint("GET", url)
def get_weather(self, weather_url):
"""Call endpoint to get local weather from Hive API."""
t_url = self.urls["weather"] + weather_url
url = t_url.replace(" ", "%20")
- try:
- response = self.request("GET", url)
- self.json_return.update({"original": response.status_code})
- self.json_return.update({"parsed": response.json()})
- except (OSError, RuntimeError, ZeroDivisionError, ConnectionError):
- self.error()
-
- return self.json_return
+ return self._call_endpoint("GET", url)
def set_state(self, n_type, n_id, **kwargs):
"""Set the state of a Device."""
@@ -245,58 +177,27 @@ def set_state(self, n_type, n_id, **kwargs):
n_type,
kwargs,
)
- jsc = (
- "{"
- + ",".join(
- ('"' + str(i) + '": "' + str(t) + '" ' for i, t in kwargs.items())
- )
- + "}"
- )
-
+ jsc = json.dumps(kwargs)
url = self.urls["base"] + self.urls["nodes"].format(n_type, n_id)
-
- try:
- response = self.request("POST", url, jsc)
- if response is not None:
- self.json_return.update({"original": response.status_code})
- self.json_return.update({"parsed": response.json()})
- _LOGGER.debug(
- "set_state - State set successfully for %s, status: %s",
- n_id,
- response.status_code,
- )
- else:
- _LOGGER.error("Failed to set state for %s - no response", n_id)
- except (
- OSError,
- RuntimeError,
- ZeroDivisionError,
- ConnectionError,
- json.JSONDecodeError,
- ) as e:
- _LOGGER.error("Failed to set state for %s: %s", n_id, str(e))
- self.error()
-
- return self.json_return
+ return self._call_endpoint("POST", url, jsc)
def set_action(self, n_id, data):
"""Set the state of a Action."""
jsc = data
url = self.urls["base"] + self.urls["actions"] + "/" + n_id
- try:
- response = self.request("POST", url, jsc)
- self.json_return.update({"original": response.status_code})
- self.json_return.update({"parsed": response.json()})
- except (OSError, RuntimeError, ZeroDivisionError, ConnectionError):
- self.error()
-
- return self.json_return
+ return self._call_endpoint("POST", url, jsc)
def error(self):
"""An error has occurred interacting with the Hive API."""
_LOGGER.error("API error occurred - returning error response")
- self.json_return.update({"original": "Error making API call"})
- self.json_return.update({"parsed": "Error making API call"})
+ error_return = {
+ "original": _ERROR_RESPONSE,
+ "parsed": _ERROR_RESPONSE,
+ }
+ # Kept in sync for backwards compatibility with callers that read
+ # the last error state off the instance.
+ self.json_return.update(error_return)
+ return error_return
class UnknownConfig(Exception):
diff --git a/src/api/hive_async_api.py b/src/api/hive_async_api.py
index bc91789..e68d833 100644
--- a/src/api/hive_async_api.py
+++ b/src/api/hive_async_api.py
@@ -5,17 +5,20 @@
import logging
import time
-import requests
-import urllib3
-from aiohttp import ClientResponse, ClientSession, ClientTimeout, web_exceptions
-from pyquery import PyQuery
-
-from ..helper.const import HTTP_FORBIDDEN, HTTP_OK, HTTP_UNAUTHORIZED
+from aiohttp import (
+ ClientError,
+ ClientResponse,
+ ClientSession,
+ ClientTimeout,
+ web_exceptions,
+)
+
+from ..helper.const import HTTP_FORBIDDEN, HTTP_UNAUTHORIZED
from ..helper.hive_exceptions import FileInUse, HiveApiError, HiveAuthError, NoApiToken
_LOGGER = logging.getLogger(__name__)
-urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
+_REQUEST_ERRORS = (ClientError, OSError, RuntimeError, json.JSONDecodeError)
class HiveApiAsync:
@@ -43,7 +46,17 @@ def __init__(self, hive_session=None, websession: ClientSession | None = None):
"parsed": "No response to Hive API request",
}
self.session = hive_session
- self.websession = ClientSession() if websession is None else websession
+ self.websession = websession
+
+ def _get_websession(self) -> ClientSession:
+ """Return the shared ClientSession, creating it on first use.
+
+ Created lazily so the session is constructed inside a running
+ event loop rather than in the synchronous constructor.
+ """
+ if self.websession is None:
+ self.websession = ClientSession()
+ return self.websession
async def request(self, method: str, url: str, **kwargs) -> ClientResponse:
"""Make a request."""
@@ -72,7 +85,7 @@ async def request(self, method: str, url: str, **kwargs) -> ClientResponse:
timeout = ClientTimeout(total=self.timeout)
req_start = time.monotonic()
- async with self.websession.request(
+ async with self._get_websession().request(
method, url, headers=headers, data=data, timeout=timeout
) as resp:
resp_body = await resp.text()
@@ -97,125 +110,50 @@ async def request(self, method: str, url: str, **kwargs) -> ClientResponse:
raise HiveAuthError(
f"Token expired or forbidden calling {url} — HTTP {resp.status}"
)
- if url is not None and resp.status is not None:
- _LOGGER.error(
- "Something has gone wrong calling %s - HTTP status is - %s — response: %s",
- url,
- resp.status,
- resp_body[:200],
- )
-
- raise HiveApiError
-
- def get_login_info(self):
- """Get login properties to make the login request."""
- url = "https://sso.hivehome.com/"
-
- data = requests.get(url=url, verify=False, timeout=self.timeout)
- html = PyQuery(data.content)
- json_data = json.loads(
- '{"'
- + (html("script:first").text())
- .replace(",", ', "')
- .replace("=", '":')
- .replace("window.", "")
- + "}"
- )
-
- login_data = {}
- login_data.update({"UPID": json_data["HiveSSOPoolId"]})
- login_data.update({"CLIID": json_data["HiveSSOPublicCognitoClientId"]})
- login_data.update({"REGION": json_data["HiveSSOPoolId"]})
- return login_data
-
- async def refresh_tokens(self):
- """Refresh tokens - DEPRECATED NOW BY AWS TOKEN MANAGEMENT."""
- url = self.urls["refresh"]
- if self.session is not None:
- tokens = self.session.tokens.token_data
- jsc = (
- "{"
- + ",".join(
- ('"' + str(i) + '": "' + str(t) + '" ' for i, t in tokens.items())
- )
- + "}"
+ _LOGGER.error(
+ "Something has gone wrong calling %s - HTTP status is - %s — response: %s",
+ url,
+ resp.status,
+ resp_body[:200],
)
- try:
- await self.request("post", url, data=jsc)
-
- if self.json_return["original"] == HTTP_OK:
- info = self.json_return["parsed"]
- if "token" in info:
- await self.session.update_tokens(info)
- # pylint: disable-next=invalid-sequence-index
- self.base_url = info["platform"]["endpoint"]
- return True
- except (ConnectionError, OSError, RuntimeError, ZeroDivisionError):
- await self.error()
-
- return self.json_return
+ raise HiveApiError
- async def get_all(self):
- """Build and query all endpoint."""
- json_return = {}
- url = self.urls["all"]
+ async def _call_endpoint(self, method: str, url: str, data=None) -> dict:
+ """Call an endpoint and return {"original": status, "parsed": json}."""
+ json_return: dict = {}
try:
- resp = await self.request("get", url)
+ resp = await self.request(method, url, data=data)
json_return.update({"original": resp.status})
json_return.update({"parsed": await resp.json(content_type=None)})
except asyncio.TimeoutError:
- _LOGGER.warning("Hive API request timed out fetching all nodes.")
+ _LOGGER.warning("Hive API request timed out calling %s", url)
raise
- except (OSError, RuntimeError, ZeroDivisionError):
+ except _REQUEST_ERRORS:
await self.error()
return json_return
+ async def get_all(self):
+ """Build and query all endpoint."""
+ return await self._call_endpoint("get", self.urls["all"])
+
async def get_devices(self):
"""Call the get devices endpoint."""
- json_return = {}
- url = self.urls["devices"]
- try:
- resp = await self.request("get", url)
- json_return.update({"original": resp.status})
- json_return.update({"parsed": await resp.json(content_type=None)})
- except (OSError, RuntimeError, ZeroDivisionError):
- await self.error()
-
- return json_return
+ return await self._call_endpoint("get", self.urls["devices"])
async def get_products(self):
"""Call the get products endpoint."""
- json_return = {}
- url = self.urls["products"]
- try:
- resp = await self.request("get", url)
- json_return.update({"original": resp.status})
- json_return.update({"parsed": await resp.json(content_type=None)})
- except (OSError, RuntimeError, ZeroDivisionError):
- await self.error()
-
- return json_return
+ return await self._call_endpoint("get", self.urls["products"])
async def get_actions(self):
"""Call the get actions endpoint."""
- json_return = {}
- url = self.urls["actions"]
- try:
- resp = await self.request("get", url)
- json_return.update({"original": resp.status})
- json_return.update({"parsed": await resp.json(content_type=None)})
- except (OSError, RuntimeError, ZeroDivisionError):
- await self.error()
-
- return json_return
+ return await self._call_endpoint("get", self.urls["actions"])
async def motion_sensor(self, sensor, fromepoch, toepoch):
"""Call a way to get motion sensor info."""
- json_return = {}
url = (
- self.urls["base"]
- + self.urls["products"]
+ self.base_url
+ + "/products"
+ "/"
+ sensor["type"]
+ "/"
@@ -225,68 +163,34 @@ async def motion_sensor(self, sensor, fromepoch, toepoch):
+ "&to="
+ str(toepoch)
)
- try:
- resp = await self.request("get", url)
- json_return.update({"original": resp.status})
- json_return.update({"parsed": await resp.json(content_type=None)})
- except (OSError, RuntimeError, ZeroDivisionError):
- await self.error()
-
- return json_return
+ return await self._call_endpoint("get", url)
async def get_weather(self, weather_url):
"""Call endpoint to get local weather from Hive API."""
- json_return = {}
t_url = self.urls["weather"] + weather_url
url = t_url.replace(" ", "%20")
- try:
- resp = await self.request("get", url)
- json_return.update({"original": resp.status})
- json_return.update({"parsed": await resp.json(content_type=None)})
- except (OSError, RuntimeError, ZeroDivisionError, ConnectionError):
- await self.error()
-
- return json_return
+ return await self._call_endpoint("get", url)
async def set_state(self, n_type, n_id, **kwargs):
"""Set the state of a Device."""
_LOGGER.debug("set_state - Setting state for %s/%s: %s", n_type, n_id, kwargs)
- json_return = {}
- jsc = (
- "{"
- + ",".join(
- ('"' + str(i) + '": "' + str(t) + '" ' for i, t in kwargs.items())
- )
- + "}"
- )
-
+ jsc = json.dumps(kwargs)
url = self.urls["nodes"].format(n_type, n_id)
try:
await self.is_file_being_used()
- resp = await self.request("post", url, data=jsc)
- json_return["original"] = resp.status
- json_return["parsed"] = await resp.json(content_type=None)
- except (FileInUse, OSError, RuntimeError, ConnectionError) as e:
- if e.__class__.__name__ == "FileInUse":
- return {"original": "file"}
- await self.error()
-
- return json_return
+ except FileInUse:
+ return {"original": "file"}
+ return await self._call_endpoint("post", url, data=jsc)
async def set_action(self, n_id, data):
"""Set the state of a Action."""
_LOGGER.debug("Setting action %s", n_id)
- jsc = data
url = self.urls["actions"] + "/" + n_id
try:
await self.is_file_being_used()
- await self.request("put", url, data=jsc)
- except (FileInUse, OSError, RuntimeError, ConnectionError) as e:
- if e.__class__.__name__ == "FileInUse":
- return {"original": "file"}
- await self.error()
-
- return self.json_return
+ except FileInUse:
+ return {"original": "file"}
+ return await self._call_endpoint("put", url, data=data)
async def error(self):
"""An error has occurred interacting with the Hive API."""
diff --git a/src/api/hive_auth_async.py b/src/api/hive_auth_async.py
index 4056690..6866a84 100644
--- a/src/api/hive_auth_async.py
+++ b/src/api/hive_auth_async.py
@@ -23,6 +23,7 @@
HiveInvalidPassword,
HiveInvalidUsername,
HiveRefreshTokenExpired,
+ HiveUnknownConfiguration,
)
from .device_registration import DeviceRegistrationMixin
from .hive_api import HiveApi
@@ -58,17 +59,10 @@ def __init__( # pylint: disable=too-many-positional-arguments # noqa: PLR0913
device_group_key: str | None = None,
device_key: str | None = None,
device_password: str | None = None,
- pool_region: str | None = None,
client_secret: str | None = None,
):
"""Initialise async auth."""
- if pool_region is not None:
- raise ValueError(
- "pool_region and client should not both be specified "
- "(region should be passed to the boto3 client instead)"
- )
-
- self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
+ self.loop: asyncio.AbstractEventLoop | None = None
self.username = username
self.password = password
self.device_group_key: str | None = device_group_key
@@ -95,10 +89,19 @@ def __init__( # pylint: disable=too-many-positional-arguments # noqa: PLR0913
async def async_init(self):
"""Initialise async variables."""
+ self.loop = asyncio.get_running_loop()
self.data = await self.loop.run_in_executor(None, self.api.get_login_info)
+ if not self.data:
+ raise HiveUnknownConfiguration("SSO login page returned no data")
self._pool_id = self.data.get("UPID")
self._client_id = self.data.get("CLIID")
- self._region = self.data.get("REGION").split("_")[0]
+ region_raw = self.data.get("REGION")
+ if not self._pool_id or not region_raw:
+ raise HiveUnknownConfiguration(
+ "SSO login page did not return required pool/region data"
+ )
+ self._region = region_raw.split("_")[0]
+ # Cognito USER_SRP_AUTH does not use IAM credentials — boto3 requires non-None values.
self.client = await self.loop.run_in_executor(
None,
functools.partial(
@@ -128,6 +131,25 @@ def generate_random_small_a(self):
random_long_int = get_random(128)
return random_long_int % self.big_n
+ def _new_srp_ephemeral(self) -> None:
+ """Generate a fresh SRP client ephemeral (a, A) pair.
+
+ SRP ephemerals must not be reused across handshakes, so this is
+ called at the start of every authentication attempt.
+ """
+ self.small_a_value = self.generate_random_small_a()
+ self.large_a_value = self.calculate_a()
+
+ def _store_auth_result(self, result: dict) -> None:
+ """Store tokens and any new device keys from an AuthenticationResult."""
+ auth_result = result["AuthenticationResult"]
+ self.access_token = auth_result["AccessToken"]
+ self.token_created = datetime.datetime.now()
+ if "NewDeviceMetadata" in auth_result:
+ self.device_group_key = auth_result["NewDeviceMetadata"]["DeviceGroupKey"]
+ self.device_key = auth_result["NewDeviceMetadata"]["DeviceKey"]
+ _LOGGER.debug("Device keys stored successfully.")
+
def calculate_a(self):
"""
Calculate the client's public value A.
@@ -156,6 +178,8 @@ def get_password_authentication_key(self, username, password, server_b_value, sa
u_value = calculate_u(self.large_a_value, server_b_value)
if u_value == 0:
raise ValueError("U cannot be zero.")
+ if not self._pool_id or "_" not in self._pool_id:
+ raise HiveUnknownConfiguration(f"Invalid pool ID format: {self._pool_id!r}")
pool_id = self._pool_id.split("_")[1]
username_password = f"{pool_id}{username}:{password}"
username_password_hash = hash_sha256(username_password.encode("utf-8"))
@@ -255,7 +279,7 @@ async def process_challenge(self, challenge_parameters):
return response
- async def login(self): # noqa: PLR0912
+ async def login(self): # noqa: PLR0912, PLR0915 # pylint: disable=too-many-statements
"""Login into a Hive account - handles initial SRP auth only."""
if self.use_file:
_LOGGER.debug("login - Using file-based authentication.")
@@ -264,6 +288,7 @@ async def login(self): # noqa: PLR0912
if self.client is None:
await self.async_init()
+ self._new_srp_ephemeral()
auth_params = await self.get_auth_params()
response = None
result = None
@@ -279,15 +304,22 @@ async def login(self): # noqa: PLR0912
),
)
except botocore.exceptions.ClientError as err:
- if err.__class__.__name__ == "UserNotFoundException":
+ code = (err.response or {}).get("Error", {}).get("Code", "")
+ if code == "UserNotFoundException":
_LOGGER.error("Cognito auth failed: user not found.")
raise HiveInvalidUsername from err
+ _LOGGER.error("Cognito auth failed: %s", code)
+ raise HiveApiError from err
except botocore.exceptions.EndpointConnectionError as err:
- if err.__class__.__name__ == "EndpointConnectionError":
- _LOGGER.error("Cognito auth failed: cannot reach endpoint.")
- raise HiveApiError from err
+ _LOGGER.error("Cognito auth failed: cannot reach endpoint.")
+ raise HiveApiError from err
+
+ if "AuthenticationResult" in response:
+ _LOGGER.debug("login - Authenticated directly without a challenge.")
+ self._store_auth_result(response)
+ return response
- if response["ChallengeName"] == self.PASSWORD_VERIFIER_CHALLENGE:
+ if response.get("ChallengeName") == self.PASSWORD_VERIFIER_CHALLENGE:
_LOGGER.debug("login - Processing PASSWORD_VERIFIER challenge.")
challenge_response = await self.process_challenge(
response["ChallengeParameters"]
@@ -303,38 +335,29 @@ async def login(self): # noqa: PLR0912
),
)
except botocore.exceptions.ClientError as err:
- if err.__class__.__name__ == "NotAuthorizedException":
+ code = (err.response or {}).get("Error", {}).get("Code", "")
+ if code == "NotAuthorizedException":
_LOGGER.error("Cognito auth challenge failed: not authorised.")
raise HiveInvalidPassword from err
- if err.__class__.__name__ == "ResourceNotFoundException":
+ if code == "ResourceNotFoundException":
_LOGGER.error(
"Cognito auth challenge failed: device resource not found."
)
raise HiveInvalidDeviceAuthentication from err
+ _LOGGER.error("Cognito auth challenge failed: %s", code)
+ raise HiveApiError from err
except botocore.exceptions.EndpointConnectionError as err:
- if err.__class__.__name__ == "EndpointConnectionError":
- _LOGGER.error(
- "Cognito auth challenge failed: cannot reach endpoint."
- )
- raise HiveApiError from err
+ _LOGGER.error("Cognito auth challenge failed: cannot reach endpoint.")
+ raise HiveApiError from err
_LOGGER.debug("login - SRP auth challenge completed successfully.")
if "AuthenticationResult" in result:
- self.access_token = result["AuthenticationResult"]["AccessToken"]
- self.token_created = datetime.datetime.now()
- if "NewDeviceMetadata" in result["AuthenticationResult"]:
- self.device_group_key = result["AuthenticationResult"][
- "NewDeviceMetadata"
- ]["DeviceGroupKey"]
- self.device_key = result["AuthenticationResult"][
- "NewDeviceMetadata"
- ]["DeviceKey"]
- _LOGGER.debug("login - Device keys stored successfully.")
+ self._store_auth_result(result)
return result
- challenge_name = response["ChallengeName"]
+ challenge_name = response.get("ChallengeName")
_LOGGER.error("Unsupported Cognito challenge: %s", challenge_name)
raise NotImplementedError(f"The {challenge_name} challenge is not supported")
@@ -349,6 +372,7 @@ async def device_login(self):
if self.client is None:
await self.async_init()
+ self._new_srp_ephemeral()
auth_params = await self.get_auth_params(is_device_login=True)
_LOGGER.debug("device_login - Processing DEVICE_SRP_AUTH challenge.")
@@ -385,10 +409,8 @@ async def device_login(self):
raise HiveInvalidDeviceAuthentication from err
raise
except botocore.exceptions.EndpointConnectionError as err:
- if err.__class__.__name__ == "EndpointConnectionError":
- _LOGGER.error("Device login failed: cannot reach endpoint.")
- raise HiveApiError from err
- raise HiveInvalidDeviceAuthentication from err
+ _LOGGER.error("Device login failed: cannot reach endpoint.")
+ raise HiveApiError from err
_LOGGER.debug("device_login - Device authentication completed successfully.")
return result
@@ -413,26 +435,18 @@ async def sms_2fa(self, entered_code, challenge_parameters):
},
),
)
- self.access_token = result["AuthenticationResult"]["AccessToken"]
- self.token_created = datetime.datetime.now()
- if "NewDeviceMetadata" in result["AuthenticationResult"]:
- self.device_group_key = result["AuthenticationResult"][
- "NewDeviceMetadata"
- ]["DeviceGroupKey"]
- self.device_key = result["AuthenticationResult"]["NewDeviceMetadata"][
- "DeviceKey"
- ]
+ if result and "AuthenticationResult" in result:
+ self._store_auth_result(result)
except botocore.exceptions.ClientError as err:
- if err.__class__.__name__ in (
- "NotAuthorizedException",
- "CodeMismatchException",
- ):
+ code = (err.response or {}).get("Error", {}).get("Code", "")
+ if code in ("NotAuthorizedException", "CodeMismatchException"):
_LOGGER.error("2FA code rejected by Cognito.")
raise HiveInvalid2FACode from err
+ _LOGGER.error("2FA failed: %s", code)
+ raise HiveApiError from err
except botocore.exceptions.EndpointConnectionError as err:
- if err.__class__.__name__ == "EndpointConnectionError":
- _LOGGER.error("2FA failed: cannot reach Cognito endpoint.")
- raise HiveApiError from err
+ _LOGGER.error("2FA failed: cannot reach Cognito endpoint.")
+ raise HiveApiError from err
_LOGGER.debug("sms_2fa - 2FA authentication completed successfully.")
return result
@@ -478,11 +492,10 @@ async def refresh_token(self, token):
)
raise HiveFailedToRefreshTokens from err
except botocore.exceptions.EndpointConnectionError as err:
- if err.__class__.__name__ == "EndpointConnectionError":
- _LOGGER.error(
- "refresh_token - Token refresh failed: cannot reach Cognito endpoint."
- )
- raise HiveApiError from err
+ _LOGGER.error(
+ "refresh_token - Token refresh failed: cannot reach Cognito endpoint."
+ )
+ raise HiveApiError from err
_LOGGER.debug("refresh_token - Cognito token refresh completed successfully.")
return result
diff --git a/src/api/srp_crypto.py b/src/api/srp_crypto.py
index faf141b..74cea33 100644
--- a/src/api/srp_crypto.py
+++ b/src/api/srp_crypto.py
@@ -1,7 +1,6 @@
"""Pure SRP/HKDF crypto helpers for AWS Cognito authentication."""
import binascii
-import concurrent.futures
import hashlib
import hmac
import os
@@ -28,7 +27,6 @@
# https://github.com/aws/amazon-cognito-identity-js/blob/master/src/AuthenticationHelper.js#L49
G_HEX = "2"
INFO_BITS = bytearray("Caldera Derived Key", "utf-8")
-POOL = concurrent.futures.ThreadPoolExecutor()
def hex_to_long(hex_string):
diff --git a/src/devices/boost.py b/src/devices/boost.py
index 2e3f1b7..b1e22e9 100644
--- a/src/devices/boost.py
+++ b/src/devices/boost.py
@@ -29,7 +29,7 @@ async def get_boost_status(self, device: Device):
data = self.session.data.products[device.hive_id]
return HIVETOHA["Boost"].get(data["state"].get("boost", False), "ON")
except KeyError as e:
- _LOGGER.error(e)
+ _LOGGER.error("get_boost_status - KeyError for %s: %s", device.ha_name, e)
return None
async def get_boost_time(self, device: Device):
@@ -43,5 +43,5 @@ async def get_boost_time(self, device: Device):
data = self.session.data.products[device.hive_id]
return data["state"]["boost"]
except KeyError as e:
- _LOGGER.error(e)
+ _LOGGER.error("get_boost_time - KeyError for %s: %s", device.ha_name, e)
return None
diff --git a/src/devices/color.py b/src/devices/color.py
index ade4cb2..8a63a9d 100644
--- a/src/devices/color.py
+++ b/src/devices/color.py
@@ -2,7 +2,6 @@
from __future__ import annotations
-import colorsys
import logging
from typing import Any
@@ -33,7 +32,7 @@ async def get_min_color_temp(self, device: Device):
data = self.session.data.products[device.hive_id]
state = data["props"]["colourTemperature"]["max"]
return round((1 / state) * 1000000)
- except KeyError as e:
+ except (KeyError, ZeroDivisionError) as e:
_LOGGER.error(e)
return None
@@ -50,7 +49,7 @@ async def get_max_color_temp(self, device: Device):
data = self.session.data.products[device.hive_id]
state = data["props"]["colourTemperature"]["min"]
return round((1 / state) * 1000000)
- except KeyError as e:
+ except (KeyError, ZeroDivisionError) as e:
_LOGGER.error(e)
return None
@@ -67,27 +66,23 @@ async def get_color_temp(self, device: Device):
data = self.session.data.products[device.hive_id]
state = data["state"]["colourTemperature"]
return round((1 / state) * 1000000)
- except KeyError as e:
+ except (KeyError, ZeroDivisionError) as e:
_LOGGER.error(e)
return None
async def get_color(self, device: Device):
- """Get light current colour as an RGB tuple.
+ """Get light current colour as an HS tuple for HA hs_color.
Args:
device (Device): Device to query.
Returns:
- tuple | None: ``(r, g, b)`` each in 0–255.
+ tuple | None: ``(hue_degrees, saturation_percent)`` where hue is
+ 0–360 and saturation is 0–100, or None on error.
"""
try:
data = self.session.data.products[device.hive_id]
- hsv = [
- data["state"]["hue"] / 360,
- data["state"]["saturation"] / 100,
- data["state"]["value"] / 100,
- ]
- return tuple(int(i * 255) for i in colorsys.hsv_to_rgb(*hsv))
+ return (data["state"]["hue"], data["state"]["saturation"])
except KeyError as e:
_LOGGER.error(e)
return None
diff --git a/src/devices/heating.py b/src/devices/heating.py
index 2b81e1e..de458d7 100644
--- a/src/devices/heating.py
+++ b/src/devices/heating.py
@@ -48,6 +48,36 @@ async def get_max_temperature(self, device: Device):
return self._get_product_state(device, "props", "maxHeat")
return 32
+ def _track_minmax(self, hive_id: str, temperature: float) -> None:
+ """Record today's and since-restart min/max temperatures for a device."""
+ today = str(datetime.date(datetime.now()))
+ min_max = self.session.data.minMax.get(hive_id)
+
+ if min_max is None:
+ self.session.data.minMax[hive_id] = {
+ "TodayMin": temperature,
+ "TodayMax": temperature,
+ "TodayDate": today,
+ "RestartMin": temperature,
+ "RestartMax": temperature,
+ }
+ return
+
+ if min_max["TodayDate"] == today:
+ min_max["TodayMin"] = min(min_max["TodayMin"], temperature)
+ min_max["TodayMax"] = max(min_max["TodayMax"], temperature)
+ else:
+ min_max.update(
+ {
+ "TodayMin": temperature,
+ "TodayMax": temperature,
+ "TodayDate": today,
+ }
+ )
+
+ min_max["RestartMin"] = min(min_max["RestartMin"], temperature)
+ min_max["RestartMax"] = max(min_max["RestartMax"], temperature)
+
async def get_current_temperature(self, device: Device):
"""Get heating current temperature.
@@ -75,41 +105,7 @@ async def get_current_temperature(self, device: Device):
)
return None
- if device.hive_id in self.session.data.minMax:
- if self.session.data.minMax[device.hive_id]["TodayDate"] == str(
- datetime.date(datetime.now())
- ):
- self.session.data.minMax[device.hive_id]["TodayMin"] = min(
- self.session.data.minMax[device.hive_id]["TodayMin"], state
- )
-
- self.session.data.minMax[device.hive_id]["TodayMax"] = max(
- self.session.data.minMax[device.hive_id]["TodayMax"], state
- )
- else:
- data = {
- "TodayMin": state,
- "TodayMax": state,
- "TodayDate": str(datetime.date(datetime.now())),
- }
- self.session.data.minMax[device.hive_id].update(data)
-
- self.session.data.minMax[device.hive_id]["RestartMin"] = min(
- self.session.data.minMax[device.hive_id]["RestartMin"], state
- )
-
- self.session.data.minMax[device.hive_id]["RestartMax"] = max(
- self.session.data.minMax[device.hive_id]["RestartMax"], state
- )
- else:
- data = {
- "TodayMin": state,
- "TodayMax": state,
- "TodayDate": str(datetime.date(datetime.now())),
- "RestartMin": state,
- "RestartMax": state,
- }
- self.session.data.minMax[device.hive_id] = data
+ self._track_minmax(device.hive_id, state)
final = round(state, 1)
except KeyError as e:
@@ -170,15 +166,16 @@ async def get_mode(self, device: Device):
"""
state = None
final = None
+ device_name = device.ha_name
try:
data = self.session.data.products[device.hive_id]
state = data["state"]["mode"]
if state == "BOOST":
- state = data["props"]["previous"]["mode"]
+ state = self._get_product_state(device, "props", "previous", "mode")
final = HIVETOHA[self.heating_type].get(state, state)
except KeyError as e:
- _LOGGER.error(e)
+ _LOGGER.error("get_mode - KeyError getting mode for %s: %s", device_name, e)
return final
@@ -284,7 +281,18 @@ async def set_boost_on(self, device: Device, mins: str, temp: float):
"""
min_temp = await self.get_min_temperature(device)
max_temp = await self.get_max_temperature(device)
- if not (int(mins) > 0 and min_temp <= int(temp) <= max_temp):
+ try:
+ mins_value = int(mins)
+ temp_value = float(temp)
+ except (ValueError, TypeError):
+ _LOGGER.warning(
+ "set_boost_on - Invalid boost inputs for %s: mins=%r temp=%r",
+ device.ha_name,
+ mins,
+ temp,
+ )
+ return None
+ if not (mins_value > 0 and min_temp <= temp_value <= max_temp):
return None
_LOGGER.debug(
"set_boost_on - Setting heating boost ON for %s: %s mins at %s degrees.",
@@ -317,6 +325,12 @@ async def set_boost_off(self, device: Device):
"set_boost_off - Setting heating boost OFF for %s.", device.ha_name
)
prev_mode = self._get_product_state(device, "props", "previous", "mode")
+ if prev_mode is None:
+ _LOGGER.warning(
+ "set_boost_off - Cannot determine previous mode for %s, skipping.",
+ device.ha_name,
+ )
+ return False
kwargs = {"mode": prev_mode}
if prev_mode in ("MANUAL", "OFF"):
kwargs["target"] = (
diff --git a/src/devices/hotwater.py b/src/devices/hotwater.py
index b6985b2..d4774c8 100644
--- a/src/devices/hotwater.py
+++ b/src/devices/hotwater.py
@@ -33,15 +33,16 @@ async def get_mode(self, device: Device):
"""
state = None
final = None
+ device_name = device.ha_name
try:
data = self.session.data.products[device.hive_id]
state = data["state"]["mode"]
if state == "BOOST":
- state = data["props"]["previous"]["mode"]
+ state = self._get_product_state(device, "props", "previous", "mode")
final = HIVETOHA[self.hotwater_type].get(state, state)
except KeyError as e:
- _LOGGER.error(e)
+ _LOGGER.error("get_mode - KeyError getting mode for %s: %s", device_name, e)
return final
@@ -77,7 +78,8 @@ async def get_state(self, device: Device):
snan = self.session.helper.get_schedule_nnl(
data["state"]["schedule"]
)
- state = snan["now"]["value"]["status"]
+ if snan and "now" in snan:
+ state = snan["now"]["value"]["status"]
final = HIVETOHA[self.hotwater_type].get(state, state)
except KeyError as e:
@@ -141,6 +143,12 @@ async def set_boost_off(self, device: Device):
"set_boost_off - Setting hot water boost OFF for %s.", device.ha_name
)
prev_mode = self._get_product_state(device, "props", "previous", "mode")
+ if prev_mode is None:
+ _LOGGER.warning(
+ "set_boost_off - Cannot determine previous mode for %s, skipping.",
+ device.ha_name,
+ )
+ return False
return await self._execute_state_change(device, mode=prev_mode)
diff --git a/src/devices/light.py b/src/devices/light.py
index b33637e..d48dc91 100644
--- a/src/devices/light.py
+++ b/src/devices/light.py
@@ -62,10 +62,10 @@ async def get_brightness(self, device: Device):
try:
data = self.session.data.products[device.hive_id]
state = data["state"]["brightness"]
- final = (state / 100) * 255
- except KeyError as e:
+ final = int((state / 100) * 255)
+ except (KeyError, TypeError) as e:
_LOGGER.error(
- "KeyError getting light brightness for %s: %s", device_name, str(e)
+ "Error getting light brightness for %s: %s", device_name, str(e)
)
return final
diff --git a/src/devices/sensor.py b/src/devices/sensor.py
index 9f0a431..bed1ebb 100644
--- a/src/devices/sensor.py
+++ b/src/devices/sensor.py
@@ -4,12 +4,36 @@
from typing import Any
from ..helper.compat_aliases import SensorCompatMixin
-from ..helper.const import HIVE_TYPES, HIVETOHA, sensor_commands
+from ..helper.const import HIVE_TYPES, HIVETOHA
from ..helper.device_handler_base import BaseDeviceHandler
from ..helper.hivedataclasses import Device
_LOGGER = logging.getLogger(__name__)
+sensor_commands = {
+ "SMOKE_CO": lambda s, d: s.session.hub.get_smoke_status(d),
+ "DOG_BARK": lambda s, d: s.session.hub.get_dog_bark_status(d),
+ "GLASS_BREAK": lambda s, d: s.session.hub.get_glass_break_status(d),
+ "Current_Temperature": lambda s, d: s.session.heating.get_current_temperature(d),
+ "Heating_Current_Temperature": lambda s, d: (
+ s.session.heating.get_current_temperature(d)
+ ),
+ "Heating_Target_Temperature": lambda s, d: s.session.heating.get_target_temperature(
+ d
+ ),
+ "Heating_State": lambda s, d: s.session.heating.get_state(d),
+ "Heating_Mode": lambda s, d: s.session.heating.get_mode(d),
+ "Heating_Boost": lambda s, d: s.session.heating.get_boost_status(d),
+ "Hotwater_State": lambda s, d: s.session.hotwater.get_state(d),
+ "Hotwater_Mode": lambda s, d: s.session.hotwater.get_mode(d),
+ "Hotwater_Boost": lambda s, d: s.session.hotwater.get_boost(d),
+ "Battery": lambda s, d: s.session.attr.get_battery(d.device_id),
+ "Mode": lambda s, d: s.session.attr.get_mode(d.hive_id),
+ "Availability": lambda s, d: s.online(d),
+ "Connectivity": lambda s, d: s.online(d),
+ "Power": lambda s, d: s.session.switch.get_power_usage(d),
+}
+
class HiveSensor(BaseDeviceHandler):
"""Hive Sensor Code."""
@@ -133,7 +157,7 @@ async def get_sensor(self, device: Device):
device.device_data = props
device.parent_device = data.get("parent", None)
elif device.hive_type in HIVE_TYPES["Sensor"]:
- data = self.session.data.devices.get(device.hive_id, {})
+ data = self.session.data.devices.get(device.device_id, {})
device.status = {"state": await self.get_state(device)}
props = data.get("props") or {}
props["online"] = online
diff --git a/src/helper/compat_aliases.py b/src/helper/compat_aliases.py
index b1fe79e..8cdfc0e 100644
--- a/src/helper/compat_aliases.py
+++ b/src/helper/compat_aliases.py
@@ -7,6 +7,7 @@
from __future__ import annotations
+from datetime import timedelta
from typing import Any
from .hivedataclasses import Device
@@ -138,6 +139,7 @@ async def updateData(self, device: Device): # pylint: disable=invalid-name
"""Backwards-compatible alias for update_data."""
return await self.update_data(device) # type: ignore[attr-defined]
- async def updateInterval(self, new_interval: int): # pylint: disable=invalid-name,unused-argument
+ async def updateInterval(self, new_interval: int): # pylint: disable=invalid-name
"""Backwards-compatible alias for Home Assistant Scan Interval."""
+ self.config.scan_interval = timedelta(seconds=new_interval) # type: ignore[attr-defined]
return True
diff --git a/src/helper/const.py b/src/helper/const.py
index 7c27601..59d1c7d 100644
--- a/src/helper/const.py
+++ b/src/helper/const.py
@@ -60,29 +60,6 @@
"Sensor": ["motionsensor", "contactsensor"],
"Switch": ["activeplug"],
}
-sensor_commands = {
- "SMOKE_CO": lambda s, d: s.session.hub.get_smoke_status(d),
- "DOG_BARK": lambda s, d: s.session.hub.get_dog_bark_status(d),
- "GLASS_BREAK": lambda s, d: s.session.hub.get_glass_break_status(d),
- "Current_Temperature": lambda s, d: s.session.heating.get_current_temperature(d),
- "Heating_Current_Temperature": lambda s, d: (
- s.session.heating.get_current_temperature(d)
- ),
- "Heating_Target_Temperature": lambda s, d: s.session.heating.get_target_temperature(
- d
- ),
- "Heating_State": lambda s, d: s.session.heating.get_state(d),
- "Heating_Mode": lambda s, d: s.session.heating.get_mode(d),
- "Heating_Boost": lambda s, d: s.session.heating.get_boost_status(d),
- "Hotwater_State": lambda s, d: s.session.hotwater.get_state(d),
- "Hotwater_Mode": lambda s, d: s.session.hotwater.get_mode(d),
- "Hotwater_Boost": lambda s, d: s.session.hotwater.get_boost(d),
- "Battery": lambda s, d: s.session.attr.get_battery(d.device_id),
- "Mode": lambda s, d: s.session.attr.get_mode(d.hive_id),
- "Availability": lambda s, d: s.online(d),
- "Connectivity": lambda s, d: s.online(d),
- "Power": lambda s, d: s.session.switch.get_power_usage(d),
-}
PRODUCTS = {
"sense": [
diff --git a/src/helper/device_attributes.py b/src/helper/device_attributes.py
index b703c21..d218a36 100644
--- a/src/helper/device_attributes.py
+++ b/src/helper/device_attributes.py
@@ -3,8 +3,6 @@
import logging
from typing import Any
-from .const import HIVETOHA
-
_LOGGER = logging.getLogger(__name__)
@@ -18,7 +16,6 @@ def __init__(self, session: Any = None):
session (object, optional): Session to interact with hive account. Defaults to None.
"""
self.session = session
- self.type = "Attribute"
async def state_attributes(self, n_id: str, _type: str):
"""Get HA State Attributes.
@@ -70,17 +67,12 @@ async def get_mode(self, n_id: str):
Returns:
str: The mode of the device.
"""
- state = None
- final = None
-
try:
data = self.session.data.products[n_id]
- state = data["state"]["mode"]
- final = HIVETOHA[self.type].get(state, state)
+ return data["state"]["mode"]
except KeyError as e:
_LOGGER.error(e)
-
- return final
+ return None
async def get_battery(self, n_id: str):
"""Get device battery level.
@@ -98,7 +90,6 @@ async def get_battery(self, n_id: str):
data = self.session.data.devices[n_id]
state = data["props"]["battery"]
final = state
- await self.session.helper.error_check(n_id, self.type, state)
except KeyError as e:
_LOGGER.error(e)
diff --git a/src/helper/hive_exceptions.py b/src/helper/hive_exceptions.py
index 4881d12..2a69441 100644
--- a/src/helper/hive_exceptions.py
+++ b/src/helper/hive_exceptions.py
@@ -19,14 +19,22 @@ class NoApiToken(Exception):
"""
-class HiveApiError(Exception):
- """Api error.
+class HiveError(Exception):
+ """Common base class for all Hive-specific exceptions.
Args:
Exception (object): Exception object to invoke
"""
+class HiveApiError(HiveError):
+ """Api error.
+
+ Args:
+ HiveError (object): Parent Hive error class
+ """
+
+
class HiveAuthError(HiveApiError):
"""Auth error (401/403) — token may be expired or invalid.
@@ -35,65 +43,81 @@ class HiveAuthError(HiveApiError):
"""
-class HiveRefreshTokenExpired(Exception):
+class HiveRefreshTokenExpired(HiveApiError):
"""Refresh token expired.
Args:
- Exception (object): Exception object to invoke
+ HiveApiError (object): Parent API error class
"""
-class HiveReauthRequired(Exception):
- """Re-Authentication is required.
+class HiveFailedToRefreshTokens(HiveApiError):
+ """Raise invalid refresh tokens.
Args:
- Exception (object): Exception object to invoke
+ HiveApiError (object): Parent API error class
+ """
+
+
+class HiveConfigurationError(HiveError):
+ """Base class for configuration-related errors.
+
+ Args:
+ HiveError (object): Parent Hive error class
"""
-class HiveUnknownConfiguration(Exception):
+class HiveUnknownConfiguration(HiveConfigurationError):
"""Unknown Hive Configuration.
Args:
- Exception (object): Exception object to invoke
+ HiveConfigurationError (object): Parent configuration error class
"""
-class HiveInvalidUsername(Exception):
- """Raise invalid Username.
+class HiveInvalidDeviceAuthentication(HiveConfigurationError):
+ """Raise invalid device authentication.
Args:
- Exception (object): Exception object to invoke
+ HiveConfigurationError (object): Parent configuration error class
"""
-class HiveInvalidPassword(Exception):
- """Raise invalid password.
+class HiveAuthCredentialError(HiveError):
+ """Base class for authentication credential errors.
Args:
- Exception (object): Exception object to invoke
+ HiveError (object): Parent Hive error class
"""
-class HiveInvalid2FACode(Exception):
- """Raise invalid 2FA code.
+class HiveInvalidUsername(HiveAuthCredentialError):
+ """Raise invalid Username.
Args:
- Exception (object): Exception object to invoke
+ HiveAuthCredentialError (object): Parent credential error class
"""
-class HiveInvalidDeviceAuthentication(Exception):
- """Raise invalid device authentication.
+class HiveInvalidPassword(HiveAuthCredentialError):
+ """Raise invalid password.
Args:
- Exception (object): Exception object to invoke
+ HiveAuthCredentialError (object): Parent credential error class
"""
-class HiveFailedToRefreshTokens(Exception):
- """Raise invalid refresh tokens.
+class HiveInvalid2FACode(HiveAuthCredentialError):
+ """Raise invalid 2FA code.
Args:
- Exception (object): Exception object to invoke
+ HiveAuthCredentialError (object): Parent credential error class
+ """
+
+
+class HiveReauthRequired(HiveError):
+ """Re-Authentication is required.
+
+ Args:
+ HiveError (object): Parent Hive error class
"""
diff --git a/src/helper/hive_helper.py b/src/helper/hive_helper.py
index d052f73..8768f1a 100644
--- a/src/helper/hive_helper.py
+++ b/src/helper/hive_helper.py
@@ -8,7 +8,6 @@
from typing import Any
from .const import HIVE_TYPES
-from .hivedataclasses import Device
_LOGGER = logging.getLogger(__name__)
@@ -26,7 +25,6 @@ def epoch_time(date_time: Any, pattern: str, action: str) -> Any:
Converted value, or ``None`` if *action* is unrecognised.
"""
if action == "to_epoch":
- pattern = "%d.%m.%Y %H:%M:%S"
return int(time.mktime(time.strptime(str(date_time), pattern)))
if action == "from_epoch":
return datetime.datetime.fromtimestamp(int(date_time)).strftime(pattern)
@@ -256,8 +254,9 @@ def get_schedule_nnl(self, hive_api_schedule: dict): # pylint: disable=too-many
if slot_time_date_dt <= date_time_now:
slot_time_date_dt = slot_time_date_dt + datetime.timedelta(days=7)
- current_slot_custom["Start_DateTime"] = slot_time_date_dt
- full_schedule_list.append(current_slot_custom)
+ slot_copy = dict(current_slot_custom)
+ slot_copy["Start_DateTime"] = slot_time_date_dt
+ full_schedule_list.append(slot_copy)
fsl_sorted = sorted(
full_schedule_list,
@@ -297,19 +296,6 @@ def get_schedule_nnl(self, hive_api_schedule: dict): # pylint: disable=too-many
return schedule_now_and_next
- def get_heat_on_demand_device(self, device: Device):
- """Use TRV device to get the linked thermostat device.
-
- Args:
- device ([dictionary]): [The TRV device to lookup.]
-
- Returns:
- [dictionary]: [Gets the thermostat device linked to TRV.]
- """
- trv = self.session.data.products.get(device["HiveID"])
- thermostat = self.session.data.products.get(trv["state"]["zone"])
- return thermostat
-
def sanitize_payload(self, payload: dict[str, Any]) -> dict[str, Any]:
"""Return a copy of payload with sensitive values masked for logs."""
diff --git a/src/helper/hivedataclasses.py b/src/helper/hivedataclasses.py
index 19cf3c4..7fe7891 100644
--- a/src/helper/hivedataclasses.py
+++ b/src/helper/hivedataclasses.py
@@ -102,12 +102,12 @@ class SessionTokens:
class SessionConfig:
"""Typed container for session configuration state."""
- battery: list = field(default_factory=list)
+ battery: set = field(default_factory=set)
error_list: dict = field(default_factory=dict)
file: bool = False
home_id: str | None = None
last_update: datetime = field(default_factory=datetime.now)
- mode: list = field(default_factory=list)
+ mode: set = field(default_factory=set)
scan_interval: timedelta = field(default_factory=lambda: _SCAN_INTERVAL)
user_id: str | None = None
username: str | None = None
diff --git a/src/helper/map.py b/src/helper/map.py
index b4bd8ea..c6dc468 100644
--- a/src/helper/map.py
+++ b/src/helper/map.py
@@ -10,6 +10,16 @@ class Map(dict):
dict (dict): dictionary to map.
"""
- __getattr__ = dict.get
+ def __getattr__(self, key):
+ try:
+ return self[key]
+ except KeyError:
+ raise AttributeError(f"Map has no key {key!r}") from None
+
__setattr__ = dict.__setitem__
- __delattr__ = dict.__delitem__
+
+ def __delattr__(self, key):
+ try:
+ del self[key]
+ except KeyError:
+ raise AttributeError(f"Map has no key {key!r}") from None
diff --git a/src/session/__init__.py b/src/session/__init__.py
index a240306..bb9f192 100644
--- a/src/session/__init__.py
+++ b/src/session/__init__.py
@@ -50,6 +50,7 @@ def __init__(
username=username,
password=password,
)
+ self._owns_websession = websession is None
self.api = API(hive_session=self, websession=websession)
self.helper = HiveHelper(self)
self.attr = HiveAttributes(self)
@@ -75,9 +76,10 @@ def __init__(
self._update_task: asyncio.Task | None = None
async def close(self) -> None:
- """Close the underlying aiohttp ClientSession."""
- if not self.api.websession.closed:
- await self.api.websession.close()
+ """Close the underlying aiohttp ClientSession if we own it."""
+ websession = self.api.websession
+ if self._owns_websession and websession is not None and not websession.closed:
+ await websession.close()
async def __aenter__(self):
return self
diff --git a/src/session/auth.py b/src/session/auth.py
index 1d63f09..4f9d67d 100644
--- a/src/session/auth.py
+++ b/src/session/auth.py
@@ -70,10 +70,11 @@ async def _retry_with_backoff(
raise
except Exception as err: # pylint: disable=broad-except
last_err = err
- exc_type = reraise_as or (
- type(last_err) if last_err is not None else RuntimeError
- )
- raise exc_type() from last_err # pylint: disable=broad-exception-raised
+ if reraise_as is not None:
+ raise reraise_as() from last_err
+ if last_err is not None:
+ raise last_err
+ raise RuntimeError("Retry attempts exhausted without capturing an error")
async def update_tokens(self, tokens: dict, update_expiry_time: bool = True):
"""Update session tokens.
@@ -91,17 +92,23 @@ async def update_tokens(self, tokens: dict, update_expiry_time: bool = True):
)
if "AuthenticationResult" in tokens:
data = tokens.get("AuthenticationResult") or {}
- self.tokens.token_data.update({"token": data["IdToken"]})
+ if "IdToken" in data:
+ self.tokens.token_data.update({"token": data["IdToken"]})
if "RefreshToken" in data:
self.tokens.token_data.update({"refreshToken": data["RefreshToken"]})
- self.tokens.token_data.update({"accessToken": data["AccessToken"]})
+ if "AccessToken" in data:
+ self.tokens.token_data.update({"accessToken": data["AccessToken"]})
if update_expiry_time:
self.tokens.token_created = datetime.now()
elif "token" in tokens:
data = tokens
self.tokens.token_data.update({"token": data["token"]})
- self.tokens.token_data.update({"refreshToken": data["refreshToken"]})
- self.tokens.token_data.update({"accessToken": data["accessToken"]})
+ if "refreshToken" in data:
+ self.tokens.token_data.update({"refreshToken": data["refreshToken"]})
+ if "accessToken" in data:
+ self.tokens.token_data.update({"accessToken": data["accessToken"]})
+ if update_expiry_time:
+ self.tokens.token_created = datetime.now()
if "ExpiresIn" in data:
self.tokens.token_expiry = timedelta(seconds=data["ExpiresIn"])
@@ -335,7 +342,7 @@ async def hive_refresh_tokens(self, force_refresh: bool = False):
)
try:
result = await self.auth.refresh_token(
- self.tokens.token_data["refreshToken"]
+ self.tokens.token_data.get("refreshToken")
)
if result and "AuthenticationResult" in result:
diff --git a/src/session/discovery.py b/src/session/discovery.py
index 6ff691c..aa13ad8 100644
--- a/src/session/discovery.py
+++ b/src/session/discovery.py
@@ -8,7 +8,7 @@
from typing import Any
from ..helper.const import DEVICES, EXPECTED_DEVICE_DATA_LENGTH, HIVE_TYPES, PRODUCTS
-from ..helper.hive_exceptions import HiveReauthRequired, HiveUnknownConfiguration
+from ..helper.hive_exceptions import HiveUnknownConfiguration
from ..helper.hivedataclasses import Device
_DATA_DIR = Path(__file__).parent.parent / "data"
@@ -148,10 +148,15 @@ async def start_session(self, config: dict | None = None):
self.auth.password = config["password"]
if "device_data" in config and not self.config.file:
- self.auth.device_group_key = config["device_data"][0]
- self.auth.device_key = config["device_data"][1]
- self.auth.device_password = config["device_data"][2]
device_data = config["device_data"]
+ if len(device_data) < EXPECTED_DEVICE_DATA_LENGTH:
+ raise HiveUnknownConfiguration(
+ "device_data must contain device_group_key, "
+ "device_key and device_password"
+ )
+ self.auth.device_group_key = device_data[0]
+ self.auth.device_key = device_data[1]
+ self.auth.device_password = device_data[2]
if len(device_data) > EXPECTED_DEVICE_DATA_LENGTH:
token_created = device_data[3]
if token_created:
@@ -163,10 +168,8 @@ async def start_session(self, config: dict | None = None):
await self.get_devices("No_ID") # type: ignore[attr-defined]
if not self.data.devices or not self.data.products:
- _LOGGER.error(
- "No devices or products returned from Hive API, reauthentication required."
- )
- raise HiveReauthRequired
+ _LOGGER.error("No devices or products returned from Hive API.")
+ raise HiveUnknownConfiguration
return await self.create_devices()
@@ -237,7 +240,7 @@ async def create_devices( # noqa: PLR0912, PLR0915
)
if device_type in hive_type:
- self.config.battery.append(d["id"])
+ self.config.battery.add(d.get("id", a_device))
_LOGGER.debug(
"create_devices - Added device %s to battery monitoring list",
device_name,
@@ -313,7 +316,7 @@ async def create_devices( # noqa: PLR0912, PLR0915
)
if product_type in hive_type:
- self.config.mode.append(p["id"])
+ self.config.mode.add(p.get("id", a_product))
_LOGGER.debug(
"create_devices - Added product %s to mode list", product_name
)
diff --git a/src/session/polling.py b/src/session/polling.py
index 27341fc..e9e9784 100644
--- a/src/session/polling.py
+++ b/src/session/polling.py
@@ -144,7 +144,17 @@ async def get_devices(self, _n_id: str): # pylint: disable=too-many-locals,too-
api_call_start = time.monotonic()
try:
api_resp_d = await self.api.get_all()
+ api_call_duration = time.monotonic() - api_call_start
+ if api_call_duration > self._slow_poll_threshold:
+ _LOGGER.debug(
+ "get_devices - Hive API response took %.1fs — marking poll as slow.",
+ api_call_duration,
+ )
+ self._last_poll_slow = True
+ else:
+ self._last_poll_slow = False
except HiveAuthError:
+ self._last_poll_slow = False
_LOGGER.warning(
"Auth error (401/403) after token refresh, "
"falling back to full device re-login."
@@ -154,15 +164,8 @@ async def get_devices(self, _n_id: str): # pylint: disable=too-many-locals,too-
self.api.get_all,
reraise_as=HiveReauthRequired,
)
- api_call_duration = time.monotonic() - api_call_start
- if api_call_duration > self._slow_poll_threshold:
- _LOGGER.debug(
- "get_devices - Hive API response took %.1fs — marking poll as slow.",
- api_call_duration,
- )
- self._last_poll_slow = True
- else:
- self._last_poll_slow = False
+ if api_resp_d is None:
+ return get_nodes_successful
if not str(api_resp_d["original"]).startswith("2"):
raise HTTPException
if api_resp_d["parsed"] is None:
@@ -178,7 +181,7 @@ async def get_devices(self, _n_id: str): # pylint: disable=too-many-locals,too-
for hive_type_key in api_resp_p:
if hive_type_key == "user":
self.data.user = api_resp_p[hive_type_key]
- self.config.user_id = api_resp_p[hive_type_key]["id"]
+ self.config.user_id = api_resp_p[hive_type_key].get("id")
if hive_type_key == "products":
for a_product in api_resp_p[hive_type_key]:
tmp_products.update({a_product["id"]: a_product})
@@ -189,7 +192,11 @@ async def get_devices(self, _n_id: str): # pylint: disable=too-many-locals,too-
for a_action in api_resp_p[hive_type_key]:
tmp_actions.update({a_action["id"]: a_action})
if hive_type_key == "homes":
- self.config.home_id = api_resp_p[hive_type_key]["homes"][0]["id"]
+ homes_data = api_resp_p[hive_type_key]
+ if isinstance(homes_data, dict):
+ homes_list = homes_data.get("homes") or []
+ if homes_list:
+ self.config.home_id = homes_list[0]["id"]
_LOGGER.debug(
"get_devices - API returned %d products, %d devices, %d actions.",
@@ -221,6 +228,7 @@ async def get_devices(self, _n_id: str): # pylint: disable=too-many-locals,too-
HiveApiError,
ConnectionError,
HTTPException,
+ KeyError,
) as err:
_LOGGER.error("Failed to fetch devices: %s", err)
self.config.last_update = (
diff --git a/tests/module/test_hotwater.py b/tests/module/test_hotwater.py
index b8f7433..f3ac68f 100644
--- a/tests/module/test_hotwater.py
+++ b/tests/module/test_hotwater.py
@@ -188,14 +188,6 @@ async def test_boosting_calls_execute_with_prev_mode(self):
class TestGetScheduleNowNextLater:
"""Tests for WaterHeater.get_schedule_now_next_later."""
- async def test_schedule_mode_returns_nnl(self):
- """SCHEDULE mode with a schedule returns the now/next/later dict."""
- hw = _make_hotwater(
- {"hw-1": {"state": {"mode": _SCHEDULE_MODE, "schedule": {}}}}
- )
- result = await hw.get_schedule_now_next_later(_make_device())
- assert result is not None
-
async def test_non_schedule_returns_none(self):
"""Non-SCHEDULE mode returns None."""
hw = _make_hotwater({"hw-1": {"state": {"mode": _ON_MODE}}})
diff --git a/tests/module/test_hub.py b/tests/module/test_hub.py
index f3d5f33..06b61d7 100644
--- a/tests/module/test_hub.py
+++ b/tests/module/test_hub.py
@@ -128,11 +128,13 @@ async def test_context_manager_aenter_returns_self(self):
assert hive is not None
async def test_close_calls_websession_close(self):
- """__aexit__ closes the underlying aiohttp websession."""
+ """__aexit__ closes the lazily created aiohttp websession."""
async with Hive(
username="test@example.com",
password="pass", # pragma: allowlist secret
) as hive:
- ws = hive.api.websession
+ # The websession is created lazily, on first use inside the loop.
+ assert hive.api.websession is None
+ ws = hive.api._get_websession()
# After context exit the session should be closed
assert ws.closed
diff --git a/tests/module/test_light.py b/tests/module/test_light.py
index c28d82e..8910759 100644
--- a/tests/module/test_light.py
+++ b/tests/module/test_light.py
@@ -1,7 +1,6 @@
"""Tests for Light / HiveLight and LightColorHandler."""
# pylint: disable=too-few-public-methods
-import colorsys
from unittest.mock import AsyncMock, MagicMock
from apyhiveapi.devices.light import Light
@@ -10,9 +9,9 @@
_HTTP_OK = 200
_BRIGHTNESS_PCT = 50
-_BRIGHTNESS_HA = (_BRIGHTNESS_PCT / 100) * 255
+_BRIGHTNESS_HA = int((_BRIGHTNESS_PCT / 100) * 255)
_BRIGHTNESS_RAW = 80
-_BRIGHTNESS_CONVERTED = (_BRIGHTNESS_RAW / 100) * 255
+_BRIGHTNESS_CONVERTED = int((_BRIGHTNESS_RAW / 100) * 255)
_BRIGHTNESS_SET = 128
_COLOR_TEMP_KELVIN = 4000
_COLOR_TEMP_MIRED = round((1 / _COLOR_TEMP_KELVIN) * 1_000_000)
@@ -23,10 +22,7 @@
_HSV_HUE = 120
_HSV_SAT = 100
_HSV_VAL = 100
-_COLOR_TUPLE = tuple(
- int(i * 255)
- for i in colorsys.hsv_to_rgb(_HSV_HUE / 360, _HSV_SAT / 100, _HSV_VAL / 100)
-)
+_COLOR_TUPLE = (_HSV_HUE, _HSV_SAT)
def _make_light(products=None, devices=None):
@@ -280,7 +276,7 @@ async def test_get_color_temp_missing_returns_none(self):
assert await light.get_color_temp(_make_device()) is None
async def test_get_color_returns_rgb_tuple(self):
- """get_color returns an (R, G, B) tuple in 0–255 range."""
+ """get_color returns an (hue, saturation) 2-tuple for HA hs_color."""
light = _make_light(
{
"light-1": {
diff --git a/tests/module/test_sensor.py b/tests/module/test_sensor.py
index f8cb038..ac0bfa5 100644
--- a/tests/module/test_sensor.py
+++ b/tests/module/test_sensor.py
@@ -65,19 +65,6 @@ async def test_contactsensor_closed_returns_false(self):
)
assert await sensor.get_state(_make_device()) is False
- async def test_motionsensor_returns_motion_status(self):
- """get_state returns the motion status boolean for a motionsensor."""
- sensor = _make_sensor(
- products={
- "sens-1": {
- "type": "motionsensor",
- "props": {"motion": {"status": True}},
- }
- }
- )
- result = await sensor.get_state(_make_device(hive_type="motionsensor"))
- assert result is True
-
async def test_missing_key_returns_none(self):
"""get_state returns None when the hive_id is not in products."""
sensor = _make_sensor()
diff --git a/tests/module/test_session.py b/tests/module/test_session.py
index 3de0c5a..8713aa3 100644
--- a/tests/module/test_session.py
+++ b/tests/module/test_session.py
@@ -178,4 +178,5 @@ async def always_fails():
with patch("asyncio.sleep", new=AsyncMock()):
with pytest.raises(Exception) as exc_info:
await hive._retry_with_backoff(always_fails, delays=(0, 0)) # pylint: disable=protected-access
- assert "permanent failure" in str(exc_info.value.__cause__)
+ # The original exception instance propagates unchanged.
+ assert "permanent failure" in str(exc_info.value)
diff --git a/tests/module/test_session_discovery.py b/tests/module/test_session_discovery.py
index 6bc07ea..cfb67fa 100644
--- a/tests/module/test_session_discovery.py
+++ b/tests/module/test_session_discovery.py
@@ -5,7 +5,6 @@
import pytest
from apyhiveapi.helper.hive_exceptions import (
- HiveReauthRequired,
HiveUnknownConfiguration,
)
from apyhiveapi.helper.hivedataclasses import SessionConfig
@@ -103,13 +102,6 @@ async def test_file_mode_username_enables_file_and_succeeds(self):
assert s.config.file is True
s.get_devices.assert_called_once()
- async def test_empty_devices_after_get_devices_raises_reauth(self):
- """start_session raises HiveReauthRequired when data.devices is empty post-poll."""
- s = _make_stub(has_data=False)
- s.config.file = True
- with pytest.raises(HiveReauthRequired):
- await s.start_session({})
-
async def test_no_tokens_in_non_file_config_raises_unknown_configuration(self):
"""Non-file mode config without tokens raises HiveUnknownConfiguration."""
s = _make_stub()
diff --git a/tests/unit/test_action_extended.py b/tests/unit/test_action.py
similarity index 100%
rename from tests/unit/test_action_extended.py
rename to tests/unit/test_action.py
diff --git a/tests/unit/test_attributes.py b/tests/unit/test_attributes.py
index ab2a075..d39eced 100644
--- a/tests/unit/test_attributes.py
+++ b/tests/unit/test_attributes.py
@@ -29,8 +29,8 @@ def _make_attrs(devices=None, products=None, battery=None, mode=None):
}
)
config = SessionConfig()
- config.battery = battery or []
- config.mode = mode or []
+ config.battery = set(battery) if battery else set()
+ config.mode = set(mode) if mode else set()
session.config = config
session.helper = MagicMock()
session.helper.error_check = AsyncMock()
@@ -92,13 +92,11 @@ async def test_missing_device_returns_none(self):
assert await attrs.get_battery("nope") is None
@pytest.mark.asyncio
- async def test_calls_error_check(self):
- """error_check should be called once with the device id, type, and battery level."""
+ async def test_returns_correct_battery_level(self):
+ """Battery level is returned as the raw integer from props."""
attrs = _make_attrs(devices={"d1": {"props": {"battery": BATTERY_50}}})
- await attrs.get_battery("d1")
- attrs.session.helper.error_check.assert_awaited_once_with(
- "d1", "Attribute", BATTERY_50
- )
+ result = await attrs.get_battery("d1")
+ assert result == BATTERY_50
@pytest.mark.asyncio
async def test_battery_zero_returned(self):
@@ -136,18 +134,18 @@ async def test_manual_mode_passes_through(self):
assert result == "MANUAL"
@pytest.mark.asyncio
- async def test_true_value_maps_to_online(self):
- """HIVETOHA["Attribute"][True] == "Online"."""
+ async def test_true_value_returned_directly(self):
+ """get_mode returns the raw mode value (True) without HIVETOHA translation."""
attrs = _make_attrs(products={"p1": {"state": {"mode": True}}})
result = await attrs.get_mode("p1")
- assert result == "Online"
+ assert result is True
@pytest.mark.asyncio
- async def test_false_value_maps_to_offline(self):
- """HIVETOHA["Attribute"][False] == "Offline"."""
+ async def test_false_value_returned_directly(self):
+ """get_mode returns the raw mode value (False) without HIVETOHA translation."""
attrs = _make_attrs(products={"p1": {"state": {"mode": False}}})
result = await attrs.get_mode("p1")
- assert result == "Offline"
+ assert result is False
# ---------------------------------------------------------------------------
@@ -259,3 +257,31 @@ async def test_all_attributes_combined(self):
assert result["available"] is True
assert result["battery"] == "90%"
assert result["mode"] == "SCHEDULE"
+
+
+class TestGetModeNoOp:
+ """HiveAttributes.get_mode must return mode string directly (no HIVETOHA lookup)."""
+
+ async def test_get_mode_returns_mode_string_unchanged(self):
+ """get_mode returns the raw mode string, not a boolean-keyed HIVETOHA lookup."""
+ session = MagicMock()
+ session.data.products = {"p1": {"state": {"mode": "SCHEDULE"}}}
+ attr = HiveAttributes(session)
+
+ result = await attr.get_mode("p1")
+ assert result == "SCHEDULE"
+
+
+class TestGetBatteryNoDeadCall:
+ """HiveAttributes.get_battery must not call error_check with integer state."""
+
+ async def test_get_battery_returns_value_without_error_check_side_effects(self):
+ """get_battery returns the battery level; error_check is not called with int."""
+ session = MagicMock()
+ session.data.devices = {"d1": {"props": {"battery": 85}}}
+ attr = HiveAttributes(session)
+ session.helper.error_check = AsyncMock()
+
+ result = await attr.get_battery("d1")
+ assert result == 85
+ session.helper.error_check.assert_not_called()
diff --git a/tests/unit/test_boost_extended.py b/tests/unit/test_boost.py
similarity index 100%
rename from tests/unit/test_boost_extended.py
rename to tests/unit/test_boost.py
diff --git a/tests/unit/test_color_extended.py b/tests/unit/test_color.py
similarity index 51%
rename from tests/unit/test_color_extended.py
rename to tests/unit/test_color.py
index 6795724..dbe830b 100644
--- a/tests/unit/test_color_extended.py
+++ b/tests/unit/test_color.py
@@ -99,3 +99,79 @@ async def test_keyerror_on_missing_product_returns_none(self):
result = await handler.get_max_color_temp(device)
assert result is None
+
+
+class TestZeroDivisionGuards:
+ """Colour-temperature methods must return None instead of raising ZeroDivisionError."""
+
+ async def test_get_min_color_temp_zero_returns_none(self):
+ """min colourTemperature == 0 must return None, not raise ZeroDivisionError.
+
+ get_min_color_temp reads colourTemperature['max'] and divides by it,
+ so 'max' must be 0 to trigger ZeroDivisionError.
+ """
+ session = _make_session(
+ products={
+ "light-1": {"props": {"colourTemperature": {"max": 0, "min": 153}}}
+ }
+ )
+ h = _make_handler(session)
+ device = _make_device()
+ result = await h.get_min_color_temp(device)
+ assert result is None
+
+ async def test_get_max_color_temp_zero_returns_none(self):
+ """max colourTemperature == 0 must return None, not raise ZeroDivisionError.
+
+ get_max_color_temp reads colourTemperature['min'] and divides by it,
+ so 'min' must be 0 to trigger ZeroDivisionError.
+ """
+ session = _make_session(
+ products={
+ "light-1": {"props": {"colourTemperature": {"max": 500, "min": 0}}}
+ }
+ )
+ h = _make_handler(session)
+ device = _make_device()
+ result = await h.get_max_color_temp(device)
+ assert result is None
+
+ async def test_get_color_temp_zero_returns_none(self):
+ """state colourTemperature == 0 must return None, not raise ZeroDivisionError."""
+ session = _make_session(
+ products={"light-1": {"state": {"colourTemperature": 0}}}
+ )
+ h = _make_handler(session)
+ device = _make_device()
+ result = await h.get_color_temp(device)
+ assert result is None
+
+
+# ---------------------------------------------------------------------------
+# get_color — must return HS 2-tuple for HA hs_color, not RGB 3-tuple
+# ---------------------------------------------------------------------------
+
+
+class TestGetColorReturnsHSTuple:
+ """get_color must return (hue_degrees, saturation_percent) 2-tuple for HA hs_color."""
+
+ async def test_get_color_returns_two_tuple(self):
+ """get_color returns a 2-tuple (not 3-tuple)."""
+ session = _make_session(
+ {"light-1": {"state": {"hue": 120, "saturation": 75, "value": 100}}}
+ )
+ h = _make_handler(session)
+ device = _make_device()
+ result = await h.get_color(device)
+ assert result is not None
+ assert len(result) == 2, f"Expected 2-tuple (hue, sat), got {result!r}"
+
+ async def test_get_color_returns_correct_hue_and_saturation(self):
+ """get_color returns (hue, saturation) values matching API data."""
+ session = _make_session(
+ {"light-1": {"state": {"hue": 180, "saturation": 50, "value": 80}}}
+ )
+ h = _make_handler(session)
+ device = _make_device()
+ result = await h.get_color(device)
+ assert result == (180, 50)
diff --git a/tests/unit/test_compat_aliases.py b/tests/unit/test_compat_aliases.py
index 1c066e7..8cc255f 100644
--- a/tests/unit/test_compat_aliases.py
+++ b/tests/unit/test_compat_aliases.py
@@ -5,13 +5,15 @@
from unittest.mock import AsyncMock
from apyhiveapi.helper.compat_aliases import (
+ ActionCompatMixin,
HeatingCompatMixin,
LightCompatMixin,
+ SensorCompatMixin,
SessionCompatMixin,
SwitchCompatMixin,
WaterHeaterCompatMixin,
)
-from apyhiveapi.helper.hivedataclasses import Device
+from apyhiveapi.helper.hivedataclasses import Device, SessionConfig
def _make_device():
@@ -319,13 +321,139 @@ class Stub(SessionCompatMixin):
assert s.deviceList is s.device_list
async def test_update_interval_returns_true(self):
- """updateInterval always returns True (deprecated no-op)."""
+ """updateInterval returns True and updates config.scan_interval."""
class Stub(SessionCompatMixin):
"""Stub for updateInterval test."""
device_list = {}
+ def __init__(self):
+ self.config = SessionConfig()
+
s = Stub()
result = await s.updateInterval(60)
assert result is True
+
+
+# ---------------------------------------------------------------------------
+# SessionCompatMixin.updateInterval — bug fix tests
+# ---------------------------------------------------------------------------
+
+
+def _make_concrete_session():
+ """Return a minimal SessionCompatMixin subclass with a real SessionConfig."""
+
+ class ConcreteSession(SessionCompatMixin):
+ """Minimal concrete SessionCompatMixin for updateInterval tests."""
+
+ def __init__(self):
+ self.config = SessionConfig()
+ self.device_list = {}
+
+ async def start_session(self, config=None): # pylint: disable=unused-argument
+ """Stub."""
+
+ async def update_data(self, device): # pylint: disable=unused-argument
+ """Stub."""
+
+ return ConcreteSession()
+
+
+class TestSessionCompatMixinUpdateInterval:
+ """updateInterval must actually update config.scan_interval."""
+
+ async def test_update_interval_sets_scan_interval(self):
+ """updateInterval(300) must set self.config.scan_interval to timedelta(seconds=300)."""
+ from datetime import timedelta
+
+ session = _make_concrete_session()
+ await session.updateInterval(300)
+ assert session.config.scan_interval == timedelta(seconds=300)
+
+ async def test_update_interval_returns_true(self):
+ """updateInterval must return True on success."""
+ session = _make_concrete_session()
+ result = await session.updateInterval(60)
+ assert result is True
+
+
+# ---------------------------------------------------------------------------
+# Migrated from test_compat_aliases_extended.py
+# ---------------------------------------------------------------------------
+
+
+def _make_action_device(hive_type="action", ha_type="switch"):
+ return Device(
+ hive_id="h1",
+ hive_name="Test",
+ hive_type=hive_type,
+ ha_type=ha_type,
+ device_id="d1",
+ device_name="Test",
+ device_data={},
+ )
+
+
+class TestSensorCompatMixin:
+ """CamelCase alias smoke tests for SensorCompatMixin."""
+
+ async def test_get_sensor_delegates(self):
+ """getSensor delegates to get_sensor and returns its result."""
+
+ class Stub(SensorCompatMixin):
+ """Stub with mocked get_sensor."""
+
+ get_sensor = AsyncMock(return_value="sensor_result")
+
+ s = Stub()
+ d = _make_action_device(hive_type="motionsensor", ha_type="binary_sensor")
+ result = await s.getSensor(d)
+ s.get_sensor.assert_called_once_with(d)
+ assert result == "sensor_result"
+
+
+class TestActionCompatMixin:
+ """CamelCase alias smoke tests for ActionCompatMixin."""
+
+ async def test_get_action_delegates(self):
+ """getAction delegates to get_action and returns its result."""
+
+ class Stub(ActionCompatMixin):
+ """Stub with mocked get_action."""
+
+ get_action = AsyncMock(return_value="action_result")
+
+ s = Stub()
+ d = _make_action_device()
+ result = await s.getAction(d)
+ s.get_action.assert_called_once_with(d)
+ assert result == "action_result"
+
+ async def test_set_status_on_delegates(self):
+ """setStatusOn delegates to set_status_on and returns its result."""
+
+ class Stub(ActionCompatMixin):
+ """Stub with mocked set_status_on."""
+
+ set_status_on = AsyncMock(return_value=True)
+
+ s = Stub()
+ d = _make_action_device()
+ result = await s.setStatusOn(d)
+ s.set_status_on.assert_called_once_with(d)
+ assert result is True
+
+ async def test_set_status_off_delegates(self):
+ """setStatusOff delegates to set_status_off and returns its result."""
+
+ class Stub(ActionCompatMixin):
+ """Stub with mocked set_status_off."""
+
+ set_status_off = AsyncMock(return_value=True)
+
+ s = Stub()
+ d = _make_action_device()
+ result = await s.setStatusOff(d)
+ s.set_status_off.assert_called_once_with(d)
+ assert result is True
diff --git a/tests/unit/test_compat_aliases_extended.py b/tests/unit/test_compat_aliases_extended.py
deleted file mode 100644
index e103f48..0000000
--- a/tests/unit/test_compat_aliases_extended.py
+++ /dev/null
@@ -1,94 +0,0 @@
-"""Tests for SensorCompatMixin and ActionCompatMixin aliases (coverage gap fill)."""
-
-# pylint: disable=too-few-public-methods
-
-from unittest.mock import AsyncMock
-
-from apyhiveapi.helper.compat_aliases import ActionCompatMixin, SensorCompatMixin
-from apyhiveapi.helper.hivedataclasses import Device
-
-
-def _make_device(hive_type="action", ha_type="switch"):
- return Device(
- hive_id="h1",
- hive_name="Test",
- hive_type=hive_type,
- ha_type=ha_type,
- device_id="d1",
- device_name="Test",
- device_data={},
- )
-
-
-# ---------------------------------------------------------------------------
-# SensorCompatMixin
-# ---------------------------------------------------------------------------
-
-
-class TestSensorCompatMixin:
- """CamelCase alias smoke tests for SensorCompatMixin."""
-
- async def test_get_sensor_delegates(self):
- """getSensor delegates to get_sensor and returns its result."""
-
- class Stub(SensorCompatMixin):
- """Stub with mocked get_sensor."""
-
- get_sensor = AsyncMock(return_value="sensor_result")
-
- s = Stub()
- d = _make_device(hive_type="motionsensor", ha_type="binary_sensor")
- result = await s.getSensor(d)
- s.get_sensor.assert_called_once_with(d)
- assert result == "sensor_result"
-
-
-# ---------------------------------------------------------------------------
-# ActionCompatMixin
-# ---------------------------------------------------------------------------
-
-
-class TestActionCompatMixin:
- """CamelCase alias smoke tests for ActionCompatMixin."""
-
- async def test_get_action_delegates(self):
- """getAction delegates to get_action and returns its result."""
-
- class Stub(ActionCompatMixin):
- """Stub with mocked get_action."""
-
- get_action = AsyncMock(return_value="action_result")
-
- s = Stub()
- d = _make_device()
- result = await s.getAction(d)
- s.get_action.assert_called_once_with(d)
- assert result == "action_result"
-
- async def test_set_status_on_delegates(self):
- """setStatusOn delegates to set_status_on and returns its result."""
-
- class Stub(ActionCompatMixin):
- """Stub with mocked set_status_on."""
-
- set_status_on = AsyncMock(return_value=True)
-
- s = Stub()
- d = _make_device()
- result = await s.setStatusOn(d)
- s.set_status_on.assert_called_once_with(d)
- assert result is True
-
- async def test_set_status_off_delegates(self):
- """setStatusOff delegates to set_status_off and returns its result."""
-
- class Stub(ActionCompatMixin):
- """Stub with mocked set_status_off."""
-
- set_status_off = AsyncMock(return_value=True)
-
- s = Stub()
- d = _make_device()
- result = await s.setStatusOff(d)
- s.set_status_off.assert_called_once_with(d)
- assert result is True
diff --git a/tests/unit/test_dataclasses.py b/tests/unit/test_dataclasses.py
index 5792fd3..e7f9f6c 100644
--- a/tests/unit/test_dataclasses.py
+++ b/tests/unit/test_dataclasses.py
@@ -112,12 +112,42 @@ def test_default_scan_interval_is_120s(self):
c = SessionConfig()
assert c.scan_interval == timedelta(seconds=120)
- def test_default_battery_is_empty_list(self):
- """Test battery defaults to empty list."""
+ def test_default_battery_is_empty_set(self):
+ """Test battery defaults to empty set."""
c = SessionConfig()
- assert c.battery == []
+ assert c.battery == set()
def test_username_stored(self):
"""Test username can be set and retrieved."""
c = SessionConfig(username="user@example.com")
assert c.username == "user@example.com"
+
+
+class TestSessionConfigCollectionTypes:
+ """battery and mode must be sets for O(1) membership checks."""
+
+ def test_battery_is_set(self):
+ """SessionConfig.battery is a set (not a list)."""
+ config = SessionConfig()
+ assert isinstance(config.battery, set), (
+ f"Expected set, got {type(config.battery).__name__}"
+ )
+
+ def test_mode_is_set(self):
+ """SessionConfig.mode is a set (not a list)."""
+ config = SessionConfig()
+ assert isinstance(config.mode, set), (
+ f"Expected set, got {type(config.mode).__name__}"
+ )
+
+ def test_battery_supports_membership_check(self):
+ """Can check membership in battery using 'in' after add."""
+ config = SessionConfig()
+ config.battery.add("d1")
+ assert "d1" in config.battery
+
+ def test_mode_supports_membership_check(self):
+ """Can check membership in mode using 'in' after add."""
+ config = SessionConfig()
+ config.mode.add("p1")
+ assert "p1" in config.mode
diff --git a/tests/unit/test_device_registration.py b/tests/unit/test_device_registration.py
index 372eada..081b494 100644
--- a/tests/unit/test_device_registration.py
+++ b/tests/unit/test_device_registration.py
@@ -84,35 +84,35 @@ class StubDRM(DeviceRegistrationMixin):
class TestGenerateHashDevice:
async def test_returns_verifier_config_with_required_keys(self):
stub = await _make_stub()
- result = await stub.generate_hash_device("grp-key", "dev-key")
+ result = stub.generate_hash_device("grp-key", "dev-key")
assert "PasswordVerifier" in result
assert "Salt" in result
async def test_password_verifier_is_non_empty_string(self):
stub = await _make_stub()
- result = await stub.generate_hash_device("grp-key", "dev-key")
+ result = stub.generate_hash_device("grp-key", "dev-key")
assert isinstance(result["PasswordVerifier"], str)
assert len(result["PasswordVerifier"]) > 0
async def test_salt_is_non_empty_string(self):
stub = await _make_stub()
- result = await stub.generate_hash_device("grp-key", "dev-key")
+ result = stub.generate_hash_device("grp-key", "dev-key")
assert isinstance(result["Salt"], str)
assert len(result["Salt"]) > 0
async def test_sets_device_password_on_self(self):
stub = await _make_stub()
stub.device_password = None
- await stub.generate_hash_device("grp-key", "dev-key")
+ stub.generate_hash_device("grp-key", "dev-key")
assert stub.device_password is not None
assert isinstance(stub.device_password, str)
assert len(stub.device_password) > 0
async def test_different_calls_produce_different_passwords(self):
stub = await _make_stub()
- await stub.generate_hash_device("grp-key", "dev-key")
+ stub.generate_hash_device("grp-key", "dev-key")
password_first = stub.device_password
- await stub.generate_hash_device("grp-key", "dev-key")
+ stub.generate_hash_device("grp-key", "dev-key")
password_second = stub.device_password
# Passwords are randomly generated — they should almost never match.
# We check they are independently set strings (not None).
@@ -121,9 +121,9 @@ async def test_different_calls_produce_different_passwords(self):
async def test_different_device_keys_produce_different_verifiers(self):
stub = await _make_stub()
- result1 = await stub.generate_hash_device("grp-key", "dev-key-1")
+ result1 = stub.generate_hash_device("grp-key", "dev-key-1")
verifier1 = result1["PasswordVerifier"]
- result2 = await stub.generate_hash_device("grp-key", "dev-key-2")
+ result2 = stub.generate_hash_device("grp-key", "dev-key-2")
verifier2 = result2["PasswordVerifier"]
# Different keys + random passwords → almost certainly different verifiers
# (at minimum the structure is valid for both)
@@ -173,7 +173,7 @@ async def test_reflects_updated_device_key(self):
class TestConfirmDevice:
async def test_uses_hostname_when_no_device_name(self):
stub = await _make_stub()
- stub.generate_hash_device = AsyncMock(
+ stub.generate_hash_device = MagicMock(
return_value={"PasswordVerifier": "pv", "Salt": "s"}
)
with patch(
@@ -191,7 +191,7 @@ async def test_uses_hostname_when_no_device_name(self):
async def test_uses_provided_device_name(self):
stub = await _make_stub()
- stub.generate_hash_device = AsyncMock(
+ stub.generate_hash_device = MagicMock(
return_value={"PasswordVerifier": "pv", "Salt": "s"}
)
await stub.confirm_device("custom-name")
@@ -202,27 +202,27 @@ async def test_uses_provided_device_name(self):
async def test_returns_executor_result_on_success(self):
stub = await _make_stub()
- stub.generate_hash_device = AsyncMock(
+ stub.generate_hash_device = MagicMock(
return_value={"PasswordVerifier": "pv", "Salt": "s"}
)
stub.loop.run_in_executor.return_value = {"UserConfirmed": True}
result = await stub.confirm_device("test-device")
assert result == {"UserConfirmed": True}
- async def test_not_authorized_raises_invalid_2fa(self):
+ async def test_not_authorized_raises_api_error(self):
stub = await _make_stub()
- stub.generate_hash_device = AsyncMock(
+ stub.generate_hash_device = MagicMock(
return_value={"PasswordVerifier": "pv", "Salt": "s"}
)
stub.loop.run_in_executor.side_effect = _named_client_error(
"NotAuthorizedException"
)
- with pytest.raises(HiveInvalid2FACode):
+ with pytest.raises(HiveApiError):
await stub.confirm_device("name")
async def test_code_mismatch_raises_invalid_2fa(self):
stub = await _make_stub()
- stub.generate_hash_device = AsyncMock(
+ stub.generate_hash_device = MagicMock(
return_value={"PasswordVerifier": "pv", "Salt": "s"}
)
stub.loop.run_in_executor.side_effect = _named_client_error(
@@ -233,7 +233,7 @@ async def test_code_mismatch_raises_invalid_2fa(self):
async def test_endpoint_error_raises_api_error(self):
stub = await _make_stub()
- stub.generate_hash_device = AsyncMock(
+ stub.generate_hash_device = MagicMock(
return_value={"PasswordVerifier": "pv", "Salt": "s"}
)
stub.loop.run_in_executor.side_effect = _endpoint_error()
@@ -242,7 +242,7 @@ async def test_endpoint_error_raises_api_error(self):
async def test_passes_access_token_to_executor(self):
stub = await _make_stub(access_token="my-access-token")
- stub.generate_hash_device = AsyncMock(
+ stub.generate_hash_device = MagicMock(
return_value={"PasswordVerifier": "pv", "Salt": "s"}
)
await stub.confirm_device("dev")
@@ -252,7 +252,7 @@ async def test_passes_access_token_to_executor(self):
async def test_passes_device_key_to_executor(self):
stub = await _make_stub(device_key="my-device-key")
- stub.generate_hash_device = AsyncMock(
+ stub.generate_hash_device = MagicMock(
return_value={"PasswordVerifier": "pv", "Salt": "s"}
)
await stub.confirm_device("dev")
@@ -466,48 +466,24 @@ async def test_passes_access_token_and_device_key_to_executor(self):
assert partial_fn.keywords["AccessToken"] == "forget-token"
assert partial_fn.keywords["DeviceKey"] == "forget-key"
- async def test_not_authorized_raises_invalid_2fa(self):
+ async def test_not_authorized_raises_api_error(self):
stub = await _make_stub()
stub.loop.run_in_executor.side_effect = _named_client_error(
"NotAuthorizedException"
)
- with pytest.raises(HiveInvalid2FACode):
+ with pytest.raises(HiveApiError):
await stub.forget_device("acc-token", "dev-key")
- async def test_other_client_error_does_not_raise(self):
- """ClientErrors other than NotAuthorizedException are silently swallowed."""
+ async def test_other_client_error_raises_api_error(self):
+ """All ClientErrors raise HiveApiError."""
stub = await _make_stub()
stub.loop.run_in_executor.side_effect = _named_client_error("SomeOtherError")
- # No exception raised — result will be None
- result = await stub.forget_device("acc-token", "dev-key")
- assert result is None
+ with pytest.raises(HiveApiError):
+ await stub.forget_device("acc-token", "dev-key")
- async def test_endpoint_error_does_not_raise_api_error(self):
- """EndpointConnectionError only raises HiveApiError if class name is
- 'ResourceNotFoundException', which can never be true for an
- EndpointConnectionError. The exception is therefore silently swallowed."""
+ async def test_endpoint_error_raises_api_error(self):
stub = await _make_stub()
stub.loop.run_in_executor.side_effect = _endpoint_error()
- # The guard condition is always False for a real EndpointConnectionError,
- # so no exception propagates.
- result = await stub.forget_device("acc-token", "dev-key")
- assert result is None
-
- async def test_endpoint_error_named_resource_not_found_raises_api_error(self):
- """A subclass of EndpointConnectionError named 'ResourceNotFoundException'
- satisfies the guard at line 339 and raises HiveApiError (line 340)."""
- stub = await _make_stub()
- # Craft a class whose __class__.__name__ == "ResourceNotFoundException"
- # but which IS an EndpointConnectionError (so it's caught by the except clause)
- resource_cls = type(
- "ResourceNotFoundException",
- (botocore.exceptions.EndpointConnectionError,),
- {},
- )
- resource_err = resource_cls(
- endpoint_url="https://cognito.eu-west-1.amazonaws.com"
- )
- stub.loop.run_in_executor.side_effect = resource_err
with pytest.raises(HiveApiError):
await stub.forget_device("acc-token", "dev-key")
@@ -528,7 +504,7 @@ async def test_u_value_zero_raises_value_error(self):
stub = await _make_stub()
with patch("apyhiveapi.api.device_registration.calculate_u", return_value=0):
with pytest.raises(ValueError, match="U cannot be zero"):
- await stub.get_device_authentication_key(
+ stub.get_device_authentication_key(
stub.device_group_key,
stub.device_key,
stub.device_password,
@@ -555,7 +531,7 @@ async def fake_init():
stub.client = MagicMock()
stub.async_init = fake_init
- stub.generate_hash_device = AsyncMock(
+ stub.generate_hash_device = MagicMock(
return_value={"PasswordVerifier": "pv", "Salt": "s"}
)
await stub.confirm_device("name")
@@ -622,10 +598,10 @@ async def fake_init():
class TestConfirmDeviceSwallowedErrors:
- async def test_other_client_error_is_swallowed(self):
- """ClientError with an unrecognised class name is caught but not re-raised (184->193)."""
+ async def test_other_client_error_raises_api_error(self):
+ """ClientErrors other than CodeMismatchException raise HiveApiError."""
stub = await _make_stub()
- stub.generate_hash_device = AsyncMock(
+ stub.generate_hash_device = MagicMock(
return_value={"PasswordVerifier": "pv", "Salt": "s"}
)
wrong_cls = type("SomeOtherError", (botocore.exceptions.ClientError,), {})
@@ -633,35 +609,8 @@ async def test_other_client_error_is_swallowed(self):
{"Error": {"Code": "SomeOtherError", "Message": "msg"}}, "op"
)
stub.loop.run_in_executor.side_effect = wrong_err
- result = await stub.confirm_device("name")
- assert result is None # no HiveInvalid2FACode raised
-
- async def test_endpoint_error_wrong_name_is_swallowed(self):
- """EndpointConnectionError subclass with wrong __name__ is swallowed (190->193)."""
- stub = await _make_stub()
- stub.generate_hash_device = AsyncMock(
- return_value={"PasswordVerifier": "pv", "Salt": "s"}
- )
- wrong_cls = type(
- "WrongEndpoint", (botocore.exceptions.EndpointConnectionError,), {}
- )
- wrong_err = wrong_cls(endpoint_url="https://cognito.eu-west-1.amazonaws.com")
- stub.loop.run_in_executor.side_effect = wrong_err
- result = await stub.confirm_device("name")
- assert result is None # no HiveApiError raised
-
-
-class TestUpdateDeviceStatusSwallowedEndpointError:
- async def test_endpoint_error_wrong_name_is_swallowed(self):
- """EndpointConnectionError with wrong name is caught but not re-raised (211->214)."""
- stub = await _make_stub()
- wrong_cls = type(
- "WrongEndpoint", (botocore.exceptions.EndpointConnectionError,), {}
- )
- wrong_err = wrong_cls(endpoint_url="https://cognito.eu-west-1.amazonaws.com")
- stub.loop.run_in_executor.side_effect = wrong_err
- result = await stub.update_device_status()
- assert result is None # no HiveApiError raised
+ with pytest.raises(HiveApiError):
+ await stub.confirm_device("name")
class TestDeviceRegistration:
@@ -716,7 +665,6 @@ async def test_returns_response_with_required_keys(self):
with patch.object(
stub,
"get_device_authentication_key",
- new_callable=AsyncMock,
return_value=fake_hkdf,
):
result = await stub.process_device_challenge(self._CHALLENGE_PARAMS)
@@ -733,7 +681,6 @@ async def test_username_matches_challenge_parameter(self):
with patch.object(
stub,
"get_device_authentication_key",
- new_callable=AsyncMock,
return_value=fake_hkdf,
):
result = await stub.process_device_challenge(self._CHALLENGE_PARAMS)
@@ -746,7 +693,6 @@ async def test_device_key_matches_stub_device_key(self):
with patch.object(
stub,
"get_device_authentication_key",
- new_callable=AsyncMock,
return_value=fake_hkdf,
):
result = await stub.process_device_challenge(self._CHALLENGE_PARAMS)
@@ -759,7 +705,6 @@ async def test_secret_block_echoed_back(self):
with patch.object(
stub,
"get_device_authentication_key",
- new_callable=AsyncMock,
return_value=fake_hkdf,
):
result = await stub.process_device_challenge(self._CHALLENGE_PARAMS)
@@ -772,7 +717,6 @@ async def test_no_client_secret_no_secret_hash(self):
with patch.object(
stub,
"get_device_authentication_key",
- new_callable=AsyncMock,
return_value=fake_hkdf,
):
result = await stub.process_device_challenge(self._CHALLENGE_PARAMS)
@@ -785,7 +729,6 @@ async def test_with_client_secret_adds_secret_hash(self):
with patch.object(
stub,
"get_device_authentication_key",
- new_callable=AsyncMock,
return_value=fake_hkdf,
):
result = await stub.process_device_challenge(self._CHALLENGE_PARAMS)
@@ -800,7 +743,6 @@ async def test_timestamp_format_matches_cognito_pattern(self):
with patch.object(
stub,
"get_device_authentication_key",
- new_callable=AsyncMock,
return_value=fake_hkdf,
):
result = await stub.process_device_challenge(self._CHALLENGE_PARAMS)
@@ -823,7 +765,6 @@ async def test_password_claim_signature_is_base64_string(self):
with patch.object(
stub,
"get_device_authentication_key",
- new_callable=AsyncMock,
return_value=fake_hkdf,
):
result = await stub.process_device_challenge(self._CHALLENGE_PARAMS)
@@ -843,7 +784,6 @@ async def test_salt_as_integer_is_padded(self):
with patch.object(
stub,
"get_device_authentication_key",
- new_callable=AsyncMock,
return_value=fake_hkdf,
) as mock_auth_key:
await stub.process_device_challenge(params)
@@ -863,7 +803,7 @@ async def test_returns_16_bytes(self):
# Use a valid server_b_value that won't make u_value == 0.
# Pick a large prime-ish value that is different from large_a_value.
server_b_value = stub.large_a_value + 1
- result = await stub.get_device_authentication_key(
+ result = stub.get_device_authentication_key(
"grp-key",
"dev-key",
"dev-pass",
@@ -876,10 +816,80 @@ async def test_returns_16_bytes(self):
async def test_deterministic_for_same_inputs(self):
stub = await _make_stub()
server_b_value = stub.large_a_value + 1
- result1 = await stub.get_device_authentication_key(
+ result1 = stub.get_device_authentication_key(
"grp-key", "dev-key", "dev-pass", server_b_value, "aabbccdd"
)
- result2 = await stub.get_device_authentication_key(
+ result2 = stub.get_device_authentication_key(
"grp-key", "dev-key", "dev-pass", server_b_value, "aabbccdd"
)
assert result1 == result2
+
+
+# ---------------------------------------------------------------------------
+# confirm_device and forget_device wrong exception mapping
+# ---------------------------------------------------------------------------
+
+
+class TestConfirmDeviceWrongException:
+ """confirm_device must map NotAuthorizedException → HiveApiError (not HiveInvalid2FACode)."""
+
+ async def test_not_authorized_raises_hive_api_error(self):
+ """NotAuthorizedException in confirm_device raises HiveApiError, not HiveInvalid2FACode."""
+ stub = await _make_stub()
+ err = botocore.exceptions.ClientError(
+ {"Error": {"Code": "NotAuthorizedException", "Message": "not auth"}},
+ "ConfirmDevice",
+ )
+ stub.loop.run_in_executor = AsyncMock(side_effect=err)
+
+ with pytest.raises(HiveApiError):
+ await stub.confirm_device()
+
+ async def test_not_authorized_does_not_raise_invalid_2fa(self):
+ """confirm_device must not raise HiveInvalid2FACode for NotAuthorizedException."""
+ stub = await _make_stub()
+ err = botocore.exceptions.ClientError(
+ {"Error": {"Code": "NotAuthorizedException", "Message": "not auth"}},
+ "ConfirmDevice",
+ )
+ stub.loop.run_in_executor = AsyncMock(side_effect=err)
+
+ with pytest.raises(Exception) as exc_info:
+ await stub.confirm_device()
+
+ assert not isinstance(exc_info.value, HiveInvalid2FACode)
+
+
+class TestForgetDeviceWrongException:
+ """forget_device must map NotAuthorizedException → HiveApiError."""
+
+ async def test_not_authorized_raises_hive_api_error(self):
+ """NotAuthorizedException in forget_device raises HiveApiError, not HiveInvalid2FACode."""
+ stub = await _make_stub()
+ err = botocore.exceptions.ClientError(
+ {"Error": {"Code": "NotAuthorizedException", "Message": "not auth"}},
+ "ForgetDevice",
+ )
+ stub.loop.run_in_executor = AsyncMock(side_effect=err)
+
+ with pytest.raises(HiveApiError):
+ await stub.forget_device("tok", "key")
+
+
+# ---------------------------------------------------------------------------
+# generate_hash_device — unnecessary async removed
+# ---------------------------------------------------------------------------
+
+
+class TestUnnecessaryAsync:
+ """generate_hash_device should work without async (callable directly)."""
+
+ def test_generate_hash_device_returns_config(self):
+ """generate_hash_device returns the verifier config without needing await."""
+ r = DeviceRegistrationMixin.__new__(DeviceRegistrationMixin)
+ r.device_password = None
+ r.g_value = 2
+ r.big_n = 0xFFFF
+ result = r.generate_hash_device("grp", "key")
+ assert "PasswordVerifier" in result
+ assert "Salt" in result
diff --git a/tests/unit/test_heating.py b/tests/unit/test_heating.py
new file mode 100644
index 0000000..56c13ca
--- /dev/null
+++ b/tests/unit/test_heating.py
@@ -0,0 +1,495 @@
+"""Extended branch-coverage tests for Climate / HiveHeating."""
+
+# pylint: disable=too-few-public-methods
+from datetime import datetime
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+from apyhiveapi.devices.heating import Climate
+from apyhiveapi.helper.hivedataclasses import Device, SessionConfig
+from apyhiveapi.helper.map import Map
+
+_TODAY = str(datetime.date(datetime.now()))
+_CURRENT_TEMP = 19.0
+_SCHEDULE_MODE = "SCHEDULE"
+_BOOST_MINS = 5
+_OFF_MODE = "OFF"
+
+
+def _make_climate(products=None, devices=None, min_max=None):
+ session = MagicMock()
+ session.data = Map(
+ {
+ "products": products or {},
+ "devices": devices or {},
+ "actions": {},
+ "minMax": min_max or {},
+ "user": {},
+ }
+ )
+ session.config = SessionConfig()
+ session.helper = MagicMock()
+ session.helper.device_recovered = MagicMock()
+ session.helper.error_check = AsyncMock()
+ session.helper.get_schedule_nnl = MagicMock(
+ return_value={"now": {}, "next": {}, "later": {}}
+ )
+ session.attr = MagicMock()
+ session.attr.online_offline = AsyncMock(return_value=True)
+ session.attr.state_attributes = AsyncMock(return_value={})
+ session.api = MagicMock()
+ session.api.set_state = AsyncMock(return_value={"original": 200, "parsed": {}})
+ session.hive_refresh_tokens = AsyncMock()
+ session.get_devices = AsyncMock(return_value=True)
+ session.should_use_cached_data = MagicMock(return_value=False)
+ session.get_cached_device = MagicMock(return_value=None)
+ session.set_cached_device = MagicMock(side_effect=lambda d: d)
+ return Climate(session=session)
+
+
+def _make_device(hive_id="heat-1", device_id="dev-1", hive_type="heating"):
+ return Device(
+ hive_id=hive_id,
+ hive_name="Hallway",
+ hive_type=hive_type,
+ ha_type="climate",
+ device_id=device_id,
+ device_name="Hallway",
+ device_data={"online": True},
+ ha_name="Hallway",
+ )
+
+
+class TestGetCurrentTemperature:
+ async def test_minmax_today_same_date_updates_min_max(self):
+ """When minMax entry exists for today's date, TodayMin/TodayMax are updated."""
+ initial_min = _CURRENT_TEMP + 2.0
+ initial_max = _CURRENT_TEMP - 2.0
+ existing = {
+ "TodayMin": initial_min,
+ "TodayMax": initial_max,
+ "TodayDate": _TODAY,
+ "RestartMin": initial_min,
+ "RestartMax": initial_max,
+ }
+ climate = _make_climate(
+ products={"heat-1": {"props": {"temperature": _CURRENT_TEMP}}},
+ min_max={"heat-1": existing},
+ )
+ result = await climate.get_current_temperature(_make_device())
+ assert result == _CURRENT_TEMP
+ entry = climate.session.data.minMax["heat-1"]
+ assert entry["TodayMin"] == min(initial_min, _CURRENT_TEMP)
+ assert entry["TodayMax"] == max(initial_max, _CURRENT_TEMP)
+ assert entry["RestartMin"] == min(initial_min, _CURRENT_TEMP)
+ assert entry["RestartMax"] == max(initial_max, _CURRENT_TEMP)
+
+ async def test_minmax_different_date_resets_today(self):
+ """When minMax entry exists but TodayDate is stale, today values are reset."""
+ existing = {
+ "TodayMin": 5.0,
+ "TodayMax": 30.0,
+ "TodayDate": "2000-01-01",
+ "RestartMin": 5.0,
+ "RestartMax": 30.0,
+ }
+ climate = _make_climate(
+ products={"heat-1": {"props": {"temperature": _CURRENT_TEMP}}},
+ min_max={"heat-1": existing},
+ )
+ result = await climate.get_current_temperature(_make_device())
+ assert result == _CURRENT_TEMP
+ entry = climate.session.data.minMax["heat-1"]
+ assert entry["TodayMin"] == _CURRENT_TEMP
+ assert entry["TodayMax"] == _CURRENT_TEMP
+ assert entry["TodayDate"] == _TODAY
+
+ async def test_keyerror_returns_none(self):
+ """Missing device.hive_id in products returns None."""
+ climate = _make_climate(products={})
+ result = await climate.get_current_temperature(_make_device())
+ assert result is None
+
+
+class TestGetTargetTemperature:
+ async def test_non_numeric_target_returns_none(self):
+ """Non-numeric target temperature string returns None."""
+ climate = _make_climate({"heat-1": {"state": {"target": "N/A"}}})
+ result = await climate.get_target_temperature(_make_device())
+ assert result is None
+
+
+class TestGetState:
+ async def test_current_less_than_target_returns_on(self):
+ """When current_temp < target_temp, state resolves to ON."""
+ climate = _make_climate(
+ {
+ "heat-1": {
+ "props": {"temperature": 19.0},
+ "state": {"target": 21.0},
+ }
+ }
+ )
+ result = await climate.get_state(_make_device())
+ assert result == "ON"
+
+ async def test_current_ge_target_returns_off(self):
+ """When current_temp >= target_temp, state resolves to OFF."""
+ climate = _make_climate(
+ {
+ "heat-1": {
+ "props": {"temperature": 21.0},
+ "state": {"target": 19.0},
+ }
+ }
+ )
+ result = await climate.get_state(_make_device())
+ assert result == "OFF"
+
+ async def test_none_temps_returns_none(self):
+ """When temperatures cannot be read, get_state returns None."""
+ climate = _make_climate(products={})
+ result = await climate.get_state(_make_device())
+ assert result is None
+
+ async def test_key_error_in_get_current_temperature_is_caught(self):
+ """KeyError from get_current_temperature is caught; get_state returns None."""
+ climate = _make_climate(
+ {"heat-1": {"state": {"mode": "MANUAL", "target": 20.0}, "props": {}}}
+ )
+ with patch.object(
+ climate, "get_current_temperature", new_callable=AsyncMock
+ ) as mock_t:
+ mock_t.side_effect = KeyError("missing_key")
+ result = await climate.get_state(_make_device())
+ assert result is None
+
+ async def test_type_error_in_get_target_temperature_is_caught(self):
+ """TypeError from get_target_temperature is caught; get_state returns None."""
+ climate = _make_climate(
+ {"heat-1": {"state": {"mode": "MANUAL", "target": 20.0}, "props": {}}}
+ )
+ with patch.object(
+ climate, "get_current_temperature", new_callable=AsyncMock
+ ) as mock_cur:
+ mock_cur.return_value = 19.0
+ with patch.object(
+ climate, "get_target_temperature", new_callable=AsyncMock
+ ) as mock_tgt:
+ mock_tgt.side_effect = TypeError("bad type")
+ result = await climate.get_state(_make_device())
+ assert result is None
+
+
+class TestGetCurrentOperation:
+ async def test_returns_working_state(self):
+ """get_current_operation returns the 'working' value from props."""
+ climate = _make_climate({"heat-1": {"props": {"working": True}, "state": {}}})
+ result = await climate.get_current_operation(_make_device())
+ assert result is True
+
+
+class TestSetBoostOnValidation:
+ """set_boost_on input validation must reject bad values with None."""
+
+ async def test_non_numeric_mins_returns_none(self):
+ climate = _make_climate(products={"heat-1": {"type": "heating"}})
+ result = await climate.set_boost_on(_make_device(), "abc", 20)
+ assert result is None
+
+ async def test_non_numeric_temp_returns_none(self):
+ climate = _make_climate(products={"heat-1": {"type": "heating"}})
+ result = await climate.set_boost_on(_make_device(), "30", "hot")
+ assert result is None
+
+ async def test_temp_fraction_above_max_returns_none(self):
+ """32.9 exceeds the 32 maximum and must not be truncated past validation."""
+ climate = _make_climate(products={"heat-1": {"type": "heating"}})
+ result = await climate.set_boost_on(_make_device(), "30", 32.9)
+ assert result is None
+
+ async def test_valid_boost_executes_state_change(self):
+ climate = _make_climate(products={"heat-1": {"type": "heating"}})
+ result = await climate.set_boost_on(_make_device(), "30", 22)
+ assert result is True
+
+
+class TestSetBoostOff:
+ async def test_not_in_products_returns_false(self):
+ """Device hive_id not present in products returns False."""
+ climate = _make_climate(products={})
+ result = await climate.set_boost_off(_make_device())
+ assert result is False
+
+ async def test_previous_off_mode_restored(self):
+ """Previous mode OFF sets mode=OFF and target falls back to 7."""
+ climate = _make_climate(
+ {
+ "heat-1": {
+ "type": "heating",
+ "state": {"boost": _BOOST_MINS},
+ "props": {
+ "previous": {
+ "mode": _OFF_MODE,
+ "target": None,
+ }
+ },
+ }
+ }
+ )
+ result = await climate.set_boost_off(_make_device())
+ assert result is True
+ _, kwargs = climate.session.api.set_state.call_args
+ assert kwargs.get("mode") == _OFF_MODE
+ assert kwargs.get("target") == 7
+
+
+class TestGetClimate:
+ async def test_device_data_not_dict_gets_reset(self):
+ """Non-dict device_data is replaced with an empty dict before use."""
+ climate = _make_climate(
+ products={"heat-1": {"props": {}, "state": {}}},
+ devices={"dev-1": {"props": {}, "parent": None}},
+ )
+ d = _make_device()
+ d.device_data = None
+ await climate.get_climate(d)
+ assert isinstance(d.device_data, dict)
+
+ async def test_offline_device_error_check_called(self):
+ """Offline device triggers error_check and status defaults to None values."""
+ climate = _make_climate(
+ products={"heat-1": {}},
+ devices={"dev-1": {}},
+ )
+ climate.session.attr.online_offline = AsyncMock(return_value=False)
+ d = _make_device()
+ result = await climate.get_climate(d)
+ climate.session.helper.error_check.assert_called_once()
+ assert result.status["current_temperature"] is None
+
+ async def test_cache_hit_returns_cached(self):
+ """When cached data is available and poll is slow, returns cached device."""
+ climate = _make_climate()
+ climate.session.should_use_cached_data = MagicMock(return_value=True)
+ cached_device = _make_device()
+ cached_device.status = {"current_temperature": 20.0}
+ climate.session.get_cached_device = MagicMock(return_value=cached_device)
+ d = _make_device()
+ result = await climate.get_climate(d)
+ assert result is cached_device
+ climate.session.attr.online_offline.assert_not_called()
+
+
+class TestGetScheduleNowNextLater:
+ async def test_offline_returns_none(self):
+ """Offline device returns None regardless of mode."""
+ climate = _make_climate(
+ {"heat-1": {"state": {"mode": _SCHEDULE_MODE, "schedule": {}}}}
+ )
+ climate.session.attr.online_offline = AsyncMock(return_value=False)
+ result = await climate.get_schedule_now_next_later(_make_device())
+ assert result is None
+
+
+# ---------------------------------------------------------------------------
+# get_mode — BOOST path with missing props.previous must not log error
+# ---------------------------------------------------------------------------
+
+
+class TestGetModeBoostMissingPrevious:
+ """get_mode BOOST path must use safe access, not bare dict that logs a spurious error."""
+
+ async def test_boost_missing_previous_returns_none_without_error_log(self):
+ """When mode=BOOST and props has no previous, get_mode returns None without error log."""
+ climate = _make_climate({"heat-1": {"state": {"mode": "BOOST"}, "props": {}}})
+ d = _make_device()
+ with patch("apyhiveapi.devices.heating._LOGGER") as mock_log:
+ result = await climate.get_mode(d)
+ assert result is None
+ mock_log.error.assert_not_called()
+
+ async def test_boost_with_previous_mode_returns_mapped_value(self):
+ """When mode=BOOST and props.previous.mode exists, returns the mapped HA value."""
+ from apyhiveapi.helper.const import HIVETOHA
+
+ climate = _make_climate(
+ {
+ "heat-1": {
+ "state": {"mode": "BOOST"},
+ "props": {"previous": {"mode": "MANUAL"}},
+ }
+ }
+ )
+ d = _make_device()
+ result = await climate.get_mode(d)
+ expected = HIVETOHA["Heating"].get("MANUAL", "MANUAL")
+ assert result == expected
+
+
+# ---------------------------------------------------------------------------
+# set_boost_off — must return False when prev_mode is None (not send mode=None to API)
+# ---------------------------------------------------------------------------
+
+
+class TestSetBoostOffNullPrevMode:
+ """set_boost_off returns False (not an API call) when prev_mode is None."""
+
+ async def test_set_boost_off_returns_false_when_prev_mode_missing(self):
+ """set_boost_off returns False and skips _execute_state_change when prev mode absent."""
+ from apyhiveapi.devices.heating import HiveHeating
+
+ class StubHeating(HiveHeating):
+ """Concrete stub for testing."""
+
+ h = StubHeating()
+ h.session = MagicMock()
+ h.session.data.products = {
+ "h1": {
+ "state": {"mode": "BOOST"},
+ "props": {},
+ }
+ }
+ h._execute_state_change = AsyncMock(return_value=True)
+ h.get_boost_status = AsyncMock(return_value="ON")
+
+ d = Device(
+ hive_id="h1",
+ hive_name="T",
+ hive_type="heating",
+ ha_type="climate",
+ device_id="d1",
+ device_name="T",
+ device_data={"online": True},
+ ha_name="Heating",
+ )
+ result = await h.set_boost_off(d)
+ assert result is False
+ h._execute_state_change.assert_not_called()
+
+
+# ===========================================================================
+# Migrated from test_remaining_branches.py
+# ===========================================================================
+
+
+class TestHeatingGetStateKeyError:
+ """Lines 206-207: KeyError/TypeError branch in get_state."""
+
+ async def test_get_state_key_error_returns_none(self):
+ """Missing product entry causes get_current_temperature to return None,
+ leaving final as None without raising."""
+ # products dict is empty — device.hive_id not found → both temp helpers
+ # return None → the if branch is skipped → final stays None
+ climate = _make_climate(products={})
+ d = _make_device()
+ result = await climate.get_state(d)
+ assert result is None
+
+
+class TestHeatingGetHeatOnDemand:
+ """Line 231: get_heat_on_demand happy path."""
+
+ async def test_get_heat_on_demand_returns_value(self):
+ """Returns the nested autoBoost.active value from products."""
+ climate = _make_climate({"heat-1": {"props": {"autoBoost": {"active": True}}}})
+ result = await climate.get_heat_on_demand(_make_device())
+ assert result is True
+
+ async def test_get_heat_on_demand_returns_none_when_missing(self):
+ """Returns None when the nested path does not exist."""
+ climate = _make_climate({"heat-1": {"props": {}}})
+ result = await climate.get_heat_on_demand(_make_device())
+ assert result is None
+
+
+class TestHeatingSetHeatOnDemand:
+ """Lines 337-342: set_heat_on_demand calls _execute_state_change with autoBoost kwarg."""
+
+ async def test_set_heat_on_demand_enabled(self):
+ """set_heat_on_demand passes autoBoost='ENABLED' to the API."""
+ climate = _make_climate({"heat-1": {"type": "heating"}})
+ result = await climate.set_heat_on_demand(_make_device(), "ENABLED")
+ assert result is True
+ climate.session.api.set_state.assert_called_once()
+ _, kwargs = climate.session.api.set_state.call_args
+ assert kwargs.get("autoBoost") == "ENABLED"
+
+ async def test_set_heat_on_demand_disabled(self):
+ """set_heat_on_demand passes autoBoost='DISABLED' to the API."""
+ climate = _make_climate({"heat-1": {"type": "heating"}})
+ result = await climate.set_heat_on_demand(_make_device(), "DISABLED")
+ assert result is True
+ _, kwargs = climate.session.api.set_state.call_args
+ assert kwargs.get("autoBoost") == "DISABLED"
+
+
+class TestHeatingGetScheduleNNLKeyError:
+ """Lines 438-439: KeyError in get_schedule_now_next_later."""
+
+ async def test_missing_schedule_key_returns_none(self):
+ """Product with state but no 'schedule' key causes KeyError → returns None."""
+ climate = _make_climate(
+ {"heat-1": {"state": {"mode": "SCHEDULE"}}}
+ # no 'schedule' key inside state
+ )
+ # Override get_mode to return SCHEDULE directly so the if-branch is entered
+ climate.session.helper.get_schedule_nnl.side_effect = KeyError("schedule")
+ # get_mode will read data["state"]["mode"] == "SCHEDULE" → enters the try block
+ # data["state"]["schedule"] raises KeyError → caught, returns None
+ result = await climate.get_schedule_now_next_later(_make_device())
+ assert result is None
+
+ async def test_schedule_key_error_caught_not_raised(self):
+ """A KeyError inside the try block does not propagate to the caller."""
+ climate = _make_climate({"heat-1": {"state": {"mode": "SCHEDULE"}}})
+ # Accessing data["state"]["schedule"] will raise KeyError (key absent)
+ try:
+ result = await climate.get_schedule_now_next_later(_make_device())
+ except KeyError:
+ pytest.fail(
+ "KeyError should have been caught inside get_schedule_now_next_later"
+ )
+ assert result is None
+
+
+class TestHeatingSetBoostOffScheduleMode:
+ """Lines 321->325: prev_mode not in ('MANUAL','OFF') — target kwarg not added."""
+
+ async def test_schedule_mode_no_target_kwarg(self):
+ """SCHEDULE as previous mode does not add a target kwarg."""
+ climate = _make_climate(
+ {
+ "heat-1": {
+ "type": "heating",
+ "state": {"boost": 5},
+ "props": {"previous": {"mode": "SCHEDULE"}},
+ }
+ }
+ )
+ result = await climate.set_boost_off(_make_device())
+ assert result is True
+ _, kwargs = climate.session.api.set_state.call_args
+ assert "target" not in kwargs
+ assert kwargs.get("mode") == "SCHEDULE"
+
+
+class TestHeatingGetClimateCacheMiss:
+ """Lines 371->377: cache enabled but cached device is None → normal execution."""
+
+ async def test_cached_none_falls_through_to_normal_path(self):
+ """should_use_cached_data=True but get_cached_device=None → normal update."""
+ climate = _make_climate(
+ {
+ "heat-1": {
+ "state": {"mode": "MANUAL", "target": 20.0},
+ "props": {"temperature": 19.0},
+ }
+ },
+ devices={"dev-1": {"state": {}, "props": {}}},
+ )
+ climate.session.should_use_cached_data.return_value = True
+ climate.session.get_cached_device.return_value = None
+ result = await climate.get_climate(_make_device())
+ assert result is not None
+ climate.session.attr.online_offline.assert_called_once()
diff --git a/tests/unit/test_heating_extended.py b/tests/unit/test_heating_extended.py
deleted file mode 100644
index e6223dd..0000000
--- a/tests/unit/test_heating_extended.py
+++ /dev/null
@@ -1,238 +0,0 @@
-"""Extended branch-coverage tests for Climate / HiveHeating."""
-
-# pylint: disable=too-few-public-methods
-from datetime import datetime
-from unittest.mock import AsyncMock, MagicMock
-
-from apyhiveapi.devices.heating import Climate
-from apyhiveapi.helper.hivedataclasses import Device, SessionConfig
-from apyhiveapi.helper.map import Map
-
-_TODAY = str(datetime.date(datetime.now()))
-_CURRENT_TEMP = 19.0
-_SCHEDULE_MODE = "SCHEDULE"
-_BOOST_MINS = 5
-_OFF_MODE = "OFF"
-
-
-def _make_climate(products=None, devices=None, min_max=None):
- session = MagicMock()
- session.data = Map(
- {
- "products": products or {},
- "devices": devices or {},
- "actions": {},
- "minMax": min_max or {},
- "user": {},
- }
- )
- session.config = SessionConfig()
- session.helper = MagicMock()
- session.helper.device_recovered = MagicMock()
- session.helper.error_check = AsyncMock()
- session.helper.get_schedule_nnl = MagicMock(
- return_value={"now": {}, "next": {}, "later": {}}
- )
- session.attr = MagicMock()
- session.attr.online_offline = AsyncMock(return_value=True)
- session.attr.state_attributes = AsyncMock(return_value={})
- session.api = MagicMock()
- session.api.set_state = AsyncMock(return_value={"original": 200, "parsed": {}})
- session.hive_refresh_tokens = AsyncMock()
- session.get_devices = AsyncMock(return_value=True)
- session.should_use_cached_data = MagicMock(return_value=False)
- session.get_cached_device = MagicMock(return_value=None)
- session.set_cached_device = MagicMock(side_effect=lambda d: d)
- return Climate(session=session)
-
-
-def _make_device(hive_id="heat-1", device_id="dev-1", hive_type="heating"):
- return Device(
- hive_id=hive_id,
- hive_name="Hallway",
- hive_type=hive_type,
- ha_type="climate",
- device_id=device_id,
- device_name="Hallway",
- device_data={"online": True},
- ha_name="Hallway",
- )
-
-
-class TestGetCurrentTemperature:
- async def test_minmax_today_same_date_updates_min_max(self):
- """When minMax entry exists for today's date, TodayMin/TodayMax are updated."""
- initial_min = _CURRENT_TEMP + 2.0
- initial_max = _CURRENT_TEMP - 2.0
- existing = {
- "TodayMin": initial_min,
- "TodayMax": initial_max,
- "TodayDate": _TODAY,
- "RestartMin": initial_min,
- "RestartMax": initial_max,
- }
- climate = _make_climate(
- products={"heat-1": {"props": {"temperature": _CURRENT_TEMP}}},
- min_max={"heat-1": existing},
- )
- result = await climate.get_current_temperature(_make_device())
- assert result == _CURRENT_TEMP
- entry = climate.session.data.minMax["heat-1"]
- assert entry["TodayMin"] == min(initial_min, _CURRENT_TEMP)
- assert entry["TodayMax"] == max(initial_max, _CURRENT_TEMP)
- assert entry["RestartMin"] == min(initial_min, _CURRENT_TEMP)
- assert entry["RestartMax"] == max(initial_max, _CURRENT_TEMP)
-
- async def test_minmax_different_date_resets_today(self):
- """When minMax entry exists but TodayDate is stale, today values are reset."""
- existing = {
- "TodayMin": 5.0,
- "TodayMax": 30.0,
- "TodayDate": "2000-01-01",
- "RestartMin": 5.0,
- "RestartMax": 30.0,
- }
- climate = _make_climate(
- products={"heat-1": {"props": {"temperature": _CURRENT_TEMP}}},
- min_max={"heat-1": existing},
- )
- result = await climate.get_current_temperature(_make_device())
- assert result == _CURRENT_TEMP
- entry = climate.session.data.minMax["heat-1"]
- assert entry["TodayMin"] == _CURRENT_TEMP
- assert entry["TodayMax"] == _CURRENT_TEMP
- assert entry["TodayDate"] == _TODAY
-
- async def test_keyerror_returns_none(self):
- """Missing device.hive_id in products returns None."""
- climate = _make_climate(products={})
- result = await climate.get_current_temperature(_make_device())
- assert result is None
-
-
-class TestGetTargetTemperature:
- async def test_non_numeric_target_returns_none(self):
- """Non-numeric target temperature string returns None."""
- climate = _make_climate({"heat-1": {"state": {"target": "N/A"}}})
- result = await climate.get_target_temperature(_make_device())
- assert result is None
-
-
-class TestGetState:
- async def test_current_less_than_target_returns_on(self):
- """When current_temp < target_temp, state resolves to ON."""
- climate = _make_climate(
- {
- "heat-1": {
- "props": {"temperature": 19.0},
- "state": {"target": 21.0},
- }
- }
- )
- result = await climate.get_state(_make_device())
- assert result == "ON"
-
- async def test_current_ge_target_returns_off(self):
- """When current_temp >= target_temp, state resolves to OFF."""
- climate = _make_climate(
- {
- "heat-1": {
- "props": {"temperature": 21.0},
- "state": {"target": 19.0},
- }
- }
- )
- result = await climate.get_state(_make_device())
- assert result == "OFF"
-
- async def test_none_temps_returns_none(self):
- """When temperatures cannot be read, get_state returns None."""
- climate = _make_climate(products={})
- result = await climate.get_state(_make_device())
- assert result is None
-
-
-class TestGetCurrentOperation:
- async def test_returns_working_state(self):
- """get_current_operation returns the 'working' value from props."""
- climate = _make_climate({"heat-1": {"props": {"working": True}, "state": {}}})
- result = await climate.get_current_operation(_make_device())
- assert result is True
-
-
-class TestSetBoostOff:
- async def test_not_in_products_returns_false(self):
- """Device hive_id not present in products returns False."""
- climate = _make_climate(products={})
- result = await climate.set_boost_off(_make_device())
- assert result is False
-
- async def test_previous_off_mode_restored(self):
- """Previous mode OFF sets mode=OFF and target falls back to 7."""
- climate = _make_climate(
- {
- "heat-1": {
- "type": "heating",
- "state": {"boost": _BOOST_MINS},
- "props": {
- "previous": {
- "mode": _OFF_MODE,
- "target": None,
- }
- },
- }
- }
- )
- result = await climate.set_boost_off(_make_device())
- assert result is True
- _, kwargs = climate.session.api.set_state.call_args
- assert kwargs.get("mode") == _OFF_MODE
- assert kwargs.get("target") == 7
-
-
-class TestGetClimate:
- async def test_device_data_not_dict_gets_reset(self):
- """Non-dict device_data is replaced with an empty dict before use."""
- climate = _make_climate(
- products={"heat-1": {"props": {}, "state": {}}},
- devices={"dev-1": {"props": {}, "parent": None}},
- )
- d = _make_device()
- d.device_data = None
- await climate.get_climate(d)
- assert isinstance(d.device_data, dict)
-
- async def test_offline_device_error_check_called(self):
- """Offline device triggers error_check and status defaults to None values."""
- climate = _make_climate(
- products={"heat-1": {}},
- devices={"dev-1": {}},
- )
- climate.session.attr.online_offline = AsyncMock(return_value=False)
- d = _make_device()
- result = await climate.get_climate(d)
- climate.session.helper.error_check.assert_called_once()
- assert result.status["current_temperature"] is None
-
- async def test_cache_hit_returns_cached(self):
- """When cached data is available and poll is slow, returns cached device."""
- climate = _make_climate()
- climate.session.should_use_cached_data = MagicMock(return_value=True)
- cached_device = _make_device()
- cached_device.status = {"current_temperature": 20.0}
- climate.session.get_cached_device = MagicMock(return_value=cached_device)
- d = _make_device()
- result = await climate.get_climate(d)
- assert result is cached_device
- climate.session.attr.online_offline.assert_not_called()
-
-
-class TestGetScheduleNowNextLater:
- async def test_offline_returns_none(self):
- """Offline device returns None regardless of mode."""
- climate = _make_climate(
- {"heat-1": {"state": {"mode": _SCHEDULE_MODE, "schedule": {}}}}
- )
- climate.session.attr.online_offline = AsyncMock(return_value=False)
- result = await climate.get_schedule_now_next_later(_make_device())
- assert result is None
diff --git a/tests/unit/test_helpers.py b/tests/unit/test_helpers.py
index 1f84c3e..8ee6b13 100644
--- a/tests/unit/test_helpers.py
+++ b/tests/unit/test_helpers.py
@@ -38,11 +38,7 @@ class TestEpochTime:
"""Tests for the top-level epoch_time() helper function."""
def test_to_epoch_returns_int(self):
- """to_epoch converts a date string to an integer Unix timestamp.
-
- Note: epoch_time ignores the *pattern* argument for "to_epoch" —
- it always applies "%d.%m.%Y %H:%M:%S" internally.
- """
+ """to_epoch converts a date string to an integer Unix timestamp."""
result = epoch_time("01.01.2024 12:00:00", "%d.%m.%Y %H:%M:%S", "to_epoch")
assert isinstance(result, int)
@@ -198,6 +194,31 @@ def test_non_string_value_under_sensitive_key_passes_through(self):
result = helper.sanitize_payload({"token": _non_string_int})
assert result["token"] == _non_string_int
+ def test_dict_under_sensitive_key_is_recursively_masked(self):
+ """A dict value under a sensitive key has its own values masked."""
+ helper, _ = _make_helper()
+ result = helper.sanitize_payload({"token": {"inner_key": "secret_value"}})
+ assert isinstance(result["token"], dict)
+ assert "inner_key" in result["token"]
+ assert result["token"]["inner_key"] != "secret_value"
+
+ def test_list_value_under_sensitive_key_masks_each_element(self):
+ """A list under a sensitive key has each element masked individually."""
+ helper, _ = _make_helper()
+ payload = {"token": ["short", "averylongtoken123"]}
+ result = helper.sanitize_payload(payload)
+ assert result["token"] == ["***", "aver...n123"]
+
+ def test_none_under_sensitive_key_passes_through(self):
+ """None under a sensitive key is returned unchanged."""
+ helper, _ = _make_helper()
+ assert helper.sanitize_payload({"token": None})["token"] is None
+
+ def test_bool_under_sensitive_key_passes_through(self):
+ """A bool under a sensitive key is returned unchanged."""
+ helper, _ = _make_helper()
+ assert helper.sanitize_payload({"token": True})["token"] is True
+
# ---------------------------------------------------------------------------
# HiveHelper.device_recovered
@@ -472,3 +493,183 @@ def test_heating_matches_by_zone(self):
}
result = helper.get_device_data(product)
assert result["id"] == "thermo-1"
+
+
+# ===========================================================================
+# Migrated from test_remaining_branches.py
+# ===========================================================================
+
+
+class TestHiveHelperZoneMismatch:
+ """hive_helper.py 163->160: loop continues when zones don't match."""
+
+ def test_zone_mismatch_keeps_product_as_device(self):
+ """When a Thermo device's zone doesn't match the product's zone,
+ the loop arc 163->160 is taken and device stays as the product."""
+ helper, _ = _make_helper(
+ devices={
+ "thermo-1": {
+ "type": "thermostatui",
+ "props": {"zone": "zone-B"},
+ }
+ }
+ )
+
+ product = {
+ "type": "heating",
+ "id": "prod-1",
+ "props": {"zone": "zone-A"}, # different zone from thermo-1
+ }
+
+ result = helper.get_device_data(product)
+ # The zone mismatch means device was never re-assigned; returns the product
+ assert result is product
+
+ def test_trv_without_zone_does_not_log_warning(self, caplog):
+ """TRV devices that omit 'zone' from props are silently skipped (no warning)."""
+ import logging
+
+ helper, _ = _make_helper(
+ devices={
+ "trv-1": {
+ "type": "trv",
+ "props": {"online": True}, # no 'zone' key — current API behaviour
+ }
+ }
+ )
+
+ product = {
+ "type": "heating",
+ "id": "prod-1",
+ "props": {"zone": "zone-A"},
+ }
+
+ with caplog.at_level(logging.WARNING, logger="apyhiveapi.helper.hive_helper"):
+ result = helper.get_device_data(product)
+
+ assert result is product
+ assert not caplog.records, (
+ f"Unexpected warnings: {[r.getMessage() for r in caplog.records]}"
+ )
+
+
+class TestHiveHelperSanitizeListNode:
+ """hive_helper.py line 359: list value under a non-sensitive key calls _walk(list)."""
+
+ def test_list_under_non_sensitive_key_is_walked(self):
+ """A list value under a non-sensitive key hits the isinstance(node, list) branch."""
+ helper, _ = _make_helper()
+ result = helper.sanitize_payload({"devices": ["device-a", "device-b"]})
+ # 'devices' is not a sensitive key → _walk called for the list
+ # _walk for a list returns [_walk(item) for item in node]
+ # Each string item: _walk(str) → str (falls through to return node)
+ assert result == {"devices": ["device-a", "device-b"]}
+
+ def test_list_containing_dicts_is_walked_recursively(self):
+ """A list of dicts under a non-sensitive key is recursively processed."""
+ helper, _ = _make_helper()
+ result = helper.sanitize_payload(
+ {
+ "items": [
+ {"token": "abc", "name": "device1"},
+ {"token": "xyz", "name": "device2"},
+ ]
+ }
+ )
+ # 'items' is not sensitive → _walk called for the list
+ # Each dict in the list is processed by _walk
+ # 'token' IS sensitive → masked in each sub-dict
+ assert result["items"][0]["name"] == "device1"
+ assert result["items"][0]["token"] != "abc"
+ assert result["items"][1]["name"] == "device2"
+ assert result["items"][1]["token"] != "xyz"
+
+
+# ---------------------------------------------------------------------------
+# Migrated from test_hive_helper_extended.py
+# ---------------------------------------------------------------------------
+
+
+class TestGetDeviceFromIdBranch:
+ """Covers the branch where no cache entry matches the requested ID."""
+
+ def test_returns_false_when_no_match_in_cache(self):
+ """When entity_cache has entries but none match n_id, returns False."""
+ other_device = Device(
+ hive_id="other-hive-id",
+ hive_name="Other",
+ hive_type="heating",
+ ha_type="climate",
+ device_id="other-device-id",
+ device_name="Other",
+ device_data={},
+ )
+ helper, _ = _make_helper(entity_cache={"other-key": other_device})
+ result = helper.get_device_from_id("nonexistent-id")
+ assert result is False
+
+ def test_returns_false_when_cache_is_empty(self):
+ """When entity_cache is empty, returns False without entering the loop."""
+ helper, _ = _make_helper(entity_cache={})
+ assert helper.get_device_from_id("any-id") is False
+
+
+class TestEpochTimePattern:
+ """epoch_time to_epoch must honour the pattern argument."""
+
+ def test_to_epoch_uses_caller_pattern(self):
+ """Passing a custom pattern must parse the date string with that pattern."""
+ result = epoch_time("2024-06-15", "%Y-%m-%d", "to_epoch")
+ assert isinstance(result, int), "Expected int epoch timestamp"
+ assert result > 0
+
+ def test_to_epoch_standard_hive_format_still_works(self):
+ """The standard Hive date+time format must still parse correctly."""
+ result = epoch_time("15.06.2024 12:00:00", "%d.%m.%Y %H:%M:%S", "to_epoch")
+ assert isinstance(result, int)
+ assert result > 0
+
+
+def _sample_schedule():
+ """Minimal 7-day schedule with 3 slots on every day."""
+ days = [
+ "monday",
+ "tuesday",
+ "wednesday",
+ "thursday",
+ "friday",
+ "saturday",
+ "sunday",
+ ]
+ schedule = {}
+ for d in days:
+ schedule[d] = [
+ {"start": 0, "value": {"status": "ON"}},
+ {"start": 480, "value": {"status": "OFF"}},
+ {"start": 1200, "value": {"status": "ON"}},
+ ]
+ return schedule
+
+
+class TestGetScheduleNnlMutation:
+ """get_schedule_nnl must not mutate the input schedule dicts."""
+
+ def test_second_call_returns_same_result_as_first_call(self):
+ """Calling get_schedule_nnl twice on the same schedule dict gives consistent results."""
+ h, _ = _make_helper()
+ schedule = _sample_schedule()
+ result1 = h.get_schedule_nnl(schedule)
+ result2 = h.get_schedule_nnl(schedule)
+ assert result1.get("now", {}).get("value") == result2.get("now", {}).get(
+ "value"
+ ), "Second call returned different 'now' value — schedule was mutated in-place"
+
+ def test_input_schedule_slots_not_modified(self):
+ """Slot dicts in the input schedule must not gain 'Start_DateTime' after the call."""
+ h, _ = _make_helper()
+ schedule = _sample_schedule()
+ monday_slot_before = dict(schedule["monday"][0])
+ h.get_schedule_nnl(schedule)
+ assert schedule["monday"][0] == monday_slot_before, (
+ "get_schedule_nnl mutated the original slot dict"
+ )
diff --git a/tests/unit/test_hive_api.py b/tests/unit/test_hive_api.py
index 8b6a182..6e272bd 100644
--- a/tests/unit/test_hive_api.py
+++ b/tests/unit/test_hive_api.py
@@ -152,6 +152,27 @@ def test_request_passes_timeout(self):
class TestGetLoginInfo:
+ def test_tls_verification_is_not_disabled(self):
+ """The SSO bootstrap request must not pass verify=False."""
+ api = _make_api()
+ html_content = (
+ b""
+ )
+ mock_resp = MagicMock()
+ mock_resp.content = html_content
+ mock_resp.status_code = 200
+
+ with patch(
+ "apyhiveapi.api.hive_api.requests.get", return_value=mock_resp
+ ) as mock_get:
+ api.get_login_info()
+
+ _, call_kwargs = mock_get.call_args
+ assert call_kwargs.get("verify", True) is not False
+
def test_successful_parse_returns_login_data(self):
"""Parses HiveSSOPoolId and HiveSSOPublicCognitoClientId from the SSO page."""
api = _make_api()
@@ -213,119 +234,6 @@ def test_key_error_calls_error_and_returns_none(self):
assert result is None
-# ---------------------------------------------------------------------------
-# Tests: HiveApi.refresh_tokens
-# ---------------------------------------------------------------------------
-
-
-class TestRefreshTokens:
- def test_successful_with_token_key_updates_session(self):
- """When the response contains 'token', session.update_tokens is called."""
- api = _make_api()
- refresh_data = {
- "token": "new-token",
- "platform": {"endpoint": "https://new.endpoint.com"},
- }
- mock_resp = _make_mock_response(
- 200, json_data=refresh_data, text=json.dumps(refresh_data)
- )
-
- with patch.object(api, "request", return_value=mock_resp):
- result = api.refresh_tokens()
-
- api.session.update_tokens.assert_called_once_with(refresh_data)
- assert result["original"] == 200
-
- def test_no_token_in_response_no_session_update(self):
- """When response lacks 'token' key, update_tokens is not called."""
- api = _make_api()
- response_data = {"other_key": "value"}
- mock_resp = _make_mock_response(
- 200, json_data=response_data, text=json.dumps(response_data)
- )
-
- with patch.object(api, "request", return_value=mock_resp):
- api.refresh_tokens()
-
- api.session.update_tokens.assert_not_called()
-
- def test_none_tokens_defaults_to_empty_dict(self):
- """Calling refresh_tokens() without arguments uses session.token_data."""
- api = _make_api()
- response_data = {"other": "val"}
- mock_resp = _make_mock_response(
- 200, json_data=response_data, text=json.dumps(response_data)
- )
-
- with patch.object(api, "request", return_value=mock_resp) as mock_req:
- api.refresh_tokens()
- # Should have been called (session provides the tokens dict)
- mock_req.assert_called_once()
-
- def test_os_error_calls_error(self):
- api = _make_api()
- with patch.object(api, "request", side_effect=OSError("connection failed")):
- api.refresh_tokens()
-
- assert api.json_return["original"] == "Error making API call"
-
- def test_runtime_error_calls_error(self):
- api = _make_api()
- with patch.object(api, "request", side_effect=RuntimeError("fail")):
- api.refresh_tokens()
-
- assert api.json_return["original"] == "Error making API call"
-
- def test_json_decode_error_calls_error(self):
- """Bad JSON in response text triggers error()."""
- api = _make_api()
- mock_resp = MagicMock()
- mock_resp.status_code = 200
- mock_resp.text = "not-json"
-
- with patch.object(api, "request", return_value=mock_resp):
- api.refresh_tokens()
-
- assert api.json_return["original"] == "Error making API call"
-
- def test_explicit_tokens_arg_skips_none_branch(self):
- """Passing a non-None tokens arg covers the 80->82 False branch."""
- api = _make_api()
- explicit_tokens = {"key": "val"}
- response_data = {"other": "x"}
- mock_resp = _make_mock_response(200, json_data=response_data)
-
- with patch.object(api, "request", return_value=mock_resp):
- api.refresh_tokens(tokens=explicit_tokens)
- # Session is not None so session tokens overwrite, but no crash
- api.session.update_tokens.assert_not_called()
-
- def test_session_none_skips_token_overwrite(self):
- """When session is None the 83->85 False branch is taken (no token overwrite)."""
- api = _make_api_no_session(token="standalone-token")
- response_data = {"other": "x"}
- mock_resp = _make_mock_response(200, json_data=response_data)
-
- with patch.object(api, "request", return_value=mock_resp):
- api.refresh_tokens(tokens={"key": "val"})
-
- def test_urls_base_updated_on_token_refresh(self):
- """After a successful refresh the base URL is updated from the response."""
- api = _make_api()
- refresh_data = {
- "token": "new-tok",
- "platform": {"endpoint": "https://new-platform.com/1.0"},
- }
- mock_resp = _make_mock_response(
- 200, json_data=refresh_data, text=json.dumps(refresh_data)
- )
-
- with patch.object(api, "request", return_value=mock_resp):
- api.refresh_tokens()
-
- assert api.urls["base"] == "https://new-platform.com/1.0"
-
-
# ---------------------------------------------------------------------------
# Tests: HiveApi.get_all
# ---------------------------------------------------------------------------
@@ -343,14 +251,13 @@ def test_successful_returns_original_and_parsed(self):
assert result["original"] == 200
assert result["parsed"] == payload
- def test_none_response_logs_error_and_returns_empty(self):
+ def test_none_response_logs_error_and_returns_no_response_marker(self):
"""When request returns None the method should not crash."""
api = _make_api()
with patch.object(api, "request", return_value=None):
result = api.get_all()
- # No keys populated — dict remains empty
- assert "original" not in result
+ assert result["original"] == "No response to Hive API request"
def test_os_error_calls_error_method(self):
api = _make_api()
@@ -584,8 +491,7 @@ def test_none_response_logs_error_no_crash(self):
with patch.object(api, "request", return_value=None):
result = api.set_state("heating", "node-1", mode="MANUAL")
- # json_return stays at default (unchanged from init defaults)
- assert result is api.json_return
+ assert result["original"] == "No response to Hive API request"
def test_os_error_calls_error(self):
api = _make_api()
@@ -626,6 +532,28 @@ def test_kwargs_serialised_into_jsc(self):
assert "target" in jsc_arg
assert "21" in jsc_arg
+ def test_payload_is_valid_json_with_native_types(self):
+ """The payload must round-trip through json.loads with types intact."""
+ api = _make_api()
+ mock_resp = _make_mock_response(200, json_data={})
+
+ with patch.object(api, "request", return_value=mock_resp) as mock_req:
+ api.set_state("heating", "n1", mode="SCHEDULE", target=21.5)
+
+ jsc_arg = mock_req.call_args[0][2]
+ assert json.loads(jsc_arg) == {"mode": "SCHEDULE", "target": 21.5}
+
+ def test_payload_with_quotes_is_valid_json(self):
+ """Values containing quotes must not break or inject into the JSON."""
+ api = _make_api()
+ mock_resp = _make_mock_response(200, json_data={})
+
+ with patch.object(api, "request", return_value=mock_resp) as mock_req:
+ api.set_state("heating", "n1", name='say "hi", "extra": "injected')
+
+ jsc_arg = mock_req.call_args[0][2]
+ assert json.loads(jsc_arg) == {"name": 'say "hi", "extra": "injected'}
+
# ---------------------------------------------------------------------------
# Tests: HiveApi.set_action
@@ -687,6 +615,52 @@ def test_runtime_error_calls_error(self):
assert api.json_return["original"] == "Error making API call"
+# ---------------------------------------------------------------------------
+# Tests: result isolation between calls
+# ---------------------------------------------------------------------------
+
+
+class TestResultIsolation:
+ def test_results_are_independent_between_calls(self):
+ """A later call must not mutate the dict returned by an earlier call."""
+ api = _make_api()
+ resp_devices = _make_mock_response(200, json_data=[{"id": "dev1"}])
+ resp_products = _make_mock_response(200, json_data=[{"id": "prod1"}])
+
+ with patch.object(api, "request", side_effect=[resp_devices, resp_products]):
+ devices = api.get_devices()
+ products = api.get_products()
+
+ assert devices is not products
+ assert devices["parsed"] == [{"id": "dev1"}]
+ assert products["parsed"] == [{"id": "prod1"}]
+
+ def test_error_call_does_not_corrupt_previous_result(self):
+ """An error in a later call must not overwrite an earlier result."""
+ api = _make_api()
+ resp_devices = _make_mock_response(200, json_data=[{"id": "dev1"}])
+
+ with patch.object(api, "request", side_effect=[resp_devices, OSError("down")]):
+ devices = api.get_devices()
+ failed = api.get_products()
+
+ assert devices["original"] == 200
+ assert devices["parsed"] == [{"id": "dev1"}]
+ assert failed["original"] == "Error making API call"
+
+ def test_set_state_none_response_does_not_return_stale_data(self):
+ """set_state with no response must not surface a previous call's payload."""
+ api = _make_api()
+ resp_devices = _make_mock_response(200, json_data=[{"id": "dev1"}])
+
+ with patch.object(api, "request", side_effect=[resp_devices, None]):
+ api.get_devices()
+ result = api.set_state("heating", "n1", mode="MANUAL")
+
+ assert result["parsed"] != [{"id": "dev1"}]
+ assert result["original"] == "No response to Hive API request"
+
+
# ---------------------------------------------------------------------------
# Tests: HiveApi.error
# ---------------------------------------------------------------------------
diff --git a/tests/unit/test_hive_async_api.py b/tests/unit/test_hive_async_api.py
index 86c1cd7..92d1027 100644
--- a/tests/unit/test_hive_async_api.py
+++ b/tests/unit/test_hive_async_api.py
@@ -1,8 +1,9 @@
"""Unit tests for HiveApiAsync."""
import asyncio
-from unittest.mock import AsyncMock, MagicMock
+from unittest.mock import AsyncMock, MagicMock, patch
+import aiohttp
import pytest
from aiohttp import web_exceptions
from apyhiveapi.api.hive_async_api import HiveApiAsync
@@ -65,50 +66,42 @@ def _make_api_no_token(_url_contains_sso=False):
class TestHiveApiAsyncRequest:
- @pytest.mark.asyncio
async def test_successful_200_returns_response(self):
api = _make_api(status=200, json_data={"ok": True})
resp = await api.request("get", "https://beekeeper.hivehome.com/1.0/nodes/all")
assert resp.status == 200
- @pytest.mark.asyncio
async def test_201_also_succeeds(self):
api = _make_api(status=201)
resp = await api.request("post", "https://beekeeper.hivehome.com/1.0/nodes/x/y")
assert resp.status == 201
- @pytest.mark.asyncio
async def test_sso_url_without_token_does_not_raise(self):
api = _make_api_no_token()
# Should not raise NoApiToken because "sso" is in the URL
resp = await api.request("get", "https://sso.hivehome.com/")
assert resp.status == 200
- @pytest.mark.asyncio
async def test_non_sso_without_token_raises_no_api_token(self):
api = _make_api_no_token()
with pytest.raises(NoApiToken):
await api.request("get", "https://beekeeper.hivehome.com/1.0/nodes/all")
- @pytest.mark.asyncio
async def test_401_raises_hive_auth_error(self):
api = _make_api(status=401)
with pytest.raises(HiveAuthError):
await api.request("get", "https://beekeeper.hivehome.com/1.0/nodes/all")
- @pytest.mark.asyncio
async def test_403_raises_hive_auth_error(self):
api = _make_api(status=403)
with pytest.raises(HiveAuthError):
await api.request("get", "https://beekeeper.hivehome.com/1.0/nodes/all")
- @pytest.mark.asyncio
async def test_500_raises_hive_api_error(self):
api = _make_api(status=500)
with pytest.raises(HiveApiError):
await api.request("get", "https://beekeeper.hivehome.com/1.0/nodes/all")
- @pytest.mark.asyncio
async def test_404_raises_hive_api_error(self):
api = _make_api(status=404)
with pytest.raises(HiveApiError):
@@ -121,7 +114,6 @@ async def test_404_raises_hive_api_error(self):
class TestGetAll:
- @pytest.mark.asyncio
async def test_successful_get_all_returns_parsed_json(self):
payload = {"products": [], "devices": []}
api = _make_api(status=200, json_data=payload)
@@ -129,21 +121,18 @@ async def test_successful_get_all_returns_parsed_json(self):
assert result["original"] == 200
assert result["parsed"] == payload
- @pytest.mark.asyncio
async def test_timeout_error_propagates(self):
api = _make_api(status=200)
api.websession.request.side_effect = asyncio.TimeoutError
with pytest.raises(asyncio.TimeoutError):
await api.get_all()
- @pytest.mark.asyncio
async def test_os_error_calls_error_method(self):
api = _make_api(status=200)
api.websession.request.side_effect = OSError("network down")
with pytest.raises(web_exceptions.HTTPError):
await api.get_all()
- @pytest.mark.asyncio
async def test_runtime_error_calls_error_method(self):
api = _make_api(status=200)
api.websession.request.side_effect = RuntimeError("boom")
@@ -157,7 +146,6 @@ async def test_runtime_error_calls_error_method(self):
class TestGetEndpoints:
- @pytest.mark.asyncio
async def test_get_devices_returns_parsed_json(self):
payload = [{"id": "dev1"}]
api = _make_api(status=200, json_data=payload)
@@ -165,7 +153,6 @@ async def test_get_devices_returns_parsed_json(self):
assert result["original"] == 200
assert result["parsed"] == payload
- @pytest.mark.asyncio
async def test_get_products_returns_parsed_json(self):
payload = [{"id": "prod1"}]
api = _make_api(status=200, json_data=payload)
@@ -173,7 +160,6 @@ async def test_get_products_returns_parsed_json(self):
assert result["original"] == 200
assert result["parsed"] == payload
- @pytest.mark.asyncio
async def test_get_actions_returns_parsed_json(self):
payload = [{"id": "act1"}]
api = _make_api(status=200, json_data=payload)
@@ -181,21 +167,18 @@ async def test_get_actions_returns_parsed_json(self):
assert result["original"] == 200
assert result["parsed"] == payload
- @pytest.mark.asyncio
async def test_get_devices_os_error_raises_http_error(self):
api = _make_api(status=200)
api.websession.request.side_effect = OSError
with pytest.raises(web_exceptions.HTTPError):
await api.get_devices()
- @pytest.mark.asyncio
async def test_get_products_os_error_raises_http_error(self):
api = _make_api(status=200)
api.websession.request.side_effect = OSError
with pytest.raises(web_exceptions.HTTPError):
await api.get_products()
- @pytest.mark.asyncio
async def test_get_actions_os_error_raises_http_error(self):
api = _make_api(status=200)
api.websession.request.side_effect = OSError
@@ -209,13 +192,11 @@ async def test_get_actions_os_error_raises_http_error(self):
class TestSetState:
- @pytest.mark.asyncio
async def test_file_in_use_returns_file_response(self):
api = _make_api(status=200, file_mode=True)
result = await api.set_state("heating", "node-1", mode="MANUAL")
assert result == {"original": "file"}
- @pytest.mark.asyncio
async def test_successful_set_state(self):
payload = {"id": "node-1", "mode": "MANUAL"}
api = _make_api(status=200, json_data=payload)
@@ -223,14 +204,12 @@ async def test_successful_set_state(self):
assert result["original"] == 200
assert result["parsed"] == payload
- @pytest.mark.asyncio
async def test_os_error_calls_error_method(self):
api = _make_api(status=200)
api.websession.request.side_effect = OSError("fail")
with pytest.raises(web_exceptions.HTTPError):
await api.set_state("heating", "node-1", mode="MANUAL")
- @pytest.mark.asyncio
async def test_runtime_error_calls_error_method(self):
api = _make_api(status=200)
api.websession.request.side_effect = RuntimeError("fail")
@@ -244,19 +223,24 @@ async def test_runtime_error_calls_error_method(self):
class TestSetAction:
- @pytest.mark.asyncio
async def test_file_in_use_returns_file_response(self):
api = _make_api(status=200, file_mode=True)
result = await api.set_action("action-1", '{"status": "on"}')
assert result == {"original": "file"}
- @pytest.mark.asyncio
- async def test_successful_set_action_returns_json_return(self):
- api = _make_api(status=200)
+ async def test_successful_set_action_returns_status_200(self):
+ payload = {"id": "action-1", "status": "on"}
+ api = _make_api(status=200, json_data=payload)
result = await api.set_action("action-1", '{"status": "on"}')
- assert result == api.json_return
+ assert result["original"] == 200
+ assert result["parsed"] == payload
+
+ async def test_runtime_error_calls_error_method(self):
+ api = _make_api(status=200)
+ api.websession.request.side_effect = RuntimeError("fail")
+ with pytest.raises(web_exceptions.HTTPError):
+ await api.set_action("action-1", "{}")
- @pytest.mark.asyncio
async def test_os_error_calls_error_method(self):
api = _make_api(status=200)
api.websession.request.side_effect = OSError
@@ -264,13 +248,45 @@ async def test_os_error_calls_error_method(self):
await api.set_action("action-1", "{}")
+# ---------------------------------------------------------------------------
+# Tests: HiveApiAsync.motion_sensor
+# ---------------------------------------------------------------------------
+
+
+class TestMotionSensor:
+ async def test_url_does_not_double_base_url(self):
+ payload = [{"timestamp": 12345}]
+ api = _make_api(status=200, json_data=payload)
+ captured = {}
+ original_request = api.request
+
+ async def capture_request(method, url, **kwargs):
+ captured["url"] = url
+ return await original_request(method, url, **kwargs)
+
+ api.request = capture_request
+ sensor = {"type": "motionsensor", "id": "ms-001"}
+ await api.motion_sensor(sensor, 1000000, 2000000)
+ url = captured["url"]
+ assert url.startswith(api.base_url + "/products/")
+ assert "motionsensor/ms-001" in url
+ assert url.count("https://beekeeper") == 1
+
+ async def test_motion_sensor_returns_parsed_json(self):
+ payload = [{"timestamp": 12345}]
+ api = _make_api(status=200, json_data=payload)
+ sensor = {"type": "motionsensor", "id": "ms-001"}
+ result = await api.motion_sensor(sensor, 1000000, 2000000)
+ assert result["original"] == 200
+ assert result["parsed"] == payload
+
+
# ---------------------------------------------------------------------------
# Tests: HiveApiAsync.error
# ---------------------------------------------------------------------------
class TestError:
- @pytest.mark.asyncio
async def test_error_raises_http_error(self):
api = _make_api()
with pytest.raises(web_exceptions.HTTPError):
@@ -283,13 +299,11 @@ async def test_error_raises_http_error(self):
class TestIsFileBeingUsed:
- @pytest.mark.asyncio
async def test_file_mode_raises_file_in_use(self):
api = _make_api(file_mode=True)
with pytest.raises(FileInUse):
await api.is_file_being_used()
- @pytest.mark.asyncio
async def test_not_file_mode_does_not_raise(self):
api = _make_api(file_mode=False)
await api.is_file_being_used() # Should not raise
@@ -301,14 +315,26 @@ async def test_not_file_mode_does_not_raise(self):
class TestInit:
- async def test_default_websession_created_when_none_passed(self):
+ def test_no_websession_created_at_init(self):
+ """No ClientSession is created in the sync constructor."""
session = MagicMock()
session.tokens = MagicMock()
session.tokens.token_data = {"token": "tok"}
session.config = MagicMock()
api = HiveApiAsync(hive_session=session)
- assert api.websession is not None
- await api.websession.close()
+ assert api.websession is None
+
+ def test_websession_created_lazily_and_cached(self):
+ """The first _get_websession() call creates and caches a ClientSession."""
+ session = MagicMock()
+ api = HiveApiAsync(hive_session=session)
+ with patch("apyhiveapi.api.hive_async_api.ClientSession") as mock_session_cls:
+ first = api._get_websession()
+ second = api._get_websession()
+ mock_session_cls.assert_called_once()
+ assert first is mock_session_cls.return_value
+ assert second is first
+ assert api.websession is first
def test_custom_websession_is_used(self):
session = MagicMock()
@@ -323,3 +349,191 @@ def test_base_url_is_set(self):
def test_default_timeout(self):
api = _make_api()
assert api.timeout == 5
+
+
+# ---------------------------------------------------------------------------
+# Migrated from test_hive_async_api_extended.py
+# ---------------------------------------------------------------------------
+
+
+class TestRequestNonAuthErrorBranch:
+ """Cover lines 100-108: url/status not None branch leading to HiveApiError."""
+
+ async def test_404_logs_and_raises_hive_api_error(self):
+ """A 404 falls through to the url/status branch and raises HiveApiError."""
+ api = _make_api(status=404)
+ with pytest.raises(HiveApiError):
+ await api.request("get", "https://beekeeper.hivehome.com/1.0/nodes/all")
+
+ async def test_503_logs_and_raises_hive_api_error(self):
+ """A 503 falls through to the url/status branch and raises HiveApiError."""
+ api = _make_api(status=503)
+ with pytest.raises(HiveApiError):
+ await api.request("get", "https://beekeeper.hivehome.com/1.0/nodes/all")
+
+ async def test_422_logs_and_raises_hive_api_error(self):
+ """A 422 also falls through (not 401/403) and raises HiveApiError."""
+ api = _make_api(status=422)
+ with pytest.raises(HiveApiError):
+ await api.request("get", "https://beekeeper.hivehome.com/1.0/devices")
+
+
+class TestMotionSensorBranches:
+ """Cover lines 215-235: motion_sensor() success and error paths."""
+
+ async def test_success_returns_status_and_parsed(self):
+ """Successful call returns status and parsed JSON."""
+ payload = [{"event": "motion", "timestamp": 1234567890}]
+ api = _make_api(status=200, json_data=payload)
+ api.urls["base"] = ""
+ sensor = {"type": "motionsensor", "id": "sensor-001"}
+ result = await api.motion_sensor(sensor, fromepoch=1000000, toepoch=2000000)
+ assert result["original"] == 200
+ assert result["parsed"] == payload
+
+ async def test_url_is_built_correctly(self):
+ """Verifies the URL is assembled with correct sensor type and id."""
+ api = _make_api(status=200, json_data=[])
+ api.urls["base"] = "https://beekeeper-uk.hivehome.com/1.0"
+ sensor = {"type": "contactsensor", "id": "abc-123"}
+ captured_url = []
+ original_request = api.request
+
+ async def capture_request(method, url, **kwargs):
+ captured_url.append(url)
+ return await original_request(method, url, **kwargs)
+
+ with patch.object(api, "request", side_effect=capture_request):
+ await api.motion_sensor(sensor, fromepoch=100, toepoch=200)
+ assert len(captured_url) == 1
+ assert "contactsensor" in captured_url[0]
+ assert "abc-123" in captured_url[0]
+ assert "from=100" in captured_url[0]
+ assert "to=200" in captured_url[0]
+
+ async def test_os_error_raises_http_error(self):
+ """OSError inside the try block causes error() → HTTPError."""
+ api = _make_api(status=200)
+ api.urls["base"] = ""
+ sensor = {"type": "motionsensor", "id": "sensor-001"}
+ api.websession.request.side_effect = OSError("fail")
+ with pytest.raises(web_exceptions.HTTPError):
+ await api.motion_sensor(sensor, fromepoch=1000, toepoch=2000)
+
+ async def test_runtime_error_raises_http_error(self):
+ """RuntimeError inside the try block causes error() → HTTPError."""
+ api = _make_api(status=200)
+ api.urls["base"] = ""
+ sensor = {"type": "motionsensor", "id": "sensor-002"}
+ api.websession.request.side_effect = RuntimeError("unexpected")
+ with pytest.raises(web_exceptions.HTTPError):
+ await api.motion_sensor(sensor, fromepoch=1000, toepoch=2000)
+
+ async def test_client_error_raises_http_error(self):
+ """aiohttp.ClientError inside the try block causes error() → HTTPError."""
+ api = _make_api(status=200)
+ api.urls["base"] = ""
+ sensor = {"type": "motionsensor", "id": "sensor-003"}
+ api.websession.request.side_effect = aiohttp.ClientError()
+ with pytest.raises(web_exceptions.HTTPError):
+ await api.motion_sensor(sensor, fromepoch=1000, toepoch=2000)
+
+
+class TestGetWeather:
+ """Cover lines 239-249: get_weather() success, space encoding, and error paths."""
+
+ async def test_success_returns_status_and_parsed(self):
+ """Successful call returns status and parsed weather JSON."""
+ payload = {"temperature": {"value": 15, "unit": "C"}}
+ api = _make_api(status=200, json_data=payload)
+ result = await api.get_weather("?lat=51.5&lon=-0.1")
+ assert result["original"] == 200
+ assert result["parsed"] == payload
+
+ async def test_space_in_weather_url_is_encoded(self):
+ """Spaces in the weather_url are replaced with %20."""
+ api = _make_api(status=200, json_data={})
+ captured_url = []
+ original_request = api.request
+
+ async def capture_request(method, url, **kwargs):
+ captured_url.append(url)
+ return await original_request(method, url, **kwargs)
+
+ with patch.object(api, "request", side_effect=capture_request):
+ await api.get_weather("?postcode=SW1A 2AA")
+ assert len(captured_url) == 1
+ assert " " not in captured_url[0]
+ assert "%20" in captured_url[0]
+
+ async def test_url_is_prefixed_with_weather_base(self):
+ """The weather base URL is prepended to the given weather_url."""
+ api = _make_api(status=200, json_data={})
+ captured_url = []
+ original_request = api.request
+
+ async def capture_request(method, url, **kwargs):
+ captured_url.append(url)
+ return await original_request(method, url, **kwargs)
+
+ with patch.object(api, "request", side_effect=capture_request):
+ await api.get_weather("?lat=51.5")
+ assert captured_url[0].startswith("https://weather.prod.bgchprod.info/weather")
+
+ async def test_os_error_raises_http_error(self):
+ """OSError inside the try block causes error() → HTTPError."""
+ api = _make_api(status=200)
+ api.websession.request.side_effect = OSError("network fail")
+ with pytest.raises(web_exceptions.HTTPError):
+ await api.get_weather("?lat=51.5")
+
+ async def test_runtime_error_raises_http_error(self):
+ """RuntimeError inside the try block causes error() → HTTPError."""
+ api = _make_api(status=200)
+ api.websession.request.side_effect = RuntimeError("unexpected")
+ with pytest.raises(web_exceptions.HTTPError):
+ await api.get_weather("?lat=51.5")
+
+ async def test_client_error_raises_http_error(self):
+ """aiohttp.ClientError inside the try block causes error() → HTTPError."""
+ api = _make_api(status=200)
+ api.websession.request.side_effect = aiohttp.ClientError()
+ with pytest.raises(web_exceptions.HTTPError):
+ await api.get_weather("?lat=51.5")
+
+ async def test_connection_error_raises_http_error(self):
+ """ConnectionError inside the try block causes error() → HTTPError."""
+ api = _make_api(status=200)
+ api.websession.request.side_effect = ConnectionError("disconnected")
+ with pytest.raises(web_exceptions.HTTPError):
+ await api.get_weather("?lat=51.5")
+
+
+class TestSetStateJsonEncoding:
+ """set_state must produce valid JSON even when kwarg values contain special characters."""
+
+ async def test_set_state_escapes_quotes_in_value(self):
+ """A value containing double-quotes must produce valid, parseable JSON."""
+ import json # noqa: PLC0415
+
+ session = MagicMock()
+ session.tokens.token_data = {"token": "tok"}
+ session.config.file = False
+ api = HiveApiAsync(hive_session=session)
+ api.urls = {"nodes": "https://beekeeper.hivehome.com/1.0/nodes/{}/{}"}
+
+ captured = {}
+
+ async def fake_request(_method, _url, **kwargs):
+ captured["data"] = kwargs.get("data")
+ resp = MagicMock()
+ resp.status = 200
+ resp.json = AsyncMock(return_value={})
+ return resp
+
+ with patch.object(api, "request", side_effect=fake_request):
+ with patch.object(api, "is_file_being_used", new=AsyncMock()):
+ await api.set_state("heating", "node-1", mode='MANUAL"injected')
+
+ parsed = json.loads(captured["data"])
+ assert parsed["mode"] == 'MANUAL"injected'
diff --git a/tests/unit/test_hive_async_api_extended.py b/tests/unit/test_hive_async_api_extended.py
deleted file mode 100644
index 5a9848c..0000000
--- a/tests/unit/test_hive_async_api_extended.py
+++ /dev/null
@@ -1,427 +0,0 @@
-"""Extended unit tests for HiveApiAsync — covers previously uncovered lines."""
-
-from unittest.mock import AsyncMock, MagicMock, patch
-
-import pytest
-from aiohttp import web_exceptions
-from apyhiveapi.api.hive_async_api import HiveApiAsync
-from apyhiveapi.helper.hive_exceptions import HiveApiError
-
-# ---------------------------------------------------------------------------
-# Shared helpers (same pattern as test_hive_async_api.py)
-# ---------------------------------------------------------------------------
-
-
-def _make_mock_response(status=200, json_data=None):
- resp = MagicMock()
- resp.status = status
- resp.text = AsyncMock(return_value="")
- resp.json = AsyncMock(return_value=json_data or {"data": "test"})
- resp.__aenter__ = AsyncMock(return_value=resp)
- resp.__aexit__ = AsyncMock(return_value=False)
- return resp
-
-
-def _make_api(status=200, json_data=None, token="test-token", file_mode=False):
- resp = _make_mock_response(status=status, json_data=json_data)
- websession = MagicMock()
- websession.request.return_value = resp
- websession.closed = False
- websession.close = AsyncMock()
- session = MagicMock()
- session.tokens = MagicMock()
- session.tokens.token_data = {"token": token}
- session.config = MagicMock()
- session.config.file = file_mode
- return HiveApiAsync(hive_session=session, websession=websession)
-
-
-# ---------------------------------------------------------------------------
-# Tests: request() branch — url is not None and status is not None (non-auth error)
-# ---------------------------------------------------------------------------
-
-
-class TestRequestNonAuthErrorBranch:
- """Cover lines 100-108: url/status not None branch leading to HiveApiError."""
-
- async def test_404_logs_and_raises_hive_api_error(self):
- """A 404 falls through to the url/status branch and raises HiveApiError."""
- api = _make_api(status=404)
- with pytest.raises(HiveApiError):
- await api.request("get", "https://beekeeper.hivehome.com/1.0/nodes/all")
-
- async def test_503_logs_and_raises_hive_api_error(self):
- """A 503 falls through to the url/status branch and raises HiveApiError."""
- api = _make_api(status=503)
- with pytest.raises(HiveApiError):
- await api.request("get", "https://beekeeper.hivehome.com/1.0/nodes/all")
-
- async def test_422_logs_and_raises_hive_api_error(self):
- """A 422 also falls through (not 401/403) and raises HiveApiError."""
- api = _make_api(status=422)
- with pytest.raises(HiveApiError):
- await api.request("get", "https://beekeeper.hivehome.com/1.0/devices")
-
-
-# ---------------------------------------------------------------------------
-# Tests: get_login_info() — sync method (lines 110-129)
-# ---------------------------------------------------------------------------
-
-
-class TestGetLoginInfo:
- """Cover lines 112-129: get_login_info() parses HTML and returns login dict."""
-
- def test_returns_upid_cliid_region(self):
- """Successful fetch returns correct keys from parsed HTML."""
- html_content = (
- b""
- )
- mock_response = MagicMock()
- mock_response.content = html_content
-
- api = _make_api()
- with patch(
- "apyhiveapi.api.hive_async_api.requests.get", return_value=mock_response
- ):
- result = api.get_login_info()
-
- assert result["UPID"] == "eu-west-1_abc123"
- assert result["CLIID"] == "client-xyz"
- # REGION is set to HiveSSOPoolId value
- assert result["REGION"] == "eu-west-1_abc123"
-
- def test_makes_request_to_sso_url(self):
- """Verifies requests.get is called with the SSO URL."""
- html_content = (
- b""
- )
- mock_response = MagicMock()
- mock_response.content = html_content
-
- api = _make_api()
- with patch(
- "apyhiveapi.api.hive_async_api.requests.get", return_value=mock_response
- ) as mock_get:
- api.get_login_info()
-
- mock_get.assert_called_once_with(
- url="https://sso.hivehome.com/", verify=False, timeout=api.timeout
- )
-
- def test_uses_first_script_tag(self):
- """PyQuery selects the first script — extra scripts are ignored."""
- html_content = (
- b""
- b''
- )
- mock_response = MagicMock()
- mock_response.content = html_content
-
- api = _make_api()
- with patch(
- "apyhiveapi.api.hive_async_api.requests.get", return_value=mock_response
- ):
- result = api.get_login_info()
-
- assert result["UPID"] == "eu-west-1_first"
-
-
-# ---------------------------------------------------------------------------
-# Tests: refresh_tokens() — lines 131-156
-# ---------------------------------------------------------------------------
-
-
-class TestRefreshTokens:
- """Cover lines 133-156: refresh_tokens() success, no-token, and error paths."""
-
- async def test_successful_request_with_non_ok_json_return_returns_json_return(self):
- """When request succeeds but json_return["original"] != HTTP_OK, returns json_return."""
- api = _make_api(status=200)
- # request() will succeed (200) but json_return is not updated by refresh_tokens
- # so json_return["original"] stays as the default string, not HTTP_OK (200)
- result = await api.refresh_tokens()
- # Returns self.json_return (the default dict)
- assert result == api.json_return
-
- async def test_session_tokens_read_before_request(self):
- """tokens are read from session.tokens.token_data before constructing the request."""
- api = _make_api(status=200, token="my-session-token")
- api.session.tokens.token_data = {
- "token": "my-session-token",
- "refreshToken": "r-tok",
- }
- result = await api.refresh_tokens()
- # No exception raised — tokens were read without error
- assert result is not None
-
- async def test_connection_error_raises_http_error(self):
- """ConnectionError inside the try block causes error() → HTTPError."""
- api = _make_api(status=200)
- api.websession.request.side_effect = ConnectionError("connection refused")
- with pytest.raises(web_exceptions.HTTPError):
- await api.refresh_tokens()
-
- async def test_os_error_raises_http_error(self):
- """OSError inside the try block causes error() → HTTPError."""
- api = _make_api(status=200)
- api.websession.request.side_effect = OSError("network error")
- with pytest.raises(web_exceptions.HTTPError):
- await api.refresh_tokens()
-
- async def test_runtime_error_raises_http_error(self):
- """RuntimeError inside the try block causes error() → HTTPError."""
- api = _make_api(status=200)
- api.websession.request.side_effect = RuntimeError("bad state")
- with pytest.raises(web_exceptions.HTTPError):
- await api.refresh_tokens()
-
- async def test_zero_division_raises_http_error(self):
- """ZeroDivisionError inside the try block causes error() → HTTPError."""
- api = _make_api(status=200)
- api.websession.request.side_effect = ZeroDivisionError("division by zero")
- with pytest.raises(web_exceptions.HTTPError):
- await api.refresh_tokens()
-
- async def test_json_return_true_when_ok_status_in_json_return(self):
- """When json_return["original"] equals HTTP_OK (200) and token is present,
- update_tokens is called and base_url is updated, returning True."""
- api = _make_api(status=200)
- # Manually set json_return to simulate a successful response
- api.json_return = {
- "original": 200,
- "parsed": {
- "token": "new-token",
- "platform": {"endpoint": "https://new.endpoint"},
- },
- }
- api.session.update_tokens = AsyncMock()
-
- # Patch request to be a no-op (doesn't modify json_return)
- with patch.object(api, "request", new_callable=AsyncMock) as mock_req:
- mock_req.return_value = MagicMock()
- result = await api.refresh_tokens()
-
- assert result is True
- api.session.update_tokens.assert_called_once_with(api.json_return["parsed"])
- assert api.base_url == "https://new.endpoint"
-
- async def test_json_return_true_without_token_in_parsed(self):
- """When json_return["original"] == HTTP_OK but no 'token' in parsed,
- update_tokens is NOT called and returns True."""
- api = _make_api(status=200)
- api.json_return = {
- "original": 200,
- "parsed": {"other_key": "value"},
- }
- api.session.update_tokens = AsyncMock()
-
- with patch.object(api, "request", new_callable=AsyncMock) as mock_req:
- mock_req.return_value = MagicMock()
- result = await api.refresh_tokens()
-
- assert result is True
- api.session.update_tokens.assert_not_called()
-
-
-# ---------------------------------------------------------------------------
-# Tests: motion_sensor() — lines 213-235
-# ---------------------------------------------------------------------------
-
-
-class TestMotionSensor:
- """Cover lines 215-235: motion_sensor() success and error paths."""
-
- async def test_success_returns_status_and_parsed(self):
- """Successful call returns status and parsed JSON."""
- payload = [{"event": "motion", "timestamp": 1234567890}]
- api = _make_api(status=200, json_data=payload)
- # motion_sensor uses urls["base"] which doesn't exist in HiveApiAsync;
- # add it so the URL can be constructed
- api.urls["base"] = ""
- sensor = {"type": "motionsensor", "id": "sensor-001"}
-
- result = await api.motion_sensor(sensor, fromepoch=1000000, toepoch=2000000)
-
- assert result["original"] == 200
- assert result["parsed"] == payload
-
- async def test_url_is_built_correctly(self):
- """Verifies the URL is assembled with correct sensor type and id."""
- api = _make_api(status=200, json_data=[])
- api.urls["base"] = "https://beekeeper-uk.hivehome.com/1.0"
- sensor = {"type": "contactsensor", "id": "abc-123"}
-
- captured_url = []
- original_request = api.request
-
- async def capture_request(method, url, **kwargs):
- captured_url.append(url)
- return await original_request(method, url, **kwargs)
-
- with patch.object(api, "request", side_effect=capture_request):
- await api.motion_sensor(sensor, fromepoch=100, toepoch=200)
-
- assert len(captured_url) == 1
- assert "contactsensor" in captured_url[0]
- assert "abc-123" in captured_url[0]
- assert "from=100" in captured_url[0]
- assert "to=200" in captured_url[0]
-
- async def test_os_error_raises_http_error(self):
- """OSError inside the try block causes error() → HTTPError."""
- api = _make_api(status=200)
- api.urls["base"] = ""
- sensor = {"type": "motionsensor", "id": "sensor-001"}
- api.websession.request.side_effect = OSError("fail")
- with pytest.raises(web_exceptions.HTTPError):
- await api.motion_sensor(sensor, fromepoch=1000, toepoch=2000)
-
- async def test_runtime_error_raises_http_error(self):
- """RuntimeError inside the try block causes error() → HTTPError."""
- api = _make_api(status=200)
- api.urls["base"] = ""
- sensor = {"type": "motionsensor", "id": "sensor-002"}
- api.websession.request.side_effect = RuntimeError("unexpected")
- with pytest.raises(web_exceptions.HTTPError):
- await api.motion_sensor(sensor, fromepoch=1000, toepoch=2000)
-
- async def test_zero_division_raises_http_error(self):
- """ZeroDivisionError inside the try block causes error() → HTTPError."""
- api = _make_api(status=200)
- api.urls["base"] = ""
- sensor = {"type": "motionsensor", "id": "sensor-003"}
- api.websession.request.side_effect = ZeroDivisionError()
- with pytest.raises(web_exceptions.HTTPError):
- await api.motion_sensor(sensor, fromepoch=1000, toepoch=2000)
-
-
-# ---------------------------------------------------------------------------
-# Tests: get_weather() — lines 237-249
-# ---------------------------------------------------------------------------
-
-
-class TestGetWeather:
- """Cover lines 239-249: get_weather() success, space encoding, and error paths."""
-
- async def test_success_returns_status_and_parsed(self):
- """Successful call returns status and parsed weather JSON."""
- payload = {"temperature": {"value": 15, "unit": "C"}}
- api = _make_api(status=200, json_data=payload)
-
- result = await api.get_weather("?lat=51.5&lon=-0.1")
-
- assert result["original"] == 200
- assert result["parsed"] == payload
-
- async def test_space_in_weather_url_is_encoded(self):
- """Spaces in the weather_url are replaced with %20."""
- api = _make_api(status=200, json_data={})
-
- captured_url = []
- original_request = api.request
-
- async def capture_request(method, url, **kwargs):
- captured_url.append(url)
- return await original_request(method, url, **kwargs)
-
- with patch.object(api, "request", side_effect=capture_request):
- await api.get_weather("?postcode=SW1A 2AA")
-
- assert len(captured_url) == 1
- assert " " not in captured_url[0]
- assert "%20" in captured_url[0]
-
- async def test_url_is_prefixed_with_weather_base(self):
- """The weather base URL is prepended to the given weather_url."""
- api = _make_api(status=200, json_data={})
-
- captured_url = []
- original_request = api.request
-
- async def capture_request(method, url, **kwargs):
- captured_url.append(url)
- return await original_request(method, url, **kwargs)
-
- with patch.object(api, "request", side_effect=capture_request):
- await api.get_weather("?lat=51.5")
-
- assert captured_url[0].startswith("https://weather.prod.bgchprod.info/weather")
-
- async def test_os_error_raises_http_error(self):
- """OSError inside the try block causes error() → HTTPError."""
- api = _make_api(status=200)
- api.websession.request.side_effect = OSError("network fail")
- with pytest.raises(web_exceptions.HTTPError):
- await api.get_weather("?lat=51.5")
-
- async def test_runtime_error_raises_http_error(self):
- """RuntimeError inside the try block causes error() → HTTPError."""
- api = _make_api(status=200)
- api.websession.request.side_effect = RuntimeError("unexpected")
- with pytest.raises(web_exceptions.HTTPError):
- await api.get_weather("?lat=51.5")
-
- async def test_zero_division_raises_http_error(self):
- """ZeroDivisionError inside the try block causes error() → HTTPError."""
- api = _make_api(status=200)
- api.websession.request.side_effect = ZeroDivisionError()
- with pytest.raises(web_exceptions.HTTPError):
- await api.get_weather("?lat=51.5")
-
- async def test_connection_error_raises_http_error(self):
- """ConnectionError inside the try block causes error() → HTTPError."""
- api = _make_api(status=200)
- api.websession.request.side_effect = ConnectionError("disconnected")
- with pytest.raises(web_exceptions.HTTPError):
- await api.get_weather("?lat=51.5")
-
-
-# ---------------------------------------------------------------------------
-# Tests: request() — url=None and resp.status=None skips the logging branch
-# ---------------------------------------------------------------------------
-
-
-class TestRequestUrlOrStatusNone:
- """Lines 100->108: when url is None or resp.status is None, skip log → raise directly."""
-
- async def test_none_status_skips_log_and_raises_hive_api_error(self):
- """resp.status=None causes branch 100->108 (skips the log lines) then raises."""
- api = _make_api(status=200)
- # Replace the websession response with one having status=None
- bad_resp = _make_mock_response(status=None)
- bad_resp.text = AsyncMock(return_value="")
- api.websession.request.return_value = bad_resp
- with pytest.raises(HiveApiError):
- await api.request("get", None)
-
-
-# ---------------------------------------------------------------------------
-# Tests: refresh_tokens() — session=None (134->136)
-# ---------------------------------------------------------------------------
-
-
-class TestRefreshTokensSessionNone:
- """Line 134->136: when self.session is None, skip token_data read (line 135)."""
-
- async def test_session_none_skips_token_data_read(self):
- """When session is None, tokens is not set from session → jsc uses undefined."""
- ws = MagicMock()
- ws.request.return_value = _make_mock_response(status=200)
- ws.closed = False
- ws.close = AsyncMock()
- api = HiveApiAsync(hive_session=None, websession=ws)
- # tokens is not defined before jsc, so this will raise NameError or UnboundLocalError;
- # what we need is that line 134's False branch (134->136) is traversed.
- try:
- await api.refresh_tokens()
- except (NameError, UnboundLocalError, AttributeError):
- pass # expected — tokens was never defined since session is None
diff --git a/tests/unit/test_hive_auth_async.py b/tests/unit/test_hive_auth_async.py
index fb541f8..9b03a7f 100644
--- a/tests/unit/test_hive_auth_async.py
+++ b/tests/unit/test_hive_auth_async.py
@@ -14,6 +14,7 @@
HiveInvalidPassword,
HiveInvalidUsername,
HiveRefreshTokenExpired,
+ HiveUnknownConfiguration,
)
# ---------------------------------------------------------------------------
@@ -73,18 +74,40 @@ async def _make_auth(
return auth
+_LOGIN_INFO = {
+ "UPID": "eu-west-1_TestPool",
+ "CLIID": "test-client-id",
+ "REGION": "eu-west-1_TestPool",
+}
+
+
# ---------------------------------------------------------------------------
# Tests: __init__
# ---------------------------------------------------------------------------
class TestHiveAuthAsyncInit:
- def test_pool_region_raises_value_error(self):
+ def test_pool_region_no_longer_accepted(self):
from apyhiveapi.api.hive_auth_async import HiveAuthAsync
- with pytest.raises(ValueError, match="pool_region"):
+ with pytest.raises(TypeError):
HiveAuthAsync(username="u", password="p", pool_region="eu-west-1")
+ async def test_async_init_sets_running_loop(self):
+ from apyhiveapi.api.hive_auth_async import HiveAuthAsync
+
+ auth = HiveAuthAsync(username="u@test.com", password="pass")
+ assert auth.loop is None # not set until async_init
+ mock_data = {
+ "UPID": "eu-west-1_Test",
+ "CLIID": "client-id",
+ "REGION": "eu-west-1_Test",
+ }
+ with patch.object(auth.api, "get_login_info", return_value=mock_data):
+ with patch("boto3.client", return_value=MagicMock()):
+ await auth.async_init()
+ assert auth.loop is not None
+
async def test_file_flag_set_for_magic_username(self):
from apyhiveapi.api.hive_auth_async import HiveAuthAsync
@@ -185,6 +208,49 @@ async def test_user_not_found_raises_invalid_username(self):
with pytest.raises(HiveInvalidUsername):
await auth.login()
+ @pytest.mark.asyncio
+ async def test_direct_authentication_without_challenge_stores_token(self):
+ """A response with AuthenticationResult and no ChallengeName must not crash."""
+ auth = await _make_auth()
+ auth_result = {"AuthenticationResult": {"AccessToken": "direct-tok"}}
+ auth.loop.run_in_executor.return_value = auth_result
+
+ result = await auth.login()
+
+ assert result is auth_result
+ assert auth.access_token == "direct-tok"
+
+ @pytest.mark.asyncio
+ async def test_login_regenerates_srp_ephemeral_each_attempt(self):
+ """Each login() must use a fresh SRP (a, A) ephemeral pair."""
+ auth = await _make_auth()
+ auth.loop.run_in_executor.return_value = {
+ "AuthenticationResult": {"AccessToken": "tok"}
+ }
+
+ initial_a = auth.large_a_value
+ await auth.login()
+ second_a = auth.large_a_value
+ await auth.login()
+ third_a = auth.large_a_value
+
+ assert len({initial_a, second_a, third_a}) == 3
+
+ @pytest.mark.asyncio
+ async def test_device_login_regenerates_srp_ephemeral(self):
+ """device_login() must use a fresh SRP (a, A) ephemeral pair."""
+ auth = await _make_auth(device_key="dk-1", device_group_key="grp-1")
+ auth.device_password = "dev-pass" # pragma: allowlist secret
+ auth.loop.run_in_executor.return_value = {"ChallengeParameters": {}}
+
+ initial_a = auth.large_a_value
+ with patch.object(
+ auth, "process_device_challenge", new=AsyncMock(return_value={})
+ ):
+ await auth.device_login()
+
+ assert auth.large_a_value != initial_a
+
@pytest.mark.asyncio
async def test_endpoint_error_on_initiate_raises_api_error(self):
auth = await _make_auth()
@@ -444,6 +510,14 @@ async def test_new_device_metadata_in_sms_stores_keys(self):
assert auth.device_group_key == "sms-grp"
assert auth.device_key == "sms-dev"
+ @pytest.mark.asyncio
+ async def test_no_authentication_result_key_does_not_raise(self):
+ auth = await _make_auth()
+ auth.loop.run_in_executor.return_value = {"ChallengeName": "SMS_MFA"}
+ result = await auth.sms_2fa("123456", {"Session": "sess-1"})
+ assert auth.access_token is None
+ assert result == {"ChallengeName": "SMS_MFA"}
+
# ---------------------------------------------------------------------------
# Tests: refresh_token
@@ -516,3 +590,840 @@ async def test_endpoint_error_raises_api_error(self):
auth.loop.run_in_executor.side_effect = _endpoint_error()
with pytest.raises(HiveApiError):
await auth.refresh_token("tok")
+
+
+# ---------------------------------------------------------------------------
+# Migrated from test_hive_auth_async_extended.py
+# ---------------------------------------------------------------------------
+
+
+class TestAsyncInit:
+ """Cover lines 98-112: async_init() sets pool_id, client_id, region and boto3 client."""
+
+ async def test_async_init_sets_pool_id_and_client_id(self):
+ """async_init reads login info and sets internal auth fields."""
+ from apyhiveapi.api.hive_auth_async import HiveAuthAsync
+
+ auth = HiveAuthAsync(username="user@test.com", password="pass")
+ auth.client = None
+
+ mock_boto_client = MagicMock()
+ mock_loop = MagicMock()
+ mock_loop.run_in_executor = AsyncMock(
+ side_effect=[_LOGIN_INFO, mock_boto_client]
+ )
+
+ with patch("asyncio.get_running_loop", return_value=mock_loop):
+ await auth.async_init()
+
+ assert auth._pool_id == "eu-west-1_TestPool"
+ assert auth._client_id == "test-client-id"
+ assert auth._region == "eu-west-1"
+ assert auth.client is mock_boto_client
+
+ async def test_async_init_splits_region_correctly(self):
+ """Region is extracted as the part before the underscore in UPID/REGION."""
+ from apyhiveapi.api.hive_auth_async import HiveAuthAsync
+
+ auth = HiveAuthAsync(username="user@test.com", password="pass")
+ auth.client = None
+
+ login_info = {
+ "UPID": "ap-southeast-2_XyzPool",
+ "CLIID": "ap-client",
+ "REGION": "ap-southeast-2_XyzPool",
+ }
+ mock_boto_client = MagicMock()
+ mock_loop = MagicMock()
+ mock_loop.run_in_executor = AsyncMock(
+ side_effect=[login_info, mock_boto_client]
+ )
+
+ with patch("asyncio.get_running_loop", return_value=mock_loop):
+ await auth.async_init()
+
+ assert auth._region == "ap-southeast-2"
+
+
+class TestCalculateA:
+ """Cover line 141: safety check when big_a % big_n == 0."""
+
+ async def test_safety_check_raises_when_a_is_zero_mod_n(self):
+ """If pow(g, a, n) == 0 mod n (i.e., equals big_n or 0), ValueError is raised."""
+ auth = await _make_auth()
+ with patch("builtins.pow", return_value=auth.big_n):
+ with pytest.raises(ValueError, match="Safety check for A failed"):
+ auth.calculate_a()
+
+ async def test_safety_check_passes_normally(self):
+ """Under normal random inputs, calculate_a does not raise and returns positive int."""
+ auth = await _make_auth()
+ result = auth.calculate_a()
+ assert result > 0
+
+
+class TestGetPasswordAuthenticationKey:
+ """Cover lines 155-172: get_password_authentication_key() computes HKDF."""
+
+ async def test_returns_bytes(self):
+ """With valid SRP inputs, the method returns a bytes-like value."""
+ auth = await _make_auth()
+ from apyhiveapi.api.srp_crypto import get_random
+
+ server_b_value = hex(get_random(128))[2:]
+ salt = hex(get_random(16))[2:]
+
+ with patch("apyhiveapi.api.hive_auth_async.calculate_u", return_value=99999):
+ result = auth.get_password_authentication_key(
+ "testuser", "testpass", server_b_value, salt
+ )
+
+ assert isinstance(result, (bytes, bytearray))
+
+ async def test_u_value_zero_raises_value_error(self):
+ """If calculate_u returns 0, ValueError is raised."""
+ auth = await _make_auth()
+ from apyhiveapi.api.srp_crypto import get_random
+
+ server_b_value = hex(get_random(128))[2:]
+ salt = hex(get_random(16))[2:]
+
+ with patch("apyhiveapi.api.hive_auth_async.calculate_u", return_value=0):
+ with pytest.raises(ValueError, match="U cannot be zero"):
+ auth.get_password_authentication_key(
+ "testuser", "testpass", server_b_value, salt
+ )
+
+ async def test_accepts_integer_server_b(self):
+ """server_b_value can be passed as an integer (handled by _to_int)."""
+ auth = await _make_auth()
+ from apyhiveapi.api.srp_crypto import get_random
+
+ server_b_int = get_random(128)
+ salt = hex(get_random(16))[2:]
+
+ with patch("apyhiveapi.api.hive_auth_async.calculate_u", return_value=12345):
+ result = auth.get_password_authentication_key(
+ "testuser", "testpass", server_b_int, salt
+ )
+
+ assert isinstance(result, (bytes, bytearray))
+
+
+class TestProcessChallenge:
+ """Cover lines 205-254: process_challenge() builds the SRP response."""
+
+ def _make_challenge_params(self, salt_as_int=False):
+ import base64
+
+ salt = "aabbccddee"
+ if salt_as_int:
+ salt = int("aabbccddee", 16)
+ return {
+ "USER_ID_FOR_SRP": "challenge-user@test.com",
+ "SALT": salt,
+ "SRP_B": "ff" * 32,
+ "SECRET_BLOCK": base64.b64encode(b"secret-block-bytes").decode(),
+ }
+
+ async def test_returns_required_keys(self):
+ """Basic challenge response includes mandatory SRP keys."""
+ auth = await _make_auth()
+ fake_hkdf = b"\x00" * 32
+ auth.loop.run_in_executor = AsyncMock(return_value=fake_hkdf)
+ params = self._make_challenge_params()
+ result = await auth.process_challenge(params)
+ assert "TIMESTAMP" in result
+ assert "USERNAME" in result
+ assert "PASSWORD_CLAIM_SECRET_BLOCK" in result
+ assert "PASSWORD_CLAIM_SIGNATURE" in result
+
+ async def test_sets_user_id_from_challenge(self):
+ """process_challenge stores USER_ID_FOR_SRP as self.user_id."""
+ auth = await _make_auth()
+ fake_hkdf = b"\x00" * 32
+ auth.loop.run_in_executor = AsyncMock(return_value=fake_hkdf)
+ params = self._make_challenge_params()
+ await auth.process_challenge(params)
+ assert auth.user_id == "challenge-user@test.com"
+
+ async def test_with_client_secret_adds_secret_hash(self):
+ """When client_secret is set, SECRET_HASH is added to the response."""
+ auth = await _make_auth(client_secret="my-secret")
+ fake_hkdf = b"\x00" * 32
+ auth.loop.run_in_executor = AsyncMock(return_value=fake_hkdf)
+ params = self._make_challenge_params()
+ result = await auth.process_challenge(params)
+ assert "SECRET_HASH" in result
+
+ async def test_without_client_secret_no_secret_hash(self):
+ """When client_secret is None, SECRET_HASH is absent from the response."""
+ auth = await _make_auth(client_secret=None)
+ fake_hkdf = b"\x00" * 32
+ auth.loop.run_in_executor = AsyncMock(return_value=fake_hkdf)
+ params = self._make_challenge_params()
+ result = await auth.process_challenge(params)
+ assert "SECRET_HASH" not in result
+
+ async def test_with_device_key_adds_device_key(self):
+ """When device_key is set, DEVICE_KEY is added to the response."""
+ auth = await _make_auth(device_key="dk-challenge")
+ fake_hkdf = b"\x00" * 32
+ auth.loop.run_in_executor = AsyncMock(return_value=fake_hkdf)
+ params = self._make_challenge_params()
+ result = await auth.process_challenge(params)
+ assert result["DEVICE_KEY"] == "dk-challenge"
+
+ async def test_without_device_key_no_device_key_in_response(self):
+ """When device_key is None, DEVICE_KEY is absent from the response."""
+ auth = await _make_auth(device_key=None)
+ fake_hkdf = b"\x00" * 32
+ auth.loop.run_in_executor = AsyncMock(return_value=fake_hkdf)
+ params = self._make_challenge_params()
+ result = await auth.process_challenge(params)
+ assert "DEVICE_KEY" not in result
+
+ async def test_salt_as_integer_triggers_pad_hex(self):
+ """When SALT is an integer (not str), pad_hex is applied before use."""
+ auth = await _make_auth()
+ fake_hkdf = b"\x00" * 32
+ auth.loop.run_in_executor = AsyncMock(return_value=fake_hkdf)
+ params = self._make_challenge_params(salt_as_int=True)
+ result = await auth.process_challenge(params)
+ assert "TIMESTAMP" in result
+
+
+class TestLoginClientNone:
+ """Cover line 263: when client is None, async_init() is awaited."""
+
+ async def test_login_calls_async_init_when_client_is_none(self):
+ """If client is None before login, async_init is called before SRP flow."""
+ auth = await _make_auth()
+ auth.client = None
+ auth.use_file = False
+
+ auth_result = {"AuthenticationResult": {"AccessToken": "post-init-token"}}
+ challenge_response = {
+ "ChallengeName": "PASSWORD_VERIFIER",
+ "ChallengeParameters": {
+ "USER_ID_FOR_SRP": "user@test.com",
+ "SALT": "aabbccdd",
+ "SRP_B": "ccddee",
+ "SECRET_BLOCK": "YWJj",
+ },
+ }
+
+ async def fake_async_init():
+ auth.client = MagicMock()
+ auth._client_id = "test-client-id"
+ auth._pool_id = "eu-west-1_TestPool"
+ auth._region = "eu-west-1"
+
+ with patch.object(auth, "async_init", side_effect=fake_async_init) as mock_init:
+ with patch.object(
+ auth, "process_challenge", new_callable=AsyncMock
+ ) as mock_ch:
+ mock_ch.return_value = {
+ "TIMESTAMP": "Mon Jan 01 00:00:00 UTC 2024",
+ "USERNAME": "user",
+ }
+ auth.loop.run_in_executor = AsyncMock(
+ side_effect=[challenge_response, auth_result]
+ )
+ result = await auth.login()
+
+ mock_init.assert_called_once()
+ assert "AuthenticationResult" in result
+
+
+class TestLoginUnsupportedChallenge:
+ """Cover lines 335-337: non-PASSWORD_VERIFIER challenge raises NotImplementedError."""
+
+ async def test_new_password_required_raises_not_implemented(self):
+ """NEW_PASSWORD_REQUIRED challenge is not supported and raises NotImplementedError."""
+ auth = await _make_auth()
+ auth.loop.run_in_executor = AsyncMock(
+ return_value={
+ "ChallengeName": "NEW_PASSWORD_REQUIRED",
+ "ChallengeParameters": {},
+ }
+ )
+ with pytest.raises(NotImplementedError, match="NEW_PASSWORD_REQUIRED"):
+ await auth.login()
+
+ async def test_custom_challenge_raises_not_implemented(self):
+ """Any unknown challenge name raises NotImplementedError."""
+ auth = await _make_auth()
+ auth.loop.run_in_executor = AsyncMock(
+ return_value={
+ "ChallengeName": "UNKNOWN_CHALLENGE_TYPE",
+ "ChallengeParameters": {},
+ }
+ )
+ with pytest.raises(NotImplementedError, match="UNKNOWN_CHALLENGE_TYPE"):
+ await auth.login()
+
+
+class TestLoginResourceNotFound:
+ """Cover lines 307-311: ResourceNotFoundException in respond_to_auth_challenge."""
+
+ async def test_resource_not_found_raises_invalid_device_authentication(self):
+ """ResourceNotFoundException during challenge → HiveInvalidDeviceAuthentication."""
+ auth = await _make_auth()
+ challenge_response = {
+ "ChallengeName": "PASSWORD_VERIFIER",
+ "ChallengeParameters": {
+ "USER_ID_FOR_SRP": "user@test.com",
+ "SALT": "aabbccdd",
+ "SRP_B": "ccddee",
+ "SECRET_BLOCK": "YWJj",
+ },
+ }
+ resource_err = _named_client_error("ResourceNotFoundException")
+ with patch.object(auth, "process_challenge", new_callable=AsyncMock) as mock_ch:
+ mock_ch.return_value = {
+ "TIMESTAMP": "Mon Jan 01 00:00:00 UTC 2024",
+ "USERNAME": "user",
+ }
+ auth.loop.run_in_executor = AsyncMock(
+ side_effect=[challenge_response, resource_err]
+ )
+ with pytest.raises(HiveInvalidDeviceAuthentication):
+ await auth.login()
+
+
+class TestLoginEndpointErrorOnChallenge:
+ """Cover lines 312-317: EndpointConnectionError during respond_to_auth_challenge."""
+
+ async def test_endpoint_error_on_challenge_raises_api_error(self):
+ """EndpointConnectionError during SRP challenge response → HiveApiError."""
+ auth = await _make_auth()
+ challenge_response = {
+ "ChallengeName": "PASSWORD_VERIFIER",
+ "ChallengeParameters": {
+ "USER_ID_FOR_SRP": "user@test.com",
+ "SALT": "aabbccdd",
+ "SRP_B": "ccddee",
+ "SECRET_BLOCK": "YWJj",
+ },
+ }
+ with patch.object(auth, "process_challenge", new_callable=AsyncMock) as mock_ch:
+ mock_ch.return_value = {
+ "TIMESTAMP": "Mon Jan 01 00:00:00 UTC 2024",
+ "USERNAME": "user",
+ }
+ auth.loop.run_in_executor = AsyncMock(
+ side_effect=[challenge_response, _endpoint_error()]
+ )
+ with pytest.raises(HiveApiError):
+ await auth.login()
+
+
+class TestLoginResultHandling:
+ """Cover lines 321-333: AuthenticationResult presence/absence in login result."""
+
+ async def test_result_without_authentication_result_does_not_store_token(self):
+ """If result lacks 'AuthenticationResult', access_token is not set."""
+ auth = await _make_auth()
+ challenge_response = {
+ "ChallengeName": "PASSWORD_VERIFIER",
+ "ChallengeParameters": {
+ "USER_ID_FOR_SRP": "user@test.com",
+ "SALT": "aabbccdd",
+ "SRP_B": "ccddee",
+ "SECRET_BLOCK": "YWJj",
+ },
+ }
+ sms_challenge_result = {
+ "ChallengeName": "SMS_MFA",
+ "Session": "session-tok",
+ "ChallengeParameters": {},
+ }
+ with patch.object(auth, "process_challenge", new_callable=AsyncMock) as mock_ch:
+ mock_ch.return_value = {
+ "TIMESTAMP": "Mon Jan 01 00:00:00 UTC 2024",
+ "USERNAME": "user",
+ }
+ auth.loop.run_in_executor = AsyncMock(
+ side_effect=[challenge_response, sms_challenge_result]
+ )
+ result = await auth.login()
+
+ assert auth.access_token is None
+ assert result is sms_challenge_result
+
+ async def test_result_with_authentication_result_but_no_new_device_metadata(self):
+ """AuthenticationResult without NewDeviceMetadata sets access_token only."""
+ auth = await _make_auth()
+ challenge_response = {
+ "ChallengeName": "PASSWORD_VERIFIER",
+ "ChallengeParameters": {
+ "USER_ID_FOR_SRP": "user@test.com",
+ "SALT": "aabbccdd",
+ "SRP_B": "ccddee",
+ "SECRET_BLOCK": "YWJj",
+ },
+ }
+ auth_result = {"AuthenticationResult": {"AccessToken": "my-access-token"}}
+ with patch.object(auth, "process_challenge", new_callable=AsyncMock) as mock_ch:
+ mock_ch.return_value = {
+ "TIMESTAMP": "Mon Jan 01 00:00:00 UTC 2024",
+ "USERNAME": "user",
+ }
+ auth.loop.run_in_executor = AsyncMock(
+ side_effect=[challenge_response, auth_result]
+ )
+ await auth.login()
+
+ assert auth.access_token == "my-access-token"
+ assert auth.device_group_key is None
+ assert auth.device_key is None
+
+
+class TestDeviceLoginClientNone:
+ """Cover device_login's async_init call when client is None."""
+
+ async def test_device_login_calls_async_init_when_client_is_none(self):
+ """If client is None, async_init is called before proceeding."""
+ auth = await _make_auth(device_key="dk-1")
+ auth.client = None
+
+ async def fake_async_init():
+ auth.client = MagicMock()
+ auth._client_id = "test-client-id"
+
+ with patch.object(auth, "async_init", side_effect=fake_async_init) as mock_init:
+ auth.loop.run_in_executor = AsyncMock(
+ side_effect=_named_client_error("ResourceNotFoundException")
+ )
+ with pytest.raises(HiveInvalidDeviceAuthentication):
+ await auth.device_login()
+
+ mock_init.assert_called_once()
+
+
+class TestSms2faNoNewDeviceMetadata:
+ """Cover line 441: sms_2fa when NewDeviceMetadata is absent in result."""
+
+ async def test_no_new_device_metadata_does_not_set_device_keys(self):
+ """When NewDeviceMetadata is absent, device_group_key and device_key stay None."""
+ auth = await _make_auth()
+ original_group_key = auth.device_group_key
+ original_device_key = auth.device_key
+
+ sms_result = {
+ "AuthenticationResult": {
+ "AccessToken": "sms-access-token",
+ }
+ }
+ auth.loop.run_in_executor = AsyncMock(return_value=sms_result)
+
+ result = await auth.sms_2fa("654321", {"Session": "sess-abc"})
+
+ assert auth.access_token == "sms-access-token"
+ assert auth.device_group_key == original_group_key
+ assert auth.device_key == original_device_key
+ assert result is sms_result
+
+
+class TestSms2faCodeMismatch:
+ """Cover lines 424-429: CodeMismatchException raises HiveInvalid2FACode."""
+
+ async def test_code_mismatch_raises_invalid_2fa_code(self):
+ """CodeMismatchException in sms_2fa raises HiveInvalid2FACode."""
+ auth = await _make_auth()
+ auth.loop.run_in_executor = AsyncMock(
+ side_effect=_named_client_error("CodeMismatchException")
+ )
+ with pytest.raises(HiveInvalid2FACode):
+ await auth.sms_2fa("000000", {"Session": "sess-1"})
+
+ async def test_not_authorized_raises_invalid_2fa_code(self):
+ """NotAuthorizedException in sms_2fa raises HiveInvalid2FACode."""
+ auth = await _make_auth()
+ auth.loop.run_in_executor = AsyncMock(
+ side_effect=_named_client_error("NotAuthorizedException")
+ )
+ with pytest.raises(HiveInvalid2FACode):
+ await auth.sms_2fa("111111", {"Session": "sess-2"})
+
+
+class TestRefreshTokenClientNone:
+ """Cover line 440-441: refresh_token calls async_init when client is None."""
+
+ async def test_refresh_token_calls_async_init_when_client_is_none(self):
+ """If client is None, async_init is awaited before refreshing."""
+ auth = await _make_auth()
+ auth.client = None
+
+ result_payload = {"AuthenticationResult": {"AccessToken": "refreshed-tok"}}
+
+ async def fake_async_init():
+ auth.client = MagicMock()
+
+ with patch.object(auth, "async_init", side_effect=fake_async_init) as mock_init:
+ auth.loop.run_in_executor = AsyncMock(return_value=result_payload)
+ result = await auth.refresh_token("some-refresh-token")
+
+ mock_init.assert_called_once()
+ assert result is result_payload
+
+
+class TestRefreshTokenResultPath:
+ """Cover lines 479-485: refresh_token when result is returned normally."""
+
+ async def test_returns_result_directly(self):
+ """refresh_token returns the result from Cognito directly."""
+ auth = await _make_auth()
+ result_payload = {"AuthenticationResult": {"AccessToken": "tok-xyz"}}
+ auth.loop.run_in_executor = AsyncMock(return_value=result_payload)
+ result = await auth.refresh_token("refresh-tok-abc")
+ assert result is result_payload
+
+ async def test_with_device_key_includes_device_key_param(self):
+ """When device_key is set, DEVICE_KEY is included in auth_params."""
+ auth = await _make_auth(device_key="dk-refresh-001")
+ result_payload = {"AuthenticationResult": {"AccessToken": "tok-dk"}}
+ auth.loop.run_in_executor = AsyncMock(return_value=result_payload)
+ result = await auth.refresh_token("refresh-tok-dk")
+ assert result is result_payload
+
+ async def test_without_device_key_sends_only_refresh_token(self):
+ """When device_key is None, auth_params has only REFRESH_TOKEN."""
+ auth = await _make_auth(device_key=None)
+ result_payload = {"AuthenticationResult": {"AccessToken": "tok-no-dk"}}
+ auth.loop.run_in_executor = AsyncMock(return_value=result_payload)
+ result = await auth.refresh_token("refresh-tok-no-dk")
+ assert result is result_payload
+
+
+class TestLoginInitiateAuthSwallowedClientError:
+ """Arc 280->288: ClientError caught but class name is not UserNotFoundException."""
+
+ async def test_other_client_error_in_initiate_auth_falls_through(self):
+ """Non-UserNotFoundException ClientError raises HiveApiError."""
+ auth = await _make_auth()
+
+ wrong_cls = type("SomeOtherError", (botocore.exceptions.ClientError,), {})
+ wrong_err = wrong_cls(
+ {"Error": {"Code": "SomeOtherError", "Message": "msg"}}, "op"
+ )
+ auth.loop.run_in_executor = AsyncMock(side_effect=wrong_err)
+
+ with pytest.raises(HiveApiError):
+ await auth.login()
+
+
+class TestLoginInitiateAuthSwallowedEndpointError:
+ """EndpointConnectionError in initiate_auth always raises HiveApiError."""
+
+ async def test_wrong_name_endpoint_error_in_initiate_auth_raises_api_error(self):
+ """Any EndpointConnectionError subclass in initiate_auth raises HiveApiError."""
+ auth = await _make_auth()
+
+ wrong_cls = type(
+ "WrongEndpoint", (botocore.exceptions.EndpointConnectionError,), {}
+ )
+ wrong_err = wrong_cls(endpoint_url="https://cognito.eu-west-1.amazonaws.com")
+ auth.loop.run_in_executor = AsyncMock(side_effect=wrong_err)
+
+ with pytest.raises(HiveApiError):
+ await auth.login()
+
+
+class TestLoginChallengeSwallowedClientError:
+ """Arc 307->319: ClientError caught in challenge response with name not matching."""
+
+ async def test_other_client_error_in_challenge_falls_through(self):
+ """ClientError that is neither NotAuthorized nor ResourceNotFound raises HiveApiError."""
+ auth = await _make_auth()
+
+ challenge_response = {
+ "ChallengeName": "PASSWORD_VERIFIER",
+ "ChallengeParameters": {
+ "USER_ID_FOR_SRP": "user@test.com",
+ "SALT": "aabbccdd",
+ "SRP_B": "ccddee",
+ "SECRET_BLOCK": "YWJj",
+ },
+ }
+
+ wrong_cls = type("ThirdPartyError", (botocore.exceptions.ClientError,), {})
+ wrong_err = wrong_cls(
+ {"Error": {"Code": "ThirdPartyError", "Message": "msg"}}, "op"
+ )
+
+ with patch.object(auth, "process_challenge", new_callable=AsyncMock) as mock_ch:
+ mock_ch.return_value = {"TIMESTAMP": "...", "USERNAME": "user"}
+ auth.loop.run_in_executor = AsyncMock(
+ side_effect=[challenge_response, wrong_err]
+ )
+ with pytest.raises(HiveApiError):
+ await auth.login()
+
+
+class TestLoginChallengeSwallowedEndpointError:
+ """EndpointConnectionError in respond_to_auth_challenge always raises HiveApiError."""
+
+ async def test_wrong_name_endpoint_error_in_challenge_raises_api_error(self):
+ """Any EndpointConnectionError subclass in SRP challenge raises HiveApiError."""
+ auth = await _make_auth()
+
+ challenge_response = {
+ "ChallengeName": "PASSWORD_VERIFIER",
+ "ChallengeParameters": {
+ "USER_ID_FOR_SRP": "user@test.com",
+ "SALT": "aabbccdd",
+ "SRP_B": "ccddee",
+ "SECRET_BLOCK": "YWJj",
+ },
+ }
+
+ wrong_cls = type(
+ "WrongEndpoint", (botocore.exceptions.EndpointConnectionError,), {}
+ )
+ wrong_err = wrong_cls(endpoint_url="https://cognito.eu-west-1.amazonaws.com")
+
+ with patch.object(auth, "process_challenge", new_callable=AsyncMock) as mock_ch:
+ mock_ch.return_value = {"TIMESTAMP": "...", "USERNAME": "user"}
+ auth.loop.run_in_executor = AsyncMock(
+ side_effect=[challenge_response, wrong_err]
+ )
+ with pytest.raises(HiveApiError):
+ await auth.login()
+
+
+class TestDeviceLoginSuccessPath:
+ """Lines 364-367, 391: device_login processes device challenge and returns result."""
+
+ async def test_successful_device_login_returns_auth_result(self):
+ """Full device_login success: process_device_challenge called, result returned."""
+ auth = await _make_auth(device_key="dk-abc", device_group_key="grp-abc")
+ auth.device_password = "dev-pass" # pragma: allowlist secret
+
+ initial_result = {
+ "ChallengeParameters": {
+ "USERNAME": "user@test.com",
+ "SALT": "aabbccdd",
+ "SRP_B": "ccddee",
+ "SECRET_BLOCK": "YWJj",
+ }
+ }
+ final_result = {"AuthenticationResult": {"AccessToken": "device-access-token"}}
+
+ with patch.object(
+ auth, "process_device_challenge", new_callable=AsyncMock
+ ) as mock_pdc:
+ mock_pdc.return_value = {
+ "TIMESTAMP": "Mon Jan 01 00:00:00 UTC 2024",
+ "USERNAME": "user@test.com",
+ "PASSWORD_CLAIM_SECRET_BLOCK": "YWJj", # pragma: allowlist secret
+ "PASSWORD_CLAIM_SIGNATURE": "sig", # pragma: allowlist secret
+ "DEVICE_KEY": "dk-abc",
+ }
+ auth.loop.run_in_executor = AsyncMock(
+ side_effect=[initial_result, final_result]
+ )
+ result = await auth.device_login()
+
+ mock_pdc.assert_called_once_with(initial_result["ChallengeParameters"])
+ assert result is final_result
+
+ async def test_device_login_calls_second_respond_to_auth_challenge(self):
+ """Lines 367-375: second respond_to_auth_challenge is called with device challenge."""
+ auth = await _make_auth(device_key="dk-xyz", device_group_key="grp-xyz")
+ auth.device_password = "dev-pass-xyz" # pragma: allowlist secret
+
+ initial_result = {
+ "ChallengeParameters": {
+ "USERNAME": "user@test.com",
+ "SALT": "11223344",
+ "SRP_B": "55667788",
+ "SECRET_BLOCK": "dGVzdA==", # pragma: allowlist secret
+ }
+ }
+ final_result = {"AuthenticationResult": {"AccessToken": "tok-xyz"}}
+
+ challenge_resp = {
+ "TIMESTAMP": "Mon Jan 01 00:00:00 UTC 2024",
+ "USERNAME": "user@test.com",
+ "PASSWORD_CLAIM_SECRET_BLOCK": "dGVzdA==", # pragma: allowlist secret
+ "PASSWORD_CLAIM_SIGNATURE": "sig", # pragma: allowlist secret
+ "DEVICE_KEY": "dk-xyz",
+ }
+
+ with patch.object(
+ auth, "process_device_challenge", new_callable=AsyncMock
+ ) as mock_pdc:
+ mock_pdc.return_value = challenge_resp
+ auth.loop.run_in_executor = AsyncMock(
+ side_effect=[initial_result, final_result]
+ )
+ result = await auth.device_login()
+
+ assert auth.loop.run_in_executor.call_count == 2
+ assert result["AuthenticationResult"]["AccessToken"] == "tok-xyz"
+
+
+class TestDeviceLoginEndpointWrongName:
+ """Any EndpointConnectionError in device_login always raises HiveApiError."""
+
+ async def test_wrong_name_endpoint_error_raises_api_error(self):
+ """Any EndpointConnectionError subclass in device_login raises HiveApiError."""
+ auth = await _make_auth(device_key="dk-err", device_group_key="grp-err")
+ auth.device_password = "dev-pass-err" # pragma: allowlist secret
+
+ wrong_cls = type(
+ "WrongEndpoint", (botocore.exceptions.EndpointConnectionError,), {}
+ )
+ wrong_err = wrong_cls(endpoint_url="https://cognito.eu-west-1.amazonaws.com")
+ auth.loop.run_in_executor = AsyncMock(side_effect=wrong_err)
+
+ with pytest.raises(HiveApiError):
+ await auth.device_login()
+
+
+class TestSms2faUnrecognisedClientError:
+ """An unrecognised ClientError in sms_2fa must surface as HiveApiError."""
+
+ async def test_other_client_error_raises_hive_api_error(self):
+ """A ClientError that is not a 2FA rejection must not be swallowed."""
+ auth = await _make_auth()
+
+ auth.loop.run_in_executor = AsyncMock(
+ side_effect=_named_client_error("LimitExceededException")
+ )
+
+ with pytest.raises(HiveApiError):
+ await auth.sms_2fa("123456", {"Session": "sess-xyz"})
+
+
+class TestSms2faSwallowedEndpointError:
+ """Any EndpointConnectionError in sms_2fa raises HiveApiError."""
+
+ async def test_wrong_name_endpoint_error_raises_api_error(self):
+ """Any EndpointConnectionError subclass in sms_2fa raises HiveApiError."""
+ auth = await _make_auth()
+
+ wrong_cls = type(
+ "WrongEndpoint", (botocore.exceptions.EndpointConnectionError,), {}
+ )
+ wrong_err = wrong_cls(endpoint_url="https://cognito.eu-west-1.amazonaws.com")
+ auth.loop.run_in_executor = AsyncMock(side_effect=wrong_err)
+
+ with pytest.raises(HiveApiError):
+ await auth.sms_2fa("654321", {"Session": "sess-abc"})
+
+
+class TestRefreshTokenSwallowedEndpointError:
+ """Any EndpointConnectionError in refresh_token raises HiveApiError."""
+
+ async def test_wrong_name_endpoint_error_raises_api_error(self):
+ """Any EndpointConnectionError subclass in refresh_token raises HiveApiError."""
+ auth = await _make_auth()
+
+ wrong_cls = type(
+ "WrongEndpoint", (botocore.exceptions.EndpointConnectionError,), {}
+ )
+ wrong_err = wrong_cls(endpoint_url="https://cognito.eu-west-1.amazonaws.com")
+ auth.loop.run_in_executor = AsyncMock(side_effect=wrong_err)
+
+ with pytest.raises(HiveApiError):
+ await auth.refresh_token("some-refresh-token")
+
+
+class TestAsyncInitMissingKeys:
+ """async_init must raise HiveUnknownConfiguration when login info keys are absent."""
+
+ async def test_async_init_missing_region_raises_configuration_error(self):
+ """If REGION is absent from login info, raise HiveUnknownConfiguration."""
+ from apyhiveapi.api.hive_auth_async import HiveAuthAsync
+
+ auth = HiveAuthAsync(username="user@test.com", password="pass")
+ bad_login_info = {"UPID": "eu-west-1_TestPool", "CLIID": "test-client-id"}
+ mock_loop = MagicMock()
+ mock_loop.run_in_executor = AsyncMock(side_effect=[bad_login_info])
+ with patch("asyncio.get_running_loop", return_value=mock_loop):
+ with pytest.raises(HiveUnknownConfiguration):
+ await auth.async_init()
+
+ async def test_async_init_missing_upid_raises_configuration_error(self):
+ """If UPID is absent from login info, raise HiveUnknownConfiguration."""
+ from apyhiveapi.api.hive_auth_async import HiveAuthAsync
+
+ auth = HiveAuthAsync(username="user@test.com", password="pass")
+ bad_login_info = {"CLIID": "test-client-id", "REGION": "eu-west-1_TestPool"}
+ mock_loop = MagicMock()
+ mock_loop.run_in_executor = AsyncMock(side_effect=[bad_login_info])
+ with patch("asyncio.get_running_loop", return_value=mock_loop):
+ with pytest.raises(HiveUnknownConfiguration):
+ await auth.async_init()
+
+
+class TestGetPasswordAuthKeyNonePoolId:
+ """get_password_authentication_key must not crash with AttributeError when _pool_id is None."""
+
+ async def test_none_pool_id_raises_configuration_error(self):
+ """If _pool_id is None, raise HiveUnknownConfiguration (not AttributeError)."""
+ auth = await _make_auth()
+ auth._pool_id = None
+ with pytest.raises(HiveUnknownConfiguration):
+ auth.get_password_authentication_key("user", "pass", "DEADBEEF", "ABCDEF")
+
+
+class TestLoginUnhandledClientError:
+ """Unhandled ClientError codes in login() must raise HiveApiError, not crash."""
+
+ async def test_initiate_auth_unhandled_error_raises_hive_api_error(self):
+ """Non-UserNotFoundException ClientError from initiate_auth raises HiveApiError."""
+ auth = await _make_auth()
+ err = botocore.exceptions.ClientError(
+ {"Error": {"Code": "TooManyRequestsException", "Message": "too many"}},
+ "InitiateAuth",
+ )
+ auth.loop.run_in_executor = AsyncMock(side_effect=err)
+ with pytest.raises(HiveApiError):
+ await auth.login()
+
+ async def test_respond_to_challenge_unhandled_error_raises_hive_api_error(self):
+ """Non-NotAuthorized/ResourceNotFound ClientError raises HiveApiError."""
+ auth = await _make_auth()
+ challenge_response = {
+ "ChallengeName": "PASSWORD_VERIFIER",
+ "ChallengeParameters": {
+ "USER_ID_FOR_SRP": "user@test.com",
+ "SALT": "aabbccdd",
+ "SRP_B": "ccddee",
+ "SECRET_BLOCK": "YWJj",
+ },
+ }
+ respond_err = botocore.exceptions.ClientError(
+ {"Error": {"Code": "InternalErrorException", "Message": "internal"}},
+ "RespondToAuthChallenge",
+ )
+ auth.loop.run_in_executor = AsyncMock(
+ side_effect=[challenge_response, respond_err]
+ )
+ auth.process_challenge = AsyncMock(
+ return_value={"TIMESTAMP": "t", "USERNAME": "u"}
+ )
+ with pytest.raises(HiveApiError):
+ await auth.login()
+
+
+class TestAsyncInitNoneGuard:
+ """async_init must guard against get_login_info returning None."""
+
+ async def test_async_init_raises_when_login_info_is_none(self):
+ """async_init raises HiveUnknownConfiguration when get_login_info returns None."""
+ from apyhiveapi.api.hive_auth_async import HiveAuthAsync
+
+ auth = HiveAuthAsync(username="user@test.com", password="pass")
+ auth.client = None
+
+ mock_loop = MagicMock()
+ mock_loop.run_in_executor = AsyncMock(return_value=None)
+
+ with patch("asyncio.get_running_loop", return_value=mock_loop):
+ with pytest.raises(HiveUnknownConfiguration):
+ await auth.async_init()
diff --git a/tests/unit/test_hive_auth_async_extended.py b/tests/unit/test_hive_auth_async_extended.py
deleted file mode 100644
index 54d408d..0000000
--- a/tests/unit/test_hive_auth_async_extended.py
+++ /dev/null
@@ -1,969 +0,0 @@
-"""Extended unit tests for HiveAuthAsync — covers previously uncovered paths."""
-
-from __future__ import annotations
-
-from unittest.mock import AsyncMock, MagicMock, patch
-
-import botocore.exceptions
-import pytest
-from apyhiveapi.helper.hive_exceptions import (
- HiveApiError,
- HiveInvalid2FACode,
- HiveInvalidDeviceAuthentication,
-)
-
-# ---------------------------------------------------------------------------
-# Exception factories (same pattern as test_hive_auth_async.py)
-# ---------------------------------------------------------------------------
-
-
-def _named_client_error(
- code: str, message: str = ""
-) -> botocore.exceptions.ClientError:
- """Return a ClientError whose __class__.__name__ matches ``code``."""
- cls = type(code, (botocore.exceptions.ClientError,), {})
- return cls(
- {"Error": {"Code": code, "Message": message}},
- "operation",
- )
-
-
-def _endpoint_error() -> botocore.exceptions.EndpointConnectionError:
- return botocore.exceptions.EndpointConnectionError(
- endpoint_url="https://cognito.eu-west-1.amazonaws.com"
- )
-
-
-# ---------------------------------------------------------------------------
-# Shared factory
-# ---------------------------------------------------------------------------
-
-_LOGIN_INFO = {
- "UPID": "eu-west-1_TestPool",
- "CLIID": "test-client-id",
- "REGION": "eu-west-1_TestPool",
-}
-
-
-async def _make_auth(
- username: str = "user@test.com",
- password: str = "testpass",
- device_key: str | None = None,
- device_group_key: str | None = None,
- device_password: str | None = None,
- client_secret: str | None = None,
-):
- from apyhiveapi.api.hive_auth_async import HiveAuthAsync
-
- auth = HiveAuthAsync(
- username=username,
- password=password,
- device_key=device_key,
- device_group_key=device_group_key,
- device_password=device_password,
- client_secret=client_secret,
- )
- # Bypass async_init — inject mocked internals directly.
- auth.client = MagicMock()
- auth._client_id = "test-client-id"
- auth._pool_id = "eu-west-1_TestPool"
- auth._region = "eu-west-1"
- auth.loop = MagicMock()
- auth.loop.run_in_executor = AsyncMock()
- return auth
-
-
-# ---------------------------------------------------------------------------
-# Tests: async_init() — lines 96-112
-# ---------------------------------------------------------------------------
-
-
-class TestAsyncInit:
- """Cover lines 98-112: async_init() sets pool_id, client_id, region and
- boto3 client."""
-
- async def test_async_init_sets_pool_id_and_client_id(self):
- """async_init reads login info and sets internal auth fields."""
- from apyhiveapi.api.hive_auth_async import HiveAuthAsync
-
- auth = HiveAuthAsync(username="user@test.com", password="pass")
- auth.client = None # trigger async_init flow
-
- mock_boto_client = MagicMock()
-
- auth.loop = MagicMock()
- auth.loop.run_in_executor = AsyncMock(
- side_effect=[_LOGIN_INFO, mock_boto_client]
- )
-
- await auth.async_init()
-
- assert auth._pool_id == "eu-west-1_TestPool"
- assert auth._client_id == "test-client-id"
- assert auth._region == "eu-west-1"
- assert auth.client is mock_boto_client
-
- async def test_async_init_splits_region_correctly(self):
- """Region is extracted as the part before the underscore in UPID/REGION."""
- from apyhiveapi.api.hive_auth_async import HiveAuthAsync
-
- auth = HiveAuthAsync(username="user@test.com", password="pass")
- auth.client = None
-
- login_info = {
- "UPID": "ap-southeast-2_XyzPool",
- "CLIID": "ap-client",
- "REGION": "ap-southeast-2_XyzPool",
- }
- mock_boto_client = MagicMock()
- auth.loop = MagicMock()
- auth.loop.run_in_executor = AsyncMock(
- side_effect=[login_info, mock_boto_client]
- )
-
- await auth.async_init()
-
- assert auth._region == "ap-southeast-2"
-
-
-# ---------------------------------------------------------------------------
-# Tests: calculate_a() safety check — line 140-141
-# ---------------------------------------------------------------------------
-
-
-class TestCalculateA:
- """Cover line 141: safety check when big_a % big_n == 0."""
-
- async def test_safety_check_raises_when_a_is_zero_mod_n(self):
- """If pow(g, a, n) == 0 mod n (i.e., equals big_n or 0), ValueError is raised."""
- auth = await _make_auth()
- # Force pow to return auth.big_n so that big_a % big_n == 0
- with patch("builtins.pow", return_value=auth.big_n):
- with pytest.raises(ValueError, match="Safety check for A failed"):
- auth.calculate_a()
-
- async def test_safety_check_passes_normally(self):
- """Under normal random inputs, calculate_a does not raise and returns positive int."""
- auth = await _make_auth()
- # calculate_a was already called during __init__; calling it again should also work
- result = auth.calculate_a()
- assert result > 0
-
-
-# ---------------------------------------------------------------------------
-# Tests: get_password_authentication_key() — lines 155-172
-# ---------------------------------------------------------------------------
-
-
-class TestGetPasswordAuthenticationKey:
- """Cover lines 155-172: get_password_authentication_key() computes HKDF."""
-
- async def test_returns_bytes(self):
- """With valid SRP inputs, the method returns a bytes-like value."""
- auth = await _make_auth()
- from apyhiveapi.api.srp_crypto import get_random
-
- # Pick a server_b that won't produce u_value == 0 by using a known large value
- server_b_value = hex(get_random(128))[2:]
- salt = hex(get_random(16))[2:]
-
- # Patch calculate_u to return a known non-zero value to avoid flakiness
- with patch("apyhiveapi.api.hive_auth_async.calculate_u", return_value=99999):
- result = auth.get_password_authentication_key(
- "testuser", "testpass", server_b_value, salt
- )
-
- assert isinstance(result, (bytes, bytearray))
-
- async def test_u_value_zero_raises_value_error(self):
- """If calculate_u returns 0, ValueError is raised."""
- auth = await _make_auth()
- from apyhiveapi.api.srp_crypto import get_random
-
- server_b_value = hex(get_random(128))[2:]
- salt = hex(get_random(16))[2:]
-
- with patch("apyhiveapi.api.hive_auth_async.calculate_u", return_value=0):
- with pytest.raises(ValueError, match="U cannot be zero"):
- auth.get_password_authentication_key(
- "testuser", "testpass", server_b_value, salt
- )
-
- async def test_accepts_integer_server_b(self):
- """server_b_value can be passed as an integer (handled by _to_int)."""
- auth = await _make_auth()
- from apyhiveapi.api.srp_crypto import get_random
-
- server_b_int = get_random(128)
- salt = hex(get_random(16))[2:]
-
- with patch("apyhiveapi.api.hive_auth_async.calculate_u", return_value=12345):
- result = auth.get_password_authentication_key(
- "testuser", "testpass", server_b_int, salt
- )
-
- assert isinstance(result, (bytes, bytearray))
-
-
-# ---------------------------------------------------------------------------
-# Tests: process_challenge() — lines 203-254
-# ---------------------------------------------------------------------------
-
-
-class TestProcessChallenge:
- """Cover lines 205-254: process_challenge() builds the SRP response."""
-
- def _make_challenge_params(self, salt_as_int=False):
- """Return a minimal valid challenge_parameters dict."""
- import base64
-
- salt = "aabbccddee"
- if salt_as_int:
- salt = int("aabbccddee", 16)
- return {
- "USER_ID_FOR_SRP": "challenge-user@test.com",
- "SALT": salt,
- "SRP_B": "ff" * 32, # arbitrary hex
- "SECRET_BLOCK": base64.b64encode(b"secret-block-bytes").decode(),
- }
-
- async def test_returns_required_keys(self):
- """Basic challenge response includes mandatory SRP keys."""
- auth = await _make_auth()
-
- fake_hkdf = b"\x00" * 32
- auth.loop.run_in_executor = AsyncMock(return_value=fake_hkdf)
-
- params = self._make_challenge_params()
- result = await auth.process_challenge(params)
-
- assert "TIMESTAMP" in result
- assert "USERNAME" in result
- assert "PASSWORD_CLAIM_SECRET_BLOCK" in result
- assert "PASSWORD_CLAIM_SIGNATURE" in result
-
- async def test_sets_user_id_from_challenge(self):
- """process_challenge stores USER_ID_FOR_SRP as self.user_id."""
- auth = await _make_auth()
-
- fake_hkdf = b"\x00" * 32
- auth.loop.run_in_executor = AsyncMock(return_value=fake_hkdf)
-
- params = self._make_challenge_params()
- await auth.process_challenge(params)
-
- assert auth.user_id == "challenge-user@test.com"
-
- async def test_with_client_secret_adds_secret_hash(self):
- """When client_secret is set, SECRET_HASH is added to the response."""
- auth = await _make_auth(client_secret="my-secret")
-
- fake_hkdf = b"\x00" * 32
- auth.loop.run_in_executor = AsyncMock(return_value=fake_hkdf)
-
- params = self._make_challenge_params()
- result = await auth.process_challenge(params)
-
- assert "SECRET_HASH" in result
-
- async def test_without_client_secret_no_secret_hash(self):
- """When client_secret is None, SECRET_HASH is absent from the response."""
- auth = await _make_auth(client_secret=None)
-
- fake_hkdf = b"\x00" * 32
- auth.loop.run_in_executor = AsyncMock(return_value=fake_hkdf)
-
- params = self._make_challenge_params()
- result = await auth.process_challenge(params)
-
- assert "SECRET_HASH" not in result
-
- async def test_with_device_key_adds_device_key(self):
- """When device_key is set, DEVICE_KEY is added to the response."""
- auth = await _make_auth(device_key="dk-challenge")
-
- fake_hkdf = b"\x00" * 32
- auth.loop.run_in_executor = AsyncMock(return_value=fake_hkdf)
-
- params = self._make_challenge_params()
- result = await auth.process_challenge(params)
-
- assert result["DEVICE_KEY"] == "dk-challenge"
-
- async def test_without_device_key_no_device_key_in_response(self):
- """When device_key is None, DEVICE_KEY is absent from the response."""
- auth = await _make_auth(device_key=None)
-
- fake_hkdf = b"\x00" * 32
- auth.loop.run_in_executor = AsyncMock(return_value=fake_hkdf)
-
- params = self._make_challenge_params()
- result = await auth.process_challenge(params)
-
- assert "DEVICE_KEY" not in result
-
- async def test_salt_as_integer_triggers_pad_hex(self):
- """When SALT is an integer (not str), pad_hex is applied before use."""
- auth = await _make_auth()
-
- fake_hkdf = b"\x00" * 32
- auth.loop.run_in_executor = AsyncMock(return_value=fake_hkdf)
-
- params = self._make_challenge_params(salt_as_int=True)
- # Should not raise; int-type SALT is handled by the isinstance check
- result = await auth.process_challenge(params)
-
- assert "TIMESTAMP" in result
-
-
-# ---------------------------------------------------------------------------
-# Tests: login() — client is None triggers async_init (line 262-263)
-# ---------------------------------------------------------------------------
-
-
-class TestLoginClientNone:
- """Cover line 263: when client is None, async_init() is awaited."""
-
- async def test_login_calls_async_init_when_client_is_none(self):
- """If client is None before login, async_init is called before SRP flow."""
- auth = await _make_auth()
- auth.client = None # reset to trigger the branch
- auth.use_file = False # ensure we go through the client-None path
-
- auth_result = {"AuthenticationResult": {"AccessToken": "post-init-token"}}
- challenge_response = {
- "ChallengeName": "PASSWORD_VERIFIER",
- "ChallengeParameters": {
- "USER_ID_FOR_SRP": "user@test.com",
- "SALT": "aabbccdd",
- "SRP_B": "ccddee",
- "SECRET_BLOCK": "YWJj",
- },
- }
-
- async def fake_async_init():
- auth.client = MagicMock()
- auth._client_id = "test-client-id"
- auth._pool_id = "eu-west-1_TestPool"
- auth._region = "eu-west-1"
-
- with patch.object(auth, "async_init", side_effect=fake_async_init) as mock_init:
- with patch.object(
- auth, "process_challenge", new_callable=AsyncMock
- ) as mock_ch:
- mock_ch.return_value = {
- "TIMESTAMP": "Mon Jan 01 00:00:00 UTC 2024",
- "USERNAME": "user",
- }
- # After async_init, run_in_executor is called for initiate_auth then
- # respond_to_auth_challenge
- auth.loop.run_in_executor = AsyncMock(
- side_effect=[challenge_response, auth_result]
- )
- result = await auth.login()
-
- mock_init.assert_called_once()
- assert "AuthenticationResult" in result
-
-
-# ---------------------------------------------------------------------------
-# Tests: login() — unsupported challenge name (lines 335-337)
-# ---------------------------------------------------------------------------
-
-
-class TestLoginUnsupportedChallenge:
- """Cover lines 335-337: non-PASSWORD_VERIFIER challenge raises NotImplementedError."""
-
- async def test_new_password_required_raises_not_implemented(self):
- """NEW_PASSWORD_REQUIRED challenge is not supported and raises NotImplementedError."""
- auth = await _make_auth()
- auth.loop.run_in_executor = AsyncMock(
- return_value={
- "ChallengeName": "NEW_PASSWORD_REQUIRED",
- "ChallengeParameters": {},
- }
- )
- with pytest.raises(NotImplementedError, match="NEW_PASSWORD_REQUIRED"):
- await auth.login()
-
- async def test_custom_challenge_raises_not_implemented(self):
- """Any unknown challenge name raises NotImplementedError."""
- auth = await _make_auth()
- auth.loop.run_in_executor = AsyncMock(
- return_value={
- "ChallengeName": "UNKNOWN_CHALLENGE_TYPE",
- "ChallengeParameters": {},
- }
- )
- with pytest.raises(NotImplementedError, match="UNKNOWN_CHALLENGE_TYPE"):
- await auth.login()
-
-
-# ---------------------------------------------------------------------------
-# Tests: login() — respond_to_auth_challenge ResourceNotFoundException (line 307-311)
-# ---------------------------------------------------------------------------
-
-
-class TestLoginResourceNotFound:
- """Cover lines 307-311: ResourceNotFoundException in respond_to_auth_challenge."""
-
- async def test_resource_not_found_raises_invalid_device_authentication(self):
- """ResourceNotFoundException during challenge → HiveInvalidDeviceAuthentication."""
- auth = await _make_auth()
- challenge_response = {
- "ChallengeName": "PASSWORD_VERIFIER",
- "ChallengeParameters": {
- "USER_ID_FOR_SRP": "user@test.com",
- "SALT": "aabbccdd",
- "SRP_B": "ccddee",
- "SECRET_BLOCK": "YWJj",
- },
- }
- resource_err = _named_client_error("ResourceNotFoundException")
- with patch.object(auth, "process_challenge", new_callable=AsyncMock) as mock_ch:
- mock_ch.return_value = {
- "TIMESTAMP": "Mon Jan 01 00:00:00 UTC 2024",
- "USERNAME": "user",
- }
- auth.loop.run_in_executor = AsyncMock(
- side_effect=[challenge_response, resource_err]
- )
- with pytest.raises(HiveInvalidDeviceAuthentication):
- await auth.login()
-
-
-# ---------------------------------------------------------------------------
-# Tests: login() — EndpointConnectionError in respond_to_auth_challenge (lines 312-317)
-# ---------------------------------------------------------------------------
-
-
-class TestLoginEndpointErrorOnChallenge:
- """Cover lines 312-317: EndpointConnectionError during respond_to_auth_challenge."""
-
- async def test_endpoint_error_on_challenge_raises_api_error(self):
- """EndpointConnectionError during SRP challenge response → HiveApiError."""
- auth = await _make_auth()
- challenge_response = {
- "ChallengeName": "PASSWORD_VERIFIER",
- "ChallengeParameters": {
- "USER_ID_FOR_SRP": "user@test.com",
- "SALT": "aabbccdd",
- "SRP_B": "ccddee",
- "SECRET_BLOCK": "YWJj",
- },
- }
- with patch.object(auth, "process_challenge", new_callable=AsyncMock) as mock_ch:
- mock_ch.return_value = {
- "TIMESTAMP": "Mon Jan 01 00:00:00 UTC 2024",
- "USERNAME": "user",
- }
- auth.loop.run_in_executor = AsyncMock(
- side_effect=[challenge_response, _endpoint_error()]
- )
- with pytest.raises(HiveApiError):
- await auth.login()
-
-
-# ---------------------------------------------------------------------------
-# Tests: login() — result without AuthenticationResult (lines 321-333)
-# ---------------------------------------------------------------------------
-
-
-class TestLoginResultHandling:
- """Cover lines 321-333: AuthenticationResult presence/absence in login result."""
-
- async def test_result_without_authentication_result_does_not_store_token(self):
- """If result lacks 'AuthenticationResult', access_token is not set."""
- auth = await _make_auth()
- # First call → PASSWORD_VERIFIER challenge
- # Second call → result without AuthenticationResult (e.g., SMS_MFA)
- challenge_response = {
- "ChallengeName": "PASSWORD_VERIFIER",
- "ChallengeParameters": {
- "USER_ID_FOR_SRP": "user@test.com",
- "SALT": "aabbccdd",
- "SRP_B": "ccddee",
- "SECRET_BLOCK": "YWJj",
- },
- }
- sms_challenge_result = {
- "ChallengeName": "SMS_MFA",
- "Session": "session-tok",
- "ChallengeParameters": {},
- }
- with patch.object(auth, "process_challenge", new_callable=AsyncMock) as mock_ch:
- mock_ch.return_value = {
- "TIMESTAMP": "Mon Jan 01 00:00:00 UTC 2024",
- "USERNAME": "user",
- }
- auth.loop.run_in_executor = AsyncMock(
- side_effect=[challenge_response, sms_challenge_result]
- )
- result = await auth.login()
-
- # access_token was never set
- assert auth.access_token is None
- assert result is sms_challenge_result
-
- async def test_result_with_authentication_result_but_no_new_device_metadata(self):
- """AuthenticationResult without NewDeviceMetadata sets access_token only."""
- auth = await _make_auth()
- challenge_response = {
- "ChallengeName": "PASSWORD_VERIFIER",
- "ChallengeParameters": {
- "USER_ID_FOR_SRP": "user@test.com",
- "SALT": "aabbccdd",
- "SRP_B": "ccddee",
- "SECRET_BLOCK": "YWJj",
- },
- }
- auth_result = {"AuthenticationResult": {"AccessToken": "my-access-token"}}
- with patch.object(auth, "process_challenge", new_callable=AsyncMock) as mock_ch:
- mock_ch.return_value = {
- "TIMESTAMP": "Mon Jan 01 00:00:00 UTC 2024",
- "USERNAME": "user",
- }
- auth.loop.run_in_executor = AsyncMock(
- side_effect=[challenge_response, auth_result]
- )
- await auth.login()
-
- assert auth.access_token == "my-access-token"
- assert auth.device_group_key is None
- assert auth.device_key is None
-
-
-# ---------------------------------------------------------------------------
-# Tests: device_login() — client is None (line 347-348)
-# ---------------------------------------------------------------------------
-
-
-class TestDeviceLoginClientNone:
- """Cover device_login's async_init call when client is None."""
-
- async def test_device_login_calls_async_init_when_client_is_none(self):
- """If client is None, async_init is called before proceeding."""
- auth = await _make_auth(device_key="dk-1")
- auth.client = None
-
- async def fake_async_init():
- auth.client = MagicMock()
- auth._client_id = "test-client-id"
-
- with patch.object(auth, "async_init", side_effect=fake_async_init) as mock_init:
- auth.loop.run_in_executor = AsyncMock(
- side_effect=_named_client_error("ResourceNotFoundException")
- )
- with pytest.raises(HiveInvalidDeviceAuthentication):
- await auth.device_login()
-
- mock_init.assert_called_once()
-
-
-# ---------------------------------------------------------------------------
-# Tests: sms_2fa() — NewDeviceMetadata absent (line 441)
-# ---------------------------------------------------------------------------
-
-
-class TestSms2faNoNewDeviceMetadata:
- """Cover line 441: sms_2fa when NewDeviceMetadata is absent in result."""
-
- async def test_no_new_device_metadata_does_not_set_device_keys(self):
- """When NewDeviceMetadata is absent, device_group_key and device_key stay None."""
- auth = await _make_auth()
- original_group_key = auth.device_group_key # None
- original_device_key = auth.device_key # None
-
- sms_result = {
- "AuthenticationResult": {
- "AccessToken": "sms-access-token",
- # No "NewDeviceMetadata" key
- }
- }
- auth.loop.run_in_executor = AsyncMock(return_value=sms_result)
-
- result = await auth.sms_2fa("654321", {"Session": "sess-abc"})
-
- assert auth.access_token == "sms-access-token"
- assert auth.device_group_key == original_group_key # unchanged
- assert auth.device_key == original_device_key # unchanged
- assert result is sms_result
-
-
-# ---------------------------------------------------------------------------
-# Tests: sms_2fa() — CodeMismatchException path (lines 424-429)
-# ---------------------------------------------------------------------------
-
-
-class TestSms2faCodeMismatch:
- """Cover lines 424-429: CodeMismatchException raises HiveInvalid2FACode."""
-
- async def test_code_mismatch_raises_invalid_2fa_code(self):
- """CodeMismatchException in sms_2fa raises HiveInvalid2FACode."""
- auth = await _make_auth()
- auth.loop.run_in_executor = AsyncMock(
- side_effect=_named_client_error("CodeMismatchException")
- )
- with pytest.raises(HiveInvalid2FACode):
- await auth.sms_2fa("000000", {"Session": "sess-1"})
-
- async def test_not_authorized_raises_invalid_2fa_code(self):
- """NotAuthorizedException in sms_2fa raises HiveInvalid2FACode."""
- auth = await _make_auth()
- auth.loop.run_in_executor = AsyncMock(
- side_effect=_named_client_error("NotAuthorizedException")
- )
- with pytest.raises(HiveInvalid2FACode):
- await auth.sms_2fa("111111", {"Session": "sess-2"})
-
-
-# ---------------------------------------------------------------------------
-# Tests: refresh_token() — client is None (line 440-441)
-# ---------------------------------------------------------------------------
-
-
-class TestRefreshTokenClientNone:
- """Cover line 440-441: refresh_token calls async_init when client is None."""
-
- async def test_refresh_token_calls_async_init_when_client_is_none(self):
- """If client is None, async_init is awaited before refreshing."""
- auth = await _make_auth()
- auth.client = None
-
- result_payload = {"AuthenticationResult": {"AccessToken": "refreshed-tok"}}
-
- async def fake_async_init():
- auth.client = MagicMock()
-
- with patch.object(auth, "async_init", side_effect=fake_async_init) as mock_init:
- auth.loop.run_in_executor = AsyncMock(return_value=result_payload)
- result = await auth.refresh_token("some-refresh-token")
-
- mock_init.assert_called_once()
- assert result is result_payload
-
-
-# ---------------------------------------------------------------------------
-# Tests: refresh_token() — result path when no AuthenticationResult (lines 479-485)
-# ---------------------------------------------------------------------------
-
-
-class TestRefreshTokenResultPath:
- """Cover lines 479-485: refresh_token when result is returned normally."""
-
- async def test_returns_result_directly(self):
- """refresh_token returns the result from Cognito directly."""
- auth = await _make_auth()
- result_payload = {"AuthenticationResult": {"AccessToken": "tok-xyz"}}
- auth.loop.run_in_executor = AsyncMock(return_value=result_payload)
-
- result = await auth.refresh_token("refresh-tok-abc")
-
- assert result is result_payload
-
- async def test_with_device_key_includes_device_key_param(self):
- """When device_key is set, DEVICE_KEY is included in auth_params."""
- auth = await _make_auth(device_key="dk-refresh-001")
- result_payload = {"AuthenticationResult": {"AccessToken": "tok-dk"}}
- auth.loop.run_in_executor = AsyncMock(return_value=result_payload)
-
- result = await auth.refresh_token("refresh-tok-dk")
-
- assert result is result_payload
-
- async def test_without_device_key_sends_only_refresh_token(self):
- """When device_key is None, auth_params has only REFRESH_TOKEN."""
- auth = await _make_auth(device_key=None)
- result_payload = {"AuthenticationResult": {"AccessToken": "tok-no-dk"}}
- auth.loop.run_in_executor = AsyncMock(return_value=result_payload)
-
- result = await auth.refresh_token("refresh-tok-no-dk")
-
- assert result is result_payload
-
-
-# ---------------------------------------------------------------------------
-# Tests: login() — swallowed ClientError in initiate_auth (line 280->288)
-# ---------------------------------------------------------------------------
-
-
-class TestLoginInitiateAuthSwallowedClientError:
- """Arc 280->288: ClientError caught but class name is not UserNotFoundException."""
-
- async def test_other_client_error_in_initiate_auth_falls_through(self):
- """Non-UserNotFoundException ClientError is swallowed; response stays None → TypeError."""
- auth = await _make_auth()
-
- wrong_cls = type("SomeOtherError", (botocore.exceptions.ClientError,), {})
- wrong_err = wrong_cls(
- {"Error": {"Code": "SomeOtherError", "Message": "msg"}}, "op"
- )
- auth.loop.run_in_executor = AsyncMock(side_effect=wrong_err)
-
- # Exception is swallowed; line 288 `response["ChallengeName"]` raises TypeError
- # because response is None
- with pytest.raises((TypeError, KeyError)):
- await auth.login()
-
-
-# ---------------------------------------------------------------------------
-# Tests: login() — swallowed EndpointConnectionError in initiate_auth (line 284->288)
-# ---------------------------------------------------------------------------
-
-
-class TestLoginInitiateAuthSwallowedEndpointError:
- """Arc 284->288: EndpointConnectionError caught but class name is wrong."""
-
- async def test_wrong_name_endpoint_error_in_initiate_auth_falls_through(self):
- """EndpointConnectionError with wrong name is swallowed; response stays None."""
- auth = await _make_auth()
-
- wrong_cls = type(
- "WrongEndpoint", (botocore.exceptions.EndpointConnectionError,), {}
- )
- wrong_err = wrong_cls(endpoint_url="https://cognito.eu-west-1.amazonaws.com")
- auth.loop.run_in_executor = AsyncMock(side_effect=wrong_err)
-
- with pytest.raises((TypeError, KeyError)):
- await auth.login()
-
-
-# ---------------------------------------------------------------------------
-# Tests: login() — swallowed ClientError in respond_to_auth_challenge (307->319)
-# ---------------------------------------------------------------------------
-
-
-class TestLoginChallengeSwallowedClientError:
- """Arc 307->319: ClientError caught in challenge response with name not matching."""
-
- async def test_other_client_error_in_challenge_falls_through(self):
- """ClientError that is neither NotAuthorized nor ResourceNotFound is swallowed."""
- auth = await _make_auth()
-
- challenge_response = {
- "ChallengeName": "PASSWORD_VERIFIER",
- "ChallengeParameters": {
- "USER_ID_FOR_SRP": "user@test.com",
- "SALT": "aabbccdd",
- "SRP_B": "ccddee",
- "SECRET_BLOCK": "YWJj",
- },
- }
-
- wrong_cls = type("ThirdPartyError", (botocore.exceptions.ClientError,), {})
- wrong_err = wrong_cls(
- {"Error": {"Code": "ThirdPartyError", "Message": "msg"}}, "op"
- )
-
- with patch.object(auth, "process_challenge", new_callable=AsyncMock) as mock_ch:
- mock_ch.return_value = {"TIMESTAMP": "...", "USERNAME": "user"}
- auth.loop.run_in_executor = AsyncMock(
- side_effect=[challenge_response, wrong_err]
- )
- # Exception is swallowed; result stays None → TypeError on line 321
- with pytest.raises((TypeError, AttributeError)):
- await auth.login()
-
-
-# ---------------------------------------------------------------------------
-# Tests: login() — swallowed EndpointConnectionError in challenge (313->319)
-# ---------------------------------------------------------------------------
-
-
-class TestLoginChallengeSwallowedEndpointError:
- """Arc 313->319: EndpointConnectionError caught with wrong class name in challenge."""
-
- async def test_wrong_name_endpoint_error_in_challenge_falls_through(self):
- """EndpointConnectionError with wrong name is swallowed; result stays None."""
- auth = await _make_auth()
-
- challenge_response = {
- "ChallengeName": "PASSWORD_VERIFIER",
- "ChallengeParameters": {
- "USER_ID_FOR_SRP": "user@test.com",
- "SALT": "aabbccdd",
- "SRP_B": "ccddee",
- "SECRET_BLOCK": "YWJj",
- },
- }
-
- wrong_cls = type(
- "WrongEndpoint", (botocore.exceptions.EndpointConnectionError,), {}
- )
- wrong_err = wrong_cls(endpoint_url="https://cognito.eu-west-1.amazonaws.com")
-
- with patch.object(auth, "process_challenge", new_callable=AsyncMock) as mock_ch:
- mock_ch.return_value = {"TIMESTAMP": "...", "USERNAME": "user"}
- auth.loop.run_in_executor = AsyncMock(
- side_effect=[challenge_response, wrong_err]
- )
- with pytest.raises((TypeError, AttributeError)):
- await auth.login()
-
-
-# ---------------------------------------------------------------------------
-# Tests: device_login() — success path through process_device_challenge (lines 364-367, 391)
-# ---------------------------------------------------------------------------
-
-
-class TestDeviceLoginSuccessPath:
- """Lines 364-367, 391: device_login processes device challenge and returns result."""
-
- async def test_successful_device_login_returns_auth_result(self):
- """Full device_login success: process_device_challenge called, result returned."""
- auth = await _make_auth(device_key="dk-abc", device_group_key="grp-abc")
- auth.device_password = "dev-pass"
-
- initial_result = {
- "ChallengeParameters": {
- "USERNAME": "user@test.com",
- "SALT": "aabbccdd",
- "SRP_B": "ccddee",
- "SECRET_BLOCK": "YWJj",
- }
- }
- final_result = {"AuthenticationResult": {"AccessToken": "device-access-token"}}
-
- with patch.object(
- auth, "process_device_challenge", new_callable=AsyncMock
- ) as mock_pdc:
- mock_pdc.return_value = {
- "TIMESTAMP": "Mon Jan 01 00:00:00 UTC 2024",
- "USERNAME": "user@test.com",
- "PASSWORD_CLAIM_SECRET_BLOCK": "YWJj",
- "PASSWORD_CLAIM_SIGNATURE": "sig",
- "DEVICE_KEY": "dk-abc",
- }
- auth.loop.run_in_executor = AsyncMock(
- side_effect=[initial_result, final_result]
- )
- result = await auth.device_login()
-
- mock_pdc.assert_called_once_with(initial_result["ChallengeParameters"])
- assert result is final_result
-
- async def test_device_login_calls_second_respond_to_auth_challenge(self):
- """Lines 367-375: second respond_to_auth_challenge is called with device challenge."""
- auth = await _make_auth(device_key="dk-xyz", device_group_key="grp-xyz")
- auth.device_password = "dev-pass-xyz"
-
- initial_result = {
- "ChallengeParameters": {
- "USERNAME": "user@test.com",
- "SALT": "11223344",
- "SRP_B": "55667788",
- "SECRET_BLOCK": "dGVzdA==",
- }
- }
- final_result = {"AuthenticationResult": {"AccessToken": "tok-xyz"}}
-
- challenge_resp = {
- "TIMESTAMP": "Mon Jan 01 00:00:00 UTC 2024",
- "USERNAME": "user@test.com",
- "PASSWORD_CLAIM_SECRET_BLOCK": "dGVzdA==",
- "PASSWORD_CLAIM_SIGNATURE": "sig",
- "DEVICE_KEY": "dk-xyz",
- }
-
- with patch.object(
- auth, "process_device_challenge", new_callable=AsyncMock
- ) as mock_pdc:
- mock_pdc.return_value = challenge_resp
- auth.loop.run_in_executor = AsyncMock(
- side_effect=[initial_result, final_result]
- )
- result = await auth.device_login()
-
- assert auth.loop.run_in_executor.call_count == 2
- assert result["AuthenticationResult"]["AccessToken"] == "tok-xyz"
-
-
-# ---------------------------------------------------------------------------
-# Tests: device_login() — wrong-name EndpointConnectionError (line 389)
-# ---------------------------------------------------------------------------
-
-
-class TestDeviceLoginEndpointWrongName:
- """Line 389: EndpointConnectionError with wrong __class__.__name__ raises
- HiveInvalidDeviceAuthentication instead of HiveApiError."""
-
- async def test_wrong_name_endpoint_error_raises_invalid_device_auth(self):
- """A subclass of EndpointConnectionError with a different name hits line 389."""
- auth = await _make_auth(device_key="dk-err", device_group_key="grp-err")
- auth.device_password = "dev-pass-err"
-
- wrong_cls = type(
- "WrongEndpoint", (botocore.exceptions.EndpointConnectionError,), {}
- )
- wrong_err = wrong_cls(endpoint_url="https://cognito.eu-west-1.amazonaws.com")
- auth.loop.run_in_executor = AsyncMock(side_effect=wrong_err)
-
- with pytest.raises(HiveInvalidDeviceAuthentication):
- await auth.device_login()
-
-
-# ---------------------------------------------------------------------------
-# Tests: sms_2fa() — swallowed ClientError (arc 424->435)
-# ---------------------------------------------------------------------------
-
-
-class TestSms2faSwallowedClientError:
- """Arc 424->435: ClientError caught in sms_2fa with unrecognised class name."""
-
- async def test_other_client_error_is_swallowed_returns_none(self):
- """Non-matching ClientError is swallowed; result stays None (returned)."""
- auth = await _make_auth()
-
- wrong_cls = type("OtherError", (botocore.exceptions.ClientError,), {})
- wrong_err = wrong_cls({"Error": {"Code": "OtherError", "Message": "msg"}}, "op")
- auth.loop.run_in_executor = AsyncMock(side_effect=wrong_err)
-
- result = await auth.sms_2fa("123456", {"Session": "sess-xyz"})
- assert (
- result is None
- ) # sms_2fa initialises result=None; swallowed → returns None
-
-
-# ---------------------------------------------------------------------------
-# Tests: sms_2fa() — swallowed EndpointConnectionError (arc 431->435)
-# ---------------------------------------------------------------------------
-
-
-class TestSms2faSwallowedEndpointError:
- """Arc 431->435: EndpointConnectionError caught with wrong class name in sms_2fa."""
-
- async def test_wrong_name_endpoint_error_is_swallowed(self):
- """EndpointConnectionError subclass with wrong name is swallowed; returns None."""
- auth = await _make_auth()
-
- wrong_cls = type(
- "WrongEndpoint", (botocore.exceptions.EndpointConnectionError,), {}
- )
- wrong_err = wrong_cls(endpoint_url="https://cognito.eu-west-1.amazonaws.com")
- auth.loop.run_in_executor = AsyncMock(side_effect=wrong_err)
-
- result = await auth.sms_2fa("654321", {"Session": "sess-abc"})
- assert result is None
-
-
-# ---------------------------------------------------------------------------
-# Tests: refresh_token() — swallowed EndpointConnectionError (arc 479->485)
-# ---------------------------------------------------------------------------
-
-
-class TestRefreshTokenSwallowedEndpointError:
- """Arc 479->485: EndpointConnectionError caught with wrong class name in refresh_token."""
-
- async def test_wrong_name_endpoint_error_is_swallowed_returns_none(self):
- """EndpointConnectionError subclass with wrong name is swallowed; result=None returned."""
- auth = await _make_auth()
-
- wrong_cls = type(
- "WrongEndpoint", (botocore.exceptions.EndpointConnectionError,), {}
- )
- wrong_err = wrong_cls(endpoint_url="https://cognito.eu-west-1.amazonaws.com")
- auth.loop.run_in_executor = AsyncMock(side_effect=wrong_err)
-
- # result initialised to None; exception swallowed; line 485 reached; returns None
- result = await auth.refresh_token("some-refresh-token")
- assert result is None
diff --git a/tests/unit/test_hive_exceptions.py b/tests/unit/test_hive_exceptions.py
new file mode 100644
index 0000000..33ffb5c
--- /dev/null
+++ b/tests/unit/test_hive_exceptions.py
@@ -0,0 +1,93 @@
+"""Unit tests for the hive_exceptions hierarchy."""
+
+import pytest
+from apyhiveapi.helper.hive_exceptions import (
+ FileInUse,
+ HiveApiError,
+ HiveAuthCredentialError,
+ HiveAuthError,
+ HiveConfigurationError,
+ HiveError,
+ HiveFailedToRefreshTokens,
+ HiveInvalid2FACode,
+ HiveInvalidDeviceAuthentication,
+ HiveInvalidPassword,
+ HiveInvalidUsername,
+ HiveReauthRequired,
+ HiveRefreshTokenExpired,
+ HiveUnknownConfiguration,
+ NoApiToken,
+)
+
+
+class TestHiveErrorBase:
+ def test_hive_api_error_is_hive_error(self):
+ assert issubclass(HiveApiError, HiveError)
+
+ def test_hive_auth_error_is_hive_api_error(self):
+ assert issubclass(HiveAuthError, HiveApiError)
+
+ def test_hive_auth_error_is_hive_error(self):
+ assert issubclass(HiveAuthError, HiveError)
+
+ def test_hive_refresh_token_expired_is_hive_api_error(self):
+ assert issubclass(HiveRefreshTokenExpired, HiveApiError)
+
+ def test_hive_failed_to_refresh_is_hive_api_error(self):
+ assert issubclass(HiveFailedToRefreshTokens, HiveApiError)
+
+ def test_hive_reauth_required_is_hive_error(self):
+ assert issubclass(HiveReauthRequired, HiveError)
+
+
+class TestCredentialErrors:
+ def test_invalid_username_is_hive_auth_credential_error(self):
+ assert issubclass(HiveInvalidUsername, HiveAuthCredentialError)
+
+ def test_invalid_password_is_hive_auth_credential_error(self):
+ assert issubclass(HiveInvalidPassword, HiveAuthCredentialError)
+
+ def test_invalid_2fa_is_hive_auth_credential_error(self):
+ assert issubclass(HiveInvalid2FACode, HiveAuthCredentialError)
+
+ def test_auth_credential_error_is_hive_error(self):
+ assert issubclass(HiveAuthCredentialError, HiveError)
+
+
+class TestConfigurationErrors:
+ def test_unknown_config_is_hive_configuration_error(self):
+ assert issubclass(HiveUnknownConfiguration, HiveConfigurationError)
+
+ def test_invalid_device_auth_is_hive_configuration_error(self):
+ assert issubclass(HiveInvalidDeviceAuthentication, HiveConfigurationError)
+
+ def test_configuration_error_is_hive_error(self):
+ assert issubclass(HiveConfigurationError, HiveError)
+
+
+class TestStandaloneExceptions:
+ def test_file_in_use_is_not_hive_error(self):
+ assert not issubclass(FileInUse, HiveError)
+
+ def test_no_api_token_is_not_hive_error(self):
+ assert not issubclass(NoApiToken, HiveError)
+
+ def test_file_in_use_is_exception(self):
+ assert issubclass(FileInUse, Exception)
+
+ def test_no_api_token_is_exception(self):
+ assert issubclass(NoApiToken, Exception)
+
+
+class TestInstantiable:
+ def test_hive_error_is_raiseable(self):
+ with pytest.raises(HiveError):
+ raise HiveError("test")
+
+ def test_hive_api_error_caught_as_hive_error(self):
+ with pytest.raises(HiveError):
+ raise HiveApiError("test")
+
+ def test_invalid_username_caught_as_hive_error(self):
+ with pytest.raises(HiveError):
+ raise HiveInvalidUsername("test")
diff --git a/tests/unit/test_hive_helper_extended.py b/tests/unit/test_hive_helper_extended.py
deleted file mode 100644
index 8c13ec3..0000000
--- a/tests/unit/test_hive_helper_extended.py
+++ /dev/null
@@ -1,143 +0,0 @@
-"""Tests for HiveHelper covering previously uncovered lines/branches."""
-
-# pylint: disable=protected-access
-
-from unittest.mock import MagicMock
-
-from apyhiveapi.helper.hive_helper import HiveHelper
-from apyhiveapi.helper.map import Map
-
-
-def _make_helper(entity_cache=None, products=None):
- """Build a HiveHelper with a minimally mocked session."""
- session = MagicMock()
- session.entity_cache = entity_cache if entity_cache is not None else {}
- session.data = Map(
- {
- "products": products or {},
- "devices": {},
- "actions": {},
- "user": {},
- "minMax": {},
- }
- )
- return HiveHelper(session)
-
-
-# ---------------------------------------------------------------------------
-# get_device_from_id — branch 133->122 (no match, loop continues then exits)
-# ---------------------------------------------------------------------------
-
-
-class TestGetDeviceFromIdBranch:
- """Covers the branch where no cache entry matches the requested ID."""
-
- def test_returns_false_when_no_match_in_cache(self):
- """When entity_cache has entries but none match n_id, returns False.
-
- This exercises the branch where the 'if n_id in (hive_id, device_id)'
- condition is False for every item (133->122 loop-continue then exit).
- """
- from apyhiveapi.helper.hivedataclasses import Device
-
- other_device = Device(
- hive_id="other-hive-id",
- hive_name="Other",
- hive_type="heating",
- ha_type="climate",
- device_id="other-device-id",
- device_name="Other",
- device_data={},
- )
- helper = _make_helper(entity_cache={"other-key": other_device})
- result = helper.get_device_from_id("nonexistent-id")
- assert result is False
-
- def test_returns_false_when_cache_is_empty(self):
- """When entity_cache is empty, returns False without entering the loop."""
- helper = _make_helper(entity_cache={})
- assert helper.get_device_from_id("any-id") is False
-
-
-# ---------------------------------------------------------------------------
-# get_heat_on_demand_device — lines 315-317
-# ---------------------------------------------------------------------------
-
-
-class TestGetHeatOnDemandDevice:
- """Covers HiveHelper.get_heat_on_demand_device (lines 315-317)."""
-
- def test_returns_linked_thermostat(self):
- """Looks up TRV by HiveID, then fetches linked thermostat by zone."""
- trv_id = "trv-001"
- thermostat_id = "zone-001"
-
- trv_data = {"state": {"zone": thermostat_id}, "type": "trvcontrol"}
- thermostat_data = {"id": thermostat_id, "type": "heating"}
-
- products = {
- trv_id: trv_data,
- thermostat_id: thermostat_data,
- }
- helper = _make_helper(products=products)
-
- # Device accessed with dict-style key "HiveID" as used inside the method
- device = MagicMock()
- device.__getitem__ = MagicMock(
- side_effect=lambda k: trv_id if k == "HiveID" else None
- )
-
- result = helper.get_heat_on_demand_device(device)
- assert result == thermostat_data
-
-
-# ---------------------------------------------------------------------------
-# sanitize_payload — list masking (line 329) and non-str/dict/list fallthrough
-# ---------------------------------------------------------------------------
-
-
-class TestSanitizePayload:
- """Covers _mask branches for list values and non-string scalar fallthrough."""
-
- def test_list_value_under_sensitive_key_is_masked(self):
- """A list value under a sensitive key has each element masked."""
- helper = _make_helper()
- payload = {"token": ["short", "averylongtoken123"]}
- result = helper.sanitize_payload(payload)
- # "short" (<=8 chars) → "***", "averylongtoken123" (>8 chars) → "aver...n123"
- assert result["token"] == ["***", "aver...n123"]
-
- def test_non_string_non_dict_non_list_under_sensitive_key_passes_through(self):
- """An int/bool/None value under a sensitive key is returned as-is."""
- helper = _make_helper()
- payload = {"token": 42}
- result = helper.sanitize_payload(payload)
- assert result["token"] == 42
-
- def test_none_under_sensitive_key_passes_through(self):
- """None under a sensitive key is returned unchanged."""
- helper = _make_helper()
- payload = {"token": None}
- result = helper.sanitize_payload(payload)
- assert result["token"] is None
-
- def test_bool_under_sensitive_key_passes_through(self):
- """A bool under a sensitive key is returned unchanged (not a str/dict/list)."""
- helper = _make_helper()
- payload = {"token": True}
- result = helper.sanitize_payload(payload)
- assert result["token"] is True
-
- def test_short_string_masked_as_stars(self):
- """A string of 8 characters or fewer is masked as '***'."""
- helper = _make_helper()
- payload = {"password": "abc12345"} # exactly 8 chars
- result = helper.sanitize_payload(payload)
- assert result["password"] == "***"
-
- def test_long_string_partially_masked(self):
- """A string longer than 8 characters is partially masked."""
- helper = _make_helper()
- payload = {"password": "supersecretpassword"}
- result = helper.sanitize_payload(payload)
- assert result["password"] == "supe...word"
diff --git a/tests/unit/test_hotwater.py b/tests/unit/test_hotwater.py
new file mode 100644
index 0000000..5da4e33
--- /dev/null
+++ b/tests/unit/test_hotwater.py
@@ -0,0 +1,328 @@
+"""Extended branch-coverage tests for WaterHeater / HiveHotwater."""
+
+# pylint: disable=too-few-public-methods
+from unittest.mock import AsyncMock, MagicMock
+
+from apyhiveapi.devices.hotwater import WaterHeater
+from apyhiveapi.helper.hivedataclasses import Device, SessionConfig
+from apyhiveapi.helper.map import Map
+
+_SCHEDULE_MODE = "SCHEDULE"
+_ON_MODE = "ON"
+_OFF_MODE = "OFF"
+_BOOST_MODE = "BOOST"
+
+
+def _make_hotwater(products=None, devices=None):
+ session = MagicMock()
+ session.data = Map(
+ {
+ "products": products or {},
+ "devices": devices or {},
+ "actions": {},
+ "minMax": {},
+ "user": {},
+ }
+ )
+ session.config = SessionConfig()
+ session.helper = MagicMock()
+ session.helper.device_recovered = MagicMock()
+ session.helper.error_check = AsyncMock()
+ session.helper.get_schedule_nnl = MagicMock(
+ return_value={"now": {"value": {"status": _ON_MODE}}, "next": {}, "later": {}}
+ )
+ session.attr = MagicMock()
+ session.attr.online_offline = AsyncMock(return_value=True)
+ session.attr.state_attributes = AsyncMock(return_value={})
+ session.api = MagicMock()
+ session.api.set_state = AsyncMock(return_value={"original": 200, "parsed": {}})
+ session.hive_refresh_tokens = AsyncMock()
+ session.get_devices = AsyncMock(return_value=True)
+ session.should_use_cached_data = MagicMock(return_value=False)
+ session.get_cached_device = MagicMock(return_value=None)
+ session.set_cached_device = MagicMock(side_effect=lambda d: d)
+ return WaterHeater(session=session)
+
+
+def _make_device(hive_id="hw-1", device_id="dev-1"):
+ return Device(
+ hive_id=hive_id,
+ hive_name="Hot Water",
+ hive_type="hotwater",
+ ha_type="water_heater",
+ device_id=device_id,
+ device_name="Hot Water",
+ device_data={"online": True},
+ ha_name="Hot Water",
+ )
+
+
+class TestGetMode:
+ async def test_boost_mode_reads_previous(self):
+ """BOOST state resolves to the previous mode stored in props."""
+ hw = _make_hotwater(
+ {
+ "hw-1": {
+ "state": {"mode": _BOOST_MODE},
+ "props": {"previous": {"mode": _ON_MODE}},
+ }
+ }
+ )
+ result = await hw.get_mode(_make_device())
+ assert result == _ON_MODE
+
+
+class TestGetState:
+ async def test_schedule_mode_boost_off_reads_schedule(self):
+ """SCHEDULE mode with boost OFF reads state from schedule nnl."""
+ hw = _make_hotwater(
+ {
+ "hw-1": {
+ "state": {
+ "mode": _SCHEDULE_MODE,
+ "status": _OFF_MODE,
+ "boost": False,
+ "schedule": {},
+ }
+ }
+ }
+ )
+ result = await hw.get_state(_make_device())
+ assert result is not None
+
+ async def test_non_schedule_state_mapped(self):
+ """Direct ON mode/status maps through HIVETOHA without schedule lookup."""
+ hw = _make_hotwater(
+ {
+ "hw-1": {
+ "state": {
+ "mode": _ON_MODE,
+ "status": _ON_MODE,
+ "schedule": {},
+ }
+ }
+ }
+ )
+ result = await hw.get_state(_make_device())
+ assert result is not None
+
+
+class TestGetWaterHeater:
+ async def test_cache_hit_returns_cached(self):
+ """Cached device is returned immediately when poll is slow/busy."""
+ hw = _make_hotwater()
+ hw.session.should_use_cached_data = MagicMock(return_value=True)
+ cached_device = _make_device()
+ cached_device.status = {"current_operation": _SCHEDULE_MODE}
+ hw.session.get_cached_device = MagicMock(return_value=cached_device)
+ d = _make_device()
+ result = await hw.get_water_heater(d)
+ assert result is cached_device
+ hw.session.attr.online_offline.assert_not_called()
+
+ async def test_device_data_not_dict_gets_reset(self):
+ """Non-dict device_data is replaced with an empty dict before use."""
+ hw = _make_hotwater(
+ products={"hw-1": {"state": {"mode": _SCHEDULE_MODE}}},
+ devices={"dev-1": {"props": {}, "parent": None}},
+ )
+ d = _make_device()
+ d.device_data = None
+ await hw.get_water_heater(d)
+ assert isinstance(d.device_data, dict)
+
+ async def test_offline_device_calls_error_check(self):
+ """Offline device triggers error_check and status defaults to None."""
+ hw = _make_hotwater(
+ products={"hw-1": {}},
+ devices={"dev-1": {}},
+ )
+ hw.session.attr.online_offline = AsyncMock(return_value=False)
+ d = _make_device()
+ result = await hw.get_water_heater(d)
+ hw.session.helper.error_check.assert_called_once()
+ assert result.status["current_operation"] is None
+
+
+class TestGetScheduleNowNextLater:
+ async def test_schedule_mode_returns_nnl(self):
+ """SCHEDULE mode with schedule data returns now/next/later dict."""
+ hw = _make_hotwater(
+ {"hw-1": {"state": {"mode": _SCHEDULE_MODE, "schedule": {"data": []}}}}
+ )
+ result = await hw.get_schedule_now_next_later(_make_device())
+ assert result is not None
+ assert "now" in result
+
+ async def test_non_schedule_mode_returns_none(self):
+ """Non-SCHEDULE mode returns None."""
+ hw = _make_hotwater({"hw-1": {"state": {"mode": _ON_MODE}}})
+ result = await hw.get_schedule_now_next_later(_make_device())
+ assert result is None
+
+
+# ---------------------------------------------------------------------------
+# get_mode — BOOST path with missing props.previous must not log error
+# ---------------------------------------------------------------------------
+
+
+class TestHotwaterGetModeBoostMissingPrevious:
+ """get_mode BOOST path must use safe access, not bare dict that logs a spurious error."""
+
+ async def test_boost_missing_previous_returns_off_without_error_log( # noqa: E501
+ self,
+ ):
+ """When mode=BOOST and props has no previous, get_mode returns 'OFF' without error log.
+
+ The hotwater HIVETOHA mapping maps None→'OFF', so missing previous
+ resolves safely instead of throwing a swallowed KeyError.
+ """
+ from unittest.mock import patch
+
+ hw = _make_hotwater({"hw-1": {"state": {"mode": "BOOST"}, "props": {}}})
+ d = _make_device()
+ with patch("apyhiveapi.devices.hotwater._LOGGER") as mock_log:
+ result = await hw.get_mode(d)
+ assert result == "OFF"
+ mock_log.error.assert_not_called()
+
+
+# ---------------------------------------------------------------------------
+# set_boost_off — must return False when prev_mode is None (not send mode=None)
+# ---------------------------------------------------------------------------
+
+
+class TestHotwaterSetBoostOffNullPrevMode:
+ """HiveHotwater.set_boost_off returns False when prev_mode is None."""
+
+ async def test_set_boost_off_returns_false_when_prev_mode_missing(self):
+ """set_boost_off returns False and does not call API when prev mode is absent."""
+ from apyhiveapi.devices.hotwater import HiveHotwater
+
+ class StubHotwater(HiveHotwater):
+ """Concrete stub for testing."""
+
+ h = StubHotwater()
+ h.session = MagicMock()
+ h.session.data.products = {"h1": {"state": {"mode": "BOOST"}, "props": {}}}
+ h._execute_state_change = AsyncMock(return_value=True)
+ h.get_boost_status = AsyncMock(return_value="ON")
+
+ d = Device(
+ hive_id="h1",
+ hive_name="T",
+ hive_type="hotwater",
+ ha_type="water_heater",
+ device_id="d1",
+ device_name="T",
+ device_data={"online": True},
+ ha_name="Hotwater",
+ )
+ result = await h.set_boost_off(d)
+ assert result is False
+ h._execute_state_change.assert_not_called()
+
+
+class TestHotwaterGetStateScheduleGuard:
+ """get_state handles empty schedule NNL without KeyError."""
+
+ async def test_get_state_schedule_guard_empty_nnl(self):
+ """get_state does not crash when get_schedule_nnl returns empty dict."""
+ from apyhiveapi.devices.hotwater import HiveHotwater
+
+ class StubHotwater(HiveHotwater):
+ pass
+
+ h = StubHotwater()
+ h.session = MagicMock()
+ h.session.data.products = {
+ "h1": {
+ "state": {"status": "ON", "mode": "SCHEDULE", "schedule": {}},
+ "props": {},
+ }
+ }
+ h.session.helper.get_schedule_nnl = MagicMock(return_value={})
+ h.get_mode = AsyncMock(return_value="SCHEDULE")
+ h.get_boost_status = AsyncMock(return_value="OFF")
+
+ d = Device(
+ hive_id="h1",
+ hive_name="T",
+ hive_type="hotwater",
+ ha_type="water_heater",
+ device_id="d1",
+ device_name="T",
+ device_data={"online": True},
+ ha_name="Hotwater",
+ )
+ result = await h.get_state(d)
+ assert result is None or isinstance(result, str)
+
+
+# ===========================================================================
+# Migrated from test_remaining_branches.py
+# ===========================================================================
+
+
+class TestHotwaterGetModeKeyError:
+ """Lines 43-44: KeyError in get_mode."""
+
+ async def test_get_mode_missing_state_returns_none(self):
+ """Product with no 'state' key causes KeyError → final stays None."""
+ hw = _make_hotwater({"hw-1": {}})
+ result = await hw.get_mode(_make_device())
+ assert result is None
+
+
+class TestHotwaterGetStateKeyError:
+ """Lines 83-84: KeyError in get_state."""
+
+ async def test_get_state_missing_status_key_returns_none(self):
+ """Product 'state' dict missing 'status' key triggers KeyError → None."""
+ hw = _make_hotwater({"hw-1": {"state": {"mode": "MANUAL"}}})
+ # 'status' key is absent from state → KeyError on data["state"]["status"]
+ result = await hw.get_state(_make_device())
+ assert result is None
+
+ async def test_get_state_missing_schedule_in_schedule_mode_returns_none(self):
+ """SCHEDULE mode with no 'schedule' key in state causes KeyError → None."""
+ hw = _make_hotwater(
+ {
+ "hw-1": {
+ "state": {
+ "mode": "SCHEDULE",
+ "status": "ON",
+ "boost": False,
+ # no 'schedule' key
+ }
+ }
+ }
+ )
+ result = await hw.get_state(_make_device())
+ assert result is None
+
+
+class TestHotwaterScheduleNNLNone:
+ """Lines 225->227: get_schedule_now_next_later returns None when schedule is absent."""
+
+ async def test_schedule_none_when_no_schedule_in_state(self):
+ """SCHEDULE mode product without 'schedule' key → _get_product_state returns None → None."""
+ hw = _make_hotwater({"hw-1": {"state": {"mode": "SCHEDULE"}}})
+ # _get_product_state(device, "state", "schedule") → None (key absent)
+ result = await hw.get_schedule_now_next_later(_make_device())
+ assert result is None
+
+
+class TestHotwaterGetWaterHeaterCacheMiss:
+ """Lines 173->180: cache enabled but cached is None → continues with network call."""
+
+ async def test_cached_none_falls_through(self):
+ hw = _make_hotwater(
+ {"hw-1": {"state": {"mode": "ON"}, "props": {}}},
+ devices={"dev-1": {"state": {}, "props": {}}},
+ )
+ hw.session.should_use_cached_data.return_value = True
+ hw.session.get_cached_device.return_value = None
+ result = await hw.get_water_heater(_make_device())
+ assert result is not None
+ hw.session.attr.online_offline.assert_called_once()
diff --git a/tests/unit/test_hotwater_extended.py b/tests/unit/test_hotwater_extended.py
deleted file mode 100644
index 45891f7..0000000
--- a/tests/unit/test_hotwater_extended.py
+++ /dev/null
@@ -1,162 +0,0 @@
-"""Extended branch-coverage tests for WaterHeater / HiveHotwater."""
-
-# pylint: disable=too-few-public-methods
-from unittest.mock import AsyncMock, MagicMock
-
-from apyhiveapi.devices.hotwater import WaterHeater
-from apyhiveapi.helper.hivedataclasses import Device, SessionConfig
-from apyhiveapi.helper.map import Map
-
-_SCHEDULE_MODE = "SCHEDULE"
-_ON_MODE = "ON"
-_OFF_MODE = "OFF"
-_BOOST_MODE = "BOOST"
-_BOOST_MINS = 30
-
-
-def _make_hotwater(products=None, devices=None):
- session = MagicMock()
- session.data = Map(
- {
- "products": products or {},
- "devices": devices or {},
- "actions": {},
- "minMax": {},
- "user": {},
- }
- )
- session.config = SessionConfig()
- session.helper = MagicMock()
- session.helper.device_recovered = MagicMock()
- session.helper.error_check = AsyncMock()
- session.helper.get_schedule_nnl = MagicMock(
- return_value={"now": {"value": {"status": _ON_MODE}}, "next": {}, "later": {}}
- )
- session.attr = MagicMock()
- session.attr.online_offline = AsyncMock(return_value=True)
- session.attr.state_attributes = AsyncMock(return_value={})
- session.api = MagicMock()
- session.api.set_state = AsyncMock(return_value={"original": 200, "parsed": {}})
- session.hive_refresh_tokens = AsyncMock()
- session.get_devices = AsyncMock(return_value=True)
- session.should_use_cached_data = MagicMock(return_value=False)
- session.get_cached_device = MagicMock(return_value=None)
- session.set_cached_device = MagicMock(side_effect=lambda d: d)
- return WaterHeater(session=session)
-
-
-def _make_device(hive_id="hw-1", device_id="dev-1"):
- return Device(
- hive_id=hive_id,
- hive_name="Hot Water",
- hive_type="hotwater",
- ha_type="water_heater",
- device_id=device_id,
- device_name="Hot Water",
- device_data={"online": True},
- ha_name="Hot Water",
- )
-
-
-class TestGetMode:
- async def test_boost_mode_reads_previous(self):
- """BOOST state resolves to the previous mode stored in props."""
- hw = _make_hotwater(
- {
- "hw-1": {
- "state": {"mode": _BOOST_MODE},
- "props": {"previous": {"mode": _ON_MODE}},
- }
- }
- )
- result = await hw.get_mode(_make_device())
- assert result == _ON_MODE
-
-
-class TestGetState:
- async def test_schedule_mode_boost_off_reads_schedule(self):
- """SCHEDULE mode with boost OFF reads state from schedule nnl."""
- hw = _make_hotwater(
- {
- "hw-1": {
- "state": {
- "mode": _SCHEDULE_MODE,
- "status": _OFF_MODE,
- "boost": False,
- "schedule": {},
- }
- }
- }
- )
- result = await hw.get_state(_make_device())
- assert result is not None
-
- async def test_non_schedule_state_mapped(self):
- """Direct ON mode/status maps through HIVETOHA without schedule lookup."""
- hw = _make_hotwater(
- {
- "hw-1": {
- "state": {
- "mode": _ON_MODE,
- "status": _ON_MODE,
- "schedule": {},
- }
- }
- }
- )
- result = await hw.get_state(_make_device())
- assert result is not None
-
-
-class TestGetWaterHeater:
- async def test_cache_hit_returns_cached(self):
- """Cached device is returned immediately when poll is slow/busy."""
- hw = _make_hotwater()
- hw.session.should_use_cached_data = MagicMock(return_value=True)
- cached_device = _make_device()
- cached_device.status = {"current_operation": _SCHEDULE_MODE}
- hw.session.get_cached_device = MagicMock(return_value=cached_device)
- d = _make_device()
- result = await hw.get_water_heater(d)
- assert result is cached_device
- hw.session.attr.online_offline.assert_not_called()
-
- async def test_device_data_not_dict_gets_reset(self):
- """Non-dict device_data is replaced with an empty dict before use."""
- hw = _make_hotwater(
- products={"hw-1": {"state": {"mode": _SCHEDULE_MODE}}},
- devices={"dev-1": {"props": {}, "parent": None}},
- )
- d = _make_device()
- d.device_data = None
- await hw.get_water_heater(d)
- assert isinstance(d.device_data, dict)
-
- async def test_offline_device_calls_error_check(self):
- """Offline device triggers error_check and status defaults to None."""
- hw = _make_hotwater(
- products={"hw-1": {}},
- devices={"dev-1": {}},
- )
- hw.session.attr.online_offline = AsyncMock(return_value=False)
- d = _make_device()
- result = await hw.get_water_heater(d)
- hw.session.helper.error_check.assert_called_once()
- assert result.status["current_operation"] is None
-
-
-class TestGetScheduleNowNextLater:
- async def test_schedule_mode_returns_nnl(self):
- """SCHEDULE mode with schedule data returns now/next/later dict."""
- hw = _make_hotwater(
- {"hw-1": {"state": {"mode": _SCHEDULE_MODE, "schedule": {"data": []}}}}
- )
- result = await hw.get_schedule_now_next_later(_make_device())
- assert result is not None
- assert "now" in result
-
- async def test_non_schedule_mode_returns_none(self):
- """Non-SCHEDULE mode returns None."""
- hw = _make_hotwater({"hw-1": {"state": {"mode": _ON_MODE}}})
- result = await hw.get_schedule_now_next_later(_make_device())
- assert result is None
diff --git a/tests/unit/test_light_extended.py b/tests/unit/test_light.py
similarity index 75%
rename from tests/unit/test_light_extended.py
rename to tests/unit/test_light.py
index d1904e7..a6ff385 100644
--- a/tests/unit/test_light_extended.py
+++ b/tests/unit/test_light.py
@@ -233,3 +233,79 @@ async def test_turn_on_with_color_calls_set_color(self):
call_kwargs = session.api.set_state.call_args.kwargs
assert call_kwargs.get("colourMode") == "COLOUR"
assert call_kwargs.get("hue") == str(color[0])
+
+
+# ---------------------------------------------------------------------------
+# get_brightness — must return int, not float
+# ---------------------------------------------------------------------------
+
+
+class TestGetBrightnessReturnsInt:
+ """get_brightness must return int, not float."""
+
+ async def test_get_brightness_returns_int(self):
+ """Brightness value is returned as int (not float) for HA compatibility."""
+ from apyhiveapi.devices.light import HiveLight
+
+ class StubLight(HiveLight):
+ """Concrete stub for testing."""
+
+ h = StubLight()
+ h.session = MagicMock()
+ h.session.data.products = {"h1": {"state": {"brightness": 50}}}
+ d = Device(
+ hive_id="h1",
+ hive_name="L",
+ hive_type="warmwhitelight",
+ ha_type="light",
+ device_id="d1",
+ device_name="L",
+ device_data={"online": True},
+ ha_name="Light",
+ )
+ result = await h.get_brightness(d)
+ assert isinstance(result, int), (
+ f"Expected int, got {type(result).__name__}: {result!r}"
+ )
+ assert result == 127
+
+
+class TestGetBrightnessNullValue:
+ """A null brightness in the API payload must not raise TypeError."""
+
+ async def test_null_brightness_returns_none(self):
+ session = _make_session(
+ products={
+ "light-1": {
+ "state": {"status": "ON", "brightness": None},
+ "props": {},
+ }
+ },
+ )
+ light = Light(session=session)
+ result = await light.get_brightness(_make_device())
+ assert result is None
+
+
+# ===========================================================================
+# Migrated from test_remaining_branches.py
+# ===========================================================================
+
+
+class TestLightGetLightCacheMiss:
+ """Lines 141->147: cache enabled but cached is None → normal execution."""
+
+ async def test_cached_none_falls_through(self):
+ session = _make_session(
+ products={
+ "light-1": {"state": {"status": "ON", "brightness": 100}, "props": {}}
+ },
+ devices={"dev-1": {"state": {}, "props": {}}},
+ )
+ light = Light(session=session)
+ d = _make_device()
+ session.should_use_cached_data.return_value = True
+ session.get_cached_device.return_value = None
+ result = await light.get_light(d)
+ assert result is not None
+ session.attr.online_offline.assert_called_once()
diff --git a/tests/unit/test_map.py b/tests/unit/test_map.py
index f6209f7..ee9ab44 100644
--- a/tests/unit/test_map.py
+++ b/tests/unit/test_map.py
@@ -1,5 +1,6 @@
"""Unit tests for Map — dot-notation dict wrapper."""
+import pytest
from apyhiveapi.helper.map import Map
@@ -15,10 +16,18 @@ def test_dict_read():
assert m["key"] == "value"
-def test_missing_key_returns_none_not_keyerror():
- """Test that missing keys return None instead of raising KeyError."""
+def test_missing_key_raises_attribute_error():
+ """Missing attribute access raises AttributeError."""
m = Map({})
- assert m.missing is None
+ with pytest.raises(AttributeError):
+ _ = m.missing
+
+
+def test_missing_bracket_key_raises_key_error():
+ """Missing bracket access raises KeyError (standard dict behaviour)."""
+ m = Map({})
+ with pytest.raises(KeyError):
+ _ = m["missing"]
def test_nested_access():
@@ -32,3 +41,13 @@ def test_attribute_write():
m = Map({})
m.foo = "bar"
assert m["foo"] == "bar"
+
+
+class TestMapDelAttr:
+ """Map.__delattr__ must raise AttributeError (not KeyError) for missing keys."""
+
+ def test_delattr_missing_key_raises_attribute_error(self):
+ """del m.missing raises AttributeError, not KeyError."""
+ m = Map({"a": 1})
+ with pytest.raises(AttributeError):
+ del m.nonexistent
diff --git a/tests/unit/test_polling.py b/tests/unit/test_polling.py
index e518941..77062a0 100644
--- a/tests/unit/test_polling.py
+++ b/tests/unit/test_polling.py
@@ -3,6 +3,7 @@
# pylint: disable=protected-access,attribute-defined-outside-init,too-few-public-methods
import asyncio
+from unittest.mock import AsyncMock, MagicMock
from apyhiveapi.helper.hivedataclasses import Device
from apyhiveapi.session.polling import PollingMixin
@@ -191,8 +192,6 @@ class TestPollDevices:
async def test_poll_devices_delegates_to_get_devices(self):
"""_poll_devices calls get_devices('No_ID') and returns its result."""
- from unittest.mock import AsyncMock
-
p = _make_polling()
p.get_devices = AsyncMock(return_value=True)
result = await p._poll_devices()
@@ -201,9 +200,176 @@ async def test_poll_devices_delegates_to_get_devices(self):
async def test_poll_devices_propagates_false(self):
"""_poll_devices returns False when get_devices returns False."""
- from unittest.mock import AsyncMock
-
p = _make_polling()
p.get_devices = AsyncMock(return_value=False)
result = await p._poll_devices()
assert result is False
+
+
+# ---------------------------------------------------------------------------
+# TestGetDevicesSlowPoll
+# ---------------------------------------------------------------------------
+
+
+class TestGetDevicesSlowPoll:
+ async def test_auth_error_sets_last_poll_slow_false(self):
+ from apyhiveapi.helper.hive_exceptions import HiveAuthError
+
+ p = _make_polling()
+ p.api = MagicMock()
+ p.api.get_all = AsyncMock(side_effect=HiveAuthError())
+ p.config = MagicMock()
+ p.config.file = False
+ p.tokens = MagicMock()
+ p._last_poll_slow = True # pre-set to True to confirm it gets cleared
+
+ retry_result = {
+ "original": 200,
+ "parsed": {"products": [], "devices": [], "actions": []},
+ }
+
+ async def fake_retry_login():
+ pass
+
+ async def fake_retry_with_backoff(_fn, **_kwargs):
+ return retry_result
+
+ p._retry_login = fake_retry_login
+ p._retry_with_backoff = fake_retry_with_backoff
+ p.hive_refresh_tokens = AsyncMock()
+ p.data = MagicMock()
+ p.data.products = {}
+ p.data.devices = {}
+ p.data.actions = {}
+ p.config.last_update = MagicMock()
+ p.config.scan_interval = MagicMock()
+
+ await p.get_devices("No_ID")
+ assert p._last_poll_slow is False
+
+ async def test_tokens_none_returns_false_without_crash(self):
+ """get_devices returns False (no crash) when tokens=None and file=False."""
+ from apyhiveapi.helper.map import Map
+
+ p = _make_polling()
+ p.config = MagicMock()
+ p.config.file = False
+ p.tokens = None # triggers the "neither branch" path
+ p.data = Map({"products": {}, "devices": {}, "actions": {}, "user": {}})
+
+ result = await p.get_devices("No_ID")
+ assert result is False
+
+ async def test_slow_api_call_sets_last_poll_slow_true(self):
+ p = _make_polling()
+ p._slow_poll_threshold = 0 # any call will be "slow"
+ p.api = MagicMock()
+
+ slow_result = {
+ "original": 200,
+ "parsed": {"products": [], "devices": [], "actions": []},
+ }
+
+ async def slow_get_all():
+ return slow_result
+
+ p.api.get_all = slow_get_all
+ p.config = MagicMock()
+ p.config.file = False
+ p.tokens = MagicMock()
+ p.hive_refresh_tokens = AsyncMock()
+ p.data = MagicMock()
+ p.data.products = {}
+ p.data.devices = {}
+ p.data.actions = {}
+ p.config.last_update = MagicMock()
+ p.config.scan_interval = MagicMock()
+
+ await p.get_devices("No_ID")
+ assert p._last_poll_slow is True
+
+
+# ---------------------------------------------------------------------------
+# TestGetDevicesNoneGuard — api.get_all() returning None must not crash
+# ---------------------------------------------------------------------------
+
+
+class TestGetDevicesNoneGuard:
+ """api_resp_d must be guarded before dict access when api.get_all() returns None."""
+
+ async def test_api_returns_none_does_not_crash(self):
+ """get_devices returns False without crashing when api.get_all() returns None."""
+ from apyhiveapi.helper.map import Map
+
+ p = _make_polling()
+ p.config = MagicMock()
+ p.config.file = False
+ p.tokens = MagicMock()
+ p.api = MagicMock()
+ p.api.get_all = AsyncMock(return_value=None)
+ p.hive_refresh_tokens = AsyncMock()
+ p.data = Map({"products": {}, "devices": {}, "actions": {}, "user": {}})
+
+ result = await p.get_devices("No_ID")
+ assert result is False
+
+
+# ---------------------------------------------------------------------------
+# TestGetDevicesHomesKey — homes list null/empty must not crash
+# ---------------------------------------------------------------------------
+
+
+class TestGetDevicesHomesKey:
+ """homes key in API response must not crash when homes list is None or empty."""
+
+ async def _run_get_devices_with_parsed(self, parsed):
+ from apyhiveapi.helper.map import Map
+
+ p = _make_polling()
+ p.config = MagicMock()
+ p.config.file = False
+ p.tokens = MagicMock()
+ p.api = MagicMock()
+ p.api.get_all = AsyncMock(return_value={"original": "200", "parsed": parsed})
+ p.hive_refresh_tokens = AsyncMock()
+ p.data = Map({"products": {}, "devices": {}, "actions": {}, "user": {}})
+ return p, await p.get_devices("No_ID")
+
+ async def test_homes_null_does_not_crash(self):
+ """No crash when API returns homes.homes = None."""
+ parsed = {
+ "products": [],
+ "devices": [],
+ "actions": [],
+ "homes": {"homes": None},
+ }
+ _p, result = await self._run_get_devices_with_parsed(parsed)
+ assert isinstance(result, bool)
+
+ async def test_homes_empty_list_does_not_crash(self):
+ """No crash when API returns homes.homes = []."""
+ parsed = {"products": [], "devices": [], "actions": [], "homes": {"homes": []}}
+ _p, result = await self._run_get_devices_with_parsed(parsed)
+ assert isinstance(result, bool)
+
+ async def test_valid_homes_list_sets_home_id(self):
+ """home_id is set correctly from a valid homes list."""
+ parsed = {
+ "products": [],
+ "devices": [],
+ "actions": [],
+ "homes": {"homes": [{"id": "home-abc"}]},
+ }
+ p, _ = await self._run_get_devices_with_parsed(parsed)
+ assert p.config.home_id == "home-abc"
+
+ async def test_homes_data_not_dict_does_not_crash(self):
+ """No crash when homes_data is a non-dict (list); home_id stays unset."""
+ parsed = {
+ "products": [],
+ "devices": [],
+ "actions": [],
+ "homes": [{"id": "home-list"}],
+ }
+ _p, result = await self._run_get_devices_with_parsed(parsed)
+ assert isinstance(result, bool)
diff --git a/tests/unit/test_remaining_branches.py b/tests/unit/test_remaining_branches.py
deleted file mode 100644
index cb5b8fb..0000000
--- a/tests/unit/test_remaining_branches.py
+++ /dev/null
@@ -1,1330 +0,0 @@
-"""Branch-coverage tests for several source modules.
-
-Covers missing lines in:
- - src/devices/heating.py
- - src/devices/hotwater.py
- - src/devices/light.py
- - src/devices/sensor.py
- - src/session/auth.py
- - src/session/discovery.py
-"""
-
-# pylint: disable=too-few-public-methods,protected-access,attribute-defined-outside-init
-
-import asyncio
-from datetime import datetime, timedelta
-from unittest.mock import AsyncMock, MagicMock, patch
-
-import pytest
-from apyhiveapi.devices.heating import Climate
-from apyhiveapi.devices.hotwater import WaterHeater
-from apyhiveapi.devices.light import Light
-from apyhiveapi.devices.sensor import Sensor
-from apyhiveapi.helper.hive_exceptions import HiveApiError
-from apyhiveapi.helper.hive_helper import HiveHelper
-from apyhiveapi.helper.hivedataclasses import (
- Device,
- EntityConfig,
- SessionConfig,
- SessionTokens,
-)
-from apyhiveapi.helper.map import Map
-from apyhiveapi.session.auth import SessionAuthMixin
-from apyhiveapi.session.discovery import DiscoveryMixin
-
-# ---------------------------------------------------------------------------
-# Shared helpers — heating
-# ---------------------------------------------------------------------------
-
-
-def _make_climate(products=None, devices=None, min_max=None):
- session = MagicMock()
- session.data = Map(
- {
- "products": products or {},
- "devices": devices or {},
- "actions": {},
- "minMax": min_max or {},
- "user": {},
- }
- )
- session.config = SessionConfig()
- session.helper = MagicMock()
- session.helper.device_recovered = MagicMock()
- session.helper.error_check = AsyncMock()
- session.helper.get_schedule_nnl = MagicMock(
- return_value={"now": {}, "next": {}, "later": {}}
- )
- session.attr = MagicMock()
- session.attr.online_offline = AsyncMock(return_value=True)
- session.attr.state_attributes = AsyncMock(return_value={})
- session.api = MagicMock()
- session.api.set_state = AsyncMock(return_value={"original": 200, "parsed": {}})
- session.hive_refresh_tokens = AsyncMock()
- session.get_devices = AsyncMock(return_value=True)
- session.should_use_cached_data = MagicMock(return_value=False)
- session.get_cached_device = MagicMock(return_value=None)
- session.set_cached_device = MagicMock(side_effect=lambda d: d)
- return Climate(session=session)
-
-
-def _make_device(hive_id="heat-1", device_id="dev-1", hive_type="heating"):
- return Device(
- hive_id=hive_id,
- hive_name="Hallway",
- hive_type=hive_type,
- ha_type="climate",
- device_id=device_id,
- device_name="Hallway",
- device_data={"online": True},
- ha_name="Hallway",
- )
-
-
-# ---------------------------------------------------------------------------
-# Shared helpers — hotwater
-# ---------------------------------------------------------------------------
-
-
-def _make_hotwater(products=None, devices=None):
- session = MagicMock()
- session.data = Map(
- {
- "products": products or {},
- "devices": devices or {},
- "actions": {},
- "minMax": {},
- "user": {},
- }
- )
- session.config = SessionConfig()
- session.helper = MagicMock()
- session.helper.device_recovered = MagicMock()
- session.helper.get_schedule_nnl = MagicMock(
- return_value={"now": {}, "next": {}, "later": {}}
- )
- session.attr = MagicMock()
- session.attr.online_offline = AsyncMock(return_value=True)
- session.attr.state_attributes = AsyncMock(return_value={})
- session.api = MagicMock()
- session.api.set_state = AsyncMock(return_value={"original": 200, "parsed": {}})
- session.hive_refresh_tokens = AsyncMock()
- session.get_devices = AsyncMock(return_value=True)
- session.should_use_cached_data = MagicMock(return_value=False)
- session.get_cached_device = MagicMock(return_value=None)
- session.set_cached_device = MagicMock(side_effect=lambda d: d)
- return WaterHeater(session=session)
-
-
-def _make_hw_device(hive_id="hw-1", device_id="dev-1"):
- return Device(
- hive_id=hive_id,
- hive_name="Hot Water",
- hive_type="hotwater",
- ha_type="water_heater",
- device_id=device_id,
- device_name="Hot Water",
- device_data={"online": True},
- ha_name="Hot Water",
- )
-
-
-# ---------------------------------------------------------------------------
-# Shared helpers — sensor
-# ---------------------------------------------------------------------------
-
-
-def _make_sensor(products=None, devices=None):
- session = MagicMock()
- session.data = Map(
- {
- "products": products or {},
- "devices": devices or {},
- "actions": {},
- "minMax": {},
- "user": {},
- }
- )
- session.config = SessionConfig()
- session.helper = MagicMock()
- session.helper.device_recovered = MagicMock()
- session.attr = MagicMock()
- session.attr.online_offline = AsyncMock(return_value=True)
- session.attr.state_attributes = AsyncMock(return_value={})
- session.should_use_cached_data = MagicMock(return_value=False)
- session.get_cached_device = MagicMock(return_value=None)
- session.set_cached_device = MagicMock(side_effect=lambda d: d)
- return Sensor(session=session)
-
-
-def _make_sensor_device(hive_id="sens-1", device_id="dev-1", hive_type="contactsensor"):
- return Device(
- hive_id=hive_id,
- hive_name="Door",
- hive_type=hive_type,
- ha_type="binary_sensor",
- device_id=device_id,
- device_name="Door",
- device_data={"online": True},
- ha_name="Door",
- )
-
-
-# ---------------------------------------------------------------------------
-# Shared helpers — auth
-# ---------------------------------------------------------------------------
-
-
-def _make_auth_stub():
- class StubAuth(SessionAuthMixin):
- """Concrete subclass used only for testing."""
-
- s = StubAuth()
- s.auth = MagicMock()
- s.auth.DEVICE_VERIFIER_CHALLENGE = "DEVICE_SRP_AUTH"
- s.auth.SMS_MFA_CHALLENGE = "SMS_MFA"
- s.auth.login = AsyncMock()
- s.auth.device_login = AsyncMock()
- s.auth.sms_2fa = AsyncMock()
- s.auth.refresh_token = AsyncMock()
- s.tokens = SessionTokens()
- s.tokens.token_data = {"refreshToken": "rt", "token": "", "accessToken": ""}
- s.config = SessionConfig()
- s.helper = MagicMock()
- s.helper.sanitize_payload = MagicMock(return_value={})
- s._refresh_threshold = 0.90
- s._refresh_lock = asyncio.Lock()
- return s
-
-
-# ---------------------------------------------------------------------------
-# Shared helpers — discovery
-# ---------------------------------------------------------------------------
-
-
-def _make_discovery_stub(products=None, devices=None, actions=None):
- class StubDiscovery(DiscoveryMixin):
- """Concrete subclass used only for testing."""
-
- s = StubDiscovery()
- s.config = SessionConfig()
- s.data = Map(
- {
- "products": products or {},
- "devices": devices or {},
- "actions": actions or {},
- "user": {"temperatureUnit": "C"},
- "minMax": {},
- }
- )
- s.helper = MagicMock()
- s.helper.get_device_data = MagicMock(
- return_value={
- "id": "dev-1",
- "state": {"name": "Test Device"},
- "props": {"online": True},
- }
- )
- s.hub_id = None
- s.device_list = {
- "parent": [],
- "binary_sensor": [],
- "climate": [],
- "light": [],
- "sensor": [],
- "switch": [],
- "water_heater": [],
- }
- return s
-
-
-# ===========================================================================
-# 1. src/devices/heating.py
-# ===========================================================================
-
-
-class TestHeatingGetStateKeyError:
- """Lines 206-207: KeyError/TypeError branch in get_state."""
-
- async def test_get_state_key_error_returns_none(self):
- """Missing product entry causes get_current_temperature to return None,
- leaving final as None without raising."""
- # products dict is empty — device.hive_id not found → both temp helpers
- # return None → the if branch is skipped → final stays None
- climate = _make_climate(products={})
- d = _make_device()
- result = await climate.get_state(d)
- assert result is None
-
-
-class TestHeatingGetHeatOnDemand:
- """Line 231: get_heat_on_demand happy path."""
-
- async def test_get_heat_on_demand_returns_value(self):
- """Returns the nested autoBoost.active value from products."""
- climate = _make_climate({"heat-1": {"props": {"autoBoost": {"active": True}}}})
- result = await climate.get_heat_on_demand(_make_device())
- assert result is True
-
- async def test_get_heat_on_demand_returns_none_when_missing(self):
- """Returns None when the nested path does not exist."""
- climate = _make_climate({"heat-1": {"props": {}}})
- result = await climate.get_heat_on_demand(_make_device())
- assert result is None
-
-
-class TestHeatingSetBoostOffOffMode:
- """Lines 321->325: set_boost_off when previous mode is 'OFF' with a real target."""
-
- async def test_set_boost_off_off_mode_with_target_restores_target(self):
- """Previous mode OFF with a real target value restores that target."""
- climate = _make_climate(
- {
- "heat-1": {
- "type": "heating",
- "state": {"boost": 5},
- "props": {"previous": {"mode": "OFF", "target": 18.0}},
- }
- }
- )
- result = await climate.set_boost_off(_make_device())
- assert result is True
- _, kwargs = climate.session.api.set_state.call_args
- assert kwargs.get("mode") == "OFF"
- assert kwargs.get("target") == 18.0
-
-
-class TestHeatingSetHeatOnDemand:
- """Lines 337-342: set_heat_on_demand calls _execute_state_change with autoBoost kwarg."""
-
- async def test_set_heat_on_demand_enabled(self):
- """set_heat_on_demand passes autoBoost='ENABLED' to the API."""
- climate = _make_climate({"heat-1": {"type": "heating"}})
- result = await climate.set_heat_on_demand(_make_device(), "ENABLED")
- assert result is True
- climate.session.api.set_state.assert_called_once()
- _, kwargs = climate.session.api.set_state.call_args
- assert kwargs.get("autoBoost") == "ENABLED"
-
- async def test_set_heat_on_demand_disabled(self):
- """set_heat_on_demand passes autoBoost='DISABLED' to the API."""
- climate = _make_climate({"heat-1": {"type": "heating"}})
- result = await climate.set_heat_on_demand(_make_device(), "DISABLED")
- assert result is True
- _, kwargs = climate.session.api.set_state.call_args
- assert kwargs.get("autoBoost") == "DISABLED"
-
-
-class TestHeatingGetClimateCacheHit:
- """Lines 371->377: get_climate returns cached device when cache is available."""
-
- async def test_get_climate_returns_cached_when_available(self):
- """Cache hit short-circuits all I/O and returns the cached dict."""
- climate = _make_climate({"heat-1": {"type": "heating"}})
- cached = {"current_temperature": 20.0}
- climate.session.should_use_cached_data.return_value = True
- climate.session.get_cached_device.return_value = cached
- result = await climate.get_climate(_make_device())
- assert result == cached
- # No API calls should have been made
- climate.session.attr.online_offline.assert_not_called()
-
-
-class TestHeatingGetScheduleNNLKeyError:
- """Lines 438-439: KeyError in get_schedule_now_next_later."""
-
- async def test_missing_schedule_key_returns_none(self):
- """Product with state but no 'schedule' key causes KeyError → returns None."""
- climate = _make_climate(
- {"heat-1": {"state": {"mode": "SCHEDULE"}}}
- # no 'schedule' key inside state
- )
- # Override get_mode to return SCHEDULE directly so the if-branch is entered
- climate.session.helper.get_schedule_nnl.side_effect = KeyError("schedule")
- # get_mode will read data["state"]["mode"] == "SCHEDULE" → enters the try block
- # data["state"]["schedule"] raises KeyError → caught, returns None
- result = await climate.get_schedule_now_next_later(_make_device())
- assert result is None
-
- async def test_schedule_key_error_caught_not_raised(self):
- """A KeyError inside the try block does not propagate to the caller."""
- climate = _make_climate({"heat-1": {"state": {"mode": "SCHEDULE"}}})
- # Accessing data["state"]["schedule"] will raise KeyError (key absent)
- try:
- result = await climate.get_schedule_now_next_later(_make_device())
- except KeyError:
- pytest.fail(
- "KeyError should have been caught inside get_schedule_now_next_later"
- )
- assert result is None
-
-
-# ===========================================================================
-# 2. src/devices/hotwater.py
-# ===========================================================================
-
-
-class TestHotwaterGetModeKeyError:
- """Lines 43-44: KeyError in get_mode."""
-
- async def test_get_mode_missing_state_returns_none(self):
- """Product with no 'state' key causes KeyError → final stays None."""
- hw = _make_hotwater({"hw-1": {}})
- result = await hw.get_mode(_make_hw_device())
- assert result is None
-
-
-class TestHotwaterGetStateKeyError:
- """Lines 83-84: KeyError in get_state."""
-
- async def test_get_state_missing_status_key_returns_none(self):
- """Product 'state' dict missing 'status' key triggers KeyError → None."""
- hw = _make_hotwater({"hw-1": {"state": {"mode": "MANUAL"}}})
- # 'status' key is absent from state → KeyError on data["state"]["status"]
- result = await hw.get_state(_make_hw_device())
- assert result is None
-
- async def test_get_state_missing_schedule_in_schedule_mode_returns_none(self):
- """SCHEDULE mode with no 'schedule' key in state causes KeyError → None."""
- hw = _make_hotwater(
- {
- "hw-1": {
- "state": {
- "mode": "SCHEDULE",
- "status": "ON",
- "boost": False,
- # no 'schedule' key
- }
- }
- }
- )
- result = await hw.get_state(_make_hw_device())
- assert result is None
-
-
-class TestHotwaterGetWaterHeaterCacheHit:
- """Lines 173->180: get_water_heater returns cached when cache is available."""
-
- async def test_get_water_heater_returns_cached(self):
- """Cache hit short-circuits all I/O and returns the cached value."""
- hw = _make_hotwater()
- cached = {"current_operation": "ON"}
- hw.session.should_use_cached_data.return_value = True
- hw.session.get_cached_device.return_value = cached
- result = await hw.get_water_heater(_make_hw_device())
- assert result == cached
- hw.session.attr.online_offline.assert_not_called()
-
-
-class TestHotwaterScheduleNNLNone:
- """Lines 225->227: get_schedule_now_next_later returns None when schedule is absent."""
-
- async def test_schedule_none_when_no_schedule_in_state(self):
- """SCHEDULE mode product without 'schedule' key → _get_product_state returns None → None."""
- hw = _make_hotwater({"hw-1": {"state": {"mode": "SCHEDULE"}}})
- # _get_product_state(device, "state", "schedule") → None (key absent)
- result = await hw.get_schedule_now_next_later(_make_hw_device())
- assert result is None
-
- async def test_non_schedule_mode_returns_none(self):
- """Non-SCHEDULE mode skips the schedule lookup and returns None directly."""
- hw = _make_hotwater({"hw-1": {"state": {"mode": "MANUAL"}}})
- result = await hw.get_schedule_now_next_later(_make_hw_device())
- assert result is None
-
- async def test_schedule_present_returns_nnl(self):
- """When schedule data exists, get_schedule_nnl result is returned."""
- schedule_data = {"foo": "bar"}
- hw = _make_hotwater(
- {"hw-1": {"state": {"mode": "SCHEDULE", "schedule": schedule_data}}}
- )
- expected = {"now": {}, "next": {}, "later": {}}
- hw.session.helper.get_schedule_nnl.return_value = expected
- result = await hw.get_schedule_now_next_later(_make_hw_device())
- assert result == expected
- hw.session.helper.get_schedule_nnl.assert_called_once_with(schedule_data)
-
-
-# ===========================================================================
-# 3. src/devices/sensor.py
-# ===========================================================================
-
-
-class TestSensorGetStateKeyError:
- """Lines 37->42: KeyError in HiveSensor.get_state."""
-
- async def test_get_state_missing_type_key_returns_none(self):
- """Product with no 'type' key causes KeyError → final stays None."""
- sensor = _make_sensor({"sens-1": {}})
- d = _make_sensor_device()
- result = await sensor.get_state(d)
- assert result is None
-
- async def test_get_state_missing_props_key_returns_none(self):
- """contactsensor product without 'props' causes KeyError → None."""
- sensor = _make_sensor({"sens-1": {"type": "contactsensor"}})
- d = _make_sensor_device()
- result = await sensor.get_state(d)
- assert result is None
-
-
-class TestSensorGetSensorCacheHit:
- """Lines 92->98: get_sensor returns cached device when cache is available."""
-
- async def test_get_sensor_returns_cached(self):
- """Cache hit short-circuits all I/O and returns the cached value."""
- sensor = _make_sensor()
- cached = {"state": True}
- sensor.session.should_use_cached_data.return_value = True
- sensor.session.get_cached_device.return_value = cached
- result = await sensor.get_sensor(_make_sensor_device())
- assert result == cached
- sensor.session.attr.online_offline.assert_not_called()
-
-
-class TestSensorGetSensorProductsFallback:
- """Lines 119->122: when device_id not in devices, fall back to products."""
-
- async def test_uses_products_when_device_id_absent_from_devices(self):
- """device_id not in session.data.devices → hive_id looked up in products."""
- sensor = _make_sensor(
- products={"sens-1": {"type": "contactsensor", "props": {"status": "OPEN"}}},
- devices={},
- )
- d = _make_sensor_device(
- hive_id="sens-1", device_id="unknown-dev", hive_type="contactsensor"
- )
- # Sensor with hive_type in sensor_commands path will be followed;
- # the important thing is the products-fallback path is entered without error.
- result = await sensor.get_sensor(d)
- # Result should be the device (set_cached_device returns the device itself)
- assert result is not None
-
- async def test_products_fallback_data_used_for_device_data(self):
- """Props from the products entry propagate to device.device_data."""
- sensor = _make_sensor(
- products={
- "sens-1": {
- "type": "contactsensor",
- "props": {"status": "OPEN", "online": True},
- }
- },
- devices={},
- )
- d = _make_sensor_device(
- hive_id="sens-1", device_id="unknown-dev", hive_type="contactsensor"
- )
- await sensor.get_sensor(d)
- # get_state uses self.session.data.products[device.hive_id] directly
- # so we just verify it ran without KeyError
-
-
-class TestSensorGetSensorHiveTypesSensorPath:
- """Lines 135->146: elif device.hive_type in HIVE_TYPES['Sensor'] path."""
-
- async def test_contactsensor_in_hive_types_sensor_takes_else_branch(self):
- """contactsensor is in HIVE_TYPES['Sensor'] and not in sensor_commands key set,
- so the elif branch is taken."""
- from apyhiveapi.helper.const import HIVE_TYPES, sensor_commands
-
- # 'contactsensor' is in HIVE_TYPES['Sensor'] and NOT a key in sensor_commands
- assert "contactsensor" in HIVE_TYPES["Sensor"]
- assert "contactsensor" not in sensor_commands
-
- sensor = _make_sensor(
- products={"sens-1": {"type": "contactsensor", "props": {"status": "OPEN"}}},
- devices={"dev-1": {"props": {"online": True}, "type": "contactsensor"}},
- )
- d = _make_sensor_device(
- hive_id="sens-1", device_id="dev-1", hive_type="contactsensor"
- )
- d.device_data = {"online": True}
- await sensor.get_sensor(d)
- # The elif branch sets device.status with 'state' key
- assert d.status is not None
- assert "state" in d.status
-
- async def test_motionsensor_in_hive_types_sensor_sets_status(self):
- """motionsensor is in HIVE_TYPES['Sensor'] and not in sensor_commands key set."""
- from apyhiveapi.helper.const import HIVE_TYPES, sensor_commands
-
- assert "motionsensor" in HIVE_TYPES["Sensor"]
- assert "motionsensor" not in sensor_commands
-
- sensor = _make_sensor(
- products={
- "sens-1": {
- "type": "motionsensor",
- "props": {"motion": {"status": True}},
- }
- },
- devices={"dev-1": {"props": {"online": True}, "type": "motionsensor"}},
- )
- d = _make_sensor_device(
- hive_id="sens-1", device_id="dev-1", hive_type="motionsensor"
- )
- d.device_data = {"online": True}
- await sensor.get_sensor(d)
- assert d.status is not None
- assert "state" in d.status
-
-
-# ===========================================================================
-# 4. src/session/auth.py
-# ===========================================================================
-
-
-class TestRetryWithBackoffNonZeroDelay:
- """Line 66: asyncio.sleep called when delay > 0."""
-
- async def test_non_zero_delay_is_awaited_but_succeeds(self):
- """A non-zero delay entry causes asyncio.sleep to be called; factory still runs."""
- s = _make_auth_stub()
- calls = []
-
- async def factory():
- calls.append(1)
- return "ok"
-
- with patch(
- "apyhiveapi.session.auth.asyncio.sleep", new_callable=AsyncMock
- ) as mock_sleep:
- result = await s._retry_with_backoff(factory, delays=(5,))
- assert result == "ok"
- mock_sleep.assert_called_once_with(5)
- assert len(calls) == 1
-
- async def test_zero_delay_does_not_call_sleep(self):
- """A zero delay skips asyncio.sleep."""
- s = _make_auth_stub()
-
- async def factory():
- return "done"
-
- with patch(
- "apyhiveapi.session.auth.asyncio.sleep", new_callable=AsyncMock
- ) as mock_sleep:
- result = await s._retry_with_backoff(factory, delays=(0,))
- assert result == "done"
- mock_sleep.assert_not_called()
-
-
-class TestUpdateTokensFlatDictWithExpiresIn:
- """Lines 100->106: flat token dict with ExpiresIn sets token_expiry."""
-
- async def test_flat_dict_with_expires_in_sets_token_expiry(self):
- """Flat token dict containing ExpiresIn updates tokens.token_expiry."""
- s = _make_auth_stub()
- flat = {
- "token": "t",
- "refreshToken": "r",
- "accessToken": "a",
- "ExpiresIn": 1800,
- }
- await s.update_tokens(flat)
- assert s.tokens.token_expiry == timedelta(seconds=1800)
-
- async def test_flat_dict_tokens_are_stored(self):
- """All token values from flat dict are written to token_data."""
- s = _make_auth_stub()
- flat = {"token": "my-id", "refreshToken": "my-rt", "accessToken": "my-at"}
- await s.update_tokens(flat)
- assert s.tokens.token_data["token"] == "my-id"
- assert s.tokens.token_data["refreshToken"] == "my-rt"
- assert s.tokens.token_data["accessToken"] == "my-at"
-
-
-class TestLoginApiError:
- """Lines 160-162: HiveApiError in login() is logged and re-raised."""
-
- async def test_login_api_error_reraises(self):
- """HiveApiError raised by auth.login propagates unchanged to the caller."""
- s = _make_auth_stub()
- s.auth.login.side_effect = HiveApiError()
- with pytest.raises(HiveApiError):
- await s.login()
-
-
-class TestHiveRefreshTokensNoAuthResult:
- """Lines 341->373: refresh returns a result but without AuthenticationResult."""
-
- async def test_result_without_auth_result_does_not_update_tokens(self):
- """When refresh_token returns a dict with no AuthenticationResult, tokens stay unchanged."""
- s = _make_auth_stub()
- s.tokens.token_created = datetime.now() - timedelta(hours=2)
- s.tokens.token_expiry = timedelta(hours=1)
- # Return something truthy but without AuthenticationResult
- s.auth.refresh_token.return_value = {"SomeOtherKey": "value"}
- result = await s.hive_refresh_tokens()
- # Tokens must not have been updated
- assert s.tokens.token_data["token"] == ""
- assert s.tokens.token_data["accessToken"] == ""
- # result is what refresh_token returned
- assert result == {"SomeOtherKey": "value"}
-
- async def test_none_refresh_result_does_not_update_tokens(self):
- """When refresh_token returns None, tokens are left unchanged."""
- s = _make_auth_stub()
- s.tokens.token_created = datetime.now() - timedelta(hours=2)
- s.tokens.token_expiry = timedelta(hours=1)
- s.auth.refresh_token.return_value = None
- await s.hive_refresh_tokens()
- assert s.tokens.token_data["token"] == ""
-
-
-# ===========================================================================
-# 5. src/session/discovery.py
-# ===========================================================================
-
-
-class TestCreateDevicesEntityConfigKwargs:
- """Lines 224->226, 226->228, 228->230: entity_config kwarg population in DEVICES loop."""
-
- async def test_entity_config_with_all_fields_populates_kwargs(self):
- """EntityConfig with ha_name, hive_type, and category all set → all kwargs passed."""
- s = _make_discovery_stub(
- devices={
- "dev-1": {
- "id": "dev-1",
- "type": "hub",
- "state": {"name": "My Hub"},
- "props": {},
- }
- }
- )
- entity_cfg = EntityConfig(
- entity_type="binary_sensor",
- ha_name="Hub Status",
- hive_type="Connectivity",
- category="diagnostic",
- )
- with patch("apyhiveapi.session.discovery.DEVICES", {"hub": [entity_cfg]}):
- result = await s.create_devices()
- assert len(result["binary_sensor"]) == 1
- created = result["binary_sensor"][0]
- assert created.hive_type == "Connectivity"
- assert created.category == "diagnostic"
-
- async def test_entity_config_empty_fields_does_not_add_to_kwargs(self):
- """EntityConfig with empty ha_name and hive_type does not inject those keys."""
- s = _make_discovery_stub(
- devices={
- "dev-1": {
- "id": "dev-1",
- "type": "hub",
- "state": {"name": "My Hub"},
- "props": {},
- }
- }
- )
- entity_cfg = EntityConfig(
- entity_type="binary_sensor",
- ha_name="", # falsy — should not be added to kwargs
- hive_type="", # falsy — should not be added to kwargs
- category=None, # None — should not be added to kwargs
- )
- with patch("apyhiveapi.session.discovery.DEVICES", {"hub": [entity_cfg]}):
- result = await s.create_devices()
- # Should still process without error
- assert isinstance(result, dict)
-
-
-class TestCreateDevicesDeviceAddListError:
- """Lines 232-233: KeyError/TypeError from add_list in DEVICES loop is caught."""
-
- async def test_add_list_keyerror_is_caught_not_raised(self):
- """KeyError from add_list during device processing is logged, not propagated."""
- s = _make_discovery_stub(
- devices={
- "dev-1": {
- "id": "dev-1",
- "type": "hub",
- "state": {"name": "My Hub"},
- "props": {},
- }
- }
- )
- entity_cfg = EntityConfig(
- entity_type="binary_sensor",
- ha_name="Hub Status",
- hive_type="Connectivity",
- category="diagnostic",
- )
- with patch("apyhiveapi.session.discovery.DEVICES", {"hub": [entity_cfg]}):
- with patch.object(s, "add_list", side_effect=KeyError("bad key")):
- # Should complete without raising
- result = await s.create_devices()
- assert isinstance(result, dict)
-
- async def test_add_list_typeerror_is_caught_not_raised(self):
- """TypeError from add_list during device processing is caught."""
- s = _make_discovery_stub(
- devices={
- "dev-1": {
- "id": "dev-1",
- "type": "hub",
- "state": {"name": "My Hub"},
- "props": {},
- }
- }
- )
- entity_cfg = EntityConfig(
- entity_type="binary_sensor",
- ha_name="",
- hive_type="",
- category=None,
- )
- with patch("apyhiveapi.session.discovery.DEVICES", {"hub": [entity_cfg]}):
- with patch.object(s, "add_list", side_effect=TypeError("bad type")):
- result = await s.create_devices()
- assert isinstance(result, dict)
-
-
-class TestCreateDevicesActionAddListError:
- """Lines 258-259: KeyError/TypeError from add_list in actions loop is caught."""
-
- async def test_action_add_list_keyerror_is_caught(self):
- """KeyError from add_list when processing an action is logged, not propagated."""
- s = _make_discovery_stub(
- actions={"act-1": {"id": "act-1", "name": "Good Night"}}
- )
- with patch.object(s, "add_list", side_effect=KeyError("missing")):
- result = await s.create_devices()
- assert isinstance(result, dict)
-
- async def test_action_add_list_typeerror_is_caught(self):
- """TypeError from add_list when processing an action is caught."""
- s = _make_discovery_stub(actions={"act-1": {"id": "act-1", "name": "Wake Up"}})
- with patch.object(s, "add_list", side_effect=TypeError("type error")):
- result = await s.create_devices()
- assert isinstance(result, dict)
-
-
-class TestCreateDevicesProductTemperatureUnit:
- """Line 305: entity_config.temperature_unit is used when set and entity_type != 'climate'."""
-
- async def test_entity_config_temperature_unit_passed_to_add_list(self):
- """EntityConfig with temperature_unit set propagates that value as a kwarg."""
- s = _make_discovery_stub(
- products={
- "prod-1": {
- "id": "prod-1",
- "type": "heating",
- "state": {"name": "Heating"},
- "props": {},
- }
- }
- )
- # A non-climate entity with temperature_unit set triggers line 305
- entity_cfg = EntityConfig(
- entity_type="sensor",
- ha_name="Temp Sensor",
- hive_type="Current_Temperature",
- category="diagnostic",
- temperature_unit="F",
- )
- captured_kwargs = {}
-
- original_add_list = s.add_list
-
- def capturing_add_list(entity_type, data, **kwargs):
- captured_kwargs.update(kwargs)
- return original_add_list(entity_type, data, **kwargs)
-
- with patch("apyhiveapi.session.discovery.PRODUCTS", {"heating": [entity_cfg]}):
- with patch.object(s, "add_list", side_effect=capturing_add_list):
- await s.create_devices()
-
- assert captured_kwargs.get("temperature_unit") == "F"
-
-
-class TestCreateDevicesProductAddListAttributeError:
- """Lines 308-309: NameError/AttributeError from add_list in products loop is caught."""
-
- async def test_product_add_list_attribute_error_is_caught(self):
- """AttributeError from add_list when processing a product is caught."""
- s = _make_discovery_stub(
- products={
- "prod-1": {
- "id": "prod-1",
- "type": "heating",
- "state": {"name": "Heating"},
- "props": {},
- }
- }
- )
- entity_cfg = EntityConfig(
- entity_type="climate",
- ha_name="",
- hive_type="",
- category=None,
- )
- with patch("apyhiveapi.session.discovery.PRODUCTS", {"heating": [entity_cfg]}):
- with patch.object(s, "add_list", side_effect=AttributeError("attr error")):
- result = await s.create_devices()
- assert isinstance(result, dict)
-
- async def test_product_add_list_name_error_is_caught(self):
- """NameError from add_list when processing a product is caught."""
- s = _make_discovery_stub(
- products={
- "prod-1": {
- "id": "prod-1",
- "type": "heating",
- "state": {"name": "Heating"},
- "props": {},
- }
- }
- )
- entity_cfg = EntityConfig(
- entity_type="climate",
- ha_name="",
- hive_type="",
- category=None,
- )
- with patch("apyhiveapi.session.discovery.PRODUCTS", {"heating": [entity_cfg]}):
- with patch.object(s, "add_list", side_effect=NameError("name error")):
- result = await s.create_devices()
- assert isinstance(result, dict)
-
-
-# ===========================================================================
-# Additional False-branch tests: cache-miss paths and elif False paths
-# ===========================================================================
-
-
-def _make_light_session(products=None, devices=None):
- session = MagicMock()
- session.data = Map(
- {
- "products": products or {},
- "devices": devices or {},
- "actions": {},
- "minMax": {},
- "user": {},
- }
- )
- session.config = SessionConfig()
- session.helper = MagicMock()
- session.helper.device_recovered = MagicMock()
- session.helper.error_check = AsyncMock()
- session.attr = MagicMock()
- session.attr.online_offline = AsyncMock(return_value=True)
- session.attr.state_attributes = AsyncMock(return_value={})
- session.api = MagicMock()
- session.api.set_state = AsyncMock(return_value={"original": 200, "parsed": {}})
- session.hive_refresh_tokens = AsyncMock()
- session.get_devices = AsyncMock(return_value=True)
- session.should_use_cached_data = MagicMock(return_value=False)
- session.get_cached_device = MagicMock(return_value=None)
- session.set_cached_device = MagicMock(side_effect=lambda d: d)
- return session
-
-
-def _make_light_device(
- hive_id="light-1", device_id="dev-1", hive_type="warmwhitelight"
-):
- return Device(
- hive_id=hive_id,
- hive_name="Bulb",
- hive_type=hive_type,
- ha_type="light",
- device_id=device_id,
- device_name="Bulb",
- device_data={"online": True},
- ha_name="Bulb",
- )
-
-
-# ---------------------------------------------------------------------------
-# heating.py: 321->325 — set_boost_off with non-MANUAL/OFF previous mode
-# ---------------------------------------------------------------------------
-
-
-class TestHeatingSetBoostOffScheduleMode:
- """Lines 321->325: prev_mode not in ('MANUAL','OFF') — target kwarg not added."""
-
- async def test_schedule_mode_no_target_kwarg(self):
- """SCHEDULE as previous mode does not add a target kwarg."""
- climate = _make_climate(
- {
- "heat-1": {
- "type": "heating",
- "state": {"boost": 5},
- "props": {"previous": {"mode": "SCHEDULE"}},
- }
- }
- )
- result = await climate.set_boost_off(_make_device())
- assert result is True
- _, kwargs = climate.session.api.set_state.call_args
- assert "target" not in kwargs
- assert kwargs.get("mode") == "SCHEDULE"
-
-
-# ---------------------------------------------------------------------------
-# heating.py: 371->377 — get_climate: should_use_cached=True but cached is None
-# ---------------------------------------------------------------------------
-
-
-class TestHeatingGetClimateCacheMiss:
- """Lines 371->377: cache enabled but cached device is None → normal execution."""
-
- async def test_cached_none_falls_through_to_normal_path(self):
- """should_use_cached_data=True but get_cached_device=None → normal update."""
- climate = _make_climate(
- {
- "heat-1": {
- "state": {"mode": "MANUAL", "target": 20.0},
- "props": {"temperature": 19.0},
- }
- },
- devices={"dev-1": {"state": {}, "props": {}}},
- )
- climate.session.should_use_cached_data.return_value = True
- climate.session.get_cached_device.return_value = None
- result = await climate.get_climate(_make_device())
- assert result is not None
- climate.session.attr.online_offline.assert_called_once()
-
-
-# ---------------------------------------------------------------------------
-# hotwater.py: 173->180 — same pattern
-# ---------------------------------------------------------------------------
-
-
-class TestHotwaterGetWaterHeaterCacheMiss:
- """Lines 173->180: cache enabled but cached is None → continues with network call."""
-
- async def test_cached_none_falls_through(self):
- hw = _make_hotwater(
- {"hw-1": {"state": {"mode": "ON"}, "props": {}}},
- devices={"dev-1": {"state": {}, "props": {}}},
- )
- hw.session.should_use_cached_data.return_value = True
- hw.session.get_cached_device.return_value = None
- result = await hw.get_water_heater(_make_hw_device())
- assert result is not None
- hw.session.attr.online_offline.assert_called_once()
-
-
-# ---------------------------------------------------------------------------
-# light.py: 141->147 — same pattern
-# ---------------------------------------------------------------------------
-
-
-class TestLightGetLightCacheMiss:
- """Lines 141->147: cache enabled but cached is None → normal execution."""
-
- async def test_cached_none_falls_through(self):
- session = _make_light_session(
- products={
- "light-1": {"state": {"status": "ON", "brightness": 100}, "props": {}}
- },
- devices={"dev-1": {"state": {}, "props": {}}},
- )
- light = Light(session=session)
- d = _make_light_device()
- session.should_use_cached_data.return_value = True
- session.get_cached_device.return_value = None
- result = await light.get_light(d)
- assert result is not None
- session.attr.online_offline.assert_called_once()
-
-
-# ---------------------------------------------------------------------------
-# sensor.py: 37->42 — get_state: type neither contactsensor nor motionsensor
-# ---------------------------------------------------------------------------
-
-
-class TestSensorGetStateUnknownType:
- """Lines 37->42: data['type'] is neither contactsensor nor motionsensor."""
-
- async def test_unknown_type_returns_none(self):
- """Product with type 'hub' skips both if/elif → final stays None."""
- sensor = _make_sensor({"sens-1": {"type": "hub", "props": {}}})
- d = _make_sensor_device()
- result = await sensor.get_state(d)
- assert result is None
-
-
-# ---------------------------------------------------------------------------
-# sensor.py: 92->98 — get_sensor: cache enabled but cached is None
-# ---------------------------------------------------------------------------
-
-
-class TestSensorGetSensorCacheMiss:
- """Lines 92->98: should_use_cached_data=True but cached is None."""
-
- async def test_cached_none_falls_through(self):
- sensor = _make_sensor(
- products={"sens-1": {"type": "contactsensor", "props": {"status": "OPEN"}}},
- devices={"dev-1": {"props": {"online": True}, "type": "contactsensor"}},
- )
- sensor.session.should_use_cached_data.return_value = True
- sensor.session.get_cached_device.return_value = None
- d = _make_sensor_device()
- result = await sensor.get_sensor(d)
- assert result is not None
- sensor.session.attr.online_offline.assert_called_once()
-
-
-# ---------------------------------------------------------------------------
-# sensor.py: 119->122 — get_sensor: neither device_id nor hive_id found
-# ---------------------------------------------------------------------------
-
-
-class TestSensorGetSensorNoDataFallthrough:
- """Lines 119->122: device_id not in devices AND hive_id not in products."""
-
- async def test_neither_match_continues_with_empty_data(self):
- """data stays empty dict when neither lookup succeeds."""
- sensor = _make_sensor(products={}, devices={})
- d = _make_sensor_device(
- hive_id="unknown-hive", device_id="unknown-dev", hive_type="contactsensor"
- )
- result = await sensor.get_sensor(d)
- # Should not raise; result will be the device (set_cached_device returns it)
- assert result is not None
-
-
-# ---------------------------------------------------------------------------
-# sensor.py: 135->146 — get_sensor: hive_type not in sensor_commands or HIVE_TYPES["Sensor"]
-# ---------------------------------------------------------------------------
-
-
-class TestSensorGetSensorUnknownHiveType:
- """Lines 135->146: hive_type not in sensor_commands and not in HIVE_TYPES['Sensor']."""
-
- async def test_hive_type_not_in_either_dict_skips_both_branches(self):
- """activeplug is neither in sensor_commands nor HIVE_TYPES['Sensor']."""
- sensor = _make_sensor(
- devices={"dev-1": {"props": {"online": True}, "type": "activeplug"}}
- )
- d = _make_sensor_device(
- hive_id="dev-1", device_id="dev-1", hive_type="activeplug"
- )
- d.device_data = {"online": True}
- result = await sensor.get_sensor(d)
- # Neither branch sets device.status; device returned as-is via set_cached_device
- assert result is not None
-
-
-# ===========================================================================
-# Additional branches: session/auth.py, hive_helper.py, heating.py
-# ===========================================================================
-
-
-class TestUpdateTokensUnknownKey:
- """session/auth.py 100->106: tokens dict has neither AuthenticationResult nor token."""
-
- async def test_unknown_key_does_not_raise_and_does_not_update_tokens(self):
- """When neither expected key is present, data stays {}, ExpiresIn check skips."""
- s = _make_auth_stub()
- original_token = s.tokens.token_data["token"]
- # Pass a dict that is neither the AuthResult form nor the flat-token form
- await s.update_tokens({"some_other_key": "some_value"})
- # Tokens must be unchanged
- assert s.tokens.token_data["token"] == original_token
-
- async def test_unknown_key_does_not_set_token_expiry(self):
- """ExpiresIn check at line 106 skips when data is {} (no match in either branch)."""
- s = _make_auth_stub()
- original_expiry = s.tokens.token_expiry
- await s.update_tokens({"random_key": "random_value"})
- assert s.tokens.token_expiry == original_expiry
-
-
-class TestHiveHelperZoneMismatch:
- """hive_helper.py 163->160: loop continues when zones don't match."""
-
- def test_zone_mismatch_keeps_product_as_device(self):
- """When a Thermo device's zone doesn't match the product's zone,
- the loop arc 163->160 is taken and device stays as the product."""
- helper = HiveHelper(session=MagicMock())
- helper.session.data = Map(
- {
- "devices": {
- "thermo-1": {
- "type": "thermostatui",
- "props": {"zone": "zone-B"},
- }
- },
- "products": {},
- "actions": {},
- "user": {},
- "minMax": {},
- }
- )
-
- product = {
- "type": "heating",
- "id": "prod-1",
- "props": {"zone": "zone-A"}, # different zone from thermo-1
- }
-
- result = helper.get_device_data(product)
- # The zone mismatch means device was never re-assigned; returns the product
- assert result is product
-
- def test_trv_without_zone_does_not_log_warning(self, caplog):
- """TRV devices that omit 'zone' from props are silently skipped (no warning)."""
- import logging
-
- helper = HiveHelper(session=MagicMock())
- helper.session.data = Map(
- {
- "devices": {
- "trv-1": {
- "type": "trv",
- "props": {
- "online": True
- }, # no 'zone' key — current API behaviour
- }
- },
- "products": {},
- "actions": {},
- "user": {},
- "minMax": {},
- }
- )
-
- product = {
- "type": "heating",
- "id": "prod-1",
- "props": {"zone": "zone-A"},
- }
-
- with caplog.at_level(logging.WARNING, logger="apyhiveapi.helper.hive_helper"):
- result = helper.get_device_data(product)
-
- assert result is product
- assert not caplog.records, (
- f"Unexpected warnings: {[r.getMessage() for r in caplog.records]}"
- )
-
- def test_zone_match_replaces_device_with_thermostat(self):
- """Matching zones cause device to be replaced with the thermostat entry."""
- helper = HiveHelper(session=MagicMock())
- thermo_data = {
- "type": "thermostatui",
- "props": {"zone": "zone-X"},
- }
- helper.session.data = Map(
- {
- "devices": {"thermo-1": thermo_data},
- "products": {},
- "actions": {},
- "user": {},
- "minMax": {},
- }
- )
-
- product = {
- "type": "heating",
- "id": "prod-1",
- "props": {"zone": "zone-X"}, # matching zone
- }
-
- result = helper.get_device_data(product)
- assert result is thermo_data
-
-
-class TestHiveHelperSanitizeDictValue:
- """hive_helper.py line 328: dict value under a sensitive key calls _mask(dict)."""
-
- def test_dict_under_sensitive_key_is_recursively_masked(self):
- """A dict value under 'token' key hits the isinstance(value, dict) branch."""
- helper = HiveHelper()
- result = helper.sanitize_payload({"token": {"inner_key": "secret_value"}})
- # 'token' is sensitive → _mask is called with the nested dict
- # _mask for a dict returns {k: _mask(v) for k, v in value.items()}
- # _mask("secret_value") → "sec...lue" (long enough) or "***"
- assert "token" in result
- assert isinstance(result["token"], dict)
- assert "inner_key" in result["token"]
- # The inner value should be masked (not the original)
- assert result["token"]["inner_key"] != "secret_value"
-
- def test_nested_dict_keys_preserved_after_masking(self):
- """Keys inside a sensitive dict are preserved, values are masked."""
- helper = HiveHelper()
- result = helper.sanitize_payload(
- {
- "authenticationresult": {
- "AccessToken": "long-secret-token-value",
- "ExpiresIn": 3600,
- }
- }
- )
- inner = result["authenticationresult"]
- assert "AccessToken" in inner
- assert "ExpiresIn" in inner
- # ExpiresIn is an int, _mask returns it as-is
- assert inner["ExpiresIn"] == 3600
-
-
-class TestHiveHelperSanitizeListNode:
- """hive_helper.py line 359: list value under a non-sensitive key calls _walk(list)."""
-
- def test_list_under_non_sensitive_key_is_walked(self):
- """A list value under a non-sensitive key hits the isinstance(node, list) branch."""
- helper = HiveHelper()
- result = helper.sanitize_payload({"devices": ["device-a", "device-b"]})
- # 'devices' is not a sensitive key → _walk called for the list
- # _walk for a list returns [_walk(item) for item in node]
- # Each string item: _walk(str) → str (falls through to return node)
- assert result == {"devices": ["device-a", "device-b"]}
-
- def test_list_containing_dicts_is_walked_recursively(self):
- """A list of dicts under a non-sensitive key is recursively processed."""
- helper = HiveHelper()
- result = helper.sanitize_payload(
- {
- "items": [
- {"token": "abc", "name": "device1"},
- {"token": "xyz", "name": "device2"},
- ]
- }
- )
- # 'items' is not sensitive → _walk called for the list
- # Each dict in the list is processed by _walk
- # 'token' IS sensitive → masked in each sub-dict
- assert result["items"][0]["name"] == "device1"
- assert result["items"][0]["token"] != "abc"
- assert result["items"][1]["name"] == "device2"
- assert result["items"][1]["token"] != "xyz"
-
-
-class TestHeatingGetStateExceptionCaught:
- """heating.py lines 206-207: except (KeyError, TypeError) handler is reached."""
-
- async def test_key_error_in_get_current_temperature_is_caught(self):
- """KeyError raised by get_current_temperature is caught, final stays None."""
- climate = _make_climate(
- {"heat-1": {"state": {"mode": "MANUAL", "target": 20.0}, "props": {}}}
- )
- d = _make_device()
- with patch.object(
- climate, "get_current_temperature", new_callable=AsyncMock
- ) as mock_t:
- mock_t.side_effect = KeyError("missing_key")
- result = await climate.get_state(d)
- assert result is None
-
- async def test_type_error_in_get_target_temperature_is_caught(self):
- """TypeError raised by get_target_temperature is caught, final stays None."""
- climate = _make_climate(
- {"heat-1": {"state": {"mode": "MANUAL", "target": 20.0}, "props": {}}}
- )
- d = _make_device()
- with patch.object(
- climate, "get_current_temperature", new_callable=AsyncMock
- ) as mock_cur:
- mock_cur.return_value = 19.0
- with patch.object(
- climate, "get_target_temperature", new_callable=AsyncMock
- ) as mock_tgt:
- mock_tgt.side_effect = TypeError("bad type")
- result = await climate.get_state(d)
- assert result is None
diff --git a/tests/unit/test_sensor_extended.py b/tests/unit/test_sensor.py
similarity index 52%
rename from tests/unit/test_sensor_extended.py
rename to tests/unit/test_sensor.py
index 68122f3..3bad1b1 100644
--- a/tests/unit/test_sensor_extended.py
+++ b/tests/unit/test_sensor.py
@@ -2,7 +2,7 @@
# pylint: disable=protected-access
-from unittest.mock import AsyncMock, MagicMock
+from unittest.mock import AsyncMock, MagicMock, patch
from apyhiveapi.devices.sensor import Sensor
from apyhiveapi.helper.hivedataclasses import Device, SessionConfig
@@ -146,6 +146,39 @@ async def test_contact_sensor_in_hive_types_sets_status(self):
assert "state" in result.status
session.attr.state_attributes.assert_awaited_once()
+ async def test_contact_sensor_uses_device_id_not_hive_id_for_props(self):
+ """HIVE_TYPES['Sensor'] branch must look up data.devices by device_id.
+
+ Before the fix, line 160 used hive_id; data was always {} so
+ device.parent_device was always None even when the device existed.
+ """
+ hive_id = "prod-abc"
+ device_id = "dev-xyz" # deliberately different from hive_id
+
+ products = {} # contactsensor is NOT in products
+ devices = {
+ device_id: {
+ "props": {"online": True, "signal": -70},
+ "parent": "hub-parent-id",
+ }
+ }
+ session = _make_session(products=products, devices=devices)
+ session.attr.online_offline = AsyncMock(return_value=True)
+
+ device = _make_device(
+ hive_id=hive_id, device_id=device_id, hive_type="contactsensor"
+ )
+ device.device_data = {"online": True}
+
+ sensor = Sensor(session)
+ with patch.object(sensor, "get_state", new=AsyncMock(return_value="CLOSED")):
+ result = await sensor.get_sensor(device)
+
+ assert result is not None
+ assert device.parent_device == "hub-parent-id", (
+ "parent_device must come from data.devices[device_id], not hive_id lookup"
+ )
+
class TestGetState:
"""Tests for HiveSensor.get_state covering the motionsensor branch (lines 37-42)."""
@@ -170,3 +203,90 @@ async def test_motionsensor_returns_motion_status(self):
state = await sensor.get_state(device)
assert state is True
+
+
+# ===========================================================================
+# Migrated from test_remaining_branches.py
+# ===========================================================================
+
+
+class TestSensorGetStateKeyError:
+ """Lines 37->42: KeyError in HiveSensor.get_state."""
+
+ async def test_get_state_missing_type_key_returns_none(self):
+ """Product with no 'type' key causes KeyError → final stays None."""
+ session = _make_session({"sens-1": {}})
+ sensor = Sensor(session=session)
+ d = _make_device(hive_id="sens-1", device_id="dev-1", hive_type="contactsensor")
+ result = await sensor.get_state(d)
+ assert result is None
+
+ async def test_get_state_missing_props_key_returns_none(self):
+ """contactsensor product without 'props' causes KeyError → None."""
+ session = _make_session({"sens-1": {"type": "contactsensor"}})
+ sensor = Sensor(session=session)
+ d = _make_device(hive_id="sens-1", device_id="dev-1", hive_type="contactsensor")
+ result = await sensor.get_state(d)
+ assert result is None
+
+
+class TestSensorGetStateUnknownType:
+ """Lines 37->42: data['type'] is neither contactsensor nor motionsensor."""
+
+ async def test_unknown_type_returns_none(self):
+ """Product with type 'hub' skips both if/elif → final stays None."""
+ session = _make_session({"sens-1": {"type": "hub", "props": {}}})
+ sensor = Sensor(session=session)
+ d = _make_device(hive_id="sens-1", device_id="dev-1", hive_type="contactsensor")
+ result = await sensor.get_state(d)
+ assert result is None
+
+
+class TestSensorGetSensorCacheMiss:
+ """Lines 92->98: should_use_cached_data=True but cached is None."""
+
+ async def test_cached_none_falls_through(self):
+ session = _make_session(
+ products={"sens-1": {"type": "contactsensor", "props": {"status": "OPEN"}}},
+ devices={"dev-1": {"props": {"online": True}, "type": "contactsensor"}},
+ )
+ session.should_use_cached_data.return_value = True
+ session.get_cached_device.return_value = None
+ sensor = Sensor(session=session)
+ d = _make_device(hive_id="sens-1", device_id="dev-1", hive_type="contactsensor")
+ result = await sensor.get_sensor(d)
+ assert result is not None
+ session.attr.online_offline.assert_called_once()
+
+
+class TestSensorGetSensorNoDataFallthrough:
+ """Lines 119->122: device_id not in devices AND hive_id not in products."""
+
+ async def test_neither_match_continues_with_empty_data(self):
+ """data stays empty dict when neither lookup succeeds."""
+ session = _make_session(products={}, devices={})
+ sensor = Sensor(session=session)
+ d = _make_device(
+ hive_id="unknown-hive",
+ device_id="unknown-dev",
+ hive_type="contactsensor",
+ )
+ result = await sensor.get_sensor(d)
+ # Should not raise; result will be the device (set_cached_device returns it)
+ assert result is not None
+
+
+class TestSensorGetSensorUnknownHiveType:
+ """Lines 135->146: hive_type not in sensor_commands and not in HIVE_TYPES['Sensor']."""
+
+ async def test_hive_type_not_in_either_dict_skips_both_branches(self):
+ """activeplug is neither in sensor_commands nor HIVE_TYPES['Sensor']."""
+ session = _make_session(
+ devices={"dev-1": {"props": {"online": True}, "type": "activeplug"}}
+ )
+ sensor = Sensor(session=session)
+ d = _make_device(hive_id="dev-1", device_id="dev-1", hive_type="activeplug")
+ d.device_data = {"online": True}
+ result = await sensor.get_sensor(d)
+ # Neither branch sets device.status; device returned as-is via set_cached_device
+ assert result is not None
diff --git a/tests/unit/test_session_auth_extended.py b/tests/unit/test_session_auth.py
similarity index 52%
rename from tests/unit/test_session_auth_extended.py
rename to tests/unit/test_session_auth.py
index e688096..87e9967 100644
--- a/tests/unit/test_session_auth_extended.py
+++ b/tests/unit/test_session_auth.py
@@ -93,6 +93,84 @@ async def test_auth_result_with_update_expiry_true_sets_token_created(self):
await s.update_tokens(AUTH_RESULT, update_expiry_time=True)
assert s.tokens.token_created > before
+ async def test_auth_result_missing_id_token_does_not_raise(self):
+ """AuthenticationResult with only AccessToken (e.g. file mode) must not crash."""
+ s = _make_stub()
+ payload = {"AuthenticationResult": {"AccessToken": "only-access"}}
+ await s.update_tokens(payload)
+ assert s.tokens.token_data["accessToken"] == "only-access"
+ # IdToken absent — the session token must not have been overwritten
+ assert s.tokens.token_data["token"] == ""
+
+ async def test_flat_dict_missing_access_token_does_not_raise(self):
+ """A flat token dict without accessToken must not crash."""
+ s = _make_stub()
+ flat = {"token": "t-only"}
+ await s.update_tokens(flat)
+ assert s.tokens.token_data["token"] == "t-only"
+ assert s.tokens.token_data["accessToken"] == ""
+
+ async def test_auth_result_missing_access_token_does_not_raise(self):
+ """AuthenticationResult with only IdToken must not crash."""
+ s = _make_stub()
+ payload = {"AuthenticationResult": {"IdToken": "only-id"}}
+ await s.update_tokens(payload)
+ assert s.tokens.token_data["token"] == "only-id"
+ assert s.tokens.token_data["accessToken"] == ""
+
+
+# ---------------------------------------------------------------------------
+# _retry_with_backoff — re-raise semantics
+# ---------------------------------------------------------------------------
+
+
+class _NeedsArgsError(Exception):
+ """Exception type that cannot be constructed without arguments."""
+
+ def __init__(self, first, second):
+ super().__init__(f"{first}/{second}")
+
+
+class TestRetryWithBackoffReraise:
+ """Without reraise_as, the original exception instance must propagate."""
+
+ async def test_original_exception_instance_propagates(self):
+ """The last caught error is re-raised as-is, not re-instantiated."""
+ s = _make_stub()
+ original = _NeedsArgsError("a", "b")
+
+ async def _always_fail():
+ raise original
+
+ with pytest.raises(_NeedsArgsError) as excinfo:
+ await s._retry_with_backoff(_always_fail, delays=(0,))
+
+ assert excinfo.value is original
+
+ async def test_empty_delays_raises_runtime_error(self):
+ """With no attempts configured the defensive fallback raises RuntimeError."""
+ s = _make_stub()
+
+ async def _never_called():
+ raise AssertionError("should not run")
+
+ with pytest.raises(RuntimeError, match="exhausted"):
+ await s._retry_with_backoff(_never_called, delays=())
+
+ async def test_reraise_as_still_translates_exception_type(self):
+ """When reraise_as is given the error is translated with chaining."""
+ s = _make_stub()
+
+ async def _always_fail():
+ raise ValueError("boom")
+
+ with pytest.raises(HiveReauthRequired) as excinfo:
+ await s._retry_with_backoff(
+ _always_fail, delays=(0,), reraise_as=HiveReauthRequired
+ )
+
+ assert isinstance(excinfo.value.__cause__, ValueError)
+
# ---------------------------------------------------------------------------
# _handle_device_login_challenge — extra branch
@@ -290,3 +368,193 @@ async def test_force_refresh_enters_lock_even_when_token_is_fresh(self):
s.auth.refresh_token.return_value = AUTH_RESULT
await s.hive_refresh_tokens(force_refresh=True)
s.auth.refresh_token.assert_called_once()
+
+
+# ---------------------------------------------------------------------------
+# update_tokens — elif "token" branch missing token_created and bare refreshToken
+# ---------------------------------------------------------------------------
+
+
+class TestUpdateTokensTokenBranch:
+ """update_tokens must set token_created and guard missing refreshToken in elif branch."""
+
+ async def test_token_branch_sets_token_created(self):
+ """elif 'token' branch must update token_created (was missing, stayed datetime.min)."""
+ s = _make_stub()
+ await s.update_tokens(
+ {"token": "id-tok", "refreshToken": "ref-tok", "accessToken": "acc-tok"}
+ )
+ assert s.tokens.token_created > datetime.min
+
+ async def test_token_branch_updates_token_data(self):
+ """elif 'token' branch stores all three token values."""
+ s = _make_stub()
+ await s.update_tokens(
+ {"token": "id", "refreshToken": "ref", "accessToken": "acc"}
+ )
+ assert s.tokens.token_data["token"] == "id"
+ assert s.tokens.token_data["refreshToken"] == "ref"
+ assert s.tokens.token_data["accessToken"] == "acc"
+
+ async def test_token_branch_missing_refresh_token_does_not_crash(self):
+ """elif 'token' branch without refreshToken key must not raise KeyError."""
+ s = _make_stub()
+ await s.update_tokens({"token": "id", "accessToken": "acc"})
+ assert s.tokens.token_data["token"] == "id"
+
+ async def test_token_branch_update_expiry_false_does_not_update_token_created(self):
+ """update_expiry_time=False skips the token_created assignment in elif 'token' branch."""
+
+ s = _make_stub()
+ original_created = s.tokens.token_created
+ await s.update_tokens(
+ {"token": "id", "accessToken": "acc"},
+ update_expiry_time=False,
+ )
+ assert s.tokens.token_created == original_created
+
+
+# ---------------------------------------------------------------------------
+# hive_refresh_tokens — bare refreshToken access raises KeyError when missing
+# ---------------------------------------------------------------------------
+
+
+class TestHiveRefreshTokensMissingRefreshToken:
+ """hive_refresh_tokens must not crash when token_data has no refreshToken."""
+
+ async def test_missing_refresh_token_does_not_raise_key_error(self):
+ """hive_refresh_tokens without refreshToken in token_data must not crash."""
+ s = _make_stub()
+ s.tokens.token_data = {"token": "id", "accessToken": "acc"}
+ s.tokens.token_created = datetime.now() - timedelta(hours=2)
+ s.tokens.token_expiry = timedelta(hours=1)
+ s.auth.refresh_token.return_value = None
+ result = await s.hive_refresh_tokens()
+ assert result is None
+
+
+# ===========================================================================
+# Migrated from test_remaining_branches.py
+# ===========================================================================
+
+
+class TestRetryWithBackoffNonZeroDelay:
+ """Line 66: asyncio.sleep called when delay > 0."""
+
+ async def test_non_zero_delay_is_awaited_but_succeeds(self):
+ """A non-zero delay entry causes asyncio.sleep to be called; factory still runs."""
+ from unittest.mock import patch
+
+ s = _make_stub()
+ calls = []
+
+ async def factory():
+ calls.append(1)
+ return "ok"
+
+ with patch(
+ "apyhiveapi.session.auth.asyncio.sleep", new_callable=AsyncMock
+ ) as mock_sleep:
+ result = await s._retry_with_backoff(factory, delays=(5,))
+ assert result == "ok"
+ mock_sleep.assert_called_once_with(5)
+ assert len(calls) == 1
+
+ async def test_zero_delay_does_not_call_sleep(self):
+ """A zero delay skips asyncio.sleep."""
+ from unittest.mock import patch
+
+ s = _make_stub()
+
+ async def factory():
+ return "done"
+
+ with patch(
+ "apyhiveapi.session.auth.asyncio.sleep", new_callable=AsyncMock
+ ) as mock_sleep:
+ result = await s._retry_with_backoff(factory, delays=(0,))
+ assert result == "done"
+ mock_sleep.assert_not_called()
+
+
+class TestUpdateTokensFlatDictWithExpiresIn:
+ """Lines 100->106: flat token dict with ExpiresIn sets token_expiry."""
+
+ async def test_flat_dict_with_expires_in_sets_token_expiry(self):
+ """Flat token dict containing ExpiresIn updates tokens.token_expiry."""
+ s = _make_stub()
+ flat = {
+ "token": "t",
+ "refreshToken": "r",
+ "accessToken": "a",
+ "ExpiresIn": 1800,
+ }
+ await s.update_tokens(flat)
+ assert s.tokens.token_expiry == timedelta(seconds=1800)
+
+ async def test_flat_dict_tokens_are_stored(self):
+ """All token values from flat dict are written to token_data."""
+ s = _make_stub()
+ flat = {"token": "my-id", "refreshToken": "my-rt", "accessToken": "my-at"}
+ await s.update_tokens(flat)
+ assert s.tokens.token_data["token"] == "my-id"
+ assert s.tokens.token_data["refreshToken"] == "my-rt"
+ assert s.tokens.token_data["accessToken"] == "my-at"
+
+
+class TestLoginApiError:
+ """Lines 160-162: HiveApiError in login() is logged and re-raised."""
+
+ async def test_login_api_error_reraises(self):
+ """HiveApiError raised by auth.login propagates unchanged to the caller."""
+ s = _make_stub()
+ s.auth.login.side_effect = HiveApiError()
+ with pytest.raises(HiveApiError):
+ await s.login()
+
+
+class TestHiveRefreshTokensNoAuthResult:
+ """Lines 341->373: refresh returns a result but without AuthenticationResult."""
+
+ async def test_result_without_auth_result_does_not_update_tokens(self):
+ """When refresh_token returns a dict with no AuthenticationResult, tokens stay unchanged."""
+ s = _make_stub()
+ s.tokens.token_created = datetime.now() - timedelta(hours=2)
+ s.tokens.token_expiry = timedelta(hours=1)
+ # Return something truthy but without AuthenticationResult
+ s.auth.refresh_token.return_value = {"SomeOtherKey": "value"}
+ result = await s.hive_refresh_tokens()
+ # Tokens must not have been updated
+ assert s.tokens.token_data["token"] == ""
+ assert s.tokens.token_data["accessToken"] == ""
+ # result is what refresh_token returned
+ assert result == {"SomeOtherKey": "value"}
+
+ async def test_none_refresh_result_does_not_update_tokens(self):
+ """When refresh_token returns None, tokens are left unchanged."""
+ s = _make_stub()
+ s.tokens.token_created = datetime.now() - timedelta(hours=2)
+ s.tokens.token_expiry = timedelta(hours=1)
+ s.auth.refresh_token.return_value = None
+ await s.hive_refresh_tokens()
+ assert s.tokens.token_data["token"] == ""
+
+
+class TestUpdateTokensUnknownKey:
+ """session/auth.py 100->106: tokens dict has neither AuthenticationResult nor token."""
+
+ async def test_unknown_key_does_not_raise_and_does_not_update_tokens(self):
+ """When neither expected key is present, data stays {}, ExpiresIn check skips."""
+ s = _make_stub()
+ original_token = s.tokens.token_data["token"]
+ # Pass a dict that is neither the AuthResult form nor the flat-token form
+ await s.update_tokens({"some_other_key": "some_value"})
+ # Tokens must be unchanged
+ assert s.tokens.token_data["token"] == original_token
+
+ async def test_unknown_key_does_not_set_token_expiry(self):
+ """ExpiresIn check at line 106 skips when data is {} (no match in either branch)."""
+ s = _make_stub()
+ original_expiry = s.tokens.token_expiry
+ await s.update_tokens({"random_key": "random_value"})
+ assert s.tokens.token_expiry == original_expiry
diff --git a/tests/unit/test_session_close.py b/tests/unit/test_session_close.py
index 49edabb..b66df73 100644
--- a/tests/unit/test_session_close.py
+++ b/tests/unit/test_session_close.py
@@ -9,30 +9,62 @@ class TestHiveSessionClose:
"""Branch coverage for HiveSession.close() (line 79)."""
async def test_close_calls_websession_close_when_not_already_closed(self):
- """close() calls websession.close() when websession is open (closed=False).
-
- Covers the True branch of 'if not self.api.websession.closed'.
- """
+ """close() calls websession.close() when websession is open (closed=False)."""
session = object.__new__(HiveSession)
session.api = MagicMock()
session.api.websession.closed = False
session.api.websession.close = AsyncMock()
+ session._owns_websession = True
await session.close()
session.api.websession.close.assert_called_once()
- async def test_close_skips_websession_close_when_already_closed(self):
- """close() does NOT call websession.close() when websession is already closed.
+ async def test_close_with_no_websession_does_not_raise(self):
+ """close() is a no-op when the lazy websession was never created."""
+ session = object.__new__(HiveSession)
+ session.api = MagicMock()
+ session.api.websession = None
+ session._owns_websession = True
- Covers branch 79->exit: the 'if not closed' condition is False, so the
- body is skipped entirely.
- """
+ await session.close()
+
+ async def test_close_skips_websession_close_when_already_closed(self):
+ """close() does NOT call websession.close() when websession is already closed."""
session = object.__new__(HiveSession)
session.api = MagicMock()
session.api.websession.closed = True
session.api.websession.close = AsyncMock()
+ session._owns_websession = True
await session.close()
session.api.websession.close.assert_not_called()
+
+
+class TestHiveSessionCloseOwnership:
+ """close() must not close a caller-provided websession."""
+
+ async def test_close_does_not_close_caller_provided_websession(self):
+ """When _owns_websession=False, close() must not close the websession."""
+ session = object.__new__(HiveSession)
+ session.api = MagicMock()
+ session.api.websession.closed = False
+ session.api.websession.close = AsyncMock()
+ session._owns_websession = False
+
+ await session.close()
+
+ session.api.websession.close.assert_not_called()
+
+ async def test_close_closes_owned_websession(self):
+ """When _owns_websession=True, close() closes the websession as normal."""
+ session = object.__new__(HiveSession)
+ session.api = MagicMock()
+ session.api.websession.closed = False
+ session.api.websession.close = AsyncMock()
+ session._owns_websession = True
+
+ await session.close()
+
+ session.api.websession.close.assert_called_once()
diff --git a/tests/unit/test_session_discovery.py b/tests/unit/test_session_discovery.py
new file mode 100644
index 0000000..b889fb7
--- /dev/null
+++ b/tests/unit/test_session_discovery.py
@@ -0,0 +1,602 @@
+"""Extended branch-coverage tests for DiscoveryMixin.start_session and create_devices."""
+
+# pylint: disable=attribute-defined-outside-init,too-few-public-methods,protected-access
+from datetime import datetime
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+from apyhiveapi.helper.hive_exceptions import (
+ HiveReauthRequired,
+ HiveUnknownConfiguration,
+)
+from apyhiveapi.helper.hivedataclasses import EntityConfig, SessionConfig, SessionTokens
+from apyhiveapi.helper.map import Map
+from apyhiveapi.session.discovery import DiscoveryMixin
+
+_POPULATED_PRODUCTS = {
+ "prod-1": {"id": "prod-1", "type": "heating", "state": {"name": "Hall"}}
+}
+_POPULATED_DEVICES = {"dev-1": {"id": "dev-1", "type": "hub", "state": {"name": "Hub"}}}
+
+
+def _make_stub(*, has_data=True):
+ """Return a DiscoveryMixin stub wired for start_session tests (create_devices mocked)."""
+
+ class StubDiscovery(DiscoveryMixin):
+ """Concrete subclass used only for testing."""
+
+ s = StubDiscovery()
+ s.config = SessionConfig()
+ s.data = Map(
+ {
+ "products": _POPULATED_PRODUCTS if has_data else {},
+ "devices": _POPULATED_DEVICES if has_data else {},
+ "actions": {},
+ "minMax": {},
+ "user": {},
+ }
+ )
+ s.helper = MagicMock()
+ s.helper.sanitize_payload = MagicMock(return_value={})
+ s.auth = MagicMock()
+ s.tokens = SessionTokens()
+ s.hub_id = None
+ s.device_list = {
+ "parent": [],
+ "binary_sensor": [],
+ "climate": [],
+ "light": [],
+ "sensor": [],
+ "switch": [],
+ "water_heater": [],
+ }
+ s.get_devices = AsyncMock(return_value=True)
+ s.update_tokens = AsyncMock()
+ s.create_devices = AsyncMock(return_value=s.device_list)
+ return s
+
+
+def _make_create_stub():
+ """Return a DiscoveryMixin stub for testing create_devices directly (not mocked)."""
+
+ class StubDiscovery(DiscoveryMixin):
+ """Concrete subclass used only for testing."""
+
+ s = StubDiscovery()
+ s.config = SessionConfig()
+ s.data = Map(
+ {
+ "products": {},
+ "devices": {},
+ "actions": {},
+ "minMax": {},
+ "user": {"temperatureUnit": "C"},
+ }
+ )
+ s.helper = MagicMock()
+ s.helper.get_device_data = MagicMock(
+ return_value={
+ "id": "dev-1",
+ "state": {"name": "Test Device"},
+ "props": {"online": True},
+ }
+ )
+ s.hub_id = None
+ s.device_list = {
+ "parent": [],
+ "binary_sensor": [],
+ "climate": [],
+ "light": [],
+ "sensor": [],
+ "switch": [],
+ "water_heater": [],
+ }
+ return s
+
+
+# ---------------------------------------------------------------------------
+# start_session — config branches
+# ---------------------------------------------------------------------------
+
+
+class TestStartSessionExtended:
+ """Tests for start_session config-processing branches."""
+
+ async def test_with_tokens_config_calls_update_tokens(self):
+ """Passing 'tokens' in non-file config calls update_tokens(tokens, False)."""
+ s = _make_stub()
+ s.config.file = False
+ tokens = {"token": "t", "accessToken": "a", "refreshToken": "r"}
+ await s.start_session({"tokens": tokens})
+ s.update_tokens.assert_called_once_with(tokens, False)
+
+ async def test_with_username_config_sets_auth_username(self):
+ """Passing 'username' alongside 'tokens' in non-file config sets auth.username."""
+ s = _make_stub()
+ s.config.file = False
+ tokens = {"token": "t", "accessToken": "a", "refreshToken": "r"}
+ await s.start_session({"tokens": tokens, "username": "user@test.com"})
+ assert s.auth.username == "user@test.com"
+
+ async def test_with_password_config_sets_auth_password(self):
+ """Passing 'password' alongside 'tokens' in non-file config sets auth.password."""
+ s = _make_stub()
+ s.config.file = False
+ tokens = {"token": "t", "accessToken": "a", "refreshToken": "r"}
+ await s.start_session(
+ {"tokens": tokens, "password": "secret"} # pragma: allowlist secret
+ )
+ assert s.auth.password == "secret" # pragma: allowlist secret
+
+ async def test_with_device_data_3_items_sets_auth_keys(self):
+ """3-item device_data sets device_group_key, device_key, device_password on auth."""
+ s = _make_stub()
+ s.config.file = False
+ await s.start_session(
+ {
+ "tokens": {},
+ "device_data": ["grp-key", "dev-key", "dev-pass"],
+ }
+ )
+ assert s.auth.device_group_key == "grp-key"
+ assert s.auth.device_key == "dev-key"
+ assert s.auth.device_password == "dev-pass" # pragma: allowlist secret
+
+ async def test_with_device_data_too_short_raises_unknown_configuration(self):
+ """device_data with fewer than 3 items raises HiveUnknownConfiguration."""
+ s = _make_stub()
+ s.config.file = False
+ with pytest.raises(HiveUnknownConfiguration):
+ await s.start_session(
+ {
+ "tokens": {},
+ "device_data": ["grp-key", "dev-key"],
+ }
+ )
+
+ async def test_with_device_data_4_items_sets_token_created(self):
+ """4-item device_data with a token_created timestamp sets tokens.token_created."""
+ s = _make_stub()
+ s.config.file = False
+ created_ts = datetime(2024, 1, 15, 10, 30, 0)
+ await s.start_session(
+ {
+ "tokens": {},
+ "device_data": ["grp-key", "dev-key", "dev-pass", created_ts],
+ }
+ )
+ assert s.tokens.token_created == created_ts
+
+ async def test_with_device_data_4_items_none_token_created_not_set(self):
+ """4-item device_data where token_created is None — does not overwrite token_created."""
+ s = _make_stub()
+ s.config.file = False
+ original_created = s.tokens.token_created
+ await s.start_session(
+ {
+ "tokens": {},
+ "device_data": ["grp-key", "dev-key", "dev-pass", None],
+ }
+ )
+ assert s.tokens.token_created == original_created
+
+ async def test_no_tokens_and_not_file_raises_unknown_configuration(self):
+ """Non-file config without 'tokens' raises HiveUnknownConfiguration."""
+ s = _make_stub()
+ s.config.file = False
+ with pytest.raises(HiveUnknownConfiguration):
+ await s.start_session({"username": "user@test.com"})
+
+ async def test_empty_devices_after_get_devices_raises_unknown_configuration(self):
+ """start_session raises HiveUnknownConfiguration when data.devices is empty post-poll."""
+ s = _make_stub(has_data=False)
+ s.config.file = True
+ with pytest.raises(HiveUnknownConfiguration):
+ await s.start_session({})
+
+ async def test_none_config_defaults_to_empty_dict(self):
+ """start_session(None) is treated as start_session({}) — set file mode separately."""
+ s = _make_stub()
+ s.config.file = True
+ # Should not raise; equivalent to passing {}
+ result = await s.start_session(None)
+ assert result is s.device_list
+
+ async def test_file_mode_username_skips_token_branch(self):
+ """'use@file.com' activates file mode so 'tokens' branch is skipped."""
+ s = _make_stub()
+ s.config.file = False
+ # Even if tokens is present, file mode skips the update_tokens call
+ await s.start_session({"username": "use@file.com", "tokens": {}})
+ s.update_tokens.assert_not_called()
+
+
+# ---------------------------------------------------------------------------
+# create_devices — device processing
+# ---------------------------------------------------------------------------
+
+
+class TestCreateDevicesExtended:
+ """Tests for create_devices branches not covered by the main test files."""
+
+ async def test_no_hub_device_hub_id_stays_none(self):
+ """Devices list with no 'hub' type leaves hub_id as None (else branch of for-loop)."""
+ s = _make_create_stub()
+ s.data["devices"] = {
+ "trv-1": {"id": "trv-1", "type": "trv", "state": {"name": "TRV"}}
+ }
+ s.data["products"] = {}
+ await s.create_devices()
+ assert s.hub_id is None
+
+ async def test_hub_device_sets_hub_id(self):
+ """Devices list with a 'hub' type sets hub_id to that device's ID."""
+ s = _make_create_stub()
+ s.data["devices"] = {
+ "hub-42": {"id": "hub-42", "type": "hub", "state": {"name": "My Hub"}}
+ }
+ await s.create_devices()
+ assert s.hub_id == "hub-42"
+
+ async def test_product_with_error_key_is_skipped(self):
+ """Products with an 'error' key are silently skipped."""
+ s = _make_create_stub()
+ s.data["products"] = {
+ "bad": {"id": "bad", "type": "heating", "error": "device not found"}
+ }
+ result = await s.create_devices()
+ assert result["climate"] == []
+
+ async def test_non_heating_group_product_skipped(self):
+ """isGroup=True products of non-heating type are not added to any list."""
+ s = _make_create_stub()
+ s.data["products"] = {
+ "grp-1": {
+ "id": "grp-1",
+ "type": "activeplug",
+ "isGroup": True,
+ "state": {"name": "Plug Group"},
+ }
+ }
+ result = await s.create_devices()
+ assert result["switch"] == []
+
+ async def test_heating_group_product_not_skipped(self):
+ """isGroup=True products of heating type are processed and added."""
+ s = _make_create_stub()
+ s.data["products"] = {
+ "h-grp": {
+ "id": "h-grp",
+ "type": "heating",
+ "isGroup": True,
+ "state": {"name": "Heating Zone"},
+ }
+ }
+ result = await s.create_devices()
+ assert len(result["climate"]) == 1
+
+ async def test_multiple_devices_all_processed(self):
+ """Multiple devices in the device list are all processed."""
+ s = _make_create_stub()
+ s.data["devices"] = {
+ "hub-1": {"id": "hub-1", "type": "hub", "state": {"name": "Hub"}},
+ "trv-1": {"id": "trv-1", "type": "trv", "state": {"name": "TRV"}},
+ }
+ s.data["products"] = {}
+ await s.create_devices()
+ # Hub is found; hub_id is set to the hub device
+ assert s.hub_id == "hub-1"
+
+ async def test_action_processed_as_switch(self):
+ """Actions in data.actions are added to device_list['switch']."""
+ s = _make_create_stub()
+ s.data["actions"] = {
+ "act-1": {"id": "act-1", "name": "Good Night", "type": "action"}
+ }
+ result = await s.create_devices()
+ assert len(result["switch"]) == 1
+ assert result["switch"][0].hive_type == "action"
+
+ async def test_returns_device_list_dict(self):
+ """create_devices always returns a dict with the expected HA entity keys."""
+ s = _make_create_stub()
+ result = await s.create_devices()
+ for key in (
+ "parent",
+ "binary_sensor",
+ "climate",
+ "light",
+ "sensor",
+ "switch",
+ "water_heater",
+ ):
+ assert key in result
+
+ async def test_product_with_error_and_valid_both_present_only_valid_added(self):
+ """Only products without 'error' are added when both types coexist."""
+ s = _make_create_stub()
+ s.data["products"] = {
+ "bad": {"id": "bad", "type": "heating", "error": "broken"},
+ "good": {"id": "good", "type": "heating", "state": {"name": "Hall"}},
+ }
+ result = await s.create_devices()
+ assert len(result["climate"]) == 1
+ assert result["climate"][0].hive_id == "good"
+
+
+# ---------------------------------------------------------------------------
+# start_session raises wrong exception for empty device data
+# ---------------------------------------------------------------------------
+
+
+class TestStartSessionWrongException:
+ """start_session must raise HiveUnknownConfiguration (not HiveReauthRequired) for empty data."""
+
+ async def test_empty_devices_raises_unknown_configuration(self):
+ """start_session raises HiveUnknownConfiguration when API returns no devices."""
+ s = _make_stub(has_data=False)
+ s.get_devices = AsyncMock()
+
+ with pytest.raises(HiveUnknownConfiguration):
+ await s.start_session({})
+
+ async def test_does_not_raise_reauth_for_empty_data(self):
+ """start_session must NOT raise HiveReauthRequired when device data is empty."""
+ s = _make_stub(has_data=False)
+ s.get_devices = AsyncMock()
+
+ with pytest.raises(Exception) as exc_info:
+ await s.start_session({})
+
+ assert not isinstance(exc_info.value, HiveReauthRequired), (
+ "HiveReauthRequired must not be raised for empty device list"
+ )
+
+
+# ---------------------------------------------------------------------------
+# create_devices — bare d["id"] and p["id"] crash when id key is absent
+# ---------------------------------------------------------------------------
+
+
+class TestBareIdAccess:
+ """create_devices must use .get('id', fallback) instead of bare ['id'] access."""
+
+ async def test_battery_device_without_id_does_not_crash(self):
+ """Device with no 'id' key in battery-type must not raise KeyError."""
+ s = _make_create_stub()
+ s.data["devices"] = {
+ "trv-key": {
+ "type": "trv",
+ "state": {"name": "TRV"},
+ "props": {},
+ }
+ }
+ s.config.battery = set()
+ try:
+ await s.create_devices()
+ except KeyError as err:
+ pytest.fail(f"KeyError raised for missing 'id' in device: {err}")
+
+ async def test_mode_product_without_id_does_not_crash(self):
+ """Product with no 'id' key in mode-type must not raise KeyError."""
+ s = _make_create_stub()
+ s.data["products"] = {
+ "heating-key": {
+ "type": "heating",
+ "state": {"name": "Hall"},
+ }
+ }
+ s.config.mode = set()
+ try:
+ await s.create_devices()
+ except KeyError as err:
+ pytest.fail(f"KeyError raised for missing 'id' in product: {err}")
+
+
+# ===========================================================================
+# Migrated from test_remaining_branches.py
+# ===========================================================================
+
+
+class TestCreateDevicesEntityConfigKwargs:
+ """Lines 224->226, 226->228, 228->230: entity_config kwarg population in DEVICES loop."""
+
+ async def test_entity_config_with_all_fields_populates_kwargs(self):
+ """EntityConfig with ha_name, hive_type, and category all set → all kwargs passed."""
+ s = _make_create_stub()
+ s.data["devices"] = {
+ "dev-1": {
+ "id": "dev-1",
+ "type": "hub",
+ "state": {"name": "My Hub"},
+ "props": {},
+ }
+ }
+ entity_cfg = EntityConfig(
+ entity_type="binary_sensor",
+ ha_name="Hub Status",
+ hive_type="Connectivity",
+ category="diagnostic",
+ )
+ with patch("apyhiveapi.session.discovery.DEVICES", {"hub": [entity_cfg]}):
+ result = await s.create_devices()
+ assert len(result["binary_sensor"]) == 1
+ created = result["binary_sensor"][0]
+ assert created.hive_type == "Connectivity"
+ assert created.category == "diagnostic"
+
+ async def test_entity_config_empty_fields_does_not_add_to_kwargs(self):
+ """EntityConfig with empty ha_name and hive_type does not inject those keys."""
+ s = _make_create_stub()
+ s.data["devices"] = {
+ "dev-1": {
+ "id": "dev-1",
+ "type": "hub",
+ "state": {"name": "My Hub"},
+ "props": {},
+ }
+ }
+ entity_cfg = EntityConfig(
+ entity_type="binary_sensor",
+ ha_name="", # falsy — should not be added to kwargs
+ hive_type="", # falsy — should not be added to kwargs
+ category=None, # None — should not be added to kwargs
+ )
+ with patch("apyhiveapi.session.discovery.DEVICES", {"hub": [entity_cfg]}):
+ result = await s.create_devices()
+ # Should still process without error
+ assert isinstance(result, dict)
+
+
+class TestCreateDevicesDeviceAddListError:
+ """Lines 232-233: KeyError/TypeError from add_list in DEVICES loop is caught."""
+
+ async def test_add_list_keyerror_is_caught_not_raised(self):
+ """KeyError from add_list during device processing is logged, not propagated."""
+ s = _make_create_stub()
+ s.data["devices"] = {
+ "dev-1": {
+ "id": "dev-1",
+ "type": "hub",
+ "state": {"name": "My Hub"},
+ "props": {},
+ }
+ }
+ entity_cfg = EntityConfig(
+ entity_type="binary_sensor",
+ ha_name="Hub Status",
+ hive_type="Connectivity",
+ category="diagnostic",
+ )
+ with patch("apyhiveapi.session.discovery.DEVICES", {"hub": [entity_cfg]}):
+ with patch.object(s, "add_list", side_effect=KeyError("bad key")):
+ # Should complete without raising
+ result = await s.create_devices()
+ assert isinstance(result, dict)
+
+ async def test_add_list_typeerror_is_caught_not_raised(self):
+ """TypeError from add_list during device processing is caught."""
+ s = _make_create_stub()
+ s.data["devices"] = {
+ "dev-1": {
+ "id": "dev-1",
+ "type": "hub",
+ "state": {"name": "My Hub"},
+ "props": {},
+ }
+ }
+ entity_cfg = EntityConfig(
+ entity_type="binary_sensor",
+ ha_name="",
+ hive_type="",
+ category=None,
+ )
+ with patch("apyhiveapi.session.discovery.DEVICES", {"hub": [entity_cfg]}):
+ with patch.object(s, "add_list", side_effect=TypeError("bad type")):
+ result = await s.create_devices()
+ assert isinstance(result, dict)
+
+
+class TestCreateDevicesActionAddListError:
+ """Lines 258-259: KeyError/TypeError from add_list in actions loop is caught."""
+
+ async def test_action_add_list_keyerror_is_caught(self):
+ """KeyError from add_list when processing an action is logged, not propagated."""
+ s = _make_create_stub()
+ s.data["actions"] = {"act-1": {"id": "act-1", "name": "Good Night"}}
+ with patch.object(s, "add_list", side_effect=KeyError("missing")):
+ result = await s.create_devices()
+ assert isinstance(result, dict)
+
+ async def test_action_add_list_typeerror_is_caught(self):
+ """TypeError from add_list when processing an action is caught."""
+ s = _make_create_stub()
+ s.data["actions"] = {"act-1": {"id": "act-1", "name": "Wake Up"}}
+ with patch.object(s, "add_list", side_effect=TypeError("type error")):
+ result = await s.create_devices()
+ assert isinstance(result, dict)
+
+
+class TestCreateDevicesProductTemperatureUnit:
+ """Line 305: entity_config.temperature_unit is used when set and entity_type != 'climate'."""
+
+ async def test_entity_config_temperature_unit_passed_to_add_list(self):
+ """EntityConfig with temperature_unit set propagates that value as a kwarg."""
+ s = _make_create_stub()
+ s.data["products"] = {
+ "prod-1": {
+ "id": "prod-1",
+ "type": "heating",
+ "state": {"name": "Heating"},
+ "props": {},
+ }
+ }
+ # A non-climate entity with temperature_unit set triggers line 305
+ entity_cfg = EntityConfig(
+ entity_type="sensor",
+ ha_name="Temp Sensor",
+ hive_type="Current_Temperature",
+ category="diagnostic",
+ temperature_unit="F",
+ )
+ captured_kwargs = {}
+
+ original_add_list = s.add_list
+
+ def capturing_add_list(entity_type, data, **kwargs):
+ captured_kwargs.update(kwargs)
+ return original_add_list(entity_type, data, **kwargs)
+
+ with patch("apyhiveapi.session.discovery.PRODUCTS", {"heating": [entity_cfg]}):
+ with patch.object(s, "add_list", side_effect=capturing_add_list):
+ await s.create_devices()
+
+ assert captured_kwargs.get("temperature_unit") == "F"
+
+
+class TestCreateDevicesProductAddListAttributeError:
+ """Lines 308-309: NameError/AttributeError from add_list in products loop is caught."""
+
+ async def test_product_add_list_attribute_error_is_caught(self):
+ """AttributeError from add_list when processing a product is caught."""
+ s = _make_create_stub()
+ s.data["products"] = {
+ "prod-1": {
+ "id": "prod-1",
+ "type": "heating",
+ "state": {"name": "Heating"},
+ "props": {},
+ }
+ }
+ entity_cfg = EntityConfig(
+ entity_type="climate",
+ ha_name="",
+ hive_type="",
+ category=None,
+ )
+ with patch("apyhiveapi.session.discovery.PRODUCTS", {"heating": [entity_cfg]}):
+ with patch.object(s, "add_list", side_effect=AttributeError("attr error")):
+ result = await s.create_devices()
+ assert isinstance(result, dict)
+
+ async def test_product_add_list_name_error_is_caught(self):
+ """NameError from add_list when processing a product is caught."""
+ s = _make_create_stub()
+ s.data["products"] = {
+ "prod-1": {
+ "id": "prod-1",
+ "type": "heating",
+ "state": {"name": "Heating"},
+ "props": {},
+ }
+ }
+ entity_cfg = EntityConfig(
+ entity_type="climate",
+ ha_name="",
+ hive_type="",
+ category=None,
+ )
+ with patch("apyhiveapi.session.discovery.PRODUCTS", {"heating": [entity_cfg]}):
+ with patch.object(s, "add_list", side_effect=NameError("name error")):
+ result = await s.create_devices()
+ assert isinstance(result, dict)
diff --git a/tests/unit/test_session_discovery_extended.py b/tests/unit/test_session_discovery_extended.py
deleted file mode 100644
index 8bc4764..0000000
--- a/tests/unit/test_session_discovery_extended.py
+++ /dev/null
@@ -1,312 +0,0 @@
-"""Extended branch-coverage tests for DiscoveryMixin.start_session and create_devices."""
-
-# pylint: disable=attribute-defined-outside-init,too-few-public-methods,protected-access
-from datetime import datetime
-from unittest.mock import AsyncMock, MagicMock
-
-import pytest
-from apyhiveapi.helper.hive_exceptions import (
- HiveReauthRequired,
- HiveUnknownConfiguration,
-)
-from apyhiveapi.helper.hivedataclasses import SessionConfig, SessionTokens
-from apyhiveapi.helper.map import Map
-from apyhiveapi.session.discovery import DiscoveryMixin
-
-_POPULATED_PRODUCTS = {
- "prod-1": {"id": "prod-1", "type": "heating", "state": {"name": "Hall"}}
-}
-_POPULATED_DEVICES = {"dev-1": {"id": "dev-1", "type": "hub", "state": {"name": "Hub"}}}
-
-
-def _make_stub(*, has_data=True):
- """Return a DiscoveryMixin stub wired for start_session tests (create_devices mocked)."""
-
- class StubDiscovery(DiscoveryMixin):
- """Concrete subclass used only for testing."""
-
- s = StubDiscovery()
- s.config = SessionConfig()
- s.data = Map(
- {
- "products": _POPULATED_PRODUCTS if has_data else {},
- "devices": _POPULATED_DEVICES if has_data else {},
- "actions": {},
- "minMax": {},
- "user": {},
- }
- )
- s.helper = MagicMock()
- s.helper.sanitize_payload = MagicMock(return_value={})
- s.auth = MagicMock()
- s.tokens = SessionTokens()
- s.hub_id = None
- s.device_list = {
- "parent": [],
- "binary_sensor": [],
- "climate": [],
- "light": [],
- "sensor": [],
- "switch": [],
- "water_heater": [],
- }
- s.get_devices = AsyncMock(return_value=True)
- s.update_tokens = AsyncMock()
- s.create_devices = AsyncMock(return_value=s.device_list)
- return s
-
-
-def _make_create_stub():
- """Return a DiscoveryMixin stub for testing create_devices directly (not mocked)."""
-
- class StubDiscovery(DiscoveryMixin):
- """Concrete subclass used only for testing."""
-
- s = StubDiscovery()
- s.config = SessionConfig()
- s.data = Map(
- {
- "products": {},
- "devices": {},
- "actions": {},
- "minMax": {},
- "user": {"temperatureUnit": "C"},
- }
- )
- s.helper = MagicMock()
- s.helper.get_device_data = MagicMock(
- return_value={
- "id": "dev-1",
- "state": {"name": "Test Device"},
- "props": {"online": True},
- }
- )
- s.hub_id = None
- s.device_list = {
- "parent": [],
- "binary_sensor": [],
- "climate": [],
- "light": [],
- "sensor": [],
- "switch": [],
- "water_heater": [],
- }
- return s
-
-
-# ---------------------------------------------------------------------------
-# start_session — config branches
-# ---------------------------------------------------------------------------
-
-
-class TestStartSessionExtended:
- """Tests for start_session config-processing branches."""
-
- async def test_with_tokens_config_calls_update_tokens(self):
- """Passing 'tokens' in non-file config calls update_tokens(tokens, False)."""
- s = _make_stub()
- s.config.file = False
- tokens = {"token": "t", "accessToken": "a", "refreshToken": "r"}
- await s.start_session({"tokens": tokens})
- s.update_tokens.assert_called_once_with(tokens, False)
-
- async def test_with_username_config_sets_auth_username(self):
- """Passing 'username' alongside 'tokens' in non-file config sets auth.username."""
- s = _make_stub()
- s.config.file = False
- tokens = {"token": "t", "accessToken": "a", "refreshToken": "r"}
- await s.start_session({"tokens": tokens, "username": "user@test.com"})
- assert s.auth.username == "user@test.com"
-
- async def test_with_password_config_sets_auth_password(self):
- """Passing 'password' alongside 'tokens' in non-file config sets auth.password."""
- s = _make_stub()
- s.config.file = False
- tokens = {"token": "t", "accessToken": "a", "refreshToken": "r"}
- await s.start_session(
- {"tokens": tokens, "password": "secret"} # pragma: allowlist secret
- )
- assert s.auth.password == "secret" # pragma: allowlist secret
-
- async def test_with_device_data_3_items_sets_auth_keys(self):
- """3-item device_data sets device_group_key, device_key, device_password on auth."""
- s = _make_stub()
- s.config.file = False
- await s.start_session(
- {
- "tokens": {},
- "device_data": ["grp-key", "dev-key", "dev-pass"],
- }
- )
- assert s.auth.device_group_key == "grp-key"
- assert s.auth.device_key == "dev-key"
- assert s.auth.device_password == "dev-pass"
-
- async def test_with_device_data_4_items_sets_token_created(self):
- """4-item device_data with a token_created timestamp sets tokens.token_created."""
- s = _make_stub()
- s.config.file = False
- created_ts = datetime(2024, 1, 15, 10, 30, 0)
- await s.start_session(
- {
- "tokens": {},
- "device_data": ["grp-key", "dev-key", "dev-pass", created_ts],
- }
- )
- assert s.tokens.token_created == created_ts
-
- async def test_with_device_data_4_items_none_token_created_not_set(self):
- """4-item device_data where token_created is None — does not overwrite token_created."""
- s = _make_stub()
- s.config.file = False
- original_created = s.tokens.token_created
- await s.start_session(
- {
- "tokens": {},
- "device_data": ["grp-key", "dev-key", "dev-pass", None],
- }
- )
- assert s.tokens.token_created == original_created
-
- async def test_no_tokens_and_not_file_raises_unknown_configuration(self):
- """Non-file config without 'tokens' raises HiveUnknownConfiguration."""
- s = _make_stub()
- s.config.file = False
- with pytest.raises(HiveUnknownConfiguration):
- await s.start_session({"username": "user@test.com"})
-
- async def test_empty_devices_after_get_devices_raises_reauth(self):
- """start_session raises HiveReauthRequired when data.devices is empty post-poll."""
- s = _make_stub(has_data=False)
- s.config.file = True
- with pytest.raises(HiveReauthRequired):
- await s.start_session({})
-
- async def test_none_config_defaults_to_empty_dict(self):
- """start_session(None) is treated as start_session({}) — set file mode separately."""
- s = _make_stub()
- s.config.file = True
- # Should not raise; equivalent to passing {}
- result = await s.start_session(None)
- assert result is s.device_list
-
- async def test_file_mode_username_skips_token_branch(self):
- """'use@file.com' activates file mode so 'tokens' branch is skipped."""
- s = _make_stub()
- s.config.file = False
- # Even if tokens is present, file mode skips the update_tokens call
- await s.start_session({"username": "use@file.com", "tokens": {}})
- s.update_tokens.assert_not_called()
-
-
-# ---------------------------------------------------------------------------
-# create_devices — device processing
-# ---------------------------------------------------------------------------
-
-
-class TestCreateDevicesExtended:
- """Tests for create_devices branches not covered by the main test files."""
-
- async def test_no_hub_device_hub_id_stays_none(self):
- """Devices list with no 'hub' type leaves hub_id as None (else branch of for-loop)."""
- s = _make_create_stub()
- s.data["devices"] = {
- "trv-1": {"id": "trv-1", "type": "trv", "state": {"name": "TRV"}}
- }
- s.data["products"] = {}
- await s.create_devices()
- assert s.hub_id is None
-
- async def test_hub_device_sets_hub_id(self):
- """Devices list with a 'hub' type sets hub_id to that device's ID."""
- s = _make_create_stub()
- s.data["devices"] = {
- "hub-42": {"id": "hub-42", "type": "hub", "state": {"name": "My Hub"}}
- }
- await s.create_devices()
- assert s.hub_id == "hub-42"
-
- async def test_product_with_error_key_is_skipped(self):
- """Products with an 'error' key are silently skipped."""
- s = _make_create_stub()
- s.data["products"] = {
- "bad": {"id": "bad", "type": "heating", "error": "device not found"}
- }
- result = await s.create_devices()
- assert result["climate"] == []
-
- async def test_non_heating_group_product_skipped(self):
- """isGroup=True products of non-heating type are not added to any list."""
- s = _make_create_stub()
- s.data["products"] = {
- "grp-1": {
- "id": "grp-1",
- "type": "activeplug",
- "isGroup": True,
- "state": {"name": "Plug Group"},
- }
- }
- result = await s.create_devices()
- assert result["switch"] == []
-
- async def test_heating_group_product_not_skipped(self):
- """isGroup=True products of heating type are processed and added."""
- s = _make_create_stub()
- s.data["products"] = {
- "h-grp": {
- "id": "h-grp",
- "type": "heating",
- "isGroup": True,
- "state": {"name": "Heating Zone"},
- }
- }
- result = await s.create_devices()
- assert len(result["climate"]) == 1
-
- async def test_multiple_devices_all_processed(self):
- """Multiple devices in the device list are all processed."""
- s = _make_create_stub()
- s.data["devices"] = {
- "hub-1": {"id": "hub-1", "type": "hub", "state": {"name": "Hub"}},
- "trv-1": {"id": "trv-1", "type": "trv", "state": {"name": "TRV"}},
- }
- s.data["products"] = {}
- await s.create_devices()
- # Hub is found; hub_id is set to the hub device
- assert s.hub_id == "hub-1"
-
- async def test_action_processed_as_switch(self):
- """Actions in data.actions are added to device_list['switch']."""
- s = _make_create_stub()
- s.data["actions"] = {
- "act-1": {"id": "act-1", "name": "Good Night", "type": "action"}
- }
- result = await s.create_devices()
- assert len(result["switch"]) == 1
- assert result["switch"][0].hive_type == "action"
-
- async def test_returns_device_list_dict(self):
- """create_devices always returns a dict with the expected HA entity keys."""
- s = _make_create_stub()
- result = await s.create_devices()
- for key in (
- "parent",
- "binary_sensor",
- "climate",
- "light",
- "sensor",
- "switch",
- "water_heater",
- ):
- assert key in result
-
- async def test_product_with_error_and_valid_both_present_only_valid_added(self):
- """Only products without 'error' are added when both types coexist."""
- s = _make_create_stub()
- s.data["products"] = {
- "bad": {"id": "bad", "type": "heating", "error": "broken"},
- "good": {"id": "good", "type": "heating", "state": {"name": "Hall"}},
- }
- result = await s.create_devices()
- assert len(result["climate"]) == 1
- assert result["climate"][0].hive_id == "good"
diff --git a/tests/unit/test_session_get_devices.py b/tests/unit/test_session_get_devices.py
index 9042215..b284b1e 100644
--- a/tests/unit/test_session_get_devices.py
+++ b/tests/unit/test_session_get_devices.py
@@ -94,6 +94,39 @@ async def test_file_mode_does_not_call_api(self):
p.hive_refresh_tokens.assert_not_called()
+# ---------------------------------------------------------------------------
+# get_devices — malformed API responses
+# ---------------------------------------------------------------------------
+
+
+class TestGetDevicesMalformedResponse:
+ """Hostile/partial response shapes must not escape as raw KeyError."""
+
+ async def test_user_without_id_still_succeeds(self):
+ """A user object with no 'id' key must not crash the poll."""
+ p = _make_stub()
+ p.tokens = MagicMock()
+ p.api.get_all = AsyncMock(
+ return_value={"original": 200, "parsed": {"user": {"name": "x"}}}
+ )
+ result = await p.get_devices("No_ID")
+ assert result is True
+ assert p.config.user_id is None
+
+ async def test_product_without_id_returns_false(self):
+ """A product entry with no 'id' key fails the poll gracefully."""
+ p = _make_stub()
+ p.tokens = MagicMock()
+ p.api.get_all = AsyncMock(
+ return_value={
+ "original": 200,
+ "parsed": {"products": [{"type": "heating"}]},
+ }
+ )
+ result = await p.get_devices("No_ID")
+ assert result is False
+
+
# ---------------------------------------------------------------------------
# get_devices — tokens path
# ---------------------------------------------------------------------------