8000 #198 Merge pull request from astropenguin/astropenguin/issue197 · astropenguin/xarray-dataclasses@d57fa41 · GitHub
[go: up one dir, main page]

Skip to content

Commit d57fa41

Browse files
authored
#198 Merge pull request from astropenguin/astropenguin/issue197
Add typing module (v2)
2 parents fbfaa5b + 07d71da commit d57fa41

File tree

4 files changed

+97
-1
lines changed

4 files changed

+97
-1
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,5 @@ jobs:
3434
- name: Test code's execution (pytest)
3535
run: pytest -v tests
3636
- name: Test docs' building (Sphinx)
37+
if: ${{ contains('3.10, 3.11', matrix.python) }}
3738
run: docs/build

tests/test_core_typing.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# standard library
2+
from sys import version_info
3+
from typing import Any, Union
4+
5+
6+
# dependencies
7+
from pytest import mark
8+
from xarray_dataclasses.core.typing import is_union
9+
10+
11+
if version_info.minor >= 10:
12+
data_is_union = [
13+
(Any, False),
14+
(Union[int, str], True),
15+
(int | str, True), # type: ignore
16+
]
17+
else:
18+
data_is_union = [
19+
(Any, False),
20+
(Union[int, str], True),
21+
]
22+
23+
24+
@mark.parametrize("tp, expected", data_is_union)
25+
def test_get_tags(tp: Any, expected: bool) -> None:
26+
assert is_union(tp) == expected

xarray_dataclasses/core/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
__all__ = ["tagging"]
1+
__all__ = ["tagging", "typing"]
22

33

44
from . import tagging
5+
from . import typing

xarray_dataclasses/core/typing.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
__all__ = [
2+
"DataClass",
3+
"DataClassOf",
4+
"PAny",
5+
"TAny",
6+
"TDataArray",
7+
"TDataset",
8+
"TXarray",
9+
"Xarray",
10+
"is_union",
11+
]
12+
13+
14+
# standard library
15+
import types
16+
from dataclasses import Field
17+
from typing import Any, Callable, ClassVar, Dict, Protocol, TypeVar, Union
18+
19+
20+
# dependencies
21+
from xarray import DataArray, Dataset
22+
from typing_extensions import ParamSpec, get_origin
23+
24+
25+
Xarray = Union[DataArray, Dataset]
26+
"""Type hint for any xarray object."""
27+
28+
PAny = ParamSpec("PAny")
29+
"""Parameter specification variable for any function."""
30+
31+
TAny = TypeVar("TAny")
32+
"""Type variable for any class."""
33+
34+
TDataArray = TypeVar("TDataArray", bound=DataArray)
35+
"""Type variable for xarray DataArray."""
36+
37+
TDataset = TypeVar("TDataset", bound=Dataset)
38+
"""Type variable for xarray Dataset."""
39+
40+
TXarray = TypeVar("TXarray", bound=Xarray)
41+
"""Type variable for any class of xarray object."""
42+
43+
44+
class DataClass(Protocol[PAny]):
45+
"""Protocol for any dataclass object."""
46+
47+
__dataclass_fields__: ClassVar[Dict[str, "Field[Any]"]]
48+
49+
def __init__(self, *args: PAny.args, **kwargs: PAny.kwargs) -> None:
50+
...
51+
52+
53+
class DataClassOf(Protocol[TXarray, PAny]):
54+
"""Protocol for any dataclass object with a factory."""
55+
56+
__dataclass_fields__: ClassVar[Dict[str, "Field[Any]"]]
57+
__xarray_factory__: Callable[..., TXarray]
58+
59+
def __init__(self, *args: PAny.args, **kwargs: PAny.kwargs) -> None:
60+
...
61+
62+
63+
def is_union(tp: Any) -> bool:
64+
"""Check if a type hint is a union of types."""
65+
if UnionType := getattr(types, "UnionType", None):
66+
return get_origin(tp) is Union or isinstance(tp, UnionType)
67+
else:
68+
return get_origin(tp) is Union

0 commit comments

Comments
 (0)
0