8000 Test __array_function__ not called in non-estimator API · Issue #15865 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

Test __array_function__ not called in non-estimator API #15865

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
jnothman opened this issue Dec 11, 2019 · 6 comments
Open

Test __array_function__ not called in non-estimator API #15865

jnothman opened this issue Dec 11, 2019 · 6 comments

Comments

@jnothman
Copy link
Member

See #14702 where this was fixed in common estimator tests, but not in functions such as cross_validation and permutation_importance

@shivamgargsya
Copy link
Contributor

@jnothman can I work on this ?

@jnothman
Copy link
Member Author

Sure, give it a shot

@alexshacked
Copy link
Contributor
alexshacked commented Aug 28, 2020

Hi @jnothman.
I tested cross_validate and permutation_importance using

class _NotAnArray:
    """An object that is convertible to an array.
    Parameters
    ----------
    data : array-like
        The data.
    """

    def __init__(self, data):
        self.data = np.asarray(data)

    def __array__(self, dtype=None):
        return self.data

    def __array_function__(self, func, types, args, kwargs):
        if func.__name__ == "may_share_memory":
            return True
        raise TypeError("Don't want to call array_function {}!".format(
            func.__name__))

They both look ok. cross_validate doesen't reach __array_function__ at all and permutation_importance flow reaches
__array_function__ but only from may_share_memory which is allowed.

I only did basic tests just to get a feel of the lay of the land:

X = np.array([[1, 1], [1, 2], [1, 3], [1, 4],                   
              [2, 1], [2, 2], [2, 3], [2, 4],                   
              [3, 1], [3, 2], [3, 3], [3, 4]])                  
X = _NotAnArray(X)                                              
y = _NotAnArray([1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 2, 2])                                                                                                        
estimator = LogisticRegression()                                
                                                                                                                                            
from sklearn.model_selection import cross_validate              
from sklearn.model_selection import GridSearchCV                
from sklearn.inspection import permutation_importance           
                                                                
grid = GridSearchCV(estimator, param_grid={'C': [1, 10]})      
cross_validate(grid, X, y, n_jobs=2)                           
                                                                
estimator.fit(X, y)                                             
rng = np.random.RandomState(42)                                 
result = permutation_importance(estimator, X, y, n_repeats=5,   
                                random_state=rng, n_jobs=1)     

What kind of delivery did you have in mind? Unitests on the functions in model_selection and inspection using _NotAnArray?
Maybe something similar to check_estimator in estimator_check.py? BTW, I saw _NotAnArray is removed in 0.24. Is there another util used to validate that "__array_function__" is not called?

@shivamgargsya are you OK with me looking into this? If you are already working on it I can find something else :-)

@alexshacked
Copy link
Contributor
alexshacked commented Aug 29, 2020

After looking in the sklearn code-base I guess there is no testing infrastructure for cross_validate or permutation_importance, similar to check_estimator for estimators.
So testing that "__array_function__" is not called will have to go:
into model_selection/tests/test_validation.py for cross_validate
and
into inspection/tests/test_permutation_importance.py for permutation_importance.

I will open a PR that will have 2 stages:
a. Validate "__array_function__" is not called for cross_validate andpermutation_importance
b. Identify more functions that need to be validated and add regression tests for them.

@jnothman
Copy link
Member Author
jnothman commented Aug 29, 2020 via email

alexshacked added a commit to alexshacked/scikit-learn that referenced this issue Aug 29, 2020
alexshacked added a commit to alexshacked/scikit-learn that referenced this issue Aug 29, 2020
@alexshacked
Copy link
Contributor

Opened [WIP] PR #18292

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants
0