8000 BUG: make seed, randint and shuffle threadsafe · numpy/numpy@fee4bcb · GitHub
[go: up one dir, main page]

Skip to content

Commit fee4bcb

Browse files
sturlamoldenjuliantaylor
authored andcommitted
BUG: make seed, randint and shuffle threadsafe
1 parent 5295888 commit fee4bcb

File tree

1 file changed

+22
-16
lines changed

1 file changed

+22
-16
lines changed

numpy/random/mtrand/mtrand.pyx

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -607,8 +607,8 @@ cdef class RandomState:
607607
def __init__(self, seed=None):
608608
self.internal_state = <rk_state*>PyMem_Malloc(sizeof(rk_state))
609609

610-
self.seed(seed)
611610
self.lock = Lock()
611+
self.seed(seed)
612612

613613
def __dealloc__(self):
614614
if self.internal_state != NULL:
@@ -639,19 +639,22 @@ cdef class RandomState:
639639
cdef ndarray obj "arrayObject_obj"
640640
try:
641641
if seed is None:
642-
errcode = rk_randomseed(self.internal_state)
642+
with self.lock:
643+
errcode = rk_randomseed(self.internal_state)
643644
else:
644645
idx = operator.index(seed)
645646
if idx > int(2**32 - 1) or idx < 0:
646647
raise ValueError("Seed must be between 0 and 4294967295")
647-
rk_seed(idx, self.internal_state)
648+
with self.lock:
649+
rk_seed(idx, self.internal_state)
648650
except TypeError:
649651
obj = np.asarray(seed).astype(np.int64, casting='safe')
650652
if ((obj > int(2**32 - 1)) | (obj < 0)).any():
651653
raise ValueError("Seed must be between 0 and 4294967295")
652654
obj = obj.astype('L', casting='unsafe')
653-
init_by_array(self.internal_state, <unsigned long *>PyArray_DATA(obj),
654-
PyArray_DIM(obj, 0))
655+
with self.lock:
656+
init_by_array(self.internal_state, <unsigned long *>PyArray_DATA(obj),
657+
PyArray_DIM(obj, 0))
655658

656659
def get_state(self):
657660
"""
@@ -936,7 +939,8 @@ cdef class RandomState:
936939

937940
diff = <unsigned long>hi - <unsigned long>lo - 1UL
938941
if size is None:
939-
rv = lo + <long>rk_interval(diff, self. internal_state)
942+
with self.lock:
943+
rv = lo + <long>rk_interval(diff, self. internal_state)
940944
return rv
941945
else:
942946
array = <ndarray>np.empty(size, int)
@@ -4581,20 +4585,22 @@ cdef class RandomState:
45814585
# each row. So we can't just use ordinary assignment to swap the
45824586
# rows; we need a bounce buffer.
45834587
buf = np.empty_like(x[0])
4584-
while i > 0:
4585-
j = rk_interval(i, self.internal_state)
4586-
buf[...] = x[j]
4587-
x[j] = x[i]
4588-
x[i] = buf
4589-
i = i - 1
4588+
with self.lock:
4589+
while i > 0:
4590+
j = rk_interval(i, self.internal_state)
4591+
buf[...] = x[j]
4592+
x[j] = x[i]
4593+
x[i] = buf
4594+
i = i - 1
45904595
else:
45914596
# For single-dimensional arrays, lists, and any other Python
45924597
# sequence types, indexing returns a real object that's
45934598
# independent of the array contents, so we can just swap directly.
4594-
while i > 0:
4595-
j = rk_interval(i, self.internal_state)
4596-
x[i], x[j] = x[j], x[i]
4597-
i = i - 1
4599+
with self.lock:
4600+
while i > 0:
4601+
j = rk_interval(i, self.internal_state)
4602+
x[i], x[j] = x[j], x[i]
4603+
i = i - 1
45984604

45994605
def permutation(self, object x):
46004606
"""

0 commit comments

Comments
 (0)
0