8000 object.__getstate__ by youknowone · Pull Request #5342 · RustPython/RustPython · GitHub
[go: up one dir, main page]

Skip to content

object.__getstate__ #5342

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions Lib/copyreg.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,16 @@ def constructor(object):

# Example: provide pickling support for complex numbers.

try:
complex
except NameError:
pass
else:
def pickle_complex(c):
return complex, (c.real, c.imag)

def pickle_complex(c):
return complex, (c.real, c.imag)
pickle(complex, pickle_complex, complex)

pickle(complex, pickle_complex, complex)
def pickle_union(obj):
import functools, operator
return functools.reduce, (operator.or_, obj.__args__)

pickle(type(int | str), pickle_union)

# Support for pickling new-style objects

Expand All @@ -48,6 +48,7 @@ def _reconstructor(cls, base, state):
return obj

_HEAPTYPE = 1<<9
_new_type = type(int.__new__)

# Python code for object.__reduce_ex__ for protocols 0 and 1

Expand All @@ -57,6 +58,9 @@ def _reduce_ex(self, proto):
for base in cls.__mro__:
if hasattr(base, '__flags__') and not base.__flags__ & _HEAPTYPE:
break
new = base.__new__
if isinstance(new, _new_type) and new.__self__ is base:
break
else:
base = object # not really reachable
if base is object:
Expand All @@ -79,6 +83,10 @@ def _reduce_ex(self, proto):
except AttributeError:
dict = None
else:
if (type(self).__getstate__ is object.__getstate__ and
getattr(self, "__slots__", None)):
raise TypeError("a class that defines __slots__ without "
"defining __getstate__ cannot be pickled")
dict = getstate()
if dict:
return _reconstructor, args, dict
Expand Down
3 changes: 1 addition & 2 deletions Lib/test/test_descr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5652,8 +5652,7 @@ def __repr__(self):
objcopy2 = deepcopy(objcopy)
self._assert_is_copy(obj, objcopy2)

# TODO: RUSTPYTHON
@unittest.expectedFailure
@unittest.skip("TODO: RUSTPYTHON")
def test_issue24097(self):
# Slot name is freed inside __getattr__ and is later used.
class S(str): # Not interned
Expand Down
12 changes: 0 additions & 12 deletions Lib/test/test_ordered_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,6 @@ def test_equality(self):
# different length implied inequality
self.assertNotEqual(od1, OrderedDict(pairs[:-1]))

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_copying(self):
OrderedDict = self.OrderedDict
# Check that ordered dicts are copyable, deepcopyable, picklable,
Expand Down Expand Up @@ -337,8 +335,6 @@ def check(dup):
check(update_test)
check(OrderedDict(od))

@unittest.expectedFailure
# TODO: RUSTPYTHON
def test_yaml_linkage(self):
OrderedDict = self.OrderedDict
# Verify that __reduce__ is setup in a way that supports PyYAML's dump() feature.
Expand All @@ -349,8 +345,6 @@ def test_yaml_linkage(self):
# '!!python/object/apply:__main__.OrderedDict\n- - [a, 1]\n - [b, 2]\n'
self.assertTrue(all(type(pair)==list for pair in od.__reduce__()[1]))

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_reduce_not_too_fat(self):
OrderedDict = self.OrderedDict
# do not save instance dictionary if not needed
Expand All @@ -362,8 +356,6 @@ def test_reduce_not_too_fat(self):
self.assertEqual(od.__dict__['x'], 10)
self.assertEqual(od.__reduce__()[2], {'x': 10})

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_pickle_recursive(self):
OrderedDict = self.OrderedDict
od = OrderedDict()
Expand Down Expand Up @@ -888,17 +880,13 @@ class CPythonOrderedDictSubclassTests(CPythonOrderedDictTests):
class OrderedDict(c_coll.OrderedDict):
pass

# TODO: RUSTPYTHON
@unittest.expectedFailure
class PurePythonOrderedDictWithSlotsCopyingTests(unittest.TestCase):

module = py_coll
class OrderedDict(py_coll.OrderedDict):
__slots__ = ('x', 'y')
test_copying = OrderedDictTests.test_copying

# TODO: RUSTPYTHON
@unittest.expectedFailure
@unittest.skipUnless(c_coll, 'requires the C version of the collections module')
class CPythonOrderedDictWithSlotsCopyingTests(unittest.TestCase):

Expand Down
4 changes: 0 additions & 4 deletions Lib/test/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,8 +869,6 @@ def eq(actual, expected, typed=True):
eq(x[NT], int | NT | bytes)
eq(x[S], int | S | bytes)

