8000 Creation of class _FuncInfo · matplotlib/matplotlib@509ed9f · GitHub
[go: up one dir, main page]

Skip to content

Commit 509ed9f

Browse files
committed
Creation of class _FuncInfo
1 parent 0ad3710 commit 509ed9f

File tree

2 files changed

+172
-62
lines changed

2 files changed

+172
-62
lines changed

lib/matplotlib/cbook.py

Lines changed: 141 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -2663,74 +2663,154 @@ def __exit__(self, exc_type, exc_value, traceback):
26632663
pass
26642664

26652665

2666+
class _FuncInfo(object):
2667+
"""
2668+
Class used to store a function
2669+
2670+
Each object has:
2671+
* The direct function (direct)
2672+
* The inverse function (inverse)
2673+
* A boolean indicating whether the function
2674+
is bounded in the interval 0-1 (bounded_0_1)
2675+
2676+
"""
2677+
def __init__(self, direct, inverse, bounded_0_1):
2678+
self.direct = direct
2679+
self.inverse = inverse
2680+
self.bounded_0_1 = bounded_0_1
2681+
2682+
def copy(self):
2683+
return _FuncInfo(self.direct,
2684+
self.inverse,
2685+
self.bounded_0_1)
2686+
2687+
26662688
class _StringFuncParser(object):
2667-
# Each element has:
2668-
# -The direct function,
2669-
# -The inverse function,
2670-
# -A boolean indicating whether the function
2671-
# is bounded in the interval 0-1
2672-
2673-
funcs = {'linear': (lambda x: x, lambda x: x, True),
2674-
'quadratic': (lambda x: x**2, lambda x: x**(1. / 2), True),
2675-
'cubic': (lambda x: x**3, lambda x: x**(1. / 3), True),
2676-
'sqrt': (lambda x: x**(1. / 2), lambda x: x**2, True),
2677-
'cbrt': (lambda x: x**(1. / 3), lambda x: x**3, True),
2678-
'log10': (lambda x: np.log10(x), lambda x: (10**(x)), False),
2679-
'log': (lambda x: np.log(x), lambda x: (np.exp(x)), False),
2680-
'power{a}': (lambda x, a: x**a,
2681-
lambda x, a: x**(1. / a), True),
2682-
'root{a}': (lambda x, a: x**(1. / a),
2683-
lambda x, a: x**a, True),
2684-
'log10(x+{a})': (lambda x, a: np.log10(x + a),
2685-
lambda x, a: 10**x - a, True),
2686-
'log(x+{a})': (lambda x, a: np.log(x + a),
2687-
lambda x, a: np.exp(x) - a, True)}
2689+
"""
2690+
A class used to convert predefined strings into
2691+
_FuncInfo objects, or to directly obtain _FuncInfo
2692+
properties.
2693+
2694+
"""
2695+
2696+
_funcs = {}
2697+
_funcs['linear'] = _FuncInfo(lambda x: x,
2698+
lambda x: x,
2699+
True)
2700+
_funcs['quadratic'] = _FuncInfo(lambda x: x**2,
2701+
lambda x: x**(1. / 2),
2702+
True)
2703+
_funcs['cubic'] = _FuncInfo(lambda x: x**3,
2704+
lambda x: x**(1. / 3),
2705+
True)
2706+
_funcs['sqrt'] = _FuncInfo(lambda x: x**(1. / 2),
2707+
lambda x: x**2,
2708+
True)
2709+
_funcs['cbrt'] = _FuncInfo(lambda x: x**(1. / 3),
2710+
lambda x: x**3,
2711+
True)
2712+
_funcs['log10'] = _FuncInfo(lambda x: np.log10(x),
2713+
lambda x: (10**(x)),
2714+
False)
2715+
_funcs['log'] = _FuncInfo(lambda x: np.log(x),
2716+
lambda x: (np.exp(x)),
2717+
False)
2718+
_funcs['x**{p}'] = _FuncInfo(lambda x, p: x**p[0],
2719+
lambda x, p: x**(1. / p[0]),
2720+
True)
2721+
_funcs['root{p}(x)'] = _FuncInfo(lambda x, p: x**(1. / p[0]),
2722+
lambda x, p: x**p,
2723+
True)
2724+
_funcs['log10(x+{p})'] = _FuncInfo(lambda x, p: np.log10(x + p[0]),
2725+
lambda x, p: 10**x - p[0],
2726+
True)
2727+
_funcs['log(x+{p})'] = _FuncInfo(lambda x, p: np.log(x + p[0]),
2728+
lambda x, p: np.exp(x) - p[0],
2729+
True)
2730+
_funcs['log{p}(x+{p})'] = _FuncInfo(lambda x, p: (np.log(x + p[1]) /
2731+
np.log(p[0])),
2732+
lambda x, p: p[0]**(x) - p[1],
2733+
True)
26882734

