8000 Improve smartprotocol error handling and retries · python-kasa/python-kasa@1a57e63 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1a57e63

Browse files
committed
Improve smartprotocol error handling and retries
1 parent 35a4521 commit 1a57e63

File tree

8 files changed

+218
-53
lines changed

8 files changed

+218
-53
lines changed

kasa/aestransport.py

Lines changed: 64 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,21 @@
1717
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
1818

1919
from .credentials import Credentials
20-
from .exceptions import AuthenticationException, SmartDeviceException
20+
from .exceptions import (
21+
AuthenticationException,
22+
RetryableException,
23+
SmartDeviceException,
24+
TimeoutException,
25+
)
2126
from .json import dumps as json_dumps
2227
from .json import loads as json_loads
2328
from .protocol import BaseTransport
29+
from .smartprotocolerrors import (
30+
AUTHENTICATION_ERRORS,
31+
RETRYABLE_ERRORS,
32+
TIMEOUT_ERRORS,
33+
ErrorCode,
34+
)
2435

2536
_LOGGER = logging.getLogger(__name__)
2637

@@ -110,6 +121,19 @@ async def client_post(self, url, params=None, data=None, json=None, headers=None
110121

111122
return resp.status_code, response_data
112123

124+
def _handle_response_error_code(self, resp_dict: dict, msg: str):
125+
if (error_code := ErrorCode(resp_dict.get("error_code"))) != ErrorCode.SUCCESS:
126+
msg = f"{msg}: {self.host}: {error_code.name}({error_code.value})"
127+
if error_code in TIMEOUT_ERRORS:
128+
raise TimeoutException(msg)
129+
if error_code in RETRYABLE_ERRORS:
130+
raise RetryableException(msg)
131+
if error_code in AUTHENTICATION_ERRORS:
132+
self._handshake_done = False
133+
self._login_token = None
134+
raise AuthenticationException(msg)
135+
raise SmartDeviceException(msg)
136+
113137
async def send_secure_passthrough(self, request: str):
114138
"""Send encrypted message as passthrough."""
115139
url = f"http://{self.host}/app"
@@ -123,17 +147,22 @@ async def send_secure_passthrough(self, request: str):
123147
}
124148
status_code, resp_dict = await self.client_post(url, json=passthrough_request)
125149
# _LOGGER.debug(f"secure_passthrough response is {status_code}: {resp_dict}")
126-
if status_code == 200 and resp_dict["error_code"] == 0:
127-
response = self._encryption_session.decrypt( # type: ignore
128-
resp_dict["result"]["response"].encode()
150+
151+
if status_code != 200:
152+
raise SmartDeviceException(
153+
f"Unable to send message for {self.host}, status code is {status_code}"
129154
)
130-
_LOGGER.debug(f"decrypted secure_passthrough response is {response}")
131-
resp_dict = json_loads(response)
132-
return resp_dict
133-
else:
134-
self._handshake_done = False
135-
self._login_token = None
136-
raise AuthenticationException("Could not complete send")
155+
156+
self._handle_response_error_code(
157+
resp_dict, "Error sending secure_passthrough message"
158+
)
159+
160+
response = self._encryption_session.decrypt( # type: ignore
161+
resp_dict["result"]["response"].encode()
162+
)
163+
_LOGGER.debug(f"decrypted secure_passthrough response is {response}")
164+
resp_dict = json_loads(response)
165+
return resp_dict
137166

138167
async def perform_login(self, login_request: Union[str, dict], *, login_v2: bool):
139168
"""Login to the device."""
@@ -207,29 +236,33 @@ async def perform_handshake(self):
207236

208237
_LOGGER.debug(f"Device responded with: {resp_dict}")
209238

210-
if status_code == 200 and resp_dict["error_code"] == 0:
211-
_LOGGER.debug("Decoding handshake key...")
212-
handshake_key = resp_dict["result"]["key"]
213-
214-
self._session_cookie = self._http_client.cookies.get( # type: ignore
215-
self.SESSION_COOKIE_NAME
239+
if status_code != 200:
240+
raise SmartDeviceException(
241+
f"Unable to complete handshake for {self.host}, "
242+
+ "response status code is {status_code}"
216243
)
217-
if not self._session_cookie:
218-
self._session_cookie = self._http_client.cookies.get( # type: ignore
219-
"SESSIONID"
220-
)
221244

222-
self._session_expire_at = time.time() + 86400
223-
self._encryption_session = AesEncyptionSession.create_from_keypair(
224-
handshake_key, key_pair
245+
self._handle_response_error_code(resp_dict, "Unable to complete handshake")
246+
247+
_LOGGER.debug("Decoding handshake key...")
248+
handshake_key = resp_dict["result"]["key"]
249+
250+
self._session_cookie = self._http_client.cookies.get( # type: ignore
251+
self.SESSION_COOKIE_NAME
252+
)
253+
if not self._session_cookie:
254+
self._session_cookie = self._http_client.cookies.get( # type: ignore
255+
"SESSIONID"
225256
)
226257

227-
self._handshake_done = True
258+
self._session_expire_at = time.time() + 86400
259+
self._encryption_session = AesEncyptionSession.create_from_keypair(
260+
handshake_key, key_pair
261+
)
228262

229-
_LOGGER.debug("Handshake with %s complete", self.host)
263+
self._handshake_done = True
230264

231-
else:
232-
raise AuthenticationException("Could not complete handshake")
265+
_LOGGER.debug("Handshake with %s complete", self.host)
233266

234267
def _handshake_session_expired(self):
235268
"""Return true if session has expired."""
@@ -247,19 +280,14 @@ async def send(self, request: str):
247280
if self.needs_login:
248281
raise SmartDeviceException("Login must be complete before trying to send")
249282

250-
resp_dict = await self.send_secure_passthrough(request)
251-
if resp_dict["error_code"] != 0:
252-
self._handshake_done = False
253-
self._login_token = None
254-
raise SmartDeviceException(
255-
f"Could not complete send, response was {resp_dict}",
256-
)
257-
return resp_dict
283+
return await self.send_secure_passthrough(request)
258284

259285
async def close(self) -> None:
260286
"""Close the protocol."""
261287
client = self._http_client
262288
self._http_client = None
289+
self._handshake_done = False
290+
self._login_token = None
263291
if client:
264292
await client.aclose()
265293

kasa/exceptions.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,11 @@ class UnsupportedDeviceException(SmartDeviceException):
1111

1212
class AuthenticationException(SmartDeviceException):
1313
"""Base exception for device authentication errors."""
14+
15+
16+
class RetryableException(SmartDeviceException):
17+
"""Retryable exception for device errors."""
18+
19+
20+
class TimeoutException(SmartDeviceException):
21+
"""Timeout exception for device errors."""

kasa/klaptransport.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,7 @@ async def close(self) -> None:
377377
"""Close the transport."""
378378
client = self._http_client
379379
self._http_client = None
380+
self._handshake_done = False
380381
if client:
381382
await client.aclose()
382383

kasa/smartprotocol.py

Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,20 @@
1616

1717
from .aestransport import AesTransport
1818
from .credentials import Credentials
19-
from .exceptions import AuthenticationException, SmartDeviceException
19+
from .exceptions import (
20+
AuthenticationException,
21+
RetryableException,
22+
SmartDeviceException,
23+
TimeoutException,
24+
)
2025
from .json import dumps as json_dumps
2126
from .protocol import BaseTransport, TPLinkProtocol, md5
27+
from .smartprotocolerrors import (
28+
AUTHENTICATION_ERRORS,
29+
RETRYABLE_ERRORS,
30+
TIMEOUT_ERRORS,
31+
ErrorCode,
32+
)
2233

2334
_LOGGER = logging.getLogger(__name__)
2435
logging.getLogger("httpx").propagate = False
@@ -64,6 +75,22 @@ async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
6475
"""Query the device retrying for retry_count on failure."""
6576
async with self._query_lock:
6677
resp_dict = await self._query(request, retry_count)
78+
79+
if (
80+
error_code := ErrorCode(resp_dict.get("error_code"))
81+
) != ErrorCode.SUCCESS:
82+
msg = (
83+
f"Error querying device: {self.host}: "
84+
+ f"{error_code.name}({error_code.value})"
85+
)
86+
if error_code in TIMEOUT_ERRORS:
87+
raise TimeoutException(msg)
88+
if error_code in RETRYABLE_ERRORS:
89+
raise RetryableException(msg)
90+
if error_code in AUTHENTICATION_ERRORS:
91+
raise AuthenticationException(msg)
92+
raise SmartDeviceException(msg)
93+
6794
if "result" in resp_dict:
6895
return resp_dict["result"]
6996
return {}
@@ -86,20 +113,41 @@ async def _query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
86113
f"Unable to connect to the device: {self.host}: {cex}"
87114
) from cex
88115
except TimeoutError as tex:
89-
await self.close()
90-
raise SmartDeviceException(
91-
f"Unable to connect to the device, timed out: {self.host}: {tex}"
92-
) from tex
116+
if retry >= retry_count:
117+
await self.close()
118+
raise SmartDeviceException(
119+
"Unable to connect to the device, "
120+
+ f"timed out: {self.host}: {tex}"
121+
) from tex
122+
await asyncio.sleep(2)
123+
continue
93124
except AuthenticationException as auex:
125+
await self.close()
94126
_LOGGER.debug("Unable to authenticate with %s, not retrying", self.host)
95127
raise auex
128+
except RetryableException as ex:
129+
if retry >= retry_count:
130+
await self.close()
131+
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
132+
raise ex
133+
continue
134+
except TimeoutException as ex:
135+
if retry >= retry_count:
136+
await self.close()
137+
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
138+
raise ex
139+
await asyncio.sleep(2)
140+
continue
96141
except Exception as ex:
97-
await self.close()
98142
if retry >= retry_count:
143+
await self.close()
99144
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
100145
raise SmartDeviceException(
101-
f"Unable to connect to the device: {self.host}: {ex}"
146+
f"Unable to query the device {self.host}:{self.port}: {ex}"
102147
) from ex
148+
_LOGGER.debug(
149+
"Unable to query the device %s, retrying: %s", self.host, ex
150+
)
103151
continue
104152

105153
# make mypy happy, this should never be reached..

kasa/smartprotocolerrors.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""Module for defining SMART protocol errors."""
2+
from enum import Enum
3+
4+
5+
class ErrorCode(Enum):
6+
"""Enum for SMART Error Codes."""
7+
8+
SUCCESS = 0
9+
10+
UNKNOWN_ERROR_CODE = 999999
11+ 10000
ERROR_CODE_NONE = 999998
12+
ERROR_CODE_INVALID = 999997
13+
14+
ACCOUNT_ERROR = -2101
15+
AES_DECODE_FAIL_ERROR = -1005
16+
ANTITHEFT_CONFLICT_ERROR = -2002
17+
ANTITHEFT_ERROR = -2001
18+
ANTITHEFT_SAVE_ERROR = -2003
19+
CLOUD_FAILED_ERROR = -1007
20+
CMD_COMMAND_CANCEL_ERROR = 1001
21+
COMMON_FAILED_ERROR = -1
22+
COUNTDOWN_CONFLICT_ERROR = -1902
23+
COUNTDOWN_ERROR = -1901
24+
COUNTDOWN_SAVE_ERROR = -1903
25+
DEVICE_ERROR = -1301
26+
DEVICE_NEXT_EVENT_ERROR = -1302
27+
DST_ERROR = -2301
28+
DST_SAVE_ERROR = -2302
29+
FIRMWARE_ERROR = -1401
30+
FIRMWARE_VER_ERROR_ERROR = -1402
31+
HAND_SHAKE_FAILED_ERROR = 1100
32+
HTTP_TRANSPORT_FAILED_ERROR = 1112
33+
JSON_DECODE_FAIL_ERROR = -1003
34+
JSON_ENCODE_FAIL_ERROR = -1004
35+
INVALID_PUBLIC_KEY_ERROR = -1010 # Unverified
36+
LOGIN_ERROR = -1501
37+
LOGIN_FAILED_ERROR = 1111
38+
MULTI_REQUEST_FAILED_ERROR = 1200
39+
NULL_TRANSPORT_ERROR = 1000
40+
PARAMS_ERROR = -1008
41+
QUICK_SETUP_ERROR = -1201
42+
REQUEST_LEN_ERROR_ERROR = -1006
43+
SCHEDULE_CONFLICT_ERROR = -1803
44+
SCHEDULE_ERROR = -1801
45+
SCHEDULE_FULL_ERROR = -1802
46+
SCHEDULE_INDEX_ERROR = -1805
47+
SCHEDULE_SAVE_ERROR = -1804
48+
SESSION_PARAM_ERROR = -1101
49+
SESSION_TIMEOUT_ERROR = 9999
50+
STAT_ERROR = -2201
51+
STAT_SAVE_ERROR = -2202
52+
TIME_ERROR = -1601
53+
TIME_SAVE_ERROR = -1603
54+
TIME_SYS_ERROR = -1602
55+
TRANSPORT_NOT_AVAILABLE_ERROR = 1002
56+
UNKNOWN_METHOD_ERROR = -1002
57+
UNSPECIFIC_ERROR = -1001
58+
WIRELESS_ERROR = -1701
59+
WIRELESS_UNSUPPORTED_ERROR = -1702
60+
61+
@classmethod
62+
def _missing_(cls, value):
63+
if value is None:
64+
return cls.ERROR_CODE_NONE
65+
if isinstance(value, int):
66+
return cls.UNKNOWN_ERROR_CODE
67+
else:
68+
return cls.ERROR_CODE_INVALID
69+
70+
71+
RETRYABLE_ERRORS = [
72+
ErrorCode.TRANSPORT_NOT_AVAILABLE_ERROR,
73+
ErrorCode.HTTP_TRANSPORT_FAILED_ERROR,
74+
ErrorCode.UNSPECIFIC_ERROR,
75+
]
76+
77+
AUTHENTICATION_ERRORS = [
78+
ErrorCode.LOGIN_ERROR,
79+
ErrorCode.LOGIN_FAILED_ERROR,
80+
ErrorCode.AES_DECODE_FAIL_ERROR,
81+
ErrorCode.HAND_SHAKE_FAILED_ERROR,
82+
]
83+
84+
TIMEOUT_ERRORS = [
85+
ErrorCode.SESSION_TIMEOUT_ERROR,
86+
]

kasa/tapo/tapobulb.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -166,13 +166,7 @@ async def set_hsv(
166166
if value is not None:
167167
request_payload["brightness"] = value
168168

169-
return await self.protocol.query(
170-
{
171-
"set_device_info": {
172-
**request_payload
173-
}
174-
}
175-
)
169+
return await self.protocol.query({"set_device_info": {**request_payload}})
176170

177171
async def set_color_temp(
178172
self, temp: int, *, brightness=None, transition: Optional[int] = None

kasa/tests/newfakes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -315,11 +315,11 @@ async def send(self, request: str):
315315
method = request_dict["method"]
316316
params = request_dict["params"]
317317
if method == "component_nego" or method[:4] == "get_":
318-
return {"result": self.info[method]}
318+
return {"result": self.info[method], "error_code": 0}
319319
elif method[:4] == "set_":
320320
target_method = f"get_{method[4:]}"
321321
self.info[target_method].update(params)
322-
return {"result": ""}
322+
return {"error_code": 0}
323323

324324
async def close(self) -> None:
325325
pass

kasa/tests/test_klapprotocol.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ async def test_protocol_retry_recoverable_error(
8686
async def test_protocol_reconnect(mocker, retry_count, protocol_class, transport_class):
8787
host = "127.0.0.1"
8888
remaining = retry_count
89-
mock_response = {"result": {"great": "success"}}
89+
mock_response = {"result": {"great": "success"}, "error_code": 0}
9090

9191
def _fail_one_less_than_retry_count(*_, **__):
9292
nonlocal remaining

0 commit comments

Comments
 (0)
0