From 2f97ef4c1d2fe2016ff8530d6f4e791dfee4dff7 Mon Sep 17 00:00:00 2001 From: Steven B <51370195+sdb9696@users.noreply.github.com> Date: Fri, 3 Jan 2025 11:12:11 +0000 Subject: [PATCH 1/2] Handle smartcam partial list responses --- kasa/protocols/smartcamprotocol.py | 13 ++++++++----- kasa/protocols/smartprotocol.py | 21 +++++++++++++++------ 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/kasa/protocols/smartcamprotocol.py b/kasa/protocols/smartcamprotocol.py index 324f80563..b0e8b1ec4 100644 --- a/kasa/protocols/smartcamprotocol.py +++ b/kasa/protocols/smartcamprotocol.py @@ -49,10 +49,11 @@ class SingleRequest: class SmartCamProtocol(SmartProtocol): """Class for SmartCam Protocol.""" - async def _handle_response_lists( - self, response_result: dict[str, Any], method: str, retry_count: int - ) -> None: - pass + def _get_list_request(self, method: str, start_index: int) -> dict: + if method in {"getChildDeviceList", "getChildDeviceComponentList"}: + return {method: {"childControl": {"start_index": start_index}}} + + return {method: {"start_index": start_index}} def _handle_response_error_code( self, resp_dict: dict, method: str, raise_on_error: bool = True @@ -147,7 +148,9 @@ async def _execute_query( if len(request) == 1 and method in {"get", "set", "do", "multipleRequest"}: single_request = self._get_smart_camera_single_request(request) else: - return await self._execute_multiple_query(request, retry_count) + return await self._execute_multiple_query( + request, retry_count, iterate_list_pages + ) else: single_request = self._make_smart_camera_single_request(request) diff --git a/kasa/protocols/smartprotocol.py b/kasa/protocols/smartprotocol.py index 7f02b45e7..1673ed2e8 100644 --- a/kasa/protocols/smartprotocol.py +++ b/kasa/protocols/smartprotocol.py @@ -180,7 +180,9 @@ async def _query(self, request: str | dict, retry_count: int = 3) -> dict: # make mypy happy, this should never be reached.. raise KasaException("Query reached somehow to unreachable") - async def _execute_multiple_query(self, requests: dict, retry_count: int) -> dict: + async def _execute_multiple_query( + self, requests: dict, retry_count: int, iterate_list_pages: bool + ) -> dict: debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG) multi_result: dict[str, Any] = {} smart_method = "multipleRequest" @@ -275,9 +277,10 @@ async def _execute_multiple_query(self, requests: dict, retry_count: int) -> dic response, method, raise_on_error=raise_on_error ) result = response.get("result", None) - await self._handle_response_lists( - result, method, retry_count=retry_count - ) + if iterate_list_pages and result: + await self._handle_response_lists( + result, method, retry_count=retry_count + ) multi_result[method] = result # Multi requests don't continue after errors so requery any missing. @@ -303,7 +306,9 @@ async def _execute_query( smart_method = next(iter(request)) smart_params = request[smart_method] else: - return await self._execute_multiple_query(request, retry_count) + return await self._execute_multiple_query( + request, retry_count, iterate_list_pages + ) else: smart_method = request smart_params = None @@ -334,6 +339,9 @@ async def _execute_query( ) return {smart_method: result} + def _get_list_request(self, method: str, start_index: int) -> dict: + return {method: {"start_index": start_index}} + async def _handle_response_lists( self, response_result: dict[str, Any], method: str, retry_count: int ) -> None: @@ -355,8 +363,9 @@ async def _handle_response_lists( ) ) while (list_length := len(response_result[response_list_name])) < list_sum: + request = self._get_list_request(method, list_length) response = await self._execute_query( - {method: {"start_index": list_length}}, + request, retry_count=retry_count, iterate_list_pages=False, ) From 949d558100b9306adb96c0aa1f0bed8b1420ebdd Mon Sep 17 00:00:00 2001 From: Steven B <51370195+sdb9696@users.noreply.github.com> Date: Sat, 4 Jan 2025 11:44:27 +0000 Subject: [PATCH 2/2] Add tests and make obtaining module name generic --- kasa/protocols/smartcamprotocol.py | 14 +++++---- kasa/protocols/smartprotocol.py | 17 +++++++---- tests/fakeprotocol_smartcam.py | 19 ++++++++++--- tests/protocols/test_smartprotocol.py | 41 +++++++++++++++++++++++++++ 4 files changed, 76 insertions(+), 15 deletions(-) diff --git a/kasa/protocols/smartcamprotocol.py b/kasa/protocols/smartcamprotocol.py index b0e8b1ec4..a1d6ae9c8 100644 --- a/kasa/protocols/smartcamprotocol.py +++ b/kasa/protocols/smartcamprotocol.py @@ -5,7 +5,7 @@ import logging from dataclasses import dataclass from pprint import pformat as pf -from typing import Any +from typing import Any, cast from ..exceptions import ( AuthenticationError, @@ -49,11 +49,13 @@ class SingleRequest: class SmartCamProtocol(SmartProtocol): """Class for SmartCam Protocol.""" - def _get_list_request(self, method: str, start_index: int) -> dict: - if method in {"getChildDeviceList", "getChildDeviceComponentList"}: - return {method: {"childControl": {"start_index": start_index}}} - - return {method: {"start_index": start_index}} + def _get_list_request( + self, method: str, params: dict | None, start_index: int + ) -> dict: + # All smartcam requests have params + params = cast(dict, params) + module_name = next(iter(params)) + return {method: {module_name: {"start_index": start_index}}} def _handle_response_error_code( self, resp_dict: dict, method: str, raise_on_error: bool = True diff --git a/kasa/protocols/smartprotocol.py b/kasa/protocols/smartprotocol.py index 1673ed2e8..28a20641e 100644 --- a/kasa/protocols/smartprotocol.py +++ b/kasa/protocols/smartprotocol.py @@ -277,9 +277,10 @@ async def _execute_multiple_query( response, method, raise_on_error=raise_on_error ) result = response.get("result", None) + request_params = rp if (rp := requests.get(method)) else None if iterate_list_pages and result: await self._handle_response_lists( - result, method, retry_count=retry_count + result, method, request_params, retry_count=retry_count ) multi_result[method] = result @@ -335,15 +336,21 @@ async def _execute_query( result = response_data.get("result") if iterate_list_pages and result: await self._handle_response_lists( - result, smart_method, retry_count=retry_count + result, smart_method, smart_params, retry_count=retry_count ) return {smart_method: result} - def _get_list_request(self, method: str, start_index: int) -> dict: + def _get_list_request( + self, method: str, params: dict | None, start_index: int + ) -> dict: return {method: {"start_index": start_index}} async def _handle_response_lists( - self, response_result: dict[str, Any], method: str, retry_count: int + self, + response_result: dict[str, Any], + method: str, + params: dict | None, + retry_count: int, ) -> None: if ( response_result is None @@ -363,7 +370,7 @@ async def _handle_response_lists( ) ) while (list_length := len(response_result[response_list_name])) < list_sum: - request = self._get_list_request(method, list_length) + request = self._get_list_request(method, params, list_length) response = await self._execute_query( request, retry_count=retry_count, diff --git a/tests/fakeprotocol_smartcam.py b/tests/fakeprotocol_smartcam.py index 381a0a89c..eee014e8f 100644 --- a/tests/fakeprotocol_smartcam.py +++ b/tests/fakeprotocol_smartcam.py @@ -33,6 +33,7 @@ def __init__( *, list_return_size=10, is_child=False, + get_child_fixtures=True, verbatim=False, components_not_included=False, ): @@ -52,9 +53,12 @@ def __init__( self.verbatim = verbatim if not is_child: self.info = copy.deepcopy(info) - self.child_protocols = FakeSmartTransport._get_child_protocols( - self.info, self.fixture_name, "getChildDeviceList" - ) + # We don't need to get the child fixtures if testing things like + # lists + if get_child_fixtures: + self.child_protocols = FakeSmartTransport._get_child_protocols( + self.info, self.fixture_name, "getChildDeviceList" + ) else: self.info = info # self.child_protocols = self._get_child_protocols() @@ -229,9 +233,16 @@ async def _send_request(self, request_dict: dict): list_key = next( iter([key for key in result if isinstance(result[key], list)]) ) + assert isinstance(params, dict) + module_name = next(iter(params)) + start_index = ( start_index - if (params and (start_index := params.get("start_index"))) + if ( + params + and module_name + and (start_index := params[module_name].get("start_index")) + ) else 0 ) diff --git a/tests/protocols/test_smartprotocol.py b/tests/protocols/test_smartprotocol.py index 7961df68d..514926353 100644 --- a/tests/protocols/test_smartprotocol.py +++ b/tests/protocols/test_smartprotocol.py @@ -10,6 +10,7 @@ KasaException, SmartErrorCode, ) +from kasa.protocols.smartcamprotocol import SmartCamProtocol from kasa.protocols.smartprotocol import SmartProtocol, _ChildProtocolWrapper from kasa.smart import SmartDevice @@ -373,6 +374,46 @@ async def test_smart_protocol_lists_multiple_request(mocker, list_sum, batch_siz assert resp == response +@pytest.mark.parametrize("list_sum", [5, 10, 30]) +@pytest.mark.parametrize("batch_size", [1, 2, 3, 50]) +async def test_smartcam_protocol_list_request(mocker, list_sum, batch_size): + """Test smartcam protocol list handling for lists.""" + child_list = [{"foo": i} for i in range(list_sum)] + + response = { + "getChildDeviceList": { + "child_device_list": child_list, + "start_index": 0, + "sum": list_sum, + }, + "getChildDeviceComponentList": { + "child_component_list": child_list, + "start_index": 0, + "sum": list_sum, + }, + } + request = { + "getChildDeviceList": {"childControl": {"start_index": 0}}, + "getChildDeviceComponentList": {"childControl": {"start_index": 0}}, + } + + ft = FakeSmartCamTransport( + response, + "foobar", + list_return_size=batch_size, + components_not_included=True, + get_child_fixtures=False, + ) + protocol = SmartCamProtocol(transport=ft) + query_spy = mocker.spy(protocol, "_execute_query") + resp = await protocol.query(request) + expected_count = 1 + 2 * ( + int(list_sum / batch_size) + (0 if list_sum % batch_size else -1) + ) + assert query_spy.call_count == expected_count + assert resp == response + + async def test_incomplete_list(mocker, caplog): """Test for handling incomplete lists returned from queries.""" info = {