8000 Estimator tag overwriting and update in _get_tags · Issue #14044 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content
Estimator tag overwriting and update in _get_tags #14044
Closed
@glemaitre

Description

@glemaitre

I am updating some code in imbalanced-learn to use the estimator tag. In some way, I was able to add a new entry in the _DEFAULT_TAGS and use the implementation of _safe_tags.

I have the following use case:

_DEFAULT_TAGS = {'sample_indices': False}

class BaseClass:
    ...
    def _more_tags(self):
        return {'sample_indices': True}

class SpecialClass(BaseClass):
    ...
    def _more_tags(self):
        tags = super()._more_tags()
        tags['sample_indices'] = False
        return tags

For some reason, all estimators inheriting from the BaseClass would have a sample_indices tag to True. All but not for one class where I would like to overwrite the tag. Here, I made use of the super class to give the trick that I think could solve the following issue.

Because we are overwriting the sample_indices, _get_tag is failing due to what is currently considered as an inconsistent update of the tags dictionary:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-31-5558b1aebfa8> in <module>
----> 1 _safe_tags(xx, 'sample_indices')

~/Documents/packages/scikit-learn/sklearn/utils/estimator_checks.py in _safe_tags(estimator, key)
     68     if hasattr(estimator, "_get_tags"):
     69         if key is not None:
---> 70             return estimator._get_tags().get(key, _DEFAULT_TAGS[key])
     71         tags = estimator._get_tags()
     72         return {key: tags.get(key, _DEFAULT_TAGS[key])

~/D
6DCF
ocuments/packages/scikit-learn/sklearn/base.py in _get_tags(self)
    320         if hasattr(self, '_more_tags'):
    321             more_tags = self._more_tags()
--> 322             collected_tags = _update_if_consistent(collected_tags, more_tags)
    323         tags = _DEFAULT_TAGS.copy()
    324         tags.update(collected_tags)

~/Documents/packages/scikit-learn/sklearn/base.py in _update_if_consistent(dict1, dict2)
    132         if dict1[key] != dict2[key]:
    133             raise TypeError("Inconsistent values for tag {}: {} != {}".format(
--> 134                 key, dict1[key], dict2[key]
    135             ))
    136     dict1.update(dict2)

TypeError: Inconsistent values for tag sample_indices: True != False

Without the call of super(), I find this rule quite meaningful and the error raising a good thing. However, I think that we could lift this rule in case where we call super() in self._more_tag. By calling super(), one should be aware of overwriting the base-class default tag.

Thus, by introspecting self._more_tag and check that super() is called (e.g. using 'super' in inspect.getclosurevars(self._more_tags).builtins), we could still allow to update the tags.

@amueller @rth @jnothman Do you think that it is use case and a solution which make sense. I still have the option to add a _more_tag to each classes but this is a lot of duplicated code then.

So the changes would be something like the following:

-def _update_if_consistent(dict1, dict2):
+def _update_if_consistent(dict1, dict2, force=False):
     common_keys = set(dict1.keys()).intersection(dict2.keys())
-    for key in common_keys:
-        if dict1[key] != dict2[key]:
-            raise TypeError("Inconsistent values for tag {}: {} != {}".format(
-                key, dict1[key], dict2[key]
-            ))
+    if not force:
+        for key in common_keys:
+            if dict1[key] != dict2[key]:
+                raise TypeError("Inconsistent values for tag {}: {} != {}"
+                                .format(key, dict1[key], dict2[key]))
     dict1.update(dict2)
     return dict1
 
@@ -319,7 +319,10 @@ class BaseEstimator:
                                                        more_tags)
         if hasattr(self, '_more_tags'):
             more_tags = self._more_tags()
-            collected_tags = _update_if_consistent(collected_tags, more_tags)
+            force = 'super' in inspect.getclosurevars(self._more_tags).builtins
+            collected_tags = _update_if_consistent(
+                collected_tags, more_tags, force=force
+            )
         tags = _DEFAULT_TAGS.copy()
         tags.update(collected_tags)
         return tags

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0