Description
I am updating some code in imbalanced-learn
to use the estimator tag. In some way, I was able to add a new entry in the _DEFAULT_TAGS
and use the implementation of _safe_tags
.
I have the following use case:
_DEFAULT_TAGS = {'sample_indices': False}
class BaseClass:
...
def _more_tags(self):
return {'sample_indices': True}
class SpecialClass(BaseClass):
...
def _more_tags(self):
tags = super()._more_tags()
tags['sample_indices'] = False
return tags
For some reason, all estimators inheriting from the BaseClass
would have a sample_indices
tag to True
. All but not for one class where I would like to overwrite the tag. Here, I made use of the super
class to give the trick that I think could solve the following issue.
Because we are overwriting the sample_indices
, _get_tag
is failing due to what is currently considered as an inconsistent update of the tags dictionary:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-31-5558b1aebfa8> in <module>
----> 1 _safe_tags(xx, 'sample_indices')
~/Documents/packages/scikit-learn/sklearn/utils/estimator_checks.py in _safe_tags(estimator, key)
68 if hasattr(estimator, "_get_tags"):
69 if key is not None:
---> 70 return estimator._get_tags().get(key, _DEFAULT_TAGS[key])
71 tags = estimator._get_tags()
72 return {key: tags.get(key, _DEFAULT_TAGS[key])
~/D
6DCF
ocuments/packages/scikit-learn/sklearn/base.py in _get_tags(self)
320 if hasattr(self, '_more_tags'):
321 more_tags = self._more_tags()
--> 322 collected_tags = _update_if_consistent(collected_tags, more_tags)
323 tags = _DEFAULT_TAGS.copy()
324 tags.update(collected_tags)
~/Documents/packages/scikit-learn/sklearn/base.py in _update_if_consistent(dict1, dict2)
132 if dict1[key] != dict2[key]:
133 raise TypeError("Inconsistent values for tag {}: {} != {}".format(
--> 134 key, dict1[key], dict2[key]
135 ))
136 dict1.update(dict2)
TypeError: Inconsistent values for tag sample_indices: True != False
Without the call of super()
, I find this rule quite meaningful and the error raising a good thing. However, I think that we could lift this rule in case where we call super()
in self._more_tag
. By calling super()
, one should be aware of overwriting the base-class default tag.
Thus, by introspecting self._more_tag
and check that super()
is called (e.g. using 'super' in inspect.getclosurevars(self._more_tags).builtins
), we could still allow to update the tags.
@amueller @rth @jnothman Do you think that it is use case and a solution which make sense. I still have the option to add a _more_tag
to each classes but this is a lot of duplicated code then.
So the changes would be something like the following:
-def _update_if_consistent(dict1, dict2):
+def _update_if_consistent(dict1, dict2, force=False):
common_keys = set(dict1.keys()).intersection(dict2.keys())
- for key in common_keys:
- if dict1[key] != dict2[key]:
- raise TypeError("Inconsistent values for tag {}: {} != {}".format(
- key, dict1[key], dict2[key]
- ))
+ if not force:
+ for key in common_keys:
+ if dict1[key] != dict2[key]:
+ raise TypeError("Inconsistent values for tag {}: {} != {}"
+ .format(key, dict1[key], dict2[key]))
dict1.update(dict2)
return dict1
@@ -319,7 +319,10 @@ class BaseEstimator:
more_tags)
if hasattr(self, '_more_tags'):
more_tags = self._more_tags()
- collected_tags = _update_if_consistent(collected_tags, more_tags)
+ force = 'super' in inspect.getclosurevars(self._more_tags).builtins
+ collected_tags = _update_if_consistent(
+ collected_tags, more_tags, force=force
+ )
tags = _DEFAULT_TAGS.copy()
tags.update(collected_tags)
return tags