@@ -91,6 +91,41 @@ def __dlpack__(self, stream=None):
91
91
capsule = dlpack .make_py_capsule (managed )
92
92
return capsule
93
93
94
+
95
+ if cupy is not None :
96
+
97
+ class DLPackGPUBuf (BaseBuf ):
98
+
99
+ has_dlpack = None
100
+ dev_type = None
101
+
102
+ def __init__ (self , typecode , initializer ):
103
+ self ._buf = cupy .array (initializer , dtype = typecode )
104
+ self .has_dlpack = hasattr (self ._buf , '__dlpack_device__' )
105
+ # TODO(leofang): test CUDA managed memory?
106
+ if cupy .cuda .runtime .is_hip :
107
+ self .dev_type = dlpack .DLDeviceType .kDLROCM
108
+ else :
109
+ self .dev_type = dlpack .DLDeviceType .kDLCUDA
110
+
111
+ def __del__ (self ):
112
+ if not pypy and sys .getrefcount (self ._buf ) > 2 :
113
+ raise RuntimeError ('dlpack: possible reference leak' )
114
+
115
+ def __dlpack_device__ (self ):
116
+ if self .has_dlpack :
117
+ return self ._buf .__dlpack_device__ ()
118
+ else :
119
+ return (self .dev_type , self ._buf .device .id )
120
+
121
+ def __dlpack__ (self , stream = None ):
122
+ cupy .cuda .get_current_stream ().synchronize ()
123
+ if self .has_dlpack :
124
+ return self ._buf .__dlpack__ (stream = - 1 )
125
+ else :
126
+ return self ._buf .toDlpack ()
127
+
128
+
94
129
# ---
95
130
96
131
class CAIBuf (BaseBuf ):
@@ -426,12 +461,21 @@ def testNotContiguous(self):
426
461
@unittest .skipIf (array is None , 'array' )
427
462
@unittest .skipIf (dlpack is None , 'dlpack' )
428
463
class TestMessageSimpleDLPackCPUBuf (unittest .TestCase ,
429
- BaseTestMessageSimpleArray ):
464
+ BaseTestMessageSimpleArray ):
430
465
431
466
def array (self , typecode , initializer ):
432
467
return DLPackCPUBuf (typecode , initializer )
433
468
434
469
470
+ @unittest .skipIf (cupy is None , 'cupy' )
471
+ @unittest .skipIf (dlpack is None , 'dlpack' )
472
+ class TestMessageSimpleDLPackGPUBuf (unittest .TestCase ,
473
+ BaseTestMessageSimpleArray ):
474
+
475
+ def array (self , typecode , initializer ):
476
+ return DLPackGPUBuf (typecode , initializer )
477
+
478
+
435
479
@unittest .skipIf (array is None , 'array' )
436
480
class TestMessageSimpleCAIBuf (unittest .TestCase ,
437
481
BaseTestMessageSimpleArray ):
0 commit comments