16
16
17
17
18
18
# NOTE: Also update the CUDA sources in tools/nightly.py when changing this list
19
- CUDA_ARCHES = ["11.8" , "12.4" , "12. 6" , "12.8" ]
19
+ CUDA_ARCHES = ["11.8" , "12.6" , "12.8" ]
20
20
CUDA_ARCHES_FULL_VERSION = {
21
21
"11.8" : "11.8.0" ,
22
- "12.4" : "12.4.1" ,
23
22
"12.6" : "12.6.3" ,
24
23
"12.8" : "12.8.0" ,
25
24
}
26
25
CUDA_ARCHES_CUDNN_VERSION = {
27
26
"11.8" : "9" ,
28
- "12.4" : "9" ,
29
27
"12.6" : "9" ,
30
28
"12.8" : "9" ,
31
29
}
58
56
"nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | "
59
57
"nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64'"
60
58
),
61
- "12.4" : (
62
- "nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | "
63
- "nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | "
64
- "nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | "
65
- "nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | "
66
- "nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | "
67
- "nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | "
8000
div>
68
- "nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | "
69
- "nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | "
70
- "nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | "
71
- "nvidia-cusparselt-cu12==0.6.2; platform_system == 'Linux' and platform_machine == 'x86_64' | "
72
- "nvidia-nccl-cu12==2.25.1; platform_system == 'Linux' and platform_machine == 'x86_64' | "
73
- "nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | "
74
- "nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64'"
75
- ),
76
59
"12.6" : (
77
60
"nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | "
78
61
"nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | "
@@ -348,7 +331,7 @@ def generate_wheels_matrix(
348
331
continue
349
332
350
333
if use_split_build and (
351
- arch_version not in ["12.6" , "12.4 " , "11.8" , "cpu" ] or os != "linux"
334
+ arch_version not in ["12.6" , "12.8 " , "11.8" , "cpu" ] or os != "linux"
352
335
):
353
336
raise RuntimeError (
354
337
"Split build is only supported on linux with cuda 12*, 11.8, and cpu.\n "
@@ -359,7 +342,7 @@ def generate_wheels_matrix(
359
342
# cuda linux wheels require PYTORCH_EXTRA_INSTALL_REQUIREMENTS to install
360
343
361
344
if (
362
- arch_version in ["12.8" , "12.6" , "12.4" , " 11.8" ]
345
+ arch_version in ["12.8" , "12.6" , "11.8" ]
363
346
and os == "linux"
364
347
or arch_version in CUDA_AARCH64_ARCHES
365
348
):
@@ -388,8 +371,8 @@ def generate_wheels_matrix(
388
371
), # include special case for aarch64 build, remove the -aarch64 postfix
389
372
}
390
373
)
391
- # Special build building to use on Colab. Python 3.11 for 12.4 CUDA
392
- if python_version == "3.11" and arch_version == "12.4 " :
374
+ # Special build building to use on Colab. Python 3.11 for 12.6 CUDA
375
+ if python_version == "3.11" and arch_version == "12.6 " :
393
376
ret .append (
394
377
{
395
378
"python_version" : python_version ,
@@ -432,7 +415,7 @@ def generate_wheels_matrix(
432
415
"pytorch_extra_install_requirements" : (
433
416
PYTORCH_EXTRA_INSTALL_REQUIREMENTS ["xpu" ]
434
417
if gpu_arch_type == "xpu"
435
- else PYTORCH_EXTRA_INSTALL_REQUIREMENTS ["12.4 " ]
418
+ else PYTORCH_EXTRA_INSTALL_REQUIREMENTS ["12.6 " ]
436
419
if os != "linux"
437
420
else ""
438
421
),
@@ -444,5 +427,4 @@ def generate_wheels_matrix(
444
427
445
428
validate_nccl_dep_consistency ("12.8" )
446
429
validate_nccl_dep_consistency ("12.6" )
447
- validate_nccl_dep_consistency ("12.4" )
448
430
validate_nccl_dep_consistency ("11.8" )
0 commit comments