8000 how can cv folds can be more than number of groups in cross_validate? · Issue #13972 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

how can cv folds can be more than number of groups in cross_validate? #13972

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
omarcr opened this issue May 29, 2019 · 10 comments
Closed

how can cv folds can be more than number of groups in cross_validate? #13972

omarcr opened this issue May 29, 2019 · 10 comments

Comments

@omarcr
Copy link
omarcr commented May 29, 2019

In the following code:

from sklearn import datasets, linear_model
from sklearn.model_selection import cross_validate
from sklearn.metrics.scorer import make_scorer
from sklearn.metrics import confusion_matrix
from sklearn.svm import LinearSVC
import numpy as np

digits = datasets.load_digits()

X = digits.data[:500]
y = digits.target[:500]

counts, unique = np.unique(y, return_counts=True)

g1 = np.repeat(0, int(len(y)/2))
g2 = np.repeat(1, int(len(y)/2))

g = np.concatenate((g1, g2))

linear = LinearSVC()

cv_results = cross_validate(linear, X, y, cv=5)
sorted(cv_results.keys())
print(cv_results)


print('cv_results')
print(cv_results['test_score'])


print('cross-validate')
scores = cross_validate(linear, X, y, cv=5,
                        scoring=('f1_weighted'),
                        return_train_score=True, groups=g)
print(scores)

Shouldn't it be returned that:
groups should be equal to CV for the folds?

@jnothman
Copy link
Member
jnothman commented May 29, 2019 via email

@omarcr
Copy link
Author
omarcr commented May 29, 2019

you mean something like cv=GroupKFold(n_splits=3) or what type of object?

@jnothman
Copy link
Member
jnothman commented May 29, 2019 via email

@omarcr
Copy link
Author
omarcr commented May 29, 2019

I think it should be fixed in the library too. An exception should be shown if a group is declared but there is no object in the cv argument.

@amueller
Copy link
Member

@omarcr I'm not sure that's easy to do in a very consistent way, and I'm also not sure it's desirable. We could raise an error any time groups is passed to a cv object that doesn't use it but that would be a pretty big change in behavior and potentially break many people's code.
If you change your CV object from a grouped one to a not-grouped one I feel like you shouldn't need to pass different data/groups to make the code work.
One option I could see is that if cv is a number we also validate groups. In a sense it's reasonable to expect the groups to be used if you pass cv=5. If you pass cv=KFold(n_splits=5) I think it's reasonable to ignore groups. But it's also a bit strange if these two cases have different behavior.

@omarcr
Copy link
Author
omarcr commented May 29, 2019

@amueller The way I understand the method is the following:

if for example I have 3 groups of labels then there should be 3 folds. After cross validation, the 3 models built will have an independent group for testing.

If for example I don't declare the groups then the cross_validate function will just arrange the data by index to construct the 3 folds.

therefore if I pass cv=KFold(n_splits=5) I need to declare the groups labels of the data samples so cross_validate understands how to select the data for the folds.

Is this the way the method is implemented?

@amueller
Copy link
Member

Groups are only used if you explicitly pass a cross-validation strategy that uses groups, so GroupKFold or LeaveOneGroupOut.
Whatever you pass to cv completely determines the behavior, and some values of cv require groups to be specified, all the others ignore it.
If you pass cv=5 for classification it will use stratified 5-fold cross-validation (so not actually sort by index, though we might be using a stable stratification strategy?).

@jnothman
Copy link
Member
jnothman commented May 29, 2019 via email

@omarcr
Copy link
Author
omarcr commented May 30, 2019

Yes I agree. GroupKFold(n_splits=n) already has a routine checking for splits not to be greater than the number of folds: ValueError: Cannot have number of splits n_splits=4 greater than the number of groups: 3. maybe that can be reused here.

@lucyleeow
Copy link
Member

I think this can be closed. With #28210 you now get a user warning when the splitter doesn't support group. And you will get an error if folds is greater than number of groups.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants
0