8000 Merge pull request #5556 from tacaswell/fix_pandas_indexing · matplotlib/matplotlib@aeafb64 · GitHub
[go: up one dir, main page]

Skip to content

Commit aeafb64

Browse files
committed
Merge pull request #5556 from tacaswell/fix_pandas_indexing
FIX: pandas indexing error
2 parents 781000c + bdaaf59 commit aeafb64

File tree

6 files changed

+98
-48
lines changed

6 files changed

+98
-48
lines changed

lib/matplotlib/axes/_axes.py

Lines changed: 55 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2331,11 +2331,11 @@ def broken_barh(self, xranges, yrange, **kwargs):
23312331
"""
23322332
# process the unit information
23332333
if len(xranges):
2334-
xdata = six.next(iter(xranges))
2334+
xdata = cbook.safe_first_element(xranges)
23352335
else:
23362336
xdata = None
23372337
if len(yrange):
2338-
ydata = six.next(iter(yrange))
2338+
ydata = cbook.safe_first_element(yrange)
23392339
else:
23402340
ydata = None
23412341
self._process_unit_info(xdata=xdata,
@@ -3016,7 +3016,7 @@ def xywhere(xs, ys, mask):
30163016

30173017
if ecolor is None:
30183018
if l0 is None and 'color' in self._get_lines._prop_keys:
3019-
ecolor = six.next(self._get_lines.prop_cycler)['color']
3019+
ecolor = next(self._get_lines.prop_cycler)['color']
30203020
else:
30213021
ecolor = l0.get_color()
30223022

@@ -5875,6 +5875,41 @@ def hist(self, x, bins=10, range=None, normed=False, weights=None,
58755875
.. plot:: mpl_examples/statistics/histogram_demo_features.py
58765876
58775877
"""
5878+
def _normalize_input(inp, ename='input'):
5879+
"""Normalize 1 or 2d input into list of np.ndarray or
5880+
a single 2D np.ndarray.
5881+
5882+
Parameters
5883+
----------
5884+
inp : iterable
5885+
ename : str, optional
5886+
Name to use in ValueError if `inp` can not be normalized
5887+
5888+
"""
5889+
if (isinstance(x, np.ndarray) or
5890+
not iterable(cbook.safe_first_element(inp))):
5891+
# TODO: support masked arrays;
5892+
inp = np.asarray(inp)
5893+
if inp.ndim == 2:
5894+
# 2-D input with columns as datasets; switch to rows
5895+
inp = inp.T
5896+
elif inp.ndim == 1:
5897+
# new view, single row
5898+
inp = inp.reshape(1, inp.shape[0])
5899+
else:
5900+
raise ValueError(
5901+
"{ename} must be 1D or 2D".format(ename=ename))
5902+
if inp.shape[1] < inp.shape[0]:
5903+
warnings.warn(
5904+
'2D hist input should be nsamples x nvariables;\n '
5905+
'this looks transposed '
5906+
'(shape is %d x %d)' % inp.shape[::-1])
5907+
else:
5908+
# multiple hist with data of different length
5909+
inp = [np.asarray(xi) for xi in inp]
5910+
5911+
return inp
5912+
58785913
if not self._hold:
58795914
self.cla()
58805915

