8000 Support TypedDicts with missing keys (total=False) by JukkaL · Pull Request #3558 · python/mypy · GitHub
[go: up one dir, main page]

Skip to content

Support TypedDicts with missing keys (total=False) #3558

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Jun 23, 2017
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Basic support for TypedDicts with missing keys (total=False)
Only the functional syntax is supported.
  • Loading branch information
JukkaL committed Jun 15, 2017
commit ccfc4adbbb795ffe84c64ede92ea941d33e6f8c7
9 changes: 6 additions & 3 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,9 @@ def check_typeddict_call_with_dict(self, callee: TypedDictType,
def check_typeddict_call_with_kwargs(self, callee: TypedDictType,
kwargs: 'OrderedDict[str, Expression]',
context: Context) -> Type:
if callee.items.keys() != kwargs.keys():
callee_item_names = callee.items.keys()
if not (callee.required_keys <= kwargs.keys() <= callee.items.keys()):
callee_item_names = [key for key in callee.items.keys()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Naming these two variables expected_xxx and actual_xxx (matching the error message call below) would help in understanding what they mean.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A good idea -- done.

if key in callee.required_keys or key in kwargs.keys()]
kwargs_item_names = kwargs.keys()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This blank line irks me.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

self.msg.typeddict_instantiated_with_unexpected_items(
Expand All @@ -316,7 +317,7 @@ def check_typeddict_call_with_kwargs(self, callee: TypedDictType,
mapping_value_type = join.join_type_list(list(items.values()))
fallback = self.chk.named_generic_type('typing.Mapping',
[self.chk.str_type(), mapping_value_type])
return TypedDictType(items, fallback)
return TypedDictType(items, set(callee.required_keys), fallback)

# Types and methods that can be used to infer partial types.
item_args = {'builtins.list': ['append'],
Expand Down Expand Up @@ -1656,6 +1657,8 @@ def visit_typeddict_index_expr(self, td_type: TypedDictType, index: Expression)
if item_type is None:
self.msg.typeddict_item_name_not_found(td_type, item_name, index)
return AnyType()
if item_name not in td_type.required_keys:
self.msg.typeddict_item_may_be_undefined(item_name, index)
return item_type

def visit_enum_index_expr(self, enum_type: TypeInfo, index: Expression,
Expand Down
2 changes: 1 addition & 1 deletion mypy/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def visit_typeddict_type(self, t: TypedDictType) -> Type:
])
mapping_value_type = join_type_list(list(items.values()))
fallback = self.s.create_anonymous_fallback(value_type=mapping_value_type)
return TypedDictType(items, fallback)
return TypedDictType(items, set(items.keys()), fallback) # XXX required
elif isinstance(self.s, Instance):
return join_instances(self.s, t.fallback)
else:
Expand Down
2 changes: 1 addition & 1 deletion mypy/meet.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def visit_typeddict_type(self, t: TypedDictType) -> Type:
items = OrderedDict(item_list)
mapping_value_type = join_type_list(list(items.values()))
fallback = self.s.create_anonymous_fallback(value_type=mapping_value_type)
return TypedDictType(items, fallback)
return TypedDictType(items, set(items.keys()), fallback) # XXX required
else:
return self.default(self.s)

Expand Down
7 changes: 7 additions & 0 deletions mypy/messages.py
8000
Original file line number Diff line number Diff line change
Expand Up @@ -901,6 +901,13 @@ def typeddict_item_name_not_found(self,
self.fail('\'{}\' is not a valid TypedDict key; expected one of {}'.format(
item_name, format_item_name_list(typ.items.keys())), context)

def typeddict_item_may_be_undefined(self,
item_name: str,
context: Context,
) -> None:
self.fail("TypedDict key '{}' may be undefined".format(item_name), context)
self.note("Consider using get() instead", context)

def type_arguments_not_allowed(self, context: Context) -> None:
self.fail('Parameterized generics cannot be used with class or instance checks', context)

Expand Down
72 changes: 48 additions & 24 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2302,46 +2302,64 @@ def check_typeddict(self, node: Expression, var_name: str = None) -> Optional[Ty
fullname = callee.fullname
if fullname != 'mypy_extensions.TypedDict':
return None
items, types, ok = self.parse_typeddict_args(call, fullname)
items, types, total, ok = self.parse_typeddict_args(call, fullname)
if not ok:
# Error. Construct dummy return value.
return self.build_typeddict_typeinfo('TypedDict', [], [])
name = cast(StrExpr, call.args[0]).value
if name != var_name or self.is_func_scope():
# Give it a unique name derived from the line number.
name += '@' + str(call.line)
info = self.build_typeddict_typeinfo(name, items, types)
# Store it as a global just in case it would remain anonymous.
# (Or in the nearest class if there is one.)
stnode = SymbolTableNode(GDEF, info, self.cur_mod_id)
if self.type:
self.type.names[name] = stnode
info = self.build_typeddict_typeinfo('TypedDict', [], [])
else:
self.globals[name] = stnode
name = cast(StrExpr, call.args[0]).value
if name != var_name or self.is_func_scope():
# Give it a unique name derived from the line number.
name += '@' + str(call.line)
info = self.build_typeddict_typeinfo(name, items, types, total)
# Store it as a global just in case it would remain anonymous.
# (Or in the nearest class if there is one.)
stnode = SymbolTableNode(GDEF, info, self.cur_mod_id)
if self.type:
self.type.names[name] = stnode
else:
self.globals[name] = stnode
call.analyzed = TypedDictExpr(info)
call.analyzed.set_line(call.line, call.column)
return info

def parse_typeddict_args(self, call: CallExpr,
fullname: str) -> Tuple[List[str], List[Type], bool]:
fullname: str) -> Tuple[List[str], List[Type], bool, bool]:
# TODO: Share code with check_argument_count in checkexpr.py?
args = call.args
if len(args) < 2:
return self.fail_typeddict_arg("Too few arguments for TypedDict()", call)
if len(args) > 2:
if len(args) > 3:
return self.fail_typeddict_arg("Too many arguments for TypedDict()", call)
# TODO: Support keyword arguments
if call.arg_kinds != [ARG_POS, ARG_POS]:
if call.arg_kinds not in ([ARG_POS, ARG_POS], [ARG_POS, ARG_POS, ARG_NAMED]):
return self.fail_typeddict_arg("Unexpected arguments to TypedDict()", call)
if len(args) == 3 and call.arg_names[2] != 'total':
return self.fail_typeddict_arg(
'Unexpected keyword argument "{}" for "TypedDict"'.format(call.arg_names[2]), call)
if not isinstance(args[0], (StrExpr, BytesExpr, UnicodeExpr)):
return self.fail_typeddict_arg(
"TypedDict() expects a string literal as the first argument", call)
if not isinstance(args[1], DictExpr):
return self.fail_typeddict_arg(
"TypedDict() expects a dictionary literal as the second argument", call)
total = True
if len(args) == 3:
total = self.parse_bool(call.args[2])
if total is None:
return self.fail_typeddict_arg(
'TypedDict() "total" argument must be True or False', call)
dictexpr = args[1]
items, types, ok = self.parse_typeddict_fields_with_types(dictexpr.items, call)
return items, types, ok
return items, types, total, ok

def parse_bool(self, expr: Expression) -> Optional[bool]:
if isinstance(expr, NameExpr):
if expr.fullname == 'builtins.True':
return True
if expr.fullname == 'builtins.False':
return False
return None

def parse_typeddict_fields_with_types(self, dict_items: List[Tuple[Expression, Expression]],
context: Context) -> Tuple[List[str], List[Type], bool]:
Expand All @@ -2351,29 +2369,35 @@ def parse_typeddict_fields_with_types(self, dict_items: List[Tuple[Expression, E
if isinstance(field_name_expr, (StrExpr, BytesExpr, UnicodeExpr)):
items.append(field_name_expr.value)
else:
return self.fail_typeddict_arg("Invalid TypedDict() field name", field_name_expr)
self.fail_typeddict_arg("Invalid TypedDict() field name", field_name_expr)
return [], [], False
try:
type = expr_to_unanalyzed_type(field_type_expr)
except TypeTranslationError:
return self.fail_typeddict_arg('Invalid field type', field_type_expr)
self.fail_typeddict_arg('Invalid field type', field_type_expr)
return [], [], False
types.append(self.anal_type(type))
return items, types, True

def fail_typeddict_arg(self, message: str,
context: Context) -> Tuple[List[str], List[Type], bool]:
context: Context) -> Tuple[List[str], List[Type], bool, bool]:
self.fail(message, context)
return [], [], False
return [], [], True, False

def build_typeddict_typeinfo(self, name: str, items: List[str],
types: List[Type]) -> TypeInfo:
types: List[Type], total: bool = True) -> TypeInfo:
mapping_value_type = join.join_type_list(types)
fallback = (self.named_type_or_none('typing.Mapping',
[self.str_type(), mapping_value_type])
or self.object_type())

info = self.basic_new_typeinfo(name, fallback)
info.typeddict_type = TypedDictType(OrderedDict(zip(items, types)), fallback)

if total:
required_keys = set(items)
else:
required_keys = set()
info.typeddict_type = TypedDictType(OrderedDict(zip(items, types)), required_keys,
fallback)
return info

def check_classvar(self, s: AssignmentStmt) -> None:
Expand Down
2 changes: 1 addition & 1 deletion mypy/typeanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def visit_typeddict_type(self, t: TypedDictType) -> Type:
(item_name, self.anal_type(item_type))
for (item_name, item_type) in t.items.items()
])
return TypedDictType(items, t.fallback)
return TypedDictType(items, set(t.required_keys), t.fallback)

def visit_star_type(self, t: StarType) -> Type:
return StarType(self.anal_type(t.type), t.line)
Expand Down
26 changes: 19 additions & 7 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,12 +921,14 @@ class TypedDictType(Type):
whose TypeInfo has a typeddict_type that is anonymous.
"""

items = None # type: OrderedDict[str, Type] # (item_name, item_type)
items = None # type: OrderedDict[str, Type] # item_name -> item_type
required_keys = None # type: Set[str]
fallback = None # type: Instance

def __init__(self, items: 'OrderedDict[str, Type]', fallback: Instance,
line: int = -1, column: int = -1) -> None:
def __init__(self, items: 'OrderedDict[str, Type]', required_keys: Set[str],
fallback: Instance, line: int = -1, column: int = -1) -> None:
self.items = items
self.required_keys = required_keys
self.fallback = fallback
self.can_be_true = len(self.items) > 0
self.can_be_false = len(self.items) == 0
Expand All @@ -938,6 +940,7 @@ def accept(self, visitor: 'TypeVisitor[T]') -> T:
def serialize(self) -> JsonDict:
return {'.class': 'TypedDictType',
'items': [[n, t.serialize()] for (n, t) in self.items.items()],
'required_keys': sorted(self.required_keys),
'fallback': self.fallback.serialize(),
}

Expand All @@ -946,6 +949,7 @@ def deserialize(cls, data: JsonDict) -> 'TypedDictType':
assert data['.class'] == 'TypedDictType'
return TypedDictType(OrderedDict([(n, deserialize_type(t))
for (n, t) in data['items']]),
set(data['required_keys']),
Instance.deserialize(data['fallback']))

def as_anonymous(self) -> 'TypedDictType':
Expand All @@ -955,14 +959,15 @@ def as_anonymous(self) -> 'TypedDictType':
return self.fallback.type.typeddict_type.as_anonymous()

def copy_modified(self, *, fallback: Instance = None,
item_types: List[Type] = None) -> 'TypedDictType':
item_types: List[Type] = None,
required_keys: Set[str] = None) -> 'TypedDictType':
if fallback is None:
fallback = self.fallback
if item_types is None:
items = self.items
else:
items = OrderedDict(zip(self.items, item_types))
return TypedDictType(items, fallback, self.line, self.column)
return TypedDictType(items, self.required_keys, fallback, self.line, self.column)

def create_anonymous_fallback(self, *, value_type: Type) -> Instance:
anonymous = self.as_anonymous()
Expand Down Expand Up @@ -1371,6 +1376,7 @@ def visit_typeddict_type(self, t: TypedDictType) -> Type:
for (item_name, item_type) in t.items.items()
])
return TypedDictType(items,
t.required_keys,
# TODO: This appears to be unsafe.
cast(Any, t.fallback.accept(self)),
t.line, t.column)
Expand Down Expand Up @@ -1516,11 +1522,17 @@ def visit_tuple_type(self, t: TupleType) -> str:

def visit_typeddict_type(self, t: TypedDictType) -> str:
s = self.keywords_str(t.items.items())
if t.required_keys == set(t.items):
keys_str = ''
elif t.required_keys == set():
keys_str = ', _total=False'
else:
keys_str = ', _required_keys=[{}]'.format(', '.join(sorted(t.required_keys)))
if t.fallback and t.fallback.type:
if s == '':
return 'TypedDict(_fallback={})'.format(t.fallback.accept(self))
return 'TypedDict(_fallback={}{})'.format(t.fallback.accept(self), keys_str)
else:
return 'TypedDict({}, _fallback={})'.format(s, t.fallback.accept(self))
return 'TypedDict({}, _fallback={}{})'.format(s, t.fallback.accept(self), keys_str)
return 'TypedDict({})'.format(s)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit disturbing that the repr of a TypedDict is so different from the actual syntax used to create one, but let's deal with that some other time.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Created #3590 to track this.


def visit_star_type(self, t: StarType) -> str:
Expand Down
80 changes: 80 additions & 0 deletions test-data/unit/check-typeddict.test
Original file line number Diff line number Diff line change
Expand Up @@ -813,3 +813,83 @@ p = TaggedPoint(type='2d', x=42, y=1337)
p.get('x', 1 + 'y') # E: Unsupported operand types for + ("int" and "str")
[builtins fixtures/dict.pyi]
[typing fixtures/typing-full.pyi]


-- Totality (the "total" keyword argument)

[case testTypedDictWithTotalTrue]
from mypy_extensions import TypedDict
D = TypedDict('D', {'x': int}, total=True)
d: D
reveal_type(d) # E: Revealed type is 'TypedDict(x=builtins.int, _fallback=__main__.D)'
[builtins fixtures/dict.pyi]

[case testTypedDictWithInvalidTotalArgument]
from mypy_extensions import TypedDict
A = TypedDict('A', {'x': int}, total=0) # E: TypedDict() "total" argument must be True or False
B = TypedDict('B', {'x': int}, total=bool) # E: TypedDict() "total" argument must be True or False
C = TypedDict('C', {'x': int}, x=False) # E: Unexpected keyword argument "x" for "TypedDict"
D = TypedDict('D', {'x': int}, False) # E: Unexpected arguments to TypedDict()
[builtins fixtures/dict.pyi]

[case testTypedDictWithTotalFalse]
from mypy_extensions import TypedDict
D = TypedDict('D', {'x': int}, total=False)
d: D
reveal_type(d) # E: Revealed type is 'TypedDict(x=builtins.int, _fallback=__main__.D, _total=False)'
[builtins fixtures/dict.pyi]

[case testTypedDictIndexingWithNonRequiredKey]
from mypy_extensions import TypedDict
D = TypedDict('D', {'x': int, 'y': str}, total=False)
d: D
v = d['x'] # E: TypedDict key 'x' may be undefined \
# N: Consider using get() instead
reveal_type(v) # E: Revealed type is 'builtins.int'
w = d['y'] # E: TypedDict key 'y' may be undefined \
# N: Consider using get() instead
reveal_type(w) # E: Revealed type is 'builtins.str'
reveal_type(d.get('x')) # E: Revealed type is 'builtins.int'
reveal_type(d.get('y')) # E: Revealed type is 'builtins.str'
[builtins fixtures/dict.pyi]
[typing fixtures/typing-full.pyi]


-- Create Type (Errors)

[case testCannotCreateTypedDictTypeWithTooFewArguments]
from mypy_extensions import TypedDict
Point = TypedDict('Point') # E: Too few arguments for TypedDict()
[builtins fixtures/dict.pyi]

[case testCannotCreateTypedDictTypeWithTooManyArguments]
from mypy_extensions import TypedDict
Point = TypedDict('Point', {'x': int, 'y': int}, dict) # E: Unexpected arguments to TypedDict()
[builtins fixtures/dict.pyi]

[case testCannotCreateTypedDictTypeWithInvalidName]
from mypy_extensions import TypedDict
Point = TypedDict(dict, {'x': int, 'y': int}) # E: TypedDict() expects a string literal as the first argument
[builtins fixtures/dict.pyi]

[case testCannotCreateTypedDictTypeWithInvalidItems]
from mypy_extensions import TypedDict
Point = TypedDict('Point', {'x'}) # E: TypedDict() expects a dictionary literal as the second argument
[builtins fixtures/dict.pyi]

-- NOTE: The following code works at runtime but is not yet supported by mypy.
-- Keyword arguments may potentially be supported in the future.
[case testCannotCreateTypedDictTypeWithNonpositionalArgs]
from mypy_extensions import TypedDict
Point = TypedDict(typename='Point', fields={'x': int, 'y': int}) # E: Unexpected arguments to TypedDict()
[builtins fixtures/dict.pyi]

[case testCannotCreateTypedDictTypeWithInvalidItemName]
from mypy_extensions import TypedDict
Point = TypedDict('Point', {int: int, int: int}) # E: Invalid TypedDict() field name
[builtins fixtures/dict.pyi]

[case testCannotCreateTypedDictTypeWithInvalidItemType]
from mypy_extensions import TypedDict
Point = TypedDict('Point', {'x': 1, 'y': 1}) # E: Invalid field type
[builtins fixtures/dict.pyi]
2 changes: 1 addition & 1 deletion test-data/unit/lib-stub/mypy_extensions.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ def VarArg(type: _T = ...) -> _T: ...
def KwArg(type: _T = ...) -> _T: ...


def TypedDict(typename: str, fields: Dict[str, Type[_T]]) -> Type[dict]: ...
def TypedDict(typename: str, fields: Dict[str, Type[_T]], *, total: Any = ...) -> Type[dict]: ...

class NoReturn: pass
Loading
0