@@ -289,7 +289,7 @@ inline std::tuple<size_t, size_t, size_t, size_t> getCopyParameters(const Tensor
289
289
}
290
290
}
291
291
292
- static bool copy_requires_temporaries (TensorIterator& iter, bool p2p_enabled) {
292
+ static bool copy_requires_temporaries (TensorIterator& iter, bool p2p_enabled, bool non_blocking ) {
293
293
Device dst_device = iter.device (0 );
294
294
Device src_device = iter.device (1 );
295
295
@@ -300,8 +300,9 @@ static bool copy_requires_temporaries(TensorIterator& iter, bool p2p_enabled) {
300
300
}
301
301
302
302
bool same_dtype = iter.dtype (0 ) == iter.dtype (1 );
303
- if (same_dtype && iter.is_contiguous ()) {
304
- // Contiguous same-dtype copies can always use cudaMemcpyAsync
303
+ bool same_conj_neg = iter.tensor (0 ).is_conj () == iter.tensor (1 ).is_conj () && iter.tensor (0 ).is_neg () == iter.tensor (0 ).is_neg ();
304
+ if (same_dtype && iter.is_contiguous () && same_conj_neg) {
305
+ // Contiguous same-dtype copies can always use cudaMemcpyAsync if we don't have a conjugate flag to handle first
305
306
return false ;
306
307
} else if (dst_device.is_cuda () && src_device.is_cuda ()) {
307
308
// Copies between GPUs can use the copy kernel if P2P is supported
@@ -310,12 +311,15 @@ static bool copy_requires_temporaries(TensorIterator& iter, bool p2p_enabled) {
310
311
311
312
// for cross-device copies we can use memcpy2d if conditions are satisfied
312
313
if (dst_device.is_cuda () != src_device.is_cuda () && same_dtype && iter.ndim () <= 2 ) {
313
- // TensorIterator reorders strides so that the first one is the smallest
314
-
315
- if (iter.ndim () == 1 || iter.has_contiguous_first_dim ()) {
316
- auto [width_in_bytes, src_pitch, dst_pitch, height] = getCopyParameters (iter);
317
- if (src_pitch >= width_in_bytes && dst_pitch >= width_in_bytes) {
318
- return false ; // No need for temporaries
314
+ // We can do an async copy if there is not a conjugate bit to resolve
315
+ if (same_conj_neg || !non_blocking){
316
+ // TensorIterator reorders strides so that the first one is the smallest
317
+
318
+ if (iter.ndim () == 1 || iter.has_contiguous_first_dim ()) {
319
+ auto [width_in_bytes, src_pitch, dst_pitch, height] = getCopyParameters (iter);
320
+ if (src_pitch >= width_in_bytes && dst_pitch >= width_in_bytes) {
321
+ return false ; // No need for temporaries
322
+ }
319
323
}
320
324
}
321
325
}
@@ -340,8 +344,8 @@ static void copy_kernel_cuda(TensorIterator& iter, bool non_blocking) {
340
344
341
345
// Enable p2p access between devices. (No-op if it involves the CPU)
342
346
bool p2p_enabled = maybe_enable_p2p_access (dst_device, src_device);
343
-
344
- if (copy_requires_temporaries (iter, p2p_enabled)) {
347
+ bool temp_needed = copy_requires_temporaries (iter, p2p_enabled, non_blocking);
348
+ if (copy_requires_temporaries (iter, p2p_enabled, non_blocking )) {
345
349
// NB: this involves recursive calls to copy. Be careful that those copies
346
350
// don't require temporaries or you will cause an infinite recursion!
347
351
auto & dst = iter.tensor (0 );
@@ -355,19 +359,17 @@ static void copy_kernel_cuda(TensorIterator& iter, bool non_blocking) {
355
359
auto conversion_device = non_blocking ? kCUDA : kCPU ;
356
360
if (iter.device_type (1 ) == conversion_device) {
357
361
dst_contig = dst.is_contiguous () ? dst : at::empty_like (dst, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
358
- src_contig = iter.tensor (1 ).to (iter.dtype (0 )).expand_as (dst).contiguous ();
362
+ src_contig = iter.tensor (1 ).to (iter.dtype (0 )).expand_as (dst).contiguous (). resolve_conj () ;
359
363
} else {
360
364
bool same_type = iter.dtype (0 ) == iter.dtype (1 );
361
365
dst_contig = (dst.is_contiguous () && same_type) ? dst : at::empty_like (dst, iter.dtype (1 ), LEGACY_CONTIGUOUS_MEMORY_FORMAT);
362
- src_contig = iter.tensor (1 ).expand_as (dst).contiguous ();
366
+ src_contig = iter.tensor (1 ).expand_as (dst).contiguous (). resolve_conj () ;
363
367
}
364
368
365
- // propagate the correct conjugate bit
369
+ // propagate the correct conjugate bit to dst, src is resolved above
366
370
dst_contig._set_conj (dst.is_conj ());
367
- src_contig._set_conj (iter.tensor (1 ).is_conj ());
368
371
369
372
dst_contig._set_neg (dst.is_neg ());
370
- src_contig._set_neg (iter.tensor (1 ).is_neg ());
371
373
372
374
// perform a same-dtype copy on contiguous tensors
373
375
TORCH_INTERNAL_ASSERT (dst_contig.sizes ().equals (src_contig.sizes ()));
0 commit comments