26892735
def __init__(self, str_func):
2690-
self.str_func = str_func
2736+
"""
2737+
Parameters
2738+
----------
2739+
str_func : string
2740+
String to be parsed.
26912741
2692-
def is_string(self):
2693-
return not hasattr(self.str_func, '__call__')
2742+
"""
2743+
try: # For python 2.7 and python 3+ compatibility
2744+
is_str = isinstance(str_func, basestring)
2745+
except NameError:
2746+
is_str = isinstance(str_func, str)
2747+
2748+
if not is_str:
2749+
raise ValueError("The argument passed is not a string.")
2750+
self._str_func = str_func
2751+
self._key, self._params = self._get_key_params()
2752+
self._func = self.get_func()
26942753

26952754
def get_func(self):
2696-
return self._get_element(0)
2755+
"""
2756+
Returns the _FuncInfo object, replacing the relevant parameters if
2757+
necessary in the lambda functions.
2758+
2759+
"""
2760+
2761+
func = self._funcs[self._key].copy()
2762+
if len(self._params) > 0:
2763+
m = func.direct
2764+
func.direct = (lambda x, m=m: m(x, self._params))
2765+
m = func.inverse
2766+
func.inverse = (lambda x, m=m: m(x, self._params))
2767+
return func
2768+
2769+
def get_directfunc(self):
2770+
"""
2771+
Returns the callable for the direct function.
2772+
2773+
"""
2774+
return self._func.direct
26972775

26982776
def get_invfunc(self):
2699-
return self._get_element(1)
2777+
"""
2778+
Returns the callable for the inverse function.
2779+
2780+
"""
2781+
return self._func.inverse
27002782

27012783
def is_bounded_0_1(self):
2702-
return self._get_element(2)
2784+
"""
2785+
Returns a boolean indicating if the function is bounded
2786+
in the [0-1 interval].
27032787
2704-
def _get_element(self, ind):
2705-
if not self.is_string():
2706-
raise ValueError("The argument passed is not a string.")
2788+
"""
2789+
return self._func.bounded_0_1
27072790

2708-
str_func = six.text_type(self.str_func)
2709-
# Checking if it comes with a parameter
2710-
param = None
2791+
def _get_key_params(self):
2792+
str_func = six.text_type(self._str_func)
2793+
# Checking if it comes with parameters
27112794
regex = '\{(.*?)\}'
2712-
search = re.search(regex, str_func)
2713-
if search is not None:
2714-
parstring = search.group(1)
2795+
params = re.findall(regex, str_func)
27152796

2716-
try:
2717-
param = float(parstring)
2718-
except:
2719-
raise ValueError("'a' in parametric function strings must be "
2720-
"replaced by a number that is not "
2721-
"zero, e.g. 'log10(x+{0.1})'.")
2722-
if param == 0:
2723-
raise ValueError("'a' in parametric function strings must be "
2724-
"replaced by a number that is not "
2725-
"zero.")
2726-
str_func = re.sub(regex, '{a}', str_func)
2797+
if len(params) > 0:
2798+
for i in range(len(params)):
2799+
try:
2800+
params[i] = float(params[i])
2801+
except:
2802+
raise ValueError("'p' in parametric function strings must"
2803+
" be replaced by a number that is not "
2804+
"zero, e.g. 'log10(x+{0.1})'.")
2805+
2806+
if params[i] == 0:
2807+
raise ValueError("'p' in parametric function strings must"
2808+
" be replaced by a number that is not "
2809+
"zero.")
2810+
str_func = re.sub(regex, '{p}', str_func)
27272811

