8000 Cleanup arrow_demo. · matplotlib/matplotlib@b488fce · GitHub
[go: up one dir, main page]

Skip to content

Commit b488fce

Browse files
committed
Cleanup arrow_demo.
- Skip unnecessary numbered_bases_to_rates, lettered_bases_to_rates. - Don't use pyplot, but pass axes to make_arrow_plot. - letter colors only depend on the start point, not the end point. - Don't hard-code coordinates of arrows, but compute them from the letter coordinates. - Only include one dataset. - Directly show all three possibilities (length, width, alpha). - head_starts_at_zero=True resulted in incorrect alignment e.g. when calling `arrow_demo.py realistic length` previously; one should use the default value of head_starts_at_zero=False.
1 parent 9c530bc commit b488fce

File tree

1 file changed

+72
-196
lines changed

1 file changed

+72
-196
lines changed

examples/text_labels_and_annotations/arrow_demo.py

Lines changed: 72 additions & 196 deletions
Original file line numberDiff line numberDiff line change
@@ -5,93 +5,57 @@
55
66
Arrow drawing example for the new fancy_arrow facilities.
77
8-
Code contributed by: Rob Knight <rob@spot.colorado.edu>
9-
10-
usage:
11-
12-
python arrow_demo.py realistic|full|sample|extreme
8+
Original author: Rob Knight <rob@spot.colorado.edu>
9+
"""
1310

11+
import itertools
1412

15-
"""
1613
import matplotlib.pyplot as plt
1714
import numpy as np
1815

19-
rates_to_bases = {'r1': 'AT', 'r2': 'TA', 'r3': 'GA', 'r4': 'AG', 'r5': 'CA',
20-
'r6': 'AC', 'r7': 'GT', 'r8': 'TG', 'r9': 'CT', 'r10': 'TC',
21-
'r11': 'GC', 'r12': 'CG'}
22-
numbered_bases_to_rates = {v: k for k, v in rates_to_bases.items()}
23-
lettered_bases_to_rates = {v: 'r' + v for k, v in rates_to_bases.items()}
2416

25-
26-
def make_arrow_plot(data, size=4, display='length', shape='right',
27-
max_arrow_width=0.03, arrow_sep=0.02, alpha=0.5,
28-
normalize_data=False, ec=None, labelcolor=None,
29-
head_starts_at_zero=True,
30-
rate_labels=lettered_bases_to_rates,
31-
**kwargs):
17+
def make_arrow_graph(ax, data, size=4, display='length', shape='right',
18+
max_arrow_width=0.03, arrow_sep=0.02, alpha=0.5,
19+
normalize_data=False, ec=None, labelcolor=None,
20+
**kwargs):
3221
"""
3322
Makes an arrow plot.
3423
3524
Parameters
3625
----------
26+
ax
27+
The axes where the graph is drawn.
3728
data
3829
Dict with probabilities for the bases and pair transitions.
3930
size
40-
Size of the graph in inches.
31+
Size of the plot, in inches.
4132
display : {'length', 'width', 'alpha'}
4233
The arrow property to change.
4334
shape : {'full', 'left', 'right'}
4435
For full or half arrows.
4536
max_arrow_width : float
46-
Maximum width of an arrow, data coordinates.
37+
Maximum width of an arrow, in data coordinates.
4738
arrow_sep : float
48-
Separation between arrows in a pair, data coordinates.
39+
Separation between arrows in a pair, in data coordinates.
4940
alpha : float
5041
Maximum opacity of arrows.
5142
**kwargs
52-
Can be anything allowed by a Arrow object, e.g. *linewidth* or
53-
*edgecolor*.
43+
`.FancyArrow` properties, e.g. *linewidth* or *edgecolor*.
5444
"""
5545

