8000 Update channel shuffle to return alias instead of self as-is (#99745) · pytorch/pytorch@5ee5afb · GitHub
[go: up one dir, main page]

Skip to content

Commit 5ee5afb

Browse files
soulitzerpytorchmergebot
authored andcommitted
Update channel shuffle to return alias instead of self as-is (#99745)
Partially addresses #99655 Pull Request resolved: #99745 Approved by: https://github.com/albanD
1 parent ab0a821 commit 5ee5afb

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

aten/src/ATen/native/ChanelShuffle.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,12 @@ Tensor channel_shuffle(const Tensor& self, int64_t groups) {
4747
#if defined(C10_MOBILE) && defined(USE_XNNPACK)
4848
if (self.is_contiguous(MemoryFormat::ChannelsLast) &&
4949
xnnpack::use_channel_shuffle(self, groups)) {
50-
auto output = self.numel() == 0 ? self : xnnpack::channel_shuffle(self, groups);
50+
auto output = self.numel() == 0 ? self.alias() : xnnpack::channel_shuffle(self, groups);
5151
return output;
5252
}
5353
#endif
5454

55-
auto output = self.numel() == 0 ? self : at::native_channel_shuffle(self, groups);
55+
auto output = self.numel() == 0 ? self.alias() : at::native_channel_shuffle(self, groups);
5656
return namedinference::propagate_names_if_nonempty(
5757
output,
5858
self.has_names() ? self.names() : at::ArrayRef<Dimname>{});

test/test_nn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6414,8 +6414,8 @@ def test_channel_shuffle(self):
64146414
self.assertEqual(y, y_ref)
64156415

64166416

6417-
def test_channel_shuffle_return_self(self):
6418-
# gh-76616: nn.ChannelShuffle will return self with an empty input tensor
6417+
def test_channel_shuffle_return_alias_of_self(self):
6418+
# gh-76616: nn.ChannelShuffle will return alias of self with an empty input tensor
64196419
groups = 3
64206420
input_tensor = torch.rand([0, 9, 4, 4])
64216421
output = torch.nn.ChannelShuffle(groups)(input_tensor)

0 commit comments

Comments
 (0)
0