8000 Enable XPU distributed test for PT2.8 by daisyden · Pull Request #149916 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Enable XPU distributed test for PT2.8 #149916

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 53 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
d0d8271
make skipXPU work
daisyden May 11, 2024
c791db9
enabled torch-xpu ops in op_db
daisyden May 13, 2024
f5cbd50
clean up code
daisyden May 13, 2024
4d94417
Revert "clean up code"
daisyden May 13, 2024
6844101
Revert "enabled torch-xpu ops in op_db"
daisyden May 13, 2024
5051e3c
Revert "make skipXPU work"
daisyden May 13, 2024
e2aa92a
merge common code update from https://github.com/Chao1Han/pytorch/pul…
chunhuanMeng Mar 19, 2025
9e83095
Merge branch 'main' of https://github.com/daisyden/pytorch into distr…
chunhuanMeng Mar 20, 2025
06dd2aa
merge common code update from https://github.com/Chao1Han/pytorch/pul…
daisyden Mar 20, 2025
a4a732b
Add XPU support for distributed
daisyden Mar 20, 2025
6e3f6b8
Merge branch 'distributed_2.8' of https://github.com/daisyden/pytorch…
daisyden Mar 20, 2025 < 8000 /div>
5f47367
Merge remote-tracking branch 'upstream/main' into distributed_2.8
daisyden Mar 21, 2025
345d7e6
ported fsdp and _composable/fsdp cases
daisyden Mar 21, 2025
4a5a522
Support XPU device for DDP test cases
PenghuiCheng Mar 24, 2025
20a4456
Support XPU device for pipeline cases
PenghuiCheng Mar 24, 2025
a90a603
ported fsdp tests
daisyden Mar 24, 2025
5b1aff7
Merge branch 'distributed_2.8' of https://github.com/daisyden/pytorch…
daisyden Mar 24, 2025
44d55b9
fixed backend mapping error for register_backend function
PenghuiCheng Mar 26, 2025
7dade1f
Update distributed UT cases
PenghuiCheng Apr 1, 2025
580aaee
remove fsdp_kwargs in test_fsdp_memory.py to align with cuda, added r…
daisyden Apr 1, 2025
c0f5713
Merge branch 'upstream_main4' into distributed_2.8
daisyden Apr 1, 2025
6dedbe3
Add test_dynamo_distributed cases
PenghuiCheng Apr 1, 2025
20d074c
Merge remote-tracking branch 'upstream/distributed_2.8' into distribu…
PenghuiCheng Apr 1, 2025
124ff16
update test_tp_random_state.py
PenghuiCheng Apr 2, 2025
0bea112
Merge from main branch
PenghuiCheng Apr 7, 2025
7409ade
support xccl in with_comms
daisyden Apr 8, 2025
636cbff
Merge branch 'distributed_2.8' of https://github.com/daisyden/pytorch…
daisyden Apr 8, 2025
0d5a86b
Merge branch 'upstream_main3' into distributed_2.8
daisyden Apr 8, 2025
3826e30
Enabled UT in test/distributed/tensor
PenghuiCheng Apr 9, 2025
cb711b7
Merge branch 'distributed_2.8' of https://github.com/daisyden/pytorch…
PenghuiCheng Apr 9, 2025
d6cd1b3
refine fsdp2 test case for xpu
daisyden Apr 9, 2025
624be3a
Merge branch 'distributed_2.8' of https://github.com/daisyden/pytorch…
daisyden Apr 9, 2025
8d8c5fe
fix some issues in test case, cuda specific code, world_size 8, etc.
daisyden Apr 11, 2025
1cf7887
merge from main branch
PenghuiCheng Apr 16, 2025
41475ac
Merge remote-tracking branch 'upstream/distributed_2.8' into distribu…
PenghuiCheng Apr 16, 2025
0628c76
Change world size in test_device_mesh.py
PenghuiCheng Apr 18, 2025
b0d935d
Merge remote-tracking branch 'origin/distributed_2.8' into distribute…
PenghuiCheng Apr 18, 2025
58eb87e
Merge remote-tracking branch 'upstream/main' into distributed_2.8
PenghuiCheng Apr 22, 2025
e558eaa
Enabled some UT cases of distributed
PenghuiCheng Apr 24, 2025
83ac56e
enable UT case in _shard and _tool folder
PenghuiCheng Apr 29, 2025
0e7a7b6
Fixed hard code error for world_size 8
PenghuiCheng May 5, 2025
a2b2fc6
merge from main branch
PenghuiCheng May 7, 2025
8de00b9
fix regex
daisyden May 8, 2025
91f5d10
Merge branch 'distributed_2.8' of https://github.com/daisyden/pytorch…
daisyden May 8, 2025
39e6c02
Fixed UT errors for cuda hard code
PenghuiCheng May 15, 2025
06d6c3e
Merge remote-tracking branch 'origin/distributed_2.8' into distribute…
PenghuiCheng May 15, 2025
a059005
Merge from upstream main branch
PenghuiCheng May 16, 2025
31ddfc0
Fixed XPU UT error for CUDA hard code
PenghuiCheng May 21, 2025
9a6df8a
Merge remote-tracking branch 'upstream0523/main' into distributed_2.8
daisyden May 23, 2025
50cb9e9
fix fsdp2 issue after rebase, fix #1618 dynamo issue
daisyden May 26, 2025
08559b9
remove duplicated device_type
daisyden May 27, 2025
f6a8c6a
fix rebase issue of test_fully_shard_overlap.py
daisyden May 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
ported fsdp tests
  • Loading branch information
