8000 Fuse MatMul and Add into Gemm (#1542) · onnx/onnx@f85221f · GitHub
[go: up one dir, main page]

Skip to content

Commit f85221f

Browse files
vloncarhouseroad
authored andcommitted
Fuse MatMul and Add into Gemm (#1542)
* Fuse MatMul and add into Gemm * Fix signed & unsigned comparison * Fix whitespace in optimizer_test * Typecheck fix * Add MatMul symbol * Remove unnecessary print statement * Remove unnecessary check * Additional tests for fuse_matmul_add_bias_into_gemm * Compare graphs instead of nodes with asserts * Additional shape checks and test fixes * Minor style fix * Simpler check of shape compatibility * Reintroduce more strict shape checks
1 parent 022230e commit f85221f

File tree

5 files changed

+223
-2
lines changed

5 files changed

+223
-2
lines changed

onnx/common/interned_strings.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ namespace ONNX_NAMESPACE {
4040
_(Squeeze) \
4141
_(Undefined) \
4242
_(FusionGroup) \
43+
_(MatMul) \
4344
_(Gemm) \
4445
_(Tile) \
4546
_(SubConstant) \

onnx/examples/optimize_onnx.ipynb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
"\tfuse_consecutive_reduce_unsqueeze\n",
6464
"\tfuse_consecutive_squeezes\n",
6565
"\tfuse_consecutive_transposes\n",
66+
"\tfuse_matmul_add_bias_into_gemm\n",
6667
"\tfuse_pad_into_conv\n",
6768
"\tfuse_transpose_into_gemm\n",
6869
"\tlift_lexical_references\n",

onnx/optimizer/pass_registry.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@
1515
#include "onnx/optimizer/passes/eliminate_unused_initializer.h"
1616
#include "onnx/optimizer/passes/extract_constant_to_initializer.h"
1717
#include "onnx/optimizer/passes/fuse_add_bias_into_conv.h"
18-
#include "onnx/optimizer/passes/fuse_pad_into_conv.h"
1918
#include "onnx/optimizer/passes/fuse_bn_into_conv.h"
2019
#include "onnx/optimizer/passes/fuse_consecutive_concats.h"
2120
#include "onnx/optimizer/passes/fuse_consecutive_log_softmax.h"
2221
#include "onnx/optimizer/passes/fuse_consecutive_reduce_unsqueeze.h"
2322
#include "onnx/optimizer/passes/fuse_consecutive_squeezes.h"
2423
#include "onnx/optimizer/passes/fuse_consecutive_transposes.h"
24+
#include "onnx/optimizer/passes/fuse_matmul_add_bias_into_gemm.h"
25+
#include "onnx/optimizer/passes/fuse_pad_into_conv.h"
2526
#include "onnx/optimizer/passes/fuse_transpose_into_gemm.h"
2627
#include "onnx/optimizer/passes/lift_lexical_references.h"
2728
#include "onnx/optimizer/passes/nop.h"
@@ -50,13 +51,14 @@ struct GlobalPassRegistry {
5051
registerPass<EliminateUnusedInitializer>();
5152
registerPass<ExtractConstantToInitializer>();
5253
registerPass<FuseAddBiasIntoConv>();
53-
registerPass<FusePadIntoConv>();
5454
registerPass<FuseBNIntoConv>();
5555
registerPass<FuseConsecutiveConcats>();
5656
registerPass<FuseConsecutiveLogSoftmax>();
5757
registerPass<FuseConsecutiveReduceUnsqueeze>();
5858
registerPass<FuseConsecutiveSqueezes>();
5959
registerPass<FuseConsecutiveTransposes>();
60+
registerPass<FuseMatMulAddBiasIntoGemm>();
61+
registerPass<FusePadIntoConv>();
6062
registerPass<FuseTransposeIntoGemm>();
6163
registerPass<LiftLexicalReferences>();
6264
registerPass<SplitInit>();
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
// ATTENTION: The code in this file is highly EXPERIMENTAL.
2+
// Adventurous users should note that the APIs will probably change.
3+
4+
#pragma once
5+
6+
// Before:
7+
// Z = MatMul(X, Y)
8+
// A = Z + Bias
9+
// After:
10+
// A = Gemm(X, Y, Bias)
11+
//
12+
// the pass can handle th F438 e case when:
13+
// case 1: Bias is 1D tensor and Bias.dim[0] == Z.dim[1]
14+
// case 2: Bias is 2D tensor and Bias.dim[0] == Z.dim[0] or 1
15+
// and Bias.dim[1] = Z.dim[1]
16+
17+
#include <numeric>
18+
19+
#include "onnx/common/assertions.h"
20+
#include "onnx/optimizer/pass.h"
21+
22+
namespace ONNX_NAMESPACE {
23+
namespace optimization {
24+
25+
struct FuseMatMulAddBiasIntoGemm final : public PredicateBasedPass {
26+
explicit FuseMatMulAddBiasIntoGemm()
27+
: PredicateBasedPass(
28+
PassType::Fuse,
29+
PassEfficiency::Complete,
30+
PassOptimizationType::Compute) {}
31+
std::string getPassName() const override {
32+
return "fuse_matmul_add_bias_into_gemm";
33+
}
34+
bool patternMatchPredicate(Node* node) override {
35+
return node->kind() == kAdd &&
36+
node->inputs()[0]->node()->kind() == kMatMul;
37+
}
38+
bool runTransform(Node* n, Graph& graph, NodeDestroyType& destroy_current)
39+
override {
40+
// due to current broadcasting's constraint, MatMul has to be the first
41+
// operand
42+
destroy_current = NodeDestroyType::DestroyZero;
43+
auto orig_matmul = n->inputs()[0];
44+
auto orig_bias = n->inputs()[1];
45+
// check if bias is Const or in graph's initializers
46+
if (orig_bias->node()->kind() != kConstant &&
47+
orig_bias->node()->kind() != kParam) {
48+
return false;
49+
}
50+
// check if MatMul is only used by Add
51+
if (orig_matmul->uses().size() > 1) {
52+
return false;
53+
}
54+
auto x_shape = orig_matmul->node()->inputs()[0]->sizes();
55+
auto y_shape = orig_matmul->node()->inputs()[1]->sizes();
56+
int64_t z_N = -1;
57+
int64_t z_M = -1;
58+
// try to get feature N from x_shape
59+
if (static_cast<int64_t>(x_shape.size()) == 2 && x_shape[0].is_int) {
60+
z_N = x_shape[0].dim;
61+
} else {
62+
return false;
63+
}
64+
// try to get feature M from y_shape
65+
if (static_cast<int64_t>(y_shape.size()) == 2 && y_shape[1].is_int) {
66+
z_M = y_shape[1].dim;
67+
} else {
68+
return false;
69+
}
70+
// check if bias_shape is compatible
71+
auto bias_shape = orig_bias->sizes();
72+
auto bias_dim = static_cast<int64_t>(bias_shape.size());
73+
int64_t bias_N = -1;
74+
int64_t bias_M = -1;
75+
if (bias_dim == 1 && bias_shape[0].is_int) {
76+
bias_N = 1;
77+
bias_M = bias_shape[0].dim;
78+
} else if (bias_dim == 2 && bias_shape[0].is_int && bias_shape[1].is_int) {
79+
bias_N = bias_shape[0].dim;
80+
bias_M = bias_shape[1].dim;
81+
} else {
82+
return false;
83+
}
84+
if ((bias_N != z_N && bias_N != 1) || bias_M != z_M) {
85+
return false;
86+
}
87+
// proceed to fuse MatMul and Add into Gemm
88+
Node* gemm = graph.create(kGemm,
89+
orig_matmul->node()->inputs(),
90+
n->outputs().size());
91+
gemm->addInput(n->inputs()[1]);
92+
for (int i = 0; i < static_cast<int64_t>(gemm->outputs().size()); ++i) {
93+
gemm->outputs()[i]->copyMetadata(n->outputs()[i]);
94+
}
95+
gemm->f_(kalpha, 1.0);
96+
gemm->f_(kbeta, 1.0);
97+
gemm->i_(ktransA, 0);
98+
gemm->i_(ktransB, 0);
99+
gemm->insertBefore(orig_matmul->node());
100+
n->replaceAllUsesWith(gemm);
101+
destroy_current = NodeDestroyType::DestroyTwo;
102+
return true;
103+
}
104+
};
105+
106+
} // namespace optimization
107+
} // namespace ONNX_NAMESPACE

onnx/test/optimizer_test.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,116 @@ def test_fuse_add_bias_into_conv_squeeze_4d_bias_no_fuse(self): # type: () -> N
653653
assert optimized_model.graph.node[0].op_type == 'Conv'
654654
assert optimized_model.graph.node[1].op_type == 'Add'
655655

656+
def test_fuse_matmul_add_bias_into_gemm(self): # type: () -> None
657+
matmul = helper.make_node("MatMul", ["X", "Y"], ["Z"])
658+
add = helper.make_node("Add", ["Z", "B"], ["A"])
659+
graph = helper.make_graph(
660+
[matmul, add],
661+
"test",
662+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (32, 10)),
663+
helper.make_tensor_value_info("Y", TensorProto.FLOAT, (10, 16)),
664+
helper.make_tensor_value_info("B", TensorProto.FLOAT, (16,))],
665+
[helper.make_tensor_value_info("A", TensorProto.FLOAT, (32, 16))]
666+
)
667+
optimized_model = self._optimized(graph, ["fuse_matmul_add_bias_into_gemm"])
668+
669+
assert len(list(optimized_model.graph.node)) == 1
670+
assert optimized_model.graph.node[0].op_type == "Gemm"
671+
672+
def test_fuse_matmul_add_bias_into_gemm_2d_bias(self): # type: () -> None
673+
matmul = helper.make_node("MatMul", ["X", "Y"], ["Z"])
674+
add = helper.make_node("Add", ["Z", "B"], ["A"])
675+
graph = helper.make_graph(
676+
[matmul, add],
677+
"test",
678+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (32, 10)),
679+
helper.make_tensor_value_info("Y", TensorProto.FLOAT, (10, 16)),
680+
helper.make_tensor_value_info("B", TensorProto.FLOAT, (1, 16))],
681+
[helper.make_tensor_value_info("A", TensorProto.FLOAT, (32, 16))]
682+
)
683+
optimized_model = self._optimized(graph, ["fuse_matmul_add_bias_into_gemm"])
684+
685+
assert len(list(optimized_model.graph.node)) == 1
686+
assert optimized_model.graph.node[0].op_type == "Gemm"
687+
688+
def test_fuse_matmul_add_bias_into_gemm_2d_bias_same_shape(self): # type: () -> None
689+
matmul = helper.make_node("MatMul", ["X", "Y"], ["Z"])
690+
add = helper.make_node("Add", ["Z", "B"], ["A"])
691+
graph = helper.make_graph(
692+
[matmul, add],
693+
"test",
694+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (32, 10)),
695+
helper.make_tensor_value_info("Y", TensorProto.FLOAT, (10, 16)),
696+
helper.make_tensor_value_info("B", TensorProto.FLOAT, (32, 16))],
697+
[helper.make_tensor_value_info("A", TensorProto.FLOAT, (32, 16))]
698+
)
699+
optimized_model = self._optimized(graph, ["fuse_matmul_add_bias_into_gemm"])
700+
701+
assert len(list(optimized_model.graph.node)) == 1
702+
assert optimized_model.graph.node[0].op_type == "Gemm"
703+
704+
def test_fuse_matmul_add_bias_into_gemm_2d_bias_bcast_no_fuse(self): # type: () -> None
705+
matmul = helper.make_node("MatMul", ["X", "Y"], ["Z"])
706+
add = helper.make_node("Add", ["Z", "B"], ["A"])
707+
graph = helper.make_graph(
708+
[matmul, add],
709+
"test",
710+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (1, 10)),
711+
helper.make_tensor_value_info("Y", TensorProto.FLOAT, (10, 16)),
712+
helper.make_tensor_value_info("B", TensorProto.FLOAT, (16, 16))],
713+
[helper.make_tensor_value_info("A", TensorProto.FLOAT, (16, 16))]
714+
)
715+
optimized_model = self._optimized(graph, ["fuse_matmul_add_bias_into_gemm"])
716+
717+
assert optimized_model.graph == graph
718+
719+
def test_fuse_matmul_add_bias_into_gemm_3d_matmul_no_fuse(self): # type: () -> None
720+
matmul = helper.make_node("MatMul", ["X", "Y"], ["Z"])
721+
add = helper.make_node("Add", ["Z", "B"], ["A"])
722+
graph = helper.make_graph(
723+
[matmul, add],
724+
"test",
725+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (2, 3, 4)),
726+
helper.make_tensor_value_info("Y", TensorProto.FLOAT, (2, 4, 3)),
727+
helper.make_tensor_value_info("B", TensorProto.FLOAT, (3, 3))],
728+
[helper.make_tensor_value_info("A", TensorProto.FLOAT, (2, 3, 3))]
729+
)
730+
optimized_model = self._optimized(graph, ["fuse_matmul_add_bias_into_gemm"])
731+
732+
assert optimized_model.graph == graph
733+
734+
def test_fuse_matmul_add_bias_into_gemm_3d_bias_no_fuse(self): # type: () -> None
735+
matmul = helper.make_node("MatMul", ["X", "Y"], ["Z"])
736+
add = helper.make_node("Add", ["Z", "B"], ["A"])
737+
graph = helper.make_graph(
738+
[matmul, add],
739+
"test",
740+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (32, 10)),
741+
helper.make_tensor_value_info("Y", TensorProto.FLOAT, (10, 16)),
742+
helper.make_tensor_value_info("B", TensorProto.FLOAT, (4, 1, 16))],
743+
[helper.make_tensor_value_info("A", TensorProto.FLOAT, (32, 16))]
744+
)
745+
optimized_model = self._optimized(graph, ["fuse_matmul_add_bias_into_gemm"])
746+
747+
assert optimized_model.graph == graph
748+
749+
def test_fuse_matmul_add_bias_into_gemm_multiple_use_no_fuse(self): # type: () -> None
750+
matmul = helper.make_node("MatMul", ["X", "Y"], ["Z"])
751+
identity = helper.make_node("Identity", ["Z"], ["A1"])
752+
add = helper.make_node("Add", ["Z", "B"], ["A2"])
753+
graph = helper.make_graph(
754+
[matmul, add, identity],
755+
"test",
756+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (32, 10)),
757+
helper.make_tensor_value_info("Y", TensorProto.FLOAT, (10, 16)),
758+
helper.make_tensor_value_info("B", TensorProto.FLOAT, (1, 16))],
759+
[helper.make_tensor_value_info("A1", TensorProto.FLOAT, (32, 16)),
760+
helper.make_tensor_value_info("A2", TensorProto.FLOAT, (32, 16))]
761+
)
762+
optimized_model = self._optimized(graph, ["fuse_matmul_add_bias_into_gemm"])
763+
764+
assert optimized_model.graph == graph
765+
656766
def test_fuse_pad_into_conv(self): # type: () -> None
657767
pad = helper.make_node(
658768
"Pad",

0 commit comments

Comments
 (0)
0