8000 correct flags configuration · pytorch/pytorch@68c559c · GitHub
[go: up one dir, main page]

Skip to content

Commit 68c559c

Browse files
committed
correct flags configuration
1 parent 8e9cc87 commit 68c559c

File tree

1 file changed

+10
-12
lines changed

1 file changed

+10
-12
lines changed

torch/utils/cpp_extension.py

+10-12
Original file line numberDiff line numberDiff line change
@@ -281,29 +281,27 @@ def _join_sycl_home(*paths) -> str:
281281
'-D__HIP_NO_HALF_CONVERSIONS__=1',
282282
]
283283

284-
_COMMON_SYCL_FLAGS = [
285-
'-fsycl',
286-
]
287-
288-
def _get_sycl_arch_flags():
284+
def _get_sycl_arch_list():
285+
if 'TORCH_XPU_ARCH_LIST' in os.environ:
286+
return os.environ.get('TORCH_XPU_ARCH_LIST')
289287
arch_list = torch.xpu.get_arch_list()
290288
# Dropping dg2* archs since they lack hardware support for fp64 and require
291289
# special consideration from the user. If needed these platforms can
292290
# be requested thru TORCH_XPU_ARCH_LIST environment variable.
293291
arch_list = [x for x in arch_list if not x.startswith('dg2')]
294-
if len(arch_list) == 0:
295-
return []
296-
else:
297-
return ['-fsycl-targets=spir64_gen,spir64',
298-
'-flink-huge-device-code',
299-
f'-Xs "-device {",".join(arch_list)}"']
292+
return ','.join(arch_list)
293+
294+
_COMMON_SYCL_FLAGS = [
295+
'-fsycl',
296+
'-fsycl-targets=spir64_gen,spir64' if _get_sycl_arch_list() != '' else '',
297+
]
300298

301299
_SYCL_DLINK_FLAGS = [
302300
*_COMMON_SYCL_FLAGS,
303301
'-fsycl-link',
304302
'--offload-compress',
303+
f'-Xs "-device {_get_sycl_arch_list()}"' if _get_sycl_arch_list() != '' else '',
305304
]
306- 4D20
_SYCL_DLINK_FLAGS += _get_sycl_arch_flags()
307305

308306
JIT_EXTENSION_VERSIONER = ExtensionVersioner()
309307

0 commit comments

Comments
 (0)
0