8000 FIX `param_distribution` param of `HalvingRandomSearchCV` accepts list of dicts by StefanieSenger · Pull Request #26893 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

FIX param_distribution param of HalvingRandomSearchCV accepts list of dicts #26893

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Aug 7, 2023

Conversation

StefanieSenger
Copy link
Contributor

What does this implement/fix? Explain your changes.

Closes #26885
Fixed that the param_distribution param of HalvingRandomSearchCV accepts lists of dicts and updated documentation.

I also tried to implement a test, using test_random_search_cv_results as a template, as you suggested @glemaitre , but I encountered several problems, that I could not resolve.

The template implementation calls two functions (check_cv_results_array_types and check_cv_results_keys), that check and compare the occurrence of params. But those might not always be present (like 'param_degree' is only a key in cv_results for the poly kernel, not for rbf). (HalvingSerach' cv_results will also have two additional keys, compared to the other searches, these tests are used for: "iter", "n_resources")

I cannot see a way to use the assert tests in the end of the template test, because HalvingGridSearchCV will mask part of the candidates, as part of the process. So, checking for this is not going to work, I assume.

I have determined the value for n_proportion = 6 by looking at cv_results[key].shape, which is for sure the wrong way around.

I have commented the test out and hope for your advice. At the moment the test fails because of KeyError (param_degree) in one of the util functions.

I could write a much simpler test, that captures the insertion of param_distribution as a list of dicts.

@github-actions
Copy link
github-actions bot commented Jul 24, 2023

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 5afd7d0. Link to the linter CI: here

@StefanieSenger StefanieSenger changed the title FIX param_distribution param of HalvingRandomSearchCV accepts lists of dicts FIX param_distribution param of HalvingRandomSearchCV accepts list of dicts Jul 25, 2023
@jeremiedbb
Copy link
Member

I think the main reasons of the failures come from a too small dataset and not enough candidates. Here's a slightly modified version:

def test_halving_random_search_cv_results():
    X, y = make_classification(n_samples=150, n_features=4, random_state=42)

    params = [
        {"kernel": ["rbf"], "C": expon(scale=10), "gamma": expon(scale=0.1)},
        {"kernel": ["poly"], "degree": [2, 3]},
    ]
    param_keys = ("param_C", "param_degree", "param_gamma", "param_kernel")
    score_keys = (
        "mean_test_score",
        "mean_train_score",
        "rank_test_score",
        "split0_test_score",
        "split1_test_score",
        "split2_test_score",
        "split0_train_score",
        "split1_train_score",
        "split2_train_score",
        "std_test_score",
        "std_train_score",
        "mean_fit_time",
        "std_fit_time",
        "mean_score_time",
        "std_score_time",
    )
    extra_keys = ("n_resources", "iter")

    search = HalvingRandomSearchCV(
        SVC(),
        cv=3,
        param_distributions=params,
        return_train_score=True,
        random_state=0,
    )
    search.fit(X, y)
    n_candidates = sum(search.n_candidates_)

    cv_results = search.cv_results_
    # Check results structure
    check_cv_results_keys(
        cv_results, param_keys, score_keys, n_candidates, extra_keys
    )
    check_cv_results_array_types(search, param_keys, score_keys)

    assert all(
        (
            cv_results["param_C"].mask[i]
            and cv_results["param_gamma"].mask[i]
            and not cv_results["param_degree"].mask[i]
        )
        for i in range(n_candidates)
        if cv_results["param_kernel"][i] == "poly"
    )
    assert all(
        (
            not cv_results["param_C"].mask[i]
            and not cv_results["param_gamma"].mask[i]
            and cv_results["param_degree"].mask[i]
        )
        for i in range(n_candidates)
        if cv_results["param_kernel"][i] == "rbf"
    )
  • more samples. 2 reasons:
    • we're testing something doing cross validation so the dataset will be splitted and we need to have the 2 classes in both train and test sets as often as possible.
    • Successing halving uses n_samples as a resource and relies on it to evaluate the number of candidates for the first round. So it means that it'll start with less samples which comes back to the previous point. Also, too small can lead to only select a single type of kernel for the first round for instance.
  • n_candidates is not a paremeter so we can't easily know in advance what it'll be. Better to just use the n_candidates_ attribute
  • set random_state to ensure reproducible results. I tested with 100 different seeds and the test always passed.
  • I'd rather set iter and n_resources as extra_keys and modify check_cv_results_keys instead.
    def check_cv_results_keys(cv_results, param_keys, score_keys, n_cand, extra_keys=()):
        # Test the search.cv_results_ contains all the required results
        all_keys = param_keys + score_keys + extra_keys
        assert_array_equal(
            sorted(cv_results.keys()), sorted(all_keys + ("params",))
        )
        assert all(cv_results[key].shape == (n_cand,) for key in all_keys)

Copy link
Member
@jeremiedbb jeremiedbb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also add a changelog entry for 1.3.1 in v1.3.rst

@StefanieSenger StefanieSenger marked this pull request as ready for review July 26, 2023 18:52
@StefanieSenger
Copy link
Contributor Author
StefanieSenger commented Jul 26, 2023

Thanks for reviewing, @jeremiedbb, for your help and the explanations. I have made the changes according to your suggestions and kind of understood your reasoning.

I still haven't understood the asserts for the masked arrays though (all the candidates appeared as masked_array for both kernels, when I checked), and I will talk with @adrinjalali about it tomorrow.

@StefanieSenger
Copy link
Contributor Author

After reviewing this together with @adrinjalali, I have also modified the two not fully functioning tests I had idicated in the issue (#26885). Please have a look. :)

@StefanieSenger
Copy link
Contributor Author

@glemaitre

@glemaitre glemaitre self-requested a review August 1, 2023 11:42
Copy link
Member
@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM on my side.

@glemaitre
Copy link
Member

And you would need to solve the conflict in the changelog.

Copy link
Member
@betatim betatim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix! LGTM (looks good to me)

@StefanieSenger
Copy link
Contributor Author

I've finished the last few things. Thanks everyone for your support. :)

@adrinjalali adrinjalali merged commit 3725ac1 into scikit-learn:main Aug 7, 2023
TamaraAtanasoska pushed a commit to TamaraAtanasoska/scikit-learn that referenced this pull request Aug 21, 2023
…t of dicts (scikit-learn#26893)

Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
@StefanieSenger StefanieSenger deleted the HalvingRandomSearchCV branch August 23, 2023 11:16
glemaitre added a commit to glemaitre/scikit-learn that referenced this pull request Sep 18, 2023
…t of dicts (scikit-learn#26893)

Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
jeremiedbb pushed a commit that referenced this pull request Sep 20, 2023
…t of dicts (#26893)

Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
REDVM pushed a commit to REDVM/scikit-learn that referenced this pull request Nov 16, 2023
…t of dicts (scikit-learn#26893)

Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

HalvingRandomSearchCV does not support param_distribution as a list
5 participants
0