8000 [dtensor] add src_data_rank to distribute_tensor API by wanchaol · Pull Request #143883 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[dtensor] add src_data_rank to distribute_tensor API #143883

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from

Conversation

wanchaol
Copy link
Collaborator
@wanchaol wanchaol commented Dec 26, 2024

Stack from ghstack (oldest at bottom):

As titled, this PR add a kwarg src_data_rank to the distribute_tensor
API, to allow user specify a specific rank as the full tensor source
data. Previously we by default specify group_rank=0 as the source of
truth for single device semantic, this new option:

  • gives advanced user flexiblity to choose the source data rank
  • allow user to specify None explicity, which means we will skip the
    communications needed (scatter/broadcast) for the cases that does not
    care about single device semantic (i.e. loading from a checkpoint)

cc @H-Huang @awgu @kwen2501 @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

As titled, this PR add a kwarg src_data_rank to the distribute_tensor
API, to allow user specify a specific rank as the full tensor source
data. Previously we by default specify group_rank=0 as the source of
truth for single device semantic, this new option:

* gives advanced user flexiblity to choose the source data rank
* allow user to specify None explicity, which means we will skip the
  communications needed (scatter/broadcast) for the cases that does not
care about single device semantic (i.e. loading from a checkpoint)

[ghstack-poisoned]
Copy link
pytorch-bot bot commented Dec 26, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/143883

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit cd83f97 with merge base d88a8c4 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue labels Dec 26, 2024
As titled, this PR add a kwarg src_data_rank to the distribute_tensor
API, to allow user specify a specific rank as the full tensor source
data. Previously we by default specify group_rank=0 as the source of
truth for single device semantic, this new option:

* gives advanced user flexiblity to choose the source data rank
* allow user to specify None explicity, which means we will skip the
  communicat
10000
ions needed (scatter/broadcast) for the cases that does not
care about single device semantic (i.e. loading from a checkpoint)

cc H-Huang awgu kwen2501 fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
wanchaol added a commit that referenced this pull request Dec 27, 2024
As titled, this PR add a kwarg src_data_rank to the distribute_tensor
API, to allow user specify a specific rank as the full tensor source
data. Previously we by default specify group_rank=0 as the source of
truth for single device semantic, this new option:

* gives advanced user flexiblity to choose the source data rank
* allow user to specify None explicity, which means we will skip the
  communications needed (scatter/broadcast) for the cases that does not
care about single device semantic (i.e. loading from a checkpoint)

ghstack-source-id: 55510d9
Pull Request resolved: #143883
@wanchaol wanchaol added release notes: distributed (dtensor) release notes category ciflow/trunk Trigger trunk jobs on your pull request labels Dec 28, 2024
Copy link
Contributor
@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Welcome back!

Had a comment on the semantics of distribute_tensor with shard placements when src_data_rank is None.

As titled, this PR add a kwarg src_data_rank to the distribute_tensor
API, to allow user specify a specific rank as the full tensor source
data. Previously we by default specify group_rank=0 as the source of
truth for single device semantic, this new option:

* gives advanced user flexiblity to choose the source data rank
* allow user to specify None explicity, which means we will skip the
  communications needed (scatter/broadcast) for the cases that does not
care about single device semantic (i.e. loading from a checkpoint)

cc H-Huang awgu kwen2501 fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
As titled, this PR add a kwarg src_data_rank to the distribute_tensor
API, to allow user specify a specific rank as the full tensor source
data. Previously we by default specify group_rank=0 as the source of
truth for single device semantic, this new option:

* gives advanced user flexiblity to choose the source data rank
* allow user to specify None explicity, which means we will skip the
  communications needed (scatter/broadcast) for the cases that does not
care about single device semantic (i.e. loading from a checkpoint)

cc H-Huang awgu kwen2501 fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
Copy link
Contributor
@XilunWu XilunWu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for supporting 0-collective tensor sharding.

@@ -41,15 +45,21 @@ def world_size(self) -> int:
return 4

@with_comms
def test_distribute_tensor(self):
def test_distribute_tensor_rank(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be good if we also test uneven sharding.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this feature is somewhat orthogonal to even or uneven sharding, I'll try to update as a follow up later

Copy link
Contributor
@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

pytorchmergebot pushed a commit that referenced this pull request Jan 2, 2025
as titled, this PR propagates the src_data_rank in the TP API, so that
module level APIs could leverage the flexibility to choose
src_data_rank, and avoid the communication if it does not need to

Pull Request resolved: #144005
Approved by: https://github.com/tianyu-l
ghstack dependencies: #143883
@github-actions github-actions bot deleted the gh/wanchaol/361/head branch February 2, 2025 02:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (dtensor) release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
0