8000 Fix a bunch of `random` tests (#5533) · RustPython/RustPython@2721f2d · GitHub
[go: up one dir, main page]

Skip to content

Commit 2721f2d

Browse files
authored
Fix a bunch of random tests (#5533)
1 parent b55a55a commit 2721f2d

File tree

4 files changed

+63
-98
lines changed

4 files changed

+63
-98
lines changed

.gitattributes

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ Cargo.lock linguist-generated -merge
44
vm/src/stdlib/ast/gen.rs linguist-generated -merge
55
Lib/*.py text working-tree-encoding=UTF-8 eol=LF
66
**/*.rs text working-tree-encoding=UTF-8 eol=LF
7+
*.pck binary

Lib/test/test_random.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@ def randomlist(self, n):
2222
"""Helper function to make a list of random numbers"""
2323
return [self.gen.random() for i in range(n)]
2424

25-
# TODO: RUSTPYTHON AttributeError: 'super' object has no attribute 'getstate'
26-
@unittest.expectedFailure
2725
def test_autoseed(self):
2826
self.gen.seed()
2927
state1 = self.gen.getstate()
@@ -32,8 +30,6 @@ def test_autoseed(self):
3230
state2 = self.gen.getstate()
3331
self.assertNotEqual(state1, state2)
3432

35-
# TODO: RUSTPYTHON AttributeError: 'super' object has no attribute 'getstate'
36-
@unittest.expectedFailure
3733
def test_saverestore(self):
3834
N = 1000
3935
self.gen.seed()
@@ -60,7 +56,6 @@ def __hash__(self):
6056
self.assertRaises(TypeError, self.gen.seed, 1, 2, 3, 4)
6157
self.assertRaises(TypeError, type(self.gen), [])
6258

63-
@unittest.skip("TODO: RUSTPYTHON, TypeError: Expected type 'bytes', not 'bytearray'")
6459
def test_seed_no_mutate_bug_44018(self):
6560
a = bytearray(b'1234')
6661
self.gen.seed(a)
@@ -386,8 +381,6 @@ def test_getrandbits(self):
386381
self.assertRaises(ValueError, self.gen.getrandbits, -1)
387382
self.assertRaises(TypeError, self.gen.getrandbits, 10.1)
388383

389-
# TODO: RUSTPYTHON AttributeError: 'super' object has no attribute 'getstate'
390-
@unittest.expectedFailure
391384
def test_pickling(self):
392385
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
393386
state = pickle.dumps(self.gen, proto)
@@ -396,8 +389,6 @@ def test_pickling(self):
396389
restoredseq = [newgen.random() for i in range(10)]
397390
self.assertEqual(origseq, restoredseq)
398391

399-
# TODO: RUSTPYTHON AttributeError: 'super' object has no attribute 'getstate'
400-
@unittest.expectedFailure
401392
def test_bug_1727780(self):
402393
# verify that version-2-pickles can be loaded
403394
# fine, whether they are created on 32-bit or 64-bit
@@ -600,11 +591,6 @@ def test_bug_42008(self):
600591
class MersenneTwister_TestBasicOps(TestBasicOps, unittest.TestCase):
601592
gen = random.Random()
602593

603-
# TODO: RUSTPYTHON, TypeError: Expected type 'bytes', not 'bytearray'
604-
@unittest.expectedFailure
605-
def test_seed_no_mutate_bug_44018(self): # TODO: RUSTPYTHON, remove when this passes
606-
super().test_seed_no_mutate_bug_44018() # TODO: RUSTPYTHON, remove when this passes
607-
608594
def test_guaranteed_stable(self):
609595
# These sequences are guaranteed to stay the same across versions of python
610596
self.gen.seed(3456147, version=1)
@@ -675,8 +661,6 @@ def test_bug_31482(self):
675661
def test_setstate_first_arg(self):
676662
self.assertRaises(ValueError, self.gen.setstate, (1, None, None))
677663

678-
# TODO: RUSTPYTHON AttributeError: 'super' object has no attribute 'getstate'
679-
@unittest.expectedFailure
680664
def test_setstate_middle_arg(self):
681665
start_state = self.gen.getstate()
682666
# Wrong type, s/b tuple
@@ -1282,15 +1266,6 @@ def test_betavariate_return_zero(self, gammavariate_mock):
12821266

12831267

12841268
class TestRandomSubclassing(unittest.TestCase):
1285-
# TODO: RUSTPYTHON Unexpected keyword argument newarg
1286-
@unittest.expectedFailure
1287-
def test_random_subclass_with_kwargs(self):
1288-
# SF bug #1486663 -- this used to erroneously raise a TypeError
1289-
class Subclass(random.Random):
1290-
def __init__(self, newarg=None):
1291-
random.Random.__init__(self)
1292-
Subclass(newarg=1)
1293-
12941269
def test_subclasses_overriding_methods(self):
12951270
# Subclasses with an overridden random, but only the original
12961271
# getrandbits method should not rely on getrandbits in for randrange,

stdlib/src/random.rs

Lines changed: 61 additions & 69 deletions
F42D
Original file line numberDiff line numberDiff line change
@@ -6,73 +6,37 @@ pub(crate) use _random::make_module;
66
mod _random {
77
use crate::common::lock::PyMutex;
88
use crate::vm::{
9-
builtins::{PyInt, PyTypeRef},
9+
builtins::{PyInt, PyTupleRef},
10+
convert::ToPyException,
1011
function::OptionalOption,
11-
types::Constructor,
12-
PyObjectRef, PyPayload, PyResult, VirtualMachine,
12+
types::{Constructor, Initializer},
13+
PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
1314
};
15+
use itertools::Itertools;
1416
use malachite_bigint::{BigInt, BigUint, Sign};
17+
use mt19937::MT19937;
1518
use num_traits::{Signed, Zero};
16-
use rand::{rngs::StdRng, RngCore, SeedableRng};
17-
18-
#[derive(Debug)]
19-
enum PyRng {
20-
Std(Box<StdRng>),
21-
MT(Box<mt19937::MT19937>),
22-
}
23-
24-
impl Default for PyRng {
25-
fn default() -> Self {
26-
PyRng::Std(Box::new(StdRng::from_os_rng()))
27-
}
28-
}
29-
30-
impl RngCore for PyRng {
31-
fn next_u32(&mut self) -> u32 {
32-
match self {
33-
Self::Std(s) => s.next_u32(),
34-
Self::MT(m) => m.next_u32(),
35-
}
36-
}
37-
fn next_u64(&mut self) -> u64 {
38-
match self {
39-
Self::Std(s) => s.next_u64(),
40-
Self::MT(m) => m.next_u64(),
41-
}
42-
}
43-
fn fill_bytes(&mut self, dest: &mut [u8]) {
44-
match self {
45-
Self::Std(s) => s.fill_bytes(dest),
46-
Self::MT(m) => m.fill_bytes(dest),
47-
}
48-
}
49-
}
19+
use rand::{RngCore, SeedableRng};
20+
use rustpython_vm::types::DefaultConstructor;
5021

5122
#[pyattr]
5223
#[pyclass(name = "Random")]
53-
#[derive(Debug, PyPayload)]
24+
#[derive(Debug, PyPayload, Default)]
5425
struct PyRandom {
55-
rng: PyMutex<PyRng>,
26+
rng: PyMutex<MT19937>,
5627
}
5728

58-
impl Constructor for PyRandom {
59-
type Args = OptionalOption<PyObjectRef>;
29+
impl DefaultConstructor for PyRandom {}
6030

61-
fn py_new(
62-
cls: PyTypeRef,
63-
// TODO: use x as the seed.
64-
_x: Self::Args,
65-
vm: &VirtualMachine,
66-
) -> PyResult {
67-
PyRandom {
68-
rng: PyMutex::default(),
69-
}
70-
.into_ref_with_type(vm, cls)
71-
.map(Into::into)
31+
impl Initializer for PyRandom {
32+
type Args = OptionalOption;
33+
34+
fn init(zelf: PyRef<Self>, x: Self::Args, vm: &VirtualMachine) -> PyResult<()> {
35+
zelf.seed(x, vm)
7236
}
7337
}
7438

75-
#[pyclass(flags(BASETYPE), with(Constructor))]
39+
#[pyclass(flags(BASETYPE), with(Constructor, Initializer))]
7640
impl PyRandom {
7741
#[pymethod]
7842
fn random(&self) -> f64 {
@@ -82,9 +46,8 @@ mod _random {
8246

8347
#[pymethod]
8448
fn seed(&self, n: OptionalOption<PyObjectRef>, vm: &VirtualMachine) -> PyResult<()> {
85-
let new_rng = n
86-
.flatten()
87-
.map(|n| {
49+
*self.rng.lock() = match n.flatten() {
50+
Some(n) => {
8851
// Fallback to using hash if object isn't Int-like.
8952
let (_, mut key) = match n.downcast::<PyInt>() {
9053
Ok(n) => n.as_bigint().abs(),
@@ -95,27 +58,21 @@ mod _random {
9558
key.reverse();
9659
}
9760
let key = if key.is_empty() { &[0] } else { key.as_slice() };
98-
Ok(PyRng::MT(Box::new(mt19937::MT19937::new_with_slice_seed(
99-
key,
100-
))))
101-
})
102-
.transpose()?
103-
.unwrap_or_default();
104-
105-
*self.rng.lock() = new_rng;
61+
MT19937::new_with_slice_seed(key)
62+
}
63+
None => MT19937::try_from_os_rng()
64+
.map_err(|e| std::io::Error::from(e).to_pyexception(vm))?,
65+
};
10666
Ok(())
10767
}
10868

10969
#[pymethod]
11070
fn getrandbits(&self, k: isize, vm: &VirtualMachine) -> PyResult<BigInt> {
11171
match k {
112-
k if k < 0 => {
113-
Err(vm.new_value_error("number of bits must be non-negative".to_owned()))
114-
}
72+
..0 => Err(vm.new_value_error("number of bits must be non-negative".to_owned())),
11573
0 => Ok(BigInt::zero()),
116-
_ => {
74+
mut k => {
11775
let mut rng = self.rng.lock();
118-
let mut k = k;
11976
let mut gen_u32 = |k| {
12077
let r = rng.next_u32();
12178
if k < 32 {
@@ -145,5 +102,40 @@ mod _random {
145102
}
146103
}
147104
}
105+
106+
#[pymethod]
107+
fn getstate(&self, vm: &VirtualMachine) -> PyTupleRef {
108+
let rng = self.rng.lock();
109+
vm.new_tuple(
110+
rng.get_state()
111+
.iter()
112+
.copied()
113+
.chain([rng.get_index() as u32])
114+
.map(|i| vm.ctx.new_int(i).into())
115+
.collect::<Vec<PyObjectRef>>(),
116+
)
117+
}
118+
119+
#[pymethod]
120+
fn setstate(&self, state: PyTupleRef, vm: &VirtualMachine) -> PyResult<()> {
121+
let state: &[_; mt19937::N + 1] = state
122+
.as_slice()
123+
.try_into()
124+
.map_err(|_| vm.new_value_error("state vector is the wrong size".to_owned()))?;
125+
let (index, state) = state.split_last().unwrap();
126+
let index: usize = index.try_to_value(vm)?;
127+
if index > mt19937::N {
128+
return Err(vm.new_value_error("invalid state".to_owned()));
129+
}
130+
let state: [u32; mt19937::N] = state
131+
.iter()
132+
.map(|i| i.try_to_value(vm))
133+
.process_results(|it| it.collect_array())?
134+
.unwrap();
135+
let mut rng = self.rng.lock();
136+
rng.set_state(&state);
137+
rng.set_index(index);
138+
Ok(())
139+
}
148140
}
149141
}

vm/src/stdlib/os.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -978,10 +978,7 @@ pub(super) mod _os {
978978
return Err(vm.new_value_error("negative argument not allowed".to_owned()));
979979
}
980980
let mut buf = vec![0u8; size as usize];
981-
getrandom::fill(&mut buf).map_err(|e| match e.raw_os_error() {
982-
Some(errno) => io::Error::from_raw_os_error(errno).into_pyexception(vm),
983-
None => vm.new_os_error("Getting random failed".to_owned()),
984-
})?;
981+
getrandom::fill(&mut buf).map_err(|e| io::Error::from(e).into_pyexception(vm))?;
985982
Ok(buf)
986983
}
987984

0 commit comments

Comments
 (0)
0