@@ -132,7 +132,10 @@ THPPyInterpreterFrame* THPPyInterpreterFrame_New(_PyInterpreterFrame* frame) {
132
132
#else
133
133
#define THP_EVAL_API_FRAME_OBJECT PyFrameObject
134
134
135
- #define THP_PyFrame_FastToLocalsWithError PyFrame_FastToLocalsWithError
135
+ static int
136
+ THP_PyFrame_FastToLocalsWithError (THP_EVAL_API_FRAME_OBJECT * frame , int * free_vars_copied ) {
137
+ return PyFrame_FastToLocalsWithError (frame );
138
+ }
136
139
#endif
137
140
138
141
PyObject * guard_error_hook = NULL ;
@@ -161,7 +164,8 @@ static PyObject* _custom_eval_frame(
161
164
PyThreadState * tstate ,
162
165
THP_EVAL_API_FRAME_OBJECT * frame ,
163
166
int throw_flag ,
164
- PyObject * callback );
167
+ PyObject * callback ,
168
+ int * should_clear_frame );
165
169
static PyObject * (* previous_eval_frame )(PyThreadState * tstate ,
166
170
THP_EVAL_API_FRAME_OBJECT * frame , int throw_flag ) = NULL ;
167
171
@@ -283,7 +287,8 @@ inline static PyObject* eval_custom_code_impl(
283
287
PyThreadState * tstate ,
284
288
THP_EVAL_API_FRAME_OBJECT * frame ,
285
289
PyCodeObject * code ,
286
- int throw_flag ) {
290
+ int throw_flag ,
<
E851
/tr>
291
+ int free_vars_copied ) {
287
292
288
293
DEBUG_NULL_CHECK (tstate );
289
294
DEBUG_NULL_CHECK (frame );
@@ -333,6 +338,13 @@ inline static PyObject* eval_custom_code_impl(
333
338
}
334
339
#endif
335
340
341
+ // for 3.11+, if free_vars_copied is true, we do not need to
342
+ // run the first COPY_FREE_VARS since THP_PyFrame_FastToLocalsWithError
343
+ // already did the equivalent action.
344
+ if (free_vars_copied && _Py_OPCODE (_PyCode_CODE (shadow -> f_code )[0 ]) == COPY_FREE_VARS ) {
345
+ shadow -> prev_instr = _PyCode_CODE (shadow -> f_code );
346
+ }
347
+
336
348
#else
337
349
338
350
THP_EVAL_API_FRAME_OBJECT * shadow = PyFrame_New (tstate , code , frame -> f_globals , NULL );
@@ -428,14 +440,16 @@ inline static PyObject* eval_custom_code_impl(
428
440
fastlocals_new [j ] = fastlocals_old [i ];
429
441
}
430
442
443
+ // NOTE: if you want to evaluate frame instead of shadow in 3.12+,
444
+ // you need to clear_old_frame_if_python_312_plus the shadow frame BEFORE
445
+ // calling eval_frame_default (i.e. here) and comment out the
446
+ // clear_old_frame_if_python_312_plus call on the original frame.
447
+
431
448
PyObject * result = eval_frame_default (tstate , shadow , throw_flag );
432
449
433
450
#if IS_PYTHON_3_12_PLUS
434
451
435
- // In 3.12, the frame evaluation function is responsible for
436
- // clearing and popping the frame, so we manually do that on the
437
- // old frame.
438
- clear_old_frame_if_python_312_plus (tstate , frame );
452
+ // frame is cleared by caller
439
453
Py_DECREF (func );
440
454
441
455
#elif IS_PYTHON_3_11_PLUS
@@ -460,13 +474,15 @@ inline static PyObject* eval_custom_code(
460
474
PyThreadState * tstate ,
461
475
THP_EVAL_API_FRAME_OBJECT * frame ,
462
476
PyCodeObject * code ,
463
- int throw_flag ) {
477
+ int throw_flag ,
478
+ int free_vars_copied ) {
464
479
_PytorchRecordFunctionState * rf = _pytorch_record_function_enter ("Torch-Compiled Region" );
465
480
PyObject * result = eval_custom_code_impl (
466
481
tstate ,
467
482
frame ,
468
483
code ,
469
- throw_flag
484
+ throw_flag ,
485
+ free_vars_copied
470
486
);
471
487
_pytorch_record_function_exit (rf );
472
488
return result ;
@@ -487,18 +503,25 @@ static PyObject* _custom_eval_frame_shim(
487
503
return eval_frame_default (tstate , frame , throw_flag );
488
504
}
489
505
490
- return _custom_eval_frame (tstate , frame , throw_flag , callback );
506
+ int should_clear_frame = 0 ;
507
+ PyObject * result = _custom_eval_frame (tstate , frame , throw_flag , callback , & should_clear_frame );
508
+ if (should_clear_frame ) {
509
+ clear_old_frame_if_python_312_plus (tstate , frame );
510
+ }
511
+ return result ;
491
512
}
492
513
493
- // NOTE: In 3.12+, any return NULL; statements must be preceded by
494
- // clear_old_frame_if_python_312_plus(tstate, frame); since the eval frame function
495
- // is now responsible for clearing/popping the frame.
496
- // eval_frame_default/eval_custom_code will clear/pop the frame.
514
+ // NOTE: In 3.12+, the frame evaluation function (callee) is responsible for clearing/popping
515
+ // the frame, meaning that unless we default evaluate the original frame,
516
+ // we are responsible for clearing it - via clear_old_frame_if_python_312_plus.
517
+ // The should_clear_frame flag is used to indicate whether the frame should be
518
+ // cleared by _custom_eval_frame's caller.
497
519
static PyObject * _custom_eval_frame (
498
520
PyThreadState * tstate ,
499
521
THP_EVAL_API_FRAME_OBJECT * frame ,
500
522
int throw_flag ,
501
- PyObject * callback ) {
523
+ PyObject * callback ,
524
+ int * should_clear_frame ) {
502
525
#if IS_PYTHON_3_11_PLUS
503
526
DEBUG_TRACE (
504
527
"begin %s %s %i %i" ,
@@ -552,9 +575,10 @@ static PyObject* _custom_eval_frame(
552
575
}
553
576
554
577
// TODO(jansel): investigate directly using the "fast" representation
555
- if (THP_PyFrame_FastToLocalsWithError (frame ) < 0 ) {
578
+ int free_vars_copied = 0 ;
579
+ if (THP_PyFrame_FastToLocalsWithError (frame , & free_vars_copied ) < 0 ) {
556
580
DEBUG_TRACE ("error %s" , get_frame_name (frame ));
557
- clear_old_frame_if_python_312_plus ( tstate , frame ) ;
581
+ * should_clear_frame = 1 ;
558
582
return NULL ;
559
583
}
560
584
@@ -570,7 +594,7 @@ static PyObject* _custom_eval_frame(
570
594
571
595
if (maybe_cached_code == NULL ) {
572
596
// guard eval failed, keep propagating
573
- clear_old_frame_if_python_312_plus ( tstate , frame ) ;
597
+ * should_clear_frame = 1 ;
574
598
return NULL ;
575
599
} else if (maybe_cached_code == Py_None ) {
576
600
DEBUG_TRACE ("cache miss %s" , get_frame_name (frame ));
@@ -579,7 +603,8 @@ static PyObject* _custom_eval_frame(
579
603
PyCodeObject * cached_code = (PyCodeObject * )maybe_cached_code ;
580
604
// used cached version
581
605
DEBUG_TRACE ("cache hit %s" , get_frame_name (frame ));
582
- return eval_custom_code (tstate , frame , cached_code , throw_flag );
606
+ * should_clear_frame = 1 ;
607
+ return eval_custom_code (tstate , frame , cached_code , throw_flag , free_vars_copied );
583
608
}
584
609
DEBUG_CHECK (PyDict_CheckExact (frame -> f_locals ));
585
610
DEBUG_CHECK (PyDict_CheckExact (frame -> f_globals ));
@@ -595,15 +620,16 @@ static PyObject* _custom_eval_frame(
595
620
_pytorch_record_function_exit (rf );
596
621
if (maybe_cached_code == NULL ) {
597
622
// Python error
598
- clear_old_frame_if_python_312_plus ( tstate , frame ) ;
623
+ * should_clear_frame = 1 ;
599
624
return NULL ;
600
625
} else if (maybe_cached_code != Py_None ) {
601
626
PyCodeObject * cached_code = (PyCodeObject * )maybe_cached_code ;
602
627
// used cached version
603
628
DEBUG_TRACE ("cache hit %s" , get_frame_name (frame ));
604
629
// Re-enable custom behavior
605
630
eval_frame_callback_set (callback );
606
- return eval_custom_code (tstate , frame , cached_code , throw_flag );
631
+ * should_clear_frame = 1 ;
632
+ return eval_custom_code (tstate , frame , cached_code , throw_flag , free_vars_copied );
607
633
}
608
634
// cache miss
609
635
CacheEntry * cache_entry = extract_cache_entry (extra );
@@ -618,7 +644,7 @@ static PyObject* _custom_eval_frame(
618
644
// cascading failure from internal exceptions. The upshot is if
619
645
// Dynamo barfs, that's it for Dynamo, even if you catch the exception
620
646
// inside the torch.compile block we won't try to Dynamo anything else.
621
- clear_old_frame_if_python_312_plus ( tstate , frame ) ;
647
+ * should_clear_frame = 1 ;
622
648
return NULL ;
623
649
} else if (result != Py_None ) {
624
650
DEBUG_TRACE ("create cache %s" , get_frame_name (frame ));
@@ -636,7 +662,8 @@ static PyObject* _custom_eval_frame(
636
662
// will be cleaned up when set_extra_state is called.
637
663
// Re-enable custom behavior
638
664
eval_frame_callback_set (callback );
639
- return eval_custom_code (tstate , frame , CacheEntry_get_code (new_cache_entry ), throw_flag );
665
+ * should_clear_frame = 1 ;
666
+ return eval_custom_code (tstate , frame , CacheEntry_get_code (new_cache_entry ), throw_flag , free_vars_copied );
640
667
} else {
641
668
DEBUG_TRACE ("create skip %s" , get_frame_name (frame ));
642
669
Py_DECREF (result );
0 commit comments