10
10
http://www.apache.org/licenses/LICENSE-2.0
11
11
"""
12
12
import asyncio
13
+ import contextlib
13
14
import json
14
15
import logging
15
16
import struct
16
17
from pprint import pformat as pf
17
- from typing import Dict , Union
18
+ from typing import Dict , Optional , Union
18
19
19
20
from .exceptions import SmartDeviceException
20
21
@@ -28,8 +29,26 @@ class TPLinkSmartHomeProtocol:
28
29
DEFAULT_PORT = 9999
29
30
DEFAULT_TIMEOUT = 5
30
31
31
- @staticmethod
32
- async def query (host : str , request : Union [str , Dict ], retry_count : int = 3 ) -> Dict :
32
+ BLOCK_SIZE = 4
33
+
34
+ def __init__ (self , host : str ) -> None :
35
+ """Create a protocol object."""
36
+ self .host = host
37
+ self .reader : Optional [asyncio .StreamReader ] = None
38
+ self .writer : Optional [asyncio .StreamWriter ] = None
39
+ self .query_lock : Optional [asyncio .Lock ] = None
40
+ self .loop : Optional [asyncio .AbstractEventLoop ] = None
41
+
42
+ def _detect_event_loop_change (self ) -> None :
43
+ """Check if this object has been reused betwen event loops."""
44
+ loop = asyncio .get_running_loop ()
45
+ if not self .loop :
46
+ self .loop = loop
47
+ elif self .loop != loop :
48
+ _LOGGER .warning ("Detected protocol reuse between different event loop" )
49
+ self ._reset ()
50
+
51
+ async def query (self , request : Union [str , Dict ], retry_count : int = 3 ) -> Dict :
33
52
"""Request information from a TP-Link SmartHome Device.
34
53
35
54
:param str host: host name or ip address of the device
@@ -38,57 +57,106 @@ async def query(host: str, request: Union[str, Dict], retry_count: int = 3) -> D
38
57
:param retry_count: how many retries to do in case of failure
39
58
:return: response dict
40
59
"""
60
+ self ._detect_event_loop_change ()
61
+
62
+ if not self .query_lock :
63
+ self .query_lock = asyncio .Lock ()
64
+
41
65
if isinstance (request , dict ):
42
66
request = json .dumps (request )
67
+ assert isinstance (request , str )
43
68
44
69
timeout = TPLinkSmartHomeProtocol .DEFAULT_TIMEOUT
45
- writer = None
70
+
71
+ async with self .query_lock :
72
+ return await self ._query (request , retry_count , timeout )
73
+
74
+ async def _connect (self , timeout : int ) -> bool :
75
+ """Try to connect or reconnect to the device."""
76
+ if self .writer :
77
+ return True
78
+
79
+ with contextlib .suppress (Exception ):
80
+ self .reader = self .writer = None
81
+ task = asyncio .open_connection (
82
+ self .host , TPLinkSmartHomeProtocol .DEFAULT_PORT
83
+ )
84
+ self .reader , self .writer = await asyncio .wait_for (task , timeout = timeout )
85
+ return True
86
+
87
+ return False
88
+
89
+ async def _execute_query (self , request : str ) -> Dict :
90
+ """Execute a query on the device and wait for the response."""
91
+ assert self .writer is not None
92
+ assert self .reader is not None
93
+
94
+ _LOGGER .debug ("> (%i) %s" , len (request ), request )
95
+ self .writer .write (TPLinkSmartHomeProtocol .encrypt (request ))
96
+ await self .writer .drain ()
97
+
98
+ packed_block_size = await self .reader .readexactly (self .BLOCK_SIZE )
99
+ length = struct .unpack (">I" , packed_block_size )[0 ]
100
+
101
+ buffer = await self .reader .readexactly (length )
102
+ response = TPLinkSmartHomeProtocol .decrypt (buffer )
103
+ json_payload = json .loads (response )
104
+ _LOGGER .debug ("< (%i) %s" , len (response ), pf (json_payload ))
105
+ return json_payload
106
+
107
+ async def close (self ):
108
+ """Close the connection."""
109
+ writer = self .writer
110
+ self ._reset ()
111
+ if writer :
112
+ writer .close ()
113
+ with contextlib .suppress (Exception ):
114
+ await writer .wait_closed ()
115
+
116
+ def _reset (self ):
117
+ """Clear any varibles that should not survive between loops."""
118
+ self .writer = None
119
+ self .reader = None
120
+ self .query_lock = None
121
+ self .loop = None
122
+
123
+ async def _query (self , request : str , retry_count : int , timeout : int ) -> Dict :
124
+ """Try to query a device."""
46
125
for retry in range (retry_count + 1 ):
126
+ if not await self ._connect (timeout ):
127
+ await self .close ()
128
+ if retry >= retry_count :
129
+ _LOGGER .debug ("Giving up after %s retries" , retry )
130
+ raise SmartDeviceException (
131
+ f"Unable to connect to the device: { self .host } "
132
+ )
133
+ continue
134
+
47
135
try :
48
- task = asyncio .open_connection (
49
- host , TPLinkSmartHomeProtocol .DEFAULT_PORT
136
+ assert self .reader is not None
137
+ assert self .writer is not None
138
+ return await asyncio .wait_for (
139
+ self ._execute_query (request ), timeout = timeout
50
140
)
51
- reader , writer = await asyncio .wait_for (task , timeout = timeout )
52
- _LOGGER .debug ("> (%i) %s" , len (request ), request )
53
- writer .write (TPLinkSmartHomeProtocol .encrypt (request ))
54
- await writer .drain ()
55
-
56
- buffer = bytes ()
57
- # Some devices send responses with a length header of 0 and
58
- # terminate with a zero size chunk. Others send the length and
59
- # will hang if we attempt to read more data.
60
- length = - 1
61
- while True :
62
- chunk = await reader .read (4096 )
63
- if length == - 1 :
64
- length = struct .unpack (">I" , chunk [0 :4 ])[0 ]
65
- buffer += chunk
66
- if (length > 0 and len (buffer ) >= length + 4 ) or not chunk :
67
- break
68
-
69
- response = TPLinkSmartHomeProtocol .decrypt (buffer [4 :])
70
- json_payload = json .loads (response )
71
- _LOGGER .debug ("< (%i) %s" , len (response ), pf (json_payload ))
72
-
73
- return json_payload
74
-
75
141
except Exception as ex :
142
+ await self .close ()
76
143
if retry >= retry_count :
77
144
_LOGGER .debug ("Giving up after %s retries" , retry )
78
145
raise SmartDeviceException (
79
- "Unable to query the device: %s" % ex
14
8000
6
+ f "Unable to query the device: { ex } "
80
147
) from ex
81
148
82
149
_LOGGER .debug ("Unable to query the device, retrying: %s" , ex )
83
150
84
- finally :
85
- if writer :
86
- writer .close ()
87
- await writer .wait_closed ()
88
-
89
151
# make mypy happy, this should never be reached..
152
+ await self .close ()
90
153
raise SmartDeviceException ("Query reached somehow to unreachable" )
91
154
155
+ def __del__ (self ):
156
+ if self .writer and self .loop and self .loop .is_running ():
157
+ self .writer .close ()
158
+ self ._reset ()
159
+
92
160
@staticmethod
93
161
def _xor_payload (unencrypted ):
94
162
key = TPLinkSmartHomeProtocol .INITIALIZATION_VECTOR
0 commit comments