8000 Reimplement quantile with apply_ufunc (#3559) · pydata/xarray@7dfdfca · GitHub
[go: up one dir, main page]

Skip to content

Commit 7dfdfca

Browse files
authored
Reimplement quantile with apply_ufunc (#3559)
* Reimplement quantile with apply_ufunc * Update xarray/core/variable.py Co-Authored-By: Stephan Hoyer <shoyer@google.com> * Update doc/whats-new.rst
1 parent 8a148b6 commit 7dfdfca

File tree

6 files changed

+91
-69
lines changed

6 files changed

+91
-69
lines changed

doc/whats-new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ Breaking changes
2525

2626
New Features
2727
~~~~~~~~~~~~
28+
- :py:meth:`Dataset.quantile`, :py:meth:`DataArray.quantile` and ``GroupBy.quantile``
29+
now work with dask Variables.
30+
By `Deepak Cherian <https://github.com/dcherian>`_.
2831

2932

3033
Bug fixes

xarray/core/dataset.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5166,11 +5166,7 @@ def quantile(
51665166
new = self._replace_with_new_dims(
51675167
variables, coord_names=coord_names, attrs=attrs, indexes=indexes
51685168
)
5169-
if "quantile" in new.dims:
5170-
new.coords["quantile"] = Variable("quantile", q)
5171-
else:
5172-
new.coords["quantile"] = q
5173-
return new
5169+
return new.assign_coords(quantile=q)
51745170

51755171
def rank(self, dim, pct=False, keep_attrs=None):
51765172
"""Ranks the data.

xarray/core/variable.py

Lines changed: 34 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1716,40 +1716,45 @@ def quantile(self, q, dim=None, interpolation="linear", keep_attrs=None):
17161716
numpy.nanpercentile, pandas.Series.quantile, Dataset.quantile,
17171717
DataArray.quantile
17181718
"""
1719-
if isinstance(self.data, dask_array_type):
1720-
raise TypeError(
1721-
"quantile does not work for arrays stored as dask "
1722-
"arrays. Load the data via .compute() or .load() "
1723-
"prior to calling this method."
1724-
)
17251719

1726-
q = np.asarray(q, dtype=np.float64)
1727-
1728-
new_dims = list(self.dims)
1729-
if dim is not None:
1730-
axis = self.get_axis_num(dim)
1731-
if utils.is_scalar(dim):
1732-
new_dims.remove(dim)
1733-
else:
1734-
for d in dim:
1735-
new_dims.remove(d)
1736-
else:
1737-
axis = None
1738-
new_dims = []
1739-
1740-
# Only add the quantile dimension if q is array-like
1741-
if q.ndim != 0:
1742-
new_dims = ["quantile"] + new_dims
1743-
1744-
qs = np.nanpercentile(
1745-
self.data, q * 100.0, axis=axis, interpolation=interpolation
1746-
)
1720+
from .computation import apply_ufunc
17471721

17481722
if keep_attrs is None:
17491723
keep_attrs = _get_keep_attrs(default=False)
1750-
attrs = self._attrs if keep_attrs else None
17511724

1752-
return Variable(new_dims, qs, attrs)
1725+
scalar = utils.is_scalar(q)
1726+
q = np.atleast_1d(np.asarray(q, dtype=np.float64))
1727+
1728+
if dim is None:
1729+
dim = self.dims
1730+
1731+
if utils.is_scalar(dim):
1732+
dim = [dim]
1733+
1734+
def _wrapper(npa, **kwargs):
1735+
# move quantile axis to end. required for apply_ufunc
1736+
return np.moveaxis(np.nanpercentile(npa, **kwargs), 0, -1)
1737+
1738+
axis = np.arange(-1, -1 * len(dim) - 1, -1)
1739+
result = apply_ufunc(
1740+
_wrapper,
1741+
self,
1742+
input_core_dims=[dim],
1743+
exclude_dims=set(dim),
1744+
output_core_dims=[["quantile"]],
1745+
output_dtypes=[np.float64],
1746+
output_sizes={"quantile": len(q)},
1747+
dask="parallelized",
1748+
kwargs={"q": q * 100, "axis": axis, "interpolation": interpolation},
1749+
)
1750+
1751+
# for backward compatibility
1752+
result = result.transpose("quantile", ...)
1753+
if scalar:
1754+
result = result.squeeze("quantile")
1755+
if keep_attrs:
1756+
result.attrs = self._attrs
1757+
return result
17531758

17541759
def rank(self, dim, pct=False):
17551760
"""Ranks the data.

xarray/tests/test_dataarray.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from xarray.core import dtypes
1616
from xarray.core.common import full_like
1717
from xarray.core.indexes import propagate_indexes
18+
from xarray.core.utils import is_scalar
19+
1820
from xarray.tests import (
1921
LooseVersion,
2022
ReturnItem,
@@ -2330,17 +2332,20 @@ def test_reduce_out(self):
23302332
with pytest.raises(TypeError):
23312333
orig.mean(out=np.ones(orig.shape))
23322334

2333-
def test_quantile(self):
2334-
for q in [0.25, [0.50], [0.25, 0.75]]:
2335-
for axis, dim in zip(
2336-
[None, 0, [0], [0, 1]], [None, "x", ["x"], ["x", "y"]]
2337-
):
2338-
actual = DataArray(self.va).quantile(q, dim=dim, keep_attrs=True)
2339-
expected = np.nanpercentile(
2340-
self.dv.values, np.array(q) * 100, axis=axis
2341-
)
2342-
np.testing.assert_allclose(actual.values, expected)
2343-
assert actual.attrs == self.attrs
2335+
@pytest.mark.parametrize("q", [0.25, [0.50], [0.25, 0.75]])
2336+
@pytest.mark.parametrize(
2337+
"axis, dim", zip([None, 0, [0], [0, 1]], [None, "x", ["x"], ["x", "y"]])
2338+
)
2339+
def test_quantile(self, q, axis, dim):
2340+
actual = DataArray(self.va).quantile(q, dim=dim, keep_attrs=True)
2341+
expected = np.nanpercentile(self.dv.values, np.array(q) * 100, axis=axis)
2342+
np.testing.assert_allclose(actual.values, expected)
2343+
if is_scalar(q):
2344+
assert "quantile" not in actual.dims
2345+
else:
2346+
assert "quantile" in actual.dims
2347+
2348+
assert actual.attrs == self.attrs
23442349

23452350
def test_reduce_keep_attrs(self):
23462351
# Test dropped attrs

xarray/tests/test_dataset.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from xarray.core.common import duck_array_ops, full_like
2929
from xarray.core.npcompat import IS_NEP18_ACTIVE
3030
from xarray.core.pycompat import integer_types
31+
from xarray.core.utils import is_scalar
3132

3233
from . import (
3334
InaccessibleArray,
@@ -4575,21 +4576,24 @@ def test_reduce_keepdims(self):
45754576
)
45764577
assert_identical(expected, actual)
45774578

4578-
def test_quantile(self):
4579-
4579+
@pytest.mark.parametrize("q", [0.25, [0.50], [0.25, 0.75]])
4580+
def test_quantile(self, q):
45804581
ds = create_test_data(seed=123)
45814582

4582-
for q in [0.25, [0.50], [0.25, 0.75]]:
4583-
for dim in [None, "dim1", ["dim1"]]:
4584-
ds_quantile = ds.quantile(q, dim=dim)
4585-
assert "quantile" in ds_quantile
4586-
for var, dar in ds.data_vars.items():
4587-
assert var in ds_quantile
4588-
assert_identical(ds_quantile[var], dar.quantile(q, dim=dim))
4589-
dim = ["dim1", "dim2"]
4583+
for dim in [None, "dim1", ["dim1"]]:
45904584
ds_quantile = ds.quantile(q, dim=dim)
4591-
assert "dim3" in ds_quantile.dims
4592-
assert all(d not in ds_quantile.dims for d in dim)
4585+
if is_scalar(q):
4586+
assert "quantile" not in ds_quantile.dims
4587+
else:
4588+
assert "quantile" in ds_quantile.dims
4589+
4590+
for var, dar in ds.data_vars.items():
4591+
assert var in ds_quantile
4592+
assert_identical(ds_quantile[var], dar.quantile(q, dim=dim))
4593+
dim = ["dim1", "dim2"]
4594+
ds_quantile = ds.quantile(q, dim=dim)
4595+
assert "dim3" in ds_quantile.dims
4596+
assert all(d not in ds_quantile.dims for d in dim)
45934597

45944598
@requires_bottleneck
45954599
def test_rank(self):

xarray/tests/test_variable.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
PandasIndexAdapter,
2323
VectorizedIndexer,
2424
)
25+
from xarray.core.pycompat import dask_array_type
2526
from xarray.core.utils import NDArrayMixin
2627
from xarray.core.variable import as_compatible_data, as_variable
2728
from xarray.tests import requires_bottleneck
@@ -1492,23 +1493,31 @@ def test_reduce(self):
14921493
with pytest.warns(DeprecationWarning, match="allow_lazy is deprecated"):
14931494
v.mean(dim="x", allow_lazy=False)
14941495

1495-
def test_quantile(self):
1496+
@pytest.mark.parametrize("q", [0.25, [0.50], [0.25, 0.75]])
1497+
@pytest.mark.parametrize(
1498+
"axis, dim", zip([None, 0, [0], [0, 1]], [None, "x", ["x"], ["x", "y"]])
1499+
)
1500+
def test_quantile(self, q, axis, dim):
14961501
v = Variable(["x", "y"], self.d)
1497-
for q in [0.25, [0.50], [0.25, 0.75]]:
1498-
for axis, dim in zip(
1499-
[None, 0, [0], [0, 1]], [None, "x", ["x"], ["x", "y"]]
1500-
):
1501-
actual = v.quantile(q, dim=dim)
1502+
actual = v.quantile(q, dim=dim)
1503+
expected = np.nanpercentile(self.d, np.array(q) * 100, axis=axis)
1504+
np.testing.assert_allclose(actual.values, expected)
15021505

1503-
expected = np.nanpercentile(self.d, np.array(q) * 100, axis=axis)
1504-
np.testing.assert_allclose(actual.values, expected)
1506+
@requires_dask
1507+
@pytest.mark.parametrize("q", [0.25, [0.50], [0.25, 0.75]])
1508+
@pytest.mark.parametrize("axis, dim", [[1, "y"], [[1], ["y"]]])
1509+
def test_quantile_dask(self, q, axis, dim):
1510+
v = Variable(["x", "y"], self.d).chunk({"x": 2})
1511+
actual = v.quantile(q, dim=dim)
1512+
assert isinstance(actual.data, dask_array_type)
1513+
expected = np.nanpercentile(self.d, np.array(q) * 100, axis=axis)
1514+
np.testing.assert_allclose(actual.values, expected)
15051515

15061516
@requires_dask
1507-
def test_quantile_dask_raises(self):
1508-
# regression for GH1524
1509-
v = Variable(["x", "y"], self.d).chunk(2)
1517+
def test_quantile_chunked_dim_error(self):
1518+
v = Variable(["x", "y"], self.d).chunk({"x": 2})
15101519

1511-
with raises_regex(TypeError, "arrays stored as dask"):
1520+
with raises_regex(ValueError, "dimension 'x'"):
15121521
v.quantile(0.5, dim="x")
15131522

15141523
@requires_dask

0 commit comments

Comments
 (0)
0