8000 Common check for sample weight invariance with removed samples (#17176) · scikit-learn/scikit-learn@8feb045 · GitHub
[go: up one dir, main page]

Skip to content < 8000 span data-view-component="true" class="progress-pjax-loader Progress position-fixed width-full">

Commit 8feb045

Browse files
rthNicolasHug
andauthored
Common check for sample weight invariance with removed samples (#17176)
Co-Authored-By: Nicolas Hug <contact@nicolas-hug.com>
1 parent 8cce5bf commit 8feb045

File tree

10 files changed

+185
-36
lines changed

10 files changed

+185
-36
lines changed

sklearn/calibration.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,14 @@ class that has the highest probability, and can thus be different
284284
check_is_fitted(self)
285285
return self.classes_[np.argmax(self.predict_proba(X), axis=1)]
286286

287+
def _more_tags(self):
288+
return {
289+
'_xfail_checks': {
290+
'check_sample_weights_invariance(kind=zeros)':
291+
'zero sample_weight is not equivalent to removing samples',
292+
}
293+
}
294+
287295

288296
class _CalibratedClassifier:
289297
"""Probability calibration with isotonic regression or sigmoid.

sklearn/cluster/_kmeans.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1215,6 +1215,14 @@ def score(self, X, y=None, sample_weight=None):
12151215
return -_labels_inertia(X, sample_weight, x_squared_norms,
12161216
self.cluster_centers_)[1]
12171217

1218+
def _more_tags(self):
1219+
return {
1220+
'_xfail_checks': {
1221+
'check_sample_weights_invariance(kind=zeros)':
1222+
'zero sample_weight is not equivalent to removing samples',
1223+
}
1224+
}
1225+
12181226

12191227
def _mini_batch_step(X, sample_weight, x_squared_norms, centers, weight_sums,
12201228
old_center_buffer, compute_squared_diff,
@@ -1871,3 +1879,11 @@ def predict(self, X, sample_weight=None):
18711879

18721880
X = self._check_test_data(X)
18731881
return self._labels_inertia_minibatch(X, sample_weight)[0]
1882+
1883+
def _more_tags(self):
1884+
return {
1885+
'_xfail_checks': {
1886+
'check_sample_weights_invariance(kind=zeros)':
1887+
'zero sample_weight is not equivalent to removing samples',
1888+
}
1889+
}

sklearn/ensemble/_iforest.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
447447
)
448448
return scores
449449

450+
def _more_tags(self):
451+
return {
452+
'_xfail_checks': {
453+
'check_sample_weights_invariance(kind=zeros)':
454+
'zero sample_weight is not equivalent to removing samples',
455+
}
456+
}
457+
450458

451459
def _average_path_length(n_samples_leaf):
452460
"""

sklearn/linear_model/_logistic.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2084,3 +2084,11 @@ def score(self, X, y, sample_weight=None):
20842084
scoring = get_scorer(scoring)
20852085

20862086
return scoring(self, X, y, sample_weight=sample_weight)
2087+
2088+
def _more_tags(self):
2089+
return {
2090+
'_xfail_checks': {
2091+
'check_sample_weights_invariance(kind=zeros)':
2092+
'zero sample_weight is not equivalent to removing samples',
2093+
}
2094+
}

sklearn/linear_model/_ransac.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,3 +502,11 @@ def score(self, X, y):
502502
check_is_fitted(self)
503503

504504
return self.estimator_.score(X, y)
505+
506+
def _more_tags(self):
507+
return {
508+
'_xfail_checks': {
509+
'check_sample_weights_invariance(kind=zeros)':
510+
'zero sample_weight is not equivalent to removing samples',
511+
}
512+
}

sklearn/linear_model/_ridge.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1909,3 +1909,11 @@ def fit(self, X, y, sample_weight=None):
19091909
@property
19101910
def classes_(self):
19111911
return self._label_binarizer.classes_
1912+
1913+
def _more_tags(self):
1914+
return {
1915+
'_xfail_checks': {
1916+
'check_sample_weights_invariance(kind=zeros)':
1917+
'zero sample_weight is not equivalent to removing samples',
1918+
}
1919+
}

sklearn/linear_model/_stochastic_gradient.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1095,6 +1095,14 @@ def predict_log_proba(self):
10951095
def _predict_log_proba(self, X):
10961096
return np.log(self.predict_proba(X))
10971097

1098+
def _more_tags(self):
1099+
return {
1100+
'_xfail_checks': {
1101+
'check_sample_weights_invariance(kind=zeros)':
1102+
'zero sample_weight is not equivalent to removing samples',
1103+
}
1104+
}
1105+
10981106

10991107
class BaseSGDRegressor(RegressorMixin, BaseSGD):
11001108

@@ -1576,3 +1584,11 @@ def __init__(self, loss="squared_loss", *, penalty="l2", alpha=0.0001,
15761584
validation_fraction=validation_fraction,
15771585
n_iter_no_change=n_iter_no_change, warm_start=warm_start,
15781586
average=average)
1587+
1588+
def _more_tags(self):
1589+
return {
1590+
'_xfail_checks': {
1591+
'check_sample_weights_invariance(kind=zeros)':
1592+
'zero sample_weight is not equivalent to removing samples',
1593+
}
1594+
}

sklearn/neighbors/_kde.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,3 +274,11 @@ def sample(self, n_samples=1, random_state=None):
274274
correction = (gammainc(0.5 * dim, 0.5 * s_sq) ** (1. / dim)
275275
* self.bandwidth / np.sqrt(s_sq))
276276
return data[i] + X * correction[:, np.newaxis]
277+
278+
def _more_tags(self):
279+
return {
280+
'_xfail_checks': {
281+
'check_sample_weights_invariance(kind=zeros)':
282+
'sample_weight must have positive values',
283+
}
284+
}

sklearn/svm/_classes.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,14 @@ def fit(self, X, y, sample_weight=None):
244244

245245
return self
246246 EED3

247+
def _more_tags(self):
248+
return {
249+
'_xfail_checks': {
250+
'check_sample_weights_invariance(kind=zeros)':
251+
'zero sample_weight is not equivalent to removing samples',
252+
}
253+
}
254+
247255

248256
class LinearSVR(RegressorMixin, LinearModel):
249257
"""Linear Support Vector Regression.
@@ -424,6 +432,14 @@ def fit(self, X, y, sample_weight=None):
424432

425433
return self
426434

435+
def _more_tags(self):
436+
return {
437+
'_xfail_checks': {
438+
'check_sample_weights_invariance(kind=zeros)':
439+
'zero sample_weight is not equivalent to removing samples',
440+
}
441+
}
442+
427443

428444
class SVC(BaseSVC):
429445
"""C-Support Vector Classification.
@@ -650,6 +666,14 @@ def __init__(self, *, C=1.0, kernel='rbf', degree=3, gamma='scale',
650666
break_ties=break_ties,
651667
random_state=random_state)
652668

669+
def _more_tags(self):
670+
return {
671+
'_xfail_checks': {
672+
'check_sample_weights_invariance(kind=zeros)':
673+
'zero sample_weight is not equivalent to removing samples',
674+
}
675+
}
676+
653677

654678
class NuSVC(BaseSVC):
655679
"""Nu-Support Vector Classification.
@@ -866,7 +890,9 @@ def _more_tags(self):
866890
'_xfail_checks': {
867891
'check_methods_subset_invariance':
868892
'fails for the decision_function method',
869-
'check_class_weight_classifiers': 'class_weight is ignored.'
893+
'check_class_weight_classifiers': 'class_weight is ignored.',
894+
'check_sample_weights_invariance(kind=zeros)':
895+
'zero sample_weight is not equivalent to removing samples',
870896
}
871897
}
872898

@@ -1027,6 +1053,14 @@ def probA_(self):
10271053
def probB_(self):
10281054
return self._probB
10291055

1056+
def _more_tags(self):
1057+
return {
1058+
'_xfail_checks': {
1059+
'check_sample_weights_invariance(kind=zeros)':
1060+
'zero sample_weight is not equivalent to removing samples',
1061+
}
1062+
}
1063+
10301064

10311065
class NuSVR(RegressorMixin, BaseLibSVM):
10321066
"""Nu Support Vector Regression.
@@ -1157,6 +1191,14 @@ def __init__(self, *, nu=0.5, C=1.0, kernel='rbf', degree=3,
11571191
probability=False, cache_size=cache_size, class_weight=None,
11581192
verbose=verbose, max_iter=max_iter, random_state=None)
11591193

1194+
def _more_tags(self):
1195+
return {
1196+
'_xfail_checks': {
1197+
'check_sample_weights_invariance(kind=zeros)':
1198+
'zero sample_weight is not equivalent to removing samples',
1199+
}
1200+
}
1201+
11601202

11611203
class OneClassSVM(OutlierMixin, BaseLibSVM):
11621204
"""Unsupervised Outlier Detection.
@@ -1371,3 +1413,11 @@ def probA_(self):
13711413
@property
13721414
def probB_(self):
13731415
return self._probB
1416+
1417+
def _more_tags(self):
1418+
return {
1419+
'_xfail_checks': {
1420+
'check_sample_weights_invariance(kind=zeros)':
1421+
'zero sample_weight is not equivalent to removing samples',
1422+
}
1423+
}

sklearn/utils/estimator_checks.py

Lines changed: 54 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,12 @@ def _yield_checks(estimator):
6767
yield check_sample_weights_not_an_array
6868
yield check_sample_weights_list
6969
yield check_sample_weights_shape
70-
yield check_sample_weights_invariance
70+
if (has_fit_parameter(estimator, "sample_weight")
71+
and not (hasattr(estimator, "_pairwise")
72+
and estimator._pairwise)):
73+
# We skip pairwise because the data is not pairwise
74+
yield partial(check_sample_weights_invariance, kind='ones')
75+
yield partial(check_sample_weights_invariance, kind='zeros')
7176
yield check_estimators_fit_returns_self
7277
yield partial(check_estimators_fit_returns_self, readonly_memmap=True)
7378

@@ -836,41 +841,55 @@ def check_sample_weights_shape(name, estimator_orig):
836841

837842

838843
@ignore_warnings(category=FutureWarning)
839-
def check_sample_weights_invariance(name, estimator_orig):
840-
# check that the estimators yield same results for
844+
def check_sample_weights_invariance(name, estimator_orig, kind="ones"):
845+
# For kind="ones" check that the estimators yield same results for
841846
# unit weights and no weights
842-
if (has_fit_parameter(estimator_orig, "sample_weight") and
843-
not (hasattr(estimator_orig, "_pairwise")
844-
and estimator_orig._pairwise)):
845-
# We skip pairwise because the data is not pairwise
846-
847-
estimator1 = clone(estimator_orig)
848-
estimator2 = clone(estimator_orig)
849-
set_random_state(estimator1, random_state=0)
850-
set_random_state(estimator2, random_state=0)
851-
852-
X = np.array([[1, 3], [1, 3], [1, 3], [1, 3],
853-
[2, 1], [2, 1], [2, 1], [2, 1],
854-
[3, 3], [3, 3], [3, 3], [3, 3],
855-
[4, 1], [4, 1], [4, 1], [4, 1]], dtype=np.dtype('float'))
856-
y = np.array([1, 1, 1, 1, 2, 2, 2, 2,
857-
1, 1, 1, 1, 2, 2, 2, 2], dtype=np.dtype('int'))
858-
y = _enforce_estimator_tags_y(estimator1, y)
859-
860-
estimator1.fit(X, y=y, sample_weight=np.ones(shape=len(y)))
861-
estimator2.fit(X, y=y, sample_weight=None)
862-
863-
for method in ["predict", "transform"]:
864-
if hasattr(estimator_orig, method):
865-
X_pred1 = getattr(estimator1, method)(X)
866-
X_pred2 = getattr(estimator2, method)(X)
867-
if sparse.issparse(X_pred1):
868-
X_pred1 = X_pred1.toarray()
869-
X_pred2 = X_pred2.toarray()
870-
assert_allclose(X_pred1, X_pred2,
871-
err_msg="For %s sample_weight=None is not"
872-
" equivalent to sample_weight=ones"
873-
% name)
847+
# For kind="zeros" check that setting sample_weight to 0 is equivalent
848+
# to removing corresponding samples.
849+
estimator1 = clone(estimator_orig)
850+
estimator2 = clone(estimator_orig)
851+
set_random_state(estimator1, random_state=0)
852+
set_random_state(estimator2, random_state=0)
853+
854+
X1 = np.array([[1, 3], [1, 3], [1, 3], [1, 3],
855+
[2, 1], [2, 1], [2, 1], [2, 1],
856+
[3, 3], [3, 3], [3, 3], [3, 3],
857+
[4, 1], [4, 1], [4, 1], [4, 1]], dtype=np.float64)
858+
y1 = np.array([1, 1, 1, 1, 2, 2, 2, 2,
859+
1, 1, 1, 1, 2, 2, 2, 2], dtype=np.int)
860+
861+
if kind == 'ones':
862+
X2 = X1
863+
y2 = y1
864+
sw2 = np.ones(shape=len(y1))
865+
err_msg = (f"For {name} sample_weight=None is not equivalent to "
866+
f"sample_weight=ones")
867+
elif kind == 'zeros':
868+
# Construct a dataset that is very different to (X, y) if weights
869+
# are disregarded, but identical to (X, y) given weights.
870+
X2 = np.vstack([X1, X1 + 1])
871+
y2 = np.hstack([y1, 3 - y1])
872+
sw2 = np.ones(shape=len(y1) * 2)
873+
sw2[len(y1):] = 0
874+
X2, y2, sw2 = shuffle(X2, y2, sw2, random_state=0)
875+
876+
err_msg = (f"For {name}, a zero sample_weight is not equivalent "
877+
f"to removing the sample")
878+
else: # pragma: no cover
879+
raise ValueError
880+
881+
y1 = _enforce_estimator_tags_y(estimator1, y1)
882+
y2 = _enforce_estimator_tags_y(estimator2, y2)
883+
884+
estimator1.fit(X1, y=y1, sample_weight=None)
885+
estimator2.fit(X2, y=y2, sample_weight=sw2)
886+
887+
for method in ["predict", "predict_proba",
888+
"decision_function", "transform"]:
889+
if hasattr(estimator_orig, method):
890+
X_pred1 = getattr(estimator1, method)(X1)
891+
X_pred2 = getattr(estimator2, method)(X1)
892+
assert_allclose_dense_sparse(X_pred1, X_pred2, err_msg=err_msg)
874893

875894

876895
@ignore_warnings(category=(FutureWarning, UserWarning))

0 commit comments

Comments
 (0)
0