10000 [Distributed][CI] Rework continuous TestCase by kwen2501 · Pull Request #153653 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[Distributed][CI] Rework continuous TestCase #153653

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from

Conversation

kwen2501
Copy link
Contributor
@kwen2501 kwen2501 commented May 15, 2025

Stack from ghstack (oldest at bottom):

  1. Reworked MultiProcContinousTest to spawn processes during setUpClass instead of main (so that we can support multiple TestClass'es in one file).

  2. The child processes are now an infinite loop, monitoring test IDs passed from main process via a task queue. Reciprocally, the child processes inform the main process completion of a test via a completion queue.

  3. Added a test template.

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k

[ghstack-poisoned]
Copy link
pytorch-bot bot commented May 15, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/153653

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 95f6d5d with merge base fa85434 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category labels May 15, 2025
[ghstack-poisoned]
[ghstack-poisoned]
@kwen2501 kwen2501 changed the title [SymmMem] Speed up tests [Distributed][CI] Rework continuous TestCase May 16, 2025
@kwen2501 kwen2501 requested review from fduwjj and d4l3k and removed request for fduwjj May 16, 2025 01:49
Copy link
Member
@d4l3k d4l3k left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor
@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but the tests failures are real due to Python 3.10 syntax (Pytorch still supports 3.9).

# Rendezvous file
rdvz_file: Optional[str] = None
# timeout configured per class
timeout: timedelta = timedelta(seconds=120)

@classmethod
@abc.abstractmethod
def backend_str(cls) -> str:
def backend_str(cls) -> str | None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyTorch is still using Python 3.9, which doesn't support | (this supports in Python 3.10)

Copy link
Contributor
@fduwjj fduwjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unblock

# number of test processes
world_size: int = 2
world_size: int = -2 # unset state
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we changing this to -2?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will be determined at run time by device count.


@classmethod
def _run_test_given_id(cls, test_id: str, **kwargs) -> None:
# self.id() == e.g. '__main__.TestDistributed.TestAdditive.test_get_rank'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no this is an example.

kwen2501 added 2 commits May 17, 2025 13:16
[ghstack-poisoned]
[ghstack-poisoned]
@kwen2501
8000 Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 19, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #153677

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / linux-jammy-rocm-py3.10 / test (default, 1, 2, linux.rocm.gpu.2)

Details for Dev Infra team Raised by workflow job

@kwen2501
Copy link
Contributor Author

@pytorchbot merge -f "the RoCM node prep timeout does not seem related"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@kwen2501
Copy link
Contributor Author

@pytorchbot revert -m "More fixes needed" -c=nosignal

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot added a commit that referenced this pull request May 19, 2025
This reverts commit 0d5c628.

Reverted #153653 on behalf of https://github.com/kwen2501 due to More fixes needed ([comment](#153653 (comment)))
@pytorchmergebot
Copy link
Collaborator

@kwen2501 your PR has been successfully reverted.

@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels May 19, 2025
kwen2501 added 3 commits May 23, 2025 16:28
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@kwen2501
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request May 26, 2025
Use `MultiProcContinousTest` to avoid re-create ProcessGroup in each test instance.

Pull Request resolved: #153677
Approved by: https://github.com/fegin, https://github.com/Skylion007, https://github.com/ngimel
ghstack dependencies: #153653
pytorchmergebot pushed a commit that referenced this pull request Jun 3, 2025
Fix #154373, #154391, #154408, #154443, #154481

Because MultiProcContinousTest [now executes the tests with 8 GPUs instead of 2](#153653), our PP tests comparing gradients have become flakier due to the longer pipeline. The gradients are still close but we need to relax the tolerance.

Pull Request resolved: #154856
Approved by: https://github.com/Skylion007
iupaikov-amd pushed a commit to ROCm/pytorch that referenced this pull request Jun 4, 2025
Fix pytorch#154373, pytorch#154391, pytorch#154408, pytorch#154443, pytorch#154481

Because MultiProcContinousTest [now executes the tests with 8 GPUs instead of 2](pytorch#153653), our PP tests comparing gradients have become flakier due to the longer pipeline. The gradients are still close but we need to relax the tolerance.

Pull Request resolved: pytorch#154856
Approved by: https://github.com/Skylion007
angelayi pushed a commit to angelayi/pytorch that referenced this pull request Jun 5, 2025
Fix pytorch#154373, pytorch#154391, pytorch#154408, pytorch#154443, pytorch#154481

Because MultiProcContinousTest [now executes the tests with 8 GPUs instead of 2](pytorch#153653), our PP tests comparing gradients have become flakier due to the longer pipeline. The gradients are still close but we need to relax the tolerance.

Pull Request resolved: pytorch#154856
Approved by: https://github.com/Skylion007
pytorchmergebot pushed a commit that referenced this pull request Jun 6, 2025
A 2D AllToAllv shuffle is illustrated below:
(`world_size` = 2, `ne` = 2, where `ne` is number of experts per rank)
```
        Source: |       Rank 0      |       Rank 1      |
                | c0 | c1 | c2 | c3 | d0 | d1 | d2 | d3 |

        Dest  : |       Rank 0      |       Rank 1      |
                | c0 | d0 | c1 | d1 | c2 | d2 | c3 | d3 |
```
where each `c_i` / `d_i` are slices of the `input` tensor, targeting expert `i`, with length indicated by input splits (in `in_out_splits[0]`).

That is, the 2D AllToAllv shuffle achieves a transpose from rank-major order at input to expert-major order at output.

Pull Request resolved: #155058
Approved by: https://github.com/ngimel
ghstack dependencies: #153653, #153677
pytorchmergebot pushed a commit that referenced this pull request Jun 6, 2025
Downstream consumer of the 2D all-to-all-v is often a group GEMM.
Today the GEMM often have an alignment requirement on the chunk sizes within grouped sequence, where each chunk carries the tokens headed for an expert. For example, `torch._group_mm` requires an alignment of 8.

This PR adds that alignment capability, when user passes in a `major_align` argument, so that no extra padding step is needed.

The key in supporting that is making the output offsets aligned to such value. (Output offsets are returned to the users in the 3rd row of `in_out_splits`, on device. The 2nd row, output splits, are unaffected by this alignment value -- i.e. reflecting true number of tokens for an expert.)

The algorithm is as follows.

![502413288_678786854922438_530852083153996358_n](https://github.com/user-attachments/assets/557624a3-150e-4ab6-ba8b-1dbaa5ac01ac)

In detailed implementation, we use warp scan to calculate prefix sum on the "block" illustrated above. As a result, the "block" size, i.e. `npes` is currently limited to warp size 32.

Pull Request resolved: #155172
Approved by: https://github.com/ngimel
ghstack dependencies: #153653, #153677, #155058
@github-actions github-actions bot deleted the gh/kwen2501/153/head branch June 27, 2025 02:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci-no-td Do not run TD on this PR ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category Reverted
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants
0