|
2 | 2 |
|
3 | 3 | import numpy as np
|
4 | 4 | from numpy.testing import assert_array_almost_equal
|
| 5 | +import pytest |
5 | 6 | import matplotlib.pyplot as plt
|
6 | 7 | from matplotlib.testing.decorators import image_comparison
|
7 | 8 | import matplotlib.transforms as mtransforms
|
@@ -114,3 +115,49 @@ def test_streamplot_limits():
|
114 | 115 | # datalim.
|
115 | 116 | assert_array_almost_equal(ax.dataLim.bounds, (20, 30, 15, 6),
|
116 | 117 | decimal=1)
|
| 118 | + |
| 119 | + |
| 120 | +def test_streamplot_grid(): |
| 121 | + u = np.ones((2, 2)) |
| 122 | + v = np.zeros((2, 2)) |
| 123 | + |
| 124 | + # Test for same rows and columns |
| 125 | + x = np.array([[10, 20], [10, 30]]) |
| 126 | + y = np.array([[10, 10], [20, 20]]) |
| 127 | + |
| 128 | + with pytest.raises(ValueError, match="The rows of 'x' must be equal"): |
| 129 | + plt.streamplot(x, y, u, v) |
| 130 | + |
| 131 | + x = np.array([[10, 20], [10, 20]]) |
| 132 | + y = np.array([[10, 10], [20, 30]]) |
| 133 | + |
| 134 | + with pytest.raises(ValueError, match="The columns of 'y' must be equal"): |
| 135 | + plt.streamplot(x, y, u, v) |
| 136 | + |
| 137 | + x = np.array([[10, 20], [10, 20]]) |
| 138 | + y = np.array([[10, 10], [20, 20]]) |
| 139 | + plt.streamplot(x, y, u, v) |
| 140 | + |
| 141 | + # Test for maximum dimensions |
| 142 | + x = np.array([0, 10]) |
| 143 | + y = np.array([[[0, 10]]]) |
| 144 | + |
| 145 | + with pytest.raises(ValueError, match="'y' can have at maximum " |
| 146 | + "2 dimensions"): |
| 147 | + plt.streamplot(x, y, u, v) |
| 148 | + |
| 149 | + # Test for equal spacing |
| 150 | + u = np.ones((3, 3)) |
| 151 | + v = np.zeros((3, 3)) |
| 152 | + x = np.array([0, 10, 20]) |
| 153 | + y = np.array([0, 10, 30]) |
| 154 | + |
| 155 | + with pytest.raises(ValueError, match="'y' values must be equally spaced"): |
| 156 | + plt.streamplot(x, y, u, v) |
| 157 | + |
| 158 | + # Test for strictly increasing |
| 159 | + x = np.array([0, 20, 40]) |
| 160 | + y = np.array([0, 20, 10]) |
| 161 | + |
| 162 | + with pytest.raises(ValueError, match="'y' must be strictly increasing"): |
| 163 | + plt.streamplot(x, y, u, v) |
0 commit comments