1
1
# Copyright (c) Meta Platforms, Inc. and affiliates
2
2
# Owner(s): ["oncall: distributed"]
3
3
4
+ import contextlib
4
5
import copy
5
6
import functools
6
7
import unittest
@@ -879,9 +880,17 @@ class TestDTensorCompileE2E(DTensorTestBase):
879
880
def world_size (self ):
880
881
return 4
881
882
883
+ # multiprocess relies on pickling the source code
884
+ # so compiled autograd tests can't dynamically wrap this class
885
+ def _bwd_ctx (self , use_ca ):
886
+ if not use_ca :
887
+ return contextlib .nullcontext ()
888
+ return torch ._dynamo .compiled_autograd ._enable (torch .compile )
889
+
882
890
@with_comms
883
891
@parametrize ("is_seq_parallel" , [True , False ])
884
- def test_tp_compile_fullgraph (self , is_seq_parallel ):
892
+ @parametrize ("use_ca" , [True , False ])
893
+ def test_tp_compile_fullgraph (self , is_seq_parallel , use_ca ):
885
894
mesh = DeviceMesh (self .device_type , torch .arange (self .world_size ))
886
895
887
896
model = SimpleModel (self .device_type )
@@ -935,13 +944,15 @@ def test_tp_compile_fullgraph(self, is_seq_parallel):
935
944
cnt = torch ._dynamo .testing .CompileCounterWithBackend ("aot_eager" )
936
945
compiled_mod = torch .compile (model , backend = cnt , fullgraph = True )
937
946
compiled_out = compiled_mod (inp )
938
- compiled_out .sum ().backward ()
947
+ with self ._bwd_ctx (use_ca ):
948
+ compiled_out .sum ().backward ()
939
949
self .assertEqual (compiled_out , out )
940
950
self .assertEqual (cnt .frame_count , 1 )
941
951
942
952
@with_comms
943
953
@skip_if_lt_x_gpu (4 )
944
- def test_2d_fsdp_tp_compile (self ):
954
+ @parametrize ("use_ca" , [True , False ])
955
+ def test_2d_fsdp_tp_compile (self , use_ca ):
945
956
data_parallel_size = 2
946
957
model = SimpleModel (self .device_type )
947
958
model_copy = copy .deepcopy (model )
@@ -984,13 +995,16 @@ def test_2d_fsdp_tp_compile(self):
984
995
cnt = torch ._dynamo .testing .CompileCounterWithBackend ("aot_eager" )
985
996
compiled_2d = torch .compile (fsdp_2d , backend = cnt )
986
997
compiled_output = compiled_2d (inp )
998
+ with self ._bwd_ctx (use_ca ):
999
+ compiled_output .sum ().backward ()
987
1000
988
1001
self .assertEqual (out , compiled_output )
989
1002
self .assertEqual (cnt .frame_count , 1 )
990
1003
991
1004
@with_comms
992
1005
@skip_if_lt_x_gpu (4 )
993
- def test_2d_fsdp_tp_ac_compile (self ):
1006
+ @parametrize ("use_ca" , [True , False ])
1007
+ def test_2d_fsdp_tp_ac_compile (self , use_ca ):
994
1008
dp_degree = 2
995
1009
tp_degree = self .world_size // dp_degree
996
1010
model = SimpleModel (self .device_type )
@@ -1033,15 +1047,17 @@ def test_2d_fsdp_tp_ac_compile(self):
1033
1047
1034
1048
# backward pass
1035
1049
out .sum ().backward ()
1036
- compiled_output .sum ().backward ()
1050
+ with self ._bwd_ctx (use_ca ):
1051
+ compiled_output .sum ().backward ()
1037
1052
1038
1053
# compare the gradients:
1039
1054
for n , p in zip (fsdp_2d .parameters (), compiled_2d .parameters ()):
1040
1055
self .assertEqual (n .grad , p .grad )
1041
1056
1042
1057
@with_comms
1043
1058
@skip_if_lt_x_gpu (4 )
1044
- def test_compile_dtensor_redistribute_backward (self ):
1059
+ @parametrize ("use_ca" , [True , False ])
1060
+ def test_compile_dtensor_redistribute_backward (self , use_ca ):
1045
1061
mesh = DeviceMesh (device_type = "cuda" , mesh = torch .arange (self .world_size ))
1046
1062
1047
1063
def fn (x , y ):
@@ -1065,7 +1081,8 @@ def fn(x, y):
1065
1081
1066
1082
# Now run and assert the backward + gradients
1067
1083
ref .sum ().backward ()
1068
- res .sum ().backward ()
1084
+ with self ._bwd_ctx (use_ca ):
1085
+ res .sum ().backward ()
1069
1086
1070
1087
self .assertEqual (x_ref .grad , x .grad )
1071
1088
self .assertEqual (y_ref .grad , y .grad )
0 commit comments