8000 Fixes based on feedback on review · matplotlib/matplotlib@69f9977 · GitHub
[go: up one dir, main page]

Skip to content

Commit 69f9977

Browse files
committed
Fixes based on feedback on review
1 parent e688f03 commit 69f9977

File tree

3 files changed

+47
-49
lines changed

3 files changed

+47
-49
lines changed

lib/matplotlib/colors.py

Lines changed: 32 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -728,14 +728,16 @@ def __call__(self, X, alpha=None, bytes=False):
728728
bytes : bool
729729
If False (default), the returned RGBA values will be floats in the
730730
interval ``[0, 1]`` otherwise they will be `numpy.uint8`\s in the
731-
interval ``[0, 255]``
731+
interval ``[0, 255]``.
732732
733733
Returns
734734
-------
735735
Tuple of RGBA values if X is scalar, otherwise an array of
736736
RGBA values with a shape of ``X.shape + (4, )``.
737737
"""
738738
rgba, mask = self._get_rgba_and_mask(X, alpha=alpha, bytes=bytes)
739+
if not np.iterable(X):
740+
rgba = tuple(rgba)
739741
return rgba
740742

741743
def _get_rgba_and_mask(self, X, alpha=None, bytes=False):
@@ -758,9 +760,9 @@ def _get_rgba_and_mask(self, X, alpha=None, bytes=False):
758760
759761
Returns
760762
-------
761-
(colors, mask), where color is a tuple of RGBA values if X is scalar,
762-
otherwise an array of RGBA values with a shape of ``X.shape + (4, )``,
763-
and mask is a boolean array.
763+
colors : array of RGBA values with a shape of ``X.shape + (4, )``.
764+
mask : boolean array with True where the input is ``np.nan`` or
765+
masked.
764766
"""
765767
if not self._isinit:
766768
self._init()
@@ -805,8 +807,6 @@ def _get_rgba_and_mask(self, X, alpha=None, bytes=False):
805807
if (lut[-1] == 0).all():
806808
rgba[mask_bad] = (0, 0, 0, 0)
807809

808-
if not np.iterable(X):
809-
rgba = tuple(rgba)
810810
return rgba, mask_bad
811811

812812
def __copy__(self):
@@ -1291,9 +1291,7 @@ def __init__(self, colormaps, combination_mode, name='multivariate colormap'):
12911291
" Colormap or valid strings.")
12921292

12931293
self._colormaps = colormaps
1294-
if combination_mode not in ['sRGB_add', 'sRGB_sub']:
1295-
raise ValueError("Combination_mode must be 'sRGB_add' or 'sRGB_sub',"
1296-
f" {combination_mode!r} is not allowed.")
1294+
_api.check_in_list(['sRGB_add', 'sRGB_sub'], combination_mode=combination_mode)
12971295
self._combination_mode = combination_mode
12981296
self.n_variates = len(colormaps)
12991297
self._rgba_bad = (0.0, 0.0, 0.0, 0.0) # If bad, don't paint anything.
@@ -1332,10 +1330,8 @@ def __call__(self, X, alpha=None, bytes=False, clip=True):
13321330
f'For the selected colormap the data must have a first dimension '
13331331
f'{len(self)}, not {len(X)}')
13341332
rgba, mask_bad = self[0]._get_rgba_and_mask(X[0], bytes=False)
1335-
rgba = np.asarray(rgba)
13361333
for c, xx in zip(self[1:], X[1:]):
13371334
sub_rgba, sub_mask_bad = c._get_rgba_and_mask(xx, bytes=False)
1338-
sub_rgba = np.asarray(sub_rgba)
13391335
rgba[..., :3] += sub_rgba[..., :3] # add colors
13401336
rgba[..., 3] *= sub_rgba[..., 3] # multiply alpha
13411337
mask_bad |= sub_mask_bad
@@ -1419,11 +1415,10 @@ def resampled(self, lutshape):
14191415
14201416
< F438 span class=pl-s> Parameters
14211417
----------
1422-
lutshape : tuple of ints or None
1423-
The tuple must be of length matching the number of variates,
1424-
and each entry is either an int or None.
1425-
If an int, the corresponding colorbar is resampled.
1426-
If None, the corresponding colorbar is not resampled.
1418+
lutshape : tuple of (`int`, `None`)
1419+
The tuple must have a length matching the number of variates.
1420+
For each element in the tuple, if `int`, the corresponding colorbar
1421+
is resampled, if `None`, the corresponding colorbar is not resampled.
14271422
14281423
Returns
14291424
-------
@@ -1440,22 +1435,24 @@ def resampled(self, lutshape):
14401435

