8000 Merge pull request #27562 from QuLogic/no-alpha-copy · matplotlib/matplotlib@8542398 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8542398

Browse files
authored
Merge pull request #27562 from QuLogic/no-alpha-copy
Avoid an extra copy/resample if imshow input has no alpha
2 parents 7f8b9b3 + 3d6a349 commit 8542398

File tree

2 files changed

+34
-4
lines changed

2 files changed

+34
-4
lines changed

lib/matplotlib/image.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -555,11 +555,15 @@ def _make_image(self, A, in_bbox, out_bbox, clip_bbox, magnification=1.0,
555555
if A.ndim == 2: # _interpolation_stage == 'rgba'
556556
self.norm.autoscale_None(A)
557557
A = self.to_rgba(A)
558-
if A.shape[2] == 3:
559-
A = _rgb_to_rgba(A)
560558
alpha = self._get_scalar_alpha()
561-
output_alpha = _resample( # resample alpha channel
562-
self, A[..., 3], out_shape, t, alpha=alpha)
559+
if A.shape[2] == 3:
560+
# No need to resample alpha or make a full array; NumPy will expand
561+
# this out and cast to uint8 if necessary when it's assigned to the
562+
# alpha channel below.
563+
output_alpha = (255 * alpha) if A.dtype == np.uint8 else alpha
564+
else:
565+
output_alpha = _resample( # resample alpha channel
566+
self, A[..., 3], out_shape, t, alpha=alpha)
563567
output = _resample( # resample rgb channels
564568
self, _rgb_to_rgba(A[..., :3]), out_shape, t, alpha=alpha)
565569
output[..., 3] = output_alpha # recombine rgb and alpha

lib/matplotlib/tests/test_image.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,32 @@ def test_image_alpha():
268268
ax3.imshow(Z, alpha=0.5, interpolation='nearest')
269269

270270

271+
@mpl.style.context('mpl20')
272+
@check_figures_equal(extensions=['png'])
273+
def test_imshow_alpha(fig_test, fig_ref):
274+
np.random.seed(19680801)
275+
276+
rgbf = np.random.rand(6, 6, 3)
277+
rgbu = np.uint8(rgbf * 255)
278+
((ax0, ax1), (ax2, ax3)) = fig_test.subplots(2, 2)
279+
ax0.imshow(rgbf, alpha=0.5)
280+
ax1.imshow(rgbf, alpha=0.75)
281+
ax2.imshow(rgbu, alpha=0.5)
282+
ax3.imshow(rgbu, alpha=0.75)
283+
284+
rgbaf = np.concatenate((rgbf, np.ones((6, 6, 1))), axis=2)
285+
rgbau = np.concatenate((rgbu, np.full((6, 6, 1), 255, np.uint8)), axis=2)
286+
((ax0, ax1), (ax2, ax3)) = fig_ref.subplots(2, 2)
287+
rgbaf[:, :, 3] = 0.5
288+
ax0.imshow(rgbaf)
289+
rgbaf[:, :, 3] = 0.75
290+
ax1.imshow(rgbaf)
291+
rgbau[:, :, 3] = 127
292+
ax2.imshow(rgbau)
293+
rgbau[:, :, 3] = 191
294+
ax3.imshow(rgbau)
295+
296+
271297
def test_cursor_data():
272298
from matplotlib.backend_bases import MouseEvent
273299

0 commit comments

Comments
 (0)
0