From 62cf252d3187486a0a854d520c1ab02d4f078edf Mon Sep 17 00:00:00 2001 From: CPython Developers <> Date: Mon, 29 Jul 2024 20:24:24 +0900 Subject: [PATCH 1/5] Update dataclasses from CPython3.12 --- Lib/dataclasses.py | 321 +++++--- .../__init__.py} | 735 +++++++++++++++++- .../test_dataclasses/dataclass_module_1.py | 32 + .../dataclass_module_1_str.py | 32 + .../test_dataclasses/dataclass_module_2.py | 32 + .../dataclass_module_2_str.py | 32 + .../test_dataclasses/dataclass_textanno.py | 12 + 7 files changed, 1071 insertions(+), 125 deletions(-) rename Lib/test/{test_dataclasses.py => test_dataclasses/__init__.py} (84%) create mode 100644 Lib/test/test_dataclasses/dataclass_module_1.py create mode 100644 Lib/test/test_dataclasses/dataclass_module_1_str.py create mode 100644 Lib/test/test_dataclasses/dataclass_module_2.py create mode 100644 Lib/test/test_dataclasses/dataclass_module_2_str.py create mode 100644 Lib/test/test_dataclasses/dataclass_textanno.py diff --git a/Lib/dataclasses.py b/Lib/dataclasses.py index e1687a117d..c32a30fb29 100644 --- a/Lib/dataclasses.py +++ b/Lib/dataclasses.py @@ -4,8 +4,8 @@ import types import inspect import keyword -import builtins import functools +import itertools import abc import _thread from types import FunctionType, GenericAlias @@ -222,6 +222,49 @@ def __repr__(self): # https://bugs.python.org/issue33453 for details. _MODULE_IDENTIFIER_RE = re.compile(r'^(?:\s*(\w+)\s*\.)?\s*(\w+)') +# Atomic immutable types which don't require any recursive handling and for which deepcopy +# returns the same object. We can provide a fast-path for these types in asdict and astuple. +_ATOMIC_TYPES = frozenset({ + # Common JSON Serializable types + types.NoneType, + bool, + int, + float, + str, + # Other common types + complex, + bytes, + # Other types that are also unaffected by deepcopy + types.EllipsisType, + types.NotImplementedType, + types.CodeType, + types.BuiltinFunctionType, + types.FunctionType, + type, + range, + property, +}) + +# This function's logic is copied from "recursive_repr" function in +# reprlib module to avoid dependency. +def _recursive_repr(user_function): + # Decorator to make a repr function return "..." for a recursive + # call. + repr_running = set() + + @functools.wraps(user_function) + def wrapper(self): + key = id(self), _thread.get_ident() + if key in repr_running: + return '...' + repr_running.add(key) + try: + result = user_function(self) + finally: + repr_running.discard(key) + return result + return wrapper + class InitVar: __slots__ = ('type', ) @@ -229,7 +272,7 @@ def __init__(self, type): self.type = type def __repr__(self): - if isinstance(self.type, type) and not isinstance(self.type, GenericAlias): + if isinstance(self.type, type): type_name = self.type.__name__ else: # typing objects, e.g. List[int] @@ -279,6 +322,7 @@ def __init__(self, default, default_factory, init, repr, hash, compare, self.kw_only = kw_only self._field_type = None + @_recursive_repr def __repr__(self): return ('Field(' f'name={self.name!r},' @@ -297,7 +341,7 @@ def __repr__(self): # This is used to support the PEP 487 __set_name__ protocol in the # case where we're using a field that contains a descriptor as a # default value. For details on __set_name__, see - # https://www.python.org/dev/peps/pep-0487/#implementation-details. + # https://peps.python.org/pep-0487/#implementation-details. # # Note that in _process_class, this Field object is overwritten # with the default value, so the end result is a descriptor that @@ -319,15 +363,25 @@ class _DataclassParams: 'order', 'unsafe_hash', 'frozen', + 'match_args', + 'kw_only', + 'slots', + 'weakref_slot', ) - def __init__(self, init, repr, eq, order, unsafe_hash, frozen): + def __init__(self, + init, repr, eq, order, unsafe_hash, frozen, + match_args, kw_only, slots, weakref_slot): self.init = init self.repr = repr self.eq = eq self.order = order self.unsafe_hash = unsafe_hash self.frozen = frozen + self.match_args = match_args + self.kw_only = kw_only + self.slots = slots + self.weakref_slot = weakref_slot def __repr__(self): return ('_DataclassParams(' @@ -336,7 +390,11 @@ def __repr__(self): f'eq={self.eq!r},' f'order={self.order!r},' f'unsafe_hash={self.unsafe_hash!r},' - f'frozen={self.frozen!r}' + f'frozen={self.frozen!r},' + f'match_args={self.match_args!r},' + f'kw_only={self.kw_only!r},' + f'slots={self.slots!r},' + f'weakref_slot={self.weakref_slot!r}' ')') @@ -388,27 +446,6 @@ def _tuple_str(obj_name, fields): return f'({",".join([f"{obj_name}.{f.name}" for f in fields])},)' -# This function's logic is copied from "recursive_repr" function in -# reprlib module to avoid dependency. -def _recursive_repr(user_function): - # Decorator to make a repr function return "..." for a recursive - # call. - repr_running = set() - - @functools.wraps(user_function) - def wrapper(self): - key = id(self), _thread.get_ident() - if key in repr_running: - return '...' - repr_running.add(key) - try: - result = user_function(self) - finally: - repr_running.discard(key) - return result - return wrapper - - def _create_fn(name, args, body, *, globals=None, locals=None, return_type=MISSING): # Note that we may mutate locals. Callers beware! @@ -418,14 +455,18 @@ def _create_fn(name, args, body, *, globals=None, locals=None, locals = {} return_annotation = '' if return_type is not MISSING: - locals['_return_type'] = return_type - return_annotation = '->_return_type' + locals['__dataclass_return_type__'] = return_type + return_annotation = '->__dataclass_return_type__' args = ','.join(args) body = '\n'.join(f' {b}' for b in body) # Compute the text of the entire function. txt = f' def {name}({args}){return_annotation}:\n{body}' + # Free variables in exec are resolved in the global namespace. + # The global namespace we have is user-provided, so we can't modify it for + # our purposes. So we put the things we need into locals and introduce a + # scope to allow the function we're creating to close over them. local_vars = ', '.join(locals.keys()) txt = f"def __create_fn__({local_vars}):\n{txt}\n return {name}" ns = {} @@ -449,14 +490,14 @@ def _field_init(f, frozen, globals, self_name, slots): # Return the text of the line in the body of __init__ that will # initialize this field. - default_name = f'_dflt_{f.name}' + default_name = f'__dataclass_dflt_{f.name}__' if f.default_factory is not MISSING: if f.init: # This field has a default factory. If a parameter is # given, use it. If not, call the factory. globals[default_name] = f.default_factory value = (f'{default_name}() ' - f'if {f.name} is _HAS_DEFAULT_FACTORY ' + f'if {f.name} is __dataclass_HAS_DEFAULT_FACTORY__ ' f'else {f.name}') else: # This is a field that's not in the __init__ params, but @@ -517,11 +558,11 @@ def _init_param(f): elif f.default is not MISSING: # There's a default, this will be the name that's used to look # it up. - default = f'=_dflt_{f.name}' + default = f'=__dataclass_dflt_{f.name}__' elif f.default_factory is not MISSING: # There's a factory function. Set a marker. - default = '=_HAS_DEFAULT_FACTORY' - return f'{f.name}:_type_{f.name}{default}' + default = '=__dataclass_HAS_DEFAULT_FACTORY__' + return f'{f.name}:__dataclass_type_{f.name}__{default}' def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init, @@ -544,10 +585,9 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init, raise TypeError(f'non-default argument {f.name!r} ' 'follows default argument') - locals = {f'_type_{f.name}': f.type for f in fields} + locals = {f'__dataclass_type_{f.name}__': f.type for f in fields} locals.update({ - 'MISSING': MISSING, - '_HAS_DEFAULT_FACTORY': _HAS_DEFAULT_FACTORY, + '__dataclass_HAS_DEFAULT_FACTORY__': _HAS_DEFAULT_FACTORY, '__dataclass_builtins_object__': object, }) @@ -598,21 +638,19 @@ def _repr_fn(fields, globals): def _frozen_get_del_attr(cls, fields, globals): locals = {'cls': cls, 'FrozenInstanceError': FrozenInstanceError} + condition = 'type(self) is cls' if fields: - fields_str = '(' + ','.join(repr(f.name) for f in fields) + ',)' - else: - # Special case for the zero-length tuple. - fields_str = '()' + condition += ' or name in {' + ', '.join(repr(f.name) for f in fields) + '}' return (_create_fn('__setattr__', ('self', 'name', 'value'), - (f'if type(self) is cls or name in {fields_str}:', + (f'if {condition}:', ' raise FrozenInstanceError(f"cannot assign to field {name!r}")', f'super(cls, self).__setattr__(name, value)'), locals=locals, globals=globals), _create_fn('__delattr__', ('self', 'name'), - (f'if type(self) is cls or name in {fields_str}:', + (f'if {condition}:', ' raise FrozenInstanceError(f"cannot delete field {name!r}")', f'super(cls, self).__delattr__(name)'), locals=locals, @@ -807,8 +845,10 @@ def _get_field(cls, a_name, a_type, default_kw_only): raise TypeError(f'field {f.name} is a ClassVar but specifies ' 'kw_only') - # For real fields, disallow mutable defaults for known types. - if f._field_type is _FIELD and isinstance(f.default, (list, dict, set)): + # For real fields, disallow mutable defaults. Use unhashable as a proxy + # indicator for mutability. Read the __hash__ attribute from the class, + # not the instance. + if f._field_type is _FIELD and f.default.__class__.__hash__ is None: raise ValueError(f'mutable default {type(f.default)} for field ' f'{f.name} is not allowed: use default_factory') @@ -879,7 +919,7 @@ def _hash_exception(cls, fields, globals): def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen, - match_args, kw_only, slots): + match_args, kw_only, slots, weakref_slot): # Now that dicts retain insertion order, there's no reason to use # an ordered dict. I am leveraging that ordering here, because # derived class fields overwrite base class fields, but the order @@ -897,7 +937,9 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen, globals = {} setattr(cls, _PARAMS, _DataclassParams(init, repr, eq, order, - unsafe_hash, frozen)) + unsafe_hash, frozen, + match_args, kw_only, + slots, weakref_slot)) # Find our base classes in reverse MRO order, and exclude # ourselves. In reversed order so that more derived classes @@ -916,10 +958,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen, if getattr(b, _PARAMS).frozen: any_frozen_base = True - # Annotations that are defined in this class (not in base - # classes). If __annotations__ isn't present, then this class - # adds no new annotations. We use this to compute fields that are - # added by this class. + # Annotations defined specifically in this class (not in base classes). # # Fields are found from cls_annotations, which is guaranteed to be # ordered. Default values are from class attributes, if a field @@ -928,7 +967,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen, # actual default value. Pseudo-fields ClassVars and InitVars are # included, despite the fact that they're not real fields. That's # dealt with later. - cls_annotations = cls.__dict__.get('__annotations__', {}) + cls_annotations = inspect.get_annotations(cls) # Now find fields in our class. While doing so, validate some # things, and set the default values (as class attributes) where @@ -1089,16 +1128,24 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen, if not getattr(cls, '__doc__'): # Create a class doc-string. - cls.__doc__ = (cls.__name__ + - str(inspect.signature(cls)).replace(' -> None', '')) + try: + # In some cases fetching a signature is not possible. + # But, we surely should not fail in this case. + text_sig = str(inspect.signature(cls)).replace(' -> None', '') + except (TypeError, ValueError): + text_sig = '' + cls.__doc__ = (cls.__name__ + text_sig) if match_args: # I could probably compute this once _set_new_attribute(cls, '__match_args__', tuple(f.name for f in std_init_fields)) + # It's an error to specify weakref_slot if slots is False. + if weakref_slot and not slots: + raise TypeError('weakref_slot is True but slots is False') if slots: - cls = _add_slots(cls, frozen) + cls = _add_slots(cls, frozen, weakref_slot) abc.update_abstractmethods(cls) @@ -1106,7 +1153,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen, # _dataclass_getstate and _dataclass_setstate are needed for pickling frozen -# classes with slots. These could be slighly more performant if we generated +# classes with slots. These could be slightly more performant if we generated # the code instead of iterating over fields. But that can be a project for # another day, if performance becomes an issue. def _dataclass_getstate(self): @@ -1119,7 +1166,30 @@ def _dataclass_setstate(self, state): object.__setattr__(self, field.name, value) -def _add_slots(cls, is_frozen): +def _get_slots(cls): + match cls.__dict__.get('__slots__'): + # `__dictoffset__` and `__weakrefoffset__` can tell us whether + # the base type has dict/weakref slots, in a way that works correctly + # for both Python classes and C extension types. Extension types + # don't use `__slots__` for slot creation + case None: + slots = [] + if getattr(cls, '__weakrefoffset__', -1) != 0: + slots.append('__weakref__') + if getattr(cls, '__dictrefoffset__', -1) != 0: + slots.append('__dict__') + yield from slots + case str(slot): + yield slot + # Slots may be any iterable, but we cannot handle an iterator + # because it will already be (partially) consumed. + case iterable if not hasattr(iterable, '__next__'): + yield from iterable + case _: + raise TypeError(f"Slots of '{cls.__name__}' cannot be determined") + + +def _add_slots(cls, is_frozen, weakref_slot): # Need to create a new class, since we can't set __slots__ # after a class has been created. @@ -1130,7 +1200,23 @@ def _add_slots(cls, is_frozen): # Create a new dict for our new class. cls_dict = dict(cls.__dict__) field_names = tuple(f.name for f in fields(cls)) - cls_dict['__slots__'] = field_names + # Make sure slots don't overlap with those in base classes. + inherited_slots = set( + itertools.chain.from_iterable(map(_get_slots, cls.__mro__[1:-1])) + ) + # The slots for our class. Remove slots from our base classes. Add + # '__weakref__' if weakref_slot was given, unless it is already present. + cls_dict["__slots__"] = tuple( + itertools.filterfalse( + inherited_slots.__contains__, + itertools.chain( + # gh-93521: '__weakref__' also needs to be filtered out if + # already present in inherited_slots + field_names, ('__weakref__',) if weakref_slot else () + ) + ), + ) + for field_name in field_names: # Remove our attributes, if present. They'll still be # available in _MARKER. @@ -1139,6 +1225,9 @@ def _add_slots(cls, is_frozen): # Remove __dict__ itself. cls_dict.pop('__dict__', None) + # Clear existing `__weakref__` descriptor, it belongs to a previous type: + cls_dict.pop('__weakref__', None) # gh-102069 + # And finally create the class. qualname = getattr(cls, '__qualname__', None) cls = type(cls)(cls.__name__, cls.__bases__, cls_dict) @@ -1147,33 +1236,35 @@ def _add_slots(cls, is_frozen): if is_frozen: # Need this for pickling frozen classes with slots. - cls.__getstate__ = _dataclass_getstate - cls.__setstate__ = _dataclass_setstate + if '__getstate__' not in cls_dict: + cls.__getstate__ = _dataclass_getstate + if '__setstate__' not in cls_dict: + cls.__setstate__ = _dataclass_setstate return cls def dataclass(cls=None, /, *, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False, match_args=True, - kw_only=False, slots=False): - """Returns the same class as was passed in, with dunder methods - added based on the fields defined in the class. + kw_only=False, slots=False, weakref_slot=False): + """Add dunder methods based on the fields defined in the class. Examines PEP 526 __annotations__ to determine fields. - If init is true, an __init__() method is added to the class. If - repr is true, a __repr__() method is added. If order is true, rich + If init is true, an __init__() method is added to the class. If repr + is true, a __repr__() method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a - __hash__() method function is added. If frozen is true, fields may - not be assigned to after instance creation. If match_args is true, - the __match_args__ tuple is added. If kw_only is true, then by - default all fields are keyword-only. If slots is true, an - __slots__ attribute is added. + __hash__() method is added. If frozen is true, fields may not be + assigned to after instance creation. If match_args is true, the + __match_args__ tuple is added. If kw_only is true, then by default + all fields are keyword-only. If slots is true, a new class with a + __slots__ attribute is returned. """ def wrap(cls): return _process_class(cls, init, repr, eq, order, unsafe_hash, - frozen, match_args, kw_only, slots) + frozen, match_args, kw_only, slots, + weakref_slot) # See if we're being called as @dataclass or @dataclass(). if cls is None: @@ -1195,7 +1286,7 @@ def fields(class_or_instance): try: fields = getattr(class_or_instance, _FIELDS) except AttributeError: - raise TypeError('must be called with a dataclass type or instance') + raise TypeError('must be called with a dataclass type or instance') from None # Exclude pseudo-fields. Note that fields is sorted by insertion # order, so the order of the tuple is as the fields were defined. @@ -1210,7 +1301,7 @@ def _is_dataclass_instance(obj): def is_dataclass(obj): """Returns True if obj is a dataclass or an instance of a dataclass.""" - cls = obj if isinstance(obj, type) and not isinstance(obj, GenericAlias) else type(obj) + cls = obj if isinstance(obj, type) else type(obj) return hasattr(cls, _FIELDS) @@ -1218,7 +1309,7 @@ def asdict(obj, *, dict_factory=dict): """Return the fields of a dataclass instance as a new dictionary mapping field names to field values. - Example usage: + Example usage:: @dataclass class C: @@ -1231,7 +1322,7 @@ class C: If given, 'dict_factory' will be used instead of built-in dict. The function applies recursively to field values that are dataclass instances. This will also look into built-in containers: - tuples, lists, and dicts. + tuples, lists, and dicts. Other objects are copied with 'copy.deepcopy()'. """ if not _is_dataclass_instance(obj): raise TypeError("asdict() should be called on dataclass instances") @@ -1239,12 +1330,21 @@ class C: def _asdict_inner(obj, dict_factory): - if _is_dataclass_instance(obj): - result = [] - for f in fields(obj): - value = _asdict_inner(getattr(obj, f.name), dict_factory) - result.append((f.name, value)) - return dict_factory(result) + if type(obj) in _ATOMIC_TYPES: + return obj + elif _is_dataclass_instance(obj): + # fast path for the common case + if dict_factory is dict: + return { + f.name: _asdict_inner(getattr(obj, f.name), dict) + for f in fields(obj) + } + else: + result = [] + for f in fields(obj): + value = _asdict_inner(getattr(obj, f.name), dict_factory) + result.append((f.name, value)) + return dict_factory(result) elif isinstance(obj, tuple) and hasattr(obj, '_fields'): # obj is a namedtuple. Recurse into it, but the returned # object is another namedtuple of the same type. This is @@ -1272,6 +1372,13 @@ def _asdict_inner(obj, dict_factory): # above). return type(obj)(_asdict_inner(v, dict_factory) for v in obj) elif isinstance(obj, dict): + if hasattr(type(obj), 'default_factory'): + # obj is a defaultdict, which has a different constructor from + # dict as it requires the default_factory as its first arg. + result = type(obj)(getattr(obj, 'default_factory')) + for k, v in obj.items(): + result[_asdict_inner(k, dict_factory)] = _asdict_inner(v, dict_factory) + return result return type(obj)((_asdict_inner(k, dict_factory), _asdict_inner(v, dict_factory)) for k, v in obj.items()) @@ -1289,13 +1396,13 @@ class C: x: int y: int - c = C(1, 2) - assert astuple(c) == (1, 2) + c = C(1, 2) + assert astuple(c) == (1, 2) If given, 'tuple_factory' will be used instead of built-in tuple. The function applies recursively to field values that are dataclass instances. This will also look into built-in containers: - tuples, lists, and dicts. + tuples, lists, and dicts. Other objects are copied with 'copy.deepcopy()'. """ if not _is_dataclass_instance(obj): @@ -1304,7 +1411,9 @@ class C: def _astuple_inner(obj, tuple_factory): - if _is_dataclass_instance(obj): + if type(obj) in _ATOMIC_TYPES: + return obj + elif _is_dataclass_instance(obj): result = [] for f in fields(obj): value = _astuple_inner(getattr(obj, f.name), tuple_factory) @@ -1324,7 +1433,15 @@ def _astuple_inner(obj, tuple_factory): # above). return type(obj)(_astuple_inner(v, tuple_factory) for v in obj) elif isinstance(obj, dict): - return type(obj)((_astuple_inner(k, tuple_factory), _astuple_inner(v, tuple_factory)) + obj_type = type(obj) + if hasattr(obj_type, 'default_factory'): + # obj is a defaultdict, which has a different constructor from + # dict as it requires the default_factory as its first arg. + result = obj_type(getattr(obj, 'default_factory')) + for k, v in obj.items(): + result[_astuple_inner(k, tuple_factory)] = _astuple_inner(v, tuple_factory) + return result + return obj_type((_astuple_inner(k, tuple_factory), _astuple_inner(v, tuple_factory)) for k, v in obj.items()) else: return copy.deepcopy(obj) @@ -1332,17 +1449,18 @@ def _astuple_inner(obj, tuple_factory): def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True, repr=True, eq=True, order=False, unsafe_hash=False, - frozen=False, match_args=True, kw_only=False, slots=False): + frozen=False, match_args=True, kw_only=False, slots=False, + weakref_slot=False, module=None): """Return a new dynamically created dataclass. The dataclass name will be 'cls_name'. 'fields' is an iterable of either (name), (name, type) or (name, type, Field) objects. If type is omitted, use the string 'typing.Any'. Field objects are created by - the equivalent of calling 'field(name, type [, Field-info])'. + the equivalent of calling 'field(name, type [, Field-info])'.:: C = make_dataclass('C', ['x', ('y', int), ('z', int, field(init=False))], bases=(Base,)) - is equivalent to: + is equivalent to:: @dataclass class C(Base): @@ -1352,8 +1470,11 @@ class C(Base): For the bases and namespace parameters, see the builtin type() function. - The parameters init, repr, eq, order, unsafe_hash, and frozen are passed to - dataclass(). + The parameters init, repr, eq, order, unsafe_hash, frozen, match_args, kw_only, + slots, and weakref_slot are passed to dataclass(). + + If module parameter is defined, the '__module__' attribute of the dataclass is + set to that value. """ if namespace is None: @@ -1396,16 +1517,30 @@ def exec_body_callback(ns): # of generic dataclasses. cls = types.new_class(cls_name, bases, {}, exec_body_callback) + # For pickling to work, the __module__ variable needs to be set to the frame + # where the dataclass is created. + if module is None: + try: + module = sys._getframemodulename(1) or '__main__' + except AttributeError: + try: + module = sys._getframe(1).f_globals.get('__name__', '__main__') + except (AttributeError, ValueError): + pass + if module is not None: + cls.__module__ = module + # Apply the normal decorator. return dataclass(cls, init=init, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen, - match_args=match_args, kw_only=kw_only, slots=slots) + match_args=match_args, kw_only=kw_only, slots=slots, + weakref_slot=weakref_slot) def replace(obj, /, **changes): """Return a new object replacing specified fields with new values. - This is especially useful for frozen classes. Example usage: + This is especially useful for frozen classes. Example usage:: @dataclass(frozen=True) class C: @@ -1415,7 +1550,7 @@ class C: c = C(1, 2) c1 = replace(c, x=3) assert c1.x == 3 and c1.y == 2 - """ + """ # We're going to mutate 'changes', but that's okay because it's a # new dict, even if called with 'replace(obj, **my_changes)'. diff --git a/Lib/test/test_dataclasses.py b/Lib/test/test_dataclasses/__init__.py similarity index 84% rename from Lib/test/test_dataclasses.py rename to Lib/test/test_dataclasses/__init__.py index 62ae5622a8..e15b34570e 100644 --- a/Lib/test/test_dataclasses.py +++ b/Lib/test/test_dataclasses/__init__.py @@ -5,20 +5,25 @@ from dataclasses import * import abc +import io import pickle import inspect import builtins import types +import weakref +import traceback import unittest from unittest.mock import Mock -from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional, Protocol +from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional, Protocol, DefaultDict from typing import get_type_hints -from collections import deque, OrderedDict, namedtuple +from collections import deque, OrderedDict, namedtuple, defaultdict from functools import total_ordering import typing # Needed for the string "typing.ClassVar[int]" to work as an annotation. import dataclasses # Needed for the string "dataclasses.InitVar[int]" to work as an annotation. +from test import support + # Just any custom exception we can catch. class CustomError(Exception): pass @@ -67,6 +72,50 @@ def test_field_repr(self): self.assertEqual(repr_output, expected_output) + def test_field_recursive_repr(self): + rec_field = field() + rec_field.type = rec_field + rec_field.name = "id" + repr_output = repr(rec_field) + + self.assertIn(",type=...,", repr_output) + + def test_recursive_annotation(self): + class C: + pass + + @dataclass + class D: + C: C = field() + + self.assertIn(",type=...,", repr(D.__dataclass_fields__["C"])) + + def test_dataclass_params_repr(self): + # Even though this is testing an internal implementation detail, + # it's testing a feature we want to make sure is correctly implemented + # for the sake of dataclasses itself + @dataclass(slots=True, frozen=True) + class Some: pass + + repr_output = repr(Some.__dataclass_params__) + expected_output = "_DataclassParams(init=True,repr=True," \ + "eq=True,order=False,unsafe_hash=False,frozen=True," \ + "match_args=True,kw_only=False," \ + "slots=True,weakref_slot=False)" + self.assertEqual(repr_output, expected_output) + + def test_dataclass_params_signature(self): + # Even though this is testing an internal implementation detail, + # it's testing a feature we want to make sure is correctly implemented + # for the sake of dataclasses itself + @dataclass + class Some: pass + + for param in inspect.signature(dataclass).parameters: + if param == 'cls': + continue + self.assertTrue(hasattr(Some.__dataclass_params__, param), msg=param) + def test_named_init_params(self): @dataclass class C: @@ -230,6 +279,31 @@ class C: c = C('foo') self.assertEqual(c.object, 'foo') + def test_field_named_BUILTINS_frozen(self): + # gh-96151 + @dataclass(frozen=True) + class C: + BUILTINS: int + c = C(5) + self.assertEqual(c.BUILTINS, 5) + + def test_field_with_special_single_underscore_names(self): + # gh-98886 + + @dataclass + class X: + x: int = field(default_factory=lambda: 111) + _dflt_x: int = field(default_factory=lambda: 222) + + X() + + @dataclass + class Y: + y: int = field(default_factory=lambda: 111) + _HAS_DEFAULT_FACTORY: int = 222 + + assert Y(y=222).y == 222 + def test_field_named_like_builtin(self): # Attribute names can shadow built-in names # since code generation is used. @@ -501,6 +575,32 @@ class C: self.assertNotEqual(C(3), C(4, 10)) self.assertNotEqual(C(3, 10), C(4, 10)) + def test_no_unhashable_default(self): + # See bpo-44674. + class Unhashable: + __hash__ = None + + unhashable_re = 'mutable default .* for field a is not allowed' + with self.assertRaisesRegex(ValueError, unhashable_re): + @dataclass + class A: + a: dict = {} + + with self.assertRaisesRegex(ValueError, unhashable_re): + @dataclass + class A: + a: Any = Unhashable() + + # Make sure that the machinery looking for hashability is using the + # class's __hash__, not the instance's __hash__. + with self.assertRaisesRegex(ValueError, unhashable_re): + unhashable = Unhashable() + # This shouldn't make the variable hashable. + unhashable.__hash__ = lambda: 0 + @dataclass + class A: + a: Any = unhashable + def test_hash_field_rules(self): # Test all 6 cases of: # hash=True/False/None @@ -659,8 +759,8 @@ class Point: class Subclass(typ): pass with self.assertRaisesRegex(ValueError, - f"mutable default .*Subclass'>" - ' for field z is not allowed' + "mutable default .*Subclass'>" + " for field z is not allowed" ): @dataclass class Point: @@ -990,6 +1090,65 @@ def __post_init__(cls): self.assertEqual((c.x, c.y), (3, 4)) self.assertTrue(C.flag) + def test_post_init_not_auto_added(self): + # See bpo-46757, which had proposed always adding __post_init__. As + # Raymond Hettinger pointed out, that would be a breaking change. So, + # add a test to make sure that the current behavior doesn't change. + + @dataclass + class A0: + pass + + @dataclass + class B0: + b_called: bool = False + def __post_init__(self): + self.b_called = True + + @dataclass + class C0(A0, B0): + c_called: bool = False + def __post_init__(self): + super().__post_init__() + self.c_called = True + + # Since A0 has no __post_init__, and one wasn't automatically added + # (because that's the rule: it's never added by @dataclass, it's only + # the class author that can add it), then B0.__post_init__ is called. + # Verify that. + c = C0() + self.assertTrue(c.b_called) + self.assertTrue(c.c_called) + + ###################################### + # Now, the same thing, except A1 defines __post_init__. + @dataclass + class A1: + def __post_init__(self): + pass + + @dataclass + class B1: + b_called: bool = False + def __post_init__(self): + self.b_called = True + + @dataclass + class C1(A1, B1): + c_called: bool = False + def __post_init__(self): + super().__post_init__() + self.c_called = True + + # This time, B1.__post_init__ isn't being called. This mimics what + # would happen if A1.__post_init__ had been automatically added, + # instead of manually added as we see here. This test isn't really + # needed, but I'm including it just to demonstrate the changed + # behavior when A1 does define __post_init__. + c = C1() + self.assertFalse(c.b_called) + self.assertTrue(c.c_called) + def test_class_var(self): # Make sure ClassVars are ignored in __init__, __repr__, etc. @dataclass @@ -1158,6 +1317,29 @@ def __post_init__(self, init_base, init_derived): c = C(10, 11, 50, 51) self.assertEqual(vars(c), {'x': 21, 'y': 101}) + def test_init_var_name_shadowing(self): + # Because dataclasses rely exclusively on `__annotations__` for + # handling InitVar and `__annotations__` preserves shadowed definitions, + # you can actually shadow an InitVar with a method or property. + # + # This only works when there is no default value; `dataclasses` uses the + # actual name (which will be bound to the shadowing method) for default + # values. + @dataclass + class C: + shadowed: InitVar[int] + _shadowed: int = field(init=False) + + def __post_init__(self, shadowed): + self._shadowed = shadowed * 2 + + @property + def shadowed(self): + return self._shadowed * 3 + + c = C(5) + self.assertEqual(c.shadowed, 30) + def test_default_factory(self): # Test a factory that returns a new list. @dataclass @@ -1365,6 +1547,24 @@ class A(types.GenericAlias): self.assertTrue(is_dataclass(type(a))) self.assertTrue(is_dataclass(a)) + def test_is_dataclass_inheritance(self): + @dataclass + class X: + y: int + + class Z(X): + pass + + self.assertTrue(is_dataclass(X), "X should be a dataclass") + self.assertTrue( + is_dataclass(Z), + "Z should be a dataclass because it inherits from X", + ) + z_instance = Z(y=5) + self.assertTrue( + is_dataclass(z_instance), + "z_instance should be a dataclass because it is an instance of Z", + ) def test_helper_fields_with_class_instance(self): # Check that we can call fields() on either a class or instance, @@ -1388,6 +1588,16 @@ class C: pass with self.assertRaisesRegex(TypeError, 'dataclass type or instance'): fields(C()) + def test_clean_traceback_from_fields_exception(self): + stdout = io.StringIO() + try: + fields(object) + except TypeError as exc: + traceback.print_exception(exc, file=stdout) + printed_traceback = stdout.getvalue() + self.assertNotIn("AttributeError", printed_traceback) + self.assertNotIn("__dataclass_fields__", printed_traceback) + def test_helper_asdict(self): # Basic tests for asdict(), it should return a new dictionary. @dataclass @@ -1565,6 +1775,21 @@ class C: self.assertIsNot(d['f'], t) self.assertEqual(d['f'].my_a(), 6) + def test_helper_asdict_defaultdict(self): + # Ensure asdict() does not throw exceptions when a + # defaultdict is a member of a dataclass + @dataclass + class C: + mp: DefaultDict[str, List] + + dd = defaultdict(list) + dd["x"].append(12) + c = C(mp=dd) + d = asdict(c) + + self.assertEqual(d, {"mp": {"x": [12]}}) + self.assertTrue(d["mp"] is not c.mp) # make sure defaultdict is copied + def test_helper_astuple(self): # Basic tests for astuple(), it should return a new tuple. @dataclass @@ -1692,6 +1917,21 @@ class C: t = astuple(c, tuple_factory=list) self.assertEqual(t, ['outer', T(1, ['inner', T(11, 12, 13)], 2)]) + def test_helper_astuple_defaultdict(self): + # Ensure astuple() does not throw exceptions when a + # defaultdict is a member of a dataclass + @dataclass + class C: + mp: DefaultDict[str, List] + + dd = defaultdict(list) + dd["x"].append(12) + c = C(mp=dd) + t = astuple(c) + + self.assertEqual(t, ({"x": [12]},)) + self.assertTrue(t[0] is not dd) # make sure defaultdict is copied + def test_dynamic_class_creation(self): cls_dict = {'__annotations__': {'x': int, 'y': int}, } @@ -2019,6 +2259,7 @@ def assertDocStrEqual(self, a, b): # whitespace stripped. self.assertEqual(a.replace(' ', ''), b.replace(' ', '')) + @support.requires_docstrings def test_existing_docstring_not_overridden(self): @dataclass class C: @@ -2086,8 +2327,6 @@ class C: self.assertDocStrEqual(C.__doc__, "C(x:List[int]=)") - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_docstring_deque_field(self): @dataclass class C: @@ -2095,8 +2334,6 @@ class C: self.assertDocStrEqual(C.__doc__, "C(x:collections.deque)") - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_docstring_deque_field_with_default_factory(self): @dataclass class C: @@ -2104,13 +2341,25 @@ class C: self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=)") + def test_docstring_with_no_signature(self): + # See https://github.com/python/cpython/issues/103449 + class Meta(type): + __call__ = dict + class Base(metaclass=Meta): + pass + + @dataclass + class C(Base): + pass + + self.assertDocStrEqual(C.__doc__, "C") + class TestInit(unittest.TestCase): def test_base_has_init(self): class B: def __init__(self): self.z = 100 - pass # Make sure that declaring this class doesn't raise an error. # The issue is that we can't override __init__ in our class, @@ -2133,12 +2382,12 @@ class C(B): self.assertEqual(c.z, 100) def test_no_init(self): - dataclass(init=False) + @dataclass(init=False) class C: i: int = 0 self.assertEqual(C().i, 0) - dataclass(init=False) + @dataclass(init=False) class C: i: int = 2 def __init__(self): @@ -2617,6 +2866,19 @@ class C: c.i = 5 self.assertEqual(c.i, 10) + def test_frozen_empty(self): + @dataclass(frozen=True) + class C: + pass + + c = C() + self.assertFalse(hasattr(c, 'i')) + with self.assertRaises(FrozenInstanceError): + c.i = 5 + self.assertFalse(hasattr(c, 'i')) + with self.assertRaises(FrozenInstanceError): + del c.i + def test_inherit(self): @dataclass(frozen=True) class C: @@ -2740,6 +3002,37 @@ class S(D): self.assertEqual(s.y, 10) self.assertEqual(s.cached, True) + with self.assertRaises(FrozenInstanceError): + del s.x + self.assertEqual(s.x, 3) + with self.assertRaises(FrozenInstanceError): + del s.y + self.assertEqual(s.y, 10) + del s.cached + self.assertFalse(hasattr(s, 'cached')) + with self.assertRaises(AttributeError) as cm: + del s.cached + self.assertNotIsInstance(cm.exception, FrozenInstanceError) + + def test_non_frozen_normal_derived_from_empty_frozen(self): + @dataclass(frozen=True) + class D: + pass + + class S(D): + pass + + s = S() + self.assertFalse(hasattr(s, 'x')) + s.x = 5 + self.assertEqual(s.x, 5) + + del s.x + self.assertFalse(hasattr(s, 'x')) + with self.assertRaises(AttributeError) as cm: + del s.x + self.assertNotIsInstance(cm.exception, FrozenInstanceError) + def test_overwriting_frozen(self): # frozen uses __setattr__ and __delattr__. with self.assertRaisesRegex(TypeError, @@ -2780,8 +3073,6 @@ class C: class TestSlots(unittest.TestCase): - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_simple(self): @dataclass class C: @@ -2823,8 +3114,6 @@ class Derived(Base): # We can add a new field to the derived instance. d.z = 10 - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_generated_slots(self): @dataclass(slots=True) class C: @@ -2849,23 +3138,58 @@ class C: x: int def test_generated_slots_value(self): - @dataclass(slots=True) - class Base: - x: int - self.assertEqual(Base.__slots__, ('x',)) + class Root: + __slots__ = {'x'} + + class Root2(Root): + __slots__ = {'k': '...', 'j': ''} + + class Root3(Root2): + __slots__ = ['h'] + + class Root4(Root3): + __slots__ = 'aa' @dataclass(slots=True) - class Delivered(Base): + class Base(Root4): y: int + j: str + h: str - self.assertEqual(Delivered.__slots__, ('x', 'y')) + self.assertEqual(Base.__slots__, ('y', )) + + @dataclass(slots=True) + class Derived(Base): + aa: float + x: str + z: int + k: str + h: str + + self.assertEqual(Derived.__slots__, ('z', )) @dataclass - class AnotherDelivered(Base): + class AnotherDerived(Base): z: int - self.assertTrue('__slots__' not in AnotherDelivered.__dict__) + self.assertNotIn('__slots__', AnotherDerived.__dict__) + + def test_cant_inherit_from_iterator_slots(self): + + class Root: + __slots__ = iter(['a']) + + class Root2(Root): + __slots__ = ('b', ) + + with self.assertRaisesRegex( + TypeError, + "^Slots of 'Root' cannot be determined" + ): + @dataclass(slots=True) + class C(Root2): + x: int def test_returns_new_class(self): class A: @@ -2904,6 +3228,74 @@ def test_frozen_pickle(self): self.assertIsNot(obj, p) self.assertEqual(obj, p) + @dataclass(frozen=True, slots=True) + class FrozenSlotsGetStateClass: + foo: str + bar: int + + getstate_called: bool = field(default=False, compare=False) + + def __getstate__(self): + object.__setattr__(self, 'getstate_called', True) + return [self.foo, self.bar] + + @dataclass(frozen=True, slots=True) + class FrozenSlotsSetStateClass: + foo: str + bar: int + + setstate_called: bool = field(default=False, compare=False) + + def __setstate__(self, state): + object.__setattr__(self, 'setstate_called', True) + object.__setattr__(self, 'foo', state[0]) + object.__setattr__(self, 'bar', state[1]) + + @dataclass(frozen=True, slots=True) + class FrozenSlotsAllStateClass: + foo: str + bar: int + + getstate_called: bool = field(default=False, compare=False) + setstate_called: bool = field(default=False, compare=False) + + def __getstate__(self): + object.__setattr__(self, 'getstate_called', True) + return [self.foo, self.bar] + + def __setstate__(self, state): + object.__setattr__(self, 'setstate_called', True) + object.__setattr__(self, 'foo', state[0]) + object.__setattr__(self, 'bar', state[1]) + + def test_frozen_slots_pickle_custom_state(self): + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(proto=proto): + obj = self.FrozenSlotsGetStateClass('a', 1) + dumped = pickle.dumps(obj, protocol=proto) + + self.assertTrue(obj.getstate_called) + self.assertEqual(obj, pickle.loads(dumped)) + + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(proto=proto): + obj = self.FrozenSlotsSetStateClass('a', 1) + obj2 = pickle.loads(pickle.dumps(obj, protocol=proto)) + + self.assertTrue(obj2.setstate_called) + self.assertEqual(obj, obj2) + + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(proto=proto): + obj = self.FrozenSlotsAllStateClass('a', 1) + dumped = pickle.dumps(obj, protocol=proto) + + self.assertTrue(obj.getstate_called) + + obj2 = pickle.loads(dumped) + self.assertTrue(obj2.setstate_called) + self.assertEqual(obj, obj2) + def test_slots_with_default_no_init(self): # Originally reported in bpo-44649. @dataclass(slots=True) @@ -2926,6 +3318,249 @@ class A: self.assertEqual(obj.a, 'a') self.assertEqual(obj.b, 'b') + def test_slots_no_weakref(self): + @dataclass(slots=True) + class A: + # No weakref. + pass + + self.assertNotIn("__weakref__", A.__slots__) + a = A() + with self.assertRaisesRegex(TypeError, + "cannot create weak reference"): + weakref.ref(a) + with self.assertRaises(AttributeError): + a.__weakref__ + + def test_slots_weakref(self): + @dataclass(slots=True, weakref_slot=True) + class A: + a: int + + self.assertIn("__weakref__", A.__slots__) + a = A(1) + a_ref = weakref.ref(a) + + self.assertIs(a.__weakref__, a_ref) + + def test_slots_weakref_base_str(self): + class Base: + __slots__ = '__weakref__' + + @dataclass(slots=True) + class A(Base): + a: int + + # __weakref__ is in the base class, not A. But an A is still weakref-able. + self.assertIn("__weakref__", Base.__slots__) + self.assertNotIn("__weakref__", A.__slots__) + a = A(1) + weakref.ref(a) + + def test_slots_weakref_base_tuple(self): + # Same as test_slots_weakref_base, but use a tuple instead of a string + # in the base class. + class Base: + __slots__ = ('__weakref__',) + + @dataclass(slots=True) + class A(Base): + a: int + + # __weakref__ is in the base class, not A. But an A is still + # weakref-able. + self.assertIn("__weakref__", Base.__slots__) + self.assertNotIn("__weakref__", A.__slots__) + a = A(1) + weakref.ref(a) + + def test_weakref_slot_without_slot(self): + with self.assertRaisesRegex(TypeError, + "weakref_slot is True but slots is False"): + @dataclass(weakref_slot=True) + class A: + a: int + + def test_weakref_slot_make_dataclass(self): + A = make_dataclass('A', [('a', int),], slots=True, weakref_slot=True) + self.assertIn("__weakref__", A.__slots__) + a = A(1) + weakref.ref(a) + + # And make sure if raises if slots=True is not given. + with self.assertRaisesRegex(TypeError, + "weakref_slot is True but slots is False"): + B = make_dataclass('B', [('a', int),], weakref_slot=True) + + def test_weakref_slot_subclass_weakref_slot(self): + @dataclass(slots=True, weakref_slot=True) + class Base: + field: int + + # A *can* also specify weakref_slot=True if it wants to (gh-93521) + @dataclass(slots=True, weakref_slot=True) + class A(Base): + ... + + # __weakref__ is in the base class, not A. But an instance of A + # is still weakref-able. + self.assertIn("__weakref__", Base.__slots__) + self.assertNotIn("__weakref__", A.__slots__) + a = A(1) + a_ref = weakref.ref(a) + self.assertIs(a.__weakref__, a_ref) + + def test_weakref_slot_subclass_no_weakref_slot(self): + @dataclass(slots=True, weakref_slot=True) + class Base: + field: int + + @dataclass(slots=True) + class A(Base): + ... + + # __weakref__ is in the base class, not A. Even though A doesn't + # specify weakref_slot, it should still be weakref-able. + self.assertIn("__weakref__", Base.__slots__) + self.assertNotIn("__weakref__", A.__slots__) + a = A(1) + a_ref = weakref.ref(a) + self.assertIs(a.__weakref__, a_ref) + + def test_weakref_slot_normal_base_weakref_slot(self): + class Base: + __slots__ = ('__weakref__',) + + @dataclass(slots=True, weakref_slot=True) + class A(Base): + field: int + + # __weakref__ is in the base class, not A. But an instance of + # A is still weakref-able. + self.assertIn("__weakref__", Base.__slots__) + self.assertNotIn("__weakref__", A.__slots__) + a = A(1) + a_ref = weakref.ref(a) + self.assertIs(a.__weakref__, a_ref) + + + def test_dataclass_derived_weakref_slot(self): + class A: + pass + + @dataclass(slots=True, weakref_slot=True) + class B(A): + pass + + self.assertEqual(B.__slots__, ()) + B() + + def test_dataclass_derived_generic(self): + T = typing.TypeVar('T') + + @dataclass(slots=True, weakref_slot=True) + class A(typing.Generic[T]): + pass + self.assertEqual(A.__slots__, ('__weakref__',)) + self.assertTrue(A.__weakref__) + A() + + @dataclass(slots=True, weakref_slot=True) + class B[T2]: + pass + self.assertEqual(B.__slots__, ('__weakref__',)) + self.assertTrue(B.__weakref__) + B() + + def test_dataclass_derived_generic_from_base(self): + T = typing.TypeVar('T') + + class RawBase: ... + + @dataclass(slots=True, weakref_slot=True) + class C1(typing.Generic[T], RawBase): + pass + self.assertEqual(C1.__slots__, ()) + self.assertTrue(C1.__weakref__) + C1() + @dataclass(slots=True, weakref_slot=True) + class C2(RawBase, typing.Generic[T]): + pass + self.assertEqual(C2.__slots__, ()) + self.assertTrue(C2.__weakref__) + C2() + + @dataclass(slots=True, weakref_slot=True) + class D[T2](RawBase): + pass + self.assertEqual(D.__slots__, ()) + self.assertTrue(D.__weakref__) + D() + + def test_dataclass_derived_generic_from_slotted_base(self): + T = typing.TypeVar('T') + + class WithSlots: + __slots__ = ('a', 'b') + + @dataclass(slots=True, weakref_slot=True) + class E1(WithSlots, Generic[T]): + pass + self.assertEqual(E1.__slots__, ('__weakref__',)) + self.assertTrue(E1.__weakref__) + E1() + @dataclass(slots=True, weakref_slot=True) + class E2(Generic[T], WithSlots): + pass + self.assertEqual(E2.__slots__, ('__weakref__',)) + self.assertTrue(E2.__weakref__) + E2() + + @dataclass(slots=True, weakref_slot=True) + class F[T2](WithSlots): + pass + self.assertEqual(F.__slots__, ('__weakref__',)) + self.assertTrue(F.__weakref__) + F() + + def test_dataclass_derived_generic_from_slotted_base(self): + T = typing.TypeVar('T') + + class WithWeakrefSlot: + __slots__ = ('__weakref__',) + + @dataclass(slots=True, weakref_slot=True) + class G1(WithWeakrefSlot, Generic[T]): + pass + self.assertEqual(G1.__slots__, ()) + self.assertTrue(G1.__weakref__) + G1() + @dataclass(slots=True, weakref_slot=True) + class G2(Generic[T], WithWeakrefSlot): + pass + self.assertEqual(G2.__slots__, ()) + self.assertTrue(G2.__weakref__) + G2() + + @dataclass(slots=True, weakref_slot=True) + class H[T2](WithWeakrefSlot): + pass + self.assertEqual(H.__slots__, ()) + self.assertTrue(H.__weakref__) + H() + + def test_dataclass_slot_dict(self): + class WithDictSlot: + __slots__ = ('__dict__',) + + @dataclass(slots=True) + class A(WithDictSlot): ... + + self.assertEqual(A.__slots__, ()) + self.assertEqual(A().__dict__, {}) + A() + + class TestDescriptors(unittest.TestCase): def test_set_name(self): # See bpo-33141. @@ -3210,10 +3845,10 @@ class C: self.assertEqual(C(10).x, 10) def test_classvar_module_level_import(self): - from test import dataclass_module_1 - from test import dataclass_module_1_str - from test import dataclass_module_2 - from test import dataclass_module_2_str + from test.test_dataclasses import dataclass_module_1 + from test.test_dataclasses import dataclass_module_1_str + from test.test_dataclasses import dataclass_module_2 + from test.test_dataclasses import dataclass_module_2_str for m in (dataclass_module_1, dataclass_module_1_str, dataclass_module_2, dataclass_module_2_str, @@ -3251,7 +3886,7 @@ def test_classvar_module_level_import(self): self.assertNotIn('not_iv4', c.__dict__) def test_text_annotations(self): - from test import dataclass_textanno + from test.test_dataclasses import dataclass_textanno self.assertEqual( get_type_hints(dataclass_textanno.Bar), @@ -3262,6 +3897,15 @@ def test_text_annotations(self): 'return': type(None)}) +ByMakeDataClass = make_dataclass('ByMakeDataClass', [('x', int)]) +ManualModuleMakeDataClass = make_dataclass('ManualModuleMakeDataClass', + [('x', int)], + module=__name__) +WrongNameMakeDataclass = make_dataclass('Wrong', [('x', int)]) +WrongModuleMakeDataclass = make_dataclass('WrongModuleMakeDataclass', + [('x', int)], + module='custom') + class TestMakeDataclass(unittest.TestCase): def test_simple(self): C = make_dataclass('C', @@ -3371,6 +4015,36 @@ def test_no_types(self): 'y': int, 'z': 'typing.Any'}) + def test_module_attr(self): + self.assertEqual(ByMakeDataClass.__module__, __name__) + self.assertEqual(ByMakeDataClass(1).__module__, __name__) + self.assertEqual(WrongModuleMakeDataclass.__module__, "custom") + Nested = make_dataclass('Nested', []) + self.assertEqual(Nested.__module__, __name__) + self.assertEqual(Nested().__module__, __name__) + + def test_pickle_support(self): + for klass in [ByMakeDataClass, ManualModuleMakeDataClass]: + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(proto=proto): + self.assertEqual( + pickle.loads(pickle.dumps(klass, proto)), + klass, + ) + self.assertEqual( + pickle.loads(pickle.dumps(klass(1), proto)), + klass(1), + ) + + def test_cannot_be_pickled(self): + for klass in [WrongNameMakeDataclass, WrongModuleMakeDataclass]: + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(proto=proto): + with self.assertRaises(pickle.PickleError): + pickle.dumps(klass, proto) + with self.assertRaises(pickle.PickleError): + pickle.dumps(klass(1), proto) + def test_invalid_type_specification(self): for bad_field in [(), (1, 2, 3, 4), @@ -3674,7 +4348,6 @@ class Date(Ordered): self.assertFalse(inspect.isabstract(Date)) self.assertGreater(Date(2020,12,25), Date(2020,8,31)) - @unittest.expectedFailure def test_maintain_abc(self): class A(abc.ABC): @abc.abstractmethod @@ -3688,7 +4361,7 @@ class Date(A): day: 'int' self.assertTrue(inspect.isabstract(Date)) - msg = 'class Date with abstract method foo' + msg = "class Date without an implementation for abstract method 'foo'" self.assertRaisesRegex(TypeError, msg, Date) @@ -3840,8 +4513,6 @@ class C: b: int = field(kw_only=True) self.assertEqual(C(42, b=10).__match_args__, ('a',)) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_KW_ONLY(self): @dataclass class A: diff --git a/Lib/test/test_dataclasses/dataclass_module_1.py b/Lib/test/test_dataclasses/dataclass_module_1.py new file mode 100644 index 0000000000..87a33f8191 --- /dev/null +++ b/Lib/test/test_dataclasses/dataclass_module_1.py @@ -0,0 +1,32 @@ +#from __future__ import annotations +USING_STRINGS = False + +# dataclass_module_1.py and dataclass_module_1_str.py are identical +# except only the latter uses string annotations. + +import dataclasses +import typing + +T_CV2 = typing.ClassVar[int] +T_CV3 = typing.ClassVar + +T_IV2 = dataclasses.InitVar[int] +T_IV3 = dataclasses.InitVar + +@dataclasses.dataclass +class CV: + T_CV4 = typing.ClassVar + cv0: typing.ClassVar[int] = 20 + cv1: typing.ClassVar = 30 + cv2: T_CV2 + cv3: T_CV3 + not_cv4: T_CV4 # When using string annotations, this field is not recognized as a ClassVar. + +@dataclasses.dataclass +class IV: + T_IV4 = dataclasses.InitVar + iv0: dataclasses.InitVar[int] + iv1: dataclasses.InitVar + iv2: T_IV2 + iv3: T_IV3 + not_iv4: T_IV4 # When using string annotations, this field is not recognized as an InitVar. diff --git a/Lib/test/test_dataclasses/dataclass_module_1_str.py b/Lib/test/test_dataclasses/dataclass_module_1_str.py new file mode 100644 index 0000000000..6de490b7ad --- /dev/null +++ b/Lib/test/test_dataclasses/dataclass_module_1_str.py @@ -0,0 +1,32 @@ +from __future__ import annotations +USING_STRINGS = True + +# dataclass_module_1.py and dataclass_module_1_str.py are identical +# except only the latter uses string annotations. + +import dataclasses +import typing + +T_CV2 = typing.ClassVar[int] +T_CV3 = typing.ClassVar + +T_IV2 = dataclasses.InitVar[int] +T_IV3 = dataclasses.InitVar + +@dataclasses.dataclass +class CV: + T_CV4 = typing.ClassVar + cv0: typing.ClassVar[int] = 20 + cv1: typing.ClassVar = 30 + cv2: T_CV2 + cv3: T_CV3 + not_cv4: T_CV4 # When using string annotations, this field is not recognized as a ClassVar. + +@dataclasses.dataclass +class IV: + T_IV4 = dataclasses.InitVar + iv0: dataclasses.InitVar[int] + iv1: dataclasses.InitVar + iv2: T_IV2 + iv3: T_IV3 + not_iv4: T_IV4 # When using string annotations, this field is not recognized as an InitVar. diff --git a/Lib/test/test_dataclasses/dataclass_module_2.py b/Lib/test/test_dataclasses/dataclass_module_2.py new file mode 100644 index 0000000000..68fb733e29 --- /dev/null +++ b/Lib/test/test_dataclasses/dataclass_module_2.py @@ -0,0 +1,32 @@ +#from __future__ import annotations +USING_STRINGS = False + +# dataclass_module_2.py and dataclass_module_2_str.py are identical +# except only the latter uses string annotations. + +from dataclasses import dataclass, InitVar +from typing import ClassVar + +T_CV2 = ClassVar[int] +T_CV3 = ClassVar + +T_IV2 = InitVar[int] +T_IV3 = InitVar + +@dataclass +class CV: + T_CV4 = ClassVar + cv0: ClassVar[int] = 20 + cv1: ClassVar = 30 + cv2: T_CV2 + cv3: T_CV3 + not_cv4: T_CV4 # When using string annotations, this field is not recognized as a ClassVar. + +@dataclass +class IV: + T_IV4 = InitVar + iv0: InitVar[int] + iv1: InitVar + iv2: T_IV2 + iv3: T_IV3 + not_iv4: T_IV4 # When using string annotations, this field is not recognized as an InitVar. diff --git a/Lib/test/test_dataclasses/dataclass_module_2_str.py b/Lib/test/test_dataclasses/dataclass_module_2_str.py new file mode 100644 index 0000000000..b363d17c17 --- /dev/null +++ b/Lib/test/test_dataclasses/dataclass_module_2_str.py @@ -0,0 +1,32 @@ +from __future__ import annotations +USING_STRINGS = True + +# dataclass_module_2.py and dataclass_module_2_str.py are identical +# except only the latter uses string annotations. + +from dataclasses import dataclass, InitVar +from typing import ClassVar + +T_CV2 = ClassVar[int] +T_CV3 = ClassVar + +T_IV2 = InitVar[int] +T_IV3 = InitVar + +@dataclass +class CV: + T_CV4 = ClassVar + cv0: ClassVar[int] = 20 + cv1: ClassVar = 30 + cv2: T_CV2 + cv3: T_CV3 + not_cv4: T_CV4 # When using string annotations, this field is not recognized as a ClassVar. + +@dataclass +class IV: + T_IV4 = InitVar + iv0: InitVar[int] + iv1: InitVar + iv2: T_IV2 + iv3: T_IV3 + not_iv4: T_IV4 # When using string annotations, this field is not recognized as an InitVar. diff --git a/Lib/test/test_dataclasses/dataclass_textanno.py b/Lib/test/test_dataclasses/dataclass_textanno.py new file mode 100644 index 0000000000..3eb6c943d4 --- /dev/null +++ b/Lib/test/test_dataclasses/dataclass_textanno.py @@ -0,0 +1,12 @@ +from __future__ import annotations + +import dataclasses + + +class Foo: + pass + + +@dataclasses.dataclass +class Bar: + foo: Foo From f3253fc1c0ebd449ee54074e87c951c1cd23f530 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Mon, 29 Jul 2024 20:38:17 +0900 Subject: [PATCH 2/5] mark failing tests --- Lib/dataclasses.py | 20 ++++++++++-------- Lib/test/test_dataclasses/__init__.py | 30 +++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 9 deletions(-) diff --git a/Lib/dataclasses.py b/Lib/dataclasses.py index c32a30fb29..95f5486b1d 100644 --- a/Lib/dataclasses.py +++ b/Lib/dataclasses.py @@ -1167,26 +1167,28 @@ def _dataclass_setstate(self, state): def _get_slots(cls): - match cls.__dict__.get('__slots__'): + __slots__ = cls.__dict__.get('__slots__') # `__dictoffset__` and `__weakrefoffset__` can tell us whether # the base type has dict/weakref slots, in a way that works correctly # for both Python classes and C extension types. Extension types # don't use `__slots__` for slot creation - case None: + if __slots__ is None: slots = [] if getattr(cls, '__weakrefoffset__', -1) != 0: slots.append('__weakref__') if getattr(cls, '__dictrefoffset__', -1) != 0: slots.append('__dict__') yield from slots - case str(slot): + elif type(__slots__) is str: + slot = __slots__ yield slot - # Slots may be any iterable, but we cannot handle an iterator - # because it will already be (partially) consumed. - case iterable if not hasattr(iterable, '__next__'): - yield from iterable - case _: - raise TypeError(f"Slots of '{cls.__name__}' cannot be determined") + # Slots may be any iterable, but we cannot handle an iterator + # because it will already be (partially) consumed. + elif not hasattr(__slots__, '__next__'): + iterable = __slots__ + yield from iterable + else: + raise TypeError(f"Slots of '{cls.__name__}' cannot be determined") def _add_slots(cls, is_frozen, weakref_slot): diff --git a/Lib/test/test_dataclasses/__init__.py b/Lib/test/test_dataclasses/__init__.py index e15b34570e..5b0245401f 100644 --- a/Lib/test/test_dataclasses/__init__.py +++ b/Lib/test/test_dataclasses/__init__.py @@ -1279,6 +1279,8 @@ def __post_init__(self, init_param): c = C(init_param=10) self.assertEqual(c.x, 20) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_init_var_preserve_type(self): self.assertEqual(InitVar[int].type, int) @@ -1537,6 +1539,8 @@ class B: with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'): replace(obj, x=0) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_is_dataclass_genericalias(self): @dataclass class A(types.GenericAlias): @@ -1775,6 +1779,8 @@ class C: self.assertIsNot(d['f'], t) self.assertEqual(d['f'].my_a(), 6) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_helper_asdict_defaultdict(self): # Ensure asdict() does not throw exceptions when a # defaultdict is a member of a dataclass @@ -1917,6 +1923,8 @@ class C: t = astuple(c, tuple_factory=list) self.assertEqual(t, ['outer', T(1, ['inner', T(11, 12, 13)], 2)]) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_helper_astuple_defaultdict(self): # Ensure astuple() does not throw exceptions when a # defaultdict is a member of a dataclass @@ -2327,6 +2335,8 @@ class C: self.assertDocStrEqual(C.__doc__, "C(x:List[int]=)") + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_docstring_deque_field(self): @dataclass class C: @@ -2334,6 +2344,8 @@ class C: self.assertDocStrEqual(C.__doc__, "C(x:collections.deque)") + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_docstring_deque_field_with_default_factory(self): @dataclass class C: @@ -3073,6 +3085,8 @@ class C: class TestSlots(unittest.TestCase): + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_simple(self): @dataclass class C: @@ -3114,6 +3128,8 @@ class Derived(Base): # We can add a new field to the derived instance. d.z = 10 + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_generated_slots(self): @dataclass(slots=True) class C: @@ -3318,6 +3334,8 @@ class A: self.assertEqual(obj.a, 'a') self.assertEqual(obj.b, 'b') + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_slots_no_weakref(self): @dataclass(slots=True) class A: @@ -3332,6 +3350,8 @@ class A: with self.assertRaises(AttributeError): a.__weakref__ + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_slots_weakref(self): @dataclass(slots=True, weakref_slot=True) class A: @@ -3392,6 +3412,8 @@ def test_weakref_slot_make_dataclass(self): "weakref_slot is True but slots is False"): B = make_dataclass('B', [('a', int),], weakref_slot=True) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_weakref_slot_subclass_weakref_slot(self): @dataclass(slots=True, weakref_slot=True) class Base: @@ -3410,6 +3432,8 @@ class A(Base): a_ref = weakref.ref(a) self.assertIs(a.__weakref__, a_ref) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_weakref_slot_subclass_no_weakref_slot(self): @dataclass(slots=True, weakref_slot=True) class Base: @@ -3427,6 +3451,8 @@ class A(Base): a_ref = weakref.ref(a) self.assertIs(a.__weakref__, a_ref) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_weakref_slot_normal_base_weakref_slot(self): class Base: __slots__ = ('__weakref__',) @@ -3472,6 +3498,8 @@ class B[T2]: self.assertTrue(B.__weakref__) B() + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_dataclass_derived_generic_from_base(self): T = typing.TypeVar('T') @@ -4513,6 +4541,8 @@ class C: b: int = field(kw_only=True) self.assertEqual(C(42, b=10).__match_args__, ('a',)) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_KW_ONLY(self): @dataclass class A: From 19be12094b81d4b50ad001f83e30d4ad637f17db Mon Sep 17 00:00:00 2001 From: CPython Developers <> Date: Sun, 28 Jul 2024 22:19:57 +0900 Subject: [PATCH 3/5] Update inspect from CPython 3.12 --- Lib/inspect.py | 1224 +++++++++++++++++++++++++++++------------------- 1 file changed, 732 insertions(+), 492 deletions(-) diff --git a/Lib/inspect.py b/Lib/inspect.py index a84f3346b3..4ff8791282 100644 --- a/Lib/inspect.py +++ b/Lib/inspect.py @@ -24,6 +24,8 @@ stack(), trace() - get info about frames on the stack or in a traceback signature() - get a Signature object for the callable + + get_annotations() - safely compute an object's annotations """ # This module is in the public domain. No warranties. @@ -31,7 +33,116 @@ __author__ = ('Ka-Ping Yee ', 'Yury Selivanov ') +__all__ = [ + "AGEN_CLOSED", + "AGEN_CREATED", + "AGEN_RUNNING", + "AGEN_SUSPENDED", + "ArgInfo", + "Arguments", + "Attribute", + "BlockFinder", + "BoundArguments", + "BufferFlags", + "CORO_CLOSED", + "CORO_CREATED", + "CORO_RUNNING", + "CORO_SUSPENDED", + "CO_ASYNC_GENERATOR", + "CO_COROUTINE", + "CO_GENERATOR", + "CO_ITERABLE_COROUTINE", + "CO_NESTED", + "CO_NEWLOCALS", + "CO_NOFREE", + "CO_OPTIMIZED", + "CO_VARARGS", + "CO_VARKEYWORDS", + "ClassFoundException", + "ClosureVars", + "EndOfBlock", + "FrameInfo", + "FullArgSpec", + "GEN_CLOSED", + "GEN_CREATED", + "GEN_RUNNING", + "GEN_SUSPENDED", + "Parameter", + "Signature", + "TPFLAGS_IS_ABSTRACT", + "Traceback", + "classify_class_attrs", + "cleandoc", + "currentframe", + "findsource", + "formatannotation", + "formatannotationrelativeto", + "formatargvalues", + "get_annotations", + "getabsfile", + "getargs", + "getargvalues", + "getasyncgenlocals", + "getasyncgenstate", + "getattr_static", + "getblock", + "getcallargs", + "getclasstree", + "getclosurevars", + "getcomments", + "getcoroutinelocals", + "getcoroutinestate", + "getdoc", + "getfile", + "getframeinfo", + "getfullargspec", + "getgeneratorlocals", + "getgeneratorstate", + "getinnerframes", + "getlineno", + "getmembers", + "getmembers_static", + "getmodule", + "getmodulename", + "getmro", + "getouterframes", + "getsource", + "getsourcefile", + "getsourcelines", + "indentsize", + "isabstract", + "isasyncgen", + "isasyncgenfunction", + "isawaitable", + "isbuiltin", + "isclass", + "iscode", + "iscoroutine", + "iscoroutinefunction", + "isdatadescriptor", + "isframe", + "isfunction", + "isgenerator", + "isgeneratorfunction", + "isgetsetdescriptor", + "ismemberdescriptor", + "ismethod", + "ismethoddescriptor", + "ismethodwrapper", + "ismodule", + "isroutine", + "istraceback", + "markcoroutinefunction", + "signature", + "stack", + "trace", + "unwrap", + "walktree", +] + + import abc +import ast import dis import collections.abc import enum @@ -44,9 +155,9 @@ import tokenize import token import types -import warnings import functools import builtins +from keyword import iskeyword from operator import attrgetter from collections import namedtuple, OrderedDict @@ -55,36 +166,138 @@ mod_dict = globals() for k, v in dis.COMPILER_FLAG_NAMES.items(): mod_dict["CO_" + v] = k +del k, v, mod_dict # See Include/object.h TPFLAGS_IS_ABSTRACT = 1 << 20 + +def get_annotations(obj, *, globals=None, locals=None, eval_str=False): + """Compute the annotations dict for an object. + + obj may be a callable, class, or module. + Passing in an object of any other type raises TypeError. + + Returns a dict. get_annotations() returns a new dict every time + it's called; calling it twice on the same object will return two + different but equivalent dicts. + + This function handles several details for you: + + * If eval_str is true, values of type str will + be un-stringized using eval(). This is intended + for use with stringized annotations + ("from __future__ import annotations"). + * If obj doesn't have an annotations dict, returns an + empty dict. (Functions and methods always have an + annotations dict; classes, modules, and other types of + callables may not.) + * Ignores inherited annotations on classes. If a class + doesn't have its own annotations dict, returns an empty dict. + * All accesses to object members and dict values are done + using getattr() and dict.get() for safety. + * Always, always, always returns a freshly-created dict. + + eval_str controls whether or not values of type str are replaced + with the result of calling eval() on those values: + + * If eval_str is true, eval() is called on values of type str. + * If eval_str is false (the default), values of type str are unchanged. + + globals and locals are passed in to eval(); see the documentation + for eval() for more information. If either globals or locals is + None, this function may replace that value with a context-specific + default, contingent on type(obj): + + * If obj is a module, globals defaults to obj.__dict__. + * If obj is a class, globals defaults to + sys.modules[obj.__module__].__dict__ and locals + defaults to the obj class namespace. + * If obj is a callable, globals defaults to obj.__globals__, + although if obj is a wrapped function (using + functools.update_wrapper()) it is first unwrapped. + """ + if isinstance(obj, type): + # class + obj_dict = getattr(obj, '__dict__', None) + if obj_dict and hasattr(obj_dict, 'get'): + ann = obj_dict.get('__annotations__', None) + if isinstance(ann, types.GetSetDescriptorType): + ann = None + else: + ann = None + + obj_globals = None + module_name = getattr(obj, '__module__', None) + if module_name: + module = sys.modules.get(module_name, None) + if module: + obj_globals = getattr(module, '__dict__', None) + obj_locals = dict(vars(obj)) + unwrap = obj + elif isinstance(obj, types.ModuleType): + # module + ann = getattr(obj, '__annotations__', None) + obj_globals = getattr(obj, '__dict__') + obj_locals = None + unwrap = None + elif callable(obj): + # this includes types.Function, types.BuiltinFunctionType, + # types.BuiltinMethodType, functools.partial, functools.singledispatch, + # "class funclike" from Lib/test/test_inspect... on and on it goes. + ann = getattr(obj, '__annotations__', None) + obj_globals = getattr(obj, '__globals__', None) + obj_locals = None + unwrap = obj + else: + raise TypeError(f"{obj!r} is not a module, class, or callable.") + + if ann is None: + return {} + + if not isinstance(ann, dict): + raise ValueError(f"{obj!r}.__annotations__ is neither a dict nor None") + + if not ann: + return {} + + if not eval_str: + return dict(ann) + + if unwrap is not None: + while True: + if hasattr(unwrap, '__wrapped__'): + unwrap = unwrap.__wrapped__ + continue + if isinstance(unwrap, functools.partial): + unwrap = unwrap.func + continue + break + if hasattr(unwrap, "__globals__"): + obj_globals = unwrap.__globals__ + + if globals is None: + globals = obj_globals + if locals is None: + locals = obj_locals + + return_value = {key: + value if not isinstance(value, str) else eval(value, globals, locals) + for key, value in ann.items() } + return return_value + + # ----------------------------------------------------------- type-checking def ismodule(object): - """Return true if the object is a module. - - Module objects provide these attributes: - __cached__ pathname to byte compiled file - __doc__ documentation string - __file__ filename (missing for built-in modules)""" + """Return true if the object is a module.""" return isinstance(object, types.ModuleType) def isclass(object): - """Return true if the object is a class. - - Class objects provide these attributes: - __doc__ documentation string - __module__ name of module in which this class was defined""" + """Return true if the object is a class.""" return isinstance(object, type) def ismethod(object): - """Return true if the object is an instance method. - - Instance method objects provide these attributes: - __doc__ documentation string - __name__ name with which this method was defined - __func__ function object containing implementation of method - __self__ instance to which this method is bound""" + """Return true if the object is an instance method.""" return isinstance(object, types.MethodType) def ismethoddescriptor(object): @@ -175,7 +388,7 @@ def _has_code_flag(f, flag): while ismethod(f): f = f.__func__ f = functools._unwrap_partial(f) - if not isfunction(f): + if not (isfunction(f) or _signature_is_functionlike(f)): return False return bool(f.__code__.co_flags & flag) @@ -186,12 +399,31 @@ def isgeneratorfunction(obj): See help(isfunction) for a list of attributes.""" return _has_code_flag(obj, CO_GENERATOR) +# A marker for markcoroutinefunction and iscoroutinefunction. +_is_coroutine_marker = object() + +def _has_coroutine_mark(f): + while ismethod(f): + f = f.__func__ + f = functools._unwrap_partial(f) + return getattr(f, "_is_coroutine_marker", None) is _is_coroutine_marker + +def markcoroutinefunction(func): + """ + Decorator to ensure callable is recognised as a coroutine function. + """ + if hasattr(func, '__func__'): + func = func.__func__ + func._is_coroutine_marker = _is_coroutine_marker + return func + def iscoroutinefunction(obj): """Return true if the object is a coroutine function. - Coroutine functions are defined with "async def" syntax. + Coroutine functions are normally defined with "async def" syntax, but may + be marked via markcoroutinefunction. """ - return _has_code_flag(obj, CO_COROUTINE) + return _has_code_flag(obj, CO_COROUTINE) or _has_coroutine_mark(obj) def isasyncgenfunction(obj): """Return true if the object is an asynchronous generator function. @@ -276,7 +508,7 @@ def iscode(object): co_kwonlyargcount number of keyword only arguments (not including ** arg) co_lnotab encoded mapping of line numbers to bytecode indices co_name name with which this code object was defined - co_names tuple of names of local variables + co_names tuple of names other than arguments and function locals co_nlocals number of local variables co_stacksize virtual machine stack space required co_varnames tuple of names of arguments and local variables""" @@ -291,12 +523,17 @@ def isbuiltin(object): __self__ instance to which a method is bound, or None""" return isinstance(object, types.BuiltinFunctionType) +def ismethodwrapper(object): + """Return true if the object is a method wrapper.""" + return isinstance(object, types.MethodWrapperType) + def isroutine(object): """Return true if the object is any kind of function or method.""" return (isbuiltin(object) or isfunction(object) or ismethod(object) - or ismethoddescriptor(object)) + or ismethoddescriptor(object) + or ismethodwrapper(object)) def isabstract(object): """Return true if the object is an abstract base class (ABC).""" @@ -325,32 +562,30 @@ def isabstract(object): return True return False -def getmembers(object, predicate=None): - """Return all members of an object as (name, value) pairs sorted by name. - Optionally, only return members that satisfy a given predicate.""" - if isclass(object): - mro = (object,) + getmro(object) - else: - mro = () +def _getmembers(object, predicate, getter): results = [] processed = set() names = dir(object) - # :dd any DynamicClassAttributes to the list of names if object is a class; - # this may result in duplicate entries if, for example, a virtual - # attribute with the same name as a DynamicClassAttribute exists - try: - for base in object.__bases__: - for k, v in base.__dict__.items(): - if isinstance(v, types.DynamicClassAttribute): - names.append(k) - except AttributeError: - pass + if isclass(object): + mro = getmro(object) + # add any DynamicClassAttributes to the list of names if object is a class; + # this may result in duplicate entries if, for example, a virtual + # attribute with the same name as a DynamicClassAttribute exists + try: + for base in object.__bases__: + for k, v in base.__dict__.items(): + if isinstance(v, types.DynamicClassAttribute): + names.append(k) + except AttributeError: + pass + else: + mro = () for key in names: # First try to get the value via getattr. Some descriptors don't # like calling their __get__ (see bug #1785), so fall back to # looking in the __dict__. try: - value = getattr(object, key) + value = getter(object, key) # handle the duplicate key if key in processed: raise AttributeError @@ -369,6 +604,25 @@ def getmembers(object, predicate=None): results.sort(key=lambda pair: pair[0]) return results +def getmembers(object, predicate=None): + """Return all members of an object as (name, value) pairs sorted by name. + Optionally, only return members that satisfy a given predicate.""" + return _getmembers(object, predicate, getattr) + +def getmembers_static(object, predicate=None): + """Return all members of an object as (name, value) pairs sorted by name + without triggering dynamic lookup via the descriptor protocol, + __getattr__ or __getattribute__. Optionally, only return members that + satisfy a given predicate. + + Note: this function may not be able to retrieve all members + that getmembers can fetch (like dynamically created attributes) + and may find members that getmembers can't (like descriptors + that raise AttributeError). It can also return descriptor objects + instead of instance members in some cases. + """ + return _getmembers(object, predicate, getattr_static) + Attribute = namedtuple('Attribute', 'name kind defining_class object') def classify_class_attrs(cls): @@ -409,7 +663,7 @@ def classify_class_attrs(cls): # attribute with the same name as a DynamicClassAttribute exists. for base in mro: for k, v in base.__dict__.items(): - if isinstance(v, types.DynamicClassAttribute): + if isinstance(v, types.DynamicClassAttribute) and v.fget is not None: names.append(k) result = [] processed = set() @@ -432,7 +686,7 @@ def classify_class_attrs(cls): if name == '__dict__': raise Exception("__dict__ is special, don't want the proxy") get_obj = getattr(cls, name) - except Exception as exc: + except Exception: pass else: homecls = getattr(get_obj, "__objclass__", homecls) @@ -509,18 +763,14 @@ def unwrap(func, *, stop=None): :exc:`ValueError` is raised if a cycle is encountered. """ - if stop is None: - def _is_wrapper(f): - return hasattr(f, '__wrapped__') - else: - def _is_wrapper(f): - return hasattr(f, '__wrapped__') and not stop(f) f = func # remember the original func for error reporting # Memoise by id to tolerate non-hashable objects, but store objects to # ensure they aren't destroyed, which would allow their IDs to be reused. memo = {id(f): f} recursion_limit = sys.getrecursionlimit() - while _is_wrapper(func): + while not isinstance(func, type) and hasattr(func, '__wrapped__'): + if stop is not None and stop(func): + break func = func.__wrapped__ id_func = id(func) if (id_func in memo) or (len(memo) >= recursion_limit): @@ -665,6 +915,8 @@ def getfile(object): module = sys.modules.get(object.__module__) if getattr(module, '__file__', None): return module.__file__ + if object.__module__ == '__main__': + raise OSError('source code not available') raise TypeError('{!r} is a built-in class'.format(object)) if ismethod(object): object = object.__func__ @@ -705,13 +957,16 @@ def getsourcefile(object): elif any(filename.endswith(s) for s in importlib.machinery.EXTENSION_SUFFIXES): return None + # return a filename found in the linecache even if it doesn't exist on disk + if filename in linecache.cache: + return filename if os.path.exists(filename): return filename # only return a non-existent filename if the module has a PEP 302 loader - if getattr(getmodule(object, filename), '__loader__', None) is not None: + module = getmodule(object, filename) + if getattr(module, '__loader__', None) is not None: return filename - # or it is in the linecache - if filename in linecache.cache: + elif getattr(getattr(module, "__spec__", None), "loader", None) is not None: return filename def getabsfile(object, _filename=None): @@ -738,13 +993,13 @@ def getmodule(object, _filename=None): # Try the cache again with the absolute file name try: file = getabsfile(object, _filename) - except TypeError: + except (TypeError, FileNotFoundError): return None if file in modulesbyfile: return sys.modules.get(modulesbyfile[file]) # Update the filename to module name cache and check yet again # Copy sys.modules in order to cope with changes while iterating - for modname, module in list(sys.modules.items()): + for modname, module in sys.modules.copy().items(): if ismodule(module) and hasattr(module, '__file__'): f = module.__file__ if f == _filesbymodname.get(modname, None): @@ -772,6 +1027,42 @@ def getmodule(object, _filename=None): if builtinobject is object: return builtin + +class ClassFoundException(Exception): + pass + + +class _ClassFinder(ast.NodeVisitor): + + def __init__(self, qualname): + self.stack = [] + self.qualname = qualname + + def visit_FunctionDef(self, node): + self.stack.append(node.name) + self.stack.append('') + self.generic_visit(node) + self.stack.pop() + self.stack.pop() + + visit_AsyncFunctionDef = visit_FunctionDef + + def visit_ClassDef(self, node): + self.stack.append(node.name) + if self.qualname == '.'.join(self.stack): + # Return the decorator for the class if present + if node.decorator_list: + line_number = node.decorator_list[0].lineno + else: + line_number = node.lineno + + # decrement by one since lines starts with indexing by zero + line_number -= 1 + raise ClassFoundException(line_number) + self.generic_visit(node) + self.stack.pop() + + def findsource(object): """Return the entire source file and starting line number for an object. @@ -804,25 +1095,15 @@ def findsource(object): return lines, 0 if isclass(object): - name = object.__name__ - pat = re.compile(r'^(\s*)class\s*' + name + r'\b') - # make some effort to find the best matching class definition: - # use the one with the least indentation, which is the one - # that's most probably not inside a function definition. - candidates = [] - for i in range(len(lines)): - match = pat.match(lines[i]) - if match: - # if it's at toplevel, it's already the best one - if lines[i][0] == 'c': - return lines, i - # else add whitespace to candidate list - candidates.append((match.group(1), i)) - if candidates: - # this will sort by whitespace, and by line number, - # less whitespace first - candidates.sort() - return lines, candidates[0][1] + qualname = object.__qualname__ + source = ''.join(lines) + tree = ast.parse(source) + class_finder = _ClassFinder(qualname) + try: + class_finder.visit(tree) + except ClassFoundException as e: + line_number = e.args[0] + return lines, line_number else: raise OSError('could not find class definition') @@ -840,7 +1121,12 @@ def findsource(object): lnum = object.co_firstlineno - 1 pat = re.compile(r'^(\s*def\s)|(\s*async\s+def\s)|(.*(? 0: - if pat.match(lines[lnum]): break + try: + line = lines[lnum] + except IndexError: + raise OSError('lineno is out of bounds') + if pat.match(line): + break lnum = lnum - 1 return lines, lnum raise OSError('could not find code object') @@ -900,8 +1186,8 @@ def __init__(self): self.started = False self.passline = False self.indecorator = False - self.decoratorhasargs = False self.last = 1 + self.body_col0 = None def tokeneater(self, type, token, srowcol, erowcol, line): if not self.started and not self.indecorator: @@ -914,13 +1200,6 @@ def tokeneater(self, type, token, srowcol, erowcol, line): self.islambda = True self.started = True self.passline = True # skip to the end of the line - elif token == "(": - if self.indecorator: - self.decoratorhasargs = True - elif token == ")": - if self.indecorator: - self.indecorator = False - self.decoratorhasargs = False elif type == tokenize.NEWLINE: self.passline = False # stop skipping when a NEWLINE is seen self.last = srowcol[0] @@ -928,11 +1207,13 @@ def tokeneater(self, type, token, srowcol, erowcol, line): raise EndOfBlock # hitting a NEWLINE when in a decorator without args # ends the decorator - if self.indecorator and not self.decoratorhasargs: + if self.indecorator: self.indecorator = False elif self.passline: pass elif type == tokenize.INDENT: + if self.body_col0 is None and self.started: + self.body_col0 = erowcol[1] self.indent = self.indent + 1 self.passline = True elif type == tokenize.DEDENT: @@ -942,6 +1223,10 @@ def tokeneater(self, type, token, srowcol, erowcol, line): # not e.g. for "if: else:" or "try: finally:" blocks) if self.indent <= 0: raise EndOfBlock + elif type == tokenize.COMMENT: + if self.body_col0 is not None and srowcol[1] >= self.body_col0: + # Include comments if indented at least as much as the block + self.last = srowcol[0] elif self.indent == 0 and type not in (tokenize.COMMENT, tokenize.NL): # any other token on the same indentation level end the previous # block as well, except the pseudo-tokens COMMENT and NL. @@ -956,6 +1241,14 @@ def getblock(lines): blockfinder.tokeneater(*_token) except (EndOfBlock, IndentationError): pass + except SyntaxError as e: + if "unmatched" not in e.msg: + raise e from None + _, *_token_info = _token + try: + blockfinder.tokeneater(tokenize.NEWLINE, *_token_info) + except (EndOfBlock, IndentationError): + pass return lines[:blockfinder.last] def getsourcelines(object): @@ -1043,7 +1336,6 @@ def getargs(co): nkwargs = co.co_kwonlyargcount args = list(names[:nargs]) kwonlyargs = list(names[nargs:nargs+nkwargs]) - step = 0 nargs += nkwargs varargs = None @@ -1055,37 +1347,6 @@ def getargs(co): varkw = co.co_varnames[nargs] return Arguments(args + kwonlyargs, varargs, varkw) -ArgSpec = namedtuple('ArgSpec', 'args varargs keywords defaults') - -def getargspec(func): - """Get the names and default values of a function's parameters. - - A tuple of four things is returned: (args, varargs, keywords, defaults). - 'args' is a list of the argument names, including keyword-only argument names. - 'varargs' and 'keywords' are the names of the * and ** parameters or None. - 'defaults' is an n-tuple of the default values of the last n parameters. - - This function is deprecated, as it does not support annotations or - keyword-only parameters and will raise ValueError if either is present - on the supplied callable. - - For a more structured introspection API, use inspect.signature() instead. - - Alternatively, use getfullargspec() for an API with a similar namedtuple - based interface, but full support for annotations and keyword-only - parameters. - - Deprecated since Python 3.5, use `inspect.getfullargspec()`. - """ - warnings.warn("inspect.getargspec() is deprecated since Python 3.0, " - "use inspect.signature() or inspect.getfullargspec()", - DeprecationWarning, stacklevel=2) - args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, ann = \ - getfullargspec(func) - if kwonlyargs or ann: - raise ValueError("Function has keyword-only parameters or annotations" - ", use inspect.signature() API which can support them") - return ArgSpec(args, varargs, varkw, defaults) FullArgSpec = namedtuple('FullArgSpec', 'args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, annotations') @@ -1126,7 +1387,8 @@ def getfullargspec(func): sig = _signature_from_callable(func, follow_wrapper_chains=False, skip_bound_arg=False, - sigcls=Signature) + sigcls=Signature, + eval_str=False) except Exception as ex: # Most of the times 'signature' will raise ValueError. # But, it can also raise AttributeError, and, maybe something @@ -1139,7 +1401,6 @@ def getfullargspec(func): varkw = None posonlyargs = [] kwonlyargs = [] - defaults = () annotations = {} defaults = () kwdefaults = {} @@ -1197,7 +1458,12 @@ def getargvalues(frame): def formatannotation(annotation, base_module=None): if getattr(annotation, '__module__', None) == 'typing': - return repr(annotation).replace('typing.', '') + def repl(match): + text = match.group() + return text.removeprefix('typing.') + return re.sub(r'[\w\.]+', repl, repr(annotation)) + if isinstance(annotation, types.GenericAlias): + return str(annotation) if isinstance(annotation, type): if annotation.__module__ in ('builtins', base_module): return annotation.__qualname__ @@ -1210,63 +1476,6 @@ def _formatannotation(annotation): return formatannotation(annotation, module) return _formatannotation -def formatargspec(args, varargs=None, varkw=None, defaults=None, - kwonlyargs=(), kwonlydefaults={}, annotations={}, - formatarg=str, - formatvarargs=lambda name: '*' + name, - formatvarkw=lambda name: '**' + name, - formatvalue=lambda value: '=' + repr(value), - formatreturns=lambda text: ' -> ' + text, - formatannotation=formatannotation): - """Format an argument spec from the values returned by getfullargspec. - - The first seven arguments are (args, varargs, varkw, defaults, - kwonlyargs, kwonlydefaults, annotations). The other five arguments - are the corresponding optional formatting functions that are called to - turn names and values into strings. The last argument is an optional - function to format the sequence of arguments. - - Deprecated since Python 3.5: use the `signature` function and `Signature` - objects. - """ - - from warnings import warn - - warn("`formatargspec` is deprecated since Python 3.5. Use `signature` and " - "the `Signature` object directly", - DeprecationWarning, - stacklevel=2) - - def formatargandannotation(arg): - result = formatarg(arg) - if arg in annotations: - result += ': ' + formatannotation(annotations[arg]) - return result - specs = [] - if defaults: - firstdefault = len(args) - len(defaults) - for i, arg in enumerate(args): - spec = formatargandannotation(arg) - if defaults and i >= firstdefault: - spec = spec + formatvalue(defaults[i - firstdefault]) - specs.append(spec) - if varargs is not None: - specs.append(formatvarargs(formatargandannotation(varargs))) - else: - if kwonlyargs: - specs.append('*') - if kwonlyargs: - for kwonlyarg in kwonlyargs: - spec = formatargandannotation(kwonlyarg) - if kwonlydefaults and kwonlyarg in kwonlydefaults: - spec += formatvalue(kwonlydefaults[kwonlyarg]) - specs.append(spec) - if varkw is not None: - specs.append(formatvarkw(formatargandannotation(varkw))) - result = '(' + ', '.join(specs) + ')' - if 'return' in annotations: - result += formatreturns(formatannotation(annotations['return'])) - return result def formatargvalues(args, varargs, varkw, locals, formatarg=str, @@ -1443,7 +1652,30 @@ def getclosurevars(func): # -------------------------------------------------- stack frame extraction -Traceback = namedtuple('Traceback', 'filename lineno function code_context index') +_Traceback = namedtuple('_Traceback', 'filename lineno function code_context index') + +class Traceback(_Traceback): + def __new__(cls, filename, lineno, function, code_context, index, *, positions=None): + instance = super().__new__(cls, filename, lineno, function, code_context, index) + instance.positions = positions + return instance + + def __repr__(self): + return ('Traceback(filename={!r}, lineno={!r}, function={!r}, ' + 'code_context={!r}, index={!r}, positions={!r})'.format( + self.filename, self.lineno, self.function, self.code_context, + self.index, self.positions)) + +def _get_code_position_from_tb(tb): + code, instruction_index = tb.tb_frame.f_code, tb.tb_lasti + return _get_code_position(code, instruction_index) + +def _get_code_position(code, instruction_index): + if instruction_index < 0: + return (None, None, None, None) + positions_gen = code.co_positions() + # The nth entry in code.co_positions() corresponds to instruction (2*n)th since Python 3.10+ + return next(itertools.islice(positions_gen, instruction_index // 2, None)) def getframeinfo(frame, context=1): """Get information about a frame or traceback object. @@ -1454,10 +1686,20 @@ def getframeinfo(frame, context=1): The optional second argument specifies the number of lines of context to return, which are centered around the current line.""" if istraceback(frame): + positions = _get_code_position_from_tb(frame) lineno = frame.tb_lineno frame = frame.tb_frame else: lineno = frame.f_lineno + positions = _get_code_position(frame.f_code, frame.f_lasti) + + if positions[0] is None: + frame, *positions = (frame, lineno, *positions[1:]) + else: + frame, *positions = (frame, *positions) + + lineno = positions[0] + if not isframe(frame): raise TypeError('{!r} is not a frame or traceback object'.format(frame)) @@ -1475,14 +1717,26 @@ def getframeinfo(frame, context=1): else: lines = index = None - return Traceback(filename, lineno, frame.f_code.co_name, lines, index) + return Traceback(filename, lineno, frame.f_code.co_name, lines, + index, positions=dis.Positions(*positions)) def getlineno(frame): """Get the line number from a frame object, allowing for optimization.""" # FrameType.f_lineno is now a descriptor that grovels co_lnotab return frame.f_lineno -FrameInfo = namedtuple('FrameInfo', ('frame',) + Traceback._fields) +_FrameInfo = namedtuple('_FrameInfo', ('frame',) + Traceback._fields) +class FrameInfo(_FrameInfo): + def __new__(cls, frame, filename, lineno, function, code_context, index, *, positions=None): + instance = super().__new__(cls, frame, filename, lineno, function, code_context, index) + instance.positions = positions + return instance + + def __repr__(self): + return ('FrameInfo(frame={!r}, filename={!r}, lineno={!r}, function={!r}, ' + 'code_context={!r}, index={!r}, positions={!r})'.format( + self.frame, self.filename, self.lineno, self.function, + self.code_context, self.index, self.positions)) def getouterframes(frame, context=1): """Get a list of records for a frame and all higher (calling) frames. @@ -1491,8 +1745,9 @@ def getouterframes(frame, context=1): name, a list of lines of context, and index within the context.""" framelist = [] while frame: - frameinfo = (frame,) + getframeinfo(frame, context) - framelist.append(FrameInfo(*frameinfo)) + traceback_info = getframeinfo(frame, context) + frameinfo = (frame,) + traceback_info + framelist.append(FrameInfo(*frameinfo, positions=traceback_info.positions)) frame = frame.f_back return framelist @@ -1503,8 +1758,9 @@ def getinnerframes(tb, context=1): name, a list of lines of context, and index within the context.""" framelist = [] while tb: - frameinfo = (tb.tb_frame,) + getframeinfo(tb, context) - framelist.append(FrameInfo(*frameinfo)) + traceback_info = getframeinfo(tb, context) + frameinfo = (tb.tb_frame,) + traceback_info + framelist.append(FrameInfo(*frameinfo, positions=traceback_info.positions)) tb = tb.tb_next return framelist @@ -1518,15 +1774,17 @@ def stack(context=1): def trace(context=1): """Return a list of records for the stack below the current exception.""" - return getinnerframes(sys.exc_info()[2], context) + exc = sys.exception() + tb = None if exc is None else exc.__traceback__ + return getinnerframes(tb, context) # ------------------------------------------------ static version of getattr _sentinel = object() +_static_getmro = type.__dict__['__mro__'].__get__ +_get_dunder_dict_of_class = type.__dict__["__dict__"].__get__ -def _static_getmro(klass): - return type.__dict__['__mro__'].__get__(klass) def _check_instance(obj, attr): instance_dict = {} @@ -1539,34 +1797,25 @@ def _check_instance(obj, attr): def _check_class(klass, attr): for entry in _static_getmro(klass): - if _shadowed_dict(type(entry)) is _sentinel: - try: - return entry.__dict__[attr] - except KeyError: - pass + if _shadowed_dict(type(entry)) is _sentinel and attr in entry.__dict__: + return entry.__dict__[attr] return _sentinel -def _is_type(obj): - try: - _static_getmro(obj) - except TypeError: - return False - return True - -def _shadowed_dict(klass): - dict_attr = type.__dict__["__dict__"] - for entry in _static_getmro(klass): - try: - class_dict = dict_attr.__get__(entry)["__dict__"] - except KeyError: - pass - else: +@functools.lru_cache() +def _shadowed_dict_from_mro_tuple(mro): + for entry in mro: + dunder_dict = _get_dunder_dict_of_class(entry) + if '__dict__' in dunder_dict: + class_dict = dunder_dict['__dict__'] if not (type(class_dict) is types.GetSetDescriptorType and class_dict.__name__ == "__dict__" and class_dict.__objclass__ is entry): return class_dict return _sentinel +def _shadowed_dict(klass): + return _shadowed_dict_from_mro_tuple(_static_getmro(klass)) + def getattr_static(obj, attr, default=_sentinel): """Retrieve attributes without triggering dynamic lookup via the descriptor protocol, __getattr__ or __getattribute__. @@ -1579,8 +1828,10 @@ def getattr_static(obj, attr, default=_sentinel): documentation for details. """ instance_result = _sentinel - if not _is_type(obj): - klass = type(obj) + + objtype = type(obj) + if type not in _static_getmro(objtype): + klass = objtype dict_attr = _shadowed_dict(klass) if (dict_attr is _sentinel or type(dict_attr) is types.MemberDescriptorType): @@ -1591,8 +1842,10 @@ def getattr_static(obj, attr, default=_sentinel): klass_result = _check_class(klass, attr) if instance_result is not _sentinel and klass_result is not _sentinel: - if (_check_class(type(klass_result), '__get__') is not _sentinel and - _check_class(type(klass_result), '__set__') is not _sentinel): + if _check_class(type(klass_result), "__get__") is not _sentinel and ( + _check_class(type(klass_result), "__set__") is not _sentinel + or _check_class(type(klass_result), "__delete__") is not _sentinel + ): return klass_result if instance_result is not _sentinel: @@ -1603,11 +1856,11 @@ def getattr_static(obj, attr, default=_sentinel): if obj is klass: # for types we check the metaclass too for entry in _static_getmro(type(klass)): - if _shadowed_dict(type(entry)) is _sentinel: - try: - return entry.__dict__[attr] - except KeyError: - pass + if ( + _shadowed_dict(type(entry)) is _sentinel + and attr in entry.__dict__ + ): + return entry.__dict__[attr] if default is not _sentinel: return default raise AttributeError(attr) @@ -1631,11 +1884,11 @@ def getgeneratorstate(generator): """ if generator.gi_running: return GEN_RUNNING + if generator.gi_suspended: + return GEN_SUSPENDED if generator.gi_frame is None: return GEN_CLOSED - if generator.gi_frame.f_lasti == -1: - return GEN_CREATED - return GEN_SUSPENDED + return GEN_CREATED def getgeneratorlocals(generator): @@ -1673,11 +1926,11 @@ def getcoroutinestate(coroutine): """ if coroutine.cr_running: return CORO_RUNNING + if coroutine.cr_suspended: + return CORO_SUSPENDED if coroutine.cr_frame is None: return CORO_CLOSED - if coroutine.cr_frame.f_lasti == -1: - return CORO_CREATED - return CORO_SUSPENDED + return CORO_CREATED def getcoroutinelocals(coroutine): @@ -1693,18 +1946,58 @@ def getcoroutinelocals(coroutine): return {} +# ----------------------------------- asynchronous generator introspection + +AGEN_CREATED = 'AGEN_CREATED' +AGEN_RUNNING = 'AGEN_RUNNING' +AGEN_SUSPENDED = 'AGEN_SUSPENDED' +AGEN_CLOSED = 'AGEN_CLOSED' + + +def getasyncgenstate(agen): + """Get current state of an asynchronous generator object. + + Possible states are: + AGEN_CREATED: Waiting to start execution. + AGEN_RUNNING: Currently being executed by the interpreter. + AGEN_SUSPENDED: Currently suspended at a yield expression. + AGEN_CLOSED: Execution has completed. + """ + if agen.ag_running: + return AGEN_RUNNING + if agen.ag_suspended: + return AGEN_SUSPENDED + if agen.ag_frame is None: + return AGEN_CLOSED + return AGEN_CREATED + + +def getasyncgenlocals(agen): + """ + Get the mapping of asynchronous generator local variables to their current + values. + + A dict is returned, with the keys the local variable names and values the + bound values.""" + + if not isasyncgen(agen): + raise TypeError(f"{agen!r} is not a Python async generator") + + frame = getattr(agen, "ag_frame", None) + if frame is not None: + return agen.ag_frame.f_locals + else: + return {} + + ############################################################################### ### Function Signature Object (PEP 362) ############################################################################### -_WrapperDescriptor = type(type.__call__) -_MethodWrapper = type(all.__call__) -_ClassMethodWrapper = type(int.__dict__['from_bytes']) - -_NonUserDefinedCallables = (_WrapperDescriptor, - _MethodWrapper, - _ClassMethodWrapper, +_NonUserDefinedCallables = (types.WrapperDescriptorType, + types.MethodWrapperType, + types.ClassMethodDescriptorType, types.BuiltinFunctionType) @@ -1713,15 +2006,17 @@ def _signature_get_user_defined_method(cls, method_name): named ``method_name`` and returns it only if it is a pure python function. """ - try: - meth = getattr(cls, method_name) - except AttributeError: - return + if method_name == '__new__': + meth = getattr(cls, method_name, None) else: - if not isinstance(meth, _NonUserDefinedCallables): - # Once '__signature__' will be added to 'C'-level - # callables, this check won't be necessary - return meth + meth = getattr_static(cls, method_name, None) + if meth is None or isinstance(meth, _NonUserDefinedCallables): + # Once '__signature__' will be added to 'C'-level + # callables, this check won't be necessary + return None + if method_name != '__new__': + meth = _descriptor_get(meth, cls) + return meth def _signature_get_partial(wrapped_sig, partial, extra_args=()): @@ -1860,30 +2155,7 @@ def _signature_is_functionlike(obj): isinstance(name, str) and (defaults is None or isinstance(defaults, tuple)) and (kwdefaults is None or isinstance(kwdefaults, dict)) and - isinstance(annotations, dict)) - - -def _signature_get_bound_param(spec): - """ Private helper to get first parameter name from a - __text_signature__ of a builtin method, which should - be in the following format: '($param1, ...)'. - Assumptions are that the first argument won't have - a default value or an annotation. - """ - - assert spec.startswith('($') - - pos = spec.find(',') - if pos == -1: - pos = spec.find(')') - - cpos = spec.find(':') - assert cpos == -1 or cpos > pos - - cpos = spec.find('=') - assert cpos == -1 or cpos > pos - - return spec[2:pos] + (isinstance(annotations, (dict)) or annotations is None) ) def _signature_strip_non_python_syntax(signature): @@ -1891,26 +2163,21 @@ def _signature_strip_non_python_syntax(signature): Private helper function. Takes a signature in Argument Clinic's extended signature format. - Returns a tuple of three things: - * that signature re-rendered in standard Python syntax, + Returns a tuple of two things: + * that signature re-rendered in standard Python syntax, and * the index of the "self" parameter (generally 0), or None if - the function does not have a "self" parameter, and - * the index of the last "positional only" parameter, - or None if the signature has no positional-only parameters. + the function does not have a "self" parameter. """ if not signature: - return signature, None, None + return signature, None self_parameter = None - last_positional_only = None - lines = [l.encode('ascii') for l in signature.split('\n')] + lines = [l.encode('ascii') for l in signature.split('\n') if l] generator = iter(lines).__next__ token_stream = tokenize.tokenize(generator) - delayed_comma = False - skip_next_comma = False text = [] add = text.append @@ -1927,49 +2194,27 @@ def _signature_strip_non_python_syntax(signature): if type == OP: if string == ',': - if skip_next_comma: - skip_next_comma = False - else: - assert not delayed_comma - delayed_comma = True - current_parameter += 1 - continue + current_parameter += 1 - if string == '/': - assert not skip_next_comma - assert last_positional_only is None - skip_next_comma = True - last_positional_only = current_parameter - 1 - continue - - if (type == ERRORTOKEN) and (string == '$'): + if (type == OP) and (string == '$'): assert self_parameter is None self_parameter = current_parameter continue - if delayed_comma: - delayed_comma = False - if not ((type == OP) and (string == ')')): - add(', ') add(string) if (string == ','): add(' ') - clean_signature = ''.join(text) - return clean_signature, self_parameter, last_positional_only + clean_signature = ''.join(text).strip().replace("\n", "") + return clean_signature, self_parameter def _signature_fromstr(cls, obj, s, skip_bound_arg=True): """Private helper to parse content of '__text_signature__' and return a Signature based on it. """ - # Lazy import ast because it's relatively heavy and - # it's not used for other than this function. - import ast - Parameter = cls._parameter_cls - clean_signature, self_parameter, last_positional_only = \ - _signature_strip_non_python_syntax(s) + clean_signature, self_parameter = _signature_strip_non_python_syntax(s) program = "def foo" + clean_signature + ": pass" @@ -1985,7 +2230,6 @@ def _signature_fromstr(cls, obj, s, skip_bound_arg=True): parameters = [] empty = Parameter.empty - invalid = object() module = None module_dict = {} @@ -2009,11 +2253,11 @@ def wrap_value(s): try: value = eval(s, sys_module_dict) except NameError: - raise RuntimeError() + raise ValueError if isinstance(value, (str, int, float, bytes, bool, type(None))): return ast.Constant(value) - raise RuntimeError() + raise ValueError class RewriteSymbolics(ast.NodeTransformer): def visit_Attribute(self, node): @@ -2023,7 +2267,7 @@ def visit_Attribute(self, node): a.append(n.attr) n = n.value if not isinstance(n, ast.Name): - raise RuntimeError() + raise ValueError a.append(n.id) value = ".".join(reversed(a)) return wrap_value(value) @@ -2033,33 +2277,43 @@ def visit_Name(self, node): raise ValueError() return wrap_value(node.id) + def visit_BinOp(self, node): + # Support constant folding of a couple simple binary operations + # commonly used to define default values in text signatures + left = self.visit(node.left) + right = self.visit(node.right) + if not isinstance(left, ast.Constant) or not isinstance(right, ast.Constant): + raise ValueError + if isinstance(node.op, ast.Add): + return ast.Constant(left.value + right.value) + elif isinstance(node.op, ast.Sub): + return ast.Constant(left.value - right.value) + elif isinstance(node.op, ast.BitOr): + return ast.Constant(left.value | right.value) + raise ValueError + def p(name_node, default_node, default=empty): name = parse_name(name_node) - if name is invalid: - return None if default_node and default_node is not _empty: try: default_node = RewriteSymbolics().visit(default_node) - o = ast.literal_eval(default_node) + default = ast.literal_eval(default_node) except ValueError: - o = invalid - if o is invalid: - return None - default = o if o is not invalid else default + raise ValueError("{!r} builtin has invalid signature".format(obj)) from None parameters.append(Parameter(name, kind, default=default, annotation=empty)) # non-keyword-only parameters - args = reversed(f.args.args) - defaults = reversed(f.args.defaults) - iter = itertools.zip_longest(args, defaults, fillvalue=None) - if last_positional_only is not None: - kind = Parameter.POSITIONAL_ONLY - else: - kind = Parameter.POSITIONAL_OR_KEYWORD - for i, (name, default) in enumerate(reversed(list(iter))): + total_non_kw_args = len(f.args.posonlyargs) + len(f.args.args) + required_non_kw_args = total_non_kw_args - len(f.args.defaults) + defaults = itertools.chain(itertools.repeat(None, required_non_kw_args), f.args.defaults) + + kind = Parameter.POSITIONAL_ONLY + for (name, default) in zip(f.args.posonlyargs, defaults): + p(name, default) + + kind = Parameter.POSITIONAL_OR_KEYWORD + for (name, default) in zip(f.args.args, defaults): p(name, default) - if i == last_positional_only: - kind = Parameter.POSITIONAL_OR_KEYWORD # *args if f.args.vararg: @@ -2112,7 +2366,8 @@ def _signature_from_builtin(cls, func, skip_bound_arg=True): return _signature_fromstr(cls, func, s, skip_bound_arg) -def _signature_from_function(cls, func, skip_bound_arg=True): +def _signature_from_function(cls, func, skip_bound_arg=True, + globals=None, locals=None, eval_str=False): """Private helper: constructs Signature for the given python function.""" is_duck_function = False @@ -2138,7 +2393,7 @@ def _signature_from_function(cls, func, skip_bound_arg=True): positional = arg_names[:pos_count] keyword_only_count = func_code.co_kwonlyargcount keyword_only = arg_names[pos_count:pos_count + keyword_only_count] - annotations = func.__annotations__ + annotations = get_annotations(func, globals=globals, locals=locals, eval_str=eval_str) defaults = func.__defaults__ kwdefaults = func.__kwdefaults__ @@ -2206,26 +2461,42 @@ def _signature_from_function(cls, func, skip_bound_arg=True): __validate_parameters__=is_duck_function) +def _descriptor_get(descriptor, obj): + if isclass(descriptor): + return descriptor + get = getattr(type(descriptor), '__get__', _sentinel) + if get is _sentinel: + return descriptor + return get(descriptor, obj, type(obj)) + + def _signature_from_callable(obj, *, follow_wrapper_chains=True, skip_bound_arg=True, + globals=None, + locals=None, + eval_str=False, sigcls): """Private helper function to get signature for arbitrary callable objects. """ + _get_signature_of = functools.partial(_signature_from_callable, + follow_wrapper_chains=follow_wrapper_chains, + skip_bound_arg=skip_bound_arg, + globals=globals, + locals=locals, + sigcls=sigcls, + eval_str=eval_str) + if not callable(obj): raise TypeError('{!r} is not a callable object'.format(obj)) if isinstance(obj, types.MethodType): # In this case we skip the first parameter of the underlying # function (usually `self` or `cls`). - sig = _signature_from_callable( - obj.__func__, - follow_wrapper_chains=follow_wrapper_chains, - skip_bound_arg=skip_bound_arg, - sigcls=sigcls) + sig = _get_signature_of(obj.__func__) if skip_bound_arg: return _signature_bound_method(sig) @@ -2234,16 +2505,15 @@ def _signature_from_callable(obj, *, # Was this function wrapped by a decorator? if follow_wrapper_chains: - obj = unwrap(obj, stop=(lambda f: hasattr(f, "__signature__"))) + # Unwrap until we find an explicit signature or a MethodType (which will be + # handled explicitly below). + obj = unwrap(obj, stop=(lambda f: hasattr(f, "__signature__") + or isinstance(f, types.MethodType))) if isinstance(obj, types.MethodType): # If the unwrapped object is a *method*, we might want to # skip its first parameter (self). # See test_signature_wrapped_bound_method for details. - return _signature_from_callable( - obj, - follow_wrapper_chains=follow_wrapper_chains, - skip_bound_arg=skip_bound_arg, - sigcls=sigcls) + return _get_signature_of(obj) try: sig = obj.__signature__ @@ -2251,10 +2521,18 @@ def _signature_from_callable(obj, *, pass else: if sig is not None: + # since __text_signature__ is not writable on classes, __signature__ + # may contain text (or be a callable that returns text); + # if so, convert it + o_sig = sig + if not isinstance(sig, (Signature, str)) and callable(sig): + sig = sig() + if isinstance(sig, str): + sig = _signature_fromstr(sigcls, obj, sig) if not isinstance(sig, Signature): raise TypeError( 'unexpected object {!r} in __signature__ ' - 'attribute'.format(sig)) + 'attribute'.format(o_sig)) return sig try: @@ -2270,11 +2548,7 @@ def _signature_from_callable(obj, *, # (usually `self`, or `cls`) will not be passed # automatically (as for boundmethods) - wrapped_sig = _signature_from_callable( - partialmethod.func, - follow_wrapper_chains=follow_wrapper_chains, - skip_bound_arg=skip_bound_arg, - sigcls=sigcls) + wrapped_sig = _get_signature_of(partialmethod.func) sig = _signature_get_partial(wrapped_sig, partialmethod, (None,)) first_wrapped_param = tuple(wrapped_sig.parameters.values())[0] @@ -2293,21 +2567,17 @@ def _signature_from_callable(obj, *, # If it's a pure Python function, or an object that is duck type # of a Python function (Cython functions, for instance), then: return _signature_from_function(sigcls, obj, - skip_bound_arg=skip_bound_arg) + skip_bound_arg=skip_bound_arg, + globals=globals, locals=locals, eval_str=eval_str) if _signature_is_builtin(obj): return _signature_from_builtin(sigcls, obj, skip_bound_arg=skip_bound_arg) if isinstance(obj, functools.partial): - wrapped_sig = _signature_from_callable( - obj.func, - follow_wrapper_chains=follow_wrapper_chains, - skip_bound_arg=skip_bound_arg, - sigcls=sigcls) + wrapped_sig = _get_signature_of(obj.func) return _signature_get_partial(wrapped_sig, obj) - sig = None if isinstance(obj, type): # obj is a class or a metaclass @@ -2315,95 +2585,65 @@ def _signature_from_callable(obj, *, # in its metaclass call = _signature_get_user_defined_method(type(obj), '__call__') if call is not None: - sig = _signature_from_callable( - call, - follow_wrapper_chains=follow_wrapper_chains, - skip_bound_arg=skip_bound_arg, - sigcls=sigcls) - else: - # Now we check if the 'obj' class has a '__new__' method - new = _signature_get_user_defined_method(obj, '__new__') - if new is not None: - sig = _signature_from_callable( - new, - follow_wrapper_chains=follow_wrapper_chains, - skip_bound_arg=skip_bound_arg, - sigcls=sigcls) + return _get_signature_of(call) + + new = _signature_get_user_defined_method(obj, '__new__') + init = _signature_get_user_defined_method(obj, '__init__') + + # Go through the MRO and see if any class has user-defined + # pure Python __new__ or __init__ method + for base in obj.__mro__: + # Now we check if the 'obj' class has an own '__new__' method + if new is not None and '__new__' in base.__dict__: + sig = _get_signature_of(new) + if skip_bound_arg: + sig = _signature_bound_method(sig) + return sig + # or an own '__init__' method + elif init is not None and '__init__' in base.__dict__: + return _get_signature_of(init) + + # At this point we know, that `obj` is a class, with no user- + # defined '__init__', '__new__', or class-level '__call__' + + for base in obj.__mro__[:-1]: + # Since '__text_signature__' is implemented as a + # descriptor that extracts text signature from the + # class docstring, if 'obj' is derived from a builtin + # class, its own '__text_signature__' may be 'None'. + # Therefore, we go through the MRO (except the last + # class in there, which is 'object') to find the first + # class with non-empty text signature. + try: + text_sig = base.__text_signature__ + except AttributeError: + pass else: - # Finally, we should have at least __init__ implemented - init = _signature_get_user_defined_method(obj, '__init__') - if init is not None: - sig = _signature_from_callable( - init, - follow_wrapper_chains=follow_wrapper_chains, - skip_bound_arg=skip_bound_arg, - sigcls=sigcls) - - if sig is None: - # At this point we know, that `obj` is a class, with no user- - # defined '__init__', '__new__', or class-level '__call__' - - for base in obj.__mro__[:-1]: - # Since '__text_signature__' is implemented as a - # descriptor that extracts text signature from the - # class docstring, if 'obj' is derived from a builtin - # class, its own '__text_signature__' may be 'None'. - # Therefore, we go through the MRO (except the last - # class in there, which is 'object') to find the first - # class with non-empty text signature. - try: - text_sig = base.__text_signature__ - except AttributeError: - pass - else: - if text_sig: - # If 'obj' class has a __text_signature__ attribute: - # return a signature based on it - return _signature_fromstr(sigcls, obj, text_sig) - - # No '__text_signature__' was found for the 'obj' class. - # Last option is to check if its '__init__' is - # object.__init__ or type.__init__. - if type not in obj.__mro__: - # We have a class (not metaclass), but no user-defined - # __init__ or __new__ for it - if (obj.__init__ is object.__init__ and - obj.__new__ is object.__new__): - # Return a signature of 'object' builtin. - return sigcls.from_callable(object) - else: - raise ValueError( - 'no signature found for builtin type {!r}'.format(obj)) + if text_sig: + # If 'base' class has a __text_signature__ attribute: + # return a signature based on it + return _signature_fromstr(sigcls, base, text_sig) + + # No '__text_signature__' was found for the 'obj' class. + # Last option is to check if its '__init__' is + # object.__init__ or type.__init__. + if type not in obj.__mro__: + # We have a class (not metaclass), but no user-defined + # __init__ or __new__ for it + if (obj.__init__ is object.__init__ and + obj.__new__ is object.__new__): + # Return a signature of 'object' builtin. + return sigcls.from_callable(object) + else: + raise ValueError( + 'no signature found for builtin type {!r}'.format(obj)) - elif not isinstance(obj, _NonUserDefinedCallables): + else: # An object with __call__ - # We also check that the 'obj' is not an instance of - # _WrapperDescriptor or _MethodWrapper to avoid - # infinite recursion (and even potential segfault) - call = _signature_get_user_defined_method(type(obj), '__call__') + call = getattr_static(type(obj), '__call__', None) if call is not None: - try: - sig = _signature_from_callable( - call, - follow_wrapper_chains=follow_wrapper_chains, - skip_bound_arg=skip_bound_arg, - sigcls=sigcls) - except ValueError as ex: - msg = 'no signature found for {!r}'.format(obj) - raise ValueError(msg) from ex - - if sig is not None: - # For classes and objects we skip the first parameter of their - # __call__, __new__, or __init__ methods - if skip_bound_arg: - return _signature_bound_method(sig) - else: - return sig - - if isinstance(obj, types.BuiltinFunctionType): - # Raise a nicer error message for builtins - msg = 'no signature found for builtin function {!r}'.format(obj) - raise ValueError(msg) + call = _descriptor_get(call, obj) + return _get_signature_of(call) raise ValueError('callable {!r} is not supported by signature'.format(obj)) @@ -2417,18 +2657,21 @@ class _empty: class _ParameterKind(enum.IntEnum): - POSITIONAL_ONLY = 0 - POSITIONAL_OR_KEYWORD = 1 - VAR_POSITIONAL = 2 - KEYWORD_ONLY = 3 - VAR_KEYWORD = 4 + POSITIONAL_ONLY = 'positional-only' + POSITIONAL_OR_KEYWORD = 'positional or keyword' + VAR_POSITIONAL = 'variadic positional' + KEYWORD_ONLY = 'keyword-only' + VAR_KEYWORD = 'variadic keyword' + + def __new__(cls, description): + value = len(cls.__members__) + member = int.__new__(cls, value) + member._value_ = value + member.description = description + return member def __str__(self): - return self._name_ - - @property - def description(self): - return _PARAM_NAME_MAPPING[self] + return self.name _POSITIONAL_ONLY = _ParameterKind.POSITIONAL_ONLY _POSITIONAL_OR_KEYWORD = _ParameterKind.POSITIONAL_OR_KEYWORD @@ -2436,14 +2679,6 @@ def description(self): _KEYWORD_ONLY = _ParameterKind.KEYWORD_ONLY _VAR_KEYWORD = _ParameterKind.VAR_KEYWORD -_PARAM_NAME_MAPPING = { - _POSITIONAL_ONLY: 'positional-only', - _POSITIONAL_OR_KEYWORD: 'positional or keyword', - _VAR_POSITIONAL: 'variadic positional', - _KEYWORD_ONLY: 'keyword-only', - _VAR_KEYWORD: 'variadic keyword' -} - class Parameter: """Represents a parameter in a function signature. @@ -2512,7 +2747,10 @@ def __init__(self, name, kind, *, default=_empty, annotation=_empty): self._kind = _POSITIONAL_ONLY name = 'implicit{}'.format(name[1:]) - if not name.isidentifier(): + # It's possible for C functions to have a positional-only parameter + # where the name is a keyword, so for compatibility we'll allow it. + is_keyword = iskeyword(name) and self._kind is not _POSITIONAL_ONLY + if is_keyword or not name.isidentifier(): raise ValueError('{!r} is not a valid parameter name'.format(name)) self._name = name @@ -2587,7 +2825,7 @@ def __repr__(self): return '<{} "{}">'.format(self.__class__.__name__, self) def __hash__(self): - return hash((self.name, self.kind, self.annotation, self.default)) + return hash((self._name, self._kind, self._annotation, self._default)) def __eq__(self, other): if self is other: @@ -2606,7 +2844,7 @@ class BoundArguments: Has the following public attributes: - * arguments : OrderedDict + * arguments : dict An ordered mutable mapping of parameters' names to arguments' values. Does not contain arguments' default values. * signature : Signature @@ -2706,7 +2944,7 @@ def apply_defaults(self): # Signature.bind_partial(). continue new_arguments.append((name, val)) - self.arguments = OrderedDict(new_arguments) + self.arguments = dict(new_arguments) def __eq__(self, other): if self is other: @@ -2772,9 +3010,9 @@ def __init__(self, parameters=None, *, return_annotation=_empty, if __validate_parameters__: params = OrderedDict() top_kind = _POSITIONAL_ONLY - kind_defaults = False + seen_default = False - for idx, param in enumerate(parameters): + for param in parameters: kind = param.kind name = param.name @@ -2787,21 +3025,19 @@ def __init__(self, parameters=None, *, return_annotation=_empty, kind.description) raise ValueError(msg) elif kind > top_kind: - kind_defaults = False top_kind = kind if kind in (_POSITIONAL_ONLY, _POSITIONAL_OR_KEYWORD): if param.default is _empty: - if kind_defaults: + if seen_default: # No default for this parameter, but the - # previous parameter of the same kind had - # a default + # previous parameter of had a default msg = 'non-default argument follows default ' \ 'argument' raise ValueError(msg) else: # There is a default for this parameter. - kind_defaults = True + seen_default = True if name in params: msg = 'duplicate parameter name: {!r}'.format(name) @@ -2809,41 +3045,18 @@ def __init__(self, parameters=None, *, return_annotation=_empty, params[name] = param else: - params = OrderedDict(((param.name, param) - for param in parameters)) + params = OrderedDict((param.name, param) for param in parameters) self._parameters = types.MappingProxyType(params) self._return_annotation = return_annotation @classmethod - def from_function(cls, func): - """Constructs Signature for the given python function. - - Deprecated since Python 3.5, use `Signature.from_callable()`. - """ - - warnings.warn("inspect.Signature.from_function() is deprecated since " - "Python 3.5, use Signature.from_callable()", - DeprecationWarning, stacklevel=2) - return _signature_from_function(cls, func) - - @classmethod - def from_builtin(cls, func): - """Constructs Signature for the given builtin function. - - Deprecated since Python 3.5, use `Signature.from_callable()`. - """ - - warnings.warn("inspect.Signature.from_builtin() is deprecated since " - "Python 3.5, use Signature.from_callable()", - DeprecationWarning, stacklevel=2) - return _signature_from_builtin(cls, func) - - @classmethod - def from_callable(cls, obj, *, follow_wrapped=True): + def from_callable(cls, obj, *, + follow_wrapped=True, globals=None, locals=None, eval_str=False): """Constructs Signature for the given callable object.""" return _signature_from_callable(obj, sigcls=cls, - follow_wrapper_chains=follow_wrapped) + follow_wrapper_chains=follow_wrapped, + globals=globals, locals=locals, eval_str=eval_str) @property def parameters(self): @@ -2892,7 +3105,7 @@ def __eq__(self, other): def _bind(self, args, kwargs, *, partial=False): """Private method. Don't use directly.""" - arguments = OrderedDict() + arguments = {} parameters = iter(self.parameters.values()) parameters_ex = () @@ -2938,8 +3151,12 @@ def _bind(self, args, kwargs, *, partial=False): parameters_ex = (param,) break else: - msg = 'missing a required argument: {arg!r}' - msg = msg.format(arg=param.name) + if param.kind == _KEYWORD_ONLY: + argtype = ' keyword-only' + else: + argtype = '' + msg = 'missing a required{argtype} argument: {arg!r}' + msg = msg.format(arg=param.name, argtype=argtype) raise TypeError(msg) from None else: # We have a positional argument to process @@ -3091,9 +3308,32 @@ def __str__(self): return rendered -def signature(obj, *, follow_wrapped=True): +def signature(obj, *, follow_wrapped=True, globals=None, locals=None, eval_str=False): """Get a signature object for the passed callable.""" - return Signature.from_callable(obj, follow_wrapped=follow_wrapped) + return Signature.from_callable(obj, follow_wrapped=follow_wrapped, + globals=globals, locals=locals, eval_str=eval_str) + + +class BufferFlags(enum.IntFlag): + SIMPLE = 0x0 + WRITABLE = 0x1 + FORMAT = 0x4 + ND = 0x8 + STRIDES = 0x10 | ND + C_CONTIGUOUS = 0x20 | STRIDES + F_CONTIGUOUS = 0x40 | STRIDES + ANY_CONTIGUOUS = 0x80 | STRIDES + INDIRECT = 0x100 | STRIDES + CONTIG = ND | WRITABLE + CONTIG_RO = ND + STRIDED = STRIDES | WRITABLE + STRIDED_RO = STRIDES + RECORDS = STRIDES | WRITABLE | FORMAT + RECORDS_RO = STRIDES | FORMAT + FULL = INDIRECT | WRITABLE | FORMAT + FULL_RO = INDIRECT | FORMAT + READ = 0x100 + WRITE = 0x200 def _main(): From 316dc33d900447045237d6b92a37914c22d3276c Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Tue, 30 Jul 2024 01:43:47 +0900 Subject: [PATCH 4/5] make inspect work without co_positions --- Lib/inspect.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/Lib/inspect.py b/Lib/inspect.py index 4ff8791282..4ea2627ba6 100644 --- a/Lib/inspect.py +++ b/Lib/inspect.py @@ -1686,19 +1686,22 @@ def getframeinfo(frame, context=1): The optional second argument specifies the number of lines of context to return, which are centered around the current line.""" if istraceback(frame): - positions = _get_code_position_from_tb(frame) + # XXX: RUSTPYTHON: we don't support _get_code_position yet + # positions = _get_code_position_from_tb(frame) lineno = frame.tb_lineno frame = frame.tb_frame else: lineno = frame.f_lineno - positions = _get_code_position(frame.f_code, frame.f_lasti) + # XXX: RUSTPYTHON: we don't support _get_code_position yet + # positions = _get_code_position(frame.f_code, frame.f_lasti) - if positions[0] is None: - frame, *positions = (frame, lineno, *positions[1:]) - else: - frame, *positions = (frame, *positions) + # XXX: RUSTPYTHON: we don't support _get_code_position yet + # if positions[0] is None: + # frame, *positions = (frame, lineno, *positions[1:]) + # else: + # frame, *positions = (frame, *positions) - lineno = positions[0] + # lineno = positions[0] if not isframe(frame): raise TypeError('{!r} is not a frame or traceback object'.format(frame)) @@ -1718,7 +1721,10 @@ def getframeinfo(frame, context=1): lines = index = None return Traceback(filename, lineno, frame.f_code.co_name, lines, - index, positions=dis.Positions(*positions)) + index, + # XXX: RUSTPYTHON: we don't support _get_code_position yet + # positions=dis.Positions(*positions) + ) def getlineno(frame): """Get the line number from a frame object, allowing for optimization.""" From b1235fe02d6d1b40faacfe6ca575f417dae2b8f6 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Mon, 29 Jul 2024 20:41:37 +0900 Subject: [PATCH 5/5] mark failing tests by inspect update --- Lib/test/test_enum.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/Lib/test/test_enum.py b/Lib/test/test_enum.py index bff85a7ec2..32a3c1dee0 100644 --- a/Lib/test/test_enum.py +++ b/Lib/test/test_enum.py @@ -4882,8 +4882,6 @@ def test_inspect_classify_class_attrs(self): if failed: self.fail("result does not equal expected, see print above") - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_inspect_signatures(self): from inspect import signature, Signature, Parameter self.assertEqual(