@@ -132,7 +132,10 @@ THPPyInterpreterFrame* THPPyInterpreterFrame_New(_PyInterpreterFrame* frame) {
132
132
#else
133
133
#define THP_EVAL_API_FRAME_OBJECT PyFrameObject
134
134
135
- #define THP_PyFrame_FastToLocalsWithError PyFrame_FastToLocalsWithError
135
+ static int
136
+ THP_PyFrame_FastToLocalsWithError (THP_EVAL_API_FRAME_OBJECT * frame , int * free_vars_copied ) {
137
+ return PyFrame_FastToLocalsWithError (frame );
138
+ }
136
139
#endif
137
140
138
141
PyObject * guard_error_hook = NULL ;
@@ -161,7 +164,8 @@ static PyObject* _custom_eval_frame(
161
164
PyThreadState * tstate ,
162
165
THP_EVAL_API_FRAME_OBJECT * frame ,
163
166
int throw_flag ,
164
- PyObject * callback );
167
+ PyObject * callback ,
168
+ int * should_clear_frame );
165
169
static PyObject * (* previous_eval_frame )(PyThreadState * tstate ,
166
170
THP_EVAL_API_FRAME_OBJECT * frame , int throw_flag ) = NULL ;
167
171
@@ -283,7 +287,8 @@ inline static PyObject* eval_custom_code_impl(
283
287
PyThreadState * tstate ,
284
288
THP_EVAL_API_FRAME_OBJECT * frame ,
285
289
PyCodeObject * code ,
286
- int throw_flag ) {
290
+ int throw_flag ,
291
+ int free_vars_copied ) {
287
292
288
293
DEBUG_NULL_CHECK (tstate );
289
294
DEBUG_NULL_CHECK (frame );
@@ -333,6 +338,13 @@ inline static PyObject* eval_custom_code_impl(
333
338
}
334
339
#endif
335
340
341
+ // for 3.11+, if free_vars_copied is true, we do not need to
342
+ // run the first COPY_FREE_VARS since THP_PyFrame_FastToLocalsWithError
343
+ // already did the equivalent action.
344
+ if (free_vars_copied ) {
345
+ shadow -> prev_instr = _PyCode_CODE (shadow -> f_code );
346
+ }
347
+
336
348
#else
337
349
338
350
THP_EVAL_API_FRAME_OBJECT * shadow = PyFrame_New (tstate , code , frame -> f_globals , NULL );
@@ -432,10 +444,7 @@ inline static PyObject* eval_custom_code_impl(
432
444
433
445
#if IS_PYTHON_3_12_PLUS
434
446
435
- // In 3.12, the frame evaluation function is responsible for
436
- // clearing and popping the frame, so we manually do that on the
437
- // old frame.
438
- clear_old_frame_if_python_312_plus (tstate , frame );
447
+ // frame is cleared by caller
439
448
Py_DECREF (func );
440
449
441
450
#elif IS_PYTHON_3_11_PLUS
@@ -460,13 +469,15 @@ inline static PyObject* eval_custom_code(
460
469
PyThreadState * tstate ,
461
470
THP_EVAL_API_FRAME_OBJECT * frame ,
462
471
PyCodeObject * code ,
463
- int throw_flag ) {
472
+ int throw_flag ,
473
+ int free_vars_copied ) {
464
474
_PytorchRecordFunctionState * rf = _pytorch_record_function_enter ("Torch-Compiled Region" );
465
475
PyObject * result = eval_custom_code_impl (
466
476
tstate ,
467
477
frame ,
468
478
code ,
469
- throw_flag
479
+ throw_flag ,
480
+ free_vars_copied
470
481
);
471
482
_pytorch_record_function_exit (rf );
472
483
return result ;
@@ -487,7 +498,12 @@ static PyObject* _custom_eval_frame_shim(
487
498
return eval_frame_default (tstate , frame , throw_flag );
488
499
}
489
500
490
- return _custom_eval_frame (tstate , frame , throw_flag , callback );
501
+ int should_clear_frame = 0 ;
502
+ PyObject * result = _custom_eval_frame (tstate , frame , throw_flag , callback , & should_clear_frame );
503
+ if (should_clear_frame ) {
504
+ clear_old_frame_if_python_312_plus (tstate , frame );
505
+ }
506
+ return result ;
491
507
}
492
508
493
509
// NOTE: In 3.12+, any return NULL; statements must be preceded by
@@ -498,7 +514,8 @@ static PyObject* _custom_eval_frame(
498
514
PyThreadState * tstate ,
499
515
THP_EVAL_API_FRAME_OBJECT * frame ,
500
516
int throw_flag ,
501
- PyObject * callback ) {
517
+ PyObject * callback ,
518
+ int * should_clear_frame ) {
502
519
#if IS_PYTHON_3_11_PLUS
503
520
DEBUG_TRACE (
504
521
"begin %s %s %i %i" ,
@@ -552,9 +569,10 @@ static PyObject* _custom_eval_frame(
552
569
}
553
570
554
571
// TODO(jansel): investigate directly using the "fast" representation
555
- if (THP_PyFrame_FastToLocalsWithError (frame ) < 0 ) {
572
+ int free_vars_copied = 0 ;
573
+ if (THP_PyFrame_FastToLocalsWithError (frame , & free_vars_copied ) < 0 ) {
556
574
DEBUG_TRACE ("error %s" , get_frame_name (frame ));
557
- clear_old_frame_if_python_312_plus ( tstate , frame ) ;
575
+ * should_clear_frame = 1 ;
558
576
return NULL ;
559
577
}
560
578
@@ -570,7 +588,7 @@ static PyObject* _custom_eval_frame(
570
588
571
589
if (maybe_cached_code == NULL ) {
572
590
// guard eval failed, keep propagating
573
- clear_old_frame_if_python_312_plus ( tstate , frame ) ;
591
+ * should_clear_frame = 1 ;
574
592
return NULL ;
575
593
} else if (maybe_cached_code == Py_None ) {
576
594
DEBUG_TRACE ("cache miss %s" , get_frame_name (frame ));
@@ -579,7 +597,8 @@ static PyObject* _custom_eval_frame(
579
597
PyCodeObject * cached_code = (PyCodeObject * )maybe_cached_code ;
580
598
// used cached version
581
599
DEBUG_TRACE ("cache hit %s" , get_frame_name (frame ));
582
- return eval_custom_code (tstate , frame , cached_code , throw_flag );
600
+ * should_clear_frame = 1 ;
601
+ return eval_custom_code (tstate , frame , cached_code , throw_flag , free_vars_copied );
583
602
}
584
603
DEBUG_CHECK (PyDict_CheckExact (frame -> f_locals ));
585
604
DEBUG_CHECK (PyDict_CheckExact (frame -> f_globals ));
@@ -595,15 +614,16 @@ static PyObject* _custom_eval_frame(
595
614
_pytorch_record_function_exit (rf );
596
615
if (maybe_cached_code == NULL ) {
597
616
// Python error
598
- clear_old_frame_if_python_312_plus ( tstate , frame ) ;
617
+ * should_clear_frame = 1 ;
599
618
return NULL ;
600
619
} else if (maybe_cached_code != Py_None ) {
601
620
PyCodeObject * cached_code = (PyCodeObject * )maybe_cached_code ;
602
621
// used cached version
603
622
DEBUG_TRACE ("cache hit %s" , get_frame_name (frame ));
604
623
// Re-enable custom behavior
605
624
eval_frame_callback_set (callback );
606
- return eval_custom_code (tstate , frame , cached_code , throw_flag );
625
+ * should_clear_frame = 1 ;
626
+ return eval_custom_code (tstate , frame , cached_code , throw_flag , free_vars_copied );
607
627
}
608
628
// cache miss
609
629
CacheEntry * cache_entry = extract_cache_entry (extra );
@@ -618,7 +638,7 @@ static PyObject* _custom_eval_frame(
618
638
// cascading failure from internal exceptions. The upshot is if
619
639
// Dynamo barfs, that's it for Dynamo, even if you catch the exception
620
640
// inside the torch.compile block we won't try to Dynamo anything else.
621
- clear_old_frame_if_python_312_plus ( tstate , frame ) ;
641
+ * should_clear_frame = 1 ;
622
642
return NULL ;
623
643
} else if (result != Py_None ) {
624
644
DEBUG_TRACE ("create cache %s" , get_frame_name (frame ));
@@ -636,7 +656,8 @@ static PyObject* _custom_eval_frame(
636
656
// will be cleaned up when set_extra_state is called.
637
657
// Re-enable custom behavior
638
658
eval_frame_callback_set (callback );
639
- return eval_custom_code (tstate , frame , CacheEntry_get_code (new_cache_entry ), throw_flag );
659
+ * should_clear_frame = 1 ;
660
+ return eval_custom_code (tstate , frame , CacheEntry_get_code (new_cache_entry ), throw_flag , free_vars_copied );
640
661
} else {
641
662
DEBUG_TRACE ("create skip %s" , get_frame_name (frame ));
642
663
Py_DECREF (result );
0 commit comments