From aa9113f51dbeb6e159159d415ebcfef1a9cb7345 Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Wed, 14 Feb 2024 14:00:47 +0000 Subject: [PATCH 1/2] Ensure connections are closed when cli is finished --- kasa/cli.py | 10 +++++++++- kasa/device_factory.py | 20 ++++++++++++++------ 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/kasa/cli.py b/kasa/cli.py index 0893d5b05..86a0c15a6 100755 --- a/kasa/cli.py +++ b/kasa/cli.py @@ -5,6 +5,7 @@ import logging import re import sys +from contextlib import asynccontextmanager from functools import singledispatch, wraps from pprint import pformat as pf from typing import Any, Dict, cast @@ -365,7 +366,14 @@ def _nop_echo(*args, **kwargs): if ctx.invoked_subcommand not in SKIP_UPDATE_COMMANDS and not device_family: await dev.update() - ctx.obj = dev + @asynccontextmanager + async def async_wrapped_device(device: Device): + try: + yield device + finally: + await device.disconnect() + + ctx.obj = await ctx.with_async_resource(async_wrapped_device(dev)) if ctx.invoked_subcommand is None: return await ctx.invoke(state) diff --git a/kasa/device_factory.py b/kasa/device_factory.py index 28a5e3b2b..3550539c7 100755 --- a/kasa/device_factory.py +++ b/kasa/device_factory.py @@ -49,6 +49,20 @@ async def connect(*, host: Optional[str] = None, config: DeviceConfig) -> "Devic if host: config = DeviceConfig(host=host) + if (protocol := get_protocol(config=config)) is None: + raise UnsupportedDeviceException( + f"Unsupported device for {config.host}: " + + f"{config.connection_type.device_family.value}" + ) + + try: + return await _connect(config, protocol) + except: + await protocol.close() + raise + + +async def _connect(config: DeviceConfig, protocol: BaseProtocol) -> "Device": debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG) if debug_enabled: start_time = time.perf_counter() @@ -63,12 +77,6 @@ def _perf_log(has_params, perf_type): ) start_time = time.perf_counter() - if (protocol := get_protocol(config=config)) is None: - raise UnsupportedDeviceException( - f"Unsupported device for {config.host}: " - + f"{config.connection_type.device_family.value}" - ) - device_class: Optional[Type[Device]] device: Optional[Device] = None From 51ff9ca17e85d5ff63be5d9be01d4b3c4446f0af Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Wed, 14 Feb 2024 16:35:41 +0000 Subject: [PATCH 2/2] Test for close calls on error and success --- kasa/tests/test_device_factory.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/kasa/tests/test_device_factory.py b/kasa/tests/test_device_factory.py index 67ab39d50..7369a9874 100644 --- a/kasa/tests/test_device_factory.py +++ b/kasa/tests/test_device_factory.py @@ -53,7 +53,7 @@ async def test_connect( host=host, credentials=Credentials("foor", "bar"), connection_type=ctype ) protocol_class = get_protocol(config).__class__ - + close_mock = mocker.patch.object(protocol_class, "close") dev = await connect( config=config, ) @@ -61,8 +61,9 @@ async def test_connect( assert isinstance(dev.protocol, protocol_class) assert dev.config == config - + assert close_mock.call_count == 0 await dev.disconnect() + assert close_mock.call_count == 1 @pytest.mark.parametrize("custom_port", [123, None]) @@ -116,8 +117,12 @@ async def test_connect_query_fails(all_fixture_data: dict, mocker): config = DeviceConfig( host=host, credentials=Credentials("foor", "bar"), connection_type=ctype ) + protocol_class = get_protocol(config).__class__ + close_mock = mocker.patch.object(protocol_class, "close") + assert close_mock.call_count == 0 with pytest.raises(SmartDeviceException): await connect(config=config) + assert close_mock.call_count == 1 async def test_connect_http_client(all_fixture_data, mocker):