@@ -2663,74 +2663,154 @@ def __exit__(self, exc_type, exc_value, traceback):
2663
2663
pass
2664
2664
2665
2665
2666
+ class _FuncInfo (object ):
2667
+ """
2668
+ Class used to store a function
2669
+
2670
+ Each object has:
2671
+ * The direct function (direct)
2672
+ * The inverse function (inverse)
2673
+ * A boolean indicating whether the function
2674
+ is bounded in the interval 0-1 (bounded_0_1)
2675
+
2676
+ """
2677
+ def __init__ (self , direct , inverse , bounded_0_1 ):
2678
+ self .direct = direct
2679
+ self .inverse = inverse
2680
+ self .bounded_0_1 = bounded_0_1
2681
+
2682
+ def copy (self ):
2683
+ return _FuncInfo (self .direct ,
2684
+ self .inverse ,
2685
+ self .bounded_0_1 )
2686
+
2687
+
2666
2688
class _StringFuncParser (object ):
2667
- # Each element has:
2668
- # -The direct function,
2669
- # -The inverse function,
2670
- # -A boolean indicating whether the function
2671
- # is bounded in the interval 0-1
2672
-
2673
- funcs = {'linear' : (lambda x : x , lambda x : x , True ),
2674
- 'quadratic' : (lambda x : x ** 2 , lambda x : x ** (1. / 2 ), True ),
2675
- 'cubic' : (lambda x : x ** 3 , lambda x : x ** (1. / 3 ), True ),
2676
- 'sqrt' : (lambda x : x ** (1. / 2), lambda x : x ** 2 , True ),
2677
- 'cbrt' : (lambda x : x ** (1. / 3 ), lambda x : x ** 3 , True ),
2678
- 'log10' : (lambda x : np .log10 (x ), lambda x : (10 ** (x )), False ),
2679
- 'log' : (lambda x : np .log (x ), lambda x : (np .exp (x )), False ),
2680
- 'power{a}' : (lambda x , a : x ** a ,
2681
- lambda x , a : x ** (1. / a ), True ),
2682
- 'root{a}' : (lambda x , a : x ** (1. / a ),
2683
- lambda x , a : x ** a , True ),
2684
- 'log10(x+{a})' : (lambda x , a : np .log10 (x + a ),
2685
- lambda x , a : 10 ** x - a , True ),
2686
- 'log(x+{a})' : (lambda x , a : np .log (x + a ),
2687
- lambda x , a : np .exp (x ) - a , True )}
2689
+ """
2690
+ A class used to convert predefined strings into
2691
+ _FuncInfo objects, or to directly obtain _FuncInfo
2692
+ properties.
2693
+
2694
+ """
2695
+
2696
+ _funcs = {}
2697
+ _funcs ['linear' ] = _FuncInfo (lambda x : x ,
2698
+ lambda x : x ,
2699
+ True )
2700
+ _funcs ['quadratic' ] = _FuncInfo (lambda x : x ** 2 ,
2701
+ lambda x : x ** (1. / 2 ),
2702
+ True )
2703
+ _funcs ['cubic' ] = _FuncInfo (lambda x : x ** 3 ,
2704
+ lambda x : x ** (1. / 3 ),
2705
+ True )
2706
+ _funcs ['sqrt' ] = _FuncInfo (lambda x : x ** (1. / 2 ),
2707
+ lambda x : x ** 2 ,
2708
+ True )
2709
+ _funcs ['cbrt' ] = _FuncInfo (lambda x : x ** (1. / 3 ),
2710
+ lambda x : x ** 3 ,
2711
+ True )
2712
+ _funcs ['log10' ] = _FuncInfo (lambda x : np .log10 (x ),
2713
+ lambda x : (10 ** (x )),
2714
+ False )
2715
+ _funcs ['log' ] = _FuncInfo (lambda x : np .log (x ),
2716
+ lambda x : (np .exp (x )),
2717
+ False )
2718
+ _funcs ['x**{p}' ] = _FuncInfo (lambda x , p : x ** p [0 ],
2719
+ lambda x , p : x ** (1. / p [0 ]),
2720
+ True )
2721
+ _funcs ['root{p}(x)' ] = _FuncInfo (lambda x , p : x ** (1. / p [0 ]),
2722
+ lambda x , p : x ** p ,
2723
+ True )
2724
+ _funcs ['log10(x+{p})' ] = _FuncInfo (lambda x , p : np .log10 (x + p [0 ]),
2725
+ lambda x , p : 10 ** x - p [0 ],
2726
+ True )
2727
+ _funcs ['log(x+{p})' ] = _FuncInfo (lambda x , p : np .log (x + p [0 ]),
2728
+ lambda x , p : np .exp (x ) - p [0 ],
2729
+ True )
2730
+ _funcs ['log{p}(x+{p})' ] = _FuncInfo (lambda x , p : (np .log (x + p [1 ]) /
2731
+ np .log (p [0 ])),
2732
+ lambda x , p : p [0 ]** (x ) - p [1 ],
2733
+ True )
2688
2734
2689
2735
def __init__ (self , str_func ):
2690
- self .str_func = str_func
2736
+ """
2737
+ Parameters
2738
+ ----------
2739
+ str_func : string
2740
+ String to be parsed.
2691
2741
2692
- def is_string (self ):
2693
- return not hasattr (self .str_func , '__call__' )
2742
+ """
2743
+ try : # For python 2.7 and python 3+ compatibility
2744
+ is_str = isinstance (str_func , basestring )
2745
+ except NameError :
2746
+ is_str = isinstance (str_func , str )
2747
+
2748
+ if not is_str :
2749
+ raise ValueError ("The argument passed is not a string." )
2750
+ self ._str_func = str_func
2751
+ self ._key , self ._params = self ._get_key_params ()
2752
+ self ._func = self .get_func ()
2694
2753
2695
2754
def get_func (self ):
2696
- return self ._get_element (0 )
2755
+ """
2756
+ Returns the _FuncInfo object, replacing the relevant parameters if
2757
+ necessary in the lambda functions.
2758
+
2759
+ """
2760
+
2761
+ func = self ._funcs [self ._key ].copy ()
2762
+ if len (self ._params ) > 0 :
2763
+ m = func .direct
2764
+ func .direct = (lambda x , m = m : m (x , self ._params ))
2765
+ m = func .inverse
2766
+ func .inverse = (lambda x , m = m : m (x , self ._params ))
2767
+ return func
2768
+
2769
+ def get_directfunc (self ):
2770
+ """
2771
+ Returns the callable for the direct function.
2772
+
2773
+ """
2774
+ return self ._func .direct
2697
2775
2698
2776
def get_invfunc (self ):
2699
- return self ._get_element (1 )
2777
+ """
2778
+ Returns the callable for the inverse function.
2779
+
2780
+ """
2781
+ return self ._func .inverse
2700
2782
2701
2783
def is_bounded_0_1 (self ):
2702
- return self ._get_element (2 )
2784
+ """
2785
+ Returns a boolean indicating if the function is bounded
2786
+ in the [0-1 interval].
2703
2787
2704
- def _get_element (self , ind ):
2705
- if not self .is_string ():
2706
- raise ValueError ("The argument passed is not a string." )
2788
+ """
2789
+ return self ._func .bounded_0_1
2707
2790
2708
- str_func = six . text_type (self . str_func )
2709
- # Checking if it comes with a parameter
2710
- param = None
2791
+ def _get_key_params (self ):
2792
+ str_func = six . text_type ( self . _str_func )
2793
+ # Checking if it comes with parameters
2711
2794
regex = '\{(.*?)\}'
2712
- search = re .search (regex , str_func )
2713
- if search is not None :
2714
- parstring = search .group (1 )
2795
+ params = re .findall (regex , str_func )
2715
2796
2716
- try :
2717
- param = float (parstring )
2718
- except :
2719
- raise ValueError ("'a' in parametric function strings must be "
2720
- "replaced by a number that is not "
2721
- "zero, e.g. 'log10(x+{0.1})'." )
2722
- if param == 0 :
2723
- raise ValueError ("'a' in parametric function strings must be "
2724
- "replaced by a number that is not "
2725
- "zero." )
2726
- str_func = re .sub (regex , '{a}' , str_func )
2797
+ if len (params ) > 0 :
2798
+ for i in range (len (params )):
2799
+ try :
2800
+ params [i ] = float (params [i ])
2801
+ except :
2802
+ raise ValueError ("'p' in parametric function strings must"
2803
+ " be replaced by a number that is not "
2804
+ "zero, e.g. 'log10(x+{0.1})'." )
2805
+
2806
+ if params [i ] == 0 :
2807
+ raise ValueError ("'p' in parametric function strings must"
2808
+ " be replaced by a number that is not "
2809
+ "zero." )
2810
+ str_func = re .sub (regex , '{p}' , str_func )
2727
2811
2728
2812
try :
2729
- output = self .funcs [str_func ][ind ]
2730
- if param is not None :
2731
- output = (lambda x , output = output : output (x , param ))
2732
-
2733
- return output
2813
+ func = self ._funcs [str_func ]
2734
2814
except KeyError :
2735
2815
raise ValueError ("%s: invalid function. The only strings "
2736
2816
"recognized as functions are %s." %
@@ -2739,3 +2819,12 @@ def _get_element(self, ind):
2739
2819
raise ValueError ("Invalid function. The only strings recognized "
2740
2820
"as functions are %s." %
2741
2821
(self .funcs .keys ()))
2822
+ if len (params ) > 0 :
2823
+ func .direct (0.5 , params )
2824
+ try :
2825
+ func .direct (0.5 , params )
2826
+ except :
2827
+ raise ValueError ("Invalid parameters set for '%s'." %
2828
+ (str_func ))
2829
+
2830
+ return str_func , params
0 commit comments