@@ -74,9 +74,14 @@ def _weighted_percentile(array, sample_weight, percentile_rank=50):
74
74
sorted_weights [sorted_nan_mask ] = 0
75
75
76
76
# Compute the weighted cumulative distribution function (CDF) based on
77
- # `sample_weight` and scale `percentile_rank` along it:
78
- weight_cdf = xp .cumulative_sum (sorted_weights , axis = 0 )
79
- adjusted_percentile_rank = percentile_rank / 100 * weight_cdf [- 1 , ...]
77
+ # `sample_weight` and scale `percentile_rank` along it.
78
+ #
79
+ # Note: we call `xp.cumulative_sum` on the transposed `sorted_weights` to
80
+ # ensure that the result is of shape `(n_features, n_samples)` that the
81
+ # `xp.searchsorted` calls take contiguous inputs as a result (for
82
+ # performance reasons).
83
+ weight_cdf = xp .cumulative_sum (sorted_weights .T , axis = 1 )
84
+ adjusted_percentile_rank = percentile_rank / 100 * weight_cdf [..., - 1 ]
80
85
81
86
# Ignore leading `sample_weight=0` observations when `percentile_rank=0` (#20528)
82
87
mask = adjusted_percentile_rank == 0
@@ -88,8 +93,8 @@ def _weighted_percentile(array, sample_weight, percentile_rank=50):
88
93
# (Needs to be an array as we pass to `clip` later)
89
94
percentile_idx = xp .asarray (
90
95
[
91
- xp .searchsorted (weight_cdf [..., i ], adjusted_percentile_rank [i ])
92
- for i in range (weight_cdf .shape [1 ])
96
+ xp .searchsorted (weight_cdf [i ], adjusted_percentile_rank [i ])
97
+ for i in range (weight_cdf .shape [0 ])
93
98
],
94
99
device = device ,
95
100
)
0 commit comments