27282812
try:
2729-
output = self.funcs[str_func][ind]
2730-
if param is not None:
2731-
output = (lambda x, output=output: output(x, param))
2732-
2733-
return output
2813+
func = self._funcs[str_func]
27342814
except KeyError:
27352815
raise ValueError("%s: invalid function. The only strings "
27362816
"recognized as functions are %s." %
@@ -2739,3 +2819,12 @@ def _get_element(self, ind):
27392819
raise ValueError("Invalid function. The only strings recognized "
27402820
"as functions are %s." %
27412821
(self.funcs.keys()))
2822+
if len(params) > 0:
2823+
func.direct(0.5, params)
2824+
try:
2825+
func.direct(0.5, params)
2826+
except:
2827+
raise ValueError("Invalid parameters set for '%s'." %
2828+
(str_func))
2829+
2830+
return str_func, params

lib/matplotlib/tests/test_cbook.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -522,8 +522,8 @@ def test_flatiter():
522522
class TestFuncParser(object):
523523
x_test = np.linspace(0.01, 0.5, 3)
524524
validstrings = ['linear', 'quadratic', 'cubic', 'sqrt', 'cbrt',
525-
'log', 'log10', 'power{1.5}', 'root{2.5}',
526-
'log(x+{0.5})', 'log10(x+{0.1})']
525+
'log', 'log10', 'x**{1.5}', 'root{2.5}(x)',
526+
'log(x+{0.5})', 'log10(x+{0.1})', 'log{2}(x+{0.1})']
527527
results = [(lambda x: x),
528528
(lambda x: x**2),
529529
(lambda x: x**3),
@@ -534,19 +534,40 @@ class TestFuncParser(object):
534534
(lambda x: x**1.5),
535535
(lambda x: x**(1 / 2.5)),
536536
(lambda x: np.log(x + 0.5)),
537-
(lambda x: np.log10(x + 0.1))]
537+
(lambda x: np.log10(x + 0.1)),
538+
(lambda x: np.log2(x + 0.1))]
538539

539-
@pytest.mark.parametrize("string", validstrings, ids=validstrings)
540-
def test_inverse(self, string):
541-
func_parser = cbook._StringFuncParser(string)
542-
f = func_parser.get_func()
543-
finv = func_parser.get_invfunc()
544-
assert_array_almost_equal(finv(f(self.x_test)), self.x_test)
540+
bounded_list = [True, True, True, True, True,
541+
False, False, True, True,
542+
True, True, True]
545543

546544
@pytest.mark.parametrize("string, func",
547545
zip(validstrings, results),
548546
ids=validstrings)
549547
def test_values(self, string, func):
550548
func_parser = cbook._StringFuncParser(string)
551-
f = func_parser.get_func()
549+
f = func_parser.get_directfunc()
552550
assert_array_almost_equal(f(self.x_test), func(self.x_test))
551+
552+
@pytest.mark.parametrize("string", validstrings, ids=validstrings)
553+
def test_inverse(self, string):
554+
func_parser = cbook._StringFuncParser(string)
555+
f = func_parser.get_func()
556+
fdir = f.direct
557+
finv = f.inverse
558+
assert_array_almost_equal(finv(fdir(self.x_test)), self.x_test)
559+
560+
@pytest.mark.parametrize("string", validstrings, ids=validstrings)
561+
def test_get_invfunc(self, string):
562+
func_parser = cbook._StringFuncParser(string)
563+
finv1 = func_parser.get_invfunc()
564+
finv2 = func_parser.get_func().inverse
565+
assert_array_almost_equal(finv1(self.x_test), finv2(self.x_test))
566+
567+
@pytest.mark.parametrize("string, bounded",
568+
zip(validstrings, bounded_list),
569+
ids=validstrings)
570+
def test_bounded(self, string, bounded):
571+
func_parser = cbook._StringFuncParser(string)
572+
b = func_parser.is_bounded_0_1()
573+
assert_array_equal(b, bounded)

0 commit comments

Comments
 (0)
0