@@ -719,15 +719,12 @@ def check_estimator_sparse_data(name, estimator_orig):
719
719
X [X < .8 ] = 0
720
720
X = _pairwise_estimator_convert_X (X , estimator_orig )
721
721
X_csr = sparse .csr_matrix (X )
722
- tags = estimator_orig ._get_tags ()
723
- if tags ['binary_only' ]:
724
- y = (2 * rng .rand (40 )).astype (np .int )
725
- else :
726
- y = (4 * rng .rand (40 )).astype (np .int )
722
+ y = (4 * rng .rand (40 )).astype (int )
727
723
# catch deprecation warnings
728
724
with ignore_warnings (category = FutureWarning ):
729
725
estimator = clone (estimator_orig )
730
726
y = _enforce_estimator_tags_y (estimator , y )
727
+ tags = estimator_orig ._get_tags ()
731
728
for matrix_format , X in _generate_sparse_matrix (X_csr ):
732
729
# catch deprecation warnings
733
730
with ignore_warnings (category = FutureWarning ):
@@ -825,10 +822,7 @@ def check_sample_weights_list(name, estimator_orig):
825
822
n_samples = 30
826
823
X = _pairwise_estimator_convert_X (rnd .uniform (size = (n_samples , 3 )),
827
824
estimator_orig )
828
- if estimator ._get_tags ()['binary_only' ]:
829
- y = np .arange (n_samples ) % 2
830
- else :
831
- y = np .arange (n_samples ) % 3
825
+ y = np .arange (n_samples ) % 3
832
826
y = _enforce_estimator_tags_y (estimator , y )
833
827
sample_weight = [3 ] * n_samples
834
828
# Test that estimators don't raise any exception
@@ -905,10 +899,7 @@ def check_dtype_object(name, estimator_orig):
905
899
X = _pairwise_estimator_convert_X (rng .rand (40 , 10 ), estimator_orig )
906
900
X = X .astype (object )
907
901
tags = estimator_orig ._get_tags ()
908
- if tags ['binary_only' ]:
909
- y = (X [:, 0 ] * 2 ).astype (np .int )
910
- else :
911
- y = (X [:, 0 ] * 4 ).astype (np .int )
902
+ y = (X [:, 0 ] * 4 ).astype (int )
912
903
estimator = clone (estimator_orig )
913
904
y = _enforce_estimator_tags_y (estimator , y )
914
905
@@ -1007,9 +998,7 @@ def check_dont_overwrite_parameters(name, estimator_orig):
1007
998
rnd = np .random .RandomState (0 )
1008
999
<
57AE
div class="diff-text-inner"> X = 3 * rnd .uniform (size = (20 , 3 ))
1009
1000
X = _pairwise_estimator_convert_X (X , estimator_orig )
1010
- y = X [:, 0 ].astype (np .int )
1011
- if estimator ._get_tags ()['binary_only' ]:
1012
- y [y == 2 ] = 1
1001
+ y = X [:, 0 ].astype (int )
1013
1002
y = _enforce_estimator_tags_y (estimator , y )
1014
1003
1015
1004
if hasattr (estimator , "n_components" ):
@@ -1060,8 +1049,6 @@ def check_fit2d_predict1d(name, estimator_orig):
1060
1049
X = _pairwise_estimator_convert_X (X , estimator_orig )
1061
1050
y = X [:, 0 ].astype (np .int )
1062
1051
tags = estimator_orig ._get_tags ()
1063
- if tags ['binary_only' ]:
1064
- y [y == 2 ] = 1
1065
1052
estimator = clone (estimator_orig )
1066
1053
y = _enforce_estimator_tags_y (estimator , y )
1067
1054
@@ -1109,9 +1096,7 @@ def check_methods_subset_invariance(name, estimator_orig):
1109
1096
rnd = np .random .RandomState (0 )
1110
1097
X = 3 * rnd .uniform (size = (20 , 3 ))
1111
1098
X = _pairwise_estimator_convert_X (X , estimator_orig )
1112
- y = X [:, 0 ].astype (np .int )
1113
- if estimator_orig ._get_tags ()['binary_only' ]:
1114
- y [y == 2 ] = 1
1099
+ y = X [:, 0 ].astype (int )
1115
1100
estimator = clone (estimator_orig )
1116
1101
y = _enforce_estimator_tags_y (estimator , y )
1117
1102
@@ -1383,10 +1368,7 @@ def check_fit_score_takes_y(name, estimator_orig):
1383
1368
n_samples = 30
1384
1369
X = rnd .uniform (size = (n_samples , 3 ))
1385
1370
X = _pairwise_estimator_convert_X (X , estimator_orig )
1386
- if estimator_orig ._get_tags ()['binary_only' ]:
1387
- y = np .arange (n_samples ) % 2
1388
- else :
1389
- y = np .arange (n_samples ) % 3
1371
+ y = np .arange (n_samples ) % 3
1390
1372
estimator = clone (estimator_orig )
1391
1373
y = _enforce_estimator_tags_y (estimator , y )
1392
1374
set_random_state (estimator )
@@ -1416,8 +1398,6 @@ def check_estimators_dtypes(name, estimator_orig):
1416
1398
X_train_int_64 = X_train_32 .astype (np .int64 )
1417
1399
X_train_int_32 = X_train_32 .astype (np .int32 )
1418
1400
y = X_train_int_64 [:, 0 ]
1419
- if estimator_orig ._get_tags ()['binary_only' ]:
1420
- y [y == 2 ] = 1
1421
1401
y = _enforce_estimator_tags_y (estimator_orig , y )
1422
1402
1423
1403
methods = ["predict" , "transform" , "decision_function" , "predict_proba" ]
@@ -1596,6 +1576,7 @@ def check_estimators_partial_fit_n_features(name, estimator_orig):
1596
1576
estimator = clone (estimator_orig )
1597
1577
X , y = make_blobs (n_samples = 50 , random_state = 1 )
1598
1578
X -= X .min ()
1579
+ y = _enforce_estimator_tags_y (estimator_orig , y )
1599
1580
1600
1581
try :
1601
1582
if is_classifier (estimator ):
@@ -2062,11 +2043,7 @@ def check_classifiers_multilabel_representation_invariance(name,
2062
2043
def check_estimators_fit_returns_self (name , estimator_orig ,
2063
2044
readonly_memmap = False ):
2064
2045
"""Check if self is returned when calling fit"""
2065
- if estimator_orig ._get_tags ()['binary_only' ]:
2066
- n_centers = 2
2067
- else :
2068
- n_centers = 3
2069
- X , y = make_blobs (random_state = 0 , n_samples = 21 , centers = n_centers )
2046
+ X , y = make_blobs (random_state = 0 , n_samples = 21 )
2070
2047
# some want non-negative input
2071
2048
X -= X .min ()
2072
2049
X = _pairwise_estimator_convert_X (X , estimator_orig )
@@ -2108,10 +2085,7 @@ def check_supervised_y_2d(name, estimator_orig):
2108
2085
X = _pairwise_estimator_convert_X (
2109
2086
rnd .uniform (size = (n_samples , 3 )), estimator_orig
2110
2087
)
2111
- if tags ['binary_only' ]:
2112
- y = np .arange (n_samples ) % 2
2113
- else :
2114
- y = np .arange (n_samples ) % 3
2088
+ y = np .arange (n_samples ) % 3
2115
2089
y = _enforce_estimator_tags_y (estimator_orig , y )
2116
2090
estimator = clone (estimator_orig )
2117
2091
set_random_state (estimator )
@@ -2436,11 +2410,7 @@ def check_class_weight_balanced_linear_classifier(name, Classifier):
2436
2410
2437
2411
@ignore_warnings (category = FutureWarning )
2438
2412
def check_estimators_overwrite_params (name , estimator_orig ):
2439
- if estimator_orig ._get_tags ()['binary_only' ]:
2440
- n_centers = 2
2441
- else :
2442
- n_centers = 3
2443
- X , y = make_blobs (random_state = 0 , n_samples = 21 , centers = n_centers )
2413
+ X , y = make_blobs (random_state = 0 , n_samples = 21 )
2444
2414
# some want non-negative input
2445
2415
X -= X .min ()
2446
2416
X = _pairwise_estimator_convert_X (X , estimator_orig , kernel = rbf_kernel )
@@ -2511,7 +2481,8 @@ def check_no_attributes_set_in_init(name, estimator_orig):
2511
2481
def check_sparsify_coefficients (name , estimator_orig ):
2512
2482
X = np .array ([[- 2 , - 1 ], [- 1 , - 1 ], [- 1 , - 2 ], [1 , 1 ], [1 , 2 ], [2 , 1 ],
2513
2483
[- 1 , - 2 ], [2 , 2 ], [- 2 , - 2 ]])
2514
- y = [1 , 1 , 1 , 2 , 2 , 2 , 3 , 3 , 3 ]
2484
+ y = np .array ([1 , 1 , 1 , 2 , 2 , 2 , 3 , 3 , 3 ])
2485
+ y = _enforce_estimator_tags_y (estimator_orig , y )
2515
2486
est = clone (estimator_orig )
2516
2487
2517
2488
est .fit (X , y )
@@ -2535,7 +2506,7 @@ def check_classifier_data_not_an_array(name, estimator_orig):
2535
2506
X = np .array ([[3 , 0 ], [0 , 1 ], [0 , 2 ], [1 , 1 ], [1 , 2 ], [2 , 1 ],
2536
2507
[0 , 3 ], [1 , 0 ], [2 , 0 ], [4 , 4 ], [2 , 3 ], [3 , 2 ]])
2537
2508
X = _pairwise_estimator_convert_X (X , estimator_orig )
2538
- y = [1 , 1 , 1 , 2 , 2 , 2 , 1 , 1 , 1 , 2 , 2 , 2 ]
2509
+ y = np . array ( [1 , 1 , 1 , 2 , 2 , 2 , 1 , 1 , 1 , 2 , 2 , 2 ])
2539
2510
y = _enforce_estimator_tags_y (estimator_orig , y )
2540
2511
for obj_type in ["NotAnArray" , "PandasDataframe" ]:
2541
2512
check_estimators_data_not_an_array (name , estimator_orig , X , y ,
@@ -2682,6 +2653,9 @@ def _enforce_estimator_tags_y(estimator, y):
2682
2653
# Create strictly positive y. The minimal increment above 0 is 1, as
2683
2654
# y could be of integer dtype.
2684
2655
y += 1 + abs (y .min ())
2656
+ # Estimators with a `binary_only` tag only accept up to two unique y values
2657
+ if estimator ._get_tags ()["binary_only" ] and y .size > 0 :
2658
+ y = np .where (y == y .flat [0 ], y , y .flat [0 ] + 1 )
2685
2659
# Estimators in mono_output_task_error raise ValueError if y is of 1-D
2686
2660
# Convert into a 2-D y for those estimators.
2687
2661
if estimator ._get_tags ()["multioutput_only" ]:
0 commit comments