8000 Clean up · scikit-learn/scikit-learn@c4d6278 · GitHub
[go: up one dir, main page]

Skip to content

Commit c4d6278

Browse files
Clean up
1 parent b217697 commit c4d6278

File tree

2 files changed

+7
-9
lines changed
Filter options

2 files changed

+7
-9
lines changed

sklearn/cross_validation.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1091,7 +1091,6 @@ def cross_val_score(estimator, X, y=None, scoring=None, cv=None, n_jobs=1,
10911091
scoring=scoring)
10921092
# We clone the estimator to make sure that all the folds are
10931093
# independent, and that it is pickle-able.
1094-
fit_params = fit_params if fit_params is not None else {}
10951094
parallel = Parallel(n_jobs=n_jobs, verbose=verbose,
10961095
pre_dispatch=pre_dispatch)
10971096
scores = parallel(
@@ -1104,15 +1103,15 @@ def cross_val_score(estimator, X, y=None, scoring=None, cv=None, n_jobs=1,
11041103
def _cross_val_score(estimator, X, y, scorer, train, test, verbose,
11051104
fit_params):
11061105
"""Inner loop for cross validation"""
1107-
# TODO replace with grid_search.fit_grid_point()
11081106
n_samples = _num_samples(X)
1107+
fit_params = fit_params if fit_params is not None else {}
11091108
fit_params = dict([(k, np.asarray(v)[train] # TODO why is this necessary?
11101109
if hasattr(v, '__len__') and len(v) == n_samples else v)
11111110
for k, v in fit_params.items()])
11121111

11131112
X_train, y_train = _split(estimator, X, y, train)
11141113
X_test, y_test = _split(estimator, X, y, test, train)
1115-
estimator.fit(X_train, y_train, **fit_params)
1114+
_fit(estimator.fit, X_train, y_train, **fit_params)
11161115
score = _score(estimator, X_test, y_test, scorer)
11171116

11181117
if verbose > 1:

sklearn/grid_search.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -243,17 +243,16 @@ def fit_grid_point(X, y, base_estimator, parameters, train, test, scorer,
243243
X_train, y_train = _split(estimator, X, y, train)
244244
X_test, y_test = _split(estimator, X, y, test, train)
245245
_fit(estimator.fit, X_train, y_train, **fit_params)
246-
this_score = _score(estimator, X_test, y_test, scorer)
246+
score = _score(estimator, X_test, y_test, scorer)
247247

248248
if verbose > 2:
249-
msg += ", score=%f" % this_score
249+
msg += ", score=%f" % score
250250
if verbose > 1:
251-
end_msg = "%s -%s" % (msg,
252-
logger.short_format_time(time.time() -
253-
start_time))
251+
end_msg = "%s -%s" % (msg, logger.short_format_time(time.time() -
252+
start_time))
254253
print("[GridSearchCV] %s %s" % ((64 - len(end_msg)) * '.', end_msg))
255254

256-
return this_score, parameters, _num_samples(X_test)
255+
return score, parameters, _num_samples(X_test)
257256

258257

259258
def _check_param_grid(param_grid):

0 commit comments

Comments
 (0)
0