8000 Enable XPU distributed test for PT2.8 by daisyden · Pull Request #149916 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Enable XPU distributed test for PT2.8 #149916

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 53 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
d0d8271
make skipXPU work
daisyden May 11, 2024
c791db9
enabled torch-xpu ops in op_db
daisyden May 13, 2024
f5cbd50
clean up code
daisyden May 13, 2024
4d94417
Revert "clean up code"
daisyden May 13, 2024
6844101
Revert "enabled torch-xpu ops in op_db"
daisyden May 13, 2024
5051e3c
Revert "make skipXPU work"
daisyden May 13, 2024
e2aa92a
merge common code update from https://github.com/Chao1Han/pytorch/pul…
chunhuanMeng Mar 19, 2025
9e83095
Merge branch 'main' of https://github.com/daisyden/pytorch into distr…
chunhuanMeng Mar 20, 2025
06dd2aa
merge common code update from https://github.com/Chao1Han/pytorch/pul…
daisyden Mar 20, 2025
a4a732b
Add XPU support for distributed
daisyden Mar 20, 2025
6e3f6b8
Merge branch 'distributed_2.8' of https://github.com/daisyden/pytorch…
daisyden Mar 20, 2025
5f47367
Merge remote-tracking branch 'upstream/main' into distributed_2.8
daisyden Mar 21, 2025
345d7e6
ported fsdp and _composable/fsdp cases
daisyden Mar 21, 2025
4a5a522
Support XPU device for DDP test cases
PenghuiCheng Mar 24, 2025
20a4456
Support XPU device for pipeline cases
PenghuiCheng Mar 24, 2025
a90a603
ported fsdp tests
daisyden Mar 24, 2025
5b1aff7
Merge branch 'distributed_2.8' of https://github.com/daisyden/pytorch…
daisyden Mar 24, 2025
44d55b9
fixed backend mapping error for register_backend function
PenghuiCheng Mar 26, 2025
7dade1f
Update distributed UT cases
PenghuiCheng Apr 1, 2025
580aaee
remove fsdp_kwargs in test_fsdp_memory.py to align with cuda, added r…
daisyden Apr 1, 2025
c0f5713
Merge branch 'upstream_main4' into distributed_2.8
daisyden Apr 1, 2025
6dedbe3
Add test_dynamo_distributed cases
PenghuiCheng Apr 1, 2025
20d074c
Merge remote-tracking branch 'upstream/distributed_2.8' into distribu…
PenghuiCheng Apr 1, 2025
124ff16
update test_tp_random_state.py
PenghuiCheng Apr 2, 2025
0bea112
Merge from main branch
PenghuiCheng Apr 7, 2025
7409ade
support xccl in with_comms
daisyden Apr 8, 2025
636cbff
Merge branch 'distributed_2.8' of https://github.com/daisyden/pytorch…
daisyden Apr 8, 2025
0d5a86b
Merge branch 'upstream_main3' into distributed_2.8
daisyden Apr 8, 2025
3826e30
Enabled UT in test/distributed/tensor
PenghuiCheng Apr 9, 2025
cb711b7
Merge branch 'distributed_2.8' of https://github.com/daisyden/pytorch…
PenghuiCheng Apr 9, 2025
d6cd1b3
refine fsdp2 test case for xpu
daisyden Apr 9, 2025
624be3a
Merge branch 'distributed_2.8' of https://github.com/daisyden/pytorch…
daisyden Apr 9, 2025
8d8c5fe
fix some issues in test case, cuda specific code, world_size 8, etc.
daisyden Apr 11, 2025
1cf7887
merge from main branch
PenghuiCheng Apr 16, 2025
41475ac
Merge remote-tracking branch 'upstream/distributed_2.8' into distribu…
PenghuiCheng Apr 16, 2025
0628c76
Change world size in test_device_mesh.py
PenghuiCheng Apr 18, 2025
b0d935d
Merge remote-tracking branch 'origin/distributed_2.8' into distribute…
PenghuiCheng Apr 18, 2025
58eb87e
Merge remote-tracking branch 'upstream/main' into distributed_2.8
PenghuiCheng Apr 22, 2025
e558eaa
Enabled some UT cases of distributed
PenghuiCheng Apr 24, 2025
83ac56e
enable UT case in _shard and _tool folder
PenghuiCheng Apr 29, 2025
0e7a7b6
Fixed hard code error for world_size 8
PenghuiCheng May 5, 2025
a2b2fc6
merge from main branch
PenghuiCheng May 7, 2025
8de00b9
fix regex
daisyden May 8, 2025
91f5d10
Merge branch 'distributed_2.8' of https://github.com/daisyden/pytorch…
daisyden May 8, 2025
39e6c02
Fixed UT errors for cuda hard code
PenghuiCheng May 15, 2025
06d6c3e
Merge remote-tracking branch 'origin/distributed_2.8' into distribute…
PenghuiCheng May 15, 2025
a059005
Merge from upstream main branch
PenghuiCheng May 16, 2025
31ddfc0
Fixed XPU UT error for CUDA hard code
PenghuiCheng May 21, 2025
9a6df8a
Merge remote-tracking branch 'upstream0523/main' into distributed_2.8
daisyden May 23, 2025
50cb9e9
fix fsdp2 issue after rebase, fix #1618 dynamo issue
daisyden May 26, 2025
08559b9
remove duplicated device_type
daisyden May 27, 2025
f6a8c6a
fix rebase issue of test_fully_shard_overlap.py
daisyden May 27, 2025
dc0befa
merge from main branch
PenghuiCheng May 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix some issues in test case, cuda specific code, world_size 8, etc.
  • Loading branch information
