10000 input validation with shape (0,N>0) for RandomForestClassifier, DecisionTreeClassifier, others? · Issue #1793 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

input validation with shape (0,N>0) for RandomForestClassifier, DecisionTreeClassifier, others? #1793

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
erg opened this issue Mar 20, 2013 · 11 comments
Labels
Bug Easy Well-defined and straightforward way to resolve

Comments

@erg
Copy link
Contributor
erg commented Mar 20, 2013
from sklearn.ensemble import RandomForestClassifier

X = np.ones(shape=(0,1))
y = np.ones(shape=(0,1))

rfc = RandomForestClassifier()

In [36]: rfc.fit(X,y)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-36-eec77b0a9246> in <module>()
----> 1 rfc.fit(XX,yy)

/usr/local/lib/python2.7/site-packages/sklearn/ensemble/forest.pyc in fit(self, X, y, sample_weight)
    363                 random_state.randint(MAX_INT),
    364                 verbose=self.verbose)
--> 365             for i in xrange(n_jobs))
    366 
    367         # Reduce

/usr/local/lib/python2.7/site-packages/sklearn/externals/joblib/parallel.pyc in __call__(self, iterable)
    512         try:
    513             for function, args, kwargs in iterable:
--> 514                 self.dispatch(function, args, kwargs)
    515 
    516             self.retrieve()

/usr/local/lib/python2.7/site-packages/sklearn/externals/joblib/parallel.pyc in dispatch(self, func, args, kwargs)
    309         """
    310         if self._pool is None:
--> 311             job = ImmediateApply(func, args, kwargs)
    312             index = len(self._jobs)
    313             if not _verbosity_filter(index, self.verbose):

/usr/local/lib/python2.7/site-packages/sklearn/externals/joblib/parallel.pyc in __init__(self, func, args, kwargs)
    133         # Don't delay the application, to avoid keeping the input
    134         # arguments in memory
--> 135         self.results = func(*args, **kwargs)
    136 
    137     def get(self):

/usr/local/lib/python2.7/site-packages/sklearn/ensemble/forest.pyc in _parallel_build_trees(n_trees, forest, X, y, sample_weight, sample_mask, X_argsorted, seed, verbose)
     86                 curr_sample_weight = sample_weight.copy()
     87 
---> 88             indices = random_state.randint(0, n_samples, n_samples)
     89             sample_counts = bincount(indices, minlength=n_samples)
     90 

/usr/local/lib/python2.7/site-packages/numpy/random/mtrand.so in mtrand.RandomState.randint (numpy/random/mtrand/mtrand.c:6443)()

ValueError: low >= high

Also for decision trees:

from sklearn.tree import DecisionTreeClassifier

X = np.ones(shape=(0,1))
y = np.ones(shape=(0,1))

dtc = DecisionTreeClassifier()
dtc.fit(X,y)

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-41-bd6795874945> in <module>()
----> 1 dtc.fit(X,y)

/usr/local/lib/python2.7/site-packages/sklearn/tree/tree.pyc in fit(self, X, y, sample_mask, X_argsorted, check_input, sample_weight)
    358                          sample_weight=sample_weight,
    359                          sample_mask=sample_mask,
--> 360                          X_argsorted=X_argsorted)
    361 
    362         if self.n_outputs_ == 1:

/usr/local/lib/python2.7/site-packages/sklearn/tree/_tree.so in sklearn.tree._tree.Tree.build (sklearn/tree/_tree.c:4823)()

/usr/local/lib/python2.7/site-packages/sklearn/tree/_tree.so in sklearn.tree._tree.Tree.build (sklearn/tree/_tree.c:4636)()

/usr/local/lib/python2.7/site-packages/sklearn/tree/_tree.so in sklearn.tree._tree.Tree.recursive_partition (sklearn/tree/_tree.c:5156)()

ValueError: Attempting to find a split with an empty sample_mask.
@jaquesgrobler
Copy link
Member

This might be a silly question, but would you actually want to input (0,1) ? or is it just to be thorough and have an error message for the case?

By the way, it bombs for (1,1) too ( ValueError: ndarray is not Fortran contiguous) meaning
X = np.ones(shape=(1,1))
y = np.ones(shape=(1,1))

.. so should that be included in the input validation?
I havent look which other models do this yet.

10000

@erg
Copy link
Contributor Author
erg commented Mar 20, 2013

Your (1,1) shape works for me.

It's just a nonsensical input that I had on accident and noticed the error message sucks.

@jaquesgrobler
Copy link
Member

Okay I thought it was something like that :)
I'll have a look in a wee bit if I can reproduce my 1,1 error again. If I can I'll give the traceback

@jaquesgrobler
Copy link
Member

Regarding the (1,1) case, here's my trace:

In [20]: X = np.ones(shape=(1,1))

In [21]: y = np.ones(shape=(1,1))

In [22]: rfc = RandomForestClassifier()

In [23]: rfc.fit(X,y)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)

/home/jaques/<ipython console> in <module>()

/home/jaques/scikit-learn/sklearn/ensemble/forest.pyc in fit(self, X, y, sample_weight)
    372                 seeds[i],
    373                 verbose=self.verbose)
--> 374             for i in range(n_jobs))
    375 
    376         # Reduce


/home/jaques/scikit-learn/sklearn/externals/joblib/parallel.pyc in __call__(self, iterable)
    512         try:
    513             for function, args, kwargs in iterable:
--> 514                 self.dispatch(function, args, kwargs)
    515 
    516             self.retrieve()

/home/jaques/scikit-learn/sklearn/externals/joblib/parallel.pyc in dispatch(self, func, args, kwargs)
    309         """
    310         if self._pool is None:
--> 311             job = ImmediateApply(func, args, kwargs)
    312             index = len(self._jobs)
    313             if not _verbosity_filter(index, self.verbose):

/home/jaques/scikit-learn/sklearn/externals/joblib/parallel.pyc in __init__(self, func, args, kwargs)
    133         # Don't delay the application, to avoid keeping the input

    134         # arguments in memory

--> 135         self.results = func(*args, **kwargs)
    136 
    137     def get(self):

/home/jaques/scikit-learn/sklearn/ensemble/forest.pyc in _parallel_build_trees(n_trees, forest, X, y, sample_weight, sample_mask, X_argsorted, seeds, verbose)
     97                      sample_mask=curr_sample_mask,
     98                      X_argsorted=X_argsorted,
---> 99                      check_input=False)
    100 
    101             tree.indices_ = curr_sample_mask

/home/jaques/scikit-learn/sklearn/tree/tree.pyc in fit(self, X, y, sample_mask, X_argsorted, check_input, sample_weight)
    370                          sample_weight=sample_weight,
    371                          sample_mask=sample_mask,
--> 372                          X_argsorted=X_argsorted)
    373 
    374         if self.n_outputs_ == 1:

/home/jaques/scikit-learn/sklearn/tree/_tree.so in sklearn.tree._tree.Tree.build (sklearn/tree/_tree.c:4800)()

/home/jaques/scikit-learn/sklearn/tree/_tree.so in sklearn.tree._tree.Tree.build (sklearn/tree/_tree.c:4613)()

/home/jaques/scikit-learn/sklearn/tree/_tree.so in sklearn.tree._tree.Tree.recursive_partition (sklearn/tree/_tree.c:4905)()

ValueError: ndarray is not Fortran contiguous

You can't reproduce this? 😕

@glouppe
Copy link
Contributor
glouppe commented Mar 21, 2013

Such tricky cases should be tested in in the common tests in my opinion. I
am quite sure trees are not the only models to crash in those cases
(unfortunately).

My 2 cents.

On 21 March 2013 11:36, Jaques Grobler notifications@github.com wrote:

Regarding the (1,1) case, here's my trace:

In [20]: X = np.ones(shape=(1,1))

In [21]: y = np.ones(shape=(1,1))

In [22]: rfc = RandomForestClassifier()

In [23]: rfc.fit(X,y)

ValueError Traceback (most recent call last)

/home/jaques/ in ()

/home/jaques/scikit-learn/sklearn/ensemble/forest.pyc in fit(self, X, y, sample_weight)
372 seeds[i],
373 verbose=self.verbose)
--> 374 for i in range(n_jobs))
375
376 # Reduce

/home/jaques/scikit-learn/sklearn/externals/joblib/parallel.pyc in call(self, iterable)
512 try:
513 for function, args, kwargs in iterable:
--> 514 self.dispatch(function, args, kwargs)
515
516 self.retrieve()

