[go: up one dir, main page]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use scikit-learn check_estimator to ensure compatibiity of Nilearn estimators with scikit-learn #4538

Open
Remi-Gau opened this issue Sep 9, 2024 · 8 comments
Labels
Code quality This issue tackles code quality (code refactoring, PEP8...).

Comments

@Remi-Gau
Copy link
Collaborator
Remi-Gau commented Sep 9, 2024

raised by #4533

see https://scikit-learn.org/dev/developers/develop.html#rolling-your-own-estimator

quickly checked on a couple of Nilearn estimators and they seem to fail: will add code examples when in front of a real computer

may be worth discussing to see if we want to have 100% compliance or if we should have our own version of check_estimator

@bthirion
Copy link
Member
bthirion commented Sep 9, 2024

If it's not a hassle, I rather keep the compat with sklearn.

@Remi-Gau
Copy link
Collaborator Author
Remi-Gau commented Sep 9, 2024

Yup me too.
Will list the errors we get in this issue so we can make a proper decision.

@Remi-Gau
Copy link
Collaborator Author

one example that loops through all the checks

from sklearn.utils.estimator_checks import check_estimator
from nilearn.decoding.decoder import (
    Decoder,
)

def test_check_estimator():
    """Check compliance with sklearn estimators."""
    model = Decoder()

    for estimator, check in check_estimator(estimator=model, generate_only=True):
        try:
            check(estimator)
        except Exception:
            print(f"FAIL: {check}")
        else:
            print(f"SUCCESS: {check}")
SUCCESS: functools.partial(<function check_no_attributes_set_in_init at 0x735a695ef240>, 'Decoder')
FAIL: functools.partial(<function check_estimators_dtypes at 0x735a695eca40>, 'Decoder')
FAIL: functools.partial(<function check_fit_score_takes_y at 0x735a695ec900>, 'Decoder')
FAIL: functools.partial(<function check_estimators_fit_returns_self at 0x735a695ee340>, 'Decoder')
FAIL: functools.partial(<function check_estimators_fit_returns_self at 0x735a695ee340>, 'Decoder', readonly_memmap=True)
FAIL: functools.partial(<function check_complex_data at 0x735a695e3420>, 'Decoder')
FAIL: functools.partial(<function check_dtype_object at 0x735a695e3380>, 'Decoder')
FAIL: functools.partial(<function check_estimators_empty_data_messages at 0x735a695ecc20>, 'Decoder')
FAIL: functools.partial(<function check_pipeline_consistency at 0x735a695ec7c0>, 'Decoder')
FAIL: functools.partial(<function check_estimators_nan_inf at 0x735a695ecd60>, 'Decoder')
FAIL: functools.partial(<function check_estimators_overwrite_params at 0x735a695ef100>, 'Decoder')
FAIL: functools.partial(<function check_estimator_sparse_array at 0x735a695e2b60>, 'Decoder')
FAIL: functools.partial(<function check_estimator_sparse_matrix at 0x735a695e2ac0>, 'Decoder')
FAIL: functools.partial(<function check_estimators_pickle at 0x735a695ecfe0>, 'Decoder')
FAIL: functools.partial(<function check_estimators_pickle at 0x735a695ecfe0>, 'Decoder', readonly_memmap=True)
SUCCESS: functools.partial(<function check_estimator_get_tags_default_keys at 0x735a695f4680>, 'Decoder')
FAIL: functools.partial(<function check_classifier_data_not_an_array at 0x735a695ef4c0>, 'Decoder')
FAIL: functools.partial(<function check_classifiers_one_label at 0x735a695ed760>, 'Decoder')
SUCCESS: functools.partial(<function check_classifiers_one_label_sample_weights at 0x735a695ed8a0>, 'Decoder')
FAIL: functools.partial(<function check_classifiers_classes at 0x735a695ee840>, 'Decoder')
SUCCESS: functools.partial(<function check_estimators_partial_fit_n_features at 0x735a695ed120>, 'Decoder')
FAIL: functools.partial(<function check_classifiers_train at 0x735a695ed9e0>, 'Decoder')
FAIL: functools.partial(<function check_classifiers_train at 0x735a695ed9e0>, 'Decoder', readonly_memmap=True)
FAIL: functools.partial(<function check_classifiers_train at 0x735a695ed9e0>, 'Decoder', readonly_memmap=True, X_dtype='float32')
FAIL: functools.partial(<function check_classifiers_regression_target at 0x735a695eff60>, 'Decoder')
FAIL: functools.partial(<function check_supervised_y_no_nan at 0x735a695e1bc0>, 'Decoder')
FAIL: functools.partial(<function check_supervised_y_2d at 0x735a695ee5c0>, 'Decoder')
FAIL: functools.partial(<function check_estimators_unfitted at 0x735a695ee480>, 'Decoder')
SUCCESS: functools.partial(<function check_non_transformer_estimators_n_iter at 0x735a695efa60>, 'Decoder')
SUCCESS: functools.partial(<function check_decision_proba_consistency at 0x735a695f40e0>, 'Decoder')
SUCCESS: functools.partial(<function check_parameters_default_constructible at 0x735a695ef7e0>, 'Decoder')
FAIL: functools.partial(<function check_methods_sample_order_invariance at 0x735a695e3ba0>, 'Decoder')
FAIL: functools.partial(<function check_methods_subset_invariance at 0x735a695e3a60>, 'Decoder')
FAIL: functools.partial(<function check_fit2d_1sample at 0x735a695e3d80>, 'Decoder')
FAIL: functools.partial(<function check_fit2d_1feature at 0x735a695e3f60>, 'Decoder')
SUCCESS: functools.partial(<function check_get_params_invariance at 0x735a695efce0>, 'Decoder')
SUCCESS: functools.partial(<function check_set_params at 0x735a695efe20>, 'Decoder')
FAIL: functools.partial(<function check_dict_unchanged at 0x735a695e3560>, 'Decoder')
FAIL: functools.partial(<function check_dont_overwrite_parameters at 0x735a695e3740>, 'Decoder')
FAIL: functools.partial(<function check_fit_idempotent at 0x735a695f42c0>, 'Decoder')
FAIL: functools.partial(<function check_fit_check_is_fitted at 0x735a695f4360>, 'Decoder')
FAIL: functools.partial(<function check_n_features_in at 0x735a695f4400>, 'Decoder')
FAIL: functools.partial(<function check_fit1d at 0x735a695ec0e0>, 'Decoder')
FAIL: functools.partial(<function check_fit2d_predict1d at 0x735a695e3880>, 'Decoder')

