8000 [WIP] Common test for sample weight by eickenberg · Pull Request #5461 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[WIP] Common test for sample weight #5461

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions sklearn/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,23 @@
import warnings
import sys
import pkgutil
import numpy as np

from sklearn import datasets
from sklearn.base import is_classifier, is_regressor
from sklearn.cross_validation import train_test_split
from sklearn.externals.six import PY3
from sklearn.externals.six.moves import zip
from sklearn.externals.funcsigs import signature
from sklearn.utils import check_random_state
from sklearn.utils.testing import assert_false, clean_warning_registry
from sklearn.utils.testing import all_estimators
from sklearn.utils.testing import assert_greater
from sklearn.utils.testing import assert_in
from sklearn.utils.testing import ignore_warnings

from numpy.testing import assert_array_almost_equal

import sklearn
from sklearn.cluster.bicluster import BiclusterMixin
from sklearn.decomposition import ProjectedGradientNMF
Expand Down Expand Up @@ -219,3 +228,64 @@ def test_get_params_invariance():
yield check_get_params_invariance, name, Estimator
else:
yield check_get_params_invariance, name, Estimator
yield check_transformer_n_iter, name, estimator


def test_sample_weight_consistency(random_state=42):
estimators = all_estimators()

n_samples, n_features = 20, 5
rng = check_random_state(random_state)

sample_weight = rng.randint(1, 4, (n_samples,))

X_clf, y_clf = datasets.make_classification(
n_samples=n_samples, n_features=n_features,
random_state=random_state)
X_reg, y_reg = datasets.make_regression(
n_samples=n_samples, n_features=n_features,
n_informative=2, random_state=random_state)

def aug(data, sample_weight):
# raise all samples to multiplicity of the corresponding sampleweight
aug_data = []
for samples, weight in zip(zip(*data), sample_weight):
for _ in range(weight):
aug_data.append(samples)
aug_data = map(np.array, zip(*aug_data))
return aug_data

train, test = train_test_split(range(n_samples))

for name, Estimator in estimators:
if 'sample_weight' not in signature(Estimator.fit).parameters.keys():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can use in utils.validation.has_fit_parameter

continue
if is_classifier(Estimator):
X, y = X_clf, y_clf
elif is_regressor(Estimator):
X, y = X_reg, y_reg
else:
print ("%s is neither classifier nor regressor" % name)
continue

try:
estimator_sw = Estimator().fit(X[train], y[train],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

random_state should be enforced here, if it is a parameter of Estimator.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is the set_random_state helper for that.

sample_weight=sample_weight[train])
X_aug_train, y_aug_train = aug((X[train], y[train]),
sample_weight[train])
estimator_aug = Estimator().fit(X_aug_train, y_aug_train)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and the same should be reused here.

except ValueError:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a if name == check might be better.

# LogisticRegression liblinear (standard solver)
# does not support sample weights, but the argument is there
continue

# if estimator has `coef_` attribute, then compare the two
if hasattr(estimator_sw, 'coef_'):
yield (assert_array_almost_equal,
estimator_sw.coef_, estimator_aug.coef_)

pred_sw = estimator_sw.predict(X[test])
pred_aug = estimator_aug.predict(X[test])

yield assert_array_almost_equal, pred_sw, pred_aug

0