8000 Exception When Calling cross_val_predict with Multi-Output Estimators and Methods Besides 'predict' · Issue #10028 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

8000
Skip to content
Exception When Calling cross_val_predict with Multi-Output Estimators and Methods Besides 'predict' #10028
Closed
@waltaskew

Description

@waltaskew

Description

ValueError thrown when calling cross_val_predict on multi-output models when the method is one of 'decision_function', 'predict_proba' or 'predict_log_proba'

Steps/Code to Reproduce

import numpy
import sklearn.ensemble
import sklearn.model_selection

X = numpy.arange(400).reshape(100, 4)
y = numpy.arange(200).reshape(100, 2)
est = sklearn.ensemble.RandomForestClassifier()

sklearn.model_selection.cross_val_predict(est, X, y, method='predict_proba')

Expected Results

No error is thrown.

Actual Results

Traceback (most recent call last):
  File "p.py", line 9, in <module>
    sklearn.model_selection.cross_val_predict(est, X, y, method='predict_proba')
  File "venv/lib/python3.6/site-packages/sklearn/model_selection/_validation.py", line 674, in cross_val_predict
    y = le.fit_transform(y)
  File "venv/lib/python3.6/site-packages/sklearn/preprocessing/label.py", line 111, in fit_transform
    y = column_or_1d(y, warn=True)
  File "venv/lib/python3.6/site-packages/sklearn/utils/validation.py", line 614, in column_or_1d
    raise ValueError("bad input shape {0}".format(shape))
ValueError: bad input shape (100, 2)

Versions

Python 3.6.2 (default, Jul 29 2017, 00:00:00)
[GCC 4.8.4]
NumPy 1.13.3
SciPy 1.0.0
Scikit-Learn 0.19.1

It looks like this is a regression from 0.18 introduced here: #7889
It looks like this pull request contains more multi-output friendly code-paths: #8773

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0