diff --git a/.coverage_full b/.coverage_full new file mode 100644 index 0000000..6aa8830 Binary files /dev/null and b/.coverage_full differ diff --git a/.secrets.baseline b/.secrets.baseline index 1beb0ff..9805f36 100644 --- a/.secrets.baseline +++ b/.secrets.baseline @@ -268,28 +268,28 @@ "filename": "src/api/hive_auth_async.py", "hashed_secret": "5dc786e32e3a0a4611daaf397721c6ef64cd71b0", "is_verified": false, - "line_number": 48 + "line_number": 49 }, { "type": "Secret Keyword", "filename": "src/api/hive_auth_async.py", "hashed_secret": "ac9f290e69cee683ba3c63461f1f3fa02765032a", "is_verified": false, - "line_number": 49 + "line_number": 50 }, { "type": "Secret Keyword", "filename": "src/api/hive_auth_async.py", "hashed_secret": "351b174ccf89601f6f4bd3f3970a4aba7d17c98e", "is_verified": false, - "line_number": 52 + "line_number": 53 }, { "type": "Secret Keyword", "filename": "src/api/hive_auth_async.py", "hashed_secret": "576956b5291ac38d04ef5f82cc974286a857f0b2", "is_verified": false, - "line_number": 109 + "line_number": 112 } ], "src/api/srp_crypto.py": [ @@ -298,112 +298,112 @@ "filename": "src/api/srp_crypto.py", "hashed_secret": "3e619ee0820ecf213c2f38c634e416b53defe3b0", "is_verified": false, - "line_number": 11 + "line_number": 10 }, { "type": "Hex High Entropy String", "filename": "src/api/srp_crypto.py", "hashed_secret": "b8e0d506d969f09a9af89ce89fd9759b72c63262", "is_verified": false, - "line_number": 12 + "line_number": 11 }, { "type": "Hex High Entropy String", "filename": "src/api/srp_crypto.py", "hashed_secret": "e97a751edc71e9afbe0c0f63ec94873392833f9f", "is_verified": false, - "line_number": 13 + "line_number": 12 }, { "type": "Hex High Entropy String", "filename": "src/api/srp_crypto.py", "hashed_secret": "92488c021dd524a2f4e116666b3645308fa0e35c", "is_verified": false, - "line_number": 14 + "line_number": 13 }, { "type": "Hex High Entropy String", "filename": "src/api/srp_crypto.py", "hashed_secret": "d4571e2f026f458aecd2950b0eb6aec190276177", "is_verified": false, - "line_number": 15 + "line_number": 14 }, { "type": "Hex High Entropy String", "filename": "src/api/srp_crypto.py", "hashed_secret": "8109d3c2f659f13cb61fc9e71eed574efe8c8fd8", "is_verified": false, - "line_number": 16 + "line_number": 15 }, { "type": "Hex High Entropy String", "filename": "src/api/srp_crypto.py", "hashed_secret": "08cac7461d7b624b88c53ee47da09cbbb84ea290", "is_verified": false, - "line_number": 17 + "line_number": 16 }, { "type": "Hex High Entropy String", "filename": "src/api/srp_crypto.py", "hashed_secret": "95523fea7e6136c6148299dcc3077debfa2976b3", "is_verified": false, - "line_number": 18 + "line_number": 17 }, { "type": "Hex High Entropy String", "filename": "src/api/srp_crypto.py", "hashed_secret": "c978fb77621e86f5e9077653fe5345ac1616b466", "is_verified": false, - "line_number": 19 + "line_number": 18 }, { "type": "Hex High Entropy String", "filename": "src/api/srp_crypto.py", "hashed_secret": "fc02990268ecf8a35a4912d60dab3754e5f43846", "is_verified": false, - "line_number": 20 + "line_number": 19 }, { "type": "Hex High Entropy String", "filename": "src/api/srp_crypto.py", "hashed_secret": "2c2c0ca491a73e95c8965b6641731057b65f6462", "is_verified": false, - "line_number": 21 + "line_number": 20 }, { "type": "Hex High Entropy String", "filename": "src/api/srp_crypto.py", "hashed_secret": "672b25c6be065170206f3fc6346ebb8e84cbb9d3", "is_verified": false, - "line_number": 22 + "line_number": 21 }, { "type": "Hex High Entropy String", "filename": "src/api/srp_crypto.py", "hashed_secret": "99d02e268ea3ee849fb6e359c6c1b019e4d07efd", "is_verified": false, - "line_number": 23 + "line_number": 22 }, { "type": "Hex High Entropy String", "filename": "src/api/srp_crypto.py", "hashed_secret": "e677fc4cb09d99e1e0d30af31f2e209e541e380e", "is_verified": false, - "line_number": 24 + "line_number": 23 }, { "type": "Hex High Entropy String", "filename": "src/api/srp_crypto.py", "hashed_secret": "05b69b06f40cae0c910a15b1ac75b1f7a847eccb", "is_verified": false, - "line_number": 25 + "line_number": 24 }, { "type": "Hex High Entropy String", "filename": "src/api/srp_crypto.py", "hashed_secret": "c7f914bac2d66eb3f8ae3888fa47bf1ada6caaf5", "is_verified": false, - "line_number": 26 + "line_number": 25 } ], "tests/unit/test_device_registration.py": [ @@ -419,21 +419,21 @@ "filename": "tests/unit/test_device_registration.py", "hashed_secret": "d8bce9746547bb7743e5933fbf0fc4f2d2cbcad3", "is_verified": false, - "line_number": 710 + "line_number": 659 }, { "type": "Secret Keyword", "filename": "tests/unit/test_device_registration.py", "hashed_secret": "e4f50034475acff058e17b35679f8ef1e54f86c5", "is_verified": false, - "line_number": 783 + "line_number": 727 }, { "type": "Secret Keyword", "filename": "tests/unit/test_device_registration.py", "hashed_secret": "6ab013c213c685b1f1b1a452796bf22afbd44699", "is_verified": false, - "line_number": 794 + "line_number": 737 } ], "tests/unit/test_hive_auth.py": [ @@ -486,14 +486,14 @@ "filename": "tests/unit/test_hive_auth_async.py", "hashed_secret": "5c5a15a8b0b3e154d77746945e563ba40100681b", "is_verified": false, - "line_number": 150 + "line_number": 173 }, { "type": "Secret Keyword", "filename": "tests/unit/test_hive_auth_async.py", "hashed_secret": "d8bce9746547bb7743e5933fbf0fc4f2d2cbcad3", "is_verified": false, - "line_number": 206 + "line_number": 272 } ], "tests/unit/test_hive_auth_async_extended.py": [ @@ -502,49 +502,49 @@ "filename": "tests/unit/test_hive_auth_async_extended.py", "hashed_secret": "5c5a15a8b0b3e154d77746945e563ba40100681b", "is_verified": false, - "line_number": 259 + "line_number": 260 }, { "type": "Secret Keyword", "filename": "tests/unit/test_hive_auth_async_extended.py", "hashed_secret": "d8bce9746547bb7743e5933fbf0fc4f2d2cbcad3", "is_verified": false, - "line_number": 340 + "line_number": 341 }, { "type": "Secret Keyword", "filename": "tests/unit/test_hive_auth_async_extended.py", "hashed_secret": "76f6b6f16cb41692b330fc806029e8a31e20b69b", "is_verified": false, - "line_number": 815 + "line_number": 813 }, { "type": "Secret Keyword", "filename": "tests/unit/test_hive_auth_async_extended.py", "hashed_secret": "b3ed2cf313e7546085c3c50622143ff31e467d23", "is_verified": false, - "line_number": 834 + "line_number": 832 }, { "type": "Secret Keyword", "filename": "tests/unit/test_hive_auth_async_extended.py", "hashed_secret": "7476b69b5005e05d536361f960a9d18b736dfbfc", "is_verified": false, - "line_number": 848 + "line_number": 846 }, { "type": "Secret Keyword", "filename": "tests/unit/test_hive_auth_async_extended.py", "hashed_secret": "ff9f30d9ba5a4ec386edddeacc27f74ef412085e", "is_verified": false, - "line_number": 855 + "line_number": 853 }, { "type": "Secret Keyword", "filename": "tests/unit/test_hive_auth_async_extended.py", "hashed_secret": "a8ad0732120b9dfed5b99fd6a2aca4fc8ba48d80", "is_verified": false, - "line_number": 893 + "line_number": 890 } ], "tests/unit/test_hive_helper_extended.py": [ @@ -553,21 +553,21 @@ "filename": "tests/unit/test_hive_helper_extended.py", "hashed_secret": "701b389b848a2b1cfab867093101d8d5ac56addd", "is_verified": false, - "line_number": 134 + "line_number": 102 }, { "type": "Secret Keyword", "filename": "tests/unit/test_hive_helper_extended.py", "hashed_secret": "18960546905b75c869e7de63961dc185f9a0a7c9", "is_verified": false, - "line_number": 141 + "line_number": 109 }, { "type": "Secret Keyword", "filename": "tests/unit/test_hive_helper_extended.py", "hashed_secret": "fbf52ca8a72d8ecd77235d3b3e5d014e19ffbff2", "is_verified": false, - "line_number": 143 + "line_number": 111 } ], "tests/unit/test_session_discovery_extended.py": [ @@ -580,5 +580,5 @@ } ] }, - "generated_at": "2026-05-17T16:44:49Z" + "generated_at": "2026-06-16T18:53:56Z" } diff --git a/src/__init__.py b/src/__init__.py index bf6f4a0..13762b9 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -12,7 +12,10 @@ from .helper.const import SMS_REQUIRED from .helper.hive_exceptions import ( HiveApiError, + HiveAuthCredentialError, HiveAuthError, + HiveConfigurationError, + HiveError, HiveFailedToRefreshTokens, HiveInvalid2FACode, HiveInvalidDeviceAuthentication, diff --git a/src/api/device_registration.py b/src/api/device_registration.py index 5013ed8..316eb3c 100644 --- a/src/api/device_registration.py +++ b/src/api/device_registration.py @@ -57,7 +57,7 @@ class DeviceRegistrationMixin: large_a_value: int client_secret: str | None - async def generate_hash_device(self, device_group_key, device_key): + def generate_hash_device(self, device_group_key, device_key): """Generate device hash key.""" # source: https://github.com/amazon-archives/amazon-cognito-identity-js/blob/6b87f1a30a998072b4d98facb49dcaf8780d15b0/src/AuthenticationHelper.js#L137 # pylint: disable=line-too-long @@ -81,7 +81,7 @@ async def generate_hash_device(self, device_group_key, device_key): self.device_password = device_password return device_secret_verifier_config - async def get_device_authentication_key( # pylint: disable=too-many-positional-arguments + def get_device_authentication_key( # pylint: disable=too-many-positional-arguments self, device_group_key, device_key, device_password, server_b_value, salt ): """Get device authentication key.""" @@ -120,7 +120,7 @@ async def process_device_challenge(self, challenge_parameters): "%a %b %d %H:%M:%S UTC %Y" ), ) - hkdf = await self.get_device_authentication_key( + hkdf = self.get_device_authentication_key( self.device_group_key, self.device_key, self.device_password, @@ -169,7 +169,7 @@ async def confirm_device(self, device_name: str | None = None): result = None try: - device_secret_verifier_config = await self.generate_hash_device( + device_secret_verifier_config = self.generate_hash_device( self.device_group_key, self.device_key ) result = await self.loop.run_in_executor( @@ -183,14 +183,12 @@ async def confirm_device(self, device_name: str | None = None): ), ) except botocore.exceptions.ClientError as err: - if err.__class__.__name__ in ( - "NotAuthorizedException", - "CodeMismatchException", - ): + code = (err.response or {}).get("Error", {}).get("Code", "") + if code == "CodeMismatchException": raise HiveInvalid2FACode from err + raise HiveApiError from err except botocore.exceptions.EndpointConnectionError as err: - if err.__class__.__name__ == "EndpointConnectionError": - raise HiveApiError from err + raise HiveApiError from err return result @@ -210,8 +208,7 @@ async def update_device_status(self): ), ) except botocore.exceptions.EndpointConnectionError as err: - if err.__class__.__name__ == "EndpointConnectionError": - raise HiveApiError from err + raise HiveApiError from err return result @@ -335,10 +332,8 @@ async def forget_device(self, access_token, device_key): ), ) except botocore.exceptions.ClientError as err: - if err.__class__.__name__ == "NotAuthorizedException": - raise HiveInvalid2FACode from err + raise HiveApiError from err except botocore.exceptions.EndpointConnectionError as err: - if err.__class__.__name__ == "ResourceNotFoundException": - raise HiveApiError from err + raise HiveApiError from err return result diff --git a/src/api/hive_api.py b/src/api/hive_api.py index 6801d3f..c4ae70c 100644 --- a/src/api/hive_api.py +++ b/src/api/hive_api.py @@ -2,15 +2,22 @@ import json import logging +import re import requests -import urllib3 from pyquery import PyQuery -urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) - _LOGGER = logging.getLogger(__name__) +_NO_RESPONSE = "No response to Hive API request" +_ERROR_RESPONSE = "Error making API call" + +# requests exceptions all subclass OSError; response.json() raises a +# json.JSONDecodeError subclass. +_REQUEST_ERRORS = (OSError, RuntimeError, json.JSONDecodeError) + +_SSO_ASSIGNMENT = re.compile(r'window\.(\w+)\s*=\s*"([^"]*)"') + class HiveApi: """Hive API Code.""" @@ -33,8 +40,8 @@ def __init__(self, hive_session=None, token=None): } self.timeout = 5 self.json_return = { - "original": "No response to Hive API request", - "parsed": "No response to Hive API request", + "original": _NO_RESPONSE, + "parsed": _NO_RESPONSE, } self.session = hive_session self.token = token @@ -74,37 +81,24 @@ def request(self, http_method, url, jsc=None): _LOGGER.error("Request failed: %s", e) raise - def refresh_tokens(self, tokens=None): - """Get new session tokens - DEPRECATED NOW BY AWS TOKEN MANAGEMENT.""" - _LOGGER.debug("refresh_tokens - Attempting token refresh (deprecated method)") - if tokens is None: - tokens = {} - url = self.urls["refresh"] - if self.session is not None: - tokens = self.session.tokens.token_data - jsc = ( - "{" - + ",".join( - ('"' + str(i) + '": "' + str(t) + '" ' for i, t in tokens.items()) - ) - + "}" - ) + def _call_endpoint(self, http_method, url, jsc=None): + """Call an endpoint and return a fresh result dict for this call.""" + json_return = { + "original": _NO_RESPONSE, + "parsed": _NO_RESPONSE, + } try: - info = self.request("POST", url, jsc) - data = json.loads(info.text) - if "token" in data and self.session: - _LOGGER.debug( - "refresh_tokens - Token refresh successful, updating session" - ) - self.session.update_tokens(data) - self.urls.update({"base": data["platform"]["endpoint"]}) - self.json_return.update({"original": info.status_code}) - self.json_return.update({"parsed": info.json()}) - except (OSError, RuntimeError, ZeroDivisionError, json.JSONDecodeError) as e: - _LOGGER.error("Token refresh failed: %s", str(e)) - self.error() + response = self.request(http_method, url, jsc) + if response is not None: + json_return["original"] = response.status_code + json_return["parsed"] = response.json() + else: + _LOGGER.error("No response from Hive API call to %s", url) + except _REQUEST_ERRORS as e: + _LOGGER.error("Hive API call to %s failed: %s", url, e) + json_return = self.error() - return self.json_return + return json_return def get_login_info(self): """Get login properties to make the login request.""" @@ -113,33 +107,21 @@ def get_login_info(self): ) url = self.urls["properties"] try: - data = requests.get(url=url, verify=False, timeout=self.timeout) + data = requests.get(url=url, timeout=self.timeout) _LOGGER.debug( "get_login_info - Login info response status: %s", data.status_code ) - html = PyQuery(data.content) - json_data = json.loads( - '{"' - + (html("script:first").text()) - .replace(",", ', "') - .replace("=", '":') - .replace("window.", "") - + "}" - ) - - login_data = {} - login_data.update({"UPID": json_data["HiveSSOPoolId"]}) - login_data.update({"CLIID": json_data["HiveSSOPublicCognitoClientId"]}) - login_data.update({"REGION": json_data["HiveSSOPoolId"]}) + script_text = PyQuery(data.content)("script:first").text() + sso_values = dict(_SSO_ASSIGNMENT.findall(script_text)) + + login_data = { + "UPID": sso_values["HiveSSOPoolId"], + "CLIID": sso_values["HiveSSOPublicCognitoClientId"], + "REGION": sso_values["HiveSSOPoolId"], + } _LOGGER.debug("get_login_info - Login info extracted successfully") return login_data - except ( - OSError, - RuntimeError, - ZeroDivisionError, - json.JSONDecodeError, - KeyError, - ) as e: + except (OSError, RuntimeError, KeyError) as e: _LOGGER.error("Failed to get login info: %s", str(e)) self.error() return None @@ -147,59 +129,23 @@ def get_login_info(self): def get_all(self): """Build and query all endpoint.""" _LOGGER.debug("get_all - Fetching all devices/products/actions from Hive API") - json_return = {} url = self.urls["base"] + self.urls["all"] - try: - info = self.request("GET", url) - if info is not None: - json_return.update({"original": info.status_code}) - json_return.update({"parsed": info.json()}) - _LOGGER.debug( - "get_all - All data fetch successful, status: %s", info.status_code - ) - else: - _LOGGER.error("Failed to get response from all endpoint") - except (OSError, RuntimeError, ZeroDivisionError, json.JSONDecodeError) as e: - _LOGGER.error("Failed to fetch all data: %s", str(e)) - self.error() - - return json_return + return self._call_endpoint("GET", url) def get_devices(self): """Call the get devices endpoint.""" url = self.urls["base"] + self.urls["devices"] - try: - response = self.request("GET", url) - self.json_return.update({"original": response.status_code}) - self.json_return.update({"parsed": response.json()}) - except (OSError, RuntimeError, ZeroDivisionError): - self.error() - - return self.json_return + return self._call_endpoint("GET", url) def get_products(self): """Call the get products endpoint.""" url = self.urls["base"] + self.urls["products"] - try: - response = self.request("GET", url) - self.json_return.update({"original": response.status_code}) - self.json_return.update({"parsed": response.json()}) - except (OSError, RuntimeError, ZeroDivisionError): - self.error() - - return self.json_return + return self._call_endpoint("GET", url) def get_actions(self): """Call the get actions endpoint.""" url = self.urls["base"] + self.urls["actions"] - try: - response = self.request("GET", url) - self.json_return.update({"original": response.status_code}) - self.json_return.update({"parsed": response.json()}) - except (OSError, RuntimeError, ZeroDivisionError): - self.error() - - return self.json_return + return self._call_endpoint("GET", url) def motion_sensor(self, sensor, fromepoch, toepoch): """Call a way to get motion sensor info.""" @@ -215,27 +161,13 @@ def motion_sensor(self, sensor, fromepoch, toepoch): + "&to=" + str(toepoch) ) - try: - response = self.request("GET", url) - self.json_return.update({"original": response.status_code}) - self.json_return.update({"parsed": response.json()}) - except (OSError, RuntimeError, ZeroDivisionError): - self.error() - - return self.json_return + return self._call_endpoint("GET", url) def get_weather(self, weather_url): """Call endpoint to get local weather from Hive API.""" t_url = self.urls["weather"] + weather_url url = t_url.replace(" ", "%20") - try: - response = self.request("GET", url) - self.json_return.update({"original": response.status_code}) - self.json_return.update({"parsed": response.json()}) - except (OSError, RuntimeError, ZeroDivisionError, ConnectionError): - self.error() - - return self.json_return + return self._call_endpoint("GET", url) def set_state(self, n_type, n_id, **kwargs): """Set the state of a Device.""" @@ -245,58 +177,27 @@ def set_state(self, n_type, n_id, **kwargs): n_type, kwargs, ) - jsc = ( - "{" - + ",".join( - ('"' + str(i) + '": "' + str(t) + '" ' for i, t in kwargs.items()) - ) - + "}" - ) - + jsc = json.dumps(kwargs) url = self.urls["base"] + self.urls["nodes"].format(n_type, n_id) - - try: - response = self.request("POST", url, jsc) - if response is not None: - self.json_return.update({"original": response.status_code}) - self.json_return.update({"parsed": response.json()}) - _LOGGER.debug( - "set_state - State set successfully for %s, status: %s", - n_id, - response.status_code, - ) - else: - _LOGGER.error("Failed to set state for %s - no response", n_id) - except ( - OSError, - RuntimeError, - ZeroDivisionError, - ConnectionError, - json.JSONDecodeError, - ) as e: - _LOGGER.error("Failed to set state for %s: %s", n_id, str(e)) - self.error() - - return self.json_return + return self._call_endpoint("POST", url, jsc) def set_action(self, n_id, data): """Set the state of a Action.""" jsc = data url = self.urls["base"] + self.urls["actions"] + "/" + n_id - try: - response = self.request("POST", url, jsc) - self.json_return.update({"original": response.status_code}) - self.json_return.update({"parsed": response.json()}) - except (OSError, RuntimeError, ZeroDivisionError, ConnectionError): - self.error() - - return self.json_return + return self._call_endpoint("POST", url, jsc) def error(self): """An error has occurred interacting with the Hive API.""" _LOGGER.error("API error occurred - returning error response") - self.json_return.update({"original": "Error making API call"}) - self.json_return.update({"parsed": "Error making API call"}) + error_return = { + "original": _ERROR_RESPONSE, + "parsed": _ERROR_RESPONSE, + } + # Kept in sync for backwards compatibility with callers that read + # the last error state off the instance. + self.json_return.update(error_return) + return error_return class UnknownConfig(Exception): diff --git a/src/api/hive_async_api.py b/src/api/hive_async_api.py index bc91789..e68d833 100644 --- a/src/api/hive_async_api.py +++ b/src/api/hive_async_api.py @@ -5,17 +5,20 @@ import logging import time -import requests -import urllib3 -from aiohttp import ClientResponse, ClientSession, ClientTimeout, web_exceptions -from pyquery import PyQuery - -from ..helper.const import HTTP_FORBIDDEN, HTTP_OK, HTTP_UNAUTHORIZED +from aiohttp import ( + ClientError, + ClientResponse, + ClientSession, + ClientTimeout, + web_exceptions, +) + +from ..helper.const import HTTP_FORBIDDEN, HTTP_UNAUTHORIZED from ..helper.hive_exceptions import FileInUse, HiveApiError, HiveAuthError, NoApiToken _LOGGER = logging.getLogger(__name__) -urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) +_REQUEST_ERRORS = (ClientError, OSError, RuntimeError, json.JSONDecodeError) class HiveApiAsync: @@ -43,7 +46,17 @@ def __init__(self, hive_session=None, websession: ClientSession | None = None): "parsed": "No response to Hive API request", } self.session = hive_session - self.websession = ClientSession() if websession is None else websession + self.websession = websession + + def _get_websession(self) -> ClientSession: + """Return the shared ClientSession, creating it on first use. + + Created lazily so the session is constructed inside a running + event loop rather than in the synchronous constructor. + """ + if self.websession is None: + self.websession = ClientSession() + return self.websession async def request(self, method: str, url: str, **kwargs) -> ClientResponse: """Make a request.""" @@ -72,7 +85,7 @@ async def request(self, method: str, url: str, **kwargs) -> ClientResponse: timeout = ClientTimeout(total=self.timeout) req_start = time.monotonic() - async with self.websession.request( + async with self._get_websession().request( method, url, headers=headers, data=data, timeout=timeout ) as resp: resp_body = await resp.text() @@ -97,125 +110,50 @@ async def request(self, method: str, url: str, **kwargs) -> ClientResponse: raise HiveAuthError( f"Token expired or forbidden calling {url} — HTTP {resp.status}" ) - if url is not None and resp.status is not None: - _LOGGER.error( - "Something has gone wrong calling %s - HTTP status is - %s — response: %s", - url, - resp.status, - resp_body[:200], - ) - - raise HiveApiError - - def get_login_info(self): - """Get login properties to make the login request.""" - url = "https://sso.hivehome.com/" - - data = requests.get(url=url, verify=False, timeout=self.timeout) - html = PyQuery(data.content) - json_data = json.loads( - '{"' - + (html("script:first").text()) - .replace(",", ', "') - .replace("=", '":') - .replace("window.", "") - + "}" - ) - - login_data = {} - login_data.update({"UPID": json_data["HiveSSOPoolId"]}) - login_data.update({"CLIID": json_data["HiveSSOPublicCognitoClientId"]}) - login_data.update({"REGION": json_data["HiveSSOPoolId"]}) - return login_data - - async def refresh_tokens(self): - """Refresh tokens - DEPRECATED NOW BY AWS TOKEN MANAGEMENT.""" - url = self.urls["refresh"] - if self.session is not None: - tokens = self.session.tokens.token_data - jsc = ( - "{" - + ",".join( - ('"' + str(i) + '": "' + str(t) + '" ' for i, t in tokens.items()) - ) - + "}" + _LOGGER.error( + "Something has gone wrong calling %s - HTTP status is - %s — response: %s", + url, + resp.status, + resp_body[:200], ) - try: - await self.request("post", url, data=jsc) - - if self.json_return["original"] == HTTP_OK: - info = self.json_return["parsed"] - if "token" in info: - await self.session.update_tokens(info) - # pylint: disable-next=invalid-sequence-index - self.base_url = info["platform"]["endpoint"] - return True - except (ConnectionError, OSError, RuntimeError, ZeroDivisionError): - await self.error() - - return self.json_return + raise HiveApiError - async def get_all(self): - """Build and query all endpoint.""" - json_return = {} - url = self.urls["all"] + async def _call_endpoint(self, method: str, url: str, data=None) -> dict: + """Call an endpoint and return {"original": status, "parsed": json}.""" + json_return: dict = {} try: - resp = await self.request("get", url) + resp = await self.request(method, url, data=data) json_return.update({"original": resp.status}) json_return.update({"parsed": await resp.json(content_type=None)}) except asyncio.TimeoutError: - _LOGGER.warning("Hive API request timed out fetching all nodes.") + _LOGGER.warning("Hive API request timed out calling %s", url) raise - except (OSError, RuntimeError, ZeroDivisionError): + except _REQUEST_ERRORS: await self.error() return json_return + async def get_all(self): + """Build and query all endpoint.""" + return await self._call_endpoint("get", self.urls["all"]) + async def get_devices(self): """Call the get devices endpoint.""" - json_return = {} - url = self.urls["devices"] - try: - resp = await self.request("get", url) - json_return.update({"original": resp.status}) - json_return.update({"parsed": await resp.json(content_type=None)}) - except (OSError, RuntimeError, ZeroDivisionError): - await self.error() - - return json_return + return await self._call_endpoint("get", self.urls["devices"]) async def get_products(self): """Call the get products endpoint.""" - json_return = {} - url = self.urls["products"] - try: - resp = await self.request("get", url) - json_return.update({"original": resp.status}) - json_return.update({"parsed": await resp.json(content_type=None)}) - except (OSError, RuntimeError, ZeroDivisionError): - await self.error() - - return json_return + return await self._call_endpoint("get", self.urls["products"]) async def get_actions(self): """Call the get actions endpoint.""" - json_return = {} - url = self.urls["actions"] - try: - resp = await self.request("get", url) - json_return.update({"original": resp.status}) - json_return.update({"parsed": await resp.json(content_type=None)}) - except (OSError, RuntimeError, ZeroDivisionError): - await self.error() - - return json_return + return await self._call_endpoint("get", self.urls["actions"]) async def motion_sensor(self, sensor, fromepoch, toepoch): """Call a way to get motion sensor info.""" - json_return = {} url = ( - self.urls["base"] - + self.urls["products"] + self.base_url + + "/products" + "/" + sensor["type"] + "/" @@ -225,68 +163,34 @@ async def motion_sensor(self, sensor, fromepoch, toepoch): + "&to=" + str(toepoch) ) - try: - resp = await self.request("get", url) - json_return.update({"original": resp.status}) - json_return.update({"parsed": await resp.json(content_type=None)}) - except (OSError, RuntimeError, ZeroDivisionError): - await self.error() - - return json_return + return await self._call_endpoint("get", url) async def get_weather(self, weather_url): """Call endpoint to get local weather from Hive API.""" - json_return = {} t_url = self.urls["weather"] + weather_url url = t_url.replace(" ", "%20") - try: - resp = await self.request("get", url) - json_return.update({"original": resp.status}) - json_return.update({"parsed": await resp.json(content_type=None)}) - except (OSError, RuntimeError, ZeroDivisionError, ConnectionError): - await self.error() - - return json_return + return await self._call_endpoint("get", url) async def set_state(self, n_type, n_id, **kwargs): """Set the state of a Device.""" _LOGGER.debug("set_state - Setting state for %s/%s: %s", n_type, n_id, kwargs) - json_return = {} - jsc = ( - "{" - + ",".join( - ('"' + str(i) + '": "' + str(t) + '" ' for i, t in kwargs.items()) - ) - + "}" - ) - + jsc = json.dumps(kwargs) url = self.urls["nodes"].format(n_type, n_id) try: await self.is_file_being_used() - resp = await self.request("post", url, data=jsc) - json_return["original"] = resp.status - json_return["parsed"] = await resp.json(content_type=None) - except (FileInUse, OSError, RuntimeError, ConnectionError) as e: - if e.__class__.__name__ == "FileInUse": - return {"original": "file"} - await self.error() - - return json_return + except FileInUse: + return {"original": "file"} + return await self._call_endpoint("post", url, data=jsc) async def set_action(self, n_id, data): """Set the state of a Action.""" _LOGGER.debug("Setting action %s", n_id) - jsc = data url = self.urls["actions"] + "/" + n_id try: await self.is_file_being_used() - await self.request("put", url, data=jsc) - except (FileInUse, OSError, RuntimeError, ConnectionError) as e: - if e.__class__.__name__ == "FileInUse": - return {"original": "file"} - await self.error() - - return self.json_return + except FileInUse: + return {"original": "file"} + return await self._call_endpoint("put", url, data=data) async def error(self): """An error has occurred interacting with the Hive API.""" diff --git a/src/api/hive_auth_async.py b/src/api/hive_auth_async.py index 4056690..6866a84 100644 --- a/src/api/hive_auth_async.py +++ b/src/api/hive_auth_async.py @@ -23,6 +23,7 @@ HiveInvalidPassword, HiveInvalidUsername, HiveRefreshTokenExpired, + HiveUnknownConfiguration, ) from .device_registration import DeviceRegistrationMixin from .hive_api import HiveApi @@ -58,17 +59,10 @@ def __init__( # pylint: disable=too-many-positional-arguments # noqa: PLR0913 device_group_key: str | None = None, device_key: str | None = None, device_password: str | None = None, - pool_region: str | None = None, client_secret: str | None = None, ): """Initialise async auth.""" - if pool_region is not None: - raise ValueError( - "pool_region and client should not both be specified " - "(region should be passed to the boto3 client instead)" - ) - - self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() + self.loop: asyncio.AbstractEventLoop | None = None self.username = username self.password = password self.device_group_key: str | None = device_group_key @@ -95,10 +89,19 @@ def __init__( # pylint: disable=too-many-positional-arguments # noqa: PLR0913 async def async_init(self): """Initialise async variables.""" + self.loop = asyncio.get_running_loop() self.data = await self.loop.run_in_executor(None, self.api.get_login_info) + if not self.data: + raise HiveUnknownConfiguration("SSO login page returned no data") self._pool_id = self.data.get("UPID") self._client_id = self.data.get("CLIID") - self._region = self.data.get("REGION").split("_")[0] + region_raw = self.data.get("REGION") + if not self._pool_id or not region_raw: + raise HiveUnknownConfiguration( + "SSO login page did not return required pool/region data" + ) + self._region = region_raw.split("_")[0] + # Cognito USER_SRP_AUTH does not use IAM credentials — boto3 requires non-None values. self.client = await self.loop.run_in_executor( None, functools.partial( @@ -128,6 +131,25 @@ def generate_random_small_a(self): random_long_int = get_random(128) return random_long_int % self.big_n + def _new_srp_ephemeral(self) -> None: + """Generate a fresh SRP client ephemeral (a, A) pair. + + SRP ephemerals must not be reused across handshakes, so this is + called at the start of every authentication attempt. + """ + self.small_a_value = self.generate_random_small_a() + self.large_a_value = self.calculate_a() + + def _store_auth_result(self, result: dict) -> None: + """Store tokens and any new device keys from an AuthenticationResult.""" + auth_result = result["AuthenticationResult"] + self.access_token = auth_result["AccessToken"] + self.token_created = datetime.datetime.now() + if "NewDeviceMetadata" in auth_result: + self.device_group_key = auth_result["NewDeviceMetadata"]["DeviceGroupKey"] + self.device_key = auth_result["NewDeviceMetadata"]["DeviceKey"] + _LOGGER.debug("Device keys stored successfully.") + def calculate_a(self): """ Calculate the client's public value A. @@ -156,6 +178,8 @@ def get_password_authentication_key(self, username, password, server_b_value, sa u_value = calculate_u(self.large_a_value, server_b_value) if u_value == 0: raise ValueError("U cannot be zero.") + if not self._pool_id or "_" not in self._pool_id: + raise HiveUnknownConfiguration(f"Invalid pool ID format: {self._pool_id!r}") pool_id = self._pool_id.split("_")[1] username_password = f"{pool_id}{username}:{password}" username_password_hash = hash_sha256(username_password.encode("utf-8")) @@ -255,7 +279,7 @@ async def process_challenge(self, challenge_parameters): return response - async def login(self): # noqa: PLR0912 + async def login(self): # noqa: PLR0912, PLR0915 # pylint: disable=too-many-statements """Login into a Hive account - handles initial SRP auth only.""" if self.use_file: _LOGGER.debug("login - Using file-based authentication.") @@ -264,6 +288,7 @@ async def login(self): # noqa: PLR0912 if self.client is None: await self.async_init() + self._new_srp_ephemeral() auth_params = await self.get_auth_params() response = None result = None @@ -279,15 +304,22 @@ async def login(self): # noqa: PLR0912 ), ) except botocore.exceptions.ClientError as err: - if err.__class__.__name__ == "UserNotFoundException": + code = (err.response or {}).get("Error", {}).get("Code", "") + if code == "UserNotFoundException": _LOGGER.error("Cognito auth failed: user not found.") raise HiveInvalidUsername from err + _LOGGER.error("Cognito auth failed: %s", code) + raise HiveApiError from err except botocore.exceptions.EndpointConnectionError as err: - if err.__class__.__name__ == "EndpointConnectionError": - _LOGGER.error("Cognito auth failed: cannot reach endpoint.") - raise HiveApiError from err + _LOGGER.error("Cognito auth failed: cannot reach endpoint.") + raise HiveApiError from err + + if "AuthenticationResult" in response: + _LOGGER.debug("login - Authenticated directly without a challenge.") + self._store_auth_result(response) + return response - if response["ChallengeName"] == self.PASSWORD_VERIFIER_CHALLENGE: + if response.get("ChallengeName") == self.PASSWORD_VERIFIER_CHALLENGE: _LOGGER.debug("login - Processing PASSWORD_VERIFIER challenge.") challenge_response = await self.process_challenge( response["ChallengeParameters"] @@ -303,38 +335,29 @@ async def login(self): # noqa: PLR0912 ), ) except botocore.exceptions.ClientError as err: - if err.__class__.__name__ == "NotAuthorizedException": + code = (err.response or {}).get("Error", {}).get("Code", "") + if code == "NotAuthorizedException": _LOGGER.error("Cognito auth challenge failed: not authorised.") raise HiveInvalidPassword from err - if err.__class__.__name__ == "ResourceNotFoundException": + if code == "ResourceNotFoundException": _LOGGER.error( "Cognito auth challenge failed: device resource not found." ) raise HiveInvalidDeviceAuthentication from err + _LOGGER.error("Cognito auth challenge failed: %s", code) + raise HiveApiError from err except botocore.exceptions.EndpointConnectionError as err: - if err.__class__.__name__ == "EndpointConnectionError": - _LOGGER.error( - "Cognito auth challenge failed: cannot reach endpoint." - ) - raise HiveApiError from err + _LOGGER.error("Cognito auth challenge failed: cannot reach endpoint.") + raise HiveApiError from err _LOGGER.debug("login - SRP auth challenge completed successfully.") if "AuthenticationResult" in result: - self.access_token = result["AuthenticationResult"]["AccessToken"] - self.token_created = datetime.datetime.now() - if "NewDeviceMetadata" in result["AuthenticationResult"]: - self.device_group_key = result["AuthenticationResult"][ - "NewDeviceMetadata" - ]["DeviceGroupKey"] - self.device_key = result["AuthenticationResult"][ - "NewDeviceMetadata" - ]["DeviceKey"] - _LOGGER.debug("login - Device keys stored successfully.") + self._store_auth_result(result) return result - challenge_name = response["ChallengeName"] + challenge_name = response.get("ChallengeName") _LOGGER.error("Unsupported Cognito challenge: %s", challenge_name) raise NotImplementedError(f"The {challenge_name} challenge is not supported") @@ -349,6 +372,7 @@ async def device_login(self): if self.client is None: await self.async_init() + self._new_srp_ephemeral() auth_params = await self.get_auth_params(is_device_login=True) _LOGGER.debug("device_login - Processing DEVICE_SRP_AUTH challenge.") @@ -385,10 +409,8 @@ async def device_login(self): raise HiveInvalidDeviceAuthentication from err raise except botocore.exceptions.EndpointConnectionError as err: - if err.__class__.__name__ == "EndpointConnectionError": - _LOGGER.error("Device login failed: cannot reach endpoint.") - raise HiveApiError from err - raise HiveInvalidDeviceAuthentication from err + _LOGGER.error("Device login failed: cannot reach endpoint.") + raise HiveApiError from err _LOGGER.debug("device_login - Device authentication completed successfully.") return result @@ -413,26 +435,18 @@ async def sms_2fa(self, entered_code, challenge_parameters): }, ), ) - self.access_token = result["AuthenticationResult"]["AccessToken"] - self.token_created = datetime.datetime.now() - if "NewDeviceMetadata" in result["AuthenticationResult"]: - self.device_group_key = result["AuthenticationResult"][ - "NewDeviceMetadata" - ]["DeviceGroupKey"] - self.device_key = result["AuthenticationResult"]["NewDeviceMetadata"][ - "DeviceKey" - ] + if result and "AuthenticationResult" in result: + self._store_auth_result(result) except botocore.exceptions.ClientError as err: - if err.__class__.__name__ in ( - "NotAuthorizedException", - "CodeMismatchException", - ): + code = (err.response or {}).get("Error", {}).get("Code", "") + if code in ("NotAuthorizedException", "CodeMismatchException"): _LOGGER.error("2FA code rejected by Cognito.") raise HiveInvalid2FACode from err + _LOGGER.error("2FA failed: %s", code) + raise HiveApiError from err except botocore.exceptions.EndpointConnectionError as err: - if err.__class__.__name__ == "EndpointConnectionError": - _LOGGER.error("2FA failed: cannot reach Cognito endpoint.") - raise HiveApiError from err + _LOGGER.error("2FA failed: cannot reach Cognito endpoint.") + raise HiveApiError from err _LOGGER.debug("sms_2fa - 2FA authentication completed successfully.") return result @@ -478,11 +492,10 @@ async def refresh_token(self, token): ) raise HiveFailedToRefreshTokens from err except botocore.exceptions.EndpointConnectionError as err: - if err.__class__.__name__ == "EndpointConnectionError": - _LOGGER.error( - "refresh_token - Token refresh failed: cannot reach Cognito endpoint." - ) - raise HiveApiError from err + _LOGGER.error( + "refresh_token - Token refresh failed: cannot reach Cognito endpoint." + ) + raise HiveApiError from err _LOGGER.debug("refresh_token - Cognito token refresh completed successfully.") return result diff --git a/src/api/srp_crypto.py b/src/api/srp_crypto.py index faf141b..74cea33 100644 --- a/src/api/srp_crypto.py +++ b/src/api/srp_crypto.py @@ -1,7 +1,6 @@ """Pure SRP/HKDF crypto helpers for AWS Cognito authentication.""" import binascii -import concurrent.futures import hashlib import hmac import os @@ -28,7 +27,6 @@ # https://github.com/aws/amazon-cognito-identity-js/blob/master/src/AuthenticationHelper.js#L49 G_HEX = "2" INFO_BITS = bytearray("Caldera Derived Key", "utf-8") -POOL = concurrent.futures.ThreadPoolExecutor() def hex_to_long(hex_string): diff --git a/src/devices/boost.py b/src/devices/boost.py index 2e3f1b7..b1e22e9 100644 --- a/src/devices/boost.py +++ b/src/devices/boost.py @@ -29,7 +29,7 @@ async def get_boost_status(self, device: Device): data = self.session.data.products[device.hive_id] return HIVETOHA["Boost"].get(data["state"].get("boost", False), "ON") except KeyError as e: - _LOGGER.error(e) + _LOGGER.error("get_boost_status - KeyError for %s: %s", device.ha_name, e) return None async def get_boost_time(self, device: Device): @@ -43,5 +43,5 @@ async def get_boost_time(self, device: Device): data = self.session.data.products[device.hive_id] return data["state"]["boost"] except KeyError as e: - _LOGGER.error(e) + _LOGGER.error("get_boost_time - KeyError for %s: %s", device.ha_name, e) return None diff --git a/src/devices/color.py b/src/devices/color.py index ade4cb2..8a63a9d 100644 --- a/src/devices/color.py +++ b/src/devices/color.py @@ -2,7 +2,6 @@ from __future__ import annotations -import colorsys import logging from typing import Any @@ -33,7 +32,7 @@ async def get_min_color_temp(self, device: Device): data = self.session.data.products[device.hive_id] state = data["props"]["colourTemperature"]["max"] return round((1 / state) * 1000000) - except KeyError as e: + except (KeyError, ZeroDivisionError) as e: _LOGGER.error(e) return None @@ -50,7 +49,7 @@ async def get_max_color_temp(self, device: Device): data = self.session.data.products[device.hive_id] state = data["props"]["colourTemperature"]["min"] return round((1 / state) * 1000000) - except KeyError as e: + except (KeyError, ZeroDivisionError) as e: _LOGGER.error(e) return None @@ -67,27 +66,23 @@ async def get_color_temp(self, device: Device): data = self.session.data.products[device.hive_id] state = data["state"]["colourTemperature"] return round((1 / state) * 1000000) - except KeyError as e: + except (KeyError, ZeroDivisionError) as e: _LOGGER.error(e) return None async def get_color(self, device: Device): - """Get light current colour as an RGB tuple. + """Get light current colour as an HS tuple for HA hs_color. Args: device (Device): Device to query. Returns: - tuple | None: ``(r, g, b)`` each in 0–255. + tuple | None: ``(hue_degrees, saturation_percent)`` where hue is + 0–360 and saturation is 0–100, or None on error. """ try: data = self.session.data.products[device.hive_id] - hsv = [ - data["state"]["hue"] / 360, - data["state"]["saturation"] / 100, - data["state"]["value"] / 100, - ] - return tuple(int(i * 255) for i in colorsys.hsv_to_rgb(*hsv)) + return (data["state"]["hue"], data["state"]["saturation"]) except KeyError as e: _LOGGER.error(e) return None diff --git a/src/devices/heating.py b/src/devices/heating.py index 2b81e1e..de458d7 100644 --- a/src/devices/heating.py +++ b/src/devices/heating.py @@ -48,6 +48,36 @@ async def get_max_temperature(self, device: Device): return self._get_product_state(device, "props", "maxHeat") return 32 + def _track_minmax(self, hive_id: str, temperature: float) -> None: + """Record today's and since-restart min/max temperatures for a device.""" + today = str(datetime.date(datetime.now())) + min_max = self.session.data.minMax.get(hive_id) + + if min_max is None: + self.session.data.minMax[hive_id] = { + "TodayMin": temperature, + "TodayMax": temperature, + "TodayDate": today, + "RestartMin": temperature, + "RestartMax": temperature, + } + return + + if min_max["TodayDate"] == today: + min_max["TodayMin"] = min(min_max["TodayMin"], temperature) + min_max["TodayMax"] = max(min_max["TodayMax"], temperature) + else: + min_max.update( + { + "TodayMin": temperature, + "TodayMax": temperature, + "TodayDate": today, + } + ) + + min_max["RestartMin"] = min(min_max["RestartMin"], temperature) + min_max["RestartMax"] = max(min_max["RestartMax"], temperature) + async def get_current_temperature(self, device: Device): """Get heating current temperature. @@ -75,41 +105,7 @@ async def get_current_temperature(self, device: Device): ) return None - if device.hive_id in self.session.data.minMax: - if self.session.data.minMax[device.hive_id]["TodayDate"] == str( - datetime.date(datetime.now()) - ): - self.session.data.minMax[device.hive_id]["TodayMin"] = min( - self.session.data.minMax[device.hive_id]["TodayMin"], state - ) - - self.session.data.minMax[device.hive_id]["TodayMax"] = max( - self.session.data.minMax[device.hive_id]["TodayMax"], state - ) - else: - data = { - "TodayMin": state, - "TodayMax": state, - "TodayDate": str(datetime.date(datetime.now())), - } - self.session.data.minMax[device.hive_id].update(data) - - self.session.data.minMax[device.hive_id]["RestartMin"] = min( - self.session.data.minMax[device.hive_id]["RestartMin"], state - ) - - self.session.data.minMax[device.hive_id]["RestartMax"] = max( - self.session.data.minMax[device.hive_id]["RestartMax"], state - ) - else: - data = { - "TodayMin": state, - "TodayMax": state, - "TodayDate": str(datetime.date(datetime.now())), - "RestartMin": state, - "RestartMax": state, - } - self.session.data.minMax[device.hive_id] = data + self._track_minmax(device.hive_id, state) final = round(state, 1) except KeyError as e: @@ -170,15 +166,16 @@ async def get_mode(self, device: Device): """ state = None final = None + device_name = device.ha_name try: data = self.session.data.products[device.hive_id] state = data["state"]["mode"] if state == "BOOST": - state = data["props"]["previous"]["mode"] + state = self._get_product_state(device, "props", "previous", "mode") final = HIVETOHA[self.heating_type].get(state, state) except KeyError as e: - _LOGGER.error(e) + _LOGGER.error("get_mode - KeyError getting mode for %s: %s", device_name, e) return final @@ -284,7 +281,18 @@ async def set_boost_on(self, device: Device, mins: str, temp: float): """ min_temp = await self.get_min_temperature(device) max_temp = await self.get_max_temperature(device) - if not (int(mins) > 0 and min_temp <= int(temp) <= max_temp): + try: + mins_value = int(mins) + temp_value = float(temp) + except (ValueError, TypeError): + _LOGGER.warning( + "set_boost_on - Invalid boost inputs for %s: mins=%r temp=%r", + device.ha_name, + mins, + temp, + ) + return None + if not (mins_value > 0 and min_temp <= temp_value <= max_temp): return None _LOGGER.debug( "set_boost_on - Setting heating boost ON for %s: %s mins at %s degrees.", @@ -317,6 +325,12 @@ async def set_boost_off(self, device: Device): "set_boost_off - Setting heating boost OFF for %s.", device.ha_name ) prev_mode = self._get_product_state(device, "props", "previous", "mode") + if prev_mode is None: + _LOGGER.warning( + "set_boost_off - Cannot determine previous mode for %s, skipping.", + device.ha_name, + ) + return False kwargs = {"mode": prev_mode} if prev_mode in ("MANUAL", "OFF"): kwargs["target"] = ( diff --git a/src/devices/hotwater.py b/src/devices/hotwater.py index b6985b2..d4774c8 100644 --- a/src/devices/hotwater.py +++ b/src/devices/hotwater.py @@ -33,15 +33,16 @@ async def get_mode(self, device: Device): """ state = None final = None + device_name = device.ha_name try: data = self.session.data.products[device.hive_id] state = data["state"]["mode"] if state == "BOOST": - state = data["props"]["previous"]["mode"] + state = self._get_product_state(device, "props", "previous", "mode") final = HIVETOHA[self.hotwater_type].get(state, state) except KeyError as e: - _LOGGER.error(e) + _LOGGER.error("get_mode - KeyError getting mode for %s: %s", device_name, e) return final @@ -77,7 +78,8 @@ async def get_state(self, device: Device): snan = self.session.helper.get_schedule_nnl( data["state"]["schedule"] ) - state = snan["now"]["value"]["status"] + if snan and "now" in snan: + state = snan["now"]["value"]["status"] final = HIVETOHA[self.hotwater_type].get(state, state) except KeyError as e: @@ -141,6 +143,12 @@ async def set_boost_off(self, device: Device): "set_boost_off - Setting hot water boost OFF for %s.", device.ha_name ) prev_mode = self._get_product_state(device, "props", "previous", "mode") + if prev_mode is None: + _LOGGER.warning( + "set_boost_off - Cannot determine previous mode for %s, skipping.", + device.ha_name, + ) + return False return await self._execute_state_change(device, mode=prev_mode) diff --git a/src/devices/light.py b/src/devices/light.py index b33637e..d48dc91 100644 --- a/src/devices/light.py +++ b/src/devices/light.py @@ -62,10 +62,10 @@ async def get_brightness(self, device: Device): try: data = self.session.data.products[device.hive_id] state = data["state"]["brightness"] - final = (state / 100) * 255 - except KeyError as e: + final = int((state / 100) * 255) + except (KeyError, TypeError) as e: _LOGGER.error( - "KeyError getting light brightness for %s: %s", device_name, str(e) + "Error getting light brightness for %s: %s", device_name, str(e) ) return final diff --git a/src/devices/sensor.py b/src/devices/sensor.py index 9f0a431..bed1ebb 100644 --- a/src/devices/sensor.py +++ b/src/devices/sensor.py @@ -4,12 +4,36 @@ from typing import Any from ..helper.compat_aliases import SensorCompatMixin -from ..helper.const import HIVE_TYPES, HIVETOHA, sensor_commands +from ..helper.const import HIVE_TYPES, HIVETOHA from ..helper.device_handler_base import BaseDeviceHandler from ..helper.hivedataclasses import Device _LOGGER = logging.getLogger(__name__) +sensor_commands = { + "SMOKE_CO": lambda s, d: s.session.hub.get_smoke_status(d), + "DOG_BARK": lambda s, d: s.session.hub.get_dog_bark_status(d), + "GLASS_BREAK": lambda s, d: s.session.hub.get_glass_break_status(d), + "Current_Temperature": lambda s, d: s.session.heating.get_current_temperature(d), + "Heating_Current_Temperature": lambda s, d: ( + s.session.heating.get_current_temperature(d) + ), + "Heating_Target_Temperature": lambda s, d: s.session.heating.get_target_temperature( + d + ), + "Heating_State": lambda s, d: s.session.heating.get_state(d), + "Heating_Mode": lambda s, d: s.session.heating.get_mode(d), + "Heating_Boost": lambda s, d: s.session.heating.get_boost_status(d), + "Hotwater_State": lambda s, d: s.session.hotwater.get_state(d), + "Hotwater_Mode": lambda s, d: s.session.hotwater.get_mode(d), + "Hotwater_Boost": lambda s, d: s.session.hotwater.get_boost(d), + "Battery": lambda s, d: s.session.attr.get_battery(d.device_id), + "Mode": lambda s, d: s.session.attr.get_mode(d.hive_id), + "Availability": lambda s, d: s.online(d), + "Connectivity": lambda s, d: s.online(d), + "Power": lambda s, d: s.session.switch.get_power_usage(d), +} + class HiveSensor(BaseDeviceHandler): """Hive Sensor Code.""" @@ -133,7 +157,7 @@ async def get_sensor(self, device: Device): device.device_data = props device.parent_device = data.get("parent", None) elif device.hive_type in HIVE_TYPES["Sensor"]: - data = self.session.data.devices.get(device.hive_id, {}) + data = self.session.data.devices.get(device.device_id, {}) device.status = {"state": await self.get_state(device)} props = data.get("props") or {} props["online"] = online diff --git a/src/helper/compat_aliases.py b/src/helper/compat_aliases.py index b1fe79e..8cdfc0e 100644 --- a/src/helper/compat_aliases.py +++ b/src/helper/compat_aliases.py @@ -7,6 +7,7 @@ from __future__ import annotations +from datetime import timedelta from typing import Any from .hivedataclasses import Device @@ -138,6 +139,7 @@ async def updateData(self, device: Device): # pylint: disable=invalid-name """Backwards-compatible alias for update_data.""" return await self.update_data(device) # type: ignore[attr-defined] - async def updateInterval(self, new_interval: int): # pylint: disable=invalid-name,unused-argument + async def updateInterval(self, new_interval: int): # pylint: disable=invalid-name """Backwards-compatible alias for Home Assistant Scan Interval.""" + self.config.scan_interval = timedelta(seconds=new_interval) # type: ignore[attr-defined] return True diff --git a/src/helper/const.py b/src/helper/const.py index 7c27601..59d1c7d 100644 --- a/src/helper/const.py +++ b/src/helper/const.py @@ -60,29 +60,6 @@ "Sensor": ["motionsensor", "contactsensor"], "Switch": ["activeplug"], } -sensor_commands = { - "SMOKE_CO": lambda s, d: s.session.hub.get_smoke_status(d), - "DOG_BARK": lambda s, d: s.session.hub.get_dog_bark_status(d), - "GLASS_BREAK": lambda s, d: s.session.hub.get_glass_break_status(d), - "Current_Temperature": lambda s, d: s.session.heating.get_current_temperature(d), - "Heating_Current_Temperature": lambda s, d: ( - s.session.heating.get_current_temperature(d) - ), - "Heating_Target_Temperature": lambda s, d: s.session.heating.get_target_temperature( - d - ), - "Heating_State": lambda s, d: s.session.heating.get_state(d), - "Heating_Mode": lambda s, d: s.session.heating.get_mode(d), - "Heating_Boost": lambda s, d: s.session.heating.get_boost_status(d), - "Hotwater_State": lambda s, d: s.session.hotwater.get_state(d), - "Hotwater_Mode": lambda s, d: s.session.hotwater.get_mode(d), - "Hotwater_Boost": lambda s, d: s.session.hotwater.get_boost(d), - "Battery": lambda s, d: s.session.attr.get_battery(d.device_id), - "Mode": lambda s, d: s.session.attr.get_mode(d.hive_id), - "Availability": lambda s, d: s.online(d), - "Connectivity": lambda s, d: s.online(d), - "Power": lambda s, d: s.session.switch.get_power_usage(d), -} PRODUCTS = { "sense": [ diff --git a/src/helper/device_attributes.py b/src/helper/device_attributes.py index b703c21..d218a36 100644 --- a/src/helper/device_attributes.py +++ b/src/helper/device_attributes.py @@ -3,8 +3,6 @@ import logging from typing import Any -from .const import HIVETOHA - _LOGGER = logging.getLogger(__name__) @@ -18,7 +16,6 @@ def __init__(self, session: Any = None): session (object, optional): Session to interact with hive account. Defaults to None. """ self.session = session - self.type = "Attribute" async def state_attributes(self, n_id: str, _type: str): """Get HA State Attributes. @@ -70,17 +67,12 @@ async def get_mode(self, n_id: str): Returns: str: The mode of the device. """ - state = None - final = None - try: data = self.session.data.products[n_id] - state = data["state"]["mode"] - final = HIVETOHA[self.type].get(state, state) + return data["state"]["mode"] except KeyError as e: _LOGGER.error(e) - - return final + return None async def get_battery(self, n_id: str): """Get device battery level. @@ -98,7 +90,6 @@ async def get_battery(self, n_id: str): data = self.session.data.devices[n_id] state = data["props"]["battery"] final = state - await self.session.helper.error_check(n_id, self.type, state) except KeyError as e: _LOGGER.error(e) diff --git a/src/helper/hive_exceptions.py b/src/helper/hive_exceptions.py index 4881d12..2a69441 100644 --- a/src/helper/hive_exceptions.py +++ b/src/helper/hive_exceptions.py @@ -19,14 +19,22 @@ class NoApiToken(Exception): """ -class HiveApiError(Exception): - """Api error. +class HiveError(Exception): + """Common base class for all Hive-specific exceptions. Args: Exception (object): Exception object to invoke """ +class HiveApiError(HiveError): + """Api error. + + Args: + HiveError (object): Parent Hive error class + """ + + class HiveAuthError(HiveApiError): """Auth error (401/403) — token may be expired or invalid. @@ -35,65 +43,81 @@ class HiveAuthError(HiveApiError): """ -class HiveRefreshTokenExpired(Exception): +class HiveRefreshTokenExpired(HiveApiError): """Refresh token expired. Args: - Exception (object): Exception object to invoke + HiveApiError (object): Parent API error class """ -class HiveReauthRequired(Exception): - """Re-Authentication is required. +class HiveFailedToRefreshTokens(HiveApiError): + """Raise invalid refresh tokens. Args: - Exception (object): Exception object to invoke + HiveApiError (object): Parent API error class + """ + + +class HiveConfigurationError(HiveError): + """Base class for configuration-related errors. + + Args: + HiveError (object): Parent Hive error class """ -class HiveUnknownConfiguration(Exception): +class HiveUnknownConfiguration(HiveConfigurationError): """Unknown Hive Configuration. Args: - Exception (object): Exception object to invoke + HiveConfigurationError (object): Parent configuration error class """ -class HiveInvalidUsername(Exception): - """Raise invalid Username. +class HiveInvalidDeviceAuthentication(HiveConfigurationError): + """Raise invalid device authentication. Args: - Exception (object): Exception object to invoke + HiveConfigurationError (object): Parent configuration error class """ -class HiveInvalidPassword(Exception): - """Raise invalid password. +class HiveAuthCredentialError(HiveError): + """Base class for authentication credential errors. Args: - Exception (object): Exception object to invoke + HiveError (object): Parent Hive error class """ -class HiveInvalid2FACode(Exception): - """Raise invalid 2FA code. +class HiveInvalidUsername(HiveAuthCredentialError): + """Raise invalid Username. Args: - Exception (object): Exception object to invoke + HiveAuthCredentialError (object): Parent credential error class """ -class HiveInvalidDeviceAuthentication(Exception): - """Raise invalid device authentication. +class HiveInvalidPassword(HiveAuthCredentialError): + """Raise invalid password. Args: - Exception (object): Exception object to invoke + HiveAuthCredentialError (object): Parent credential error class """ -class HiveFailedToRefreshTokens(Exception): - """Raise invalid refresh tokens. +class HiveInvalid2FACode(HiveAuthCredentialError): + """Raise invalid 2FA code. Args: - Exception (object): Exception object to invoke + HiveAuthCredentialError (object): Parent credential error class + """ + + +class HiveReauthRequired(HiveError): + """Re-Authentication is required. + + Args: + HiveError (object): Parent Hive error class """ diff --git a/src/helper/hive_helper.py b/src/helper/hive_helper.py index d052f73..8768f1a 100644 --- a/src/helper/hive_helper.py +++ b/src/helper/hive_helper.py @@ -8,7 +8,6 @@ from typing import Any from .const import HIVE_TYPES -from .hivedataclasses import Device _LOGGER = logging.getLogger(__name__) @@ -26,7 +25,6 @@ def epoch_time(date_time: Any, pattern: str, action: str) -> Any: Converted value, or ``None`` if *action* is unrecognised. """ if action == "to_epoch": - pattern = "%d.%m.%Y %H:%M:%S" return int(time.mktime(time.strptime(str(date_time), pattern))) if action == "from_epoch": return datetime.datetime.fromtimestamp(int(date_time)).strftime(pattern) @@ -256,8 +254,9 @@ def get_schedule_nnl(self, hive_api_schedule: dict): # pylint: disable=too-many if slot_time_date_dt <= date_time_now: slot_time_date_dt = slot_time_date_dt + datetime.timedelta(days=7) - current_slot_custom["Start_DateTime"] = slot_time_date_dt - full_schedule_list.append(current_slot_custom) + slot_copy = dict(current_slot_custom) + slot_copy["Start_DateTime"] = slot_time_date_dt + full_schedule_list.append(slot_copy) fsl_sorted = sorted( full_schedule_list, @@ -297,19 +296,6 @@ def get_schedule_nnl(self, hive_api_schedule: dict): # pylint: disable=too-many return schedule_now_and_next - def get_heat_on_demand_device(self, device: Device): - """Use TRV device to get the linked thermostat device. - - Args: - device ([dictionary]): [The TRV device to lookup.] - - Returns: - [dictionary]: [Gets the thermostat device linked to TRV.] - """ - trv = self.session.data.products.get(device["HiveID"]) - thermostat = self.session.data.products.get(trv["state"]["zone"]) - return thermostat - def sanitize_payload(self, payload: dict[str, Any]) -> dict[str, Any]: """Return a copy of payload with sensitive values masked for logs.""" diff --git a/src/helper/hivedataclasses.py b/src/helper/hivedataclasses.py index 19cf3c4..7fe7891 100644 --- a/src/helper/hivedataclasses.py +++ b/src/helper/hivedataclasses.py @@ -102,12 +102,12 @@ class SessionTokens: class SessionConfig: """Typed container for session configuration state.""" - battery: list = field(default_factory=list) + battery: set = field(default_factory=set) error_list: dict = field(default_factory=dict) file: bool = False home_id: str | None = None last_update: datetime = field(default_factory=datetime.now) - mode: list = field(default_factory=list) + mode: set = field(default_factory=set) scan_interval: timedelta = field(default_factory=lambda: _SCAN_INTERVAL) user_id: str | None = None username: str | None = None diff --git a/src/helper/map.py b/src/helper/map.py index b4bd8ea..c6dc468 100644 --- a/src/helper/map.py +++ b/src/helper/map.py @@ -10,6 +10,16 @@ class Map(dict): dict (dict): dictionary to map. """ - __getattr__ = dict.get + def __getattr__(self, key): + try: + return self[key] + except KeyError: + raise AttributeError(f"Map has no key {key!r}") from None + __setattr__ = dict.__setitem__ - __delattr__ = dict.__delitem__ + + def __delattr__(self, key): + try: + del self[key] + except KeyError: + raise AttributeError(f"Map has no key {key!r}") from None diff --git a/src/session/__init__.py b/src/session/__init__.py index a240306..bb9f192 100644 --- a/src/session/__init__.py +++ b/src/session/__init__.py @@ -50,6 +50,7 @@ def __init__( username=username, password=password, ) + self._owns_websession = websession is None self.api = API(hive_session=self, websession=websession) self.helper = HiveHelper(self) self.attr = HiveAttributes(self) @@ -75,9 +76,10 @@ def __init__( self._update_task: asyncio.Task | None = None async def close(self) -> None: - """Close the underlying aiohttp ClientSession.""" - if not self.api.websession.closed: - await self.api.websession.close() + """Close the underlying aiohttp ClientSession if we own it.""" + websession = self.api.websession + if self._owns_websession and websession is not None and not websession.closed: + await websession.close() async def __aenter__(self): return self diff --git a/src/session/auth.py b/src/session/auth.py index 1d63f09..4f9d67d 100644 --- a/src/session/auth.py +++ b/src/session/auth.py @@ -70,10 +70,11 @@ async def _retry_with_backoff( raise except Exception as err: # pylint: disable=broad-except last_err = err - exc_type = reraise_as or ( - type(last_err) if last_err is not None else RuntimeError - ) - raise exc_type() from last_err # pylint: disable=broad-exception-raised + if reraise_as is not None: + raise reraise_as() from last_err + if last_err is not None: + raise last_err + raise RuntimeError("Retry attempts exhausted without capturing an error") async def update_tokens(self, tokens: dict, update_expiry_time: bool = True): """Update session tokens. @@ -91,17 +92,23 @@ async def update_tokens(self, tokens: dict, update_expiry_time: bool = True): ) if "AuthenticationResult" in tokens: data = tokens.get("AuthenticationResult") or {} - self.tokens.token_data.update({"token": data["IdToken"]}) + if "IdToken" in data: + self.tokens.token_data.update({"token": data["IdToken"]}) if "RefreshToken" in data: self.tokens.token_data.update({"refreshToken": data["RefreshToken"]}) - self.tokens.token_data.update({"accessToken": data["AccessToken"]}) + if "AccessToken" in data: + self.tokens.token_data.update({"accessToken": data["AccessToken"]}) if update_expiry_time: self.tokens.token_created = datetime.now() elif "token" in tokens: data = tokens self.tokens.token_data.update({"token": data["token"]}) - self.tokens.token_data.update({"refreshToken": data["refreshToken"]}) - self.tokens.token_data.update({"accessToken": data["accessToken"]}) + if "refreshToken" in data: + self.tokens.token_data.update({"refreshToken": data["refreshToken"]}) + if "accessToken" in data: + self.tokens.token_data.update({"accessToken": data["accessToken"]}) + if update_expiry_time: + self.tokens.token_created = datetime.now() if "ExpiresIn" in data: self.tokens.token_expiry = timedelta(seconds=data["ExpiresIn"]) @@ -335,7 +342,7 @@ async def hive_refresh_tokens(self, force_refresh: bool = False): ) try: result = await self.auth.refresh_token( - self.tokens.token_data["refreshToken"] + self.tokens.token_data.get("refreshToken") ) if result and "AuthenticationResult" in result: diff --git a/src/session/discovery.py b/src/session/discovery.py index 6ff691c..aa13ad8 100644 --- a/src/session/discovery.py +++ b/src/session/discovery.py @@ -8,7 +8,7 @@ from typing import Any from ..helper.const import DEVICES, EXPECTED_DEVICE_DATA_LENGTH, HIVE_TYPES, PRODUCTS -from ..helper.hive_exceptions import HiveReauthRequired, HiveUnknownConfiguration +from ..helper.hive_exceptions import HiveUnknownConfiguration from ..helper.hivedataclasses import Device _DATA_DIR = Path(__file__).parent.parent / "data" @@ -148,10 +148,15 @@ async def start_session(self, config: dict | None = None): self.auth.password = config["password"] if "device_data" in config and not self.config.file: - self.auth.device_group_key = config["device_data"][0] - self.auth.device_key = config["device_data"][1] - self.auth.device_password = config["device_data"][2] device_data = config["device_data"] + if len(device_data) < EXPECTED_DEVICE_DATA_LENGTH: + raise HiveUnknownConfiguration( + "device_data must contain device_group_key, " + "device_key and device_password" + ) + self.auth.device_group_key = device_data[0] + self.auth.device_key = device_data[1] + self.auth.device_password = device_data[2] if len(device_data) > EXPECTED_DEVICE_DATA_LENGTH: token_created = device_data[3] if token_created: @@ -163,10 +168,8 @@ async def start_session(self, config: dict | None = None): await self.get_devices("No_ID") # type: ignore[attr-defined] if not self.data.devices or not self.data.products: - _LOGGER.error( - "No devices or products returned from Hive API, reauthentication required." - ) - raise HiveReauthRequired + _LOGGER.error("No devices or products returned from Hive API.") + raise HiveUnknownConfiguration return await self.create_devices() @@ -237,7 +240,7 @@ async def create_devices( # noqa: PLR0912, PLR0915 ) if device_type in hive_type: - self.config.battery.append(d["id"]) + self.config.battery.add(d.get("id", a_device)) _LOGGER.debug( "create_devices - Added device %s to battery monitoring list", device_name, @@ -313,7 +316,7 @@ async def create_devices( # noqa: PLR0912, PLR0915 ) if product_type in hive_type: - self.config.mode.append(p["id"]) + self.config.mode.add(p.get("id", a_product)) _LOGGER.debug( "create_devices - Added product %s to mode list", product_name ) diff --git a/src/session/polling.py b/src/session/polling.py index 27341fc..e9e9784 100644 --- a/src/session/polling.py +++ b/src/session/polling.py @@ -144,7 +144,17 @@ async def get_devices(self, _n_id: str): # pylint: disable=too-many-locals,too- api_call_start = time.monotonic() try: api_resp_d = await self.api.get_all() + api_call_duration = time.monotonic() - api_call_start + if api_call_duration > self._slow_poll_threshold: + _LOGGER.debug( + "get_devices - Hive API response took %.1fs — marking poll as slow.", + api_call_duration, + ) + self._last_poll_slow = True + else: + self._last_poll_slow = False except HiveAuthError: + self._last_poll_slow = False _LOGGER.warning( "Auth error (401/403) after token refresh, " "falling back to full device re-login." @@ -154,15 +164,8 @@ async def get_devices(self, _n_id: str): # pylint: disable=too-many-locals,too- self.api.get_all, reraise_as=HiveReauthRequired, ) - api_call_duration = time.monotonic() - api_call_start - if api_call_duration > self._slow_poll_threshold: - _LOGGER.debug( - "get_devices - Hive API response took %.1fs — marking poll as slow.", - api_call_duration, - ) - self._last_poll_slow = True - else: - self._last_poll_slow = False + if api_resp_d is None: + return get_nodes_successful if not str(api_resp_d["original"]).startswith("2"): raise HTTPException if api_resp_d["parsed"] is None: @@ -178,7 +181,7 @@ async def get_devices(self, _n_id: str): # pylint: disable=too-many-locals,too- for hive_type_key in api_resp_p: if hive_type_key == "user": self.data.user = api_resp_p[hive_type_key] - self.config.user_id = api_resp_p[hive_type_key]["id"] + self.config.user_id = api_resp_p[hive_type_key].get("id") if hive_type_key == "products": for a_product in api_resp_p[hive_type_key]: tmp_products.update({a_product["id"]: a_product}) @@ -189,7 +192,11 @@ async def get_devices(self, _n_id: str): # pylint: disable=too-many-locals,too- for a_action in api_resp_p[hive_type_key]: tmp_actions.update({a_action["id"]: a_action}) if hive_type_key == "homes": - self.config.home_id = api_resp_p[hive_type_key]["homes"][0]["id"] + homes_data = api_resp_p[hive_type_key] + if isinstance(homes_data, dict): + homes_list = homes_data.get("homes") or [] + if homes_list: + self.config.home_id = homes_list[0]["id"] _LOGGER.debug( "get_devices - API returned %d products, %d devices, %d actions.", @@ -221,6 +228,7 @@ async def get_devices(self, _n_id: str): # pylint: disable=too-many-locals,too- HiveApiError, ConnectionError, HTTPException, + KeyError, ) as err: _LOGGER.error("Failed to fetch devices: %s", err) self.config.last_update = ( diff --git a/tests/module/test_hotwater.py b/tests/module/test_hotwater.py index b8f7433..f3ac68f 100644 --- a/tests/module/test_hotwater.py +++ b/tests/module/test_hotwater.py @@ -188,14 +188,6 @@ async def test_boosting_calls_execute_with_prev_mode(self): class TestGetScheduleNowNextLater: """Tests for WaterHeater.get_schedule_now_next_later.""" - async def test_schedule_mode_returns_nnl(self): - """SCHEDULE mode with a schedule returns the now/next/later dict.""" - hw = _make_hotwater( - {"hw-1": {"state": {"mode": _SCHEDULE_MODE, "schedule": {}}}} - ) - result = await hw.get_schedule_now_next_later(_make_device()) - assert result is not None - async def test_non_schedule_returns_none(self): """Non-SCHEDULE mode returns None.""" hw = _make_hotwater({"hw-1": {"state": {"mode": _ON_MODE}}}) diff --git a/tests/module/test_hub.py b/tests/module/test_hub.py index f3d5f33..06b61d7 100644 --- a/tests/module/test_hub.py +++ b/tests/module/test_hub.py @@ -128,11 +128,13 @@ async def test_context_manager_aenter_returns_self(self): assert hive is not None async def test_close_calls_websession_close(self): - """__aexit__ closes the underlying aiohttp websession.""" + """__aexit__ closes the lazily created aiohttp websession.""" async with Hive( username="test@example.com", password="pass", # pragma: allowlist secret ) as hive: - ws = hive.api.websession + # The websession is created lazily, on first use inside the loop. + assert hive.api.websession is None + ws = hive.api._get_websession() # After context exit the session should be closed assert ws.closed diff --git a/tests/module/test_light.py b/tests/module/test_light.py index c28d82e..8910759 100644 --- a/tests/module/test_light.py +++ b/tests/module/test_light.py @@ -1,7 +1,6 @@ """Tests for Light / HiveLight and LightColorHandler.""" # pylint: disable=too-few-public-methods -import colorsys from unittest.mock import AsyncMock, MagicMock from apyhiveapi.devices.light import Light @@ -10,9 +9,9 @@ _HTTP_OK = 200 _BRIGHTNESS_PCT = 50 -_BRIGHTNESS_HA = (_BRIGHTNESS_PCT / 100) * 255 +_BRIGHTNESS_HA = int((_BRIGHTNESS_PCT / 100) * 255) _BRIGHTNESS_RAW = 80 -_BRIGHTNESS_CONVERTED = (_BRIGHTNESS_RAW / 100) * 255 +_BRIGHTNESS_CONVERTED = int((_BRIGHTNESS_RAW / 100) * 255) _BRIGHTNESS_SET = 128 _COLOR_TEMP_KELVIN = 4000 _COLOR_TEMP_MIRED = round((1 / _COLOR_TEMP_KELVIN) * 1_000_000) @@ -23,10 +22,7 @@ _HSV_HUE = 120 _HSV_SAT = 100 _HSV_VAL = 100 -_COLOR_TUPLE = tuple( - int(i * 255) - for i in colorsys.hsv_to_rgb(_HSV_HUE / 360, _HSV_SAT / 100, _HSV_VAL / 100) -) +_COLOR_TUPLE = (_HSV_HUE, _HSV_SAT) def _make_light(products=None, devices=None): @@ -280,7 +276,7 @@ async def test_get_color_temp_missing_returns_none(self): assert await light.get_color_temp(_make_device()) is None async def test_get_color_returns_rgb_tuple(self): - """get_color returns an (R, G, B) tuple in 0–255 range.""" + """get_color returns an (hue, saturation) 2-tuple for HA hs_color.""" light = _make_light( { "light-1": { diff --git a/tests/module/test_sensor.py b/tests/module/test_sensor.py index f8cb038..ac0bfa5 100644 --- a/tests/module/test_sensor.py +++ b/tests/module/test_sensor.py @@ -65,19 +65,6 @@ async def test_contactsensor_closed_returns_false(self): ) assert await sensor.get_state(_make_device()) is False - async def test_motionsensor_returns_motion_status(self): - """get_state returns the motion status boolean for a motionsensor.""" - sensor = _make_sensor( - products={ - "sens-1": { - "type": "motionsensor", - "props": {"motion": {"status": True}}, - } - } - ) - result = await sensor.get_state(_make_device(hive_type="motionsensor")) - assert result is True - async def test_missing_key_returns_none(self): """get_state returns None when the hive_id is not in products.""" sensor = _make_sensor() diff --git a/tests/module/test_session.py b/tests/module/test_session.py index 3de0c5a..8713aa3 100644 --- a/tests/module/test_session.py +++ b/tests/module/test_session.py @@ -178,4 +178,5 @@ async def always_fails(): with patch("asyncio.sleep", new=AsyncMock()): with pytest.raises(Exception) as exc_info: await hive._retry_with_backoff(always_fails, delays=(0, 0)) # pylint: disable=protected-access - assert "permanent failure" in str(exc_info.value.__cause__) + # The original exception instance propagates unchanged. + assert "permanent failure" in str(exc_info.value) diff --git a/tests/module/test_session_discovery.py b/tests/module/test_session_discovery.py index 6bc07ea..cfb67fa 100644 --- a/tests/module/test_session_discovery.py +++ b/tests/module/test_session_discovery.py @@ -5,7 +5,6 @@ import pytest from apyhiveapi.helper.hive_exceptions import ( - HiveReauthRequired, HiveUnknownConfiguration, ) from apyhiveapi.helper.hivedataclasses import SessionConfig @@ -103,13 +102,6 @@ async def test_file_mode_username_enables_file_and_succeeds(self): assert s.config.file is True s.get_devices.assert_called_once() - async def test_empty_devices_after_get_devices_raises_reauth(self): - """start_session raises HiveReauthRequired when data.devices is empty post-poll.""" - s = _make_stub(has_data=False) - s.config.file = True - with pytest.raises(HiveReauthRequired): - await s.start_session({}) - async def test_no_tokens_in_non_file_config_raises_unknown_configuration(self): """Non-file mode config without tokens raises HiveUnknownConfiguration.""" s = _make_stub() diff --git a/tests/unit/test_action_extended.py b/tests/unit/test_action.py similarity index 100% rename from tests/unit/test_action_extended.py rename to tests/unit/test_action.py diff --git a/tests/unit/test_attributes.py b/tests/unit/test_attributes.py index ab2a075..d39eced 100644 --- a/tests/unit/test_attributes.py +++ b/tests/unit/test_attributes.py @@ -29,8 +29,8 @@ def _make_attrs(devices=None, products=None, battery=None, mode=None): } ) config = SessionConfig() - config.battery = battery or [] - config.mode = mode or [] + config.battery = set(battery) if battery else set() + config.mode = set(mode) if mode else set() session.config = config session.helper = MagicMock() session.helper.error_check = AsyncMock() @@ -92,13 +92,11 @@ async def test_missing_device_returns_none(self): assert await attrs.get_battery("nope") is None @pytest.mark.asyncio - async def test_calls_error_check(self): - """error_check should be called once with the device id, type, and battery level.""" + async def test_returns_correct_battery_level(self): + """Battery level is returned as the raw integer from props.""" attrs = _make_attrs(devices={"d1": {"props": {"battery": BATTERY_50}}}) - await attrs.get_battery("d1") - attrs.session.helper.error_check.assert_awaited_once_with( - "d1", "Attribute", BATTERY_50 - ) + result = await attrs.get_battery("d1") + assert result == BATTERY_50 @pytest.mark.asyncio async def test_battery_zero_returned(self): @@ -136,18 +134,18 @@ async def test_manual_mode_passes_through(self): assert result == "MANUAL" @pytest.mark.asyncio - async def test_true_value_maps_to_online(self): - """HIVETOHA["Attribute"][True] == "Online".""" + async def test_true_value_returned_directly(self): + """get_mode returns the raw mode value (True) without HIVETOHA translation.""" attrs = _make_attrs(products={"p1": {"state": {"mode": True}}}) result = await attrs.get_mode("p1") - assert result == "Online" + assert result is True @pytest.mark.asyncio - async def test_false_value_maps_to_offline(self): - """HIVETOHA["Attribute"][False] == "Offline".""" + async def test_false_value_returned_directly(self): + """get_mode returns the raw mode value (False) without HIVETOHA translation.""" attrs = _make_attrs(products={"p1": {"state": {"mode": False}}}) result = await attrs.get_mode("p1") - assert result == "Offline" + assert result is False # --------------------------------------------------------------------------- @@ -259,3 +257,31 @@ async def test_all_attributes_combined(self): assert result["available"] is True assert result["battery"] == "90%" assert result["mode"] == "SCHEDULE" + + +class TestGetModeNoOp: + """HiveAttributes.get_mode must return mode string directly (no HIVETOHA lookup).""" + + async def test_get_mode_returns_mode_string_unchanged(self): + """get_mode returns the raw mode string, not a boolean-keyed HIVETOHA lookup.""" + session = MagicMock() + session.data.products = {"p1": {"state": {"mode": "SCHEDULE"}}} + attr = HiveAttributes(session) + + result = await attr.get_mode("p1") + assert result == "SCHEDULE" + + +class TestGetBatteryNoDeadCall: + """HiveAttributes.get_battery must not call error_check with integer state.""" + + async def test_get_battery_returns_value_without_error_check_side_effects(self): + """get_battery returns the battery level; error_check is not called with int.""" + session = MagicMock() + session.data.devices = {"d1": {"props": {"battery": 85}}} + attr = HiveAttributes(session) + session.helper.error_check = AsyncMock() + + result = await attr.get_battery("d1") + assert result == 85 + session.helper.error_check.assert_not_called() diff --git a/tests/unit/test_boost_extended.py b/tests/unit/test_boost.py similarity index 100% rename from tests/unit/test_boost_extended.py rename to tests/unit/test_boost.py diff --git a/tests/unit/test_color_extended.py b/tests/unit/test_color.py similarity index 51% rename from tests/unit/test_color_extended.py rename to tests/unit/test_color.py index 6795724..dbe830b 100644 --- a/tests/unit/test_color_extended.py +++ b/tests/unit/test_color.py @@ -99,3 +99,79 @@ async def test_keyerror_on_missing_product_returns_none(self): result = await handler.get_max_color_temp(device) assert result is None + + +class TestZeroDivisionGuards: + """Colour-temperature methods must return None instead of raising ZeroDivisionError.""" + + async def test_get_min_color_temp_zero_returns_none(self): + """min colourTemperature == 0 must return None, not raise ZeroDivisionError. + + get_min_color_temp reads colourTemperature['max'] and divides by it, + so 'max' must be 0 to trigger ZeroDivisionError. + """ + session = _make_session( + products={ + "light-1": {"props": {"colourTemperature": {"max": 0, "min": 153}}} + } + ) + h = _make_handler(session) + device = _make_device() + result = await h.get_min_color_temp(device) + assert result is None + + async def test_get_max_color_temp_zero_returns_none(self): + """max colourTemperature == 0 must return None, not raise ZeroDivisionError. + + get_max_color_temp reads colourTemperature['min'] and divides by it, + so 'min' must be 0 to trigger ZeroDivisionError. + """ + session = _make_session( + products={ + "light-1": {"props": {"colourTemperature": {"max": 500, "min": 0}}} + } + ) + h = _make_handler(session) + device = _make_device() + result = await h.get_max_color_temp(device) + assert result is None + + async def test_get_color_temp_zero_returns_none(self): + """state colourTemperature == 0 must return None, not raise ZeroDivisionError.""" + session = _make_session( + products={"light-1": {"state": {"colourTemperature": 0}}} + ) + h = _make_handler(session) + device = _make_device() + result = await h.get_color_temp(device) + assert result is None + + +# --------------------------------------------------------------------------- +# get_color — must return HS 2-tuple for HA hs_color, not RGB 3-tuple +# --------------------------------------------------------------------------- + + +class TestGetColorReturnsHSTuple: + """get_color must return (hue_degrees, saturation_percent) 2-tuple for HA hs_color.""" + + async def test_get_color_returns_two_tuple(self): + """get_color returns a 2-tuple (not 3-tuple).""" + session = _make_session( + {"light-1": {"state": {"hue": 120, "saturation": 75, "value": 100}}} + ) + h = _make_handler(session) + device = _make_device() + result = await h.get_color(device) + assert result is not None + assert len(result) == 2, f"Expected 2-tuple (hue, sat), got {result!r}" + + async def test_get_color_returns_correct_hue_and_saturation(self): + """get_color returns (hue, saturation) values matching API data.""" + session = _make_session( + {"light-1": {"state": {"hue": 180, "saturation": 50, "value": 80}}} + ) + h = _make_handler(session) + device = _make_device() + result = await h.get_color(device) + assert result == (180, 50) diff --git a/tests/unit/test_compat_aliases.py b/tests/unit/test_compat_aliases.py index 1c066e7..8cc255f 100644 --- a/tests/unit/test_compat_aliases.py +++ b/tests/unit/test_compat_aliases.py @@ -5,13 +5,15 @@ from unittest.mock import AsyncMock from apyhiveapi.helper.compat_aliases import ( + ActionCompatMixin, HeatingCompatMixin, LightCompatMixin, + SensorCompatMixin, SessionCompatMixin, SwitchCompatMixin, WaterHeaterCompatMixin, ) -from apyhiveapi.helper.hivedataclasses import Device +from apyhiveapi.helper.hivedataclasses import Device, SessionConfig def _make_device(): @@ -319,13 +321,139 @@ class Stub(SessionCompatMixin): assert s.deviceList is s.device_list async def test_update_interval_returns_true(self): - """updateInterval always returns True (deprecated no-op).""" + """updateInterval returns True and updates config.scan_interval.""" class Stub(SessionCompatMixin): """Stub for updateInterval test.""" device_list = {} + def __init__(self): + self.config = SessionConfig() + s = Stub() result = await s.updateInterval(60) assert result is True + + +# --------------------------------------------------------------------------- +# SessionCompatMixin.updateInterval — bug fix tests +# --------------------------------------------------------------------------- + + +def _make_concrete_session(): + """Return a minimal SessionCompatMixin subclass with a real SessionConfig.""" + + class ConcreteSession(SessionCompatMixin): + """Minimal concrete SessionCompatMixin for updateInterval tests.""" + + def __init__(self): + self.config = SessionConfig() + self.device_list = {} + + async def start_session(self, config=None): # pylint: disable=unused-argument + """Stub.""" + + async def update_data(self, device): # pylint: disable=unused-argument + """Stub.""" + + return ConcreteSession() + + +class TestSessionCompatMixinUpdateInterval: + """updateInterval must actually update config.scan_interval.""" + + async def test_update_interval_sets_scan_interval(self): + """updateInterval(300) must set self.config.scan_interval to timedelta(seconds=300).""" + from datetime import timedelta + + session = _make_concrete_session() + await session.updateInterval(300) + assert session.config.scan_interval == timedelta(seconds=300) + + async def test_update_interval_returns_true(self): + """updateInterval must return True on success.""" + session = _make_concrete_session() + result = await session.updateInterval(60) + assert result is True + + +# --------------------------------------------------------------------------- +# Migrated from test_compat_aliases_extended.py +# --------------------------------------------------------------------------- + + +def _make_action_device(hive_type="action", ha_type="switch"): + return Device( + hive_id="h1", + hive_name="Test", + hive_type=hive_type, + ha_type=ha_type, + device_id="d1", + device_name="Test", + device_data={}, + ) + + +class TestSensorCompatMixin: + """CamelCase alias smoke tests for SensorCompatMixin.""" + + async def test_get_sensor_delegates(self): + """getSensor delegates to get_sensor and returns its result.""" + + class Stub(SensorCompatMixin): + """Stub with mocked get_sensor.""" + + get_sensor = AsyncMock(return_value="sensor_result") + + s = Stub() + d = _make_action_device(hive_type="motionsensor", ha_type="binary_sensor") + result = await s.getSensor(d) + s.get_sensor.assert_called_once_with(d) + assert result == "sensor_result" + + +class TestActionCompatMixin: + """CamelCase alias smoke tests for ActionCompatMixin.""" + + async def test_get_action_delegates(self): + """getAction delegates to get_action and returns its result.""" + + class Stub(ActionCompatMixin): + """Stub with mocked get_action.""" + + get_action = AsyncMock(return_value="action_result") + + s = Stub() + d = _make_action_device() + result = await s.getAction(d) + s.get_action.assert_called_once_with(d) + assert result == "action_result" + + async def test_set_status_on_delegates(self): + """setStatusOn delegates to set_status_on and returns its result.""" + + class Stub(ActionCompatMixin): + """Stub with mocked set_status_on.""" + + set_status_on = AsyncMock(return_value=True) + + s = Stub() + d = _make_action_device() + result = await s.setStatusOn(d) + s.set_status_on.assert_called_once_with(d) + assert result is True + + async def test_set_status_off_delegates(self): + """setStatusOff delegates to set_status_off and returns its result.""" + + class Stub(ActionCompatMixin): + """Stub with mocked set_status_off.""" + + set_status_off = AsyncMock(return_value=True) + + s = Stub() + d = _make_action_device() + result = await s.setStatusOff(d) + s.set_status_off.assert_called_once_with(d) + assert result is True diff --git a/tests/unit/test_compat_aliases_extended.py b/tests/unit/test_compat_aliases_extended.py deleted file mode 100644 index e103f48..0000000 --- a/tests/unit/test_compat_aliases_extended.py +++ /dev/null @@ -1,94 +0,0 @@ -"""Tests for SensorCompatMixin and ActionCompatMixin aliases (coverage gap fill).""" - -# pylint: disable=too-few-public-methods - -from unittest.mock import AsyncMock - -from apyhiveapi.helper.compat_aliases import ActionCompatMixin, SensorCompatMixin -from apyhiveapi.helper.hivedataclasses import Device - - -def _make_device(hive_type="action", ha_type="switch"): - return Device( - hive_id="h1", - hive_name="Test", - hive_type=hive_type, - ha_type=ha_type, - device_id="d1", - device_name="Test", - device_data={}, - ) - - -# --------------------------------------------------------------------------- -# SensorCompatMixin -# --------------------------------------------------------------------------- - - -class TestSensorCompatMixin: - """CamelCase alias smoke tests for SensorCompatMixin.""" - - async def test_get_sensor_delegates(self): - """getSensor delegates to get_sensor and returns its result.""" - - class Stub(SensorCompatMixin): - """Stub with mocked get_sensor.""" - - get_sensor = AsyncMock(return_value="sensor_result") - - s = Stub() - d = _make_device(hive_type="motionsensor", ha_type="binary_sensor") - result = await s.getSensor(d) - s.get_sensor.assert_called_once_with(d) - assert result == "sensor_result" - - -# --------------------------------------------------------------------------- -# ActionCompatMixin -# --------------------------------------------------------------------------- - - -class TestActionCompatMixin: - """CamelCase alias smoke tests for ActionCompatMixin.""" - - async def test_get_action_delegates(self): - """getAction delegates to get_action and returns its result.""" - - class Stub(ActionCompatMixin): - """Stub with mocked get_action.""" - - get_action = AsyncMock(return_value="action_result") - - s = Stub() - d = _make_device() - result = await s.getAction(d) - s.get_action.assert_called_once_with(d) - assert result == "action_result" - - async def test_set_status_on_delegates(self): - """setStatusOn delegates to set_status_on and returns its result.""" - - class Stub(ActionCompatMixin): - """Stub with mocked set_status_on.""" - - set_status_on = AsyncMock(return_value=True) - - s = Stub() - d = _make_device() - result = await s.setStatusOn(d) - s.set_status_on.assert_called_once_with(d) - assert result is True - - async def test_set_status_off_delegates(self): - """setStatusOff delegates to set_status_off and returns its result.""" - - class Stub(ActionCompatMixin): - """Stub with mocked set_status_off.""" - - set_status_off = AsyncMock(return_value=True) - - s = Stub() - d = _make_device() - result = await s.setStatusOff(d) - s.set_status_off.assert_called_once_with(d) - assert result is True diff --git a/tests/unit/test_dataclasses.py b/tests/unit/test_dataclasses.py index 5792fd3..e7f9f6c 100644 --- a/tests/unit/test_dataclasses.py +++ b/tests/unit/test_dataclasses.py @@ -112,12 +112,42 @@ def test_default_scan_interval_is_120s(self): c = SessionConfig() assert c.scan_interval == timedelta(seconds=120) - def test_default_battery_is_empty_list(self): - """Test battery defaults to empty list.""" + def test_default_battery_is_empty_set(self): + """Test battery defaults to empty set.""" c = SessionConfig() - assert c.battery == [] + assert c.battery == set() def test_username_stored(self): """Test username can be set and retrieved.""" c = SessionConfig(username="user@example.com") assert c.username == "user@example.com" + + +class TestSessionConfigCollectionTypes: + """battery and mode must be sets for O(1) membership checks.""" + + def test_battery_is_set(self): + """SessionConfig.battery is a set (not a list).""" + config = SessionConfig() + assert isinstance(config.battery, set), ( + f"Expected set, got {type(config.battery).__name__}" + ) + + def test_mode_is_set(self): + """SessionConfig.mode is a set (not a list).""" + config = SessionConfig() + assert isinstance(config.mode, set), ( + f"Expected set, got {type(config.mode).__name__}" + ) + + def test_battery_supports_membership_check(self): + """Can check membership in battery using 'in' after add.""" + config = SessionConfig() + config.battery.add("d1") + assert "d1" in config.battery + + def test_mode_supports_membership_check(self): + """Can check membership in mode using 'in' after add.""" + config = SessionConfig() + config.mode.add("p1") + assert "p1" in config.mode diff --git a/tests/unit/test_device_registration.py b/tests/unit/test_device_registration.py index 372eada..081b494 100644 --- a/tests/unit/test_device_registration.py +++ b/tests/unit/test_device_registration.py @@ -84,35 +84,35 @@ class StubDRM(DeviceRegistrationMixin): class TestGenerateHashDevice: async def test_returns_verifier_config_with_required_keys(self): stub = await _make_stub() - result = await stub.generate_hash_device("grp-key", "dev-key") + result = stub.generate_hash_device("grp-key", "dev-key") assert "PasswordVerifier" in result assert "Salt" in result async def test_password_verifier_is_non_empty_string(self): stub = await _make_stub() - result = await stub.generate_hash_device("grp-key", "dev-key") + result = stub.generate_hash_device("grp-key", "dev-key") assert isinstance(result["PasswordVerifier"], str) assert len(result["PasswordVerifier"]) > 0 async def test_salt_is_non_empty_string(self): stub = await _make_stub() - result = await stub.generate_hash_device("grp-key", "dev-key") + result = stub.generate_hash_device("grp-key", "dev-key") assert isinstance(result["Salt"], str) assert len(result["Salt"]) > 0 async def test_sets_device_password_on_self(self): stub = await _make_stub() stub.device_password = None - await stub.generate_hash_device("grp-key", "dev-key") + stub.generate_hash_device("grp-key", "dev-key") assert stub.device_password is not None assert isinstance(stub.device_password, str) assert len(stub.device_password) > 0 async def test_different_calls_produce_different_passwords(self): stub = await _make_stub() - await stub.generate_hash_device("grp-key", "dev-key") + stub.generate_hash_device("grp-key", "dev-key") password_first = stub.device_password - await stub.generate_hash_device("grp-key", "dev-key") + stub.generate_hash_device("grp-key", "dev-key") password_second = stub.device_password # Passwords are randomly generated — they should almost never match. # We check they are independently set strings (not None). @@ -121,9 +121,9 @@ async def test_different_calls_produce_different_passwords(self): async def test_different_device_keys_produce_different_verifiers(self): stub = await _make_stub() - result1 = await stub.generate_hash_device("grp-key", "dev-key-1") + result1 = stub.generate_hash_device("grp-key", "dev-key-1") verifier1 = result1["PasswordVerifier"] - result2 = await stub.generate_hash_device("grp-key", "dev-key-2") + result2 = stub.generate_hash_device("grp-key", "dev-key-2") verifier2 = result2["PasswordVerifier"] # Different keys + random passwords → almost certainly different verifiers # (at minimum the structure is valid for both) @@ -173,7 +173,7 @@ async def test_reflects_updated_device_key(self): class TestConfirmDevice: async def test_uses_hostname_when_no_device_name(self): stub = await _make_stub() - stub.generate_hash_device = AsyncMock( + stub.generate_hash_device = MagicMock( return_value={"PasswordVerifier": "pv", "Salt": "s"} ) with patch( @@ -191,7 +191,7 @@ async def test_uses_hostname_when_no_device_name(self): async def test_uses_provided_device_name(self): stub = await _make_stub() - stub.generate_hash_device = AsyncMock( + stub.generate_hash_device = MagicMock( return_value={"PasswordVerifier": "pv", "Salt": "s"} ) await stub.confirm_device("custom-name") @@ -202,27 +202,27 @@ async def test_uses_provided_device_name(self): async def test_returns_executor_result_on_success(self): stub = await _make_stub() - stub.generate_hash_device = AsyncMock( + stub.generate_hash_device = MagicMock( return_value={"PasswordVerifier": "pv", "Salt": "s"} ) stub.loop.run_in_executor.return_value = {"UserConfirmed": True} result = await stub.confirm_device("test-device") assert result == {"UserConfirmed": True} - async def test_not_authorized_raises_invalid_2fa(self): + async def test_not_authorized_raises_api_error(self): stub = await _make_stub() - stub.generate_hash_device = AsyncMock( + stub.generate_hash_device = MagicMock( return_value={"PasswordVerifier": "pv", "Salt": "s"} ) stub.loop.run_in_executor.side_effect = _named_client_error( "NotAuthorizedException" ) - with pytest.raises(HiveInvalid2FACode): + with pytest.raises(HiveApiError): await stub.confirm_device("name") async def test_code_mismatch_raises_invalid_2fa(self): stub = await _make_stub() - stub.generate_hash_device = AsyncMock( + stub.generate_hash_device = MagicMock( return_value={"PasswordVerifier": "pv", "Salt": "s"} ) stub.loop.run_in_executor.side_effect = _named_client_error( @@ -233,7 +233,7 @@ async def test_code_mismatch_raises_invalid_2fa(self): async def test_endpoint_error_raises_api_error(self): stub = await _make_stub() - stub.generate_hash_device = AsyncMock( + stub.generate_hash_device = MagicMock( return_value={"PasswordVerifier": "pv", "Salt": "s"} ) stub.loop.run_in_executor.side_effect = _endpoint_error() @@ -242,7 +242,7 @@ async def test_endpoint_error_raises_api_error(self): async def test_passes_access_token_to_executor(self): stub = await _make_stub(access_token="my-access-token") - stub.generate_hash_device = AsyncMock( + stub.generate_hash_device = MagicMock( return_value={"PasswordVerifier": "pv", "Salt": "s"} ) await stub.confirm_device("dev") @@ -252,7 +252,7 @@ async def test_passes_access_token_to_executor(self): async def test_passes_device_key_to_executor(self): stub = await _make_stub(device_key="my-device-key") - stub.generate_hash_device = AsyncMock( + stub.generate_hash_device = MagicMock( return_value={"PasswordVerifier": "pv", "Salt": "s"} ) await stub.confirm_device("dev") @@ -466,48 +466,24 @@ async def test_passes_access_token_and_device_key_to_executor(self): assert partial_fn.keywords["AccessToken"] == "forget-token" assert partial_fn.keywords["DeviceKey"] == "forget-key" - async def test_not_authorized_raises_invalid_2fa(self): + async def test_not_authorized_raises_api_error(self): stub = await _make_stub() stub.loop.run_in_executor.side_effect = _named_client_error( "NotAuthorizedException" ) - with pytest.raises(HiveInvalid2FACode): + with pytest.raises(HiveApiError): await stub.forget_device("acc-token", "dev-key") - async def test_other_client_error_does_not_raise(self): - """ClientErrors other than NotAuthorizedException are silently swallowed.""" + async def test_other_client_error_raises_api_error(self): + """All ClientErrors raise HiveApiError.""" stub = await _make_stub() stub.loop.run_in_executor.side_effect = _named_client_error("SomeOtherError") - # No exception raised — result will be None - result = await stub.forget_device("acc-token", "dev-key") - assert result is None + with pytest.raises(HiveApiError): + await stub.forget_device("acc-token", "dev-key") - async def test_endpoint_error_does_not_raise_api_error(self): - """EndpointConnectionError only raises HiveApiError if class name is - 'ResourceNotFoundException', which can never be true for an - EndpointConnectionError. The exception is therefore silently swallowed.""" + async def test_endpoint_error_raises_api_error(self): stub = await _make_stub() stub.loop.run_in_executor.side_effect = _endpoint_error() - # The guard condition is always False for a real EndpointConnectionError, - # so no exception propagates. - result = await stub.forget_device("acc-token", "dev-key") - assert result is None - - async def test_endpoint_error_named_resource_not_found_raises_api_error(self): - """A subclass of EndpointConnectionError named 'ResourceNotFoundException' - satisfies the guard at line 339 and raises HiveApiError (line 340).""" - stub = await _make_stub() - # Craft a class whose __class__.__name__ == "ResourceNotFoundException" - # but which IS an EndpointConnectionError (so it's caught by the except clause) - resource_cls = type( - "ResourceNotFoundException", - (botocore.exceptions.EndpointConnectionError,), - {}, - ) - resource_err = resource_cls( - endpoint_url="https://cognito.eu-west-1.amazonaws.com" - ) - stub.loop.run_in_executor.side_effect = resource_err with pytest.raises(HiveApiError): await stub.forget_device("acc-token", "dev-key") @@ -528,7 +504,7 @@ async def test_u_value_zero_raises_value_error(self): stub = await _make_stub() with patch("apyhiveapi.api.device_registration.calculate_u", return_value=0): with pytest.raises(ValueError, match="U cannot be zero"): - await stub.get_device_authentication_key( + stub.get_device_authentication_key( stub.device_group_key, stub.device_key, stub.device_password, @@ -555,7 +531,7 @@ async def fake_init(): stub.client = MagicMock() stub.async_init = fake_init - stub.generate_hash_device = AsyncMock( + stub.generate_hash_device = MagicMock( return_value={"PasswordVerifier": "pv", "Salt": "s"} ) await stub.confirm_device("name") @@ -622,10 +598,10 @@ async def fake_init(): class TestConfirmDeviceSwallowedErrors: - async def test_other_client_error_is_swallowed(self): - """ClientError with an unrecognised class name is caught but not re-raised (184->193).""" + async def test_other_client_error_raises_api_error(self): + """ClientErrors other than CodeMismatchException raise HiveApiError.""" stub = await _make_stub() - stub.generate_hash_device = AsyncMock( + stub.generate_hash_device = MagicMock( return_value={"PasswordVerifier": "pv", "Salt": "s"} ) wrong_cls = type("SomeOtherError", (botocore.exceptions.ClientError,), {}) @@ -633,35 +609,8 @@ async def test_other_client_error_is_swallowed(self): {"Error": {"Code": "SomeOtherError", "Message": "msg"}}, "op" ) stub.loop.run_in_executor.side_effect = wrong_err - result = await stub.confirm_device("name") - assert result is None # no HiveInvalid2FACode raised - - async def test_endpoint_error_wrong_name_is_swallowed(self): - """EndpointConnectionError subclass with wrong __name__ is swallowed (190->193).""" - stub = await _make_stub() - stub.generate_hash_device = AsyncMock( - return_value={"PasswordVerifier": "pv", "Salt": "s"} - ) - wrong_cls = type( - "WrongEndpoint", (botocore.exceptions.EndpointConnectionError,), {} - ) - wrong_err = wrong_cls(endpoint_url="https://cognito.eu-west-1.amazonaws.com") - stub.loop.run_in_executor.side_effect = wrong_err - result = await stub.confirm_device("name") - assert result is None # no HiveApiError raised - - -class TestUpdateDeviceStatusSwallowedEndpointError: - async def test_endpoint_error_wrong_name_is_swallowed(self): - """EndpointConnectionError with wrong name is caught but not re-raised (211->214).""" - stub = await _make_stub() - wrong_cls = type( - "WrongEndpoint", (botocore.exceptions.EndpointConnectionError,), {} - ) - wrong_err = wrong_cls(endpoint_url="https://cognito.eu-west-1.amazonaws.com") - stub.loop.run_in_executor.side_effect = wrong_err - result = await stub.update_device_status() - assert result is None # no HiveApiError raised + with pytest.raises(HiveApiError): + await stub.confirm_device("name") class TestDeviceRegistration: @@ -716,7 +665,6 @@ async def test_returns_response_with_required_keys(self): with patch.object( stub, "get_device_authentication_key", - new_callable=AsyncMock, return_value=fake_hkdf, ): result = await stub.process_device_challenge(self._CHALLENGE_PARAMS) @@ -733,7 +681,6 @@ async def test_username_matches_challenge_parameter(self): with patch.object( stub, "get_device_authentication_key", - new_callable=AsyncMock, return_value=fake_hkdf, ): result = await stub.process_device_challenge(self._CHALLENGE_PARAMS) @@ -746,7 +693,6 @@ async def test_device_key_matches_stub_device_key(self): with patch.object( stub, "get_device_authentication_key", - new_callable=AsyncMock, return_value=fake_hkdf, ): result = await stub.process_device_challenge(self._CHALLENGE_PARAMS) @@ -759,7 +705,6 @@ async def test_secret_block_echoed_back(self): with patch.object( stub, "get_device_authentication_key", - new_callable=AsyncMock, return_value=fake_hkdf, ): result = await stub.process_device_challenge(self._CHALLENGE_PARAMS) @@ -772,7 +717,6 @@ async def test_no_client_secret_no_secret_hash(self): with patch.object( stub, "get_device_authentication_key", - new_callable=AsyncMock, return_value=fake_hkdf, ): result = await stub.process_device_challenge(self._CHALLENGE_PARAMS) @@ -785,7 +729,6 @@ async def test_with_client_secret_adds_secret_hash(self): with patch.object( stub, "get_device_authentication_key", - new_callable=AsyncMock, return_value=fake_hkdf, ): result = await stub.process_device_challenge(self._CHALLENGE_PARAMS) @@ -800,7 +743,6 @@ async def test_timestamp_format_matches_cognito_pattern(self): with patch.object( stub, "get_device_authentication_key", - new_callable=AsyncMock, return_value=fake_hkdf, ): result = await stub.process_device_challenge(self._CHALLENGE_PARAMS) @@ -823,7 +765,6 @@ async def test_password_claim_signature_is_base64_string(self): with patch.object( stub, "get_device_authentication_key", - new_callable=AsyncMock, return_value=fake_hkdf, ): result = await stub.process_device_challenge(self._CHALLENGE_PARAMS) @@ -843,7 +784,6 @@ async def test_salt_as_integer_is_padded(self): with patch.object( stub, "get_device_authentication_key", - new_callable=AsyncMock, return_value=fake_hkdf, ) as mock_auth_key: await stub.process_device_challenge(params) @@ -863,7 +803,7 @@ async def test_returns_16_bytes(self): # Use a valid server_b_value that won't make u_value == 0. # Pick a large prime-ish value that is different from large_a_value. server_b_value = stub.large_a_value + 1 - result = await stub.get_device_authentication_key( + result = stub.get_device_authentication_key( "grp-key", "dev-key", "dev-pass", @@ -876,10 +816,80 @@ async def test_returns_16_bytes(self): async def test_deterministic_for_same_inputs(self): stub = await _make_stub() server_b_value = stub.large_a_value + 1 - result1 = await stub.get_device_authentication_key( + result1 = stub.get_device_authentication_key( "grp-key", "dev-key", "dev-pass", server_b_value, "aabbccdd" ) - result2 = await stub.get_device_authentication_key( + result2 = stub.get_device_authentication_key( "grp-key", "dev-key", "dev-pass", server_b_value, "aabbccdd" ) assert result1 == result2 + + +# --------------------------------------------------------------------------- +# confirm_device and forget_device wrong exception mapping +# --------------------------------------------------------------------------- + + +class TestConfirmDeviceWrongException: + """confirm_device must map NotAuthorizedException → HiveApiError (not HiveInvalid2FACode).""" + + async def test_not_authorized_raises_hive_api_error(self): + """NotAuthorizedException in confirm_device raises HiveApiError, not HiveInvalid2FACode.""" + stub = await _make_stub() + err = botocore.exceptions.ClientError( + {"Error": {"Code": "NotAuthorizedException", "Message": "not auth"}}, + "ConfirmDevice", + ) + stub.loop.run_in_executor = AsyncMock(side_effect=err) + + with pytest.raises(HiveApiError): + await stub.confirm_device() + + async def test_not_authorized_does_not_raise_invalid_2fa(self): + """confirm_device must not raise HiveInvalid2FACode for NotAuthorizedException.""" + stub = await _make_stub() + err = botocore.exceptions.ClientError( + {"Error": {"Code": "NotAuthorizedException", "Message": "not auth"}}, + "ConfirmDevice", + ) + stub.loop.run_in_executor = AsyncMock(side_effect=err) + + with pytest.raises(Exception) as exc_info: + await stub.confirm_device() + + assert not isinstance(exc_info.value, HiveInvalid2FACode) + + +class TestForgetDeviceWrongException: + """forget_device must map NotAuthorizedException → HiveApiError.""" + + async def test_not_authorized_raises_hive_api_error(self): + """NotAuthorizedException in forget_device raises HiveApiError, not HiveInvalid2FACode.""" + stub = await _make_stub() + err = botocore.exceptions.ClientError( + {"Error": {"Code": "NotAuthorizedException", "Message": "not auth"}}, + "ForgetDevice", + ) + stub.loop.run_in_executor = AsyncMock(side_effect=err) + + with pytest.raises(HiveApiError): + await stub.forget_device("tok", "key") + + +# --------------------------------------------------------------------------- +# generate_hash_device — unnecessary async removed +# --------------------------------------------------------------------------- + + +class TestUnnecessaryAsync: + """generate_hash_device should work without async (callable directly).""" + + def test_generate_hash_device_returns_config(self): + """generate_hash_device returns the verifier config without needing await.""" + r = DeviceRegistrationMixin.__new__(DeviceRegistrationMixin) + r.device_password = None + r.g_value = 2 + r.big_n = 0xFFFF + result = r.generate_hash_device("grp", "key") + assert "PasswordVerifier" in result + assert "Salt" in result diff --git a/tests/unit/test_heating.py b/tests/unit/test_heating.py new file mode 100644 index 0000000..56c13ca --- /dev/null +++ b/tests/unit/test_heating.py @@ -0,0 +1,495 @@ +"""Extended branch-coverage tests for Climate / HiveHeating.""" + +# pylint: disable=too-few-public-methods +from datetime import datetime +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from apyhiveapi.devices.heating import Climate +from apyhiveapi.helper.hivedataclasses import Device, SessionConfig +from apyhiveapi.helper.map import Map + +_TODAY = str(datetime.date(datetime.now())) +_CURRENT_TEMP = 19.0 +_SCHEDULE_MODE = "SCHEDULE" +_BOOST_MINS = 5 +_OFF_MODE = "OFF" + + +def _make_climate(products=None, devices=None, min_max=None): + session = MagicMock() + session.data = Map( + { + "products": products or {}, + "devices": devices or {}, + "actions": {}, + "minMax": min_max or {}, + "user": {}, + } + ) + session.config = SessionConfig() + session.helper = MagicMock() + session.helper.device_recovered = MagicMock() + session.helper.error_check = AsyncMock() + session.helper.get_schedule_nnl = MagicMock( + return_value={"now": {}, "next": {}, "later": {}} + ) + session.attr = MagicMock() + session.attr.online_offline = AsyncMock(return_value=True) + session.attr.state_attributes = AsyncMock(return_value={}) + session.api = MagicMock() + session.api.set_state = AsyncMock(return_value={"original": 200, "parsed": {}}) + session.hive_refresh_tokens = AsyncMock() + session.get_devices = AsyncMock(return_value=True) + session.should_use_cached_data = MagicMock(return_value=False) + session.get_cached_device = MagicMock(return_value=None) + session.set_cached_device = MagicMock(side_effect=lambda d: d) + return Climate(session=session) + + +def _make_device(hive_id="heat-1", device_id="dev-1", hive_type="heating"): + return Device( + hive_id=hive_id, + hive_name="Hallway", + hive_type=hive_type, + ha_type="climate", + device_id=device_id, + device_name="Hallway", + device_data={"online": True}, + ha_name="Hallway", + ) + + +class TestGetCurrentTemperature: + async def test_minmax_today_same_date_updates_min_max(self): + """When minMax entry exists for today's date, TodayMin/TodayMax are updated.""" + initial_min = _CURRENT_TEMP + 2.0 + initial_max = _CURRENT_TEMP - 2.0 + existing = { + "TodayMin": initial_min, + "TodayMax": initial_max, + "TodayDate": _TODAY, + "RestartMin": initial_min, + "RestartMax": initial_max, + } + climate = _make_climate( + products={"heat-1": {"props": {"temperature": _CURRENT_TEMP}}}, + min_max={"heat-1": existing}, + ) + result = await climate.get_current_temperature(_make_device()) + assert result == _CURRENT_TEMP + entry = climate.session.data.minMax["heat-1"] + assert entry["TodayMin"] == min(initial_min, _CURRENT_TEMP) + assert entry["TodayMax"] == max(initial_max, _CURRENT_TEMP) + assert entry["RestartMin"] == min(initial_min, _CURRENT_TEMP) + assert entry["RestartMax"] == max(initial_max, _CURRENT_TEMP) + + async def test_minmax_different_date_resets_today(self): + """When minMax entry exists but TodayDate is stale, today values are reset.""" + existing = { + "TodayMin": 5.0, + "TodayMax": 30.0, + "TodayDate": "2000-01-01", + "RestartMin": 5.0, + "RestartMax": 30.0, + } + climate = _make_climate( + products={"heat-1": {"props": {"temperature": _CURRENT_TEMP}}}, + min_max={"heat-1": existing}, + ) + result = await climate.get_current_temperature(_make_device()) + assert result == _CURRENT_TEMP + entry = climate.session.data.minMax["heat-1"] + assert entry["TodayMin"] == _CURRENT_TEMP + assert entry["TodayMax"] == _CURRENT_TEMP + assert entry["TodayDate"] == _TODAY + + async def test_keyerror_returns_none(self): + """Missing device.hive_id in products returns None.""" + climate = _make_climate(products={}) + result = await climate.get_current_temperature(_make_device()) + assert result is None + + +class TestGetTargetTemperature: + async def test_non_numeric_target_returns_none(self): + """Non-numeric target temperature string returns None.""" + climate = _make_climate({"heat-1": {"state": {"target": "N/A"}}}) + result = await climate.get_target_temperature(_make_device()) + assert result is None + + +class TestGetState: + async def test_current_less_than_target_returns_on(self): + """When current_temp < target_temp, state resolves to ON.""" + climate = _make_climate( + { + "heat-1": { + "props": {"temperature": 19.0}, + "state": {"target": 21.0}, + } + } + ) + result = await climate.get_state(_make_device()) + assert result == "ON" + + async def test_current_ge_target_returns_off(self): + """When current_temp >= target_temp, state resolves to OFF.""" + climate = _make_climate( + { + "heat-1": { + "props": {"temperature": 21.0}, + "state": {"target": 19.0}, + } + } + ) + result = await climate.get_state(_make_device()) + assert result == "OFF" + + async def test_none_temps_returns_none(self): + """When temperatures cannot be read, get_state returns None.""" + climate = _make_climate(products={}) + result = await climate.get_state(_make_device()) + assert result is None + + async def test_key_error_in_get_current_temperature_is_caught(self): + """KeyError from get_current_temperature is caught; get_state returns None.""" + climate = _make_climate( + {"heat-1": {"state": {"mode": "MANUAL", "target": 20.0}, "props": {}}} + ) + with patch.object( + climate, "get_current_temperature", new_callable=AsyncMock + ) as mock_t: + mock_t.side_effect = KeyError("missing_key") + result = await climate.get_state(_make_device()) + assert result is None + + async def test_type_error_in_get_target_temperature_is_caught(self): + """TypeError from get_target_temperature is caught; get_state returns None.""" + climate = _make_climate( + {"heat-1": {"state": {"mode": "MANUAL", "target": 20.0}, "props": {}}} + ) + with patch.object( + climate, "get_current_temperature", new_callable=AsyncMock + ) as mock_cur: + mock_cur.return_value = 19.0 + with patch.object( + climate, "get_target_temperature", new_callable=AsyncMock + ) as mock_tgt: + mock_tgt.side_effect = TypeError("bad type") + result = await climate.get_state(_make_device()) + assert result is None + + +class TestGetCurrentOperation: + async def test_returns_working_state(self): + """get_current_operation returns the 'working' value from props.""" + climate = _make_climate({"heat-1": {"props": {"working": True}, "state": {}}}) + result = await climate.get_current_operation(_make_device()) + assert result is True + + +class TestSetBoostOnValidation: + """set_boost_on input validation must reject bad values with None.""" + + async def test_non_numeric_mins_returns_none(self): + climate = _make_climate(products={"heat-1": {"type": "heating"}}) + result = await climate.set_boost_on(_make_device(), "abc", 20) + assert result is None + + async def test_non_numeric_temp_returns_none(self): + climate = _make_climate(products={"heat-1": {"type": "heating"}}) + result = await climate.set_boost_on(_make_device(), "30", "hot") + assert result is None + + async def test_temp_fraction_above_max_returns_none(self): + """32.9 exceeds the 32 maximum and must not be truncated past validation.""" + climate = _make_climate(products={"heat-1": {"type": "heating"}}) + result = await climate.set_boost_on(_make_device(), "30", 32.9) + assert result is None + + async def test_valid_boost_executes_state_change(self): + climate = _make_climate(products={"heat-1": {"type": "heating"}}) + result = await climate.set_boost_on(_make_device(), "30", 22) + assert result is True + + +class TestSetBoostOff: + async def test_not_in_products_returns_false(self): + """Device hive_id not present in products returns False.""" + climate = _make_climate(products={}) + result = await climate.set_boost_off(_make_device()) + assert result is False + + async def test_previous_off_mode_restored(self): + """Previous mode OFF sets mode=OFF and target falls back to 7.""" + climate = _make_climate( + { + "heat-1": { + "type": "heating", + "state": {"boost": _BOOST_MINS}, + "props": { + "previous": { + "mode": _OFF_MODE, + "target": None, + } + }, + } + } + ) + result = await climate.set_boost_off(_make_device()) + assert result is True + _, kwargs = climate.session.api.set_state.call_args + assert kwargs.get("mode") == _OFF_MODE + assert kwargs.get("target") == 7 + + +class TestGetClimate: + async def test_device_data_not_dict_gets_reset(self): + """Non-dict device_data is replaced with an empty dict before use.""" + climate = _make_climate( + products={"heat-1": {"props": {}, "state": {}}}, + devices={"dev-1": {"props": {}, "parent": None}}, + ) + d = _make_device() + d.device_data = None + await climate.get_climate(d) + assert isinstance(d.device_data, dict) + + async def test_offline_device_error_check_called(self): + """Offline device triggers error_check and status defaults to None values.""" + climate = _make_climate( + products={"heat-1": {}}, + devices={"dev-1": {}}, + ) + climate.session.attr.online_offline = AsyncMock(return_value=False) + d = _make_device() + result = await climate.get_climate(d) + climate.session.helper.error_check.assert_called_once() + assert result.status["current_temperature"] is None + + async def test_cache_hit_returns_cached(self): + """When cached data is available and poll is slow, returns cached device.""" + climate = _make_climate() + climate.session.should_use_cached_data = MagicMock(return_value=True) + cached_device = _make_device() + cached_device.status = {"current_temperature": 20.0} + climate.session.get_cached_device = MagicMock(return_value=cached_device) + d = _make_device() + result = await climate.get_climate(d) + assert result is cached_device + climate.session.attr.online_offline.assert_not_called() + + +class TestGetScheduleNowNextLater: + async def test_offline_returns_none(self): + """Offline device returns None regardless of mode.""" + climate = _make_climate( + {"heat-1": {"state": {"mode": _SCHEDULE_MODE, "schedule": {}}}} + ) + climate.session.attr.online_offline = AsyncMock(return_value=False) + result = await climate.get_schedule_now_next_later(_make_device()) + assert result is None + + +# --------------------------------------------------------------------------- +# get_mode — BOOST path with missing props.previous must not log error +# --------------------------------------------------------------------------- + + +class TestGetModeBoostMissingPrevious: + """get_mode BOOST path must use safe access, not bare dict that logs a spurious error.""" + + async def test_boost_missing_previous_returns_none_without_error_log(self): + """When mode=BOOST and props has no previous, get_mode returns None without error log.""" + climate = _make_climate({"heat-1": {"state": {"mode": "BOOST"}, "props": {}}}) + d = _make_device() + with patch("apyhiveapi.devices.heating._LOGGER") as mock_log: + result = await climate.get_mode(d) + assert result is None + mock_log.error.assert_not_called() + + async def test_boost_with_previous_mode_returns_mapped_value(self): + """When mode=BOOST and props.previous.mode exists, returns the mapped HA value.""" + from apyhiveapi.helper.const import HIVETOHA + + climate = _make_climate( + { + "heat-1": { + "state": {"mode": "BOOST"}, + "props": {"previous": {"mode": "MANUAL"}}, + } + } + ) + d = _make_device() + result = await climate.get_mode(d) + expected = HIVETOHA["Heating"].get("MANUAL", "MANUAL") + assert result == expected + + +# --------------------------------------------------------------------------- +# set_boost_off — must return False when prev_mode is None (not send mode=None to API) +# --------------------------------------------------------------------------- + + +class TestSetBoostOffNullPrevMode: + """set_boost_off returns False (not an API call) when prev_mode is None.""" + + async def test_set_boost_off_returns_false_when_prev_mode_missing(self): + """set_boost_off returns False and skips _execute_state_change when prev mode absent.""" + from apyhiveapi.devices.heating import HiveHeating + + class StubHeating(HiveHeating): + """Concrete stub for testing.""" + + h = StubHeating() + h.session = MagicMock() + h.session.data.products = { + "h1": { + "state": {"mode": "BOOST"}, + "props": {}, + } + } + h._execute_state_change = AsyncMock(return_value=True) + h.get_boost_status = AsyncMock(return_value="ON") + + d = Device( + hive_id="h1", + hive_name="T", + hive_type="heating", + ha_type="climate", + device_id="d1", + device_name="T", + device_data={"online": True}, + ha_name="Heating", + ) + result = await h.set_boost_off(d) + assert result is False + h._execute_state_change.assert_not_called() + + +# =========================================================================== +# Migrated from test_remaining_branches.py +# =========================================================================== + + +class TestHeatingGetStateKeyError: + """Lines 206-207: KeyError/TypeError branch in get_state.""" + + async def test_get_state_key_error_returns_none(self): + """Missing product entry causes get_current_temperature to return None, + leaving final as None without raising.""" + # products dict is empty — device.hive_id not found → both temp helpers + # return None → the if branch is skipped → final stays None + climate = _make_climate(products={}) + d = _make_device() + result = await climate.get_state(d) + assert result is None + + +class TestHeatingGetHeatOnDemand: + """Line 231: get_heat_on_demand happy path.""" + + async def test_get_heat_on_demand_returns_value(self): + """Returns the nested autoBoost.active value from products.""" + climate = _make_climate({"heat-1": {"props": {"autoBoost": {"active": True}}}}) + result = await climate.get_heat_on_demand(_make_device()) + assert result is True + + async def test_get_heat_on_demand_returns_none_when_missing(self): + """Returns None when the nested path does not exist.""" + climate = _make_climate({"heat-1": {"props": {}}}) + result = await climate.get_heat_on_demand(_make_device()) + assert result is None + + +class TestHeatingSetHeatOnDemand: + """Lines 337-342: set_heat_on_demand calls _execute_state_change with autoBoost kwarg.""" + + async def test_set_heat_on_demand_enabled(self): + """set_heat_on_demand passes autoBoost='ENABLED' to the API.""" + climate = _make_climate({"heat-1": {"type": "heating"}}) + result = await climate.set_heat_on_demand(_make_device(), "ENABLED") + assert result is True + climate.session.api.set_state.assert_called_once() + _, kwargs = climate.session.api.set_state.call_args + assert kwargs.get("autoBoost") == "ENABLED" + + async def test_set_heat_on_demand_disabled(self): + """set_heat_on_demand passes autoBoost='DISABLED' to the API.""" + climate = _make_climate({"heat-1": {"type": "heating"}}) + result = await climate.set_heat_on_demand(_make_device(), "DISABLED") + assert result is True + _, kwargs = climate.session.api.set_state.call_args + assert kwargs.get("autoBoost") == "DISABLED" + + +class TestHeatingGetScheduleNNLKeyError: + """Lines 438-439: KeyError in get_schedule_now_next_later.""" + + async def test_missing_schedule_key_returns_none(self): + """Product with state but no 'schedule' key causes KeyError → returns None.""" + climate = _make_climate( + {"heat-1": {"state": {"mode": "SCHEDULE"}}} + # no 'schedule' key inside state + ) + # Override get_mode to return SCHEDULE directly so the if-branch is entered + climate.session.helper.get_schedule_nnl.side_effect = KeyError("schedule") + # get_mode will read data["state"]["mode"] == "SCHEDULE" → enters the try block + # data["state"]["schedule"] raises KeyError → caught, returns None + result = await climate.get_schedule_now_next_later(_make_device()) + assert result is None + + async def test_schedule_key_error_caught_not_raised(self): + """A KeyError inside the try block does not propagate to the caller.""" + climate = _make_climate({"heat-1": {"state": {"mode": "SCHEDULE"}}}) + # Accessing data["state"]["schedule"] will raise KeyError (key absent) + try: + result = await climate.get_schedule_now_next_later(_make_device()) + except KeyError: + pytest.fail( + "KeyError should have been caught inside get_schedule_now_next_later" + ) + assert result is None + + +class TestHeatingSetBoostOffScheduleMode: + """Lines 321->325: prev_mode not in ('MANUAL','OFF') — target kwarg not added.""" + + async def test_schedule_mode_no_target_kwarg(self): + """SCHEDULE as previous mode does not add a target kwarg.""" + climate = _make_climate( + { + "heat-1": { + "type": "heating", + "state": {"boost": 5}, + "props": {"previous": {"mode": "SCHEDULE"}}, + } + } + ) + result = await climate.set_boost_off(_make_device()) + assert result is True + _, kwargs = climate.session.api.set_state.call_args + assert "target" not in kwargs + assert kwargs.get("mode") == "SCHEDULE" + + +class TestHeatingGetClimateCacheMiss: + """Lines 371->377: cache enabled but cached device is None → normal execution.""" + + async def test_cached_none_falls_through_to_normal_path(self): + """should_use_cached_data=True but get_cached_device=None → normal update.""" + climate = _make_climate( + { + "heat-1": { + "state": {"mode": "MANUAL", "target": 20.0}, + "props": {"temperature": 19.0}, + } + }, + devices={"dev-1": {"state": {}, "props": {}}}, + ) + climate.session.should_use_cached_data.return_value = True + climate.session.get_cached_device.return_value = None + result = await climate.get_climate(_make_device()) + assert result is not None + climate.session.attr.online_offline.assert_called_once() diff --git a/tests/unit/test_heating_extended.py b/tests/unit/test_heating_extended.py deleted file mode 100644 index e6223dd..0000000 --- a/tests/unit/test_heating_extended.py +++ /dev/null @@ -1,238 +0,0 @@ -"""Extended branch-coverage tests for Climate / HiveHeating.""" - -# pylint: disable=too-few-public-methods -from datetime import datetime -from unittest.mock import AsyncMock, MagicMock - -from apyhiveapi.devices.heating import Climate -from apyhiveapi.helper.hivedataclasses import Device, SessionConfig -from apyhiveapi.helper.map import Map - -_TODAY = str(datetime.date(datetime.now())) -_CURRENT_TEMP = 19.0 -_SCHEDULE_MODE = "SCHEDULE" -_BOOST_MINS = 5 -_OFF_MODE = "OFF" - - -def _make_climate(products=None, devices=None, min_max=None): - session = MagicMock() - session.data = Map( - { - "products": products or {}, - "devices": devices or {}, - "actions": {}, - "minMax": min_max or {}, - "user": {}, - } - ) - session.config = SessionConfig() - session.helper = MagicMock() - session.helper.device_recovered = MagicMock() - session.helper.error_check = AsyncMock() - session.helper.get_schedule_nnl = MagicMock( - return_value={"now": {}, "next": {}, "later": {}} - ) - session.attr = MagicMock() - session.attr.online_offline = AsyncMock(return_value=True) - session.attr.state_attributes = AsyncMock(return_value={}) - session.api = MagicMock() - session.api.set_state = AsyncMock(return_value={"original": 200, "parsed": {}}) - session.hive_refresh_tokens = AsyncMock() - session.get_devices = AsyncMock(return_value=True) - session.should_use_cached_data = MagicMock(return_value=False) - session.get_cached_device = MagicMock(return_value=None) - session.set_cached_device = MagicMock(side_effect=lambda d: d) - return Climate(session=session) - - -def _make_device(hive_id="heat-1", device_id="dev-1", hive_type="heating"): - return Device( - hive_id=hive_id, - hive_name="Hallway", - hive_type=hive_type, - ha_type="climate", - device_id=device_id, - device_name="Hallway", - device_data={"online": True}, - ha_name="Hallway", - ) - - -class TestGetCurrentTemperature: - async def test_minmax_today_same_date_updates_min_max(self): - """When minMax entry exists for today's date, TodayMin/TodayMax are updated.""" - initial_min = _CURRENT_TEMP + 2.0 - initial_max = _CURRENT_TEMP - 2.0 - existing = { - "TodayMin": initial_min, - "TodayMax": initial_max, - "TodayDate": _TODAY, - "RestartMin": initial_min, - "RestartMax": initial_max, - } - climate = _make_climate( - products={"heat-1": {"props": {"temperature": _CURRENT_TEMP}}}, - min_max={"heat-1": existing}, - ) - result = await climate.get_current_temperature(_make_device()) - assert result == _CURRENT_TEMP - entry = climate.session.data.minMax["heat-1"] - assert entry["TodayMin"] == min(initial_min, _CURRENT_TEMP) - assert entry["TodayMax"] == max(initial_max, _CURRENT_TEMP) - assert entry["RestartMin"] == min(initial_min, _CURRENT_TEMP) - assert entry["RestartMax"] == max(initial_max, _CURRENT_TEMP) - - async def test_minmax_different_date_resets_today(self): - """When minMax entry exists but TodayDate is stale, today values are reset.""" - existing = { - "TodayMin": 5.0, - "TodayMax": 30.0, - "TodayDate": "2000-01-01", - "RestartMin": 5.0, - "RestartMax": 30.0, - } - climate = _make_climate( - products={"heat-1": {"props": {"temperature": _CURRENT_TEMP}}}, - min_max={"heat-1": existing}, - ) - result = await climate.get_current_temperature(_make_device()) - assert result == _CURRENT_TEMP - entry = climate.session.data.minMax["heat-1"] - assert entry["TodayMin"] == _CURRENT_TEMP - assert entry["TodayMax"] == _CURRENT_TEMP - assert entry["TodayDate"] == _TODAY - - async def test_keyerror_returns_none(self): - """Missing device.hive_id in products returns None.""" - climate = _make_climate(products={}) - result = await climate.get_current_temperature(_make_device()) - assert result is None - - -class TestGetTargetTemperature: - async def test_non_numeric_target_returns_none(self): - """Non-numeric target temperature string returns None.""" - climate = _make_climate({"heat-1": {"state": {"target": "N/A"}}}) - result = await climate.get_target_temperature(_make_device()) - assert result is None - - -class TestGetState: - async def test_current_less_than_target_returns_on(self): - """When current_temp < target_temp, state resolves to ON.""" - climate = _make_climate( - { - "heat-1": { - "props": {"temperature": 19.0}, - "state": {"target": 21.0}, - } - } - ) - result = await climate.get_state(_make_device()) - assert result == "ON" - - async def test_current_ge_target_returns_off(self): - """When current_temp >= target_temp, state resolves to OFF.""" - climate = _make_climate( - { - "heat-1": { - "props": {"temperature": 21.0}, - "state": {"target": 19.0}, - } - } - ) - result = await climate.get_state(_make_device()) - assert result == "OFF" - - async def test_none_temps_returns_none(self): - """When temperatures cannot be read, get_state returns None.""" - climate = _make_climate(products={}) - result = await climate.get_state(_make_device()) - assert result is None - - -class TestGetCurrentOperation: - async def test_returns_working_state(self): - """get_current_operation returns the 'working' value from props.""" - climate = _make_climate({"heat-1": {"props": {"working": True}, "state": {}}}) - result = await climate.get_current_operation(_make_device()) - assert result is True - - -class TestSetBoostOff: - async def test_not_in_products_returns_false(self): - """Device hive_id not present in products returns False.""" - climate = _make_climate(products={}) - result = await climate.set_boost_off(_make_device()) - assert result is False - - async def test_previous_off_mode_restored(self): - """Previous mode OFF sets mode=OFF and target falls back to 7.""" - climate = _make_climate( - { - "heat-1": { - "type": "heating", - "state": {"boost": _BOOST_MINS}, - "props": { - "previous": { - "mode": _OFF_MODE, - "target": None, - } - }, - } - } - ) - result = await climate.set_boost_off(_make_device()) - assert result is True - _, kwargs = climate.session.api.set_state.call_args - assert kwargs.get("mode") == _OFF_MODE - assert kwargs.get("target") == 7 - - -class TestGetClimate: - async def test_device_data_not_dict_gets_reset(self): - """Non-dict device_data is replaced with an empty dict before use.""" - climate = _make_climate( - products={"heat-1": {"props": {}, "state": {}}}, - devices={"dev-1": {"props": {}, "parent": None}}, - ) - d = _make_device() - d.device_data = None - await climate.get_climate(d) - assert isinstance(d.device_data, dict) - - async def test_offline_device_error_check_called(self): - """Offline device triggers error_check and status defaults to None values.""" - climate = _make_climate( - products={"heat-1": {}}, - devices={"dev-1": {}}, - ) - climate.session.attr.online_offline = AsyncMock(return_value=False) - d = _make_device() - result = await climate.get_climate(d) - climate.session.helper.error_check.assert_called_once() - assert result.status["current_temperature"] is None - - async def test_cache_hit_returns_cached(self): - """When cached data is available and poll is slow, returns cached device.""" - climate = _make_climate() - climate.session.should_use_cached_data = MagicMock(return_value=True) - cached_device = _make_device() - cached_device.status = {"current_temperature": 20.0} - climate.session.get_cached_device = MagicMock(return_value=cached_device) - d = _make_device() - result = await climate.get_climate(d) - assert result is cached_device - climate.session.attr.online_offline.assert_not_called() - - -class TestGetScheduleNowNextLater: - async def test_offline_returns_none(self): - """Offline device returns None regardless of mode.""" - climate = _make_climate( - {"heat-1": {"state": {"mode": _SCHEDULE_MODE, "schedule": {}}}} - ) - climate.session.attr.online_offline = AsyncMock(return_value=False) - result = await climate.get_schedule_now_next_later(_make_device()) - assert result is None diff --git a/tests/unit/test_helpers.py b/tests/unit/test_helpers.py index 1f84c3e..8ee6b13 100644 --- a/tests/unit/test_helpers.py +++ b/tests/unit/test_helpers.py @@ -38,11 +38,7 @@ class TestEpochTime: """Tests for the top-level epoch_time() helper function.""" def test_to_epoch_returns_int(self): - """to_epoch converts a date string to an integer Unix timestamp. - - Note: epoch_time ignores the *pattern* argument for "to_epoch" — - it always applies "%d.%m.%Y %H:%M:%S" internally. - """ + """to_epoch converts a date string to an integer Unix timestamp.""" result = epoch_time("01.01.2024 12:00:00", "%d.%m.%Y %H:%M:%S", "to_epoch") assert isinstance(result, int) @@ -198,6 +194,31 @@ def test_non_string_value_under_sensitive_key_passes_through(self): result = helper.sanitize_payload({"token": _non_string_int}) assert result["token"] == _non_string_int + def test_dict_under_sensitive_key_is_recursively_masked(self): + """A dict value under a sensitive key has its own values masked.""" + helper, _ = _make_helper() + result = helper.sanitize_payload({"token": {"inner_key": "secret_value"}}) + assert isinstance(result["token"], dict) + assert "inner_key" in result["token"] + assert result["token"]["inner_key"] != "secret_value" + + def test_list_value_under_sensitive_key_masks_each_element(self): + """A list under a sensitive key has each element masked individually.""" + helper, _ = _make_helper() + payload = {"token": ["short", "averylongtoken123"]} + result = helper.sanitize_payload(payload) + assert result["token"] == ["***", "aver...n123"] + + def test_none_under_sensitive_key_passes_through(self): + """None under a sensitive key is returned unchanged.""" + helper, _ = _make_helper() + assert helper.sanitize_payload({"token": None})["token"] is None + + def test_bool_under_sensitive_key_passes_through(self): + """A bool under a sensitive key is returned unchanged.""" + helper, _ = _make_helper() + assert helper.sanitize_payload({"token": True})["token"] is True + # --------------------------------------------------------------------------- # HiveHelper.device_recovered @@ -472,3 +493,183 @@ def test_heating_matches_by_zone(self): } result = helper.get_device_data(product) assert result["id"] == "thermo-1" + + +# =========================================================================== +# Migrated from test_remaining_branches.py +# =========================================================================== + + +class TestHiveHelperZoneMismatch: + """hive_helper.py 163->160: loop continues when zones don't match.""" + + def test_zone_mismatch_keeps_product_as_device(self): + """When a Thermo device's zone doesn't match the product's zone, + the loop arc 163->160 is taken and device stays as the product.""" + helper, _ = _make_helper( + devices={ + "thermo-1": { + "type": "thermostatui", + "props": {"zone": "zone-B"}, + } + } + ) + + product = { + "type": "heating", + "id": "prod-1", + "props": {"zone": "zone-A"}, # different zone from thermo-1 + } + + result = helper.get_device_data(product) + # The zone mismatch means device was never re-assigned; returns the product + assert result is product + + def test_trv_without_zone_does_not_log_warning(self, caplog): + """TRV devices that omit 'zone' from props are silently skipped (no warning).""" + import logging + + helper, _ = _make_helper( + devices={ + "trv-1": { + "type": "trv", + "props": {"online": True}, # no 'zone' key — current API behaviour + } + } + ) + + product = { + "type": "heating", + "id": "prod-1", + "props": {"zone": "zone-A"}, + } + + with caplog.at_level(logging.WARNING, logger="apyhiveapi.helper.hive_helper"): + result = helper.get_device_data(product) + + assert result is product + assert not caplog.records, ( + f"Unexpected warnings: {[r.getMessage() for r in caplog.records]}" + ) + + +class TestHiveHelperSanitizeListNode: + """hive_helper.py line 359: list value under a non-sensitive key calls _walk(list).""" + + def test_list_under_non_sensitive_key_is_walked(self): + """A list value under a non-sensitive key hits the isinstance(node, list) branch.""" + helper, _ = _make_helper() + result = helper.sanitize_payload({"devices": ["device-a", "device-b"]}) + # 'devices' is not a sensitive key → _walk called for the list + # _walk for a list returns [_walk(item) for item in node] + # Each string item: _walk(str) → str (falls through to return node) + assert result == {"devices": ["device-a", "device-b"]} + + def test_list_containing_dicts_is_walked_recursively(self): + """A list of dicts under a non-sensitive key is recursively processed.""" + helper, _ = _make_helper() + result = helper.sanitize_payload( + { + "items": [ + {"token": "abc", "name": "device1"}, + {"token": "xyz", "name": "device2"}, + ] + } + ) + # 'items' is not sensitive → _walk called for the list + # Each dict in the list is processed by _walk + # 'token' IS sensitive → masked in each sub-dict + assert result["items"][0]["name"] == "device1" + assert result["items"][0]["token"] != "abc" + assert result["items"][1]["name"] == "device2" + assert result["items"][1]["token"] != "xyz" + + +# --------------------------------------------------------------------------- +# Migrated from test_hive_helper_extended.py +# --------------------------------------------------------------------------- + + +class TestGetDeviceFromIdBranch: + """Covers the branch where no cache entry matches the requested ID.""" + + def test_returns_false_when_no_match_in_cache(self): + """When entity_cache has entries but none match n_id, returns False.""" + other_device = Device( + hive_id="other-hive-id", + hive_name="Other", + hive_type="heating", + ha_type="climate", + device_id="other-device-id", + device_name="Other", + device_data={}, + ) + helper, _ = _make_helper(entity_cache={"other-key": other_device}) + result = helper.get_device_from_id("nonexistent-id") + assert result is False + + def test_returns_false_when_cache_is_empty(self): + """When entity_cache is empty, returns False without entering the loop.""" + helper, _ = _make_helper(entity_cache={}) + assert helper.get_device_from_id("any-id") is False + + +class TestEpochTimePattern: + """epoch_time to_epoch must honour the pattern argument.""" + + def test_to_epoch_uses_caller_pattern(self): + """Passing a custom pattern must parse the date string with that pattern.""" + result = epoch_time("2024-06-15", "%Y-%m-%d", "to_epoch") + assert isinstance(result, int), "Expected int epoch timestamp" + assert result > 0 + + def test_to_epoch_standard_hive_format_still_works(self): + """The standard Hive date+time format must still parse correctly.""" + result = epoch_time("15.06.2024 12:00:00", "%d.%m.%Y %H:%M:%S", "to_epoch") + assert isinstance(result, int) + assert result > 0 + + +def _sample_schedule(): + """Minimal 7-day schedule with 3 slots on every day.""" + days = [ + "monday", + "tuesday", + "wednesday", + "thursday", + "friday", + "saturday", + "sunday", + ] + schedule = {} + for d in days: + schedule[d] = [ + {"start": 0, "value": {"status": "ON"}}, + {"start": 480, "value": {"status": "OFF"}}, + {"start": 1200, "value": {"status": "ON"}}, + ] + return schedule + + +class TestGetScheduleNnlMutation: + """get_schedule_nnl must not mutate the input schedule dicts.""" + + def test_second_call_returns_same_result_as_first_call(self): + """Calling get_schedule_nnl twice on the same schedule dict gives consistent results.""" + h, _ = _make_helper() + schedule = _sample_schedule() + result1 = h.get_schedule_nnl(schedule) + result2 = h.get_schedule_nnl(schedule) + assert result1.get("now", {}).get("value") == result2.get("now", {}).get( + "value" + ), "Second call returned different 'now' value — schedule was mutated in-place" + + def test_input_schedule_slots_not_modified(self): + """Slot dicts in the input schedule must not gain 'Start_DateTime' after the call.""" + h, _ = _make_helper() + schedule = _sample_schedule() + monday_slot_before = dict(schedule["monday"][0]) + h.get_schedule_nnl(schedule) + assert schedule["monday"][0] == monday_slot_before, ( + "get_schedule_nnl mutated the original slot dict" + ) diff --git a/tests/unit/test_hive_api.py b/tests/unit/test_hive_api.py index 8b6a182..6e272bd 100644 --- a/tests/unit/test_hive_api.py +++ b/tests/unit/test_hive_api.py @@ -152,6 +152,27 @@ def test_request_passes_timeout(self): class TestGetLoginInfo: + def test_tls_verification_is_not_disabled(self): + """The SSO bootstrap request must not pass verify=False.""" + api = _make_api() + html_content = ( + b"" + ) + mock_resp = MagicMock() + mock_resp.content = html_content + mock_resp.status_code = 200 + + with patch( + "apyhiveapi.api.hive_api.requests.get", return_value=mock_resp + ) as mock_get: + api.get_login_info() + + _, call_kwargs = mock_get.call_args + assert call_kwargs.get("verify", True) is not False + def test_successful_parse_returns_login_data(self): """Parses HiveSSOPoolId and HiveSSOPublicCognitoClientId from the SSO page.""" api = _make_api() @@ -213,119 +234,6 @@ def test_key_error_calls_error_and_returns_none(self): assert result is None -# --------------------------------------------------------------------------- -# Tests: HiveApi.refresh_tokens -# --------------------------------------------------------------------------- - - -class TestRefreshTokens: - def test_successful_with_token_key_updates_session(self): - """When the response contains 'token', session.update_tokens is called.""" - api = _make_api() - refresh_data = { - "token": "new-token", - "platform": {"endpoint": "https://new.endpoint.com"}, - } - mock_resp = _make_mock_response( - 200, json_data=refresh_data, text=json.dumps(refresh_data) - ) - - with patch.object(api, "request", return_value=mock_resp): - result = api.refresh_tokens() - - api.session.update_tokens.assert_called_once_with(refresh_data) - assert result["original"] == 200 - - def test_no_token_in_response_no_session_update(self): - """When response lacks 'token' key, update_tokens is not called.""" - api = _make_api() - response_data = {"other_key": "value"} - mock_resp = _make_mock_response( - 200, json_data=response_data, text=json.dumps(response_data) - ) - - with patch.object(api, "request", return_value=mock_resp): - api.refresh_tokens() - - api.session.update_tokens.assert_not_called() - - def test_none_tokens_defaults_to_empty_dict(self): - """Calling refresh_tokens() without arguments uses session.token_data.""" - api = _make_api() - response_data = {"other": "val"} - mock_resp = _make_mock_response( - 200, json_data=response_data, text=json.dumps(response_data) - ) - - with patch.object(api, "request", return_value=mock_resp) as mock_req: - api.refresh_tokens() - # Should have been called (session provides the tokens dict) - mock_req.assert_called_once() - - def test_os_error_calls_error(self): - api = _make_api() - with patch.object(api, "request", side_effect=OSError("connection failed")): - api.refresh_tokens() - - assert api.json_return["original"] == "Error making API call" - - def test_runtime_error_calls_error(self): - api = _make_api() - with patch.object(api, "request", side_effect=RuntimeError("fail")): - api.refresh_tokens() - - assert api.json_return["original"] == "Error making API call" - - def test_json_decode_error_calls_error(self): - """Bad JSON in response text triggers error().""" - api = _make_api() - mock_resp = MagicMock() - mock_resp.status_code = 200 - mock_resp.text = "not-json" - - with patch.object(api, "request", return_value=mock_resp): - api.refresh_tokens() - - assert api.json_return["original"] == "Error making API call" - - def test_explicit_tokens_arg_skips_none_branch(self): - """Passing a non-None tokens arg covers the 80->82 False branch.""" - api = _make_api() - explicit_tokens = {"key": "val"} - response_data = {"other": "x"} - mock_resp = _make_mock_response(200, json_data=response_data) - - with patch.object(api, "request", return_value=mock_resp): - api.refresh_tokens(tokens=explicit_tokens) - # Session is not None so session tokens overwrite, but no crash - api.session.update_tokens.assert_not_called() - - def test_session_none_skips_token_overwrite(self): - """When session is None the 83->85 False branch is taken (no token overwrite).""" - api = _make_api_no_session(token="standalone-token") - response_data = {"other": "x"} - mock_resp = _make_mock_response(200, json_data=response_data) - - with patch.object(api, "request", return_value=mock_resp): - api.refresh_tokens(tokens={"key": "val"}) - - def test_urls_base_updated_on_token_refresh(self): - """After a successful refresh the base URL is updated from the response.""" - api = _make_api() - refresh_data = { - "token": "new-tok", - "platform": {"endpoint": "https://new-platform.com/1.0"}, - } - mock_resp = _make_mock_response( - 200, json_data=refresh_data, text=json.dumps(refresh_data) - ) - - with patch.object(api, "request", return_value=mock_resp): - api.refresh_tokens() - - assert api.urls["base"] == "https://new-platform.com/1.0" - - # --------------------------------------------------------------------------- # Tests: HiveApi.get_all # --------------------------------------------------------------------------- @@ -343,14 +251,13 @@ def test_successful_returns_original_and_parsed(self): assert result["original"] == 200 assert result["parsed"] == payload - def test_none_response_logs_error_and_returns_empty(self): + def test_none_response_logs_error_and_returns_no_response_marker(self): """When request returns None the method should not crash.""" api = _make_api() with patch.object(api, "request", return_value=None): result = api.get_all() - # No keys populated — dict remains empty - assert "original" not in result + assert result["original"] == "No response to Hive API request" def test_os_error_calls_error_method(self): api = _make_api() @@ -584,8 +491,7 @@ def test_none_response_logs_error_no_crash(self): with patch.object(api, "request", return_value=None): result = api.set_state("heating", "node-1", mode="MANUAL") - # json_return stays at default (unchanged from init defaults) - assert result is api.json_return + assert result["original"] == "No response to Hive API request" def test_os_error_calls_error(self): api = _make_api() @@ -626,6 +532,28 @@ def test_kwargs_serialised_into_jsc(self): assert "target" in jsc_arg assert "21" in jsc_arg + def test_payload_is_valid_json_with_native_types(self): + """The payload must round-trip through json.loads with types intact.""" + api = _make_api() + mock_resp = _make_mock_response(200, json_data={}) + + with patch.object(api, "request", return_value=mock_resp) as mock_req: + api.set_state("heating", "n1", mode="SCHEDULE", target=21.5) + + jsc_arg = mock_req.call_args[0][2] + assert json.loads(jsc_arg) == {"mode": "SCHEDULE", "target": 21.5} + + def test_payload_with_quotes_is_valid_json(self): + """Values containing quotes must not break or inject into the JSON.""" + api = _make_api() + mock_resp = _make_mock_response(200, json_data={}) + + with patch.object(api, "request", return_value=mock_resp) as mock_req: + api.set_state("heating", "n1", name='say "hi", "extra": "injected') + + jsc_arg = mock_req.call_args[0][2] + assert json.loads(jsc_arg) == {"name": 'say "hi", "extra": "injected'} + # --------------------------------------------------------------------------- # Tests: HiveApi.set_action @@ -687,6 +615,52 @@ def test_runtime_error_calls_error(self): assert api.json_return["original"] == "Error making API call" +# --------------------------------------------------------------------------- +# Tests: result isolation between calls +# --------------------------------------------------------------------------- + + +class TestResultIsolation: + def test_results_are_independent_between_calls(self): + """A later call must not mutate the dict returned by an earlier call.""" + api = _make_api() + resp_devices = _make_mock_response(200, json_data=[{"id": "dev1"}]) + resp_products = _make_mock_response(200, json_data=[{"id": "prod1"}]) + + with patch.object(api, "request", side_effect=[resp_devices, resp_products]): + devices = api.get_devices() + products = api.get_products() + + assert devices is not products + assert devices["parsed"] == [{"id": "dev1"}] + assert products["parsed"] == [{"id": "prod1"}] + + def test_error_call_does_not_corrupt_previous_result(self): + """An error in a later call must not overwrite an earlier result.""" + api = _make_api() + resp_devices = _make_mock_response(200, json_data=[{"id": "dev1"}]) + + with patch.object(api, "request", side_effect=[resp_devices, OSError("down")]): + devices = api.get_devices() + failed = api.get_products() + + assert devices["original"] == 200 + assert devices["parsed"] == [{"id": "dev1"}] + assert failed["original"] == "Error making API call" + + def test_set_state_none_response_does_not_return_stale_data(self): + """set_state with no response must not surface a previous call's payload.""" + api = _make_api() + resp_devices = _make_mock_response(200, json_data=[{"id": "dev1"}]) + + with patch.object(api, "request", side_effect=[resp_devices, None]): + api.get_devices() + result = api.set_state("heating", "n1", mode="MANUAL") + + assert result["parsed"] != [{"id": "dev1"}] + assert result["original"] == "No response to Hive API request" + + # --------------------------------------------------------------------------- # Tests: HiveApi.error # --------------------------------------------------------------------------- diff --git a/tests/unit/test_hive_async_api.py b/tests/unit/test_hive_async_api.py index 86c1cd7..92d1027 100644 --- a/tests/unit/test_hive_async_api.py +++ b/tests/unit/test_hive_async_api.py @@ -1,8 +1,9 @@ """Unit tests for HiveApiAsync.""" import asyncio -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock, MagicMock, patch +import aiohttp import pytest from aiohttp import web_exceptions from apyhiveapi.api.hive_async_api import HiveApiAsync @@ -65,50 +66,42 @@ def _make_api_no_token(_url_contains_sso=False): class TestHiveApiAsyncRequest: - @pytest.mark.asyncio async def test_successful_200_returns_response(self): api = _make_api(status=200, json_data={"ok": True}) resp = await api.request("get", "https://beekeeper.hivehome.com/1.0/nodes/all") assert resp.status == 200 - @pytest.mark.asyncio async def test_201_also_succeeds(self): api = _make_api(status=201) resp = await api.request("post", "https://beekeeper.hivehome.com/1.0/nodes/x/y") assert resp.status == 201 - @pytest.mark.asyncio async def test_sso_url_without_token_does_not_raise(self): api = _make_api_no_token() # Should not raise NoApiToken because "sso" is in the URL resp = await api.request("get", "https://sso.hivehome.com/") assert resp.status == 200 - @pytest.mark.asyncio async def test_non_sso_without_token_raises_no_api_token(self): api = _make_api_no_token() with pytest.raises(NoApiToken): await api.request("get", "https://beekeeper.hivehome.com/1.0/nodes/all") - @pytest.mark.asyncio async def test_401_raises_hive_auth_error(self): api = _make_api(status=401) with pytest.raises(HiveAuthError): await api.request("get", "https://beekeeper.hivehome.com/1.0/nodes/all") - @pytest.mark.asyncio async def test_403_raises_hive_auth_error(self): api = _make_api(status=403) with pytest.raises(HiveAuthError): await api.request("get", "https://beekeeper.hivehome.com/1.0/nodes/all") - @pytest.mark.asyncio async def test_500_raises_hive_api_error(self): api = _make_api(status=500) with pytest.raises(HiveApiError): await api.request("get", "https://beekeeper.hivehome.com/1.0/nodes/all") - @pytest.mark.asyncio async def test_404_raises_hive_api_error(self): api = _make_api(status=404) with pytest.raises(HiveApiError): @@ -121,7 +114,6 @@ async def test_404_raises_hive_api_error(self): class TestGetAll: - @pytest.mark.asyncio async def test_successful_get_all_returns_parsed_json(self): payload = {"products": [], "devices": []} api = _make_api(status=200, json_data=payload) @@ -129,21 +121,18 @@ async def test_successful_get_all_returns_parsed_json(self): assert result["original"] == 200 assert result["parsed"] == payload - @pytest.mark.asyncio async def test_timeout_error_propagates(self): api = _make_api(status=200) api.websession.request.side_effect = asyncio.TimeoutError with pytest.raises(asyncio.TimeoutError): await api.get_all() - @pytest.mark.asyncio async def test_os_error_calls_error_method(self): api = _make_api(status=200) api.websession.request.side_effect = OSError("network down") with pytest.raises(web_exceptions.HTTPError): await api.get_all() - @pytest.mark.asyncio async def test_runtime_error_calls_error_method(self): api = _make_api(status=200) api.websession.request.side_effect = RuntimeError("boom") @@ -157,7 +146,6 @@ async def test_runtime_error_calls_error_method(self): class TestGetEndpoints: - @pytest.mark.asyncio async def test_get_devices_returns_parsed_json(self): payload = [{"id": "dev1"}] api = _make_api(status=200, json_data=payload) @@ -165,7 +153,6 @@ async def test_get_devices_returns_parsed_json(self): assert result["original"] == 200 assert result["parsed"] == payload - @pytest.mark.asyncio async def test_get_products_returns_parsed_json(self): payload = [{"id": "prod1"}] api = _make_api(status=200, json_data=payload) @@ -173,7 +160,6 @@ async def test_get_products_returns_parsed_json(self): assert result["original"] == 200 assert result["parsed"] == payload - @pytest.mark.asyncio async def test_get_actions_returns_parsed_json(self): payload = [{"id": "act1"}] api = _make_api(status=200, json_data=payload) @@ -181,21 +167,18 @@ async def test_get_actions_returns_parsed_json(self): assert result["original"] == 200 assert result["parsed"] == payload - @pytest.mark.asyncio async def test_get_devices_os_error_raises_http_error(self): api = _make_api(status=200) api.websession.request.side_effect = OSError with pytest.raises(web_exceptions.HTTPError): await api.get_devices() - @pytest.mark.asyncio async def test_get_products_os_error_raises_http_error(self): api = _make_api(status=200) api.websession.request.side_effect = OSError with pytest.raises(web_exceptions.HTTPError): await api.get_products() - @pytest.mark.asyncio async def test_get_actions_os_error_raises_http_error(self): api = _make_api(status=200) api.websession.request.side_effect = OSError @@ -209,13 +192,11 @@ async def test_get_actions_os_error_raises_http_error(self): class TestSetState: - @pytest.mark.asyncio async def test_file_in_use_returns_file_response(self): api = _make_api(status=200, file_mode=True) result = await api.set_state("heating", "node-1", mode="MANUAL") assert result == {"original": "file"} - @pytest.mark.asyncio async def test_successful_set_state(self): payload = {"id": "node-1", "mode": "MANUAL"} api = _make_api(status=200, json_data=payload) @@ -223,14 +204,12 @@ async def test_successful_set_state(self): assert result["original"] == 200 assert result["parsed"] == payload - @pytest.mark.asyncio async def test_os_error_calls_error_method(self): api = _make_api(status=200) api.websession.request.side_effect = OSError("fail") with pytest.raises(web_exceptions.HTTPError): await api.set_state("heating", "node-1", mode="MANUAL") - @pytest.mark.asyncio async def test_runtime_error_calls_error_method(self): api = _make_api(status=200) api.websession.request.side_effect = RuntimeError("fail") @@ -244,19 +223,24 @@ async def test_runtime_error_calls_error_method(self): class TestSetAction: - @pytest.mark.asyncio async def test_file_in_use_returns_file_response(self): api = _make_api(status=200, file_mode=True) result = await api.set_action("action-1", '{"status": "on"}') assert result == {"original": "file"} - @pytest.mark.asyncio - async def test_successful_set_action_returns_json_return(self): - api = _make_api(status=200) + async def test_successful_set_action_returns_status_200(self): + payload = {"id": "action-1", "status": "on"} + api = _make_api(status=200, json_data=payload) result = await api.set_action("action-1", '{"status": "on"}') - assert result == api.json_return + assert result["original"] == 200 + assert result["parsed"] == payload + + async def test_runtime_error_calls_error_method(self): + api = _make_api(status=200) + api.websession.request.side_effect = RuntimeError("fail") + with pytest.raises(web_exceptions.HTTPError): + await api.set_action("action-1", "{}") - @pytest.mark.asyncio async def test_os_error_calls_error_method(self): api = _make_api(status=200) api.websession.request.side_effect = OSError @@ -264,13 +248,45 @@ async def test_os_error_calls_error_method(self): await api.set_action("action-1", "{}") +# --------------------------------------------------------------------------- +# Tests: HiveApiAsync.motion_sensor +# --------------------------------------------------------------------------- + + +class TestMotionSensor: + async def test_url_does_not_double_base_url(self): + payload = [{"timestamp": 12345}] + api = _make_api(status=200, json_data=payload) + captured = {} + original_request = api.request + + async def capture_request(method, url, **kwargs): + captured["url"] = url + return await original_request(method, url, **kwargs) + + api.request = capture_request + sensor = {"type": "motionsensor", "id": "ms-001"} + await api.motion_sensor(sensor, 1000000, 2000000) + url = captured["url"] + assert url.startswith(api.base_url + "/products/") + assert "motionsensor/ms-001" in url + assert url.count("https://beekeeper") == 1 + + async def test_motion_sensor_returns_parsed_json(self): + payload = [{"timestamp": 12345}] + api = _make_api(status=200, json_data=payload) + sensor = {"type": "motionsensor", "id": "ms-001"} + result = await api.motion_sensor(sensor, 1000000, 2000000) + assert result["original"] == 200 + assert result["parsed"] == payload + + # --------------------------------------------------------------------------- # Tests: HiveApiAsync.error # --------------------------------------------------------------------------- class TestError: - @pytest.mark.asyncio async def test_error_raises_http_error(self): api = _make_api() with pytest.raises(web_exceptions.HTTPError): @@ -283,13 +299,11 @@ async def test_error_raises_http_error(self): class TestIsFileBeingUsed: - @pytest.mark.asyncio async def test_file_mode_raises_file_in_use(self): api = _make_api(file_mode=True) with pytest.raises(FileInUse): await api.is_file_being_used() - @pytest.mark.asyncio async def test_not_file_mode_does_not_raise(self): api = _make_api(file_mode=False) await api.is_file_being_used() # Should not raise @@ -301,14 +315,26 @@ async def test_not_file_mode_does_not_raise(self): class TestInit: - async def test_default_websession_created_when_none_passed(self): + def test_no_websession_created_at_init(self): + """No ClientSession is created in the sync constructor.""" session = MagicMock() session.tokens = MagicMock() session.tokens.token_data = {"token": "tok"} session.config = MagicMock() api = HiveApiAsync(hive_session=session) - assert api.websession is not None - await api.websession.close() + assert api.websession is None + + def test_websession_created_lazily_and_cached(self): + """The first _get_websession() call creates and caches a ClientSession.""" + session = MagicMock() + api = HiveApiAsync(hive_session=session) + with patch("apyhiveapi.api.hive_async_api.ClientSession") as mock_session_cls: + first = api._get_websession() + second = api._get_websession() + mock_session_cls.assert_called_once() + assert first is mock_session_cls.return_value + assert second is first + assert api.websession is first def test_custom_websession_is_used(self): session = MagicMock() @@ -323,3 +349,191 @@ def test_base_url_is_set(self): def test_default_timeout(self): api = _make_api() assert api.timeout == 5 + + +# --------------------------------------------------------------------------- +# Migrated from test_hive_async_api_extended.py +# --------------------------------------------------------------------------- + + +class TestRequestNonAuthErrorBranch: + """Cover lines 100-108: url/status not None branch leading to HiveApiError.""" + + async def test_404_logs_and_raises_hive_api_error(self): + """A 404 falls through to the url/status branch and raises HiveApiError.""" + api = _make_api(status=404) + with pytest.raises(HiveApiError): + await api.request("get", "https://beekeeper.hivehome.com/1.0/nodes/all") + + async def test_503_logs_and_raises_hive_api_error(self): + """A 503 falls through to the url/status branch and raises HiveApiError.""" + api = _make_api(status=503) + with pytest.raises(HiveApiError): + await api.request("get", "https://beekeeper.hivehome.com/1.0/nodes/all") + + async def test_422_logs_and_raises_hive_api_error(self): + """A 422 also falls through (not 401/403) and raises HiveApiError.""" + api = _make_api(status=422) + with pytest.raises(HiveApiError): + await api.request("get", "https://beekeeper.hivehome.com/1.0/devices") + + +class TestMotionSensorBranches: + """Cover lines 215-235: motion_sensor() success and error paths.""" + + async def test_success_returns_status_and_parsed(self): + """Successful call returns status and parsed JSON.""" + payload = [{"event": "motion", "timestamp": 1234567890}] + api = _make_api(status=200, json_data=payload) + api.urls["base"] = "" + sensor = {"type": "motionsensor", "id": "sensor-001"} + result = await api.motion_sensor(sensor, fromepoch=1000000, toepoch=2000000) + assert result["original"] == 200 + assert result["parsed"] == payload + + async def test_url_is_built_correctly(self): + """Verifies the URL is assembled with correct sensor type and id.""" + api = _make_api(status=200, json_data=[]) + api.urls["base"] = "https://beekeeper-uk.hivehome.com/1.0" + sensor = {"type": "contactsensor", "id": "abc-123"} + captured_url = [] + original_request = api.request + + async def capture_request(method, url, **kwargs): + captured_url.append(url) + return await original_request(method, url, **kwargs) + + with patch.object(api, "request", side_effect=capture_request): + await api.motion_sensor(sensor, fromepoch=100, toepoch=200) + assert len(captured_url) == 1 + assert "contactsensor" in captured_url[0] + assert "abc-123" in captured_url[0] + assert "from=100" in captured_url[0] + assert "to=200" in captured_url[0] + + async def test_os_error_raises_http_error(self): + """OSError inside the try block causes error() → HTTPError.""" + api = _make_api(status=200) + api.urls["base"] = "" + sensor = {"type": "motionsensor", "id": "sensor-001"} + api.websession.request.side_effect = OSError("fail") + with pytest.raises(web_exceptions.HTTPError): + await api.motion_sensor(sensor, fromepoch=1000, toepoch=2000) + + async def test_runtime_error_raises_http_error(self): + """RuntimeError inside the try block causes error() → HTTPError.""" + api = _make_api(status=200) + api.urls["base"] = "" + sensor = {"type": "motionsensor", "id": "sensor-002"} + api.websession.request.side_effect = RuntimeError("unexpected") + with pytest.raises(web_exceptions.HTTPError): + await api.motion_sensor(sensor, fromepoch=1000, toepoch=2000) + + async def test_client_error_raises_http_error(self): + """aiohttp.ClientError inside the try block causes error() → HTTPError.""" + api = _make_api(status=200) + api.urls["base"] = "" + sensor = {"type": "motionsensor", "id": "sensor-003"} + api.websession.request.side_effect = aiohttp.ClientError() + with pytest.raises(web_exceptions.HTTPError): + await api.motion_sensor(sensor, fromepoch=1000, toepoch=2000) + + +class TestGetWeather: + """Cover lines 239-249: get_weather() success, space encoding, and error paths.""" + + async def test_success_returns_status_and_parsed(self): + """Successful call returns status and parsed weather JSON.""" + payload = {"temperature": {"value": 15, "unit": "C"}} + api = _make_api(status=200, json_data=payload) + result = await api.get_weather("?lat=51.5&lon=-0.1") + assert result["original"] == 200 + assert result["parsed"] == payload + + async def test_space_in_weather_url_is_encoded(self): + """Spaces in the weather_url are replaced with %20.""" + api = _make_api(status=200, json_data={}) + captured_url = [] + original_request = api.request + + async def capture_request(method, url, **kwargs): + captured_url.append(url) + return await original_request(method, url, **kwargs) + + with patch.object(api, "request", side_effect=capture_request): + await api.get_weather("?postcode=SW1A 2AA") + assert len(captured_url) == 1 + assert " " not in captured_url[0] + assert "%20" in captured_url[0] + + async def test_url_is_prefixed_with_weather_base(self): + """The weather base URL is prepended to the given weather_url.""" + api = _make_api(status=200, json_data={}) + captured_url = [] + original_request = api.request + + async def capture_request(method, url, **kwargs): + captured_url.append(url) + return await original_request(method, url, **kwargs) + + with patch.object(api, "request", side_effect=capture_request): + await api.get_weather("?lat=51.5") + assert captured_url[0].startswith("https://weather.prod.bgchprod.info/weather") + + async def test_os_error_raises_http_error(self): + """OSError inside the try block causes error() → HTTPError.""" + api = _make_api(status=200) + api.websession.request.side_effect = OSError("network fail") + with pytest.raises(web_exceptions.HTTPError): + await api.get_weather("?lat=51.5") + + async def test_runtime_error_raises_http_error(self): + """RuntimeError inside the try block causes error() → HTTPError.""" + api = _make_api(status=200) + api.websession.request.side_effect = RuntimeError("unexpected") + with pytest.raises(web_exceptions.HTTPError): + await api.get_weather("?lat=51.5") + + async def test_client_error_raises_http_error(self): + """aiohttp.ClientError inside the try block causes error() → HTTPError.""" + api = _make_api(status=200) + api.websession.request.side_effect = aiohttp.ClientError() + with pytest.raises(web_exceptions.HTTPError): + await api.get_weather("?lat=51.5") + + async def test_connection_error_raises_http_error(self): + """ConnectionError inside the try block causes error() → HTTPError.""" + api = _make_api(status=200) + api.websession.request.side_effect = ConnectionError("disconnected") + with pytest.raises(web_exceptions.HTTPError): + await api.get_weather("?lat=51.5") + + +class TestSetStateJsonEncoding: + """set_state must produce valid JSON even when kwarg values contain special characters.""" + + async def test_set_state_escapes_quotes_in_value(self): + """A value containing double-quotes must produce valid, parseable JSON.""" + import json # noqa: PLC0415 + + session = MagicMock() + session.tokens.token_data = {"token": "tok"} + session.config.file = False + api = HiveApiAsync(hive_session=session) + api.urls = {"nodes": "https://beekeeper.hivehome.com/1.0/nodes/{}/{}"} + + captured = {} + + async def fake_request(_method, _url, **kwargs): + captured["data"] = kwargs.get("data") + resp = MagicMock() + resp.status = 200 + resp.json = AsyncMock(return_value={}) + return resp + + with patch.object(api, "request", side_effect=fake_request): + with patch.object(api, "is_file_being_used", new=AsyncMock()): + await api.set_state("heating", "node-1", mode='MANUAL"injected') + + parsed = json.loads(captured["data"]) + assert parsed["mode"] == 'MANUAL"injected' diff --git a/tests/unit/test_hive_async_api_extended.py b/tests/unit/test_hive_async_api_extended.py deleted file mode 100644 index 5a9848c..0000000 --- a/tests/unit/test_hive_async_api_extended.py +++ /dev/null @@ -1,427 +0,0 @@ -"""Extended unit tests for HiveApiAsync — covers previously uncovered lines.""" - -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest -from aiohttp import web_exceptions -from apyhiveapi.api.hive_async_api import HiveApiAsync -from apyhiveapi.helper.hive_exceptions import HiveApiError - -# --------------------------------------------------------------------------- -# Shared helpers (same pattern as test_hive_async_api.py) -# --------------------------------------------------------------------------- - - -def _make_mock_response(status=200, json_data=None): - resp = MagicMock() - resp.status = status - resp.text = AsyncMock(return_value="") - resp.json = AsyncMock(return_value=json_data or {"data": "test"}) - resp.__aenter__ = AsyncMock(return_value=resp) - resp.__aexit__ = AsyncMock(return_value=False) - return resp - - -def _make_api(status=200, json_data=None, token="test-token", file_mode=False): - resp = _make_mock_response(status=status, json_data=json_data) - websession = MagicMock() - websession.request.return_value = resp - websession.closed = False - websession.close = AsyncMock() - session = MagicMock() - session.tokens = MagicMock() - session.tokens.token_data = {"token": token} - session.config = MagicMock() - session.config.file = file_mode - return HiveApiAsync(hive_session=session, websession=websession) - - -# --------------------------------------------------------------------------- -# Tests: request() branch — url is not None and status is not None (non-auth error) -# --------------------------------------------------------------------------- - - -class TestRequestNonAuthErrorBranch: - """Cover lines 100-108: url/status not None branch leading to HiveApiError.""" - - async def test_404_logs_and_raises_hive_api_error(self): - """A 404 falls through to the url/status branch and raises HiveApiError.""" - api = _make_api(status=404) - with pytest.raises(HiveApiError): - await api.request("get", "https://beekeeper.hivehome.com/1.0/nodes/all") - - async def test_503_logs_and_raises_hive_api_error(self): - """A 503 falls through to the url/status branch and raises HiveApiError.""" - api = _make_api(status=503) - with pytest.raises(HiveApiError): - await api.request("get", "https://beekeeper.hivehome.com/1.0/nodes/all") - - async def test_422_logs_and_raises_hive_api_error(self): - """A 422 also falls through (not 401/403) and raises HiveApiError.""" - api = _make_api(status=422) - with pytest.raises(HiveApiError): - await api.request("get", "https://beekeeper.hivehome.com/1.0/devices") - - -# --------------------------------------------------------------------------- -# Tests: get_login_info() — sync method (lines 110-129) -# --------------------------------------------------------------------------- - - -class TestGetLoginInfo: - """Cover lines 112-129: get_login_info() parses HTML and returns login dict.""" - - def test_returns_upid_cliid_region(self): - """Successful fetch returns correct keys from parsed HTML.""" - html_content = ( - b"" - ) - mock_response = MagicMock() - mock_response.content = html_content - - api = _make_api() - with patch( - "apyhiveapi.api.hive_async_api.requests.get", return_value=mock_response - ): - result = api.get_login_info() - - assert result["UPID"] == "eu-west-1_abc123" - assert result["CLIID"] == "client-xyz" - # REGION is set to HiveSSOPoolId value - assert result["REGION"] == "eu-west-1_abc123" - - def test_makes_request_to_sso_url(self): - """Verifies requests.get is called with the SSO URL.""" - html_content = ( - b"" - ) - mock_response = MagicMock() - mock_response.content = html_content - - api = _make_api() - with patch( - "apyhiveapi.api.hive_async_api.requests.get", return_value=mock_response - ) as mock_get: - api.get_login_info() - - mock_get.assert_called_once_with( - url="https://sso.hivehome.com/", verify=False, timeout=api.timeout - ) - - def test_uses_first_script_tag(self): - """PyQuery selects the first script — extra scripts are ignored.""" - html_content = ( - b"" - b'' - ) - mock_response = MagicMock() - mock_response.content = html_content - - api = _make_api() - with patch( - "apyhiveapi.api.hive_async_api.requests.get", return_value=mock_response - ): - result = api.get_login_info() - - assert result["UPID"] == "eu-west-1_first" - - -# --------------------------------------------------------------------------- -# Tests: refresh_tokens() — lines 131-156 -# --------------------------------------------------------------------------- - - -class TestRefreshTokens: - """Cover lines 133-156: refresh_tokens() success, no-token, and error paths.""" - - async def test_successful_request_with_non_ok_json_return_returns_json_return(self): - """When request succeeds but json_return["original"] != HTTP_OK, returns json_return.""" - api = _make_api(status=200) - # request() will succeed (200) but json_return is not updated by refresh_tokens - # so json_return["original"] stays as the default string, not HTTP_OK (200) - result = await api.refresh_tokens() - # Returns self.json_return (the default dict) - assert result == api.json_return - - async def test_session_tokens_read_before_request(self): - """tokens are read from session.tokens.token_data before constructing the request.""" - api = _make_api(status=200, token="my-session-token") - api.session.tokens.token_data = { - "token": "my-session-token", - "refreshToken": "r-tok", - } - result = await api.refresh_tokens() - # No exception raised — tokens were read without error - assert result is not None - - async def test_connection_error_raises_http_error(self): - """ConnectionError inside the try block causes error() → HTTPError.""" - api = _make_api(status=200) - api.websession.request.side_effect = ConnectionError("connection refused") - with pytest.raises(web_exceptions.HTTPError): - await api.refresh_tokens() - - async def test_os_error_raises_http_error(self): - """OSError inside the try block causes error() → HTTPError.""" - api = _make_api(status=200) - api.websession.request.side_effect = OSError("network error") - with pytest.raises(web_exceptions.HTTPError): - await api.refresh_tokens() - - async def test_runtime_error_raises_http_error(self): - """RuntimeError inside the try block causes error() → HTTPError.""" - api = _make_api(status=200) - api.websession.request.side_effect = RuntimeError("bad state") - with pytest.raises(web_exceptions.HTTPError): - await api.refresh_tokens() - - async def test_zero_division_raises_http_error(self): - """ZeroDivisionError inside the try block causes error() → HTTPError.""" - api = _make_api(status=200) - api.websession.request.side_effect = ZeroDivisionError("division by zero") - with pytest.raises(web_exceptions.HTTPError): - await api.refresh_tokens() - - async def test_json_return_true_when_ok_status_in_json_return(self): - """When json_return["original"] equals HTTP_OK (200) and token is present, - update_tokens is called and base_url is updated, returning True.""" - api = _make_api(status=200) - # Manually set json_return to simulate a successful response - api.json_return = { - "original": 200, - "parsed": { - "token": "new-token", - "platform": {"endpoint": "https://new.endpoint"}, - }, - } - api.session.update_tokens = AsyncMock() - - # Patch request to be a no-op (doesn't modify json_return) - with patch.object(api, "request", new_callable=AsyncMock) as mock_req: - mock_req.return_value = MagicMock() - result = await api.refresh_tokens() - - assert result is True - api.session.update_tokens.assert_called_once_with(api.json_return["parsed"]) - assert api.base_url == "https://new.endpoint" - - async def test_json_return_true_without_token_in_parsed(self): - """When json_return["original"] == HTTP_OK but no 'token' in parsed, - update_tokens is NOT called and returns True.""" - api = _make_api(status=200) - api.json_return = { - "original": 200, - "parsed": {"other_key": "value"}, - } - api.session.update_tokens = AsyncMock() - - with patch.object(api, "request", new_callable=AsyncMock) as mock_req: - mock_req.return_value = MagicMock() - result = await api.refresh_tokens() - - assert result is True - api.session.update_tokens.assert_not_called() - - -# --------------------------------------------------------------------------- -# Tests: motion_sensor() — lines 213-235 -# --------------------------------------------------------------------------- - - -class TestMotionSensor: - """Cover lines 215-235: motion_sensor() success and error paths.""" - - async def test_success_returns_status_and_parsed(self): - """Successful call returns status and parsed JSON.""" - payload = [{"event": "motion", "timestamp": 1234567890}] - api = _make_api(status=200, json_data=payload) - # motion_sensor uses urls["base"] which doesn't exist in HiveApiAsync; - # add it so the URL can be constructed - api.urls["base"] = "" - sensor = {"type": "motionsensor", "id": "sensor-001"} - - result = await api.motion_sensor(sensor, fromepoch=1000000, toepoch=2000000) - - assert result["original"] == 200 - assert result["parsed"] == payload - - async def test_url_is_built_correctly(self): - """Verifies the URL is assembled with correct sensor type and id.""" - api = _make_api(status=200, json_data=[]) - api.urls["base"] = "https://beekeeper-uk.hivehome.com/1.0" - sensor = {"type": "contactsensor", "id": "abc-123"} - - captured_url = [] - original_request = api.request - - async def capture_request(method, url, **kwargs): - captured_url.append(url) - return await original_request(method, url, **kwargs) - - with patch.object(api, "request", side_effect=capture_request): - await api.motion_sensor(sensor, fromepoch=100, toepoch=200) - - assert len(captured_url) == 1 - assert "contactsensor" in captured_url[0] - assert "abc-123" in captured_url[0] - assert "from=100" in captured_url[0] - assert "to=200" in captured_url[0] - - async def test_os_error_raises_http_error(self): - """OSError inside the try block causes error() → HTTPError.""" - api = _make_api(status=200) - api.urls["base"] = "" - sensor = {"type": "motionsensor", "id": "sensor-001"} - api.websession.request.side_effect = OSError("fail") - with pytest.raises(web_exceptions.HTTPError): - await api.motion_sensor(sensor, fromepoch=1000, toepoch=2000) - - async def test_runtime_error_raises_http_error(self): - """RuntimeError inside the try block causes error() → HTTPError.""" - api = _make_api(status=200) - api.urls["base"] = "" - sensor = {"type": "motionsensor", "id": "sensor-002"} - api.websession.request.side_effect = RuntimeError("unexpected") - with pytest.raises(web_exceptions.HTTPError): - await api.motion_sensor(sensor, fromepoch=1000, toepoch=2000) - - async def test_zero_division_raises_http_error(self): - """ZeroDivisionError inside the try block causes error() → HTTPError.""" - api = _make_api(status=200) - api.urls["base"] = "" - sensor = {"type": "motionsensor", "id": "sensor-003"} - api.websession.request.side_effect = ZeroDivisionError() - with pytest.raises(web_exceptions.HTTPError): - await api.motion_sensor(sensor, fromepoch=1000, toepoch=2000) - - -# --------------------------------------------------------------------------- -# Tests: get_weather() — lines 237-249 -# --------------------------------------------------------------------------- - - -class TestGetWeather: - """Cover lines 239-249: get_weather() success, space encoding, and error paths.""" - - async def test_success_returns_status_and_parsed(self): - """Successful call returns status and parsed weather JSON.""" - payload = {"temperature": {"value": 15, "unit": "C"}} - api = _make_api(status=200, json_data=payload) - - result = await api.get_weather("?lat=51.5&lon=-0.1") - - assert result["original"] == 200 - assert result["parsed"] == payload - - async def test_space_in_weather_url_is_encoded(self): - """Spaces in the weather_url are replaced with %20.""" - api = _make_api(status=200, json_data={}) - - captured_url = [] - original_request = api.request - - async def capture_request(method, url, **kwargs): - captured_url.append(url) - return await original_request(method, url, **kwargs) - - with patch.object(api, "request", side_effect=capture_request): - await api.get_weather("?postcode=SW1A 2AA") - - assert len(captured_url) == 1 - assert " " not in captured_url[0] - assert "%20" in captured_url[0] - - async def test_url_is_prefixed_with_weather_base(self): - """The weather base URL is prepended to the given weather_url.""" - api = _make_api(status=200, json_data={}) - - captured_url = [] - original_request = api.request - - async def capture_request(method, url, **kwargs): - captured_url.append(url) - return await original_request(method, url, **kwargs) - - with patch.object(api, "request", side_effect=capture_request): - await api.get_weather("?lat=51.5") - - assert captured_url[0].startswith("https://weather.prod.bgchprod.info/weather") - - async def test_os_error_raises_http_error(self): - """OSError inside the try block causes error() → HTTPError.""" - api = _make_api(status=200) - api.websession.request.side_effect = OSError("network fail") - with pytest.raises(web_exceptions.HTTPError): - await api.get_weather("?lat=51.5") - - async def test_runtime_error_raises_http_error(self): - """RuntimeError inside the try block causes error() → HTTPError.""" - api = _make_api(status=200) - api.websession.request.side_effect = RuntimeError("unexpected") - with pytest.raises(web_exceptions.HTTPError): - await api.get_weather("?lat=51.5") - - async def test_zero_division_raises_http_error(self): - """ZeroDivisionError inside the try block causes error() → HTTPError.""" - api = _make_api(status=200) - api.websession.request.side_effect = ZeroDivisionError() - with pytest.raises(web_exceptions.HTTPError): - await api.get_weather("?lat=51.5") - - async def test_connection_error_raises_http_error(self): - """ConnectionError inside the try block causes error() → HTTPError.""" - api = _make_api(status=200) - api.websession.request.side_effect = ConnectionError("disconnected") - with pytest.raises(web_exceptions.HTTPError): - await api.get_weather("?lat=51.5") - - -# --------------------------------------------------------------------------- -# Tests: request() — url=None and resp.status=None skips the logging branch -# --------------------------------------------------------------------------- - - -class TestRequestUrlOrStatusNone: - """Lines 100->108: when url is None or resp.status is None, skip log → raise directly.""" - - async def test_none_status_skips_log_and_raises_hive_api_error(self): - """resp.status=None causes branch 100->108 (skips the log lines) then raises.""" - api = _make_api(status=200) - # Replace the websession response with one having status=None - bad_resp = _make_mock_response(status=None) - bad_resp.text = AsyncMock(return_value="") - api.websession.request.return_value = bad_resp - with pytest.raises(HiveApiError): - await api.request("get", None) - - -# --------------------------------------------------------------------------- -# Tests: refresh_tokens() — session=None (134->136) -# --------------------------------------------------------------------------- - - -class TestRefreshTokensSessionNone: - """Line 134->136: when self.session is None, skip token_data read (line 135).""" - - async def test_session_none_skips_token_data_read(self): - """When session is None, tokens is not set from session → jsc uses undefined.""" - ws = MagicMock() - ws.request.return_value = _make_mock_response(status=200) - ws.closed = False - ws.close = AsyncMock() - api = HiveApiAsync(hive_session=None, websession=ws) - # tokens is not defined before jsc, so this will raise NameError or UnboundLocalError; - # what we need is that line 134's False branch (134->136) is traversed. - try: - await api.refresh_tokens() - except (NameError, UnboundLocalError, AttributeError): - pass # expected — tokens was never defined since session is None diff --git a/tests/unit/test_hive_auth_async.py b/tests/unit/test_hive_auth_async.py index fb541f8..9b03a7f 100644 --- a/tests/unit/test_hive_auth_async.py +++ b/tests/unit/test_hive_auth_async.py @@ -14,6 +14,7 @@ HiveInvalidPassword, HiveInvalidUsername, HiveRefreshTokenExpired, + HiveUnknownConfiguration, ) # --------------------------------------------------------------------------- @@ -73,18 +74,40 @@ async def _make_auth( return auth +_LOGIN_INFO = { + "UPID": "eu-west-1_TestPool", + "CLIID": "test-client-id", + "REGION": "eu-west-1_TestPool", +} + + # --------------------------------------------------------------------------- # Tests: __init__ # --------------------------------------------------------------------------- class TestHiveAuthAsyncInit: - def test_pool_region_raises_value_error(self): + def test_pool_region_no_longer_accepted(self): from apyhiveapi.api.hive_auth_async import HiveAuthAsync - with pytest.raises(ValueError, match="pool_region"): + with pytest.raises(TypeError): HiveAuthAsync(username="u", password="p", pool_region="eu-west-1") + async def test_async_init_sets_running_loop(self): + from apyhiveapi.api.hive_auth_async import HiveAuthAsync + + auth = HiveAuthAsync(username="u@test.com", password="pass") + assert auth.loop is None # not set until async_init + mock_data = { + "UPID": "eu-west-1_Test", + "CLIID": "client-id", + "REGION": "eu-west-1_Test", + } + with patch.object(auth.api, "get_login_info", return_value=mock_data): + with patch("boto3.client", return_value=MagicMock()): + await auth.async_init() + assert auth.loop is not None + async def test_file_flag_set_for_magic_username(self): from apyhiveapi.api.hive_auth_async import HiveAuthAsync @@ -185,6 +208,49 @@ async def test_user_not_found_raises_invalid_username(self): with pytest.raises(HiveInvalidUsername): await auth.login() + @pytest.mark.asyncio + async def test_direct_authentication_without_challenge_stores_token(self): + """A response with AuthenticationResult and no ChallengeName must not crash.""" + auth = await _make_auth() + auth_result = {"AuthenticationResult": {"AccessToken": "direct-tok"}} + auth.loop.run_in_executor.return_value = auth_result + + result = await auth.login() + + assert result is auth_result + assert auth.access_token == "direct-tok" + + @pytest.mark.asyncio + async def test_login_regenerates_srp_ephemeral_each_attempt(self): + """Each login() must use a fresh SRP (a, A) ephemeral pair.""" + auth = await _make_auth() + auth.loop.run_in_executor.return_value = { + "AuthenticationResult": {"AccessToken": "tok"} + } + + initial_a = auth.large_a_value + await auth.login() + second_a = auth.large_a_value + await auth.login() + third_a = auth.large_a_value + + assert len({initial_a, second_a, third_a}) == 3 + + @pytest.mark.asyncio + async def test_device_login_regenerates_srp_ephemeral(self): + """device_login() must use a fresh SRP (a, A) ephemeral pair.""" + auth = await _make_auth(device_key="dk-1", device_group_key="grp-1") + auth.device_password = "dev-pass" # pragma: allowlist secret + auth.loop.run_in_executor.return_value = {"ChallengeParameters": {}} + + initial_a = auth.large_a_value + with patch.object( + auth, "process_device_challenge", new=AsyncMock(return_value={}) + ): + await auth.device_login() + + assert auth.large_a_value != initial_a + @pytest.mark.asyncio async def test_endpoint_error_on_initiate_raises_api_error(self): auth = await _make_auth() @@ -444,6 +510,14 @@ async def test_new_device_metadata_in_sms_stores_keys(self): assert auth.device_group_key == "sms-grp" assert auth.device_key == "sms-dev" + @pytest.mark.asyncio + async def test_no_authentication_result_key_does_not_raise(self): + auth = await _make_auth() + auth.loop.run_in_executor.return_value = {"ChallengeName": "SMS_MFA"} + result = await auth.sms_2fa("123456", {"Session": "sess-1"}) + assert auth.access_token is None + assert result == {"ChallengeName": "SMS_MFA"} + # --------------------------------------------------------------------------- # Tests: refresh_token @@ -516,3 +590,840 @@ async def test_endpoint_error_raises_api_error(self): auth.loop.run_in_executor.side_effect = _endpoint_error() with pytest.raises(HiveApiError): await auth.refresh_token("tok") + + +# --------------------------------------------------------------------------- +# Migrated from test_hive_auth_async_extended.py +# --------------------------------------------------------------------------- + + +class TestAsyncInit: + """Cover lines 98-112: async_init() sets pool_id, client_id, region and boto3 client.""" + + async def test_async_init_sets_pool_id_and_client_id(self): + """async_init reads login info and sets internal auth fields.""" + from apyhiveapi.api.hive_auth_async import HiveAuthAsync + + auth = HiveAuthAsync(username="user@test.com", password="pass") + auth.client = None + + mock_boto_client = MagicMock() + mock_loop = MagicMock() + mock_loop.run_in_executor = AsyncMock( + side_effect=[_LOGIN_INFO, mock_boto_client] + ) + + with patch("asyncio.get_running_loop", return_value=mock_loop): + await auth.async_init() + + assert auth._pool_id == "eu-west-1_TestPool" + assert auth._client_id == "test-client-id" + assert auth._region == "eu-west-1" + assert auth.client is mock_boto_client + + async def test_async_init_splits_region_correctly(self): + """Region is extracted as the part before the underscore in UPID/REGION.""" + from apyhiveapi.api.hive_auth_async import HiveAuthAsync + + auth = HiveAuthAsync(username="user@test.com", password="pass") + auth.client = None + + login_info = { + "UPID": "ap-southeast-2_XyzPool", + "CLIID": "ap-client", + "REGION": "ap-southeast-2_XyzPool", + } + mock_boto_client = MagicMock() + mock_loop = MagicMock() + mock_loop.run_in_executor = AsyncMock( + side_effect=[login_info, mock_boto_client] + ) + + with patch("asyncio.get_running_loop", return_value=mock_loop): + await auth.async_init() + + assert auth._region == "ap-southeast-2" + + +class TestCalculateA: + """Cover line 141: safety check when big_a % big_n == 0.""" + + async def test_safety_check_raises_when_a_is_zero_mod_n(self): + """If pow(g, a, n) == 0 mod n (i.e., equals big_n or 0), ValueError is raised.""" + auth = await _make_auth() + with patch("builtins.pow", return_value=auth.big_n): + with pytest.raises(ValueError, match="Safety check for A failed"): + auth.calculate_a() + + async def test_safety_check_passes_normally(self): + """Under normal random inputs, calculate_a does not raise and returns positive int.""" + auth = await _make_auth() + result = auth.calculate_a() + assert result > 0 + + +class TestGetPasswordAuthenticationKey: + """Cover lines 155-172: get_password_authentication_key() computes HKDF.""" + + async def test_returns_bytes(self): + """With valid SRP inputs, the method returns a bytes-like value.""" + auth = await _make_auth() + from apyhiveapi.api.srp_crypto import get_random + + server_b_value = hex(get_random(128))[2:] + salt = hex(get_random(16))[2:] + + with patch("apyhiveapi.api.hive_auth_async.calculate_u", return_value=99999): + result = auth.get_password_authentication_key( + "testuser", "testpass", server_b_value, salt + ) + + assert isinstance(result, (bytes, bytearray)) + + async def test_u_value_zero_raises_value_error(self): + """If calculate_u returns 0, ValueError is raised.""" + auth = await _make_auth() + from apyhiveapi.api.srp_crypto import get_random + + server_b_value = hex(get_random(128))[2:] + salt = hex(get_random(16))[2:] + + with patch("apyhiveapi.api.hive_auth_async.calculate_u", return_value=0): + with pytest.raises(ValueError, match="U cannot be zero"): + auth.get_password_authentication_key( + "testuser", "testpass", server_b_value, salt + ) + + async def test_accepts_integer_server_b(self): + """server_b_value can be passed as an integer (handled by _to_int).""" + auth = await _make_auth() + from apyhiveapi.api.srp_crypto import get_random + + server_b_int = get_random(128) + salt = hex(get_random(16))[2:] + + with patch("apyhiveapi.api.hive_auth_async.calculate_u", return_value=12345): + result = auth.get_password_authentication_key( + "testuser", "testpass", server_b_int, salt + ) + + assert isinstance(result, (bytes, bytearray)) + + +class TestProcessChallenge: + """Cover lines 205-254: process_challenge() builds the SRP response.""" + + def _make_challenge_params(self, salt_as_int=False): + import base64 + + salt = "aabbccddee" + if salt_as_int: + salt = int("aabbccddee", 16) + return { + "USER_ID_FOR_SRP": "challenge-user@test.com", + "SALT": salt, + "SRP_B": "ff" * 32, + "SECRET_BLOCK": base64.b64encode(b"secret-block-bytes").decode(), + } + + async def test_returns_required_keys(self): + """Basic challenge response includes mandatory SRP keys.""" + auth = await _make_auth() + fake_hkdf = b"\x00" * 32 + auth.loop.run_in_executor = AsyncMock(return_value=fake_hkdf) + params = self._make_challenge_params() + result = await auth.process_challenge(params) + assert "TIMESTAMP" in result + assert "USERNAME" in result + assert "PASSWORD_CLAIM_SECRET_BLOCK" in result + assert "PASSWORD_CLAIM_SIGNATURE" in result + + async def test_sets_user_id_from_challenge(self): + """process_challenge stores USER_ID_FOR_SRP as self.user_id.""" + auth = await _make_auth() + fake_hkdf = b"\x00" * 32 + auth.loop.run_in_executor = AsyncMock(return_value=fake_hkdf) + params = self._make_challenge_params() + await auth.process_challenge(params) + assert auth.user_id == "challenge-user@test.com" + + async def test_with_client_secret_adds_secret_hash(self): + """When client_secret is set, SECRET_HASH is added to the response.""" + auth = await _make_auth(client_secret="my-secret") + fake_hkdf = b"\x00" * 32 + auth.loop.run_in_executor = AsyncMock(return_value=fake_hkdf) + params = self._make_challenge_params() + result = await auth.process_challenge(params) + assert "SECRET_HASH" in result + + async def test_without_client_secret_no_secret_hash(self): + """When client_secret is None, SECRET_HASH is absent from the response.""" + auth = await _make_auth(client_secret=None) + fake_hkdf = b"\x00" * 32 + auth.loop.run_in_executor = AsyncMock(return_value=fake_hkdf) + params = self._make_challenge_params() + result = await auth.process_challenge(params) + assert "SECRET_HASH" not in result + + async def test_with_device_key_adds_device_key(self): + """When device_key is set, DEVICE_KEY is added to the response.""" + auth = await _make_auth(device_key="dk-challenge") + fake_hkdf = b"\x00" * 32 + auth.loop.run_in_executor = AsyncMock(return_value=fake_hkdf) + params = self._make_challenge_params() + result = await auth.process_challenge(params) + assert result["DEVICE_KEY"] == "dk-challenge" + + async def test_without_device_key_no_device_key_in_response(self): + """When device_key is None, DEVICE_KEY is absent from the response.""" + auth = await _make_auth(device_key=None) + fake_hkdf = b"\x00" * 32 + auth.loop.run_in_executor = AsyncMock(return_value=fake_hkdf) + params = self._make_challenge_params() + result = await auth.process_challenge(params) + assert "DEVICE_KEY" not in result + + async def test_salt_as_integer_triggers_pad_hex(self): + """When SALT is an integer (not str), pad_hex is applied before use.""" + auth = await _make_auth() + fake_hkdf = b"\x00" * 32 + auth.loop.run_in_executor = AsyncMock(return_value=fake_hkdf) + params = self._make_challenge_params(salt_as_int=True) + result = await auth.process_challenge(params) + assert "TIMESTAMP" in result + + +class TestLoginClientNone: + """Cover line 263: when client is None, async_init() is awaited.""" + + async def test_login_calls_async_init_when_client_is_none(self): + """If client is None before login, async_init is called before SRP flow.""" + auth = await _make_auth() + auth.client = None + auth.use_file = False + + auth_result = {"AuthenticationResult": {"AccessToken": "post-init-token"}} + challenge_response = { + "ChallengeName": "PASSWORD_VERIFIER", + "ChallengeParameters": { + "USER_ID_FOR_SRP": "user@test.com", + "SALT": "aabbccdd", + "SRP_B": "ccddee", + "SECRET_BLOCK": "YWJj", + }, + } + + async def fake_async_init(): + auth.client = MagicMock() + auth._client_id = "test-client-id" + auth._pool_id = "eu-west-1_TestPool" + auth._region = "eu-west-1" + + with patch.object(auth, "async_init", side_effect=fake_async_init) as mock_init: + with patch.object( + auth, "process_challenge", new_callable=AsyncMock + ) as mock_ch: + mock_ch.return_value = { + "TIMESTAMP": "Mon Jan 01 00:00:00 UTC 2024", + "USERNAME": "user", + } + auth.loop.run_in_executor = AsyncMock( + side_effect=[challenge_response, auth_result] + ) + result = await auth.login() + + mock_init.assert_called_once() + assert "AuthenticationResult" in result + + +class TestLoginUnsupportedChallenge: + """Cover lines 335-337: non-PASSWORD_VERIFIER challenge raises NotImplementedError.""" + + async def test_new_password_required_raises_not_implemented(self): + """NEW_PASSWORD_REQUIRED challenge is not supported and raises NotImplementedError.""" + auth = await _make_auth() + auth.loop.run_in_executor = AsyncMock( + return_value={ + "ChallengeName": "NEW_PASSWORD_REQUIRED", + "ChallengeParameters": {}, + } + ) + with pytest.raises(NotImplementedError, match="NEW_PASSWORD_REQUIRED"): + await auth.login() + + async def test_custom_challenge_raises_not_implemented(self): + """Any unknown challenge name raises NotImplementedError.""" + auth = await _make_auth() + auth.loop.run_in_executor = AsyncMock( + return_value={ + "ChallengeName": "UNKNOWN_CHALLENGE_TYPE", + "ChallengeParameters": {}, + } + ) + with pytest.raises(NotImplementedError, match="UNKNOWN_CHALLENGE_TYPE"): + await auth.login() + + +class TestLoginResourceNotFound: + """Cover lines 307-311: ResourceNotFoundException in respond_to_auth_challenge.""" + + async def test_resource_not_found_raises_invalid_device_authentication(self): + """ResourceNotFoundException during challenge → HiveInvalidDeviceAuthentication.""" + auth = await _make_auth() + challenge_response = { + "ChallengeName": "PASSWORD_VERIFIER", + "ChallengeParameters": { + "USER_ID_FOR_SRP": "user@test.com", + "SALT": "aabbccdd", + "SRP_B": "ccddee", + "SECRET_BLOCK": "YWJj", + }, + } + resource_err = _named_client_error("ResourceNotFoundException") + with patch.object(auth, "process_challenge", new_callable=AsyncMock) as mock_ch: + mock_ch.return_value = { + "TIMESTAMP": "Mon Jan 01 00:00:00 UTC 2024", + "USERNAME": "user", + } + auth.loop.run_in_executor = AsyncMock( + side_effect=[challenge_response, resource_err] + ) + with pytest.raises(HiveInvalidDeviceAuthentication): + await auth.login() + + +class TestLoginEndpointErrorOnChallenge: + """Cover lines 312-317: EndpointConnectionError during respond_to_auth_challenge.""" + + async def test_endpoint_error_on_challenge_raises_api_error(self): + """EndpointConnectionError during SRP challenge response → HiveApiError.""" + auth = await _make_auth() + challenge_response = { + "ChallengeName": "PASSWORD_VERIFIER", + "ChallengeParameters": { + "USER_ID_FOR_SRP": "user@test.com", + "SALT": "aabbccdd", + "SRP_B": "ccddee", + "SECRET_BLOCK": "YWJj", + }, + } + with patch.object(auth, "process_challenge", new_callable=AsyncMock) as mock_ch: + mock_ch.return_value = { + "TIMESTAMP": "Mon Jan 01 00:00:00 UTC 2024", + "USERNAME": "user", + } + auth.loop.run_in_executor = AsyncMock( + side_effect=[challenge_response, _endpoint_error()] + ) + with pytest.raises(HiveApiError): + await auth.login() + + +class TestLoginResultHandling: + """Cover lines 321-333: AuthenticationResult presence/absence in login result.""" + + async def test_result_without_authentication_result_does_not_store_token(self): + """If result lacks 'AuthenticationResult', access_token is not set.""" + auth = await _make_auth() + challenge_response = { + "ChallengeName": "PASSWORD_VERIFIER", + "ChallengeParameters": { + "USER_ID_FOR_SRP": "user@test.com", + "SALT": "aabbccdd", + "SRP_B": "ccddee", + "SECRET_BLOCK": "YWJj", + }, + } + sms_challenge_result = { + "ChallengeName": "SMS_MFA", + "Session": "session-tok", + "ChallengeParameters": {}, + } + with patch.object(auth, "process_challenge", new_callable=AsyncMock) as mock_ch: + mock_ch.return_value = { + "TIMESTAMP": "Mon Jan 01 00:00:00 UTC 2024", + "USERNAME": "user", + } + auth.loop.run_in_executor = AsyncMock( + side_effect=[challenge_response, sms_challenge_result] + ) + result = await auth.login() + + assert auth.access_token is None + assert result is sms_challenge_result + + async def test_result_with_authentication_result_but_no_new_device_metadata(self): + """AuthenticationResult without NewDeviceMetadata sets access_token only.""" + auth = await _make_auth() + challenge_response = { + "ChallengeName": "PASSWORD_VERIFIER", + "ChallengeParameters": { + "USER_ID_FOR_SRP": "user@test.com", + "SALT": "aabbccdd", + "SRP_B": "ccddee", + "SECRET_BLOCK": "YWJj", + }, + } + auth_result = {"AuthenticationResult": {"AccessToken": "my-access-token"}} + with patch.object(auth, "process_challenge", new_callable=AsyncMock) as mock_ch: + mock_ch.return_value = { + "TIMESTAMP": "Mon Jan 01 00:00:00 UTC 2024", + "USERNAME": "user", + } + auth.loop.run_in_executor = AsyncMock( + side_effect=[challenge_response, auth_result] + ) + await auth.login() + + assert auth.access_token == "my-access-token" + assert auth.device_group_key is None + assert auth.device_key is None + + +class TestDeviceLoginClientNone: + """Cover device_login's async_init call when client is None.""" + + async def test_device_login_calls_async_init_when_client_is_none(self): + """If client is None, async_init is called before proceeding.""" + auth = await _make_auth(device_key="dk-1") + auth.client = None + + async def fake_async_init(): + auth.client = MagicMock() + auth._client_id = "test-client-id" + + with patch.object(auth, "async_init", side_effect=fake_async_init) as mock_init: + auth.loop.run_in_executor = AsyncMock( + side_effect=_named_client_error("ResourceNotFoundException") + ) + with pytest.raises(HiveInvalidDeviceAuthentication): + await auth.device_login() + + mock_init.assert_called_once() + + +class TestSms2faNoNewDeviceMetadata: + """Cover line 441: sms_2fa when NewDeviceMetadata is absent in result.""" + + async def test_no_new_device_metadata_does_not_set_device_keys(self): + """When NewDeviceMetadata is absent, device_group_key and device_key stay None.""" + auth = await _make_auth() + original_group_key = auth.device_group_key + original_device_key = auth.device_key + + sms_result = { + "AuthenticationResult": { + "AccessToken": "sms-access-token", + } + } + auth.loop.run_in_executor = AsyncMock(return_value=sms_result) + + result = await auth.sms_2fa("654321", {"Session": "sess-abc"}) + + assert auth.access_token == "sms-access-token" + assert auth.device_group_key == original_group_key + assert auth.device_key == original_device_key + assert result is sms_result + + +class TestSms2faCodeMismatch: + """Cover lines 424-429: CodeMismatchException raises HiveInvalid2FACode.""" + + async def test_code_mismatch_raises_invalid_2fa_code(self): + """CodeMismatchException in sms_2fa raises HiveInvalid2FACode.""" + auth = await _make_auth() + auth.loop.run_in_executor = AsyncMock( + side_effect=_named_client_error("CodeMismatchException") + ) + with pytest.raises(HiveInvalid2FACode): + await auth.sms_2fa("000000", {"Session": "sess-1"}) + + async def test_not_authorized_raises_invalid_2fa_code(self): + """NotAuthorizedException in sms_2fa raises HiveInvalid2FACode.""" + auth = await _make_auth() + auth.loop.run_in_executor = AsyncMock( + side_effect=_named_client_error("NotAuthorizedException") + ) + with pytest.raises(HiveInvalid2FACode): + await auth.sms_2fa("111111", {"Session": "sess-2"}) + + +class TestRefreshTokenClientNone: + """Cover line 440-441: refresh_token calls async_init when client is None.""" + + async def test_refresh_token_calls_async_init_when_client_is_none(self): + """If client is None, async_init is awaited before refreshing.""" + auth = await _make_auth() + auth.client = None + + result_payload = {"AuthenticationResult": {"AccessToken": "refreshed-tok"}} + + async def fake_async_init(): + auth.client = MagicMock() + + with patch.object(auth, "async_init", side_effect=fake_async_init) as mock_init: + auth.loop.run_in_executor = AsyncMock(return_value=result_payload) + result = await auth.refresh_token("some-refresh-token") + + mock_init.assert_called_once() + assert result is result_payload + + +class TestRefreshTokenResultPath: + """Cover lines 479-485: refresh_token when result is returned normally.""" + + async def test_returns_result_directly(self): + """refresh_token returns the result from Cognito directly.""" + auth = await _make_auth() + result_payload = {"AuthenticationResult": {"AccessToken": "tok-xyz"}} + auth.loop.run_in_executor = AsyncMock(return_value=result_payload) + result = await auth.refresh_token("refresh-tok-abc") + assert result is result_payload + + async def test_with_device_key_includes_device_key_param(self): + """When device_key is set, DEVICE_KEY is included in auth_params.""" + auth = await _make_auth(device_key="dk-refresh-001") + result_payload = {"AuthenticationResult": {"AccessToken": "tok-dk"}} + auth.loop.run_in_executor = AsyncMock(return_value=result_payload) + result = await auth.refresh_token("refresh-tok-dk") + assert result is result_payload + + async def test_without_device_key_sends_only_refresh_token(self): + """When device_key is None, auth_params has only REFRESH_TOKEN.""" + auth = await _make_auth(device_key=None) + result_payload = {"AuthenticationResult": {"AccessToken": "tok-no-dk"}} + auth.loop.run_in_executor = AsyncMock(return_value=result_payload) + result = await auth.refresh_token("refresh-tok-no-dk") + assert result is result_payload + + +class TestLoginInitiateAuthSwallowedClientError: + """Arc 280->288: ClientError caught but class name is not UserNotFoundException.""" + + async def test_other_client_error_in_initiate_auth_falls_through(self): + """Non-UserNotFoundException ClientError raises HiveApiError.""" + auth = await _make_auth() + + wrong_cls = type("SomeOtherError", (botocore.exceptions.ClientError,), {}) + wrong_err = wrong_cls( + {"Error": {"Code": "SomeOtherError", "Message": "msg"}}, "op" + ) + auth.loop.run_in_executor = AsyncMock(side_effect=wrong_err) + + with pytest.raises(HiveApiError): + await auth.login() + + +class TestLoginInitiateAuthSwallowedEndpointError: + """EndpointConnectionError in initiate_auth always raises HiveApiError.""" + + async def test_wrong_name_endpoint_error_in_initiate_auth_raises_api_error(self): + """Any EndpointConnectionError subclass in initiate_auth raises HiveApiError.""" + auth = await _make_auth() + + wrong_cls = type( + "WrongEndpoint", (botocore.exceptions.EndpointConnectionError,), {} + ) + wrong_err = wrong_cls(endpoint_url="https://cognito.eu-west-1.amazonaws.com") + auth.loop.run_in_executor = AsyncMock(side_effect=wrong_err) + + with pytest.raises(HiveApiError): + await auth.login() + + +class TestLoginChallengeSwallowedClientError: + """Arc 307->319: ClientError caught in challenge response with name not matching.""" + + async def test_other_client_error_in_challenge_falls_through(self): + """ClientError that is neither NotAuthorized nor ResourceNotFound raises HiveApiError.""" + auth = await _make_auth() + + challenge_response = { + "ChallengeName": "PASSWORD_VERIFIER", + "ChallengeParameters": { + "USER_ID_FOR_SRP": "user@test.com", + "SALT": "aabbccdd", + "SRP_B": "ccddee", + "SECRET_BLOCK": "YWJj", + }, + } + + wrong_cls = type("ThirdPartyError", (botocore.exceptions.ClientError,), {}) + wrong_err = wrong_cls( + {"Error": {"Code": "ThirdPartyError", "Message": "msg"}}, "op" + ) + + with patch.object(auth, "process_challenge", new_callable=AsyncMock) as mock_ch: + mock_ch.return_value = {"TIMESTAMP": "...", "USERNAME": "user"} + auth.loop.run_in_executor = AsyncMock( + side_effect=[challenge_response, wrong_err] + ) + with pytest.raises(HiveApiError): + await auth.login() + + +class TestLoginChallengeSwallowedEndpointError: + """EndpointConnectionError in respond_to_auth_challenge always raises HiveApiError.""" + + async def test_wrong_name_endpoint_error_in_challenge_raises_api_error(self): + """Any EndpointConnectionError subclass in SRP challenge raises HiveApiError.""" + auth = await _make_auth() + + challenge_response = { + "ChallengeName": "PASSWORD_VERIFIER", + "ChallengeParameters": { + "USER_ID_FOR_SRP": "user@test.com", + "SALT": "aabbccdd", + "SRP_B": "ccddee", + "SECRET_BLOCK": "YWJj", + }, + } + + wrong_cls = type( + "WrongEndpoint", (botocore.exceptions.EndpointConnectionError,), {} + ) + wrong_err = wrong_cls(endpoint_url="https://cognito.eu-west-1.amazonaws.com") + + with patch.object(auth, "process_challenge", new_callable=AsyncMock) as mock_ch: + mock_ch.return_value = {"TIMESTAMP": "...", "USERNAME": "user"} + auth.loop.run_in_executor = AsyncMock( + side_effect=[challenge_response, wrong_err] + ) + with pytest.raises(HiveApiError): + await auth.login() + + +class TestDeviceLoginSuccessPath: + """Lines 364-367, 391: device_login processes device challenge and returns result.""" + + async def test_successful_device_login_returns_auth_result(self): + """Full device_login success: process_device_challenge called, result returned.""" + auth = await _make_auth(device_key="dk-abc", device_group_key="grp-abc") + auth.device_password = "dev-pass" # pragma: allowlist secret + + initial_result = { + "ChallengeParameters": { + "USERNAME": "user@test.com", + "SALT": "aabbccdd", + "SRP_B": "ccddee", + "SECRET_BLOCK": "YWJj", + } + } + final_result = {"AuthenticationResult": {"AccessToken": "device-access-token"}} + + with patch.object( + auth, "process_device_challenge", new_callable=AsyncMock + ) as mock_pdc: + mock_pdc.return_value = { + "TIMESTAMP": "Mon Jan 01 00:00:00 UTC 2024", + "USERNAME": "user@test.com", + "PASSWORD_CLAIM_SECRET_BLOCK": "YWJj", # pragma: allowlist secret + "PASSWORD_CLAIM_SIGNATURE": "sig", # pragma: allowlist secret + "DEVICE_KEY": "dk-abc", + } + auth.loop.run_in_executor = AsyncMock( + side_effect=[initial_result, final_result] + ) + result = await auth.device_login() + + mock_pdc.assert_called_once_with(initial_result["ChallengeParameters"]) + assert result is final_result + + async def test_device_login_calls_second_respond_to_auth_challenge(self): + """Lines 367-375: second respond_to_auth_challenge is called with device challenge.""" + auth = await _make_auth(device_key="dk-xyz", device_group_key="grp-xyz") + auth.device_password = "dev-pass-xyz" # pragma: allowlist secret + + initial_result = { + "ChallengeParameters": { + "USERNAME": "user@test.com", + "SALT": "11223344", + "SRP_B": "55667788", + "SECRET_BLOCK": "dGVzdA==", # pragma: allowlist secret + } + } + final_result = {"AuthenticationResult": {"AccessToken": "tok-xyz"}} + + challenge_resp = { + "TIMESTAMP": "Mon Jan 01 00:00:00 UTC 2024", + "USERNAME": "user@test.com", + "PASSWORD_CLAIM_SECRET_BLOCK": "dGVzdA==", # pragma: allowlist secret + "PASSWORD_CLAIM_SIGNATURE": "sig", # pragma: allowlist secret + "DEVICE_KEY": "dk-xyz", + } + + with patch.object( + auth, "process_device_challenge", new_callable=AsyncMock + ) as mock_pdc: + mock_pdc.return_value = challenge_resp + auth.loop.run_in_executor = AsyncMock( + side_effect=[initial_result, final_result] + ) + result = await auth.device_login() + + assert auth.loop.run_in_executor.call_count == 2 + assert result["AuthenticationResult"]["AccessToken"] == "tok-xyz" + + +class TestDeviceLoginEndpointWrongName: + """Any EndpointConnectionError in device_login always raises HiveApiError.""" + + async def test_wrong_name_endpoint_error_raises_api_error(self): + """Any EndpointConnectionError subclass in device_login raises HiveApiError.""" + auth = await _make_auth(device_key="dk-err", device_group_key="grp-err") + auth.device_password = "dev-pass-err" # pragma: allowlist secret + + wrong_cls = type( + "WrongEndpoint", (botocore.exceptions.EndpointConnectionError,), {} + ) + wrong_err = wrong_cls(endpoint_url="https://cognito.eu-west-1.amazonaws.com") + auth.loop.run_in_executor = AsyncMock(side_effect=wrong_err) + + with pytest.raises(HiveApiError): + await auth.device_login() + + +class TestSms2faUnrecognisedClientError: + """An unrecognised ClientError in sms_2fa must surface as HiveApiError.""" + + async def test_other_client_error_raises_hive_api_error(self): + """A ClientError that is not a 2FA rejection must not be swallowed.""" + auth = await _make_auth() + + auth.loop.run_in_executor = AsyncMock( + side_effect=_named_client_error("LimitExceededException") + ) + + with pytest.raises(HiveApiError): + await auth.sms_2fa("123456", {"Session": "sess-xyz"}) + + +class TestSms2faSwallowedEndpointError: + """Any EndpointConnectionError in sms_2fa raises HiveApiError.""" + + async def test_wrong_name_endpoint_error_raises_api_error(self): + """Any EndpointConnectionError subclass in sms_2fa raises HiveApiError.""" + auth = await _make_auth() + + wrong_cls = type( + "WrongEndpoint", (botocore.exceptions.EndpointConnectionError,), {} + ) + wrong_err = wrong_cls(endpoint_url="https://cognito.eu-west-1.amazonaws.com") + auth.loop.run_in_executor = AsyncMock(side_effect=wrong_err) + + with pytest.raises(HiveApiError): + await auth.sms_2fa("654321", {"Session": "sess-abc"}) + + +class TestRefreshTokenSwallowedEndpointError: + """Any EndpointConnectionError in refresh_token raises HiveApiError.""" + + async def test_wrong_name_endpoint_error_raises_api_error(self): + """Any EndpointConnectionError subclass in refresh_token raises HiveApiError.""" + auth = await _make_auth() + + wrong_cls = type( + "WrongEndpoint", (botocore.exceptions.EndpointConnectionError,), {} + ) + wrong_err = wrong_cls(endpoint_url="https://cognito.eu-west-1.amazonaws.com") + auth.loop.run_in_executor = AsyncMock(side_effect=wrong_err) + + with pytest.raises(HiveApiError): + await auth.refresh_token("some-refresh-token") + + +class TestAsyncInitMissingKeys: + """async_init must raise HiveUnknownConfiguration when login info keys are absent.""" + + async def test_async_init_missing_region_raises_configuration_error(self): + """If REGION is absent from login info, raise HiveUnknownConfiguration.""" + from apyhiveapi.api.hive_auth_async import HiveAuthAsync + + auth = HiveAuthAsync(username="user@test.com", password="pass") + bad_login_info = {"UPID": "eu-west-1_TestPool", "CLIID": "test-client-id"} + mock_loop = MagicMock() + mock_loop.run_in_executor = AsyncMock(side_effect=[bad_login_info]) + with patch("asyncio.get_running_loop", return_value=mock_loop): + with pytest.raises(HiveUnknownConfiguration): + await auth.async_init() + + async def test_async_init_missing_upid_raises_configuration_error(self): + """If UPID is absent from login info, raise HiveUnknownConfiguration.""" + from apyhiveapi.api.hive_auth_async import HiveAuthAsync + + auth = HiveAuthAsync(username="user@test.com", password="pass") + bad_login_info = {"CLIID": "test-client-id", "REGION": "eu-west-1_TestPool"} + mock_loop = MagicMock() + mock_loop.run_in_executor = AsyncMock(side_effect=[bad_login_info]) + with patch("asyncio.get_running_loop", return_value=mock_loop): + with pytest.raises(HiveUnknownConfiguration): + await auth.async_init() + + +class TestGetPasswordAuthKeyNonePoolId: + """get_password_authentication_key must not crash with AttributeError when _pool_id is None.""" + + async def test_none_pool_id_raises_configuration_error(self): + """If _pool_id is None, raise HiveUnknownConfiguration (not AttributeError).""" + auth = await _make_auth() + auth._pool_id = None + with pytest.raises(HiveUnknownConfiguration): + auth.get_password_authentication_key("user", "pass", "DEADBEEF", "ABCDEF") + + +class TestLoginUnhandledClientError: + """Unhandled ClientError codes in login() must raise HiveApiError, not crash.""" + + async def test_initiate_auth_unhandled_error_raises_hive_api_error(self): + """Non-UserNotFoundException ClientError from initiate_auth raises HiveApiError.""" + auth = await _make_auth() + err = botocore.exceptions.ClientError( + {"Error": {"Code": "TooManyRequestsException", "Message": "too many"}}, + "InitiateAuth", + ) + auth.loop.run_in_executor = AsyncMock(side_effect=err) + with pytest.raises(HiveApiError): + await auth.login() + + async def test_respond_to_challenge_unhandled_error_raises_hive_api_error(self): + """Non-NotAuthorized/ResourceNotFound ClientError raises HiveApiError.""" + auth = await _make_auth() + challenge_response = { + "ChallengeName": "PASSWORD_VERIFIER", + "ChallengeParameters": { + "USER_ID_FOR_SRP": "user@test.com", + "SALT": "aabbccdd", + "SRP_B": "ccddee", + "SECRET_BLOCK": "YWJj", + }, + } + respond_err = botocore.exceptions.ClientError( + {"Error": {"Code": "InternalErrorException", "Message": "internal"}}, + "RespondToAuthChallenge", + ) + auth.loop.run_in_executor = AsyncMock( + side_effect=[challenge_response, respond_err] + ) + auth.process_challenge = AsyncMock( + return_value={"TIMESTAMP": "t", "USERNAME": "u"} + ) + with pytest.raises(HiveApiError): + await auth.login() + + +class TestAsyncInitNoneGuard: + """async_init must guard against get_login_info returning None.""" + + async def test_async_init_raises_when_login_info_is_none(self): + """async_init raises HiveUnknownConfiguration when get_login_info returns None.""" + from apyhiveapi.api.hive_auth_async import HiveAuthAsync + + auth = HiveAuthAsync(username="user@test.com", password="pass") + auth.client = None + + mock_loop = MagicMock() + mock_loop.run_in_executor = AsyncMock(return_value=None) + + with patch("asyncio.get_running_loop", return_value=mock_loop): + with pytest.raises(HiveUnknownConfiguration): + await auth.async_init() diff --git a/tests/unit/test_hive_auth_async_extended.py b/tests/unit/test_hive_auth_async_extended.py deleted file mode 100644 index 54d408d..0000000 --- a/tests/unit/test_hive_auth_async_extended.py +++ /dev/null @@ -1,969 +0,0 @@ -"""Extended unit tests for HiveAuthAsync — covers previously uncovered paths.""" - -from __future__ import annotations - -from unittest.mock import AsyncMock, MagicMock, patch - -import botocore.exceptions -import pytest -from apyhiveapi.helper.hive_exceptions import ( - HiveApiError, - HiveInvalid2FACode, - HiveInvalidDeviceAuthentication, -) - -# --------------------------------------------------------------------------- -# Exception factories (same pattern as test_hive_auth_async.py) -# --------------------------------------------------------------------------- - - -def _named_client_error( - code: str, message: str = "" -) -> botocore.exceptions.ClientError: - """Return a ClientError whose __class__.__name__ matches ``code``.""" - cls = type(code, (botocore.exceptions.ClientError,), {}) - return cls( - {"Error": {"Code": code, "Message": message}}, - "operation", - ) - - -def _endpoint_error() -> botocore.exceptions.EndpointConnectionError: - return botocore.exceptions.EndpointConnectionError( - endpoint_url="https://cognito.eu-west-1.amazonaws.com" - ) - - -# --------------------------------------------------------------------------- -# Shared factory -# --------------------------------------------------------------------------- - -_LOGIN_INFO = { - "UPID": "eu-west-1_TestPool", - "CLIID": "test-client-id", - "REGION": "eu-west-1_TestPool", -} - - -async def _make_auth( - username: str = "user@test.com", - password: str = "testpass", - device_key: str | None = None, - device_group_key: str | None = None, - device_password: str | None = None, - client_secret: str | None = None, -): - from apyhiveapi.api.hive_auth_async import HiveAuthAsync - - auth = HiveAuthAsync( - username=username, - password=password, - device_key=device_key, - device_group_key=device_group_key, - device_password=device_password, - client_secret=client_secret, - ) - # Bypass async_init — inject mocked internals directly. - auth.client = MagicMock() - auth._client_id = "test-client-id" - auth._pool_id = "eu-west-1_TestPool" - auth._region = "eu-west-1" - auth.loop = MagicMock() - auth.loop.run_in_executor = AsyncMock() - return auth - - -# --------------------------------------------------------------------------- -# Tests: async_init() — lines 96-112 -# --------------------------------------------------------------------------- - - -class TestAsyncInit: - """Cover lines 98-112: async_init() sets pool_id, client_id, region and - boto3 client.""" - - async def test_async_init_sets_pool_id_and_client_id(self): - """async_init reads login info and sets internal auth fields.""" - from apyhiveapi.api.hive_auth_async import HiveAuthAsync - - auth = HiveAuthAsync(username="user@test.com", password="pass") - auth.client = None # trigger async_init flow - - mock_boto_client = MagicMock() - - auth.loop = MagicMock() - auth.loop.run_in_executor = AsyncMock( - side_effect=[_LOGIN_INFO, mock_boto_client] - ) - - await auth.async_init() - - assert auth._pool_id == "eu-west-1_TestPool" - assert auth._client_id == "test-client-id" - assert auth._region == "eu-west-1" - assert auth.client is mock_boto_client - - async def test_async_init_splits_region_correctly(self): - """Region is extracted as the part before the underscore in UPID/REGION.""" - from apyhiveapi.api.hive_auth_async import HiveAuthAsync - - auth = HiveAuthAsync(username="user@test.com", password="pass") - auth.client = None - - login_info = { - "UPID": "ap-southeast-2_XyzPool", - "CLIID": "ap-client", - "REGION": "ap-southeast-2_XyzPool", - } - mock_boto_client = MagicMock() - auth.loop = MagicMock() - auth.loop.run_in_executor = AsyncMock( - side_effect=[login_info, mock_boto_client] - ) - - await auth.async_init() - - assert auth._region == "ap-southeast-2" - - -# --------------------------------------------------------------------------- -# Tests: calculate_a() safety check — line 140-141 -# --------------------------------------------------------------------------- - - -class TestCalculateA: - """Cover line 141: safety check when big_a % big_n == 0.""" - - async def test_safety_check_raises_when_a_is_zero_mod_n(self): - """If pow(g, a, n) == 0 mod n (i.e., equals big_n or 0), ValueError is raised.""" - auth = await _make_auth() - # Force pow to return auth.big_n so that big_a % big_n == 0 - with patch("builtins.pow", return_value=auth.big_n): - with pytest.raises(ValueError, match="Safety check for A failed"): - auth.calculate_a() - - async def test_safety_check_passes_normally(self): - """Under normal random inputs, calculate_a does not raise and returns positive int.""" - auth = await _make_auth() - # calculate_a was already called during __init__; calling it again should also work - result = auth.calculate_a() - assert result > 0 - - -# --------------------------------------------------------------------------- -# Tests: get_password_authentication_key() — lines 155-172 -# --------------------------------------------------------------------------- - - -class TestGetPasswordAuthenticationKey: - """Cover lines 155-172: get_password_authentication_key() computes HKDF.""" - - async def test_returns_bytes(self): - """With valid SRP inputs, the method returns a bytes-like value.""" - auth = await _make_auth() - from apyhiveapi.api.srp_crypto import get_random - - # Pick a server_b that won't produce u_value == 0 by using a known large value - server_b_value = hex(get_random(128))[2:] - salt = hex(get_random(16))[2:] - - # Patch calculate_u to return a known non-zero value to avoid flakiness - with patch("apyhiveapi.api.hive_auth_async.calculate_u", return_value=99999): - result = auth.get_password_authentication_key( - "testuser", "testpass", server_b_value, salt - ) - - assert isinstance(result, (bytes, bytearray)) - - async def test_u_value_zero_raises_value_error(self): - """If calculate_u returns 0, ValueError is raised.""" - auth = await _make_auth() - from apyhiveapi.api.srp_crypto import get_random - - server_b_value = hex(get_random(128))[2:] - salt = hex(get_random(16))[2:] - - with patch("apyhiveapi.api.hive_auth_async.calculate_u", return_value=0): - with pytest.raises(ValueError, match="U cannot be zero"): - auth.get_password_authentication_key( - "testuser", "testpass", server_b_value, salt - ) - - async def test_accepts_integer_server_b(self): - """server_b_value can be passed as an integer (handled by _to_int).""" - auth = await _make_auth() - from apyhiveapi.api.srp_crypto import get_random - - server_b_int = get_random(128) - salt = hex(get_random(16))[2:] - - with patch("apyhiveapi.api.hive_auth_async.calculate_u", return_value=12345): - result = auth.get_password_authentication_key( - "testuser", "testpass", server_b_int, salt - ) - - assert isinstance(result, (bytes, bytearray)) - - -# --------------------------------------------------------------------------- -# Tests: process_challenge() — lines 203-254 -# --------------------------------------------------------------------------- - - -class TestProcessChallenge: - """Cover lines 205-254: process_challenge() builds the SRP response.""" - - def _make_challenge_params(self, salt_as_int=False): - """Return a minimal valid challenge_parameters dict.""" - import base64 - - salt = "aabbccddee" - if salt_as_int: - salt = int("aabbccddee", 16) - return { - "USER_ID_FOR_SRP": "challenge-user@test.com", - "SALT": salt, - "SRP_B": "ff" * 32, # arbitrary hex - "SECRET_BLOCK": base64.b64encode(b"secret-block-bytes").decode(), - } - - async def test_returns_required_keys(self): - """Basic challenge response includes mandatory SRP keys.""" - auth = await _make_auth() - - fake_hkdf = b"\x00" * 32 - auth.loop.run_in_executor = AsyncMock(return_value=fake_hkdf) - - params = self._make_challenge_params() - result = await auth.process_challenge(params) - - assert "TIMESTAMP" in result - assert "USERNAME" in result - assert "PASSWORD_CLAIM_SECRET_BLOCK" in result - assert "PASSWORD_CLAIM_SIGNATURE" in result - - async def test_sets_user_id_from_challenge(self): - """process_challenge stores USER_ID_FOR_SRP as self.user_id.""" - auth = await _make_auth() - - fake_hkdf = b"\x00" * 32 - auth.loop.run_in_executor = AsyncMock(return_value=fake_hkdf) - - params = self._make_challenge_params() - await auth.process_challenge(params) - - assert auth.user_id == "challenge-user@test.com" - - async def test_with_client_secret_adds_secret_hash(self): - """When client_secret is set, SECRET_HASH is added to the response.""" - auth = await _make_auth(client_secret="my-secret") - - fake_hkdf = b"\x00" * 32 - auth.loop.run_in_executor = AsyncMock(return_value=fake_hkdf) - - params = self._make_challenge_params() - result = await auth.process_challenge(params) - - assert "SECRET_HASH" in result - - async def test_without_client_secret_no_secret_hash(self): - """When client_secret is None, SECRET_HASH is absent from the response.""" - auth = await _make_auth(client_secret=None) - - fake_hkdf = b"\x00" * 32 - auth.loop.run_in_executor = AsyncMock(return_value=fake_hkdf) - - params = self._make_challenge_params() - result = await auth.process_challenge(params) - - assert "SECRET_HASH" not in result - - async def test_with_device_key_adds_device_key(self): - """When device_key is set, DEVICE_KEY is added to the response.""" - auth = await _make_auth(device_key="dk-challenge") - - fake_hkdf = b"\x00" * 32 - auth.loop.run_in_executor = AsyncMock(return_value=fake_hkdf) - - params = self._make_challenge_params() - result = await auth.process_challenge(params) - - assert result["DEVICE_KEY"] == "dk-challenge" - - async def test_without_device_key_no_device_key_in_response(self): - """When device_key is None, DEVICE_KEY is absent from the response.""" - auth = await _make_auth(device_key=None) - - fake_hkdf = b"\x00" * 32 - auth.loop.run_in_executor = AsyncMock(return_value=fake_hkdf) - - params = self._make_challenge_params() - result = await auth.process_challenge(params) - - assert "DEVICE_KEY" not in result - - async def test_salt_as_integer_triggers_pad_hex(self): - """When SALT is an integer (not str), pad_hex is applied before use.""" - auth = await _make_auth() - - fake_hkdf = b"\x00" * 32 - auth.loop.run_in_executor = AsyncMock(return_value=fake_hkdf) - - params = self._make_challenge_params(salt_as_int=True) - # Should not raise; int-type SALT is handled by the isinstance check - result = await auth.process_challenge(params) - - assert "TIMESTAMP" in result - - -# --------------------------------------------------------------------------- -# Tests: login() — client is None triggers async_init (line 262-263) -# --------------------------------------------------------------------------- - - -class TestLoginClientNone: - """Cover line 263: when client is None, async_init() is awaited.""" - - async def test_login_calls_async_init_when_client_is_none(self): - """If client is None before login, async_init is called before SRP flow.""" - auth = await _make_auth() - auth.client = None # reset to trigger the branch - auth.use_file = False # ensure we go through the client-None path - - auth_result = {"AuthenticationResult": {"AccessToken": "post-init-token"}} - challenge_response = { - "ChallengeName": "PASSWORD_VERIFIER", - "ChallengeParameters": { - "USER_ID_FOR_SRP": "user@test.com", - "SALT": "aabbccdd", - "SRP_B": "ccddee", - "SECRET_BLOCK": "YWJj", - }, - } - - async def fake_async_init(): - auth.client = MagicMock() - auth._client_id = "test-client-id" - auth._pool_id = "eu-west-1_TestPool" - auth._region = "eu-west-1" - - with patch.object(auth, "async_init", side_effect=fake_async_init) as mock_init: - with patch.object( - auth, "process_challenge", new_callable=AsyncMock - ) as mock_ch: - mock_ch.return_value = { - "TIMESTAMP": "Mon Jan 01 00:00:00 UTC 2024", - "USERNAME": "user", - } - # After async_init, run_in_executor is called for initiate_auth then - # respond_to_auth_challenge - auth.loop.run_in_executor = AsyncMock( - side_effect=[challenge_response, auth_result] - ) - result = await auth.login() - - mock_init.assert_called_once() - assert "AuthenticationResult" in result - - -# --------------------------------------------------------------------------- -# Tests: login() — unsupported challenge name (lines 335-337) -# --------------------------------------------------------------------------- - - -class TestLoginUnsupportedChallenge: - """Cover lines 335-337: non-PASSWORD_VERIFIER challenge raises NotImplementedError.""" - - async def test_new_password_required_raises_not_implemented(self): - """NEW_PASSWORD_REQUIRED challenge is not supported and raises NotImplementedError.""" - auth = await _make_auth() - auth.loop.run_in_executor = AsyncMock( - return_value={ - "ChallengeName": "NEW_PASSWORD_REQUIRED", - "ChallengeParameters": {}, - } - ) - with pytest.raises(NotImplementedError, match="NEW_PASSWORD_REQUIRED"): - await auth.login() - - async def test_custom_challenge_raises_not_implemented(self): - """Any unknown challenge name raises NotImplementedError.""" - auth = await _make_auth() - auth.loop.run_in_executor = AsyncMock( - return_value={ - "ChallengeName": "UNKNOWN_CHALLENGE_TYPE", - "ChallengeParameters": {}, - } - ) - with pytest.raises(NotImplementedError, match="UNKNOWN_CHALLENGE_TYPE"): - await auth.login() - - -# --------------------------------------------------------------------------- -# Tests: login() — respond_to_auth_challenge ResourceNotFoundException (line 307-311) -# --------------------------------------------------------------------------- - - -class TestLoginResourceNotFound: - """Cover lines 307-311: ResourceNotFoundException in respond_to_auth_challenge.""" - - async def test_resource_not_found_raises_invalid_device_authentication(self): - """ResourceNotFoundException during challenge → HiveInvalidDeviceAuthentication.""" - auth = await _make_auth() - challenge_response = { - "ChallengeName": "PASSWORD_VERIFIER", - "ChallengeParameters": { - "USER_ID_FOR_SRP": "user@test.com", - "SALT": "aabbccdd", - "SRP_B": "ccddee", - "SECRET_BLOCK": "YWJj", - }, - } - resource_err = _named_client_error("ResourceNotFoundException") - with patch.object(auth, "process_challenge", new_callable=AsyncMock) as mock_ch: - mock_ch.return_value = { - "TIMESTAMP": "Mon Jan 01 00:00:00 UTC 2024", - "USERNAME": "user", - } - auth.loop.run_in_executor = AsyncMock( - side_effect=[challenge_response, resource_err] - ) - with pytest.raises(HiveInvalidDeviceAuthentication): - await auth.login() - - -# --------------------------------------------------------------------------- -# Tests: login() — EndpointConnectionError in respond_to_auth_challenge (lines 312-317) -# --------------------------------------------------------------------------- - - -class TestLoginEndpointErrorOnChallenge: - """Cover lines 312-317: EndpointConnectionError during respond_to_auth_challenge.""" - - async def test_endpoint_error_on_challenge_raises_api_error(self): - """EndpointConnectionError during SRP challenge response → HiveApiError.""" - auth = await _make_auth() - challenge_response = { - "ChallengeName": "PASSWORD_VERIFIER", - "ChallengeParameters": { - "USER_ID_FOR_SRP": "user@test.com", - "SALT": "aabbccdd", - "SRP_B": "ccddee", - "SECRET_BLOCK": "YWJj", - }, - } - with patch.object(auth, "process_challenge", new_callable=AsyncMock) as mock_ch: - mock_ch.return_value = { - "TIMESTAMP": "Mon Jan 01 00:00:00 UTC 2024", - "USERNAME": "user", - } - auth.loop.run_in_executor = AsyncMock( - side_effect=[challenge_response, _endpoint_error()] - ) - with pytest.raises(HiveApiError): - await auth.login() - - -# --------------------------------------------------------------------------- -# Tests: login() — result without AuthenticationResult (lines 321-333) -# --------------------------------------------------------------------------- - - -class TestLoginResultHandling: - """Cover lines 321-333: AuthenticationResult presence/absence in login result.""" - - async def test_result_without_authentication_result_does_not_store_token(self): - """If result lacks 'AuthenticationResult', access_token is not set.""" - auth = await _make_auth() - # First call → PASSWORD_VERIFIER challenge - # Second call → result without AuthenticationResult (e.g., SMS_MFA) - challenge_response = { - "ChallengeName": "PASSWORD_VERIFIER", - "ChallengeParameters": { - "USER_ID_FOR_SRP": "user@test.com", - "SALT": "aabbccdd", - "SRP_B": "ccddee", - "SECRET_BLOCK": "YWJj", - }, - } - sms_challenge_result = { - "ChallengeName": "SMS_MFA", - "Session": "session-tok", - "ChallengeParameters": {}, - } - with patch.object(auth, "process_challenge", new_callable=AsyncMock) as mock_ch: - mock_ch.return_value = { - "TIMESTAMP": "Mon Jan 01 00:00:00 UTC 2024", - "USERNAME": "user", - } - auth.loop.run_in_executor = AsyncMock( - side_effect=[challenge_response, sms_challenge_result] - ) - result = await auth.login() - - # access_token was never set - assert auth.access_token is None - assert result is sms_challenge_result - - async def test_result_with_authentication_result_but_no_new_device_metadata(self): - """AuthenticationResult without NewDeviceMetadata sets access_token only.""" - auth = await _make_auth() - challenge_response = { - "ChallengeName": "PASSWORD_VERIFIER", - "ChallengeParameters": { - "USER_ID_FOR_SRP": "user@test.com", - "SALT": "aabbccdd", - "SRP_B": "ccddee", - "SECRET_BLOCK": "YWJj", - }, - } - auth_result = {"AuthenticationResult": {"AccessToken": "my-access-token"}} - with patch.object(auth, "process_challenge", new_callable=AsyncMock) as mock_ch: - mock_ch.return_value = { - "TIMESTAMP": "Mon Jan 01 00:00:00 UTC 2024", - "USERNAME": "user", - } - auth.loop.run_in_executor = AsyncMock( - side_effect=[challenge_response, auth_result] - ) - await auth.login() - - assert auth.access_token == "my-access-token" - assert auth.device_group_key is None - assert auth.device_key is None - - -# --------------------------------------------------------------------------- -# Tests: device_login() — client is None (line 347-348) -# --------------------------------------------------------------------------- - - -class TestDeviceLoginClientNone: - """Cover device_login's async_init call when client is None.""" - - async def test_device_login_calls_async_init_when_client_is_none(self): - """If client is None, async_init is called before proceeding.""" - auth = await _make_auth(device_key="dk-1") - auth.client = None - - async def fake_async_init(): - auth.client = MagicMock() - auth._client_id = "test-client-id" - - with patch.object(auth, "async_init", side_effect=fake_async_init) as mock_init: - auth.loop.run_in_executor = AsyncMock( - side_effect=_named_client_error("ResourceNotFoundException") - ) - with pytest.raises(HiveInvalidDeviceAuthentication): - await auth.device_login() - - mock_init.assert_called_once() - - -# --------------------------------------------------------------------------- -# Tests: sms_2fa() — NewDeviceMetadata absent (line 441) -# --------------------------------------------------------------------------- - - -class TestSms2faNoNewDeviceMetadata: - """Cover line 441: sms_2fa when NewDeviceMetadata is absent in result.""" - - async def test_no_new_device_metadata_does_not_set_device_keys(self): - """When NewDeviceMetadata is absent, device_group_key and device_key stay None.""" - auth = await _make_auth() - original_group_key = auth.device_group_key # None - original_device_key = auth.device_key # None - - sms_result = { - "AuthenticationResult": { - "AccessToken": "sms-access-token", - # No "NewDeviceMetadata" key - } - } - auth.loop.run_in_executor = AsyncMock(return_value=sms_result) - - result = await auth.sms_2fa("654321", {"Session": "sess-abc"}) - - assert auth.access_token == "sms-access-token" - assert auth.device_group_key == original_group_key # unchanged - assert auth.device_key == original_device_key # unchanged - assert result is sms_result - - -# --------------------------------------------------------------------------- -# Tests: sms_2fa() — CodeMismatchException path (lines 424-429) -# --------------------------------------------------------------------------- - - -class TestSms2faCodeMismatch: - """Cover lines 424-429: CodeMismatchException raises HiveInvalid2FACode.""" - - async def test_code_mismatch_raises_invalid_2fa_code(self): - """CodeMismatchException in sms_2fa raises HiveInvalid2FACode.""" - auth = await _make_auth() - auth.loop.run_in_executor = AsyncMock( - side_effect=_named_client_error("CodeMismatchException") - ) - with pytest.raises(HiveInvalid2FACode): - await auth.sms_2fa("000000", {"Session": "sess-1"}) - - async def test_not_authorized_raises_invalid_2fa_code(self): - """NotAuthorizedException in sms_2fa raises HiveInvalid2FACode.""" - auth = await _make_auth() - auth.loop.run_in_executor = AsyncMock( - side_effect=_named_client_error("NotAuthorizedException") - ) - with pytest.raises(HiveInvalid2FACode): - await auth.sms_2fa("111111", {"Session": "sess-2"}) - - -# --------------------------------------------------------------------------- -# Tests: refresh_token() — client is None (line 440-441) -# --------------------------------------------------------------------------- - - -class TestRefreshTokenClientNone: - """Cover line 440-441: refresh_token calls async_init when client is None.""" - - async def test_refresh_token_calls_async_init_when_client_is_none(self): - """If client is None, async_init is awaited before refreshing.""" - auth = await _make_auth() - auth.client = None - - result_payload = {"AuthenticationResult": {"AccessToken": "refreshed-tok"}} - - async def fake_async_init(): - auth.client = MagicMock() - - with patch.object(auth, "async_init", side_effect=fake_async_init) as mock_init: - auth.loop.run_in_executor = AsyncMock(return_value=result_payload) - result = await auth.refresh_token("some-refresh-token") - - mock_init.assert_called_once() - assert result is result_payload - - -# --------------------------------------------------------------------------- -# Tests: refresh_token() — result path when no AuthenticationResult (lines 479-485) -# --------------------------------------------------------------------------- - - -class TestRefreshTokenResultPath: - """Cover lines 479-485: refresh_token when result is returned normally.""" - - async def test_returns_result_directly(self): - """refresh_token returns the result from Cognito directly.""" - auth = await _make_auth() - result_payload = {"AuthenticationResult": {"AccessToken": "tok-xyz"}} - auth.loop.run_in_executor = AsyncMock(return_value=result_payload) - - result = await auth.refresh_token("refresh-tok-abc") - - assert result is result_payload - - async def test_with_device_key_includes_device_key_param(self): - """When device_key is set, DEVICE_KEY is included in auth_params.""" - auth = await _make_auth(device_key="dk-refresh-001") - result_payload = {"AuthenticationResult": {"AccessToken": "tok-dk"}} - auth.loop.run_in_executor = AsyncMock(return_value=result_payload) - - result = await auth.refresh_token("refresh-tok-dk") - - assert result is result_payload - - async def test_without_device_key_sends_only_refresh_token(self): - """When device_key is None, auth_params has only REFRESH_TOKEN.""" - auth = await _make_auth(device_key=None) - result_payload = {"AuthenticationResult": {"AccessToken": "tok-no-dk"}} - auth.loop.run_in_executor = AsyncMock(return_value=result_payload) - - result = await auth.refresh_token("refresh-tok-no-dk") - - assert result is result_payload - - -# --------------------------------------------------------------------------- -# Tests: login() — swallowed ClientError in initiate_auth (line 280->288) -# --------------------------------------------------------------------------- - - -class TestLoginInitiateAuthSwallowedClientError: - """Arc 280->288: ClientError caught but class name is not UserNotFoundException.""" - - async def test_other_client_error_in_initiate_auth_falls_through(self): - """Non-UserNotFoundException ClientError is swallowed; response stays None → TypeError.""" - auth = await _make_auth() - - wrong_cls = type("SomeOtherError", (botocore.exceptions.ClientError,), {}) - wrong_err = wrong_cls( - {"Error": {"Code": "SomeOtherError", "Message": "msg"}}, "op" - ) - auth.loop.run_in_executor = AsyncMock(side_effect=wrong_err) - - # Exception is swallowed; line 288 `response["ChallengeName"]` raises TypeError - # because response is None - with pytest.raises((TypeError, KeyError)): - await auth.login() - - -# --------------------------------------------------------------------------- -# Tests: login() — swallowed EndpointConnectionError in initiate_auth (line 284->288) -# --------------------------------------------------------------------------- - - -class TestLoginInitiateAuthSwallowedEndpointError: - """Arc 284->288: EndpointConnectionError caught but class name is wrong.""" - - async def test_wrong_name_endpoint_error_in_initiate_auth_falls_through(self): - """EndpointConnectionError with wrong name is swallowed; response stays None.""" - auth = await _make_auth() - - wrong_cls = type( - "WrongEndpoint", (botocore.exceptions.EndpointConnectionError,), {} - ) - wrong_err = wrong_cls(endpoint_url="https://cognito.eu-west-1.amazonaws.com") - auth.loop.run_in_executor = AsyncMock(side_effect=wrong_err) - - with pytest.raises((TypeError, KeyError)): - await auth.login() - - -# --------------------------------------------------------------------------- -# Tests: login() — swallowed ClientError in respond_to_auth_challenge (307->319) -# --------------------------------------------------------------------------- - - -class TestLoginChallengeSwallowedClientError: - """Arc 307->319: ClientError caught in challenge response with name not matching.""" - - async def test_other_client_error_in_challenge_falls_through(self): - """ClientError that is neither NotAuthorized nor ResourceNotFound is swallowed.""" - auth = await _make_auth() - - challenge_response = { - "ChallengeName": "PASSWORD_VERIFIER", - "ChallengeParameters": { - "USER_ID_FOR_SRP": "user@test.com", - "SALT": "aabbccdd", - "SRP_B": "ccddee", - "SECRET_BLOCK": "YWJj", - }, - } - - wrong_cls = type("ThirdPartyError", (botocore.exceptions.ClientError,), {}) - wrong_err = wrong_cls( - {"Error": {"Code": "ThirdPartyError", "Message": "msg"}}, "op" - ) - - with patch.object(auth, "process_challenge", new_callable=AsyncMock) as mock_ch: - mock_ch.return_value = {"TIMESTAMP": "...", "USERNAME": "user"} - auth.loop.run_in_executor = AsyncMock( - side_effect=[challenge_response, wrong_err] - ) - # Exception is swallowed; result stays None → TypeError on line 321 - with pytest.raises((TypeError, AttributeError)): - await auth.login() - - -# --------------------------------------------------------------------------- -# Tests: login() — swallowed EndpointConnectionError in challenge (313->319) -# --------------------------------------------------------------------------- - - -class TestLoginChallengeSwallowedEndpointError: - """Arc 313->319: EndpointConnectionError caught with wrong class name in challenge.""" - - async def test_wrong_name_endpoint_error_in_challenge_falls_through(self): - """EndpointConnectionError with wrong name is swallowed; result stays None.""" - auth = await _make_auth() - - challenge_response = { - "ChallengeName": "PASSWORD_VERIFIER", - "ChallengeParameters": { - "USER_ID_FOR_SRP": "user@test.com", - "SALT": "aabbccdd", - "SRP_B": "ccddee", - "SECRET_BLOCK": "YWJj", - }, - } - - wrong_cls = type( - "WrongEndpoint", (botocore.exceptions.EndpointConnectionError,), {} - ) - wrong_err = wrong_cls(endpoint_url="https://cognito.eu-west-1.amazonaws.com") - - with patch.object(auth, "process_challenge", new_callable=AsyncMock) as mock_ch: - mock_ch.return_value = {"TIMESTAMP": "...", "USERNAME": "user"} - auth.loop.run_in_executor = AsyncMock( - side_effect=[challenge_response, wrong_err] - ) - with pytest.raises((TypeError, AttributeError)): - await auth.login() - - -# --------------------------------------------------------------------------- -# Tests: device_login() — success path through process_device_challenge (lines 364-367, 391) -# --------------------------------------------------------------------------- - - -class TestDeviceLoginSuccessPath: - """Lines 364-367, 391: device_login processes device challenge and returns result.""" - - async def test_successful_device_login_returns_auth_result(self): - """Full device_login success: process_device_challenge called, result returned.""" - auth = await _make_auth(device_key="dk-abc", device_group_key="grp-abc") - auth.device_password = "dev-pass" - - initial_result = { - "ChallengeParameters": { - "USERNAME": "user@test.com", - "SALT": "aabbccdd", - "SRP_B": "ccddee", - "SECRET_BLOCK": "YWJj", - } - } - final_result = {"AuthenticationResult": {"AccessToken": "device-access-token"}} - - with patch.object( - auth, "process_device_challenge", new_callable=AsyncMock - ) as mock_pdc: - mock_pdc.return_value = { - "TIMESTAMP": "Mon Jan 01 00:00:00 UTC 2024", - "USERNAME": "user@test.com", - "PASSWORD_CLAIM_SECRET_BLOCK": "YWJj", - "PASSWORD_CLAIM_SIGNATURE": "sig", - "DEVICE_KEY": "dk-abc", - } - auth.loop.run_in_executor = AsyncMock( - side_effect=[initial_result, final_result] - ) - result = await auth.device_login() - - mock_pdc.assert_called_once_with(initial_result["ChallengeParameters"]) - assert result is final_result - - async def test_device_login_calls_second_respond_to_auth_challenge(self): - """Lines 367-375: second respond_to_auth_challenge is called with device challenge.""" - auth = await _make_auth(device_key="dk-xyz", device_group_key="grp-xyz") - auth.device_password = "dev-pass-xyz" - - initial_result = { - "ChallengeParameters": { - "USERNAME": "user@test.com", - "SALT": "11223344", - "SRP_B": "55667788", - "SECRET_BLOCK": "dGVzdA==", - } - } - final_result = {"AuthenticationResult": {"AccessToken": "tok-xyz"}} - - challenge_resp = { - "TIMESTAMP": "Mon Jan 01 00:00:00 UTC 2024", - "USERNAME": "user@test.com", - "PASSWORD_CLAIM_SECRET_BLOCK": "dGVzdA==", - "PASSWORD_CLAIM_SIGNATURE": "sig", - "DEVICE_KEY": "dk-xyz", - } - - with patch.object( - auth, "process_device_challenge", new_callable=AsyncMock - ) as mock_pdc: - mock_pdc.return_value = challenge_resp - auth.loop.run_in_executor = AsyncMock( - side_effect=[initial_result, final_result] - ) - result = await auth.device_login() - - assert auth.loop.run_in_executor.call_count == 2 - assert result["AuthenticationResult"]["AccessToken"] == "tok-xyz" - - -# --------------------------------------------------------------------------- -# Tests: device_login() — wrong-name EndpointConnectionError (line 389) -# --------------------------------------------------------------------------- - - -class TestDeviceLoginEndpointWrongName: - """Line 389: EndpointConnectionError with wrong __class__.__name__ raises - HiveInvalidDeviceAuthentication instead of HiveApiError.""" - - async def test_wrong_name_endpoint_error_raises_invalid_device_auth(self): - """A subclass of EndpointConnectionError with a different name hits line 389.""" - auth = await _make_auth(device_key="dk-err", device_group_key="grp-err") - auth.device_password = "dev-pass-err" - - wrong_cls = type( - "WrongEndpoint", (botocore.exceptions.EndpointConnectionError,), {} - ) - wrong_err = wrong_cls(endpoint_url="https://cognito.eu-west-1.amazonaws.com") - auth.loop.run_in_executor = AsyncMock(side_effect=wrong_err) - - with pytest.raises(HiveInvalidDeviceAuthentication): - await auth.device_login() - - -# --------------------------------------------------------------------------- -# Tests: sms_2fa() — swallowed ClientError (arc 424->435) -# --------------------------------------------------------------------------- - - -class TestSms2faSwallowedClientError: - """Arc 424->435: ClientError caught in sms_2fa with unrecognised class name.""" - - async def test_other_client_error_is_swallowed_returns_none(self): - """Non-matching ClientError is swallowed; result stays None (returned).""" - auth = await _make_auth() - - wrong_cls = type("OtherError", (botocore.exceptions.ClientError,), {}) - wrong_err = wrong_cls({"Error": {"Code": "OtherError", "Message": "msg"}}, "op") - auth.loop.run_in_executor = AsyncMock(side_effect=wrong_err) - - result = await auth.sms_2fa("123456", {"Session": "sess-xyz"}) - assert ( - result is None - ) # sms_2fa initialises result=None; swallowed → returns None - - -# --------------------------------------------------------------------------- -# Tests: sms_2fa() — swallowed EndpointConnectionError (arc 431->435) -# --------------------------------------------------------------------------- - - -class TestSms2faSwallowedEndpointError: - """Arc 431->435: EndpointConnectionError caught with wrong class name in sms_2fa.""" - - async def test_wrong_name_endpoint_error_is_swallowed(self): - """EndpointConnectionError subclass with wrong name is swallowed; returns None.""" - auth = await _make_auth() - - wrong_cls = type( - "WrongEndpoint", (botocore.exceptions.EndpointConnectionError,), {} - ) - wrong_err = wrong_cls(endpoint_url="https://cognito.eu-west-1.amazonaws.com") - auth.loop.run_in_executor = AsyncMock(side_effect=wrong_err) - - result = await auth.sms_2fa("654321", {"Session": "sess-abc"}) - assert result is None - - -# --------------------------------------------------------------------------- -# Tests: refresh_token() — swallowed EndpointConnectionError (arc 479->485) -# --------------------------------------------------------------------------- - - -class TestRefreshTokenSwallowedEndpointError: - """Arc 479->485: EndpointConnectionError caught with wrong class name in refresh_token.""" - - async def test_wrong_name_endpoint_error_is_swallowed_returns_none(self): - """EndpointConnectionError subclass with wrong name is swallowed; result=None returned.""" - auth = await _make_auth() - - wrong_cls = type( - "WrongEndpoint", (botocore.exceptions.EndpointConnectionError,), {} - ) - wrong_err = wrong_cls(endpoint_url="https://cognito.eu-west-1.amazonaws.com") - auth.loop.run_in_executor = AsyncMock(side_effect=wrong_err) - - # result initialised to None; exception swallowed; line 485 reached; returns None - result = await auth.refresh_token("some-refresh-token") - assert result is None diff --git a/tests/unit/test_hive_exceptions.py b/tests/unit/test_hive_exceptions.py new file mode 100644 index 0000000..33ffb5c --- /dev/null +++ b/tests/unit/test_hive_exceptions.py @@ -0,0 +1,93 @@ +"""Unit tests for the hive_exceptions hierarchy.""" + +import pytest +from apyhiveapi.helper.hive_exceptions import ( + FileInUse, + HiveApiError, + HiveAuthCredentialError, + HiveAuthError, + HiveConfigurationError, + HiveError, + HiveFailedToRefreshTokens, + HiveInvalid2FACode, + HiveInvalidDeviceAuthentication, + HiveInvalidPassword, + HiveInvalidUsername, + HiveReauthRequired, + HiveRefreshTokenExpired, + HiveUnknownConfiguration, + NoApiToken, +) + + +class TestHiveErrorBase: + def test_hive_api_error_is_hive_error(self): + assert issubclass(HiveApiError, HiveError) + + def test_hive_auth_error_is_hive_api_error(self): + assert issubclass(HiveAuthError, HiveApiError) + + def test_hive_auth_error_is_hive_error(self): + assert issubclass(HiveAuthError, HiveError) + + def test_hive_refresh_token_expired_is_hive_api_error(self): + assert issubclass(HiveRefreshTokenExpired, HiveApiError) + + def test_hive_failed_to_refresh_is_hive_api_error(self): + assert issubclass(HiveFailedToRefreshTokens, HiveApiError) + + def test_hive_reauth_required_is_hive_error(self): + assert issubclass(HiveReauthRequired, HiveError) + + +class TestCredentialErrors: + def test_invalid_username_is_hive_auth_credential_error(self): + assert issubclass(HiveInvalidUsername, HiveAuthCredentialError) + + def test_invalid_password_is_hive_auth_credential_error(self): + assert issubclass(HiveInvalidPassword, HiveAuthCredentialError) + + def test_invalid_2fa_is_hive_auth_credential_error(self): + assert issubclass(HiveInvalid2FACode, HiveAuthCredentialError) + + def test_auth_credential_error_is_hive_error(self): + assert issubclass(HiveAuthCredentialError, HiveError) + + +class TestConfigurationErrors: + def test_unknown_config_is_hive_configuration_error(self): + assert issubclass(HiveUnknownConfiguration, HiveConfigurationError) + + def test_invalid_device_auth_is_hive_configuration_error(self): + assert issubclass(HiveInvalidDeviceAuthentication, HiveConfigurationError) + + def test_configuration_error_is_hive_error(self): + assert issubclass(HiveConfigurationError, HiveError) + + +class TestStandaloneExceptions: + def test_file_in_use_is_not_hive_error(self): + assert not issubclass(FileInUse, HiveError) + + def test_no_api_token_is_not_hive_error(self): + assert not issubclass(NoApiToken, HiveError) + + def test_file_in_use_is_exception(self): + assert issubclass(FileInUse, Exception) + + def test_no_api_token_is_exception(self): + assert issubclass(NoApiToken, Exception) + + +class TestInstantiable: + def test_hive_error_is_raiseable(self): + with pytest.raises(HiveError): + raise HiveError("test") + + def test_hive_api_error_caught_as_hive_error(self): + with pytest.raises(HiveError): + raise HiveApiError("test") + + def test_invalid_username_caught_as_hive_error(self): + with pytest.raises(HiveError): + raise HiveInvalidUsername("test") diff --git a/tests/unit/test_hive_helper_extended.py b/tests/unit/test_hive_helper_extended.py deleted file mode 100644 index 8c13ec3..0000000 --- a/tests/unit/test_hive_helper_extended.py +++ /dev/null @@ -1,143 +0,0 @@ -"""Tests for HiveHelper covering previously uncovered lines/branches.""" - -# pylint: disable=protected-access - -from unittest.mock import MagicMock - -from apyhiveapi.helper.hive_helper import HiveHelper -from apyhiveapi.helper.map import Map - - -def _make_helper(entity_cache=None, products=None): - """Build a HiveHelper with a minimally mocked session.""" - session = MagicMock() - session.entity_cache = entity_cache if entity_cache is not None else {} - session.data = Map( - { - "products": products or {}, - "devices": {}, - "actions": {}, - "user": {}, - "minMax": {}, - } - ) - return HiveHelper(session) - - -# --------------------------------------------------------------------------- -# get_device_from_id — branch 133->122 (no match, loop continues then exits) -# --------------------------------------------------------------------------- - - -class TestGetDeviceFromIdBranch: - """Covers the branch where no cache entry matches the requested ID.""" - - def test_returns_false_when_no_match_in_cache(self): - """When entity_cache has entries but none match n_id, returns False. - - This exercises the branch where the 'if n_id in (hive_id, device_id)' - condition is False for every item (133->122 loop-continue then exit). - """ - from apyhiveapi.helper.hivedataclasses import Device - - other_device = Device( - hive_id="other-hive-id", - hive_name="Other", - hive_type="heating", - ha_type="climate", - device_id="other-device-id", - device_name="Other", - device_data={}, - ) - helper = _make_helper(entity_cache={"other-key": other_device}) - result = helper.get_device_from_id("nonexistent-id") - assert result is False - - def test_returns_false_when_cache_is_empty(self): - """When entity_cache is empty, returns False without entering the loop.""" - helper = _make_helper(entity_cache={}) - assert helper.get_device_from_id("any-id") is False - - -# --------------------------------------------------------------------------- -# get_heat_on_demand_device — lines 315-317 -# --------------------------------------------------------------------------- - - -class TestGetHeatOnDemandDevice: - """Covers HiveHelper.get_heat_on_demand_device (lines 315-317).""" - - def test_returns_linked_thermostat(self): - """Looks up TRV by HiveID, then fetches linked thermostat by zone.""" - trv_id = "trv-001" - thermostat_id = "zone-001" - - trv_data = {"state": {"zone": thermostat_id}, "type": "trvcontrol"} - thermostat_data = {"id": thermostat_id, "type": "heating"} - - products = { - trv_id: trv_data, - thermostat_id: thermostat_data, - } - helper = _make_helper(products=products) - - # Device accessed with dict-style key "HiveID" as used inside the method - device = MagicMock() - device.__getitem__ = MagicMock( - side_effect=lambda k: trv_id if k == "HiveID" else None - ) - - result = helper.get_heat_on_demand_device(device) - assert result == thermostat_data - - -# --------------------------------------------------------------------------- -# sanitize_payload — list masking (line 329) and non-str/dict/list fallthrough -# --------------------------------------------------------------------------- - - -class TestSanitizePayload: - """Covers _mask branches for list values and non-string scalar fallthrough.""" - - def test_list_value_under_sensitive_key_is_masked(self): - """A list value under a sensitive key has each element masked.""" - helper = _make_helper() - payload = {"token": ["short", "averylongtoken123"]} - result = helper.sanitize_payload(payload) - # "short" (<=8 chars) → "***", "averylongtoken123" (>8 chars) → "aver...n123" - assert result["token"] == ["***", "aver...n123"] - - def test_non_string_non_dict_non_list_under_sensitive_key_passes_through(self): - """An int/bool/None value under a sensitive key is returned as-is.""" - helper = _make_helper() - payload = {"token": 42} - result = helper.sanitize_payload(payload) - assert result["token"] == 42 - - def test_none_under_sensitive_key_passes_through(self): - """None under a sensitive key is returned unchanged.""" - helper = _make_helper() - payload = {"token": None} - result = helper.sanitize_payload(payload) - assert result["token"] is None - - def test_bool_under_sensitive_key_passes_through(self): - """A bool under a sensitive key is returned unchanged (not a str/dict/list).""" - helper = _make_helper() - payload = {"token": True} - result = helper.sanitize_payload(payload) - assert result["token"] is True - - def test_short_string_masked_as_stars(self): - """A string of 8 characters or fewer is masked as '***'.""" - helper = _make_helper() - payload = {"password": "abc12345"} # exactly 8 chars - result = helper.sanitize_payload(payload) - assert result["password"] == "***" - - def test_long_string_partially_masked(self): - """A string longer than 8 characters is partially masked.""" - helper = _make_helper() - payload = {"password": "supersecretpassword"} - result = helper.sanitize_payload(payload) - assert result["password"] == "supe...word" diff --git a/tests/unit/test_hotwater.py b/tests/unit/test_hotwater.py new file mode 100644 index 0000000..5da4e33 --- /dev/null +++ b/tests/unit/test_hotwater.py @@ -0,0 +1,328 @@ +"""Extended branch-coverage tests for WaterHeater / HiveHotwater.""" + +# pylint: disable=too-few-public-methods +from unittest.mock import AsyncMock, MagicMock + +from apyhiveapi.devices.hotwater import WaterHeater +from apyhiveapi.helper.hivedataclasses import Device, SessionConfig +from apyhiveapi.helper.map import Map + +_SCHEDULE_MODE = "SCHEDULE" +_ON_MODE = "ON" +_OFF_MODE = "OFF" +_BOOST_MODE = "BOOST" + + +def _make_hotwater(products=None, devices=None): + session = MagicMock() + session.data = Map( + { + "products": products or {}, + "devices": devices or {}, + "actions": {}, + "minMax": {}, + "user": {}, + } + ) + session.config = SessionConfig() + session.helper = MagicMock() + session.helper.device_recovered = MagicMock() + session.helper.error_check = AsyncMock() + session.helper.get_schedule_nnl = MagicMock( + return_value={"now": {"value": {"status": _ON_MODE}}, "next": {}, "later": {}} + ) + session.attr = MagicMock() + session.attr.online_offline = AsyncMock(return_value=True) + session.attr.state_attributes = AsyncMock(return_value={}) + session.api = MagicMock() + session.api.set_state = AsyncMock(return_value={"original": 200, "parsed": {}}) + session.hive_refresh_tokens = AsyncMock() + session.get_devices = AsyncMock(return_value=True) + session.should_use_cached_data = MagicMock(return_value=False) + session.get_cached_device = MagicMock(return_value=None) + session.set_cached_device = MagicMock(side_effect=lambda d: d) + return WaterHeater(session=session) + + +def _make_device(hive_id="hw-1", device_id="dev-1"): + return Device( + hive_id=hive_id, + hive_name="Hot Water", + hive_type="hotwater", + ha_type="water_heater", + device_id=device_id, + device_name="Hot Water", + device_data={"online": True}, + ha_name="Hot Water", + ) + + +class TestGetMode: + async def test_boost_mode_reads_previous(self): + """BOOST state resolves to the previous mode stored in props.""" + hw = _make_hotwater( + { + "hw-1": { + "state": {"mode": _BOOST_MODE}, + "props": {"previous": {"mode": _ON_MODE}}, + } + } + ) + result = await hw.get_mode(_make_device()) + assert result == _ON_MODE + + +class TestGetState: + async def test_schedule_mode_boost_off_reads_schedule(self): + """SCHEDULE mode with boost OFF reads state from schedule nnl.""" + hw = _make_hotwater( + { + "hw-1": { + "state": { + "mode": _SCHEDULE_MODE, + "status": _OFF_MODE, + "boost": False, + "schedule": {}, + } + } + } + ) + result = await hw.get_state(_make_device()) + assert result is not None + + async def test_non_schedule_state_mapped(self): + """Direct ON mode/status maps through HIVETOHA without schedule lookup.""" + hw = _make_hotwater( + { + "hw-1": { + "state": { + "mode": _ON_MODE, + "status": _ON_MODE, + "schedule": {}, + } + } + } + ) + result = await hw.get_state(_make_device()) + assert result is not None + + +class TestGetWaterHeater: + async def test_cache_hit_returns_cached(self): + """Cached device is returned immediately when poll is slow/busy.""" + hw = _make_hotwater() + hw.session.should_use_cached_data = MagicMock(return_value=True) + cached_device = _make_device() + cached_device.status = {"current_operation": _SCHEDULE_MODE} + hw.session.get_cached_device = MagicMock(return_value=cached_device) + d = _make_device() + result = await hw.get_water_heater(d) + assert result is cached_device + hw.session.attr.online_offline.assert_not_called() + + async def test_device_data_not_dict_gets_reset(self): + """Non-dict device_data is replaced with an empty dict before use.""" + hw = _make_hotwater( + products={"hw-1": {"state": {"mode": _SCHEDULE_MODE}}}, + devices={"dev-1": {"props": {}, "parent": None}}, + ) + d = _make_device() + d.device_data = None + await hw.get_water_heater(d) + assert isinstance(d.device_data, dict) + + async def test_offline_device_calls_error_check(self): + """Offline device triggers error_check and status defaults to None.""" + hw = _make_hotwater( + products={"hw-1": {}}, + devices={"dev-1": {}}, + ) + hw.session.attr.online_offline = AsyncMock(return_value=False) + d = _make_device() + result = await hw.get_water_heater(d) + hw.session.helper.error_check.assert_called_once() + assert result.status["current_operation"] is None + + +class TestGetScheduleNowNextLater: + async def test_schedule_mode_returns_nnl(self): + """SCHEDULE mode with schedule data returns now/next/later dict.""" + hw = _make_hotwater( + {"hw-1": {"state": {"mode": _SCHEDULE_MODE, "schedule": {"data": []}}}} + ) + result = await hw.get_schedule_now_next_later(_make_device()) + assert result is not None + assert "now" in result + + async def test_non_schedule_mode_returns_none(self): + """Non-SCHEDULE mode returns None.""" + hw = _make_hotwater({"hw-1": {"state": {"mode": _ON_MODE}}}) + result = await hw.get_schedule_now_next_later(_make_device()) + assert result is None + + +# --------------------------------------------------------------------------- +# get_mode — BOOST path with missing props.previous must not log error +# --------------------------------------------------------------------------- + + +class TestHotwaterGetModeBoostMissingPrevious: + """get_mode BOOST path must use safe access, not bare dict that logs a spurious error.""" + + async def test_boost_missing_previous_returns_off_without_error_log( # noqa: E501 + self, + ): + """When mode=BOOST and props has no previous, get_mode returns 'OFF' without error log. + + The hotwater HIVETOHA mapping maps None→'OFF', so missing previous + resolves safely instead of throwing a swallowed KeyError. + """ + from unittest.mock import patch + + hw = _make_hotwater({"hw-1": {"state": {"mode": "BOOST"}, "props": {}}}) + d = _make_device() + with patch("apyhiveapi.devices.hotwater._LOGGER") as mock_log: + result = await hw.get_mode(d) + assert result == "OFF" + mock_log.error.assert_not_called() + + +# --------------------------------------------------------------------------- +# set_boost_off — must return False when prev_mode is None (not send mode=None) +# --------------------------------------------------------------------------- + + +class TestHotwaterSetBoostOffNullPrevMode: + """HiveHotwater.set_boost_off returns False when prev_mode is None.""" + + async def test_set_boost_off_returns_false_when_prev_mode_missing(self): + """set_boost_off returns False and does not call API when prev mode is absent.""" + from apyhiveapi.devices.hotwater import HiveHotwater + + class StubHotwater(HiveHotwater): + """Concrete stub for testing.""" + + h = StubHotwater() + h.session = MagicMock() + h.session.data.products = {"h1": {"state": {"mode": "BOOST"}, "props": {}}} + h._execute_state_change = AsyncMock(return_value=True) + h.get_boost_status = AsyncMock(return_value="ON") + + d = Device( + hive_id="h1", + hive_name="T", + hive_type="hotwater", + ha_type="water_heater", + device_id="d1", + device_name="T", + device_data={"online": True}, + ha_name="Hotwater", + ) + result = await h.set_boost_off(d) + assert result is False + h._execute_state_change.assert_not_called() + + +class TestHotwaterGetStateScheduleGuard: + """get_state handles empty schedule NNL without KeyError.""" + + async def test_get_state_schedule_guard_empty_nnl(self): + """get_state does not crash when get_schedule_nnl returns empty dict.""" + from apyhiveapi.devices.hotwater import HiveHotwater + + class StubHotwater(HiveHotwater): + pass + + h = StubHotwater() + h.session = MagicMock() + h.session.data.products = { + "h1": { + "state": {"status": "ON", "mode": "SCHEDULE", "schedule": {}}, + "props": {}, + } + } + h.session.helper.get_schedule_nnl = MagicMock(return_value={}) + h.get_mode = AsyncMock(return_value="SCHEDULE") + h.get_boost_status = AsyncMock(return_value="OFF") + + d = Device( + hive_id="h1", + hive_name="T", + hive_type="hotwater", + ha_type="water_heater", + device_id="d1", + device_name="T", + device_data={"online": True}, + ha_name="Hotwater", + ) + result = await h.get_state(d) + assert result is None or isinstance(result, str) + + +# =========================================================================== +# Migrated from test_remaining_branches.py +# =========================================================================== + + +class TestHotwaterGetModeKeyError: + """Lines 43-44: KeyError in get_mode.""" + + async def test_get_mode_missing_state_returns_none(self): + """Product with no 'state' key causes KeyError → final stays None.""" + hw = _make_hotwater({"hw-1": {}}) + result = await hw.get_mode(_make_device()) + assert result is None + + +class TestHotwaterGetStateKeyError: + """Lines 83-84: KeyError in get_state.""" + + async def test_get_state_missing_status_key_returns_none(self): + """Product 'state' dict missing 'status' key triggers KeyError → None.""" + hw = _make_hotwater({"hw-1": {"state": {"mode": "MANUAL"}}}) + # 'status' key is absent from state → KeyError on data["state"]["status"] + result = await hw.get_state(_make_device()) + assert result is None + + async def test_get_state_missing_schedule_in_schedule_mode_returns_none(self): + """SCHEDULE mode with no 'schedule' key in state causes KeyError → None.""" + hw = _make_hotwater( + { + "hw-1": { + "state": { + "mode": "SCHEDULE", + "status": "ON", + "boost": False, + # no 'schedule' key + } + } + } + ) + result = await hw.get_state(_make_device()) + assert result is None + + +class TestHotwaterScheduleNNLNone: + """Lines 225->227: get_schedule_now_next_later returns None when schedule is absent.""" + + async def test_schedule_none_when_no_schedule_in_state(self): + """SCHEDULE mode product without 'schedule' key → _get_product_state returns None → None.""" + hw = _make_hotwater({"hw-1": {"state": {"mode": "SCHEDULE"}}}) + # _get_product_state(device, "state", "schedule") → None (key absent) + result = await hw.get_schedule_now_next_later(_make_device()) + assert result is None + + +class TestHotwaterGetWaterHeaterCacheMiss: + """Lines 173->180: cache enabled but cached is None → continues with network call.""" + + async def test_cached_none_falls_through(self): + hw = _make_hotwater( + {"hw-1": {"state": {"mode": "ON"}, "props": {}}}, + devices={"dev-1": {"state": {}, "props": {}}}, + ) + hw.session.should_use_cached_data.return_value = True + hw.session.get_cached_device.return_value = None + result = await hw.get_water_heater(_make_device()) + assert result is not None + hw.session.attr.online_offline.assert_called_once() diff --git a/tests/unit/test_hotwater_extended.py b/tests/unit/test_hotwater_extended.py deleted file mode 100644 index 45891f7..0000000 --- a/tests/unit/test_hotwater_extended.py +++ /dev/null @@ -1,162 +0,0 @@ -"""Extended branch-coverage tests for WaterHeater / HiveHotwater.""" - -# pylint: disable=too-few-public-methods -from unittest.mock import AsyncMock, MagicMock - -from apyhiveapi.devices.hotwater import WaterHeater -from apyhiveapi.helper.hivedataclasses import Device, SessionConfig -from apyhiveapi.helper.map import Map - -_SCHEDULE_MODE = "SCHEDULE" -_ON_MODE = "ON" -_OFF_MODE = "OFF" -_BOOST_MODE = "BOOST" -_BOOST_MINS = 30 - - -def _make_hotwater(products=None, devices=None): - session = MagicMock() - session.data = Map( - { - "products": products or {}, - "devices": devices or {}, - "actions": {}, - "minMax": {}, - "user": {}, - } - ) - session.config = SessionConfig() - session.helper = MagicMock() - session.helper.device_recovered = MagicMock() - session.helper.error_check = AsyncMock() - session.helper.get_schedule_nnl = MagicMock( - return_value={"now": {"value": {"status": _ON_MODE}}, "next": {}, "later": {}} - ) - session.attr = MagicMock() - session.attr.online_offline = AsyncMock(return_value=True) - session.attr.state_attributes = AsyncMock(return_value={}) - session.api = MagicMock() - session.api.set_state = AsyncMock(return_value={"original": 200, "parsed": {}}) - session.hive_refresh_tokens = AsyncMock() - session.get_devices = AsyncMock(return_value=True) - session.should_use_cached_data = MagicMock(return_value=False) - session.get_cached_device = MagicMock(return_value=None) - session.set_cached_device = MagicMock(side_effect=lambda d: d) - return WaterHeater(session=session) - - -def _make_device(hive_id="hw-1", device_id="dev-1"): - return Device( - hive_id=hive_id, - hive_name="Hot Water", - hive_type="hotwater", - ha_type="water_heater", - device_id=device_id, - device_name="Hot Water", - device_data={"online": True}, - ha_name="Hot Water", - ) - - -class TestGetMode: - async def test_boost_mode_reads_previous(self): - """BOOST state resolves to the previous mode stored in props.""" - hw = _make_hotwater( - { - "hw-1": { - "state": {"mode": _BOOST_MODE}, - "props": {"previous": {"mode": _ON_MODE}}, - } - } - ) - result = await hw.get_mode(_make_device()) - assert result == _ON_MODE - - -class TestGetState: - async def test_schedule_mode_boost_off_reads_schedule(self): - """SCHEDULE mode with boost OFF reads state from schedule nnl.""" - hw = _make_hotwater( - { - "hw-1": { - "state": { - "mode": _SCHEDULE_MODE, - "status": _OFF_MODE, - "boost": False, - "schedule": {}, - } - } - } - ) - result = await hw.get_state(_make_device()) - assert result is not None - - async def test_non_schedule_state_mapped(self): - """Direct ON mode/status maps through HIVETOHA without schedule lookup.""" - hw = _make_hotwater( - { - "hw-1": { - "state": { - "mode": _ON_MODE, - "status": _ON_MODE, - "schedule": {}, - } - } - } - ) - result = await hw.get_state(_make_device()) - assert result is not None - - -class TestGetWaterHeater: - async def test_cache_hit_returns_cached(self): - """Cached device is returned immediately when poll is slow/busy.""" - hw = _make_hotwater() - hw.session.should_use_cached_data = MagicMock(return_value=True) - cached_device = _make_device() - cached_device.status = {"current_operation": _SCHEDULE_MODE} - hw.session.get_cached_device = MagicMock(return_value=cached_device) - d = _make_device() - result = await hw.get_water_heater(d) - assert result is cached_device - hw.session.attr.online_offline.assert_not_called() - - async def test_device_data_not_dict_gets_reset(self): - """Non-dict device_data is replaced with an empty dict before use.""" - hw = _make_hotwater( - products={"hw-1": {"state": {"mode": _SCHEDULE_MODE}}}, - devices={"dev-1": {"props": {}, "parent": None}}, - ) - d = _make_device() - d.device_data = None - await hw.get_water_heater(d) - assert isinstance(d.device_data, dict) - - async def test_offline_device_calls_error_check(self): - """Offline device triggers error_check and status defaults to None.""" - hw = _make_hotwater( - products={"hw-1": {}}, - devices={"dev-1": {}}, - ) - hw.session.attr.online_offline = AsyncMock(return_value=False) - d = _make_device() - result = await hw.get_water_heater(d) - hw.session.helper.error_check.assert_called_once() - assert result.status["current_operation"] is None - - -class TestGetScheduleNowNextLater: - async def test_schedule_mode_returns_nnl(self): - """SCHEDULE mode with schedule data returns now/next/later dict.""" - hw = _make_hotwater( - {"hw-1": {"state": {"mode": _SCHEDULE_MODE, "schedule": {"data": []}}}} - ) - result = await hw.get_schedule_now_next_later(_make_device()) - assert result is not None - assert "now" in result - - async def test_non_schedule_mode_returns_none(self): - """Non-SCHEDULE mode returns None.""" - hw = _make_hotwater({"hw-1": {"state": {"mode": _ON_MODE}}}) - result = await hw.get_schedule_now_next_later(_make_device()) - assert result is None diff --git a/tests/unit/test_light_extended.py b/tests/unit/test_light.py similarity index 75% rename from tests/unit/test_light_extended.py rename to tests/unit/test_light.py index d1904e7..a6ff385 100644 --- a/tests/unit/test_light_extended.py +++ b/tests/unit/test_light.py @@ -233,3 +233,79 @@ async def test_turn_on_with_color_calls_set_color(self): call_kwargs = session.api.set_state.call_args.kwargs assert call_kwargs.get("colourMode") == "COLOUR" assert call_kwargs.get("hue") == str(color[0]) + + +# --------------------------------------------------------------------------- +# get_brightness — must return int, not float +# --------------------------------------------------------------------------- + + +class TestGetBrightnessReturnsInt: + """get_brightness must return int, not float.""" + + async def test_get_brightness_returns_int(self): + """Brightness value is returned as int (not float) for HA compatibility.""" + from apyhiveapi.devices.light import HiveLight + + class StubLight(HiveLight): + """Concrete stub for testing.""" + + h = StubLight() + h.session = MagicMock() + h.session.data.products = {"h1": {"state": {"brightness": 50}}} + d = Device( + hive_id="h1", + hive_name="L", + hive_type="warmwhitelight", + ha_type="light", + device_id="d1", + device_name="L", + device_data={"online": True}, + ha_name="Light", + ) + result = await h.get_brightness(d) + assert isinstance(result, int), ( + f"Expected int, got {type(result).__name__}: {result!r}" + ) + assert result == 127 + + +class TestGetBrightnessNullValue: + """A null brightness in the API payload must not raise TypeError.""" + + async def test_null_brightness_returns_none(self): + session = _make_session( + products={ + "light-1": { + "state": {"status": "ON", "brightness": None}, + "props": {}, + } + }, + ) + light = Light(session=session) + result = await light.get_brightness(_make_device()) + assert result is None + + +# =========================================================================== +# Migrated from test_remaining_branches.py +# =========================================================================== + + +class TestLightGetLightCacheMiss: + """Lines 141->147: cache enabled but cached is None → normal execution.""" + + async def test_cached_none_falls_through(self): + session = _make_session( + products={ + "light-1": {"state": {"status": "ON", "brightness": 100}, "props": {}} + }, + devices={"dev-1": {"state": {}, "props": {}}}, + ) + light = Light(session=session) + d = _make_device() + session.should_use_cached_data.return_value = True + session.get_cached_device.return_value = None + result = await light.get_light(d) + assert result is not None + session.attr.online_offline.assert_called_once() diff --git a/tests/unit/test_map.py b/tests/unit/test_map.py index f6209f7..ee9ab44 100644 --- a/tests/unit/test_map.py +++ b/tests/unit/test_map.py @@ -1,5 +1,6 @@ """Unit tests for Map — dot-notation dict wrapper.""" +import pytest from apyhiveapi.helper.map import Map @@ -15,10 +16,18 @@ def test_dict_read(): assert m["key"] == "value" -def test_missing_key_returns_none_not_keyerror(): - """Test that missing keys return None instead of raising KeyError.""" +def test_missing_key_raises_attribute_error(): + """Missing attribute access raises AttributeError.""" m = Map({}) - assert m.missing is None + with pytest.raises(AttributeError): + _ = m.missing + + +def test_missing_bracket_key_raises_key_error(): + """Missing bracket access raises KeyError (standard dict behaviour).""" + m = Map({}) + with pytest.raises(KeyError): + _ = m["missing"] def test_nested_access(): @@ -32,3 +41,13 @@ def test_attribute_write(): m = Map({}) m.foo = "bar" assert m["foo"] == "bar" + + +class TestMapDelAttr: + """Map.__delattr__ must raise AttributeError (not KeyError) for missing keys.""" + + def test_delattr_missing_key_raises_attribute_error(self): + """del m.missing raises AttributeError, not KeyError.""" + m = Map({"a": 1}) + with pytest.raises(AttributeError): + del m.nonexistent diff --git a/tests/unit/test_polling.py b/tests/unit/test_polling.py index e518941..77062a0 100644 --- a/tests/unit/test_polling.py +++ b/tests/unit/test_polling.py @@ -3,6 +3,7 @@ # pylint: disable=protected-access,attribute-defined-outside-init,too-few-public-methods import asyncio +from unittest.mock import AsyncMock, MagicMock from apyhiveapi.helper.hivedataclasses import Device from apyhiveapi.session.polling import PollingMixin @@ -191,8 +192,6 @@ class TestPollDevices: async def test_poll_devices_delegates_to_get_devices(self): """_poll_devices calls get_devices('No_ID') and returns its result.""" - from unittest.mock import AsyncMock - p = _make_polling() p.get_devices = AsyncMock(return_value=True) result = await p._poll_devices() @@ -201,9 +200,176 @@ async def test_poll_devices_delegates_to_get_devices(self): async def test_poll_devices_propagates_false(self): """_poll_devices returns False when get_devices returns False.""" - from unittest.mock import AsyncMock - p = _make_polling() p.get_devices = AsyncMock(return_value=False) result = await p._poll_devices() assert result is False + + +# --------------------------------------------------------------------------- +# TestGetDevicesSlowPoll +# --------------------------------------------------------------------------- + + +class TestGetDevicesSlowPoll: + async def test_auth_error_sets_last_poll_slow_false(self): + from apyhiveapi.helper.hive_exceptions import HiveAuthError + + p = _make_polling() + p.api = MagicMock() + p.api.get_all = AsyncMock(side_effect=HiveAuthError()) + p.config = MagicMock() + p.config.file = False + p.tokens = MagicMock() + p._last_poll_slow = True # pre-set to True to confirm it gets cleared + + retry_result = { + "original": 200, + "parsed": {"products": [], "devices": [], "actions": []}, + } + + async def fake_retry_login(): + pass + + async def fake_retry_with_backoff(_fn, **_kwargs): + return retry_result + + p._retry_login = fake_retry_login + p._retry_with_backoff = fake_retry_with_backoff + p.hive_refresh_tokens = AsyncMock() + p.data = MagicMock() + p.data.products = {} + p.data.devices = {} + p.data.actions = {} + p.config.last_update = MagicMock() + p.config.scan_interval = MagicMock() + + await p.get_devices("No_ID") + assert p._last_poll_slow is False + + async def test_tokens_none_returns_false_without_crash(self): + """get_devices returns False (no crash) when tokens=None and file=False.""" + from apyhiveapi.helper.map import Map + + p = _make_polling() + p.config = MagicMock() + p.config.file = False + p.tokens = None # triggers the "neither branch" path + p.data = Map({"products": {}, "devices": {}, "actions": {}, "user": {}}) + + result = await p.get_devices("No_ID") + assert result is False + + async def test_slow_api_call_sets_last_poll_slow_true(self): + p = _make_polling() + p._slow_poll_threshold = 0 # any call will be "slow" + p.api = MagicMock() + + slow_result = { + "original": 200, + "parsed": {"products": [], "devices": [], "actions": []}, + } + + async def slow_get_all(): + return slow_result + + p.api.get_all = slow_get_all + p.config = MagicMock() + p.config.file = False + p.tokens = MagicMock() + p.hive_refresh_tokens = AsyncMock() + p.data = MagicMock() + p.data.products = {} + p.data.devices = {} + p.data.actions = {} + p.config.last_update = MagicMock() + p.config.scan_interval = MagicMock() + + await p.get_devices("No_ID") + assert p._last_poll_slow is True + + +# --------------------------------------------------------------------------- +# TestGetDevicesNoneGuard — api.get_all() returning None must not crash +# --------------------------------------------------------------------------- + + +class TestGetDevicesNoneGuard: + """api_resp_d must be guarded before dict access when api.get_all() returns None.""" + + async def test_api_returns_none_does_not_crash(self): + """get_devices returns False without crashing when api.get_all() returns None.""" + from apyhiveapi.helper.map import Map + + p = _make_polling() + p.config = MagicMock() + p.config.file = False + p.tokens = MagicMock() + p.api = MagicMock() + p.api.get_all = AsyncMock(return_value=None) + p.hive_refresh_tokens = AsyncMock() + p.data = Map({"products": {}, "devices": {}, "actions": {}, "user": {}}) + + result = await p.get_devices("No_ID") + assert result is False + + +# --------------------------------------------------------------------------- +# TestGetDevicesHomesKey — homes list null/empty must not crash +# --------------------------------------------------------------------------- + + +class TestGetDevicesHomesKey: + """homes key in API response must not crash when homes list is None or empty.""" + + async def _run_get_devices_with_parsed(self, parsed): + from apyhiveapi.helper.map import Map + + p = _make_polling() + p.config = MagicMock() + p.config.file = False + p.tokens = MagicMock() + p.api = MagicMock() + p.api.get_all = AsyncMock(return_value={"original": "200", "parsed": parsed}) + p.hive_refresh_tokens = AsyncMock() + p.data = Map({"products": {}, "devices": {}, "actions": {}, "user": {}}) + return p, await p.get_devices("No_ID") + + async def test_homes_null_does_not_crash(self): + """No crash when API returns homes.homes = None.""" + parsed = { + "products": [], + "devices": [], + "actions": [], + "homes": {"homes": None}, + } + _p, result = await self._run_get_devices_with_parsed(parsed) + assert isinstance(result, bool) + + async def test_homes_empty_list_does_not_crash(self): + """No crash when API returns homes.homes = [].""" + parsed = {"products": [], "devices": [], "actions": [], "homes": {"homes": []}} + _p, result = await self._run_get_devices_with_parsed(parsed) + assert isinstance(result, bool) + + async def test_valid_homes_list_sets_home_id(self): + """home_id is set correctly from a valid homes list.""" + parsed = { + "products": [], + "devices": [], + "actions": [], + "homes": {"homes": [{"id": "home-abc"}]}, + } + p, _ = await self._run_get_devices_with_parsed(parsed) + assert p.config.home_id == "home-abc" + + async def test_homes_data_not_dict_does_not_crash(self): + """No crash when homes_data is a non-dict (list); home_id stays unset.""" + parsed = { + "products": [], + "devices": [], + "actions": [], + "homes": [{"id": "home-list"}], + } + _p, result = await self._run_get_devices_with_parsed(parsed) + assert isinstance(result, bool) diff --git a/tests/unit/test_remaining_branches.py b/tests/unit/test_remaining_branches.py deleted file mode 100644 index cb5b8fb..0000000 --- a/tests/unit/test_remaining_branches.py +++ /dev/null @@ -1,1330 +0,0 @@ -"""Branch-coverage tests for several source modules. - -Covers missing lines in: - - src/devices/heating.py - - src/devices/hotwater.py - - src/devices/light.py - - src/devices/sensor.py - - src/session/auth.py - - src/session/discovery.py -""" - -# pylint: disable=too-few-public-methods,protected-access,attribute-defined-outside-init - -import asyncio -from datetime import datetime, timedelta -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest -from apyhiveapi.devices.heating import Climate -from apyhiveapi.devices.hotwater import WaterHeater -from apyhiveapi.devices.light import Light -from apyhiveapi.devices.sensor import Sensor -from apyhiveapi.helper.hive_exceptions import HiveApiError -from apyhiveapi.helper.hive_helper import HiveHelper -from apyhiveapi.helper.hivedataclasses import ( - Device, - EntityConfig, - SessionConfig, - SessionTokens, -) -from apyhiveapi.helper.map import Map -from apyhiveapi.session.auth import SessionAuthMixin -from apyhiveapi.session.discovery import DiscoveryMixin - -# --------------------------------------------------------------------------- -# Shared helpers — heating -# --------------------------------------------------------------------------- - - -def _make_climate(products=None, devices=None, min_max=None): - session = MagicMock() - session.data = Map( - { - "products": products or {}, - "devices": devices or {}, - "actions": {}, - "minMax": min_max or {}, - "user": {}, - } - ) - session.config = SessionConfig() - session.helper = MagicMock() - session.helper.device_recovered = MagicMock() - session.helper.error_check = AsyncMock() - session.helper.get_schedule_nnl = MagicMock( - return_value={"now": {}, "next": {}, "later": {}} - ) - session.attr = MagicMock() - session.attr.online_offline = AsyncMock(return_value=True) - session.attr.state_attributes = AsyncMock(return_value={}) - session.api = MagicMock() - session.api.set_state = AsyncMock(return_value={"original": 200, "parsed": {}}) - session.hive_refresh_tokens = AsyncMock() - session.get_devices = AsyncMock(return_value=True) - session.should_use_cached_data = MagicMock(return_value=False) - session.get_cached_device = MagicMock(return_value=None) - session.set_cached_device = MagicMock(side_effect=lambda d: d) - return Climate(session=session) - - -def _make_device(hive_id="heat-1", device_id="dev-1", hive_type="heating"): - return Device( - hive_id=hive_id, - hive_name="Hallway", - hive_type=hive_type, - ha_type="climate", - device_id=device_id, - device_name="Hallway", - device_data={"online": True}, - ha_name="Hallway", - ) - - -# --------------------------------------------------------------------------- -# Shared helpers — hotwater -# --------------------------------------------------------------------------- - - -def _make_hotwater(products=None, devices=None): - session = MagicMock() - session.data = Map( - { - "products": products or {}, - "devices": devices or {}, - "actions": {}, - "minMax": {}, - "user": {}, - } - ) - session.config = SessionConfig() - session.helper = MagicMock() - session.helper.device_recovered = MagicMock() - session.helper.get_schedule_nnl = MagicMock( - return_value={"now": {}, "next": {}, "later": {}} - ) - session.attr = MagicMock() - session.attr.online_offline = AsyncMock(return_value=True) - session.attr.state_attributes = AsyncMock(return_value={}) - session.api = MagicMock() - session.api.set_state = AsyncMock(return_value={"original": 200, "parsed": {}}) - session.hive_refresh_tokens = AsyncMock() - session.get_devices = AsyncMock(return_value=True) - session.should_use_cached_data = MagicMock(return_value=False) - session.get_cached_device = MagicMock(return_value=None) - session.set_cached_device = MagicMock(side_effect=lambda d: d) - return WaterHeater(session=session) - - -def _make_hw_device(hive_id="hw-1", device_id="dev-1"): - return Device( - hive_id=hive_id, - hive_name="Hot Water", - hive_type="hotwater", - ha_type="water_heater", - device_id=device_id, - device_name="Hot Water", - device_data={"online": True}, - ha_name="Hot Water", - ) - - -# --------------------------------------------------------------------------- -# Shared helpers — sensor -# --------------------------------------------------------------------------- - - -def _make_sensor(products=None, devices=None): - session = MagicMock() - session.data = Map( - { - "products": products or {}, - "devices": devices or {}, - "actions": {}, - "minMax": {}, - "user": {}, - } - ) - session.config = SessionConfig() - session.helper = MagicMock() - session.helper.device_recovered = MagicMock() - session.attr = MagicMock() - session.attr.online_offline = AsyncMock(return_value=True) - session.attr.state_attributes = AsyncMock(return_value={}) - session.should_use_cached_data = MagicMock(return_value=False) - session.get_cached_device = MagicMock(return_value=None) - session.set_cached_device = MagicMock(side_effect=lambda d: d) - return Sensor(session=session) - - -def _make_sensor_device(hive_id="sens-1", device_id="dev-1", hive_type="contactsensor"): - return Device( - hive_id=hive_id, - hive_name="Door", - hive_type=hive_type, - ha_type="binary_sensor", - device_id=device_id, - device_name="Door", - device_data={"online": True}, - ha_name="Door", - ) - - -# --------------------------------------------------------------------------- -# Shared helpers — auth -# --------------------------------------------------------------------------- - - -def _make_auth_stub(): - class StubAuth(SessionAuthMixin): - """Concrete subclass used only for testing.""" - - s = StubAuth() - s.auth = MagicMock() - s.auth.DEVICE_VERIFIER_CHALLENGE = "DEVICE_SRP_AUTH" - s.auth.SMS_MFA_CHALLENGE = "SMS_MFA" - s.auth.login = AsyncMock() - s.auth.device_login = AsyncMock() - s.auth.sms_2fa = AsyncMock() - s.auth.refresh_token = AsyncMock() - s.tokens = SessionTokens() - s.tokens.token_data = {"refreshToken": "rt", "token": "", "accessToken": ""} - s.config = SessionConfig() - s.helper = MagicMock() - s.helper.sanitize_payload = MagicMock(return_value={}) - s._refresh_threshold = 0.90 - s._refresh_lock = asyncio.Lock() - return s - - -# --------------------------------------------------------------------------- -# Shared helpers — discovery -# --------------------------------------------------------------------------- - - -def _make_discovery_stub(products=None, devices=None, actions=None): - class StubDiscovery(DiscoveryMixin): - """Concrete subclass used only for testing.""" - - s = StubDiscovery() - s.config = SessionConfig() - s.data = Map( - { - "products": products or {}, - "devices": devices or {}, - "actions": actions or {}, - "user": {"temperatureUnit": "C"}, - "minMax": {}, - } - ) - s.helper = MagicMock() - s.helper.get_device_data = MagicMock( - return_value={ - "id": "dev-1", - "state": {"name": "Test Device"}, - "props": {"online": True}, - } - ) - s.hub_id = None - s.device_list = { - "parent": [], - "binary_sensor": [], - "climate": [], - "light": [], - "sensor": [], - "switch": [], - "water_heater": [], - } - return s - - -# =========================================================================== -# 1. src/devices/heating.py -# =========================================================================== - - -class TestHeatingGetStateKeyError: - """Lines 206-207: KeyError/TypeError branch in get_state.""" - - async def test_get_state_key_error_returns_none(self): - """Missing product entry causes get_current_temperature to return None, - leaving final as None without raising.""" - # products dict is empty — device.hive_id not found → both temp helpers - # return None → the if branch is skipped → final stays None - climate = _make_climate(products={}) - d = _make_device() - result = await climate.get_state(d) - assert result is None - - -class TestHeatingGetHeatOnDemand: - """Line 231: get_heat_on_demand happy path.""" - - async def test_get_heat_on_demand_returns_value(self): - """Returns the nested autoBoost.active value from products.""" - climate = _make_climate({"heat-1": {"props": {"autoBoost": {"active": True}}}}) - result = await climate.get_heat_on_demand(_make_device()) - assert result is True - - async def test_get_heat_on_demand_returns_none_when_missing(self): - """Returns None when the nested path does not exist.""" - climate = _make_climate({"heat-1": {"props": {}}}) - result = await climate.get_heat_on_demand(_make_device()) - assert result is None - - -class TestHeatingSetBoostOffOffMode: - """Lines 321->325: set_boost_off when previous mode is 'OFF' with a real target.""" - - async def test_set_boost_off_off_mode_with_target_restores_target(self): - """Previous mode OFF with a real target value restores that target.""" - climate = _make_climate( - { - "heat-1": { - "type": "heating", - "state": {"boost": 5}, - "props": {"previous": {"mode": "OFF", "target": 18.0}}, - } - } - ) - result = await climate.set_boost_off(_make_device()) - assert result is True - _, kwargs = climate.session.api.set_state.call_args - assert kwargs.get("mode") == "OFF" - assert kwargs.get("target") == 18.0 - - -class TestHeatingSetHeatOnDemand: - """Lines 337-342: set_heat_on_demand calls _execute_state_change with autoBoost kwarg.""" - - async def test_set_heat_on_demand_enabled(self): - """set_heat_on_demand passes autoBoost='ENABLED' to the API.""" - climate = _make_climate({"heat-1": {"type": "heating"}}) - result = await climate.set_heat_on_demand(_make_device(), "ENABLED") - assert result is True - climate.session.api.set_state.assert_called_once() - _, kwargs = climate.session.api.set_state.call_args - assert kwargs.get("autoBoost") == "ENABLED" - - async def test_set_heat_on_demand_disabled(self): - """set_heat_on_demand passes autoBoost='DISABLED' to the API.""" - climate = _make_climate({"heat-1": {"type": "heating"}}) - result = await climate.set_heat_on_demand(_make_device(), "DISABLED") - assert result is True - _, kwargs = climate.session.api.set_state.call_args - assert kwargs.get("autoBoost") == "DISABLED" - - -class TestHeatingGetClimateCacheHit: - """Lines 371->377: get_climate returns cached device when cache is available.""" - - async def test_get_climate_returns_cached_when_available(self): - """Cache hit short-circuits all I/O and returns the cached dict.""" - climate = _make_climate({"heat-1": {"type": "heating"}}) - cached = {"current_temperature": 20.0} - climate.session.should_use_cached_data.return_value = True - climate.session.get_cached_device.return_value = cached - result = await climate.get_climate(_make_device()) - assert result == cached - # No API calls should have been made - climate.session.attr.online_offline.assert_not_called() - - -class TestHeatingGetScheduleNNLKeyError: - """Lines 438-439: KeyError in get_schedule_now_next_later.""" - - async def test_missing_schedule_key_returns_none(self): - """Product with state but no 'schedule' key causes KeyError → returns None.""" - climate = _make_climate( - {"heat-1": {"state": {"mode": "SCHEDULE"}}} - # no 'schedule' key inside state - ) - # Override get_mode to return SCHEDULE directly so the if-branch is entered - climate.session.helper.get_schedule_nnl.side_effect = KeyError("schedule") - # get_mode will read data["state"]["mode"] == "SCHEDULE" → enters the try block - # data["state"]["schedule"] raises KeyError → caught, returns None - result = await climate.get_schedule_now_next_later(_make_device()) - assert result is None - - async def test_schedule_key_error_caught_not_raised(self): - """A KeyError inside the try block does not propagate to the caller.""" - climate = _make_climate({"heat-1": {"state": {"mode": "SCHEDULE"}}}) - # Accessing data["state"]["schedule"] will raise KeyError (key absent) - try: - result = await climate.get_schedule_now_next_later(_make_device()) - except KeyError: - pytest.fail( - "KeyError should have been caught inside get_schedule_now_next_later" - ) - assert result is None - - -# =========================================================================== -# 2. src/devices/hotwater.py -# =========================================================================== - - -class TestHotwaterGetModeKeyError: - """Lines 43-44: KeyError in get_mode.""" - - async def test_get_mode_missing_state_returns_none(self): - """Product with no 'state' key causes KeyError → final stays None.""" - hw = _make_hotwater({"hw-1": {}}) - result = await hw.get_mode(_make_hw_device()) - assert result is None - - -class TestHotwaterGetStateKeyError: - """Lines 83-84: KeyError in get_state.""" - - async def test_get_state_missing_status_key_returns_none(self): - """Product 'state' dict missing 'status' key triggers KeyError → None.""" - hw = _make_hotwater({"hw-1": {"state": {"mode": "MANUAL"}}}) - # 'status' key is absent from state → KeyError on data["state"]["status"] - result = await hw.get_state(_make_hw_device()) - assert result is None - - async def test_get_state_missing_schedule_in_schedule_mode_returns_none(self): - """SCHEDULE mode with no 'schedule' key in state causes KeyError → None.""" - hw = _make_hotwater( - { - "hw-1": { - "state": { - "mode": "SCHEDULE", - "status": "ON", - "boost": False, - # no 'schedule' key - } - } - } - ) - result = await hw.get_state(_make_hw_device()) - assert result is None - - -class TestHotwaterGetWaterHeaterCacheHit: - """Lines 173->180: get_water_heater returns cached when cache is available.""" - - async def test_get_water_heater_returns_cached(self): - """Cache hit short-circuits all I/O and returns the cached value.""" - hw = _make_hotwater() - cached = {"current_operation": "ON"} - hw.session.should_use_cached_data.return_value = True - hw.session.get_cached_device.return_value = cached - result = await hw.get_water_heater(_make_hw_device()) - assert result == cached - hw.session.attr.online_offline.assert_not_called() - - -class TestHotwaterScheduleNNLNone: - """Lines 225->227: get_schedule_now_next_later returns None when schedule is absent.""" - - async def test_schedule_none_when_no_schedule_in_state(self): - """SCHEDULE mode product without 'schedule' key → _get_product_state returns None → None.""" - hw = _make_hotwater({"hw-1": {"state": {"mode": "SCHEDULE"}}}) - # _get_product_state(device, "state", "schedule") → None (key absent) - result = await hw.get_schedule_now_next_later(_make_hw_device()) - assert result is None - - async def test_non_schedule_mode_returns_none(self): - """Non-SCHEDULE mode skips the schedule lookup and returns None directly.""" - hw = _make_hotwater({"hw-1": {"state": {"mode": "MANUAL"}}}) - result = await hw.get_schedule_now_next_later(_make_hw_device()) - assert result is None - - async def test_schedule_present_returns_nnl(self): - """When schedule data exists, get_schedule_nnl result is returned.""" - schedule_data = {"foo": "bar"} - hw = _make_hotwater( - {"hw-1": {"state": {"mode": "SCHEDULE", "schedule": schedule_data}}} - ) - expected = {"now": {}, "next": {}, "later": {}} - hw.session.helper.get_schedule_nnl.return_value = expected - result = await hw.get_schedule_now_next_later(_make_hw_device()) - assert result == expected - hw.session.helper.get_schedule_nnl.assert_called_once_with(schedule_data) - - -# =========================================================================== -# 3. src/devices/sensor.py -# =========================================================================== - - -class TestSensorGetStateKeyError: - """Lines 37->42: KeyError in HiveSensor.get_state.""" - - async def test_get_state_missing_type_key_returns_none(self): - """Product with no 'type' key causes KeyError → final stays None.""" - sensor = _make_sensor({"sens-1": {}}) - d = _make_sensor_device() - result = await sensor.get_state(d) - assert result is None - - async def test_get_state_missing_props_key_returns_none(self): - """contactsensor product without 'props' causes KeyError → None.""" - sensor = _make_sensor({"sens-1": {"type": "contactsensor"}}) - d = _make_sensor_device() - result = await sensor.get_state(d) - assert result is None - - -class TestSensorGetSensorCacheHit: - """Lines 92->98: get_sensor returns cached device when cache is available.""" - - async def test_get_sensor_returns_cached(self): - """Cache hit short-circuits all I/O and returns the cached value.""" - sensor = _make_sensor() - cached = {"state": True} - sensor.session.should_use_cached_data.return_value = True - sensor.session.get_cached_device.return_value = cached - result = await sensor.get_sensor(_make_sensor_device()) - assert result == cached - sensor.session.attr.online_offline.assert_not_called() - - -class TestSensorGetSensorProductsFallback: - """Lines 119->122: when device_id not in devices, fall back to products.""" - - async def test_uses_products_when_device_id_absent_from_devices(self): - """device_id not in session.data.devices → hive_id looked up in products.""" - sensor = _make_sensor( - products={"sens-1": {"type": "contactsensor", "props": {"status": "OPEN"}}}, - devices={}, - ) - d = _make_sensor_device( - hive_id="sens-1", device_id="unknown-dev", hive_type="contactsensor" - ) - # Sensor with hive_type in sensor_commands path will be followed; - # the important thing is the products-fallback path is entered without error. - result = await sensor.get_sensor(d) - # Result should be the device (set_cached_device returns the device itself) - assert result is not None - - async def test_products_fallback_data_used_for_device_data(self): - """Props from the products entry propagate to device.device_data.""" - sensor = _make_sensor( - products={ - "sens-1": { - "type": "contactsensor", - "props": {"status": "OPEN", "online": True}, - } - }, - devices={}, - ) - d = _make_sensor_device( - hive_id="sens-1", device_id="unknown-dev", hive_type="contactsensor" - ) - await sensor.get_sensor(d) - # get_state uses self.session.data.products[device.hive_id] directly - # so we just verify it ran without KeyError - - -class TestSensorGetSensorHiveTypesSensorPath: - """Lines 135->146: elif device.hive_type in HIVE_TYPES['Sensor'] path.""" - - async def test_contactsensor_in_hive_types_sensor_takes_else_branch(self): - """contactsensor is in HIVE_TYPES['Sensor'] and not in sensor_commands key set, - so the elif branch is taken.""" - from apyhiveapi.helper.const import HIVE_TYPES, sensor_commands - - # 'contactsensor' is in HIVE_TYPES['Sensor'] and NOT a key in sensor_commands - assert "contactsensor" in HIVE_TYPES["Sensor"] - assert "contactsensor" not in sensor_commands - - sensor = _make_sensor( - products={"sens-1": {"type": "contactsensor", "props": {"status": "OPEN"}}}, - devices={"dev-1": {"props": {"online": True}, "type": "contactsensor"}}, - ) - d = _make_sensor_device( - hive_id="sens-1", device_id="dev-1", hive_type="contactsensor" - ) - d.device_data = {"online": True} - await sensor.get_sensor(d) - # The elif branch sets device.status with 'state' key - assert d.status is not None - assert "state" in d.status - - async def test_motionsensor_in_hive_types_sensor_sets_status(self): - """motionsensor is in HIVE_TYPES['Sensor'] and not in sensor_commands key set.""" - from apyhiveapi.helper.const import HIVE_TYPES, sensor_commands - - assert "motionsensor" in HIVE_TYPES["Sensor"] - assert "motionsensor" not in sensor_commands - - sensor = _make_sensor( - products={ - "sens-1": { - "type": "motionsensor", - "props": {"motion": {"status": True}}, - } - }, - devices={"dev-1": {"props": {"online": True}, "type": "motionsensor"}}, - ) - d = _make_sensor_device( - hive_id="sens-1", device_id="dev-1", hive_type="motionsensor" - ) - d.device_data = {"online": True} - await sensor.get_sensor(d) - assert d.status is not None - assert "state" in d.status - - -# =========================================================================== -# 4. src/session/auth.py -# =========================================================================== - - -class TestRetryWithBackoffNonZeroDelay: - """Line 66: asyncio.sleep called when delay > 0.""" - - async def test_non_zero_delay_is_awaited_but_succeeds(self): - """A non-zero delay entry causes asyncio.sleep to be called; factory still runs.""" - s = _make_auth_stub() - calls = [] - - async def factory(): - calls.append(1) - return "ok" - - with patch( - "apyhiveapi.session.auth.asyncio.sleep", new_callable=AsyncMock - ) as mock_sleep: - result = await s._retry_with_backoff(factory, delays=(5,)) - assert result == "ok" - mock_sleep.assert_called_once_with(5) - assert len(calls) == 1 - - async def test_zero_delay_does_not_call_sleep(self): - """A zero delay skips asyncio.sleep.""" - s = _make_auth_stub() - - async def factory(): - return "done" - - with patch( - "apyhiveapi.session.auth.asyncio.sleep", new_callable=AsyncMock - ) as mock_sleep: - result = await s._retry_with_backoff(factory, delays=(0,)) - assert result == "done" - mock_sleep.assert_not_called() - - -class TestUpdateTokensFlatDictWithExpiresIn: - """Lines 100->106: flat token dict with ExpiresIn sets token_expiry.""" - - async def test_flat_dict_with_expires_in_sets_token_expiry(self): - """Flat token dict containing ExpiresIn updates tokens.token_expiry.""" - s = _make_auth_stub() - flat = { - "token": "t", - "refreshToken": "r", - "accessToken": "a", - "ExpiresIn": 1800, - } - await s.update_tokens(flat) - assert s.tokens.token_expiry == timedelta(seconds=1800) - - async def test_flat_dict_tokens_are_stored(self): - """All token values from flat dict are written to token_data.""" - s = _make_auth_stub() - flat = {"token": "my-id", "refreshToken": "my-rt", "accessToken": "my-at"} - await s.update_tokens(flat) - assert s.tokens.token_data["token"] == "my-id" - assert s.tokens.token_data["refreshToken"] == "my-rt" - assert s.tokens.token_data["accessToken"] == "my-at" - - -class TestLoginApiError: - """Lines 160-162: HiveApiError in login() is logged and re-raised.""" - - async def test_login_api_error_reraises(self): - """HiveApiError raised by auth.login propagates unchanged to the caller.""" - s = _make_auth_stub() - s.auth.login.side_effect = HiveApiError() - with pytest.raises(HiveApiError): - await s.login() - - -class TestHiveRefreshTokensNoAuthResult: - """Lines 341->373: refresh returns a result but without AuthenticationResult.""" - - async def test_result_without_auth_result_does_not_update_tokens(self): - """When refresh_token returns a dict with no AuthenticationResult, tokens stay unchanged.""" - s = _make_auth_stub() - s.tokens.token_created = datetime.now() - timedelta(hours=2) - s.tokens.token_expiry = timedelta(hours=1) - # Return something truthy but without AuthenticationResult - s.auth.refresh_token.return_value = {"SomeOtherKey": "value"} - result = await s.hive_refresh_tokens() - # Tokens must not have been updated - assert s.tokens.token_data["token"] == "" - assert s.tokens.token_data["accessToken"] == "" - # result is what refresh_token returned - assert result == {"SomeOtherKey": "value"} - - async def test_none_refresh_result_does_not_update_tokens(self): - """When refresh_token returns None, tokens are left unchanged.""" - s = _make_auth_stub() - s.tokens.token_created = datetime.now() - timedelta(hours=2) - s.tokens.token_expiry = timedelta(hours=1) - s.auth.refresh_token.return_value = None - await s.hive_refresh_tokens() - assert s.tokens.token_data["token"] == "" - - -# =========================================================================== -# 5. src/session/discovery.py -# =========================================================================== - - -class TestCreateDevicesEntityConfigKwargs: - """Lines 224->226, 226->228, 228->230: entity_config kwarg population in DEVICES loop.""" - - async def test_entity_config_with_all_fields_populates_kwargs(self): - """EntityConfig with ha_name, hive_type, and category all set → all kwargs passed.""" - s = _make_discovery_stub( - devices={ - "dev-1": { - "id": "dev-1", - "type": "hub", - "state": {"name": "My Hub"}, - "props": {}, - } - } - ) - entity_cfg = EntityConfig( - entity_type="binary_sensor", - ha_name="Hub Status", - hive_type="Connectivity", - category="diagnostic", - ) - with patch("apyhiveapi.session.discovery.DEVICES", {"hub": [entity_cfg]}): - result = await s.create_devices() - assert len(result["binary_sensor"]) == 1 - created = result["binary_sensor"][0] - assert created.hive_type == "Connectivity" - assert created.category == "diagnostic" - - async def test_entity_config_empty_fields_does_not_add_to_kwargs(self): - """EntityConfig with empty ha_name and hive_type does not inject those keys.""" - s = _make_discovery_stub( - devices={ - "dev-1": { - "id": "dev-1", - "type": "hub", - "state": {"name": "My Hub"}, - "props": {}, - } - } - ) - entity_cfg = EntityConfig( - entity_type="binary_sensor", - ha_name="", # falsy — should not be added to kwargs - hive_type="", # falsy — should not be added to kwargs - category=None, # None — should not be added to kwargs - ) - with patch("apyhiveapi.session.discovery.DEVICES", {"hub": [entity_cfg]}): - result = await s.create_devices() - # Should still process without error - assert isinstance(result, dict) - - -class TestCreateDevicesDeviceAddListError: - """Lines 232-233: KeyError/TypeError from add_list in DEVICES loop is caught.""" - - async def test_add_list_keyerror_is_caught_not_raised(self): - """KeyError from add_list during device processing is logged, not propagated.""" - s = _make_discovery_stub( - devices={ - "dev-1": { - "id": "dev-1", - "type": "hub", - "state": {"name": "My Hub"}, - "props": {}, - } - } - ) - entity_cfg = EntityConfig( - entity_type="binary_sensor", - ha_name="Hub Status", - hive_type="Connectivity", - category="diagnostic", - ) - with patch("apyhiveapi.session.discovery.DEVICES", {"hub": [entity_cfg]}): - with patch.object(s, "add_list", side_effect=KeyError("bad key")): - # Should complete without raising - result = await s.create_devices() - assert isinstance(result, dict) - - async def test_add_list_typeerror_is_caught_not_raised(self): - """TypeError from add_list during device processing is caught.""" - s = _make_discovery_stub( - devices={ - "dev-1": { - "id": "dev-1", - "type": "hub", - "state": {"name": "My Hub"}, - "props": {}, - } - } - ) - entity_cfg = EntityConfig( - entity_type="binary_sensor", - ha_name="", - hive_type="", - category=None, - ) - with patch("apyhiveapi.session.discovery.DEVICES", {"hub": [entity_cfg]}): - with patch.object(s, "add_list", side_effect=TypeError("bad type")): - result = await s.create_devices() - assert isinstance(result, dict) - - -class TestCreateDevicesActionAddListError: - """Lines 258-259: KeyError/TypeError from add_list in actions loop is caught.""" - - async def test_action_add_list_keyerror_is_caught(self): - """KeyError from add_list when processing an action is logged, not propagated.""" - s = _make_discovery_stub( - actions={"act-1": {"id": "act-1", "name": "Good Night"}} - ) - with patch.object(s, "add_list", side_effect=KeyError("missing")): - result = await s.create_devices() - assert isinstance(result, dict) - - async def test_action_add_list_typeerror_is_caught(self): - """TypeError from add_list when processing an action is caught.""" - s = _make_discovery_stub(actions={"act-1": {"id": "act-1", "name": "Wake Up"}}) - with patch.object(s, "add_list", side_effect=TypeError("type error")): - result = await s.create_devices() - assert isinstance(result, dict) - - -class TestCreateDevicesProductTemperatureUnit: - """Line 305: entity_config.temperature_unit is used when set and entity_type != 'climate'.""" - - async def test_entity_config_temperature_unit_passed_to_add_list(self): - """EntityConfig with temperature_unit set propagates that value as a kwarg.""" - s = _make_discovery_stub( - products={ - "prod-1": { - "id": "prod-1", - "type": "heating", - "state": {"name": "Heating"}, - "props": {}, - } - } - ) - # A non-climate entity with temperature_unit set triggers line 305 - entity_cfg = EntityConfig( - entity_type="sensor", - ha_name="Temp Sensor", - hive_type="Current_Temperature", - category="diagnostic", - temperature_unit="F", - ) - captured_kwargs = {} - - original_add_list = s.add_list - - def capturing_add_list(entity_type, data, **kwargs): - captured_kwargs.update(kwargs) - return original_add_list(entity_type, data, **kwargs) - - with patch("apyhiveapi.session.discovery.PRODUCTS", {"heating": [entity_cfg]}): - with patch.object(s, "add_list", side_effect=capturing_add_list): - await s.create_devices() - - assert captured_kwargs.get("temperature_unit") == "F" - - -class TestCreateDevicesProductAddListAttributeError: - """Lines 308-309: NameError/AttributeError from add_list in products loop is caught.""" - - async def test_product_add_list_attribute_error_is_caught(self): - """AttributeError from add_list when processing a product is caught.""" - s = _make_discovery_stub( - products={ - "prod-1": { - "id": "prod-1", - "type": "heating", - "state": {"name": "Heating"}, - "props": {}, - } - } - ) - entity_cfg = EntityConfig( - entity_type="climate", - ha_name="", - hive_type="", - category=None, - ) - with patch("apyhiveapi.session.discovery.PRODUCTS", {"heating": [entity_cfg]}): - with patch.object(s, "add_list", side_effect=AttributeError("attr error")): - result = await s.create_devices() - assert isinstance(result, dict) - - async def test_product_add_list_name_error_is_caught(self): - """NameError from add_list when processing a product is caught.""" - s = _make_discovery_stub( - products={ - "prod-1": { - "id": "prod-1", - "type": "heating", - "state": {"name": "Heating"}, - "props": {}, - } - } - ) - entity_cfg = EntityConfig( - entity_type="climate", - ha_name="", - hive_type="", - category=None, - ) - with patch("apyhiveapi.session.discovery.PRODUCTS", {"heating": [entity_cfg]}): - with patch.object(s, "add_list", side_effect=NameError("name error")): - result = await s.create_devices() - assert isinstance(result, dict) - - -# =========================================================================== -# Additional False-branch tests: cache-miss paths and elif False paths -# =========================================================================== - - -def _make_light_session(products=None, devices=None): - session = MagicMock() - session.data = Map( - { - "products": products or {}, - "devices": devices or {}, - "actions": {}, - "minMax": {}, - "user": {}, - } - ) - session.config = SessionConfig() - session.helper = MagicMock() - session.helper.device_recovered = MagicMock() - session.helper.error_check = AsyncMock() - session.attr = MagicMock() - session.attr.online_offline = AsyncMock(return_value=True) - session.attr.state_attributes = AsyncMock(return_value={}) - session.api = MagicMock() - session.api.set_state = AsyncMock(return_value={"original": 200, "parsed": {}}) - session.hive_refresh_tokens = AsyncMock() - session.get_devices = AsyncMock(return_value=True) - session.should_use_cached_data = MagicMock(return_value=False) - session.get_cached_device = MagicMock(return_value=None) - session.set_cached_device = MagicMock(side_effect=lambda d: d) - return session - - -def _make_light_device( - hive_id="light-1", device_id="dev-1", hive_type="warmwhitelight" -): - return Device( - hive_id=hive_id, - hive_name="Bulb", - hive_type=hive_type, - ha_type="light", - device_id=device_id, - device_name="Bulb", - device_data={"online": True}, - ha_name="Bulb", - ) - - -# --------------------------------------------------------------------------- -# heating.py: 321->325 — set_boost_off with non-MANUAL/OFF previous mode -# --------------------------------------------------------------------------- - - -class TestHeatingSetBoostOffScheduleMode: - """Lines 321->325: prev_mode not in ('MANUAL','OFF') — target kwarg not added.""" - - async def test_schedule_mode_no_target_kwarg(self): - """SCHEDULE as previous mode does not add a target kwarg.""" - climate = _make_climate( - { - "heat-1": { - "type": "heating", - "state": {"boost": 5}, - "props": {"previous": {"mode": "SCHEDULE"}}, - } - } - ) - result = await climate.set_boost_off(_make_device()) - assert result is True - _, kwargs = climate.session.api.set_state.call_args - assert "target" not in kwargs - assert kwargs.get("mode") == "SCHEDULE" - - -# --------------------------------------------------------------------------- -# heating.py: 371->377 — get_climate: should_use_cached=True but cached is None -# --------------------------------------------------------------------------- - - -class TestHeatingGetClimateCacheMiss: - """Lines 371->377: cache enabled but cached device is None → normal execution.""" - - async def test_cached_none_falls_through_to_normal_path(self): - """should_use_cached_data=True but get_cached_device=None → normal update.""" - climate = _make_climate( - { - "heat-1": { - "state": {"mode": "MANUAL", "target": 20.0}, - "props": {"temperature": 19.0}, - } - }, - devices={"dev-1": {"state": {}, "props": {}}}, - ) - climate.session.should_use_cached_data.return_value = True - climate.session.get_cached_device.return_value = None - result = await climate.get_climate(_make_device()) - assert result is not None - climate.session.attr.online_offline.assert_called_once() - - -# --------------------------------------------------------------------------- -# hotwater.py: 173->180 — same pattern -# --------------------------------------------------------------------------- - - -class TestHotwaterGetWaterHeaterCacheMiss: - """Lines 173->180: cache enabled but cached is None → continues with network call.""" - - async def test_cached_none_falls_through(self): - hw = _make_hotwater( - {"hw-1": {"state": {"mode": "ON"}, "props": {}}}, - devices={"dev-1": {"state": {}, "props": {}}}, - ) - hw.session.should_use_cached_data.return_value = True - hw.session.get_cached_device.return_value = None - result = await hw.get_water_heater(_make_hw_device()) - assert result is not None - hw.session.attr.online_offline.assert_called_once() - - -# --------------------------------------------------------------------------- -# light.py: 141->147 — same pattern -# --------------------------------------------------------------------------- - - -class TestLightGetLightCacheMiss: - """Lines 141->147: cache enabled but cached is None → normal execution.""" - - async def test_cached_none_falls_through(self): - session = _make_light_session( - products={ - "light-1": {"state": {"status": "ON", "brightness": 100}, "props": {}} - }, - devices={"dev-1": {"state": {}, "props": {}}}, - ) - light = Light(session=session) - d = _make_light_device() - session.should_use_cached_data.return_value = True - session.get_cached_device.return_value = None - result = await light.get_light(d) - assert result is not None - session.attr.online_offline.assert_called_once() - - -# --------------------------------------------------------------------------- -# sensor.py: 37->42 — get_state: type neither contactsensor nor motionsensor -# --------------------------------------------------------------------------- - - -class TestSensorGetStateUnknownType: - """Lines 37->42: data['type'] is neither contactsensor nor motionsensor.""" - - async def test_unknown_type_returns_none(self): - """Product with type 'hub' skips both if/elif → final stays None.""" - sensor = _make_sensor({"sens-1": {"type": "hub", "props": {}}}) - d = _make_sensor_device() - result = await sensor.get_state(d) - assert result is None - - -# --------------------------------------------------------------------------- -# sensor.py: 92->98 — get_sensor: cache enabled but cached is None -# --------------------------------------------------------------------------- - - -class TestSensorGetSensorCacheMiss: - """Lines 92->98: should_use_cached_data=True but cached is None.""" - - async def test_cached_none_falls_through(self): - sensor = _make_sensor( - products={"sens-1": {"type": "contactsensor", "props": {"status": "OPEN"}}}, - devices={"dev-1": {"props": {"online": True}, "type": "contactsensor"}}, - ) - sensor.session.should_use_cached_data.return_value = True - sensor.session.get_cached_device.return_value = None - d = _make_sensor_device() - result = await sensor.get_sensor(d) - assert result is not None - sensor.session.attr.online_offline.assert_called_once() - - -# --------------------------------------------------------------------------- -# sensor.py: 119->122 — get_sensor: neither device_id nor hive_id found -# --------------------------------------------------------------------------- - - -class TestSensorGetSensorNoDataFallthrough: - """Lines 119->122: device_id not in devices AND hive_id not in products.""" - - async def test_neither_match_continues_with_empty_data(self): - """data stays empty dict when neither lookup succeeds.""" - sensor = _make_sensor(products={}, devices={}) - d = _make_sensor_device( - hive_id="unknown-hive", device_id="unknown-dev", hive_type="contactsensor" - ) - result = await sensor.get_sensor(d) - # Should not raise; result will be the device (set_cached_device returns it) - assert result is not None - - -# --------------------------------------------------------------------------- -# sensor.py: 135->146 — get_sensor: hive_type not in sensor_commands or HIVE_TYPES["Sensor"] -# --------------------------------------------------------------------------- - - -class TestSensorGetSensorUnknownHiveType: - """Lines 135->146: hive_type not in sensor_commands and not in HIVE_TYPES['Sensor'].""" - - async def test_hive_type_not_in_either_dict_skips_both_branches(self): - """activeplug is neither in sensor_commands nor HIVE_TYPES['Sensor'].""" - sensor = _make_sensor( - devices={"dev-1": {"props": {"online": True}, "type": "activeplug"}} - ) - d = _make_sensor_device( - hive_id="dev-1", device_id="dev-1", hive_type="activeplug" - ) - d.device_data = {"online": True} - result = await sensor.get_sensor(d) - # Neither branch sets device.status; device returned as-is via set_cached_device - assert result is not None - - -# =========================================================================== -# Additional branches: session/auth.py, hive_helper.py, heating.py -# =========================================================================== - - -class TestUpdateTokensUnknownKey: - """session/auth.py 100->106: tokens dict has neither AuthenticationResult nor token.""" - - async def test_unknown_key_does_not_raise_and_does_not_update_tokens(self): - """When neither expected key is present, data stays {}, ExpiresIn check skips.""" - s = _make_auth_stub() - original_token = s.tokens.token_data["token"] - # Pass a dict that is neither the AuthResult form nor the flat-token form - await s.update_tokens({"some_other_key": "some_value"}) - # Tokens must be unchanged - assert s.tokens.token_data["token"] == original_token - - async def test_unknown_key_does_not_set_token_expiry(self): - """ExpiresIn check at line 106 skips when data is {} (no match in either branch).""" - s = _make_auth_stub() - original_expiry = s.tokens.token_expiry - await s.update_tokens({"random_key": "random_value"}) - assert s.tokens.token_expiry == original_expiry - - -class TestHiveHelperZoneMismatch: - """hive_helper.py 163->160: loop continues when zones don't match.""" - - def test_zone_mismatch_keeps_product_as_device(self): - """When a Thermo device's zone doesn't match the product's zone, - the loop arc 163->160 is taken and device stays as the product.""" - helper = HiveHelper(session=MagicMock()) - helper.session.data = Map( - { - "devices": { - "thermo-1": { - "type": "thermostatui", - "props": {"zone": "zone-B"}, - } - }, - "products": {}, - "actions": {}, - "user": {}, - "minMax": {}, - } - ) - - product = { - "type": "heating", - "id": "prod-1", - "props": {"zone": "zone-A"}, # different zone from thermo-1 - } - - result = helper.get_device_data(product) - # The zone mismatch means device was never re-assigned; returns the product - assert result is product - - def test_trv_without_zone_does_not_log_warning(self, caplog): - """TRV devices that omit 'zone' from props are silently skipped (no warning).""" - import logging - - helper = HiveHelper(session=MagicMock()) - helper.session.data = Map( - { - "devices": { - "trv-1": { - "type": "trv", - "props": { - "online": True - }, # no 'zone' key — current API behaviour - } - }, - "products": {}, - "actions": {}, - "user": {}, - "minMax": {}, - } - ) - - product = { - "type": "heating", - "id": "prod-1", - "props": {"zone": "zone-A"}, - } - - with caplog.at_level(logging.WARNING, logger="apyhiveapi.helper.hive_helper"): - result = helper.get_device_data(product) - - assert result is product - assert not caplog.records, ( - f"Unexpected warnings: {[r.getMessage() for r in caplog.records]}" - ) - - def test_zone_match_replaces_device_with_thermostat(self): - """Matching zones cause device to be replaced with the thermostat entry.""" - helper = HiveHelper(session=MagicMock()) - thermo_data = { - "type": "thermostatui", - "props": {"zone": "zone-X"}, - } - helper.session.data = Map( - { - "devices": {"thermo-1": thermo_data}, - "products": {}, - "actions": {}, - "user": {}, - "minMax": {}, - } - ) - - product = { - "type": "heating", - "id": "prod-1", - "props": {"zone": "zone-X"}, # matching zone - } - - result = helper.get_device_data(product) - assert result is thermo_data - - -class TestHiveHelperSanitizeDictValue: - """hive_helper.py line 328: dict value under a sensitive key calls _mask(dict).""" - - def test_dict_under_sensitive_key_is_recursively_masked(self): - """A dict value under 'token' key hits the isinstance(value, dict) branch.""" - helper = HiveHelper() - result = helper.sanitize_payload({"token": {"inner_key": "secret_value"}}) - # 'token' is sensitive → _mask is called with the nested dict - # _mask for a dict returns {k: _mask(v) for k, v in value.items()} - # _mask("secret_value") → "sec...lue" (long enough) or "***" - assert "token" in result - assert isinstance(result["token"], dict) - assert "inner_key" in result["token"] - # The inner value should be masked (not the original) - assert result["token"]["inner_key"] != "secret_value" - - def test_nested_dict_keys_preserved_after_masking(self): - """Keys inside a sensitive dict are preserved, values are masked.""" - helper = HiveHelper() - result = helper.sanitize_payload( - { - "authenticationresult": { - "AccessToken": "long-secret-token-value", - "ExpiresIn": 3600, - } - } - ) - inner = result["authenticationresult"] - assert "AccessToken" in inner - assert "ExpiresIn" in inner - # ExpiresIn is an int, _mask returns it as-is - assert inner["ExpiresIn"] == 3600 - - -class TestHiveHelperSanitizeListNode: - """hive_helper.py line 359: list value under a non-sensitive key calls _walk(list).""" - - def test_list_under_non_sensitive_key_is_walked(self): - """A list value under a non-sensitive key hits the isinstance(node, list) branch.""" - helper = HiveHelper() - result = helper.sanitize_payload({"devices": ["device-a", "device-b"]}) - # 'devices' is not a sensitive key → _walk called for the list - # _walk for a list returns [_walk(item) for item in node] - # Each string item: _walk(str) → str (falls through to return node) - assert result == {"devices": ["device-a", "device-b"]} - - def test_list_containing_dicts_is_walked_recursively(self): - """A list of dicts under a non-sensitive key is recursively processed.""" - helper = HiveHelper() - result = helper.sanitize_payload( - { - "items": [ - {"token": "abc", "name": "device1"}, - {"token": "xyz", "name": "device2"}, - ] - } - ) - # 'items' is not sensitive → _walk called for the list - # Each dict in the list is processed by _walk - # 'token' IS sensitive → masked in each sub-dict - assert result["items"][0]["name"] == "device1" - assert result["items"][0]["token"] != "abc" - assert result["items"][1]["name"] == "device2" - assert result["items"][1]["token"] != "xyz" - - -class TestHeatingGetStateExceptionCaught: - """heating.py lines 206-207: except (KeyError, TypeError) handler is reached.""" - - async def test_key_error_in_get_current_temperature_is_caught(self): - """KeyError raised by get_current_temperature is caught, final stays None.""" - climate = _make_climate( - {"heat-1": {"state": {"mode": "MANUAL", "target": 20.0}, "props": {}}} - ) - d = _make_device() - with patch.object( - climate, "get_current_temperature", new_callable=AsyncMock - ) as mock_t: - mock_t.side_effect = KeyError("missing_key") - result = await climate.get_state(d) - assert result is None - - async def test_type_error_in_get_target_temperature_is_caught(self): - """TypeError raised by get_target_temperature is caught, final stays None.""" - climate = _make_climate( - {"heat-1": {"state": {"mode": "MANUAL", "target": 20.0}, "props": {}}} - ) - d = _make_device() - with patch.object( - climate, "get_current_temperature", new_callable=AsyncMock - ) as mock_cur: - mock_cur.return_value = 19.0 - with patch.object( - climate, "get_target_temperature", new_callable=AsyncMock - ) as mock_tgt: - mock_tgt.side_effect = TypeError("bad type") - result = await climate.get_state(d) - assert result is None diff --git a/tests/unit/test_sensor_extended.py b/tests/unit/test_sensor.py similarity index 52% rename from tests/unit/test_sensor_extended.py rename to tests/unit/test_sensor.py index 68122f3..3bad1b1 100644 --- a/tests/unit/test_sensor_extended.py +++ b/tests/unit/test_sensor.py @@ -2,7 +2,7 @@ # pylint: disable=protected-access -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock, MagicMock, patch from apyhiveapi.devices.sensor import Sensor from apyhiveapi.helper.hivedataclasses import Device, SessionConfig @@ -146,6 +146,39 @@ async def test_contact_sensor_in_hive_types_sets_status(self): assert "state" in result.status session.attr.state_attributes.assert_awaited_once() + async def test_contact_sensor_uses_device_id_not_hive_id_for_props(self): + """HIVE_TYPES['Sensor'] branch must look up data.devices by device_id. + + Before the fix, line 160 used hive_id; data was always {} so + device.parent_device was always None even when the device existed. + """ + hive_id = "prod-abc" + device_id = "dev-xyz" # deliberately different from hive_id + + products = {} # contactsensor is NOT in products + devices = { + device_id: { + "props": {"online": True, "signal": -70}, + "parent": "hub-parent-id", + } + } + session = _make_session(products=products, devices=devices) + session.attr.online_offline = AsyncMock(return_value=True) + + device = _make_device( + hive_id=hive_id, device_id=device_id, hive_type="contactsensor" + ) + device.device_data = {"online": True} + + sensor = Sensor(session) + with patch.object(sensor, "get_state", new=AsyncMock(return_value="CLOSED")): + result = await sensor.get_sensor(device) + + assert result is not None + assert device.parent_device == "hub-parent-id", ( + "parent_device must come from data.devices[device_id], not hive_id lookup" + ) + class TestGetState: """Tests for HiveSensor.get_state covering the motionsensor branch (lines 37-42).""" @@ -170,3 +203,90 @@ async def test_motionsensor_returns_motion_status(self): state = await sensor.get_state(device) assert state is True + + +# =========================================================================== +# Migrated from test_remaining_branches.py +# =========================================================================== + + +class TestSensorGetStateKeyError: + """Lines 37->42: KeyError in HiveSensor.get_state.""" + + async def test_get_state_missing_type_key_returns_none(self): + """Product with no 'type' key causes KeyError → final stays None.""" + session = _make_session({"sens-1": {}}) + sensor = Sensor(session=session) + d = _make_device(hive_id="sens-1", device_id="dev-1", hive_type="contactsensor") + result = await sensor.get_state(d) + assert result is None + + async def test_get_state_missing_props_key_returns_none(self): + """contactsensor product without 'props' causes KeyError → None.""" + session = _make_session({"sens-1": {"type": "contactsensor"}}) + sensor = Sensor(session=session) + d = _make_device(hive_id="sens-1", device_id="dev-1", hive_type="contactsensor") + result = await sensor.get_state(d) + assert result is None + + +class TestSensorGetStateUnknownType: + """Lines 37->42: data['type'] is neither contactsensor nor motionsensor.""" + + async def test_unknown_type_returns_none(self): + """Product with type 'hub' skips both if/elif → final stays None.""" + session = _make_session({"sens-1": {"type": "hub", "props": {}}}) + sensor = Sensor(session=session) + d = _make_device(hive_id="sens-1", device_id="dev-1", hive_type="contactsensor") + result = await sensor.get_state(d) + assert result is None + + +class TestSensorGetSensorCacheMiss: + """Lines 92->98: should_use_cached_data=True but cached is None.""" + + async def test_cached_none_falls_through(self): + session = _make_session( + products={"sens-1": {"type": "contactsensor", "props": {"status": "OPEN"}}}, + devices={"dev-1": {"props": {"online": True}, "type": "contactsensor"}}, + ) + session.should_use_cached_data.return_value = True + session.get_cached_device.return_value = None + sensor = Sensor(session=session) + d = _make_device(hive_id="sens-1", device_id="dev-1", hive_type="contactsensor") + result = await sensor.get_sensor(d) + assert result is not None + session.attr.online_offline.assert_called_once() + + +class TestSensorGetSensorNoDataFallthrough: + """Lines 119->122: device_id not in devices AND hive_id not in products.""" + + async def test_neither_match_continues_with_empty_data(self): + """data stays empty dict when neither lookup succeeds.""" + session = _make_session(products={}, devices={}) + sensor = Sensor(session=session) + d = _make_device( + hive_id="unknown-hive", + device_id="unknown-dev", + hive_type="contactsensor", + ) + result = await sensor.get_sensor(d) + # Should not raise; result will be the device (set_cached_device returns it) + assert result is not None + + +class TestSensorGetSensorUnknownHiveType: + """Lines 135->146: hive_type not in sensor_commands and not in HIVE_TYPES['Sensor'].""" + + async def test_hive_type_not_in_either_dict_skips_both_branches(self): + """activeplug is neither in sensor_commands nor HIVE_TYPES['Sensor'].""" + session = _make_session( + devices={"dev-1": {"props": {"online": True}, "type": "activeplug"}} + ) + sensor = Sensor(session=session) + d = _make_device(hive_id="dev-1", device_id="dev-1", hive_type="activeplug") + d.device_data = {"online": True} + result = await sensor.get_sensor(d) + # Neither branch sets device.status; device returned as-is via set_cached_device + assert result is not None diff --git a/tests/unit/test_session_auth_extended.py b/tests/unit/test_session_auth.py similarity index 52% rename from tests/unit/test_session_auth_extended.py rename to tests/unit/test_session_auth.py index e688096..87e9967 100644 --- a/tests/unit/test_session_auth_extended.py +++ b/tests/unit/test_session_auth.py @@ -93,6 +93,84 @@ async def test_auth_result_with_update_expiry_true_sets_token_created(self): await s.update_tokens(AUTH_RESULT, update_expiry_time=True) assert s.tokens.token_created > before + async def test_auth_result_missing_id_token_does_not_raise(self): + """AuthenticationResult with only AccessToken (e.g. file mode) must not crash.""" + s = _make_stub() + payload = {"AuthenticationResult": {"AccessToken": "only-access"}} + await s.update_tokens(payload) + assert s.tokens.token_data["accessToken"] == "only-access" + # IdToken absent — the session token must not have been overwritten + assert s.tokens.token_data["token"] == "" + + async def test_flat_dict_missing_access_token_does_not_raise(self): + """A flat token dict without accessToken must not crash.""" + s = _make_stub() + flat = {"token": "t-only"} + await s.update_tokens(flat) + assert s.tokens.token_data["token"] == "t-only" + assert s.tokens.token_data["accessToken"] == "" + + async def test_auth_result_missing_access_token_does_not_raise(self): + """AuthenticationResult with only IdToken must not crash.""" + s = _make_stub() + payload = {"AuthenticationResult": {"IdToken": "only-id"}} + await s.update_tokens(payload) + assert s.tokens.token_data["token"] == "only-id" + assert s.tokens.token_data["accessToken"] == "" + + +# --------------------------------------------------------------------------- +# _retry_with_backoff — re-raise semantics +# --------------------------------------------------------------------------- + + +class _NeedsArgsError(Exception): + """Exception type that cannot be constructed without arguments.""" + + def __init__(self, first, second): + super().__init__(f"{first}/{second}") + + +class TestRetryWithBackoffReraise: + """Without reraise_as, the original exception instance must propagate.""" + + async def test_original_exception_instance_propagates(self): + """The last caught error is re-raised as-is, not re-instantiated.""" + s = _make_stub() + original = _NeedsArgsError("a", "b") + + async def _always_fail(): + raise original + + with pytest.raises(_NeedsArgsError) as excinfo: + await s._retry_with_backoff(_always_fail, delays=(0,)) + + assert excinfo.value is original + + async def test_empty_delays_raises_runtime_error(self): + """With no attempts configured the defensive fallback raises RuntimeError.""" + s = _make_stub() + + async def _never_called(): + raise AssertionError("should not run") + + with pytest.raises(RuntimeError, match="exhausted"): + await s._retry_with_backoff(_never_called, delays=()) + + async def test_reraise_as_still_translates_exception_type(self): + """When reraise_as is given the error is translated with chaining.""" + s = _make_stub() + + async def _always_fail(): + raise ValueError("boom") + + with pytest.raises(HiveReauthRequired) as excinfo: + await s._retry_with_backoff( + _always_fail, delays=(0,), reraise_as=HiveReauthRequired + ) + + assert isinstance(excinfo.value.__cause__, ValueError) + # --------------------------------------------------------------------------- # _handle_device_login_challenge — extra branch @@ -290,3 +368,193 @@ async def test_force_refresh_enters_lock_even_when_token_is_fresh(self): s.auth.refresh_token.return_value = AUTH_RESULT await s.hive_refresh_tokens(force_refresh=True) s.auth.refresh_token.assert_called_once() + + +# --------------------------------------------------------------------------- +# update_tokens — elif "token" branch missing token_created and bare refreshToken +# --------------------------------------------------------------------------- + + +class TestUpdateTokensTokenBranch: + """update_tokens must set token_created and guard missing refreshToken in elif branch.""" + + async def test_token_branch_sets_token_created(self): + """elif 'token' branch must update token_created (was missing, stayed datetime.min).""" + s = _make_stub() + await s.update_tokens( + {"token": "id-tok", "refreshToken": "ref-tok", "accessToken": "acc-tok"} + ) + assert s.tokens.token_created > datetime.min + + async def test_token_branch_updates_token_data(self): + """elif 'token' branch stores all three token values.""" + s = _make_stub() + await s.update_tokens( + {"token": "id", "refreshToken": "ref", "accessToken": "acc"} + ) + assert s.tokens.token_data["token"] == "id" + assert s.tokens.token_data["refreshToken"] == "ref" + assert s.tokens.token_data["accessToken"] == "acc" + + async def test_token_branch_missing_refresh_token_does_not_crash(self): + """elif 'token' branch without refreshToken key must not raise KeyError.""" + s = _make_stub() + await s.update_tokens({"token": "id", "accessToken": "acc"}) + assert s.tokens.token_data["token"] == "id" + + async def test_token_branch_update_expiry_false_does_not_update_token_created(self): + """update_expiry_time=False skips the token_created assignment in elif 'token' branch.""" + + s = _make_stub() + original_created = s.tokens.token_created + await s.update_tokens( + {"token": "id", "accessToken": "acc"}, + update_expiry_time=False, + ) + assert s.tokens.token_created == original_created + + +# --------------------------------------------------------------------------- +# hive_refresh_tokens — bare refreshToken access raises KeyError when missing +# --------------------------------------------------------------------------- + + +class TestHiveRefreshTokensMissingRefreshToken: + """hive_refresh_tokens must not crash when token_data has no refreshToken.""" + + async def test_missing_refresh_token_does_not_raise_key_error(self): + """hive_refresh_tokens without refreshToken in token_data must not crash.""" + s = _make_stub() + s.tokens.token_data = {"token": "id", "accessToken": "acc"} + s.tokens.token_created = datetime.now() - timedelta(hours=2) + s.tokens.token_expiry = timedelta(hours=1) + s.auth.refresh_token.return_value = None + result = await s.hive_refresh_tokens() + assert result is None + + +# =========================================================================== +# Migrated from test_remaining_branches.py +# =========================================================================== + + +class TestRetryWithBackoffNonZeroDelay: + """Line 66: asyncio.sleep called when delay > 0.""" + + async def test_non_zero_delay_is_awaited_but_succeeds(self): + """A non-zero delay entry causes asyncio.sleep to be called; factory still runs.""" + from unittest.mock import patch + + s = _make_stub() + calls = [] + + async def factory(): + calls.append(1) + return "ok" + + with patch( + "apyhiveapi.session.auth.asyncio.sleep", new_callable=AsyncMock + ) as mock_sleep: + result = await s._retry_with_backoff(factory, delays=(5,)) + assert result == "ok" + mock_sleep.assert_called_once_with(5) + assert len(calls) == 1 + + async def test_zero_delay_does_not_call_sleep(self): + """A zero delay skips asyncio.sleep.""" + from unittest.mock import patch + + s = _make_stub() + + async def factory(): + return "done" + + with patch( + "apyhiveapi.session.auth.asyncio.sleep", new_callable=AsyncMock + ) as mock_sleep: + result = await s._retry_with_backoff(factory, delays=(0,)) + assert result == "done" + mock_sleep.assert_not_called() + + +class TestUpdateTokensFlatDictWithExpiresIn: + """Lines 100->106: flat token dict with ExpiresIn sets token_expiry.""" + + async def test_flat_dict_with_expires_in_sets_token_expiry(self): + """Flat token dict containing ExpiresIn updates tokens.token_expiry.""" + s = _make_stub() + flat = { + "token": "t", + "refreshToken": "r", + "accessToken": "a", + "ExpiresIn": 1800, + } + await s.update_tokens(flat) + assert s.tokens.token_expiry == timedelta(seconds=1800) + + async def test_flat_dict_tokens_are_stored(self): + """All token values from flat dict are written to token_data.""" + s = _make_stub() + flat = {"token": "my-id", "refreshToken": "my-rt", "accessToken": "my-at"} + await s.update_tokens(flat) + assert s.tokens.token_data["token"] == "my-id" + assert s.tokens.token_data["refreshToken"] == "my-rt" + assert s.tokens.token_data["accessToken"] == "my-at" + + +class TestLoginApiError: + """Lines 160-162: HiveApiError in login() is logged and re-raised.""" + + async def test_login_api_error_reraises(self): + """HiveApiError raised by auth.login propagates unchanged to the caller.""" + s = _make_stub() + s.auth.login.side_effect = HiveApiError() + with pytest.raises(HiveApiError): + await s.login() + + +class TestHiveRefreshTokensNoAuthResult: + """Lines 341->373: refresh returns a result but without AuthenticationResult.""" + + async def test_result_without_auth_result_does_not_update_tokens(self): + """When refresh_token returns a dict with no AuthenticationResult, tokens stay unchanged.""" + s = _make_stub() + s.tokens.token_created = datetime.now() - timedelta(hours=2) + s.tokens.token_expiry = timedelta(hours=1) + # Return something truthy but without AuthenticationResult + s.auth.refresh_token.return_value = {"SomeOtherKey": "value"} + result = await s.hive_refresh_tokens() + # Tokens must not have been updated + assert s.tokens.token_data["token"] == "" + assert s.tokens.token_data["accessToken"] == "" + # result is what refresh_token returned + assert result == {"SomeOtherKey": "value"} + + async def test_none_refresh_result_does_not_update_tokens(self): + """When refresh_token returns None, tokens are left unchanged.""" + s = _make_stub() + s.tokens.token_created = datetime.now() - timedelta(hours=2) + s.tokens.token_expiry = timedelta(hours=1) + s.auth.refresh_token.return_value = None + await s.hive_refresh_tokens() + assert s.tokens.token_data["token"] == "" + + +class TestUpdateTokensUnknownKey: + """session/auth.py 100->106: tokens dict has neither AuthenticationResult nor token.""" + + async def test_unknown_key_does_not_raise_and_does_not_update_tokens(self): + """When neither expected key is present, data stays {}, ExpiresIn check skips.""" + s = _make_stub() + original_token = s.tokens.token_data["token"] + # Pass a dict that is neither the AuthResult form nor the flat-token form + await s.update_tokens({"some_other_key": "some_value"}) + # Tokens must be unchanged + assert s.tokens.token_data["token"] == original_token + + async def test_unknown_key_does_not_set_token_expiry(self): + """ExpiresIn check at line 106 skips when data is {} (no match in either branch).""" + s = _make_stub() + original_expiry = s.tokens.token_expiry + await s.update_tokens({"random_key": "random_value"}) + assert s.tokens.token_expiry == original_expiry diff --git a/tests/unit/test_session_close.py b/tests/unit/test_session_close.py index 49edabb..b66df73 100644 --- a/tests/unit/test_session_close.py +++ b/tests/unit/test_session_close.py @@ -9,30 +9,62 @@ class TestHiveSessionClose: """Branch coverage for HiveSession.close() (line 79).""" async def test_close_calls_websession_close_when_not_already_closed(self): - """close() calls websession.close() when websession is open (closed=False). - - Covers the True branch of 'if not self.api.websession.closed'. - """ + """close() calls websession.close() when websession is open (closed=False).""" session = object.__new__(HiveSession) session.api = MagicMock() session.api.websession.closed = False session.api.websession.close = AsyncMock() + session._owns_websession = True await session.close() session.api.websession.close.assert_called_once() - async def test_close_skips_websession_close_when_already_closed(self): - """close() does NOT call websession.close() when websession is already closed. + async def test_close_with_no_websession_does_not_raise(self): + """close() is a no-op when the lazy websession was never created.""" + session = object.__new__(HiveSession) + session.api = MagicMock() + session.api.websession = None + session._owns_websession = True - Covers branch 79->exit: the 'if not closed' condition is False, so the - body is skipped entirely. - """ + await session.close() + + async def test_close_skips_websession_close_when_already_closed(self): + """close() does NOT call websession.close() when websession is already closed.""" session = object.__new__(HiveSession) session.api = MagicMock() session.api.websession.closed = True session.api.websession.close = AsyncMock() + session._owns_websession = True await session.close() session.api.websession.close.assert_not_called() + + +class TestHiveSessionCloseOwnership: + """close() must not close a caller-provided websession.""" + + async def test_close_does_not_close_caller_provided_websession(self): + """When _owns_websession=False, close() must not close the websession.""" + session = object.__new__(HiveSession) + session.api = MagicMock() + session.api.websession.closed = False + session.api.websession.close = AsyncMock() + session._owns_websession = False + + await session.close() + + session.api.websession.close.assert_not_called() + + async def test_close_closes_owned_websession(self): + """When _owns_websession=True, close() closes the websession as normal.""" + session = object.__new__(HiveSession) + session.api = MagicMock() + session.api.websession.closed = False + session.api.websession.close = AsyncMock() + session._owns_websession = True + + await session.close() + + session.api.websession.close.assert_called_once() diff --git a/tests/unit/test_session_discovery.py b/tests/unit/test_session_discovery.py new file mode 100644 index 0000000..b889fb7 --- /dev/null +++ b/tests/unit/test_session_discovery.py @@ -0,0 +1,602 @@ +"""Extended branch-coverage tests for DiscoveryMixin.start_session and create_devices.""" + +# pylint: disable=attribute-defined-outside-init,too-few-public-methods,protected-access +from datetime import datetime +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from apyhiveapi.helper.hive_exceptions import ( + HiveReauthRequired, + HiveUnknownConfiguration, +) +from apyhiveapi.helper.hivedataclasses import EntityConfig, SessionConfig, SessionTokens +from apyhiveapi.helper.map import Map +from apyhiveapi.session.discovery import DiscoveryMixin + +_POPULATED_PRODUCTS = { + "prod-1": {"id": "prod-1", "type": "heating", "state": {"name": "Hall"}} +} +_POPULATED_DEVICES = {"dev-1": {"id": "dev-1", "type": "hub", "state": {"name": "Hub"}}} + + +def _make_stub(*, has_data=True): + """Return a DiscoveryMixin stub wired for start_session tests (create_devices mocked).""" + + class StubDiscovery(DiscoveryMixin): + """Concrete subclass used only for testing.""" + + s = StubDiscovery() + s.config = SessionConfig() + s.data = Map( + { + "products": _POPULATED_PRODUCTS if has_data else {}, + "devices": _POPULATED_DEVICES if has_data else {}, + "actions": {}, + "minMax": {}, + "user": {}, + } + ) + s.helper = MagicMock() + s.helper.sanitize_payload = MagicMock(return_value={}) + s.auth = MagicMock() + s.tokens = SessionTokens() + s.hub_id = None + s.device_list = { + "parent": [], + "binary_sensor": [], + "climate": [], + "light": [], + "sensor": [], + "switch": [], + "water_heater": [], + } + s.get_devices = AsyncMock(return_value=True) + s.update_tokens = AsyncMock() + s.create_devices = AsyncMock(return_value=s.device_list) + return s + + +def _make_create_stub(): + """Return a DiscoveryMixin stub for testing create_devices directly (not mocked).""" + + class StubDiscovery(DiscoveryMixin): + """Concrete subclass used only for testing.""" + + s = StubDiscovery() + s.config = SessionConfig() + s.data = Map( + { + "products": {}, + "devices": {}, + "actions": {}, + "minMax": {}, + "user": {"temperatureUnit": "C"}, + } + ) + s.helper = MagicMock() + s.helper.get_device_data = MagicMock( + return_value={ + "id": "dev-1", + "state": {"name": "Test Device"}, + "props": {"online": True}, + } + ) + s.hub_id = None + s.device_list = { + "parent": [], + "binary_sensor": [], + "climate": [], + "light": [], + "sensor": [], + "switch": [], + "water_heater": [], + } + return s + + +# --------------------------------------------------------------------------- +# start_session — config branches +# --------------------------------------------------------------------------- + + +class TestStartSessionExtended: + """Tests for start_session config-processing branches.""" + + async def test_with_tokens_config_calls_update_tokens(self): + """Passing 'tokens' in non-file config calls update_tokens(tokens, False).""" + s = _make_stub() + s.config.file = False + tokens = {"token": "t", "accessToken": "a", "refreshToken": "r"} + await s.start_session({"tokens": tokens}) + s.update_tokens.assert_called_once_with(tokens, False) + + async def test_with_username_config_sets_auth_username(self): + """Passing 'username' alongside 'tokens' in non-file config sets auth.username.""" + s = _make_stub() + s.config.file = False + tokens = {"token": "t", "accessToken": "a", "refreshToken": "r"} + await s.start_session({"tokens": tokens, "username": "user@test.com"}) + assert s.auth.username == "user@test.com" + + async def test_with_password_config_sets_auth_password(self): + """Passing 'password' alongside 'tokens' in non-file config sets auth.password.""" + s = _make_stub() + s.config.file = False + tokens = {"token": "t", "accessToken": "a", "refreshToken": "r"} + await s.start_session( + {"tokens": tokens, "password": "secret"} # pragma: allowlist secret + ) + assert s.auth.password == "secret" # pragma: allowlist secret + + async def test_with_device_data_3_items_sets_auth_keys(self): + """3-item device_data sets device_group_key, device_key, device_password on auth.""" + s = _make_stub() + s.config.file = False + await s.start_session( + { + "tokens": {}, + "device_data": ["grp-key", "dev-key", "dev-pass"], + } + ) + assert s.auth.device_group_key == "grp-key" + assert s.auth.device_key == "dev-key" + assert s.auth.device_password == "dev-pass" # pragma: allowlist secret + + async def test_with_device_data_too_short_raises_unknown_configuration(self): + """device_data with fewer than 3 items raises HiveUnknownConfiguration.""" + s = _make_stub() + s.config.file = False + with pytest.raises(HiveUnknownConfiguration): + await s.start_session( + { + "tokens": {}, + "device_data": ["grp-key", "dev-key"], + } + ) + + async def test_with_device_data_4_items_sets_token_created(self): + """4-item device_data with a token_created timestamp sets tokens.token_created.""" + s = _make_stub() + s.config.file = False + created_ts = datetime(2024, 1, 15, 10, 30, 0) + await s.start_session( + { + "tokens": {}, + "device_data": ["grp-key", "dev-key", "dev-pass", created_ts], + } + ) + assert s.tokens.token_created == created_ts + + async def test_with_device_data_4_items_none_token_created_not_set(self): + """4-item device_data where token_created is None — does not overwrite token_created.""" + s = _make_stub() + s.config.file = False + original_created = s.tokens.token_created + await s.start_session( + { + "tokens": {}, + "device_data": ["grp-key", "dev-key", "dev-pass", None], + } + ) + assert s.tokens.token_created == original_created + + async def test_no_tokens_and_not_file_raises_unknown_configuration(self): + """Non-file config without 'tokens' raises HiveUnknownConfiguration.""" + s = _make_stub() + s.config.file = False + with pytest.raises(HiveUnknownConfiguration): + await s.start_session({"username": "user@test.com"}) + + async def test_empty_devices_after_get_devices_raises_unknown_configuration(self): + """start_session raises HiveUnknownConfiguration when data.devices is empty post-poll.""" + s = _make_stub(has_data=False) + s.config.file = True + with pytest.raises(HiveUnknownConfiguration): + await s.start_session({}) + + async def test_none_config_defaults_to_empty_dict(self): + """start_session(None) is treated as start_session({}) — set file mode separately.""" + s = _make_stub() + s.config.file = True + # Should not raise; equivalent to passing {} + result = await s.start_session(None) + assert result is s.device_list + + async def test_file_mode_username_skips_token_branch(self): + """'use@file.com' activates file mode so 'tokens' branch is skipped.""" + s = _make_stub() + s.config.file = False + # Even if tokens is present, file mode skips the update_tokens call + await s.start_session({"username": "use@file.com", "tokens": {}}) + s.update_tokens.assert_not_called() + + +# --------------------------------------------------------------------------- +# create_devices — device processing +# --------------------------------------------------------------------------- + + +class TestCreateDevicesExtended: + """Tests for create_devices branches not covered by the main test files.""" + + async def test_no_hub_device_hub_id_stays_none(self): + """Devices list with no 'hub' type leaves hub_id as None (else branch of for-loop).""" + s = _make_create_stub() + s.data["devices"] = { + "trv-1": {"id": "trv-1", "type": "trv", "state": {"name": "TRV"}} + } + s.data["products"] = {} + await s.create_devices() + assert s.hub_id is None + + async def test_hub_device_sets_hub_id(self): + """Devices list with a 'hub' type sets hub_id to that device's ID.""" + s = _make_create_stub() + s.data["devices"] = { + "hub-42": {"id": "hub-42", "type": "hub", "state": {"name": "My Hub"}} + } + await s.create_devices() + assert s.hub_id == "hub-42" + + async def test_product_with_error_key_is_skipped(self): + """Products with an 'error' key are silently skipped.""" + s = _make_create_stub() + s.data["products"] = { + "bad": {"id": "bad", "type": "heating", "error": "device not found"} + } + result = await s.create_devices() + assert result["climate"] == [] + + async def test_non_heating_group_product_skipped(self): + """isGroup=True products of non-heating type are not added to any list.""" + s = _make_create_stub() + s.data["products"] = { + "grp-1": { + "id": "grp-1", + "type": "activeplug", + "isGroup": True, + "state": {"name": "Plug Group"}, + } + } + result = await s.create_devices() + assert result["switch"] == [] + + async def test_heating_group_product_not_skipped(self): + """isGroup=True products of heating type are processed and added.""" + s = _make_create_stub() + s.data["products"] = { + "h-grp": { + "id": "h-grp", + "type": "heating", + "isGroup": True, + "state": {"name": "Heating Zone"}, + } + } + result = await s.create_devices() + assert len(result["climate"]) == 1 + + async def test_multiple_devices_all_processed(self): + """Multiple devices in the device list are all processed.""" + s = _make_create_stub() + s.data["devices"] = { + "hub-1": {"id": "hub-1", "type": "hub", "state": {"name": "Hub"}}, + "trv-1": {"id": "trv-1", "type": "trv", "state": {"name": "TRV"}}, + } + s.data["products"] = {} + await s.create_devices() + # Hub is found; hub_id is set to the hub device + assert s.hub_id == "hub-1" + + async def test_action_processed_as_switch(self): + """Actions in data.actions are added to device_list['switch'].""" + s = _make_create_stub() + s.data["actions"] = { + "act-1": {"id": "act-1", "name": "Good Night", "type": "action"} + } + result = await s.create_devices() + assert len(result["switch"]) == 1 + assert result["switch"][0].hive_type == "action" + + async def test_returns_device_list_dict(self): + """create_devices always returns a dict with the expected HA entity keys.""" + s = _make_create_stub() + result = await s.create_devices() + for key in ( + "parent", + "binary_sensor", + "climate", + "light", + "sensor", + "switch", + "water_heater", + ): + assert key in result + + async def test_product_with_error_and_valid_both_present_only_valid_added(self): + """Only products without 'error' are added when both types coexist.""" + s = _make_create_stub() + s.data["products"] = { + "bad": {"id": "bad", "type": "heating", "error": "broken"}, + "good": {"id": "good", "type": "heating", "state": {"name": "Hall"}}, + } + result = await s.create_devices() + assert len(result["climate"]) == 1 + assert result["climate"][0].hive_id == "good" + + +# --------------------------------------------------------------------------- +# start_session raises wrong exception for empty device data +# --------------------------------------------------------------------------- + + +class TestStartSessionWrongException: + """start_session must raise HiveUnknownConfiguration (not HiveReauthRequired) for empty data.""" + + async def test_empty_devices_raises_unknown_configuration(self): + """start_session raises HiveUnknownConfiguration when API returns no devices.""" + s = _make_stub(has_data=False) + s.get_devices = AsyncMock() + + with pytest.raises(HiveUnknownConfiguration): + await s.start_session({}) + + async def test_does_not_raise_reauth_for_empty_data(self): + """start_session must NOT raise HiveReauthRequired when device data is empty.""" + s = _make_stub(has_data=False) + s.get_devices = AsyncMock() + + with pytest.raises(Exception) as exc_info: + await s.start_session({}) + + assert not isinstance(exc_info.value, HiveReauthRequired), ( + "HiveReauthRequired must not be raised for empty device list" + ) + + +# --------------------------------------------------------------------------- +# create_devices — bare d["id"] and p["id"] crash when id key is absent +# --------------------------------------------------------------------------- + + +class TestBareIdAccess: + """create_devices must use .get('id', fallback) instead of bare ['id'] access.""" + + async def test_battery_device_without_id_does_not_crash(self): + """Device with no 'id' key in battery-type must not raise KeyError.""" + s = _make_create_stub() + s.data["devices"] = { + "trv-key": { + "type": "trv", + "state": {"name": "TRV"}, + "props": {}, + } + } + s.config.battery = set() + try: + await s.create_devices() + except KeyError as err: + pytest.fail(f"KeyError raised for missing 'id' in device: {err}") + + async def test_mode_product_without_id_does_not_crash(self): + """Product with no 'id' key in mode-type must not raise KeyError.""" + s = _make_create_stub() + s.data["products"] = { + "heating-key": { + "type": "heating", + "state": {"name": "Hall"}, + } + } + s.config.mode = set() + try: + await s.create_devices() + except KeyError as err: + pytest.fail(f"KeyError raised for missing 'id' in product: {err}") + + +# =========================================================================== +# Migrated from test_remaining_branches.py +# =========================================================================== + + +class TestCreateDevicesEntityConfigKwargs: + """Lines 224->226, 226->228, 228->230: entity_config kwarg population in DEVICES loop.""" + + async def test_entity_config_with_all_fields_populates_kwargs(self): + """EntityConfig with ha_name, hive_type, and category all set → all kwargs passed.""" + s = _make_create_stub() + s.data["devices"] = { + "dev-1": { + "id": "dev-1", + "type": "hub", + "state": {"name": "My Hub"}, + "props": {}, + } + } + entity_cfg = EntityConfig( + entity_type="binary_sensor", + ha_name="Hub Status", + hive_type="Connectivity", + category="diagnostic", + ) + with patch("apyhiveapi.session.discovery.DEVICES", {"hub": [entity_cfg]}): + result = await s.create_devices() + assert len(result["binary_sensor"]) == 1 + created = result["binary_sensor"][0] + assert created.hive_type == "Connectivity" + assert created.category == "diagnostic" + + async def test_entity_config_empty_fields_does_not_add_to_kwargs(self): + """EntityConfig with empty ha_name and hive_type does not inject those keys.""" + s = _make_create_stub() + s.data["devices"] = { + "dev-1": { + "id": "dev-1", + "type": "hub", + "state": {"name": "My Hub"}, + "props": {}, + } + } + entity_cfg = EntityConfig( + entity_type="binary_sensor", + ha_name="", # falsy — should not be added to kwargs + hive_type="", # falsy — should not be added to kwargs + category=None, # None — should not be added to kwargs + ) + with patch("apyhiveapi.session.discovery.DEVICES", {"hub": [entity_cfg]}): + result = await s.create_devices() + # Should still process without error + assert isinstance(result, dict) + + +class TestCreateDevicesDeviceAddListError: + """Lines 232-233: KeyError/TypeError from add_list in DEVICES loop is caught.""" + + async def test_add_list_keyerror_is_caught_not_raised(self): + """KeyError from add_list during device processing is logged, not propagated.""" + s = _make_create_stub() + s.data["devices"] = { + "dev-1": { + "id": "dev-1", + "type": "hub", + "state": {"name": "My Hub"}, + "props": {}, + } + } + entity_cfg = EntityConfig( + entity_type="binary_sensor", + ha_name="Hub Status", + hive_type="Connectivity", + category="diagnostic", + ) + with patch("apyhiveapi.session.discovery.DEVICES", {"hub": [entity_cfg]}): + with patch.object(s, "add_list", side_effect=KeyError("bad key")): + # Should complete without raising + result = await s.create_devices() + assert isinstance(result, dict) + + async def test_add_list_typeerror_is_caught_not_raised(self): + """TypeError from add_list during device processing is caught.""" + s = _make_create_stub() + s.data["devices"] = { + "dev-1": { + "id": "dev-1", + "type": "hub", + "state": {"name": "My Hub"}, + "props": {}, + } + } + entity_cfg = EntityConfig( + entity_type="binary_sensor", + ha_name="", + hive_type="", + category=None, + ) + with patch("apyhiveapi.session.discovery.DEVICES", {"hub": [entity_cfg]}): + with patch.object(s, "add_list", side_effect=TypeError("bad type")): + result = await s.create_devices() + assert isinstance(result, dict) + + +class TestCreateDevicesActionAddListError: + """Lines 258-259: KeyError/TypeError from add_list in actions loop is caught.""" + + async def test_action_add_list_keyerror_is_caught(self): + """KeyError from add_list when processing an action is logged, not propagated.""" + s = _make_create_stub() + s.data["actions"] = {"act-1": {"id": "act-1", "name": "Good Night"}} + with patch.object(s, "add_list", side_effect=KeyError("missing")): + result = await s.create_devices() + assert isinstance(result, dict) + + async def test_action_add_list_typeerror_is_caught(self): + """TypeError from add_list when processing an action is caught.""" + s = _make_create_stub() + s.data["actions"] = {"act-1": {"id": "act-1", "name": "Wake Up"}} + with patch.object(s, "add_list", side_effect=TypeError("type error")): + result = await s.create_devices() + assert isinstance(result, dict) + + +class TestCreateDevicesProductTemperatureUnit: + """Line 305: entity_config.temperature_unit is used when set and entity_type != 'climate'.""" + + async def test_entity_config_temperature_unit_passed_to_add_list(self): + """EntityConfig with temperature_unit set propagates that value as a kwarg.""" + s = _make_create_stub() + s.data["products"] = { + "prod-1": { + "id": "prod-1", + "type": "heating", + "state": {"name": "Heating"}, + "props": {}, + } + } + # A non-climate entity with temperature_unit set triggers line 305 + entity_cfg = EntityConfig( + entity_type="sensor", + ha_name="Temp Sensor", + hive_type="Current_Temperature", + category="diagnostic", + temperature_unit="F", + ) + captured_kwargs = {} + + original_add_list = s.add_list + + def capturing_add_list(entity_type, data, **kwargs): + captured_kwargs.update(kwargs) + return original_add_list(entity_type, data, **kwargs) + + with patch("apyhiveapi.session.discovery.PRODUCTS", {"heating": [entity_cfg]}): + with patch.object(s, "add_list", side_effect=capturing_add_list): + await s.create_devices() + + assert captured_kwargs.get("temperature_unit") == "F" + + +class TestCreateDevicesProductAddListAttributeError: + """Lines 308-309: NameError/AttributeError from add_list in products loop is caught.""" + + async def test_product_add_list_attribute_error_is_caught(self): + """AttributeError from add_list when processing a product is caught.""" + s = _make_create_stub() + s.data["products"] = { + "prod-1": { + "id": "prod-1", + "type": "heating", + "state": {"name": "Heating"}, + "props": {}, + } + } + entity_cfg = EntityConfig( + entity_type="climate", + ha_name="", + hive_type="", + category=None, + ) + with patch("apyhiveapi.session.discovery.PRODUCTS", {"heating": [entity_cfg]}): + with patch.object(s, "add_list", side_effect=AttributeError("attr error")): + result = await s.create_devices() + assert isinstance(result, dict) + + async def test_product_add_list_name_error_is_caught(self): + """NameError from add_list when processing a product is caught.""" + s = _make_create_stub() + s.data["products"] = { + "prod-1": { + "id": "prod-1", + "type": "heating", + "state": {"name": "Heating"}, + "props": {}, + } + } + entity_cfg = EntityConfig( + entity_type="climate", + ha_name="", + hive_type="", + category=None, + ) + with patch("apyhiveapi.session.discovery.PRODUCTS", {"heating": [entity_cfg]}): + with patch.object(s, "add_list", side_effect=NameError("name error")): + result = await s.create_devices() + assert isinstance(result, dict) diff --git a/tests/unit/test_session_discovery_extended.py b/tests/unit/test_session_discovery_extended.py deleted file mode 100644 index 8bc4764..0000000 --- a/tests/unit/test_session_discovery_extended.py +++ /dev/null @@ -1,312 +0,0 @@ -"""Extended branch-coverage tests for DiscoveryMixin.start_session and create_devices.""" - -# pylint: disable=attribute-defined-outside-init,too-few-public-methods,protected-access -from datetime import datetime -from unittest.mock import AsyncMock, MagicMock - -import pytest -from apyhiveapi.helper.hive_exceptions import ( - HiveReauthRequired, - HiveUnknownConfiguration, -) -from apyhiveapi.helper.hivedataclasses import SessionConfig, SessionTokens -from apyhiveapi.helper.map import Map -from apyhiveapi.session.discovery import DiscoveryMixin - -_POPULATED_PRODUCTS = { - "prod-1": {"id": "prod-1", "type": "heating", "state": {"name": "Hall"}} -} -_POPULATED_DEVICES = {"dev-1": {"id": "dev-1", "type": "hub", "state": {"name": "Hub"}}} - - -def _make_stub(*, has_data=True): - """Return a DiscoveryMixin stub wired for start_session tests (create_devices mocked).""" - - class StubDiscovery(DiscoveryMixin): - """Concrete subclass used only for testing.""" - - s = StubDiscovery() - s.config = SessionConfig() - s.data = Map( - { - "products": _POPULATED_PRODUCTS if has_data else {}, - "devices": _POPULATED_DEVICES if has_data else {}, - "actions": {}, - "minMax": {}, - "user": {}, - } - ) - s.helper = MagicMock() - s.helper.sanitize_payload = MagicMock(return_value={}) - s.auth = MagicMock() - s.tokens = SessionTokens() - s.hub_id = None - s.device_list = { - "parent": [], - "binary_sensor": [], - "climate": [], - "light": [], - "sensor": [], - "switch": [], - "water_heater": [], - } - s.get_devices = AsyncMock(return_value=True) - s.update_tokens = AsyncMock() - s.create_devices = AsyncMock(return_value=s.device_list) - return s - - -def _make_create_stub(): - """Return a DiscoveryMixin stub for testing create_devices directly (not mocked).""" - - class StubDiscovery(DiscoveryMixin): - """Concrete subclass used only for testing.""" - - s = StubDiscovery() - s.config = SessionConfig() - s.data = Map( - { - "products": {}, - "devices": {}, - "actions": {}, - "minMax": {}, - "user": {"temperatureUnit": "C"}, - } - ) - s.helper = MagicMock() - s.helper.get_device_data = MagicMock( - return_value={ - "id": "dev-1", - "state": {"name": "Test Device"}, - "props": {"online": True}, - } - ) - s.hub_id = None - s.device_list = { - "parent": [], - "binary_sensor": [], - "climate": [], - "light": [], - "sensor": [], - "switch": [], - "water_heater": [], - } - return s - - -# --------------------------------------------------------------------------- -# start_session — config branches -# --------------------------------------------------------------------------- - - -class TestStartSessionExtended: - """Tests for start_session config-processing branches.""" - - async def test_with_tokens_config_calls_update_tokens(self): - """Passing 'tokens' in non-file config calls update_tokens(tokens, False).""" - s = _make_stub() - s.config.file = False - tokens = {"token": "t", "accessToken": "a", "refreshToken": "r"} - await s.start_session({"tokens": tokens}) - s.update_tokens.assert_called_once_with(tokens, False) - - async def test_with_username_config_sets_auth_username(self): - """Passing 'username' alongside 'tokens' in non-file config sets auth.username.""" - s = _make_stub() - s.config.file = False - tokens = {"token": "t", "accessToken": "a", "refreshToken": "r"} - await s.start_session({"tokens": tokens, "username": "user@test.com"}) - assert s.auth.username == "user@test.com" - - async def test_with_password_config_sets_auth_password(self): - """Passing 'password' alongside 'tokens' in non-file config sets auth.password.""" - s = _make_stub() - s.config.file = False - tokens = {"token": "t", "accessToken": "a", "refreshToken": "r"} - await s.start_session( - {"tokens": tokens, "password": "secret"} # pragma: allowlist secret - ) - assert s.auth.password == "secret" # pragma: allowlist secret - - async def test_with_device_data_3_items_sets_auth_keys(self): - """3-item device_data sets device_group_key, device_key, device_password on auth.""" - s = _make_stub() - s.config.file = False - await s.start_session( - { - "tokens": {}, - "device_data": ["grp-key", "dev-key", "dev-pass"], - } - ) - assert s.auth.device_group_key == "grp-key" - assert s.auth.device_key == "dev-key" - assert s.auth.device_password == "dev-pass" - - async def test_with_device_data_4_items_sets_token_created(self): - """4-item device_data with a token_created timestamp sets tokens.token_created.""" - s = _make_stub() - s.config.file = False - created_ts = datetime(2024, 1, 15, 10, 30, 0) - await s.start_session( - { - "tokens": {}, - "device_data": ["grp-key", "dev-key", "dev-pass", created_ts], - } - ) - assert s.tokens.token_created == created_ts - - async def test_with_device_data_4_items_none_token_created_not_set(self): - """4-item device_data where token_created is None — does not overwrite token_created.""" - s = _make_stub() - s.config.file = False - original_created = s.tokens.token_created - await s.start_session( - { - "tokens": {}, - "device_data": ["grp-key", "dev-key", "dev-pass", None], - } - ) - assert s.tokens.token_created == original_created - - async def test_no_tokens_and_not_file_raises_unknown_configuration(self): - """Non-file config without 'tokens' raises HiveUnknownConfiguration.""" - s = _make_stub() - s.config.file = False - with pytest.raises(HiveUnknownConfiguration): - await s.start_session({"username": "user@test.com"}) - - async def test_empty_devices_after_get_devices_raises_reauth(self): - """start_session raises HiveReauthRequired when data.devices is empty post-poll.""" - s = _make_stub(has_data=False) - s.config.file = True - with pytest.raises(HiveReauthRequired): - await s.start_session({}) - - async def test_none_config_defaults_to_empty_dict(self): - """start_session(None) is treated as start_session({}) — set file mode separately.""" - s = _make_stub() - s.config.file = True - # Should not raise; equivalent to passing {} - result = await s.start_session(None) - assert result is s.device_list - - async def test_file_mode_username_skips_token_branch(self): - """'use@file.com' activates file mode so 'tokens' branch is skipped.""" - s = _make_stub() - s.config.file = False - # Even if tokens is present, file mode skips the update_tokens call - await s.start_session({"username": "use@file.com", "tokens": {}}) - s.update_tokens.assert_not_called() - - -# --------------------------------------------------------------------------- -# create_devices — device processing -# --------------------------------------------------------------------------- - - -class TestCreateDevicesExtended: - """Tests for create_devices branches not covered by the main test files.""" - - async def test_no_hub_device_hub_id_stays_none(self): - """Devices list with no 'hub' type leaves hub_id as None (else branch of for-loop).""" - s = _make_create_stub() - s.data["devices"] = { - "trv-1": {"id": "trv-1", "type": "trv", "state": {"name": "TRV"}} - } - s.data["products"] = {} - await s.create_devices() - assert s.hub_id is None - - async def test_hub_device_sets_hub_id(self): - """Devices list with a 'hub' type sets hub_id to that device's ID.""" - s = _make_create_stub() - s.data["devices"] = { - "hub-42": {"id": "hub-42", "type": "hub", "state": {"name": "My Hub"}} - } - await s.create_devices() - assert s.hub_id == "hub-42" - - async def test_product_with_error_key_is_skipped(self): - """Products with an 'error' key are silently skipped.""" - s = _make_create_stub() - s.data["products"] = { - "bad": {"id": "bad", "type": "heating", "error": "device not found"} - } - result = await s.create_devices() - assert result["climate"] == [] - - async def test_non_heating_group_product_skipped(self): - """isGroup=True products of non-heating type are not added to any list.""" - s = _make_create_stub() - s.data["products"] = { - "grp-1": { - "id": "grp-1", - "type": "activeplug", - "isGroup": True, - "state": {"name": "Plug Group"}, - } - } - result = await s.create_devices() - assert result["switch"] == [] - - async def test_heating_group_product_not_skipped(self): - """isGroup=True products of heating type are processed and added.""" - s = _make_create_stub() - s.data["products"] = { - "h-grp": { - "id": "h-grp", - "type": "heating", - "isGroup": True, - "state": {"name": "Heating Zone"}, - } - } - result = await s.create_devices() - assert len(result["climate"]) == 1 - - async def test_multiple_devices_all_processed(self): - """Multiple devices in the device list are all processed.""" - s = _make_create_stub() - s.data["devices"] = { - "hub-1": {"id": "hub-1", "type": "hub", "state": {"name": "Hub"}}, - "trv-1": {"id": "trv-1", "type": "trv", "state": {"name": "TRV"}}, - } - s.data["products"] = {} - await s.create_devices() - # Hub is found; hub_id is set to the hub device - assert s.hub_id == "hub-1" - - async def test_action_processed_as_switch(self): - """Actions in data.actions are added to device_list['switch'].""" - s = _make_create_stub() - s.data["actions"] = { - "act-1": {"id": "act-1", "name": "Good Night", "type": "action"} - } - result = await s.create_devices() - assert len(result["switch"]) == 1 - assert result["switch"][0].hive_type == "action" - - async def test_returns_device_list_dict(self): - """create_devices always returns a dict with the expected HA entity keys.""" - s = _make_create_stub() - result = await s.create_devices() - for key in ( - "parent", - "binary_sensor", - "climate", - "light", - "sensor", - "switch", - "water_heater", - ): - assert key in result - - async def test_product_with_error_and_valid_both_present_only_valid_added(self): - """Only products without 'error' are added when both types coexist.""" - s = _make_create_stub() - s.data["products"] = { - "bad": {"id": "bad", "type": "heating", "error": "broken"}, - "good": {"id": "good", "type": "heating", "state": {"name": "Hall"}}, - } - result = await s.create_devices() - assert len(result["climate"]) == 1 - assert result["climate"][0].hive_id == "good" diff --git a/tests/unit/test_session_get_devices.py b/tests/unit/test_session_get_devices.py index 9042215..b284b1e 100644 --- a/tests/unit/test_session_get_devices.py +++ b/tests/unit/test_session_get_devices.py @@ -94,6 +94,39 @@ async def test_file_mode_does_not_call_api(self): p.hive_refresh_tokens.assert_not_called() +# --------------------------------------------------------------------------- +# get_devices — malformed API responses +# --------------------------------------------------------------------------- + + +class TestGetDevicesMalformedResponse: + """Hostile/partial response shapes must not escape as raw KeyError.""" + + async def test_user_without_id_still_succeeds(self): + """A user object with no 'id' key must not crash the poll.""" + p = _make_stub() + p.tokens = MagicMock() + p.api.get_all = AsyncMock( + return_value={"original": 200, "parsed": {"user": {"name": "x"}}} + ) + result = await p.get_devices("No_ID") + assert result is True + assert p.config.user_id is None + + async def test_product_without_id_returns_false(self): + """A product entry with no 'id' key fails the poll gracefully.""" + p = _make_stub() + p.tokens = MagicMock() + p.api.get_all = AsyncMock( + return_value={ + "original": 200, + "parsed": {"products": [{"type": "heating"}]}, + } + ) + result = await p.get_devices("No_ID") + assert result is False + + # --------------------------------------------------------------------------- # get_devices — tokens path # ---------------------------------------------------------------------------