8000 addressed comments · mne-tools/mne-python@0bf0626 · GitHub
[go: up one dir, main page]

Skip to content

Commit 0bf0626

Browse files
committed
addressed comments
1 parent c74064d commit 0bf0626

File tree

2 files changed

+18
-22
lines changed

2 files changed

+18
-22
lines changed

mne/decoding/search_light.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -34,23 +34,11 @@ def __repr__(self):
3434
return repr_str + '>'
3535

3636
def __init__(self, base_estimator, scoring=None, n_jobs=1):
37-
from sklearn.metrics import make_scorer, get_scorer
38-
3937
_check_estimator(base_estimator)
4038
self.base_estimator = base_estimator
4139
self.n_jobs = n_jobs
4240
self.scoring = scoring
4341

44-
# If scoring is None (default), the predictions are internally
45-
# generated by estimator.score(). Else, we must first get the
46-
# predictions based on the scorer.
47-
if not isinstance(self.scoring, str):
48-
self.scoring = (make_scorer(self.scoring) if self.scoring is
49-
not None else self.scoring)
50-
51-
elif self.scoring is not None:
52-
self.scoring = get_scorer(self.scoring)
53-
5442
if not isinstance(self.n_jobs, int):
5543
raise ValueError('n_jobs must be int, got %s' % n_jobs)
5644

@@ -239,18 +227,29 @@ 10000 def score(self, X, y):
239227
score : array, shape (n_samples, n_estimators)
240228
Score for each estimator / data slice couple.
241229
"""
230+
from sklearn.metrics import make_scorer, get_scorer
242231
self._check_Xy(X)
243232
if X.shape[-1] != len(self.estimators_):
244233
raise ValueError('The number of estimators does not match '
245234
'X.shape[-1]')
246235

236+
# If scoring is None (default), the predictions are internally
237+
# generated by estimator.score(). Else, we must first get the
238+
# predictions based on the scorer.
239+
if not isinstance(self.scoring, str):
240+
scoring_ = (make_scorer(self.scoring) if self.scoring is
241+
not None else self.scoring)
242+
243+
elif self.scoring is not None:
244+
scoring_ = get_scorer(self.scoring)
245+
247246
# For predictions/transforms the parallelization is across the data and
248247
# not across the estimators to avoid memory load.
249248
parallel, p_func, n_jobs = parallel_func(_sl_score, self.n_jobs)
250249
n_jobs = min(n_jobs, X.shape[-1])
251250
X_splits = np.array_split(X, n_jobs, axis=-1)
252251
est_splits = np.array_split(self.estimators_, n_jobs)
253-
score = parallel(p_func(est, self.scoring, X, y)
252+
score = parallel(p_func(est, scoring_, X, y)
254253
for (est, x) in zip(est_splits, X_splits))
255254

256255
if n_jobs > 1:
@@ -400,18 +399,13 @@ class GeneralizationLight(SearchLight):
400399
----------
401400
base_estimator : object
402401
The base estimator to iteratively fit on a subset of the dataset.
402+
scoring : callable, string, defaults to None
403+
Score function (or loss function) with signature
404+
score_func(y, y_pred, **kwargs).
403405
n_jobs : int, optional (default=1)
404406
The number of jobs to run in parallel for both `fit` and `predict`.
405407
If -1, then the number of jobs is set to the number of cores.
406408
"""
407-
def __init__(self, base_estimator, n_jobs=1):
408-
_check_estimator(base_estimator)
409-
self.base_estimator = base_estimator
410-
self.n_jobs = n_jobs
411-
412-
if not isinstance(self.n_jobs, int):
413-
raise ValueError('n_jobs must be int, got %s' % n_jobs)
414-
415409
def __repr__(self):
416410
repr_str = super(GeneralizationLight, self).__repr__()
417411
if hasattr(self, 'estimators_'):

mne/decoding/tests/test_search_light.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,9 @@ def test_searchlight():
7272
sl2.fit(X, y)
7373
assert_array_equal(score1, sl2.score(X, y))
7474

75-
assert_raises(ValueError, SearchLight, LogisticRegression(), scoring='foo')
75+
sl = SearchLight(LogisticRegression(), scoring='foo')
76+
sl.fit(X, y)
77+
assert_raises(ValueError, sl.score, X, y)
7678

7779
sl = SearchLight(LogisticRegression())
7880
assert_equal(sl.scoring, None)

0 commit comments

Comments
 (0)
0