1
1
from __future__ import unicode_literals
2
2
3
- from .dag import topo_sort
3
+ from .dag import get_outgoing_edges , topo_sort
4
4
from functools import reduce
5
5
from past .builtins import basestring
6
6
import copy
@@ -53,16 +53,31 @@ def _get_input_args(input_node):
53
53
return args
54
54
55
55
56
- def _get_filter_spec (i , node , stream_name_map ):
57
- stream_name = _get_stream_name ('v{}' .format (i ))
58
- stream_name_map [node ] = stream_name
59
- inputs = [stream_name_map [edge .upstream_node ] for edge in node .incoming_edges ]
60
- filter_spec = '{}{}{}' .format ('' .join (inputs ), node ._get_filter (), stream_name )
56
+ def _get_filter_spec (node , outgoing_edge_map , stream_name_map ):
57
+ incoming_edges = node .incoming_edges
58
+ outgoing_edges = get_outgoing_edges (node , outgoing_edge_map )
59
+ inputs = [stream_name_map [edge .upstream_node , edge .upstream_label ] for edge in incoming_edges ]
60
+ outputs = [stream_name_map [edge .upstream_node , edge .upstream_label ] for edge in outgoing_edges ]
61
+ filter_spec = '{}{}{}' .format ('' .join (inputs ), node ._get_filter (), '' .join (outputs ))
61
62
return filter_spec
62
63
63
64
64
- def _get_filter_arg (filter_nodes , stream_name_map ):
65
- filter_specs = [_get_filter_spec (i , node , stream_name_map ) for i , node in enumerate (filter_nodes )]
65
+ def _allocate_filter_stream_names (filter_nodes , outgoing_edge_maps , stream_name_map ):
66
+ stream_count = 0
67
+ for upstream_node in filter_nodes :
68
+ outgoing_edge_map = outgoing_edge_maps [upstream_node ]
69
+ for upstream_label , downstreams in outgoing_edge_map .items ():
70
+ if len (downstreams ) > 1 :
71
+ # TODO: automatically insert `splits` ahead of time via graph transformation.
72
+ raise ValueError ('Encountered {} with multiple outgoing edges with same upstream label {!r}; a '
73
+ '`split` filter is probably required' .format (upstream_node , upstream_label ))
74
+ stream_name_map [upstream_node , upstream_label ] = _get_stream_name ('s{}' .format (stream_count ))
75
+ stream_count += 1
76
+
77
+
78
+ def _get_filter_arg (filter_nodes , outgoing_edge_maps , stream_name_map ):
79
+ _allocate_filter_stream_names (filter_nodes , outgoing_edge_maps , stream_name_map )
80
+ filter_specs = [_get_filter_spec (node , outgoing_edge_maps [node ], stream_name_map ) for node in filter_nodes ]
66
81
return ';' .join (filter_specs )
67
82
68
83
@@ -78,7 +93,8 @@ def _get_output_args(node, stream_name_map):
78
93
raise ValueError ('Unsupported output node: {}' .format (node ))
79
94
args = []
80
95
assert len (node .incoming_edges ) == 1
81
- stream_name = stream_name_map [node .incoming_edges [0 ].upstream_node ]
96
+ edge = node .incoming_edges [0 ]
97
+ stream_name = stream_name_map [edge .upstream_node , edge .upstream_label ]
82
98
if stream_name != '[0]' :
83
99
args += ['-map' , stream_name ]
84
100
kwargs = copy .copy (node .kwargs )
@@ -104,8 +120,8 @@ def get_args(stream):
104
120
isinstance (node , GlobalNode )]
105
121
global_nodes = [node for node in sorted_nodes if isinstance (node , GlobalNode )]
106
122
filter_nodes = [node for node in sorted_nodes if node not in (input_nodes + output_nodes + global_nodes )]
107
- stream_name_map = {node : _get_stream_name (i ) for i , node in enumerate (input_nodes )}
108
- filter_arg = _get_filter_arg (filter_nodes , stream_name_map )
123
+ stream_name_map = {( node , None ) : _get_stream_name (i ) for i , node in enumerate (input_nodes )}
124
+ filter_arg = _get_filter_arg (filter_nodes , outgoing_edge_maps , stream_name_map )
109
125
args += reduce (operator .add , [_get_input_args (node ) for node in input_nodes ])
110
126
if filter_arg :
111
127
args += ['-filter_complex' , filter_arg ]
0 commit comments