8000 TST Extend tests for `scipy.sparse/*array` in `sklearn/linear_model/tests/test_sparse_coordinate_descent` by Charlie-XIAO · Pull Request #27237 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

TST Extend tests for scipy.sparse/*array in sklearn/linear_model/tests/test_sparse_coordinate_descent #27237

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 46 additions & 32 deletions sklearn/linear_model/tests/test_sparse_coordinate_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
create_memmap_backed_data,
ignore_warnings,
)
from sklearn.utils.fixes import COO_CONTAINERS, CSC_CONTAINERS, LIL_CONTAINERS


def test_sparse_coef():
Expand All @@ -23,9 +24,10 @@ def test_sparse_coef():
assert clf.sparse_coef_.toarray().tolist()[0] == clf.coef_


def test_lasso_zero():
@pytest.mark.parametrize("csc_container", CSC_CONTAINERS)
def test_lasso_zero(csc_container):
# Check that the sparse lasso can handle zero data without crashing
X = sp.csc_matrix((3, 1))
X = csc_container((3, 1))
y = [0, 0, 0]
T = np.array([[1], [2], [3]])
clf = Lasso().fit(X, y)
Expand All @@ -36,11 +38,12 @@ def test_lasso_zero():


@pytest.mark.parametrize("with_sample_weight", [True, False])
def test_enet_toy_list_input(with_sample_weight):
@pytest.mark.parametrize("csc_container", CSC_CONTAINERS)
def test_enet_toy_list_input(with_sample_weight, csc_container):
# Test ElasticNet for various values of alpha and l1_ratio with list X

X = np.array([[-1], [0], [1]])
X = sp.csc_matrix(X)
X = csc_container(X)
Y = [-1, 0, 1] # just a straight line
T = np.array([[2], [3], [4]]) # test sample
if with_sample_weight:
Expand Down Expand Up @@ -73,18 +76,19 @@ def test_enet_toy_list_input(with_sample_weight):
assert_almost_equal(clf.dual_gap_, 0)


def test_enet_toy_explicit_sparse_input():
< 10000 span class='blob-code-inner blob-code-marker ' data-code-marker="+">@pytest.mark.parametrize("lil_container", LIL_CONTAINERS)
def test_enet_toy_explicit_sparse_input(lil_container):
# Test ElasticNet for various values of alpha and l1_ratio with sparse X
f = ignore_warnings
# training samples
X = sp.lil_matrix((3, 1))
X = lil_container((3, 1))
X[0, 0] = -1
# X[1, 0] = 0
X[2, 0] = 1
Y = [-1, 0, 1] # just a straight line (the identity function)

# test samples
T = sp.lil_matrix((3, 1))
T = lil_container((3, 1))
T[0, 0] = 2
T[1, 0] = 3
T[2, 0] = 4
Expand Down Expand Up @@ -113,6 +117,7 @@ def test_enet_toy_explicit_sparse_input():


def make_sparse_data(
sparse_container,
n_samples=100,
n_features=100,
n_informative=10,
Expand All @@ -137,17 +142,24 @@ def make_sparse_data(

# generate training ground truth labels
y = np.dot(X, w)
X = sp.csc_matrix(X)
X = sparse_container(X)
if n_targets == 1:
y = np.ravel(y)
return X, y


def _test_sparse_enet_not_as_toy_dataset(alpha, fit_intercept, positive):
@pytest.mark.parametrize("csc_container", CSC_CONTAINERS)
@pytest.mark.parametrize(
"alpha, fit_intercept, positive",
[(0.1, False, False), (0.1, True, False), (1e-3, False, True), (1e-3, True, True)],
)
def test_sparse_enet_not_as_toy_dataset(csc_container, alpha, fit_intercept, positive):
n_samples, n_features, max_iter = 100, 100, 1000
n_informative = 10

X, y = make_sparse_data(n_samples, n_features, n_informative, positive=positive)
X, y = make_sparse_data(
csc_container, n_samples, n_features, n_informative, positive=positive
)

X_train, X_test = X[n_samples // 2 :], X[: n_samples // 2]
y_train, y_test = y[n_samples // 2 :], y[: n_samples // 2]
Expand Down Expand Up @@ -188,18 +200,14 @@ def _test_sparse_enet_not_as_toy_dataset(alpha, fit_intercept, positive):
assert np.sum(s_clf.coef_ != 0.0) < 2 * n_informative


def test_sparse_enet_not_as_toy_dataset():
_test_sparse_enet_not_as_toy_dataset(alpha=0.1, fit_intercept=False, positive=False)
_test_sparse_enet_not_as_toy_dataset(alpha=0.1, fit_intercept=True, positive=False)
_test_sparse_enet_not_as_toy_dataset(alpha=1e-3, fit_intercept=False, positive=True)
_test_sparse_enet_not_as_toy_dataset(alpha=1e-3, fit_intercept=True, positive=True)


def test_sparse_lasso_not_as_toy_dataset():
@pytest.mark.parametrize("csc_container", CSC_CONTAINERS)
def test_sparse_lasso_not_as_toy_dataset(csc_container):
n_samples = 100
max_iter = 1000
n_informative = 10
X, y = make_sparse_data(n_samples=n_samples, n_informative=n_informative)
X, y = make_sparse_data(
csc_container, n_samples=n_samples, n_informative=n_informative
)

X_train, X_test = X[n_samples // 2 :], X[: n_samples // 2]
y_train, y_test = y[n_samples // 2 :], y[: n_samples // 2]
Expand All @@ -219,9 +227,10 @@ def test_sparse_lasso_not_as_toy_dataset():
assert np.sum(s_clf.coef_ != 0.0) == n_informative


def test_enet_multitarget():
@pytest.mark.parametrize("csc_container", CSC_CONTAINERS)
def test_enet_multitarget(csc_container):
n_targets = 3
X, y = make_sparse_data(n_targets=n_targets)
X, y = make_sparse_data(csc_container, n_targets=n_targets)

estimator = ElasticNet(alpha=0.01, precompute=False)
# XXX: There is a bug when precompute is not False!
Expand All @@ -239,8 +248,9 @@ def test_enet_multitarget():
assert_array_almost_equal(dual_gap[k], estimator.dual_gap_)


def test_path_parameters():
X, y = make_sparse_data()
@pytest.mark.parametrize("csc_container", CSC_CONTAINERS)
def test_path_parameters(csc_container):
X, y = make_sparse_data(csc_container)
max_iter = 50
n_alphas = 10
clf = ElasticNetCV(
Expand All @@ -263,8 +273,9 @@ def test_path_parameters():
@pytest.mark.parametrize("fit_intercept", [False, True])
@pytest.mark.parametrize("n_samples, n_features", [(24, 6), (6, 24)])
@pytest.mark.parametrize("with_sample_weight", [True, False])
@pytest.mark.parametrize("csc_container", CSC_CONTAINERS)
def test_sparse_dense_equality(
Model, fit_intercept, n_samples, n_features, with_sample_weight
Model, fit_intercept, n_samples, n_features, with_sample_weight, csc_container
):
X, y = make_regression(
n_samples=n_samples,
Expand All @@ -279,7 +290,7 @@ def test_sparse_dense_equality(
sw = np.abs(np.random.RandomState(42).normal(scale=10, size=y.shape))
else:
sw = None
Xs = sp.csc_matrix(X)
Xs = csc_container(X)
params = {"fit_intercept": fit_intercept}
reg_dense = Model(**params).fit(X, y, sample_weight=sw)
reg_sparse = Model(**params).fit(Xs, y, sample_weight=sw)
Expand All @@ -292,8 +303,9 @@ def test_sparse_dense_equality(
assert_allclose(reg_sparse.coef_, reg_dense.coef_)


def test_same_output_sparse_dense_lasso_and_enet_cv():
X, y = make_sparse_data(n_samples=40, n_features=10)
@pytest.mark.parametrize("csc_container", CSC_CONTAINERS)
def test_same_output_sparse_dense_lasso_and_enet_cv(csc_container):
X, y = make_sparse_data(csc_container, n_samples=40, n_features=10)
clfs = ElasticNetCV(max_iter=100)
clfs.fit(X, y)
clfd = ElasticNetCV(max_iter=100)
Expand All @@ -313,7 +325,8 @@ def test_same_output_sparse_dense_lasso_and_enet_cv():
assert_array_almost_equal(clfs.alphas_, clfd.alphas_)


def test_same_multiple_output_sparse_dense():
@pytest.mark.parametrize("coo_container", COO_CONTAINERS)
def test_same_multiple_output_sparse_dense(coo_container):
l = ElasticNet()
X = [
[0, 1, 2, 3, 4],
Expand All @@ -332,20 +345,21 @@ def test_same_multiple_output_sparse_dense():
predict_dense = l.predict(sample)

l_sp = ElasticNet()
X_sp = sp.coo_matrix(X)
X_sp = coo_container(X)
l_sp.fit(X_sp, y)
sample_sparse = sp.coo_matrix(sample)
sample_sparse = coo_container(sample)
predict_sparse = l_sp.predict(sample_sparse)

assert_array_almost_equal(predict_sparse, predict_dense)


def test_sparse_enet_coordinate_descent():
@pytest.mark.parametrize("csc_container", CSC_CONTAINERS)
def test_sparse_enet_coordinate_descent(csc_container):
"""Test that a warning is issued if model does not converge"""
clf = Lasso(max_iter=2)
n_samples = 5
n_features = 2
X = sp.csc_matrix((n_samples, n_features)) * 1e50
X = csc_container((n_samples, n_features)) * 1e50
y = np.ones(n_samples)
warning_message = (
"Objective did not converge. You might want "
Expand Down
0