From 02a672159f8a0d54c4df8965e68f5038ed630169 Mon Sep 17 00:00:00 2001 From: Hoel Bagard Date: Sat, 10 Feb 2024 00:31:37 +0900 Subject: [PATCH 1/5] Add some linalg stubs Taken from: https://github.com/hmc-cs-mdrissi/tensorflow_stubs/blob/main/stubs/tensorflow/linalg.pyi --- stubs/tensorflow/tensorflow/linalg.pyi | 50 ++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 stubs/tensorflow/tensorflow/linalg.pyi diff --git a/stubs/tensorflow/tensorflow/linalg.pyi b/stubs/tensorflow/tensorflow/linalg.pyi new file mode 100644 index 000000000000..22a8fc9e0fe1 --- /dev/null +++ b/stubs/tensorflow/tensorflow/linalg.pyi @@ -0,0 +1,50 @@ +from _typeshed import Incomplete +from builtins import bool as _bool +from typing import Literal, overload + +from tensorflow import RaggedTensor, Tensor, norm as norm +from tensorflow._aliases import DTypeLike, ScalarTensorCompatible, TensorCompatible +from tensorflow.math import l2_normalize as l2_normalize + +@overload +def matmul( + a: TensorCompatible, + b: TensorCompatible, + transpose_a: _bool = False, + transpose_b: _bool = False, + adjoint_a: _bool = False, + adjoint_b: _bool = False, + a_is_sparse: _bool = False, + b_is_sparse: _bool = False, + output_type: DTypeLike | None = None, + name: str | None = None, +) -> Tensor: ... +@overload +def matmul( + a: RaggedTensor, + b: RaggedTensor, + transpose_a: _bool = False, + transpose_b: _bool = False, + adjoint_a: _bool = False, + adjoint_b: _bool = False, + a_is_sparse: _bool = False, + b_is_sparse: _bool = False, + output_type: DTypeLike | None = None, + name: str | None = None, +) -> RaggedTensor: ... +def set_diag( + input: TensorCompatible, + diagonal: TensorCompatible, + name: str | None = "set_diag", + k: int = 0, + align: Literal["RIGHT_LEFT", "RIGHT_RIGHT", "LEFT_LEFT", "LEFT_RIGHT"] = "RIGHT_LEFT", +) -> Tensor: ... +def eye( + num_rows: ScalarTensorCompatible, + num_columns: ScalarTensorCompatible | None = None, + batch_shape: int | tuple[int, ...] | None = None, + dtype: DTypeLike = ..., + name: str | None = None, +) -> Tensor: ... +def band_part(input: TensorCompatible, num_lower: int, num_upper: int, name: str | None = None) -> Tensor: ... +def __getattr__(name: str) -> Incomplete: ... From a5372f0932673c388331406123582af2d3d6404b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ho=C3=ABl=20Bagard?= <34478245+hoel-bagard@users.noreply.github.com> Date: Sat, 17 Feb 2024 11:55:22 +0900 Subject: [PATCH 2/5] Update stubs/tensorflow/tensorflow/linalg.pyi Co-authored-by: Jelle Zijlstra --- stubs/tensorflow/tensorflow/linalg.pyi | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stubs/tensorflow/tensorflow/linalg.pyi b/stubs/tensorflow/tensorflow/linalg.pyi index 22a8fc9e0fe1..178a2f026cb1 100644 --- a/stubs/tensorflow/tensorflow/linalg.pyi +++ b/stubs/tensorflow/tensorflow/linalg.pyi @@ -42,7 +42,7 @@ def set_diag( def eye( num_rows: ScalarTensorCompatible, num_columns: ScalarTensorCompatible | None = None, - batch_shape: int | tuple[int, ...] | None = None, + batch_shape: int | list[int] | tuple[int, ...] | None = None, dtype: DTypeLike = ..., name: str | None = None, ) -> Tensor: ... From 28ae404a5be7f888feed3971ec6ec412f07e3c14 Mon Sep 17 00:00:00 2001 From: Hoel Bagard Date: Sat, 17 Feb 2024 11:57:43 +0900 Subject: [PATCH 3/5] fix: re-export eye from tf --- stubs/tensorflow/tensorflow/__init__.pyi | 1 + 1 file changed, 1 insertion(+) diff --git a/stubs/tensorflow/tensorflow/__init__.pyi b/stubs/tensorflow/tensorflow/__init__.pyi index 1021bff117a7..b4a5b082ce14 100644 --- a/stubs/tensorflow/tensorflow/__init__.pyi +++ b/stubs/tensorflow/tensorflow/__init__.pyi @@ -37,6 +37,7 @@ from tensorflow.core.protobuf import struct_pb2 from tensorflow.dtypes import * from tensorflow.dtypes import DType as DType from tensorflow.keras import losses as losses +from tensorflow.linalg import eye as eye # Most tf.math functions are exported as tf, but sadly not all are. from tensorflow.math import ( From 6e2d07a41c53fcc7c7db56aafa3dc16b15e24a76 Mon Sep 17 00:00:00 2001 From: Hoel Bagard Date: Sat, 17 Feb 2024 20:53:31 +0900 Subject: [PATCH 4/5] fix: eye's batch_shape --- stubs/tensorflow/tensorflow/linalg.pyi | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/stubs/tensorflow/tensorflow/linalg.pyi b/stubs/tensorflow/tensorflow/linalg.pyi index 178a2f026cb1..612a13f70fe0 100644 --- a/stubs/tensorflow/tensorflow/linalg.pyi +++ b/stubs/tensorflow/tensorflow/linalg.pyi @@ -1,9 +1,11 @@ from _typeshed import Incomplete from builtins import bool as _bool +from collections.abc import Iterable from typing import Literal, overload +import tensorflow as tf from tensorflow import RaggedTensor, Tensor, norm as norm -from tensorflow._aliases import DTypeLike, ScalarTensorCompatible, TensorCompatible +from tensorflow._aliases import DTypeLike, IntArray, ScalarTensorCompatible, TensorCompatible from tensorflow.math import l2_normalize as l2_normalize @overload @@ -42,7 +44,7 @@ def set_diag( def eye( num_rows: ScalarTensorCompatible, num_columns: ScalarTensorCompatible | None = None, - batch_shape: int | list[int] | tuple[int, ...] | None = None, + batch_shape: Iterable[int] | IntArray | tf.Tensor | None = None, dtype: DTypeLike = ..., name: str | None = None, ) -> Tensor: ... From ece7680458fe890154b42f69c6df37b551103935 Mon Sep 17 00:00:00 2001 From: Hoel Bagard Date: Sat, 17 Feb 2024 20:58:54 +0900 Subject: [PATCH 5/5] add and use Integer type alias. --- stubs/tensorflow/tensorflow/_aliases.pyi | 1 + stubs/tensorflow/tensorflow/linalg.pyi | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/stubs/tensorflow/tensorflow/_aliases.pyi b/stubs/tensorflow/tensorflow/_aliases.pyi index 7b92a6fa28ad..ca85fbf0b1c4 100644 --- a/stubs/tensorflow/tensorflow/_aliases.pyi +++ b/stubs/tensorflow/tensorflow/_aliases.pyi @@ -27,6 +27,7 @@ class KerasSerializable2(Protocol): KerasSerializable: TypeAlias = KerasSerializable1 | KerasSerializable2 +Integer: TypeAlias = tf.Tensor | int | IntArray | np.number[Any] # Here tf.Tensor and IntArray are assumed to be 0D. Slice: TypeAlias = int | slice | None FloatDataSequence: TypeAlias = Sequence[float] | Sequence[FloatDataSequence] StrDataSequence: TypeAlias = Sequence[str] | Sequence[StrDataSequence] diff --git a/stubs/tensorflow/tensorflow/linalg.pyi b/stubs/tensorflow/tensorflow/linalg.pyi index 612a13f70fe0..1ce7acfacf42 100644 --- a/stubs/tensorflow/tensorflow/linalg.pyi +++ b/stubs/tensorflow/tensorflow/linalg.pyi @@ -5,7 +5,7 @@ from typing import Literal, overload import tensorflow as tf from tensorflow import RaggedTensor, Tensor, norm as norm -from tensorflow._aliases import DTypeLike, IntArray, ScalarTensorCompatible, TensorCompatible +from tensorflow._aliases import DTypeLike, IntArray, Integer, ScalarTensorCompatible, TensorCompatible from tensorflow.math import l2_normalize as l2_normalize @overload @@ -48,5 +48,5 @@ def eye( dtype: DTypeLike = ..., name: str | None = None, ) -> Tensor: ... -def band_part(input: TensorCompatible, num_lower: int, num_upper: int, name: str | None = None) -> Tensor: ... +def band_part(input: TensorCompatible, num_lower: Integer, num_upper: Integer, name: str | None = None) -> Tensor: ... def __getattr__(name: str) -> Incomplete: ...