10000 Cleanup arrow_demo. · matplotlib/matplotlib@44a229b · GitHub
[go: up one dir, main page]

Skip to content

Commit 44a229b

Browse files
committed
Cleanup arrow_demo.
- Skip unnecessary numbered_bases_to_rates, lettered_bases_to_rates. - Don't use pyplot, but pass axes to make_arrow_plot. - letter colors only depend on the start point, not the end point. - Don't hard-code coordinates of arrows, but compute them from the letter coordinates. - Only include one dataset. - Directly show all three possibilities (length, width, alpha). - head_starts_at_zero=True resulted in incorrect alignment e.g. when calling `arrow_demo.py realistic length` previously; one should use the default value of head_starts_at_zero=False.
1 parent 9c530bc commit 44a229b

File tree

1 file changed

+80
-227
lines changed

1 file changed

+80
-227
lines changed

examples/text_labels_and_annotations/arrow_demo.py

Lines changed: 80 additions & 227 deletions
Original file line numberDiff line numberDiff line change
@@ -5,151 +5,83 @@
55
66
Arrow drawing example for the new fancy_arrow facilities.
77
8-
Code contributed by: Rob Knight <rob@spot.colorado.edu>
9-
10-
usage:
11-
12-
python arrow_demo.py realistic|full|sample|extreme
8+
Original author: Rob Knight <rob@spot.colorado.edu>
9+
"""
1310

11+
import itertools
1412

15-
"""
1613
import matplotlib.pyplot as plt
1714
import numpy as np
1815

19-
rates_to_bases = {'r1': 'AT', 'r2': 'TA', 'r3': 'GA', 'r4': 'AG', 'r5': 'CA',
20-
'r6': 'AC', 'r7': 'GT', 'r8': 'TG', 'r9': 'CT', 'r10': 'TC',
21-
'r11': 'GC', 'r12': 'CG'}
22-
numbered_bases_to_rates = {v: k for k, v in rates_to_bases.items()}
23-
lettered_bases_to_rates = {v: 'r' + v for k, v in rates_to_bases.items()}
24-
2516

26-
def make_arrow_plot(data, size=4, display='length', shape='right',
27-
max_arrow_width=0.03, arrow_sep=0.02, alpha=0.5,
28-
normalize_data=False, ec=None, labelcolor=None,
29-
head_starts_at_zero=True,
30-
rate_labels=lettered_bases_to_rates,
31-
**kwargs):
17+
def make_arrow_graph(ax, data, size=4, display='length', shape='right',
18+
max_arrow_width=0.03, arrow_sep=0.02, alpha=0.5,
19+
normalize_data=False, ec=None, labelcolor=None,
20+
**kwargs):
3221
"""
3322
Makes an arrow plot.
3423
3524
Parameters
3625
----------
26+
ax
27+
The axes where the graph is drawn.
3728
data
3829
Dict with probabilities for the bases and pair transitions.
3930
size
40-
Size of the graph in inches.
31+
Size of the plot, in inches.
4132
display : {'length', 'width', 'alpha'}
4233
The arrow property to change.
4334
shape : {'full', 'left', 'right'}
4435
For full or half arrows.
4536
max_arrow_width : float
46-
Maximum width of an arrow, data coordinates.
37+
Maximum width of an arrow, in data coordinates.
4738
arrow_sep : float
48-
Separation between arrows in a pair, data coordinates.
39+
Separation between arrows in a pair, in data coordinates.
4940
alpha : float
5041
Maximum opacity of arrows.
5142
**kwargs
52-
Can be anything allowed by a Arrow object, e.g. *linewidth* or
53-
*edgecolor*.
43+
`.FancyArrow` properties, e.g. *linewidth* or *edgecolor*.
5444
"""
5545

