From 1f63ead2e8c8a19a719f452ae055b8be875849a7 Mon Sep 17 00:00:00 2001 From: Joel Nothman Date: Wed, 22 Feb 2017 18:12:22 +1100 Subject: [PATCH] ENH allow extraction of subsequence pipeline Conceptually Fixes #8414 and related issues. Alternative to #2568 without __getitem__ and mixed semantics. Designed to assist in model inspection and particularly to replicate the composite transformer represented by steps of the pipeline with the exception of the last. I.e. pipe.get_subsequence(0, -1) is a common idiom. I feel like this becomes more necessary when considering more API-consistent clone behaviour as per #8350 as Pipeline(pipe.steps[:-1]) is no longer possible. --- sklearn/pipeline.py | 37 ++++++++++++++++++++++++ sklearn/tests/test_pipeline.py | 52 ++++++++++++++++++++++++++++++++++ 2 files changed, 89 insertions(+) diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index 61d7b12b7564d..af3fd704f614e 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -7,6 +7,8 @@ # Virgile Fritsch # Alexandre Gramfort # Lars Buitinck +# Joel Nothman +# Guillaume Lemaitre # License: BSD from collections import defaultdict @@ -233,6 +235,41 @@ def named_steps(self): def _final_estimator(self): return self.steps[-1][1] + def get_subsequence(self, start=None, stop=None): + """Extract a Pipeline consisting of a subsequence of steps + + Parameters + ---------- + start : int or str, optional + The index (0-based) or name of the step where the extracted + subsequence begins (inclusive bound). By default, get from the + beginning. Negative integers are intpreted as subtracted from the + number of steps in the Pipeline. + stop : int or str, optional + The index (0-based) or name of the step before which the extracted + subsequence ends (exclusive bound). By default, get until the end. + Negative integers are intpreted as subtracted from the number of + steps in the Pipeline. + + Returns + ------- + sub_pipeline : Pipeline instance + The steps of this pipeline range from the start step to the stop + step specified. The constituent estimators are not copied: if the + Pipeline had been fit, so will be the returned Pipeline. + + The return type will be of the same type as self, if a subclass + is used and if its constructor is compatible. + """ + if isinstance(start, six.string_types): + start = [name for name, _ in self.steps].index(start) + if isinstance(stop, six.string_types): + stop = [name for name, _ in self.steps].index(stop) + + kwargs = self.get_params(deep=False) + kwargs['steps'] = self.steps[start:stop] + return type(self)(**kwargs) + # Estimator interface def _fit(self, X, y=None, **fit_params): diff --git a/sklearn/tests/test_pipeline.py b/sklearn/tests/test_pipeline.py index 33e3128931aff..4ef4faaf4f210 100644 --- a/sklearn/tests/test_pipeline.py +++ b/sklearn/tests/test_pipeline.py @@ -894,3 +894,55 @@ def test_pipeline_memory(): assert_equal(ts, cached_pipe_2.named_steps['transf_2'].timestamp_) finally: shutil.rmtree(cachedir) + + +class MyPipeline(Pipeline): + pass + + +class MyPipelineNoMemory(Pipeline): + def __init__(self, steps, other_param): + super(MyPipelineNoMemory, self).__init__(steps) + self.other_param = other_param + + +def test_pipeline_get_subsequence(): + pipe = Pipeline([('transf1', Transf()), + ('transf2', Transf()), + ('predict', Mult())]) + pipe.fit(np.arange(5)[:, None], np.arange(5)) + + for start, stop, expected_slice in [ + (None, None, slice(None, None)), + ('transf2', None, slice(1, None)), + (None, 'predict', slice(None, 2)), + (1, 'predict', slice(1, 2)), + (1, -1, slice(1, -1)), + (-1, None, slice(-1, None)), + ]: + new_pipe = pipe.get_subsequence(start, stop) + expected_steps = pipe.steps[expected_slice] + assert_equal(new_pipe.steps, expected_steps) + assert_dict_equal(new_pipe.named_steps, dict(expected_steps)) + for name in new_pipe.named_steps: + assert_true(new_pipe.named_steps[name] is pipe.named_steps[name]) + + # invalid step name + assert_raise_message(ValueError, "'foo' is not in list", + pipe.get_subsequence, 'foo') + + # test subtype is maintained by get_subsequence + for memory in [None, '/path/to/somewhere']: + pipe = MyPipeline([('transf1', Transf()), + ('predict', Mult())], + memory=memory) + new_pipe = pipe.get_subsequence(1) + assert_equal(new_pipe.steps, pipe.steps[1:]) + assert_equal(pipe.memory, new_pipe.memory) + + pipe = MyPipelineNoMemory([('transf1', Transf()), + ('predict', Mult())], + other_param='blah') + new_pipe = pipe.get_subsequence(1) + assert_equal(new_pipe.steps, pipe.steps[1:]) + assert_equal(pipe.other_param, new_pipe.other_param)