8000 MAINT Param validation: remove n_estimators checks from validate_esti… · scikit-learn/scikit-learn@4114161 · GitHub
[go: up one dir, main page]

Skip to content

Commit 4114161

Browse files
naoise-hMicky774jeremiedbb
authored
MAINT Param validation: remove n_estimators checks from validate_estimator (#24224)
Co-authored-by: Meekail Zain <34613774+Micky774@users.noreply.github.com> Co-authored-by: jeremie du boisberranger <jeremiedbb@yahoo.fr>
1 parent 3cd0760 commit 4114161

File tree

2 files changed

+1
-38
lines changed

2 files changed

+1
-38
lines changed

sklearn/ensemble/_base.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# License: BSD 3 clause
55

66
from abc import ABCMeta, abstractmethod
7-
import numbers
87
from typing import List
98

109
import numpy as np
@@ -128,24 +127,10 @@ def __init__(self, base_estimator, *, n_estimators=10, estimator_params=tuple())
128127
# self.estimators_ needs to be filled by the derived classes in fit.
129128

130129
def _validate_estimator(self, default=None):
131-
"""Check the estimator and the n_estimator attribute.
130+
"""Check the base estimator.
132131
133132
Sets the base_estimator_` attributes.
134133
"""
135-
if not isinstance(self.n_estimators, numbers.Integral):
136-
raise ValueError(
137-
"n_estimators must be an integer, got {0}.".format(
138-
type(self.n_estimators)
139-
)
140-
)
141-
142-
if self.n_estimators <= 0:
143-
raise ValueError(
144-
"n_estimators must be greater than zero, got {0}.".format(
145-
self.n_estimators
146-
)
147-
)
148-
149134
if self.base_estimator is not None:
150135
self.base_estimator_ = self.base_estimator
151136
else:

sklearn/ensemble/tests/test_base.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
# License: BSD 3 clause
77

88
import numpy as np
9-
import pytest
109

1110
from sklearn.datasets import load_iris
1211
from sklearn.ensemble import BaggingClassifier
@@ -49,27 +48,6 @@ def test_base():
4948
np_int_ensemble.fit(iris.data, iris.target)
5049

5150

52-
def test_base_zero_n_estimators():
53-
# Check that instantiating a BaseEnsemble with n_estimators<=0 raises
54-
# a ValueError.
55-
ensemble = BaggingClassifier(base_estimator=Perceptron(), n_estimators=0)
56-
iris = load_iris()
57-
with pytest.raises(ValueError):
58-
ensemble.fit(iris.data, iris.target)
59-
60-
61-
def test_base_not_int_n_estimators():
62-
# Check that instantiating a BaseEnsemble with a string as n_estimators
63-
# raises a ValueError demanding n_estimators to be supplied as an integer.
64-
string_ensemble = BaggingClassifier(base_estimator=Perceptron(), n_estimators="3")
65-
iris = load_iris()
66-
with pytest.raises(ValueError):
67-
string_ensemble.fit(iris.data, iris.target)
68-
float_ensemble = BaggingClassifier(base_estimator=Perceptron(), n_estimators=3.0)
69-
with pytest.raises(ValueError):
70-
float_ensemble.fit(iris.data, iris.target)
71-
72-
7351
def test_set_random_states():
7452
# Linear Discriminant Analysis doesn't have random state: smoke test
7553
_set_random_states(LinearDiscriminantAnalysis(), random_state=17)

0 commit comments

Comments
 (0)
0