8000 [MPS] Fix backward computation for channels_last tensors in MPS backend by NripeshN · Pull Request #143136 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[MPS] Fix backward computation for channels_last tensors in MPS backend #143136

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
27 changes: 10 additions & 17 deletions aten/src/ATen/native/mps/operations/Convolution.mm
Original file line number Diff line number Diff line change
Expand Up @@ -417,36 +417,29 @@ static Tensor mps_convolution_backward_input(IntArrayRef input_size,
assert(0 && "Check should have been done earlier\n");
}

MPSShape* gradOutputShape = getMPSShape(grad_output_t, memory_format);
MPSShape* mps_input_shape = getMPSShape(input_size);
NSString* ns_shape_key = [[gradOutputShape valueForKey:@"description"] componentsJoinedByString:@","];
string key;
if (is3DConv) {
key = "mps_3d_convolution_backward_input:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" +
":" + std::to_string(stride[2]) + std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" +
std::to_string(dilation[2]) + ":" + std::to_string(padding[0]) + ":" + std::to_string(padding[1]) + ":" +
std::to_string(padding[2]) + ":" + std::to_string(groups) + ":" + mem_format_key +
getTensorsStringKey({grad_output_t, weight_t}) + ":" + string([ns_shape_key UTF8String]);
getTensorsStringKey({grad_output_t, weight_t});

} else {
key = "mps_convolution_backward_input:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" +
std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" + std::to_string(padding[0]) + ":" +
std::to_string(padding[1]) + ":" + std::to_string(groups) + ":" + mem_format_key +
getTensorsStringKey({grad_output_t, weight_t}) + ":" + string([ns_shape_key UTF8String]);
getTensorsStringKey({grad_output_t, weight_t});
}
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* gradOutputTensor =
mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(grad_output_t), gradOutputShape);
MPSGraphTensor* weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight_t);
auto gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output_t);
auto weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight_t);

MPSGraphTensor* gradOutputTensorTranspose = gradOutputTensor;
if (is_channels_last) {
gradOutputTensorTranspose = mps::convertNHWCtoNCHW(mpsGraph, gradOutputTensorTranspose);
}
MPSGraphTensor* gradInputTensor;
MPSShape* weightOutputShape = mps::getMPSShape(weight_t);
// Depthwise conv is input feature channels = groups. So I in OIHW has to be 1.
bool isDepthwiseConv = ((groups > 1 && (weightOutputShape[1].intValue == 1)) && gradOutputShape.count >= 4 &&
bool isDepthwiseConv = ((groups > 1 && (weightOutputShape[1].intValue == 1)) && grad_output_t.ndimension() >= 4 &&
weightOutputShape.count >= 4 && !is_channels_last);

if (is3DConv) {
Expand All @@ -462,7 +455,7 @@ static Tensor mps_convolution_backward_input(IntArrayRef input_size,
padding[1],
padding[0],
groups);
gradInputTensor = [mpsGraph convolution3DDataGradientWithIncomingGradientTensor:gradOutputTensorTranspose
gradInputTensor = [mpsGraph convolution3DDataGradientWithIncomingGradientTensor:gradOutputTensor
weightsTensor:weightTensor
outputShape:mps_input_shape
forwardConvolutionDescriptor:conv3dDescriptor_
Expand All @@ -484,7 +477,7 @@ static Tensor mps_convolution_backward_input(IntArrayRef input_size,
withDimension:-4
name:nil];
gradInputTensor =
[mpsGraph depthwiseConvolution3DDataGradientWithIncomingGradientTensor:gradOutputTensorTranspose
[mpsGraph depthwiseConvolution3DDataGradientWithIncomingGradientTensor:gradOutputTensor
weightsTensor:weightTransposeTensor
outputShape:mps_input_shape
descriptor:depthWiseConv3dDescriptor_
Expand All @@ -501,7 +494,7 @@ static Tensor mps_convolution_backward_input(IntArrayRef input_size,
at::MemoryFormat::Contiguous,
groups);

gradInputTensor = [mpsGraph convolution2DDataGradientWithIncomingGradientTensor:gradOutputTensorTranspose
gradInputTensor = [mpsGraph convolution2DDataGradientWithIncomingGradientTensor:gradOutputTensor
weightsTensor:weightTensor
outputShape:mps_input_shape
forwardConvolutionDescriptor:conv2dDescriptor_
Expand All @@ -513,7 +506,7 @@ static Tensor mps_convolution_backward_input(IntArrayRef input_size,
newCachedGraph->gradInputTensor_ = gradInputTensor;
});

auto gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output_t, gradOutputShape);
auto gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output_t);
auto weightsPlaceholder = Placeholder(cachedGraph->weightTensor_, weight_t);
auto outputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, *grad_input);

Expand Down Expand Up @@ -758,4 +751,4 @@ static Tensor mps_convolution_transpose_backward_weight(IntArrayRef weight_size,
return std::tuple<Tensor, Tensor>{grad_input, grad_weight};
}

} // namespace at::native
} // namespace at::native
8 changes: 5 additions & 3 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -10593,20 +10593,22 @@ def helper(shape, in_channels=1, out_channels=1, kernel_size=3, groups=1):
helper(shape=(1024, 376, 9), in_channels=9, out_channels=9, groups=3)

# Regression test for https://github.com/pytorch/pytorch/issues/140902
# And https://github.com/pytorch/pytorch/issues/142344 (adding grad for input)
ic, oc, ks, f = 2, 5, 3, 7
conv = torch.nn.Conv1d(ic, oc, kernel_size=ks, padding=1).to("mps")
inp = torch.rand(1, ic, f, device="mps")
inp = torch.rand(1, ic, f, device="mps", requires_grad=True)
out = conv(inp)
grad_in = torch.rand(1, oc, f, device="mps")
grad_in_cl = torch.empty(1, f, oc, device="mps").transpose(1, 2)
grad_in_cl[:] = grad_in

# It does not matter whether grad_in contigous, or channels last, results should equal to each other
grad_rc = torch.autograd.grad((out,), (conv.weight, conv.bias), (grad_in,), retain_graph=True)
grad_rc_cl = torch.autograd.grad((out,), (conv.weight, conv.bias), (grad_in_cl,), retain_graph=True)
grad_rc = torch.autograd.grad((out,), (inp, conv.weight, conv.bias), (grad_in,), retain_graph=True)
grad_rc_cl = torch.autograd.grad((out,), (inp, conv.weight, conv.bias), (grad_in_cl,), retain_graph=True)

self.assertEqual(grad_rc[0], grad_rc_cl[0])
self.assertEqual(grad_rc[1], grad_rc_cl[1])
self.assertEqual(grad_rc[2], grad_rc_cl[2])

def test_conv1d_contiguous(self):
model_cpu = torch.nn.Conv1d(1, 128, 3)
Expand Down
Loading
0