8000 Add TORCH_CHECK_INDEX in convert_indices_from_coo_to_csr_cpu by Kh4L · Pull Request #138068 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Add TORCH_CHECK_INDEX in convert_indices_from_coo_to_csr_cpu #138068

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

Kh4L
Copy link
@Kh4L Kh4L commented Oct 16, 2024

The to_sparse_csr CPU implementation convert_indices_from_coo_to_csr_cpu doesn't validate the COO indices, which may lead to illegal memory access.

This PR fixes that by adding checks on the index before accessing the data.

# repro
num_nonzeros = 2048

#row,col
dense_size = (1024, 256)


row_indices = torch.randint(0, dense_size[0], (num_nonzeros,))
col_indices = torch.randint(0, dense_size[1], (num_nonzeros,))
coo_indices = torch.stack((row_indices, col_indices))

# this should work
sparse_coo = torch.sparse_coo_tensor(
    coo_indices,
    torch.ones(coo_indices.size(1), dtype=bool),
    dense_size,
)
dst_node_indices = sparse_coo.to_sparse_csr()
print(f"Works well {dst_node_indices}")


# altering row_indices, now it's expected to fail

row_indices[-1] = dense_size[0] + 42

coo_indices = torch.stack((row_indices, col_indices))

# this should not work
sparse_coo = torch.sparse_coo_tensor(
    coo_indices,
    torch.ones(coo_indices.size(1), dtype=bool),
    dense_size,
)
dst_node_indices = sparse_coo.to_sparse_csr()
print("Should never reach here")

Copy link
pytorch-bot bot commented Oct 16, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/138068

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 41 New Failures, 25 Cancelled Jobs, 6 Pending

As of commit 137ce9d with merge base 7f88bf9 (image):

NEW FAILURES - The following jobs have failed:

CANCELLED JOBS - The following jobs were cancelled. Ple 8000 ase retry:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link

CLA Not Signed

Copy link
Contributor

This PR needs a release notes: label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@zou3519 zou3519 requested a review from cpuhrsch October 16, 2024 13:28
@zou3519 zou3519 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Oct 16, 2024
@cpuhrsch
Copy link
Contributor

@Kh4L - have you tried using the check_invariants flag of torch.sparse_coo_tensor?

@Kh4L
Copy link
Author
Kh4L commented Oct 17, 2024

@Kh4L - have you tried using the check_invariants flag of torch.sparse_coo_tensor?

Yes, I gave it a try
However, I still believe it’s worth checking the index here, as we were encountering various types of memory corruption issues that were causing Python to crash, making it difficult to debug

I think the cost of adding this index check is relatively small compared to the time it could save developers in the long run

@cpuhrsch
Copy link
Contributor

@Kh4L - I agree, I'd have hoped that check_invariants and the torch.sparse.check_sparse_tensor_invariants context manager already did all the relevant checks needed here. Did you try running under with torch.sparse.check_sparse_tensor_invariants(): as in

with torch.sparse.check_sparse_tensor_invariants():
    run_my_model()

@cpuhrsch
Copy link
Contributor

If this doesn't work, I suggest we modify this PR to use https://pytorch.org/docs/main/generated/torch.sparse.check_sparse_tensor_invariants.html#torch.sparse.check_sparse_tensor_invariants.is_enabled as a guard. That way this can be turned on for debugging, but doesn't affect performance by default.

8000

@Kh4L
Copy link
Author
Kh4L commented Oct 17, 2024

@cpuhrsch I understand but is it really ok to have python crashing with errors such as malloc(): invalid next size (unsorted) even without check_sparse_tensor_invariants ? to_sparse_csr does not seem to be a typical bottleneck

@Kh4L
Copy link
Author
Kh4L commented Oct 17, 2024

@cpuhrsch
the index check if optimized and the cost is negligible, see it for a relatively big sptensor:

import torch
import time
import numpy as np

num_nonzeros = 2 ** 14
dense_size = (2**16, 2**14)

row_indices = torch.randint(0, dense_size[0], (num_nonzeros,))
col_indices = torch.randint(0, dense_size[1], (num_nonzeros,))
coo_indices = torch.stack((row_indices, col_indices))
sparse_coo = torch.sparse_coo_tensor(
    coo_indices,
    torch.ones(coo_indices.size(1), dtype=bool),
    dense_size,
)

def benchmark_to_sparse_csr():
    return sparse_coo.to_sparse_csr()

num_runs = 100
times = []

for _ in range(num_runs):
    start_time = time.time()
    benchmark_to_sparse_csr()
    end_time = time.time()
    times.append(end_time - start_time)

print(f"Mean execution time: {np.mean(times):.6f} seconds")
print(f"Standard deviation: {np.std(times):.6f} seconds")

before the change (without index check):

Mean execution time: 0.178715 seconds
Standard deviation: 0.004270 seconds

after the change (with index check):

Mean execution time: 0.178128 seconds
Standard deviation: 0.003673 seconds

@cpuhrsch
Copy link
Contributor

@Kh4L - when the Tensor is big it's fine, but maybe when it's small it's not fine anymore for some applications. By supporting the guard the user can choose what works best for them.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
0