8000 Keep connection open and lock to prevent duplicate requests (#213) · python-kasa/python-kasa@e31cc66 · GitHub
[go: up one dir, main page]

Skip to content

Commit e31cc66

Browse files
bdracorytilahti
andauthored
Keep connection open and lock to prevent duplicate requests (#213)
* Keep connection open and lock to prevent duplicate requests * option to not update children * tweaks * typing * tweaks * run tests in the same event loop * memorize model * Update kasa/protocol.py Co-authored-by: Teemu R. <tpr@iki.fi> * Update kasa/protocol.py Co-authored-by: Teemu R. <tpr@iki.fi> * Update kasa/protocol.py Co-authored-by: Teemu R. <tpr@iki.fi> * Update kasa/protocol.py Co-authored-by: Teemu R. <tpr@iki.fi> * dry * tweaks * warn when the event loop gets switched out from under us * raise on unable to connect multiple times * fix patch target * tweaks * isrot * reconnect test * prune * fix mocking * fix mocking * fix test under python 3.7 * fix test under python 3.7 * less patching * isort * use mocker to patch * disable on old python since mocking doesnt work * avoid disconnect/reconnect cycles * isort * Fix hue validation * Fix latitude_i/longitude_i units Co-authored-by: Teemu R. <tpr@iki.fi>
1 parent f1b28e7 commit e31cc66

File tree

11 files changed

+238
-93
lines changed

11 files changed

+238
-93
lines changed

devtools/dump_devinfo.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -78,16 +78,17 @@ def cli(host, debug):
7878
),
7979
]
8080

81-
protocol = TPLinkSmartHomeProtocol()
82-
8381
successes = []
8482

8583
for test_call in items:
84+
85+
async def _run_query():
86+
protocol = TPLinkSmartHomeProtocol(host)
87+
return await protocol.query({test_call.module: {test_call.method: None}})
88+
8689
try:
8790
click.echo(f"Testing {test_call}..", nl=False)
88-
info = asyncio.run(
89-
protocol.query(host, {test_call.module: {test_call.method: None}})
90-
)
91+
info = asyncio.run(_run_query())
9192
resp = info[test_call.module]
9293
except Exception as ex:
9394
click.echo(click.style(f"FAIL {ex}", fg="red"))
@@ -107,8 +108,12 @@ def cli(host, debug):
107108
108109
final = default_to_regular(final)
109110

