From 96364acbdf60b841b0f65df011bc42449b6b0b5d Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Mon, 25 Aug 2025 22:07:42 +0900 Subject: [PATCH] PyTypeFlags::{SEQUENCE,MAPPING} --- vm/src/builtins/dict.rs | 2 +- vm/src/builtins/list.rs | 2 +- vm/src/builtins/memory.rs | 25 ++++++++------- vm/src/builtins/range.rs | 21 +++++++------ vm/src/builtins/tuple.rs | 2 +- vm/src/builtins/type.rs | 64 +++++++++++++++++++++++++++++++++++++++ vm/src/types/slot.rs | 2 ++ 7 files changed, 95 insertions(+), 23 deletions(-) diff --git a/vm/src/builtins/dict.rs b/vm/src/builtins/dict.rs index 6ec89aa408..2f1c91323c 100644 --- a/vm/src/builtins/dict.rs +++ b/vm/src/builtins/dict.rs @@ -176,7 +176,7 @@ impl PyDict { AsMapping, Representable ), - flags(BASETYPE) + flags(BASETYPE, MAPPING) )] impl PyDict { #[pyclassmethod] diff --git a/vm/src/builtins/list.rs b/vm/src/builtins/list.rs index 9a7b589418..14c8341e44 100644 --- a/vm/src/builtins/list.rs +++ b/vm/src/builtins/list.rs @@ -109,7 +109,7 @@ pub type PyListRef = PyRef; AsSequence, Representable ), - flags(BASETYPE) + flags(BASETYPE, SEQUENCE) )] impl PyList { #[pymethod] diff --git a/vm/src/builtins/memory.rs b/vm/src/builtins/memory.rs index ddbae0dd66..7ec03bf971 100644 --- a/vm/src/builtins/memory.rs +++ b/vm/src/builtins/memory.rs @@ -538,17 +538,20 @@ impl Py { } } -#[pyclass(with( - Py, - Hashable, - Comparable, - AsBuffer, - AsMapping, - AsSequence, - Constructor, - Iterable, - Representable -))] +#[pyclass( + with( + Py, + Hashable, + Comparable, + AsBuffer, + AsMapping, + AsSequence, + Constructor, + Iterable, + Representable + ), + flags(SEQUENCE) +)] impl PyMemoryView { // TODO: Uncomment when Python adds __class_getitem__ to memoryview // #[pyclassmethod] diff --git a/vm/src/builtins/range.rs b/vm/src/builtins/range.rs index 7ce40c24bb..ef55feed60 100644 --- a/vm/src/builtins/range.rs +++ b/vm/src/builtins/range.rs @@ -174,15 +174,18 @@ pub fn init(context: &Context) { PyRangeIterator::extend_class(context, context.types.range_iterator_type); } -#[pyclass(with( - Py, - AsMapping, - AsSequence, - Hashable, - Comparable, - Iterable, - Representable -))] +#[pyclass( + with( + Py, + AsMapping, + AsSequence, + Hashable, + Comparable, + Iterable, + Representable + ), + flags(SEQUENCE) +)] impl PyRange { fn new(cls: PyTypeRef, stop: ArgIndex, vm: &VirtualMachine) -> PyResult> { Self { diff --git a/vm/src/builtins/tuple.rs b/vm/src/builtins/tuple.rs index 037ac8da51..abd7bf71de 100644 --- a/vm/src/builtins/tuple.rs +++ b/vm/src/builtins/tuple.rs @@ -244,7 +244,7 @@ impl PyTuple> { } #[pyclass( - flags(BASETYPE), + flags(BASETYPE, SEQUENCE), with( AsMapping, AsSequence, diff --git a/vm/src/builtins/type.rs b/vm/src/builtins/type.rs index 6a93751a93..01ac2c1d6f 100644 --- a/vm/src/builtins/type.rs +++ b/vm/src/builtins/type.rs @@ -31,6 +31,7 @@ use crate::{ }; use indexmap::{IndexMap, map::Entry}; use itertools::Itertools; +use num_traits::ToPrimitive; use std::{borrow::Borrow, collections::HashSet, ops::Deref, pin::Pin, ptr::NonNull}; #[pyclass(module = false, name = "type", traverse = "manual")] @@ -231,6 +232,58 @@ impl PyType { linearise_mro(mros) } + /// Inherit SEQUENCE and MAPPING flags from base class (CPython: inherit_patma_flags) + fn inherit_patma_flags(slots: &mut PyTypeSlots, base: &PyRef) { + const COLLECTION_FLAGS: PyTypeFlags = PyTypeFlags::from_bits_truncate( + PyTypeFlags::SEQUENCE.bits() | PyTypeFlags::MAPPING.bits(), + ); + if !slots.flags.intersects(COLLECTION_FLAGS) { + slots.flags |= base.slots.flags & COLLECTION_FLAGS; + } + } + + /// Check for __abc_tpflags__ and set the appropriate flags + /// This checks in attrs and all base classes for __abc_tpflags__ + fn check_abc_tpflags( + slots: &mut PyTypeSlots, + attrs: &PyAttributes, + bases: &[PyRef], + ctx: &Context, + ) { + const COLLECTION_FLAGS: PyTypeFlags = PyTypeFlags::from_bits_truncate( + PyTypeFlags::SEQUENCE.bits() | PyTypeFlags::MAPPING.bits(), + ); + + // Don't override if flags are already set + if slots.flags.intersects(COLLECTION_FLAGS) { + return; + } + + // First check in our own attributes + let abc_tpflags_name = ctx.intern_str("__abc_tpflags__"); + if let Some(abc_tpflags_obj) = attrs.get(abc_tpflags_name) { + if let Some(int_obj) = abc_tpflags_obj.downcast_ref::() { + let flags_val = int_obj.as_bigint().to_i64().unwrap_or(0); + let abc_flags = PyTypeFlags::from_bits_truncate(flags_val as u64); + slots.flags |= abc_flags & COLLECTION_FLAGS; + return; + } + } + + // Then check in base classes + for base in bases { + if let Some(abc_tpflags_obj) = base.find_name_in_mro(abc_tpflags_name) { + if let Some(int_obj) = abc_tpflags_obj.downcast_ref::() + { + let flags_val = int_obj.as_bigint().to_i64().unwrap_or(0); + let abc_flags = PyTypeFlags::from_bits_truncate(flags_val as u64); + slots.flags |= abc_flags & COLLECTION_FLAGS; + return; + } + } + } + } + #[allow(clippy::too_many_arguments)] fn new_heap_inner( base: PyRef, @@ -246,6 +299,13 @@ impl PyType { if base.slots.flags.has_feature(PyTypeFlags::HAS_DICT) { slots.flags |= PyTypeFlags::HAS_DICT } + + // Inherit SEQUENCE and MAPPING flags from base class + Self::inherit_patma_flags(&mut slots, &base); + + // Check for __abc_tpflags__ from ABCMeta (for collections.abc.Sequence, Mapping, etc.) + Self::check_abc_tpflags(&mut slots, &attrs, &bases, ctx); + if slots.basicsize == 0 { slots.basicsize = base.slots.basicsize; } @@ -297,6 +357,10 @@ impl PyType { if base.slots.flags.has_feature(PyTypeFlags::HAS_DICT) { slots.flags |= PyTypeFlags::HAS_DICT } + + // Inherit SEQUENCE and MAPPING flags from base class + Self::inherit_patma_flags(&mut slots, &base); + if slots.basicsize == 0 { slots.basicsize = base.slots.basicsize; } diff --git a/vm/src/types/slot.rs b/vm/src/types/slot.rs index 54ad667b8c..182f902bc7 100644 --- a/vm/src/types/slot.rs +++ b/vm/src/types/slot.rs @@ -123,6 +123,8 @@ bitflags! { #[non_exhaustive] pub struct PyTypeFlags: u64 { const MANAGED_DICT = 1 << 4; + const SEQUENCE = 1 << 5; + const MAPPING = 1 << 6; const IMMUTABLETYPE = 1 << 8; const HEAPTYPE = 1 << 9; const BASETYPE = 1 << 10;