@@ -294,7 +294,24 @@ void copy_device_to_device(TensorIterator& iter,
294
294
AT_CUDA_CHECK (cudaGetLastError ());
295
295
}
296
296
297
- static bool copy_requires_temporaries (TensorIterator& iter, bool p2p_enabled) {
297
+ inline std::tuple<size_t , size_t , size_t , size_t > getCopyParameters (const TensorIteratorBase& iter) {
298
+ size_t element_size = iter.tensor (0 ).element_size ();
299
+ if (iter.ndim () == 1 ) {
300
+ size_t width_in_bytes = element_size;
301
+ size_t src_pitch = iter.strides (1 )[0 ];
302
+ size_t dst_pitch = iter.strides (0 )[0 ];
303
+ size_t height = iter.shape ()[0 ];
304
+ return std::make_tuple (width_in_bytes, src_pitch, dst_pitch, height);
305
+ } else {
306
+ size_t width_in_bytes = iter.shape ()[0 ] * element_size;
307
+ size_t src_pitch = iter.strides (1 )[1 ];
308
+ size_t dst_pitch = iter.strides (0 )[1 ];
309
+ size_t height = iter.shape ()[1 ];
310
+ return std::make_tuple (width_in_bytes, src_pitch, dst_pitch, height);
311
+ }
312
+ }
313
+
314
+ static bool copy_requires_temporaries (TensorIterator& iter, bool p2p_enabled, bool non_blocking) {
298
315
Device dst_device = iter.device (0 );
299
316
Device src_device = iter.device (1 );
300
317
@@ -303,19 +320,37 @@ static bool copy_requires_temporaries(TensorIterator& iter, bool p2p_enabled) {
303
320
TORCH_INTERNAL_ASSERT (dst_device.is_cuda () && src_device.is_cuda ());
304
321
return false ;
305
322
}
323
+ bool different_conj_neg = iter.tensor (0 ).is_conj () != iter.tensor (1 ).is_conj () || iter.tensor (0 ).is_neg () != iter.tensor (1 ).is_neg ();
324
+ if (!different_conj_neg || !non_blocking){
325
+ // We require a temporary when the src tensor has a conj/neg flag set, but
326
+ // the destination does not, and the copy is async. Otherwise, the
327
+ // conj_physical applied to dst may execute before the copy has completed.
328
+
329
+ bool same_dtype = iter.dtype (0 ) == iter.dtype (1 );
330
+ if (same_dtype && iter.is_contiguous ()) {
331
+ // Contiguous same-dtype copies can always use cudaMemcpyAsync if we don't have a conjugate flag to handle first
332
+ return false ;
333
+ } else if (dst_device.is_cuda () && src_device.is_cuda ()) {
334
+ // Copies between GPUs can use the copy kernel if P2P is supported
335
+ return !p2p_enabled;
336
+ }
306
337
307
- bool same_dtype = iter. dtype ( 0 ) == iter. dtype ( 1 );
308
- if (same_dtype && iter.is_contiguous () ) {
309
- // Contiguous same-dtype copies can always use cudaMemcpyAsync
310
- return false ;
311
- } else if (dst_device. is_cuda () && src_device. is_cuda ()) {
312
- // Copies between GPUs can use the copy kernel if P2P is supported
313
- return !p2p_enabled;
314
- } else {
315
- // The remaining cases require temporaries. For example, this includes
316
- // non-contiguous copies between CPU and GPU.
317
- return true ;
338
+ // for cross-device copies we can use memcpy2d if conditions are satisfied
339
+ if (dst_device. is_cuda () != src_device. is_cuda () && same_dtype && iter.ndim () <= 2 ) {
340
+ // TensorIterator reorders strides so that the first one is the smallest
341
+
342
+ if (iter. ndim () == 1 || iter. has_contiguous_first_dim ()) {
343
+ auto [width_in_bytes, src_pitch, dst_pitch, height] = getCopyParameters (iter);
344
+ if (src_pitch >= width_in_bytes && dst_pitch >= width_in_bytes) {
345
+ return false ; // No need for temporaries
346
+ }
347
+ }
348
+ }
318
349
}
350
+
351
+ // The remaining cases require temporaries. For example, this includes
352
+ // non-contiguous copies between CPU and GPU.
353
+ return true ;
319
354
}
320
355
321
356
static bool maybe_enable_p2p_access (Device dst_device, Device src_device) {
@@ -333,8 +368,7 @@ static void copy_kernel_cuda(TensorIterator& iter, bool non_blocking) {
333
368
334
369
// Enable p2p access between devices. (No-op if it involves the CPU)
335
370
bool p2p_enabled = maybe_enable_p2p_access (dst_device, src_device);
336
-
337
- if (copy_requires_temporaries (iter, p2p_enabled)) {
371
+ if (copy_requires_temporaries (iter, p2p_enabled, non_blocking)) {
338
372
// NB: this involves recursive calls to copy. Be careful that those copies
339
373
// don't require temporaries or you will cause an infinite recursion!
340
374
auto & dst = iter.tensor (0 );
@@ -355,16 +389,31 @@ static void copy_kernel_cuda(TensorIterator& iter, bool non_blocking) {
355
389
src_contig = iter.tensor (1 ).expand_as (dst).contiguous ();
356
390
}
357
391
358
- // propagate the correct conjugate bit
359
- dst_contig._set_conj (dst.is_conj ());
360
- src_contig._set_conj (iter.tensor (1 ).is_conj ());
392
+ if (non_blocking){
393
+ const bool need_conj = dst.is_conj () != src_contig.is_conj ();
394
+ const bool need_neg = dst.is_neg () != src_contig.is_neg ();
395
+ // Doing these inplace may lead to modification on the src if none of
396
+ // the above materialized a clone.
397
+ if (need_conj){
398
+ src_contig = src_contig.conj_physical ();
399
+ }
400
+ if (need_neg){
401
+ src_contig = src_contig.neg ();
402
+ }
403
+ src_contig._set_conj (dst.is_conj ());
404
+ src_contig._set_neg (dst.is_neg ());
405
+ }
406
+
361
407
408
+ // propagate the correct conjugate bit to dst
409
+ dst_contig._set_conj (dst.is_conj ());
362
410
dst_contig._set_neg (dst.is_neg ());
363
- src_contig._set_neg (iter.tensor (1 ).is_neg ());
364
411
365
412
// perform a same-dtype copy on contiguous tensors
366
413
TORCH_INTERNAL_ASSERT (dst_contig.sizes ().equals (src_contig.sizes ()));
367
414
TORCH_INTERNAL_ASSERT (dst_contig.scalar_type () == src_contig.scalar_type ());
415
+ TORCH_INTERNAL_ASSERT (dst_contig.is_conj () == src_contig.is_conj ());
416
+ TORCH_INTERNAL_ASSERT (dst_contig.is_neg () == src_contig.is_neg ());
368
417
dst_contig.copy_ (src_contig, non_blocking);
369
418
370
419
// if necessary, copy back into dst
@@ -424,11 +473,13 @@ static void copy_kernel_cuda(TensorIterator& iter, bool non_blocking) {
424
473
at::cuda::memcpy_and_sync (dst, src, nbytes, kind, stream);
425
474
}
426
475
427
- if (iter.tensor (0 ).is_conj () != iter.tensor (1 ).is_conj ()) {
428
- iter.tensor (0 ).conj_physical_ ();
429
- }
430
- if (iter.tensor (0 ).is_neg () != iter.tensor (1 ).is_neg ()) {
431
- iter.tensor (0 ).neg_ ();
476
+ if (!non_blocking){
477
+ if (iter.tensor (0 ).is_conj () != iter.tensor (1 ).is_conj ()) {
478
+ iter.tensor (0 ).conj_physical_ ();
479
+ }
480
+ if (iter.tensor (0 ).is_neg () != iter.tensor (1 ).is_neg ()) {
481
+ iter.tensor (0 ).neg_ ();
482
+ }
432
483
}
433
484
}
434
485
0 commit comments