File tree 1 file changed +10
-12
lines changed
1 file changed +10
-12
lines changed Original file line number Diff line number Diff line change @@ -281,29 +281,27 @@ def _join_sycl_home(*paths) -> str:
281
281
'-D__HIP_NO_HALF_CONVERSIONS__=1' ,
282
282
]
283
283
284
- _COMMON_SYCL_FLAGS = [
285
- '-fsycl' ,
286
- ]
287
-
288
- def _get_sycl_arch_flags ():
284
+ def _get_sycl_arch_list ():
285
+ if 'TORCH_XPU_ARCH_LIST' in os .environ :
286
+ return os .environ .get ('TORCH_XPU_ARCH_LIST' )
289
287
arch_list = torch .xpu .get_arch_list ()
290
288
# Dropping dg2* archs since they lack hardware support for fp64 and require
291
289
# special consideration from the user. If needed these platforms can
292
290
# be requested thru TORCH_XPU_ARCH_LIST environment variable.
293
291
arch_list = [x for x in arch_list if not x .startswith ('dg2' )]
294
- if len (arch_list ) == 0 :
295
- return []
296
- else :
297
- return [ '-fsycl-targets=spir64_gen,spir64 ' ,
298
- '-flink-huge-device-code ' ,
299
- f'-Xs "-device { "," . join ( arch_list ) } "' ]
292
+ return ',' . join (arch_list )
293
+
294
+ _COMMON_SYCL_FLAGS = [
295
+ '-fsycl' ,
296
+ '-fsycl-targets=spir64_gen,spir64' if _get_sycl_arch_list () != '' else ' ' ,
297
+ ]
300
298
301
299
_SYCL_DLINK_FLAGS = [
302
300
* _COMMON_SYCL_FLAGS ,
303
301
'-fsycl-link' ,
304
302
'--offload-compress' ,
303
+ f'-Xs "-device { _get_sycl_arch_list ()} "' if _get_sycl_arch_list () != '' else '' ,
305
304
]
306
-
4D20
_SYCL_DLINK_FLAGS += _get_sycl_arch_flags ()
307
305
308
306
JIT_EXTENSION_VERSIONER = ExtensionVersioner ()
309
307
You can’t perform that action at this time.
0 commit comments