8000 Merge pull request #5342 from youknowone/getstate · RustPython/RustPython@adbadfc · GitHub
[go: up one dir, main page]

Skip to content

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit adbadfc

Browse files
authored
Merge pull request #5342 from youknowone/getstate
object.__getstate__
2 parents 866e7cf + 1333688 commit adbadfc

File tree

14 files changed

+195
-34
lines changed

14 files changed

+195
-34
lines changed

Lib/copyreg.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,16 @@ def constructor(object):
2525

2626
# Example: provide pickling support for complex numbers.
2727

28-
try:
29-
complex
30-
except NameError:
31-
pass
32-
else:
28+
def pickle_complex(c):
29+
return complex, (c.real, c.imag)
3330

34-
def pickle_complex(c):
35-
return complex, (c.real, c.imag)
31+
pickle(complex, pickle_complex, complex)
3632

37-
pickle(complex, pickle_complex, complex)
33+
def pickle_union(obj):
34+
import functools, operator
35+
return functools.reduce, (operator.or_, obj.__args__)
36+
37+
pickle(type(int | str), pickle_union)
3838

3939
# Support for pickling new-style objects
4040

@@ -48,6 +48,7 @@ def _reconstructor(cls, base, state):
4848
return obj
4949

5050
_HEAPTYPE = 1<<9
51+
_new_type = type(int.__new__)
5152

5253
# Python code for object.__reduce_ex__ for protocols 0 and 1
5354

@@ -57,6 +58,9 @@ def _reduce_ex(self, proto):
5758
for base in cls.__mro__:
5859
if hasattr(base, '__flags__') and not base.__flags__ & _HEAPTYPE:
5960
break
61+
new = base.__new__
62+
if isinstance(new, _new_type) and new.__self__ is base:
63+
break
6064
else:
6165
base = object # not really reachable
6266
if base is object:
@@ -79,6 +83,10 @@ def _reduce_ex(self, proto):
7983
except AttributeError:
8084
dict = None
8185
else:
86+
if (type(self).__getstate__ is object.__getstate__ and
87+
getattr(self, "__slots__", None)):
88+
raise TypeError("a class that defines __slots__ without "
89+
"defining __getstate__ cannot be pickled")
8290
dict = getstate()
8391
if dict:
8492
return _reconstructor, args, dict

Lib/test/test_descr.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5652,8 +5652,7 @@ def __repr__(self):
56525652
objcopy2 = deepcopy(objcopy)
56535653
self._assert_is_copy(obj, objcopy2)
56545654

5655-
# TODO: RUSTPYTHON
5656-
@unittest.expectedFailure
5655+
@unittest.skip("TODO: RUSTPYTHON")
56575656
def test_issue24097(self):
56585657
# Slot name is freed inside __getattr__ and is later used.
56595658
class S(str): # Not interned

Lib/test/test_ordered_dict.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -292,8 +292,6 @@ def test_equality(self):
292292
# different length implied inequality
293293
self.assertNotEqual(od1, OrderedDict(pairs[:-1]))
294294

295-
# TODO: RUSTPYTHON
296-
@unittest.expectedFailure
297295
def test_copying(self):
298296
OrderedDict = self.OrderedDict
299297
# Check that ordered dicts are copyable, deepcopyable, picklable,
@@ -337,8 +335,6 @@ def check(dup):
337335
check(update_test)
338336
check(OrderedDict(od))
339337

340-
@unittest.expectedFailure
341-
# TODO: RUSTPYTHON
342338
def test_yaml_linkage(self):
343339
OrderedDict = self.OrderedDict
344340
# Verify that __reduce__ is setup in a way that supports PyYAML's dump() feature.
@@ -349,8 +345,6 @@ def test_yaml_linkage(self):
349345
# '!!python/object/apply:__main__.OrderedDict\n- - [a, 1]\n - [b, 2]\n'
350346
self.assertTrue(all(type(pair)==list for pair in od.__reduce__()[1]))
351347

352-
# TODO: RUSTPYTHON
353-
@unittest.expectedFailure
354348
def test_reduce_not_too_fat(self):
355349
OrderedDict = self.OrderedDict
356350
# do not save instance dictionary if not needed
@@ -362,8 +356,6 @@ def test_reduce_not_too_fat(self):
362356
self.assertEqual(od.__dict__['x'], 10)
363357
self.assertEqual(od.__reduce__()[2], {'x': 10})
364358

