8000 ENH enable `__imatmul__` for ArrayAPI compatability. by Micky774 · Pull Request #21912 · numpy/numpy · GitHub
[go: up one dir, main page]

Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10BC0
Next Next commit
Initial tinkering
  • Loading branch information
Micky774 committed Jul 2, 2022
commit e3de0e6987f90d36113867928ff219cf5c04b2df
24 changes: 16 additions & 8 deletions numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import enum
from abc import abstractmethod
from types import TracebackType, MappingProxyType
from contextlib import ContextDecorator
from contextlib import contextmanager

if sys.version_info >= (3, 9):
from types import GenericAlias
Expand Down Expand Up @@ -181,7 +180,6 @@ from collections.abc import (
from typing import (
Literal as L,
Any,
Generator,
Generic,
IO,
NoReturn,
Expand Down Expand Up @@ -1915,7 +1913,6 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]):
def __neg__(self: NDArray[object_]) -> Any: ...

# Binary ops
# NOTE: `ndarray` does not implement `__imatmul__`
@overload
def __matmul__(self: NDArray[bool_], other: _ArrayLikeBool_co) -> NDArray[bool_]: ... # type: ignore[misc]
@overload
Expand All @@ -1933,6 +1930,22 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]):
@overload
def __matmul__(self: NDArray[Any], other: _ArrayLikeObject_co) -> Any: ...

def __imatmul__(self: NDArray[bool_], other: _ArrayLikeBool_co) -> NDArray[bool_]: ... # type: ignore[misc]
@overload
def __imatmul__(self: _ArrayUInt_co, other: _ArrayLikeUInt_co) -> NDArray[unsignedinteger[Any]]: ... # type: ignore[misc]
@overload
def __imatmul__(self: _ArrayInt_co, other: _ArrayLikeInt_co) -> NDArray[signedinteger[Any]]: ... # type: ignore[misc]
@overload
def __imatmul__(self: _ArrayFloat_co, other: _ArrayLikeFloat_co) -> NDArray[floating[Any]]: ... # type: ignore[misc]
@overload
def __imatmul__(self: _ArrayComplex_co, other: _ArrayLikeComplex_co) -> NDArray[complexfloating[Any, Any]]: ...
@overload
def __imatmul__(self: NDArray[number[Any]], other: _ArrayLikeNumber_co) -> NDArray[number[Any]]: ...
@overload
def __imatmul__(self: NDArray[object_], other: Any) -> Any: ...
@overload
def __imatmul__(self: NDArray[Any], other: _ArrayLikeObject_co) -> Any: ...

@overload
def __rmatmul__(self: NDArray[bool_], other: _ArrayLikeBool_co) -> NDArray[bool_]: ... # type: ignore[misc]
@overload
Expand Down Expand Up @@ -3352,11 +3365,6 @@ class errstate(Generic[_CallType], ContextDecorator):
/,
) -> None: ...

@contextmanager
def _no_nep50_warning() -> Generator[None, None, None]: ...
def _get_promotion_state() -> str: ...
def _set_promotion_state(state: str, /) -> None: ...

class ndenumerate(Generic[_ScalarType]):
iter: flatiter[NDArray[_ScalarType]]
@overload
Expand Down
14 changes: 7 additions & 7 deletions numpy/core/src/multiarray/number.c
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ _PyArray_SetNumericOps(PyObject *dict)
SET(conjugate);
SET(matmul);
SET(clip);
SET(imatmul);
return 0;
}

Expand Down Expand Up @@ -183,6 +184,7 @@ _PyArray_GetNumericOps(void)
GET(conjugate);
GET(matmul);
GET(clip);
GET(imatmul);
return dict;

fail:
Expand Down Expand Up @@ -348,13 +350,11 @@ array_matrix_multiply(PyObject *m1, PyObject *m2)
}

static PyObject *
array_inplace_matrix_multiply(
PyArrayObject *NPY_UNUSED(m1), PyObject *NPY_UNUSED(m2))
array_inplace_matrix_multiply(PyArrayObject *m1, PyObject *m2)
{
PyErr_SetString(PyExc_TypeError,
"In-place matrix multiplication is not (yet) supported. "
"Use 'a = a @ b' instead of 'a @= b'.");
return NULL;
INPLACE_GIVE_UP_IF_NEEDED(
m1, m2, nb_inplace_matrix_multiply, array_inplace_multiply);
return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.imatmul);
}

/*
B347 Expand Down Expand Up @@ -690,7 +690,7 @@ array_inplace_multiply(PyArrayObject *m1, PyObject *m2)
{
INPLACE_GIVE_UP_IF_NEEDED(
m1, m2, nb_inplace_multiply, array_inplace_multiply);
return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.multiply);
return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.matmul);
}

static PyObject *
Expand Down
1 change: 1 addition & 0 deletions numpy/core/src/multiarray/number.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ typedef struct {
PyObject *conjugate;
PyObject *matmul;
PyObject *clip;
PyObject *imatmul;
} NumericOps;

extern NPY_NO_EXPORT NumericOps n_ops;
Expand Down
0