8000 Enable XPU distributed test for PT2.8 by daisyden · Pull Request #149916 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Enable XPU distributed test for PT2.8 #149916

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 53 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
d0d8271
make skipXPU work
daisyden May 11, 2024
c791db9
enabled torch-xpu ops in op_db
daisyden May 13, 2024
f5cbd50
clean up code
daisyden May 13, 2024
4d94417
Revert "clean up code"
daisyden May 13, 2024
6844101
Revert "enabled torch-xpu ops in op_db"
daisyden May 13, 2024
5051e3c
Revert "make skipXPU work"
daisyden May 13, 2024
e2aa92a
merge common code update from https://github.com/Chao1Han/pytorch/pul…
chunhuanMeng Mar 19, 2025
9e83095
Merge branch 'main' of https://github.com/daisyden/pytorch into distr…
chunhuanMeng Mar 20, 2025
06dd2aa
merge common code update from https://github.com/Chao1Han/pytorch/pul…
daisyden Mar 20, 2025
a4a732b
Add XPU support for distributed
daisyden Mar 20, 2025
6e3f6b8
Merge branch 'distributed_2.8' of https://github.com/daisyden/pytorch…
daisyden Mar 20, 2025
5f47367
Merge remote-tracking branch 'upstream/main' into distributed_2.8
daisyden Mar 21, 2025
345d7e6
ported fsdp and _composable/fsdp cases
daisyden Mar 21, 2025
4a5a522
Support XPU device for DDP test cases
PenghuiCheng Mar 24, 2025
20a4456
Support XPU device for pipeline cases
PenghuiCheng Mar 24, 2025
a90a603
ported fsdp tests
daisyden Mar 24, 2025
5b1aff7
Merge branch 'distributed_2.8' of https://github.com/daisyden/pytorch…
daisyden Mar 24, 2025
44d55b9
fixed backend mapping error for register_backend function
PenghuiCheng Mar 26, 2025
7dade1f
Update distributed UT cases
PenghuiCheng Apr 1, 2025
580aaee
remove fsdp_kwargs in test_fsdp_memory.py to align with cuda, added r…
daisyden Apr 1, 2025
c0f5713
Merge branch 'upstream_main4' into distributed_2.8
daisyden Apr 1, 2025
6dedbe3
Add test_dynamo_distributed cases
PenghuiCheng Apr 1, 2025
20d074c
Merge remote-tracking branch 'upstream/distributed_2.8' into distribu…
PenghuiCheng Apr 1, 2025
124ff16
update test_tp_random_state.py
PenghuiCheng Apr 2, 2025
0bea112
Merge from main branch
PenghuiCheng Apr 7, 2025
7409ade
support xccl in with_comms
daisyden Apr 8, 2025
636cbff
Merge branch 'distributed_2.8' of https://github.com/daisyden/pytorch…
daisyden Apr 8, 2025
0d5a86b
Merge branch 'upstream_main3' into distributed_2.8
daisyden Apr 8, 2025
3826e30
Enabled UT in test/distributed/tensor
PenghuiCheng Apr 9, 2025
cb711b7
Merge branch 'distributed_2.8' of https://github.com/daisyden/pytorch…
PenghuiCheng Apr 9, 2025
d6cd1b3
refine fsdp2 test case for xpu
daisyden Apr 9, 2025
624be3a
Merge branch 'distributed_2.8' of https://github.com/daisyden/pytorch…
daisyden Apr 9, 2025
8d8c5fe
fix some issues in test case, cuda specific code, world_size 8, etc.
daisyden Apr 11, 2025
1cf7887
merge from main branch
PenghuiCheng Apr 16, 2025
41475ac
Merge remote-tracking branch 'upstream/distributed_2.8' into distribu…
PenghuiCheng Apr 16, 2025
0628c76
Change world size in test_device_mesh.py
PenghuiCheng Apr 18, 2025
b0d935d
Merge remote-tracking branch 'origin/distributed_2.8' into distribute…
PenghuiCheng Apr 18, 2025
58eb87e
Merge remote-tracking branch 'upstream/main' into distributed_2.8
PenghuiCheng Apr 22, 2025
e558eaa
Enabled some UT cases of distributed
PenghuiCheng Apr 24, 2025
83ac56e
enable UT case in _shard and _tool folder
PenghuiCheng Apr 29, 2025
0e7a7b6
Fixed hard code error for world_size 8
PenghuiCheng May 5, 2025
a2b2fc6
merge from main branch
PenghuiCheng May 7, 2025
8de00b9
fix regex
daisyden May 8, 2025
91f5d10
Merge branch 'distributed_2.8' of https://github.com/daisyden/pytorch…
daisyden May 8, 2025
39e6c02
Fixed UT errors for cuda hard code
PenghuiCheng May 15, 2025
06d6c3e
Merge remote-tracking branch 'origin/distributed_2.8' into distribute…
PenghuiCheng May 15, 2025
a059005
Merge from upstream main branch
PenghuiCheng May 16, 2025
31ddfc0
Fixed XPU UT error for CUDA hard code
PenghuiCheng May 21, 2025
9a6df8a
Merge remote-tracking branch 'upstream0523/main' into distributed_2.8
daisyden May 23, 2025
50cb9e9
fix fsdp2 issue after rebase, fix #1618 dynamo issue
daisyden May 26, 2025
08559b9
remove duplicated device_type
daisyden May 27, 2025
f6a8c6a
fix rebase issue of test_fully_shard_overlap.py
daisyden May 27, 2025
dc0befa
merge from main branch
PenghuiCheng May 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@
get_devtype,
MLP,
)
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.common_utils import run_tests, TEST_XPU
from torch.testing._internal.distributed._tensor.common_dtensor import (
ModelArgs,
Transformer,
)

