diff --git a/lib/matplotlib/scale.py b/lib/matplotlib/scale.py index 449368d8206c..f0182347f7e8 100644 --- a/lib/matplotlib/scale.py +++ b/lib/matplotlib/scale.py @@ -485,18 +485,14 @@ def __init__(self, base, linthresh, linscale): self._log_base = np.log(base) def transform_non_affine(self, a): - sign = np.sign(a) - masked = ma.masked_inside(a, - -self.linthresh, - self.linthresh, - copy=False) - log = sign * self.linthresh * ( - self._linscale_adj + - ma.log(np.abs(masked) / self.linthresh) / self._log_base) - if masked.mask.any(): - return ma.where(masked.mask, a * self._linscale_adj, log) - else: - return log + abs_a = np.abs(a) + with np.errstate(divide="ignore", invalid="ignore"): + out = np.sign(a) * self.linthresh * ( + self._linscale_adj + + np.log(abs_a / self.linthresh) / self._log_base) + inside = abs_a <= self.linthresh + out[inside] = a[inside] * self._linscale_adj + return out def inverted(self): return InvertedSymmetricalLogTransform(self.base, self.linthresh, @@ -519,16 +515,14 @@ def __init__(self, base, linthresh, linscale): self._linscale_adj = (linscale / (1.0 - self.base ** -1)) def transform_non_affine(self, a): - sign = np.sign(a) - masked = ma.masked_inside(a, -self.invlinthresh, - self.invlinthresh, copy=False) - exp = sign * self.linthresh * ( - ma.power(self.base, (sign * (masked / self.linthresh)) - - self._linscale_adj)) - if masked.mask.any(): - return ma.where(masked.mask, a / self._linscale_adj, exp) - else: - return exp + abs_a = np.abs(a) + with np.errstate(divide="ignore", invalid="ignore"): + out = np.sign(a) * self.linthresh * ( + np.power(self.base, + abs_a / self.linthresh - self._linscale_adj)) + inside = abs_a <= self.invlinthresh + out[inside] = a[inside] / self._linscale_adj + return out def inverted(self): return SymmetricalLogTransform(self.base, diff --git a/lib/matplotlib/tests/test_scale.py b/lib/matplotlib/tests/test_scale.py index c0ca6230ed33..3eee976a3e7f 100644 --- a/lib/matplotlib/tests/test_scale.py +++ b/lib/matplotlib/tests/test_scale.py @@ -1,9 +1,11 @@ from matplotlib.cbook import MatplotlibDeprecationWarning import matplotlib.pyplot as plt -from matplotlib.scale import Log10Transform, InvertedLog10Transform +from matplotlib.scale import (Log10Transform, InvertedLog10Transform, + SymmetricalLogTransform) from matplotlib.testing.decorators import check_figures_equal, image_comparison import numpy as np +from numpy.testing import assert_allclose import io import platform import pytest @@ -22,6 +24,33 @@ def test_log_scales(fig_test, fig_ref): ax_ref.plot(xlim, [24.1, 24.1], 'b') +def test_symlog_mask_nan(): + # Use a transform round-trip to verify that the forward and inverse + # transforms work, and that they respect nans and/or masking. + slt = SymmetricalLogTransform(10, 2, 1) + slti = slt.inverted() + + x = np.arange(-1.5, 5, 0.5) + out = slti.transform_non_affine(slt.transform_non_affine(x)) + assert_allclose(out, x) + assert type(out) == type(x) + + x[4] = np.nan + out = slti.transform_non_affine(slt.transform_non_affine(x)) + assert_allclose(out, x) + assert type(out) == type(x) + + x = np.ma.array(x) + out = slti.transform_non_affine(slt.transform_non_affine(x)) + assert_allclose(out, x) + assert type(out) == type(x) + + x[3] = np.ma.masked + out = slti.transform_non_affine(slt.transform_non_affine(x)) + assert_allclose(out, x) + assert type(out) == type(x) + + @image_comparison(['logit_scales.png'], remove_text=True) def test_logit_scales(): fig, ax = plt.subplots()