8000 [dynamo] fix 3.11+ refleak by williamwen42 · Pull Request #124238 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[dynamo] fix 3.11+ refleak #124238

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
same,
skipIfNotPy311,
unsupported,
xfailIfPy311,
)
from torch._dynamo.utils import CompileProfiler, counters, ifdynstaticde 8000 fault
from torch._inductor.utils import run_and_get_code
Expand Down Expand Up @@ -10100,7 +10099,6 @@ def forward(self, out):
lambda mod: mod.fc,
)

@xfailIfPy311
def test_sequential_module_free(self):
self._test_compile_model_free(
lambda: (
Expand All @@ -10113,14 +10111,12 @@ def test_sequential_module_free(self):
lambda mod: mod[0],
)

@xfailIfPy311
def test_linear_module_free(self):
self._test_compile_model_free(
lambda: (torch.nn.Linear(100, 100), torch.randn(100, 100)),
lambda mod: mod,
)

@xfailIfPy311
def test_outside_linear_module_free(self):
# Compared to test_linear_module_free, the linear
# layer is not the code object that is directly compiled.
Expand Down
6 changes: 0 additions & 6 deletions torch/_dynamo/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,12 +341,6 @@ def skipIfNotPy311(fn):
return unittest.skip(fn)


def xfailIfPy311(fn):
if sys.version_info >= (3, 11):
return unittest.expectedFailure(fn)
return fn


# Controls tests generated in test/inductor/test_torchinductor_dynamic_shapes.py
# and test/dynamo/test_dynamic_shapes.py
def expectedFailureDynamic(fn):
Expand Down
25 changes: 14 additions & 11 deletions torch/csrc/dynamo/cpython_defs.c
8000
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,10 @@ THP_PyFrame_OpAlreadyRan(_PyInterpreterFrame *frame, int opcode, int oparg)

// https://github.com/python/cpython/blob/0325a8a8cdba6c091bcbbb3c995f3bf1d1217012/Objects/frameobject.c#L1136
// Initialize frame free variables if needed
// free_vars_copied argument added in order to let caller know that the COPY_FREE_VARS
// codepath occurred.
static void
frame_init_get_vars(_PyInterpreterFrame *frame)
frame_init_get_vars(_PyInterpreterFrame *frame, int *free_vars_copied)
{
// COPY_FREE_VARS has no quickened forms, so no need to use _PyOpcode_Deopt
// here:
Expand All @@ -91,6 +93,8 @@ frame_init_get_vars(_PyInterpreterFrame *frame)
}
// COPY_FREE_VARS doesn't have inline CACHEs, either:
frame->prev_instr = _PyCode_CODE(frame->f_code);

*free_vars_copied = 1;
}

// https://github.com/python/cpython/blob/0325a8a8cdba6c091bcbbb3c995f3bf1d1217012/Objects/frameobject.c#L1162
Expand Down Expand Up @@ -146,7 +150,7 @@ frame_get_var(_PyInterpreterFrame *frame, PyCodeObject *co, int i,

// https://github.com/python/cpython/blob/0325a8a8cdba6c091bcbbb3c995f3bf1d1217012/Objects/frameobject.c#L1213
static PyObject *
THP_PyFrame_GetLocals(_PyInterpreterFrame *frame, int include_hidden)
THP_PyFrame_GetLocals(_PyInterpreterFrame *frame, int include_hidden, int *free_vars_copied)
{
/* Merge fast locals into f->f_locals */
PyObject *locals = frame->f_locals;
Expand All @@ -169,7 +173,7 @@ THP_PyFrame_GetLocals(_PyInterpreterFrame *frame, int include_hidden)
}
}

frame_init_get_vars(frame);
frame_init_get_vars(frame, free_vars_copied);

