8000 FIX Check and correct the input_tags.sparse flag by antoinebaker · Pull Request #30187 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

FIX Check and correct the input_tags.sparse flag #30187

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 35 commits into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
a896f2e
add input_tags.sparse and test
antoinebaker Oct 31, 2024
b0e605d
fix LinearRegression tag
antoinebaker Nov 4, 2024
6c72527
changelog
antoinebaker Nov 4, 2024
22b5a6b
Merge branch 'main' into fix_input_sparse_tag
antoinebaker Nov 4, 2024
c306593
changelog
antoinebaker Nov 4, 2024
6aadc95
fix column transformer tag
antoinebaker Nov 14, 2024
7979fa9
change error message
antoinebaker Nov 15, 2024
6d7c2b1
changelog
antoinebaker Nov 15, 2024
f73913f
Merge remote-tracking branch 'upstream/main' into fix_input_sparse_tag
antoinebaker Nov 15, 2024
7005797
fix passthrough sparse tag
antoinebaker Nov 15, 2024
6962aa9
fix SelfTrainingClassifier
antoinebaker Nov 15, 2024
2668fea
Apply suggestions from code review
antoinebaker Nov 15, 2024
705c414
add suggestions
antoinebaker Nov 15, 2024
5178539
fix multitask
antoinebaker Nov 15, 2024
e5a4458
black formatting
antoinebaker Nov 15, 2024
64ddb62
Merge branch 'main' into fix_input_sparse_tag
antoinebaker Nov 18, 2024
d6f277f
add meta test
antoinebaker Nov 18, 2024
0bcc765
add feature union
antoinebaker Nov 18, 2024
d862de7
check function transformer
antoinebaker Nov 18, 2024
21ec585
catch invalid transformers list
antoinebaker Nov 20, 2024
8447756
Merge branch 'main' into fix_input_sparse_tag
glemaitre Nov 25, 2024
c653834
add todo
antoinebaker Nov 26, 2024
d9f4de3
remove outer function
antoinebaker Nov 26, 2024
a9fb7d7
change pipeline tag
antoinebaker Nov 26, 2024
e85f94a
tag RobustScaler
antoinebaker Nov 29, 2024
8f2f3db
tag RANSAC
antoinebaker Nov 29, 2024
3ac38e1
multi_output in LinearModelCV
antoinebaker Nov 29, 2024
42d815e
Merge remote-tracking branch 'upstream/main' into fix_input_sparse_tag
antoinebaker Dec 9, 2024
425b473
raise from exception
antoinebaker Dec 9, 2024
17ccb72
no cover
antoinebaker Dec 11, 2024
8429e8f
test raise inappropriate error
antoinebaker Dec 11, 2024
87e5211
Merge branch 'main' into fix_input_sparse_tag
antoinebaker Dec 11, 2024
97d700d
Apply suggestions from code review
antoinebaker Dec 12, 2024
66c3bd9
suggestions from code review
antoinebaker Dec 12, 2024
b54a408
Merge remote-tracking branch 'upstream/main' into pr/antoinebaker/30187
jeremiedbb Jan 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/whats_new/upcoming_changes/changed-models/30187.fix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
- The `tags.input_tags.sparse` flag was corrected for a majority of estimators.
By :user:`Antoine Baker <antoinebaker>`
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- :func:`utils.estimator_checks.check_estimator_sparse_tag` ensures that
the estimator tag `input_tags.sparse` is consistent with its `fit`
method (accepting sparse input `X` or raising the appropriate error).
By :user:`Antoine Baker <antoinebaker>`
11 changes: 6 additions & 5 deletions sklearn/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,7 @@
from .model_selection import LeaveOneOut, check_cv, cross_val_predict
from .preprocessing import LabelEncoder, label_binarize
from .svm import LinearSVC
from .utils import (
_safe_indexing,
column_or_1d,
indexable,
)
from .utils import _safe_indexing, column_or_1d, get_tags, indexable
from .utils._param_validation import (
HasMethods,
Hidden,
Expand Down Expand Up @@ -554,6 +550,11 @@ def get_metadata_routing(self):
)
return router

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.sparse = get_tags(self._get_estimator()).input_tags.sparse
return tags


