8000 [dynamo] fix 3.11+ refleak · pytorch/pytorch@d6dabc5 · GitHub
[go: up one dir, main page]

Skip to content

Commit d6dabc5

Browse files
committed
[dynamo] fix 3.11+ refleak
ghstack-source-id: 1d9fa5d Pull Request resolved: #124238
1 parent 1d5efb1 commit d6dabc5

File tree

5 files changed

+57
-41
lines changed

5 files changed

+57
-41
lines changed

test/dynamo/test_misc.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
same,
4545
skipIfNotPy311,
4646
unsupported,
47-
xfailIfPy311,
4847
)
4948
from torch._dynamo.utils import CompileProfiler, counters, ifdynstaticdefault
5049
from torch._inductor.utils import run_and_get_code
@@ -10040,7 +10039,6 @@ def forward(self, out):
1004010039
lambda mod: mod.fc,
1004110040
)
1004210041

10043-
@xfailIfPy311
1004410042
def test_sequential_module_free(self):
1004510043
self._test_compile_model_free(
1004610044
lambda: (
@@ -10053,14 +10051,12 @@ def test_sequential_module_free(self):
1005310051
lambda mod: mod[0],
1005410052
)
1005510053

10056-
@xfailIfPy311
1005710054
def test_linear_module_free(self):
1005810055
self._test_compile_model_free(
1005910056
lambda: (torch.nn.Linear(100, 100), torch.randn(100, 100)),
1006010057
lambda mod: mod,
1006110058
)
1006210059

10063-
@xfailIfPy311
1006410060
def test_outside_linear_module_free(self):
1006510061
# Compared to test_linear_module_free, the linear
1006610062
# layer is not the code object that is directly compiled.

torch/_dynamo/testing.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -332,12 +332,6 @@ def skipIfNotPy311(fn):
332332
return unittest.skip(fn)
333333

334334

335-
def xfailIfPy311(fn):
336-
if sys.version_info >= (3, 11):
337-
return unittest.expectedFailure(fn)
338-
return fn
339-
340-
341335
# Controls tests generated in test/inductor/test_torchinductor_dynamic_shapes.py
342336
# and test/dynamo/test_dynamic_shapes.py
343337
def expectedFailureDynamic(fn):

torch/csrc/dynamo/cpython_defs.c

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,10 @@ THP_PyFrame_OpAlreadyRan(_PyInterpreterFrame *frame, int opcode, int oparg)
6868

6969
// https://github.com/python/cpython/blob/0325a8a8cdba6c091bcbbb3c995f3bf1d1217012/Objects/frameobject.c#L1136
7070
// Initialize frame free variables if needed
71+
// free_vars_copied argument added in order to let caller know that the COPY_FREE_VARS
72+
// codepath occurred.
7173
static void
72-
frame_init_get_vars(_PyInterpreterFrame *frame)
74+
frame_init_get_vars(_PyInterpreterFrame *frame, int *free_vars_copied)
7375
{
7476
// COPY_FREE_VARS has no quickened forms, so no need to use _PyOpcode_Deopt
7577
// here:
@@ -91,6 +93,8 @@ frame_init_get_vars(_PyInterpreterFrame *frame)
9193
}
9294
// COPY_FREE_VARS doesn't have inline CACHEs, either:
9395
frame->prev_instr = _PyCode_CODE(frame->f_code);
96+
97+
*free_vars_copied = 1;
9498
}
9599

96100
// https://github.com/python/cpython/blob/0325a8a8cdba6c091bcbbb3c995f3bf1d1217012/Objects/frameobject.c#L1162
@@ -146,7 +150,7 @@ frame_get_var(_PyInterpreterFrame *frame, PyCodeObject *co, int i,
146150

147151
// https://github.com/python/cpython/blob/0325a8a8cdba6c091bcbbb3c995f3bf1d1217012/Objects/frameobject.c#L1213
148152
static PyObject *
149-
THP_PyFrame_GetLocals(_PyInterpreterFrame *frame, int include_hidden)
153+
THP_PyFrame_GetLocals(_PyInterpreterFrame *frame, int include_hidden, int *free_vars_copied)
150154
{
151155
/* Merge fast locals into f->f_locals */
152156
PyObject *locals = frame->f_locals;
@@ -169,7 +173,7 @@ THP_PyFrame_GetLocals(_PyInterpreterFrame *frame, int include_hidden)
169173
}
170174
}
171175

