-
Notifications
You must be signed in to change notification settings - Fork 24.7k
Add state to distributed composable API #87838
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/87838
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 Failures, 6 PendingAs of commit 835f540: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a labelIf your changes are user facing and intended to be a part of release notes, please use a label starting with If not, please add the For more information, see https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work. |
api.state(module).dummy_state = 8 | ||
return inp | ||
|
||
# FIXME: circular reference looks a bit weird. Shall we make .state a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any thoughts on this?
pass | ||
|
||
|
||
state_key = _StateKey() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about STATE_KEY
, like a constant?
assert isinstance(d, dict), "Distributed composable API states corrupted" | ||
return d | ||
|
||
def wrapper(module: nn.Module, *args, **kwargs) -> Optional[nn.Module]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to make module
-> *module
, like in the design doc?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point. Will update. Is it OK to update that in the follow up PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, not a blocker for me
def wrapper(module: nn.Module, *args 8000 , **kwargs) -> Optional[nn.Module]: | ||
# install states specific to the wrapped ``func`` | ||
all_state: Dict[Callable, dict] = get_all_state(module) | ||
assert func not in all_state, ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An open question we can resolve later: how to make sure some APIs are mutual exclusive, for example shard_embedding/replicate, while some others are not, for example checkpoint/fsdp.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one option could be letting the contract adding info to the state, and the contract can help checking whether 1) same API is called twice, 2) whether there are conflicting APIs.
But might need a way to allow new APIs to declare conflicts.
[ghstack-poisoned]
@pytorchbot merge -g |
Merge startedYour change will be merged once all checks on your PR pass since you used the green (-g) flag (ETA: 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 2 additional jobs have failed, first few of them are: trunk ,trunk / macos-12-py3-arm64-mps / Run MPS tests Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge -f "test failure is irrelevant" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Pull Request resolved: pytorch#87838 Approved by: https://github.com/yhcharles
Pull Request resolved: pytorch#87838 Approved by: https://github.com/yhcharles
Stack from ghstack (oldest at bottom):