-
-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Initial tensorflow stubs #8974
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
AlexWaygood
merged 22 commits into
python:main
from
hmc-cs-mdrissi:md/tensorflow_initial_stubs
Jan 14, 2023
Merged
Initial tensorflow stubs #8974
Changes from all commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
fc7a2c0
Initial tensorflow stubs
29017d2
fix flake8 lints besides default values
93d1db9
fix mypy primer
46369ed
Remove default values for consistency with typeshed.
4863daa
Merge branch 'main' into md/tensorflow_initial_stubs
AlexWaygood 304756d
Add default values and fix pyright errors.
5fc1ce4
Fix couple mypy errors.
bb714f5
Add ignore missing stub
df9576e
Fix stubtest errors.
556fb08
fix pyright
500ce70
Try to fix the mypy issue
AlexWaygood 1dd49f8
Update tests.yml
AlexWaygood a20f2d9
Update utils.py
AlexWaygood 64bf119
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 4c6af9e
Update mypy_test.py
AlexWaygood 19ab96c
Update mypy_test.py
AlexWaygood db0924d
Update mypy_test.py
AlexWaygood 9e01dde
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 150160d
Try to fix pytype with numpy usage.
5174b85
Apply suggestions from code review
hmc-cs-mdrissi d0ad189
Merge branch 'main' into md/tensorflow_initial_stubs
AlexWaygood e5c6e8f
Merge branch 'main' into md/tensorflow_initial_stubs
AlexWaygood File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Some methods are dynamically patched onto to instances as they | ||
# may depend on whether code is executed in graph/eager/v1/v2/etc. | ||
# Tensorflow supports multiple modes of execution which changes some | ||
# of the attributes/methods/even class hierachies. | ||
tensorflow.Tensor.__int__ | ||
tensorflow.Tensor.numpy | ||
tensorflow.Tensor.__index__ | ||
# Incomplete | ||
tensorflow.sparse.SparseTensor.__getattr__ | ||
tensorflow.SparseTensor.__getattr__ | ||
tensorflow.TensorShape.__getattr__ | ||
tensorflow.dtypes.DType.__getattr__ | ||
tensorflow.RaggedTensor.__getattr__ | ||
tensorflow.DType.__getattr__ | ||
tensorflow.Graph.__getattr__ | ||
tensorflow.Operation.__getattr__ | ||
tensorflow.Variable.__getattr__ | ||
# Internal undocumented API | ||
tensorflow.RaggedTensor.__init__ | ||
# Has an undocumented extra argument that tf.Variable which acts like subclass | ||
# (by dynamically patching tf.Tensor methods) does not preserve. | ||
tensorflow.Tensor.__getitem__ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
version = "2.10.*" | ||
# requires a version of numpy with a `py.typed` file | ||
requires = ["numpy>=1.20"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,195 @@ | ||
from _typeshed import Incomplete, Self, Unused | ||
from abc import ABCMeta | ||
from builtins import bool as _bool | ||
from collections.abc import Callable, Iterable, Iterator, Sequence | ||
from contextlib import contextmanager | ||
from enum import Enum | ||
from typing import Any, NoReturn, overload | ||
from typing_extensions import TypeAlias | ||
|
||
import numpy | ||
from tensorflow.dtypes import * | ||
|
||
# Most tf.math functions are exported as tf, but sadly not all are. | ||
from tensorflow.math import abs as abs | ||
from tensorflow.sparse import SparseTensor | ||
|
||
# Tensors ideally should be a generic type, but properly typing data type/shape | ||
# will be a lot of work. Until we have good non-generic tensorflow stubs, | ||
# we will skip making Tensor generic. Also good type hints for shapes will | ||
# run quickly into many places where type system is not strong enough today. | ||
# So shape typing is probably not worth doing anytime soon. | ||
_Slice: TypeAlias = int | slice | None | ||
|
||
_FloatDataSequence: TypeAlias = Sequence[float] | Sequence[_FloatDataSequence] | ||
_StrDataSequence: TypeAlias = Sequence[str] | Sequence[_StrDataSequence] | ||
_ScalarTensorCompatible: TypeAlias = Tensor | str | float | numpy.ndarray[Any, Any] | numpy.number[Any] | ||
_TensorCompatible: TypeAlias = _ScalarTensorCompatible | Sequence[_TensorCompatible] | ||
_ShapeLike: TypeAlias = TensorShape | Iterable[_ScalarTensorCompatible | None] | int | Tensor | ||
_DTypeLike: TypeAlias = DType | str | numpy.dtype[Any] | ||
|
||
class Tensor: | ||
def __init__(self, op: Operation, value_index: int, dtype: DType) -> None: ... | ||
def consumers(self) -> list[Incomplete]: ... | ||
@property | ||
def shape(self) -> TensorShape: ... | ||
def get_shape(self) -> TensorShape: ... | ||
@property | ||
def dtype(self) -> DType: ... | ||
@property | ||
def graph(self) -> Graph: ... | ||
@property | ||
def name(self) -> str: ... | ||
@property | ||
def op(self) -> Operation: ... | ||
def numpy(self) -> numpy.ndarray[Any, Any]: ... | ||
def __int__(self) -> int: ... | ||
def __abs__(self, name: str | None = None) -> Tensor: ... | ||
def __add__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __radd__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __sub__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __rsub__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __mul__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __rmul__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __pow__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __matmul__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __rmatmul__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __floordiv__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __rfloordiv__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __truediv__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __rtruediv__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __neg__(self, name: str | None = None) -> Tensor: ... | ||
def __and__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __rand__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __or__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __ror__(self, other: _TensorCompatible) -> Tensor: ... | ||
def __eq__(self, other: _TensorCompatible) -> Tensor: ... # type: ignore[override] | ||
def __ne__(self, other: _TensorCompatible) -> Tensor: ... # type: ignore[override] | ||
def __ge__(self, other: _TensorCompatible, name: str | None = None) -> Tensor: ... | ||
def __gt__(self, other: _TensorCompatible, name: str | None = None) -> Tensor: ... | ||
def __le__(self, other: _TensorCompatible, name: str | None = None) -> Tensor: ... | ||
def __lt__(self, other: _TensorCompatible, name: str | None = None) -> Tensor: ... | ||
def __bool__(self) -> NoReturn: ... | ||
def __getitem__(self, slice_spec: _Slice | tuple[_Slice, ...]) -> Tensor: ... | ||
def __len__(self) -> int: ... | ||
# This only works for rank 0 tensors. | ||
def __index__(self) -> int: ... | ||
def __getattr__(self, name: str) -> Incomplete: ... | ||
|
||
class VariableSynchronization(Enum): | ||
AUTO = 0 | ||
NONE = 1 | ||
ON_WRITE = 2 | ||
ON_READ = 3 | ||
|
||
class VariableAggregation(Enum): | ||
AUTO = 0 | ||
NONE = 1 | ||
ON_WRITE = 2 | ||
ON_READ = 3 | ||
AlexWaygood marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
class _VariableMetaclass(type): ... | ||
|
||
# Variable class in intent/documentation is a Tensor. In implementation there's | ||
# TODO comment to make it Tensor. It is not actually Tensor type wise, but even | ||
# dynamically patches on most methods of tf.Tensor | ||
# https://github.com/tensorflow/tensorflow/blob/9524a636cae9ae3f0554203c1ba7ee29c85fcf12/tensorflow/python/ops/variables.py#L1086. | ||
class Variable(Tensor, metaclass=_VariableMetaclass): | ||
def __init__( | ||
self, | ||
initial_value: Tensor | Callable[[], Tensor] | None = None, | ||
trainable: _bool | None = None, | ||
validate_shape: _bool = True, | ||
# Valid non-None values are deprecated. | ||
caching_device: None = None, | ||
name: str | None = None, | ||
# Real type is VariableDef protobuf type. Can be added after adding script | ||
# to generate tensorflow protobuf stubs with mypy-protobuf. | ||
variable_def: Incomplete | None = None, | ||
dtype: _DTypeLike | None = None, | ||
import_scope: str | None = None, | ||
constraint: Callable[[Tensor], Tensor] | None = None, | ||
synchronization: VariableSynchronization = VariableSynchronization.AUTO, | ||
aggregation: VariableAggregation = VariableAggregation.NONE, | ||
shape: _ShapeLike | None = None, | ||
) -> None: ... | ||
def __getattr__(self, name: str) -> Incomplete: ... | ||
|
||
class RaggedTensor(metaclass=ABCMeta): | ||
def bounding_shape( | ||
self, axis: _TensorCompatible | None = None, name: str | None = None, out_type: _DTypeLike | None = None | ||
) -> Tensor: ... | ||
@classmethod | ||
def from_sparse( | ||
cls, st_input: SparseTensor, name: str | None = None, row_splits_dtype: _DTypeLike = int64 | ||
) -> RaggedTensor: ... | ||
def to_sparse(self, name: str | None = None) -> SparseTensor: ... | ||
def to_tensor( | ||
self, default_value: float | str | None = None, name: str | None = None, shape: _ShapeLike | None = None | ||
) -> Tensor: ... | ||
def __add__(self, other: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ... | ||
def __radd__(self, other: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ... | ||
def __sub__(self, other: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ... | ||
def __mul__(self, other: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ... | ||
def __rmul__(self, other: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ... | ||
def __floordiv__(self, other: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ... | ||
def __truediv__(self, other: RaggedTensor | float, name: str | None = None) -> RaggedTensor: ... | ||
def __getitem__(self, slice_spec: _Slice | tuple[_Slice, ...]) -> RaggedTensor: ... | ||
def __getattr__(self, name: str) -> Incomplete: ... | ||
|
||
class Operation: | ||
def __init__( | ||
self, | ||
node_def: Incomplete, | ||
g: Graph, | ||
# isinstance is used so can not be Sequence/Iterable. | ||
inputs: list[Tensor] | None = None, | ||
output_types: Unused = None, | ||
control_inputs: Iterable[Tensor | Operation] | None = None, | ||
input_types: Iterable[DType] | None = None, | ||
original_op: Operation | None = None, | ||
op_def: Incomplete = None, | ||
) -> None: ... | ||
@property | ||
def inputs(self) -> list[Tensor]: ... | ||
@property | ||
def outputs(self) -> list[Tensor]: ... | ||
@property | ||
def device(self) -> str: ... | ||
@property | ||
def name(self) -> str: ... | ||
@property | ||
def type(self) -> str: ... | ||
def __getattr__(self, name: str) -> Incomplete: ... | ||
|
||
class TensorShape(metaclass=ABCMeta): | ||
def __init__(self, dims: _ShapeLike) -> None: ... | ||
@property | ||
def rank(self) -> int: ... | ||
def as_list(self) -> list[int | None]: ... | ||
def assert_has_rank(self, rank: int) -> None: ... | ||
def assert_is_compatible_with(self, other: Iterable[int | None]) -> None: ... | ||
def __bool__(self) -> _bool: ... | ||
@overload | ||
def __getitem__(self, key: int) -> int | None: ... | ||
@overload | ||
def __getitem__(self, key: slice) -> TensorShape: ... | ||
def __iter__(self) -> Iterator[int | None]: ... | ||
def __len__(self) -> int: ... | ||
def __add__(self, other: Iterable[int | None]) -> TensorShape: ... | ||
def __radd__(self, other: Iterable[int | None]) -> TensorShape: ... | ||
def __getattr__(self, name: str) -> Incomplete: ... | ||
|
||
class Graph: | ||
def add_to_collection(self, name: str, value: object) -> None: ... | ||
def add_to_collections(self, names: Iterable[str] | str, value: object) -> None: ... | ||
@contextmanager | ||
def as_default(self: Self) -> Iterator[Self]: ... | ||
def finalize(self) -> None: ... | ||
def get_tensor_by_name(self, name: str) -> Tensor: ... | ||
6D40 | def get_operation_by_name(self, name: str) -> Operation: ... | |
def get_operations(self) -> list[Operation]: ... | ||
def get_name_scope(self) -> str: ... | ||
def __getattr__(self, name: str) -> Incomplete: ... | ||
|
||
def __getattr__(name: str) -> Incomplete: ... |
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
from _typeshed import Incomplete | ||
from abc import ABCMeta | ||
from builtins import bool as _bool | ||
from typing import Any | ||
|
||
import numpy as np | ||
from tensorflow import _DTypeLike | ||
|
||
class _DTypeMeta(ABCMeta): ... | ||
|
||
class DType(metaclass=_DTypeMeta): | ||
@property | ||
def name(self) -> str: ... | ||
@property | ||
def as_numpy_dtype(self) -> type[np.number[Any]]: ... | ||
@property | ||
def is_numpy_compatible(self) -> _bool: ... | ||
@property | ||
def is_bool(self) -> _bool: ... | ||
@property | ||
def is_floating(self) -> _bool: ... | ||
@property | ||
def is_integer(self) -> _bool: ... | ||
@property | ||
def is_quantized(self) -> _bool: ... | ||
@property | ||
def is_unsigned(self) -> _bool: ... | ||
def __getattr__(self, name: str) -> Incomplete: ... | ||
|
||
bool: DType | ||
complex128: DType | ||
complex64: DType | ||
bfloat16: DType | ||
float16: DType | ||
half: DType | ||
float32: DType | ||
float64: DType | ||
double: DType | ||
int8: DType | ||
int16: DType | ||
int32: DType | ||
int64: DType | ||
uint8: DType | ||
uint16: DType | ||
uint32: DType | ||
uint64: DType | ||
qint8: DType | ||
qint16: DType | ||
qint32: DType | ||
quint8: DType | ||
quint16: DType | ||
string: DType | ||
|
||
def as_dtype(type_value: _DTypeLike) -> DType: ... | ||
def __getattr__(name: str) -> Incomplete: ... |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from _typeshed import Incomplete | ||
from typing import overload | ||
|
||
from tensorflow import RaggedTensor, Tensor, _TensorCompatible | ||
from tensorflow.sparse import SparseTensor | ||
|
||
@overload | ||
def abs(x: _TensorCompatible, name: str | None = None) -> Tensor: ... | ||
@overload | ||
def abs(x: SparseTensor, name: str | None = None) -> SparseTensor: ... | ||
@overload | ||
def abs(x: RaggedTensor, name: str | None = None) -> RaggedTensor: ... | ||
def __getattr__(name: str) -> Incomplete: ... |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from _typeshed import Incomplete | ||
from abc import ABCMeta | ||
from typing_extensions import TypeAlias | ||
|
||
from tensorflow import Tensor, TensorShape, _TensorCompatible | ||
from tensorflow.dtypes import DType | ||
|
||
_SparseTensorCompatible: TypeAlias = _TensorCompatible | SparseTensor | ||
|
||
class SparseTensor(metaclass=ABCMeta): | ||
@property | ||
def indices(self) -> Tensor: ... | ||
@property | ||
def values(self) -> Tensor: ... | ||
@property | ||
def dense_shape(self) -> Tensor: ... | ||
@property | ||
def shape(self) -> TensorShape: ... | ||
@property | ||
def dtype(self) -> DType: ... | ||
name: str | ||
def __init__(self, indices: _TensorCompatible, values: _TensorCompatible, dense_shape: _TensorCompatible) -> None: ... | ||
def get_shape(self) -> TensorShape: ... | ||
# Many arithmetic operations are not directly supported. Some have alternatives like tf.sparse.add instead of +. | ||
def __div__(self, y: _SparseTensorCompatible) -> SparseTensor: ... | ||
def __truediv__(self, y: _SparseTensorCompatible) -> SparseTensor: ... | ||
def __mul__(self, y: _SparseTensorCompatible) -> SparseTensor: ... | ||
def __getattr__(self, name: str) -> Incomplete: ... | ||
|
||
def __getattr__(name: str) -> Incomplete: ... |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.