8000 Allow passing an aiohttp client session during discover try_connect_a… · python-kasa/python-kasa@88b7951 · GitHub
[go: up one dir, main page]

Skip to content

Commit 88b7951

Browse files
authored
Allow passing an aiohttp client session during discover try_connect_all (#1198)
1 parent 7eb8d45 commit 88b7951

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

kasa/discover.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@
9393
from pprint import pformat as pf
9494
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, cast
9595

96+
from aiohttp import ClientSession
97+
9698
# When support for cpython older than 3.11 is dropped
9799
# async_timeout can be replaced with asyncio.timeout
98100
from async_timeout import timeout as asyncio_timeout
@@ -533,6 +535,7 @@ async def try_connect_all(
533535
port: int | None = None,
534536
timeout: int | None = None,
535537
credentials: Credentials | None = None,
538+
http_client: ClientSession | None = None,
536539
) -> Device | None:
537540
"""Try to connect directly to a device with all possible parameters.
538541
@@ -544,6 +547,7 @@ async def try_connect_all(
544547
:param port: Optionally set a different port for legacy devices using port 9999
545548
:param timeout: Timeout in seconds device for devices queries
546549
:param credentials: Credentials for devices that require authentication.
550+
:param http_client: Optional client session for devices that use http.
547551
username and password are ignored if provided.
548552
"""
549553
from .device_factory import _connect
@@ -570,6 +574,8 @@ async def try_connect_all(
570574
timeout=timeout,
571575
port_override=port,
572576
credentials=credentials,
577+
http_client=http_client,
578+
uses_http=encrypt is not Device.EncryptionType.Xor,
573579
)
574580
)
575581
and (protocol := get_protocol(config))

kasa/tests/test_discovery.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -697,9 +697,13 @@ async def _update(self, *args, **kwargs):
697697
mocker.patch("kasa.SmartProtocol.query", new=_query)
698698
mocker.patch.object(dev_class, "update", new=_update)
699699

700-
dev = await Discover.try_connect_all(discovery_mock.ip)
700+
session = aiohttp.ClientSession()
701+
dev = await Discover.try_connect_all(discovery_mock.ip, http_client=session)
701702

702703
assert dev
703704
assert isinstance(dev, dev_class)
704705
assert isinstance(dev.protocol, protocol_class)
705706
assert isinstance(dev.protocol._transport, transport_class)
707+
assert dev.config.uses_http is (transport_class != XorTransport)
708+
if transport_class != XorTransport:
709+
assert dev.protocol._transport._http_client.client == session

0 commit comments

Comments
 (0)
0