8000 MNT Make sample_weight checking more consistent in regression metrics by lucyleeow · Pull Request #30886 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

MNT Make sample_weight checking more consistent in regression metrics #30886

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
17 changes: 17 additions & 0 deletions doc/whats_new/upcoming_changes/sklearn.metrics/30886.fix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
- Additional `sample_weight` checking has been added to
:func:`metrics.mean_absolute_error`,
:func:`metrics.mean_pinball_loss`,
:func:`metrics.mean_absolute_percentage_error`,
:func:`metrics.mean_squared_error`,
:func:`metrics.root_mean_squared_error`,
:func:`metrics.mean_squared_log_error`,
:func:`metrics.root_mean_squared_log_error`,
:func:`metrics.explained_variance_score`,
:func:`metrics.r2_score`,
:func:`metrics.mean_tweedie_deviance`,
:func:`metrics.mean_poisson_deviance`,
:func:`metrics.mean_gamma_deviance` and
:func:`metrics.d2_tweedie_score`.
`sample_weight` can only be 1D, consistent to `y_true` and `y_pred` in length
or a scalar.
By :user:`Lucy Liu <lucyleeow>`.
67 changes: 32 additions & 35 deletions sklearn/metrics/_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,10 @@
]


def _check_reg_targets(y_true, y_pred, multioutput, dtype="numeric", xp=None):
"""Check that y_true and y_pred belong to the same regression task.
def _check_reg_targets(
y_true, y_pred, sample_weight, multioutput, dtype="numeric", xp=None
):
"""Check that y_true, y_pred and sample_weight belong to the same regression task.

To reduce redundancy when calling `_find_matching_floating_dtype`,
please use `_check_reg_targets_with_floating_dtype` instead.
Expand All @@ -71,6 +73,9 @@ def _check_reg_targets(y_true, y_pred, multioutput, dtype="numeric", xp=None):
y_pred : array-like of shape (n_samples,) or (n_samples, n_outputs)
Estimated target values.

sample_weight : array-like of shape (n_samples,) or None
Sample weights.

multioutput : array-like or string in ['raw_values', uniform_average',
'variance_weighted'] or None
None is accepted due to backward compatibility of r2_score().
Expand All @@ -95,6 +100,9 @@ def _check_reg_targets(y_true, y_pred, multioutput, dtype="numeric", xp=None):
y_pred : array-like of shape (n_samples, n_outputs)
Estimated target values.

sample_weight : array-like of shape (n_samples,) or None
Sample weights.

multioutput : array-like of shape (n_outputs) or string in ['raw_values',
uniform_average', 'variance_weighted'] or None
Custom output weights if ``multioutput`` is array-like or
Expand All @@ -103,9 +111,11 @@ def _check_reg_targets(y_true, y_pred, multioutput, dtype="numeric", xp=None):
"""
xp, _ = get_namespace(y_true, y_pred, multioutput, xp=xp)

check_consistent_length(y_true, y_pred)
check_consistent_length(y_true, y_pred, sample_weight)
y_true = check_array(y_true, ensure_2d=False, dtype=dtype)
y_pred = check_array(y_pred, ensure_2d=False, dtype=dtype)
if sample_weight is not None:
sample_weight = _check_sample_weight(sample_weight, y_true, dtype=dtype)

if y_true.ndim == 1:
y_true = xp.reshape(y_true, (-1, 1))
Expand Down Expand Up @@ -141,14 +151,13 @@ def _check_reg_targets(y_true, y_pred, multioutput, dtype="numeric", xp=None):
)
y_type = "continuous" if n_outputs == 1 else "continuous-multioutput"

return y_type, y_true, y_pred, multioutput
return y_type, y_true, y_pred, sample_weight, multioutput


def _check_reg_targets_with_floating_dtype(
y_true, y_pred, sample_weight, multioutput, xp=None
):
"""Ensures that y_true, y_pred, and sample_weight correspond to the same
regression task.
"""Ensures y_true, y_pred, and sample_weight correspond to same regression task.

Extends `_check_reg_targets` by automatically selecting a suitable floating-point
data type for inputs using `_find_matching_floating_dtype`.
Expand Down Expand Up @@ -197,15 +206,10 @@ def _check_reg_targets_with_floating_dtype(
"""
dtype_name = _find_matching_floating_dtype(y_true, y_pred, sample_weight, xp=xp)

y_type, y_true, y_pred, multioutput = _check_reg_targets(
y_true, y_pred, multioutput, dtype=dtype_name, xp=xp
y_type, y_true, y_pred, sample_weight, multioutput = _check_reg_targets(
y_true, y_pred, sample_weight, multioutput, dtype=dtype_name, xp=xp
)

# _check_reg_targets does not accept sample_weight as input.
# Convert sample_weight's data type separately to match dtype_name.
if sample_weight is not None:
sample_weight = xp.asarray(sample_weight, dtype=dtype_name)

return y_type, y_true, y_pred, sample_weight, multioutput


Expand Down Expand Up @@ -282,8 +286,6 @@ def mean_absolute_error(
)
)

check_consistent_length(y_true, y_pred, sample_weight)

output_errors = _average(
xp.abs(y_pred - y_true), weights=sample_weight, axis=0, xp=xp
)
Expand Down Expand Up @@ -383,7 +385,6 @@ def mean_pinball_loss(
)
)