56-
plt.xlim(-0.5, 1.5)
57-
plt.ylim(-0.5, 1.5)
58-
plt.gcf().set_size_inches(size, size)
59-
plt.xticks([])
60-
plt.yticks([])
46+
ax.set(xlim=(-0.5, 1.5), ylim=(-0.5, 1.5), xticks=[], yticks=[])
6147
max_text_size = size * 12
6248
min_text_size = size
6349
label_text_size = size * 2.5
64-
text_params = {'ha': 'center', 'va': 'center', 'family': 'sans-serif',
65-
'fontweight': 'bold'}
66-
r2 = np.sqrt(2)
67-
68-
deltas = {
69-
'AT': (1, 0),
70-
'TA': (-1, 0),
71-
'GA': (0, 1),
72-
'AG': (0, -1),
73-
'CA': (-1 / r2, 1 / r2),
74-
'AC': (1 / r2, -1 / r2),
75-
'GT': (1 / r2, 1 / r2),
76-
'TG': (-1 / r2, -1 / r2),
77-
'CT': (0, 1),
78-
'TC': (0, -1),
79-
'GC': (1, 0),
80-
'CG': (-1, 0)}
8150

82-
colors = {
83-
'AT': 'r',
84-
'TA': 'k',
85-
'GA': 'g',
86-
'AG': 'r',
87-
'CA': 'b',
88-
'AC': 'r',
89-
'GT': 'g',
90-
'TG': 'k',
91-
'CT': 'b',
92-
'TC': 'k',
93-
'GC': 'g',
94-
'CG': 'b'}
51+
bases = 'ATGC'
52+
coords = {
53+
'A': np.array([0, 1]),
54+
'T': np.array([1, 1]),
55+
'G': np.array([0, 0]),
56+
'C': np.array([1, 0]),
57+
}
58+
colors = {'A': 'r', 'T': 'k', 'G': 'g', 'C': 'b'}
9559