56-
plt.xlim(-0.5, 1.5)
57-
plt.ylim(-0.5, 1.5)
58-
plt.gcf().set_size_inches(size, size)
59-
plt.xticks([])
60-
plt.yticks([])
46+
ax.set(xlim=(-0.5, 1.5), ylim=(-0.5, 1.5), xticks=[], yticks=[])
47+
ax.text(.01, .01, f'flux encoded as arrow {display}',
48+
transform=ax.transAxes)
6149
max_text_size = size * 12
6250
min_text_size = size
6351
label_text_size = size * 2.5
64-
text_params = {'ha': 'center', 'va': 'center', 'family': 'sans-serif',
65-
'fontweight': 'bold'}
66-
r2 = np.sqrt(2)
67-
68-
deltas = {
69-
'AT': (1, 0),
70-
'TA': (-1, 0),
71-
'GA': (0, 1),
72-
'AG': (0, -1),
73-
'CA': (-1 / r2, 1 / r2),
74-
'AC': (1 / r2, -1 / r2),
75-
'GT': (1 / r2, 1 / r2),
76-
'TG': (-1 / r2, -1 / r2),
77-
'CT': (0, 1),
78-
'TC': (0, -1),
79-
'GC': (1, 0),
80-
'CG': (-1, 0)}
8152

82-
colors = {
83-
'AT': 'r',
84-
'TA': 'k',
85-
'GA': 'g',
86-
'AG': 'r',
87-
'CA': 'b',
88-
'AC': 'r',
89-
'GT': 'g',
90-
'TG': 'k',
91-
'CT': 'b',
92-
'TC': 'k',
93-
'GC': 'g',
94-
'CG': 'b'}
95-
96-
label_positions = {
97-
'AT': 'center',
98-
'TA': 'center',
99-
'GA': 'center',
100-
'AG': 'center',
101-
'CA': 'left',
102-
'AC': 'left',
103-
'GT': 'left',
104-
'TG': 'left',
105-
'CT': 'center',
106-
'TC': 'center',
107-
'GC': 'center',
108-
'CG': 'center'}
109-
110-
def do_fontsize(k):
111-
return float(np.clip(max_text_size * np.sqrt(data[k]),
112-
min_text_size, max_text_size))
113-
114-
plt.text(0, 1, '$A_3$', color='r', size=do_fontsize('A'), **text_params)
115-
plt.text(1, 1, '$T_3$', color='k', size=do_fontsize('T'), **text_params)
116-
plt.text(0, 0, '$G_3$', color='g', size=do_fontsize('G'), **text_params)
117-
plt.text(1, 0, '$C_3$', color='b', size=do_fontsize('C'), **text_params)
53+
bases = 'ATGC'
54+
coords = {
55+
'A': np.array([0, 1]),
56+
'T': np.array([1, 1]),
57+
'G': np.array([0, 0]),
58+
'C': np.array([1, 0]),
59+
}
60+
colors = {'A': 'r', 'T': 'k', 'G': 'g', 'C': 'b'}
61+
62+
for base in bases:
63+
fontsize = np.clip(max_text_size * data[base]**(1/2),
64+
min_text_size, max_text_size)
65+
ax.text(*coords[base], f'${base}_3$',
66+
color=colors[base], size=fontsize,
67+
horizontalalignment='center', verticalalignment= 10000 'center',
68+
weight='bold')
11869

11970
arrow_h_offset = 0.25 # data coordinates, empirically determined
12071
max_arrow_length = 1 - 2 * arrow_h_offset
12172
max_head_width = 2.5 * max_arrow_width
12273
max_head_length = 2 * max_arrow_width
123-
arrow_params = {'length_includes_head': True, 'shape': shape,
124-
'head_starts_at_zero': head_starts_at_zero}
12574
sf = 0.6 # max arrow size represents this in data coords
12675

127-
d = (r2 / 2 + arrow_h_offset - 0.5) / r2 # distance for diags
128-
r2v = arrow_sep / r2 # offset for diags
129-
130-
# tuple of x, y for start position
131-
positions = {
132-
'AT': (arrow_h_offset, 1 + arrow_sep),
133-
'TA': (1 - arrow_h_offset, 1 - arrow_sep),
134-
'GA': (-arrow_sep, arrow_h_offset),
135-
'AG': (arrow_sep, 1 - arrow_h_offset),
136-
'CA': (1 - d - r2v, d - r2v),
137-
'AC': (d + r2v, 1 - d + r2v),
138-
'GT': (d - r2v, d + r2v),
139-
'TG': (1 - d + r2v, 1 - d - r2v),
140-
'CT': (1 - arrow_sep, arrow_h_offset),
141-
'TC': (1 + arrow_sep, 1 - arrow_h_offset),
142-
'GC': (arrow_h_offset, arrow_sep),
143-
'CG': (1 - arrow_h_offset, -arrow_sep)}
144-
14576
if normalize_data:
14677
# find maximum value for rates, i.e. where keys are 2 chars long
14778
max_val = max((v for k, v in data.items() if len(k) == 2), default=0)
14879
# divide rates by max val, multiply by arrow scale factor
14980
for k, v in data.items():
15081
data[k] = v / max_val * sf
15182

