8000 Update names and restructure tapo devices · python-kasa/python-kasa@f5e320d · GitHub
[go: up one dir, main page]

Skip to content

Commit f5e320d

Browse files
committed
Update names and restructure tapo devices
1 parent 1cc089d commit f5e320d

File tree

7 files changed

+290
-224
lines changed

7 files changed

+290
-224
lines changed

kasa/tapoprotocol.py renamed to kasa/aesprotocol.py

Lines changed: 53 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
"""Implementation of the TP-Link TAPO Protocol.
1+
"""Implementation of the TP-Link AES Protocol.
22
3+
Based on the work of https://github.com/petretiandrea/plugp100
4+
under compatible GNU GPL3 license.
35
"""
46

57
import asyncio
@@ -35,16 +37,16 @@ def _md5(payload: bytes) -> bytes:
3537

3638

3739
def _sha1(payload: bytes) -> str:
38-
sha1_algo = hashlib.sha1()
40+
sha1_algo = hashlib.sha1() # noqa: S324
3941
sha1_algo.update(payload)
4042
return sha1_algo.hexdigest()
4143

4244

43-
class TapoAESProtocol(TPLinkProtocol):
44-
"""Implementation of the KLAP encryption protocol.
45+
class TPLinkAes(TPLinkProtocol):
46+
"""Implementation of the AES encryption protocol.
4547
46-
KLAP is the name used in device discovery for TP-Link's new encryption
47-
protocol, used by newer firmware versions.
48+
AES is the name used in device discovery for TP-Link's TAPO encryption
49+
protocol, sometimes used by newer firmware versions on kasa devices.
4850
"""
4951

