8000 Do not regenerate aes key pair (#1114) · msz-coder/python-kasa@fcf8f07 · GitHub
[go: up one dir, main page]

Skip to content

Commit fcf8f07

Browse files
authored
Do not regenerate aes key pair (python-kasa#1114)
And read it from `device_config` if provided. This is required as key generation can eat up cpu when a device is not fully available and the library is retrying.
1 parent 2a89e58 commit fcf8f07

File tree

3 files changed

+46
-4
lines changed

3 files changed

+46
-4
lines changed

kasa/aestransport.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,9 @@ def __init__(
106106
self._session_cookie: dict[str, str] | None = None
107107

108108
self._key_pair: KeyPair | None = None
109+
if config.aes_keys:
110+
aes_keys = config.aes_keys
111+
self._key_pair = KeyPair(aes_keys["private"], aes_keys["public"])
109112
self._app_url = URL(f"http://{self._host}:{self._port}/app")
110113
self._token_url: URL | None = None
111114

@@ -271,7 +274,14 @@ async def _generate_key_pair_payload(self) -> AsyncGenerator:
271274
can be made to the device.
272275
"""
273276
_LOGGER.debug("Generating keypair")
274-
self._key_pair = KeyPair.create_key_pair()
277+
if not self._key_pair:
278+
kp = KeyPair.create_key_pair()
279+
self._config.aes_keys = {
280+
"private": kp.get_private_key(),
281+
"public": kp.get_public_key(),
282+
}
283+
self._key_pair = kp
284+
275285
pub_key = (
276286
"-----BEGIN PUBLIC KEY-----\n"
277287
+ self._key_pair.get_public_key() # type: ignore[union-attr]
@@ -286,7 +296,6 @@ async def perform_handshake(self) -> None:
286296
"""Perform the handshake."""
287297
_LOGGER.debug("Will perform handshaking...")
288298

289-
self._key_pair = None
290299
self._token_url = None
291300
self._session_expire_at = None
292301
self._session_cookie = None

kasa/deviceconfig.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
import logging
3535
from dataclasses import asdict, dataclass, field, fields, is_dataclass
3636
from enum import Enum
37-
from typing import TYPE_CHECKING, Dict, Optional, Union
37+
from typing import TYPE_CHECKING, Dict, Optional, TypedDict, Union
3838

3939
from .credentials import Credentials
4040
from .exceptions import KasaException
@@ -45,6 +45,13 @@
4545
_LOGGER = logging.getLogger(__name__)
4646

4747

48+
class KeyPairDict(TypedDict):
49+
"""Class to represent a public/private key pair."""
50+
51+
private: str
52+
public: str
53+
54+
4855
class DeviceEncryptionType(Enum):
4956
"""Encrypt type enum."""
5057

@@ -182,7 +189,7 @@ class DeviceConfig:
182189
#: The batch size for protoools supporting multiple request batches.
183190
connection_type: DeviceConnectionParameters = field(
184191
default_factory=lambda: DeviceConnectionParameters(
185-
DeviceFamily.IotSmartPlugSwitch, DeviceEncryptionType.Xor, 1
192+
DeviceFamily.IotSmartPlugSwitch, DeviceEncryptionType.Xor
186193
)
187194
)
188195
#: True if the device uses http. Consumers should retrieve rather than set this
@@ -193,6 +200,8 @@ class DeviceConfig:
193200
#: Set a custom http_client for the device to use.
194201
http_client: Optional["ClientSession"] = field(default=None, compare=False)
195202

203+
aes_keys: Optional[KeyPairDict] = None
204+
196205
def __post_init__(self):
197206
if self.connection_type is None:
198207
self.connection_type = DeviceConnectionParameters(

kasa/tests/test_aestransport.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,29 @@ async def test_handshake(
8080
assert transport._state is TransportState.LOGIN_REQUIRED
8181

8282

83+
async def test_handshake_with_keys(mocker):
84+
host = "127.0.0.1"
85+
mock_aes_device = MockAesDevice(host)
86+
mocker.patch.object(aiohttp.ClientSession, "post", side_effect=mock_aes_device.post)
87+
88+
test_keys = {
89+
"private": "MIICdQIBADANBgkqhkiG9w0BAQEFAASCAl8wggJbAgEAAoGBAMo/JQpXIbP2M3bLOKyfEVCURFCxHIXv4HDME8J58AL4BwGDXf0oQycgj9nV+T/MzgEd/4iVysYuYfLuIEKXADP7Lby6AfA/dbcinZZ7bLUNMNa7TaylIvVKtSfR0LV8AmG0jdQYkr4cTzLAEd+AEs/wG3nMQNEcoQRVY+svLPDjAgMBAAECgYBCsDOch0KbvrEVmMklUoY5Fcq4+M249HIDf6d8VwznTbWxsAmL8nzCKCCG6eF4QiYjhCrAdPQaCS1PF2oXywbLhngid/9W9gz4CKKDJChs1X8KvLi+TLg1jgJUXvq9yVNh1CB+lS2ho4gdDDCbVmiVOZR5TDfEf0xeJ+Zz3zlUEQJBAPkhuNdc3yRue8huFZbrWwikURQPYBxLOYfVTDsfV9mZGSkGoWS1FPDsxrqSXugTmcTRuw+lrXKDabJ72kqywA8CQQDP0oaGh5r7F12Xzcwb7X9JkTvyr+rO8YgVtKNBaNVOPabAzysNwOlvH/sNCVQcRj8rn5LNXitgLx6T+Q5uqa3tAkA7J0elUzbkhps7ju/vYri9x448zh3K+g2R9BJio2GPmCuCM0HVEK4FOqNBH4oLXsQPGKFq6LLTUuKg74l4XRL/AkBHBO6r8pNn0yhMxCtIL/UbsuIFoVBgv/F9WWmg5K5gOnlN0n4oCRC8xPUKE3IG54qW4cVNIS05hWCxuJ7R+nJRAkByt/+kX1nQxis2wIXj90fztXG3oSmoVaieYxaXPxlWvX3/Q5kslFF5UsGy9gcK0v2PXhqjTbhud3/X0Er6YP4v",
90+
"public": "MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDKPyUKVyGz9jN2yzisnxFQlERQsRyF7+BwzBPCefAC+AcBg139KEMnII/Z1fk/zM4BHf+IlcrGLmHy7iBClwAz+y28ugHwP3W3Ip2We2y1DTDWu02spSL1SrUn0dC1fAJhtI3UGJK+HE8ywBHfgBLP8Bt5zEDRHKEEVWPrLyzw4wIDAQAB",
91+
}
92+
transport = AesTransport(
93+
config=DeviceConfig(
94+
host, credentials=Credentials("foo", "bar"), aes_keys=test_keys
95+
)
96+
)
97+
98+
assert transport._encryption_session is None
99+
assert transport._state is TransportState.HANDSHAKE_REQUIRED
100+
101+
await transport.perform_handshake()
102+
assert transport._key_pair.get_private_key() == test_keys["private"]
103+
assert transport._key_pair.get_public_key() == test_keys["public"]
104+
105+
83106
@status_parameters
84107
async def test_login(mocker, status_code, error_code, inner_error_code, expectation):
85108
host = "127.0.0.1"
@@ -97,6 +120,7 @@ async def test_login(mocker, status_code, error_code, inner_error_code, expectat
97120
with expectation:
98121
await transport.perform_login()
99122
assert mock_aes_device.token in str(transport._token_url)
123+
assert transport._config.aes_keys == transport._key_pair
100124

101125

102126
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)
0