8000 Put modules back on children for wall switches by sdb9696 · Pull Request #881 · python-kasa/python-kasa · GitHub
[go: up one dir, main page]

Skip to content

Put modules back on children for wall switches #881

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

8000
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion kasa/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
IotStrip,
IotWallSwitch,
)
from kasa.iot.modules import Usage
from kasa.smart import SmartBulb, SmartDevice

try:
Expand Down Expand Up @@ -829,7 +830,7 @@
Daily and monthly data provided in CSV format.
"""
echo("[bold]== Usage ==[/bold]")
usage = dev.modules["usage"]
usage = cast(Usage, dev.modules["usage"])

Check warning on line 833 in kasa/cli.py

View check run for this annotation

Codecov / codecov/patch

kasa/cli.py#L833

Added line #L833 was not covered by tests

if erase:
echo("Erasing usage statistics..")
Expand Down
7 changes: 6 additions & 1 deletion kasa/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .exceptions import KasaException
from .feature import Feature
from .iotprotocol import IotProtocol
from .module import Module
from .protocol import BaseProtocol
from .xortransport import XorTransport

Expand Down Expand Up @@ -72,7 +73,6 @@ def __init__(
self._last_update: Any = None
self._discovery_info: dict[str, Any] | None = None

self.modules: dict[str, Any] = {}
self._features: dict[str, Feature] = {}
self._parent: Device | None = None
self._children: Mapping[str, Device] = {}
Expand Down Expand Up @@ -111,6 +111,11 @@ async def disconnect(self):
"""Disconnect and close any underlying connection resources."""
await self.protocol.close()

@property
@abstractmethod
def modules(self) -> Mapping[str, Module]:
"""Return the device modules."""

@property
@abstractmethod
def is_on(self) -> bool:
Expand Down
46 changes: 28 additions & 18 deletions kasa/iot/iotdevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import inspect
import logging
from datetime import datetime, timedelta
from typing import Any, Mapping, Sequence
from typing import Any, Mapping, Sequence, cast

from ..device import Device, WifiNetwork
from ..deviceconfig import DeviceConfig
Expand All @@ -28,7 +28,7 @@
from ..feature import Feature
from ..protocol import BaseProtocol
from .iotmodule import IotModule
from .modules import Emeter
from .modules import Emeter, Time

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -189,12 +189,18 @@
self._supported_modules: dict[str, IotModule] | None = None
self._legacy_features: set[str] = set()
self._children: Mapping[str, IotDevice] = {}
self._modules: dict[str, IotModule] = {}

@property
def children(self) -> Sequence[IotDevice]:
"""Return list of children."""
return list(self._children.values())

@property
def modules(self) -> dict[str, IotModule]:
"""Return the device modules."""
return self._modules

def add_module(self, name: str, module: IotModule):
"""Register a module."""
if name in self.modules:
Expand Down Expand Up @@ -420,31 +426,31 @@
"""Set the device name (alias)."""
return await self._query_helper("system", "set_dev_alias", {"alias": alias})

@property # type: ignore
@property
@requires_update
def time(self) -> datetime:
"""Return current time from the device."""
return self.modules["time"].time
return cast(Time, self.modules["time"]).time

@property # type: ignore
@property
@requires_update
def timezone(self) -> dict:
"""Return the current timezone."""
return self.modules["time"].timezone
return cast(Time, self.modules["time"]).timezone

async def get_time(self) -> datetime | None:
"""Return current time from the device, if available."""
_LOGGER.warning(
"Use `time` property instead, this call will be removed in the future."
)
return await self.modules["time"].get_time()
return await cast(Time, self.modules["time"]).get_time()

async def get_timezone(self) -> dict:
"""Return timezone information."""
_LOGGER.warning(
"Use `timezone` property instead, this call will be removed in the future."
)
return await self.modules["time"].get_timezone()
return await cast(Time, self.modules["time"]).get_timezone()

@property # type: ignore
@requires_update
Expand Down Expand Up @@ -520,31 +526,31 @@
"""
return await self._query_helper("system", "set_mac_addr", {"mac": mac})

@property # type: ignore
@property
@requires_update
def emeter_realtime(self) -> EmeterStatus:
"""Return current energy readings."""
self._verify_emeter()
return EmeterStatus(self.modules["emeter"].realtime)
return EmeterStatus(cast(Emeter, self.modules["emeter"]).realtime)

