8000 sparse tensor support mps initial changes · pytorch/pytorch@b75c8aa · GitHub
[go: up one dir, main page]

Skip to content

Commit b75c8aa

Browse files
committed
sparse tensor support mps initial changes
1 parent 029e2b0 commit b75c8aa

File tree

9 files changed

+97
-18
lines changed

9 files changed

+97
-18
lines changed

aten/src/ATen/native/native_functions.yaml

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7283,26 +7283,26 @@
72837283

72847284
- func: _sparse_coo_tensor_with_dims(int sparse_dim, int dense_dim, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
72857285
dispatch:
7286-
SparseCPU, SparseCUDA, SparseMeta, Meta: new_with_dims_sparse
7286+
SparseCPU, SparseCUDA, SparseMeta, SparseMPS, Meta: new_with_dims_sparse
72877287
autogen: _sparse_coo_tensor_with_dims.out
72887288

72897289
- func: _sparse_coo_tensor_with_dims_and_tensors(int sparse_dim, int dense_dim, SymInt[] size, Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False, bool? is_coalesced=None) -> Tensor
72907290
dispatch:
7291-
SparseCPU, SparseCUDA, SparseMeta, Meta: new_with_dims_and_tensor_sparse_symint
7291+
SparseCPU, SparseCUDA, SparseMeta, SparseMPS, Meta: new_with_dims_and_tensor_sparse_symint
72927292
autogen: _sparse_coo_tensor_with_dims_and_tensors.out
72937293

72947294
- func: sparse_resize_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!)
72957295
use_const_ref_for_mutable_tensors: True
72967296
variants: method
72977297
dispatch:
7298-
SparseCPU, SparseCUDA, SparseMeta: sparse_resize_
7298+
SparseCPU, SparseCUDA, SparseMPS, SparseMeta: sparse_resize_
72997299
autogen: sparse_resize, sparse_resize.out
73007300

73017301
- func: sparse_resize_and_clear_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!)
73027302
use_const_ref_for_mutable_tensors: True
73037303
variants: method
73047304
dispatch:
7305-
SparseCPU, SparseCUDA, SparseMeta: sparse_resize_and_clear_
7305+
SparseCPU, SparseCUDA, SparseMPS, SparseMeta: sparse_resize_and_clear_
73067306
autogen: sparse_resize_and_clear, sparse_resize_and_clear.out
73077307

73087308
- func: sparse_mask(Tensor self, Tensor mask) -> Tensor
@@ -7338,8 +7338,8 @@
73387338
- func: sparse_dim(Tensor self) -> int
73397339
variants: method
73407340
dispatch:
7341-
SparseCPU, SparseCUDA, SparseMeta: sparse_dim_sparse
7342-
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sparse_dim_sparse_csr
7341+
SparseCPU, SparseCUDA, SparseMPS, SparseMeta: sparse_dim_sparse
7342+
SparseCsrCPU, SparseCsrCUDA, SparseCsrMPS, SparseCsrMeta: sparse_dim_sparse_csr
73437343
CompositeExplicitAutograd: sparse_dim_default
73447344
device_check: NoCheck
73457345
device_guard: False
@@ -7372,8 +7372,8 @@
73727372
- func: _nnz(Tensor self) -> int
73737373
variants: method
73747374
dispatch:
7375-
SparseCPU, SparseCUDA, SparseMeta: _nnz_sparse
7376-
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: _nnz_sparse_csr
7375+
SparseCPU, SparseCUDA, SparseMPS, SparseMeta: _nnz_sparse
7376+
SparseCsrCPU, SparseCsrCUDA, SparseCsrMPS, SparseCsrMeta: _nnz_sparse_csr
73777377
device_check: NoCheck
73787378
device_guard: False
73797379

@@ -7394,22 +7394,22 @@
73947394
- func: is_coalesced(Tensor self) -> bool
73957395
variants: method
73967396
dispatch:
7397-
SparseCPU, SparseCUDA, SparseMeta: is_coalesced_sparse
7397+
SparseCPU, SparseCUDA, SparseMPS, SparseMeta: is_coalesced_sparse
73987398
CompositeExplicitAutograd: is_coalesced_default
73997399
device_check: NoCheck
74007400
device_guard: False
74017401

