8000 Update · pytorch/pytorch@95f6d5d · GitHub
[go: up one dir, main page]

Skip to content

Commit 95f6d5d

Browse files
committed
Update
[ghstack-poisoned]
1 parent 4b820a8 commit 95f6d5d

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

test/distributed/pipelining/test_stage.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
instantiate_parametrized_tests,
2525
parametrize,
2626
run_tests,
27+
skip_but_pass_in_sandcastle,
2728
skip_but_pass_in_sandcastle_if,
2829
)
2930
from torch.utils._pytree import tree_map_only
@@ -68,6 +69,10 @@ def backend_str(cls) -> str:
6869
# Testing with NCCL backend
6970
return "nccl"
7071

72+
@classmethod
73+
def device_type(cls) -> str:
74+
return device_type
75+
7176
@property
7277
def device(self) -> torch.device:
7378
return torch.device(device_type, self.rank)
@@ -350,7 +355,7 @@ def init_pg(self):
350355
)
351356

352357
@requires_nccl()
353-
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
358+
@skip_but_pass_in_sandcastle("Flaky in CI")
354359
def test_shape_prop_mismatch(self):
355360
"""Tests shape prop errors are raised"""
356361
self.init_pg()

torch/testing/_internal/common_distributed.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1529,7 +1529,10 @@ def backend_str(cls) -> Optional[str]:
15291529
# Please override if you intend to test on specific device type
15301530
@classmethod
15311531
def device_type(cls) -> str:
1532-
return torch.accelerator.current_accelerator().type
1532+
curr_device = torch.accelerator.current_accelerator()
1533+
if curr_device is None:
1534+
return "cpu"
1535+
return curr_device.type
15331536

15341537
@classmethod
15351538
def opts(cls, high_priority_stream=False):

0 commit comments

Comments
 (0)
0