E60C BUG: make set_state and get_state threadsafe · numpy/numpy@5295888 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5295888

Browse files
sturlamoldenjuliantaylor
authored andcommitted
BUG: make set_state and get_state threadsafe
1 parent 14a3dca commit 5295888

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

numpy/random/mtrand/mtrand.pyx

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -685,10 +685,13 @@ cdef class RandomState:
685685
"""
686686
cdef ndarray state "arrayObject_state"
687687
state = <ndarray>np.empty(624, np.uint)
688-
memcpy(<void*>PyArray_DATA(state), <void*>(self.internal_state.key), 624*sizeof(long))
688+
with self.lock:
689+
memcpy(<void*>PyArray_DATA(state), <void*>(self.internal_state.key), 624*sizeof(long))
690+
has_gauss = self.internal_state.has_gauss
691+
gauss = self.internal_state.gauss
692+
pos = self.internal_state.pos
689693
state = <ndarray>np.asarray(state, np.uint32)
690-
return ('MT19937', state, self.internal_state.pos,
691-
self.internal_state.has_gauss, self.internal_state.gauss)
694+
return ('MT19937', state, pos, has_gauss, gauss)
692695

693696
def set_state(self, state):
694697
"""
@@ -755,10 +758,11 @@ cdef class RandomState:
755758
obj = <ndarray>PyArray_ContiguousFromObject(key, NPY_LONG, 1, 1)
756759
if PyArray_DIM(obj, 0) != 624:
757760
raise ValueError("state must be 624 longs")
758-
memcpy(<void*>(self.internal_state.key), <void*>PyArray_DATA(obj), 624*sizeof(long))
759-
self.internal_state.pos = pos
760-
self.internal_state.has_gauss = has_gauss
761-
self.internal_state.gauss = cached_gaussian
761+
with self.lock:
762+
memcpy(<void*>(self.internal_state.key), <void*>PyArray_DATA(obj), 624*sizeof(long))
763+
self.internal_state.pos = pos
764+
self.internal_state.has_gauss = has_gauss
765+
self.internal_state.gauss = cached_gaussian
762766

763767
# Pickling support:
764768
def __getstate__(self):

0 commit comments

Comments
 (0)
0