8000 Improve variable names and comments in xp.searchsorted calls · scikit-learn/scikit-learn@9c7bb75 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9c7bb75

Browse files
committed
Improve variable names and comments in xp.searchsorted calls
1 parent 97a8da3 commit 9c7bb75

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

sklearn/utils/stats.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -88,22 +88,24 @@ def _weighted_percentile(array, sample_weight, percentile_rank=50):
8888
adjusted_percentile_rank[mask] = xp.nextafter(
8989
adjusted_percentile_rank[mask], adjusted_percentile_rank[mask] + 1
9090
)
91-
# Find index (i) of `adjusted_percentile_rank` in `weight_cdf`,
92-
# such that weight_cdf[i-1] < percentile <= weight_cdf[i]
93-
# (Needs to be an array as we pass to `clip` later)
94-
percentile_idx = xp.asarray(
91+
# For each feature with index j, find sample index i of the scalar value
92+
# `adjusted_percentile_rank[j]` in 1D array `weight_cdf[j]`, such that:
93+
# weight_cdf[j, i-1] < adjusted_percentile_rank[j] <= weight_cdf[j, i].
94+
percentile_indices = xp.asarray(
9595
[
96-
xp.searchsorted(weight_cdf[i], adjusted_percentile_rank[i])
97-
for i in range(weight_cdf.shape[0])
96+
xp.searchsorted(
97+
weight_cdf[feature_idx], adjusted_percentile_rank[feature_idx]
98+
)
99+
for feature_idx in range(weight_cdf.shape[0])
98100
],
99101
device=device,
100102
)
101-
# In rare cases, `percentile_idx` equals to `sorted_idx.shape[0]`
103+
# In rare cases, `percentile_indices` equals to `sorted_idx.shape[0]`
102104
max_idx = sorted_idx.shape[0] - 1
103-
percentile_idx = xp.clip(percentile_idx, 0, max_idx)
105+
percentile_indices = xp.clip(percentile_indices, 0, max_idx)
104106

105107
col_indices = xp.arange(array.shape[1], device=device)
106-
percentile_in_sorted = sorted_idx[percentile_idx, col_indices]
108+
percentile_in_sorted = sorted_idx[percentile_indices, col_indices]
107109

108110
result = array[percentile_in_sorted, col_indices]
109111

0 commit comments

Comments
 (0)
0