18
18
from sklearn .utils .testing import assert_raise_message
19
19
from sklearn .utils .testing import assert_equal
20
20
from sklearn .utils .testing import assert_not_equal
21
+ from sklearn .utils .testing import assert_almost_equal
21
22
from sklearn .utils .testing import assert_in
22
23
from sklearn .utils .testing import assert_array_equal
23
24
from sklearn .utils .testing import assert_array_almost_equal
@@ -1524,29 +1525,8 @@ def check_classifiers_train(name, classifier_orig, readonly_memmap=False):
1524
1525
assert_array_equal (np .argsort (y_log_prob ), np .argsort (y_prob ))
1525
1526
1526
1527
1527
- def check_outlier_corruption (num_outliers , expected_outliers , decision ):
1528
- # Check for deviation from the precise given contamination level that may
1529
- # be due to ties in the anomaly scores.
1530
- if num_outliers < expected_outliers :
1531
- start = num_outliers
1532
- end = expected_outliers + 1
1533
- else :
1534
- start = expected_outliers
1535
- end = num_outliers + 1
1536
-
1537
- # ensure that all values in the 'critical area' are tied,
1538
- # leading to the observed discrepancy between provided
1539
- # and actual contamination levels.
1540
- sorted_decision = np .sort (decision )
1541
- msg = ('The number of predicted outliers is not equal to the expected '
1542
- 'number of outliers and this difference is not explained by the '
1543
- 'number of ties in the decision_function values' )
1544
- assert len (np .unique (sorted_decision [start :end ])) == 1 , msg
1545
-
1546
-
1547
1528
def check_outliers_train (name , estimator_orig , readonly_memmap = True ):
1548
- n_samples = 300
1549
- X , _ = make_blobs (n_samples = n_samples , random_state = 0 )
1529
+ X , _ = make_blobs (n_samples = 300 , random_state = 0 )
1550
1530
X = shuffle (X , random_state = 7 )
1551
1531
1552
1532
if readonly_memmap :
@@ -1567,15 +1547,17 @@ def check_outliers_train(name, estimator_orig, readonly_memmap=True):
1567
1547
assert_array_equal (np .unique (y_pred ), np .array ([- 1 , 1 ]))
1568
1548
1569
1549
decision = estimator .decision_function (X )
1570
- scores = estimator . score_samples ( X )
1571
- for output in [ decision , scores ]:
1572
- assert output . dtype == np . dtype ( 'float' )
1573
- assert output . shape == ( n_samples , )
1550
+ assert decision . dtype == np . dtype ( 'float' )
1551
+
1552
+ score = estimator . score_samples ( X )
1553
+ assert score . dtype == np . dtype ( 'float' )
1574
1554
1575
1555
# raises error on malformed input for predict
1576
1556
assert_raises (ValueError , estimator .predict , X .T )
1577
1557
1578
1558
# decision_function agrees with predict
1559
+ decision = estimator .decision_function (X )
1560
+ assert decision .shape == (n_samples ,)
1579
1561
dec_pred = (decision >= 0 ).astype (np .int )
1580
1562
dec_pred [dec_pred == 0 ] = - 1
1581
1563
assert_array_equal (dec_pred , y_pred )
@@ -1584,7 +1566,9 @@ def check_outliers_train(name, estimator_orig, readonly_memmap=True):
1584
1566
assert_raises (ValueError , estimator .decision_function , X .T )
1585
1567
1586
1568
# decision_function is a translation of score_samples
1587
- y_dec = scores - estimator .offset_
1569
+ y_scores = estimator .score_samples (X )
1570
+ assert y_scores .shape == (n_samples ,)
1571
+ y_dec = y_scores - estimator .offset_
1588
1572
assert_allclose (y_dec , decision )
1589
1573
1590
1574
# raises error on malformed input for score_samples
@@ -1597,21 +1581,11 @@ def check_outliers_train(name, estimator_orig, readonly_memmap=True):
1597
1581
# set to 'auto'. This is true for the training set and cannot thus be
1598
1582
# checked as follows for estimators with a novelty parameter such as
1599
1583
# LocalOutlierFactor (tested in check_outliers_fit_predict)
1600
- expected_outliers = 30
1601
- contamination = expected_outliers / n_samples
1584
+ contamination = 0.1
1602
1585
estimator .set_params (contamination = contamination )
1603
1586
estimator .fit (X )
1604
1587
y_pred = estimator .predict (X )
1605
-
1606
- num_outliers = np .sum (y_pred != 1 )
1607
- # num_outliers should be equal to expected_outliers unless
1608
- # there are ties in the decision_function values. this can
1609
- # only be tested for estimators with a decision_function
1610
- # method, i.e. all estimators except LOF which is already
1611
- # excluded from this if branch.
1612
- if num_outliers != expected_outliers :
1613
- decision = estimator .decision_function (X )
1614
- check_outlier_corruption (num_outliers , expected_outliers , decision )
1588
+ assert_almost_equal (np .mean (y_pred != 1 ), contamination )
1615
1589
1616
1590
# raises error when contamination is a scalar and not in [0,1]
1617
1591
for contamination in [- 0.5 , 2.3 ]:
@@ -2382,8 +2356,7 @@ def check_decision_proba_consistency(name, estimator_orig):
2382
2356
def check_outliers_fit_predict (name , estimator_orig ):
2383
2357
# Check fit_predict for outlier detectors.
2384
2358
2385
- n_samples = 300
2386
- X , _ = make_blobs (n_samples = n_samples , random_state = 0 )
2359
+ X , _ = make_blobs (n_samples = 300 , random_state = 0 )
2387
2360
X = shuffle (X , random_state = 7 )
2388
2361
n_samples , n_features = X .shape
2389
2362
estimator = clone (estimator_orig )
@@ -2405,20 +2378,10 @@ def check_outliers_fit_predict(name, estimator_orig):
2405
2378
if hasattr (estimator , "contamination" ):
2406
2379
# proportion of outliers equal to contamination parameter when not
2407
2380
# set to 'auto'
2408
- expected_outliers = 30
2409
- contamination = float (expected_outliers )/ n_samples
2381
+ contamination = 0.1
2410
2382
estimator .set_params (contamination = contamination )
2411
2383
y_pred = estimator .fit_predict (X )
2412
-
2413
- num_outliers = np .sum (y_pred != 1 )
2414
- # num_outliers should be equal to expected_outliers unless
2415
- # there are ties in the decision_function values. this can
2416
- # only be tested for estimators with a decision_function
2417
- # method
2418
- if (num_outliers != expected_outliers and
2419
- hasattr (estimator , 'decision_function' )):
2420
- decision = estimator .decision_function (X )
2421
- check_outlier_corruption (num_outliers , expected_outliers , decision )
2384
+ assert_almost_equal (np .mean (y_pred != 1 ), contamination )
2422
2385
2423
2386
# raises error when contamination is a scalar and not in [0,1]
2424
2387
for contamination in [- 0.5 , 2.3 ]:
0 commit comments