8000 Revert "[ca][dtensor] run real PG dtensor tests under CA (#152689)" · pytorch/pytorch@2327c9e · GitHub
[go: up one dir, main page]

Skip to content

Commit 2327c9e

Browse files
Revert "[ca][dtensor] run real PG dtensor tests under CA (#152689)"
This reverts commit b297e01. Reverted #152689 on behalf of https://github.com/malfet due to Looks like it breaks rocm, see https://hud.pytorch.org/hud/pytorch/pytorch/fa8543454ab5d3deda1d15c1f8d24e9ebe14f340/1?per_page=50&name_filter=slow%20%2F%20linux-jammy-rocm&mergeEphemeralLF=true ([comment](#153300 (comment)))
1 parent db26aea commit 2327c9e

File tree

1 file changed

+7
-24
lines changed

1 file changed

+7
-24
lines changed

test/distributed/tensor/test_dtensor_compile.py

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates
22
# Owner(s): ["oncall: distributed"]
33

4-
import contextlib
54
import copy
65
import functools
76
import unittest
@@ -880,17 +879,9 @@ class TestDTensorCompileE2E(DTensorTestBase):
880879
def world_size(self):
881880
return 4
882881

883-
# multiprocess relies on pickling the source code
884-
# so compiled autograd tests can't dynamically wrap this class
885-
def _bwd_ctx(self, use_ca):
886-
if not use_ca:
887-
return contextlib.nullcontext()
888-
return torch._dynamo.compiled_autograd._enable(torch.compile)
889-
890882
@with_comms
891883
@parametrize("is_seq_parallel", [True, False])
892-
@parametrize("use_ca", [True, False])
893-
def test_tp_compile_fullgraph(self, is_seq_parallel, use_ca):
884+
def test_tp_compile_fullgraph(self, is_seq_parallel):
894885
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
895886

896887
model = SimpleModel(self.device_type)
@@ -944,15 +935,13 @@ def test_tp_compile_fullgraph(self, is_seq_parallel, use_ca):
944935
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
945936
compiled_mod = torch.compile(model, backend=cnt, fullgraph=True)
946937
compiled_out = compiled_mod(inp)
947-
with self._bwd_ctx(use_ca):
948-
compiled_out.sum().backward()
938+
compiled_out.sum().backward()
949939
self.assertEqual(compiled_out, out)
950940
self.assertEqual(cnt.frame_count, 1)
951941

952942
@with_comms
953943
@skip_if_lt_x_gpu(4)
954-
@parametrize("use_ca", [True, False])
955-
def test_2d_fsdp_tp_compile(self, use_ca):
944+
def test_2d_fsdp_tp_compile(self):
956945
data_parallel_size = 2
957946
model = SimpleModel(self.device_type)
958947
model_copy = copy.deepcopy(model)
@@ -995,16 +984,13 @@ def test_2d_fsdp_tp_compile(self, use_ca):
995984
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
996985
compiled_2d = torch.compile(fsdp_2d, backend=cnt)
997986
compiled_output = compiled_2d(inp)
998-
with self._bwd_ctx(use_ca):
999-
compiled_output.sum().backward()
1000987

1001988
self.assertEqual(out, compiled_output)
1002989
self.assertEqual(cnt.frame_count, 1)
1003990

1004991
@with_comms
1005992
@skip_if_lt_x_gpu(4)
1006-
@parametrize("use_ca", [True, False])
1007-
def test_2d_fsdp_tp_ac_compile(self, use_ca):
993+
def test_2d_fsdp_tp_ac_compile(self):
1008994
dp_degree = 2
1009995
tp_degree = self.world_size // dp_degree
1010996
model = SimpleModel(self.device_type)
@@ -1047,17 +1033,15 @@ def test_2d_fsdp_tp_ac_compile(self, use_ca):
10471033

10481034
# backward pass
10491035
out.sum().backward()
1050-
with self._bwd_ctx(use_ca):
1051-
compiled_output.sum().backward()
1036+
compiled_output.sum().backward()
10521037

10531038
# compare the gradients:
10541039
for n, p in zip(fsdp_2d.parameters(), compiled_2d.parameters()):
10551040
self.assertEqual(n.grad, p.grad)
10561041

10571042
@with_comms
10581043
@skip_if_lt_x_gpu(4)
1059-
@parametrize("use_ca", [True, False])
1060-
def test_compile_dtensor_redistribute_backward(self, use_ca):
1044+
def test_compile_dtensor_redistribute_backward(self):
10611045
mesh = DeviceMesh(device_type="cuda", mesh=torch.arange(self.world_size))
10621046

10631047
def fn(x, y):
@@ -1081,8 +1065,7 @@ def fn(x, y):
10811065

10821066
# Now run and assert the backward + gradients
10831067
ref.sum().backward()
1084-
with self._bwd_ctx(use_ca):
1085-
res.sum().backward()
1068+
res.sum().backward()
10861069

10871070
self.assertEqual(x_ref.grad, x.grad)
10881071
self.assertEqual(y_ref.grad, y.grad)

0 commit comments

Comments
 (0)
0