1
1
import warnings
2
2
import numpy as np
3
- from numpy .testing import (TestCase , assert_ , assert_raises , assert_equal ,
4
- assert_array_equal , run_module_suite )
3
+ from numpy .testing import (TestCase , assert_ , assert_raises , assert_array_equal ,
4
+ assert_equal , run_module_suite )
5
5
from numpy .core import (array , arange , atleast_1d , atleast_2d , atleast_3d ,
6
6
vstack , hstack , newaxis , concatenate )
7
7
@@ -40,6 +40,7 @@ def test_r1array(self):
40
40
assert_ (atleast_1d (3.0 ).shape == (1 ,))
41
41
assert_ (atleast_1d ([[2 ,3 ],[4 ,5 ]]).shape == (2 ,2 ))
42
42
43
+
43
44
class TestAtleast2d (TestCase ):
44
45
def test_0D_array (self ):
45
46
a = array (1 ); b = array (2 );
@@ -100,6 +101,7 @@ def test_3D_array(self):
100
101
desired = [a ,b ]
101
102
assert_array_equal (res ,desired )
102
103
104
+
103
105
class TestHstack (TestCase ):
104
106
def test_0D_array (self ):
105
107
a = array (1 ); b = array (2 );
@@ -119,6 +121,7 @@ def test_2D_array(self):
119
121
desired = array ([[1 ,1 ],[2 ,2 ]])
120
122
assert_array_equal (res ,desired )
121
123
124
+
122
125
class TestVstack (TestCase ):
123
126
def test_0D_array (self ):
124
127
a = array (1 ); b = array (2 );
@@ -159,5 +162,71 @@ def test_concatenate_axis_None():
159
162
'0' , '1' , '2' , 'x' ])
160
163
assert_array_equal (r ,d )
161
164
165
+
166
+ def test_concatenate ():
167
+ # Test concatenate function
168
+ # No arrays raise ValueError
169
+ assert_raises (ValueError , concatenate , ())
170
+ # Scalars cannot be concatenated
171
+ assert_raises (ValueError , concatenate , (0 ,))
172
+ assert_raises (ValueError , concatenate , (array (0 ),))
173
+ # One sequence returns unmodified (but as array)
174
+ r4 = list (range (4 ))
175
+ assert_array_equal (concatenate ((r4 ,)), r4 )
176
+ # Any sequence
177
+ assert_array_equal (concatenate ((tuple (r4 ),)), r4 )
178
+ assert_array_equal (concatenate ((array (r4 ),)), r4 )
179
+ # 1D default concatenation
180
+ r3 = list (range (3 ))
181
+ assert_array_equal (concatenate ((r4 , r3 )), r4 + r3 )
182
+ # Mixed sequence types
183
+ assert_array_equal (concatenate ((tuple (r4 ), r3 )), r4 + r3 )
184
+ assert_array_equal (concatenate ((array (r4 ), r3 )), r4 + r3 )
185
+ # Explicit axis specification
186
+ assert_array_equal (concatenate ((r4 , r3 ), 0 ), r4 + r3 )
187
+ # Including negative
188
+ assert_array_equal (concatenate ((r4 , r3 ), - 1 ), r4 + r3 )
189
+ # 2D
190
+ a23 = array ([[10 , 11 , 12 ], [13 , 14 , 15 ]])
191
+ a13 = array ([[0 , 1 , 2 ]])
192
+ res = array ([[10 , 11 , 12 ], [13 , 14 , 15 ], [0 , 1 , 2 ]])
193
+ assert_array_equal (concatenate ((a23 , a13 )), res )
194
+ assert_array_equal (concatenate ((a23 , a13 ), 0 ), res )
195
+ assert_array_equal (concatenate ((a23 .T , a13 .T ), 1 ), res .T )
196
+ assert_array_equal (concatenate ((a23 .T , a13 .T ), - 1 ), res .T )
197
+ # Arrays much match shape
198
+ assert_raises (ValueError , concatenate , (a23 .T , a13 .T ), 0 )
199
+ # 3D
200
+ res = arange (2 * 3 * 7 ).reshape ((2 , 3 , 7 ))
201
+ a0 = res [..., :4 ]
202
+ a1 = res [..., 4 :6 ]
203
+ a2 = res [..., 6 :]
204
+ assert_array_equal (concatenate ((a0 , a1 , a2 ), 2 ), res )
205
+ assert_array_equal (concatenate ((a0 , a1 , a2 ), - 1 ), res )
206
+ assert_array_equal (concatenate ((a0 .T , a1 .T , a2 .T ), 0 ), res .T )
207
+
208
+
209
+ def test_concatenate_sloppy0 ():
210
+ # Versions of numpy < 1.7.0 ignored axis argument value for 1D arrays. We
211
+ # allow this for now, but in due course we will raise an error
212
+ r4 = list (range (4 ))
213
+ r3 = list (range (3 ))
214
+ assert_array_equal (concatenate ((r4 , r3 ), 0 ), r4 + r3 )
215
+ warnings .simplefilter ('ignore' , DeprecationWarning )
216
+ try :
217
+ assert_array_equal (concatenate ((r4 , r3 ), - 10 ), r4 + r3 )
218
+ assert_array_equal (concatenate ((r4 , r3 ), 10 ), r4 + r3 )
219
+ finally :
220
+ warnings .filters .pop (0 )
221
+ # Confurm DepractionWarning raised
222
+ warnings .simplefilter ('always' , DeprecationWarning )
223
+ warnings .simplefilter ('error' , DeprecationWarning )
224
+ try :
225
+ assert_raises (DeprecationWarning , concatenate , (r4 , r3 ), 10 )
226
+ finally :
227
+ warnings .filters .pop (0 )
228
+ warnings .filters .pop (0 )
229
+
230
+
162
231
if __name__ == "__main__" :
163
232
run_module_suite ()
0 commit comments