daisyden committed Mar 24, 2025
commit a90a6038ff356090be39306780a0eea1f298e621
4 changes: 2 additions & 2 deletions test/distributed/fsdp/test_fsdp_comm_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

# bfloat16 is only supported by CUDA 11+ or XPU
BFLOAT16_AVAILABLE = ( torch.cuda.is_available() or torch.xpu.is_available() ) and (
torch.version.cuda is not None or torch.version.hip is not None
torch.version.cuda is not None or torch.version.hip is not None or torch.version.xpu is not None
)


Expand Down Expand Up @@ -402,7 +402,7 @@ def test_fp16_hook(
state, hook, sharding_strategy, torch.float16, has_wrapping
)

@requires_nccl()
@requires_nccl_or('xccl')
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for BF16_COMPRESS")
@skip_but_pass_in_sandcastle_if(
not BFLOAT16_AVAILABLE,
Expand Down
87 changes: 48 additions & 39 deletions test/distributed/fsdp/test_fsdp_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
sys.exit(0)

device_type = torch.accelerator.current_accelerator().type

class MyModel(nn.Module):
def __init__(self) -> None:
super().__init__()
Expand Down Expand Up @@ -93,7 +94,7 @@ def test_fsdp_device_id(self, use_index):
without specifying a device ID (i.e. ``torch.device("cuda")``) warns
"""
dev_id = (
torch.cuda.current_device()
torch.accelerator.current_device_index()
if use_index
else torch.device(device_type, torch.accelerator.current_device_index())
)
Expand Down Expand Up @@ -197,7 +198,7 @@ def forward(self, x, y):

seed = self.rank + 20231010
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.get_device_module(device_type).manual_seed(seed)

losses = []
grads = []
Expand All @@ -207,7 +208,7 @@ def forward(self, x, y):
for model, opt in ((fsdp_model, fsdp_opt), (ddp_model, ddp_opt)):
seed = self.rank + i
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.get_device_module(device_type).manual_seed(seed)
loss = model(x, y).sum()
losses.append(loss)
loss.backward()
Expand Down Expand Up @@ -237,7 +238,7 @@ def forward(self, x, y):
for model, opt in ((fsdp_model, fsdp_opt), (ddp_model, ddp_opt)):
seed = self.rank + i
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.get_device_module(device_type).manual_seed(seed)
loss = model(x, y).sum()
losses.append(loss)
loss.backward()
Expand Down Expand Up @@ -272,7 +273,7 @@ def forward(self, x, y):
return out1

fsdp = FSDP(
MyModel().cuda(),
MyModel().to(device=device_type),
sharding_strategy=sharding_strategy,
auto_wrap_policy=always_wrap_policy,
)
Expand Down Expand Up @@ -336,7 +337,7 @@ def _check_equal(local, fsdp):
torch.testing.assert_close(p1, p2)

fsdp_ctor = functools.partial(FSDP, sharding_strategy=sharding_strategy)
m = MyModule().cuda()
m = MyModule().to(device=device_type)
m_local = deepcopy(m)
local_m = m_local
prev_params = [p.clone() for p in m_local.parameters()]
Expand Down Expand Up @@ -385,7 +386,7 @@ def _check_equal(local, fsdp):
@skip_if_lt_x_gpu(2)
def test_fsdp_optim_overlap_no_use_orig_params_error(self):
fsdp_overlap = FSDP(
MyModel().cuda(),
MyModel().to(device=device_type),
auto_wrap_policy=always_wrap_policy,
use_orig_params=False,
)
Expand All @@ -409,16 +410,16 @@ def test_fsdp_optimizer_overlap(self):
torch.manual_seed(0)
for cpu_offload in [True, False]:
offload = CPUOffload(offload_params=cpu_offload)
model = MyModel().cuda()
model = MyModel().to(device=device_type)
model_overlap = deepcopy(model)
fsdp = FSDP(
model.cuda(),
model.to(device=device_type),
auto_wrap_policy=always_wrap_policy,
use_orig_params=True,
cpu_offload=offload,
)
fsdp_overlap = FSDP(
model_overlap.cuda(),
model_overlap.to(device=device_type),
auto_wrap_policy=always_wrap_policy,
use_orig_params=True,
cpu_offload=offload,
Expand Down Expand Up @@ -546,7 +547,7 @@ def test_fsdp_cpu_init_stays_on_cpu(self):
"""Tests that passing a CPU module to FSDP preserves that the wrapped
module is on CPU after FSDP initialization, albeit after logging a
warning, and that FSDP moves CPU input to GPU before the forward."""
torch.cuda.set_device(self.rank)
torch.accelerator.set_device_index(self.rank)
regex = "passed-in `module` is on CPU"
context = self.assertWarnsRegex(
expected_warning=UserWarning, expected_regex=regex
Expand All @@ -561,7 +562,7 @@ def test_fsdp_cpu_init_stays_on_cpu(self):
devices = {p.device for p in fsdp_model.parameters()}
self.assertEqual(1, len(devices))
self.assertEqual(torch.device("cpu"), devices.pop())
fsdp_model = fsdp_model.cuda()
fsdp_model = fsdp_model.to(device=device_type)
# Ensure fwd + backward can be performed after moving to CUDA.
# CPU input also tests that input is correctly moved to appropriate
# CUDA device.
Expand Down Expand Up @@ -606,7 +607,7 @@ def init_nested_wrapped_module():
nested_wrapped_module,
self.process_group,
auto_wrap_policy=ModuleWrapPolicy({nn.Linear}),
device_id=torch.cuda.current_device(),
device_id=torch.accelerator.current_device_index(),
sync_module_states=True,
)
# Each rank's buffers should be 0s since rank 0 is the source, and they
Expand Down Expand Up @@ -683,7 +684,7 @@ def _test_device_id_auto_wrap(self, use_callable: bool):
auto_wrap_policy = ModuleWrapPolicy(module_classes)
fsdp_kwargs = {
"auto_wrap_policy": auto_wrap_policy,
"device_id": torch.cuda.current_device(),
"device_id": torch.accelerator.current_device(),
}
fsdp_model = TransformerWithSharedParams.init(
self.process_group,
Expand Down Expand Up @@ -729,7 +730,7 @@ def forward(self, x):
model,
auto_wrap_policy=auto_wrap_policy,
cpu_offload=CPUOffload(offload_params=True),
device_id=torch.cuda.current_device(),
device_id=torch.accelerator.current_device_index(),
use_orig_params=use_orig_params,
)
cpu_device = torch.device("cpu")
Expand All @@ -742,12 +743,20 @@ def test_module_device_mismatches_device_id(self):
module that does not match the GPU device ID raises an error."""
# TODO: override FSDP MT Thread _run to set this instead of here for
# every test.
torch.cuda.set_device(self.rank)
context = (
self.assertRaisesRegex(ValueError, f"cuda:{self.rank} vs cuda:0")
if self.rank != 0
else nullcontext()
)
torch.accelerator.set_device_index(self.rank)

if device_type == "xpu":
context = (
self.assertRaisesRegex(ValueError, f"xpu:{self.rank} vs xpu:0")
if self.rank != 0
else nullcontext()
)
else:
context = (
self.assertRaisesRegex(ValueError, f"cuda:{self.rank} vs cuda:0")
if self.rank != 0
else nullcontext()
)
with context:
NestedWrappedModule.init(
self.process_group,
Expand All @@ -764,26 +773,26 @@ def test_cpu_gpu_module(self):
"""Tests a CPU + GPU module supported if device_id is passed
in, errors if device_id is not.
"""
torch.cuda.set_device(self.rank)
torch.accelerator.set_device_index(self.rank)

class CPUGPUModule(nn.Module):
def __init__(self) -> None:
super().__init__()
self.a = nn.Linear(1, 1).cuda()
self.a = nn.Linear(1, 1).to(device=device_type)
self.b = nn.Linear(1, 1)

cpu_gpu = CPUGPUModule()
fsdp = FSDP(cpu_gpu, device_id=torch.cuda.current_device())
fsdp = FSDP(cpu_gpu, device_id=torch.accelerator.current_device_index())
for param in fsdp.parameters():
self.assertEqual(param.device, torch.device(torch.cuda.current_device()))
self.assertEqual(param.device, torch.device(torch.accelerator.current_device_index()))

# without device_id, we hit an error
with self.assertRaisesRegex(RuntimeError, "please pass in device_id"):
FSDP(CPUGPUModule())

@skip_if_lt_x_gpu(2)
def test_fsdp_ignored_module_meta(self):
torch.cuda.set_device(self.rank)
torch.accelerator.set_device_index(self.rank)

class CPUGPUModule(nn.Module):
def __init__(self) -> None:
Expand All @@ -802,11 +811,11 @@ def __init__(self) -> None:
m = CPUGPUModule()
m = FSDP(
m,
device_id=torch.cuda.current_device(),
device_id=torch.accelerator.current_device_index(),
ignored_modules=[m.a],
use_orig_params=True,
param_init_fn=lambda m: m.to_empty(
device=torch.cuda.current_device(), recurse=False
device=torch.accelerator.current_device_index(), recurse=False
),
)
self.assertEqual(meta_device, next(m.a.parameters()).device)
Expand Down Expand Up @@ -854,20 +863,20 @@ def test_no_params(self):
"""
# TODO: override FSDP MT Thread _run to set this instead of here for
# every test.
torch.cuda.set_device(self.rank)
torch.accelerator.set_device_index(self.rank)
# Test CPU
no_params = nn.ReLU()
FSDP(no_params)
# Test CUDA
no_params = nn.ReLU().cuda()
no_params = nn.ReLU().to(device=device_type)
FSDP(no_params)
# Test CPU + device_id
no_params = nn.ReLU()
FSDP(no_params, device_id=torch.cuda.current_device())
FSDP(no_params, device_id=torch.accelerator.current_device_index())
# For modules with no params, wrong device_id will raise error about
# inconsistency between compute_device and device_id, since compute_device
# is computed as torch.cuda.current_device when there are no params.
no_params = nn.ReLU().cuda()
no_params = nn.ReLU().to(device=device_type)
context = (
(
self.assertRaisesRegex(
Expand All @@ -892,11 +901,11 @@ def __init__(self, rank):
super().__init__()
# Seed via rank to make model different across ranks
torch.manual_seed(rank)
torch.cuda.manual_seed(rank)
torch.get_device_module(device_type).manual_seed(rank)
self.lin = nn.Linear(10, 10, bias=False)
self.buffer = nn.Buffer(torch.ones(1) * rank)

m = MyModel(self.rank).cuda()
m = MyModel(self.rank).to(device=device_type)
_assert_module_states(
m, process_group=self.process_group, assert_fn=self.assertNotEqual
)
Expand All @@ -913,7 +922,7 @@ def __init__(self, rank):
m, process_group=self.process_group, assert_fn=self.assertNotEqual
)
# Passing sync_module_states into FSDP makes model the same during init.
fsdp = FSDP(m, device_id=torch.cuda.current_device(), sync_module_states=True)
fsdp = FSDP(m, device_id=torch.accelerator.current_device_index(), sync_module_states=True)
with fsdp.summon_full_params(fsdp):
_assert_module_states(
fsdp, process_group=self.process_group, assert_fn=self.assertEqual
Expand Down Expand Up @@ -1000,7 +1009,7 @@ def test_world_size_1_sharding_strategy_warning(self):
# warning
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always") # trigger all warnings
FSDP(nn.Linear(3, 3).cuda(), sharding_strategy=ShardingStrategy.NO_SHARD)
FSDP(nn.Linear(3, 3).to(device=device_type), sharding_strategy=ShardingStrategy.NO_SHARD)
for warning in w:
self.assertTrue(
warning.category != UserWarning
Expand All @@ -1014,16 +1023,16 @@ def test_world_size_1_sharding_strategy_warning(self):
warning_prefix + " " + str(ShardingStrategy.FULL_SHARD) + warning_suffix
)
with self.assertWarnsRegex(UserWarning, expected_regex_full_shard):
FSDP(nn.Linear(3, 3).cuda(), sharding_strategy=ShardingStrategy.FULL_SHARD)
FSDP(nn.Linear(3, 3).to(device=device_type), sharding_strategy=ShardingStrategy.FULL_SHARD)
with self.assertWarnsRegex(UserWarning, expected_regex_full_shard):
FSDP(nn.Linear(3, 3).cuda())
FSDP(nn.Linear(3, 3).to(device=device_type))
# - Pass `SHARD_GRAD_OP`
expected_regex_shard_grad_op = (
warning_prefix + " " + str(ShardingStrategy.SHARD_GRAD_OP) + warning_suffix
)
with self.assertWarnsRegex(UserWarning, expected_regex_shard_grad_op):
FSDP(
nn.Linear(3, 3).cuda(), sharding_strategy=ShardingStrategy.SHARD_GRAD_OP
nn.Linear(3, 3).to(device=device_type), sharding_strategy=ShardingStrategy.SHARD_GRAD_OP
)

@skip_if_lt_x_gpu(1)
Expand Down
12 changes: 9 additions & 3 deletions test/distributed/fsdp/test_fsdp_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
import torch.nn as nn
from torch import distributed as dist
from torch.cuda import Event

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
Expand All @@ -34,6 +34,10 @@
sys.exit(0)

device_type = torch.accelerator.current_accelerator().type
if device_type == "xpu":
from torch.xpu import Event
else:
from torch.cuda import Event

class Layer(nn.Module):
def __init__(self, compute_cycles, has_params: bool):
Expand All @@ -51,7 +55,8 @@ def forward(self, x):
# Record the fake forward compute time.
self.e1.record()
if self.sleep_cycles > 0:
torch.get_device_module(device_type)._sleep(self.sleep_cycles)
if torch.cuda.is_available():
torch.cuda._sleep(self.sleep_cycles)
if self.optional_param is not None:
x = x + self.optional_param # force the param to be part of the graph
self.e2.record()
Expand Down Expand Up @@ -138,7 +143,8 @@ def run(compute_cycles, all_gather_cycles):
def _delayed_all_gather(*args, **kwargs):
nonlocal all_gather_called
all_gather_called = True
torch.get_device_module(device_type)._sleep(all_gather_cycles)
if torch.cuda.is_available():
torch.cuda._sleep(all_gather_cycles)
assert orig_all_gather
return orig_all_gather(*args, **kwargs)

Expand Down
2 changes: 1 addition & 1 deletion test/distributed/fsdp/test_fsdp_pure_fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,6 @@ def _test_fp16_dtypes(


devices = ("cuda", "hpu", "xpu")
instantiate_device_type_tests(TestPureFP16, globals(), only_for=devices)
instantiate_device_type_tests(TestPureFP16, globals(), only_for=devices, allow_xpu=True)
if __name__ == "__main__":
run_tests()
Loading
0