8000 Generate AES KeyPair lazily (#687) · python-kasa/python-kasa@e233e37 · GitHub
[go: up one dir, main page]

Skip to content

Commit e233e37

Browse files
authored
Generate AES KeyPair lazily (#687)
* Generate AES KeyPair lazily * Fix coverage * Update post-review * Fix pragma * Make json dumps consistent between python and orjson * Add comment * Add comments re json parameter in HttpClient
1 parent 718983c commit e233e37

File tree

4 files changed

+57
-23
lines changed

4 files changed

+57
-23
lines changed

kasa/aestransport.py

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import hashlib
99
import logging
1010
import time
11-
from typing import Dict, Optional, cast
11+
from typing import TYPE_CHECKING, AsyncGenerator, Dict, Optional, cast
1212

1313
from cryptography.hazmat.primitives import padding, serialization
1414
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
@@ -55,6 +55,8 @@ class AesTransport(BaseTransport):
5555
"requestByApp": "true",
5656
"Accept": "application/json",
5757
}
58+
CONTENT_LENGTH = "Content-Length"
59+
KEY_PAIR_CONTENT_LENGTH = 314
5860

5961
def __init__(
6062
self,
@@ -86,6 +88,8 @@ def __init__(
8688

8789
self._login_token = None
8890

91+
self._key_pair: Optional[KeyPair] = None
92+
8993
_LOGGER.debug("Created AES transport for %s", self._host)
90< 8000 /td>94

9195
@property
@@ -204,34 +208,44 @@ async def try_login(self, login_params):
204208
self._handle_response_error_code(resp_dict, "Error logging in")
205209
self._login_token = resp_dict["result"]["token"]
206210

207-
async def perform_handshake(self):
208-
"""Perform the handshake."""
209-
_LOGGER.debug("Will perform handshaking...")
210-
_LOGGER.debug("Generating keypair")
211-
212-
self._handshake_done = False
213-
self._session_expire_at = None
214-
self._session_cookie = None
215-
216-
url = f"http://{self._host}/app"
217-
key_pair = KeyPair.create_key_pair()
211+
async def _generate_key_pair_payload(self) -> AsyncGenerator:
212+
"""Generate the request body and return an ascyn_generator.
218213
214+
This prevents the key pair being generated unless a connection
215+
can be made to the device.
216+
"""
217+
_LOGGER.debug("Generating keypair")
218+
self._key_pair = KeyPair.create_key_pair()
219219
pub_key = (
220220
"-----BEGIN PUBLIC KEY-----\n"
221-
+ key_pair.get_public_key()
221+
+ self._key_pair.get_public_key() # type: ignore[union-attr]
222222
+ "\n-----END PUBLIC KEY-----\n"
223223
)
224224
handshake_params = {"key": pub_key}
225225
_LOGGER.debug(f"Handshake params: {handshake_params}")
226-
227226
request_body = {"method": "handshake", "params": handshake_params}
228-
229227
_LOGGER.debug(f"Request {request_body}")
228+
yield json_dumps(request_body).encode()
230229

230+
async def perform_handshake(self):
231+
"""Perform the handshake."""
232+
_LOGGER.debug("Will perform handshaking...")
233+
234+
self._key_pair = None
235+
self._handshake_done = False
236+
self._session_expire_at = None
237+
self._session_cookie = None
238+
239+
url = f"http://{self._host}/app"
240+
# Device needs the content length or it will response with 500
241+
headers = {
242+
**self.COMMON_HEADERS,
243+
self.CONTENT_LENGTH: str(self.KEY_PAIR_CONTENT_LENGTH),
244+
}
231245
status_code, resp_dict = await self._http_client.post(
232246
url,
233-
json=request_body,
234-
headers=self.COMMON_HEADERS,
247+
json=self._generate_key_pair_payload(),
248+
headers=headers,
235249
cookies_dict=self._session_cookie,
236250
)
237251

@@ -259,8 +273,10 @@ async def perform_handshake(self):
259273
self._session_cookie = {self.SESSION_COOKIE_NAME: cookie}
260274

261275
self._session_expire_at = time.time() + 86400
276+
if TYPE_CHECKING:
277+
assert self._key_pair is not None # pragma: no cover
262278
self._encryption_session = AesEncyptionSession.create_from_keypair(
263-
handshake_key, key_pair
279+
handshake_key, self._key_pair
264280
)
265281

266282
self._handshake_done = True

kasa/httpclient.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,25 @@ async def post(
4141
*,
4242
params: Optional[Dict[str, Any]] = None,
4343
data: Optional[bytes] = None,
44-
json: Optional[Dict] = None,
44+
json: Optional[Union[Dict, Any]] = None,
4545
headers: Optional[Dict[str, str]] = None,
4646
cookies_dict: Optional[Dict[str, str]] = None,
4747
) -> Tuple[int, Optional[Union[Dict, bytes]]]:
48-
"""Send an http post request to the device."""
48+
"""Send an http post request to the device.
49+
50+
If the request is provided via the json parameter json will be returned.
51+
"""
4952
response_data = None
5053
self._last_url = url
5154
self.client.cookie_jar.clear()
55+
return_json = bool(json)
56+
# If json is not a dict send as data.
57+
# This allows the json parameter to be used to pass other
58+
# types of data such as async_generator and still have json
59+
# returned.
60+
if json and not isinstance(json, Dict):
61+
data = json
62+
json = None
5263
try:
5364
resp = await self.client.post(
5465
url,
@@ -62,7 +73,7 @@ async def post(
6273
async with resp:
6374
if resp.status == 200:
6475
response_data = await resp.read()
65-
if json:
76+
if return_json:
6677
response_data = json_loads(response_data.decode())
6778

6879
except (aiohttp.ServerDisconnectedError, aiohttp.ClientOSError) as ex:

kasa/json.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,9 @@ def dumps(obj, *, default=None):
1111
except ImportError:
1212
import json
1313

14-
dumps = json.dumps
14+
def dumps(obj, *, default=None):
15+
"""Dump JSON."""
16+
# Separators specified for consistency with orjson
17+
return json.dumps(obj, separators=(",", ":"))
18+
1519
loads = json.loads

kasa/tests/test_aestransport.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,10 @@ def inner_error_code(self):
225225
else:
226226
return self._inner_error_code
227227

228-
async def post(self, url, params=None, json=None, *_, **__):
228+
async def post(self, url, params=None, json=None, data=None, *_, **__):
229+
if data:
230+
async for item in data:
231+
json = json_loads(item.decode())
229232
return await self._post(url, json)
230233

231234
async def _post(self, url, json):

0 commit comments

Comments
 (0)
0