8000 ENH Fix averaged RMSE (#17309) · scikit-learn/scikit-learn@9368545 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9368545

Browse files
authored
ENH Fix averaged RMSE (#17309)
1 parent d5de894 commit 9368545

File tree

3 files changed

+18
-5
lines changed

3 files changed

+18
-5
lines changed

doc/whats_new/v0.24.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,14 @@ Changelog
6161
change since `None` was defaulting to these values already.
6262
:pr:`16493` by :user:`Darshan N <DarshanGowda0>`.
6363

64+
:mod:`sklearn.metrics`
65+
......................
66+
67+
- |Fix| Fixed a bug in :func:`metrics.mean_squared_error` where the
68+
average of multiple RMSE values was incorrectly calculated as the root of the
69+
average of multiple MSE values.
70+
:pr:`17309` by :user:`Swier Heeres <swierh>`
71+
6472
:mod:`sklearn.model_selection`
6573
..............................
6674

sklearn/metrics/_regression.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,8 @@ def mean_squared_error(y_true, y_pred, *,
244244
>>> y_pred = [[0, 2],[-1, 2],[8, -5]]
245245
>>> mean_squared_error(y_true, y_pred)
246246
0.708...
247+
>>> mean_squared_error(y_true, y_pred, squared=False)
248+
0.822...
247249
>>> mean_squared_error(y_true, y_pred, multioutput='raw_values')
248250
array([0.41666667, 1. ])
249251
>>> mean_squared_error(y_true, y_pred, multioutput=[0.3, 0.7])
@@ -255,15 +257,18 @@ def mean_squared_error(y_true, y_pred, *,
255257
check_consistent_length(y_true, y_pred, sample_weight)
256258
output_errors = np.average((y_true - y_pred) ** 2, axis=0,
257259
weights=sample_weight)
260+
261+
if not squared:
262+
output_errors = np.sqrt(output_errors)
263+
258264
if isinstance(multioutput, str):
259265
if multioutput == 'raw_values':
260-
return output_errors if squared else np.sqrt(output_errors)
266+
return output_errors
261267
elif multioutput == 'uniform_average':
262268
# pass None as weights to np.average: uniform mean
263269
multioutput = None
264270

265-
mse = np.average(output_errors, weights=multioutput)
266-
return mse if squared else np.sqrt(mse)
271+
return np.average(output_errors, weights=multioutput)
267272

268273

269274
@_deprecate_positional_args

sklearn/metrics/tests/test_regression.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def test_multioutput_regression():
7676
assert_almost_equal(error, (1. / 3 + 2. / 3 + 2. / 3) / 4.)
7777

7878
error = mean_squared_error(y_true, y_pred, squared=False)
79-
assert_almost_equal(error, 0.645, decimal=2)
79+
assert_almost_equal(error, 0.454, decimal=2)
8080

8181
error = mean_squared_log_error(y_true, y_pred)
8282
assert_almost_equal(error, 0.200, decimal=2)
@@ -258,7 +258,7 @@ def test_regression_custom_weights():
258258
evsw = explained_variance_score(y_true, y_pred, multioutput=[0.4, 0.6])
259259

260260
assert_almost_equal(msew, 0.39, decimal=2)
261-
assert_almost_equal(rmsew, 0.62, decimal=2)
261+
assert_almost_equal(rmsew, 0.59, decimal=2)
262262
assert_almost_equal(maew, 0.475, decimal=3)
263263
assert_almost_equal(rw, 0.94, decimal=2)
264264
assert_almost_equal(evsw, 0.94, decimal=2)

0 commit comments

Comments
 (0)
0