8000 Refactor onto a 'join_colormaps' function · matplotlib/matplotlib@688ff01 · GitHub
[go: up one dir, main page]

Skip to content

Commit 688ff01

Browse files
committed
Refactor onto a 'join_colormaps' function
1 parent 79bfcfb commit 688ff01

File tree

1 file changed

+56
-10
lines changed

1 file changed

+56
-10
lines changed

lib/matplotlib/colors.py

Lines changed: 56 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,60 @@ def __delitem__(self, key, value):
9090
_colors_full_map = _ColorMapping(_colors_full_map)
9191

9292

93+
def join_colormaps(cmaps, fractions=None, name=None, N=None):
94+
"""
95+
Join a series of colormaps into one.
96+
97+
Parameters
98+
----------
99+
cmaps : a sequence of colormaps to be joined (length M)
100+
fractions : a sequence of floats or ints (length M)
101+
The fraction of the new colormap that each colormap should
102+
occupy. These are normalized so they sum to 1. By default, the
103+
fractions are the ``N`` attribute of each cmap.
104+
name : str, optional
105+
The name for the joined colormap. This defaults to
106+
``cmap[0].name + '+' + cmap[1].name + '+' ...``
107+
N : int
108+
The number of entries in the color map. This defaults to the
109+
sum of the ``N`` attributes of the cmaps.
110+
111+
Returns
112+
-------
113+
ListedColormap
114+
The joined colormap.
115+
116+
Examples
117+
--------
118+
import matplotlib.pyplat as plt
119+
cmap1 = plt.get_cmap('viridis', 128)
120+
cmap2 = plt.get_cmap('plasma_r', 64)
121+
cmap3 = plt.get_cmap('jet', 64)
122+
123+
joined_cmap = join_colormaps((cmap1, cmap2, cmap3))
124+
125+
See Also
126+
--------
127+
128+
:meth:`Colorbar.join` and :meth:`Colorbar.__add__` : a method
129+
implementation of this functionality
130+
"""
131+
if N is None:
132+
N = np.sum([cm.N for cm in cmaps])
133+
if fractions is None:
134+
fractions = [cm.N for cm in cmaps]
135+
fractions = np.array(fractions) / np.sum(fractions, dtype='float')
136+
if name is None:
137+
name = ""
138+
for cm in cmaps:
139+
name += cm.name + '+'
140+
name.rstrip('+')
141+
maps = [cm(np.linspace(0, 1, int(N * frac)))
142+
for cm, frac in zip(cmaps, fractions)]
143+
# N is set by len of the vstack'd array:
144+
return ListedColormap(np.vstack(maps), name, )
145+
146+
93147
def get_named_colors_mapping():
94148
"""Return the global mapping of names to named colors.
95149
"""
@@ -631,18 +685,10 @@ def join(self, other, frac_self=None, name=None, N=None):
631685
632686
joined_cmap = cmap1 + cmap2
633687
"""
634-
if N is None:
635-
N = self.N + other.N
636688
if frac_self is None:
637689
frac_self = self.N / (other.N + self.N)
638-
if name is None:
639-
name = '{}+{}'.format(self.name, other.name)
640-
if not (0 < frac_self and frac_self < 1):
641-
raise ValueError("frac_self must be in the interval (0.0, 1.0)")
642-
map0 = self(np.linspace(0, 1, int(N * frac_self)))
643-
map1 = other(np.linspace(0, 1, int(N * (1 - frac_self))))
644-
# N is set by len of the vstack'd array:
645-
return ListedColormap(np.vstack((map0, map1)), name, )
690+
fractions = [frac_self, 1 - frac_self]
691+
return join_colormaps([self, other], fractions, name, N)
646692

647693
__add__ = join
648694

0 commit comments

Comments
 (0)
0