8000 ENH: Weighted quantile by chunweiyuan · Pull Request #9211 · numpy/numpy · GitHub
[go: up one dir, main page]

Skip to content

ENH: Weighted quantile #9211

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
ENH: Add weights optional arg to quantile/percentile in lib.function_…
…base

     and nanquantile/nanpercentile in lib.nanfunctions (#9211)
  • Loading branch information
cwatuw committed Nov 8, 2023
commit ccc2ece47a39aa2f11b7c24a845f79bdb783b40b
240 changes: 218 additions & 22 deletions numpy/lib/_function_base_impl.py
6DB6
Original file line number Diff line number Diff line change
Expand Up @@ -3905,15 +3905,17 @@ def _median(a, axis=None, out=None, overwrite_input=False):
return rout


def _percentile_dispatcher(a, q, axis=None, out=None, overwrite_input=None,
method=None, keepdims=None, *, interpolation=None):
def _percentile_dispatcher(a, q, axis=None, weights=None, out=None,
overwrite_input=None, method=None, keepdims=None, *,
interpolation=None):
return (a, q, out)


@array_function_dispatch(_percentile_dispatcher)
def percentile(a,
q,
axis=None,
weights=None,
out=None,
overwrite_input=False,
method="linear",
Expand Down Expand Up @@ -4204,18 +4206,20 @@ def percentile(a,
if not _quantile_is_valid(q):
raise ValueError("Percentiles must be in the range [0, 100]")
return _quantile_unchecked(
a, q, axis, out, overwrite_input, method, keepdims)
a, q, axis, weights, out, overwrite_input, method, keepdims)


def _quantile_dispatcher(a, q, axis=None, out=None, overwrite_input=None,
method=None, keepdims=None, *, interpolation=None):
def _quantile_dispatcher(a, q, axis=None, weights=None, out=None,
overwrite_input=None, method=None, keepdims=None, *,
interpolation=None):
return (a, q, out)


@array_function_dispatch(_quantile_dispatcher)
def quantile(a,
q,
axis=None,
weights=None,
out=None,
overwrite_input=False,
method="linear",
Expand Down Expand Up @@ -4469,27 +4473,101 @@ def quantile(a,
if not _quantile_is_valid(q):
raise ValueError("Quantiles must be in the range [0, 1]")
return _quantile_unchecked(
a, q, axis, out, overwrite_input, method, keepdims)
a, q, axis, weights, out, overwrite_input, method, keepdims)


def _quantile_unchecked(a,
q,
axis=None,
weights=None,
out=None,
overwrite_input=False,
method="linear",
keepdims=False):
"""Assumes that q is in [0, 1], and is an ndarray"""
"""Assumes that q is in [0, 1], and is an ndarray."""
if weights is not None:
weights = _validate_and_ureduce_weights(a, axis, weights)

return _ureduce(a,
func=_quantile_ureduce_func,
q=q,
keepdims=keepdims,
axis=axis,
weights=weights,
out=out,
overwrite_input=overwrite_input,
method=method)


def _validate_and_ureduce_weights(a, axis, wgts):
"""Ensure quantile weights are valid.

Weights cannot:
* be negative
* sum to 0
However, they can be
* 0, as long as they do not sum to 0
* less than 1. In this case, all weights are re-normalized by
the lowest non-zero weight prior to computation.

Weights will be broadcasted to the shape of a, then reduced as done
via _ureduce().
* if wgts.shape == a.shape, simply return ureduced wgts.
* if scalar, broadcast to shape of a, followed by _ureduce.
* for 1-D or N-D array, must make sure that its size matches that of a
along the axis dims, so can be reshaped.
Example: if a.shape == (2, 3, 4, 5, 6), axis == (1, 3),
then weights must be of size 15, and will be reshaped to
(1, 3, 1, 5, 1) before broadcasting to the shape of a,
followed by _ureduce.
"""
a = asanyarray(a)
wgts = asanyarray(wgts)

if not np.issubdtype(wgts.dtype, float):
# convert every element to np.float
# raises ValueError if any element fails to convert
wgts = wgts.astype(float)

if np.isnan(wgts).any():
raise ValueError("No weight can be NaN.")

if (wgts < 0).any():
raise ValueError("Negative weight not allowed.")

# dims to reshape to, before broadcast
if axis is None:
dims = tuple(range(a.ndim)) # all axes
else:
dims = _nx.normalize_axis_tuple(axis, a.ndim)

# broadcast weights to the shape of array a.
if wgts.shape == a.shape: # no need to broadcast
pass
elif wgts.size == 1:
wgts = np.broadcast_to(wgts, a.shape)
elif (wgts.size == functools.reduce(lambda x, y: x * y,
[a.shape[d] for d in dims])):
to_shape = tuple(a.shape[d] if d in dims else 1 for d in range(a.ndim))
wgts = wgts.reshape(to_shape)
wgts = np.broadcast_to(wgts, a.shape)
else:
raise ValueError("Weights are not broadcastable to array. "
"Either weights is a scalar, an array the same shape "
"as a, or an array that is reshapable to the shape "
"of a's axes in the axis arg. Please see the "
"doctring for the weights arg.")

scl = wgts.sum(axis=axis)
if np.any(scl <= 0.0):
raise ZeroDivisionError("Weights sum to zero, cannot be normalized.")

# Obtain a weights array of the same shape as ureduced a
wgts = _ureduce(wgts, func=lambda x, **kwargs: x, axis=dims)

return wgts


def _quantile_is_valid(q):
# avoid expensive reductions, relevant for arrays with < O(1000) elements
if q.ndim == 1 and q.size < 10:
Expand Down Expand Up @@ -4627,6 +4705,7 @@ def _quantile_ureduce_func(
a: np.array,
q: np.array,
axis: int = None,
weights: np.ndarray = None,
out=None,
overwrite_input: bool = False,
method="linear",
Expand All @@ -4651,6 +4730,7 @@ def _quantile_ureduce_func(
result = _quantile(arr,
quantiles=q,
axis=axis,
weights=weights,
method=method,
out=out)
return result
Expand Down Expand Up @@ -4682,7 +4762,7 @@ def _get_indexes(arr, virtual_indexes, valid_values_count):
next_indexes[indexes_below_bounds] = 0
if np.issubdtype(arr.dtype, np.inexact):
# After the sort, slices having NaNs will have for last element a NaN
virtual_indexes_nans = np.isnan(virtual_indexes)
virtual_indexes_nans = np.isnan(virtual_indexes) # indexes have nans?
if virtual_indexes_nans.any():
previous_indexes[virtual_indexes_nans] = -1
next_indexes[virtual_indexes_nans] = -1
Expand All @@ -4691,10 +4771,25 @@ def _get_indexes(arr, virtual_indexes, valid_values_count):
return previous_indexes, next_indexes


def _map_weighted_indexes_to_real_indexes(w_indexes, weights, axis,
w_sums, n_vals):
"""Map indexes in weighted space back to original space.

Args:
w_indexes: array_like
1-D array of indexes in weight space.
weights: array_like
Number of values in weight-expandec space. Could be ndarray.

"""
return


def _quantile(
arr: np.array,
quantiles: np.array,
axis: int = -1,
weights: np.ndarray = None,
method="linear",
out=None,
):
Expand All @@ -4712,12 +4807,9 @@ def _quantile(
"""
# --- Setup
arr = np.asanyarray(arr)
values_count = arr.shape[axis]
# The dimensions of `q` are prepended to the output shape, so we need the
# axis being sampled from `arr` to be last.
# The dimensions of `q` are prepended to the output shape, so we need
# the axis being sampled from `arr` to be last.

if axis != 0: # But moveaxis is slow, so only call it if necessary.
arr = np.moveaxis(arr, axis, destination=0)
# --- Computation of indexes
# Index where to find the value in the sorted array.
# Virtual because it is a floating point value, not an valid index.
Expand All @@ -4728,12 +4820,122 @@ def _quantile(
raise ValueError(
f"{method!r} is not a valid method. Use one of: "
f"{_QuantileMethods.keys()}") from None
virtual_indexes = method["get_virtual_index"](values_count, quantiles)
virtual_indexes = np.asanyarray(virtual_indexes)

supports_nans = (
np.issubdtype(arr.dtype, np.inexact) or arr.dtype.kind in 'Mm')

if weights is not None: # weights is the same shape as arr
# first move data to axis=-1 for np.vectorize, which runs on axis=-1
if axis != -1:
arr = np.moveaxis(arr, axis, destination=-1)
weights = np.moveaxis(weights, axis, destination=-1)
values_count = arr.shape[-1]

# Now check/renormalize weights if any is (0, 1)
def _normalize(v):
inds = v > 0
if (v[inds] < 1).any():
vec = v.copy()
vec[inds] = vec[inds] / vec[inds].min() # renormalization
return vec
else:
return v
weights = np.apply_along_axis(_normalize, -1, weights)

# values with weight=0 are made nan and later sorted to end of array
if (weights == 0).any():
weights[weights == 0] = np.nan

def _sort_by_index(vector, vec_indices):
return vector[vec_indices]
# this func vectorizes sort along axis
n = values_count
arraysort = np.vectorize(_sort_by_index,
signature=f"({n}),({n})->({n})")
# compute the sorted array. nans will be sorted to the end.
ind_sorted = np.argsort(arr, axis=-1)
arr = arraysort(arr, ind_sorted)
# need to align weights to now-sorted arr, hence np.vectorize is used
weights = arraysort(weights, ind_sorted)

if supports_nans:
slices_having_nans = np.isnan(
take(arr, indices=-1, axis=-1)
)
else:
slices_having_nans = None

# convert weights to virtual indexes.
def _map_weights_to_indexes(wgts, quantiles, hard_indexes):
# NOTE 1-D function used within apply_along_axis() might be slower,
# but probably incurs less memory footprint.
wgts_cumsum = wgts.cumsum()
n = wgts_cumsum[-1] # sum along axis replaces n in Hyndman paper.
weight_space_indexes = method["get_virtual_index"](n, quantiles)
left_weight_bound = np.roll(wgts_cumsum, 1)
left_weight_bound[0] = 0
right_weight_bound = wgts_cumsum - 1
indexes = np.zeros(2 * hard_indexes.size)
weights = indexes.copy()
indexes[0::2] = hard_indexes
indexes[1::2] = hard_indexes
weights[0::2] = left_weight_bound
weights[1::2] = right_weight_bound
return np.interp(weight_space_indexes, weights, indexes)

hard_indexes = np.arange(values_count)

virtual_indexes = np.apply_along_axis(_map_weights_to_indexes,
-1, weights,
quantiles=quantiles,
hard_indexes=hard_indexes)
# unlike the no-weights case, virtual_indexes here can be N-D,
# with different indexes along every DATA_AXIS slice.
# also the use of np.interp means virtual_indexes is of type float.
previous_indexes, next_indexes = _get_indexes(arr, virtual_indexes,
values_count)

# this function becomes vectorized
def _get_values_from_indexes(arr, virtual_indexes,
previous_indexes, next_indexes):
previous = take(arr, previous_indexes)
next = take(arr, next_indexes)
gamma = _get_gamma(virtual_indexes, previous_indexes, method)
return _lerp(previous, next, gamma, out=out)

m = virtual_indexes.shape[-1]
get_values_from_indexes =\
np.vectorize(_get_values_from_indexes,
signature=f"({n}),({m}),({m}),({m})->({m})")
result = get_values_from_indexes(arr, virtual_indexes,
previous_indexes, next_indexes)
# now move data to DATA_AXIS to be consistent with no-weights case
result = np.moveaxis(result, -1, destination=0)

else:
values_count = arr.shape[axis]
if axis != 0: # moveaxis is slow, so only call it if necessary.
arr = np.moveaxis(arr, axis, destination=0)

virtual_indexes = method["get_virtual_index"](values_count, quantiles)
virtual_indexes = np.asanyarray(virtual_indexes)

result, slices_having_nans =\
_take_sorted_index_values(arr, method, virtual_indexes,
values_count, supports_nans, out)

if np.any(slices_having_nans):
if result.ndim == 0 and out is None:
# can't write to a scalar, but indexing will be correct
result = arr[-1]
else:
np.copyto(result, arr[-1, ...], where=slices_having_nans)
return result


def _take_sorted_index_values(arr, method, virtual_indexes, values_count,
supports_nans, out):

if np.issubdtype(virtual_indexes.dtype, np.integer):
# No interpolation needed, take the points along axis
if supports_nans:
Expand Down Expand Up @@ -4771,13 +4973,7 @@ def _quantile(
next,
gamma,
out=out)
if np.any(slices_having_nans):
if result.ndim == 0 and out is None:
# can't write to a scalar, but indexing will be correct
result = arr[-1]
else:
np.copyto(result, arr[-1, ...], where=slices_having_nans)
return result
return result, slices_having_nans


def _trapz_dispatcher(y, x=None, dx=None, axis=None):
Expand Down
Loading
0