8000 Cleanup arrow_demo. by anntzer · Pull Request #20390 · matplotlib/matplotlib · GitHub
[go: up one dir, main page]

Skip to content

Cleanup arrow_demo. #20390

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 11, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
310 changes: 81 additions & 229 deletions examples/text_labels_and_annotations/arrow_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,153 +3,84 @@
Arrow Demo
==========

Arrow drawing example for the new fancy_arrow facilities.

Code contributed by: Rob Knight <rob@spot.colorado.edu>

usage:

python arrow_demo.py realistic|full|sample|extreme
Three ways of drawing arrows to encode arrow "strength" (e.g., transition
probabilities in a Markov model) using arrow length, width, or alpha (opacity).
"""

import itertools

"""
import matplotlib.pyplot as plt
import numpy as np

rates_to_bases = {'r1': 'AT', 'r2': 'TA', 'r3': 'GA', 'r4': 'AG', 'r5': 'CA',
'r6': 'AC', 'r7': 'GT', 'r8': 'TG', 'r9': 'CT', 'r10': 'TC',
'r11': 'GC', 'r12': 'CG'}
numbered_bases_to_rates = {v: k for k, v in rates_to_bases.items()}
lettered_bases_to_rates = {v: 'r' + v for k, v in rates_to_bases.items()}


def make_arrow_plot(data, size=4, display='length', shape='right',
max_arrow_width=0.03, arrow_sep=0.02, alpha=0.5,
normalize_data=False, ec=None, labelcolor=None,
head_starts_at_zero=True,
rate_labels=lettered_bases_to_rates,
**kwargs):
def make_arrow_graph(ax, data, size=4, display='length', shape='right',
max_arrow_width=0.03, arrow_sep=0.02, alpha=0.5,
normalize_data=False, ec=None, labelcolor=None,
**kwargs):
"""
Makes an arrow plot.

Parameters
----------
ax
The axes where the graph is drawn.
data
Dict with probabilities for the bases and pair transitions.
size
Size of the graph in inches.
Size of the plot, in inches.
display : {'length', 'width', 'alpha'}
The arrow property to change.
shape : {'full', 'left', 'right'}
For full or half arrows.
max_arrow_width : float
Maximum width of an arrow, data coordinates.
Maximum width of an arrow, in data coordinates.
arrow_sep : float
Separation between arrows in a pair, data coordinates.
Separation between arrows in a pair, in data coordinates.
alpha : float
Maximum opacity of arrows.
**kwargs
Can be anything allowed by a Arrow object, e.g. *linewidth* or
*edgecolor*.
`.FancyArrow` properties, e.g. *linewidth* or *edgecolor*.
"""

plt.xlim(-0.5, 1.5)
plt.ylim(-0.5, 1.5)
plt.gcf().set_size_inches(size, size)
plt.xticks([])
plt.yticks([])
ax.set(xlim=(-0.5, 1.5), ylim=(-0.5, 1.5), xticks=[], yticks=[])
ax.text(.01, .01, f'flux encoded as arrow {display}',
transform=ax.transAxes)
max_text_size = size * 12
min_text_size = size
label_text_size = size * 2.5
text_params = {'ha': 'center', 'va': 'center', 'family': 'sans-serif',
'fontweight': 'bold'}
r2 = np.sqrt(2)

deltas = {
'AT': (1, 0),
'TA': (-1, 0),
'GA': (0, 1),
'AG': (0, -1),
'CA': (-1 / r2, 1 / r2),
'AC': (1 / r2, -1 / r2),
'GT': (1 / r2, 1 / r2),
'TG': (-1 / r2, -1 / r2),
'CT': (0, 1),
'TC': (0, -1),
'GC': (1, 0),
'CG': (-1, 0)}

colors = {
'AT': 'r',
'TA': 'k',
'GA': 'g',
'AG': 'r',
'CA': 'b',
'AC': 'r',
'GT': 'g',
'TG': 'k',
'CT': 'b',
'TC': 'k',
'GC': 'g',
'CG': 'b'}

label_positions = {
'AT': 'center',
'TA': 'center',
'GA': 'center',
'AG': 'center',
'CA': 'left',
'AC': 'left',
'GT': 'left',
'TG': 'left',
'CT': 'center',
'TC': 'center',
'GC': 'center',
'CG': 'center'}

def do_fontsize(k):
return float(np.clip(max_text_size * np.sqrt(data[k]),
min_text_size, max_text_size))

plt.text(0, 1, '$A_3$', color='r', size=do_fontsize('A'), **text_params)
plt.text(1, 1, '$T_3$', color='k', size=do_fontsize('T'), **text_params)
plt.text(0, 0, '$G_3$', color='g', size=do_fontsize('G'), **text_params)
plt.text(1, 0, '$C_3$', color='b', size=do_fontsize('C'), **text_params)
bases = 'ATGC'
coords = {
'A': np.array([0, 1]),
'T': np.array([1, 1]),
'G': np.array([0, 0]),
'C': np.array([1, 0]),
}
colors = {'A': 'r', 'T': 'k', 'G': 'g', 'C': 'b'}