365-
# TODO: RUSTPYTHON
366-
@unittest.expectedFailure
367359
def test_pickle_recursive(self):
368360
OrderedDict = self.OrderedDict
369361
od = OrderedDict()
@@ -888,17 +880,13 @@ class CPythonOrderedDictSubclassTests(CPythonOrderedDictTests):
888880
class OrderedDict(c_coll.OrderedDict):
889881
pass
890882

891-
# TODO: RUSTPYTHON
892-
@unittest.expectedFailure
893883
class PurePythonOrderedDictWithSlotsCopyingTests(unittest.TestCase):
894884

895885
module = py_coll
896886
class OrderedDict(py_coll.OrderedDict):
897887
__slots__ = ('x', 'y')
898888
test_copying = OrderedDictTests.test_copying
899889

900-
# TODO: RUSTPYTHON
901-
@unittest.expectedFailure
902890
@unittest.skipUnless(c_coll, 'requires the C version of the collections module')
903891
class CPythonOrderedDictWithSlotsCopyingTests(unittest.TestCase):
904892

Lib/test/test_types.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -869,8 +869,6 @@ def eq(actual, expected, typed=True):
869869
eq(x[NT], int | NT | bytes)
870870
eq(x[S], int | S | bytes)
871871

872-
# TODO: RUSTPYTHON
873-
@unittest.expectedFailure
874872
def test_union_pickle(self):
875873
orig = list[T] | int
876874
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
@@ -880,8 +878,6 @@ def test_union_pickle(self):
880878
self.assertEqual(loaded.__args__, orig.__args__)
881879
self.assertEqual(loaded.__parameters__, orig.__parameters__)
882880

883-
# TODO: RUSTPYTHON
884-
@unittest.expectedFailure
885881
def test_union_copy(self):
886882
orig = list[T] | int
887883
for copied in (copy.copy(orig), copy.deepcopy(orig)):

Lib/test/test_xml_dom_minicompat.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,6 @@ def test_nodelist___radd__(self):
8282
node_list = [1, 2] + NodeList([3, 4])
8383
self.assertEqual(node_list, NodeList([1, 2, 3, 4]))
8484

85-
# TODO: RUSTPYTHON
86-
@unittest.expectedFailure
8785
def test_nodelist_pickle_roundtrip(self):
8886
# Test pickling and unpickling of a NodeList.
8987

derive-impl/src/pyclass.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -752,6 +752,7 @@ where
752752
let item_meta = MethodItemMeta::from_attr(ident.clone(), &item_attr)?;
753753

754754
let py_name = item_meta.method_name()?;
755+
let raw = item_meta.raw()?;
755756
let sig_doc = text_signature(func.sig(), &py_name);
756757

