8000 Initial tensorflow stubs (#8974) · python/typeshed@ea0ae21 · GitHub
[go: up one dir, main page]

Skip to content

Commit ea0ae21

Browse files
Initial tensorflow stubs (#8974)
Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
1 parent 0a291da commit ea0ae21

File tree

7 files changed

+318
-0
lines changed

7 files changed

+318
-0
lines changed
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Some methods are dynamically patched onto to instances as they
2+
# may depend on whether code is executed in graph/eager/v1/v2/etc.
3+
# Tensorflow supports multiple modes of execution which changes some
4+
# of the attributes/methods/even class hierachies.
5+
tensorflow.Tensor.__int__
6+
tensorflow.Tensor.numpy
7+
tensorflow.Tensor.__index__
8+
# Incomplete
9+
tensorflow.sparse.SparseTensor.__getattr__
10+
tensorflow.SparseTensor.__getattr__
11+
tensorflow.TensorShape.__getattr__
12+
tensorflow.dtypes.DType.__getattr__
13+
tensorflow.RaggedTensor.__getattr__
14+
tensorflow.DType.__getattr__
15+
tensorflow.Graph.__getattr__
16+
tensorflow.Operation.__getattr__
17+
tensorflow.Variable.__getattr__
18+
# Internal undocumented API
19+
tensorflow.RaggedTensor.__init__
20+
# Has an undocumented extra argument that tf.Variable which acts like subclass
21+
# (by dynamically patching tf.Tensor methods) does not preserve.
22+
tensorflow.Tensor.__getitem__

stubs/tensorflow/METADATA.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version = "2.10.*"
2+
# requires a version of numpy with a `py.typed` file
3+
requires = ["numpy>=1.20"]
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
from _typeshed import Incomplete, Self, Unused
2+
from abc import ABCMeta
3+
from builtins import bool as _bool
4+
from collections.abc import Callable, Iterable, Iterator, Sequence
5+
from contextlib import contextmanager
6+
from enum import Enum
7+
from typing import Any, NoReturn, overload
8+
from typing_extensions import TypeAlias
9+
10+
import numpy
11+
from tensorflow.dtypes import *
12+
13+
# Most tf.math functions are exported as tf, but sadly not all are.
14+
from tensorflow.math import abs as abs
15+
from tensorflow.sparse import SparseTensor
16+
17+
# Tensors ideally should be a generic type, but properly typing data type/shape
18+
# will be a lot of work. Until we have good non-generic tensorflow stubs,
19+
# we will skip making Tensor generic. Also good type hints for shapes will
20+
# run quickly into many places where type system is not strong enough today.
21+
# So shape typing is probably not worth doing anytime soon.
22+
_Slice: TypeAlias = int | slice | None
23+
24+
_FloatDataSequence: TypeAlias = Sequence[float] | Sequence[_FloatDataSequence]
25+
_StrDataSequence: TypeAlias = Sequence[str] | Sequence[_StrDataSequence]
26+
_ScalarTensorCompatible: TypeAlias = Tensor | str | float | numpy.ndarray[Any, Any] | numpy.number[Any]
27+
_TensorCompatible: TypeAlias = _ScalarTensorCompatible | Sequence[_TensorCompatible]
28+
_ShapeLike: TypeAlias = TensorShape | Iterable[_ScalarTensorCompatible | None] | int | Tensor
29+
_DTypeLike: TypeAlias = DType | str | numpy.dtype[Any]
30+
31+
class Tensor:
32+
def __init__(self, op: Operation, value_index: int, dtype: DType) -> None: ...
33+
def consumers(self) -> list[Incomplete]: ...
34+
@property
35+
def shape(self) -> TensorShape: ...
36+
def get_shape(self) -> TensorShape: ...
37+
@property
38+
def dtype(self) -> DType: ...
39+
@property
40+
def graph(self) -> Graph: ...
41+
@property
42+
def name(self) -> str: ...
43+
@property
44+
def op(self) -> Operation: ...
45+
def numpy(self) -> numpy.ndarray[Any, Any]: ...
46+
def __int__(self) -> int: ...
47+
def __abs__(self, name: str | None = None) -> Tensor: ...
48+
def __add__(self, other: _TensorCompatible) -> Tensor: ...
49+
def __radd__(self, other: _TensorCompatible) -> Tensor: ...
50+
def __sub__(self, other: _TensorCompatible) -> Tensor: ...
51+
def __rsub__(self, other: _TensorCompatible) -> Tensor: ...
52+
def __mul__(self, other: _TensorCompatible) -> Tensor: ...
53+
def __rmul__(self, other: _TensorCompatible) -> Tensor: ...
54+
def __pow__(self, other: _TensorCompatible) -> Tensor: ...
55+
def __matmul__(self, other: _TensorCompatible) -> Tensor: ...
56+
def __rmatmul__(self, other: _TensorCompatible) -> Tensor: ...
57+
def __floordiv__(self, other: _TensorCompatible) -> Tensor: ...
58+
def __rfloordiv__(self, other: _TensorCompatible) -> Tensor: ...
59+
def __truediv__(self, other: _TensorCompatible) -> Tensor: ...
60+
def __rtruediv__(self, other: _TensorCompatible) -> Tensor: ...
61+
def __neg__(self, name: str | None = None) -> Tensor: ...
62+
def __and__( 10000 self, other: _TensorCompatible) -> Tensor: ...
63+
def __rand__(self, other: _TensorCompatible) -> Tensor: ...
64+
def __or__(self, other: _TensorCompatible) -> Tensor: ...
65+
def __ror__(self, other: _TensorCompatible) -> Tensor: ...
66+
def __eq__(self, other: _TensorCompatible) -> Tensor: ... # type: ignore[override]
67+
def __ne__(self, other: _TensorCompatible) -> Tensor: ... # type: ignore[override]
68+
def __ge__(self, other: _TensorCompatible, name: str | None = None) -> Tensor: ...
69+
def __gt__(self, other: _TensorCompatible, name: str | None = None) -> Tensor: ...
70+
def __le__(self, other: _TensorCompatible, name: str | None = None) -> Tensor: ...
71+
def __lt__(self, other: _TensorCompatible, name: str | None = None) -> Tensor: ...
72+
def __bool__(self) -> NoReturn: ...
73+
def __getitem__(self, slice_spec: _Slice | tuple[_Slice, ...]) -> Tensor: ...
74+
def __len__(self) -> int: ...
75+
# This only works for rank 0 tensors.
76+
def __index__(self) -> int: ...
77+
def __getattr__(self, name: str) -> Incomplete: ...
78+
79+
class VariableSynchronization(Enum):
80+
AUTO = 0
81+
NONE = 1
82+
ON_WRITE = 2
83+
ON_READ = 3
84+
85+
class VariableAggregation(Enum):
86+
AUTO = 0
87+
NONE = 1
88+
ON_WRITE = 2
89+
ON_READ = 3
90+
91+
class _VariableMetaclass(type): ...
92+
93+
# Variable class in intent/documentation is a Tensor. In implementation there's
94+
# TODO comment to make it Tensor. It is not actually Tensor type wise, but even
95+
# dynamically patches on most methods of tf.Tensor
96+
# https://github.com/tensorflow/tensorflow/blob/9524a636cae9ae3f0554203c1ba7ee29c85fcf12/tensorflow/python/ops/variables.py#L1086.
97+
class Variable(Tensor, metaclass=_VariableM F42D etaclass):
98+
def __init__(
99+
self,
100+
initial_value: Tensor | Callable[[], Tensor] | None = None,
101+
trainable: _bool | None = None,
102+
validate_shape: _bool = True,
103+
# Valid non-None values are deprecated.
104+
caching_device: None = None,
105+
name: str | None = None,
106+
# Real type is VariableDef protobuf type. Can be added after adding script
107+
# to generate tensorflow protobuf stubs with mypy-protobuf.
108+
variable_def: Incomplete | None = None,
109+
dtype: _DTypeLike | None = None,
110+
import_scope: str | None = None,
111+
constraint: Callable[[Tensor], Tensor] | None = None,
112+
synchronization: VariableSynchronization = VariableSynchronization.AUTO,
113+
aggregation: VariableAggregation = VariableAggregation.NONE,
114+
shape: _ShapeLike | None = None,
115+
) -> None: ...
116+
def __getattr__(self, name: str) -> Incomplete: ...
117+
118+
class RaggedTensor(metaclass=ABCMeta):
119+
def bounding_shape(
120+
self, axis: _TensorCompatible | None = None, name: str | None = None, out_type: _DTypeLike | None = None
121+
) -> Tensor: ...
122+
@classmethod
123+
def from_sparse(
124+
cls, st_input: SparseTensor, name: str | None = None, row_splits_dtype: _DTypeLike = int64
125+
) -> RaggedTensor: ...
126+
def to_sparse(self, name: str | None = None) -> SparseTensor: ...
127+
def to_tensor(
128+
self, default_value: float | str | None = None, name: str | None = None, shape: _ShapeLike | None = None
129+
) -> Tensor: ...
130+
def __add__(self, other: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ...
131+
def __radd__(self, other: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ...
132+
def __sub__(self, other: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ...
133+
def __mul__(self, other: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ...
134+
def __rmul__(self, other: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ...
135+
def __floordiv__(self, other: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ...
136+
def __truediv__(self, other: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ...
137+
def __getitem__(self, slice_spec: _Slice | tuple[_Slice, ...]) -> RaggedTensor: ...
138+
def __getattr__(self, name: str) -> Incomplete: ...
139+
140+
class Operation:
141+
def __init__(
142+
self,
143+
node_def: Incomplete,
144+
g: Graph,
145+
# isinstance is used so can not be Sequence/Iterable.
146+
inputs: list[Tensor] | None = None,
147+
output_types: Unused = None,
148+
control_inputs: Iterable[Tensor | Operation] | None = None,
149+
input_types: Iterable[DType] | None = None,
150+
original_op: Operation | None = None,
151+
op_def: Incomplete = None,
152+
) -> None: ...
153+
@property
154+
def inputs(self) -> list[Tensor]: ...
155+
@property
156+
def outputs(self) -> list[Tensor]: ...
157+
@property
158+
def device(self) -> str: ...
159+
@property
160+
def name(self) -> str: ...
161+
@property
162+
def type(self) -> str: ...
163+
def __getattr__(self, name: str) -> Incomplete: ...
164+
165+
class TensorShape(metaclass=ABCMeta):
166+
def __init__(self, dims: _ShapeLike) -> None: ...
167+
@property
168+
def rank(self) -> int: ...
169+
def as_list(self) -> list[int | None]: ...
170+
def assert_has_rank(self, rank: int) -> None: ...
171+
def assert_is_compatible_with(self, other: Iterable[int | None]) -> None: ...
172+
def __bool__(self) -> _bool: ...
173+
@overload
174+
def __getitem__(self, key: int) -> int | None: ...
175+
@overload
176+
def __getitem__(self, key: slice) -> TensorShape: ...
177+
def __iter__(self) -> Iterator[int | None]: ...
178+
def __len__(self) -> int: ...
179+
def __add__(self, other: Iterable[int | None]) -> TensorShape: ...
180+
def __radd__(self, other: Iterable[int | None]) -> TensorShape: ...
181+
def __getattr__(self, name: str) -> Incomplete: ...
182+
183+
class Graph:
184+
def add_to_collection(self, name: str, value: object) -> None: ...
185+
def add_to_collections(self, names: Iterable[str] | str, value: object) -> None: ...
186+
@contextmanager
187+
def as_default(self: Self) -> Iterator[Self]: ...
188+
def finalize(self) -> None: ...
189+
def get_tensor_by_name(self, name: str) -> Tensor: ...
190+
def get_operation_by_name(self, name: str) -> Operation: ...
191+
def get_operations(self) -> list[Operation]: ...
192+
def get_name_scope(self) -> str: ...
193+
def __getattr__(self, name: str) -> Incomplete: ...
194+
195+
def __getattr__(name: str) -> Incomplete: ...

stubs/tensorflow/tensorflow/core/framework/variable_pb2.pyi

Whitespace-only changes.
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from _typeshed import Incomplete
2+
from abc import ABCMeta
3+
from builtins import bool as _bool
4+
from typing import Any
5+
6+
import numpy as np
7+
from tensorflow import _DTypeLike
8+
9+
class _DTypeMeta(ABCMeta): ...
10+
11+
class DType(metaclass=_DTypeMeta):
12+
@property
13+
def name(self) -> str: ...
14+
@property
15+
def as_numpy_dtype(self) -> type[np.number[Any]]: ...
16+
@property
17+
def is_numpy_compatible(self) -> _bool: ...
18+
@property
19+
def is_bool(self) -> _bool: ...
20+
@property
21+
def is_floating(self) -> _bool: ...
22+
@property
23+
def is_integer(self) -> _bool: ...
24+
@property
25+
def is_quantized(self) -> _bool: ...
26+
@property
27+
def is_unsigned(self) -> _bool: ...
28+
def __getattr__(self, name: str) -> Incomplete: ...
29+
30+
bool: DType
31+
complex128: DType
32+
complex64: DType
33+
bfloat16: DType
34+
float16: DType
35+
half: DType
36+
float32: DType
37+
float64: DType
38+
double: DType
39+
int8: DType
40+
int16: DType
41+
int32: DType
42+
int64: DType
43+
uint8: DType
44+
uint16: DType
45+
uint32: DType
46+
uint64: DType
47+
qint8: DType
48+
qint16: DType
49+
qint32: DType
50+
quint8: DType
51+
quint16: DType
52+
string: DType
53+
54+
def as_dtype(type_value: _DTypeLike) -> DType: ...
55+
def __getattr__(name: str) -> Incomplete: ...

stubs/tensorflow/tensorflow/math.pyi

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from _typeshed import Incomplete
2+
from typing import overload
3+
4+
from tensorflow import RaggedTensor, Tensor, _TensorCompatible
5+
from tensorflow.sparse import SparseTensor
6+
7+
@overload
8+
def abs(x: _TensorCompatible, name: str | None = None) -> Tensor: ...
9+
@overload
10+
def abs(x: SparseTensor, name: str | None = None) -> SparseTensor: ...
11+
@overload
12+
def abs(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
13+
def __getattr__(name: str) -> Incomplete: ...
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from _typeshed import Incomplete
2+
from abc import ABCMeta
3+
from typing_extensions import TypeAlias
4+
5+
from tensorflow import Tensor, TensorShape, _TensorCompatible
6+
from tensorflow.dtypes import DType
7+
8+
_SparseTensorCompatible: TypeAlias = _TensorCompatible | SparseTensor
9+
10+
class SparseTensor(metaclass=ABCMeta):
11+
@property
12+
def indices(self) -> Tensor: ...
13+
@property
14+
def values(self) -> Tensor: ...
15+
@property
16+
def dense_shape(self) -> Tensor: ...
17+
@property
18+
def shape(self) -> TensorShape: ...
19+
@property
20+
def dtype(self) -> DType: ...
21+
name: str
22+
def __init__(self, indices: _TensorCompatible, values: _TensorCompatible, dense_shape: _TensorCompatible) -> None: ...
23+
def get_shape(self) -> TensorShape: ...
24+
# Many arithmetic operations are not directly supported. Some have alternatives like tf.sparse.add instead of +.
25+
def __div__(self, y: _SparseTensorCompatible) -> SparseTensor: ...
26+
def __truediv__(self, y: _SparseTensorCompatible) -> SparseTensor: ...
27+
def __mul__(self, y: _SparseTensorCompatible) -> SparseTensor: ...
28+
def __getattr__(self, name: str) -> Incomplete: ...
29+
30+
def __getattr__(name: str) -> Incomplete: ...

0 commit comments

Comments
 (0)
0