8000 Merge pull request #5342 from youknowone/getstate · RustPython/RustPython@adbadfc · GitHub
[go: up one dir, main page]

Skip to content

Commit adbadfc

Browse files
authored
Merge pull request #5342 from youknowone/getstate
object.__getstate__
2 parents 866e7cf + 1333688 commit adbadfc

File tree

14 files changed

+195
-34
lines changed

14 files changed

+195
-34
lines changed

Lib/copyreg.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,16 @@ def constructor(object):
2525

2626
# Example: provide pickling support for complex numbers.
2727

28-
try:
29-
complex
30-
except NameError:
31-
pass
32-
else:
28+
def pickle_complex(c):
29+
return complex, (c.real, c.imag)
3330

34-
def pickle_complex(c):
35-
return complex, (c.real, c.imag)
31+
pickle(complex, pickle_complex, complex)
3632

37-
pickle(complex, pickle_complex, complex)
33+
def pickle_union(obj):
34+
import functools, operator
35+
return functools.reduce, (operator.or_, obj.__args__)
36+
37+
pickle(type(int | str), pickle_union)
3838

3939
# Support for pickling new-style objects
4040

@@ -48,6 +48,7 @@ def _reconstructor(cls, base, state):
4848
return obj
4949

5050
_HEAPTYPE = 1<<9
51+
_new_type = type(int.__new__)
5152

5253
# Python code for object.__reduce_ex__ for protocols 0 and 1
5354

@@ -57,6 +58,9 @@ def _reduce_ex(self, proto):
5758
for base in cls.__mro__:
5859
if hasattr(base, '__flags__') and not base.__flags__ & _HEAPTYPE:
5960
break
61+
new = base.__new__
62+
if isinstance(new, _new_type) and new.__self__ is base:
63+
break
6064
else:
6165
base = object # not really reachable
6266
if base is object:
@@ -79,6 +83,10 @@ def _reduce_ex(self, proto):
7983
except AttributeError:
8084
dict = None
8185
else:
86+
if (type(self).__getstate__ is object.__getstate__ and
87+
getattr(self, "__slots__", None)):
88+
raise TypeError("a class that defines __slots__ without "
89+
"defining __getstate__ cannot be pickled")
8290
dict = getstate()
8391
if dict:
8492
return _reconstructor, args, dict

Lib/test/test_descr.py

-2Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5652,8 +5652,7 @@ def __repr__(self):
56525652
objcopy2 = deepcopy(objcopy)
56535653
self._assert_is_copy(obj, objcopy2)
56545654

5655-
# TODO: RUSTPYTHON
5656-
@unittest.expectedFailure
5655+
@unittest.skip("TODO: RUSTPYTHON")
56575656
def test_issue24097(self):
56585657
# Slot name is freed inside __getattr__ and is later used.
56595658
class S(str): # Not interned

Lib/test/test_ordered_dict.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -292,8 +292,6 @@ def test_equality(self):
292292
# different length implied inequality
293293
self.assertNotEqual(od1, OrderedDict(pairs[:-1]))
294294

295-
# TODO: RUSTPYTHON
296-
@unittest.expectedFailure
297295
def test_copying(self):
298296
OrderedDict = self.OrderedDict
299297
# Check that ordered dicts are copyable, deepcopyable, picklable,
@@ -337,8 +335,6 @@ def check(dup):
337335
check(update_test)
338336
check(OrderedDict(od))
339337

340-
@unittest.expectedFailure
341-
# TODO: RUSTPYTHON
342338
def test_yaml_linkage(self):
343339
OrderedDict = self.OrderedDict
344340
# Verify that __reduce__ is setup in a way that supports PyYAML's dump() feature.
@@ -349,8 +345,6 @@ def test_yaml_linkage(self):
349345
# '!!python/object/apply:__main__.OrderedDict\n- - [a, 1]\n - [b, 2]\n'
350346
self.assertTrue(all(type(pair)==list for pair in od.__reduce__()[1]))
351347

352-
# TODO: RUSTPYTHON
353-
@unittest.expectedFailure
354348
def test_reduce_not_too_fat(self):
355349
OrderedDict = self.OrderedDict
356350
# do not save instance dictionary if not needed
@@ -362,8 +356,6 @@ def test_reduce_not_too_fat(self):
362356
self.assertEqual(od.__dict__['x'], 10)
363357
self.assertEqual(od.__reduce__()[2], {'x': 10})
364358