172-
frame_init_get_vars(frame);
176+
frame_init_get_vars(frame, free_vars_copied);
173177

174178
PyCodeObject *co = frame->f_code;
175179
for (int i = 0; i < co->co_nlocalsplus; i++) {
@@ -234,9 +238,9 @@ THP_PyFrame_GetLocals(_PyInterpreterFrame *frame, int include_hidden)
234238

235239
// https://github.com/python/cpython/blob/0325a8a8cdba6c091bcbbb3c995f3bf1d1217012/Objects/frameobject.c#L1301
236240
int
237-
THP_PyFrame_FastToLocalsWithError(_PyInterpreterFrame *frame)
241+
THP_PyFrame_FastToLocalsWithError(_PyInterpreterFrame *frame, int *free_vars_copied)
238242
{
239-
PyObject *locals = THP_PyFrame_GetLocals(frame, 0);
243+
PyObject *locals = THP_PyFrame_GetLocals(frame, 0, free_vars_copied);
240244
if (locals == NULL) {
241245
return -1;
242246
}
@@ -247,8 +251,10 @@ THP_PyFrame_FastToLocalsWithError(_PyInterpreterFrame *frame)
247251
#else
248252

249253
// https://github.com/python/cpython/blob/a7715ccfba5b86ab09f86ec56ac3755c93b46b48/Objects/frameobject.c#L1182
254+
// free_vars_copied argument added in order to let caller know that the COPY_FREE_VARS
255+
// codepath occurred.
250256
int
251-
THP_PyFrame_FastToLocalsWithError(_PyInterpreterFrame *frame) {
257+
THP_PyFrame_FastToLocalsWithError(_PyInterpreterFrame *frame, int *free_vars_copied) {
252258
/* Merge fast locals into f->f_locals */
253259
PyObject *locals = NULL;
254260
PyObject **fast = NULL;
@@ -267,20 +273,17 @@ THP_PyFrame_FastToLocalsWithError(_PyInterpreterFrame *frame) {
267273
if (lasti < 0 && _Py_OPCODE(_PyCode_CODE(co)[0]) == COPY_FREE_VARS) {
268274
/* Free vars have not been initialized -- Do that */
269275
PyCodeObject *co = frame->f_code;
270-
#if IS_PYTHON_3_12_PLUS
271-
PyObject *closure = ((PyFunctionObject *)frame->f_funcobj)->func_closure;
272-
int offset = co->co_nlocals + co->co_ncellvars;
273-
#else
274276
PyObject *closure = frame->f_func->func_closure;
275277
int offset = co->co_nlocals + co->co_nplaincellvars;
276-
#endif
277278
for (int i = 0; i < co->co_nfreevars; ++i) {
278279
PyObject *o = PyTuple_GET_ITEM(closure, i);
279280
Py_INCREF(o);
280281
frame->localsplus[offset + i] = o;
281282
}
282283
// COPY_FREE_VARS doesn't have inline CACHEs, either:
283284
frame->prev_instr = _PyCode_CODE(frame->f_code);
285+
286+
*free_vars_copied = 1;
284287
}
285288
for (int i = 0; i < co->co_nlocalsplus; i++) {
286289
_PyLocals_Kind kind = _PyLocals_GetKind(co->co_localspluskinds, i);

torch/csrc/dynamo/cpython_defs.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010

1111
#include <internal/pycore_frame.h>
1212

13-
int THP_PyFrame_FastToLocalsWithError(_PyInterpreterFrame* frame);
13+
int THP_PyFrame_FastToLocalsWithError(
14+
_PyInterpreterFrame* frame,
15+
int* free_vars_copied);
1416

1517
PyFunctionObject* _PyFunction_CopyWithNewCode(
1618
PyFunctionObject* o,

torch/csrc/dynamo/eval_frame.c

Lines changed: 40 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,10 @@ THPPyInterpreterFrame* THPPyInterpreterFrame_New(_PyInterpreterFrame* frame) {
132132
#else
133133
#define THP_EVAL_API_FRAME_OBJECT PyFrameObject
134134

135-
#define THP_PyFrame_FastToLocalsWithError PyFrame_FastToLocalsWithError
135+
static int
136+
THP_PyFrame_FastToLocalsWithError(THP_EVAL_API_FRAME_OBJECT *frame, int *free_vars_copied) {
137+
return PyFrame_FastToLocalsWithError(frame);
138+
}
136139
#endif
137140

138141
PyObject* guard_error_hook = NULL;
@@ -161,7 +164,8 @@ static PyObject* _custom_eval_frame(
161164
PyThreadState* tstate,
162165
THP_EVAL_API_FRAME_OBJECT* frame,
163166
int throw_flag,
164-
PyObject* callback);
167+
PyObject* callback,
168+
int* should_clear_frame);
165169
static PyObject *(*previous_eval_frame)(PyThreadState *tstate,
166170
THP_EVAL_API_FRAME_OBJECT* frame, int throw_flag) = NULL;
167171

@@ -283,7 +287,8 @@ inline static PyObject* eval_custom_code_impl(
283287
PyThreadState* tstate,
284288
THP_EVAL_API_FRAME_OBJECT* frame,
285289
PyCodeObject* code,
286-
int throw_flag) {
290+
int throw_flag,
291+
int free_vars_copied) {
287292

288293
DEBUG_NULL_CHECK(tstate);
289294
DEBUG_NULL_CHECK(frame);
@@ -333,6 +338,13 @@ inline static PyObject* eval_custom_code_impl(
333338
}
334339
#endif
335340

341+
// for 3.11+, if free_vars_copied is true, we do not need to
342+
// run the first COPY_FREE_VARS since THP_PyFrame_FastToLocalsWithError
343+
// already did the equivalent action.
344+
if (free_vars_copied) {
345+
shadow->prev_instr = _PyCode_CODE(shadow->f_code);
346+
}
347+
336348
#else
337349

338350
THP_EVAL_API_FRAME_OBJECT* shadow = PyFrame_New(tstate, code, frame->f_globals, NULL);
@@ -432,10 +444,7 @@ inline static PyObject* eval_custom_code_impl(
432444

433445
#if IS_PYTHON_3_12_PLUS
434446

435-
// In 3.12, the frame evaluation function is responsible for
436-
// clearing and popping the frame, so we manually do that on the
437-
// old frame.
438-
clear_old_frame_if_python_312_plus(tstate, frame);
447+
// frame is cleared by caller
439448
Py_DECREF(func);
440449

441450
#elif IS_PYTHON_3_11_PLUS
@@ -460,13 +469,15 @@ inline static PyObject* eval_custom_code(
460469
PyThreadState* tstate,
461470
THP_EVAL_API_FRAME_OBJECT* frame,
462471
PyCodeObject* code,
463-
int throw_flag) {
472+
int throw_flag,
473+
int free_vars_copied) {
464474
_PytorchRecordFunctionState* rf = _pytorch_record_function_enter("Torch-Compiled Region");
465475
PyObject* result = eval_custom_code_impl(
466476
tstate,
467477
frame,
468478
code,
469-
throw_flag
479+
throw_flag,
480+
free_vars_copied
470481
);
471482
_pytorch_record_function_exit(rf);
472483
return result;
@@ -487,7 +498,12 @@ static PyObject* _custom_eval_frame_shim(
487498
return eval_frame_default(tstate, frame, throw_flag);
488499
}
489500

490-
return _custom_eval_frame(tstate, frame, throw_flag, callback);
501+
int should_clear_frame = 0;
502+
PyObject* result = _custom_eval_frame(tstate, frame, throw_flag, callback, &should_clear_frame);
503+
if (should_clear_frame) {
504+
clear_old_frame_if_python_312_plus(tstate, frame);
505+
}
506+
return result;
491507
}
492508

493509
// NOTE: In 3.12+, any return NULL; statements must be preceded by
@@ -498,7 +514,8 @@ static PyObject* _custom_eval_frame(
498514
PyThreadState* tstate,
499515
THP_EVAL_API_FRAME_OBJECT* frame,
500516
int throw_flag,
501-
PyObject* callback) {
517+
PyObject* callback,
518+
int* should_clear_frame) {
502519
#if IS_PYTHON_3_11_PLUS
503520
DEBUG_TRACE(
504521
"begin %s %s %i %i",
@@ -552,9 +569,10 @@ static PyObject* _custom_eval_frame(
552569
}
553570

554571
// TODO(jansel): investigate directly using the "fast" representation
555-
if (THP_PyFrame_FastToLocalsWithError(frame) < 0) {
572+
int free_vars_copied = 0;
573+
if (THP_PyFrame_FastToLocalsWithError(frame, &free_vars_copied) < 0) {
556574
DEBUG_TRACE("error %s", get_frame_name(frame));
557-
clear_old_frame_if_python_312_plus(tstate, frame);
575+
*should_clear_frame = 1;
558576
return NULL;
559577
}
560578

@@ -570,7 +588,7 @@ static PyObject* _custom_eval_frame(
570588

571589
if (maybe_cached_code == NULL) {
572590
// guard eval failed, keep propagating
573-
clear_old_frame_if_python_312_plus(tstate, frame);
591+
*should_clear_frame = 1;
574592
return NULL;
575593
} else if (maybe_cached_code == Py_None) {
576594
DEBUG_TRACE("cache miss %s", get_frame_name(frame));
@@ -579,7 +597,8 @@ static PyObject* _custom_eval_frame(
579597
PyCodeObject* cached_code = (PyCodeObject*)maybe_cached_code;
580598
// used cached version
581599
DEBUG_TRACE("cache hit %s", get_frame_name(frame));
582-
return eval_custom_code(tstate, frame, cached_code, throw_flag);
600+
*should_clear_frame = 1;
601+
return eval_custom_code(tstate, frame, cached_code, throw_flag, free_vars_copied);
583602
}
584603
DEBUG_CHECK(PyDict_CheckExact(frame->f_locals));
585604
DEBUG_CHECK(PyDict_CheckExact(frame->f_globals));
@@ -595,15 +614,16 @@ static PyObject* _custom_eval_frame(
595614
_pytorch_record_function_exit(rf);
596615
if (maybe_cached_code == NULL) {
597616
// Python error
598-
clear_old_frame_if_python_312_plus(tstate, frame);
617+
*should_clear_frame = 1;
599618
return NULL;
600619
} else if (maybe_cached_code != Py_None) {
601620
PyCodeObject* cached_code = (PyCodeObject*)maybe_cached_code;
602621
// used cached version
603622
DEBUG_TRACE("cache hit %s", get_frame_name(frame));
604623
// Re-enable custom behavior
605624
eval_frame_callback_set(callback);
606-
return eval_custom_code(tstate, frame, cached_code, throw_flag);
625+
*should_clear_frame = 1;
626+
return eval_custom_code(tstate, frame, cached_code, throw_flag, free_vars_copied);
607627
}
608628
// cache miss
609629
CacheEntry* cache_entry = extract_cache_entry(extra);
@@ -618,7 +638,7 @@ static PyObject* _custom_eval_frame(
618638
// cascading failure from internal exceptions. The upshot is if
619639
// Dynamo barfs, that's it for Dynamo, even if you catch the exception
620640
// inside the torch.compile block we won't try to Dynamo anything else.
621-
clear_old_frame_if_python_312_plus(tstate, frame);
641+
*should_clear_frame = 1;
622642
return NULL;
623643
} else if (result != Py_None) {
624644
DEBUG_TRACE("create cache %s", get_frame_name(frame));
@@ -636,7 +656,8 @@ static PyObject* _custom_eval_frame(
636656
// will be cleaned up when set_extra_state is called.
637657
// Re-enable custom behavior
638658
eval_frame_callback_set(callback);
639-
return eval_custom_code(tstate, frame, CacheEntry_get_code(new_cache_entry), throw_flag);
659+
*should_clear_frame = 1;
660+
return eval_custom_code(tstate, frame, CacheEntry_get_code(new_cache_entry), throw_flag, free_vars_copied);
640661
} else {
641662
DEBUG_TRACE("create skip %s", get_frame_name(frame));
642663
Py_DECREF(result);

0 commit comments

Comments
 (0)
0