8000 Merge pull request #5388 from sturlamolden/mtrand-bugfix-threadsafe · numpy/numpy@a581765 · GitHub
[go: up one dir, main page]

Skip to content

Commit a581765

Browse files
committed
Merge pull request #5388 from sturlamolden/mtrand-bugfix-threadsafe
BUG: Make RandomState.set_state and RandomState.get_state threadsafe
2 parents 243ab56 + 5295888 commit a581765

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

numpy/random/mtrand/mtrand.pyx

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -691,10 +691,13 @@ cdef class RandomState:
691691
"""
692692
cdef ndarray state "arrayObject_state"
693693
state = <ndarray>np.empty(624, np.uint)
694-
memcpy(<void*>PyArray_DATA(state), <void*>(self.internal_state.key), 624*sizeof(long))
694+
with self.lock:
695+
memcpy(<void*>PyArray_DATA(state), <void*>(self.internal_state.key), 624*sizeof(long))
696+
has_gauss = self.internal_state.has_gauss
697+
gauss = self.internal_state.gauss
698+
pos = self.internal_state.pos
695699
state = <ndarray>np.asarray(state, np.uint32)
696-
return ('MT19937', state, self.internal_state.pos,
697-
self.internal_state.has_gauss, self.internal_state.gauss)
700+
return ('MT19937', state, pos, has_gauss, gauss)
698701

699702
def set_state(self, state):
700703
"""
@@ -761,10 +764,11 @@ cdef class RandomState:
761764
obj = <ndarray>PyArray_ContiguousFromObject(key, NPY_LONG, 1, 1)
762765
if PyArray_DIM(obj, 0) != 624:
763766
raise ValueError("state must be 624 longs")
764-
memcpy(<void*>(self.internal_state.key), <void*>PyArray_DATA(obj), 624*sizeof(long))
765-
self.internal_state.pos = pos
766-
self.internal_state.has_gauss = has_gauss
767-
self.internal_state.gauss = cached_gaussian
767+
with self.lock:
768+
memcpy(<void*>(self.internal_state.key), <void*>PyArray_DATA(obj), 624*sizeof(long))
769+
self.internal_state.pos = pos
770+
self.internal_state.has_gauss = has_gauss
771+
self.internal_state.gauss = cached_gaussian
768772

769773
# Pickling support:
770774
def __getstate__(self):

0 commit comments

Comments
 (0)
0