8000 add _last_non_passthrough_estimator · scikit-learn/scikit-learn@bc9df43 · GitHub
[go: up one dir, main page]

Skip to content

Commit bc9df43

Browse files
committed
add _last_non_passthrough_estimator
1 parent 3870302 commit bc9df43

File tree

2 files changed

+11
-0
lines changed

2 files changed

+11
-0
lines changed

sklearn/pipeline.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,15 @@ def _final_estimator(self):
244244
estimator = self.steps[-1][1]
245245
return 'passthrough' if estimator is None else estimator
246246

247+
@property
248+
def _final_non_passthrough_estimator(self):
249+
final_estimator = None
250+
for name, est in reversed(self.steps):
251+
if est not in [None, 'passthrough']:
252+
final_estimator = est
253+
break
254+
return final_estimator
255+
247256
def _log_message(self, step_idx):
248257
if not self.verbose:
249258
return None

sklearn/tests/test_pipeline.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -699,6 +699,7 @@ def make():
699699
assert_array_equal([[exp]], pipeline.fit_transform(X, y))
700700
assert_array_equal([exp], pipeline.fit(X).predict(X))
701701
assert_array_equal(X, pipeline.inverse_transform([[exp]]))
702+
assert pipeline._final_non_passthrough_estimator is mult5
702703

703704
pipeline = make()
704705
pipeline.set_params(last=passthrough)
@@ -707,6 +708,7 @@ def make():
707708
assert_array_equal([[exp]], pipeline.fit(X, y).transform(X))
708709
assert_array_equal([[exp]], pipeline.fit_transform(X, y))
709710
assert_array_equal( 5545 X, pipeline.inverse_transform([[exp]]))
711+
assert pipeline._final_non_passthrough_estimator is mult3
710712
assert_raise_message(AttributeError,
711713
"'str' object has no attribute 'predict'",
712714
getattr, pipeline, 'predict')

0 commit comments

Comments
 (0)
0