10000 Merge pull request #973 from cgohlke/patch-9 · sgowda/matplotlib@11567bb · GitHub
[go: up one dir, main page]

Skip to content
< 8000 div class="d-none">

Commit 11567bb

Browse files
committed
Merge pull request matplotlib#973 from cgohlke/patch-9
Fix sankey.py pep8 and py3 compatibility
2 parents 7a6897e + 829e253 commit 11567bb

File tree

3 files changed

+334
-317
lines changed

3 files changed

+334
-317
lines changed

examples/api/sankey_demo_links.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,35 @@
11
"""Demonstrate/test the Sankey class by producing a long chain of connections.
22
"""
3-
import numpy as np
4-
import matplotlib.pyplot as plt
53

6-
from matplotlib.sankey import Sankey
74
from itertools import cycle
85

6+
import matplotlib.pyplot as plt
7+
from matplotlib.sankey import Sankey
8+
99
links_per_side = 6
10+
11+
1012
def side(sankey, n=1):
11-
"""Generate a side chain.
12-
"""
13+
"""Generate a side chain."""
1314
prior = len(sankey.diagrams)
1415
colors = cycle(['orange', 'b', 'g', 'r', 'c', 'm', 'y'])
1516
for i in range(0, 2*n, 2):
1617
sankey.add(flows=[1, -1], orientations=[-1, -1],
17-
patchlabel=str(prior+i), facecolor=colors.next(),
18+
patchlabel=str(prior+i), facecolor=next(colors),
1819
prior=prior+i-1, connect=(1, 0), alpha=0.5)
1920
sankey.add(flows=[1, -1], orientations=[1, 1],
20-
patchlabel=str(prior+i+1), facecolor=colors.next(),
21+
patchlabel=str(prior+i+1), facecolor=next(colors),
2122
prior=prior+i, connect=(1, 0), alpha=0.5)
23+
24+
2225
def corner(sankey):
23-
"""Generate a corner link.
24-
"""
26+
"""Generate a corner link."""
2527
prior = len(sankey.diagrams)
2628
sankey.add(flows=[1, -1], orientations=[0, 1],
2729
patchlabel=str(prior), facecolor='k',
2830
prior=prior-1, connect=(1, 0), alpha=0.5)
31+
32+
2933
fig = plt.figure()
3034
ax = fig.add_subplot(1, 1, 1, xticks=[], yticks=[],
3135
title="Why would you want to do this?\n(But you could.)")

examples/api/sankey_demo_old.py

Lines changed: 99 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -7,149 +7,149 @@
77

88
import numpy as np
99