74027402
- func: _indices(Tensor(a) self) -> Tensor(a)
74037403
variants: method
74047404
dispatch:
7405-
SparseCPU, SparseCUDA, SparseMeta: _indices_sparse
7405+
SparseCPU, SparseCUDA, SparseMPS, SparseMeta: _indices_sparse
74067406
device_check: NoCheck
74077407
device_guard: False
74087408

74097409
- func: _values(Tensor(a) self) -> Tensor(a)
74107410
variants: method
74117411
dispatch:
7412-
SparseCPU, SparseCUDA, SparseMeta: _values_sparse
7412+
SparseCPU, SparseCUDA, SparseMPS, SparseMeta: _values_sparse
74137413
device_check: NoCheck
74147414
device_guard: False
74157415

@@ -7419,7 +7419,7 @@
74197419
- func: _coalesced_(Tensor(a!) self, bool coalesced) -> Tensor(a!)
74207420
variants: method
74217421
dispatch:
7422-
SparseCPU, SparseCUDA, SparseMeta: _coalesced_sparse_
7422+
SparseCPU, SparseCUDA, SparseMPS, SparseMeta: _coalesced_sparse_
74237423
device_check: NoCheck
74247424
device_guard: False
74257425
autogen: _coalesced, _coalesced.out
@@ -7508,9 +7508,9 @@
75087508
- func: _to_sparse.sparse_dim(Tensor self, int sparse_dim) -> Tensor
75097509
variants: method
75107510
dispatch:
7511-
CPU, CUDA: dense_to_sparse
7512-
SparseCPU, SparseCUDA: sparse_coo_to_sparse
7513-
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sparse_compressed_to_sparse
7511+
CPU, CUDA, MPS: dense_to_sparse
7512+
SparseCPU, SparseCUDA, SparseMPS: sparse_coo_to_sparse
7513+
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta, SparseCsrMPS: sparse_compressed_to_sparse
75147514
autogen: _to_sparse.sparse_dim_out
75157515

75167516
- func: to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor
@@ -7520,8 +7520,8 @@
75207520
- func: _to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor
75217521
variants: method
75227522
dispatch:
7523-
CPU, CUDA: dense_to_sparse
7524-
SparseCPU, SparseCUDA: sparse_coo_to_sparse
7523+
CPU, CUDA, MPS: dense_to_sparse
7524+
SparseCPU, SparseCUDA, SparseMPS: sparse_coo_to_sparse
75257525
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sparse_compressed_to_sparse
75267526
autogen: _to_sparse.out
75277527

