8000 PERF call xp.searchsorted on contiguous inputs · scikit-learn/scikit-learn@97a8da3 · GitHub
[go: up one dir, main page]

Skip to content

Commit 97a8da3

Browse files
committed
PERF call xp.searchsorted on contiguous inputs
1 parent a31489c commit 97a8da3

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

sklearn/utils/stats.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,14 @@ def _weighted_percentile(array, sample_weight, percentile_rank=50):
7474
sorted_weights[sorted_nan_mask] = 0
7575

7676
# Compute the weighted cumulative distribution function (CDF) based on
77-
# `sample_weight` and scale `percentile_rank` along it:
78-
weight_cdf = xp.cumulative_sum(sorted_weights, axis=0)
79-
adjusted_percentile_rank = percentile_rank / 100 * weight_cdf[-1, ...]
77+
# `sample_weight` and scale `percentile_rank` along it.
78+
#
79+
# Note: we call `xp.cumulative_sum` on the transposed `sorted_weights` to
80+
# ensure that the result is of shape `(n_features, n_samples)` that the
81+
# `xp.searchsorted` calls take contiguous inputs as a result (for
82+
# performance reasons).
83+
weight_cdf = xp.cumulative_sum(sorted_weights.T, axis=1)
84+
adjusted_percentile_rank = percentile_rank / 100 * weight_cdf[..., -1]
8085

8186
# Ignore leading `sample_weight=0` observations when `percentile_rank=0` (#20528)
8287
mask = adjusted_percentile_rank == 0
@@ -88,8 +93,8 @@ def _weighted_percentile(array, sample_weight, percentile_rank=50):
8893
# (Needs to be an array as we pass to `clip` later)
8994
percentile_idx = xp.asarray(
9095
[
91-
xp.searchsorted(weight_cdf[..., i], adjusted_percentile_rank[i])
92-
for i in range(weight_cdf.shape[1])
96+
xp.searchsorted(weight_cdf[i], adjusted_percentile_rank[i])
97+
for i in range(weight_cdf.shape[0])
9398
],
9499
device=device,
95100
)

0 commit comments

Comments
 (0)
0