8000 [MRG+1] Fix iforest average path length by albertcthomas · Pull Request #13251 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG+1] Fix iforest average path length #13251

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Feb 26, 2019

Conversation

albertcthomas
Copy link
Contributor
@albertcthomas albertcthomas commented Feb 25, 2019

Reference Issues/PRs

Taking over #12085. Besides fixing the average path length for IsolationForest this PR also improves the checks for the predicted number of outliers in the common tests.
Closes #12085, Fixes #11839

Only modifications in the tests were required.

cc @joshuakennethjones

joshuakennethjones and others added 8 commits February 25, 2019 13:44
Fix Issue scikit-learn#11839 : sklearn.ensemble.IsolationForest._average_path_length returns incorrect values for input < 3.
Changed existing test to reflect correct values now produced by _average_path_length(), and added checks to ensure non-regression on all "base case" values in {0,1,2}.
Made recommended enhancements to comments, and change assert_almost_equal to assert_equal where constants should be returned.
Change assert_equal to assert ... == to adhere to latest conventions, and change test to properly deal with anomaly score ties in critical regions if 'decision_function' method is supported by the estimator in question, or default to the old behavior if not.
Refactoring and adding more tests to try and get coverage to an acceptable level.
assert_almost_equal(_average_path_length(1), 1., decimal=10)
assert _average_path_length(0) == 0.
assert _average_path_length(1) == 0.
assert _average_path_length(2) == 1.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we also test that _average_path_length is increasing for more values? I guess _average_path_length(2) < _average_path_length(3) would be enough

@albertcthomas
Copy link
Contributor Author

ping @agramfort for a review when you have time

Copy link
Member
@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add an entry inside in what's new

assert_almost_equal(_average_path_length(5), result_one, decimal=10)
assert_almost_equal(_average_path_length(999), result_two, decimal=10)
assert_array_almost_equal(_average_path_length(np.array([1, 5, 999])),
[1., result_one, result_two], decimal=10)
assert_array_almost_equal(_average_path_length(np.array([1, 2, 5, 999])),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we are changing this line, could we use assert_allclose

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. What's the difference? assert_allclose does not check the shapes are the same?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all_close use rtol atol instead of decimal. It is just recommended by numpy for consistency:
https://docs.scipy.org/doc/numpy/reference/generated/numpy.testing.assert_array_almost_equal.html

@@ -1548,16 +1568,16 @@ def check_outliers_train(name, estimator_orig, readonly_memmap=True):

decision = estimator.decision_function(X)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for func in ['decision_function', 'score_samples']:
    output = getattr(estimator, func)(X)
    assert output.dtype == np.dtype('float')
    assert output.shape == (n_samples,)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I now use a for loop but a bit different to your suggestion as we need the outputs for other checks.

Copy link
Contributor
@ngoix ngoix left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

appart from my small comment on _average_path_length monotonic testing and guillaume formatting comments LGTM

glemaitre and others added 4 commits February 26, 2019 12:24
Co-Authored-By: albertcthomas <albertthomas88@gmail.com>
@albertcthomas
Copy link
Contributor Author

Thanks for the reviews @glemaitre and @ngoix

@agramfort
Copy link
Member

+1 for MRG

@agramfort agramfort changed the title [MRG] Fix iforest average path length [MRG+1] Fix iforest average path length Feb 26, 2019
@agramfort
Copy link
Member

ok to merge when green @ngoix and @glemaitre ?

@glemaitre
Copy link
Member

Merging. Azure pipeline is green.

@glemaitre glemaitre merged commit bcdeadd into scikit-learn:master Feb 26, 2019
@albertcthomas
Copy link
Contributor Author

Thanks for the reviews @ngoix, @glemaitre and @agramfort. Thanks for most of the work @joshuakennethjones and sorry for delaying the original PR.

@joshuakennethjones
Copy link

Thanks for pushing this one across the finish line @albertcthomas! Glad to be able to help out a little bit -- I appreciate the efforts of everyone involved in maintaining and improving what is obviously a very useful package.

xhluca pushed a commit to xhluca/scikit-learn that referenced this pull request Apr 28, 2019
xhluca pushed a commit to xhluca/scikit-learn that referenced this pull request Apr 28, 2019
xhluca pushed a commit to xhluca/scikit-learn that referenced this pull request Apr 28, 2019
koenvandevelde pushed a commit to koenvandevelde/scikit-learn that referenced this pull request Jul 12, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

sklearn.ensemble.IsolationForest._average_path_length returns incorrect values for input < 3.
5 participants
0