diff --git a/crates/vm/src/builtins/descriptor.rs b/crates/vm/src/builtins/descriptor.rs index c251ac45ee..fcf40d0082 100644 --- a/crates/vm/src/builtins/descriptor.rs +++ b/crates/vm/src/builtins/descriptor.rs @@ -4,10 +4,11 @@ use crate::{ builtins::{PyTypeRef, builtin_func::PyNativeMethod, type_}, class::PyClassImpl, common::hash::PyHash, + convert::ToPyResult, function::{FuncArgs, PyMethodDef, PyMethodFlags, PySetterValue}, types::{ - Callable, Comparable, GetDescriptor, HashFunc, Hashable, InitFunc, PyComparisonOp, - Representable, StringifyFunc, + Callable, Comparable, GetDescriptor, HashFunc, Hashable, InitFunc, IterFunc, IterNextFunc, + PyComparisonOp, Representable, StringifyFunc, }, }; use rustpython_common::lock::PyRwLock; @@ -399,6 +400,8 @@ pub enum SlotFunc { Init(InitFunc), Hash(HashFunc), Repr(StringifyFunc), + Iter(IterFunc), + IterNext(IterNextFunc), } impl std::fmt::Debug for SlotFunc { @@ -407,6 +410,8 @@ impl std::fmt::Debug for SlotFunc { SlotFunc::Init(_) => write!(f, "SlotFunc::Init(...)"), SlotFunc::Hash(_) => write!(f, "SlotFunc::Hash(...)"), SlotFunc::Repr(_) => write!(f, "SlotFunc::Repr(...)"), + SlotFunc::Iter(_) => write!(f, "SlotFunc::Iter(...)"), + SlotFunc::IterNext(_) => write!(f, "SlotFunc::IterNext(...)"), } } } @@ -437,6 +442,22 @@ impl SlotFunc { let s = func(&obj, vm)?; Ok(s.into()) } + SlotFunc::Iter(func) => { + if !args.args.is_empty() || !args.kwargs.is_empty() { + return Err( + vm.new_type_error("__iter__() takes no arguments (1 given)".to_owned()) + ); + } + func(obj, vm) + } + SlotFunc::IterNext(func) => { + if !args.args.is_empty() || !args.kwargs.is_empty() { + return Err( + vm.new_type_error("__next__() takes no arguments (1 given)".to_owned()) + ); + } + func(&obj, vm).to_pyresult(vm) + } } } } diff --git a/crates/vm/src/class.rs b/crates/vm/src/class.rs index 236967dd36..1addf00497 100644 --- a/crates/vm/src/class.rs +++ b/crates/vm/src/class.rs @@ -139,52 +139,50 @@ pub trait PyClassImpl: PyClassDef { } } - // Add __init__ slot wrapper if slot exists and not already in dict - if let Some(init_func) = class.slots.init.load() { - let init_name = identifier!(ctx, __init__); - if !class.attributes.read().contains_key(init_name) { - let wrapper = PySlotWrapper { - typ: class, - name: ctx.intern_str("__init__"), - wrapped: SlotFunc::Init(init_func), - doc: Some("Initialize self. See help(type(self)) for accurate signature."), - }; - class.set_attr(init_name, wrapper.into_ref(ctx).into()); - } - } - - // Add __hash__ slot wrapper if slot exists and not already in dict - // Note: hash_not_implemented is handled separately (sets __hash__ = None) - if let Some(hash_func) = class.slots.hash.load() - && hash_func as usize != hash_not_implemented as usize - { - let hash_name = identifier!(ctx, __hash__); - if !class.attributes.read().contains_key(hash_name) { - let wrapper = PySlotWrapper { - typ: class, - name: ctx.intern_str("__hash__"), - wrapped: SlotFunc::Hash(hash_func), - doc: Some("Return hash(self)."), - }; - class.set_attr(hash_name, wrapper.into_ref(ctx).into()); - } - } - - if class.slots.hash.load().map_or(0, |h| h as usize) == hash_not_implemented as usize { - class.set_attr(ctx.names.__hash__, ctx.none.clone().into()); + // Add slot wrappers for slots that exist and are not already in dict + // This mirrors CPython's add_operators() in typeobject.c + macro_rules! add_slot_wrapper { + ($slot:ident, $name:ident, $variant:ident, $doc:expr) => { + if let Some(func) = class.slots.$slot.load() { + let attr_name = identifier!(ctx, $name); + if !class.attributes.read().contains_key(attr_name) { + let wrapper = PySlotWrapper { + typ: class, + name: ctx.intern_str(stringify!($name)), + wrapped: SlotFunc::$variant(func), + doc: Some($doc), + }; + class.set_attr(attr_name, wrapper.into_ref(ctx).into()); + } + } + }; } - // Add __repr__ slot wrapper if slot exists and not already in dict - if let Some(repr_func) = class.slots.repr.load() { - let repr_name = identifier!(ctx, __repr__); - if !class.attributes.read().contains_key(repr_name) { - let wrapper = PySlotWrapper { - typ: class, - name: ctx.intern_str("__repr__"), - wrapped: SlotFunc::Repr(repr_func), - doc: Some("Return repr(self)."), - }; - class.set_attr(repr_name, wrapper.into_ref(ctx).into()); + add_slot_wrapper!( + init, + __init__, + Init, + "Initialize self. See help(type(self)) for accurate signature." + ); + add_slot_wrapper!(repr, __repr__, Repr, "Return repr(self)."); + add_slot_wrapper!(iter, __iter__, Iter, "Implement iter(self)."); + add_slot_wrapper!(iternext, __next__, IterNext, "Implement next(self)."); + + // __hash__ needs special handling: hash_not_implemented sets __hash__ = None + if let Some(hash_func) = class.slots.hash.load() { + if hash_func as usize == hash_not_implemented as usize { + class.set_attr(ctx.names.__hash__, ctx.none.clone().into()); + } else { + let hash_name = identifier!(ctx, __hash__); + if !class.attributes.read().contains_key(hash_name) { + let wrapper = PySlotWrapper { + typ: class, + name: ctx.intern_str("__hash__"), + wrapped: SlotFunc::Hash(hash_func), + doc: Some("Return hash(self)."), + }; + class.set_attr(hash_name, wrapper.into_ref(ctx).into()); + } } } diff --git a/crates/vm/src/types/slot.rs b/crates/vm/src/types/slot.rs index 6e98da173a..26c059067e 100644 --- a/crates/vm/src/types/slot.rs +++ b/crates/vm/src/types/slot.rs @@ -6,7 +6,7 @@ use crate::{ builtins::{PyInt, PyStr, PyStrInterned, PyStrRef, PyType, PyTypeRef, type_::PointerSlot}, bytecode::ComparisonOperator, common::hash::PyHash, - convert::{ToPyObject, ToPyResult}, + convert::ToPyObject, function::{ Either, FromArgs, FuncArgs, OptionalArg, PyComparisonValue, PyMethodDef, PySetterValue, }, @@ -1435,10 +1435,7 @@ pub trait Iterable: PyPayload { Self::iter(zelf, vm) } - #[pymethod] - fn __iter__(zelf: PyObjectRef, vm: &VirtualMachine) -> PyResult { - Self::slot_iter(zelf, vm) - } + // __iter__ is exposed via SlotFunc::Iter wrapper in extend_class() fn iter(zelf: PyRef, vm: &VirtualMachine) -> PyResult; @@ -1458,11 +1455,7 @@ pub trait IterNext: PyPayload + Iterable { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult; - #[inline] - #[pymethod] - fn __next__(zelf: PyObjectRef, vm: &VirtualMachine) -> PyResult { - Self::slot_iternext(&zelf, vm).to_pyresult(vm) - } + // __next__ is exposed via SlotFunc::IterNext wrapper in extend_class() } pub trait SelfIter: PyPayload {} @@ -1477,9 +1470,7 @@ where unreachable!("slot must be overridden for {}", repr.as_str()); } - fn __iter__(zelf: PyObjectRef, vm: &VirtualMachine) -> PyResult { - self_iter(zelf, vm) - } + // __iter__ is exposed via SlotFunc::Iter wrapper in extend_class() #[cold] fn iter(_zelf: PyRef, _vm: &VirtualMachine) -> PyResult {