8000 Add `torch.cuda.streams.ExternalStream` by emcastillo · Pull Request #57781 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Add torch.cuda.streams.ExternalStream #57781

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from

Conversation

emcastillo
Copy link
Collaborator
@emcastillo emcastillo commented May 6, 2021

This is required in #57110 (comment)

We need to provide means to synchronize on externally allocated streams for dlpack support in python array data api.

cc @mruberry @rgommers @leofang @asi1024 @kmaehashi

@facebook-github-bot
Copy link
Contributor
facebook-github-bot commented May 6, 2021

💊 CI failures summary and remediations

As of commit 9ce8e84 (more details on the Dr. CI page):


  • 1/1 failures possibly* introduced in this PR
    • 1/1 non-scanned failure(s)

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@emcastillo emcastillo force-pushed the external-stream branch 2 times, most recently from 1058f07 to 7bf6a9c Compare May 6, 2021 23:09
@leofang
Copy link
Contributor
leofang commented May 6, 2021

cc: @gmarkall @pentschev @jakirkham @kkraus14 for vis

@codecov
Copy link
codecov bot commented May 7, 2021

Codecov Report

Merging #57781 (c431a8c) into master (ccd7141) will increase coverage by 0.07%.
The diff coverage is 78.69%.

❗ Current head c431a8c differs from pull request most recent head 2dbe52d. Consider uploading reports for the commit 2dbe52d to get more accurate results

@@            Coverage Diff             @@
##           master   #57781      +/-   ##
==========================================
+ Coverage   76.44%   76.52%   +0.07%     
==========================================
  Files        1990     2008      +18     
  Lines      199690   201099    +1409     
==========================================
+ Hits       152651   153885    +1234     
- Misses      47039    47214     +175     

@ngimel ngimel requested review from mruberry and ezyang May 7, 2021 22:52
@ngimel ngimel added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 7, 2021
@ngimel
Copy link
Collaborator
ngimel commented May 7, 2021

See previous attempt #39567

@ezyang
Copy link
Contributor
ezyang commented May 10, 2021

This is basically the same thing as the previous PR which @dzhulgakov convinced me that it would be OK to take as a stopgap. So I guess I'm OK with this too.

I'm a little concerned about the ROCm failure, though I swear I've seen this event failure in some other PRs too.

@facebook-github-bot
Copy link
Contributor

@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ezyang
Copy link
Contributor
ezyang commented May 10, 2021

giving others a chance to review

@ezyang
Copy link
Contributor
ezyang commented May 18, 2021

According to @ngimel the only blocker is making sure you don't silently overflow the 32 external streams you get (just need to add a little error checking there)

@emcastillo
Copy link
Collaborator Author

@ezyang tahnks for the follow up!

don't silently overflow the 32 external streams you get (just need to add a little error checking there)

I believe this is not as easy as it might seem.
Once we get to assign a stream, we can't currently release it and mark it available on the pool. Currently, the pool overwrites older streams.

import torch
streams = []
stream_ids = set()
for i in range(128):
    streams.append(torch.cuda.Stream())
    if streams[-1].cuda_stream in stream_ids:
        raise ValueError(f'Stream {i} reallocated')
    else:
        stream_ids.add(streams[-1].cuda_stream)
----

Traceback (most recent call last):
  File "stream_stress.py", line 8, in <module>
    raise ValueError(f'Stream {i} reallocated')
ValueError: Stream 32 reallocated

There is no destructor in CUDAStream that allows to mark a stream to be used or unused so I believe the only cleaner way is to have a hash map that holds the externally allocated streams with the LeakyStreamInternals instead of a "pool" for external streams, and subclass CUDAStream that manages this map by filling it when a stream is requested, and removing it when the subclass objects are deleted. How do you think?

@dzhulgakov
Copy link
Collaborator

Have we considered doing pointer bit packing tricks? :) Stream/CudaStream are currently 1 byte for device type, 1 byte for device index and 4 bytes for stream id: https://github.com/pytorch/pytorch/blob/master/c10/core/Stream.h#L120

