@@ -32,6 +32,9 @@ namespace graph_transforms {
32
32
Status BackportConcatV2Transform (const GraphDef& input_graph_def,
33
33
const TransformFuncContext& context,
34
34
GraphDef* output_graph_def);
35
+ Status BackportTensorArrayV3Transform (const GraphDef& input_graph_def,
36
+ const TransformFuncContext& context,
37
+ GraphDef* output_graph_def);
35
38
36
39
class BackportConcatV2Test : public ::testing::Test {
37
40
protected:
@@ -101,5 +104,96 @@ class BackportConcatV2Test : public ::testing::Test {
101
104
102
105
TEST_F (BackportConcatV2Test, TestBackportConcatV2) { TestBackportConcatV2 (); }
103
106
107
+ TEST (BackportTensorArrayV3Test, TestBackportTensorArrayV3) {
108
+ GraphDef graph_def;
109
+
110
+ NodeDef* size_node = graph_def.add_node ();
111
+ size_node->set_name (" size_node" );
112
+ size_node->set_op (" Const" );
113
+ Tensor size_tensor (DT_INT32, {});
114
+ size_tensor.flat <int32>()(0 ) = 1 ;
115
+ SetNodeTensorAttr<float >(" value" , size_tensor, size_node);
116
+
117
+ NodeDef* tensor_array_node = graph_def.add_node ();
118
+ tensor_array_node->set_name (" tensor_array_node" );
119
+ tensor_array_node->set_op (" TensorArrayV3" );
120
+ tensor_array_node->add_input (" size_node" );
121
+ SetNodeAttr (" dtype" , DT_FLOAT, tensor_array_node);
122
+ SetNodeAttr (" element_shape" , TensorShape ({1 , 2 }), tensor_array_node);
123
+ SetNodeAttr (" dynamic_size" , false , tensor_array_node);
124
+ SetNodeAttr (" clear_after_read" , true , tensor_array_node);
125
+ SetNodeAttr (" tensor_array_name" , " some_name" , tensor_array_node);
126
+
A935
127
+ NodeDef* handle_output_node = graph_def.add_node ();
128
+ handle_output_node->set_name (" handle_output_node" );
129
+ handle_output_node->set_op (" Identity" );
130
+ handle_output_node->add_input (" tensor_array_node:0" );
131
+
132
+ NodeDef* flow_output_node = graph_def.add_node ();
133
+ flow_output_node->set_name (" flow_output_node" );
134
+ flow_output_node->set_op (" Identity" );
135
+ flow_output_node->add_input (" tensor_array_node:1" );
136
+
137
+ NodeDef* tensor_array_grad_node = graph_def.add_node ();
138
+ tensor_array_grad_node->set_name (" tensor_array_grad_node" );
139
+ tensor_array_grad_node->set_op (" TensorArrayGradV3" );
140
+ tensor_array_grad_node->add_input (" tensor_array_node:0" );
141
+ tensor_array_grad_node->add_input (" tensor_array_node:1" );
142
+ SetNodeAttr (" source" , " foo" , tensor_array_grad_node);
143
+
144
+ NodeDef* grad_handle_output_node = graph_def.add_node ();
145
+ grad_handle_output_node->set_name (" grad_handle_output_node" );
146
+ grad_handle_output_node->set_op (" Identity" );
147
+ grad_handle_output_node->add_input (" tensor_array_grad_node:0" );
148
+
149
+ NodeDef* grad_flow_output_node = graph_def.add_node ();
150
+ grad_flow_output_node->set_name (" grad_flow_output_node" );
151
+ grad_flow_output_node->set_op (" Identity" );
152
+ grad_flow_output_node->add_input (" tensor_array_grad_node:1" );
153
+
154
+ GraphDef result;
155
+ TransformFuncContext context;
156
+ context.input_names = {};
157
+ context.output_names = {" handle_output_node" , " grad_handle_output_node" };
158
+ TF_ASSERT_OK (BackportTensorArrayV3Transform (graph_def, context, &result));
159
+
160
+ std::map<string, const NodeDef*> node_lookup;
161
+ MapNamesToNodes (result, &node_lookup);
162
+ ASSERT_EQ (1 , node_lookup.count (" tensor_array_node" ));
163
+ EXPECT_EQ (" TensorArrayV2" , node_lookup.at (" tensor_array_node" )->op ());
164
+ EXPECT_EQ (" TensorArrayGradV2" ,
165
+ node_lookup.at (" tensor_array_grad_node" )->op ());
166
+
167
+ for (const NodeDef& node : result.node ()) {
168
+ for (const string& input : node.input ()) {
169
+ EXPECT_NE (" tensor_array_node:1" , input);
170
+ }
171
+ }
172
+ }
173
+
174
+ TEST (BackportTensorArrayV3Test, TestBackportTensorArrayV3Subtypes) {
175
+ const std::vector<string> v3_ops = {
176
+ " TensorArrayWriteV3" , " TensorArrayReadV3" , " TensorArrayGatherV3" ,
177
+ " TensorArrayScatterV3" , " TensorArrayConcatV3" , " TensorArraySplitV3" ,
178
+ " TensorArraySizeV3" , " TensorArrayCloseV3" };
179
+ for (const string& v3_op : v3_ops) {
180
+ GraphDef graph_def;
181
+ NodeDef* v3_node = graph_def.add_node ();
182
+ v3_node->set_name (" v3_node" );
183
+ v3_node->set_op (v3_op);
184
+
185
+ GraphDef result;
186
+ TransformFuncContext context;
187
+ context.input_names = {};
188
+ context.output_names = {" " };
189
+ TF_ASSERT_OK (BackportTensorArrayV3Transform (graph_def, context, &result));
190
+
191
+ std::map<string, const NodeDef*> node_lookup;
192
+ MapNamesToNodes (result, &node_lookup);
193
+ ASSERT_EQ (1 , node_lookup.count (" v3_node" ));
194
+ EXPECT_TRUE (StringPiece (node_lookup.at (" v3_node" )->op ()).ends_with (" V2" ));
195
+ }
196
+ }
197
+
104
198
} // namespace graph_transforms
105
199
} // namespace tensorflow
0 commit comments