8000 Add pinned memory support to sparse COO/CSR/CSC/BSR/BSC tensors by pearu · Pull Request #129645 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Add pinned memory support to sparse COO/CSR/CSC/BSR/BSC tensors #129645

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update on "Add pinned memory support to sparse COO/CSR/CSC/BSR/BSC te…
…nsors"

As in the title:

To register indices/values of a sparse XYZ tensor with CUDA, the following methods are supported
- `sparse_xyz_tensor(indices, values, pin_memory=True)`
- `sparse_xyz_tensor(indices, values).pin_memory()`
- `sparse_xyz_tensor(indices.pin_memory(), values.pin_memory())`

Fixes #115330




cc alexsamardzic nikitaved cpuhrsch amjames bhosmer jcaip

[ghstack-poisoned]
  • Loading branch information
pearu committed Aug 1, 2024
commit e91a2b33e9902c0b3122b9006eb268b824469cdd
19 changes: 9 additions & 10 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4545,11 +4545,11 @@
- func: is_pinned(Tensor self, Device? device=None) -> bool
variants: method
dispatch:
NestedTensorCUDA, CUDA: is_pinned_cuda
MPS: is_pinned_mps
SparseCsrCUDA: is_pinned_sparse_compressed
SparseCUDA: is_pinned_sparse_coo
CompositeExplicitAutograd: is_pinned_default
# the NestedTensor keys are necessary because NestedTensor has been removed
# from the CompositeExplicitAutograd keyset see Note [NestedTensor Not Included in Backend Keys]
CompositeExplicitAutograd, NestedTensorCPU: is_pinned
SparseCsrCPU: is_pinned_sparse_compressed
SparseCPU: is_pinned_sparse_coo

# TODO: add a copy kwarg that guarantees that the tensor is put into fresh
# pinned memory
Expand All @@ -4559,11 +4559,10 @@
# Unlike pin_memory, this is guaranteed to give a new non-aliasing tensor
- func: _pin_memory(Tensor self, Device? device=None) -> Tensor
dispatch:
CUDA: _pin_memory_cuda
MPS: _pin_memory_mps
NestedTensorCUDA, NestedTensorCPU: _pin_memory_nested
SparseCUDA: _pin_memory_sparse_coo
SparseCsrCUDA, SparseCsrCPU: _pin_memory_sparse_compressed
CompositeExplicitAutograd: _pin_memory
NestedTensorCPU: _pin_memory_nested
SparseCPU: _pin_memory_sparse_coo
SparseCsrCPU: _pin_memory_sparse_compressed
autogen: _pin_memory.out

- func: pinverse(Tensor self, float rcond=1e-15) -> Tensor
Expand Down
19 changes: 0 additions & 19 deletions aten/src/ATen/native/sparse/SparseUnaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,23 +282,4 @@ Tensor& nan_to_num_sparse_(
return nan_to_num_sparse_out(self, nan, posinf, neginf, self);
}

bool is_pinned_sparse(const Tensor& self, std::optional<c10::Device> device) {
if (device.has_value()) {
TORCH_WARN_DEPRECATION(
"The argument 'device' of Tensor.is_pinned() ",
"is deprecated. Please do not pass this argument.")
}
// Currently, we don't support pin memory for sparse tensor.
// so always return false
return false;
}

