@@ -4604,11 +4604,44 @@ def test_dlpack_conversion_with_streams(self, device, dtype):
4604
4604
stream .synchronize ()
4605
4605
self .assertEqual (z , x )
4606
4606
4607
+ @skipMeta
4608
+ @onlyNativeDeviceTypes
4609
+ @dtypes (* get_all_dtypes (include_bool = False ))
4610
+ def test_from_dlpack (self , device , dtype ):
4611
+ x = make_tensor ((5 ,), device , dtype )
4612
+ y = torch .from_dlpack (x )
4613
+ self .assertEqual (x , y )
4614
+
4615
+ @skipMeta
4616
+ @onlyNativeDeviceTypes
4617
+ @dtypes (* get_all_dtypes (include_bool = False ))
4618
+ def test_from_dlpack_noncontinguous (self , device , dtype ):
4619
+ x = make_tensor ((25 ,), device , dtype ).reshape (5 , 5 )
4620
+
4621
+ y1 = x [0 ]
4622
+ y1_dl = torch .from_dlpack (y1 )
4623
+ self .assertEqual (y1 , y1_dl )
4624
+
4625
+ y2 = x [:, 0 ]
4626
+ y2_dl = torch .from_dlpack (y2 )
4627
+ self .assertEqual (y2 , y2_dl )
4628
+
4629
+ y3 = x [1 , :]
4630
+ y3_dl = torch .from_dlpack (y3 )
4631
+ self .assertEqual (y3 , y3_dl )
4632
+
4633
+ y4 = x [1 ]
4634
+ y4_dl = torch .from_dlpack (y4 )
4635
+ self .assertEqual (y4 , y4_dl )
4636
+
4637
+ y5 = x .t ()
4638
+ y5_dl = torch .from_dlpack (y5 )
4639
+ self .assertEqual (y5 , y5_dl )
4640
+
4607
4641
@skipMeta
4608
4642
@onlyCUDA
4609
4643
@dtypes (* get_all_dtypes (include_bool = False ))
4610
4644
def test_dlpack_conversion_with_diff_streams (self , device , dtype ):
4611
- from torch ._C import _from_dlpack
4612
4645
stream_a = torch .cuda .Stream ()
4613
4646
stream_b = torch .cuda .Stream ()
4614
4647
# DLPack protocol helps establish a correct stream order
@@ -4617,11 +4650,19 @@ def test_dlpack_conversion_with_diff_streams(self, device, dtype):
4617
4650
# in the current stream to make sure that it was correctly populated.
4618
4651
with torch .cuda .stream (stream_a ):
4619
4652
x = make_tensor ((5 ,), device , dtype ) + 1
4620
- z = _from_dlpack (x .__dlpack__ (stream_b .cuda_stream ))
4653
+ z = torch . from_dlpack (x .__dlpack__ (stream_b .cuda_stream ))
4621
4654
stream_a .synchronize ()
4622
4655
stream_b .synchronize ()
4623
4656
self .assertEqual (z , x )
4624
4657
4658
+ @skipMeta
4659
+ @onlyNativeDeviceTypes
4660
+ @dtypes (* get_all_dtypes (include_bool = False ))
4661
+ def test_from_dlpack_dtype (self , device , dtype ):
4662
+ x = make_tensor ((5 ,), device , dtype )
4663
+ y = torch .from_dlpack (x )
4664
+ assert x .dtype == y .dtype
4665
+
4625
4666
@skipMeta
4626
4667
@onlyCUDA
4627
4668
def test_dlpack_default_stream (self , device ):
0 commit comments