10+
1011
def sankey(ax,
1112
outputs=[100.], outlabels=None,
1213
inputs=[100.], inlabels='',
1314
dx=40, dy=10, outangle=45, w=3, inangle=30, offset=2, **kwargs):
1415
"""Draw a Sankey diagram.
1516
16-
outputs: array of outputs, should sum up to 100%
17-
outlabels: output labels (same length as outputs),
18-
or None (use default labels) or '' (no labels)
19-
inputs and inlabels: similar for inputs
20-
dx: horizontal elongation
21-
dy: vertical elongation
22-
outangle: output arrow angle [deg]
23-
w: output arrow shoulder
24-
inangle: input dip angle
25-
offset: text offset
26-
**kwargs: propagated to Patch (e.g. fill=False)
27-
28-
Return (patch,[intexts,outtexts])."""
29-
17+
outputs: array of outputs, should sum up to 100%
18+
outlabels: output labels (same length as outputs),
19+
or None (use default labels) or '' (no labels)
20+
inputs and inlabels: similar for inputs
21+
dx: horizontal elongation
22+
dy: vertical elongation
23+
outangle: output arrow angle [deg]
24+
w: output arrow shoulder
25+
inangle: input dip angle
26+
offset: text offset
27+
**kwargs: propagated to Patch (e.g. fill=False)
28+
29+
Return (patch,[intexts,outtexts]).
30+
"""
3031
import matplotlib.patches as mpatches
3132
from matplotlib.path import Path
3233

3334
outs = np.absolute(outputs)
3435
outsigns = np.sign(outputs)
35-
outsigns[-1] = 0 # Last output
36+
outsigns[-1] = 0 # Last output
3637

3738
ins = np.absolute(inputs)
3839
insigns = np.sign(inputs)
39-
insigns[0] = 0 # First input
40+
insigns[0] = 0 # First input
4041

41-
assert sum(outs)==100, "Outputs don't sum up to 100%"
42-
assert sum(ins)==100, "Inputs don't sum up to 100%"
42+
assert sum(outs) == 100, "Outputs don't sum up to 100%"
43+
assert sum(ins) == 100, "Inputs don't sum up to 100%"
4344

4445
def add_output(path, loss, sign=1):
45-
h = (loss/2+w)*np.tan(outangle/180.*np.pi) # Arrow tip height
46-
move,(x,y) = path[-1] # Use last point as reference
47-
if sign==0: # Final loss (horizontal)
48-
path.extend([(Path.LINETO,[x+dx,y]),
49-
(Path.LINETO,[x+dx,y+w]),
50-
(Path.LINETO,[x+dx+h,y-loss/2]), # Tip
51-
(Path.LINETO,[x+dx,y-loss-w]),
52-
(Path.LINETO,[x+dx,y-loss])])
53-
outtips.append((sign,path[-3][1]))
54-
else: # Intermediate loss (vertical)
55-
path.extend([(Path.CURVE4,[x+dx/2,y]),
56-
(Path.CURVE4,[x+dx,y]),
57-
(Path.CURVE4,[x+dx,y+sign*dy]),
58-
(Path.LINETO,[x+dx-w,y+sign*dy]),
59-
(Path.LINETO,[x+dx+loss/2,y+sign*(dy+h)]), # Tip
60-
(Path.LINETO,[x+dx+loss+w,y+sign*dy]),
61-
(Path.LINETO,[x+dx+loss,y+sign*dy]),
62-
(Path.CURVE3,[x+dx+loss,y-sign*loss]),
63-
(Path.CURVE3,[x+dx/2+loss,y-sign*loss])])
64-
outtips.append((sign,path[-5][1]))
46+
h = (loss/2 + w)*np.tan(outangle/180. * np.pi) # Arrow tip height
47+
move, (x, y) = path[-1] # Use last point as reference
48+
if sign == 0: # Final loss (horizontal)
49+
path.extend([(Path.LINETO, [x+dx, y]),
50+
(Path.LINETO, [x+dx, y+w]),
51+
(Path.LINETO, [x+dx+h, y-loss/2]), # Tip
52+
(Path.LINETO, [x+dx, y-loss-w]),
53+
(Path.LINETO, [x+dx, y-loss])])
54+
outtips.append((sign, path[-3][1]))
55+
else: # Intermediate loss (vertical)
56+
path.extend([(Path.CURVE4, [x+dx/2, y]),
57+
(Path.CURVE4, [x+dx, y]),
58+
(Path.CURVE4, [x+dx, y+sign*dy]),
59+
(Path.LINETO, [x+dx-w, y+sign*dy]),
60+
(Path.LINETO, [x+dx+loss/2, y+sign*(dy+h)]), # Tip
61+
(Path.LINETO, [x+dx+loss+w, y+sign*dy]),
62+
(Path.LINETO, [x+dx+loss, y+sign*dy]),
63+
(Path.CURVE3, [x+dx+loss, y-sign*loss]),
64+
(Path.CURVE3, [x+dx/2+loss, y-sign*loss])])
65+
outtips.append((sign, path[-5][1]))
6566

6667
def add_input(path, gain, sign=1):
67-
h = (gain/2)*np.tan(inangle/180.*np.pi) # Dip depth
68-
move,(x,y) = path[-1] # Use last point as reference
69-
if sign==0: # First gain (horizontal)
70-
path.extend([(Path.LINETO,[x-dx,y]),
71-
(Path.LINETO,[x-dx+h,y+gain/2]), # Dip
72-
(Path.LINETO,[x-dx,y+gain])])
73-
xd,yd = path[-2][1] # Dip position
74-
indips.append((sign,[xd-h,yd]))
75-
else: # Intermediate gain (vertical)
76-
path.extend([(Path.CURVE4,[x-dx/2,y]),
77-
(Path.CURVE4,[x-dx,y]),
78-
(Path.CURVE4,[x-dx,y+sign*dy]),
79-
(Path.LINETO,[x-dx-gain/2,y+sign*(dy-h)]), # Dip
80-
(Path.LINETO,[x-dx-gain,y+sign*dy]),
81-
(Path.CURVE3,[x-dx-gain,y-sign*gain]),
82-
(Path.CURVE3,[x-dx/2-gain,y-sign*gain])])
83-
xd,yd = path[-4][1] # Dip position
84-
indips.append((sign,[xd,yd+sign*h]))
85-
86-
outtips = [] # Output arrow tip dir. and positions
87-
urpath = [(Path.MOVETO,[0,100])] # 1st point of upper right path
88-
lrpath = [(Path.LINETO,[0,0])] # 1st point of lower right path
89-
for loss,sign in zip(outs,outsigns):
68+
h = (gain/2)*np.tan(inangle/180. * np.pi) # Dip depth
69+
move, (x, y) = path[-1] # Use last point as reference
70+
if sign == 0: # First gain (horizontal)
71+
path.extend([(Path.LINETO, [x-dx, y]),
72+
(Path.LINETO, [x-dx+h, y+gain/2]), # Dip
73+
(Path.LINETO, [x-dx, y+gain])])
74+
xd, yd = path[-2][1] # Dip position
75+
indips.append((sign, [xd-h, yd]))
76+
else: # Intermediate gain (vertical)
77+
path.extend([(Path.CURVE4, [x-dx/2, y]),
78< CE7 code class="diff-text syntax-highlighted-line addition">+
(Path.CURVE4, [x-dx, y]),
79+
(Path.CURVE4, [x-dx, y+sign*dy]),
80+
(Path.LINETO, [x-dx-gain/2, y+sign*(dy-h)]), # Dip
81+
(Path.LINETO, [x-dx-gain, y+sign*dy]),
82+
(Path.CURVE3, [x-dx-gain, y-sign*gain]),
83+
(Path.CURVE3, [x-dx/2-gain, y-sign*gain])])
84+
xd, yd = path[-4][1] # Dip position
85+
indips.append((sign, [xd, yd+sign*h]))
86+
87+
outtips = [] # Output arrow tip dir. and positions
88+
urpath = [(Path.MOVETO, [0, 100])] # 1st point of upper right path
89+
lrpath = [(Path.LINETO, [0, 0])] # 1st point of lower right path
90+
for loss, sign in zip(outs, outsigns):
9091
add_output(sign>=0 and urpath or lrpath, loss, sign=sign)
9192

92-
indips = [] # Input arrow tip dir. and positions
93-
llpath = [(Path.LINETO,[0,0])] # 1st point of lower left path
94-
ulpath = [(Path.MOVETO,[0,100])] # 1st point of upper left path
95-
for gain,sign in zip(ins,insigns)[::-1]:
93+
indips = [] # Input arrow tip dir. and positions
94+
llpath = [(Path.LINETO, [0, 0])] # 1st point of lower left path
95+
ulpath = [(Path.MOVETO, [0, 100])] # 1st point of upper left path
96+
for gain, sign in reversed(list(zip(ins, insigns))):
9697
add_input(sign<=0 and llpath or ulpath, gain, sign=sign)
9798

9899
def revert(path):
99100
"""A path is not just revertable by path[::-1] because of Bezier
100-
curves."""
101+
curves."""
101102
rpath = []
102103
nextmove = Path.LINETO
103-
for move,pos in path[::-1]:
104-
rpath.append((nextmove,pos))
104+
for move, pos in path[::-1]:
105+
rpath.append((nextmove, pos))
105106
nextmove = move
106107
return rpath
107108

