-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
ENH Adds n_feature_in_ checking to cluster #18727
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH Adds n_feature_in_ checking to cluster #18727
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just a nitpick.
While finishing these up, I am growing more concerned with validating twice. Should all functions that call |
Or we set the context manager |
I am not sure how that would work. Sometimes only the caller knows that the check has already been done. We would need to dig an hole in the API of predict / transform and similar to pass this flag. Unless we use a context manager as discussed here: #18691 (comment) Edit:
I saw this reply to your comment after writing my own... |
with config_context(assume_finite=True): | ||
return pairwise_distances_argmin(X, self.cluster_centers_) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As a quick benchmark:
from sklearn.cluster import AffinityPropagation
from sklearn.datasets import make_classification
X, _ = make_classification(n_features=10_000, n_samples=5_000, random_state=42)
aff_prop = AffinityPropagation(random_state=42)
aff_prop.fit(X)
# this PR
%timeit aff_prop.predict(X)
# 182 ms ± 2.23 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# master
%timeit aff_prop.predict(X)
# 254 ms ± 5.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
The failure of the recently merged |
@@ -38,10 +37,7 @@ def transform(self, X): | |||
""" | |||
check_is_fitted(self) | |||
|
|||
X = check_array(X) | |||
if len(self.labels_) != X.shape[1]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're assuming that the invariance of len(self.labels_) == X_train.shape[1] is enforced elsewhere, right? I guess this was only a workaround for not having n_features_in_
anyway.
f"Incorrect number of features. Got {n_features} features, " | ||
f"expected {expected_n_features}.") | ||
|
||
X = self._validate_data(X, accept_sparse='csr', reset=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we still need the function now? but fine with me.
Continues #18514