|
50 | 50 | from sklearn.svm import LinearSVC
|
51 | 51 | from sklearn.utils.validation import check_random_state
|
52 | 52 |
|
| 53 | +from sklearn.metrics import mean_squared_error |
| 54 | + |
53 | 55 | from sklearn.tree._classes import SPARSE_SPLITTERS
|
54 | 56 |
|
55 | 57 |
|
@@ -1441,6 +1443,42 @@ def test_max_samples_exceptions(name, max_samples, exc_type, exc_msg):
|
1441 | 1443 | est.fit(X, y)
|
1442 | 1444 |
|
1443 | 1445 |
|
| 1446 | +@pytest.mark.parametrize('name', FOREST_REGRESSORS) |
| 1447 | +def test_max_samples_boundary_regressors(name): |
| 1448 | + X_train, X_test, y_train, y_test = train_test_split( |
| 1449 | + X_reg, y_reg, train_size=0.7, test_size=0.3, random_state=0) |
| 1450 | + |
| 1451 | + ms_1_predict = FOREST_REGRESSORS[name]( |
| 1452 | + max_samples=1.0, random_state=0).fit( |
| 1453 | + X_train, y_train).predict(X_test) |
| 1454 | + ms_None_predict = FOREST_REGRESSORS[name]( |
| 1455 | + max_samples=None, random_state=0).fit( |
| 1456 | + X_train, y_train).predict(X_test) |
| 1457 | + |
| 1458 | + ms_1_ms = mean_squared_error(ms_1_predict, y_test) |
| 1459 | + ms_None_ms = mean_squared_error(ms_None_predict, y_test) |
| 1460 | + |
| 1461 | + assert np.all(ms_1_ms == ms_None_ms) |
| 1462 | + |
| 1463 | + |
| 1464 | +@pytest.mark.parametrize('name', FOREST_CLASSIFIERS) |
| 1465 | +def test_max_samples_boundary_classifiers(name): |
| 1466 | + rng = np.random.RandomState(1) |
| 1467 | + |
| 1468 | + X_train = rng.randn(10000, 2) |
| 1469 | + y_train = rng.randn(10000) > 0 |
| 1470 | + X_test = rng.randn(1000, 2) |
| 1471 | + |
| 1472 | + ms_1_proba = FOREST_CLASSIFIERS[name]( |
| 1473 | + max_samples=1.0, random_state=0).fit( |
| 1474 | + X_train, y_train).predict_proba(X_test) |
| 1475 | + ms_None_proba = FOREST_CLASSIFIERS[name]( |
| 1476 | + max_samples=None, random_state=0).fit( |
| 1477 | + X_train, y_train).predict_proba(X_test) |
| 1478 | + |
| 1479 | + assert np.all(ms_1_proba == ms_None_proba) |
| 1480 | + |
| 1481 | + |
1444 | 1482 | def test_forest_y_sparse():
|
1445 | 1483 | X = [[1, 2, 3]]
|
1446 | 1484 | y = csr_matrix([4, 5, 6])
|
|
0 commit comments