108109
# Concatenate subpathes in correct order
109110
path = urpath + revert(lrpath) + llpath + revert(ulpath)
110111

111-
codes,verts = zip(*path)
112+
codes, verts = zip(*path)
112113
verts = np.array(verts)
113114

114115
# Path patch
115-
path = Path(verts,codes)
116+
path = Path(verts, codes)
116117
patch = mpatches.PathPatch(path, **kwargs)
117118
ax.add_patch(patch)
118119

119-
if False: # DEBUG
120+
if False: # DEBUG
120121
print("urpath", urpath)
121122
print("lrpath", revert(lrpath))
122123
print("llpath", llpath)
123124
print("ulpath", revert(ulpath))
124-
125-
xs,ys = zip(*verts)
126-
ax.plot(xs,ys,'go-')
125+
xs, ys = zip(*verts)
126+
ax.plot(xs, ys, 'go-')
127127

128128
# Labels
129129

130-
def set_labels(labels,values):
130+
def set_labels(labels, values):
131131
"""Set or check labels according to values."""
132-
if labels=='': # No labels
132+
if labels == '': # No labels
133133
return labels
134-
elif labels is None: # Default labels
135-
return [ '%2d%%' % val for val in values ]
134+
elif labels is None: # Default labels
135+
return ['%2d%%' % val for val in values]
136136
else:
137-
assert len(labels)==len(values)
137+
assert len(labels) == len(values)
138138
return labels
139139

