8000 Added backporting rules for TensorArrayV3 and friends · jbenjos/tensorflow@331f272 · GitHub
[go: up one dir, main page]

8000
Skip to content

Commit 331f272

Browse files
petewardentensorflower-gardener
authored andcommitted
Added backporting rules for TensorArrayV3 and friends
Change: 150645199
1 parent b25d1c7 commit 331f272

File tree

2 files changed

+164
-0
lines changed

2 files changed

+164
-0
lines changed

tensorflow/tools/graph_transforms/backports.cc

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,5 +61,75 @@ Status BackportConcatV2Transform(const GraphDef& input_graph_def,
6161

6262
REGISTER_GRAPH_TRANSFORM("backport_concatv2", BackportConcatV2Transform);
6363

64+
// Switch any TensorArrayV3 nodes to the v2 version, removing the second output.
65+
Status BackportTensorArrayV3Transform(const GraphDef& input_graph_def,
66+
const TransformFuncContext& context,
67+
GraphDef* output_graph_def) {
68+
std::map<string, string> inputs_to_rename;
69+
GraphDef replaced_graph_def;
70+
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
71+
input_graph_def, {"TensorArrayV3|TensorArrayGradV3"},
72+
[&inputs_to_rename](const NodeMatch& match,
73+
const std::set<string>& input_nodes,
74+
const std::set<string>& output_nodes,
75+
std::vector<NodeDef>* new_nodes) {
76+
const NodeDef& tensor_array_v3_node = match.node;
77+
78+
// All we need to do here is rename the op type, since the attributes
79+
// remain the same.
80+
NodeDef tensor_array_v2_node = tensor_array_v3_node;
81+
if (tensor_array_v3_node.op() == "TensorArrayV3") {
82+
tensor_array_v2_node.set_op("TensorArrayV2");
83+
} else {
84+
tensor_array_v2_node.set_op("TensorArrayGradV2");
85+
}
86+
87+
// The v3 version has a second 'flow' output that's not present in v2,
88+
// so substitute a dummy constant instead in any places that use it.
89+
NodeDef replacement_flow_node;
90+
replacement_flow_node.set_op("Const");
91+
replacement_flow_node.set_name(tensor_array_v3_node.name() +
92+
"/replacement_flow_node");
93+
Tensor replacement_flow_tensor(DT_FLOAT, {});
94+
// I'm picking an arbitrary value for the gradient flow here, for lack
95+
// of a better alternative.
96+
replacement_flow_tensor.flat<float>()(0) = 1.0f;
97+
SetNodeTensorAttr<float>("value", replacement_flow_tensor,
98+
&replacement_flow_node);
99+
inputs_to_rename[tensor_array_v3_node.name() + ":1"] =
100+
replacement_flow_node.name();
101+
102+
new_nodes->push_back(tensor_array_v2_node);
103+
new_nodes->push_back(replacement_flow_node);
104+
return Status::OK();
105+
},
106+
{true}, &replaced_graph_def));
107+
// Update the graph so that any nodes that referred to removed inputs now
108+
// pull from the substitute constants we've added.
109+
GraphDef renamed_graph_def;
110+
TF_RETURN_IF_ERROR(RenameNodeInputs(replaced_graph_def, inputs_to_rename,
111+
std::unordered_set<string>(),
112+
&renamed_graph_def));
113+
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
114+
renamed_graph_def,
115+
{"TensorArrayWriteV3|TensorArrayReadV3|TensorArrayGatherV3|"
116+
"TensorArrayScatterV3|TensorArrayConcatV3|TensorArraySplitV3|"
117+
"TensorArraySizeV3|TensorArrayCloseV3"},
118+
[](const NodeMatch& match, const std::set<string>& input_nodes,
119+
const std::set<string>& output_nodes,
120+
std::vector<NodeDef>* new_nodes) {
121+
const NodeDef& v3_node = match.node;
122+
NodeDef v2_node = v3_node;
123+
v2_node.set_op(v3_node.op().substr(0, v3_node.op().size() - 1) + "2");
124+
new_nodes->push_back(v2_node);
125+
return Status::OK();
126+
},
127+
{true}, output_graph_def));
128+
return Status::OK();
129+
}
130+
131+
REGISTER_GRAPH_TRANSFORM("backport_tensor_array_v3",
132+
BackportTensorArrayV3Transform);
133+
64134
} // namespace graph_transforms
65135
} // namespace tensorflow

tensorflow/tools/graph_transforms/backports_test.cc

