10000 Merge pull request #22560 from oscargus/pandasconversion · matplotlib/matplotlib@709fba8 · GitHub
[go: up one dir, main page]

Skip to content

Commit 709fba8

Browse files
authored
Merge pull request #22560 from oscargus/pandasconversion
Improve pandas/xarray/... conversion
2 parents a7b7260 + 56af810 commit 709fba8

File tree

8 files changed

+61
-21
lines changed

8 files changed

+61
-21
lines changed

lib/matplotlib/axes/_axes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7944,8 +7944,8 @@ def violinplot(self, dataset, positions=None, vert=True, widths=0.5,
79447944
"""
79457945

79467946
def _kde_method(X, coords):
7947-
if hasattr(X, 'values'): # support pandas.Series
7948-
X = X.values
7947+
# Unpack in case of e.g. Pandas or xarray object
7948+
X = cbook._unpack_to_numpy(X)
79497949
# fallback gracefully if the vector contains only one value
79507950
if np.all(X[0] == X):
79517951
return (X[0] == coords).astype(float)

lib/matplotlib/cbook/__init__.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1331,9 +1331,8 @@ def _to_unmasked_float_array(x):
13311331

13321332
def _check_1d(x):
13331333
"""Convert scalars to 1D arrays; pass-through arrays as is."""
1334-
if hasattr(x, 'to_numpy'):
1335-
# if we are given an object that creates a numpy, we should use it...
1336-
x = x.to_numpy()
1334+
# Unpack in case of e.g. Pandas or xarray object
1335+
x = _unpack_to_numpy(x)
13371336
if not hasattr(x, 'shape') or len(x.shape) < 1:
13381337
return np.atleast_1d(x)
13391338
else:
@@ -1352,15 +1351,8 @@ def _reshape_2D(X, name):
13521351
*name* is used to generate the error message for invalid inputs.
13531352
"""
13541353

1355-
# unpack if we have a values or to_numpy method.
1356-
try:
1357-
X = X.to_numpy()
1358-
except AttributeError:
1359-
try:
1360-
if isinstance(X.values, np.ndarray):
1361-
X = X.values
1362-
except AttributeError:
1363-
pass
1354+
# Unpack in case of e.g. Pandas or xarray object
1355+
X = _unpack_to_numpy(X)
13641356

13651357
# Iterate over columns for ndarrays.
13661358
if isinstance(X, np.ndarray):
@@ -2251,3 +2243,20 @@ def _picklable_class_constructor(mixin_class, fmt, attr_name, base_class):
22512243
factory = _make_class_factory(mixin_class, fmt, attr_name)
22522244
cls = factory(base_class)
22532245
return cls.__new__(cls)
2246+
2247+
2248+
def _unpack_to_numpy(x):
2249+
"""Internal helper to extract data from e.g. pandas and xarray objects."""
2250+
if isinstance(x, np.ndarray):
2251+
# If numpy, return directly
2252+
return x
2253+
if hasattr(x, 'to_numpy'):
2254+
# Assume that any function to_numpy() do actually return a numpy array
2255+
return x.to_numpy()
2256+
if hasattr(x, 'values'):
2257+
xtmp = x.values
2258+
# For example a dict has a 'values' attribute, but it is not a property
2259+
# so in this case we do not want to return a function
2260+
if isinstance(xtmp, np.ndarray):
2261+
return xtmp
2262+
return x

lib/matplotlib/dates.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -434,9 +434,8 @@ def date2num(d):
434434
The Gregorian calendar is assumed; this is not universal practice.
435435
For details see the module docstring.
436436
"""
437-
if hasattr(d, "values"):
438-
# this unpacks pandas series or dataframes...
439-
d = d.values
437+
# Unpack in case of e.g. Pandas or xarray object
438+
d = cbook._unpack_to_numpy(d)
440439

441440
# make an iterable, but save state to unpack later:
442441
iterable = np.iterable(d)

lib/matplotlib/testing/conftest.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,10 @@ def pd():
125125
except ImportError:
126126
pass
127127
return pd
128+
129+
130+
@pytest.fixture
131+
def xr():
132+
"""Fixture to import xarray."""
133+
xr = pytest.importorskip('xarray')
134+
return xr

lib/matplotlib/tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from matplotlib.testing.conftest import (mpl_test_settings,
22
pytest_configure, pytest_unconfigure,
3-
pd)
3+
pd, xr)

lib/matplotlib/tests/test_cbook.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -701,14 +701,37 @@ def test_reshape2d_pandas(pd):
701701
for x, xnew in zip(X.T, Xnew):
702702
np.testing.assert_array_equal(x, xnew)
703703

704+
705+
def test_reshape2d_xarray(xr):
706+
# separate to allow the rest of the tests to run if no xarray...
704707
X = np.arange(30).reshape(10, 3)
705-
x = pd.DataFrame(X, columns=["a", "b", "c"])
708+
x = xr.DataArray(X, dims=["x", "y"])
706709
Xnew = cbook._reshape_2D(x, 'x')
707710
# Need to check each row because _reshape_2D returns a list of arrays:
708711
for x, xnew in zip(X.T, Xnew):
709712
np.testing.assert_array_equal(x, xnew)
710713

711714

715+
def test_index_of_pandas(pd):
716+
# separate to allow the rest of the tests to run if no pandas...
717+
X = np.arange(30).reshape(10, 3)
718+
x = pd.DataFrame(X, columns=["a", "b", "c"])
719+
Idx, Xnew = cbook.index_of(x)
720+
np.testing.assert_array_equal(X, Xnew)
721+
IdxRef = np.arange(10)
722+
np.testing.assert_array_equal(Idx, IdxRef)
723+
724+
725+
def test_index_of_xarray(xr):
726+
# separate to allow the rest of the tests to run if no xarray...
727+
X = np.arange(30).reshape(10, 3)
728+
x = xr.DataArray(X, dims=["x", "y"])
729+
Idx, Xnew = cbook.index_of(x)
730+
np.testing.assert_array_equal(X, Xnew)
731+
IdxRef = np.arange(10)
732+
np.testing.assert_array_equal(Idx, IdxRef)
733+
734+
712735
def test_contiguous_regions():
713736
a, b, c = 3, 4, 5
714737
# Starts and ends with True

lib/matplotlib/units.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,9 @@ class Registry(dict):
180180

181181
def get_converter(self, x):
182182
"""Get the converter interface instance for *x*, or None."""
183-
if hasattr(x, "values"):
184-
x = x.values # Unpack pandas Series and DataFrames.
183+
# Unpack in case of e.g. Pandas or xarray object
184+
x = cbook._unpack_to_numpy(x)
185+
185186
if isinstance(x, np.ndarray):
186187
# In case x in a masked array, access the underlying data (only its
187188
# type matters). If x is a regular ndarray, getdata() just returns

requirements/testing/extra.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ pandas!=0.25.0
77
pikepdf
88
pytz
99
pywin32; sys.platform == 'win32'
10+
xarray

0 commit comments

Comments
 (0)
0