@@ -1048,8 +1048,8 @@ def _check_is_permutation(indices, n_samples):
1048
1048
@_deprecate_positional_args
1049
1049
def permutation_test_score (estimator , X , y , * , groups = None , cv = None ,
1050
1050
n_permutations = 100 , n_jobs = None , random_state = 0 ,
1051
- verbose = 0 , scoring = None ):
1052
- """Evaluates the significance of a cross-validated score using permutations
1051
+ verbose = 0 , scoring = None , fit_params = None ):
1052
+ """Evaluate the significance of a cross-validated score with permutations
1053
1053
1054
1054
Permutes targets to generate 'randomized data' and compute the empirical
1055
1055
p-value against the null hypothesis that features and targets are
@@ -1129,6 +1129,11 @@ def permutation_test_score(estimator, X, y, *, groups=None, cv=None,
1129
1129
verbose : int, default=0
1130
1130
The verbosity level.
1131
1131
1132
+ fit_params : dict, default=None
1133
+ Parameters to pass to the fit method of the estimator.
1134
+
1135
+ .. versionadded:: 0.24
1136
+
1132
1137
Returns
1133
1138
-------
1134
1139
score : float
@@ -1165,24 +1170,29 @@ def permutation_test_score(estimator, X, y, *, groups=None, cv=None,
1165
1170
1166
1171
# We clone the estimator to make sure that all the folds are
1167
1172
# independent, and that it is pickle-able.
1168
- score = _permutation_test_score (clone (estimator ), X , y , groups , cv , scorer )
1173
+ score = _permutation_test_score (clone (estimator ), X , y , groups , cv , scorer ,
1174
+ fit_params = fit_params )
1169
1175
permutation_scores = Parallel (n_jobs = n_jobs , verbose = verbose )(
1170
1176
delayed (_permutation_test_score )(
1171
1177
clone (estimator ), X , _shuffle (y , groups , random_state ),
1172
- groups , cv , scorer )
1178
+ groups , cv , scorer , fit_params = fit_params )
1173
1179
for _ in range (n_permutations ))
1174
1180
permutation_scores = np .array (permutation_scores )
1175
1181
pvalue = (np .sum (permutation_scores >= score ) + 1.0 ) / (n_permutations + 1 )
1176
1182
return score , permutation_scores , pvalue
1177
1183
1178
1184
1179
- def _permutation_test_score (estimator , X , y , groups , cv , scorer ):
1185
+ def _permutation_test_score (estimator , X , y , groups , cv , scorer ,
1186
+ fit_params ):
1180
1187
"""Auxiliary function for permutation_test_score"""
1188
+ # Adjust length of sample weights
1189
+ fit_params = fit_params if fit_params is not None else {}
1181
1190
avg_score = []
1182
1191
for train , test in cv .split (X , y , groups ):
1183
1192
X_train , y_train = _safe_split (estimator , X , y , train )
1184
1193
X_test , y_test = _safe_split (estimator , X , y , test , train )
1185
- estimator .fit (X_train , y_train )
1194
+ fit_params = _check_fit_params (X , fit_params , train )
1195
+ estimator .fit (X_train , y_train , ** fit_params )
1186
1196
avg_score .append (scorer (estimator , X_test , y_test ))
1187
1197
return np .mean (avg_score )
1188
1198
@@ -1204,7 +1214,8 @@ def learning_curve(estimator, X, y, *, groups=None,
1204
1214
train_sizes = np .linspace (0.1 , 1.0 , 5 ), cv = None ,
1205
1215
scoring = None , exploit_incremental_learning = False ,
1206
1216
n_jobs = None , pre_dispatch = "all" , verbose = 0 , shuffle = False ,
1207
- random_state = None , error_score = np .nan , return_times = False ):
1217
+ random_state = None , error_score = np .nan ,
1218
+ return_times = False ):
1208
1219
"""Learning curve.
1209
1220
1210
1221
Determines cross-validated training and test scores for different training
@@ -1501,7 +1512,7 @@ def _incremental_fit_estimator(estimator, X, y, classes, train, test,
1501
1512
@_deprecate_positional_args
1502
1513
def validation_curve (estimator , X , y , * , param_name , param_range , groups = None ,
1503
1514
cv = None , scoring = None , n_jobs = None , pre_dispatch = "all" ,
1504
- verbose = 0 , error_score = np .nan ):
1515
+ verbose = 0 , error_score = np .nan , fit_params = None ):
1505
1516
"""Validation curve.
1506
1517
1507
1518
Determine training and test scores for varying parameter values.
@@ -1577,6 +1588,11 @@ def validation_curve(estimator, X, y, *, param_name, param_range, groups=None,
1577
1588
verbose : int, default=0
1578
1589
Controls the verbosity: the higher, the more messages.
1579
1590
1591
+ fit_params : dict, default=None
1592
+ Parameters to pass to the fit method of the estimator.
1593
+
1594
+ .. versionadded:: 0.24
1595
+
1580
1596
error_score : 'raise' or numeric, default=np.nan
1581
1597
Value to assign to the score if an error occurs in estimator fitting.
1582
1598
If set to 'raise', the error is raised.
@@ -1606,8 +1622,9 @@ def validation_curve(estimator, X, y, *, param_name, param_range, groups=None,
1606
1622
verbose = verbose )
1607
1623
results = parallel (delayed (_fit_and_score )(
1608
1624
clone (estimator ), X , y , scorer , train , test , verbose ,
1609
- parameters = {param_name : v }, fit_params = None , return_train_score = True ,
1610
- error_score = error_score )
1625
+ parameters = {param_name : v }, fit_params = fit_params ,
1626
+ return_train_score = True , error_score = error_score )
1627
+
1611
1628
# NOTE do not change order of iteration to allow one time cv splitters
1612
1629
for train , test in cv .split (X , y , groups ) for v in param_range )
1613
1630
n_params = len (param_range )
0 commit comments