@@ -6,73 +6,37 @@ pub(crate) use _random::make_module;
6
6
mod _random {
7
7
use crate :: common:: lock:: PyMutex ;
8
8
use crate :: vm:: {
9
- builtins:: { PyInt , PyTypeRef } ,
9
+ builtins:: { PyInt , PyTupleRef } ,
10
+ convert:: ToPyException ,
10
11
function:: OptionalOption ,
11
- types:: Constructor ,
12
- PyObjectRef , PyPayload , PyResult , VirtualMachine ,
12
+ types:: { Constructor , Initializer } ,
13
+ PyObjectRef , PyPayload , PyRef , PyResult , VirtualMachine ,
13
14
} ;
15
+ use itertools:: Itertools ;
14
16
use malachite_bigint:: { BigInt , BigUint , Sign } ;
17
+ use mt19937:: MT19937 ;
15
18
use num_traits:: { Signed , Zero } ;
16
- use rand:: { rngs:: StdRng , RngCore , SeedableRng } ;
17
-
18
- #[ derive( Debug ) ]
19
- enum PyRng {
20
- Std ( Box < StdRng > ) ,
21
- MT ( Box < mt19937:: MT19937 > ) ,
22
- }
23
-
24
- impl Default for PyRng {
25
- fn default ( ) -> Self {
26
- PyRng :: Std ( Box :: new ( StdRng :: from_os_rng ( ) ) )
27
- }
28
- }
29
-
30
- impl RngCore for PyRng {
31
- fn next_u32 ( & mut self ) -> u32 {
32
- match self {
33
- Self :: Std ( s) => s. next_u32 ( ) ,
34
- Self :: MT ( m) => m. next_u32 ( ) ,
35
- }
36
- }
37
- fn next_u64 ( & mut self ) -> u64 {
38
- match self {
39
- Self :: Std ( s) => s. next_u64 ( ) ,
40
- Self :: MT ( m) => m. next_u64 ( ) ,
41
- }
42
- }
43
- fn fill_bytes ( & mut self , dest : & mut [ u8 ] ) {
44
- match self {
45
- Self :: Std ( s) => s. fill_bytes ( dest) ,
46
- Self :: MT ( m) => m. fill_bytes ( dest) ,
47
- }
48
- }
49
- }
19
+ use rand:: { RngCore , SeedableRng } ;
20
+ use rustpython_vm:: types:: DefaultConstructor ;
50
21
51
22
#[ pyattr]
52
23
#[ pyclass( name = "Random" ) ]
53
- #[ derive( Debug , PyPayload ) ]
24
+ #[ derive( Debug , PyPayload , Default ) ]
54
25
struct PyRandom {
55
- rng : PyMutex < PyRng > ,
26
+ rng : PyMutex < MT19937 > ,
56
27
}
57
28
58
- impl Constructor for PyRandom {
59
- type Args = OptionalOption < PyObjectRef > ;
29
+ impl DefaultConstructor for PyRandom { }
60
30
61
- fn py_new (
62
- cls : PyTypeRef ,
63
- // TODO: use x as the seed.
64
- _x : Self :: Args ,
65
- vm : & VirtualMachine ,
66
- ) -> PyResult {
67
- PyRandom {
68
- rng : PyMutex :: default ( ) ,
69
- }
70
- . into_ref_with_type ( vm, cls)
71
- . map ( Into :: into)
31
+ impl Initializer for PyRandom {
32
+ type Args = OptionalOption ;
F42D
33
+
34
+ fn init ( zelf : PyRef < Self > , x : Self :: Args , vm : & VirtualMachine ) -> PyResult < ( ) > {
35
+ zelf. seed ( x, vm)
72
36
}
73
37
}
74
38
75
- #[ pyclass( flags( BASETYPE ) , with( Constructor ) ) ]
39
+ #[ pyclass( flags( BASETYPE ) , with( Constructor , Initializer ) ) ]
76
40
impl PyRandom {
77
41
#[ pymethod]
78
42
fn random ( & self ) -> f64 {
@@ -82,9 +46,8 @@ mod _random {
82
46
83
47
#[ pymethod]
84
48
fn seed ( & self , n : OptionalOption < PyObjectRef > , vm : & VirtualMachine ) -> PyResult < ( ) > {
85
- let new_rng = n
86
- . flatten ( )
87
- . map ( |n| {
49
+ * self . rng . lock ( ) = match n. flatten ( ) {
50
+ Some ( n) => {
88
51
// Fallback to using hash if object isn't Int-like.
89
52
let ( _, mut key) = match n. downcast :: < PyInt > ( ) {
90
53
Ok ( n) => n. as_bigint ( ) . abs ( ) ,
@@ -95,27 +58,21 @@ mod _random {
95
58
key. reverse ( ) ;
96
59
}
97
60
let key = if key. is_empty ( ) { & [ 0 ] } else { key. as_slice ( ) } ;
98
- Ok ( PyRng :: MT ( Box :: new ( mt19937:: MT19937 :: new_with_slice_seed (
99
- key,
100
- ) ) ) )
101
- } )
102
- . transpose ( ) ?
103
- . unwrap_or_default ( ) ;
104
-
105
- * self . rng . lock ( ) = new_rng;
61
+ MT19937 :: new_with_slice_seed ( key)
62
+ }
63
+ None => MT19937 :: try_from_os_rng ( )
64
+ . map_err ( |e| std:: io:: Error :: from ( e) . to_pyexception ( vm) ) ?,
65
+ } ;
106
66
Ok ( ( ) )
107
67
}
108
68
109
69
#[ pymethod]
110
70
fn getrandbits ( & self , k : isize , vm : & VirtualMachine ) -> PyResult < BigInt > {
111
71
match k {
112
- k if k < 0 => {
113
- Err ( vm. new_value_error ( "number of bits must be non-negative" . to_owned ( ) ) )
114
- }
72
+ ..0 => Err ( vm. new_value_error ( "number of bits must be non-negative" . to_owned ( ) ) ) ,
115
73
0 => Ok ( BigInt :: zero ( ) ) ,
116
- _ => {
74
+ mut k => {
117
75
let mut rng = self . rng . lock ( ) ;
118
- let mut k = k;
119
76
let mut gen_u32 = |k| {
120
77
let r = rng. next_u32 ( ) ;
121
78
if k < 32 {
@@ -145,5 +102,40 @@ mod _random {
145
102
}
146
103
}
147
104
}
105
+
106
+ #[ pymethod]
107
+ fn getstate ( & self , vm : & VirtualMachine ) -> PyTupleRef {
108
+ let rng = self . rng . lock ( ) ;
109
+ vm. new_tuple (
110
+ rng. get_state ( )
111
+ . iter ( )
112
+ . copied ( )
113
+ . chain ( [ rng. get_index ( ) as u32 ] )
114
+ . map ( |i| vm. ctx . new_int ( i) . into ( ) )
115
+ . collect :: < Vec < PyObjectRef > > ( ) ,
116
+ )
117
+ }
118
+
119
+ #[ pymethod]
120
+ fn setstate ( & self , state : PyTupleRef , vm : & VirtualMachine ) -> PyResult < ( ) > {
121
+ let state: & [ _ ; mt19937:: N + 1 ] = state
122
+ . as_slice ( )
123
+ . try_into ( )
124
+ . map_err ( |_| vm. new_value_error ( "state vector is the wrong size" . to_owned ( ) ) ) ?;
125
+ let ( index, state) = state. split_last ( ) . unwrap ( ) ;
126
+ let index: usize = index. try_to_value ( vm) ?;
127
+ if index > mt19937:: N {
128
+ return Err ( vm. new_value_error ( "invalid state" . to_owned ( ) ) ) ;
129
+ }
130
+ let state: [ u32 ; mt19937:: N ] = state
131
+ . iter ( )
132
+ . map ( |i| i. try_to_value ( vm) )
133
+ . process_results ( |it| it. collect_array ( ) ) ?
134
+ . unwrap ( ) ;
135
+ let mut rng = self . rng . lock ( ) ;
136
+ rng. set_state ( & state) ;
137
+ rng. set_index ( index) ;
138
+ Ok ( ( ) )
139
+ }
148
140
}
149
141
}
0 commit comments