8000 Adds DLPack support by emcastillo · Pull Request #57110 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Adds DLPack support #57110

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fixes
  • Loading branch information
Emilio Castillo committed Sep 8, 2021
commit a73fb64ae4775e148f6d43d85e1d9e9cd4cd2bf9
41 changes: 8 additions & 33 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7097,11 +7097,9 @@ def compare_strides(s1, s2, div):
@skipMeta
@dtypes(*torch.testing.get_all_dtypes())
def test_dlpack_capsule_conversion(self, device, dtype):
# DLpack does not explicitly support bool
# DLpack does not explicitly support bool (xref dmlc/dlpack#75)
# It does it through uint8 type
if dtype is torch.bool:
return
if 'xla' in device:
if dtype is torch.bool or 'xla' in device:
return
x = make_tensor((5,), device, dtype, low=-9, high=9)
z = from_dlpack(to_dlpack(x))
Expand All @@ -7110,11 +7108,7 @@ def test_dlpack_capsule_conversion(self, device, dtype):
@skipMeta
@dtypes(*torch.testing.get_all_dtypes())
def test_dlpack_protocol_conversion(self, device, dtype):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add a test that compares the DLPack semantics with our .numpy() and from_numpy() semantics. If/when NumPy implements the protocol we could really validate that the behavior is the same

Copy link
Collaborator Author
@emcastillo emcastillo Jul 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I don't understand what should be compared here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now, I think I understood it,
A test that checks that tensors from dlpack can't be resized, or tensors with gradients can't be exported in the same sense that numpy does, is this correct?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes: tensors with gradients or the conjugate bit set, I think, can't be exported.

For importing we should verify that the underlying memory is shared by writing to it on both CPU and CUDA.

NumPy arrays can also be non-writable, which we check for on import. I'm not sure what (if any) special properties DLPack capsules have that PyTorch can't emulate.

# DLpack does not explicitly support bool
# It does it through uint8 type
if dtype is torch.bool:
return
if 'xla' in device:
if dtype is torch.bool or 'xla' in device:
return
x = make_tensor((5,), device, dtype, low=-9, high=9)
z = from_dlpack(x)
10000 Expand All @@ -7123,57 +7117,38 @@ def test_dlpack_protocol_conversion(self, device, dtype):
@skipMeta
@dtypes(*torch.testing.get_all_dtypes())
def test_dlpack_conversion_with_streams(self, device, dtype):
# DLpack does not explicitly support bool
# It does it through uint8 type
if dtype is torch.bool:
return
if 'xla' in device:
if dtype is torch.bool or 'xla' in device:
return
# Create a stream where the tensor will reside
if device == 'cuda':
x = make_tensor((5,), device, dtype, low=-9, high=9)
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
assert stream.query()
z = from_dlpack(x)
assert not stream.query()
assert stream.query()
stream.synchronize()
self.assertEqual(z, x)

@skipMeta
@dtypes(*torch.testing.get_all_dtypes())
def test_dlpack_conversion_with_diff_streams(self, device, dtype):
# DLpack does not explicitly support bool
# It does it through uint8 type
if dtype is torch.bool:
return
if 'xla' in device:
if dtype is torch.bool or 'xla' in device:
return
if device == 'cuda':
from torch._C import _from_dlpack
x = make_tensor((5,), device, dtype, low=-9, high=9)
stream_a = torch.cuda.Stream()
stream_b = torch.cuda.Stream()
with torch.cuda.stream(stream_a):
assert stream_a.query()
assert stream_b.query()
z = _from_dlpack(x.__dlpack__(stream_b.cuda_stream))
# sync in stream a forces the stream b work to be completed
assert not stream_a.query()
assert not stream_b.query()
stream_a.synchronize()
assert stream_a.query()
assert stream_b.query()
stream_b.synchronize()
self.assertEqual(z, x)

@skipMeta
@dtypes(*torch.testing.get_all_dtypes())
def test_dlpack_tensor_invalid_stream(self, device, dtype):
# DLpack does not explicitly support bool
# It does it through uint8 type
if dtype is torch.bool:
return
if 'xla' in device:
if dtype is torch.bool or 'xla' in device:
return
with self.assertRaises(TypeError):
x = make_tensor((5,), device, dtype, low=-9, high=9)
Expand Down
2 changes: 1 addition & 1 deletion torch/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1083,7 +1083,7 @@ def __dlpack__(self, stream=None):
# Only synchronize on different streams
if stream != torch.cuda.current_stream:
event = torch.cuda.Event()
event.record(stream)
event.record(torch.cuda.current_stream())
torch.cuda.current_stream().wait_event(event)
return torch.to_dlpack(self)

Expand Down
0