-
Notifications
You must be signed in to change notification settings - Fork 24.8k
Adds DLPack support #57110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds DLPack support #57110
Changes from 1 commit
a72a579
56af6f3
aa107ef
6aaacd6
86dfedd
68d35ef
70bc04b
df70f26
a73fb64
7134d76
27b9639
b827087
8d60bbe
2059638
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
ExternalStream
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -1062,19 +1062,32 @@ def __dlpack__(self, stream=None): | |||||||||
of the consumer library. | ||||||||||
|
||||||||||
Args: | ||||||||||
stream (integer or None): Object that represents a stream and provides | ||||||||||
a `synchronize` method. Optional. | ||||||||||
stream (integer or None): A Python integer representing a pointer | ||||||||||
to a stream. `stream` is provided by the consumer to the producer | ||||||||||
to instruct the producer to ensure that operations can safely be | ||||||||||
performed on the array. The pointer must be a positive integer or | ||||||||||
-1 . If stream is -1 , the value may be used by the consumer to | ||||||||||
signal "producer must not perform any synchronization. Optional. | ||||||||||
""" | ||||||||||
if isinstance(stream, torch.cuda.Stream) or hasattr(stream, 'synchronize'): | ||||||||||
stream.synchronize() | ||||||||||
elif stream is not None and type(stream) is int: | ||||||||||
if stream is not None and type(stream) is not int: | ||||||||||
# currently in pytorch is not possible to create a stream | ||||||||||
# from a given pointer | ||||||||||
raise TypeError('Can\'t create a stream from an integer in PyTorch') | ||||||||||
|
||||||||||
raise TypeError('stream must be ``int`` or ``none``') | ||||||||||
elif stream is not None and stream != -1: | ||||||||||
mruberry marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
if self.device.type in ('cuda', 'rocm'): | ||||||||||
stream = torch.cuda.streams.ExternalStream(stream) | ||||||||||
# Only synchronize on different streams | ||||||||||
if stream != torch.cuda.current_stream: | ||||||||||
event = torch.cuda.Event() | ||||||||||
event.record(stream) | ||||||||||
torch.cuda.current_stream().wait_event(event) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should be the other way around: The consumer's stream waits for the producer (PyTorch)'s stream:
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks! |
||||||||||
return torch.utils.dlpack.to_dlpack(self) | ||||||||||
|
||||||||||
def __dlpack_device__(self) -> Tuple[Int, Int]: | ||||||||||
def __dlpack_device__(self) -> Tuple[int, int]: | ||||||||||
# TODO(ecastill) | ||||||||||
# Add support for the following devices | ||||||||||
# CPU = 1 CPU_PINNED = 3 OPENCL = 4 VULKAN = 7 | ||||||||||
# METAL = 8 VPI = 9 | ||||||||||
dlpack_ids = {'cpu': 1, 'cuda': 2, 'rocm': 10} | ||||||||||
idx = self.device.index if self.device.index is not None else 0 | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This still has TODO's. I think it would be nice if this returned There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a bit out-of-scope but if we were to support these other devices, how would the stream support work? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think what @rgommers meant is to change the return type of this function: def __dlpack_device__(self) -> Tuple[enum.IntEnum, int]:
For |
||||||||||
# TODO(ecastill) detect HIP or CUDA | ||||||||||
|
Uh oh!
There was an error while loading. Please reload this page.