8000 [mypyc] Simplify generated code for native attribute get by JukkaL · Pull Request #11978 · python/mypy · GitHub
[go: up one dir, main page]

Skip to content

[mypyc] Simplify generated code for native attribute get #11978

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Mar 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 62 additions & 11 deletions mypyc/codegen/emitfunc.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""Code generation for native function bodies."""

from typing import Union, Optional
from typing import List, Union, Optional
from typing_extensions import Final

from mypyc.common import (
REG_PREFIX, NATIVE_PREFIX, STATIC_PREFIX, TYPE_PREFIX, MODULE_PREFIX,
)
from mypyc.codegen.emit import Emitter
from mypyc.ir.ops import (
OpVisitor, Goto, Branch, Return, Assign, Integer, LoadErrorValue, GetAttr, SetAttr,
Op, OpVisitor, Goto, Branch, Return, Assign, Integer, LoadErrorValue, GetAttr, SetAttr,
LoadStatic, InitStatic, TupleGet, TupleSet, Call, IncRef, DecRef, Box, Cast, Unbox,
BasicBlock, Value, MethodCall, Unreachable, NAMESPACE_STATIC, NAMESPACE_TYPE, NAMESPACE_MODULE,
RaiseStandardError, CallC, LoadGlobal, Truncate, IntOp, LoadMem, GetElementPtr,
Expand Down Expand Up @@ -88,8 +88,13 @@ def generate_native_function(fn: FuncIR,
next_block = blocks[i + 1]
body.emit_label(block)
visitor.next_block = next_block
for op in block.ops:
op.accept(visitor)

ops = block.ops
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could probably use a comment of what we are keeping track of here now that the loop is a lot more complex

visitor.ops = ops
visitor.op_index = 0
while visitor.op_index < len(ops):
ops[visitor.op_index].accept(visitor)
visitor.op_index += 1

body.emit_line('}')

Expand All @@ -110,7 +115,12 @@ def __init__(self,
self.module_name = module_name
self.literals = emitter.context.literals
self.rare = False
# Next basic block to be processed after the current one (if any), set by caller
self.next_block: Optional[BasicBlock] = None
# Ops in the basic block currently being processed, set by caller
self.ops: List[Op] = []
# Current index within ops; visit methods can increment this to skip/merge ops
self.op_index = 0

def temp_name(self) -> str:
return self.emitter.temp_name()
Expand Down Expand Up @@ -293,16 +303,44 @@ def visit_get_attr(self, op: GetAttr) -> None:
attr_expr = self.get_attr_expr(obj, op, decl_cl)
self.emitter.emit_line('{} = {};'.format(dest, attr_expr))
self.emitter.emit_undefined_attr_check(
attr_rtype, attr_expr, '==', unlikely=True
attr_rtype, dest, '==', unlikely=True
)
exc_class = 'PyExc_AttributeError'
self.emitter.emit_line(
'PyErr_SetString({}, "attribute {} of {} undefined");'.format(
exc_class, repr(op.attr), repr(cl.name)))
merged_branch = None
branch = self.next_branch()
if branch is not None:
if (branch.value is op
and branch.op == Branch.IS_ERROR
and branch.traceback_entry is not None
and not branch.negated):
# Generate code for the following branch here to avoid
# redundant branches in the generate code.
self.emit_attribute_error(branch, cl.name, op.attr)
self.emit_line('goto %s;' % self.label(branch.true))
merged_branch = branch
self.emitter.emit_line('}')
if not merged_branch:
self.emitter.emit_line(
'PyErr_SetString({}, "attribute {} of {} undefined");'.format(
exc_class, repr(op.attr), repr(cl.name)))

if attr_rtype.is_refcounted:
self.emitter.emit_line('} else {')
self.emitter.emit_inc_ref(attr_expr, attr_rtype)
self.emitter.emit_line('}')
if not merged_branch:
self.emitter.emit_line('} else {')
self.emitter.emit_inc_ref(dest, attr_rtype)
if merged_branch:
if merged_branch.false is not self.next_block:
self.emit_line('goto %s;' % self.label(merged_branch.false))
self.op_index += 1
else:
self.emitter.emit_line('}')

def next_branch(self) -> Optional[Branch]:
if self.op_index + 1 < len(self.ops):
next_op = self.ops[self.op_index + 1]
if isinstance(next_op, Branch):
return next_op
return None

def visit_set_attr(self, op: SetAttr) -> None:
dest = self.reg(op)
Expand Down Expand Up @@ -603,6 +641,19 @@ def emit_traceback(self, op: Branch) -> None:
if DEBUG_ERRORS:
self.emit_line('assert(PyErr_Occurred() != NULL && "failure w/o err!");')

def emit_attribute_error(self, op: Branch, class_name: str, attr: str) -> None:
assert op.traceback_entry is not None
globals_static = self.emitter.static_name('globals', self.module_name)
self.emit_line('CPy_AttributeError("%s", "%s", "%s", "%s", %d, %s);' % (
self.source_path.replace("\\", "\\\\"),
op.traceback_entry[0],
class_name,
attr,
op.traceback_entry[1],
globals_static))
if DEBUG_ERRORS:
self.emit_line('assert(PyErr_Occurred() != NULL && "failure w/o err!");')

def emit_signed_int_cast(self, type: RType) -> str:
if is_tagged(type):
return '(Py_ssize_t)'
Expand Down
2 changes: 2 additions & 0 deletions mypyc/lib-rt/CPy.h
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,8 @@ void _CPy_GetExcInfo(PyObject **p_type, PyObject **p_value, PyObject **p_traceba
void CPyError_OutOfMemory(void);
void CPy_TypeError(const char *expected, PyObject *value);
void CPy_AddTraceback(const char *filename, const char *funcname, int line, PyObject *globals);
void CPy_AttributeError(const char *filename, const char *funcname, const char *classname,
const char *attrname, int line, PyObject *globals);


// Misc operations
Expand Down
8 changes: 8 additions & 0 deletions mypyc/lib-rt/exc_ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,11 @@ void CPy_AddTraceback(const char *filename, const char *funcname, int line, PyOb
error:
_PyErr_ChainExceptions(exc, val, tb);
}

void CPy_AttributeError(const char *filename, const char *funcname, const char *classname,
const char *attrname, int line, PyObject *globals) {
char buf[500];
snprintf(buf, sizeof(buf), "attribute '%.200s' of '%.200s' undefined", classname, attrname);
PyErr_SetString(PyExc_AttributeError, buf);
CPy_AddTraceback(filename, funcname, line, globals);
}
31 changes: 27 additions & 4 deletions mypyc/test/test_emitfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,22 +281,39 @@ def test_get_attr(self) -> None:
self.assert_emit(
GetAttr(self.r, 'y', 1),
"""cpy_r_r0 = ((mod___AObject *)cpy_r_r)->_y;
if (unlikely(((mod___AObject *)cpy_r_r)->_y == CPY_INT_TAG)) {
if (unlikely(cpy_r_r0 == CPY_INT_TAG)) {
PyErr_SetString(PyExc_AttributeError, "attribute 'y' of 'A' undefined");
} else {
CPyTagged_INCREF(((mod___AObject *)cpy_r_r)->_y);
CPyTagged_INCREF(cpy_r_r0);
}
""")

def test_get_attr_non_refcounted(self) -> None:
self.assert_emit(
GetAttr(self.r, 'x', 1),
"""cpy_r_r0 = ((mod___AObject *)cpy_r_r)->_x;
if (unlikely(((mod___AObject *)cpy_r_r)->_x == 2)) {
if (unlikely(cpy_r_r0 == 2)) {
PyErr_SetString(PyExc_AttributeError, "attribute 'x' of 'A' undefined");
}
""")

def test_get_attr_merged(self) -> None:
op = GetAttr(self.r, 'y', 1)
branch = Branch(op, BasicBlock(8), BasicBlock(9), Branch.IS_ERROR)
branch.traceback_entry = ('foobar', 123)
self.assert_emit(
op,
"""\
cpy_r_r0 = ((mod___AObject *)cpy_r_r)->_y;
if (unlikely(cpy_r_r0 == CPY_INT_TAG)) {
CPy_AttributeError("prog.py", "foobar", "A", "y", 123, CPyStatic_prog___globals);
goto CPyL8;
}
CPyTagged_INCREF(cpy_r_r0);
goto CPyL9;
""",
next_branch=branch)

def test_set_attr(self) -> None:
self.assert_emit(
SetAttr(self.r, 'y', self.m, 1),
Expand Down Expand Up @@ -428,7 +445,8 @@ def assert_emit(self,
expected: str,
next_block: Optional[BasicBlock] = None,
*,
rare: bool = False) -> None:
rare: bool = False,
next_branch: Optional[Branch] = None) -> None:
block = BasicBlock(0)
block.ops.append(op)
value_names = generate_names_for_ir(self.registers, [block])
Expand All @@ -440,6 +458,11 @@ def assert_emit(self,
visitor = FunctionEmitterVisitor(emitter, declarations, 'prog.py', 'prog')
visitor.next_block = next_block
visitor.rare = rare
if next_branch:
visitor.ops = [op, next_branch]
else:
visitor.ops = [op]
visitor.op_index = 0

op.accept(visitor)
frags = declarations.fragments + emitter.fragments
Expand Down
0