8000 Exception When Calling cross_val_predict with Multi-Output Estimators and Methods Besides 'predict' · Issue #10028 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

Exception When Calling cross_val_predict with Multi-Output Estimators and Methods Besides 'predict' #10028

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
waltaskew opened this issue Oct 27, 2017 · 2 comments

Comments

@waltaskew
Copy link
waltaskew commented Oct 27, 2017

Description

ValueError thrown when calling cross_val_predict on multi-output models when the method is one of 'decision_function', 'predict_proba' or 'predict_log_proba'

Steps/Code to Reproduce

import numpy
import sklearn.ensemble
import sklearn.model_selection

X = numpy.arange(400).reshape(100, 4)
y = numpy.arange(200).reshape(100, 2)
est = sklearn.ensemble.RandomForestClassifier()

sklearn.model_selection.cross_val_predict(est, X, y, method='predict_proba')

Expected Results

No error is thrown.

Actual Results

Traceback (most recent call last):
  File "p.py", line 9, in <module>
    sklearn.model_selection.cross_val_predict(est, X, y, method='predict_proba')
  File "venv/lib/python3.6/site-packages/sklearn/model_selection/_validation.py", line 674, in cross_val_predict
    y = le.fit_transform(y)
  File "venv/lib/python3.6/site-packages/sklearn/preprocessing/label.py", line 111, in fit_transform
    y = column_or_1d(y, warn=True)
  File "venv/lib/python3.6/site-packages/sklearn/utils/validation.py", line 614, in column_or_1d
    raise ValueError("bad input shape {0}".format(shape))
ValueError: bad input shape (100, 2)

Versions

Python 3.6.2 (default, Jul 29 2017, 00:00:00)
[GCC 4.8.4]
NumPy 1.13.3
SciPy 1.0.0
Scikit-Learn 0.19.1

It looks like this is a regression from 0.18 introduced here: #7889
It looks like this pull request contains more multi-output friendly code-paths: #8773

@amueller
Copy link
Member

So that's the same issue as reported in #8773 right?

@waltaskew
Copy link
Author
6081

So that's the same issue as reported in #8773 right?

Ah, sorry -- you're right -- different manifestation of the same issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants
0