8000 Fix typing issues of StringSetFlag (#107) · lablup/backend.ai-common@ec96e50 · GitHub
[go: up one dir, main page]

Skip to content
This repository was archived by the owner on Jul 16, 2024. It is now read-only.

Commit ec96e50

Browse files
authored
Fix typing issues of StringSetFlag (#107)
* Selectively ignore override errors as this is a custom class * Use a separate interface declaration file (.pyi) by refactoring out StringSetFlag class from the utils module to overcome limitation of mypy semantic analysis on "__ror__ = __or__" in subclasses of Enum/Flag classes.
1 parent 9c4ee94 commit ec96e50

File tree

5 files changed

+84
-56
lines changed

5 files changed

+84
-56
lines changed

changes/107.fix.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix typing issues of `StringSetFlag` by refactoring it using a separate interface definition file
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from __future__ import annotations
2+
3+
import enum
4+
5+
__all__ = (
6+
'StringSetFlag',
7+
)
8+
9+
10+
class StringSetFlag(enum.Flag):
11+
12+
def __eq__(self, other):
13+
return self.value == other
14+
15+
def __hash__(self):
16+
return hash(self.value)
17+
18+
def __or__(self, other):
19+
if isinstance(other, type(self)):
20+
other = other.value
21+
if not isinstance(other, (set, frozenset)):
22+
other = set((other,))
23+
return set((self.value,)) | other
24+
25+
__ror__ = __or__
26+
27+
def __and__(self, other):
28+
if isinstance(other, (set, frozenset)):
29+
return self.value in other
30+
if isinstance(other, str):
31+
return self.value == other
32+
raise TypeError
33+
34+
__rand__ = __and__
35+
36+
def __xor__(self, other):
37+
if isinstance(other, (set, frozenset)):
38+
return set((self.value,)) ^ other
39+
if isinstance(other, str):
40+
if other == self.value:
41+
return set()
42+
else:
43+
return other
44+
raise TypeError
45+
46+
def __rxor__(self, other):
47+
if isinstance(other, (set, frozenset)):
48+
return other ^ set((self.value,))
49+
if isinstance(other, str):
50+
if other == self.value:
51+
return set()
52+
else:
53+
return other
54+
raise TypeError
55+
56+
def __str__(self):
57+
return self.value
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import enum
2+
3+
4+
class StringSetFlag(enum.Flag):
5+
def __eq__(self, other: object) -> bool: ...
6+
def __hash__(self) -> int: ...
7+
def __or__( # type: ignore[override]
8+
self,
9+
other: StringSetFlag | str | set[str] | frozenset[str],
10+
) -> set[str]: ...
11+
def __and__( # type: ignore[override]
12+
self,
13+
other: StringSetFlag | str | set[str] | frozenset[str],
14+
) -> bool: ...
15+
def __xor__( # type: ignore[override]
16+
self,
17+
other: StringSetFlag | str | set[str] | frozenset[str],
18+
) -> set[str]: ...
19+
def __ror__(self, other: StringSetFlag | str | set[str] | frozenset[str]) -> set[str]: ...
20+
def __rand__(self, other: StringSetFlag | str | set[str] | frozenset[str]) -> bool: ...
21+
def __rxor__(self, other: StringSetFlag | str | set[str] | frozenset[str]) -> set[str]: ...
22+
def __str__(self) -> str: ...

src/ai/backend/common/utils.py

Lines changed: 1 addition & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import base64
44
from collections import OrderedDict
55
from datetime import timedelta
6-
import enum
76
from itertools import chain
87
import numbers
98
import random
@@ -34,6 +33,7 @@
3433
current_loop,
3534
run_through,
3635
)
36+
from .enum_extension import StringSetFlag # for legacy imports # noqa
3737
from .files import AsyncFileWriter # for legacy imports # noqa
3838
from .networking import ( # for legacy imports # noqa
3939
curl,
@@ -198,56 +198,6 @@ def str_to_timedelta(tstr: str) -> timedelta:
198198
return timedelta(**params) # type: ignore
199199

200200

201-
class StringSetFlag(enum.Flag):
202-
203-
def __eq__(self, other):
204-
return self.value == other
205-
206-
def __hash__(self):
207-
return hash(self.value)
208-
209-
def __or__(self, other):
210-
if isinstance(other, type(self)):
211-
other = other.value
212-
if not isinstance(other, (set, frozenset)):
213-
other = set((other,))
214-
return set((self.value,)) | other
215-
216-
__ror__ = __or__
217-
218-
def __and__(self, other):
219-
if isinstance(other, (set, frozenset)):
220-
return self.value in other
221-
if isinstance(other, str):
222-
return self.value == other
223-
raise TypeError
224-
225-
__rand__ = __and__
226-
227-
def __xor__(self, other):
228-
if isinstance(other, (set, frozenset)):
229-
return set((self.value,)) ^ other
230-
if isinstance(other, str):
231-
if other == self.value:
232-
return set()
233-
else:
234-
return other
235-
raise TypeError
236-
237-
def __rxor__(self, other):
238-
if isinstance(other, (set, frozenset)):
239-
return other ^ set((self.value,))
240-
if isinstance(other, str):
241-
if other == self.value:
242-
return set()
243-
else:
244-
return other
245-
raise TypeError
246-
247-
def __str__(self):
248-
return self.value
249-
250-
251201
class FstabEntry:
252202
"""
253203
Entry class represents a non-comment line on the `fstab` file.

tests/test_utils.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@
1212
import pytest
1313

1414
from ai.backend.common.asyncio import AsyncBarrier, run_through
15+
from ai.backend.common.enum_extension import StringSetFlag
1516
from ai.backend.common.files import AsyncFileWriter
1617
from ai.backend.common.networking import curl
1718
from ai.backend.common.utils import (
1819
odict, dict2kvlist, nmget,
1920
generate_uuid, get_random_seq,
2021
readable_size_to_bytes,
2122
str_to_timedelta,
22-
StringSetFlag,
2323
)
2424
from ai.backend.common.testutils import (
2525
mock_corofunc, mock_awaitable, AsyncContextManagerMock,
@@ -156,9 +156,7 @@ async def test_curl_returns_default_value_if_not_success(mocker) -> None:
156156

157157
def test_string_set_flag() -> None:
158158

159-
# FIXME: Remove "type: ignore" when mypy gets released with
160-
# python/mypy#11579.
161-
class MyFlags(StringSetFlag): # type: ignore
159+
class MyFlags(StringSetFlag):
162160
A = 'a'
163161
B = 'b'
164162

@@ -182,7 +180,7 @@ class MyFlags(StringSetFlag): # type: ignore
182180
assert {'b'} == MyFlags.A ^ {'a', 'b'}
183181
assert {'a', 'b', 'c'} == MyFlags.A ^ {'b', 'c'}
184182
with pytest.raises(TypeError):
185-
123 & MyFlags.A
183+
123 & MyFlags.A # type: ignore[operator]
186184

187185
assert {'a', 'c'} & MyFlags.A
188186
assert not {'a', 'c'} & MyFlags.B

0 commit comments

Comments
 (0)
0