Lines changed: 94 additions & 0 deletions
A935
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ namespace graph_transforms {
3232
Status BackportConcatV2Transform(const GraphDef& input_graph_def,
3333
const TransformFuncContext& context,
3434
GraphDef* output_graph_def);
35+
Status BackportTensorArrayV3Transform(const GraphDef& input_graph_def,
36+
const TransformFuncContext& context,
37+
GraphDef* output_graph_def);
3538

3639
class BackportConcatV2Test : public ::testing::Test {
3740
protected:
@@ -101,5 +104,96 @@ class BackportConcatV2Test : public ::testing::Test {
101104

102105
TEST_F(BackportConcatV2Test, TestBackportConcatV2) { TestBackportConcatV2(); }
103106

107+
TEST(BackportTensorArrayV3Test, TestBackportTensorArrayV3) {
108+
GraphDef graph_def;
109+
110+
NodeDef* size_node = graph_def.add_node();
111+
size_node->set_name("size_node");
112+
size_node->set_op("Const");
113+
Tensor size_tensor(DT_INT32, {});
114+
size_tensor.flat<int32>()(0) = 1;
115+
SetNodeTensorAttr<float>("value", size_tensor, size_node);
116+
117+
NodeDef* tensor_array_node = graph_def.add_node();
118+
tensor_array_node->set_name("tensor_array_node");
119+
tensor_array_node->set_op("TensorArrayV3");
120+
tensor_array_node->add_input("size_node");
121+
SetNodeAttr("dtype", DT_FLOAT, tensor_array_node);
122+
SetNodeAttr("element_shape", TensorShape({1, 2}), tensor_array_node);
123+
SetNodeAttr("dynamic_size", false, tensor_array_node);
124+
SetNodeAttr("clear_after_read", true, tensor_array_node);
125+
SetNodeAttr("tensor_array_name", "some_name", tensor_array_node);
126+
127+
NodeDef* handle_output_node = graph_def.add_node();
128+
handle_output_node->set_name("handle_output_node");
129+
handle_output_node->set_op("Identity");
130+
handle_output_node->add_input("tensor_array_node:0");
131+
132+
NodeDef* flow_output_node = graph_def.add_node();
133+
flow_output_node->set_name("flow_output_node");
134+
flow_output_node->set_op("Identity");
135+
flow_output_node->add_input("tensor_array_node:1");
136+
137+
NodeDef* tensor_array_grad_node = graph_def.add_node();
138+
tensor_array_grad_node->set_name("tensor_array_grad_node");
139+
tensor_array_grad_node->set_op("TensorArrayGradV3");
140+
tensor_array_grad_node->add_input("tensor_array_node:0");
141+
tensor_array_grad_node->add_input("tensor_array_node:1");
142+
SetNodeAttr("source", "foo", tensor_array_grad_node);
143+
144+
NodeDef* grad_handle_output_node = graph_def.add_node();
145+
grad_handle_output_node->set_name("grad_handle_output_node");
146+
grad_handle_output_node->set_op("Identity");
147+
grad_handle_output_node->add_input("tensor_array_grad_node:0");
148+
149+
NodeDef* grad_flow_output_node = graph_def.add_node();
150+
grad_flow_output_node->set_name("grad_flow_output_node");
151+
grad_flow_output_node->set_op("Identity");
152+
grad_flow_output_node->add_input("tensor_array_grad_node:1");
153+
154+
GraphDef result;
155+
TransformFuncContext context;
156+
context.input_names = {};
157+
context.output_names = {"handle_output_node", "grad_handle_output_node"};
158+
TF_ASSERT_OK(BackportTensorArrayV3Transform(graph_def, context, &result));
159+
160+
std::map<string, const NodeDef*> node_lookup;
161+
MapNamesToNodes(result, &node_lookup);
162+
ASSERT_EQ(1, node_lookup.count("tensor_array_node"));
163+
EXPECT_EQ("TensorArrayV2", node_lookup.at("tensor_array_node")->op());
164+
EXPECT_EQ("TensorArrayGradV2",
165+
node_lookup.at("tensor_array_grad_node")->op());
166+
167+
for (const NodeDef& node : result.node()) {
168+
for (const string& input : node.input()) {
169+
EXPECT_NE("tensor_array_node:1", input);
170+
}
171+
}
172+
}
173+
174+
TEST(BackportTensorArrayV3Test, TestBackportTensorArrayV3Subtypes) {
175+
const std::vector<string> v3_ops = {
176+
"TensorArrayWriteV3", "TensorArrayReadV3", "TensorArrayGatherV3",
177+
"TensorArrayScatterV3", "TensorArrayConcatV3", "TensorArraySplitV3",
178+
"TensorArraySizeV3", "TensorArrayCloseV3"};
179+
for (const string& v3_op : v3_ops) {
180+
GraphDef graph_def;
181+
NodeDef* v3_node = graph_def.add_node();
182+
v3_node->set_name("v3_node");
183+
v3_node->set_op(v3_op);
184+
185+
GraphDef result;
186+
TransformFuncContext context;
187+
context.input_names = {};
188+
context.output_names = {""};
189+
TF_ASSERT_OK(BackportTensorArrayV3Transform(graph_def, context, &result));
190+
191+
std::map<string, const NodeDef*> node_lookup;
192+
MapNamesToNodes(result, &node_lookup);
193+
ASSERT_EQ(1, node_lookup.count("v3_node"));
194+
EXPECT_TRUE(StringPiece(node_lookup.at("v3_node")->op()).ends_with("V2"));
195+
}
196+
}
197+
104198
} // namespace graph_transforms
105199
} // namespace tensorflow

0 commit comments

Comments
 (0)
0