8000 Add pinned memory support to sparse COO/CSR/CSC/BSR/BSC tensors by pearu · Pull Request #129645 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Add pinned memory support to sparse COO/CSR/CSC/BSR/BSC tensors #129645

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update on "Add pinned memory support to sparse COO/CSR/CSC/BSR/BSC te…
…nsors"

As in the title:

To register indices/values of a sparse XYZ tensor with CUDA, the following methods are supported
- `sparse_xyz_tensor(indices, values, pin_memory=True)`
- `sparse_xyz_tensor(indices, values).pin_memory()`
- `sparse_xyz_tensor(indices.pin_memory(), values.pin_memory())`

Fixes #115330




cc alexsamardzic nikitaved cpuhrsch amjames bhosmer jcaip

[ghstack-poisoned]
  • Loading branch information
pearu committed Jun 29, 2024
commit cac97c80304b3b4ae6ae27f3a7df2089db0f5f16
19 changes: 18 additions & 1 deletion aten/src/ATen/native/sparse/SparseCsrTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,9 +292,23 @@ static void _validate_sparse_compressed_tensor_args_worker(const Tensor& compres
compressed_indices.get_device() == plain_indices.get_device(),
"device of ", compressed_indices_name, " (=",
compressed_indices.device(),
") must match device of ", plain_indices_name," (=",
") must match device of ", plain_indices_name, " (=",
plain_indices.device(),
")");
TORCH_CHECK(
compressed_indices.is_pinned() == values.is_pinned(),
"memory pinning of ", compressed_indices_name, " (=",
compressed_indices.is_pinned(),
") must match memory pinning of values (=",
values.is_pinned(),
")");
TORCH_CHECK(
compressed_indices.is_pinned() == plain_indices.is_pinned(),
"memory pinning of ", compressed_indices_name, " (=",
compressed_indices.is_pinned(),
") must match memory pinning of ", plain_indices_name, " (=",
plain_indices.is_pinned(),
")");

// Autograd Invariants
//
Expand Down Expand Up @@ -1239,6 +1253,9 @@ bool is_pinned_sparse_compressed(const Tensor& self, std::optional<Device> devic

Tensor _pin_memory_sparse_compressed(const Tensor& self, std::optional<Device> device) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!device.has_value() || device->is_cuda());
// pinning of sparse tensor is equivalent to cloning indices and
// values that will not change the sparse tensor invariants. Hence,
// we can skip checking the sparse tensor invariants for efficiency.
CheckSparseTensorInvariants _(false);
TensorOptions options = self.options().pinned_memory(true);
8000 auto impl = get_sparse_csr_impl(self);
Expand Down
13 changes: 13 additions & 0 deletions aten/src/ATen/native/sparse/SparseTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <ATen/InitialTensorOptions.h>
#include <ATen/Layout.h>
#include <ATen/Parallel.h>
#include <ATen/SparseCsrTensorUtils.h>
#include <ATen/SparseTensorImpl.h>
#include <ATen/native/SparseTensorUtils.h>
#include <ATen/native/sparse/SparseStubs.h>
Expand Down Expand Up @@ -384,6 +385,14 @@ void _validate_sparse_coo_tensor_args(
"), but got ",
size.size());

TORCH_CHECK(
indices.is_pinned() == values.is_pinned(),
"memory pinning of indices (=",
indices.is_pinned(),
") must match memory pinning of values (=",
values.is_pinned(),
")");

