8000 `tensorflow`: add `tf.linalg` module by hoel-bagard · Pull Request #11386 · python/typeshed · GitHub
[go: up one dir, main page]

Skip to content

tensorflow: add tf.linalg module #11386

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 17, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions stubs/tensorflow/tensorflow/linalg.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from _typeshed import Incomplete
from builtins import bool as _bool
from typing import Literal, overload

from tensorflow import RaggedTensor, Tensor, norm as norm
from tensorflow._aliases import DTypeLike, ScalarTensorCompatible, TensorCompatible
from tensorflow.math import l2_normalize as l2_normalize

@overload
def matmul(
a: TensorCompatible,
b: TensorCompatible,
transpose_a: _bool = False,
transpose_b: _bool = False,
adjoint_a: _bool = False,
adjoint_b: _bool = False,
a_is_sparse: _bool = False,
b_is_sparse: _bool = False,
output_type: DTypeLike | None = None,
name: str | None = None,
) -> Tensor: ...
@overload
def matmul(
a: RaggedTensor,
b: RaggedTensor,
transpose_a: _bool = False,
transpose_b: _bool = False,
adjoint_a: _bool = False,
adjoint_b: _bool = False,
a_is_sparse: _bool = False,
b_is_sparse: _bool = False,
output_type: DTypeLike | None = None,
name: str | None = None,
) -> RaggedTensor: ...
def set_diag(
input: TensorCompatible,
diagonal: TensorCompatible,
name: str | None = "set_diag",
k: int = 0,
align: Literal["RIGHT_LEFT", "RIGHT_RIGHT", "LEFT_LEFT", "LEFT_RIGHT"] = "RIGHT_LEFT",
) -> Tensor: ...
def eye(
num_rows: ScalarTensorCompatible,
num_columns: ScalarTensorCompatible | None = None,
batch_shape: int | tuple[int, ...] | None = None,
dtype: DTypeLike = ...,
name: str | None = None,
) -> Tensor: ...
def band_part(input: TensorCompatible, num_lower: int, num_upper: int, name: str | None = None) -> Tensor: ...
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A scalar (0d like tf.constant(10)) int tensor will also work for num_lower/num_upper. This is generally true of most tensor functions that take an int, although I'm unsure how common it is and higher rank tensors will fail. I'd weakly lean to prefer ScalarTensorCompatible over int here and in similar cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also noticed that in #11333 and there I added an Integer type for that purpose. It seems common enough in TensorFlow to warrant an alias.

Integer: TypeAlias = tf.Tensor | int | IntArray | np.number[Any]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've done the same here in ece7680.

def __getattr__(name: str) -> Incomplete: ...
0