8000 Improve pandas and xarray conversion · matplotlib/matplotlib@7b51044 · GitHub
[go: up one dir, main page]

Skip to content

Commit 7b51044

Browse files
Oscar Gustafssonoscargus
Oscar Gustafsson
authored andcommitted
Improve pandas and xarray conversion
1 parent 0359832 commit 7b51044

File tree

7 files changed

+60
-21
lines changed

7 files changed

+60
-21
lines changed

lib/matplotlib/axes/_axes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7907,8 +7907,8 @@ def violinplot(self, dataset, positions=None, vert=True, widths=0.5,
79077907
"""
79087908

79097909
def _kde_method(X, coords):
7910-
if hasattr(X, 'values'): # support pandas.Series
7911-
X = X.values
7910+
# Unpack in case of e.g. Pandas or xarray object
7911+
X = cbook._unpack_to_numpy(X)
79127912
# fallback gracefully if the vector contains only one value
79137913
if np.all(X[0] == X):
79147914
return (X[0] == coords).astype(float)

lib/matplotlib/cbook/__init__.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1311,9 +1311,8 @@ def _to_unmasked_float_array(x):
13111311

13121312
def _check_1d(x):
13131313
"""Convert scalars to 1D arrays; pass-through arrays as is."""
1314-
if hasattr(x, 'to_numpy'):
1315-
# if we are given an object that creates a numpy, we should use it...
1316-
x = x.to_numpy()
1314+
# Unpack in case of e.g. Pandas or xarray object
1315+
x = _unpack_to_numpy(x)
13171316
if not hasattr(x, 'shape') or len(x.shape) < 1:
13181317
return np.atleast_1d(x)
13191318
else:
@@ -1332,15 +1331,8 @@ def _reshape_2D(X, name):
13321331
*name* is used to generate the error message for invalid inputs.
13331332
"""
13341333

1335-
# unpack if we have a values or to_numpy method.
1336-
try:
1337-
X = X.to_numpy()
1338-
except AttributeError:
1339-
try:
1340-
if isinstance(X.values, np.ndarray):
1341-
X = X.values
1342-
except AttributeError:
1343-
pass
1334+
# Unpack in case of e.g. Pandas or xarray object
1335+
X = _unpack_to_numpy(X)
13441336

13451337
# Iterate over columns for ndarrays.
13461338
if isinstance(X, np.ndarray):
@@ -2231,3 +2223,20 @@ def _picklable_class_constructor(mixin_class, fmt, attr_name, base_class):
22312223
factory = _make_class_factory(mixin_class, fmt, attr_name)
22322224
cls = factory(base_class)
22332225
return cls.__new__(cls)
2226+
2227+
2228+
def _unpack_to_numpy(x):
2229+
"""Internal helper to extract data from e.g. pandas and xarray objects."""
2230+
if isinstance(x, np.ndarray):
2231+
# If numpy, return directly
2232+
return x
2233+
if hasattr(x, 'to_numpy'):
2234+
# Assume that any function to_numpy() do actually return a numpy array
2235+
return x.to_numpy()
2236+
if hasattr(x, 'values'):
2237+
xtmp = x.values
2238+
# For example a dict has a 'values' attribute, but it is not a property
2239+
# so in this case we do not want to return a function
2240+
if isinstance(xtmp, np.ndarray):
2241+
return xtmp
2242+
return x

lib/matplotlib/dates.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -437,9 +437,8 @@ def date2num(d):
437437
The Gregorian calendar is assumed; this is not universal practice.
438438
For details see the module docstring.
439439
"""
440-
if hasattr(d, "values"):
441-
# this unpacks pandas series or dataframes...
442-
d = d.values
440+
# Unpack in case of e.g. Pandas or xarray object
441+
d = cbook._unpack_to_numpy(d)
443442

444443
# make an iterable, but save state to unpack later:
445444
iterable = np.iterable(d)

lib/matplotlib/testing/conftest.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,10 @@ def pd():
125125
except ImportError:
126126
pass
127127
return pd
128+
129+
130+
@pytest.fixture
131+
def xr():
132+
"""Fixture to import xarray."""
133+
xr = pytest.importorskip('xarray')
134+
return xr

lib/matplotlib/tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from matplotlib.testing.conftest import (mpl_test_settings,
22
pytest_configure, pytest_unconfigure,
3-
pd)
3+
pd, xr)

lib/matplotlib/tests/test_cbook.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -680,14 +680,37 @@ def test_reshape2d_pandas(pd):
680680
for x, xnew in zip(X.T, Xnew):
681681
np.testing.assert_array_equal(x, xnew)
682682

683+
684+
def test_reshape2d_xarray(xr):
685+
# separate to allow the rest of the tests to run if no xarray...
683686
X = np.arange(30).reshape(10, 3)
684-
x = pd.DataFrame(X, columns=["a", "b", "c"])
687+
x = xr.DataArray(X, dims=["x", "y"])
685688
Xnew = cbook._reshape_2D(x, 'x')
686689
# Need to check each row because _reshape_2D returns a list of arrays:
687690
for x, xnew in zip(X.T, Xnew):
688691
np.testing.assert_array_equal(x, xnew)
689692

690693

694+
def test_index_of_pandas(pd):
695+
# separate to allow the rest of the tests to run if no pandas...
696+
X = np.arange(30).reshape(10, 3)
697+
x = pd.DataFrame(X, columns=["a", "b", "c"])
698+
Idx, Xnew = cbook.index_of(x)
699+
np.testing.assert_array_equal(X, Xnew)
700+
IdxRef = np.arange(10)
701+
np.testing.assert_array_equal(Idx, IdxRef)
702+
703+
704+
def test_index_of_xarray(xr):
705+
# separate to allow the rest of the tests to run if no xarray...
706+
X = np.arange(30).reshape(10, 3)
707+
x = xr.DataArray(X, dims=["x", "y"])
708+
Idx, Xnew = cbook.index_of(x)
709+
np.testing.assert_array_equal(X, Xnew)
710+
IdxRef = np.arange(10)
711+
np.testing.assert_array_equal(Idx, IdxRef)
712+
713+
691714
def test_contiguous_regions():
692715
a, b, c = 3, 4, 5
693716
# Starts and ends with True

lib/matplotlib/units.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,9 @@ class Registry(dict):
180180

181181
def get_converter(self, x):
182182
"""Get the converter interface instance for *x*, or None."""
183-
if hasattr(x, "values"):
184-
x = x.values # Unpack pandas Series and DataFrames.
183+
# Unpack in case of e.g. Pandas or xarray object
184+
x = cbook._unpack_to_numpy(x)
185+
185186
if isinstance(x, np.ndarray):
186187
# In case x in a masked array, access the underlying data (only its
187188
# type matters). If x is a regular ndarray, getdata() just returns

0 commit comments

Comments
 (0)
0