8000 Merge pull request #22876 from meeseeksmachine/auto-backport-of-pr-22… · matplotlib/matplotlib@3420565 · GitHub
[go: up one dir, main page]

Skip to content

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 3420565

Browse files
authored
Merge pull request #22876 from meeseeksmachine/auto-backport-of-pr-22560-on-v3.5.x
Backport PR #22560 on branch v3.5.x (Improve pandas/xarray/... conversion)
2 parents 0fe45ab + 1e23977 commit 3420565

File tree

7 files changed

+60
-21
lines changed

7 files changed

+60
-21
lines changed

lib/matplotlib/axes/_axes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7933,8 +7933,8 @@ def violinplot(self, dataset, positions=None, vert=True, widths=0.5,
79337933
"""
79347934

79357935
def _kde_method(X, coords):
7936-
if hasattr(X, 'values'): # support pandas.Series
7937-
X = X.values
7936+
# Unpack in case of e.g. Pandas or xarray object
7937+
X = cbook._unpack_to_numpy(X)
79387938
# fallback gracefully if the vector contains only one value
79397939
if np.all(X[0] == X):
79407940
return (X[0] == coords).astype(float)

lib/matplotlib/cbook/__init__.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1300,9 +1300,8 @@ def _to_unmasked_float_array(x):
13001300

13011301
def _check_1d(x):
13021302
"""Convert scalars to 1D arrays; pass-through arrays as is."""
1303-
if hasattr(x, 'to_numpy'):
1304-
# if we are given an object that creates a numpy, we should use it...
1305-
x = x.to_numpy()
1303+
# Unpack in case of e.g. Pandas or xarray object
1304+
x = _unpack_to_numpy(x)
13061305
if not hasattr(x, 'shape') or len(x.shape) < 1:
13071306
return np.atleast_1d(x)
13081307
else:
@@ -1321,15 +1320,8 @@ def _reshape_2D(X, name):
13211320
*name* is used to generate the error message for invalid inputs.
13221321
"""
13231322

1324-
# unpack if we have a values or to_numpy method.
1325-
try:
1326-
X = X.to_numpy()
1327-
except AttributeError:
1328-
try:
1329-
if isinstance(X.values, np.ndarray):
1330-
X = X.values
1331-
except AttributeError:
1332-
pass
1323+
# Unpack in case of e.g. Pandas or xarray object
1324+
X = _unpack_to_numpy(X)
13331325

13341326
# Iterate over columns for ndarrays.
13351327
if isinstance(X, np.ndarray):
@@ -2275,3 +2267,20 @@ def _picklable_class_constructor(mixin_class, fmt, attr_name, base_class):
22752267
factory = _make_class_factory(mixin_class, fmt, attr_name)
22762268
cls = factory(base_class)
22772269
return cls.__new__(cls)
2270+
2271+
2272+
def _unpack_to_numpy(x):
2273+
"""Internal helper to extract data from e.g. pandas and xarray objects."""
2274+
if isinstance(x, np.ndarray):
2275+
# If numpy, return directly
2276+
return x
2277+
if hasattr(x, 'to_numpy'):
2278+
# Assume that any function to_numpy() do actually return a numpy array
2279+
return x.to_numpy()
2280+
if hasattr(x, 'values'):
2281+
xtmp = x.values
2282+
# For example a dict has a 'values' attribute, but it is not a property
2283+
# so in this case we do not want to return a function
2284+
if isinstance(xtmp, np.ndarray):
2285+
return xtmp
2286+
return x

lib/matplotlib/dates.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -423,9 +423,8 @@ def date2num(d):
423423
The Gregorian calendar is assumed; this is not universal practice.
424424
For details see the module docstring.
425425
"""
426-
if hasattr(d, "values"):
427-
# this unpacks pandas series or dataframes...
428-
d = d.values
426+
# Unpack in case of e.g. Pandas or xarray object
427+
d = cbook._unpack_to_numpy(d)
429428

430429
# make an iterable, but save state to unpack later:
431430
iterable = np.iterable(d)

lib/matplotlib/testing/conftest.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,10 @@ def pd():
125125
except ImportError:
126126
pass
127127
return pd
128+
129+
130+
@pytest.fixture
131+
def xr():
132+
"""Fixture to import xarray."""
133+
xr = pytest.importorskip('xarray')
134+
return xr

lib/matplotlib/tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from matplotlib.testing.conftest import (mpl_test_settings,
22
pytest_configure, pytest_unconfigure,
3-
pd)
3+
pd, xr)

lib/matplotlib/tests/test_cbook.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -668,14 +668,37 @@ def test_reshape2d_pandas(pd):
668668
for x, xnew in zip(X.T, Xnew):
669669
np.testing.assert_array_equal(x, xnew)
670670

671+
672+
def test_reshape2d_xarray(xr):
673+
# separate to allow the rest of the tests to run if no xarray...
671674
X = np.arange(30).reshape(10, 3)
672-
x = pd.DataFrame(X, columns=["a", "b", "c"])
675+
x = xr.DataArray(X, dims=["x", "y"])
673676
Xnew = cbook._reshape_2D(x, 'x')
674677
# Need to check each row because _reshape_2D returns a list of arrays:
675678
for x, xnew in zip(X.T, Xnew):
676679
np.testing.assert_array_equal(x, xnew)
677680

678681

682+
def test_index_of_pandas(pd):
683+
# separate to allow the rest of the tests to run if no pandas...
684+
X = np.arange(30).reshape(10, 3)
685+
x = pd.DataFrame(X, columns=["a", "b", "c"])
686+
Idx, Xnew = cbook.index_of(x)
687+
np.testing.assert_array_equal(X, Xnew)
688+
IdxRef = np.arange(10)
689+
np.testing.assert_array_equal(Idx, IdxRef)
690+
691+
692+
def test_index_of_xarray(xr):
693+
# separate to allow the rest of the tests to run if no xarray...
694+
X = np.arange(30).reshape(10, 3)
695+
x = xr.DataArray(X, dims=["x", "y"])
696+
Idx, Xnew = cbook.index_of(x)
697+
np.testing.assert_array_equal(X, Xnew)
698+
IdxRef = np.arange(10)
699+
np.testing.assert_array_equal(Idx, IdxRef)
700+
701+
679702
def test_contiguous_regions():
680703
a, b, c = 3, 4, 5
681704
# Starts and ends with True

lib/matplotlib/units.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,9 @@ class Registry(dict):
180180

181181
def get_converter(self, x):
182182
"""Get the converter interface instance for *x*, or None."""
183-
if hasattr(x, "values"):
184-
x = x.values # Unpack pandas Series and DataFrames.
183+
# Unpack in case of e.g. Pandas or xarray object
184+
x = cbook._unpack_to_numpy(x)
185+
185186
if isinstance(x, np.ndarray):
186187
# In case x in a masked array, access the underlying data (only its
187188
# type matters). If x is a regular ndarray, getdata() just returns

0 commit comments

Comments
 (0)
0