device_type = torch.accelerator.current_accelerator().type

device_type = torch.device(get_devtype())

Expand Down
3 changes: 2 additions & 1 deletion test/distributed/_composable/fsdp/test_fully_shard_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
patch_reshard,
patch_unshard,
)
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.common_utils import run_tests, TEST_XPU
from torch.testing._internal.distributed._tensor.common_dtensor import (
ModelArgs,
Transformer,
Expand Down Expand Up @@ -221,6 +221,7 @@ def test_reduce_scatter_fp32(self):
reduce_scatter_dtype=torch.float32,
)


@skip_if_lt_x_gpu(1)
def test_reduce_scatter_fp16(self):
param_sizes = self._get_param_sizes()
Expand Down
3 changes: 3 additions & 0 deletions test/distributed/_composable/fsdp/test_fully_shard_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ class TestFullyShardCompile(FSDPTest):
def skipTestForOldSm(self):
# Assumption: This test class is only run on GPU. See `HAS_GPU` check at
# the top of the class.
# XPU is not applicable in this function
if device_type == 'xpu':
return
device = torch.device(
device_type.type,
self.rank % torch.get_device_module(device_type).device_count(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@
get_devtype,
MLP,
)
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.common_utils import run_tests, TEST_XPU
from torch.testing._internal.two_tensor import TwoTensor

device_type = torch.accelerator.current_accelerator().type

device_type = torch.device(get_devtype())

Expand Down Expand Up @@ -259,7 +260,7 @@ class TestFullyShardAllGatherExtensionsMultiThread(
):
@property
def world_size(self) -> int:
return 8
return min(8, torch.accelerator.device_count())

@property
def device(self) -> torch.device:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
)
from torch.testing._internal.common_utils import run_tests

device_type = torch.accelerator.current_accelerator().type

device_type = torch.device(get_devtype())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torch.testing._internal.common_fsdp import FSDPTest, get_devtype, MLP
from torch.testing._internal.common_utils import run_tests

device_type = torch.accelerator.current_accelerator().type

device_type = torch.device(get_devtype())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
)
sys.exit(0)


class C(nn.Module):
def __init__(self, dim: int) -> None:
super().__init__()
Expand Down
2 changes: 1 addition & 1 deletion test/distributed/_composable/fsdp/test_fully_shard_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
TransformerBlock,
)


