@@ -653,6 +653,116 @@ def test_fuse_add_bias_into_conv_squeeze_4d_bias_no_fuse(self): # type: () -> N
653
653
assert optimized_model .graph .node [0 ].op_type == 'Conv'
654
654
assert optimized_model .graph .node [1 ].op_type == 'Add'
655
655
656
+ def test_fuse_matmul_add_bias_into_gemm (self ): # type: () -> None
657
+ matmul = helper .make_node ("MatMul" , ["X" , "Y" ], ["Z" ])
658
+ add = helper .make_node ("Add" , ["Z" , "B" ], ["A" ])
659
+ graph = helper .make_graph (
660
+ [matmul , add ],
661
+ "test" ,
662
+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , (32 , 10 )),
663
+ helper .make_tensor_value_info ("Y" , TensorProto .FLOAT , (10 , 16 )),
664
+ helper .make_tensor_value_info ("B" , TensorProto .FLOAT , (16 ,))],
665
+ [helper .make_tensor_value_info ("A" , TensorProto .FLOAT , (32 , 16 ))]
666
+ )
667
+ optimized_model = self ._optimized (graph , ["fuse_matmul_add_bias_into_gemm" ])
668
+
669
+ assert len (list (optimized_model .graph .node )) == 1
670
+ assert optimized_model .graph .node [0 ].op_type == "Gemm"
671
+
672
+ def test_fuse_matmul_add_bias_into_gemm_2d_bias (self ): # type: () -> None
673
+ matmul = helper .make_node ("MatMul" , ["X" , "Y" ], ["Z" ])
674
+ add = helper .make_node ("Add" , ["Z" , "B" ], ["A" ])
675
+ graph = helper .make_graph (
676
+ [matmul , add ],
677
+ "test" ,
678
+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , (32 , 10 )),
679
+ helper .make_tensor_value_info ("Y" , TensorProto .FLOAT , (10 , 16 )),
680
+ helper .make_tensor_value_info ("B" , TensorProto .FLOAT , (1 , 16 ))],
681
+ [helper .make_tensor_value_info ("A" , TensorProto .FLOAT , (32 , 16 ))]
682
+ )
683
+ optimized_model = self ._optimized (graph , ["fuse_matmul_add_bias_into_gemm" ])
684
+
685
+ assert len (list (optimized_model .graph .node )) == 1
686
+ assert optimized_model .graph .node [0 ].op_type == "Gemm"
687
+
688
+ def test_fuse_matmul_add_bias_into_gemm_2d_bias_same_shape (self ): # type: () -> None
689
+ matmul = helper .make_node ("MatMul" , ["X" , "Y" ], ["Z" ])
690
+ add = helper .make_node ("Add" , ["Z" , "B" ], ["A" ])
691
+ graph = helper .make_graph (
692
+ [matmul , add ],
693
+ "test" ,
694
+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , (32 , 10 )),
695
+ helper .make_tensor_value_info ("Y" , TensorProto .FLOAT , (10 , 16 )),
696
+ helper .make_tensor_value_info ("B" , TensorProto .FLOAT , (32 , 16 ))],
697
+ [helper .make_tensor_value_info ("A" , TensorProto .FLOAT , (32 , 16 ))]
698
+ )
699
+ optimized_model = self ._optimized (graph , ["fuse_matmul_add_bias_into_gemm" ])
700
+
701
+ assert len (list (optimized_model .graph .node )) == 1
702
+ assert optimized_model .graph .node [0 ].op_type == "Gemm"
703
+
704
+ def test_fuse_matmul_add_bias_into_gemm_2d_bias_bcast_no_fuse (self ): # type: () -> None
705
+ matmul = helper .make_node ("MatMul" , ["X" , "Y" ], ["Z" ])
706
+ add = helper .make_node ("Add" , ["Z" , "B" ], ["A" ])
707
+ graph = helper .make_graph (
708
+ [matmul , add ],
709
+ "test" ,
710
+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , (1 , 10 )),
711
+ helper .make_tensor_value_info ("Y" , TensorProto .FLOAT , (10 , 16 )),
712
+ helper .make_tensor_value_info ("B" , TensorProto .FLOAT , (16 , 16 ))],
713
+ [helper .make_tensor_value_info ("A" , TensorProto .FLOAT , (16 , 16 ))]
714
+ )
715
+ optimized_model = self ._optimized (graph , ["fuse_matmul_add_bias_into_gemm" ])
716
+
717
+ assert optimized_model .graph == graph
718
+
719
+ def test_fuse_matmul_add_bias_into_gemm_3d_matmul_no_fuse (self ): # type: () -> None
720
+ matmul = helper .make_node ("MatMul" , ["X" , "Y" ], ["Z" ])
721
+ add = helper .make_node ("Add" , ["Z" , "B" ], ["A" ])
722
+ graph = helper .make_graph (
723
+ [matmul , add ],
724
+ "test" ,
725
+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , (2 , 3 , 4 )),
726
+ helper .make_tensor_value_info ("Y" , TensorProto .FLOAT , (2 , 4 , 3 )),
727
+ helper .make_tensor_value_info ("B" , TensorProto .FLOAT , (3 , 3 ))],
728
+ [helper .make_tensor_value_info ("A" , TensorProto .FLOAT , (2 , 3 , 3 ))]
729
+ )
730
+ optimized_model = self ._optimized (graph , ["fuse_matmul_add_bias_into_gemm" ])
731
+
732
+ assert optimized_model .graph == graph
733
+
734
+ def test_fuse_matmul_add_bias_into_gemm_3d_bias_no_fuse (self ): # type: () -> None
735
+ matmul = helper .make_node ("MatMul" , ["X" , "Y" ], ["Z" ])
736
+ add = helper .make_node ("Add" , ["Z" , "B" ], ["A" ])
737
+ graph = helper .make_graph (
738
+ [matmul , add ],
739
+ "test" ,
740
+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , (32 , 10 )),
741
+ helper .make_tensor_value_info ("Y" , TensorProto .FLOAT , (10 , 16 )),
742
+ helper .make_tensor_value_info ("B" , TensorProto .FLOAT , (4 , 1 , 16 ))],
743
+ [helper .make_tensor_value_info ("A" , TensorProto .FLOAT , (32 , 16 ))]
744
+ )
745
+ optimized_model = self ._optimized (graph , ["fuse_matmul_add_bias_into_gemm" ])
746
+
747
+ assert optimized_model .graph == graph
748
+
749
+ def test_fuse_matmul_add_bias_into_gemm_multiple_use_no_fuse (self ): # type: () -> None
750
+ matmul = helper .make_node ("MatMul" , ["X" , "Y" ], ["Z" ])
751
+ identity = helper .make_node ("Identity" , ["Z" ], ["A1" ])
752
+ add = helper .make_node ("Add" , ["Z" , "B" ], ["A2" ])
753
+ graph = helper .make_graph (
754
+ [matmul , add , identity ],
755
+ "test" ,
756
+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , (32 , 10 )),
757
+ helper .make_tensor_value_info ("Y" , TensorProto .FLOAT , (10 , 16 )),
758
+ helper .make_tensor_value_info ("B" , TensorProto .FLOAT , (1 , 16 ))],
759
+ [helper .make_tensor_value_info ("A1" , TensorProto .FLOAT , (32 , 16 )),
760
+ helper .make_tensor_value_info ("A2" , TensorProto .FLOAT , (32 , 16 ))]
761
+ )
762
+ optimized_model = self ._optimized (graph , ["fuse_matmul_add_bias_into_gemm" ])
763
+
764
+ assert optimized_model .graph == graph
765
+
656
766
def test_fuse_pad_into_conv (self ): # type: () -> None
657
767
pad = helper .make_node (
658
768
"Pad" ,
0 commit comments