E5E8 [Distributed][CI] Rework continuous TestCase by kwen2501 · Pull Request #153653 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Update
[ghstack-poisoned]
  • Loading branch information
kwen2501 committed May 24, 2025
commit 95f6d5dd8cda47fceae2ad499fff08a7a2cae1a3
7 changes: 6 additions & 1 deletion test/distributed/pipelining/test_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
instantiate_parametrized_tests,
parametrize,
run_tests,
skip_but_pass_in_sandcastle,
skip_but_pass_in_sandcastle_if,
)
from torch.utils._pytree import tree_map_only
Expand Down Expand Up @@ -68,6 +69,10 @@ def backend_str(cls) -> str:
# Testing with NCCL backend
return "nccl"

@classmethod
def device_type(cls) -> str:
return device_type

@property
def device(self) -> torch.device:
return torch.device(device_type, self.rank)
Expand Down Expand Up @@ -350,7 +355,7 @@ def init_pg(self):
)

@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@skip_but_pass_in_sandcastle("Flaky in CI")
def test_shape_prop_mismatch(self):
"""Tests shape prop errors are raised"""
self.init_pg()
Expand Down
5 changes: 4 additions & 1 deletion torch/testing/_internal/common_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -1529,7 +1529,10 @@ def backend_str(cls) -> Optional[str]:
# Please override if you intend to test on specific device type
@classmethod
def device_type(cls) -> str:
return torch.accelerator.current_accelerator().type
curr_device = torch.accelerator.current_accelerator()
if curr_device is None:
return "cpu"
return curr_device.type

@classmethod
def opts(cls, high_priority_stream=False):
Expand Down
Loading
0