8000 Improve pandas and xarray conversion · matplotlib/matplotlib@1e27b8a · GitHub
[go: up one dir, main page]

Skip to content

Commit 1e27b8a

Browse files
Oscar Gustafssonoscargus
Oscar Gustafsson
authored andcommitted
Improve pandas and xarray conversion
1 parent 2e95730 commit 1e27b8a

File tree

9 files changed

+40
-18
lines changed

9 files changed

+40
-18
lines changed

environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,4 @@ dependencies:
5656
- pytest-xdist
5757
- tornado
5858
- pytz
59+
- xarray

lib/matplotlib/axes/_axes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7907,8 +7907,8 @@ def violinplot(self, dataset, positions=None, vert=True, widths=0.5,
79077907
"""
79087908

79097909
def _kde_method(X, coords):
7910-
if hasattr(X, 'values'): # support pandas.Series
7911-
X = X.values
7910+
# Unpack in case of e.g. Pandas or xarray object
7911+
X = cbook._unpack_to_numpy(X)
79127912
# fallback gracefully if the vector contains only one value
79137913
if np.all(X[0] == X):
79147914
return (X[0] == coords).astype(float)

lib/matplotlib/cbook/__init__.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1377,15 +1377,8 @@ def _reshape_2D(X, name):
13771377
*name* is used to generate the error message for invalid inputs.
13781378
"""
13791379

1380-
# unpack if we have a values or to_numpy method.
1381-
try:
1382-
X = X.to_numpy()
1383-
except AttributeError:
1384-
try:
1385-
if isinstance(X.values, np.ndarray):
1386-
X = X.values
1387-
except AttributeError:
1388-
pass
1380+
# Unpack in case of e.g. Pandas or xarray object
1381+
X = _unpack_to_numpy(X)
13891382

13901383
# Iterate over columns for ndarrays.
13911384
if isinstance(X, np.ndarray):
@@ -2276,3 +2269,20 @@ def _picklable_class_constructor(mixin_class, fmt, attr_name, base_class):
22762269
factory = _make_class_factory(mixin_class, fmt, attr_name)
22772270
cls = factory(base_class)
22782271
return cls.__new__(cls)
2272+
2273+
2274+
def _unpack_to_numpy(x):
2275+
"""Internal helper to extract data from e.g. pandas and xarray objects."""
2276+
if isinstance(x, np.ndarray):
2277+
# If numpy, return directly
2278+
return x
2279+
if hasattr(x, 'to_numpy'):
2280+
# Assume that any function to_numpy() do actually return a numpy array
2281+
return x.to_numpy()
2282+
if hasattr(x, 'values'):
2283+
xtmp = x.values
2284+
# For example a dict has a 'values' attribute, but it is not a property
2285+
# so in this case we do not want to return a function
2286+
if isinstance(xtmp, np.ndarray):
2287+
return xtmp
2288+
return x

lib/matplotlib/dates.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -437,9 +437,8 @@ def date2num(d):
437437
The Gregorian calendar is assumed; this is not universal practice.
438438
For details see the module docstring.
439439
"""
440-
if hasattr(d, "values"):
441-
# this unpacks pandas series or dataframes...
442-
d = d.values
440+
# Unpack in case of e.g. Pandas or xarray object
441+
d = cbook._unpack_to_numpy(d)
443442

444443
# make an iterable, but save state to unpack later:
445444
iterable = np.iterable(d)

lib/matplotlib/testing/conftest.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,10 @@ def pd():
125125
except ImportError:
126126
pass
127127
return pd
128+
129+
130+
@pytest.fixture
131+
def xr():
132+
"""Fixture to import xarray."""
133+
xr = pytest.importorskip('xarray')
134+
return xr

lib/matplotlib/tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from matplotlib.testing.conftest import (mpl_test_settings,
22
pytest_configure, pytest_unconfigure,
3-
pd)
3+
pd, xr)

lib/matplotlib/tests/test_cbook.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -681,8 +681,11 @@ def test_reshape2d_pandas(pd):
681681
for x, xnew in zip(X.T, Xnew):
682682
np.testing.assert_array_equal(x, xnew)
683683

684+
685+
def test_reshape2d_xarray(xr):
686+
# separate to allow the rest of the tests to run if no pandas...
684687
X = np.arange(30).reshape(10, 3)
685-
x = pd.DataFrame(X, columns=["a", "b", "c"])
688+
x = xr.DataArray(X, dims=["x", "y"])
686689
Xnew = cbook._reshape_2D(x, 'x')
687690
# Need to check each row because _reshape_2D returns a list of arrays:
688691
for x, xnew in zip(X.T, Xnew):

lib/matplotlib/units.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,9 @@ class Registry(dict):
180180

181181
def get_converter(self, x):
182182
"""Get the converter interface instance for *x*, or None."""
183-
if hasattr(x, "values"):
184-
x = x.values # Unpack pandas Series and DataFrames.
183+
# Unpack in case of e.g. Pandas or xarray object
184+
x = cbook._unpack_to_numpy(x)
185+
185186
if isinstance(x, np.ndarray):
186187
# In case x in a masked array, access the underlying data (only its
187188
# type matters). If x is a regular ndarray, getdata() just returns

requirements/testing/extra.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ pandas!=0.25.0
77
pikepdf
88
pytz
99
pywin32; sys.platform == 'win32'
10+
xarray

0 commit comments

Comments
 (0)
0