@@ -5918,58 +5953,34 @@ def hist(self, x, bins=10, range=None, normed=False, weights=None,
59185953
input_empty = len(flat) == 0
59195954

59205955
# Massage 'x' for processing.
5921-
# NOTE: Be sure any changes here is also done below to 'weights'
59225956
if input_empty:
59235957
x = np.array([[]])
5924-
elif isinstance(x, np.ndarray) or not iterable(x[0]):
5925-
# TODO: support masked arrays;
5926-
x = np.asarray(x)
5927-
if x.ndim == 2:
5928-
x = x.T # 2-D input with columns as datasets; switch to rows
5929-
elif x.ndim == 1:
5930-
x = x.re A92E shape(1, x.shape[0]) # new view, single row
5931-
else:
5932-
raise ValueError("x must be 1D or 2D")
5933-
if x.shape[1] < x.shape[0]:
5934-
warnings.warn(
5935-
'2D hist input should be nsamples x nvariables;\n '
5936-
'this looks transposed (shape is %d x %d)' % x.shape[::-1])
59375958
else:
5938-
# multiple hist with data of different length
5939-
x = [np.asarray(xi) for xi in x]
5940-
5959+
x = _normalize_input(x, 'x')
59415960
nx = len(x) # number of datasets
59425961

5962+
# We need to do to 'weights' what was done to 'x'
5963+
if weights is not None:
5964+
w = _normalize_input(weights, 'weights')
5965+
else:
5966+
w = [None]*nx
5967+
5968+
if len(w) != nx:
5969+
raise ValueError('weights should have the same shape as x')
5970+
5971+
for xi, wi in zip(x, w):
5972+
if wi is not None and len(wi) != len(xi):
5973+
raise ValueError(
5974+
'weights should have the same shape as x')
5975+
59435976
if color is None and 'color' in self._get_lines._prop_keys:
5944-
color = [six.next(self._get_lines.prop_cycler)['color']
5977+
color = [next(self._get_lines.prop_cycler)['color']
59455978
for i in xrange(nx)]
59465979
else:
59475980
color = mcolors.colorConverter.to_rgba_array(color)
59485981
if len(color) != nx:
59495982
raise ValueError("color kwarg must have one color per dataset")
59505983

5951-
# We need to do to 'weights' what was done to 'x'
5952-
if weights is not None:
5953-
if isinstance(weights, np.ndarray) or not iterable(weights[0]):
5954-
w = np.array(weights)
5955-
if w.ndim == 2:
5956-
w = w.T
5957-
elif w.ndim == 1:
5958-
w.shape = (1, w.shape[0])
5959-
else:
5960-
raise ValueError("weights must be 1D or 2D")
5961-
else:
5962-
w = [np.asarray(wi) for wi in weights]
5963-
5964-
if len(w) != nx:
5965-
raise ValueError('weights should have the same shape as x')
5966-
for i in xrange(nx):
5967-
if len(w[i]) != len(x[i]):
5968-
raise ValueError(
5969-
'weights should have the same shape as x')
5970-
else:
5971-
w = [None]*nx
5972-
59735984
# Save the datalimits for the same reason:
59745985
_saved_bounds = self.dataLim.bounds
59755986

@@ -5985,7 +5996,7 @@ def hist(self, x, bins=10, range=None, normed=False, weights=None,
59855996
xmax = max(xmax, xi.max())
59865997
bin_range = (xmin, xmax)
59875998

5988-
#hist_kwargs = dict(range=range, normed=bool(normed))
5999+
# hist_kwargs = dict(range=range, normed=bool(normed))
59896000
# We will handle the normed kwarg within mpl until we
59906001
# get to the point of requiring numpy >= 1.5.
59916002
hist_kwargs = dict(range=bin_range)

lib/matplotlib/axes/_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
import matplotlib.image as mimage
3333
from matplotlib.offsetbox import OffsetBox
3434
from matplotlib.artist import allow_rasterization
35-
from matplotlib.cbook import iterable, index_of
35+
3636
from matplotlib.rcsetup import cycler
3737

3838
rcParams = matplotlib.rcParams

lib/matplotlib/backends/backend_pdf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1846,7 +1846,7 @@ def draw_tex(self, gc, x, y, s, prop, angle, ismath='TeX!', mtext=None):
18461846
fontsize = prop.get_size_in_points()
18471847
dvifile = texmanager.make_dvi(s, fontsize)
18481848
dvi = dviread.Dvi(dvifile, 72)
1849-
page = six.next(iter(dvi))
1849+
page = next(iter(dvi))
18501850
dvi.close()
18511851

18521852
# Gather font information and do some setup for combining

lib/matplotlib/cbook.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from matplotlib.externals import six
1313
from matplotlib.externals.six.moves import xrange, zip
1414
from itertools import repeat
15+
import collections
1516

1617
import datetime
1718
import errno
@@ -2536,6 +2537,13 @@ def index_of(y):
25362537
return np.arange(y.shape[0], dtype=float), y
25372538

25382539

2540+
def safe_first_element(obj):
2541+
if isinstance(obj, collections.Iterator):
2542+
raise RuntimeError("matplotlib does not support generators "
2543+
"as input")
2544+
return next(iter(obj))
2545+
2546+
25392547
def get_label(y, default_name):
25402548
try:
25412549
return y.name

lib/matplotlib/dates.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1561,8 +1561,8 @@ def default_units(x, axis):
15611561
x = x.ravel()
15621562

15631563
try:
1564-
x = x[0]
1565-
except (TypeError, IndexError):
1564+
x = cbook.safe_first_element(x)
1565+
except (TypeError, StopIteration):
15661566
pass
15671567

15681568
try:

lib/matplotlib/tests/test_axes.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import io
88

99
from nose.tools import assert_equal, assert_raises, assert_false, assert_true
10+
from nose.plugins.skip import SkipTest
1011

1112
import datetime
1213

@@ -4183,6 +4184,36 @@ def test_broken_barh_empty():
41834184
ax.broken_barh([], (.1, .5))
41844185

41854186

4187+
@cleanup
4188+
def test_pandas_indexing_dates():
4189+
try:
4190+
import pandas as pd
4191+
except ImportError:
4192+
raise SkipTest("Pandas not installed")
4193+
4194+
dates = np.arange('2005-02', '2005-03', dtype='datetime64[D]')
4195+
values = np.sin(np.array(range(len(dates))))
4196+
df = pd.DataFrame({'dates': dates, 'values': values})
4197+
4198+
ax = plt.gca()
4199+
4200+
without_zero_index = df[np.array(df.index) % 2 == 1].copy()
4201+
ax.plot('dates', 'values', data=without_zero_index)
4202+
4203+
4204+
@cleanup
4205+
def test_pandas_indexing_hist():
4206+
try:
4207+
import pandas as pd
4208+
except ImportError:
4209+
raise SkipTest("Pandas not installed")
4210+
4211+
ser_1 = pd.Series(data=[1, 2, 2, 3, 3, 4, 4, 4, 4, 5])
4212+
ser_2 = ser_1.iloc[1:]
4213+
fig, axes = plt.subplots()
4214+
axes.hist(ser_2)
4215+
4216+
41864217
if __name__ == '__main__':
41874218
import nose
41884219
import sys

0 commit comments

Comments
 (0)
0