8000 [MRG] Separated regression metrics from other metrics in test_sample_… · maskani-moh/scikit-learn@c29c2b6 · GitHub
[go: up one dir, main page]

Skip to content

Commit c29c2b6

Browse files
nikitasingh981maskani-moh
authored andcommitted
[MRG] Separated regression metrics from other metrics in test_sample_weight_invariance in metrics/tests/test_common.py (scikit-learn#8537)
* Separated tests for regression features in test_sample_weight_invariance * Fixed pep8 * Removed unecessary check for regression * Updated regression metrics * Joel's suggestions
1 parent 4cade0b commit c29c2b6

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

sklearn/metrics/tests/test_common.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,13 +1013,26 @@ def check_sample_weight_invariance(name, metric, y1, y2):
10131013

10141014
def test_sample_weight_invariance(n_samples=50):
10151015
random_state = check_random_state(0)
1016+
# regression
1017+
y_true = random_state.random_sample(size=(n_samples,))
1018+
y_pred = random_state.random_sample(size=(n_samples,))
1019+
for name in ALL_METRICS:
1020+
if name not in REGRESSION_METRICS:
1021+
continue
1022+
if name in METRICS_WITHOUT_SAMPLE_WEIGHT:
1023+
continue
1024+
metric = ALL_METRICS[name]
1025+
yield _named_check(check_sample_weight_invariance, name), name,\
1026+
metric, y_true, y_pred
10161027

10171028
# binary
10181029
random_state = check_random_state(0)
10191030
y_true = random_state.randint(0, 2, size< 97EC /span>=(n_samples, ))
10201031
y_pred = random_state.randint(0, 2, size=(n_samples, ))
10211032
y_score = random_state.random_sample(size=(n_samples,))
10221033
for name in ALL_METRICS:
1034+
if name in REGRESSION_METRICS:
1035+
continue
10231036
if (name in METRICS_WITHOUT_SAMPLE_WEIGHT or
10241037
name in METRIC_UNDEFINED_BINARY):
10251038
continue
@@ -1037,6 +1050,8 @@ def test_sample_weight_invariance(n_samples=50):
10371050
y_pred = random_state.randint(0, 5, size=(n_samples, ))
10381051
y_score = random_state.random_sample(size=(n_samples, 5))
10391052
for name in ALL_METRICS:
1053+
if name in REGRESSION_METRICS:
1054+
continue
10401055
if (name in METRICS_WITHOUT_SAMPLE_WEIGHT or
10411056
name in METRIC_UNDEFINED_BINARY_MULTICLASS):
10421057
continue

0 commit comments

Comments
 (0)
0