for base in bases:
fontsize = np.clip(max_text_size * data[base]**(1/2),
min_text_size, max_text_size)
ax.text(*coords[base], f'${base}_3$',
color=colors[base], size=fontsize,
horizontalalignment='center', verticalalignment='center',
weight='bold')

arrow_h_offset = 0.25 # data coordinates, empirically determined
max_arrow_length = 1 - 2 * arrow_h_offset
max_head_width = 2.5 * max_arrow_width
max_head_length = 2 * max_arrow_width
arrow_params = {'length_includes_head': True, 'shape': shape,
'head_starts_at_zero': head_starts_at_zero}
sf = 0.6 # max arrow size represents this in data coords

d = (r2 / 2 + arrow_h_offset - 0.5) / r2 # distance for diags
r2v = arrow_sep / r2 # offset for diags

# tuple of x, y for start position
positions = {
'AT': (arrow_h_offset, 1 + arrow_sep),
'TA': (1 - arrow_h_offset, 1 - arrow_sep),
'GA': (-arrow_sep, arrow_h_offset),
'AG': (arrow_sep, 1 - arrow_h_offset),
'CA': (1 - d - r2v, d - r2v),
'AC': (d + r2v, 1 - d + r2v),
'GT': (d - r2v, d + r2v),
'TG': (1 - d + r2v, 1 - d - r2v),
'CT': (1 - arrow_sep, arrow_h_offset),
'TC': (1 + arrow_sep, 1 - arrow_h_offset),
'GC': (arrow_h_offset, arrow_sep),
'CG': (1 - arrow_h_offset, -arrow_sep)}

if normalize_data:
# find maximum value for rates, i.e. where keys are 2 chars long
max_val = max((v for k, v in data.items() if len(k) == 2), default=0)
# divide rates by max val, multiply by arrow scale factor
for k, v in data.items():
data[k] = v / max_val * sf

