diff --git a/test/test_dlpack.py b/test/test_dlpack.py index 36b8dcb7ca686..3437f16fdc740 100644 --- a/test/test_dlpack.py +++ b/test/test_dlpack.py @@ -3,11 +3,13 @@ import torch from torch.testing import make_tensor from torch.testing._internal.common_device_type import ( + deviceCountAtLeast, dtypes, instantiate_device_type_tests, onlyCPU, onlyCUDA, onlyNativeDeviceTypes, + skipCUDAIfNotRocm, skipCUDAIfRocm, skipMeta, ) @@ -242,6 +244,62 @@ def test_dlpack_tensor_invalid_stream(self, device, dtype): x = make_tensor((5,), dtype=dtype, device=device) x.__dlpack__(stream=object()) + @skipMeta + @onlyCUDA + @skipCUDAIfRocm + def test_dlpack_cuda_per_thread_stream(self, device): + # Test whether we raise an error if we are trying to use per-thread default + # stream, which is currently not supported by PyTorch. + x = make_tensor((5,), dtype=torch.float32, device=device) + with self.assertRaisesRegex( + BufferError, "per-thread default stream is not supported" + ): + x.__dlpack__(stream=2) + + @skipMeta + @onlyCUDA + @skipCUDAIfNotRocm + def test_dlpack_invalid_rocm_streams(self, device): + # Test that we correctly raise errors on unsupported ROCm streams. + def test(x, stream): + with self.assertRaisesRegex( + AssertionError, r"unsupported stream on ROCm: \d" + ): + x.__dlpack__(stream=stream) + + x = make_tensor((5,), dtype=torch.float32, device=device) + test(x, stream=1) + test(x, stream=2) + + @skipMeta + @onlyCUDA + @skipCUDAIfRocm + def test_dlpack_invalid_cuda_streams(self, device): + x = make_tensor((5,), dtype=torch.float32, device=device) + with self.assertRaisesRegex(AssertionError, r"unsupported stream on CUDA: \d"): + x.__dlpack__(stream=0) + + @skipMeta + def test_dlpack_invalid_cpu_stream(self): + x = make_tensor((5,), dtype=torch.float32, device="cpu") + with self.assertRaisesRegex(AssertionError, r"stream should be None on cpu."): + x.__dlpack__(stream=0) + + @skipMeta + @onlyCUDA + @deviceCountAtLeast(2) + def test_dlpack_tensor_on_different_device(self, devices): + dev0, dev1 = devices[:2] + + with torch.device(dev0): + x = make_tensor((5,), dtype=torch.float32, device=dev0) + + with self.assertRaisesRegex( + BufferError, r"Can't export tensors on a different CUDA device" + ): + with torch.device(dev1): + x.__dlpack__() + # TODO: add interchange tests once NumPy 1.22 (dlpack support) is required @skipMeta def test_dlpack_export_requires_grad(self): diff --git a/torch/_tensor.py b/torch/_tensor.py index 652cd33a03538..3369b6602cfa4 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -1703,27 +1703,49 @@ def __dlpack__(self, *, stream=None, max_version=None): "Can't export tensors with layout other than torch.strided" ) + if ( + self.device.type == "cuda" + and self.device.index != torch.cuda.current_device() + ): + raise BufferError( + "Can't export tensors on a different CUDA device. " + f"Expected: {self.device}. " + f"Current device: {torch.cuda.current_device()}." + ) + if stream is not None and type(stream) is not int: # Stream pointers in CUDA/ROCm are uniquely numbered and can # be retrieved from their integer value. raise TypeError("stream must be ``int`` or ``none``") - elif stream is not None and stream != -1: - if self.device.type == "cuda": - # NB: This logic handles the special case values for default - # streams and must be kept in sync with from_dlpack in - # torch/utils/dlpack.py - if stream == 1 and torch.version.hip is None: - stream = torch.cuda.default_stream() - elif stream == 0 and torch.version.hip is not None: - stream = torch.cuda.default_stream() - else: - stream = torch.cuda.ExternalStream(stream) - # Only synchronize on different streams - sync_stream = torch.cuda.current_stream() - if stream != sync_stream: - event = torch.cuda.Event() - event.record(sync_stream) - stream.wait_event(event) + elif self.device.type == "cuda" and stream != -1: + # NB: This logic handles the special case values for default + # streams and must be kept in sync with from_dlpack in + # torch/utils/dlpack.py + is_rocm = torch.version.hip is not None + is_cuda = not is_rocm + + if stream is None or (is_rocm and stream == 0) or (is_cuda and stream == 1): + stream = torch.cuda.default_stream() + else: + if is_cuda and stream == 2: + raise BufferError("per-thread default stream is not supported.") + + device_str = "CUDA" if is_cuda else "ROCm" + assert (is_cuda and stream != 0) or ( + is_rocm and stream not in (1, 2) + ), f"unsupported stream on {device_str}: {stream}." + + stream = torch.cuda.ExternalStream(stream) + + # Only synchronize on different streams + current_stream = torch.cuda.current_stream() + if stream != current_stream: + event = torch.cuda.Event() + event.record(current_stream) + stream.wait_event(event) + elif self.device.type == "cpu": + assert stream is None, "stream should be None on cpu." + if self.device.type == "xla": import torch_xla import torch_xla.utils.dlpack as xla_dlpack