8000 TST non-regression test for #6147, roc_auc on memmap data · scikit-learn/scikit-learn@613f1ad · GitHub
[go: up one dir, main page]

Skip to content

Commit 613f1ad

Browse files
committed
TST non-regression test for #6147, roc_auc on memmap data
1 parent 3933ff4 commit 613f1ad

File tree

1 file changed

+64
-12
lines changed

1 file changed

+64
-12
lines changed

sklearn/metrics/tests/test_score_objects.py

Lines changed: 64 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
import pickle
2+
import tempfile
3+
import shutil
4+
import os
5+
import numbers
26

37
import numpy as np
48

@@ -30,6 +34,7 @@
3034
from sklearn.cross_validation import train_test_split, cross_val_score
3135
from sklearn.grid_search import GridSearchCV
3236
from sklearn.multiclass import OneVsRestClassifier
37+
from sklearn.externals import joblib
3338

3439

3540
REGRESSION_SCORERS = ['r2', 'mean_absolute_error', 'mean_squared_error',
@@ -46,6 +51,46 @@
4651
MULTILABEL_ONLY_SCORERS = ['precision_samples', 'recall_samples', 'f1_samples']
4752

4853

54+
def _make_estimators(X_train, y_train, y_ml_train):
55+
# Make estimators that make sense to test various scoring methods
56+
sensible_regr = DummyRegressor(strategy='median')
57+
sensible_regr.fit(X_train, y_train)
58+
sensible_clf = DecisionTreeClassifier(random_state=0)
59+
sensible_clf.fit(X_train, y_train)
60+
sensible_ml_clf = DecisionTreeClassifier(random_state=0)
61+
sensible_ml_clf.fit(X_train, y_ml_train)
62+
return dict(
63+
[(name, sensible_regr) for name in REGRESSION_SCORERS] +
64+
[(name, sensible_clf) for name in CLF_SCORERS] +
65+
[(name, sensible_ml_clf) for name in MULTILABEL_ONLY_SCORERS]
66+
)
67+
68+
69+
X_mm, y_mm, y_ml_mm = None, None, None
70+
ESTIMATORS = None
71+
TEMP_FOLDER = None
72+
73+
74+
def setup_module():
75+
# Create some memory mapped data
76+
global X_mm, y_mm, y_ml_mm, TEMP_FOLDER, ESTIMATORS
77+
TEMP_FOLDER = tempfile.mkdtemp(prefix='sklearn_test_score_objects_')
78+
X, y = make_classification(n_samples=30, n_features=5, random_state=0)
79+
_, y_ml = make_multilabel_classification(n_samples=X.shape[0],
80+
random_state=0)
81+
filename = os.path.join(TEMP_FOLDER, 'test_data.pkl')
82+
joblib.dump((X, y, y_ml), filename)
83+
X_mm, y_mm, y_ml_mm = joblib.load(filename, mmap_mode='r')
84+
ESTIMATORS = _make_estimators(X_mm, y_mm, y_ml_mm)
85+
86+
87+
def teardown_module():
88+
global X_mm, y_mm, y_ml_mm, TEMP_FOLDER, ESTIMATORS
89+
# GC closes the mmap file descriptors
90+
X_mm, y_mm, y_ml_mm, ESTIMATORS = None, None, None, None
91+
shutil.rmtree(TEMP_FOLDER)
92+
93+
4994
class EstimatorWithoutFit(object):
5095
"""Dummy estimator to test check_scoring"""
5196
pass
@@ -318,18 +363,7 @@ def test_scorer_sample_weight():
318363
sample_weight[:10] = 0
319364

320365
# get sensible estimators for each metric
321-
sensible_regr = DummyRegressor(strategy='median')
322-
sensible_regr.fit(X_train, y_train)
323-
sensible_clf = DecisionTreeClassifier(random_state=0)
324-
sensible_clf.fit(X_train, y_train)
325-
sensible_ml_clf = DecisionTreeClassifier(random_state=0)
326-
sensible_ml_clf.fit(X_train, y_ml_train)
327-
estimator = dict([(name, sensible_regr)
328-
for name in REGRESSION_SCORERS] +
329-
[(name, sensible_clf)
330-
for name in CLF_SCORERS] +
331-
[(name, sensible_ml_clf)
332-
for name in MULTILABEL_ONLY_SCORERS])
366+
estimator = _make_estimators(X_train, y_train, y_ml_train)
333367

334368
for name, scorer in SCORERS.items():
335369
if name in MULTILABEL_ONLY_SCORERS:
@@ -355,3 +389,21 @@ def test_scorer_sample_weight():
355389
assert_true("sample_weight" in str(e),
356390
"scorer {0} raises unhelpful exception when called "
357391
"with sample weights: {1}".format(name, str(e)))
392+
393+
394+
@ignore_warnings # UndefinedMetricWarning for P / R scores
395+
def check_scorer_memmap(scorer_name):
396+
scorer, estimator = SCORERS[scorer_name], ESTIMATORS[scorer_name]
397+
if scorer_name in MULTILABEL_ONLY_SCORERS:
398+
score = scorer(estimator, X_mm, y_ml_mm)
399+
else:
400+
score = scorer(estimator, X_mm, y_mm)
401+
assert isinstance(score, numbers.Number), scorer_name
402+
403+
404+
def test_scorer_memmap_input():
405+
# Non-regression test for #6147: some score functions would
406+
# return singleton memmap when computed on memmap data instead of scalar
407+
# float values.
408+
for name in SCORERS.keys():
409+
yield check_scorer_memmap, name

0 commit comments

Comments
 (0)
0