|
19 | 19 | from ._testing import assert_array_almost_equal
|
20 | 20 | from ._testing import assert_allclose
|
21 | 21 | from ._testing import assert_allclose_dense_sparse
|
| 22 | +from ._testing import assert_array_less |
22 | 23 | from ._testing import set_random_state
|
23 | 24 | from ._testing import SkipTest
|
24 | 25 | from ._testing import ignore_warnings
|
@@ -141,6 +142,9 @@ def _yield_classifier_checks(classifier):
|
141 | 142 | yield check_classifiers_regression_target
|
142 | 143 | if tags["multilabel"]:
|
143 | 144 | yield check_classifiers_multilabel_representation_invariance
|
| 145 | + yield check_classifiers_multilabel_output_format_predict |
| 146 | + yield check_classifiers_multilabel_output_format_predict_proba |
| 147 | + yield check_classifiers_multilabel_output_format_decision_function |
144 | 148 | if not tags["no_validation"]:
|
145 | 149 | yield check_supervised_y_no_nan
|
146 | 150 | if not tags["multioutput_only"]:
|
@@ -651,7 +655,7 @@ def _set_checking_parameters(estimator):
|
651 | 655 | estimator.set_params(strategy="stratified")
|
652 | 656 |
|
653 | 657 | # Speed-up by reducing the number of CV or splits for CV estimators
|
654 |
| - loo_cv = ["RidgeCV"] |
| 658 | + loo_cv = ["RidgeCV", "RidgeClassifierCV"] |
655 | 659 | if name not in loo_cv and hasattr(estimator, "cv"):
|
656 | 660 | estimator.set_params(cv=3)
|
657 | 661 | if hasattr(estimator, "n_splits"):
|
@@ -2258,18 +2262,18 @@ def check_outliers_train(name, estimator_orig, readonly_memmap=True):
|
2258 | 2262 | estimator.fit(X)
|
2259 | 2263 |
|
2260 | 2264 |
|
2261 |
| -@ignore_warnings(category=(FutureWarning)) |
| 2265 | +@ignore_warnings(category=FutureWarning) |
2262 | 2266 | def check_classifiers_multilabel_representation_invariance(name, classifier_orig):
|
2263 |
| - |
2264 | 2267 | X, y = make_multilabel_classification(
|
2265 | 2268 | n_samples=100,
|
2266 |
| - n_features=20, |
| 2269 | + n_features=2, |
2267 | 2270 | n_classes=5,
|
2268 | 2271 | n_labels=3,
|
2269 | 2272 | length=50,
|
2270 | 2273 | allow_unlabeled=True,
|
2271 | 2274 | random_state=0,
|
2272 | 2275 | )
|
| 2276 | + X = scale(X) |
2273 | 2277 |
|
2274 | 2278 | X_train, y_train = X[:80], y[:80]
|
2275 | 2279 | X_test = X[80:]
|
@@ -2299,6 +2303,181 @@ def check_classifiers_multilabel_representation_invariance(name, classifier_orig
|
2299 | 2303 | assert type(y_pred) == type(y_pred_list_of_lists)
|
2300 | 2304 |
|
2301 | 2305 |
|
| 2306 | +@ignore_warnings(category=FutureWarning) |
| 2307 | +def check_classifiers_multilabel_output_format_predict(name, classifier_orig): |
| 2308 | + """Check the output of the `predict` method for classifiers supporting |
| 2309 | + multilabel-indicator targets.""" |
| 2310 | + classifier = clone(classifier_orig) |
| 2311 | + set_random_state(classifier) |
| 2312 | + |
| 2313 | + n_samples, test_size, n_outputs = 100, 25, 5 |
| 2314 | + X, y = make_multilabel_classification( |
| 2315 | + n_samples=n_samples, |
| 2316 | + n_features=2, |
| 2317 | + n_classes=n_outputs, |
| 2318 | + n_labels=3, |
| 2319 | + length=50, |
| 2320 | + allow_unlabeled=True, |
| 2321 | + random_state=0, |
| 2322 | + ) |
| 2323 | + X = scale(X) |
| 2324 | + |
| 2325 | + X_train, X_test = X[:-test_size], X[-test_size:] |
| 2326 | + y_train, y_test = y[:-test_size], y[-test_size:] |
| 2327 | + classifier.fit(X_train, y_train) |
| 2328 | + |
| 2329 | + response_method_name = "predict" |
| 2330 | + predict_method = getattr(classifier, response_method_name, None) |
| 2331 | + if predict_method is None: |
| 2332 | + raise SkipTest(f"{name} does not have a {response_method_name} method.") |
| 2333 | + |
| 2334 | + y_pred = predict_method(X_test) |
| 2335 | + |
| 2336 | + # y_pred.shape -> y_test.shape with the same dtype |
| 2337 | + assert isinstance(y_pred, np.ndarray), ( |
| 2338 | + f"{name}.predict is expected to output a NumPy array. Got " |
| 2339 | + f"{type(y_pred)} instead." |
| 2340 | + ) |
| 2341 | + assert y_pred.shape == y_test.shape, ( |
| 2342 | + f"{name}.predict outputs a NumPy array of shape {y_pred.shape} " |
| 2343 | + f"instead of {y_test.shape}." |
| 2344 | + ) |
| 2345 | + assert y_pred.dtype == y_test.dtype, ( |
| 2346 | + f"{name}.predict does not output the same dtype than the targets. " |
| 2347 | + f"Got {y_pred.dtype} instead of {y_test.dtype}." |
| 2348 | + ) |
| 2349 | + |
| 2350 | + |
| 2351 | +@ignore_warnings(category=FutureWarning) |
| 2352 | +def check_classifiers_multilabel_output_format_predict_proba(name, classifier_orig): |
| 2353 | + """Check the output of the `predict_proba` method for classifiers supporting |
| 2354 | + multilabel-indicator targets.""" |
| 2355 | + classifier = clone(classifier_orig) |
| 2356 | + set_random_state(classifier) |
| 2357 | + |
| 2358 | + n_samples, test_size, n_outputs = 100, 25, 5 |
| 2359 | + X, y = make_multilabel_classification( |
| 2360 | + n_samples=n_samples, |
| 2361 | + n_features=2, |
| 2362 | + n_classes=n_outputs, |
| 2363 | + n_labels=3, |
| 2364 | + length=50, |
| 2365 | + allow_unlabeled=True, |
| 2366 | + random_state=0, |
| 2367 | + ) |
| 2368 | + X = scale(X) |
| 2369 | + |
| 2370 | + X_train, X_test = X[:-test_size], X[-test_size:] |
| 2371 | + y_train = y[:-test_size] |
| 2372 | + classifier.fit(X_train, y_train) |
| 2373 | + |
| 2374 | + response_method_name = "predict_proba" |
| 2375 | + predict_proba_method = getattr(classifier, response_method_name, None) |
| 2376 | + if predict_proba_method is None: |
| 2377 | + raise SkipTest(f"{name} does not have a {response_method_name} method.") |
| 2378 | + |
| 2379 | + y_pred = predict_proba_method(X_test) |
| 2380 | + |
| 2381 | + # y_pred.shape -> 2 possibilities: |
| 2382 | + # - list of length n_outputs of shape (n_samples, 2); |
| 2383 | + # - ndarray of shape (n_samples, n_outputs). |
| 2384 | + # dtype should be floating |
| 2385 | + if isinstance(y_pred, list): |
| 2386 | + assert len(y_pred) == n_outputs, ( |
| 2387 | + f"When {name}.predict_proba returns a list, the list should " |
| 2388 | + "be of length n_outputs and contain NumPy arrays. Got length " |
| 2389 | + f"of {len(y_pred)} instead of {n_outputs}." |
| 2390 | + ) |
| 2391 | + for pred in y_pred: |
| 2392 | + assert pred.shape == (test_size, 2), ( |
| 2393 | + f"When {name}.predict_proba returns a list, this list " |
| 2394 | + "should contain NumPy arrays of shape (n_samples, 2). Got " |
| 2395 | + f"NumPy arrays of shape {pred.shape} instead of " |
| 2396 | + f"{(test_size, 2)}." |
| 2397 | + ) |
| 2398 | + assert pred.dtype.kind == "f", ( |
| 2399 | + f"When {name}.predict_proba returns a list, it should " |
| 2400 | + "contain NumPy arrays with floating dtype. Got " |
| 2401 | + f"{pred.dtype} instead." |
| 2402 | + ) |
| 2403 | + # check that we have the correct probabilities |
| 2404 | + err_msg = ( |
| 2405 | + f"When {name}.predict_proba returns a list, each NumPy " |
| 2406 | + "array should contain probabilities for each class and " |
| 2407 | + "thus each row should sum to 1 (or close to 1 due to " |
| 2408 | + "numerical errors)." |
| 2409 | + ) |
| 2410 | + assert_allclose(pred.sum(axis=1), 1, err_msg=err_msg) |
| 2411 | + elif isinstance(y_pred, np.ndarray): |
| 2412 | + assert y_pred.shape == (test_size, n_outputs), ( |
| 2413 | + f"When {name}.predict_proba returns a NumPy array, the " |
| 2414 | + f"expected shape is (n_samples, n_outputs). Got {y_pred.shape}" |
| 2415 | + f" instead of {(test_size, n_outputs)}." |
| 2416 | + ) |
| 2417 | + assert y_pred.dtype.kind == "f", ( |
| 2418 | + f"When {name}.predict_proba returns a NumPy array, the " |
| 2419 | + f"expected data type is floating. Got {y_pred.dtype} instead." |
| 2420 | + ) |
| 2421 | + err_msg = ( |
| 2422 | + f"When {name}.predict_proba returns a NumPy array, this array " |
| 2423 | + "is expected to provide probabilities of the positive class " |
| 2424 | + "and should therefore contain values between 0 and 1." |
| 2425 | + ) |
| 2426 | + assert_array_less(0, y_pred, err_msg=err_msg) |
| 2427 | + assert_array_less(y_pred, 1, err_msg=err_msg) |
| 2428 | + else: |
| 2429 | + raise ValueError( |
| 2430 | + f"Unknown returned type {type(y_pred)} by {name}." |
| 2431 | + "predict_proba. A list or a Numpy array is expected." |
| 2432 | + ) |
| 2433 | + |
| 2434 | + |
| 2435 | +@ignore_warnings(category=FutureWarning) |
| 2436 | +def check_classifiers_multilabel_output_format_decision_function(name, classifier_orig): |
| 2437 | + """Check the output of the `decision_function` method for classifiers supporting |
| 2438 | + multilabel-indicator targets.""" |
| 2439 | + classifier = clone(classifier_orig) |
| 2440 | + set_random_state(classifier) |
| 2441 | + |
| 2442 | + n_samples, test_size, n_outputs = 100, 25, 5 |
| 2443 | + X, y = make_multilabel_classification( |
| 2444 | + n_samples=n_samples, |
| 2445 | + n_features=2, |
| 2446 | + n_classes=n_outputs, |
| 2447 | + n_labels=3, |
| 2448 | + length=50, |
| 2449 | + allow_unlabeled=True, |
| 2450 | + random_state=0, |
| 2451 | + ) |
| 2452 | + X = scale(X) |
| 2453 | + |
| 2454 | + X_train, X_test = X[:-test_size], X[-test_size:] |
| 2455 | + y_train = y[:-test_size] |
| 2456 | + classifier.fit(X_train, y_train) |
| 2457 | + |
| 2458 | + response_method_name = "decision_function" |
| 2459 | + decision_function_method = getattr(classifier, response_method_name, None) |
| 2460 | + if decision_function_method is None: |
| 2461 | + raise SkipTest(f"<
BAC9
span class=pl-s1>{name} does not have a {response_method_name} method.") |
| 2462 | + |
| 2463 | + y_pred = decision_function_method(X_test) |
| 2464 | + |
| 2465 | + # y_pred.shape -> y_test.shape with floating dtype |
| 2466 | + assert isinstance(y_pred, np.ndarray), ( |
| 2467 | + f"{name}.decision_function is expected to output a NumPy array." |
| 2468 | + f" Got {type(y_pred)} instead." |
| 2469 | + ) |
| 2470 | + assert y_pred.shape == (test_size, n_outputs), ( |
| 2471 | + f"{name}.decision_function is expected to provide a NumPy array " |
| 2472 | + f"of shape (n_samples, n_outputs). Got {y_pred.shape} instead of " |
| 2473 | + f"{(test_size, n_outputs)}." |
| 2474 | + ) |
| 2475 | + assert y_pred.dtype.kind == "f", ( |
| 2476 | + f"{name}.decision_function is expected to output a floating dtype." |
| 2477 | + f" Got {y_pred.dtype} instead." |
| 2478 | + ) |
| 2479 | + |
| 2480 | + |
2302 | 2481 | @ignore_warnings(category=FutureWarning)
|
2303 | 2482 | def check_estimators_fit_returns_self(name, estimator_orig, readonly_memmap=False):
|
2304 | 2483 | """Check if self is returned when calling fit."""
|
|
0 commit comments