diff --git a/examples/text_labels_and_annotations/arrow_demo.py b/examples/text_labels_and_annotations/arrow_demo.py index 9e678249d1a1..a79c5eb27a05 100644 --- a/examples/text_labels_and_annotations/arrow_demo.py +++ b/examples/text_labels_and_annotations/arrow_demo.py @@ -3,145 +3,75 @@ Arrow Demo ========== -Arrow drawing example for the new fancy_arrow facilities. - -Code contributed by: Rob Knight - -usage: - - python arrow_demo.py realistic|full|sample|extreme +Three ways of drawing arrows to encode arrow "strength" (e.g., transition +probabilities in a Markov model) using arrow length, width, or alpha (opacity). +""" +import itertools -""" import matplotlib.pyplot as plt import numpy as np -rates_to_bases = {'r1': 'AT', 'r2': 'TA', 'r3': 'GA', 'r4': 'AG', 'r5': 'CA', - 'r6': 'AC', 'r7': 'GT', 'r8': 'TG', 'r9': 'CT', 'r10': 'TC', - 'r11': 'GC', 'r12': 'CG'} -numbered_bases_to_rates = {v: k for k, v in rates_to_bases.items()} -lettered_bases_to_rates = {v: 'r' + v for k, v in rates_to_bases.items()} - -def make_arrow_plot(data, size=4, display='length', shape='right', - max_arrow_width=0.03, arrow_sep=0.02, alpha=0.5, - normalize_data=False, ec=None, labelcolor=None, - head_starts_at_zero=True, - rate_labels=lettered_bases_to_rates, - **kwargs): +def make_arrow_graph(ax, data, size=4, display='length', shape='right', + max_arrow_width=0.03, arrow_sep=0.02, alpha=0.5, + normalize_data=False, ec=None, labelcolor=None, + **kwargs): """ Makes an arrow plot. Parameters ---------- + ax + The axes where the graph is drawn. data Dict with probabilities for the bases and pair transitions. size - Size of the graph in inches. + Size of the plot, in inches. display : {'length', 'width', 'alpha'} The arrow property to change. shape : {'full', 'left', 'right'} For full or half arrows. max_arrow_width : float - Maximum width of an arrow, data coordinates. + Maximum width of an arrow, in data coordinates. arrow_sep : float - Separation between arrows in a pair, data coordinates. + Separation between arrows in a pair, in data coordinates. alpha : float Maximum opacity of arrows. **kwargs - Can be anything allowed by a Arrow object, e.g. *linewidth* or - *edgecolor*. + `.FancyArrow` properties, e.g. *linewidth* or *edgecolor*. """ - plt.xlim(-0.5, 1.5) - plt.ylim(-0.5, 1.5) - plt.gcf().set_size_inches(size, size) - plt.xticks([]) - plt.yticks([]) + ax.set(xlim=(-0.5, 1.5), ylim=(-0.5, 1.5), xticks=[], yticks=[]) + ax.text(.01, .01, f'flux encoded as arrow {display}', + transform=ax.transAxes) max_text_size = size * 12 min_text_size = size label_text_size = size * 2.5 - text_params = {'ha': 'center', 'va': 'center', 'family': 'sans-serif', - 'fontweight': 'bold'} - r2 = np.sqrt(2) - - deltas = { - 'AT': (1, 0), - 'TA': (-1, 0), - 'GA': (0, 1), - 'AG': (0, -1), - 'CA': (-1 / r2, 1 / r2), - 'AC': (1 / r2, -1 / r2), - 'GT': (1 / r2, 1 / r2), - 'TG': (-1 / r2, -1 / r2), - 'CT': (0, 1), - 'TC': (0, -1), - 'GC': (1, 0), - 'CG': (-1, 0)} - colors = { - 'AT': 'r', - 'TA': 'k', - 'GA': 'g', - 'AG': 'r', - 'CA': 'b', - 'AC': 'r', - 'GT': 'g', - 'TG': 'k', - 'CT': 'b', - 'TC': 'k', - 'GC': 'g', - 'CG': 'b'} - - label_positions = { - 'AT': 'center', - 'TA': 'center', - 'GA': 'center', - 'AG': 'center', - 'CA': 'left', - 'AC': 'left', - 'GT': 'left', - 'TG': 'left', - 'CT': 'center', - 'TC': 'center', - 'GC': 'center', - 'CG': 'center'} - - def do_fontsize(k): - return float(np.clip(max_text_size * np.sqrt(data[k]), - min_text_size, max_text_size)) - - plt.text(0, 1, '$A_3$', color='r', size=do_fontsize('A'), **text_params) - plt.text(1, 1, '$T_3$', color='k', size=do_fontsize('T'), **text_params) - plt.text(0, 0, '$G_3$', color='g', size=do_fontsize('G'), **text_params) - plt.text(1, 0, '$C_3$', color='b', size=do_fontsize('C'), **text_params) + bases = 'ATGC' + coords = { + 'A': np.array([0, 1]), + 'T': np.array([1, 1]), + 'G': np.array([0, 0]), + 'C': np.array([1, 0]), + } + colors = {'A': 'r', 'T': 'k', 'G': 'g', 'C': 'b'} + + for base in bases: + fontsize = np.clip(max_text_size * data[base]**(1/2), + min_text_size, max_text_size) + ax.text(*coords[base], f'${base}_3$', + color=colors[base], size=fontsize, + horizontalalignment='center', verticalalignment='center', + weight='bold') arrow_h_offset = 0.25 # data coordinates, empirically determined max_arrow_length = 1 - 2 * arrow_h_offset max_head_width = 2.5 * max_arrow_width max_head_length = 2 * max_arrow_width - arrow_params = {'length_includes_head': True, 'shape': shape, - 'head_starts_at_zero': head_starts_at_zero} sf = 0.6 # max arrow size represents this in data coords - d = (r2 / 2 + arrow_h_offset - 0.5) / r2 # distance for diags - r2v = arrow_sep / r2 # offset for diags - - # tuple of x, y for start position - positions = { - 'AT': (arrow_h_offset, 1 + arrow_sep), - 'TA': (1 - arrow_h_offset, 1 - arrow_sep), - 'GA': (-arrow_sep, arrow_h_offset), - 'AG': (arrow_sep, 1 - arrow_h_offset), - 'CA': (1 - d - r2v, d - r2v), - 'AC': (d + r2v, 1 - d + r2v), - 'GT': (d - r2v, d + r2v), - 'TG': (1 - d + r2v, 1 - d - r2v), - 'CT': (1 - arrow_sep, arrow_h_offset), - 'TC': (1 + arrow_sep, 1 - arrow_h_offset), - 'GC': (arrow_h_offset, arrow_sep), - 'CG': (1 - arrow_h_offset, -arrow_sep)} - if normalize_data: # find maximum value for rates, i.e. where keys are 2 chars long max_val = max((v for k, v in data.items() if len(k) == 2), default=0) @@ -149,7 +79,8 @@ def do_fontsize(k): for k, v in data.items(): data[k] = v / max_val * sf - def draw_arrow(pair, alpha=alpha, ec=ec, labelcolor=labelcolor): + # iterate over strings 'AT', 'TA', 'AG', 'GA', etc. + for pair in map(''.join, itertools.permutations(bases, 2)): # set the length of the arrow if display == 'length': length = (max_head_length @@ -159,7 +90,6 @@ def draw_arrow(pair, alpha=alpha, ec=ec, labelcolor=labelcolor): # set the transparency of the arrow if display == 'alpha': alpha = min(data[pair] / sf, alpha) - # set the width of the arrow if display == 'width': scale = data[pair] / sf @@ -171,137 +101,59 @@ def draw_arrow(pair, alpha=alpha, ec=ec, labelcolor=labelcolor): head_width = max_head_width head_length = max_head_length - fc = colors[pair] - ec = ec or fc - - x_scale, y_scale = deltas[pair] - x_pos, y_pos = positions[pair] - plt.arrow(x_pos, y_pos, x_scale * length, y_scale * length, - fc=fc, ec=ec, alpha=alpha, width=width, - head_width=head_width, head_length=head_length, - **arrow_params) - - # figure out coordinates for text + fc = colors[pair[0]] + + cp0 = coords[pair[0]] + cp1 = coords[pair[1]] + # unit vector in arrow direction + delta = cos, sin = (cp1 - cp0) / np.hypot(*(cp1 - cp0)) + x_pos, y_pos = ( + (cp0 + cp1) / 2 # midpoint + - delta * length / 2 # half the arrow length + + np.array([-sin, cos]) * arrow_sep # shift outwards by arrow_sep + ) + ax.arrow( + x_pos, y_pos, cos * length, sin * length, + fc=fc, ec=ec or fc, alpha=alpha, width=width, + head_width=head_width, head_length=head_length, shape=shape, + length_includes_head=True, + ) + + # figure out coordinates for text: # if drawing relative to base: x and y are same as for arrow # dx and dy are one arrow width left and up - # need to rotate based on direction of arrow, use x_scale and y_scale - # as sin x and cos x? - sx, cx = y_scale, x_scale - - where = label_positions[pair] - if where == 'left': - orig_position = 3 * np.array([[max_arrow_width, max_arrow_width]]) - elif where == 'absolute': - orig_position = np.array([[max_arrow_length / 2.0, - 3 * max_arrow_width]]) - elif where == 'right': - orig_position = np.array([[length - 3 * max_arrow_width, - 3 * max_arrow_width]]) - elif where == 'center': - orig_position = np.array([[length / 2.0, 3 * max_arrow_width]]) - else: - raise ValueError("Got unknown position parameter %s" % where) - - M = np.array([[cx, sx], [-sx, cx]]) - coords = np.dot(orig_position, M) + [[x_pos, y_pos]] - x, y = np.ravel(coords) - orig_label = rate_labels[pair] - label = r'$%s_{_{\mathrm{%s}}}$' % (orig_label[0], orig_label[1:]) - - plt.text(x, y, label, size=label_text_size, ha='center', va='center', - color=labelcolor or fc) - - for p in sorted(positions): - draw_arrow(p) - - -# test data -all_on_max = dict([(i, 1) for i in 'TCAG'] + - [(i + j, 0.6) for i in 'TCAG' for j in 'TCAG']) - -realistic_data = { - 'A': 0.4, - 'T': 0.3, - 'G': 0.5, - 'C': 0.2, - 'AT': 0.4, - 'AC': 0.3, - 'AG': 0.2, - 'TA': 0.2, - 'TC': 0.3, - 'TG': 0.4, - 'CT': 0.2, - 'CG': 0.3, - 'CA': 0.2, - 'GA': 0.1, - 'GT': 0.4, - 'GC': 0.1} - -extreme_data = { - 'A': 0.75, - 'T': 0.10, - 'G': 0.10, - 'C': 0.05, - 'AT': 0.6, - 'AC': 0.3, - 'AG': 0.1, - 'TA': 0.02, - 'TC': 0.3, - 'TG': 0.01, - 'CT': 0.2, - 'CG': 0.5, - 'CA': 0.2, - 'GA': 0.1, - 'GT': 0.4, - 'GC': 0.2} - -sample_data = { - 'A': 0.2137, - 'T': 0.3541, - 'G': 0.1946, - 'C': 0.2376, - 'AT': 0.0228, - 'AC': 0.0684, - 'AG': 0.2056, - 'TA': 0.0315, - 'TC': 0.0629, - 'TG': 0.0315, - 'CT': 0.1355, - 'CG': 0.0401, - 'CA': 0.0703, - 'GA': 0.1824, - 'GT': 0.0387, - 'GC': 0.1106} + orig_positions = { + 'base': [3 * max_arrow_width, 3 * max_arrow_width], + 'center': [length / 2, 3 * max_arrow_width], + 'tip': [length - 3 * max_arrow_width, 3 * max_arrow_width], + } + # for diagonal arrows, put the label at the arrow base + # for vertical or horizontal arrows, center the label + where = 'base' if (cp0 != cp1).all() else 'center' + # rotate based on direction of arrow (cos, sin) + M = [[cos, -sin], [sin, cos]] + x, y = np.dot(M, orig_positions[where]) + [x_pos, y_pos] + label = r'$r_{_{\mathrm{%s}}}$' % (pair,) + ax.text(x, y, label, size=label_text_size, ha='center', va='center', + color=labelcolor or fc) if __name__ == '__main__': - from sys import argv - d = None - if len(argv) > 1: - if argv[1] == 'full': - d = all_on_max - scaled = False - elif argv[1] == 'extreme': - d = extreme_data - scaled = False - elif argv[1] == 'realistic': - d = realistic_data - scaled = False - elif argv[1] == 'sample': - d = sample_data - scaled = True - if d is None: - d = all_on_max - scaled = False - if len(argv) > 2: - display = argv[2] - else: - display = 'length' + data = { # test data + 'A': 0.4, 'T': 0.3, 'G': 0.6, 'C': 0.2, + 'AT': 0.4, 'AC': 0.3, 'AG': 0.2, + 'TA': 0.2, 'TC': 0.3, 'TG': 0.4, + 'CT': 0.2, 'CG': 0.3, 'CA': 0.2, + 'GA': 0.1, 'GT': 0.4, 'GC': 0.1, + } size = 4 - plt.figure(figsize=(size, size)) + fig = plt.figure(figsize=(3 * size, size), constrained_layout=True) + axs = fig.subplot_mosaic([["length", "width", "alpha"]]) - make_arrow_plot(d, display=display, linewidth=0.001, edgecolor=None, - normalize_data=scaled, head_starts_at_zero=True, size=size) + for display, ax in axs.items(): + make_arrow_graph( + ax, data, display=display, linewidth=0.001, edgecolor=None, + normalize_data=True, size=size) plt.show()