@@ -75,8 +75,10 @@ def incoming_edge_map(self):
75
75
76
76
def get_incoming_edges (downstream_node , incoming_edge_map ):
77
77
edges = []
78
- for downstream_label , upstream_info in [(i [0 ], i [1 ]) for i in incoming_edge_map .items ()]:
78
+ for downstream_label , upstream_info in incoming_edge_map .items ():
79
+ # `upstream_info` may contain the upstream_selector. [:2] trims it away
79
80
upstream_node , upstream_label = upstream_info [:2 ]
81
+ # Take into account the stream selector if it's present (i.e. len(upstream_info) >= 3)
80
82
upstream_selector = None if len (upstream_info ) < 3 else upstream_info [2 ]
81
83
edges += [DagEdge (downstream_node , downstream_label , upstream_node , upstream_label , upstream_selector )]
82
84
return edges
@@ -86,7 +88,9 @@ def get_outgoing_edges(upstream_node, outgoing_edge_map):
86
88
edges = []
87
89
for upstream_label , downstream_infos in list (outgoing_edge_map .items ()):
88
90
for downstream_info in downstream_infos :
91
+ # `downstream_info` may contain the downstream_selector. [:2] trims it away
89
92
downstream_node , downstream_label = downstream_info [:2 ]
93
+ # Take into account the stream selector if it's present
90
94
downstream_selector = None if len (downstream_info ) < 3 else downstream_info [2 ]
91
95
edges += [DagEdge (downstream_node , downstream_label , upstream_node , upstream_label , downstream_selector )]
92
96
return edges
@@ -99,8 +103,10 @@ class KwargReprNode(DagNode):
99
103
@property
100
104
def __upstream_hashes (self ):
101
105
hashes = []
102
- # This is needed to allow extra stuff in the incoming_edge_map's value tuples
103
- for downstream_label , (upstream_node , upstream_label ) in [(i [0 ], i [1 ][:2 ]) for i in self .incoming_edge_map .items ()]:
106
+ for downstream_label , upstream_info in self .incoming_edge_map .items ():
107
+ # `upstream_info` may contain the upstream_selector. [:2] trims it away
108
+ upstream_node , upstream_label = upstream_info [:2 ]
109
+ # The stream selector is discarded when calculating the hash: the stream "as a whole" is still the same
104
110
hashes += [hash (x ) for x in [downstream_label , upstream_node , upstream_label ]]
105
111
return hashes
106
112
0 commit comments