365-
# TODO: RUSTPYTHON
366-
@unittest.expectedFailure
367359
def test_pickle_recursive(self):
368360
OrderedDict = self.OrderedDict
369361
od = OrderedDict()
@@ -888,17 +880,13 @@ class CPythonOrderedDictSubclassTests(CPythonOrderedDictTests):
888880
class OrderedDict(c_coll.OrderedDict):
889881
pass
890882

891-
# TODO: RUSTPYTHON
892-
@unittest.expectedFailure
893883
class PurePythonOrderedDictWithSlotsCopyingTests(unittest.TestCase):
894884

895885
module = py_coll
896886
class OrderedDict(py_coll.OrderedDict):
897887
__slots__ = ('x', 'y')
898888
test_copying = OrderedDictTests.test_copying
899889

900-
# TODO: RUSTPYTHON
901-
@unittest.expectedFailure
902890
@unittest.skipUnless(c_coll, 'requires the C version of the collections module')
903891
class CPythonOrderedDictWithSlotsCopyingTests(unittest.TestCase):
904892

Lib/test/test_types.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -869,8 +869,6 @@ def eq(actual, expected, typed=True):
869869
eq(x[NT], int | NT | bytes)
870870
eq(x[S], int | S | bytes)
871871

872-
# TODO: RUSTPYTHON
873-
@unittest.expectedFailure
874872
def test_union_pickle(self):
875873
orig = list[T] | int
876874
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
@@ -880,8 +878,6 @@ def test_union_pickle(self):
880878
self.assertEqual(loaded.__args__, orig.__args__)
881879
self.assertEqual(loaded.__parameters__, orig.__parameters__)
882880

883-
# TODO: RUSTPYTHON
884-
@unittest.expectedFailure
885881
def test_union_copy(self):
886882
orig = list[T] | int
887883
for copied in (copy.copy(orig), copy.deepcopy(orig)):

Lib/test/test_xml_dom_minicompat.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,6 @@ def test_nodelist___radd__(self):
8282
node_list = [1, 2] + NodeList([3, 4])
8383
self.assertEqual(node_list, NodeList([1, 2, 3, 4]))
8484

85-
# TODO: RUSTPYTHON
86-
@unittest.expectedFailure
8785
def test_nodelist_pickle_roundtrip(self):
8886
# Test pickling and unpickling of a NodeList.
8987

derive-impl/src/pyclass.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -752,6 +752,7 @@ where
752752
let item_meta = MethodItemMeta::from_attr(ident.clone(), &item_attr)?;
753753

754754
let py_name = item_meta.method_name()?;
755+
let raw = item_meta.raw()?;
755756
let sig_doc = text_signature(func.sig(), &py_name);
756757

