Open
Description
🚀 The feature, motivation and pitch
torch.nn.functional.conv_transpose3d
is not currently supported on MPS (Apple Silicon):
import torch
import torch.nn.functional as F
# Create a random input tensor on MPS
device = torch.device("mps")
input_tensor = torch.randn(1, 1, 8, 8, 8, device=device)
# Define a random weight tensor
weight = torch.randn(1, 1, 3, 3, 3, device=device)
# Perform the 3D transposed convolution
output = F.conv_transpose3d(input_tensor, weight)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: ConvTranspose 3D is not supported on MPS
>>> print(torch.__version__)
2.3.1
This limits, for example, the usage of medical image segmentation framework nnUNet to torch.device("cpu")
.
Alternatives
No response
Additional context
No response