8000 linear interp with NaNs in nd indexer (#4233) · pydata/xarray@ffce4ec · GitHub
[go: up one dir, main page]

Skip to content

Commit ffce4ec

Browse files
authored
linear interp with NaNs in nd indexer (#4233)
* Added test for nd interpolation with nan * Now ignoring NaNs in missing._localize When interpolating with an nd indexer that contains NaN's, the code previously threw a KeyError from the missing._localize function. This commit fixes this by swapping `np.min` and `np.max` with `np.nanmin` and `np.nanmax`, ignoring any NaN values. * Added `@requires_scipy` to test. Also updated what's new. * Added numpy>=1.18 checks with `LooseVersion` * Added checks for np.datetime64 type This means the PR now also works for numpy < 1.18, as long as index is not with datetime * Removed `raise ValueError` from previous commit It seems that np.min/max works in place of nanmin/nanmax for datetime types for numpy < 1.18, see https://github.com/pydata/xarray/pull/3924/files * Added datetime `NaT` test. Also added a test for `Dataset` to `test_interpolate_nd_with_nan`, and "Missing values are skipped." to the dosctring of `interp` and `interp_like` methods of `DataArray` and `Dataset`.
1 parent 9c85dd5 commit ffce4ec

File tree

5 files changed

+50
-9
lines changed

5 files changed

+50
-9
lines changed

doc/whats-new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ Bug fixes
7777
arrays (:issue:`4341`). By `Spencer Clark <https://github.com/spencerkclark>`_.
7878
- Fix :py:func:`xarray.apply_ufunc` with ``vectorize=True`` and ``exclude_dims`` (:issue:`3890`).
7979
By `Mathias Hauser <https://github.com/mathause>`_.
80+
- Fix `KeyError` when doing linear interpolation to an nd `DataArray`
81+
that contains NaNs (:pull:`4233`).
82+
By `Jens Svensmark <https://github.com/jenssss>`_
8083

8184
Documentation
8285
~~~~~~~~~~~~~

xarray/core/dataarray.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1420,9 +1420,9 @@ def interp(
14201420
----------
14211421
coords : dict, optional
14221422
Mapping from dimension names to the new coordinates.
1423-
new coordinate can be an scalar, array-like or DataArray.
1424-
If DataArrays are passed as new coordates, their dimensions are
1425-
used for the broadcasting.
1423+
New coordinate can be an scalar, array-like or DataArray.
1424+
If DataArrays are passed as new coordinates, their dimensions are
1425+
used for the broadcasting. Missing values are skipped.
14261426
method : str, default: "linear"
14271427
The method used to interpolate. Choose from
14281428
@@ -1492,7 +1492,7 @@ def interp_like(
14921492
other : Dataset or DataArray
14931493
Object with an 'indexes' attribute giving a mapping from dimension
14941494
names to an 1d array-like, which provides coordinates upon
1495-
which to index the variables in this dataset.
1495+
which to index the variables in this dataset. Missing values are skipped.
14961496
method : str, default: "linear"
14971497
The method used to interpolate. Choose from
14981498

xarray/core/dataset.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2604,8 +2604,8 @@ def interp(
26042604
coords : dict, optional
26052605
Mapping from dimension names to the new coordinates.
26062606
New coordinate can be a scalar, array-like or DataArray.
2607-
If DataArrays are passed as new coordates, their dimensions are
2608-
used for the broadcasting.
2607+
If DataArrays are passed as new coordinates, their dimensions are
2608+
used for the broadcasting. Missing values are skipped.
26092609
method : str, optional
26102610
{"linear", "nearest"} for multidimensional array,
26112611
{"linear", "nearest", "zero", "slinear", "quadratic", "cubic"}
@@ -2733,7 +2733,7 @@ def interp_like(
27332733
other : Dataset or DataArray
27342734
Object with an 'indexes' attribute giving a mapping from dimension
27352735
names to an 1d array-like, which provides coordinates upon
2736-
which to index the variables in this dataset.
2736+
which to index the variables in this dataset. Missing values are skipped.
27372737
method : str, optional
27382738
{"linear", "nearest"} for multidimensional array,
27392739
{"linear", "nearest", "zero", "slinear", "quadratic", "cubic"}

xarray/core/missing.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import datetime as dt
22
import warnings
3+
from distutils.version import LooseVersion
34
from functools import partial
45
from numbers import Number
56
from typing import Any, Callable, Dict, Hashable, Sequence, Union
@@ -550,9 +551,19 @@ def _localize(var, indexes_coords):
550551
"""
551552
indexes = {}
552553
for dim, [x, new_x] in indexes_coords.items():
554+
if np.issubdtype(new_x.dtype, np.datetime64) and LooseVersion(
555+
np.__version__
556+
) < LooseVersion("1.18"):
557+
# np.nanmin/max changed behaviour for datetime types in numpy 1.18,
558+
# see https://github.com/pydata/xarray/pull/3924/files
559+
minval = np.min(new_x.values)
560+
maxval = np.max(new_x.values)
561+
else:
562+
minval = np.nanmin(new_x.values)
563+
maxval = np.nanmax(new_x.values)
553564
index = x.to_index()
554-
imin = index.get_loc(np.min(new_x.values), method="nearest")
555-
imax = index.get_loc(np.max(new_x.values), method="nearest")
565+
imin = index.get_loc(minval, method="nearest")
566+
imax = index.get_loc(maxval, method="nearest")
556567

557568
indexes[dim] = slice(max(imin - 2, 0), imax + 2)
558569
indexes_coords[dim] = (x[indexes[dim]], new_x)

xarray/tests/test_interp.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,32 @@ def test_interpolate_nd_nd():
277277
da.interp(a=ia)
278278

279279

280+
@requires_scipy
281+
def test_interpolate_nd_with_nan():
282+
"""Interpolate an array with an nd indexer and `NaN` values."""
283+
284+
# Create indexer into `a` with dimensions (y, x)
285+
x = [0, 1, 2]
286+
y = [10, 20]
287+
c = {"x": x, "y": y}
288+
a = np.arange(6, dtype=float).reshape(2, 3)
289+
a[0, 1] = np.nan
290+
ia = xr.DataArray(a, dims=("y", "x"), coords=c)
291+
292+
da = xr.DataArray([1, 2, 2], dims=("a"), coords={"a": [0, 2, 4]})
293+
out = da.interp(a=ia)
294+
expected = xr.DataArray(
295+
[[1.0, np.nan, 2.0], [2.0, 2.0, np.nan]], dims=("y", "x"), coords=c
296+
)
297+
xr.testing.assert_allclose(out.drop_vars("a"), expected)
298+
299+
db = 2 * da
300+
ds = xr.Dataset({"da": da, "db": db})
301+
out = ds.interp(a=ia)
302+
expected_ds = xr.Dataset({"da": expected, "db": 2 * expected})
303+
xr.testing.assert_allclose(out.drop_vars("a"), expected_ds)
304+
305+
280306
@pytest.mark.parametrize("method", ["linear"])
281307
@pytest.mark.parametrize("case", [0, 1])
282308
def test_interpolate_scalar(method, case):
@@ -553,6 +579,7 @@ def test_interp_like():
553579
[0.5, 1.5],
554580
),
555581
(["2000-01-01T12:00", "2000-01-02T12:00"], [0.5, 1.5]),
582+
(["2000-01-01T12:00", "2000-01-02T12:00", "NaT"], [0.5, 1.5, np.nan]),
556583
(["2000-01-01T12:00"], 0.5),
557584
pytest.param("2000-01-01T12:00", 0.5, marks=pytest.mark.xfail),
558585
],

0 commit comments

Comments
 (0)
0