# TODO: RUSTPYTHON
@unittest.expectedFailure
6D47 def test_union_pickle(self):
orig = list[T] | int
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
Expand All @@ -880,8 +878,6 @@ def test_union_pickle(self):
self.assertEqual(loaded.__args__, orig.__args__)
self.assertEqual(loaded.__parameters__, orig.__parameters__)

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_union_copy(self):
orig = list[T] | int
for copied in (copy.copy(orig), copy.deepcopy(orig)):
Expand Down
2 changes: 0 additions & 2 deletions Lib/test/test_xml_dom_minicompat.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,6 @@ def test_nodelist___radd__(self):
node_list = [1, 2] + NodeList([3, 4])
self.assertEqual(node_list, NodeList([1, 2, 3, 4]))

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_nodelist_pickle_roundtrip(self):
# Test pickling and unpickling of a NodeList.

Expand Down
15 changes: 13 additions & 2 deletions derive-impl/src/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,7 @@ where
let item_meta = MethodItemMeta::from_attr(ident.clone(), &item_attr)?;

let py_name = item_meta.method_name()?;
let raw = item_meta.raw()?;
let sig_doc = text_signature(func.sig(), &py_name);

let doc = args.attrs.doc().map(|doc| format_doc(&sig_doc, &doc));
Expand All @@ -760,6 +761,7 @@ where
cfgs: args.cfgs.to_vec(),
ident: ident.to_owned(),
doc,
raw,
attr_name: self.inner.attr_name,
});
Ok(())
Expand Down Expand Up @@ -954,6 +956,7 @@ struct MethodNurseryItem {
py_name: String,
cfgs: Vec<Attribute>,
ident: Ident,
raw: bool,
doc: Option<String>,
attr_name: AttrName,
}
Expand Down Expand Up @@ -1005,9 +1008,14 @@ impl ToTokens for MethodNursery {
// } else {
// quote_spanned! { ident.span() => #py_name }
// };
let method_new = if item.raw {
quote!(new_raw_const)
} else {
quote!(new_const)
};
inner_tokens.extend(quote! [
#(#cfgs)*
rustpython_vm::function::PyMethodDef::new_const(
rustpython_vm::function::PyMethodDef::#method_new(
#py_name,
Self::#ident,
#flags,
Expand Down Expand Up @@ -1203,7 +1211,7 @@ impl ToTokens for MemberNursery {
struct MethodItemMeta(ItemMetaInner);

impl ItemMeta for MethodItemMeta {
const ALLOWED_NAMES: &'static [&'static str] = &["name", "magic"];
const ALLOWED_NAMES: &'static [&'static str] = &["name", "magic", "raw"];

fn from_inner(inner: ItemMetaInner) -> Self {
Self(inner)
Expand All @@ -1214,6 +1222,9 @@ impl ItemMeta for MethodItemMeta {
}

impl MethodItemMeta {
fn raw(&self) -> Result<bool> {
self.inner()._bool("raw")
}
fn method_name(&self) -> Result<String> {
let inner = self.inner();
let name = inner._optional_str("name")?;
Expand Down
8 changes: 8 additions & 0 deletions vm/src/builtins/builtin_func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@ impl PyNativeFunction {
)
}

// PyCFunction_GET_SELF
pub fn get_self(&self) -> Option<&PyObjectRef> {
if self.value.flags.contains(PyMethodFlags::STATIC) {
return None;
}
self.zelf.as_ref()
}

pub fn as_func(&self) -> &'static dyn PyNativeFn {
self.value.func
}
Expand Down
3 changes: 2 additions & 1 deletion vm/src/builtins/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,9 @@ impl PyList {
Self::new_ref(self.borrow_vec().to_vec(), &vm.ctx)
}

#[allow(clippy::len_without_is_empty)]
#[pymethod(magic)]
fn len(&self) -> usize {
pub fn len(&self) -> usize {
self.borrow_vec().len()
}

Expand Down
130 changes: 130 additions & 0 deletions vm/src/builtins/object.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::common::hash::PyHash;
use crate::types::PyTypeFlags;
use crate::{
class::PyClassImpl,
convert::ToPyResult,
function::{Either, FuncArgs, PyArithmeticValue, PyComparisonValue, PySetterValue},
types::{Constructor, PyComparisonOp},
AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyResult, VirtualMachine,
Expand Down Expand Up @@ -73,8 +74,137 @@ impl Constructor for PyBaseObject {
}
}

// TODO: implement _PyType_GetSlotNames properly
fn type_slot_names(typ: &Py<PyType>, vm: &VirtualMachine) -> PyResult<Option<super::PyListRef>> {
// let attributes = typ.attributes.read();
// if let Some(slot_names) = attributes.get(identifier!(vm.ctx, __slotnames__)) {
// return match_class!(match slot_names.clone() {
// l @ super::PyList => Ok(Some(l)),
// _n @ super::PyNone => Ok(None),
// _ => Err(vm.new_type_error(format!(
// "{:.200}.__slotnames__ should be a list or None, not {:.200}",
// typ.name(),
// slot_names.class().name()
// ))),
// });
// }

let copyreg = vm.import("copyreg", 0)?;
let copyreg_slotnames = copyreg.get_attr("_slotnames", vm)?;
let slot_names = copyreg_slotnames.call((typ.to_owned(),), vm)?;
let result = match_class!(match slot_names {
l @ super::PyList => Some(l),
_n @ super::PyNone => None,
_ =>
return Err(
vm.new_type_error("copyreg._slotnames didn't return a list or None".to_owned())
),
});
Ok(result)
}

// object_getstate_default in CPython
fn object_getstate_default(obj: &PyObject, required: bool, vm: &VirtualMachine) -> PyResult {
// TODO: itemsize
// if required && obj.class().slots.itemsize > 0 {
// return vm.new_type_error(format!(
// "cannot pickle {:.200} objects",
// obj.class().name()
// ));
// }

let state = if obj.dict().map_or(true, |d| d.is_empty()) {
vm.ctx.none()
} else {
// let state = object_get_dict(obj.clone(), obj.ctx()).unwrap();
let Some(state) = obj.dict() else {
return Ok(vm.ctx.none());
};
state.into()
};

let slot_names = type_slot_names(obj.class(), vm)
.map_err(|_| vm.new_type_error("cannot pickle object".to_owned()))?;

if required {
let mut basicsize = obj.class().slots.basicsize;
// if obj.class().slots.dictoffset > 0
// && !obj.class().slots.flags.has_feature(PyTypeFlags::MANAGED_DICT)
// {
// basicsize += std::mem::size_of::<PyObjectRef>();
// }
// if obj.class().slots.weaklistoffset > 0 {
// basicsize += std::mem::size_of::<PyObjectRef>();
// }
if let Some(ref slot_names) = slot_names {
basicsize += std::mem::size_of::<PyObjectRef>() * slot_names.len();
}
if obj.class().slots.basicsize > basicsize {
return Err(
vm.new_type_error(format!("cannot pickle {:.200} object", obj.class().name()))
);
}
}

if let Some(slot_names) = slot_names {
let slot_names_len = slot_names.len();
if slot_names_len > 0 {
let slots = vm.ctx.new_dict();
for i in 0..slot_names_len {
let borrowed_names = slot_names.borrow_vec();
let name = borrowed_names[i].downcast_ref::<PyStr>().unwrap();
let Ok(value) = obj.get_attr(name, vm) else {
continue;
};
slots.set_item(name.as_str(), value, vm).unwrap();
}

if slots.len() > 0 {
return (state, slots).to_pyresult(vm);
}
}
}

Ok(state)
}

// object_getstate in CPython
// fn object_getstate(
// obj: &PyObject,
// required: bool,
// vm: &VirtualMachine,
// ) -> PyResult {
// let getstate = obj.get_attr(identifier!(vm, __getstate__), vm)?;
// if vm.is_none(&getstate) {
// return Ok(None);
// }

// let getstate = match getstate.downcast_exact::<PyNativeFunction>(vm) {
// Ok(getstate)
// if getstate
// .get_self()
// .map_or(false, |self_obj| self_obj.is(obj))
// && std::ptr::addr_eq(
// getstate.as_func() as *const _,
// &PyBaseObject::__getstate__ as &dyn crate::function::PyNativeFn as *const _,
// ) =>
// {
// return object_getstate_default(obj, required, vm);
// }
// Ok(getstate) => getstate.into_pyref().into(),
// Err(getstate) => getstate,
// };
// getstate.call((), vm)
// }

#[pyclass(with(Constructor), flags(BASETYPE))]
impl PyBaseObject {
#[pymethod(raw)]
fn __getstate__(vm: &VirtualMachine, args: FuncArgs) -> PyResult {
let (zelf,): (PyObjectRef,) = args.bind(vm)?;
object_getstate_default(&zelf, false, vm)
}

#[pyslot]
fn slot_richcompare(
zelf: &PyObject,
Expand Down
5 changes: 5 additions & 0 deletions vm/src/function/builtin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ pub const fn static_func<Kind, F: IntoPyNativeFn<Kind>>(f: F) -> &'static dyn Py
zst_ref_out_of_thin_air(into_func(f))
}

#[inline(always)]
pub const fn static_raw_func<F: PyNativeFn>(f: F) -> &'static dyn PyNativeFn {
zst_ref_out_of_thin_air(f)
}

// TODO: once higher-rank trait bounds are stabilized, remove the `Kind` type
// parameter and impl for F where F: for<T, R, VM> PyNativeFnInternal<T, R, VM>
impl<F, T, R, VM> IntoPyNativeFn<(T, R, VM)> for F
Expand Down
Loading
Loading
0