diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index 61d7b12b7564d..af3fd704f614e 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -7,6 +7,8 @@ # Virgile Fritsch # Alexandre Gramfort # Lars Buitinck +# Joel Nothman +# Guillaume Lemaitre # License: BSD from collections import defaultdict @@ -233,6 +235,41 @@ def named_steps(self): def _final_estimator(self): return self.steps[-1][1] + def get_subsequence(self, start=None, stop=None): + """Extract a Pipeline consisting of a subsequence of steps + + Parameters + ---------- + start : int or str, optional + The index (0-based) or name of the step where the extracted + subsequence begins (inclusive bound). By default, get from the + beginning. Negative integers are intpreted as subtracted from the + number of steps in the Pipeline. + stop : int or str, optional + The index (0-based) or name of the step before which the extracted + subsequence ends (exclusive bound). By default, get until the end. + Negative integers are intpreted as subtracted from the number of + steps in the Pipeline. + + Returns + ------- + sub_pipeline : Pipeline instance + The steps of this pipeline range from the start step to the stop + step specified. The constituent estimators are not copied: if the + Pipeline had been fit, so will be the returned Pipeline. + + The return type will be of the same type as self, if a subclass + is used and if its constructor is compatible. + """ + if isinstance(start, six.string_types): + start = [name for name, _ in self.steps].index(start) + if isinstance(stop, six.string_types): + stop = [name for name, _ in self.steps].index(stop) + + kwargs = self.get_params(deep=False) + kwargs['steps'] = self.steps[start:stop] + return type(self)(**kwargs) + # Estimator interface def _fit(self, X, y=None, **fit_params): diff --git a/sklearn/tests/test_pipeline.py b/sklearn/tests/test_pipeline.py index 33e3128931aff..4ef4faaf4f210 100644 --- a/sklearn/tests/test_pipeline.py +++ b/sklearn/tests/test_pipeline.py @@ -894,3 +894,55 @@ def test_pipeline_memory(): assert_equal(ts, cached_pipe_2.named_steps['transf_2'].timestamp_) finally: shutil.rmtree(cachedir) + + +class MyPipeline(Pipeline): + pass + + +class MyPipelineNoMemory(Pipeline): + def __init__(self, steps, other_param): + super(MyPipelineNoMemory, self).__init__(steps) + self.other_param = other_param + + +def test_pipeline_get_subsequence(): + pipe = Pipeline([('transf1', Transf()), + ('transf2', Transf()), + ('predict', Mult())]) + pipe.fit(np.arange(5)[:, None], np.arange(5)) + + for start, stop, expected_slice in [ + (None, None, slice(None, None)), + ('transf2', None, slice(1, None)), + (None, 'predict', slice(None, 2)), + (1, 'predict', slice(1, 2)), + (1, -1, slice(1, -1)), + (-1, None, slice(-1, None)), + ]: + new_pipe = pipe.get_subsequence(start, stop) + expected_steps = pipe.steps[expected_slice] + assert_equal(new_pipe.steps, expected_steps) + assert_dict_equal(new_pipe.named_steps, dict(expected_steps)) + for name in new_pipe.named_steps: + assert_true(new_pipe.named_steps[name] is pipe.named_steps[name]) + + # invalid step name + assert_raise_message(ValueError, "'foo' is not in list", + pipe.get_subsequence, 'foo') + + # test subtype is maintained by get_subsequence + for memory in [None, '/path/to/somewhere']: + pipe = MyPipeline([('transf1', Transf()), + ('predict', Mult())], + memory=memory) + new_pipe = pipe.get_subsequence(1) + assert_equal(new_pipe.steps, pipe.steps[1:]) + assert_equal(pipe.memory, new_pipe.memory) + + pipe = MyPipelineNoMemory([('transf1', Transf()), + ('predict', Mult())], + other_param='blah') + new_pipe = pipe.get_subsequence(1) + assert_equal(new_pipe.steps, pipe.steps[1:]) + assert_equal(pipe.other_param, new_pipe.other_param)