8000 TST Relax init parameter checks for unmutable objects (#17936) · simonamaggio/scikit-learn@86101d7 · GitHub
[go: up one dir, main page]

Skip to content
Sign in

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 86101d7

Browse files
authored
TST Relax init parameter checks for unmutable objects (scikit-learn#17936)
1 parent 5af6561 commit 86101d7

File tree

2 files changed

+62
-9
lines changed

2 files changed

+62
-9
lines changed

sklearn/utils/estimator_checks.py

+33-9
Original file line numberDiff line numberDiff line change
@@ -2661,26 +2661,50 @@ def param_filter(p):
26612661
assert init_param.default != init_param.empty, (
26622662
"parameter %s for %s has no default value"
26632663
% (init_param.name, type(estimator).__name__))
2664-
if type(init_param.default) is type:
2665-
assert init_param.default in [np.float64, np.int64]
2666-
else:
2667-
assert (type(init_param.default) in
2668-
[str, int, float, bool, tuple, type(None),
2669-
np.float64, types.FunctionType, joblib.Memory])
2664+
allowed_types = {
2665+
str,
2666+
int,
2667+
float,
2668+
bool,
2669+
tuple,
2670+
type(None),
2671+
type,
2672+
types.FunctionType,
2673+
joblib.Memory,
2674+
}
2675+
# Any numpy numeric such as np.int32.
2676+
allowed_types.update(np.core.numerictypes.allTypes.values())
2677+
assert type(init_param.default) in allowed_types, (
2678+
f"Parameter '{init_param.name}' of estimator "
2679+
f"'{Estimator.__name__}' is of type "
2680+
f"{type(init_param.default).__name__} which is not "
2681+
f"allowed. All init parameters have to be immutable to "
2682+
f"make cloning possible. Therefore we restrict the set of "
2683+
f"legal types to "
2684+
f"{set(type.__name__ for type in allowed_types)}."
2685+
)
26702686
if init_param.name not in params.keys():
26712687
# deprecated parameter, not in get_params
2672-
assert init_param.default is None
2688+
assert init_param.default is None, (
2689+
f"Estimator parameter '{init_param.name}' of estimator "
2690+
f"'{Estimator.__name__}' is not returned by get_params. "
2691+
f"If it is deprecated, set its default value to None."
2692+
)
26732693
continue
26742694

26752695
param_value = params[init_param.name]
26762696
if isinstance(param_value, np.ndarray):
26772697
assert_array_equal(param_value, init_param.default)
26782698
else:
2699+
failure_text = (
2700+
f"Parameter {init_param.name} was mutated on init. All "
2701+
f"parameters must be stored unchanged."
2702+
)
26792703
if is_scalar_nan(param_value):
26802704
# Allows to set default parameters to np.nan
2681-
assert param_value is init_param.default, init_param.name
2705+
assert param_value is init_param.default, failure_text
26822706
else:
2683-
assert param_value == init_param.default, init_param.name
2707+
assert param_value == init_param.default, failure_text
26842708

26852709

26862710
def _enforce_estimator_tags_y(estimator, y):

sklearn/utils/tests/test_estimator_checks.py

+29
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,26 @@ def fit(self, X, y=None):
113113
return self
114114

115115

116+
class HasMutableParameters(BaseEstimator):
117+
def __init__(self, p=object()):
118+
self.p = p
119+
120+
def fit(self, X, y=None):
121+
X, y = self._validate_data(X, y)
122+
return self
123+
124+
125+
class HasImmutableParameters(BaseEstimator):
126+
# Note that object is an uninitialized class, thus immutable.
127+
def __init__(self, p=42, q=np.int32(42), r=object):
128+
self.p = p
129+
self.q = q
130+
self.r = r
131+
132+
def fit(self, X, y=None):
133+
X, y = self._validate_data(X, y)
134+
return self
135+
116136
class ModifiesValueInsteadOfRaisingError(BaseEstimator):
117137
def __init__(self, p=0):
118138
self.p = p
@@ -381,6 +401,15 @@ def test_check_estimator():
381401
assert_raises_regex(TypeError, msg, check_estimator, object)
382402
msg = "object has no attribute '_get_tags'"
383403
assert_raises_regex(AttributeError, msg, check_estimator, object())
404+
msg = (
405+
"Parameter 'p' of estimator 'HasMutableParameters' is of type "
406+
"object which is not allowed"
407+
)
408+
# check that the "default_constructible" test checks for mutable parameters
409+
check_estimator(HasImmutableParameters()) # should pass
410+
assert_raises_regex(
411+
AssertionError, msg, check_estimator, HasMutableParameters()
412+
)
384413
# check that values returned by get_params match set_params
385414
msg = "get_params result does not match what was passed to set_params"
386415
assert_raises_regex(AssertionError, msg, check_estimator,

0 commit comments

Comments
 (0)
0