8000 Test row_norms for float32 data · scikit-learn/scikit-learn@ad211a1 · GitHub
[go: up one dir, main page]

Skip to content

Commit ad211a1

Browse files
committed
Test row_norms for float32 data
1 parent 65e91f7 commit ad211a1

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

sklearn/utils/tests/test_extmath.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -148,14 +148,16 @@ def test_norm_squared_norm():
148148

149149
def test_row_norms():
150150
X = np.random.RandomState(42).randn(100, 100)
151-
sq_norm = (X ** 2).sum(axis=1)
151+
for dtype in (np.float32, np.float64):
152+
X = X.astype(dtype)
153+
sq_norm = (X ** 2).sum(axis=1)
152154

153-
assert_array_almost_equal(sq_norm, row_norms(X, squared=True), 5)
154-
assert_array_almost_equal(np.sqrt(sq_norm), row_norms(X))
155+
assert_array_almost_equal(sq_norm, row_norms(X, squared=True), 4)
156+
assert_array_almost_equal(np.sqrt(sq_norm), row_norms(X), 4)
155157

156-
Xcsr = sparse.csr_matrix(X, dtype=np.float32)
157-
assert_array_almost_equal(sq_norm, row_norms(Xcsr, squared=True), 5)
158-
assert_array_almost_equal(np.sqrt(sq_norm), row_norms(Xcsr))
158+
Xcsr = sparse.csr_matrix(X, dtype=dtype)
159+
assert_array_almost_equal(sq_norm, row_norms(Xcsr, squared=True), 4)
160+
assert_array_almost_equal(np.sqrt(sq_norm), row_norms(Xcsr), 4)
159161

160162

161163
def test_randomized_svd_low_rank_with_noise():

0 commit comments

Comments
 (0)
0