8000 add test for `max_samples` boundary · scikit-learn/scikit-learn@3f8a2a6 · GitHub
[go: up one dir, main page]

Skip to content

Commit 3f8a2a6

Browse files
committed
add test for max_samples boundary
1 parent e6a7e85 commit 3f8a2a6

File tree

1 file changed

+38
-0
lines changed

1 file changed

+38
-0
lines changed

sklearn/ensemble/tests/test_forest.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@
5050
from sklearn.svm import LinearSVC
5151
from sklearn.utils.validation import check_random_state
5252

53+
from sklearn.metrics import mean_squared_error
54+
5355
from sklearn.tree._classes import SPARSE_SPLITTERS
5456

5557

@@ -1441,6 +1443,42 @@ def test_max_samples_exceptions(name, max_samples, exc_type, exc_msg):
14411443
est.fit(X, y)
14421444

14431445

1446+
@pytest.mark.parametrize('name', FOREST_REGRESSORS)
1447+
def test_max_samples_boundary_regressors(name):
1448+
X_train, X_test, y_train, y_test = train_test_split(
1449+
X_reg, y_reg, train_size=0.7, test_size=0.3, random_state=0)
1450+
1451+
ms_1_predict = FOREST_REGRESSORS[name](
1452+
max_samples=1.0, random_state=0).fit(
1453+
X_train, y_train).predict(X_test)
1454+
ms_None_predict = FOREST_REGRESSORS[name](
1455+
max_samples=None, random_state=0).fit(
1456+
X_train, y_train).predict(X_test)
1457+
1458+
ms_1_ms = mean_squared_error(ms_1_predict, y_test)
1459+
ms_None_ms = mean_squared_error(ms_None_predict, y_test)
1460+
1461+
assert np.all(ms_1_ms == ms_None_ms)
1462+
1463+
1464+
@pytest.mark.parametrize('name', FOREST_CLASSIFIERS)
1465+
def test_max_samples_boundary_classifiers(name):
1466+
rng = np.random.RandomState(1)
1467+
1468+
X_train = rng.randn(10000, 2)
1469+
y_train = rng.randn(10000) > 0
1470+
X_test = rng.randn(1000, 2)
1471+
1472+
ms_1_proba = FOREST_CLASSIFIERS[name](
1473+
max_samples=1.0, random_state=0).fit(
1474+
X_train, y_train).predict_proba(X_test)
1475+
ms_None_proba = FOREST_CLASSIFIERS[name](
1476+
max_samples=None, random_state=0).fit(
1477+
X_train, y_train).predict_proba(X_test)
1478+
1479+
assert np.all(ms_1_proba == ms_None_proba)
1480+
1481+
14441482
def test_forest_y_sparse():
14451483
X = [[1, 2, 3]]
14461484
y = csr_matrix([4, 5, 6])

0 commit comments

Comments
 (0)
0