8000 [dynamo] fix 3.11+ refleak (#124238) · pytorch/pytorch@812bae0 · GitHub
[go: up one dir, main page]

Skip to content

Commit 812bae0

Browse files
williamwen42pytorchmergebot
authored andcommitted
[dynamo] fix 3.11+ refleak (#124238)
Fixes #119607 for 3.11+. In 3.11+, `_PyFrame_FastToLocalsWithError` could implicity run `COPY_FREE_VARS` on the original frame, leading to double incref's since the dynamo shadow frame can rerun `COPY_FREE_VARS`. So the solution is to skip the first `COPY_FREE_VARS` instruction in the shadow frame if it was already executed in the original frame. Also move the location for clearing the original frame in 3.12 to handle error cases more thoroughly. Pull Request resolved: #124238 Approved by: https://github.com/jansel
1 parent 7c94652 commit 812bae0

File tree

5 files changed

+67
-45
lines changed

5 files changed

+67
-45
lines changed

test/dynamo/test_misc.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
same,
4545
skipIfNotPy311,
4646
unsupported,
47-
xfailIfPy311,
4847
)
4948
from torch._dynamo.utils import CompileProfiler, counters, ifdynstaticdefault
5049
from torch._inductor.utils import run_and_get_code
@@ -10100,7 +10099,6 @@ def forward(self, out):
1010010099
lambda mod: mod.fc,
1010110100
)
1010210101

10103-
@xfailIfPy311
1010410102
def test_sequential_module_free(self):
1010510103
self._test_compile_model_free(
1010610104
lambda: (
@@ -10113,14 +10111,12 @@ def test_sequential_module_free(self):
1011310111
lambda mod: mod[0],
1011410112
)
1011510113

10116-
@xfailIfPy311
1011710114
def test_linear_module_free(self):
1011810115
self._test_compile_model_free(
1011910116
lambda: (torch.nn.Linear(100, 100), torch.randn(100, 100)),
1012010117
lambda mod: mod,
1012110118
)
1012210119

10123-
@xfailIfPy311
1012410120
def test_outside_linear_module_free(self):
1012510121
# Compared to test_linear_module_free, the linear
1012610122
# layer is not the code object that is directly compiled.

torch/_dynamo/testing.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -341,12 +341,6 @@ def skipIfNotPy311(fn):
341341
return unittest.skip(fn)
342342

343343

344-
def xfailIfPy311(fn):
345-
if sys.version_info >= (3, 11):
346-
return unittest.expectedFailure(fn)
347-
return fn
348-
349-
350344
# Controls tests generated in test/inductor/test_torchinductor_dynamic_shapes.py
351345
# and test/dynamo/test_dynamic_shapes.py
352346
def expectedFailureDynamic(fn):

torch/csrc/dynamo/cpython_defs.c

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,10 @@ THP_PyFrame_OpAlreadyRan(_PyInterpreterFrame *frame, int opcode, int oparg)
6868

6969
// https://github.com/python/cpython/blob/0325a8a8cdba6c091bcbbb3c995f3bf1d1217012/Objects/frameobject.c#L1136
7070
// Initialize frame free variables if needed
71+
// free_vars_copied argument added in order to let caller know that the COPY_FREE_VARS
72+
// codepath occurred.
7173
static void
72-
frame_init_get_vars(_PyInterpreterFrame *frame)
74+
frame_init_get_vars(_PyInterpreterFrame *frame, int *free_vars_copied)
7375
{
7476
// COPY_FREE_VARS has no quickened forms, so no need to use _PyOpcode_Deopt
7577
// here:
@@ -91,6 +93,8 @@ frame_init_get_vars(_PyInterpreterFrame *frame)
9193
}
9294
// COPY_FREE_VARS doesn't have inline CACHEs, either:
9395
frame->prev_instr = _PyCode_CODE(frame->f_code);
96+
97+
*free_vars_copied = 1;
9498
}
9599

96100
// https://github.com/python/cpython/blob/0325a8a8cdba6c091bcbbb3c995f3bf1d1217012/Objects/frameobject.c#L1162
@@ -146,7 +150,7 @@ frame_get_var(_PyInterpreterFrame *frame, PyCodeObject *co, int i,
146150

147151
// https://github.com/python/cpython/blob/0325a8a8cdba6c091bcbbb3c995f3bf1d1217012/Objects/frameobject.c#L1213
148152
static PyObject *
149-
THP_PyFrame_GetLocals(_PyInterpreterFrame *frame, int include_hidden)
153+
THP_PyFrame_GetLocals(_PyInterpreterFrame *frame, int include_hidden, int *free_vars_copied)
150154
{
151155
/* Merge fast locals into f->f_locals */
152156
PyObject *locals = frame->f_locals;
@@ -169,7 +173,7 @@ THP_PyFrame_GetLocals(_PyInterpreterFrame *frame, int include_hidden)
169173
}
170174
}
171175

