10000 Add retries to protocol queries (#65) · dave-vsdevs/python-kasa@9dc0cba · GitHub
[go: up one dir, main page]

Skip to content

Commit 9dc0cba

Browse files
authored
Add retries to protocol queries (python-kasa#65)
* Add retries to query(), defaults to 3 + add tests * Catch also json decoding errors for retries * add missing exceptions file, fix old protocol tests
1 parent 644a10a commit 9dc0cba

File tree

6 files changed

+156
-100
lines changed

6 files changed

+156
-100
lines changed

kasa/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@
1313
"""
1414
from importlib_metadata import version # type: ignore
1515
from kasa.discover import Discover
16+
from kasa.exceptions import SmartDeviceException
1617
from kasa.protocol import TPLinkSmartHomeProtocol
1718
from kasa.smartbulb import SmartBulb
18-
from kasa.smartdevice import DeviceType, EmeterStatus, SmartDevice, SmartDeviceException
19+
from kasa.smartdevice import DeviceType, EmeterStatus, SmartDevice
1920
from kasa.smartdimmer import SmartDimmer
2021
from kasa.smartplug import SmartPlug
2122
from kasa.smartstrip import SmartStrip

kasa/exceptions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""python-kasa exceptions."""
2+
3+
4+
class SmartDeviceException(Exception):
5+
"""Base exception for device errors."""

kasa/protocol.py

Lines changed: 49 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from pprint import pformat as pf
1717
from typing import Dict, Union
1818

19+
from .exceptions import SmartDeviceException
20+
1921
_LOGGER = logging.getLogger(__name__)
2022

2123

@@ -27,48 +29,65 @@ class TPLinkSmartHomeProtocol:
2729
DEFAULT_TIMEOUT = 5
2830

2931
@staticmethod
30-
async def query(host: str, request: Union[str, Dict]) -> Dict:
32+
async def query(host: str, request: Union[str, Dict], retry_count: int = 3) -> Dict:
3133
"""Request information from a TP-Link SmartHome Device.
3234
3335
:param str host: host name or ip address of the device
3436
:param request: command to send to the device (can be either dict or
3537
json string)
38+
:param retry_count: how many retries to do in case of failure
3639
:return: response dict
3740
"""
3841
if isinstance(request, dict):
3942
request = json.dumps(request)
4043

4144
timeout = TPLinkSmartHomeProtocol.DEFAULT_TIMEOUT
4245
writer = None
43-
try:
44-
task = asyncio.open_connection(host, TPLinkSmartHomeProtocol.DEFAULT_PORT)
45-
reader, writer = await asyncio.wait_for(task, timeout=timeout)
46-
_LOGGER.debug("> (%i) %s", len(request), request)
47-
writer.write(TPLinkSmartHomeProtocol.encrypt(request))
48-
await writer.drain()
49-
50-
buffer = bytes()
51-
# Some devices send responses with a length header of 0 and
52-
# terminate with a zero size chunk. Others send the length and
53-
# will hang if we attempt to read more data.
54-
length = -1
55-
while True:
56-
chunk = await reader.read(4096)
57-
if length == -1:
58-
length = struct.unpack(">I", chunk[0:4])[0]
59-
buffer += chunk
60-
if (length > 0 and len(buffer) >= length + 4) or not chunk:
61-
break
62-
finally:
63-
if writer:
64-
writer.close()
65-
await writer.wait_closed()
66-
67-
response = TPLinkSmartHomeProtocol.decrypt(buffer[4:])
68-
json_payload = json.loads(response)
69-
_LOGGER.debug("< (%i) %s", len(response), pf(json_payload))
70-
71-
return json_payload
46+
for retry in range(retry_count + 1):
47+
try:
48+
task = asyncio.open_connection(
49+
host, TPLinkSmartHomeProtocol.DEFAULT_PORT
50+
)
51+
reader, writer = await asyncio.wait_for(task, timeout=timeout)
52+
_LOGGER.debug("> (%i) %s", len(request), request)
53+
writer.write(TPLinkSmartHomeProtocol.encrypt(request))
54+
await writer.drain()
55+
56+
buffer = bytes()
57+
# Some devices send responses with a length header of 0 and
58+
# terminate with a zero size chunk. Others send the length and
59+
# will hang if we attempt to read more data.
60+
length = -1
61+
while True:
62+
chunk = await reader.read(4096)
63+
if length == -1:
64+
length = struct.unpack(">I", chunk[0:4])[0]
65+
buffer += chunk
66+
if (length > 0 and len(buffer) >= length + 4) or not chunk:
67+
break
68+
69+
response = TPLinkSmartHomeProtocol.decrypt(buffer[4:])
70+
json_payload = json.loads(response)
71+
_LOGGER.debug("< (%i) %s", len(response), pf(json_payload))
72+
73+
return json_payload
74+
75+
except Exception as ex:
76+
if retry >= retry_count:
77+
_LOGGER.debug("Giving up after %s retries", retry)
78+
raise SmartDeviceException(
79+
"Unable to query the device: %s" % ex
80+
) from ex
81+
82+
_LOGGER.debug("Unable to query the device, retrying: %s", ex)
83+
84+
finally:
85+
if writer:
86+
writer.close()
87+
await writer.wait_closed()
88+
89+
# make mypy happy, this should never be reached..
90+
raise SmartDeviceException("Query reached somehow to unreachable")
7291

7392
@staticmethod
7493
def encrypt(request: str) -> bytes:

kasa/smartdevice.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
from enum import Enum
2020
from typing import Any, Dict, List, Optional
2121

22-
from kasa.protocol import TPLinkSmartHomeProtocol
22+
from .exceptions import SmartDeviceException
23+
from .protocol import TPLinkSmartHomeProtocol
2324

2425
_LOGGER = logging.getLogger(__name__)
2526

@@ -47,10 +48,6 @@ class WifiNetwork:
4748
rssi: Optional[int] = None
4849

4950

50-
class SmartDeviceException(Exception):
51-
"""Base exception for device errors."""
52-
53-
5451
class EmeterStatus(dict):
5552
"""Container for converting different representations of emeter data.
5653

kasa/tests/conftest.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import json
44
import os
55
from os.path import basename
6+
from unittest.mock import MagicMock
67

78
import pytest # type: ignore # see https://github.com/pytest-dev/pytest/issues/3342
89

@@ -151,3 +152,14 @@ def pytest_collection_modifyitems(config, items):
151152
return
152153
else:
153154
print("Running against ip %s" % config.getoption("--ip"))
155+
156+
157+
# allow mocks to be awaited
158+
# https://stackoverflow.com/questions/51394411/python-object-magicmock-cant-be-used-in-await-expression/51399767#51399767
159+
160+
161+
async def async_magic():
162+
pass
163+
164+
165+
MagicMock.__await__ = lambda x: async_magic().__await__()

kasa/tests/test_protocol.py

Lines changed: 86 additions & 64 deletions
E377
Original file line numberDiff line numberDiff line change
@@ -1,73 +1,95 @@
11
import json
2-
from unittest import TestCase
32

3+
import pytest
4+
5+
from ..exceptions import SmartDeviceException
46
from ..protocol import TPLinkSmartHomeProtocol
57

68

7-
class TestTPLinkSmartHomeProtocol(TestCase):
8-
def test_encrypt(self):
9-
d = json.dumps({"foo": 1, "bar": 2})
10-
encrypted = TPLinkSmartHomeProtocol.encrypt(d)
11-
# encrypt adds a 4 byte header
12-
encrypted = encrypted[4:]
13-
self.assertEqual(d, TPLinkSmartHomeProtocol.decrypt(encrypted))
14-
15-
def test_encrypt_unicode(self):
16-
d = "{'snowman': '\u2603'}"
17-
18-
e = bytes(
19-
[
20-
208,
21-
247,
22-
132,
23-
234,
24-
133,
25-
242,
26-
159,
27-
254,
28-
144,
29-
183,
30-
141,
31-
173,
32-
138,
33-
104,
34-
240,
35-
115,
36-
84,
37-
41,
38-
]
39-
)
9+
@pytest.mark.parametrize("retry_count", [1, 3, 5])
10+
async def test_protocol_retries(mocker, retry_count):
11+
def aio_mock_writer(_, __):
12+
reader = mocker.patch("asyncio.StreamReader")
13+
writer = mocker.patch("asyncio.StreamWriter")
4014

41-
encrypted = TPLinkSmartHomeProtocol.encrypt(d)
42-
# encrypt adds a 4 byte header
43-
encrypted = encrypted[4:]
44-
45-
self.assertEqual(e, encrypted)
46-
47-
def test_decrypt_unicode(self):
48-
e = bytes(
49-
[
50-
208,
51-
247,
52-
132,
53-
234,
54-
133,
55-
242,
56-
159,
57-
254,
58-
144,
59-
183,
60-
141,
61-
173,
62-
138,
63-
104,
64-
240,
65-
115,
66-
84,
67-
41,
68-
]
15+
mocker.patch(
16+
"asyncio.StreamWriter.write", side_effect=Exception("dummy exception")
6917
)
7018

71-
d = "{'snowman': '\u2603'}"
19+
return reader, writer
20+
21+
conn = mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer)
22+
with pytest.raises(SmartDeviceException):
23+
await TPLinkSmartHomeProtocol.query("127.0.0.1", {}, retry_count=retry_count)
24+
25+
assert conn.call_count == retry_count + 1
26+
27+
28+
def test_encrypt():
29+
d = json.dumps({"foo": 1, "bar": 2})
30+
encrypted = TPLinkSmartHomeProtocol.encrypt(d)
31+
# encrypt adds a 4 byte header
32+
encrypted = encrypted[4:]
33+
assert d == TPLinkSmartHomeProtocol.decrypt(encrypted)
34+
35+
36+
def test_encrypt_unicode():
37+
d = "{'snowman': '\u2603'}"
38+
39+
e = bytes(
40+
[
41+
208,
42+
247,
43+
132,
44+
234,
45+
133,
46+
242,
47+
159,
48+
254,
49+
144,
50+
183,
51+
141,
52+
173,
53+
138,
54+
104,
55+
240,
56+
115,
57+
84,
58+
41,
59+
]
60+
)
61+
62+
encrypted = TPLinkSmartHomeProtocol.encrypt(d)
63+
# encrypt adds a 4 byte header
64+
encrypted = encrypted[4:]
65+
66+
assert e == encrypted
67+
68+
69+
def test_decrypt_unicode():
70+
e = bytes(
71+
[
72+
208,
73+
247,
74+
132,
75+
234,
76+
133,
77+
242,
78+
159,
79+
254,
80+
144,
81+
183,
82+
141,
83+
173,
84+
138,
85+
104,
86+
240,
87+
115,
88+
84,
89+
41,
90+
]
91+
)
92+
93+
d = "{'snowman': '\u2603'}"
7294

73-
self.assertEqual(d, TPLinkSmartHomeProtocol.decrypt(e))
95+
assert d == TPLinkSmartHomeProtocol.decrypt(e)

0 commit comments

Comments
 (0)
0