8000 Test normalize function in data.py · scikit-learn/scikit-learn@08e1db4 · GitHub
[go: up one dir, main page]

Skip to content

Commit 08e1db4

Browse files
committed
Test normalize function in data.py
1 parent e34bbc1 commit 08e1db4

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

sklearn/preprocessing/tests/test_data.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1269,6 +1269,26 @@ def test_normalize():
12691269
assert_raises(ValueError, normalize, [[0]], axis=2)
12701270
assert_raises(ValueError, normalize, [[0]], norm='l3')
12711271

1272+
rs = np.random.RandomState(0)
1273+
X_dense = rs.randn(10, 5)
1274+
X_sparse = sparse.csr_matrix(X_dense)
1275+
ones = np.ones((10))
1276+
for X in (X_dense, X_sparse):
1277+
for dtype in (np.float32, np.float64):
1278+
for norm in ('l1', 'l2'):
1279+
X = X.astype(dtype)
1280+
X_norm = normalize(X, norm=norm)
1281+
assert_equal(X_norm.dtype, dtype)
1282+
1283+
X_norm = toarray(X_norm)
1284+
if norm == 'l1':
1285+
row_sums = np.abs(X_norm).sum(axis=1)
1286+
else:
1287+
X_norm_squared = X_norm**2
1288+
row_sums = X_norm_squared.sum(axis=1)
1289+
1290+
assert_array_almost_equal(row_sums, ones)
1291+
12721292

12731293
def test_binarizer():
12741294
X_ = np.array([[1, 0, 5], [2, 3, -1]])

0 commit comments

Comments
 (0)
0