152-
def draw_arrow(pair, alpha=alpha, ec=ec, labelcolor=labelcolor):
83+
# iterate over strings 'AT', 'TA', 'AG', 'GA', etc.
84+
for pair in map(''.join, itertools.permutations(bases, 2)):
15385
# set the length of the arrow
15486
if display == 'length':
15587
length = (max_head_length
@@ -159,7 +91,6 @@ def draw_arrow(pair, alpha=alpha, ec=ec, labelcolor=labelcolor):
15991
# set the transparency of the arrow
16092
if display == 'alpha':
16193
alpha = min(data[pair] / sf, alpha)
162-
16394
# set the width of the arrow
16495
if display == 'width':
16596
scale = data[pair] / sf
@@ -171,137 +102,59 @@ def draw_arrow(pair, alpha=alpha, ec=ec, labelcolor=labelcolor):
171102
head_width = max_head_width
172103
head_length = max_head_length
173104

174-
fc = colors[pair]
175-
ec = ec or fc
176-
177-
x_scale, y_scale = deltas[pair]
178-
x_pos, y_pos = positions[pair]
179-
plt.arrow(x_pos, y_pos, x_scale * length, y_scale * length,
180-
fc=fc, ec=ec, alpha=alpha, width=width,
181-
head_width=head_width, head_length=head_length,
182-
**arrow_params)
183-
184-
# figure out coordinates for text
105+
fc = colors[pair[0]]
106+
107+
cp0 = coords[pair[0]]
108+
cp1 = coords[pair[1]]
109+
# unit vector in arrow direction
110+
delta = cos, sin = (cp1 - cp0) / np.hypot(*(cp1 - cp0))
111+
x_pos, y_pos = (
112+
(cp0 + cp1) / 2 # midpoint
113+
- delta * length / 2 # half the arrow length
114+
+ np.array([-sin, cos]) * arrow_sep # shift outwards by arrow_sep
115+
)
116+
ax.arrow(
117+
x_pos, y_pos, cos * length, sin * length,
118+
fc=fc, ec=ec or fc, alpha=alpha, width=width,
119+
head_width=head_width, head_length=head_length, shape=shape,
120+
length_includes_head=True,
121+
)
122+
123+
# figure out coordinates for text:
185124
# if drawing relative to base: x and y are same as for arrow
186125
# dx and dy are one arrow width left and up
187-
# need to rotate based on direction of arrow, use x_scale and y_scale
188-
# as sin x and cos x?
189-
sx, cx = y_scale, x_scale
190-
191-
where = label_positions[pair]
192-
if where == 'left':
193-
orig_position = 3 * np.array([[max_arrow_width, max_arrow_width]])
194-
elif where == 'absolute':
195-
orig_position = np.array([[max_arrow_length / 2.0,
196-
3 * max_arrow_width]])
197-
elif where == 'right':
198-
orig_position = np.array([[length - 3 * max_arrow_width,
199-
3 * max_arrow_width]])
200-
elif where == 'center':
201-
orig_position = np.array([[length / 2.0, 3 * max_arrow_width]])
202-
else:
203-
raise ValueError("Got unknown position parameter %s" % where)
204-
205-
M = np.array([[cx, sx], [-sx, cx]])
206-
coords = np.dot(orig_position, M) + [[x_pos, y_pos]]
207-
x, y = np.ravel(coords)
208-
orig_label = rate_labels[pair]
209-
label = r'$%s_{_{\mathrm{%s}}}$' % (orig_label[0], orig_label[1:])
210-
211-
plt.text(x, y, label, size=label_text_size, ha='center', va='center',
212-
color=labelcolor or fc)
213-
214-
for p in sorted(positions):
215-
draw_arrow(p)
216-
217-
218-
# test data
219-
all_on_max = dict([(i, 1) for i in 'TCAG'] +
220-
[(i + j, 0.6) for i in 'TCAG' for j in 'TCAG'])
221-
222-
realistic_data = {
223-
'A': 0.4,
224-
'T': 0.3,
225-
'G': 0.5,
226-
'C': 0.2,
227-
'AT': 0.4,
228-
'AC': 0.3,
229-
'AG': 0.2,
230-
'TA': 0.2,
231-
'TC': 0.3,
232-
'TG': 0.4,
233-
'CT': 0.2,
234-
'CG': 0.3,
235-
'CA': 0.2,
236-
'GA': 0.1,
237-
'GT': 0.4,
238-
'GC': 0.1}
239-
240-
extreme_data = {
241-
'A': 0.75,
242-
'T': 0.10,
243-
'G': 0.10,
244-
'C': 0.05,
245-
'AT': 0.6,
246-
'AC': 0.3,
247-
'AG': 0.1,
248-
'TA': 0.02,
249-
'TC': 0.3,
250-
'TG': 0.01,
251-
'CT': 0.2,
252-
'CG': 0.5,
253-
'CA': 0.2,
254-
'GA': 0.1,
255-
'GT': 0.4,
256-
'GC': 0.2}
257-
258-
sample_data = {
259-
'A': 0.2137,
260-
'T': 0.3541,
261-
'G': 0.1946,
262-
'C': 0.2376,
263-
'AT': 0.0228,
264-
'AC': 0.0684,
265-
'AG': 0.2056,
266-
'TA': 0.0315,
267-
'TC': 0.0629,
268-
'TG': 0.0315,
269-
'CT': 0.1355,
270-
'CG': 0.0401,
271-
'CA': 0.0703,
272-
'GA': 0.1824,
273-
'GT': 0.0387,
274-
'GC': 0.1106}
126+
orig_positions = {
127+
'base': [3 * max_arrow_width, 3 * max_arrow_width],
128+
'center': [length / 2, 3 * max_arrow_width],
129+
'tip': [length - 3 * max_arrow_width, 3 * max_arrow_width],
130+
}
131+
# for diagonal arrows, put the label at the arrow base
132+
# for vertical or horizontal arrows, center the label
133+
where = 'base' if (cp0 != cp1).all() else 'center'
134+
# rotate based on direction of arrow (cos, sin)
135+
M = [[cos, -sin], [sin, cos]]
136+
x, y = np.dot(M, orig_positions[where]) + [x_pos, y_pos]
137+
label = r'$r_{_{\mathrm{%s}}}$' % (pair,)
138+
ax.text(x, y, label, size=label_text_size, ha='center', va='center',
139+
color=labelcolor or fc)
275140

276141

277142
if __name__ == '__main__':
278-
from sys import argv
279-
d = None
280-
if len(argv) > 1:
281-
if argv[1] == 'full':
282-
d = all_on_max
283-
scaled = False
284-
elif argv[1] == 'extreme':
285-
d = extreme_data
286-
scaled = False
287-
elif argv[1] == 'realistic':
288-
d = realistic_data
289-
scaled = False
290-
elif argv[1] == 'sample':
291-
d = sample_data
292-
scaled = True
293-
if d is None:
294-
d = all_on_max
295-
scaled = False
296-
if len(argv) > 2:
297-
display = argv[2]
298-
else:
299-
display = 'length'
143+
data = { # test data
144+
'A': 0.4, 'T': 0.3, 'G': 0.6, 'C': 0.2,
145+
'AT': 0.4, 'AC': 0.3, 'AG': 0.2,
146+
'TA': 0.2, 'TC': 0.3, 'TG': 0.4,
147+
'CT': 0.2, 'CG': 0.3, 'CA': 0.2,
148+
'GA': 0.1, 'GT': 0.4, 'GC': 0.1,
149+
}
300150

301151
size = 4
302-
plt.figure(figsize=(size, size))
152+
fig = plt.figure(figsize=(3 * size, size), constrained_layout=True)
153+
axs = fig.subplot_mosaic([["length", "width", "alpha"]])
303154

304-
make_arrow_plot(d, display=display, linewidth=0.001, edgecolor=None,
305-
normalize_data=scaled, head_starts_at_zero=True, size=size)
155+
for display, ax in axs.items():
156+
make_arrow_graph(
157+
ax, data, display=display, linewidth=0.001, edgecolor=None,
158+
normalize_data=True, size=size)
306159

307160
plt.show()

0 commit comments

Comments
 (0)
0