10000 add simple message tests via DLPack GPU support · mpi4py/mpi4py@0780e63 · GitHub
[go: up one dir, main page]

Skip to content

Commit 0780e63

Browse files
committed
add simple message tests via DLPack GPU support
1 parent 79f64d5 commit 0780e63

File tree

2 files changed

+46
-2
lines changed

2 files changed

+46
-2
lines changed

test/arrayimpl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ def __dlpack_device__(self):
424424
def __dlpack__(self, stream=None):
425425
cupy.cuda.get_current_stream().synchronize()
426426
if self.has_dlpack:
427-
return self.array.__dlpack__(stream)
427+
return self.array.__dlpack__(stream=-1)
428428
else:
429429
return self.array.toDlpack()
430430

test/test_msgspec.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,41 @@ def __dlpack__(self, stream=None):
9191
capsule = dlpack.make_py_capsule(managed)
9292
return capsule
9393

94+
95+
if cupy is not None:
96+
97+
class DLPackGPUBuf(BaseBuf):
98+
99+
has_dlpack = None
100+
dev_type = None
101+
102+
def __init__(self, typecode, initializer):
103+
self._buf = cupy.array(initializer, dtype=typecode)
104+
self.has_dlpack = hasattr(self._buf, '__dlpack_device__')
105+
# TODO(leofang): test CUDA managed memory?
106+
if cupy.cuda.runtime.is_hip:
107+
self.dev_type = dlpack.DLDeviceType.kDLROCM
108+
else:
109+
self.dev_type = dlpack.DLDeviceType.kDLCUDA
110+
111+
def __del__(self):
112+
if not pypy and sys.getrefcount(self._buf) > 2:
113+
raise RuntimeError('dlpack: possible reference leak')
114+
115+
def __dlpack_device__(self):
116+
if self.has_dlpack:
117+
return self._buf.__dlpack_device__()
118+
else:
119+
return (self.dev_type, self._buf.device.id)
120+
121+
def __dlpack__(self, stream=None):
122+
cupy.cuda.get_current_stream().synchronize()
123+
if self.has_dlpack:
124+
return self._buf.__dlpack__(stream=-1)
125+
else:
126+
return self._buf.toDlpack()
127+
128+
94129
# ---
95130

96131
class CAIBuf(BaseBuf):
@@ -426,12 +461,21 @@ def testNotContiguous(self):
426461
@unittest.skipIf(array is None, 'array')
427462
@unittest.skipIf(dlpack is None, 'dlpack')
428463
class TestMessageSimpleDLPackCPUBuf(unittest.TestCase,
429-
BaseTestMessageSimpleArray):
464+
BaseTestMessageSimpleArray):
430465

431466
def array(self, typecode, initializer):
432467
return DLPackCPUBuf(typecode, initializer)
433468

434469

470+
@unittest.skipIf(cupy is None, 'cupy')
471+
@unittest.skipIf(dlpack is None, 'dlpack')
472+
class TestMessageSimpleDLPackGPUBuf(unittest.TestCase,
473+
BaseTestMessageSimpleArray):
474+
475+
def array(self, typecode, initializer):
476+
return DLPackGPUBuf(typecode, initializer)
477+
478+
435479
@unittest.skipIf(array is None, 'array')
436480
class TestMessageSimpleCAIBuf(unittest.TestCase,
437481
BaseTestMessageSimpleArray):

0 commit comments

Comments
 (0)
0