diff --git a/test/test_dlpack.py b/test/test_dlpack.py index b43909ff4765f..360413643eab3 100644 --- a/test/test_dlpack.py +++ b/test/test_dlpack.py @@ -8,6 +8,7 @@ onlyCPU, onlyCUDA, onlyNativeDeviceTypes, + skipCUDAIfNotRocm, skipCUDAIfRocm, skipMeta, ) @@ -242,6 +243,31 @@ def test_dlpack_tensor_invalid_stream(self, device, dtype): x = make_tensor((5,), dtype=dtype, device=device) x.__dlpack__(stream=object()) + @skipMeta + @onlyCUDA + @skipCUDAIfRocm + def test_dlpack_cuda_per_thread_stream(self, device): + # Test whether we raise an error if we are trying to use per-thread default + # stream, which is currently not supported by PyTorch. + x = make_tensor((5,), dtype=torch.float32, device=device) + with self.assertRaisesRegex( + BufferError, "per-thread default stream is not supported" + ): + torch.from_dlpack(x.__dlpack__(stream=2)) + + @skipMeta + @onlyCUDA + @skipCUDAIfNotRocm + def test_dlpack_invalid_streams(self, device): + # Test that we correctly raise errors on unsupported ROCm streams. + def test(x, stream): + with self.assertRaisesRegex(BufferError, r"unsupported stream \d for ROCm"): + torch.from_dlpack(x.__dlpack__(stream=stream)) + + x = make_tensor((5,), dtype=torch.float32, device=device) + test(x, stream=1) + test(x, stream=2) + # TODO: add interchange tests once NumPy 1.22 (dlpack support) is required @skipMeta def test_dlpack_export_requires_grad(self): diff --git a/torch/_tensor.py b/torch/_tensor.py index 0084c42acd442..c807b3a6a2170 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -1717,23 +1717,37 @@ def __dlpack__(self, stream=None, max_version=None): # Stream pointers in CUDA/ROCm are uniquely numbered and can # be retrieved from their integer value. raise TypeError("stream must be ``int`` or ``none``") - elif stream is not None and stream != -1: + elif stream != -1: if self.device.type == "cuda": # NB: This logic handles the special case values for default # streams and must be kept in sync with from_dlpack in # torch/utils/dlpack.py - if stream == 1 and torch.version.hip is None: - stream = torch.cuda.default_stream() - elif stream == 0 and torch.version.hip is not None: + is_rocm = torch.version.hip is not None + is_cuda = not is_rocm + + if ( + stream is None + or (is_rocm and stream == 0) + or (is_cuda and stream == 1) + ): stream = torch.cuda.default_stream() else: + if is_cuda and stream == 2: + raise BufferError("per-thread default stream is not supported.") + + assert is_cuda or (is_rocm and stream not in (1, 2)), ( + f"unsupported stream {stream} for ROCm." + ) + stream = torch.cuda.ExternalStream(stream) + # Only synchronize on different streams - sync_stream = torch.cuda.current_stream() - if stream != sync_stream: + current_stream = torch.cuda.current_stream() + if stream != current_stream: event = torch.cuda.Event() - event.record(sync_stream) + event.record(current_stream) stream.wait_event(event) + if self.device.type == "xla": import torch_xla import torch_xla.utils.dlpack as xla_dlpack