check_consistent_length(y_true, y_pred, sample_weight)
diff = y_true - y_pred
sign = xp.astype(diff >= 0, diff.dtype)
loss = alpha * sign * diff - (1 - alpha) * (1 - sign) * diff
Expand Down Expand Up @@ -489,7 +490,6 @@ def mean_absolute_percentage_error(
y_true, y_pred, sample_weight, multioutput, xp=xp
)
)
check_consistent_length(y_true, y_pred, sample_weight)
epsilon = xp.asarray(xp.finfo(xp.float64).eps, dtype=y_true.dtype, device=device_)
y_true_abs = xp.abs(y_true)
mape = xp.abs(y_pred - y_true) / xp.maximum(y_true_abs, epsilon)
Expand Down Expand Up @@ -581,7 +581,6 @@ def mean_squared_error(
y_true, y_pred, sample_weight, multioutput, xp=xp
)
)
check_consistent_length(y_true, y_pred, sample_weight)
output_errors = _average((y_true - y_pred) ** 2, axis=0, weights=sample_weight)

if isinstance(multioutput, str):
Expand Down Expand Up @@ -753,8 +752,10 @@ def mean_squared_log_error(
"""
xp, _ = get_namespace(y_true, y_pred)

_, y_true, y_pred, _, _ = _check_reg_targets_with_floating_dtype(
y_true, y_pred, sample_weight, multioutput, xp=xp
_, y_true, y_pred, sample_weight, multioutput = (
_check_reg_targets_with_floating_dtype(
y_true, y_pred, sample_weight, multioutput, xp=xp
)
)

if xp.any(y_true <= -1) or xp.any(y_pred <= -1):
Expand Down Expand Up @@ -829,8 +830,10 @@ def root_mean_squared_log_error(
"""
xp, _ = get_namespace(y_true, y_pred)

_, y_true, y_pred, _, _ = _check_reg_targets_with_floating_dtype(
y_true, y_pred, sample_weight, multioutput, xp=xp
_, y_true, y_pred, sample_weight, multioutput = (
_check_reg_targets_with_floating_dtype(
y_true, y_pred, sample_weight, multioutput, xp=xp
)
)

if xp.any(y_true <= -1) or xp.any(y_pred <= -1):
Expand Down Expand Up @@ -912,13 +915,12 @@ def median_absolute_error(
>>> median_absolute_error(y_true, y_pred, multioutput=[0.3, 0.7])
0.85
"""
y_type, y_true, y_pred, multioutput = _check_reg_targets(
y_true, y_pred, multioutput
_, y_true, y_pred, sample_weight, multioutput = _check_reg_targets(
y_true, y_pred, sample_weight, multioutput
)
if sample_weight is None:
output_errors = np.median(np.abs(y_pred - y_true), axis=0)
else:
sample_weight = _check_sample_weight(sample_weight, y_pred)
output_errors = _weighted_percentile(
np.abs(y_pred - y_true), sample_weight=sample_weight
)
Expand Down Expand Up @@ -1106,8 +1108,6 @@ def explained_variance_score(
)
)

check_consistent_length(y_true, y_pred, sample_weight)

y_diff_avg = _average(y_true - y_pred, weights=sample_weight, axis=0)
numerator = _average(
(y_true - y_pred - y_diff_avg) ** 2, weights=sample_weight, axis=0
Expand Down Expand Up @@ -1278,8 +1278,6 @@ def r2_score(
)
)

check_consistent_length(y_true, y_pred, sample_weight)

if _num_samples(y_pred) < 2:
msg = "R^2 score is not well-defined with less than two samples."
warnings.warn(msg, UndefinedMetricWarning)
Expand Down Expand Up @@ -1343,7 +1341,9 @@ def max_error(y_true, y_pred):
1.0
"""
xp, _ = get_namespace(y_true, y_pred)
y_type, y_true, y_pred, _ = _check_reg_targets(y_true, y_pred, None, xp=xp)
y_type, y_true, y_pred, _, _ = _check_reg_targets(
y_true, y_pred, sample_weight=None, multioutput=None, xp=xp
)
if y_type == "continuous-multioutput":
raise ValueError("Multioutput not supported in max_error")
return float(xp.max(xp.abs(y_true - y_pred)))
Expand Down Expand Up @@ -1448,7 +1448,6 @@ def mean_tweedie_deviance(y_true, y_pred, *, sample_weight=None, power=0):
)
if y_type == "continuous-multioutput":
raise ValueError("Multioutput not supported in mean_tweedie_deviance")
check_consistent_length(y_true, y_pred, sample_weight)

if sample_weight is not None:
sample_weight = column_or_1d(sample_weight)
Expand Down Expand Up @@ -1773,10 +1772,9 @@ def d2_pinball_score(
>>> d2_pinball_score(y_true, y_true, alpha=0.1)
1.0
"""
y_type, y_true, y_pred, multioutput = _check_reg_targets(
y_true, y_pred, multioutput
_, y_true, y_pred, sample_weight, multioutput = _check_reg_targets(
y_true, y_pred, sample_weight, multioutput
)
check_consistent_length(y_true, y_pred, sample_weight)

if _num_samples(y_pred) < 2:
msg = "D^2 score is not well-defined with less than two samples."
Expand All @@ -1796,7 +1794,6 @@ def d2_pinball_score(
np.percentile(y_true, q=alpha * 100, axis=0), (len(y_true), 1)
)
else:
sample_weight = _check_sample_weight(sample_weight, y_true)
y_quantile = np.tile(
_weighted_percentile(
y_true, sample_weight=sample_weight, percentile_rank=alpha * 100
Expand Down
26 changes: 26 additions & 0 deletions sklearn/metrics/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1588,6 +1588,32 @@ def test_regression_sample_weight_invariance(name):
check_sample_weight_invariance(name, metric, y_true, y_pred)


@pytest.mark.parametrize(
"name",
sorted(
set(ALL_METRICS).intersection(set(REGRESSION_METRICS))
- METRICS_WITHOUT_SAMPLE_WEIGHT
),
)
def test_regression_with_invalid_sample_weight(name):
# Check that `sample_weight` with incorrect length raises error
n_samples = 50
random_state = check_random_state(0)
y_true = random_state.random_sample(size=(n_samples,))
y_pred = random_state.random_sample(size=(n_samples,))
metric = ALL_METRICS[name]

sample_weight = random_state.random_sample(size=(n_samples - 1,))
with pytest.raises(ValueError, match="Found input variables with inconsistent"):
metric(y_true, y_pred, sample_weight=sample_weight)

sample_weight = random_state.random_sample(size=(n_samples * 2,)).reshape(
(n_samples, 2)
)
with pytest.raises(ValueError, match="Sample weights must be 1D array or scalar"):
metric(y_true, y_pred, sample_weight=sample_weight)


@pytest.mark.parametrize(
"name",
sorted(
Expand Down
8 changes: 5 additions & 3 deletions sklearn/metrics/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,9 @@ def test__check_reg_targets():

for (type1, y1, n_out1), (type2, y2, n_out2) in product(EXAMPLES, repeat=2):
if type1 == type2 and n_out1 == n_out2:
y_type, y_check1, y_check2, multioutput = _check_reg_targets(y1, y2, None)
y_type, y_check1, y_check2, _, _ = _check_reg_targets(
y1, y2, sample_weight=None, multioutput=None
)
assert type1 == y_type
if type1 == "continuous":
assert_array_equal(y_check1, np.reshape(y1, (-1, 1)))
Expand All @@ -340,7 +342,7 @@ def test__check_reg_targets():
assert_array_equal(y_check2, y2)
else:
with pytest.raises(ValueError):
_check_reg_targets(y1, y2, None)
_check_reg_targets(y1, y2, sample_weight=None, multioutput=None)


def test__check_reg_targets_exception():
Expand All @@ -351,7 +353,7 @@ def test__check_reg_targets_exception():
)
)
with pytest.raises(ValueError, match=expected_message):
_check_reg_targets([1, 2, 3], [[1], [2], [3]], invalid_multioutput)
_check_reg_targets([1, 2, 3], [[1], [2], [3]], None, invalid_multioutput)


def test_regression_multioutput_array():
Expand Down
12 changes: 7 additions & 5 deletions sklearn/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2169,16 +2169,18 @@ def _check_sample_weight(
"""
n_samples = _num_samples(X)

if dtype is not None and dtype not in [np.float32, np.float64]:
dtype = np.float64
xp, _ = get_namespace(X)

if dtype is not None and dtype not in [xp.float32, xp.float64]:
dtype = xp.float64

if sample_weight is None:
sample_weight = np.ones(n_samples, dtype=dtype)
sample_weight = xp.ones(n_samples, dtype=dtype)
elif isinstance(sample_weight, numbers.Number):
sample_weight = np.full(n_samples, sample_weight, dtype=dtype)
sample_weight = xp.full(n_samples, sample_weight, dtype=dtype)
else:
if dtype is None:
dtype = [np.float64, np.float32]
dtype = [xp.float64, xp.float32]
sample_weight = check_array(
sample_weight,
accept_sparse=False,
Expand Down
0