8000 MPS ConvTranspose 3D Support · Issue #130256 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
MPS ConvTranspose 3D Support #130256
@valosekj

Description

@valosekj

🚀 The feature, motivation and pitch

torch.nn.functional.conv_transpose3d is not currently supported on MPS (Apple Silicon):

import torch
import torch.nn.functional as F

# Create a random input tensor on MPS
device = torch.device("mps")
input_tensor = torch.randn(1, 1, 8, 8, 8, device=device)

# Define a random weight tensor
weight = torch.randn(1, 1, 3, 3, 3, device=device)

# Perform the 3D transposed convolution
output = F.conv_transpose3d(input_tensor, weight)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: ConvTranspose 3D is not supported on MPS
>>> print(torch.__version__)
2.3.1

This limits, for example, the usage of medical image segmentation framework nnUNet to torch.device("cpu").

Alternatives

No response

Additional context

No response

cc @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNot as big of a feature, but technically not a bug. Should be easy to fixmodule: convolutionProblems related to convolutions (THNN, THCUNN, CuDNN)module: mpsRelated to Apple Metal Performance Shaders frameworktriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0