From a534dae51c43a03514da3174d4ecef866565311a Mon Sep 17 00:00:00 2001 From: Elazar Gershuni Date: Sun, 2 Oct 2016 13:48:12 +0300 Subject: [PATCH 1/3] tighten types treetransform --- mypy/fastparse.py | 58 ++++++++----- mypy/fastparse2.py | 52 +++++++----- mypy/nodes.py | 4 +- mypy/parse.py | 145 ++++++++++++++++----------------- mypy/treetransform.py | 184 +++++++++++++++++++++++------------------- 5 files changed, 243 insertions(+), 200 deletions(-) diff --git a/mypy/fastparse.py b/mypy/fastparse.py index 195915fa8db9..218db15a25eb 100644 --- a/mypy/fastparse.py +++ b/mypy/fastparse.py @@ -125,8 +125,22 @@ def generic_visit(self, node: ast35.AST) -> None: def visit_NoneType(self, n: Any) -> Optional[Node]: return None - def visit_list(self, l: Sequence[ast35.AST]) -> List[Expression]: - return [self.visit(e) for e in l] + def translate_expr_list(self, l: Sequence[ast35.AST]) -> List[Expression]: + res = [] # type: List[Expression] + for e in l: + exp = self.visit(e) + assert exp is None or isinstance(exp, Expression) + res.append(exp) + return res + + def translate_stmt_list(self, l: Sequence[ast35.AST]) -> List[Statement]: + res = [] # type: List[Statement] + for e in l: + stmt = self.visit(e) + assert stmt is None or isinstance(stmt, Statement) + res.append(stmt) + return res + op_map = { ast35.Add: '+', @@ -176,7 +190,7 @@ def from_comp_operator(self, op: ast35.cmpop) -> str: def as_block(self, stmts: List[ast35.stmt], lineno: int) -> Block: b = None if stmts: - b = Block(self.fix_function_overloads(self.visit_list(stmts))) + b = Block(self.fix_function_overloads(self.translate_stmt_list(stmts))) b.set_line(lineno) return b @@ -225,7 +239,7 @@ def translate_module_id(self, id: str) -> str: return id def visit_Module(self, mod: ast35.Module) -> MypyFile: - body = self.fix_function_overloads(self.visit_list(mod.body)) + body = self.fix_function_overloads(self.translate_stmt_list(mod.body)) return MypyFile(body, self.imports, @@ -269,7 +283,7 @@ def do_func_def(self, n: Union[ast35.FunctionDef, ast35.AsyncFunctionDef], for a in args] else: arg_types = [a if a is not None else AnyType() for - a in TypeConverter(line=n.lineno).visit_list(func_type_ast.argtypes)] + a in TypeConverter(line=n.lineno).translate_expr_list(func_type_ast.argtypes)] return_type = TypeConverter(line=n.lineno).visit(func_type_ast.returns) # add implicit self type @@ -312,7 +326,7 @@ def do_func_def(self, n: Union[ast35.FunctionDef, ast35.AsyncFunctionDef], func_def.is_decorated = True func_def.set_line(n.lineno + len(n.decorator_list)) func_def.body.set_line(func_def.get_line()) - return Decorator(func_def, self.visit_list(n.decorator_list), var) + return Decorator(func_def, self.translate_expr_list(n.decorator_list), var) else: return func_def @@ -382,9 +396,9 @@ def visit_ClassDef(self, n: ast35.ClassDef) -> ClassDef: cdef = ClassDef(n.name, self.as_block(n.body, n.lineno), None, - self.visit_list(n.bases), + self.translate_expr_list(n.bases), metaclass=metaclass) - cdef.decorators = self.visit_list(n.decorator_list) + cdef.decorators = self.translate_expr_list(n.decorator_list) self.class_nesting -= 1 return cdef @@ -397,7 +411,7 @@ def visit_Return(self, n: ast35.Return) -> ReturnStmt: @with_line def visit_Delete(self, n: ast35.Delete) -> DelStmt: if len(n.targets) > 1: - tup = TupleExpr(self.visit_list(n.targets)) + tup = TupleExpr(self.translate_expr_list(n.targets)) tup.set_line(n.lineno) return DelStmt(tup) else: @@ -424,7 +438,7 @@ def visit_Assign(self, n: ast35.Assign) -> AssignmentStmt: rvalue = TempNode(AnyType()) # type: Expression else: rvalue = self.visit(n.value) - lvalues = self.visit_list(n.targets) + lvalues = self.translate_expr_list(n.targets) return AssignmentStmt(lvalues, rvalue, type=typ, new_syntax=new_syntax) @@ -590,7 +604,7 @@ def group(vals: List[Expression]) -> OpExpr: else: return OpExpr(op, vals[0], group(vals[1:])) - return group(self.visit_list(n.values)) + return group(self.translate_expr_list(n.values)) # BinOp(expr left, operator op, expr right) @with_line @@ -640,12 +654,12 @@ def visit_IfExp(self, n: ast35.IfExp) -> ConditionalExpr: # Dict(expr* keys, expr* values) @with_line def visit_Dict(self, n: ast35.Dict) -> DictExpr: - return DictExpr(list(zip(self.visit_list(n.keys), self.visit_list(n.values)))) + return DictExpr(list(zip(self.translate_expr_list(n.keys), self.translate_expr_list(n.values)))) # Set(expr* elts) @with_line def visit_Set(self, n: ast35.Set) -> SetExpr: - return SetExpr(self.visit_list(n.elts)) + return SetExpr(self.translate_expr_list(n.elts)) # ListComp(expr elt, comprehension* generators) @with_line @@ -662,7 +676,7 @@ def visit_SetComp(self, n: ast35.SetComp) -> SetComprehension: def visit_DictComp(self, n: ast35.DictComp) -> DictionaryComprehension: targets = [self.visit(c.target) for c in n.generators] iters = [self.visit(c.iter) for c in n.generators] - ifs_list = [self.visit_list(c.ifs) for c in n.generators] + ifs_list = [self.translate_expr_list(c.ifs) for c in n.generators] return DictionaryComprehension(self.visit(n.key), self.visit(n.value), targets, @@ -674,7 +688,7 @@ def visit_DictComp(self, n: ast35.DictComp) -> DictionaryComprehension: def visit_GeneratorExp(self, n: ast35.GeneratorExp) -> GeneratorExpr: targets = [self.visit(c.target) for c in n.generators] iters = [self.visit(c.iter) for c in n.generators] - ifs_list = [self.visit_list(c.ifs) for c in n.generators] + ifs_list = [self.translate_expr_list(c.ifs) for c in n.generators] return GeneratorExpr(self.visit(n.elt), targets, iters, @@ -700,7 +714,7 @@ def visit_YieldFrom(self, n: ast35.YieldFrom) -> YieldFromExpr: @with_line def visit_Compare(self, n: ast35.Compare) -> ComparisonExpr: operators = [self.from_comp_operator(o) for o in n.ops] - operands = self.visit_list([n.left] + n.comparators) + operands = self.translate_expr_list([n.left] + n.comparators) return ComparisonExpr(operators, operands) # Call(expr func, expr* args, keyword* keywords) @@ -710,7 +724,7 @@ def visit_Call(self, n: ast35.Call) -> CallExpr: def is_star2arg(k: ast35.keyword) -> bool: return k.arg is None - arg_types = self.visit_list( + arg_types = self.translate_expr_list( [a.value if isinstance(a, ast35.Starred) else a for a in n.args] + [k.value for k in n.keywords]) arg_kinds = ([ARG_STAR if isinstance(a, ast35.Starred) else ARG_POS for a in n.args] + @@ -812,7 +826,7 @@ def visit_Slice(self, n: ast35.Slice) -> SliceExpr: # ExtSlice(slice* dims) def visit_ExtSlice(self, n: ast35.ExtSlice) -> TupleExpr: - return TupleExpr(self.visit_list(n.dims)) + return TupleExpr(self.translate_expr_list(n.dims)) # Index(expr value) def visit_Index(self, n: ast35.Index) -> Node: @@ -836,7 +850,7 @@ def generic_visit(self, node: ast35.AST) -> None: def visit_NoneType(self, n: Any) -> Type: return None - def visit_list(self, l: Sequence[ast35.AST]) -> List[Type]: + def translate_expr_list(self, l: Sequence[ast35.AST]) -> List[Type]: return [self.visit(e) for e in l] def visit_Name(self, n: ast35.Name) -> Type: @@ -860,7 +874,7 @@ def visit_Subscript(self, n: ast35.Subscript) -> Type: empty_tuple_index = False if isinstance(n.slice.value, ast35.Tuple): - params = self.visit_list(n.slice.value.elts) + params = self.translate_expr_list(n.slice.value.elts) if len(n.slice.value.elts) == 0: empty_tuple_index = True else: @@ -869,7 +883,7 @@ def visit_Subscript(self, n: ast35.Subscript) -> Type: return UnboundType(value.name, params, line=self.line, empty_tuple_index=empty_tuple_index) def visit_Tuple(self, n: ast35.Tuple) -> Type: - return TupleType(self.visit_list(n.elts), None, implicit=True, line=self.line) + return TupleType(self.translate_expr_list(n.elts), None, implicit=True, line=self.line) # Attribute(expr value, identifier attr, expr_context ctx) def visit_Attribute(self, n: ast35.Attribute) -> Type: @@ -886,7 +900,7 @@ def visit_Ellipsis(self, n: ast35.Ellipsis) -> Type: # List(expr* elts, expr_context ctx) def visit_List(self, n: ast35.List) -> Type: - return TypeList(self.visit_list(n.elts), line=self.line) + return TypeList(self.translate_expr_list(n.elts), line=self.line) class TypeCommentParseError(Exception): diff --git a/mypy/fastparse2.py b/mypy/fastparse2.py index bb0b798cce63..93e7c37d6580 100644 --- a/mypy/fastparse2.py +++ b/mypy/fastparse2.py @@ -142,8 +142,22 @@ def generic_visit(self, node: ast27.AST) -> None: def visit_NoneType(self, n: Any) -> Optional[Node]: return None - def visit_list(self, l: Sequence[ast27.AST]) -> List[Expression]: - return [self.visit(e) for e in l] + def translate_expr_list(self, l: Sequence[ast27.AST]) -> List[Expression]: + res = [] # type: List[Expression] + for e in l: + exp = self.visit(e) + assert isinstance(exp, Expression) + res.append(exp) + return res + + def translate_stmt_list(self, l: Sequence[ast27.AST]) -> List[Statement]: + res = [] # type: List[Statement] + for e in l: + stmt = self.visit(e) + assert isinstance(stmt, Statement) + res.append(stmt) + return res + op_map = { ast27.Add: '+', @@ -192,7 +206,7 @@ def from_comp_operator(self, op: ast27.cmpop) -> str: def as_block(self, stmts: List[ast27.stmt], lineno: int) -> Block: b = None if stmts: - b = Block(self.fix_function_overloads(self.visit_list(stmts))) + b = Block(self.fix_function_overloads(self.translate_stmt_list(stmts))) b.set_line(lineno) return b @@ -241,7 +255,7 @@ def translate_module_id(self, id: str) -> str: return id def visit_Module(self, mod: ast27.Module) -> MypyFile: - body = self.fix_function_overloads(self.visit_list(mod.body)) + body = self.fix_function_overloads(self.translate_stmt_list(mod.body)) return MypyFile(body, self.imports, @@ -275,7 +289,7 @@ def visit_FunctionDef(self, n: ast27.FunctionDef) -> Statement: for a in args] else: arg_types = [a if a is not None else AnyType() for - a in converter.visit_list(func_type_ast.argtypes)] + a in converter.translate_expr_list(func_type_ast.argtypes)] return_type = converter.visit(func_type_ast.returns) # add implicit self type @@ -315,7 +329,7 @@ def visit_FunctionDef(self, n: ast27.FunctionDef) -> Statement: func_def.is_decorated = True func_def.set_line(n.lineno + len(n.decorator_list)) func_def.body.set_line(func_def.get_line()) - return Decorator(func_def, self.visit_list(n.decorator_list), var) + return Decorator(func_def, self.translate_expr_list(n.decorator_list), var) else: return func_def @@ -354,7 +368,7 @@ def get_type(i: int) -> Optional[Type]: return None args = [(convert_arg(arg), get_type(i)) for i, arg in enumerate(n.args)] - defaults = self.visit_list(n.defaults) + defaults = self.translate_expr_list(n.defaults) new_args = [] # type: List[Argument] num_no_defaults = len(args) - len(defaults) @@ -397,9 +411,9 @@ def visit_ClassDef(self, n: ast27.ClassDef) -> ClassDef: cdef = ClassDef(n.name, self.as_block(n.body, n.lineno), None, - self.visit_list(n.bases), + self.translate_expr_list(n.bases), metaclass=None) - cdef.decorators = self.visit_list(n.decorator_list) + cdef.decorators = self.translate_expr_list(n.decorator_list) self.class_nesting -= 1 return cdef @@ -412,7 +426,7 @@ def visit_Return(self, n: ast27.Return) -> ReturnStmt: @with_line def visit_Delete(self, n: ast27.Delete) -> DelStmt: if len(n.targets) > 1: - tup = TupleExpr(self.visit_list(n.targets)) + tup = TupleExpr(self.translate_expr_list(n.targets)) tup.set_line(n.lineno) return DelStmt(tup) else: @@ -425,7 +439,7 @@ def visit_Assign(self, n: ast27.Assign) -> AssignmentStmt: if n.type_comment: typ = parse_type_comment(n.type_comment, n.lineno) - return AssignmentStmt(self.visit_list(n.targets), + return AssignmentStmt(self.translate_expr_list(n.targets), self.visit(n.value), type=typ) @@ -644,7 +658,7 @@ def group(vals: List[Expression]) -> OpExpr: else: return OpExpr(op, vals[0], group(vals[1:])) - return group(self.visit_list(n.values)) + return group(self.translate_expr_list(n.values)) # BinOp(expr left, operator op, expr right) @with_line @@ -694,12 +708,12 @@ def visit_IfExp(self, n: ast27.IfExp) -> ConditionalExpr: # Dict(expr* keys, expr* values) @with_line def visit_Dict(self, n: ast27.Dict) -> DictExpr: - return DictExpr(list(zip(self.visit_list(n.keys), self.visit_list(n.values)))) + return DictExpr(list(zip(self.translate_expr_list(n.keys), self.translate_expr_list(n.values)))) # Set(expr* elts) @with_line def visit_Set(self, n: ast27.Set) -> SetExpr: - return SetExpr(self.visit_list(n.elts)) + return SetExpr(self.translate_expr_list(n.elts)) # ListComp(expr elt, comprehension* generators) @with_line @@ -716,7 +730,7 @@ def visit_SetComp(self, n: ast27.SetComp) -> SetComprehension: def visit_DictComp(self, n: ast27.DictComp) -> DictionaryComprehension: targets = [self.visit(c.target) for c in n.generators] iters = [self.visit(c.iter) for c in n.generators] - ifs_list = [self.visit_list(c.ifs) for c in n.generators] + ifs_list = [self.translate_expr_list(c.ifs) for c in n.generators] return DictionaryComprehension(self.visit(n.key), self.visit(n.value), targets, @@ -728,7 +742,7 @@ def visit_DictComp(self, n: ast27.DictComp) -> DictionaryComprehension: def visit_GeneratorExp(self, n: ast27.GeneratorExp) -> GeneratorExpr: targets = [self.visit(c.target) for c in n.generators] iters = [self.visit(c.iter) for c in n.generators] - ifs_list = [self.visit_list(c.ifs) for c in n.generators] + ifs_list = [self.translate_expr_list(c.ifs) for c in n.generators] return GeneratorExpr(self.visit(n.elt), targets, iters, @@ -743,7 +757,7 @@ def visit_Yield(self, n: ast27.Yield) -> YieldExpr: @with_line def visit_Compare(self, n: ast27.Compare) -> ComparisonExpr: operators = [self.from_comp_operator(o) for o in n.ops] - operands = self.visit_list([n.left] + n.comparators) + operands = self.translate_expr_list([n.left] + n.comparators) return ComparisonExpr(operators, operands) # Call(expr func, expr* args, keyword* keywords) @@ -773,7 +787,7 @@ def visit_Call(self, n: ast27.Call) -> CallExpr: signature.append(None) return CallExpr(self.visit(n.func), - self.visit_list(arg_types), + self.translate_expr_list(arg_types), arg_kinds, cast("List[str]", signature)) @@ -870,7 +884,7 @@ def visit_Slice(self, n: ast27.Slice) -> SliceExpr: # ExtSlice(slice* dims) def visit_ExtSlice(self, n: ast27.ExtSlice) -> TupleExpr: - return TupleExpr(self.visit_list(n.dims)) + return TupleExpr(self.translate_expr_list(n.dims)) # Index(expr value) def visit_Index(self, n: ast27.Index) -> Expression: diff --git a/mypy/nodes.py b/mypy/nodes.py index b8e27926a266..9aaa7c355326 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -135,7 +135,9 @@ def accept(self, visitor: NodeVisitor[T]) -> T: # These are placeholders for a future refactoring; see #1783. # For now they serve as (unchecked) documentation of what various # fields of Node subtypes are expected to contain. -Statement = Node +class Statement(Node): + pass + Expression = Node Lvalue = Expression diff --git a/mypy/parse.py b/mypy/parse.py index ef4da013ef88..2c18297749b4 100644 --- a/mypy/parse.py +++ b/mypy/parse.py @@ -6,7 +6,7 @@ import re -from typing import List, Tuple, Any, Set, cast, Union, Optional +from typing import List, Tuple, Set, cast, Union, Optional from mypy import lex from mypy.lex import ( @@ -14,13 +14,12 @@ UnicodeLit, FloatLit, Op, Indent, Keyword, Punct, LexError, ComplexLit, EllipsisToken ) -import mypy.types from mypy.nodes import ( - MypyFile, Import, Node, ImportAll, ImportFrom, FuncDef, OverloadedFuncDef, - ClassDef, Decorator, Block, Var, OperatorAssignmentStmt, + MypyFile, Import, ImportAll, ImportFrom, FuncDef, OverloadedFuncDef, + ClassDef, Decorator, Block, Var, OperatorAssignmentStmt, Statement, ExpressionStmt, AssignmentStmt, ReturnStmt, RaiseStmt, AssertStmt, DelStmt, BreakStmt, ContinueStmt, PassStmt, GlobalDecl, - WhileStmt, ForStmt, IfStmt, TryStmt, WithStmt, + WhileStmt, ForStmt, IfStmt, TryStmt, WithStmt, Expression, TupleExpr, GeneratorExpr, ListComprehension, ListExpr, ConditionalExpr, DictExpr, SetExpr, NameExpr, IntExpr, StrExpr, BytesExpr, UnicodeExpr, FloatExpr, CallExpr, SuperExpr, MemberExpr, IndexExpr, SliceExpr, OpExpr, @@ -34,7 +33,7 @@ from mypy.errors import Errors, CompileError from mypy.types import Type, CallableType, AnyType, UnboundType from mypy.parsetype import ( - parse_type, parse_types, parse_signature, TypeParseError, parse_str_as_signature + parse_type, parse_types, parse_signature, TypeParseError ) from mypy.options import Options @@ -235,7 +234,7 @@ def translate_module_id(self, id: str) -> str: return 'builtins' return id - def parse_import_from(self) -> Node: + def parse_import_from(self) -> ImportBase: self.expect('from') # Build the list of beginning relative tokens. @@ -318,8 +317,8 @@ def parse_qualified_name(self) -> str: # Parsing global definitions - def parse_defs(self) -> List[Node]: - defs = [] # type: List[Node] + def parse_defs(self) -> List[Statement]: + defs = [] # type: List[Statement] while not self.eof(): try: defn, is_simple = self.parse_statement() @@ -340,7 +339,7 @@ def parse_class_def(self) -> ClassDef: metaclass = None try: - base_types = [] # type: List[Node] + base_types = [] # type: List[Expression] try: name_tok = self.expect_type(Name) name = name_tok.string @@ -391,10 +390,10 @@ def parse_class_keywords(self) -> Optional[str]: break return metaclass - def parse_super_type(self) -> Node: + def parse_super_type(self) -> Expression: return self.parse_expression(precedence[',']) - def parse_decorated_function_or_class(self) -> Node: + def parse_decorated_function_or_class(self) -> Union[Decorator, ClassDef]: decorators = [] no_type_checks = False while self.current_str() == '@': @@ -418,7 +417,7 @@ def parse_decorated_function_or_class(self) -> Node: cls.decorators = decorators return cls - def is_no_type_check_decorator(self, expr: Node) -> bool: + def is_no_type_check_decorator(self, expr: Expression) -> bool: if isinstance(expr, NameExpr): return expr.name == 'no_type_check' elif isinstance(expr, MemberExpr): @@ -427,7 +426,7 @@ def is_no_type_check_decorator(self, expr: Node) -> bool: else: return False - def parse_function(self, no_type_checks: bool=False) -> FuncDef: + def parse_function(self, no_type_checks: bool = False) -> FuncDef: def_tok = self.expect('def') is_method = self.is_class_body self.is_class_body = False @@ -754,7 +753,7 @@ def parse_tuple_arg(self, index: int) -> Tuple[Argument, AssignmentStmt, List[st arg_names = self.find_tuple_arg_argument_names(paren_arg) return Argument(var, None, initializer, kind), decompose, arg_names - def verify_tuple_arg(self, paren_arg: Node) -> None: + def verify_tuple_arg(self, paren_arg: Expression) -> None: if isinstance(paren_arg, TupleExpr): if not paren_arg.items: self.fail('Empty tuple not valid as an argument', paren_arg.line, paren_arg.column) @@ -763,7 +762,7 @@ def verify_tuple_arg(self, paren_arg: Node) -> None: elif not isinstance(paren_arg, NameExpr): self.fail('Invalid item in tuple argument', paren_arg.line, paren_arg.column) - def find_tuple_arg_argument_names(self, node: Node) -> List[str]: + def find_tuple_arg_argument_names(self, node: Expression) -> List[str]: result = [] # type: List[str] if isinstance(node, TupleExpr): for item in node.items: @@ -784,7 +783,7 @@ def parse_normal_arg(self, require_named: bool, else: type = self.parse_arg_type(allow_signature) - initializer = None # type: Node + initializer = None # type: Expression if self.current_str() == '=': self.expect('=') initializer = self.parse_expression(precedence[',']) @@ -800,7 +799,7 @@ def parse_normal_arg(self, require_named: bool, return Argument(variable, type, initializer, kind), require_named - def set_type_optional(self, type: Type, initializer: Node) -> None: + def set_type_optional(self, type: Type, initializer: Expression) -> None: if not experiments.STRICT_OPTIONAL: return # Indicate that type should be wrapped in an Optional if arg is initialized to None. @@ -808,7 +807,7 @@ def set_type_optional(self, type: Type, initializer: Node) -> None: if isinstance(type, UnboundType): type.optional = optional - def parse_parameter_annotation(self) -> Node: + def parse_parameter_annotation(self) -> Expression: if self.current_str() == ':': self.skip() return self.parse_expression(precedence[',']) @@ -872,7 +871,7 @@ def parse_block(self, allow_type: bool = False) -> Tuple[Block, Type]: brk = self.expect_break() type = self.parse_type_comment(brk, signature=True) self.expect_indent() - stmt_list = [] # type: List[Node] + stmt_list = [] # type: List[Statement] while (not isinstance(self.current(), Dedent) and not isinstance(self.current(), Eof)): try: @@ -890,7 +889,7 @@ def parse_block(self, allow_type: bool = False) -> Tuple[Block, Type]: node.set_line(colon) return node, type - def try_combine_overloads(self, s: Node, stmt: List[Node]) -> bool: + def try_combine_overloads(self, s: Statement, stmt: List[Statement]) -> bool: if isinstance(s, Decorator) and stmt: fdef = s n = fdef.func.name() @@ -902,8 +901,8 @@ def try_combine_overloads(self, s: Node, stmt: List[Node]) -> bool: return True return False - def parse_statement(self) -> Tuple[Node, bool]: - stmt = None # type: Node + def parse_statement(self) -> Tuple[Statement, bool]: + stmt = None # type: Statement t = self.current() ts = self.current_str() is_simple = True # Is this a non-block statement? @@ -968,7 +967,9 @@ def parse_statement(self) -> Tuple[Node, bool]: stmt.set_line(t) return stmt, is_simple - def parse_expression_or_assignment(self) -> Node: + def parse_expression_or_assignment(self) -> Union[AssignmentStmt, + OperatorAssignmentStmt, + ExpressionStmt]: expr = self.parse_expression(star_expr_allowed=True) if self.current_str() == '=': return self.parse_assignment(expr) @@ -982,7 +983,7 @@ def parse_expression_or_assignment(self) -> Node: # Expression statement. return ExpressionStmt(expr) - def parse_assignment(self, lvalue: Any) -> Node: + def parse_assignment(self, lvalue: Expression) -> AssignmentStmt: """Parse an assignment statement. Assume that lvalue has been parsed already, and the current token is '='. @@ -1132,7 +1133,7 @@ def parse_for_stmt(self) -> ForStmt: node = ForStmt(index, expr, body, else_body) return node - def parse_for_index_variables(self) -> Node: + def parse_for_index_variables(self) -> Expression: # Parse index variables of a 'for' statement. index_items = [] force_tuple = False @@ -1188,12 +1189,12 @@ def parse_if_stmt(self) -> IfStmt: else: return None - def parse_try_stmt(self) -> Node: + def parse_try_stmt(self) -> TryStmt: self.expect('try') body, _ = self.parse_block() is_error = False vars = [] # type: List[NameExpr] - types = [] # type: List[Node] + types = [] # type: List[Optional[Expression]] handlers = [] # type: List[Block] while self.current_str() == 'except': self.expect('except') @@ -1293,9 +1294,9 @@ def parse_exec_stmt(self) -> ExecStmt: # Parsing expressions - def parse_expression(self, prec: int = 0, star_expr_allowed: bool = False) -> Node: + def parse_expression(self, prec: int = 0, star_expr_allowed: bool = False) -> Expression: """Parse a subexpression within a specific precedence context.""" - expr = None # type: Node + expr = None # type: Expression current = self.current() # Remember token for setting the line number. # Parse a "value" expression or unary operator expression and store @@ -1415,18 +1416,18 @@ def parse_expression(self, prec: int = 0, star_expr_allowed: bool = False) -> No return expr - def parse_parentheses(self) -> Node: + def parse_parentheses(self) -> Expression: self.skip() if self.current_str() == ')': # Empty tuple (). - expr = self.parse_empty_tuple_expr() # type: Node + expr = self.parse_empty_tuple_expr() # type: Expression else: # Parenthesised expression. expr = self.parse_expression(0, star_expr_allowed=True) self.expect(')') return expr - def parse_star_expr(self) -> Node: + def parse_star_expr(self) -> StarExpr: star = self.expect('*') expr = self.parse_expression(precedence['*u']) expr = StarExpr(expr) @@ -1439,7 +1440,7 @@ def parse_empty_tuple_expr(self) -> TupleExpr: node = TupleExpr([]) return node - def parse_list_expr(self) -> Node: + def parse_list_expr(self) -> Union[ListExpr, ListComprehension]: """Parse list literal or list comprehension.""" items = [] self.expect('[') @@ -1457,7 +1458,7 @@ def parse_list_expr(self) -> Node: expr = ListExpr(items) return expr - def parse_generator_expr(self, left_expr: Node) -> GeneratorExpr: + def parse_generator_expr(self, left_expr: Expression) -> GeneratorExpr: tok = self.current() indices, sequences, condlists = self.parse_comp_for() @@ -1465,10 +1466,10 @@ def parse_generator_expr(self, left_expr: Node) -> GeneratorExpr: gen.set_line(tok) return gen - def parse_comp_for(self) -> Tuple[List[Node], List[Node], List[List[Node]]]: + def parse_comp_for(self) -> Tuple[List[Expression], List[Expression], List[List[Expression]]]: indices = [] sequences = [] - condlists = [] # type: List[List[Node]] + condlists = [] # type: List[List[Expression]] while self.current_str() == 'for': conds = [] self.expect('for') @@ -1487,7 +1488,7 @@ def parse_comp_for(self) -> Tuple[List[Node], List[Node], List[List[Node]]]: return indices, sequences, condlists - def parse_expression_list(self) -> Node: + def parse_expression_list(self) -> Expression: prec = precedence[''] expr = self.parse_expression(prec) if self.current_str() != ',': @@ -1498,15 +1499,16 @@ def parse_expression_list(self) -> Node: tuple_expr.set_line(t) return tuple_expr - def parse_conditional_expr(self, left_expr: Node) -> ConditionalExpr: + def parse_conditional_expr(self, left_expr: Expression) -> ConditionalExpr: self.expect('if') cond = self.parse_expression(precedence['']) self.expect('else') else_expr = self.parse_expression(precedence['']) return ConditionalExpr(cond, left_expr, else_expr) - def parse_dict_or_set_expr(self) -> Node: - items = [] # type: List[Tuple[Node, Node]] + def parse_dict_or_set_expr(self) -> Union[SetComprehension, SetExpr, + DictionaryComprehension, DictExpr]: + items = [] # type: List[Tuple[Expression, Expression]] self.expect('{') while self.current_str() != '}' and not self.eol(): key = self.parse_expression(precedence['']) @@ -1528,7 +1530,7 @@ def parse_dict_or_set_expr(self) -> Node: node = DictExpr(items) return node - def parse_set_expr(self, first: Node) -> SetExpr: + def parse_set_expr(self, first: Expression) -> SetExpr: items = [first] while self.current_str() != '}' and not self.eol(): self.expect(',') @@ -1539,13 +1541,13 @@ def parse_set_expr(self, first: Node) -> SetExpr: expr = SetExpr(items) return expr - def parse_set_comprehension(self, expr: Node) -> SetComprehension: + def parse_set_comprehension(self, expr: Expression) -> SetComprehension: gen = self.parse_generator_expr(expr) self.expect('}') set_comp = SetComprehension(gen) return set_comp - def parse_dict_comprehension(self, key: Node, value: Node, + def parse_dict_comprehension(self, key: Expression, value: Expression, colon: Token) -> DictionaryComprehension: indices, sequences, condlists = self.parse_comp_for() dic = DictionaryComprehension(key, value, indices, sequences, condlists) @@ -1553,7 +1555,7 @@ def parse_dict_comprehension(self, key: Node, value: Node, self.expect('}') return dic - def parse_tuple_expr(self, expr: Node, + def parse_tuple_expr(self, expr: Expression, prec: int = precedence[',']) -> TupleExpr: items = [expr] while True: @@ -1590,7 +1592,7 @@ def parse_int_expr(self) -> IntExpr: node = IntExpr(value) return node - def parse_str_expr(self) -> Node: + def parse_str_expr(self) -> Union[UnicodeExpr, StrExpr]: # XXX \uxxxx literals token = self.expect_type(StrLit) value = cast(StrLit, token).parsed() @@ -1603,12 +1605,11 @@ def parse_str_expr(self) -> Node: value += token.parsed() is_unicode = True if is_unicode or (self.pyversion[0] == 2 and 'unicode_literals' in self.future_options): - node = UnicodeExpr(value) # type: Node + return UnicodeExpr(value) else: - node = StrExpr(value) - return node + return StrExpr(value) - def parse_bytes_literal(self) -> Node: + def parse_bytes_literal(self) -> Union[BytesExpr, StrExpr]: # XXX \uxxxx literals tok = [self.expect_type(BytesLit)] value = (cast(BytesLit, tok[0])).parsed() @@ -1616,12 +1617,11 @@ def parse_bytes_literal(self) -> Node: t = cast(BytesLit, self.skip()) value += t.parsed() if self.pyversion[0] >= 3: - node = BytesExpr(value) # type: Node + return BytesExpr(value) else: - node = StrExpr(value) - return node + return StrExpr(value) - def parse_unicode_literal(self) -> Node: + def parse_unicode_literal(self) -> Union[StrExpr, UnicodeExpr]: # XXX \uxxxx literals token = self.expect_type(UnicodeLit) value = cast(UnicodeLit, token).parsed() @@ -1630,29 +1630,25 @@ def parse_unicode_literal(self) -> Node: value += token.parsed() if self.pyversion[0] >= 3: # Python 3.3 supports u'...' as an alias of '...'. - node = StrExpr(value) # type: Node + return StrExpr(value) else: - node = UnicodeExpr(value) - return node + return UnicodeExpr(value) def parse_float_expr(self) -> FloatExpr: tok = self.expect_type(FloatLit) - node = FloatExpr(float(tok.string)) - return node + return FloatExpr(float(tok.string)) def parse_complex_expr(self) -> ComplexExpr: tok = self.expect_type(ComplexLit) - node = ComplexExpr(complex(tok.string)) - return node + return ComplexExpr(complex(tok.string)) - def parse_call_expr(self, callee: Any) -> CallExpr: + def parse_call_expr(self, callee: Expression) -> CallExpr: self.expect('(') args, kinds, names = self.parse_arg_expr() self.expect(')') - node = CallExpr(callee, args, kinds, names) - return node + return CallExpr(callee, args, kinds, names) - def parse_arg_expr(self) -> Tuple[List[Node], List[int], List[str]]: + def parse_arg_expr(self) -> Tuple[List[Expression], List[int], List[str]]: """Parse arguments in a call expression (within '(' and ')'). Return a tuple with these items: @@ -1660,7 +1656,7 @@ def parse_arg_expr(self) -> Tuple[List[Node], List[int], List[str]]: argument kinds argument names (for named arguments; None for ordinary args) """ - args = [] # type: List[Node] + args = [] # type: List[Expression] kinds = [] # type: List[int] names = [] # type: List[str] var_arg = False @@ -1698,18 +1694,17 @@ def parse_arg_expr(self) -> Tuple[List[Node], List[int], List[str]]: self.expect(',') return args, kinds, names - def parse_member_expr(self, expr: Any) -> Node: + def parse_member_expr(self, expr: Expression) -> Union[SuperExpr, MemberExpr]: self.expect('.') name = self.expect_type(Name) if (isinstance(expr, CallExpr) and isinstance(expr.callee, NameExpr) and expr.callee.name == 'super'): # super() expression - node = SuperExpr(name.string) # type: Node + return SuperExpr(name.string) else: - node = MemberExpr(expr, name.string) - return node + return MemberExpr(expr, name.string) - def parse_index_expr(self, base: Any) -> IndexExpr: + def parse_index_expr(self, base: Expression) -> IndexExpr: self.expect('[') index = self.parse_slice_item() if self.current_str() == ',': @@ -1726,7 +1721,7 @@ def parse_index_expr(self, base: Any) -> IndexExpr: node = IndexExpr(base, index) return node - def parse_slice_item(self) -> Node: + def parse_slice_item(self) -> Expression: if self.current_str() != ':': if self.current_str() == '...': # Ellipsis is valid here even in Python 2 (but not elsewhere). @@ -1755,7 +1750,7 @@ def parse_slice_item(self) -> Node: item.set_line(colon) return item - def parse_bin_op_expr(self, left: Node, prec: int) -> OpExpr: + def parse_bin_op_expr(self, left: Expression, prec: int) -> OpExpr: op = self.expect_type(Op) op_str = op.string if op_str == '~': @@ -1765,7 +1760,7 @@ def parse_bin_op_expr(self, left: Node, prec: int) -> OpExpr: node = OpExpr(op_str, left, right) return node - def parse_comparison_expr(self, left: Node, prec: int) -> ComparisonExpr: + def parse_comparison_expr(self, left: Expression, prec: int) -> ComparisonExpr: operators_str = [] operands = [left] @@ -1824,7 +1819,7 @@ def parse_lambda_expr(self) -> FuncExpr: return_stmt = ReturnStmt(expr) return_stmt.set_line(lambda_tok) - nodes = [return_stmt] # type: List[Node] + nodes = [return_stmt] # type: List[Statement] # Potentially insert extra assignment statements to the beginning of the # body, used to decompose Python 2 tuple arguments. nodes[:0] = extra_stmts diff --git a/mypy/treetransform.py b/mypy/treetransform.py index 9c090f94588f..cf5d87e35247 100644 --- a/mypy/treetransform.py +++ b/mypy/treetransform.py @@ -16,12 +16,12 @@ UnicodeExpr, FloatExpr, CallExpr, SuperExpr, MemberExpr, IndexExpr, SliceExpr, OpExpr, UnaryExpr, FuncExpr, TypeApplication, PrintStmt, SymbolTable, RefExpr, TypeVarExpr, NewTypeExpr, PromoteExpr, - ComparisonExpr, TempNode, StarExpr, + ComparisonExpr, TempNode, StarExpr, Statement, Expression, YieldFromExpr, NamedTupleExpr, NonlocalDecl, SetComprehension, DictionaryComprehension, ComplexExpr, TypeAliasExpr, EllipsisExpr, YieldExpr, ExecStmt, Argument, BackquoteExpr, AwaitExpr, ) -from mypy.types import Type, FunctionLike, Instance +from mypy.types import Type, FunctionLike from mypy.traverser import TraverserVisitor from mypy.visitor import NodeVisitor @@ -55,9 +55,9 @@ def __init__(self) -> None: # transformed node). self.func_placeholder_map = {} # type: Dict[FuncDef, FuncDef] - def visit_mypy_file(self, node: MypyFile) -> Node: + def visit_mypy_file(self, node: MypyFile) -> MypyFile: # NOTE: The 'names' and 'imports' instance variables will be empty! - new = MypyFile(self.nodes(node.defs), [], node.is_bom, + new = MypyFile(self.statements(node.defs), [], node.is_bom, ignored_lines=set(node.ignored_lines)) new._name = node._name new._fullname = node._fullname @@ -65,13 +65,13 @@ def visit_mypy_file(self, node: MypyFile) -> Node: new.names = SymbolTable() return new - def visit_import(self, node: Import) -> Node: + def visit_import(self, node: Import) -> Import: return Import(node.ids[:]) - def visit_import_from(self, node: ImportFrom) -> Node: + def visit_import_from(self, node: ImportFrom) -> ImportFrom: return ImportFrom(node.id, node.relative, node.names[:]) - def visit_import_all(self, node: ImportAll) -> Node: + def visit_import_all(self, node: ImportAll) -> ImportAll: return ImportAll(node.id, node.relative) def copy_argument(self, argument: Argument) -> Argument: @@ -143,7 +143,7 @@ def visit_func_def(self, node: FuncDef) -> FuncDef: else: return new - def visit_func_expr(self, node: FuncExpr) -> Node: + def visit_func_expr(self, node: FuncExpr) -> FuncExpr: new = FuncExpr([self.copy_argument(arg) for arg in node.arguments], self.block(node.body), cast(FunctionLike, self.optional_type(node.type))) @@ -169,7 +169,7 @@ def duplicate_inits(self, result.append(None) return result - def visit_overloaded_func_def(self, node: OverloadedFuncDef) -> Node: + def visit_overloaded_func_def(self, node: OverloadedFuncDef) -> OverloadedFuncDef: items = [self.visit_decorator(decorator) for decorator in node.items] for newitem, olditem in zip(items, node.items): @@ -180,11 +180,11 @@ def visit_overloaded_func_def(self, node: OverloadedFuncDef) -> Node: new.info = node.info return new - def visit_class_def(self, node: ClassDef) -> Node: + def visit_class_def(self, node: ClassDef) -> ClassDef: new = ClassDef(node.name, self.block(node.defs), node.type_vars, - self.nodes(node.base_type_exprs), + self.expressions(node.base_type_exprs), node.metaclass) new.fullname = node.fullname new.info = node.info @@ -193,20 +193,20 @@ def visit_class_def(self, node: ClassDef) -> Node: new.is_builtinclass = node.is_builtinclass return new - def visit_global_decl(self, node: GlobalDecl) -> Node: + def visit_global_decl(self, node: GlobalDecl) -> GlobalDecl: return GlobalDecl(node.names[:]) - def visit_nonlocal_decl(self, node: NonlocalDecl) -> Node: + def visit_nonlocal_decl(self, node: NonlocalDecl) -> NonlocalDecl: return NonlocalDecl(node.names[:]) def visit_block(self, node: Block) -> Block: - return Block(self.nodes(node.body)) + return Block(self.statements(node.body)) def visit_decorator(self, node: Decorator) -> Decorator: # Note that a Decorator must be transformed to a Decorator. func = self.visit_func_def(node.func) func.line = node.func.line - new = Decorator(func, self.nodes(node.decorators), + new = Decorator(func, self.expressions(node.decorators), self.visit_var(node.var)) new.is_overload = node.is_overload return new @@ -229,111 +229,111 @@ def visit_var(self, node: Var) -> Var: self.var_map[node] = new return new - def visit_expression_stmt(self, node: ExpressionStmt) -> Node: + def visit_expression_stmt(self, node: ExpressionStmt) -> ExpressionStmt: return ExpressionStmt(self.node(node.expr)) - def visit_assignment_stmt(self, node: AssignmentStmt) -> Node: + def visit_assignment_stmt(self, node: AssignmentStmt) -> AssignmentStmt: return self.duplicate_assignment(node) def duplicate_assignment(self, node: AssignmentStmt) -> AssignmentStmt: - new = AssignmentStmt(self.nodes(node.lvalues), + new = AssignmentStmt(self.expressions(node.lvalues), self.node(node.rvalue), self.optional_type(node.type)) new.line = node.line return new def visit_operator_assignment_stmt(self, - node: OperatorAssignmentStmt) -> Node: + node: OperatorAssignmentStmt) -> OperatorAssignmentStmt: return OperatorAssignmentStmt(node.op, self.node(node.lvalue), self.node(node.rvalue)) - def visit_while_stmt(self, node: WhileStmt) -> Node: + def visit_while_stmt(self, node: WhileStmt) -> WhileStmt: return WhileStmt(self.node(node.expr), self.block(node.body), self.optional_block(node.else_body)) - def visit_for_stmt(self, node: ForStmt) -> Node: + def visit_for_stmt(self, node: ForStmt) -> ForStmt: return ForStmt(self.node(node.index), self.node(node.expr), self.block(node.body), self.optional_block(node.else_body)) - def visit_return_stmt(self, node: ReturnStmt) -> Node: + def visit_return_stmt(self, node: ReturnStmt) -> ReturnStmt: return ReturnStmt(self.optional_node(node.expr)) - def visit_assert_stmt(self, node: AssertStmt) -> Node: + def visit_assert_stmt(self, node: AssertStmt) -> AssertStmt: return AssertStmt(self.node(node.expr)) - def visit_del_stmt(self, node: DelStmt) -> Node: + def visit_del_stmt(self, node: DelStmt) -> DelStmt: return DelStmt(self.node(node.expr)) - def visit_if_stmt(self, node: IfStmt) -> Node: - return IfStmt(self.nodes(node.expr), + def visit_if_stmt(self, node: IfStmt) -> IfStmt: + return IfStmt(self.expressions(node.expr), self.blocks(node.body), self.optional_block(node.else_body)) - def visit_break_stmt(self, node: BreakStmt) -> Node: + def visit_break_stmt(self, node: BreakStmt) -> BreakStmt: return BreakStmt() - def visit_continue_stmt(self, node: ContinueStmt) -> Node: + def visit_continue_stmt(self, node: ContinueStmt) -> ContinueStmt: return ContinueStmt() - def visit_pass_stmt(self, node: PassStmt) -> Node: + def visit_pass_stmt(self, node: PassStmt) -> PassStmt: return PassStmt() - def visit_raise_stmt(self, node: RaiseStmt) -> Node: + def visit_raise_stmt(self, node: RaiseStmt) -> RaiseStmt: return RaiseStmt(self.optional_node(node.expr), self.optional_node(node.from_expr)) - def visit_try_stmt(self, node: TryStmt) -> Node: + def visit_try_stmt(self, node: TryStmt) -> TryStmt: return TryStmt(self.block(node.body), self.optional_names(node.vars), - self.optional_nodes(node.types), + self.optional_expressions(node.types), self.blocks(node.handlers), self.optional_block(node.else_body), self.optional_block(node.finally_body)) - def visit_with_stmt(self, node: WithStmt) -> Node: - return WithStmt(self.nodes(node.expr), - self.optional_nodes(node.target), + def visit_with_stmt(self, node: WithStmt) -> WithStmt: + return WithStmt(self.expressions(node.expr), + self.optional_expressions(node.target), self.block(node.body)) - def visit_print_stmt(self, node: PrintStmt) -> Node: - return PrintStmt(self.nodes(node.args), + def visit_print_stmt(self, node: PrintStmt) -> PrintStmt: + return PrintStmt(self.expressions(node.args), node.newline, self.optional_node(node.target)) - def visit_exec_stmt(self, node: ExecStmt) -> Node: + def visit_exec_stmt(self, node: ExecStmt) -> ExecStmt: return ExecStmt(self.node(node.expr), self.optional_node(node.variables1), self.optional_node(node.variables2)) - def visit_star_expr(self, node: StarExpr) -> Node: + def visit_star_expr(self, node: StarExpr) -> StarExpr: return StarExpr(node.expr) - def visit_int_expr(self, node: IntExpr) -> Node: + def visit_int_expr(self, node: IntExpr) -> IntExpr: return IntExpr(node.value) - def visit_str_expr(self, node: StrExpr) -> Node: + def visit_str_expr(self, node: StrExpr) -> StrExpr: return StrExpr(node.value) - def visit_bytes_expr(self, node: BytesExpr) -> Node: + def visit_bytes_expr(self, node: BytesExpr) -> BytesExpr: return BytesExpr(node.value) - def visit_unicode_expr(self, node: UnicodeExpr) -> Node: + def visit_unicode_expr(self, node: UnicodeExpr) -> UnicodeExpr: return UnicodeExpr(node.value) - def visit_float_expr(self, node: FloatExpr) -> Node: + def visit_float_expr(self, node: FloatExpr) -> FloatExpr: return FloatExpr(node.value) - def visit_complex_expr(self, node: ComplexExpr) -> Node: + def visit_complex_expr(self, node: ComplexExpr) -> ComplexExpr: return ComplexExpr(node.value) - def visit_ellipsis(self, node: EllipsisExpr) -> Node: + def visit_ellipsis(self, node: EllipsisExpr) -> EllipsisExpr: return EllipsisExpr() - def visit_name_expr(self, node: NameExpr) -> Node: + def visit_name_expr(self, node: NameExpr) -> NameExpr: return self.duplicate_name(node) def duplicate_name(self, node: NameExpr) -> NameExpr: @@ -343,7 +343,7 @@ def duplicate_name(self, node: NameExpr) -> NameExpr: self.copy_ref(new, node) return new - def visit_member_expr(self, node: MemberExpr) -> Node: + def visit_member_expr(self, node: MemberExpr) -> MemberExpr: member = MemberExpr(self.node(node.expr), node.name) if node.def_var: @@ -363,63 +363,63 @@ def copy_ref(self, new: RefExpr, original: RefExpr) -> None: new.node = target new.is_def = original.is_def - def visit_yield_from_expr(self, node: YieldFromExpr) -> Node: + def visit_yield_from_expr(self, node: YieldFromExpr) -> YieldFromExpr: return YieldFromExpr(self.node(node.expr)) - def visit_yield_expr(self, node: YieldExpr) -> Node: + def visit_yield_expr(self, node: YieldExpr) -> YieldExpr: return YieldExpr(self.node(node.expr)) - def visit_await_expr(self, node: AwaitExpr) -> Node: + def visit_await_expr(self, node: AwaitExpr) -> AwaitExpr: return AwaitExpr(self.node(node.expr)) - def visit_call_expr(self, node: CallExpr) -> Node: + def visit_call_expr(self, node: CallExpr) -> CallExpr: return CallExpr(self.node(node.callee), - self.nodes(node.args), + self.expressions(node.args), node.arg_kinds[:], node.arg_names[:], self.optional_node(node.analyzed)) - def visit_op_expr(self, node: OpExpr) -> Node: + def visit_op_expr(self, node: OpExpr) -> OpExpr: new = OpExpr(node.op, self.node(node.left), self.node(node.right)) new.method_type = self.optional_type(node.method_type) return new - def visit_comparison_expr(self, node: ComparisonExpr) -> Node: - new = ComparisonExpr(node.operators, self.nodes(node.operands)) + def visit_comparison_expr(self, node: ComparisonExpr) -> ComparisonExpr: + new = ComparisonExpr(node.operators, self.expressions(node.operands)) new.method_types = [self.optional_type(t) for t in node.method_types] return new - def visit_cast_expr(self, node: CastExpr) -> Node: + def visit_cast_expr(self, node: CastExpr) -> CastExpr: return CastExpr(self.node(node.expr), self.type(node.type)) - def visit_reveal_type_expr(self, node: RevealTypeExpr) -> Node: + def visit_reveal_type_expr(self, node: RevealTypeExpr) -> RevealTypeExpr: return RevealTypeExpr(self.node(node.expr)) - def visit_super_expr(self, node: SuperExpr) -> Node: + def visit_super_expr(self, node: SuperExpr) -> SuperExpr: new = SuperExpr(node.name) new.info = node.info return new - def visit_unary_expr(self, node: UnaryExpr) -> Node: + def visit_unary_expr(self, node: UnaryExpr) -> UnaryExpr: new = UnaryExpr(node.op, self.node(node.expr)) new.method_type = self.optional_type(node.method_type) return new - def visit_list_expr(self, node: ListExpr) -> Node: - return ListExpr(self.nodes(node.items)) + def visit_list_expr(self, node: ListExpr) -> ListExpr: + return ListExpr(self.expressions(node.items)) - def visit_dict_expr(self, node: DictExpr) -> Node: + def visit_dict_expr(self, node: DictExpr) -> DictExpr: return DictExpr([(self.node(key), self.node(value)) for key, value in node.items]) - def visit_tuple_expr(self, node: TupleExpr) -> Node: - return TupleExpr(self.nodes(node.items)) + def visit_tuple_expr(self, node: TupleExpr) -> TupleExpr: + return TupleExpr(self.expressions(node.items)) - def visit_set_expr(self, node: SetExpr) -> Node: - return SetExpr(self.nodes(node.items)) + def visit_set_expr(self, node: SetExpr) -> SetExpr: + return SetExpr(self.expressions(node.items)) - def visit_index_expr(self, node: IndexExpr) -> Node: + def visit_index_expr(self, node: IndexExpr) -> IndexExpr: new = IndexExpr(self.node(node.base), self.node(node.index)) if node.method_type: new.method_type = self.type(node.method_type) @@ -435,24 +435,24 @@ def visit_type_application(self, node: TypeApplication) -> TypeApplication: return TypeApplication(self.node(node.expr), self.types(node.types)) - def visit_list_comprehension(self, node: ListComprehension) -> Node: + def visit_list_comprehension(self, node: ListComprehension) -> ListComprehension: generator = self.duplicate_generator(node.generator) generator.set_line(node.generator.line) return ListComprehension(generator) - def visit_set_comprehension(self, node: SetComprehension) -> Node: + def visit_set_comprehension(self, node: SetComprehension) -> SetComprehension: generator = self.duplicate_generator(node.generator) generator.set_line(node.generator.line) return SetComprehension(generator) - def visit_dictionary_comprehension(self, node: DictionaryComprehension) -> Node: + def visit_dictionary_comprehension(self, node: DictionaryComprehension) -> DictionaryComprehension: return DictionaryComprehension(self.node(node.key), self.node(node.value), [self.node(index) for index in node.indices], [self.node(s) for s in node.sequences], [[self.node(cond) for cond in conditions] for conditions in node.condlists]) - def visit_generator_expr(self, node: GeneratorExpr) -> Node: + def visit_generator_expr(self, node: GeneratorExpr) -> GeneratorExpr: return self.duplicate_generator(node) def duplicate_generator(self, node: GeneratorExpr) -> GeneratorExpr: @@ -462,20 +462,20 @@ def duplicate_generator(self, node: GeneratorExpr) -> GeneratorExpr: [[self.node(cond) for cond in conditions] for conditions in node.condlists]) - def visit_slice_expr(self, node: SliceExpr) -> Node: + def visit_slice_expr(self, node: SliceExpr) -> SliceExpr: return SliceExpr(self.optional_node(node.begin_index), self.optional_node(node.end_index), self.optional_node(node.stride)) - def visit_conditional_expr(self, node: ConditionalExpr) -> Node: + def visit_conditional_expr(self, node: ConditionalExpr) -> ConditionalExpr: return ConditionalExpr(self.node(node.cond), self.node(node.if_expr), self.node(node.else_expr)) - def visit_backquote_expr(self, node: BackquoteExpr) -> Node: + def visit_backquote_expr(self, node: BackquoteExpr) -> BackquoteExpr: return BackquoteExpr(self.node(node.expr)) - def visit_type_var_expr(self, node: TypeVarExpr) -> Node: + def visit_type_var_expr(self, node: TypeVarExpr) -> TypeVarExpr: return TypeVarExpr(node.name(), node.fullname(), self.types(node.values), self.type(node.upper_bound), variance=node.variance) @@ -488,13 +488,13 @@ def visit_newtype_expr(self, node: NewTypeExpr) -> NewTypeExpr: res.info = node.info return res - def visit_namedtuple_expr(self, node: NamedTupleExpr) -> Node: + def visit_namedtuple_expr(self, node: NamedTupleExpr) -> NamedTupleExpr: return NamedTupleExpr(node.info) - def visit__promote_expr(self, node: PromoteExpr) -> Node: + def visit__promote_expr(self, node: PromoteExpr) -> PromoteExpr: return PromoteExpr(node.type) - def visit_temp_node(self, node: TempNode) -> Node: + def visit_temp_node(self, node: TempNode) -> TempNode: return TempNode(self.type(node.type)) def node(self, node: Node) -> Node: @@ -523,11 +523,29 @@ def optional_block(self, block: Block) -> Block: else: return None - def nodes(self, nodes: List[Node]) -> List[Node]: - return [self.node(node) for node in nodes] + def statements(self, statements: List[Statement]) -> List[Statement]: + res = [] + for node in statements: + stmt = self.node(node) + assert isinstance(stmt, Statement) + res.append(stmt) + return res + + def expressions(self, expressions: List[Expression]) -> List[Expression]: + res = [] + for node in expressions: + expr = self.node(node) + assert isinstance(expr, Expression) + res.append(expr) + return res - def optional_nodes(self, nodes: List[Node]) -> List[Node]: - return [self.optional_node(node) for node in nodes] + def optional_expressions(self, expressions: List[Expression]) -> List[Expression]: + res = [] + for node in expressions: + expr = self.optional_node(node) + assert expr is None or isinstance(expr, Expression) + res.append(expr) + return res def blocks(self, blocks: List[Block]) -> List[Block]: return [self.block(block) for block in blocks] From 71306aaa2154d03fd2411148fa49f7d8f9850e80 Mon Sep 17 00:00:00 2001 From: Elazar Gershuni Date: Sun, 2 Oct 2016 16:05:29 +0300 Subject: [PATCH 2/3] seperate visitor calls by types --- mypy/fastparse.py | 10 +-- mypy/fastparse2.py | 4 +- mypy/test/testtransform.py | 2 +- mypy/treetransform.py | 140 +++++++++++++++++++------------------ 4 files changed, 81 insertions(+), 75 deletions(-) diff --git a/mypy/fastparse.py b/mypy/fastparse.py index 218db15a25eb..b3e9232591b2 100644 --- a/mypy/fastparse.py +++ b/mypy/fastparse.py @@ -141,7 +141,6 @@ def translate_stmt_list(self, l: Sequence[ast35.AST]) -> List[Statement]: res.append(stmt) return res - op_map = { ast35.Add: '+', ast35.Sub: '-', @@ -282,8 +281,10 @@ def do_func_def(self, n: Union[ast35.FunctionDef, ast35.AsyncFunctionDef], arg_types = [a.type_annotation if a.type_annotation is not None else AnyType() for a in args] else: - arg_types = [a if a is not None else AnyType() for - a in TypeConverter(line=n.lineno).translate_expr_list(func_type_ast.argtypes)] + translated_args = (TypeConverter(line=n.lineno) + .translate_expr_list(func_type_ast.argtypes)) + arg_types = [a if a is not None else AnyType() + for a in translated_args] return_type = TypeConverter(line=n.lineno).visit(func_type_ast.returns) # add implicit self type @@ -654,7 +655,8 @@ def visit_IfExp(self, n: ast35.IfExp) -> ConditionalExpr: # Dict(expr* keys, expr* values) @with_line def visit_Dict(self, n: ast35.Dict) -> DictExpr: - return DictExpr(list(zip(self.translate_expr_list(n.keys), self.translate_expr_list(n.values)))) + return DictExpr(list(zip(self.translate_expr_list(n.keys), + self.translate_expr_list(n.values)))) # Set(expr* elts) @with_line diff --git a/mypy/fastparse2.py b/mypy/fastparse2.py index 93e7c37d6580..01a1dffe21a7 100644 --- a/mypy/fastparse2.py +++ b/mypy/fastparse2.py @@ -158,7 +158,6 @@ def translate_stmt_list(self, l: Sequence[ast27.AST]) -> List[Statement]: res.append(stmt) return res - op_map = { ast27.Add: '+', ast27.Sub: '-', @@ -708,7 +707,8 @@ def visit_IfExp(self, n: ast27.IfExp) -> ConditionalExpr: # Dict(expr* keys, expr* values) @with_line def visit_Dict(self, n: ast27.Dict) -> DictExpr: - return DictExpr(list(zip(self.translate_expr_list(n.keys), self.translate_expr_list(n.values)))) + return DictExpr(list(zip(self.translate_expr_list(n.keys), + self.translate_expr_list(n.values)))) # Set(expr* elts) @with_line diff --git a/mypy/test/testtransform.py b/mypy/test/testtransform.py index bceee5c15b7e..d96af94ee729 100644 --- a/mypy/test/testtransform.py +++ b/mypy/test/testtransform.py @@ -67,7 +67,7 @@ def test_transform(testcase): and not os.path.splitext( os.path.basename(f.path))[0].endswith('_')): t = TestTransformVisitor() - f = t.node(f) + f = t.mypyfile(f) a += str(f).split('\n') except CompileError as e: a = e.messages diff --git a/mypy/treetransform.py b/mypy/treetransform.py index cf5d87e35247..1681a40ce2c2 100644 --- a/mypy/treetransform.py +++ b/mypy/treetransform.py @@ -80,12 +80,12 @@ def copy_argument(self, argument: Argument) -> Argument: if argument.initialization_statement: init_lvalue = cast( NameExpr, - self.node(argument.initialization_statement.lvalues[0]), + self.expr(argument.initialization_statement.lvalues[0]), ) init_lvalue.set_line(argument.line) init_stmt = AssignmentStmt( [init_lvalue], - self.node(argument.initialization_statement.rvalue), + self.expr(argument.initialization_statement.rvalue), self.optional_type(argument.initialization_statement.type), ) @@ -230,14 +230,14 @@ def visit_var(self, node: Var) -> Var: return new def visit_expression_stmt(self, node: ExpressionStmt) -> ExpressionStmt: - return ExpressionStmt(self.node(node.expr)) + return ExpressionStmt(self.expr(node.expr)) def visit_assignment_stmt(self, node: AssignmentStmt) -> AssignmentStmt: return self.duplicate_assignment(node) def duplicate_assignment(self, node: AssignmentStmt) -> AssignmentStmt: new = AssignmentStmt(self.expressions(node.lvalues), - self.node(node.rvalue), + self.expr(node.rvalue), self.optional_type(node.type)) new.line = node.line return new @@ -245,28 +245,28 @@ def duplicate_assignment(self, node: AssignmentStmt) -> AssignmentStmt: def visit_operator_assignment_stmt(self, node: OperatorAssignmentStmt) -> OperatorAssignmentStmt: return OperatorAssignmentStmt(node.op, - self.node(node.lvalue), - self.node(node.rvalue)) + self.expr(node.lvalue), + self.expr(node.rvalue)) def visit_while_stmt(self, node: WhileStmt) -> WhileStmt: - return WhileStmt(self.node(node.expr), + return WhileStmt(self.expr(node.expr), self.block(node.body), self.optional_block(node.else_body)) def visit_for_stmt(self, node: ForStmt) -> ForStmt: - return ForStmt(self.node(node.index), - self.node(node.expr), + return ForStmt(self.expr(node.index), + self.expr(node.expr), self.block(node.body), self.optional_block(node.else_body)) def visit_return_stmt(self, node: ReturnStmt) -> ReturnStmt: - return ReturnStmt(self.optional_node(node.expr)) + return ReturnStmt(self.optional_expr(node.expr)) def visit_assert_stmt(self, node: AssertStmt) -> AssertStmt: - return AssertStmt(self.node(node.expr)) + return AssertStmt(self.expr(node.expr)) def visit_del_stmt(self, node: DelStmt) -> DelStmt: - return DelStmt(self.node(node.expr)) + return DelStmt(self.expr(node.expr)) def visit_if_stmt(self, node: IfStmt) -> IfStmt: return IfStmt(self.expressions(node.expr), @@ -283,8 +283,8 @@ def visit_pass_stmt(self, node: PassStmt) -> PassStmt: return PassStmt() def visit_raise_stmt(self, node: RaiseStmt) -> RaiseStmt: - return RaiseStmt(self.optional_node(node.expr), - self.optional_node(node.from_expr)) + return RaiseStmt(self.optional_expr(node.expr), + self.optional_expr(node.from_expr)) def visit_try_stmt(self, node: TryStmt) -> TryStmt: return TryStmt(self.block(node.body), @@ -302,12 +302,12 @@ def visit_with_stmt(self, node: WithStmt) -> WithStmt: def visit_print_stmt(self, node: PrintStmt) -> PrintStmt: return PrintStmt(self.expressions(node.args), node.newline, - self.optional_node(node.target)) + self.optional_expr(node.target)) def visit_exec_stmt(self, node: ExecStmt) -> ExecStmt: - return ExecStmt(self.node(node.expr), - self.optional_node(node.variables1), - self.optional_node(node.variables2)) + return ExecStmt(self.expr(node.expr), + self.optional_expr(node.variables1), + self.optional_expr(node.variables2)) def visit_star_expr(self, node: StarExpr) -> StarExpr: return StarExpr(node.expr) @@ -344,7 +344,7 @@ def duplicate_name(self, node: NameExpr) -> NameExpr: return new def visit_member_expr(self, node: MemberExpr) -> MemberExpr: - member = MemberExpr(self.node(node.expr), + member = MemberExpr(self.expr(node.expr), node.name) if node.def_var: member.def_var = self.visit_var(node.def_var) @@ -364,23 +364,23 @@ def copy_ref(self, new: RefExpr, original: RefExpr) -> None: new.is_def = original.is_def def visit_yield_from_expr(self, node: YieldFromExpr) -> YieldFromExpr: - return YieldFromExpr(self.node(node.expr)) + return YieldFromExpr(self.expr(node.expr)) def visit_yield_expr(self, node: YieldExpr) -> YieldExpr: - return YieldExpr(self.node(node.expr)) + return YieldExpr(self.expr(node.expr)) def visit_await_expr(self, node: AwaitExpr) -> AwaitExpr: - return AwaitExpr(self.node(node.expr)) + return AwaitExpr(self.expr(node.expr)) def visit_call_expr(self, node: CallExpr) -> CallExpr: - return CallExpr(self.node(node.callee), + return CallExpr(self.expr(node.callee), self.expressions(node.args), node.arg_kinds[:], node.arg_names[:], - self.optional_node(node.analyzed)) + self.optional_expr(node.analyzed)) def visit_op_expr(self, node: OpExpr) -> OpExpr: - new = OpExpr(node.op, self.node(node.left), self.node(node.right)) + new = OpExpr(node.op, self.expr(node.left), self.expr(node.right)) new.method_type = self.optional_type(node.method_type) return new @@ -390,11 +390,11 @@ def visit_comparison_expr(self, node: ComparisonExpr) -> ComparisonExpr: return new def visit_cast_expr(self, node: CastExpr) -> CastExpr: - return CastExpr(self.node(node.expr), + return CastExpr(self.expr(node.expr), self.type(node.type)) def visit_reveal_type_expr(self, node: RevealTypeExpr) -> RevealTypeExpr: - return RevealTypeExpr(self.node(node.expr)) + return RevealTypeExpr(self.expr(node.expr)) def visit_super_expr(self, node: SuperExpr) -> SuperExpr: new = SuperExpr(node.name) @@ -402,7 +402,7 @@ def visit_super_expr(self, node: SuperExpr) -> SuperExpr: return new def visit_unary_expr(self, node: UnaryExpr) -> UnaryExpr: - new = UnaryExpr(node.op, self.node(node.expr)) + new = UnaryExpr(node.op, self.expr(node.expr)) new.method_type = self.optional_type(node.method_type) return new @@ -410,7 +410,7 @@ def visit_list_expr(self, node: ListExpr) -> ListExpr: return ListExpr(self.expressions(node.items)) def visit_dict_expr(self, node: DictExpr) -> DictExpr: - return DictExpr([(self.node(key), self.node(value)) + return DictExpr([(self.expr(key), self.expr(value)) for key, value in node.items]) def visit_tuple_expr(self, node: TupleExpr) -> TupleExpr: @@ -420,7 +420,7 @@ def visit_set_expr(self, node: SetExpr) -> SetExpr: return SetExpr(self.expressions(node.items)) def visit_index_expr(self, node: IndexExpr) -> IndexExpr: - new = IndexExpr(self.node(node.base), self.node(node.index)) + new = IndexExpr(self.expr(node.base), self.expr(node.index)) if node.method_type: new.method_type = self.type(node.method_type) if node.analyzed: @@ -432,7 +432,7 @@ def visit_index_expr(self, node: IndexExpr) -> IndexExpr: return new def visit_type_application(self, node: TypeApplication) -> TypeApplication: - return TypeApplication(self.node(node.expr), + return TypeApplication(self.expr(node.expr), self.types(node.types)) def visit_list_comprehension(self, node: ListComprehension) -> ListComprehension: @@ -445,35 +445,36 @@ def visit_set_comprehension(self, node: SetComprehension) -> SetComprehension: generator.set_line(node.generator.line) return SetComprehension(generator) - def visit_dictionary_comprehension(self, node: DictionaryComprehension) -> DictionaryComprehension: - return DictionaryComprehension(self.node(node.key), self.node(node.value), - [self.node(index) for index in node.indices], - [self.node(s) for s in node.sequences], - [[self.node(cond) for cond in conditions] + def visit_dictionary_comprehension(self, node: DictionaryComprehension + ) -> DictionaryComprehension: + return DictionaryComprehension(self.expr(node.key), self.expr(node.value), + [self.expr(index) for index in node.indices], + [self.expr(s) for s in node.sequences], + [[self.expr(cond) for cond in conditions] for conditions in node.condlists]) def visit_generator_expr(self, node: GeneratorExpr) -> GeneratorExpr: return self.duplicate_generator(node) def duplicate_generator(self, node: GeneratorExpr) -> GeneratorExpr: - return GeneratorExpr(self.node(node.left_expr), - [self.node(index) for index in node.indices], - [self.node(s) for s in node.sequences], - [[self.node(cond) for cond in conditions] + return GeneratorExpr(self.expr(node.left_expr), + [self.expr(index) for index in node.indices], + [self.expr(s) for s in node.sequences], + [[self.expr(cond) for cond in conditions] for conditions in node.condlists]) def visit_slice_expr(self, node: SliceExpr) -> SliceExpr: - return SliceExpr(self.optional_node(node.begin_index), - self.optional_node(node.end_index), - self.optional_node(node.stride)) + return SliceExpr(self.optional_expr(node.begin_index), + self.optional_expr(node.end_index), + self.optional_expr(node.stride)) def visit_conditional_expr(self, node: ConditionalExpr) -> ConditionalExpr: - return ConditionalExpr(self.node(node.cond), - self.node(node.if_expr), - self.node(node.else_expr)) + return ConditionalExpr(self.expr(node.cond), + self.expr(node.if_expr), + self.expr(node.else_expr)) def visit_backquote_expr(self, node: BackquoteExpr) -> BackquoteExpr: - return BackquoteExpr(self.node(node.expr)) + return BackquoteExpr(self.expr(node.expr)) def visit_type_var_expr(self, node: TypeVarExpr) -> TypeVarExpr: return TypeVarExpr(node.name(), node.fullname(), @@ -502,13 +503,31 @@ def node(self, node: Node) -> Node: new.set_line(node.line) return new + def mypyfile(self, node: MypyFile) -> MypyFile: + new = node.accept(self) + assert isinstance(new, MypyFile) + new.set_line(node.line) + return new + + def expr(self, expr: Expression) -> Expression: + new = expr.accept(self) + assert isinstance(new, Expression) + new.set_line(expr.line) + return new + + def stmt(self, stmt: Statement) -> Statement: + new = stmt.accept(self) + assert isinstance(new, Statement) + new.set_line(stmt.line) + return new + # Helpers # # All the node helpers also propagate line numbers. - def optional_node(self, node: Node) -> Node: - if node: - return self.node(node) + def optional_expr(self, expr: Expression) -> Expression: + if expr: + return self.expr(expr) else: return None @@ -524,28 +543,13 @@ def optional_block(self, block: Block) -> Block: return None def statements(self, statements: List[Statement]) -> List[Statement]: - res = [] - for node in statements: - stmt = self.node(node) - assert isinstance(stmt, Statement) - res.append(stmt) - return res + return [self.stmt(stmt) for stmt in statements] def expressions(self, expressions: List[Expression]) -> List[Expression]: - res = [] - for node in expressions: - expr = self.node(node) - assert isinstance(expr, Expression) - res.append(expr) - return res + return [self.expr(expr) for expr in expressions] def optional_expressions(self, expressions: List[Expression]) -> List[Expression]: - res = [] - for node in expressions: - expr = self.optional_node(node) - assert expr is None or isinstance(expr, Expression) - res.append(expr) - return res + return [self.optional_expr(expr) for expr in expressions] def blocks(self, blocks: List[Block]) -> List[Block]: return [self.block(block) for block in blocks] From 4743e5ce04b2d8e9a2db7fe0d795504e6f36bc6f Mon Sep 17 00:00:00 2001 From: Elazar Gershuni Date: Sun, 2 Oct 2016 16:49:09 +0300 Subject: [PATCH 3/3] make Expression and Statment separable --- mypy/binder.py | 20 ++++++++++---------- mypy/checker.py | 2 +- mypy/exprtotype.py | 7 ++++--- mypy/nodes.py | 8 +++----- mypy/semanal.py | 4 ++-- mypy/stubgen.py | 16 ++++++++-------- mypy/treetransform.py | 2 +- mypy/typeanal.py | 4 ++-- 8 files changed, 31 insertions(+), 32 deletions(-) diff --git a/mypy/binder.py b/mypy/binder.py index 96e9cb30ada3..ba956ef20764 100644 --- a/mypy/binder.py +++ b/mypy/binder.py @@ -1,8 +1,8 @@ -from typing import (Any, Dict, List, Set, Iterator) +from typing import (Any, Dict, List, Set, Iterator, Union) from contextlib import contextmanager from mypy.types import Type, AnyType, PartialType -from mypy.nodes import (Expression, Var, RefExpr, SymbolTableNode) +from mypy.nodes import (Node, Expression, Var, RefExpr, SymbolTableNode) from mypy.subtypes import is_subtype from mypy.join import join_simple @@ -96,16 +96,16 @@ def _get(self, key: Key, index: int=-1) -> Type: return self.frames[i][key] return None - def push(self, expr: Expression, typ: Type) -> None: - if not expr.literal: + def push(self, node: Node, typ: Type) -> None: + if not node.literal: return - key = expr.literal_hash + key = node.literal_hash if key not in self.declarations: - self.declarations[key] = self.get_declaration(expr) + self.declarations[key] = self.get_declaration(node) self._add_dependencies(key) self._push(key, typ) - def get(self, expr: Expression) -> Type: + def get(self, expr: Union[Expression, Var]) -> Type: return self._get(expr.literal_hash) def cleanse(self, expr: Expression) -> None: @@ -165,9 +165,9 @@ def pop_frame(self, fall_through: int = 0) -> Frame: return result - def get_declaration(self, expr: Expression) -> Type: - if isinstance(expr, (RefExpr, SymbolTableNode)) and isinstance(expr.node, Var): - type = expr.node.type + def get_declaration(self, node: Node) -> Type: + if isinstance(node, (RefExpr, SymbolTableNode)) and isinstance(node.node, Var): + type = node.node.type if isinstance(type, PartialType): return None return type diff --git a/mypy/checker.py b/mypy/checker.py index 4474a9ba8ee5..c30c00dd00ab 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -2218,7 +2218,7 @@ def check_type_equivalency(self, t1: Type, t2: Type, node: Context, if not is_equivalent(t1, t2): self.fail(msg, node) - def store_type(self, node: Expression, typ: Type) -> None: + def store_type(self, node: Node, typ: Type) -> None: """Store the type of a node in the type map.""" self.type_map[node] = typ if typ is not None: diff --git a/mypy/exprtotype.py b/mypy/exprtotype.py index dc95f3d4ba02..764c716b1f96 100644 --- a/mypy/exprtotype.py +++ b/mypy/exprtotype.py @@ -1,7 +1,8 @@ -"""Translate an expression (Node) to a Type value.""" +"""Translate an Expression to a Type value.""" from mypy.nodes import ( - Node, NameExpr, MemberExpr, IndexExpr, TupleExpr, ListExpr, StrExpr, BytesExpr, EllipsisExpr + Expression, NameExpr, MemberExpr, IndexExpr, TupleExpr, + ListExpr, StrExpr, BytesExpr, EllipsisExpr ) from mypy.parsetype import parse_str_as_type, TypeParseError from mypy.types import Type, UnboundType, TypeList, EllipsisType @@ -11,7 +12,7 @@ class TypeTranslationError(Exception): """Exception raised when an expression is not valid as a type.""" -def expr_to_unanalyzed_type(expr: Node) -> Type: +def expr_to_unanalyzed_type(expr: Expression) -> Type: """Translate an expression to the corresponding type. The result is not semantically analyzed. It can be UnboundType or TypeList. diff --git a/mypy/nodes.py b/mypy/nodes.py index 9aaa7c355326..2df25510fc37 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -135,9 +135,7 @@ def accept(self, visitor: NodeVisitor[T]) -> T: # These are placeholders for a future refactoring; see #1783. # For now they serve as (unchecked) documentation of what various # fields of Node subtypes are expected to contain. -class Statement(Node): - pass - +Statement = Node Expression = Node Lvalue = Expression @@ -1767,9 +1765,9 @@ def accept(self, visitor: NodeVisitor[T]) -> T: class AwaitExpr(Node): """Await expression (await ...).""" - expr = None # type: Node + expr = None # type: Expression - def __init__(self, expr: Node) -> None: + def __init__(self, expr: Expression) -> None: self.expr = expr def accept(self, visitor: NodeVisitor[T]) -> T: diff --git a/mypy/semanal.py b/mypy/semanal.py index 0f777e4c6361..c886065996f7 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -60,7 +60,7 @@ FuncExpr, MDEF, FuncBase, Decorator, SetExpr, TypeVarExpr, NewTypeExpr, StrExpr, BytesExpr, PrintStmt, ConditionalExpr, PromoteExpr, ComparisonExpr, StarExpr, ARG_POS, ARG_NAMED, MroError, type_aliases, - YieldFromExpr, NamedTupleExpr, NonlocalDecl, + YieldFromExpr, NamedTupleExpr, NonlocalDecl, SymbolNode, SetComprehension, DictionaryComprehension, TYPE_ALIAS, TypeAliasExpr, YieldExpr, ExecStmt, Argument, BackquoteExpr, ImportBase, AwaitExpr, IntExpr, FloatExpr, UnicodeExpr, EllipsisExpr, @@ -1300,7 +1300,7 @@ def is_self_member_ref(self, memberexpr: MemberExpr) -> bool: node = memberexpr.expr.node return isinstance(node, Var) and node.is_self - def check_lvalue_validity(self, node: Expression, ctx: Context) -> None: + def check_lvalue_validity(self, node: Union[Expression, SymbolNode], ctx: Context) -> None: if isinstance(node, (TypeInfo, TypeVarExpr)): self.fail('Invalid assignment target', ctx) diff --git a/mypy/stubgen.py b/mypy/stubgen.py index b6c7c8f47dfe..2bf79654fae0 100644 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -53,7 +53,7 @@ import mypy.traverser from mypy import defaults from mypy.nodes import ( - Node, IntExpr, UnaryExpr, StrExpr, BytesExpr, NameExpr, FloatExpr, MemberExpr, TupleExpr, + Expression, IntExpr, UnaryExpr, StrExpr, BytesExpr, NameExpr, FloatExpr, MemberExpr, TupleExpr, ListExpr, ComparisonExpr, CallExpr, ClassDef, MypyFile, Decorator, AssignmentStmt, IfStmt, ImportAll, ImportFrom, Import, FuncDef, FuncBase, ARG_STAR, ARG_STAR2, ARG_NAMED ) @@ -360,7 +360,7 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None: if all(foundl): self._state = VAR - def is_namedtuple(self, expr: Node) -> bool: + def is_namedtuple(self, expr: Expression) -> bool: if not isinstance(expr, CallExpr): return False callee = expr.callee @@ -445,7 +445,7 @@ def visit_import(self, o: Import) -> None: self.add_import_line('import %s as %s\n' % (id, target_name)) self.record_name(target_name) - def get_init(self, lvalue: str, rvalue: Node) -> str: + def get_init(self, lvalue: str, rvalue: Expression) -> str: """Return initializer for a variable. Return None if we've generated one already or if the variable is internal. @@ -504,7 +504,7 @@ def is_private_name(self, name: str) -> bool: '__setstate__', '__slots__')) - def get_str_type_of_node(self, rvalue: Node, + def get_str_type_of_node(self, rvalue: Expression, can_infer_optional: bool = False) -> str: if isinstance(rvalue, IntExpr): return 'int' @@ -543,8 +543,8 @@ def is_recorded_name(self, name: str) -> bool: return self.is_top_level() and name in self._toplevel_names -def find_self_initializers(fdef: FuncBase) -> List[Tuple[str, Node]]: - results = [] # type: List[Tuple[str, Node]] +def find_self_initializers(fdef: FuncBase) -> List[Tuple[str, Expression]]: + results = [] # type: List[Tuple[str, Expression]] class SelfTraverser(mypy.traverser.TraverserVisitor): def visit_assignment_stmt(self, o: AssignmentStmt) -> None: @@ -558,7 +558,7 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None: return results -def find_classes(node: Node) -> Set[str]: +def find_classes(node: MypyFile) -> Set[str]: results = set() # type: Set[str] class ClassTraverser(mypy.traverser.TraverserVisitor): @@ -569,7 +569,7 @@ def visit_class_def(self, o: ClassDef) -> None: return results -def get_qualified_name(o: Node) -> str: +def get_qualified_name(o: Expression) -> str: if isinstance(o, NameExpr): return o.name elif isinstance(o, MemberExpr): diff --git a/mypy/treetransform.py b/mypy/treetransform.py index 1681a40ce2c2..100ff7854a8c 100644 --- a/mypy/treetransform.py +++ b/mypy/treetransform.py @@ -188,7 +188,7 @@ def visit_class_def(self, node: ClassDef) -> ClassDef: node.metaclass) new.fullname = node.fullname new.info = node.info - new.decorators = [decorator.accept(self) + new.decorators = [self.expr(decorator) for decorator in node.decorators] new.is_builtinclass = node.is_builtinclass return new diff --git a/mypy/typeanal.py b/mypy/typeanal.py index f299b0b94ba4..931cf7cf5363 100644 --- a/mypy/typeanal.py +++ b/mypy/typeanal.py @@ -9,7 +9,7 @@ ) from mypy.nodes import ( BOUND_TVAR, TYPE_ALIAS, UNBOUND_IMPORTED, - TypeInfo, Context, SymbolTableNode, Var, Node, + TypeInfo, Context, SymbolTableNode, Var, Expression, IndexExpr, RefExpr ) from mypy.sametypes import is_same_type @@ -28,7 +28,7 @@ } -def analyze_type_alias(node: Node, +def analyze_type_alias(node: Expression, lookup_func: Callable[[str, Context], SymbolTableNode], lookup_fqn_func: Callable[[str], SymbolTableNode], fail_func: Callable[[str, Context], None]) -> Type: