diff --git a/CHANGELOG.md b/CHANGELOG.md index 437c756dd..18e0289c0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,12 +1,21 @@ # Changelog +## [0.4.0.dev5](https://github.com/python-kasa/python-kasa/tree/0.4.0.dev5) (2021-09-24) + +[Full Changelog](https://github.com/python-kasa/python-kasa/compare/0.4.0.dev4...0.4.0.dev5) + +**Merged pull requests:** + +- Add KL130 fixture, initial lightstrip tests [\#214](https://github.com/python-kasa/python-kasa/pull/214) ([rytilahti](https://github.com/rytilahti)) +- Keep connection open and lock to prevent duplicate requests [\#213](https://github.com/python-kasa/python-kasa/pull/213) ([bdraco](https://github.com/bdraco)) +- Cleanup discovery & add tests [\#212](https://github.com/python-kasa/python-kasa/pull/212) ([rytilahti](https://github.com/rytilahti)) + ## [0.4.0.dev4](https://github.com/python-kasa/python-kasa/tree/0.4.0.dev4) (2021-09-23) [Full Changelog](https://github.com/python-kasa/python-kasa/compare/0.4.0.dev3...0.4.0.dev4) **Implemented enhancements:** -- HS300 Children plugs have emeter [\#64](https://github.com/python-kasa/python-kasa/issues/64) - Improve emeterstatus API, move into own module [\#205](https://github.com/python-kasa/python-kasa/pull/205) ([rytilahti](https://github.com/rytilahti)) - Avoid temp array during encrypt and decrypt [\#204](https://github.com/python-kasa/python-kasa/pull/204) ([bdraco](https://github.com/bdraco)) - Add emeter support for strip sockets [\#203](https://github.com/python-kasa/python-kasa/pull/203) ([bdraco](https://github.com/bdraco)) @@ -19,6 +28,7 @@ **Fixed bugs:** - KL430: Throw error for Device specific information [\#189](https://github.com/python-kasa/python-kasa/issues/189) +- HS300 Children plugs have emeter [\#64](https://github.com/python-kasa/python-kasa/issues/64) - dump\_devinfo: handle latitude/longitude keys properly [\#175](https://github.com/python-kasa/python-kasa/pull/175) ([rytilahti](https://github.com/rytilahti)) **Closed issues:** @@ -34,6 +44,7 @@ **Merged pull requests:** +- Release 0.4.0.dev4 [\#210](https://github.com/python-kasa/python-kasa/pull/210) ([rytilahti](https://github.com/rytilahti)) - More CI fixes [\#208](https://github.com/python-kasa/python-kasa/pull/208) ([rytilahti](https://github.com/rytilahti)) - Fix CI dep installation [\#207](https://github.com/python-kasa/python-kasa/pull/207) ([rytilahti](https://github.com/rytilahti)) - Use github actions instead of azure pipelines [\#206](https://github.com/python-kasa/python-kasa/pull/206) ([rytilahti](https://github.com/rytilahti)) diff --git a/devtools/dump_devinfo.py b/devtools/dump_devinfo.py index 9d30f9674..1108e7fb4 100644 --- a/devtools/dump_devinfo.py +++ b/devtools/dump_devinfo.py @@ -78,16 +78,17 @@ def cli(host, debug): ), ] - protocol = TPLinkSmartHomeProtocol() - successes = [] for test_call in items: + + async def _run_query(): + protocol = TPLinkSmartHomeProtocol(host) + return await protocol.query({test_call.module: {test_call.method: None}}) + try: click.echo(f"Testing {test_call}..", nl=False) - info = asyncio.run( - protocol.query(host, {test_call.module: {test_call.method: None}}) - ) + info = asyncio.run(_run_query()) resp = info[test_call.module] except Exception as ex: click.echo(click.style(f"FAIL {ex}", fg="red")) @@ -107,8 +108,12 @@ def cli(host, debug): final = default_to_regular(final) + async def _run_final_query(): + protocol = TPLinkSmartHomeProtocol(host) + return await protocol.query(final_query) + try: - final = asyncio.run(protocol.query(host, final_query)) + final = asyncio.run(_run_final_query()) except Exception as ex: click.echo( click.style( diff --git a/kasa/cli.py b/kasa/cli.py index 626eadc2b..209fcf965 100755 --- a/kasa/cli.py +++ b/kasa/cli.py @@ -143,13 +143,11 @@ async def discover(ctx, timeout, discover_only, dump_raw): """Discover devices in the network.""" target = ctx.parent.params["target"] click.echo(f"Discovering devices on {target} for {timeout} seconds") - found_devs = await Discover.discover( - target=target, timeout=timeout, return_raw=dump_raw - ) + found_devs = await Discover.discover(target=target, timeout=timeout) if not discover_only: for ip, dev in found_devs.items(): if dump_raw: - click.echo(dev) + click.echo(dev.sys_info) continue ctx.obj = dev await ctx.invoke(state) diff --git a/kasa/discover.py b/kasa/discover.py index b12b79264..a408c2de9 100755 --- a/kasa/discover.py +++ b/kasa/discover.py @@ -3,7 +3,7 @@ import json import logging import socket -from typing import Awaitable, Callable, Dict, Mapping, Optional, Type, Union, cast +from typing import Awaitable, Callable, Dict, Optional, Type, cast from kasa.protocol import TPLinkSmartHomeProtocol from kasa.smartbulb import SmartBulb @@ -17,6 +17,7 @@ OnDiscoveredCallable = Callable[[SmartDevice], Awaitable[None]] +DeviceDict = Dict[str, SmartDevice] class _DiscoverProtocol(asyncio.DatagramProtocol): @@ -25,8 +26,7 @@ class _DiscoverProtocol(asyncio.DatagramProtocol): This is internal class, use :func:`Discover.discover`: instead. """ - discovered_devices: Dict[str, SmartDevice] - discovered_devices_raw: Dict[str, Dict] + discovered_devices: DeviceDict def __init__( self, @@ -40,10 +40,8 @@ def __init__( self.discovery_packets = discovery_packets self.interface = interface self.on_discovered = on_discovered - self.protocol = TPLinkSmartHomeProtocol() self.target = (target, Discover.DISCOVERY_PORT) self.discovered_devices = {} - self.discovered_devices_raw = {} def connection_made(self, transport) -> None: """Set socket options for broadcasting.""" @@ -62,7 +60,7 @@ def do_discover(self) -> None: """Send number of discovery datagrams.""" req = json.dumps(Discover.DISCOVERY_QUERY) _LOGGER.debug("[DISCOVERY] %s >> %s", self.target, Discover.DISCOVERY_QUERY) - encrypted_req = self.protocol.encrypt(req) + encrypted_req = TPLinkSmartHomeProtocol.encrypt(req) for i in range(self.discovery_packets): self.transport.sendto(encrypted_req[4:], self.target) # type: ignore @@ -72,7 +70,7 @@ def datagram_received(self, data, addr) -> None: if ip in self.discovered_devices: return - info = json.loads(self.protocol.decrypt(data)) + info = json.loads(TPLinkSmartHomeProtocol.decrypt(data)) _LOGGER.debug("[DISCOVERY] %s << %s", ip, info) device_class = Discover._get_device_class(info) @@ -80,13 +78,9 @@ def datagram_received(self, data, addr) -> None: device.update_from_discover_info(info) self.discovered_devices[ip] = device - self.discovered_devices_raw[ip] = info - if device_class is not None: - if self.on_discovered is not None: - asyncio.ensure_future(self.on_discovered(device)) - else: - _LOGGER.error("Received invalid response: %s", info) + if self.on_discovered is not None: + asyncio.ensure_future(self.on_discovered(device)) def error_received(self, ex): """Handle asyncio.Protocol errors.""" @@ -144,9 +138,8 @@ async def discover( on_discovered=None, timeout=5, discovery_packets=3, - return_raw=False, interface=None, - ) -> Mapping[str, Union[SmartDevice, Dict]]: + ) -> DeviceDict: """Discover supported devices. Sends discovery message to 255.255.255.255:9999 in order @@ -154,17 +147,17 @@ async def discover( and waits for given timeout for answers from devices. If you have multiple interfaces, you can use target parameter to specify the network for discovery. - If given, `on_discovered` coroutine will get passed with the :class:`SmartDevice`-derived object as parameter. + If given, `on_discovered` coroutine will get awaited with a :class:`SmartDevice`-derived object as parameter. - The results of the discovery are returned either as a list of :class:`SmartDevice`-derived objects - or as raw response dictionaries objects (if `return_raw` is True). + The results of the discovery are returned as a dict of :class:`SmartDevice`-derived objects keyed with IP addresses. + The devices are already initialized and all but emeter-related properties can be accessed directly. :param target: The target address where to send the broadcast discovery queries if multi-homing (e.g. 192.168.xxx.255). :param on_discovered: coroutine to execute on discovery :param timeout: How long to wait for responses, defaults to 5 - :param discovery_packets: Number of discovery packets are broadcasted. - :param return_raw: True to return JSON objects instead of Devices. - :return: + :param discovery_packets: Number of discovery packets to broadcast + :param interface: Bind to specific interface + :return: dictionary with discovered devices """ loop = asyncio.get_event_loop() transport, protocol = await loop.create_datagram_endpoint( @@ -186,9 +179,6 @@ async def discover( _LOGGER.debug("Discovered %s devices", len(protocol.discovered_devices)) - if return_raw: - return protocol.discovered_devices_raw - return protocol.discovered_devices @staticmethod @@ -199,17 +189,15 @@ async def discover_single(host: str) -> SmartDevice: :rtype: SmartDevice :return: Object for querying/controlling found device. """ - protocol = TPLinkSmartHomeProtocol() + protocol = TPLinkSmartHomeProtocol(host) - info = await protocol.query(host, Discover.DISCOVERY_QUERY) + info = await protocol.query(Discover.DISCOVERY_QUERY) device_class = Discover._get_device_class(info) - if device_class is not None: - dev = device_class(host) - await dev.update() - return dev + dev = device_class(host) + await dev.update() - raise SmartDeviceException("Unable to discover device, received: %s" % info) + return dev @staticmethod def _get_device_class(info: dict) -> Type[SmartDevice]: @@ -237,17 +225,4 @@ def _get_device_class(info: dict) -> Type[SmartDevice]: return SmartBulb - raise SmartDeviceException("Unknown device type: %s", type_) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.INFO) - loop = asyncio.get_event_loop() - - async def _on_device(dev): - await dev.update() - _LOGGER.info("Got device: %s", dev) - - devices = loop.run_until_complete(Discover.discover(on_discovered=_on_device)) - for ip, dev in devices.items(): - print(f"[{ip}] {dev}") + raise SmartDeviceException("Unknown device type: %s" % type_) diff --git a/kasa/protocol.py b/kasa/protocol.py index bbf13b995..b54029c66 100755 --- a/kasa/protocol.py +++ b/kasa/protocol.py @@ -10,11 +10,12 @@ http://www.apache.org/licenses/LICENSE-2.0 """ import asyncio +import contextlib import json import logging import struct from pprint import pformat as pf -from typing import Dict, Union +from typing import Dict, Optional, Union from .exceptions import SmartDeviceException @@ -28,8 +29,26 @@ class TPLinkSmartHomeProtocol: DEFAULT_PORT = 9999 DEFAULT_TIMEOUT = 5 - @staticmethod - async def query(host: str, request: Union[str, Dict], retry_count: int = 3) -> Dict: + BLOCK_SIZE = 4 + + def __init__(self, host: str) -> None: + """Create a protocol object.""" + self.host = host + self.reader: Optional[asyncio.StreamReader] = None + self.writer: Optional[asyncio.StreamWriter] = None + self.query_lock: Optional[asyncio.Lock] = None + self.loop: Optional[asyncio.AbstractEventLoop] = None + + def _detect_event_loop_change(self) -> None: + """Check if this object has been reused betwen event loops.""" + loop = asyncio.get_running_loop() + if not self.loop: + self.loop = loop + elif self.loop != loop: + _LOGGER.warning("Detected protocol reuse between different event loop") + self._reset() + + async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: """Request information from a TP-Link SmartHome Device. :param str host: host name or ip address of the device @@ -38,57 +57,106 @@ async def query(host: str, request: Union[str, Dict], retry_count: int = 3) -> D :param retry_count: how many retries to do in case of failure :return: response dict """ + self._detect_event_loop_change() + + if not self.query_lock: + self.query_lock = asyncio.Lock() + if isinstance(request, dict): request = json.dumps(request) + assert isinstance(request, str) timeout = TPLinkSmartHomeProtocol.DEFAULT_TIMEOUT - writer = None + + async with self.query_lock: + return await self._query(request, retry_count, timeout) + + async def _connect(self, timeout: int) -> bool: + """Try to connect or reconnect to the device.""" + if self.writer: + return True + + with contextlib.suppress(Exception): + self.reader = self.writer = None + task = asyncio.open_connection( + self.host, TPLinkSmartHomeProtocol.DEFAULT_PORT + ) + self.reader, self.writer = await asyncio.wait_for(task, timeout=timeout) + return True + + return False + + async def _execute_query(self, request: str) -> Dict: + """Execute a query on the device and wait for the response.""" + assert self.writer is not None + assert self.reader is not None + + _LOGGER.debug("> (%i) %s", len(request), request) + self.writer.write(TPLinkSmartHomeProtocol.encrypt(request)) + await self.writer.drain() + + packed_block_size = await self.reader.readexactly(self.BLOCK_SIZE) + length = struct.unpack(">I", packed_block_size)[0] + + buffer = await self.reader.readexactly(length) + response = TPLinkSmartHomeProtocol.decrypt(buffer) + json_payload = json.loads(response) + _LOGGER.debug("< (%i) %s", len(response), pf(json_payload)) + return json_payload + + async def close(self): + """Close the connection.""" + writer = self.writer + self._reset() + if writer: + writer.close() + with contextlib.suppress(Exception): + await writer.wait_closed() + + def _reset(self): + """Clear any varibles that should not survive between loops.""" + self.writer = None + self.reader = None + self.query_lock = None + self.loop = None + + async def _query(self, request: str, retry_count: int, timeout: int) -> Dict: + """Try to query a device.""" for retry in range(retry_count + 1): + if not await self._connect(timeout): + await self.close() + if retry >= retry_count: + _LOGGER.debug("Giving up after %s retries", retry) + raise SmartDeviceException( + f"Unable to connect to the device: {self.host}" + ) + continue + try: - task = asyncio.open_connection( - host, TPLinkSmartHomeProtocol.DEFAULT_PORT + assert self.reader is not None + assert self.writer is not None + return await asyncio.wait_for( + self._execute_query(request), timeout=timeout ) - reader, writer = await asyncio.wait_for(task, timeout=timeout) - _LOGGER.debug("> (%i) %s", len(request), request) - writer.write(TPLinkSmartHomeProtocol.encrypt(request)) - await writer.drain() - - buffer = bytes() - # Some devices send responses with a length header of 0 and - # terminate with a zero size chunk. Others send the length and - # will hang if we attempt to read more data. - length = -1 - while True: - chunk = await reader.read(4096) - if length == -1: - length = struct.unpack(">I", chunk[0:4])[0] - buffer += chunk - if (length > 0 and len(buffer) >= length + 4) or not chunk: - break - - response = TPLinkSmartHomeProtocol.decrypt(buffer[4:]) - json_payload = json.loads(response) - _LOGGER.debug("< (%i) %s", len(response), pf(json_payload)) - - return json_payload - except Exception as ex: + await self.close() if retry >= retry_count: _LOGGER.debug("Giving up after %s retries", retry) raise SmartDeviceException( - "Unable to query the device: %s" % ex + f"Unable to query the device: {ex}" ) from ex _LOGGER.debug("Unable to query the device, retrying: %s", ex) - finally: - if writer: - writer.close() - await writer.wait_closed() - # make mypy happy, this should never be reached.. + await self.close() raise SmartDeviceException("Query reached somehow to unreachable") + def __del__(self): + if self.writer and self.loop and self.loop.is_running(): + self.writer.close() + self._reset() + @staticmethod def _xor_payload(unencrypted): key = TPLinkSmartHomeProtocol.INITIALIZATION_VECTOR diff --git a/kasa/smartdevice.py b/kasa/smartdevice.py index 11c7d1c96..fabf26b32 100755 --- a/kasa/smartdevice.py +++ b/kasa/smartdevice.py @@ -194,7 +194,7 @@ def __init__(self, host: str) -> None: """ self.host = host - self.protocol = TPLinkSmartHomeProtocol() + self.protocol = TPLinkSmartHomeProtocol(host) self.emeter_type = "emeter" _LOGGER.debug("Initializing %s of type %s", self.host, type(self)) self._device_type = DeviceType.Unknown @@ -234,7 +234,7 @@ async def _query_helper( request = self._create_request(target, cmd, arg, child_ids) try: - response = await self.protocol.query(host=self.host, request=request) + response = await self.protocol.query(request=request) except Exception as ex: raise SmartDeviceException(f"Communication error on {target}:{cmd}") from ex @@ -272,7 +272,7 @@ async def get_sys_info(self) -> Dict[str, Any]: """Retrieve system information.""" return await self._query_helper("system", "get_sysinfo") - async def update(self): + async def update(self, update_children: bool = True): """Query the device to update the data. Needed for properties that are decorated with `requires_update`. @@ -285,7 +285,7 @@ async def update(self): # See #105, #120, #161 if self._last_update is None: _LOGGER.debug("Performing the initial update to obtain sysinfo") - self._last_update = await self.protocol.query(self.host, req) + self._last_update = await self.protocol.query(req) self._sys_info = self._last_update["system"]["get_sysinfo"] # If the device has no emeter, we are done for the initial update # Otherwise we will follow the regular code path to also query @@ -299,7 +299,7 @@ async def update(self): ) req.update(self._create_emeter_request()) - self._last_update = await self.protocol.query(self.host, req) + self._last_update = await self.protocol.query(req) self._sys_info = self._last_update["system"]["get_sysinfo"] def update_from_discover_info(self, info): @@ -383,8 +383,8 @@ def location(self) -> Dict: loc["latitude"] = sys_info["latitude"] loc["longitude"] = sys_info["longitude"] elif "latitude_i" in sys_info and "longitude_i" in sys_info: - loc["latitude"] = sys_info["latitude_i"] - loc["longitude"] = sys_info["longitude_i"] + loc["latitude"] = sys_info["latitude_i"] / 10000 + loc["longitude"] = sys_info["longitude_i"] / 10000 else: _LOGGER.warning("Unsupported device location.") diff --git a/kasa/smartstrip.py b/kasa/smartstrip.py index c1235920d..71373a7a9 100755 --- a/kasa/smartstrip.py +++ b/kasa/smartstrip.py @@ -87,12 +87,12 @@ def is_on(self) -> bool: """Return if any of the outlets are on.""" return any(plug.is_on for plug in self.children) - async def update(self): + async def update(self, update_children: bool = True): """Update some of the attributes. Needed for methods that are decorated with `requires_update`. """ - await super().update() + await super().update(update_children) # Initialize the child devices during the first update. if not self.children: @@ -103,7 +103,7 @@ async def update(self): SmartStripPlug(self.host, parent=self, child_id=child["id"]) ) - if self.has_emeter: + if update_children and self.has_emeter: for plug in self.children: await plug.update() @@ -243,13 +243,13 @@ def __init__(self, host: str, parent: "SmartStrip", child_id: str) -> None: self._sys_info = parent._sys_info self._device_type = DeviceType.StripSocket - async def update(self): + async def update(self, update_children: bool = True): """Query the device to update the data. Needed for properties that are decorated with `requires_update`. """ self._last_update = await self.parent.protocol.query( - self.host, self._create_emeter_request() + self._create_emeter_request() ) def _create_request( diff --git a/kasa/tests/conftest.py b/kasa/tests/conftest.py index 46cb86216..a7ab5d13a 100644 --- a/kasa/tests/conftest.py +++ b/kasa/tests/conftest.py @@ -4,6 +4,7 @@ import os from os.path import basename from pathlib import Path, PurePath +from typing import Dict from unittest.mock import MagicMock import pytest # type: ignore # see https://github.com/pytest-dev/pytest/issues/3342 @@ -39,6 +40,8 @@ ALL_DEVICES = BULBS.union(PLUGS).union(STRIPS).union(DIMMERS) +IP_MODEL_CACHE: Dict[str, str] = {} + def filter_model(desc, filter): filtered = list() @@ -53,8 +56,6 @@ def filter_model(desc, filter): def parametrize(desc, devices, ids=None): - # if ids is None: - # ids = ["on", "off"] return pytest.mark.parametrize( "dev", filter_model(desc, devices), indirect=True, ids=ids ) @@ -63,32 +64,11 @@ def parametrize(desc, devices, ids=None): has_emeter = parametrize("has emeter", WITH_EMETER) no_emeter = parametrize("no emeter", ALL_DEVICES - WITH_EMETER) - -def name_for_filename(x): - from os.path import basename - - return basename(x) - - -bulb = parametrize("bulbs", BULBS, ids=name_for_filename) -plug = parametrize("plugs", PLUGS, ids=name_for_filename) -strip = parametrize("strips", STRIPS, ids=name_for_filename) -dimmer = parametrize("dimmers", DIMMERS, ids=name_for_filename) -lightstrip = parametrize("lightstrips", LIGHT_STRIPS, ids=name_for_filename) - -# This ensures that every single file inside fixtures/ is being placed in some category -categorized_fixtures = set( - dimmer.args[1] + strip.args[1] + plug.args[1] + bulb.args[1] + lightstrip.args[1] -) -diff = set(SUPPORTED_DEVICES) - set(categorized_fixtures) -if diff: - for file in diff: - print( - "No category for file %s, add to the corresponding set (BULBS, PLUGS, ..)" - % file - ) - raise Exception("Missing category for %s" % diff) - +bulb = parametrize("bulbs", BULBS, ids=basename) +plug = parametrize("plugs", PLUGS, ids=basename) +strip = parametrize("strips", STRIPS, ids=basename) +dimmer = parametrize("dimmers", DIMMERS, ids=basename) +lightstrip = parametrize("lightstrips", LIGHT_STRIPS, ids=basename) # bulb types dimmable = parametrize("dimmable", DIMMABLE) @@ -98,6 +78,28 @@ def name_for_filename(x): color_bulb = parametrize("color bulbs", COLOR_BULBS) non_color_bulb = parametrize("non-color bulbs", BULBS - COLOR_BULBS) + +def check_categories(): + """Check that every fixture file is categorized.""" + categorized_fixtures = set( + dimmer.args[1] + + strip.args[1] + + plug.args[1] + + bulb.args[1] + + lightstrip.args[1] + ) + diff = set(SUPPORTED_DEVICES) - set(categorized_fixtures) + if diff: + for file in diff: + print( + "No category for file %s, add to the corresponding set (BULBS, PLUGS, ..)" + % file + ) + raise Exception("Missing category for %s" % diff) + + +check_categories() + # Parametrize tests to run with device both on and off turn_on = pytest.mark.parametrize("turn_on", [True, False]) @@ -138,23 +140,39 @@ def device_for_file(model): raise Exception("Unable to find type for %s", model) -def get_device_for_file(file): +async def _update_and_close(d): + await d.update() + await d.protocol.close() + return d + + +async def _discover_update_and_close(ip): + d = await Discover.discover_single(ip) + return await _update_and_close(d) + + +async def get_device_for_file(file): # if the wanted file is not an absolute path, prepend the fixtures directory p = Path(file) if not p.is_absolute(): p = Path(__file__).parent / "fixtures" / file - with open(p) as f: - sysinfo = json.load(f) - model = basename(file) - p = device_for_file(model)(host="127.0.0.123") - p.protocol = FakeTransportProtocol(sysinfo) - asyncio.run(p.update()) - return p + def load_file(): + with open(p) as f: + return json.load(f) + loop = asyncio.get_running_loop() + sysinfo = await loop.run_in_executor(None, load_file) -@pytest.fixture(params=SUPPORTED_DEVICES, scope="session") -def dev(request): + model = basename(file) + d = device_for_file(model)(host="127.0.0.123") + d.protocol = FakeTransportProtocol(sysinfo) + await _update_and_close(d) + return d + + +@pytest.fixture(params=SUPPORTED_DEVICES) +async def dev(request): """Device fixture. Provides a device (given --ip) or parametrized fixture for the supported devices. @@ -164,14 +182,28 @@ def dev(request): ip = request.config.getoption("--ip") if ip: - d = asyncio.run(Discover.discover_single(ip)) - asyncio.run(d.update()) - if d.model in file: - return d - else: + model = IP_MODEL_CACHE.get(ip) + d = None + if not model: + d = await _discover_update_and_close(ip) + IP_MODEL_CACHE[ip] = model = d.model + if model not in file: pytest.skip(f"skipping file {file}") + return d if d else await _discover_update_and_close(ip) + + return await get_device_for_file(file) + - return get_device_for_file(file) +@pytest.fixture(params=SUPPORTED_DEVICES, scope="session") +def discovery_data(request): + """Return raw discovery file contents as JSON. Used for discovery tests.""" + file = request.param + p = Path(file) + if not p.is_absolute(): + p = Path(__file__).parent / "fixtures" / file + + with open(p) as f: + return json.load(f) def pytest_addoption(parser): diff --git a/kasa/tests/fixtures/KL130(EU)_1.0_1.8.8.json b/kasa/tests/fixtures/KL130(EU)_1.0_1.8.8.json new file mode 100644 index 000000000..b57a01d28 --- /dev/null +++ b/kasa/tests/fixtures/KL130(EU)_1.0_1.8.8.json @@ -0,0 +1,85 @@ +{ + "smartlife.iot.common.emeter": { + "get_realtime": { + "err_code": 0, + "power_mw": 1300 + } + }, + "smartlife.iot.smartbulb.lightingservice": { + "get_light_state": { + "brightness": 5, + "color_temp": 2700, + "err_code": 0, + "hue": 1, + "mode": "normal", + "on_off": 1, + "saturation": 1 + } + }, + "system": { + "get_sysinfo": { + "active_mode": "schedule", + "alias": "bedroom", + "ctrl_protocols": { + "name": "Linkie", + "version": "1.0" + }, + "description": "Smart Wi-Fi LED Bulb with Color Changing", + "dev_state": "normal", + "deviceId": "0000000000000000000000000000000000000000", + "disco_ver": "1.0", + "err_code": 0, + "heapsize": 332316, + "hwId": "00000000000000000000000000000000", + "hw_ver": "1.0", + "is_color": 1, + "is_dimmable": 1, + "is_factory": false, + "is_variable_color_temp": 1, + "light_state": { + "brightness": 5, + "color_temp": 2700, + "hue": 1, + "mode": "normal", + "on_off": 1, + "saturation": 1 + }, + "mic_mac": "000000000000", + "mic_type": "IOT.SMARTBULB", + "model": "KL130(EU)", + "oemId": "00000000000000000000000000000000", + "preferred_state": [ + { + "brightness": 10, + "color_temp": 2500, + "hue": 0, + "index": 0, + "saturation": 0 + }, + { + "brightness": 100, + "color_temp": 0, + "hue": 299, + "index": 1, + "saturation": 95 + }, + { + "brightness": 100, + "color_temp": 0, + "hue": 120, + "index": 2, + "saturation": 75 + }, + { + "brightness": 100, + "color_temp": 0, + "hue": 240, + "index": 3, + "saturation": 75 + } + ], + "rssi": -62, + "sw_ver": "1.8.8 Build 190613 Rel.123436" + } + } +} diff --git a/kasa/tests/newfakes.py b/kasa/tests/newfakes.py index a37bb4147..a4764b660 100644 --- a/kasa/tests/newfakes.py +++ b/kasa/tests/newfakes.py @@ -83,9 +83,19 @@ def lb_dev_state(x): "icon_hash": str, "led_off": check_int_bool, "latitude": Any(All(float, Range(min=-90, max=90)), 0, None), - "latitude_i": Any(All(float, Range(min=-90, max=90)), 0, None), + "latitude_i": Any( + All(int, Range(min=-900000, max=900000)), + All(float, Range(min=-900000, max=900000)), + 0, + None, + ), "longitude": Any(All(float, Range(min=-180, max=180)), 0, None), - "longitude_i": Any(All(float, Range(min=-180, max=180)), 0, None), + "longitude_i": Any( + All(int, Range(min=-18000000, max=18000000)), + All(float, Range(min=-18000000, max=18000000)), + 0, + None, + ), "mac": check_mac, "model": str, "oemId": str, @@ -117,17 +127,17 @@ def lb_dev_state(x): { "brightness": All(int, Range(min=0, max=100)), "color_temp": int, - "hue": All(int, Range(min=0, max=255)), + "hue": All(int, Range(min=0, max=360)), "mode": str, "on_off": check_int_bool, - "saturation": All(int, Range(min=0, max=255)), + "saturation": All(int, Range(min=0, max=100)), "dft_on_state": Optional( { "brightness": All(int, Range(min=0, max=100)), "color_temp": All(int, Range(min=0, max=9000)), - "hue": All(int, Range(min=0, max=255)), + "hue": All(int, Range(min=0, max=360)), "mode": str, - "saturation": All(int, Range(min=0, max=255)), + "saturation": All(int, Range(min=0, max=100)), } ), "err_code": int, @@ -276,6 +286,8 @@ def success(res): class FakeTransportProtocol(TPLinkSmartHomeProtocol): def __init__(self, info): self.discovery_data = info + self.writer = None + self.reader = None proto = FakeTransportProtocol.baseproto for target in info: @@ -426,7 +438,7 @@ def light_state(self, x, *args): }, } - async def query(self, host, request, port=9999): + async def query(self, request, port=9999): proto = self.proto # collect child ids from context diff --git a/kasa/tests/test_bulb.py b/kasa/tests/test_bulb.py index 28fcd4cb7..ea8a28cb8 100644 --- a/kasa/tests/test_bulb.py +++ b/kasa/tests/test_bulb.py @@ -60,7 +60,7 @@ async def test_hsv(dev, turn_on): assert dev.is_color hue, saturation, brightness = dev.hsv - assert 0 <= hue <= 255 + assert 0 <= hue <= 360 assert 0 <= saturation <= 100 assert 0 <= brightness <= 100 diff --git a/kasa/tests/test_discovery.py b/kasa/tests/test_discovery.py index 1356892e9..c933cb124 100644 --- a/kasa/tests/test_discovery.py +++ b/kasa/tests/test_discovery.py @@ -1,7 +1,10 @@ # type: ignore +import sys + import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342 -from kasa import DeviceType, Discover, SmartDevice, SmartDeviceException +from kasa import DeviceType, Discover, SmartDevice, SmartDeviceException, protocol +from kasa.discover import _DiscoverProtocol from .conftest import bulb, dimmer, lightstrip, plug, pytestmark, strip @@ -47,3 +50,58 @@ async def test_type_unknown(): invalid_info = {"system": {"get_sysinfo": {"type": "nosuchtype"}}} with pytest.raises(SmartDeviceException): Discover._get_device_class(invalid_info) + + +@pytest.mark.skipif(sys.version_info < (3, 8), reason="3.8 is first one with asyncmock") +async def test_discover_single(discovery_data: dict, mocker): + """Make sure that discover_single returns an initialized SmartDevice instance.""" + mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data) + x = await Discover.discover_single("127.0.0.1") + assert issubclass(x.__class__, SmartDevice) + assert x._sys_info is not None + + +INVALIDS = [ + ("No 'system' or 'get_sysinfo' in response", {"no": "data"}), + ( + "Unable to find the device type field", + {"system": {"get_sysinfo": {"missing_type": 1}}}, + ), + ("Unknown device type: foo", {"system": {"get_sysinfo": {"type": "foo"}}}), +] + + +@pytest.mark.skipif(sys.version_info < (3, 8), reason="3.8 is first one with asyncmock") +@pytest.mark.parametrize("msg, data", INVALIDS) +async def test_discover_invalid_info(msg, data, mocker): + """Make sure that invalid discovery information raises an exception.""" + mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=data) + with pytest.raises(SmartDeviceException, match=msg): + await Discover.discover_single("127.0.0.1") + + +async def test_discover_send(mocker): + """Test discovery parameters.""" + proto = _DiscoverProtocol() + assert proto.discovery_packets == 3 + assert proto.target == ("255.255.255.255", 9999) + sendto = mocker.patch.object(proto, "transport") + proto.do_discover() + assert sendto.sendto.call_count == proto.discovery_packets + + +async def test_discover_datagram_received(mocker, discovery_data): + """Verify that datagram received fills discovered_devices.""" + proto = _DiscoverProtocol() + mocker.patch("json.loads", return_value=discovery_data) + mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "encrypt") + mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "decrypt") + + addr = "127.0.0.1" + proto.datagram_received("", (addr, 1234)) + + # Check that device in discovered_devices is initialized correctly + assert len(proto.discovered_devices) == 1 + dev = proto.discovered_devices[addr] + assert issubclass(dev.__class__, SmartDevice) + assert dev.host == addr diff --git a/kasa/tests/test_lightstrip.py b/kasa/tests/test_lightstrip.py new file mode 100644 index 000000000..7a8d8726a --- /dev/null +++ b/kasa/tests/test_lightstrip.py @@ -0,0 +1,17 @@ +from kasa import DeviceType, SmartLightStrip + +from .conftest import lightstrip, pytestmark + + +@lightstrip +async def test_lightstrip_length(dev: SmartLightStrip): + assert dev.is_light_strip + assert dev.device_type == DeviceType.LightStrip + assert dev.length == dev.sys_info["length"] + + +@lightstrip +async def test_lightstrip_effect(dev: SmartLightStrip): + assert isinstance(dev.effect, dict) + for k in ["brightness", "custom", "enable", "id", "name"]: + assert k in dev.effect diff --git a/kasa/tests/test_protocol.py b/kasa/tests/test_protocol.py index 51c01d49d..bc0da1833 100644 --- a/kasa/tests/test_protocol.py +++ b/kasa/tests/test_protocol.py @@ -1,4 +1,6 @@ import json +import struct +import sys import pytest @@ -21,11 +23,47 @@ def aio_mock_writer(_, __): conn = mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) with pytest.raises(SmartDeviceException): - await TPLinkSmartHomeProtocol.query("127.0.0.1", {}, retry_count=retry_count) + await TPLinkSmartHomeProtocol("127.0.0.1").query({}, retry_count=retry_count) assert conn.call_count == retry_count + 1 +@pytest.mark.skipif(sys.version_info < (3, 8), reason="3.8 is first one with asyncmock") +@pytest.mark.parametrize("retry_count", [1, 3, 5]) +async def test_protocol_reconnect(mocker, retry_count): + remaining = retry_count + encrypted = TPLinkSmartHomeProtocol.encrypt('{"great":"success"}')[ + TPLinkSmartHomeProtocol.BLOCK_SIZE : + ] + + def _fail_one_less_than_retry_count(*_): + nonlocal remaining + remaining -= 1 + if remaining: + raise Exception("Simulated write failure") + + async def _mock_read(byte_count): + nonlocal encrypted + if byte_count == TPLinkSmartHomeProtocol.BLOCK_SIZE: + return struct.pack(">I", len(encrypted)) + if byte_count == len(encrypted): + return encrypted + + raise ValueError(f"No mock for {byte_count}") + + def aio_mock_writer(_, __): + reader = mocker.patch("asyncio.StreamReader") + writer = mocker.patch("asyncio.StreamWriter") + mocker.patch.object(writer, "write", _fail_one_less_than_retry_count) + mocker.patch.object(reader, "readexactly", _mock_read) + return reader, writer + + protocol = TPLinkSmartHomeProtocol("127.0.0.1") + mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) + response = await protocol.query({}, retry_count=retry_count) + assert response == {"great": "success"} + + def test_encrypt(): d = json.dumps({"foo": 1, "bar": 2}) encrypted = TPLinkSmartHomeProtocol.encrypt(d) diff --git a/kasa/tests/test_readme_examples.py b/kasa/tests/test_readme_examples.py index 27455dd84..a64c824c1 100644 --- a/kasa/tests/test_readme_examples.py +++ b/kasa/tests/test_readme_examples.py @@ -1,3 +1,4 @@ +import asyncio import sys import pytest @@ -8,7 +9,7 @@ def test_bulb_examples(mocker): """Use KL130 (bulb with all features) to test the doctests.""" - p = get_device_for_file("KL130(US)_1.0.json") + p = asyncio.run(get_device_for_file("KL130(US)_1.0.json")) mocker.patch("kasa.smartbulb.SmartBulb", return_value=p) mocker.patch("kasa.smartbulb.SmartBulb.update") res = xdoctest.doctest_module("kasa.smartbulb", "all") @@ -17,7 +18,7 @@ def test_bulb_examples(mocker): def test_smartdevice_examples(mocker): """Use HS110 for emeter examples.""" - p = get_device_for_file("HS110(EU)_1.0_real.json") + p = asyncio.run(get_device_for_file("HS110(EU)_1.0_real.json")) mocker.patch("kasa.smartdevice.SmartDevice", return_value=p) mocker.patch("kasa.smartdevice.SmartDevice.update") res = xdoctest.doctest_module("kasa.smartdevice", "all") @@ -26,7 +27,7 @@ def test_smartdevice_examples(mocker): def test_plug_examples(mocker): """Test plug examples.""" - p = get_device_for_file("HS110(EU)_1.0_real.json") + p = asyncio.run(get_device_for_file("HS110(EU)_1.0_real.json")) mocker.patch("kasa.smartplug.SmartPlug", return_value=p) mocker.patch("kasa.smartplug.SmartPlug.update") res = xdoctest.doctest_module("kasa.smartplug", "all") @@ -35,7 +36,7 @@ def test_plug_examples(mocker): def test_strip_examples(mocker): """Test strip examples.""" - p = get_device_for_file("KP303(UK)_1.0.json") + p = asyncio.run(get_device_for_file("KP303(UK)_1.0.json")) mocker.patch("kasa.smartstrip.SmartStrip", return_value=p) mocker.patch("kasa.smartstrip.SmartStrip.update") res = xdoctest.doctest_module("kasa.smartstrip", "all") @@ -44,7 +45,7 @@ def test_strip_examples(mocker): def test_dimmer_examples(mocker): """Test dimmer examples.""" - p = get_device_for_file("HS220(US)_1.0_real.json") + p = asyncio.run(get_device_for_file("HS220(US)_1.0_real.json")) mocker.patch("kasa.smartdimmer.SmartDimmer", return_value=p) mocker.patch("kasa.smartdimmer.SmartDimmer.update") res = xdoctest.doctest_module("kasa.smartdimmer", "all") @@ -53,7 +54,7 @@ def test_dimmer_examples(mocker): def test_lightstrip_examples(mocker): """Test lightstrip examples.""" - p = get_device_for_file("KL430(US)_1.0.json") + p = asyncio.run(get_device_for_file("KL430(US)_1.0.json")) mocker.patch("kasa.smartlightstrip.SmartLightStrip", return_value=p) mocker.patch("kasa.smartlightstrip.SmartLightStrip.update") res = xdoctest.doctest_module("kasa.smartlightstrip", "all") @@ -65,7 +66,7 @@ def test_lightstrip_examples(mocker): ) def test_discovery_examples(mocker): """Test discovery examples.""" - p = get_device_for_file("KP303(UK)_1.0.json") + p = asyncio.run(get_device_for_file("KP303(UK)_1.0.json")) # This succeeds on python 3.8 but fails on 3.7 # ValueError: a coroutine was expected, got ["]