8000 [Re-open 90266] [inductor] weight prepack for _convolution_transpose_… · pytorch/pytorch@bd4a5b4 · GitHub
[go: up one dir, main page]

Skip to content

Commit bd4a5b4

Browse files
chunyuan-wpytorchmergebot
authored andcommitted
[Re-open 90266] [inductor] weight prepack for _convolution_transpose_pointwise (#91955)
Re-open #90266 since earlier pr on that stack got reverted. Depend on internal ideep upgrade. [Update]: internal ideep upgrade issue is resolved in #92239. Pull Request resolved: #91955 Approved by: https://github.com/jgong5, https://github.com/desertfire
1 parent cc49f5a commit bd4a5b4

File tree

7 files changed

+315
-53
lines changed

7 files changed

+315
-53
lines changed

aten/src/ATen/native/mkldnn/Conv.cpp

Lines changed: 80 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -594,25 +594,29 @@ Tensor& mkldnn_convolution_pointwise_binary_(
594594
return other_t;
595595
}
596596

597-
static inline std::vector<int64_t> padding_r(
598-
IntArrayRef padding, IntArrayRef output_padding)
599-
{
600-
// ConvTranpose padding adjustment
601-
//
602-
// PyTorch uses padding/output_padding:
603-
// osize = (isize - 1) * stride - 2 * padding + dilation * (kernel_size - 1) + output_padding + 1
604-
//
605-
// MKLDNN uses padding_l/padding_r:
606-
// osize = (isize - 1) * stride - padding_l - padding_r + dilation * (kernel_size - 1) + 1
607-
//
608-
// So: padding_l = padding, padding_r = padding - output_padding
609-
//
610-
auto dim = padding.size();
611-
std::vector<int64_t> pad_r(dim);
612-
for (const auto d : c10::irange(dim)) {
613-
pad_r[d] = padding[d] - output_padding[d];
597+
std::vector<int64_t> _original_deconv_weight_size(
598+
const Tensor& weight_t,
599+
int64_t groups) {
600+
TORCH_CHECK(weight_t.is_mkldnn() || weight_t.is_meta(), "expects weight_t to be mkldnn or meta tensor");
601+
// The size of weight_t is the prepacked size.
602+
// Groups > 1: [g*o, i/g, ...]
603+
// Groups == 1: [o, i, ...]
604+
// Returns original weight size in [i, o, ...]
605+
auto dim = weight_t.sizes().size();
606+
TORCH_CHECK(dim > 2);
607+
608+
std::vector<int64_t> weight_IOHW_sizes(dim);
609+
if (groups > 1) {
610+
weight_IOHW_sizes[0] = weight_t.sizes()[1] * groups;
611+
weight_IOHW_sizes[1] = weight_t.sizes()[0] / groups;
612+
} else {
613+
weight_IOHW_sizes[0] = weight_t.sizes()[1];
614+
weight_IOHW_sizes[1] = weight_t.sizes()[0];
614615
}
615-
return pad_r;
616+
for (const auto d : c10::irange(2, dim)) {
617+
weight_IOHW_sizes[d] = weight_t.sizes()[d];
618+
}
619+
return weight_IOHW_sizes;
616620
}
617621

618622

@@ -625,6 +629,7 @@ Tensor _mkldnn_convolution_transpose(
625629
IntArrayRef stride,
626630
IntArrayRef dilation,
627631
int64_t groups,
632+
bool use_channels_last,
628633
c10::string_view attr = "none",
629634
torch::List<c10::optional<at::Scalar>> scalars =
630635
torch::List<c10::optional<at::Scalar>>(),
@@ -644,22 +649,33 @@ Tensor _mkldnn_convolution_transpose(
644649
TORCH_CHECK(mkldnn_bf16_device_check(),
645650
"mkldnn_convolution_transpose: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq");
646651
}
647-
bool is_channels_last = input_t.suggest_memory_format() == at::MemoryFormat::ChannelsLast;
648652

649-
auto output_sizes = conv_input_size(input_t.sizes(), weight_t.sizes(), padding, output_padding, stride, dilation, groups);
650-
auto output = at::empty({0}, input_t.options());
653+
std::vector<int64_t> weight_IOHW_sizes = weight_t.is_mkldnn() ? _original_deconv_weight_size(weight_t, groups) : weight_t.sizes().vec();
654+
655+
auto memory_format =
656+
mkldnn_convolution_memory_format(input_t.ndimension(), use_channels_last);
657+
658+
auto input = input_t.is_mkldnn() ? input_t : input_t.contiguous(memory_format);
659+
auto weight = weight_t.is_mkldnn() ? weight_t : weight_t.contiguous(memory_format);
651660

652-
const ideep::tensor x = itensor_from_tensor(input_t);
653-
ideep::tensor w = itensor_from_tensor(weight_t);
654-
// mkldnn transposed convolution has weight in logical order of OIHW or OIDHW,
655-
// while PyTorch has IOHW or IODHW, `._tranpose()` switches strides (no memory copy).
656-
w.transpose_(0, 1);
661+
auto output_sizes = conv_input_size(input.sizes(), weight_IOHW_sizes, padding, output_padding, stride, dilation, groups);
662+
auto output = at::empty({0}, input.options());
663+
664+
const ideep::tensor x = itensor_from_tensor(input);
665+
666+
ideep::tensor w = itensor_from_tensor(weight);
667+
if (!weight.is_mkldnn()) {
668+
// mkldnn transposed convolution has weight in logical order of OIHW or OIDHW,
669+
// while PyTorch has IOHW or IODHW, `._tranpose()` switches strides (no memory copy).
670+
w.transpose_(0, 1);
671+
}
657672

658673
ideep::tensor y;
659-
if (is_channels_last) {
660-
output.resize_(output_sizes, input_t.suggest_memory_format());
674+
if (use_channels_last) {
675+
output.resize_(output_sizes, memory_format);
661676
y = itensor_from_tensor(output);
662677
}
678+
663679
if (bias.defined()) {
664680
const ideep::tensor b = itensor_from_tensor(bias);
665681
ideep::convolution_transpose_forward::compute(
@@ -687,10 +703,10 @@ Tensor _mkldnn_convolution_transpose(
687703
groups,
688704
op_attr);
689705
}
690-
if (input_t.is_mkldnn()) {
691-
return MKLDNNTensor(y, input_t.options());
692-
} else if (!is_channels_last) {
693-
return mkldnn_to_dense(MKLDNNTensor(y, input_t.options()));
706+
if (input.is_mkldnn()) {
707+
return MKLDNNTensor(y, input.options());
708+
} else if (!use_channels_last) {
709+
return mkldnn_to_dense(MKLDNNTensor(y, input.options()));
694710
} else {
695711
TORCH_INTERNAL_ASSERT(y.get_desc().is_nhwc());
696712
return output;
@@ -710,6 +726,8 @@ Tensor mkldnn_convolution_transpose_pointwise(
710726
torch::List<c10::optional<at::Scalar>> scalars,
711727
c10::optional<c10::string_view> algorithm) {
712728
c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset);
729+
bool use_channels_last =
730+
weight_t.is_mkldnn() || mkldnn_conv_use_channels_last(input_t, weight_t);
713731
return _mkldnn_convolution_transpose(
714732
input_t,
715733
weight_t,
@@ -719,12 +737,32 @@ Tensor mkldnn_convolution_transpose_pointwise(
719737
stride,
720738
dilation,
721739
groups,
740+
use_channels_last,
722741
attr,
723742
scalars,
724743
algorithm
725744
);
726745
}
727746

747+
Tensor mkldnn_convolution_transpose_pointwise_meta(
748+
const Tensor& input_t,
749+
const Tensor& weight_t,
750+
const c10::optional<Tensor>& bias_opt,
751+
IntArrayRef padding,
752+
IntArrayRef output_padding,
753+
IntArrayRef stride,
754+
IntArrayRef dilation,
755+
int64_t groups,
756+
c10::string_view attr,
757+
torch::List<c10::optional<at::Scalar>> scalars,
758+
c10::optional<c10::string_view> algorithm) {
759+
760+
std::vector<int64_t> weight_IOHW_sizes = _original_deconv_weight_size(weight_t, groups);
761+
auto output_sizes = conv_input_size(input_t.sizes(), weight_IOHW_sizes, padding, output_padding, stride, dilation, groups);
762+
763+
auto output = at::empty(output_sizes, input_t.options());
764+
return output;
765+
}
728766

729767
Tensor mkldnn_convolution_backward_input(
730768
IntArrayRef input_size,
@@ -871,7 +909,16 @@ TORCH_LIBRARY_IMPL(mkldnn, MkldnnCPU, m) {
871909
m.impl(
872910
TORCH_SELECTIVE_NAME("mkldnn::_convolution_pointwise_.binary"),
873911
TORCH_FN(mkldnn_convolution_pointwise_binary_));
912+
m.impl(
913+
TORCH_SELECTIVE_NAME("mkldnn::_convolution_transpose_pointwise"),
914+
TORCH_FN(mkldnn_convolution_transpose_pointwise));
915+
}
916+
917+
TORCH_LIBRARY_IMPL(mkldnn, Meta, m) {
918+
m.impl(
919+
TORCH_SELECTIVE_NAME("mkldnn::_convolution_transpose_pointwise"),
920+
TORCH_FN(mkldnn_convolution_transpose_pointwise_meta));
874921
}
875922
}} // namespace at::native
876923

877-
#endif
924+
#endif

aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,105 @@ Tensor mkldnn_reorder_conv3d_weight(
168168
return new_with_itensor_mkldnn(std::move(result), optTypeMetaToScalarType(self.options().dtype_opt()), self.options().device_opt());
169169
}
170170

171+
172+
ideep::tensor::desc get_conv_transpose_expected_weights_desc(
173+
const ideep::tensor::dims& weights_dims,
174+
ideep::tensor::data_type w_dtype,
175+
const ideep::tensor::dims& strides,
176+
const ideep::tensor::dims& padding_l,
177+
const ideep::tensor::dims& padding_r,
178+
const ideep::tensor::dims& dilates,
179+
int groups,
180+
bool channels_last,
181+
ideep::algorithm aalgorithm,
182+
ideep::data_type x_dtype,
183+
const ideep::dims& src_dims) {
184+
if (channels_last) {
185+
return ideep::convolution_transpose_forward::expected_weights_desc<true>(
186+
weights_dims,
187+
w_dtype,
188+
strides,
189+
padding_l,
190+
padding_r,
191+
dilates,
192+
groups,
193+
aalgorithm,
194+
ideep::prop_kind::forward,
195+
src_dims);
196+
} else {
197+
return ideep::convolution_transpose_forward::expected_weights_desc<false>(
198+
weights_dims,
199+
w_dtype,
200+
strides,
201+
padding_l,
202+
padding_r,
203+
dilates,
204+
groups,
205+
aalgorithm,
206+
ideep::prop_kind::forward,
207+
src_dims);
208+
}
209+
}
210+
211+
212+
Tensor mkldnn_reorder_conv_transpose2d_weight(
213+
const Tensor& self,
214+
IntArrayRef padding,
215+
IntArrayRef output_padding,
216+
IntArrayRef stride,
217+
IntArrayRef dilation,
218+
int64_t groups,
219+
c10::OptionalArrayRef<int64_t> input_size) {
220+
c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset);
221+
if (self.scalar_type() == ScalarType::BFloat16) {
222+
TORCH_CHECK(mkldnn_bf16_device_check(),
223+
"mkldnn_reorder_conv2d_weight: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq");
224+
}
225+
226+
ideep::tensor w = itensor_from_tensor(self);
227+
228+
ideep::dims src_dims = ideep::dims();
229+
bool is_channels_last = false;
230+
if (input_size.has_value()) {
231+
src_dims = input_size.value().vec();
232+
// if has input size, we always use channels last.
233+
is_channels_last = true;
234+
}
235+
236+
auto expected_desc = get_conv_transpose_expected_weights_desc(
237+
w.get_dims(),
238+
w.get_data_type(),
239+
stride.vec(),
240+
padding.vec(),
241+
padding_r(padding, output_padding),
242+
dilation.vec(),
243+
groups,
244+
is_channels_last,
245+
ideep::algorithm::deconvolution_direct,
246+
w.get_data_type(),
247+
src_dims);
248+
249+
if (groups > 1) {
250+
expected_desc = expected_desc.transpose(1, 2);
251+
} else {
252+
expected_desc = expected_desc.transpose(0, 1);
253+
}
254+
255+
ideep::tensor result;
256+
result.init(expected_desc);
257+
w.transpose_(0, 1);
258+
result.feed_from(w, /*is_deconv_weights*/true);
259+
260+
return new_with_itensor_mkldnn(std::move(result), optTypeMetaToScalarType(self.options().dtype_opt()),
261+
self.options().device_opt());
262+
}
263+
264+
TORCH_LIBRARY_IMPL(mkldnn, MkldnnCPU, m) {
265+
m.impl(
266+
TORCH_SELECTIVE_NAME("mkldnn::_reorder_convolution_transpose_weight"),
267+
TORCH_FN(mkldnn_reorder_conv_transpose2d_weight));
268+
}
269+
171270
#else
172271

173272
Tensor mkldnn_to_dense(const Tensor& mkldnn_tensor, c10::optional<ScalarType> dtype) {

aten/src/ATen/native/mkldnn/RegisterMkldnnOpContextClass.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ TORCH_LIBRARY(mkldnn, m) {
4444
"mkldnn::_convolution_pointwise_.binary(Tensor X, Tensor(a!) other, Tensor W, Tensor? B, int[] padding, int[] stride, int[] dilation, int groups, str binary_attr, Scalar? alpha, str? unary_attr, Scalar?[] unary_scalars, str? unary_algorithm) -> Tensor(a!) Y"));
4545
m.def(TORCH_SELECTIVE_SCHEMA(
4646
"mkldnn::_convolution_transpose_pointwise(Tensor X, Tensor W, Tensor? B, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, str attr, Scalar?[] scalars, str? algorithm) -> Tensor Y"));
47+
m.def(TORCH_SELECTIVE_SCHEMA(
48+
"mkldnn::_reorder_convolution_transpose_weight(Tensor self, int[2] padding=0, int[2] output_padding=0, int[2] stride=1, int[2] dilation=1, int groups=1, int[]? input_size=None) -> Tensor Y"));
4749
}
4850

4951
TORCH_LIBRARY(mkldnn_prepacked, m) {

aten/src/ATen/native/mkldnn/Utils.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,27 @@ void check_mkldnn_binary_fusion_inputs(
3333
const Tensor& weight,
3434
const Tensor& bias);
3535

36+
static inline std::vector<int64_t> padding_r(
37+
IntArrayRef padding, IntArrayRef output_padding)
38+
{
39+
// ConvTranpose padding adjustment
40+
//
41+
// PyTorch uses padding/output_padding:
42+
// osize = (isize - 1) * stride - 2 * padding + dilation * (kernel_size - 1) + output_padding + 1
43+
//
44+
// MKLDNN uses padding_l/padding_r:
45+
// osize = (isize - 1) * stride - padding_l - padding_r + dilation * (kernel_size - 1) + 1
46+
//
47+
// So: padding_l = padding, padding_r = padding - output_padding
48+
//
49+
auto dim = padding.size();
50+
std::vector<int64_t> pad_r(dim);
51+
for (const auto d : 10000 c10::irange(dim)) {
52+
pad_r[d] = padding[d] - output_padding[d];
53+
}
54+
return pad_r;
55+
}
56+
3657
#if AT_MKLDNN_ENABLED()
3758

3859
using AttrFunction = std::function<ideep::attr_t(

test/test_mkldnn_fusion.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -350,8 +350,8 @@ def forward(self, x):
350350
for pointwise_name, pointwise_info in self._unary_list().items():
351351
for dim in [2]:
352352
channels_last = torch.channels_last if dim == 2 else torch.channels_last_3d
353-
options = itertools.product([True, False], [1, 2], [1, 4], [torch.contiguous_format, channels_last])
354-
for bias, dilation, groups, memory_format in options:
353+
options = itertools.product([True, False], [1, 2], [1, 4], [torch.contiguous_format, channels_last], [False, True])
354+
for bias, dilation, groups, memory_format, prepack_weight in options:
355355
oC = 32 * groups
356356
iC = 3 * groups
357357
x_shape = (1, iC) + input_shapes[dim]
@@ -363,6 +363,21 @@ def forward(self, x):
363363
attr = pointwise_info.attr
364364
scalars = pointwise_info.scalars
365365
algorithm = pointwise_info.algorithm
366+
367+
if prepack_weight:
368+
packed_weight = torch.ops.mkldnn._reorder_convolution_transpose_weight(
369+
mod.conv_transpose.weight.to_mkldnn(),
370+
mod.conv_transpose.padding,
371+
mod.conv_transpose.output_padding,
372+
mod.conv_transpose.stride,
373+
mod.conv_transpose.dilation,
374+
mod.conv_transpose.groups,
375+
x.size())
376+
mod.conv_transpose.weight = torch.nn.Parameter(
377+
packed_weight,
378+
requires_grad=mod.conv_transpose.weight.requires_grad,
379+
)
380+
366381
fused = torch.ops.mkldnn._convolution_transpose_pointwise(
367382
x,
368383
mod.conv_transpose.weight,

0 commit comments

Comments
 (0)
0