140-
def put_labels(labels,positions,output=True):
140+
def put_labels(labels, positions, output=True):
141141
"""Put labels to positions."""
142142
texts = []
143143
lbls = output and labels or labels[::-1]
144-
for i,label in enumerate(lbls):
145-
s,(x,y) = positions[i] # Label direction and position
146-
if s==0:
147-
t = ax.text(x+offset,y,label,
144+
for i, label in enumerate(lbls):
145+
s, (x, y) = positions[i] # Label direction and position
146+
if s == 0:
147+
t = ax.text(x+offset, y, label,
148148
ha=output and 'left' or 'right', va='center')
149-
elif s>0:
150-
t = ax.text(x,y+offset,label, ha='center', va='bottom')
149+
elif s > 0:
150+
t = ax.text(x, y+offset, label, ha='center', va='bottom')
151151
else:
152-
t = ax.text(x,y-offset,label, ha='center', va='top')
152+
t = ax.text(x, y-offset, label, ha='center', va='top')
153153
texts.append(t)
154154
return texts
155155

@@ -160,32 +160,30 @@ def put_labels(labels,positions,output=True):
160160
intexts = put_labels(inlabels, indips, output=False)
161161

162162
# Axes management
163-
ax.set_xlim(verts[:,0].min()-dx, verts[:,0].max()+dx)
164-
ax.set_ylim(verts[:,1].min()-dy, verts[:,1].max()+dy)
163+
ax.set_xlim(verts[:, 0].min()-dx, verts[:, 0].max()+dx)
164+
ax.set_ylim(verts[:, 1].min()-dy, verts[:, 1].max()+dy)
165165
ax.set_aspect('equal', adjustable='datalim')
166166

167-
return patch,[intexts,outtexts]
167+
return patch, [intexts, outtexts]
168+
168169

169170
if __name__=='__main__':
170171

171172
import matplotlib.pyplot as plt
172173

173-
outputs = [10.,-20.,5.,15.,-10.,40.]
174-
outlabels = ['First','Second','Third','Fourth','Fifth','Hurray!']
175-
outlabels = [ s+'\n%d%%' % abs(l) for l,s in zip(outputs,outlabels) ]
174+
outputs = [10., -20., 5., 15., -10., 40.]
175+
outlabels = ['First', 'Second', 'Third', 'Fourth', 'Fifth', 'Hurray!']
176+
outlabels = [s+'\n%d%%' % abs(l) for l, s in zip(outputs, outlabels)]
176177

177-
inputs = [60.,-25.,15.]
178+
inputs = [60., -25., 15.]
178179

179180
fig = plt.figure()
180-
ax = fig.add_subplot(1,1,1, xticks=[],yticks=[],
181-
title="Sankey diagram"
182-
)
181+
ax = fig.add_subplot(1, 1, 1, xticks=[], yticks=[], title="Sankey diagram")
183182

184-
patch,(intexts,outtexts) = sankey(ax, outputs=outputs, outlabels=outlabels,
185-
inputs=inputs, inlabels=None,
186-
fc='g', alpha=0.2)
183+
patch, (intexts, outtexts) = sankey(ax, outputs=outputs,
184+
outlabels=outlabels, inputs=inputs,
185+
inlabels=None, fc='g', alpha=0.2)
187186
outtexts[1].set_color('r')
188187
outtexts[-1].set_fontweight('bold')
189188

190189
plt.show()
191-

0 commit comments

Comments
 (0)
0