@@ -701,15 +701,12 @@ def check_estimator_sparse_data(name, estimator_orig):
701
701
X [X < .8 ] = 0
702
702
X = _pairwise_estimator_convert_X (X , estimator_orig )
703
703
X_csr = sparse.csr_matrix (X )
704
- tags = estimator_orig ._get_tags ()
705
- if tags ['binary_only' ]:
706
- y = (2 * rng .rand (40 )).astype (int )
707
- else :
708
- y = (4 * rng .rand (40 )).astype (int )
704
+ y = (4 * rng .rand (40 )).astype (int )
709
705
# catch deprecation warnings
710
706
with ignore_warnings (category = FutureWarning ):
711
707
estimator = clone (estimator_orig )
712
708
y = _enforce_estimator_tags_y (estimator , y )
709
+ tags = estimator_orig ._get_tags ()
713
710
for matrix_format , X in _generate_sparse_matrix (X_csr ):
714
711
# catch deprecation warnings
715
712
with ignore_warnings (category = FutureWarning ):
@@ -807,10 +804,7 @@ def check_sample_weights_list(name, estimator_orig):
807
804
n_samples = 30
808
805
X = _pairwise_estimator_convert_X (rnd .uniform (size = (n_samples , 3 )),
809
806
estimator_orig )
810
- if estimator ._get_tags ()['binary_only' ]:
811
- y = np .arange (n_samples ) % 2
812
- else :
813
- y = np .arange (n_samples ) % 3
807
+ y = np .arange (n_samples ) % 3
814
808
y = _enforce_estimator_tags_y (estimator , y )
815
809
sample_weight = [3 ] * n_samples
816
810
# Test that estimators don't raise any exception
@@ -901,10 +895,7 @@ def check_dtype_object(name, estimator_orig):
901
895
X = _pairwise_estimator_convert_X (rng .rand (40 , 10 ), estimator_orig )
902
896
X = X .astype (object )
903
897
tags = estimator_orig ._get_tags ()
904
- if tags ['binary_only' ]:
905
- y = (X [:, 0 ] * 2 ).astype (int )
906
- else :
907
- y = (X [:, 0 ] * 4 ).astype (int )
898
+ y = (X [:, 0 ] * 4 ).astype (int )
908
899
estimator = clone (estimator_orig )
909
900
y = _enforce_estimator_tags_y (estimator , y )
910
901
@@ -998,8 +989,6 @@ def check_dont_overwrite_parameters(name, estimator_orig):
998
989
X = 3 * rnd .uniform (size = (20 , 3 ))
999
990
X = _pairwise_estimator_convert_X (X , estimator_orig )
1000
991
y = X [:, 0 ].astype (int )
1001
- if estimator ._get_tags ()['binary_only' ]:
1002
- y [y == 2 ] = 1
1003
992
y = _enforce_estimator_tags_y (estimator , y )
1004
993
1005
994
if hasattr (estimator , "n_components" ):
@@ -1050,8 +1039,6 @@ def check_fit2d_predict1d(name, estimator_orig):
1050
1039
X = _pairwise_estimator_convert_X (X , estimator_orig )
1051
1040
y = X [:, 0 ].astype (int )
1052
1041
tags = estimator_orig ._get_tags ()
1053
- if tags ['binary_only' ]:
1054
- y [y == 2 ] = 1
1055
1042
estimator = clone (estimator_orig )
1056
1043
y = _enforce_estimator_tags_y (estimator , y )
1057
1044
@@ -1100,8 +1087,6 @@ def check_methods_subset_invariance(name, estimator_orig):
1100
1087
X = 3 * rnd .uniform (size = (20 , 3 ))
1101
1088
X = _pairwise_estimator_convert_X (X , estimator_orig )
1102
1089
y = X [:, 0 ].astype (int )
1103
- if estimator_orig ._get_tags ()['binary_only' ]:
1104
- y [y == 2 ] = 1
1105
1090
estimator = clone (estimator_orig )
1106
1091
y = _enforce_estimator_tags_y (estimator , y )
1107
1092
@@ -1373,10 +1358,7 @@ def check_fit_score_takes_y(name, estimator_orig):
1373
1358
n_samples = 30
1374
1359
X = rnd .uniform (size = (n_samples , 3 ))
1375
1360
X = _pairwise_estimator_convert_X (X , estimator_orig )
1376
- if estimator_orig ._get_tags ()['binary_only' ]:
1377
- y = np .arange (n_samples ) % 2
1378
- else :
1379
- y = np .arange (n_samples ) % 3
1361
+ y = np .arange (n_samples ) % 3
1380
1362
estimator = clone (estimator_orig )
1381
1363
y = _enforce_estimator_tags_y (estimator , y )
1382
1364
set_random_state (estimator )
@@ -1406,8 +1388,6 @@ def check_estimators_dtypes(name, estimator_orig):
1406
1388
X_train_int_64 = X_train_32 .astype (np .int64 )
1407
1389
X_train_int_32 = X_train_32 .astype (np .int32 )
1408
1390
y = X_train_int_64 [:, 0 ]
1409
- if estimator_orig ._get_tags ()['binary_only' ]:
1410
- y [y == 2 ] = 1
1411
1391
y = _enforce_estimator_tags_y (estimator_orig , y )
1412
1392
1413
1393
methods = ["predict"
F438
span>, "transform" , "decision_function" , "predict_proba" ]
@@ -1581,6 +1561,7 @@ def check_estimators_partial_fit_n_features(name, estimator_orig):
1581
1561
estimator = clone (estimator_orig )
1582
1562
X , y = make_blobs (n_samples = 50 , random_state = 1 )
1583
1563
X -= X .min ()
1564
+ y = _enforce_estimator_tags_y (estimator_orig , y )
1584
1565
1585
1566
try :
1586
1567
if is_classifier (estimator ):
@@ -2047,11 +2028,7 @@ def check_classifiers_multilabel_representation_invariance(name,
2047
2028
def check_estimators_fit_returns_self (name , estimator_orig ,
2048
2029
readonly_memmap = False ):
2049
2030
"""Check if self is returned when calling fit"""
2050
- if estimator_orig ._get_tags ()['binary_only' ]:
2051
- n_centers = 2
2052
- else :
2053
- n_centers = 3
2054
- X , y = make_blobs (random_state = 0 , n_samples = 21 , centers = n_centers )
2031
+ X , y = make_blobs (random_state = 0 , n_samples = 21 )
2055
2032
# some want non-negative input
2056
2033
X -= X .min ()
2057
2034
X = _pairwise_estimator_convert_X (X , estimator_orig )
@@ -2093,10 +2070,7 @@ def check_supervised_y_2d(name, estimator_orig):
2093
2070
X = _pairwise_estimator_convert_X (
2094
2071
rnd .uniform (size = (n_samples , 3 )), estimator_orig
2095
2072
)
2096
- if tags ['binary_only' ]:
2097
- y = np .arange (n_samples ) % 2
2098
- else :
2099
- y = np .arange (n_samples ) % 3
2073
+ y = np .arange (n_samples ) % 3
2100
2074
y = _enforce_estimator_tags_y (estimator_orig , y )
2101
2075
estimator = clone (estimator_orig )
2102
2076
set_random_state (estimator )
@@ -2414,11 +2388,7 @@ def check_class_weight_balanced_linear_classifier(name, Classifier):
2414
2388
2415
2389
@ignore_warnings (category = FutureWarning )
2416
2390
def check_estimators_overwrite_params (name , estimator_orig ):
2417
- if estimator_orig ._get_tags ()['binary_only' ]:
2418
- n_centers = 2
2419
- else :
2420
- n_centers = 3
2421
- X , y = make_blobs (random_state = 0 , n_samples = 21 , centers = n_centers )
2391
+ X , y = make_blobs (random_state = 0 , n_samples = 21 )
2422
2392
# some want non-negative input
2423
2393
X -= X .min ()
2424
2394
X = _pairwise_estimator_convert_X (X , estimator_orig , kernel = rbf_kernel )
@@ -2489,7 +2459,8 @@ def check_no_attributes_set_in_init(name, estimator_orig):
2489
2459
def check_sparsify_coefficients (name , estimator_orig ):
2490
2460
X = np .array ([[- 2 , - 1 ], [- 1 , - 1 ], [- 1 , - 2 ], [1 , 1 ], [1 , 2 ], [2 , 1 ],
2491
2461
[- 1 , - 2 ], [2 , 2 ], [- 2 , - 2 ]])
2492
- y = [1 , 1 , 1 , 2 , 2 , 2 , 3 , 3 , 3 ]
2462
+ y = np .array ([1 , 1 , 1 , 2 , 2 , 2 , 3 , 3 , 3 ])
2463
+ y = _enforce_estimator_tags_y (estimator_orig , y )
2493
2464
est = clone (estimator_orig )
2494
2465
2495
2466
est .fit (X , y )
@@ -2513,7 +2484,7 @@ def check_classifier_data_not_an_array(name, estimator_orig):
2513
2484
X = np .array ([[3 , 0 ], [0 , 1 ], [0 , 2 ], [1 , 1 ], [1 , 2 ], [2 , 1 ],
2514
2485
[0 , 3 ], [1 , 0 ], [2 , 0 ], [4 , 4 ], [2 , 3 ], [3 , 2 ]])
2515
2486
X = _pairwise_estimator_convert_X (X , estimator_orig )
2516
- y = [1 , 1 , 1 , 2 , 2 , 2 , 1 , 1 , 1 , 2 , 2 , 2 ]
2487
+ y = np . array ( [1 , 1 , 1 , 2 , 2 , 2 , 1 , 1 , 1 , 2 , 2 , 2 ])
2517
2488
y = _enforce_estimator_tags_y (estimator_orig , y )
2518
2489
for obj_type in ["NotAnArray" , "PandasDataframe" ]:
2519
2490
check_estimators_data_not_an_array (name , estimator_orig , X , y ,
@@ -2649,6 +2620,9 @@ def _enforce_estimator_tags_y(estimator, y):
2649
2620
# Create strictly positive y. The minimal increment above 0 is 1, as
2650
2621
# y could be of integer dtype.
2651
2622
y += 1 + abs (y .min ())
2623
+ # Estimators with a `binary_only` tag only accept up to two unique y values
2624
+ if estimator ._get_tags ()["binary_only" ] and y .size > 0 :
2625
+ y = np .where (y == y .flat [0 ], y , y .flat [0 ] + 1 )
2652
2626
# Estimators in mono_output_task_error raise ValueError if y is of 1-D
2653
2627
# Convert into a 2-D y for those estimators.
2654
2628
if estimator ._get_tags ()["multioutput_only" ]:
0 commit comments