14411436
def with_extremes(self, *, bad=None, under=None, over=None):
14421437
"""
1443-
Return a copy of the MultivarColormap, for which the colors for masked (*bad*)
1444-
values has been set and, low (*under*) and high (*over*) out-of-range values,
1445-
been set in the component colormaps. Note that *under* and *over* colors
1446-
are subject to the mixing rules determined by the *combination_mode*.
1438+
Return a copy of the `MultivarColormap` with modified out-of-range attributes.
1439+
1440+
The *bad* keyword modifies the copied `MultivarColormap` while *under* and
1441+
*over* modifies the attributes of the copied component colormaps.
1442+
Note that *under* and *over* colors are subject to the mixing rules determined
1443+
by the *combination_mode*.
14471444
14481445
Parameters
14491446
----------
14501447
bad : None or :mpltype:`color`
14511448
If Matplotlib color, the bad value is set accordingly in the copy
14521449
1453-
under : None or tuple of length matching the length of the MultivarColormap
1450+
under : None or tuple of :mpltype:`color`
14541451
If tuple, the `under` value of each component is s 10000 et with the values
14551452
from the tuple.
14561453
1457-
over : None or tuple of length matching the length of the MultivarColormap
1458-
If tuple, the `under` value of each component is set with the values
1454+
over : None or tuple of :mpltype:`color`
1455+
If tuple, the `over` value of each component is set with the values
14591456
from the tuple.
14601457
14611458
Returns
@@ -1528,7 +1525,6 @@ def __init__(self, N=256, M=256, shape='square', origin=(0, 0),
15281525
The number of RGB quantization levels along the first axis.
15291526
M : int
15301527
The number of RGB quantization levels along the second axis.
1531-
If None, M = N
15321528
shape: {'square', 'circle', 'ignore', 'circleignore'}
15331529
15341530
- 'square' each variate is clipped to [0,1] independently
@@ -1551,11 +1547,8 @@ def __init__(self, N=256, M=256, shape='square', origin=(0, 0),
15511547
self.name = name
15521548
self.N = int(N) # ensure that N is always int
15531549
self.M = int(M)
1554-
if shape in ['square', 'circle', 'ignore', 'circleignore']:
1555-
self._shape = shape
1556-
else:
1557-
raise ValueError("The shape must be a valid string, "
1558-
"'square', 'circle', 'ignore', or 'circleignore'")
1550+
_api.check_in_list(['square', 'circle', 'ignore', 'circleignore'], shape=shape)
1551+
self._shape = shape
15591552
self._rgba_bad = (0.0, 0.0, 0.0, 0.0) # If bad, don't paint anything.
15601553
self._rgba_outside = (1.0, 0.0, 1.0, 1.0)
15611554
self._isinit = False
@@ -1738,8 +1731,8 @@ def resampled(self, lutshape, transposed=False):
17381731
The tuple must be of length 2, and each entry is either an int or None.
17391732
17401733
- If an int, the corresponding axis is resampled.
1741-
- If -1, the axis is inverted
17421734
- If negative the corresponding axis is resampled in reverse
1735+
- If -1, the axis is inverted
17431736
- If 1 or None, the corresponding axis is not resampled.
17441737
17451738
transposed : bool
@@ -1815,9 +1808,10 @@ def transposed(self):
18151808

18161809
def with_extremes(self, *, bad=None, outside=None, shape=None, origin=None):
18171810
"""
1818-
Return a copy of the BivarColormap, for which the colors for masked (*bad*)
1819-
valuesand if shape = 'ignore' or 'circleignore', out-of-range *outside* values,
1820-
have been set accordingly.
1811+
Return a copy of the `BivarColormap` with modified attributes.
1812+
1813+
Note that the *outside* color is only relevantif `shape` = 'ignore'
1814+
or 'circleignore'.
18211815
18221816
Parameters
18231817
----------
@@ -1855,11 +1849,9 @@ def with_extremes(self, *, bad=None, outside=None, shape=None, origin=None):
18551849
if outside is not None:
18561850
new_cm._rgba_outside = to_rgba(outside)
18571851
if shape is not None:
1858-
if shape in ['square', 'circle', 'ignore', 'circleignore']:
1859-
new_cm._shape = shape
1860-
else:
1861-
raise ValueError("The shape must be a valid string, "
1862-
"'square', 'circle', 'ignore', or 'circleignore'")
1852+
_api.check_in_list(['square', 'circle', 'ignore', 'circleignore'],
1853+
shape=shape)
1854+
new_cm._shape = shape
18631855
if origin is not None:
18641856
new_cm._origin = (float(origin[0]), float(origin[1]))
18651857

@@ -2014,7 +2006,7 @@ def color_block(color):
20142006
'</div>'
20152007
'<div style="float: right;">'
20162008
f'bad {color_block(self.get_bad())}'
2017-
'</div>')
2009+
'</div></div>')
20182010

20192011
def copy(self):
20202012
"""Return a copy of the colormap."""
@@ -2052,6 +2044,7 @@ class SegmentedBivarColormap(BivarColormap):
20522044

20532045
def __init__(self, patch, N=256, shape='square', origin=(0, 0),
20542046
name='segmented bivariate colormap'):
2047+
_api.check_shape((None, None, 3), patch=patch)
20552048
self.patch = patch
20562049
super().__init__(N, N, shape, origin, name=name)
20572050

lib/matplotlib/colors.pyi

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -140,9 +140,8 @@ class ListedColormap(Colormap):
140140