def draw_arrow(pair, alpha=alpha, ec=ec, labelcolor=labelcolor):
# iterate over strings 'AT', 'TA', 'AG', 'GA', etc.
for pair in map(''.join, itertools.permutations(bases, 2)):
# set the length of the arrow
if display == 'length':
length = (max_head_length
Expand All @@ -159,7 +90,6 @@ def draw_arrow(pair, alpha=alpha, ec=ec, labelcolor=labelcolor):
# set the transparency of the arrow
if display == 'alpha':
alpha = min(data[pair] / sf, alpha)

# set the width of the arrow
if display == 'width':
scale = data[pair] / sf
Expand All @@ -171,137 +101,59 @@ def draw_arrow(pair, alpha=alpha, ec=ec, labelcolor=labelcolor):
head_width = max_head_width
head_length = max_head_length

fc = colors[pair]
ec = ec or fc

x_scale, y_scale = deltas[pair]
x_pos, y_pos = positions[pair]
plt.arrow(x_pos, y_pos, x_scale * length, y_scale * length,
fc=fc, ec=ec, alpha=alpha, width=width,
head_width=head_width, head_length=head_length,
**arrow_params)

# figure out coordinates for text
fc = colors[pair[0]]

cp0 = coords[pair[0]]
cp1 = coords[pair[1]]
# unit vector in arrow direction
delta = cos, sin = (cp1 - cp0) / np.hypot(*(cp1 - cp0))
x_pos, y_pos = (
(cp0 + cp1) / 2 # midpoint
- delta * length / 2 # half the arrow length
+ np.array([-sin, cos]) * arrow_sep # shift outwards by arrow_sep
)
ax.arrow(
x_pos, y_pos, cos * length, sin * length,
fc=fc, ec=ec or fc, alpha=alpha, width=width,
head_width=head_width, head_length=head_length, shape=shape,
length_includes_head=True,
)

# figure out coordinates for text:
# if drawing relative to base: x and y are same as for arrow
# dx and dy are one arrow width left and up
# need to rotate based on direction of arrow, use x_scale and y_scale
# as sin x and cos x?
sx, cx = y_scale, x_scale

where = label_positions[pair]
if where == 'left':
orig_position = 3 * np.array([[max_arrow_width, max_arrow_width]])
elif where == 'absolute':
orig_position = np.array([[max_arrow_length / 2.0,
3 * max_arrow_width]])
elif where == 'right':
orig_position = np.array([[length - 3 * max_arrow_width,
3 * max_arrow_width]])
elif where == 'center':
orig_position = np.array([[length / 2.0, 3 * max_arrow_width]])
else:
raise ValueError("Got unknown position parameter %s" % where)

M = np.array([[cx, sx], [-sx, cx]])
coords = np.dot(orig_position, M) + [[x_pos, y_pos]]
x, y = np.ravel(coords)
orig_label = rate_labels[pair]
label = r'$%s_{_{\mathrm{%s}}}$' % (orig_label[0], orig_label[1:])

plt.text(x, y, label, size=label_text_size, ha='center', va='center',
color=labelcolor or fc)

for p in sorted(positions):
draw_arrow(p)


# test data
all_on_max = dict([(i, 1) for i in 'TCAG'] +
[(i + j, 0.6) for i in 'TCAG' for j in 'TCAG'])

realistic_data = {
'A': 0.4,
'T': 0.3,
'G': 0.5,
'C': 0.2,
'AT': 0.4,
'AC': 0.3,
'AG': 0.2,
'TA': 0.2,
'TC': 0.3,
'TG': 0.4,
'CT': 0.2,
'CG': 0.3,
'CA': 0.2,
'GA': 0.1,
'GT': 0.4,
'GC': 0.1}

extreme_data = {
'A': 0.75,
'T': 0.10,
'G': 0.10,
'C': 0.05,
'AT': 0.6,
'AC': 0.3,
'AG': 0.1,
'TA': 0.02,
'TC': 0.3,
'TG': 0.01,
'CT': 0.2,
'CG': 0.5,
'CA': 0.2,
'GA': 0.1,
'GT': 0.4,
'GC': 0.2}

sample_data = {
'A': 0.2137,
'T': 0.3541,
'G': 0.1946,
'C': 0.2376,
'AT': 0.0228,
'AC': 0.0684,
'AG': 0.2056,
'TA': 0.0315,
'TC': 0.0629,
'TG': 0.0315,
'CT': 0.1355,
'CG': 0.0401,
'CA': 0.0703,
'GA': 0.1824,
'GT': 0.0387,
'GC': 0.1106}
orig_positions = {
'base': [3 * max_arrow_width, 3 * max_arrow_width],
'center': [length / 2, 3 * max_arrow_width],
'tip': [length - 3 * max_arrow_width, 3 * max_arrow_width],
}
# for diagonal arrows, put the label at the arrow base
# for vertical or horizontal arrows, center the label
where = 'base' if (cp0 != cp1).all() else 'center'
# rotate based on direction of arrow (cos, sin)
M = [[cos, -sin], [sin, cos]]
x, y = np.dot(M, orig_positions[where]) + [x_pos, y_pos]
label = r'$r_{_{\mathrm{%s}}}$' % (pair,)
ax.text(x, y, label, size=label_text_size, ha='center', va='center',
color=labelcolor or fc)


if __name__ == '__main__':
from sys import argv
d = None
if len(argv) > 1:
if argv[1] == 'full':
d = all_on_max
scaled = False
elif argv[1] == 'extreme':
d = extreme_data
scaled = False
elif argv[1] == 'realistic':
d = realistic_data
scaled = False
elif argv[1] == 'sample':
d = sample_data
scaled = True
if d is None:
d = all_on_max
scaled = False
if len(argv) > 2:
display = argv[2]
else:
display = 'length'
data = { # test data
'A': 0.4, 'T': 0.3, 'G': 0.6, 'C': 0.2,
'AT': 0.4, 'AC': 0.3, 'AG': 0.2,
'TA': 0.2, 'TC': 0.3, 'TG': 0.4,
'CT': 0.2, 'CG': 0.3, 'CA': 0.2,
'GA': 0.1, 'GT': 0.4, 'GC': 0.1,
}

size = 4
plt.figure(figsize=(size, size))
fig = plt.figure(figsize=(3 * size, size), constrained_layout=True)
axs = fig.subplot_mosaic([["length", "width", "alpha"]])
Copy link
Member
@story645 story645 Jun 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find it really hard to learn these concepts from the demo (and I feel like @jklymak may also have opinions here).

I really like that you broke this out into three subplots but is there a way to make it clearer, especially given how much code is in this demo? Either by making the different length, width, alpha values very dramatic, and/or adding the values to the axes label?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think its fine to have this stuff in here. The methods are clickable, and I don't think this is too complicated. OTOH I'd probably just do figsize=(12, 4); we can all see that it'll be 3x wider than tall ;-)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That size is also used later to scale the text size, so I left it in as a variable.


make_arrow_plot(d, display=display, linewidth=0.001, edgecolor=None,
normalize_data=scaled, head_starts_at_zero=True, size=size)
for display, ax in axs.items():
make_arrow_graph(
ax, data, display=display, linewidth=0.001, edgecolor=None,
normalize_data=True, size=size)

plt.show()
0