172-
frame_init_get_vars(frame);
176+
frame_init_get_vars(frame, free_vars_copied);
173177

174178
PyCodeObject *co = frame->f_code;
175179
for (int i = 0; i < co->co_nlocalsplus; i++) {
@@ -234,9 +238,9 @@ THP_PyFrame_GetLocals(_PyInterpreterFrame *frame, int include_hidden)
234238

235239
// https://github.com/python/cpython/blob/0325a8a8cdba6c091bcbbb3c995f3bf1d1217012/Objects/frameobject.c#L1301
236240
int
237-
THP_PyFrame_FastToLocalsWithError(_PyInterpreterFrame *frame)
241+
THP_PyFrame_FastToLocalsWithError(_PyInterpreterFrame *frame, int *free_vars_copied)
238242
{
239-
PyObject *locals = THP_PyFrame_GetLocals(frame, 0);
243+
PyObject *locals = THP_PyFrame_GetLocals(frame, 0, free_vars_copied);
240244
if (locals == NULL) {
241245
return -1;
242246
}
@@ -247,8 +251,10 @@ THP_PyFrame_FastToLocalsWithError(_PyInterpreterFrame *frame)
247251
#else
248252

249253
// https://github.com/python/cpython/blob/a7715ccfba5b86ab09f86ec56ac3755c93b46b48/Objects/frameobject.c#L1182
254+
// free_vars_copied argument added in order to let caller know that the COPY_FREE_VARS
255+
// codepath occurred.
250256
int
251-
THP_PyFrame_FastToLocalsWithError(_PyInterpreterFrame *frame) {
257+
THP_PyFrame_FastToLocalsWithError(_PyInterpreterFrame *frame, int *free_vars_copied) {
252258
/* Merge fast locals into f->f_locals */
253259
PyObject *locals = NULL;
254260
PyObject **fast = NULL;
@@ -267,20 +273,17 @@ THP_PyFrame_FastToLocalsWithError(_PyInterpreterFrame *frame) {
267273
if (lasti < 0 && _Py_OPCODE(_PyCode_CODE(co)[0]) == COPY_FREE_VARS) {
268274
/* Free vars have not been initialized -- Do that */
269275
PyCodeObject *co = frame->f_code;
270-
#if IS_PYTHON_3_12_PLUS
271-
PyObject *closure = ((PyFunctionObject *)frame->f_funcobj)->func_closure;
272-
int offset = co->co_nlocals + co->co_ncellvars;
273-
#else
274276
PyObject *closure = frame->f_func->func_closure;
275277
int offset = co->co_nlocals + co->co_nplaincellvars;
276-
#endif
277278
for (int i = 0; i < co->co_nfreevars; ++i) {
278279
PyObject *o = PyTuple_GET_ITEM(closure, i);
279280
Py_INCREF(o);
280281
frame->localsplus[offset + i] = o;
281282
}
282283
// COPY_FREE_VARS doesn't have inline CACHEs, either:
283284
frame->prev_instr = _PyCode_CODE(frame->f_code);
285+
286+
*free_vars_copied = 1;
284287
}
285288
for (int i = 0; i < co->co_nlocalsplus; i++) {
286289
_PyLocals_Kind kind = _PyLocals_GetKind(co->co_localspluskinds, i);

torch/csrc/dynamo/cpython_defs.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010

1111
#include <internal/pycore_frame.h>
1212

13-
int THP_PyFrame_FastToLocalsWithError(_PyInterpreterFrame* frame);
13+
int THP_PyFrame_FastToLocalsWithError(
14+
_PyInterpreterFrame* frame,
15+
int* free_vars_copied);
1416

1517
PyFunctionObject* _PyFunction_CopyWithNewCode(
1618
PyFunctionObject* o,

torch/csrc/dynamo/eval_frame.c

Lines changed: 50 additions & 23 deletions
< E851 /tr>
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,10 @@ THPPyInterpreterFrame* THPPyInterpreterFrame_New(_PyInterpreterFrame* frame) {
132132
#else
133133
#define THP_EVAL_API_FRAME_OBJECT PyFrameObject
134134

135-
#define THP_PyFrame_FastToLocalsWithError PyFrame_FastToLocalsWithError
135+
static int
136+
THP_PyFrame_FastToLocalsWithError(THP_EVAL_API_FRAME_OBJECT *frame, int *free_vars_copied) {
137+
return PyFrame_FastToLocalsWithError(frame);
138+
}
136139
#endif
137140

138141
PyObject* guard_error_hook = NULL;
@@ -161,7 +164,8 @@ static PyObject* _custom_eval_frame(
161164
PyThreadState* tstate,
162165
THP_EVAL_API_FRAME_OBJECT* frame,
163166
int throw_flag,
164-
PyObject* callback);
167+
PyObject* callback,
168+
int* should_clear_frame);
165169
static PyObject *(*previous_eval_frame)(PyThreadState *tstate,
166170
THP_EVAL_API_FRAME_OBJECT* frame, int throw_flag) = NULL;
167171

@@ -283,7 +287,8 @@ inline static PyObject* eval_custom_code_impl(
283287
PyThreadState* tstate,
284288
THP_EVAL_API_FRAME_OBJECT* frame,
285289
PyCodeObject* code,
286-
int throw_flag) {
290+
int throw_flag,
291+
int free_vars_copied) {
287292

288293
DEBUG_NULL_CHECK(tstate);
289294
DEBUG_NULL_CHECK(frame);
@@ -333,6 +338,13 @@ inline static PyObject* eval_custom_code_impl(
333338
}
334339
#endif
335340

341+
// for 3.11+, if free_vars_copied is true, we do not need to
342+
// run the first COPY_FREE_VARS since THP_PyFrame_FastToLocalsWithError
343+
// already did the equivalent action.
344+
if (free_vars_copied && _Py_OPCODE(_PyCode_CODE(shadow->f_code)[0]) == COPY_FREE_VARS) {
345+
shadow->prev_instr = _PyCode_CODE(shadow->f_code);
346+
}
347+
336348
#else
337349

338350
THP_EVAL_API_FRAME_OBJECT* shadow = PyFrame_New(tstate, code, frame->f_globals, NULL);
@@ -428,14 +440,16 @@ inline static PyObject* eval_custom_code_impl(
428440
fastlocals_new[j] = fastlocals_old[i];
429441
}
430442

443+
// NOTE: if you want to evaluate frame instead of shadow in 3.12+,
444+
// you need to clear_old_frame_if_python_312_plus the shadow frame BEFORE
445+
// calling eval_frame_default (i.e. here) and comment out the
446+
// clear_old_frame_if_python_312_plus call on the original frame.
447+
431448
PyObject* result = eval_frame_default(tstate, shadow, throw_flag);
432449

433450
#if IS_PYTHON_3_12_PLUS
434451

435-
// In 3.12, the frame evaluation function is responsible for
436-
// clearing and popping the frame, so we manually do that on the
437-
// old frame.
438-
clear_old_frame_if_python_312_plus(tstate, frame);
452+
// frame is cleared by caller
439453
Py_DECREF(func);
440454

441455
#elif IS_PYTHON_3_11_PLUS
@@ -460,13 +474,15 @@ inline static PyObject* eval_custom_code(
460474
PyThreadState* tstate,
461475
THP_EVAL_API_FRAME_OBJECT* frame,
462476
PyCodeObject* code,
463-
int throw_flag) {
477+
int throw_flag,
478+
int free_vars_copied) {
464479
_PytorchRecordFunctionState* rf = _pytorch_record_function_enter("Torch-Compiled Region");
465480
PyObject* result = eval_custom_code_impl(
466481
tstate,
467482
frame,
468483
code,
469-
throw_flag
484+
throw_flag,
485+
free_vars_copied
470486
);
471487
_pytorch_record_function_exit(rf);
472488
return result;
@@ -487,18 +503,25 @@ static PyObject* _custom_eval_frame_shim(
487503
return eval_frame_default(tstate, frame, throw_flag);
488504
}
489505

490-
return _custom_eval_frame(tstate, frame, throw_flag, callback);
506+
int should_clear_frame = 0;
507+
PyObject* result = _custom_eval_frame(tstate, frame, throw_flag, callback, &should_clear_frame);
508+
if (should_clear_frame) {
509+
clear_old_frame_if_python_312_plus(tstate, frame);
510+
}
511+
return result;
491512
}
492513

493-
// NOTE: In 3.12+, any return NULL; statements must be preceded by
494-
// clear_old_frame_if_python_312_plus(tstate, frame); since the eval frame function
495-
// is now responsible for clearing/popping the frame.
496-
// eval_frame_default/eval_custom_code will clear/pop the frame.
514+
// NOTE: In 3.12+, the frame evaluation function (callee) is responsible for clearing/popping
515+
// the frame, meaning that unless we default evaluate the original frame,
516+
// we are responsible for clearing it - via clear_old_frame_if_python_312_plus.
517+
// The should_clear_frame flag is used to indicate whether the frame should be
518+
// cleared by _custom_eval_frame's caller.
497519
static PyObject* _custom_eval_frame(
498520
PyThreadState* tstate,
499521
THP_EVAL_API_FRAME_OBJECT* frame,
500522
int throw_flag,
501-
PyObject* callback) {
523+
PyObject* callback,
524+
int* should_clear_frame) {
502525
#if IS_PYTHON_3_11_PLUS
503526
DEBUG_TRACE(
504527
"begin %s %s %i %i",
@@ -552,9 +575,10 @@ static PyObject* _custom_eval_frame(
552575
}
553576

554577
// TODO(jansel): investigate directly using the "fast" representation
555-
if (THP_PyFrame_FastToLocalsWithError(frame) < 0) {
578+
int free_vars_copied = 0;
579+
if (THP_PyFrame_FastToLocalsWithError(frame, &free_vars_copied) < 0) {
556580
DEBUG_TRACE("error %s", get_frame_name(frame));
557-
clear_old_frame_if_python_312_plus(tstate, frame);
581+
*should_clear_frame = 1;
558582
return NULL;
559583
}
560584

@@ -570,7 +594,7 @@ static PyObject* _custom_eval_frame(
570594

571595
if (maybe_cached_code == NULL) {
572596
// guard eval failed, keep propagating
573-
clear_old_frame_if_python_312_plus(tstate, frame);
597+
*should_clear_frame = 1;
574598
return NULL;
575599
} else if (maybe_cached_code == Py_None) {
576600
DEBUG_TRACE("cache miss %s", get_frame_name(frame));
@@ -579,7 +603,8 @@ static PyObject* _custom_eval_frame(
579603
PyCodeObject* cached_code = (PyCodeObject*)maybe_cached_code;
580604
// used cached version
581605
DEBUG_TRACE("cache hit %s", get_frame_name(frame));
582-
return eval_custom_code(tstate, frame, cached_code, throw_flag);
606+
*should_clear_frame = 1;
607+
return eval_custom_code(tstate, frame, cached_code, throw_flag, free_vars_copied);
583608
}
584609
DEBUG_CHECK(PyDict_CheckExact(frame->f_locals));
585610
DEBUG_CHECK(PyDict_CheckExact(frame->f_globals));
@@ -595,15 +620,16 @@ static PyObject* _custom_eval_frame(
595620
_pytorch_record_function_exit(rf);
596621
if (maybe_cached_code == NULL) {
597622
// Python error
598-
clear_old_frame_if_python_312_plus(tstate, frame);
623+
*should_clear_frame = 1;
599624
return NULL;
600625
} else if (maybe_cached_code != Py_None) {
601626
PyCodeObject* cached_code = (PyCodeObject*)maybe_cached_code;
602627
// used cached version
603628
DEBUG_TRACE("cache hit %s", get_frame_name(frame));
604629
// Re-enable custom behavior
605630
eval_frame_callback_set(callback);
606-
return eval_custom_code(tstate, frame, cached_code, throw_flag);
631+
*should_clear_frame = 1;
632+
return eval_custom_code(tstate, frame, cached_code, throw_flag, free_vars_copied);
607633
}
608634
// cache miss
609635
CacheEntry* cache_entry = extract_cache_entry(extra);
@@ -618,7 +644,7 @@ static PyObject* _custom_eval_frame(
618644
// cascading failure from internal exceptions. The upshot is if
619645
// Dynamo barfs, that's it for Dynamo, even if you catch the exception
620646
// inside the torch.compile block we won't try to Dynamo anything else.
621-
clear_old_frame_if_python_312_plus(tstate, frame);
647+
*should_clear_frame = 1;
622648
return NULL;
623649
} else if (result != Py_None) {
624650
DEBUG_TRACE("create cache %s", get_frame_name(frame));
@@ -636,7 +662,8 @@ static PyObject* _custom_eval_frame(
636662
// will be cleaned up when set_extra_state is called.
637663
// Re-enable custom behavior
638664
eval_frame_callback_set(callback);
639-
return eval_custom_code(tstate, frame, CacheEntry_get_code(new_cache_entry), throw_flag);
665+
*should_clear_frame = 1;
666+
return eval_custom_code(tstate, frame, CacheEntry_get_code(new_cache_entry), throw_flag, free_vars_copied);
640667
} else {
641668
DEBUG_TRACE("create skip %s", get_frame_name(frame));
642669
Py_DECREF(result);

0 commit comments

Comments
 (0)
0