10000 Ensure connections are closed when cli is finished (#752) · python-kasa/python-kasa@45f251e · GitHub
[go: up one dir, main page]

Skip to content

Commit 45f251e

Browse files
authored
Ensure connections are closed when cli is finished (#752)
* Ensure connections are closed when cli is finished * Test for close calls on error and success
1 parent 5d81e9f commit 45f251e

File tree

3 files changed

+30
-9
lines changed

3 files changed

+30
-9
lines changed

kasa/cli.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import logging
66
import re
77
import sys
8+
from contextlib import asynccontextmanager
89
from functools import singledispatch, wraps
910
from pprint import pformat as pf
1011
from typing import Any, Dict, cast
@@ -365,7 +366,14 @@ def _nop_echo(*args, **kwargs):
365366
if ctx.invoked_subcommand not in SKIP_UPDATE_COMMANDS and not device_family:
366367
await dev.update()
367368

368-
ctx.obj = dev
369+
@asynccontextmanager
370+
async def async_wrapped_device(device: Device):
371+
try:
372+
yield device
373+
finally:
374+
await device.disconnect()
375+
376+
ctx.obj = await ctx.with_async_resource(async_wrapped_device(dev))
369377

370378
if ctx.invoked_subcommand is None:
371379
return await ctx.invoke(state)

kasa/device_factory.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,20 @@ async def connect(*, host: Optional[str] = None, config: DeviceConfig) -> "Devic
4949
if host:
5050
config = DeviceConfig(host=host)
5151

52+
if (protocol := get_protocol(config=config)) is None:
53+
raise UnsupportedDeviceException(
54+
f"Unsupported device for {config.host}: "
55+
+ f"{config.connection_type.device_family.value}"
56+
)
57+
58+
try:
59+
return await _connect(config, protocol)
60+
except:
61+
await protocol.close()
62+
raise
63+
64+
65+
async def _connect(config: DeviceConfig, protocol: BaseProtocol) -> "Device":
5266
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
5367
if debug_enabled:
5468
start_time = time.perf_counter()
@@ -63,12 +77,6 @@ def _perf_log(has_params, perf_type):
6377
)
6478
start_time = time.perf_counter()
6579

66-
if (protocol := get_protocol(config=config)) is None:
67-
raise UnsupportedDeviceException(
68-
f"Unsupported device for {config.host}: "
69-
+ f"{config.connection_type.device_family.value}"
70-
)
71-
7280
device_class: Optional[Type[Device]]
7381
device: Optional[Device] = None
7482

kasa/tests/test_device_factory.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,17 @@ async def test_connect(
5353
host=host, credentials=Credentials("foor", "bar"), connection_type=ctype
5454
)
5555
protocol_class = get_protocol(config).__class__
56-
56+
close_mock = mocker.patch.object(protocol_class, "close")
5757
dev = await connect(
5858
config=config,
5959
)
6060
assert isinstance(dev, device_class)
6161
assert isinstance(dev.protocol, protocol_class)
6262

6363
assert dev.config == config
64-
64+
assert close_mock.call_count == 0
6565
await dev.disconnect()
66+
assert close_mock.call_count == 1
6667

6768

6869
@pytest.mark.parametrize("custom_port", [123, None])
@@ -116,8 +117,12 @@ async def test_connect_query_fails(all_fixture_data: dict, mocker):
116117
config = DeviceConfig(
117118
host=host, credentials=Credentials("foor", "bar"), connection_type=ctype
118119
)
120+
protocol_class = get_protocol(config).__class__
121+
close_mock = mocker.patch.object(protocol_class, "close")
122+
assert close_mock.call_count == 0
119123
with pytest.raises(SmartDeviceException):
120124
await connect(config=config)
125+
assert close_mock.call_count == 1
121126

122127

123128
async def test_connect_http_client(all_fixture_data, mocker):

0 commit comments

Comments
 (0)
0