-
-
Notifications
You must be signed in to change notification settings - Fork 1.9k
tensorflow
: add tf.linalg
module
#11386
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
305c5bf
to
02a6721
Compare
This comment has been minimized.
This comment has been minimized.
tensorflow
add tf.linalg
moduletensorflow
: add tf.linalg
module
Co-authored-by: Jelle Zijlstra <jelle.zijlstra@gmail.com>
def eye( | ||
num_rows: ScalarTensorCompatible, | ||
num_columns: ScalarTensorCompatible | None = None, | ||
batch_shape: int | list[int] | tuple[int, ...] | None = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor nit: I'd lean to relax to TensorCompatible. batch_shape is also compatible with tensor, A list or tuple of Python integers or a 1-D int32 Tensor
. Looking over implementation I think Iterable[int] | TensorCompatible
is about as broad as it can be. It's very rare for list[int] to be compatible but not int tensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TensorCompatible
includes invalid types, like str
, float
and array of strings/float.
Iterable[int] | IntArray | tf.Tensor
seems better to me, so I went with that in 6e2d07a.
dtype: DTypeLike = ..., | ||
name: str | None = None, | ||
) -> Tensor: ... | ||
def band_part(input: TensorCompatible, num_lower: int, num_upper: int, name: str | None = None) -> Tensor: ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A scalar (0d like tf.constant(10)
) int tensor will also work for num_lower/num_upper. This is generally true of most tensor functions that take an int, although I'm unsure how common it is and higher rank tensors will fail. I'd weakly lean to prefer ScalarTensorCompatible over int here and in similar cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also noticed that in #11333 and there I added an Integer
type for that purpose. It seems common enough in TensorFlow to warrant an alias.
Integer: TypeAlias = tf.Tensor | int | IntArray | np.number[Any]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've done the same here in ece7680.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only two small comments on broadening two hints. My stubs have tendency for tensor functions to just have size: int instead of size: ScalarTensorCompatible. Usually functions in linalg/tf.math/tf.nn if it expects a scalar int will accept tensor scalar too. The former choice of just int will have some false positives, while latter ScalarTensorCompatible will have some false negatives as we don't distinguish rank of tensor (scalar vs vector vs matrix). I'd lean for typeshed to prefer latter.
Same here. While false negatives aren't great, they probably won't occur as much as the false positives would. And they will probably quickly be found out and fixed since they cause errors at runtime. I wish the third-party stubtests would return more examples, that would help find edge cases and evaluate the impact. |
According to mypy_primer, this change has no effect on the checked open source code. 🤖🎉 |
Add incomplete stubs taken from here.