8000 [dynamo] fix 3.11+ refleak · pytorch/pytorch@6ac1867 · GitHub
[go: up one dir, main page]

Skip to content

Commit 6ac1867

Browse files
committed
[dynamo] fix 3.11+ refleak
ghstack-source-id: c6fb318 Pull Request resolved: #124238
1 parent 4efdf9a commit 6ac1867

File tree

5 files changed

+67
-45
lines changed

5 files changed

+67
-45
lines changed

test/dynamo/test_misc.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
same,
4545
skipIfNotPy311,
4646
unsupported,
47-
xfailIfPy311,
4847
)
4948
from torch._dynamo.utils import CompileProfiler, counters, ifdynstaticdefault
5049
from torch._inductor.utils import run_and_get_code
@@ -10100,7 +10099,6 @@ def forward(self, out):
1010010099
lambda mod: mod.fc,
1010110100
)
1010210101

10103-
@xfailIfPy311
1010410102
def test_sequential_module_free(self):
1010510103
self._test_compile_model_free(
1010610104
lambda: (
@@ -10113,14 +10111,12 @@ def test_sequential_module_free(self):
1011310111
lambda mod: mod[0],
1011410112
)
1011510113

10116-
@xfailIfPy311
1011710114
def test_linear_module_free(self):
1011810115
self._test_compile_model_free(
1011910116
lambda: (torch.nn.Linear(100, 100), torch.randn(100, 100)),
1012010117
lambda mod: mod,
1012110118
)
1012210119

10123-
@xfailIfPy311
1012410120
def test_outside_linear_module_free(self):
1012510121
# Compared to test_linear_module_free, the linear
1012610122
# layer is not the code object that is directly compiled.

torch/_dynamo/testing.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -341,12 +341,6 @@ def skipIfNotPy311(fn):
341341
return unittest.skip(fn)
342342

343343

344-
def xfailIfPy311(fn):
345-
if sys.version_info >= (3, 11):
346-
return unittest.expectedFailure(fn)
347-
return fn
348-
349-
350344
# Controls tests generated in test/inductor/test_torchinductor_dynamic_shapes.py
351345
# and test/dynamo/test_dynamic_shapes.py
352346
def expectedFailureDynamic(fn):

torch/csrc/dynamo/cpython_defs.c

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,10 @@ THP_PyFrame_OpAlreadyRan(_PyInterpreterFrame *frame, int opcode, int oparg)
6868

6969
// https://github.com/python/cpython/blob/0325a8a8cdba6c091bcbbb3c995f3bf1d1217012/Objects/frameobject.c#L1136
7070
// Initialize frame free variables if needed
71+
// free_vars_copied argument added in order to let caller know that the COPY_FREE_VARS
72+
// codepath occurred.
7173
static void
72-
frame_init_get_vars(_PyInterpreterFrame *frame)
74+
frame_init_get_vars(_PyInterpreterFrame *frame, int *free_vars_copied)
7375
{
7476
// COPY_FREE_VARS has no quickened forms, so no need to use _PyOpcode_Deopt
7577
// here:
@@ -91,6 +93,8 @@ frame_init_get_vars(_PyInterpreterFrame *frame)
9193
}
9294
// COPY_FREE_VARS doesn't have inline CACHEs, either:
9395
frame->prev_instr = _PyCode_CODE(frame->f_code);
96+
97+
*free_vars_copied = 1;
9498
}
9599

96100
// https://github.com/python/cpython/blob/0325a8a8cdba6c091bcbbb3c995f3bf1d1217012/Objects/frameobject.c#L1162
@@ -146,7 +150,7 @@ frame_get_var(_PyInterpreterFrame *frame, PyCodeObject *co, int i,
146150

147151
// https://github.com/python/cpython/blob/0325a8a8cdba6c091bcbbb3c995f3bf1d1217012/Objects/frameobject.c#L1213
148152
static PyObject *
149-
THP_PyFrame_GetLocals(_PyInterpreterFrame *frame, int include_hidden)
153+
THP_PyFrame_GetLocals(_PyInterpreterFrame *frame, int include_hidden, int *free_vars_copied)
150154
{
151155
/* Merge fast locals into f->f_locals */
152156
PyObject *locals = frame->f_locals;
@@ -169,7 +173,7 @@ THP_PyFrame_GetLocals(_PyInterpreterFrame *frame, int include_hidden)
169173
}
170174
}
171175

172-
frame_init_get_vars(frame);
176+
frame_init_get_vars(frame, free_vars_copied);
173177

