8000 `tensorflow`: add `tensorflow.keras.activations` members by hoel-bagard · Pull Request #11444 · python/typeshed · GitHub
[go: up one dir, main page]

Skip to content
8000

tensorflow: add tensorflow.keras.activations members #11444

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
add tf.keras.activations members
  • Loading branch information
hoel-bagard committed Feb 18, 2024
commit 90676e693631f8b42e3d6e2ac8cfb413f782b815
1 change: 1 addition & 0 deletions stubs/tensorflow/tensorflow/_aliases.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ IntDataSequence: TypeAlias = Sequence[int] | Sequence[IntDataSequence]
StrDataSequence: TypeAlias = Sequence[str] | Sequence[StrDataSequence]
ScalarTensorCompatible: TypeAlias = tf.Tensor | str | float | np.ndarray[Any, Any] | np.number[Any]
UIntTensorCompatible: TypeAlias = tf.Tensor | int | UIntArray
FloatTensorCompatible: TypeAlias = Integer | float | FloatArray
StringTensorCompatible: TypeAlias = tf.Tensor | str | npt.NDArray[np.str_] | Sequence[StringTensorCompatible]

TensorCompatible: TypeAlias = ScalarTensorCompatible | Sequence[TensorCompatible]
Expand Down
26 changes: 24 additions & 2 deletions stubs/tensorflow/tensorflow/keras/activations.pyi
Original file line number Diff line number Diff line change
@@ -1,12 +1,34 @@
from _typeshed import Incomplete
from collections.abc import Callable
from typing import Any
from typing_extensions import TypeAlias

from tensorflow import Tensor
from tensorflow._aliases import FloatDataSequence, FloatTensorCompatible, Integer

# The implementation uses isinstance so it must be dict and not any Mapping.
_Activation: TypeAlias = str | None | Callable[[Tensor], Tensor] | dict[str, Any]

def deserialize(
name: str, custom_objects: dict[str, Callable[..., Any]] | None = None, use_legacy_format: bool = False
) -> Callable[..., Any]: ...
def elu(x: FloatTensorCompatible, alpha: FloatTensorCompatible | FloatDataSequence = 1.0) -> Tensor: ...
def exponential(x: FloatTensorCompatible) -> Tensor: ...
def gelu(x: FloatTensorCompatible, approximate: bool = False) -> Tensor: ...
def get(identifier: _Activation) -> Callable[[Tensor], Tensor]: ...
def __getattr__(name: str) -> Incomplete: ...
def hard_sigmoid(x: FloatTensorCompatible) -> Tensor: ...
def linear(x: FloatTensorCompatible) -> Tensor: ...
def mish(x: FloatTensorCompatible) -> Tensor: ...
def relu(
x: FloatTensorCompatible | FloatDataSequence,
alpha: FloatTensorCompatible = 0.0,
max_value: FloatTensorCompatible | FloatDataSequence | None = None,
threshold: FloatTensorCompatible | FloatDataSequence = 0.0,
) -> Tensor: ...
def selu(x: FloatTensorCompatible | FloatDataSequence) -> Tensor: ... # x here cannot be an int.
def serialize(activation: Callable[..., Any], use_legacy_format: bool = False) -> str: ...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can also return a dict:

In [15]: tensorflow.keras.activations.serialize(len)
Out[15]: 
{'module': 'builtins',
 'class_name': 'builtin_function_or_method',
 'config': 'len',
 'registered_name': 'builtin_function_or_method'}

In [16]: tensorflow.keras.activations.serialize(lambda: 42)
/Users/jelle/py/venvs/py311/lib/python3.11/site-packages/keras/src/activations.py:549: UserWarning: The object being serialized includes a `lambda`. This is unsafe. In order to reload the object, you will have to pass `safe_mode=False` to the loading function. Please avoid using `lambda` in the future, and use named Python functions instead. This is the `lambda` being serialized: tensorflow.keras.activations.serialize(lambda: 42)

  fn_config = serialization_lib.serialize_keras_object(activation)
Out[16]: 
{'value': ('4wAAAAAAAAAAAAAAAAEAAAADAAAA8wYAAACXAGQBUwApAk7pKgAAAKkAcgMAAADzAAAAAPofPGlw\neXRob24taW5wdXQtMTYtYjFiYzU1YjgzOGViPvoIPGxhbWJkYT5yBgAAAAEAAABzBgAAAIAAqHKA\nAHIEAAAA\n',
  None,
  None)}

In [17]: tensorflow.keras.activations.serialize(os.path.exists)
Out[17]: 
{'module': 'builtins',
 'class_name': 'function',
 'config': 'exists',
 'registered_name': 'function'}

Copy link
Contributor Author
@hoel-bagard hoel-bagard Feb 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JelleZijlstra Thanks for pointing it out, I've fixed it in d25b89f. Do you have a method / tool to find this kind of errors ?
For this particular function, looking at the source code would have made it clear that the return type could be a dict. Is it how you found out about it ?

When doing the TensorFlow stubs, I usually try a few inputs, but I can't test everything, so I often rely on what the docs say 😞

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just by trying it out. I didn't do that for all functions, I think at first I was mostly curious what this function would do with a builtin.

def sigmoid(x: FloatTensorCompatible | FloatDataSequence) -> Tensor: ... # x here cannot be an int.
def softmax(x: Tensor, axis: Integer = -1) -> Tensor: ...
def softplus(x: FloatTensorCompatible | FloatDataSequence) -> Tensor: ... # x here cannot be an int.
def softsign(x: 3DFD FloatTensorCompatible | FloatDataSequence) -> Tensor: ... # x here cannot be an int.
def swish(x: FloatTensorCompatible | FloatDataSequence) -> Tensor: ... # x here cannot be an int.
def tanh(x: FloatTensorCompatible | FloatDataSequence) -> Tensor: ... # x here cannot be an int.
0