Description
Describe the bug
In the case of multiple target outputs, sklearn.metrics.mean_squared_error
documentation stated that it returns "an array of floating point values, one for each individual target" but doesn't behave like so. Instead, it is considering multiple target outputs as a matrix. Bug take effect only when squared=False
(thanks to linearity of the mse).
Suggested fix
- Go with
np.sqrt
early. - Then no need for the if statements at returns
def mean_squared_error(y_true, y_pred,
sample_weight=None,
multioutput='uniform_average',
squared=True):
y_type, y_true, y_pred, multioutput = _check_reg_targets(
y_true, y_pred, multioutput)
check_consistent_length(y_true, y_pred, sample_weight)
output_errors = np.average((y_true - y_pred) ** 2, axis=0,
weights=sample_weight)
if not squared: #! line added
output_errors = np.sqrt(output_errors) #! line added
if isinstance(multioutput, str):
if multioutput == 'raw_values':
return output_errors #! line changed
elif multioutput == 'uniform_average':
# pass None as weights to np.average: uniform mean
multioutput = None
#! line removed
return np.average(output_errors, weights=multioutput) #! line changed
Steps/Code to reproduce
Note : using np.sqrt
and not squared=False
because of Issue 16313 fixed in #16323 but not yet available on my installation.
import numpy as np
from sklearn.metrics import mean_squared_error
# Uniform average
print(np.average(np.sqrt(mean_squared_error(y_true, y_pred, multioutput='raw_values'))))
print(mean_squared_error(y_true, y_pred, squared=False, multioutput='uniform_average'))
# Weighted average
print(np.average(np.sqrt(mean_squared_error(y_true, y_pred, multioutput='raw_values')),
weights=[0.3, 0.7]))
print(mean_squared_error(y_true, y_pred, squared=False, multioutput=[0.3, 0.7]))
Which returns the following output:
0.8227486121839513 # wanted
0.8416254115301732 # returned by mean_squared_error
0.8936491673103708 # wanted
0.9082951062292475 # returned by mean_squared_error
Version
System:
python: 3.7.6 | packaged by conda-forge | (default, Mar 23 2020, 23:03:20) [GCC 7.3.0]
executable: /opt/anaconda3/envs/hsi/bin/python
machine: Linux-5.3.0-46-generic-x86_64-with-debian-buster-sid
Python dependencies:
pip: 20.0.2
setuptools: 46.1.3.post20200325
sklearn: 0.22.2.post1
numpy: 1.18.1
scipy: 1.4.1
Cython: None
pandas: 1.0.3
matplotlib: 3.2.1
joblib: 0.14.1
Built with OpenMP: True