1
1
import math
2
+
2
3
import numpy as np
3
4
4
5
import matplotlib .units as units
5
6
import matplotlib .ticker as ticker
6
7
from matplotlib .axes import Axes
7
8
from matplotlib .cbook import iterable
8
9
10
+
9
11
class ProxyDelegate (object ):
10
12
def __init__ (self , fn_name , proxy_type ):
11
13
self .proxy_type = proxy_type
12
14
self .fn_name = fn_name
15
+
13
16
def __get__ (self , obj , objtype = None ):
14
17
return self .proxy_type (self .fn_name , obj )
15
18
19
+
16
20
class TaggedValueMeta (type ):
17
21
def __init__ (cls , name , bases , dict ):
18
22
for fn_name in cls ._proxies .keys ():
19
23
try :
20
24
dummy = getattr (cls , fn_name )
21
25
except AttributeError :
22
- setattr (cls , fn_name , ProxyDelegate (fn_name , cls ._proxies [fn_name ]))
26
+ setattr (cls , fn_name ,
27
+ ProxyDelegate (fn_name , cls ._proxies [fn_name ]))
28
+
23
29
24
30
class PassThroughProxy (object ):
25
31
def __init__ (self , fn_name , obj ):
26
32
self .fn_name = fn_name
27
33
self .target = obj .proxy_target
34
+
28
35
def __call__ (self , * args ):
29
- #print 'passthrough', self.target, self.fn_name
30
36
fn = getattr (self .target , self .fn_name )
31
37
ret = fn (* args )
32
38
return ret
33
39
40
+
34
41
class ConvertArgsProxy (PassThroughProxy ):
35
42
def __init__ (self , fn_name , obj ):
36
43
PassThroughProxy .__init__ (self , fn_name , obj )
37
44
self .unit = obj .unit
45
+
38
46
def __call__ (self , * args ):
39
47
converted_args = []
40
48
for a in args :
@@ -45,16 +53,19 @@ def __call__(self, *args):
45
53
converted_args = tuple ([c .get_value () for c in converted_args ])
46
54
return PassThroughProxy .__call__ (self , * converted_args )
47
55
56
+
48
57
class ConvertReturnProxy (PassThroughProxy ):
49
58
def __init__ (self , fn_name , obj ):
50
59
PassThroughProxy .__init__ (self , fn_name , obj )
51
60
self .unit = obj .unit
61
+
52
62
def __call__ (self , * args ):
53
63
ret = PassThroughProxy .__call__ (self , * args )
54
64
if (type (ret ) == type (NotImplemented )):
55
65
return NotImplemented
56
66
return TaggedValue (ret , self .unit )
57
67
68
+
58
69
class ConvertAllProxy (PassThroughProxy ):
59
70
def __init__ (self , fn_name , obj ):
60
71
PassThroughProxy .__init__ (self , fn_name , obj )
@@ -91,17 +102,17 @@ def __call__(self, *args):
91
102
return NotImplemented
92
103
return TaggedValue (ret , ret_unit )
93
104
94
- class TaggedValue (object ):
95
105
96
- __metaclass__ = TaggedValueMeta
97
- _proxies = {'__add__' :ConvertAllProxy ,
98
- '__sub__' :ConvertAllProxy ,
99
- '__mul__' :ConvertAllProxy ,
100
- '__rmul__' :ConvertAllProxy ,
101
- '__cmp__' :ConvertAllProxy ,
102
- '__lt__' :ConvertAllProxy ,
103
- '__gt__' :ConvertAllProxy ,
104
- '__len__' :PassThroughProxy }
106
+ class _TaggedValue (object ):
107
+
108
+ _proxies = {'__add__' : ConvertAllProxy ,
109
+ '__sub__' : ConvertAllProxy ,
110
+ '__mul__' : ConvertAllProxy ,
111
+ '__rmul__' : ConvertAllProxy ,
112
+ '__cmp__' : ConvertAllProxy ,
113
+ '__lt__' : ConvertAllProxy ,
114
+ '__gt__' : ConvertAllProxy ,
115
+ '__len__' : PassThroughProxy }
105
116
106
117
def __new__ (cls , value , unit ):
107
118
# generate a new subclass for value
@@ -120,13 +131,9 @@ def __new__(cls, value, unit):
120
131
121
132
def __init__ (self , value , unit ):
122
133
self .value = value
123
- self .unit = unit
134
+ self .unit = unit
124
135
self .proxy_target = self .value
125
136
126
- def get_compressed_copy (self , mask ):
127
- compressed_value = np .ma .masked_array (self .value , mask = mask ).compressed ()
128
- return TaggedValue (compressed_value , self .unit )
129
-
130
137
def __getattribute__ (self , name ):
131
138
if (name .startswith ('__' )):
132
139
return object .__getattribute__ (self , name )
@@ -135,7 +142,7 @@ def __getattribute__(self, name):
135
142
return getattr (variable , name )
136
143
return object .__getattribute__ (self , name )
137
144
138
- def __array__ (self , t = None , context = None ):
145
+ def __array__ (self , t = None , context = None ):
139
146
if t is not None :
140
147
return np .asarray (self .value ).astype (t )
141
148
else :
@@ -158,6 +165,7 @@ class IteratorProxy(object):
158
165
def __init__ (self , iter , unit ):
159
166
self .iter = iter
160
167
self .unit = unit
168
+
161
169
def __next__ (self ):
162
170
value = next (self .iter )
163
171
return TaggedValue (value , self .unit )
@@ -169,7 +177,6 @@ def get_compressed_copy(self, mask):
169
177
return TaggedValue (new_value , self .unit )
170
178
171
179
def convert_to (self , unit ):
172
- #print 'convert to', unit, self.unit
173
180
if (unit == self .unit or not unit ):
174
181
return self
175
182
new_value = self .unit .convert_value_to (self .value , unit )
@@ -182,14 +189,17 @@ def get_unit(self):
182
189
return self .unit
183
190
184
191
192
+ TaggedValue = TaggedValueMeta ('TaggedValue' , (_TaggedValue , ), {})
193
+
194
+
185
195
class BasicUnit (object ):
186
196
def __init__ (self , name , fullname = None ):
187
197
self .name = name
188
- if fullname is None : fullname = name
198
+ if fullname is None :
199
+ fullname = name
189
200
self .fullname = fullname
190
201
self .conversions = dict ()
191
202
192
-
193
203
def __repr__ (self ):
194
204
return 'BasicUnit(%s)' % self .name
195
205
@@ -201,11 +211,11 @@ def __call__(self, value):
201
211
202
212
def __mul__ (self , rhs ):
203
213
value = rhs
204
- unit = self
214
+ unit = self
205
215
if hasattr (rhs , 'get_unit' ):
206
216
value = rhs .get_value ()
207
- unit = rhs .get_unit ()
208
- unit = unit_resolver ('__mul__' , (self , unit ))
217
+ unit = rhs .get_unit ()
218
+ unit = unit_resolver ('__mul__' , (self , unit ))
209
219
if (unit == NotImplemented ):
210
220
return NotImplemented
211
221
return TaggedValue (value , unit )
@@ -235,44 +245,43 @@ def get_conversion_fn(self, unit):
235
245
return self .conversions [unit ]
236
246
237
247
def convert_value_to (self , value , unit ):
238
- #print 'convert value to: value ="%s", unit="%s"'%(value, type(unit)), self.conversions
239
248
conversion_fn = self .conversions [unit ]
240
249
ret = conversion_fn (value )
241
250
return ret
242
251
243
-
244
252
def get_unit (self ):
245
253
return self
246
254
255
+
247
256
class UnitResolver (object ):
248
257
def addition_rule (self , units ):
249
258
for unit_1 , unit_2 in zip (units [:- 1 ], units [1 :]):
250
259
if (unit_1 != unit_2 ):
251
260
return NotImplemented
252
261
return units [0 ]
262
+
253
263
def multiplication_rule (self , units ):
254
264
non_null = [u for u in units if u ]
255
265
if (len (non_null ) > 1 ):
256
266
return NotImplemented
257
267
return non_null [0 ]
258
268
259
269
op_dict = {
260
- '__mul__' :multiplication_rule ,
261
- '__rmul__' :multiplication_rule ,
262
- '__add__' :addition_rule ,
263
- '__radd__' :addition_rule ,
264
- '__sub__' :addition_rule ,
265
- '__rsub__' :addition_rule ,
266
- }
270
+ '__mul__' : multiplication_rule ,
271
+ '__rmul__' : multiplication_rule ,
272
+ '__add__' : addition_rule ,
273
+ '__radd__' : addition_rule ,
274
+ '__sub__' : addition_rule ,
275
+ '__rsub__' : addition_rule }
267
276
268
277
def __call__ (self , operation , units ):
269
278
if (operation not in self .op_dict ):
270
279
return NotImplemented
271
280
272
281
return self .op_dict [operation ](self , units )
273
282
274
- unit_resolver = UnitResolver ()
275
283
284
+ unit_resolver = UnitResolver ()
276
285
277
286
cm = BasicUnit ('cm' , 'centimeters' )
278
287
inch = BasicUnit ('inch' , 'inches' )
@@ -288,11 +297,12 @@ def __call__(self, operation, units):
288
297
hertz = BasicUnit ('Hz' , 'Hertz' )
289
298
minutes = BasicUnit ('min' , 'minutes' )
290
299
291
- secs .add_conversion_fn (hertz , lambda x :1. / x )
300
+ secs .add_conversion_fn (hertz , lambda x : 1. / x )
292
301
secs .add_conversion_factor (minutes , 1 / 60.0 )
293
302
303
+
294
304
# radians formatting
295
- def rad_fn (x ,pos = None ):
305
+ def rad_fn (x , pos = None ):
296
306
n = int ((x / np .pi ) * 2.0 + 0.25 )
297
307
if n == 0 :
298
308
return '0'
@@ -335,7 +345,6 @@ def axisinfo(unit, axis):
335
345
def convert (val , unit , axis ):
336
346
if units .ConversionInterface .is_numlike (val ):
337
347
return val
338
- #print 'convert checking iterable'
339
348
if iterable (val ):
340
349
return [thisval .convert_to (unit ).get_value () for thisval in val ]
341
350
else :
@@ -350,15 +359,12 @@ def default_units(x, axis):
350
359
return x .unit
351
360
352
361
353
-
354
- def cos ( x ):
355
- if ( iterable (x ) ):
356
- result = []
357
- for val in x :
358
- result .append ( math .cos ( val .convert_to ( radians ).get_value () ) )
359
- return result
362
+ def cos (x ):
363
+ if iterable (x ):
364
+ return [math .cos (val .convert_to (radians ).get_value ()) for val in x ]
360
365
else :
361
- return math .cos ( x .convert_to ( radians ).get_value () )
366
+ return math .cos (x .convert_to (radians ).get_value ())
367
+
362
368
363
369
basicConverter = BasicUnitConverter ()
364
370
units .registry [BasicUnit ] = basicConverter
0 commit comments