8000 Merge pull request #7366 from gkBCCN/bug-fix-6542-reloaded · numpy/numpy@8ccfede · GitHub
[go: up one dir, main page]

Skip to content

Commit 8ccfede

Browse files
committed
Merge pull request #7366 from gkBCCN/bug-fix-6542-reloaded
TST: fix #6542, add tests to check non-iterable argument raises in hstack and related functions.
2 parents 25ac6b1 + 4aac3ae commit 8ccfede

File tree

2 files changed

+37
-1
lines changed

2 files changed

+37
-1
lines changed

numpy/core/tests/test_shape_base.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,9 @@ def test_3D_array(self):
120120

121121

122122
class TestHstack(TestCase):
123+
def test_non_iterable(self):
124+
assert_raises(TypeError, hstack, 1)
125+
123126
def test_0D_array(self):
124127
a = array(1)
125128
b = array(2)
@@ -143,6 +146,9 @@ def test_2D_array(self):
143146

144147

145148
class TestVstack(TestCase):
149+
def test_non_iterable(self):
150+
assert_raises(TypeError, vstack, 1)
151+
146152
def test_0D_array(self):
147153
a = array(1)
148154
b = array(2)
@@ -265,6 +271,9 @@ def test_concatenate(self):
265271

266272

267273
def test_stack():
274+
# non-iterable input
275+
assert_raises(TypeError, stack, 1)
276+
268277
# 0d input
269278
for input_ in [(1, 2, 3),
270279
[np.int32(1), np.int32(2), np.int32(3)],

numpy/lib/tests/test_shape_base.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
from numpy.lib.shape_base import (
55
apply_along_axis, apply_over_axes, array_split, split, hsplit, dsplit,
6-
vsplit, dstack, kron, tile
6+
vsplit, dstack, column_stack, kron, tile
77
)
88
from numpy.testing import (
99
run_module_suite, TestCase, assert_, assert_equal, assert_array_equal,
@@ -175,8 +175,15 @@ def test_unequal_split(self):
175175
a = np.arange(10)
176176
assert_raises(ValueError, split, a, 3)
177177

178+
class TestColumnStack(TestCase):
179+
def test_non_iterable(self):
180+
assert_raises(TypeError, column_stack, 1)
181+
178182

179183
class TestDstack(TestCase):
184+
def test_non_iterable(self):
185+
assert_raises(TypeError, dstack, 1)
186+
180187
def test_0D_array(self):
181188
a = np.array(1)
182189
b = np.array(2)
@@ -212,6 +219,9 @@ class TestHsplit(TestCase):
212219
"""Only testing for integer splits.
213220
214221
"""
222+
def test_non_iterable(self):
223+
assert_raises(ValueError, hsplit, 1, 1)
224+
215225
def test_0D_array(self):
216226
a = np.array(1)
217227
try:
@@ -238,6 +248,13 @@ class TestVsplit(TestCase):
238248
"""Only testing for integer splits.
239249
240250
"""
251+
def test_non_iterable(self):
252+
assert_raises(ValueError, vsplit, 1, 1)
253+
254+
def test_0D_array(self):
255+
a = np.array(1)
256+
assert_raises(ValueError, vsplit, a, 2)
257+
241258
def test_1D_array(self):
242259
a = np.array([1, 2, 3, 4])
243260
try:
@@ -256,6 +273,16 @@ def test_2D_array(self):
256273

257274
class TestDsplit(TestCase):
258275
# Only testing for integer splits.
276+
def test_non_iterable(self):
277+
assert_raises(ValueError, dsplit, 1, 1)
278+
279+
def test_0D_array(self):
280+
a = np.array(1)
281+
assert_raises(ValueError, dsplit, a, 2)
282+
283+
def test_1D_array(self):
284+
a = np.array([1, 2, 3, 4])
285+
assert_raises(ValueError, dsplit, a, 2)
259286

260287
def test_2D_array(self):
261288
a = np.array([[1, 2, 3, 4],

0 commit comments

Comments
 (0)
0