Tensor _pin_memory_sparse(const Tensor& self, std::optional<c10::Device> device) {
// Here, we throw an error rather than return self tensor. This
// is because we always return the pinned memory tensor, while
// giving unpinned tensor might mislead users.
TORCH_CHECK_NOT_IMPLEMENTED(
false, "'aten::_pin_memory' is not implemented for sparse tensor.");
}

} // namespace at::native
64 changes: 31 additions & 33 deletions test/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -5372,48 +5372,46 @@ def test_constructor_pin_memory(self, device, layout):
@onlyCPU
@all_sparse_layouts('layout', include_strided=True)
def test_method_pin_memory(self, device, layout):
"""Tests sparse_xyz_tensor(indices, values, pin_memory=False).pin_memory(device)
"""Tests sparse_xyz_tensor(indices, values, pin_memory=False).pin_memory()
"""
# is_pinned() ignores cuda device id, so no point in specifying it here:
pin_memory_device = "cuda"

for t_ in self.generate_simple_inputs(
layout, device=device, dtype=torch.float64,
enable_zero_sized=False, # pinning zero-sized tensors is a no-op
pin_memory=False, # no pinning
enable_batch=False, # TODO: remove after gh-104868 is resolved
):
t = t_.pin_memory(pin_memory_device)
self.assertTrue(t.is_pinned(pin_memory_device))
t = t_.pin_memory()
self.assertTrue(t.is_pinned())

# registering a non-pinned tensor with CUDA memory is a
# clone operation
self.assertFalse(t_.is_pinned(pin_memory_device))
self.assertFalse(t_.is_pinned())

# registering already pinned tensor with CUDA memory is an
# identity operation:
t2 = t.pin_memory(pin_memory_device)
t2 = t.pin_memory()
self.assertTrue(t2 is t)

if layout is torch.sparse_coo:
self.assertTrue(t._indices().is_pinned(pin_memory_device))
self.assertTrue(t._values().is_pinned(pin_memory_device))
self.assertFalse(t_._indices().is_pinned(pin_memory_device))
self.assertFalse(t_._values().is_pinned(pin_memory_device))
self.assertTrue(t._indices().is_pinned())
self.assertTrue(t._values().is_pinned())
self.assertFalse(t_._indices().is_pinned())
self.assertFalse(t_._values().is_pinned())
elif layout in {torch.sparse_csr, torch.sparse_bsr}:
self.assertTrue(t.crow_indices().is_pinned(pin_memory_device))
self.assertTrue(t.col_indices().is_pinned(pin_memory_device))
self.assertTrue(t.values().is_pinned(pin_memory_device))
self.assertFalse(t_.crow_indices().is_pinned(pin_memory_device))
self.assertFalse(t_.col_indices().is_pinned(pin_memory_device))
self.assertFalse(t_.values().is_pinned(pin_memory_device))
self.assertTrue(t.crow_indices().is_pinned())
self.assertTrue(t.col_indices().is_pinned())
self.assertTrue(t.values().is_pinned())
self.assertFalse(t_.crow_indices().is_pinned())
self.assertFalse(t_.col_indices().is_pinned())
self.assertFalse(t_.values().is_pinned())
elif layout in {torch.sparse_csc, torch.sparse_bsc}:
self.assertTrue(t.ccol_indices().is_pinned(pin_memory_device))
self.assertTrue(t.row_indices().is_pinned(pin_memory_device))
self.assertTrue(t.values().is_pinned(pin_memory_device))
self.assertFalse(t_.ccol_indices().is_pinned(pin_memory_device))
self.assertFalse(t_.row_indices().is_pinned(pin_memory_device))
self.assertFalse(t_.values().is_pinned(pin_memory_device))
self.assertTrue(t.ccol_indices().is_pinned())
self.assertTrue(t.row_indices().is_pinned())
self.assertTrue(t.values().is_pinned())
self.assertFalse(t_.ccol_indices().is_pinned())
self.assertFalse(t_.row_indices().is_pinned())
self.assertFalse(t_.values().is_pinned())
elif layout is torch.strided:
pass
else:
Expand All @@ -5431,25 +5429,25 @@ def test_constructor_pinned_memory(self, device, layout):
layout, device=device, dtype=torch.float64,
enable_zero_sized=False, # pinning zero-sized tensors is a no-op
pin_memory=None, # constructor does not specify pin_memory=...
pin_memory_device=pin_memory_device, # indices and values are pinned to the given device
members_pin_memory=True, # indices and values are pinned
enable_batch=False, # TODO: remove after gh-104868 is resolved
):
if layout is torch.sparse_coo:
self.assertTrue(t._indices().is_pinned(pin_memory_device))
self.assertTrue(t._values().is_pinned(pin_memory_device))
self.assertTrue(t._indices().is_pinned())
self.assertTrue(t._values().is_pinned())
elif layout in {torch.sparse_csr, torch.sparse_bsr}:
self.assertTrue(t.crow_indices().is_pinned(pin_memory_device))
self.assertTrue(t.col_indices().is_pinned(pin_memory_device))
self.assertTrue(t.values().is_pinned(pin_memory_device))
self.assertTrue(t.crow_indices().is_pinned())
self.assertTrue(t.col_indices().is_pinned())
self.assertTrue(t.values().is_pinned())
elif layout in {torch.sparse_csc, torch.sparse_bsc}:
self.assertTrue(t.ccol_indices().is_pinned(pin_memory_device))
self.assertTrue(t.row_indices().is_pinned(pin_memory_device))
self.assertTrue(t.values().is_pinned(pin_memory_device))
self.assertTrue(t.ccol_indices().is_pinned())
self.assertTrue(t.row_indices().is_pinned())
self.assertTrue(t.values().is_pinned())
elif layout is torch.strided:
pass
else:
assert 0 # unreachable
self.assertTrue(t.is_pinned(pin_memory_device))
self.assertTrue(t.is_pinned())

@unittest.skipIf(not torch.cuda.is_available(), 'requires cuda')
@onlyCPU
Expand Down
1 change: 0 additions & 1 deletion torch/_subclasses/fake_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ def is_noncontiguous_supported(device):
aten._nested_tensor_from_tensor_list.default,
aten._nested_tensor_from_tensor_list.out,
aten.pin_memory.default,
aten.is_pinned.default,
aten.to.device,
aten.to.prim_Device,
aten._pin_memory.default,
Expand Down
6 changes: 3 additions & 3 deletions torch/testing/_internal/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3331,7 +3331,7 @@ def generate_simple_inputs(self, layout,
dtype=None,
index_dtype=None,
pin_memory=None,
pin_memory_device=None,
members_pin_memory=None,
enable_batch=True,
enable_hybrid=True,
enable_zero_sized=True,
Expand Down Expand Up @@ -3378,8 +3378,8 @@ def generate_simple_inputs(self, layout,
enable_non_contiguous_values=enable_non_contiguous_values,
enable_batch_variable_nse=enable_batch_variable_nse,
output_tensor=False):
if pin_memory_device is not None:
args = tuple(a.pin_memory(pin_memory_device) for a in args)
if members_pin_memory:
args = tuple(a.pin_memory() for a in args)
if layout is torch.strided:
assert len(args) == 1
size = kwargs.pop('size', None) # to ensure that a zero-sized tensor has the desired shape
Expand Down
Loading
You are viewing a condensed version of this merge commit. You can view the full changes here.
0