8000 FIX np.digitize fixes in calibration and discretization (#22526) · scikit-learn/scikit-learn@4d9e005 · GitHub
[go: up one dir, main page]

Skip to content

Commit 4d9e005

Browse files
Micky774thomasjpfanamuellerjeremiedbb
authored
FIX np.digitize fixes in calibration and discretization (#22526)
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com> Co-authored-by: Andreas Mueller <andreas.mueller@columbia.edu> Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>
1 parent 6016083 commit 4d9e005

File tree

3 files changed

+16
-12
lines changed

3 files changed

+16
-12
lines changed

doc/whats_new/v1.1.rst

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,13 @@ random sampling procedures.
5151
`feature_names_in_` to be defined, columns must be all strings. :pr:`22410` by
5252
`Thomas Fan`_.
5353

54+
- |Fix| :class:`preprocessing.KBinsDiscretizer` changed handling of bin edges
55+
slightly, which might result in a different encoding with the same data.
56+
57+
- |Fix| :func:`calibration.calibration_curve` changed handling of bin
58+
edges slightly, which might result in a different output curve given the same
59+
data.
60+
5461
Changelog
5562
---------
5663

@@ -112,6 +119,9 @@ Changelog
112119
`pos_label` to specify the positive class label.
113120
:pr:`21032` by :user:`Guillaume Lemaitre <glemaitre>`.
114121

122+
- |Fix| :func:`calibration.calibration_curve` handles bin edges more consistently now.
123+
:pr:`14975` by `Andreas Müller`_ and :pr:`22526` by :user:`Meekail Zain <micky774>`.
124+
115125
- |Enhancement| :class:`CalibrationDisplay` accepts a parameter `pos_label` to
116126
add this information to the plot.
117127
:pr:`21038` by :user:`Guillaume Lemaitre <glemaitre>`.
@@ -672,6 +682,9 @@ Changelog
672682
:class:`preprocessing.OrdinalEncoder`, and
673683
:class:`preprocessing.Binarizer`. :pr:`21079` by `Thomas Fan`_.
674684

685+
- |Fix| :class:`preprocessing.KBinDiscretizer` handles bin edges more consistently now.
686+
:pr:`14975` by `Andreas Müller`_ and :pr:`22526` by :user:`Meekail Zain <micky774>`.
687+
675688
:mod:`sklearn.random_projection`
676689
................................
677690

sklearn/calibration.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -964,16 +964,15 @@ def calibration_curve(
964964
if strategy == "quantile": # Determine bin edges by distribution of data
965965
quantiles = np.linspace(0, 1, n_bins + 1)
966966
bins = np.percentile(y_prob, quantiles * 100)
967-
bins[-1] = bins[-1] + 1e-8
968967
elif strategy == "uniform":
969-
bins = np.linspace(0.0, 1.0 + 1e-8, n_bins + 1)
968+
bins = np.linspace(0.0, 1.0, n_bins + 1)
970969
else:
971970
raise ValueError(
972971
"Invalid entry to 'strategy' input. Strategy "
973972
"must be either 'quantile' or 'uniform'."
974973
)
975974

976-
binids = np.digitize(y_prob, bins) - 1
975+
binids = np.searchsorted(bins[1:-1], y_prob)
977976

978977
bin_sums = np.bincount(binids, weights=y_prob, minlength=len(bins))
979978
bin_true = np.bincount(binids, weights=y_true, minlength=len(bins))

sklearn/preprocessing/_discretization.py

Lines changed: 1 addition & 9 deletions
< 8000 /div>
Original file line numberDiff line numberDiff line change
@@ -374,15 +374,7 @@ def transform(self, X):
374374

375375
bin_edges = self.bin_edges_
376376
for jj in range(Xt.shape[1]):
377-
# Values which are close to a bin edge are susceptible to numeric
378-
# instability. Add eps to X so these values are binned correctly
379-
# with respect to their decimal truncation. See documentation of
380-
# numpy.isclose for an explanation of ``rtol`` and ``atol``.
381-
rtol = 1.0e-5
382-
atol = 1.0e-8
383-
eps = atol + rtol * np.abs(Xt[:, jj])
384-
Xt[:, jj] = np.digitize(Xt[:, jj] + eps, bin_edges[jj][1:])
385-
np.clip(Xt, 0, self.n_bins_ - 1, out=Xt)
377+
Xt[:, jj] = np.searchsorted(bin_edges[jj][1:-1], Xt[:, jj], side="right")
386378

387379
if self.encode == "ordinal":
388380
return Xt

0 commit comments

Comments
 (0)
0