c10/core/Backend.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ enum class Backend {
3838
SparseCUDA,
3939
SparseCsrCPU,
4040
SparseCsrCUDA,
41+
SparseCsrMPS,
42+
SparseMPS,
4143
SparseHIP,
4244
SparseVE,
4345
SparseXPU,
@@ -94,6 +96,10 @@ inline Backend dispatchKeyToBackend(DispatchKey t) {
9496
return Backend::SparseCPU;
9597
} else if (t == DispatchKey::SparseCUDA) {
9698
return Backend::SparseCUDA;
99+
} else if (t == DispatchKey::SparseMPS) {
100+
return Backend::SparseMPS;
101+
} else if (t == DispatchKey::SparseCsrMPS) {
102+
return Backend::SparseCsrMPS;
97103
} else if (t == DispatchKey::SparseHIP) {
98104
return Backend::SparseHIP;
99105
} else if (t == DispatchKey::SparseVE) {
@@ -172,6 +178,10 @@ inline DispatchKey backendToDispatchKey(Backend b) {
172178
return DispatchKey::SparseCPU;
173179
case Backend::SparseCUDA:
174180
return DispatchKey::SparseCUDA;
181+
case Backend::SparseMPS:
182+
return DispatchKey::SparseMPS;
183+
case Backend::SparseCsrMPS:
184+
return DispatchKey::SparseCsrMPS;
175185
case Backend::SparseHIP:
176186
return DispatchKey::SparseHIP;
177187
case Backend::SparseVE:
@@ -227,6 +237,8 @@ inline DeviceType backendToDeviceType(Backend b) {
227237
return DeviceType::CPU;
228238
case Backend::CUDA:
229239
case Backend::SparseCUDA:
240+
case Backend::SparseMPS:
241+
case Backend::SparseCsrMPS:
230242
case Backend::QuantizedCUDA:
231243
case Backend::SparseCsrCUDA:
232244
return DeviceType::CUDA;
@@ -309,6 +321,10 @@ inline const char* toString(Backend b) {
309321
return "SparseCPU";
310322
case Backend::SparseCUDA:
311323
return "SparseCUDA";
324+
case Backend::SparseMPS:
325+
return "SparseMPS";
326+
case Backend::SparseCsrMPS:
327+
return "SparseCsrMPS";
312328
case Backend::SparseHIP:
313329
return "SparseHIP";
314330
case Backend::SparseVE:
@@ -361,6 +377,7 @@ inline bool isSparse(Backend b) {
361377
case Backend::SparseXPU:
362378
case Backend::SparseCPU:
363379
case Backend::SparseCUDA:
380+
case Backend::SparseMPS:
364381
case Backend::SparseHIP:
365382
case Backend::SparseVE:
366383
case Backend::SparsePrivateUse1:

c10/core/DispatchKey.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,8 @@ c10::DispatchKey parseDispatchKey(const std::string& k) {
354354

355355
{"SparseCPU", c10::DispatchKey::SparseCPU},
356356
{"SparseCUDA", c10::DispatchKey::SparseCUDA},
357+
{"SparseMPS", c10::DispatchKey::SparseMPS},
358+
{"SparseCsrMPS", c10::DispatchKey::SparseCsrMPS},
357359
{"SparseHIP", c10::DispatchKey::SparseHIP},
358360
{"SparseXPU", c10::DispatchKey::SparseXPU},
359361
{"SparseVE", c10::DispatchKey::SparseVE},

c10/core/Layout.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ inline Layout layout_from_backend(Backend backend) {
3232
switch (backend) {
3333
case Backend::SparseCPU:
3434
case Backend::SparseCUDA:
35+
case Backend::SparseMPS:
36+
case Backend::SparseCsrMPS:
3537
case Backend::SparseHIP:
3638
case Backend::SparseVE:
3739
case Backend::SparseXPU:
@@ -46,7 +48,7 @@ inline Layout layout_from_backend(Backend backend) {
4648
case Backend::SparseCsrXPU:
4749
TORCH_CHECK(
4850
false,
49-
"Cannot map Backend SparseCsr(CPU|CUDA|HIP|VE|XPU) to a unique layout.");
51+
"Cannot map Backend SparseCsr(CPU|CUDA|HIP|VE|XPU|MPS) to a unique layout.");
5052
default:
5153
return Layout::Strided;
5254
}

test/test_mps.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12522,6 +12522,58 @@ def test_metal_capture(self):
1252212522
f"Capture file {capture_dirname} contains only metadata, i.e. {capture_listdir}")
1252312523

1252412524

12525+
12526+
class TestSparseMPS(TestCaseMPS):
12527+
def _get_basic_sparse_coo(self, device="mps"):
12528+
indices = torch.tensor([[0, 1], [2, 0]], dtype=torch.int64, device=device)
12529+
values = torch.tensor([1, 2], dtype=torch.float32, device=device)
12530+
size = (2, 3)
12531+
return torch.sparse_coo_tensor(indices, values, size, device=device)
12532+
12533+
def test_sparse_coo_tensor_with_dims(self):
12534+
indices = torch.zeros((2, 0), dtype=torch.int64, device="mps")
12535+
values = torch.tensor([], dtype=torch.float32, device="mps")
12536+
size = (2, 3)
12537+
t = torch.sparse_coo_tensor(indices, values, size, device="mps")
12538+
self.assertEqual(t.device.type, "mps")
12539+
self.assertEqual(t.layout, torch.sparse_coo)
12540+
12541+
def test_sparse_coo_tensor_with_dims_and_tensors(self):
12542+
indices = torch.tensor([[0, 1], [2, 0]], device="mps")
12543+
values = torch.tensor([1., 2.], device="mps")
12544+
size = (2, 3)
12545+
t = torch.sparse_coo_tensor(indices, values, size, device="mps")
12546+
self.assertEqual(t.device.type, "mps")
12547+
self.assertEqual(t.layout, torch.sparse_coo)
12548+
self.assertEqual(t._indices().cpu(), indices.cpu())
12549+
self.assertEqual(t._values().cpu(), values.cpu())
12550+
12551+
def test_nnz(self):
12552+
t = self._get_basic_sparse_coo()
12553+
self.assertEqual(t._nnz(), 2)
12554+
12555+
def test_sparse_dim(self):
12556+
t = self._get_basic_sparse_coo()
12557+
self.assertEqual(t.sparse_dim(), 2)
12558+
12559+
def test_to_sparse(self):
12560+
t = torch.tensor([[[1., 0], [2., 3.]], [[4., 0], [5., 6.]]], device="mps")
12561+
x = t.to_sparse()
12562+
t_cpu = torch.tensor([[[1., 0], [2., 3.]], [[4., 0], [5., 6.]]], device="mps")
12563+
x_cpu = t.to_sparse()
12564+
self.assertEqual(x.cpu(), x_cpu)
12565+
12566+
def test_resize(self):
12567+
indices = torch.tensor([[0, 1], [2, 0]])
12568+
values = torch.tensor([3.0, 4.0])
12569+
size = torch.Size([2, 3])
12570+
sparse = torch.sparse_coo_tensor(indices, values, size, device="mps")
12571+
sparse_cpu = torch.sparse_coo_tensor(indices, values, size, device="cpu")
12572+
sparse = sparse.sparse_resize_(torch.Size([4, 5]), sparse_dim=2, dense_dim=0)
12573+
sparse_cpu = sparse_cpu.sparse_resize_(torch.Size([4, 5]), sparse_dim=2, dense_dim=0)
12574+
self.assertEqual(sparse, sparse_cpu)
12575+
12576+
1252512577
# TODO: Actually instantiate that test for the "mps" device to better reflect what it is doing.
1252612578
# This requires mps to be properly registered in the device generic test framework which is not the
1252712579
# case right now. We can probably use `allow_mps` introduced in https://github.com/pytorch/pytorch/pull/87342

torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fused_moving_avg_obs_fq_helper_
2121
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__histogramdd_from_bin_cts(AtenTensorHandle self, const int64_t* bins, int64_t bins_len_, const double** range, int64_t range_len_, AtenTensorHandle* weight, int32_t density, AtenTensorHandle* ret0);
2222
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8);
2323
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
24+
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0);
2425
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0);
2526
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__weight_int4pack_mm(AtenTensorHandle self, AtenTensorHandle mat2, int64_t qGroupSize, AtenTensorHandle qScaleAndZeros, AtenTensorHandle* ret0);
2627
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__weight_int8pack_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scales, AtenTensorHandle* ret0);

torch/csrc/utils/tensor_new.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,7 @@ void check_base_legacy_new(
556556
c10::DispatchKey::SparseCUDA,
557557
c10::DispatchKey::SparseHIP,
558558
c10::DispatchKey::SparseXPU,
559+
c10::DispatchKey::SparseMPS,
559560
c10::DispatchKey::SparsePrivateUse1,
560561
});
561562
TORCH_CHECK(

torch/csrc/utils/tensor_types.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ const char* backend_to_string(const at::Backend& backend) {
3939
return "torch.cuda.sparse";
4040
case at::Backend::SparseXPU:
4141
return "torch.xpu.sparse";
42+
case at::Backend::SparseMPS:
43+
return "torch.mps.sparse";
4244
case at::Backend::QuantizedCPU:
4345
return "torch.quantized";
4446
case at::Backend::HPU:

torchgen/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,8 @@ def codegen_per_backend_entries() -> str:
288288
DispatchKey.SparseCsrXPU,
289289
DispatchKey.SparseCUDA,
290290
DispatchKey.SparseCsrCUDA,
291+
DispatchKey.SparseMPS,
292+
DispatchKey.SparseCsrMPS,
291293
DispatchKey.QuantizedCPU,
292294
DispatchKey.QuantizedCUDA,
293295
DispatchKey.CompositeImplicitAutograd,

0 commit comments

Comments
 (0)
0