@@ -1500,25 +1500,29 @@ def check_estimators_partial_fit_n_features(name, estimator_orig):
1500
1500
1501
1501
@ignore_warnings (category = (DeprecationWarning , FutureWarning ))
1502
1502
def check_classifier_multioutput (name , estimator ):
1503
- n_samples , n_labels = 42 , 5
1503
+ n_samples , n_labels , n_classes = 42 , 5 , 3
1504
1504
tags = _safe_tags (estimator )
1505
1505
estimator = clone (estimator )
1506
1506
X , y = make_multilabel_classification (random_state = 42 ,
1507
1507
n_samples = n_samples ,
1508
- n_labels = n_labels )
1508
+ n_labels = n_labels ,
1509
+ n_classes = n_classes )
1509
1510
estimator .fit (X , y )
1510
1511
y_pred = estimator .predict (X )
1511
1512
1512
- assert ( y_pred .shape == (n_samples , n_labels ),
1513
- "The shape of the prediction for multioutput data is "
1514
- " incorrect. Expected {}, got {}."
1515
- .format ((n_samples , n_labels ), y_pred .shape ))
1513
+ assert y_pred .shape == (n_samples , n_classes ), (
1514
+ "The shape of the prediction for multioutput data is "
1515
+ " incorrect. Expected {}, got {}."
1516
+ .format ((n_samples , n_labels ), y_pred .shape ))
1516
1517
assert y_pred .dtype .kind == 'i'
1517
1518
1518
1519
if hasattr (estimator , "decision_function" ):
1519
1520
decision = estimator .decision_function (X )
1520
1521
assert isinstance (decision , np .ndarray )
1521
- assert decision .shape == (n_samples , n_labels )
1522
+ assert decision .shape == (n_samples , n_classes ), (
1523
+ "The shape of the decision function output for "
1524
+ "multioutput data is incorrect. Expected {}, got {}."
1525
+ .format ((n_samples , n_classes ), decision .shape ))
1522
1526
1523
1527
dec_pred = (decision > 0 ).astype (np .int )
1524
1528
dec_exp = estimator .classes_ [dec_pred ]
@@ -1528,14 +1532,20 @@ def check_classifier_multioutput(name, estimator):
1528
1532
y_prob = estimator .predict_proba (X )
1529
1533
1530
1534
if isinstance (y_prob , list ) and not tags ['poor_score' ]:
1531
- for i in range (n_labels ):
1532
- y_prob [i ].shape == (n_samples , n_labels )
1535
+ for i in range (n_classes ):
1536
+ assert y_prob [i ].shape == (n_samples , 2 ), (
1537
+ "The shape of the probability for multioutput data is"
1538
+ " incorrect. Expected {}, got {}."
1539
+ .format ((n_samples , 2 ), y_prob [i ].shape ))
1533
1540
assert_array_equal (
1534
1541
np .argmax (y_prob [i ], axis = 1 ).astype (np .int ),
1535
1542
y_pred [:, i ]
1536
1543
)
1537
1544
elif not tags ['poor_score' ]:
1538
- y_prob .shape == (n_samples , n_labels )
1545
+ assert y_prob .shape == (n_samples , n_labels ), (
1546
+ "The shape of the probability for multioutput data is"
1547
+ " incorrect. Expected {}, got {}."
1548
+ .format ((n_samples , n_labels ), y_prob .shape ))
1539
1549
assert_array_equal (y_prob .round ().astype (int ), y_pred )
1540
1550
1541
1551
if (hasattr (estimator , "decision_function" ) and
@@ -1561,8 +1571,12 @@ def check_regressor_multioutput(name, estimator):
1561
1571
estimator .fit (X , y )
1562
1572
y_pred = estimator .predict (X )
1563
1573
1564
- assert y_pred .dtype == np .dtype ('float' )
1565
- assert y_pred .shape == y .shape
1574
+ assert y_pred .dtype == np .dtype ('float64' ), (
1575
+ "Multioutput predictions by a regressor are expected to be"
1576
+ " floating-point precision. Got {} instead" .format (y_pred .dtype ))
1577
+ assert y_pred .shape == y .shape , (
1578
+ "The shape of the orediction for multioutput data is incorrect."
1579
+ " Expected {}, got {}." )
1566
1580
1567
1581
1568
1582
@ignore_warnings (category = (DeprecationWarning , FutureWarning ))
0 commit comments