8000 [MPS] Convert `channels_last_3d` to `contiguous` for input tensor in … · pytorch/pytorch@90f19fe · GitHub
[go: up one dir, main page]

Skip to content

Commit 90f19fe

Browse files
hvaarapytorchmergebot
authored andcommitted
[MPS] Convert channels_last_3d to contiguous for input tensor in nn.Conv3d (#141780)
When the input tensor to Conv3d is in the channels_last_3d memory format the Conv3d op will generate incorrect output (see example image in #141471). This PR checks if the op is 3d, and then attempts to convert the input tensor to contiguous. Added a regression test that verifies the output by running the same op on the CPU. I'm unsure if Conv3d supports the channels last memory format after #128393. If it does, we should consider updating the logic to utilize this as it would be more efficient. Perhaps @DenisVieriu97 knows or has more context? Fixes #141471 Pull Request resolved: #141780 Approved by: https://github.com/malfet
1 parent 5deca07 commit 90f19fe

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

aten/src/ATen/native/mps/operations/Convolution.mm

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,13 +127,13 @@ static Tensor _mps_convolution_impl(const Tensor& input_t_,
127127
const bool is_macOS_13_2_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS);
128128
const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS);
129129
Tensor input_t = input_t_;
130-
if (!is_macOS_15_0_or_newer) {
130+
bool is3DConv = input_t.dim() == 5;
131+
if (!is_macOS_15_0_or_newer || is3DConv) {
131132
input_t = input_t.contiguous();
132133
}
133134

134135
TORCH_CHECK(((input_t.dim() < 5) || is_macOS_13_2_or_newer),
135136
"Conv3D is only supported on MPS for MacOS_13_2 or newer");
136-
bool is3DConv = input_t.dim() == 5;
137137

138138
TORCH_CHECK(isFloatingType(input_t.scalar_type()), "Convolution is supported only for Floating types");
139139

test/test_mps.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9058,6 +9058,19 @@ def test_conv3d_backward_collision(self):
90589058
# This used to crash with MPSNDArrayConvolutionA14.mm:4352: failed assertion
90599059
y2.sum().backward()
90609060

9061+
# Regression test for https://github.com/pytorch/pytorch/issues/141471
9062+
def test_conv3d_channels_last_3d(self):
9063+
m_cpu = nn.Conv3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0), device="cpu")
9064+
m_mps = copy.deepcopy(m_cpu).to("mps")
9065+
9066+
x_cpu = torch.randn(20, 16, 10, 50, 100, device="cpu").to(memory_format=torch.channels_last_3d)
9067+
x_mps = x_cpu.detach().clone().to("mps")
9068+
9069+
res_cpu = m_cpu(x_cpu)
9070+
res_mps = m_mps(x_mps)
9071+
9072+
self.assertEqual(res_cpu, res_mps)
9073+
90619074
def test_gemm_permute_transpose(self):
90629075
batch_size = 32
90639076
n = 20

0 commit comments

Comments
 (0)
0