8000 Scatter: make "c" argument handling more consistent. · matplotlib/matplotlib@5a5c2f6 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5a5c2f6

Browse files
committed
Scatter: make "c" argument handling more consistent.
Closes #12735.
1 parent bbd10ba commit 5a5c2f6

File tree

2 files changed

+36
-30
lines changed

lib/matplotlib/axes/_axes.py

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4133,7 +4133,7 @@ def dopatch(xs, ys, **kwargs):
41334133
medians=medians, fliers=fliers, means=means)
41344134

41354135
@staticmethod
4136-
def _parse_scatter_color_args(c, edgecolors, kwargs, xshape, yshape,
4136+
def _parse_scatter_color_args(c, edgecolors, kwargs, xsize,
41374137
get_next_color_func):
41384138
"""
41394139
Helper function to process color related arguments of `.Axes.scatter`.
@@ -4163,8 +4163,8 @@ def _parse_scatter_color_args(c, edgecolors, kwargs, xshape, yshape,
41634163
Additional kwargs. If these keys exist, we pop and process them:
41644164
'facecolors', 'facecolor', 'edgecolor', 'color'
41654165
Note: The dict is modified by this function.
4166-
xshape, yshape : tuple of int
4167-
The shape of the x and y arrays passed to `.Axes.scatter`.
4166+
xsize : int
4167+
The size of the x and y arrays passed to `.Axes.scatter`.
41684168
get_next_color_func : callable
41694169
A callable that returns a color. This color is used as facecolor
41704170
if no other color is provided.
@@ -4187,9 +4187,6 @@ def _parse_scatter_color_args(c, edgecolors, kwargs, xshape, yshape,
41874187
The edgecolor specification.
41884188
41894189
"""
4190-
xsize = functools.reduce(operator.mul, xshape, 1)
4191-
ysize = functools.reduce(operator.mul, yshape, 1)
4192-
41934190
facecolors = kwargs.pop('facecolors', None)
41944191
facecolors = kwargs.pop('facecolor', facecolors)
41954192
edgecolors = kwargs.pop('edgecolor', edgecolors)
@@ -4241,9 +4238,9 @@ def _parse_scatter_color_args(c, edgecolors, kwargs, xshape, yshape,
42414238
else:
42424239
try: # First, does 'c' look suitable for value-mapping?
42434240
c_array = np.asanyarray(c, dtype=float)
4244-
n_elem = c_array.shape[0]
4245-
if c_array.shape in [xshape, yshape]:
4246-
c = np.ma.ravel(c_array)
4241+
n_elem = c_array.size
4242+
if n_elem == xsize:
4243+
c = c_array.ravel()
42474244
else:
42484245
if c_array.shape in ((3,), (4,)):
42494246
_log.warning(
@@ -4263,18 +4260,17 @@ def _parse_scatter_color_args(c, edgecolors, kwargs, xshape, yshape,
42634260
try: # Then is 'c' acceptable as PathCollection facecolors?
42644261
colors = mcolors.to_rgba_array(c)
42654262
n_elem = colors.shape[0]
4266-
if colors.shape[0] not in (0, 1, xsize, ysize):
4263+
if colors.shape[0] not in (0, 1, xsize):
42674264 8000
# NB: remember that a single color is also acceptable.
42684265
# Besides *colors* will be an empty array if c == 'none'.
42694266
valid_shape = False
42704267
raise ValueError
4271-
except ValueError:
4268+
except (ValueError, TypeError):
42724269
if not valid_shape: # but at least one conversion succeeded.
42734270
raise ValueError(
42744271
"'c' argument has {nc} elements, which is not "
4275-
"acceptable for use with 'x' with size {xs}, "
4276-
"'y' with size {ys}."
4277-
.format(nc=n_elem, xs=xsize, ys=ysize)
4272+
"acceptable for use with 'x' and 'y' with size {xs}."
4273+
.format(nc=n_elem, xs=xsize)
42784274
)
42794275
else:
42804276
# Both the mapping *and* the RGBA conversion failed: pretty
@@ -4301,7 +4297,7 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None,
43014297
43024298
Parameters
43034299
----------
4304-
x, y : array_like, shape (n, )
4300+
x, y : scalar or array_like, shape (n, )
43054301
The data positions.
43064302
43074303
s : scalar or array_like, shape (n, ), optional
@@ -4313,8 +4309,8 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None,
43134309
43144310
- A single color format string.
43154311
- A sequence of color specifications of length n.
4316-
- A sequence of n numbers to be mapped to colors using *cmap* and
4317-
*norm*.
4312+
- A scalar or sequence of n numbers to be mapped to colors using
4313+
*cmap* and *norm*.
43184314
- A 2-D array in which the rows are RGB or RGBA.
43194315
43204316
Note that *c* should not be a single numeric RGB or RGBA sequence
@@ -4403,7 +4399,7 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None,
44034399
plotted.
44044400
44054401
* Fundamentally, scatter works with 1-D arrays; *x*, *y*, *s*, and *c*
4406-
may be input as 2-D arrays, but within scatter they will be
4402+
may be input as N-D arrays, but within scatter they will be
44074403
flattened. The exception is *c*, which will be flattened only if its
44084404
size matches the size of *x* and *y*.
44094405
@@ -4416,7 +4412,6 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None,
44164412

44174413
# np.ma.ravel yields an ndarray, not a masked array,
44184414
# unless its argument is a masked array.
4419-
xshape, yshape = np.shape(x), np.shape(y)
44204415
x = np.ma.ravel(x)
44214416
y = np.ma.ravel(y)
44224417
if x.size != y.size:
@@ -4429,7 +4424,7 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None,
44294424

44304425
c, colors, edgecolors = \
44314426
self._parse_scatter_color_args(
4432-
c, edgecolors, kwargs, xshape, yshape,
4427+
c, edgecolors, kwargs, x.size,
44334428
get_next_color_func=self._get_patches_for_fill.get_next_color)
44344429

44354430
if plotnonfinite and colors is None:

lib/matplotlib/tests/test_axes.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1890,6 +1890,21 @@ def test_scatter_no_invalid_color(self, fig_test, fig_ref):
18901890
ax = fig_ref.subplots()
18911891
ax.scatter([0, 2], [0, 2], c=[1, 2], s=[1, 3], cmap=cmap)
18921892

1893+
@check_figures_equal(extensions=["png"])
1894+
def test_scatter_single_point(self, fig_test, fig_ref):
1895+
ax = fig_test.subplots()
1896+
ax.scatter(1, 1, c=1)
1897+
ax = fig_ref.subplots()
1898+
ax.scatter([1], [1], c=[1])
1899+
1900+
@check_figures_equal(extensions=["png"])
1901+
def test_scatter_different_shapes(self, fig_test, fig_ref):
1902+
x = np.arange(10)
1903+
ax = fig_test.subplots()
1904+
ax.scatter(x, x.reshape(2, 5), c=x.reshape(5, 2))
1905+
ax = fig_ref.subplots()
1906+
ax.scatter(x.reshape(5, 2), x, c=x.reshape(2, 5))
1907+
18931908
# Parameters for *test_scatter_c*. NB: assuming that the
18941909
# scatter plot will have 4 elements. The tuple scheme is:
18951910
# (*c* parameter case, exception regexp key or None if no exception)
@@ -1946,7 +1961,7 @@ def get_next_color():
19461961

19471962
from matplotlib.axes import Axes
19481963

1949-
xshape = yshape = (4,)
1964+
xsize = 4
19501965

19511966
# Additional checking of *c* (introduced in #11383).
19521967
REGEXP = {
@@ -1956,21 +1971,18 @@ def get_next_color():
19561971

19571972
if re_key is None:
19581973
Axes._parse_scatter_color_args(
1959-
c=c_case, edgecolors="black", kwargs={},
1960-
xshape=xshape, yshape=yshape,
1974+
c=c_case, edgecolors="black", kwargs={}, xsize=xsize,
19611975
get_next_color_func=get_next_color)
19621976
else:
19631977
with pytest.raises(ValueError, match=REGEXP[re_key]):
19641978
Axes._parse_scatter_color_args(
1965-
c=c_case, edgecolors="black", kwargs={},
1966-
xshape=xshape, yshape=yshape,
1979+
c=c_case, edgecolors="black", kwargs={}, xsize=xsize,
19671980
get_next_color_func=get_next_color)
19681981

19691982

1970-
def _params(c=None, xshape=(2,), yshape=(2,), **kwargs):
1983+
def _params(c=None, xsize=2, **kwargs):
19711984
edgecolors = kwargs.pop('edgecolors', None)
1972-
return (c, edgecolors, kwargs if kwargs is not None else {},
1973-
xshape, yshape)
1985+
return (c, edgecolors, kwargs if kwargs is not None else {}, xsize)
19741986
_result = namedtuple('_result', 'c, colors')
19751987

19761988

@@ -2022,8 +2034,7 @@ def get_next_color():
20222034
c = kwargs.pop('c', None)
20232035
edgecolors = kwargs.pop('edgecolors', None)
20242036
_, _, result_edgecolors = \
2025-
Axes._parse_scatter_color_args(c, edgecolors, kwargs,
2026-
xshape=(2,), yshape=(2,),
2037+
Axes._parse_scatter_color_args(c, edgecolors, kwargs, xsize=2,
20272038
get_next_color_func=get_next_color)
20282039
assert result_edgecolors == expected_edgecolors
20292040

0 commit comments

Comments
 (0)
0