10000 Polymorphic inference: support for parameter specifications and lambdas by ilevkivskyi · Pull Request #15837 · python/mypy · GitHub
[go: up one dir, main page]

Skip to content

Polymorphic inference: support for parameter specifications and lambdas #15837

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Aug 15, 2023
Prev Previous commit
Next Next commit
Fix and support Concatenate
  • Loading branch information
Ivan Levkivskyi committed Aug 9, 2023
commit d4c91462072d3525310eef77a651a11dc0846c42
1 change: 0 additions & 1 deletion mypy/applytype.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ def apply_generic_arguments(
if param_spec is not None:
nt = id_to_type.get(param_spec.id)
if nt is not None:
nt = get_proper_type(nt)
if isinstance(nt, Parameters):
callable = callable.expand_param_spec(nt)

Expand Down
119 changes: 73 additions & 46 deletions mypy/constraints.py
< 8000 tr data-hunk="b688565943c43e866c574669f6faa47606e6b3f48c95fe86c13b9da7f8c874ed" class="show-top-border">
Original file line number Diff line number Diff line change
Expand Up @@ -576,18 +576,22 @@ def visit_unpack_type(self, template: UnpackType) -> list[Constraint]:
raise RuntimeError("Mypy bug: unpack should be handled at a higher level.")

def visit_parameters(self, template: Parameters) -> list[Constraint]:
# constraining Any against C[P] turns into infer_against_any([P], Any)
# ... which seems like the only case this can happen. Better to fail loudly.
# Constraining Any against C[P] turns into infer_against_any([P], Any)
# ... which seems like the only case this can happen. Better to fail loudly otherwise.
if isinstance(self.actual, AnyType):
return self.infer_against_any(template.arg_types, self.actual)
if type_state.infer_polymorphic and isinstance(self.actual, Parameters):
# For polymorphic inference we need to be able to infer secondary constraints
# in situations like [x: T] <: P <: [x: int].
res = []
if len(template.arg_types) == len(self.actual.arg_types):
# TODO: this may assume positional arguments
for tt, at, k in zip(
template.arg_types, self.actual.arg_types, self.actual.arg_kinds
for tt, at, tk, ak in zip(
template.arg_types,
self.actual.arg_types,
template.arg_kinds,
self.actual.arg_kinds,
):
if k in (ARG_STAR, ARG_STAR2):
if tk == ARG_STAR and ak != ARG_STAR or tk == ARG_STAR2 and ak != ARG_STAR2:
continue
res.extend(infer_constraints(tt, at, self.direction))
return res
Expand Down Expand Up @@ -696,7 +700,6 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
# N.B: We use zip instead of indexing because the lengths might have
# mismatches during daemon reprocessing.
for tvar, mapped_arg, instance_arg in zip(tvars, mapped_args, instance_args):
# TODO(PEP612): More ParamSpec work (or is Parameters the only thing accepted)
if isinstance(tvar, TypeVarType):
# The constraints for generic type parameters depend on variance.
# Include constraints from both directions if invariant.
Expand All @@ -707,21 +710,27 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
infer_constraints(mapped_arg, instance_arg, neg_op(self.direction))
)
elif isinstance(tvar, ParamSpecType) and isinstance(mapped_arg, ParamSpecType):
suffix = get_proper_type(instance_arg)

if isinstance(suffix, Parameters):
# No such thing as variance for ParamSpecs, consider them covariant
# TODO: is there a case I am missing?
prefix = mapped_arg.prefix
if isinstance(instance_arg, Parameters):
# No such thing as variance for ParamSpecs, consider them invariant
# TODO: constraints between prefixes
prefix = mapped_arg.prefix
suffix = suffix.copy_modified(
suffix.arg_types[len(prefix.arg_types) :],
suffix.arg_kinds[len(prefix.arg_kinds) :],
suffix.arg_names[len(prefix.arg_names) :],
suffix: Type = instance_arg.copy_modified(
instance_arg.arg_types[len(prefix.arg_types) :],
instance_arg.arg_kinds[len(prefix.arg_kinds) :],
instance_arg.arg_names[len(prefix.arg_names) :],
)
res.append(Constraint(mapped_arg, SUBTYPE_OF, suffix))
res.append(Constraint(mapped_arg, SUPERTYPE_OF, suffix))
elif isinstance(instance_arg, ParamSpecType):
suffix = instance_arg.copy_modified(
prefix=Parameters(
instance_arg.prefix.arg_types[len(prefix.arg_types) :],
instance_arg.prefix.arg_kinds[len(prefix.arg_kinds) :],
instance_arg.prefix.arg_names[len(prefix.arg_names) :],
)
)
res.append(Constraint(mapped_arg, self.direction, suffix))
elif isinstance(suffix, ParamSpecType):
res.append(Constraint(mapped_arg, self.direction, suffix))
res.append(Constraint(mapped_arg, SUBTYPE_OF, suffix))
res.append(Constraint(mapped_arg, SUPERTYPE_OF, suffix))
else:
# This case should have been handled above.
assert not isinstance(tvar, TypeVarTupleType)
Expand Down Expand Up @@ -772,22 +781,27 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
elif isinstance(tvar, ParamSpecType) and isinstance(
template_arg, ParamSpecType
):
suffix = get_proper_type(mapped_arg)

if isinstance(suffix, Parameters):
# No such thing as variance for ParamSpecs, consider them covariant
# TODO: is there a case I am missing?
prefix = template_arg.prefix
if isinstance(mapped_arg, Parameters):
# No such thing as variance for ParamSpecs, consider them invariant
# TODO: constraints between prefixes
prefix = template_arg.prefix

suffix = suffix.copy_modified(
suffix.arg_types[len(prefix.arg_types) :],
suffix.arg_kinds[len(prefix.arg_kinds) :],
suffix.arg_names[len(prefix.arg_names) :],
suffix = mapped_arg.copy_modified(
mapped_arg.arg_types[len(prefix.arg_types) :],
mapped_arg.arg_kinds[len(prefix.arg_kinds) :],
mapped_arg.arg_names[len(prefix.arg_names) :],
)
res.append(Constraint(template_arg, self.direction, suffix))
elif isinstance(suffix, ParamSpecType):
res.append(Constraint(template_arg, self.direction, suffix))
res.append(Constraint(template_arg, SUBTYPE_OF, suffix))
res.append(Constraint(template_arg, SUPERTYPE_OF, suffix))
elif isinstance(mapped_arg, ParamSpecType):
suffix = mapped_arg.copy_modified(
prefix=Parameters(
mapped_arg.prefix.arg_types[len(prefix.arg_types) :],
mapped_arg.prefix.arg_kinds[len(prefix.arg_kinds) :],
mapped_arg.prefix.arg_names[len(prefix.arg_names) :],
)
)
res.append(Constraint(template_arg, SUBTYPE_OF, suffix))
res.append(Constraint(template_arg, SUPERTYPE_OF, suffix))
else:
# This case should have been handled above.
assert not isinstance(tvar, TypeVarTupleType)
Expand Down Expand Up @@ -926,7 +940,8 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
# We can't infer constraints from arguments if the template is Callable[..., T]
# (with literal '...').
if not template.is_ellipsis_args:
if find_unpack_in_list(template.arg_types) is not None:
unpack_present = find_unpack_in_list(template.arg_types)
if unpack_present is not None:
(
unpack_constraints,
cactual_args_t,
Expand All @@ -942,17 +957,25 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
template_args = template.arg_types
cactual_args = cactual.arg_types
# The lengths should match, but don't crash (it will error elsewhere).
for t, a in zip(template_args, cactual_args):
if isinstance(a, ParamSpecType) and not isinstance(t, ParamSpecType):
for t, a, tk, ak in zip(
template_args, cactual_args, template.arg_kinds, cactual.arg_kinds
):
# Unpack may have shifted indices.
if not unpack_present:
# This avoids bogus constraints like T <: P.args
# TODO: figure out a more principled way to skip arg_kind mismatch
# (see also a similar to do item in corresponding branch below)
if (
tk == ARG_STAR
and ak != ARG_STAR
or tk == ARG_STAR2
and ak != ARG_STAR2
):
continue
if isinstance(a, ParamSpecType):
# TODO: can we infer something useful for *T vs P?
continue
# Negate direction due to function argument type contravariance.
res.extend(infer_constraints(t, a, neg_op(self.direction)))
else:
# sometimes, it appears we try to get constraints between two paramspec callables?

# TODO: check the prefixes match
prefix = param_spec.prefix
prefix_len = len(prefix.arg_types)
Expand Down Expand Up @@ -985,19 +1008,23 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
)
)
else:
if not param_spec.prefix.arg_types or cactual_ps.prefix.arg_types:
# TODO: figure out a more general logic to reject shorter prefix in actual.
# This may be actually fixed by a more general to do item above.
if len(param_spec.prefix.arg_types) <= len(cactual_ps.prefix.arg_types):
cactual_ps = cactual_ps.copy_modified(
prefix=Parameters(
arg_types=cactual_ps.prefix.arg_types[prefix_len:],
arg_kinds=cactual_ps.prefix.arg_kinds[prefix_len:],
arg_names=cactual_ps.prefix.arg_names[prefix_len:],
)
)
res.append(Constraint(param_spec, neg_op(self.direction), cactual_ps))

# compare prefixes
# Compare prefixes as well
cactual_prefix = cactual.copy_modified(
arg_types=cactual.arg_types[:prefix_len],
arg_kinds=cactual.arg_kinds[:prefix_len],
arg_names=cactual.arg_names[:prefix_len],
)

# TODO: this may assume positional arguments
for t, a, k in zip(
prefix.arg_types, cactual_prefix.arg_types, cactual_prefix.arg_kinds
):
Expand Down
100 changes: 39 additions & 61 deletions mypy/expandtype.py
10000
Original file line number Diff line number Diff line change
Expand Up @@ -231,44 +231,27 @@ def visit_type_var(self, t: TypeVarType) -> Type:
return repl

def visit_param_spec(self, t: ParamSpecType) -> Type:
# set prefix to something empty so we don't duplicate it
repl = get_proper_type(
self.variables.get(t.id, t.copy_modified(prefix=Parameters([], [], [])))
)
if isinstance(repl, Instance):
# TODO: what does prefix mean in this case?
# TODO: why does this case even happen? Instances aren't plural.
return repl
elif isinstance(repl, (ParamSpecType, Parameters)):
if isinstance(repl, ParamSpecType):
return repl.copy_modified(
flavor=t.flavor,
prefix=t.prefix.copy_modified(
arg_types=t.prefix.arg_types + repl.prefix.arg_types,
arg_kinds=t.prefix.arg_kinds + repl.prefix.arg_kinds,
arg_names=t.prefix.arg_names + repl.prefix.arg_names,
),
)
else:
# if the paramspec is *P.args or **P.kwargs:
if t.flavor != ParamSpecFlavor.BARE:
assert isinstance(repl, CallableType), "Should not be able to get here."
# Is this always the right thing to do?
param_spec = repl.param_spec()
if param_spec:
return param_spec.with_flavor(t.flavor)
else:
return repl
else:
return Parameters(
t.prefix.arg_types + repl.arg_types,
t.prefix.arg_kinds + repl.arg_kinds,
t.prefix.arg_names + repl.arg_names,
variables=[*t.prefix.variables, *repl.variables],
)

# Set prefix to something empty, so we don't duplicate it below.
repl = self.variables.get(t.id, t.copy_modified(prefix=Parameters([], [], [])))
if isinstance(repl, ParamSpecType):
return repl.copy_modified(
flavor=t.flavor,
prefix=t.prefix.copy_modified(
arg_types=self.expand_types(t.prefix.arg_types + repl.prefix.arg_types),
arg_kinds=t.prefix.arg_kinds + repl.prefix.arg_kinds,
arg_names=t.prefix.arg_names + repl.prefix.arg_names,
),
)
elif isinstance(repl, Parameters):
assert t.flavor == ParamSpecFlavor.BARE
return Parameters(
self.expand_types(t.prefix.arg_types + repl.arg_types),
t.prefix.arg_kinds + repl.arg_kinds,
t.prefix.arg_names + repl.arg_names,
variables=[*t.prefix.variables, *repl.variables],
)
else:
# TODO: should this branch be removed? better not to fail silently
# TODO: replace this with "assert False"
return repl

def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type:
Expand Down Expand Up @@ -387,7 +370,7 @@ def interpolate_args_for_unpack(
def visit_callable_type(self, t: CallableType) -> CallableType:
param_spec = t.param_spec()
if param_spec is not None:
repl = get_proper_type(self.variables.get(param_spec.id))
repl = self.variables.get(param_spec.id)
# If a ParamSpec in a callable type is substituted with a
# callable type, we can't use normal substitution logic,
# since ParamSpec is actually split into two components
Expand All @@ -396,34 +379,29 @@ def visit_callable_type(self, t: CallableType) -> CallableType:
# kinds and names in the replacement. The return type in
# the replacement is ignored.
if isinstance(repl, Parameters):
# Substitute *args: P.args, **kwargs: P.kwargs
prefix = param_spec.prefix
# we need to expand the types in the prefix, so might as well
# not get them in the first place
t = t.expand_param_spec(repl, no_prefix=True)
# We need to expand both the types in the prefix and the ParamSpec itself
t = t.expand_param_spec(repl)
return t.copy_modified(
arg_types=self.expand_types(prefix.arg_types) + t.arg_types,
arg_kinds=prefix.arg_kinds + t.arg_kinds,
arg_names=prefix.arg_names + t.arg_names,
arg_types=self.expand_types(t.arg_types),
arg_kinds=t.arg_kinds,
arg_names=t.arg_names,
ret_type=t.ret_type.accept(self),
type_guard=(t.type_guard.accept(self) if t.type_guard is not None else None),
)
# TODO: Conceptually, the "len(t.arg_types) == 2" should not be here. However, this
# errors without it. Either figure out how to eliminate this or place an
# explanation for why this is necessary.
elif isinstance(repl, ParamSpecType) and len(t.arg_types) == 2:
# We're substituting one paramspec for another; this can mean that the prefix
# changes. (e.g. sub Concatenate[int, P] for Q)
elif isinstance(repl, ParamSpecType):
# We're substituting one ParamSpec for another; this can mean that the prefix
# changes, e.g. substitute Concatenate[int, P] in place of Q.
prefix = repl.prefix
old_prefix = param_spec.prefix

# Check assumptions. I'm not sure what order to place new prefix vs old prefix:
assert not old_prefix.arg_types or not prefix.arg_types

t = t.copy_modified(
arg_types=prefix.arg_types + old_prefix.arg_types + t.arg_types,
arg_kinds=prefix.arg_kinds + old_prefix.arg_kinds + t.arg_kinds,
arg_names=prefix.arg_names + old_prefix.arg_names + t.arg_names,
clean_repl = repl.copy_modified(prefix=Parameters([], [], []))
return t.copy_modified(
arg_types=self.expand_types(t.arg_types[:-2] + prefix.arg_types)
+ [
clean_repl.with_flavor(ParamSpecFlavor.ARGS),
clean_repl.with_flavor(ParamSpecFlavor.KWARGS),
],
arg_kinds=t.arg_kinds[:-2] + prefix.arg_kinds + t.arg_kinds[-2:],
arg_names=t.arg_names[:-2] + prefix.arg_names + t.arg_names[-2:],
Comment on lines +402 to +403
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have a reason for this ordering? I recall not being able to come up with an ordering (and now that I look at my own code here I'm not sure what I was even thinking, using both param_spec and repl...)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What effectively happens is we have Concatenate[X, P], where P = Concatenate[Y, Q], and for positional arguments, nested concatenates simply flatten. I didn't think much about non-positional arguments, but IIUC for them order is not that important and some other parts of the code already work well only for positional arguments.

ret_type=t.ret_type.accept(self),
)

var_arg = t.var_arg()
Expand Down
4 changes: 3 additions & 1 deletion mypy/solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,9 @@ def is_trivial_bound(tp: ProperType) -> bool:


def normalize_constraints(
constraints: list[Constraint], vars: list[TypeVarId]
# TODO: delete this function?
constraints: list[Constraint],
vars: list[TypeVarId],
) -> list[Constraint]:
"""Normalize list of constraints (to simplify life for the non-linear solver).

Expand Down
Loading
0