8000 (fix): `ExtensionArray` + `DataArray` roundtrip (#9520) · pydata/xarray@e649e13 · GitHub
[go: up one dir, main page]

Skip to content

Commit e649e13

Browse files
authored
(fix): ExtensionArray + DataArray roundtrip (#9520)
* (fix): fix extension array + dataarray roundtrip * (fix): satisfy mypy * (refactor): move check out of `Variable.values` * (fix): ensure `mypy` is happy with `values` typing * (fix): setter with `mypy` * (fix): remove case of `values`
1 parent 2b800ba commit e649e13

File tree

4 files changed

+36
-7
lines changed

4 files changed

+36
-7
lines changed

properties/test_pandas_roundtrip.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,3 +132,12 @@ def test_roundtrip_pandas_dataframe_datetime(df) -> None:
132132
roundtripped.columns.name = "cols" # why?
133133
pd.testing.assert_frame_equal(df, roundtripped)
134134
xr.testing.assert_identical(dataset, roundtripped.to_xarray())
135+
136+
137+
def test_roundtrip_1d_pandas_extension_array() -> None:
138+
df = pd.DataFrame({"cat": pd.Categorical(["a", "b", "c"])})
139+
arr = xr.Dataset.from_dataframe(df)["cat"]
140+
roundtripped = arr.to_pandas()
141+
assert (df["cat"] == roundtripped).all()
142+
assert df["cat"].dtype == roundtripped.dtype
143+
xr.testing.assert_identical(arr, roundtripped.to_xarray())

xarray/core/dataarray.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
create_coords_with_default_indexes,
4848
)
4949
from xarray.core.dataset import Dataset
50+
from xarray.core.extension_array import PandasExtensionArray
5051
from xarray.core.formatting import format_item
5152
from xarray.core.indexes import (
5253
Index,
@@ -3857,7 +3858,11 @@ def to_pandas(self) -> Self | pd.Series | pd.DataFrame:
38573858
"""
38583859
# TODO: consolidate the info about pandas constructors and the
38593860
# attributes that correspond to their indexes into a separate module?
3860-
constructors = {0: lambda x: x, 1: pd.Series, 2: pd.DataFrame}
3861+
constructors: dict[int, Callable] = {
3862+
0: lambda x: x,
3863+
1: pd.Series,
3864+
2: pd.DataFrame,
3865+
}
38613866
try:
38623867
constructor = constructors[self.ndim]
38633868
except KeyError as err:
@@ -3866,7 +3871,14 @@ def to_pandas(self) -> Self | pd.Series | pd.DataFrame:
38663871
"pandas objects. Requires 2 or fewer dimensions."
38673872
) from err
38683873
indexes = [self.get_index(dim) for dim in self.dims]
3869-
return constructor(self.values, *indexes) # type: ignore[operator]
3874+
if isinstance(self._variable._data, PandasExtensionArray):
3875+
values = self._variable._data.array
3876+
else:
3877+
values = self.values
3878+
pandas_object = constructor(values, *indexes)
3879+
if isinstance(pandas_object, pd.Series):
3880+
pandas_object.name = self.name
3881+
return pandas_object
38703882

38713883
def to_dataframe(
38723884
self, name: Hashable | None = None, dim_order: Sequence[Hashable] | None = None

xarray/core/variable.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -287,10 +287,13 @@ def as_compatible_data(
287287
if isinstance(data, DataArray):
288288
return cast("T_DuckArray", data._variable._data)
289289

290-
if isinstance(data, NON_NUMPY_SUPPORTED_ARRAY_TYPES):
290+
def convert_non_numpy_type(data):
291291
data = _possibly_convert_datetime_or_timedelta_index(data)
292292
return cast("T_DuckArray", _maybe_wrap_data(data))
293293

294+
if isinstance(data, NON_NUMPY_SUPPORTED_ARRAY_TYPES):
295+
return convert_non_numpy_type(data)
296+
294297
if isinstance(data, tuple):
295298
data = utils.to_0d_object_array(data)
296299

@@ -303,7 +306,11 @@ def as_compatible_data(
303306

304307
# we don't want nested self-described arrays
305308
if isinstance(data, pd.Series | pd.DataFrame):
306-
data = data.values # type: ignore[assignment]
309+
pandas_data = data.values
310+
if isinstance(pandas_data, NON_NUMPY_SUPPORTED_ARRAY_TYPES):
311+
return convert_non_numpy_type(pandas_data)
312+
else:
313+
data = pandas_data
307314

308315
if isinstance(data, np.ma.MaskedArray):
309316
mask = np.ma.getmaskarray(data)
@@ -540,7 +547,7 @@ def _dask_finalize(self, results, array_func, *args, **kwargs):
540547
return Variable(self._dims, data, attrs=self._attrs, encoding=self._encoding)
541548

542549
@property
543-
def values(self):
550+
def values(self) -> np.ndarray:
544551
"""The variable's data as a numpy.ndarray"""
545552
return _as_array_or_item(self._data)
546553

xarray/tests/test_variable.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2671,11 +2671,12 @@ def test_full_like(self) -> None:
26712671
)
26722672

26732673
expect = orig.copy(deep=True)
2674-
expect.values = [[2.0, 2.0], [2.0, 2.0]]
2674+
# see https://github.com/python/mypy/issues/3004 for why we need to ignore type
2675+
expect.values = [[2.0, 2.0], [2.0, 2.0]] # type: ignore[assignment]
26752676
assert_identical(expect, full_like(orig, 2))
26762677

26772678
# override dtype
2678-
expect.values = [[True, True], [True, True]]
2679+
expect.values = [[True, True], [True, True]] # type: ignore[assignment]
26792680
assert expect.dtype == bool
26802681
assert_identical(expect, full_like(orig, True, dtype=bool))
26812682

0 commit comments

Comments
 (0)
0