8000 Merge pull request #20762 from honno/xp-pow-fixes · rjeb/numpy@f98c60a · GitHub
[go: up one dir, main page]

Skip to content

Commit f98c60a

Browse files
authored
Merge pull request numpy#20762 from honno/xp-pow-fixes
BUG: Allow integer inputs for pow-related functions in `array_api`
2 parents f637d75 + e3406ed commit f98c60a

File tree

4 files changed

+10
-12
lines changed

4 files changed

+10
-12
lines changed

numpy/array_api/_array_object.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -656,15 +656,13 @@ def __pos__(self: Array, /) -> Array:
656656
res = self._array.__pos__()
657657
return self.__class__._new(res)
658658

659-
# PEP 484 requires int to be a subtype of float, but __pow__ should not
660-
# accept int.
661-
def __pow__(self: Array, other: Union[float, Array], /) -> Array:
659+
def __pow__(self: Array, other: Union[int, float, Array], /) -> Array:
662660
"""
663661
Performs the operation __pow__.
664662
"""
665663
from ._elementwise_functions import pow
666664

667-
other = self._check_allowed_dtypes(other, "floating-point", "__pow__")
665+
other = self._check_allowed_dtypes(other, "numeric", "__pow__")
668666
if other is NotImplemented:
669667
return other
670668
# Note: NumPy's __pow__ does not follow type promotion rules for 0-d
@@ -914,23 +912,23 @@ def __ror__(self: Array, other: Union[int, bool, Array], /) -> Array:
914912
res = self._array.__ror__(other._array)
915913
return self.__class__._new(res)
916914

917-
def __ipow__(self: Array, other: Union[float, Array], /) -> Array:
915+
def __ipow__(self: Array, other: Union[int, float, Array], /) -> Array:
918916
"""
919917
Performs the operation __ipow__.
920918
"""
921-
other = self._check_allowed_dtypes(other, "floating-point", "__ipow__")
919+
other = self._check_allowed_dtypes(other, "numeric", "__ipow__")
922920
if other is NotImplemented:
923921
return other
924922
self._array.__ipow__(other._array)
925923
return self
926924

927-
def __rpow__(self: Array, other: Union[float, Array], /) -> Array:
925+
def __rpow__(self: Array, other: Union[int, float, Array], /) -> Array:
928926
"""
929927
Performs the operation __rpow__.
930928
"""
931929
from ._elementwise_functions import pow
932930

933-
other = self._check_allowed_dtypes(other, "floating-point", "__rpow__")
931+
other = self._check_allowed_dtypes(other, "numeric", "__rpow__")
934932
if other is NotImplemented:
935933
return other
936934
# Note: NumPy's __pow__ does not follow the spec type promotion rules

numpy/array_api/_elementwise_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -591,8 +591,8 @@ def pow(x1: Array, x2: Array, /) -> Array:
591591
592592
See its docstring for more information.
593593
"""
594-
if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes:
595-
raise TypeError("Only floating-point dtypes are allowed in pow")
594+
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
595+
raise TypeError("Only numeric dtypes are allowed in pow")
596596
# Call result type here just to raise on disallowed type combinations
597597
_result_type(x1.dtype, x2.dtype)
598598
x1, x2 = Array._normalize_two_args(x1, x2)

numpy/array_api/tests/test_array_object.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def test_operators():
9898
"__mul__": "numeric",
9999
"__ne__": "all",
100100
"__or__": "integer_or_boolean",
101-
"__pow__": "floating",
101+
"__pow__": "numeric",
102102
"__rshift__": "integer",
103103
"__sub__": "numeric",
104104
"__truediv__": "floating",

numpy/array_api/tests/test_elementwise_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def test_function_types():
6666
"negative": "numeric",
6767
"not_equal": "all",
6868
"positive": "numeric",
69-
"pow": "floating-point",
69+
"pow": "numeric",
7070
"remainder": "numeric",
7171
"round": "numeric",
7272
"sign": "numeric",

0 commit comments

Comments
 (0)
0