-
Notifications
You must be signed in to change notification settings - Fork 24.8k
[dtensor] add src_data_rank to distribute_tensor API #143883
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
As titled, this PR add a kwarg src_data_rank to the distribute_tensor API, to allow user specify a specific rank as the full tensor source data. Previously we by default specify group_rank=0 as the source of truth for single device semantic, this new option: * gives advanced user flexiblity to choose the source data rank * allow user to specify None explicity, which means we will skip the communications needed (scatter/broadcast) for the cases that does not care about single device semantic (i.e. loading from a checkpoint) [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/143883
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit cd83f97 with merge base d88a8c4 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
As titled, this PR add a kwarg src_data_rank to the distribute_tensor API, to allow user specify a specific rank as the full tensor source data. Previously we by default specify group_rank=0 as the source of truth for single device semantic, this new option: * gives advanced user flexiblity to choose the source data rank * allow user to specify None explicity, which means we will skip the communicat 10000 ions needed (scatter/broadcast) for the cases that does not care about single device semantic (i.e. loading from a checkpoint) cc H-Huang awgu kwen2501 fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
As titled, this PR add a kwarg src_data_rank to the distribute_tensor API, to allow user specify a specific rank as the full tensor source data. Previously we by default specify group_rank=0 as the source of truth for single device semantic, this new option: * gives advanced user flexiblity to choose the source data rank * allow user to specify None explicity, which means we will skip the communications needed (scatter/broadcast) for the cases that does not care about single device semantic (i.e. loading from a checkpoint) ghstack-source-id: 55510d9 Pull Request resolved: #143883
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Welcome back!
Had a comment on the semantics of distribute_tensor
with shard placements when src_data_rank is None
.
As titled, this PR add a kwarg src_data_rank to the distribute_tensor API, to allow user specify a specific rank as the full tensor source data. Previously we by default specify group_rank=0 as the source of truth for single device semantic, this new option: * gives advanced user flexiblity to choose the source data rank * allow user to specify None explicity, which means we will skip the communications needed (scatter/broadcast) for the cases that does not care about single device semantic (i.e. loading from a checkpoint) cc H-Huang awgu kwen2501 fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
As titled, this PR add a kwarg src_data_rank to the distribute_tensor API, to allow user specify a specific rank as the full tensor source data. Previously we by default specify group_rank=0 as the source of truth for single device semantic, this new option: * gives advanced user flexiblity to choose the source data rank * allow user to specify None explicity, which means we will skip the communications needed (scatter/broadcast) for the cases that does not care about single device semantic (i.e. loading from a checkpoint) cc H-Huang awgu kwen2501 fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for supporting 0-collective tensor sharding.
@@ -41,15 +45,21 @@ def world_size(self) -> int: | |||
return 4 | |||
|
|||
@with_comms | |||
def test_distribute_tensor(self): | |||
def test_distribute_tensor_rank(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it would be good if we also test uneven sharding.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this feature is somewhat orthogonal to even or uneven sharding, I'll try to update as a follow up later
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
as titled, this PR propagates the src_data_rank in the TP API, so that module level APIs could leverage the flexibility to choose src_data_rank, and avoid the communication if it does not need to Pull Request resolved: #144005 Approved by: https://github.com/tianyu-l ghstack dependencies: #143883
Stack from ghstack (oldest at bottom):
As titled, this PR add a kwarg src_data_rank to the distribute_tensor
API, to allow user specify a specific rank as the full tensor source
data. Previously we by default specify group_rank=0 as the source of
truth for single device semantic, this new option:
communications needed (scatter/broadcast) for the cases that does not
care about single device semantic (i.e. loading from a checkpoint)
cc @H-Huang @awgu @kwen2501 @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o