8000 use `typing.Annotated` as standardised format for annotating shapes of tensors using type hints. · Issue #98702 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

use typing.Annotated as standardised format for annotating shapes of tensors using type hints. #98702

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
dominiquegarmier opened this issue Apr 9, 2023 · 2 comments
Labels
enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: typing Related to mypy type annotations triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@dominiquegarmier
Copy link
dominiquegarmier commented Apr 9, 2023

📚 The doc issue

Currently there is no standard way of annotating shapes of tensors using type hints.

Some of the currently existing solutions:
(the Pro for each of these is that they exist and do what they are suppose to)

  • Named Tensors, Con: not using Pythons type hint system (this is only a con in the context of this issue)
  • jaxtyping, Con: clashes with mypy and flake8 due to (mypys and flake8s) poor support of new typing features; Limited by mypys (lack of) support for PEP646.
  • tslib, Con: not compatible with type hints.

Suggest a potential alternative/fix

Add a section to the docs about how to annotate shapes of tensors using type hints according to the ideas outlined below:

Idea

Even with PEP646 classes may only be generic with respect to proper-types (i.e. Tensor["A"] would never be a valid type hint as "A" is not a proper-type). This is where typing.Annotated comes to save the day (native after 3.9 available through typing-extension after 3.7).

typing.Annotated allows a type to be "annotated" using arbitrary Python code. i.e. a: typing.Annotated[T, foo, bar, baz, ...] where 'T' is the proper-type of a and foo, bar, baz can be any Python expression.

Proposal

I propose the following format for tensor shapes:

from typing import Annotated

x: Annotated[torch.Tensor, dtype, 'A', 'B']
y: Annotated[torch.Tensor, dtype, 42, 7]
z: Annotated[torch.Tensor, dtype, ...]

here there are two primary ways in which dimensions are specified:

  • generic: that is using a string e.g. 'A' or 'B'.
  • concrete: using integers.

optionally shapes may also be unspecified using ... akin to typing.Any.
Note that dtype would be something like torch.float32 etc. or may be left out in which case torch.float would be assumed.

Perhaps we would also allow for multiple unknown dimensions to be represented using '*D'. This may be useful when working with batches (that may be spread across unknown number of dimensions).\

What this is not.

This isn't some proposal for a static type-checker (which would be possible based on this system) this is simply a call for creating a convention for how to annotate tensor shapes in PyTorch projects (if one wishes to do so).

Personal Motivation

I personally am a big fan of Python type hints. So naturally when I started a new PyTorch project I searched for some convention on how to annotate tensor shapes. [...] Eventually I ended up here.

cc @ezyang @malfet @rgommers @xuzhao9 @gramster

@ngimel ngimel added module: typing Related to mypy type annotations triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module enhancement Not as big of a feature, but technically not a bug. Should be easy to fix labels Apr 11, 2023
@rgommers
Copy link
Collaborator

Thanks for writing this up @dominiquegarmier!

My main concern here is that Annotated[torch.Tensor, dtype] is quite verbose, and seems to go in the opposite direction to where we'd like to end up. To me, Tensor[<more specific stuff>] is the ideal end goal.

I'll note that NumPy also aims for that, and already providing a numpy.typing module which allows

from numpy.typing import NDArray

NDArray[np.float64]

See the numpy.typing docs for more examples.

jaxtyping is similarly concise (which I like) but turns parametrization around so it's Float[Array] rather than the opposite (which I don't really like).

The array API standard will hopefully start supporting more specific annotations in the same way as numpy (see array-api#584 and array-api#229).

jaxtyping, Con: clashes with mypy and flake8 due to (mypys and flake8s) poor support of new typing features; Limited by mypys (lack of) support for PEP646.

Isn't that just a matter of time until Mypy and flake8 get updated, rather than a fundamental issue? Static typing in Python is on the whole still pretty immature, so it takes time - but I'm not sure that taking shortcuts because of that and ending up with a much more verbose final syntax is the right response to that.

y: Annotated[torch.Tensor, dtype, 42, 7]

This should out to me too, here and in PEP 646 - annotations with exact sizes of dimensions seem both impractical and not necessary - you'd want to annotate dimensionality of tensors (or named dimensions), but not their exact sizes I believe. The talk I saw from the PEP 646 authors suggested that to propagate exact shapes, they'd want to support a shape-calculation DSL within annotations. That is quite unlikely to work out in the real world, so no need to worry about that much.

@wassname
Copy link

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: typing Related to mypy type annotations triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants
0