757758
let doc = args.attrs.doc().map(|doc| format_doc(&sig_doc, &doc));
@@ -760,6 +761,7 @@ where
760761
cfgs: args.cfgs.to_vec(),
761762
ident: ident.to_owned(),
762763
doc,
764+
raw,
763765
attr_name: self.inner.attr_name,
764766
});
765767
Ok(())
@@ -954,6 +956,7 @@ struct MethodNurseryItem {
954956
py_name: String,
955957
cfgs: Vec<Attribute>,
956958
ident: Ident,
959+
raw: bool,
957960
doc: Option<String>,
958961
attr_name: AttrName,
959962
}
@@ -1005,9 +1008,14 @@ impl ToTokens for MethodNursery {
10051008
// } else {
10061009
// quote_spanned! { ident.span() => #py_name }
10071010
// };
1011+
let method_new = if item.raw {
1012+
quote!(new_raw_const)
1013+
} else {
1014+
quote!(new_const)
1015+
};
10081016
inner_tokens.extend(quote! [
10091017
#(#cfgs)*
1010-
rustpython_vm::function::PyMethodDef::new_const(
1018+
rustpython_vm::function::PyMethodDef::#method_new(
10111019
#py_name,
10121020
Self::#ident,
10131021
#flags,
@@ -1203,7 +1211,7 @@ impl ToTokens for MemberNursery {
12031211
struct MethodItemMeta(ItemMetaInner);
12041212

12051213
impl ItemMeta for MethodItemMeta {
1206-
const ALLOWED_NAMES: &'static [&'static str] = &["name", "magic"];
1214+
const ALLOWED_NAMES: &'static [&'static str] = &["name", "magic", "raw"];
12071215

12081216
fn from_inner(inner: ItemMetaInner) -> Self {
12091217
Self(inner)
@@ -1214,6 +1222,9 @@ impl ItemMeta for MethodItemMeta {
12141222
}
12151223

12161224
impl MethodItemMeta {
1225+
fn raw(&self) -> Result<bool> {
1226+
self.inner()._bool("raw")
1227+
}
12171228
fn method_name(&self) -> Result<String> {
12181229
let inner = self.inner();
12191230
let name = inner._optional_str("name")?;

vm/src/builtins/builtin_func.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,14 @@ impl PyNativeFunction {
4949
)
5050
}
5151

52+
// PyCFunction_GET_SELF
53+
pub fn get_self(&self) -> Option<&PyObjectRef> {
54+
if self.value.flags.contains(PyMethodFlags::STATIC) {
55+
return None;
56+
}
57+
self.zelf.as_ref()
58+
}
59+
5260
pub fn as_func(&self) -> &'static dyn PyNativeFn {
5361
self.value.func
5462
}

vm/src/builtins/list.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,9 @@ impl PyList {
174174
Self::new_ref(self.borrow_vec().to_vec(), &vm.ctx)
175175
}
176176

177+
#[allow(clippy::len_without_is_empty)]
177178
#[pymethod(magic)]
178-
fn len(&self) -> usize {
179+
pub fn len(&self) -> usize {
179180
self.borrow_vec().len()
180181
}
181182

vm/src/builtins/object.rs

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use crate::common::hash::PyHash;
33
use crate::types::PyTypeFlags;
44
use crate::{
55
class::PyClassImpl,
6+
convert::ToPyResult,
67
function::{Either, FuncArgs, PyArithmeticValue, PyComparisonValue, PySetterValue},
78
types::{Constructor, PyComparisonOp},
89
AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyResult, VirtualMachine,
@@ -73,8 +74,137 @@ impl Constructor for PyBaseObject {
7374
}
7475
}
7576

77+
// TODO: implement _PyType_GetSlotNames properly
78+
fn type_slot_names(typ: &Py<PyType>, vm: &VirtualMachine) -> PyResult<Option<super::PyListRef>> {
79+
// let attributes = typ.attributes.read();
80+
// if let Some(slot_names) = attributes.get(identifier!(vm.ctx, __slotnames__)) {
81+
// return match_class!(match slot_names.clone() {
82+
// l @ super::PyList => Ok(Some(l)),
83+
// _n @ super::PyNone => Ok(None),
84+
// _ => Err(vm.new_type_error(format!(
85+
// "{:.200}.__slotnames__ should be a list or None, not {:.200}",
86+
// typ.name(),
87+
// slot_names.class().name()
88+
// ))),
89+
// });
90+
// }
91+
92+
let copyreg = vm.import("copyreg", 0)?;
93+
let copyreg_slotnames = copyreg.get_attr("_slotnames", vm)?;
94+
let slot_names = copyreg_slotnames.call((typ.to_owned(),), vm)?;
95+
let result = match_class!(match slot_names {
96+
l @ super::PyList => Some(l),
97+
_n @ super::PyNone => None,
98+
_ =>
99+
return Err(
100+
vm.new_type_error("copyreg._slotnames didn't return a list or None".to_owned())
101+
),
102+
});
103+
Ok(result)
104+
}
105+
106+
// object_getstate_default in CPython
107+
fn object_getstate_default(obj: &PyObject, required: bool, vm: &VirtualMachine) -> PyResult {
108+
// TODO: itemsize
109+
// if required && obj.class().slots.itemsize > 0 {
110+
// return vm.new_type_error(format!(
111+
// "cannot pickle {:.200} objects",
112+
// obj.class().name()
113+
// ));
114+
// }
115+
116+
let state = if obj.dict().map_or(true, |d| d.is_empty()) {
117+
vm.ctx.none()
118+
} else {
119+
// let state = object_get_dict(obj.clone(), obj.ctx()).unwrap();
120+
let Some(state) = obj.dict() else {
121+
return Ok(vm.ctx.none());
122+
};
123+
state.into()
124+
};
125+
126+
let slot_names = t D306 ype_slot_names(obj.class(), vm)
127+
.map_err(|_| vm.new_type_error("cannot pickle object".to_owned()))?;
128+
129+
if required {
130+
let mut basicsize = obj.class().slots.basicsize;
131+
// if obj.class().slots.dictoffset > 0
132+
// && !obj.class().slots.flags.has_feature(PyTypeFlags::MANAGED_DICT)
133+
// {
134+
// basicsize += std::mem::size_of::<PyObjectRef>();
135+
// }
136+
// if obj.class().slots.weaklistoffset > 0 {
137+
// basicsize += std::mem::size_of::<PyObjectRef>();
138+
// }
139+
if let Some(ref slot_names) = slot_names {
140+
basicsize += std::mem::size_of::<PyObjectRef>() * slot_names.len();
141+
}
142+
if obj.class().slots.basicsize > basicsize {
143+
return Err(
144+
vm.new_type_error(format!("cannot pickle {:.200} object", obj.class().name()))
145+
);
146+
}
147+
}
148+
149+
if let Some(slot_names) = slot_names {
150+
let slot_names_len = slot_names.len();
151+
if slot_names_len > 0 {
152+
let slots = vm.ctx.new_dict();
153+
for i in 0..slot_names_len {
154+
let borrowed_names = slot_names.borrow_vec();
155+
let name = borrowed_names[i].downcast_ref::<PyStr>().unwrap();
156+
let Ok(value) = obj.get_attr(name, vm) else {
157+
continue;
158+
};
159+
slots.set_item(name.as_str(), value, vm).unwrap();
160+
}
161+
162+
if slots.len() > 0 {
163+
return (state, slots).to_pyresult(vm);
164+
}
165+
}
166+
}
167+
168+
Ok(state)
169+
}
170+
171+
// object_getstate in CPython
172+
// fn object_getstate(
173+
// obj: &PyObject,
174+
// required: bool,
175+
// vm: &VirtualMachine,
176+
// ) -> PyResult {
177+
// let getstate = obj.get_attr(identifier!(vm, __getstate__), vm)?;
178+
// if vm.is_none(&getstate) {
179+
// return Ok(None);
180+
// }
181+
182+
// let getstate = match getstate.downcast_exact::<PyNativeFunction>(vm) {
183+
// Ok(getstate)
184+
// if getstate
185+
// .get_self()
186+
// .map_or(false, |self_obj| self_obj.is(obj))
187+
// && std::ptr::addr_eq(
188+
// getstate.as_func() as *const _,
189+
// &PyBaseObject::__getstate__ as &dyn crate::function::PyNativeFn as *const _,
190+
// ) =>
191+
// {
192+
// return object_getstate_default(obj, required, vm);
193+
// }
194+
// Ok(getstate) => getstate.into_pyref().into(),
195+
// Err(getstate) => getstate,
196+
// };
197+
// getstate.call((), vm)
198+
// }
199+
76200
#[pyclass(with(Constructor), flags(BASETYPE))]
77201
impl PyBaseObject {
202+
#[pymethod(raw)]
203+
fn __getstate__(vm: &VirtualMachine, args: FuncArgs) -> PyResult {
204+
let (zelf,): (PyObjectRef,) = args.bind(vm)?;
205+
object_getstate_default(&zelf, false, vm)
206+
}
207+
78208
#[pyslot]
79209
fn slot_richcompare(
80210
zelf: &PyObject,

vm/src/function/builtin.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,11 @@ pub const fn static_func<Kind, F: IntoPyNativeFn<Kind>>(f: F) -> &'static dyn Py
8181
zst_ref_out_of_thin_air(into_func(f))
8282
}
8383

84+
#[inline(always)]
85+
pub const fn static_raw_func<F: PyNativeFn>(f: F) -> &'static dyn PyNativeFn {
86+
zst_ref_out_of_thin_air(f)
87+
}
88+
8489
// TODO: once higher-rank trait bounds are stabilized, remove the `Kind` type
8590
// parameter and impl for F where F: for<T, R, VM> PyNativeFnInternal<T, R, VM>
8691
impl<F, T, R, VM> IntoPyNativeFn<(T, R, VM)> for F

0 commit comments

Comments
 (0)
0