-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
[WIP] Common test for sample weight #5461
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,14 +11,23 @@ | |
import warnings | ||
import sys | ||
import pkgutil | ||
import numpy as np | ||
|
||
from sklearn import datasets | ||
from sklearn.base import is_classifier, is_regressor | ||
from sklearn.cross_validation import train_test_split | ||
from sklearn.externals.six import PY3 | ||
from sklearn.externals.six.moves import zip | ||
from sklearn.externals.funcsigs import signature | ||
from sklearn.utils import check_random_state | ||
from sklearn.utils.testing import assert_false, clean_warning_registry | ||
from sklearn.utils.testing import all_estimators | ||
from sklearn.utils.testing import assert_greater | ||
from sklearn.utils.testing import assert_in | ||
from sklearn.utils.testing import ignore_warnings | ||
|
||
from numpy.testing import assert_array_almost_equal | ||
|
||
import sklearn | ||
from sklearn.cluster.bicluster import BiclusterMixin | ||
from sklearn.decomposition import ProjectedGradientNMF | ||
|
@@ -219,3 +228,64 @@ def test_get_params_invariance(): | |
yield check_get_params_invariance, name, Estimator | ||
else: | ||
yield check_get_params_invariance, name, Estimator | ||
yield check_transformer_n_iter, name, estimator | ||
|
||
|
||
def test_sample_weight_consistency(random_state=42): | ||
estimators = all_estimators() | ||
|
||
n_samples, n_features = 20, 5 | ||
rng = check_random_state(random_state) | ||
|
||
sample_weight = rng.randint(1, 4, (n_samples,)) | ||
|
||
X_clf, y_clf = datasets.make_classification( | ||
n_samples=n_samples, n_features=n_features, | ||
random_state=random_state) | ||
X_reg, y_reg = datasets.make_regression( | ||
n_samples=n_samples, n_features=n_features, | ||
n_informative=2, random_state=random_state) | ||
|
||
def aug(data, sample_weight): | ||
# raise all samples to multiplicity of the corresponding sampleweight | ||
aug_data = [] | ||
for samples, weight in zip(zip(*data), sample_weight): | ||
for _ in range(weight): | ||
aug_data.append(samples) | ||
aug_data = map(np.array, zip(*aug_data)) | ||
return aug_data | ||
|
||
train, test = train_test_split(range(n_samples)) | ||
|
||
for name, Estimator in estimators: | ||
if 'sample_weight' not in signature(Estimator.fit).parameters.keys(): | ||
continue | ||
if is_classifier(Estimator): | ||
X, y = X_clf, y_clf | ||
elif is_regressor(Estimator): | ||
X, y = X_reg, y_reg | ||
else: | ||
print ("%s is neither classifier nor regressor" % name) | ||
continue | ||
|
||
try: | ||
estimator_sw = Estimator().fit(X[train], y[train], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there is the |
||
sample_weight=sample_weight[train]) | ||
X_aug_train, y_aug_train = aug((X[train], y[train]), | ||
sample_weight[train]) | ||
estimator_aug = Estimator().fit(X_aug_train, y_aug_train) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and the same should be reused here. |
||
except ValueError: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. a |
||
# LogisticRegression liblinear (standard solver) | ||
# does not support sample weights, but the argument is there | ||
continue | ||
|
||
# if estimator has `coef_` attribute, then compare the two | ||
if hasattr(estimator_sw, 'coef_'): | ||
yield (assert_array_almost_equal, | ||
estimator_sw.coef_, estimator_aug.coef_) | ||
|
||
pred_sw = estimator_sw.predict(X[test]) | ||
pred_aug = estimator_aug.predict(X[test]) | ||
|
||
yield assert_array_almost_equal, pred_sw, pred_aug | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can use in
utils.validation.has_fit_parameter