device_type = torch.device(get_devtype())


Expand All @@ -62,6 +61,7 @@ def test_move_states_to_device_tensor(self):
for tensor in itertools.chain(model.parameters(), model.buffers()):
self.assertEqual(tensor.device, torch.device("cpu"))
fully_shard(model)

accelerator_device = torch.device(
device_type.type, torch.get_device_module(device_type).current_device()
)
Expand Down
10000
6 changes: 4 additions & 2 deletions test/distributed/_composable/fsdp/test_fully_shard_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
import torch.distributed as dist
from torch._dynamo.test_case import run_tests
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.inductor_utils import HAS_CUDA
from torch.testing._internal.inductor_utils import HAS_CUDA, HAS_XPU
from torch.testing._internal.logging_utils import LoggingTestCase
import torch

device_type = torch.accelerator.current_accelerator().type

requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
requires_gpu = unittest.skipUnless(HAS_CUDA or HAS_XPU, "requires cuda or xpu")
requires_distributed = functools.partial(
unittest.skipIf, not dist.is_available(), "requires distributed"
)
Expand Down
7 changes: 4 additions & 3 deletions test/distributed/_composable/fsdp/test_fully_shard_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, OffloadPolicy
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest, get_devtype
from torch.testing._internal.common_utils import run_tests, TEST_CUDA, TEST_HPU
from torch.testing._internal.common_utils import run_tests, TEST_CUDA, TEST_HPU, TEST_XPU
from torch.testing._internal.distributed._tensor.common_dtensor import (
ModelArgs,
Transformer,
Expand Down Expand Up @@ -236,14 +236,15 @@ def test_fully_shard_del_memory(self):

def _get_peak_active_memory_mb(self) -> int:
mem_stats = torch.get_device_module(device_type).memory_stats()
if TEST_CUDA:

if TEST_CUDA or TEST_XPU:
return round(mem_stats["active_bytes.all.peak"] / 1e6)
if TEST_HPU:
return round(mem_stats["MaxInUse"] / 1e6)

def _get_curr_active_memory_mb(self) -> int:
mem_stats = torch.get_device_module(device_type).memory_stats()
if TEST_CUDA:
if TEST_CUDA or TEST_XPU:
return round(mem_stats["active_bytes.all.current"] / 1e6)
if TEST_HPU:
return round(mem_stats["InUse"] / 1e6)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torch.distributed.tensor import Shard
from torch.testing._internal.common_distributed import (
requires_nccl_version,
requires_nccl_version_or,
SaveForwardInputsModel,
skip_if_lt_x_gpu,
)
Expand All @@ -32,6 +33,7 @@

device_type = torch.device(get_devtype())

device_type = torch.accelerator.current_accelerator().type

class TestFullyShardMixedPrecisionTraining(FSDPTest):
@property
Expand Down Expand Up @@ -87,7 +89,7 @@ def _get_use_shard_placement_fn_vals_for_bf16_reduce(self):

@skipIfRocm # regressed in ROCm 6.4, but ROCm 6.5 fixes it
@skip_if_lt_x_gpu(2)
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for bf16 collectives")
@requires_nccl_version_or((2, 10), "Need NCCL 2.10+ for bf16 collectives", backends=['xccl',])
def test_compute_dtype(self):
use_shard_placement_fn_vals = (
self._get_use_shard_placement_fn_vals_for_bf16_reduce()
Expand Down Expand Up @@ -167,7 +169,7 @@ def assert_fn(output: torch.Tensor):

@skipIfRocm # regressed in ROCm 6.4, but ROCm 6.5 fixes it
@skip_if_lt_x_gpu(2)
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for bf16 collectives")
@requires_nccl_version_or((2, 10), "Need NCCL 2.10+ for bf16 collectives", backends=['xccl',])
def test_reduce_dtype(self):
self.run_subtests(
{
Expand Down Expand Up @@ -500,7 +502,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
)

@skip_if_lt_x_gpu(1)
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for bf16 collectives")
@requires_nccl_version_or((2, 10), "Need NCCL 2.10+ for bf16 collectives", backends=['xccl',])
def test_norm_modules_bf16(self):
mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16)
self._test_norm_modules(mp_policy)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
patch_all_gather,
patch_reduce_scatter,
)
from torch.testing._internal.common_utils import get_cycles_per_ms, run_tests, TEST_HPU

from torch.testing._internal.common_utils import get_cycles_per_ms, run_tests, TEST_HPU, TEST_XPU


device_type = torch.device(get_devtype())
Expand Down Expand Up @@ -45,6 +46,7 @@ def world_size(self) -> int:

@skip_if_lt_x_gpu(2)
@unittest.skipIf(TEST_HPU, "Sleep is not supported on HPU")
@unittest.skipIf(TEST_XPU, "Sleep is not supported on XPU")
def test_fully_shard_training_overlap(self):
torch.manual_seed(42)

Expand All @@ -66,6 +68,7 @@ def test_fully_shard_training_overlap(self):
def delay_collective():
# Share a stream so that all-gather and reduce-scatter block each
# other like in `ProcessGroupNCCL`

comm_stream.wait_stream(
torch.get_device_module(device_type).current_stream()
)
Expand Down Expand Up @@ -158,6 +161,7 @@ def fwd_bwd():

@skip_if_lt_x_gpu(2)
@unittest.skipIf(TEST_HPU, "Sleep is not supported on HPU")
@unittest.skipIf(TEST_XPU, "Sleep is not supported on XPU")
def test_fully_shard_post_optim_event_overlap(self):
torch.manual_seed(42)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
TransformerBlock,
)

