-
-
Notifications
You must be signed in to change notification settings - Fork 26k
TST Add a check for dtype preservation on Regressors' predictions #22763
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
If How can I hook this function check, when the default test is executed. After this PR gets merged new issue has to be raised to hook all regressors which holds this property. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for tackling this, @Diwakar-Gupta.
8000
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After this PR gets merged new issue has to be raised to hook all regressors which holds this property.
I think we need to hook it up now so we actually run the code on an estimator that has the property. Hooking it up is very similar to #21355 but listing out the regressors that fails to ignore and calling check_regressor_preserve_dtypes
directly.
After #22525, I think BayesianRidge
should have the property.
sklearn/utils/estimator_checks.py
Outdated
X, y = make_regression(random_state=42, n_targets=5, n_samples=100, n_features=3) | ||
X = StandardScaler().fit_transform(X) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See:
scikit-learn/sklearn/utils/estimator_checks.py
Lines 2774 to 2777 in 5afd5e1
@ignore_warnings(category=FutureWarning) | |
def check_regressors_train( | |
name, regressor_orig, readonly_memmap=False, X_dtype=np.float64 | |
): |
for a demonstration of generating a regression dataset.
After adding
to class
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR @Diwakar-Gupta. Please find below how we expect this new check to be integrated in the scikit-learn test suite to run automatically for any estimator that has the preserves_dtype
tag defined.
regressors check for dtype preservation
LGTM but before merging I would like to make sure that we have at least one estimator which has the relevant tags to make sure that this new test passes on int. @thomasjpfan mentioned previously that it might be the case for
Can you please try if this holds by setting the |
The documentation of the https://scikit-learn.org/stable/developers/develop.html#estimator-tags The source of this page is in: |
Reference Issues/PRs
Fixes #22682
What does this implement/fix? Explain your changes.
check_regressor_preserve_dtypes
function is added inestimator_checks.py
it looks for regressor models with the 'preserves dtype' tag to ensure that the label dtype is preserved whenpredict
is called.