111+
async def _run_final_query():
112+
protocol = TPLinkSmartHomeProtocol(host)
113+
return await protocol.query(final_query)
114+
110115
try:
111-
final = asyncio.run(protocol.query(host, final_query))
116+
final = asyncio.run(_run_final_query())
112117
except Exception as ex:
113118
click.echo(
114119
click.style(

kasa/discover.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ def __init__(
4040
self.discovery_packets = discovery_packets
4141
self.interface = interface
4242
self.on_discovered = on_discovered
43-
self.protocol = TPLinkSmartHomeProtocol()
4443
self.target = (target, Discover.DISCOVERY_PORT)
4544
self.discovered_devices = {}
4645

@@ -61,7 +60,7 @@ def do_discover(self) -> None:
6160
"""Send number of discovery datagrams."""
6261
req = json.dumps(Discover.DISCOVERY_QUERY)
6362
_LOGGER.debug("[DISCOVERY] %s >> %s", self.target, Discover.DISCOVERY_QUERY)
64-
encrypted_req = self.protocol.encrypt(req)
63+
encrypted_req = TPLinkSmartHomeProtocol.encrypt(req)
6564
for i in range(self.discovery_packets):
6665
self.transport.sendto(encrypted_req[4:], self.target) # type: ignore
6766

@@ -71,7 +70,7 @@ def datagram_received(self, data, addr) -> None:
7170
if ip in self.discovered_devices:
7271
return
7372

74-
info = json.loads(self.protocol.decrypt(data))
73+
info = json.loads(TPLinkSmartHomeProtocol.decrypt(data))
7574
_LOGGER.debug("[DISCOVERY] %s << %s", ip, info)
7675

7776
device_class = Discover._get_device_class(info)
@@ -190,9 +189,9 @@ async def discover_single(host: str) -> SmartDevice:
190189
:rtype: SmartDevice
191190
:return: Object for querying/controlling found device.
192191
"""
193-
protocol = TPLinkSmartHomeProtocol()
192+
protocol = TPLinkSmartHomeProtocol(host)
194193

195-
info = await protocol.query(host, Discover.DISCOVERY_QUERY)
194+
info = await protocol.query(Discover.DISCOVERY_QUERY)
196195

197196
device_class = Discover._get_device_class(info)
198197
dev = device_class(host)

kasa/protocol.py

Lines changed: 104 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,12 @@
1010
http://www.apache.org/licenses/LICENSE-2.0
1111
"""
1212
import asyncio
13+
import contextlib
1314
import json
1415
import logging
1516
import struct
1617
from pprint import pformat as pf
17-
from typing import Dict, Union
18+
from typing import Dict, Optional, Union
1819

1920
from .exceptions import SmartDeviceException
2021

@@ -28,8 +29,26 @@ class TPLinkSmartHomeProtocol:
2829
DEFAULT_PORT = 9999
2930
DEFAULT_TIMEOUT = 5
3031

31-
@staticmethod
32-
async def query(host: str, request: Union[str, Dict], retry_count: int = 3) -> Dict:
32+
BLOCK_SIZE = 4
33+
34+
def __init__(self, host: str) -> None:
35+
"""Create a protocol object."""
36+
self.host = host
37+
self.reader: Optional[asyncio.StreamReader] = None
38+
self.writer: Optional[asyncio.StreamWriter] = None
39+
self.query_lock: Optional[asyncio.Lock] = None
40+
self.loop: Optional[asyncio.AbstractEventLoop] = None
41+
42+
def _detect_event_loop_change(self) -> None:
43+
"""Check if this object has been reused betwen event loops."""
44+
loop = asyncio.get_running_loop()
45+
if not self.loop:
46+
self.loop = loop
47+
elif self.loop != loop:
48+
_LOGGER.warning("Detected protocol reuse between different event loop")
49+
self._reset()
50+
51+
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
3352
"""Request information from a TP-Link SmartHome Device.
3453
3554
:param str host: host name or ip address of the device
@@ -38,57 +57,106 @@ async def query(host: str, request: Union[str, Dict], retry_count: int = 3) -> D
3857
:param retry_count: how many retries to do in case of failure
3958
:return: response dict
4059
"""
60+
self._detect_event_loop_change()
61+
62+
if not self.query_lock:
63+
self.query_lock = asyncio.Lock()
64+
4165
if isinstance(request, dict):
4266
request = json.dumps(request)
67+
assert isinstance(request, str)
4368

4469
timeout = TPLinkSmartHomeProtocol.DEFAULT_TIMEOUT
45-
writer = None
70+
71+
async with self.query_lock:
72+
return await self._query(request, retry_count, timeout)
73+
74+
async def _connect(self, timeout: int) -> bool:
75+
"""Try to connect or reconnect to the device."""
76+
if self.writer:
77+
return True
78+
79+
with contextlib.suppress(Exception):
80+
self.reader = self.writer = None
81+
task = asyncio.open_connection(
82+
self.host, TPLinkSmartHomeProtocol.DEFAULT_PORT
83+
)
84+
self.reader, self.writer = await asyncio.wait_for(task, timeout=timeout)
85+
return True
86+
87+
return False
88+
89+
async def _execute_query(self, request: str) -> Dict:
90+
"""Execute a query on the device and wait for the response."""
91+
assert self.writer is not None
92+
assert self.reader is not None
93+
94+
_LOGGER.debug("> (%i) %s", len(request), request)
95+
self.writer.write(TPLinkSmartHomeProtocol.encrypt(request))
96+
await self.writer.drain()
97+
98+
packed_block_size = await self.reader.readexactly(self.BLOCK_SIZE)
99+
length = struct.unpack(">I", packed_block_size)[0]
100+
101+
buffer = await self.reader.readexactly(length)
102+
response = TPLinkSmartHomeProtocol.decrypt(buffer)
103+
json_payload = json.loads(response)
104+
_LOGGER.debug("< (%i) %s", len(response), pf(json_payload))
105+
return json_payload
106+
107+
async def close(self):
108+
"""Close the connection."""
109+
writer = self.writer
110+
self._reset()
111+
if writer:
112+
writer.close()
113+
with contextlib.suppress(Exception):
114+
await writer.wait_closed()
115+
116+
def _reset(self):
117+
"""Clear any varibles that should not survive between loops."""
118+
self.writer = None
119+
self.reader = None
120+
self.query_lock = None
121+
self.loop = None
122+
123+
async def _query(self, request: str, retry_count: int, timeout: int) -> Dict:
124+
"""Try to query a device."""
46125
for retry in range(retry_count + 1):
126+
if not await self._connect(timeout):
127+
await self.close()
128+
if retry >= retry_count:
129+
_LOGGER.debug("Giving up after %s retries", retry)
130+
raise SmartDeviceException(
131+
f"Unable to connect to the device: {self.host}"
132+
)
133+
continue
134+
47135
try:
48-
task = asyncio.open_connection(
49-
host, TPLinkSmartHomeProtocol.DEFAULT_PORT
136+
assert self.reader is not None
137+
assert self.writer is not None
138+
return await asyncio.wait_for(
139+
self._execute_query(request), timeout=timeout
50140
)
51-
reader, writer = await asyncio.wait_for(task, timeout=timeout)
52-
_LOGGER.debug("> (%i) %s", len(request), request)
53-
writer.write(TPLinkSmartHomeProtocol.encrypt(request))
54-
await writer.drain()
55-
56-
buffer = bytes()
57-
# Some devices send responses with a length header of 0 and
58-
# terminate with a zero size chunk. Others send the length and
59-
# will hang if we attempt to read more data.
60-
length = -1
61-
while True:
62-
chunk = await reader.read(4096)
63-
if length == -1:
64-
length = struct.unpack(">I", chunk[0:4])[0]
65-
buffer += chunk
66-
if (length > 0 and len(buffer) >= length + 4) or not chunk:
67-
break
68-
69-
response = TPLinkSmartHomeProtocol.decrypt(buffer[4:])
70-
json_payload = json.loads(response)
71-
_LOGGER.debug("< (%i) %s", len(response), pf(json_payload))
72-
73-
return json_payload
74-
75141
except Exception as ex:
142+
await self.close()
76143
if retry >= retry_count:
77144
_LOGGER.debug("Giving up after %s retries", retry)
78145
raise SmartDeviceException(
79-
"Unable to query the device: %s" % ex
14 8000 6+
f"Unable to query the device: {ex}"
80147
) from ex
81148

82149
_LOGGER.debug("Unable to query the device, retrying: %s", ex)
83150

84-
finally:
85-
if writer:
86-
writer.close()
87-
await writer.wait_closed()
88-
89151
# make mypy happy, this should never be reached..
152+
await self.close()
90153
raise SmartDeviceException("Query reached somehow to unreachable")
91154

155+
def __del__(self):
156+
if self.writer and self.loop and self.loop.is_running():
157+
self.writer.close()
158+
self._reset()
159+
92160
@staticmethod
93161
def _xor_payload(unencrypted):
94162
key = TPLinkSmartHomeProtocol.INITIALIZATION_VECTOR

kasa/smartdevice.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def __init__(self, host: str) -> None:
194194
"""
195195
self.host = host
196196

197-
self.protocol = TPLinkSmartHomeProtocol()
197+
self.protocol = TPLinkSmartHomeProtocol(host)
198198
self.emeter_type = "emeter"
199199
_LOGGER.debug("Initializing %s of type %s", self.host, type(self))
200200
self._device_type = DeviceType.Unknown
@@ -234,7 +234,7 @@ async def _query_helper(
234234
request = self._create_request(target, cmd, arg, child_ids)
235235

236236
try:
237-
response = await self.protocol.query(host=self.host, request=request)
237+
response = await self.protocol.query(request=request)
238238
except Exception as ex:
239239
raise SmartDeviceException(f"Communication error on {target}:{cmd}") from ex
240240

@@ -272,7 +272,7 @@ async def get_sys_info(self) -> Dict[str, Any]:
272272
"""Retrieve system information."""
273273
return await self._query_helper("system", "get_sysinfo")
274274

275-
async def update(self):
275+
async def update(self, update_children: bool = True):
276276
"""Query the device to update the data.
277277
278278
Needed for properties that are decorated with `requires_update`.
@@ -285,7 +285,7 @@ async def update(self):
285285
# See #105, #120, #161
286286
if self._last_update is None:
287287
_LOGGER.debug("Performing the initial update to obtain sysinfo")
288-
self._last_update = await self.protocol.query(self.host, req)
288+
self._last_update = await self.protocol.query(req)
289289
self._sys_info = self._last_update["system"]["get_sysinfo"]
290290
# If the device has no emeter, we are done for the initial update
291291
# Otherwise we will follow the regular code path to also query
@@ -299,7 +299,7 @@ async def update(self):
299299
)
300300
req.update(self._create_emeter_request())
301301

302-
self._last_update = await self.protocol.query(self.host, req)
302+
self._last_update = await self.protocol.query(req)
303303
self._sys_info = self._last_update["system"]["get_sysinfo"]
304304

305305
def update_from_discover_info(self, info):
@@ -383,8 +383,8 @@ def location(self) -> Dict:
383383
loc["latitude"] = sys_info["latitude"]
384384
loc["longitude"] = sys_info["longitude"]
385385
elif "latitude_i" in sys_info and "longitude_i" in sys_info:
386-
loc["latitude"] = sys_info["latitude_i"]
387-
loc["longitude"] = sys_info["longitude_i"]
386+
loc["latitude"] = sys_info["latitude_i"] / 10000
387+
loc["longitude"] = sys_info["longitude_i"] / 10000
388388
else:
389389
_LOGGER.warning("Unsupported device location.")
390390

kasa/smartstrip.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,12 @@ def is_on(self) -> bool:
8787
"""Return if any of the outlets are on."""
8888
return any(plug.is_on for plug in self.children)
8989

90-
async def update(self):
90+
async def update(self, update_children: bool = True):
9191
"""Update some of the attributes.
9292
9393
Needed for methods that are decorated with `requires_update`.
9494
"""
95-
await super().update()
95+
await super().update(update_children)
9696

9797
# Initialize the child devices during the first update.
9898
if not self.children:
@@ -103,7 +103,7 @@ async def update(self):
103103
SmartStripPlug(self.host, parent=self, child_id=child["id"])
104104
)
105105

106-
if self.has_emeter:
106+
if update_children and self.has_emeter:
107107
for plug in self.children:
108108
await plug.update()
109109

@@ -243,13 +243,13 @@ def __init__(self, host: str, parent: "SmartStrip", child_id: str) -> None:
243243
self._sys_info = parent._sys_info
244244
self._device_type = DeviceType.StripSocket
245245

246-
async def update(self):
246+
async def update(self, update_children: bool = True):
247247
"""Query the device to update the data.
248248
249249
Needed for properties that are decorated with `requires_update`.
250250
"""
251251
self._last_update = await self.parent.protocol.query(
252-
self.host, self._create_emeter_request()
252+
self._create_emeter_request()
253253
)
254254

255255
def _create_request(

0 commit comments

Comments
 (0)
0