8000 Add `torch.cuda.streams.ExternalStream` (#57781) · pytorch/pytorch@d7ef9b7 · GitHub
[go: up one dir, main page]

Skip to content

Commit d7ef9b7

Browse files
Emilio Castillofacebook-github-bot
Emilio Castillo
authored andcommitted
Add torch.cuda.streams.ExternalStream (#57781)
Summary: This is required in #57110 (comment) We need to provide means to synchronize on externally allocated streams for dlpack support in python array data api. cc mruberry rgommers leofang asi1024 kmaehashi Pull Request resolved: #57781 Reviewed By: mrshenli Differential Revision: D28326365 Pulled By: ezyang fbshipit-source-id: b67858c8033949951b49a3d319f649884dfd0a91
1 parent c769300 commit d7ef9b7

File tree

13 files changed

+175
-34
lines changed

13 files changed

+175
-34
lines changed

aten/src/ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ inline getStreamFromPoolMasqueradingAsCUDA(const bool isHighPriority = false, De
9090
return HIPStreamMasqueradingAsCUDA(getStreamFromPool(isHighPriority, device));
9191
}
9292

93+
HIPStreamMasqueradingAsCUDA
94+
inline getStreamFromExternalMasqueradingAsCUDA(hipStream_t ext_stream, DeviceIndex device) {
95+
return HIPStreamMasqueradingAsCUDA(getStreamFromExternal(ext_stream, device));
96+
}
97+
9398
inline HIPStreamMasqueradingAsCUDA getDefaultHIPStreamMasqueradingAsCUDA(DeviceIndex device_index = -1) {
9499
return HIPStreamMasqueradingAsCUDA(getDefaultHIPStream(device_index));
95100
}

c10/core/Stream.h

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace c10 {
1212
/// numbering system which is not visible to the user. HOWEVER, we
1313
/// guarantee that StreamId 0 is always a valid stream, and corresponds
1414
/// to some sort of "default" stream.
15-
using StreamId = int32_t;
15+
using StreamId = int64_t;
1616

1717
// NB: I decided not to call the above StreamIndex to avoid confusion with
1818
// DeviceIndex. This way, you access device index with index(), and stream id
@@ -119,21 +119,38 @@ class Stream final {
119119
// that the bitmasking code below is updated accordingly!
120120
static_assert(sizeof(DeviceType) == 1, "DeviceType is not 8-bit");
121121
static_assert(sizeof(DeviceIndex) == 1, "DeviceIndex is not 8-bit");
122-
static_assert(sizeof(StreamId) == 4, "DeviceIndex is not 32-bit");
122+
static_assert(sizeof(StreamId) == 8, "StreamId is not 64-bit");
123123
// Concat these together into a 64-bit integer
124124
// See Note [Hazard when concatenating signed integers]
125125
uint64_t bits = static_cast<uint64_t>(static_cast<uint8_t>(device_type()))
126-
<< 48 |
127-
static_cast<uint64_t>(static_cast<uint8_t>(device_index())) << 32 |
128-
static_cast<uint64_t>(static_cast<uint32_t>(id()));
126+
<< 56 |
127+
static_cast<uint64_t>(static_cast<uint8_t>(device_index())) << 48 |
128+
// Remove the sign extension part of the 64-bit address because
129+
// the id might be used to hold a pointer.
130+
(static_cast<uint64_t>(id()) & ((1ull << 48) - 1));
131+
TORCH_INTERNAL_ASSERT(
132+
static_cast<DeviceIndex>((bits >> 48) & 0xFFull) == device_index(),
133+
"DeviceIndex is not correctly packed");
134+
TORCH_INTERNAL_ASSERT(
135+
static_cast<DeviceType>((bits >> 56)) == device_type(),
136+
"DeviceType is not correctly packed");
137+
// Re-extend the sign of stream_id for checking
138+
uint64_t mask = (1ull << 47);
139+
TORCH_INTERNAL_ASSERT(
140+
static_cast<StreamId>(((bits & 0xFFFFFFFFFFFFull) ^ mask) - mask) ==
141+
id(),
142+
"DeviceType is not correctly packed");
129143
return bits;
130144
}
131145

132146
static Stream unpack(uint64_t bits) {
133-
const auto stream_id = static_cast<StreamId>(bits & 0xFFFFFFFFull);
134-
bits >>= 32;
135-
const auto device_index = static_cast<DeviceIndex>(bits & 0xFFFFull);
136-
bits >>= 16;
147+
// Re-extend the sign of stream_id
148+
uint64_t mask = (1ull << 47);
149+
const auto stream_id =
150+
(static_cast<StreamId>(bits & 0xFFFFFFFFFFFFull) ^ mask) - mask;
151+
bits >>= 48;
152+
const auto device_index = static_cast<DeviceIndex>(bits & 0xFFull);
153+
bits >>= 8;
137154
const auto device_type = static_cast<DeviceType>(bits);
138155
TORCH_CHECK(isValidDeviceType(device_type));
139156
// Unfortunately, we can't check if the StreamId is valid here; it

c10/cuda/CUDAStream.cpp

Lines changed: 53 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <mutex>
1111
#include <vector>
1212

13+
#include <iostream>
1314
namespace c10 {
1415
namespace cuda {
1516

@@ -41,6 +42,7 @@ static DeviceIndex num_gpus = -1;
4142
static constexpr int kStreamsPerPoolBits = 5;
4243
static constexpr int kStreamsPerPool = 1 << kStreamsPerPoolBits;
4344
static constexpr unsigned int kDefaultFlags = cudaStreamNonBlocking;
45+
static constexpr int kStreamTypeBits = 3;
4446

4547
// Note: lower numbers are higher priorities, zero is default priority
4648
static int kHighPriority = -1;
@@ -73,13 +75,13 @@ static std::array<LeakyStreamInternals, kStreamsPerPool>
7375
// ~~~~~~~~~~~~~~~~~~~~~~~~~~
7476
// How do we assign stream IDs?
7577
//
76-
// -- 25 bits -- -- 2 bits -- -- 5 bits -----
77-
// zeros StreamIdType stream id index
78+
// -- 57 bits -- -- 5 bits ----- -- 3 bits --
79+
// zeros stream id index StreamIdType
7880
//
7981
// Where StreamIdType:
80-
// 00 = default stream
81-
// 01 = low priority stream
82-
// 10 = high priority stream
82+
// 000 = default stream or externally allocated if id[63:3] != 0
83+
// 001 = low priority stream
84+
// 010 = high priority stream
8385
//
8486
// This is not really for efficiency; it's just easier to write the code
8587
// to extract the index if we do this with bitmasks :)
@@ -95,11 +97,16 @@ static std::array<LeakyStreamInternals, kStreamsPerPool>
9597
// could work around this with something like
9698
// https://stackoverflow.com/questions/13150449/efficient-unsigned-to-signed-cast-avoiding-implementation-defined-behavior
9799
// but it seems a bit overkill for this.
98-
100+
//
101+
// Also, external managed stream pointers (cudaStream_t) can be directly stored
102+
// in the Id field so in this case, we need to check the stream alignment.
103+
// The IdType uses an additional bit to match with the 64-bit address alignment
104+
// making easy to identify an external stream when its value (X & 7) > 0
99105
enum class StreamIdType : uint8_t {
100106
DEFAULT = 0x0,
101107
LOW = 0x1,
102108
HIGH = 0x2,
109+
EXT = 0x3,
103110
};
104111

105112
std::ostream& operator<<(std::ostream& stream, StreamIdType s) {
@@ -113,28 +120,39 @@ std::ostream& operator<<(std::ostream& stream, StreamIdType s) {
113120
case StreamIdType::HIGH:
114121
stream << "HIGH";
115122
break;
123+
case StreamIdType::EXT:
124+
stream << "EXT";
125+
break;
116126
default:
117127
stream << static_cast<uint8_t>(s);
118128
break;
119129
}
120130
return stream;
121131
}
122132

123-
// StreamId is 32-bit, so we can just rely on regular promotion rules.
133+
// StreamId is 64-bit, so we can just rely on regular promotion rules.
124134
// We rely on streamIdIndex and streamIdType being non-negative;
125135
// see Note [Hazard when concatenating signed integers]
126136

127137
static inline StreamIdType streamIdType(StreamId s) {
128-
return static_cast<StreamIdType>(s >> kStreamsPerPoolBits);
138+
int mask_for_type = (1 << kStreamTypeBits) - 1;
139+
if (s && ((s & mask_for_type) == 0)) {
140+
// Externally allocated streams have their id being the cudaStream_ptr
141+
// so the bits corresponding to the type will be 0 and will collide with
142+
// the default stream.
143+
return StreamIdType::EXT;
144+
}
145+
return static_cast<StreamIdType>(s & mask_for_type);
129146
}
130147

131148
static inline size_t streamIdIndex(StreamId s) {
132-
return static_cast<size_t>(s & ((1 << kStreamsPerPoolBits) - 1));
149+
return static_cast<size_t>(
150+
(s >> kStreamTypeBits) & ((1 << kStreamsPerPoolBits) - 1));
133151
}
134152

135153
StreamId makeStreamId(StreamIdType st, size_t si) {
136-
return (static_cast<StreamId>(st) << kStreamsPerPoolBits) |
137-
static_cast<StreamId>(si);
154+
return (static_cast<StreamId>(si) << kStreamTypeBits) |
155+
static_cast<StreamId>(st);
138156
}
139157

140158
template <typename T, typename A>
@@ -251,7 +269,7 @@ static void initCUDAStreamsOnce() {
251269

252270
// Helper to verify the GPU index is valid
253271
static inline void check_gpu(DeviceIndex device_index) {
254-
AT_ASSERT(device_index >= 0 && device_index < num_gpus);
272+
TORCH_INTERNAL_ASSERT(device_index >= 0 && device_index < num_gpus);
255273
}
256274

257275
// Helper to determine the index of the stream to return
@@ -305,9 +323,16 @@ CUDAStream CUDAStream_fromInternals(const LeakyStreamInternals* ptr) {
305323
} // anonymous namespace
306324

307325
cudaStream_t CUDAStream::stream() const {
308-
auto ptr = CUDAStream_internals(*this);
309-
AT_ASSERT(ptr);
310-
return ptr->stream;
326+
int64_t stream_id = unwrap().id();
327+
if (streamIdType(stream_id) == StreamIdType::EXT) {
328+
// In this case this is a externally allocated stream
329+
// we don't need to manage its life cycle
330+
return reinterpret_cast<cudaStream_t>(stream_id);
331+
} else {
332+
auto ptr = CUDAStream_internals(*this);
333+
TORCH_INTERNAL_ASSERT(ptr);
334+
return ptr->stream;
335+
}
311336
}
312337

313338
// Returns a stream 10000 from the requested pool
@@ -334,6 +359,18 @@ CUDAStream getStreamFromPool(
334359
return CUDAStream_fromInternals(&low_priority_streams[device_index][idx]);
335360
}
336361

362+
CUDAStream getStreamFromExternal(
363+
cudaStream_t ext_stream,
364+
DeviceIndex device_index) {
365+
return CUDAStream(
366+
CUDAStream::UNCHECKED,
367+
// The stream pointer will be the actual id
368+
Stream(
369+
Stream::UNSAFE,
370+
c10::Device(DeviceType::CUDA, device_index),
371+
reinterpret_cast<int64_t>(ext_stream)));
372+
}
373+
337374
CUDAStream getDefaultCUDAStream(DeviceIndex device_index) {
338375
initCUDAStreamsOnce();
339376
if (device_index == -1) {
@@ -354,7 +391,7 @@ CUDAStream getCurrentCUDAStream(DeviceIndex device_index) {
354391
void setCurrentCUDAStream(CUDAStream stream) {
355392
initCUDAStreamsOnce();
356393
auto ptr = CUDAStream_internals(stream);
357-
AT_ASSERT(ptr);
394+
TORCH_INTERNAL_ASSERT(ptr);
358395
current_streams[ptr->device_index] = ptr;
359396
}
360397

c10/cuda/CUDAStream.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,16 @@ class C10_CUDA_API CUDAStream {
195195
TORCH_API CUDAStream
196196
getStreamFromPool(const bool isHighPriority = false, DeviceIndex device = -1);
197197

198+
/**
199+
* Get a CUDAStream from a externally allocated one.
200+
*
201+
* This is mainly for interoperability with different libraries where we
202+
* want to operate on a non-torch allocated stream for data exchange or similar
203+
* purposes
204+
*/
205+
TORCH_API CUDAStream
206+
getStreamFromExternal(cudaStream_t ext_stream, DeviceIndex device_index);
207+
198208
/**
199209
* Get the default CUDA stream, for the passed CUDA device, or for the
200210
* current device if no device index is passed. The default stream is

caffe2/contrib/opencl/context.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ class OpenCLContext final {
6464
CopyBytes<SrcContext, DstContext>(n * meta.itemsize(), src, dst);
6565
}
6666

67-
void SwitchToDevice(int a, ...) {
67+
void SwitchToDevice(int64_t a, ...) {
6868
auto& ctx = GetSingleton();
6969
CAFFE_ENFORCE(a < ctx.devices.size());
7070
ctx.device = ctx.devices[a];

caffe2/core/context.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class TORCH_API CPUContext final : public BaseContext {
6363

6464
~CPUContext() noexcept override {}
6565

66-
inline void SwitchToDevice(int /*stream_id*/) override {}
66+
inline void SwitchToDevice(int64_t /*stream_id*/) override {}
6767

6868
using BaseContext::SwitchToDevice;
6969

caffe2/core/context_base.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class TORCH_API BaseContext {
4242
/* Sorry for the naming, will get rid of this in future diff */
4343
virtual DeviceType device_type() const = 0;
4444

45-
virtual void SwitchToDevice(int /*stream_id*/) = 0;
45+
virtual void SwitchToDevice(int64_t /*stream_id*/) = 0;
4646

4747
inline void SwitchToDevice() {
4848
SwitchToDevice(0);

caffe2/ideep/utils/ideep_context.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class IDEEPContext final : public BaseContext {
2323

2424
~IDEEPContext() noexcept override {}
2525

26-
inline void SwitchToDevice(int /*stream_id*/) {}
26+
inline void SwitchToDevice(int64_t /*stream_id*/) {}
2727
using BaseContext::SwitchToDevice;
2828

2929
inline void WaitEvent(const Event& ev) {

test/test_cuda.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from itertools import repeat, chain, product
22
from typing import NamedTuple
33
import collections
4+
import contextlib
5+
import ctypes
46
import gc
57
import io
68
import os
@@ -1314,6 +1316,37 @@ def test_record_stream_on_shifted_view(self):
13141316

13151317
self.assertNotEqual(try_realloc.data_ptr(), data_ptr)
13161318

1319+
@contextlib.contextmanager
1320+
def _get_external_stream(self, device):
1321+
lib = ctypes.cdll.LoadLibrary(None)
1322+
p = ctypes.c_void_p()
1323+
with device:
1324+
try:
1325+
out = lib.cudaStreamCreate(ctypes.byref(p))
1326+
yield p.value
1327+
finally:
1328+
out = lib.cudaStreamDestroy(ctypes.c_ulonglong(p.value))
1329+
1330+
@skipIfRocm
1331+
@unittest.skipIf(IS_SANDCASTLE or IS_REMOTE_GPU, "Does not work on Sandcastle")
1332+
def test_external_streams(self):
1333+
device = torch.cuda.device(0)
1334+
with self._get_external_stream(device) as stream_v:
1335+
ext_stream = torch.cuda.streams.ExternalStream(stream_v)
1336+
self.assertEqual(stream_v, ext_stream.cuda_stream)
1337+
self.assertEqual(ext_stream.device.index, device.idx)
1338+
1339+
@skipIfRocm
1340+
@unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
1341+
@unittest.skipIf(IS_SANDCASTLE or IS_REMOTE_GPU, "Does not work on Sandcastle")
1342+
def test_external_streams_multi_device(self):
1343+
device = torch.cuda.device(1)
1344+
with self._get_external_stream(device) as stream_v:
1345+
ext_stream = torch.cuda.streams.ExternalStream(
1346+
stream_v, device=device)
1347+
self.assertEqual(stream_v, ext_stream.cuda_stream)
1348+
self.assertEqual(ext_stream.device.index, device.idx)
1349+
13171350
def test_noncontiguous_pinned_memory(self):
13181351
# See issue #3266
13191352
x = torch.arange(0, 10).view((2, 5))

torch/_C/__init__.pyi.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -840,7 +840,7 @@ class _CudaStreamBase:
840840
cuda_stream: _int
841841
priority: _int
842842

843-
def __new__(self, priority: _int = 0, _cdata: _int = 0) -> _CudaStreamBase: ...
843+
def __new__(self, priority: _int = 0, _cdata: _int = 0, stream_ptr: _int = 0) -> _CudaStreamBase: ...
844844
def query(self) -> _bool: ...
845845
def synchronize(self) -> None: ...
846846
def priority_range(self) -> Tuple[_int, _int]: ...

torch/csrc/cuda/Stream.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,12 @@ static PyObject * THCPStream_pynew(
2222

2323
int priority = 0;
2424
uint64_t cdata = 0;
25+
uint64_t stream_ptr = 0;
2526

2627
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
27-
static char *kwlist[] = {"priority", "_cdata", nullptr};
28+
static char *kwlist[] = {"priority", "_cdata", "stream_ptr", nullptr};
2829
if (!PyArg_ParseTupleAndKeywords(
29-
args, kwargs, "|iK", kwlist, &priority, &cdata)) {
30+
args, kwargs, "|iKK", kwlist, &priority, &cdata, &stream_ptr)) {
3031
return nullptr;
3132
}
3233

@@ -35,11 +36,17 @@ static PyObject * THCPStream_pynew(
3536
return nullptr;
3637
}
3738

39+
if (stream_ptr) {
40+
TORCH_CHECK(priority == 0, "Priority was explicitly set for a external stream")
41+
}
42+
3843
at::cuda::CUDAStream stream =
3944
cdata ?
4045
at::cuda::CUDAStream::unpack(cdata) :
41-
at::cuda::getStreamFromPool(
42-
/* isHighPriority */ priority < 0 ? true : false);
46+
stream_ptr ?
47+
at::cuda::getStreamFromExternal(reinterpret_cast<cudaStream_t>(stream_ptr), current_device) :
48+
at::cuda::getStreamFromPool(
49+
/* isHighPriority */ priority < 0 ? true : false);
4350

4451
THCPStream* self = (THCPStream *)ptr.get();
4552
self->cdata = stream.pack();

torch/cuda/streams.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,29 @@ def __repr__(self):
112112
.format(self.device, self.cuda_stream))
113113

114114

115+
class ExternalStream(Stream):
116+
r"""Wrapper around an externally allocated CUDA stream.
117+
118+
This class is used to wrap streams allocated in other libraries in order
119+
to facilitate data exchange and multi-library interactions.
120+
121+
.. note:: This class doesn't manage the stream life-cycle, it is the user
122+
responsibility to keep the referenced stream alive while this class is
123+
being used.
124+
125+
Args:
126+
stream_ptr(int): Integer representation of the `cudaStream_t` value.
127+
allocated externally.
128+
device(torch.device or int, optional): the device where the stream
129+
was originally allocated. if device is specified incorrectly,
130+
subsequent launches using this stream may fail.
131+
"""
132+
133+
def __new__(cls, stream_ptr, device=None, **kwargs):
134+
with torch.cuda.device(device):
135+
return super(Stream, cls).__new__(cls, stream_ptr=stream_ptr, **kwargs)
136+
137+
115138
class Event(torch._C._CudaEventBase):
116139
r"""Wrapper around a CUDA event.
117140

0 commit comments

Comments
 (0)
0