8000 TST add test for pipeline in partial dependence (#14079) · scikit-learn/scikit-learn@4a6264d · GitHub
[go: up one dir, main page]

Skip to content

Commit 4a6264d

Browse files
glemaitrejnothman
authored andcommitted
TST add test for pipeline in partial dependence (#14079)
1 parent fd1d210 commit 4a6264d

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

sklearn/inspection/tests/test_partial_dependence.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222
from sklearn.datasets import make_classification, make_regression
2323
from sklearn.cluster import KMeans
2424
from sklearn.metrics import r2_score
25+
from sklearn.pipeline import make_pipeline
2526
from sklearn.preprocessing import PolynomialFeatures
27+
from sklearn.preprocessing import StandardScaler
2628
from sklearn.dummy import DummyClassifier
2729
from sklearn.base import BaseEstimator, ClassifierMixin
2830
from sklearn.utils.testing import assert_allclose
@@ -393,6 +395,31 @@ def test_partial_dependence_sample_weight():
393395
assert np.corrcoef(pdp, values)[0, 1] > 0.99
394396

395397

398+
def test_partial_dependence_pipeline():
399+
# check that the partial dependence support pipeline
400+
iris = load_iris()
401+
402+
scaler = StandardScaler()
403+
clf = DummyClassifier(random_state=42)
404+
pipe = make_pipeline(scaler, clf)
405+
406+
clf.fit(scaler.fit_transform(iris.data), iris.target)
407+
pipe.fit(iris.data, iris.target)
408+
409+
features = 0
410+
pdp_pipe, values_pipe = partial_dependence(
411+
pipe, iris.data, features=[features]
412+
)
413+
pdp_clf, values_clf = partial_dependence(
414+
clf, scaler.transform(iris.data), features=[features]
415+
)
416+
assert_allclose(pdp_pipe, pdp_clf)
417+
assert_allclose(
418+
values_pipe[0],
419+
values_clf[0] * scaler.scale_[features] + scaler.mean_[features]
420+
)
421+
422+
396423
def test_plot_partial_dependence(pyplot):
397424
# Test partial dependence plot function.
398425
boston = load_boston()

0 commit comments

Comments
 (0)
0