8000 Add error handling for missing attributes in NodeDef. · jbenjos/tensorflow@0f9f2e0 · GitHub
[go: up one dir, main page]

Skip to content

Commit 0f9f2e0

Browse files
Yao Zhangtensorflower-gardener
authored andcommitted
Add error handling for missing attributes in NodeDef.
Change: 150663008
1 parent 5c21d55 commit 0f9f2e0

File tree

1 file changed

+73
-35
lines changed

1 file changed

+73
-35
lines changed

tensorflow/core/grappler/optimizers/layout_optimizer.cc

Lines changed: 73 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -134,16 +134,17 @@ class NodeProcessor {
134134
NodeProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map)
135135
: graph_(graph), node_(node), node_map_(node_map) {}
136136
virtual ~NodeProcessor() {}
137-
virtual void ConvertNode() {
137+
virtual Status ConvertNode() {
138138
if (ShouldProcess()) {
139139
UpdateAttrDataFormat();
140140
UpdateAttrKSize();
141141
UpdateAttrStrides();
142142
UpdateAttrShape();
143-
AddLayoutTransposeToInputs();
144-
AddLayoutTransposeToOutputs();
145-
CustomizedProcessing();
143+
TF_RETURN_IF_ERROR(AddLayoutTransposeToInputs());
144+
TF_RETURN_IF_ERROR(AddLayoutTransposeToOutputs());
145+
TF_RETURN_IF_ERROR(CustomizedProcessing());
146146
}
147+
return Status::OK();
147148
}
148149

149150
protected:
@@ -173,6 +174,14 @@ class NodeProcessor {
173174
return !outputs.empty();
174175
}
175176

177+
Status HasAttribute(const NodeDef& node, const string& attr) const {
178+
if (node.attr().find(attr) == node.attr().end()) {
179+
return Status(error::INVALID_ARGUMENT,
180+
strings::StrCat("Missing attribute ", attr));
181+
}
182+
return Status::OK();
183+
}
184+
176185
virtual bool ShouldProcess() const {
177186
return IsNHWC() && IsDimsFour(node_) && HasOutputs();
178187
}
@@ -218,8 +227,9 @@ class NodeProcessor {
218227
}
219228
}
220229

