8000 Infer shape from data in Constant nodes (#1667) · onnx/onnx@0ab3c95 · GitHub
[go: up one dir, main page]

Skip to content

Commit 0ab3c95

Browse files
shinhhouseroad
authored andcommitted
Infer shape from data in Constant nodes (#1667)
1 parent 6b34743 commit 0ab3c95

File tree

2 files changed

+27
-3
lines changed

2 files changed

+27
-3
lines changed

onnx/shape_inference/implementation.cc

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,22 @@ static void InferShapesImpl(
106106
valueTypesByName[vi.name()] = vi.mutable_type();
107107
}
108108

109-
std::unordered_map<std::string, const TensorProto*> initializersByName;
109+
std::unordered_map<std::string, const TensorProto*> inputDataByName;
110110
for (const auto& tp : g->initializer()) {
111-
initializersByName[tp.name()] = &tp;
111+
inputDataByName[tp.name()] = &tp;
112+
}
113+
// Collect data from constant nodes.
114+
for (const auto& n : g->node()) {
115+
if (n.op_type() != "Constant" || n.output().size() != 1) {
116+
continue;
117+
}
118+
for (const auto& attr : n.attribute()) {
119+
if (attr.name() == "value" &&
120+
attr.type() == AttributeProto::TENSOR &&
121+
attr.has_t()) {
122+
inputDataByName[n.output(0)] = &attr.t();
123+
}
124+
}
112125
}
113126

114127
for (auto& n : *g->mutable_node()) {
@@ -122,7 +135,7 @@ static void InferShapesImpl(
122135
const auto schema =
123136
schema_registry->GetSchema(n.op_type(), domain_version, n.domain());
124137
InferenceContextImpl ctx(
125-
n, valueTypesByName, initializersByName, &graphInferenceContext);
138+
n, valueTypesByName, inputDataByName, &graphInferenceContext);
126139
if (!schema) {
127140
if (nullptr == func_registry) {
128141
continue;

onnx/test/shape_inference_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,17 @@ def test_reshape_static_shape_inferred(self): # type: () -> None
222222
initializer=[make_tensor('shape', TensorProto.INT64, (3,), (0, 3, -1))])
223223
self._assert_inferred(graph, [make_tensor_value_info('y', TensorProto.UINT8, (2, 3, 4))])
224224

225+
def test_reshape_static_shape_constant(self): # type: () -> None
226+
graph = self._make_graph(
227+
[('x', TensorProto.UINT8, (2, 4, 3))],
228+
[make_node("Constant", [], ['shape'],
229+
value=make_tensor('shape', TensorProto.INT64, (2,), (3, 8))),
230+
make_node("Reshape", ['x', 'shape'], ['y'])],
231+
[])
232+
self._assert_inferred(graph, [
233+
make_tensor_value_info('shape', TensorProto.INT64, (2,)),
234+
make_tensor_value_info('y', TensorProto.UINT8, (3, 8))])
235+
225236
def test_upsample(self): # type: () -> None
226237
graph = self._make_graph(
227238
[('x', TensorProto.INT32, (2, 4, 3, 5)),

0 commit comments

Comments
 (0)
0