5052
DEFAULT_PORT = 80
@@ -95,9 +97,10 @@ def __init__(
9597
self.login_version = login_version
9698
self.login_token = None
9799

98-
_LOGGER.debug("Created KLAP object for %s", self.host)
100+
_LOGGER.debug("Created AES object for %s", self.host)
99101

100102
def hash_credentials(self, credentials, try_login_version2):
103+
"""Hash the credentials."""
101104
if try_login_version2:
102105
un = base64.b64encode(
103106
_sha1(credentials.username.encode()).encode()
@@ -135,6 +138,7 @@ async def client_post(self, url, params=None, data=None, json=None, headers=None
135138
return resp.status_code, response_data
136139

137140
async def send_secure_passthrough(self, request):
141+
"""Send encrypted message as passthrough."""
138142
url = f"http://{self.host}/app"
139143
if self.login_token:
140144
url += f"?token={self.login_token}"
@@ -159,7 +163,8 @@ async def send_secure_passthrough(self, request):
159163
else:
160164
raise AuthenticationException("Could not complete send")
161165

162-
def get_tapo_request(self, method, params=None):
166+
def get_aes_request(self, method, params=None):
167+
"""Get a request message."""
163168
request = {
164169
"method": method,
165170
"params": params,
@@ -170,18 +175,20 @@ def get_tapo_request(self, method, params=None):
170175
return request
171176

172177
async def perform_login(self, login_v2):
178+
"""Login to the device."""
173179
self.login_token = None
174-
url = f"http://{self.host}/app"
180+
175181
un, pw = self.hash_credentials(self.credentials, login_v2)
176182
params = {"password": pw, "username": un}
177-
request = self.get_tapo_request("login_device", params)
183+
request = self.get_aes_request("login_device", params)
178184
try:
179185
result = await self.send_secure_passthrough(request)
180186
except SmartDeviceException as ex:
181187
raise AuthenticationException(ex) from ex
182188
self.login_token = result["token"]
183189

184190
async def perform_handshake(self):
191+
"""Perform the handshake."""
185192
_LOGGER.debug("Will perform handshaking...")
186193
_LOGGER.debug("Generating keypair")
187194

@@ -192,7 +199,11 @@ async def perform_handshake(self):
192199
url = f"http://{self.host}/app"
193200
key_pair = KeyPair.create_key_pair()
194201

195-
pub_key = f"-----BEGIN PUBLIC KEY-----\n{key_pair.get_public_key()}\n-----END PUBLIC KEY-----\n"
202+
pub_key = (
203+
"-----BEGIN PUBLIC KEY-----\n"
204+
+ key_pair.get_public_key()
205+
+ "\n-----END PUBLIC KEY-----\n"
206+
)
196207
handshake_params = {"key": pub_key}
197208
_LOGGER.debug(f"Handshake params: {handshake_params}")
198209

@@ -255,7 +266,7 @@ async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
255266
async with self.query_lock:
256267
return await self._query(request, retry_count)
257268

258-
async def _query(self, request: str, retry_count: int = 3) -> Dict:
269+
async def < F438 span class="pl-en">_query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
259270
for retry in range(retry_count + 1):
260271
try:
261272
return await self._execute_query(request, retry)
@@ -292,7 +303,7 @@ async def _query(self, request: str, retry_count: int = 3) -> Dict:
292303
# make mypy happy, this should never be reached..
293304
raise SmartDeviceException("Query reached somehow to unreachable")
294305

295-
async def _execute_query(self, request: str, retry_count: int) -> Dict:
306+
async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> Dict:
296307
if not self.http_client:
297308
self.http_client = httpx.AsyncClient()
298309

@@ -305,14 +316,14 @@ async def _execute_query(self, request: str, retry_count: int) -> Dict:
305316
await self.perform_login(True)
306317

307318
if isinstance(request, dict):
308-
tapo_method = next(iter(request))
309-
tapo_params = request[tapo_method]
319+
aes_method = next(iter(request))
320+
aes_params = request[aes_method]
310321
else:
311-
tapo_method = request
312-
tapo_params = None
322+
aes_method = request
323+
aes_params = None
313324

314-
tapo_request = self.get_tapo_request(tapo_method, tapo_params)
315-
response_data = await self.send_secure_passthrough(tapo_request)
325+
aes_request = self.get_aes_request(aes_method, aes_params)
326+
response_data = await self.send_secure_passthrough(aes_request)
316327

317328
_LOGGER.debug(
318329
"%s << %s",
@@ -331,13 +342,18 @@ async def close(self) -> None:
331342

332343

333344
class AesEncyptionSession:
345+
"""Class for an AES encryption session."""
346+
334347
@staticmethod
335348
def create_from_keypair(handshake_key: str, keypair):
336-
handshake_key: bytes = base64.b64decode(handshake_key.encode("UTF-8"))
349+
"""Create the encryption session."""
350+
handshake_key_bytes: bytes = base64.b64decode(handshake_key.encode("UTF-8"))
337351
private_key_data = base64.b64decode(keypair.get_private_key().encode("UTF-8"))
338352

339353
private_key = serialization.load_der_private_key(private_key_data, None, None)
340-
key_and_iv = private_key.decrypt(handshake_key, asymmetric_padding.PKCS1v15())
354+
key_and_iv = private_key.decrypt(
355+
handshake_key_bytes, asymmetric_padding.PKCS1v15()
356+
)
341357
if key_and_iv is None:
342358
raise ValueError("Decryption failed!")
343359

@@ -348,13 +364,15 @@ def __init__(self, key, iv):
348364
self.padding_strategy = padding.PKCS7(algorithms.AES.block_size)
349365

350366
def encrypt(self, data) -> bytes:
367+
"""Encrypt the message."""
351368
encryptor = self.cipher.encryptor()
352369
padder = self.padding_strategy.padder()
353370
padded_data = padder.update(data) + padder.finalize()
354371
encrypted = encryptor.update(padded_data) + encryptor.finalize()
355372
return base64.b64encode(encrypted)
356373

357374
def decrypt(self, data) -> str:
375+
"""Decrypt the message."""
358376
decryptor = self.cipher.decryptor()
359377
unpadder = self.padding_strategy.unpadder()
360378
decrypted = decryptor.update(base64.b64decode(data)) + decryptor.finalize()
@@ -363,8 +381,11 @@ def decrypt(self, data) -> str:
363381

364382

365383
class KeyPair:
384+
"""Class for generating key pairs."""
385+
366386
@staticmethod
367387
def create_key_pair(key_size: int = 1024):
388+
"""Create a key pair."""
368389
private_key = rsa.generate_private_key(public_exponent=65537, key_size=key_size)
369390
public_key = private_key.public_key()
370391

@@ -388,13 +409,17 @@ def __init__(self, private_key: str, public_key: str):
388409
self.public_key = public_key
389410

390411
def get_private_key(self) -> str:
412+
"""Get the private key."""
391413
return self.private_key
392414

393415
def get_public_key(self) -> str:
416+
"""Get the public key."""
394417
return self.public_key
395418

396419

397420
class SnowflakeId:
421+
"""Class for generating snowflake ids."""
422+
398423
EPOCH = 1420041600000 # Custom epoch (in milliseconds)
399424
WORKER_ID_BITS = 5
400425
DATA_CENTER_ID_BITS = 5
@@ -408,11 +433,15 @@ class SnowflakeId:
408433
def __init__(self, worker_id, data_center_id):
409434
if worker_id > SnowflakeId.MAX_WORKER_ID or worker_id < 0:
410435
raise ValueError(
411-
f"Worker ID can't be greater than {SnowflakeId.MAX_WORKER_ID} or less than 0"
436+
"Worker ID can't be greater than "
437+
+ str(SnowflakeId.MAX_WORKER_ID)
438+
+ " or less than 0"
412439
)
413440
if data_center_id > SnowflakeId.MAX_DATA_CENTER_ID or data_center_id < 0:
414441
raise ValueError(
415-
f"Data center ID can't be greater than {SnowflakeId.MAX_DATA_CENTER_ID} or less than 0"
442+
"Data center ID can't be greater than "
443+
+ str(SnowflakeId.MAX_DATA_CENTER_ID)
444+
+ " or less than 0"
416445
)
417446

418447
self.worker_id = worker_id
@@ -421,6 +450,7 @@ def __init__(self, worker_id, data_center_id):
421450
self.last_timestamp = -1
422451

423452
def generate_id(self):
453+
"""Generate a snowflake id."""
424454
timestamp = self._current_millis()
425455

426456
if timestamp < self.last_timestamp:

kasa/device_factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from .smartlightstrip import SmartLightStrip
1414
from .smartplug import SmartPlug
1515
from .smartstrip import SmartStrip
16-
from .tapoplug import TapoPlug
16+
from .tapo.tapoplug import TapoPlug
1717

1818
DEVICE_TYPE_TO_CLASS = {
1919
DeviceType.Plug: SmartPlug,

kasa/discover.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
except ImportError:
1616
from pydantic import BaseModel, Field
1717

18+
from kasa.aesprotocol import TPLinkAes
1819
from kasa.credentials import Credentials
1920
from kasa.exceptions import UnsupportedDeviceException
2021
from kasa.json import dumps as json_dumps
@@ -23,8 +24,7 @@
2324
from kasa.protocol import TPLinkSmartHomeProtocol
2425
from kasa.smartdevice import SmartDevice, SmartDeviceException
2526
from kasa.smartplug import SmartPlug
26-
from kasa.tapoplug import TapoPlug
27-
from kasa.tapoprotocol import TapoAESProtocol
27+
from kasa.tapo.tapoplug import TapoPlug
2828

2929
from .device_factory import get_device_class_from_info
3030

@@ -383,7 +383,7 @@ def _get_device_instance(
383383
if discovery_result.mgt_encrypt_schm.encrypt_type in ("KLAP", "AES"):
384384
type_ = discovery_result.device_type
385385
device_class = None
386-
supported_device_types = {
386+
supported_device_types: dict[str, Type[SmartDevice]] = {
387387
"SMART.TAPOPLUG": TapoPlug,
388388
"SMART.KASAPLUG": TapoPlug,
389389
"IOT.SMARTPLUGSWITCH": SmartPlug,
@@ -399,7 +399,7 @@ def _get_device_instance(
399399
if discovery_result.mgt_encrypt_schm.encrypt_type == "KLAP":
400400
device.protocol = TPLinkKlap(ip, credentials=credentials)
401401
else:
402-
device.protocol = TapoAESProtocol(
402+
device.protocol = TPLinkAes(
403403
ip,
404404
credentials=credentials,
405405
login_version=discovery_result.mgt_encrypt_schm.lv,

0 commit comments

Comments
 (0)
0