1
1
# Copyright (c) Meta Platforms, Inc. and affiliates
2
2
# Owner(s): ["oncall: distributed"]
3
3
4
- import contextlib
5
4
import copy
6
5
import functools
7
6
import unittest
@@ -880,17 +879,9 @@ class TestDTensorCompileE2E(DTensorTestBase):
880
879
def world_size (self ):
881
880
return 4
882
881
883
- # multiprocess relies on pickling the source code
884
- # so compiled autograd tests can't dynamically wrap this class
885
- def _bwd_ctx (self , use_ca ):
886
- if not use_ca :
887
- return contextlib .nullcontext ()
888
- return torch ._dynamo .compiled_autograd ._enable (torch .compile )
889
-
890
882
@with_comms
891
883
@parametrize ("is_seq_parallel" , [True , False ])
892
- @parametrize ("use_ca" , [True , False ])
893
- def test_tp_compile_fullgraph (self , is_seq_parallel , use_ca ):
884
+ def test_tp_compile_fullgraph (self , is_seq_parallel ):
894
885
mesh = DeviceMesh (self .device_type , torch .arange (self .world_size ))
895
886
896
887
model = SimpleModel (self .device_type )
@@ -944,15 +935,13 @@ def test_tp_compile_fullgraph(self, is_seq_parallel, use_ca):
944
935
cnt = torch ._dynamo .testing .CompileCounterWithBackend ("aot_eager" )
945
936
compiled_mod = torch .compile (model , backend = cnt , fullgraph = True )
946
937
compiled_out = compiled_mod (inp )
947
- with self ._bwd_ctx (use_ca ):
948
- compiled_out .sum ().backward ()
938
+ compiled_out .sum ().backward ()
949
939
self .assertEqual (compiled_out , out )
950
940
self .assertEqual (cnt .frame_count , 1 )
951
941
952
942
@with_comms
953
943
@skip_if_lt_x_gpu (4 )
954
- @parametrize ("use_ca" , [True , False ])
955
- def test_2d_fsdp_tp_compile (self , use_ca ):
944
+ def test_2d_fsdp_tp_compile (self ):
956
945
data_parallel_size = 2
957
946
model = SimpleModel (self .device_type )
958
947
model_copy = copy .deepcopy (model )
@@ -995,16 +984,13 @@ def test_2d_fsdp_tp_compile(self, use_ca):
995
984
cnt = torch ._dynamo .testing .CompileCounterWithBackend ("aot_eager" )
996
985
compiled_2d = torch .compile (fsdp_2d , backend = cnt )
997
986
compiled_output = compiled_2d (inp )
998
- with self ._bwd_ctx (use_ca ):
999
- compiled_output .sum ().backward ()
1000
987
1001
988
self .assertEqual (out , compiled_output )
1002
989
self .assertEqual (cnt .frame_count , 1 )
1003
990
1004
991
@with_comms
1005
992
@skip_if_lt_x_gpu (4 )
1006
- @parametrize ("use_ca" , [True , False ])
1007
- def test_2d_fsdp_tp_ac_compile (self , use_ca ):
993
+ def test_2d_fsdp_tp_ac_compile (self ):
1008
994
dp_degree = 2
1009
995
tp_degree = self .world_size // dp_degree
1010
996
model = SimpleModel (self .device_type )
@@ -1047,17 +1033,15 @@ def test_2d_fsdp_tp_ac_compile(self, use_ca):
1047
1033
1048
1034
# backward pass
1049
1035
out .sum ().backward ()
1050
- with self ._bwd_ctx (use_ca ):
1051
- compiled_output .sum ().backward ()
1036
+ compiled_output .sum ().backward ()
1052
1037
1053
1038
# compare the gradients:
1054
1039
for n , p in zip (fsdp_2d .parameters (), compiled_2d .parameters ()):
1055
1040
self .assertEqual (n .grad , p .grad )
1056
1041
1057
1042
@with_comms
1058
1043
@skip_if_lt_x_gpu (4 )
1059
- @parametrize ("use_ca" , [True , False ])
1060
- def test_compile_dtensor_redistribute_backward (self , use_ca ):
1044
+ def test_compile_dtensor_redistribute_backward (self ):
1061
1045
mesh = DeviceMesh (device_type = "cuda" , mesh = torch .arange (self .world_size ))
1062
1046
1063
1047
def fn (x , y ):
@@ -1081,8 +1065,7 @@ def fn(x, y):
1081
1065
1082
1066
# Now run and assert the backward + gradients
1083
1067
ref .sum ().backward ()
1084
- with self ._bwd_ctx (use_ca ):
1085
- res .sum ().backward ()
1068
+ res .sum ().backward ()
1086
1069
1087
1070
self .assertEqual (x_ref .grad , x .grad )
1088
1071
self .assertEqual (y_ref .grad , y .grad )
0 commit comments