5
5
6
6
Arrow drawing example for the new fancy_arrow facilities.
7
7
8
- Code contributed by: Rob Knight <rob@spot.colorado.edu>
9
-
10
- usage:
11
-
12
- python arrow_demo.py realistic|full|sample|extreme
8
+ Original author: Rob Knight <rob@spot.colorado.edu>
9
+ """
13
10
11
+ import itertools
14
12
15
- """
16
13
import matplotlib .pyplot as plt
17
14
import numpy as np
18
15
19
- rates_to_bases = {'r1' : 'AT' , 'r2' : 'TA' , 'r3' : 'GA' , 'r4' : 'AG' , 'r5' : 'CA' ,
20
- 'r6' : 'AC' , 'r7' : 'GT' , 'r8' : 'TG' , 'r9' : 'CT' , 'r10' : 'TC' ,
21
- 'r11' : 'GC' , 'r12' : 'CG' }
22
- numbered_bases_to_rates = {v : k for k , v in rates_to_bases .items ()}
23
- lettered_bases_to_rates = {v : 'r' + v for k , v in rates_to_bases .items ()}
24
-
25
16
26
- def make_arrow_plot (data , size = 4 , display = 'length' , shape = 'right' ,
27
- max_arrow_width = 0.03 , arrow_sep = 0.02 , alpha = 0.5 ,
28
- normalize_data = False , ec = None , labelcolor = None ,
29
- head_starts_at_zero = True ,
30
- rate_labels = lettered_bases_to_rates ,
31
- ** kwargs ):
17
+ def make_arrow_graph (ax , data , size = 4 , display = 'length' , shape = 'right' ,
18
+ max_arrow_width = 0.03 , arrow_sep = 0.02 , alpha = 0.5 ,
19
+ normalize_data = False , ec = None , labelcolor = None ,
20
+ ** kwargs ):
32
21
"""
33
22
Makes an arrow plot.
34
23
35
24
Parameters
36
25
----------
26
+ ax
27
+ The axes where the graph is drawn.
37
28
data
38
29
Dict with probabilities for the bases and pair transitions.
39
30
size
40
- Size of the graph in inches.
31
+ Size of the plot, in inches.
41
32
display : {'length', 'width', 'alpha'}
42
33
The arrow property to change.
43
34
shape : {'full', 'left', 'right'}
44
35
For full or half arrows.
45
36
max_arrow_width : float
46
- Maximum width of an arrow, data coordinates.
37
+ Maximum width of an arrow, in data coordinates.
47
38
arrow_sep : float
48
- Separation between arrows in a pair, data coordinates.
39
+ Separation between arrows in a pair, in data coordinates.
49
40
alpha : float
50
41
Maximum opacity of arrows.
51
42
**kwargs
52
- Can be anything allowed by a Arrow object, e.g. *linewidth* or
53
- *edgecolor*.
43
+ `.FancyArrow` properties, e.g. *linewidth* or *edgecolor*.
54
44
"""
55
45
56
- plt .xlim (- 0.5 , 1.5 )
57
- plt .ylim (- 0.5 , 1.5 )
58
- plt .gcf ().set_size_inches (size , size )
59
- plt .xticks ([])
60
- plt .yticks ([])
46
+ ax .set (xlim = (- 0.5 , 1.5 ), ylim = (- 0.5 , 1.5 ), xticks = [], yticks = [])
47
+ ax .text (.01 , .01 , f'flux encoded as arrow { display } ' ,
48
+ transform = ax .transAxes )
61
49
max_text_size = size * 12
62
50
min_text_size = size
63
51
label_text_size = size * 2.5
64
- text_params = {'ha' : 'center' , 'va' : 'center' , 'family' : 'sans-serif' ,
65
- 'fontweight' : 'bold' }
66
- r2 = np .sqrt (2 )
67
-
68
- deltas = {
69
- 'AT' : (1 , 0 ),
70
- 'TA' : (- 1 , 0 ),
71
- 'GA' : (0 , 1 ),
72
- 'AG' : (0 , - 1 ),
73
- 'CA' : (- 1 / r2 , 1 / r2 ),
74
- 'AC' : (1 / r2 , - 1 / r2 ),
75
- 'GT' : (1 / r2 , 1 / r2 ),
76
- 'TG' : (- 1 / r2 , - 1 / r2 ),
77
- 'CT' : (0 , 1 ),
78
- 'TC' : (0 , - 1 ),
79
- 'GC' : (1 , 0 ),
80
- 'CG' : (- 1 , 0 )}
81
52
82
- colors = {
83
- 'AT' : 'r' ,
84
- 'TA' : 'k' ,
85
- 'GA' : 'g' ,
86
- 'AG' : 'r' ,
87
- 'CA' : 'b' ,
88
- 'AC' : 'r' ,
89
- 'GT' : 'g' ,
90
- 'TG' : 'k' ,
91
- 'CT' : 'b' ,
92
- 'TC' : 'k' ,
93
- 'GC' : 'g' ,
94
- 'CG' : 'b' }
95
-
96
- label_positions = {
97
- 'AT' : 'center' ,
98
- 'TA' : 'center' ,
99
- 'GA' : 'center' ,
100
- 'AG' : 'center' ,
101
- 'CA' : 'left' ,
102
- 'AC' : 'left' ,
103
- 'GT' : 'left' ,
104
- 'TG' : 'left' ,
105
- 'CT' : 'center' ,
106
- 'TC' : 'center' ,
107
- 'GC' : 'center' ,
108
- 'CG' : 'center' }
109
-
110
- def do_fontsize (k ):
111
- return float (np .clip (max_text_size * np .sqrt (data [k ]),
112
- min_text_size , max_text_size ))
113
-
114
- plt .text (0 , 1 , '$A_3$' , color = 'r' , size = do_fontsize ('A' ), ** text_params )
115
- plt .text (1 , 1 , '$T_3$' , color = 'k' , size = do_fontsize ('T' ), ** text_params )
116
- plt .text (0 , 0 , '$G_3$' , color = 'g' , size = do_fontsize ('G' ), ** text_params )
117
- plt .text (1 , 0 , '$C_3$' , color = 'b' , size = do_fontsize ('C' ), ** text_params )
53
+ bases = 'ATGC'
54
+ coords = {
55
+ 'A' : np .array ([0 , 1 ]),
56
+ 'T' : np .array ([1 , 1 ]),
57
+ 'G' : np .array ([0 , 0 ]),
58
+ 'C' : np .array ([1 , 0 ]),
59
+ }
60
+ colors = {'A' : 'r' , 'T' : 'k' , 'G' : 'g' , 'C' : 'b' }
61
+
62
+ for base in bases :
63
+ fontsize = np .clip (max_text_size * data [base ]** (1 / 2 ),
64
+ min_text_size , max_text_size )
65
+ ax .text (* coords [base ], f'${ base } _3$' ,
66
+ color = colors [base ], size = fontsize ,
67
+ horizontalalignment = 'center' , verticalalignment =
10000
'center' ,
68
+ weight = 'bold' )
118
69
119
70
arrow_h_offset = 0.25 # data coordinates, empirically determined
120
71
max_arrow_length = 1 - 2 * arrow_h_offset
121
72
max_head_width = 2.5 * max_arrow_width
122
73
max_head_length = 2 * max_arrow_width
123
- arrow_params = {'length_includes_head' : True , 'shape' : shape ,
124
- 'head_starts_at_zero' : head_starts_at_zero }
125
74
sf = 0.6 # max arrow size represents this in data coords
126
75
127
- d = (r2 / 2 + arrow_h_offset - 0.5 ) / r2 # distance for diags
128
- r2v = arrow_sep / r2 # offset for diags
129
-
130
- # tuple of x, y for start position
131
- positions = {
132
- 'AT' : (arrow_h_offset , 1 + arrow_sep ),
133
- 'TA' : (1 - arrow_h_offset , 1 - arrow_sep ),
134
- 'GA' : (- arrow_sep , arrow_h_offset ),
135
- 'AG' : (arrow_sep , 1 - arrow_h_offset ),
136
- 'CA' : (1 - d - r2v , d - r2v ),
137
- 'AC' : (d + r2v , 1 - d + r2v ),
138
- 'GT' : (d - r2v , d + r2v ),
139
- 'TG' : (1 - d + r2v , 1 - d - r2v ),
140
- 'CT' : (1 - arrow_sep , arrow_h_offset ),
141
- 'TC' : (1 + arrow_sep , 1 - arrow_h_offset ),
142
- 'GC' : (arrow_h_offset , arrow_sep ),
143
- 'CG' : (1 - arrow_h_offset , - arrow_sep )}
144
-
145
76
if normalize_data :
146
77
# find maximum value for rates, i.e. where keys are 2 chars long
147
78
max_val = max ((v for k , v in data .items () if len (k ) == 2 ), default = 0 )
148
79
# divide rates by max val, multiply by arrow scale factor
149
80
for k , v in data .items ():
150
81
data [k ] = v / max_val * sf
151
82
152
- def draw_arrow (pair , alpha = alpha , ec = ec , labelcolor = labelcolor ):
83
+ # iterate over strings 'AT', 'TA', 'AG', 'GA', etc.
84
+ for pair in map ('' .join , itertools .permutations (bases , 2 )):
153
85
# set the length of the arrow
154
86
if display == 'length' :
155
87
length = (max_head_length
@@ -159,7 +91,6 @@ def draw_arrow(pair, alpha=alpha, ec=ec, labelcolor=labelcolor):
159
91
# set the transparency of the arrow
160
92
if display == 'alpha' :
161
93
alpha = min (data [pair ] / sf , alpha )
162
-
163
94
# set the width of the arrow
164
95
if display == 'width' :
165
96
scale = data [pair ] / sf
@@ -171,137 +102,59 @@ def draw_arrow(pair, alpha=alpha, ec=ec, labelcolor=labelcolor):
171
102
head_width = max_head_width
172
103
head_length = max_head_length
173
104
174
- fc = colors [pair ]
175
- ec = ec or fc
176
-
177
- x_scale , y_scale = deltas [pair ]
178
- x_pos , y_pos = positions [pair ]
179
- plt .arrow (x_pos , y_pos , x_scale * length , y_scale * length ,
180
- fc = fc , ec = ec , alpha = alpha , width = width ,
181
- head_width = head_width , head_length = head_length ,
182
- ** arrow_params )
183
-
184
- # figure out coordinates for text
105
+ fc = colors [pair [0 ]]
106
+
107
+ cp0 = coords [pair [0 ]]
108
+ cp1 = coords [pair [1 ]]
109
+ # unit vector in arrow direction
110
+ delta = cos , sin = (cp1 - cp0 ) / np .hypot (* (cp1 - cp0 ))
111
+ x_pos , y_pos = (
112
+ (cp0 + cp1 ) / 2 # midpoint
113
+ - delta * length / 2 # half the arrow length
114
+ + np .array ([- sin , cos ]) * arrow_sep # shift outwards by arrow_sep
115
+ )
116
+ ax .arrow (
117
+ x_pos , y_pos , cos * length , sin * length ,
118
+ fc = fc , ec = ec or fc , alpha = alpha , width = width ,
119
+ head_width = head_width , head_length = head_length , shape = shape ,
120
+ length_includes_head = True ,
121
+ )
122
+
123
+ # figure out coordinates for text:
185
124
# if drawing relative to base: x and y are same as for arrow
186
125
# dx and dy are one arrow width left and up
187
- # need to rotate based on direction of arrow, use x_scale and y_scale
188
- # as sin x and cos x?
189
- sx , cx = y_scale , x_scale
190
-
191
- where = label_positions [pair ]
192
- if where == 'left' :
193
- orig_position = 3 * np .array ([[max_arrow_width , max_arrow_width ]])
194
- elif where == 'absolute' :
195
- orig_position = np .array ([[max_arrow_length / 2.0 ,
196
- 3 * max_arrow_width ]])
197
- elif where == 'right' :
198
- orig_position = np .array ([[length - 3 * max_arrow_width ,
199
- 3 * max_arrow_width ]])
200
- elif where == 'center' :
201
- orig_position = np .array ([[length / 2.0 , 3 * max_arrow_width ]])
202
- else :
203
- raise ValueError ("Got unknown position parameter %s" % where )
204
-
205
- M = np .array ([[cx , sx ], [- sx , cx ]])
206
- coords = np .dot (orig_position , M ) + [[x_pos , y_pos ]]
207
- x , y = np .ravel (coords )
208
- orig_label = rate_labels [pair ]
209
- label = r'$%s_{_{\mathrm{%s}}}$' % (orig_label [0 ], orig_label [1 :])
210
-
211
- plt .text (x , y , label , size = label_text_size , ha = 'center' , va = 'center' ,
212
- color = labelcolor or fc )
213
-
214
- for p in sorted (positions ):
215
- draw_arrow (p )
216
-
217
-
218
- # test data
219
- all_on_max = dict ([(i , 1 ) for i in 'TCAG' ] +
220
- [(i + j , 0.6 ) for i in 'TCAG' for j in 'TCAG' ])
221
-
222
- realistic_data = {
223
- 'A' : 0.4 ,
224
- 'T' : 0.3 ,
225
- 'G' : 0.5 ,
226
- 'C' : 0.2 ,
227
- 'AT' : 0.4 ,
228
- 'AC' : 0.3 ,
229
- 'AG' : 0.2 ,
230
- 'TA' : 0.2 ,
231
- 'TC' : 0.3 ,
232
- 'TG' : 0.4 ,
233
- 'CT' : 0.2 ,
234
- 'CG' : 0.3 ,
235
- 'CA' : 0.2 ,
236
- 'GA' : 0.1 ,
237
- 'GT' : 0.4 ,
238
- 'GC' : 0.1 }
239
-
240
- extreme_data = {
241
- 'A' : 0.75 ,
242
- 'T' : 0.10 ,
243
- 'G' : 0.10 ,
244
- 'C' : 0.05 ,
245
- 'AT' : 0.6 ,
246
- 'AC' : 0.3 ,
247
- 'AG' : 0.1 ,
248
- 'TA' : 0.02 ,
249
- 'TC' : 0.3 ,
250
- 'TG' : 0.01 ,
251
- 'CT' : 0.2 ,
252
- 'CG' : 0.5 ,
253
- 'CA' : 0.2 ,
254
- 'GA' : 0.1 ,
255
- 'GT' : 0.4 ,
256
- 'GC' : 0.2 }
257
-
258
- sample_data = {
259
- 'A' : 0.2137 ,
260
- 'T' : 0.3541 ,
261
- 'G' : 0.1946 ,
262
- 'C' : 0.2376 ,
263
- 'AT' : 0.0228 ,
264
- 'AC' : 0.0684 ,
265
- 'AG' : 0.2056 ,
266
- 'TA' : 0.0315 ,
267
- 'TC' : 0.0629 ,
268
- 'TG' : 0.0315 ,
269
- 'CT' : 0.1355 ,
270
- 'CG' : 0.0401 ,
271
- 'CA' : 0.0703 ,
272
- 'GA' : 0.1824 ,
273
- 'GT' : 0.0387 ,
274
- 'GC' : 0.1106 }
126
+ orig_positions = {
127
+ 'base' : [3 * max_arrow_width , 3 * max_arrow_width ],
128
+ 'center' : [length / 2 , 3 * max_arrow_width ],
129
+ 'tip' : [length - 3 * max_arrow_width , 3 * max_arrow_width ],
130
+ }
131
+ # for diagonal arrows, put the label at the arrow base
132
+ # for vertical or horizontal arrows, center the label
133
+ where = 'base' if (cp0 != cp1 ).all () else 'center'
134
+ # rotate based on direction of arrow (cos, sin)
135
+ M = [[cos , - sin ], [sin , cos ]]
136
+ x , y = np .dot (M , orig_positions [where ]) + [x_pos , y_pos ]
137
+ label = r'$r_{_{\mathrm{%s}}}$' % (pair ,)
138
+ ax .text (x , y , label , size = label_text_size , ha = 'center' , va = 'center' ,
139
+ color = labelcolor or fc )
275
140
276
141
277
142
if __name__ == '__main__' :
278
- from sys import argv
279
- d = None
280
- if len (argv ) > 1 :
281
- if argv [1 ] == 'full' :
282
- d = all_on_max
283
- scaled = False
284
- elif argv [1 ] == 'extreme' :
285
- d = extreme_data
286
- scaled = False
287
- elif argv [1 ] == 'realistic' :
288
- d = realistic_data
289
- scaled = False
290
- elif argv [1 ] == 'sample' :
291
- d = sample_data
292
- scaled = True
293
- if d is None :
294
- d = all_on_max
295
- scaled = False
296
- if len (argv ) > 2 :
297
- display = argv [2 ]
298
- else :
299
- display = 'length'
143
+ data = { # test data
144
+ 'A' : 0.4 , 'T' : 0.3 , 'G' : 0.6 , 'C' : 0.2 ,
145
+ 'AT' : 0.4 , 'AC' : 0.3 , 'AG' : 0.2 ,
146
+ 'TA' : 0.2 , 'TC' : 0.3 , 'TG' : 0.4 ,
147
+ 'CT' : 0.2 , 'CG' : 0.3 , 'CA' : 0.2 ,
148
+ 'GA' : 0.1 , 'GT' : 0.4 , 'GC' : 0.1 ,
149
+ }
300
150
301
151
size = 4
302
- plt .figure (figsize = (size , size ))
152
+ fig = plt .figure (figsize = (3 * size , size ), constrained_layout = True )
153
+ axs = fig .subplot_mosaic ([["length" , "width" , "alpha" ]])
303
154
304
- make_arrow_plot (d , display = display , linewidth = 0.001 , edgecolor = None ,
305
- normalize_data = scaled , head_starts_at_zero = True , size = size )
155
+ for display , ax in axs .items ():
156
+ make_arrow_graph (
157
+ ax , data , display = display , linewidth = 0.001 , edgecolor = None ,
158
+ normalize_data = True , size = size )
306
159
307
160
plt .show ()
0 commit comments