@@ -594,25 +594,29 @@ Tensor& mkldnn_convolution_pointwise_binary_(
594
594
return other_t ;
595
595
}
596
596
597
- static inline std::vector<int64_t > padding_r (
598
- IntArrayRef padding, IntArrayRef output_padding)
599
- {
600
- // ConvTranpose padding adjustment
601
- //
602
- // PyTorch uses padding/output_padding:
603
- // osize = (isize - 1) * stride - 2 * padding + dilation * (kernel_size - 1) + output_padding + 1
604
- //
605
- // MKLDNN uses padding_l/padding_r:
606
- // osize = (isize - 1) * stride - padding_l - padding_r + dilation * (kernel_size - 1) + 1
607
- //
608
- // So: padding_l = padding, padding_r = padding - output_padding
609
- //
610
- auto dim = padding.size ();
611
- std::vector<int64_t > pad_r (dim);
612
- for (const auto d : c10::irange (dim)) {
613
- pad_r[d] = padding[d] - output_padding[d];
597
+ std::vector<int64_t > _original_deconv_weight_size (
598
+ const Tensor& weight_t ,
599
+ int64_t groups) {
600
+ TORCH_CHECK (weight_t .is_mkldnn () || weight_t .is_meta (), " expects weight_t to be mkldnn or meta tensor" );
601
+ // The size of weight_t is the prepacked size.
602
+ // Groups > 1: [g*o, i/g, ...]
603
+ // Groups == 1: [o, i, ...]
604
+ // Returns original weight size in [i, o, ...]
605
+ auto dim = weight_t .sizes ().size ();
606
+ TORCH_CHECK (dim > 2 );
607
+
608
+ std::vector<int64_t > weight_IOHW_sizes (dim);
609
+ if (groups > 1 ) {
610
+ weight_IOHW_sizes[0 ] = weight_t .sizes ()[1 ] * groups;
611
+ weight_IOHW_sizes[1 ] = weight_t .sizes ()[0 ] / groups;
612
+ } else {
613
+ weight_IOHW_sizes[0 ] = weight_t .sizes ()[1 ];
614
+ weight_IOHW_sizes[1 ] = weight_t .sizes ()[0 ];
614
615
}
615
- return pad_r;
616
+ for (const auto d : c10::irange (2 , dim)) {
617
+ weight_IOHW_sizes[d] = weight_t .sizes ()[d];
618
+ }
619
+ return weight_IOHW_sizes;
616
620
}
617
621
618
622
@@ -625,6 +629,7 @@ Tensor _mkldnn_convolution_transpose(
625
629
IntArrayRef stride,
626
630
IntArrayRef dilation,
627
631
int64_t groups,
632
+ bool use_channels_last,
628
633
c10::string_view attr = " none" ,
629
634
torch::List<c10::optional<at::Scalar>> scalars =
630
635
torch::List<c10::optional<at::Scalar>>(),
@@ -644,22 +649,33 @@ Tensor _mkldnn_convolution_transpose(
644
649
TORCH_CHECK (mkldnn_bf16_device_check (),
645
650
" mkldnn_convolution_transpose: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq" );
646
651
}
647
- bool is_channels_last = input_t .suggest_memory_format () == at::MemoryFormat::ChannelsLast;
648
652
649
- auto output_sizes = conv_input_size (input_t .sizes (), weight_t .sizes (), padding, output_padding, stride, dilation, groups);
650
- auto output = at::empty ({0 }, input_t .options ());
653
+ std::vector<int64_t > weight_IOHW_sizes = weight_t .is_mkldnn () ? _original_deconv_weight_size (weight_t , groups) : weight_t .sizes ().vec ();
654
+
655
+ auto memory_format =
656
+ mkldnn_convolution_memory_format (input_t .ndimension (), use_channels_last);
657
+
658
+ auto input = input_t .is_mkldnn () ? input_t : input_t .contiguous (memory_format);
659
+ auto weight = weight_t .is_mkldnn () ? weight_t : weight_t .contiguous (memory_format);
651
660
652
- const ideep::tensor x = itensor_from_tensor (input_t );
653
- ideep::tensor w = itensor_from_tensor (weight_t );
654
- // mkldnn transposed convolution has weight in logical order of OIHW or OIDHW,
655
- // while PyTorch has IOHW or IODHW, `._tranpose()` switches strides (no memory copy).
656
- w.transpose_ (0 , 1 );
661
+ auto output_sizes = conv_input_size (input.sizes (), weight_IOHW_sizes, padding, output_padding, stride, dilation, groups);
662
+ auto output = at::empty ({0 }, input.options ());
663
+
664
+ const ideep::tensor x = itensor_from_tensor (input);
665
+
666
+ ideep::tensor w = itensor_from_tensor (weight);
667
+ if (!weight.is_mkldnn ()) {
668
+ // mkldnn transposed convolution has weight in logical order of OIHW or OIDHW,
669
+ // while PyTorch has IOHW or IODHW, `._tranpose()` switches strides (no memory copy).
670
+ w.transpose_ (0 , 1 );
671
+ }
657
672
658
673
ideep::tensor y;
659
- if (is_channels_last ) {
660
- output.resize_ (output_sizes, input_t . suggest_memory_format () );
674
+ if (use_channels_last ) {
675
+ output.resize_ (output_sizes, memory_format );
661
676
y = itensor_from_tensor (output);
662
677
}
678
+
663
679
if (bias.defined ()) {
664
680
const ideep::tensor b = itensor_from_tensor (bias);
665
681
ideep::convolution_transpose_forward::compute (
@@ -687,10 +703,10 @@ Tensor _mkldnn_convolution_transpose(
687
703
groups,
688
704
op_attr);
689
705
}
690
- if (input_t .is_mkldnn ()) {
691
- return MKLDNNTensor (y, input_t .options ());
692
- } else if (!is_channels_last ) {
693
- return mkldnn_to_dense (MKLDNNTensor (y, input_t .options ()));
706
+ if (input .is_mkldnn ()) {
707
+ return MKLDNNTensor (y, input .options ());
708
+ } else if (!use_channels_last ) {
709
+ return mkldnn_to_dense (MKLDNNTensor (y, input .options ()));
694
710
} else {
695
711
TORCH_INTERNAL_ASSERT (y.get_desc ().is_nhwc ());
696
712
return output;
@@ -710,6 +726,8 @@ Tensor mkldnn_convolution_transpose_pointwise(
710
726
torch::List<c10::optional<at::Scalar>> scalars,
711
727
c10::optional<c10::string_view> algorithm) {
712
728
c10::impl::ExcludeDispatchKeyGuard edkg (c10::autograd_dispatch_keyset);
729
+ bool use_channels_last =
730
+ weight_t .is_mkldnn () || mkldnn_conv_use_channels_last (input_t , weight_t );
713
731
return _mkldnn_convolution_transpose (
714
732
input_t ,
715
733
weight_t ,
@@ -719,12 +737,32 @@ Tensor mkldnn_convolution_transpose_pointwise(
719
737
stride,
720
738
dilation,
721
739
groups,
740
+ use_channels_last,
722
741
attr,
723
742
scalars,
724
743
algorithm
725
744
);
726
745
}
727
746
747
+ Tensor mkldnn_convolution_transpose_pointwise_meta (
748
+ const Tensor& input_t ,
749
+ const Tensor& weight_t ,
750
+ const c10::optional<Tensor>& bias_opt,
751
+ IntArrayRef padding,
752
+ IntArrayRef output_padding,
753
+ IntArrayRef stride,
754
+ IntArrayRef dilation,
755
+ int64_t groups,
756
+ c10::string_view attr,
757
+ torch::List<c10::optional<at::Scalar>> scalars,
758
+ c10::optional<c10::string_view> algorithm) {
759
+
760
+ std::vector<int64_t > weight_IOHW_sizes = _original_deconv_weight_size (weight_t , groups);
761
+ auto output_sizes = conv_input_size (input_t .sizes (), weight_IOHW_sizes, padding, output_padding, stride, dilation, groups);
762
+
763
+ auto output = at::empty (output_sizes, input_t .options ());
764
+ return output;
765
+ }
728
766
729
767
Tensor mkldnn_convolution_backward_input (
730
768
IntArrayRef input_size,
@@ -871,7 +909,16 @@ TORCH_LIBRARY_IMPL(mkldnn, MkldnnCPU, m) {
871
909
m.impl (
872
910
TORCH_SELECTIVE_NAME (" mkldnn::_convolution_pointwise_.binary" ),
873
911
TORCH_FN (mkldnn_convolution_pointwise_binary_));
912
+ m.impl (
913
+ TORCH_SELECTIVE_NAME (" mkldnn::_convolution_transpose_pointwise" ),
914
+ TORCH_FN (mkldnn_convolution_transpose_pointwise));
915
+ }
916
+
917
+ TORCH_LIBRARY_IMPL (mkldnn, Meta, m) {
918
+ m.impl (
919
+ TORCH_SELECTIVE_NAME (" mkldnn::_convolution_transpose_pointwise" ),
920
+ TORCH_FN (mkldnn_convolution_transpose_pointwise_meta));
874
921
}
875
922
}} // namespace at::native
876
923
877
- #endif
924
+ #endif
0 commit comments