PyCodeObject *co = frame->f_code;
for (int i = 0; i < co->co_nlocalsplus; i++) {
Expand Down Expand Up @@ -234,9 +238,9 @@ THP_PyFrame_GetLocals(_PyInterpreterFrame *frame, int include_hidden)

// https://github.com/python/cpython/blob/0325a8a8cdba6c091bcbbb3c995f3bf1d1217012/Objects/frameobject.c#L1301
int
THP_PyFrame_FastToLocalsWithError(_PyInterpreterFrame *frame)
THP_PyFrame_FastToLocalsWithError(_PyInterpreterFrame *frame, int *free_vars_copied)
{
PyObject *locals = THP_PyFrame_GetLocals(frame, 0);
PyObject *locals = THP_PyFrame_GetLocals(frame, 0, free_vars_copied);
if (locals == NULL) {
return -1;
}
Expand All @@ -247,8 +251,10 @@ THP_PyFrame_FastToLocalsWithError(_PyInterpreterFrame *frame)
#else

// https://github.com/python/cpython/blob/a7715ccfba5b86ab09f86ec56ac3755c93b46b48/Objects/frameobject.c#L1182
// free_vars_copied argument added in order to let caller know that the COPY_FREE_VARS
// codepath occurred.
int
THP_PyFrame_FastToLocalsWithError(_PyInterpreterFrame *frame) {
THP_PyFrame_FastToLocalsWithError(_PyInterpreterFrame *frame, int *free_vars_copied) {
/* Merge fast locals into f->f_locals */
PyObject *locals = NULL;
PyObject **fast = NULL;
Expand All @@ -267,20 +273,17 @@ THP_PyFrame_FastToLocalsWithError(_PyInterpreterFrame *frame) {
if (lasti < 0 && _Py_OPCODE(_PyCode_CODE(co)[0]) == COPY_FREE_VARS) {
/* Free vars have not been initialized -- Do that */
PyCodeObject *co = frame->f_code;
#if IS_PYTHON_3_12_PLUS
PyObject *closure = ((PyFunctionObject *)frame->f_funcobj)->func_closure;
int offset = co->co_nlocals + co->co_ncellvars;
#else
PyObject *closure = frame->f_func->func_closure;
int offset = co->co_nlocals + co->co_nplaincellvars;
#endif
for (int i = 0; i < co->co_nfreevars; ++i) {
PyObject *o = PyTuple_GET_ITEM(closure, i);
Py_INCREF(o);
frame->localsplus[offset + i] = o;
}
// COPY_FREE_VARS doesn't have inline CACHEs, either:
frame->prev_instr = _PyCode_CODE(frame->f_code);

*free_vars_copied = 1;
}
for (int i = 0; i < co->co_nlocalsplus; i++) {
_PyLocals_Kind kind = _PyLocals_GetKind(co->co_localspluskinds, i);
Expand Down
4 changes: 3 additions & 1 deletion torch/csrc/dynamo/cpython_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@

#include <internal/pycore_frame.h>

int THP_PyFrame_FastToLocalsWithError(_PyInterpreterFrame* frame);
int THP_PyFrame_FastToLocalsWithError(
_PyInterpreterFrame* frame,
int* free_vars_copied);

PyFunctionObject* _PyFunction_CopyWithNewCode(
PyFunctionObject* o,
Expand Down
73 changes: 50 additions & 23 deletions torch/csrc/dynamo/eval_frame.c
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,10 @@ THPPyInterpreterFrame* THPPyInterpreterFrame_New(_PyInterpreterFrame* frame) {
#else
#define THP_EVAL_API_FRAME_OBJECT PyFrameObject

#define THP_PyFrame_FastToLocalsWithError PyFrame_FastToLocalsWithError
static int
THP_PyFrame_FastToLocalsWithError(THP_EVAL_API_FRAME_OBJECT *frame, int *free_vars_copied) {
return PyFrame_FastToLocalsWithError(frame);
}
#endif

PyObject* guard_error_hook = NULL;
Expand Down Expand Up @@ -161,7 +164,8 @@ static PyObject* _custom_eval_frame(
PyThreadState* tstate,
THP_EVAL_API_FRAME_OBJECT* frame,
int throw_flag,
PyObject* callback);
PyObject* callback,
int* should_clear_frame);
static PyObject *(*previous_eval_frame)(PyThreadState *tstate,
THP_EVAL_API_FRAME_OBJECT* frame, int throw_flag) = NULL;

Expand Down Expand Up @@ -283,7 +287,8 @@ inline static PyObject* eval_custom_code_impl(
PyThreadState* tstate,
THP_EVAL_API_FRAME_OBJECT* frame,
PyCodeObject* code,
int throw_flag) {
int throw_flag,
int free_vars_copied) {

DEBUG_NULL_CHECK(tstate);
DEBUG_NULL_CHECK(frame);
Expand Down Expand Up @@ -333,6 +338,13 @@ inline static PyObject* eval_custom_code_impl(
}
#endif

// for 3.11+, if free_vars_copied is true, we do not need to
// run the first COPY_FREE_VARS since THP_PyFrame_FastToLocalsWithError
// already did the equivalent action.
if (free_vars_copied && _Py_OPCODE(_PyCode_CODE(shadow->f_code)[0]) == COPY_FREE_VARS) {
shadow->prev_instr = _PyCode_CODE(shadow->f_code);
}

#else

THP_EVAL_API_FRAME_OBJECT* shadow = PyFrame_New(tstate, code, frame->f_globals, NULL);
Expand Down Expand Up @@ -428,14 +440,16 @@ inline static PyObject* eval_custom_code_impl(
fastlocals_new[j] = fastlocals_old[i];
}

// NOTE: if you want to evaluate frame instead of shadow in 3.12+,
// you need to clear_old_frame_if_python_312_plus the shadow frame BEFORE
// calling eval_frame_default (i.e. here) and comment out the
// clear_old_frame_if_python_312_plus call on the original frame.

PyObject* result = eval_frame_default(tstate, shadow, throw_flag);

#if IS_PYTHON_3_12_PLUS

// In 3.12, the frame evaluation function is responsible for
// clearing and popping the frame, so we manually do that on the
// old frame.
clear_old_frame_if_python_312_plus(tstate, frame);
// frame is cleared by caller
Py_DECREF(func);

#elif IS_PYTHON_3_11_PLUS
Expand All @@ -460,13 +474,15 @@ inline static PyObject* eval_custom_code(
PyThreadState* tstate,
THP_EVAL_API_FRAME_OBJECT* frame,
PyCodeObject* code,
int throw_flag) {
int throw_flag,
int free_vars_copied) {
_PytorchRecordFunctionState* rf = _pytorch_record_function_enter("Torch-Compiled Region");
PyObject* result = eval_custom_code_impl(
tstate,
frame,
code,
throw_flag
throw_flag,
free_vars_copied
);
_pytorch_record_function_exit(rf);
return result;
Expand All @@ -487,18 +503,25 @@ static PyObject* _custom_eval_frame_shim(
return eval_frame_default(tstate, frame, throw_flag);
}

return _custom_eval_frame(tstate, frame, throw_flag, callback);
int should_clear_frame = 0;
PyObject* result = _custom_eval_frame(tstate, frame, throw_flag, callback, &should_clear_frame);
if (should_clear_frame) {
clear_old_frame_if_python_312_plus(tstate, frame);
}
return result;
}

// NOTE: In 3.12+, any return NULL; statements must be preceded by
// clear_old_frame_if_python_312_plus(tstate, frame); since the eval frame function
// is now responsible for clearing/popping the frame.
// eval_frame_default/eval_custom_code will clear/pop the frame.
// NOTE: In 3.12+, the frame evaluation function (callee) is responsible for clearing/popping
// the frame, meaning that unless we default evaluate the original frame,
// we are responsible for clearing it - via clear_old_frame_if_python_312_plus.
// The should_clear_frame flag is used to indicate whether the frame should be
// cleared by _custom_eval_frame's caller.
static PyObject* _custom_eval_frame(
PyThreadState* tstate,
THP_EVAL_API_FRAME_OBJECT* frame,
int throw_flag,
PyObject* callback) {
PyObject* callback,
int* should_clear_frame) {
#if IS_PYTHON_3_11_PLUS
DEBUG_TRACE(
"begin %s %s %i %i",
Expand Down Expand Up @@ -552,9 +575,10 @@ static PyObject* _custom_eval_frame(
}

// TODO(jansel): investigate directly using the "fast" representation
if (THP_PyFrame_FastToLocalsWithError(frame) < 0) {
int free_vars_copied = 0;
if (THP_PyFrame_FastToLocalsWithError(frame, &free_vars_copied) < 0) {
DEBUG_TRACE("error %s", get_frame_name(frame));
clear_old_frame_if_python_312_plus(tstate, frame);
*should_clear_frame = 1;
return NULL;
}

Expand All @@ -570,7 +594,7 @@ static PyObject* _custom_eval_frame(

if (maybe_cached_code == NULL) {
// guard eval failed, keep propagating
clear_old_frame_if_python_312_plus(tstate, frame);
*should_clear_frame = 1;
return NULL;
} else if (maybe_cached_code == Py_None) {
DEBUG_TRACE("cache miss %s", get_frame_name(frame));
Expand All @@ -579,7 +603,8 @@ static PyObject* _custom_eval_frame(
PyCodeObject* cached_code = (PyCodeObject*)maybe_cached_code;
// used cached version
DEBUG_TRACE("cache hit %s", get_frame_name(frame));
return eval_custom_code(tstate, frame, cached_code, throw_flag);
*should_clear_frame = 1;
return eval_custom_code(tstate, frame, cached_code, throw_flag, free_vars_copied);
}
DEBUG_CHECK(PyDict_CheckExact(frame->f_locals));
DEBUG_CHECK(PyDict_CheckExact(frame->f_globals));
Expand All @@ -595,15 +620,16 @@ static PyObject* _custom_eval_frame(
_pytorch_record_function_exit(rf);
if (maybe_cached_code == NULL) {
// Python error
clear_old_frame_if_python_312_plus(tstate, frame);
*should_clear_frame = 1;
return NULL;
} else if (maybe_cached_code != Py_None) {
PyCodeObject* cached_code = (PyCodeObject*)maybe_cached_code;
// used cached version
DEBUG_TRACE("cache hit %s", get_frame_name(frame));
// Re-enable custom behavior
eval_frame_callback_set(callback);
return eval_custom_code(tstate, frame, cached_code, throw_flag);
*should_clear_frame = 1;
return eval_custom_code(tstate, frame, cached_code, throw_flag, free_vars_copied);
}
// cache miss
CacheEntry* cache_entry = extract_cache_entry(extra);
Expand All @@ -618,7 +644,7 @@ static PyObject* _custom_eval_frame(
// cascading failure from internal exceptions. The upshot is if
// Dynamo barfs, that's it for Dynamo, even if you catch the exception
// inside the torch.compile block we won't try to Dynamo anything else.
clear_old_frame_if_python_312_plus(tstate, frame);
*should_clear_frame = 1;
return NULL;
} else if (result != Py_None) {
DEBUG_TRACE("create cache %s", get_frame_name(frame));
Expand All @@ -636,7 +662,8 @@ static PyObject* _custom_eval_frame(
// will be cleaned up when set_extra_state is called.
// Re-enable custom behavior
eval_frame_callback_set(callback);
return eval_custom_code(tstate, frame, CacheEntry_get_code(new_cache_entry), throw_flag);
*should_clear_frame = 1;
return eval_custom_code(tstate, frame, CacheEntry_get_code(new_cache_entry), throw_flag, free_vars_copied);
} else {
DEBUG_TRACE("create skip %s", get_frame_name(frame));
Py_DECREF(result);
Expand Down
Loading
0