8000 MultiOutputClassifier.predict_proba fails if targets have different number of values · Issue #8093 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content
MultiOutputClassifier.predict_proba fails if targets have different number of values #8093
Closed
@pjbull

Description

@pjbull

Description

If two target columns are categorical and have a different number unique values, MultiOutputClassifier.predict_proba raises a value error when trying to dstack the probability matrices.

Steps/Code to Reproduce

Example:

from sklearn.linear_model import LogisticRegression
from sklearn.multioutput import MultiOutputClassifier

import numpy as np

# random features
X = np.random.normal(size=(100, 100))

# random labels
Y = np.concatenate([
        np.random.choice(['a', 'b'], (100, 1)),     # first column can have 2 values
        np.random.choice(['d', 'e', 'f'], (100, 1)) # second column can have 3 
    ], axis=1)

clf = MultiOutputClassifier(LogisticRegression())

clf.fit(X, Y)

clf.predict_proba(X)

Expected Results

No error is thrown. It looks like the RandomForestClassifier handles data of this shape and returns a list of numpy arrays. I would expect the same behavior in this case.

Actual Results

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-40-84c34a558c92> in <module>()
     18 clf.fit(X, Y)
     19 
---> 20 clf.predict_proba(X)

anaconda/lib/python3.5/site-packages/sklearn/multioutput.py in predict_proba(self, X)
    224 
    225         results = np.dstack([estimator.predict_proba(X) for estimator in
--> 226                             self.estimators_])
    227         return results
    228 

anaconda/lib/python3.5/site-packages/numpy/lib/shape_base.py in dstack(tup)
    366 
    367     """
--> 368     return _nx.concatenate([atleast_3d(_m) for _m in tup], 2)
    369 
    370 def _replace_zero_by_x_arrays(sub_arys):

ValueError: all the input array dimensions except for the concatenation axis must match exactly

Versions

Darwin-15.6.0-x86_64-i386-64bit
Python 3.5.2 |Anaconda custom (x86_64)| (default, Jul  2 2016, 17:52:12) 
[GCC 4.2.1 Compatible Apple LLVM 4.2 (clang-425.0.28)]
NumPy 1.11.1
SciPy 0.18.1
Scikit-Learn 0.18.1

If returning a list is the right fix, happy to submit a PR for this.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0