From fd45f428ea9242e012fe329e80e131ee3d6b38b9 Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Tue, 23 Jul 2024 21:22:24 -0700 Subject: [PATCH 1/2] gh-119180: Improvements to ForwardRef.evaluate Noticed some issues while writing documentation for this method. --- Lib/annotationlib.py | 27 +++++++++++++--------- Lib/test/test_annotationlib.py | 41 ++++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 11 deletions(-) diff --git a/Lib/annotationlib.py b/Lib/annotationlib.py index b4036ffb189c2d..20a4d46552a563 100644 --- a/Lib/annotationlib.py +++ b/Lib/annotationlib.py @@ -74,7 +74,7 @@ def __init_subclass__(cls, /, *args, **kwds): def evaluate(self, *, globals=None, locals=None, type_params=None, owner=None): """Evaluate the forward reference and return the value. - If the forward reference is not evaluatable, raise an exception. + If the forward reference cannot be evaluated, raise an exception. """ if self.__forward_evaluated__: return self.__forward_value__ @@ -89,12 +89,10 @@ def evaluate(self, *, globals=None, locals=None, type_params=None, owner=None): return value if owner is None: owner = self.__owner__ - if type_params is None and owner is None: - raise TypeError("Either 'type_params' or 'owner' must be provided") - if self.__forward_module__ is not None: + if globals is None and self.__forward_module__ is not None: globals = getattr( - sys.modules.get(self.__forward_module__, None), "__dict__", globals + sys.modules.get(self.__forward_module__, None), "__dict__", None ) if globals is None: globals = self.__globals__ @@ -112,14 +110,14 @@ def evaluate(self, *, globals=None, locals=None, type_params=None, owner=None): if locals is None: locals = {} - if isinstance(self.__owner__, type): - locals.update(vars(self.__owner__)) + if isinstance(owner, type): + locals.update(vars(owner)) - if type_params is None and self.__owner__ is not None: + if type_params is None and owner is not None: # "Inject" type parameters into the local namespace # (unless they are shadowed by assignments *in* the local namespace), # as a way of emulating annotation scopes when calling `eval()` - type_params = getattr(self.__owner__, "__type_params__", None) + type_params = getattr(owner, "__type_params__", None) # type parameters require some special handling, # as they exist in their own scope @@ -129,7 +127,14 @@ def evaluate(self, *, globals=None, locals=None, type_params=None, owner=None): # but should in turn be overridden by names in the class scope # (which here are called `globalns`!) if type_params is not None: - globals, locals = dict(globals), dict(locals) + if globals is None: + globals = {} + else: + globals = dict(globals) + if locals is None: + locals = {} + else: + locals = dict(locals) for param in type_params: param_name = param.__name__ if not self.__forward_is_class__ or param_name not in globals: @@ -413,7 +418,7 @@ def __missing__(self, key): return fwdref -def call_annotate_function(annotate, format, owner=None): +def call_annotate_function(annotate, format, *, owner=None): """Call an __annotate__ function. __annotate__ functions are normally generated by the compiler to defer the evaluation of annotations. They can be called with any of the format arguments in the Format enum, but diff --git a/Lib/test/test_annotationlib.py b/Lib/test/test_annotationlib.py index e68d63c91d1a73..d7370bc93f9c0b 100644 --- a/Lib/test/test_annotationlib.py +++ b/Lib/test/test_annotationlib.py @@ -4,6 +4,7 @@ import functools import pickle import unittest +from annotationlib import Format, ForwardRef from typing import Unpack from test.test_inspect import inspect_stock_annotations @@ -248,6 +249,46 @@ def test_special_attrs(self): with self.assertRaises(TypeError): pickle.dumps(fr, proto) + def test_evaluate_with_type_params(self): + class Gen[T]: + alias = int + + with self.assertRaises(NameError): + ForwardRef("T").evaluate() + with self.assertRaises(NameError): + ForwardRef("T").evaluate(type_params=()) + with self.assertRaises(NameError): + ForwardRef("T").evaluate(owner=int) + + T, = Gen.__type_params__ + self.assertIs(ForwardRef("T").evaluate(type_params=Gen.__type_params__), T) + self.assertIs(ForwardRef("T").evaluate(owner=Gen), T) + + with self.assertRaises(NameError): + ForwardRef("alias").evaluate(type_params=Gen.__type_params__) + self.assertIs(ForwardRef("alias").evaluate(owner=Gen), int) + # If you pass custom locals, we don't look at the owner's locals + with self.assertRaises(NameError): + ForwardRef("alias").evaluate(owner=Gen, locals={}) + # But if the name exists in the locals, it works + self.assertIs( + ForwardRef("alias").evaluate(owner=Gen, locals={"alias": str}), str + ) + + def test_fwdref_with_module(self): + self.assertIs(ForwardRef("Format", module=annotationlib).evaluate(), Format) + + with self.assertRaises(NameError): + # If globals are passed explicitly, we don't look at the module dict + ForwardRef("Format", module=annotationlib).evaluate(globals={}) + + def test_fwdref_value_is_cached(self): + fr = ForwardRef("hello") + with self.assertRaises(NameError): + fr.evaluate() + self.assertIs(fr.evaluate(globals={"hello": str}), str) + self.assertIs(fr.evaluate(), str) + class TestGetAnnotations(unittest.TestCase): def test_builtin_type(self): From 749ca9d43c41863eac90f2e07639f9161b1990b1 Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Tue, 23 Jul 2024 22:35:54 -0700 Subject: [PATCH 2/2] Fix test --- Lib/typing.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/Lib/typing.py b/Lib/typing.py index 626053d8166160..080816f56d977c 100644 --- a/Lib/typing.py +++ b/Lib/typing.py @@ -474,6 +474,10 @@ def _eval_type(t, globalns, localns, type_params=_sentinel, *, recursive_guard=f _deprecation_warning_for_no_type_params_passed("typing._eval_type") type_params = () if isinstance(t, ForwardRef): + # If the forward_ref has __forward_module__ set, evaluate() infers the globals + # from the module, and it will probably pick better than the globals we have here. + if t.__forward_module__ is not None: + globalns = None return evaluate_forward_ref(t, globals=globalns, locals=localns, type_params=type_params, owner=owner, _recursive_guard=recursive_guard, format=format)