10000 Implement estimator__sklearn_tags__ method · KhiopsML/khiops-python@cb319a2 · GitHub
[go: up one dir, main page]

Skip to content

Commit cb319a2

Browse files
Implement estimator__sklearn_tags__ method
This is an update to the estimator API which generates an error starting with scikit-learn 1.7. We keep `_more_tags` for backward compatibility and implement `__sklearn_tags__` with it. For details see scikit-learn/scikit-learn#28910
1 parent 34c6714 commit cb319a2

File tree

1 file changed

+19
-3
lines changed

1 file changed

+19
-3
lines changed

khiops/sklearn/estimators.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -268,8 +268,24 @@ def __init__(
268268
# Make sklearn get_params happy
269269
self.internal_sort = internal_sort
270270

271+
def __sklearn_tags__(self):
272+
# Error if BaseEstimator does not support __sklearn_tags__
273+
if not hasattr(BaseEstimator, "__sklearn_tags__"):
274+
raise AttributeError("__sklearn_tags__ API unsupported.")
275+
276+
# Set the tags from _more_tags
277+
tags = super().__sklearn_tags__()
278+
for tag, tag_value in self._more_tags():
279+
if tag == "allow_nan":
280+
tags.input_tags.allow_nan = tag_value
281+
elif tag == "requires_y":
282+
tags.target_tag.required = tag_value
283+
elif tag == "preserves_dtype":
284+
tags.transformer_tags.preserves_dtype = tag_value
285+
return tags
286+
271287
def _more_tags(self):
272-
return {"allow_nan": True, "accept_large_sparse": False}
288+
return {"allow_nan": True}
273289

274290
def _undefine_estimator_attributes(self):
275291
"""Undefines all sklearn estimator attributes (ie. pass to "not fit" state)
@@ -1400,7 +1416,7 @@ def __init__(
14001416
)
14011417

14021418
def _more_tags(self):
1403-
return {"require_y": True}
1419+
return {"requires_y": True}
14041420

14051421
def _fit_check_dataset(self, ds):
14061422
super()._fit_check_dataset(ds)
@@ -2783,7 +2799,7 @@ def __init__(
27832799
self.keep_initial_variables = keep_initial_variables
27842800
self._khiops_model_prefix = "R_"
27852801

2786-
def more_tags(self):
2802+
def _more_tags(self):
27872803
return {"preserves_dtype": []}
27882804

27892805
def _categorical_transform_method(self):

0 commit comments

Comments
 (0)
0