8000 Merge pull request #440 from matthew-brett/crazy-axis-concat-warning · numpy/numpy@6a847ef · GitHub
[go: up one dir, main page]

Skip to content

Commit 6a847ef

Browse files
committed
Merge pull request #440 from matthew-brett/crazy-axis-concat-warning
BUG: allow any axis for np.concatenate for 1D
2 parents c8010d0 + 69afd27 commit 6a847ef

File tree

2 files changed

+81
-2
lines changed

2 files changed

+81
-2
lines changed

numpy/core/src/multiarray/multiarraymodule.c

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,16 @@ PyArray_ConcatenateArrays(int narrays, PyArrayObject **arrays, int axis)
337337
if (axis < 0) {
338338
axis += ndim;
339339
}
340+
341+
if (ndim == 1 & axis != 0) {
342+
char msg[] = "axis != 0 for ndim == 1; this will raise an error in "
343+
"future versions of numpy";
344+
if (DEPRECATE(msg) < 0) {
345+
return NULL;
346+
}
347+
axis = 0;
348+
}
349+
340350
if (axis < 0 || axis >= ndim) {
341351
PyErr_Format(PyExc_IndexError,
342352
"axis %d out of bounds [0, %d)", orig_axis, ndim);

numpy/core/tests/test_shape_base.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import warnings
22
import numpy as np
3-
from numpy.testing import (TestCase, assert_, assert_raises, assert_equal,
4-
assert_array_equal, run_module_suite)
3+
from numpy.testing import (TestCase, assert_, assert_raises, assert_array_equal,
4+
assert_equal, run_module_suite)
55
from numpy.core import (array, arange, atleast_1d, atleast_2d, atleast_3d,
66
vstack, hstack, newaxis, concatenate)
77

@@ -40,6 +40,7 @@ def test_r1array(self):
4040
assert_(atleast_1d(3.0).shape == (1,))
4141
assert_(atleast_1d([[2,3],[4,5]]).shape == (2,2))
4242

43+
4344
class TestAtleast2d(TestCase):
4445
def test_0D_array(self):
4546
a = array(1); b = array(2);
@@ -100,6 +101,7 @@ def test_3D_array(self):
100101
desired = [a,b]
101102
assert_array_equal(res,desired)
102103

104+
103105
class TestHstack(TestCase):
104106
def test_0D_array(self):
105107
a = array(1); b = array(2);
@@ -119,6 +121,7 @@ def test_2D_array(self):
119121
desired = array([[1,1],[2,2]])
120122
assert_array_equal(res,desired)
121123

124+
122125
class TestVstack(TestCase):
123126
def test_0D_array(self):
124127
a = array(1); b = array(2);
@@ -159,5 +162,71 @@ def test_concatenate_axis_None():
159162
'0', '1', '2', 'x'])
160163
assert_array_equal(r,d)
161164

165+
166+
def test_concatenate():
167+
# Test concatenate function
168+
# No arrays raise ValueError
169+
assert_raises(ValueError, concatenate, ())
170+
# Scalars cannot be concatenated
171+
assert_raises(ValueError, concatenate, (0,))
172+
assert_raises(ValueError, concatenate, (array(0),))
173+
# One sequence returns unmodified (but as array)
174+
r4 = list(range(4))
175+
assert_array_equal(concatenate((r4,)), r4)
176+
# Any sequence
177+
assert_array_equal(concatenate((tuple(r4),)), r4)
178+
assert_array_equal(concatenate((array(r4),)), r4)
179+
# 1D default concatenation
180+
r3 = list(range(3))
181+
assert_array_equal(concatenate((r4, r3)), r4 + r3)
182+
# Mixed sequence types
183+
assert_array_equal(concatenate((tuple(r4), r3)), r4 + r3)
184+
assert_array_equal(concatenate((array(r4), r3)), r4 + r3)
185+
# Explicit axis specification
186+
assert_array_equal(concatenate((r4, r3), 0), r4 + r3)
187+
# Including negative
188+
assert_array_equal(concatenate((r4, r3), -1), r4 + r3)
189+
# 2D
190+
a23 = array([[10, 11, 12], [13, 14, 15]])
191+
a13 = array([[0, 1, 2]])
192+
res = array([[10, 11, 12], [13, 14, 15], [0, 1, 2]])
193+
assert_array_equal(concatenate((a23, a13)), res)
194+
assert_array_equal(concatenate((a23, a13), 0), res)
195+
assert_array_equal(concatenate((a23.T, a13.T), 1), res.T)
196+
assert_array_equal(concatenate((a23.T, a13.T), -1), res.T)
197+
# Arrays much match shape
198+
assert_raises(ValueError, concatenate, (a23.T, a13.T), 0)
199+
# 3D
200+
res = arange(2 * 3 * 7).reshape((2, 3, 7))
201+
a0 = res[..., :4]
202+
a1 = res[..., 4:6]
203+
a2 = res[..., 6:]
204+
assert_array_equal(concatenate((a0, a1, a2), 2), res)
205+
assert_array_equal(concatenate((a0, a1, a2), -1), res)
206+
assert_array_equal(concatenate((a0.T, a1.T, a2.T), 0), res.T)
207+
208+
209+
def test_concatenate_sloppy0():
210+
# Versions of numpy < 1.7.0 ignored axis argument value for 1D arrays. We
211+
# allow this for now, but in due course we will raise an error
212+
r4 = list(range(4))
213+
r3 = list(range(3))
214+
assert_array_equal(concatenate((r4, r3), 0), r4 + r3)
215+
warnings.simplefilter('ignore', DeprecationWarning)
216+
try:
217+
assert_array_equal(concatenate((r4, r3), -10), r4 + r3)
218+
assert_array_equal(concatenate((r4, r3), 10), r4 + r3)
219+
finally:
220+
warnings.filters.pop(0)
221+
# Confurm DepractionWarning raised
222+
warnings.simplefilter('always', DeprecationWarning)
223+
warnings.simplefilter('error', DeprecationWarning)
224+
try:
225+
assert_raises(DeprecationWarning, concatenate, (r4, r3), 10)
226+
finally:
227+
warnings.filters.pop(0)
228+
warnings.filters.pop(0)
229+
230+
162231
if __name__ == "__main__":
163232
run_module_suite()

0 commit comments

Comments
 (0)
0