def _fit_classifier_calibrator_pair(
estimator,
Expand Down
1 change: 1 addition & 0 deletions sklearn/cluster/_affinity_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,7 @@ def __init__(
def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.pairwise = self.affinity == "precomputed"
tags.input_tags.sparse = self.affinity != "precomputed"
return tags

@_fit_context(prefer_skip_nested_validation=True)
Expand Down
5 changes: 5 additions & 0 deletions sklearn/cluster/_bicluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,11 @@ def _k_means(self, data, n_clusters):
labels = model.labels_
return centroid, labels

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.sparse = True
return tags


class SpectralCoclustering(BaseSpectral):
"""Spectral Co-Clustering algorithm (Dhillon, 2001).
Expand Down
1 change: 1 addition & 0 deletions sklearn/cluster/_birch.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,4 +742,5 @@ def _global_clustering(self, X=None):
def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.transformer_tags.preserves_dtype = ["float64", "float32"]
tags.input_tags.sparse = True
return tags
1 change: 1 addition & 0 deletions sklearn/cluster/_bisect_k_means.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,5 +538,6 @@ def _predict_recursive(self, X, sample_weight, cluster_node):

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.sparse = True
tags.transformer_tags.preserves_dtype = ["float64", "float32"]
return tags
1 change: 1 addition & 0 deletions sklearn/cluster/_dbscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,4 +473,5 @@ def fit_predict(self, X, y=None, sample_weight=None):
def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.pairwise = self.metric == "precomputed"
tags.input_tags.sparse = True
return tags
1 change: 1 addition & 0 deletions sklearn/cluster/_hdbscan/hdbscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,5 +999,6 @@ def dbscan_clustering(self, cut_distance, min_cluster_size=5):

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.sparse = True
tags.input_tags.allow_nan = self.metric != "precomputed"
return tags
5 changes: 5 additions & 0 deletions sklearn/cluster/_kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -1177,6 +1177,11 @@ def score(self, X, y=None, sample_weight=None):
)
return -scores

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.sparse = True
return tags


class KMeans(_BaseKMeans):
"""K-Means clustering.
Expand Down
1 change: 1 addition & 0 deletions sklearn/cluster/_spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,7 @@ def fit_predict(self, X, y=None):

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.sparse = True
tags.input_tags.pairwise = self.affinity in [
"precomputed",
"precomputed_nearest_neighbors",
Expand Down
16 changes: 16 additions & 0 deletions sklearn/compose/_column_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
_get_output_config,
_safe_set_output,
)
from ..utils._tags import get_tags
from ..utils.metadata_routing import (
MetadataRouter,
MethodMapping,
Expand Down Expand Up @@ -1315,6 +1316,21 @@

return router

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
try:
tags.input_tags.sparse = all(
get_tags(trans).input_tags.sparse
for name, trans, _ in self.transformers
if trans not in {"passthrough", "drop"}
)
except Exception:

Check warning on line 1327 in sklearn/compose/_column_transformer.py

View check run for this annotation

Codecov / codecov/patch

sklearn/compose/_column_transformer.py#L1327

Added line #L1327 was not covered by tests
# If `transformers` does not comply with our API (list of tuples)
# then it will fail. In this case, we assume that `sparse` is False
# but the parameter validation will raise an error during `fit`.
pass # pragma: no cover
return tags


def _check_X(X):
"""Use check_array only when necessary, e.g. on lists and other non-array-likes."""
Expand Down
1 change: 1 addition & 0 deletions sklearn/compose/_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ def __sklearn_tags__(self):
regressor = self._get_regressor()
tags = super().__sklearn_tags__()
tags.regressor_tags.poor_score = True
tags.input_tags.sparse = get_tags(regressor).input_tags.sparse
tags.target_tags.multi_output = get_tags(regressor).target_tags.multi_output
return tags

Expand Down
6 changes: 6 additions & 0 deletions sklearn/decomposition/_incremental_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,3 +418,9 @@ def transform(self, X):
return np.vstack(output)
else:
return super().transform(X)

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
# Beware that fit accepts sparse data but partial_fit doesn't
tags.input_tags.sparse = True
return tags
1 change: 1 addition & 0 deletions sklearn/decomposition/_kernel_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,7 @@ def inverse_transform(self, X):

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.sparse = True
tags.transformer_tags.preserves_dtype = ["float64", "float32"]
tags.input_tags.pairwise = self.kernel == "precomputed" 10000
return tags
Expand Down
1 change: 1 addition & 0 deletions sklearn/decomposition/_lda.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,7 @@ def _em_step(self, X, total_samples, batch_update, parallel=None):
def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.positive_only = True
tags.input_tags.sparse = True
tags.transformer_tags.preserves_dtype = ["float32", "float64"]
return tags

Expand Down
1 change: 1 addition & 0 deletions sklearn/decomposition/_nmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1331,6 +1331,7 @@ def _n_features_out(self):
def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.positive_only = True
tags.input_tags.sparse = True
tags.transformer_tags.preserves_dtype = ["float64", "float32"]
return tags

Expand Down
5 changes: 5 additions & 0 deletions sklearn/decomposition/_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,4 +851,9 @@ def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.transformer_tags.preserves_dtype = ["float64", "float32"]
tags.array_api_support = True
tags.input_tags.sparse = self.svd_solver in (
"auto",
"arpack",
"covariance_eigh",
)
return tags
1 change: 1 addition & 0 deletions sklearn/decomposition/_truncated_svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ def inverse_transform(self, X):

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.sparse = True
tags.transformer_tags.preserves_dtype = ["float64", "float32"]
return tags

Expand Down
2 changes: 2 additions & 0 deletions sklearn/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,7 @@ def predict_log_proba(self, X):

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.sparse = True
tags.classifier_tags.poor_score = True
tags.no_validation = True
return tags
Expand Down Expand Up @@ -662,6 +663,7 @@ def predict(self, X, return_std=False):

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.sparse = True
tags.regressor_tags.poor_score = True
tags.no_validation = True
return tags
Expand Down
1 change: 1 addition & 0 deletions sklearn/ensemble/_bagging.py
10000
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,7 @@ def _get_estimator(self):

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.sparse = get_tags(self._get_estimator()).input_tags.sparse
tags.input_tags.allow_nan = get_tags(self._get_estimator()).input_tags.allow_nan
return tags

Expand Down
13 changes: 8 additions & 5 deletions sklearn/ensemble/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,14 +288,17 @@ def get_params(self, deep=True):
def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
try:
allow_nan = all(
tags.input_tags.allow_nan = all(
get_tags(est[1]).input_tags.allow_nan if est[1] != "drop" else True
for est in self.estimators
)
tags.input_tags.sparse = all(
get_tags(est[1]).input_tags.sparse if est[1] != "drop" else True
for est in self.estimators
)
except Exception:
# If `estimators` does not comply with our API (list of tuples) then it will
# fail. In this case, we assume that `allow_nan` is False but the parameter
# validation will raise an error during `fit`.
allow_nan = False
tags.input_tags.allow_nan = allow_nan
# fail. In this case, we assume that `allow_nan` and `sparse` are False but
# the parameter validation will raise an error during `fit`.
pass # pragma: no cover
return tags
11 changes: 11 additions & 0 deletions sklearn/ensemble/_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1002,6 +1002,7 @@ def predict_log_proba(self, X):
def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.classifier_tags.multi_label = True
tags.input_tags.sparse = True
return tags


Expand Down Expand Up @@ -1165,6 +1166,11 @@ def _compute_partial_dependence_recursion(self, grid, target_features):

return averaged_predictions

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.sparse = True
return tags


class RandomForestClassifier(ForestClassifier):
"""
Expand Down Expand Up @@ -2987,3 +2993,8 @@ def transform(self, X):
"""
check_is_fitted(self)
return self.one_hot_encoder_.transform(self.apply(X))

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.sparse = True
return tags
5 changes: 5 additions & 0 deletions sklearn/ensemble/_gb.py
Original file line number Diff line number Diff line change
Expand Up @@ -1117,6 +1117,11 @@ def apply(self, X):

return leaves

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.sparse = True
return tags


class GradientBoostingClassifier(ClassifierMixin, BaseGradientBoosting):
"""Gradient Boosting for classification.
Expand Down
5 changes: 5 additions & 0 deletions sklearn/ensemble/_weight_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,11 @@ def feature_importances_(self):
"feature_importances_ attribute"
) from e

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.sparse = True
return tags


def _samme_proba(estimator, n_classes, X):
"""Calculate algorithm 4, step 2, equation c) of Zhu et al [1].
Expand Down
1 change: 1 addition & 0 deletions sklearn/feature_selection/_from_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,5 +501,6 @@ def get_metadata_routing(self):

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.sparse = get_tags(self.estimator).input_tags.sparse
tags.input_tags.allow_nan = get_tags(self.estimator).input_tags.allow_nan
return tags
1 change: 1 addition & 0 deletions sklearn/feature_selection/_rfe.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,7 @@ def __sklearn_tags__(self):
if tags.regressor_tags is not None:
tags.regressor_tags.poor_score = True
tags.target_tags.required = True
tags.input_tags.sparse = sub_estimator_tags.input_tags.sparse
tags.input_tags.allow_nan = sub_estimator_tags.input_tags.allow_nan
return tags

Expand Down
1 change: 1 addition & 0 deletions sklearn/feature_selection/_sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ def _get_support_mask(self):
def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.allow_nan = get_tags(self.estimator).input_tags.allow_nan
tags.input_tags.sparse = get_tags(self.estimator).input_tags.sparse
return tags

def get_metadata_routing(self):
Expand Down
1 change: 1 addition & 0 deletions sklearn/feature_selection/_univariate_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,7 @@ def _check_params(self, X, y):
def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.target_tags.required = True
tags.input_tags.sparse = True
return tags


Expand Down
1 change: 1 addition & 0 deletions sklearn/feature_selection/_variance_threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,4 +137,5 @@ def _get_support_mask(self):
def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.allow_nan = True
tags.input_tags.sparse = True
return tags
2 changes: 2 additions & 0 deletions sklearn/impute/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,7 @@ def inverse_transform(self, X):

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.sparse = True
tags.input_tags.allow_nan = is_pandas_na(self.missing_values) or is_scalar_nan(
self.missing_values
)
Expand Down Expand Up @@ -1130,5 +1131,6 @@ def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.allow_nan = True
tags.input_tags.string = True
tags.input_tags.sparse = True
tags.transformer_tags.preserves_dtype = []
return tags
8 changes: 8 additions & 0 deletions sklearn/kernel_approximation.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,11 @@ def transform(self, X):

return data_sketch

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.sparse = True
return tags


class RBFSampler(ClassNamePrefixFeaturesOutMixin, TransformerMixin, BaseEstimator):
"""Approximate a RBF kernel feature map using random Fourier features.
Expand Down Expand Up @@ -404,6 +409,7 @@ def transform(self, X):

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.sparse = True
tags.transformer_tags.preserves_dtype = ["float64", "float32"]
return tags

Expand Down Expand Up @@ -826,6 +832,7 @@ def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.requires_fit = False
tags.input_tags.positive_only = True
tags.input_tags.sparse = True
return tags


Expand Down Expand Up @@ -1094,5 +1101,6 @@ def _get_kernel_params(self):

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.sparse = True
tags.transformer_tags.preserves_dtype = ["float64", "float32"]
return tags
1 change: 1 addition & 0 deletions sklearn/kernel_ridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def _get_kernel(self, X, Y=None):

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.sparse = True
tags.input_tags.pairwise = self.kernel == "precomputed"
return tags

Expand Down
5 changes: 5 additions & 0 deletions sklearn/linear_model/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,11 @@ def rmatvec(b):
self._set_intercept(X_offset, y_offset, X_scale)
return self

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.sparse = not self.positive

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My team changed to scikit-learn v1.6.1 this week. We had v1.5.1 before. Our code crashes in this exact line with the error "Unexpected <class 'AttributeError'>. 'LinearRegression' object has no attribute 'positive'".

We cannot deploy in production because of this. I am desperate enough to come here to ask for help. I do not understand why it would complain that the attribute does not exist given that we were using v1.5.1 before and the attribute has existed for 4 years now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ItsIronOxide feel free to open a new issue with a reproducer, and ping me. Happy to have a look and help out.

return tags


def _check_precomputed_gram_matrix(
X, precompute, X_offset, X_scale, rtol=None, atol=1e-5
Expand Down
Loading
Loading
0