8000 Common check for sample weight invariance with removed samples (#16507) · scikit-learn/scikit-learn@77279d6 · GitHub
[go: up one dir, main page]

Skip to content

Commit 77279d6

Browse files
rthNicolasHug
andauthored
Common check for sample weight invariance with removed samples (#16507)
Co-Authored-By: Nicolas Hug <contact@nicolas-hug.com>
1 parent c36c104 commit 77279d6

File tree

10 files changed

+173
-14
lines changed

10 files changed

+173
-14
lines changed

sklearn/calibration.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,14 @@ class that has the highest probability, and can thus be different
233233
check_is_fitted(self)
234234
return self.classes_[np.argmax(self.predict_proba(X), axis=1)]
235235

236+
def _more_tags(self):
237+
return {
238+
'_xfail_test': {
239+
'check_sample_weights_invariance(kind=zeros)':
240+
'zero sample_weight is not equivalent to removing samples',
241+
}
242+
}
243+
236244

237245
class _CalibratedClassifier:
238246
"""Probability calibration with isotonic regression or sigmoid.

sklearn/cluster/_kmeans.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,6 +1209,14 @@ def score(self, X, y=None, sample_weight=None):
12091209
return -_labels_inertia(X, sample_weight, x_squared_norms,
12101210
self.cluster_centers_)[1]
12111211

1212+
def _more_tags(self):
1213+
return {
1214+
'_xfail_test': {
1215+
'check_sample_weights_invariance(kind=zeros)':
1216+
'zero sample_weight is not equivalent to removing samples',
1217+
}
1218+
}
1219+
12121220

12131221
def _mini_batch_step(X, sample_weight, x_squared_norms, centers, weight_sums,
12141222
old_center_buffer, compute_squared_diff,
@@ -1865,3 +1873,11 @@ def predict(self, X, sample_weight=None):
18651873

18661874
X = self._check_test_data(X)
18671875
return self._labels_inertia_minibatch(X, sample_weight)[0]
1876+
1877+
def _more_tags(self):
1878+
return {
1879+
'_xfail_test': {
1880+
'check_sample_weights_invariance(kind=zeros)':
1881+
'zero sample_weight is not equivalent to removing samples',
1882+
}
1883+
}

sklearn/ensemble/_iforest.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,14 @@ def _compute_score_samples(self, X, subsample_features):
476476
)
477477
return scores
478478

479+
def _more_tags(self):
480+
return {
481+
'_xfail_test': {
482+
'check_sample_weights_invariance(kind=zeros)':
483+
'zero sample_weight is not equivalent to removing samples',
484+
}
485+
}
486+
479487

480488
def _average_path_length(n_samples_leaf):
481489
"""

sklearn/linear_model/_logistic.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2083,3 +2083,11 @@ def score(self, X, y, sample_weight=None):
20832083
scoring = get_scorer(scoring)
20842084

20852085
return scoring(self, X, y, sample_weight=sample_weight)
2086+
2087+
def _more_tags(self):
2088+
return {
2089+
'_xfail_test': {
2090+
'check_sample_weights_invariance(kind=zeros)':
2091+
'zero sample_weight is not equivalent to removing samples',
2092+
}
2093+
}

sklearn/linear_model/_ransac.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,3 +502,11 @@ def score(self, X, y):
502502
check_is_fitted(self)
503503

504504
return self.estimator_.score(X, y)
505+
506+
def _more_tags(self):
507+
return {
508+
'_xfail_test': {
509+
'check_sample_weights_invariance(kind=zeros)':
510+
'zero sample_weight is not equivalent to removing samples',
511+
}
512+
}

sklearn/linear_model/_ridge.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1909,3 +1909,11 @@ def fit(self, X, y, sample_weight=None):
19091909
@property
19101910
def classes_(self):
19111911
return self._label_binarizer.classes_
1912+
1913+
def _more_tags(self):
1914+
return {
1915+
'_xfail_test': {
1916+
'check_sample_weights_invariance(kind=zeros)':
1917+
'zero sample_weight is not equivalent to removing samples',
1918+
}
1919+
}

sklearn/linear_model/_stochastic_gradient.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1094,6 +1094,14 @@ def predict_log_proba(self):
10941094
def _predict_log_proba(self, X):
10951095
return np.log(self.predict_proba(X))
10961096

1097+
def _more_tags(self):
1098+
return {
1099+
'_xfail_test': {
1100+
'check_sample_weights_invariance(kind=zeros)':
1101+
'zero sample_weight is not equivalent to removing samples',
1102+
}
1103+
}
1104+
10971105

10981106
class BaseSGDRegressor(RegressorMixin, BaseSGD):
10991107

@@ -1575,3 +1583,11 @@ def __init__(self, loss="squared_loss", *, penalty="l2", alpha=0.0001,
15751583
validation_fraction=validation_fraction,
15761584
n_iter_no_change=n_iter_no_change, warm_start=warm_start,
15771585
average=average)
1586+
1587+
def _more_tags(self):
1588+
return {
1589+
'_xfail_test': {
1590+
'check_sample_weights_invariance(kind=zeros)':
1591+
'zero sample_weight is not equivalent to removing samples',
1592+
}
1593+
}

sklearn/neighbors/_kde.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,3 +274,11 @@ def sample(self, n_samples=1, random_state=None):
274274
correction = (gammainc(0.5 * dim, 0.5 * s_sq) ** (1. / dim)
275275
* self.bandwidth / np.sqrt(s_sq))
276276
return data[i] + X * correction[:, np.newaxis]
277+
278+
def _more_tags(self):
279+
return {
280+
'_xfail_test': {
281+
'check_sample_weights_invariance(kind=zeros)':
282+
'sample_weight must have positive values',
283+
}
284+
}

sklearn/svm/_classes.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,14 @@ def fit(self, X, y, sample_weight=None):
244244

245245
return self
246246

247+
def _more_tags(self):
248+
return {
249+
'_xfail_test': {
250+
'check_sample_weights_invariance(kind=zeros)':
251+
'zero sample_weight is not equivalent to removing samples',
252+
}
253+
}
254+
247255

248256
class LinearSVR(RegressorMixin, LinearModel):
249257
"""Linear Support Vector Regression.
@@ -424,6 +432,14 @@ def fit(self, X, y, sample_weight=None):
424432

425433
return self
426434

435+
def _more_tags(self):
436+
return {
437+
'_xfail_test': {
438+
'check_sample_weights_invariance(kind=zeros)':
439+
'zero sample_weight is not equivalent to removing samples',
440+
}
441+
}
442+
427443

428444
class SVC(BaseSVC):
429445
"""C-Support Vector Classification.
@@ -650,6 +666,14 @@ def __init__(self, *, C=1.0, kernel='rbf', degree=3, gamma='scale',
650666
break_ties=break_ties,
651667
random_state=random_state)
652668

669+
def _more_tags(self):
670+
return {
671+
'_xfail_test': {
672+
'check_sample_weights_invariance(kind=zeros)':
673+
'zero sample_weight is not equivalent to removing samples',
674+
}
675+
}
676+
653677

654678
class NuSVC(BaseSVC):
655679
"""Nu-Support Vector Classification.
@@ -866,7 +890,9 @@ def _more_tags(self):
866890
'_xfail_checks': {
867891
'check_methods_subset_invariance':
868892
'fails for the decision_function method',
869-
'check_class_weight_classifiers': 'class_weight is ignored.'
893+
'check_class_weight_classifiers': 'class_weight is ignored.',
894+
'check_sample_weights_invariance(kind=zeros)':
895+
'zero sample_weight is not equivalent to removing samples',
870896
}
871897
}
872898

@@ -1027,6 +1053,14 @@ def probA_(self):
10271053
def probB_(self):
10281054
return self._probB
10291055

1056+
def _more_tags(self):
1057+
return {
1058+
'_xfail_test': {
1059+
'check_sample_weights_invariance(kind=zeros)':
1060+
'zero sample_weight is not equivalent to removing samples',
1061+
}
1062+
}
1063+
10301064

10311065
class NuSVR(RegressorMixin, BaseLibSVM):
10321066
"""Nu Support Vector Regression.
@@ -1157,6 +1191,14 @@ def __init__(self, *, nu=0.5, C=1.0, kernel='rbf', degree=3,
11571191
probability=False, cache_size=cache_size, class_weight=None,
11581192
verbose=verbose, max_iter=max_iter, random_state=None)
11591193

1194+
def _more_tags(self):
1195+
return {
1196+
'_xfail_test': {
1197+
'check_sample_weights_invariance(kind=zeros)':
1198+
'zero sample_weight is not equivalent to removing samples',
1199+
}
1200+
}
1201+
11601202

11611203
class OneClassSVM(OutlierMixin, BaseLibSVM):
11621204
"""Unsupervised Outlier Detection.
@@ -1371,3 +1413,11 @@ def probA_(self):
13711413
@property
13721414
def probB_(self):
13731415
return self._probB
1416+
1417+
def _more_tags(self):
1418+
return {
1419+
'_xfail_test': {
1420+
'check_sample_weights_invariance(kind=zeros)':
1421+
'zero sample_weight is not equivalent to removing samples',
1422+
}
1423+
}

sklearn/utils/estimator_checks.py

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ def _yield_checks(name, estimator):
6868
yield check_sample_weights_not_an_array
6969
yield check_sample_weights_list
7070
yield check_sample_weights_shape
71-
yield check_sample_weights_invariance
71+
yield partial(check_sample_weights_invariance, kind='ones')
72+
yield partial(check_sample_weights_invariance, kind='zeros')
7273
yield check_estimators_fit_returns_self
7374
yield partial(check_estimators_fit_returns_self, readonly_memmap=True)
7475

@@ -488,6 +489,7 @@ def check_estimator(Estimator, generate_only=False):
488489
warnings.warn(msg, FutureWarning)
489490

490491
checks_generator = _generate_class_checks(Estimator)
492+
estimator = _construct_instance(Estimator)
491493
else:
492494
# got an instance
493495
estimator = Estimator
@@ -497,12 +499,19 @@ def check_estimator(Estimator, generate_only=False):
497499
if generate_only:
498500
return checks_generator
499501

502+
xfail_checks = _safe_tags(estimator, '_xfail_test')
503+
500504
for estimator, check in checks_generator:
505+
check_name = _set_check_estimator_ids(check)
506+
if xfail_checks and check_name in xfail_checks:
507+
# skip tests marked as a known failure and raise a warning
508+
msg = xfail_checks[check_name]
509+
warnings.warn(f'Skipping {check_name}: {msg}', SkipTestWarning)
510+
continue
501511
try:
502512
check(estimator)
503513
except SkipTest as exception:
504-
# the only SkipTest thrown currently results from not
505-
# being able to import pandas.
514+
# raise warning for tests that are are skipped
506515
warnings.warn(str(exception), SkipTestWarning)
507516

508517

@@ -861,7 +870,7 @@ def check_sample_weights_shape(name, estimator_orig):
861870

862871

863872
@ignore_warnings(category=FutureWarning)
864-
def check_sample_weights_invariance(name, estimator_orig):
873+
def check_sample_weights_invariance(name, estimator_orig, kind="ones"):
865874
# check that the estimators yield same results for
866875
# unit weights and no weights
867876
if (has_fit_parameter(estimator_orig, "sample_weight") and
@@ -877,25 +886,45 @@ def check_sample_weights_invariance(name, estimator_orig):
877886
X = np.array([[1, 3], [1, 3], [1, 3], [1, 3],
878887
[2, 1], [2, 1], [2, 1], [2, 1],
879888
[3, 3], [3, 3], [3, 3], [3, 3],
880-
[4, 1], [4, 1], [4, 1], [4, 1]], dtype=np.dtype('float'))
889+
[4, 1], [4, 1], [4, 1], [4, 1]], dtype=np.float64)
881890
y = np.array([1, 1, 1, 1, 2, 2, 2, 2,
882-
1, 1, 1, 1, 2, 2, 2, 2], dtype=np.dtype('int'))
891+
1, 1, 1, 1, 2, 2, 2, 2], dtype=np.int)
892+
893+
if kind == 'ones':
894+
X2 = X
895+
y2 = y
896+
sw2 = np.ones(shape=len(y))
897+
err_msg = (f"For {name} sample_weight=None is not equivalent to "
898+
f"sample_weight=ones")
899+
elif kind == 'zeros':
900+
# Construct a dataset that is very different to (X, y) if weights
901+
# are disregarded, but identical to (X, y) given weights.
902+
X2 = np.vstack([X, X + 1])
903+
y2 = np.hstack([y, 3 - y])
904+
sw2 = np.ones(shape=len(y) * 2)
905+
sw2[len(y):] = 0
906+
X2, y2, sw2 = shuffle(X2, y2, sw2, random_state=0)
907+
908+
err_msg = (f"For {name} sample_weight is not equivalent "
909+
f"to removing samples")
910+
else:
911+
raise ValueError
912+
883913
y = _enforce_estimator_tags_y(estimator1, y)
914+
y2 = _enforce_estimator_tags_y(estimator2, y2)
884915

885-
estimator1.fit(X, y=y, sample_weight=np.ones(shape=len(y)))
886-
estimator2.fit(X, y=y, sample_weight=None)
916+
estimator1.fit(X, y=y, sample_weight=None)
917+
estimator2.fit(X2, y=y2, sample_weight=sw2)
887918

888-
for method in ["predict", "transform"]:
919+
for method in ["predict", "predict_proba",
920+
"decision_function", "transform"]:
889921
if hasattr(estimator_orig, method):
890922
X_pred1 = getattr(estimator1, method)(X)
891923
X_pred2 = getattr(estimator2, method)(X)
892924
if sparse.issparse(X_pred1):
893925
X_pred1 = X_pred1.toarray()
894926
X_pred2 = X_pred2.toarray()
895-
assert_allclose(X_pred1, X_pred2,
896-
err_msg="For %s sample_weight=None is not"
897-
" equivalent to sample_weight=ones"
898-
% name)
927+
assert_allclose(X_pred1, X_pred2, err_msg=err_msg)
899928

900929

901930
@ignore_warnings(category=(FutureWarning, UserWarning))

0 commit comments

Comments
 (0)
0