1
- """Implementation of the TP-Link TAPO Protocol.
1
+ """Implementation of the TP-Link AES Protocol.
2
2
3
+ Based on the work of https://github.com/petretiandrea/plugp100
4
+ under compatible GNU GPL3 license.
3
5
"""
4
6
5
7
import asyncio
@@ -35,16 +37,16 @@ def _md5(payload: bytes) -> bytes:
35
37
36
38
37
39
def _sha1 (payload : bytes ) -> str :
38
- sha1_algo = hashlib .sha1 ()
40
+ sha1_algo = hashlib .sha1 () # noqa: S324
39
41
sha1_algo .update (payload )
40
42
return sha1_algo .hexdigest ()
41
43
42
44
43
- class TapoAESProtocol (TPLinkProtocol ):
44
- """Implementation of the KLAP encryption protocol.
45
+ class TPLinkAes (TPLinkProtocol ):
46
+ """Implementation of the AES encryption protocol.
45
47
46
- KLAP is the name used in device discovery for TP-Link's new encryption
47
- protocol, used by newer firmware versions.
48
+ AES is the name used in device discovery for TP-Link's TAPO encryption
49
+ protocol, sometimes used by newer firmware versions on kasa devices .
48
50
"""
49
51
50
52
DEFAULT_PORT = 80
@@ -95,9 +97,10 @@ def __init__(
95
97
self .login_version = login_version
96
98
self .login_token = None
97
99
98
- _LOGGER .debug ("Created KLAP object for %s" , self .host )
100
+ _LOGGER .debug ("Created AES object for %s" , self .host )
99
101
100
102
def hash_credentials (self , credentials , try_login_version2 ):
103
+ """Hash the credentials."""
101
104
if try_login_version2 :
102
105
un = base64 .b64encode (
103
106
_sha1 (credentials .username .encode ()).encode ()
@@ -135,6 +138,7 @@ async def client_post(self, url, params=None, data=None, json=None, headers=None
135
138
return resp .status_code , response_data
136
139
137
140
async def send_secure_passthrough (self , request ):
141
+ """Send encrypted message as passthrough."""
138
142
url = f"http://{ self .host } /app"
139
143
if self .login_token :
140
144
url += f"?token={ self .login_token } "
@@ -159,7 +163,8 @@ async def send_secure_passthrough(self, request):
159
163
else :
160
164
raise AuthenticationException ("Could not complete send" )
161
165
162
- def get_tapo_request (self , method , params = None ):
166
+ def get_aes_request (self , method , params = None ):
167
+ """Get a request message."""
163
168
request = {
164
169
"method" : method ,
165
170
"params" : params ,
@@ -170,18 +175,20 @@ def get_tapo_request(self, method, params=None):
170
175
return request
171
176
172
177
async def perform_login (self , login_v2 ):
178
+ """Login to the device."""
173
179
self .login_token = None
174
- url = f"http:// { self . host } /app"
180
+
175
181
un , pw = self .hash_credentials (self .credentials , login_v2 )
176
182
params = {"password" : pw , "username" : un }
177
- request = self .get_tapo_request ("login_device" , params )
183
+ request = self .get_aes_request ("login_device" , params )
178
184
try :
179
185
result = await self .send_secure_passthrough (request )
180
186
except SmartDeviceException as ex :
181
187
raise AuthenticationException (ex ) from ex
182
188
self .login_token = result ["token" ]
183
189
184
190
async def perform_handshake (self ):
191
+ """Perform the handshake."""
185
192
_LOGGER .debug ("Will perform handshaking..." )
186
193
_LOGGER .debug ("Generating keypair" )
187
194
@@ -192,7 +199,11 @@ async def perform_handshake(self):
192
199
url = f"http://{ self .host } /app"
193
200
key_pair = KeyPair .create_key_pair ()
194
201
195
- pub_key = f"-----BEGIN PUBLIC KEY-----\n { key_pair .get_public_key ()} \n -----END PUBLIC KEY-----\n "
202
+ pub_key = (
203
+ "-----BEGIN PUBLIC KEY-----\n "
204
+ + key_pair .get_public_key ()
205
+ + "\n -----END PUBLIC KEY-----\n "
206
+ )
196
207
handshake_params = {"key" : pub_key }
197
208
_LOGGER .debug (f"Handshake params: { handshake_params } " )
198
209
@@ -255,7 +266,7 @@ async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
255
266
async with self .query_lock :
256
267
return await self ._query (request , retry_count )
257
268
258
- async def _query (self , request : str , retry_count : int = 3 ) -> Dict :
269
+ async def <
F438
span class="pl-en">_query(self , request : Union [ str , Dict ] , retry_count : int = 3 ) -> Dict :
259
270
for retry in range (retry_count + 1 ):
260
271
try :
261
272
return await self ._execute_query (request , retry )
@@ -292,7 +303,7 @@ async def _query(self, request: str, retry_count: int = 3) -> Dict:
292
303
# make mypy happy, this should never be reached..
293
304
raise SmartDeviceException ("Query reached somehow to unreachable" )
294
305
295
- async def _execute_query (self , request : str , retry_count : int ) -> Dict :
306
+ async def _execute_query (self , request : Union [ str , Dict ] , retry_count : int ) -> Dict :
296
307
if not self .http_client :
297
308
self .http_client = httpx .AsyncClient ()
298
309
@@ -305,14 +316,14 @@ async def _execute_query(self, request: str, retry_count: int) -> Dict:
305
316
await self .perform_login (True )
306
317
307
318
if isinstance (request , dict ):
308
- tapo_method = next (iter (request ))
309
- tapo_params = request [tapo_method ]
319
+ aes_method = next (iter (request ))
320
+ aes_params = request [aes_method ]
310
321
else :
311
- tapo_method = request
312
- tapo_params = None
322
+ aes_method = request
323
+ aes_params = None
313
324
314
- tapo_request = self .get_tapo_request ( tapo_method , tapo_params )
315
- response_data = await self .send_secure_passthrough (tapo_request )
325
+ aes_request = self .get_aes_request ( aes_method , aes_params )
326
+ response_data = await self .send_secure_passthrough (aes_request )
316
327
317
328
_LOGGER .debug (
318
329
"%s << %s" ,
@@ -331,13 +342,18 @@ async def close(self) -> None:
331
342
332
343
333
344
class AesEncyptionSession :
345
+ """Class for an AES encryption session."""
346
+
334
347
@staticmethod
335
348
def create_from_keypair (handshake_key : str , keypair ):
336
- handshake_key : bytes = base64 .b64decode (handshake_key .encode ("UTF-8" ))
349
+ """Create the encryption session."""
350
+ handshake_key_bytes : bytes = base64 .b64decode (handshake_key .encode ("UTF-8" ))
337
351
private_key_data = base64 .b64decode (keypair .get_private_key ().encode ("UTF-8" ))
338
352
339
353
private_key = serialization .load_der_private_key (private_key_data , None , None )
340
- key_and_iv = private_key .decrypt (handshake_key , asymmetric_padding .PKCS1v15 ())
354
+ key_and_iv = private_key .decrypt (
355
+ handshake_key_bytes , asymmetric_padding .PKCS1v15 ()
356
+ )
341
357
if key_and_iv is None :
342
358
raise ValueError ("Decryption failed!" )
343
359
@@ -348,13 +364,15 @@ def __init__(self, key, iv):
348
364
self .padding_strategy = padding .PKCS7 (algorithms .AES .block_size )
349
365
350
366
def encrypt (self , data ) -> bytes :
367
+ """Encrypt the message."""
351
368
encryptor = self .cipher .encryptor ()
352
369
padder = self .padding_strategy .padder ()
353
370
padded_data = padder .update (data ) + padder .finalize ()
354
371
encrypted = encryptor .update (padded_data ) + encryptor .finalize ()
355
372
return base64 .b64encode (encrypted )
356
373
357
374
def decrypt (self , data ) -> str :
375
+ """Decrypt the message."""
358
376
decryptor = self .cipher .decryptor ()
359
377
unpadder = self .padding_strategy .unpadder ()
360
378
decrypted = decryptor .update (base64 .b64decode (data )) + decryptor .finalize ()
@@ -363,8 +381,11 @@ def decrypt(self, data) -> str:
363
381
364
382
365
383
class KeyPair :
384
+ """Class for generating key pairs."""
385
+
366
386
@staticmethod
367
387
def create_key_pair (key_size : int = 1024 ):
388
+ """Create a key pair."""
368
389
private_key = rsa .generate_private_key (public_exponent = 65537 , key_size = key_size )
369
390
public_key = private_key .public_key ()
370
391
@@ -388,13 +409,17 @@ def __init__(self, private_key: str, public_key: str):
388
409
self .public_key = public_key
389
410
390
411
def get_private_key (self ) -> str :
412
+ """Get the private key."""
391
413
return self .private_key
392
414
393
415
def get_public_key (self ) -> str :
416
+ """Get the public key."""
394
417
return self .public_key
395
418
396
419
397
420
class SnowflakeId :
421
+ """Class for generating snowflake ids."""
422
+
398
423
EPOCH = 1420041600000 # Custom epoch (in milliseconds)
399
424
WORKER_ID_BITS = 5
400
425
DATA_CENTER_ID_BITS = 5
@@ -408,11 +433,15 @@ class SnowflakeId:
408
433
def __init__ (self , worker_id , data_center_id ):
409
434
if worker_id > SnowflakeId .MAX_WORKER_ID or worker_id < 0 :
410
435
raise ValueError (
411
- f"Worker ID can't be greater than { SnowflakeId .MAX_WORKER_ID } or less than 0"
436
+ "Worker ID can't be greater than "
437
+ + str (SnowflakeId .MAX_WORKER_ID )
438
+ + " or less than 0"
412
439
)
413
440
if data_center_id > SnowflakeId .MAX_DATA_CENTER_ID or data_center_id < 0 :
414
441
raise ValueError (
415
- f"Data center ID can't be greater than { SnowflakeId .MAX_DATA_CENTER_ID } or less than 0"
442
+ "Data center ID can't be greater than "
443
+ + str (SnowflakeId .MAX_DATA_CENTER_ID )
444
+ + " or less than 0"
416
445
)
417
446
418
447
self .worker_id = worker_id
@@ -421,6 +450,7 @@ def __init__(self, worker_id, data_center_id):
421
450
self .last_timestamp = - 1
422
451
423
452
def generate_id (self ):
453
+ """Generate a snowflake id."""
424
454
timestamp = self ._current_millis ()
425
455
426
456
if timestamp < self .last_timestamp :
0 commit comments