8000 F1 score not calculated properly · Issue #27189 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content
F1 score not calculated properly #27189
@kjelljorner

Description

@kjelljorner

Describe the bug

According to the definition of the F1 score for two classes, it can be calculated as

$$ 2 \frac{2tp}{2tp + fp + fn} $$

or

$$ 2 \frac{precision * recall}{precision + recall} $$

From what I can see, scikit-learn uses some variant of the second definition. The problem is that the first definition can be valid, while the second gives a division by zero as the precision is not defined.

$$ precision = \frac{tp}{tp + fp} $$

$$ recall = \frac{tp}{tp + fn} $$

For definitions of precision and recall, see Wikipedia.

Below, I give a code example where this happens.

Steps/Code to Reproduce

from sklearn.metrics import f1_score, precision_score, recall_score
import sklearn.metrics
import numpy as np

y_true = [True, False, True]
y_pred = [False, False, False]

tn, fp, fn, tp = sklearn.metrics.confusion_matrix(
    y_true, y_pred, labels=[False, True]
).ravel()

print("TN:", tn)
print("FP:", fp)
print("FN:", fn)
print("TP:", tp)

precision = tp / (tp + fp)
recall = tp / (tp + fn)
print("Precision:", precision)
print("Recall:", recall)

f1_true = 2 * tp / (2 * tp + fp + fn)
print("F1 (true):", f1_true)

f1_sk = f1_score(y_true, y_pred, zero_division=np.nan)
print("F1 (sklearn):", f1_sk)

Expected Results

TN: 1
FP: 0
FN: 2
TP: 0
Precision: nan
Recall: 0.0
F1 (true): 0.0
F1 (sklearn): 0.0

Actual Results

TN: 1
FP: 0
FN: 2
TP: 0
<ipython-input-1-d59969af6bb0>:17: RuntimeWarning: invalid value encountered in scalar divide
  precision = tp / (tp + fp)
Precision: nan
Recall: 0.0
F1 (true): 0.0
F1 (sklearn): nan

Versions

System:
    python: 3.11.5 | packaged by conda-forge | (main, Aug 27 2023, 03:33:12) [Clang 15.0.7 ]
executable: /Users/Kjell/mambaforge/envs/sklearn/bin/python3.11
   machine: macOS-13.4.1-arm64-arm-64bit

Python dependencies:
      sklearn: 1.3.0
          pip: 23.2.1
   setuptools: 68.1.2
        numpy: 1.25.2
        scipy: 1.11.2
       Cython: None
       pandas: None
   matplotlib: None
       joblib: 1.3.2
threadpoolctl: 3.2.0

Built with OpenMP: True

threadpoolctl info:
       user_api: blas
   internal_api: openblas
    num_threads: 8
         prefix: libopenblas
       filepath: /Users/Kjell/mambaforge/envs/sklearn/lib/libopenblas.0.dylib
        version: 0.3.23
threading_layer: openmp
   architecture: VORTEX

       user_api: openmp
   internal_api: openmp
    num_threads: 8
         prefix: libomp
       filepath: /Users/Kjell/mambaforge/envs/sklearn/lib/libomp.dylib
        version: None

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0