async def get_emeter_realtime(self) -> EmeterStatus:
"""Retrieve current energy readings."""
self._verify_emeter()
return EmeterStatus(await self.modules["emeter"].get_realtime())
return EmeterStatus(await cast(Emeter, self.modules["emeter"]).get_realtime())

@property # type: ignore
@property
@requires_update
def emeter_today(self) -> float | None:
"""Return today's energy consumption in kWh."""
self._verify_emeter()
return self.modules["emeter"].emeter_today
return cast(Emeter, self.modules["emeter"]).emeter_today

@property # type: ignore
@property
@requires_update
def emeter_this_month(self) -> float | None:
"""Return this month's energy consumption in kWh."""
self._verify_emeter()
return self.modules["emeter"].emeter_this_month
return cast(Emeter, self.modules["emeter"]).emeter_this_month

async def get_emeter_daily(
self, year: int | None = None, month: int | None = None, kwh: bool = True
Expand All @@ -558,7 +564,9 @@
:return: mapping of day of month to value
"""
self._verify_emeter()
return await self.modules["emeter"].get_daystat(year=year, month=month, kwh=kwh)
return await cast(Emeter, self.modules["emeter"]).get_daystat(
year=year, month=month, kwh=kwh
)

@requires_update
async def get_emeter_monthly(
Expand All @@ -571,13 +579,15 @@
:return: dict: mapping of month to value
"""
self._verify_emeter()
return await self.modules["emeter"].get_monthstat(year=year, kwh=kwh)
return await cast(Emeter, self.modules["emeter"]).get_monthstat(
year=year, kwh=kwh
)

@requires_update
async def erase_emeter_stats(self) -> dict:
"""Erase energy meter statistics."""
self._verify_emeter()
return await self.modules["emeter"].erase_stats()
return await cast(Emeter, self.modules["emeter"]).erase_stats()

Check warning on line 590 in kasa/iot/iotdevice.py

View check run for this annotation

Codecov / codecov/patch

kasa/iot/iotdevice.py#L590

Added line #L590 was not covered by tests

@requires_update
async def current_consumption(self) -> float:
Expand Down
2 changes: 1 addition & 1 deletion kasa/iot/iotstrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def __init__(self, host: str, parent: IotStrip, child_id: str) -> None:
self._last_update = parent._last_update
self._set_sys_info(parent.sys_info)
self._device_type = DeviceType.StripSocket
self.modules = {}
self._modules = {}
self.protocol = parent.protocol # Must use the same connection as the parent
self.add_module("time", Time(self, "time"))

Expand Down
5 changes: 4 additions & 1 deletion kasa/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@

import logging
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING

from .device import Device
from .exceptions import KasaException
from .feature import Feature

if TYPE_CHECKING:
from .device import Device

_LOGGER = logging.getLogger(__name__)


Expand Down
41 changes: 26 additions & 15 deletions kasa/smart/smartdevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def __init__(
self._components_raw: dict[str, Any] | None = None
self._components: dict[str, int] = {}
self._state_information: dict[str, Any] = {}
self.modules: dict[str, SmartModule] = {}
self._modules: dict[str, SmartModule] = {}
self._exposes_child_modules = False
self._parent: SmartDevice | None = None
self._children: Mapping[str, SmartDevice] = {}
self._last_update = {}
Expand Down Expand Up @@ -84,11 +85,13 @@ async def _initialize_children(self):
@property
def children(self) -> Sequence[SmartDevice]:
"""Return list of children."""
# Wall switches with children report all modules on the parent only
if self.device_type == DeviceType.WallSwitch:
return []
return list(self._children.values())

@property
def modules(self) -> dict[str, SmartModule]:
"""Return the device modules."""
return self._modules

def _try_get_response(self, responses: dict, request: str, default=None) -> dict:
response = responses.get(request)
if isinstance(response, SmartErrorCode):
Expand Down Expand Up @@ -148,7 +151,7 @@ async def update(self, update_children: bool = True):
req: dict[str, Any] = {}

# TODO: this could be optimized by constructing the query only once
for module in self.modules.values():
for module in self._modules.values():
req.update(module.query())

self._last_update = resp = await self.protocol.query(req)
Expand All @@ -174,19 +177,24 @@ async def _initialize_modules(self):
# Some wall switches (like ks240) are internally presented as having child
# devices which report the child's components on the parent's sysinfo, even
# when they need to be accessed through the children.
# The logic below ensures that such devices report all but whitelisted, the
# child modules at the parent level to create an illusion of a single device.
# The logic below ensures that such devices add all but whitelisted, only on
# the child device.
skip_parent_only_modules = False
child_modules_to_skip = {}
if self._parent and self._parent.device_type == DeviceType.WallSwitch:
modules = self._parent.modules
skip_parent_only_modules = True
else:
modules = self.modules
skip_parent_only_modules = False
elif self._children and self.device_type == DeviceType.WallSwitch:
# _initialize_modules is called on the parent after the children
self._exposes_child_modules = True
for child in self._children.values():
child_modules_to_skip.update(**child.modules)

for mod in SmartModule.REGISTERED_MODULES.values():
_LOGGER.debug("%s requires %s", mod, mod.REQUIRED_COMPONENT)

if skip_parent_only_modules and mod in WALL_SWITCH_PARENT_ONLY_MODULES:
if (
skip_parent_only_modules and mod in WALL_SWITCH_PARENT_ONLY_MODULES
) or mod.__name__ in child_modules_to_skip:
continue
if mod.REQUIRED_COMPONENT in self._components:
_LOGGER.debug(
Expand All @@ -195,8 +203,11 @@ async def _initialize_modules(self):
mod.__name__,
)
module = mod(self, mod.REQUIRED_COMPONENT)
if module.name not in modules and await module._check_supported():
modules[module.name] = module
if await module._check_supported():
self._modules[module.name] = module

if self._exposes_child_modules:
self._modules.update(**child_modules_to_skip)

async def _initialize_features(self):
"""Initialize device features."""
Expand Down Expand Up @@ -278,7 +289,7 @@ async def _initialize_features(self):
)
)

