From addcc261b141892aef8b1573f6928011480911cc Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Tue, 23 Jan 2024 12:06:58 +0000 Subject: [PATCH 1/7] Generate AES KeyPair lazily --- kasa/aestransport.py | 49 +++++++++++++++++++++------------ kasa/httpclient.py | 7 ++++- kasa/tests/test_aestransport.py | 5 +++- 3 files changed, 41 insertions(+), 20 deletions(-) diff --git a/kasa/aestransport.py b/kasa/aestransport.py index 65b0045df..3e36f9ba8 100644 --- a/kasa/aestransport.py +++ b/kasa/aestransport.py @@ -8,7 +8,7 @@ import hashlib import logging import time -from typing import Dict, Optional, cast +from typing import TYPE_CHECKING, AsyncGenerator, Dict, Optional, cast from cryptography.hazmat.primitives import padding, serialization from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding @@ -86,6 +86,8 @@ def __init__( self._login_token = None + self._key_pair: Optional[KeyPair] = None + _LOGGER.debug("Created AES transport for %s", self._host) @property @@ -185,34 +187,43 @@ async def perform_login(self): self._handle_response_error_code(resp_dict, "Error logging in") self._login_token = resp_dict["result"]["token"] - async def perform_handshake(self): - """Perform the handshake.""" - _LOGGER.debug("Will perform handshaking...") - _LOGGER.debug("Generating keypair") - - self._handshake_done = False - self._session_expire_at = None - self._session_cookie = None - - url = f"http://{self._host}/app" - key_pair = KeyPair.create_key_pair() + async def _generate_request_body(self) -> AsyncGenerator: + """Generate the request body and return an ascyn_generator. + This prevents the key pair being generated unless a connection + can be made to the device. + """ + if self._key_pair: + return + self._key_pair = KeyPair.create_key_pair() pub_key = ( "-----BEGIN PUBLIC KEY-----\n" - + key_pair.get_public_key() + + self._key_pair.get_public_key() # type: ignore[union-attr] + "\n-----END PUBLIC KEY-----\n" ) handshake_params = {"key": pub_key} _LOGGER.debug(f"Handshake params: {handshake_params}") - request_body = {"method": "handshake", "params": handshake_params} - _LOGGER.debug(f"Request {request_body}") + yield json_dumps(request_body).encode() + + async def perform_handshake(self): + """Perform the handshake.""" + _LOGGER.debug("Will perform handshaking...") + _LOGGER.debug("Generating keypair") + self._key_pair = None + self._handshake_done = False + self._session_expire_at = None + self._session_cookie = None + + url = f"http://{self._host}/app" + # Device needs the content length or it will response with 500 + headers = {**self.COMMON_HEADERS, "Content-Length": "318"} status_code, resp_dict = await self._http_client.post( url, - json=request_body, - headers=self.COMMON_HEADERS, + json=self._generate_request_body(), + headers=headers, cookies_dict=self._session_cookie, ) @@ -240,8 +251,10 @@ async def perform_handshake(self): self._session_cookie = {self.SESSION_COOKIE_NAME: cookie} self._session_expire_at = time.time() + 86400 + if TYPE_CHECKING: + assert self._key_pair is not None self._encryption_session = AesEncyptionSession.create_from_keypair( - handshake_key, key_pair + handshake_key, self._key_pair ) self._handshake_done = True diff --git a/kasa/httpclient.py b/kasa/httpclient.py index a4bd84a33..008bc091c 100644 --- a/kasa/httpclient.py +++ b/kasa/httpclient.py @@ -49,6 +49,11 @@ async def post( response_data = None self._last_url = url self.client.cookie_jar.clear() + return_json = bool(json) + # If json is not a dict send as data + if json and not isinstance(json, Dict): + data = json + json = None try: resp = await self.client.post( url, @@ -62,7 +67,7 @@ async def post( async with resp: if resp.status == 200: response_data = await resp.read() - if json: + if return_json: response_data = json_loads(response_data.decode()) except (aiohttp.ServerDisconnectedError, aiohttp.ClientOSError) as ex: diff --git a/kasa/tests/test_aestransport.py b/kasa/tests/test_aestransport.py index 774aaf943..0174c6371 100644 --- a/kasa/tests/test_aestransport.py +++ b/kasa/tests/test_aestransport.py @@ -169,7 +169,10 @@ def __init__(self, host, status_code=200, error_code=0, inner_error_code=0): self.inner_error_code = inner_error_code self.http_client = HttpClient(DeviceConfig(self.host)) - async def post(self, url, params=None, json=None, *_, **__): + async def post(self, url, params=None, json=None, data=None, *_, **__): + if data: + async for item in data: + json = json_loads(item.decode()) return await self._post(url, json) async def _post(self, url, json): From a763b2b7939ce1fc94b64d6d7e69281cefc5c8fc Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Tue, 23 Jan 2024 12:18:18 +0000 Subject: [PATCH 2/7] Fix coverage --- kasa/aestransport.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/kasa/aestransport.py b/kasa/aestransport.py index 3e36f9ba8..cbc438a6a 100644 --- a/kasa/aestransport.py +++ b/kasa/aestransport.py @@ -193,8 +193,6 @@ async def _generate_request_body(self) -> AsyncGenerator: This prevents the key pair being generated unless a connection can be made to the device. """ - if self._key_pair: - return self._key_pair = KeyPair.create_key_pair() pub_key = ( "-----BEGIN PUBLIC KEY-----\n" @@ -251,7 +249,7 @@ async def perform_handshake(self): self._session_cookie = {self.SESSION_COOKIE_NAME: cookie} self._session_expire_at = time.time() + 86400 - if TYPE_CHECKING: + if TYPE_CHECKING: # pragma: no cover assert self._key_pair is not None self._encryption_session = AesEncyptionSession.create_from_keypair( handshake_key, self._key_pair From ade611bc9e59e116b00eaa19be42c440021c374a Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Tue, 23 Jan 2024 13:17:34 +0000 Subject: [PATCH 3/7] Update post-review --- kasa/aestransport.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/kasa/aestransport.py b/kasa/aestransport.py index cbc438a6a..ed87f963c 100644 --- a/kasa/aestransport.py +++ b/kasa/aestransport.py @@ -55,6 +55,8 @@ class AesTransport(BaseTransport): "requestByApp": "true", "Accept": "application/json", } + CONTENT_LENGTH = "Content-Length" + KEY_PAIR_CONTENT_LENGTH = 318 def __init__( self, @@ -187,12 +189,13 @@ async def perform_login(self): self._handle_response_error_code(resp_dict, "Error logging in") self._login_token = resp_dict["result"]["token"] - async def _generate_request_body(self) -> AsyncGenerator: + async def _generate_key_pair_payload(self) -> AsyncGenerator: """Generate the request body and return an ascyn_generator. This prevents the key pair being generated unless a connection can be made to the device. """ + _LOGGER.debug("Generating keypair") self._key_pair = KeyPair.create_key_pair() pub_key = ( "-----BEGIN PUBLIC KEY-----\n" @@ -208,7 +211,6 @@ async def _generate_request_body(self) -> AsyncGenerator: async def perform_handshake(self): """Perform the handshake.""" _LOGGER.debug("Will perform handshaking...") - _LOGGER.debug("Generating keypair") self._key_pair = None self._handshake_done = False @@ -217,10 +219,13 @@ async def perform_handshake(self): url = f"http://{self._host}/app" # Device needs the content length or it will response with 500 - headers = {**self.COMMON_HEADERS, "Content-Length": "318"} + headers = { + **self.COMMON_HEADERS, + self.CONTENT_LENGTH: str(self.KEY_PAIR_CONTENT_LENGTH), + } status_code, resp_dict = await self._http_client.post( url, - json=self._generate_request_body(), + json=self._generate_key_pair_payload(), headers=headers, cookies_dict=self._session_cookie, ) From ce6e832f239b120374480530db8c748ec2104d44 Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Tue, 23 Jan 2024 13:20:56 +0000 Subject: [PATCH 4/7] Fix pragma --- kasa/aestransport.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/kasa/aestransport.py b/kasa/aestransport.py index ed87f963c..e221e3b70 100644 --- a/kasa/aestransport.py +++ b/kasa/aestransport.py @@ -254,8 +254,8 @@ async def perform_handshake(self): self._session_cookie = {self.SESSION_COOKIE_NAME: cookie} self._session_expire_at = time.time() + 86400 - if TYPE_CHECKING: # pragma: no cover - assert self._key_pair is not None + if TYPE_CHECKING: + assert self._key_pair is not None # pragma: no cover self._encryption_session = AesEncyptionSession.create_from_keypair( handshake_key, self._key_pair ) From 8293ae319cd13a21ffe4b1cbfc4294ec454245a5 Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Tue, 23 Jan 2024 14:35:06 +0000 Subject: [PATCH 5/7] Make json dumps consistent between python and orjson --- kasa/aestransport.py | 2 +- kasa/json.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/kasa/aestransport.py b/kasa/aestransport.py index e221e3b70..b35713b6d 100644 --- a/kasa/aestransport.py +++ b/kasa/aestransport.py @@ -56,7 +56,7 @@ class AesTransport(BaseTransport): "Accept": "application/json", } CONTENT_LENGTH = "Content-Length" - KEY_PAIR_CONTENT_LENGTH = 318 + KEY_PAIR_CONTENT_LENGTH = 314 def __init__( self, diff --git a/kasa/json.py b/kasa/json.py index 4acc865f5..73bf0e60e 100755 --- a/kasa/json.py +++ b/kasa/json.py @@ -11,5 +11,8 @@ def dumps(obj, *, default=None): except ImportError: import json - dumps = json.dumps + def dumps(obj, *, default=None): + """Dump JSON.""" + return json.dumps(obj, separators=(",", ":")) + loads = json.loads From 8db6635603cbd268cd7b8cce47417df235940700 Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Tue, 23 Jan 2024 14:49:37 +0000 Subject: [PATCH 6/7] Add comment --- kasa/json.py | 1 + 1 file changed, 1 insertion(+) diff --git a/kasa/json.py b/kasa/json.py index 73bf0e60e..aed8cd56d 100755 --- a/kasa/json.py +++ b/kasa/json.py @@ -13,6 +13,7 @@ def dumps(obj, *, default=None): def dumps(obj, *, default=None): """Dump JSON.""" + # Separators specified for consistency with orjson return json.dumps(obj, separators=(",", ":")) loads = json.loads From 946b9f75aa2a7c7e381d5556be91ea2556203ba2 Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Tue, 23 Jan 2024 14:55:13 +0000 Subject: [PATCH 7/7] Add comments re json parameter in HttpClient --- kasa/httpclient.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/kasa/httpclient.py b/kasa/httpclient.py index 008bc091c..28a19e8bd 100644 --- a/kasa/httpclient.py +++ b/kasa/httpclient.py @@ -41,16 +41,22 @@ async def post( *, params: Optional[Dict[str, Any]] = None, data: Optional[bytes] = None, - json: Optional[Dict] = None, + json: Optional[Union[Dict, Any]] = None, headers: Optional[Dict[str, str]] = None, cookies_dict: Optional[Dict[str, str]] = None, ) -> Tuple[int, Optional[Union[Dict, bytes]]]: - """Send an http post request to the device.""" + """Send an http post request to the device. + + If the request is provided via the json parameter json will be returned. + """ response_data = None self._last_url = url self.client.cookie_jar.clear() return_json = bool(json) - # If json is not a dict send as data + # If json is not a dict send as data. + # This allows the json parameter to be used to pass other + # types of data such as async_generator and still have json + # returned. if json and not isinstance(json, Dict): data = json json = None