8000 [ca][dtensor] run real PG dtensor tests under CA (#152689) · pytorch/pytorch@b297e01 · GitHub
[go: up one dir, main page]

Skip to content

Commit b297e01

Browse files
xmfanpytorchmergebot
authored andcommitted
[ca][dtensor] run real PG dtensor tests under CA (#152689)
Pull Request resolved: #152689 Approved by: https://github.com/bdhirsh ghstack dependencies: #153300
1 parent 4863e5c commit b297e01

File tree

1 file changed

+24
-7
lines changed

1 file changed

+24
-7
lines changed

test/distributed/tensor/test_dtensor_compile.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates
22
# Owner(s): ["oncall: distributed"]
33

4+
import contextlib
45
import copy
56
import functools
67
import unittest
@@ -879,9 +880,17 @@ class TestDTensorCompileE2E(DTensorTestBase):
879880
def world_size(self):
880881
return 4
881882

883+
# multiprocess relies on pickling the source code
884+
# so compiled autograd tests can't dynamically wrap this class
885+
def _bwd_ctx(self, use_ca):
886+
if not use_ca:
887+
return contextlib.nullcontext()
888+
return torch._dynamo.compiled_autograd._enable(torch.compile)
889+
882890
@with_comms
883891
@parametrize("is_seq_parallel", [True, False])
884-
def test_tp_compile_fullgraph(self, is_seq_parallel):
892+
@parametrize("use_ca", [True, False])
893+
def test_tp_compile_fullgraph(self, is_seq_parallel, use_ca):
885894
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
886895

887896
model = SimpleModel(self.device_type)
@@ -935,13 +944,15 @@ def test_tp_compile_fullgraph(self, is_seq_parallel):
935944
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
936945
compiled_mod = torch.compile(model, backend=cnt, fullgraph=True)
937946
compiled_out = compiled_mod(inp)
938-
compiled_out.sum().backward()
947+
with self._bwd_ctx(use_ca):
948+
compiled_out.sum().backward()
939949
self.assertEqual(compiled_out, out)
940950
self.assertEqual(cnt.frame_count, 1)
941951

942952
@with_comms
943953
@skip_if_lt_x_gpu(4)
944-
def test_2d_fsdp_tp_compile(self):
954+
@parametrize("use_ca", [True, False])
955+
def test_2d_fsdp_tp_compile(self, use_ca):
945956
data_parallel_size = 2
946957
model = SimpleModel(self.device_type)
947958
model_copy = copy.deepcopy(model)
@@ -984,13 +995,16 @@ def test_2d_fsdp_tp_compile(self):
984995
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
985996
compiled_2d = torch.compile(fsdp_2d, backend=cnt)
986997
compiled_output = compiled_2d(inp)
998+
with self._bwd_ctx(use_ca):
999+
compiled_output.sum().backward()
9871000

9881001
self.assertEqual(out, compiled_output)
9891002
self.assertEqual(cnt.frame_count, 1)
9901003

9911004
@with_comms
9921005
@skip_if_lt_x_gpu(4)
993-
def test_2d_fsdp_tp_ac_compile(self):
1006+
@parametrize("use_ca", [True, False])
1007+
def test_2d_fsdp_tp_ac_compile(self, use_ca):
9941008
dp_degree = 2
9951009
tp_degree = self.world_size // dp_degree
9961010
model = SimpleModel(self.device_type)
@@ -1033,15 +1047,17 @@ def test_2d_fsdp_tp_ac_compile(self):
10331047

10341048
# backward pass
10351049
out.sum().backward()
1036-
compiled_output.sum().backward()
1050+
with self._bwd_ctx(use_ca):
1051+
compiled_output.sum().backward()
10371052

10381053
# compare the gradients:
10391054
for n, p in zip(fsdp_2d.parameters(), compiled_2d.parameters()):
10401055
self.assertEqual(n.grad, p.grad)
10411056

10421057
@with_comms
10431058
@skip_if_lt_x_gpu(4)
1044-
def test_compile_dtensor_redistribute_backward(self):
1059+
@parametrize("use_ca", [True, False])
1060+
def test_compile_dtensor_redistribute_backward(self, use_ca):
10451061
mesh = DeviceMesh(device_type="cuda", mesh=torch.arange(self.world_size))
10461062

10471063
def fn(x, y):
@@ -1065,7 +1081,8 @@ def fn(x, y):
10651081

10661082
# Now run and assert the backward + gradients
10671083
ref.sum().backward()
1068-
res.sum().backward()
1084+
with self._bwd_ctx(use_ca):
1085+
res.sum().backward()
10691086

10701087
self.assertEqual(x_ref.grad, x.grad)
10711088
self.assertEqual(y_ref.grad, y.grad)

0 commit comments

Comments
 (0)
0