|
8 | 8 | import hashlib
|
9 | 9 | import logging
|
10 | 10 | import time
|
11 |
| -from typing import Dict, Optional, cast |
| 11 | +from typing import TYPE_CHECKING, AsyncGenerator, Dict, Optional, cast |
12 | 12 |
|
13 | 13 | from cryptography.hazmat.primitives import padding, serialization
|
14 | 14 | from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
|
@@ -55,6 +55,8 @@ class AesTransport(BaseTransport):
|
55 | 55 | "requestByApp": "true",
|
56 | 56 | "Accept": "application/json",
|
57 | 57 | }
|
| 58 | + CONTENT_LENGTH = "Content-Length" |
| 59 | + KEY_PAIR_CONTENT_LENGTH = 314 |
58 | 60 |
|
59 | 61 | def __init__(
|
60 | 62 | self,
|
@@ -86,6 +88,8 @@ def __init__(
|
86 | 88 |
|
87 | 89 | self._login_token = None
|
88 | 90 |
|
| 91 | + self._key_pair: Optional[KeyPair] = None |
| 92 | + |
89 | 93 | _LOGGER.debug("Created AES transport for %s", self._host)
|
90 <
8000
/td> | 94 |
|
91 | 95 | @property
|
@@ -204,34 +208,44 @@ async def try_login(self, login_params):
|
204 | 208 | self._handle_response_error_code(resp_dict, "Error logging in")
|
205 | 209 | self._login_token = resp_dict["result"]["token"]
|
206 | 210 |
|
207 |
| - async def perform_handshake(self): |
208 |
| - """Perform the handshake.""" |
209 |
| - _LOGGER.debug("Will perform handshaking...") |
210 |
| - _LOGGER.debug("Generating keypair") |
211 |
| - |
212 |
| - self._handshake_done = False |
213 |
| - self._session_expire_at = None |
214 |
| - self._session_cookie = None |
215 |
| - |
216 |
| - url = f"http://{self._host}/app" |
217 |
| - key_pair = KeyPair.create_key_pair() |
| 211 | + async def _generate_key_pair_payload(self) -> AsyncGenerator: |
| 212 | + """Generate the request body and return an ascyn_generator. |
218 | 213 |
|
| 214 | + This prevents the key pair being generated unless a connection |
| 215 | + can be made to the device. |
| 216 | + """ |
| 217 | + _LOGGER.debug("Generating keypair") |
| 218 | + self._key_pair = KeyPair.create_key_pair() |
219 | 219 | pub_key = (
|
220 | 220 | "-----BEGIN PUBLIC KEY-----\n"
|
221 |
| - + key_pair.get_public_key() |
| 221 | + + self._key_pair.get_public_key() # type: ignore[union-attr] |
222 | 222 | + "\n-----END PUBLIC KEY-----\n"
|
223 | 223 | )
|
224 | 224 | handshake_params = {"key": pub_key}
|
225 | 225 | _LOGGER.debug(f"Handshake params: {handshake_params}")
|
226 |
| - |
227 | 226 | request_body = {"method": "handshake", "params": handshake_params}
|
228 |
| - |
229 | 227 | _LOGGER.debug(f"Request {request_body}")
|
| 228 | + yield json_dumps(request_body).encode() |
230 | 229 |
|
| 230 | + async def perform_handshake(self): |
| 231 | + """Perform the handshake.""" |
| 232 | + _LOGGER.debug("Will perform handshaking...") |
| 233 | + |
| 234 | + self._key_pair = None |
| 235 | + self._handshake_done = False |
| 236 | + self._session_expire_at = None |
| 237 | + self._session_cookie = None |
| 238 | + |
| 239 | + url = f"http://{self._host}/app" |
| 240 | + # Device needs the content length or it will response with 500 |
| 241 | + headers = { |
| 242 | + **self.COMMON_HEADERS, |
| 243 | + self.CONTENT_LENGTH: str(self.KEY_PAIR_CONTENT_LENGTH), |
| 244 | + } |
231 | 245 | status_code, resp_dict = await self._http_client.post(
|
232 | 246 | url,
|
233 |
| - json=request_body, |
234 |
| - headers=self.COMMON_HEADERS, |
| 247 | + json=self._generate_key_pair_payload(), |
| 248 | + headers=headers, |
235 | 249 | cookies_dict=self._session_cookie,
|
236 | 250 | )
|
237 | 251 |
|
@@ -259,8 +273,10 @@ async def perform_handshake(self):
|
259 | 273 | self._session_cookie = {self.SESSION_COOKIE_NAME: cookie}
|
260 | 274 |
|
261 | 275 | self._session_expire_at = time.time() + 86400
|
| 276 | + if TYPE_CHECKING: |
| 277 | + assert self._key_pair is not None # pragma: no cover |
262 | 278 | self._encryption_session = AesEncyptionSession.create_from_keypair(
|
263 |
| - handshake_key, key_pair |
| 279 | + handshake_key, self._key_pair |
264 | 280 | )
|
265 | 281 |
|
266 | 282 | self._handshake_done = True
|
|
0 commit comments