8000 [WIP] ENH allow extraction of subsequence pipeline by jnothman · Pull Request #8431 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[WIP] ENH allow extraction of subsequence pipeline #8431

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
8000
Diff view
Diff view
37 changes: 37 additions & 0 deletions sklearn/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
# Virgile Fritsch
# Alexandre Gramfort
# Lars Buitinck
# Joel Nothman
# Guillaume Lemaitre
# License: BSD

from collections import defaultdict
Expand Down Expand Up @@ -233,6 +235,41 @@ def named_steps(self):
def _final_estimator(self):
return self.steps[-1][1]

def get_subsequence(self, start=None, stop=None):
"""Extract a Pipeline consisting of a subsequence of steps

Parameters
----------
start : int or str, optional
The index (0-based) or name of the step where the extracted
subsequence begins (inclusive bound). By default, get from the
beginning. Negative integers are intpreted as subtracted from the
number of steps in the Pipeline.
stop : int or str, optional
The index (0-based) or name of the step before which the extracted
subsequence ends (exclusive bound). By default, get until the end.
Negative integers are intpreted as subtracted from the number of
steps in the Pipeline.

Returns
-------
sub_pipeline : Pipeline instance
The steps of this pipeline range from the start step to the stop
step specified. The constituent estimators are not copied: if the
Pipeline had been fit, so will be the returned Pipeline.

The return type will be of the same type as self, if a subclass
is used and if its constructor is compatible.
"""
if isinstance(start, six.string_types):
start = [name for name, _ in self.steps].index(start)
if isinstance(stop, six.string_types):
stop = [name for name, _ in self.steps].index(stop)

kwargs = self.get_params(deep=False)
kwargs['steps'] = self.steps[start:stop]
return type(self)(**kwargs)

# Estimator interface

def _fit(self, X, y=None, **fit_params):
Expand Down
52 changes: 52 additions & 0 deletions sklearn/tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,3 +894,55 @@ def test_pipeline_memory():
assert_equal(ts, cached_pipe_2.named_steps['transf_2'].timestamp_)
finally:
shutil.rmtree(cachedir)


class MyPipeline(Pipeline):
pass


class MyPipelineNoMemory(Pipeline):
def __init__(self, steps, other_param):
super(MyPipelineNoMemory, self).__init__(steps)
self.other_param = other_param


def test_pipeline_get_subsequence():
pipe = Pipeline([('transf1', Transf()),
('transf2', Transf()),
('predict', Mult())])
pipe.fit(np.arange(5)[:, None], np.arange(5))

for start, stop, expected_slice in [
(None, None, slice(None, None)),
('transf2', None, slice(1, None)),
(None, 'predict', slice(None, 2)),
(1, 'predict', slice(1, 2)),
(1, -1, slice(1, -1)),
(-1, None, slice(-1, None)),
]:
new_pipe = pipe.get_subsequence(start, stop)
expected_steps = pipe.steps[expected_slice]
assert_equal(new_pipe.steps, expected_steps)
assert_dict_equal(new_pipe.named_steps, dict(expected_steps))
for name in new_pipe.named_steps:
assert_true(new_pipe.named_steps[name] is pipe.named_steps[name])

# invalid step name
assert_raise_message(ValueError, "'foo' is not in list",
pipe.get_subsequence, 'foo')

# test subtype is maintained by get_subsequence
for memory in [None, '/path/to/somewhere']:
pipe = MyPipeline([('transf1', Transf()),
('predict', Mult())],
memory=memory)
new_pipe = pipe.get_subsequence(1)
assert_equal(new_pipe.steps, pipe.steps[1:])
assert_equal(pipe.memory, new_pipe.memory)

pipe = MyPipelineNoMemory([('transf1', Transf()),
('predict', Mult())],
other_param='blah')
new_pipe = pipe.get_subsequence(1)
assert_equal(new_pipe.steps, pipe.steps[1:])
assert_equal(pipe.other_param, new_pipe.other_param)
0