10000 BaggingClassifer/BaggingRegressor tests for sparse input · scikit-learn/scikit-learn@f38db56 · GitHub
[go: up one dir, main page]

Skip to content

Commit f38db56

Browse files
committed
BaggingClassifer/BaggingRegressor tests for sparse input
1 parent 6982977 commit f38db56

File tree

1 file changed

+43
-26
lines changed

1 file changed

+43
-26
lines changed

sklearn/ensemble/tests/test_bagging.py

Lines changed: 43 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from sklearn.datasets import load_boston, load_iris
2929
from sklearn.utils import check_random_state
3030

31-
from scipy.sparse import csc_matrix, csr_matrix
31+
from scipy.sparse import csc_matrix
3232

3333
rng = check_random_state(0)
3434

@@ -81,19 +81,27 @@ def test_sparse_classification():
8181
"bootstrap": [True, False],
8282
"bootstrap_features": [True, False]})
8383

84-
for base_estimator in [DummyClassifier(),
85-
Perceptron(),
86-
KNeighborsClassifier(),
87-
SVC()]:
88-
for params in grid:
89-
for sparse_format in [csc_matrix, csr_matrix]:
90-
X_train_sparse = sparse_format(X_train)
91-
X_test_sparse = sparse_format(X_test)
92-
BaggingClassifier(
93-
base_estimator=base_estimator,
94-
random_state=rng,
95-
**params
96-
).fit(X_train_sparse, y_train).predict(X_test_sparse) 10000
84+
base_estimator = SVC()
85+
for params in grid:
86+
sparse_format = csc_matrix
87+
X_train_sparse = sparse_format(X_train)
88+
X_test_sparse = sparse_format(X_test)
89+
90+
# Trained on sparse format
91+
sparse_results = BaggingClassifier(
92+
base_estimator=base_estimator,
93+
random_state=check_random_state(1),
94+
**params
95+
).fit(X_train_sparse, y_train).predict(X_test_sparse)
96+
97+
# Trained on dense format
98+
dense_results = BaggingClassifier(
99+
base_estimator=base_estimator,
100+
random_state=check_random_state(1),
101+
**params
102+
).fit(X_train, y_train).predict(X_test)
103+
104+
assert_array_equal(sparse_results, dense_results)
97105

98106

99107
def test_regression():
@@ -129,18 +137,27 @@ def test_sparse_regression():
129137
"bootstrap": [True, False],
130138
"bootstrap_features": [True, False]})
131139

132-
for base_estimator in [DummyRegressor(),
133-
KNeighborsRegressor(),
134-
SVR()]:
135-
for params in grid:
136-
for sparse_format in [csc_matrix, csr_matrix]:
137-
X_train_sparse = sparse_format(X_train)
138-
X_test_sparse = sparse_format(X_test)
139-
BaggingRegressor(
140-
base_estimator=base_estimator,
141-
random_state=rng,
142-
**params
143-
).fit(X_train_sparse, y_train).predict(X_test_sparse)
140+
base_estimator = SVR()
141+
for params in grid:
142+
sparse_format = csc_matrix
143+
X_train_sparse = sparse_format(X_train)
144+
X_test_sparse = sparse_format(X_test)
145+
146+
# Trained on sparse format
147+
sparse_results = BaggingRegressor(
148+
base_estimator=base_estimator,
149+
random_state=check_random_state(1),
150+
**params
151+
).fit(X_train_sparse, y_train).predict(X_test_sparse)
152+
153+
# Trained on dense format
154+
dense_results = BaggingRegressor(
155+
base_estimator=base_estimator,
156+
random_state=check_random_state(1),
157+
**params
158+
).fit(X_train, y_train).predict(X_test)
159+
160+
assert_array_equal(sparse_results, dense_results)
144161

145162

146163
def test_bootstrap_samples():

0 commit comments

Comments
 (0)
0