8000 FIX parametrizing estimator checks allows custom estimators that impl… · charlesjhill/scikit-learn@1210b06 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1210b06

Browse files
randolf-scholzadrinjalalijeremiedbb
authored
FIX parametrizing estimator checks allows custom estimators that implements __call__ (scikit-learn#28860)
Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com> Co-authored-by: jeremiedbb <jeremiedbb@yahoo.fr>
1 parent 0c72e54 commit 1210b06

File tree

2 files changed

+16
-6
lines changed

2 files changed

+16
-6
lines changed

sklearn/tests/test_common.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import pytest
2121

2222
import sklearn
23+
from sklearn.base import BaseEstimator
2324
from sklearn.cluster import (
2425
OPTICS,
2526
AffinityPropagation,
@@ -103,6 +104,16 @@ def _sample_func(x, y=1):
103104
pass
104105

105106

107+
class CallableEstimator(BaseEstimator):
108+
"""Dummy development stub for an estimator.
109+
110+
This is to make sure a callable estimator passes common tests.
111+
"""
112+
113+
def __call__(self):
114+
pass # pragma: nocover
115+
116+
106117
@pytest.mark.parametrize(
107118
"val, expected",
108119
[
@@ -122,6 +133,7 @@ def _sample_func(x, y=1):
122133
"solver='newton-cg',warm_start=True)"
123134
),
124135
),
136+
(CallableEstimator(), "CallableEstimator()"),
125137
],
126138
)
127139
def test_get_check_estimator_ids(val, expected):

sklearn/utils/estimator_checks.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from contextlib import nullcontext
1010
from copy import deepcopy
1111
from functools import partial, wraps
12-
from inspect import signature
12+
from inspect import isfunction, signature
1313
from numbers import Integral, Real
1414

1515
import joblib
@@ -405,13 +405,11 @@ def _get_check_estimator_ids(obj):
405405
--------
406406
check_estimator
407407
"""
408-
if callable(obj):
409-
if not isinstance(obj, partial):
410-
return obj.__name__
411-
408+
if isfunction(obj):
409+
return obj.__name__
410+
if isinstance(obj, partial):
412411
if not obj.keywords:
413412
return obj.func.__name__
414-
415413
kwstring = ",".join(["{}={}".format(k, v) for k, v in obj.keywords.items()])
416414
return "{}({})".format(obj.func.__name__, kwstring)
417415
if hasattr(obj, "get_params"):

0 commit comments

Comments
 (0)
0