@@ -88,22 +88,24 @@ def _weighted_percentile(array, sample_weight, percentile_rank=50):
88
88
adjusted_percentile_rank [mask ] = xp .nextafter (
89
89
adjusted_percentile_rank [mask ], adjusted_percentile_rank [mask ] + 1
90
90
)
91
- # Find index (i) of `adjusted_percentile_rank` in `weight_cdf`,
92
- # such that weight_cdf[i-1] < percentile <= weight_cdf[i]
93
- # (Needs to be an array as we pass to `clip` later)
94
- percentile_idx = xp .asarray (
91
+ # For each feature with index j, find sample index i of the scalar value
92
+ # `adjusted_percentile_rank[j]` in 1D array ` weight_cdf[j]`, such that:
93
+ # weight_cdf[j, i-1] < adjusted_percentile_rank[j] <= weight_cdf[j, i].
94
+ percentile_indices = xp .asarray (
95
95
[
96
- xp .searchsorted (weight_cdf [i ], adjusted_percentile_rank [i ])
97
- for i in range (weight_cdf .shape [0 ])
96
+ xp .searchsorted (
97
+ weight_cdf [feature_idx ], adjusted_percentile_rank [feature_idx ]
98
+ )
99
+ for feature_idx in range (weight_cdf .shape [0 ])
98
100
],
99
101
device = device ,
100
102
)
101
- # In rare cases, `percentile_idx ` equals to `sorted_idx.shape[0]`
103
+ # In rare cases, `percentile_indices ` equals to `sorted_idx.shape[0]`
102
104
max_idx = sorted_idx .shape [0 ] - 1
103
- percentile_idx = xp .clip (percentile_idx , 0 , max_idx )
105
+ percentile_indices = xp .clip (percentile_indices , 0 , max_idx )
104
106
105
107
col_indices = xp .arange (array .shape [1 ], device = device )
106
- percentile_in_sorted = sorted_idx [percentile_idx , col_indices ]
108
+ percentile_in_sorted = sorted_idx [percentile_indices , col_indices ]
107
109
108
110
result = array [percentile_in_sorted , col_indices ]
109
111
0 commit comments