8000 Fix DLPack stream logic. by ysiraichi · Pull Request #150217 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Fix DLPack stream logic. #150217

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: gh/ysiraichi/85/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions test/test_dlpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
onlyCPU,
onlyCUDA,
onlyNativeDeviceTypes,
skipCUDAIfNotRocm,
skipCUDAIfRocm,
skipMeta,
)
Expand Down Expand Up @@ -242,6 +243,31 @@ def test_dlpack_tensor_invalid_stream(self, device, dtype):
x = make_tensor((5,), dtype=dtype, device=device)
x.__dlpack__(stream=object())

@skipMeta
@onlyCUDA
@skipCUDAIfRocm
def test_dlpack_cuda_per_thread_stream(self, device):
# Test whether we raise an error if we are trying to use per-thread default
# stream, which is currently not supported by PyTorch.
x = make_tensor((5,), dtype=torch.float32, device=device)
with self.assertRaisesRegex(
BufferError, "per-thread default stream is not supported"
):
torch.from_dlpack(x.__dlpack__(stream=2))

@skipMeta
@onlyCUDA
@skipCUDAIfNotRocm
def test_dlpack_invalid_streams(self, device):
# Test that we correctly raise errors on unsupported ROCm streams.
def test(x, stream):
with self.assertRaisesRegex(BufferError, r"unsupported stream \d for ROCm"):
torch.from_dlpack(x.__dlpack__(stream=stream))

x = make_tensor((5,), dtype=torch.float32, device=device)
test(x, stream=1)
test(x, stream=2)

# TODO: add interchange tests once NumPy 1.22 (dlpack support) is required
@skipMeta
def test_dlpack_export_requires_grad(self):
Expand Down
28 changes: 21 additions & 7 deletions torch/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1717,23 +1717,37 @@ def __dlpack__(self, stream=None, max_version=None):
# Stream pointers in CUDA/ROCm are uniquely numbered and can
# be retrieved from their integer value.
raise TypeError("stream must be ``int`` or ``none``")
elif stream is not None and stream != -1:
elif stream != -1:
if self.device.type == "cuda":
# NB: This logic handles the special case values for default
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No update to dlpack.py ? :D

# streams and must be kept in sync with from_dlpack in
# torch/utils/dlpack.py
if stream == 1 and torch.version.hip is None:
stream = torch.cuda.default_stream()
elif stream == 0 and torch.version.hip is not None:
is_rocm = torch.version.hip is not None
is_cuda = not is_rocm

if (
stream is None
or (is_rocm and stream == 0)
or (is_cuda and stream == 1)
):
stream = torch.cuda.default_stream()
else:
if is_cuda and stream == 2:
raise BufferError("per-thread default stream is not supported.")

assert is_cuda or (is_rocm and stream not in (1, 2)), (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be a BufferError like above instead of AssertionError?

f"unsupported stream {stream} for ROCm."
)

stream = torch.cuda.ExternalStream(stream)

# Only synchronize on different streams
sync_stream = torch.cuda.current_stream()
if stream != sync_stream:
current_stream = torch.cuda.current_stream()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we care if self.device.index != torch.cuda.current_device() ?

if stream != current_stream:
event = torch.cuda.Event()
event.record(sync_stream)
event.record(current_stream)
stream.wait_event(event)

if self.device.type == "xla":
import torch_xla
import torch_xla.utils.dlpack as xla_dlpack
Expand Down
Loading
0