From a quick inspection it looks like most of the checks that fail actually do so because the fit() method in scikit-learn is tested with numpy arrays whereas we expect most of the time list of nifti or surface images.

@Remi-Gau
Copy link
Collaborator Author

See for example.

https://github.com/scikit-learn/scikit-learn/blob/8a2d5ffa0512873a75da608eb14832253979ec44/sklearn/utils/estimator_checks.py#L1674

I don't think we want to start supporting numpy arrays, right?
So I think, that for some checks we may decide: that's a departure from scikit-learn expectations, because we don't want to support that type of input data.

At a minimum we can run some checks to make sure that those that pass, keep on passing, so to avoid 'regressions'.

And then we may have some checks that currently fail, but we may want to still run so we may have to write our own adapted checks.

@bthirion
Copy link
Member

Indeed, we don't want to enforce Numpy arrays.
So the long-term solution is to have our own checks.

@Remi-Gau
Copy link
Collaborator Author

yeah I am afraid so.

this may take some time but given there are several checks we can probably add them one by one through several PRs

@Remi-Gau
Copy link
Collaborator Author

follow up on this:

  • identify which check fails or pass for each of our estimator
  • identify which estimator should always fail for each estimator (for example do not fit numpy array)
  • identify which checks that currently fails that we care about fixing (for example pipeline compatibility) and possibly fix them
  • see if we want / can create some of our own checks

@Remi-Gau Remi-Gau added the Code quality This issue tackles code quality (code refactoring, PEP8...). label Sep 19, 2024
@Remi-Gau
Copy link
Collaborator Author

Maybe worth double checking if those tests are related to this issue: https://github.com/nilearn/nilearn/blob/main/nilearn/decoding/tests/test_sklearn_compatibility.py

Or if those tests are still needed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Code quality This issue tackles code quality (code refactoring, PEP8...).
Projects
None yet
Development

No branches or pull requests

2 participants