We don't seem to pass around raw stream id much at all. So let's say that we extend StreamId to 6 bytes. On x86 pointers are only 48 bits long, so we can easily pack the cuda_stream_t pointer as the cuda_stream_id. Of course we need to distinguish stream type, but it's easy, currently CUDA stream has the following scheme allocating the ids: https://github.com/pytorch/pytorch/blob/master/c10/cuda/CUDAStream.cpp#L76 - they are always small (7 bits) so those aren't valid pointers either. This way we can just assume that if StreamId is more than say 256 - it's reinterpreted as cudaStream_t (which is defined as a pointer btw).

I know it sounds crazy, but it's a small change from the today's code and doesn't bring any need for the bookkeeping / static limitations

@mruberry
Copy link
Collaborator

Have we considered doing pointer bit packing tricks? :) Stream/CudaStream are currently 1 byte for device type, 1 byte for device index and 4 bytes for stream id: https://github.com/pytorch/pytorch/blob/master/c10/core/Stream.h#L120

We don't seem to pass around raw stream id much at all. So let's say that we extend StreamId to 6 bytes. On x86 pointers are only 48 bits long, so we can easily pack the cuda_stream_t pointer as the cuda_stream_id. Of course we need to distinguish stream type, but it's easy, currently CUDA stream has the following scheme allocating the ids: https://github.com/pytorch/pytorch/blob/master/c10/cuda/CUDAStream.cpp#L76 - they are always small (7 bits) so those aren't valid pointers either. This way we can just assume that if StreamId is more than say 256 - it's reinterpreted as cudaStream_t (which is defined as a pointer btw).

I know it sounds crazy, but it's a small change from the today's code and doesn't bring any need for the bookkeeping / static limitations

To be sure I understand this proposal, you're suggesting that instead of looking at subclassing CUDAStream with something like CUDAStreamExternal, instead CUDAStream simply understand it's a "native" or "external" stream, and pack, unpack, and destroy itself appropriately?

For example, the unpack method would change to interrogate the id, and then infer whether the CUDAStream was a "native" or "external" stream, and configure itself properly. Similarly the CUDAStream class would get a destructor that would check if the stream was "external" and, if so, destroy the stream?

That makes sense to me if I'm understanding you correctly, @dzhulgakov. We should be very careful to document this.

@ezyang
Copy link
Contributor
ezyang commented May 20, 2021

Bit packing is very reasonable and might just moot our entire conversation here. However, just to clarify the earlier point:

@emcastillo Sorry, I wasn't clear, I was suggesting something very very dumb: just hard error at the point we overflow. Don't bother trying to find unused stream ids that can be reused. That should be simple to implement, and blow up very loudly when someone tries to use too many external streams.

@dzhulgakov
Copy link
Collaborator

Similarly the CUDAStream class would get a destructor that would check if the stream was "external" and, if so, destroy the stream?

We can go even simpler and say that the ownership of the external stream stays with the caller. So we don't even need the destructor

@mruberry
Copy link
Collaborator

Similarly the CUDAStream class would get a destructor that would check if the stream was "external" and, if so, destroy the stream?

We can go even simpler and say that the ownership of the external stream stays with the caller. So we don't even need the destructor

That sounds good. Is that the DLPack ownership model for streams? That the producer manages the stream lifetime independent of the consumer?

@emcastillo
Copy link
Collaborator Author

I think it is reasonable to assume that the stream life-cycle is responsibly of the caller.
Thanks @dzhulgakov! I will try to implement the bitpacking trick :D

@leofang
Copy link
Contributor
leofang commented May 22, 2021

In the DLPack protocol, the consumer hands its stream to the producer, and the producer decides what to do (to sync immediately or just establish the stream order).

8000
@ezyang
Copy link
Contributor
ezyang commented May 27, 2021

I'm slowly grinding through the bit patterns

Emilio Castillo added 3 commits May 27, 2021 04:51
Fix assert and tests

Fix mask and checks

fix

fix
@ezyang
Copy link
Contributor
ezyang commented May 27, 2021

I verified that the bit twiddling correctly works in the sign extended case using Crux (thanks @atomb) using the following program:

#include <stdint.h>
#include <crucible.h>

uint64_t pack(int8_t device_index, int8_t device_type, int64_t id) {
  uint64_t bits = (uint64_t)((uint8_t)(device_type))
          << 56 |
      (uint64_t)((uint8_t)(device_index)) << 48 |
      // Remove the sign extension part of the 64-bit address because
      // the id might be used to hold a pointer.
      ((uint64_t)(id) & ((1ULL << 48) - 1));
  return bits;
}

