8000 Add lshift and rshift operators (#7741) · pydata/xarray@a220022 · GitHub
[go: up one dir, main page]

Skip to content

Commit a220022

Browse files
abrammerpre-commit-ci[bot]Illviljan
authored
Add lshift and rshift operators (#7741)
* Initial commit * Add auto generated .pyi typed_ops file * Add bitshift op test for dask * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add bitshift tests to dataarray and variable * Apply typing suggestions from code review Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> * Fix type checking on test_dask addition * Remove new type checking on test_variable edits Type checking throws errors on existing test lines * Add typing to test_1d_math and ignore 1 existing line * Add simple bitshift test on groupby ops * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Edit groupby bitshift test to use groups with len>1 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add rshift and lshift to docs Example use in computation.rst and entry in whats-new * Create new array in docs so examples later don't break * Indent second line on whats-new entry * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com>
1 parent 0f4e99d commit a220022

File tree

10 files changed

+189
-8
lines changed

10 files changed

+189
-8
lines changed

doc/user-guide/computation.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ Data arrays also implement many :py:class:`numpy.ndarray` methods:
6363
arr.round(2)
6464
arr.T
6565
66+
intarr = xr.DataArray([0, 1, 2, 3, 4, 5])
67+
intarr << 2 # only supported for int types
68+
intarr >> 1
69+
6670
.. _missing_values:
6771

6872
Missing values

doc/whats-new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ v2023.05.0 (unreleased)
2222

2323
New Features
2424
~~~~~~~~~~~~
25+
- Add support for lshift and rshift binary operators (`<<`, `>>`) on
26+
:py:class:`xr.DataArray` of type :py:class:`int` (:issue:`7727` , :pull:`7741`).
27+
By `Alan Brammer <https://github.com/abrammer>`_.
2528

2629

2730
Breaking changes

xarray/core/_typed_ops.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,12 @@ def __xor__(self, other):
4242
def __or__(self, other):
4343
return self._binary_op(other, operator.or_)
4444

45+
def __lshift__(self, other):
46+
return self._binary_op(other, operator.lshift)
47+
48+
def __rshift__(self, other):
49+
return self._binary_op(other, operator.rshift)
50+
4551
def __lt__(self, other):
4652
return self._binary_op(other, operator.lt)
4753

@@ -123,6 +129,12 @@ def __ixor__(self, other):
123129
def __ior__(self, other):
124130
return self._inplace_binary_op(other, operator.ior)
125131

132+
def __ilshift__(self, other):
133+
return self._inplace_binary_op(other, operator.ilshift)
134+
135+
def __irshift__(self, other):
136+
return self._inplace_binary_op(other, operator.irshift)
137+
126138
def _unary_op(self, f, *args, **kwargs):
127139
raise NotImplementedError
128140

@@ -160,6 +172,8 @@ def conjugate(self, *args, **kwargs):
160172
__and__.__doc__ = operator.and_.__doc__
161173
__xor__.__doc__ = operator.xor.__doc__
162174
__or__.__doc__ = operator.or_.__doc__
175+
__lshift__.__doc__ = operator.lshift.__doc__
176+
__rshift__.__doc__ = operator.rshift.__doc__
163177
__lt__.__doc__ = operator.lt.__doc__
164178
__le__.__doc__ = operator.le.__doc__
165179
__gt__.__doc__ = operator.gt.__doc__
@@ -186,6 +200,8 @@ def conjugate(self, *args, **kwargs):
186200
__iand__.__doc__ = operator.iand.__doc__
187201
__ixor__.__doc__ = operator.ixor.__doc__
188202
__ior__.__doc__ = operator.ior.__doc__
203+
__ilshift__.__doc__ = operator.ilshift.__doc__
204+
__irshift__.__doc__ = operator.irshift.__doc__
189205
__neg__.__doc__ = operator.neg.__doc__
190206
__pos__.__doc__ = operator.pos.__doc__
191207
__abs__.__doc__ = operator.abs.__doc__
@@ -232,6 +248,12 @@ def __xor__(self, other):
232248
def __or__(self, other):
233249
return self._binary_op(other, operator.or_)
234250

251+
def __lshift__(self, other):
252+
return self._binary_op(other, operator.lshift)
253+
254+
def __rshift__(self, other):
255+
return self._binary_op(other, operator.rshift)
256+
235257
def __lt__(self, other):
236258
return self._binary_op(other, operator.lt)
237259

@@ -313,6 +335,12 @@ def __ixor__(self, other):
313335
def __ior__(self, other):
314336
return self._inplace_binary_op(other, operator.ior)
315337

338+
def __ilshift__(self, other):
339+
return self._inplace_binary_op(other, operator.ilshift)
340+
341+
def __irshift__(self, other):
342+
return self._inplace_binary_op(other, operator.irshift)
343+
316344
def _unary_op(self, f, *args, **kwargs):
317345
raise NotImplementedError
318346

@@ -350,6 +378,8 @@ def conjugate(self, *args, **kwargs):
350378
__and__.__doc__ = operator.and_.__doc__
351379
__xor__.__doc__ = operator.xor.__doc__
352380
__or__.__doc__ = operator.or_.__doc__
381+
__lshift__.__doc__ = operator.lshift.__doc__
382+
__rshift__.__doc__ = operator.rshift.__doc__
353383
__lt__.__doc__ = operator.lt.__doc__
354384
__le__.__doc__ = operator.le.__doc__
355385
__gt__.__doc__ = operator.gt.__doc__
@@ -376,6 +406,8 @@ def conjugate(self, *args, **kwargs):
376406
__iand__.__doc__ = operator.iand.__doc__
377407
__ixor__.__doc__ = operator.ixor.__doc__
378408
__ior__.__doc__ = operator.ior.__doc__
409+
__ilshift__.__doc__ = operator.ilshift.__doc__
410+
__irshift__.__doc__ = operator.irshift.__doc__
379411
__neg__.__doc__ = operator.neg.__doc__
380412
__pos__.__doc__ = operator.pos.__doc__
381413
__abs__.__doc__ = operator.abs.__doc__
@@ -422,6 +454,12 @@ def __xor__(self, other):
422454
def __or__(self, other):
423455
return self._binary_op(other, operator.or_)
424456

457+
def __lshift__(self, other):
458+
return self._binary_op(other, operator.lshift)
459+
460+
def __rshift__(self, other):
461+
return self._binary_op(other, operator.rshift)
462+
425463
def __lt__(self, other):
426464
return self._binary_op(other, operator.lt)
427465

@@ -503,6 +541,12 @@ def __ixor__(self, other):
503541
def __ior__(self, other):
504542
return self._inplace_binary_op(other, operator.ior)
505543

544+
def __ilshift__(self, other):
545+
return self._inplace_binary_op(other, operator.ilshift)
546+
547+
def __irshift__(self, other):
548+
return self._inplace_binary_op(other, operator.irshift)
549+
506550
def _unary_op(self, f, *args, **kwargs):
507551
raise NotImplementedError
508552

@@ -540,6 +584,8 @@ def conjugate(self, *args, **kwargs):
540584
__and__.__doc__ = operator.and_.__doc__
541585
__xor__.__doc__ = operator.xor.__doc__
542586
__or__.__doc__ = operator.or_.__doc__
587+
__lshift__.__doc__ = operator.lshift.__doc__
588+
__rshift__.__doc__ = operator.rshift.__doc__
543589
__lt__.__doc__ = operator.lt.__doc__
544590
__le__.__doc__ = operator.le.__doc__
545591
__gt__.__doc__ = operator.gt.__doc__
@@ -566,6 +612,8 @@ def conjugate(self, *args, **kwargs):
566612
__iand__.__doc__ = operator.iand.__doc__
567613
__ixor__.__doc__ = operator.ixor.__doc__
568614
__ior__.__doc__ = operator.ior.__doc__
615+
__ilshift__.__doc__ = operator.ilshift.__doc__
616+
__irshift__.__doc__ = operator.irshift.__doc__
569617
__neg__.__doc__ = operator.neg.__doc__
570618
__pos__.__doc__ = operator.pos.__doc__
571619
__abs__.__doc__ = operator.abs.__doc__
@@ -612,6 +660,12 @@ def __xor__(self, other):
612660
def __or__(self, other):
613661
return self._binary_op(other, operator.or_)
614662

663+
def __lshift__(self, other):
664+
return self._binary_op(other, operator.lshift)
665+
666+
def __rshift__(self, other):
667+
return self._binary_op(other, operator.rshift)
668+
615669
def __lt__(self, other):
616670
return self._binary_op(other, operator.lt)
617671

@@ -670,6 +724,8 @@ def __ror__(self, other):
670724
__and__.__doc__ = operator.and_.__doc__
671725
__xor__.__doc__ = operator.xor.__doc__
672726
__or__.__doc__ = operator.or_.__doc__
727+
__lshift__.__doc__ = operator.lshift.__doc__
728+
__rshift__.__doc__ = operator.rshift.__doc__
673729
__lt__.__doc__ = operator.lt.__doc__
674730
__le__.__doc__ = operator.le.__doc__
675731
__gt__.__doc__ = operator.gt.__doc__
@@ -724,6 +780,12 @@ def __xor__(self, other):
724780
def __or__(self, other):
725781
return self._binary_op(other, operator.or_)
726782

783+
def __lshift__(self, other):
784+
return self._binary_op(other, operator.lshift)
785+
786+
def __rshift__(self, other):
787+
return self._binary_op(other, operator.rshift)
788+
727789
def __lt__(self, other):
728790
return self._binary_op(other, operator.lt)
729791

@@ -782,6 +844,8 @@ def __ror__(self, other):
782844
__and__.__doc__ = operator.and_.__doc__
783845
__xor__.__doc__ = operator.xor.__doc__
784846
__or__.__doc__ = operator.or_.__doc__
847+
__lshift__.__doc__ = operator.lshift.__doc__
848+
__rshift__.__doc__ = operator.rshift.__doc__
785849
__lt__.__doc__ = operator.lt.__doc__
786850
__le__.__doc__ = operator.le.__doc__
787851
__gt__.__doc__ = operator.gt.__doc__

xarray/core/_typed_ops.pyi

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ class DatasetOpsMixin:
4444
def __and__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ...
4545
def __xor__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ...
4646
def __or__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ...
47+
def __lshift__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ...
48+
def __rshift__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ...
4749
def __lt__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ...
4850
def __le__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ...
4951
def __gt__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ...
@@ -135,6 +137,18 @@ class DataArrayOpsMixin:
135137
@overload
136138
def __or__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ...
137139
@overload
140+
def __lshift__(self, other: T_Dataset) -> T_Dataset: ...
141+
@overload
142+
def __lshift__(self, other: "DatasetGroupBy") -> "Dataset": ...
143+
@overload
144+
def __lshift__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ...
145+
@overload
146+
def __rshift__(self, other: T_Dataset) -> T_Dataset: ...
147+
@overload
148+
def __rshift__(self, other: "DatasetGroupBy") -> "Dataset": ...
149+
@overload
150+
def __rshift__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ...
151+
@overload
138152
def __lt__(self, other: T_Dataset) -> T_Dataset: ...
139153
@overload
140154
def __lt__(self, other: "DatasetGroupBy") -> "Dataset": ...
@@ -305,6 +319,18 @@ class VariableOpsMixin:
305319
@overload
306320
def __or__(self: T_Variable, other: VarCompatible) -> T_Variable: ...
307321
@overload
322+
def __lshift__(self, other: T_Dataset) -> T_Dataset: ...
323+
@overload
324+
def __lshift__(self, other: T_DataArray) -> T_DataArray: ...
325+
@overload
326+
def __lshift__(self: T_Variable, other: VarCompatible) -> T_Variable: ...
327+
@overload
328+
def __rshift__(self, other: T_Dataset) -> T_Dataset: ...
329+
@overload
330+
def __rshift__(self, other: T_DataArray) -> T_DataArray: ...
331+
@overload
332+
def __rshift__(self: T_Variable, other: VarCompatible) -> T_Variable: ...
333+
@overload
308334
def __lt__(self, other: T_Dataset) -> T_Dataset: ...
309335
@overload
310336
def __lt__(self, other: T_DataArray) -> T_DataArray: ...
@@ -475,6 +501,18 @@ class DatasetGroupByOpsMixin:
475501
@overload
476502
def __or__(self, other: GroupByIncompatible) -> NoReturn: ...
477503
@overload
504+
def __lshift__(self, other: T_Dataset) -> T_Dataset: ...
505+
@overload
506+
def __lshift__(self, other: "DataArray") -> "Dataset": ...
507+
@overload
508+
def __lshift__(self, other: GroupByIncompatible) -> NoReturn: ...
509+
@overload
510+
def __rshift__(self, other: T_Dataset) -> T_Dataset: ...
511+
@overload
512+
def __rshift__(self, other: "DataArray") -> "Dataset": ...
513+
@overload
514+
def __rshift__(self, other: GroupByIncompatible) -> NoReturn: ...
515+
@overload
478516
def __lt__(self, other: T_Dataset) -> T_Dataset: ...
479517
@overload
480518
def __lt__(self, other: "DataArray") -> "Dataset": ...
@@ -635,6 +673,18 @@ class DataArrayGroupByOpsMixin:
635673
@overload
636674
def __or__(self, other: GroupByIncompatible) -> NoReturn: ...
637675
@overload
676+
def __lshift__(self, other: T_Dataset) -> T_Dataset: ...
677+
@overload
678+
def __lshift__(self, other: T_DataArray) -> T_DataArray: ...
679+
@overload
680+
def __lshift__(self, other: GroupByIncompatible) -> NoReturn: ...
681+
@overload
682+
def __rshift__(self, other: T_Dataset) -> T_Dataset: ...
683+
@overload
684+
def __rshift__(self, other: T_DataArray) -> T_DataArray: ...
685+
@overload
686+
def __rshift__(self, other: GroupByIncompatible) -> NoReturn: ...
687+
@overload
638688
def __lt__(self, other: T_Dataset) -> T_Dataset: ...
639689
@overload
640690
def __lt__(self, other: T_DataArray) -> T_DataArray: ...

xarray/core/ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
"and",
3434
"xor",
3535
"or",
36+
"lshift",
37+
"rshift",
3638
]
3739

3840
# methods which pass on the numpy return value unchanged

xarray/tests/test_dask.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,19 @@ def test_binary_op(self):
178178
self.assertLazyAndIdentical(u + u, v + v)
179179
self.assertLazyAndIdentical(u[0] + u, v[0] + v)
180180

181+
def test_binary_op_bitshift(self) -> None:
182+
# bit shifts only work on ints so we need to generate
183+
# new eager and lazy vars
184+
rng = np.random.default_rng(0)
185+
values = rng.integers(low=-10000, high=10000, size=(4, 6))
186+
data = da.from_array(values, chunks=(2, 2))
187+
u = Variable(("x", "y"), values)
188+
v = Variable(("x", "y"), data)
189+
self.assertLazyAndIdentical(u << 2, v << 2)
190+
self.assertLazyAndIdentical(u << 5, v << 5)
191+
self.assertLazyAndIdentical(u >> 2, v >> 2)
192+
self.assertLazyAndIdentical(u >> 5, v >> 5)
193+
181194
def test_repr(self):
182195
expected = dedent(
183196
"""\

xarray/tests/test_dataarray.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3926,6 +3926,11 @@ def test_binary_op_propagate_indexes(self) -> None:
39263926
actual = (self.dv > 10).xindexes["x"]
39273927
assert expected is actual
39283928

3929+
# use mda for bitshift test as it's type int
3930+
actual = (self.mda << 2).xindexes["x"]
3931+
expected = self.mda.xindexes["x"]
3932+
assert expected is actual
3933+
39293934
def test_binary_op_join_setting(self) -> None:
39303935
dim = "x"
39313936
align_type: Final = "outer"

xarray/tests/test_groupby.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -808,6 +808,34 @@ def test_groupby_math_more() -> None:
808808
ds + ds.groupby("time.month")
809809

810810

811+
def test_groupby_math_bitshift() -> None:
812+
# create new dataset of int's only
813+
ds = Dataset(
814+
{
815+
"x": ("index", np.ones(4, dtype=int)),
816+
"y": ("index", np.ones(4, dtype=int) * -1),
817+
"level": ("index", [1, 1, 2, 2]),
818+
"index": [0, 1, 2, 3],
819+
}
820+
)
821+
shift = DataArray([1, 2, 1], [("level", [1, 2, 8])])
822+
823+
left_expected = Dataset(
824+
{
825+
"x": ("index", [2, 2, 4, 4]),
826+
"y": ("index", [-2, -2, -4, -4]),
827+
"level": ("index", [2, 2, 8, 8]),
828+
"index": [0, 1, 2, 3],
829+
}
830+
)
831+
832+
left_actual = (ds.groupby("level") << shift).reset_coords(names="level")
833+
assert_equal(left_expected, left_actual)
834+
835+
right_actual = (left_expected.groupby("level") >> shift).reset_coords(names="level")
836+
assert_equal(ds, right_actual)
837+
838+
811839
@pytest.mark.parametrize("use_flox", [True, False])
812840
def test_groupby_bins_cut_kwargs(use_flox: bool) -> None:
813841
da = xr.DataArray(np.arange(12).reshape(6, 2), dims=("x", "y"))

0 commit comments

Comments
 (0)
0