174178
PyCodeObject *co = frame->f_code;
175179
for (int i = 0; i < co->co_nlocalsplus; i++) {
@@ -234,9 +238,9 @@ THP_PyFrame_GetLocals(_PyInterpreterFrame *frame, int include_hidden)
234238

235239
// https://github.com/python/cpython/blob/0325a8a8cdba6c091bcbbb3c995f3bf1d1217012/Objects/frameobject.c#L1301
236240
int
237-
THP_PyFrame_FastToLocalsWithError(_PyInterpreterFrame *frame)
241+
THP_PyFrame_FastToLocalsWithError(_PyInterpreterFrame *frame, int *free_vars_copied)
238242
{
239-
PyObject *locals = THP_PyFrame_GetLocals(frame, 0);
243+
PyObject *locals = THP_PyFrame_GetLocals(frame, 0, free_vars_copied);
240244
if (locals == NULL) {
241245
return -1;
242246
}
@@ -247,8 +251,10 @@ THP_PyFrame_FastToLocalsWithError(_PyInterpreterFrame *frame)
247251
#else
248252

249253
// https://github.com/python/cpython/blob/a7715ccfba5b86ab09f86ec56ac3755c93b46b48/Objects/frameobject.c#L1182
254+
// free_vars_copied argument added in order to let caller know that the COPY_FREE_VARS
255+
// codepath occurred.
250256
int
251-
THP_PyFrame_FastToLocalsWithError(_PyInterpreterFrame *frame) {
257+
THP_PyFrame_FastToLocalsWithError(_PyInterpreterFrame *frame, int *free_vars_copied) {
252258
/* Merge fast locals into f->f_locals */
253259
PyObject *locals = NULL;
254260
PyObject **fast = NULL;
@@ -267,20 +273,17 @@ THP_PyFrame_FastToLocalsWithError(_PyInterpreterFrame *frame) {
267273
if (lasti < 0 && _Py_OPCODE(_PyCode_CODE(co)[0]) == COPY_FREE_VARS) {
268274
/* Free vars have not been initialized -- Do that */
269275
PyCodeObject *co = frame->f_code;
270-
#if IS_PYTHON_3_12_PLUS
271-
PyObject *closure = ((PyFunctionObject *)frame->f_funcobj)->func_closure;
272-
int offset = co->co_nlocals + co->co_ncellvars;
273-
#else
274276
PyObject *closure = frame->f_func->func_closure;
275277
int offset = co->co_nlocals + co->co_nplaincellvars;
276-
#endif
277278
for (int i = 0; i < co->co_nfreevars; ++i) {
278279
PyObject *o = PyTuple_GET_ITEM(closure, i);
279280
Py_INCREF(o);
280281
frame->localsplus[offset + i] = o;
281282
}
282283
// COPY_FREE_VARS doesn't have inline CACHEs, either:
283284
frame->prev_instr = _PyCode_CODE(frame->f_code);
285+
286+
*free_vars_copied = 1;
284287
}
285288
for (int i = 0; i < co->co_nlocalsplus; i++) {
286289
_PyLocals_Kind kind = _PyLocals_GetKind(co->co_localspluskinds, i);

torch/csrc/dynamo/cpython_defs.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010

1111
#include <internal/pycore_frame.h>
1212

13-
int THP_PyFrame_FastToLocalsWithError(_PyInterpreterFrame* frame);
13+
int THP_PyFrame_FastToLocalsWithError(
14+
_PyInterpreterFrame* frame,
15+
int* free_vars_copied);
1416

1517
PyFunctionObject* _PyFunction_CopyWithNewCode(
1618
PyFunctionObject* o,

torch/csrc/dynamo/eval_frame.c

Lines changed: 50 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,10 @@ THPPyInterpreterFrame* THPPyInterpreterFrame_New(_PyInterpreterFrame* frame) {
132132
#else
133133
#define THP_EVAL_API_FRAME_OBJECT PyFrameObject
134134

135-
#define THP_PyFrame_FastToLocalsWithError PyFrame_FastToLocalsWithError
135+
static int
136+
THP_PyFrame_FastToLocalsWithError(THP_EVAL_API_FRAME_OBJECT *frame, int *free_vars_copied) {
137+
return PyFrame_FastToLocalsWithError(frame);
138+
}
136139
#endif
137140

138141
PyObject* guard_error_hook = NULL;
@@ -161,7 +164,8 @@ static PyObject* _custom_eval_frame(
161164
PyThreadState* tstate,
162165
THP_EVAL_API_FRAME_OBJECT* frame,
163166
int throw_flag,
164-
PyObject* callback);
167+
PyObject* callback,
168+
int* should_clear_frame);
165169
static PyObject *(*previous_eval_frame)(PyThreadState *tstate,
166170
THP_EVAL_API_FRAME_OBJECT* frame, int throw_flag) = NULL;
167171

@@ -283,7 +287,8 @@ inline static PyObject* eval_custom_code_impl(
283287
PyThreadState* tstate,
284288
THP_EVAL_API_FRAME_OBJECT* frame,
285289
PyCodeObject* code,
286-
int throw_flag) {
290+
int throw_flag,
291+
int free_vars_copied) {
287292

288293
DEBUG_NULL_CHECK(tstate);
289294
DEBUG_NULL_CHECK(frame);
@@ -333,6 +338,13 @@ inline static PyObject* eval_custom_code_impl(
333338
}
334339
#endif
335340

341+
// for 3.11+, if free_vars_copied is true, we do not need to
342+
// run the first COPY_FREE_VARS since THP_PyFrame_FastToLocalsWithError
343+
// already did the equivalent action.
344+
if (free_vars_copied && _Py_OPCODE(_PyCode_CODE(shadow->f_code)[0]) == COPY_FREE_VARS) {
345+
shadow->prev_instr = _PyCode_CODE(shadow->f_code);
346+
}
347+
336348
#else
337349

338350
THP_EVAL_API_FRAME_OBJECT* shadow = PyFrame_New(tstate, code, frame->f_globals, NULL);
@@ -428,14 +440,16 @@ inline static PyObject* eval_custom_code_impl(
428440
fastlocals_new[j] = fastlocals_old[i];
429441
}
430442

443+
// NOTE: if you want to evaluate frame instead of shadow in 3.12+,
444+
// you need to clear_old_frame_if_python_312_plus the shadow frame BEFORE
445+
// calling eval_frame_default (i.e. here) and comment out the
446+
// clear_old_frame_if_python_312_plus call on the original frame.
447+
431448
PyObject* result = eval_frame_default(tstate, shadow, throw_flag);
432449

433450
#if IS_PYTHON_3_12_PLUS
434451

435-
// In 3.12, the frame evaluation function is responsible for
436-
// clearing and popping the frame, so we manually do that on the
437-
// old frame.
438-
clear_old_frame_if_python_312_plus(tstate, frame);
452+
// frame is cleared by caller
439453
Py_DECREF(func);
440454

441455
#elif IS_PYTHON_3_11_PLUS
@@ -460,13 +474,15 @@ inline static PyObject* eval_custom_code(
460474
PyThreadState* tstate,
461475
THP_EVAL_API_FRAME_OBJECT* frame,
462476
PyCodeObject* code,
463-
int throw_flag) {
477+
int throw_flag,
478+
int free_vars_copied) {
464479
_PytorchRecordFunctionState* rf = _pytorch_record_function_enter("Torch-Compiled Region");
465480
PyObject* result = eval_custom_code_impl(
466481
tstate,
467482
frame,
468483
code,
469-
throw_flag
484+
throw_flag,
485+
free_vars_copied
470486
);
471487
_pytorch_record_function_exit(rf);
472488
return result;
@@ -487,18 +503,25 @@ static PyObject* _custom_eval_frame_shim(
487503
return eval_frame_default(tstate, frame, throw_flag);
488504
}
489505

490-
return _custom_eval_frame(tstate, frame, throw_flag, callback);
506+
int should_clear_frame = 0;
507+
PyObject* result = _custom_eval_frame(tstate, frame, throw_flag, callback, &should_clear_frame);
508+
if (should_clear_frame) {
509+
clear_old_frame_if_python_312_plus(tstate, frame);
510+
}
511+
return result;
491512
}
492513

493-
// NOTE: In 3.12+, any return NULL; statements must be preceded by
494-
// clear_old_frame_if_python_312_plus(tstate, frame); since the eval frame function
495-
// is now responsible for clearing/popping the frame.
496-
// eval_frame_default/eval_custom_code will clear/pop the frame.
514+
// NOTE: In 3.12+, the frame evaluation function (callee) is responsible for clearing/popping
515+
// the frame, meaning that unless we default evaluate the original frame,
516+
// we are responsible for clearing it - via clear_old_frame_if_python_312_plus.
517+
// The should_clear_frame flag is used to indicate whether the frame should be
518+
// cleared by _custom_eval_frame's caller.
497519
static PyObject* _custom_eval_frame(
498520
PyThreadState* tstate,
499521
THP_EVAL_API_FRAME_OBJECT* frame,
500522
int throw_flag,
501-
PyObject* callback) {
523+
PyObject* callback,
524+
int* should_clear_frame) {
502525
#if IS_PYTHON_3_11_PLUS
503526
DEBUG_TRACE(
504527
"begin %s %s %i %i",
@@ -552,9 +575,10 @@ static PyObject* _custom_eval_frame(
552575
}
553576

554577
// TODO(jansel): investigate directly using the "fast" representation
555-
if (THP_PyFrame_FastToLocalsWithError(frame) < 0) {
578+
int free_vars_copied = 0;
579+
if (THP_PyFrame_FastToLocalsWithError(frame, &free_vars_copied) < 0) {
556580
DEBUG_TRACE("error %s", get_frame_name(frame));
557-
clear_old_frame_if_python_312_plus(tstate, frame);
581+
*should_clear_frame = 1;
558582
return NULL;
559583
}
560584

@@ -570,7 +594,7 @@ static PyObject* _custom_eval_frame(
570594

571595
if (maybe_cached_code == NULL) {
572596
// guard eval failed, keep propagating
573-
clear_old_frame_if_python_312_plus(tstate, frame);
597+
*should_clear_frame = 1;
574598
return NULL;
575599
} else if (maybe_cached_code == Py_None) {
576600
DEBUG_TRACE("cache miss %s", get_frame_name(frame));
@@ -579,7 +603,8 @@ static PyObject* _custom_eval_frame(
579603
PyCodeObject* cached_code = (PyCodeObject*)maybe_cached_code;
580604
// used cached version
581605
DEBUG_TRACE("cache hit %s", get_frame_name(frame));
582-
return eval_custom_code(tstate, frame, cached_code, throw_flag);
606+
*should_clear_frame = 1;
607+
return eval_custom_code(tstate, frame, cached_code, throw_flag, free_vars_copied);
583608
}
584609
DEBUG_CHECK(PyDict_CheckExact(frame->f_locals));
585610
DEBUG_CHECK(PyDict_CheckExact(frame->f_globals));
@@ -595,15 +620,16 @@ static PyObject* _custom_eval_frame(
595620
_pytorch_record_function_exit(rf);
596621
if (maybe_cached_code == NULL) {
597622
// Python error
598-
clear_old_frame_if_python_312_plus(tstate, frame);
623+
*should_clear_frame = 1;
599624
return NULL;
600625
} else if (maybe_cached_code != Py_None) {
601626
PyCodeObject* cached_code = (PyCodeObject*)maybe_cached_code;
602627
// used cached version
603628
DEBUG_TRACE("cache hit %s", get_frame_name(frame));
604629
// Re-enable custom behavior
605630
eval_frame_callback_set(callback);
606-
return eval_custom_code(tstate, frame, cached_code, throw_flag);
631+
*should_clear_frame = 1;
632+
return eval_custom_code(tstate, frame, cached_code, throw_flag, free_vars_copied);
607633
}
608634
// cache miss
609635
CacheEntry* cache_entry = extract_cache_entry(extra);
@@ -618,7 +644,7 @@ static PyObject* _custom_eval_frame(
618644
// cascading failure from internal exceptions. The upshot is if
619645
// Dynamo barfs, that's it for Dynamo, even if you catch the exception
620646
// inside the torch.compile block we won't try to Dynamo anything else.
621-
clear_old_frame_if_python_312_plus(tstate, frame);
647+
*should_clear_frame = 1;
622648
return NULL;
623649
} else if (result != Py_None) {
624650
DEBUG_TRACE("create cache %s", get_frame_name(frame));
@@ -636,7 +662,8 @@ static PyObject* _custom_eval_frame(
636662
// will be cleaned up when set_extra_state is called.
637663
// Re-enable custom behavior
638664
eval_frame_callback_set(callback);
639-
return eval_custom_code(tstate, frame, CacheEntry_get_code(new_cache_entry), throw_flag);
665+
*should_clear_frame = 1;
666+
return eval_custom_code(tstate, frame, CacheEntry_get_code(new_cache_entry), throw_flag, free_vars_copied);
640667
} else {
641668
DEBUG_TRACE("create skip %s", get_frame_name(frame));
642669
Py_DECREF(result);

0 commit comments

Comments
 (0)
0