From 1eadb0ef363fc48a34d243d74de279e31179e1cb Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Fri, 28 Mar 2025 16:37:11 -0300 Subject: [PATCH 1/3] Update [ghstack-poisoned] --- test/test_dlpack.py | 24 ++++++++++++++++++++++++ torch/_tensor.py | 28 +++++++++++++++++++++------- 2 files changed, 45 insertions(+), 7 deletions(-) diff --git a/test/test_dlpack.py b/test/test_dlpack.py index c74e0f1a23268a..e74561c5a035bf 100644 --- a/test/test_dlpack.py +++ b/test/test_dlpack.py @@ -8,6 +8,7 @@ onlyCPU, onlyCUDA, onlyNativeDeviceTypes, + skipCUDAIfNotRocm, skipCUDAIfRocm, skipMeta, ) @@ -237,6 +238,29 @@ def test_dlpack_tensor_invalid_stream(self, device, dtype): x = make_tensor((5,), dtype=dtype, device=device) x.__dlpack__(stream=object()) + @skipMeta + @onlyCUDA + @skipCUDAIfRocm + def test_dlpack_cuda_per_thread_stream(self, device): + # Test whether we raise an error if we are trying to use per-thread default + # stream, which is currently not supported by PyTorch. + x = make_tensor((5,), dtype=torch.float32, device=device) + with self.assertRaisesRegex(BufferError, "per-thread default stream is not supported"): + torch.from_dlpack(x.__dlpack__(stream=2)) + + @skipMeta + @onlyCUDA + @skipCUDAIfNotRocm + def test_dlpack_invalid_streams(self, device): + # Test that we correctly raise errors on unsupported ROCm streams. + def test(x, stream): + with self.assertRaisesRegex(BufferError, r"unsupported stream \d for ROCm"): + torch.from_dlpack(x.__dlpack__(stream=stream)) + + x = make_tensor((5,), dtype=torch.float32, device=device) + test(x, stream=1) + test(x, stream=2) + # TODO: add interchange tests once NumPy 1.22 (dlpack support) is required @skipMeta def test_dlpack_export_requires_grad(self): diff --git a/torch/_tensor.py b/torch/_tensor.py index 5e795405b5ed52..bb1b58bebe130d 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -1704,23 +1704,37 @@ def __dlpack__(self, stream=None, max_version=None): # Stream pointers in CUDA/ROCm are uniquely numbered and can # be retrieved from their integer value. raise TypeError("stream must be ``int`` or ``none``") - elif stream is not None and stream != -1: + elif stream != -1: if self.device.type == "cuda": # NB: This logic handles the special case values for default # streams and must be kept in sync with from_dlpack in # torch/utils/dlpack.py - if stream == 1 and torch.version.hip is None: - stream = torch.cuda.default_stream() - elif stream == 0 and torch.version.hip is not None: + is_rocm = torch.version.hip is not None + is_cuda = not is_rocm + + if ( + stream is None + or (is_rocm and stream == 0) + or (is_cuda and stream == 1) + ): stream = torch.cuda.default_stream() else: + if is_cuda and stream == 2: + raise BufferError("per-thread default stream is not supported.") + + assert is_cuda or (is_rocm and stream not in (1, 2)), ( + f"unsupported stream {stream} for ROCm." + ) + stream = torch.cuda.ExternalStream(stream) + # Only synchronize on different streams - sync_stream = torch.cuda.current_stream() - if stream != sync_stream: + current_stream = torch.cuda.current_stream() + if stream != current_stream: event = torch.cuda.Event() - event.record(sync_stream) + event.record(current_stream) stream.wait_event(event) + if self.device.type == "xla": import torch_xla import torch_xla.utils.dlpack as xla_dlpack From cf810ee108a8333a2c7771d4f43a1e8760d42c16 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Fri, 28 Mar 2025 16:48:52 -0300 Subject: [PATCH 2/3] Update [ghstack-poisoned] --- test/test_dlpack.py | 4 +++- torch/_tensor.py | 6 +++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/test/test_dlpack.py b/test/test_dlpack.py index e74561c5a035bf..41ad42e9907a1e 100644 --- a/test/test_dlpack.py +++ b/test/test_dlpack.py @@ -245,7 +245,9 @@ def test_dlpack_cuda_per_thread_stream(self, device): # Test whether we raise an error if we are trying to use per-thread default # stream, which is currently not supported by PyTorch. x = make_tensor((5,), dtype=torch.float32, device=device) - with self.assertRaisesRegex(BufferError, "per-thread default stream is not supported"): + with self.assertRaisesRegex( + BufferError, "per-thread default stream is not supported" + ): torch.from_dlpack(x.__dlpack__(stream=2)) @skipMeta diff --git a/torch/_tensor.py b/torch/_tensor.py index bb1b58bebe130d..97f8b118a3befc 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -1722,9 +1722,9 @@ def __dlpack__(self, stream=None, max_version=None): if is_cuda and stream == 2: raise BufferError("per-thread default stream is not supported.") - assert is_cuda or (is_rocm and stream not in (1, 2)), ( - f"unsupported stream {stream} for ROCm." - ) + assert is_cuda or ( + is_rocm and stream not in (1, 2) + ), f"unsupported stream {stream} for ROCm." stream = torch.cuda.ExternalStream(stream) From b2de390daa5f0112096bbce3e8898ff6cdf434c2 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Fri, 4 Apr 2025 15:08:38 -0300 Subject: [PATCH 3/3] Update [ghstack-poisoned] --- torch/_tensor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch/_tensor.py b/torch/_tensor.py index d304b775e31002..293017af3adca0 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -1734,9 +1734,9 @@ def __dlpack__(self, stream=None, max_version=None): if is_cuda and stream == 2: raise BufferError("per-thread default stream is not supported.") - assert is_cuda or ( - is_rocm and stream not in (1, 2) - ), f"unsupported stream {stream} for ROCm." + assert is_cuda or (is_rocm and stream not in (1, 2)), ( + f"unsupported stream {stream} for ROCm." + ) stream = torch.cuda.ExternalStream(stream)