8000 unify device_interface · pytorch/pytorch@b7bbb96 · GitHub
[go: up one dir, main page]

Skip to content

Commit b7bbb96

Browse files
committed
unify device_interface
remove duplicated wa add format ignore flag
1 parent fc93c31 commit b7bbb96

File tree

3 files changed

+7
-17
lines changed

3 files changed

+7
-17
lines changed

torch/_inductor/autotune_process.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -547,20 +547,11 @@ def do_bench(
547547
if len(device_idx_set) == 1:
548548
device_idx = next(iter(device_idx_set))
549549
else:
550-
if torch.cuda.is_available():
551-
device_idx = torch.cuda.current_device()
552-
elif torch.xpu.is_available():
553-
device_idx = torch.xpu.current_device()
554-
555-
556-
if torch.cuda.is_available():
557-
with torch.cuda.device(device_idx):
558-
out = benchmarker.benchmark_gpu(fn)
559-
torch.cuda.synchronize() # shake out any CUDA errors
560-
elif torch.xpu.is_available():
561-
with torch.xpu.device(device_idx):
562-
out = benchmarker.benchmark_gpu(fn)
563-
torch.xpu.synchronize() # shake out any XPU errors
550+
device_idx = device_interface.current_device()
551+
552+
with device_interface.device(device_idx): # type: ignore[attr-defined]
553+
out = benchmarker.benchmark_gpu(fn)
554+
device_interface.synchronize() # shake out any GPU errors
564555

565556
return out
566557

torch/_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def resolve_key(op: OperatorBase, k: DispatchKey): # type: ignore[valid-type]
256256
DispatchKey.BackendSelect,
257257
DispatchKey.AutocastCPU, # type: ignore[attr-defined]
258258
DispatchKey.AutocastCUDA, # type: ignore[attr-defined]
259-
DispatchKey.AutocastXPU,
259+
DispatchKey.AutocastXPU, # type: ignore[attr-defined]
260260
]
261261

262262

torch/testing/_internal/common_device_type.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1339,8 +1339,7 @@ def dep_fn(self, *args, **kwargs):
13391339
if inductor and torch._inductor.config.cpp_wrapper and _device != "cpu":
13401340
size_bytes *= 2
13411341

1342-
# TODO: Memory availability checks for Intel GPU
1343-
if _device != "xpu" and not _has_sufficient_memory(_device, size_bytes):
1342+
if not _has_sufficient_memory(_device, size_bytes):
13441343
raise unittest.SkipTest(f"Insufficient {_device} memory")
13451344

13461345
return fn(self, *args, **kwargs)

0 commit comments

Comments
 (0)
0