141141
class MultivarColormap:
142142
name: str
143-
colormaps: list[Colormap]
144143
n_variates: int
145-
def __init__(self, colormaps: list[Colormap], combination_mode: str, name: str = ...) -> None: ...
144+
def __init__(self, colormaps: list[Colormap], combination_mode: Literal['sRGB_add', 'sRGB_sub'], name: str = ...) -> None: ...
146145
@overload
147146
def __call__(
148147
self, X: Sequence[Sequence[float]] | np.ndarray, alpha: ArrayLike | None = ..., bytes: bool = ..., clip: bool = ...
@@ -157,6 +156,7 @@ class MultivarColormap:
157156
) -> tuple[float, float, float, float] | np.ndarray: ...
158157
def copy(self) -> MultivarColormap: ...
159158
def __copy__(self) -> MultivarColormap: ...
159+
def __eq__(self, other: Any) -> bool: ...
160160
def __getitem__(self, item: int) -> Colormap: ...
161161
def __iter__(self) -> Iterator[Colormap]: ...
162162
def __len__(self) -> int: ...
@@ -179,7 +179,9 @@ class BivarColormap:
179179
N: int
180180
M: int
181181
n_variates: int
182-
def __init__(self, N: int = ..., M: int | None = ..., shape: str = ..., origin: Sequence[float] = ..., name: str = ...
182+
def __init__(
183+
self, N: int = ..., M: int | None = ..., shape: Literal['square', 'circle', 'ignore', 'circleignore'] = ...,
184+
origin: Sequence[float] = ..., name: str = ...
183185
) -> None: ...
184186
@overload
185187
def __call__(
@@ -199,12 +201,12 @@ class BivarColormap:
199201
def shape(self) -> str: ...
200202
@property
201203
def origin(self) -> tuple[float, float]: ...
204+
def copy(self) -> BivarColormap: ...
202205
def __copy__(self) -> BivarColormap: ...
203206
def __getitem__(self, item: int) -> Colormap: ...
204207
def __eq__(self, other: Any) -> bool: ...
205208
def get_bad(self) -> np.ndarray: ...
206209
def get_outside(self) -> np.ndarray: ...
207-
def copy(self) -> BivarColormap: ...
208210
def resampled(self, lutshape: Sequence[int | None], transposed: bool = ...) -> BivarColormap: ...
209211
def transposed(self) -> BivarColormap: ...
210212
def reversed(self, axis_0: bool = ..., axis_1: bool = ...) -> BivarColormap: ...
@@ -221,11 +223,14 @@ class BivarColormap:
221223

222224
class SegmentedBivarColormap(BivarColormap):
223225
def __init__(
224-
self, patch: np.ndarray, N: int = ..., shape: str = ..., origin: Sequence[float] = ..., name: str = ...
226+
self, patch: np.ndarray, N: int = ..., shape: Literal['square', 'circle', 'ignore', 'circleignore'] = ...,
227+
origin: Sequence[float] = ..., name: str = ...
225228
) -> None: ...
226229

227230
class BivarColormapFromImage(BivarColormap):
228-
def __init__(self, lut: np.ndarray, shape: str = ..., origin: Sequence[float] = ..., name: str = ...
231+
def __init__(
232+
self, lut: np.ndarray, shape: Literal['square', 'circle', 'ignore', 'circleignore'] = ...,
233+
origin: Sequence[float] = ..., name: str = ...
229234
) -> None: ...
230235

231236
class Normalize:

lib/matplotlib/tests/test_multivariate_colormaps.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
@image_comparison(["bivariate_cmap_shapes.png"])
1515
def test_bivariate_cmap_shapes():
16-
x_0 = np.repeat(np.linspace(-0.1, 1.1, 10, dtype='float32'), (10, 1))
16+
x_0 = np.repeat(np.linspace(-0.1, 1.1, 10, dtype='float32')[None, :], 10, axis=0)
1717
x_1 = x_0.T
1818

1919
fig, axes = plt.subplots(1, 4, figsize=(10, 2))
@@ -193,7 +193,7 @@ def test_multivar_cmap_call():
193193

194194
def test_multivar_bad_mode():
195195
cmap = mpl.multivar_colormaps['2VarSubA']
196-
with pytest.raises(ValueError, match="Combination_mode must be 'sRGB_add' or"):
196+
with pytest.raises(ValueError, match="is not a valid value for"):
197197
cmap = mpl.colors.MultivarColormap(cmap[:], 'bad')
198198

199199

@@ -359,11 +359,11 @@ def test_bivar_cmap_bad_shape():
359359
cmap = mpl.bivar_colormaps['BiCone']
360360
_ = cmap.lut
361361
with pytest.raises(ValueError,
362-
match="shape must be a valid string"):
362+
match="is not a valid value for shape"):
363363
cmap.with_extremes(shape='bad_shape')
364364

365365
with pytest.raises(ValueError,
366-
match="shape must be a valid string"):
366+
match="is not a valid value for shape"):
367367
mpl.colors.BivarColormapFromImage(np.ones((3, 3, 4)),
368368
shape='bad_shape')
369369

0 commit comments

Comments
 (0)
0