/home/jaques/scikit-learn/sklearn/externals/joblib/parallel.pyc in dispatch(self, func, args, kwargs)
309 """
310 if self._pool is None:
--> 311 job = ImmediateApply(func, args, kwargs)
312 index = len(self._jobs)
313 if not _verbosity_filter(index, self.verbose):

/home/jaques/scikit-learn/sklearn/externals/joblib/parallel.pyc in init(self, func, args, kwargs)
133 # Don't delay the application, to avoid keeping the input

134         # arguments in memory

--> 135 self.results = func(_args, *_kwargs)
136
137 def get(self):

/home/jaques/scikit-learn/sklearn/ensemble/forest.pyc in parallel_build_trees(n_trees, forest, X, y, sample_weight, sample_mask, X_argsorted, seeds, verbose)
97 sample_mask=curr_sample_mask,
98 X_argsorted=X_argsorted,
---> 99 check_input=False)
100
101 tree.indices
= curr_sample_mask

/home/jaques/scikit-learn/sklearn/tree/tree.pyc in fit(self, X, y, sample_mask, X_argsorted, check_input, sample_weight)
370 sample_weight=sample_weight,
371 sample_mask=sample_mask,
--> 372 X_argsorted=X_argsorted)
373
374 if self.n_outputs_ == 1:

/home/jaques/scikit-learn/sklearn/tree/_tree.so in sklearn.tree._tree.Tree.build (sklearn/tree/_tree.c:4800)()

/home/jaques/scikit-learn/sklearn/tree/_tree.so in sklearn.tree._tree.Tree.build (sklearn/tree/_tree.c:4613)()

/home/jaques/scikit-learn/sklearn/tree/_tree.so in sklearn.tree._tree.Tree.recursive_partition (sklearn/tree/_tree.c:4905)()

ValueError: ndarray is not Fortran contiguous

You can't reproduce this? [image: 😕]


Reply to this email directly or view it on GitHubhttps://github.com//issues/1793#issuecomment-15229372
.

@jaquesgrobler
Copy link
Member

So then these fellows should all be added to the input validation I guess. I'll have a look which other models also do this

@amueller
Copy link
Member

closing as duplicate of #1678. Totally agree with @glouppe. This is known for being untested and failing probably a lot.

@amueller amueller closed this as completed Mar 31, 2013 8000
@dwiel
Copy link
dwiel commented Mar 6, 2015

I'm still getting this error:

Python 2.7.9 |Anaconda 2.0.0 (64-bit)| (default, Dec 15 2014, 10:33:51) 
[GCC 4.4.7 20120313 (Red Hat 4.4.7-1)] on linux2
Type "help", "copyright", "credits" or "license" for more information.
Anaconda is brought to you by Continuum Analytics.
Please check out: http://continuum.io/thanks and https://binstar.org
>>> import sklearn
>>> sklearn.__version__
'0.15.2'
>>> from sklearn.ensemble import RandomForestClassifier
>>> import numpy as np
>>> 
>>> X = np.ones(shape=(0,2))
>>> y = np.ones(shape=(0,2))
>>> 
>>> rfc = RandomForestClassifier()
>>> 
>>> rfc.fit(X,y)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/ubuntu/anaconda/lib/python2.7/site-packages/sklearn/ensemble/forest.py", line 279, in fit
    for i in range(n_jobs))
  File "/home/ubuntu/anaconda/lib/python2.7/site-packages/sklearn/externals/joblib/parallel.py", line 653, in __call__
    self.dispatch(function, args, kwargs)
  File "/home/ubuntu/anaconda/lib/python2.7/site-packages/sklearn/externals/joblib/parallel.py", line 400, in dispatch
    job = ImmediateApply(func, args, kwargs)
  File "/home/ubuntu/anaconda/lib/python2.7/site-packages/sklearn/externals/joblib/parallel.py", line 138, in __init__
    self.results = func(*args, **kwargs)
  File "/home/ubuntu/anaconda/lib/python2.7/site-packages/sklearn/ensemble/forest.py", line 83, in _parallel_build_trees
    indices = random_state.randint(0, n_samples, n_samples)
  File "mtrand.pyx", line 941, in mtrand.RandomState.randint (numpy/random/mtrand/mtrand.c:9569)
ValueError: low >= high

It appears that #1678 is refering to inputs of shape (1, n_features) where this is (0, n_features). Also there is #2293, but that also only checks (1, n_features). It would still be helpful to get a better error message for (0, n_features)

@ogrisel
Copy link
Member
ogrisel commented Mar 6, 2015

Please try on the current master. I get:

>>> rfc.fit(X, y)
Traceback (most recent call last):
  File "<ipython-input-5-b17a4c91c80c>", line 1, in <module>
    rfc.fit(X, y)
  File "/Users/ogrisel/code/scikit-learn/sklearn/ensemble/forest.py", line 195, in fit
    X = check_array(X, dtype=DTYPE, accept_sparse="csc")
  File "/Users/ogrisel/code/scikit-learn/sklearn/utils/validation.py", line 354, in check_array
    % (n_samples, shape_repr, ensure_min_samples))
ValueError: Found array with 0 sample(s) (shape=(0, 2)) while a minimum of 1 is required.

@dwiel
Copy link
dwiel commented Mar 6, 2015

shoot, I'm sorry. I saw the timestamps of all the comments well before
0.15.2 was released and so figured testing on that version would be
sufficient. Thanks!

On Fri, Mar 6, 2015 at 9:22 AM, Olivier Grisel notifications@github.com
wrote:

Please try on the current master. I get:

rfc.fit(X, y)
Traceback (most recent call last):
File "", line 1, in
rfc.fit(X, y)
File "/Users/ogrisel/code/scikit-learn/sklearn/ensemble/forest.py", line 195, in fit
X = check_array(X, dtype=DTYPE, accept_sparse="csc")
File "/Users/ogrisel/code/scikit-learn/sklearn/utils/validation.py", line 354, in check_array
% (n_samples, shape_repr, ensure_min_samples))
ValueError: Found array with 0 sample(s) (shape=(0, 2)) while a minimum of 1 is required.


Reply to this email directly or view it on GitHub
#1793 (comment)
.

@ogrisel
Copy link
Member
ogrisel commented Mar 6, 2015

No pbm. We should release more often ;P

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Bug Easy Well-defined and straightforward way to resolve
Projects
None yet
Development

No branches or pull requests

6 participants
0