File tree Expand file tree Collapse file tree 3 files changed +43
-1
lines changed Expand file tree Collapse file tree 3 files changed +43
-1
lines changed Original file line number Diff line number Diff line change @@ -192,6 +192,47 @@ def fn(x, y):
192192 ],
193193 )
194194
195+ def test_profiler_enabled (self ):
196+ def fn (x ):
197+ x = torch .sin (x )
198+ if torch .autograd ._profiler_enabled ():
199+ return torch .cos (x )
200+ else :
201+ return torch .sigmoid (x )
202+
203+ opt_fn = torch .compile (fn , backend = "eager" , fullgraph = True )
204+ x = torch .randn (4 )
205+
206+ ref = fn (x )
207+ res = opt_fn (x )
208+ self .assertEqual (ref , res )
209+
210+ with torch .autograd .profiler .profile ():
211+ ref = fn (x )
212+ res = opt_fn (x )
213+ self .assertEqual (ref , res )
214+
215+ def test_profiler_record_function_ignore (self ):
216+ def fn (x ):
217+ x = torch .sin (x )
218+ if torch .autograd ._profiler_enabled ():
219+ with torch .autograd .profiler .record_function ("dummy" ):
220+ return torch .cos (x )
221+ else :
222+ return torch .sigmoid (x )
223+
224+ opt_fn = torch .compile (fn , backend = "eager" , fullgraph = True )
225+ x = torch .randn (4 )
226+
227+ ref = fn (x )
228+ res = opt_fn (x )
229+ self .assertEqual (ref , res )
230+
231+ with torch .autograd .profiler .profile ():
232+ ref = fn (x )
233+ res = opt_fn (x )
234+ self .assertEqual (ref , res )
235+
195236
196237if __name__ == "__main__" :
197238 from torch ._dynamo .test_case import run_tests
Original file line number Diff line number Diff line change 176176 "torch.compiler.is_compiling" : TorchInGraphFunctionVariable ,
177177 "torch.compiler.is_dynamo_compiling" : TorchInGraphFunctionVariable ,
178178 "torch.compiler.is_exporting" : TorchInGraphFunctionVariable ,
179- "torch.autograd._profiler_enabled" : SkipFunctionVariable ,
180179 "torch._C._to_dlpack" : SkipFunctionVariable ,
181180 "torch.to_dlpack" : SkipFunctionVariable ,
182181 # We graph break on RNG state setters or getters like
24342433 "torch.atleast_3d" ,
24352434 "torch.autograd._calculate_shape" ,
24362435 "torch.autograd._is_checkpoint_valid" ,
2436+ "torch.autograd._profiler_enabled" ,
24372437 "torch.autograd._make_grads" ,
24382438 "torch.autograd._register_py_tensor_class_for_device" ,
24392439 "torch.autograd._tensor_or_tensors_to_tuple" ,
Original file line number Diff line number Diff line change 142142 torch .cuda .is_initialized ,
143143 torch .xpu .current_device ,
144144 torch .xpu .is_initialized ,
145+ torch .autograd ._profiler_enabled ,
145146]
146147
147148constant_fold_functions = [
You can’t perform that action at this time.
0 commit comments