8000 Review changes · pytorch/pytorch@2059638 · GitHub
[go: up one dir, main page]

Skip to content

Commit 2059638

Browse files
author
Emilio Castillo
committed
Review changes
1 parent 8d60bbe commit 2059638

File tree

3 files changed

+39
-27
lines changed

3 files changed

+39
-27
lines changed

test/test_torch.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7096,35 +7096,38 @@ def compare_strides(s1, s2, div):
70967096

70977097
@skipMeta
70987098
@onlyOnCPUAndCUDA
7099-
@dtypes(*torch.testing.get_all_dtypes(include_bool=False))
7099+
@dtypes(*get_all_dtypes(include_bool=False))
71007100
def test_dlpack_capsule_conversion(self, device, dtype):
71017101
# DLpack does not explicitly support bool (xref dmlc/dlpack#75)
7102-
x = make_tensor((5,), device, dtype, low=-9, high=9)
7102+
x = make_tensor((5,), device, dtype)
71037103
z = from_dlpack(to_dlpack(x))
71047104
self.assertEqual(z, x)
71057105

71067106
@skipMeta
71077107
@onlyOnCPUAndCUDA
7108-
@dtypes(*torch.testing.get_all_dtypes(include_bool=False))
7108+
@dtypes(*get_all_dtypes(include_bool=False))
71097109
def test_dlpack_protocol_conversion(self, device, dtype):
7110-
x = make_tensor((5,), device, dtype, low=-9, high=9)
7110+
x = make_tensor((5,), device, dtype)
71117111
z = from_dlpack(x)
71127112
self.assertEqual(z, x)
71137113

71147114
@skipMeta
71157115
@onlyOnCPUAndCUDA
71167116
def test_dlpack_shared_storage(self, device):
7117-
x = make_tensor((5,), device, torch.float64, low=-9, high=9)
7117+
x = make_tensor((5,), device, torch.float64)
71187118
z = from_dlpack(to_dlpack(x))
71197119
z[0] = z[0] + 20.0
71207120
self.assertEqual(z, x)
71217121

71227122
@skipMeta
71237123
@onlyCUDA
7124-
@dtypes(*torch.testing.get_all_dtypes(include_bool=False))
7124+
@dtypes(*get_all_dtypes(include_bool=False))
71257125
def test_dlpack_conversion_with_streams(self, device, dtype):
71267126
# Create a stream where the tensor will reside
7127-
x = make_tensor((5,), device, dtype, low=-9, high=9)
7127+
stream = torch.cuda.Stream()
7128+
with torch.cuda.stream(stream):
7129+
# Do an operation in the actual stream
7130+
x = make_tensor((5,), device, dtype) + 1
71287131
# DLPack protocol helps establish a correct stream order
71297132
# (hence data dependency) at the exchange boundary.
71307133
# DLPack manages this synchronization for us, so we don't need to
@@ -7137,28 +7140,28 @@ def test_dlpack_conversion_with_streams(self, device, dtype):
71377140

71387141
@skipMeta
71397142
@onlyCUDA
7140-
@dtypes(*torch.testing.get_all_dtypes(include_bool=False))
7143+
@dtypes(*get_all_dtypes(include_bool=False))
71417144
def test_dlpack_conversion_with_diff_streams(self, device, dtype):
71427145
from torch._C import _from_dlpack
7143-
x = make_tensor((5,), device, dtype, low=-9, high=9)
71447146
stream_a = torch.cuda.Stream()
71457147
stream_b = torch.cuda.Stream()
71467148
# DLPack protocol helps establish a correct stream order
71477149
# (hence data dependency) at the exchange boundary.
71487150
# the `tensor.__dlpack__` method will insert a synchronization event
71497151
# in the current stream to make sure that it was correctly populated.
71507152
with torch.cuda.stream(stream_a):
7153+
x = make_tensor((5,), device, dtype) + 1
71517154
z = _from_dlpack(x.__dlpack__(stream_b.cuda_stream))
71527155
stream_a.synchronize()
71537156
stream_b.synchronize()
71547157
self.assertEqual(z, x)
71557158

71567159
@skipMeta
71577160
@onlyOnCPUAndCUDA
7158-
@dtypes(*torch.testing.get_all_dtypes(include_bool=False))
7161+
@dtypes(*get_all_dtypes(include_bool=False))
71597162
def test_dlpack_tensor_invalid_stream(self, device, dtype):
71607163
with self.assertRaises(TypeError):
7161-
x = make_tensor((5,), device, dtype, low=-9, high=9)
7164+
x = make_tensor((5,), device, dtype)
71627165
x.__dlpack__(stream=object())
71637166