for module in self.modules.values():
for module in self._modules.values():
for feat in module._module_features.values():
self._add_feature(feat)

Expand Down
7 changes: 5 additions & 2 deletions kasa/tests/smart/modules/test_fan.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import cast

from pytest_mock import MockerFixture

from kasa import SmartDevice
from kasa.smart.modules import FanModule
from kasa.tests.device_fixtures import parametrize

fan = parametrize("has fan", component_filter="fan_control", protocol_filter={"SMART"})
Expand All @@ -9,7 +12,7 @@
@fan
async def test_fan_speed(dev: SmartDevice, mocker: MockerFixture):
"""Test fan speed feature."""
fan = dev.modules.get("FanModule")
fan = cast(FanModule, dev.modules.get("FanModule"))
assert fan

level_feature = fan._module_features["fan_speed_level"]
Expand All @@ -32,7 +35,7 @@ async def test_fan_speed(dev: SmartDevice, mocker: MockerFixture):
@fan
async def test_sleep_mode(dev: SmartDevice, mocker: MockerFixture):
"""Test sleep mode feature."""
fan = dev.modules.get("FanModule")
fan = cast(FanModule, dev.modules.get("FanModule"))
assert fan
sleep_feature = fan._module_features["fan_sleep_mode"]
assert isinstance(sleep_feature.value, bool)
Expand Down
16 changes: 8 additions & 8 deletions kasa/tests/test_smartdevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,22 +103,22 @@ async def test_negotiate(dev: SmartDevice, mocker: MockerFixture):
async def test_update_module_queries(dev: SmartDevice, mocker: MockerFixture):
"""Test that the regular update uses queries from all supported modules."""
# We need to have some modules initialized by now
assert dev.modules
assert dev._modules

device_queries: dict[SmartDevice, dict[str, Any]] = {}
for mod in dev.modules.values():
for mod in dev._modules.values():
device_queries.setdefault(mod._device, {}).update(mod.query())

spies = {}
for dev in device_queries:
spies[dev] = mocker.spy(dev.protocol, "query")
for device in device_queries:
spies[device] = mocker.spy(device.protocol, "query")

await dev.update()
for dev in device_queries:
if device_queries[dev]:
spies[dev].assert_called_with(device_queries[dev])
for device in device_queries:
if device_queries[device]:
spies[device].assert_called_with(device_queries[device])
else:
spies[dev].assert_not_called()
spies[device].assert_not_called()


@bulb_smart
Expand Down
Loading
0