10000 FEA Add QuantileRegressor estimator (#9978) · murata-yu/scikit-learn@c1cc67d · GitHub
[go: up one dir, main page]

Skip to content

Commit c1cc67d

Browse files
avidaleDavid Dalelorentzenchr
authored
FEA Add QuantileRegressor estimator (scikit-learn#9978)
Co-authored-by: David Dale <ddale@yandex-team.ru> Co-authored-by: Christian Lorentzen <lorentzen.ch@gmail.com>
1 parent 88be3c1 commit c1cc67d

File tree

7 files changed

+729
-0
lines changed

7 files changed

+729
-0
lines changed

doc/modules/classes.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -839,6 +839,7 @@ Any estimator using the Huber loss would also be robust to outliers, e.g.
839839
:template: class.rst
840840

841841
linear_model.HuberRegressor
842+
linear_model.QuantileRegressor
842843
linear_model.RANSACRegressor
843844
linear_model.TheilSenRegressor
844845

doc/modules/linear_model.rst

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1423,6 +1423,83 @@ Note that this estimator is different from the R implementation of Robust Regres
14231423
squares implementation with weights given to each sample on the basis of how much the residual is
14241424
greater than a certain threshold.
14251425

1426+
.. _quantile_regression:
1427+
1428+
Quantile Regression
1429+
===================
1430+
1431+
Quantile regression estimates the median or other quantiles of :math:`y`
1432+
conditional on :math:`X`, while ordinary least squares (OLS) estimates the
1433+
conditional mean.
1434+
1435+
As a linear model, the :class:`QuantileRegressor` gives linear predictions
1436+
:math:`\hat{y}(w, X) = Xw` for the :math:`q`-th quantile, :math:`q \in (0, 1)`.
1437+
The weights or coefficients :math:`w` are then found by the following
1438+
minimization problem:
1439+
1440+
.. math::
1441+
\min_{w} {\frac{1}{n_{\text{samples}}}
1442+
\sum_i PB_q(y_i - X_i w) + \alpha ||w||_1}.
1443+
1444+
This consists of the pinball loss (also known as linear loss),
1445+
see also :class:`~sklearn.metrics.mean_pinball_loss`,
1446+
1447+
.. math::
1448+
PB_q(t) = q \max(t, 0) + (1 - q) \max(-t, 0) =
1449+
\begin{cases}
1450+
q t, & t > 0, \\
1451+
0, & t = 0, \\
1452+
(1-q) t, & t < 0
1453+
\end{cases}
1454+
1455+
and the L1 penalty controlled by parameter ``alpha``, similar to
1456+
:class:`Lasso`.
1457+
1458+
As the pinball loss is only linear in the residuals, quantile regression is
1459+
much more robust to outliers than squared error based estimation of the mean.
1460+
Somewhat in between is the :class:`HuberRegressor`.
1461+
1462+
Quantile regression may be useful if one is interested in predicting an
1463+
interval instead of point prediction. Sometimes, prediction intervals are
1464+
calculated based on the assumption that prediction error is distributed
1465+
normally with zero mean and constant variance. Quantile regression provides
1466+
sensible prediction intervals even for errors with non-constant (but
1467+
predictable) variance or non-normal distribution.
1468+
1469+
.. figure:: /auto_examples/linear_model/images/sphx_glr_plot_quantile_regression_001.png
1470+
:target: ../auto_examples/linear_model/plot_quantile_regression.html
1471+
:align: center
1472+
:scale: 50%
1473+
1474+
Based on minimizing the pinball loss, conditional quantiles can also be
1475+
estimated by models other than linear models. For example,
1476+
:class:`~sklearn.ensemble.GradientBoostingRegressor` can predict conditional
1477+
quantiles if its parameter ``loss`` is set to ``"quantile"`` and parameter
1478+
``alpha`` is set to the quantile that should be predicted. See the example in
1479+
:ref:`sphx_glr_auto_examples_ensemble_plot_gradient_boosting_quantile.py`.
1480+
1481+
Most implementations of quantile regression are based on linear programming
1482+
problem. The current implementation is based on
1483+
:func:`scipy.optimize.linprog`.
1484+
1485+
.. topic:: Examples:
1486+
1487+
* :ref:`sphx_glr_auto_examples_linear_model_plot_quantile_regression.py`
1488+
1489+
.. topic:: References:
1490+
1491+
* Koenker, R., & Bassett Jr, G. (1978). `Regression quantiles.
1492+
<https://gib.people.uic.edu/RQ.pdf>`_
1493+
Econometrica: journal of the Econometric Society, 33-50.
1494+
1495+
* Portnoy, S., & Koenker, R. (1997). The Gaussian hare and the Laplacian
1496+
tortoise: computability of squared-error versus absolute-error estimators.
1497+
Statistical Science, 12, 279-300. https://doi.org/10.1214/ss/1030037960
1498+
1499+
* Koenker, R. (2005). Quantile Regression.
1500+
Cambridge University Press. https://doi.org/10.1017/CBO9780511754098
1501+
1502+
14261503
.. _polynomial_regression:
14271504

14281505
Polynomial regression: extending linear models with basis functions

doc/whats_new/v1.0.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,11 @@ Changelog
282282
:mod:`sklearn.linear_model`
283283
...........................
284284

285+
- |Feature| Added :class:`linear_model.QuantileRegressor` which implements
286+
linear quantile regression with L1 penalty.
287+
:pr:`9978` by :user:`David Dale <avidale>` and
288+
:user:`Christian Lorentzen <lorentzenchr>`.
289+
285290
- |Feature| The new :class:`linear_model.SGDOneClassSVM` provides an SGD
286291
implementation of the linear One-Class SVM. Combined with kernel
287292
approximation techniques, this implementation approximates the solution of
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
"""
2+
===================
3+
Quantile regression
4+
===================
5+
This example illustrates how quantile regression can predict non-trivial
6+
conditional quantiles.
7+
8+
The left figure shows the case when the error distribution is normal,
9+
but has non-constant variance, i.e. with heteroscedasticity.
10+
11+
The right figure shows an example of an asymmetric error distribution,
12+
namely the Pareto distribution.
13+
"""
14+
print(__doc__)
15+
# Authors: David Dale <dale.david@mail.ru>
16+
# Christian Lorentzen <lorentzen.ch@gmail.com>
17+
# License: BSD 3 clause
18+
import numpy as np
19+
import matplotlib.pyplot as plt
20+
21+
from sklearn.linear_model import QuantileRegressor, LinearRegression
22+
from sklearn.metrics import mean_absolute_error, mean_squared_error
23+
from sklearn.model_selection import cross_val_score
24+
25+
26+
def plot_points_highlighted(x, y, model_low, model_high, ax):
27+
"""Plot points with highlighting."""
28+
mask = y <= model_low.predict(X)
29+
ax.scatter(x[mask], C38F y[mask], c="k", marker="x")
30+
mask = y > model_high.predict(X)
31+
ax.scatter(x[mask], y[mask], c="k", marker="x")
32+
mask = (y > model_low.predict(X)) & (y <= model_high.predict(X))
33+
ax.scatter(x[mask], y[mask], c="k")
34+
35+
36+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5), sharey=True)
37+
38+
rng = np.random.RandomState(42)
39+
x = np.linspace(0, 10, 100)
40+
X = x[:, np.newaxis]
41+
y = 10 + 0.5 * x + rng.normal(loc=0, scale=0.5 + 0.5 * x, size=x.shape[0])
42+
y_mean = 10 + 0.5 * x
43+
ax1.plot(x, y_mean, "k--")
44+
45+
quantiles = [0.05, 0.5, 0.95]
46+
models = []
47+
for quantile in quantiles:
48+
qr = QuantileRegressor(quantile=quantile, alpha=0)
49+
qr.fit(X, y)
50+
ax1.plot(x, qr.predict(X))
51+
models.append(qr)
52+
53+
plot_points_highlighted(x, y, models[0], models[2], ax1)
54+
ax1.set_xlabel("x")
55+
ax1.set_ylabel("y")
56+
ax1.set_title("Quantiles of heteroscedastic Normal distributed target")
57+
ax1.legend(["true mean"] + quantiles)
58+
59+
60+
a = 5
61+
y = 10 + 0.5 * x + 10 * (rng.pareto(a, size=x.shape[0]) - 1 / (a - 1))
62+
ax2.plot(x, y_mean, "k--")
63+
64+
models = []
65+
for quantile in quantiles:
66+
qr = QuantileRegressor(quantile=quantile, alpha=0)
67+
qr.fit(X, y)
68+
ax2.plot([0, 10], qr.predict([[0], [10]]))
69+
models.append(qr)
70+
71+
plot_points_highlighted(x, y, models[0], models[2], ax2)
72+
ax2.set_xlabel("x")
73+
ax2.set_ylabel("y")
74+
ax2.set_title("Quantiles of asymmetric Pareto distributed target")
75+
ax2.legend(["true mean"] + quantiles, loc="lower right")
76+
ax2.yaxis.set_tick_params(labelbottom=True)
77+
78+
plt.show()
79+
80+
# %%
81+
# Note that both targets have the same mean value, indicated by the dashed
82+
# black line. As the Normal distribution is symmetric, mean and median are
83+
# identical and the predicted 0.5 quantile almost hits the true mean.
84+
# In the Pareto case, the difference between predicted median and true mean
85+
# is evident. We also marked the points below the 0.05 and above 0.95
86+
# predicted quantiles by small crosses. You might count them and consider
87+
# that we have 100 samples in total.
88+
#
89+
# The second part of the example shows that LinearRegression minimizes MSE
90+
# in order to predict the mean, while QuantileRegressor with `quantile=0.5`
91+
# minimizes MAE in order to predict the median. Both do their own job well.
92+
93+
models = [LinearRegression(), QuantileRegressor(alpha=0)]
94+
names = ["OLS", "Quantile"]
95+
96+
print("# In-sample performance")
97+
for model_name, model in zip(names, models):
98+
print(model_name + ":")
99+
model.fit(X, y)
100+
mae = mean_absolute_error(model.predict(X), y)
101+
rmse = np.sqrt(mean_squared_error(model.predict(X), y))
102+
print(f"MAE = {mae:.4} RMSE = {rmse:.4}")
103+
print("\n# Cross-validated performance")
104+
for model_name, model in zip(names, models):
105+
print(model_name + ":")
106+
mae = -cross_val_score(model, X, y, cv=3,
107+
scoring="neg_mean_absolute_error").mean()
108+
rmse = np.sqrt(-cross_val_score(model, X, y, cv=3,
109+
scoring="neg_mean_squared_error").mean())
110+
print(f"MAE = {mae:.4} RMSE = {rmse:.4}")

sklearn/linear_model/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from ._passive_aggressive import PassiveAggressiveRegressor
2929
from ._perceptron import Perceptron
3030

31+
from ._quantile import QuantileRegressor
3132
from ._ransac import RANSACRegressor
3233
from ._theil_sen import TheilSenRegressor
3334

@@ -59,6 +60,7 @@
5960
'PassiveAggressiveClassifier',
6061
'PassiveAggressiveRegressor',
6162
'Perceptron',
63+
'QuantileRegressor',
6264
'Ridge',
6365
'RidgeCV',
6466
'RidgeClassifier',

0 commit comments

Comments
 (0)
0