8000 Automatically move `y_true` to the same device and namespace as `y_pred` for metrics · Issue #31274 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content
Automatically move y_true to the same device and namespace as y_pred for metrics #31274
Open
@lucyleeow

Description

@lucyleeow

This is closely linked to #28668 but separate enough to warrant it's own issue (#28668 (comment)). This is mostly a summary of discussions so far. If we are happy with a decision, we can move to updating the documentation.


For classification metrics to support array API, there is a problem in the case where y_pred is not in the same namespace/device as y_true.

y_pred is likely to be the output of predict_proba or decision_function and would be in the same namespace/device as X (if we decide in #28668 that "everything should follow X").
y_true could be an integer array or a numpy array or pandas series (this is pertinent as y_true may be string labels)

Motivating use case:

Using e.g., GridSearchCV or cross_validate with a pipeline that moves X to GPU.
Consider a pipeline like below (copied from #28668 (comment)):

pipeline = make_pipeline(
   SomeDataFrameAwareFeatureExtractor(),
   MoveFeaturesToPyTorch(device="cuda"),
   SomeArrayAPICapableClassifier(),
)

Pipelines do not ever touch y so we are not able to alter y within the pipeline.
We would need to pass a metric to GridSearchCV or cross_validate, which would be passed y_true and y_pred on different namespace / devices.

Thus the motivation to automatically move y_true to the same namespace / device as y_pred, in metrics functions.

(Note another example is discussed in #30439 (comment))

As it is more likely that y_pred is on GPU, y_true follow y_pred was slightly preferred over y_pred follows y_true. Computation wise, CPU vs GPU is probably similar for metrics like log-loss, but for metrics that require sorting (e.g., ROC AUC) GPU may be faster? (see #30439 (comment) for more discussion on this point)

Question for my own clarification, the main motivation is for usability, so the user does not have to manually convert y_true ? Would a helper function to help the user convert y_true to the correct namespace/device be interesting?

cc @ogrisel @betatim

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    Status

    No status

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0