diff --git a/stubs/tensorflow/tensorflow/__init__.pyi b/stubs/tensorflow/tensorflow/__init__.pyi index 1021bff117a7..b4a5b082ce14 100644 --- a/stubs/tensorflow/tensorflow/__init__.pyi +++ b/stubs/tensorflow/tensorflow/__init__.pyi @@ -37,6 +37,7 @@ from tensorflow.core.protobuf import struct_pb2 from tensorflow.dtypes import * from tensorflow.dtypes import DType as DType from tensorflow.keras import losses as losses +from tensorflow.linalg import eye as eye # Most tf.math functions are exported as tf, but sadly not all are. from tensorflow.math import ( diff --git a/stubs/tensorflow/tensorflow/_aliases.pyi b/stubs/tensorflow/tensorflow/_aliases.pyi index 7b92a6fa28ad..ca85fbf0b1c4 100644 --- a/stubs/tensorflow/tensorflow/_aliases.pyi +++ b/stubs/tensorflow/tensorflow/_aliases.pyi @@ -27,6 +27,7 @@ class KerasSerializable2(Protocol): KerasSerializable: TypeAlias = KerasSerializable1 | KerasSerializable2 +Integer: TypeAlias = tf.Tensor | int | IntArray | np.number[Any] # Here tf.Tensor and IntArray are assumed to be 0D. Slice: TypeAlias = int | slice | None FloatDataSequence: TypeAlias = Sequence[float] | Sequence[FloatDataSequence] StrDataSequence: TypeAlias = Sequence[str] | Sequence[StrDataSequence] diff --git a/stubs/tensorflow/tensorflow/linalg.pyi b/stubs/tensorflow/tensorflow/linalg.pyi new file mode 100644 index 000000000000..1ce7acfacf42 --- /dev/null +++ b/stubs/tensorflow/tensorflow/linalg.pyi @@ -0,0 +1,52 @@ +from _typeshed import Incomplete +from builtins import bool as _bool +from collections.abc import Iterable +from typing import Literal, overload + +import tensorflow as tf +from tensorflow import RaggedTensor, Tensor, norm as norm +from tensorflow._aliases import DTypeLike, IntArray, Integer, ScalarTensorCompatible, TensorCompatible +from tensorflow.math import l2_normalize as l2_normalize + +@overload +def matmul( + a: TensorCompatible, + b: TensorCompatible, + transpose_a: _bool = False, + transpose_b: _bool = False, + adjoint_a: _bool = False, + adjoint_b: _bool = False, + a_is_sparse: _bool = False, + b_is_sparse: _bool = False, + output_type: DTypeLike | None = None, + name: str | None = None, +) -> Tensor: ... +@overload +def matmul( + a: RaggedTensor, + b: RaggedTensor, + transpose_a: _bool = False, + transpose_b: _bool = False, + adjoint_a: _bool = False, + adjoint_b: _bool = False, + a_is_sparse: _bool = False, + b_is_sparse: _bool = False, + output_type: DTypeLike | None = None, + name: str | None = None, +) -> RaggedTensor: ... +def set_diag( + input: TensorCompatible, + diagonal: TensorCompatible, + name: str | None = "set_diag", + k: int = 0, + align: Literal["RIGHT_LEFT", "RIGHT_RIGHT", "LEFT_LEFT", "LEFT_RIGHT"] = "RIGHT_LEFT", +) -> Tensor: ... +def eye( + num_rows: ScalarTensorCompatible, + num_columns: ScalarTensorCompatible | None = None, + batch_shape: Iterable[int] | IntArray | tf.Tensor | None = None, + dtype: DTypeLike = ..., + name: str | None = None, +) -> Tensor: ... +def band_part(input: TensorCompatible, num_lower: Integer, num_upper: Integer, name: str | None = None) -> Tensor: ... +def __getattr__(name: str) -> Incomplete: ...