8000 Merge pull request #7000 from gfyoung/ndarray_arg_enforce · numpy/numpy@54224f4 · GitHub
[go: up one dir, main page]

Skip to content
< 8000 /react-partial>

Commit 54224f4

Browse files
committed
Merge pull request #7000 from gfyoung/ndarray_arg_enforce
DOC, MAINT: Enforce np.ndarray arg for np.put and np.place
2 parents 4d87d90 + 02bcbd7 commit 54224f4

File tree

5 files changed

+33
-6
lines changed

5 files changed

+33
-6
lines changed

numpy/core/fromnumeric.py

8000 Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,13 @@ def put(a, ind, v, mode='raise'):
445445
array([ 0, 1, 2, 3, -5])
446446
447447
"""
448-
return a.put(ind, v, mode)
448+
try:
449+
put = a.put
450+
except AttributeError:
451+
raise TypeError("argument 1 must be numpy.ndarray, "
452+
"not {name}".format(name=type(a).__name__))
453+
454+
return put(ind, v, mode)
449455

450456

451457
def swapaxes(a, axis1, axis2):

numpy/core/tests/test_fromnumeric.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from __future__ import division, absolute_import, print_function
2+
3+
from numpy import put
4+
from numpy.testing import TestCase, assert_raises
5+
6+
7+
class TestPut(TestCase):
8+
9+
def test_bad_array(self):
10+
# We want to raise a TypeError in the
11+
# case that a non-ndarray object is passed
12+
# in since `np.put` modifies in place and
13+
# hence would do nothing to a non-ndarray
14+
v = 5
15+
indx = [0, 2]
16+
bad_array = [1, 2, 3]
17+
assert_raises(TypeError, put, bad_array, indx, v)

numpy/lib/function_base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1779,7 +1779,7 @@ def place(arr, mask, vals):
17791779
17801780
Parameters
17811781
----------
1782-
arr : array_like
1782+
arr : ndarray
17831783
Array to put data into.
17841784
mask : array_like
17851785
Boolean mask array. Must have the same size as `a`.
@@ -1801,6 +1801,10 @@ def place(arr, mask, vals):
18011801
[44, 55, 44]])
18021802
18031803
"""
1804+
if not isinstance(arr, np.ndarray):
1805+
raise TypeError("argument 1 must be numpy.ndarray, "
1806+
"not {name}".format(name=type(arr).__name__))
1807+
18041808
return _insert(arr, mask, vals)
18051809

18061810

numpy/lib/tests/test_function_base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,10 @@ def test_basic(self):
674674
assert_array_equal(b, [3, 2, 2, 3, 3])
675675

676676
def test_place(self):
677+
# Make sure that non-np.ndarray objects
678+
# raise an error instead of doing nothing
679+
assert_raises(TypeError, place, [1, 2, 3], [True, False], [0, 1])
680+
677681
a = np.array([1, 4, 3, 2, 5, 8, 7])
678682
place(a, [0, 1, 0, 1, 0, 1, 0], [2, 4, 6])
679683
assert_array_equal(a, [1, 2, 3, 4, 5, 6, 7])

numpy/lib/tests/test_regression.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,6 @@ def test_poly_eq(self, level=rlevel):
8585
assert_(x != y)
8686
assert_(x == x)
8787

88-
def test_mem_insert(self, level=rlevel):
89-
# Ticket #572
90-
np.lib.place(1, 1, 1)
91-
9288
def test_polyfit_build(self):
9389
# Ticket #628
9490
ref = [-1.06123820e-06, 5.70886914e-04, -1.13822012e-01,

0 commit comments

Comments
 (0)
0