10000 FIX changed_only=True with kwargs parameters (#17205) · scikit-learn/scikit-learn@1d40392 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1d40392

Browse files
NicolasHugadrinjalali
authored andcommitted
FIX changed_only=True with kwargs parameters (#17205)
1 parent d3f5254 commit 1d40392

File tree

3 files changed

+49
-4
lines changed

3 files changed

+49
-4
lines changed

doc/whats_new/v0.23.rst

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,16 @@ Changelog
1616
......................
1717

1818
- |Fix| Fixed a bug in :class:`cluster.KMeans` where the sample weights
19-
provided by the user was modified in place. :pr:`17204` by
19+
provided by the user were modified in place. :pr:`17204` by
2020
:user:`Jeremie du Boisberranger <jeremiedbb>`.
2121

22+
Miscellaneous
23+
.............
24+
25+
- |Fix| Fixed a bug in the `repr` of third-party estimators that use a
26+
`**kwargs` parameter in their constructor, when `changed_only` is True
27+
which is now the default. :pr:`17205` by `Nicolas Hug`_.
28+
2229
.. _changes_0_23:
2330

2431
Version 0.23.0

sklearn/utils/_pprint.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,11 @@ def _changed_params(estimator):
9494
estimator.__init__)
9595
init_params = signature(init_func).parameters
9696
init_params = {name: param.default for name, param in init_params.items()}
97+
9798
for k, v in params.items():
98-
if (repr(v) != repr(init_params[k]) and
99-
not (is_scalar_nan(init_params[k]) and is_scalar_nan(v))):
99+
if (k not in init_params or ( # happens if k is part of a **kwargs
100+
repr(v) != repr(init_params[k]) and
101+
not (is_scalar_nan(init_params[k]) and is_scalar_nan(v)))):
100102
filtered_params[k] = v
101103
return filtered_params
102104

sklearn/utils/tests/test_pprint.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from sklearn.pipeline import make_pipeline
99
from sklearn.base import BaseEstimator, TransformerMixin
1010
from sklearn.feature_selection import SelectKBest, chi2
11-
from sklearn import set_config
11+
from sklearn import set_config, config_context
1212

1313

1414
# Ignore flake8 (lots of line too long issues)
@@ -538,3 +538,39 @@ def test_builtin_prettyprinter():
538538
# Used to be a bug
539539

540540
PrettyPrinter().pprint(LogisticRegression())
541+
542+
543+
def test_kwargs_in_init():
544+
# Make sure the changed_only=True mode is OK when an argument is passed as
545+
# kwargs.
546+
# Non-regression test for
547+
# https://github.com/scikit-learn/scikit-learn/issues/17206
548+
549+
class WithKWargs(BaseEstimator):
550+
# Estimator with a kwargs argument. These need to hack around
551+
# set_params and get_params. Here we mimic what LightGBM does.
552+
def __init__(self, a='willchange', b='unchanged', **kwargs):
553+
self.a = a
554+
self.b = b
555+
self._other_params = {}
556+
self.set_params(**kwargs)
557+
558+
def get_params(self, deep=True):
559+
params = super().get_params(deep=deep)
560+
params.update(self._other_params)
561+
return params
562+
563+
def set_params(self, **params):
564+
for key, value in params.items():
565+
setattr(self, key, value)
566+
self._other_params[key] = value
567+
return self
568+
569+
est = WithKWargs(a='something', c='abcd', d=None)
570+
571+
expected = "WithKWargs(a='something', c='abcd', d=None)"
572+
assert expected == est.__repr__()
573+
574+
with config_context(print_changed_only=False):
575+
expected = "WithKWargs(a='something', b='unchanged', c='abcd', d=None)"
576+
assert expected == est.__repr__()

0 commit comments

Comments
 (0)
0