8000 Merge pull request #429 from 87/fix_insert · numpy/numpy@aed8fc5 · GitHub
[go: up one dir, main page]

Skip to content

Commit aed8fc5

Browse files
committed
Merge pull request #429 from 87/fix_insert
Fix for issues #392 and #378
2 parents b6a1acd + 1688b29 commit aed8fc5

File tree

2 files changed

+17
-6
lines changed

2 files changed

+17
-6
lines changed

numpy/lib/function_base.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3591,19 +3591,21 @@ def insert(arr, obj, values, axis=None):
35913591
slobj = [slice(None)]*ndim
35923592
N = arr.shape[axis]
35933593
newshape = list(arr.shape)
3594-
if isinstance(obj, (int, long, integer)):
35953594

3595+
if isinstance(obj, (int, long, integer)):
35963596
if (obj < 0): obj += N
35973597
if obj < 0 or obj > N:
35983598
raise ValueError(
35993599
"index (%d) out of range (0<=index<=%d) "\
36003600
"in dimension %d" % (obj, N, axis))
3601-
3602-
if isinstance(values, (int, long, integer)):
3603-
obj = [obj]
3601+
if isscalar(values):
3602+
obj = [obj]
36043603
else:
3605-
obj = [obj] * len(values)
3606-
3604+
values = asarray(values)
3605+
if ndim > values.ndim:
3606+
obj = [obj]
3607+
else:
3608+
obj = [obj] * len(values)
36073609

36083610
elif isinstance(obj, slice):
36093611
# turn it into a range object

numpy/lib/tests/test_function_base.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,15 @@ def test_basic(self):
147147
assert_equal(insert(a, [1, 1, 1], [1, 2, 3]), [1, 1, 2, 3, 2, 3])
148148
assert_equal(insert(a, 1,[1,2,3]), [1, 1, 2, 3, 2, 3])
149149
assert_equal(insert(a,[1,2,3],9),[1,9,2,9,3,9])
150+
b = np.array([0, 1], dtype=np.float64)
151+
assert_equal(insert(b, 0, b[0]), [0., 0., 1.])
152+
def test_multidim(self):
153+
a = [[1, 1, 1]]
154+
r = [[2, 2, 2],
155+
[1, 1, 1]]
156+
assert_equal(insert(a, 0, [2, 2, 2], axis=0), r)
157+
assert_equal(insert(a, 0, 2, axis=0), r)
158+
assert_equal(insert(a, 2, 2, axis=1), [[1, 1, 2, 1]])
150159

151160
class TestAmax(TestCase):
152161
def test_basic(self):

0 commit comments

Comments
 (0)
0