71647167
@skipMeta
@@ -7182,6 +7185,13 @@ def test_dlpack_export_is_conj(self):
71827185
with self.assertRaisesRegex(RuntimeError, r"conjugate bit"):
71837186
y.__dlpack__()
71847187

7188+
@skipMeta
7189+
def test_dlpack_export_non_strided(self):
7190+
x = torch.sparse_coo_tensor([[0]], [1], size=(1,))
7191+
y = torch.conj(x)
7192+
with self.assertRaisesRegex(RuntimeError, r"strided"):
7193+
y.__dlpack__()
7194+
71857195
@onlyCUDA
71867196
@unittest.skipIf(PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property")
71877197
def test_pin_memory_from_constructor(self, device):

torch/_tensor.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1064,23 +1064,24 @@ def __dlpack__(self, stream=None):
10641064
stream to this method as part of the specification.
10651065
10661066
Args:
1067-
stream (integer or None): A Python integer representing a pointer
1068-
to a stream (CUDA or ROCm). `stream` is provided by the consumer
1069-
to the producer to instruct the producer to ensure that operations
1070-
can safely be performed on the array.
1071-
The pointer must be a positive integer or
1072-
-1 . If stream is -1 , the value may be used by the consumer to
1073-
signal "producer must not perform any synchronization. Optional.
1067+
stream (integer or None): An optional Python integer representing a
1068+
pointer to a CUDA stream. The current stream is synchronized with
1069+
this stream before the capsule is created, and since the capsule
1070+
shares its storage with the tensor this make it safe to access from
1071+
both streams. If None or -1 is passed then no synchronization is performed.
10741072
"""
10751073
if has_torch_function_unary(self):
10761074
return handle_torch_function(Tensor.__dlpack__, (self,), self, stream)
10771075

1078-
# Some semantics that can prevent tensors from being exported are
1079-
# when they require a gradient or they have their conjugate bit set
1076+
# DLPack capsules can't capture all of PyTorch's semantics,
1077+
# so we prohibit exporting tensors that would lose their properties like
1078+
# requires_grad and having the conjugate bit set.
10801079
if self.requires_grad:
10811080
raise RuntimeError('Can\'t export tensors that require gradient, use tensor.detach()')
10821081
if self.is_conj():
10831082
raise RuntimeError('Can\'t export tensors with the conjugate bit set')
1083+
if self.layout != torch.strided:
1084+
raise RuntimeError('Can\'t export tensors with layout other than torch.strided')
10841085

10851086
if stream is not None and type(stream) is not int:
10861087
# Stream pointers in CUDA/ROCm are uniquely numbered and can
@@ -1093,7 +1094,7 @@ def __dlpack__(self, stream=None):
10931094
if stream != torch.cuda.current_stream:
10941095
event = torch.cuda.Event()
10951096
event.record(torch.cuda.current_stream())
1096-
torch.cuda.current_stream().wait_event(event)
1097+
stream.wait_event(event)
10971098
return torch.to_dlpack(self)
10981099

10991100
def __dlpack_device__(self) -> Tuple[enum.IntEnum, int]:

torch/utils/dlpack.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ class DLDeviceType(enum.IntEnum):
2727
Args:
2828
tensor: a tensor to be exported
2929
30-
The dlpack shares the tensors memory.
31-
Note that each dlpack can only be consumed once.
30+
The DLPack shares the tensors memory.
31+
Note that each DLPack can only be consumed once.
3232
""")
3333

3434
# TODO: add a typing.Protocol to be able to tell Mypy that only objects with
@@ -40,13 +40,14 @@ def from_dlpack(ext_tensor: Any) -> torch.Tensor:
4040
by means of the ``__dlpack__`` protocol.
4141
4242
The tensor will share the memory with the object represented
43-
in the dlpack.
43+
in the DLPack.
4444
45-
Note that each dlpack capsule can only be consumed once. Otherwise
46-
memory errors could happen.
45+
.. warning::
46+
Only call from_dlpack once per capsule. Its behavior when used
47+
on the same capsule multiple times is undefined.
4748
4849
Args:
49-
ext_tensor (object with __dlpack__ attribute or dlpack capsule):
50+
ext_tensor (object with __dlpack__ attribute or DLPack capsule):
5051
The tensor or DLPack capsule to convert.
5152
"""
5253
if hasattr(ext_tensor, '__dlpack__'):

0 commit comments

Comments
 (0)
0