int main() {
  int8_t device_index = crucible_int8_t("device_index");
  int8_t device_type = crucible_int8_t("device_type");
  int64_t id = crucible_int64_t("id");

  for (int i = 47; i < 64; i++) {
    assuming(((uint64_t)id) >> i & 1 == 1);
  }

  uint64_t bits = pack(device_index, device_type, id);

  // Re-extend the sign of stream_id
  uint64_t mask = (1ULL << 47);
  int64_t out_id =
      ((int64_t)(bits & 0xFFFFFFFFFFFFull) ^ mask) - mask;
  bits >>= 48;
  int8_t out_device_index = (int8_t)(bits & 0xFFFFull);
  bits >>= 8;
  int8_t out_device_type = (int8_t)(bits);

  check(device_index == out_device_index);
  check(device_type == out_device_type);
  check(id == out_id);
}

Crux seems to work pretty well and is able to detect bugs when I slightly modify the program to introduce bugs (e.g., forgetting to write 1ULL instead of 1, or forgetting to cast to the same size unsigned integer type before widening).

AT_ASSERT(ptr);
return ptr->stream;
int64_t stream_id = unwrap().id();
// the stream_ids managed from the pool have only 8 bits
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this comment out of date now

Copy link
Contributor
@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is so great, thank you so much for writing this

@facebook-github-bot
Copy link
Contributor

@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ezyang
Copy link
Contributor
ezyang commented May 27, 2021

need to skip tests on rocm

ezyang added 2 commits May 27, 2021 10:43
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
@facebook-github-bot
Copy link
Contributor

@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ezyang
Copy link
Contributor
ezyang commented May 27, 2021
C3FE

Needs internal updates (related to int to int64)

@ezyang
Copy link
Contributor
ezyang commented Jun 2, 2021

needs more internal updates (related to CUDA)

@facebook-github-bot
Copy link
Contributor

@ezyang merged this pull request in d7ef9b7.

@facebook-github-bot
Copy link
Contributor

This pull request has been reverted by 689a5ed.

@emcastillo
Copy link
Collaborator Author

what happened? 😥

@ngimel
Copy link
Collaborator
ngimel commented Jun 5, 2021

Windows doesn't like ctypes https://app.circleci.com/pipelines/github/pytorch/pytorch/331129/workflows/3ecc5f8a-e782-4718-beca-1aea7a7433ce/jobs/13909012

leofang reacted with eyes emoji

@ezyang
Copy link
Contributor
ezyang commented Jun 6, 2021

@emcastillo If you could push a new copy of the PR that would be a help. My guess is the use of ctypes to access CUDA library functionality is what is not working. We can instead stick all the functions we need in torch/csrc/cuda/shared/cudart.cpp and that should be good enough to get the test going.

@emcastillo
Copy link
Collaborator Author

@emcastillo If you could push a new copy of the PR that would be a help. My guess is the use of ctypes to access CUDA library functionality is what is not working. We can instead stick all the functions we need in torch/csrc/cuda/shared/cudart.cpp and that should be good enough to get the test going.

Sure! let me work on it.

deniskokarev pushed a commit to deniskokarev/pytorch that referenced this pull request Jun 9, 2021
Summary:
This is required in pytorch#57110 (comment)

We need to provide means to synchronize on externally allocated streams for dlpack support in python array data api.

cc mruberry rgommers leofang asi1024 kmaehashi

Pull Request resolved: pytorch#57781

Reviewed By: mrshenli

Differential Revision: D28326365

Pulled By: ezyang

fbshipit-source-id: b67858c8033949951b49a3d319f649884dfd0a91
facebook-github-bot pushed a commit that referenced this pull request Jun 14, 2021
Summary:
Previous is #57781

We add now two CUDA bindings to avoid using ctypes to fix a windows issue.
However, we use ctypes to allocate the stream and create its pointer
(we can do this with a 0-dim tensor too if it feels better).

CC. ezyang rgommers ngimel mruberry

Pull Request resolved: #59527

Reviewed By: albanD

Differential Revision: D29053062

Pulled By: ezyang

fbshipit-source-id: 661e7e58de98b1bdb7a0871808cd41d91fe8f13f
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed Merged open source Reverted triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants
0