8000 flatten_transform: add tests and validation · genvalen/scikit-learn@ae4698c · GitHub
[go: up one dir, main page]

Skip to content

Commit ae4698c

Browse files
committed
flatten_transform: add tests and validation
1 parent d67a9fc commit ae4698c

File tree

3 files changed

+175
-1
lines changed

3 files changed

+175
-1
lines changed

sklearn/ensemble/_voting.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,16 @@ def fit(self, X, y, sample_weight=None):
8080
% (len(self.weights), len(self.estimators))
8181
)
8282

83+
# if self.n_jobs is not None:
84+
# if isinstance(self.n_jobs, numbers.Integral):
85+
# if self.n_jobs < 0:
86+
# check_scalar(
87+
# self.n_jobs,
88+
# name="n_jobs",
89+
# target_type=numbers.Integral,
90+
# max_val=-1,
91+
# )
92+
8393
self.estimators_ = Parallel(n_jobs=self.n_jobs)(
8494
delayed(_fit_single_estimator)(
8595
clone(clf),
@@ -321,6 +331,12 @@ def fit(self, X, y, sample_weight=None):
321331
"Multilabel and m 8000 ulti-output classification is not supported."
322332
)
323333

334+
check_scalar(
335+
self.flatten_transform,
336+
name="flatten_transform",
337+
target_type=(numbers.Integral, np.bool_),
338+
)
339+
324340
if self.voting not in ("soft", "hard"):
325341
raise ValueError(
326342
"Voting must be 'soft' or 'hard'; got (voting=%r)" % self.voting

sklearn/ensemble/tests/test_gradient_boosting.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def test_classification_toy(loss):
8181
@pytest.mark.parametrize(
8282
"params, err_msg",
8383
[
84+
<<<<<<< Updated upstream
8485
({"n_estimators": 0}, "n_estimators must be greater than 0"),
8586
({"n_estimators": -1}, "n_estimators must be greater than 0"),
8687
({"learning_rate": 0}, "learning_rate must be greater than 0"),
@@ -104,6 +105,153 @@ def test_classification_toy(loss):
104105
({"max_features": 100}, r"max_features must be in \(0, n_features\]"),
105106
({"max_features": -0.1}, r"max_features must be in \(0, n_features\]"),
106107
({"n_iter_no_change": "invalid"}, "n_iter_no_change should either be"),
108+
=======
109+
({"learning_rate": 0}, ValueError, "learning_rate == 0, must be > 0."),
110+
({"learning_rate": -1.0}, ValueError, "learning_rate == -1.0, must be > 0."),
111+
({"n_estimators": 0}, ValueError, "n_estimators == 0, must be >= 1."),
112+
({"n_estimators": -1}, ValueError, "n_estimators == -1, must be >= 1."),
113+
(
114+
{"n_estimators": 1.5},
115+
TypeError,
116+
"n_estimators must be an instance of <class 'numbers.Integral'>,",
117+
),
118+
({"loss": "foobar"}, ValueError, "Loss 'foobar' not supported"),
119+
# ({"min_samples_split": 1}, ValueError, "min_samples_split == 1, must be >= 2"),
120+
# (
121+
# {"min_samples_split": 900},
122+
# ValueError,
123+
# "min_samples_split == 900, must be <=",
124+
# ),
125+
# (
126+
# {"min_samples_split": 0.0},
127+
# ValueError,
128+
# "min_samples_split == 0.0, must be > 0.0",
129+
# ),
130+
# (
131+
# {"min_samples_split": 1.1},
132+
# ValueError,
133+
# "min_samples_split == 1.1, must be <= 1.0",
134+
# ),
135+
# (
136+
# {"min_samples_split": "foo"},
137+
# TypeError,
138+
# "min_samples_split must be an instance of <class 'numbers.Real'>",
139+
# ),
140+
# ({"min_samples_leaf": 0}, ValueError, "min_samples_leaf == 0, must be >= 1"),
141+
# ({"min_samples_leaf": 900}, ValueError, "min_samples_leaf == 900, must be <="),
142+
# ({"min_samples_leaf": 0.0}, ValueError, "min_samples_leaf == 0.0, must be > 0"),
143+
# (
144+
# {"min_samples_leaf": 0.6},
145+
# ValueError,
146+
# "min_samples_leaf == 0.6, must be <= 0.5",
147+
# ),
148+
# (
149+
# {"min_samples_leaf": "foo"},
150+
# TypeError,
151+
# "min_samples_leaf must be an instance of <class 'numbers.Real'>",
152+
# ),
153+
# (
154+
# {"min_weight_fraction_leaf": -1},
155+
# ValueError,
156+
# "min_weight_fraction_leaf == -1, must be >= 0.0",
157+
# ),
158+
# (
159+
# {"min_weight_fraction_leaf": 0.6},
160+
# ValueError,
161+
# "min_weight_fraction_leaf == 0.6, must be <= 0.5",
162+
# ),
163+
# (
164+
# {"min_weight_fraction_leaf": "foo"},
165+
# TypeError,
166+
# "min_weight_fraction_leaf must be an instance of <class 'numbers.Real'>",
167+
# ),
168+
# ({"max_depth": -1}, ValueError, "max_depth == -1, must be >= 1"),
169+
# (
170+
# {"max_depth": 1.1},
171+
# TypeError,
172+
# "max_depth must be an instance of <class 'numbers.Integral'>",
173+
# ),
174+
# (
175+
# {"min_impurity_decrease": -1},
176+
# ValueError,
177+
# "min_impurity_decrease == -1, must be >= 0.0",
178+
# ),
179+
# (
180+
# {"min_impurity_decrease": "foo"},
181+
# TypeError,
182+
# "min_impurity_decrease must be an instance of <class 'numbers.Real'>",
183+
# ),
184+
({"subsample": 0.0}, ValueError, "subsample == 0.0, must be > 0."),
185+
({"subsample": 1.1}, ValueError, "subsample == 1.1, must be <= 1."),
186+
({"subsample": -0.1}, ValueError, "subsample == -0.1, must be > 0."),
187+
(
188+
{"subsample": "1"},
189+
TypeError,
190+
"subsample must be an instance of <class 'numbers.Real'>,",
191+
),
192+
193+
({"init": {}}, ValueError, "The init parameter must be an estimator or 'zero'"),
194+
({"max_features": 0}, ValueError, "max_features == 0, must be >= 1"),
195+
({"max_features": 1000}, ValueError, "max_features == 1000, must be <="),
196+
({"max_features": 0.0}, ValueError, "max_features == 0.0, must be > 0.0"),
197+
({"max_features": 1.1}, ValueError, "max_features == 1.1, must be <= 1.0"),
198+
({"max_features": "foobar"}, ValueError, "Invalid value for max_features."),
199+
# ({"ccp_alpha": -1.0}, ValueError, "ccp_alpha == -1.0, must be >= 0.0"),
200+
# (
201+
# {"ccp_alpha": "foo"},
202+
# TypeError,
203+
# "ccp_alpha must be an instance of <class 'numbers.Real'>",
204+
# ),
205+
({"verbose": -1}, ValueError, "verbose == -1, must be >= 0"),
206+
(
207+
{"verbose": "foo"},
208+
TypeError,
209+
"verbose must be an instance of",
210+
),
211+
# ({"max_leaf_nodes": 0}, ValueError, "max_leaf_nodes == 0, must be >= 2"),
212+
# (
213+
# {"max_leaf_nodes": 1.5},
214+
# TypeError,
215+
# "max_leaf_nodes must be an instance of <class 'numbers.Integral'>",
216+
# ),
217+
({"warm_start": "foo"}, TypeError, "warm_start must be an instance of"),
218+
(
219+
{"validation_fraction": 0.0},
220+
ValueError,
221+
"validation_fraction == 0.0, must be > 0.0",
222+
),
223+
(
224+
{"validation_fraction": 1.0},
225+
ValueError,
226+
"validation_fraction == 1.0, must be < 1.0",
227+
),
228+
(
229+
{"validation_fraction": "foo"},
230+
TypeError,
231+
"validation_fraction must be an instance of <class 'numbers.Real'>",
232+
),
233+
234+
({"n_iter_no_change": -1}, ValueError, "n_iter_no_change == -1, must be >= 1"),
235+
({"n_iter_no_change": 0}, ValueError, "n_iter_no_change == 0, must be >= 1"),
236+
(
237+
{"n_iter_no_change": 1.5},
238+
TypeError,
239+
"n_iter_no_change must be an instance of <class 'numbers.Integral'>,",
240+
),
241+
(
242+
{"n_iter_no_change": "invalid"},
243+
TypeError,
244+
"n_iter_no_change must be an instance of <class 'numbers.Integral'>,",
245+
),
246+
({"tol": 0.0}, ValueError, "tol == 0.0, must be > 0.0"),
247+
(
248+
{"tol": "foo"},
249+
TypeError,
250+
"tol must be an instance of <class 'numbers.Real'>,",
251+
),
252+
253+
254+
>>>>>>> Stashed changes
107255
],
108256
# Avoid long error messages in test names:
109257
# https://github.com/scikit-learn/scikit-learn/issues/21362

sklearn/ensemble/tests/test_voting.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,16 @@
3434
X_r, y_r = datasets.load_diabetes(return_X_y=True)
3535

3636

37+
def test_error():
38+
# Test that proper excetions are raise given invalid input
39+
voter = VotingClassifier(
40+
estimators=[("lr", LogisticRegression())], flatten_transform="foo"
41+
)
42+
err_msg = "flatten_transform must be an instance of"
43+
with pytest.raises(TypeError, match=err_msg):
44+
voter.fit(X_r, y_r)
45+
46+
3747
@pytest.mark.parametrize(
3848
"X, y, voter, learner",
3949
[
@@ -51,7 +61,7 @@
5161
def test_voting_estimators_param_validation(
5262
X, y, voter, learner, params, err_type, err_msg
5363
):
54-
# Test scalar parameters that are invalid
64+
# Test that proper excetions are raise given invalid input
5565
params.update(learner)
5666
est = voter(**params)
5767
with pytest.raises(err_type, match=err_msg):

0 commit comments

Comments
 (0)
0