@@ -96,6 +96,20 @@ def incoming_edge_map(self):
96
96
DagEdge = namedtuple ('DagEdge' , ['downstream_node' , 'downstream_label' , 'upstream_node' , 'upstream_label' ])
97
97
98
98
99
+ def get_incoming_edges (downstream_node , incoming_edge_map ):
100
+ edges = []
101
+ for downstream_label , (upstream_node , upstream_label ) in incoming_edge_map .items ():
102
+ edges += [DagEdge (downstream_node , downstream_label , upstream_node , upstream_label )]
103
+ return edges
104
+
105
+
106
+ def get_outgoing_edges (upstream_node , outgoing_edge_map ):
107
+ edges = []
108
+ for upstream_label , (downstream_node , downstream_label ) in outgoing_edge_map :
109
+ edges += [DagEdge (downstream_node , downstream_label , upstream_node , upstream_label )]
110
+ return edges
111
+
112
+
99
113
class KwargReprNode (DagNode ):
100
114
"""A DagNode that can be represented as a set of args+kwargs.
101
115
"""
@@ -142,11 +156,7 @@ def __repr__(self):
142
156
143
157
@property
144
158
def incoming_edges (self ):
145
- edges = []
146
- for downstream_label , (upstream_node , upstream_label ) in self .incoming_edge_map .items ():
147
- downstream_node = self
148
- edges += [DagEdge (downstream_node , downstream_label , upstream_node , upstream_label )]
149
- return edges
159
+ return get_incoming_edges (self , self .incoming_edge_map )
150
160
151
161
@property
152
162
def incoming_edge_map (self ):
@@ -157,24 +167,29 @@ def short_repr(self):
157
167
return self .name
158
168
159
169
160
- def topo_sort (start_nodes ):
170
+ def topo_sort (downstream_nodes ):
161
171
marked_nodes = []
162
172
sorted_nodes = []
163
- child_map = {}
164
- def visit (node , child ):
165
- if node in marked_nodes :
173
+ outgoing_edge_maps = {}
174
+
175
+ def visit (upstream_node , upstream_label , downstream_node , downstream_label ):
176
+ if upstream_node in marked_nodes :
166
177
raise RuntimeError ('Graph is not a DAG' )
167
- if child is not None :
168
- if node not in child_map :
169
- child_map [node ] = []
170
- child_map [node ].append (child )
171
- if node not in sorted_nodes :
172
- marked_nodes .append (node )
173
- parents = [edge .upstream_node for edge in node .incoming_edges ]
174
- [visit (parent , node ) for parent in parents ]
175
- marked_nodes .remove (node )
176
- sorted_nodes .append (node )
177
- unmarked_nodes = list (copy .copy (start_nodes ))
178
+
179
+ if downstream_node is not None :
180
+ if upstream_node not in outgoing_edge_maps :
181
+ outgoing_edge_maps [upstream_node ] = {}
182
+ outgoing_edge_maps [upstream_node ][upstream_label ] = (downstream_node , downstream_label )
183
+
184
+ if upstream_node not in sorted_nodes :
185
+ marked_nodes .append (upstream_node )
186
+ for edge in upstream_node .incoming_edges :
187
+ visit (edge .upstream_node , edge .upstream_label , edge .downstream_node , edge .downstream_label )
188
+ marked_nodes .remove (upstream_node )
189
+ sorted_nodes .append (upstream_node )
190
+
191
+ unmarked_nodes = [(node , 0 ) for node in downstream_nodes ]
178
192
while unmarked_nodes :
179
- visit (unmarked_nodes .pop (), None )
180
- return sorted_nodes , child_map
193
+ upstream_node , upstream_label = unmarked_nodes .pop ()
194
+ visit (upstream_node , upstream_label , None , None )
195
+ return sorted_nodes , outgoing_edge_maps
0 commit comments