221-
void UpdateAttrValue(const string& name) {
230+
Status UpdateAttrValue(const string& name) {
222231
NodeDef* node = node_map_->GetNode(name);
232+
TF_RETURN_IF_ERROR(HasAttribute(*node, "value"));
223233
Tensor tensor;
224234
auto success =
225235
tensor.FromProto(node->mutable_attr()->at({"value"}).tensor());
@@ -232,6 +242,7 @@ class NodeProcessor {
232242
tensor.flat<int>()(1) = c;
233243
tensor.AsProtoTensorContent(
234244
node->mutable_attr()->at({"value"}).mutable_tensor());
245+
return Status::OK();
235246
}
236247

237248
virtual std::vector<int> GetInputPos() const {
@@ -270,13 +281,15 @@ class NodeProcessor {
270281
node->mutable_attr()->insert({"_output_shapes", attr_output_shape});
271282
}
272283

273-
virtual void AddLayoutTransposeToInputs() {
284+
virtual Status AddLayoutTransposeToInputs() {
274285
std::vector<int> input_pos = GetInputPos();
275286
for (const auto& pos : input_pos) {
276287
string node_name_NHWCToNCHW = strings::StrCat(
277288
kTransposeNHWCToNCHW, "-", node_->name(), "-", node_->input(pos));
278289
auto input_node = node_map_->GetNode(node_->input(pos));
279290
int output_pos = NodePosition(node_->input(pos));
291+
TF_RETURN_IF_ERROR(HasAttribute(*node_, "T"));
292+
TF_RETURN_IF_ERROR(HasAttribute(*input_node, "_output_shapes"));
280293
AddNodeTranspose(
281294
node_name_NHWCToNCHW, node_->input(pos), node_->attr().at("T").type(),
282295
input_node->attr().at("_output_shapes").list().shape(output_pos),
@@ -286,9 +299,10 @@ class NodeProcessor {
286299
node_map_->AddOutput(node_name_NHWCToNCHW, node_->name());
287300
*node_->mutable_input(pos) = node_name_NHWCToNCHW;
288301
}
302+
return Status::OK();
289303
}
290304

291-
virtual void AddLayoutTransposeToOutputs() {
305+
virtual Status AddLayoutTransposeToOutputs() {
292306
auto outputs = node_map_->GetOutputs(node_->name());
293307
for (const auto& output : outputs) {
294308
string node_name_NCHWToNHWC = strings::StrCat(
@@ -299,6 +313,8 @@ class NodeProcessor {
299313
return input.compare(node_->name()) == 0;
300314
});
301315
int output_pos = NodePosition(*it);
316+
TF_RETURN_IF_ERROR(HasAttribute(*node_, "T"));
317+
TF_RETURN_IF_ERROR(HasAttribute(*node_, "_output_shapes"));
302318
AddNodeTranspose(
303319
node_name_NCHWToNHWC, node_->name(), node_->attr().at("T").type(),
304320
node_->attr().at("_output_shapes").list().shape(output_pos), false);
@@ -307,9 +323,10 @@ class NodeProcessor {
307323
node_name_NCHWToNHWC);
308324
node_map_->AddOutput(node_name_NCHWToNHWC, output->name());
309325
}
326+
return Status::OK();
310327
}
311328

312-
virtual void CustomizedProcessing() {}
329+
virtual Status CustomizedProcessing() { return Status::OK(); }
313330

314331
GraphDef* graph_;
315332
NodeDef* node_;
@@ -336,7 +353,9 @@ class AvgPoolGradProcessor : public NodeProcessor {
336353
std::vector<int> input_pos = {1};
337354
return input_pos;
338355
}
339-
void CustomizedProcessing() override { UpdateAttrValue(node_->input(0)); }
356+
Status CustomizedProcessing() override {
357+
return UpdateAttrValue(node_->input(0));
358+
}
340359
};
341360

342361
class BiasAddGradProcessor : public NodeProcessor {
@@ -355,7 +374,7 @@ class BiasAddGradProcessor : public NodeProcessor {
355374
return false;
356375
}
357376

358-
void AddLayoutTransposeToOutputs() override {}
377+
Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
359378
};
360379

361380
class Conv2DBackpropFilterProcessor : public NodeProcessor {
@@ -370,7 +389,7 @@ class Conv2DBackpropFilterProcessor : public NodeProcessor {
370389
return input_pos;
371390
}
372391

373-
void AddLayoutTransposeToOutputs() override {}
392+
Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
374393
// No need to update output shape, as it is always of shape
375394
// [filter_height, filter_width, in_channels, out_channels], regardless of
376395
// whether NCHW or NHWC is used.
@@ -388,7 +407,9 @@ class Conv2DBackpropInputProcessor : public NodeProcessor {
388407
std::vector<int> input_pos = {2};
389408
return input_pos;
390409
}
391-
void CustomizedProcessing() override { UpdateAttrValue(node_->input(0)); }
410+
Status CustomizedProcessing() override {
411+
return UpdateAttrValue(node_->input(0));
412+
}
392413
};
393414

394415
class FusedBatchNormGradProcessor : public NodeProcessor {
@@ -537,19 +558,17 @@ class BinaryOpProcessor : public AgnosticNodeProcessor {
537558
node->mutable_attr()->insert({"T", attr_type_params});
538559
}
539560

540-
void CustomizedProcessing() override {
561+
Status CustomizedProcessing() override {
541562
if (is_4d_with_vector_) {
542563
string suffix = strings::StrCat("-", node_->name(), "-", node_->input(1));
543564
string reshape_node_name = strings::StrCat(kReshapeNHWCToNCHW, suffix);
544565
string shape_const_node_name = strings::StrCat(kReshapeConst, suffix);
545-
int vector_size = node_map_->GetNode(node_->input(1))
546-
->attr()
547-
.at("_output_shapes")
548-
.list()
549-
.shape(0)
550-
.dim(0)
551-
.size();
566+
auto input_node = node_map_->GetNode(node_->input(1));
567+
TF_RETURN_IF_ERROR(HasAttribute(*input_node, "_output_shapes"));
568+
int vector_size =
569+
input_node->attr().at("_output_shapes").list().shape(0).dim(0).size();
552570
AddNodeShapeConst(shape_const_node_name, vector_size);
571+
TF_RETURN_IF_ERROR(HasAttribute(*node_, "T"));
553572
AddNodeReshape(reshape_node_name, node_->input(1), shape_const_node_name,
554573
node_->attr().at("T").type());
555574
node_map_->AddOutput(shape_const_node_name, reshape_node_name);
@@ -558,6 +577,7 @@ class BinaryOpProcessor : public AgnosticNodeProcessor {
558577
node_map_->AddOutput(reshape_node_name, node_->name());
559578
*node_->mutable_input(1) = reshape_node_name;
560579
}
580+
return Status::OK();
561581
}
562582

563583
private:
@@ -591,9 +611,10 @@ class ConcatProcessor : public AgnosticNodeProcessor {
591611
return input_pos;
592612
}
593613

594-
void CustomizedProcessing() override {
614+
Status CustomizedProcessing() override {
595615
node_map_->AddOutput(kConcatConst, node_->name());
596616
*node_->mutable_input(axis_node_pos_) = kConcatConst;
617+
return Status::OK();
597618
}
598619

599620
bool IsAlongDimC() const {
@@ -627,18 +648,20 @@ class SliceProcessorGatherBased : public AgnosticNodeProcessor {
627648
: AgnosticNodeProcessor(graph, node, node_map) {}
628649

629650
protected:
630-
void CustomizedProcessing() override {
651+
Status CustomizedProcessing() override {
631652
// Skip the first input, which is the data to be sliced.
632653
for (int i = 1; i < node_->input_size(); i++) {
633654
string node_name_NHWCToNCHW =
634655
strings::StrCat(kPermVecNHWCToNCHW, "-", node_->name(), "-input", i);
656+
TF_RETURN_IF_ERROR(HasAttribute(*node_, "Index"));
635657
AddNodePermVec(node_name_NHWCToNCHW, node_->input(i),
636658
node_->attr().at("Index").type(), true);
637659
node_map_->UpdateOutput(node_->input(i), node_->name(),
638660
node_name_NHWCToNCHW);
639661
node_map_->AddOutput(node_name_NHWCToNCHW, node_->name());
640662
*node_->mutable_input(i) = node_name_NHWCToNCHW;
641663
}
664+
return Status::OK();
642665
}
643666

644667
private:
@@ -671,7 +694,7 @@ class SliceProcessor : public AgnosticNodeProcessor {
671694
: AgnosticNodeProcessor(graph, node, node_map) {}
672695

673696
protected:
674-
void CustomizedProcessing() override {
697+
Status CustomizedProcessing() override {
675698
auto maybe_concatoffset_node =
676699
node_map_->GetNode(NodeName(node_->input(1)));
677700
if (maybe_concatoffset_node->op() == "ConcatOffset") {
@@ -680,12 +703,14 @@ class SliceProcessor : public AgnosticNodeProcessor {
680703
// NHWC format is being used. As mutiple Slice nodes may share the same
681704
// ConcatOffset node, the NHWC to NCHW conversion may have already
682705
// been performed when processing other Slice nodes.
706+
TF_RETURN_IF_ERROR(HasAttribute(*axis_node, "value"));
683707
if (axis_node->attr().at("value").tensor().int_val(0) == 3) {
684708
for (int i = 1; i < maybe_concatoffset_node->input_size(); i++) {
685709
auto shape_node =
686710
node_map_->GetNode(maybe_concatoffset_node->input(i));
687711
AttrValue attr_tensor;
688712
Tensor tensor;
713+
TF_RETURN_IF_ERROR(HasAttribute(*shape_node, "value"));
689714
CHECK(tensor.FromProto(shape_node->attr().at({"value"}).tensor()));
690715
int h = tensor.flat<int>()(1);
691716
int w = tensor.flat<int>()(2);
@@ -702,6 +727,7 @@ class SliceProcessor : public AgnosticNodeProcessor {
702727
1);
703728
}
704729
}
730+
return Status::OK();
705731
}
706732
};
707733

@@ -716,7 +742,7 @@ class SqueezeProcessor : public AgnosticNodeProcessor {
716742
IsInputConvertible() && IsAlongDimHW();
717743
}
718744

719-
void AddLayoutTransposeToOutputs() override {}
745+
Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
720746

721747
bool IsInputConvertible() const {
722748
auto input = node_map_->GetNode(node_->input(0));
@@ -745,10 +771,12 @@ class SqueezeProcessor : public AgnosticNodeProcessor {
745771
return false;
746772
}
747773

748-
void CustomizedProcessing() override {
774+
Status CustomizedProcessing() override {
775+
TF_RETURN_IF_ERROR(HasAttribute(*node_, "squeeze_dims"));
749776
auto list = node_->mutable_attr()->at("squeeze_dims").mutable_list();
750777
list->set_i(0, 2);
751778
list->set_i(1, 3);
779+
return Status::OK();
752780
}
753781
};
754782

@@ -765,11 +793,12 @@ class SumProcessor : public AgnosticNodeProcessor {
765793
IsAlongDimNHW();
766794
}
767795

768-
void AddLayoutTransposeToOutputs() override {}
796+
Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
769797

770-
void CustomizedProcessing() override {
798+
Status CustomizedProcessing() override {
771799
node_map_->AddOutput(kReductionConst, node_->name());
772800
*node_->mutable_input(1) = kReductionConst;
801+
return Status::OK();
773802
}
774803

775804
private:
@@ -798,12 +827,15 @@ class SumProcessor : public AgnosticNodeProcessor {
798827
class DataLayoutOptimizer {
799828
public:
800829
explicit DataLayoutOptimizer(GraphDef* graph)
801-
: graph_(graph), node_map_(graph_) {
830+
: graph_(graph), node_map_(graph_) {}
831+
832+
Status Optimize() {
802833
LOG(INFO) << "Number of nodes for original graph: " << graph_->node_size();
803-
Expand();
834+
TF_RETURN_IF_ERROR(Expand());
804835
LOG(INFO) << "Number of nodes after Expand: " << graph_->node_size();
805-
Collapse();
836+
TF_RETURN_IF_ERROR(Collapse());
806837
LOG(INFO) << "Number of nodes after Collapse: " << graph_->node_size();
838+
return Status::OK();
807839
}
808840

809841
private:
@@ -860,7 +892,7 @@ class DataLayoutOptimizer {
860892
}
861893

862894
// Expand all nodes which is in NHWC, but supports NCHW or is layout agnostic.
863-
void Expand() {
895+
Status Expand() {
864896
int node_size_original = graph_->node_size();
865897
// This is the first pass where we expand the nodes which support NCHW.
866898
std::set<string> ops_format_supported = GetOpsFormatSupported();
@@ -890,7 +922,7 @@ 10000 class DataLayoutOptimizer {
890922
} else {
891923
node_processor.reset(new NodeProcessor(graph_, node, &node_map_));
892924
}
893-
node_processor->ConvertNode();
925+
TF_RETURN_IF_ERROR(node_processor->ConvertNode());
894926
}
895927
}
896928

@@ -934,15 +966,16 @@ class DataLayoutOptimizer {
934966
node_processor.reset(
935967
new AgnosticNodeProcessor(graph_, node, &node_map_));
936968
}
937-
node_processor->ConvertNode();
969+
TF_RETURN_IF_ERROR(node_processor->ConvertNode());
938970
}
939971
}
940972
}
973+
return Status::OK();
941974
}
942975

943976
// Remove all node pairs, where a NCHW-to-NHWC node is followed by
944977
// a NHWC-to-NCHW node.
945-
void Collapse() {
978+
Status Collapse() {
946979
std::unordered_set<string> nodes_removable;
947980
for (int i = 0; i < graph_->node_size(); i++) {
948981
auto node = graph_->mutable_node(i);
@@ -974,6 +1007,7 @@ class DataLayoutOptimizer {
9741007
return nodes_removable.find(node.name()) != nodes_removable.end();
9751008
}),
9761009
graph_->mutable_node()->end());
1010+
return Status::OK();
9771011
}
9781012

9791013
GraphDef* graph_;
@@ -984,7 +1018,11 @@ Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
9841018
GraphDef* output) {
9851019
*output = item.graph;
9861020
DataLayoutOptimizer layout_optimizer(output);
987-
return Status::OK();
1021+
auto status = layout_optimizer.Optimize();
1022+
if (!status.ok()) {
1023+
*output = item.graph;
1024+
}
1025+
return status;
9881026
}
9891027

9901028
void LayoutOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item,

0 commit comments

Comments
 (0)
0