8000 Adds DLPack support by emcastillo · Pull Request #57110 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Adds DLPack support #57110

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Use ExternalStream
  • Loading branch information
Emilio Castillo committed Sep 8, 2021
commit 56af6f329ada96307528a0ccf12c6766d796edf9
29 changes: 21 additions & 8 deletions torch/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,19 +1062,32 @@ def __dlpack__(self, stream=None):
of the consumer library.

Args:
stream (integer or None): Object that represents a stream and provides
a `synchronize` method. Optional.
stream (integer or None): A Python integer representing a pointer
to a stream. `stream` is provided by the consumer to the producer
to instruct the producer to ensure that operations can safely be
performed on the array. The pointer must be a positive integer or
-1 . If stream is -1 , the value may be used by the consumer to
signal "producer must not perform any synchronization. Optional.
"""
if isinstance(stream, torch.cuda.Stream) or hasattr(stream, 'synchronize'):
stream.synchronize()
elif stream is not None and type(stream) is int:
if stream is not None and type(stream) is not int:
# currently in pytorch is not possible to create a stream
# from a given pointer
raise TypeError('Can\'t create a stream from an integer in PyTorch')

raise TypeError('stream must be ``int`` or ``none``')
elif stream is not None and stream != -1:
if self.device.type in ('cuda', 'rocm'):
stream = torch.cuda.streams.ExternalStream(stream)
# Only synchronize on different streams
if stream != torch.cuda.current_stream:
event = torch.cuda.Event()
event.record(stream)
torch.cuda.current_stream().wait_event(event)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be the other way around: The consumer's stream waits for the producer (PyTorch)'s stream:

Suggested change
event.record(stream)
torch.cuda.current_stream().wait_event(event)
event.record(torch.cuda.current_stream())
stream.wait_event(event)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

return torch.utils.dlpack.to_dlpack(self)

def __dlpack_device__(self) -> Tuple[Int, Int]:
def __dlpack_device__(self) -> Tuple[int, int]:
# TODO(ecastill)
# Add support for the following devices
# CPU = 1 CPU_PINNED = 3 OPENCL = 4 VULKAN = 7
# METAL = 8 VPI = 9
dlpack_ids = {'cpu': 1, 'cuda': 2, 'rocm': 10}
idx = self.device.index if self.device.index is not None else 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This still has TODO's. I think it would be nice if this returned Tuple[enum.IntEnum, int] as in the spec: https://data-apis.org/array-api/latest/API_specification/array_object.html#dlpack-device-self

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit out-of-scope but if we were to support these other devices, how would the stream support work?
Should it be ignored in environments where a stream does not make any sense?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think what @rgommers meant is to change the return type of this function:

def __dlpack_device__(self) -> Tuple[enum.IntEnum, int]:

This is a bit out-of-scope but if we were to support these other devices, how would the stream support work?
Should it be ignored in environments where a stream does not make any sense?

For __dlpack_device__ whether a device has the concept of stream/queue doesn't matter. For __dlpack__ stream can be Any:
https://data-apis.org/array-api/latest/API_specification/array_object.html#dlpack-self-stream-none

# TODO(ecastill) detect HIP or CUDA
Expand Down
15 changes: 9 additions & 6 deletions torch/utils/dlpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@


def from_dlpack(ext_tensor) -> torch.Tensor:
"""from_dlpack(dlpack) -> Tensor
"""from_dlpack(ext_tensor) -> Tensor

Decodes a DLPack to a tensor.

Args:
dlpack: a PyCapsule object with the dltensor
ext_tensor: a PyCapsule object with the dltensor

The tensor will share the memory with the object represented
in the dlpack.
Expand All @@ -32,14 +32,17 @@ def from_dlpack(ext_tensor) -> torch.Tensor:
The tensor from an external library that will be converted
to a PyTorch one.
"""
if hasattr(dlpack, '__dlpack__'):
device = dlpack.__dlpack_device__()
if hasattr(ext_tensor, '__dlpack__'):
device = ext_tensor.__dlpack_device__()
# device is either CUDA or ROCm, we need to pass the current
# stream
if device[0] in (2, 10):
stream = torch.cuda.stream.current_stream('cuda:{}'.format(device[1]))
# Should we pass an id? or the producer can accept a stream object
stream = torch.cuda.current_stream('cuda:{}'.format(device[1]))
# cuda_stream is the pointer to the stream and it is a public
# attribute, but it is not documented
dlpack = ext_tensor.__dlpack__(stream=stream.cuda_stream)
else:
dlpack = ext_tensor.__dlpack__()
else:
# Old versions just call the converter
dlpack = tensor
Expand Down
0