daisyden committed Apr 11, 2025
commit 8d8c5fe60570e2adf18c0fc1b1091b273e7d0c49
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def _test_gradient_scaler(self, has_inf: bool, test_2d: bool):
input = torch.randn((2,), device=device_type)

loss = model(input).sum()
scaler = GradScaler(init_scale=2.0, enabled=True)
scaler = GradScaler(init_scale=2.0, enabled=True, device=device_type)
opt = torch.optim.Adam(model.parameters(), lr=1e-2)
scaler.scale(loss).backward()
inv_scale = scaler._scale.double().reciprocal().float()
Expand Down
6 changes: 3 additions & 3 deletions test/distributed/_composable/fsdp/test_fully_shard_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ def delay_collective():
# Share a stream so that all-gather and reduce-scatter block each
# other like in `ProcessGroupNCCL`
comm_stream.wait_stream(torch.accelerator.current_stream())
with torch.cuda.stream(comm_stream):
if device_type == 'cuda':
if device_type == 'cuda':
with torch.cuda.stream(comm_stream):
torch.cuda._sleep(int(comm_sleep_ms * get_cycles_per_ms()))
torch.cuda.current_stream().wait_stream(comm_stream)
torch.accelerator.current_stream().wait_stream(comm_stream)

def delayed_all_gather(*args, **kwargs):
delay_collective()
Expand Down
4 changes: 2 additions & 2 deletions test/distributed/tensor/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ def test_shard_tensor_2d(self):
class DTensorMeshTest(DTensorTestBase):
@property
def world_size(self):
return 8
return min(8, torch.accelerator.device_count())

def sub_mesh_assert_equal(self, mesh, exp_in_mesh, exp_out_of_mesh, tensor):
if self.rank in mesh:
Expand Down Expand Up @@ -930,7 +930,7 @@ def test_metadata_consistency_check(self):
class TestDTensorPlacementTypes(DTensorTestBase):
@property
def world_size(self):
return 8
return min(8, torch.accelerator.device_count())

def _create_tensor(self, size):
# Keep everything deterministic.
Expand Down
2 changes: 1 addition & 1 deletion test/distributed/tensor/test_redistribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def test_shard_dim_alltoall(self):
class MultiDimRedistributeTest(DTensorTestBase):
@property
def world_size(self) -> int:
return 8
return min(8, torch.accelerator.device_count())

@with_comms
def test_multi_dim_mesh(self):
Expand Down
4 changes: 2 additions & 2 deletions test/distributed/test_device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def test_device_mesh_init_backend(self):
# we call init_backend we should make sure the default pg already created
mesh.get_coordinate()

@unittest.skipif(not torch.accelerator.is_available(), "No accelerator available!")
@unittest.skipIf(not torch.accelerator.is_available(), "No accelerator available!")
def test_fake_pg_device_mesh(self):
fake_store = FakeStore()
init_process_group("fake", store=fake_store, rank=0, world_size=self.world_size)
Expand Down Expand Up @@ -266,7 +266,7 @@ def test_from_group_with_invalid_mesh(self):
groups, self.device_type, invalid_mesh, mesh_dim_names=("dim0", "dim1")
)

@unittest.skipif(not torch.accelerator.is_available(), "No accelerator available!")
@unittest.skipIf(not torch.accelerator.is_available(), "No accelerator available!")
def test_raises_invalid_device_type(self):
with self.assertRaisesRegex(
RuntimeError,
Expand Down
0