8000 Update smart request parameter handling (#1061) · python-kasa/python-kasa@58afeb2 · GitHub
[go: up one dir, main page]

8000
Skip to content

Commit 58afeb2

Browse files
authored
Update smart request parameter handling (#1061)
Changes to the smart request handling: - Do not send params if null - Drop the requestId parameter - get_preset_rules doesn't send parameters for preset component version less than 3 - get_led_info no longer sends the wrong parameters - get_on_off_gradually_info no longer sends an empty {} parameter
1 parent 06ff598 commit 58afeb2

File tree

5 files changed

+19
-91
lines changed

5 files changed

+19
-91
lines changed

kasa/smart/modules/led.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class Led(SmartModule, LedInterface):
1616

1717
def query(self) -> dict:
1818
"""Query to execute during the update cycle."""
19-
return {self.QUERY_GETTER_NAME: {"led_rule": None}}
19+
return {self.QUERY_GETTER_NAME: None}
2020

2121
@property
2222
def mode(self):

kasa/smart/modules/lightpreset.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,9 @@ def query(self) -> dict:
153153
"""Query to execute during the update cycle."""
154154
if self._state_in_sysinfo: # Child lights can have states in the child info
155155
return {}
156+
if self.supported_version < 3:
157+
return {self.QUERY_GETTER_NAME: None}
158+
156159
return {self.QUERY_GETTER_NAME: {"start_index": 0}}
157160

158161
async def _check_supported(self):

kasa/smart/modules/lighttransition.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def query(self) -> dict:
234234
if self._state_in_sysinfo:
235235
return {}
236236
else:
237-
return {self.QUERY_GETTER_NAME: {}}
237+
return {self.QUERY_GETTER_NAME: None}
238238

239239
async def _check_supported(self):
240240
"""Additional check to see if the module is supported by the device."""

kasa/smartprotocol.py

Lines changed: 8 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ def __init__(
6666
"""Create a protocol object."""
6767
super().__init__(transport=transport)
6868
self._terminal_uuid: str = base64.b64encode(md5(uuid.uuid4().bytes)).decode()
69-
self._request_id_generator = SnowflakeId(1, 1)
7069
self._query_lock = asyncio.Lock()
7170
self._multi_request_batch_size = (
7271
self._transport._config.batch_size or self.DEFAULT_MULTI_REQUEST_BATCH_SIZE
@@ -77,11 +76,11 @@ def get_smart_request(self, method, params=None) -> str:
7776
"""Get a request message as a string."""
7877
request = {
7978
"method": method,
80-
"params": params,
81-
"requestID": self._request_id_generator.generate_id(),
8279
"request_time_milis": round(time.time() * 1000),
8380
"terminal_uuid": self._terminal_uuid,
8481
}
82+
if params:
83+
request["params"] = params
8584
return json_dumps(request)
8685

8786
async def query(self, request: str | dict, retry_count: int = 3) -> dict:
@@ -157,8 +156,10 @@ async def _execute_multiple_query(self, requests: dict, retry_count: int) -> dic
157156
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
158157
multi_result: dict[str, Any] = {}
159158
smart_method = "multipleRequest"
159+
160160
multi_requests = [
161-
{"method": method, "params": params} for method, params in requests.items()
161+
{"method": method, "params": params} if params else {"method": method}
162+
for method, params in requests.items()
162163
]
163164

164165
end = len(multi_requests)
@@ -168,7 +169,7 @@ async def _execute_multiple_query(self, requests: dict, retry_count: int) -> dic
168169
# If step is 1 do not send request batches
169170
for request in multi_requests:
170171
method = request["method"]
171-
req = self.get_smart_request(method, request["params"])
172+
req = self.get_smart_request(method, request.get("params"))
172173
resp = await self._transport.send(req)
173174
self._handle_response_error_code(resp, method, raise_on_error=False)
174175
multi_result[method] = resp["result"]
@@ -347,86 +348,6 @@ async def close(self) -> None:
347348
await self._transport.close()
348349

349350

350-
class SnowflakeId:
351-
"""Class for generating snowflake ids."""
352-
353-
EPOCH = 1420041600000 # Custom epoch (in milliseconds)
354-
WORKER_ID_BITS = 5
355-
DATA_CENTER_ID_BITS = 5
356-
SEQUENCE_BITS = 12
357-
358-
MAX_WORKER_ID = (1 << WORKER_ID_BITS) - 1
359-
MAX_DATA_CENTER_ID = (1 << DATA_CENTER_ID_BITS) - 1
360-
361-
SEQUENCE_MASK = (1 << SEQUENCE_BITS) - 1
362-
363-
def __init__(self, worker_id, data_center_id):
364-
if worker_id > SnowflakeId.MAX_WORKER_ID or worker_id < 0:
365-
raise ValueError(
366-
"Worker ID can't be greater than "
367-
+ str(SnowflakeId.MAX_WORKER_ID)
368-
+ " or less than 0"
369-
)
370-
if data_center_id > SnowflakeId.MAX_DATA_CENTER_ID or data_center_id < 0:
371-
raise ValueError(
372-
"Data center ID can't be greater than "
373-
+ str(SnowflakeId.MAX_DATA_CENTER_ID)
374-
+ " or less than 0"
375-
)
376-
377-
self.worker_id = worker_id
378-
self.data_center_id = data_center_id
379-
self.sequence = 0
380-
self.last_timestamp = -1
381-
382-
def generate_id(self):
383-
"""Generate a snowflake id."""
384-
timestamp = self._current_millis()
385-
386-
if timestamp < self.last_timestamp:
387-
raise ValueError("Clock moved backwards. Refusing to generate ID.")
388-
389-
if timestamp == self.last_timestamp:
390-
# Within the same millisecond, increment the sequence number
391-
self.sequence = (self.sequence + 1) & SnowflakeId.SEQUENCE_MASK
392-
if self.sequence == 0:
393-
# Sequence exceeds its bit range, wait until the next millisecond
394-
timestamp = self._wait_next_millis(self.last_timestamp)
395-
else:
396-
# New millisecond, reset the sequence number
397-
self.sequence = 0
398-
399-
# Update the last timestamp
400-
self.last_timestamp = timestamp
401-
402-
# Generate and return the final ID
403-
return (
404-
(
405-
(timestamp - SnowflakeId.EPOCH)
406-
<< (
407-
SnowflakeId.WORKER_ID_BITS
408-
+ SnowflakeId.SEQUENCE_BITS
409-
+ SnowflakeId.DATA_CENTER_ID_BITS
410-
)
411-
)
412-
| (
413-
self.data_center_id
414-
<< (SnowflakeId.SEQUENCE_BITS + SnowflakeId.WORKER_ID_BITS)
415-
)
416-
| (self.worker_id << SnowflakeId.SEQUENCE_BITS)
417-
| self.sequence
418-
)
419-
420-
def _current_millis(self):
421-
return round(time.monotonic() * 1000)
422-
423-
def _wait_next_millis(self, last_timestamp):
424-
timestamp = self._current_millis()
425-
while timestamp <= last_timestamp:
426-
timestamp = self._current_millis()
427-
return timestamp
428-
429-
430351
class _ChildProtocolWrapper(SmartProtocol):
431352
"""Protocol wrapper for controlling child devices.
432353
@@ -456,6 +377,8 @@ def _get_method_and_params_for_request(self, request):
456377
smart_method = "multipleRequest"
457378
requests = [
458379
{"method": method, "params": params}
380+
if params
381+
else {"method": method}
459382
for method, params in request.items()
460383
]
461384
smart_params = {"requests": requests}

kasa/tests/fakeprotocol_smart.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,9 @@ def credentials_hash(self):
119119
async def send(self, request: str):
120120
request_dict = json_loads(request)
121121
method = request_dict["method"]
122-
params = request_dict["params"]
122+
123123
if method == "multipleRequest":
124+
params = request_dict["params"]
124125
responses = []
125126
for request in params["requests"]:
126127
response = self._send_request(request) # type: ignore[arg-type]
@@ -308,12 +309,13 @@ def _edit_preset_rules(self, info, params):
308309

309310
def _send_request(self, request_dict: dict):
310311
method = request_dict["method"]
311-
params = request_dict["params"]
312312

313313
info = self.info
314314
if method == "control_child":
315-
return self._handle_control_child(params)
316-
elif method == "component_nego" or method[:4] == "get_":
315+
return self._handle_control_child(request_dict["params"])
316+
317+
params = request_dict.get("params")
318+
if method == "component_nego" or method[:4] == "get_":
317319
if method in info:
318320
result = copy.deepcopy(info[method])
319321
if "start_index" in result and "sum" in result:

0 commit comments

Comments
 (0)
0