diff --git a/lib/matplotlib/axis.py b/lib/matplotlib/axis.py index 97d10ae37225..3400d4b67472 100644 --- a/lib/matplotlib/axis.py +++ b/lib/matplotlib/axis.py @@ -1565,8 +1565,8 @@ def have_units(self): return self.converter is not None or self.units is not None def convert_units(self, x): - # If x is already a number, doesn't need converting - if munits.ConversionInterface.is_numlike(x): + # If x is natively supported by Matplotlib, doesn't need converting + if munits.ConversionInterface.is_natively_supported(x): return x if self.converter is None: diff --git a/lib/matplotlib/tests/test_axes.py b/lib/matplotlib/tests/test_axes.py index 210d470636c2..f08f45b67e9b 100644 --- a/lib/matplotlib/tests/test_axes.py +++ b/lib/matplotlib/tests/test_axes.py @@ -11,6 +11,7 @@ import numpy as np from numpy import ma from cycler import cycler +from decimal import Decimal import pytest import warnings @@ -1477,6 +1478,62 @@ def test_bar_tick_label_multiple_old_alignment(): align='center') +@check_figures_equal(extensions=["png"]) +def test_bar_decimal_center(fig_test, fig_ref): + ax = fig_test.subplots() + x0 = [1.5, 8.4, 5.3, 4.2] + y0 = [1.1, 2.2, 3.3, 4.4] + x = [Decimal(x) for x in x0] + y = [Decimal(y) for y in y0] + # Test image - vertical, align-center bar chart with Decimal() input + ax.bar(x, y, align='center') + # Reference image + ax = fig_ref.subplots() + ax.bar(x0, y0, align='center') + + +@check_figures_equal(extensions=["png"]) +def test_barh_decimal_center(fig_test, fig_ref): + ax = fig_test.subplots() + x0 = [1.5, 8.4, 5.3, 4.2] + y0 = [1.1, 2.2, 3.3, 4.4] + x = [Decimal(x) for x in x0] + y = [Decimal(y) for y in y0] + # Test image - horizontal, align-center bar chart with Decimal() input + ax.barh(x, y, height=[0.5, 0.5, 1, 1], align='center') + # Reference image + ax = fig_ref.subplots() + ax.barh(x0, y0, height=[0.5, 0.5, 1, 1], align='center') + + +@check_figures_equal(extensions=["png"]) +def test_bar_decimal_width(fig_test, fig_ref): + x = [1.5, 8.4, 5.3, 4.2] + y = [1.1, 2.2, 3.3, 4.4] + w0 = [0.7, 1.45, 1, 2] + w = [Decimal(i) for i in w0] + # Test image - vertical bar chart with Decimal() width + ax = fig_test.subplots() + ax.bar(x, y, width=w, align='center') + # Reference image + ax = fig_ref.subplots() + ax.bar(x, y, width=w0, align='center') + + +@check_figures_equal(extensions=["png"]) +def test_barh_decimal_height(fig_test, fig_ref): + x = [1.5, 8.4, 5.3, 4.2] + y = [1.1, 2.2, 3.3, 4.4] + h0 = [0.7, 1.45, 1, 2] + h = [Decimal(i) for i in h0] + # Test image - horizontal bar chart with Decimal() height + ax = fig_test.subplots() + ax.barh(x, y, height=h, align='center') + # Reference image + ax = fig_ref.subplots() + ax.barh(x, y, height=h0, align='center') + + def test_bar_color_none_alpha(): ax = plt.gca() rects = ax.bar([1, 2], [2, 4], alpha=0.3, color='none', edgecolor='r') @@ -1819,6 +1876,21 @@ def test_scatter_2D(self): fig, ax = plt.subplots() ax.scatter(x, y, c=z, s=200, edgecolors='face') + @check_figures_equal(extensions=["png"]) + def test_scatter_decimal(self, fig_test, fig_ref): + x0 = np.array([1.5, 8.4, 5.3, 4.2]) + y0 = np.array([1.1, 2.2, 3.3, 4.4]) + x = np.array([Decimal(i) for i in x0]) + y = np.array([Decimal(i) for i in y0]) + c = ['r', 'y', 'b', 'lime'] + s = [24, 15, 19, 29] + # Test image - scatter plot with Decimal() input + ax = fig_test.subplots() + ax.scatter(x, y, c=c, s=s) + # Reference image + ax = fig_ref.subplots() + ax.scatter(x0, y0, c=c, s=s) + def test_scatter_color(self): # Try to catch cases where 'c' kwarg should have been used. with pytest.raises(ValueError): @@ -5965,6 +6037,18 @@ def test_plot_columns_cycle_deprecation(): plt.plot(np.zeros((2, 2)), np.zeros((2, 3))) +@check_figures_equal(extensions=["png"]) +def test_plot_decimal(fig_test, fig_ref): + x0 = np.arange(-10, 10, 0.3) + y0 = [5.2 * x ** 3 - 2.1 * x ** 2 + 7.34 * x + 4.5 for x in x0] + x = [Decimal(i) for i in x0] + y = [Decimal(i) for i in y0] + # Test image - line plot with Decimal input + fig_test.subplots().plot(x, y) + # Reference image + fig_ref.subplots().plot(x0, y0) + + # pdf and svg tests fail using travis' old versions of gs and inkscape. @check_figures_equal(extensions=["png"]) def test_markerfacecolor_none_alpha(fig_test, fig_ref): diff --git a/lib/matplotlib/units.py b/lib/matplotlib/units.py index 4cbac7c226ab..0f55b286bc21 100644 --- a/lib/matplotlib/units.py +++ b/lib/matplotlib/units.py @@ -45,6 +45,8 @@ def default_units(x, axis): from numbers import Number import numpy as np +from numpy import ma +from decimal import Decimal from matplotlib import cbook @@ -132,6 +134,58 @@ def is_numlike(x): else: return isinstance(x, Number) + @staticmethod + def is_natively_supported(x): + """ + Return whether *x* is of a type that Matplotlib natively supports or + *x* is array of objects of such types. + """ + # Matplotlib natively supports all number types except Decimal + if np.iterable(x): + # Assume lists are homogeneous as other functions in unit system + for thisx in x: + return (isinstance(thisx, Number) and + not isinstance(thisx, Decimal)) + else: + return isinstance(x, Number) and not isinstance(x, Decimal) + + +class DecimalConverter(ConversionInterface): + """ + Converter for decimal.Decimal data to float. + """ + @staticmethod + def convert(value, unit, axis): + """ + Convert Decimals to floats. + + The *unit* and *axis* arguments are not used. + + Parameters + ---------- + value : decimal.Decimal or iterable + Decimal or list of Decimal need to be converted + """ + # If value is a Decimal + if isinstance(value, Decimal): + return np.float(value) + else: + # assume x is a list of Decimal + converter = np.asarray + if isinstance(value, ma.MaskedArray): + converter = ma.asarray + return converter(value, dtype=np.float) + + @staticmethod + def axisinfo(unit, axis): + # Since Decimal is a kind of Number, don't need specific axisinfo. + return AxisInfo() + + @staticmethod + def default_units(x, axis): + # Return None since Decimal is a kind of Number. + return None + class Registry(dict): """Register types with conversion interface.""" @@ -164,3 +218,4 @@ def get_converter(self, x): registry = Registry() +registry[Decimal] = DecimalConverter()