8000 fix: `finfo_object`, `iinfo_object`, `_array` to typing.Protocol by 34j · Pull Request #857 · data-apis/array-api · GitHub
[go: up one dir, main page]

Skip to content

fix: finfo_object, iinfo_object, _array to typing.Protocol #857

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
docs: fix docs again
  • Loading branch information
34j committed Nov 24, 2024
commit f3c9eb4a45f413cd965e6c6a1711192dca2f92ba
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,5 @@ tmp/
*.egg
dist/
.DS_STORE
venv
.venv
10 changes: 10 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,13 @@ repos:
rev: 23.7.0
hooks:
- id: black

- repo: https://github.com/pre-commit/mirrors-mypy
rev: "v1.0.0"
hooks:
- id: mypy
additional 10000 _dependencies: [typing_extensions>=4.4.0]
args:
- --ignore-missing-imports
- --config=pyproject.toml
files: ".*(_draft.*)$"
9 changes: 9 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,12 @@ build-backend = "setuptools.build_meta"

[tool.black]
line-length = 88

[tool.mypy]
python_version = "3.9"
mypy_path = "$MYPY_CONFIG_FILE_DIR/src/array_api_stubs/_draft/"
files = [
"src/array_api_stubs/_draft/**/*.py"
]
follow_imports = "silent"
disable_error_code = "empty-body,type-var"
4 changes: 4 additions & 0 deletions src/_array_api_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,11 @@
]
nitpick_ignore_regex = [
("py:class", ".*array"),
("py:class", ".*Array"),
("py:class", ".*device"),
("py:class", ".*Device"),
("py:class", ".*dtype"),
("py:class", ".*DType"),
("py:class", ".*NestedSequence"),
("py:class", ".*SupportsBufferProtocol"),
("py:class", ".*PyCapsule"),
Expand All @@ -84,6 +87,7 @@
"array": "array",
"Device": "device",
"Dtype": "dtype",
"DType": "dtype",
}

# Make autosummary show the signatures of functions in the tables using actual
Expand Down
22 changes: 12 additions & 10 deletions src/array_api_stubs/_draft/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
"Info",
]

from dataclasses import dataclass
from typing import (
Any,
List,
Expand All @@ -45,10 +44,13 @@
Protocol,
)
from enum import Enum
from .data_types import DType

array = TypeVar("array", bound="array_")
array = TypeVar("array", bound="Array")
device = TypeVar("device")
dtype = TypeVar("dtype")
dtype = TypeVar("dtype", bound=DType)
device_ = TypeVar("device_") # only used in this file
dtype_ = TypeVar("dtype_", bound=DType) # only used in this file
SupportsDLPack = TypeVar("SupportsDLPack")
SupportsBufferProtocol = TypeVar("SupportsBufferProtocol")
PyCapsule = TypeVar("PyCapsule")
Expand Down Expand Up @@ -88,7 +90,7 @@ def __len__(self, /) -> int:
...


class Info(Protocol):
class Info(Protocol[device]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain what the [device] here means after Protocol? The Info namespace object itself does not depend on the device (__array_namespace_info__ does not take a device parameter).

"""Namespace returned by `__array_namespace_info__`."""

def capabilities(self) -> Capabilities:
Expand Down Expand Up @@ -147,12 +149,12 @@ def dtypes(
)


class _array(Protocol[array, dtype, device]):
class Array(Protocol[array, dtype_, device_, PyCapsule]): # type: ignore
def __init__(self: array) -> None:
"""Initialize the attributes for the array object class."""

@property
def dtype(self: array) -> dtype:
def dtype(self: array) -> dtype_:
"""
Data type of the array elements.

Expand All @@ -163,7 +165,7 @@ def dtype(self: array) -> dtype:
"""

@property
def device(self: array) -> device:
def device(self: array) -> device_:
"""
Hardware device the array data resides on.

Expand Down Expand Up @@ -625,7 +627,7 @@ def __dlpack_device__(self: array, /) -> Tuple[Enum, int]:
ONE_API = 14
"""

def __eq__(self: array, other: Union[int, float, bool, array], /) -> array:
def __eq__(self: array, other: Union[int, float, bool, array], /) -> array: # type: ignore
r"""
Computes the truth value of ``self_i == other_i`` for each element of an array instance with the respective element of the array ``other``.

Expand Down Expand Up @@ -1072,7 +1074,7 @@ def __mul__(self: array, other: Union[int, float, array], /) -> array:
Added complex data type support.
"""

def __ne__(self: array, other: Union[int, float, bool, array], /) -> array:
def __ne__(self: array, other: Union[int, float, bool, array], /) -> array: # type: ignore
"""
Computes the truth value of ``self_i != other_i`` for each element of an array instance with the respective element of the array ``other``.

Expand Down Expand Up @@ -1342,7 +1344,7 @@ def __xor__(self: array, other: Union[int, bool, array], /) -> array:
"""

def to_device(
self: array, device: device, /, *, stream: Optional[Union[int, Any]] = None
self: array, device: device_, /, *, stream: Optional[Union[int, Any]] = None
) -> array:
"""
Copy the array from the device on which it currently resides to the specified ``device``.
Expand Down
6 changes: 6 additions & 0 deletions src/array_api_stubs/_draft/array_object.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from ._types import Array

# for documentation
array = Array

__all__ = ["array"]
36 changes: 19 additions & 17 deletions src/array_api_stubs/_draft/data_types.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
__all__ = ["__eq__"]
from __future__ import annotations

__all__ = ["DType"]

from ._types import dtype

from typing import Protocol

def __eq__(self: dtype, other: dtype, /) -> bool:
"""
Computes the truth value of ``self == other`` in order to test for data type object equality.

Parameters
----------
self: dtype
data type instance. May be any supported data type.
other: dtype
other data type instance. May be any supported data type.

Returns
-------
out: bool
a boolean indicating whether the data type objects are equal.
"""
class DType(Protocol):
def __eq__(self, other: DType, /) -> bool:
"""
Computes the truth value of ``self == other`` in order to test for data type object equality.
Parameters
----------
self: dtype
data type instance. May be any supported data type.
other: dtype
other data type instance. May be any supported data type.
Returns
-------
out: bool
a boolean indicating whether the data type objects are equal.
"""
...
6 changes: 3 additions & 3 deletions src/array_api_stubs/_draft/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def matrix_norm(
/,
*,
keepdims: bool = False,
ord: Optional[Union[int, float, Literal[inf, -inf, "fro", "nuc"]]] = "fro",
ord: Optional[Union[int, float, Literal[inf, -inf, "fro", "nuc"]]] = "fro", # type: ignore
) -> array:
"""
Computes the matrix norm of a matrix (or a stack of matrices) ``x``.
Expand Down Expand Up @@ -781,7 +781,7 @@ def trace(x: array, /, *, offset: int = 0, dtype: Optional[dtype] = None) -> arr
"""


def vecdot(x1: array, x2: array, /, *, axis: int = None) -> array:
def vecdot(x1: array, x2: array, /, *, axis: Optional[int] = None) -> array:
"""Alias for :func:`~array_api.vecdot`."""


Expand All @@ -791,7 +791,7 @@ def vector_norm(
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
keepdims: bool = False,
ord: Union[int, float, Literal[inf, -inf]] = 2,
ord: Union[int, float, Literal[inf, -inf]] = 2, # type: ignore
) -> array:
r"""
Computes the vector norm of a vector (or batch of vectors) ``x``.
Expand Down
0