8000 Merge pull request #9842 from bashtage/protect-empty-init · numpy/numpy@c0b2aba · GitHub
[go: up one dir, main page]

Skip to content

Commit c0b2aba

Browse files
authored
Merge pull request #9842 from bashtage/protect-empty-init
BUG: Prevent invalid array shapes in seed
2 parents 47a8fbb + ad8a4c7 commit c0b2aba

File tree

3 files changed

+23
-5
lines changed

3 files changed

+23
-5
lines changed

doc/release/1.14.0-notes.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,3 +378,9 @@ display the sign. This new behavior can be disabled to mostly reproduce numpy
378378
-----------------------------------------------------------------
379379
These options could previously be controlled using ``np.set_printoptions``, but
380380
now can be changed on a per-call basis as arguments to ``np.array2string``.
381+
382+
Seeding ``RandomState`` using an array requires a 1-d array
383+
-----------------------------------------------------------
384+
``RandomState`` previously would accept empty arrays or arrays with 2 or more
385+
dimensions, which resulted in either a failure to seed (empty arrays) or for
386+
some of the passed values to be ignored when setting the seed.

numpy/random/mtrand/mtrand.pyx

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -659,7 +659,7 @@ cdef class RandomState:
659659
660660
Parameters
661661
----------
662-
seed : int or array_like, optional
662+
seed : int or 1-d array_like, optional
663663
Seed for `RandomState`.
664664
Must be convertible to 32 bit unsigned integers.
665665
@@ -676,14 +676,19 @@ cdef class RandomState:
676676
errcode = rk_randomseed(self.internal_state)
677677
else:
678678
idx = operator.index(seed)
679-
if idx > int(2**32 - 1) or idx < 0:
679+
if (idx >= 2**32) or (idx < 0):
680680
raise ValueError("Seed must be between 0 and 2**32 - 1")
681681
with self.lock:
682682
rk_seed(idx, self.internal_state)
683683
except TypeError:
684-
obj = np.asarray(seed).astype(np.int64, casting='safe')
685-
if ((obj > int(2**32 - 1)) | (obj < 0)).any():
686-
raise ValueError("Seed must be between 0 and 2**32 - 1")
684+
obj = np.asarray(seed)
685+
if obj.size == 0:
686+
raise ValueError("Seed must be non-empty")
687+
obj = obj.astype(np.int64, casting='safe')
688+
if obj.ndim != 1:
689+
raise ValueError("Seed array must be 1-d")
690+
if ((obj >= 2**32) | (obj < 0)).any():
691+
raise ValueError("Seed values must be between 0 and 2**32 - 1")
687692
obj = obj.astype('L', casting='unsafe')
688693
with self.lock:
689694
init_by_array(self.internal_state, <unsigned long *>PyArray_DATA(obj),

numpy/random/tests/test_random.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,13 @@ def test_invalid_array(self):
4242
assert_raises(ValueError, np.random.RandomState, [1, 2, 4294967296])
4343
assert_raises(ValueError, np.random.RandomState, [1, -2, 4294967296])
4444

45+
def test_invalid_array_shape(self):
46+
# gh-9832
47+
assert_raises(ValueError, np.random.RandomState, np.array([], dtype=np.int64))
48+
assert_raises(ValueError, np.random.RandomState, [[1, 2, 3]])
49+
assert_raises(ValueError, np.random.RandomState, [[1, 2, 3],
50+
[4, 5, 6]])
51+
4552

4653
class TestBinomial(object):
4754
def test_n_zero(self):

0 commit comments

Comments
 (0)
0