8000 Handle KeyboardInterrupts in the cli better (#1391) · python-kasa/python-kasa@296af31 · GitHub
[go: up one dir, main page]

Skip to content

Commit 296af31

Browse files
authored
Handle KeyboardInterrupts in the cli better (#1391)
Addresses an issue with how `asyncclick` deals with `KeyboardInterrupt` errors. Instead of the `click.main` receiving `KeyboardInterrupt` it receives `CancelledError` because it's a task running inside the loop. Also ensures that discovery catches the `CancelledError` and closes the http clients.
1 parent fe88b52 commit 296af31

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

kasa/cli/common.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22

33
from __future__ import annotations
44

5+
import asyncio
56
import json
67
import re
78
import sys
89
from collections.abc import Callable
910
from contextlib import contextmanager
1011
from functools import singledispatch, update_wrapper, wraps
12+
from gettext import gettext
1113
from typing import TYPE_CHECKING, Any, Final
1214

1315
import asyncclick as click
@@ -238,4 +240,19 @@ async def invoke(self, ctx):
238240
except Exception as exc:
239241
_handle_exception(self._debug, exc)
240242

243+
def __call__(self, *args, **kwargs):
244+
"""Run the coroutine in the event loop and print any exceptions.
245+
246+
python click catches KeyboardInterrupt in main, raises Abort()
247+
and does sys.exit. asyncclick doesn't properly handle a coroutine
248+
receiving CancelledError on a KeyboardInterrupt, so we catch the
249+
KeyboardInterrupt here once asyncio.run has re-raised it. This
250+
avoids large stacktraces when a user presses Ctrl-C.
251+
"""
252+
try:
253+
asyncio.run(self.main(*args, **kwargs))
254+
except KeyboardInterrupt:
255+
click.echo(gettext("\nAborted!"), file=sys.stderr)
256+
sys.exit(1)
257+
241258
return _CommandCls

kasa/discover.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,7 @@ async def discover(
498498
try:
499499
_LOGGER.debug("Waiting %s seconds for responses...", discovery_timeout)
500500
await protocol.wait_for_discovery_to_complete()
501-
except KasaException as ex:
501+
except (KasaException, asyncio.CancelledError) as ex:
502502
for device in protocol.discovered_devices.values():
503503
await device.protocol.close()
504504
raise ex

0 commit comments

Comments
 (0)
0