// Check to make sure all indices ar 8000 e within the boundaries of `size`
if (indices.numel() > 0) {
Tensor min_indices =
Expand Down Expand Up @@ -858,6 +867,10 @@ bool is_pinned_sparse_coo(const Tensor& self, std::optional<Device> device) {

Tensor _pin_memory_sparse_coo(const Tensor& self, std::optional<Device> device) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!device.has_value() || device->is_cuda());
// pinning of sparse tensor is equivalent to cloning indices and
// values that will not change the sparse tensor invariants. Hence,
// we can skip checking the sparse tensor invariants for efficiency.
at::sparse_csr::CheckSparseTensorInvariants _(false);
TensorOptions options = self.options().pinned_memory(true);
return at::_sparse_coo_tensor_with_dims_and_tensors(
self.sparse_dim(),
Expand Down
57 changes: 56 additions & 1 deletion test/test_sparse.py
A4D7
Original file line number Diff line number Diff line change
Expand Up @@ -5386,22 +5386,41 @@ def test_method_pin_memory(self, device, layout):
enable_batch=False, # TODO: remove after gh-104868 is resolved
):
t = t_.pin_memory(pin_memory_device)
self.assertTrue(t.is_pinned(pin_memory_device))

# registering a non-pinned tensor with CUDA memory is a
# clone operation
self.assertFalse(t_.is_pinned(pin_memory_device))

# registering already pinned tensor with CUDA memory is an
# identity operation:
t2 = t.pin_memory(pin_memory_device)
self.assertTrue(t2 is t)

if layout is torch.sparse_coo:
self.assertTrue(t._indices().is_pinned(pin_memory_device))
self.assertTrue(t._values().is_pinned(pin_memory_device))
self.assertFalse(t_._indices().is_pinned(pin_memory_device))
self.assertFalse(t_._values().is_pinned(pin_memory_device))
elif layout in {torch.sparse_csr, torch.sparse_bsr}:
self.assertTrue(t.crow_indices().is_pinned(pin_memory_device))
self.assertTrue(t.col_indices().is_pinned(pin_memory_device))
self.assertTrue(t.values().is_pinned(pin_memory_device))
self.assertFalse(t_.crow_indices().is_pinned(pin_memory_device))
self.assertFalse(t_.col_indices().is_pinned(pin_memory_device))
self.assertFalse(t_.values().is_pinned(pin_memory_device))
elif layout in {torch.sparse_csc, torch.sparse_bsc}:
self.assertTrue(t.ccol_indices().is_pinned(pin_memory_device))
self.assertTrue(t.row_indices().is_pinned(pin_memory_device))
self.assertTrue(t.values().is_pinned(pin_memory_device))
self.assertFalse(t_.ccol_indices().is_pinned(pin_memory_device))
self.assertFalse(t_.row_indices().is_pinned(pin_memory_device))
self.assertFalse(t_.values().is_pinned(pin_memory_device))
elif layout is torch.strided:
pass
else:
assert 0 # unreachable
self.assertTrue(t.is_pinned(pin_memory_device))


@unittest.skipIf(not torch.cuda.is_available(), 'requires cuda')
@onlyCPU
Expand Down Expand Up @@ -5434,6 +5453,42 @@ def test_constructor_pinned_memory(self, device, layout):
assert 0 # unreachable
self.assertTrue(t.is_pinned(pin_memory_device))

@unittest.skipIf(not torch.cuda.is_available(), 'requires cuda')
@onlyCPU
@all_sparse_layouts('layout', include_strided=False)
def test_constructor_mismatched_pinned_memory(self, device, layout):
"""Test the failure to construct sparse tensor from indices and values
that have different pinning states.
"""
def generic_constructor(*args, **kwargs):
if layout in {torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc}:
kwargs.update(layout=layout)
return torch.sparse_compressed_tensor(*args, **kwargs)
elif layout is torch.sparse_coo:
return torch.sparse_coo_tensor(*args, **kwargs)
else:
raise NotImplementedError(layout)

for args, kwargs in self.generate_simple_inputs(
layout, device=device, dtype=torch.float64,
enable_zero_sized=False, # pinning zero-sized tensors is a no-op
enable_batch=False, # TODO: remove after gh-104868 is resolved
output_tensor=False):

# indices are pinned, values is a non-pinned tensor
args1 = (args[0].pin_memory(), *args[1:])

# indices are non-pinned, values is a pinned tensor
args2 = (*args[:-1], args[-1].pin_memory())

with self.assertRaisesRegex(
RuntimeError, r"memory pinning of \w*indices \(=1\) must match memory pinning of values \(=0\)"):
generic_constructor(*args1, **kwargs)

with self.assertRaisesRegex(
RuntimeError, r"memory pinning of \w*indices \(=0\) must match memory pinning of values \(=1\)"):
generic_constructor(*args2, **kwargs)


# e.g., TestSparseUnaryUfuncsCPU and TestSparseUnaryUfuncsCUDA
instantiate_device_type_tests(TestSparseUnaryUfuncs, globals(), except_for='meta')
Expand Down
Loading
You are viewing a condensed version of this merge commit. You can view the full changes here.
0