8000 MPS backend silently ignores dimension mismatch · Issue #153378 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
MPS backend silently ignores dimension mismatch #153378
@iandanforth

Description

@iandanforth

🐛 Describe the bug

In a basic example of training ResNet18 on CIFAR10 you should get an error if you specify the wrong number of logits for the model. This happens as expected with the CPU backend, but not with mps.

import torch

from torch import (
    nn,
    optim
)
from torchvision import (
    datasets,
    transforms,
    models
)

def main():
    # Get data
    train_ds = datasets.CIFAR10(
        root="./data",
        train=True,
        download=True,
        transform=transforms.ToTensor()
    )

    train_loader = torch.utils.data.DataLoader(
        train_ds,
        batch_size=128,
        shuffle=True
    )


    # device = torch.device("cpu") # Throws error correctly
    device = torch.device("mps") # Runs to completion
    model = models.resnet18(num_classes=5).to(device)

    # Loss and optimizer

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

    batch_count = len(train_loader)
    for epoch in range(2):
        running_loss = 0.0
        count = 0
        for batch_x, batch_y in train_loader:
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)

            optimizer.zero_grad()
            out = model(batch_x)
            loss = criterion(out, batch_y)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            count +=1
            print(f"{count}/{batch_count}")
            
        print(f"Epoch {epoch}: {running_loss/count}")


if __name__ == "__main__":
    main()

Expected Behavior

  File "/opt/miniconda3/envs/elicit/lib/python3.11/site-packages/torch/nn/functional.py", line 3494, in cross_entropy
    return torch._C._nn.cross_entropy_loss(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
IndexError: Target 7 is out of bounds.

Or similar when encountering the first label >= 5

Versions

Collecting environment information...
PyTorch version: 2.7.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 15.1 (arm64)
GCC version: Could not collect
Clang version: 16.0.0 (clang-1600.0.26.4)
CMake version: version 3.31.5
Libc version: N/A

Python version: 3.11.11 (main, Dec 11 2024, 10:25:04) [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-15.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M4 Pro

Versions of relevant libraries:
[pip3] numpy==2.2.5
[pip3] torch==2.7.0
[pip3] torchvision==0.22.0
[conda] numpy 2.2.5 pypi_0 pypi
[conda] torch 2.7.0 pypi_0 pypi
[conda] torchvision 0.22.0 pypi_0 pypi

cc @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: error checkingBugs related to incorrect/lacking error checkingmodule: mpsRelated to Apple Metal Performance Shaders frameworkneeds designWe want to add this feature but we need to figure out how firsttriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0