29
29
LINUX_CPU_TEST_RUNNER = "linux.2xlarge"
30
30
# contains 1 gpu
31
31
LINUX_CUDA_TEST_RUNNER = "linux.4xlarge.nvidia.gpu"
32
+ # contains 4 gpus
33
+ LINUX_ROCM_TEST_RUNNER = "linux.rocm.gpu"
32
34
LINUX_RUNNERS = {
33
35
LINUX_CPU_TEST_RUNNER ,
34
36
LINUX_CUDA_TEST_RUNNER ,
37
+ LINUX_ROCM_TEST_RUNNER ,
38
+ }
39
+
40
+ LINUX_DISTRIBUTED_GPU_RUNNERS = {
41
+ LINUX_CUDA_TEST_RUNNER : "linux.8xlarge.nvidia.gpu" ,
42
+ LINUX_ROCM_TEST_RUNNER : LINUX_ROCM_TEST_RUNNER ,
43
+ }
44
+
45
+ LINUX_MULTIGPU_RUNNERS = {
46
+ LINUX_CUDA_TEST_RUNNER : "linux.16xlarge.nvidia.gpu" ,
47
+ LINUX_ROCM_TEST_RUNNER : LINUX_ROCM_TEST_RUNNER ,
35
48
}
36
49
37
50
MACOS_TEST_RUNNER_10_15 = "macos-10.15"
46
59
WINDOWS_CUDA_TEST_RUNNER ,
47
60
LINUX_CUDA_TEST_RUNNER ,
48
61
}
62
+ ROCM_RUNNERS = {
63
+ LINUX_ROCM_TEST_RUNNER ,
64
+ }
49
65
CPU_RUNNERS = {
50
66
WINDOWS_CPU_TEST_RUNNER ,
51
67
LINUX_CPU_TEST_RUNNER ,
55
71
LABEL_CIFLOW_BAZEL = "ciflow/bazel"
56
72
LABEL_CIFLOW_CPU = "ciflow/cpu"
57
73
LABEL_CIFLOW_CUDA = "ciflow/cuda"
74
+ LABEL_CIFLOW_ROCM = "ciflow/rocm"
58
75
LABEL_CIFLOW_DOCS = "ciflow/docs"
59
76
LABEL_CIFLOW_DEFAULT = "ciflow/default"
60
77
LABEL_CIFLOW_LIBTORCH = "ciflow/libtorch"
@@ -164,6 +181,8 @@ class CIWorkflow:
164
181
165
182
# Optional fields
166
183
test_runner_type : str = ''
184
+ multigpu_runner_type : str = ''
185
+ distributed_gpu_runner_type : str = ''
167
186
ciflow_config : CIFlowConfig = field (default_factory = CIFlowConfig )
168
187
cuda_version : str = ''
169
188
docker_image_base : str = ''
@@ -205,6 +224,9 @@ def __post_init__(self) -> None:
205
224
if self .fx2trt_test :
206
225
self .enable_fx2trt_test = 1
207
226
227
+ self .multigpu_runner_type = LINUX_MULTIGPU_RUNNERS .get (self .test_runner_type , "linux.16xlarge.nvidia.gpu" )
228
+ self .distributed_gpu_runner_type = LINUX_DISTRIBUTED_GPU_RUNNERS .get (self .test_runner_type , "linux.8xlarge.nvidia.gpu" )
229
+
208
230
# If num_test_shards_on_pull_request is not user-defined, default to num_test_shards unless we are
209
231
# only running smoke tests on the pull request.
210
232
if self .num_test_shards_on_pull_request == - 1 :
@@ -235,6 +257,8 @@ def assert_valid(self) -> None:
235
257
10000
assert self .test_runner_type != ''
236
258
if self .test_runner_type in CUDA_RUNNERS :
237
259
assert LABEL_CIFLOW_CUDA in self .ciflow_config .labels
260
+ if self .test_runner_type in ROCM_RUNNERS :
261
+ assert LABEL_CIFLOW_ROCM in self .ciflow_config .labels
238
262
if self .test_runner_type in CPU_RUNNERS and not self .exclude_test :
239
263
assert LABEL_CIFLOW_CPU in self .ciflow_config .labels
240
264
if self .is_scheduled :
@@ -576,6 +600,16 @@ def generate_workflow_file(self, workflow_template: jinja2.Template) -> None:
576
600
labels = set ([LABEL_CIFLOW_DEFAULT , LABEL_CIFLOW_LINUX , LABEL_CIFLOW_CPU ]),
577
601
),
578
602
),
603
+ CIWorkflow (
604
+ arch = "linux" ,
605
+ build_environment = "linux-bionic-rocm4.5-py3.7" ,
606
+ docker_image_base = f"{ DOCKER_REGISTRY } /pytorch/pytorch-linux-bionic-rocm4.5-py3.7" ,
607
+ test_runner_type = LINUX_ROCM_TEST_RUNNER ,
608
+ num_test_shards = 2 ,
609
+ ciflow_config = CIFlowConfig (
610
+ labels = set ([LABEL_CIFLOW_LINUX , LABEL_CIFLOW_ROCM ]),
611
+ ),
612
+ ),
579
613
CIWorkflow (
580
614
arch = "linux" ,
581
615
build_environment = "libtorch-linux-xenial-cuda11.3-py3.7-gcc7" ,
@@ -836,7 +870,7 @@ def generate_workflow_file(self, workflow_template: jinja2.Template) -> None:
836
870
package_type = "manywheel" ,
837
871
build_configs = generate_binary_build_matrix .generate_wheels_matrix (),
838
872
ciflow_config = CIFlowConfig (
839
- labels = {LABEL_CIFLOW_BINARIES , LABEL_CIFLOW_BINARIES_WHEEL },
873
+ labels = {LABEL_CIFLOW_DEFAULT , LABEL_CIFLOW_BINARIES , LABEL_CIFLOW_BINARIES_WHEEL },
840
874
isolated_workflow = True ,
841
875
),
842
876
),
@@ -845,7 +879,7 @@ def generate_workflow_file(self, workflow_template: jinja2.Template) -> None:
845
879
package_type = "conda" ,
846
880
build_configs = generate_binary_build_matrix .generate_conda_matrix (),
847
881
ciflow_config = CIFlowConfig (
848
- labels = {LABEL_CIFLOW_BINARIES , LABEL_CIFLOW_BINARIES_CONDA },
882
+ labels = {LABEL_CIFLOW_DEFAULT , LABEL_CIFLOW_BINARIES , LABEL_CIFLOW_BINARIES_CONDA },
849
883
isolated_workflow = True ,
850
884
),
851
885
),
@@ -857,7 +891,7 @@ def generate_workflow_file(self, workflow_template: jinja2.Template) -> None:
857
891
generate_binary_build_matrix .CXX11_ABI
858
892
),
859
893
ciflow_config = CIFlowConfig (
860
- labels = {LABEL_CIFLOW_BINARIES , LABEL_CIFLOW_BINARIES_LIBTORCH },
894
+ labels = {LABEL_CIFLOW_DEFAULT , LABEL_CIFLOW_BINARIES , LABEL_CIFLOW_BINARIES_LIBTORCH },
861
895
isolated_workflow = True ,
862
896
),
863
897
),
@@ -869,7 +903,7 @@ def generate_workflow_file(self, workflow_template: jinja2.Template) -> None:
869
903
generate_binary_build_matrix .PRE_CXX11_ABI
870
904
),
871
905
ciflow_config = CIFlowConfig (
872
- labels = {LABEL_CIFLOW_BINARIES , LABEL_CIFLOW_BINARIES_LIBTORCH },
906
+ labels = {LABEL_CIFLOW_DEFAULT , LABEL_CIFLOW_BINARIES , LABEL_CIFLOW_BINARIES_LIBTORCH },
873
907
isolated_workflow = True ,
874
908
),
875
909
),
0 commit comments