8000 [MRG] Fit params support for permutation_test_score and validation_cu… · thomasjpfan/scikit-learn@8091faf · GitHub
[go: up one dir, main page]

Skip to content

Commit 8091faf

Browse files
authored
[MRG] Fit params support for permutation_test_score and validation_curve (scikit-learn#18527)
1 parent 1f217f3 commit 8091faf

File tree

3 files changed

+70
-10
lines changed

3 files changed

+70
-10
lines changed

doc/whats_new/v0.24.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,12 @@ Changelog
450450
samples between the train and test set on each fold.
451451
:pr:`13204` by :user:`Kyle Kosic <kykosic>`.
452452

453+
- |Enhancement| :func:`model_selection.permutation_test_score` and
454+
:func:`model_selection.validation_curve` now accept fit_params
455+
to pass additional estimator parameters.
456+
:pr:`18527` by :user:`Gaurav Dhingra <gxyd>`,
457+
:user:`Julien Jerphanion <jjerphan>` and :user:`Amanda Dsouza <amy12xx>`.
458+
453459
- |Enhancement| :func:`model_selection.cross_val_score`,
454460
:func:`model_selection.cross_validate`,
455461
:class:`model_selection.GridSearchCV`, and

sklearn/model_selection/_validation.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1048,8 +1048,8 @@ def _check_is_permutation(indices, n_samples):
10481048
@_deprecate_positional_args
10491049
def permutation_test_score(estimator, X, y, *, groups=None, cv=None,
10501050
n_permutations=100, n_jobs=None, random_state=0,
1051-
verbose=0, scoring=None):
1052-
"""Evaluates the significance of a cross-validated score using permutations
1051+
verbose=0, scoring=None, fit_params=None):
1052+
"""Evaluate the significance of a cross-validated score with permutations
10531053
10541054
Permutes targets to generate 'randomized data' and compute the empirical
10551055
p-value against the null hypothesis that features and targets are
@@ -1129,6 +1129,11 @@ def permutation_test_score(estimator, X, y, *, groups=None, cv=None,
11291129
verbose : int, default=0
11301130
The verbosity level.
11311131
1132+
fit_params : dict, default=None
1133+
Parameters to pass to the fit method of the estimator.
1134+
1135+
.. versionadded:: 0.24
1136+
11321137
Returns
11331138
-------
11341139
score : float
@@ -1165,24 +1170,29 @@ def permutation_test_score(estimator, X, y, *, groups=None, cv=None,
11651170

11661171
# We clone the estimator to make sure that all the folds are
11671172
# independent, and that it is pickle-able.
1168-
score = _permutation_test_score(clone(estimator), X, y, groups, cv, scorer)
1173+
score = _permutation_test_score(clone(estimator), X, y, groups, cv, scorer,
1174+
fit_params=fit_params)
11691175
permutation_scores = Parallel(n_jobs=n_jobs, verbose=verbose)(
11701176
delayed(_permutation_test_score)(
11711177
clone(estimator), X, _shuffle(y, groups, random_state),
1172-
groups, cv, scorer)
1178+
groups, cv, scorer, fit_params=fit_params)
11731179
for _ in range(n_permutations))
11741180
permutation_scores = np.array(permutation_scores)
11751181
pvalue = (np.sum(permutation_scores >= score) + 1.0) / (n_permutations + 1)
11761182
return score, permutation_scores, pvalue
11771183

11781184

1179-
def _permutation_test_score(estimator, X, y, groups, cv, scorer):
1185+
def _permutation_test_score(estimator, X, y, groups, cv, scorer,
1186+
fit_params):
11801187
"""Auxiliary function for permutation_test_score"""
1188+
# Adjust length of sample weights
1189+
fit_params = fit_params if fit_params is not None else {}
11811190
avg_score = []
11821191
for train, test in cv.split(X, y, groups):
11831192
X_train, y_train = _safe_split(estimator, X, y, train)
11841193
X_test, y_test = _safe_split(estimator, X, y, test, train)
1185-
estimator.fit(X_train, y_train)
1194+
fit_params = _check_fit_params(X, fit_params, train)
1195+
estimator.fit(X_train, y_train, **fit_params)
11861196
avg_score.append(scorer(estimator, X_test, y_test))
11871197
return np.mean(avg_score)
11881198

@@ -1204,7 +1214,8 @@ def learning_curve(estimator, X, y, *, groups=None,
12041214
train_sizes=np.linspace(0.1, 1.0, 5), cv=None,
12051215
scoring=None, exploit_incremental_learning=False,
12061216
n_jobs=None, pre_dispatch="all", verbose=0, shuffle=False,
1207-
random_state=None, error_score=np.nan, return_times=False):
1217+
random_state=None, error_score=np.nan,
1218+
return_times=False):
12081219
"""Learning curve.
12091220
12101221
Determines cross-validated training and test scores for different training
@@ -1501,7 +1512,7 @@ def _incremental_fit_estimator(estimator, X, y, classes, train, test,
15011512
@_deprecate_positional_args
15021513
def validation_curve(estimator, X, y, *, param_name, param_range, groups=None,
15031514
cv=None, scoring=None, n_jobs=None, pre_dispatch="all",
1504-
verbose=0, error_score=np.nan):
1515+
verbose=0, error_score=np.nan, fit_params=None):
15051516
"""Validation curve.
15061517
15071518
Determine training and test scores for varying parameter values.
@@ -1577,6 +1588,11 @@ def validation_curve(estimator, X, y, *, param_name, param_range, groups=None,
15771588
verbose : int, default=0
15781589
Controls the verbosity: the higher, the more messages.
15791590
1591+
fit_params : dict, default=None
1592+
Parameters to pass to the fit method of the estimator.
1593+
1594+
.. versionadded:: 0.24
1595+
15801596
error_score : 'raise' or numeric, default=np.nan
15811597
Value to assign to the score if an error occurs in estimator fitting.
15821598
If set to 'raise', the error is raised.
@@ -1606,8 +1622,9 @@ def validation_curve(estimator, X, y, *, param_name, param_range, groups=None,
16061622
verbose=verbose)
16071623
results = parallel(delayed(_fit_and_score)(
16081624
clone(estimator), X, y, scorer, train, test, verbose,
1609-
parameters={param_name: v}, fit_params=None, return_train_score=True,
1610-
error_score=error_score)
1625+
parameters={param_name: v}, fit_params=fit_params,
1626+
return_train_score=True, error_score=error_score)
1627+
16111628
# NOTE do not change order of iteration to allow one time cv splitters
16121629
for train, test in cv.split(X, y, groups) for v in param_range)
16131630
n_params = len(param_range)

sklearn/model_selection/tests/test_validation.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -754,6 +754,23 @@ def test_permutation_test_score_allow_nans():
754754
permutation_test_score(p, X, y)
755755

756756

757+
def test_permutation_test_score_fit_params():
758+
X = np.arange(100).reshape(10, 10)
759+
y = np.array([0] * 5 + [1] * 5)
760+
clf = CheckingClassifier(expected_fit_params=['sample_weight'])
761+
762+
err_msg = r"Expected fit parameter\(s\) \['sample_weight'\] not seen."
763+
with pytest.raises(AssertionError, match=err_msg):
764+
permutation_test_score(clf, X, y)
765+
766+
err_msg = "Fit parameter sample_weight has length 1; expected"
767+
with pytest.raises(AssertionError, match=err_msg):
768+
permutation_test_score(clf, X, y,
769+
fit_params={'sample_weight': np.ones(1)})
770+
permutation_test_score(clf, X, y,
771+
fit_params={'sample_weight': np.ones(10)})
772+
773+
757774
def test_cross_val_score_allow_nans():
758775
# Check that cross_val_score allows input data with NaNs
759776
X = np.arange(200, dtype=np.float64).reshape(10, -1)
@@ -1298,6 +1315,26 @@ def test_validation_curve_cv_splits_consistency():
12981315
assert_array_almost_equal(np.array(scores3), np.array(scores1))
12991316

13001317

1318+
def test_validation_curve_fit_params():
1319+
X = np.arange(100).reshape(10, 10)
1320+
y = np.array([0] * 5 + [1] * 5)
1321+
clf = CheckingClassifier(expected_fit_params=['sample_weight'])
1322+
1323+
err_msg = r"Expected fit parameter\(s\) \['sample_weight'\] not seen."
1324+
with pytest.raises(AssertionError, match=err_msg):
1325+
validation_curve(clf, X, y, param_name='foo_param',
1326+
param_range=[1, 2, 3], error_score='raise')
1327+
1328+
err_msg = "Fit parameter sample_weight has length 1; expected"
1329+
with pytest.raises(AssertionError, match=err_msg):
1330+
validation_curve(clf, X, y, param_name='foo_param',
1331+
param_range=[1, 2, 3], error_score='raise',
1332+
fit_params={'sample_weight': np.ones(1)})
1333+
validation_curve(clf, X, y, param_name='foo_param',
1334+
param_range=[1, 2, 3], error_score='raise',
1335+
fit_params={'sample_weight': np.ones(10)})
1336+
1337+
13011338
def test_check_is_permutation():
13021339
rng = np.random.RandomState(0)
13031340
p = np.arange(100)

0 commit comments

Comments
 (0)
0