8000 Rewrite autograd producer consumer stream sync logic (#151079) · pytorch/pytorch@a060f3d · GitHub 8000
[go: up one dir, main page]

Skip to content

Commit a060f3d

Browse files
soulitzerpytorchmergebot
authored andcommitted
Rewrite autograd producer consumer stream sync logic (#151079)
Also see previous work #142097 Pull Request resolved: #151079 Approved by: https://github.com/albanD
1 parent 2ce0b66 commit a060f3d

File tree

5 files changed

+182
-98
lines changed

5 files changed

+182
-98
lines changed

test/test_autograd.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13132,14 +13132,12 @@ def call_backward(x):
1313213132
for _ in range(2):
1313313133
test()
1313413134

13135-
# This fails because we currently sync to the default stream
1313613135
# AttributeError: module 'torch.mps' has no attribute 'default_stream'
1313713136
@skipIfMPS
1313813137
@unittest.skipIf(not torch.accelerator.is_available(), "requires accelerator")
1313913138
@unittest.skipIf(
1314013139
torch.accelerator.device_count() < 2, "accelerator count is less than 2"
1314113140
)
13142-
@unittest.expectedFailure
1314313141
def test_consumer_to_single_producer_case_3_correctness_non_default_ambient_stream(
1314413142
self,
1314513143
):
@@ -13313,7 +13311,6 @@ def test():
1331313311
# This test may spuriously fail on non-cuda accelerators (since we won't
1331413312
# be calling sleep)
1331513313
@unittest.skipIf(not TEST_CUDA, "requires CUDA")
13316-
@unittest.expectedFailure
1331713314
def test_side_stream_backward_overlap(self):
1331813315
# In case 2/3, we would designate the consumer as the accumulation
1331913316
# stream and naively, one might have the consumer wait for the producer

test/test_nestedtensor.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8379,16 +8379,6 @@ def f(values, offsets):
83798379
sample_match_fn=lambda device, sample: ("noncontig_holes" in sample.name),
83808380
name="broken_unflatten_backward",
83818381
),
8382-
# -> CPU device conversion backwards is broken
8383-
XFailRule(
8384-
error_type=RuntimeError,
8385-
error_msg="Unknown layout in record_stream_any_impl",
8386-
op_match_fn=lambda device, op: (op.full_name == "to"),
8387-
sample_match_fn=lambda device, sample: (
8388-
sample.kwargs.get("device", None) == "cpu"
8389-
),
8390-
name="broken_to_backward",
8391-
),
83928382
# sum() backward is not implemented for non-full reductions
83938383
XFailRule(
83948384
error_type=NotImplementedError,

torch/csrc/autograd/engine.cpp

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,13 +1065,32 @@ void Engine::evaluate_function(
10651065
Node* func,
10661066
InputBuffer& inputs,
10671067
const std::shared_ptr<ReadyQueue>& cpu_ready_queue) {
1068-
// The InputBuffer::adds that supplied incoming grads took pains to
1069-
// ensure they're safe to consume in the context of the present
1070-
// func's stream (if applicable). So we guard onto that stream
1071-
// before working with the grads in any capacity.
1068+
// Locally set the current stream to func's associated stream
10721069
auto opt_parent_stream = (*func).stream();
10731070
c10::OptionalStreamGuard parent_stream_guard{opt_parent_stream};
10741071

1072+
// Ensure that the incoming gradients are ready
1073+
for (size_t pos = 0; pos < inputs.ready_events.size(); ++pos) {
1074+
if (!inputs.buffer[pos].defined()) {
1075+
continue;
1076+
}
1077+
const auto device = inputs.buffer[pos].device();
1078+
// TODO: Use at::accelerator::isAccelerator(device->type()) instead
1079+
bool is_accelerator =
1080+
device.is_cuda() || device.is_mtia() || device.is_privateuseone();
1081+
if (!is_accelerator) {
1082+
continue;
1083+
}
1084+
TORCH_INTERNAL_ASSERT(inputs.ready_events[pos].has_value());
1085+
TORCH_INTERNAL_ASSERT(inputs.ready_streams[pos].has_value());
1086+
TORCH_INTERNAL_ASSERT(opt_parent_stream.has_value());
1087+
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
1088+
if (opt_parent_stream.value() != inputs.ready_streams[pos].value()) {
1089+
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
1090+
opt_parent_stream->wait(inputs.ready_events[pos].value());
1091+
}
1092+
}
1093+
10751094
// If exec_info_ is not empty, we have to instrument the execution
10761095
auto& exec_info_ = graph_task->exec_info_;
10771096
if (!exec_info_.empty()) {

torch/csrc/autograd/input_buffer.cpp

Lines changed: 146 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,13 @@ namespace {
2626
// See https://github.com/pytorch/pytorch/issues/60306
2727
// TODO: clean this up when https://github.com/pytorch/pytorch/issues/60306 is
2828
// improved
29-
void record_stream_any_impl(Variable& var, c10::Stream& stream) {
29+
void record_stream_any_impl(Variable& var, const c10::Stream& stream) {
3030
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
31+
32+
if (stream.device_index() != var.device().index()) {
33+
return;
34+
}
35+
3136
const auto guard = c10::impl::VirtualGuardImpl(device_of(var).value().type());
3237

3338
if (C10_UNLIKELY(at::isBatchedTensor(var))) {
@@ -126,99 +131,160 @@ static void accumulate(
126131
}
127132
}
128133

134+
// Note: [Stream sync contract when dealing with multi-deviced-ness]
135+
//
136+
// An operator can deal with multiple devices, e.g. if it does a device
137+
// transfer, etc. However, for the purpose of stream synchronization, the engine
138+
// is only aware of single canonical device/stream for each autograd Node.
139+
//
140+
// For the proper synchronization, the Node author should make sure of the
141+
// following:
142+
//
143+
// 1) A node consuming a gradient should wait on the canonical stream before
144+
// using it.
145+
// 2) A node producing a gradient should have it ready on the canonical
146+
// stream during node execution.
147+
//
148+
149+
// Note: [Autograd Producer-Consumer Stream Syncs]
150+
//
151+
// The producer-consumer stream syncs are partially handled in this method
152+
// and partially handled in the engine prior to the consumer's execution.
153+
// The logic here is mainly responsible for handling the synchronization needed
154+
// for accumulation and recording the event that the consumer should wait on
155+
// later. The corresponding wait and record_stream happens in the engine.
156+
//
157+
// First producer
158+
// ==============
159+
// There are several things we need to do upon seeing the first producer:
160+
// 1) Determine the accumulation stream (which may or may not be used):
161+
// case A) var's device matches consumer node's canonical device
162+
// (The producer node's canonical device may or may not match)
163+
// -> accumulator stream = consumer stream
164+
// case B) var's device matches producer node's canonical device
165+
// and does not match consumer node's canonical device
166+
// -> accumulator stream = producer stream
167+
// case C) var device matches neither
168+
// -> accumulator stream = var device's current stream
169+
// See Note [Stream sync contract when dealing with
170+
// multi-deviced-ness]
171+
// 2) Because we are the first producer, there's no accumulation necessary.
172+
// Just move var into the buffer.
173+
// 3) Update the ready_events and streams for the current position.
174+
// ready_events are events you need to wait for to ensure the corresponding
175+
// buffers are ready. The events are updated as we accumulate into the
176+
// buffer.
177+
//
178+
// Nth producer
179+
// ============
180+
// 1) Synchronize for accumulation. Accumulation operates on both the new
181+
// incoming gradient and the existing gradient in the buffer.
182+
// (i) wait stream and (ii) record stream to make sure both are ready to be
183+
// used on the accumulation stream.
184+
// 2) Accumulate on the accumulation stream
185+
// 3) Update the ready event and stream for the current position.
186+
//
129187
void InputBuffer::add(
130188
size_t pos,
131189
Variable&& var,
132-
const std::optional<c10::Stream>& opt_producer_stream,
133-
const std::optional<c10::Stream>& opt_consumer_stream) {
190+
const std::optional<c10::Stream>& opt_producer_stream_,
191+
const std::optional<c10::Stream>& opt_consumer_stream_) {
134192
TORCH_INTERNAL_ASSERT(pos < buffer.size());
193+
135194
if (!var.defined()) {
136195
return;
137196
}
138-
139-
// Switches to accumulate device
140-
// The device (and stream) chosen for accumulation is:
141-
// (1) var is not a CUDA/privateuse1 variable. Accumulation happens on var's
142-
// device. (2) var is a CUDA/privateuse1 variable and it, the consumer, and
143-
// the producer share the same device:
144-
// (2a) Uses the consumer's stream as the accumulation stream
145-
// (2b) Syncs the accumulation stream with the producer's stream (if
146-
// different) (2c) Accumulates.
147-
// (3) var is a CUDA/MTIA/privateuse1 variable and it shares a device with
148-
// the consumer but not the producer:
149-
// (3a) Uses the consumer's stream as the accumulation stream
150-
// (3b) Syncs the accumulation stream with the consumer device's default
151-
// stream (3c) Accumulates.
152-
// (4) var is a CUDA/MTIA/privateuse1 variable and it shares a device with
153-
// the producer but not the consumer:
154-
// (4a) Uses the producer device's default stream as the accumulation
155-
// stream (4b) Syncs the accumulation stream with the producer's
156-
// stream (4c) Accumulates.
157-
// (5) var is a CUDA/MTIA/privateuse1 variable and it does not share a device
158-
// with the consumer or producer.
159-
// Accumulation happens on the var device's default stream.
160-
161-
auto const device = device_of(var);
162-
TORCH_INTERNAL_ASSERT(device.has_value());
163-
std::optional<c10::Stream> opt_accumulate_stream = std::nullopt;
164-
const auto device_type = device->type();
165-
if (device->is_cuda() || device->is_mtia() || device->is_privateuseone()) {
166-
const auto on_producer =
167-
opt_producer_stream && device == opt_producer_stream->device();
168-
const auto on_consumer =
169-
opt_consumer_stream && device == opt_consumer_stream->device();
170-
171-
if (on_producer && on_consumer) {
172-
// (2a)
173-
opt_accumulate_stream = opt_consumer_stream;
174-
if (opt_accumulate_stream != opt_producer_stream) {
175-
// (2b)
176-
auto event = c10::Event{device_type};
177-
event.record(*opt_producer_stream);
178-
opt_accumulate_stream->wait(event);
179-
record_stream_any_impl(var, *opt_accumulate_stream);
180-
}
197+
const auto device = var.device();
198+
const auto device_type = device.type();
199+
// TODO: Use at::accelerator::isAccelerator(device->type()) instead
200+
bool is_accelerator =
201+
device.is_cuda() || device.is_mtia() || device.is_privateuseone();
202+
//
203+
// Non-accelerator case
204+
//
205+
if (!is_accelerator) {
206+
if (!buffer[pos].defined()) {
207+
buffer[pos] = std::move(var);
181208
} else {
182-
std::optional<c10::Stream> opt_sync_stream = std::nullopt;
183-
const auto guard = c10::impl::VirtualGuardImpl{device_type};
184-
if (on_consumer && !on_producer) {
185-
// (3a)
186-
opt_accumulate_stream = opt_consumer_stream;
187-
opt_sync_stream = guard.getDefaultStream(opt_consumer_stream->device());
188-
} else if (on_producer && !on_consumer) {
189-
// (4a)
190-
opt_accumulate_stream =
191-
guard.getDefaultStream(opt_producer_stream->device());
192-
opt_sync_stream = opt_producer_stream;
193-
} else {
194-
// (5)
195-
opt_accumulate_stream = guard.getDefaultStream(*device);
196-
}
197-
if (opt_sync_stream && (opt_accumulate_stream != opt_sync_stream)) {
198-
// (3b), (4b)
199-
c10::OptionalDeviceGuard device_guard{opt_sync_stream->device()};
200-
auto event = c10::Event{device_type};
201-
event.record(*opt_sync_stream);
202-
opt_accumulate_stream->wait(event);
203-
const auto guard = c10::impl::VirtualGuardImpl(device_type);
204-
record_stream_any_impl(var, *opt_accumulate_stream);
205-
}
209+
c10::OptionalDeviceGuard device_guard{device};
210+
accumulate(buffer, pos, std::move(var));
206211
}
212+
return;
207213
}
214+
// Handle the case where var is on an accelerator but producer node has no
215+
// canonical stream, e.g. this can happen if forward is DtoH
216+
const std::optional<c10::Stream>& opt_producer_stream =
217+
(opt_producer_stream_.has_value()
218+
? opt_producer_stream_
219+
: std::optional<c10::Stream>(
220+
at::accelerator::getCurrentStream(device.index())));
208221

209-
auto& old_var = buffer[pos];
210-
if (!old_var.defined()) {
222+
// opt_consumer_stream is always non-null when is_accelerator is true
223+
// when InputBuffer is used in the engine. InputBuffer is also called
224+
// elsewhere however! (e.g. other engine implementations)
225+
const std::optional<c10::Stream>& opt_consumer_stream =
226+
(opt_consumer_stream_.has_value()
227+
? opt_consumer_stream_
228+
: std::optional<c10::Stream>(
229+
at::accelerator::getCurrentStream(device.index())));
230+
231+
TORCH_INTERNAL_ASSERT(opt_consumer_stream && opt_producer_stream);
232+
233+
// See Note: [Autograd Producer-Consumer Stream Syncs]
234+
if (!opt_accum_streams[pos].has_value()) {
235+
// [ First producer ]
236+
TORCH_INTERNAL_ASSERT(!buffer[pos].defined());
237+
// 1)
238+
if (opt_consumer_stream->device() == device) {
239+
// Case A
240+
opt_accum_streams[pos] = opt_consumer_stream;
241+
if (*opt_consumer_stream != *opt_producer_stream) {
242+
// We will end up doing record_stream on the accumulation stream
243+
// (which is the consumer stream) later, but we also need to do
244+
// it here in case we don't end up accumulating.
245+
record_stream_any_impl(var, *opt_consumer_stream);
246+
}
247+
} else if (opt_producer_stream->device() == device) {
248+
// Case B
249+
opt_accum_streams[pos] = opt_producer_stream;
250+
} else {
251+
// Case C
252+
opt_accum_streams[pos] =
253+
at::accelerator::getCurrentStream(device.index());
254+
}
255+
// 2)
211256
buffer[pos] = std::move(var);
257+
// 3)
258+
auto event = c10::Event{device_type};
259+
event.record(*opt_producer_stream);
260+
ready_events[pos] = std::move(event);
261+
ready_streams[pos] = opt_producer_stream;
212262
} else {
213-
if (opt_accumulate_stream) {
214-
c10::OptionalStreamGuard stream_guard{opt_accumulate_stream};
215-
accumulate(buffer, pos, std::move(var));
216-
} else {
217-
// (1) non-CUDA/privateuse1 variable
218-
// Accumulation happens on variable's device
219-
c10::OptionalDeviceGuard device_guard{device};
220-
accumulate(buffer, pos, std::move(var));
263+
// [ Nth producer ]
264+
auto accum_stream = opt_accum_streams[pos];
265+
auto& ready_event = ready_events[pos];
266+
auto& ready_stream = ready_streams[pos];
267+
TORCH_INTERNAL_ASSERT(accum_stream && ready_event && ready_stream);
268+
// 1)
269+
if (*accum_stream != *opt_producer_stream) {
270+
auto event = c10::Event{device_type};
271+
event.record(*opt_producer_stream);
272+
accum_stream->wait(event);
273+
record_stream_any_impl(var, *accum_stream);
274+
}
275+
if (*accum_stream != *ready_stream) {
276+
accum_stream->wait(*ready_event);
277+
// This is redundant for case A, but needed for case C
278+
record_stream_any_impl(buffer[pos], *accum_stream);
221279
}
280+
// 2)
281+
c10::OptionalStreamGuard stream_guard{accum_stream};
282+
accumulate(buffer, pos, std::move(var));
283+
// 3)
284+
auto event = c10::Event{device_type};
285+
event.record(*accum_stream);
286+
ready_events[pos] = std::move(event);
287+
ready_streams[pos] = accum_stream;
222288
}
223289
}
224290

torch/csrc/autograd/input_buffer.h

Lines changed: 13 additions & 1 deletion
6E9E
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@
1515
namespace torch::autograd {
1616

1717
struct InputBuffer {
18-
explicit InputBuffer(size_t size) : buffer(size) {}
18+
explicit InputBuffer(size_t size)
19+
: buffer(size),
20+
opt_accum_streams(size),
21+
ready_events(size),
22+
ready_streams(size) {}
1923
InputBuffer(const InputBuffer& other) = delete;
2024
InputBuffer(InputBuffer&& other) = default;
2125
explicit InputBuffer(variable_list&& inputs) : buffer(std::move(inputs)) {}
@@ -38,6 +42,14 @@ struct InputBuffer {
3842
static std::vector<Variable> variables(InputBuffer&& g);
3943

4044
std::vector<Variable> buffer;
45+
// The stream used for accumulation when a variable is used multiple times.
46+
std::vector<std::optional<c10::Stream>> opt_accum_streams;
47+
// The events you need to wait for to ensure the corresponding buffers
48+
// are ready. The events are updated as we accumulate into the buffer.
49+
std::vector<std::optional<c10::Event>> ready_events;
50+
// The streams corresponding to the events above. This is only used to
51+
// check if more synchronization is needed or not.
52+
std::vector<std::optional<c10::Stream>> ready_streams;
4153
};
4254

4355
} // namespace torch::autograd

0 commit comments

Comments
 (0)
0