@@ -467,39 +467,31 @@ struct ConvParams {
467
467
// always use cudnn_depthwise for channels_last format
468
468
return true ;
469
469
}
470
- if (detail::getCUDAHooks ().supportsDepthwiseConvolutionWithCuDNN ()) {
471
- long cudnn_version = detail::getCUDAHooks ().versionCuDNN ();
472
- if (cudnn_version >= 8200 ) {
473
- bool kernel_cond = (use_cudnn (input, weight) &&
474
- input.scalar_type () == kHalf && // only for FP16
475
- weight.scalar_type () == kHalf &&
476
- is_depthwise (input, weight) &&
477
- input.ndimension () == 4 && // TODO: 5-D contiguous depthwise is not supported yet, need benchmarks
478
- !is_dilated () && // no dilation supported
479
- (stride[0 ] == stride[1 ] || at::symint::size<T>(input, 2 ) == 1 ) && // square or 1d
480
- at::symint::size<T>(input, 1 ) >= 32 ); // min 32 channels supported)
481
- if (kernel_cond) {
482
- return check_cudnn_depthwise_workload_with_filter<T>(input, stride[1 ], weight);
483
- }
470
+ // native kernel doesn't support 64-bit non-splittable case
471
+ if (cudnn_enabled && needs_64bit_indexing_no_split (input, weight)) {
472
+ static long cudnn_version = detail::getCUDAHooks ().versionCuDNN ();
473
+ if (!(cudnn_version >= 90300 && at::native::cudnnv8_enabled_check_debug ())) {
474
+ TORCH_WARN_ONCE (" cuDNN cannot be used for large non-batch-splittable convolutions"
475
+ " if the V8 API is not enabled or before cuDNN version 9.3+."
476
+ " Upgrade cuDNN or enable the V8 API to use cuDNN for 64-bit depthwise convolutions." <
8000
/span>);
477
+ return false ;
478
+ } else {
479
+ return true ;
484
480
}
485
- // keep (7600 <= cudnn < 8200) code unchanged
486
- bool kernel_cond = (cudnn_version >= 7600 &&
487
- use_cudnn (input, weight) &&
481
+ }
482
+ if ( detail::getCUDAHooks (). supportsDepthwiseConvolutionWithCuDNN ()) {
483
+ bool kernel_cond = ( use_cudnn (input, weight) &&
488
484
input.scalar_type () == kHalf && // only for FP16
489
485
weight.scalar_type () == kHalf &&
490
486
is_depthwise (input, weight) &&
491
487
input.ndimension () == 4 && // TODO: 5-D contiguous depthwise is not supported yet, need benchmarks
492
- at::symint::size<T>(weight, 2 ) == at::symint::size<T>(weight, 3 ) && // only square kernels
493
- at::symint::size<T>(input, 2 ) >= 7 && // min width/height 7
494
488
!is_dilated () && // no dilation supported
495
- stride[0 ] == stride[1 ] && // equal strides
496
- ((at::symint::size<T>(weight, 3 ) == 3 ) || (at::symint::size<T>(weight, 3 ) == 1 )) &&
489
+ (stride[0 ] == stride[1 ] || at::symint::size<T>(input, 2 ) == 1 ) && // square or 1d
497
490
at::symint::size<T>(input, 1 ) >= 32 ); // min 32 channels supported)
498
491
if (kerne
8000
l_cond) {
499
- return check_cudnn_depthwise_workload<T>(input, stride[0 ]);
500
- } else {
501
- return false ;
492
+ return check_cudnn_depthwise_workload_with_filter<T>(input, stride[1 ], weight);
502
493
}
494
+ return false ;
503
495
} else {
504
496
return false ;
505
497
}
0 commit comments