@@ -26,8 +26,13 @@ namespace {
26
26
// See https://github.com/pytorch/pytorch/issues/60306
27
27
// TODO: clean this up when https://github.com/pytorch/pytorch/issues/60306 is
28
28
// improved
29
- void record_stream_any_impl (Variable& var, c10::Stream& stream) {
29
+ void record_stream_any_impl (Variable& var, const c10::Stream& stream) {
30
30
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
31
+
32
+ if (stream.device_index () != var.device ().index ()) {
33
+ return ;
34
+ }
35
+
31
36
const auto guard = c10::impl::VirtualGuardImpl (device_of (var).value ().type ());
32
37
33
38
if (C10_UNLIKELY (at::isBatchedTensor (var))) {
@@ -126,99 +131,160 @@ static void accumulate(
126
131
}
127
132
}
128
133
134
+ // Note: [Stream sync contract when dealing with multi-deviced-ness]
135
+ //
136
+ // An operator can deal with multiple devices, e.g. if it does a device
137
+ // transfer, etc. However, for the purpose of stream synchronization, the engine
138
+ // is only aware of single canonical device/stream for each autograd Node.
139
+ //
140
+ // For the proper synchronization, the Node author should make sure of the
141
+ // following:
142
+ //
143
+ // 1) A node consuming a gradient should wait on the canonical stream before
144
+ // using it.
145
+ // 2) A node producing a gradient should have it ready on the canonical
146
+ // stream during node execution.
147
+ //
148
+
149
+ // Note: [Autograd Producer-Consumer Stream Syncs]
150
+ //
151
+ // The producer-consumer stream syncs are partially handled in this method
152
+ // and partially handled in the engine prior to the consumer's execution.
153
+ // The logic here is mainly responsible for handling the synchronization needed
154
+ // for accumulation and recording the event that the consumer should wait on
155
+ // later. The corresponding wait and record_stream happens in the engine.
156
+ //
157
+ // First producer
158
+ // ==============
159
+ // There are several things we need to do upon seeing the first producer:
160
+ // 1) Determine the accumulation stream (which may or may not be used):
161
+ // case A) var's device matches consumer node's canonical device
162
+ // (The producer node's canonical device may or may not match)
163
+ // -> accumulator stream = consumer stream
164
+ // case B) var's device matches producer node's canonical device
165
+ // and does not match consumer node's canonical device
166
+ // -> accumulator stream = producer stream
167
+ // case C) var device matches neither
168
+ // -> accumulator stream = var device's current stream
169
+ // See Note [Stream sync contract when dealing with
170
+ // multi-deviced-ness]
171
+ // 2) Because we are the first producer, there's no accumulation necessary.
172
+ // Just move var into the buffer.
173
+ // 3) Update the ready_events and streams for the current position.
174
+ // ready_events are events you need to wait for to ensure the corresponding
175
+ // buffers are ready. The events are updated as we accumulate into the
176
+ // buffer.
177
+ //
178
+ // Nth producer
179
+ // ============
180
+ // 1) Synchronize for accumulation. Accumulation operates on both the new
181
+ // incoming gradient and the existing gradient in the buffer.
182
+ // (i) wait stream and (ii) record stream to make sure both are ready to be
183
+ // used on the accumulation stream.
184
+ // 2) Accumulate on the accumulation stream
185
+ // 3) Update the ready event and stream for the current position.
186
+ //
129
187
void InputBuffer::add (
130
188
size_t pos,
131
189
Variable&& var,
132
- const std::optional<c10::Stream>& opt_producer_stream ,
133
- const std::optional<c10::Stream>& opt_consumer_stream ) {
190
+ const std::optional<c10::Stream>& opt_producer_stream_ ,
191
+ const std::optional<c10::Stream>& opt_consumer_stream_ ) {
134
192
TORCH_INTERNAL_ASSERT (pos < buffer.size ());
193
+
135
194
if (!var.defined ()) {
136
195
return ;
137
196
}
138
-
139
- // Switches to accumulate device
140
- // The device (and stream) chosen for accumulation is:
141
- // (1) var is not a CUDA/privateuse1 variable. Accumulation happens on var's
142
- // device. (2) var is a CUDA/privateuse1 variable and it, the consumer, and
143
- // the producer share the same device:
144
- // (2a) Uses the consumer's stream as the accumulation stream
145
- // (2b) Syncs the accumulation stream with the producer's stream (if
146
- // different) (2c) Accumulates.
147
- // (3) var is a CUDA/MTIA/privateuse1 variable and it shares a device with
148
- // the consumer but not the producer:
149
- // (3a) Uses the consumer's stream as the accumulation stream
150
- // (3b) Syncs the accumulation stream with the consumer device's default
151
- // stream (3c) Accumulates.
152
- // (4) var is a CUDA/MTIA/privateuse1 variable and it shares a device with
153
- // the producer but not the consumer:
154
- // (4a) Uses the producer device's default stream as the accumulation
155
- // stream (4b) Syncs the accumulation stream with the producer's
156
- // stream (4c) Accumulates.
157
- // (5) var is a CUDA/MTIA/privateuse1 variable and it does not share a device
158
- // with the consumer or producer.
159
- // Accumulation happens on the var device's default stream.
160
-
161
- auto const device = device_of (var);
162
- TORCH_INTERNAL_ASSERT (device.has_value ());
163
- std::optional<c10::Stream> opt_accumulate_stream = std::nullopt;
164
- const auto device_type = device->type ();
165
- if (device->is_cuda () || device->is_mtia () || device->is_privateuseone ()) {
166
- const auto on_producer =
167
- opt_producer_stream && device == opt_producer_stream->device ();
168
- const auto on_consumer =
169
- opt_consumer_stream && device == opt_consumer_stream->device ();
170
-
171
- if (on_producer && on_consumer) {
172
- // (2a)
173
- opt_accumulate_stream = opt_consumer_stream;
174
- if (opt_accumulate_stream != opt_producer_stream) {
175
- // (2b)
176
- auto event = c10::Event{device_type};
177
- event.record (*opt_producer_stream);
178
- opt_accumulate_stream->wait (event);
179
- record_stream_any_impl (var, *opt_accumulate_stream);
180
- }
197
+ const auto device = var.device ();
198
+ const auto device_type = device.type ();
199
+ // TODO: Use at::accelerator::isAccelerator(device->type()) instead
200
+ bool is_accelerator =
201
+ device.is_cuda () || device.is_mtia () || device.is_privateuseone ();
202
+ //
203
+ // Non-accelerator case
204
+ //
205
+ if (!is_accelerator) {
206
+ if (!buffer[pos].defined ()) {
207
+ buffer[pos] = std::move (var);
181
208
} else {
182
- std::optional<c10::Stream> opt_sync_stream = std::nullopt;
183
- const auto guard = c10::impl::VirtualGuardImpl{device_type};
184
- if (on_consumer && !on_producer) {
185
- // (3a)
186
- opt_accumulate_stream = opt_consumer_stream;
187
- opt_sync_stream = guard.getDefaultStream (opt_consumer_stream->device ());
188
- } else if (on_producer && !on_consumer) {
189
- // (4a)
190
- opt_accumulate_stream =
191
- guard.getDefaultStream (opt_producer_stream->device ());
192
- opt_sync_stream = opt_producer_stream;
193
- } else {
194
- // (5)
195
- opt_accumulate_stream = guard.getDefaultStream (*device);
196
- }
197
- if (opt_sync_stream && (opt_accumulate_stream != opt_sync_stream)) {
198
- // (3b), (4b)
199
- c10::OptionalDeviceGuard device_guard{opt_sync_stream->device ()};
200
- auto event = c10::Event{device_type};
201
- event.record (*opt_sync_stream);
202
- opt_accumulate_stream->wait (event);
203
- const auto guard = c10::impl::VirtualGuardImpl (device_type);
204
- record_stream_any_impl (var, *opt_accumulate_stream);
205
- }
209
+ c10::OptionalDeviceGuard device_guard{device};
210
+ accumulate (buffer, pos, std::move (var));
206
211
}
212
+ return ;
207
213
}
214
+ // Handle the case where var is on an accelerator but producer node has no
215
+ // canonical stream, e.g. this can happen if forward is DtoH
216
+ const std::optional<c10::Stream>& opt_producer_stream =
217
+ (opt_producer_stream_.has_value ()
218
+ ? opt_producer_stream_
219
+ : std::optional<c10::Stream>(
220
+ at::accelerator::getCurrentStream (device.index ())));
208
221
209
- auto & old_var = buffer[pos];
210
- if (!old_var.defined ()) {
222
+ // opt_consumer_stream is always non-null when is_accelerator is true
223
+ // when InputBuffer is used in the engine. InputBuffer is also called
224
+ // elsewhere however! (e.g. other engine implementations)
225
+ const std::optional<c10::Stream>& opt_consumer_stream =
226
+ (opt_consumer_stream_.has_value ()
227
+ ? opt_consumer_stream_
228
+ : std::optional<c10::Stream>(
229
+ at::accelerator::getCurrentStream (device.index ())));
230
+
231
+ TORCH_INTERNAL_ASSERT (opt_consumer_stream && opt_producer_stream);
232
+
233
+ // See Note: [Autograd Producer-Consumer Stream Syncs]
234
+ if (!opt_accum_streams[pos].has_value ()) {
235
+ // [ First producer ]
236
+ TORCH_INTERNAL_ASSERT (!buffer[pos].defined ());
237
+ // 1)
238
+ if (opt_consumer_stream->device () == device) {
239
+ // Case A
240
+ opt_accum_streams[pos] = opt_consumer_stream;
241
+ if (*opt_consumer_stream != *opt_producer_stream) {
242
+ // We will end up doing record_stream on the accumulation stream
243
+ // (which is the consumer stream) later, but we also need to do
244
+ // it here in case we don't end up accumulating.
245
+ record_stream_any_impl (var, *opt_consumer_stream);
246
+ }
247
+ } else if (opt_producer_stream->device () == device) {
248
+ // Case B
249
+ opt_accum_streams[pos] = opt_producer_stream;
250
+ } else {
251
+ // Case C
252
+ opt_accum_streams[pos] =
253
+ at::accelerator::getCurrentStream (device.index ());
254
+ }
255
+ // 2)
211
256
buffer[pos] = std::move (var);
257
+ // 3)
258
+ auto event = c10::Event{device_type};
259
+ event.record (*opt_producer_stream);
260
+ ready_events[pos] = std::move (event);
261
+ ready_streams[pos] = opt_producer_stream;
212
262
} else {
213
- if (opt_accumulate_stream) {
214
- c10::OptionalStreamGuard stream_guard{opt_accumulate_stream};
215
- accumulate (buffer, pos, std::move (var));
216
- } else {
217
- // (1) non-CUDA/privateuse1 variable
218
- // Accumulation happens on variable's device
219
- c10::OptionalDeviceGuard device_guard{device};
220
- accumulate (buffer, pos, std::move (var));
263
+ // [ Nth producer ]
264
+ auto accum_stream = opt_accum_streams[pos];
265
+ auto & ready_event = ready_events[pos];
266
+ auto & ready_stream = ready_streams[pos];
267
+ TORCH_INTERNAL_ASSERT (accum_stream && ready_event && ready_stream);
268
+ // 1)
269
+ if (*accum_stream != *opt_producer_stream) {
270
+ auto event = c10::Event{device_type};
271
+ event.record (*opt_producer_stream);
272
+ accum_stream->wait (event);
273
+ record_stream_any_impl (var, *accum_stream);
274
+ }
275
+ if (*accum_stream != *ready_stream) {
276
+ accum_stream->wait (*ready_event);
277
+ // This is redundant for case A, but needed for case C
278
+ record_stream_any_impl (buffer[pos], *accum_stream);
221
279
}
280
+ // 2)
281
+ c10::OptionalStreamGuard stream_guard{accum_stream};
282
+ accumulate (buffer, pos, std::move (var));
283
+ // 3)
284
+ auto event = c10::Event{device_type};
285
+ event.record (*accum_stream);
286
+ ready_events[pos] = std::move (event);
287
+ ready_streams[pos] = accum_stream;
222
288
}
223
289
}
224
290
0 commit comments