8000 [MPS] Fix backward computation for channels_last tensors in MPS backend by NripeshN · Pull Request #143136 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[MPS] Fix backward computation for channels_last tensors in MPS backend #143136

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
Prev Previous commit
Next Next commit
Simplify gradOutputTensor initialization in mps_convolution_backward_…
…input by removing scalar type and shape parameters
  • Loading branch information
NripeshN committed Dec 13, 2024
commit d284da24b3ced65fb33a01cb6fb1ea387fded057
2 changes: 1 addition & 1 deletion aten/src/ATen/native/mps/operations/Convolution.mm
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ static Tensor mps_convolution_backward_input(IntArrayRef input_size,
getTensorsStringKey({grad_output_t, weight_t});
}
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
auto gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(grad_output_t.scalar_type()), mps::getMPSShape(grad_output_t));
auto gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output_t);
auto weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight_t);

MPSGraphTensor* gradInputTensor;
Expand Down
0