24
24
25
25
CPU_AARCH64_ARCH = ["cpu-aarch64" ]
26
26
27
- PYTORCH_EXTRA_INSTALL_REQUIREMENTS = (
28
- "nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | " # noqa: B950
29
- "nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | "
30
- "nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | "
31
- "nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | "
32
- "nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | "
33
- "nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | "
34
- "nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | "
35
- "nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | "
36
- "nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | "
37
- "nvidia-nccl-cu12==2.19.3; platform_system == 'Linux' and platform_machine == 'x86_64' | "
38
- "nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64'"
39
- )
27
+ PYTORCH_EXTRA_INSTALL_REQUIREMENTS = {
28
+ "11.8" : (
29
+ "nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | " # noqa: B950
30
+ "nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | "
31
+ "nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | "
32
+ "nvidia-cudnn-cu11==8.7.0.84; platform_system == 'Linux' and platform_machine == 'x86_64' | "
33
+ "nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | "
34
+ "nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | "
35
+ "nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | "
36
+ "nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | "
37
+ "nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | "
38
+ "nvidia-nccl-cu11==2.19.3; platform_system == 'Linux' and platform_machine == 'x86_64' | "
39
+ "nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64'"
40
+ ),
41
+ "12.1" : (
42
+ "nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | " # noqa: B950
43
+ "nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | "
44
+ "nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | "
45
+ "nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | "
46
+ "nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | "
47
+ "nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | "
48
+ "nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | "
49
+ "nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | "
50
+ "nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | "
51
+ "nvidia-nccl-cu12==2.19.3; platform_system == 'Linux' and platform_machine == 'x86_64' | "
52
+ "nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64'"
53
+ ),
54
+ }
40
55
41
56
42
57
def get_nccl_submodule_version () -> str :
@@ -65,15 +80,17 @@ def get_nccl_submodule_version() -> str:
65
80
return f"{ d ['NCCL_MAJOR' ]} .{ d ['NCCL_MINOR' ]} .{ d ['NCCL_PATCH' ]} "
66
81
67
82
68
- def get_nccl_wheel_version () -> str :
83
+ def get_nccl_wheel_version (arch_version : str ) -> str :
69
84
import re
70
85
71
- requrements = map (str .strip , re .split ("[;|]" , PYTORCH_EXTRA_INSTALL_REQUIREMENTS ))
72
- return [x for x in requrements if x .startswith ("nvidia-nccl-cu" )][0 ].split ("==" )[1 ]
86
+ requirements = map (
87
+ str .strip , re .split ("[;|]" , PYTORCH_EXTRA_INSTALL_REQUIREMENTS [arch_version ])
88
+ )
89
+ return [x for x in requirements if x .startswith ("nvidia-nccl-cu" )][0 ].split ("==" )[1 ]
73
90
74
91
75
- def validate_nccl_dep_consistency () -> None :
76
- wheel_ver = get_nccl_wheel_version ()
92
+ def validate_nccl_dep_consistency (arch_version : str ) -> None :
93
+ wheel_ver = get_nccl_wheel_version (arch_version )
77
94
submodule_ver = get_nccl_submodule_version ()
78
95
if wheel_ver != submodule_ver :
79
96
raise RuntimeError (
@@ -298,7 +315,7 @@ def generate_wheels_matrix(
298
315
)
299
316
300
317
# 12.1 linux wheels require PYTORCH_EXTRA_INSTALL_REQUIREMENTS to install
301
- if arch_version == "12.1" and os == "linux" :
318
+ if arch_version in [ "12.1" , "11.8" ] and os == "linux" :
302
319
ret .append (
303
320
{
304
321
"python_version" : python_version ,
@@ -310,7 +327,7 @@ def generate_wheels_matrix(
310
327
"devtoolset" : "" ,
311
328
"container_image" : WHEEL_CONTAINER_IMAGES [arch_version ],
312
329
"package_type" : package_type ,
313
- "pytorch_extra_install_requirements" : PYTORCH_EXTRA_INSTALL_REQUIREMENTS ,
330
+ "pytorch_extra_install_requirements" : PYTORCH_EXTRA_INSTALL_REQUIREMENTS [ arch_version ], # fmt: skip
314
331
"build_name" : f"{ package_type } -py{ python_version } -{ gpu_arch_type } { gpu_arch_version } " .replace ( # noqa: B950
315
332
"." , "_"
316
333
),
@@ -333,12 +350,13 @@ def generate_wheels_matrix(
333
350
"build_name" : f"{ package_type } -py{ python_version } -{ gpu_arch_type } { gpu_arch_version } " .replace (
334
351
"." , "_"
335
352
),
336
- "pytorch_extra_install_requirements" : PYTORCH_EXTRA_INSTALL_REQUIREMENTS
337
- if os != "linux"
338
- else "" ,
353
+ "pytorch_extra_install_requirements" :
354
+ PYTORCH_EXTRA_INSTALL_REQUIREMENTS [ "12.1" ] # fmt: skip
355
+ if os != "linux" else "" ,
339
356
}
340
357
)
341
358
return ret
342
359
343
360
344
- validate_nccl_dep_consistency ()
361
+ validate_nccl_dep_consistency ("12.1" )
362
+ validate_nccl_dep_consistency ("11.8" )
0 commit comments