8000 Update from_dlpack tests and documentation (#70543) · cyyever/pytorch_private@c6d6645 · GitHub
[go: up one dir, main page]

Skip to content

Commit c6d6645

Browse files
kurtamohlercyyever
authored andcommitted
Update from_dlpack tests and documentation (#70543)
Summary: Part of pytorch/pytorch#58742 Pull Request resolved: pytorch/pytorch#70543 Reviewed By: soulitzer Differential Revision: D34172475 Pulled By: mruberry fbshipit-source-id: d498764b8651a8b7a19181b3421aeebf28a5db2b (cherry picked from commit 05332f164c4317e46e1242fdae483204e0412ef3)
1 parent 76e2fd3 commit c6d6645

File tree

3 files changed

+45
-3
lines changed

3 files changed

+45
-3
lines changed

docs/source/torch.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ Creation Ops
5858
as_tensor
5959
as_strided
6060
from_numpy
61+
from_dlpack
6162
frombuffer
6263
zeros
6364
zeros_like

test/test_torch.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4604,11 +4604,44 @@ def test_dlpack_conversion_with_streams(self, device, dtype):
46044604
stream.synchronize()
46054605
self.assertEqual(z, x)
46064606

4607+
@skipMeta
4608+
@onlyNativeDeviceTypes
4609+
@dtypes(*get_all_dtypes(include_bool=False))
4610+
def test_from_dlpack(self, device, dtype):
4611+
x = make_tensor((5,), device, dtype)
4612+
y = torch.from_dlpack(x)
4613+
self.assertEqual(x, y)
4614+
4615+
@skipMeta
4616+
@onlyNativeDeviceTypes
4617+
@dtypes(*get_all_dtypes(include_bool=False))
4618+
def test_from_dlpack_noncontinguous(self, device, dtype):
4619+
x = make_tensor((25,), device, dtype).reshape(5, 5)
4620+
4621+
y1 = x[0]
4622+
y1_dl = torch.from_dlpack(y1)
4623+
self.assertEqual(y1, y1_dl)
4624+
4625+
y2 = x[:, 0]
4626+
y2_dl = torch.from_dlpack(y2)
4627+
self.assertEqual(y2, y2_dl)
4628+
4629+
y3 = x[1, :]
4630+
y3_dl = torch.from_dlpack(y3)
4631+
self.assertEqual(y3, y3_dl)
4632+
4633+
y4 = x[1]
4634+
y4_dl = torch.from_dlpack(y4)
4635+
self.assertEqual(y4, y4_dl)
4636+
4637+
y5 = x.t()
4638+
y5_dl = torch.from_dlpack(y5)
4639+
self.assertEqual(y5, y5_dl)
4640+
46074641
@skipMeta
46084642
@onlyCUDA
46094643
@dtypes(*get_all_dtypes(include_bool=False))
46104644
def test_dlpack_conversion_with_diff_streams(self, device, dtype):
4611-
from torch._C import _from_dlpack
46124645
stream_a = torch.cuda.Stream()
46134646
stream_b = torch.cuda.Stream()
46144647
# DLPack protocol helps establish a correct stream order
@@ -4617,11 +4650,19 @@ def test_dlpack_conversion_with_diff_streams(self, device, dtype):
46174650
# in the current stream to make sure that it was correctly populated.
46184651
with torch.cuda.stream(stream_a):
46194652
x = make_tensor((5,), device, dtype) + 1
4620-
z = _from_dlpack(x.__dlpack__(stream_b.cuda_stream))
4653+
z = torch.from_dlpack(x.__dlpack__(stream_b.cuda_stream))
46214654
stream_a.synchronize()
46224655
stream_b.synchronize()
46234656
self.assertEqual(z, x)
46244657

4658+
@skipMeta
4659+
@onlyNativeDeviceTypes
4660+
@dtypes(*get_all_dtypes(include_bool=False))
4661+
def test_from_dlpack_dtype(self, device, dtype):
4662+
x = make_tensor((5,), device, dtype)
4663+
y = torch.from_dlpack(x)
4664+
assert x.dtype == y.dtype
4665+
46254666
@skipMeta
46264667
@onlyCUDA
46274668
def test_dlpack_default_stream(self, device):

torch/_torch_docs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1041,7 +1041,7 @@ def merge_dicts(*dicts):
10411041
:func:`torch.frombuffer` creates a tensor that always shares memory from objects that
10421042
implement the buffer protocol.
10431043
1044-
:func:`torch.utils.dlpack.from_dlpack` creates a tensor that always shares memory from
1044+
:func:`torch.from_dlpack` creates a tensor that always shares memory from
10451045
DLPack capsules.
10461046
10471047
Args:

0 commit comments

Comments
 (0)
0