8000 Use typing aliases for compressors · zarr-developers/zarr-python@5515892 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5515892

Browse files
committed
Use typing aliases for compressors
1 parent 965267d commit 5515892

File tree

4 files changed

+45
-25
lines changed

4 files changed

+45
-25
lines changed

src/zarr/api/asynchronous.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,14 @@
99
import numpy.typing as npt
1010
from typing_extensions import deprecated
1111

12-
from zarr.core.array import Array, AsyncArray, create_array, from_array, get_array_metadata
12+
from zarr.core.array import (
13+
Array,
14+
AsyncArray,
15+
CompressorLike,
16+
create_array,
17+
from_array,
18+
get_array_metadata,
19+
)
1320
from zarr.core.array_spec import ArrayConfig, ArrayConfigLike, ArrayConfigParams
1421
from zarr.core.buffer import NDArrayLike
1522
from zarr.core.common import (
@@ -837,9 +844,7 @@ async def create(
837844
*, # Note: this is a change from v2
838845
chunks: ChunkCoords | int | None = None, # TODO: v2 allowed chunks=True
839846
dtype: npt.DTypeLike | None = None,
840-
compressor: dict[str, JSON]
841-
| Literal["default"]
842-
| None = "default", # TODO: default and type change
847+
compressor: CompressorLike = "auto",
843848
fill_value: Any | None = 0, # TODO: need type
844849
order: MemoryOrder | None = None,
845850
store: str | StoreLike | None = None,
@@ -992,7 +997,7 @@ async def create(
992997
dtype = parse_dtype(dtype, zarr_format)
993998
if not filters:
994999
filters = _default_filters(dtype)
995-
if compressor == "default":
1000+
if compressor == "auto":
9961001
compressor = _default_compressor(dtype)
9971002
elif zarr_format == 3 and chunk_shape is None: # type: ignore[redundant-expr]
9981003
if chunks is not None:

src/zarr/api/synchronous.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import zarr.api.asynchronous as async_api
88
import zarr.core.array
99
from zarr._compat import _deprecate_positional_args
10-
from zarr.core.array import Array, AsyncArray
10+
from zarr.core.array import Array, AsyncArray, CompressorLike
1111
from zarr.core.group import Group
1212
from zarr.core.sync import sync
1313
from zarr.core.sync_group import create_hierarchy
@@ -598,9 +598,7 @@ def create(
598598
*, # Note: this is a change from v2
599599
chunks: ChunkCoords | int | bool | None = None,
600600
dtype: npt.DTypeLike | None = None,
601-
compressor: dict[str, JSON]
602-
| Literal["default"]
603-
| None = "default", # TODO: default and type change
8000 601+
compressor: CompressorLike = "auto",
604602
fill_value: Any | None = 0, # TODO: need type
605603
order: MemoryOrder | None = None,
606604
store: str | StoreLike | None = None,

src/zarr/core/array.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@
101101
T_ArrayMetadata,
102102
)
103103
from zarr.core.metadata.v2 import (
104+
CompressorLikev2,
104105
_default_compressor,
105106
_default_filters,
106107
parse_compressor,
@@ -301,7 +302,7 @@ async def create(
301302
dimension_separator: Literal[".", "/"] | None = None,
302303
order: MemoryOrder | None = None,
303304
filters: list[dict[str, JSON]] | None = None,
304-
compressor: dict[str, JSON] | Literal["default"] | None = "default",
305+
compressor: CompressorLikev2 | Literal["auto"] = "auto",
305306
# runtime
306307
overwrite: bool = False,
307308
data: npt.ArrayLike | None = None,
@@ -392,7 +393,7 @@ async def create(
392393
dimension_separator: Literal[".", "/"] | None = None,
393394
order: MemoryOrder | None = None,
394395
filters: list[dict[str, JSON]] | None = None,
395-
compressor: dict[str, JSON] | Literal["default"] | None = "default",
396+
compressor: CompressorLike = "auto",
396397
# runtime
397398
overwrite: bool = False,
398399
data: npt.ArrayLike | None = None,
@@ -427,7 +428,7 @@ async def create(
427428
dimension_separator: Literal[".", "/"] | None = None,
428429
order: MemoryOrder | None = None,
429430
filters: list[dict[str, JSON]] | None = None,
430-
compressor: dict[str, JSON] | Literal["default"] | None = "default",
431+
compressor: CompressorLike = "auto",
431432
# runtime
432433
overwrite: bool = False,
433434
data: npt.ArrayLike | None = None,
@@ -568,7 +569,7 @@ async def _create(
568569
dimension_separator: Literal[".", "/"] | None = None,
569570
order: MemoryOrder | None = None,
570571
filters: list[dict[str, JSON]] | None = None,
571-
compressor: dict[str, JSON] | Literal["default"] | None = "default",
572+
compressor: CompressorLike = "auto",
572573
# runtime
573574
overwrite: bool = False,
574575
data: npt.ArrayLike | None = None,
@@ -602,7 +603,7 @@ async def _create(
602603
raise ValueError(
603604
"filters cannot be used for arrays with zarr_format 3. Use array-to-array codecs instead."
604605
)
605-
if compressor != "default":
606+
if compressor != "auto":
606607
raise ValueError(
607608
"compressor cannot be used for arrays with zarr_format 3. Use bytes-to-bytes codecs instead."
608609
)
@@ -766,7 +767,7 @@ def _create_metadata_v2(
766767
dimension_separator: Literal[".", "/"] | None = None,
767768
fill_value: float | None = None,
768769
filters: Iterable[dict[str, JSON] | numcodecs.abc.Codec] | None = None,
769-
compressor: dict[str, JSON] | numcodecs.abc.Codec | None = None,
770+
compressor: CompressorLikev2 = None,
770771
attributes: dict[str, JSON] | None = None,
771772
) -> ArrayV2Metadata:
772773
if dimension_separator is None:
@@ -807,7 +808,7 @@ async def _create_v2(
807808
dimension_separator: Literal[".", "/"] | None = None,
808809
fill_value: float | None = None,
809810
filters: Iterable[dict[str, JSON] | numcodecs.abc.Codec] | None = None,
810-
compressor: dict[str, JSON] | numcodecs.abc.Codec | Literal["default"] | None = None,
811+
compressor: CompressorLike = "auto",
811812
attributes: dict[str, JSON] | None = None,
812813
overwrite: bool = False,
813814
) -> AsyncArray[ArrayV2Metadata]:
@@ -819,8 +820,16 @@ async def _create_v2(
819820
else:
820821
await ensure_no_existing_node(store_path, zarr_format=2)
821822

822-
if compressor == "default":
823-
compressor = _default_compressor(dtype)
823+
compressor_parsed: CompressorLikev2
824+
if compressor == "auto":
825+
compressor_parsed = _default_compressor(dtype)
826+
elif isinstance(compressor, BytesBytesCodec):
827+
raise ValueError(
828+
"Cannot use a BytesBytesCodec as a compressor for zarr v2 arrays. "
829+
"Use a numcodecs codec directly instead."
830+
)
831+
else:
832+
compressor_parsed = compressor
824833

825834
metadata = cls._create_metadata_v2(
826835
shape=shape,
@@ -830,7 +839,7 @@ async def _create_v2(
830839
dimension_separator=dimension_separator,
831840
fill_value=fill_value,
832841
filters=filters,
833-
compressor=compressor,
842+
compressor=compressor_parsed,
834843
attributes=attributes,
835844
)
836845

@@ -1752,7 +1761,7 @@ def create(
17521761
dimension_separator: Literal[".", "/"] | None = None,
17531762
order: MemoryOrder | None = None,
17541763
filters: list[dict[str, JSON]] | None = None,
1755-
compressor: dict[str, JSON] | Literal["default"] | None = "default",
1764+
compressor: CompressorLike = "auto",
17561765
# runtime
17571766
overwrite: bool = False,
17581767
config: ArrayConfigLike | None = None,
@@ -1881,7 +1890,7 @@ def _create(
18811890
dimension_separator: Literal[".", "/"] | None = None,
18821891
order: MemoryOrder | None = None,
18831892
filters: list[dict[str, JSON]] | None = None,
1884-
compressor: dict[str, JSON] | Literal["default"] | None = "default",
1893+
compressor: CompressorLike = "auto",
18851894
# runtime
18861895
overwrite: bool = False,
18871896
config: ArrayConfigLike | None = None,
@@ -3788,7 +3797,11 @@ def _get_default_codecs(
37883797
| Literal["auto"]
37893798
| None
37903799
)
3791-
CompressorLike: TypeAlias = dict[str, JSON] | BytesBytesCodec | numcodecs.abc.Codec | None
3800+
# Union of acceptable types for v2 and v3 compressors
3801+
CompressorLike: TypeAlias = (
3802+
dict[str, JSON] | BytesBytesCodec | numcodecs.abc.Codec | Literal["auto"] | None
3803+
)
3804+
37923805
CompressorsLike: TypeAlias = (
37933806
Iterable[dict[str, JSON] | BytesBytesCodec | numcodecs.abc.Codec]
37943807
| dict[str, JSON]

src/zarr/core/metadata/v2.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections.abc import Iterable, Sequence
66
from enum import Enum
77
from functools import cached_property
8-
from typing import TYPE_CHECKING, Any, TypedDict, cast
8+
from typing import TYPE_CHECKING, Any, TypeAlias, TypedDict, cast
99

1010
import numcodecs.abc
1111

@@ -43,6 +43,10 @@ class ArrayV2MetadataDict(TypedDict):
4343
attributes: dict[str, JSON]
4444

4545

46+
# Union of acceptable types for v2 compressors
47+
CompressorLikev2: TypeAlias = dict[str, JSON] | numcodecs.abc.Codec | None
48+
49+
4650
@dataclass(frozen=True, kw_only=True)
4751
class ArrayV2Metadata(Metadata):
4852
shape: ChunkCoords
@@ -52,7 +56,7 @@ class ArrayV2Metadata(Metadata):
5256
order: MemoryOrder = "C"
5357
filters: tuple[numcodecs.abc.Codec, ...] | None = None
5458
dimension_separator: Literal[".", "/"] = "."
55-
compressor: numcodecs.abc.Codec | None = None
59+
compressor: CompressorLikev2
5660
attributes: dict[str, JSON] = field(default_factory=dict)
5761
zarr_format: Literal[2] = field(init=False, default=2)
5862

@@ -65,7 +69,7 @@ def __init__(
6569
fill_value: Any,
6670
order: MemoryOrder,
6771
dimension_separator: Literal[".", "/"] = ".",
68-
compressor: numcodecs.abc.Codec | dict[str, JSON] | None = None,
72+
compressor: CompressorLikev2 = None,
6973
filters: Iterable[numcodecs.abc.Codec | dict[str, JSON]] | None = None,
7074
attributes: dict[str, JSON] | None = None,
7175
) -> None:

0 commit comments

Comments
 (0)
0