8000 TST: Fix #6542: Add test for non-iterable input... · gkBCCN/numpy@a3e12c9 · GitHub
[go: up one dir, main page]

Skip to content

Commit a3e12c9

Browse files
committed
TST: Fix numpy#6542: Add test for non-iterable input...
...for hsplit, vsplit, dsplit, dstack
1 parent 8e05d78 commit a3e12c9

File tree

1 file changed

+40
-1
lines changed

1 file changed

+40
-1
lines changed

numpy/lib/tests/test_shape_base.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
from numpy.lib.shape_base import (
55
apply_along_axis, apply_over_axes, array_split, split, hsplit, dsplit,
6-
vsplit, dstack, kron, tile
6+
vsplit, dstack, column_stack, kron, tile
77
)
88
from numpy.testing import (
99
run_module_suite, TestCase, assert_, assert_equal, assert_array_equal,
@@ -175,8 +175,15 @@ def test_unequal_split(self):
175175
a = np.arange(10)
176176
assert_raises(ValueError, split, a, 3)
177177

178+
class TestColumnStack(TestCase):
179+
def test_non_iterable(self):
180+
assert_raises(TypeError, column_stack, 1)
181+
178182

179183
class TestDstack(TestCase):
184+
def test_non_iterable(self):
185+
assert_raises(TypeError, dstack, 1)
186+
180187
def test_0D_array(self):
181188
a = np.array(1)
182189
b = np.array(2)
@@ -212,6 +219,9 @@ class TestHsplit(TestCase):
212219
"""Only testing for integer splits.
213220
214221
"""
222+
def test_non_iterable(self):
223+
assert_raises(ValueError, hsplit, 1,1)
224+
215225
def test_0D_array(self):
216226
a = np.array(1)
217227
try:
@@ -238,6 +248,17 @@ class TestVsplit(TestCase):
238248
"""Only testing for integer splits.
239249
240250
"""
251+
def test_non_iterable(self):
252+
assert_raises(ValueError, vsplit, 1,1)
253+
254+
def test_0D_array(self):
255+
a = np.array(1)
256+
try:
257+
vsplit(a, 2)
258+
assert_(0)
259+
except ValueError:
260+
pass
261+
241262
def test_1D_array(self):
242263
a = np.array([1, 2, 3, 4])
243264
try:
@@ -256,6 +277,24 @@ def test_2D_array(self):
256277

257278
class TestDsplit(TestCase):
258279
# Only testing for integer splits.
280+
def test_non_iterable(self):
281+
assert_raises(ValueError, dsplit, 1,1)
282+
283+
def test_0D_array(self):
284+
a = np.array(1)
285+
try:
286+
dsplit(a, 2)
287+
assert_(0)
288+
except ValueError:
289+
pass
290+
291+
def test_1D_array(self):
292+
a = np.array([1, 2, 3, 4])
293+
try:
294+
dsplit(a, 2)
295+
assert_(0)
296+
except ValueError:
297+
pass
259298

260299
def test_2D_array(self):
261300
a = np.array([[1, 2, 3, 4],

0 commit comments

Comments
 (0)
0