8000 Test normalize function in data.py · scikit-learn/scikit-learn@16829dc · GitHub
[go: up one dir, main page]

Skip to content

Commit 16829dc

Browse files
committed
Test normalize function in data.py
1 parent 27ac56b commit 16829dc

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed

sklearn/preprocessing/tests/test_data.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1270,6 +1270,39 @@ def test_normalize():
12701270
assert_raises(ValueError, normalize, [[0]], norm='l3')
12711271

12721272

1273+
def test_normalize_l1():
1274+
rs = np.random.RandomState(0)
1275+
X_dense = rs.rand(10, 5)
1276+
X_sparse = sparse.csr_matrix(X_dense)
1277+
ones = np.ones((10))
1278+
for X in (X_dense, X_sparse):
1279+
for dtype in (np.float32, np.float64):
1280+
X = X.astype(dtype)
1281+
X_norm = normalize(X, norm='l1')
1282+
assert_equal(X_norm.dtype, dtype)
1283+
1284+
X_norm = toarray(X_norm)
1285+
row_sums = np.abs(X_norm).sum(axis=1)
1286+
assert_array_almost_equal(row_sums, ones)
1287+
1288+
1289+
def test_normalize_l2():
1290+
rs = np.random.RandomState(0)
1291+
X_dense = rs.rand(10, 5)
1292+
X_sparse = sparse.csr_matrix(X_dense)
1293+
ones = np.ones((10))
1294+
for X in (X_dense, X_sparse):
1295+
for dtype in (np.float32, np.float64):
1296+
X = X.astype(dtype)
1297+
X_norm = normalize(X, norm='l2')
1298+
assert_equal(X_norm.dtype, dtype)
1299+
1300+
X_norm = toarray(X_norm)
1301+
X_norm_squared = X_norm**2
1302+
row_sums = X_norm_squared.sum(axis=1)
1303+
assert_array_almost_equal(row_sums, ones)
1304+
1305+
12731306
def test_binarizer():
12741307
X_ = np.array([[1, 0, 5], [2, 3, -1]])
12751308

0 commit comments

Comments
 (0)
0