7
7
8
8
import numpy as np
9
9
10
+
10
11
def sankey (ax ,
11
12
outputs = [100. ], outlabels = None ,
12
13
inputs = [100. ], inlabels = '' ,
13
14
dx = 40 , dy = 10 , outangle = 45 , w = 3 , inangle = 30 , offset = 2 , ** kwargs ):
14
15
"""Draw a Sankey diagram.
15
16
16
- outputs: array of outputs, should sum up to 100%
17
- outlabels: output labels (same length as outputs),
18
- or None (use default labels) or '' (no labels)
19
- inputs and inlabels: similar for inputs
20
- dx: horizontal elongation
21
- dy: vertical elongation
22
- outangle: output arrow angle [deg]
23
- w: output arrow shoulder
24
- inangle: input dip angle
25
- offset: text offset
26
- **kwargs: propagated to Patch (e.g. fill=False)
27
-
28
- Return (patch,[intexts,outtexts])."""
29
-
17
+ outputs: array of outputs, should sum up to 100%
18
+ outlabels: output labels (same length as outputs),
19
+ or None (use default labels) or '' (no labels)
20
+ inputs and inlabels: similar for inputs
21
+ dx: horizontal elongation
22
+ dy: vertical elongation
23
+ outangle: output arrow angle [deg]
24
+ w: output arrow shoulder
25
+ inangle: input dip angle
26
+ offset: text offset
27
+ **kwargs: propagated to Patch (e.g. fill=False)
28
+
29
+ Return (patch,[intexts,outtexts]).
30
+ """
30
31
import matplotlib .patches as mpatches
31
32
from matplotlib .path import Path
32
33
33
34
outs = np .absolute (outputs )
34
35
outsigns = np .sign (outputs )
35
- outsigns [- 1 ] = 0 # Last output
36
+ outsigns [- 1 ] = 0 # Last output
36
37
37
38
ins = np .absolute (inputs )
38
39
insigns = np .sign (inputs )
39
- insigns [0 ] = 0 # First input
40
+ insigns [0 ] = 0 # First input
40
41
41
- assert sum (outs )== 100 , "Outputs don't sum up to 100%"
42
- assert sum (ins )== 100 , "Inputs don't sum up to 100%"
42
+ assert sum (outs ) == 100 , "Outputs don't sum up to 100%"
43
+ assert sum (ins ) == 100 , "Inputs don't sum up to 100%"
43
44
44
45
def add_output (path , loss , sign = 1 ):
45
- h = (loss / 2 + w )* np .tan (outangle / 180. * np .pi ) # Arrow tip height
46
- move ,(x ,y ) = path [- 1 ] # Use last point as reference
47
- if sign == 0 : # Final loss (horizontal)
48
- path .extend ([(Path .LINETO ,[x + dx ,y ]),
49
- (Path .LINETO ,[x + dx ,y + w ]),
50
- (Path .LINETO ,[x + dx + h ,y - loss / 2 ]), # Tip
51
- (Path .LINETO ,[x + dx ,y - loss - w ]),
52
- (Path .LINETO ,[x + dx ,y - loss ])])
53
- outtips .append ((sign ,path [- 3 ][1 ]))
54
- else : # Intermediate loss (vertical)
55
- path .extend ([(Path .CURVE4 ,[x + dx / 2 ,y ]),
56
- (Path .CURVE4 ,[x + dx ,y ]),
57
- (Path .CURVE4 ,[x + dx ,y + sign * dy ]),
58
- (Path .LINETO ,[x + dx - w ,y + sign * dy ]),
59
- (Path .LINETO ,[x + dx + loss / 2 ,y + sign * (dy + h )]), # Tip
60
- (Path .LINETO ,[x + dx + loss + w ,y + sign * dy ]),
61
- (Path .LINETO ,[x + dx + loss ,y + sign * dy ]),
62
- (Path .CURVE3 ,[x + dx + loss ,y - sign * loss ]),
63
- (Path .CURVE3 ,[x + dx / 2 + loss ,y - sign * loss ])])
64
- outtips .append ((sign ,path [- 5 ][1 ]))
46
+ h = (loss / 2 + w )* np .tan (outangle / 180. * np .pi ) # Arrow tip height
47
+ move , (x , y ) = path [- 1 ] # Use last point as reference
48
+ if sign == 0 : # Final loss (horizontal)
49
+ path .extend ([(Path .LINETO , [x + dx , y ]),
50
+ (Path .LINETO , [x + dx , y + w ]),
51
+ (Path .LINETO , [x + dx + h , y - loss / 2 ]), # Tip
52
+ (Path .LINETO , [x + dx , y - loss - w ]),
53
+ (Path .LINETO , [x + dx , y - loss ])])
54
+ outtips .append ((sign , path [- 3 ][1 ]))
55
+ else : # Intermediate loss (vertical)
56
+ path .extend ([(Path .CURVE4 , [x + dx / 2 , y ]),
57
+ (Path .CURVE4 , [x + dx , y ]),
58
+ (Path .CURVE4 , [x + dx , y + sign * dy ]),
59
+ (Path .LINETO , [x + dx - w , y + sign * dy ]),
60
+ (Path .LINETO , [x + dx + loss / 2 , y + sign * (dy + h )]), # Tip
61
+ (Path .LINETO , [x + dx + loss + w , y + sign * dy ]),
62
+ (Path .LINETO , [x + dx + loss , y + sign * dy ]),
63
+ (Path .CURVE3 , [x + dx + loss , y - sign * loss ]),
64
+ (Path .CURVE3 , [x + dx / 2 + loss , y - sign * loss ])])
65
+ outtips .append ((sign , path [- 5 ][1 ]))
65
66
66
67
def add_input (path , gain , sign = 1 ):
67
- h = (gain / 2 )* np .tan (inangle / 180. * np .pi ) # Dip depth
68
- move ,(x ,y ) = path [- 1 ] # Use last point as reference
69
- if sign == 0 : # First gain (horizontal)
70
- path .extend ([(Path .LINETO ,[x - dx ,y ]),
71
- (Path .LINETO ,[x - dx + h ,y + gain / 2 ]), # Dip
72
- (Path .LINETO ,[x - dx ,y + gain ])])
73
- xd ,yd = path [- 2 ][1 ] # Dip position
74
- indips .append ((sign ,[xd - h ,yd ]))
75
- else : # Intermediate gain (vertical)
76
- path .extend ([(Path .CURVE4 ,[x - dx / 2 ,y ]),
77
- (Path .CURVE4 ,[x - dx ,y ]),
78
- (Path .CURVE4 ,[x - dx ,y + sign * dy ]),
79
- (Path .LINETO ,[x - dx - gain / 2 ,y + sign * (dy - h )]), # Dip
80
- (Path .LINETO ,[x - dx - gain ,y + sign * dy ]),
81
- (Path .CURVE3 ,[x - dx - gain ,y - sign * gain ]),
82
- (Path .CURVE3 ,[x - dx / 2 - gain ,y - sign * gain ])])
83
- xd ,yd = path [- 4 ][1 ] # Dip position
84
- indips .append ((sign ,[xd ,yd + sign * h ]))
85
-
86
- outtips = [] # Output arrow tip dir. and positions
87
- urpath = [(Path .MOVETO ,[0 ,100 ])] # 1st point of upper right path
88
- lrpath = [(Path .LINETO ,[0 ,0 ])] # 1st point of lower right path
89
- for loss ,sign in zip (outs ,outsigns ):
68
+ h = (gain / 2 )* np .tan (inangle / 180. * np .pi ) # Dip depth
69
+ move , (x , y ) = path [- 1 ] # Use last point as reference
70
+ if sign == 0 : # First gain (horizontal)
71
+ path .extend ([(Path .LINETO , [x - dx , y ]),
72
+ (Path .LINETO , [x - dx + h , y + gain / 2 ]), # Dip
73
+ (Path .LINETO , [x - dx , y + gain ])])
74
+ xd , yd = path [- 2 ][1 ] # Dip position
75
+ indips .append ((sign , [xd - h , yd ]))
76
+ else : # Intermediate gain (vertical)
77
+ path .extend ([(Path .CURVE4 , [x - dx / 2 , y ]),
78
<
CE7
code class="diff-text syntax-highlighted-line addition">+ (Path .CURVE4 , [x - dx , y ]),
79
+ (Path .CURVE4 , [x - dx , y + sign * dy ]),
80
+ (Path .LINETO , [x - dx - gain / 2 , y + sign * (dy - h )]), # Dip
81
+ (Path .LINETO , [x - dx - gain , y + sign * dy ]),
82
+ (Path .CURVE3 , [x - dx - gain , y - sign * gain ]),
83
+ (Path .CURVE3 , [x - dx / 2 - gain , y - sign * gain ])])
84
+ xd , yd = path [- 4 ][1 ] # Dip position
85
+ indips .append ((sign , [xd , yd + sign * h ]))
86
+
87
+ outtips = [] # Output arrow tip dir. and positions
88
+ urpath = [(Path .MOVETO , [0 , 100 ])] # 1st point of upper right path
89
+ lrpath = [(Path .LINETO , [0 , 0 ])] # 1st point of lower right path
90
+ for loss , sign in zip (outs , outsigns ):
90
91
add_output (sign >= 0 and urpath or lrpath , loss , sign = sign )
91
92
92
- indips = [] # Input arrow tip dir. and positions
93
- llpath = [(Path .LINETO ,[0 ,0 ])] # 1st point of lower left path
94
- ulpath = [(Path .MOVETO ,[0 ,100 ])] # 1st point of upper left path
95
- for gain ,sign in zip (ins ,insigns )[:: - 1 ] :
93
+ indips = [] # Input arrow tip dir. and positions
94
+ llpath = [(Path .LINETO , [0 , 0 ])] # 1st point of lower left path
95
+ ulpath = [(Path .MOVETO , [0 , 100 ])] # 1st point of upper left path
96
+ for gain , sign in reversed ( list ( zip (ins , insigns ))) :
96
97
add_input (sign <= 0 and llpath or ulpath , gain , sign = sign )
97
98
98
99
def revert (path ):
99
100
"""A path is not just revertable by path[::-1] because of Bezier
100
- curves."""
101
+ curves."""
101
102
rpath = []
102
103
nextmove = Path .LINETO
103
- for move ,pos in path [::- 1 ]:
104
- rpath .append ((nextmove ,pos ))
104
+ for move , pos in path [::- 1 ]:
105
+ rpath .append ((nextmove , pos ))
105
106
nextmove = move
106
107
return rpath
107
108
108
109
# Concatenate subpathes in correct order
109
110
path = urpath + revert (lrpath ) + llpath + revert (ulpath )
110
111
111
- codes ,verts = zip (* path )
112
+ codes , verts = zip (* path )
112
113
verts = np .array (verts )
113
114
114
115
# Path patch
115
- path = Path (verts ,codes )
116
+ path = Path (verts , codes )
116
117
patch = mpatches .PathPatch (path , ** kwargs )
117
118
ax .add_patch (patch )
118
119
119
- if False : # DEBUG
120
+ if False : # DEBUG
120
121
print ("urpath" , urpath )
121
122
print ("lrpath" , revert (lrpath ))
122
123
print ("llpath" , llpath )
123
124
print ("ulpath" , revert (ulpath ))
124
-
125
- xs ,ys = zip (* verts )
126
- ax .plot (xs ,ys ,'go-' )
125
+ xs , ys = zip (* verts )
126
+ ax .plot (xs , ys , 'go-' )
127
127
128
128
# Labels
129
129
130
- def set_labels (labels ,values ):
130
+ def set_labels (labels , values ):
131
131
"""Set or check labels according to values."""
132
- if labels == '' : # No labels
132
+ if labels == '' : # No labels
133
133
return labels
134
- elif labels is None : # Default labels
135
- return [ '%2d%%' % val for val in values ]
134
+ elif labels is None : # Default labels
135
+ return ['%2d%%' % val for val in values ]
136
136
else :
137
- assert len (labels )== len (values )
137
+ assert len (labels ) == len (values )
138
138
return labels
139
139
140
- def put_labels (labels ,positions ,output = True ):
140
+ def put_labels (labels , positions , output = True ):
141
141
"""Put labels to positions."""
142
142
texts = []
143
143
lbls = output and labels or labels [::- 1 ]
144
- for i ,label in enumerate (lbls ):
145
- s ,(x ,y ) = positions [i ] # Label direction and position
146
- if s == 0 :
147
- t = ax .text (x + offset ,y , label ,
144
+ for i , label in enumerate (lbls ):
145
+ s , (x , y ) = positions [i ] # Label direction and position
146
+ if s == 0 :
147
+ t = ax .text (x + offset , y , label ,
148
148
ha = output and 'left' or 'right' , va = 'center' )
149
- elif s > 0 :
150
- t = ax .text (x ,y + offset ,label , ha = 'center' , va = 'bottom' )
149
+ elif s > 0 :
150
+ t = ax .text (x , y + offset , label , ha = 'center' , va = 'bottom' )
151
151
else :
152
- t = ax .text (x ,y - offset ,label , ha = 'center' , va = 'top' )
152
+ t = ax .text (x , y - offset , label , ha = 'center' , va = 'top' )
153
153
texts .append (t )
154
154
return texts
155
155
@@ -160,32 +160,30 @@ def put_labels(labels,positions,output=True):
160
160
intexts = put_labels (inlabels , indips , output = False )
161
161
162
162
# Axes management
163
- ax .set_xlim (verts [:,0 ].min ()- dx , verts [:,0 ].max ()+ dx )
164
- ax .set_ylim (verts [:,1 ].min ()- dy , verts [:,1 ].max ()+ dy )
163
+ ax .set_xlim (verts [:, 0 ].min ()- dx , verts [:, 0 ].max ()+ dx )
164
+ ax .set_ylim (verts [:, 1 ].min ()- dy , verts [:, 1 ].max ()+ dy )
165
165
ax .set_aspect ('equal' , adjustable = 'datalim' )
166
166
167
- return patch ,[intexts ,outtexts ]
167
+ return patch , [intexts , outtexts ]
168
+
168
169
169
170
if __name__ == '__main__' :
170
171
171
172
import matplotlib .pyplot as plt
172
173
173
- outputs = [10. ,- 20. ,5. ,15. ,- 10. ,40. ]
174
- outlabels = ['First' ,'Second' ,'Third' ,'Fourth' ,'Fifth' ,'Hurray!' ]
175
- outlabels = [ s + '\n %d%%' % abs (l ) for l ,s in zip (outputs ,outlabels ) ]
174
+ outputs = [10. , - 20. , 5. , 15. , - 10. , 40. ]
175
+ outlabels = ['First' , 'Second' , 'Third' , 'Fourth' , 'Fifth' , 'Hurray!' ]
176
+ outlabels = [s + '\n %d%%' % abs (l ) for l , s in zip (outputs , outlabels )]
176
177
177
- inputs = [60. ,- 25. ,15. ]
178
+ inputs = [60. , - 25. , 15. ]
178
179
179
180
fig = plt .figure ()
180
- ax = fig .add_subplot (1 ,1 ,1 , xticks = [],yticks = [],
181
- title = "Sankey diagram"
182
- )
181
+ ax = fig .add_subplot (1 , 1 , 1 , xticks = [], yticks = [], title = "Sankey diagram" )
183
182
184
- patch ,(intexts ,outtexts ) = sankey (ax , outputs = outputs , outlabels = outlabels ,
185
- inputs = inputs , inlabels = None ,
186
- fc = 'g' , alpha = 0.2 )
183
+ patch , (intexts , outtexts ) = sankey (ax , outputs = outputs ,
184
+ outlabels = outlabels , inputs = inputs ,
185
+ inlabels = None , fc = 'g' , alpha = 0.2 )
187
186
outtexts [1 ].set_color ('r' )
188
187
outtexts [- 1 ].set_fontweight ('bold' )
189
188
190
189
plt .show ()
191
-
0 commit comments