-
Notifications
You must be signed in to change notification settings - Fork 24.8k
Adds DLPack support #57110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds DLPack support #57110
Changes from 1 commit
a72a579
56af6f3
aa107ef
6aaacd6
86dfedd
68d35ef
70bc04b
df70f26
a73fb64
7134d76
27b9639
b827087
8d60bbe
2059638
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7097,11 +7097,9 @@ def compare_strides(s1, s2, div): | |
@skipMeta | ||
@dtypes(*torch.testing.get_all_dtypes()) | ||
def test_dlpack_capsule_conversion(self, device, dtype): | ||
# DLpack does not explicitly support bool | ||
# DLpack does not explicitly support bool (xref dmlc/dlpack#75) | ||
# It does it through uint8 type | ||
if dtype is torch.bool: | ||
return | ||
if 'xla' in device: | ||
if dtype is torch.bool or 'xla' in device: | ||
return | ||
x = make_tensor((5,), device, dtype, low=-9, high=9) | ||
z = from_dlpack(to_dlpack(x)) | ||
|
@@ -7110,11 +7108,7 @@ def test_dlpack_capsule_conversion(self, device, dtype): | |
@skipMeta | ||
@dtypes(*torch.testing.get_all_dtypes()) | ||
def test_dlpack_protocol_conversion(self, device, dtype): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should add a test that compares the DLPack semantics with our .numpy() and from_numpy() semantics. If/when NumPy implements the protocol we could really validate that the behavior is the same There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now, I think I understood it, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes: tensors with gradients or the conjugate bit set, I think, can't be exported. For importing we should verify that the underlying memory is shared by writing to it on both CPU and CUDA. NumPy arrays can also be non-writable, which we check for on import. I'm not sure what (if any) special properties DLPack capsules have that PyTorch can't emulate. |
||
# DLpack does not explicitly support bool | ||
# It does it through uint8 type | ||
if dtype is torch.bool: | ||
return | ||
if 'xla' in device: | ||
if dtype is torch.bool or 'xla' in device: | ||
return | ||
x = make_tensor((5,), device, dtype, low=-9, high=9) | ||
z = from_dlpack(x) | ||
10000
|
@@ -7123,57 +7117,38 @@ def test_dlpack_protocol_conversion(self, device, dtype): | |
@skipMeta | ||
@dtypes(*torch.testing.get_all_dtypes()) | ||
def test_dlpack_conversion_with_streams(self, device, dtype): | ||
# DLpack does not explicitly support bool | ||
# It does it through uint8 type | ||
if dtype is torch.bool: | ||
return | ||
if 'xla' in device: | ||
if dtype is torch.bool or 'xla' in device: | ||
return | ||
# Create a stream where the tensor will reside | ||
if device == 'cuda': | ||
x = make_tensor((5,), device, dtype, low=-9, high=9) | ||
stream = torch.cuda.Stream() | ||
with torch.cuda.stream(stream): | ||
assert stream.query() | ||
z = from_dlpack(x) | ||
assert not stream.query() | ||
assert stream.query() | ||
stream.synchronize() | ||
self.assertEqual(z, x) | ||
|
||
@skipMeta | ||
@dtypes(*torch.testing.get_all_dtypes()) | ||
def test_dlpack_conversion_with_diff_streams(self, device, dtype): | ||
# DLpack does not explicitly support bool | ||
# It does it through uint8 type | ||
if dtype is torch.bool: | ||
return | ||
if 'xla' in device: | ||
if dtype is torch.bool or 'xla' in device: | ||
return | ||
if device == 'cuda': | ||
from torch._C import _from_dlpack | ||
x = make_tensor((5,), device, dtype, low=-9, high=9) | ||
stream_a = torch.cuda.Stream() | ||
stream_b = torch.cuda.Stream() | ||
with torch.cuda.stream(stream_a): | ||
assert stream_a.query() | ||
assert stream_b.query() | ||
z = _from_dlpack(x.__dlpack__(stream_b.cuda_stream)) | ||
# sync in stream a forces the stream b work to be completed | ||
assert not stream_a.query() | ||
assert not stream_b.query() | ||
stream_a.synchronize() | ||
assert stream_a.query() | ||
assert stream_b.query() | ||
stream_b.synchronize() | ||
self.assertEqual(z, x) | ||
|
||
@skipMeta | ||
@dtypes(*torch.testing.get_all_dtypes()) | ||
def test_dlpack_tensor_invalid_stream(self, device, dtype): | ||
# DLpack does not explicitly support bool | ||
# It does it through uint8 type | ||
if dtype is torch.bool: | ||
return | ||
if 'xla' in device: | ||
if dtype is torch.bool or 'xla' in device: | ||
return | ||
with self.assertRaises(TypeError): | ||
x = make_tensor((5,), device, dtype, low=-9, high=9) | ||
|
Uh oh!
There was an error while loading. Please reload this page.