8000 Fix DLPack stream logic. by ysiraichi · Pull Request #150217 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Fix DLPack stream logic. #150217

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 15 commits into from
Closed
26 changes: 26 additions & 0 deletions test/test_dlpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
onlyCPU,
onlyCUDA,
onlyNativeDeviceTypes,
skipCUDAIfNotRocm,
skipCUDAIfRocm,
skipMeta,
)
Expand Down Expand Up @@ -242,6 +243,31 @@ def test_dlpack_tensor_invalid_stream(self, device, dtype):
x = make_tensor((5,), dtype=dtype, device=device)
x.__dlpack__(stream=object())

@skipMeta
@onlyCUDA
@skipCUDAIfRocm
def test_dlpack_cuda_per_thread_stream(self, device):
# Test whether we raise an error if we are trying to use per-thread default
# stream, which is currently not supported by PyTorch.
x = make_tensor((5,), dtype=torch.float32, device=device)
with self.assertRaisesRegex(
BufferError, "per-thread default stream is not supported"
):
torch.from_dlpack(x.__dlpack__(stream=2))

@skipMeta
@onlyCUDA
@skipCUDAIfNotRocm
def test_dlpack_invalid_streams(self, device):
# Test that we correctly raise errors on unsupported ROCm streams.
def test(x, stream):
with self.assertRaisesRegex(BufferError, r"unsupported stream \d for ROCm"):
torch.from_dlpack(x.__dlpack__(stream=stream))

x = make_tensor((5,), dtype=torch.float32, device=device)
test(x, stream=1)
test(x, stream=2)

# TODO: add interchange tests once NumPy 1.22 (dlpack support) is required
@skipMeta
def test_dlpack_export_requires_grad(self):
Expand Down
28 changes: 21 additions & 7 deletions torch/_tensor.py
41F2
Original file line number Diff line number Diff line change
Expand Up @@ -1717,23 +1717,37 @@ def __dlpack__(self, stream=None, max_version=None):
# Stream pointers in CUDA/ROCm are uniquely numbered and can
# be retrieved from their integer value.
raise TypeError("stream must be ``int`` or ``none``")
elif stream is not None and stream != -1:
elif stream != -1:
if self.device.type == "cuda":
# NB: This logic handles the special case values for default
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No update to dlpack.py ? :D

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need. If stream is None, we still need to synchronize, assuming the legacy default stream.

# streams and must be kept in sync with from_dlpack in
# torch/utils/dlpack.py
if stream == 1 and torch.version.hip is None:
stream = torch.cuda.default_stream()
elif stream == 0 and torch.version.hip is not None:
is_rocm = torch.version.hip is not None
is_cuda = not is_rocm

if (
stream is None
or (is_rocm and stream == 0)
or (is_cuda and stream == 1)
):
stream = torch.cuda.default_stream()
else:
if is_cuda and stream == 2:
raise BufferError("per-thread default stream is not supported.")

assert is_cuda or (is_rocm and stream not in (1, 2)), (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be a BufferError like above instead of AssertionError?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so. The reason being that this assertion checks something the standard explicitly states as "unsupported" or "disallowed", i.e. something the consumer should know about. Moreover, the standard also says that:

Other errors are raised when export fails for other reasons (e.g., incorrect arguments passed or out of memory).

f"unsupported stream {stream} for ROCm."
)

stream = torch.cuda.ExternalStream(stream)

# Only synchronize on different streams
sync_stream = torch.cuda.current_stream()
if stream != sync_stream:
current_stream = torch.cuda.current_stream()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we care if self.device.index != torch.cuda.current_device() ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I think we should. I will add a check for that.

if stream != current_stream:
event = torch.cuda.Event()
event.record(sync_stream)
event.record(current_stream)
stream.wait_event(event)

if self.device.type == "xla":
import torch_xla
import torch_xla.utils.dlpack as xla_dlpack
Expand Down
Loading
0