8000 MPS Sparse Support · Issue #129842 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

MPS Sparse Support #129842

@rootjalex

Description

@rootjalex

🚀 The feature, motivation and pitch

I would like to be able to use torch.sparse on Metal, which currently breaks:

import torch

assert(torch.backends.mps.is_available())
mps_device = torch.device("mps")

crow_indices = torch.tensor([0, 2, 4], dtype=torch.int32).to(mps_device)
col_indices = torch.tensor([0, 1, 0, 1], dtype=torch.int32).to(mps_device)
values = torch.tensor([1, 2, 3, 4], dtype=torch.float32).to(mps_device)
csr = torch.sparse_csr_tensor(crow_indices, col_indices, values, dtype=torch.float32)

with the following error:

% python3 file.py
Traceback (most recent call last):
  File "file.py", line 9, in <module>
    csr = torch.sparse_csr_tensor(crow_indices, col_indices, values, dtype=torch.float32)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
NotImplementedError: Could not run 'new_compressed_tensor' from the 'mps:0' device.)

Alternatives

No response

Additional context

No response

cc @alexsamardzic @nikitaved @pearu @cpuhrsch @amjames @bhosmer @jcaip @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNot as big of a feature, but technically not a bug. Should be easy to fixmodule: mpsRelated to Apple Metal Performance Shaders frameworkmodule: sparseRelated to torch.sparsetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0