8000 Update DiscoveryResult to use Mashumaro instead of pydantic (#1231) · python-kasa/python-kasa@254a9af · GitHub
[go: up one dir, main page]

Skip to content

Commit 254a9af

Browse files
authored
Update DiscoveryResult to use Mashumaro instead of pydantic (#1231)
Mashumaro is faster and doesn't come with all versioning problems that pydantic does. A basic perf test deserializing all of our discovery results fixtures shows mashumaro as being about 6 times faster deserializing dicts than pydantic. It's much faster parsing from a json string but that's likely because it uses orjson under the hood although that's not really our use case at the moment. ``` PYDANTIC - ms ================= json dict ----------------- 4.7665 1.3268 3.1548 1.5922 3.1130 1.8039 4.2834 2.7606 2.0669 1.3757 2.0163 1.6377 3.1667 1.3561 4.1296 2.7297 2.0132 1.3471 4.0648 1.4105 MASHUMARO - ms ================= json dict ----------------- 0.5977 0.5543 0.5336 0.2983 0.3955 0.2549 0.6516 0.2742 0.5386 0.2706 0.6678 0.2580 0.4120 0.2511 0.3836 0.2472 0.4020 0.2465 0.4268 0.2487 ```
1 parent 9d5e07b commit 254a9af

File tree

9 files changed

+81
-41
lines changed

9 files changed

+81
-41
lines changed

devtools/dump_devinfo.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ async def cli(
319319
click.echo("Host and discovery info given, trying connect on %s." % host)
320320

321321
di = json.loads(discovery_info)
322-
dr = DiscoveryResult(**di)
322+
dr = DiscoveryResult.from_dict(di)
323323
connection_type = DeviceConnectionParameters.from_values(
324324
dr.device_type,
325325
dr.mgt_encrypt_schm.encrypt_type,
@@ -336,7 +336,7 @@ async def cli(
336336
basedir,
337337
autosave,
338338
device.protocol,
339-
discovery_info=dr.get_dict(),
339+
discovery_info=dr.to_dict(),
340340
batch_size=batch_size,
341341
)
342342
elif device_family and encrypt_type:
@@ -443,7 +443,7 @@ async def get_legacy_fixture(protocol, *, discovery_info):
443443
if discovery_info and not discovery_info.get("system"):
444444
# Need to recreate a DiscoverResult here because we don't want the aliases
445445
# in the fixture, we want the actual field names as returned by the device.
446-
dr = DiscoveryResult(**protocol._discovery_info)
446+
dr = DiscoveryResult.from_dict(protocol._discovery_info)
447447
final["discovery_result"] = dr.dict(
448448
by_alias=False, exclude_unset=True, exclude_none=True, exclude_defaults=True
449449
)
@@ -960,10 +960,8 @@ async def get_smart_fixtures(
960960
# Need to recreate a DiscoverResult here because we don't want the aliases
961961
# in the fixture, we want the actual field names as returned by the device.
962962
if discovery_info:
963-
dr = DiscoveryResult(**discovery_info) # type: ignore
964-
final["discovery_result"] = dr.dict(
965-
by_alias=False, exclude_unset=True, exclude_none=True, exclude_defaults=True
966-
)
963+
dr = DiscoveryResult.from_dict(discovery_info) # type: ignore
964+
final["discovery_result"] = dr.to_dict()
967965

968966
click.echo("Got %s successes" % len(successes))
969967
click.echo(click.style("## device info file ##", bold=True))

kasa/cli/discover.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def _echo_discovery_info(discovery_info) -> None:
207207
return
208208

209209
try:
210-
dr = DiscoveryResult(**discovery_info)
210+
dr = DiscoveryResult.from_dict(discovery_info)
211211
except ValidationError:
212212
_echo_dictionary(discovery_info)
213213
return

kasa/discover.py

Lines changed: 45 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@
9090
import socket
9191
import struct
9292
from asyncio.transports import DatagramTransport
93+
from dataclasses import dataclass, field
9394
from pprint import pformat as pf
9495
from typing import (
9596
TYPE_CHECKING,
@@ -108,7 +109,8 @@
108109
# When support for cpython older than 3.11 is dropped
109110
# async_timeout can be replaced with asyncio.timeout
110111
from async_timeout import timeout as asyncio_timeout
111-
from pydantic.v1 import BaseModel, ValidationError
112+
from mashumaro import field_options
113+
from mashumaro.config import BaseConfig
112114

113115
from kasa import Device
114116
from kasa.credentials import Credentials
@@ -130,6 +132,7 @@
130132
from kasa.experimental import Experimental
131133
from kasa.iot.iotdevice import IotDevice
132134
from kasa.iotprotocol import REDACTORS as IOT_REDACTORS
135+
from kasa.json import DataClassJSONMixin
133136
from kasa.json import dumps as json_dumps
134137
from kasa.json import loads as json_loads
135138
from kasa.protocol import mask_mac, redact_data
@@ -647,7 +650,7 @@ async def try_connect_all(
647650
def _get_device_class(info: dict) -> type[Device]:
648651
"""Find SmartDevice subclass for device described by passed data."""
649652
if "result" in info:
650-
discovery_result = DiscoveryResult(**info["result"])
653+
discovery_result = DiscoveryResult.from_dict(info["result"])
651654
https = discovery_result.mgt_encrypt_schm.is_support_https
652655
dev_class = get_device_class_from_family(
653656
discovery_result.device_type, https=https
@@ -721,12 +724,8 @@ def _get_device_instance(
721724
f"Unable to read response from device: {config.host}: {ex}"
722725
) from ex
723726
try:
724-
discovery_result = DiscoveryResult(**info["result"])
725-
if (
726-
encrypt_info := discovery_result.encrypt_info
727-
) and encrypt_info.sym_schm == "AES":
728-
Discover._decrypt_discovery_data(discovery_result)
729-
except ValidationError as ex:
727+
discovery_result = DiscoveryResult.from_dict(info["result"])
728+
except Exception as ex:
730729
if debug_enabled:
731730
data = (
732731
redact_data(info, NEW_DISCOVERY_REDACTORS)
@@ -742,6 +741,16 @@ def _get_device_instance(
742741
f"Unable to parse discovery from device: {config.host}: {ex}",
743742
host=config.host,
744743
) from ex
744+
# Decrypt the data
745+
if (
746+
encrypt_info := discovery_result.encrypt_info
747+
) and encrypt_info.sym_schm == "AES":
748+
try:
749+
Discover._decrypt_discovery_data(discovery_result)
750+
except Exception:
751+
_LOGGER.exception(
752+
"Unable to decrypt discovery data %s: %s", config.host, data
753+
)
745754

746755
type_ = discovery_result.device_type
747756
encrypt_schm = discovery_result.mgt_encrypt_schm
@@ -754,7 +763,7 @@ def _get_device_instance(
754763
raise UnsupportedDeviceError(
755764
f"Unsupported device {config.host} of type {type_} "
756765
+ "with no encryption type",
757-
discovery_result=discovery_result.get_dict(),
766+
discovery_result=discovery_result.to_dict(),
758767
host=config.host,
759768
)
760769
config.connection_type = DeviceConnectionParameters.from_values(
@@ -767,7 +776,7 @@ def _get_device_instance(
767776
raise UnsupportedDeviceError(
768777
f"Unsupported device {config.host} of type {type_} "
769778
+ f"with encrypt_type {discovery_result.mgt_encrypt_schm.encrypt_type}",
770-
discovery_result=discovery_result.get_dict(),
779+
discovery_result=discovery_result.to_dict(),
771780
host=config.host,
772781
) from ex
773782
if (
@@ -778,7 +787,7 @@ def _get_device_instance(
778787
_LOGGER.warning("Got unsupported device type: %s", type_)
779788
raise UnsupportedDeviceError(
780789
f"Unsupported device {config.host} of type {type_}: {info}",
781-
discovery_result=discovery_result.get_dict(),
790+
discovery_result=discovery_result.to_dict(),
782791
host=config.host,
783792
)
784793
if (protocol := get_protocol(config)) is None:
@@ -788,7 +797,7 @@ def _get_device_instance(
788797
raise UnsupportedDeviceError(
789798
f"Unsupported encryption scheme {config.host} of "
790799
+ f"type {config.connection_type.to_dict()}: {info}",
791-
discovery_result=discovery_result.get_dict(),
800+
discovery_result=discovery_result.to_dict(),
792801
host=config.host,
793802
)
794803

@@ -801,42 +810,59 @@ def _get_device_instance(
801810
_LOGGER.debug("[DISCOVERY] %s << %s", config.host, pf(data))
802811
device = device_class(config.host, protocol=protocol)
803812

804-
di = discovery_result.get_dict()
813+
di = discovery_result.to_dict()
805814
di["model"], _, _ = discovery_result.device_model.partition("(")
806815
device.update_from_discover_info(di)
807816
return device
808817

809818

810-
class EncryptionScheme(BaseModel):
819+
class _DiscoveryBaseMixin(DataClassJSONMixin):
820+
"""Base class for serialization mixin."""
821+
822+
class Config(BaseConfig):
823+
"""Serialization config."""
824+
825+
omit_none = True
826+
omit_default = True
827+
serialize_by_alias = True
828+
829+
830+
@dataclass
831+
class EncryptionScheme(_DiscoveryBaseMixin):
811832
"""Base model for encryption scheme of discovery result."""
812833

813834
is_support_https: bool
814-
encrypt_type: Optional[str] # noqa: UP007
835+
encrypt_type: Optional[str] = None # noqa: UP007
815836
http_port: Optional[int] = None # noqa: UP007
816837
lv: Optional[int] = None # noqa: UP007
817838

818839

819-
class EncryptionInfo(BaseModel):
840+
@dataclass
841+
class EncryptionInfo(_DiscoveryBaseMixin):
820842
"""Base model for encryption info of discovery result."""
821843

822844
sym_schm: str
823845
key: str
824846
data: str
825847

826848

827-
class DiscoveryResult(BaseModel):
849+
@dataclass
850+
class DiscoveryResult(_DiscoveryBaseMixin):
828851
"""Base model for discovery result."""
829852

830853
device_type: str
831854
device_model: str
832-
device_name: Optional[str] # noqa: UP007
855+
device_id: str
833856
ip: str
834857
mac: str
835858
mgt_encrypt_schm: EncryptionScheme
859+
device_name: Optional[str] = None # noqa: UP007
836860
encrypt_info: Optional[EncryptionInfo] = None # noqa: UP007
837861
encrypt_type: Optional[list[str]] = None # noqa: UP007
838862
decrypted_data: Optional[dict] = None # noqa: UP007
839-
device_id: str
863+
is_reset_wifi: Optional[bool] = field( # noqa: UP007
864+
metadata=field_options(alias="isResetWiFi"), default=None
865+
)
840866

841867
firmware_version: Optional[str] = None # noqa: UP007
842868
hardware_version: Optional[str] = None # noqa: UP007
@@ -845,12 +871,3 @@ class DiscoveryResult(BaseModel):
845871
is_support_iot_cloud: Optional[bool] = None # noqa: UP007
846872
obd_src: Optional[str] = None # noqa: UP007
847873
factory_default: Optional[bool] = None # noqa: UP007
848-
849-
def get_dict(self) -> dict:
850-
"""Return a dict for this discovery result.
851-
852-
containing only the values actually set and with aliases as field names.
853-
"""
854-
return self.dict(
855-
by_alias=False, exclude_unset=True, exclude_none=True, exclude_defaults=True
856-
)

kasa/json.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,13 @@ def dumps(obj: Any, *, default: Callable | None = None) -> str:
2121
return json.dumps(obj, separators=(",", ":"))
2222

2323
loads = json.loads
24+
25+
26+
try:
27+
from mashumaro.mixins.orjson import DataClassORJSONMixin
28+
29+
DataClassJSONMixin = DataClassORJSONMixin
30+
except ImportError:
31+
from mashumaro.mixins.json import DataClassJSONMixin as JSONMixin
32+
33+
DataClassJSONMixin = JSONMixin # type: ignore[assignment, misc]

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ dependencies = [
1414
"aiohttp>=3",
1515
"typing-extensions>=4.12.2,<5.0",
1616
"tzdata>=2024.2 ; platform_system == 'Windows'",
17+
"mashumaro>=3.14",
1718
]
1819

1920
classifiers = [

tests/test_cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -616,7 +616,7 @@ async def _state(dev: Device):
616616

617617
mocker.patch("kasa.cli.device.state", new=_state)
618618

619-
dr = DiscoveryResult(**discovery_mock.discovery_data["result"])
619+
dr = DiscoveryResult.from_dict(discovery_mock.discovery_data["result"])
620620
res = await runner.invoke(
621621
cli,
622622
[

tests/test_device_factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
def _get_connection_type_device_class(discovery_info):
4444
if "result" in discovery_info:
4545
device_class = Discover._get_device_class(discovery_info)
46-
dr = DiscoveryResult(**discovery_info["result"])
46+
dr = DiscoveryResult.from_dict(discovery_info["result"])
4747

4848
connection_type = DeviceConnectionParameters.from_values(
4949
dr.device_type, dr.mgt_encrypt_schm.encrypt_type

tests/test_discovery.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -391,8 +391,8 @@ async def test_device_update_from_new_discovery_info(discovery_mock):
391391
discovery_data = discovery_mock.discovery_data
392392
device_class = Discover._get_device_class(discovery_data)
393393
device = device_class("127.0.0.1")
394-
discover_info = DiscoveryResult(**discovery_data["result"])
395-
discover_dump = discover_info.get_dict()
394+
discover_info = DiscoveryResult.from_dict(discovery_data["result"])
395+
discover_dump = discover_info.to_dict()
396396
model, _, _ = discover_dump["device_model"].partition("(")
397397
discover_dump["model"] = model
398398
device.update_from_discover_info(discover_dump)
@@ -652,7 +652,7 @@ async def test_discovery_decryption():
652652
"sym_schm": "AES",
653653
}
654654
info = {**UNSUPPORTED["result"], "encrypt_info": encrypt_info}
655-
dr = DiscoveryResult(**info)
655+
dr = DiscoveryResult.from_dict(info)
656656
Discover._decrypt_discovery_data(dr)
657657
assert dr.decrypted_data == data_dict
658658

uv.lock

Lines changed: 14 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)
0