device_type = torch.accelerator.current_accelerator().type

device_type = torch.device(get_devtype())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from torch.testing._internal.common_utils import (
get_cycles_per_ms,
run_tests,
TEST_XPU,
TEST_HPU,
wrapSwapTensorsTest,
)
Expand All @@ -50,7 +51,6 @@
TransformerBlock,
)


c10d_ops = torch.ops.c10d
funcol = torch.ops.c10d_functional

Expand Down Expand Up @@ -315,6 +315,7 @@ def _shard_placement_fn(param: nn.Parameter) -> Optional[Shard]:

@skip_if_lt_x_gpu(2)
@unittest.skipIf(TEST_HPU, "Sleep kernel not supported for HPU")
@unittest.skipIf(TEST_XPU, "sleep kernel not supported on XPU")
@compiled_fsdp_test(compile_compute_on_module=Transformer)
def test_train_parity_multi_group(self):
"""
Expand All @@ -338,6 +339,7 @@ def test_train_parity_multi_group(self):

@skip_if_lt_x_gpu(2)
@unittest.skipIf(TEST_HPU, "sleep kernel not supported on HPU")
@unittest.skipIf(TEST_XPU, "sleep kernel not supported on XPU")
def test_train_parity_multi_group_cpu_offload_eager(self):
"""
Tests train parity against DDP when using multiple parameter groups for
Expand All @@ -362,6 +364,7 @@ def test_train_parity_multi_group_cpu_offload_eager(self):

@skip_if_lt_x_gpu(2)
@unittest.skipIf(TEST_HPU, "sleep kernel not supported on HPU")
@unittest.skipIf(TEST_XPU, "sleep kernel not supported on XPU")
@compiled_fsdp_test(compile_compute_on_module=Transformer)
def test_train_parity_multi_group_unshard_async_op(self):
"""
Expand Down Expand Up @@ -616,6 +619,7 @@ def test_explicit_prefetching(self):

@skip_if_lt_x_gpu(2)
@unittest.skipIf(TEST_HPU, "Sleep is not supported on HPU")
@unittest.skipIf(TEST_XPU, "Sleep is not supported on XPU")
def test_post_optim_event(self):
torch.manual_seed(42)
model_args = ModelArgs(dropout_p=0.0)
Expand Down
Loading
0