9660
label_positions = {
9761
'AT': 'center',
@@ -107,49 +71,31 @@ def make_arrow_plot(data, size=4, display='length', shape='right',
10771
'GC': 'center',
10872
'CG': 'center'}
10973

110-
def do_fontsize(k):
111-
return float(np.clip(max_text_size * np.sqrt(data[k]),
112-
min_text_size, max_text_size))
74+
def get_fontsize(k):
75+
return np.clip(max_text_size * data[k]**(1/2),
76+
min_text_size, max_text_size)
11377

114-
plt.text(0, 1, '$A_3$', color='r', size=do_fontsize('A'), **text_params)
115-
plt.text(1, 1, '$T_3$', color='k', size=do_fontsize('T'), **text_params)
116-
plt.text(0, 0, '$G_3$', color='g', size=do_fontsize('G'), **text_params)
117-
plt.text(1, 0, '$C_3$', color='b', size=do_fontsize('C'), **text_params)
78+
for base in bases:
79+
ax.text(*coords[base], f'${base}_3$',
80+
color=colors[base], size=get_fontsize(base),
81+
horizontalalignment='center', verticalalignment='center',
82+
weight='bold')
11883

11984
arrow_h_offset = 0.25 # data coordinates, empirically determined
12085
max_arrow_length = 1 - 2 * arrow_h_offset
12186
max_head_width = 2.5 * max_arrow_width
12287
max_head_length = 2 * max_arrow_width
123-
arrow_params = {'length_includes_head': True, 'shape': shape,
124-
'head_starts_at_zero': head_starts_at_zero}
12588
sf = 0.6 # max arrow size represents this in data coords
12689

127-
d = (r2 / 2 + arrow_h_offset - 0.5) / r2 # distance for diags
128-
r2v = arrow_sep / r2 # offset for diags
129-
130-
# tuple of x, y for start position
131-
positions = {
132-
'AT': (arrow_h_offset, 1 + arrow_sep),
133-
'TA': (1 - arrow_h_offset, 1 - arrow_sep),
134-
'GA': (-arrow_sep, arrow_h_offset),
135-
'AG': (arrow_sep, 1 - arrow_h_offset),
136-
'CA': (1 - d - r2v, d - r2v),
137-
'AC': (d + r2v, 1 - d + r2v),
138-
'GT': (d - r2v, d + r2v),
139-
'TG': (1 - d + r2v, 1 - d - r2v),
140-
'CT': (1 - arrow_sep, arrow_h_offset),
141-
'TC': (1 + arrow_sep, 1 - arrow_h_offset),
142-
'GC': (arrow_h_offset, arrow_sep),
143-
'CG': (1 - arrow_h_offset, -arrow_sep)}
144-
14590
if normalize_data:
14691
# find maximum value for rates, i.e. where keys are 2 chars long
14792
max_val = max((v for k, v in data.items() if len(k) == 2), default=0)
14893
# divide rates by max val, multiply by arrow scale factor
14994
for k, v in data.items():
15095
data[k] = v / max_val * sf
15196

152-
def draw_arrow(pair, alpha=alpha, ec=ec, labelcolor=labelcolor):
97+
# iterate over strings 'AT', 'TA', 'AG', 'GA', etc.
98+
for pair in map(''.join, itertools.permutations(bases, 2)):
15399
# set the length of the arrow
154100
if display == 'length':
155101
length = (max_head_length
@@ -159,7 +105,6 @@ def draw_arrow(pair, alpha=alpha, ec=ec, labelcolor=labelcolor):
159105
# set the transparency of the arrow
160106
if display == 'alpha':
161107
alpha = min(data[pair] / sf, alpha)
162-
163108
# set the width of the arrow
164109
if display == 'width':
165110
scale = data[pair] / sf
@@ -171,58 +116,44 @@ def draw_arrow(pair, alpha=alpha, ec=ec, labelcolor=labelcolor):
171116
head_width = max_head_width
172117
head_length = max_head_length
173118

174-
fc = colors[pair]
175-
ec = ec or fc
176-
177-
x_scale, y_scale = deltas[pair]
178-
x_pos, y_pos = positions[pair]
179-
plt.arrow(x_pos, y_pos, x_scale * length, y_scale * length,
180-
fc=fc, ec=ec, alpha=alpha, width=width,
181-
head_width=head_width, head_length=head_length,
182-
**arrow_params)
119+
fc = colors[pair[0]]
120+
121+
delta = coords[pair[1]] - coords[pair[0]]
122+
delta = delta / np.hypot(*delta) # unit vector in arrow direction
123+
cos, sin = delta
124+
x_pos, y_pos = (
125+
(coords[pair[0]] + coords[pair[1]]) / 2 # midpoint
126+
- delta * length / 2 # half the arrow length
127+
+ np.array([-sin, cos]) * arrow_sep # shift outwards by arrow_sep
128+
)
129+
ax.arrow(
130+
x_pos, y_pos, cos * length, sin * length,
131+
fc=fc, ec=ec or fc, alpha=alpha, width=width,
132+
head_width=head_width, head_length=head_length, shape=shape,
133+
length_includes_head=True,
134+
)
183135

184136
# figure out coordinates for text
185137
# if drawing relative to base: x and y are same as for arrow
186138
# dx and dy are one arrow width left and up
187-
# need to rotate based on direction of arrow, use x_scale and y_scale
188-
# as sin x and cos x?
189-
sx, cx = y_scale, x_scale
190-
191-
where = label_positions[pair]
192-
if where == 'left':
193-
orig_position = 3 * np.array([[max_arrow_width, max_arrow_width]])
194-
elif where == 'absolute':
195-
orig_position = np.array([[max_arrow_length / 2.0,
196-
3 * max_arrow_width]])
197-
elif where == 'right':
198-
orig_position = np.array([[length - 3 * max_arrow_width,
199-
3 * max_arrow_width]])
200-
elif where == 'center':
201-
orig_position = np.array([[length / 2.0, 3 * max_arrow_width]])
202-
else:
203-
raise ValueError("Got unknown position parameter %s" % where)
204-
205-
M = np.array([[cx, sx], [-sx, cx]])
206-
coords = np.dot(orig_position, M) + [[x_pos, y_pos]]
207-
x, y = np.ravel(coords)
208-
orig_label = rate_labels[pair]
209-
label = r'$%s_{_{\mathrm{%s}}}$' % (orig_label[0], orig_label[1:])
210-
211-
plt.text(x, y, label, size=label_text_size, ha='center', va='center',
212-
color=labelcolor or fc)
213-
214-
for p in sorted(positions):
215-
draw_arrow(p)
216-
217-
218-
# test data
219-
all_on_max = dict([(i, 1) for i in 'TCAG'] +
220-
[(i + j, 0.6) for i in 'TCAG' for j in 'TCAG'])
221-
222-
realistic_data = {
139+
# need to rotate based on direction of arrow (cos, sin)
140+
orig_position = {
141+
'left': [3 * max_arrow_width, 3 * max_arrow_width],
142+
'absolute': [max_arrow_length / 2, 3 * max_arrow_width],
143+
'right': [length - 3 * max_arrow_width, 3 * max_arrow_width],
144+
'center': [length / 2, 3 * max_arrow_width],
145+
}[label_positions[pair]]
146+
M = [[cos, -sin], [sin, cos]]
147+
x, y = np.dot(M, orig_position) + [x_pos, y_pos]
148+
label = r'$r_{_{\mathrm{%s}}}$' % (pair,)
149+
ax.text(x, y, label, size=label_text_size, ha='center', va='center',
150+
color=labelcolor or fc)
151+
152+
153+
data = { # test data
223154
'A': 0.4,
224155
'T': 0.3,
225-
'G': 0.5,
156+
'G': 0.6,
226157
'C': 0.2,
227158
'AT': 0.4,
228159
'AC': 0.3,
@@ -235,73 +166,18 @@ def draw_arrow(pair, alpha=alpha, ec=ec, labelcolor=labelcolor):
235166
'CA': 0.2,
236167
'GA': 0.1,
237168
'GT': 0.4,
238-
'GC': 0.1}
239-
240-
extreme_data = {
241-
'A': 0.75,
242-
'T': 0.10,
243-
'G': 0.10,
244-
'C': 0.05,
245-
'AT': 0.6,
246-
'AC': 0.3,
247-
'AG': 0.1,
248-
'TA': 0.02,
249-
'TC': 0.3,
250-
'TG': 0.01,
251-
'CT': 0.2,
252-
'CG': 0.5,
253-
'CA': 0.2,
254-
'GA': 0.1,
255-
'GT': 0.4,
256-
'GC': 0.2}
257-
258-
sample_data = {
259-
'A': 0.2137,
260-
'T': 0.3541,
261-
'G': 0.1946,
262-
'C': 0.2376,
263-
'AT': 0.0228,
264-
'AC': 0.0684,
265-
'AG': 0.2056,
266-
'TA': 0.0315,
267-
'TC': 0.0629,
268-
'TG': 0.0315,
269-
'CT': 0.1355,
270-
'CG': 0.0401,
271-
'CA': 0.0703,
272-
'GA': 0.1824,
273-
'GT': 0.0387,
274-
'GC': 0.1106}
169+
'GC': 0.1,
170+
}
275171

276172

277173
if __name__ == '__main__':
278-
from sys import argv
279-
d = None
280-
if len(argv) > 1:
281-
if argv[1] == 'full':
282-
d = all_on_max
283-
scaled = False
284-
elif argv[1] == 'extreme':
285-
d = extreme_data
286-
scaled = False
287-
elif argv[1] == 'realistic':
288-
d = realistic_data
289-
scaled = False
290-
elif argv[1] == 'sample':
291-
d = sample_data
292-
scaled = True
293-
if d is None:
294-
d = all_on_max
295-
scaled = False
296-
if len(argv) > 2:
297-
display = argv[2]
298-
else:
299-
display = 'length'
300-
301174
size = 4
302-
plt.figure(figsize=(size, size))
175+
fig = plt.figure(figsize=(3 * size, size), constrained_layout=True)
176+
axs = fig.subplot_mosaic([["length", "width", "alpha"]])
303177

304-
make_arrow_plot(d, display=display, linewidth=0.001, edgecolor=None,
305-
normalize_data=scaled, head_starts_at_zero=True, size=size)
178+
for display, ax in axs.items():
179+
make_arrow_graph(
180+
ax, data, display=display, linewidth=0.001, edgecolor=None,
181+
normalize_data=True, size=size)
306182

307183
plt.show()

0 commit comments

Comments
 (0)
0