-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Implement some methods almost compatible with Scikit-learn private methods. #952
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
optuna/integration/sklearn.py
Outdated
) | ||
|
||
|
||
def _num_samples(x): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Original implementation is here.
How much should I take care of the original implementation? (is this too simple?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It basically seems good to me to keep the code simple.
This implementation removes the handling of exceptional cases, so we may have bug reports about them.
So, please add a link to the original implementation (including the commit id) as a comment.
Also, I'm not familiar with Dask dataframe, but I think we need the following check for Dask users.
https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/utils/validation.py#L155-L158
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the suggestion. 🙏
I added the specific check for a dask dataframe in f8f21a5.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the suggestion. 🙏
I added the specific check for a dask dataframe in f8f21a5.
fit_params_validated: Dict = {} | ||
for key, value in fit_params.items(): | ||
if ( | ||
not _is_arraylike(value) or |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Original implementation is here.
Currently, scikit-learn does not accept non-iterable inputs and this line is for keeping backward compatibility.
scikit-learn/scikit-learn#15805
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please leave the link to the original implementation with the commit id in a comment for the future development.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your PR. I confirmed that examples/optuna_search_cv_simple.py
successfully worked with this implementation using v0.20.4, 0.21.3 and 0.22.1 of scikit-learn.
@@ -18,7 +17,6 @@ | |||
from sklearn.utils import check_random_state | |||
from sklearn.utils.metaestimators import _safe_split |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_safe_split
is also a private method of sklearn
. I think we can work on it in the new PR because it is not related to #881.
optuna/integration/sklearn.py
Outdated
) | ||
|
||
|
||
def _num_samples(x): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It basically seems good to me to keep the code simple.
This implementation removes the handling of exceptional cases, so we may have bug reports about them.
So, please add a link to the original implementation (including the commit id) as a comment.
Also, I'm not familiar with Dask dataframe, but I think we need the following check for Dask users.
https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/utils/validation.py#L155-L158
fit_params_validated: Dict = {} | ||
for key, value in fit_params.items(): | ||
if ( | ||
not _is_arraylike(value) or |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please leave the link to the original implementation with the commit id in a comment for the future development.
optuna/integration/sklearn.py
Outdated
): | ||
fit_params_validated[key] = value | ||
else: | ||
fit_params_validated[key] = value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The original code here applies the _make_indexiable
to value
. Do we have any drawbacks?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, you're right. I missed it.
I added _make_indexable
in 2be88c7.
Co-Authored-By: Toshihiko Yanase <toshihiko.yanase@gmail.com>
Co-Authored-By: Toshihiko Yanase <toshihiko.yanase@gmail.com>
Co-Authored-By: Toshihiko Yanase <toshihiko.yanase@gmail.com>
optuna/integration/sklearn.py
Outdated
# NOTE For dask dataframes | ||
# https://github.com/scikit-learn/scikit-learn/blob/ \ | ||
# 8caa93889f85254fc3ca84caa0a24a1640eebdd1/sklearn/utils/validation.py#L155-L158 | ||
if hasattr(x, 'shape') and x.shape is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mypy throws the following errors. (https://app.circleci.com/jobs/github/optuna/optuna/29090)
#!/bin/bash -eo pipefail
. venv/bin/activate
mypy --disallow-untyped-defs --ignore-missing-imports .
optuna/integration/sklearn.py:134: error: Item "List[Any]" of "Union[List[Any], Any, Any]" has no attribute "shape"
optuna/integration/sklearn.py:135: error: Item "List[Any]" of "Union[List[Any], Any, Any]" has no attribute "shape"
optuna/integration/sklearn.py:136: error: Item "List[Any]" of "Union[List[Any], Any, Any]" has no attribute "shape"
optuna/integration/sklearn.py:136: error: Incompatible return value type (got "Integral", expected "int")
Found 4 errors in 1 file (checked 121 source files)
Exited with code exit status 1
ArrayLikeType
is defined here and I don't understand why these errors occur. 🤕
(ArrayLikeType = Union[List, np.ndarray, pd.Series]
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think mypy
cannot infer the type based on hasattr
.
How about using getattr
? It is suggested here.
Example:
x_shape = getattr(x, 'shape', None)
if x_shape is not None:
if isinstance(x_shape[0], Integral):
return int(x_shape[0])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It works. Thank you!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I investigated the mypy error and I think I can find a workaround.
optuna/integration/sklearn.py
Outdated
# NOTE For dask dataframes | ||
# https://github.com/scikit-learn/scikit-learn/blob/ \ | ||
# 8caa93889f85254fc3ca84caa0a24a1640eebdd1/sklearn/utils/validation.py#L155-L158 | ||
if hasattr(x, 'shape') and x.shape is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think mypy
cannot infer the type based on hasattr
.
How about using getattr
? It is suggested here.
Example:
x_shape = getattr(x, 'shape', None)
if x_shape is not None:
if isinstance(x_shape[0], Integral):
return int(x_shape[0])
optuna/integration/sklearn.py
Outdated
# NOTE Original implementation: | ||
# https://github.com/scikit-learn/scikit-learn/blob/ \ | ||
# 8caa93889f85254fc3ca84caa0a24a1640eebdd1/sklearn/utils/validation.py#L217-L234 | ||
# It removed the check if an input is scipy sparse matrix |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, but I couldn't get the point of this comment.
When I checked the difference between the original code and this code, I understood that the latter one does not have the conversion from scipy sparse matrix to csr. Could you mention it and add the reason?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Co-Authored-By: Toshihiko Yanase <toshihiko.yanase@gmail.com>
Codecov Report
@@ Coverage Diff @@
## master #952 +/- ##
==========================================
+ Coverage 90.15% 90.17% +0.01%
==========================================
Files 112 114 +2
Lines 9306 9548 +242
==========================================
+ Hits 8390 8610 +220
- Misses 916 938 +22
Continue to review full report at Codecov.
|
@toshihikoyanase
Could you please take a look? 🙇 For old fashion type hints, I'd like to create a follow-up PR |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thank you for your update.
For old fashion type hints, I'd like to create a follow-up PR
since I think it would be better to update all type hints in sklearn.py and test_sklearn.py at the same time. (but it should be done in another PR)
That makes sense. Let's update the type hints in a new PR.
@Y-oHr-N This PR will introduce some private methods of scikit-learn to |
- from sklearn.utils import safe_indexing as sklearn_safe_indexing
+ if sklearn_version >= "0.22":
+ from sklearn.utils import _safe_indexing as sklearn_safe_indexing
+ else:
+ from sklearn.utils import safe_indexing as sklearn_safe_indexing |
@Y-oHr-N Thank you for pointing it out! It is deprecated and will be removed in 0.24. IMO, we can work on it in a new PR because we still have some time. https://github.com/scikit-learn/scikit-learn/blob/0.22.X/sklearn/utils/__init__.py#L292-L294 @deprecated("safe_indexing is deprecated in version "
"0.22 and will be removed in version 0.24.")
def safe_indexing(X, indices, axis=0): |
@toshihikoyanase, You're right. There are no other comments. LGTM. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, LGTM!
Not entirely sure about the PR labeling but I added one for now. Let me just modify the title to match our release note format. |
This PR is follow-up for #881.
I implement some methods to reduce dependencies on scikit-learn private methods.
I basically define methods to be compatible with scikit-learn's, but some points are di 8000 fferent.
(The name of this branch should be
sklearn-privates
...:innocent:)