757758
let doc = args.attrs.doc().map(|doc| format_doc(&sig_doc, &doc));
@@ -760,6 +761,7 @@ where
760761
cfgs: args.cfgs.to_vec(),
761762
ident: ident.to_owned(),
762763
doc,
764+
raw,
763765
attr_name: self.inner.attr_name,
764766
});
765767
Ok(())
@@ -954,6 +956,7 @@ struct MethodNurseryItem {
954956
py_name: String,
955957
cfgs: Vec<Attribute>,
956958
ident: Ident,
959+
raw: bool,
957960
doc: Option<String>,
958961
attr_name: AttrName,
959962
}
@@ -1005,9 +1008,14 @@ impl ToTokens for MethodNursery {
10051008
// } else {
10061009
// quote_spanned! { ident.span() => #py_name }
10071010
// };
1011+
let method_new = if item.raw {
1012+
quote!(new_raw_const)
1013+
} else {
1014+
quote!(new_const)
1015+
};
10081016
inner_tokens.extend(quote! [
10091017
#(#cfgs)*
1010-
rustpython_vm::function::PyMethodDef::new_const(
1018+
rustpython_vm::function::PyMethodDef::#method_new(
10111019
#py_name,
10121020
Self::#ident,
10131021
#flags,
@@ -1203,7 +1211,7 @@ impl ToTokens for MemberNursery {
12031211
struct MethodItemMeta(ItemMetaInner);
12041212

12051213
impl ItemMeta for MethodItemMeta {
1206-
const ALLOWED_NAMES: &'static [&'static str] = &["name", "magic"];
1214+
const ALLOWED_NAMES: &'static [&'static str] = &["name", "magic", "raw"];
12071215

12081216
fn from_inner(inner: ItemMetaInner) -> Self {
12091217
Self(inner)
@@ -1214,6 +1222,9 @@ impl ItemMeta for MethodItemMeta {
12141222
}
12151223

12161224
impl MethodItemMeta {
1225+
fn raw(&self) -> Result<bool> {
1226+
self.inner()._bool("raw")
1227+
}
12171228
fn method_name(&self) -> Result<String> {
12181229
let inner = self.inner();
12191230
let name = inner._optional_str("name")?;

vm/src/builtins/builtin_func.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,14 @@ impl PyNativeFunction {
4949
)
5050
}
5151

52+
// PyCFunction_GET_SELF
53+
pub fn get_self(&self) -> Option<&PyObjectRef> {
54+
if self.value.flags.contains(PyMethodFlags::STATIC) {
55+
return None;
56+
}
57+
self.zelf.as_ref()
58+
}
59+
5260
pub fn as_func(&self) -> &'static dyn PyNativeFn {
5361
self.value.func
5462< 57AE code class="diff-text syntax-highlighted-line">
}

vm/src/builtins/list.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,9 @@ impl PyList {
174174
Self::new_ref(self.borrow_vec().to_vec(), &vm.ctx)
175175
}
176176

177+
#[allow(clippy::len_without_is_empty)]
177178
#[pymethod(magic)]
178-
fn len(&self) -> usize {
179+
pub fn len(&self) -> usize {
179180
self.borrow_vec().len()
180181
}
181182

vm/src/builtins/object.rs

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use crate::common::hash::PyHash;
33
use crate::types::PyTypeFlags;
44
use crate::{
55
class::PyClassImpl,
6+
convert::ToPyResult,
67
function::{Either, FuncArgs, PyArithmeticValue, PyComparisonValue, PySetterValue},
78
types::{Constructor, PyComparisonOp},
89
AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyResult, VirtualMachine,
@@ -73,8 +74,137 @@ impl Constructor for PyBaseObject {
7374
}
7475
}
7576

77+
// TODO: implement _PyType_GetSlotNames properly
78+
fn type_slot_names(typ: &Py<PyType>, vm: &VirtualMachine) -> PyResult<Option<super::PyListRef>> {
79+
// let attributes = typ.attributes.read();
80+
// if let Some(slot_names) = attributes.get(identifier!(vm.ctx, __slotnames__)) {
81+
// return match_class!(match slot_names.clone() {
82+
// l @ super::PyList => Ok(Some(l)),
83+
// _n @ super::PyNone => Ok(None),
84+
// _ => Err(vm.new_type_error(format!(
85+
// "{:.200}.__slotnames__ should be a list or None, not {:.200}",
86+
// typ.name(),
87+
// slot_names.class().name()
88+
// ))),
89+
// });
90+
// }
91+
92+
let copyreg = vm.import("copyreg", 0)?;
93+
let copyreg_slotnames = copyreg.get_attr("_slotnames", vm)?;
94+
let slot_names = copyreg_slotnames.call((typ.to_owned(),), vm)?;
95+
let result = match_class!(match slot_names {
96+
l @ super::PyList => Some(l),
97+
_n @ super::PyNone => None,
98+
_ =>
99+
return Err(
100+
vm.new_type_error("copyreg._slotnames didn't return a list or None".to_owned())
101+
),
102+
});
103+
Ok(result)
104+
}
105+
106+
// object_getstate_default in CPython
107+
fn object_getstate_default(obj: &PyObject, required: bool, vm: &VirtualMachine) -> PyResult {
108+
// TODO: itemsize
109+
// if required && obj.class().slots.itemsize > 0 {
110+
// return vm.new_type_error(format!(
111+
// "cannot pickle {:.200} objects",
112+
// obj.class().name()
113+
// ));
114+
// }
115+
116+
let state = if obj.dict().map_or(true, |d| d.is_empty()) {
117+
vm.ctx.none()
118+
} else {
119+
// let state = object_get_dict(obj.clone(), obj.ctx()).unwrap();
120+
let Some(state) = obj.dict() else {
121+
return Ok(vm.ctx.none());
122+
};
123+
state.into()
124+
};
125+
126+
let slot_names = type_slot_names(obj.class(), vm)
127+
.map_err(|_| vm.new_type_error("cannot pickle object".to_owned()))?;
128+
129+
if required {
130+
let mut basicsize = obj.class().slots.basicsize;
131+
// if obj.class().slots.dictoffset > 0
132+
// && !obj.class().slots.flags.has_feature(PyTypeFlags::MANAGED_DICT)
133+
// {
134+
// basicsize += std::mem::size_of::<PyObjectRef>();
135+
// }
136+
// if obj.class().slots.weaklistoffset > 0 {
137+
// basicsize += std::mem::size_of::<PyObjectRef>();
138+
// }
139+
if let Some(ref slot_names) = slot_names {
140+
basicsize += std::mem::size_of::<PyObjectRef>() * slot_names.len();
141+
}
142+
if obj.class().slots.basicsize > basicsize {
143+
return Err(
144+
vm.new_type_error(format!("cannot pickle {:.200} object", obj.class().name()))
145+
);
146+
}
147+
}
148+
149+
if let Some(slot_names) = slot_names {
150+
let slot_names_len = slot_names.len();
151+
if slot_names_len > 0 {
152+
let slots = vm.ctx.new_dict();
153+
for i in 0..slot_names_len {
154+
let borrowed_names = slot_names.borrow_vec();
155+
let name = borrowed_names[i].downcast_ref::<PyStr>().unwrap();
156+
let Ok(value) = obj.get_attr(name, vm) else {
157+
continue;
158+
};
159+
slots.set_item(name.as_str(), value, vm).unwrap();
160+
}
161+
162+
if slots.len() > 0 {
163+
return (state, slots).to_pyresult(vm);
164+
}
165+
}
166+
}
167+
168+
Ok(state)
169+
}
170+
171+
// object_getstate in CPython
172+
// fn object_getstate(
173+
// obj: &PyObject,
174+
// required: bool,
175+
// vm: &VirtualMachine,
176+
// ) -> PyResult {
177+
// let getstate = obj.get_attr(identifier!(vm, __getstate__), vm)?;
178+
// if vm.is_none(&getstate) {
179+
// return Ok(None);
180+
// }
181+
182+
// let getstate = match getstate.downcast_exact::<PyNativeFunction>(vm) {
183+
// Ok(getstate)
184+
// if getstate
185+
// .get_self()
186+
// .map_or(false, |self_obj| self_obj.is(obj))
187+
// && std::ptr::addr_eq(
188+
// getstate.as_func() as *const _,
189+
// &PyBaseObject::__getstate__ as &dyn crate::function::PyNativeFn as *const _,
190+
// ) =>
191+
// {
192+
// return object_getstate_default(obj, required, vm);
193+
// }
194+
// Ok(getstate) => getstate.into_pyref().into(),
195+
// Err(getstate) => getstate,
196+
// };
197+
// getstate.call((), vm)
198+
// }
199+
76200
#[pyclass(with(Constructor), flags(BASETYPE))]
77201
impl PyBaseObject {
202+
#[pymethod(raw)]
203+
fn __getstate__(vm: &VirtualMachine, args: FuncArgs) -> PyResult {
204+
let (zelf,): (PyObjectRef,) = args.bind(vm)?;
205+
object_getstate_default(&zelf, false, vm)
206+
}
207+
78208
#[pyslot]
79209
fn slot_richcompare(
80210
zelf: &PyObject,

vm/src/function/builtin.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,11 @@ pub const fn static_func<Kind, F: IntoPyNativeFn<Kind>>(f: F) -> &'static dyn Py
8181
zst_ref_out_of_thin_air(into_func(f))
8282
}
8383

84+
#[inline(always)]
85+
pub const fn static_raw_func<F: PyNativeFn>(f: F) -> &'static dyn PyNativeFn {
86+
zst_ref_out_of_thin_air(f)
87+
}
88+
8489
// TODO: once higher-rank trait bounds are stabilized, remove the `Kind` type
8590
// parameter and impl for F where F: for<T, R, VM> PyNativeFnInternal<T, R, VM>
8691
impl<F, T, R, VM> IntoPyNativeFn<(T, R, VM)> for F

0 commit comments

Comments
 (0)
0